├── .flake8 ├── .gitignore ├── README.md ├── bench.py ├── nequip_jax ├── __init__.py ├── filter_layers.py ├── nequip.py ├── nequip_escn.py └── radial.py ├── setup.cfg ├── setup.py └── test.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-complexity = 21 3 | select = B,C,E,F,W,T4,B9 4 | ignore = E741, E203, W503, C901, E501 5 | exclude = .eggs,*.egg,build,dist,docs/_build,notebooks 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.egg-info 3 | *.ipynb 4 | .vscode -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | Install directly from GitHub with: 4 | 5 | ``` 6 | pip install git+https://github.com/mariogeiger/nequip-jax 7 | ``` 8 | 9 | 10 | ## Usage 11 | 12 | ### Original Nequip 13 | 14 | ```python 15 | from nequip_jax import NEQUIPLayerFlax # Flax version 16 | from nequip_jax import NEQUIPLayerHaiku # Haiku version 17 | ``` 18 | 19 | Look at [test.py](test.py) for an example of how to stack the layers. 20 | 21 | ### Optimization using ESCN 22 | 23 | Optimization for large `L` using [https://arxiv.org/pdf/2302.03655.pdf](paper). 24 | With extra support of parity. 25 | 26 | ```python 27 | from nequip_jax import NEQUIPESCNLayerFlax # Flax version 28 | from nequip_jax import NEQUIPESCNLayerHaiku # Haiku version 29 | ``` 30 | -------------------------------------------------------------------------------- /bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import e3nn_jax as e3nn 5 | import flax 6 | import jax 7 | import jax.numpy as jnp 8 | import jraph 9 | 10 | from nequip_jax import NEQUIPLayerFlax 11 | 12 | 13 | class NEQUIP(flax.linen.Module): 14 | @flax.linen.compact 15 | def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 16 | positions = graph.nodes["positions"] 17 | species = graph.nodes["species"] 18 | 19 | vectors = e3nn.IrrepsArray( 20 | "1o", positions[graph.receivers] - positions[graph.senders] 21 | ) 22 | node_feats = flax.linen.Embed(num_embeddings=5, features=32)(species) 23 | node_feats = e3nn.IrrepsArray(f"{node_feats.shape[1]}x0e", node_feats) 24 | 25 | for _ in range(3): 26 | layer = NEQUIPLayerFlax( 27 | avg_num_neighbors=1.0, 28 | output_irreps=64 * e3nn.Irreps.spherical_harmonics(args.max_ell), 29 | max_ell=2 * args.max_ell, 30 | ) 31 | node_feats = layer( 32 | vectors, 33 | node_feats, 34 | species, 35 | graph.senders, 36 | graph.receivers, 37 | ) 38 | 39 | return node_feats 40 | 41 | 42 | def main(): 43 | model = NEQUIP() 44 | 45 | n_nodes = 256 46 | n_edges = 4096 47 | 48 | graph = jraph.GraphsTuple( 49 | nodes={ 50 | "positions": jax.random.normal(jax.random.PRNGKey(0), (n_nodes, 3)), 51 | "species": jax.random.randint(jax.random.PRNGKey(1), (n_nodes,), 0, 5), 52 | }, 53 | edges=None, 54 | globals=None, 55 | senders=jax.random.randint(jax.random.PRNGKey(2), (n_edges,), 0, n_nodes), 56 | receivers=jax.random.randint(jax.random.PRNGKey(3), (n_edges,), 0, n_nodes), 57 | n_node=jnp.array([n_nodes]), 58 | n_edge=jnp.array([n_edges]), 59 | ) 60 | 61 | w = jax.jit(model.init)(jax.random.PRNGKey(0), graph) 62 | print("number of parameters:", sum(x.size for x in jax.tree_util.tree_leaves(w))) 63 | 64 | apply = jax.jit(model.apply) 65 | 66 | print("compiling forward pass") 67 | apply(w, graph) 68 | 69 | print("running forward pass") 70 | t0 = time.time() 71 | apply(w, graph).array.block_until_ready() 72 | t1 = time.time() 73 | print(f"took {t1 - t0} seconds") 74 | 75 | bwr = jax.jit(jax.grad(lambda w, graph: apply(w, graph).array.sum())) 76 | print("compiling backward pass") 77 | bwr(w, graph) 78 | 79 | print("running backward pass") 80 | t0 = time.time() 81 | g = bwr(w, graph) 82 | jax.tree_util.tree_map(lambda x: x.block_until_ready(), g) 83 | t1 = time.time() 84 | print(f"took {t1 - t0} seconds") 85 | 86 | 87 | if __name__ == "__main__": 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument("--max-ell", type=int, default=3) 90 | args = parser.parse_args() 91 | main() 92 | -------------------------------------------------------------------------------- /nequip_jax/__init__.py: -------------------------------------------------------------------------------- 1 | from .nequip import NEQUIPLayerFlax, NEQUIPLayerHaiku 2 | from .nequip_escn import NEQUIPESCNLayerFlax, NEQUIPESCNLayerHaiku 3 | from .filter_layers import filter_layers 4 | from .radial import default_radial_basis, simple_smooth_radial_basis 5 | 6 | __version__ = "1.1.0" 7 | 8 | __all__ = [ 9 | "NEQUIPLayerFlax", 10 | "NEQUIPLayerHaiku", 11 | "NEQUIPESCNLayerFlax", 12 | "NEQUIPESCNLayerHaiku", 13 | "filter_layers", 14 | "default_radial_basis", 15 | "simple_smooth_radial_basis", 16 | ] 17 | -------------------------------------------------------------------------------- /nequip_jax/filter_layers.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import e3nn_jax as e3nn 4 | 5 | 6 | def filter_layers( 7 | layer_irreps: List[e3nn.Irreps], max_ell: Optional[int] 8 | ) -> List[e3nn.Irreps]: 9 | filtered = [e3nn.Irreps(layer_irreps[-1])] 10 | for irreps in reversed(layer_irreps[:-1]): 11 | irreps = e3nn.Irreps(irreps) 12 | if max_ell is not None: 13 | lmax = max_ell 14 | else: 15 | lmax = max(irreps.lmax, filtered[0].lmax) 16 | irreps = irreps.filter( 17 | keep=e3nn.tensor_product( 18 | filtered[0], 19 | e3nn.Irreps.spherical_harmonics(lmax=lmax), 20 | ).regroup() 21 | ) 22 | filtered.insert(0, irreps) 23 | return filtered 24 | -------------------------------------------------------------------------------- /nequip_jax/nequip.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Union 2 | 3 | import e3nn_jax as e3nn 4 | import flax 5 | import haiku as hk 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | from .radial import default_radial_basis 10 | 11 | 12 | class NEQUIPLayerFlax(flax.linen.Module): 13 | avg_num_neighbors: float 14 | num_species: int = 1 15 | max_ell: int = 3 16 | output_irreps: e3nn.Irreps = 64 * e3nn.Irreps("0e + 1o + 2e") 17 | even_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.silu 18 | odd_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.tanh 19 | gate_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.silu 20 | mlp_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.silu 21 | mlp_n_hidden: int = 64 22 | mlp_n_layers: int = 2 23 | radial_basis: Callable[[jnp.ndarray, int], jnp.ndarray] = default_radial_basis 24 | n_radial_basis: int = 8 25 | 26 | @flax.linen.compact 27 | def __call__( 28 | self, 29 | vectors: e3nn.IrrepsArray, 30 | node_feats: e3nn.IrrepsArray, 31 | node_specie: jnp.ndarray, 32 | senders: jnp.ndarray, 33 | receivers: jnp.ndarray, 34 | ): 35 | return _impl( 36 | e3nn.flax.Linear, 37 | e3nn.flax.MultiLayerPerceptron, 38 | self, 39 | vectors, 40 | node_feats, 41 | node_specie, 42 | senders, 43 | receivers, 44 | ) 45 | 46 | 47 | class NEQUIPLayerHaiku(hk.Module): 48 | def __init__( 49 | self, 50 | avg_num_neighbors: float, 51 | num_species: int = 1, 52 | max_ell: int = 3, 53 | output_irreps: e3nn.Irreps = 64 * e3nn.Irreps("0e + 1o + 2e"), 54 | even_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.silu, 55 | odd_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.tanh, 56 | gate_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.silu, 57 | mlp_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.silu, 58 | mlp_n_hidden: int = 64, 59 | mlp_n_layers: int = 2, 60 | radial_basis: Callable[[jnp.ndarray, int], jnp.ndarray] = default_radial_basis, 61 | n_radial_basis: int = 8, 62 | name: Optional[str] = None, 63 | ): 64 | super().__init__(name) 65 | self.avg_num_neighbors = avg_num_neighbors 66 | self.num_species = num_species 67 | self.max_ell = max_ell 68 | self.output_irreps = output_irreps 69 | self.even_activation = even_activation 70 | self.odd_activation = odd_activation 71 | self.gate_activation = gate_activation 72 | self.mlp_activation = mlp_activation 73 | self.mlp_n_hidden = mlp_n_hidden 74 | self.mlp_n_layers = mlp_n_layers 75 | self.radial_basis = radial_basis 76 | self.n_radial_basis = n_radial_basis 77 | 78 | def __call__( 79 | self, 80 | vectors: e3nn.IrrepsArray, 81 | node_feats: e3nn.IrrepsArray, 82 | node_specie: jnp.ndarray, 83 | senders: jnp.ndarray, 84 | receivers: jnp.ndarray, 85 | ): 86 | return _impl( 87 | e3nn.haiku.Linear, 88 | e3nn.haiku.MultiLayerPerceptron, 89 | self, 90 | vectors, 91 | node_feats, 92 | node_specie, 93 | senders, 94 | receivers, 95 | ) 96 | 97 | 98 | def _impl( 99 | Linear: Callable, 100 | MultiLayerPerceptron: Callable, 101 | self: Union[NEQUIPLayerFlax, NEQUIPLayerHaiku], 102 | vectors: e3nn.IrrepsArray, # [n_edges, 3] 103 | node_feats: e3nn.IrrepsArray, # [n_nodes, irreps] 104 | node_specie: jnp.ndarray, # [n_nodes] int between 0 and num_species-1 105 | senders: jnp.ndarray, # [n_edges] 106 | receivers: jnp.ndarray, # [n_edges] 107 | ): 108 | node_feats = e3nn.as_irreps_array(node_feats) 109 | 110 | num_nodes = node_feats.shape[0] 111 | num_edges = vectors.shape[0] 112 | assert vectors.shape == (num_edges, 3) 113 | assert node_feats.shape == (num_nodes, node_feats.irreps.dim) 114 | assert node_specie.shape == (num_nodes,) 115 | assert senders.shape == (num_edges,) 116 | assert receivers.shape == (num_edges,) 117 | 118 | # we regroup the target irreps to make sure that gate activation 119 | # has the same irreps as the target 120 | output_irreps = e3nn.Irreps(self.output_irreps).regroup() 121 | 122 | messages = Linear(node_feats.irreps, name="linear_up")(node_feats)[senders] 123 | 124 | # Angular part 125 | messages = e3nn.concatenate( 126 | [ 127 | messages.filter(output_irreps + "0e"), 128 | e3nn.tensor_product( 129 | messages, 130 | e3nn.spherical_harmonics( 131 | [l for l in range(1, self.max_ell + 1)], 132 | vectors, 133 | normalize=True, 134 | normalization="component", 135 | ), 136 | filter_ir_out=output_irreps + "0e", 137 | ), 138 | ] 139 | ).regroup() 140 | assert messages.shape == (num_edges, messages.irreps.dim) 141 | 142 | # Radial part 143 | with jax.ensure_compile_time_eval(): 144 | assert abs(self.mlp_activation(0.0)) < 1e-6 145 | lengths = e3nn.norm(vectors).array 146 | mix = MultiLayerPerceptron( 147 | self.mlp_n_layers * (self.mlp_n_hidden,) + (messages.irreps.num_irreps,), 148 | self.mlp_activation, 149 | output_activation=False, 150 | )(self.radial_basis(lengths[:, 0], self.n_radial_basis)) 151 | 152 | # Discard 0 length edges that come from graph padding 153 | mix = jnp.where(lengths == 0.0, 0.0, mix) 154 | assert mix.shape == (num_edges, messages.irreps.num_irreps) 155 | 156 | # Product of radial and angular part 157 | messages = messages * mix 158 | assert messages.shape == (num_edges, messages.irreps.dim) 159 | 160 | # Skip connection 161 | irreps = output_irreps.filter(keep=messages.irreps) 162 | num_nonscalar = irreps.filter(drop="0e + 0o").num_irreps 163 | irreps = irreps + e3nn.Irreps(f"{num_nonscalar}x0e").simplify() 164 | 165 | skip = Linear( 166 | irreps, 167 | num_indexed_weights=self.num_species, 168 | name="skip_tp", 169 | force_irreps_out=True, 170 | )(node_specie, node_feats) 171 | 172 | # Message passing 173 | node_feats = e3nn.scatter_sum(messages, dst=receivers, output_size=num_nodes) 174 | node_feats = node_feats / jnp.sqrt(self.avg_num_neighbors) 175 | 176 | node_feats = Linear(irreps, name="linear_down")(node_feats) 177 | 178 | node_feats = node_feats + skip 179 | assert node_feats.shape == (num_nodes, node_feats.irreps.dim) 180 | 181 | node_feats = e3nn.gate( 182 | node_feats, 183 | even_act=self.even_activation, 184 | odd_act=self.odd_activation, 185 | even_gate_act=self.gate_activation, 186 | ) 187 | 188 | return node_feats 189 | -------------------------------------------------------------------------------- /nequip_jax/nequip_escn.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Union 2 | 3 | import e3nn_jax as e3nn 4 | import flax 5 | import haiku as hk 6 | import jax 7 | import jax.numpy as jnp 8 | from e3nn_jax.experimental.linear_shtp import LinearSHTP 9 | 10 | from .radial import default_radial_basis 11 | 12 | 13 | class NEQUIPESCNLayerFlax(flax.linen.Module): 14 | avg_num_neighbors: float 15 | num_species: int = 1 16 | output_irreps: e3nn.Irreps = 64 * e3nn.Irreps("0e + 1o + 2e") 17 | even_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.silu 18 | odd_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.tanh 19 | gate_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.silu 20 | mlp_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.silu 21 | mlp_n_hidden: int = 64 22 | mlp_n_layers: int = 2 23 | radial_basis: Callable[[jnp.ndarray, int], jnp.ndarray] = default_radial_basis 24 | n_radial_basis: int = 8 25 | 26 | @flax.linen.compact 27 | def __call__( 28 | self, 29 | vectors: e3nn.IrrepsArray, 30 | node_feats: e3nn.IrrepsArray, 31 | node_specie: jnp.ndarray, 32 | senders: jnp.ndarray, 33 | receivers: jnp.ndarray, 34 | ): 35 | return _impl( 36 | e3nn.flax.Linear, 37 | e3nn.flax.MultiLayerPerceptron, 38 | self, 39 | vectors, 40 | node_feats, 41 | node_specie, 42 | senders, 43 | receivers, 44 | ) 45 | 46 | 47 | class NEQUIPESCNLayerHaiku(hk.Module): 48 | def __init__( 49 | self, 50 | avg_num_neighbors: float, 51 | num_species: int = 1, 52 | output_irreps: e3nn.Irreps = 64 * e3nn.Irreps("0e + 1o + 2e"), 53 | even_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.silu, 54 | odd_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.tanh, 55 | gate_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.silu, 56 | mlp_activation: Callable[[jnp.ndarray], jnp.ndarray] = jax.nn.silu, 57 | mlp_n_hidden: int = 64, 58 | mlp_n_layers: int = 2, 59 | radial_basis: Callable[[jnp.ndarray, int], jnp.ndarray] = default_radial_basis, 60 | n_radial_basis: int = 8, 61 | name: Optional[str] = None, 62 | ): 63 | super().__init__(name) 64 | self.avg_num_neighbors = avg_num_neighbors 65 | self.num_species = num_species 66 | self.output_irreps = output_irreps 67 | self.even_activation = even_activation 68 | self.odd_activation = odd_activation 69 | self.gate_activation = gate_activation 70 | self.mlp_activation = mlp_activation 71 | self.mlp_n_hidden = mlp_n_hidden 72 | self.mlp_n_layers = mlp_n_layers 73 | self.radial_basis = radial_basis 74 | self.n_radial_basis = n_radial_basis 75 | 76 | def __call__( 77 | self, 78 | vectors: e3nn.IrrepsArray, 79 | node_feats: e3nn.IrrepsArray, 80 | node_specie: jnp.ndarray, 81 | senders: jnp.ndarray, 82 | receivers: jnp.ndarray, 83 | ): 84 | return _impl( 85 | e3nn.haiku.Linear, 86 | e3nn.haiku.MultiLayerPerceptron, 87 | self, 88 | vectors, 89 | node_feats, 90 | node_specie, 91 | senders, 92 | receivers, 93 | ) 94 | 95 | 96 | def _impl( 97 | Linear: Callable, 98 | MultiLayerPerceptron: Callable, 99 | self: Union[NEQUIPESCNLayerFlax, NEQUIPESCNLayerHaiku], 100 | vectors: e3nn.IrrepsArray, # [n_edges, 3] 101 | node_feats: e3nn.IrrepsArray, # [n_nodes, irreps] 102 | node_specie: jnp.ndarray, # [n_nodes] int between 0 and num_species-1 103 | senders: jnp.ndarray, # [n_edges] 104 | receivers: jnp.ndarray, # [n_edges] 105 | ): 106 | num_nodes = node_feats.shape[0] 107 | num_edges = vectors.shape[0] 108 | assert vectors.shape == (num_edges, 3) 109 | assert node_feats.shape == (num_nodes, node_feats.irreps.dim) 110 | assert node_specie.shape == (num_nodes,) 111 | assert senders.shape == (num_edges,) 112 | assert receivers.shape == (num_edges,) 113 | 114 | # we regroup the target irreps to make sure that gate activation 115 | # has the same irreps as the target 116 | output_irreps = e3nn.Irreps(self.output_irreps).regroup() 117 | 118 | messages = Linear(node_feats.irreps, name="linear_up")(node_feats)[senders] 119 | 120 | conv = LinearSHTP(output_irreps, mix=False) 121 | w_unused = conv.init(jax.random.PRNGKey(0), messages[0], vectors[0]) 122 | w_unused_flat = flatten(w_unused) 123 | 124 | # Radial part 125 | with jax.ensure_compile_time_eval(): 126 | assert abs(self.mlp_activation(0.0)) < 1e-6 127 | lengths = e3nn.norm(vectors).array 128 | mix = MultiLayerPerceptron( 129 | self.mlp_n_layers * (self.mlp_n_hidden,) + (w_unused_flat.size,), 130 | self.mlp_activation, 131 | output_activation=False, 132 | )(self.radial_basis(lengths[:, 0], self.n_radial_basis)) 133 | 134 | # Discard 0 length edges that come from graph padding 135 | mix = jnp.where(lengths == 0.0, 0.0, mix) 136 | assert mix.shape == (num_edges, w_unused_flat.size) 137 | 138 | w = jax.vmap(unflatten, (0, None))(mix, w_unused) 139 | messages = jax.vmap(conv.apply)(w, messages, vectors) 140 | assert messages.shape == (num_edges, messages.irreps.dim) 141 | 142 | # Skip connection 143 | irreps = output_irreps.filter(keep=messages.irreps) 144 | num_nonscalar = irreps.filter(drop="0e + 0o").num_irreps 145 | irreps = irreps + e3nn.Irreps(f"{num_nonscalar}x0e").simplify() 146 | 147 | skip = Linear( 148 | irreps, 149 | num_indexed_weights=self.num_species, 150 | name="skip_tp", 151 | force_irreps_out=True, 152 | )(node_specie, node_feats) 153 | 154 | # Message passing 155 | node_feats = e3nn.scatter_sum(messages, dst=receivers, output_size=num_nodes) 156 | node_feats = node_feats / jnp.sqrt(self.avg_num_neighbors) 157 | 158 | node_feats = Linear(irreps, name="linear_down")(node_feats) 159 | 160 | node_feats = node_feats + skip 161 | assert node_feats.shape == (num_nodes, node_feats.irreps.dim) 162 | 163 | node_feats = e3nn.gate( 164 | node_feats, 165 | even_act=self.even_activation, 166 | odd_act=self.odd_activation, 167 | even_gate_act=self.gate_activation, 168 | ) 169 | 170 | return node_feats 171 | 172 | 173 | def flatten(w): 174 | return jnp.concatenate([x.ravel() for x in jax.tree_util.tree_leaves(w)]) 175 | 176 | 177 | def unflatten(array, template): 178 | lst = [] 179 | start = 0 180 | for x in jax.tree_util.tree_leaves(template): 181 | lst.append(array[start : start + x.size].reshape(x.shape)) 182 | start += x.size 183 | return jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(template), lst) 184 | -------------------------------------------------------------------------------- /nequip_jax/radial.py: -------------------------------------------------------------------------------- 1 | import e3nn_jax as e3nn 2 | 3 | 4 | def default_radial_basis(r, n: int): 5 | """Default radial basis function.""" 6 | return e3nn.bessel(r, n) * e3nn.poly_envelope(5, 2)(r)[:, None] 7 | 8 | 9 | def simple_smooth_radial_basis(r, n: int): 10 | return e3nn.soft_one_hot_linspace( 11 | r, 12 | start=0.0, 13 | end=1.0, 14 | number=n, 15 | basis="smooth_finite", 16 | start_zero=False, 17 | end_zero=True, 18 | ) 19 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = nequip_jax 3 | version = attr: nequip_jax.__version__ 4 | long_description = file: README.md 5 | long_description_content_type = text/markdown 6 | classifiers = 7 | Programming Language :: Python :: 3.8 8 | 9 | [options] 10 | packages = find: 11 | python_requires = >=3.8 12 | install_requires = 13 | e3nn_jax 14 | flax 15 | numpy 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() 5 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import e3nn_jax as e3nn 2 | import flax 3 | import haiku as hk 4 | import jax 5 | import jax.numpy as jnp 6 | import jraph 7 | 8 | from nequip_jax import ( 9 | NEQUIPLayerFlax, 10 | NEQUIPLayerHaiku, 11 | NEQUIPESCNLayerFlax, 12 | NEQUIPESCNLayerHaiku, 13 | filter_layers, 14 | ) 15 | 16 | 17 | def dummy_graph(): 18 | return jraph.GraphsTuple( 19 | nodes={ 20 | "positions": jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]), 21 | "species": jnp.array([0, 1]), 22 | }, 23 | edges=None, 24 | globals=None, 25 | senders=jnp.array([0, 1]), 26 | receivers=jnp.array([1, 0]), 27 | n_node=jnp.array([2]), 28 | n_edge=jnp.array([2]), 29 | ) 30 | 31 | 32 | def test_nequip_flax(): 33 | class NEQUIP(flax.linen.Module): 34 | @flax.linen.compact 35 | def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 36 | positions = graph.nodes["positions"] 37 | species = graph.nodes["species"] 38 | 39 | vectors = e3nn.IrrepsArray( 40 | "1o", positions[graph.receivers] - positions[graph.senders] 41 | ) 42 | node_feats = flax.linen.Embed(num_embeddings=5, features=32)(species) 43 | node_feats = e3nn.IrrepsArray(f"{node_feats.shape[1]}x0e", node_feats) 44 | 45 | layers_irreps = ["16x0e + 16x1o + 16x1e"] * 2 + ["0e"] 46 | layers_irreps = filter_layers(layers_irreps, max_ell=3) 47 | for irreps in layers_irreps: 48 | layer = NEQUIPLayerFlax( 49 | avg_num_neighbors=1.0, 50 | output_irreps=irreps, 51 | max_ell=3, 52 | ) 53 | node_feats = layer( 54 | vectors, 55 | node_feats, 56 | species, 57 | graph.senders, 58 | graph.receivers, 59 | ) 60 | 61 | return node_feats 62 | 63 | graph = dummy_graph() 64 | 65 | model = NEQUIP() 66 | w = model.init(jax.random.PRNGKey(0), graph) 67 | 68 | apply = jax.jit(model.apply) 69 | apply(w, graph) 70 | apply(w, graph) 71 | 72 | 73 | def test_nequip_haiku(): 74 | @hk.without_apply_rng 75 | @hk.transform 76 | def model(graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 77 | positions = graph.nodes["positions"] 78 | species = graph.nodes["species"] 79 | 80 | vectors = e3nn.IrrepsArray( 81 | "1o", positions[graph.receivers] - positions[graph.senders] 82 | ) 83 | node_feats = hk.Embed(vocab_size=5, embed_dim=32)(species) 84 | node_feats = e3nn.IrrepsArray(f"{node_feats.shape[1]}x0e", node_feats) 85 | 86 | layers_irreps = ["16x0e + 16x1o + 16x1e"] * 2 + ["0e"] 87 | layers_irreps = filter_layers(layers_irreps, max_ell=3) 88 | for irreps in layers_irreps: 89 | layer = NEQUIPLayerHaiku( 90 | avg_num_neighbors=1.0, 91 | output_irreps=irreps, 92 | max_ell=3, 93 | ) 94 | node_feats = layer( 95 | vectors, 96 | node_feats, 97 | species, 98 | graph.senders, 99 | graph.receivers, 100 | ) 101 | 102 | return node_feats 103 | 104 | graph = dummy_graph() 105 | 106 | w = model.init(jax.random.PRNGKey(0), graph) 107 | 108 | apply = jax.jit(model.apply) 109 | apply(w, graph) 110 | apply(w, graph) 111 | 112 | 113 | def test_nequip_escn_flax(): 114 | class NEQUIP(flax.linen.Module): 115 | @flax.linen.compact 116 | def __call__(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 117 | positions = graph.nodes["positions"] 118 | species = graph.nodes["species"] 119 | 120 | vectors = e3nn.IrrepsArray( 121 | "1o", positions[graph.receivers] - positions[graph.senders] 122 | ) 123 | node_feats = flax.linen.Embed(num_embeddings=5, features=32)(species) 124 | node_feats = e3nn.IrrepsArray(f"{node_feats.shape[1]}x0e", node_feats) 125 | 126 | layers_irreps = ["16x0e + 16x1o + 16x1e"] * 2 + ["0e"] 127 | layers_irreps = filter_layers(layers_irreps, max_ell=None) 128 | for irreps in layers_irreps: 129 | layer = NEQUIPESCNLayerFlax( 130 | avg_num_neighbors=1.0, 131 | output_irreps=irreps, 132 | ) 133 | node_feats = layer( 134 | vectors, 135 | node_feats, 136 | species, 137 | graph.senders, 138 | graph.receivers, 139 | ) 140 | 141 | return node_feats 142 | 143 | graph = dummy_graph() 144 | 145 | model = NEQUIP() 146 | w = model.init(jax.random.PRNGKey(0), graph) 147 | 148 | apply = jax.jit(model.apply) 149 | apply(w, graph) 150 | apply(w, graph) 151 | 152 | 153 | def test_nequip_escn_haiku(): 154 | @hk.without_apply_rng 155 | @hk.transform 156 | def model(graph: jraph.GraphsTuple) -> jraph.GraphsTuple: 157 | positions = graph.nodes["positions"] 158 | species = graph.nodes["species"] 159 | 160 | vectors = e3nn.IrrepsArray( 161 | "1o", positions[graph.receivers] - positions[graph.senders] 162 | ) 163 | node_feats = hk.Embed(vocab_size=5, embed_dim=32)(species) 164 | node_feats = e3nn.IrrepsArray(f"{node_feats.shape[1]}x0e", node_feats) 165 | 166 | layers_irreps = ["16x0e + 16x1o + 16x1e"] * 2 + ["0e"] 167 | layers_irreps = filter_layers(layers_irreps, max_ell=None) 168 | for irreps in layers_irreps: 169 | layer = NEQUIPESCNLayerHaiku( 170 | avg_num_neighbors=1.0, 171 | output_irreps=irreps, 172 | ) 173 | node_feats = layer( 174 | vectors, 175 | node_feats, 176 | species, 177 | graph.senders, 178 | graph.receivers, 179 | ) 180 | 181 | return node_feats 182 | 183 | graph = dummy_graph() 184 | 185 | w = model.init(jax.random.PRNGKey(0), graph) 186 | 187 | apply = jax.jit(model.apply) 188 | apply(w, graph) 189 | apply(w, graph) 190 | --------------------------------------------------------------------------------