├── .github └── workflows │ ├── python-publish.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── gotennet.png ├── gotennet_pytorch ├── __init__.py ├── gotennet.py └── tensor_typing.py ├── pyproject.toml └── tests └── test_gotennet.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests the examples in README 2 | on: push 3 | 4 | jobs: 5 | test: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v4 9 | - name: Install Python 10 | uses: actions/setup-python@v4 11 | - name: Install the latest version of rye 12 | uses: eifinger/setup-rye@v2 13 | - name: Use UV instead of pip 14 | run: rye config --set-bool behavior.use-uv=true 15 | - name: Install dependencies 16 | run: | 17 | rye sync 18 | - name: Run pytest 19 | run: rye run pytest tests/ 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## GotenNet - Pytorch 4 | 5 | Implementation of GotenNet, new SOTA 3d equivariant transformer, in Pytorch 6 | 7 | I know a lot of researchers have moved on from geometric learning after Alphafold3. However, I just cannot help but wonder. Hedging my bets 8 | 9 | The official repository has been released [here](https://github.com/sarpaykent/GotenNet/) 10 | 11 | ## Install 12 | 13 | ```bash 14 | $ pip install gotennet-pytorch 15 | ``` 16 | 17 | ## Usage 18 | 19 | ```python 20 | import torch 21 | torch.set_default_dtype(torch.float64) # recommended for equivariant network training 22 | 23 | from gotennet_pytorch import GotenNet 24 | 25 | model = GotenNet( 26 | dim = 256, 27 | max_degree = 2, 28 | depth = 1, 29 | heads = 2, 30 | dim_head = 32, 31 | dim_edge_refinement = 256, 32 | return_coors = False 33 | ) 34 | 35 | atom_ids = torch.randint(0, 14, (1, 12)) # negative atom indices will be assumed to be padding - length of molecule is thus `(atom_ids >= 0).sum(dim = -1)` 36 | coors = torch.randn(1, 12, 3) 37 | adj_mat = torch.randint(0, 2, (1, 12, 12)).bool() 38 | 39 | invariant, coors_out = model(atom_ids, adj_mat = adj_mat, coors = coors) 40 | ``` 41 | 42 | ## Citations 43 | 44 | ```bibtex 45 | @inproceedings{anonymous2024rethinking, 46 | title = {Rethinking Efficient 3D Equivariant Graph Neural Networks}, 47 | author = {Anonymous}, 48 | booktitle = {Submitted to The Thirteenth International Conference on Learning Representations}, 49 | year = {2024}, 50 | url = {https://openreview.net/forum?id=5wxCQDtbMo}, 51 | note = {under review} 52 | } 53 | ``` 54 | 55 | ```bibtex 56 | @inproceedings{Zhou2024ValueRL, 57 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers}, 58 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan}, 59 | year = {2024}, 60 | url = {https://api.semanticscholar.org/CorpusID:273532030} 61 | } 62 | ``` 63 | 64 | ```bibtex 65 | @article{Zhu2024HyperConnections, 66 | title = {Hyper-Connections}, 67 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou}, 68 | journal = {ArXiv}, 69 | year = {2024}, 70 | volume = {abs/2409.19606}, 71 | url = {https://api.semanticscholar.org/CorpusID:272987528} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /gotennet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/gotennet-pytorch/20a63dfabd5c9c71effd9e643941dfca99868a85/gotennet.png -------------------------------------------------------------------------------- /gotennet_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from gotennet_pytorch.gotennet import GotenNet, torch_default_dtype 2 | -------------------------------------------------------------------------------- /gotennet_pytorch/gotennet.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import partial 4 | from contextlib import contextmanager 5 | from collections.abc import Sequence 6 | 7 | import torch 8 | from torch import nn, cat, Tensor, einsum 9 | from torch.nn import Linear, Sequential, Module, ModuleList, ParameterList 10 | 11 | import einx 12 | from einx import get_at 13 | 14 | from einops import rearrange, repeat, reduce 15 | from einops.layers.torch import Rearrange 16 | 17 | from e3nn.o3 import spherical_harmonics 18 | 19 | from gotennet_pytorch.tensor_typing import Float, Int, Bool 20 | 21 | from hyper_connections import get_init_and_expand_reduce_stream_functions 22 | 23 | from x_transformers import Attention 24 | 25 | # ein notation 26 | 27 | # b - batch 28 | # h - heads 29 | # n - sequence 30 | # m - sequence (neighbors) 31 | # i, j - source and target sequence 32 | # d - feature 33 | # m - order of each degree 34 | # l - degree 35 | 36 | # helper functions 37 | 38 | def exists(v): 39 | return v is not None 40 | 41 | def default(v, d): 42 | return v if exists(v) else d 43 | 44 | def max_neg_value(t): 45 | return -torch.finfo(t.dtype).max 46 | 47 | def mask_from_lens(lens, total_len): 48 | seq = torch.arange(total_len, device = lens.device) 49 | return einx.less('n, b -> b n', seq, lens) 50 | 51 | def softclamp(t, value = 50.): 52 | return (t / value).tanh() * value 53 | 54 | @contextmanager 55 | def torch_default_dtype(dtype): 56 | prev_dtype = torch.get_default_dtype() 57 | torch.set_default_dtype(dtype) 58 | yield 59 | torch.set_default_dtype(prev_dtype) 60 | 61 | # normalization 62 | 63 | LayerNorm = partial(nn.LayerNorm, bias = False) 64 | 65 | class HighDegreeNorm(Module): 66 | def __init__(self, dim, eps = 1e-6): 67 | super().__init__() 68 | self.eps = eps 69 | self.scale = dim ** 0.5 70 | self.gamma = nn.Parameter(torch.ones(dim, 1)) 71 | 72 | def forward(self, x): 73 | norms = x.norm(dim = -1, keepdim = True) 74 | den = norms.norm(dim = -2, keepdim = True) * self.scale 75 | return x / den.clamp(min = self.eps) * self.gamma 76 | 77 | # radial basis function 78 | 79 | class Radial(Module): 80 | def __init__( 81 | self, 82 | dim, 83 | radial_hidden_dim = 64 84 | ): 85 | super().__init__() 86 | 87 | hidden = radial_hidden_dim 88 | 89 | self.rp = Sequential( 90 | Rearrange('... -> ... 1'), 91 | Linear(1, hidden), 92 | nn.SiLU(), 93 | LayerNorm(hidden), 94 | Linear(hidden, hidden), 95 | nn.SiLU(), 96 | LayerNorm(hidden), 97 | Linear(hidden, dim) 98 | ) 99 | 100 | def forward(self, x): 101 | return self.rp(x) 102 | 103 | # node scalar feat init 104 | # eq (1) and (2) 105 | 106 | class NodeScalarFeatInit(Module): 107 | def __init__( 108 | self, 109 | num_atoms, 110 | dim, 111 | accept_embed = False, 112 | radial_hidden_dim = 64 113 | ): 114 | super().__init__() 115 | self.atom_embed = nn.Embedding(num_atoms, dim) if not accept_embed else nn.Linear(dim, dim) 116 | self.neighbor_atom_embed = nn.Embedding(num_atoms, dim) if not accept_embed else nn.Linear(dim, dim) 117 | 118 | self.rel_dist_mlp = Radial( 119 | dim = dim, 120 | radial_hidden_dim = radial_hidden_dim 121 | ) 122 | 123 | self.to_node_feats = Sequential( 124 | Linear(dim * 2, dim), 125 | LayerNorm(dim), 126 | nn.SiLU(), 127 | Linear(dim, dim) 128 | ) 129 | 130 | def forward( 131 | self, 132 | atoms: Int['b n'] | Float['b n d'], 133 | rel_dist: Float['b n m'], 134 | adj_mat: Bool['b n m'] | None = None, 135 | neighbor_indices: Int['b n m'] | None = None, 136 | neighbor_mask: Bool['b n m'] | None = None, 137 | mask: Bool['b n'] | None = None, 138 | ) -> Float['b n d']: 139 | 140 | dtype = rel_dist.dtype 141 | batch, seq, device = *atoms.shape[:2], atoms.device 142 | 143 | if not exists(adj_mat): 144 | adj_mat = torch.ones_like(rel_dist, device = device, dtype = dtype) 145 | 146 | if atoms.dtype in (torch.int, torch.long): 147 | atoms = atoms.masked_fill(atoms < 0, 0) 148 | 149 | embeds = self.atom_embed(atoms) 150 | 151 | rel_dist_feats = self.rel_dist_mlp(rel_dist) 152 | 153 | if exists(neighbor_indices): 154 | if exists(neighbor_mask): 155 | rel_dist_feats = einx.where('b i j, b i j d, -> b i j d', neighbor_mask, rel_dist_feats, 0.) 156 | else: 157 | if exists(mask): 158 | rel_dist_feats = einx.where('b j, b i j d, -> b i j d', mask, rel_dist_feats, 0.) 159 | 160 | neighbor_embeds = self.neighbor_atom_embed(atoms) 161 | 162 | if exists(neighbor_indices): 163 | neighbor_embeds = get_at('b [n] d, b i j -> b i j d', neighbor_embeds, neighbor_indices) 164 | 165 | neighbor_feats = einsum('b i j, b i j d, b i j d -> b i d', adj_mat.type(dtype), rel_dist_feats, neighbor_embeds) 166 | else: 167 | neighbor_feats = einsum('b i j, b i j d, b j d -> b i d', adj_mat.type(dtype), rel_dist_feats, neighbor_embeds) 168 | 169 | self_and_neighbor = torch.cat((embeds, neighbor_feats), dim = -1) 170 | 171 | return self.to_node_feats(self_and_neighbor) 172 | 173 | # edge scalar feat init 174 | # eq (3) 175 | 176 | class EdgeScalarFeatInit(Module): 177 | def __init__( 178 | self, 179 | dim, 180 | expansion_factor = 4., 181 | ): 182 | super().__init__() 183 | 184 | dim_inner = int(dim * expansion_factor) 185 | 186 | self.rel_dist_mlp = Sequential( 187 | Rearrange('... -> ... 1'), 188 | nn.Linear(1, dim_inner, bias = False), 189 | LayerNorm(dim_inner), 190 | nn.SiLU(), 191 | nn.Linear(dim_inner, dim, bias = False) 192 | ) 193 | 194 | def forward( 195 | self, 196 | h: Float['b n d'], 197 | rel_dist: Float['b n m'], 198 | neighbor_indices: Int['b n m'] | None = None 199 | ) -> Float['b n m d']: 200 | 201 | if exists(neighbor_indices): 202 | h_neighbors = get_at('b [n] d, b i j -> b i j d', h, neighbor_indices) 203 | 204 | outer_sum_feats = einx.add('b i d, b i j d -> b i j d', h, h_neighbors) 205 | else: 206 | outer_sum_feats = einx.add('b i d, b j d -> b i j d', h, h) 207 | 208 | rel_dist_feats = self.rel_dist_mlp(rel_dist) 209 | 210 | return outer_sum_feats + rel_dist_feats 211 | 212 | # equivariant feedforward 213 | # section 3.5 214 | 215 | class EquivariantFeedForward(Module): 216 | def __init__( 217 | self, 218 | dim, 219 | max_degree, 220 | mlp_expansion_factor = 2., 221 | layernorm_input = False # no mention of this in the paper, but think there should be one based on my own intuition 222 | ): 223 | """ 224 | following eq 13 225 | """ 226 | super().__init__() 227 | assert max_degree > 1 228 | self.max_degree = max_degree 229 | 230 | mlp_dim = int(mlp_expansion_factor * dim * 2) 231 | 232 | self.projs = ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(max_degree)]) 233 | 234 | self.mlps = ModuleList([ 235 | Sequential( 236 | LayerNorm(dim * 2) if layernorm_input else nn.Identity(), 237 | Linear(dim * 2, mlp_dim), 238 | nn.SiLU(), 239 | Linear(mlp_dim, dim * 2) 240 | ) 241 | for _ in range(max_degree) 242 | ]) 243 | 244 | def forward( 245 | self, 246 | h: Float['b n d'], 247 | x: Sequence[Float['b n d _'], ...] 248 | ): 249 | assert len(x) == self.max_degree 250 | 251 | h_residual = 0. 252 | x_residuals = [] 253 | 254 | for one_degree, proj, mlp in zip(x, self.projs, self.mlps): 255 | 256 | # make higher degree tensor invariant through norm on `m` axis and then concat -> mlp -> split 257 | 258 | proj_one_degree = einsum('... d m, ... d e -> ... e m', one_degree, proj) 259 | 260 | normed_invariant = proj_one_degree.norm(dim = -1) 261 | 262 | mlp_inp = torch.cat((h, normed_invariant), dim = - 1) 263 | mlp_out = mlp(mlp_inp) 264 | 265 | m1, m2 = mlp_out.chunk(2, dim = -1) # named m1, m2 in equation 13, one is the residual for h, the other modulates the projected higher degree tensor for its residual 266 | 267 | modulated_one_degree = einx.multiply('... d m, ... d -> ... d m', proj_one_degree, m2) 268 | 269 | # aggregate residuals 270 | 271 | h_residual = h_residual + m1 272 | 273 | x_residuals.append(modulated_one_degree) 274 | 275 | # return residuals 276 | 277 | return h_residual, x_residuals 278 | 279 | # hierarchical tensor refinement 280 | # section 3.4 281 | 282 | class HierarchicalTensorRefinement(Module): 283 | def __init__( 284 | self, 285 | dim, 286 | dim_edge_refinement, # they made this value much higher for MD22 task. so it is an important hparam for certain more difficult tasks 287 | max_degree, 288 | norm_edge_proj_input = True # this was not in the paper, but added or else network explodes at around depth 4 289 | ): 290 | super().__init__() 291 | assert max_degree > 0 292 | 293 | 294 | # in paper, queries share the same projection, but each higher degree has its own key projection 295 | 296 | self.to_queries = nn.Parameter(torch.randn(dim, dim_edge_refinement)) 297 | 298 | self.to_keys = ParameterList([nn.Parameter(torch.randn(dim, dim_edge_refinement)) for _ in range(max_degree)]) 299 | 300 | # the two weight matrices 301 | # one for mixing the inner product between all queries and keys across degrees above 302 | # the other for refining the t_ij passed down from the layer before as a residual 303 | 304 | self.residue_update = nn.Linear(dim, dim, bias = False) 305 | self.edge_proj = nn.Linear(dim_edge_refinement * max_degree, dim, bias = False) 306 | 307 | # norm not in diagram or paper, added to prevent t_ij from exploding 308 | 309 | self.norm = LayerNorm(dim_edge_refinement * max_degree) if norm_edge_proj_input else nn.Identity() 310 | 311 | def forward( 312 | self, 313 | t_ij: Float['b n m d'], 314 | x: Sequence[Float['b n d _'], ...], 315 | neighbor_indices: Int['b n m'] | None = None, 316 | ) -> Float['b n m d']: 317 | 318 | # eq (10) 319 | 320 | queries = [einsum('... d m, ... d e -> ... e m', one_degree, self.to_queries) for one_degree in x] 321 | 322 | keys = [einsum('... d m, ... d e -> ... e m', one_degree, to_keys) for one_degree, to_keys in zip(x, self.to_keys)] 323 | 324 | # eq (11) 325 | 326 | if exists(neighbor_indices): 327 | keys = [get_at('b [n] d m, b i j -> b i j d m', one_degree_key, neighbor_indices) for one_degree_key in keys] 328 | 329 | inner_product = [einsum('... i d m, ... i j d m -> ... i j d', one_degree_query, one_degree_key) for one_degree_query, one_degree_key in zip(queries, keys)] 330 | else: 331 | inner_product = [einsum('... i d m, ... j d m -> ... i j d', one_degree_query, one_degree_key) for one_degree_query, one_degree_key in zip(queries, keys)] 332 | 333 | w_ij = cat(inner_product, dim = -1) 334 | 335 | # this was not in the paper, but added or else network explodes at around depth 4 336 | 337 | w_ij = self.norm(w_ij) 338 | 339 | # eq (12) 340 | 341 | edge_proj_out = self.edge_proj(w_ij) 342 | edge_proj_out = torch.sigmoid(edge_proj_out) 343 | 344 | residue_update_out = self.residue_update(t_ij) 345 | 346 | return edge_proj_out * residue_update_out 347 | 348 | # geometry-aware tensor attention 349 | # section 3.3 350 | 351 | class GeometryAwareTensorAttention(Module): 352 | def __init__( 353 | self, 354 | dim, 355 | max_degree, 356 | dim_head = None, 357 | heads = 8, 358 | softclamp_value = 50., 359 | mlp_expansion_factor = 2., 360 | only_init_high_degree_feats = False, # if set to True, only returns high degree steerable features eq (4) in section 3.2 361 | learned_value_residual_mix = False, 362 | ): 363 | super().__init__() 364 | self.only_init_high_degree_feats = only_init_high_degree_feats 365 | 366 | assert max_degree > 0 367 | self.max_degree = max_degree 368 | 369 | dim_head = default(dim_head, dim) 370 | 371 | # for some reason, there is no mention of attention heads, will just improvise 372 | 373 | dim_inner = heads * dim_head 374 | 375 | self.split_heads = Rearrange('b ... (h d) -> b h ... d', h = heads) 376 | self.merge_heads = Rearrange('b h ... d -> b ... (h d)') 377 | 378 | # eq (5) - layernorms are present in the diagram in figure 2. but not mentioned in the equations.. 379 | 380 | self.to_hi = LayerNorm(dim) 381 | self.to_hj = LayerNorm(dim) 382 | 383 | self.to_queries = Linear(dim, dim_inner, bias = False) 384 | self.to_keys = Linear(dim, dim_inner, bias = False) 385 | 386 | dim_mlp_inner = int(mlp_expansion_factor * dim_inner) 387 | 388 | # attention softclamping, used in Gemma 389 | 390 | self.softclamp = partial(softclamp, value = softclamp_value) 391 | 392 | # S contains two parts of L_max (one to modulate each degree of r_ij, another to modulate each X_j, then one final to modulate h). incidentally, this overlaps with eq. (m = 2 * L + 1), causing much confusion, cleared up in openreview 393 | 394 | self.S = (1, max_degree, max_degree) if not only_init_high_degree_feats else (max_degree,) 395 | S = sum(self.S) 396 | 397 | self.to_values = Sequential( 398 | Linear(dim, dim_mlp_inner), 399 | nn.SiLU(), 400 | Linear(dim_mlp_inner, S * dim_inner), 401 | Rearrange('... (s d) -> ... s d', s = S) 402 | ) 403 | 404 | # value residual, iclr 2024 paper that is certain to take off 405 | 406 | self.to_value_residual_mix = Sequential( 407 | Linear(dim, heads, bias = False), 408 | nn.Sigmoid(), 409 | Rearrange('b n h -> b h n 1 1') 410 | ) if learned_value_residual_mix else None 411 | 412 | # eq (6) second half: t_ij -> edge scalar features 413 | 414 | self.to_edge_keys = Sequential( 415 | Linear(dim, S * dim_inner, bias = False), # W_re 416 | nn.SiLU(), # σ_k - never indicated in paper. just use Silu 417 | Rearrange('... (s d) -> ... s d', s = S) 418 | ) 419 | 420 | # eq (7) - todo, handle cutoff radius 421 | 422 | self.to_edge_values = nn.Sequential( # t_ij modulating weighted sum 423 | Linear(dim, S * dim_inner, bias = False), 424 | Rearrange('... (s d) -> ... s d', s = S) 425 | ) 426 | 427 | self.post_attn_h_values = Sequential( 428 | Linear(dim, dim_mlp_inner), 429 | nn.SiLU(), 430 | Linear(dim_mlp_inner, S * dim_inner), 431 | Rearrange('... (s d) -> ... s d', s = S) 432 | ) 433 | 434 | # alphafold styled gating 435 | 436 | self.to_gates = Sequential( 437 | Linear(dim, heads * S), 438 | nn.Sigmoid(), 439 | Rearrange('b i (h s) -> b h i 1 s 1', s = S), 440 | ) 441 | 442 | # combine heads 443 | 444 | self.combine_heads = Sequential( 445 | Linear(dim_inner, dim, bias = False) 446 | ) 447 | 448 | def forward( 449 | self, 450 | h: Float['b n d'], 451 | t_ij: Float['b n m d'], 452 | r_ij: Sequence[Float['b n m _'], ...], 453 | x: Sequence[Float['b n d _'], ...] | None = None, 454 | mask: Bool['b n'] | None = None, 455 | neighbor_indices: Int['b n m'] | None = None, 456 | neighbor_mask: Bool['b n m'] | None = None, 457 | return_value_residual = False, 458 | value_residuals: Tuple[Tensor, Tensor] | None = None 459 | ): 460 | # validation 461 | 462 | assert exists(x) ^ self.only_init_high_degree_feats 463 | 464 | if not self.only_init_high_degree_feats: 465 | assert len(x) == self.max_degree 466 | 467 | assert len(r_ij) == self.max_degree 468 | 469 | # eq (5) 470 | 471 | hi = self.to_hi(h) 472 | hj = self.to_hj(h) 473 | 474 | queries = self.to_queries(hi) 475 | keys = self.to_keys(hj) 476 | 477 | # unsure why values are split into two, with one elementwise-multiplied with the edge values coming from t_ij 478 | # need another pair of eyes to double check 479 | 480 | values = self.to_values(hj) 481 | post_attn_values = self.post_attn_h_values(hj) 482 | 483 | # edge keys and values 484 | 485 | edge_keys = self.to_edge_keys(t_ij) 486 | edge_values = self.to_edge_values(t_ij) 487 | 488 | # split out attention heads 489 | 490 | queries, keys, values, post_attn_values, edge_keys, edge_values = map(self.split_heads, (queries, keys, values, post_attn_values, edge_keys, edge_values)) 491 | 492 | # value residual mixing 493 | 494 | next_value_residuals = (values, post_attn_values, edge_values) 495 | 496 | if exists(self.to_value_residual_mix): 497 | assert exists(value_residuals) 498 | 499 | value_residual, post_attn_values_residual, edge_values_residual = value_residuals 500 | 501 | mix = self.to_value_residual_mix(hi) 502 | 503 | values = values.lerp(value_residual, mix) 504 | post_attn_values = post_attn_values.lerp(post_attn_values_residual, mix) 505 | 506 | if exists(neighbor_indices): 507 | mix = get_at('b h [n] ..., b i j -> b h i j ...', mix, neighbor_indices) 508 | else: 509 | mix = rearrange(mix, 'b h j ... -> b h 1 j ...') 510 | 511 | edge_values = edge_values.lerp(edge_values_residual, mix) 512 | 513 | # account for neighbor logic 514 | 515 | if exists(neighbor_indices): 516 | keys = get_at('b h [n] ..., b i j -> b h i j ...', keys, neighbor_indices) 517 | values = get_at('b h [n] ..., b i j -> b h i j ...', values, neighbor_indices) 518 | post_attn_values = get_at('b h [n] ..., b i j -> b h i j ...', post_attn_values, neighbor_indices) 519 | 520 | # eq (6) 521 | 522 | # unsure why there is a k-dimension in the paper math notation, in addition to i, j 523 | 524 | if exists(neighbor_indices): 525 | keys = einx.multiply('... i j d, ... i j s d -> ... i j s d', keys, edge_keys) 526 | else: 527 | keys = einx.multiply('... j d, ... i j s d -> ... i j s d', keys, edge_keys) 528 | 529 | # similarities 530 | 531 | sim = einsum('... i d, ... i j s d -> ... i j s', queries, keys) 532 | 533 | # soft clamping - used successfully in gemma to prevent attention logit overflows 534 | 535 | sim = self.softclamp(sim) 536 | 537 | # masking 538 | 539 | if exists(neighbor_indices): 540 | if exists(neighbor_mask): 541 | sim = einx.where('b i j, b h i j s, -> b h i j s', neighbor_mask, sim, max_neg_value(sim)) 542 | else: 543 | if exists(mask): 544 | sim = einx.where('b j, b h i j s, -> b h i j s', mask, sim, max_neg_value(sim)) 545 | 546 | # attend 547 | 548 | attn = sim.softmax(dim = -2) 549 | 550 | # aggregate values 551 | 552 | if exists(neighbor_indices): 553 | sea_ij = einsum('... i j s, ... i j s d -> ... i j s d', attn, values) 554 | else: 555 | sea_ij = einsum('... i j s, ... j s d -> ... i j s d', attn, values) 556 | 557 | # eq (7) 558 | 559 | if exists(neighbor_indices): 560 | sea_ij = sea_ij + einx.multiply('... i j s d, ... i j s d -> ... i j s d', edge_values, post_attn_values) 561 | else: 562 | sea_ij = sea_ij + einx.multiply('... i j s d, ... j s d -> ... i j s d', edge_values, post_attn_values) 563 | 564 | # alphafold style gating 565 | 566 | out = sea_ij * self.to_gates(hi) 567 | 568 | # combine heads - not in paper for some reason, but attention heads mentioned, so must be necessary? 569 | 570 | out = self.merge_heads(out) 571 | 572 | out = self.combine_heads(out) 573 | 574 | # maybe eq (4) and early return 575 | 576 | if self.only_init_high_degree_feats: 577 | x_ij_init = [einsum('... i j m, ... i j d -> ... i d m', one_r_ij, one_r_ij_scale) for one_r_ij, one_r_ij_scale in zip(r_ij, out.unbind(dim = -2))] 578 | return x_ij_init 579 | 580 | # split out all the O's (eq 7 second half) 581 | 582 | h_scales, r_ij_scales, x_scales = out.split(self.S, dim = -2) 583 | 584 | # modulate with invariant scales and sum residuals 585 | 586 | h_residual = reduce(h_scales, 'b i j 1 d -> b i d', 'sum') 587 | x_residuals = [] 588 | 589 | for one_degree, one_r_ij, one_degree_scale, one_r_ij_scale in zip(x, r_ij, x_scales.unbind(dim = -2), r_ij_scales.unbind(dim = -2)): 590 | 591 | r_ij_residual = einsum('b i j m, b i j d -> b i d m', one_r_ij, one_r_ij_scale) 592 | 593 | if exists(neighbor_indices): 594 | one_degree_neighbors = get_at('b [n] d m, b i j -> b i j d m', one_degree, neighbor_indices) 595 | 596 | x_ij_residual = einsum('b i j d m, b i j d -> b i d m', one_degree_neighbors, one_degree_scale) 597 | 598 | else: 599 | x_ij_residual = einsum('b j d m, b i j d -> b i d m', one_degree, one_degree_scale) 600 | 601 | x_residuals.append(r_ij_residual + x_ij_residual) 602 | 603 | out = (h_residual, x_residuals) 604 | 605 | if not return_value_residual: 606 | return out 607 | 608 | return out, next_value_residuals 609 | 610 | # full attention 611 | 612 | class InvariantAttention(Module): 613 | def __init__( 614 | self, 615 | dim, 616 | **attn_kwargs 617 | ): 618 | super().__init__() 619 | self.norm = nn.RMSNorm(dim) 620 | self.attn = Attention(dim, **attn_kwargs) 621 | 622 | def forward(self, h, mask = None): 623 | h = self.norm(h) 624 | return self.attn(h, mask = mask) 625 | 626 | # main class 627 | 628 | class GotenNet(Module): 629 | def __init__( 630 | self, 631 | dim, 632 | depth, 633 | max_degree, 634 | dim_edge_refinement = None, 635 | accept_embed = False, 636 | num_atoms = 14, 637 | heads = 8, 638 | dim_head = None, 639 | cutoff_radius = None, 640 | invariant_full_attn = False, 641 | invariant_attn_use_flash = False, 642 | full_attn_kwargs: dict = dict(), 643 | max_neighbors = float('inf'), 644 | mlp_expansion_factor = 2., 645 | edge_init_mlp_expansion_factor = 4., 646 | ff_kwargs: dict = dict(), 647 | return_coors = True, 648 | proj_invariant_dim = None, 649 | final_norm = True, 650 | add_value_residual = True, 651 | num_residual_streams = 4, 652 | htr_kwargs: dict = dict() 653 | ): 654 | super().__init__() 655 | self.accept_embed = accept_embed 656 | 657 | assert max_degree > 0 658 | self.max_degree = max_degree 659 | 660 | dim_edge_refinement = default(dim_edge_refinement, dim) 661 | 662 | # hyper connections, applied to invariant h for starters 663 | 664 | init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) 665 | 666 | # only consider neighbors less than `cutoff_radius`, in paper, they used ~ 5 angstroms 667 | # can further randomly select from eligible neighbors with `max_neighbors` 668 | 669 | self.cutoff_radius = cutoff_radius 670 | self.max_neighbors = max_neighbors 671 | 672 | # node and edge feature init 673 | 674 | self.node_init = NodeScalarFeatInit(num_atoms, dim, accept_embed = accept_embed) 675 | self.edge_init = EdgeScalarFeatInit(dim, expansion_factor = edge_init_mlp_expansion_factor) 676 | 677 | self.high_degree_init = GeometryAwareTensorAttention( 678 | dim, 679 | max_degree = max_degree, 680 | dim_head = dim_head, 681 | heads = heads, 682 | mlp_expansion_factor = mlp_expansion_factor, 683 | only_init_high_degree_feats = True 684 | ) 685 | 686 | # layers, thus deep learning 687 | 688 | self.layers = ModuleList([]) 689 | self.residual_fns = ModuleList([]) 690 | 691 | for layer_index in range(depth): 692 | is_first = layer_index == 0 693 | 694 | self.layers.append(ModuleList([ 695 | HierarchicalTensorRefinement(dim, dim_edge_refinement, max_degree, **htr_kwargs), 696 | InvariantAttention(dim = dim, flash = invariant_attn_use_flash, **full_attn_kwargs) if invariant_full_attn else None, 697 | GeometryAwareTensorAttention(dim, max_degree, dim_head, heads, mlp_expansion_factor, learned_value_residual_mix = add_value_residual and not is_first), 698 | EquivariantFeedForward(dim, max_degree, mlp_expansion_factor), 699 | ])) 700 | 701 | self.residual_fns.append(ModuleList([ 702 | init_hyper_conn(dim = dim) if invariant_full_attn else None, 703 | init_hyper_conn(dim = dim), 704 | init_hyper_conn(dim = dim), 705 | ])) 706 | 707 | # not mentioned in paper, but transformers need a final norm 708 | 709 | self.final_norm = final_norm 710 | 711 | if final_norm: 712 | self.h_final_norm = nn.LayerNorm(dim) 713 | 714 | self.x_final_norms = ModuleList([HighDegreeNorm(dim) for _ in range(max_degree)]) 715 | 716 | # maybe project invariant 717 | 718 | self.proj_invariant = None 719 | 720 | if exists(proj_invariant_dim): 721 | self.proj_invariant = Linear(dim, proj_invariant_dim, bias = False) 722 | 723 | # maybe project to coordinates 724 | 725 | self.proj_to_coors = Sequential( 726 | Rearrange('... d m -> ... m d'), 727 | Linear(dim, 1, bias = False), 728 | Rearrange('... 1 -> ...') 729 | ) if return_coors else None 730 | 731 | def forward( 732 | self, 733 | atoms: Int['b n'] | Float['b n d'], 734 | coors: Float['b n 3'], 735 | adj_mat: Bool['b n n'] | None = None, 736 | lens: Int['b'] | None = None, 737 | mask: Bool['b n'] | None = None 738 | ): 739 | assert (atoms.dtype in (torch.int, torch.long)) ^ self.accept_embed 740 | 741 | batch, seq_len, device = *atoms.shape[:2], atoms.device 742 | 743 | assert not (exists(lens) and exists(mask)), '`lens` and `masks` cannot be both passed in' 744 | 745 | if exists(lens): 746 | mask = mask_from_lens(lens, seq_len) 747 | 748 | # also allow for negative atom indices in place of `lens` or `mask` 749 | 750 | if atoms.dtype in (torch.int, torch.long): 751 | atom_mask = atoms >= 0 752 | mask = default(mask, atom_mask) 753 | 754 | rel_pos = einx.subtract('b i c, b j c -> b i j c', coors, coors) 755 | rel_dist = rel_pos.norm(dim = -1) 756 | 757 | # process adjacency matrix 758 | 759 | if exists(adj_mat): 760 | eye = torch.eye(seq_len, device = device, dtype = torch.bool) 761 | adj_mat = adj_mat & ~eye # remove self from adjacency matrix 762 | 763 | # figure out neighbors, if needed 764 | 765 | neighbor_indices: Int['b n m'] | None = None 766 | neighbor_mask: Bool['b n m'] | None = None 767 | 768 | if exists(self.cutoff_radius): 769 | 770 | if exists(mask): 771 | rel_dist = einx.where('b j, b i j, -> b i j', mask, rel_dist, 1e6) 772 | 773 | is_neighbor = (rel_dist <= self.cutoff_radius).float() 774 | 775 | max_eligible_neighbors = is_neighbor.sum(dim = -1).long().amax().item() 776 | max_neighbors = min(max_eligible_neighbors, self.max_neighbors) 777 | 778 | noised_is_neighbor = is_neighbor + torch.rand_like(is_neighbor) * 1e-3 779 | neighbor_indices = noised_is_neighbor.topk(k = max_neighbors, dim = -1).indices 780 | 781 | if exists(adj_mat): 782 | adj_mat = adj_mat.gather(-1, neighbor_indices) 783 | 784 | neighbor_dist = rel_dist.gather(-1, neighbor_indices) 785 | neighbor_mask = neighbor_dist <= self.cutoff_radius 786 | 787 | rel_dist = neighbor_dist 788 | rel_pos = rel_pos.gather(-2, repeat(neighbor_indices, '... -> ... c', c = 3)) 789 | 790 | # initialization 791 | 792 | h = self.node_init(atoms, rel_dist, adj_mat, mask = mask, neighbor_indices = neighbor_indices, neighbor_mask = neighbor_mask) 793 | 794 | t_ij = self.edge_init(h, rel_dist, neighbor_indices = neighbor_indices) 795 | 796 | # constitute r_ij from section 3.1 797 | 798 | r_ij = [] 799 | 800 | for degree in range(1, self.max_degree + 1): 801 | one_degree_r_ij = spherical_harmonics(degree, rel_pos, normalize = True, normalization = 'norm') 802 | r_ij.append(one_degree_r_ij) 803 | 804 | # init the high degrees 805 | 806 | x = self.high_degree_init(h, t_ij, r_ij, mask = mask, neighbor_indices = neighbor_indices, neighbor_mask = neighbor_mask) 807 | 808 | # value residual 809 | 810 | value_residuals = None 811 | 812 | # maybe expand invariant h residual stream 813 | 814 | h = self.expand_streams(h) 815 | 816 | # go through the layers 817 | 818 | for (htr, maybe_h_attn, attn, ff), (maybe_h_attn_residual_fn, attn_residual_fn, ff_residual_fn) in zip(self.layers, self.residual_fns): 819 | 820 | # hierarchical tensor refinement 821 | 822 | t_ij = htr(t_ij, x, neighbor_indices = neighbor_indices) + t_ij 823 | 824 | # maybe full flash attention across invariants 825 | 826 | if exists(maybe_h_attn): 827 | h, add_h_attn_residual = maybe_h_attn_residual_fn(h) 828 | 829 | h = maybe_h_attn(h, mask = mask) 830 | 831 | h = add_h_attn_residual(h) 832 | 833 | # followed by attention, but of course 834 | 835 | h, add_attn_residual = attn_residual_fn(h) 836 | 837 | (h_residual, x_residuals), next_value_residuals = attn(h, t_ij, r_ij, x, mask = mask, neighbor_indices = neighbor_indices, neighbor_mask = neighbor_mask, value_residuals = value_residuals, return_value_residual = True) 838 | 839 | # add attention residuals 840 | 841 | h = add_attn_residual(h_residual) 842 | 843 | x = [*map(sum, zip(x, x_residuals))] 844 | 845 | # handle value residual 846 | 847 | value_residuals = default(value_residuals, next_value_residuals) 848 | 849 | # feedforward 850 | 851 | h, add_ff_residual = ff_residual_fn(h) 852 | 853 | h_residual, x_residuals = ff(h, x) 854 | 855 | # add feedforward residuals 856 | 857 | h = add_ff_residual(h_residual) 858 | 859 | x = [*map(sum, zip(x, x_residuals))] 860 | 861 | h = self.reduce_streams(h) 862 | 863 | # maybe final norms 864 | 865 | if self.final_norm: 866 | h = self.h_final_norm(h) 867 | x = [norm(one_degree) for one_degree, norm in zip(x, self.x_final_norms)] 868 | 869 | # maybe transform invariant h 870 | 871 | if exists(self.proj_invariant): 872 | h = self.proj_invariant(h) 873 | 874 | # return h and x if `return_coors = False` 875 | 876 | if not exists(self.proj_to_coors): 877 | return h, x 878 | 879 | degree1, *_ = x 880 | 881 | coors_out = self.proj_to_coors(degree1) 882 | 883 | return h, coors_out 884 | -------------------------------------------------------------------------------- /gotennet_pytorch/tensor_typing.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from jaxtyping import ( 4 | Float, 5 | Int, 6 | Bool 7 | ) 8 | 9 | # jaxtyping is a misnomer, works for pytorch 10 | 11 | class TorchTyping: 12 | def __init__(self, abstract_dtype): 13 | self.abstract_dtype = abstract_dtype 14 | 15 | def __getitem__(self, shapes: str): 16 | return self.abstract_dtype[Tensor, shapes] 17 | 18 | Float = TorchTyping(Float) 19 | Int = TorchTyping(Int) 20 | Bool = TorchTyping(Bool) 21 | 22 | __all__ = [ 23 | Float, 24 | Int, 25 | Bool 26 | ] 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "gotennet-pytorch" 3 | version = "0.3.1" 4 | description = "GotenNet in Pytorch" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.9" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'transformers', 15 | 'attention mechanism', 16 | 'se3 equivariance', 17 | 'molecules' 18 | ] 19 | classifiers=[ 20 | 'Development Status :: 4 - Beta', 21 | 'Intended Audience :: Developers', 22 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 23 | 'License :: OSI Approved :: MIT License', 24 | 'Programming Language :: Python :: 3.8', 25 | ] 26 | 27 | dependencies = [ 28 | 'e3nn', 29 | 'einx>=0.3.0', 30 | 'einops>=0.8.0', 31 | 'jaxtyping', 32 | 'hyper-connections>=0.1.0', 33 | 'x-transformers>=1.44.4', 34 | 'torch>=2.4', 35 | ] 36 | 37 | [project.urls] 38 | Homepage = "https://pypi.org/project/gotennet-pytorch/" 39 | Repository = "https://github.com/lucidrains/gotennet-pytorch" 40 | 41 | [project.optional-dependencies] 42 | examples = ["tqdm", "numpy"] 43 | 44 | [build-system] 45 | requires = ["hatchling"] 46 | build-backend = "hatchling.build" 47 | 48 | [tool.rye] 49 | managed = true 50 | dev-dependencies = [ 51 | "ruff>=0.4.2", 52 | "pytest>=8.2.0", 53 | "pytest-examples>=0.0.10", 54 | "pytest-cov>=5.0.0", 55 | ] 56 | 57 | [tool.pytest.ini_options] 58 | pythonpath = ["."] 59 | 60 | [tool.hatch.metadata] 61 | allow-direct-references = true 62 | 63 | [tool.hatch.build.targets.wheel] 64 | packages = ["gotennet_pytorch"] 65 | -------------------------------------------------------------------------------- /tests/test_gotennet.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import sin, cos, stack 4 | from einops import rearrange 5 | 6 | # random rotations 7 | 8 | def rot_z(gamma): 9 | c = cos(gamma) 10 | s = sin(gamma) 11 | z = torch.zeros_like(gamma) 12 | o = torch.ones_like(gamma) 13 | 14 | out = stack(( 15 | c, -s, z, 16 | s, c, z, 17 | z, z, o 18 | ), dim = -1) 19 | 20 | return rearrange(out, '... (r1 r2) -> ... r1 r2', r1 = 3) 21 | 22 | def rot_y(beta): 23 | c = cos(beta) 24 | s = sin(beta) 25 | z = torch.zeros_like(beta) 26 | o = torch.ones_like(beta) 27 | 28 | out = stack(( 29 | c, z, s, 30 | z, o, z, 31 | -s, z, c 32 | ), dim = -1) 33 | 34 | return rearrange(out, '... (r1 r2) -> ... r1 r2', r1 = 3) 35 | 36 | def rot(alpha, beta, gamma): 37 | return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma) 38 | 39 | # testing 40 | 41 | from gotennet_pytorch.gotennet import GotenNet, torch_default_dtype 42 | 43 | @torch_default_dtype(torch.float64) 44 | @pytest.mark.parametrize('invariant_full_attn', (False, True)) 45 | def test_invariant(invariant_full_attn): 46 | model = GotenNet( 47 | dim = 256, 48 | max_degree = 2, 49 | depth = 4, 50 | heads = 2, 51 | dim_head = 32, 52 | dim_edge_refinement = 256, 53 | invariant_full_attn = invariant_full_attn, 54 | return_coors = False, 55 | ff_kwargs = dict( 56 | layernorm_input = True 57 | ) 58 | ) 59 | 60 | random_rotation = rot(*torch.randn(3)) 61 | 62 | atom_ids = torch.randint(0, 14, (1, 12)) 63 | coors = torch.randn(1, 12, 3) 64 | adj_mat = torch.randint(0, 2, (1, 12, 12)).bool() 65 | mask = torch.randint(0, 2, (1, 12)).bool() 66 | 67 | inv1, _ = model(atom_ids, adj_mat = adj_mat, coors = coors, mask = mask) 68 | inv2, _ = model(atom_ids, adj_mat = adj_mat, coors = coors @ random_rotation, mask = mask) 69 | 70 | assert torch.allclose(inv1, inv2, atol = 1e-5) 71 | 72 | @torch_default_dtype(torch.float64) 73 | @pytest.mark.parametrize('num_residual_streams', (1, 4)) 74 | @pytest.mark.parametrize('invariant_full_attn', (False, True)) 75 | def test_equivariant( 76 | num_residual_streams, 77 | invariant_full_attn 78 | ): 79 | 80 | model = GotenNet( 81 | dim = 256, 82 | max_degree = 2, 83 | depth = 4, 84 | heads = 2, 85 | dim_head = 32, 86 | dim_edge_refinement = 256, 87 | return_coors = True, 88 | invariant_full_attn = invariant_full_attn, 89 | ff_kwargs = dict( 90 | layernorm_input = True 91 | ), 92 | num_residual_streams = num_residual_streams 93 | ) 94 | 95 | random_rotation = rot(*torch.randn(3)) 96 | 97 | atom_ids = torch.randint(0, 14, (1, 12)) 98 | coors = torch.randn(1, 12, 3) 99 | adj_mat = torch.randint(0, 2, (1, 12, 12)).bool() 100 | mask = torch.randint(0, 2, (1, 12)).bool() 101 | 102 | _, coors1 = model(atom_ids, adj_mat = adj_mat, coors = coors, mask = mask) 103 | _, coors2 = model(atom_ids, adj_mat = adj_mat, coors = coors @ random_rotation, mask = mask) 104 | 105 | assert torch.allclose(coors1 @ random_rotation, coors2, atol = 1e-5) 106 | 107 | @torch_default_dtype(torch.float64) 108 | def test_equivariant_with_atom_feats(): 109 | 110 | model = GotenNet( 111 | dim = 256, 112 | max_degree = 2, 113 | depth = 4, 114 | heads = 2, 115 | dim_head = 32, 116 | dim_edge_refinement = 256, 117 | accept_embed = True, 118 | return_coors = True 119 | ) 120 | 121 | random_rotation = rot(*torch.randn(3)) 122 | 123 | atom_feats = torch.randn((1, 12, 256)) 124 | coors = torch.randn(1, 12, 3) 125 | adj_mat = torch.randint(0, 2, (1, 12, 12)).bool() 126 | mask = torch.randint(0, 2, (1, 12)).bool() 127 | 128 | _, coors1 = model(atom_feats, adj_mat = adj_mat, coors = coors, mask = mask) 129 | _, coors2 = model(atom_feats, adj_mat = adj_mat, coors = coors @ random_rotation, mask = mask) 130 | 131 | assert torch.allclose(coors1 @ random_rotation, coors2, atol = 1e-5) 132 | 133 | @torch_default_dtype(torch.float64) 134 | @pytest.mark.parametrize('num_residual_streams', (1, 4)) 135 | def test_equivariant_neighbors(num_residual_streams): 136 | 137 | model = GotenNet( 138 | dim = 256, 139 | max_degree = 2, 140 | depth = 4, 141 | heads = 2, 142 | dim_head = 32, 143 | cutoff_radius = 5., 144 | dim_edge_refinement = 256, 145 | return_coors = True, 146 | num_residual_streams = num_residual_streams, 147 | ff_kwargs = dict( 148 | layernorm_input = True 149 | ) 150 | ) 151 | 152 | random_rotation = rot(*torch.randn(3)) 153 | 154 | atom_ids = torch.randint(0, 14, (1, 6)) 155 | coors = torch.randn(1, 6, 3) 156 | adj_mat = torch.randint(0, 2, (1, 6, 6)).bool() 157 | mask = torch.randint(0, 2, (1, 6)).bool() 158 | 159 | _, coors1 = model(atom_ids, adj_mat = adj_mat, coors = coors, mask = mask) 160 | _, coors2 = model(atom_ids, adj_mat = adj_mat, coors = coors @ random_rotation, mask = mask) 161 | 162 | out1 = coors1 @ random_rotation 163 | out2 = coors2 164 | 165 | assert torch.allclose(out1, out2, atol = 1e-5) 166 | --------------------------------------------------------------------------------