├── experiments
├── __init__.py
├── utils
│ ├── __init__.py
│ ├── plot_utils.py
│ └── train_utils.py
├── fig
│ ├── kchains.png
│ ├── rotsym.png
│ ├── incompleteness.png
│ └── axes-of-expressivity.png
├── rotsym.ipynb
├── kchains.ipynb
└── incompleteness.ipynb
├── models
├── layers
│ ├── __init__.py
│ ├── tfn_layer.py
│ ├── egnn_layer.py
│ ├── gvp_layer.py
│ └── spherenet_layer.py
├── __init__.py
├── mace_modules
│ ├── __init__.py
│ ├── radial.py
│ ├── irreps_tools.py
│ ├── cg.py
│ ├── symmetric_contraction.py
│ └── blocks.py
├── schnet.py
├── egnn.py
├── dimenet.py
├── spherenet.py
├── gvpgnn.py
├── tfn.py
└── mace.py
├── LICENSE
├── .gitignore
└── README.md
/experiments/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/experiments/fig/kchains.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaitjo/geometric-gnn-dojo/HEAD/experiments/fig/kchains.png
--------------------------------------------------------------------------------
/experiments/fig/rotsym.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaitjo/geometric-gnn-dojo/HEAD/experiments/fig/rotsym.png
--------------------------------------------------------------------------------
/experiments/fig/incompleteness.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaitjo/geometric-gnn-dojo/HEAD/experiments/fig/incompleteness.png
--------------------------------------------------------------------------------
/experiments/fig/axes-of-expressivity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaitjo/geometric-gnn-dojo/HEAD/experiments/fig/axes-of-expressivity.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.schnet import SchNetModel
2 | from models.dimenet import DimeNetPPModel
3 | from models.spherenet import SphereNetModel
4 | from models.egnn import EGNNModel
5 | from models.gvpgnn import GVPGNNModel
6 | from models.tfn import TFNModel
7 | from models.mace import MACEModel
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Chaitanya K. Joshi, Simon V. Mathis
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 |
5 | # C extensions
6 | *.so
7 |
8 | # Distribution / packaging
9 | .Python
10 | env/
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | *.egg-info/
23 | .installed.cfg
24 | *.egg
25 |
26 | # PyInstaller
27 | # Usually these files are written by a python script from a template
28 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
29 | *.manifest
30 | *.spec
31 |
32 | # Installer logs
33 | pip-log.txt
34 | pip-delete-this-directory.txt
35 |
36 | # Unit test / coverage reports
37 | htmlcov/
38 | .tox/
39 | .coverage
40 | .coverage.*
41 | .cache
42 | nosetests.xml
43 | coverage.xml
44 | *.cover
45 |
46 | # Translations
47 | *.mo
48 | *.pot
49 |
50 | # Django stuff:
51 | *.log
52 |
53 | # Sphinx documentation
54 | docs/_build/
55 |
56 | # PyBuilder
57 | target/
58 |
59 | # DotEnv configuration
60 | .env
61 |
62 | # Database
63 | *.db
64 | *.rdb
65 |
66 | # Pycharm
67 | .idea
68 |
69 | # VS Code
70 | .vscode/
71 |
72 | # Spyder
73 | .spyproject/
74 |
75 | # Jupyter NB Checkpoints
76 | .ipynb_checkpoints/
77 |
78 | # exclude data from source control by default
79 | io
80 | io/
81 |
82 | # Mac OS-specific storage files
83 | .DS_Store
84 |
85 | # vim
86 | *.swp
87 | *.swo
88 |
89 | # Mypy cache
90 | .mypy_cache/
91 |
92 | # Other
93 | .history
94 | .history/
95 | logs/
96 | venv*/
97 | venv_gpu/
98 | lightning_logs
99 | .DS_Store
--------------------------------------------------------------------------------
/experiments/utils/plot_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch_geometric.utils import to_networkx
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | def plot_2d(data, lim=10):
7 | # The graph to visualize
8 | G = to_networkx(data)
9 | pos = data.pos.numpy()
10 |
11 | # Extract node and edge positions from the layout
12 | node_xyz = np.array([pos[v, :2] for v in sorted(G)])
13 | edge_xyz = np.array([(pos[u, :2], pos[v, :2]) for u, v in G.edges()])
14 |
15 | # Create the 2D figure
16 | fig = plt.figure()
17 | ax = fig.add_subplot(111)
18 |
19 | # Plot the nodes - alpha is scaled by "depth" automatically
20 | ax.scatter(*node_xyz.T, s=100, c=data.atoms.numpy(), cmap="rainbow")
21 |
22 | # Plot the edges
23 | for vizedge in edge_xyz:
24 | ax.plot(*vizedge.T, color="tab:gray")
25 |
26 | # Turn gridlines off
27 | # ax.grid(False)
28 |
29 | # Suppress tick labels
30 | # for dim in (ax.xaxis, ax.yaxis, ax.zaxis):
31 | # dim.set_ticks([])
32 |
33 | # Set axes labels and limits
34 | ax.set_xlabel("x")
35 | ax.set_ylabel("y")
36 | ax.set_xlim([-lim, lim])
37 | ax.set_ylim([-lim, lim])
38 | ax.set_aspect('equal', 'box')
39 |
40 | # fig.tight_layout()
41 | plt.show()
42 |
43 |
44 | def plot_3d(data, lim=10):
45 | # The graph to visualize
46 | G = to_networkx(data)
47 | pos = data.pos.numpy()
48 |
49 | # Extract node and edge positions from the layout
50 | node_xyz = np.array([pos[v] for v in sorted(G)])
51 | edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()])
52 |
53 | # Create the 3D figure
54 | fig = plt.figure()
55 | ax = fig.add_subplot(111, projection="3d")
56 |
57 | # Plot the nodes - alpha is scaled by "depth" automatically
58 | ax.scatter(*node_xyz.T, s=100, c=data.atoms.numpy(), cmap="rainbow")
59 |
60 | # Plot the edges
61 | for vizedge in edge_xyz:
62 | ax.plot(*vizedge.T, color="tab:gray")
63 |
64 | # Turn gridlines off
65 | # ax.grid(False)
66 |
67 | # Suppress tick labels
68 | # for dim in (ax.xaxis, ax.yaxis, ax.zaxis):
69 | # dim.set_ticks([])
70 |
71 | # Set axes labels and limits
72 | ax.set_xlabel("x")
73 | ax.set_ylabel("y")
74 | ax.set_zlabel("z")
75 | ax.set_xlim([-lim, lim])
76 | ax.set_ylim([-lim, lim])
77 | ax.set_zlim([-lim, lim])
78 |
79 | # fig.tight_layout()
80 | plt.show()
81 |
--------------------------------------------------------------------------------
/models/mace_modules/__init__.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # This directory contains an implementation of MACE, with minor adaptations
3 | #
4 | # Paper: MACE: Higher Order Equivariant Message Passing Neural Networks
5 | # for Fast and Accurate Force Fields, Batatia et al.
6 | #
7 | # Orginal repository: https://github.com/ACEsuit/mace
8 | ###########################################################################################
9 |
10 | from typing import Callable, Dict, Optional, Type
11 |
12 | import torch
13 |
14 | from .blocks import (
15 | AgnosticNonlinearInteractionBlock,
16 | AgnosticResidualNonlinearInteractionBlock,
17 | AtomicEnergiesBlock,
18 | EquivariantProductBasisBlock,
19 | InteractionBlock,
20 | LinearNodeEmbeddingBlock,
21 | LinearReadoutBlock,
22 | NonLinearReadoutBlock,
23 | RadialEmbeddingBlock,
24 | RealAgnosticInteractionBlock,
25 | RealAgnosticResidualInteractionBlock,
26 | ResidualElementDependentInteractionBlock,
27 | ScaleShiftBlock,
28 | )
29 | from .radial import BesselBasis, PolynomialCutoff
30 | from .symmetric_contraction import SymmetricContraction
31 |
32 | interaction_classes: Dict[str, Type[InteractionBlock]] = {
33 | "AgnosticNonlinearInteractionBlock": AgnosticNonlinearInteractionBlock,
34 | "ResidualElementDependentInteractionBlock": ResidualElementDependentInteractionBlock,
35 | "AgnosticResidualNonlinearInteractionBlock": AgnosticResidualNonlinearInteractionBlock,
36 | "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock,
37 | "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock,
38 | }
39 |
40 | gate_dict: Dict[str, Optional[Callable]] = {
41 | "abs": torch.abs,
42 | "tanh": torch.tanh,
43 | "silu": torch.nn.functional.silu,
44 | "None": None,
45 | }
46 |
47 | __all__ = [
48 | "AtomicEnergiesBlock",
49 | "RadialEmbeddingBlock",
50 | "LinearNodeEmbeddingBlock",
51 | "LinearReadoutBlock",
52 | "EquivariantProductBasisBlock",
53 | "ScaleShiftBlock",
54 | "InteractionBlock",
55 | "NonLinearReadoutBlock",
56 | "PolynomialCutoff",
57 | "BesselBasis",
58 | "MACE",
59 | "ScaleShiftMACE",
60 | "BOTNet",
61 | "ScaleShiftBOTNet",
62 | "EnergyForcesLoss",
63 | "WeightedEnergyForcesLoss",
64 | "WeightedForcesLoss",
65 | "SymmetricContraction",
66 | "interaction_classes",
67 | "compute_mean_std_atomic_inter_energy",
68 | "compute_avg_num_neighbors",
69 | ]
70 |
--------------------------------------------------------------------------------
/models/mace_modules/radial.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # Radial basis and cutoff
3 | # Authors: Ilyes Batatia, Gregor Simm
4 | # This program is distributed under the MIT License (see MIT.md)
5 | ###########################################################################################
6 |
7 | import numpy as np
8 | import torch
9 | from e3nn.util.jit import compile_mode
10 |
11 |
12 | class BesselBasis(torch.nn.Module):
13 | """
14 | Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020.
15 | Equation (7)
16 | """
17 |
18 | def __init__(self, r_max: float, num_basis=8, trainable=False):
19 | super().__init__()
20 |
21 | bessel_weights = (
22 | np.pi
23 | / r_max
24 | * torch.linspace(
25 | start=1.0,
26 | end=num_basis,
27 | steps=num_basis,
28 | dtype=torch.get_default_dtype(),
29 | )
30 | )
31 | if trainable:
32 | self.bessel_weights = torch.nn.Parameter(bessel_weights)
33 | else:
34 | self.register_buffer("bessel_weights", bessel_weights)
35 |
36 | self.register_buffer(
37 | "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
38 | )
39 | self.register_buffer(
40 | "prefactor",
41 | torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()),
42 | )
43 |
44 | def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1]
45 | numerator = torch.sin(self.bessel_weights * x) # [..., num_basis]
46 | return self.prefactor * (numerator / x)
47 |
48 | def __repr__(self):
49 | return (
50 | f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, "
51 | f"trainable={self.bessel_weights.requires_grad})"
52 | )
53 |
54 |
55 | class PolynomialCutoff(torch.nn.Module):
56 | """
57 | Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020.
58 | Equation (8)
59 | """
60 |
61 | p: torch.Tensor
62 | r_max: torch.Tensor
63 |
64 | def __init__(self, r_max: float, p=6):
65 | super().__init__()
66 | self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
67 | self.register_buffer(
68 | "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())
69 | )
70 |
71 | def forward(self, x: torch.Tensor) -> torch.Tensor:
72 | envelope = (
73 | 1.0
74 | - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p)
75 | + self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1)
76 | - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2)
77 | )
78 | return envelope * (x < self.r_max)
79 |
80 | def __repr__(self):
81 | return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})"
82 |
--------------------------------------------------------------------------------
/models/schnet.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from torch_geometric.nn import SchNet
6 | from torch_geometric.nn import global_add_pool, global_mean_pool
7 |
8 |
9 | class SchNetModel(SchNet):
10 | """
11 | SchNet model from "Schnet - a deep learning architecture for molecules and materials".
12 |
13 | This class extends the SchNet base class for PyG.
14 | """
15 | def __init__(
16 | self,
17 | hidden_channels: int = 128,
18 | in_dim: int = 1,
19 | out_dim: int = 1,
20 | num_filters: int = 128,
21 | num_layers: int = 6,
22 | num_gaussians: int = 50,
23 | cutoff: float = 10,
24 | max_num_neighbors: int = 32,
25 | pool: str = 'sum'
26 | ):
27 | """
28 | Initializes an instance of the SchNetModel class with the provided parameters.
29 |
30 | Parameters:
31 | - hidden_channels (int): Number of channels in the hidden layers (default: 128)
32 | - in_dim (int): Input dimension of the model (default: 1)
33 | - out_dim (int): Output dimension of the model (default: 1)
34 | - num_filters (int): Number of filters used in convolutional layers (default: 128)
35 | - num_layers (int): Number of convolutional layers in the model (default: 6)
36 | - num_gaussians (int): Number of Gaussian functions used for radial filters (default: 50)
37 | - cutoff (float): Cutoff distance for interactions (default: 10)
38 | - max_num_neighbors (int): Maximum number of neighboring atoms to consider (default: 32)
39 | - pool (str): Global pooling method to be used (default: "sum")
40 | """
41 | super().__init__(
42 | hidden_channels,
43 | num_filters,
44 | num_layers,
45 | num_gaussians,
46 | cutoff,
47 | interaction_graph=None,
48 | max_num_neighbors=max_num_neighbors,
49 | readout=pool,
50 | dipole=False,
51 | mean=None,
52 | std=None,
53 | atomref=None
54 | )
55 |
56 | # Global pooling/readout function
57 | self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]
58 |
59 | # Overwrite atom embedding and final predictor
60 | self.lin2 = torch.nn.Linear(hidden_channels // 2, out_dim)
61 |
62 | def forward(self, batch):
63 |
64 | h = self.embedding(batch.atoms) # (n,) -> (n, d)
65 |
66 | row, col = batch.edge_index
67 | edge_weight = (batch.pos[row] - batch.pos[col]).norm(dim=-1)
68 | edge_attr = self.distance_expansion(edge_weight)
69 |
70 | for interaction in self.interactions:
71 | # # Message passing layer: (n, d) -> (n, d)
72 | h = h + interaction(h, batch.edge_index, edge_weight, edge_attr)
73 |
74 | out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d)
75 |
76 | out = self.lin1(out)
77 | out = self.act(out)
78 | out = self.lin2(out) # (batch_size, out_dim)
79 |
80 | return out
81 |
--------------------------------------------------------------------------------
/models/egnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | from torch_geometric.nn import global_add_pool, global_mean_pool
4 |
5 | from models.layers.egnn_layer import EGNNLayer
6 |
7 |
8 | class EGNNModel(torch.nn.Module):
9 | """
10 | E-GNN model from "E(n) Equivariant Graph Neural Networks".
11 | """
12 | def __init__(
13 | self,
14 | num_layers: int = 5,
15 | emb_dim: int = 128,
16 | in_dim: int = 1,
17 | out_dim: int = 1,
18 | activation: str = "relu",
19 | norm: str = "layer",
20 | aggr: str = "sum",
21 | pool: str = "sum",
22 | residual: bool = True,
23 | equivariant_pred: bool = False
24 | ):
25 | """
26 | Initializes an instance of the EGNNModel class with the provided parameters.
27 |
28 | Parameters:
29 | - num_layers (int): Number of layers in the model (default: 5)
30 | - emb_dim (int): Dimension of the node embeddings (default: 128)
31 | - in_dim (int): Input dimension of the model (default: 1)
32 | - out_dim (int): Output dimension of the model (default: 1)
33 | - activation (str): Activation function to be used (default: "relu")
34 | - norm (str): Normalization method to be used (default: "layer")
35 | - aggr (str): Aggregation method to be used (default: "sum")
36 | - pool (str): Global pooling method to be used (default: "sum")
37 | - residual (bool): Whether to use residual connections (default: True)
38 | - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False)
39 | """
40 | super().__init__()
41 | self.equivariant_pred = equivariant_pred
42 | self.residual = residual
43 |
44 | # Embedding lookup for initial node features
45 | self.emb_in = torch.nn.Embedding(in_dim, emb_dim)
46 |
47 | # Stack of GNN layers
48 | self.convs = torch.nn.ModuleList()
49 | for _ in range(num_layers):
50 | self.convs.append(EGNNLayer(emb_dim, activation, norm, aggr))
51 |
52 | # Global pooling/readout function
53 | self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]
54 |
55 | if self.equivariant_pred:
56 | # Linear predictor for equivariant tasks using geometric features
57 | self.pred = torch.nn.Linear(emb_dim + 3, out_dim)
58 | else:
59 | # MLP predictor for invariant tasks using only scalar features
60 | self.pred = torch.nn.Sequential(
61 | torch.nn.Linear(emb_dim, emb_dim),
62 | torch.nn.ReLU(),
63 | torch.nn.Linear(emb_dim, out_dim)
64 | )
65 |
66 | def forward(self, batch):
67 |
68 | h = self.emb_in(batch.atoms) # (n,) -> (n, d)
69 | pos = batch.pos # (n, 3)
70 |
71 | for conv in self.convs:
72 | # Message passing layer
73 | h_update, pos_update = conv(h, pos, batch.edge_index)
74 |
75 | # Update node features (n, d) -> (n, d)
76 | h = h + h_update if self.residual else h_update
77 |
78 | # Update node coordinates (no residual) (n, 3) -> (n, 3)
79 | pos = pos_update
80 |
81 | if not self.equivariant_pred:
82 | # Select only scalars for invariant prediction
83 | out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d)
84 | else:
85 | out = self.pool(torch.cat([h, pos], dim=-1), batch.batch)
86 |
87 | return self.pred(out) # (batch_size, out_dim)
88 |
--------------------------------------------------------------------------------
/models/mace_modules/irreps_tools.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # Elementary tools for handling irreducible representations
3 | # Authors: Ilyes Batatia, Gregor Simm
4 | # This program is distributed under the MIT License (see MIT.md)
5 | ###########################################################################################
6 |
7 | from typing import List, Tuple
8 |
9 | import torch
10 | from e3nn import o3
11 | from e3nn.util.jit import compile_mode
12 |
13 |
14 | # Based on mir-group/nequip
15 | def tp_out_irreps_with_instructions(
16 | irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps
17 | ) -> Tuple[o3.Irreps, List]:
18 | trainable = True
19 |
20 | # Collect possible irreps and their instructions
21 | irreps_out_list: List[Tuple[int, o3.Irreps]] = []
22 | instructions = []
23 | for i, (mul, ir_in) in enumerate(irreps1):
24 | for j, (_, ir_edge) in enumerate(irreps2):
25 | for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2
26 | if ir_out in target_irreps:
27 | k = len(irreps_out_list) # instruction index
28 | irreps_out_list.append((mul, ir_out))
29 | instructions.append((i, j, k, "uvu", trainable))
30 |
31 | # We sort the output irreps of the tensor product so that we can simplify them
32 | # when they are provided to the second o3.Linear
33 | irreps_out = o3.Irreps(irreps_out_list)
34 | irreps_out, permut, _ = irreps_out.sort()
35 |
36 | # Permute the output indexes of the instructions to match the sorted irreps:
37 | instructions = [
38 | (i_in1, i_in2, permut[i_out], mode, train)
39 | for i_in1, i_in2, i_out, mode, train in instructions
40 | ]
41 |
42 | return irreps_out, instructions
43 |
44 |
45 | def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps:
46 | # Assuming simplified irreps
47 | irreps_mid = []
48 | for _, ir_in in irreps:
49 | found = False
50 |
51 | for mul, ir_out in target_irreps:
52 | if ir_in == ir_out:
53 | irreps_mid.append((mul, ir_out))
54 | found = True
55 | break
56 |
57 | if not found:
58 | raise RuntimeError(f"{ir_in} not in {target_irreps}")
59 |
60 | return o3.Irreps(irreps_mid)
61 |
62 |
63 | @compile_mode("script")
64 | class reshape_irreps(torch.nn.Module):
65 | def __init__(self, irreps: o3.Irreps) -> None:
66 | super().__init__()
67 | self.irreps = irreps
68 |
69 | def forward(self, tensor: torch.Tensor) -> torch.Tensor:
70 | ix = 0
71 | out = []
72 | batch, _ = tensor.shape
73 | for mul, ir in self.irreps:
74 | d = ir.dim
75 | field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr]
76 | ix += mul * d
77 | field = field.reshape(batch, mul, d)
78 | out.append(field)
79 | return torch.cat(out, dim=-1)
80 |
81 |
82 | def irreps2gate(irreps):
83 | irreps_scalars = []
84 | irreps_gated = []
85 | for mul, ir in irreps:
86 | if ir.l == 0 and ir.p == 1:
87 | irreps_scalars.append((mul, ir))
88 | else:
89 | irreps_gated.append((mul, ir))
90 | irreps_scalars = o3.Irreps(irreps_scalars).simplify()
91 | irreps_gated = o3.Irreps(irreps_gated).simplify()
92 | if irreps_gated.dim > 0:
93 | ir = '0e'
94 | else:
95 | ir = None
96 | irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify()
97 | return irreps_scalars, irreps_gates, irreps_gated
98 |
--------------------------------------------------------------------------------
/models/layers/tfn_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_scatter import scatter
3 | import e3nn
4 |
5 | from models.mace_modules.irreps_tools import irreps2gate
6 |
7 |
8 | class TensorProductConvLayer(torch.nn.Module):
9 | """Tensor Field Network GNN Layer in e3nn
10 |
11 | Implements a Tensor Field Network equivariant GNN layer for higher-order tensors, using e3nn.
12 | Implementation adapted from: https://github.com/gcorso/DiffDock/
13 |
14 | Paper: Tensor Field Networks, Thomas, Smidt et al.
15 | """
16 | def __init__(
17 | self,
18 | in_irreps,
19 | out_irreps,
20 | sh_irreps,
21 | edge_feats_dim,
22 | mlp_dim,
23 | aggr="add",
24 | batch_norm=False,
25 | gate=False,
26 | ):
27 | """
28 | Args:
29 | in_irreps: (e3nn.o3.Irreps) Input irreps dimensions
30 | out_irreps: (e3nn.o3.Irreps) Output irreps dimensions
31 | sh_irreps: (e3nn.o3.Irreps) Spherical harmonic irreps dimensions
32 | edge_feats_dim: (int) Edge feature dimensions
33 | mlp_dim: (int) Hidden dimension of MLP for computing tensor product weights
34 | aggr: (str) Message passing aggregator
35 | batch_norm: (bool) Whether to apply equivariant batch norm
36 | gate: (bool) Whether to apply gated non-linearity
37 | """
38 | super().__init__()
39 | self.in_irreps = in_irreps
40 | self.out_irreps = out_irreps
41 | self.sh_irreps = sh_irreps
42 | self.edge_feats_dim = edge_feats_dim
43 | self.aggr = aggr
44 |
45 | if gate:
46 | # Optionally apply gated non-linearity
47 | irreps_scalars, irreps_gates, irreps_gated = irreps2gate(
48 | e3nn.o3.Irreps(out_irreps)
49 | )
50 | act_scalars = [torch.nn.functional.silu for _, ir in irreps_scalars]
51 | act_gates = [torch.sigmoid for _, ir in irreps_gates]
52 | if irreps_gated.num_irreps == 0:
53 | self.gate = e3nn.nn.Activation(out_irreps, acts=[torch.nn.functional.silu])
54 | else:
55 | self.gate = e3nn.nn.Gate(
56 | irreps_scalars,
57 | act_scalars, # scalar
58 | irreps_gates,
59 | act_gates, # gates (scalars)
60 | irreps_gated, # gated tensors
61 | )
62 | # Output irreps for the tensor product must be updated
63 | self.out_irreps = out_irreps = self.gate.irreps_in
64 | else:
65 | self.gate = None
66 |
67 | # Tensor product over edges to construct messages
68 | self.tp = e3nn.o3.FullyConnectedTensorProduct(
69 | in_irreps, sh_irreps, out_irreps, shared_weights=False
70 | )
71 |
72 | # MLP used to compute weights of tensor product
73 | self.fc = torch.nn.Sequential(
74 | torch.nn.Linear(edge_feats_dim, mlp_dim),
75 | torch.nn.ReLU(),
76 | torch.nn.Linear(mlp_dim, self.tp.weight_numel),
77 | )
78 |
79 | # Optional equivariant batch norm
80 | self.batch_norm = e3nn.nn.BatchNorm(out_irreps) if batch_norm else None
81 |
82 | def forward(self, node_attr, edge_index, edge_sh, edge_feat):
83 | src, dst = edge_index
84 | # Compute messages
85 | tp = self.tp(node_attr[dst], edge_sh, self.fc(edge_feat))
86 | # Aggregate messages
87 | out = scatter(tp, src, dim=0, reduce=self.aggr)
88 | # Optionally apply gated non-linearity and/or batch norm
89 | if self.gate:
90 | out = self.gate(out)
91 | if self.batch_norm:
92 | out = self.batch_norm(out)
93 | return out
94 |
--------------------------------------------------------------------------------
/models/dimenet.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Union
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from torch_geometric.nn import DimeNetPlusPlus
6 | from torch_scatter import scatter
7 |
8 |
9 | class DimeNetPPModel(DimeNetPlusPlus):
10 | """
11 | DimeNet model from "Directional message passing for molecular graphs".
12 |
13 | This class extends the DimeNetPlusPlus base class for PyG.
14 | """
15 | def __init__(
16 | self,
17 | hidden_channels: int = 128,
18 | in_dim: int = 1,
19 | out_dim: int = 1,
20 | num_layers: int = 4,
21 | int_emb_size: int = 64,
22 | basis_emb_size: int = 8,
23 | out_emb_channels: int = 256,
24 | num_spherical: int = 7,
25 | num_radial: int = 6,
26 | cutoff: float = 10,
27 | max_num_neighbors: int = 32,
28 | envelope_exponent: int = 5,
29 | num_before_skip: int = 1,
30 | num_after_skip: int = 2,
31 | num_output_layers: int = 3,
32 | act: Union[str, Callable] = 'swish'
33 | ):
34 | """
35 | Initializes an instance of the DimeNetPPModel class with the provided parameters.
36 |
37 | Parameters:
38 | - hidden_channels (int): Number of channels in the hidden layers (default: 128)
39 | - in_dim (int): Input dimension of the model (default: 1)
40 | - out_dim (int): Output dimension of the model (default: 1)
41 | - num_layers (int): Number of layers in the model (default: 4)
42 | - int_emb_size (int): Embedding size for interaction features (default: 64)
43 | - basis_emb_size (int): Embedding size for basis functions (default: 8)
44 | - out_emb_channels (int): Number of channels in the output embeddings (default: 256)
45 | - num_spherical (int): Number of spherical harmonics (default: 7)
46 | - num_radial (int): Number of radial basis functions (default: 6)
47 | - cutoff (float): Cutoff distance for interactions (default: 10)
48 | - max_num_neighbors (int): Maximum number of neighboring atoms to consider (default: 32)
49 | - envelope_exponent (int): Exponent of the envelope function (default: 5)
50 | - num_before_skip (int): Number of layers before the skip connections (default: 1)
51 | - num_after_skip (int): Number of layers after the skip connections (default: 2)
52 | - num_output_layers (int): Number of output layers (default: 3)
53 | - act (Union[str, Callable]): Activation function (default: 'swish' or callable)
54 |
55 | Note:
56 | - The `act` parameter can be either a string representing a built-in activation function,
57 | or a callable object that serves as a custom activation function.
58 | """
59 | super().__init__(
60 | hidden_channels,
61 | out_dim,
62 | num_layers,
63 | int_emb_size,
64 | basis_emb_size,
65 | out_emb_channels,
66 | num_spherical,
67 | num_radial,
68 | cutoff,
69 | max_num_neighbors,
70 | envelope_exponent,
71 | num_before_skip,
72 | num_after_skip,
73 | num_output_layers,
74 | act
75 | )
76 |
77 | def forward(self, batch):
78 |
79 | i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets(
80 | batch.edge_index, num_nodes=batch.atoms.size(0))
81 |
82 | # Calculate distances.
83 | dist = (batch.pos[i] - batch.pos[j]).pow(2).sum(dim=-1).sqrt()
84 |
85 | # Calculate angles.
86 | pos_i = batch.pos[idx_i]
87 | pos_ji, pos_ki = batch.pos[idx_j] - pos_i, batch.pos[idx_k] - pos_i
88 | a = (pos_ji * pos_ki).sum(dim=-1)
89 | b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
90 | angle = torch.atan2(b, a)
91 |
92 | rbf = self.rbf(dist)
93 | sbf = self.sbf(dist, angle, idx_kj)
94 |
95 | # Embedding block.
96 | x = self.emb(batch.atoms, rbf, i, j)
97 | P = self.output_blocks[0](x, rbf, i, num_nodes=batch.pos.size(0))
98 |
99 | # Interaction blocks.
100 | for interaction_block, output_block in zip(self.interaction_blocks,
101 | self.output_blocks[1:]):
102 | x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
103 | P += output_block(x, rbf, i)
104 |
105 | return P.sum(dim=0) if batch is None else scatter(P, batch.batch, dim=0)
106 |
--------------------------------------------------------------------------------
/experiments/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import random
3 | from tqdm.autonotebook import tqdm # from tqdm import tqdm
4 | import numpy as np
5 | from sklearn.metrics import accuracy_score
6 |
7 | import torch
8 | import torch.nn.functional as F
9 |
10 |
11 | def seed(seed=0):
12 | random.seed(seed)
13 | np.random.seed(seed)
14 | torch.manual_seed(seed)
15 | torch.cuda.manual_seed(seed)
16 | torch.cuda.manual_seed_all(seed)
17 | torch.backends.cudnn.deterministic = True
18 | torch.backends.cudnn.benchmark = False
19 |
20 |
21 | def train(model, train_loader, optimizer, device):
22 | model.train()
23 | loss_all = 0
24 | for batch in train_loader:
25 | batch = batch.to(device)
26 | optimizer.zero_grad()
27 | y_pred = model(batch)
28 | loss = F.cross_entropy(y_pred, batch.y)
29 | loss.backward()
30 | loss_all += loss.item() * batch.num_graphs
31 | optimizer.step()
32 | return loss_all / len(train_loader.dataset)
33 |
34 |
35 | def eval(model, loader, device):
36 | model.eval()
37 | y_pred = []
38 | y_true = []
39 | for batch in loader:
40 | batch = batch.to(device)
41 | with torch.no_grad():
42 | y_pred.append(model(batch).detach().cpu())
43 | y_true.append(batch.y.detach().cpu())
44 | return accuracy_score(
45 | torch.concat(y_true, dim=0),
46 | np.argmax(torch.concat(y_pred, dim=0), axis=1)
47 | ) * 100 # return percentage
48 |
49 |
50 | def _run_experiment(model, train_loader, val_loader, test_loader, n_epochs=100, verbose=True, device='cpu'):
51 | total_param = 0
52 | for param in model.parameters():
53 | total_param += np.prod(list(param.data.size()))
54 | model = model.to(device)
55 |
56 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
57 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
58 | optimizer, mode='max', factor=0.9, patience=25, min_lr=0.00001)
59 |
60 | if verbose:
61 | print(f"Running experiment for {type(model).__name__}.")
62 | # print("\nModel architecture:")
63 | # print(model)
64 | print(f'Total parameters: {total_param}')
65 | print("\nStart training:")
66 |
67 | best_val_acc = None
68 | perf_per_epoch = [] # Track Test/Val performace vs. epoch (for plotting)
69 | t = time.time()
70 | for epoch in range(1, n_epochs+1):
71 | # Train model for one epoch, return avg. training loss
72 | loss = train(model, train_loader, optimizer, device)
73 |
74 | # Evaluate model on validation set
75 | val_acc = eval(model, val_loader, device)
76 |
77 | if best_val_acc is None or val_acc >= best_val_acc:
78 | # Evaluate model on test set if validation metric improves
79 | test_acc = eval(model, test_loader, device)
80 | best_val_acc = val_acc
81 |
82 | if epoch % 10 == 0 and verbose:
83 | print(f'Epoch: {epoch:03d}, LR: {lr:.5f}, Loss: {loss:.5f}, '
84 | f'Val Acc: {val_acc:.3f}, Test Acc: {test_acc:.3f}')
85 |
86 | perf_per_epoch.append((test_acc, val_acc, epoch, type(model).__name__))
87 | scheduler.step(val_acc)
88 | lr = optimizer.param_groups[0]['lr']
89 |
90 | t = time.time() - t
91 | train_time = t
92 | if verbose:
93 | print(f"\nDone! Training took {train_time:.2f}s. Best validation accuracy: {best_val_acc:.3f}, corresponding test accuracy: {test_acc:.3f}.")
94 |
95 | return best_val_acc, test_acc, train_time, perf_per_epoch
96 |
97 |
98 | def run_experiment(model, train_loader, val_loader, test_loader, n_epochs=100, n_times=100, verbose=False, device='cpu'):
99 | print(f"Running experiment for {type(model).__name__} ({device}).")
100 |
101 | best_val_acc_list = []
102 | test_acc_list = []
103 | train_time_list = []
104 | for idx in tqdm(range(n_times)):
105 | seed(idx) # set random seed
106 | best_val_acc, test_acc, train_time, _ = _run_experiment(model, train_loader, val_loader, test_loader, n_epochs, verbose, device)
107 | best_val_acc_list.append(best_val_acc)
108 | test_acc_list.append(test_acc)
109 | train_time_list.append(train_time)
110 |
111 | print(f'\nDone! Averaged over {n_times} runs: \n '
112 | f'- Training time: {np.mean(train_time_list):.2f}s ± {np.std(train_time_list):.2f}. \n '
113 | f'- Best validation accuracy: {np.mean(best_val_acc_list):.3f} ± {np.std(best_val_acc_list):.3f}. \n'
114 | f'- Test accuracy: {np.mean(test_acc_list):.1f} ± {np.std(test_acc_list):.1f}. \n')
115 |
116 | return best_val_acc_list, test_acc_list, train_time_list
117 |
--------------------------------------------------------------------------------
/models/mace_modules/cg.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # Higher Order Real Clebsch Gordan (based on e3nn by Mario Geiger)
3 | # Authors: Ilyes Batatia
4 | # This program is distributed under the MIT License (see MIT.md)
5 | ###########################################################################################
6 |
7 | import collections
8 | from typing import List, Union
9 |
10 | import torch
11 | from e3nn import o3
12 |
13 | # Based on e3nn
14 |
15 | _TP = collections.namedtuple("_TP", "op, args")
16 | _INPUT = collections.namedtuple("_INPUT", "tensor, start, stop")
17 |
18 |
19 | def _wigner_nj(
20 | irrepss: List[o3.Irreps],
21 | normalization: str = "component",
22 | filter_ir_mid=None,
23 | dtype=None,
24 | ):
25 | irrepss = [o3.Irreps(irreps) for irreps in irrepss]
26 | if filter_ir_mid is not None:
27 | filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid]
28 |
29 | if len(irrepss) == 1:
30 | (irreps,) = irrepss
31 | ret = []
32 | e = torch.eye(irreps.dim, dtype=dtype)
33 | i = 0
34 | for mul, ir in irreps:
35 | for _ in range(mul):
36 | sl = slice(i, i + ir.dim)
37 | ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])]
38 | i += ir.dim
39 | return ret
40 |
41 | *irrepss_left, irreps_right = irrepss
42 | ret = []
43 | for ir_left, path_left, C_left in _wigner_nj(
44 | irrepss_left,
45 | normalization=normalization,
46 | filter_ir_mid=filter_ir_mid,
47 | dtype=dtype,
48 | ):
49 | i = 0
50 | for mul, ir in irreps_right:
51 | for ir_out in ir_left * ir:
52 | if filter_ir_mid is not None and ir_out not in filter_ir_mid:
53 | continue
54 |
55 | C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype)
56 | if normalization == "component":
57 | C *= ir_out.dim ** 0.5
58 | if normalization == "norm":
59 | C *= ir_left.dim ** 0.5 * ir.dim ** 0.5
60 |
61 | C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C)
62 | C = C.reshape(
63 | ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim
64 | )
65 | for u in range(mul):
66 | E = torch.zeros(
67 | ir_out.dim,
68 | *(irreps.dim for irreps in irrepss_left),
69 | irreps_right.dim,
70 | dtype=dtype,
71 | )
72 | sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim)
73 | E[..., sl] = C
74 | ret += [
75 | (
76 | ir_out,
77 | _TP(
78 | op=(ir_left, ir, ir_out),
79 | args=(
80 | path_left,
81 | _INPUT(len(irrepss_left), sl.start, sl.stop),
82 | ),
83 | ),
84 | E,
85 | )
86 | ]
87 | i += mul * ir.dim
88 | return sorted(ret, key=lambda x: x[0])
89 |
90 |
91 | def U_matrix_real(
92 | irreps_in: Union[str, o3.Irreps],
93 | irreps_out: Union[str, o3.Irreps],
94 | correlation: int,
95 | normalization: str = "component",
96 | filter_ir_mid=None,
97 | dtype=None,
98 | ):
99 | irreps_out = o3.Irreps(irreps_out)
100 | irrepss = [o3.Irreps(irreps_in)] * correlation
101 | if correlation == 4:
102 | filter_ir_mid = [
103 | (0, 1),
104 | (1, -1),
105 | (2, 1),
106 | (3, -1),
107 | (4, 1),
108 | (5, -1),
109 | (6, 1),
110 | (7, -1),
111 | (8, 1),
112 | (9, -1),
113 | (10, 1),
114 | (11, -1),
115 | ]
116 | wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype)
117 | current_ir = wigners[0][0]
118 | out = []
119 | stack = torch.tensor([])
120 |
121 | for ir, _, base_o3 in wigners:
122 | if ir in irreps_out and ir == current_ir:
123 | stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1)
124 | last_ir = current_ir
125 | elif ir in irreps_out and ir != current_ir:
126 | if len(stack) != 0:
127 | out += [last_ir, stack]
128 | stack = base_o3.squeeze().unsqueeze(-1)
129 | current_ir, last_ir = ir, ir
130 | else:
131 | current_ir = ir
132 | out += [last_ir, stack]
133 | return out
134 |
--------------------------------------------------------------------------------
/models/spherenet.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from torch_scatter import scatter
6 |
7 | from models.layers.spherenet_layer import *
8 |
9 |
10 | class SphereNetModel(torch.nn.Module):
11 | """
12 | SphereNet model from "Spherical Message Passing for 3D Molecular Graphs".
13 | """
14 | def __init__(
15 | self,
16 | cutoff: float = 10,
17 | num_layers: int = 4,
18 | hidden_channels: int = 128,
19 | in_dim: int = 1,
20 | out_dim: int = 1,
21 | int_emb_size: int = 64,
22 | basis_emb_size_dist: int = 8,
23 | basis_emb_size_angle: int = 8,
24 | basis_emb_size_torsion: int = 8,
25 | out_emb_channels: int = 128,
26 | num_spherical: int = 7,
27 | num_radial: int = 6,
28 | envelope_exponent: int = 5,
29 | num_before_skip: int = 1,
30 | num_after_skip: int = 2,
31 | num_output_layers: int = 2,
32 | act: Callable = swish,
33 | output_init: str = 'GlorotOrthogonal',
34 | use_node_features: bool = True
35 | ):
36 | """
37 | Initializes an instance of the SphereNetModel class with the following parameters:
38 |
39 | Parameters:
40 | - cutoff (int): Cutoff distance for interactions (default: 10)
41 | - num_layers (int): Number of layers in the model (default: 4)
42 | - hidden_channels (int): Number of channels in the hidden layers (default: 128)
43 | - in_dim (int): Input dimension of the model (default: 1)
44 | - out_dim (int): Output dimension of the model (default: 1)
45 | - int_emb_size (int): Embedding size for interaction features (default: 64)
46 | - basis_emb_size_dist (int): Embedding size for distance basis functions (default: 8)
47 | - basis_emb_size_angle (int): Embedding size for angle basis functions (default: 8)
48 | - basis_emb_size_torsion (int): Embedding size for torsion basis functions (default: 8)
49 | - out_emb_channels (int): Number of channels in the output embeddings (default: 128)
50 | - num_spherical (int): Number of spherical harmonics (default: 7)
51 | - num_radial (int): Number of radial basis functions (default: 6)
52 | - envelope_exponent (int): Exponent of the envelope function (default: 5)
53 | - num_before_skip (int): Number of layers before the skip connections (default: 1)
54 | - num_after_skip (int): Number of layers after the skip connections (default: 2)
55 | - num_output_layers (int): Number of output layers (default: 2)
56 | - act (function): Activation function (default: swish)
57 | - output_init (str): Initialization method for the output layer (default: 'GlorotOrthogonal')
58 | - use_node_features (bool): Whether to use node features (default: True)
59 | """
60 | super().__init__()
61 |
62 | self.cutoff = cutoff
63 |
64 | self.init_e = init(num_radial, hidden_channels, act, use_node_features=use_node_features)
65 | self.init_v = update_v(hidden_channels, out_emb_channels, out_dim, num_output_layers, act, output_init)
66 | self.init_u = update_u()
67 | self.emb = emb(num_spherical, num_radial, self.cutoff, envelope_exponent)
68 |
69 | self.update_vs = torch.nn.ModuleList([
70 | update_v(hidden_channels, out_emb_channels, out_dim, num_output_layers, act, output_init) for _ in range(num_layers)])
71 |
72 | self.update_es = torch.nn.ModuleList([
73 | update_e(hidden_channels, int_emb_size, basis_emb_size_dist, basis_emb_size_angle, basis_emb_size_torsion, num_spherical, num_radial, num_before_skip, num_after_skip,act) for _ in range(num_layers)])
74 |
75 | self.update_us = torch.nn.ModuleList([update_u() for _ in range(num_layers)])
76 |
77 | self.reset_parameters()
78 |
79 | def reset_parameters(self):
80 | self.init_e.reset_parameters()
81 | self.init_v.reset_parameters()
82 | self.emb.reset_parameters()
83 | for update_e in self.update_es:
84 | update_e.reset_parameters()
85 | for update_v in self.update_vs:
86 | update_v.reset_parameters()
87 |
88 |
89 | def forward(self, batch_data):
90 | z, pos, batch = batch_data.atoms, batch_data.pos, batch_data.batch
91 | edge_index = batch_data.edge_index
92 | num_nodes = z.size(0)
93 | dist, angle, torsion, i, j, idx_kj, idx_ji = xyz_to_dat(pos, edge_index, num_nodes, use_torsion=True)
94 |
95 | emb = self.emb(dist, angle, torsion, idx_kj)
96 |
97 | # Initialize edge, node, graph features
98 | e = self.init_e(z, emb, i, j)
99 | v = self.init_v(e, i)
100 | # Disable virutal node trick
101 | # u = self.init_u(torch.zeros_like(scatter(v, batch, dim=0)), v, batch)
102 |
103 | for update_e, update_v, update_u in zip(self.update_es, self.update_vs, self.update_us):
104 | e = update_e(e, emb, idx_kj, idx_ji)
105 | v = update_v(e, i)
106 | # Disable virutal node trick
107 | # u = update_u(u, v, batch)
108 |
109 | out = scatter(v, batch, dim=0, reduce='add')
110 | return out
111 |
--------------------------------------------------------------------------------
/models/gvpgnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | from torch_geometric.nn import global_add_pool, global_mean_pool
4 |
5 | from models.mace_modules.blocks import RadialEmbeddingBlock
6 | import models.layers.gvp_layer as gvp
7 |
8 |
9 | class GVPGNNModel(torch.nn.Module):
10 | """
11 | GVP-GNN model from "Equivariant Graph Neural Networks for 3D Macromolecular Structure".
12 | """
13 | def __init__(
14 | self,
15 | r_max: float = 10.0,
16 | num_bessel: int = 8,
17 | num_polynomial_cutoff: int = 5,
18 | num_layers: int = 5,
19 | in_dim=1,
20 | out_dim=1,
21 | s_dim: int = 128,
22 | v_dim: int = 16,
23 | s_dim_edge: int = 32,
24 | v_dim_edge: int = 1,
25 | pool: str = "sum",
26 | residual: bool = True,
27 | equivariant_pred: bool = False
28 | ):
29 | """
30 | Initializes an instance of the GVPGNNModel class with the provided parameters.
31 |
32 | Parameters:
33 | - r_max (float): Maximum distance for Bessel basis functions (default: 10.0)
34 | - num_bessel (int): Number of Bessel basis functions (default: 8)
35 | - num_polynomial_cutoff (int): Number of polynomial cutoff basis functions (default: 5)
36 | - num_layers (int): Number of layers in the model (default: 5)
37 | - in_dim (int): Input dimension of the model (default: 1)
38 | - out_dim (int): Output dimension of the model (default: 1)
39 | - s_dim (int): Dimension of the node state embeddings (default: 128)
40 | - v_dim (int): Dimension of the node vector embeddings (default: 16)
41 | - s_dim_edge (int): Dimension of the edge state embeddings (default: 32)
42 | - v_dim_edge (int): Dimension of the edge vector embeddings (default: 1)
43 | - pool (str): Global pooling method to be used (default: "sum")
44 | - residual (bool): Whether to use residual connections (default: True)
45 | - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False)
46 | """
47 | super().__init__()
48 |
49 | self.r_max = r_max
50 | self.num_layers = num_layers
51 | self.equivariant_pred = equivariant_pred
52 | self.s_dim = s_dim
53 | self.v_dim = v_dim
54 |
55 | activations = (F.relu, None)
56 | _DEFAULT_V_DIM = (s_dim, v_dim)
57 | _DEFAULT_E_DIM = (s_dim_edge, v_dim_edge)
58 |
59 | # Node embedding
60 | self.emb_in = torch.nn.Embedding(in_dim, s_dim)
61 | self.W_v = torch.nn.Sequential(
62 | gvp.LayerNorm((s_dim, 0)),
63 | gvp.GVP((s_dim, 0), _DEFAULT_V_DIM,
64 | activations=(None, None), vector_gate=True)
65 | )
66 |
67 | # Edge embedding
68 | self.radial_embedding = RadialEmbeddingBlock(
69 | r_max=r_max,
70 | num_bessel=num_bessel,
71 | num_polynomial_cutoff=num_polynomial_cutoff,
72 | )
73 | self.W_e = torch.nn.Sequential(
74 | gvp.LayerNorm((self.radial_embedding.out_dim, 1)),
75 | gvp.GVP((self.radial_embedding.out_dim, 1), _DEFAULT_E_DIM,
76 | activations=(None, None), vector_gate=True)
77 | )
78 |
79 | # Stack of GNN layers
80 | self.layers = torch.nn.ModuleList(
81 | gvp.GVPConvLayer(
82 | _DEFAULT_V_DIM, _DEFAULT_E_DIM,
83 | activations=activations, vector_gate=True,
84 | residual=residual
85 | )
86 | for _ in range(num_layers)
87 | )
88 |
89 | # Global pooling/readout function
90 | self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]
91 |
92 | if self.equivariant_pred:
93 | # Linear predictor for equivariant tasks using geometric features
94 | self.pred = torch.nn.Linear(s_dim + v_dim * 3, out_dim)
95 | else:
96 | # MLP predictor for invariant tasks using only scalar features
97 | self.pred = torch.nn.Sequential(
98 | torch.nn.Linear(s_dim, s_dim),
99 | torch.nn.ReLU(),
100 | torch.nn.Linear(s_dim, out_dim)
101 | )
102 |
103 | def forward(self, batch):
104 |
105 | # Edge features
106 | vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]] # [n_edges, 3]
107 | lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1]
108 |
109 | h_V = self.emb_in(batch.atoms) # (n,) -> (n, d)
110 | h_E = (
111 | self.radial_embedding(lengths),
112 | torch.nan_to_num(torch.div(vectors, lengths)).unsqueeze_(-2)
113 | )
114 |
115 | h_V = self.W_v(h_V)
116 | h_E = self.W_e(h_E)
117 |
118 | for layer in self.layers:
119 | h_V = layer(h_V, batch.edge_index, h_E)
120 |
121 | out = self.pool(gvp._merge(*h_V), batch.batch) # (n, d) -> (batch_size, d)
122 |
123 | if not self.equivariant_pred:
124 | # Select only scalars for invariant prediction
125 | out = out[:,:self.s_dim]
126 |
127 | return self.pred(out) # (batch_size, out_dim)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ⚔️ Geometric GNN Dojo
2 |
3 | *Geometric GNN Dojo* is a pedagogical resource for beginners and experts to explore the design space of **Graph Neural Networks for geometric graphs**.
4 |
5 | Check out the accompanying paper ['On the Expressive Power of Geometric Graph Neural Networks'](https://arxiv.org/abs/2301.09308), which studies the expressivity and theoretical limits of geometric GNNs.
6 | > Chaitanya K. Joshi*, Cristian Bodnar*, Simon V. Mathis, Taco Cohen, and Pietro Liò. On the Expressive Power of Geometric Graph Neural Networks. *International Conference on Machine Learning*.
7 | >
8 | >[PDF](https://arxiv.org/pdf/2301.09308.pdf) | [Slides](https://www.chaitjo.com/publication/joshi-2023-expressive/Geometric_GNNs_Slides.pdf) | [Video](https://youtu.be/5ulJMtpiKGc)
9 |
10 | ❓**New to geometric GNNs:** try our practical notebook on [*Geometric GNNs 101*](geometric_gnn_101.ipynb), prepared for MPhil students at the University of Cambridge.
11 |
12 |
13 |
14 |
15 |
16 | ## Architectures
17 |
18 | The `/models` directory provides unified implementations of several popular geometric GNN architectures:
19 | - Invariant GNNs: [SchNet](https://arxiv.org/abs/1706.08566), [DimeNet](https://arxiv.org/abs/2003.03123), [SphereNet](https://arxiv.org/abs/2102.05013)
20 | - Equivariant GNNs using cartesian vectors: [E(n) Equivariant GNN](https://proceedings.mlr.press/v139/satorras21a.html), [GVP-GNN](https://arxiv.org/abs/2009.01411)
21 | - Equivariant GNNs using spherical tensors: [Tensor Field Network](https://arxiv.org/abs/1802.08219), [MACE](http://arxiv.org/abs/2206.07697)
22 | - 🔥 Your new geometric GNN architecture?
23 |
24 |
25 |
26 | ## Experiments
27 |
28 | The `/experiments` directory contains notebooks with synthetic experiments to highlight practical challenges in building powerful geometric GNNs:
29 | - `kchains.ipynb`: Distinguishing k-chains, which test a model's ability to **propagate geometric information** non-locally and demonstrate oversquashing with increased depth/longer chains.
30 | - `rotsym.ipynb`: Rotationally symmetric structures, which test a layer's ability to **identify neighbourhood orientation** and highlight the utility of higher order tensors in equivariant GNNs.
31 | - `incompleteness.ipynb`: Counterexamples from [Pozdnyakov et al.](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.125.166001), which test a layer's ability to create **distinguishing fingerprints for local neighbourhoods** and highlight the need for higher body order of local scalarisation (distances, angles, and beyond).
32 |
33 |
34 |
35 | ## Installation
36 |
37 | ```bash
38 | # Create new conda environment
39 | conda create --prefix ./env python=3.8
40 | conda activate ./env
41 |
42 | # Install PyTorch (Check CUDA version for GPU!)
43 | #
44 | # Option 1: CPU
45 | conda install pytorch==1.12.0 -c pytorch
46 | #
47 | # Option 2: GPU, CUDA 11.3
48 | # conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
49 |
50 | # Install dependencies
51 | conda install matplotlib pandas networkx
52 | conda install jupyterlab -c conda-forge
53 | pip install e3nn==0.4.4 ipdb ase
54 |
55 | # Install PyG (Check CPU/GPU/MacOS)
56 | #
57 | # Option 1: CPU, MacOS
58 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-1.12.0+cpu.html
59 | pip install torch-geometric
60 | #
61 | # Option 2: GPU, CUDA 11.3
62 | # pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-1.12.1+cu113.html
63 | # pip install torch-geometric
64 | #
65 | # Option 3: CPU/GPU, but may not work on MacOS
66 | # conda install pyg -c pyg
67 | ```
68 |
69 |
70 | ## Directory Structure and Usage
71 |
72 | ```
73 | .
74 | ├── README.md
75 | |
76 | ├── geometric_gnn_101.ipynb # A gentle introduction to Geometric GNNs
77 | |
78 | ├── experiments # Synthetic experiments
79 | | |
80 | │ ├── kchains.ipynb # Experiment on k-chains
81 | │ ├── rotsym.ipynb # Experiment on rotationally symmetric structures
82 | │ ├── incompleteness.ipynb # Experiment on counterexamples from Pozdnyakov et al.
83 | | └── utils # Helper functions for training, plotting, etc.
84 | |
85 | └── models # Geometric GNN models library
86 | |
87 | ├── schnet.py # SchNet model
88 | ├── dimenet.py # DimeNet model
89 | ├── spherenet.py # SphereNet model
90 | ├── egnn.py # E(n) Equivariant GNN model
91 | ├── gvpgnn.py # GVP-GNN model
92 | ├── tfn.py # Tensor Field Network model
93 | ├── mace.py # MACE model
94 | ├── layers # Layers for each model
95 | └── modules # Modules and layers for MACE
96 | ```
97 |
98 |
99 | ## Contact
100 |
101 | Authors: Chaitanya K. Joshi (chaitanya.joshi@cl.cam.ac.uk), Simon V. Mathis (simon.mathis@cl.cam.ac.uk).
102 | We welcome your questions and feedback via email or GitHub Issues.
103 |
104 |
105 | ## Citation
106 |
107 | ```
108 | @inproceedings{joshi2023expressive,
109 | title={On the Expressive Power of Geometric Graph Neural Networks},
110 | author={Joshi, Chaitanya K. and Bodnar, Cristian and Mathis, Simon V. and Cohen, Taco and Liò, Pietro},
111 | booktitle={International Conference on Machine Learning},
112 | year={2023},
113 | }
114 | ```
--------------------------------------------------------------------------------
/models/layers/egnn_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Linear, ReLU, SiLU, Sequential
3 | from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool
4 | from torch_scatter import scatter
5 |
6 |
7 | class EGNNLayer(MessagePassing):
8 | """E(n) Equivariant GNN Layer
9 |
10 | Paper: E(n) Equivariant Graph Neural Networks, Satorras et al.
11 | """
12 | def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"):
13 | """
14 | Args:
15 | emb_dim: (int) - hidden dimension `d`
16 | activation: (str) - non-linearity within MLPs (swish/relu)
17 | norm: (str) - normalisation layer (layer/batch)
18 | aggr: (str) - aggregation function `\oplus` (sum/mean/max)
19 | """
20 | # Set the aggregation function
21 | super().__init__(aggr=aggr)
22 |
23 | self.emb_dim = emb_dim
24 | self.activation = {"swish": SiLU(), "relu": ReLU()}[activation]
25 | self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm]
26 |
27 | # MLP `\psi_h` for computing messages `m_ij`
28 | self.mlp_msg = Sequential(
29 | Linear(2 * emb_dim + 1, emb_dim),
30 | self.norm(emb_dim),
31 | self.activation,
32 | Linear(emb_dim, emb_dim),
33 | self.norm(emb_dim),
34 | self.activation,
35 | )
36 | # MLP `\psi_x` for computing messages `\overrightarrow{m}_ij`
37 | self.mlp_pos = Sequential(
38 | Linear(emb_dim, emb_dim), self.norm(emb_dim), self.activation, Linear(emb_dim, 1)
39 | )
40 | # MLP `\phi` for computing updated node features `h_i^{l+1}`
41 | self.mlp_upd = Sequential(
42 | Linear(2 * emb_dim, emb_dim),
43 | self.norm(emb_dim),
44 | self.activation,
45 | Linear(emb_dim, emb_dim),
46 | self.norm(emb_dim),
47 | self.activation,
48 | )
49 |
50 | def forward(self, h, pos, edge_index):
51 | """
52 | Args:
53 | h: (n, d) - initial node features
54 | pos: (n, 3) - initial node coordinates
55 | edge_index: (e, 2) - pairs of edges (i, j)
56 | Returns:
57 | out: [(n, d),(n,3)] - updated node features
58 | """
59 | out = self.propagate(edge_index, h=h, pos=pos)
60 | return out
61 |
62 | def message(self, h_i, h_j, pos_i, pos_j):
63 | # Compute messages
64 | pos_diff = pos_i - pos_j
65 | dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
66 | msg = torch.cat([h_i, h_j, dists], dim=-1)
67 | msg = self.mlp_msg(msg)
68 | # Scale magnitude of displacement vector
69 | pos_diff = pos_diff * self.mlp_pos(msg)
70 | # NOTE: some papers divide pos_diff by (dists + 1) to stabilise model.
71 | # NOTE: lucidrains clamps pos_diff between some [-n, +n], also for stability.
72 | return msg, pos_diff
73 |
74 | def aggregate(self, inputs, index):
75 | msgs, pos_diffs = inputs
76 | # Aggregate messages
77 | msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr)
78 | # Aggregate displacement vectors
79 | pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean")
80 | return msg_aggr, pos_aggr
81 |
82 | def update(self, aggr_out, h, pos):
83 | msg_aggr, pos_aggr = aggr_out
84 | upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1))
85 | upd_pos = pos + pos_aggr
86 | return upd_out, upd_pos
87 |
88 | def __repr__(self) -> str:
89 | return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"
90 |
91 |
92 | class MPNNLayer(MessagePassing):
93 | def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"):
94 | """Vanilla Message Passing GNN layer
95 |
96 | Args:
97 | emb_dim: (int) - hidden dimension `d`
98 | activation: (str) - non-linearity within MLPs (swish/relu)
99 | norm: (str) - normalisation layer (layer/batch)
100 | aggr: (str) - aggregation function `\oplus` (sum/mean/max)
101 | """
102 | # Set the aggregation function
103 | super().__init__(aggr=aggr)
104 |
105 | self.emb_dim = emb_dim
106 | self.activation = {"swish": SiLU(), "relu": ReLU()}[activation]
107 | self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm]
108 |
109 | # MLP `\psi_h` for computing messages `m_ij`
110 | self.mlp_msg = Sequential(
111 | Linear(2 * emb_dim, emb_dim),
112 | self.norm(emb_dim),
113 | self.activation,
114 | Linear(emb_dim, emb_dim),
115 | self.norm(emb_dim),
116 | self.activation,
117 | )
118 | # MLP `\phi` for computing updated node features `h_i^{l+1}`
119 | self.mlp_upd = Sequential(
120 | Linear(2 * emb_dim, emb_dim),
121 | self.norm(emb_dim),
122 | self.activation,
123 | Linear(emb_dim, emb_dim),
124 | self.norm(emb_dim),
125 | self.activation,
126 | )
127 |
128 | def forward(self, h, edge_index):
129 | """
130 | Args:
131 | h: (n, d) - initial node features
132 | edge_index: (e, 2) - pairs of edges (i, j)
133 | Returns:
134 | out: (n, d) - updated node features
135 | """
136 | out = self.propagate(edge_index, h=h)
137 | return out
138 |
139 | def message(self, h_i, h_j):
140 | # Compute messages
141 | msg = torch.cat([h_i, h_j], dim=-1)
142 | msg = self.mlp_msg(msg)
143 | return msg
144 |
145 | def aggregate(self, inputs, index):
146 | # Aggregate messages
147 | msg_aggr = scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)
148 | return msg_aggr
149 |
150 | def update(self, aggr_out, h):
151 | upd_out = self.mlp_upd(torch.cat([h, aggr_out], dim=-1))
152 | return upd_out
153 |
154 | def __repr__(self) -> str:
155 | return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"
156 |
--------------------------------------------------------------------------------
/models/tfn.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from torch_geometric.nn import global_add_pool, global_mean_pool
6 | import e3nn
7 |
8 | from models.mace_modules.blocks import RadialEmbeddingBlock
9 | from models.layers.tfn_layer import TensorProductConvLayer
10 |
11 |
12 | class TFNModel(torch.nn.Module):
13 | """
14 | Tensor Field Network model from "Tensor Field Networks".
15 | """
16 | def __init__(
17 | self,
18 | r_max: float = 10.0,
19 | num_bessel: int = 8,
20 | num_polynomial_cutoff: int = 5,
21 | max_ell: int = 2,
22 | num_layers: int = 5,
23 | emb_dim: int = 64,
24 | hidden_irreps: Optional[e3nn.o3.Irreps] = None,
25 | mlp_dim: int = 256,
26 | in_dim: int = 1,
27 | out_dim: int = 1,
28 | aggr: str = "sum",
29 | pool: str = "sum",
30 | gate: bool = True,
31 | batch_norm: bool = False,
32 | residual: bool = True,
33 | equivariant_pred: bool = False
34 | ):
35 | """
36 | Parameters:
37 | - r_max (float): Maximum distance for Bessel basis functions (default: 10.0)
38 | - num_bessel (int): Number of Bessel basis functions (default: 8)
39 | - num_polynomial_cutoff (int): Number of polynomial cutoff basis functions (default: 5)
40 | - max_ell (int): Maximum degree of spherical harmonics basis functions (default: 2)
41 | - num_layers (int): Number of layers in the model (default: 5)
42 | - emb_dim (int): Scalar feature embedding dimension (default: 64)
43 | - hidden_irreps (Optional[e3nn.o3.Irreps]): Hidden irreps (default: None)
44 | - mlp_dim (int): Dimension of MLP for computing tensor product weights (default: 256)
45 | - in_dim (int): Input dimension of the model (default: 1)
46 | - out_dim (int): Output dimension of the model (default: 1)
47 | - aggr (str): Aggregation method to be used (default: "sum")
48 | - pool (str): Global pooling method to be used (default: "sum")
49 | - gate (bool): Whether to use gated equivariant non-linearity (default: True)
50 | - batch_norm (bool): Whether to use batch normalization (default: False)
51 | - residual (bool): Whether to use residual connections (default: True)
52 | - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False)
53 |
54 | Note:
55 | - If `hidden_irreps` is None, the irreps for the intermediate features are computed
56 | using `emb_dim` and `max_ell`.
57 | - The `equivariant_pred` parameter determines whether it is an equivariant prediction task.
58 | If set to True, equivariant prediction will be performed.
59 | - At present, only one of `gate` and `batch_norm` can be True.
60 | """
61 | super().__init__()
62 |
63 | self.r_max = r_max
64 | self.max_ell = max_ell
65 | self.num_layers = num_layers
66 | self.emb_dim = emb_dim
67 | self.mlp_dim = mlp_dim
68 | self.residual = residual
69 | self.batch_norm = batch_norm
70 | self.gate = gate
71 | self.hidden_irreps = hidden_irreps
72 | self.equivariant_pred = equivariant_pred
73 |
74 | # Edge embedding
75 | self.radial_embedding = RadialEmbeddingBlock(
76 | r_max=r_max,
77 | num_bessel=num_bessel,
78 | num_polynomial_cutoff=num_polynomial_cutoff,
79 | )
80 | sh_irreps = e3nn.o3.Irreps.spherical_harmonics(max_ell)
81 | self.spherical_harmonics = e3nn.o3.SphericalHarmonics(
82 | sh_irreps, normalize=True, normalization="component"
83 | )
84 |
85 | # Embedding lookup for initial node features
86 | self.emb_in = torch.nn.Embedding(in_dim, emb_dim)
87 |
88 | # Set hidden irreps if none are provided
89 | if hidden_irreps is None:
90 | hidden_irreps = (sh_irreps * emb_dim).sort()[0].simplify()
91 | # Note: This defaults to O(3) equivariant layers
92 | # It is possible to use SO(3) equivariance by passing the appropriate irreps
93 |
94 | self.convs = torch.nn.ModuleList()
95 | # First conv layer: scalar only -> tensor
96 | self.convs.append(
97 | TensorProductConvLayer(
98 | in_irreps=e3nn.o3.Irreps(f'{emb_dim}x0e'),
99 | out_irreps=hidden_irreps,
100 | sh_irreps=sh_irreps,
101 | edge_feats_dim=self.radial_embedding.out_dim,
102 | mlp_dim=mlp_dim,
103 | aggr=aggr,
104 | batch_norm=batch_norm,
105 | gate=gate,
106 | )
107 | )
108 | # Intermediate conv layers: tensor -> tensor
109 | for _ in range(num_layers - 1):
110 | conv = TensorProductConvLayer(
111 | in_irreps=hidden_irreps,
112 | out_irreps=hidden_irreps,
113 | sh_irreps=sh_irreps,
114 | edge_feats_dim=self.radial_embedding.out_dim,
115 | mlp_dim=mlp_dim,
116 | aggr=aggr,
117 | batch_norm=batch_norm,
118 | gate=gate,
119 | )
120 | self.convs.append(conv)
121 |
122 | # Global pooling/readout function
123 | self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]
124 |
125 | if self.equivariant_pred:
126 | # Linear predictor for equivariant tasks using geometric features
127 | self.pred = torch.nn.Linear(hidden_irreps.dim, out_dim)
128 | else:
129 | # MLP predictor for invariant tasks using only scalar features
130 | self.pred = torch.nn.Sequential(
131 | torch.nn.Linear(emb_dim, emb_dim),
132 | torch.nn.ReLU(),
133 | torch.nn.Linear(emb_dim, out_dim)
134 | )
135 |
136 | def forward(self, batch):
137 | # Node embedding
138 | h = self.emb_in(batch.atoms) # (n,) -> (n, d)
139 |
140 | # Edge features
141 | vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]] # [n_edges, 3]
142 | lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1]
143 |
144 | edge_sh = self.spherical_harmonics(vectors)
145 | edge_feats = self.radial_embedding(lengths)
146 |
147 | for conv in self.convs:
148 | # Message passing layer
149 | h_update = conv(h, batch.edge_index, edge_sh, edge_feats)
150 |
151 | # Update node features
152 | h = h_update + F.pad(h, (0, h_update.shape[-1] - h.shape[-1])) if self.residual else h_update
153 |
154 | out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d)
155 |
156 | if not self.equivariant_pred:
157 | # Select only scalars for invariant prediction
158 | out = out[:,:self.emb_dim]
159 |
160 | return self.pred(out) # (batch_size, out_dim)
161 |
--------------------------------------------------------------------------------
/experiments/rotsym.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# Identifying neighbourhood orientation: rotationally symmetric structures\n",
9 | "\n",
10 | "*Background:*\n",
11 | "Rotationally equivariant geometric GNNs aggregate local geometric information via summing together the neighbourhood geometric features, which are either **cartesian vectors** or **higher order spherical tensors**. \n",
12 | "The ideal geometric GNN would injectively aggregate local geometric infromation to perfectly identify neighbourhood identities, orientations, etc.\n",
13 | "In practice, the choice of basis (cartesian vs. spherical) comes with tradeoffs between tractability and empirical performance.\n",
14 | "\n",
15 | "*Experiment:*\n",
16 | "In this notebook, we study how rotational symmetries interact with tensor order in equivariant GNNs. \n",
17 | "We evaluate equivariant layers on their ability to distinguish the orientation of **structures with rotational symmetry**. \n",
18 | "An [$L$-fold symmetric structure](https://en.wikipedia.org/wiki/Rotational_symmetry) does not change when rotated by an angle $\\frac{2\\pi}{L}$ around a point (in 2D) or axis (3D).\n",
19 | "We consider two *distinct* rotated versions of each $L$-fold symmetric structure and train single layer equivariant GNNs to classify the two orientations using the updated geometric features.\n",
20 | "\n",
21 | "\n",
22 | "\n",
23 | "*Result:*\n",
24 | "- **We find that layers using order $L$ tensors are unable to identify the orientation of structures with rotation symmetry higher than $L$-fold.** This observation may be attributed to **spherical harmonics**, which serve as an orthonormal basis for spherical tensor features and exhibit rotational symmetry themselves.\n",
25 | "- Layers such as E-GNN and GVP-GNN using **cartesian vectors** (corresponding to tensor order 1) are popular as working with higher order tensors can be computationally intractable for many applications. However, E-GNN and GVP-GNN are particularly poor at disciminating orientation of rotationally symmetric structures. "
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "%load_ext autoreload\n",
35 | "%autoreload 2\n",
36 | "\n",
37 | "import sys\n",
38 | "sys.path.append('../')\n",
39 | "\n",
40 | "import random\n",
41 | "import math\n",
42 | "import torch\n",
43 | "import torch_geometric\n",
44 | "from torch_geometric.data import Data\n",
45 | "from torch_geometric.loader import DataLoader\n",
46 | "from torch_geometric.utils import to_undirected\n",
47 | "import e3nn\n",
48 | "from functools import partial\n",
49 | "\n",
50 | "print(\"PyTorch version {}\".format(torch.__version__))\n",
51 | "print(\"PyG version {}\".format(torch_geometric.__version__))\n",
52 | "print(\"e3nn version {}\".format(e3nn.__version__))\n",
53 | "\n",
54 | "from experiments.utils.plot_utils import plot_2d\n",
55 | "from experiments.utils.train_utils import run_experiment\n",
56 | "from models import SchNetModel, DimeNetPPModel, SphereNetModel, EGNNModel, GVPGNNModel, TFNModel, MACEModel\n",
57 | "\n",
58 | "# Set the device\n",
59 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
60 | "print(f\"Using device: {device}\")"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "def create_rotsym_envs(fold=3):\n",
70 | " dataset = []\n",
71 | "\n",
72 | " # Environment 0\n",
73 | " atoms = torch.LongTensor([ 0 ] + [ 0 ] * fold)\n",
74 | " edge_index = torch.LongTensor( [ [0] * fold, [i for i in range(1, fold+1)] ] )\n",
75 | " x = torch.Tensor([1,0,0])\n",
76 | " pos = [\n",
77 | " torch.Tensor([0,0,0]), # origin\n",
78 | " x, # first spoke \n",
79 | " ]\n",
80 | " for count in range(1, fold):\n",
81 | " R = e3nn.o3.matrix_z(torch.Tensor([2*math.pi/fold * count])).squeeze(0)\n",
82 | " pos.append(x @ R.T)\n",
83 | " pos = torch.stack(pos)\n",
84 | " y = torch.LongTensor([0]) # Label 0\n",
85 | " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
86 | " data1.edge_index = to_undirected(data1.edge_index)\n",
87 | " dataset.append(data1)\n",
88 | " \n",
89 | " # Environment 1\n",
90 | " q = 2*math.pi/(fold + random.randint(1, fold))\n",
91 | " assert q < 2*math.pi/fold\n",
92 | " Q = e3nn.o3.matrix_z(torch.Tensor([q])).squeeze(0)\n",
93 | " pos = pos @ Q.T\n",
94 | " y = torch.LongTensor([1]) # Label 1\n",
95 | " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
96 | " data2.edge_index = to_undirected(data2.edge_index)\n",
97 | " dataset.append(data2)\n",
98 | " \n",
99 | " return dataset"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "fold = 5\n",
109 | "\n",
110 | "# Create dataset\n",
111 | "dataset = create_rotsym_envs(fold)\n",
112 | "for data in dataset:\n",
113 | " plot_2d(data, lim=1)\n",
114 | "\n",
115 | "# Create dataloaders\n",
116 | "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
117 | "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
118 | "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "# Set parameters\n",
128 | "model_name = \"tfn\"\n",
129 | "correlation = 2\n",
130 | "max_ell = 5\n",
131 | "\n",
132 | "model = {\n",
133 | " \"schnet\": SchNetModel,\n",
134 | " \"dimenet\": DimeNetPPModel,\n",
135 | " \"spherenet\": SphereNetModel,\n",
136 | " \"egnn\": partial(EGNNModel, equivariant_pred=True),\n",
137 | " \"gvp\": partial(GVPGNNModel, equivariant_pred=True),\n",
138 | " \"tfn\": partial(TFNModel, max_ell=max_ell, equivariant_pred=True),\n",
139 | " \"mace\": partial(MACEModel, max_ell=max_ell, correlation=correlation, equivariant_pred=True),\n",
140 | "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
141 | "\n",
142 | "best_val_acc, test_acc, train_time = run_experiment(\n",
143 | " model, \n",
144 | " dataloader,\n",
145 | " val_loader, \n",
146 | " test_loader,\n",
147 | " n_epochs=100,\n",
148 | " n_times=10,\n",
149 | " device=device,\n",
150 | " verbose=False\n",
151 | ")"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": null,
157 | "metadata": {},
158 | "outputs": [],
159 | "source": []
160 | }
161 | ],
162 | "metadata": {
163 | "kernelspec": {
164 | "display_name": "Python 3",
165 | "language": "python",
166 | "name": "python3"
167 | },
168 | "language_info": {
169 | "codemirror_mode": {
170 | "name": "ipython",
171 | "version": 3
172 | },
173 | "file_extension": ".py",
174 | "mimetype": "text/x-python",
175 | "name": "python",
176 | "nbconvert_exporter": "python",
177 | "pygments_lexer": "ipython3",
178 | "version": "3.8.16"
179 | },
180 | "orig_nbformat": 4,
181 | "vscode": {
182 | "interpreter": {
183 | "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897"
184 | }
185 | }
186 | },
187 | "nbformat": 4,
188 | "nbformat_minor": 2
189 | }
190 |
--------------------------------------------------------------------------------
/models/mace_modules/symmetric_contraction.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # Implementation of the symmetric contraction algorithm presented in the MACE paper
3 | # (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11)
4 | # Authors: Ilyes Batatia
5 | # This program is distributed under the MIT License (see MIT.md)
6 | ###########################################################################################
7 |
8 | from typing import Dict, Optional, Union
9 |
10 | import torch
11 | import torch.fx
12 | from e3nn import o3
13 | from e3nn.util.codegen import CodeGenMixin
14 | from e3nn.util.jit import compile_mode
15 | from opt_einsum import contract
16 |
17 | from .cg import U_matrix_real
18 |
19 |
20 | @compile_mode("script")
21 | class SymmetricContraction(CodeGenMixin, torch.nn.Module):
22 | def __init__(
23 | self,
24 | irreps_in: o3.Irreps,
25 | irreps_out: o3.Irreps,
26 | correlation: Union[int, Dict[str, int]],
27 | irrep_normalization: str = "component",
28 | path_normalization: str = "element",
29 | internal_weights: Optional[bool] = None,
30 | shared_weights: Optional[torch.Tensor] = None,
31 | element_dependent: Optional[bool] = None,
32 | num_elements: Optional[int] = None,
33 | ) -> None:
34 | super().__init__()
35 |
36 | if irrep_normalization is None:
37 | irrep_normalization = "component"
38 |
39 | if path_normalization is None:
40 | path_normalization = "element"
41 |
42 | assert irrep_normalization in ["component", "norm", "none"]
43 | assert path_normalization in ["element", "path", "none"]
44 |
45 | self.irreps_in = o3.Irreps(irreps_in)
46 | self.irreps_out = o3.Irreps(irreps_out)
47 |
48 | del irreps_in, irreps_out
49 |
50 | if not isinstance(correlation, tuple):
51 | corr = correlation
52 | correlation = {}
53 | for irrep_out in self.irreps_out:
54 | correlation[irrep_out] = corr
55 |
56 | assert shared_weights or not internal_weights
57 |
58 | if internal_weights is None:
59 | internal_weights = True
60 |
61 | if element_dependent is None:
62 | element_dependent = True
63 |
64 | self.internal_weights = internal_weights
65 | self.shared_weights = shared_weights
66 |
67 | del internal_weights, shared_weights
68 |
69 | self.contractions = torch.nn.ModuleDict()
70 | for irrep_out in self.irreps_out:
71 | self.contractions[str(irrep_out)] = Contraction(
72 | irreps_in=self.irreps_in,
73 | irrep_out=o3.Irreps(str(irrep_out.ir)),
74 | correlation=correlation[irrep_out],
75 | internal_weights=self.internal_weights,
76 | element_dependent=element_dependent,
77 | num_elements=num_elements,
78 | weights=self.shared_weights,
79 | )
80 |
81 | def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
82 | outs = []
83 | for irrep in self.irreps_out:
84 | outs.append(self.contractions[str(irrep)](x, y))
85 | return torch.cat(outs, dim=-1)
86 |
87 |
88 | class Contraction(torch.nn.Module):
89 | def __init__(
90 | self,
91 | irreps_in: o3.Irreps,
92 | irrep_out: o3.Irreps,
93 | correlation: int,
94 | internal_weights: bool = True,
95 | element_dependent: bool = True,
96 | num_elements: Optional[int] = None,
97 | weights: Optional[torch.Tensor] = None,
98 | ) -> None:
99 | super().__init__()
100 |
101 | self.element_dependent = element_dependent
102 | self.num_features = irreps_in.count((0, 1))
103 | self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in])
104 | self.correlation = correlation
105 | dtype = torch.get_default_dtype()
106 | for nu in range(1, correlation + 1):
107 | U_matrix = U_matrix_real(
108 | irreps_in=self.coupling_irreps,
109 | irreps_out=irrep_out,
110 | correlation=nu,
111 | dtype=dtype,
112 | )[-1]
113 | self.register_buffer(f"U_matrix_{nu}", U_matrix)
114 |
115 | if element_dependent:
116 | # Tensor contraction equations
117 | self.equation_main = "...ik,ekc,bci,be -> bc..."
118 | self.equation_weighting = "...k,ekc,be->bc..."
119 | self.equation_contract = "bc...i,bci->bc..."
120 | if internal_weights:
121 | # Create weight for product basis
122 | self.weights = torch.nn.ParameterDict({})
123 | for i in range(1, correlation + 1):
124 | num_params = self.U_tensors(i).size()[-1]
125 | w = torch.nn.Parameter(
126 | torch.randn(num_elements, num_params, self.num_features)
127 | / num_params
128 | )
129 | self.weights[str(i)] = w
130 | else:
131 | self.register_buffer("weights", weights)
132 |
133 | else:
134 | # Tensor contraction equations
135 | self.equation_main = "...ik,kc,bci -> bc..."
136 | self.equation_weighting = "...k,kc->c..."
137 | self.equation_contract = "bc...i,bci->bc..."
138 | if internal_weights:
139 | # Create weight for product basis
140 | self.weights = torch.nn.ParameterDict({})
141 | for i in range(1, correlation + 1):
142 | num_params = self.U_tensors(i).size()[-1]
143 | w = torch.nn.Parameter(
144 | torch.randn(num_params, self.num_features) / num_params
145 | )
146 | self.weights[str(i)] = w
147 | else:
148 | self.register_buffer("weights", weights)
149 |
150 | def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
151 | if self.element_dependent:
152 | out = contract(
153 | self.equation_main,
154 | self.U_tensors(self.correlation),
155 | self.weights[str(self.correlation)],
156 | x,
157 | y,
158 | ) # TODO: use optimize library and cuTENSOR # pylint: disable=fixme
159 | for corr in range(self.correlation - 1, 0, -1):
160 | c_tensor = contract(
161 | self.equation_weighting,
162 | self.U_tensors(corr),
163 | self.weights[str(corr)],
164 | y,
165 | )
166 | c_tensor = c_tensor + out
167 | out = contract(self.equation_contract, c_tensor, x)
168 |
169 | else:
170 | out = contract(
171 | self.equation_main,
172 | self.U_tensors(self.correlation),
173 | self.weights[str(self.correlation)],
174 | x,
175 | ) # TODO: use optimize library and cuTENSOR # pylint: disable=fixme
176 | for corr in range(self.correlation - 1, 0, -1):
177 | c_tensor = contract(
178 | self.equation_weighting,
179 | self.U_tensors(corr),
180 | self.weights[str(corr)],
181 | )
182 | c_tensor = c_tensor + out
183 | out = contract(self.equation_contract, c_tensor, x)
184 | resize_shape = torch.prod(torch.tensor(out.shape[1:]))
185 | return out.view(out.shape[0], resize_shape)
186 |
187 | def U_tensors(self, nu):
188 | return self._buffers[f"U_matrix_{nu}"]
189 |
--------------------------------------------------------------------------------
/experiments/kchains.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# Propogating geometric information: $k$-chains\n",
9 | "\n",
10 | "*Background:*\n",
11 | "In geometric GNNs, **geometric information**, such as the relative orientation of local neighbourhoods, is propagated via summing features from multiple layers in fixed dimensional spaces. \n",
12 | "The ideal architecture can be run for any number of layers to perfectly propagate geometric information without loss of information.\n",
13 | "In practice, stacking geometric GNN layers may lead to distortion or **loss of information from distant nodes**.\n",
14 | "\n",
15 | "*Experiment:*\n",
16 | "To study the practical implications of depth in propagating geometric information beyond local neighbourhoods, we consider **$k$-chain geometric graphs** which generalise the examples from [Schütt et al., 2021](https://arxiv.org/abs/2102.03150). \n",
17 | "Each pair of $k$-chains consists of $k+2$ nodes with $k$ nodes arranged in a line and differentiated by the orientation of the $2$ end points.\n",
18 | "Thus, $k$-chain graphs are $(\\lfloor \\frac{k}{2} \\rfloor + 1)$-hop distinguishable, and $(\\lfloor \\frac{k}{2} \\rfloor + 1)$ geometric GNN iterations should be theoretically sufficient to distinguish them.\n",
19 | "In this notebook, we train equivariant and invariant geometric GNNs with an increasing number of layers to distinguish $k$-chains.\n",
20 | "\n",
21 | "\n",
22 | "\n",
23 | "*Results:*\n",
24 | "- Despite the supposed simplicity of the task, especially for small chain lengths, we find that popular equivariant GNNs such as E-GNN and TFN may require **more iterations** than theoretically sufficient.\n",
25 | "- Notably, as the length of the chain gets larger than $k=4$, all equivariant GNNs tended to lose performance and required more than $(\\lfloor \\frac{k}{2} \\rfloor + 1)$ iterations to solve the task.\n",
26 | "- Invariant GNNs are **unable** to distinguish $k$-chains.\n",
27 | "\n",
28 | "These results point to preliminary evidence of the **oversquashing** phenomenon when geometric information is propagated across multiple layers using fixed dimensional feature spaces.\n",
29 | "These issues are most evident for E-GNN, which uses a single vector feature to aggregate and propagate geometric information."
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "%load_ext autoreload\n",
39 | "%autoreload 2\n",
40 | "\n",
41 | "import sys\n",
42 | "sys.path.append('../')\n",
43 | "\n",
44 | "import torch\n",
45 | "import torch_geometric\n",
46 | "from torch_geometric.data import Data\n",
47 | "from torch_geometric.loader import DataLoader\n",
48 | "from torch_geometric.utils import to_undirected\n",
49 | "import e3nn\n",
50 | "from functools import partial\n",
51 | "\n",
52 | "print(\"PyTorch version {}\".format(torch.__version__))\n",
53 | "print(\"PyG version {}\".format(torch_geometric.__version__))\n",
54 | "print(\"e3nn version {}\".format(e3nn.__version__))\n",
55 | "\n",
56 | "from experiments.utils.plot_utils import plot_3d\n",
57 | "from experiments.utils.train_utils import run_experiment\n",
58 | "from models import SchNetModel, DimeNetPPModel, SphereNetModel, EGNNModel, GVPGNNModel, TFNModel, MACEModel\n",
59 | "\n",
60 | "# Set the device\n",
61 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
62 | "print(f\"Using device: {device}\")"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "def create_kchains(k):\n",
72 | " assert k >= 2\n",
73 | " \n",
74 | " dataset = []\n",
75 | "\n",
76 | " # Graph 0\n",
77 | " atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )\n",
78 | " edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )\n",
79 | " pos = torch.FloatTensor(\n",
80 | " [[-4, -3, 0]] + \n",
81 | " [[0, 5*i , 0] for i in range(k)] + \n",
82 | " [[4, 5*(k-1) + 3, 0]]\n",
83 | " )\n",
84 | " center_of_mass = torch.mean(pos, dim=0)\n",
85 | " pos = pos - center_of_mass\n",
86 | " y = torch.LongTensor([0]) # Label 0\n",
87 | " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
88 | " data1.edge_index = to_undirected(data1.edge_index)\n",
89 | " dataset.append(data1)\n",
90 | " \n",
91 | " # Graph 1\n",
92 | " atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )\n",
93 | " edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )\n",
94 | " pos = torch.FloatTensor(\n",
95 | " [[4, -3, 0]] + \n",
96 | " [[0, 5*i , 0] for i in range(k)] + \n",
97 | " [[4, 5*(k-1) + 3, 0]]\n",
98 | " )\n",
99 | " center_of_mass = torch.mean(pos, dim=0)\n",
100 | " pos = pos - center_of_mass\n",
101 | " y = torch.LongTensor([1]) # Label 1\n",
102 | " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
103 | " data2.edge_index = to_undirected(data2.edge_index)\n",
104 | " dataset.append(data2)\n",
105 | " \n",
106 | " return dataset"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": null,
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "k = 4\n",
116 | "\n",
117 | "# Create dataset\n",
118 | "dataset = create_kchains(k=k)\n",
119 | "for data in dataset:\n",
120 | " plot_3d(data, lim=5*k)\n",
121 | "\n",
122 | "# Create dataloaders\n",
123 | "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
124 | "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
125 | "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "# Set model\n",
135 | "model_name = \"gvp\"\n",
136 | "\n",
137 | "for num_layers in range(k // 2 , k + 3):\n",
138 | "\n",
139 | " print(f\"\\nNumber of layers: {num_layers}\")\n",
140 | " \n",
141 | " correlation = 2\n",
142 | " model = {\n",
143 | " \"schnet\": SchNetModel,\n",
144 | " \"dimenet\": DimeNetPPModel,\n",
145 | " \"spherenet\": SphereNetModel,\n",
146 | " \"egnn\": EGNNModel,\n",
147 | " \"gvp\": partial(GVPGNNModel, s_dim=32, v_dim=1),\n",
148 | " \"tfn\": TFNModel,\n",
149 | " \"mace\": partial(MACEModel, correlation=correlation),\n",
150 | " }[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n",
151 | " \n",
152 | " best_val_acc, test_acc, train_time = run_experiment(\n",
153 | " model, \n",
154 | " dataloader,\n",
155 | " val_loader, \n",
156 | " test_loader,\n",
157 | " n_epochs=100,\n",
158 | " n_times=10,\n",
159 | " device=device,\n",
160 | " verbose=False\n",
161 | " )"
162 | ]
163 | }
164 | ],
165 | "metadata": {
166 | "kernelspec": {
167 | "display_name": "Python 3",
168 | "language": "python",
169 | "name": "python3"
170 | },
171 | "language_info": {
172 | "codemirror_mode": {
173 | "name": "ipython",
174 | "version": 3
175 | },
176 | "file_extension": ".py",
177 | "mimetype": "text/x-python",
178 | "name": "python",
179 | "nbconvert_exporter": "python",
180 | "pygments_lexer": "ipython3",
181 | "version": "3.8.16"
182 | },
183 | "orig_nbformat": 4,
184 | "vscode": {
185 | "interpreter": {
186 | "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897"
187 | }
188 | }
189 | },
190 | "nbformat": 4,
191 | "nbformat_minor": 2
192 | }
193 |
--------------------------------------------------------------------------------
/models/mace.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from torch_geometric.nn import global_add_pool, global_mean_pool
6 | import e3nn
7 |
8 | from models.mace_modules.irreps_tools import reshape_irreps
9 | from models.mace_modules.blocks import (
10 | EquivariantProductBasisBlock,
11 | RadialEmbeddingBlock,
12 | )
13 | from models.layers.tfn_layer import TensorProductConvLayer
14 |
15 |
16 | class MACEModel(torch.nn.Module):
17 | """
18 | MACE model from "MACE: Higher Order Equivariant Message Passing Neural Networks".
19 | """
20 | def __init__(
21 | self,
22 | r_max: float = 10.0,
23 | num_bessel: int = 8,
24 | num_polynomial_cutoff: int = 5,
25 | max_ell: int = 2,
26 | correlation: int = 3,
27 | num_layers: int = 5,
28 | emb_dim: int = 64,
29 | hidden_irreps: Optional[e3nn.o3.Irreps] = None,
30 | mlp_dim: int = 256,
31 | in_dim: int = 1,
32 | out_dim: int = 1,
33 | aggr: str = "sum",
34 | pool: str = "sum",
35 | batch_norm: bool = True,
36 | residual: bool = True,
37 | equivariant_pred: bool = False
38 | ):
39 | """
40 | Parameters:
41 | - r_max (float): Maximum distance for Bessel basis functions (default: 10.0)
42 | - num_bessel (int): Number of Bessel basis functions (default: 8)
43 | - num_polynomial_cutoff (int): Number of polynomial cutoff basis functions (default: 5)
44 | - max_ell (int): Maximum degree of spherical harmonics basis functions (default: 2)
45 | - correlation (int): Local correlation order = body order - 1 (default: 3)
46 | - num_layers (int): Number of layers in the model (default: 5)
47 | - emb_dim (int): Scalar feature embedding dimension (default: 64)
48 | - hidden_irreps (Optional[e3nn.o3.Irreps]): Hidden irreps (default: None)
49 | - mlp_dim (int): Dimension of MLP for computing tensor product weights (default: 256)
50 | - in_dim (int): Input dimension of the model (default: 1)
51 | - out_dim (int): Output dimension of the model (default: 1)
52 | - aggr (str): Aggregation method to be used (default: "sum")
53 | - pool (str): Global pooling method to be used (default: "sum")
54 | - batch_norm (bool): Whether to use batch normalization (default: True)
55 | - residual (bool): Whether to use residual connections (default: True)
56 | - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False)
57 |
58 | Note:
59 | - If `hidden_irreps` is None, the irreps for the intermediate features are computed
60 | using `emb_dim` and `max_ell`.
61 | - The `equivariant_pred` parameter determines whether it is an equivariant prediction task.
62 | If set to True, equivariant prediction will be performed.
63 | """
64 | super().__init__()
65 |
66 | self.r_max = r_max
67 | self.max_ell = max_ell
68 | self.num_layers = num_layers
69 | self.emb_dim = emb_dim
70 | self.mlp_dim = mlp_dim
71 | self.residual = residual
72 | self.batch_norm = batch_norm
73 | self.hidden_irreps = hidden_irreps
74 | self.equivariant_pred = equivariant_pred
75 |
76 | # Edge embedding
77 | self.radial_embedding = RadialEmbeddingBlock(
78 | r_max=r_max,
79 | num_bessel=num_bessel,
80 | num_polynomial_cutoff=num_polynomial_cutoff,
81 | )
82 | sh_irreps = e3nn.o3.Irreps.spherical_harmonics(max_ell)
83 | self.spherical_harmonics = e3nn.o3.SphericalHarmonics(
84 | sh_irreps, normalize=True, normalization="component"
85 | )
86 |
87 | # Embedding lookup for initial node features
88 | self.emb_in = torch.nn.Embedding(in_dim, emb_dim)
89 |
90 | # Set hidden irreps if none are provided
91 | if hidden_irreps is None:
92 | hidden_irreps = (sh_irreps * emb_dim).sort()[0].simplify()
93 | # Note: This defaults to O(3) equivariant layers
94 | # It is possible to use SO(3) equivariance by passing the appropriate irreps
95 |
96 | self.convs = torch.nn.ModuleList()
97 | self.prods = torch.nn.ModuleList()
98 | self.reshapes = torch.nn.ModuleList()
99 |
100 | # First layer: scalar only -> tensor
101 | self.convs.append(
102 | TensorProductConvLayer(
103 | in_irreps=e3nn.o3.Irreps(f'{emb_dim}x0e'),
104 | out_irreps=hidden_irreps,
105 | sh_irreps=sh_irreps,
106 | edge_feats_dim=self.radial_embedding.out_dim,
107 | mlp_dim=mlp_dim,
108 | aggr=aggr,
109 | batch_norm=batch_norm,
110 | gate=False,
111 | )
112 | )
113 | self.reshapes.append(reshape_irreps(hidden_irreps))
114 | self.prods.append(
115 | EquivariantProductBasisBlock(
116 | node_feats_irreps=hidden_irreps,
117 | target_irreps=hidden_irreps,
118 | correlation=correlation,
119 | element_dependent=False,
120 | num_elements=in_dim,
121 | use_sc=residual
122 | )
123 | )
124 |
125 | # Intermediate layers: tensor -> tensor
126 | for _ in range(num_layers - 1):
127 | self.convs.append(
128 | TensorProductConvLayer(
129 | in_irreps=hidden_irreps,
130 | out_irreps=hidden_irreps,
131 | sh_irreps=sh_irreps,
132 | edge_feats_dim=self.radial_embedding.out_dim,
133 | mlp_dim=mlp_dim,
134 | aggr=aggr,
135 | batch_norm=batch_norm,
136 | gate=False,
137 | )
138 | )
139 | self.reshapes.append(reshape_irreps(hidden_irreps))
140 | self.prods.append(
141 | EquivariantProductBasisBlock(
142 | node_feats_irreps=hidden_irreps,
143 | target_irreps=hidden_irreps,
144 | correlation=correlation,
145 | element_dependent=False,
146 | num_elements=in_dim,
147 | use_sc=residual
148 | )
149 | )
150 |
151 | # Global pooling/readout function
152 | self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]
153 |
154 | if self.equivariant_pred:
155 | # Linear predictor for equivariant tasks using geometric features
156 | self.pred = torch.nn.Linear(hidden_irreps.dim, out_dim)
157 | else:
158 | # MLP predictor for invariant tasks using only scalar features
159 | self.pred = torch.nn.Sequential(
160 | torch.nn.Linear(emb_dim, emb_dim),
161 | torch.nn.ReLU(),
162 | torch.nn.Linear(emb_dim, out_dim)
163 | )
164 |
165 | def forward(self, batch):
166 | # Node embedding
167 | h = self.emb_in(batch.atoms) # (n,) -> (n, d)
168 |
169 | # Edge features
170 | vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]] # [n_edges, 3]
171 | lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1]
172 |
173 | edge_sh = self.spherical_harmonics(vectors)
174 | edge_feats = self.radial_embedding(lengths)
175 |
176 | for conv, reshape, prod in zip(self.convs, self.reshapes, self.prods):
177 | # Message passing layer
178 | h_update = conv(h, batch.edge_index, edge_sh, edge_feats)
179 |
180 | # Update node features
181 | sc = F.pad(h, (0, h_update.shape[-1] - h.shape[-1]))
182 | h = prod(reshape(h_update), sc, None)
183 |
184 | out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d)
185 |
186 | if not self.equivariant_pred:
187 | # Select only scalars for invariant prediction
188 | out = out[:,:self.emb_dim]
189 |
190 | return self.pred(out) # (batch_size, out_dim)
191 |
--------------------------------------------------------------------------------
/models/layers/gvp_layer.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # Implementation of Geometric Vector Perceptron layers
3 | #
4 | # Papers:
5 | # (1) Learning from Protein Structure with Geometric Vector Perceptrons,
6 | # by B Jing, S Eismann, P Suriana, RJL Townshend, and RO Dror
7 | # (2) Equivariant Graph Neural Networks for 3D Macromolecular Structure,
8 | # by B Jing, S Eismann, P Soni, and RO Dror
9 | #
10 | # Orginal repository: https://github.com/drorlab/gvp-pytorch
11 | ###########################################################################################
12 |
13 | import functools
14 | import torch
15 | import torch.nn.functional as F
16 | from torch import nn
17 | import torch_scatter
18 | from torch_geometric.nn import MessagePassing
19 |
20 |
21 | def tuple_sum(*args):
22 | """
23 | Sums any number of tuples (s, V) elementwise.
24 | """
25 | return tuple(map(sum, zip(*args)))
26 |
27 |
28 | def tuple_cat(*args, dim=-1):
29 | """
30 | Concatenates any number of tuples (s, V) elementwise.
31 |
32 | :param dim: dimension along which to concatenate when viewed
33 | as the `dim` index for the scalar-channel tensors.
34 | This means that `dim=-1` will be applied as
35 | `dim=-2` for the vector-channel tensors.
36 | """
37 | dim %= len(args[0][0].shape)
38 | s_args, v_args = list(zip(*args))
39 | return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)
40 |
41 |
42 | def tuple_index(x, idx):
43 | """
44 | Indexes into a tuple (s, V) along the first dimension.
45 |
46 | :param idx: any object which can be used to index into a `torch.Tensor`
47 | """
48 | return x[0][idx], x[1][idx]
49 |
50 |
51 | def randn(n, dims, device="cpu"):
52 | """
53 | Returns random tuples (s, V) drawn elementwise from a normal distribution.
54 |
55 | :param n: number of data points
56 | :param dims: tuple of dimensions (n_scalar, n_vector)
57 |
58 | :return: (s, V) with s.shape = (n, n_scalar) and
59 | V.shape = (n, n_vector, 3)
60 | """
61 | return torch.randn(n, dims[0], device=device), torch.randn(
62 | n, dims[1], 3, device=device
63 | )
64 |
65 |
66 | def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
67 | """
68 | L2 norm of tensor clamped above a minimum value `eps`.
69 |
70 | :param sqrt: if `False`, returns the square of the L2 norm
71 | """
72 | out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
73 | return torch.sqrt(out) if sqrt else out
74 |
75 |
76 | def _split(x, nv):
77 | '''
78 | Splits a merged representation of (s, V) back into a tuple.
79 | Should be used only with `_merge(s, V)` and only if the tuple
80 | representation cannot be used.
81 |
82 | :param x: the `torch.Tensor` returned from `_merge`
83 | :param nv: the number of vector channels in the input to `_merge`
84 | '''
85 | s = x[..., :-3 * nv]
86 | v = x[..., -3 * nv:].contiguous().view(x.shape[0], nv, 3)
87 | return s, v
88 |
89 |
90 | def _merge(s, v):
91 | '''
92 | Merges a tuple (s, V) into a single `torch.Tensor`, where the
93 | vector channels are flattened and appended to the scalar channels.
94 | Should be used only if the tuple representation cannot be used.
95 | Use `_split(x, nv)` to reverse.
96 | '''
97 | v = v.contiguous().view(v.shape[0], v.shape[1] * 3)
98 | return torch.cat([s, v], -1)
99 |
100 |
101 | class GVP(nn.Module):
102 | """
103 | Geometric Vector Perceptron. See manuscript and README.md
104 | for more details.
105 |
106 | :param in_dims: tuple (n_scalar, n_vector)
107 | :param out_dims: tuple (n_scalar, n_vector)
108 | :param h_dim: intermediate number of vector channels, optional
109 | :param activations: tuple of functions (scalar_act, vector_act)
110 | :param vector_gate: whether to use vector gating.
111 | (vector_act will be used as sigma^+ in vector gating if `True`)
112 | """
113 |
114 | def __init__(
115 | self,
116 | in_dims,
117 | out_dims,
118 | h_dim=None,
119 | activations=(F.relu, torch.sigmoid),
120 | vector_gate=True,
121 | ):
122 | super(GVP, self).__init__()
123 | self.si, self.vi = in_dims
124 | self.so, self.vo = out_dims
125 | self.vector_gate = vector_gate
126 | if self.vi:
127 | self.h_dim = h_dim or max(self.vi, self.vo)
128 | self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
129 | self.ws = nn.Linear(self.h_dim + self.si, self.so)
130 | if self.vo:
131 | self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
132 | if self.vector_gate:
133 | self.wsv = nn.Linear(self.so, self.vo)
134 | else:
135 | self.ws = nn.Linear(self.si, self.so)
136 |
137 | self.scalar_act, self.vector_act = activations
138 | self.dummy_param = nn.Parameter(torch.empty(0))
139 |
140 | def forward(self, x):
141 | """
142 | :param x: tuple (s, V) of `torch.Tensor`,
143 | or (if vectors_in is 0), a single `torch.Tensor`
144 | :return: tuple (s, V) of `torch.Tensor`,
145 | or (if vectors_out is 0), a single `torch.Tensor`
146 | """
147 | if self.vi:
148 | s, v = x
149 | v = torch.transpose(v, -1, -2)
150 | vh = self.wh(v)
151 | vn = _norm_no_nan(vh, axis=-2)
152 | s = self.ws(torch.cat([s, vn], -1))
153 | if self.vo:
154 | v = self.wv(vh)
155 | v = torch.transpose(v, -1, -2)
156 | if self.vector_gate:
157 | gate = (
158 | self.wsv(self.vector_act(s)) if self.vector_act else self.wsv(s)
159 | )
160 | v = v * torch.sigmoid(gate).unsqueeze(-1)
161 | elif self.vector_act:
162 | v = v * self.vector_act(_norm_no_nan(v, axis=-1, keepdims=True))
163 | else:
164 | s = self.ws(x)
165 | if self.vo:
166 | v = torch.zeros(s.shape[0], self.vo, 3, device=self.dummy_param.device)
167 | if self.scalar_act:
168 | s = self.scalar_act(s)
169 |
170 | return (s, v) if self.vo else s
171 |
172 |
173 | class _VDropout(nn.Module):
174 | """
175 | Vector channel dropout where the elements of each
176 | vector channel are dropped together.
177 | """
178 |
179 | def __init__(self, drop_rate):
180 | super(_VDropout, self).__init__()
181 | self.drop_rate = drop_rate
182 | self.dummy_param = nn.Parameter(torch.empty(0))
183 |
184 | def forward(self, x):
185 | """
186 | :param x: `torch.Tensor` corresponding to vector channels
187 | """
188 | device = self.dummy_param.device
189 | if not self.training:
190 | return x
191 | mask = torch.bernoulli(
192 | (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
193 | ).unsqueeze(-1)
194 | x = mask * x / (1 - self.drop_rate)
195 | return x
196 |
197 |
198 | class Dropout(nn.Module):
199 | """
200 | Combined dropout for tuples (s, V).
201 | Takes tuples (s, V) as input and as output.
202 | """
203 |
204 | def __init__(self, drop_rate):
205 | super(Dropout, self).__init__()
206 | self.sdropout = nn.Dropout(drop_rate)
207 | self.vdropout = _VDropout(drop_rate)
208 |
209 | def forward(self, x):
210 | """
211 | :param x: tuple (s, V) of `torch.Tensor`,
212 | or single `torch.Tensor`
213 | (will be assumed to be scalar channels)
214 | """
215 | if type(x) is torch.Tensor:
216 | return self.sdropout(x)
217 | s, v = x
218 | return self.sdropout(s), self.vdropout(v)
219 |
220 |
221 | class LayerNorm(nn.Module):
222 | """
223 | Combined LayerNorm for tuples (s, V).
224 | Takes tuples (s, V) as input and as output.
225 | """
226 |
227 | def __init__(self, dims):
228 | super(LayerNorm, self).__init__()
229 | self.s, self.v = dims
230 | self.scalar_norm = nn.LayerNorm(self.s)
231 |
232 | def forward(self, x):
233 | """
234 | :param x: tuple (s, V) of `torch.Tensor`,
235 | or single `torch.Tensor`
236 | (will be assumed to be scalar channels)
237 | """
238 | if not self.v:
239 | return self.scalar_norm(x)
240 | s, v = x
241 | vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False)
242 | vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True))
243 | return self.scalar_norm(s), v / vn
244 |
245 |
246 | class GVPConv(MessagePassing):
247 | """
248 | Graph convolution / message passing with Geometric Vector Perceptrons.
249 | Takes in a graph with node and edge embeddings,
250 | and returns new node embeddings.
251 |
252 | This does NOT do residual updates and pointwise feedforward layers
253 | ---see `GVPConvLayer`.
254 |
255 | :param in_dims: input node embedding dimensions (n_scalar, n_vector)
256 | :param out_dims: output node embedding dimensions (n_scalar, n_vector)
257 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
258 | :param n_layers: number of GVPs in the message function
259 | :param module_list: preconstructed message function, overrides n_layers
260 | :param aggr: should be "add" if some incoming edges are masked, as in
261 | a masked autoregressive decoder architecture, otherwise "mean"
262 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
263 | :param vector_gate: whether to use vector gating.
264 | (vector_act will be used as sigma^+ in vector gating if `True`)
265 | """
266 |
267 | def __init__(
268 | self,
269 | in_dims,
270 | out_dims,
271 | edge_dims,
272 | n_layers=3,
273 | module_list=None,
274 | aggr="mean",
275 | activations=(F.relu, torch.sigmoid),
276 | vector_gate=True,
277 | ):
278 | super(GVPConv, self).__init__(aggr=aggr)
279 | self.si, self.vi = in_dims
280 | self.so, self.vo = out_dims
281 | self.se, self.ve = edge_dims
282 |
283 | GVP_ = functools.partial(GVP, activations=activations, vector_gate=vector_gate)
284 |
285 | module_list = module_list or []
286 | if not module_list:
287 | if n_layers == 1:
288 | module_list.append(
289 | GVP_(
290 | (2 * self.si + self.se, 2 * self.vi + self.ve),
291 | (self.so, self.vo),
292 | activations=(None, None),
293 | )
294 | )
295 | else:
296 | module_list.append(
297 | GVP_((2 * self.si + self.se, 2 * self.vi + self.ve), out_dims)
298 | )
299 | for i in range(n_layers - 2):
300 | module_list.append(GVP_(out_dims, out_dims))
301 | module_list.append(GVP_(out_dims, out_dims, activations=(None, None)))
302 | self.message_func = nn.Sequential(*module_list)
303 |
304 | def forward(self, x, edge_index, edge_attr):
305 | """
306 | :param x: tuple (s, V) of `torch.Tensor`
307 | :param edge_index: array of shape [2, n_edges]
308 | :param edge_attr: tuple (s, V) of `torch.Tensor`
309 | """
310 | x_s, x_v = x
311 | message = self.propagate(
312 | edge_index,
313 | s=x_s,
314 | v=x_v.contiguous().view(x_v.shape[0], x_v.shape[1] * 3),
315 | edge_attr=edge_attr,
316 | )
317 | return _split(message, self.vo)
318 |
319 | def message(self, s_i, v_i, s_j, v_j, edge_attr):
320 | v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3)
321 | v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3)
322 | message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
323 | message = self.message_func(message)
324 | return _merge(*message)
325 |
326 |
327 | class GVPConvLayer(nn.Module):
328 | """
329 | Full graph convolution / message passing layer with
330 | Geometric Vector Perceptrons. Residually updates node embeddings with
331 | aggregated incoming messages, applies a pointwise feedforward
332 | network to node embeddings, and returns updated node embeddings.
333 |
334 | To only compute the aggregated messages, see `GVPConv`.
335 |
336 | :param node_dims: node embedding dimensions (n_scalar, n_vector)
337 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
338 | :param n_message: number of GVPs to use in message function
339 | :param n_feedforward: number of GVPs to use in feedforward function
340 | :param drop_rate: drop probability in all dropout layers
341 | :param autoregressive: if `True`, this `GVPConvLayer` will be used
342 | with a different set of input node embeddings for messages
343 | where src >= dst
344 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs
345 | :param vector_gate: whether to use vector gating.
346 | (vector_act will be used as sigma^+ in vector gating if `True`)
347 | """
348 |
349 | def __init__(
350 | self,
351 | node_dims,
352 | edge_dims,
353 | n_message=3,
354 | n_feedforward=2,
355 | drop_rate=0.1,
356 | autoregressive=False,
357 | activations=(F.relu, torch.sigmoid),
358 | vector_gate=True,
359 | residual=True,
360 | ):
361 | super(GVPConvLayer, self).__init__()
362 | self.conv = GVPConv(
363 | node_dims,
364 | node_dims,
365 | edge_dims,
366 | n_message,
367 | aggr="add" if autoregressive else "mean",
368 | activations=activations,
369 | vector_gate=vector_gate,
370 | )
371 | GVP_ = functools.partial(GVP, activations=activations, vector_gate=vector_gate)
372 | self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)])
373 | self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
374 |
375 | ff_func = []
376 | if n_feedforward == 1:
377 | ff_func.append(GVP_(node_dims, node_dims, activations=(None, None)))
378 | else:
379 | hid_dims = 4 * node_dims[0], 2 * node_dims[1]
380 | ff_func.append(GVP_(node_dims, hid_dims))
381 | ff_func.extend(GVP_(hid_dims, hid_dims) for _ in range(n_feedforward - 2))
382 | ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None)))
383 | self.ff_func = nn.Sequential(*ff_func)
384 | self.residual = residual
385 |
386 | def forward(self, x, edge_index, edge_attr, autoregressive_x=None, node_mask=None):
387 | """
388 | :param x: tuple (s, V) of `torch.Tensor`
389 | :param edge_index: array of shape [2, n_edges]
390 | :param edge_attr: tuple (s, V) of `torch.Tensor`
391 | :param autoregressive_x: tuple (s, V) of `torch.Tensor`.
392 | If not `None`, will be used as src node embeddings
393 | for forming messages where src >= dst. The corrent node
394 | embeddings `x` will still be the base of the update and the
395 | pointwise feedforward.
396 | :param node_mask: array of type `bool` to index into the first
397 | dim of node embeddings (s, V). If not `None`, only
398 | these nodes will be updated.
399 | """
400 |
401 | if autoregressive_x is not None:
402 | src, dst = edge_index
403 | mask = src < dst
404 | edge_index_forward = edge_index[:, mask]
405 | edge_index_backward = edge_index[:, ~mask]
406 | edge_attr_forward = tuple_index(edge_attr, mask)
407 | edge_attr_backward = tuple_index(edge_attr, ~mask)
408 |
409 | dh = tuple_sum(
410 | self.conv(x, edge_index_forward, edge_attr_forward),
411 | self.conv(autoregressive_x, edge_index_backward, edge_attr_backward),
412 | )
413 |
414 | count = (
415 | torch_scatter.scatter_add(
416 | torch.ones_like(dst), dst, dim_size=dh[0].size(0)
417 | )
418 | .clamp(min=1)
419 | .unsqueeze(-1)
420 | )
421 |
422 | dh = dh[0] / count, dh[1] / count.unsqueeze(-1)
423 |
424 | else:
425 | dh = self.conv(x, edge_index, edge_attr)
426 |
427 | if node_mask is not None:
428 | x_ = x
429 | x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
430 |
431 | x = self.norm[0](tuple_sum(x, self.dropout[0](dh))) if self.residual else dh
432 |
433 | dh = self.ff_func(x)
434 | x = self.norm[1](tuple_sum(x, self.dropout[1](dh))) if self.residual else dh
435 |
436 | if node_mask is not None:
437 | x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
438 | x = x_
439 | return x
--------------------------------------------------------------------------------
/experiments/incompleteness.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "metadata": {},
7 | "source": [
8 | "# Identifying neighbourhood fingerprints: counterexamples from [Pozdnyakov et al., 2020](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.125.166001)\n",
9 | "\n",
10 | "*Background:*\n",
11 | "Geometric GNNs identify local neighbourhoods around nodes via **'neighbourhood finderprints'**, where local geometric information from subsets of neighbours is aggregated to compute invariant scalars. \n",
12 | "The number of neighbours involved in computing the scalars is termed the **body order**.\n",
13 | "The ideal neighbourhood fingerprint would perfectly identify neighbourhoods, which requires arbitrarily high body order.\n",
14 | "\n",
15 | "*Experiment:*\n",
16 | "To demonstrate the practical implications of scalarisation body order, we evaluate geometric GNN layers on their ability to discriminate counterexamples from [Pozdnyakov et al., 2020](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.125.166001).\n",
17 | "Each counterexample consists of a pair of local neighbourhoods that are **indistinguishable** when comparing their set of $k$-body scalars, i.e. geometric GNN layers with body order $k$ cannot distinguish the neighbourhoods.\n",
18 | "The 3-body counterexample corresponds to Fig.1(b) in Pozdnyakov et al., 2020, 4-body chiral to Fig.2(e), and 4-body non-chiral to Fig.2(f); the 2-body counterexample is based on the two local neighbourhoods in the running example from our paper.\n",
19 | "In this notebook, we train single layer geometric GNNs to distinguish the counterexamples using updated scalar features. \n",
20 | "\n",
21 | ""
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "%load_ext autoreload\n",
31 | "%autoreload 2\n",
32 | "\n",
33 | "import sys\n",
34 | "sys.path.append('../')\n",
35 | "\n",
36 | "import torch\n",
37 | "import torch_geometric\n",
38 | "from torch_geometric.data import Data\n",
39 | "from torch_geometric.loader import DataLoader\n",
40 | "from torch_geometric.utils import to_undirected\n",
41 | "import e3nn\n",
42 | "from functools import partial\n",
43 | "\n",
44 | "print(\"PyTorch version {}\".format(torch.__version__))\n",
45 | "print(\"PyG version {}\".format(torch_geometric.__version__))\n",
46 | "print(\"e3nn version {}\".format(e3nn.__version__))\n",
47 | "\n",
48 | "from experiments.utils.plot_utils import plot_3d\n",
49 | "from experiments.utils.train_utils import run_experiment\n",
50 | "from models import SchNetModel, DimeNetPPModel, SphereNetModel, EGNNModel, GVPGNNModel, TFNModel, MACEModel\n",
51 | "\n",
52 | "# Set the device\n",
53 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
54 | "print(f\"Using device: {device}\")"
55 | ]
56 | },
57 | {
58 | "attachments": {},
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "# Two-body counterexample\n",
63 | "\n",
64 | "Pair of local neighbourhoods that are indistinguishable when comparing their set of $2$-body scalars, i.e. the unordered set of pairwise distances."
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "def create_two_body_envs():\n",
74 | " dataset = []\n",
75 | "\n",
76 | " # Environment 0\n",
77 | " atoms = torch.LongTensor([ 0, 0, 0 ])\n",
78 | " edge_index = torch.LongTensor([ [0, 0], [1, 2] ])\n",
79 | " pos = torch.FloatTensor([ \n",
80 | " [0, 0, 0],\n",
81 | " [5, 0, 0],\n",
82 | " [3, 0, 4]\n",
83 | " ])\n",
84 | " y = torch.LongTensor([0]) # Label 0\n",
85 | " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
86 | " data1.edge_index = to_undirected(data1.edge_index)\n",
87 | " dataset.append(data1)\n",
88 | " \n",
89 | " # Environment 1\n",
90 | " atoms = torch.LongTensor([ 0, 0, 0 ])\n",
91 | " edge_index = torch.LongTensor([ [0, 0], [1, 2] ])\n",
92 | " pos = torch.FloatTensor([ \n",
93 | " [0, 0, 0],\n",
94 | " [5, 0, 0],\n",
95 | " [-5, 0, 0]\n",
96 | " ])\n",
97 | " y = torch.LongTensor([1]) # Label 1\n",
98 | " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
99 | " data2.edge_index = to_undirected(data2.edge_index)\n",
100 | " dataset.append(data2)\n",
101 | " \n",
102 | " return dataset\n",
103 | "\n",
104 | "# Create dataset\n",
105 | "dataset = create_two_body_envs()\n",
106 | "for data in dataset:\n",
107 | " plot_3d(data, lim=5)\n",
108 | "\n",
109 | "# Create dataloaders\n",
110 | "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
111 | "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
112 | "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "# Set model\n",
122 | "model_name = \"mace\"\n",
123 | "\n",
124 | "correlation = 2\n",
125 | "model = {\n",
126 | " \"schnet\": SchNetModel,\n",
127 | " \"dimenet\": DimeNetPPModel,\n",
128 | " \"spherenet\": SphereNetModel,\n",
129 | " \"egnn\": EGNNModel,\n",
130 | " \"gvp\": GVPGNNModel,\n",
131 | " \"tfn\": TFNModel,\n",
132 | " \"mace\": partial(MACEModel, correlation=correlation),\n",
133 | "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
134 | "\n",
135 | "best_val_acc, test_acc, train_time = run_experiment(\n",
136 | " model, \n",
137 | " dataloader,\n",
138 | " val_loader, \n",
139 | " test_loader,\n",
140 | " n_epochs=100,\n",
141 | " n_times=10,\n",
142 | " device=device,\n",
143 | " verbose=False\n",
144 | ")"
145 | ]
146 | },
147 | {
148 | "attachments": {},
149 | "cell_type": "markdown",
150 | "metadata": {},
151 | "source": [
152 | "# Three-body counterexample\n",
153 | "\n",
154 | "Pair of local neighbourhoods that are indistinguishable when comparing their set of $3$-body scalars, i.e. the unordered set of pairwise distances as well as angles."
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "def create_three_body_envs():\n",
164 | " dataset = []\n",
165 | "\n",
166 | " a_x, a_y, a_z = 5, 0, 5\n",
167 | " b_x, b_y, b_z = 5, 5, 5\n",
168 | " c_x, c_y, c_z = 0, 5, 5\n",
169 | " \n",
170 | " # Environment 0\n",
171 | " atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n",
172 | " edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n",
173 | " pos = torch.FloatTensor([ \n",
174 | " [0, 0, 0],\n",
175 | " [a_x, a_y, a_z],\n",
176 | " [+b_x, +b_y, b_z],\n",
177 | " [-b_x, -b_y, b_z],\n",
178 | " [c_x, +c_y, c_z],\n",
179 | " ])\n",
180 | " y = torch.LongTensor([0]) # Label 0\n",
181 | " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
182 | " data1.edge_index = to_undirected(data1.edge_index)\n",
183 | " dataset.append(data1)\n",
184 | " \n",
185 | " # Environment 1\n",
186 | " atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n",
187 | " edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n",
188 | " pos = torch.FloatTensor([ \n",
189 | " [0, 0, 0],\n",
190 | " [a_x, a_y, a_z],\n",
191 | " [+b_x, +b_y, b_z],\n",
192 | " [-b_x, -b_y, b_z],\n",
193 | " [c_x, -c_y, c_z],\n",
194 | " ])\n",
195 | " y = torch.LongTensor([1]) # Label 1\n",
196 | " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
197 | " data2.edge_index = to_undirected(data2.edge_index)\n",
198 | " dataset.append(data2)\n",
199 | " \n",
200 | " return dataset\n",
201 | "\n",
202 | "# Create dataset\n",
203 | "dataset = create_three_body_envs()\n",
204 | "for data in dataset:\n",
205 | " plot_3d(data, lim=5)\n",
206 | "\n",
207 | "# Create dataloaders\n",
208 | "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
209 | "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
210 | "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": null,
216 | "metadata": {},
217 | "outputs": [],
218 | "source": [
219 | "# Set model\n",
220 | "model_name = \"mace\"\n",
221 | "\n",
222 | "correlation = 3\n",
223 | "model = {\n",
224 | " \"schnet\": SchNetModel,\n",
225 | " \"dimenet\": DimeNetPPModel,\n",
226 | " \"spherenet\": SphereNetModel,\n",
227 | " \"egnn\": EGNNModel,\n",
228 | " \"gvp\": GVPGNNModel,\n",
229 | " \"tfn\": TFNModel,\n",
230 | " \"mace\": partial(MACEModel, correlation=correlation),\n",
231 | "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
232 | "\n",
233 | "best_val_acc, test_acc, train_time = run_experiment(\n",
234 | " model, \n",
235 | " dataloader,\n",
236 | " val_loader, \n",
237 | " test_loader,\n",
238 | " n_epochs=100,\n",
239 | " n_times=10,\n",
240 | " device=device,\n",
241 | " verbose=False\n",
242 | ")"
243 | ]
244 | },
245 | {
246 | "attachments": {},
247 | "cell_type": "markdown",
248 | "metadata": {},
249 | "source": [
250 | "# Four-body non-chiral counterexample\n",
251 | "\n",
252 | "Pair of local neighbourhoods that are indistinguishable when comparing their set of $4$-body scalars without considering chirality/handedness, i.e. the unordered set of pairwise distances, angles, and quadruplet scalars."
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "metadata": {},
259 | "outputs": [],
260 | "source": [
261 | "def create_four_body_nonchiral_envs():\n",
262 | " dataset = []\n",
263 | "\n",
264 | " a1_x, a1_y, a1_z = 3, 2, -4\n",
265 | " a2_x, a2_y, a2_z = 0, 2, 5\n",
266 | " a3_x, a3_y, a3_z = 0, 2, -5\n",
267 | " b1_x, b1_y, b1_z = 3, -2, -4\n",
268 | " b2_x, b2_y, b2_z = 0, -2, 5\n",
269 | " b3_x, b3_y, b3_z = 0, -2, -5\n",
270 | " c_x, c_y, c_z = 0, 5, 0\n",
271 | "\n",
272 | " angle = 1 * torch.pi / 10 # random angle\n",
273 | " Q = e3nn.o3.matrix_y(torch.tensor(angle)).numpy()\n",
274 | "\n",
275 | " # Environment 0\n",
276 | " atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0, 0 ])\n",
277 | " edge_index = torch.LongTensor([ [0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7] ])\n",
278 | " pos = torch.FloatTensor([ \n",
279 | " [0, 0, 0],\n",
280 | " [a1_x, a1_y, a1_z],\n",
281 | " [a2_x, a2_y, a2_z],\n",
282 | " [a3_x, a3_y, a3_z],\n",
283 | " [b1_x, b1_y, b1_z] @ Q,\n",
284 | " [b2_x, b2_y, b2_z] @ Q,\n",
285 | " [b3_x, b3_y, b3_z] @ Q,\n",
286 | " [c_x, +c_y, c_z],\n",
287 | " ]) #.to(dtype=torch.float64)\n",
288 | " y = torch.LongTensor([0]) # Label 0\n",
289 | " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
290 | " data1.edge_index = to_undirected(data1.edge_index)\n",
291 | " dataset.append(data1)\n",
292 | " \n",
293 | " # Environment 1\n",
294 | " atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0, 0 ])\n",
295 | " edge_index = torch.LongTensor([ [0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7] ])\n",
296 | " pos = torch.FloatTensor([ \n",
297 | " [0, 0, 0],\n",
298 | " [a1_x, a1_y, a1_z],\n",
299 | " [a2_x, a2_y, a2_z],\n",
300 | " [a3_x, a3_y, a3_z],\n",
301 | " [b1_x, b1_y, b1_z] @ Q,\n",
302 | " [b2_x, b2_y, b2_z] @ Q,\n",
303 | " [b3_x, b3_y, b3_z] @ Q,\n",
304 | " [c_x, -c_y, c_z],\n",
305 | " ]) #.to(dtype=torch.float64)\n",
306 | " y = torch.LongTensor([1]) # Label 1\n",
307 | " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
308 | " data2.edge_index = to_undirected(data2.edge_index)\n",
309 | " dataset.append(data2)\n",
310 | " \n",
311 | " return dataset\n",
312 | "\n",
313 | "# Create dataset\n",
314 | "dataset = create_four_body_nonchiral_envs()\n",
315 | "for data in dataset:\n",
316 | " plot_3d(data, lim=5)\n",
317 | "\n",
318 | "# Create dataloaders\n",
319 | "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
320 | "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
321 | "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": null,
327 | "metadata": {},
328 | "outputs": [],
329 | "source": [
330 | "# Set model\n",
331 | "model_name = \"mace\"\n",
332 | "\n",
333 | "correlation = 4\n",
334 | "model = {\n",
335 | " \"schnet\": SchNetModel,\n",
336 | " \"dimenet\": DimeNetPPModel,\n",
337 | " \"spherenet\": SphereNetModel,\n",
338 | " \"egnn\": EGNNModel,\n",
339 | " \"gvp\": GVPGNNModel,\n",
340 | " \"tfn\": TFNModel,\n",
341 | " \"mace\": partial(MACEModel, correlation=correlation),\n",
342 | "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
343 | "\n",
344 | "best_val_acc, test_acc, train_time = run_experiment(\n",
345 | " model, \n",
346 | " dataloader,\n",
347 | " val_loader, \n",
348 | " test_loader,\n",
349 | " n_epochs=100,\n",
350 | " n_times=10,\n",
351 | " device=device,\n",
352 | " verbose=False\n",
353 | ")"
354 | ]
355 | },
356 | {
357 | "attachments": {},
358 | "cell_type": "markdown",
359 | "metadata": {},
360 | "source": [
361 | "# Four-body chiral counterexample\n",
362 | "\n",
363 | "Pair of local neighbourhoods that are indistinguishable when comparing their set of $4$-body scalars when considering chirality/handedness, i.e. the unordered set of pairwise distances, angles, and quadruplet scalars."
364 | ]
365 | },
366 | {
367 | "cell_type": "code",
368 | "execution_count": null,
369 | "metadata": {},
370 | "outputs": [],
371 | "source": [
372 | "def create_four_body_chiral_envs():\n",
373 | " dataset = []\n",
374 | "\n",
375 | " a1_x, a1_y, a1_z = 3, 0, -4\n",
376 | " a2_x, a2_y, a2_z = 0, 0, 5\n",
377 | " a3_x, a3_y, a3_z = 0, 0, -5\n",
378 | " c_x, c_y, c_z = 0, 5, 0\n",
379 | "\n",
380 | " # Environment 0\n",
381 | " atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n",
382 | " edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n",
383 | " pos = torch.FloatTensor([ \n",
384 | " [0, 0, 0],\n",
385 | " [a1_x, a1_y, a1_z],\n",
386 | " [a2_x, a2_y, a2_z],\n",
387 | " [a3_x, a3_y, a3_z],\n",
388 | " [c_x, +c_y, c_z],\n",
389 | " ])\n",
390 | " y = torch.LongTensor([0]) # Label 0\n",
391 | " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
392 | " data1.edge_index = to_undirected(data1.edge_index)\n",
393 | " dataset.append(data1)\n",
394 | " \n",
395 | " # Environment 1\n",
396 | " atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n",
397 | " edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n",
398 | " pos = torch.FloatTensor([ \n",
399 | " [0, 0, 0],\n",
400 | " [a1_x, a1_y, a1_z],\n",
401 | " [a2_x, a2_y, a2_z],\n",
402 | " [a3_x, a3_y, a3_z],\n",
403 | " [c_x, -c_y, c_z],\n",
404 | " ])\n",
405 | " y = torch.LongTensor([1]) # Label 1\n",
406 | " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
407 | " data2.edge_index = to_undirected(data2.edge_index)\n",
408 | " dataset.append(data2)\n",
409 | " \n",
410 | " return dataset\n",
411 | "\n",
412 | "# Create dataset\n",
413 | "dataset = create_four_body_chiral_envs()\n",
414 | "for data in dataset:\n",
415 | " plot_3d(data, lim=5)\n",
416 | "\n",
417 | "# Create dataloaders\n",
418 | "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
419 | "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
420 | "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "execution_count": null,
426 | "metadata": {},
427 | "outputs": [],
428 | "source": [
429 | "# Set model\n",
430 | "model_name = \"tfn\"\n",
431 | "\n",
432 | "correlation = 4\n",
433 | "model = {\n",
434 | " \"schnet\": SchNetModel,\n",
435 | " \"dimenet\": DimeNetPPModel,\n",
436 | " \"spherenet\": SphereNetModel,\n",
437 | " \"egnn\": EGNNModel,\n",
438 | " \"gvp\": GVPGNNModel,\n",
439 | " \"tfn\": partial(TFNModel, hidden_irreps=e3nn.o3.Irreps(f'64x0e + 64x0o + 64x1e + 64x1o + 64x2e + 64x2o')),\n",
440 | " \"mace\": partial(MACEModel, correlation=correlation, hidden_irreps=e3nn.o3.Irreps(f'32x0e + 32x0o + 32x1e + 32x1o + 32x2e + 32x2o')),\n",
441 | "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
442 | "\n",
443 | "best_val_acc, test_acc, train_time = run_experiment(\n",
444 | " model, \n",
445 | " dataloader,\n",
446 | " val_loader, \n",
447 | " test_loader,\n",
448 | " n_epochs=100,\n",
449 | " n_times=10,\n",
450 | " device=device,\n",
451 | " verbose=False\n",
452 | ")"
453 | ]
454 | }
455 | ],
456 | "metadata": {
457 | "kernelspec": {
458 | "display_name": "Python 3",
459 | "language": "python",
460 | "name": "python3"
461 | },
462 | "language_info": {
463 | "codemirror_mode": {
464 | "name": "ipython",
465 | "version": 3
466 | },
467 | "file_extension": ".py",
468 | "mimetype": "text/x-python",
469 | "name": "python",
470 | "nbconvert_exporter": "python",
471 | "pygments_lexer": "ipython3",
472 | "version": "3.8.16"
473 | },
474 | "orig_nbformat": 4,
475 | "vscode": {
476 | "interpreter": {
477 | "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897"
478 | }
479 | }
480 | },
481 | "nbformat": 4,
482 | "nbformat_minor": 2
483 | }
484 |
--------------------------------------------------------------------------------
/models/mace_modules/blocks.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network
3 | # Authors: Ilyes Batatia, Gregor Simm
4 | # This program is distributed under the MIT License (see MIT.md)
5 | ###########################################################################################
6 |
7 | from abc import ABC, abstractmethod
8 | from typing import Callable, Dict, Optional, Tuple, Union
9 |
10 | import numpy as np
11 | import torch.nn.functional
12 | from e3nn import nn, o3
13 |
14 | # from mace.tools.scatter import scatter_sum
15 | from torch_scatter import scatter_sum
16 |
17 | from .irreps_tools import (
18 | linear_out_irreps,
19 | reshape_irreps,
20 | tp_out_irreps_with_instructions,
21 | )
22 | from .radial import BesselBasis, PolynomialCutoff
23 | from .symmetric_contraction import SymmetricContraction
24 |
25 |
26 | class LinearNodeEmbeddingBlock(torch.nn.Module):
27 | def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps):
28 | super().__init__()
29 | self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out)
30 |
31 | def forward(
32 | self, node_attrs: torch.Tensor, # [n_nodes, irreps]
33 | ):
34 | return self.linear(node_attrs)
35 |
36 |
37 | class LinearReadoutBlock(torch.nn.Module):
38 | def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps = o3.Irreps("0e")):
39 | super().__init__()
40 | self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out)
41 |
42 | def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
43 | return self.linear(x) # [n_nodes, irreps_out]
44 |
45 |
46 | class NonLinearReadoutBlock(torch.nn.Module):
47 | def __init__(
48 | self, irreps_in: o3.Irreps, MLP_irreps: o3.Irreps,
49 | gate: Optional[Callable], irreps_out: o3.Irreps = o3.Irreps("0e")
50 | ):
51 | super().__init__()
52 | self.hidden_irreps = MLP_irreps
53 | self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps)
54 | self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate])
55 | self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irreps_out)
56 |
57 | def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ]
58 | x = self.non_linearity(self.linear_1(x))
59 | return self.linear_2(x) # [n_nodes, irreps_out]
60 |
61 |
62 | class AtomicEnergiesBlock(torch.nn.Module):
63 | atomic_energies: torch.Tensor
64 |
65 | def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]):
66 | super().__init__()
67 | assert len(atomic_energies.shape) == 1
68 |
69 | self.register_buffer(
70 | "atomic_energies",
71 | torch.tensor(atomic_energies, dtype=torch.get_default_dtype()),
72 | ) # [n_elements, ]
73 |
74 | def forward(
75 | self, x: torch.Tensor # one-hot of elements [..., n_elements]
76 | ) -> torch.Tensor: # [..., ]
77 | return torch.matmul(x, self.atomic_energies)
78 |
79 | def __repr__(self):
80 | formatted_energies = ", ".join([f"{x:.4f}" for x in self.atomic_energies])
81 | return f"{self.__class__.__name__}(energies=[{formatted_energies}])"
82 |
83 |
84 | class RadialEmbeddingBlock(torch.nn.Module):
85 | def __init__(self, r_max: float, num_bessel: int, num_polynomial_cutoff: int):
86 | super().__init__()
87 | self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel)
88 | self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff)
89 | self.out_dim = num_bessel
90 |
91 | def forward(
92 | self, edge_lengths: torch.Tensor, # [n_edges, 1]
93 | ):
94 | bessel = self.bessel_fn(edge_lengths) # [n_edges, n_basis]
95 | cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1]
96 | return bessel * cutoff # [n_edges, n_basis]
97 |
98 |
99 | class EquivariantProductBasisBlock(torch.nn.Module):
100 | def __init__(
101 | self,
102 | node_feats_irreps: o3.Irreps,
103 | target_irreps: o3.Irreps,
104 | correlation: Union[int, Dict[str, int]],
105 | element_dependent: bool = True,
106 | use_sc: bool = True,
107 | batch_norm: bool = False,
108 | num_elements: Optional[int] = None,
109 | ) -> None:
110 | super().__init__()
111 |
112 | self.use_sc = use_sc
113 | self.symmetric_contractions = SymmetricContraction(
114 | irreps_in=node_feats_irreps,
115 | irreps_out=target_irreps,
116 | correlation=correlation,
117 | element_dependent=element_dependent,
118 | num_elements=num_elements,
119 | )
120 | # Update linear
121 | self.linear = o3.Linear(
122 | target_irreps, target_irreps, internal_weights=True, shared_weights=True,
123 | )
124 | self.batch_norm = nn.BatchNorm(target_irreps) if batch_norm else None
125 |
126 | def forward(
127 | self, node_feats: torch.Tensor, sc: Optional[torch.Tensor], node_attrs: Optional[torch.Tensor]
128 | ) -> torch.Tensor:
129 | node_feats = self.symmetric_contractions(node_feats, node_attrs)
130 | out = self.linear(node_feats)
131 | if self.batch_norm:
132 | out = self.batch_norm(out)
133 | if self.use_sc:
134 | out = out + sc
135 | return out
136 |
137 |
138 | class InteractionBlock(ABC, torch.nn.Module):
139 | def __init__(
140 | self,
141 | node_attrs_irreps: o3.Irreps,
142 | node_feats_irreps: o3.Irreps,
143 | edge_attrs_irreps: o3.Irreps,
144 | edge_feats_irreps: o3.Irreps,
145 | target_irreps: o3.Irreps,
146 | hidden_irreps: o3.Irreps,
147 | avg_num_neighbors: float,
148 | ) -> None:
149 | super().__init__()
150 | self.node_attrs_irreps = node_attrs_irreps
151 | self.node_feats_irreps = node_feats_irreps
152 | self.edge_attrs_irreps = edge_attrs_irreps
153 | self.edge_feats_irreps = edge_feats_irreps
154 | self.target_irreps = target_irreps
155 | self.hidden_irreps = hidden_irreps
156 | self.avg_num_neighbors = avg_num_neighbors
157 |
158 | self._setup()
159 |
160 | @abstractmethod
161 | def _setup(self) -> None:
162 | raise NotImplementedError
163 |
164 | @abstractmethod
165 | def forward(
166 | self,
167 | node_attrs: torch.Tensor,
168 | node_feats: torch.Tensor,
169 | edge_attrs: torch.Tensor,
170 | edge_feats: torch.Tensor,
171 | edge_index: torch.Tensor,
172 | ) -> torch.Tensor:
173 | raise NotImplementedError
174 |
175 |
176 | nonlinearities = {1: torch.nn.SiLU(), -1: torch.nn.Tanh()}
177 |
178 |
179 | class TensorProductWeightsBlock(torch.nn.Module):
180 | def __init__(self, num_elements: int, num_edge_feats: int, num_feats_out: int):
181 | super().__init__()
182 |
183 | weights = torch.empty(
184 | (num_elements, num_edge_feats, num_feats_out),
185 | dtype=torch.get_default_dtype(),
186 | )
187 | torch.nn.init.xavier_uniform_(weights)
188 | self.weights = torch.nn.Parameter(weights)
189 |
190 | def forward(
191 | self,
192 | sender_or_receiver_node_attrs: torch.Tensor, # assumes that the node attributes are one-hot encoded
193 | edge_feats: torch.Tensor,
194 | ):
195 | return torch.einsum(
196 | "be, ba, aek -> bk", edge_feats, sender_or_receiver_node_attrs, self.weights
197 | )
198 |
199 | def __repr__(self):
200 | return (
201 | f'{self.__class__.__name__}(shape=({", ".join(str(s) for s in self.weights.shape)}), '
202 | f"weights={np.prod(self.weights.shape)})"
203 | )
204 |
205 |
206 | class ResidualElementDependentInteractionBlock(InteractionBlock):
207 | def _setup(self) -> None:
208 | self.linear_up = o3.Linear(
209 | self.node_feats_irreps,
210 | self.node_feats_irreps,
211 | internal_weights=True,
212 | shared_weights=True,
213 | )
214 | # TensorProduct
215 | irreps_mid, instructions = tp_out_irreps_with_instructions(
216 | self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps
217 | )
218 | self.conv_tp = o3.TensorProduct(
219 | self.node_feats_irreps,
220 | self.edge_attrs_irreps,
221 | irreps_mid,
222 | instructions=instructions,
223 | shared_weights=False,
224 | internal_weights=False,
225 | )
226 | self.conv_tp_weights = TensorProductWeightsBlock(
227 | num_elements=self.node_attrs_irreps.num_irreps,
228 | num_edge_feats=self.edge_feats_irreps.num_irreps,
229 | num_feats_out=self.conv_tp.weight_numel,
230 | )
231 |
232 | # Linear
233 | irreps_mid = irreps_mid.simplify()
234 | self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps)
235 | self.irreps_out = self.irreps_out.simplify()
236 | self.linear = o3.Linear(
237 | irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
238 | )
239 |
240 | # Selector TensorProduct
241 | self.skip_tp = o3.FullyConnectedTensorProduct(
242 | self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out
243 | )
244 |
245 | def forward(
246 | self,
247 | node_attrs: torch.Tensor,
248 | node_feats: torch.Tensor,
249 | edge_attrs: torch.Tensor,
250 | edge_feats: torch.Tensor,
251 | edge_index: torch.Tensor,
252 | ) -> torch.Tensor:
253 | sender, receiver = edge_index
254 | num_nodes = node_feats.shape[0]
255 | sc = self.skip_tp(node_feats, node_attrs)
256 | node_feats = self.linear_up(node_feats)
257 | tp_weights = self.conv_tp_weights(node_attrs[sender], edge_feats)
258 | mji = self.conv_tp(
259 | node_feats[sender], edge_attrs, tp_weights
260 | ) # [n_edges, irreps]
261 | message = scatter_sum(
262 | src=mji, index=receiver, dim=0, dim_size=num_nodes
263 | ) # [n_nodes, irreps]
264 | message = self.linear(message) / self.avg_num_neighbors
265 | return message + sc # [n_nodes, irreps]
266 |
267 |
268 | class AgnosticNonlinearInteractionBlock(InteractionBlock):
269 | def _setup(self) -> None:
270 | self.linear_up = o3.Linear(
271 | self.node_feats_irreps,
272 | self.node_feats_irreps,
273 | internal_weights=True,
274 | shared_weights=True,
275 | )
276 | # TensorProduct
277 | irreps_mid, instructions = tp_out_irreps_with_instructions(
278 | self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps
279 | )
280 | self.conv_tp = o3.TensorProduct(
281 | self.node_feats_irreps,
282 | self.edge_attrs_irreps,
283 | irreps_mid,
284 | instructions=instructions,
285 | shared_weights=False,
286 | internal_weights=False,
287 | )
288 |
289 | # Convolution weights
290 | input_dim = self.edge_feats_irreps.num_irreps
291 | self.conv_tp_weights = nn.FullyConnectedNet(
292 | [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], torch.nn.SiLU(),
293 | )
294 |
295 | # Linear
296 | irreps_mid = irreps_mid.simplify()
297 | self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps)
298 | self.irreps_out = self.irreps_out.simplify()
299 | self.linear = o3.Linear(
300 | irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
301 | )
302 |
303 | # Selector TensorProduct
304 | self.skip_tp = o3.FullyConnectedTensorProduct(
305 | self.irreps_out, self.node_attrs_irreps, self.irreps_out
306 | )
307 |
308 | def forward(
309 | self,
310 | node_attrs: torch.Tensor,
311 | node_feats: torch.Tensor,
312 | edge_attrs: torch.Tensor,
313 | edge_feats: torch.Tensor,
314 | edge_index: torch.Tensor,
315 | ) -> torch.Tensor:
316 | sender, receiver = edge_index
317 | num_nodes = node_feats.shape[0]
318 | tp_weights = self.conv_tp_weights(edge_feats)
319 | node_feats = self.linear_up(node_feats)
320 | mji = self.conv_tp(
321 | node_feats[sender], edge_attrs, tp_weights
322 | ) # [n_edges, irreps]
323 | message = scatter_sum(
324 | src=mji, index=receiver, dim=0, dim_size=num_nodes
325 | ) # [n_nodes, irreps]
326 | message = self.linear(message) / self.avg_num_neighbors
327 | message = self.skip_tp(message, node_attrs)
328 | return message # [n_nodes, irreps]
329 |
330 |
331 | class AgnosticResidualNonlinearInteractionBlock(InteractionBlock):
332 | def _setup(self) -> None:
333 | # First linear
334 | self.linear_up = o3.Linear(
335 | self.node_feats_irreps,
336 | self.node_feats_irreps,
337 | internal_weights=True,
338 | shared_weights=True,
339 | )
340 | # TensorProduct
341 | irreps_mid, instructions = tp_out_irreps_with_instructions(
342 | self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps
343 | )
344 | self.conv_tp = o3.TensorProduct(
345 | self.node_feats_irreps,
346 | self.edge_attrs_irreps,
347 | irreps_mid,
348 | instructions=instructions,
349 | shared_weights=False,
350 | internal_weights=False,
351 | )
352 |
353 | # Convolution weights
354 | input_dim = self.edge_feats_irreps.num_irreps
355 | self.conv_tp_weights = nn.FullyConnectedNet(
356 | [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], torch.nn.SiLU(),
357 | )
358 |
359 | # Linear
360 | irreps_mid = irreps_mid.simplify()
361 | self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps)
362 | self.irreps_out = self.irreps_out.simplify()
363 | self.linear = o3.Linear(
364 | irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
365 | )
366 |
367 | # Selector TensorProduct
368 | self.skip_tp = o3.FullyConnectedTensorProduct(
369 | self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out
370 | )
371 |
372 | def forward(
373 | self,
374 | node_attrs: torch.Tensor,
375 | node_feats: torch.Tensor,
376 | edge_attrs: torch.Tensor,
377 | edge_feats: torch.Tensor,
378 | edge_index: torch.Tensor,
379 | ) -> torch.Tensor:
380 | sender, receiver = edge_index
381 | num_nodes = node_feats.shape[0]
382 | sc = self.skip_tp(node_feats, node_attrs)
383 | node_feats = self.linear_up(node_feats)
384 | tp_weights = self.conv_tp_weights(edge_feats)
385 | mji = self.conv_tp(
386 | node_feats[sender], edge_attrs, tp_weights
387 | ) # [n_edges, irreps]
388 | message = scatter_sum(
389 | src=mji, index=receiver, dim=0, dim_size=num_nodes
390 | ) # [n_nodes, irreps]
391 | message = self.linear(message) / self.avg_num_neighbors
392 | message = message + sc
393 | return message # [n_nodes, irreps]
394 |
395 |
396 | class RealAgnosticInteractionBlock(InteractionBlock):
397 | def _setup(self) -> None:
398 | # First linear
399 | self.linear_up = o3.Linear(
400 | self.node_feats_irreps,
401 | self.node_feats_irreps,
402 | internal_weights=True,
403 | shared_weights=True,
404 | )
405 | # TensorProduct
406 | irreps_mid, instructions = tp_out_irreps_with_instructions(
407 | self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps,
408 | )
409 | self.conv_tp = o3.TensorProduct(
410 | self.node_feats_irreps,
411 | self.edge_attrs_irreps,
412 | irreps_mid,
413 | instructions=instructions,
414 | shared_weights=False,
415 | internal_weights=False,
416 | )
417 |
418 | # Convolution weights
419 | input_dim = self.edge_feats_irreps.num_irreps
420 | self.conv_tp_weights = nn.FullyConnectedNet(
421 | [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], torch.nn.SiLU(),
422 | )
423 |
424 | # Linear
425 | irreps_mid = irreps_mid.simplify()
426 | self.irreps_out = self.target_irreps
427 | self.linear = o3.Linear(
428 | irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
429 | )
430 |
431 | # Selector TensorProduct
432 | self.skip_tp = o3.FullyConnectedTensorProduct(
433 | self.irreps_out, self.node_attrs_irreps, self.irreps_out
434 | )
435 | self.reshape = reshape_irreps(self.irreps_out)
436 |
437 | def forward(
438 | self,
439 | node_attrs: torch.Tensor,
440 | node_feats: torch.Tensor,
441 | edge_attrs: torch.Tensor,
442 | edge_feats: torch.Tensor,
443 | edge_index: torch.Tensor,
444 | ) -> Tuple[torch.Tensor, torch.Tensor]:
445 | sender, receiver = edge_index
446 | num_nodes = node_feats.shape[0]
447 |
448 | node_feats = self.linear_up(node_feats)
449 | tp_weights = self.conv_tp_weights(edge_feats)
450 | mji = self.conv_tp(
451 | node_feats[sender], edge_attrs, tp_weights
452 | ) # [n_edges, irreps]
453 | message = scatter_sum(
454 | src=mji, index=receiver, dim=0, dim_size=num_nodes
455 | ) # [n_nodes, irreps]
456 | message = self.linear(message) / self.avg_num_neighbors
457 | message = self.skip_tp(message, node_attrs)
458 | return (
459 | self.reshape(message),
460 | None,
461 | ) # [n_nodes, channels, (lmax + 1)**2]
462 |
463 |
464 | class RealAgnosticResidualInteractionBlock(InteractionBlock):
465 | def _setup(self) -> None:
466 |
467 | # First linear
468 | self.linear_up = o3.Linear(
469 | self.node_feats_irreps,
470 | self.node_feats_irreps,
471 | internal_weights=True,
472 | shared_weights=True,
473 | )
474 | # TensorProduct
475 | irreps_mid, instructions = tp_out_irreps_with_instructions(
476 | self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps,
477 | )
478 | self.conv_tp = o3.TensorProduct(
479 | self.node_feats_irreps,
480 | self.edge_attrs_irreps,
481 | irreps_mid,
482 | instructions=instructions,
483 | shared_weights=False,
484 | internal_weights=False,
485 | )
486 |
487 | # Convolution weights
488 | input_dim = self.edge_feats_irreps.num_irreps
489 | self.conv_tp_weights = nn.FullyConnectedNet(
490 | [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], torch.nn.SiLU(),
491 | )
492 |
493 | # Linear
494 | irreps_mid = irreps_mid.simplify()
495 | self.irreps_out = self.target_irreps
496 | self.linear = o3.Linear(
497 | irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
498 | )
499 |
500 | # Selector TensorProduct
501 | self.skip_tp = o3.FullyConnectedTensorProduct(
502 | self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps
503 | )
504 | self.reshape = reshape_irreps(self.irreps_out)
505 |
506 | def forward(
507 | self,
508 | node_attrs: torch.Tensor,
509 | node_feats: torch.Tensor,
510 | edge_attrs: torch.Tensor,
511 | edge_feats: torch.Tensor,
512 | edge_index: torch.Tensor,
513 | ) -> Tuple[torch.Tensor, torch.Tensor]:
514 | sender, receiver = edge_index
515 | num_nodes = node_feats.shape[0]
516 |
517 | sc = self.skip_tp(node_feats, node_attrs)
518 | node_feats = self.linear_up(node_feats)
519 | tp_weights = self.conv_tp_weights(edge_feats)
520 | mji = self.conv_tp(
521 | node_feats[sender], edge_attrs, tp_weights
522 | ) # [n_edges, irreps]
523 | message = scatter_sum(
524 | src=mji, index=receiver, dim=0, dim_size=num_nodes
525 | ) # [n_nodes, irreps]
526 | message = self.linear(message) / self.avg_num_neighbors
527 | return (
528 | self.reshape(message),
529 | sc,
530 | ) # [n_nodes, channels, (lmax + 1)**2]
531 |
532 |
533 | class ScaleShiftBlock(torch.nn.Module):
534 | def __init__(self, scale: float, shift: float):
535 | super().__init__()
536 | self.register_buffer(
537 | "scale", torch.tensor(scale, dtype=torch.get_default_dtype())
538 | )
539 | self.register_buffer(
540 | "shift", torch.tensor(shift, dtype=torch.get_default_dtype())
541 | )
542 |
543 | def forward(self, x: torch.Tensor) -> torch.Tensor:
544 | return self.scale * x + self.shift
545 |
546 | def __repr__(self):
547 | return (
548 | f"{self.__class__.__name__}(scale={self.scale:.6f}, shift={self.shift:.6f})"
549 | )
550 |
--------------------------------------------------------------------------------
/models/layers/spherenet_layer.py:
--------------------------------------------------------------------------------
1 | ################################################################
2 | # Implementation of SphereNet layers
3 | #
4 | # Paper: Spherical Message Passing for 3D Graph Networks
5 | # by Y Liu, L Wang, M Liu, X Zhang, B Oztekin, and S Ji
6 | #
7 | # Orginal repository: https://github.com/divelab/DIG
8 | ################################################################
9 |
10 | import torch
11 | from torch import nn
12 | from torch.nn import Linear, Embedding
13 | from torch_geometric.nn.inits import glorot_orthogonal
14 | from torch_scatter import scatter
15 | from math import sqrt
16 |
17 | import numpy as np
18 | from scipy.optimize import brentq
19 | from scipy import special as sp
20 | import torch
21 | from math import pi as PI
22 |
23 | import sympy as sym
24 |
25 | import torch
26 | from torch_scatter import scatter
27 | from torch_sparse import SparseTensor
28 | from math import pi as PI
29 |
30 | def swish(x):
31 | return x * torch.sigmoid(x)
32 |
33 | class emb(torch.nn.Module):
34 | def __init__(self, num_spherical, num_radial, cutoff, envelope_exponent):
35 | super(emb, self).__init__()
36 | self.dist_emb = dist_emb(num_radial, cutoff, envelope_exponent)
37 | self.angle_emb = angle_emb(num_spherical, num_radial, cutoff, envelope_exponent)
38 | self.torsion_emb = torsion_emb(num_spherical, num_radial, cutoff, envelope_exponent)
39 | self.reset_parameters()
40 |
41 | def reset_parameters(self):
42 | self.dist_emb.reset_parameters()
43 |
44 | def forward(self, dist, angle, torsion, idx_kj):
45 | dist_emb = self.dist_emb(dist)
46 | angle_emb = self.angle_emb(dist, angle, idx_kj)
47 | torsion_emb = self.torsion_emb(dist, angle, torsion, idx_kj)
48 | return dist_emb, angle_emb, torsion_emb
49 |
50 | class ResidualLayer(torch.nn.Module):
51 | def __init__(self, hidden_channels, act=swish):
52 | super(ResidualLayer, self).__init__()
53 | self.act = act
54 | self.lin1 = Linear(hidden_channels, hidden_channels)
55 | self.lin2 = Linear(hidden_channels, hidden_channels)
56 |
57 | self.reset_parameters()
58 |
59 | def reset_parameters(self):
60 | glorot_orthogonal(self.lin1.weight, scale=2.0)
61 | self.lin1.bias.data.fill_(0)
62 | glorot_orthogonal(self.lin2.weight, scale=2.0)
63 | self.lin2.bias.data.fill_(0)
64 |
65 | def forward(self, x):
66 | return x + self.act(self.lin2(self.act(self.lin1(x))))
67 |
68 |
69 | class init(torch.nn.Module):
70 | def __init__(self, num_radial, hidden_channels, act=swish, use_node_features=True):
71 | super(init, self).__init__()
72 | self.act = act
73 | self.use_node_features = use_node_features
74 | if self.use_node_features:
75 | self.emb = Embedding(95, hidden_channels)
76 | else: # option to use no node features and a learned embedding vector for each node instead
77 | self.node_embedding = nn.Parameter(torch.empty((hidden_channels,)))
78 | nn.init.normal_(self.node_embedding)
79 | self.lin_rbf_0 = Linear(num_radial, hidden_channels)
80 | self.lin = Linear(3 * hidden_channels, hidden_channels)
81 | self.lin_rbf_1 = nn.Linear(num_radial, hidden_channels, bias=False)
82 | self.reset_parameters()
83 |
84 | def reset_parameters(self):
85 | if self.use_node_features:
86 | self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))
87 | self.lin_rbf_0.reset_parameters()
88 | self.lin.reset_parameters()
89 | glorot_orthogonal(self.lin_rbf_1.weight, scale=2.0)
90 |
91 | def forward(self, x, emb, i, j):
92 | rbf,_,_ = emb
93 | if self.use_node_features:
94 | x = self.emb(x)
95 | else:
96 | x = self.node_embedding[None, :].expand(x.shape[0], -1)
97 | rbf0 = self.act(self.lin_rbf_0(rbf))
98 | e1 = self.act(self.lin(torch.cat([x[i], x[j], rbf0], dim=-1)))
99 | e2 = self.lin_rbf_1(rbf) * e1
100 |
101 | return e1, e2
102 |
103 |
104 | class update_e(torch.nn.Module):
105 | def __init__(self, hidden_channels, int_emb_size, basis_emb_size_dist, basis_emb_size_angle, basis_emb_size_torsion, num_spherical, num_radial,
106 | num_before_skip, num_after_skip, act=swish):
107 | super(update_e, self).__init__()
108 | self.act = act
109 | self.lin_rbf1 = nn.Linear(num_radial, basis_emb_size_dist, bias=False)
110 | self.lin_rbf2 = nn.Linear(basis_emb_size_dist, hidden_channels, bias=False)
111 | self.lin_sbf1 = nn.Linear(num_spherical * num_radial, basis_emb_size_angle, bias=False)
112 | self.lin_sbf2 = nn.Linear(basis_emb_size_angle, int_emb_size, bias=False)
113 | self.lin_t1 = nn.Linear(num_spherical * num_spherical * num_radial, basis_emb_size_torsion, bias=False)
114 | self.lin_t2 = nn.Linear(basis_emb_size_torsion, int_emb_size, bias=False)
115 | self.lin_rbf = nn.Linear(num_radial, hidden_channels, bias=False)
116 |
117 | self.lin_kj = nn.Linear(hidden_channels, hidden_channels)
118 | self.lin_ji = nn.Linear(hidden_channels, hidden_channels)
119 |
120 | self.lin_down = nn.Linear(hidden_channels, int_emb_size, bias=False)
121 | self.lin_up = nn.Linear(int_emb_size, hidden_channels, bias=False)
122 |
123 | self.layers_before_skip = torch.nn.ModuleList([
124 | ResidualLayer(hidden_channels, act)
125 | for _ in range(num_before_skip)
126 | ])
127 | self.lin = nn.Linear(hidden_channels, hidden_channels)
128 | self.layers_after_skip = torch.nn.ModuleList([
129 | ResidualLayer(hidden_channels, act)
130 | for _ in range(num_after_skip)
131 | ])
132 |
133 | self.reset_parameters()
134 |
135 | def reset_parameters(self):
136 | glorot_orthogonal(self.lin_rbf1.weight, scale=2.0)
137 | glorot_orthogonal(self.lin_rbf2.weight, scale=2.0)
138 | glorot_orthogonal(self.lin_sbf1.weight, scale=2.0)
139 | glorot_orthogonal(self.lin_sbf2.weight, scale=2.0)
140 | glorot_orthogonal(self.lin_t1.weight, scale=2.0)
141 | glorot_orthogonal(self.lin_t2.weight, scale=2.0)
142 |
143 | glorot_orthogonal(self.lin_kj.weight, scale=2.0)
144 | self.lin_kj.bias.data.fill_(0)
145 | glorot_orthogonal(self.lin_ji.weight, scale=2.0)
146 | self.lin_ji.bias.data.fill_(0)
147 |
148 | glorot_orthogonal(self.lin_down.weight, scale=2.0)
149 | glorot_orthogonal(self.lin_up.weight, scale=2.0)
150 |
151 | for res_layer in self.layers_before_skip:
152 | res_layer.reset_parameters()
153 | glorot_orthogonal(self.lin.weight, scale=2.0)
154 | self.lin.bias.data.fill_(0)
155 | for res_layer in self.layers_after_skip:
156 | res_layer.reset_parameters()
157 |
158 | glorot_orthogonal(self.lin_rbf.weight, scale=2.0)
159 |
160 | def forward(self, x, emb, idx_kj, idx_ji):
161 | rbf0, sbf, t = emb
162 | x1,_ = x
163 |
164 | x_ji = self.act(self.lin_ji(x1))
165 | x_kj = self.act(self.lin_kj(x1))
166 |
167 | rbf = self.lin_rbf1(rbf0)
168 | rbf = self.lin_rbf2(rbf)
169 | x_kj = x_kj * rbf
170 |
171 | x_kj = self.act(self.lin_down(x_kj))
172 |
173 | sbf = self.lin_sbf1(sbf)
174 | sbf = self.lin_sbf2(sbf)
175 | x_kj = x_kj[idx_kj] * sbf
176 |
177 | t = self.lin_t1(t)
178 | t = self.lin_t2(t)
179 | x_kj = x_kj * t
180 |
181 | x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x1.size(0))
182 | x_kj = self.act(self.lin_up(x_kj))
183 |
184 | e1 = x_ji + x_kj
185 | for layer in self.layers_before_skip:
186 | e1 = layer(e1)
187 | e1 = self.act(self.lin(e1)) + x1
188 | for layer in self.layers_after_skip:
189 | e1 = layer(e1)
190 | e2 = self.lin_rbf(rbf0) * e1
191 |
192 | return e1, e2
193 |
194 |
195 | class update_v(torch.nn.Module):
196 | def __init__(self, hidden_channels, out_emb_channels, out_channels, num_output_layers, act, output_init):
197 | super(update_v, self).__init__()
198 | self.act = act
199 | self.output_init = output_init
200 |
201 | self.lin_up = nn.Linear(hidden_channels, out_emb_channels, bias=True)
202 | self.lins = torch.nn.ModuleList()
203 | for _ in range(num_output_layers):
204 | self.lins.append(nn.Linear(out_emb_channels, out_emb_channels))
205 | self.lin = nn.Linear(out_emb_channels, out_channels, bias=False)
206 |
207 | self.reset_parameters()
208 |
209 | def reset_parameters(self):
210 | glorot_orthogonal(self.lin_up.weight, scale=2.0)
211 | for lin in self.lins:
212 | glorot_orthogonal(lin.weight, scale=2.0)
213 | lin.bias.data.fill_(0)
214 | if self.output_init == 'zeros':
215 | self.lin.weight.data.fill_(0)
216 | if self.output_init == 'GlorotOrthogonal':
217 | glorot_orthogonal(self.lin.weight, scale=2.0)
218 |
219 | def forward(self, e, i):
220 | _, e2 = e
221 | v = scatter(e2, i, dim=0)
222 | v = self.lin_up(v)
223 | for lin in self.lins:
224 | v = self.act(lin(v))
225 | v = self.lin(v)
226 | return v
227 |
228 |
229 | class update_u(torch.nn.Module):
230 | def __init__(self):
231 | super(update_u, self).__init__()
232 |
233 | def forward(self, u, v, batch):
234 | u += scatter(v, batch, dim=0)
235 | return u
236 |
237 | # Based on the code from: https://github.com/klicperajo/dimenet,
238 | # https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet_utils.py
239 |
240 |
241 | def Jn(r, n):
242 | return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r)
243 |
244 |
245 | def Jn_zeros(n, k):
246 | zerosj = np.zeros((n, k), dtype='float32')
247 | zerosj[0] = np.arange(1, k + 1) * np.pi
248 | points = np.arange(1, k + n) * np.pi
249 | racines = np.zeros(k + n - 1, dtype='float32')
250 | for i in range(1, n):
251 | for j in range(k + n - 1 - i):
252 | foo = brentq(Jn, points[j], points[j + 1], (i, ))
253 | racines[j] = foo
254 | points = racines
255 | zerosj[i][:k] = racines[:k]
256 |
257 | return zerosj
258 |
259 |
260 | def spherical_bessel_formulas(n):
261 | x = sym.symbols('x')
262 |
263 | f = [sym.sin(x) / x]
264 | a = sym.sin(x) / x
265 | for i in range(1, n):
266 | b = sym.diff(a, x) / x
267 | f += [sym.simplify(b * (-x)**i)]
268 | a = sym.simplify(b)
269 | return f
270 |
271 |
272 | def bessel_basis(n, k):
273 | zeros = Jn_zeros(n, k)
274 | normalizer = []
275 | for order in range(n):
276 | normalizer_tmp = []
277 | for i in range(k):
278 | normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1)**2]
279 | normalizer_tmp = 1 / np.array(normalizer_tmp)**0.5
280 | normalizer += [normalizer_tmp]
281 |
282 | f = spherical_bessel_formulas(n)
283 | x = sym.symbols('x')
284 | bess_basis = []
285 | for order in range(n):
286 | bess_basis_tmp = []
287 | for i in range(k):
288 | bess_basis_tmp += [
289 | sym.simplify(normalizer[order][i] *
290 | f[order].subs(x, zeros[order, i] * x))
291 | ]
292 | bess_basis += [bess_basis_tmp]
293 | return bess_basis
294 |
295 |
296 | def sph_harm_prefactor(k, m):
297 | return ((2 * k + 1) * np.math.factorial(k - abs(m)) /
298 | (4 * np.pi * np.math.factorial(k + abs(m))))**0.5
299 |
300 |
301 | def associated_legendre_polynomials(k, zero_m_only=True):
302 | z = sym.symbols('z')
303 | P_l_m = [[0] * (j + 1) for j in range(k)]
304 |
305 | P_l_m[0][0] = 1
306 | if k > 0:
307 | P_l_m[1][0] = z
308 |
309 | for j in range(2, k):
310 | P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] -
311 | (j - 1) * P_l_m[j - 2][0]) / j)
312 | if not zero_m_only:
313 | for i in range(1, k):
314 | P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1])
315 | if i + 1 < k:
316 | P_l_m[i + 1][i] = sym.simplify(
317 | (2 * i + 1) * z * P_l_m[i][i])
318 | for j in range(i + 2, k):
319 | P_l_m[j][i] = sym.simplify(
320 | ((2 * j - 1) * z * P_l_m[j - 1][i] -
321 | (i + j - 1) * P_l_m[j - 2][i]) / (j - i))
322 |
323 | return P_l_m
324 |
325 |
326 | def real_sph_harm(l, zero_m_only=False, spherical_coordinates=True):
327 | """
328 | Computes formula strings of the the real part of the spherical harmonics up to order l (excluded).
329 | Variables are either cartesian coordinates x,y,z on the unit sphere or spherical coordinates phi and theta.
330 | """
331 | if not zero_m_only:
332 | x = sym.symbols('x')
333 | y = sym.symbols('y')
334 | S_m = [x*0]
335 | C_m = [1+0*x]
336 | # S_m = [0]
337 | # C_m = [1]
338 | for i in range(1, l):
339 | x = sym.symbols('x')
340 | y = sym.symbols('y')
341 | S_m += [x*S_m[i-1] + y*C_m[i-1]]
342 | C_m += [x*C_m[i-1] - y*S_m[i-1]]
343 |
344 | P_l_m = associated_legendre_polynomials(l, zero_m_only)
345 | if spherical_coordinates:
346 | theta = sym.symbols('theta')
347 | z = sym.symbols('z')
348 | for i in range(len(P_l_m)):
349 | for j in range(len(P_l_m[i])):
350 | if type(P_l_m[i][j]) != int:
351 | P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta))
352 | if not zero_m_only:
353 | phi = sym.symbols('phi')
354 | for i in range(len(S_m)):
355 | S_m[i] = S_m[i].subs(x, sym.sin(
356 | theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi))
357 | for i in range(len(C_m)):
358 | C_m[i] = C_m[i].subs(x, sym.sin(
359 | theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi))
360 |
361 | Y_func_l_m = [['0']*(2*j + 1) for j in range(l)]
362 | for i in range(l):
363 | Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0])
364 |
365 | if not zero_m_only:
366 | for i in range(1, l):
367 | for j in range(1, i + 1):
368 | Y_func_l_m[i][j] = sym.simplify(
369 | 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j])
370 | for i in range(1, l):
371 | for j in range(1, i + 1):
372 | Y_func_l_m[i][-j] = sym.simplify(
373 | 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j])
374 |
375 | return Y_func_l_m
376 |
377 |
378 | class Envelope(torch.nn.Module):
379 | def __init__(self, exponent):
380 | super(Envelope, self).__init__()
381 | self.p = exponent + 1
382 | self.a = -(self.p + 1) * (self.p + 2) / 2
383 | self.b = self.p * (self.p + 2)
384 | self.c = -self.p * (self.p + 1) / 2
385 |
386 | def forward(self, x):
387 | p, a, b, c = self.p, self.a, self.b, self.c
388 | x_pow_p0 = x.pow(p - 1)
389 | x_pow_p1 = x_pow_p0 * x
390 | x_pow_p2 = x_pow_p1 * x
391 | return 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2
392 |
393 |
394 | class dist_emb(torch.nn.Module):
395 | def __init__(self, num_radial, cutoff=5.0, envelope_exponent=5):
396 | super(dist_emb, self).__init__()
397 | self.cutoff = cutoff
398 | self.envelope = Envelope(envelope_exponent)
399 |
400 | self.freq = torch.nn.Parameter(torch.Tensor(num_radial))
401 |
402 | self.reset_parameters()
403 |
404 | def reset_parameters(self):
405 | self.freq.data = torch.arange(1, self.freq.numel() + 1).float().mul_(PI)
406 |
407 | def forward(self, dist):
408 | dist = dist.unsqueeze(-1) / self.cutoff
409 | return self.envelope(dist) * (self.freq * dist).sin()
410 |
411 |
412 | class angle_emb(torch.nn.Module):
413 | def __init__(self, num_spherical, num_radial, cutoff=5.0,
414 | envelope_exponent=5):
415 | super(angle_emb, self).__init__()
416 | assert num_radial <= 64
417 | self.num_spherical = num_spherical
418 | self.num_radial = num_radial
419 | self.cutoff = cutoff
420 | # self.envelope = Envelope(envelope_exponent)
421 |
422 | bessel_forms = bessel_basis(num_spherical, num_radial)
423 | sph_harm_forms = real_sph_harm(num_spherical)
424 | self.sph_funcs = []
425 | self.bessel_funcs = []
426 |
427 | x, theta = sym.symbols('x theta')
428 | modules = {'sin': torch.sin, 'cos': torch.cos}
429 | for i in range(num_spherical):
430 | if i == 0:
431 | sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0)
432 | self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1)
433 | else:
434 | sph = sym.lambdify([theta], sph_harm_forms[i][0], modules)
435 | self.sph_funcs.append(sph)
436 | for j in range(num_radial):
437 | bessel = sym.lambdify([x], bessel_forms[i][j], modules)
438 | self.bessel_funcs.append(bessel)
439 |
440 | def forward(self, dist, angle, idx_kj):
441 | dist = dist / self.cutoff
442 | rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1)
443 | # rbf = self.envelope(dist).unsqueeze(-1) * rbf
444 |
445 | cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1)
446 |
447 | n, k = self.num_spherical, self.num_radial
448 | out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k)
449 | return out
450 |
451 |
452 | class torsion_emb(torch.nn.Module):
453 | def __init__(self, num_spherical, num_radial, cutoff=5.0,
454 | envelope_exponent=5):
455 | super(torsion_emb, self).__init__()
456 | assert num_radial <= 64
457 | self.num_spherical = num_spherical #
458 | self.num_radial = num_radial
459 | self.cutoff = cutoff
460 | # self.envelope = Envelope(envelope_exponent)
461 |
462 | bessel_forms = bessel_basis(num_spherical, num_radial)
463 | sph_harm_forms = real_sph_harm(num_spherical, zero_m_only=False)
464 | self.sph_funcs = []
465 | self.bessel_funcs = []
466 |
467 | x = sym.symbols('x')
468 | theta = sym.symbols('theta')
469 | phi = sym.symbols('phi')
470 | modules = {'sin': torch.sin, 'cos': torch.cos}
471 | for i in range(self.num_spherical):
472 | if i == 0:
473 | sph1 = sym.lambdify([theta, phi], sph_harm_forms[i][0], modules)
474 | self.sph_funcs.append(lambda x, y: torch.zeros_like(x) + torch.zeros_like(y) + sph1(0,0)) #torch.zeros_like(x) + torch.zeros_like(y)
475 | else:
476 | for k in range(-i, i + 1):
477 | sph = sym.lambdify([theta, phi], sph_harm_forms[i][k+i], modules)
478 | self.sph_funcs.append(sph)
479 | for j in range(self.num_radial):
480 | bessel = sym.lambdify([x], bessel_forms[i][j], modules)
481 | self.bessel_funcs.append(bessel)
482 |
483 | def forward(self, dist, angle, phi, idx_kj):
484 | dist = dist / self.cutoff
485 | rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1)
486 | cbf = torch.stack([f(angle, phi) for f in self.sph_funcs], dim=1)
487 |
488 | n, k = self.num_spherical, self.num_radial
489 | out = (rbf[idx_kj].view(-1, 1, n, k) * cbf.view(-1, n, n, 1)).view(-1, n * n * k)
490 | return out
491 |
492 |
493 | # Based on the code from: https://github.com/klicperajo/dimenet,
494 | # https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet.py
495 |
496 | def xyz_to_dat(pos, edge_index, num_nodes, use_torsion = False):
497 | """
498 | Compute the diatance, angle, and torsion from geometric information.
499 |
500 | Args:
501 | pos: Geometric information for every node in the graph.
502 | edge_index: Edge index of the graph.
503 | number_nodes: Number of nodes in the graph.
504 | use_torsion: If set to :obj:`True`, will return distance, angle and torsion, otherwise only return distance and angle (also retrun some useful index). (default: :obj:`False`)
505 | """
506 | j, i = edge_index # j->i
507 |
508 | # Calculate distances. # number of edges
509 | dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()
510 |
511 | value = torch.arange(j.size(0), device=j.device)
512 | adj_t = SparseTensor(row=i, col=j, value=value, sparse_sizes=(num_nodes, num_nodes))
513 | adj_t_row = adj_t[j]
514 | num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)
515 |
516 | # Node indices (k->j->i) for triplets.
517 | idx_i = i.repeat_interleave(num_triplets)
518 | idx_j = j.repeat_interleave(num_triplets)
519 | idx_k = adj_t_row.storage.col()
520 | mask = idx_i != idx_k
521 | idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]
522 |
523 | # Edge indices (k-j, j->i) for triplets.
524 | idx_kj = adj_t_row.storage.value()[mask]
525 | idx_ji = adj_t_row.storage.row()[mask]
526 |
527 | # Calculate angles. 0 to pi
528 | pos_ji = pos[idx_i] - pos[idx_j]
529 | pos_jk = pos[idx_k] - pos[idx_j]
530 | a = (pos_ji * pos_jk).sum(dim=-1) # cos_angle * |pos_ji| * |pos_jk|
531 | b = torch.cross(pos_ji, pos_jk).norm(dim=-1) # sin_angle * |pos_ji| * |pos_jk|
532 | angle = torch.atan2(b, a)
533 |
534 |
535 | if use_torsion:
536 | # Prepare torsion idxes.
537 | idx_batch = torch.arange(len(idx_i),device=j.device)
538 | idx_k_n = adj_t[idx_j].storage.col()
539 | repeat = num_triplets
540 | num_triplets_t = num_triplets.repeat_interleave(repeat)[mask]
541 | idx_i_t = idx_i.repeat_interleave(num_triplets_t)
542 | idx_j_t = idx_j.repeat_interleave(num_triplets_t)
543 | idx_k_t = idx_k.repeat_interleave(num_triplets_t)
544 | idx_batch_t = idx_batch.repeat_interleave(num_triplets_t)
545 | mask = idx_i_t != idx_k_n
546 | idx_i_t, idx_j_t, idx_k_t, idx_k_n, idx_batch_t = idx_i_t[mask], idx_j_t[mask], idx_k_t[mask], idx_k_n[mask], idx_batch_t[mask]
547 |
548 | # Calculate torsions.
549 | pos_j0 = pos[idx_k_t] - pos[idx_j_t]
550 | pos_ji = pos[idx_i_t] - pos[idx_j_t]
551 | pos_jk = pos[idx_k_n] - pos[idx_j_t]
552 | dist_ji = pos_ji.pow(2).sum(dim=-1).sqrt()
553 | plane1 = torch.cross(pos_ji, pos_j0)
554 | plane2 = torch.cross(pos_ji, pos_jk)
555 | a = (plane1 * plane2).sum(dim=-1) # cos_angle * |plane1| * |plane2|
556 | b = (torch.cross(plane1, plane2) * pos_ji).sum(dim=-1) / dist_ji
557 | torsion1 = torch.atan2(b, a) # -pi to pi
558 | torsion1[torsion1<=0]+=2*PI # 0 to 2pi
559 | torsion = scatter(torsion1,idx_batch_t,reduce='min')
560 |
561 | return dist, angle, torsion, i, j, idx_kj, idx_ji
562 |
563 | else:
564 | return dist, angle, i, j, idx_kj, idx_ji
565 |
--------------------------------------------------------------------------------