├── experiments ├── qm9 │ ├── __init__.py │ ├── README.md │ ├── utils.py │ └── dataset.py ├── nbody │ ├── __init__.py │ ├── README.md │ ├── data │ │ ├── generate_dataset.py │ │ └── synthetic_sim.py │ ├── datasets.py │ └── utils.py ├── requirements.txt ├── __init__.py └── train.py ├── setup.py ├── pyproject.toml ├── requirements.txt ├── .gitignore ├── segnn_jax ├── config.py ├── __init__.py ├── graph_utils.py ├── irreps_computer.py ├── segnn.py └── blocks.py ├── setup.cfg ├── .pre-commit-config.yaml ├── LICENSE ├── .github └── workflows │ └── build_branch.yaml ├── tests ├── test_blocks.py ├── conftest.py └── test_segnn.py ├── README.md └── validate.py /experiments/qm9/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/nbody/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 2 | dm-haiku>=0.0.9 3 | e3nn-jax==0.17.5 4 | jax[cuda] 5 | jraph==0.0.6.dev0 6 | numpy>=1.23.4 7 | optax==0.1.7 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # experiments 2 | *.npy 3 | wandb/ 4 | *.out 5 | datasets/ 6 | 7 | # dev 8 | .vscode 9 | __pycache__/ 10 | *.pyc 11 | venv 12 | 13 | # build 14 | *.whl 15 | build/ 16 | dist/ 17 | *.egg-info/ 18 | -------------------------------------------------------------------------------- /experiments/qm9/README.md: -------------------------------------------------------------------------------- 1 | # QM9 experiment 2 | Simple baseline experiments from the SEGNN paper. The dataset implementation is taken directly from the [original torch implementation](https://github.com/RobDHess/Steerable-E3-GNN/blob/main/qm9/dataset.py). 3 | -------------------------------------------------------------------------------- /experiments/nbody/README.md: -------------------------------------------------------------------------------- 1 | # N-body experiments (charged and gravity) 2 | Simple baseline experiments from the SEGNN paper. 3 | 4 | This code is adapted or borrows heavily from the N-Body implementation of [EGNN](https://github.com/vgsatorras/egnn). Used under the MIT license. 5 | -------------------------------------------------------------------------------- /experiments/requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cpu 2 | 3 | --find-links https://data.pyg.org/whl/torch-1.13.1+cpu.html 4 | 5 | e3nn==0.5.0 6 | matplotlib>=3.6.2 7 | rdkit==2022.9.2 8 | torch==1.13.1 9 | torch-cluster==1.6.0 10 | torch-geometric==2.1.0 11 | torch-scatter==2.1.0 12 | torch-sparse==0.6.15 13 | tqdm>=4.64.1 14 | wandb>=0.13.5 15 | -------------------------------------------------------------------------------- /segnn_jax/config.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | __conf = { 4 | "gradient_normalization": "element", # "element" or "path" 5 | "path_normalization": "element", # "element" or "path" 6 | "default_dtype": jnp.float32, 7 | "o3_layer": "tpl", # "tpl" (tp + Linear) or "fctp" (FullyConnected) or "scn" (SCN) 8 | } 9 | 10 | 11 | def config(key): 12 | return __conf[key] 13 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = segnn_jax 3 | version = attr: segnn_jax.__version__ 4 | author = Gianluca Galletti 5 | author_email = g.galletti@tum.de 6 | description = Steerable E(3) GNN in jax 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | classifiers = 10 | Programming Language :: Python :: 3.8 11 | 12 | 13 | [options] 14 | packages = segnn_jax 15 | python_requires = >=3.8 16 | install_requires = 17 | dm_haiku>=0.0.9 18 | e3nn_jax==0.17.5 19 | jax 20 | jaxlib 21 | jraph==0.0.6.dev0 22 | numpy>=1.23.4 23 | optax==0.1.3 24 | -------------------------------------------------------------------------------- /segnn_jax/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import ( 2 | O3TensorProduct, 3 | O3TensorProductFC, 4 | O3TensorProductGate, 5 | O3TensorProductSCN, 6 | ) 7 | from .graph_utils import SteerableGraphsTuple 8 | from .irreps_computer import balanced_irreps, weight_balanced_irreps 9 | from .segnn import SEGNN, SEGNNLayer 10 | 11 | __all__ = [ 12 | "SEGNN", 13 | "SEGNNLayer", 14 | "O3TensorProduct", 15 | "O3TensorProductFC", 16 | "O3TensorProductSCN", 17 | "O3TensorProductGate", 18 | "weight_balanced_irreps", 19 | "balanced_irreps", 20 | "SteerableGraphsTuple", 21 | ] 22 | 23 | __version__ = "0.7" 24 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | from torch.utils.data import DataLoader 4 | 5 | from .nbody.utils import setup_nbody_data 6 | from .qm9.utils import setup_qm9_data 7 | from .train import train 8 | 9 | __all__ = ["setup_data", "train"] 10 | 11 | 12 | __setup_conf = { 13 | "qm9": setup_qm9_data, 14 | "charged": setup_nbody_data, 15 | "gravity": setup_nbody_data, 16 | } 17 | 18 | 19 | def setup_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]: 20 | assert args.dataset in [ 21 | "qm9", 22 | "charged", 23 | "gravity", 24 | ], f"Unknown dataset {args.dataset}" 25 | setup_fn = __setup_conf[args.dataset] 26 | return setup_fn(args) 27 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | exclude: | 4 | (?x)^( 5 | datasets/| 6 | data/| 7 | )$ 8 | repos: 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v4.0.1 11 | hooks: 12 | - id: check-merge-conflict 13 | - id: check-added-large-files 14 | - id: check-docstring-first 15 | - id: check-json 16 | - id: check-toml 17 | - id: check-xml 18 | - id: check-yaml 19 | - id: trailing-whitespace 20 | - id: end-of-file-fixer 21 | - id: requirements-txt-fixer 22 | - repo: https://github.com/pycqa/isort 23 | rev: 5.12.0 24 | hooks: 25 | - id: isort 26 | args: [ --profile, black ] 27 | - repo: https://github.com/ambv/black 28 | rev: 22.3.0 29 | hooks: 30 | - id: black 31 | - repo: https://github.com/charliermarsh/ruff-pre-commit 32 | rev: 'v0.0.265' 33 | hooks: 34 | - id: ruff 35 | exclude: ^datasets/|^data/|^.git/|^venv/ 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2022] [Gianluca Galletti] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /segnn_jax/graph_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, NamedTuple, Optional 2 | 3 | import e3nn_jax as e3nn 4 | import jax.numpy as jnp 5 | import jax.tree_util as tree 6 | import jraph 7 | 8 | 9 | class SteerableGraphsTuple(NamedTuple): 10 | """Pack (steerable) node and edge attributes with jraph.GraphsTuple.""" 11 | 12 | graph: jraph.GraphsTuple 13 | node_attributes: Optional[e3nn.IrrepsArray] = None 14 | edge_attributes: Optional[e3nn.IrrepsArray] = None 15 | # NOTE: additional_message_features is in a separate field otherwise it would get 16 | # updated by jraph.GraphNetwork. Actual graph edges are used only for the messages. 17 | additional_message_features: Optional[e3nn.IrrepsArray] = None 18 | 19 | 20 | def pooling( 21 | graph: jraph.GraphsTuple, 22 | aggregate_fn: Callable = jraph.segment_sum, 23 | ) -> e3nn.IrrepsArray: 24 | """Pools over graph nodes with the specified aggregation. 25 | 26 | Args: 27 | graph: Input graph 28 | aggregate_fn: function used to update pool over the nodes 29 | 30 | Returns: 31 | The pooled graph nodes. 32 | """ 33 | n_graphs = graph.n_node.shape[0] 34 | graph_idx = jnp.arange(n_graphs) 35 | # Equivalent to jnp.sum(n_node), but jittable 36 | sum_n_node = tree.tree_leaves(graph.nodes)[0].shape[0] 37 | batch = jnp.repeat(graph_idx, graph.n_node, axis=0, total_repeat_length=sum_n_node) 38 | return e3nn.IrrepsArray( 39 | graph.nodes.irreps, aggregate_fn(graph.nodes.array, batch, n_graphs) 40 | ) 41 | -------------------------------------------------------------------------------- /.github/workflows/build_branch.yaml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: push 4 | 5 | jobs: 6 | tests: 7 | name: Tests 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@master 11 | - name: Set up docker image 12 | uses: actions/setup-python@v3 13 | with: 14 | python-version: "3.10" 15 | - name: Install dependencies 16 | run: >- 17 | python -m 18 | pip install 19 | -r requirements.txt 20 | --user 21 | - name: Install latest e3nn-jax version 22 | run: >- 23 | python -m 24 | pip install -U 25 | e3nn-jax 26 | --user 27 | - name: Install pytest 28 | run: >- 29 | python -m 30 | pip install 31 | pytest 32 | --user 33 | - name: Run tests 34 | run: >- 35 | python -m pytest tests/ 36 | 37 | build-publish: 38 | name: Build and publish 39 | runs-on: ubuntu-latest 40 | steps: 41 | - uses: actions/checkout@master 42 | - name: Set up docker image 43 | uses: actions/setup-python@v3 44 | with: 45 | python-version: "3.10" 46 | - name: Install build tools 47 | run: >- 48 | python -m 49 | pip install 50 | build 51 | --user 52 | - name: Build wheel 53 | run: >- 54 | python -m 55 | build 56 | --sdist 57 | --wheel 58 | --outdir dist/ 59 | - name: Publish to PyPI 60 | if: startsWith(github.ref, 'refs/tags') 61 | uses: pypa/gh-action-pypi-publish@release/v1 62 | with: 63 | password: ${{ secrets.PYPI_TOKEN }} 64 | -------------------------------------------------------------------------------- /tests/test_blocks.py: -------------------------------------------------------------------------------- 1 | import e3nn_jax as e3nn 2 | import haiku as hk 3 | import pytest 4 | from conftest import assert_equivariant 5 | 6 | from segnn_jax import ( 7 | O3TensorProduct, 8 | O3TensorProductFC, 9 | O3TensorProductGate, 10 | O3TensorProductSCN, 11 | ) 12 | 13 | 14 | @pytest.mark.parametrize("biases", [False, True]) 15 | @pytest.mark.parametrize( 16 | "O3Layer", [O3TensorProduct, O3TensorProductFC, O3TensorProductSCN] 17 | ) 18 | def test_linear_layers(key, biases, O3Layer): 19 | def f(x1, x2): 20 | return O3Layer("1x1o", biases=biases)(x1, x2) 21 | 22 | f = hk.without_apply_rng(hk.transform(f)) 23 | 24 | v = e3nn.normal("1x1o", key, (5,)) 25 | params = f.init(key, v, v) 26 | 27 | def wrapper(x1, x2): 28 | return f.apply(params, x1, x2) 29 | 30 | assert_equivariant( 31 | wrapper, 32 | key, 33 | e3nn.normal("1x1o", key, (5,)), 34 | e3nn.normal("1x1o", key, (5,)), 35 | ) 36 | 37 | 38 | @pytest.mark.parametrize("biases", [False, True]) 39 | @pytest.mark.parametrize( 40 | "O3Layer", [O3TensorProduct, O3TensorProductFC, O3TensorProductSCN] 41 | ) 42 | def test_gated_layers(key, biases, O3Layer): 43 | def f(x1, x2): 44 | return O3TensorProductGate("1x1o", biases=biases, o3_layer=O3Layer)(x1, x2) 45 | 46 | f = hk.without_apply_rng(hk.transform(f)) 47 | 48 | v = e3nn.normal("1x1o", key, (5,)) 49 | params = f.init(key, v, v) 50 | 51 | def wrapper(x1, x2): 52 | return f.apply(params, x1, x2) 53 | 54 | assert_equivariant( 55 | wrapper, 56 | key, 57 | e3nn.normal("1x1o", key, (5,)), 58 | e3nn.normal("1x1o", key, (5,)), 59 | ) 60 | 61 | 62 | if __name__ == "__main__": 63 | pytest.main() 64 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import e3nn_jax as e3nn 4 | import jax 5 | import jax.numpy as jnp 6 | import jraph 7 | import pytest 8 | 9 | from segnn_jax import SteerableGraphsTuple 10 | 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 12 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" 13 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 14 | 15 | 16 | @pytest.fixture 17 | def dummy_graph(): 18 | def _rand_graph(n_graphs: int = 1, attr_irreps: str = "1x0e + 1x1o"): 19 | attr_irreps = e3nn.Irreps(attr_irreps) 20 | return SteerableGraphsTuple( 21 | graph=jraph.GraphsTuple( 22 | nodes=e3nn.IrrepsArray("1x1o", jnp.ones((n_graphs * 5, 3))), 23 | edges=None, 24 | senders=jnp.zeros((10 * n_graphs,), dtype=jnp.int32), 25 | receivers=jnp.zeros((10 * n_graphs,), dtype=jnp.int32), 26 | n_node=jnp.array([5] * n_graphs), 27 | n_edge=jnp.array([10] * n_graphs), 28 | globals=None, 29 | ), 30 | additional_message_features=None, 31 | edge_attributes=None, 32 | node_attributes=e3nn.IrrepsArray( 33 | attr_irreps, jnp.ones((n_graphs * 5, attr_irreps.dim)) 34 | ), 35 | ) 36 | 37 | return _rand_graph 38 | 39 | 40 | @pytest.fixture 41 | def key(): 42 | return jax.random.PRNGKey(0) 43 | 44 | 45 | def assert_equivariant(fun, key, *args, **kwargs): 46 | try: 47 | from e3nn_jax.utils import assert_equivariant as _assert_equivariant 48 | 49 | _assert_equivariant(fun, key, *args, **kwargs) 50 | except (ImportError, AttributeError): 51 | from e3nn_jax.util import assert_equivariant as _assert_equivariant 52 | 53 | _assert_equivariant(fun, key, args_in=args, **kwargs) 54 | -------------------------------------------------------------------------------- /tests/test_segnn.py: -------------------------------------------------------------------------------- 1 | import e3nn_jax as e3nn 2 | import haiku as hk 3 | import pytest 4 | from conftest import assert_equivariant 5 | 6 | from segnn_jax import ( 7 | SEGNN, 8 | O3TensorProduct, 9 | O3TensorProductFC, 10 | O3TensorProductSCN, 11 | weight_balanced_irreps, 12 | ) 13 | 14 | 15 | @pytest.mark.parametrize("task", ["graph", "node"]) 16 | @pytest.mark.parametrize("norm", ["none", "instance"]) 17 | @pytest.mark.parametrize( 18 | "O3Layer", [O3TensorProduct, O3TensorProductFC, O3TensorProductSCN] 19 | ) 20 | def test_segnn_equivariance(key, dummy_graph, task, norm, O3Layer): 21 | scn = O3Layer == O3TensorProductSCN 22 | 23 | hidden_irreps = weight_balanced_irreps( 24 | 8, e3nn.Irreps.spherical_harmonics(1), use_sh=not scn 25 | ) 26 | 27 | def segnn(x): 28 | return SEGNN( 29 | hidden_irreps=hidden_irreps, 30 | output_irreps=e3nn.Irreps("1x1o"), 31 | num_layers=1, 32 | task=task, 33 | norm=norm, 34 | o3_layer=O3Layer, 35 | )(x) 36 | 37 | segnn = hk.without_apply_rng(hk.transform_with_state(segnn)) 38 | 39 | if scn: 40 | attr_irreps = e3nn.Irreps("1x1o") 41 | else: 42 | attr_irreps = e3nn.Irreps("1x0e+1x1o") 43 | 44 | graph = dummy_graph(attr_irreps=attr_irreps) 45 | params, segnn_state = segnn.init(key, graph) 46 | 47 | def wrapper(x): 48 | if scn: 49 | attrs = e3nn.IrrepsArray(attr_irreps, x.array) 50 | else: 51 | attrs = e3nn.spherical_harmonics(attr_irreps, x, normalize=True) 52 | st_graph = graph._replace( 53 | graph=graph.graph._replace(nodes=x), 54 | node_attributes=attrs, 55 | ) 56 | y, _ = segnn.apply(params, segnn_state, st_graph) 57 | return e3nn.IrrepsArray("1x1o", y) 58 | 59 | assert_equivariant(wrapper, key, e3nn.normal("1x1o", key, (5,))) 60 | 61 | 62 | if __name__ == "__main__": 63 | pytest.main() 64 | -------------------------------------------------------------------------------- /segnn_jax/irreps_computer.py: -------------------------------------------------------------------------------- 1 | from math import prod 2 | 3 | from e3nn_jax import Irreps 4 | 5 | 6 | def balanced_irreps(lmax: int, feature_size: int, use_sh: bool = True) -> Irreps: 7 | """Allocates irreps uniformely up until level lmax with budget feature_size.""" 8 | irreps = ["0e"] 9 | n_irreps = 1 + (lmax if use_sh else lmax * 2) 10 | total_dim = 0 11 | for level in range(1, lmax + 1): 12 | dim = 2 * level + 1 13 | multi = int(feature_size / dim / n_irreps) 14 | if multi == 0: 15 | break 16 | if use_sh: 17 | irreps.append(f"{multi}x{level}{'e' if (level % 2) == 0 else 'o'}") 18 | total_dim = multi * dim 19 | else: 20 | irreps.append(f"{multi}x{level}e+{multi}x{level}o") 21 | total_dim = multi * dim * 2 22 | 23 | # add scalars to fill missing dimensions 24 | irreps[0] = f"{feature_size - total_dim}x{irreps[0]}" 25 | 26 | return Irreps("+".join(irreps)) 27 | 28 | 29 | def weight_balanced_irreps( 30 | scalar_units: int, irreps_right: Irreps, use_sh: bool = True, lmax: int = None 31 | ) -> Irreps: 32 | """ 33 | Determines irreps_left such that the parametrized tensor product 34 | Linear(tensor_product(irreps_left, irreps_right)) 35 | has (at least) scalar_units weights. 36 | 37 | Args: 38 | scalar_units: number of desired weights 39 | irreps_right: irreps of the right tensor 40 | use_sh: whether to use spherical harmonics 41 | lmax: maximum level of spherical harmonics 42 | """ 43 | # irrep order 44 | if lmax is None: 45 | lmax = irreps_right.lmax 46 | # linear layer with squdare weight matrix 47 | linear_weights = scalar_units**2 48 | # raise hidden features until enough weigths 49 | n = 0 50 | while True: 51 | n += 1 52 | if use_sh: 53 | irreps_left = ( 54 | (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify() 55 | ) 56 | else: 57 | irreps_left = balanced_irreps(lmax, n) 58 | # number of paths 59 | tp_weights = sum( 60 | prod([irreps_left[i_1].mul ** 2, irreps_right[i_2].mul]) 61 | for i_1, (_, ir_1) in enumerate(irreps_left) 62 | for i_2, (_, ir_2) in enumerate(irreps_right) 63 | for _, (_, ir_out) in enumerate(irreps_left) 64 | if ir_out in ir_1 * ir_2 65 | ) 66 | if tp_weights >= linear_weights: 67 | break 68 | return Irreps(irreps_left) 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Steerable E(3) GNN in jax 2 | Reimplementation of [SEGNN](https://arxiv.org/abs/2110.02905) in jax. Original work by Johannes Brandstetter, Rob Hesselink, Elise van der Pol, Erik Bekkers and Max Welling. 3 | 4 | ## Why jax? 5 | **40-50% faster** inference and training compared to the [original torch implementation](https://github.com/RobDHess/Steerable-E3-GNN). Also JAX-MD. 6 | 7 | ## Installation 8 | ``` 9 | python -m pip install segnn-jax 10 | ``` 11 | 12 | Or clone this repository and build locally 13 | ``` 14 | python -m pip install -e . 15 | ``` 16 | 17 | ### GPU support 18 | Upgrade `jax` to the gpu version 19 | ``` 20 | pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 21 | ``` 22 | 23 | ## Validation 24 | N-body (charged and gravity) and QM9 datasets are included for completeness from the original paper. 25 | 26 | ### Results 27 | Charged is on 5 bodies, gravity on 100 bodies. QM9 has graphs of variable sizes, so in jax samples are padded to the maximum size. Loss is MSE for Charged and Gravity and MAE for QM9. 28 | 29 | Times are remeasured on Quadro RTX 4000, __model only__ on batches of 100 graphs, in (global) single precision. 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 |
torch (original)jax (ours)
LossInference [ms]LossInference [ms]
charged (position) .004321.22.00453.77
gravity (position) .26560.55.26441.72
QM9 (alpha) .066*82.53.082105.98**
66 | * rerun on same conditions 67 | 68 | ** padded (naive) 69 | 70 | ### Validation install 71 | 72 | The experiments are only included in the github repo, so it needs to be cloned first. 73 | ``` 74 | git clone https://github.com/gerkone/segnn-jax 75 | ``` 76 | 77 | They are adapted from the original implementation, so additionally `torch` and `torch_geometric` are needed (cpu versions are enough). 78 | ``` 79 | python -m pip install -r experiments/requirements.txt 80 | ``` 81 | 82 | ### Datasets 83 | QM9 is automatically downloaded and processed when running the respective experiment. 84 | 85 | The N-body datasets have to be generated locally from the directory [experiments/nbody/data](experiments/nbody/data) (it will take some time, especially n-body `gravity`) 86 | #### Charged dataset (5 bodies, 10000 training samples) 87 | ``` 88 | python3 -u generate_dataset.py --simulation=charged --seed=43 89 | ``` 90 | #### Gravity dataset (100 bodies, 10000 training samples) 91 | ``` 92 | python3 -u generate_dataset.py --simulation=gravity --n-balls=100 --seed=43 93 | ``` 94 | 95 | ### Notes 96 | On `jax<=0.4.6`, the `jit`-`pjit` merge can be deactivated making traning faster (on nbody). This looks like an issue with dataloading and the validation training loop implementation and it does not affect SEGNN. 97 | ``` 98 | export JAX_JIT_PJIT_API_MERGE=0 99 | ``` 100 | 101 | ### Usage 102 | #### N-body (charged) 103 | ``` 104 | python validate.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 105 | ``` 106 | 107 | #### N-body (gravity) 108 | ``` 109 | python validate.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100 110 | ``` 111 | 112 | #### QM9 113 | ``` 114 | python validate.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling 115 | ``` 116 | 117 | (configurations used in validation) 118 | 119 | 120 | 121 | ## Acknowledgments 122 | - [e3nn_jax](https://github.com/e3nn/e3nn-jax) made this reimplementation possible. 123 | - [Artur Toshev](https://github.com/arturtoshev) and [Johannes Brandsetter](https://github.com/brandstetter-johannes), for support. 124 | -------------------------------------------------------------------------------- /experiments/nbody/data/generate_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate charged and gravity datasets. 3 | 4 | charged: python3 generate_dataset.py --simulation=charged --num-train=10000 --seed=43 5 | gravity: python3 generate_dataset.py --simulation=gravity --num-train=10000 --n-balls=100 --seed=43 6 | """ 7 | import argparse 8 | import time 9 | 10 | import numpy as np 11 | from synthetic_sim import ChargedParticlesSim, GravitySim 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument( 15 | "--simulation", type=str, default="charged", help="What simulation to generate." 16 | ) 17 | parser.add_argument( 18 | "--num-train", 19 | type=int, 20 | default=10000, 21 | help="Number of training simulations to generate.", 22 | ) 23 | parser.add_argument( 24 | "--num-valid", 25 | type=int, 26 | default=2000, 27 | help="Number of validation simulations to generate.", 28 | ) 29 | parser.add_argument( 30 | "--num-test", type=int, default=2000, help="Number of test simulations to generate." 31 | ) 32 | parser.add_argument("--length", type=int, default=5000, help="Length of trajectory.") 33 | parser.add_argument( 34 | "--length-test", type=int, default=5000, help="Length of test set trajectory." 35 | ) 36 | parser.add_argument( 37 | "--sample-freq", type=int, default=100, help="How often to sample the trajectory." 38 | ) 39 | parser.add_argument( 40 | "--n-balls", type=int, default=5, help="Number of balls in the simulation." 41 | ) 42 | parser.add_argument("--seed", type=int, default=42, help="Random seed.") 43 | parser.add_argument( 44 | "--initial-vel", type=int, default=1, help="consider initial velocity" 45 | ) 46 | 47 | args = parser.parse_args() 48 | 49 | initial_vel_norm = 0.5 50 | if not args.initial_vel: 51 | initial_vel_norm = 1e-16 52 | 53 | if args.simulation == "charged": 54 | sim = ChargedParticlesSim( 55 | noise_var=0.0, n_balls=args.n_balls, vel_norm=initial_vel_norm 56 | ) 57 | suffix = "_charged" 58 | elif args.simulation == "gravity": 59 | sim = GravitySim(noise_var=0.0, n_balls=args.n_balls, vel_norm=initial_vel_norm) 60 | suffix = "_gravity" 61 | else: 62 | raise ValueError("Simulation {} not implemented".format(args.simulation)) 63 | 64 | suffix += str(args.n_balls) + "_initvel%d" % args.initial_vel 65 | np.random.seed(args.seed) 66 | 67 | print(suffix) 68 | 69 | 70 | def generate_dataset(num_sims, length, sample_freq): 71 | loc_all = list() 72 | vel_all = list() 73 | edges_all = list() 74 | charges_all = list() 75 | for i in range(num_sims): 76 | t = time.time() 77 | loc, vel, edges, charges = sim.sample_trajectory( 78 | T=length, sample_freq=sample_freq 79 | ) 80 | 81 | loc_all.append(loc) 82 | vel_all.append(vel) 83 | edges_all.append(edges) 84 | charges_all.append(charges) 85 | 86 | if i % 100 == 0: 87 | print("Iter: {}, Simulation time: {}".format(i, time.time() - t)) 88 | 89 | charges_all = np.stack(charges_all) 90 | loc_all = np.stack(loc_all) 91 | vel_all = np.stack(vel_all) 92 | edges_all = np.stack(edges_all) 93 | 94 | return loc_all, vel_all, edges_all, charges_all 95 | 96 | 97 | if __name__ == "__main__": 98 | print("Generating {} training simulations".format(args.num_train)) 99 | loc_train, vel_train, edges_train, charges_train = generate_dataset( 100 | args.num_train, args.length, args.sample_freq 101 | ) 102 | 103 | print("Generating {} validation simulations".format(args.num_valid)) 104 | loc_valid, vel_valid, edges_valid, charges_valid = generate_dataset( 105 | args.num_valid, args.length, args.sample_freq 106 | ) 107 | 108 | print("Generating {} test simulations".format(args.num_test)) 109 | loc_test, vel_test, edges_test, charges_test = generate_dataset( 110 | args.num_test, args.length_test, args.sample_freq 111 | ) 112 | 113 | np.save("loc_train" + suffix + ".npy", loc_train) 114 | np.save("vel_train" + suffix + ".npy", vel_train) 115 | np.save("edges_train" + suffix + ".npy", edges_train) 116 | np.save("q_train" + suffix + ".npy", charges_train) 117 | 118 | np.save("loc_valid" + suffix + ".npy", loc_valid) 119 | np.save("vel_valid" + suffix + ".npy", vel_valid) 120 | np.save("edges_valid" + suffix + ".npy", edges_valid) 121 | np.save("q_valid" + suffix + ".npy", charges_valid) 122 | 123 | np.save("loc_test" + suffix + ".npy", loc_test) 124 | np.save("vel_test" + suffix + ".npy", vel_test) 125 | np.save("edges_test" + suffix + ".npy", edges_test) 126 | np.save("q_test" + suffix + ".npy", charges_test) 127 | -------------------------------------------------------------------------------- /experiments/qm9/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import e3nn_jax as e3nn 4 | import jax.numpy as jnp 5 | import jraph 6 | from torch_geometric.data import Data 7 | from torch_geometric.loader import DataLoader 8 | 9 | from segnn_jax import SteerableGraphsTuple 10 | 11 | from .dataset import QM9 12 | 13 | 14 | def QM9GraphTransform( 15 | args, 16 | max_batch_nodes: int, 17 | max_batch_edges: int, 18 | train_trn: Callable, 19 | ) -> Callable: 20 | """ 21 | Build a function that converts torch DataBatch into SteerableGraphsTuple. 22 | 23 | Mostly a quick fix out of lazyness. Rewriting QM9 in jax is not trivial. 24 | """ 25 | attribute_irreps = e3nn.Irreps.spherical_harmonics(args.lmax_attributes) 26 | 27 | def _to_steerable_graph( 28 | data: Data, training: bool = True 29 | ) -> Tuple[SteerableGraphsTuple, jnp.array]: 30 | ptr = jnp.array(data.ptr) 31 | senders = jnp.array(data.edge_index[0]) 32 | receivers = jnp.array(data.edge_index[1]) 33 | graph = jraph.GraphsTuple( 34 | nodes=e3nn.IrrepsArray(args.node_irreps, jnp.array(data.x)), 35 | edges=None, 36 | senders=senders, 37 | receivers=receivers, 38 | n_node=jnp.diff(ptr), 39 | n_edge=jnp.diff(jnp.sum(senders[:, jnp.newaxis] < ptr, axis=0)), 40 | globals=None, 41 | ) 42 | # pad for jax static shapes 43 | node_attr_pad = ((0, max_batch_nodes - jnp.sum(graph.n_node) + 1), (0, 0)) 44 | edge_attr_pad = ((0, max_batch_edges - jnp.sum(graph.n_edge) + 1), (0, 0)) 45 | graph = jraph.pad_with_graphs( 46 | graph, 47 | n_node=max_batch_nodes + 1, 48 | n_edge=max_batch_edges + 1, 49 | n_graph=graph.n_node.shape[0] + 1, 50 | ) 51 | 52 | node_attributes = e3nn.IrrepsArray( 53 | attribute_irreps, jnp.pad(jnp.array(data.node_attr), node_attr_pad) 54 | ) 55 | # scalar attribute to 1 by default 56 | node_attributes = e3nn.IrrepsArray( 57 | node_attributes.irreps, node_attributes.array.at[:, 0].set(1.0) 58 | ) 59 | 60 | additional_message_features = e3nn.IrrepsArray( 61 | args.additional_message_irreps, 62 | jnp.pad(jnp.array(data.additional_message_features), edge_attr_pad), 63 | ) 64 | edge_attributes = e3nn.IrrepsArray( 65 | attribute_irreps, jnp.pad(jnp.array(data.edge_attr), edge_attr_pad) 66 | ) 67 | 68 | st_graph = SteerableGraphsTuple( 69 | graph=graph, 70 | node_attributes=node_attributes, 71 | edge_attributes=edge_attributes, 72 | additional_message_features=additional_message_features, 73 | ) 74 | 75 | # pad targets 76 | target = jnp.array(data.y) 77 | if args.task == "node": 78 | target = jnp.pad(target, [(0, max_batch_nodes - target.shape[0] - 1)]) 79 | if args.task == "graph": 80 | target = jnp.append(target, 0) 81 | 82 | # normalize targets 83 | if training and train_trn is not None: 84 | target = train_trn(target) 85 | 86 | return st_graph, target 87 | 88 | return _to_steerable_graph 89 | 90 | 91 | def setup_qm9_data( 92 | args, 93 | ) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]: 94 | dataset_train = QM9( 95 | "datasets", 96 | args.target, 97 | args.radius, 98 | partition="train", 99 | lmax_attr=args.lmax_attributes, 100 | feature_type=args.feature_type, 101 | ) 102 | dataset_val = QM9( 103 | "datasets", 104 | args.target, 105 | args.radius, 106 | partition="valid", 107 | lmax_attr=args.lmax_attributes, 108 | feature_type=args.feature_type, 109 | ) 110 | dataset_test = QM9( 111 | "datasets", 112 | args.target, 113 | args.radius, 114 | partition="test", 115 | lmax_attr=args.lmax_attributes, 116 | feature_type=args.feature_type, 117 | ) 118 | 119 | # 0.8 (un)safety factor for rejitting 120 | max_batch_nodes = int(0.8 * sum(dataset_test.top_n_nodes(args.batch_size))) 121 | max_batch_edges = int(0.8 * sum(dataset_test.top_n_edges(args.batch_size))) 122 | 123 | target_mean, target_mad = dataset_train.calc_stats() 124 | 125 | def remove_offsets(t): 126 | return (t - target_mean) / target_mad 127 | 128 | # not great and very slow due to huge padding 129 | loader_train = DataLoader( 130 | dataset_train, 131 | batch_size=args.batch_size, 132 | shuffle=True, 133 | drop_last=True, 134 | ) 135 | loader_val = DataLoader( 136 | dataset_val, 137 | batch_size=args.batch_size, 138 | shuffle=False, 139 | drop_last=True, 140 | ) 141 | loader_test = DataLoader( 142 | dataset_test, 143 | batch_size=args.batch_size, 144 | shuffle=False, 145 | drop_last=True, 146 | ) 147 | 148 | to_graphs_tuple = QM9GraphTransform( 149 | args, 150 | max_batch_nodes=max_batch_nodes, 151 | max_batch_edges=max_batch_edges, 152 | train_trn=remove_offsets, 153 | ) 154 | 155 | def add_offsets(p): 156 | return p * target_mad + target_mean 157 | 158 | return loader_train, loader_val, loader_test, to_graphs_tuple, add_offsets 159 | -------------------------------------------------------------------------------- /experiments/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | from typing import Callable, Tuple 4 | 5 | import haiku as hk 6 | import jax 7 | import jax.numpy as jnp 8 | import jraph 9 | import optax 10 | from jax import jit 11 | 12 | from segnn_jax import SteerableGraphsTuple 13 | 14 | 15 | @partial(jit, static_argnames=["model_fn", "criterion", "do_mask", "eval_trn"]) 16 | def loss_fn_wrapper( 17 | params: hk.Params, 18 | state: hk.State, 19 | st_graph: SteerableGraphsTuple, 20 | target: jnp.ndarray, 21 | model_fn: Callable, 22 | criterion: Callable, 23 | do_mask: bool = True, 24 | eval_trn: Callable = None, 25 | ) -> Tuple[float, hk.State]: 26 | pred, state = model_fn(params, state, st_graph) 27 | if eval_trn is not None: 28 | pred = eval_trn(pred) 29 | 30 | if do_mask: 31 | if target.shape == st_graph.graph.nodes.shape: 32 | mask = jraph.get_node_padding_mask(st_graph.graph) 33 | else: 34 | mask = jraph.get_graph_padding_mask(st_graph.graph) 35 | # broadcast mask for vector targets 36 | if len(pred.shape) == 2: 37 | mask = mask[:, jnp.newaxis] 38 | else: 39 | mask = jnp.ones_like(target) 40 | 41 | target = target * mask 42 | pred = pred * mask 43 | 44 | assert target.shape == pred.shape 45 | return jnp.sum(criterion(pred, target)) / jnp.count_nonzero(mask), state 46 | 47 | 48 | @partial(jit, static_argnames=["loss_fn", "opt_update"]) 49 | def update( 50 | params: hk.Params, 51 | state: hk.State, 52 | graph: SteerableGraphsTuple, 53 | target: jnp.ndarray, 54 | opt_state: optax.OptState, 55 | loss_fn: Callable, 56 | opt_update: Callable, 57 | ) -> Tuple[float, hk.Params, hk.State, optax.OptState]: 58 | (loss, state), grads = jax.value_and_grad(loss_fn, has_aux=True)( 59 | params, state, graph, target 60 | ) 61 | updates, opt_state = opt_update(grads, opt_state, params) 62 | return loss, optax.apply_updates(params, updates), state, opt_state 63 | 64 | 65 | def evaluate( 66 | loader, 67 | params: hk.Params, 68 | state: hk.State, 69 | loss_fn: Callable, 70 | graph_transform: Callable, 71 | ) -> Tuple[float, float]: 72 | eval_loss = 0.0 73 | eval_times = 0.0 74 | for data in loader: 75 | graph, target = graph_transform(data, training=False) 76 | eval_start = time.perf_counter_ns() 77 | loss, _ = jax.lax.stop_gradient(loss_fn(params, state, graph, target)) 78 | eval_loss += jax.block_until_ready(loss) 79 | eval_times += (time.perf_counter_ns() - eval_start) / 1e6 80 | return eval_times / len(loader), eval_loss / len(loader) 81 | 82 | 83 | def train( 84 | key, 85 | segnn, 86 | loader_train, 87 | loader_val, 88 | loader_test, 89 | loss_fn, 90 | eval_loss_fn, 91 | graph_transform, 92 | args, 93 | ): 94 | init_graph, _ = graph_transform(next(iter(loader_train))) 95 | params, segnn_state = segnn.init(key, init_graph) 96 | 97 | print( 98 | f"Starting {args.epochs} epochs " 99 | f"with {hk.data_structures.tree_size(params)} parameters.\n" 100 | "Jitting..." 101 | ) 102 | 103 | total_steps = args.epochs * len(loader_train) 104 | 105 | # set up learning rate and optimizer 106 | learning_rate = args.lr 107 | if args.lr_scheduling: 108 | learning_rate = optax.piecewise_constant_schedule( 109 | learning_rate, 110 | boundaries_and_scales={ 111 | int(total_steps * 0.7): 0.1, 112 | int(total_steps * 0.9): 0.1, 113 | }, 114 | ) 115 | opt_init, opt_update = optax.adamw( 116 | learning_rate=learning_rate, weight_decay=args.weight_decay 117 | ) 118 | 119 | model_fn = jit(segnn.apply) 120 | 121 | loss_fn = partial(loss_fn, model_fn=model_fn) 122 | eval_loss_fn = partial(eval_loss_fn, model_fn=model_fn) 123 | update_fn = partial(update, loss_fn=loss_fn, opt_update=opt_update) 124 | eval_fn = partial(evaluate, loss_fn=eval_loss_fn, graph_transform=graph_transform) 125 | 126 | opt_state = opt_init(params) 127 | avg_time = [] 128 | best_val = 1e10 129 | 130 | for e in range(args.epochs): 131 | train_loss = 0.0 132 | epoch_start = time.perf_counter_ns() 133 | for data in loader_train: 134 | graph, target = graph_transform(data) 135 | loss, params, segnn_state, opt_state = update_fn( 136 | params=params, 137 | state=segnn_state, 138 | graph=graph, 139 | target=target, 140 | opt_state=opt_state, 141 | ) 142 | train_loss += jax.block_until_ready(loss) 143 | train_loss /= len(loader_train) 144 | epoch_time = (time.perf_counter_ns() - epoch_start) / 1e9 145 | 146 | print( 147 | f"[Epoch {e+1:>4}] train loss {train_loss:.6f}, epoch {epoch_time:.2f}s", 148 | end="", 149 | ) 150 | if e % args.val_freq == 0: 151 | eval_time, val_loss = eval_fn(loader_val, params, segnn_state) 152 | avg_time.append(eval_time) 153 | tag = "" 154 | if val_loss < best_val: 155 | best_val = val_loss 156 | tag = " (best)" 157 | _, test_loss_ckp = eval_fn(loader_test, params, segnn_state) 158 | print(f" - val loss {val_loss:.6f}{tag}, infer {eval_time:.2f}ms", end="") 159 | 160 | print() 161 | 162 | test_loss = 0 163 | _, test_loss = eval_fn(loader_test, params, segnn_state) 164 | # ignore compilation time 165 | avg_time = avg_time[1:] if len(avg_time) > 1 else avg_time 166 | avg_time = sum(avg_time) / len(avg_time) 167 | print( 168 | "Training done.\n" 169 | f"Final test loss {test_loss:.6f} - checkpoint test loss {test_loss_ckp:.6f}.\n" 170 | f"Average (model) eval time {avg_time:.2f}ms" 171 | ) 172 | -------------------------------------------------------------------------------- /experiments/nbody/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pathlib 3 | from abc import ABC, abstractmethod 4 | from typing import Sequence, Tuple, Union 5 | 6 | import numpy as np 7 | 8 | DATA_DIR = "data" 9 | 10 | 11 | class BaseDataset(ABC): 12 | """Abstract n-body dataset class.""" 13 | 14 | def __init__( 15 | self, 16 | data_type, 17 | partition="train", 18 | max_samples=1e8, 19 | dataset_name="small", 20 | n_bodies=5, 21 | normalize=False, 22 | ): 23 | self.partition = partition 24 | if self.partition == "val": 25 | self.suffix = "valid" 26 | else: 27 | self.suffix = self.partition 28 | self.dataset_name = dataset_name 29 | self.suffix += f"_{data_type}{n_bodies}_initvel1" 30 | self.data_type = data_type 31 | self.max_samples = int(max_samples) 32 | self.normalize = normalize 33 | 34 | self.data = None 35 | 36 | def get_n_nodes(self): 37 | return self.data[0].shape[2] 38 | 39 | def _get_partition_frames(self) -> Tuple[int, int]: 40 | if self.dataset_name == "default": 41 | frame_0, frame_target = 6, 8 42 | elif self.dataset_name == "small": 43 | frame_0, frame_target = 30, 40 44 | elif self.dataset_name == "small_out_dist": 45 | frame_0, frame_target = 20, 30 46 | else: 47 | raise Exception("Wrong dataset partition %s" % self.dataset_name) 48 | 49 | return frame_0, frame_target 50 | 51 | def __len__(self) -> int: 52 | return len(self.data[0]) 53 | 54 | def _load(self) -> Tuple[np.ndarray, ...]: 55 | filepath = pathlib.Path(__file__).parent.resolve() 56 | 57 | loc = np.load(osp.join(filepath, DATA_DIR, "loc_" + self.suffix + ".npy")) 58 | vel = np.load(osp.join(filepath, DATA_DIR, "vel_" + self.suffix + ".npy")) 59 | edges = np.load(osp.join(filepath, DATA_DIR, "edges_" + self.suffix + ".npy")) 60 | q = np.load(osp.join(filepath, DATA_DIR, "q_" + self.suffix + ".npy")) 61 | 62 | return loc, vel, edges, q 63 | 64 | def _normalize(self, x: np.ndarray) -> np.ndarray: 65 | std = x.std(axis=0) 66 | x = x - x.mean(axis=0) 67 | return np.divide(x, std, out=x, where=std != 0) 68 | 69 | @abstractmethod 70 | def load(self): 71 | raise NotImplementedError 72 | 73 | @abstractmethod 74 | def preprocess(self, *args) -> Tuple[np.ndarray, ...]: 75 | raise NotImplementedError 76 | 77 | 78 | class ChargedDataset(BaseDataset): 79 | """N-body charged dataset class.""" 80 | 81 | def __init__( 82 | self, 83 | partition="train", 84 | max_samples=1e8, 85 | dataset_name="small", 86 | n_bodies=5, 87 | normalize=False, 88 | ): 89 | super().__init__( 90 | "charged", partition, max_samples, dataset_name, n_bodies, normalize 91 | ) 92 | self.data, self.edges = self.load() 93 | 94 | def preprocess(self, *args) -> Tuple[np.ndarray, ...]: 95 | # swap n_nodes - n_features dimensions 96 | loc, vel, edges, charges = args 97 | loc, vel = np.transpose(loc, (0, 1, 3, 2)), np.transpose(vel, (0, 1, 3, 2)) 98 | n_nodes = loc.shape[2] 99 | loc = loc[0 : self.max_samples, :, :, :] # limit number of samples 100 | vel = vel[0 : self.max_samples, :, :, :] # speed when starting the trajectory 101 | charges = charges[0 : self.max_samples] 102 | edge_attr = [] 103 | 104 | # Initialize edges and edge_attributes 105 | rows, cols = [], [] 106 | for i in range(n_nodes): 107 | for j in range(n_nodes): 108 | if i != j: 109 | edge_attr.append(edges[:, i, j]) 110 | rows.append(i) 111 | cols.append(j) 112 | edges = [rows, cols] 113 | # swap n_nodes - batch_size and add nf dimension 114 | edge_attr = np.array(edge_attr).T 115 | edge_attr = np.expand_dims(edge_attr, 2) 116 | 117 | if self.normalize: 118 | loc = self._normalize(loc) 119 | vel = self._normalize(vel) 120 | charges = self._normalize(charges) 121 | 122 | return loc, vel, edge_attr, edges, charges 123 | 124 | def load(self): 125 | loc, vel, edges, q = self._load() 126 | 127 | loc, vel, edge_attr, edges, charges = self.preprocess(loc, vel, edges, q) 128 | return (loc, vel, edge_attr, charges), edges 129 | 130 | def __getitem__(self, i: Union[Sequence, int]) -> Tuple[np.ndarray, ...]: 131 | frame_0, frame_target = self._get_partition_frames() 132 | 133 | loc, vel, edge_attr, charges = self.data 134 | 135 | loc, vel, edge_attr, charges, target_loc = ( 136 | loc[i, frame_0], 137 | vel[i, frame_0], 138 | edge_attr[i], 139 | charges[i], 140 | loc[i, frame_target], 141 | ) 142 | 143 | if not isinstance(i, int): 144 | # flatten batch and nodes dimensions 145 | loc = loc.reshape(-1, *loc.shape[2:]) 146 | vel = vel.reshape(-1, *vel.shape[2:]) 147 | edge_attr = edge_attr.reshape(-1, *edge_attr.shape[2:]) 148 | charges = charges.reshape(-1, *charges.shape[2:]) 149 | target_loc = target_loc.reshape(-1, *target_loc.shape[2:]) 150 | 151 | return loc, vel, edge_attr, charges, target_loc 152 | 153 | 154 | class GravityDataset(BaseDataset): 155 | """N-body gravity dataset class.""" 156 | 157 | def __init__( 158 | self, 159 | partition="train", 160 | max_samples=1e8, 161 | dataset_name="small", 162 | n_bodies=100, 163 | neighbours=6, 164 | target="pos", 165 | normalize=False, 166 | ): 167 | super().__init__( 168 | "gravity", partition, max_samples, dataset_name, n_bodies, normalize 169 | ) 170 | assert target in ["pos", "force"] 171 | self.neighbours = int(neighbours) 172 | self.target = target 173 | self.data = self.load() 174 | 175 | def preprocess(self, *args) -> Tuple[np.ndarray, ...]: 176 | loc, vel, force, mass = args 177 | # NOTE this was in the original paper but does not look right 178 | # loc = np.transpose(loc, (0, 1, 3, 2)) 179 | # vel = np.transpose(vel, (0, 1, 3, 2)) 180 | # force = np.transpose(force, (0, 1, 3, 2)) 181 | loc = loc[0 : self.max_samples, :, :, :] # limit number of samples 182 | vel = vel[0 : self.max_samples, :, :, :] # speed when starting the trajectory 183 | force = force[0 : self.max_samples, :, :, :] 184 | 185 | if self.normalize: 186 | loc = self._normalize(loc) 187 | vel = self._normalize(vel) 188 | force = self._normalize(force) 189 | mass = self._normalize(mass) 190 | 191 | return loc, vel, force, mass 192 | 193 | def load(self): 194 | loc, vel, edges, q = self._load() 195 | 196 | self.num_nodes = loc.shape[-1] 197 | 198 | loc, vel, force, mass = self.preprocess(loc, vel, edges, q) 199 | return (loc, vel, force, mass) 200 | 201 | def __getitem__(self, i: Union[Sequence, int]) -> Tuple[np.ndarray, ...]: 202 | frame_0, frame_target = self._get_partition_frames() 203 | 204 | loc, vel, force, mass = self.data 205 | if self.target == "pos": 206 | y = loc[i, frame_target] 207 | elif self.target == "force": 208 | y = force[i, frame_target] 209 | loc, vel, force, mass = (loc[i, frame_0], vel[i, frame_0], force[i], mass[i]) 210 | 211 | if not isinstance(i, int): 212 | # flatten batch and nodes dimensions 213 | loc = loc.reshape(-1, *loc.shape[2:]) 214 | vel = vel.reshape(-1, *vel.shape[2:]) 215 | force = force.reshape(-1, *force.shape[2:]) 216 | mass = mass.reshape(-1, *mass.shape[2:]) 217 | y = y.reshape(-1, *y.shape[2:]) 218 | 219 | return loc, vel, force, mass, y 220 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from functools import partial 4 | 5 | import e3nn_jax as e3nn 6 | import haiku as hk 7 | import jax 8 | import jax.numpy as jnp 9 | import wandb 10 | 11 | from experiments import setup_data, train 12 | from segnn_jax import SEGNN, weight_balanced_irreps 13 | 14 | key = jax.random.PRNGKey(1337) 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | # Run parameters 20 | parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") 21 | parser.add_argument( 22 | "--batch-size", 23 | type=int, 24 | default=128, 25 | help="Batch size (number of graphs).", 26 | ) 27 | parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate") 28 | parser.add_argument( 29 | "--lr-scheduling", 30 | action="store_true", 31 | help="Use learning rate scheduling", 32 | ) 33 | parser.add_argument( 34 | "--weight-decay", type=float, default=1e-12, help="Weight decay" 35 | ) 36 | parser.add_argument( 37 | "--dataset", 38 | type=str, 39 | choices=["qm9", "charged", "gravity"], 40 | help="Dataset name", 41 | ) 42 | parser.add_argument( 43 | "--max-samples", 44 | type=int, 45 | default=3000, 46 | help="Maximum number of samples in nbody dataset", 47 | ) 48 | parser.add_argument( 49 | "--val-freq", 50 | type=int, 51 | default=10, 52 | help="Evaluation frequency (number of epochs)", 53 | ) 54 | 55 | # nbody parameters 56 | parser.add_argument( 57 | "--target", 58 | type=str, 59 | default="pos", 60 | help="Target. e.g. pos, force (gravity), alpha (qm9)", 61 | ) 62 | parser.add_argument( 63 | "--neighbours", 64 | type=int, 65 | default=20, 66 | help="Number of connected nearest neighbours", 67 | ) 68 | parser.add_argument( 69 | "--n-bodies", 70 | type=int, 71 | default=5, 72 | help="Number of bodies in the dataset", 73 | ) 74 | parser.add_argument( 75 | "--dataset-name", 76 | type=str, 77 | default="small", 78 | choices=["small", "default", "small_out_dist"], 79 | help="Name of nbody data partition: default (200 steps), small (1000 steps)", 80 | ) 81 | 82 | # qm9 parameters 83 | parser.add_argument( 84 | "--radius", 85 | type=float, 86 | default=2.0, 87 | help="Radius (Angstrom) between which atoms to add links.", 88 | ) 89 | parser.add_argument( 90 | "--feature-type", 91 | type=str, 92 | default="one_hot", 93 | choices=["one_hot", "cormorant", "gilmer"], 94 | help="Type of input feature", 95 | ) 96 | 97 | # Model parameters 98 | parser.add_argument( 99 | "--units", type=int, default=64, help="Number of values in the hidden layers" 100 | ) 101 | parser.add_argument( 102 | "--lmax-hidden", 103 | type=int, 104 | default=1, 105 | help="Max degree of hidden representations.", 106 | ) 107 | parser.add_argument( 108 | "--lmax-attributes", 109 | type=int, 110 | default=1, 111 | help="Max degree of geometric attribute embedding", 112 | ) 113 | parser.add_argument( 114 | "--layers", type=int, default=7, help="Number of message passing layers" 115 | ) 116 | parser.add_argument( 117 | "--blocks", type=int, default=2, help="Number of layers in steerable MLPs." 118 | ) 119 | parser.add_argument( 120 | "--norm", 121 | type=str, 122 | default="none", 123 | choices=["instance", "batch", "none"], 124 | help="Normalisation type", 125 | ) 126 | parser.add_argument( 127 | "--double-precision", 128 | action="store_true", 129 | help="Use double precision in model", 130 | ) 131 | parser.add_argument( 132 | "--scn", 133 | action="store_true", 134 | help="Train SEGNN with the eSCN optimization", 135 | ) 136 | 137 | # wandb parameters 138 | parser.add_argument( 139 | "--wandb", 140 | action="store_true", 141 | help="Activate weights and biases logging", 142 | ) 143 | parser.add_argument( 144 | "--wandb-project", 145 | type=str, 146 | default="segnn", 147 | help="Weights and biases project", 148 | ) 149 | parser.add_argument( 150 | "--wandb-entity", 151 | type=str, 152 | default="", 153 | help="Weights and biases entity", 154 | ) 155 | 156 | args = parser.parse_args() 157 | 158 | # if specified set jax in double precision 159 | jax.config.update("jax_enable_x64", args.double_precision) 160 | 161 | # connect to wandb 162 | if args.wandb: 163 | wandb_name = "_".join( 164 | [ 165 | args.wandb_project, 166 | args.dataset, 167 | args.target, 168 | str(int(time.time())), 169 | ] 170 | ) 171 | wandb.init( 172 | project=args.wandb_project, 173 | name=wandb_name, 174 | config=args, 175 | entity=args.wandb_entity, 176 | ) 177 | 178 | # feature representations 179 | if args.dataset == "qm9": 180 | args.task = "graph" 181 | if args.feature_type == "one_hot": 182 | args.node_irreps = e3nn.Irreps("5x0e") 183 | elif args.feature_type == "cormorant": 184 | args.node_irreps = e3nn.Irreps("15x0e") 185 | elif args.feature_type == "gilmer": 186 | args.node_irreps = e3nn.Irreps("11x0e") 187 | args.output_irreps = e3nn.Irreps("1x0e") 188 | args.additional_message_irreps = e3nn.Irreps("1x0e") 189 | assert not args.scn, "eSCN not implemented for qm9" 190 | elif args.dataset in ["charged", "gravity"]: 191 | args.task = "node" 192 | args.node_irreps = e3nn.Irreps("2x1o + 1x0e") 193 | args.output_irreps = e3nn.Irreps("1x1o") 194 | args.additional_message_irreps = e3nn.Irreps("2x0e") 195 | 196 | # Create hidden irreps 197 | if not args.scn: 198 | attr_irreps = e3nn.Irreps.spherical_harmonics(args.lmax_attributes) 199 | else: 200 | attr_irreps = e3nn.Irrep(f"{args.lmax_attribute}y") 201 | 202 | hidden_irreps = weight_balanced_irreps( 203 | scalar_units=args.units, 204 | irreps_right=attr_irreps, 205 | use_sh=(not args.scn), 206 | lmax=args.lmax_hidden, 207 | ) 208 | 209 | args.o3_layer = "scn" if args.scn else "tpl" 210 | del args.scn 211 | 212 | # build model 213 | def segnn(x): 214 | return SEGNN( 215 | hidden_irreps=hidden_irreps, 216 | output_irreps=args.output_irreps, 217 | num_layers=args.layers, 218 | task=args.task, 219 | pool="avg", 220 | blocks_per_layer=args.blocks, 221 | norm=args.norm, 222 | o3_layer=args.o3_layer, 223 | )(x) 224 | 225 | segnn = hk.without_apply_rng(hk.transform_with_state(segnn)) 226 | 227 | loader_train, loader_val, loader_test, graph_transform, eval_trn = setup_data(args) 228 | 229 | if args.dataset == "qm9": 230 | from experiments.train import loss_fn_wrapper 231 | 232 | def _mae(p, t): 233 | return jnp.abs(p - t) 234 | 235 | train_loss = partial(loss_fn_wrapper, criterion=_mae) 236 | eval_loss = partial(loss_fn_wrapper, criterion=_mae, eval_trn=eval_trn) 237 | if args.dataset in ["charged", "gravity"]: 238 | from experiments.train import loss_fn_wrapper 239 | 240 | def _mse(p, t): 241 | return jnp.power(p - t, 2) 242 | 243 | train_loss = partial(loss_fn_wrapper, criterion=_mse, do_mask=False) 244 | eval_loss = partial(loss_fn_wrapper, criterion=_mse, do_mask=False) 245 | 246 | train( 247 | key, 248 | segnn, 249 | loader_train, 250 | loader_val, 251 | loader_test, 252 | train_loss, 253 | eval_loss, 254 | graph_transform, 255 | args, 256 | ) 257 | -------------------------------------------------------------------------------- /experiments/nbody/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple 2 | 3 | import e3nn_jax as e3nn 4 | import jax 5 | import jax.numpy as jnp 6 | import jax.tree_util as tree 7 | import numpy as np 8 | import torch 9 | from jraph import GraphsTuple, segment_mean 10 | from torch.utils.data import DataLoader 11 | from torch_geometric.nn import knn_graph 12 | 13 | from segnn_jax import SteerableGraphsTuple 14 | 15 | from .datasets import ChargedDataset, GravityDataset 16 | 17 | 18 | def O3Transform( 19 | node_features_irreps: e3nn.Irreps, 20 | edge_features_irreps: e3nn.Irreps, 21 | lmax_attributes: int, 22 | scn: bool = False, 23 | ) -> Callable: 24 | """ 25 | Build a transformation function that includes (nbody) O3 attributes to a graph. 26 | """ 27 | if not scn: 28 | attribute_irreps = e3nn.Irreps.spherical_harmonics(lmax_attributes) 29 | else: 30 | attribute_irreps = e3nn.Irrep("1o") 31 | 32 | @jax.jit 33 | def _o3_transform( 34 | st_graph: SteerableGraphsTuple, 35 | loc: jnp.ndarray, 36 | vel: jnp.ndarray, 37 | charges: jnp.ndarray, 38 | ) -> SteerableGraphsTuple: 39 | graph = st_graph.graph 40 | prod_charges = charges[graph.senders] * charges[graph.receivers] 41 | rel_pos = loc[graph.senders] - loc[graph.receivers] 42 | edge_dist = jnp.sqrt(jnp.power(rel_pos, 2).sum(1, keepdims=True)) 43 | 44 | msg_features = e3nn.IrrepsArray( 45 | edge_features_irreps, 46 | jnp.concatenate((edge_dist, prod_charges), axis=-1), 47 | ) 48 | 49 | vel_abs = jnp.sqrt(jnp.power(vel, 2).sum(1, keepdims=True)) 50 | mean_loc = loc.mean(1, keepdims=True) 51 | 52 | nodes = e3nn.IrrepsArray( 53 | node_features_irreps, 54 | jnp.concatenate((loc - mean_loc, vel, vel_abs), axis=-1), 55 | ) 56 | 57 | if not scn: 58 | edge_attributes = e3nn.spherical_harmonics( 59 | attribute_irreps, rel_pos, normalize=True, normalization="integral" 60 | ) 61 | vel_embedding = e3nn.spherical_harmonics( 62 | attribute_irreps, vel, normalize=True, normalization="integral" 63 | ) 64 | else: 65 | edge_attributes = e3nn.IrrepsArray(attribute_irreps, rel_pos) 66 | vel_embedding = e3nn.IrrepsArray(attribute_irreps, vel) 67 | 68 | # scatter edge attributes 69 | sum_n_node = tree.tree_leaves(nodes)[0].shape[0] 70 | node_attributes = ( 71 | tree.tree_map( 72 | lambda e: segment_mean(e, graph.receivers, sum_n_node), 73 | edge_attributes, 74 | ) 75 | + vel_embedding 76 | ) 77 | if not scn: 78 | # scalar attribute to 1 by default 79 | node_attributes = e3nn.IrrepsArray( 80 | node_attributes.irreps, node_attributes.array.at[:, 0].set(1.0) 81 | ) 82 | 83 | return SteerableGraphsTuple( 84 | graph=GraphsTuple( 85 | nodes=nodes, 86 | edges=None, 87 | senders=graph.senders, 88 | receivers=graph.receivers, 89 | n_node=graph.n_node, 90 | n_edge=graph.n_edge, 91 | globals=graph.globals, 92 | ), 93 | node_attributes=node_attributes, 94 | edge_attributes=edge_attributes, 95 | additional_message_features=msg_features, 96 | ) 97 | 98 | return _o3_transform 99 | 100 | 101 | def NbodyGraphTransform( 102 | transform: Callable, 103 | dataset_name: str, 104 | n_nodes: int, 105 | batch_size: int, 106 | neighbours: Optional[int] = 6, 107 | relative_target: bool = False, 108 | ) -> Callable: 109 | """ 110 | Build a function that converts torch DataBatch into SteerableGraphsTuple. 111 | """ 112 | 113 | if dataset_name == "charged": 114 | # charged system is a connected graph 115 | full_edge_indices = jnp.array( 116 | [ 117 | (i + n_nodes * b, j + n_nodes * b) 118 | for b in range(batch_size) 119 | for i in range(n_nodes) 120 | for j in range(n_nodes) 121 | if i != j 122 | ] 123 | ).T 124 | 125 | def _to_steerable_graph( 126 | data: List, training: bool = True 127 | ) -> Tuple[SteerableGraphsTuple, jnp.ndarray]: 128 | _ = training 129 | loc, vel, _, q, targets = data 130 | 131 | cur_batch = int(loc.shape[0] / n_nodes) 132 | 133 | if dataset_name == "charged": 134 | edge_indices = full_edge_indices[:, : n_nodes * (n_nodes - 1) * cur_batch] 135 | senders, receivers = edge_indices[0], edge_indices[1] 136 | if dataset_name == "gravity": 137 | batch = torch.arange(0, cur_batch) 138 | batch = batch.repeat_interleave(n_nodes).long() 139 | edge_indices = knn_graph(torch.from_numpy(np.array(loc)), neighbours, batch) 140 | # switched by default 141 | senders, receivers = jnp.array(edge_indices[0]), jnp.array(edge_indices[1]) 142 | 143 | st_graph = SteerableGraphsTuple( 144 | graph=GraphsTuple( 145 | nodes=None, 146 | edges=None, 147 | senders=senders, 148 | receivers=receivers, 149 | n_node=jnp.array([n_nodes] * cur_batch), 150 | n_edge=jnp.array([len(senders) // cur_batch] * cur_batch), 151 | globals=None, 152 | ) 153 | ) 154 | st_graph = transform(st_graph, loc, vel, q) 155 | 156 | # relative shift as target 157 | if relative_target: 158 | targets = targets - loc 159 | 160 | return st_graph, targets 161 | 162 | return _to_steerable_graph 163 | 164 | 165 | def numpy_collate(batch): 166 | if isinstance(batch[0], np.ndarray): 167 | return jnp.vstack(batch) 168 | elif isinstance(batch[0], (tuple, list)): 169 | transposed = zip(*batch) 170 | return [numpy_collate(samples) for samples in transposed] 171 | else: 172 | return jnp.array(batch) 173 | 174 | 175 | def setup_nbody_data( 176 | args, 177 | ) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]: 178 | if args.dataset == "charged": 179 | dataset_train = ChargedDataset( 180 | partition="train", 181 | dataset_name=args.dataset_name, 182 | max_samples=args.max_samples, 183 | n_bodies=args.n_bodies, 184 | ) 185 | dataset_val = ChargedDataset( 186 | partition="val", 187 | dataset_name=args.dataset_name, 188 | n_bodies=args.n_bodies, 189 | ) 190 | dataset_test = ChargedDataset( 191 | partition="test", 192 | dataset_name=args.dataset_name, 193 | n_bodies=args.n_bodies, 194 | ) 195 | 196 | if args.dataset == "gravity": 197 | dataset_train = GravityDataset( 198 | partition="train", 199 | dataset_name=args.dataset_name, 200 | max_samples=args.max_samples, 201 | neighbours=args.neighbours, 202 | target=args.target, 203 | n_bodies=args.n_bodies, 204 | ) 205 | dataset_val = GravityDataset( 206 | partition="val", 207 | dataset_name=args.dataset_name, 208 | neighbours=args.neighbours, 209 | target=args.target, 210 | n_bodies=args.n_bodies, 211 | ) 212 | dataset_test = GravityDataset( 213 | partition="test", 214 | dataset_name=args.dataset_name, 215 | neighbours=args.neighbours, 216 | target=args.target, 217 | n_bodies=args.n_bodies, 218 | ) 219 | 220 | o3_transform = O3Transform( 221 | args.node_irreps, 222 | args.additional_message_irreps, 223 | args.lmax_attributes, 224 | scn=args.o3_layer == "scn", 225 | ) 226 | graph_transform = NbodyGraphTransform( 227 | transform=o3_transform, 228 | n_nodes=args.n_bodies, 229 | batch_size=args.batch_size, 230 | neighbours=args.neighbours, 231 | relative_target=(args.target == "pos"), 232 | dataset_name=args.dataset, 233 | ) 234 | 235 | loader_train = DataLoader( 236 | dataset_train, 237 | batch_size=args.batch_size, 238 | shuffle=True, 239 | drop_last=True, 240 | collate_fn=numpy_collate, 241 | ) 242 | loader_val = DataLoader( 243 | dataset_val, 244 | batch_size=args.batch_size, 245 | shuffle=False, 246 | drop_last=False, 247 | collate_fn=numpy_collate, 248 | ) 249 | loader_test = DataLoader( 250 | dataset_test, 251 | batch_size=args.batch_size, 252 | shuffle=False, 253 | drop_last=False, 254 | collate_fn=numpy_collate, 255 | ) 256 | 257 | return loader_train, loader_val, loader_test, graph_transform, None 258 | -------------------------------------------------------------------------------- /experiments/nbody/data/synthetic_sim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ChargedParticlesSim(object): 5 | def __init__( 6 | self, 7 | n_balls=5, 8 | box_size=5.0, 9 | loc_std=1.0, 10 | vel_norm=0.5, 11 | interaction_strength=1.0, 12 | noise_var=0.0, 13 | ): 14 | self.n_balls = n_balls 15 | self.box_size = box_size 16 | self.loc_std = loc_std 17 | self.loc_std = loc_std * (float(n_balls) / 5.0) ** (1 / 3) 18 | print(self.loc_std) 19 | self.vel_norm = vel_norm 20 | self.interaction_strength = interaction_strength 21 | self.noise_var = noise_var 22 | 23 | self._charge_types = np.array([-1.0, 0.0, 1.0]) 24 | self._delta_T = 0.001 25 | self._max_F = 0.1 / self._delta_T 26 | self.dim = 3 27 | 28 | def _l2(self, A, B): 29 | """ 30 | Input: A is a Nxd matrix 31 | B is a Mxd matirx 32 | Output: dist is a NxM matrix where dist[i,j] is the square norm 33 | between A[i,:] and B[j,:] 34 | i.e. dist[i,j] = ||A[i,:]-B[j,:]||^2 35 | """ 36 | A_norm = (A**2).sum(axis=1).reshape(A.shape[0], 1) 37 | B_norm = (B**2).sum(axis=1).reshape(1, B.shape[0]) 38 | dist = A_norm + B_norm - 2 * A.dot(B.transpose()) 39 | return dist 40 | 41 | def _energy(self, loc, vel, edges): 42 | # disables division by zero warning, since I fix it with fill_diagonal 43 | with np.errstate(divide="ignore"): 44 | K = 0.5 * (vel**2).sum() 45 | U = 0 46 | for i in range(loc.shape[1]): 47 | for j in range(loc.shape[1]): 48 | if i != j: 49 | r = loc[:, i] - loc[:, j] 50 | dist = np.sqrt((r**2).sum()) 51 | U += 0.5 * self.interaction_strength * edges[i, j] / dist 52 | return U + K 53 | 54 | def _clamp(self, loc, vel): 55 | """ 56 | :param loc: 2xN location at one time stamp 57 | :param vel: 2xN velocity at one time stamp 58 | :return: location and velocity after hiting walls and returning after 59 | elastically colliding with walls 60 | """ 61 | assert np.all(loc < self.box_size * 3) 62 | assert np.all(loc > -self.box_size * 3) 63 | 64 | over = loc > self.box_size 65 | loc[over] = 2 * self.box_size - loc[over] 66 | assert np.all(loc <= self.box_size) 67 | 68 | # assert(np.all(vel[over]>0)) 69 | vel[over] = -np.abs(vel[over]) 70 | 71 | under = loc < -self.box_size 72 | loc[under] = -2 * self.box_size - loc[under] 73 | # assert (np.all(vel[under] < 0)) 74 | assert np.all(loc >= -self.box_size) 75 | vel[under] = np.abs(vel[under]) 76 | 77 | return loc, vel 78 | 79 | def sample_trajectory(self, T=10000, sample_freq=10, charge_prob=None): 80 | if charge_prob is None: 81 | charge_prob = [1.0 / 2, 0, 1.0 / 2] 82 | n = self.n_balls 83 | assert T % sample_freq == 0 84 | T_save = int(T / sample_freq - 1) 85 | diag_mask = np.ones((n, n), dtype=bool) 86 | np.fill_diagonal(diag_mask, 0) 87 | counter = 0 88 | # Sample edges 89 | charges = np.random.choice( 90 | self._charge_types, size=(self.n_balls, 1), p=charge_prob 91 | ) 92 | edges = charges.dot(charges.transpose()) 93 | # Initialize location and velocity 94 | loc = np.zeros((T_save, self.dim, n)) 95 | vel = np.zeros((T_save, self.dim, n)) 96 | loc_next = np.random.randn(self.dim, n) * self.loc_std 97 | vel_next = np.random.randn(self.dim, n) 98 | v_norm = np.sqrt((vel_next**2).sum(axis=0)).reshape(1, -1) 99 | vel_next = vel_next * self.vel_norm / v_norm 100 | loc[0, :, :], vel[0, :, :] = self._clamp(loc_next, vel_next) 101 | 102 | # disables division by zero warning, since I fix it with fill_diagonal 103 | with np.errstate(divide="ignore"): 104 | # half step leapfrog 105 | l2_dist_power3 = np.power( 106 | self._l2(loc_next.transpose(), loc_next.transpose()), 3.0 / 2.0 107 | ) 108 | 109 | # size of forces up to a 1/|r| factor 110 | # since I later multiply by an unnormalized r vector 111 | forces_size = self.interaction_strength * edges / l2_dist_power3 112 | np.fill_diagonal( 113 | forces_size, 0 114 | ) # self forces are zero (fixes division by zero) 115 | assert np.abs(forces_size[diag_mask]).min() > 1e-10 116 | F = ( 117 | forces_size.reshape(1, n, n) 118 | * np.concatenate( 119 | ( 120 | np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape( 121 | 1, n, n 122 | ), 123 | np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape( 124 | 1, n, n 125 | ), 126 | np.subtract.outer(loc_next[2, :], loc_next[2, :]).reshape( 127 | 1, n, n 128 | ), 129 | ) 130 | ) 131 | ).sum(axis=-1) 132 | F[F > self._max_F] = self._max_F 133 | F[F < -self._max_F] = -self._max_F 134 | 135 | vel_next += self._delta_T * F 136 | # run leapfrog 137 | for i in range(1, T): 138 | loc_next += self._delta_T * vel_next 139 | # loc_next, vel_next = self._clamp(loc_next, vel_next) 140 | 141 | if i % sample_freq == 0: 142 | loc[counter, :, :], vel[counter, :, :] = loc_next, vel_next 143 | counter += 1 144 | 145 | l2_dist_power3 = np.power( 146 | self._l2(loc_next.transpose(), loc_next.transpose()), 3.0 / 2.0 147 | ) 148 | forces_size = self.interaction_strength * edges / l2_dist_power3 149 | np.fill_diagonal(forces_size, 0) 150 | # assert (np.abs(forces_size[diag_mask]).min() > 1e-10) 151 | 152 | F = ( 153 | forces_size.reshape(1, n, n) 154 | * np.concatenate( 155 | ( 156 | np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape( 157 | 1, n, n 158 | ), 159 | np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape( 160 | 1, n, n 161 | ), 162 | np.subtract.outer(loc_next[2, :], loc_next[2, :]).reshape( 163 | 1, n, n 164 | ), 165 | ) 166 | ) 167 | ).sum(axis=-1) 168 | F[F > self._max_F] = self._max_F 169 | F[F < -self._max_F] = -self._max_F 170 | vel_next += self._delta_T * F 171 | # Add noise to observations 172 | loc += np.random.randn(T_save, self.dim, self.n_balls) * self.noise_var 173 | vel += np.random.randn(T_save, self.dim, self.n_balls) * self.noise_var 174 | return loc, vel, edges, charges 175 | 176 | 177 | class GravitySim(object): 178 | def __init__( 179 | self, 180 | n_balls=100, 181 | loc_std=1, 182 | vel_norm=0.5, 183 | interaction_strength=1, 184 | noise_var=0, 185 | dt=0.001, 186 | softening=0.1, 187 | ): 188 | self.n_balls = n_balls 189 | self.loc_std = loc_std 190 | self.vel_norm = vel_norm 191 | self.interaction_strength = interaction_strength 192 | self.noise_var = noise_var 193 | self.dt = dt 194 | self.softening = softening 195 | 196 | self.dim = 3 197 | 198 | def compute_acceleration(self, pos, mass, G, softening): 199 | # positions r = [x,y,z] for all particles 200 | x = pos[:, 0:1] 201 | y = pos[:, 1:2] 202 | z = pos[:, 2:3] 203 | 204 | # matrix that stores all pairwise particle separations: r_j - r_i 205 | dx = x.T - x 206 | dy = y.T - y 207 | dz = z.T - z 208 | 209 | # matrix that stores 1/r^3 for all particle pairwise particle separations 210 | inv_r3 = dx**2 + dy**2 + dz**2 + softening**2 211 | inv_r3[inv_r3 > 0] = inv_r3[inv_r3 > 0] ** (-1.5) 212 | 213 | ax = G * (dx * inv_r3) @ mass 214 | ay = G * (dy * inv_r3) @ mass 215 | az = G * (dz * inv_r3) @ mass 216 | 217 | # pack together the acceleration components 218 | a = np.hstack((ax, ay, az)) 219 | return a 220 | 221 | def _energy(self, pos, vel, mass, G): 222 | # Kinetic Energy: 223 | KE = 0.5 * np.sum(np.sum(mass * vel**2)) 224 | 225 | # Potential Energy: 226 | 227 | # positions r = [x,y,z] for all particles 228 | x = pos[:, 0:1] 229 | y = pos[:, 1:2] 230 | z = pos[:, 2:3] 231 | 232 | # matrix that stores all pairwise particle separations: r_j - r_i 233 | dx = x.T - x 234 | dy = y.T - y 235 | dz = z.T - z 236 | 237 | # matrix that stores 1/r for all particle pairwise particle separations 238 | inv_r = np.sqrt(dx**2 + dy**2 + dz**2) 239 | inv_r[inv_r > 0] = 1.0 / inv_r[inv_r > 0] 240 | 241 | # sum over upper triangle, to count each interaction only once 242 | PE = G * np.sum(np.sum(np.triu(-(mass * mass.T) * inv_r, 1))) 243 | 244 | return KE, PE, KE + PE 245 | 246 | def sample_trajectory(self, T=10000, sample_freq=10): 247 | assert T % sample_freq == 0 248 | 249 | T_save = int(T / sample_freq) 250 | 251 | N = self.n_balls 252 | 253 | pos_save = np.zeros((T_save, N, self.dim)) 254 | vel_save = np.zeros((T_save, N, self.dim)) 255 | force_save = np.zeros((T_save, N, self.dim)) 256 | 257 | # Specific sim parameters 258 | mass = np.ones((N, 1)) 259 | t = 0 260 | pos = np.random.randn(N, self.dim) # randomly selected positions and velocities 261 | vel = np.random.randn(N, self.dim) 262 | 263 | # Convert to Center-of-Mass frame 264 | vel -= np.mean(mass * vel, 0) / np.mean(mass) 265 | 266 | # calculate initial gravitational accelerations 267 | acc = self.compute_acceleration( 268 | pos, mass, self.interaction_strength, self.softening 269 | ) 270 | 271 | for i in range(T): 272 | if i % sample_freq == 0: 273 | pos_save[int(i / sample_freq)] = pos 274 | vel_save[int(i / sample_freq)] = vel 275 | force_save[int(i / sample_freq)] = acc * mass 276 | 277 | # (1/2) kick 278 | vel += acc * self.dt / 2.0 279 | 280 | # drift 281 | pos += vel * self.dt 282 | 283 | # update accelerations 284 | acc = self.compute_acceleration( 285 | pos, mass, self.interaction_strength, self.softening 286 | ) 287 | 288 | # (1/2) kick 289 | vel += acc * self.dt / 2.0 290 | 291 | # update time 292 | t += self.dt 293 | 294 | # Add noise to observations 295 | pos_save += np.random.randn(T_save, N, self.dim) * self.noise_var 296 | vel_save += np.random.randn(T_save, N, self.dim) * self.noise_var 297 | force_save += np.random.randn(T_save, N, self.dim) * self.noise_var 298 | return pos_save, vel_save, force_save, mass 299 | -------------------------------------------------------------------------------- /segnn_jax/segnn.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Optional, Union 2 | 3 | import e3nn_jax as e3nn 4 | import haiku as hk 5 | import jax.numpy as jnp 6 | import jraph 7 | from e3nn_jax import IrrepsArray 8 | from jax.tree_util import Partial 9 | 10 | from .blocks import O3_LAYERS, O3TensorProduct, O3TensorProductGate, TensorProduct 11 | from .config import config 12 | from .graph_utils import SteerableGraphsTuple, pooling 13 | 14 | 15 | def O3Embedding( 16 | embed_irreps: e3nn.Irreps, 17 | embed_edges: bool = True, 18 | O3Layer: TensorProduct = O3TensorProduct, 19 | ) -> Callable: 20 | """Linear steerable embedding. 21 | 22 | Embeds the graph nodes in the representation space :param embed_irreps:. 23 | 24 | Args: 25 | embed_irreps: Output representation 26 | embed_edges: If true also embed edges/message passing features 27 | O3Layer: Type of tensor product layer to use 28 | 29 | Returns: 30 | Function to embed graph nodes (and optionally edges) 31 | """ 32 | 33 | def _embedding( 34 | st_graph: SteerableGraphsTuple, 35 | ) -> SteerableGraphsTuple: 36 | graph = st_graph.graph 37 | nodes = O3Layer(embed_irreps, name="embedding_nodes")( 38 | graph.nodes, st_graph.node_attributes 39 | ) 40 | st_graph = st_graph._replace(graph=graph._replace(nodes=nodes)) 41 | 42 | # NOTE edge embedding is not in the original paper but can get good results 43 | if embed_edges: 44 | additional_message_features = O3Layer( 45 | embed_irreps, 46 | name="embedding_msg_features", 47 | )( 48 | st_graph.additional_message_features, 49 | st_graph.edge_attributes, 50 | ) 51 | st_graph = st_graph._replace( 52 | additional_message_features=additional_message_features 53 | ) 54 | 55 | return st_graph 56 | 57 | return _embedding 58 | 59 | 60 | def O3Decoder( 61 | latent_irreps: e3nn.Irreps, 62 | output_irreps: e3nn.Irreps, 63 | blocks: int = 1, 64 | task: str = "graph", 65 | pool: Optional[str] = "avg", 66 | pooled_irreps: Optional[e3nn.Irreps] = None, 67 | O3Layer: TensorProduct = O3TensorProduct, 68 | ): 69 | """Steerable pooler and decoder. 70 | 71 | Args: 72 | latent_irreps: Representation from the previous block 73 | output_irreps: Output representation 74 | blocks: Number of tensor product blocks in the decoder 75 | task: Specifies where the output is located. Either 'graph' or 'node' 76 | pool: Pooling method to use. One of 'avg', 'sum', 'none', None 77 | pooled_irreps: Pooled irreps. When left None the original implementation is used 78 | O3Layer: Type of tensor product layer to use 79 | 80 | Returns: 81 | Decoded latent feature space to output space. 82 | """ 83 | 84 | assert task in ["node", "graph"], f"Unknown task {task}" 85 | assert pool in ["avg", "sum", "none", None], f"Unknown pooling '{pool}'" 86 | 87 | # NOTE: original implementation restricted final layers to pooled_irreps. 88 | # This way gates cannot be applied in the post pool block when returning vectors, 89 | # because the gating scalars cannot be reached. 90 | if pooled_irreps is None: 91 | pooled_irreps = (output_irreps * latent_irreps.num_irreps).regroup() 92 | 93 | def _decoder(st_graph: SteerableGraphsTuple): 94 | nodes = st_graph.graph.nodes 95 | # pre pool block 96 | for i in range(blocks): 97 | nodes = O3TensorProductGate( 98 | latent_irreps, name=f"prepool_{i}", o3_layer=O3Layer 99 | )(nodes, st_graph.node_attributes) 100 | 101 | if task == "node": 102 | nodes = O3Layer(output_irreps, name="output")( 103 | nodes, st_graph.node_attributes 104 | ) 105 | 106 | if task == "graph": 107 | # pool over graph 108 | nodes = O3Layer(pooled_irreps, name=f"prepool_{blocks}")( 109 | nodes, st_graph.node_attributes 110 | ) 111 | 112 | # pooling layer 113 | if pool == "avg": 114 | pool_fn = jraph.segment_mean 115 | if pool == "sum": 116 | pool_fn = jraph.segment_sum 117 | 118 | nodes = pooling(st_graph.graph._replace(nodes=nodes), aggregate_fn=pool_fn) 119 | 120 | # post pool mlp (not steerable) 121 | for i in range(blocks): 122 | nodes = O3TensorProductGate( 123 | pooled_irreps, name=f"postpool_{i}", o3_layer=O3TensorProduct 124 | )(nodes) 125 | nodes = O3TensorProduct(output_irreps, name="output")(nodes) 126 | 127 | return nodes 128 | 129 | return _decoder 130 | 131 | 132 | class SEGNNLayer(hk.Module): 133 | """ 134 | Steerable E(3) equivariant layer. 135 | 136 | Applies a message passing step (GN) with equivariant message and update functions. 137 | """ 138 | 139 | def __init__( 140 | self, 141 | output_irreps: e3nn.Irreps, 142 | layer_num: int, 143 | blocks: int = 2, 144 | norm: Optional[str] = None, 145 | aggregate_fn: Optional[Callable] = jraph.segment_sum, 146 | residual: bool = True, 147 | O3Layer: TensorProduct = O3TensorProduct, 148 | ): 149 | """ 150 | Initialize the layer. 151 | 152 | Args: 153 | output_irreps: Layer output representation 154 | layer_num: Numbering of the layer 155 | blocks: Number of tensor product blocks in the layer 156 | norm: Normalization type. Either be None, 'instance' or 'batch' 157 | aggregate_fn: Message aggregation function. Defaults to sum. 158 | residual: If true, use residual connections 159 | O3Layer: Type of tensor product layer to use 160 | """ 161 | super().__init__(f"layer_{layer_num}") 162 | assert norm in ["batch", "instance", "none", None], f"Unknown norm '{norm}'" 163 | self._output_irreps = output_irreps 164 | self._blocks = blocks 165 | self._norm = norm 166 | self._aggregate_fn = aggregate_fn 167 | self._residual = residual 168 | 169 | self._O3Layer = O3Layer 170 | 171 | def _message( 172 | self, 173 | edge_attribute: IrrepsArray, 174 | additional_message_features: IrrepsArray, 175 | edge_features: Any, 176 | incoming: IrrepsArray, 177 | outgoing: IrrepsArray, 178 | globals_: Any, 179 | ) -> IrrepsArray: 180 | """Steerable equivariant message function.""" 181 | _ = globals_ 182 | _ = edge_features 183 | # create messages 184 | msg = e3nn.concatenate([incoming, outgoing], axis=-1) 185 | if additional_message_features is not None: 186 | msg = e3nn.concatenate([msg, additional_message_features], axis=-1) 187 | # message mlp (phi_m in the paper) steered by edge attributeibutes 188 | for i in range(self._blocks): 189 | msg = O3TensorProductGate( 190 | self._output_irreps, name=f"tp_{i}", o3_layer=self._O3Layer 191 | )(msg, edge_attribute) 192 | # NOTE: original implementation only applied batch norm to messages 193 | if self._norm == "batch": 194 | msg = e3nn.haiku.BatchNorm(irreps=self._output_irreps)(msg) 195 | return msg 196 | 197 | def _update( 198 | self, 199 | node_attribute: IrrepsArray, 200 | nodes: IrrepsArray, 201 | senders: Any, 202 | msg: IrrepsArray, 203 | globals_: Any, 204 | ) -> IrrepsArray: 205 | """Steerable equivariant update function.""" 206 | _ = senders 207 | _ = globals_ 208 | x = e3nn.concatenate([nodes, msg], axis=-1) 209 | # update mlp (phi_f in the paper) steered by node attributeibutes 210 | for i in range(self._blocks - 1): 211 | x = O3TensorProductGate( 212 | self._output_irreps, name=f"tp_{i}", o3_layer=self._O3Layer 213 | )(x, node_attribute) 214 | # last update layer without activation 215 | update = self._O3Layer(self._output_irreps, name=f"tp_{self._blocks - 1}")( 216 | x, node_attribute 217 | ) 218 | # residual connection 219 | if self._residual: 220 | nodes += update 221 | else: 222 | nodes = update 223 | # message norm 224 | if self._norm in ["batch", "instance"]: 225 | nodes = e3nn.haiku.BatchNorm( 226 | irreps=self._output_irreps, 227 | instance=(self._norm == "instance"), 228 | )(nodes) 229 | return nodes 230 | 231 | def __call__(self, st_graph: SteerableGraphsTuple) -> SteerableGraphsTuple: 232 | """Perform a message passing step. 233 | 234 | Args: 235 | st_graph: Input graph 236 | 237 | Returns: 238 | The updated graph 239 | """ 240 | # NOTE node_attributes, edge_attributes and additional_message_features 241 | # are never updated within the message passing layers 242 | return st_graph._replace( 243 | graph=jraph.GraphNetwork( 244 | update_node_fn=Partial(self._update, st_graph.node_attributes), 245 | update_edge_fn=Partial( 246 | self._message, 247 | st_graph.edge_attributes, 248 | st_graph.additional_message_features, 249 | ), 250 | aggregate_edges_for_nodes_fn=self._aggregate_fn, 251 | )(st_graph.graph) 252 | ) 253 | 254 | 255 | class SEGNN(hk.Module): 256 | """Steerable E(3) equivariant network. 257 | 258 | Original paper https://arxiv.org/abs/2110.02905. 259 | """ 260 | 261 | def __init__( 262 | self, 263 | hidden_irreps: Union[List[e3nn.Irreps], e3nn.Irreps], 264 | output_irreps: e3nn.Irreps, 265 | num_layers: int, 266 | norm: Optional[str] = None, 267 | pool: Optional[str] = "avg", 268 | task: Optional[str] = "graph", 269 | blocks_per_layer: int = 2, 270 | embed_msg_features: bool = False, 271 | o3_layer: Optional[Union[str, TensorProduct]] = None, 272 | ): 273 | """ 274 | Initialize the network. 275 | 276 | Args: 277 | hidden_irreps: Feature representation in the hidden layers 278 | output_irreps: Output representation. 279 | num_layers: Number of message passing layers 280 | norm: Normalization type. Either None, 'instance' or 'batch' 281 | pool: Pooling mode (only for graph-wise tasks) 282 | task: Specifies where the output is located. Either 'graph' or 'node' 283 | blocks_per_layer: Number of tensor product blocks in each message passing 284 | embed_msg_features: Set to true to also embed edges/message passing features 285 | o3_layer: Tensor product layer type. "tpl", "fctp", "scn" or a custom layer 286 | """ 287 | super().__init__() 288 | 289 | if not isinstance(output_irreps, e3nn.Irreps): 290 | output_irreps = e3nn.Irreps(output_irreps) 291 | if not isinstance(hidden_irreps, e3nn.Irreps): 292 | hidden_irreps = e3nn.Irreps(hidden_irreps) 293 | 294 | self._hidden_irreps = hidden_irreps 295 | self._num_layers = num_layers 296 | 297 | self._embed_msg_features = embed_msg_features 298 | self._norm = norm 299 | self._blocks_per_layer = blocks_per_layer 300 | 301 | # layer type 302 | if o3_layer is None: 303 | o3_layer = config("o3_layer") 304 | if isinstance(o3_layer, str): 305 | assert o3_layer in O3_LAYERS, f"Unknown O3 layer {o3_layer}." 306 | self._O3Layer = O3_LAYERS[o3_layer] 307 | else: 308 | self._O3Layer = o3_layer 309 | 310 | self._embedding = O3Embedding( 311 | self._hidden_irreps, 312 | O3Layer=self._O3Layer, 313 | embed_edges=self._embed_msg_features, 314 | ) 315 | 316 | pooled_irreps = None 317 | if task == "graph" and "0e" not in output_irreps: 318 | # NOTE: different from original. This way proper gates are always applied 319 | pooled_irreps = hidden_irreps 320 | 321 | self._decoder = O3Decoder( 322 | latent_irreps=self._hidden_irreps, 323 | output_irreps=output_irreps, 324 | O3Layer=self._O3Layer, 325 | task=task, 326 | pool=pool, 327 | pooled_irreps=pooled_irreps, 328 | ) 329 | 330 | def __call__(self, st_graph: SteerableGraphsTuple) -> jnp.array: 331 | # node (and edge) embedding 332 | st_graph = self._embedding(st_graph) 333 | 334 | # message passing 335 | for n in range(self._num_layers): 336 | st_graph = SEGNNLayer( 337 | output_irreps=self._hidden_irreps, layer_num=n, norm=self._norm 338 | )(st_graph) 339 | 340 | # decoder/pooler 341 | nodes = self._decoder(st_graph) 342 | 343 | return jnp.squeeze(nodes.array) 344 | -------------------------------------------------------------------------------- /experiments/qm9/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from typing import List, Optional 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from e3nn.o3 import Irreps, spherical_harmonics 9 | from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip 10 | from torch_geometric.nn import radius_graph 11 | from torch_scatter import scatter 12 | from tqdm import tqdm 13 | 14 | HAR2EV = 27.211386246 15 | KCALMOL2EV = 0.04336414 16 | 17 | conversion = torch.tensor( 18 | [ 19 | 1.0, 20 | 1.0, 21 | HAR2EV, 22 | HAR2EV, 23 | HAR2EV, 24 | 1.0, 25 | HAR2EV, 26 | HAR2EV, 27 | HAR2EV, 28 | HAR2EV, 29 | HAR2EV, 30 | 1.0, 31 | KCALMOL2EV, 32 | KCALMOL2EV, 33 | KCALMOL2EV, 34 | KCALMOL2EV, 35 | 1.0, 36 | 1.0, 37 | 1.0, 38 | ] 39 | ) 40 | 41 | atomrefs = { 42 | 6: [0.0, 0.0, 0.0, 0.0, 0.0], 43 | 7: [-13.61312172, -1029.86312267, -1485.30251237, -2042.61123593, -2713.48485589], 44 | 8: [-13.5745904, -1029.82456413, -1485.26398105, -2042.5727046, -2713.44632457], 45 | 9: [-13.54887564, -1029.79887659, -1485.2382935, -2042.54701705, -2713.42063702], 46 | 10: [-13.90303183, -1030.25891228, -1485.71166277, -2043.01812778, -2713.88796536], 47 | 11: [0.0, 0.0, 0.0, 0.0, 0.0], 48 | } 49 | 50 | 51 | targets = [ 52 | "mu", 53 | "alpha", 54 | "homo", 55 | "lumo", 56 | "gap", 57 | "r2", 58 | "zpve", 59 | "U0", 60 | "U", 61 | "H", 62 | "G", 63 | "Cv", 64 | "U0_atom", 65 | "U_atom", 66 | "H_atom", 67 | "G_atom", 68 | "A", 69 | "B", 70 | "C", 71 | ] 72 | 73 | thermo_targets = ["U", "U0", "H", "G"] 74 | 75 | 76 | class TargetGetter(object): 77 | """Gets relevant target""" 78 | 79 | def __init__(self, target): 80 | self.target = target 81 | self.target_idx = targets.index(target) 82 | 83 | def __call__(self, data): 84 | # Specify target. 85 | data.y = data.y[0, self.target_idx] 86 | return data 87 | 88 | 89 | class QM9(InMemoryDataset): 90 | r"""The QM9 dataset from the `"MoleculeNet: A Benchmark for Molecular 91 | Machine Learning" `_ paper, consisting of 92 | about 130,000 molecules with 19 regression targets. 93 | Each molecule includes complete spatial information for the single low 94 | energy conformation of the atoms in the molecule. 95 | In addition, we provide the atom features from the `"Neural Message 96 | Passing for Quantum Chemistry" `_ paper.""" 97 | 98 | raw_url = ( 99 | "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/" 100 | "molnet_publish/qm9.zip" 101 | ) 102 | raw_url2 = "https://ndownloader.figshare.com/files/3195404" 103 | processed_url = "https://data.pyg.org/datasets/qm9_v3.zip" 104 | 105 | def __init__( 106 | self, root, target, radius, partition, lmax_attr, feature_type="one_hot" 107 | ): 108 | assert feature_type in [ 109 | "one_hot", 110 | "cormorant", 111 | "gilmer", 112 | ], "Please use valid features" 113 | assert target in targets 114 | assert partition in ["train", "valid", "test"] 115 | self.root = osp.abspath(osp.join(root, "qm9")) 116 | self.target = target 117 | self.radius = radius 118 | self.partition = partition 119 | self.feature_type = feature_type 120 | self.lmax_attr = lmax_attr 121 | self.attr_irreps = Irreps.spherical_harmonics(lmax_attr) 122 | transform = TargetGetter(self.target) 123 | 124 | super().__init__(self.root, transform) 125 | 126 | self.data, self.slices = torch.load(self.processed_paths[0]) 127 | 128 | def calc_stats(self): 129 | ys = np.array([data.y.item() for data in self]) 130 | mean = np.mean(ys) 131 | mad = np.mean(np.abs(ys - mean)) 132 | return mean, mad 133 | 134 | def atomref(self, target) -> Optional[torch.Tensor]: 135 | if target in atomrefs: 136 | out = torch.zeros(100) 137 | out[torch.tensor([1, 6, 7, 8, 9])] = torch.tensor(atomrefs[target]) 138 | return out.view(-1, 1) 139 | return None 140 | 141 | @property 142 | def raw_file_names(self) -> Optional[List[str]]: 143 | try: 144 | import rdkit # noqa 145 | 146 | return ["gdb9.sdf", "gdb9.sdf.csv", "uncharacterized.txt"] 147 | except ImportError: 148 | import sys 149 | 150 | print("Please install rdkit") 151 | sys.exit(1) 152 | 153 | @property 154 | def processed_file_names(self) -> List[str]: 155 | return [ 156 | "_".join( 157 | [ 158 | self.partition, 159 | "r=" + str(np.round(self.radius, 2)), 160 | self.feature_type, 161 | "l=" + str(self.lmax_attr), 162 | ] 163 | ) 164 | + ".pt" 165 | ] 166 | 167 | def download(self): 168 | print("i'm downloading", self.raw_dir, self.raw_url) 169 | try: 170 | import rdkit # noqa 171 | 172 | file_path = download_url(self.raw_url, self.raw_dir) 173 | extract_zip(file_path, self.raw_dir) 174 | os.unlink(file_path) 175 | 176 | file_path = download_url(self.raw_url2, self.raw_dir) 177 | os.rename( 178 | osp.join(self.raw_dir, "3195404"), 179 | osp.join(self.raw_dir, "uncharacterized.txt"), 180 | ) 181 | except ImportError: 182 | path = download_url(self.processed_url, self.raw_dir) 183 | extract_zip(path, self.raw_dir) 184 | os.unlink(path) 185 | 186 | def process(self): 187 | try: 188 | from rdkit import Chem, RDLogger 189 | from rdkit.Chem.rdchem import HybridizationType 190 | 191 | RDLogger.DisableLog("rdApp.*") 192 | except ImportError: 193 | print("Please install rdkit") 194 | return 195 | 196 | print( 197 | "Processing", 198 | self.partition, 199 | "with radius=" + str(np.round(self.radius, 2)) + ",", 200 | "l_attr=" + str(self.lmax_attr), 201 | "and", 202 | self.feature_type, 203 | "features.", 204 | ) 205 | types = {"H": 0, "C": 1, "N": 2, "O": 3, "F": 4} 206 | 207 | with open(self.raw_paths[1], "r") as f: 208 | target = f.read().split("\n")[1:-1] 209 | target = [[float(x) for x in line.split(",")[1:20]] for line in target] 210 | target = torch.tensor(target, dtype=torch.float) 211 | target = torch.cat([target[:, 3:], target[:, :3]], dim=-1) 212 | target = target * conversion.view(1, -1) 213 | 214 | with open(self.raw_paths[2], "r") as f: 215 | skip = [int(x.split()[0]) - 1 for x in f.read().split("\n")[9:-2]] 216 | 217 | suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False, sanitize=False) 218 | data_list = [] 219 | 220 | # Create splits identical to Cormorant 221 | Nmols = len(suppl) - len(skip) 222 | Ntrain = 100000 223 | Ntest = int(0.1 * Nmols) 224 | Nvalid = Nmols - (Ntrain + Ntest) 225 | 226 | np.random.seed(0) 227 | data_perm = np.random.permutation(Nmols) 228 | train, valid, test = np.split(data_perm, [Ntrain, Ntrain + Nvalid]) 229 | indices = {"train": train, "valid": valid, "test": test} 230 | 231 | # Add a very ugly second index to align with Cormorant splits. 232 | j = 0 233 | for i, mol in enumerate(tqdm(suppl)): 234 | if i in skip: 235 | continue 236 | if j not in indices[self.partition]: 237 | j += 1 238 | continue 239 | j += 1 240 | 241 | N = mol.GetNumAtoms() 242 | 243 | pos = suppl.GetItemText(i).split("\n")[4 : 4 + N] 244 | pos = [[float(x) for x in line.split()[:3]] for line in pos] 245 | pos = torch.tensor(pos, dtype=torch.float) 246 | 247 | edge_index = radius_graph(pos, r=self.radius, loop=False) 248 | 249 | type_idx = [] 250 | atomic_number = [] 251 | aromatic = [] 252 | sp = [] 253 | sp2 = [] 254 | sp3 = [] 255 | num_hs = [] 256 | for atom in mol.GetAtoms(): 257 | type_idx.append(types[atom.GetSymbol()]) 258 | atomic_number.append(atom.GetAtomicNum()) 259 | aromatic.append(1 if atom.GetIsAromatic() else 0) 260 | hybridization = atom.GetHybridization() 261 | sp.append(1 if hybridization == HybridizationType.SP else 0) 262 | sp2.append(1 if hybridization == HybridizationType.SP2 else 0) 263 | sp3.append(1 if hybridization == HybridizationType.SP3 else 0) 264 | 265 | z = torch.tensor(atomic_number, dtype=torch.long) 266 | 267 | if self.feature_type == "one_hot": 268 | x = F.one_hot(torch.tensor(type_idx), num_classes=len(types)).float() 269 | elif self.feature_type == "cormorant": 270 | one_hot = F.one_hot(torch.tensor(type_idx), num_classes=len(types)) 271 | x = get_cormorant_features(one_hot, z, 2, z.max()) 272 | elif self.feature_type == "gilmer": 273 | row, col = edge_index 274 | hs = (z == 1).to(torch.float) 275 | num_hs = scatter(hs[row], col, dim_size=N).tolist() 276 | 277 | x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types)) 278 | x2 = ( 279 | torch.tensor( 280 | [atomic_number, aromatic, sp, sp2, sp3, num_hs], 281 | dtype=torch.float, 282 | ) 283 | .t() 284 | .contiguous() 285 | ) 286 | x = torch.cat([x1.to(torch.float), x2], dim=-1) 287 | 288 | y = target[i].unsqueeze(0) 289 | name = mol.GetProp("_Name") 290 | 291 | edge_attr, node_attr, edge_dist = self.get_O3_attr( 292 | edge_index, pos, self.attr_irreps 293 | ) 294 | 295 | data = Data( 296 | x=x, 297 | pos=pos, 298 | edge_index=edge_index, 299 | edge_attr=edge_attr, 300 | node_attr=node_attr, 301 | additional_message_features=edge_dist, 302 | y=y, 303 | name=name, 304 | index=i, 305 | ) 306 | data_list.append(data) 307 | 308 | torch.save(self.collate(data_list), self.processed_paths[0]) 309 | 310 | def get_O3_attr(self, edge_index, pos, attr_irreps): 311 | """ 312 | Creates spherical harmonic edge attributes and node attributes for the SEGNN. 313 | """ 314 | rel_pos = ( 315 | pos[edge_index[0]] - pos[edge_index[1]] 316 | ) # pos_j - pos_i (note in edge_index stores tuples like (j,i)) 317 | edge_dist = rel_pos.pow(2).sum(-1, keepdims=True) 318 | edge_attr = spherical_harmonics( 319 | attr_irreps, rel_pos, normalize=True, normalization="component" 320 | ) # Unnormalised for now 321 | node_attr = scatter(edge_attr, edge_index[1], dim=0, reduce="mean") 322 | return edge_attr, node_attr, edge_dist 323 | 324 | def top_n_nodes(self, n: int) -> List[int]: 325 | """Returns the largest n nodes in the dataset.""" 326 | return [int(k) for k in torch.topk(torch.diff(self.slices["x"]), n)[0]] 327 | 328 | def top_n_edges(self, n: int) -> List[int]: 329 | """Returns the largest n edge in the dataset.""" 330 | return [int(k) for k in torch.topk(torch.diff(self.slices["edge_attr"]), n)[0]] 331 | 332 | 333 | def get_cormorant_features(one_hot, charges, charge_power, charge_scale): 334 | """Create input features as described in section 7.3 of https://arxiv.org/pdf/1906.04015.pdf""" 335 | charge_tensor = (charges.unsqueeze(-1) / charge_scale).pow( 336 | torch.arange(charge_power + 1.0, dtype=torch.float32) 337 | ) 338 | charge_tensor = charge_tensor.view(charges.shape + (1, charge_power + 1)) 339 | atom_scalars = (one_hot.unsqueeze(-1) * charge_tensor).view( 340 | charges.shape[:2] + (-1,) 341 | ) 342 | return atom_scalars 343 | 344 | 345 | if __name__ == "__main__": 346 | dataset = QM9("datasets", "alpha", 2.0, "train", 3, feature_type="one_hot") 347 | print("length", len(dataset)) 348 | ys = np.array([data.y.item() for data in dataset]) 349 | mean, mad = dataset.calc_stats() 350 | 351 | for item in dataset: 352 | print(item.edge_index) 353 | break 354 | 355 | print("mean", mean, "mad", mad) 356 | import matplotlib.pyplot as plt 357 | 358 | plt.subplot(121) 359 | plt.title(dataset.target) 360 | plt.hist(ys) 361 | plt.subplot(122) 362 | plt.title(dataset.target + " standardised") 363 | plt.hist((ys - mean) / mad) 364 | plt.show() 365 | -------------------------------------------------------------------------------- /segnn_jax/blocks.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from abc import ABC, abstractmethod 3 | from typing import Callable, Optional, Tuple, Union 4 | 5 | import e3nn_jax as e3nn 6 | import haiku as hk 7 | import jax 8 | import jax.numpy as jnp 9 | from e3nn_jax import IrrepsArray 10 | 11 | try: 12 | from e3nn_jax.experimental import linear_shtp as escn 13 | except ImportError: 14 | escn = None 15 | try: 16 | from e3nn_jax import FunctionalFullyConnectedTensorProduct 17 | except ImportError: 18 | from e3nn_jax.legacy import FunctionalFullyConnectedTensorProduct # type: ignore 19 | 20 | from .config import config 21 | 22 | InitFn = Callable[[str, Tuple[int, ...], float, jnp.dtype], jnp.ndarray] 23 | TensorProductFn = Callable[[IrrepsArray, IrrepsArray], IrrepsArray] 24 | 25 | 26 | def uniform_init( 27 | name: str, 28 | path_shape: Tuple[int, ...], 29 | weight_std: float, 30 | dtype: jnp.dtype = config("default_dtype"), 31 | ) -> jnp.ndarray: 32 | return hk.get_parameter( 33 | name, 34 | shape=path_shape, 35 | dtype=dtype, 36 | init=hk.initializers.RandomUniform(minval=-weight_std, maxval=weight_std), 37 | ) 38 | 39 | 40 | class TensorProduct(hk.Module, ABC): 41 | """O(3) equivariant linear parametrized tensor product layer.""" 42 | 43 | def __init__( 44 | self, 45 | output_irreps: e3nn.Irreps, 46 | *, 47 | biases: bool = True, 48 | name: Optional[str] = None, 49 | init_fn: Optional[InitFn] = None, 50 | gradient_normalization: Optional[Union[str, float]] = None, 51 | path_normalization: Optional[Union[str, float]] = None, 52 | ): 53 | """Initialize the tensor product. 54 | 55 | Args: 56 | output_irreps: Output representation 57 | biases: If set ot true will add biases 58 | name: Name of the linear layer params 59 | init_fn: Weight initialization function. Default is uniform. 60 | gradient_normalization: Gradient normalization method. Default is "path" 61 | NOTE: gradient_normalization="element" is the default in torch and haiku. 62 | path_normalization: Path normalization method. Default is "element" 63 | """ 64 | super().__init__(name=name) 65 | 66 | if not isinstance(output_irreps, e3nn.Irreps): 67 | output_irreps = e3nn.Irreps(output_irreps) 68 | self.output_irreps = output_irreps 69 | 70 | # tp weight init 71 | if not init_fn: 72 | init_fn = uniform_init 73 | self.get_parameter = init_fn 74 | 75 | if not gradient_normalization: 76 | gradient_normalization = config("gradient_normalization") 77 | if not path_normalization: 78 | path_normalization = config("path_normalization") 79 | self._gradient_normalization = gradient_normalization 80 | self._path_normalization = path_normalization 81 | 82 | self.biases = biases and "0e" in self.output_irreps 83 | 84 | def _check_input( 85 | self, x: IrrepsArray, y: Optional[IrrepsArray] = None 86 | ) -> Tuple[IrrepsArray, IrrepsArray]: 87 | if not y: 88 | y = IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype)) 89 | 90 | if x.irreps.lmax == 0 and y.irreps.lmax == 0 and self.output_irreps.lmax > 0: 91 | warnings.warn( 92 | f"The specified output irreps ({self.output_irreps}) are not scalars " 93 | "but both operands are. This can have undesired behaviour (NaN). Try " 94 | "redistributing them into scalars or choose higher orders." 95 | ) 96 | 97 | return x, y 98 | 99 | @abstractmethod 100 | def __call__( 101 | self, x: IrrepsArray, y: Optional[IrrepsArray] = None, **kwargs 102 | ) -> IrrepsArray: 103 | """Applies an O(3) equivariant linear parametrized tensor product layer. 104 | 105 | Args: 106 | x (IrrepsArray): Left tensor 107 | y (IrrepsArray): Right tensor. If None it defaults to np.ones. 108 | 109 | Returns: 110 | The output to the weighted tensor product (IrrepsArray). 111 | """ 112 | raise NotImplementedError 113 | 114 | 115 | class O3TensorProduct(TensorProduct): 116 | """O(3) equivariant linear parametrized tensor product layer. 117 | 118 | Original O3TensorProduct version that uses tensor_product + Linear instead of 119 | FullyConnectedTensorProduct. 120 | From e3nn 0.19.2 (https://github.com/e3nn/e3nn-jax/releases/tag/0.19.2), this is 121 | as fast as FullyConnectedTensorProduct. 122 | """ 123 | 124 | def __init__( 125 | self, 126 | output_irreps: e3nn.Irreps, 127 | *, 128 | biases: bool = True, 129 | name: Optional[str] = None, 130 | init_fn: Optional[InitFn] = None, 131 | gradient_normalization: Optional[Union[str, float]] = "element", 132 | path_normalization: Optional[Union[str, float]] = None, 133 | ): 134 | super().__init__( 135 | output_irreps, 136 | biases=biases, 137 | name=name, 138 | init_fn=init_fn, 139 | gradient_normalization=gradient_normalization, 140 | path_normalization=path_normalization, 141 | ) 142 | 143 | self._linear = e3nn.haiku.Linear( 144 | self.output_irreps, 145 | get_parameter=self.get_parameter, 146 | biases=self.biases, 147 | gradient_normalization=self._gradient_normalization, 148 | path_normalization=self._path_normalization, 149 | ) 150 | 151 | def _check_input( 152 | self, x: IrrepsArray, y: Optional[IrrepsArray] = None 153 | ) -> Tuple[IrrepsArray, IrrepsArray]: 154 | x, y = super()._check_input(x, y) 155 | miss = self.output_irreps.filter(drop=e3nn.tensor_product(x.irreps, y.irreps)) 156 | if len(miss) > 0: 157 | warnings.warn(f"Output irreps: '{miss}' are unreachable and were ignored.") 158 | return x, y 159 | 160 | def __call__(self, x: IrrepsArray, y: Optional[IrrepsArray] = None) -> IrrepsArray: 161 | x, y = self._check_input(x, y) 162 | # tensor product + linear 163 | tp = self._linear(e3nn.tensor_product(x, y)) 164 | return tp 165 | 166 | 167 | class O3TensorProductFC(TensorProduct): 168 | """ 169 | O(3) equivariant linear parametrized tensor product layer. 170 | 171 | Functionally the same as O3TensorProduct, but uses FullyConnectedTensorProduct and 172 | is slightly slower (~5-10%) than tensor_prodict + Linear. 173 | """ 174 | 175 | def _build_tensor_product( 176 | self, left_irreps: e3nn.Irreps, right_irreps: e3nn.Irreps 177 | ) -> Callable: 178 | """Build the tensor product function.""" 179 | tp = FunctionalFullyConnectedTensorProduct( 180 | left_irreps, 181 | right_irreps, 182 | self.output_irreps, 183 | gradient_normalization=self._gradient_normalization, 184 | path_normalization=self._path_normalization, 185 | ) 186 | ws = [ 187 | self.get_parameter( 188 | name=( 189 | f"w[{ins.i_in1},{ins.i_in2},{ins.i_out}] " 190 | f"{tp.irreps_in1[ins.i_in1]}," 191 | f"{tp.irreps_in2[ins.i_in2]}," 192 | f"{tp.irreps_out[ins.i_out]}" 193 | ), 194 | path_shape=ins.path_shape, 195 | weight_std=ins.weight_std, 196 | ) 197 | for ins in tp.instructions 198 | ] 199 | 200 | def _tensor_product(x, y, **kwargs): 201 | out = tp.left_right(ws, x, y, **kwargs) 202 | # same as out.rechunk(self.output_irreps) but works with older e3nn versions 203 | return IrrepsArray(self.output_irreps, out.array) 204 | 205 | # naive broadcasting wrapper 206 | # TODO: not the best 207 | def _tp_wrapper(*args): 208 | leading_shape = jnp.broadcast_shapes(*(arg.shape[:-1] for arg in args)) 209 | args = [arg.broadcast_to(leading_shape + (-1,)) for arg in args] 210 | for _ in range(len(leading_shape)): 211 | f = jax.vmap(_tensor_product) 212 | return f(*args) 213 | 214 | return _tp_wrapper 215 | 216 | def _build_biases(self) -> Callable: 217 | """Build the add bias function.""" 218 | b = [ 219 | self.get_parameter( 220 | f"b[{i_out}] {self.output_irreps}", 221 | path_shape=(mul_ir.dim,), 222 | weight_std=1 / jnp.sqrt(mul_ir.dim), 223 | ) 224 | for i_out, mul_ir in enumerate(self.output_irreps) 225 | if mul_ir.ir.is_scalar() 226 | ] 227 | b = IrrepsArray(f"{self.output_irreps.count('0e')}x0e", jnp.concatenate(b)) 228 | 229 | # TODO: could be improved 230 | def _bias_wrapper(x: IrrepsArray) -> IrrepsArray: 231 | scalars = x.filter("0e") 232 | other = x.filter(drop="0e") 233 | return e3nn.concatenate( 234 | [scalars + b.broadcast_to(scalars.shape), other], axis=1 235 | ) 236 | 237 | return _bias_wrapper 238 | 239 | def __call__( 240 | self, x: IrrepsArray, y: Optional[IrrepsArray] = None, **kwargs 241 | ) -> IrrepsArray: 242 | x, y = self._check_input(x, y) 243 | 244 | tp = self._build_tensor_product(x.irreps, y.irreps) 245 | output = tp(x, y, **kwargs) 246 | 247 | if self.biases: 248 | # add biases 249 | bias_fn = self._build_biases() 250 | return bias_fn(output) 251 | 252 | return output 253 | 254 | 255 | class O3TensorProductSCN(TensorProduct): 256 | """ 257 | O(3) equivariant linear parametrized tensor product layer. 258 | 259 | O3TensorProduct with eSCN optimization for larger spherical harmonic orders. Should 260 | be used without spherical harmonics on the inputs. 261 | """ 262 | 263 | def __init__( 264 | self, 265 | output_irreps: e3nn.Irreps, 266 | *, 267 | biases: bool = True, 268 | name: Optional[str] = None, 269 | init_fn: Optional[InitFn] = None, 270 | gradient_normalization: Optional[Union[str, float]] = None, 271 | path_normalization: Optional[Union[str, float]] = None, 272 | ): 273 | super().__init__( 274 | output_irreps, 275 | biases=biases, 276 | name=name, 277 | init_fn=init_fn, 278 | gradient_normalization=gradient_normalization, 279 | path_normalization=path_normalization, 280 | ) 281 | 282 | if escn is None: 283 | raise ImportError( 284 | "eSCN is available from e3nn-jax>=0.17.3. " 285 | f"Your version: {e3nn.__version__}" 286 | ) 287 | 288 | self._linear = e3nn.haiku.Linear( 289 | self.output_irreps, 290 | get_parameter=self.get_parameter, 291 | biases=self.biases, 292 | gradient_normalization=self._gradient_normalization, 293 | path_normalization=self._path_normalization, 294 | ) 295 | 296 | def _check_input( 297 | self, x: IrrepsArray, y: Optional[IrrepsArray] = None 298 | ) -> Tuple[IrrepsArray, IrrepsArray]: 299 | if not y: 300 | raise ValueError("eSCN cannot be used without the right input.") 301 | return super()._check_input(x, y) 302 | 303 | def __call__(self, x: IrrepsArray, y: Optional[IrrepsArray] = None) -> IrrepsArray: 304 | """Apply the layer. y must not be into spherical harmonics.""" 305 | x, y = self._check_input(x, y) 306 | # TODO make this work for e3nn-jax<0.19.1 (without e3nn.utils.vmap) 307 | try: 308 | from e3nn_jax.utils import vmap as e3nn_vmap 309 | except ImportError: 310 | raise NotImplementedError() 311 | shtp = e3nn_vmap(escn.shtp, in_axes=(0, 0, None)) 312 | tp = shtp(x, y, self.output_irreps) 313 | return self._linear(tp) 314 | 315 | 316 | O3_LAYERS = { 317 | "tpl": O3TensorProduct, 318 | "fctp": O3TensorProductFC, 319 | "scn": O3TensorProductSCN, 320 | } 321 | 322 | 323 | def O3TensorProductGate( 324 | output_irreps: e3nn.Irreps, 325 | *, 326 | biases: bool = True, 327 | scalar_activation: Optional[Callable] = None, 328 | gate_activation: Optional[Callable] = None, 329 | name: Optional[str] = None, 330 | init_fn: Optional[InitFn] = None, 331 | o3_layer: Optional[Union[str, TensorProduct]] = None, 332 | ) -> TensorProductFn: 333 | """Non-linear (gated) O(3) equivariant linear tensor product layer. 334 | 335 | The tensor product lifts the input representation to have gating scalars. 336 | 337 | Args: 338 | output_irreps: Output representation 339 | biases: Add biases 340 | scalar_activation: Activation function for scalars 341 | gate_activation: Activation function for higher order 342 | name: Name of the linear layer params 343 | o3_layer: Tensor product layer type. "tpl", "fctp", "scn" or a custom layer 344 | 345 | Returns: 346 | Function that applies the gated tensor product layer 347 | """ 348 | 349 | if not isinstance(output_irreps, e3nn.Irreps): 350 | output_irreps = e3nn.Irreps(output_irreps) 351 | 352 | # lift output with gating scalars 353 | gate_irreps = e3nn.Irreps( 354 | f"{output_irreps.num_irreps - output_irreps.count('0e')}x0e" 355 | ) 356 | 357 | if o3_layer is None: 358 | o3_layer = config("o3_layer") 359 | 360 | if isinstance(o3_layer, str): 361 | assert o3_layer in O3_LAYERS, f"Unknown O3 layer {o3_layer}." 362 | O3Layer = O3_LAYERS[o3_layer] 363 | else: 364 | O3Layer = o3_layer 365 | 366 | tensor_product = O3Layer( 367 | (gate_irreps + output_irreps).regroup(), 368 | biases=biases, 369 | name=name, 370 | init_fn=init_fn, 371 | ) 372 | 373 | if not scalar_activation: 374 | scalar_activation = jax.nn.silu 375 | if not gate_activation: 376 | gate_activation = jax.nn.sigmoid 377 | 378 | def _gated_tensor_product( 379 | x: IrrepsArray, y: Optional[IrrepsArray] = None, **kwargs 380 | ) -> IrrepsArray: 381 | tp = tensor_product(x, y, **kwargs) 382 | # skip gate if the gating scalars are not reachable 383 | if len(gate_irreps.filter(drop=tp.irreps)) > 0: 384 | return tp 385 | else: 386 | return e3nn.gate(tp, scalar_activation, odd_gate_act=gate_activation) 387 | 388 | return _gated_tensor_product 389 | --------------------------------------------------------------------------------