├── nbody ├── __init__.py ├── requirements.txt ├── utils.py ├── datasets.py └── data │ ├── generate_dataset.py │ └── synthetic_sim.py ├── setup.py ├── requirements.txt ├── .gitignore ├── egnn_jax ├── __init__.py ├── utils.py ├── variants.py └── egnn.py ├── pyproject.toml ├── setup.cfg ├── .pre-commit-config.yaml ├── LICENSE ├── .github └── workflows │ └── build_branch.yaml ├── README.md └── validate.py /nbody/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 2 | dm-haiku==0.0.9 3 | jax[cuda]==0.4.10 4 | jraph==0.0.6.dev0 5 | numpy>=1.23.4 6 | optax==0.1.5 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # experiments 2 | *.npy 3 | wandb/ 4 | *.out 5 | datasets/ 6 | 7 | # dev 8 | .vscode 9 | __pycache__/ 10 | *.pyc 11 | venv 12 | 13 | # build 14 | *.whl 15 | build/ 16 | dist/ 17 | *.egg-info/ 18 | -------------------------------------------------------------------------------- /nbody/requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cpu 2 | 3 | --find-links https://data.pyg.org/whl/torch-1.13.1+cpu.html 4 | matplotlib>=3.4.2 5 | 6 | torch==1.13.1 7 | tqdm>=4.64.1 8 | wandb>=0.13.5 9 | -------------------------------------------------------------------------------- /egnn_jax/__init__.py: -------------------------------------------------------------------------------- 1 | from .egnn import EGNN, EGNNLayer 2 | from .variants import EGNN_vel, EGNNLayer_vel 3 | 4 | __all__ = [ 5 | "EGNN", 6 | "EGNNLayer", 7 | "EGNN_vel", 8 | "EGNNLayer_vel", 9 | ] 10 | 11 | __version__ = "0.3" 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | 6 | [tool.ruff] 7 | ignore = ["F811", "E731"] 8 | exclude = [ 9 | ".git", 10 | "venv", 11 | ] 12 | line-length = 88 13 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = egnn_jax 3 | version = attr: egnn_jax.__version__ 4 | author = Gianluca Galletti 5 | author_email = g.galletti@tum.de 6 | description = E(3) GNN in jax 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | classifiers = 10 | Programming Language :: Python :: 3.8 11 | 12 | 13 | [options] 14 | packages = egnn_jax 15 | python_requires = >=3.8 16 | install_requires = 17 | dm_haiku==0.0.9 18 | jax==0.4.10 19 | jaxlib==0.4.10 20 | jraph==0.0.6.dev0 21 | numpy>=1.23.4 22 | optax==0.1.3 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | exclude: | 4 | (?x)^( 5 | datasets/| 6 | data/| 7 | )$ 8 | repos: 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v4.0.1 11 | hooks: 12 | - id: check-merge-conflict 13 | - id: check-added-large-files 14 | - id: check-docstring-first 15 | - id: check-json 16 | - id: check-toml 17 | - id: check-xml 18 | - id: check-yaml 19 | - id: trailing-whitespace 20 | - id: end-of-file-fixer 21 | - id: requirements-txt-fixer 22 | - repo: https://github.com/pycqa/isort 23 | rev: 5.12.0 24 | hooks: 25 | - id: isort 26 | args: [ --profile, black ] 27 | - repo: https://github.com/ambv/black 28 | rev: 22.3.0 29 | hooks: 30 | - id: black 31 | - repo: https://github.com/charliermarsh/ruff-pre-commit 32 | rev: 'v0.0.265' 33 | hooks: 34 | - id: ruff 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Gianluca Galletti 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/build_branch.yaml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: push 4 | 5 | jobs: 6 | # tests: 7 | # name: Tests 8 | # runs-on: ubuntu-latest 9 | # steps: 10 | # - uses: actions/checkout@master 11 | # - name: Set up docker image 12 | # uses: actions/setup-python@v3 13 | # with: 14 | # python-version: "3.10" 15 | # - name: Install dependencies 16 | # run: >- 17 | # python -m 18 | # pip install 19 | # -r requirements.txt 20 | # --user 21 | # - name: Install pytest 22 | # run: >- 23 | # python -m 24 | # pip install 25 | # pytest 26 | # --user 27 | # - name: Run tests 28 | # run: >- 29 | # python -m pytest tests/ 30 | 31 | build-publish: 32 | name: Build and publish 33 | runs-on: ubuntu-latest 34 | steps: 35 | - uses: actions/checkout@master 36 | - name: Set up docker image 37 | uses: actions/setup-python@v3 38 | with: 39 | python-version: "3.10" 40 | - name: Install build tools 41 | run: >- 42 | python -m 43 | pip install 44 | build 45 | --user 46 | - name: Build wheel 47 | run: >- 48 | python -m 49 | build 50 | --sdist 51 | --wheel 52 | --outdir dist/ 53 | - name: Publish to PyPI 54 | if: startsWith(github.ref, 'refs/tags') 55 | uses: pypa/gh-action-pypi-publish@release/v1 56 | with: 57 | password: ${{ secrets.PYPI_TOKEN }} 58 | -------------------------------------------------------------------------------- /egnn_jax/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Iterable, Optional 2 | 3 | import haiku as hk 4 | import jax.numpy as jnp 5 | 6 | 7 | class LinearXav(hk.Linear): 8 | """Linear layer with Xavier init. Avoid distracting 'w_init' everywhere.""" 9 | 10 | def __init__( 11 | self, 12 | output_size: int, 13 | with_bias: bool = True, 14 | w_init: Optional[hk.initializers.Initializer] = None, 15 | b_init: Optional[hk.initializers.Initializer] = None, 16 | name: Optional[str] = None, 17 | ): 18 | if w_init is None: 19 | w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform") 20 | super().__init__(output_size, with_bias, w_init, b_init, name) 21 | 22 | 23 | class MLPXav(hk.nets.MLP): 24 | """MLP layer with Xavier init. Avoid distracting 'w_init' everywhere.""" 25 | 26 | def __init__( 27 | self, 28 | output_sizes: Iterable[int], 29 | with_bias: bool = True, 30 | w_init: Optional[hk.initializers.Initializer] = None, 31 | b_init: Optional[hk.initializers.Initializer] = None, 32 | activation: Callable[[jnp.ndarray], jnp.ndarray] = None, 33 | activate_final: bool = False, 34 | name: Optional[str] = None, 35 | ): 36 | if w_init is None: 37 | w_init = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform") 38 | if not with_bias: 39 | b_init = None 40 | super().__init__( 41 | output_sizes, 42 | w_init, 43 | b_init, 44 | with_bias, 45 | activation, 46 | activate_final, 47 | name, 48 | ) 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # E(n) Equivariant GNN in jax 2 | Reimplementation of [EGNN](https://arxiv.org/abs/2102.09844) in jax. Original work by Victor Garcia Satorras, Emiel Hogeboom and Max Welling. 3 | 4 | ## Installation 5 | ``` 6 | python -m pip install egnn-jax 7 | ``` 8 | 9 | Or clone this repository and build locally 10 | ``` 11 | git clone https://github.com/gerkone/egnn-jax 12 | cd painn-jax 13 | python -m pip install -e . 14 | ``` 15 | ### GPU support 16 | Upgrade `jax` to the gpu version 17 | ``` 18 | pip install --upgrade "jax[cuda]==0.4.10" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 19 | ``` 20 | 21 | ## Validation 22 | N-body (charged) is included for validation from the original paper. Times are __model only__ on batches of 100 graphs, in (global) single precision. 23 | | | MSE | Inference [ms]* | 24 | |------------------|-------|-----------------| 25 | | [torch (original)](https://github.com/vgsatorras/egnn) | .0071 | 8.27 | 26 | | jax (ours) | .0084 | 0.94 | 27 | 28 | \* remeasured (Quadro RTX 4000) 29 | 30 | ### Validation install 31 | 32 | The N-Body experiments are only included in the github repo, so it needs to be cloned first. 33 | ``` 34 | git clone https://github.com/gerkone/egnn-jax 35 | ``` 36 | 37 | They are adapted from the original implementation, so additionally `torch` and `torch_geometric` are needed (cpu versions are enough). 38 | ``` 39 | python -m pip install -r nbody/requirements.txt 40 | ``` 41 | 42 | ### Valdation usage 43 | The charged N-body dataset has to be locally generated in the directory [/nbody/data](/nbody/data). 44 | ``` 45 | python -u generate_dataset.py --num-train 3000 --seed 43 --sufix small 46 | ``` 47 | Then, the model can be trained and evaluated with 48 | ``` 49 | python validate.py --epochs=1000 --batch-size=100 --lr=1e-4 --weight-decay=1e-12 50 | ``` 51 | 52 | ## Acknowledgements 53 | This implementation heavily borrows from the [original pytorch code](https://github.com/vgsatorras/egnn). 54 | -------------------------------------------------------------------------------- /egnn_jax/variants.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Tuple 2 | 3 | import haiku as hk 4 | import jax 5 | import jax.numpy as jnp 6 | import jraph 7 | from haiku.initializers import UniformScaling 8 | 9 | from egnn_jax import EGNN, EGNNLayer 10 | from egnn_jax.utils import LinearXav 11 | 12 | 13 | class EGNNLayer_vel(EGNNLayer): 14 | def __init__( 15 | self, 16 | hidden_size: int, 17 | *args, 18 | act_fn: Callable = jax.nn.relu, 19 | dt: float = 0.001, 20 | **kwargs, 21 | ): 22 | super().__init__( 23 | *args, 24 | hidden_size=hidden_size, 25 | act_fn=act_fn, 26 | dt=dt, 27 | **kwargs, 28 | ) 29 | # velocity integrator network 30 | net = [LinearXav(hidden_size), act_fn] 31 | net += [LinearXav(1, with_bias=False, w_init=UniformScaling(dt))] 32 | self._vel_correction_mlp = hk.Sequential(net) 33 | 34 | def __call__( 35 | self, 36 | graph: jraph.GraphsTuple, 37 | pos: jnp.ndarray, 38 | vel: jnp.ndarray, 39 | edge_attribute: Optional[jnp.ndarray] = None, 40 | node_attribute: Optional[jnp.ndarray] = None, 41 | ) -> jnp.ndarray: 42 | graph, pos = super().__call__(graph, pos, edge_attribute, node_attribute) 43 | shift = self._vel_correction_mlp(graph.nodes) * vel 44 | pos = pos + jnp.clip(shift, -100, 100) 45 | return graph, pos 46 | 47 | 48 | class EGNN_vel(EGNN): 49 | def __call__( 50 | self, 51 | graph: jraph.GraphsTuple, 52 | pos: jnp.ndarray, 53 | vel: jnp.ndarray, 54 | edge_attribute: Optional[jnp.ndarray] = None, 55 | node_attribute: Optional[jnp.ndarray] = None, 56 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 57 | # input node embedding 58 | h = LinearXav(self._hidden_size, name="embedding")(graph.nodes) 59 | graph = graph._replace(nodes=h) 60 | # message passing 61 | for n in range(self._num_layers): 62 | graph, pos = EGNNLayer_vel( 63 | layer_num=n, 64 | hidden_size=self._hidden_size, 65 | output_size=self._hidden_size, 66 | act_fn=self._act_fn, 67 | residual=self._residual, 68 | attention=self._attention, 69 | normalize=self._normalize, 70 | tanh=self._tanh, 71 | )(graph, pos, vel, edge_attribute, node_attribute) 72 | return pos 73 | -------------------------------------------------------------------------------- /nbody/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Tuple 2 | 3 | import jax.numpy as jnp 4 | import jraph 5 | import numpy as np 6 | from torch.utils.data import DataLoader 7 | 8 | from .datasets import NBodyDataset 9 | 10 | 11 | def NbodyGraphTransform( 12 | n_nodes: int, 13 | batch_size: int, 14 | ) -> Callable: 15 | """ 16 | Build a function that converts torch DataBatch into jraph.GraphsTuple. 17 | """ 18 | 19 | # charged system is a connected graph 20 | full_edge_indices = jnp.array( 21 | [ 22 | (i + n_nodes * b, j + n_nodes * b) 23 | for b in range(batch_size) 24 | for i in range(n_nodes) 25 | for j in range(n_nodes) 26 | if i != j 27 | ] 28 | ).T 29 | 30 | def _to_jraph( 31 | data: List, 32 | ) -> Tuple[jraph.GraphsTuple, Dict[str, jnp.ndarray], jnp.ndarray]: 33 | props = {} 34 | pos, vel, edge_attribute, _, targets = data 35 | 36 | cur_batch = int(pos.shape[0] / n_nodes) 37 | 38 | edge_indices = full_edge_indices[:, : n_nodes * (n_nodes - 1) * cur_batch] 39 | senders, receivers = edge_indices[0], edge_indices[1] 40 | 41 | # relative distances between particles 42 | pos_dist = jnp.sum((pos[senders] - pos[receivers]) ** 2, axis=-1)[:, None] 43 | props["edge_attribute"] = jnp.concatenate([edge_attribute, pos_dist], axis=-1) 44 | props["pos"] = pos 45 | props["vel"] = vel 46 | 47 | graph = jraph.GraphsTuple( 48 | # velocity magnitude as node features (scalar) 49 | nodes=jnp.sqrt(jnp.sum(vel**2, axis=-1))[:, None], 50 | edges=None, 51 | senders=senders, 52 | receivers=receivers, 53 | n_node=jnp.array([n_nodes] * cur_batch), 54 | n_edge=jnp.array([len(senders) // cur_batch] * cur_batch), 55 | globals=None, 56 | ) 57 | 58 | return ( 59 | graph, 60 | props, 61 | targets, 62 | ) 63 | 64 | return _to_jraph 65 | 66 | 67 | def numpy_collate(batch): 68 | if isinstance(batch[0], (tuple, list)): 69 | transposed = zip(*batch) 70 | return [numpy_collate(samples) for samples in transposed] 71 | else: 72 | return np.vstack(batch) 73 | 74 | 75 | def setup_nbody_data(args) -> Tuple[DataLoader, DataLoader, DataLoader, Callable]: 76 | dataset_train = NBodyDataset(partition="train", max_samples=args.max_samples) 77 | dataset_val = NBodyDataset(partition="val") 78 | dataset_test = NBodyDataset(partition="test") 79 | 80 | graph_transform = NbodyGraphTransform(n_nodes=5, batch_size=args.batch_size) 81 | 82 | loader_train = DataLoader( 83 | dataset_train, 84 | batch_size=args.batch_size, 85 | shuffle=True, 86 | drop_last=True, 87 | collate_fn=numpy_collate, 88 | ) 89 | loader_val = DataLoader( 90 | dataset_val, 91 | batch_size=args.batch_size, 92 | shuffle=False, 93 | drop_last=False, 94 | collate_fn=numpy_collate, 95 | ) 96 | loader_test = DataLoader( 97 | dataset_test, 98 | batch_size=args.batch_size, 99 | shuffle=False, 100 | drop_last=False, 101 | collate_fn=numpy_collate, 102 | ) 103 | 104 | return loader_train, loader_val, loader_test, graph_transform 105 | -------------------------------------------------------------------------------- /nbody/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class NBodyDataset: 6 | def __init__(self, partition="train", max_samples=1e8, dataset_name="nbody_small"): 7 | self.partition = partition 8 | if self.partition == "val": 9 | self.sufix = "valid" 10 | else: 11 | self.sufix = self.partition 12 | self.dataset_name = dataset_name 13 | if dataset_name == "nbody": 14 | self.sufix += "_charged5_initvel1" 15 | elif dataset_name == "nbody_small" or dataset_name == "nbody_small_out_dist": 16 | self.sufix += "_charged5_initvel1small" 17 | else: 18 | raise Exception("Wrong dataset name %s" % self.dataset_name) 19 | 20 | self.max_samples = int(max_samples) 21 | self.dataset_name = dataset_name 22 | self.data, self.edges = self.load() 23 | 24 | def load(self): 25 | loc = np.load("nbody/data/loc_" + self.sufix + ".npy") 26 | vel = np.load("nbody/data/vel_" + self.sufix + ".npy") 27 | edges = np.load("nbody/data/edges_" + self.sufix + ".npy") 28 | charges = np.load("nbody/data/charges_" + self.sufix + ".npy") 29 | 30 | loc, vel, edge_attr, edges, charges = self.preprocess(loc, vel, edges, charges) 31 | return (loc, vel, edge_attr, charges), edges 32 | 33 | def preprocess(self, loc, vel, edges, charges): 34 | # cast to torch and swap n_nodes <--> n_features dimensions 35 | loc, vel = torch.Tensor(loc).transpose(2, 3), torch.Tensor(vel).transpose(2, 3) 36 | n_nodes = loc.size(2) 37 | loc = loc[0 : self.max_samples, :, :, :] # limit number of samples 38 | vel = vel[0 : self.max_samples, :, :, :] # speed when starting the trajectory 39 | charges = charges[0 : self.max_samples] 40 | edge_attr = [] 41 | 42 | # Initialize edges and edge_attributes 43 | rows, cols = [], [] 44 | for i in range(n_nodes): 45 | for j in range(n_nodes): 46 | if i != j: 47 | edge_attr.append(edges[:, i, j]) 48 | rows.append(i) 49 | cols.append(j) 50 | edges = [rows, cols] 51 | # swap n_nodes <--> batch_size and add nf dimension 52 | edge_attr = np.array(edge_attr) 53 | edge_attr = torch.Tensor(edge_attr).transpose(0, 1).unsqueeze(2) 54 | 55 | return ( 56 | torch.Tensor(loc), 57 | torch.Tensor(vel), 58 | torch.Tensor(edge_attr), 59 | edges, 60 | torch.Tensor(charges), 61 | ) 62 | 63 | def set_max_samples(self, max_samples): 64 | self.max_samples = int(max_samples) 65 | self.data, self.edges = self.load() 66 | 67 | def get_n_nodes(self): 68 | return self.data[0].size(1) 69 | 70 | def __getitem__(self, i): 71 | loc, vel, edge_attr, charges = self.data 72 | loc, vel, edge_attr, charges = loc[i], vel[i], edge_attr[i], charges[i] 73 | 74 | if self.dataset_name == "nbody": 75 | frame_0, frame_T = 6, 8 76 | elif self.dataset_name == "nbody_small": 77 | frame_0, frame_T = 30, 40 78 | elif self.dataset_name == "nbody_small_out_dist": 79 | frame_0, frame_T = 20, 30 80 | else: 81 | raise Exception("Wrong dataset partition %s" % self.dataset_name) 82 | 83 | return loc[frame_0], vel[frame_0], edge_attr, charges, loc[frame_T] 84 | 85 | def __len__(self): 86 | return len(self.data[0]) 87 | 88 | def get_edges(self, batch_size, n_nodes): 89 | edges = [torch.LongTensor(self.edges[0]), torch.LongTensor(self.edges[1])] 90 | if batch_size == 1: 91 | return edges 92 | elif batch_size > 1: 93 | rows, cols = [], [] 94 | for i in range(batch_size): 95 | rows.append(edges[0] + n_nodes * i) 96 | cols.append(edges[1] + n_nodes * i) 97 | edges = [torch.cat(rows), torch.cat(cols)] 98 | return edges 99 | -------------------------------------------------------------------------------- /nbody/data/generate_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | nbody: python -u generate_dataset.py --num-train 50000 --sample-freq 500 3 | nbody_small: python -u generate_dataset.py --num-train 3000 --seed 43 --sufix small 4 | """ 5 | 6 | import argparse 7 | import time 8 | 9 | import numpy as np 10 | from synthetic_sim import ChargedParticlesSim 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | "--simulation", type=str, default="charged", help="What simulation to generate." 15 | ) 16 | parser.add_argument( 17 | "--num-train", 18 | type=int, 19 | default=10000, 20 | help="Number of training simulations to generate.", 21 | ) 22 | parser.add_argument( 23 | "--num-valid", 24 | type=int, 25 | default=2000, 26 | help="Number of validation simulations to generate.", 27 | ) 28 | parser.add_argument( 29 | "--num-test", type=int, default=2000, help="Number of test simulations to generate." 30 | ) 31 | parser.add_argument("--length", type=int, default=5000, help="Length of trajectory.") 32 | parser.add_argument( 33 | "--length_test", type=int, default=5000, help="Length of test set trajectory." 34 | ) 35 | parser.add_argument( 36 | "--sample-freq", type=int, default=100, help="How often to sample the trajectory." 37 | ) 38 | parser.add_argument( 39 | "--n_balls", type=int, default=5, help="Number of balls in the simulation." 40 | ) 41 | parser.add_argument("--seed", type=int, default=42, help="Random seed.") 42 | parser.add_argument( 43 | "--initial_vel", type=int, default=1, help="consider initial velocity" 44 | ) 45 | parser.add_argument("--sufix", type=str, default="", help="add a sufix to the name") 46 | 47 | args = parser.parse_args() 48 | 49 | initial_vel_norm = 0.5 50 | if not args.initial_vel: 51 | initial_vel_norm = 1e-16 52 | 53 | sim = ChargedParticlesSim( 54 | noise_var=0.0, n_balls=args.n_balls, vel_norm=initial_vel_norm 55 | ) 56 | 57 | suffix = "_charged" + str(args.n_balls) + "_initvel%d" % args.initial_vel + args.sufix 58 | np.random.seed(args.seed) 59 | 60 | print(suffix) 61 | 62 | 63 | def generate_dataset(num_sims, length, sample_freq): 64 | loc_all = list() 65 | vel_all = list() 66 | edges_all = list() 67 | charges_all = list() 68 | for i in range(num_sims): 69 | t = time.time() 70 | loc, vel, edges, charges = sim.sample_trajectory( 71 | T=length, sample_freq=sample_freq 72 | ) 73 | if i % 100 == 0: 74 | print("Iter: {}, Simulation time: {}".format(i, time.time() - t)) 75 | loc_all.append(loc) 76 | vel_all.append(vel) 77 | edges_all.append(edges) 78 | charges_all.append(charges) 79 | 80 | charges_all = np.stack(charges_all) 81 | loc_all = np.stack(loc_all) 82 | vel_all = np.stack(vel_all) 83 | edges_all = np.stack(edges_all) 84 | 85 | return loc_all, vel_all, edges_all, charges_all 86 | 87 | 88 | if __name__ == "__main__": 89 | print("Generating {} training simulations".format(args.num_train)) 90 | loc_train, vel_train, edges_train, charges_train = generate_dataset( 91 | args.num_train, args.length, args.sample_freq 92 | ) 93 | 94 | print("Generating {} validation simulations".format(args.num_valid)) 95 | loc_valid, vel_valid, edges_valid, charges_valid = generate_dataset( 96 | args.num_valid, args.length, args.sample_freq 97 | ) 98 | 99 | print("Generating {} test simulations".format(args.num_test)) 100 | loc_test, vel_test, edges_test, charges_test = generate_dataset( 101 | args.num_test, args.length_test, args.sample_freq 102 | ) 103 | 104 | np.save("loc_train" + suffix + ".npy", loc_train) 105 | np.save("vel_train" + suffix + ".npy", vel_train) 106 | np.save("edges_train" + suffix + ".npy", edges_train) 107 | np.save("charges_train" + suffix + ".npy", charges_train) 108 | 109 | np.save("loc_valid" + suffix + ".npy", loc_valid) 110 | np.save("vel_valid" + suffix + ".npy", vel_valid) 111 | np.save("edges_valid" + suffix + ".npy", edges_valid) 112 | np.save("charges_valid" + suffix + ".npy", charges_valid) 113 | 114 | np.save("loc_test" + suffix + ".npy", loc_test) 115 | np.save("vel_test" + suffix + ".npy", vel_test) 116 | np.save("edges_test" + suffix + ".npy", edges_test) 117 | np.save("charges_test" + suffix + ".npy", charges_test) 118 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from functools import partial 4 | from typing import Callable, Dict, Iterable, Tuple 5 | 6 | import haiku as hk 7 | import jax 8 | import jax.numpy as jnp 9 | import jraph 10 | import optax 11 | 12 | from egnn_jax.variants import EGNN_vel 13 | from nbody.utils import setup_nbody_data 14 | 15 | key = jax.random.PRNGKey(1337) 16 | 17 | 18 | @partial(jax.jit, static_argnames=["model_fn"]) 19 | def mse( 20 | params: hk.Params, 21 | graph: jraph.GraphsTuple, 22 | props: Dict[str, jnp.ndarray], 23 | target: jnp.ndarray, 24 | model_fn: Callable, 25 | ) -> Tuple[float]: 26 | pred = model_fn( 27 | params, 28 | graph, 29 | **props, 30 | ) 31 | assert target.shape == pred.shape 32 | return (jnp.power(pred - target, 2)).mean() 33 | 34 | 35 | @partial(jax.jit, static_argnames=["loss_fn", "opt_update"]) 36 | def update( 37 | params: hk.Params, 38 | graph: jraph.GraphsTuple, 39 | props: Dict[str, jnp.ndarray], 40 | target: jnp.ndarray, 41 | opt_state: optax.OptState, 42 | loss_fn: Callable, 43 | opt_update: Callable, 44 | ) -> Tuple[float, hk.Params, optax.OptState]: 45 | loss, grads = jax.value_and_grad(loss_fn)(params, graph, props, target) 46 | updates, opt_state = opt_update(grads, opt_state, params) 47 | return loss, optax.apply_updates(params, updates), opt_state 48 | 49 | 50 | def evaluate( 51 | loader: Iterable, 52 | params: hk.Params, 53 | loss_fn: Callable, 54 | graph_transform: Callable, 55 | ) -> Tuple[float, float]: 56 | eval_loss = 0.0 57 | eval_times = 0.0 58 | for data in loader: 59 | graph, props, target = graph_transform(data) 60 | eval_start = time.perf_counter_ns() 61 | loss = jax.lax.stop_gradient(loss_fn(params, graph, props, target)) 62 | eval_loss += jax.block_until_ready(loss) 63 | eval_times += (time.perf_counter_ns() - eval_start) / 1e6 64 | 65 | return eval_times / len(loader), eval_loss / len(loader) 66 | 67 | 68 | def train(egnn, loader_train, loader_val, loader_test, graph_transform, args): 69 | init_graph, init_props, _ = graph_transform(next(iter(loader_train))) 70 | params = egnn.init(key, init_graph, **init_props) 71 | 72 | n_params = hk.data_structures.tree_size(params) 73 | print(f"Starting {args.epochs} epochs on with {n_params} parameters.") 74 | 75 | opt_init, opt_update = optax.adamw( 76 | learning_rate=args.lr, weight_decay=args.weight_decay 77 | ) 78 | 79 | loss_fn = partial(mse, model_fn=egnn.apply) 80 | update_fn = partial(update, loss_fn=loss_fn, opt_update=opt_update) 81 | eval_fn = partial(evaluate, loss_fn=loss_fn, graph_transform=graph_transform) 82 | 83 | opt_state = opt_init(params) 84 | avg_time = [] 85 | best_val = 1e10 86 | 87 | for e in range(args.epochs): 88 | train_loss = 0.0 89 | train_start = time.perf_counter_ns() 90 | for data in loader_train: 91 | graph, props, target = graph_transform(data) 92 | loss, params, opt_state = update_fn( 93 | params=params, 94 | graph=graph, 95 | props=props, 96 | target=target, 97 | opt_state=opt_state, 98 | ) 99 | train_loss += loss 100 | train_time = (time.perf_counter_ns() - train_start) / 1e6 101 | train_loss /= len(loader_train) 102 | print( 103 | f"[Epoch {e+1:>4}] train loss {train_loss:.6f}, epoch {train_time:.2f}ms", 104 | end="", 105 | ) 106 | if e % args.val_freq == 0: 107 | eval_time, val_loss = eval_fn(loader_val, params) 108 | avg_time.append(eval_time) 109 | tag = "" 110 | if val_loss < best_val: 111 | best_val = val_loss 112 | _, test_loss_ckp = eval_fn(loader_test, params) 113 | tag = " (BEST)" 114 | print(f" - val loss {val_loss:.6f}{tag}, infer {eval_time:.2f}ms", end="") 115 | 116 | print() 117 | 118 | test_loss = 0 119 | _, test_loss = eval_fn(loader_test, params) 120 | # ignore compilation time 121 | avg_time = avg_time[2:] 122 | avg_time = sum(avg_time) / len(avg_time) 123 | print( 124 | "Training done.\n" 125 | f"Final test loss {test_loss:.6f} - checkpoint test loss {test_loss_ckp:.6f}.\n" 126 | f"Average (model) eval time {avg_time:.2f}ms" 127 | ) 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = argparse.ArgumentParser() 132 | # model options 133 | parser.add_argument("--hidden-size", type=int, default=64) 134 | parser.add_argument("--num-layers", type=int, default=4) 135 | 136 | # data options 137 | parser.add_argument( 138 | "--dataset-name", type=str, default="small", choices=["small", "default"] 139 | ) 140 | parser.add_argument("--max-samples", type=int, default=3000) 141 | parser.add_argument("--neighbours", type=int, default=5) 142 | parser.add_argument("--target", type=str, default="pos", choices=["pos", "force"]) 143 | 144 | # training options 145 | parser.add_argument("--batch-size", type=int, default=100) 146 | parser.add_argument("--epochs", type=int, default=100) 147 | parser.add_argument("--lr", type=float, default=1e-4) 148 | parser.add_argument("--weight-decay", type=float, default=1e-12) 149 | parser.add_argument("--val-freq", type=int, default=10) 150 | 151 | args = parser.parse_args() 152 | 153 | loader_train, loader_val, loader_test, graph_transform = setup_nbody_data(args) 154 | 155 | egnn = lambda graph, pos, vel, edge_attribute: EGNN_vel( 156 | hidden_size=args.hidden_size, 157 | output_size=args.hidden_size, 158 | num_layers=args.num_layers, 159 | residual=True, 160 | )(graph, pos, vel, edge_attribute) 161 | 162 | egnn = hk.without_apply_rng(hk.transform(egnn)) 163 | 164 | train(egnn, loader_train, loader_val, loader_test, graph_transform, args) 165 | -------------------------------------------------------------------------------- /nbody/data/synthetic_sim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ChargedParticlesSim(object): 5 | def __init__( 6 | self, 7 | n_balls=5, 8 | box_size=5.0, 9 | loc_std=1.0, 10 | vel_norm=0.5, 11 | interaction_strength=1.0, 12 | noise_var=0.0, 13 | ): 14 | self.n_balls = n_balls 15 | self.box_size = box_size 16 | self.loc_std = loc_std 17 | self.loc_std = loc_std * (float(n_balls) / 5.0) ** (1 / 3) 18 | print(self.loc_std) 19 | self.vel_norm = vel_norm 20 | self.interaction_strength = interaction_strength 21 | self.noise_var = noise_var 22 | 23 | self._charge_types = np.array([-1.0, 0.0, 1.0]) 24 | self._delta_T = 0.001 25 | self._max_F = 0.1 / self._delta_T 26 | self.dim = 3 27 | 28 | def _l2(self, A, B): 29 | """ 30 | Input: A is a Nxd matrix 31 | B is a Mxd matirx 32 | Output: dist is a NxM matrix where dist[i,j] is the square norm 33 | between A[i,:] and B[j,:] 34 | i.e. dist[i,j] = ||A[i,:]-B[j,:]||^2 35 | """ 36 | A_norm = (A**2).sum(axis=1).reshape(A.shape[0], 1) 37 | B_norm = (B**2).sum(axis=1).reshape(1, B.shape[0]) 38 | dist = A_norm + B_norm - 2 * A.dot(B.transpose()) 39 | return dist 40 | 41 | def _energy(self, loc, vel, edges): 42 | # disables division by zero warning, since I fix it with fill_diagonal 43 | with np.errstate(divide="ignore"): 44 | K = 0.5 * (vel**2).sum() 45 | U = 0 46 | for i in range(loc.shape[1]): 47 | for j in range(loc.shape[1]): 48 | if i != j: 49 | r = loc[:, i] - loc[:, j] 50 | dist = np.sqrt((r**2).sum()) 51 | U += 0.5 * self.interaction_strength * edges[i, j] / dist 52 | return U + K 53 | 54 | def _clamp(self, loc, vel): 55 | """ 56 | :param loc: 2xN location at one time stamp 57 | :param vel: 2xN velocity at one time stamp 58 | :return: location and velocity after hiting walls and returning after 59 | elastically colliding with walls 60 | """ 61 | assert np.all(loc < self.box_size * 3) 62 | assert np.all(loc > -self.box_size * 3) 63 | 64 | over = loc > self.box_size 65 | loc[over] = 2 * self.box_size - loc[over] 66 | assert np.all(loc <= self.box_size) 67 | 68 | # assert(np.all(vel[over]>0)) 69 | vel[over] = -np.abs(vel[over]) 70 | 71 | under = loc < -self.box_size 72 | loc[under] = -2 * self.box_size - loc[under] 73 | # assert (np.all(vel[under] < 0)) 74 | assert np.all(loc >= -self.box_size) 75 | vel[under] = np.abs(vel[under]) 76 | 77 | return loc, vel 78 | 79 | def sample_trajectory( 80 | self, T=10000, sample_freq=10, charge_prob=[1.0 / 2, 0, 1.0 / 2] 81 | ): 82 | n = self.n_balls 83 | assert T % sample_freq == 0 84 | T_save = int(T / sample_freq - 1) 85 | diag_mask = np.ones((n, n), dtype=bool) 86 | np.fill_diagonal(diag_mask, 0) 87 | counter = 0 88 | # Sample edges 89 | charges = np.random.choice( 90 | self._charge_types, size=(self.n_balls, 1), p=charge_prob 91 | ) 92 | edges = charges.dot(charges.transpose()) 93 | # Initialize location and velocity 94 | loc = np.zeros((T_save, self.dim, n)) 95 | vel = np.zeros((T_save, self.dim, n)) 96 | loc_next = np.random.randn(self.dim, n) * self.loc_std 97 | vel_next = np.random.randn(self.dim, n) 98 | v_norm = np.sqrt((vel_next**2).sum(axis=0)).reshape(1, -1) 99 | vel_next = vel_next * self.vel_norm / v_norm 100 | loc[0, :, :], vel[0, :, :] = self._clamp(loc_next, vel_next) 101 | 102 | # disables division by zero warning, since I fix it with fill_diagonal 103 | with np.errstate(divide="ignore"): 104 | # half step leapfrog 105 | l2_dist_power3 = np.power( 106 | self._l2(loc_next.transpose(), loc_next.transpose()), 3.0 / 2.0 107 | ) 108 | 109 | # size of forces up to a 1/|r| factor 110 | # since I later multiply by an unnormalized r vector 111 | forces_size = self.interaction_strength * edges / l2_dist_power3 112 | np.fill_diagonal( 113 | forces_size, 0 114 | ) # self forces are zero (fixes division by zero) 115 | assert np.abs(forces_size[diag_mask]).min() > 1e-10 116 | F = ( 117 | forces_size.reshape(1, n, n) 118 | * np.concatenate( 119 | ( 120 | np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape( 121 | 1, n, n 122 | ), 123 | np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape( 124 | 1, n, n 125 | ), 126 | np.subtract.outer(loc_next[2, :], loc_next[2, :]).reshape( 127 | 1, n, n 128 | ), 129 | ) 130 | ) 131 | ).sum(axis=-1) 132 | F[F > self._max_F] = self._max_F 133 | F[F < -self._max_F] = -self._max_F 134 | 135 | vel_next += self._delta_T * F 136 | # run leapfrog 137 | for i in range(1, T): 138 | loc_next += self._delta_T * vel_next 139 | # loc_next, vel_next = self._clamp(loc_next, vel_next) 140 | 141 | if i % sample_freq == 0: 142 | loc[counter, :, :], vel[counter, :, :] = loc_next, vel_next 143 | counter += 1 144 | 145 | l2_dist_power3 = np.power( 146 | self._l2(loc_next.transpose(), loc_next.transpose()), 3.0 / 2.0 147 | ) 148 | forces_size = self.interaction_strength * edges / l2_dist_power3 149 | np.fill_diagonal(forces_size, 0) 150 | # assert (np.abs(forces_size[diag_mask]).min() > 1e-10) 151 | 152 | F = ( 153 | forces_size.reshape(1, n, n) 154 | * np.concatenate( 155 | ( 156 | np.subtract.outer(loc_next[0, :], loc_next[0, :]).reshape( 157 | 1, n, n 158 | ), 159 | np.subtract.outer(loc_next[1, :], loc_next[1, :]).reshape( 160 | 1, n, n 161 | ), 162 | np.subtract.outer(loc_next[2, :], loc_next[2, :]).reshape( 163 | 1, n, n 164 | ), 165 | ) 166 | ) 167 | ).sum(axis=-1) 168 | F[F > self._max_F] = self._max_F 169 | F[F < -self._max_F] = -self._max_F 170 | vel_next += self._delta_T * F 171 | # Add noise to observations 172 | loc += np.random.randn(T_save, self.dim, self.n_balls) * self.noise_var 173 | vel += np.random.randn(T_save, self.dim, self.n_balls) * self.noise_var 174 | return loc, vel, edges, charges 175 | -------------------------------------------------------------------------------- /egnn_jax/egnn.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Tuple 2 | 3 | import haiku as hk 4 | import jax 5 | import jax.numpy as jnp 6 | import jraph 7 | from haiku.initializers import UniformScaling 8 | from jax.tree_util import Partial 9 | 10 | from egnn_jax.utils import LinearXav, MLPXav 11 | 12 | 13 | class EGNNLayer(hk.Module): 14 | """EGNN layer. 15 | 16 | Args: 17 | layer_num: layer number 18 | hidden_size: hidden size 19 | output_size: output size 20 | blocks: number of blocks in the node and edge MLPs 21 | act_fn: activation function 22 | pos_aggregate_fn: position aggregation function 23 | msg_aggregate_fn: message aggregation function 24 | residual: whether to use residual connections 25 | attention: whether to use attention 26 | normalize: whether to normalize the coordinates 27 | tanh: whether to use tanh in the position update 28 | dt: position update step size 29 | eps: small number to avoid division by zero 30 | """ 31 | 32 | def __init__( 33 | self, 34 | layer_num: int, 35 | hidden_size: int, 36 | output_size: int, 37 | blocks: int = 1, 38 | act_fn: Callable = jax.nn.silu, 39 | pos_aggregate_fn: Optional[Callable] = jraph.segment_sum, 40 | msg_aggregate_fn: Optional[Callable] = jraph.segment_sum, 41 | residual: bool = True, 42 | attention: bool = False, 43 | normalize: bool = False, 44 | tanh: bool = False, 45 | dt: float = 0.001, 46 | eps: float = 1e-8, 47 | ): 48 | super().__init__(f"layer_{layer_num}") 49 | 50 | # message network 51 | self._edge_mlp = MLPXav( 52 | [hidden_size] * blocks + [hidden_size], 53 | activation=act_fn, 54 | activate_final=True, 55 | ) 56 | 57 | # update network 58 | self._node_mlp = MLPXav( 59 | [hidden_size] * blocks + [output_size], 60 | activation=act_fn, 61 | activate_final=False, 62 | ) 63 | 64 | # position update network 65 | net = [LinearXav(hidden_size), act_fn] 66 | # NOTE: from https://github.com/vgsatorras/egnn/blob/main/models/gcl.py#L254 67 | net += [LinearXav(1, with_bias=False, w_init=UniformScaling(dt))] 68 | if tanh: 69 | net.append(jax.nn.tanh) 70 | self._pos_correction_mlp = hk.Sequential(net) 71 | 72 | # attention 73 | self._attention_mlp = None 74 | if attention: 75 | self._attention_mlp = hk.Sequential( 76 | [LinearXav(hidden_size), jax.nn.sigmoid] 77 | ) 78 | 79 | self.pos_aggregate_fn = pos_aggregate_fn 80 | self.msg_aggregate_fn = msg_aggregate_fn 81 | self._residual = residual 82 | self._normalize = normalize 83 | self._eps = eps 84 | 85 | def _pos_update( 86 | self, 87 | pos: jnp.ndarray, 88 | graph: jraph.GraphsTuple, 89 | coord_diff: jnp.ndarray, 90 | ) -> jnp.ndarray: 91 | trans = coord_diff * self._pos_correction_mlp(graph.edges) 92 | # NOTE: was in the original code 93 | trans = jnp.clip(trans, -100, 100) 94 | return self.pos_aggregate_fn(trans, graph.senders, num_segments=pos.shape[0]) 95 | 96 | def _message( 97 | self, 98 | radial: jnp.ndarray, 99 | edge_attribute: jnp.ndarray, 100 | edge_features: Any, 101 | incoming: jnp.ndarray, 102 | outgoing: jnp.ndarray, 103 | globals_: Any, 104 | ) -> jnp.ndarray: 105 | _ = edge_features 106 | _ = globals_ 107 | msg = jnp.concatenate([incoming, outgoing, radial], axis=-1) 108 | if edge_attribute is not None: 109 | msg = jnp.concatenate([msg, edge_attribute], axis=-1) 110 | msg = self._edge_mlp(msg) 111 | if self._attention_mlp: 112 | att = self._attention_mlp(msg) 113 | msg = msg * att 114 | return msg 115 | 116 | def _update( 117 | self, 118 | node_attribute: jnp.ndarray, 119 | nodes: jnp.ndarray, 120 | senders: Any, 121 | msg: jnp.ndarray, 122 | globals_: Any, 123 | ) -> jnp.ndarray: 124 | _ = senders 125 | _ = globals_ 126 | x = jnp.concatenate([nodes, msg], axis=-1) 127 | if node_attribute is not None: 128 | x = jnp.concatenate([x, node_attribute], axis=-1) 129 | x = self._node_mlp(x) 130 | if self._residual: 131 | x = nodes + x 132 | return x 133 | 134 | def _coord2radial( 135 | self, graph: jraph.GraphsTuple, coord: jnp.array 136 | ) -> Tuple[jnp.array, jnp.array]: 137 | coord_diff = coord[graph.senders] - coord[graph.receivers] 138 | radial = jnp.sum(coord_diff**2, 1)[:, jnp.newaxis] 139 | if self._normalize: 140 | norm = jnp.sqrt(radial) 141 | coord_diff = coord_diff / (norm + self._eps) 142 | return radial, coord_diff 143 | 144 | def __call__( 145 | self, 146 | graph: jraph.GraphsTuple, 147 | pos: jnp.ndarray, 148 | edge_attribute: Optional[jnp.ndarray] = None, 149 | node_attribute: Optional[jnp.ndarray] = None, 150 | ) -> Tuple[jraph.GraphsTuple, jnp.ndarray]: 151 | """ 152 | Apply EGNN layer. 153 | 154 | Args: 155 | graph: Graph from previous step 156 | pos: Node position, updated separately 157 | edge_attribute: Edge attribute (optional) 158 | node_attribute: Node attribute (optional) 159 | """ 160 | radial, coord_diff = self._coord2radial(graph, pos) 161 | 162 | graph = jraph.GraphNetwork( 163 | update_edge_fn=Partial(self._message, radial, edge_attribute), 164 | update_node_fn=Partial(self._update, node_attribute), 165 | aggregate_edges_for_nodes_fn=self.msg_aggregate_fn, 166 | )(graph) 167 | 168 | pos = pos + self._pos_update(pos, graph, coord_diff) 169 | 170 | return graph, pos 171 | 172 | 173 | class EGNN(hk.Module): 174 | r""" 175 | E(n) Graph Neural Network (https://arxiv.org/abs/2102.09844). 176 | 177 | Original implementation: https://github.com/vgsatorras/egnn 178 | """ 179 | 180 | def __init__( 181 | self, 182 | hidden_size: int, 183 | output_size: int, 184 | act_fn: Callable = jax.nn.silu, 185 | num_layers: int = 4, 186 | residual: bool = True, 187 | attention: bool = False, 188 | normalize: bool = False, 189 | tanh: bool = False, 190 | ): 191 | r""" 192 | Initialize the network. 193 | 194 | Args: 195 | hidden_size: Number of hidden features 196 | output_size: Number of features for 'h' at the output 197 | act_fn: Non-linearity 198 | num_layers: Number of layer for the EGNN 199 | residual: Use residual connections, we recommend not changing this one 200 | attention: Whether using attention or not 201 | normalize: Normalizes the coordinates messages such that: 202 | x^{l+1}_i = x^{l}_i + \sum(x_i - x_j)\phi_x(m_{ij})\|x_i - x_j\| 203 | It may help in the stability or generalization. Not used in the paper. 204 | tanh: Sets a tanh activation function at the output of \phi_x(m_{ij}). It 205 | bounds the output of \phi_x(m_{ij}) which definitely improves in 206 | stability but it may decrease in accuracy. Not used in the paper. 207 | """ 208 | super().__init__() 209 | 210 | self._hidden_size = hidden_size 211 | self._output_size = output_size 212 | self._act_fn = act_fn 213 | self._num_layers = num_layers 214 | self._residual = residual 215 | self._attention = attention 216 | self._normalize = normalize 217 | self._tanh = tanh 218 | 219 | def __call__( 220 | self, 221 | graph: jraph.GraphsTuple, 222 | pos: jnp.ndarray, 223 | edge_attribute: Optional[jnp.ndarray] = None, 224 | node_attribute: Optional[jnp.ndarray] = None, 225 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 226 | """ 227 | Apply EGNN. 228 | 229 | Args: 230 | graph: Input graph 231 | pos: Node position 232 | edge_attribute: Edge attribute (optional) 233 | node_attribute: Node attribute (optional) 234 | 235 | Returns: 236 | Tuple of updated node features and positions 237 | """ 238 | # input node embedding 239 | h = LinearXav(self._hidden_size, name="embedding")(graph.nodes) 240 | graph = graph._replace(nodes=h) 241 | # message passing 242 | for n in range(self._num_layers): 243 | graph, pos = EGNNLayer( 244 | layer_num=n, 245 | hidden_size=self._hidden_size, 246 | output_size=self._hidden_size, 247 | act_fn=self._act_fn, 248 | residual=self._residual, 249 | attention=self._attention, 250 | normalize=self._normalize, 251 | tanh=self._tanh, 252 | )(graph, pos, edge_attribute=edge_attribute, node_attribute=node_attribute) 253 | # node readout 254 | h = LinearXav(self._output_size, name="readout")(graph.nodes) 255 | return h, pos 256 | --------------------------------------------------------------------------------