├── experiments
├── qm9
│ ├── __init__.py
│ ├── README.md
│ ├── utils.py
│ └── dataset.py
├── nbody
│ ├── __init__.py
│ ├── README.md
│ ├── data
│ │ ├── generate_dataset.py
│ │ └── synthetic_sim.py
│ ├── datasets.py
│ └── utils.py
├── requirements.txt
├── __init__.py
└── train.py
├── setup.py
├── pyproject.toml
├── requirements.txt
├── .gitignore
├── segnn_jax
├── config.py
├── __init__.py
├── graph_utils.py
├── irreps_computer.py
├── segnn.py
└── blocks.py
├── setup.cfg
├── .pre-commit-config.yaml
├── LICENSE
├── .github
└── workflows
│ └── build_branch.yaml
├── tests
├── test_blocks.py
├── conftest.py
└── test_segnn.py
├── README.md
└── validate.py
/experiments/qm9/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/nbody/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | if __name__ == "__main__":
4 | setup()
5 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
2 | dm-haiku>=0.0.9
3 | e3nn-jax==0.17.5
4 | jax[cuda]
5 | jraph==0.0.6.dev0
6 | numpy>=1.23.4
7 | optax==0.1.7
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # experiments
2 | *.npy
3 | wandb/
4 | *.out
5 | datasets/
6 |
7 | # dev
8 | .vscode
9 | __pycache__/
10 | *.pyc
11 | venv
12 |
13 | # build
14 | *.whl
15 | build/
16 | dist/
17 | *.egg-info/
18 |
--------------------------------------------------------------------------------
/experiments/qm9/README.md:
--------------------------------------------------------------------------------
1 | # QM9 experiment
2 | Simple baseline experiments from the SEGNN paper. The dataset implementation is taken directly from the [original torch implementation](https://github.com/RobDHess/Steerable-E3-GNN/blob/main/qm9/dataset.py).
3 |
--------------------------------------------------------------------------------
/experiments/nbody/README.md:
--------------------------------------------------------------------------------
1 | # N-body experiments (charged and gravity)
2 | Simple baseline experiments from the SEGNN paper.
3 |
4 | This code is adapted or borrows heavily from the N-Body implementation of [EGNN](https://github.com/vgsatorras/egnn). Used under the MIT license.
5 |
--------------------------------------------------------------------------------
/experiments/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cpu
2 |
3 | --find-links https://data.pyg.org/whl/torch-1.13.1+cpu.html
4 |
5 | e3nn==0.5.0
6 | matplotlib>=3.6.2
7 | rdkit==2022.9.2
8 | torch==1.13.1
9 | torch-cluster==1.6.0
10 | torch-geometric==2.1.0
11 | torch-scatter==2.1.0
12 | torch-sparse==0.6.15
13 | tqdm>=4.64.1
14 | wandb>=0.13.5
15 |
--------------------------------------------------------------------------------
/segnn_jax/config.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 |
3 | __conf = {
4 | "gradient_normalization": "element", # "element" or "path"
5 | "path_normalization": "element", # "element" or "path"
6 | "default_dtype": jnp.float32,
7 | "o3_layer": "tpl", # "tpl" (tp + Linear) or "fctp" (FullyConnected) or "scn" (SCN)
8 | }
9 |
10 |
11 | def config(key):
12 | return __conf[key]
13 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = segnn_jax
3 | version = attr: segnn_jax.__version__
4 | author = Gianluca Galletti
5 | author_email = g.galletti@tum.de
6 | description = Steerable E(3) GNN in jax
7 | long_description = file: README.md
8 | long_description_content_type = text/markdown
9 | classifiers =
10 | Programming Language :: Python :: 3.8
11 |
12 |
13 | [options]
14 | packages = segnn_jax
15 | python_requires = >=3.8
16 | install_requires =
17 | dm_haiku>=0.0.9
18 | e3nn_jax==0.17.5
19 | jax
20 | jaxlib
21 | jraph==0.0.6.dev0
22 | numpy>=1.23.4
23 | optax==0.1.3
24 |
--------------------------------------------------------------------------------
/segnn_jax/__init__.py:
--------------------------------------------------------------------------------
1 | from .blocks import (
2 | O3TensorProduct,
3 | O3TensorProductFC,
4 | O3TensorProductGate,
5 | O3TensorProductSCN,
6 | )
7 | from .graph_utils import SteerableGraphsTuple
8 | from .irreps_computer import balanced_irreps, weight_balanced_irreps
9 | from .segnn import SEGNN, SEGNNLayer
10 |
11 | __all__ = [
12 | "SEGNN",
13 | "SEGNNLayer",
14 | "O3TensorProduct",
15 | "O3TensorProductFC",
16 | "O3TensorProductSCN",
17 | "O3TensorProductGate",
18 | "weight_balanced_irreps",
19 | "balanced_irreps",
20 | "SteerableGraphsTuple",
21 | ]
22 |
23 | __version__ = "0.7"
24 |
--------------------------------------------------------------------------------
/experiments/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple
2 |
3 | from torch.utils.data import DataLoader
4 |
5 | from .nbody.utils import setup_nbody_data
6 | from .qm9.utils import setup_qm9_data
7 | from .train import train
8 |
9 | __all__ = ["setup_data", "train"]
10 |
11 |
12 | __setup_conf = {
13 | "qm9": setup_qm9_data,
14 | "charged": setup_nbody_data,
15 | "gravity": setup_nbody_data,
16 | }
17 |
18 |
19 | def setup_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]:
20 | assert args.dataset in [
21 | "qm9",
22 | "charged",
23 | "gravity",
24 | ], f"Unknown dataset {args.dataset}"
25 | setup_fn = __setup_conf[args.dataset]
26 | return setup_fn(args)
27 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 | exclude: |
4 | (?x)^(
5 | datasets/|
6 | data/|
7 | )$
8 | repos:
9 | - repo: https://github.com/pre-commit/pre-commit-hooks
10 | rev: v4.0.1
11 | hooks:
12 | - id: check-merge-conflict
13 | - id: check-added-large-files
14 | - id: check-docstring-first
15 | - id: check-json
16 | - id: check-toml
17 | - id: check-xml
18 | - id: check-yaml
19 | - id: trailing-whitespace
20 | - id: end-of-file-fixer
21 | - id: requirements-txt-fixer
22 | - repo: https://github.com/pycqa/isort
23 | rev: 5.12.0
24 | hooks:
25 | - id: isort
26 | args: [ --profile, black ]
27 | - repo: https://github.com/ambv/black
28 | rev: 22.3.0
29 | hooks:
30 | - id: black
31 | - repo: https://github.com/charliermarsh/ruff-pre-commit
32 | rev: 'v0.0.265'
33 | hooks:
34 | - id: ruff
35 | exclude: ^datasets/|^data/|^.git/|^venv/
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [2022] [Gianluca Galletti]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/segnn_jax/graph_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, NamedTuple, Optional
2 |
3 | import e3nn_jax as e3nn
4 | import jax.numpy as jnp
5 | import jax.tree_util as tree
6 | import jraph
7 |
8 |
9 | class SteerableGraphsTuple(NamedTuple):
10 | """Pack (steerable) node and edge attributes with jraph.GraphsTuple."""
11 |
12 | graph: jraph.GraphsTuple
13 | node_attributes: Optional[e3nn.IrrepsArray] = None
14 | edge_attributes: Optional[e3nn.IrrepsArray] = None
15 | # NOTE: additional_message_features is in a separate field otherwise it would get
16 | # updated by jraph.GraphNetwork. Actual graph edges are used only for the messages.
17 | additional_message_features: Optional[e3nn.IrrepsArray] = None
18 |
19 |
20 | def pooling(
21 | graph: jraph.GraphsTuple,
22 | aggregate_fn: Callable = jraph.segment_sum,
23 | ) -> e3nn.IrrepsArray:
24 | """Pools over graph nodes with the specified aggregation.
25 |
26 | Args:
27 | graph: Input graph
28 | aggregate_fn: function used to update pool over the nodes
29 |
30 | Returns:
31 | The pooled graph nodes.
32 | """
33 | n_graphs = graph.n_node.shape[0]
34 | graph_idx = jnp.arange(n_graphs)
35 | # Equivalent to jnp.sum(n_node), but jittable
36 | sum_n_node = tree.tree_leaves(graph.nodes)[0].shape[0]
37 | batch = jnp.repeat(graph_idx, graph.n_node, axis=0, total_repeat_length=sum_n_node)
38 | return e3nn.IrrepsArray(
39 | graph.nodes.irreps, aggregate_fn(graph.nodes.array, batch, n_graphs)
40 | )
41 |
--------------------------------------------------------------------------------
/.github/workflows/build_branch.yaml:
--------------------------------------------------------------------------------
1 | name: Build
2 |
3 | on: push
4 |
5 | jobs:
6 | tests:
7 | name: Tests
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@master
11 | - name: Set up docker image
12 | uses: actions/setup-python@v3
13 | with:
14 | python-version: "3.10"
15 | - name: Install dependencies
16 | run: >-
17 | python -m
18 | pip install
19 | -r requirements.txt
20 | --user
21 | - name: Install latest e3nn-jax version
22 | run: >-
23 | python -m
24 | pip install -U
25 | e3nn-jax
26 | --user
27 | - name: Install pytest
28 | run: >-
29 | python -m
30 | pip install
31 | pytest
32 | --user
33 | - name: Run tests
34 | run: >-
35 | python -m pytest tests/
36 |
37 | build-publish:
38 | name: Build and publish
39 | runs-on: ubuntu-latest
40 | steps:
41 | - uses: actions/checkout@master
42 | - name: Set up docker image
43 | uses: actions/setup-python@v3
44 | with:
45 | python-version: "3.10"
46 | - name: Install build tools
47 | run: >-
48 | python -m
49 | pip install
50 | build
51 | --user
52 | - name: Build wheel
53 | run: >-
54 | python -m
55 | build
56 | --sdist
57 | --wheel
58 | --outdir dist/
59 | - name: Publish to PyPI
60 | if: startsWith(github.ref, 'refs/tags')
61 | uses: pypa/gh-action-pypi-publish@release/v1
62 | with:
63 | password: ${{ secrets.PYPI_TOKEN }}
64 |
--------------------------------------------------------------------------------
/tests/test_blocks.py:
--------------------------------------------------------------------------------
1 | import e3nn_jax as e3nn
2 | import haiku as hk
3 | import pytest
4 | from conftest import assert_equivariant
5 |
6 | from segnn_jax import (
7 | O3TensorProduct,
8 | O3TensorProductFC,
9 | O3TensorProductGate,
10 | O3TensorProductSCN,
11 | )
12 |
13 |
14 | @pytest.mark.parametrize("biases", [False, True])
15 | @pytest.mark.parametrize(
16 | "O3Layer", [O3TensorProduct, O3TensorProductFC, O3TensorProductSCN]
17 | )
18 | def test_linear_layers(key, biases, O3Layer):
19 | def f(x1, x2):
20 | return O3Layer("1x1o", biases=biases)(x1, x2)
21 |
22 | f = hk.without_apply_rng(hk.transform(f))
23 |
24 | v = e3nn.normal("1x1o", key, (5,))
25 | params = f.init(key, v, v)
26 |
27 | def wrapper(x1, x2):
28 | return f.apply(params, x1, x2)
29 |
30 | assert_equivariant(
31 | wrapper,
32 | key,
33 | e3nn.normal("1x1o", key, (5,)),
34 | e3nn.normal("1x1o", key, (5,)),
35 | )
36 |
37 |
38 | @pytest.mark.parametrize("biases", [False, True])
39 | @pytest.mark.parametrize(
40 | "O3Layer", [O3TensorProduct, O3TensorProductFC, O3TensorProductSCN]
41 | )
42 | def test_gated_layers(key, biases, O3Layer):
43 | def f(x1, x2):
44 | return O3TensorProductGate("1x1o", biases=biases, o3_layer=O3Layer)(x1, x2)
45 |
46 | f = hk.without_apply_rng(hk.transform(f))
47 |
48 | v = e3nn.normal("1x1o", key, (5,))
49 | params = f.init(key, v, v)
50 |
51 | def wrapper(x1, x2):
52 | return f.apply(params, x1, x2)
53 |
54 | assert_equivariant(
55 | wrapper,
56 | key,
57 | e3nn.normal("1x1o", key, (5,)),
58 | e3nn.normal("1x1o", key, (5,)),
59 | )
60 |
61 |
62 | if __name__ == "__main__":
63 | pytest.main()
64 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import e3nn_jax as e3nn
4 | import jax
5 | import jax.numpy as jnp
6 | import jraph
7 | import pytest
8 |
9 | from segnn_jax import SteerableGraphsTuple
10 |
11 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
12 | os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
13 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
14 |
15 |
16 | @pytest.fixture
17 | def dummy_graph():
18 | def _rand_graph(n_graphs: int = 1, attr_irreps: str = "1x0e + 1x1o"):
19 | attr_irreps = e3nn.Irreps(attr_irreps)
20 | return SteerableGraphsTuple(
21 | graph=jraph.GraphsTuple(
22 | nodes=e3nn.IrrepsArray("1x1o", jnp.ones((n_graphs * 5, 3))),
23 | edges=None,
24 | senders=jnp.zeros((10 * n_graphs,), dtype=jnp.int32),
25 | receivers=jnp.zeros((10 * n_graphs,), dtype=jnp.int32),
26 | n_node=jnp.array([5] * n_graphs),
27 | n_edge=jnp.array([10] * n_graphs),
28 | globals=None,
29 | ),
30 | additional_message_features=None,
31 | edge_attributes=None,
32 | node_attributes=e3nn.IrrepsArray(
33 | attr_irreps, jnp.ones((n_graphs * 5, attr_irreps.dim))
34 | ),
35 | )
36 |
37 | return _rand_graph
38 |
39 |
40 | @pytest.fixture
41 | def key():
42 | return jax.random.PRNGKey(0)
43 |
44 |
45 | def assert_equivariant(fun, key, *args, **kwargs):
46 | try:
47 | from e3nn_jax.utils import assert_equivariant as _assert_equivariant
48 |
49 | _assert_equivariant(fun, key, *args, **kwargs)
50 | except (ImportError, AttributeError):
51 | from e3nn_jax.util import assert_equivariant as _assert_equivariant
52 |
53 | _assert_equivariant(fun, key, args_in=args, **kwargs)
54 |
--------------------------------------------------------------------------------
/tests/test_segnn.py:
--------------------------------------------------------------------------------
1 | import e3nn_jax as e3nn
2 | import haiku as hk
3 | import pytest
4 | from conftest import assert_equivariant
5 |
6 | from segnn_jax import (
7 | SEGNN,
8 | O3TensorProduct,
9 | O3TensorProductFC,
10 | O3TensorProductSCN,
11 | weight_balanced_irreps,
12 | )
13 |
14 |
15 | @pytest.mark.parametrize("task", ["graph", "node"])
16 | @pytest.mark.parametrize("norm", ["none", "instance"])
17 | @pytest.mark.parametrize(
18 | "O3Layer", [O3TensorProduct, O3TensorProductFC, O3TensorProductSCN]
19 | )
20 | def test_segnn_equivariance(key, dummy_graph, task, norm, O3Layer):
21 | scn = O3Layer == O3TensorProductSCN
22 |
23 | hidden_irreps = weight_balanced_irreps(
24 | 8, e3nn.Irreps.spherical_harmonics(1), use_sh=not scn
25 | )
26 |
27 | def segnn(x):
28 | return SEGNN(
29 | hidden_irreps=hidden_irreps,
30 | output_irreps=e3nn.Irreps("1x1o"),
31 | num_layers=1,
32 | task=task,
33 | norm=norm,
34 | o3_layer=O3Layer,
35 | )(x)
36 |
37 | segnn = hk.without_apply_rng(hk.transform_with_state(segnn))
38 |
39 | if scn:
40 | attr_irreps = e3nn.Irreps("1x1o")
41 | else:
42 | attr_irreps = e3nn.Irreps("1x0e+1x1o")
43 |
44 | graph = dummy_graph(attr_irreps=attr_irreps)
45 | params, segnn_state = segnn.init(key, graph)
46 |
47 | def wrapper(x):
48 | if scn:
49 | attrs = e3nn.IrrepsArray(attr_irreps, x.array)
50 | else:
51 | attrs = e3nn.spherical_harmonics(attr_irreps, x, normalize=True)
52 | st_graph = graph._replace(
53 | graph=graph.graph._replace(nodes=x),
54 | node_attributes=attrs,
55 | )
56 | y, _ = segnn.apply(params, segnn_state, st_graph)
57 | return e3nn.IrrepsArray("1x1o", y)
58 |
59 | assert_equivariant(wrapper, key, e3nn.normal("1x1o", key, (5,)))
60 |
61 |
62 | if __name__ == "__main__":
63 | pytest.main()
64 |
--------------------------------------------------------------------------------
/segnn_jax/irreps_computer.py:
--------------------------------------------------------------------------------
1 | from math import prod
2 |
3 | from e3nn_jax import Irreps
4 |
5 |
6 | def balanced_irreps(lmax: int, feature_size: int, use_sh: bool = True) -> Irreps:
7 | """Allocates irreps uniformely up until level lmax with budget feature_size."""
8 | irreps = ["0e"]
9 | n_irreps = 1 + (lmax if use_sh else lmax * 2)
10 | total_dim = 0
11 | for level in range(1, lmax + 1):
12 | dim = 2 * level + 1
13 | multi = int(feature_size / dim / n_irreps)
14 | if multi == 0:
15 | break
16 | if use_sh:
17 | irreps.append(f"{multi}x{level}{'e' if (level % 2) == 0 else 'o'}")
18 | total_dim = multi * dim
19 | else:
20 | irreps.append(f"{multi}x{level}e+{multi}x{level}o")
21 | total_dim = multi * dim * 2
22 |
23 | # add scalars to fill missing dimensions
24 | irreps[0] = f"{feature_size - total_dim}x{irreps[0]}"
25 |
26 | return Irreps("+".join(irreps))
27 |
28 |
29 | def weight_balanced_irreps(
30 | scalar_units: int, irreps_right: Irreps, use_sh: bool = True, lmax: int = None
31 | ) -> Irreps:
32 | """
33 | Determines irreps_left such that the parametrized tensor product
34 | Linear(tensor_product(irreps_left, irreps_right))
35 | has (at least) scalar_units weights.
36 |
37 | Args:
38 | scalar_units: number of desired weights
39 | irreps_right: irreps of the right tensor
40 | use_sh: whether to use spherical harmonics
41 | lmax: maximum level of spherical harmonics
42 | """
43 | # irrep order
44 | if lmax is None:
45 | lmax = irreps_right.lmax
46 | # linear layer with squdare weight matrix
47 | linear_weights = scalar_units**2
48 | # raise hidden features until enough weigths
49 | n = 0
50 | while True:
51 | n += 1
52 | if use_sh:
53 | irreps_left = (
54 | (Irreps.spherical_harmonics(lmax) * n).sort().irreps.simplify()
55 | )
56 | else:
57 | irreps_left = balanced_irreps(lmax, n)
58 | # number of paths
59 | tp_weights = sum(
60 | prod([irreps_left[i_1].mul ** 2, irreps_right[i_2].mul])
61 | for i_1, (_, ir_1) in enumerate(irreps_left)
62 | for i_2, (_, ir_2) in enumerate(irreps_right)
63 | for _, (_, ir_out) in enumerate(irreps_left)
64 | if ir_out in ir_1 * ir_2
65 | )
66 | if tp_weights >= linear_weights:
67 | break
68 | return Irreps(irreps_left)
69 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Steerable E(3) GNN in jax
2 | Reimplementation of [SEGNN](https://arxiv.org/abs/2110.02905) in jax. Original work by Johannes Brandstetter, Rob Hesselink, Elise van der Pol, Erik Bekkers and Max Welling.
3 |
4 | ## Why jax?
5 | **40-50% faster** inference and training compared to the [original torch implementation](https://github.com/RobDHess/Steerable-E3-GNN). Also JAX-MD.
6 |
7 | ## Installation
8 | ```
9 | python -m pip install segnn-jax
10 | ```
11 |
12 | Or clone this repository and build locally
13 | ```
14 | python -m pip install -e .
15 | ```
16 |
17 | ### GPU support
18 | Upgrade `jax` to the gpu version
19 | ```
20 | pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
21 | ```
22 |
23 | ## Validation
24 | N-body (charged and gravity) and QM9 datasets are included for completeness from the original paper.
25 |
26 | ### Results
27 | Charged is on 5 bodies, gravity on 100 bodies. QM9 has graphs of variable sizes, so in jax samples are padded to the maximum size. Loss is MSE for Charged and Gravity and MAE for QM9.
28 |
29 | Times are remeasured on Quadro RTX 4000, __model only__ on batches of 100 graphs, in (global) single precision.
30 |
31 |
32 |
33 | |
34 | torch (original) |
35 | jax (ours) |
36 |
37 |
38 | |
39 | Loss |
40 | Inference [ms] |
41 | Loss |
42 | Inference [ms] |
43 |
44 |
45 | charged (position) |
46 | .0043 |
47 | 21.22 |
48 | .0045 |
49 | 3.77 |
50 |
51 |
52 | gravity (position) |
53 | .265 |
54 | 60.55 |
55 | .264 |
56 | 41.72 |
57 |
58 |
59 | QM9 (alpha) |
60 | .066* |
61 | 82.53 |
62 | .082 |
63 | 105.98** |
64 |
65 |
66 | * rerun on same conditions
67 |
68 | ** padded (naive)
69 |
70 | ### Validation install
71 |
72 | The experiments are only included in the github repo, so it needs to be cloned first.
73 | ```
74 | git clone https://github.com/gerkone/segnn-jax
75 | ```
76 |
77 | They are adapted from the original implementation, so additionally `torch` and `torch_geometric` are needed (cpu versions are enough).
78 | ```
79 | python -m pip install -r experiments/requirements.txt
80 | ```
81 |
82 | ### Datasets
83 | QM9 is automatically downloaded and processed when running the respective experiment.
84 |
85 | The N-body datasets have to be generated locally from the directory [experiments/nbody/data](experiments/nbody/data) (it will take some time, especially n-body `gravity`)
86 | #### Charged dataset (5 bodies, 10000 training samples)
87 | ```
88 | python3 -u generate_dataset.py --simulation=charged --seed=43
89 | ```
90 | #### Gravity dataset (100 bodies, 10000 training samples)
91 | ```
92 | python3 -u generate_dataset.py --simulation=gravity --n-balls=100 --seed=43
93 | ```
94 |
95 | ### Notes
96 | On `jax<=0.4.6`, the `jit`-`pjit` merge can be deactivated making traning faster (on nbody). This looks like an issue with dataloading and the validation training loop implementation and it does not affect SEGNN.
97 | ```
98 | export JAX_JIT_PJIT_API_MERGE=0
99 | ```
100 |
101 | ### Usage
102 | #### N-body (charged)
103 | ```
104 | python validate.py --dataset=charged --epochs=200 --max-samples=3000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12
105 | ```
106 |
107 | #### N-body (gravity)
108 | ```
109 | python validate.py --dataset=gravity --epochs=100 --target=pos --max-samples=10000 --lmax-hidden=1 --lmax-attributes=1 --layers=4 --units=64 --norm=none --batch-size=100 --lr=5e-3 --weight-decay=1e-12 --neighbours=5 --n-bodies=100
110 | ```
111 |
112 | #### QM9
113 | ```
114 | python validate.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --lmax-attributes=3 --layers=7 --units=128 --norm=instance --batch-size=128 --lr=5e-4 --weight-decay=1e-8 --lr-scheduling
115 | ```
116 |
117 | (configurations used in validation)
118 |
119 |
120 |
121 | ## Acknowledgments
122 | - [e3nn_jax](https://github.com/e3nn/e3nn-jax) made this reimplementation possible.
123 | - [Artur Toshev](https://github.com/arturtoshev) and [Johannes Brandsetter](https://github.com/brandstetter-johannes), for support.
124 |
--------------------------------------------------------------------------------
/experiments/nbody/data/generate_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate charged and gravity datasets.
3 |
4 | charged: python3 generate_dataset.py --simulation=charged --num-train=10000 --seed=43
5 | gravity: python3 generate_dataset.py --simulation=gravity --num-train=10000 --n-balls=100 --seed=43
6 | """
7 | import argparse
8 | import time
9 |
10 | import numpy as np
11 | from synthetic_sim import ChargedParticlesSim, GravitySim
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument(
15 | "--simulation", type=str, default="charged", help="What simulation to generate."
16 | )
17 | parser.add_argument(
18 | "--num-train",
19 | type=int,
20 | default=10000,
21 | help="Number of training simulations to generate.",
22 | )
23 | parser.add_argument(
24 | "--num-valid",
25 | type=int,
26 | default=2000,
27 | help="Number of validation simulations to generate.",
28 | )
29 | parser.add_argument(
30 | "--num-test", type=int, default=2000, help="Number of test simulations to generate."
31 | )
32 | parser.add_argument("--length", type=int, default=5000, help="Length of trajectory.")
33 | parser.add_argument(
34 | "--length-test", type=int, default=5000, help="Length of test set trajectory."
35 | )
36 | parser.add_argument(
37 | "--sample-freq", type=int, default=100, help="How often to sample the trajectory."
38 | )
39 | parser.add_argument(
40 | "--n-balls", type=int, default=5, help="Number of balls in the simulation."
41 | )
42 | parser.add_argument("--seed", type=int, default=42, help="Random seed.")
43 | parser.add_argument(
44 | "--initial-vel", type=int, default=1, help="consider initial velocity"
45 | )
46 |
47 | args = parser.parse_args()
48 |
49 | initial_vel_norm = 0.5
50 | if not args.initial_vel:
51 | initial_vel_norm = 1e-16
52 |
53 | if args.simulation == "charged":
54 | sim = ChargedParticlesSim(
55 | noise_var=0.0, n_balls=args.n_balls, vel_norm=initial_vel_norm
56 | )
57 | suffix = "_charged"
58 | elif args.simulation == "gravity":
59 | sim = GravitySim(noise_var=0.0, n_balls=args.n_balls, vel_norm=initial_vel_norm)
60 | suffix = "_gravity"
61 | else:
62 | raise ValueError("Simulation {} not implemented".format(args.simulation))
63 |
64 | suffix += str(args.n_balls) + "_initvel%d" % args.initial_vel
65 | np.random.seed(args.seed)
66 |
67 | print(suffix)
68 |
69 |
70 | def generate_dataset(num_sims, length, sample_freq):
71 | loc_all = list()
72 | vel_all = list()
73 | edges_all = list()
74 | charges_all = list()
75 | for i in range(num_sims):
76 | t = time.time()
77 | loc, vel, edges, charges = sim.sample_trajectory(
78 | T=length, sample_freq=sample_freq
79 | )
80 |
81 | loc_all.append(loc)
82 | vel_all.append(vel)
83 | edges_all.append(edges)
84 | charges_all.append(charges)
85 |
86 | if i % 100 == 0:
87 | print("Iter: {}, Simulation time: {}".format(i, time.time() - t))
88 |
89 | charges_all = np.stack(charges_all)
90 | loc_all = np.stack(loc_all)
91 | vel_all = np.stack(vel_all)
92 | edges_all = np.stack(edges_all)
93 |
94 | return loc_all, vel_all, edges_all, charges_all
95 |
96 |
97 | if __name__ == "__main__":
98 | print("Generating {} training simulations".format(args.num_train))
99 | loc_train, vel_train, edges_train, charges_train = generate_dataset(
100 | args.num_train, args.length, args.sample_freq
101 | )
102 |
103 | print("Generating {} validation simulations".format(args.num_valid))
104 | loc_valid, vel_valid, edges_valid, charges_valid = generate_dataset(
105 | args.num_valid, args.length, args.sample_freq
106 | )
107 |
108 | print("Generating {} test simulations".format(args.num_test))
109 | loc_test, vel_test, edges_test, charges_test = generate_dataset(
110 | args.num_test, args.length_test, args.sample_freq
111 | )
112 |
113 | np.save("loc_train" + suffix + ".npy", loc_train)
114 | np.save("vel_train" + suffix + ".npy", vel_train)
115 | np.save("edges_train" + suffix + ".npy", edges_train)
116 | np.save("q_train" + suffix + ".npy", charges_train)
117 |
118 | np.save("loc_valid" + suffix + ".npy", loc_valid)
119 | np.save("vel_valid" + suffix + ".npy", vel_valid)
120 | np.save("edges_valid" + suffix + ".npy", edges_valid)
121 | np.save("q_valid" + suffix + ".npy", charges_valid)
122 |
123 | np.save("loc_test" + suffix + ".npy", loc_test)
124 | np.save("vel_test" + suffix + ".npy", vel_test)
125 | np.save("edges_test" + suffix + ".npy", edges_test)
126 | np.save("q_test" + suffix + ".npy", charges_test)
127 |
--------------------------------------------------------------------------------
/experiments/qm9/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple
2 |
3 | import e3nn_jax as e3nn
4 | import jax.numpy as jnp
5 | import jraph
6 | from torch_geometric.data import Data
7 | from torch_geometric.loader import DataLoader
8 |
9 | from segnn_jax import SteerableGraphsTuple
10 |
11 | from .dataset import QM9
12 |
13 |
14 | def QM9GraphTransform(
15 | args,
16 | max_batch_nodes: int,
17 | max_batch_edges: int,
18 | train_trn: Callable,
19 | ) -> Callable:
20 | """
21 | Build a function that converts torch DataBatch into SteerableGraphsTuple.
22 |
23 | Mostly a quick fix out of lazyness. Rewriting QM9 in jax is not trivial.
24 | """
25 | attribute_irreps = e3nn.Irreps.spherical_harmonics(args.lmax_attributes)
26 |
27 | def _to_steerable_graph(
28 | data: Data, training: bool = True
29 | ) -> Tuple[SteerableGraphsTuple, jnp.array]:
30 | ptr = jnp.array(data.ptr)
31 | senders = jnp.array(data.edge_index[0])
32 | receivers = jnp.array(data.edge_index[1])
33 | graph = jraph.GraphsTuple(
34 | nodes=e3nn.IrrepsArray(args.node_irreps, jnp.array(data.x)),
35 | edges=None,
36 | senders=senders,
37 | receivers=receivers,
38 | n_node=jnp.diff(ptr),
39 | n_edge=jnp.diff(jnp.sum(senders[:, jnp.newaxis] < ptr, axis=0)),
40 | globals=None,
41 | )
42 | # pad for jax static shapes
43 | node_attr_pad = ((0, max_batch_nodes - jnp.sum(graph.n_node) + 1), (0, 0))
44 | edge_attr_pad = ((0, max_batch_edges - jnp.sum(graph.n_edge) + 1), (0, 0))
45 | graph = jraph.pad_with_graphs(
46 | graph,
47 | n_node=max_batch_nodes + 1,
48 | n_edge=max_batch_edges + 1,
49 | n_graph=graph.n_node.shape[0] + 1,
50 | )
51 |
52 | node_attributes = e3nn.IrrepsArray(
53 | attribute_irreps, jnp.pad(jnp.array(data.node_attr), node_attr_pad)
54 | )
55 | # scalar attribute to 1 by default
56 | node_attributes = e3nn.IrrepsArray(
57 | node_attributes.irreps, node_attributes.array.at[:, 0].set(1.0)
58 | )
59 |
60 | additional_message_features = e3nn.IrrepsArray(
61 | args.additional_message_irreps,
62 | jnp.pad(jnp.array(data.additional_message_features), edge_attr_pad),
63 | )
64 | edge_attributes = e3nn.IrrepsArray(
65 | attribute_irreps, jnp.pad(jnp.array(data.edge_attr), edge_attr_pad)
66 | )
67 |
68 | st_graph = SteerableGraphsTuple(
69 | graph=graph,
70 | node_attributes=node_attributes,
71 | edge_attributes=edge_attributes,
72 | additional_message_features=additional_message_features,
73 | )
74 |
75 | # pad targets
76 | target = jnp.array(data.y)
77 | if args.task == "node":
78 | target = jnp.pad(target, [(0, max_batch_nodes - target.shape[0] - 1)])
79 | if args.task == "graph":
80 | target = jnp.append(target, 0)
81 |
82 | # normalize targets
83 | if training and train_trn is not None:
84 | target = train_trn(target)
85 |
86 | return st_graph, target
87 |
88 | return _to_steerable_graph
89 |
90 |
91 | def setup_qm9_data(
92 | args,
93 | ) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]:
94 | dataset_train = QM9(
95 | "datasets",
96 | args.target,
97 | args.radius,
98 | partition="train",
99 | lmax_attr=args.lmax_attributes,
100 | feature_type=args.feature_type,
101 | )
102 | dataset_val = QM9(
103 | "datasets",
104 | args.target,
105 | args.radius,
106 | partition="valid",
107 | lmax_attr=args.lmax_attributes,
108 | feature_type=args.feature_type,
109 | )
110 | dataset_test = QM9(
111 | "datasets",
112 | args.target,
113 | args.radius,
114 | partition="test",
115 | lmax_attr=args.lmax_attributes,
116 | feature_type=args.feature_type,
117 | )
118 |
119 | # 0.8 (un)safety factor for rejitting
120 | max_batch_nodes = int(0.8 * sum(dataset_test.top_n_nodes(args.batch_size)))
121 | max_batch_edges = int(0.8 * sum(dataset_test.top_n_edges(args.batch_size)))
122 |
123 | target_mean, target_mad = dataset_train.calc_stats()
124 |
125 | def remove_offsets(t):
126 | return (t - target_mean) / target_mad
127 |
128 | # not great and very slow due to huge padding
129 | loader_train = DataLoader(
130 | dataset_train,
131 | batch_size=args.batch_size,
132 | shuffle=True,
133 | drop_last=True,
134 | )
135 | loader_val = DataLoader(
136 | dataset_val,
137 | batch_size=args.batch_size,
138 | shuffle=False,
139 | drop_last=True,
140 | )
141 | loader_test = DataLoader(
142 | dataset_test,
143 | batch_size=args.batch_size,
144 | shuffle=False,
145 | drop_last=True,
146 | )
147 |
148 | to_graphs_tuple = QM9GraphTransform(
149 | args,
150 | max_batch_nodes=max_batch_nodes,
151 | max_batch_edges=max_batch_edges,
152 | train_trn=remove_offsets,
153 | )
154 |
155 | def add_offsets(p):
156 | return p * target_mad + target_mean
157 |
158 | return loader_train, loader_val, loader_test, to_graphs_tuple, add_offsets
159 |
--------------------------------------------------------------------------------
/experiments/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from functools import partial
3 | from typing import Callable, Tuple
4 |
5 | import haiku as hk
6 | import jax
7 | import jax.numpy as jnp
8 | import jraph
9 | import optax
10 | from jax import jit
11 |
12 | from segnn_jax import SteerableGraphsTuple
13 |
14 |
15 | @partial(jit, static_argnames=["model_fn", "criterion", "do_mask", "eval_trn"])
16 | def loss_fn_wrapper(
17 | params: hk.Params,
18 | state: hk.State,
19 | st_graph: SteerableGraphsTuple,
20 | target: jnp.ndarray,
21 | model_fn: Callable,
22 | criterion: Callable,
23 | do_mask: bool = True,
24 | eval_trn: Callable = None,
25 | ) -> Tuple[float, hk.State]:
26 | pred, state = model_fn(params, state, st_graph)
27 | if eval_trn is not None:
28 | pred = eval_trn(pred)
29 |
30 | if do_mask:
31 | if target.shape == st_graph.graph.nodes.shape:
32 | mask = jraph.get_node_padding_mask(st_graph.graph)
33 | else:
34 | mask = jraph.get_graph_padding_mask(st_graph.graph)
35 | # broadcast mask for vector targets
36 | if len(pred.shape) == 2:
37 | mask = mask[:, jnp.newaxis]
38 | else:
39 | mask = jnp.ones_like(target)
40 |
41 | target = target * mask
42 | pred = pred * mask
43 |
44 | assert target.shape == pred.shape
45 | return jnp.sum(criterion(pred, target)) / jnp.count_nonzero(mask), state
46 |
47 |
48 | @partial(jit, static_argnames=["loss_fn", "opt_update"])
49 | def update(
50 | params: hk.Params,
51 | state: hk.State,
52 | graph: SteerableGraphsTuple,
53 | target: jnp.ndarray,
54 | opt_state: optax.OptState,
55 | loss_fn: Callable,
56 | opt_update: Callable,
57 | ) -> Tuple[float, hk.Params, hk.State, optax.OptState]:
58 | (loss, state), grads = jax.value_and_grad(loss_fn, has_aux=True)(
59 | params, state, graph, target
60 | )
61 | updates, opt_state = opt_update(grads, opt_state, params)
62 | return loss, optax.apply_updates(params, updates), state, opt_state
63 |
64 |
65 | def evaluate(
66 | loader,
67 | params: hk.Params,
68 | state: hk.State,
69 | loss_fn: Callable,
70 | graph_transform: Callable,
71 | ) -> Tuple[float, float]:
72 | eval_loss = 0.0
73 | eval_times = 0.0
74 | for data in loader:
75 | graph, target = graph_transform(data, training=False)
76 | eval_start = time.perf_counter_ns()
77 | loss, _ = jax.lax.stop_gradient(loss_fn(params, state, graph, target))
78 | eval_loss += jax.block_until_ready(loss)
79 | eval_times += (time.perf_counter_ns() - eval_start) / 1e6
80 | return eval_times / len(loader), eval_loss / len(loader)
81 |
82 |
83 | def train(
84 | key,
85 | segnn,
86 | loader_train,
87 | loader_val,
88 | loader_test,
89 | loss_fn,
90 | eval_loss_fn,
91 | graph_transform,
92 | args,
93 | ):
94 | init_graph, _ = graph_transform(next(iter(loader_train)))
95 | params, segnn_state = segnn.init(key, init_graph)
96 |
97 | print(
98 | f"Starting {args.epochs} epochs "
99 | f"with {hk.data_structures.tree_size(params)} parameters.\n"
100 | "Jitting..."
101 | )
102 |
103 | total_steps = args.epochs * len(loader_train)
104 |
105 | # set up learning rate and optimizer
106 | learning_rate = args.lr
107 | if args.lr_scheduling:
108 | learning_rate = optax.piecewise_constant_schedule(
109 | learning_rate,
110 | boundaries_and_scales={
111 | int(total_steps * 0.7): 0.1,
112 | int(total_steps * 0.9): 0.1,
113 | },
114 | )
115 | opt_init, opt_update = optax.adamw(
116 | learning_rate=learning_rate, weight_decay=args.weight_decay
117 | )
118 |
119 | model_fn = jit(segnn.apply)
120 |
121 | loss_fn = partial(loss_fn, model_fn=model_fn)
122 | eval_loss_fn = partial(eval_loss_fn, model_fn=model_fn)
123 | update_fn = partial(update, loss_fn=loss_fn, opt_update=opt_update)
124 | eval_fn = partial(evaluate, loss_fn=eval_loss_fn, graph_transform=graph_transform)
125 |
126 | opt_state = opt_init(params)
127 | avg_time = []
128 | best_val = 1e10
129 |
130 | for e in range(args.epochs):
131 | train_loss = 0.0
132 | epoch_start = time.perf_counter_ns()
133 | for data in loader_train:
134 | graph, target = graph_transform(data)
135 | loss, params, segnn_state, opt_state = update_fn(
136 | params=params,
137 | state=segnn_state,
138 | graph=graph,
139 | target=target,
140 | opt_state=opt_state,
141 | )
142 | train_loss += jax.block_until_ready(loss)
143 | train_loss /= len(loader_train)
144 | epoch_time = (time.perf_counter_ns() - epoch_start) / 1e9
145 |
146 | print(
147 | f"[Epoch {e+1:>4}] train loss {train_loss:.6f}, epoch {epoch_time:.2f}s",
148 | end="",
149 | )
150 | if e % args.val_freq == 0:
151 | eval_time, val_loss = eval_fn(loader_val, params, segnn_state)
152 | avg_time.append(eval_time)
153 | tag = ""
154 | if val_loss < best_val:
155 | best_val = val_loss
156 | tag = " (best)"
157 | _, test_loss_ckp = eval_fn(loader_test, params, segnn_state)
158 | print(f" - val loss {val_loss:.6f}{tag}, infer {eval_time:.2f}ms", end="")
159 |
160 | print()
161 |
162 | test_loss = 0
163 | _, test_loss = eval_fn(loader_test, params, segnn_state)
164 | # ignore compilation time
165 | avg_time = avg_time[1:] if len(avg_time) > 1 else avg_time
166 | avg_time = sum(avg_time) / len(avg_time)
167 | print(
168 | "Training done.\n"
169 | f"Final test loss {test_loss:.6f} - checkpoint test loss {test_loss_ckp:.6f}.\n"
170 | f"Average (model) eval time {avg_time:.2f}ms"
171 | )
172 |
--------------------------------------------------------------------------------
/experiments/nbody/datasets.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import pathlib
3 | from abc import ABC, abstractmethod
4 | from typing import Sequence, Tuple, Union
5 |
6 | import numpy as np
7 |
8 | DATA_DIR = "data"
9 |
10 |
11 | class BaseDataset(ABC):
12 | """Abstract n-body dataset class."""
13 |
14 | def __init__(
15 | self,
16 | data_type,
17 | partition="train",
18 | max_samples=1e8,
19 | dataset_name="small",
20 | n_bodies=5,
21 | normalize=False,
22 | ):
23 | self.partition = partition
24 | if self.partition == "val":
25 | self.suffix = "valid"
26 | else:
27 | self.suffix = self.partition
28 | self.dataset_name = dataset_name
29 | self.suffix += f"_{data_type}{n_bodies}_initvel1"
30 | self.data_type = data_type
31 | self.max_samples = int(max_samples)
32 | self.normalize = normalize
33 |
34 | self.data = None
35 |
36 | def get_n_nodes(self):
37 | return self.data[0].shape[2]
38 |
39 | def _get_partition_frames(self) -> Tuple[int, int]:
40 | if self.dataset_name == "default":
41 | frame_0, frame_target = 6, 8
42 | elif self.dataset_name == "small":
43 | frame_0, frame_target = 30, 40
44 | elif self.dataset_name == "small_out_dist":
45 | frame_0, frame_target = 20, 30
46 | else:
47 | raise Exception("Wrong dataset partition %s" % self.dataset_name)
48 |
49 | return frame_0, frame_target
50 |
51 | def __len__(self) -> int:
52 | return len(self.data[0])
53 |
54 | def _load(self) -> Tuple[np.ndarray, ...]:
55 | filepath = pathlib.Path(__file__).parent.resolve()
56 |
57 | loc = np.load(osp.join(filepath, DATA_DIR, "loc_" + self.suffix + ".npy"))
58 | vel = np.load(osp.join(filepath, DATA_DIR, "vel_" + self.suffix + ".npy"))
59 | edges = np.load(osp.join(filepath, DATA_DIR, "edges_" + self.suffix + ".npy"))
60 | q = np.load(osp.join(filepath, DATA_DIR, "q_" + self.suffix + ".npy"))
61 |
62 | return loc, vel, edges, q
63 |
64 | def _normalize(self, x: np.ndarray) -> np.ndarray:
65 | std = x.std(axis=0)
66 | x = x - x.mean(axis=0)
67 | return np.divide(x, std, out=x, where=std != 0)
68 |
69 | @abstractmethod
70 | def load(self):
71 | raise NotImplementedError
72 |
73 | @abstractmethod
74 | def preprocess(self, *args) -> Tuple[np.ndarray, ...]:
75 | raise NotImplementedError
76 |
77 |
78 | class ChargedDataset(BaseDataset):
79 | """N-body charged dataset class."""
80 |
81 | def __init__(
82 | self,
83 | partition="train",
84 | max_samples=1e8,
85 | dataset_name="small",
86 | n_bodies=5,
87 | normalize=False,
88 | ):
89 | super().__init__(
90 | "charged", partition, max_samples, dataset_name, n_bodies, normalize
91 | )
92 | self.data, self.edges = self.load()
93 |
94 | def preprocess(self, *args) -> Tuple[np.ndarray, ...]:
95 | # swap n_nodes - n_features dimensions
96 | loc, vel, edges, charges = args
97 | loc, vel = np.transpose(loc, (0, 1, 3, 2)), np.transpose(vel, (0, 1, 3, 2))
98 | n_nodes = loc.shape[2]
99 | loc = loc[0 : self.max_samples, :, :, :] # limit number of samples
100 | vel = vel[0 : self.max_samples, :, :, :] # speed when starting the trajectory
101 | charges = charges[0 : self.max_samples]
102 | edge_attr = []
103 |
104 | # Initialize edges and edge_attributes
105 | rows, cols = [], []
106 | for i in range(n_nodes):
107 | for j in range(n_nodes):
108 | if i != j:
109 | edge_attr.append(edges[:, i, j])
110 | rows.append(i)
111 | cols.append(j)
112 | edges = [rows, cols]
113 | # swap n_nodes - batch_size and add nf dimension
114 | edge_attr = np.array(edge_attr).T
115 | edge_attr = np.expand_dims(edge_attr, 2)
116 |
117 | if self.normalize:
118 | loc = self._normalize(loc)
119 | vel = self._normalize(vel)
120 | charges = self._normalize(charges)
121 |
122 | return loc, vel, edge_attr, edges, charges
123 |
124 | def load(self):
125 | loc, vel, edges, q = self._load()
126 |
127 | loc, vel, edge_attr, edges, charges = self.preprocess(loc, vel, edges, q)
128 | return (loc, vel, edge_attr, charges), edges
129 |
130 | def __getitem__(self, i: Union[Sequence, int]) -> Tuple[np.ndarray, ...]:
131 | frame_0, frame_target = self._get_partition_frames()
132 |
133 | loc, vel, edge_attr, charges = self.data
134 |
135 | loc, vel, edge_attr, charges, target_loc = (
136 | loc[i, frame_0],
137 | vel[i, frame_0],
138 | edge_attr[i],
139 | charges[i],
140 | loc[i, frame_target],
141 | )
142 |
143 | if not isinstance(i, int):
144 | # flatten batch and nodes dimensions
145 | loc = loc.reshape(-1, *loc.shape[2:])
146 | vel = vel.reshape(-1, *vel.shape[2:])
147 | edge_attr = edge_attr.reshape(-1, *edge_attr.shape[2:])
148 | charges = charges.reshape(-1, *charges.shape[2:])
149 | target_loc = target_loc.reshape(-1, *target_loc.shape[2:])
150 |
151 | return loc, vel, edge_attr, charges, target_loc
152 |
153 |
154 | class GravityDataset(BaseDataset):
155 | """N-body gravity dataset class."""
156 |
157 | def __init__(
158 | self,
159 | partition="train",
160 | max_samples=1e8,
161 | dataset_name="small",
162 | n_bodies=100,
163 | neighbours=6,
164 | target="pos",
165 | normalize=False,
166 | ):
167 | super().__init__(
168 | "gravity", partition, max_samples, dataset_name, n_bodies, normalize
169 | )
170 | assert target in ["pos", "force"]
171 | self.neighbours = int(neighbours)
172 | self.target = target
173 | self.data = self.load()
174 |
175 | def preprocess(self, *args) -> Tuple[np.ndarray, ...]:
176 | loc, vel, force, mass = args
177 | # NOTE this was in the original paper but does not look right
178 | # loc = np.transpose(loc, (0, 1, 3, 2))
179 | # vel = np.transpose(vel, (0, 1, 3, 2))
180 | # force = np.transpose(force, (0, 1, 3, 2))
181 | loc = loc[0 : self.max_samples, :, :, :] # limit number of samples
182 | vel = vel[0 : self.max_samples, :, :, :] # speed when starting the trajectory
183 | force = force[0 : self.max_samples, :, :, :]
184 |
185 | if self.normalize:
186 | loc = self._normalize(loc)
187 | vel = self._normalize(vel)
188 | force = self._normalize(force)
189 | mass = self._normalize(mass)
190 |
191 | return loc, vel, force, mass
192 |
193 | def load(self):
194 | loc, vel, edges, q = self._load()
195 |
196 | self.num_nodes = loc.shape[-1]
197 |
198 | loc, vel, force, mass = self.preprocess(loc, vel, edges, q)
199 | return (loc, vel, force, mass)
200 |
201 | def __getitem__(self, i: Union[Sequence, int]) -> Tuple[np.ndarray, ...]:
202 | frame_0, frame_target = self._get_partition_frames()
203 |
204 | loc, vel, force, mass = self.data
205 | if self.target == "pos":
206 | y = loc[i, frame_target]
207 | elif self.target == "force":
208 | y = force[i, frame_target]
209 | loc, vel, force, mass = (loc[i, frame_0], vel[i, frame_0], force[i], mass[i])
210 |
211 | if not isinstance(i, int):
212 | # flatten batch and nodes dimensions
213 | loc = loc.reshape(-1, *loc.shape[2:])
214 | vel = vel.reshape(-1, *vel.shape[2:])
215 | force = force.reshape(-1, *force.shape[2:])
216 | mass = mass.reshape(-1, *mass.shape[2:])
217 | y = y.reshape(-1, *y.shape[2:])
218 |
219 | return loc, vel, force, mass, y
220 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from functools import partial
4 |
5 | import e3nn_jax as e3nn
6 | import haiku as hk
7 | import jax
8 | import jax.numpy as jnp
9 | import wandb
10 |
11 | from experiments import setup_data, train
12 | from segnn_jax import SEGNN, weight_balanced_irreps
13 |
14 | key = jax.random.PRNGKey(1337)
15 |
16 |
17 | if __name__ == "__main__":
18 | parser = argparse.ArgumentParser()
19 | # Run parameters
20 | parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
21 | parser.add_argument(
22 | "--batch-size",
23 | type=int,
24 | default=128,
25 | help="Batch size (number of graphs).",
26 | )
27 | parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
28 | parser.add_argument(
29 | "--lr-scheduling",
30 | action="store_true",
31 | help="Use learning rate scheduling",
32 | )
33 | parser.add_argument(
34 | "--weight-decay", type=float, default=1e-12, help="Weight decay"
35 | )
36 | parser.add_argument(
37 | "--dataset",
38 | type=str,
39 | choices=["qm9", "charged", "gravity"],
40 | help="Dataset name",
41 | )
42 | parser.add_argument(
43 | "--max-samples",
44 | type=int,
45 | default=3000,
46 | help="Maximum number of samples in nbody dataset",
47 | )
48 | parser.add_argument(
49 | "--val-freq",
50 | type=int,
51 | default=10,
52 | help="Evaluation frequency (number of epochs)",
53 | )
54 |
55 | # nbody parameters
56 | parser.add_argument(
57 | "--target",
58 | type=str,
59 | default="pos",
60 | help="Target. e.g. pos, force (gravity), alpha (qm9)",
61 | )
62 | parser.add_argument(
63 | "--neighbours",
64 | type=int,
65 | default=20,
66 | help="Number of connected nearest neighbours",
67 | )
68 | parser.add_argument(
69 | "--n-bodies",
70 | type=int,
71 | default=5,
72 | help="Number of bodies in the dataset",
73 | )
74 | parser.add_argument(
75 | "--dataset-name",
76 | type=str,
77 | default="small",
78 | choices=["small", "default", "small_out_dist"],
79 | help="Name of nbody data partition: default (200 steps), small (1000 steps)",
80 | )
81 |
82 | # qm9 parameters
83 | parser.add_argument(
84 | "--radius",
85 | type=float,
86 | default=2.0,
87 | help="Radius (Angstrom) between which atoms to add links.",
88 | )
89 | parser.add_argument(
90 | "--feature-type",
91 | type=str,
92 | default="one_hot",
93 | choices=["one_hot", "cormorant", "gilmer"],
94 | help="Type of input feature",
95 | )
96 |
97 | # Model parameters
98 | parser.add_argument(
99 | "--units", type=int, default=64, help="Number of values in the hidden layers"
100 | )
101 | parser.add_argument(
102 | "--lmax-hidden",
103 | type=int,
104 | default=1,
105 | help="Max degree of hidden representations.",
106 | )
107 | parser.add_argument(
108 | "--lmax-attributes",
109 | type=int,
110 | default=1,
111 | help="Max degree of geometric attribute embedding",
112 | )
113 | parser.add_argument(
114 | "--layers", type=int, default=7, help="Number of message passing layers"
115 | )
116 | parser.add_argument(
117 | "--blocks", type=int, default=2, help="Number of layers in steerable MLPs."
118 | )
119 | parser.add_argument(
120 | "--norm",
121 | type=str,
122 | default="none",
123 | choices=["instance", "batch", "none"],
124 | help="Normalisation type",
125 | )
126 | parser.add_argument(
127 | "--double-precision",
128 | action="store_true",
129 | help="Use double precision in model",
130 | )
131 | parser.add_argument(
132 | "--scn",
133 | action="store_true",
134 | help="Train SEGNN with the eSCN optimization",
135 | )
136 |
137 | # wandb parameters
138 | parser.add_argument(
139 | "--wandb",
140 | action="store_true",
141 | help="Activate weights and biases logging",
142 | )
143 | parser.add_argument(
144 | "--wandb-project",
145 | type=str,
146 | default="segnn",
147 | help="Weights and biases project",
148 | )
149 | parser.add_argument(
150 | "--wandb-entity",
151 | type=str,
152 | default="",
153 | help="Weights and biases entity",
154 | )
155 |
156 | args = parser.parse_args()
157 |
158 | # if specified set jax in double precision
159 | jax.config.update("jax_enable_x64", args.double_precision)
160 |
161 | # connect to wandb
162 | if args.wandb:
163 | wandb_name = "_".join(
164 | [
165 | args.wandb_project,
166 | args.dataset,
167 | args.target,
168 | str(int(time.time())),
169 | ]
170 | )
171 | wandb.init(
172 | project=args.wandb_project,
173 | name=wandb_name,
174 | config=args,
175 | entity=args.wandb_entity,
176 | )
177 |
178 | # feature representations
179 | if args.dataset == "qm9":
180 | args.task = "graph"
181 | if args.feature_type == "one_hot":
182 | args.node_irreps = e3nn.Irreps("5x0e")
183 | elif args.feature_type == "cormorant":
184 | args.node_irreps = e3nn.Irreps("15x0e")
185 | elif args.feature_type == "gilmer":
186 | args.node_irreps = e3nn.Irreps("11x0e")
187 | args.output_irreps = e3nn.Irreps("1x0e")
188 | args.additional_message_irreps = e3nn.Irreps("1x0e")
189 | assert not args.scn, "eSCN not implemented for qm9"
190 | elif args.dataset in ["charged", "gravity"]:
191 | args.task = "node"
192 | args.node_irreps = e3nn.Irreps("2x1o + 1x0e")
193 | args.output_irreps = e3nn.Irreps("1x1o")
194 | args.additional_message_irreps = e3nn.Irreps("2x0e")
195 |
196 | # Create hidden irreps
197 | if not args.scn:
198 | attr_irreps = e3nn.Irreps.spherical_harmonics(args.lmax_attributes)
199 | else:
200 | attr_irreps = e3nn.Irrep(f"{args.lmax_attribute}y")
201 |
202 | hidden_irreps = weight_balanced_irreps(
203 | scalar_units=args.units,
204 | irreps_right=attr_irreps,
205 | use_sh=(not args.scn),
206 | lmax=args.lmax_hidden,
207 | )
208 |
209 | args.o3_layer = "scn" if args.scn else "tpl"
210 | del args.scn
211 |
212 | # build model
213 | def segnn(x):
214 | return SEGNN(
215 | hidden_irreps=hidden_irreps,
216 | output_irreps=args.output_irreps,
217 | num_layers=args.layers,
218 | task=args.task,
219 | pool="avg",
220 | blocks_per_layer=args.blocks,
221 | norm=args.norm,
222 | o3_layer=args.o3_layer,
223 | )(x)
224 |
225 | segnn = hk.without_apply_rng(hk.transform_with_state(segnn))
226 |
227 | loader_train, loader_val, loader_test, graph_transform, eval_trn = setup_data(args)
228 |
229 | if args.dataset == "qm9":
230 | from experiments.train import loss_fn_wrapper
231 |
232 | def _mae(p, t):
233 | return jnp.abs(p - t)
234 |
235 | train_loss = partial(loss_fn_wrapper, criterion=_mae)
236 | eval_loss = partial(loss_fn_wrapper, criterion=_mae, eval_trn=eval_trn)
237 | if args.dataset in ["charged", "gravity"]:
238 | from experiments.train import loss_fn_wrapper
239 |
240 | def _mse(p, t):
241 | return jnp.power(p - t, 2)
242 |
243 | train_loss = partial(loss_fn_wrapper, criterion=_mse, do_mask=False)
244 | eval_loss = partial(loss_fn_wrapper, criterion=_mse, do_mask=False)
245 |
246 | train(
247 | key,
248 | segnn,
249 | loader_train,
250 | loader_val,
251 | loader_test,
252 | train_loss,
253 | eval_loss,
254 | graph_transform,
255 | args,
256 | )
257 |
--------------------------------------------------------------------------------
/experiments/nbody/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Tuple
2 |
3 | import e3nn_jax as e3nn
4 | import jax
5 | import jax.numpy as jnp
6 | import jax.tree_util as tree
7 | import numpy as np
8 | import torch
9 | from jraph import GraphsTuple, segment_mean
10 | from torch.utils.data import DataLoader
11 | from torch_geometric.nn import knn_graph
12 |
13 | from segnn_jax import SteerableGraphsTuple
14 |
15 | from .datasets import ChargedDataset, GravityDataset
16 |
17 |
18 | def O3Transform(
19 | node_features_irreps: e3nn.Irreps,
20 | edge_features_irreps: e3nn.Irreps,
21 | lmax_attributes: int,
22 | scn: bool = False,
23 | ) -> Callable:
24 | """
25 | Build a transformation function that includes (nbody) O3 attributes to a graph.
26 | """
27 | if not scn:
28 | attribute_irreps = e3nn.Irreps.spherical_harmonics(lmax_attributes)
29 | else:
30 | attribute_irreps = e3nn.Irrep("1o")
31 |
32 | @jax.jit
33 | def _o3_transform(
34 | st_graph: SteerableGraphsTuple,
35 | loc: jnp.ndarray,
36 | vel: jnp.ndarray,
37 | charges: jnp.ndarray,
38 | ) -> SteerableGraphsTuple:
39 | graph = st_graph.graph
40 | prod_charges = charges[graph.senders] * charges[graph.receivers]
41 | rel_pos = loc[graph.senders] - loc[graph.receivers]
42 | edge_dist = jnp.sqrt(jnp.power(rel_pos, 2).sum(1, keepdims=True))
43 |
44 | msg_features = e3nn.IrrepsArray(
45 | edge_features_irreps,
46 | jnp.concatenate((edge_dist, prod_charges), axis=-1),
47 | )
48 |
49 | vel_abs = jnp.sqrt(jnp.power(vel, 2).sum(1, keepdims=True))
50 | mean_loc = loc.mean(1, keepdims=True)
51 |
52 | nodes = e3nn.IrrepsArray(
53 | node_features_irreps,
54 | jnp.concatenate((loc - mean_loc, vel, vel_abs), axis=-1),
55 | )
56 |
57 | if not scn:
58 | edge_attributes = e3nn.spherical_harmonics(
59 | attribute_irreps, rel_pos, normalize=True, normalization="integral"
60 | )
61 | vel_embedding = e3nn.spherical_harmonics(
62 | attribute_irreps, vel, normalize=True, normalization="integral"
63 | )
64 | else:
65 | edge_attributes = e3nn.IrrepsArray(attribute_irreps, rel_pos)
66 | vel_embedding = e3nn.IrrepsArray(attribute_irreps, vel)
67 |
68 | # scatter edge attributes
69 | sum_n_node = tree.tree_leaves(nodes)[0].shape[0]
70 | node_attributes = (
71 | tree.tree_map(
72 | lambda e: segment_mean(e, graph.receivers, sum_n_node),
73 | edge_attributes,
74 | )
75 | + vel_embedding
76 | )
77 | if not scn:
78 | # scalar attribute to 1 by default
79 | node_attributes = e3nn.IrrepsArray(
80 | node_attributes.irreps, node_attributes.array.at[:, 0].set(1.0)
81 | )
82 |
83 | return SteerableGraphsTuple(
84 | graph=GraphsTuple(
85 | nodes=nodes,
86 | edges=None,
87 | senders=graph.senders,
88 | receivers=graph.receivers,
89 | n_node=graph.n_node,
90 | n_edge=graph.n_edge,
91 | globals=graph.globals,
92 | ),
93 | node_attributes=node_attributes,
94 | edge_attributes=edge_attributes,
95 | additional_message_features=msg_features,
96 | )
97 |
98 | return _o3_transform
99 |
100 |
101 | def NbodyGraphTransform(
102 | transform: Callable,
103 | dataset_name: str,
104 | n_nodes: int,
105 | batch_size: int,
106 | neighbours: Optional[int] = 6,
107 | relative_target: bool = False,
108 | ) -> Callable:
109 | """
110 | Build a function that converts torch DataBatch into SteerableGraphsTuple.
111 | """
112 |
113 | if dataset_name == "charged":
114 | # charged system is a connected graph
115 | full_edge_indices = jnp.array(
116 | [
117 | (i + n_nodes * b, j + n_nodes * b)
118 | for b in range(batch_size)
119 | for i in range(n_nodes)
120 | for j in range(n_nodes)
121 | if i != j
122 | ]
123 | ).T
124 |
125 | def _to_steerable_graph(
126 | data: List, training: bool = True
127 | ) -> Tuple[SteerableGraphsTuple, jnp.ndarray]:
128 | _ = training
129 | loc, vel, _, q, targets = data
130 |
131 | cur_batch = int(loc.shape[0] / n_nodes)
132 |
133 | if dataset_name == "charged":
134 | edge_indices = full_edge_indices[:, : n_nodes * (n_nodes - 1) * cur_batch]
135 | senders, receivers = edge_indices[0], edge_indices[1]
136 | if dataset_name == "gravity":
137 | batch = torch.arange(0, cur_batch)
138 | batch = batch.repeat_interleave(n_nodes).long()
139 | edge_indices = knn_graph(torch.from_numpy(np.array(loc)), neighbours, batch)
140 | # switched by default
141 | senders, receivers = jnp.array(edge_indices[0]), jnp.array(edge_indices[1])
142 |
143 | st_graph = SteerableGraphsTuple(
144 | graph=GraphsTuple(
145 | nodes=None,
146 | edges=None,
147 | senders=senders,
148 | receivers=receivers,
149 | n_node=jnp.array([n_nodes] * cur_batch),
150 | n_edge=jnp.array([len(senders) // cur_batch] * cur_batch),
151 | globals=None,
152 | )
153 | )
154 | st_graph = transform(st_graph, loc, vel, q)
155 |
156 | # relative shift as target
157 | if relative_target:
158 | targets = targets - loc
159 |
160 | return st_graph, targets
161 |
162 | return _to_steerable_graph
163 |
164 |
165 | def numpy_collate(batch):
166 | if isinstance(batch[0], np.ndarray):
167 | return jnp.vstack(batch)
168 | elif isinstance(batch[0], (tuple, list)):
169 | transposed = zip(*batch)
170 | return [numpy_collate(samples) for samples in transposed]
171 | else:
172 | return jnp.array(batch)
173 |
174 |
175 | def setup_nbody_data(
176 | args,
177 | ) -> Tuple[DataLoader, DataLoader, DataLoader, Callable, Callable]:
178 | if args.dataset == "charged":
179 | dataset_train = ChargedDataset(
180 | partition="train",
181 | dataset_name=args.dataset_name,
182 | max_samples=args.max_samples,
183 | n_bodies=args.n_bodies,
184 | )
185 | dataset_val = ChargedDataset(
186 | partition="val",
187 | dataset_name=args.dataset_name,
188 | n_bodies=args.n_bodies,
189 | )
190 | dataset_test = ChargedDataset(
191 | partition="test",
192 | dataset_name=args.dataset_name,
193 | n_bodies=args.n_bodies,
194 | )
195 |
196 | if args.dataset == "gravity":
197 | dataset_train = GravityDataset(
198 | partition="train",
199 | dataset_name=args.dataset_name,
200 | max_samples=args.max_samples,
201 | neighbours=args.neighbours,
202 | target=args.target,
203 | n_bodies=args.n_bodies,
204 | )
205 | dataset_val = GravityDataset(
206 | partition="val",
207 | dataset_name=args.dataset_name,
208 | neighbours=args.neighbours,
209 | target=args.target,
210 | n_bodies=args.n_bodies,
211 | )
212 | dataset_test = GravityDataset(
213 | partition="test",
214 | dataset_name=args.dataset_name,
215 | neighbours=args.neighbours,
216 | target=args.target,
217 | n_bodies=args.n_bodies,
218 | )
219 |
220 | o3_transform = O3Transform(
221 | args.node_irreps,
222 | args.additional_message_irreps,
223 | args.lmax_attributes,
224 | scn=args.o3_layer == "scn",
225 | )
226 | graph_transform = NbodyGraphTransform(
227 | transform=o3_transform,
228 | n_nodes=args.n_bodies,
229 | batch_size=args.batch_size,
230 | neighbours=args.neighbours,
231 | relative_target=(args.target == "pos"),
232 | dataset_name=args.dataset,
233 | )
234 |
235 | loader_train = DataLoader(
236 | dataset_train,
237 | batch_size=args.batch_size,
238 | shuffle=True,
239 | drop_last=True,
240 | collate_fn=numpy_collate,
241 | )
242 | loader_val = DataLoader(
243 | dataset_val,
244 | batch_size=args.batch_size,
245 | shuffle=False,
246 | drop_last=False,
247 | collate_fn=numpy_collate,
248 | )
249 | loader_test = DataLoader(
250 | dataset_test,
251 | batch_size=args.batch_size,
252 | shuffle=False,
253 | drop_last=False,
254 | collate_fn=numpy_collate,
255 | )
256 |
257 | return loader_train, loader_val, loader_test, graph_transform, None
258 |
--------------------------------------------------------------------------------
/experiments/nbody/data/synthetic_sim.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class ChargedParticlesSim(object):
5 | def __init__(
6 | self,
7 | n_balls=5,
8 | box_size=5.0,
9 | loc_std=1.0,
10 | vel_norm=0.5,
11 | interaction_strength=1.0,
12 | noise_var=0.0,
13 | ):
14 | self.n_balls = n_balls
15 | self.box_size = box_size
16 | self.loc_std = loc_std
17 | self.loc_std = loc_std * (float(n_balls) / 5.0) ** (1 / 3)
18 | print(self.loc_std)
19 | self.vel_norm = vel_norm
20 | self.interaction_strength = interaction_strength
21 | self.noise_var = noise_var
22 |
23 | self._charge_types = np.array([-1.0, 0.0, 1.0])
24 | self._delta_T = 0.001
25 | self._max_F = 0.1 / self._delta_T
26 | self.dim = 3
27 |
28 | def _l2(self, A, B):
29 | """
30 | Input: A is a Nxd matrix
31 | B is a Mxd matirx
32 | Output: dist is a NxM matrix where dist[i,j] is the square norm
33 | between A[i,:] and B[j,:]
34 | i.e. dist[i,j] = ||A[i,:]-B[j,:]||^2
35 | """
36 | A_norm = (A**2).sum(axis=1).reshape(A.shape[0], 1)
37 | B_norm = (B**2).sum(axis=1).reshape(1, B.shape[0])
38 | dist = A_norm + B_norm - 2 * A.dot(B.transpose())
39 | return dist
40 |
41 | def _energy(self, loc, vel, edges):
42 | # disables division by zero warning, since I fix it with fill_diagonal
43 | with np.errstate(divide="ignore"):
44 | K = 0.5 * (vel**2).sum()
45 | U = 0
46 | for i in range(loc.shape[1]):
47 | for j in range(loc.shape[1]):
48 | if i != j:
49 | r = loc[:, i] - loc[:, j]
50 | dist = np.sqrt((r**2).sum())
51 | U += 0.5 * self.interaction_strength * edges[i, j] / dist
52 | return U + K
53 |
54 | def _clamp(self, loc, vel):
55 | """
56 | :param loc: 2xN location at one time stamp
57 | :param vel: 2xN velocity at one time stamp
58 | :return: location and velocity after hiting walls and returning after
59 | elastically colliding with walls
60 | """
61 | assert np.all(loc < self.box_size * 3)
62 | assert np.all(loc > -self.box_size * 3)
63 |
64 | over = loc > self.box_size
65 | loc[over] = 2 * self.box_size - loc[over]
66 | assert np.all(loc <= self.box_size)
67 |
68 | # assert(np.all(vel[over]>0))
69 | vel[over] = -np.abs(vel[over])
70 |
71 | under = loc < -self.box_size
72 | loc[under] = -2 * self.box_size - loc[under]
73 | # assert (np.all(vel[under] < 0))
74 | assert np.all(loc >= -self.box_size)
75 | vel[under] = np.abs(vel[under])
76 |
77 | return loc, vel
78 |
79 | def sample_trajectory(self, T=10000, sample_freq=10, charge_prob=None):
80 | if charge_prob is None:
81 | charge_prob = [1.0 / 2, 0, 1.0 / 2]
82 | n = self.n_balls
83 | assert T % sample_freq == 0
84 | T_save = int(T / sample_freq - 1)
85 | diag_mask = np.ones((n, n), dtype=bool)
86 | np.fill_diagonal(diag_mask, 0)
87 | counter = 0
88 | # Sample edges
89 | charges = np.random.choice(
90 | self._charge_types, size=(self.n_balls, 1), p=charge_prob
91 | )
92 | edges = charges.dot(charges.transpose())
93 | # Initialize location and velocity
94 | loc = np.zeros((T_save, self.dim, n))
95 | vel = np.zeros((T_save, self.dim, n))
96 | loc_next = np.random.randn(self.dim, n) * self.loc_std
97 | vel_next = np.random.randn(self.dim, n)
98 | v_norm = np.sqrt((vel_next**2).sum(axis=0)).reshape(1, -1)
99 | vel_next = vel_next * self.vel_norm / v_norm
100 | loc[0, :, :], vel[0, :, :] = self._clamp(loc_next, vel_next)
101 |
102 | # disables division by zero warning, since I fix it with fill_diagonal
103 | with np.errstate(divide="ignore"):
104 | # half step leapfrog
105 | l2_dist_power3 = np.power(
106 | self._l2(loc_next.transpose(), loc_next.transpose()), 3.0 / 2.0
107 | )
108 |
109 | # size of forces up to a 1/|r| factor
110 | # since I later multiply by an unnormalized r vector
111 | forces_size = self.interaction_strength * edges / l2_dist_power3
112 | np.fill_diagonal(
113 | forces_size, 0
114 | ) # self forces are zero (fixes division by zero)
115 | assert np.abs(forces_size[diag_mask]).min() > 1e-10
116 | F = (
117 | forces_size.reshape(1, n, n)
118 | * np.concatenate(
119 | (
120 | np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape(
121 | 1, n, n
122 | ),
123 | np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape(
124 | 1, n, n
125 | ),
126 | np.subtract.outer(loc_next[2, :], loc_next[2, :]).reshape(
127 | 1, n, n
128 | ),
129 | )
130 | )
131 | ).sum(axis=-1)
132 | F[F > self._max_F] = self._max_F
133 | F[F < -self._max_F] = -self._max_F
134 |
135 | vel_next += self._delta_T * F
136 | # run leapfrog
137 | for i in range(1, T):
138 | loc_next += self._delta_T * vel_next
139 | # loc_next, vel_next = self._clamp(loc_next, vel_next)
140 |
141 | if i % sample_freq == 0:
142 | loc[counter, :, :], vel[counter, :, :] = loc_next, vel_next
143 | counter += 1
144 |
145 | l2_dist_power3 = np.power(
146 | self._l2(loc_next.transpose(), loc_next.transpose()), 3.0 / 2.0
147 | )
148 | forces_size = self.interaction_strength * edges / l2_dist_power3
149 | np.fill_diagonal(forces_size, 0)
150 | # assert (np.abs(forces_size[diag_mask]).min() > 1e-10)
151 |
152 | F = (
153 | forces_size.reshape(1, n, n)
154 | * np.concatenate(
155 | (
156 | np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape(
157 | 1, n, n
158 | ),
159 | np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape(
160 | 1, n, n
161 | ),
162 | np.subtract.outer(loc_next[2, :], loc_next[2, :]).reshape(
163 | 1, n, n
164 | ),
165 | )
166 | )
167 | ).sum(axis=-1)
168 | F[F > self._max_F] = self._max_F
169 | F[F < -self._max_F] = -self._max_F
170 | vel_next += self._delta_T * F
171 | # Add noise to observations
172 | loc += np.random.randn(T_save, self.dim, self.n_balls) * self.noise_var
173 | vel += np.random.randn(T_save, self.dim, self.n_balls) * self.noise_var
174 | return loc, vel, edges, charges
175 |
176 |
177 | class GravitySim(object):
178 | def __init__(
179 | self,
180 | n_balls=100,
181 | loc_std=1,
182 | vel_norm=0.5,
183 | interaction_strength=1,
184 | noise_var=0,
185 | dt=0.001,
186 | softening=0.1,
187 | ):
188 | self.n_balls = n_balls
189 | self.loc_std = loc_std
190 | self.vel_norm = vel_norm
191 | self.interaction_strength = interaction_strength
192 | self.noise_var = noise_var
193 | self.dt = dt
194 | self.softening = softening
195 |
196 | self.dim = 3
197 |
198 | def compute_acceleration(self, pos, mass, G, softening):
199 | # positions r = [x,y,z] for all particles
200 | x = pos[:, 0:1]
201 | y = pos[:, 1:2]
202 | z = pos[:, 2:3]
203 |
204 | # matrix that stores all pairwise particle separations: r_j - r_i
205 | dx = x.T - x
206 | dy = y.T - y
207 | dz = z.T - z
208 |
209 | # matrix that stores 1/r^3 for all particle pairwise particle separations
210 | inv_r3 = dx**2 + dy**2 + dz**2 + softening**2
211 | inv_r3[inv_r3 > 0] = inv_r3[inv_r3 > 0] ** (-1.5)
212 |
213 | ax = G * (dx * inv_r3) @ mass
214 | ay = G * (dy * inv_r3) @ mass
215 | az = G * (dz * inv_r3) @ mass
216 |
217 | # pack together the acceleration components
218 | a = np.hstack((ax, ay, az))
219 | return a
220 |
221 | def _energy(self, pos, vel, mass, G):
222 | # Kinetic Energy:
223 | KE = 0.5 * np.sum(np.sum(mass * vel**2))
224 |
225 | # Potential Energy:
226 |
227 | # positions r = [x,y,z] for all particles
228 | x = pos[:, 0:1]
229 | y = pos[:, 1:2]
230 | z = pos[:, 2:3]
231 |
232 | # matrix that stores all pairwise particle separations: r_j - r_i
233 | dx = x.T - x
234 | dy = y.T - y
235 | dz = z.T - z
236 |
237 | # matrix that stores 1/r for all particle pairwise particle separations
238 | inv_r = np.sqrt(dx**2 + dy**2 + dz**2)
239 | inv_r[inv_r > 0] = 1.0 / inv_r[inv_r > 0]
240 |
241 | # sum over upper triangle, to count each interaction only once
242 | PE = G * np.sum(np.sum(np.triu(-(mass * mass.T) * inv_r, 1)))
243 |
244 | return KE, PE, KE + PE
245 |
246 | def sample_trajectory(self, T=10000, sample_freq=10):
247 | assert T % sample_freq == 0
248 |
249 | T_save = int(T / sample_freq)
250 |
251 | N = self.n_balls
252 |
253 | pos_save = np.zeros((T_save, N, self.dim))
254 | vel_save = np.zeros((T_save, N, self.dim))
255 | force_save = np.zeros((T_save, N, self.dim))
256 |
257 | # Specific sim parameters
258 | mass = np.ones((N, 1))
259 | t = 0
260 | pos = np.random.randn(N, self.dim) # randomly selected positions and velocities
261 | vel = np.random.randn(N, self.dim)
262 |
263 | # Convert to Center-of-Mass frame
264 | vel -= np.mean(mass * vel, 0) / np.mean(mass)
265 |
266 | # calculate initial gravitational accelerations
267 | acc = self.compute_acceleration(
268 | pos, mass, self.interaction_strength, self.softening
269 | )
270 |
271 | for i in range(T):
272 | if i % sample_freq == 0:
273 | pos_save[int(i / sample_freq)] = pos
274 | vel_save[int(i / sample_freq)] = vel
275 | force_save[int(i / sample_freq)] = acc * mass
276 |
277 | # (1/2) kick
278 | vel += acc * self.dt / 2.0
279 |
280 | # drift
281 | pos += vel * self.dt
282 |
283 | # update accelerations
284 | acc = self.compute_acceleration(
285 | pos, mass, self.interaction_strength, self.softening
286 | )
287 |
288 | # (1/2) kick
289 | vel += acc * self.dt / 2.0
290 |
291 | # update time
292 | t += self.dt
293 |
294 | # Add noise to observations
295 | pos_save += np.random.randn(T_save, N, self.dim) * self.noise_var
296 | vel_save += np.random.randn(T_save, N, self.dim) * self.noise_var
297 | force_save += np.random.randn(T_save, N, self.dim) * self.noise_var
298 | return pos_save, vel_save, force_save, mass
299 |
--------------------------------------------------------------------------------
/segnn_jax/segnn.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, List, Optional, Union
2 |
3 | import e3nn_jax as e3nn
4 | import haiku as hk
5 | import jax.numpy as jnp
6 | import jraph
7 | from e3nn_jax import IrrepsArray
8 | from jax.tree_util import Partial
9 |
10 | from .blocks import O3_LAYERS, O3TensorProduct, O3TensorProductGate, TensorProduct
11 | from .config import config
12 | from .graph_utils import SteerableGraphsTuple, pooling
13 |
14 |
15 | def O3Embedding(
16 | embed_irreps: e3nn.Irreps,
17 | embed_edges: bool = True,
18 | O3Layer: TensorProduct = O3TensorProduct,
19 | ) -> Callable:
20 | """Linear steerable embedding.
21 |
22 | Embeds the graph nodes in the representation space :param embed_irreps:.
23 |
24 | Args:
25 | embed_irreps: Output representation
26 | embed_edges: If true also embed edges/message passing features
27 | O3Layer: Type of tensor product layer to use
28 |
29 | Returns:
30 | Function to embed graph nodes (and optionally edges)
31 | """
32 |
33 | def _embedding(
34 | st_graph: SteerableGraphsTuple,
35 | ) -> SteerableGraphsTuple:
36 | graph = st_graph.graph
37 | nodes = O3Layer(embed_irreps, name="embedding_nodes")(
38 | graph.nodes, st_graph.node_attributes
39 | )
40 | st_graph = st_graph._replace(graph=graph._replace(nodes=nodes))
41 |
42 | # NOTE edge embedding is not in the original paper but can get good results
43 | if embed_edges:
44 | additional_message_features = O3Layer(
45 | embed_irreps,
46 | name="embedding_msg_features",
47 | )(
48 | st_graph.additional_message_features,
49 | st_graph.edge_attributes,
50 | )
51 | st_graph = st_graph._replace(
52 | additional_message_features=additional_message_features
53 | )
54 |
55 | return st_graph
56 |
57 | return _embedding
58 |
59 |
60 | def O3Decoder(
61 | latent_irreps: e3nn.Irreps,
62 | output_irreps: e3nn.Irreps,
63 | blocks: int = 1,
64 | task: str = "graph",
65 | pool: Optional[str] = "avg",
66 | pooled_irreps: Optional[e3nn.Irreps] = None,
67 | O3Layer: TensorProduct = O3TensorProduct,
68 | ):
69 | """Steerable pooler and decoder.
70 |
71 | Args:
72 | latent_irreps: Representation from the previous block
73 | output_irreps: Output representation
74 | blocks: Number of tensor product blocks in the decoder
75 | task: Specifies where the output is located. Either 'graph' or 'node'
76 | pool: Pooling method to use. One of 'avg', 'sum', 'none', None
77 | pooled_irreps: Pooled irreps. When left None the original implementation is used
78 | O3Layer: Type of tensor product layer to use
79 |
80 | Returns:
81 | Decoded latent feature space to output space.
82 | """
83 |
84 | assert task in ["node", "graph"], f"Unknown task {task}"
85 | assert pool in ["avg", "sum", "none", None], f"Unknown pooling '{pool}'"
86 |
87 | # NOTE: original implementation restricted final layers to pooled_irreps.
88 | # This way gates cannot be applied in the post pool block when returning vectors,
89 | # because the gating scalars cannot be reached.
90 | if pooled_irreps is None:
91 | pooled_irreps = (output_irreps * latent_irreps.num_irreps).regroup()
92 |
93 | def _decoder(st_graph: SteerableGraphsTuple):
94 | nodes = st_graph.graph.nodes
95 | # pre pool block
96 | for i in range(blocks):
97 | nodes = O3TensorProductGate(
98 | latent_irreps, name=f"prepool_{i}", o3_layer=O3Layer
99 | )(nodes, st_graph.node_attributes)
100 |
101 | if task == "node":
102 | nodes = O3Layer(output_irreps, name="output")(
103 | nodes, st_graph.node_attributes
104 | )
105 |
106 | if task == "graph":
107 | # pool over graph
108 | nodes = O3Layer(pooled_irreps, name=f"prepool_{blocks}")(
109 | nodes, st_graph.node_attributes
110 | )
111 |
112 | # pooling layer
113 | if pool == "avg":
114 | pool_fn = jraph.segment_mean
115 | if pool == "sum":
116 | pool_fn = jraph.segment_sum
117 |
118 | nodes = pooling(st_graph.graph._replace(nodes=nodes), aggregate_fn=pool_fn)
119 |
120 | # post pool mlp (not steerable)
121 | for i in range(blocks):
122 | nodes = O3TensorProductGate(
123 | pooled_irreps, name=f"postpool_{i}", o3_layer=O3TensorProduct
124 | )(nodes)
125 | nodes = O3TensorProduct(output_irreps, name="output")(nodes)
126 |
127 | return nodes
128 |
129 | return _decoder
130 |
131 |
132 | class SEGNNLayer(hk.Module):
133 | """
134 | Steerable E(3) equivariant layer.
135 |
136 | Applies a message passing step (GN) with equivariant message and update functions.
137 | """
138 |
139 | def __init__(
140 | self,
141 | output_irreps: e3nn.Irreps,
142 | layer_num: int,
143 | blocks: int = 2,
144 | norm: Optional[str] = None,
145 | aggregate_fn: Optional[Callable] = jraph.segment_sum,
146 | residual: bool = True,
147 | O3Layer: TensorProduct = O3TensorProduct,
148 | ):
149 | """
150 | Initialize the layer.
151 |
152 | Args:
153 | output_irreps: Layer output representation
154 | layer_num: Numbering of the layer
155 | blocks: Number of tensor product blocks in the layer
156 | norm: Normalization type. Either be None, 'instance' or 'batch'
157 | aggregate_fn: Message aggregation function. Defaults to sum.
158 | residual: If true, use residual connections
159 | O3Layer: Type of tensor product layer to use
160 | """
161 | super().__init__(f"layer_{layer_num}")
162 | assert norm in ["batch", "instance", "none", None], f"Unknown norm '{norm}'"
163 | self._output_irreps = output_irreps
164 | self._blocks = blocks
165 | self._norm = norm
166 | self._aggregate_fn = aggregate_fn
167 | self._residual = residual
168 |
169 | self._O3Layer = O3Layer
170 |
171 | def _message(
172 | self,
173 | edge_attribute: IrrepsArray,
174 | additional_message_features: IrrepsArray,
175 | edge_features: Any,
176 | incoming: IrrepsArray,
177 | outgoing: IrrepsArray,
178 | globals_: Any,
179 | ) -> IrrepsArray:
180 | """Steerable equivariant message function."""
181 | _ = globals_
182 | _ = edge_features
183 | # create messages
184 | msg = e3nn.concatenate([incoming, outgoing], axis=-1)
185 | if additional_message_features is not None:
186 | msg = e3nn.concatenate([msg, additional_message_features], axis=-1)
187 | # message mlp (phi_m in the paper) steered by edge attributeibutes
188 | for i in range(self._blocks):
189 | msg = O3TensorProductGate(
190 | self._output_irreps, name=f"tp_{i}", o3_layer=self._O3Layer
191 | )(msg, edge_attribute)
192 | # NOTE: original implementation only applied batch norm to messages
193 | if self._norm == "batch":
194 | msg = e3nn.haiku.BatchNorm(irreps=self._output_irreps)(msg)
195 | return msg
196 |
197 | def _update(
198 | self,
199 | node_attribute: IrrepsArray,
200 | nodes: IrrepsArray,
201 | senders: Any,
202 | msg: IrrepsArray,
203 | globals_: Any,
204 | ) -> IrrepsArray:
205 | """Steerable equivariant update function."""
206 | _ = senders
207 | _ = globals_
208 | x = e3nn.concatenate([nodes, msg], axis=-1)
209 | # update mlp (phi_f in the paper) steered by node attributeibutes
210 | for i in range(self._blocks - 1):
211 | x = O3TensorProductGate(
212 | self._output_irreps, name=f"tp_{i}", o3_layer=self._O3Layer
213 | )(x, node_attribute)
214 | # last update layer without activation
215 | update = self._O3Layer(self._output_irreps, name=f"tp_{self._blocks - 1}")(
216 | x, node_attribute
217 | )
218 | # residual connection
219 | if self._residual:
220 | nodes += update
221 | else:
222 | nodes = update
223 | # message norm
224 | if self._norm in ["batch", "instance"]:
225 | nodes = e3nn.haiku.BatchNorm(
226 | irreps=self._output_irreps,
227 | instance=(self._norm == "instance"),
228 | )(nodes)
229 | return nodes
230 |
231 | def __call__(self, st_graph: SteerableGraphsTuple) -> SteerableGraphsTuple:
232 | """Perform a message passing step.
233 |
234 | Args:
235 | st_graph: Input graph
236 |
237 | Returns:
238 | The updated graph
239 | """
240 | # NOTE node_attributes, edge_attributes and additional_message_features
241 | # are never updated within the message passing layers
242 | return st_graph._replace(
243 | graph=jraph.GraphNetwork(
244 | update_node_fn=Partial(self._update, st_graph.node_attributes),
245 | update_edge_fn=Partial(
246 | self._message,
247 | st_graph.edge_attributes,
248 | st_graph.additional_message_features,
249 | ),
250 | aggregate_edges_for_nodes_fn=self._aggregate_fn,
251 | )(st_graph.graph)
252 | )
253 |
254 |
255 | class SEGNN(hk.Module):
256 | """Steerable E(3) equivariant network.
257 |
258 | Original paper https://arxiv.org/abs/2110.02905.
259 | """
260 |
261 | def __init__(
262 | self,
263 | hidden_irreps: Union[List[e3nn.Irreps], e3nn.Irreps],
264 | output_irreps: e3nn.Irreps,
265 | num_layers: int,
266 | norm: Optional[str] = None,
267 | pool: Optional[str] = "avg",
268 | task: Optional[str] = "graph",
269 | blocks_per_layer: int = 2,
270 | embed_msg_features: bool = False,
271 | o3_layer: Optional[Union[str, TensorProduct]] = None,
272 | ):
273 | """
274 | Initialize the network.
275 |
276 | Args:
277 | hidden_irreps: Feature representation in the hidden layers
278 | output_irreps: Output representation.
279 | num_layers: Number of message passing layers
280 | norm: Normalization type. Either None, 'instance' or 'batch'
281 | pool: Pooling mode (only for graph-wise tasks)
282 | task: Specifies where the output is located. Either 'graph' or 'node'
283 | blocks_per_layer: Number of tensor product blocks in each message passing
284 | embed_msg_features: Set to true to also embed edges/message passing features
285 | o3_layer: Tensor product layer type. "tpl", "fctp", "scn" or a custom layer
286 | """
287 | super().__init__()
288 |
289 | if not isinstance(output_irreps, e3nn.Irreps):
290 | output_irreps = e3nn.Irreps(output_irreps)
291 | if not isinstance(hidden_irreps, e3nn.Irreps):
292 | hidden_irreps = e3nn.Irreps(hidden_irreps)
293 |
294 | self._hidden_irreps = hidden_irreps
295 | self._num_layers = num_layers
296 |
297 | self._embed_msg_features = embed_msg_features
298 | self._norm = norm
299 | self._blocks_per_layer = blocks_per_layer
300 |
301 | # layer type
302 | if o3_layer is None:
303 | o3_layer = config("o3_layer")
304 | if isinstance(o3_layer, str):
305 | assert o3_layer in O3_LAYERS, f"Unknown O3 layer {o3_layer}."
306 | self._O3Layer = O3_LAYERS[o3_layer]
307 | else:
308 | self._O3Layer = o3_layer
309 |
310 | self._embedding = O3Embedding(
311 | self._hidden_irreps,
312 | O3Layer=self._O3Layer,
313 | embed_edges=self._embed_msg_features,
314 | )
315 |
316 | pooled_irreps = None
317 | if task == "graph" and "0e" not in output_irreps:
318 | # NOTE: different from original. This way proper gates are always applied
319 | pooled_irreps = hidden_irreps
320 |
321 | self._decoder = O3Decoder(
322 | latent_irreps=self._hidden_irreps,
323 | output_irreps=output_irreps,
324 | O3Layer=self._O3Layer,
325 | task=task,
326 | pool=pool,
327 | pooled_irreps=pooled_irreps,
328 | )
329 |
330 | def __call__(self, st_graph: SteerableGraphsTuple) -> jnp.array:
331 | # node (and edge) embedding
332 | st_graph = self._embedding(st_graph)
333 |
334 | # message passing
335 | for n in range(self._num_layers):
336 | st_graph = SEGNNLayer(
337 | output_irreps=self._hidden_irreps, layer_num=n, norm=self._norm
338 | )(st_graph)
339 |
340 | # decoder/pooler
341 | nodes = self._decoder(st_graph)
342 |
343 | return jnp.squeeze(nodes.array)
344 |
--------------------------------------------------------------------------------
/experiments/qm9/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from typing import List, Optional
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from e3nn.o3 import Irreps, spherical_harmonics
9 | from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
10 | from torch_geometric.nn import radius_graph
11 | from torch_scatter import scatter
12 | from tqdm import tqdm
13 |
14 | HAR2EV = 27.211386246
15 | KCALMOL2EV = 0.04336414
16 |
17 | conversion = torch.tensor(
18 | [
19 | 1.0,
20 | 1.0,
21 | HAR2EV,
22 | HAR2EV,
23 | HAR2EV,
24 | 1.0,
25 | HAR2EV,
26 | HAR2EV,
27 | HAR2EV,
28 | HAR2EV,
29 | HAR2EV,
30 | 1.0,
31 | KCALMOL2EV,
32 | KCALMOL2EV,
33 | KCALMOL2EV,
34 | KCALMOL2EV,
35 | 1.0,
36 | 1.0,
37 | 1.0,
38 | ]
39 | )
40 |
41 | atomrefs = {
42 | 6: [0.0, 0.0, 0.0, 0.0, 0.0],
43 | 7: [-13.61312172, -1029.86312267, -1485.30251237, -2042.61123593, -2713.48485589],
44 | 8: [-13.5745904, -1029.82456413, -1485.26398105, -2042.5727046, -2713.44632457],
45 | 9: [-13.54887564, -1029.79887659, -1485.2382935, -2042.54701705, -2713.42063702],
46 | 10: [-13.90303183, -1030.25891228, -1485.71166277, -2043.01812778, -2713.88796536],
47 | 11: [0.0, 0.0, 0.0, 0.0, 0.0],
48 | }
49 |
50 |
51 | targets = [
52 | "mu",
53 | "alpha",
54 | "homo",
55 | "lumo",
56 | "gap",
57 | "r2",
58 | "zpve",
59 | "U0",
60 | "U",
61 | "H",
62 | "G",
63 | "Cv",
64 | "U0_atom",
65 | "U_atom",
66 | "H_atom",
67 | "G_atom",
68 | "A",
69 | "B",
70 | "C",
71 | ]
72 |
73 | thermo_targets = ["U", "U0", "H", "G"]
74 |
75 |
76 | class TargetGetter(object):
77 | """Gets relevant target"""
78 |
79 | def __init__(self, target):
80 | self.target = target
81 | self.target_idx = targets.index(target)
82 |
83 | def __call__(self, data):
84 | # Specify target.
85 | data.y = data.y[0, self.target_idx]
86 | return data
87 |
88 |
89 | class QM9(InMemoryDataset):
90 | r"""The QM9 dataset from the `"MoleculeNet: A Benchmark for Molecular
91 | Machine Learning" `_ paper, consisting of
92 | about 130,000 molecules with 19 regression targets.
93 | Each molecule includes complete spatial information for the single low
94 | energy conformation of the atoms in the molecule.
95 | In addition, we provide the atom features from the `"Neural Message
96 | Passing for Quantum Chemistry" `_ paper."""
97 |
98 | raw_url = (
99 | "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/"
100 | "molnet_publish/qm9.zip"
101 | )
102 | raw_url2 = "https://ndownloader.figshare.com/files/3195404"
103 | processed_url = "https://data.pyg.org/datasets/qm9_v3.zip"
104 |
105 | def __init__(
106 | self, root, target, radius, partition, lmax_attr, feature_type="one_hot"
107 | ):
108 | assert feature_type in [
109 | "one_hot",
110 | "cormorant",
111 | "gilmer",
112 | ], "Please use valid features"
113 | assert target in targets
114 | assert partition in ["train", "valid", "test"]
115 | self.root = osp.abspath(osp.join(root, "qm9"))
116 | self.target = target
117 | self.radius = radius
118 | self.partition = partition
119 | self.feature_type = feature_type
120 | self.lmax_attr = lmax_attr
121 | self.attr_irreps = Irreps.spherical_harmonics(lmax_attr)
122 | transform = TargetGetter(self.target)
123 |
124 | super().__init__(self.root, transform)
125 |
126 | self.data, self.slices = torch.load(self.processed_paths[0])
127 |
128 | def calc_stats(self):
129 | ys = np.array([data.y.item() for data in self])
130 | mean = np.mean(ys)
131 | mad = np.mean(np.abs(ys - mean))
132 | return mean, mad
133 |
134 | def atomref(self, target) -> Optional[torch.Tensor]:
135 | if target in atomrefs:
136 | out = torch.zeros(100)
137 | out[torch.tensor([1, 6, 7, 8, 9])] = torch.tensor(atomrefs[target])
138 | return out.view(-1, 1)
139 | return None
140 |
141 | @property
142 | def raw_file_names(self) -> Optional[List[str]]:
143 | try:
144 | import rdkit # noqa
145 |
146 | return ["gdb9.sdf", "gdb9.sdf.csv", "uncharacterized.txt"]
147 | except ImportError:
148 | import sys
149 |
150 | print("Please install rdkit")
151 | sys.exit(1)
152 |
153 | @property
154 | def processed_file_names(self) -> List[str]:
155 | return [
156 | "_".join(
157 | [
158 | self.partition,
159 | "r=" + str(np.round(self.radius, 2)),
160 | self.feature_type,
161 | "l=" + str(self.lmax_attr),
162 | ]
163 | )
164 | + ".pt"
165 | ]
166 |
167 | def download(self):
168 | print("i'm downloading", self.raw_dir, self.raw_url)
169 | try:
170 | import rdkit # noqa
171 |
172 | file_path = download_url(self.raw_url, self.raw_dir)
173 | extract_zip(file_path, self.raw_dir)
174 | os.unlink(file_path)
175 |
176 | file_path = download_url(self.raw_url2, self.raw_dir)
177 | os.rename(
178 | osp.join(self.raw_dir, "3195404"),
179 | osp.join(self.raw_dir, "uncharacterized.txt"),
180 | )
181 | except ImportError:
182 | path = download_url(self.processed_url, self.raw_dir)
183 | extract_zip(path, self.raw_dir)
184 | os.unlink(path)
185 |
186 | def process(self):
187 | try:
188 | from rdkit import Chem, RDLogger
189 | from rdkit.Chem.rdchem import HybridizationType
190 |
191 | RDLogger.DisableLog("rdApp.*")
192 | except ImportError:
193 | print("Please install rdkit")
194 | return
195 |
196 | print(
197 | "Processing",
198 | self.partition,
199 | "with radius=" + str(np.round(self.radius, 2)) + ",",
200 | "l_attr=" + str(self.lmax_attr),
201 | "and",
202 | self.feature_type,
203 | "features.",
204 | )
205 | types = {"H": 0, "C": 1, "N": 2, "O": 3, "F": 4}
206 |
207 | with open(self.raw_paths[1], "r") as f:
208 | target = f.read().split("\n")[1:-1]
209 | target = [[float(x) for x in line.split(",")[1:20]] for line in target]
210 | target = torch.tensor(target, dtype=torch.float)
211 | target = torch.cat([target[:, 3:], target[:, :3]], dim=-1)
212 | target = target * conversion.view(1, -1)
213 |
214 | with open(self.raw_paths[2], "r") as f:
215 | skip = [int(x.split()[0]) - 1 for x in f.read().split("\n")[9:-2]]
216 |
217 | suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False, sanitize=False)
218 | data_list = []
219 |
220 | # Create splits identical to Cormorant
221 | Nmols = len(suppl) - len(skip)
222 | Ntrain = 100000
223 | Ntest = int(0.1 * Nmols)
224 | Nvalid = Nmols - (Ntrain + Ntest)
225 |
226 | np.random.seed(0)
227 | data_perm = np.random.permutation(Nmols)
228 | train, valid, test = np.split(data_perm, [Ntrain, Ntrain + Nvalid])
229 | indices = {"train": train, "valid": valid, "test": test}
230 |
231 | # Add a very ugly second index to align with Cormorant splits.
232 | j = 0
233 | for i, mol in enumerate(tqdm(suppl)):
234 | if i in skip:
235 | continue
236 | if j not in indices[self.partition]:
237 | j += 1
238 | continue
239 | j += 1
240 |
241 | N = mol.GetNumAtoms()
242 |
243 | pos = suppl.GetItemText(i).split("\n")[4 : 4 + N]
244 | pos = [[float(x) for x in line.split()[:3]] for line in pos]
245 | pos = torch.tensor(pos, dtype=torch.float)
246 |
247 | edge_index = radius_graph(pos, r=self.radius, loop=False)
248 |
249 | type_idx = []
250 | atomic_number = []
251 | aromatic = []
252 | sp = []
253 | sp2 = []
254 | sp3 = []
255 | num_hs = []
256 | for atom in mol.GetAtoms():
257 | type_idx.append(types[atom.GetSymbol()])
258 | atomic_number.append(atom.GetAtomicNum())
259 | aromatic.append(1 if atom.GetIsAromatic() else 0)
260 | hybridization = atom.GetHybridization()
261 | sp.append(1 if hybridization == HybridizationType.SP else 0)
262 | sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
263 | sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
264 |
265 | z = torch.tensor(atomic_number, dtype=torch.long)
266 |
267 | if self.feature_type == "one_hot":
268 | x = F.one_hot(torch.tensor(type_idx), num_classes=len(types)).float()
269 | elif self.feature_type == "cormorant":
270 | one_hot = F.one_hot(torch.tensor(type_idx), num_classes=len(types))
271 | x = get_cormorant_features(one_hot, z, 2, z.max())
272 | elif self.feature_type == "gilmer":
273 | row, col = edge_index
274 | hs = (z == 1).to(torch.float)
275 | num_hs = scatter(hs[row], col, dim_size=N).tolist()
276 |
277 | x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types))
278 | x2 = (
279 | torch.tensor(
280 | [atomic_number, aromatic, sp, sp2, sp3, num_hs],
281 | dtype=torch.float,
282 | )
283 | .t()
284 | .contiguous()
285 | )
286 | x = torch.cat([x1.to(torch.float), x2], dim=-1)
287 |
288 | y = target[i].unsqueeze(0)
289 | name = mol.GetProp("_Name")
290 |
291 | edge_attr, node_attr, edge_dist = self.get_O3_attr(
292 | edge_index, pos, self.attr_irreps
293 | )
294 |
295 | data = Data(
296 | x=x,
297 | pos=pos,
298 | edge_index=edge_index,
299 | edge_attr=edge_attr,
300 | node_attr=node_attr,
301 | additional_message_features=edge_dist,
302 | y=y,
303 | name=name,
304 | index=i,
305 | )
306 | data_list.append(data)
307 |
308 | torch.save(self.collate(data_list), self.processed_paths[0])
309 |
310 | def get_O3_attr(self, edge_index, pos, attr_irreps):
311 | """
312 | Creates spherical harmonic edge attributes and node attributes for the SEGNN.
313 | """
314 | rel_pos = (
315 | pos[edge_index[0]] - pos[edge_index[1]]
316 | ) # pos_j - pos_i (note in edge_index stores tuples like (j,i))
317 | edge_dist = rel_pos.pow(2).sum(-1, keepdims=True)
318 | edge_attr = spherical_harmonics(
319 | attr_irreps, rel_pos, normalize=True, normalization="component"
320 | ) # Unnormalised for now
321 | node_attr = scatter(edge_attr, edge_index[1], dim=0, reduce="mean")
322 | return edge_attr, node_attr, edge_dist
323 |
324 | def top_n_nodes(self, n: int) -> List[int]:
325 | """Returns the largest n nodes in the dataset."""
326 | return [int(k) for k in torch.topk(torch.diff(self.slices["x"]), n)[0]]
327 |
328 | def top_n_edges(self, n: int) -> List[int]:
329 | """Returns the largest n edge in the dataset."""
330 | return [int(k) for k in torch.topk(torch.diff(self.slices["edge_attr"]), n)[0]]
331 |
332 |
333 | def get_cormorant_features(one_hot, charges, charge_power, charge_scale):
334 | """Create input features as described in section 7.3 of https://arxiv.org/pdf/1906.04015.pdf"""
335 | charge_tensor = (charges.unsqueeze(-1) / charge_scale).pow(
336 | torch.arange(charge_power + 1.0, dtype=torch.float32)
337 | )
338 | charge_tensor = charge_tensor.view(charges.shape + (1, charge_power + 1))
339 | atom_scalars = (one_hot.unsqueeze(-1) * charge_tensor).view(
340 | charges.shape[:2] + (-1,)
341 | )
342 | return atom_scalars
343 |
344 |
345 | if __name__ == "__main__":
346 | dataset = QM9("datasets", "alpha", 2.0, "train", 3, feature_type="one_hot")
347 | print("length", len(dataset))
348 | ys = np.array([data.y.item() for data in dataset])
349 | mean, mad = dataset.calc_stats()
350 |
351 | for item in dataset:
352 | print(item.edge_index)
353 | break
354 |
355 | print("mean", mean, "mad", mad)
356 | import matplotlib.pyplot as plt
357 |
358 | plt.subplot(121)
359 | plt.title(dataset.target)
360 | plt.hist(ys)
361 | plt.subplot(122)
362 | plt.title(dataset.target + " standardised")
363 | plt.hist((ys - mean) / mad)
364 | plt.show()
365 |
--------------------------------------------------------------------------------
/segnn_jax/blocks.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from abc import ABC, abstractmethod
3 | from typing import Callable, Optional, Tuple, Union
4 |
5 | import e3nn_jax as e3nn
6 | import haiku as hk
7 | import jax
8 | import jax.numpy as jnp
9 | from e3nn_jax import IrrepsArray
10 |
11 | try:
12 | from e3nn_jax.experimental import linear_shtp as escn
13 | except ImportError:
14 | escn = None
15 | try:
16 | from e3nn_jax import FunctionalFullyConnectedTensorProduct
17 | except ImportError:
18 | from e3nn_jax.legacy import FunctionalFullyConnectedTensorProduct # type: ignore
19 |
20 | from .config import config
21 |
22 | InitFn = Callable[[str, Tuple[int, ...], float, jnp.dtype], jnp.ndarray]
23 | TensorProductFn = Callable[[IrrepsArray, IrrepsArray], IrrepsArray]
24 |
25 |
26 | def uniform_init(
27 | name: str,
28 | path_shape: Tuple[int, ...],
29 | weight_std: float,
30 | dtype: jnp.dtype = config("default_dtype"),
31 | ) -> jnp.ndarray:
32 | return hk.get_parameter(
33 | name,
34 | shape=path_shape,
35 | dtype=dtype,
36 | init=hk.initializers.RandomUniform(minval=-weight_std, maxval=weight_std),
37 | )
38 |
39 |
40 | class TensorProduct(hk.Module, ABC):
41 | """O(3) equivariant linear parametrized tensor product layer."""
42 |
43 | def __init__(
44 | self,
45 | output_irreps: e3nn.Irreps,
46 | *,
47 | biases: bool = True,
48 | name: Optional[str] = None,
49 | init_fn: Optional[InitFn] = None,
50 | gradient_normalization: Optional[Union[str, float]] = None,
51 | path_normalization: Optional[Union[str, float]] = None,
52 | ):
53 | """Initialize the tensor product.
54 |
55 | Args:
56 | output_irreps: Output representation
57 | biases: If set ot true will add biases
58 | name: Name of the linear layer params
59 | init_fn: Weight initialization function. Default is uniform.
60 | gradient_normalization: Gradient normalization method. Default is "path"
61 | NOTE: gradient_normalization="element" is the default in torch and haiku.
62 | path_normalization: Path normalization method. Default is "element"
63 | """
64 | super().__init__(name=name)
65 |
66 | if not isinstance(output_irreps, e3nn.Irreps):
67 | output_irreps = e3nn.Irreps(output_irreps)
68 | self.output_irreps = output_irreps
69 |
70 | # tp weight init
71 | if not init_fn:
72 | init_fn = uniform_init
73 | self.get_parameter = init_fn
74 |
75 | if not gradient_normalization:
76 | gradient_normalization = config("gradient_normalization")
77 | if not path_normalization:
78 | path_normalization = config("path_normalization")
79 | self._gradient_normalization = gradient_normalization
80 | self._path_normalization = path_normalization
81 |
82 | self.biases = biases and "0e" in self.output_irreps
83 |
84 | def _check_input(
85 | self, x: IrrepsArray, y: Optional[IrrepsArray] = None
86 | ) -> Tuple[IrrepsArray, IrrepsArray]:
87 | if not y:
88 | y = IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype))
89 |
90 | if x.irreps.lmax == 0 and y.irreps.lmax == 0 and self.output_irreps.lmax > 0:
91 | warnings.warn(
92 | f"The specified output irreps ({self.output_irreps}) are not scalars "
93 | "but both operands are. This can have undesired behaviour (NaN). Try "
94 | "redistributing them into scalars or choose higher orders."
95 | )
96 |
97 | return x, y
98 |
99 | @abstractmethod
100 | def __call__(
101 | self, x: IrrepsArray, y: Optional[IrrepsArray] = None, **kwargs
102 | ) -> IrrepsArray:
103 | """Applies an O(3) equivariant linear parametrized tensor product layer.
104 |
105 | Args:
106 | x (IrrepsArray): Left tensor
107 | y (IrrepsArray): Right tensor. If None it defaults to np.ones.
108 |
109 | Returns:
110 | The output to the weighted tensor product (IrrepsArray).
111 | """
112 | raise NotImplementedError
113 |
114 |
115 | class O3TensorProduct(TensorProduct):
116 | """O(3) equivariant linear parametrized tensor product layer.
117 |
118 | Original O3TensorProduct version that uses tensor_product + Linear instead of
119 | FullyConnectedTensorProduct.
120 | From e3nn 0.19.2 (https://github.com/e3nn/e3nn-jax/releases/tag/0.19.2), this is
121 | as fast as FullyConnectedTensorProduct.
122 | """
123 |
124 | def __init__(
125 | self,
126 | output_irreps: e3nn.Irreps,
127 | *,
128 | biases: bool = True,
129 | name: Optional[str] = None,
130 | init_fn: Optional[InitFn] = None,
131 | gradient_normalization: Optional[Union[str, float]] = "element",
132 | path_normalization: Optional[Union[str, float]] = None,
133 | ):
134 | super().__init__(
135 | output_irreps,
136 | biases=biases,
137 | name=name,
138 | init_fn=init_fn,
139 | gradient_normalization=gradient_normalization,
140 | path_normalization=path_normalization,
141 | )
142 |
143 | self._linear = e3nn.haiku.Linear(
144 | self.output_irreps,
145 | get_parameter=self.get_parameter,
146 | biases=self.biases,
147 | gradient_normalization=self._gradient_normalization,
148 | path_normalization=self._path_normalization,
149 | )
150 |
151 | def _check_input(
152 | self, x: IrrepsArray, y: Optional[IrrepsArray] = None
153 | ) -> Tuple[IrrepsArray, IrrepsArray]:
154 | x, y = super()._check_input(x, y)
155 | miss = self.output_irreps.filter(drop=e3nn.tensor_product(x.irreps, y.irreps))
156 | if len(miss) > 0:
157 | warnings.warn(f"Output irreps: '{miss}' are unreachable and were ignored.")
158 | return x, y
159 |
160 | def __call__(self, x: IrrepsArray, y: Optional[IrrepsArray] = None) -> IrrepsArray:
161 | x, y = self._check_input(x, y)
162 | # tensor product + linear
163 | tp = self._linear(e3nn.tensor_product(x, y))
164 | return tp
165 |
166 |
167 | class O3TensorProductFC(TensorProduct):
168 | """
169 | O(3) equivariant linear parametrized tensor product layer.
170 |
171 | Functionally the same as O3TensorProduct, but uses FullyConnectedTensorProduct and
172 | is slightly slower (~5-10%) than tensor_prodict + Linear.
173 | """
174 |
175 | def _build_tensor_product(
176 | self, left_irreps: e3nn.Irreps, right_irreps: e3nn.Irreps
177 | ) -> Callable:
178 | """Build the tensor product function."""
179 | tp = FunctionalFullyConnectedTensorProduct(
180 | left_irreps,
181 | right_irreps,
182 | self.output_irreps,
183 | gradient_normalization=self._gradient_normalization,
184 | path_normalization=self._path_normalization,
185 | )
186 | ws = [
187 | self.get_parameter(
188 | name=(
189 | f"w[{ins.i_in1},{ins.i_in2},{ins.i_out}] "
190 | f"{tp.irreps_in1[ins.i_in1]},"
191 | f"{tp.irreps_in2[ins.i_in2]},"
192 | f"{tp.irreps_out[ins.i_out]}"
193 | ),
194 | path_shape=ins.path_shape,
195 | weight_std=ins.weight_std,
196 | )
197 | for ins in tp.instructions
198 | ]
199 |
200 | def _tensor_product(x, y, **kwargs):
201 | out = tp.left_right(ws, x, y, **kwargs)
202 | # same as out.rechunk(self.output_irreps) but works with older e3nn versions
203 | return IrrepsArray(self.output_irreps, out.array)
204 |
205 | # naive broadcasting wrapper
206 | # TODO: not the best
207 | def _tp_wrapper(*args):
208 | leading_shape = jnp.broadcast_shapes(*(arg.shape[:-1] for arg in args))
209 | args = [arg.broadcast_to(leading_shape + (-1,)) for arg in args]
210 | for _ in range(len(leading_shape)):
211 | f = jax.vmap(_tensor_product)
212 | return f(*args)
213 |
214 | return _tp_wrapper
215 |
216 | def _build_biases(self) -> Callable:
217 | """Build the add bias function."""
218 | b = [
219 | self.get_parameter(
220 | f"b[{i_out}] {self.output_irreps}",
221 | path_shape=(mul_ir.dim,),
222 | weight_std=1 / jnp.sqrt(mul_ir.dim),
223 | )
224 | for i_out, mul_ir in enumerate(self.output_irreps)
225 | if mul_ir.ir.is_scalar()
226 | ]
227 | b = IrrepsArray(f"{self.output_irreps.count('0e')}x0e", jnp.concatenate(b))
228 |
229 | # TODO: could be improved
230 | def _bias_wrapper(x: IrrepsArray) -> IrrepsArray:
231 | scalars = x.filter("0e")
232 | other = x.filter(drop="0e")
233 | return e3nn.concatenate(
234 | [scalars + b.broadcast_to(scalars.shape), other], axis=1
235 | )
236 |
237 | return _bias_wrapper
238 |
239 | def __call__(
240 | self, x: IrrepsArray, y: Optional[IrrepsArray] = None, **kwargs
241 | ) -> IrrepsArray:
242 | x, y = self._check_input(x, y)
243 |
244 | tp = self._build_tensor_product(x.irreps, y.irreps)
245 | output = tp(x, y, **kwargs)
246 |
247 | if self.biases:
248 | # add biases
249 | bias_fn = self._build_biases()
250 | return bias_fn(output)
251 |
252 | return output
253 |
254 |
255 | class O3TensorProductSCN(TensorProduct):
256 | """
257 | O(3) equivariant linear parametrized tensor product layer.
258 |
259 | O3TensorProduct with eSCN optimization for larger spherical harmonic orders. Should
260 | be used without spherical harmonics on the inputs.
261 | """
262 |
263 | def __init__(
264 | self,
265 | output_irreps: e3nn.Irreps,
266 | *,
267 | biases: bool = True,
268 | name: Optional[str] = None,
269 | init_fn: Optional[InitFn] = None,
270 | gradient_normalization: Optional[Union[str, float]] = None,
271 | path_normalization: Optional[Union[str, float]] = None,
272 | ):
273 | super().__init__(
274 | output_irreps,
275 | biases=biases,
276 | name=name,
277 | init_fn=init_fn,
278 | gradient_normalization=gradient_normalization,
279 | path_normalization=path_normalization,
280 | )
281 |
282 | if escn is None:
283 | raise ImportError(
284 | "eSCN is available from e3nn-jax>=0.17.3. "
285 | f"Your version: {e3nn.__version__}"
286 | )
287 |
288 | self._linear = e3nn.haiku.Linear(
289 | self.output_irreps,
290 | get_parameter=self.get_parameter,
291 | biases=self.biases,
292 | gradient_normalization=self._gradient_normalization,
293 | path_normalization=self._path_normalization,
294 | )
295 |
296 | def _check_input(
297 | self, x: IrrepsArray, y: Optional[IrrepsArray] = None
298 | ) -> Tuple[IrrepsArray, IrrepsArray]:
299 | if not y:
300 | raise ValueError("eSCN cannot be used without the right input.")
301 | return super()._check_input(x, y)
302 |
303 | def __call__(self, x: IrrepsArray, y: Optional[IrrepsArray] = None) -> IrrepsArray:
304 | """Apply the layer. y must not be into spherical harmonics."""
305 | x, y = self._check_input(x, y)
306 | # TODO make this work for e3nn-jax<0.19.1 (without e3nn.utils.vmap)
307 | try:
308 | from e3nn_jax.utils import vmap as e3nn_vmap
309 | except ImportError:
310 | raise NotImplementedError()
311 | shtp = e3nn_vmap(escn.shtp, in_axes=(0, 0, None))
312 | tp = shtp(x, y, self.output_irreps)
313 | return self._linear(tp)
314 |
315 |
316 | O3_LAYERS = {
317 | "tpl": O3TensorProduct,
318 | "fctp": O3TensorProductFC,
319 | "scn": O3TensorProductSCN,
320 | }
321 |
322 |
323 | def O3TensorProductGate(
324 | output_irreps: e3nn.Irreps,
325 | *,
326 | biases: bool = True,
327 | scalar_activation: Optional[Callable] = None,
328 | gate_activation: Optional[Callable] = None,
329 | name: Optional[str] = None,
330 | init_fn: Optional[InitFn] = None,
331 | o3_layer: Optional[Union[str, TensorProduct]] = None,
332 | ) -> TensorProductFn:
333 | """Non-linear (gated) O(3) equivariant linear tensor product layer.
334 |
335 | The tensor product lifts the input representation to have gating scalars.
336 |
337 | Args:
338 | output_irreps: Output representation
339 | biases: Add biases
340 | scalar_activation: Activation function for scalars
341 | gate_activation: Activation function for higher order
342 | name: Name of the linear layer params
343 | o3_layer: Tensor product layer type. "tpl", "fctp", "scn" or a custom layer
344 |
345 | Returns:
346 | Function that applies the gated tensor product layer
347 | """
348 |
349 | if not isinstance(output_irreps, e3nn.Irreps):
350 | output_irreps = e3nn.Irreps(output_irreps)
351 |
352 | # lift output with gating scalars
353 | gate_irreps = e3nn.Irreps(
354 | f"{output_irreps.num_irreps - output_irreps.count('0e')}x0e"
355 | )
356 |
357 | if o3_layer is None:
358 | o3_layer = config("o3_layer")
359 |
360 | if isinstance(o3_layer, str):
361 | assert o3_layer in O3_LAYERS, f"Unknown O3 layer {o3_layer}."
362 | O3Layer = O3_LAYERS[o3_layer]
363 | else:
364 | O3Layer = o3_layer
365 |
366 | tensor_product = O3Layer(
367 | (gate_irreps + output_irreps).regroup(),
368 | biases=biases,
369 | name=name,
370 | init_fn=init_fn,
371 | )
372 |
373 | if not scalar_activation:
374 | scalar_activation = jax.nn.silu
375 | if not gate_activation:
376 | gate_activation = jax.nn.sigmoid
377 |
378 | def _gated_tensor_product(
379 | x: IrrepsArray, y: Optional[IrrepsArray] = None, **kwargs
380 | ) -> IrrepsArray:
381 | tp = tensor_product(x, y, **kwargs)
382 | # skip gate if the gating scalars are not reachable
383 | if len(gate_irreps.filter(drop=tp.irreps)) > 0:
384 | return tp
385 | else:
386 | return e3nn.gate(tp, scalar_activation, odd_gate_act=gate_activation)
387 |
388 | return _gated_tensor_product
389 |
--------------------------------------------------------------------------------