├── experiments ├── __init__.py ├── utils │ ├── __init__.py │ ├── plot_utils.py │ └── train_utils.py ├── fig │ ├── kchains.png │ ├── rotsym.png │ ├── incompleteness.png │ └── axes-of-expressivity.png ├── rotsym.ipynb ├── kchains.ipynb └── incompleteness.ipynb ├── models ├── layers │ ├── __init__.py │ ├── tfn_layer.py │ ├── egnn_layer.py │ ├── gvp_layer.py │ └── spherenet_layer.py ├── __init__.py ├── mace_modules │ ├── __init__.py │ ├── radial.py │ ├── irreps_tools.py │ ├── cg.py │ ├── symmetric_contraction.py │ └── blocks.py ├── schnet.py ├── egnn.py ├── dimenet.py ├── spherenet.py ├── gvpgnn.py ├── tfn.py └── mace.py ├── LICENSE ├── .gitignore └── README.md /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /experiments/fig/kchains.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/geometric-gnn-dojo/HEAD/experiments/fig/kchains.png -------------------------------------------------------------------------------- /experiments/fig/rotsym.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/geometric-gnn-dojo/HEAD/experiments/fig/rotsym.png -------------------------------------------------------------------------------- /experiments/fig/incompleteness.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/geometric-gnn-dojo/HEAD/experiments/fig/incompleteness.png -------------------------------------------------------------------------------- /experiments/fig/axes-of-expressivity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaitjo/geometric-gnn-dojo/HEAD/experiments/fig/axes-of-expressivity.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.schnet import SchNetModel 2 | from models.dimenet import DimeNetPPModel 3 | from models.spherenet import SphereNetModel 4 | from models.egnn import EGNNModel 5 | from models.gvpgnn import GVPGNNModel 6 | from models.tfn import TFNModel 7 | from models.mace import MACEModel 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Chaitanya K. Joshi, Simon V. Mathis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *.cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | # DotEnv configuration 60 | .env 61 | 62 | # Database 63 | *.db 64 | *.rdb 65 | 66 | # Pycharm 67 | .idea 68 | 69 | # VS Code 70 | .vscode/ 71 | 72 | # Spyder 73 | .spyproject/ 74 | 75 | # Jupyter NB Checkpoints 76 | .ipynb_checkpoints/ 77 | 78 | # exclude data from source control by default 79 | io 80 | io/ 81 | 82 | # Mac OS-specific storage files 83 | .DS_Store 84 | 85 | # vim 86 | *.swp 87 | *.swo 88 | 89 | # Mypy cache 90 | .mypy_cache/ 91 | 92 | # Other 93 | .history 94 | .history/ 95 | logs/ 96 | venv*/ 97 | venv_gpu/ 98 | lightning_logs 99 | .DS_Store -------------------------------------------------------------------------------- /experiments/utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch_geometric.utils import to_networkx 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def plot_2d(data, lim=10): 7 | # The graph to visualize 8 | G = to_networkx(data) 9 | pos = data.pos.numpy() 10 | 11 | # Extract node and edge positions from the layout 12 | node_xyz = np.array([pos[v, :2] for v in sorted(G)]) 13 | edge_xyz = np.array([(pos[u, :2], pos[v, :2]) for u, v in G.edges()]) 14 | 15 | # Create the 2D figure 16 | fig = plt.figure() 17 | ax = fig.add_subplot(111) 18 | 19 | # Plot the nodes - alpha is scaled by "depth" automatically 20 | ax.scatter(*node_xyz.T, s=100, c=data.atoms.numpy(), cmap="rainbow") 21 | 22 | # Plot the edges 23 | for vizedge in edge_xyz: 24 | ax.plot(*vizedge.T, color="tab:gray") 25 | 26 | # Turn gridlines off 27 | # ax.grid(False) 28 | 29 | # Suppress tick labels 30 | # for dim in (ax.xaxis, ax.yaxis, ax.zaxis): 31 | # dim.set_ticks([]) 32 | 33 | # Set axes labels and limits 34 | ax.set_xlabel("x") 35 | ax.set_ylabel("y") 36 | ax.set_xlim([-lim, lim]) 37 | ax.set_ylim([-lim, lim]) 38 | ax.set_aspect('equal', 'box') 39 | 40 | # fig.tight_layout() 41 | plt.show() 42 | 43 | 44 | def plot_3d(data, lim=10): 45 | # The graph to visualize 46 | G = to_networkx(data) 47 | pos = data.pos.numpy() 48 | 49 | # Extract node and edge positions from the layout 50 | node_xyz = np.array([pos[v] for v in sorted(G)]) 51 | edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()]) 52 | 53 | # Create the 3D figure 54 | fig = plt.figure() 55 | ax = fig.add_subplot(111, projection="3d") 56 | 57 | # Plot the nodes - alpha is scaled by "depth" automatically 58 | ax.scatter(*node_xyz.T, s=100, c=data.atoms.numpy(), cmap="rainbow") 59 | 60 | # Plot the edges 61 | for vizedge in edge_xyz: 62 | ax.plot(*vizedge.T, color="tab:gray") 63 | 64 | # Turn gridlines off 65 | # ax.grid(False) 66 | 67 | # Suppress tick labels 68 | # for dim in (ax.xaxis, ax.yaxis, ax.zaxis): 69 | # dim.set_ticks([]) 70 | 71 | # Set axes labels and limits 72 | ax.set_xlabel("x") 73 | ax.set_ylabel("y") 74 | ax.set_zlabel("z") 75 | ax.set_xlim([-lim, lim]) 76 | ax.set_ylim([-lim, lim]) 77 | ax.set_zlim([-lim, lim]) 78 | 79 | # fig.tight_layout() 80 | plt.show() 81 | -------------------------------------------------------------------------------- /models/mace_modules/__init__.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # This directory contains an implementation of MACE, with minor adaptations 3 | # 4 | # Paper: MACE: Higher Order Equivariant Message Passing Neural Networks 5 | # for Fast and Accurate Force Fields, Batatia et al. 6 | # 7 | # Orginal repository: https://github.com/ACEsuit/mace 8 | ########################################################################################### 9 | 10 | from typing import Callable, Dict, Optional, Type 11 | 12 | import torch 13 | 14 | from .blocks import ( 15 | AgnosticNonlinearInteractionBlock, 16 | AgnosticResidualNonlinearInteractionBlock, 17 | AtomicEnergiesBlock, 18 | EquivariantProductBasisBlock, 19 | InteractionBlock, 20 | LinearNodeEmbeddingBlock, 21 | LinearReadoutBlock, 22 | NonLinearReadoutBlock, 23 | RadialEmbeddingBlock, 24 | RealAgnosticInteractionBlock, 25 | RealAgnosticResidualInteractionBlock, 26 | ResidualElementDependentInteractionBlock, 27 | ScaleShiftBlock, 28 | ) 29 | from .radial import BesselBasis, PolynomialCutoff 30 | from .symmetric_contraction import SymmetricContraction 31 | 32 | interaction_classes: Dict[str, Type[InteractionBlock]] = { 33 | "AgnosticNonlinearInteractionBlock": AgnosticNonlinearInteractionBlock, 34 | "ResidualElementDependentInteractionBlock": ResidualElementDependentInteractionBlock, 35 | "AgnosticResidualNonlinearInteractionBlock": AgnosticResidualNonlinearInteractionBlock, 36 | "RealAgnosticResidualInteractionBlock": RealAgnosticResidualInteractionBlock, 37 | "RealAgnosticInteractionBlock": RealAgnosticInteractionBlock, 38 | } 39 | 40 | gate_dict: Dict[str, Optional[Callable]] = { 41 | "abs": torch.abs, 42 | "tanh": torch.tanh, 43 | "silu": torch.nn.functional.silu, 44 | "None": None, 45 | } 46 | 47 | __all__ = [ 48 | "AtomicEnergiesBlock", 49 | "RadialEmbeddingBlock", 50 | "LinearNodeEmbeddingBlock", 51 | "LinearReadoutBlock", 52 | "EquivariantProductBasisBlock", 53 | "ScaleShiftBlock", 54 | "InteractionBlock", 55 | "NonLinearReadoutBlock", 56 | "PolynomialCutoff", 57 | "BesselBasis", 58 | "MACE", 59 | "ScaleShiftMACE", 60 | "BOTNet", 61 | "ScaleShiftBOTNet", 62 | "EnergyForcesLoss", 63 | "WeightedEnergyForcesLoss", 64 | "WeightedForcesLoss", 65 | "SymmetricContraction", 66 | "interaction_classes", 67 | "compute_mean_std_atomic_inter_energy", 68 | "compute_avg_num_neighbors", 69 | ] 70 | -------------------------------------------------------------------------------- /models/mace_modules/radial.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # Radial basis and cutoff 3 | # Authors: Ilyes Batatia, Gregor Simm 4 | # This program is distributed under the MIT License (see MIT.md) 5 | ########################################################################################### 6 | 7 | import numpy as np 8 | import torch 9 | from e3nn.util.jit import compile_mode 10 | 11 | 12 | class BesselBasis(torch.nn.Module): 13 | """ 14 | Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020. 15 | Equation (7) 16 | """ 17 | 18 | def __init__(self, r_max: float, num_basis=8, trainable=False): 19 | super().__init__() 20 | 21 | bessel_weights = ( 22 | np.pi 23 | / r_max 24 | * torch.linspace( 25 | start=1.0, 26 | end=num_basis, 27 | steps=num_basis, 28 | dtype=torch.get_default_dtype(), 29 | ) 30 | ) 31 | if trainable: 32 | self.bessel_weights = torch.nn.Parameter(bessel_weights) 33 | else: 34 | self.register_buffer("bessel_weights", bessel_weights) 35 | 36 | self.register_buffer( 37 | "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) 38 | ) 39 | self.register_buffer( 40 | "prefactor", 41 | torch.tensor(np.sqrt(2.0 / r_max), dtype=torch.get_default_dtype()), 42 | ) 43 | 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: # [..., 1] 45 | numerator = torch.sin(self.bessel_weights * x) # [..., num_basis] 46 | return self.prefactor * (numerator / x) 47 | 48 | def __repr__(self): 49 | return ( 50 | f"{self.__class__.__name__}(r_max={self.r_max}, num_basis={len(self.bessel_weights)}, " 51 | f"trainable={self.bessel_weights.requires_grad})" 52 | ) 53 | 54 | 55 | class PolynomialCutoff(torch.nn.Module): 56 | """ 57 | Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020. 58 | Equation (8) 59 | """ 60 | 61 | p: torch.Tensor 62 | r_max: torch.Tensor 63 | 64 | def __init__(self, r_max: float, p=6): 65 | super().__init__() 66 | self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) 67 | self.register_buffer( 68 | "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) 69 | ) 70 | 71 | def forward(self, x: torch.Tensor) -> torch.Tensor: 72 | envelope = ( 73 | 1.0 74 | - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p) 75 | + self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1) 76 | - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2) 77 | ) 78 | return envelope * (x < self.r_max) 79 | 80 | def __repr__(self): 81 | return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" 82 | -------------------------------------------------------------------------------- /models/schnet.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch_geometric.nn import SchNet 6 | from torch_geometric.nn import global_add_pool, global_mean_pool 7 | 8 | 9 | class SchNetModel(SchNet): 10 | """ 11 | SchNet model from "Schnet - a deep learning architecture for molecules and materials". 12 | 13 | This class extends the SchNet base class for PyG. 14 | """ 15 | def __init__( 16 | self, 17 | hidden_channels: int = 128, 18 | in_dim: int = 1, 19 | out_dim: int = 1, 20 | num_filters: int = 128, 21 | num_layers: int = 6, 22 | num_gaussians: int = 50, 23 | cutoff: float = 10, 24 | max_num_neighbors: int = 32, 25 | pool: str = 'sum' 26 | ): 27 | """ 28 | Initializes an instance of the SchNetModel class with the provided parameters. 29 | 30 | Parameters: 31 | - hidden_channels (int): Number of channels in the hidden layers (default: 128) 32 | - in_dim (int): Input dimension of the model (default: 1) 33 | - out_dim (int): Output dimension of the model (default: 1) 34 | - num_filters (int): Number of filters used in convolutional layers (default: 128) 35 | - num_layers (int): Number of convolutional layers in the model (default: 6) 36 | - num_gaussians (int): Number of Gaussian functions used for radial filters (default: 50) 37 | - cutoff (float): Cutoff distance for interactions (default: 10) 38 | - max_num_neighbors (int): Maximum number of neighboring atoms to consider (default: 32) 39 | - pool (str): Global pooling method to be used (default: "sum") 40 | """ 41 | super().__init__( 42 | hidden_channels, 43 | num_filters, 44 | num_layers, 45 | num_gaussians, 46 | cutoff, 47 | interaction_graph=None, 48 | max_num_neighbors=max_num_neighbors, 49 | readout=pool, 50 | dipole=False, 51 | mean=None, 52 | std=None, 53 | atomref=None 54 | ) 55 | 56 | # Global pooling/readout function 57 | self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] 58 | 59 | # Overwrite atom embedding and final predictor 60 | self.lin2 = torch.nn.Linear(hidden_channels // 2, out_dim) 61 | 62 | def forward(self, batch): 63 | 64 | h = self.embedding(batch.atoms) # (n,) -> (n, d) 65 | 66 | row, col = batch.edge_index 67 | edge_weight = (batch.pos[row] - batch.pos[col]).norm(dim=-1) 68 | edge_attr = self.distance_expansion(edge_weight) 69 | 70 | for interaction in self.interactions: 71 | # # Message passing layer: (n, d) -> (n, d) 72 | h = h + interaction(h, batch.edge_index, edge_weight, edge_attr) 73 | 74 | out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d) 75 | 76 | out = self.lin1(out) 77 | out = self.act(out) 78 | out = self.lin2(out) # (batch_size, out_dim) 79 | 80 | return out 81 | -------------------------------------------------------------------------------- /models/egnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from torch_geometric.nn import global_add_pool, global_mean_pool 4 | 5 | from models.layers.egnn_layer import EGNNLayer 6 | 7 | 8 | class EGNNModel(torch.nn.Module): 9 | """ 10 | E-GNN model from "E(n) Equivariant Graph Neural Networks". 11 | """ 12 | def __init__( 13 | self, 14 | num_layers: int = 5, 15 | emb_dim: int = 128, 16 | in_dim: int = 1, 17 | out_dim: int = 1, 18 | activation: str = "relu", 19 | norm: str = "layer", 20 | aggr: str = "sum", 21 | pool: str = "sum", 22 | residual: bool = True, 23 | equivariant_pred: bool = False 24 | ): 25 | """ 26 | Initializes an instance of the EGNNModel class with the provided parameters. 27 | 28 | Parameters: 29 | - num_layers (int): Number of layers in the model (default: 5) 30 | - emb_dim (int): Dimension of the node embeddings (default: 128) 31 | - in_dim (int): Input dimension of the model (default: 1) 32 | - out_dim (int): Output dimension of the model (default: 1) 33 | - activation (str): Activation function to be used (default: "relu") 34 | - norm (str): Normalization method to be used (default: "layer") 35 | - aggr (str): Aggregation method to be used (default: "sum") 36 | - pool (str): Global pooling method to be used (default: "sum") 37 | - residual (bool): Whether to use residual connections (default: True) 38 | - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False) 39 | """ 40 | super().__init__() 41 | self.equivariant_pred = equivariant_pred 42 | self.residual = residual 43 | 44 | # Embedding lookup for initial node features 45 | self.emb_in = torch.nn.Embedding(in_dim, emb_dim) 46 | 47 | # Stack of GNN layers 48 | self.convs = torch.nn.ModuleList() 49 | for _ in range(num_layers): 50 | self.convs.append(EGNNLayer(emb_dim, activation, norm, aggr)) 51 | 52 | # Global pooling/readout function 53 | self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] 54 | 55 | if self.equivariant_pred: 56 | # Linear predictor for equivariant tasks using geometric features 57 | self.pred = torch.nn.Linear(emb_dim + 3, out_dim) 58 | else: 59 | # MLP predictor for invariant tasks using only scalar features 60 | self.pred = torch.nn.Sequential( 61 | torch.nn.Linear(emb_dim, emb_dim), 62 | torch.nn.ReLU(), 63 | torch.nn.Linear(emb_dim, out_dim) 64 | ) 65 | 66 | def forward(self, batch): 67 | 68 | h = self.emb_in(batch.atoms) # (n,) -> (n, d) 69 | pos = batch.pos # (n, 3) 70 | 71 | for conv in self.convs: 72 | # Message passing layer 73 | h_update, pos_update = conv(h, pos, batch.edge_index) 74 | 75 | # Update node features (n, d) -> (n, d) 76 | h = h + h_update if self.residual else h_update 77 | 78 | # Update node coordinates (no residual) (n, 3) -> (n, 3) 79 | pos = pos_update 80 | 81 | if not self.equivariant_pred: 82 | # Select only scalars for invariant prediction 83 | out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d) 84 | else: 85 | out = self.pool(torch.cat([h, pos], dim=-1), batch.batch) 86 | 87 | return self.pred(out) # (batch_size, out_dim) 88 | -------------------------------------------------------------------------------- /models/mace_modules/irreps_tools.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # Elementary tools for handling irreducible representations 3 | # Authors: Ilyes Batatia, Gregor Simm 4 | # This program is distributed under the MIT License (see MIT.md) 5 | ########################################################################################### 6 | 7 | from typing import List, Tuple 8 | 9 | import torch 10 | from e3nn import o3 11 | from e3nn.util.jit import compile_mode 12 | 13 | 14 | # Based on mir-group/nequip 15 | def tp_out_irreps_with_instructions( 16 | irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps 17 | ) -> Tuple[o3.Irreps, List]: 18 | trainable = True 19 | 20 | # Collect possible irreps and their instructions 21 | irreps_out_list: List[Tuple[int, o3.Irreps]] = [] 22 | instructions = [] 23 | for i, (mul, ir_in) in enumerate(irreps1): 24 | for j, (_, ir_edge) in enumerate(irreps2): 25 | for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 26 | if ir_out in target_irreps: 27 | k = len(irreps_out_list) # instruction index 28 | irreps_out_list.append((mul, ir_out)) 29 | instructions.append((i, j, k, "uvu", trainable)) 30 | 31 | # We sort the output irreps of the tensor product so that we can simplify them 32 | # when they are provided to the second o3.Linear 33 | irreps_out = o3.Irreps(irreps_out_list) 34 | irreps_out, permut, _ = irreps_out.sort() 35 | 36 | # Permute the output indexes of the instructions to match the sorted irreps: 37 | instructions = [ 38 | (i_in1, i_in2, permut[i_out], mode, train) 39 | for i_in1, i_in2, i_out, mode, train in instructions 40 | ] 41 | 42 | return irreps_out, instructions 43 | 44 | 45 | def linear_out_irreps(irreps: o3.Irreps, target_irreps: o3.Irreps) -> o3.Irreps: 46 | # Assuming simplified irreps 47 | irreps_mid = [] 48 | for _, ir_in in irreps: 49 | found = False 50 | 51 | for mul, ir_out in target_irreps: 52 | if ir_in == ir_out: 53 | irreps_mid.append((mul, ir_out)) 54 | found = True 55 | break 56 | 57 | if not found: 58 | raise RuntimeError(f"{ir_in} not in {target_irreps}") 59 | 60 | return o3.Irreps(irreps_mid) 61 | 62 | 63 | @compile_mode("script") 64 | class reshape_irreps(torch.nn.Module): 65 | def __init__(self, irreps: o3.Irreps) -> None: 66 | super().__init__() 67 | self.irreps = irreps 68 | 69 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 70 | ix = 0 71 | out = [] 72 | batch, _ = tensor.shape 73 | for mul, ir in self.irreps: 74 | d = ir.dim 75 | field = tensor[:, ix : ix + mul * d] # [batch, sample, mul * repr] 76 | ix += mul * d 77 | field = field.reshape(batch, mul, d) 78 | out.append(field) 79 | return torch.cat(out, dim=-1) 80 | 81 | 82 | def irreps2gate(irreps): 83 | irreps_scalars = [] 84 | irreps_gated = [] 85 | for mul, ir in irreps: 86 | if ir.l == 0 and ir.p == 1: 87 | irreps_scalars.append((mul, ir)) 88 | else: 89 | irreps_gated.append((mul, ir)) 90 | irreps_scalars = o3.Irreps(irreps_scalars).simplify() 91 | irreps_gated = o3.Irreps(irreps_gated).simplify() 92 | if irreps_gated.dim > 0: 93 | ir = '0e' 94 | else: 95 | ir = None 96 | irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]).simplify() 97 | return irreps_scalars, irreps_gates, irreps_gated 98 | -------------------------------------------------------------------------------- /models/layers/tfn_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter 3 | import e3nn 4 | 5 | from models.mace_modules.irreps_tools import irreps2gate 6 | 7 | 8 | class TensorProductConvLayer(torch.nn.Module): 9 | """Tensor Field Network GNN Layer in e3nn 10 | 11 | Implements a Tensor Field Network equivariant GNN layer for higher-order tensors, using e3nn. 12 | Implementation adapted from: https://github.com/gcorso/DiffDock/ 13 | 14 | Paper: Tensor Field Networks, Thomas, Smidt et al. 15 | """ 16 | def __init__( 17 | self, 18 | in_irreps, 19 | out_irreps, 20 | sh_irreps, 21 | edge_feats_dim, 22 | mlp_dim, 23 | aggr="add", 24 | batch_norm=False, 25 | gate=False, 26 | ): 27 | """ 28 | Args: 29 | in_irreps: (e3nn.o3.Irreps) Input irreps dimensions 30 | out_irreps: (e3nn.o3.Irreps) Output irreps dimensions 31 | sh_irreps: (e3nn.o3.Irreps) Spherical harmonic irreps dimensions 32 | edge_feats_dim: (int) Edge feature dimensions 33 | mlp_dim: (int) Hidden dimension of MLP for computing tensor product weights 34 | aggr: (str) Message passing aggregator 35 | batch_norm: (bool) Whether to apply equivariant batch norm 36 | gate: (bool) Whether to apply gated non-linearity 37 | """ 38 | super().__init__() 39 | self.in_irreps = in_irreps 40 | self.out_irreps = out_irreps 41 | self.sh_irreps = sh_irreps 42 | self.edge_feats_dim = edge_feats_dim 43 | self.aggr = aggr 44 | 45 | if gate: 46 | # Optionally apply gated non-linearity 47 | irreps_scalars, irreps_gates, irreps_gated = irreps2gate( 48 | e3nn.o3.Irreps(out_irreps) 49 | ) 50 | act_scalars = [torch.nn.functional.silu for _, ir in irreps_scalars] 51 | act_gates = [torch.sigmoid for _, ir in irreps_gates] 52 | if irreps_gated.num_irreps == 0: 53 | self.gate = e3nn.nn.Activation(out_irreps, acts=[torch.nn.functional.silu]) 54 | else: 55 | self.gate = e3nn.nn.Gate( 56 | irreps_scalars, 57 | act_scalars, # scalar 58 | irreps_gates, 59 | act_gates, # gates (scalars) 60 | irreps_gated, # gated tensors 61 | ) 62 | # Output irreps for the tensor product must be updated 63 | self.out_irreps = out_irreps = self.gate.irreps_in 64 | else: 65 | self.gate = None 66 | 67 | # Tensor product over edges to construct messages 68 | self.tp = e3nn.o3.FullyConnectedTensorProduct( 69 | in_irreps, sh_irreps, out_irreps, shared_weights=False 70 | ) 71 | 72 | # MLP used to compute weights of tensor product 73 | self.fc = torch.nn.Sequential( 74 | torch.nn.Linear(edge_feats_dim, mlp_dim), 75 | torch.nn.ReLU(), 76 | torch.nn.Linear(mlp_dim, self.tp.weight_numel), 77 | ) 78 | 79 | # Optional equivariant batch norm 80 | self.batch_norm = e3nn.nn.BatchNorm(out_irreps) if batch_norm else None 81 | 82 | def forward(self, node_attr, edge_index, edge_sh, edge_feat): 83 | src, dst = edge_index 84 | # Compute messages 85 | tp = self.tp(node_attr[dst], edge_sh, self.fc(edge_feat)) 86 | # Aggregate messages 87 | out = scatter(tp, src, dim=0, reduce=self.aggr) 88 | # Optionally apply gated non-linearity and/or batch norm 89 | if self.gate: 90 | out = self.gate(out) 91 | if self.batch_norm: 92 | out = self.batch_norm(out) 93 | return out 94 | -------------------------------------------------------------------------------- /models/dimenet.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch_geometric.nn import DimeNetPlusPlus 6 | from torch_scatter import scatter 7 | 8 | 9 | class DimeNetPPModel(DimeNetPlusPlus): 10 | """ 11 | DimeNet model from "Directional message passing for molecular graphs". 12 | 13 | This class extends the DimeNetPlusPlus base class for PyG. 14 | """ 15 | def __init__( 16 | self, 17 | hidden_channels: int = 128, 18 | in_dim: int = 1, 19 | out_dim: int = 1, 20 | num_layers: int = 4, 21 | int_emb_size: int = 64, 22 | basis_emb_size: int = 8, 23 | out_emb_channels: int = 256, 24 | num_spherical: int = 7, 25 | num_radial: int = 6, 26 | cutoff: float = 10, 27 | max_num_neighbors: int = 32, 28 | envelope_exponent: int = 5, 29 | num_before_skip: int = 1, 30 | num_after_skip: int = 2, 31 | num_output_layers: int = 3, 32 | act: Union[str, Callable] = 'swish' 33 | ): 34 | """ 35 | Initializes an instance of the DimeNetPPModel class with the provided parameters. 36 | 37 | Parameters: 38 | - hidden_channels (int): Number of channels in the hidden layers (default: 128) 39 | - in_dim (int): Input dimension of the model (default: 1) 40 | - out_dim (int): Output dimension of the model (default: 1) 41 | - num_layers (int): Number of layers in the model (default: 4) 42 | - int_emb_size (int): Embedding size for interaction features (default: 64) 43 | - basis_emb_size (int): Embedding size for basis functions (default: 8) 44 | - out_emb_channels (int): Number of channels in the output embeddings (default: 256) 45 | - num_spherical (int): Number of spherical harmonics (default: 7) 46 | - num_radial (int): Number of radial basis functions (default: 6) 47 | - cutoff (float): Cutoff distance for interactions (default: 10) 48 | - max_num_neighbors (int): Maximum number of neighboring atoms to consider (default: 32) 49 | - envelope_exponent (int): Exponent of the envelope function (default: 5) 50 | - num_before_skip (int): Number of layers before the skip connections (default: 1) 51 | - num_after_skip (int): Number of layers after the skip connections (default: 2) 52 | - num_output_layers (int): Number of output layers (default: 3) 53 | - act (Union[str, Callable]): Activation function (default: 'swish' or callable) 54 | 55 | Note: 56 | - The `act` parameter can be either a string representing a built-in activation function, 57 | or a callable object that serves as a custom activation function. 58 | """ 59 | super().__init__( 60 | hidden_channels, 61 | out_dim, 62 | num_layers, 63 | int_emb_size, 64 | basis_emb_size, 65 | out_emb_channels, 66 | num_spherical, 67 | num_radial, 68 | cutoff, 69 | max_num_neighbors, 70 | envelope_exponent, 71 | num_before_skip, 72 | num_after_skip, 73 | num_output_layers, 74 | act 75 | ) 76 | 77 | def forward(self, batch): 78 | 79 | i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( 80 | batch.edge_index, num_nodes=batch.atoms.size(0)) 81 | 82 | # Calculate distances. 83 | dist = (batch.pos[i] - batch.pos[j]).pow(2).sum(dim=-1).sqrt() 84 | 85 | # Calculate angles. 86 | pos_i = batch.pos[idx_i] 87 | pos_ji, pos_ki = batch.pos[idx_j] - pos_i, batch.pos[idx_k] - pos_i 88 | a = (pos_ji * pos_ki).sum(dim=-1) 89 | b = torch.cross(pos_ji, pos_ki).norm(dim=-1) 90 | angle = torch.atan2(b, a) 91 | 92 | rbf = self.rbf(dist) 93 | sbf = self.sbf(dist, angle, idx_kj) 94 | 95 | # Embedding block. 96 | x = self.emb(batch.atoms, rbf, i, j) 97 | P = self.output_blocks[0](x, rbf, i, num_nodes=batch.pos.size(0)) 98 | 99 | # Interaction blocks. 100 | for interaction_block, output_block in zip(self.interaction_blocks, 101 | self.output_blocks[1:]): 102 | x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) 103 | P += output_block(x, rbf, i) 104 | 105 | return P.sum(dim=0) if batch is None else scatter(P, batch.batch, dim=0) 106 | -------------------------------------------------------------------------------- /experiments/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random 3 | from tqdm.autonotebook import tqdm # from tqdm import tqdm 4 | import numpy as np 5 | from sklearn.metrics import accuracy_score 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | 11 | def seed(seed=0): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | 20 | 21 | def train(model, train_loader, optimizer, device): 22 | model.train() 23 | loss_all = 0 24 | for batch in train_loader: 25 | batch = batch.to(device) 26 | optimizer.zero_grad() 27 | y_pred = model(batch) 28 | loss = F.cross_entropy(y_pred, batch.y) 29 | loss.backward() 30 | loss_all += loss.item() * batch.num_graphs 31 | optimizer.step() 32 | return loss_all / len(train_loader.dataset) 33 | 34 | 35 | def eval(model, loader, device): 36 | model.eval() 37 | y_pred = [] 38 | y_true = [] 39 | for batch in loader: 40 | batch = batch.to(device) 41 | with torch.no_grad(): 42 | y_pred.append(model(batch).detach().cpu()) 43 | y_true.append(batch.y.detach().cpu()) 44 | return accuracy_score( 45 | torch.concat(y_true, dim=0), 46 | np.argmax(torch.concat(y_pred, dim=0), axis=1) 47 | ) * 100 # return percentage 48 | 49 | 50 | def _run_experiment(model, train_loader, val_loader, test_loader, n_epochs=100, verbose=True, device='cpu'): 51 | total_param = 0 52 | for param in model.parameters(): 53 | total_param += np.prod(list(param.data.size())) 54 | model = model.to(device) 55 | 56 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) 57 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 58 | optimizer, mode='max', factor=0.9, patience=25, min_lr=0.00001) 59 | 60 | if verbose: 61 | print(f"Running experiment for {type(model).__name__}.") 62 | # print("\nModel architecture:") 63 | # print(model) 64 | print(f'Total parameters: {total_param}') 65 | print("\nStart training:") 66 | 67 | best_val_acc = None 68 | perf_per_epoch = [] # Track Test/Val performace vs. epoch (for plotting) 69 | t = time.time() 70 | for epoch in range(1, n_epochs+1): 71 | # Train model for one epoch, return avg. training loss 72 | loss = train(model, train_loader, optimizer, device) 73 | 74 | # Evaluate model on validation set 75 | val_acc = eval(model, val_loader, device) 76 | 77 | if best_val_acc is None or val_acc >= best_val_acc: 78 | # Evaluate model on test set if validation metric improves 79 | test_acc = eval(model, test_loader, device) 80 | best_val_acc = val_acc 81 | 82 | if epoch % 10 == 0 and verbose: 83 | print(f'Epoch: {epoch:03d}, LR: {lr:.5f}, Loss: {loss:.5f}, ' 84 | f'Val Acc: {val_acc:.3f}, Test Acc: {test_acc:.3f}') 85 | 86 | perf_per_epoch.append((test_acc, val_acc, epoch, type(model).__name__)) 87 | scheduler.step(val_acc) 88 | lr = optimizer.param_groups[0]['lr'] 89 | 90 | t = time.time() - t 91 | train_time = t 92 | if verbose: 93 | print(f"\nDone! Training took {train_time:.2f}s. Best validation accuracy: {best_val_acc:.3f}, corresponding test accuracy: {test_acc:.3f}.") 94 | 95 | return best_val_acc, test_acc, train_time, perf_per_epoch 96 | 97 | 98 | def run_experiment(model, train_loader, val_loader, test_loader, n_epochs=100, n_times=100, verbose=False, device='cpu'): 99 | print(f"Running experiment for {type(model).__name__} ({device}).") 100 | 101 | best_val_acc_list = [] 102 | test_acc_list = [] 103 | train_time_list = [] 104 | for idx in tqdm(range(n_times)): 105 | seed(idx) # set random seed 106 | best_val_acc, test_acc, train_time, _ = _run_experiment(model, train_loader, val_loader, test_loader, n_epochs, verbose, device) 107 | best_val_acc_list.append(best_val_acc) 108 | test_acc_list.append(test_acc) 109 | train_time_list.append(train_time) 110 | 111 | print(f'\nDone! Averaged over {n_times} runs: \n ' 112 | f'- Training time: {np.mean(train_time_list):.2f}s ± {np.std(train_time_list):.2f}. \n ' 113 | f'- Best validation accuracy: {np.mean(best_val_acc_list):.3f} ± {np.std(best_val_acc_list):.3f}. \n' 114 | f'- Test accuracy: {np.mean(test_acc_list):.1f} ± {np.std(test_acc_list):.1f}. \n') 115 | 116 | return best_val_acc_list, test_acc_list, train_time_list 117 | -------------------------------------------------------------------------------- /models/mace_modules/cg.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # Higher Order Real Clebsch Gordan (based on e3nn by Mario Geiger) 3 | # Authors: Ilyes Batatia 4 | # This program is distributed under the MIT License (see MIT.md) 5 | ########################################################################################### 6 | 7 | import collections 8 | from typing import List, Union 9 | 10 | import torch 11 | from e3nn import o3 12 | 13 | # Based on e3nn 14 | 15 | _TP = collections.namedtuple("_TP", "op, args") 16 | _INPUT = collections.namedtuple("_INPUT", "tensor, start, stop") 17 | 18 | 19 | def _wigner_nj( 20 | irrepss: List[o3.Irreps], 21 | normalization: str = "component", 22 | filter_ir_mid=None, 23 | dtype=None, 24 | ): 25 | irrepss = [o3.Irreps(irreps) for irreps in irrepss] 26 | if filter_ir_mid is not None: 27 | filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid] 28 | 29 | if len(irrepss) == 1: 30 | (irreps,) = irrepss 31 | ret = [] 32 | e = torch.eye(irreps.dim, dtype=dtype) 33 | i = 0 34 | for mul, ir in irreps: 35 | for _ in range(mul): 36 | sl = slice(i, i + ir.dim) 37 | ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])] 38 | i += ir.dim 39 | return ret 40 | 41 | *irrepss_left, irreps_right = irrepss 42 | ret = [] 43 | for ir_left, path_left, C_left in _wigner_nj( 44 | irrepss_left, 45 | normalization=normalization, 46 | filter_ir_mid=filter_ir_mid, 47 | dtype=dtype, 48 | ): 49 | i = 0 50 | for mul, ir in irreps_right: 51 | for ir_out in ir_left * ir: 52 | if filter_ir_mid is not None and ir_out not in filter_ir_mid: 53 | continue 54 | 55 | C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype) 56 | if normalization == "component": 57 | C *= ir_out.dim ** 0.5 58 | if normalization == "norm": 59 | C *= ir_left.dim ** 0.5 * ir.dim ** 0.5 60 | 61 | C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C) 62 | C = C.reshape( 63 | ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim 64 | ) 65 | for u in range(mul): 66 | E = torch.zeros( 67 | ir_out.dim, 68 | *(irreps.dim for irreps in irrepss_left), 69 | irreps_right.dim, 70 | dtype=dtype, 71 | ) 72 | sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim) 73 | E[..., sl] = C 74 | ret += [ 75 | ( 76 | ir_out, 77 | _TP( 78 | op=(ir_left, ir, ir_out), 79 | args=( 80 | path_left, 81 | _INPUT(len(irrepss_left), sl.start, sl.stop), 82 | ), 83 | ), 84 | E, 85 | ) 86 | ] 87 | i += mul * ir.dim 88 | return sorted(ret, key=lambda x: x[0]) 89 | 90 | 91 | def U_matrix_real( 92 | irreps_in: Union[str, o3.Irreps], 93 | irreps_out: Union[str, o3.Irreps], 94 | correlation: int, 95 | normalization: str = "component", 96 | filter_ir_mid=None, 97 | dtype=None, 98 | ): 99 | irreps_out = o3.Irreps(irreps_out) 100 | irrepss = [o3.Irreps(irreps_in)] * correlation 101 | if correlation == 4: 102 | filter_ir_mid = [ 103 | (0, 1), 104 | (1, -1), 105 | (2, 1), 106 | (3, -1), 107 | (4, 1), 108 | (5, -1), 109 | (6, 1), 110 | (7, -1), 111 | (8, 1), 112 | (9, -1), 113 | (10, 1), 114 | (11, -1), 115 | ] 116 | wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype) 117 | current_ir = wigners[0][0] 118 | out = [] 119 | stack = torch.tensor([]) 120 | 121 | for ir, _, base_o3 in wigners: 122 | if ir in irreps_out and ir == current_ir: 123 | stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1) 124 | last_ir = current_ir 125 | elif ir in irreps_out and ir != current_ir: 126 | if len(stack) != 0: 127 | out += [last_ir, stack] 128 | stack = base_o3.squeeze().unsqueeze(-1) 129 | current_ir, last_ir = ir, ir 130 | else: 131 | current_ir = ir 132 | out += [last_ir, stack] 133 | return out 134 | -------------------------------------------------------------------------------- /models/spherenet.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch_scatter import scatter 6 | 7 | from models.layers.spherenet_layer import * 8 | 9 | 10 | class SphereNetModel(torch.nn.Module): 11 | """ 12 | SphereNet model from "Spherical Message Passing for 3D Molecular Graphs". 13 | """ 14 | def __init__( 15 | self, 16 | cutoff: float = 10, 17 | num_layers: int = 4, 18 | hidden_channels: int = 128, 19 | in_dim: int = 1, 20 | out_dim: int = 1, 21 | int_emb_size: int = 64, 22 | basis_emb_size_dist: int = 8, 23 | basis_emb_size_angle: int = 8, 24 | basis_emb_size_torsion: int = 8, 25 | out_emb_channels: int = 128, 26 | num_spherical: int = 7, 27 | num_radial: int = 6, 28 | envelope_exponent: int = 5, 29 | num_before_skip: int = 1, 30 | num_after_skip: int = 2, 31 | num_output_layers: int = 2, 32 | act: Callable = swish, 33 | output_init: str = 'GlorotOrthogonal', 34 | use_node_features: bool = True 35 | ): 36 | """ 37 | Initializes an instance of the SphereNetModel class with the following parameters: 38 | 39 | Parameters: 40 | - cutoff (int): Cutoff distance for interactions (default: 10) 41 | - num_layers (int): Number of layers in the model (default: 4) 42 | - hidden_channels (int): Number of channels in the hidden layers (default: 128) 43 | - in_dim (int): Input dimension of the model (default: 1) 44 | - out_dim (int): Output dimension of the model (default: 1) 45 | - int_emb_size (int): Embedding size for interaction features (default: 64) 46 | - basis_emb_size_dist (int): Embedding size for distance basis functions (default: 8) 47 | - basis_emb_size_angle (int): Embedding size for angle basis functions (default: 8) 48 | - basis_emb_size_torsion (int): Embedding size for torsion basis functions (default: 8) 49 | - out_emb_channels (int): Number of channels in the output embeddings (default: 128) 50 | - num_spherical (int): Number of spherical harmonics (default: 7) 51 | - num_radial (int): Number of radial basis functions (default: 6) 52 | - envelope_exponent (int): Exponent of the envelope function (default: 5) 53 | - num_before_skip (int): Number of layers before the skip connections (default: 1) 54 | - num_after_skip (int): Number of layers after the skip connections (default: 2) 55 | - num_output_layers (int): Number of output layers (default: 2) 56 | - act (function): Activation function (default: swish) 57 | - output_init (str): Initialization method for the output layer (default: 'GlorotOrthogonal') 58 | - use_node_features (bool): Whether to use node features (default: True) 59 | """ 60 | super().__init__() 61 | 62 | self.cutoff = cutoff 63 | 64 | self.init_e = init(num_radial, hidden_channels, act, use_node_features=use_node_features) 65 | self.init_v = update_v(hidden_channels, out_emb_channels, out_dim, num_output_layers, act, output_init) 66 | self.init_u = update_u() 67 | self.emb = emb(num_spherical, num_radial, self.cutoff, envelope_exponent) 68 | 69 | self.update_vs = torch.nn.ModuleList([ 70 | update_v(hidden_channels, out_emb_channels, out_dim, num_output_layers, act, output_init) for _ in range(num_layers)]) 71 | 72 | self.update_es = torch.nn.ModuleList([ 73 | update_e(hidden_channels, int_emb_size, basis_emb_size_dist, basis_emb_size_angle, basis_emb_size_torsion, num_spherical, num_radial, num_before_skip, num_after_skip,act) for _ in range(num_layers)]) 74 | 75 | self.update_us = torch.nn.ModuleList([update_u() for _ in range(num_layers)]) 76 | 77 | self.reset_parameters() 78 | 79 | def reset_parameters(self): 80 | self.init_e.reset_parameters() 81 | self.init_v.reset_parameters() 82 | self.emb.reset_parameters() 83 | for update_e in self.update_es: 84 | update_e.reset_parameters() 85 | for update_v in self.update_vs: 86 | update_v.reset_parameters() 87 | 88 | 89 | def forward(self, batch_data): 90 | z, pos, batch = batch_data.atoms, batch_data.pos, batch_data.batch 91 | edge_index = batch_data.edge_index 92 | num_nodes = z.size(0) 93 | dist, angle, torsion, i, j, idx_kj, idx_ji = xyz_to_dat(pos, edge_index, num_nodes, use_torsion=True) 94 | 95 | emb = self.emb(dist, angle, torsion, idx_kj) 96 | 97 | # Initialize edge, node, graph features 98 | e = self.init_e(z, emb, i, j) 99 | v = self.init_v(e, i) 100 | # Disable virutal node trick 101 | # u = self.init_u(torch.zeros_like(scatter(v, batch, dim=0)), v, batch) 102 | 103 | for update_e, update_v, update_u in zip(self.update_es, self.update_vs, self.update_us): 104 | e = update_e(e, emb, idx_kj, idx_ji) 105 | v = update_v(e, i) 106 | # Disable virutal node trick 107 | # u = update_u(u, v, batch) 108 | 109 | out = scatter(v, batch, dim=0, reduce='add') 110 | return out 111 | -------------------------------------------------------------------------------- /models/gvpgnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from torch_geometric.nn import global_add_pool, global_mean_pool 4 | 5 | from models.mace_modules.blocks import RadialEmbeddingBlock 6 | import models.layers.gvp_layer as gvp 7 | 8 | 9 | class GVPGNNModel(torch.nn.Module): 10 | """ 11 | GVP-GNN model from "Equivariant Graph Neural Networks for 3D Macromolecular Structure". 12 | """ 13 | def __init__( 14 | self, 15 | r_max: float = 10.0, 16 | num_bessel: int = 8, 17 | num_polynomial_cutoff: int = 5, 18 | num_layers: int = 5, 19 | in_dim=1, 20 | out_dim=1, 21 | s_dim: int = 128, 22 | v_dim: int = 16, 23 | s_dim_edge: int = 32, 24 | v_dim_edge: int = 1, 25 | pool: str = "sum", 26 | residual: bool = True, 27 | equivariant_pred: bool = False 28 | ): 29 | """ 30 | Initializes an instance of the GVPGNNModel class with the provided parameters. 31 | 32 | Parameters: 33 | - r_max (float): Maximum distance for Bessel basis functions (default: 10.0) 34 | - num_bessel (int): Number of Bessel basis functions (default: 8) 35 | - num_polynomial_cutoff (int): Number of polynomial cutoff basis functions (default: 5) 36 | - num_layers (int): Number of layers in the model (default: 5) 37 | - in_dim (int): Input dimension of the model (default: 1) 38 | - out_dim (int): Output dimension of the model (default: 1) 39 | - s_dim (int): Dimension of the node state embeddings (default: 128) 40 | - v_dim (int): Dimension of the node vector embeddings (default: 16) 41 | - s_dim_edge (int): Dimension of the edge state embeddings (default: 32) 42 | - v_dim_edge (int): Dimension of the edge vector embeddings (default: 1) 43 | - pool (str): Global pooling method to be used (default: "sum") 44 | - residual (bool): Whether to use residual connections (default: True) 45 | - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False) 46 | """ 47 | super().__init__() 48 | 49 | self.r_max = r_max 50 | self.num_layers = num_layers 51 | self.equivariant_pred = equivariant_pred 52 | self.s_dim = s_dim 53 | self.v_dim = v_dim 54 | 55 | activations = (F.relu, None) 56 | _DEFAULT_V_DIM = (s_dim, v_dim) 57 | _DEFAULT_E_DIM = (s_dim_edge, v_dim_edge) 58 | 59 | # Node embedding 60 | self.emb_in = torch.nn.Embedding(in_dim, s_dim) 61 | self.W_v = torch.nn.Sequential( 62 | gvp.LayerNorm((s_dim, 0)), 63 | gvp.GVP((s_dim, 0), _DEFAULT_V_DIM, 64 | activations=(None, None), vector_gate=True) 65 | ) 66 | 67 | # Edge embedding 68 | self.radial_embedding = RadialEmbeddingBlock( 69 | r_max=r_max, 70 | num_bessel=num_bessel, 71 | num_polynomial_cutoff=num_polynomial_cutoff, 72 | ) 73 | self.W_e = torch.nn.Sequential( 74 | gvp.LayerNorm((self.radial_embedding.out_dim, 1)), 75 | gvp.GVP((self.radial_embedding.out_dim, 1), _DEFAULT_E_DIM, 76 | activations=(None, None), vector_gate=True) 77 | ) 78 | 79 | # Stack of GNN layers 80 | self.layers = torch.nn.ModuleList( 81 | gvp.GVPConvLayer( 82 | _DEFAULT_V_DIM, _DEFAULT_E_DIM, 83 | activations=activations, vector_gate=True, 84 | residual=residual 85 | ) 86 | for _ in range(num_layers) 87 | ) 88 | 89 | # Global pooling/readout function 90 | self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] 91 | 92 | if self.equivariant_pred: 93 | # Linear predictor for equivariant tasks using geometric features 94 | self.pred = torch.nn.Linear(s_dim + v_dim * 3, out_dim) 95 | else: 96 | # MLP predictor for invariant tasks using only scalar features 97 | self.pred = torch.nn.Sequential( 98 | torch.nn.Linear(s_dim, s_dim), 99 | torch.nn.ReLU(), 100 | torch.nn.Linear(s_dim, out_dim) 101 | ) 102 | 103 | def forward(self, batch): 104 | 105 | # Edge features 106 | vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]] # [n_edges, 3] 107 | lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] 108 | 109 | h_V = self.emb_in(batch.atoms) # (n,) -> (n, d) 110 | h_E = ( 111 | self.radial_embedding(lengths), 112 | torch.nan_to_num(torch.div(vectors, lengths)).unsqueeze_(-2) 113 | ) 114 | 115 | h_V = self.W_v(h_V) 116 | h_E = self.W_e(h_E) 117 | 118 | for layer in self.layers: 119 | h_V = layer(h_V, batch.edge_index, h_E) 120 | 121 | out = self.pool(gvp._merge(*h_V), batch.batch) # (n, d) -> (batch_size, d) 122 | 123 | if not self.equivariant_pred: 124 | # Select only scalars for invariant prediction 125 | out = out[:,:self.s_dim] 126 | 127 | return self.pred(out) # (batch_size, out_dim) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ⚔️ Geometric GNN Dojo 2 | 3 | *Geometric GNN Dojo* is a pedagogical resource for beginners and experts to explore the design space of **Graph Neural Networks for geometric graphs**. 4 | 5 | Check out the accompanying paper ['On the Expressive Power of Geometric Graph Neural Networks'](https://arxiv.org/abs/2301.09308), which studies the expressivity and theoretical limits of geometric GNNs. 6 | > Chaitanya K. Joshi*, Cristian Bodnar*, Simon V. Mathis, Taco Cohen, and Pietro Liò. On the Expressive Power of Geometric Graph Neural Networks. *International Conference on Machine Learning*. 7 | > 8 | >[PDF](https://arxiv.org/pdf/2301.09308.pdf) | [Slides](https://www.chaitjo.com/publication/joshi-2023-expressive/Geometric_GNNs_Slides.pdf) | [Video](https://youtu.be/5ulJMtpiKGc) 9 | 10 | ❓**New to geometric GNNs:** try our practical notebook on [*Geometric GNNs 101*](geometric_gnn_101.ipynb), prepared for MPhil students at the University of Cambridge. 11 | 12 | 13 | Open In Colab (recommended!) 14 | 15 | 16 | ## Architectures 17 | 18 | The `/models` directory provides unified implementations of several popular geometric GNN architectures: 19 | - Invariant GNNs: [SchNet](https://arxiv.org/abs/1706.08566), [DimeNet](https://arxiv.org/abs/2003.03123), [SphereNet](https://arxiv.org/abs/2102.05013) 20 | - Equivariant GNNs using cartesian vectors: [E(n) Equivariant GNN](https://proceedings.mlr.press/v139/satorras21a.html), [GVP-GNN](https://arxiv.org/abs/2009.01411) 21 | - Equivariant GNNs using spherical tensors: [Tensor Field Network](https://arxiv.org/abs/1802.08219), [MACE](http://arxiv.org/abs/2206.07697) 22 | - 🔥 Your new geometric GNN architecture? 23 | 24 |
25 | 26 | ## Experiments 27 | 28 | The `/experiments` directory contains notebooks with synthetic experiments to highlight practical challenges in building powerful geometric GNNs: 29 | - `kchains.ipynb`: Distinguishing k-chains, which test a model's ability to **propagate geometric information** non-locally and demonstrate oversquashing with increased depth/longer chains. 30 | - `rotsym.ipynb`: Rotationally symmetric structures, which test a layer's ability to **identify neighbourhood orientation** and highlight the utility of higher order tensors in equivariant GNNs. 31 | - `incompleteness.ipynb`: Counterexamples from [Pozdnyakov et al.](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.125.166001), which test a layer's ability to create **distinguishing fingerprints for local neighbourhoods** and highlight the need for higher body order of local scalarisation (distances, angles, and beyond). 32 | 33 | 34 | 35 | ## Installation 36 | 37 | ```bash 38 | # Create new conda environment 39 | conda create --prefix ./env python=3.8 40 | conda activate ./env 41 | 42 | # Install PyTorch (Check CUDA version for GPU!) 43 | # 44 | # Option 1: CPU 45 | conda install pytorch==1.12.0 -c pytorch 46 | # 47 | # Option 2: GPU, CUDA 11.3 48 | # conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch 49 | 50 | # Install dependencies 51 | conda install matplotlib pandas networkx 52 | conda install jupyterlab -c conda-forge 53 | pip install e3nn==0.4.4 ipdb ase 54 | 55 | # Install PyG (Check CPU/GPU/MacOS) 56 | # 57 | # Option 1: CPU, MacOS 58 | pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-1.12.0+cpu.html 59 | pip install torch-geometric 60 | # 61 | # Option 2: GPU, CUDA 11.3 62 | # pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-1.12.1+cu113.html 63 | # pip install torch-geometric 64 | # 65 | # Option 3: CPU/GPU, but may not work on MacOS 66 | # conda install pyg -c pyg 67 | ``` 68 | 69 | 70 | ## Directory Structure and Usage 71 | 72 | ``` 73 | . 74 | ├── README.md 75 | | 76 | ├── geometric_gnn_101.ipynb # A gentle introduction to Geometric GNNs 77 | | 78 | ├── experiments # Synthetic experiments 79 | | | 80 | │ ├── kchains.ipynb # Experiment on k-chains 81 | │ ├── rotsym.ipynb # Experiment on rotationally symmetric structures 82 | │ ├── incompleteness.ipynb # Experiment on counterexamples from Pozdnyakov et al. 83 | | └── utils # Helper functions for training, plotting, etc. 84 | | 85 | └── models # Geometric GNN models library 86 | | 87 | ├── schnet.py # SchNet model 88 | ├── dimenet.py # DimeNet model 89 | ├── spherenet.py # SphereNet model 90 | ├── egnn.py # E(n) Equivariant GNN model 91 | ├── gvpgnn.py # GVP-GNN model 92 | ├── tfn.py # Tensor Field Network model 93 | ├── mace.py # MACE model 94 | ├── layers # Layers for each model 95 | └── modules # Modules and layers for MACE 96 | ``` 97 | 98 | 99 | ## Contact 100 | 101 | Authors: Chaitanya K. Joshi (chaitanya.joshi@cl.cam.ac.uk), Simon V. Mathis (simon.mathis@cl.cam.ac.uk). 102 | We welcome your questions and feedback via email or GitHub Issues. 103 | 104 | 105 | ## Citation 106 | 107 | ``` 108 | @inproceedings{joshi2023expressive, 109 | title={On the Expressive Power of Geometric Graph Neural Networks}, 110 | author={Joshi, Chaitanya K. and Bodnar, Cristian and Mathis, Simon V. and Cohen, Taco and Liò, Pietro}, 111 | booktitle={International Conference on Machine Learning}, 112 | year={2023}, 113 | } 114 | ``` -------------------------------------------------------------------------------- /models/layers/egnn_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Linear, ReLU, SiLU, Sequential 3 | from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool 4 | from torch_scatter import scatter 5 | 6 | 7 | class EGNNLayer(MessagePassing): 8 | """E(n) Equivariant GNN Layer 9 | 10 | Paper: E(n) Equivariant Graph Neural Networks, Satorras et al. 11 | """ 12 | def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"): 13 | """ 14 | Args: 15 | emb_dim: (int) - hidden dimension `d` 16 | activation: (str) - non-linearity within MLPs (swish/relu) 17 | norm: (str) - normalisation layer (layer/batch) 18 | aggr: (str) - aggregation function `\oplus` (sum/mean/max) 19 | """ 20 | # Set the aggregation function 21 | super().__init__(aggr=aggr) 22 | 23 | self.emb_dim = emb_dim 24 | self.activation = {"swish": SiLU(), "relu": ReLU()}[activation] 25 | self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm] 26 | 27 | # MLP `\psi_h` for computing messages `m_ij` 28 | self.mlp_msg = Sequential( 29 | Linear(2 * emb_dim + 1, emb_dim), 30 | self.norm(emb_dim), 31 | self.activation, 32 | Linear(emb_dim, emb_dim), 33 | self.norm(emb_dim), 34 | self.activation, 35 | ) 36 | # MLP `\psi_x` for computing messages `\overrightarrow{m}_ij` 37 | self.mlp_pos = Sequential( 38 | Linear(emb_dim, emb_dim), self.norm(emb_dim), self.activation, Linear(emb_dim, 1) 39 | ) 40 | # MLP `\phi` for computing updated node features `h_i^{l+1}` 41 | self.mlp_upd = Sequential( 42 | Linear(2 * emb_dim, emb_dim), 43 | self.norm(emb_dim), 44 | self.activation, 45 | Linear(emb_dim, emb_dim), 46 | self.norm(emb_dim), 47 | self.activation, 48 | ) 49 | 50 | def forward(self, h, pos, edge_index): 51 | """ 52 | Args: 53 | h: (n, d) - initial node features 54 | pos: (n, 3) - initial node coordinates 55 | edge_index: (e, 2) - pairs of edges (i, j) 56 | Returns: 57 | out: [(n, d),(n,3)] - updated node features 58 | """ 59 | out = self.propagate(edge_index, h=h, pos=pos) 60 | return out 61 | 62 | def message(self, h_i, h_j, pos_i, pos_j): 63 | # Compute messages 64 | pos_diff = pos_i - pos_j 65 | dists = torch.norm(pos_diff, dim=-1).unsqueeze(1) 66 | msg = torch.cat([h_i, h_j, dists], dim=-1) 67 | msg = self.mlp_msg(msg) 68 | # Scale magnitude of displacement vector 69 | pos_diff = pos_diff * self.mlp_pos(msg) 70 | # NOTE: some papers divide pos_diff by (dists + 1) to stabilise model. 71 | # NOTE: lucidrains clamps pos_diff between some [-n, +n], also for stability. 72 | return msg, pos_diff 73 | 74 | def aggregate(self, inputs, index): 75 | msgs, pos_diffs = inputs 76 | # Aggregate messages 77 | msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr) 78 | # Aggregate displacement vectors 79 | pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean") 80 | return msg_aggr, pos_aggr 81 | 82 | def update(self, aggr_out, h, pos): 83 | msg_aggr, pos_aggr = aggr_out 84 | upd_out = self.mlp_upd(torch.cat([h, msg_aggr], dim=-1)) 85 | upd_pos = pos + pos_aggr 86 | return upd_out, upd_pos 87 | 88 | def __repr__(self) -> str: 89 | return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})" 90 | 91 | 92 | class MPNNLayer(MessagePassing): 93 | def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"): 94 | """Vanilla Message Passing GNN layer 95 | 96 | Args: 97 | emb_dim: (int) - hidden dimension `d` 98 | activation: (str) - non-linearity within MLPs (swish/relu) 99 | norm: (str) - normalisation layer (layer/batch) 100 | aggr: (str) - aggregation function `\oplus` (sum/mean/max) 101 | """ 102 | # Set the aggregation function 103 | super().__init__(aggr=aggr) 104 | 105 | self.emb_dim = emb_dim 106 | self.activation = {"swish": SiLU(), "relu": ReLU()}[activation] 107 | self.norm = {"layer": torch.nn.LayerNorm, "batch": torch.nn.BatchNorm1d}[norm] 108 | 109 | # MLP `\psi_h` for computing messages `m_ij` 110 | self.mlp_msg = Sequential( 111 | Linear(2 * emb_dim, emb_dim), 112 | self.norm(emb_dim), 113 | self.activation, 114 | Linear(emb_dim, emb_dim), 115 | self.norm(emb_dim), 116 | self.activation, 117 | ) 118 | # MLP `\phi` for computing updated node features `h_i^{l+1}` 119 | self.mlp_upd = Sequential( 120 | Linear(2 * emb_dim, emb_dim), 121 | self.norm(emb_dim), 122 | self.activation, 123 | Linear(emb_dim, emb_dim), 124 | self.norm(emb_dim), 125 | self.activation, 126 | ) 127 | 128 | def forward(self, h, edge_index): 129 | """ 130 | Args: 131 | h: (n, d) - initial node features 132 | edge_index: (e, 2) - pairs of edges (i, j) 133 | Returns: 134 | out: (n, d) - updated node features 135 | """ 136 | out = self.propagate(edge_index, h=h) 137 | return out 138 | 139 | def message(self, h_i, h_j): 140 | # Compute messages 141 | msg = torch.cat([h_i, h_j], dim=-1) 142 | msg = self.mlp_msg(msg) 143 | return msg 144 | 145 | def aggregate(self, inputs, index): 146 | # Aggregate messages 147 | msg_aggr = scatter(inputs, index, dim=self.node_dim, reduce=self.aggr) 148 | return msg_aggr 149 | 150 | def update(self, aggr_out, h): 151 | upd_out = self.mlp_upd(torch.cat([h, aggr_out], dim=-1)) 152 | return upd_out 153 | 154 | def __repr__(self) -> str: 155 | return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})" 156 | -------------------------------------------------------------------------------- /models/tfn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch_geometric.nn import global_add_pool, global_mean_pool 6 | import e3nn 7 | 8 | from models.mace_modules.blocks import RadialEmbeddingBlock 9 | from models.layers.tfn_layer import TensorProductConvLayer 10 | 11 | 12 | class TFNModel(torch.nn.Module): 13 | """ 14 | Tensor Field Network model from "Tensor Field Networks". 15 | """ 16 | def __init__( 17 | self, 18 | r_max: float = 10.0, 19 | num_bessel: int = 8, 20 | num_polynomial_cutoff: int = 5, 21 | max_ell: int = 2, 22 | num_layers: int = 5, 23 | emb_dim: int = 64, 24 | hidden_irreps: Optional[e3nn.o3.Irreps] = None, 25 | mlp_dim: int = 256, 26 | in_dim: int = 1, 27 | out_dim: int = 1, 28 | aggr: str = "sum", 29 | pool: str = "sum", 30 | gate: bool = True, 31 | batch_norm: bool = False, 32 | residual: bool = True, 33 | equivariant_pred: bool = False 34 | ): 35 | """ 36 | Parameters: 37 | - r_max (float): Maximum distance for Bessel basis functions (default: 10.0) 38 | - num_bessel (int): Number of Bessel basis functions (default: 8) 39 | - num_polynomial_cutoff (int): Number of polynomial cutoff basis functions (default: 5) 40 | - max_ell (int): Maximum degree of spherical harmonics basis functions (default: 2) 41 | - num_layers (int): Number of layers in the model (default: 5) 42 | - emb_dim (int): Scalar feature embedding dimension (default: 64) 43 | - hidden_irreps (Optional[e3nn.o3.Irreps]): Hidden irreps (default: None) 44 | - mlp_dim (int): Dimension of MLP for computing tensor product weights (default: 256) 45 | - in_dim (int): Input dimension of the model (default: 1) 46 | - out_dim (int): Output dimension of the model (default: 1) 47 | - aggr (str): Aggregation method to be used (default: "sum") 48 | - pool (str): Global pooling method to be used (default: "sum") 49 | - gate (bool): Whether to use gated equivariant non-linearity (default: True) 50 | - batch_norm (bool): Whether to use batch normalization (default: False) 51 | - residual (bool): Whether to use residual connections (default: True) 52 | - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False) 53 | 54 | Note: 55 | - If `hidden_irreps` is None, the irreps for the intermediate features are computed 56 | using `emb_dim` and `max_ell`. 57 | - The `equivariant_pred` parameter determines whether it is an equivariant prediction task. 58 | If set to True, equivariant prediction will be performed. 59 | - At present, only one of `gate` and `batch_norm` can be True. 60 | """ 61 | super().__init__() 62 | 63 | self.r_max = r_max 64 | self.max_ell = max_ell 65 | self.num_layers = num_layers 66 | self.emb_dim = emb_dim 67 | self.mlp_dim = mlp_dim 68 | self.residual = residual 69 | self.batch_norm = batch_norm 70 | self.gate = gate 71 | self.hidden_irreps = hidden_irreps 72 | self.equivariant_pred = equivariant_pred 73 | 74 | # Edge embedding 75 | self.radial_embedding = RadialEmbeddingBlock( 76 | r_max=r_max, 77 | num_bessel=num_bessel, 78 | num_polynomial_cutoff=num_polynomial_cutoff, 79 | ) 80 | sh_irreps = e3nn.o3.Irreps.spherical_harmonics(max_ell) 81 | self.spherical_harmonics = e3nn.o3.SphericalHarmonics( 82 | sh_irreps, normalize=True, normalization="component" 83 | ) 84 | 85 | # Embedding lookup for initial node features 86 | self.emb_in = torch.nn.Embedding(in_dim, emb_dim) 87 | 88 | # Set hidden irreps if none are provided 89 | if hidden_irreps is None: 90 | hidden_irreps = (sh_irreps * emb_dim).sort()[0].simplify() 91 | # Note: This defaults to O(3) equivariant layers 92 | # It is possible to use SO(3) equivariance by passing the appropriate irreps 93 | 94 | self.convs = torch.nn.ModuleList() 95 | # First conv layer: scalar only -> tensor 96 | self.convs.append( 97 | TensorProductConvLayer( 98 | in_irreps=e3nn.o3.Irreps(f'{emb_dim}x0e'), 99 | out_irreps=hidden_irreps, 100 | sh_irreps=sh_irreps, 101 | edge_feats_dim=self.radial_embedding.out_dim, 102 | mlp_dim=mlp_dim, 103 | aggr=aggr, 104 | batch_norm=batch_norm, 105 | gate=gate, 106 | ) 107 | ) 108 | # Intermediate conv layers: tensor -> tensor 109 | for _ in range(num_layers - 1): 110 | conv = TensorProductConvLayer( 111 | in_irreps=hidden_irreps, 112 | out_irreps=hidden_irreps, 113 | sh_irreps=sh_irreps, 114 | edge_feats_dim=self.radial_embedding.out_dim, 115 | mlp_dim=mlp_dim, 116 | aggr=aggr, 117 | batch_norm=batch_norm, 118 | gate=gate, 119 | ) 120 | self.convs.append(conv) 121 | 122 | # Global pooling/readout function 123 | self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] 124 | 125 | if self.equivariant_pred: 126 | # Linear predictor for equivariant tasks using geometric features 127 | self.pred = torch.nn.Linear(hidden_irreps.dim, out_dim) 128 | else: 129 | # MLP predictor for invariant tasks using only scalar features 130 | self.pred = torch.nn.Sequential( 131 | torch.nn.Linear(emb_dim, emb_dim), 132 | torch.nn.ReLU(), 133 | torch.nn.Linear(emb_dim, out_dim) 134 | ) 135 | 136 | def forward(self, batch): 137 | # Node embedding 138 | h = self.emb_in(batch.atoms) # (n,) -> (n, d) 139 | 140 | # Edge features 141 | vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]] # [n_edges, 3] 142 | lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] 143 | 144 | edge_sh = self.spherical_harmonics(vectors) 145 | edge_feats = self.radial_embedding(lengths) 146 | 147 | for conv in self.convs: 148 | # Message passing layer 149 | h_update = conv(h, batch.edge_index, edge_sh, edge_feats) 150 | 151 | # Update node features 152 | h = h_update + F.pad(h, (0, h_update.shape[-1] - h.shape[-1])) if self.residual else h_update 153 | 154 | out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d) 155 | 156 | if not self.equivariant_pred: 157 | # Select only scalars for invariant prediction 158 | out = out[:,:self.emb_dim] 159 | 160 | return self.pred(out) # (batch_size, out_dim) 161 | -------------------------------------------------------------------------------- /experiments/rotsym.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Identifying neighbourhood orientation: rotationally symmetric structures\n", 9 | "\n", 10 | "*Background:*\n", 11 | "Rotationally equivariant geometric GNNs aggregate local geometric information via summing together the neighbourhood geometric features, which are either **cartesian vectors** or **higher order spherical tensors**. \n", 12 | "The ideal geometric GNN would injectively aggregate local geometric infromation to perfectly identify neighbourhood identities, orientations, etc.\n", 13 | "In practice, the choice of basis (cartesian vs. spherical) comes with tradeoffs between tractability and empirical performance.\n", 14 | "\n", 15 | "*Experiment:*\n", 16 | "In this notebook, we study how rotational symmetries interact with tensor order in equivariant GNNs. \n", 17 | "We evaluate equivariant layers on their ability to distinguish the orientation of **structures with rotational symmetry**. \n", 18 | "An [$L$-fold symmetric structure](https://en.wikipedia.org/wiki/Rotational_symmetry) does not change when rotated by an angle $\\frac{2\\pi}{L}$ around a point (in 2D) or axis (3D).\n", 19 | "We consider two *distinct* rotated versions of each $L$-fold symmetric structure and train single layer equivariant GNNs to classify the two orientations using the updated geometric features.\n", 20 | "\n", 21 | "![Rotationally symmetric structures](fig/rotsym.png)\n", 22 | "\n", 23 | "*Result:*\n", 24 | "- **We find that layers using order $L$ tensors are unable to identify the orientation of structures with rotation symmetry higher than $L$-fold.** This observation may be attributed to **spherical harmonics**, which serve as an orthonormal basis for spherical tensor features and exhibit rotational symmetry themselves.\n", 25 | "- Layers such as E-GNN and GVP-GNN using **cartesian vectors** (corresponding to tensor order 1) are popular as working with higher order tensors can be computationally intractable for many applications. However, E-GNN and GVP-GNN are particularly poor at disciminating orientation of rotationally symmetric structures. " 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "%load_ext autoreload\n", 35 | "%autoreload 2\n", 36 | "\n", 37 | "import sys\n", 38 | "sys.path.append('../')\n", 39 | "\n", 40 | "import random\n", 41 | "import math\n", 42 | "import torch\n", 43 | "import torch_geometric\n", 44 | "from torch_geometric.data import Data\n", 45 | "from torch_geometric.loader import DataLoader\n", 46 | "from torch_geometric.utils import to_undirected\n", 47 | "import e3nn\n", 48 | "from functools import partial\n", 49 | "\n", 50 | "print(\"PyTorch version {}\".format(torch.__version__))\n", 51 | "print(\"PyG version {}\".format(torch_geometric.__version__))\n", 52 | "print(\"e3nn version {}\".format(e3nn.__version__))\n", 53 | "\n", 54 | "from experiments.utils.plot_utils import plot_2d\n", 55 | "from experiments.utils.train_utils import run_experiment\n", 56 | "from models import SchNetModel, DimeNetPPModel, SphereNetModel, EGNNModel, GVPGNNModel, TFNModel, MACEModel\n", 57 | "\n", 58 | "# Set the device\n", 59 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 60 | "print(f\"Using device: {device}\")" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "def create_rotsym_envs(fold=3):\n", 70 | " dataset = []\n", 71 | "\n", 72 | " # Environment 0\n", 73 | " atoms = torch.LongTensor([ 0 ] + [ 0 ] * fold)\n", 74 | " edge_index = torch.LongTensor( [ [0] * fold, [i for i in range(1, fold+1)] ] )\n", 75 | " x = torch.Tensor([1,0,0])\n", 76 | " pos = [\n", 77 | " torch.Tensor([0,0,0]), # origin\n", 78 | " x, # first spoke \n", 79 | " ]\n", 80 | " for count in range(1, fold):\n", 81 | " R = e3nn.o3.matrix_z(torch.Tensor([2*math.pi/fold * count])).squeeze(0)\n", 82 | " pos.append(x @ R.T)\n", 83 | " pos = torch.stack(pos)\n", 84 | " y = torch.LongTensor([0]) # Label 0\n", 85 | " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", 86 | " data1.edge_index = to_undirected(data1.edge_index)\n", 87 | " dataset.append(data1)\n", 88 | " \n", 89 | " # Environment 1\n", 90 | " q = 2*math.pi/(fold + random.randint(1, fold))\n", 91 | " assert q < 2*math.pi/fold\n", 92 | " Q = e3nn.o3.matrix_z(torch.Tensor([q])).squeeze(0)\n", 93 | " pos = pos @ Q.T\n", 94 | " y = torch.LongTensor([1]) # Label 1\n", 95 | " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", 96 | " data2.edge_index = to_undirected(data2.edge_index)\n", 97 | " dataset.append(data2)\n", 98 | " \n", 99 | " return dataset" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "fold = 5\n", 109 | "\n", 110 | "# Create dataset\n", 111 | "dataset = create_rotsym_envs(fold)\n", 112 | "for data in dataset:\n", 113 | " plot_2d(data, lim=1)\n", 114 | "\n", 115 | "# Create dataloaders\n", 116 | "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n", 117 | "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", 118 | "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "# Set parameters\n", 128 | "model_name = \"tfn\"\n", 129 | "correlation = 2\n", 130 | "max_ell = 5\n", 131 | "\n", 132 | "model = {\n", 133 | " \"schnet\": SchNetModel,\n", 134 | " \"dimenet\": DimeNetPPModel,\n", 135 | " \"spherenet\": SphereNetModel,\n", 136 | " \"egnn\": partial(EGNNModel, equivariant_pred=True),\n", 137 | " \"gvp\": partial(GVPGNNModel, equivariant_pred=True),\n", 138 | " \"tfn\": partial(TFNModel, max_ell=max_ell, equivariant_pred=True),\n", 139 | " \"mace\": partial(MACEModel, max_ell=max_ell, correlation=correlation, equivariant_pred=True),\n", 140 | "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n", 141 | "\n", 142 | "best_val_acc, test_acc, train_time = run_experiment(\n", 143 | " model, \n", 144 | " dataloader,\n", 145 | " val_loader, \n", 146 | " test_loader,\n", 147 | " n_epochs=100,\n", 148 | " n_times=10,\n", 149 | " device=device,\n", 150 | " verbose=False\n", 151 | ")" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [] 160 | } 161 | ], 162 | "metadata": { 163 | "kernelspec": { 164 | "display_name": "Python 3", 165 | "language": "python", 166 | "name": "python3" 167 | }, 168 | "language_info": { 169 | "codemirror_mode": { 170 | "name": "ipython", 171 | "version": 3 172 | }, 173 | "file_extension": ".py", 174 | "mimetype": "text/x-python", 175 | "name": "python", 176 | "nbconvert_exporter": "python", 177 | "pygments_lexer": "ipython3", 178 | "version": "3.8.16" 179 | }, 180 | "orig_nbformat": 4, 181 | "vscode": { 182 | "interpreter": { 183 | "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897" 184 | } 185 | } 186 | }, 187 | "nbformat": 4, 188 | "nbformat_minor": 2 189 | } 190 | -------------------------------------------------------------------------------- /models/mace_modules/symmetric_contraction.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # Implementation of the symmetric contraction algorithm presented in the MACE paper 3 | # (Batatia et al, MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields , Eq.10 and 11) 4 | # Authors: Ilyes Batatia 5 | # This program is distributed under the MIT License (see MIT.md) 6 | ########################################################################################### 7 | 8 | from typing import Dict, Optional, Union 9 | 10 | import torch 11 | import torch.fx 12 | from e3nn import o3 13 | from e3nn.util.codegen import CodeGenMixin 14 | from e3nn.util.jit import compile_mode 15 | from opt_einsum import contract 16 | 17 | from .cg import U_matrix_real 18 | 19 | 20 | @compile_mode("script") 21 | class SymmetricContraction(CodeGenMixin, torch.nn.Module): 22 | def __init__( 23 | self, 24 | irreps_in: o3.Irreps, 25 | irreps_out: o3.Irreps, 26 | correlation: Union[int, Dict[str, int]], 27 | irrep_normalization: str = "component", 28 | path_normalization: str = "element", 29 | internal_weights: Optional[bool] = None, 30 | shared_weights: Optional[torch.Tensor] = None, 31 | element_dependent: Optional[bool] = None, 32 | num_elements: Optional[int] = None, 33 | ) -> None: 34 | super().__init__() 35 | 36 | if irrep_normalization is None: 37 | irrep_normalization = "component" 38 | 39 | if path_normalization is None: 40 | path_normalization = "element" 41 | 42 | assert irrep_normalization in ["component", "norm", "none"] 43 | assert path_normalization in ["element", "path", "none"] 44 | 45 | self.irreps_in = o3.Irreps(irreps_in) 46 | self.irreps_out = o3.Irreps(irreps_out) 47 | 48 | del irreps_in, irreps_out 49 | 50 | if not isinstance(correlation, tuple): 51 | corr = correlation 52 | correlation = {} 53 | for irrep_out in self.irreps_out: 54 | correlation[irrep_out] = corr 55 | 56 | assert shared_weights or not internal_weights 57 | 58 | if internal_weights is None: 59 | internal_weights = True 60 | 61 | if element_dependent is None: 62 | element_dependent = True 63 | 64 | self.internal_weights = internal_weights 65 | self.shared_weights = shared_weights 66 | 67 | del internal_weights, shared_weights 68 | 69 | self.contractions = torch.nn.ModuleDict() 70 | for irrep_out in self.irreps_out: 71 | self.contractions[str(irrep_out)] = Contraction( 72 | irreps_in=self.irreps_in, 73 | irrep_out=o3.Irreps(str(irrep_out.ir)), 74 | correlation=correlation[irrep_out], 75 | internal_weights=self.internal_weights, 76 | element_dependent=element_dependent, 77 | num_elements=num_elements, 78 | weights=self.shared_weights, 79 | ) 80 | 81 | def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): 82 | outs = [] 83 | for irrep in self.irreps_out: 84 | outs.append(self.contractions[str(irrep)](x, y)) 85 | return torch.cat(outs, dim=-1) 86 | 87 | 88 | class Contraction(torch.nn.Module): 89 | def __init__( 90 | self, 91 | irreps_in: o3.Irreps, 92 | irrep_out: o3.Irreps, 93 | correlation: int, 94 | internal_weights: bool = True, 95 | element_dependent: bool = True, 96 | num_elements: Optional[int] = None, 97 | weights: Optional[torch.Tensor] = None, 98 | ) -> None: 99 | super().__init__() 100 | 101 | self.element_dependent = element_dependent 102 | self.num_features = irreps_in.count((0, 1)) 103 | self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in]) 104 | self.correlation = correlation 105 | dtype = torch.get_default_dtype() 106 | for nu in range(1, correlation + 1): 107 | U_matrix = U_matrix_real( 108 | irreps_in=self.coupling_irreps, 109 | irreps_out=irrep_out, 110 | correlation=nu, 111 | dtype=dtype, 112 | )[-1] 113 | self.register_buffer(f"U_matrix_{nu}", U_matrix) 114 | 115 | if element_dependent: 116 | # Tensor contraction equations 117 | self.equation_main = "...ik,ekc,bci,be -> bc..." 118 | self.equation_weighting = "...k,ekc,be->bc..." 119 | self.equation_contract = "bc...i,bci->bc..." 120 | if internal_weights: 121 | # Create weight for product basis 122 | self.weights = torch.nn.ParameterDict({}) 123 | for i in range(1, correlation + 1): 124 | num_params = self.U_tensors(i).size()[-1] 125 | w = torch.nn.Parameter( 126 | torch.randn(num_elements, num_params, self.num_features) 127 | / num_params 128 | ) 129 | self.weights[str(i)] = w 130 | else: 131 | self.register_buffer("weights", weights) 132 | 133 | else: 134 | # Tensor contraction equations 135 | self.equation_main = "...ik,kc,bci -> bc..." 136 | self.equation_weighting = "...k,kc->c..." 137 | self.equation_contract = "bc...i,bci->bc..." 138 | if internal_weights: 139 | # Create weight for product basis 140 | self.weights = torch.nn.ParameterDict({}) 141 | for i in range(1, correlation + 1): 142 | num_params = self.U_tensors(i).size()[-1] 143 | w = torch.nn.Parameter( 144 | torch.randn(num_params, self.num_features) / num_params 145 | ) 146 | self.weights[str(i)] = w 147 | else: 148 | self.register_buffer("weights", weights) 149 | 150 | def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): 151 | if self.element_dependent: 152 | out = contract( 153 | self.equation_main, 154 | self.U_tensors(self.correlation), 155 | self.weights[str(self.correlation)], 156 | x, 157 | y, 158 | ) # TODO: use optimize library and cuTENSOR # pylint: disable=fixme 159 | for corr in range(self.correlation - 1, 0, -1): 160 | c_tensor = contract( 161 | self.equation_weighting, 162 | self.U_tensors(corr), 163 | self.weights[str(corr)], 164 | y, 165 | ) 166 | c_tensor = c_tensor + out 167 | out = contract(self.equation_contract, c_tensor, x) 168 | 169 | else: 170 | out = contract( 171 | self.equation_main, 172 | self.U_tensors(self.correlation), 173 | self.weights[str(self.correlation)], 174 | x, 175 | ) # TODO: use optimize library and cuTENSOR # pylint: disable=fixme 176 | for corr in range(self.correlation - 1, 0, -1): 177 | c_tensor = contract( 178 | self.equation_weighting, 179 | self.U_tensors(corr), 180 | self.weights[str(corr)], 181 | ) 182 | c_tensor = c_tensor + out 183 | out = contract(self.equation_contract, c_tensor, x) 184 | resize_shape = torch.prod(torch.tensor(out.shape[1:])) 185 | return out.view(out.shape[0], resize_shape) 186 | 187 | def U_tensors(self, nu): 188 | return self._buffers[f"U_matrix_{nu}"] 189 | -------------------------------------------------------------------------------- /experiments/kchains.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Propogating geometric information: $k$-chains\n", 9 | "\n", 10 | "*Background:*\n", 11 | "In geometric GNNs, **geometric information**, such as the relative orientation of local neighbourhoods, is propagated via summing features from multiple layers in fixed dimensional spaces. \n", 12 | "The ideal architecture can be run for any number of layers to perfectly propagate geometric information without loss of information.\n", 13 | "In practice, stacking geometric GNN layers may lead to distortion or **loss of information from distant nodes**.\n", 14 | "\n", 15 | "*Experiment:*\n", 16 | "To study the practical implications of depth in propagating geometric information beyond local neighbourhoods, we consider **$k$-chain geometric graphs** which generalise the examples from [Schütt et al., 2021](https://arxiv.org/abs/2102.03150). \n", 17 | "Each pair of $k$-chains consists of $k+2$ nodes with $k$ nodes arranged in a line and differentiated by the orientation of the $2$ end points.\n", 18 | "Thus, $k$-chain graphs are $(\\lfloor \\frac{k}{2} \\rfloor + 1)$-hop distinguishable, and $(\\lfloor \\frac{k}{2} \\rfloor + 1)$ geometric GNN iterations should be theoretically sufficient to distinguish them.\n", 19 | "In this notebook, we train equivariant and invariant geometric GNNs with an increasing number of layers to distinguish $k$-chains.\n", 20 | "\n", 21 | "![k-chains](fig/kchains.png)\n", 22 | "\n", 23 | "*Results:*\n", 24 | "- Despite the supposed simplicity of the task, especially for small chain lengths, we find that popular equivariant GNNs such as E-GNN and TFN may require **more iterations** than theoretically sufficient.\n", 25 | "- Notably, as the length of the chain gets larger than $k=4$, all equivariant GNNs tended to lose performance and required more than $(\\lfloor \\frac{k}{2} \\rfloor + 1)$ iterations to solve the task.\n", 26 | "- Invariant GNNs are **unable** to distinguish $k$-chains.\n", 27 | "\n", 28 | "These results point to preliminary evidence of the **oversquashing** phenomenon when geometric information is propagated across multiple layers using fixed dimensional feature spaces.\n", 29 | "These issues are most evident for E-GNN, which uses a single vector feature to aggregate and propagate geometric information." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "%load_ext autoreload\n", 39 | "%autoreload 2\n", 40 | "\n", 41 | "import sys\n", 42 | "sys.path.append('../')\n", 43 | "\n", 44 | "import torch\n", 45 | "import torch_geometric\n", 46 | "from torch_geometric.data import Data\n", 47 | "from torch_geometric.loader import DataLoader\n", 48 | "from torch_geometric.utils import to_undirected\n", 49 | "import e3nn\n", 50 | "from functools import partial\n", 51 | "\n", 52 | "print(\"PyTorch version {}\".format(torch.__version__))\n", 53 | "print(\"PyG version {}\".format(torch_geometric.__version__))\n", 54 | "print(\"e3nn version {}\".format(e3nn.__version__))\n", 55 | "\n", 56 | "from experiments.utils.plot_utils import plot_3d\n", 57 | "from experiments.utils.train_utils import run_experiment\n", 58 | "from models import SchNetModel, DimeNetPPModel, SphereNetModel, EGNNModel, GVPGNNModel, TFNModel, MACEModel\n", 59 | "\n", 60 | "# Set the device\n", 61 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 62 | "print(f\"Using device: {device}\")" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "def create_kchains(k):\n", 72 | " assert k >= 2\n", 73 | " \n", 74 | " dataset = []\n", 75 | "\n", 76 | " # Graph 0\n", 77 | " atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )\n", 78 | " edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )\n", 79 | " pos = torch.FloatTensor(\n", 80 | " [[-4, -3, 0]] + \n", 81 | " [[0, 5*i , 0] for i in range(k)] + \n", 82 | " [[4, 5*(k-1) + 3, 0]]\n", 83 | " )\n", 84 | " center_of_mass = torch.mean(pos, dim=0)\n", 85 | " pos = pos - center_of_mass\n", 86 | " y = torch.LongTensor([0]) # Label 0\n", 87 | " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", 88 | " data1.edge_index = to_undirected(data1.edge_index)\n", 89 | " dataset.append(data1)\n", 90 | " \n", 91 | " # Graph 1\n", 92 | " atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )\n", 93 | " edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )\n", 94 | " pos = torch.FloatTensor(\n", 95 | " [[4, -3, 0]] + \n", 96 | " [[0, 5*i , 0] for i in range(k)] + \n", 97 | " [[4, 5*(k-1) + 3, 0]]\n", 98 | " )\n", 99 | " center_of_mass = torch.mean(pos, dim=0)\n", 100 | " pos = pos - center_of_mass\n", 101 | " y = torch.LongTensor([1]) # Label 1\n", 102 | " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", 103 | " data2.edge_index = to_undirected(data2.edge_index)\n", 104 | " dataset.append(data2)\n", 105 | " \n", 106 | " return dataset" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "k = 4\n", 116 | "\n", 117 | "# Create dataset\n", 118 | "dataset = create_kchains(k=k)\n", 119 | "for data in dataset:\n", 120 | " plot_3d(data, lim=5*k)\n", 121 | "\n", 122 | "# Create dataloaders\n", 123 | "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n", 124 | "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n", 125 | "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "# Set model\n", 135 | "model_name = \"gvp\"\n", 136 | "\n", 137 | "for num_layers in range(k // 2 , k + 3):\n", 138 | "\n", 139 | " print(f\"\\nNumber of layers: {num_layers}\")\n", 140 | " \n", 141 | " correlation = 2\n", 142 | " model = {\n", 143 | " \"schnet\": SchNetModel,\n", 144 | " \"dimenet\": DimeNetPPModel,\n", 145 | " \"spherenet\": SphereNetModel,\n", 146 | " \"egnn\": EGNNModel,\n", 147 | " \"gvp\": partial(GVPGNNModel, s_dim=32, v_dim=1),\n", 148 | " \"tfn\": TFNModel,\n", 149 | " \"mace\": partial(MACEModel, correlation=correlation),\n", 150 | " }[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n", 151 | " \n", 152 | " best_val_acc, test_acc, train_time = run_experiment(\n", 153 | " model, \n", 154 | " dataloader,\n", 155 | " val_loader, \n", 156 | " test_loader,\n", 157 | " n_epochs=100,\n", 158 | " n_times=10,\n", 159 | " device=device,\n", 160 | " verbose=False\n", 161 | " )" 162 | ] 163 | } 164 | ], 165 | "metadata": { 166 | "kernelspec": { 167 | "display_name": "Python 3", 168 | "language": "python", 169 | "name": "python3" 170 | }, 171 | "language_info": { 172 | "codemirror_mode": { 173 | "name": "ipython", 174 | "version": 3 175 | }, 176 | "file_extension": ".py", 177 | "mimetype": "text/x-python", 178 | "name": "python", 179 | "nbconvert_exporter": "python", 180 | "pygments_lexer": "ipython3", 181 | "version": "3.8.16" 182 | }, 183 | "orig_nbformat": 4, 184 | "vscode": { 185 | "interpreter": { 186 | "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897" 187 | } 188 | } 189 | }, 190 | "nbformat": 4, 191 | "nbformat_minor": 2 192 | } 193 | -------------------------------------------------------------------------------- /models/mace.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch_geometric.nn import global_add_pool, global_mean_pool 6 | import e3nn 7 | 8 | from models.mace_modules.irreps_tools import reshape_irreps 9 | from models.mace_modules.blocks import ( 10 | EquivariantProductBasisBlock, 11 | RadialEmbeddingBlock, 12 | ) 13 | from models.layers.tfn_layer import TensorProductConvLayer 14 | 15 | 16 | class MACEModel(torch.nn.Module): 17 | """ 18 | MACE model from "MACE: Higher Order Equivariant Message Passing Neural Networks". 19 | """ 20 | def __init__( 21 | self, 22 | r_max: float = 10.0, 23 | num_bessel: int = 8, 24 | num_polynomial_cutoff: int = 5, 25 | max_ell: int = 2, 26 | correlation: int = 3, 27 | num_layers: int = 5, 28 | emb_dim: int = 64, 29 | hidden_irreps: Optional[e3nn.o3.Irreps] = None, 30 | mlp_dim: int = 256, 31 | in_dim: int = 1, 32 | out_dim: int = 1, 33 | aggr: str = "sum", 34 | pool: str = "sum", 35 | batch_norm: bool = True, 36 | residual: bool = True, 37 | equivariant_pred: bool = False 38 | ): 39 | """ 40 | Parameters: 41 | - r_max (float): Maximum distance for Bessel basis functions (default: 10.0) 42 | - num_bessel (int): Number of Bessel basis functions (default: 8) 43 | - num_polynomial_cutoff (int): Number of polynomial cutoff basis functions (default: 5) 44 | - max_ell (int): Maximum degree of spherical harmonics basis functions (default: 2) 45 | - correlation (int): Local correlation order = body order - 1 (default: 3) 46 | - num_layers (int): Number of layers in the model (default: 5) 47 | - emb_dim (int): Scalar feature embedding dimension (default: 64) 48 | - hidden_irreps (Optional[e3nn.o3.Irreps]): Hidden irreps (default: None) 49 | - mlp_dim (int): Dimension of MLP for computing tensor product weights (default: 256) 50 | - in_dim (int): Input dimension of the model (default: 1) 51 | - out_dim (int): Output dimension of the model (default: 1) 52 | - aggr (str): Aggregation method to be used (default: "sum") 53 | - pool (str): Global pooling method to be used (default: "sum") 54 | - batch_norm (bool): Whether to use batch normalization (default: True) 55 | - residual (bool): Whether to use residual connections (default: True) 56 | - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False) 57 | 58 | Note: 59 | - If `hidden_irreps` is None, the irreps for the intermediate features are computed 60 | using `emb_dim` and `max_ell`. 61 | - The `equivariant_pred` parameter determines whether it is an equivariant prediction task. 62 | If set to True, equivariant prediction will be performed. 63 | """ 64 | super().__init__() 65 | 66 | self.r_max = r_max 67 | self.max_ell = max_ell 68 | self.num_layers = num_layers 69 | self.emb_dim = emb_dim 70 | self.mlp_dim = mlp_dim 71 | self.residual = residual 72 | self.batch_norm = batch_norm 73 | self.hidden_irreps = hidden_irreps 74 | self.equivariant_pred = equivariant_pred 75 | 76 | # Edge embedding 77 | self.radial_embedding = RadialEmbeddingBlock( 78 | r_max=r_max, 79 | num_bessel=num_bessel, 80 | num_polynomial_cutoff=num_polynomial_cutoff, 81 | ) 82 | sh_irreps = e3nn.o3.Irreps.spherical_harmonics(max_ell) 83 | self.spherical_harmonics = e3nn.o3.SphericalHarmonics( 84 | sh_irreps, normalize=True, normalization="component" 85 | ) 86 | 87 | # Embedding lookup for initial node features 88 | self.emb_in = torch.nn.Embedding(in_dim, emb_dim) 89 | 90 | # Set hidden irreps if none are provided 91 | if hidden_irreps is None: 92 | hidden_irreps = (sh_irreps * emb_dim).sort()[0].simplify() 93 | # Note: This defaults to O(3) equivariant layers 94 | # It is possible to use SO(3) equivariance by passing the appropriate irreps 95 | 96 | self.convs = torch.nn.ModuleList() 97 | self.prods = torch.nn.ModuleList() 98 | self.reshapes = torch.nn.ModuleList() 99 | 100 | # First layer: scalar only -> tensor 101 | self.convs.append( 102 | TensorProductConvLayer( 103 | in_irreps=e3nn.o3.Irreps(f'{emb_dim}x0e'), 104 | out_irreps=hidden_irreps, 105 | sh_irreps=sh_irreps, 106 | edge_feats_dim=self.radial_embedding.out_dim, 107 | mlp_dim=mlp_dim, 108 | aggr=aggr, 109 | batch_norm=batch_norm, 110 | gate=False, 111 | ) 112 | ) 113 | self.reshapes.append(reshape_irreps(hidden_irreps)) 114 | self.prods.append( 115 | EquivariantProductBasisBlock( 116 | node_feats_irreps=hidden_irreps, 117 | target_irreps=hidden_irreps, 118 | correlation=correlation, 119 | element_dependent=False, 120 | num_elements=in_dim, 121 | use_sc=residual 122 | ) 123 | ) 124 | 125 | # Intermediate layers: tensor -> tensor 126 | for _ in range(num_layers - 1): 127 | self.convs.append( 128 | TensorProductConvLayer( 129 | in_irreps=hidden_irreps, 130 | out_irreps=hidden_irreps, 131 | sh_irreps=sh_irreps, 132 | edge_feats_dim=self.radial_embedding.out_dim, 133 | mlp_dim=mlp_dim, 134 | aggr=aggr, 135 | batch_norm=batch_norm, 136 | gate=False, 137 | ) 138 | ) 139 | self.reshapes.append(reshape_irreps(hidden_irreps)) 140 | self.prods.append( 141 | EquivariantProductBasisBlock( 142 | node_feats_irreps=hidden_irreps, 143 | target_irreps=hidden_irreps, 144 | correlation=correlation, 145 | element_dependent=False, 146 | num_elements=in_dim, 147 | use_sc=residual 148 | ) 149 | ) 150 | 151 | # Global pooling/readout function 152 | self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] 153 | 154 | if self.equivariant_pred: 155 | # Linear predictor for equivariant tasks using geometric features 156 | self.pred = torch.nn.Linear(hidden_irreps.dim, out_dim) 157 | else: 158 | # MLP predictor for invariant tasks using only scalar features 159 | self.pred = torch.nn.Sequential( 160 | torch.nn.Linear(emb_dim, emb_dim), 161 | torch.nn.ReLU(), 162 | torch.nn.Linear(emb_dim, out_dim) 163 | ) 164 | 165 | def forward(self, batch): 166 | # Node embedding 167 | h = self.emb_in(batch.atoms) # (n,) -> (n, d) 168 | 169 | # Edge features 170 | vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]] # [n_edges, 3] 171 | lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] 172 | 173 | edge_sh = self.spherical_harmonics(vectors) 174 | edge_feats = self.radial_embedding(lengths) 175 | 176 | for conv, reshape, prod in zip(self.convs, self.reshapes, self.prods): 177 | # Message passing layer 178 | h_update = conv(h, batch.edge_index, edge_sh, edge_feats) 179 | 180 | # Update node features 181 | sc = F.pad(h, (0, h_update.shape[-1] - h.shape[-1])) 182 | h = prod(reshape(h_update), sc, None) 183 | 184 | out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d) 185 | 186 | if not self.equivariant_pred: 187 | # Select only scalars for invariant prediction 188 | out = out[:,:self.emb_dim] 189 | 190 | return self.pred(out) # (batch_size, out_dim) 191 | -------------------------------------------------------------------------------- /models/layers/gvp_layer.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # Implementation of Geometric Vector Perceptron layers 3 | # 4 | # Papers: 5 | # (1) Learning from Protein Structure with Geometric Vector Perceptrons, 6 | # by B Jing, S Eismann, P Suriana, RJL Townshend, and RO Dror 7 | # (2) Equivariant Graph Neural Networks for 3D Macromolecular Structure, 8 | # by B Jing, S Eismann, P Soni, and RO Dror 9 | # 10 | # Orginal repository: https://github.com/drorlab/gvp-pytorch 11 | ########################################################################################### 12 | 13 | import functools 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import nn 17 | import torch_scatter 18 | from torch_geometric.nn import MessagePassing 19 | 20 | 21 | def tuple_sum(*args): 22 | """ 23 | Sums any number of tuples (s, V) elementwise. 24 | """ 25 | return tuple(map(sum, zip(*args))) 26 | 27 | 28 | def tuple_cat(*args, dim=-1): 29 | """ 30 | Concatenates any number of tuples (s, V) elementwise. 31 | 32 | :param dim: dimension along which to concatenate when viewed 33 | as the `dim` index for the scalar-channel tensors. 34 | This means that `dim=-1` will be applied as 35 | `dim=-2` for the vector-channel tensors. 36 | """ 37 | dim %= len(args[0][0].shape) 38 | s_args, v_args = list(zip(*args)) 39 | return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim) 40 | 41 | 42 | def tuple_index(x, idx): 43 | """ 44 | Indexes into a tuple (s, V) along the first dimension. 45 | 46 | :param idx: any object which can be used to index into a `torch.Tensor` 47 | """ 48 | return x[0][idx], x[1][idx] 49 | 50 | 51 | def randn(n, dims, device="cpu"): 52 | """ 53 | Returns random tuples (s, V) drawn elementwise from a normal distribution. 54 | 55 | :param n: number of data points 56 | :param dims: tuple of dimensions (n_scalar, n_vector) 57 | 58 | :return: (s, V) with s.shape = (n, n_scalar) and 59 | V.shape = (n, n_vector, 3) 60 | """ 61 | return torch.randn(n, dims[0], device=device), torch.randn( 62 | n, dims[1], 3, device=device 63 | ) 64 | 65 | 66 | def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True): 67 | """ 68 | L2 norm of tensor clamped above a minimum value `eps`. 69 | 70 | :param sqrt: if `False`, returns the square of the L2 norm 71 | """ 72 | out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps) 73 | return torch.sqrt(out) if sqrt else out 74 | 75 | 76 | def _split(x, nv): 77 | ''' 78 | Splits a merged representation of (s, V) back into a tuple. 79 | Should be used only with `_merge(s, V)` and only if the tuple 80 | representation cannot be used. 81 | 82 | :param x: the `torch.Tensor` returned from `_merge` 83 | :param nv: the number of vector channels in the input to `_merge` 84 | ''' 85 | s = x[..., :-3 * nv] 86 | v = x[..., -3 * nv:].contiguous().view(x.shape[0], nv, 3) 87 | return s, v 88 | 89 | 90 | def _merge(s, v): 91 | ''' 92 | Merges a tuple (s, V) into a single `torch.Tensor`, where the 93 | vector channels are flattened and appended to the scalar channels. 94 | Should be used only if the tuple representation cannot be used. 95 | Use `_split(x, nv)` to reverse. 96 | ''' 97 | v = v.contiguous().view(v.shape[0], v.shape[1] * 3) 98 | return torch.cat([s, v], -1) 99 | 100 | 101 | class GVP(nn.Module): 102 | """ 103 | Geometric Vector Perceptron. See manuscript and README.md 104 | for more details. 105 | 106 | :param in_dims: tuple (n_scalar, n_vector) 107 | :param out_dims: tuple (n_scalar, n_vector) 108 | :param h_dim: intermediate number of vector channels, optional 109 | :param activations: tuple of functions (scalar_act, vector_act) 110 | :param vector_gate: whether to use vector gating. 111 | (vector_act will be used as sigma^+ in vector gating if `True`) 112 | """ 113 | 114 | def __init__( 115 | self, 116 | in_dims, 117 | out_dims, 118 | h_dim=None, 119 | activations=(F.relu, torch.sigmoid), 120 | vector_gate=True, 121 | ): 122 | super(GVP, self).__init__() 123 | self.si, self.vi = in_dims 124 | self.so, self.vo = out_dims 125 | self.vector_gate = vector_gate 126 | if self.vi: 127 | self.h_dim = h_dim or max(self.vi, self.vo) 128 | self.wh = nn.Linear(self.vi, self.h_dim, bias=False) 129 | self.ws = nn.Linear(self.h_dim + self.si, self.so) 130 | if self.vo: 131 | self.wv = nn.Linear(self.h_dim, self.vo, bias=False) 132 | if self.vector_gate: 133 | self.wsv = nn.Linear(self.so, self.vo) 134 | else: 135 | self.ws = nn.Linear(self.si, self.so) 136 | 137 | self.scalar_act, self.vector_act = activations 138 | self.dummy_param = nn.Parameter(torch.empty(0)) 139 | 140 | def forward(self, x): 141 | """ 142 | :param x: tuple (s, V) of `torch.Tensor`, 143 | or (if vectors_in is 0), a single `torch.Tensor` 144 | :return: tuple (s, V) of `torch.Tensor`, 145 | or (if vectors_out is 0), a single `torch.Tensor` 146 | """ 147 | if self.vi: 148 | s, v = x 149 | v = torch.transpose(v, -1, -2) 150 | vh = self.wh(v) 151 | vn = _norm_no_nan(vh, axis=-2) 152 | s = self.ws(torch.cat([s, vn], -1)) 153 | if self.vo: 154 | v = self.wv(vh) 155 | v = torch.transpose(v, -1, -2) 156 | if self.vector_gate: 157 | gate = ( 158 | self.wsv(self.vector_act(s)) if self.vector_act else self.wsv(s) 159 | ) 160 | v = v * torch.sigmoid(gate).unsqueeze(-1) 161 | elif self.vector_act: 162 | v = v * self.vector_act(_norm_no_nan(v, axis=-1, keepdims=True)) 163 | else: 164 | s = self.ws(x) 165 | if self.vo: 166 | v = torch.zeros(s.shape[0], self.vo, 3, device=self.dummy_param.device) 167 | if self.scalar_act: 168 | s = self.scalar_act(s) 169 | 170 | return (s, v) if self.vo else s 171 | 172 | 173 | class _VDropout(nn.Module): 174 | """ 175 | Vector channel dropout where the elements of each 176 | vector channel are dropped together. 177 | """ 178 | 179 | def __init__(self, drop_rate): 180 | super(_VDropout, self).__init__() 181 | self.drop_rate = drop_rate 182 | self.dummy_param = nn.Parameter(torch.empty(0)) 183 | 184 | def forward(self, x): 185 | """ 186 | :param x: `torch.Tensor` corresponding to vector channels 187 | """ 188 | device = self.dummy_param.device 189 | if not self.training: 190 | return x 191 | mask = torch.bernoulli( 192 | (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device) 193 | ).unsqueeze(-1) 194 | x = mask * x / (1 - self.drop_rate) 195 | return x 196 | 197 | 198 | class Dropout(nn.Module): 199 | """ 200 | Combined dropout for tuples (s, V). 201 | Takes tuples (s, V) as input and as output. 202 | """ 203 | 204 | def __init__(self, drop_rate): 205 | super(Dropout, self).__init__() 206 | self.sdropout = nn.Dropout(drop_rate) 207 | self.vdropout = _VDropout(drop_rate) 208 | 209 | def forward(self, x): 210 | """ 211 | :param x: tuple (s, V) of `torch.Tensor`, 212 | or single `torch.Tensor` 213 | (will be assumed to be scalar channels) 214 | """ 215 | if type(x) is torch.Tensor: 216 | return self.sdropout(x) 217 | s, v = x 218 | return self.sdropout(s), self.vdropout(v) 219 | 220 | 221 | class LayerNorm(nn.Module): 222 | """ 223 | Combined LayerNorm for tuples (s, V). 224 | Takes tuples (s, V) as input and as output. 225 | """ 226 | 227 | def __init__(self, dims): 228 | super(LayerNorm, self).__init__() 229 | self.s, self.v = dims 230 | self.scalar_norm = nn.LayerNorm(self.s) 231 | 232 | def forward(self, x): 233 | """ 234 | :param x: tuple (s, V) of `torch.Tensor`, 235 | or single `torch.Tensor` 236 | (will be assumed to be scalar channels) 237 | """ 238 | if not self.v: 239 | return self.scalar_norm(x) 240 | s, v = x 241 | vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False) 242 | vn = torch.sqrt(torch.mean(vn, dim=-2, keepdim=True)) 243 | return self.scalar_norm(s), v / vn 244 | 245 | 246 | class GVPConv(MessagePassing): 247 | """ 248 | Graph convolution / message passing with Geometric Vector Perceptrons. 249 | Takes in a graph with node and edge embeddings, 250 | and returns new node embeddings. 251 | 252 | This does NOT do residual updates and pointwise feedforward layers 253 | ---see `GVPConvLayer`. 254 | 255 | :param in_dims: input node embedding dimensions (n_scalar, n_vector) 256 | :param out_dims: output node embedding dimensions (n_scalar, n_vector) 257 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) 258 | :param n_layers: number of GVPs in the message function 259 | :param module_list: preconstructed message function, overrides n_layers 260 | :param aggr: should be "add" if some incoming edges are masked, as in 261 | a masked autoregressive decoder architecture, otherwise "mean" 262 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs 263 | :param vector_gate: whether to use vector gating. 264 | (vector_act will be used as sigma^+ in vector gating if `True`) 265 | """ 266 | 267 | def __init__( 268 | self, 269 | in_dims, 270 | out_dims, 271 | edge_dims, 272 | n_layers=3, 273 | module_list=None, 274 | aggr="mean", 275 | activations=(F.relu, torch.sigmoid), 276 | vector_gate=True, 277 | ): 278 | super(GVPConv, self).__init__(aggr=aggr) 279 | self.si, self.vi = in_dims 280 | self.so, self.vo = out_dims 281 | self.se, self.ve = edge_dims 282 | 283 | GVP_ = functools.partial(GVP, activations=activations, vector_gate=vector_gate) 284 | 285 | module_list = module_list or [] 286 | if not module_list: 287 | if n_layers == 1: 288 | module_list.append( 289 | GVP_( 290 | (2 * self.si + self.se, 2 * self.vi + self.ve), 291 | (self.so, self.vo), 292 | activations=(None, None), 293 | ) 294 | ) 295 | else: 296 | module_list.append( 297 | GVP_((2 * self.si + self.se, 2 * self.vi + self.ve), out_dims) 298 | ) 299 | for i in range(n_layers - 2): 300 | module_list.append(GVP_(out_dims, out_dims)) 301 | module_list.append(GVP_(out_dims, out_dims, activations=(None, None))) 302 | self.message_func = nn.Sequential(*module_list) 303 | 304 | def forward(self, x, edge_index, edge_attr): 305 | """ 306 | :param x: tuple (s, V) of `torch.Tensor` 307 | :param edge_index: array of shape [2, n_edges] 308 | :param edge_attr: tuple (s, V) of `torch.Tensor` 309 | """ 310 | x_s, x_v = x 311 | message = self.propagate( 312 | edge_index, 313 | s=x_s, 314 | v=x_v.contiguous().view(x_v.shape[0], x_v.shape[1] * 3), 315 | edge_attr=edge_attr, 316 | ) 317 | return _split(message, self.vo) 318 | 319 | def message(self, s_i, v_i, s_j, v_j, edge_attr): 320 | v_j = v_j.view(v_j.shape[0], v_j.shape[1] // 3, 3) 321 | v_i = v_i.view(v_i.shape[0], v_i.shape[1] // 3, 3) 322 | message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i)) 323 | message = self.message_func(message) 324 | return _merge(*message) 325 | 326 | 327 | class GVPConvLayer(nn.Module): 328 | """ 329 | Full graph convolution / message passing layer with 330 | Geometric Vector Perceptrons. Residually updates node embeddings with 331 | aggregated incoming messages, applies a pointwise feedforward 332 | network to node embeddings, and returns updated node embeddings. 333 | 334 | To only compute the aggregated messages, see `GVPConv`. 335 | 336 | :param node_dims: node embedding dimensions (n_scalar, n_vector) 337 | :param edge_dims: input edge embedding dimensions (n_scalar, n_vector) 338 | :param n_message: number of GVPs to use in message function 339 | :param n_feedforward: number of GVPs to use in feedforward function 340 | :param drop_rate: drop probability in all dropout layers 341 | :param autoregressive: if `True`, this `GVPConvLayer` will be used 342 | with a different set of input node embeddings for messages 343 | where src >= dst 344 | :param activations: tuple of functions (scalar_act, vector_act) to use in GVPs 345 | :param vector_gate: whether to use vector gating. 346 | (vector_act will be used as sigma^+ in vector gating if `True`) 347 | """ 348 | 349 | def __init__( 350 | self, 351 | node_dims, 352 | edge_dims, 353 | n_message=3, 354 | n_feedforward=2, 355 | drop_rate=0.1, 356 | autoregressive=False, 357 | activations=(F.relu, torch.sigmoid), 358 | vector_gate=True, 359 | residual=True, 360 | ): 361 | super(GVPConvLayer, self).__init__() 362 | self.conv = GVPConv( 363 | node_dims, 364 | node_dims, 365 | edge_dims, 366 | n_message, 367 | aggr="add" if autoregressive else "mean", 368 | activations=activations, 369 | vector_gate=vector_gate, 370 | ) 371 | GVP_ = functools.partial(GVP, activations=activations, vector_gate=vector_gate) 372 | self.norm = nn.ModuleList([LayerNorm(node_dims) for _ in range(2)]) 373 | self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)]) 374 | 375 | ff_func = [] 376 | if n_feedforward == 1: 377 | ff_func.append(GVP_(node_dims, node_dims, activations=(None, None))) 378 | else: 379 | hid_dims = 4 * node_dims[0], 2 * node_dims[1] 380 | ff_func.append(GVP_(node_dims, hid_dims)) 381 | ff_func.extend(GVP_(hid_dims, hid_dims) for _ in range(n_feedforward - 2)) 382 | ff_func.append(GVP_(hid_dims, node_dims, activations=(None, None))) 383 | self.ff_func = nn.Sequential(*ff_func) 384 | self.residual = residual 385 | 386 | def forward(self, x, edge_index, edge_attr, autoregressive_x=None, node_mask=None): 387 | """ 388 | :param x: tuple (s, V) of `torch.Tensor` 389 | :param edge_index: array of shape [2, n_edges] 390 | :param edge_attr: tuple (s, V) of `torch.Tensor` 391 | :param autoregressive_x: tuple (s, V) of `torch.Tensor`. 392 | If not `None`, will be used as src node embeddings 393 | for forming messages where src >= dst. The corrent node 394 | embeddings `x` will still be the base of the update and the 395 | pointwise feedforward. 396 | :param node_mask: array of type `bool` to index into the first 397 | dim of node embeddings (s, V). If not `None`, only 398 | these nodes will be updated. 399 | """ 400 | 401 | if autoregressive_x is not None: 402 | src, dst = edge_index 403 | mask = src < dst 404 | edge_index_forward = edge_index[:, mask] 405 | edge_index_backward = edge_index[:, ~mask] 406 | edge_attr_forward = tuple_index(edge_attr, mask) 407 | edge_attr_backward = tuple_index(edge_attr, ~mask) 408 | 409 | dh = tuple_sum( 410 | self.conv(x, edge_index_forward, edge_attr_forward), 411 | self.conv(autoregressive_x, edge_index_backward, edge_attr_backward), 412 | ) 413 | 414 | count = ( 415 | torch_scatter.scatter_add( 416 | torch.ones_like(dst), dst, dim_size=dh[0].size(0) 417 | ) 418 | .clamp(min=1) 419 | .unsqueeze(-1) 420 | ) 421 | 422 | dh = dh[0] / count, dh[1] / count.unsqueeze(-1) 423 | 424 | else: 425 | dh = self.conv(x, edge_index, edge_attr) 426 | 427 | if node_mask is not None: 428 | x_ = x 429 | x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask) 430 | 431 | x = self.norm[0](tuple_sum(x, self.dropout[0](dh))) if self.residual else dh 432 | 433 | dh = self.ff_func(x) 434 | x = self.norm[1](tuple_sum(x, self.dropout[1](dh))) if self.residual else dh 435 | 436 | if node_mask is not None: 437 | x_[0][node_mask], x_[1][node_mask] = x[0], x[1] 438 | x = x_ 439 | return x -------------------------------------------------------------------------------- /experiments/incompleteness.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "# Identifying neighbourhood fingerprints: counterexamples from [Pozdnyakov et al., 2020](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.125.166001)\n", 9 | "\n", 10 | "*Background:*\n", 11 | "Geometric GNNs identify local neighbourhoods around nodes via **'neighbourhood finderprints'**, where local geometric information from subsets of neighbours is aggregated to compute invariant scalars. \n", 12 | "The number of neighbours involved in computing the scalars is termed the **body order**.\n", 13 | "The ideal neighbourhood fingerprint would perfectly identify neighbourhoods, which requires arbitrarily high body order.\n", 14 | "\n", 15 | "*Experiment:*\n", 16 | "To demonstrate the practical implications of scalarisation body order, we evaluate geometric GNN layers on their ability to discriminate counterexamples from [Pozdnyakov et al., 2020](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.125.166001).\n", 17 | "Each counterexample consists of a pair of local neighbourhoods that are **indistinguishable** when comparing their set of $k$-body scalars, i.e. geometric GNN layers with body order $k$ cannot distinguish the neighbourhoods.\n", 18 | "The 3-body counterexample corresponds to Fig.1(b) in Pozdnyakov et al., 2020, 4-body chiral to Fig.2(e), and 4-body non-chiral to Fig.2(f); the 2-body counterexample is based on the two local neighbourhoods in the running example from our paper.\n", 19 | "In this notebook, we train single layer geometric GNNs to distinguish the counterexamples using updated scalar features. \n", 20 | "\n", 21 | "![Counterexamples from Pozdnyakov et al., 2020](fig/incompleteness.png)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "%load_ext autoreload\n", 31 | "%autoreload 2\n", 32 | "\n", 33 | "import sys\n", 34 | "sys.path.append('../')\n", 35 | "\n", 36 | "import torch\n", 37 | "import torch_geometric\n", 38 | "from torch_geometric.data import Data\n", 39 | "from torch_geometric.loader import DataLoader\n", 40 | "from torch_geometric.utils import to_undirected\n", 41 | "import e3nn\n", 42 | "from functools import partial\n", 43 | "\n", 44 | "print(\"PyTorch version {}\".format(torch.__version__))\n", 45 | "print(\"PyG version {}\".format(torch_geometric.__version__))\n", 46 | "print(\"e3nn version {}\".format(e3nn.__version__))\n", 47 | "\n", 48 | "from experiments.utils.plot_utils import plot_3d\n", 49 | "from experiments.utils.train_utils import run_experiment\n", 50 | "from models import SchNetModel, DimeNetPPModel, SphereNetModel, EGNNModel, GVPGNNModel, TFNModel, MACEModel\n", 51 | "\n", 52 | "# Set the device\n", 53 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", 54 | "print(f\"Using device: {device}\")" 55 | ] 56 | }, 57 | { 58 | "attachments": {}, 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "# Two-body counterexample\n", 63 | "\n", 64 | "Pair of local neighbourhoods that are indistinguishable when comparing their set of $2$-body scalars, i.e. the unordered set of pairwise distances." 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "def create_two_body_envs():\n", 74 | " dataset = []\n", 75 | "\n", 76 | " # Environment 0\n", 77 | " atoms = torch.LongTensor([ 0, 0, 0 ])\n", 78 | " edge_index = torch.LongTensor([ [0, 0], [1, 2] ])\n", 79 | " pos = torch.FloatTensor([ \n", 80 | " [0, 0, 0],\n", 81 | " [5, 0, 0],\n", 82 | " [3, 0, 4]\n", 83 | " ])\n", 84 | " y = torch.LongTensor([0]) # Label 0\n", 85 | " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", 86 | " data1.edge_index = to_undirected(data1.edge_index)\n", 87 | " dataset.append(data1)\n", 88 | " \n", 89 | " # Environment 1\n", 90 | " atoms = torch.LongTensor([ 0, 0, 0 ])\n", 91 | " edge_index = torch.LongTensor([ [0, 0], [1, 2] ])\n", 92 | " pos = torch.FloatTensor([ \n", 93 | " [0, 0, 0],\n", 94 | " [5, 0, 0],\n", 95 | " [-5, 0, 0]\n", 96 | " ])\n", 97 | " y = torch.LongTensor([1]) # Label 1\n", 98 | " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", 99 | " data2.edge_index = to_undirected(data2.edge_index)\n", 100 | " dataset.append(data2)\n", 101 | " \n", 102 | " return dataset\n", 103 | "\n", 104 | "# Create dataset\n", 105 | "dataset = create_two_body_envs()\n", 106 | "for data in dataset:\n", 107 | " plot_3d(data, lim=5)\n", 108 | "\n", 109 | "# Create dataloaders\n", 110 | "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n", 111 | "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", 112 | "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "# Set model\n", 122 | "model_name = \"mace\"\n", 123 | "\n", 124 | "correlation = 2\n", 125 | "model = {\n", 126 | " \"schnet\": SchNetModel,\n", 127 | " \"dimenet\": DimeNetPPModel,\n", 128 | " \"spherenet\": SphereNetModel,\n", 129 | " \"egnn\": EGNNModel,\n", 130 | " \"gvp\": GVPGNNModel,\n", 131 | " \"tfn\": TFNModel,\n", 132 | " \"mace\": partial(MACEModel, correlation=correlation),\n", 133 | "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n", 134 | "\n", 135 | "best_val_acc, test_acc, train_time = run_experiment(\n", 136 | " model, \n", 137 | " dataloader,\n", 138 | " val_loader, \n", 139 | " test_loader,\n", 140 | " n_epochs=100,\n", 141 | " n_times=10,\n", 142 | " device=device,\n", 143 | " verbose=False\n", 144 | ")" 145 | ] 146 | }, 147 | { 148 | "attachments": {}, 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "# Three-body counterexample\n", 153 | "\n", 154 | "Pair of local neighbourhoods that are indistinguishable when comparing their set of $3$-body scalars, i.e. the unordered set of pairwise distances as well as angles." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "def create_three_body_envs():\n", 164 | " dataset = []\n", 165 | "\n", 166 | " a_x, a_y, a_z = 5, 0, 5\n", 167 | " b_x, b_y, b_z = 5, 5, 5\n", 168 | " c_x, c_y, c_z = 0, 5, 5\n", 169 | " \n", 170 | " # Environment 0\n", 171 | " atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n", 172 | " edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n", 173 | " pos = torch.FloatTensor([ \n", 174 | " [0, 0, 0],\n", 175 | " [a_x, a_y, a_z],\n", 176 | " [+b_x, +b_y, b_z],\n", 177 | " [-b_x, -b_y, b_z],\n", 178 | " [c_x, +c_y, c_z],\n", 179 | " ])\n", 180 | " y = torch.LongTensor([0]) # Label 0\n", 181 | " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", 182 | " data1.edge_index = to_undirected(data1.edge_index)\n", 183 | " dataset.append(data1)\n", 184 | " \n", 185 | " # Environment 1\n", 186 | " atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n", 187 | " edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n", 188 | " pos = torch.FloatTensor([ \n", 189 | " [0, 0, 0],\n", 190 | " [a_x, a_y, a_z],\n", 191 | " [+b_x, +b_y, b_z],\n", 192 | " [-b_x, -b_y, b_z],\n", 193 | " [c_x, -c_y, c_z],\n", 194 | " ])\n", 195 | " y = torch.LongTensor([1]) # Label 1\n", 196 | " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", 197 | " data2.edge_index = to_undirected(data2.edge_index)\n", 198 | " dataset.append(data2)\n", 199 | " \n", 200 | " return dataset\n", 201 | "\n", 202 | "# Create dataset\n", 203 | "dataset = create_three_body_envs()\n", 204 | "for data in dataset:\n", 205 | " plot_3d(data, lim=5)\n", 206 | "\n", 207 | "# Create dataloaders\n", 208 | "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n", 209 | "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", 210 | "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "# Set model\n", 220 | "model_name = \"mace\"\n", 221 | "\n", 222 | "correlation = 3\n", 223 | "model = {\n", 224 | " \"schnet\": SchNetModel,\n", 225 | " \"dimenet\": DimeNetPPModel,\n", 226 | " \"spherenet\": SphereNetModel,\n", 227 | " \"egnn\": EGNNModel,\n", 228 | " \"gvp\": GVPGNNModel,\n", 229 | " \"tfn\": TFNModel,\n", 230 | " \"mace\": partial(MACEModel, correlation=correlation),\n", 231 | "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n", 232 | "\n", 233 | "best_val_acc, test_acc, train_time = run_experiment(\n", 234 | " model, \n", 235 | " dataloader,\n", 236 | " val_loader, \n", 237 | " test_loader,\n", 238 | " n_epochs=100,\n", 239 | " n_times=10,\n", 240 | " device=device,\n", 241 | " verbose=False\n", 242 | ")" 243 | ] 244 | }, 245 | { 246 | "attachments": {}, 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "# Four-body non-chiral counterexample\n", 251 | "\n", 252 | "Pair of local neighbourhoods that are indistinguishable when comparing their set of $4$-body scalars without considering chirality/handedness, i.e. the unordered set of pairwise distances, angles, and quadruplet scalars." 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "def create_four_body_nonchiral_envs():\n", 262 | " dataset = []\n", 263 | "\n", 264 | " a1_x, a1_y, a1_z = 3, 2, -4\n", 265 | " a2_x, a2_y, a2_z = 0, 2, 5\n", 266 | " a3_x, a3_y, a3_z = 0, 2, -5\n", 267 | " b1_x, b1_y, b1_z = 3, -2, -4\n", 268 | " b2_x, b2_y, b2_z = 0, -2, 5\n", 269 | " b3_x, b3_y, b3_z = 0, -2, -5\n", 270 | " c_x, c_y, c_z = 0, 5, 0\n", 271 | "\n", 272 | " angle = 1 * torch.pi / 10 # random angle\n", 273 | " Q = e3nn.o3.matrix_y(torch.tensor(angle)).numpy()\n", 274 | "\n", 275 | " # Environment 0\n", 276 | " atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0, 0 ])\n", 277 | " edge_index = torch.LongTensor([ [0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7] ])\n", 278 | " pos = torch.FloatTensor([ \n", 279 | " [0, 0, 0],\n", 280 | " [a1_x, a1_y, a1_z],\n", 281 | " [a2_x, a2_y, a2_z],\n", 282 | " [a3_x, a3_y, a3_z],\n", 283 | " [b1_x, b1_y, b1_z] @ Q,\n", 284 | " [b2_x, b2_y, b2_z] @ Q,\n", 285 | " [b3_x, b3_y, b3_z] @ Q,\n", 286 | " [c_x, +c_y, c_z],\n", 287 | " ]) #.to(dtype=torch.float64)\n", 288 | " y = torch.LongTensor([0]) # Label 0\n", 289 | " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", 290 | " data1.edge_index = to_undirected(data1.edge_index)\n", 291 | " dataset.append(data1)\n", 292 | " \n", 293 | " # Environment 1\n", 294 | " atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0, 0 ])\n", 295 | " edge_index = torch.LongTensor([ [0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7] ])\n", 296 | " pos = torch.FloatTensor([ \n", 297 | " [0, 0, 0],\n", 298 | " [a1_x, a1_y, a1_z],\n", 299 | " [a2_x, a2_y, a2_z],\n", 300 | " [a3_x, a3_y, a3_z],\n", 301 | " [b1_x, b1_y, b1_z] @ Q,\n", 302 | " [b2_x, b2_y, b2_z] @ Q,\n", 303 | " [b3_x, b3_y, b3_z] @ Q,\n", 304 | " [c_x, -c_y, c_z],\n", 305 | " ]) #.to(dtype=torch.float64)\n", 306 | " y = torch.LongTensor([1]) # Label 1\n", 307 | " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", 308 | " data2.edge_index = to_undirected(data2.edge_index)\n", 309 | " dataset.append(data2)\n", 310 | " \n", 311 | " return dataset\n", 312 | "\n", 313 | "# Create dataset\n", 314 | "dataset = create_four_body_nonchiral_envs()\n", 315 | "for data in dataset:\n", 316 | " plot_3d(data, lim=5)\n", 317 | "\n", 318 | "# Create dataloaders\n", 319 | "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n", 320 | "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n", 321 | "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)" 322 | ] 323 | }, 324 | { 325 | "cell_type": "code", 326 | "execution_count": null, 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "# Set model\n", 331 | "model_name = \"mace\"\n", 332 | "\n", 333 | "correlation = 4\n", 334 | "model = {\n", 335 | " \"schnet\": SchNetModel,\n", 336 | " \"dimenet\": DimeNetPPModel,\n", 337 | " \"spherenet\": SphereNetModel,\n", 338 | " \"egnn\": EGNNModel,\n", 339 | " \"gvp\": GVPGNNModel,\n", 340 | " \"tfn\": TFNModel,\n", 341 | " \"mace\": partial(MACEModel, correlation=correlation),\n", 342 | "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n", 343 | "\n", 344 | "best_val_acc, test_acc, train_time = run_experiment(\n", 345 | " model, \n", 346 | " dataloader,\n", 347 | " val_loader, \n", 348 | " test_loader,\n", 349 | " n_epochs=100,\n", 350 | " n_times=10,\n", 351 | " device=device,\n", 352 | " verbose=False\n", 353 | ")" 354 | ] 355 | }, 356 | { 357 | "attachments": {}, 358 | "cell_type": "markdown", 359 | "metadata": {}, 360 | "source": [ 361 | "# Four-body chiral counterexample\n", 362 | "\n", 363 | "Pair of local neighbourhoods that are indistinguishable when comparing their set of $4$-body scalars when considering chirality/handedness, i.e. the unordered set of pairwise distances, angles, and quadruplet scalars." 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "def create_four_body_chiral_envs():\n", 373 | " dataset = []\n", 374 | "\n", 375 | " a1_x, a1_y, a1_z = 3, 0, -4\n", 376 | " a2_x, a2_y, a2_z = 0, 0, 5\n", 377 | " a3_x, a3_y, a3_z = 0, 0, -5\n", 378 | " c_x, c_y, c_z = 0, 5, 0\n", 379 | "\n", 380 | " # Environment 0\n", 381 | " atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n", 382 | " edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n", 383 | " pos = torch.FloatTensor([ \n", 384 | " [0, 0, 0],\n", 385 | " [a1_x, a1_y, a1_z],\n", 386 | " [a2_x, a2_y, a2_z],\n", 387 | " [a3_x, a3_y, a3_z],\n", 388 | " [c_x, +c_y, c_z],\n", 389 | " ])\n", 390 | " y = torch.LongTensor([0]) # Label 0\n", 391 | " data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", 392 | " data1.edge_index = to_undirected(data1.edge_index)\n", 393 | " dataset.append(data1)\n", 394 | " \n", 395 | " # Environment 1\n", 396 | " atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n", 397 | " edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n", 398 | " pos = torch.FloatTensor([ \n", 399 | " [0, 0, 0],\n", 400 | " [a1_x, a1_y, a1_z],\n", 401 | " [a2_x, a2_y, a2_z],\n", 402 | " [a3_x, a3_y, a3_z],\n", 403 | " [c_x, -c_y, c_z],\n", 404 | " ])\n", 405 | " y = torch.LongTensor([1]) # Label 1\n", 406 | " data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n", 407 | " data2.edge_index = to_undirected(data2.edge_index)\n", 408 | " dataset.append(data2)\n", 409 | " \n", 410 | " return dataset\n", 411 | "\n", 412 | "# Create dataset\n", 413 | "dataset = create_four_body_chiral_envs()\n", 414 | "for data in dataset:\n", 415 | " plot_3d(data, lim=5)\n", 416 | "\n", 417 | "# Create dataloaders\n", 418 | "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n", 419 | "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n", 420 | "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "# Set model\n", 430 | "model_name = \"tfn\"\n", 431 | "\n", 432 | "correlation = 4\n", 433 | "model = {\n", 434 | " \"schnet\": SchNetModel,\n", 435 | " \"dimenet\": DimeNetPPModel,\n", 436 | " \"spherenet\": SphereNetModel,\n", 437 | " \"egnn\": EGNNModel,\n", 438 | " \"gvp\": GVPGNNModel,\n", 439 | " \"tfn\": partial(TFNModel, hidden_irreps=e3nn.o3.Irreps(f'64x0e + 64x0o + 64x1e + 64x1o + 64x2e + 64x2o')),\n", 440 | " \"mace\": partial(MACEModel, correlation=correlation, hidden_irreps=e3nn.o3.Irreps(f'32x0e + 32x0o + 32x1e + 32x1o + 32x2e + 32x2o')),\n", 441 | "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n", 442 | "\n", 443 | "best_val_acc, test_acc, train_time = run_experiment(\n", 444 | " model, \n", 445 | " dataloader,\n", 446 | " val_loader, \n", 447 | " test_loader,\n", 448 | " n_epochs=100,\n", 449 | " n_times=10,\n", 450 | " device=device,\n", 451 | " verbose=False\n", 452 | ")" 453 | ] 454 | } 455 | ], 456 | "metadata": { 457 | "kernelspec": { 458 | "display_name": "Python 3", 459 | "language": "python", 460 | "name": "python3" 461 | }, 462 | "language_info": { 463 | "codemirror_mode": { 464 | "name": "ipython", 465 | "version": 3 466 | }, 467 | "file_extension": ".py", 468 | "mimetype": "text/x-python", 469 | "name": "python", 470 | "nbconvert_exporter": "python", 471 | "pygments_lexer": "ipython3", 472 | "version": "3.8.16" 473 | }, 474 | "orig_nbformat": 4, 475 | "vscode": { 476 | "interpreter": { 477 | "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897" 478 | } 479 | } 480 | }, 481 | "nbformat": 4, 482 | "nbformat_minor": 2 483 | } 484 | -------------------------------------------------------------------------------- /models/mace_modules/blocks.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # Elementary Block for Building O(3) Equivariant Higher Order Message Passing Neural Network 3 | # Authors: Ilyes Batatia, Gregor Simm 4 | # This program is distributed under the MIT License (see MIT.md) 5 | ########################################################################################### 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Callable, Dict, Optional, Tuple, Union 9 | 10 | import numpy as np 11 | import torch.nn.functional 12 | from e3nn import nn, o3 13 | 14 | # from mace.tools.scatter import scatter_sum 15 | from torch_scatter import scatter_sum 16 | 17 | from .irreps_tools import ( 18 | linear_out_irreps, 19 | reshape_irreps, 20 | tp_out_irreps_with_instructions, 21 | ) 22 | from .radial import BesselBasis, PolynomialCutoff 23 | from .symmetric_contraction import SymmetricContraction 24 | 25 | 26 | class LinearNodeEmbeddingBlock(torch.nn.Module): 27 | def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps): 28 | super().__init__() 29 | self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out) 30 | 31 | def forward( 32 | self, node_attrs: torch.Tensor, # [n_nodes, irreps] 33 | ): 34 | return self.linear(node_attrs) 35 | 36 | 37 | class LinearReadoutBlock(torch.nn.Module): 38 | def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps = o3.Irreps("0e")): 39 | super().__init__() 40 | self.linear = o3.Linear(irreps_in=irreps_in, irreps_out=irreps_out) 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] 43 | return self.linear(x) # [n_nodes, irreps_out] 44 | 45 | 46 | class NonLinearReadoutBlock(torch.nn.Module): 47 | def __init__( 48 | self, irreps_in: o3.Irreps, MLP_irreps: o3.Irreps, 49 | gate: Optional[Callable], irreps_out: o3.Irreps = o3.Irreps("0e") 50 | ): 51 | super().__init__() 52 | self.hidden_irreps = MLP_irreps 53 | self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.hidden_irreps) 54 | self.non_linearity = nn.Activation(irreps_in=self.hidden_irreps, acts=[gate]) 55 | self.linear_2 = o3.Linear(irreps_in=self.hidden_irreps, irreps_out=irreps_out) 56 | 57 | def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [..., ] 58 | x = self.non_linearity(self.linear_1(x)) 59 | return self.linear_2(x) # [n_nodes, irreps_out] 60 | 61 | 62 | class AtomicEnergiesBlock(torch.nn.Module): 63 | atomic_energies: torch.Tensor 64 | 65 | def __init__(self, atomic_energies: Union[np.ndarray, torch.Tensor]): 66 | super().__init__() 67 | assert len(atomic_energies.shape) == 1 68 | 69 | self.register_buffer( 70 | "atomic_energies", 71 | torch.tensor(atomic_energies, dtype=torch.get_default_dtype()), 72 | ) # [n_elements, ] 73 | 74 | def forward( 75 | self, x: torch.Tensor # one-hot of elements [..., n_elements] 76 | ) -> torch.Tensor: # [..., ] 77 | return torch.matmul(x, self.atomic_energies) 78 | 79 | def __repr__(self): 80 | formatted_energies = ", ".join([f"{x:.4f}" for x in self.atomic_energies]) 81 | return f"{self.__class__.__name__}(energies=[{formatted_energies}])" 82 | 83 | 84 | class RadialEmbeddingBlock(torch.nn.Module): 85 | def __init__(self, r_max: float, num_bessel: int, num_polynomial_cutoff: int): 86 | super().__init__() 87 | self.bessel_fn = BesselBasis(r_max=r_max, num_basis=num_bessel) 88 | self.cutoff_fn = PolynomialCutoff(r_max=r_max, p=num_polynomial_cutoff) 89 | self.out_dim = num_bessel 90 | 91 | def forward( 92 | self, edge_lengths: torch.Tensor, # [n_edges, 1] 93 | ): 94 | bessel = self.bessel_fn(edge_lengths) # [n_edges, n_basis] 95 | cutoff = self.cutoff_fn(edge_lengths) # [n_edges, 1] 96 | return bessel * cutoff # [n_edges, n_basis] 97 | 98 | 99 | class EquivariantProductBasisBlock(torch.nn.Module): 100 | def __init__( 101 | self, 102 | node_feats_irreps: o3.Irreps, 103 | target_irreps: o3.Irreps, 104 | correlation: Union[int, Dict[str, int]], 105 | element_dependent: bool = True, 106 | use_sc: bool = True, 107 | batch_norm: bool = False, 108 | num_elements: Optional[int] = None, 109 | ) -> None: 110 | super().__init__() 111 | 112 | self.use_sc = use_sc 113 | self.symmetric_contractions = SymmetricContraction( 114 | irreps_in=node_feats_irreps, 115 | irreps_out=target_irreps, 116 | correlation=correlation, 117 | element_dependent=element_dependent, 118 | num_elements=num_elements, 119 | ) 120 | # Update linear 121 | self.linear = o3.Linear( 122 | target_irreps, target_irreps, internal_weights=True, shared_weights=True, 123 | ) 124 | self.batch_norm = nn.BatchNorm(target_irreps) if batch_norm else None 125 | 126 | def forward( 127 | self, node_feats: torch.Tensor, sc: Optional[torch.Tensor], node_attrs: Optional[torch.Tensor] 128 | ) -> torch.Tensor: 129 | node_feats = self.symmetric_contractions(node_feats, node_attrs) 130 | out = self.linear(node_feats) 131 | if self.batch_norm: 132 | out = self.batch_norm(out) 133 | if self.use_sc: 134 | out = out + sc 135 | return out 136 | 137 | 138 | class InteractionBlock(ABC, torch.nn.Module): 139 | def __init__( 140 | self, 141 | node_attrs_irreps: o3.Irreps, 142 | node_feats_irreps: o3.Irreps, 143 | edge_attrs_irreps: o3.Irreps, 144 | edge_feats_irreps: o3.Irreps, 145 | target_irreps: o3.Irreps, 146 | hidden_irreps: o3.Irreps, 147 | avg_num_neighbors: float, 148 | ) -> None: 149 | super().__init__() 150 | self.node_attrs_irreps = node_attrs_irreps 151 | self.node_feats_irreps = node_feats_irreps 152 | self.edge_attrs_irreps = edge_attrs_irreps 153 | self.edge_feats_irreps = edge_feats_irreps 154 | self.target_irreps = target_irreps 155 | self.hidden_irreps = hidden_irreps 156 | self.avg_num_neighbors = avg_num_neighbors 157 | 158 | self._setup() 159 | 160 | @abstractmethod 161 | def _setup(self) -> None: 162 | raise NotImplementedError 163 | 164 | @abstractmethod 165 | def forward( 166 | self, 167 | node_attrs: torch.Tensor, 168 | node_feats: torch.Tensor, 169 | edge_attrs: torch.Tensor, 170 | edge_feats: torch.Tensor, 171 | edge_index: torch.Tensor, 172 | ) -> torch.Tensor: 173 | raise NotImplementedError 174 | 175 | 176 | nonlinearities = {1: torch.nn.SiLU(), -1: torch.nn.Tanh()} 177 | 178 | 179 | class TensorProductWeightsBlock(torch.nn.Module): 180 | def __init__(self, num_elements: int, num_edge_feats: int, num_feats_out: int): 181 | super().__init__() 182 | 183 | weights = torch.empty( 184 | (num_elements, num_edge_feats, num_feats_out), 185 | dtype=torch.get_default_dtype(), 186 | ) 187 | torch.nn.init.xavier_uniform_(weights) 188 | self.weights = torch.nn.Parameter(weights) 189 | 190 | def forward( 191 | self, 192 | sender_or_receiver_node_attrs: torch.Tensor, # assumes that the node attributes are one-hot encoded 193 | edge_feats: torch.Tensor, 194 | ): 195 | return torch.einsum( 196 | "be, ba, aek -> bk", edge_feats, sender_or_receiver_node_attrs, self.weights 197 | ) 198 | 199 | def __repr__(self): 200 | return ( 201 | f'{self.__class__.__name__}(shape=({", ".join(str(s) for s in self.weights.shape)}), ' 202 | f"weights={np.prod(self.weights.shape)})" 203 | ) 204 | 205 | 206 | class ResidualElementDependentInteractionBlock(InteractionBlock): 207 | def _setup(self) -> None: 208 | self.linear_up = o3.Linear( 209 | self.node_feats_irreps, 210 | self.node_feats_irreps, 211 | internal_weights=True, 212 | shared_weights=True, 213 | ) 214 | # TensorProduct 215 | irreps_mid, instructions = tp_out_irreps_with_instructions( 216 | self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps 217 | ) 218 | self.conv_tp = o3.TensorProduct( 219 | self.node_feats_irreps, 220 | self.edge_attrs_irreps, 221 | irreps_mid, 222 | instructions=instructions, 223 | shared_weights=False, 224 | internal_weights=False, 225 | ) 226 | self.conv_tp_weights = TensorProductWeightsBlock( 227 | num_elements=self.node_attrs_irreps.num_irreps, 228 | num_edge_feats=self.edge_feats_irreps.num_irreps, 229 | num_feats_out=self.conv_tp.weight_numel, 230 | ) 231 | 232 | # Linear 233 | irreps_mid = irreps_mid.simplify() 234 | self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) 235 | self.irreps_out = self.irreps_out.simplify() 236 | self.linear = o3.Linear( 237 | irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True 238 | ) 239 | 240 | # Selector TensorProduct 241 | self.skip_tp = o3.FullyConnectedTensorProduct( 242 | self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out 243 | ) 244 | 245 | def forward( 246 | self, 247 | node_attrs: torch.Tensor, 248 | node_feats: torch.Tensor, 249 | edge_attrs: torch.Tensor, 250 | edge_feats: torch.Tensor, 251 | edge_index: torch.Tensor, 252 | ) -> torch.Tensor: 253 | sender, receiver = edge_index 254 | num_nodes = node_feats.shape[0] 255 | sc = self.skip_tp(node_feats, node_attrs) 256 | node_feats = self.linear_up(node_feats) 257 | tp_weights = self.conv_tp_weights(node_attrs[sender], edge_feats) 258 | mji = self.conv_tp( 259 | node_feats[sender], edge_attrs, tp_weights 260 | ) # [n_edges, irreps] 261 | message = scatter_sum( 262 | src=mji, index=receiver, dim=0, dim_size=num_nodes 263 | ) # [n_nodes, irreps] 264 | message = self.linear(message) / self.avg_num_neighbors 265 | return message + sc # [n_nodes, irreps] 266 | 267 | 268 | class AgnosticNonlinearInteractionBlock(InteractionBlock): 269 | def _setup(self) -> None: 270 | self.linear_up = o3.Linear( 271 | self.node_feats_irreps, 272 | self.node_feats_irreps, 273 | internal_weights=True, 274 | shared_weights=True, 275 | ) 276 | # TensorProduct 277 | irreps_mid, instructions = tp_out_irreps_with_instructions( 278 | self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps 279 | ) 280 | self.conv_tp = o3.TensorProduct( 281 | self.node_feats_irreps, 282 | self.edge_attrs_irreps, 283 | irreps_mid, 284 | instructions=instructions, 285 | shared_weights=False, 286 | internal_weights=False, 287 | ) 288 | 289 | # Convolution weights 290 | input_dim = self.edge_feats_irreps.num_irreps 291 | self.conv_tp_weights = nn.FullyConnectedNet( 292 | [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], torch.nn.SiLU(), 293 | ) 294 | 295 | # Linear 296 | irreps_mid = irreps_mid.simplify() 297 | self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) 298 | self.irreps_out = self.irreps_out.simplify() 299 | self.linear = o3.Linear( 300 | irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True 301 | ) 302 | 303 | # Selector TensorProduct 304 | self.skip_tp = o3.FullyConnectedTensorProduct( 305 | self.irreps_out, self.node_attrs_irreps, self.irreps_out 306 | ) 307 | 308 | def forward( 309 | self, 310 | node_attrs: torch.Tensor, 311 | node_feats: torch.Tensor, 312 | edge_attrs: torch.Tensor, 313 | edge_feats: torch.Tensor, 314 | edge_index: torch.Tensor, 315 | ) -> torch.Tensor: 316 | sender, receiver = edge_index 317 | num_nodes = node_feats.shape[0] 318 | tp_weights = self.conv_tp_weights(edge_feats) 319 | node_feats = self.linear_up(node_feats) 320 | mji = self.conv_tp( 321 | node_feats[sender], edge_attrs, tp_weights 322 | ) # [n_edges, irreps] 323 | message = scatter_sum( 324 | src=mji, index=receiver, dim=0, dim_size=num_nodes 325 | ) # [n_nodes, irreps] 326 | message = self.linear(message) / self.avg_num_neighbors 327 | message = self.skip_tp(message, node_attrs) 328 | return message # [n_nodes, irreps] 329 | 330 | 331 | class AgnosticResidualNonlinearInteractionBlock(InteractionBlock): 332 | def _setup(self) -> None: 333 | # First linear 334 | self.linear_up = o3.Linear( 335 | self.node_feats_irreps, 336 | self.node_feats_irreps, 337 | internal_weights=True, 338 | shared_weights=True, 339 | ) 340 | # TensorProduct 341 | irreps_mid, instructions = tp_out_irreps_with_instructions( 342 | self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps 343 | ) 344 | self.conv_tp = o3.TensorProduct( 345 | self.node_feats_irreps, 346 | self.edge_attrs_irreps, 347 | irreps_mid, 348 | instructions=instructions, 349 | shared_weights=False, 350 | internal_weights=False, 351 | ) 352 | 353 | # Convolution weights 354 | input_dim = self.edge_feats_irreps.num_irreps 355 | self.conv_tp_weights = nn.FullyConnectedNet( 356 | [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], torch.nn.SiLU(), 357 | ) 358 | 359 | # Linear 360 | irreps_mid = irreps_mid.simplify() 361 | self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps) 362 | self.irreps_out = self.irreps_out.simplify() 363 | self.linear = o3.Linear( 364 | irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True 365 | ) 366 | 367 | # Selector TensorProduct 368 | self.skip_tp = o3.FullyConnectedTensorProduct( 369 | self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out 370 | ) 371 | 372 | def forward( 373 | self, 374 | node_attrs: torch.Tensor, 375 | node_feats: torch.Tensor, 376 | edge_attrs: torch.Tensor, 377 | edge_feats: torch.Tensor, 378 | edge_index: torch.Tensor, 379 | ) -> torch.Tensor: 380 | sender, receiver = edge_index 381 | num_nodes = node_feats.shape[0] 382 | sc = self.skip_tp(node_feats, node_attrs) 383 | node_feats = self.linear_up(node_feats) 384 | tp_weights = self.conv_tp_weights(edge_feats) 385 | mji = self.conv_tp( 386 | node_feats[sender], edge_attrs, tp_weights 387 | ) # [n_edges, irreps] 388 | message = scatter_sum( 389 | src=mji, index=receiver, dim=0, dim_size=num_nodes 390 | ) # [n_nodes, irreps] 391 | message = self.linear(message) / self.avg_num_neighbors 392 | message = message + sc 393 | return message # [n_nodes, irreps] 394 | 395 | 396 | class RealAgnosticInteractionBlock(InteractionBlock): 397 | def _setup(self) -> None: 398 | # First linear 399 | self.linear_up = o3.Linear( 400 | self.node_feats_irreps, 401 | self.node_feats_irreps, 402 | internal_weights=True, 403 | shared_weights=True, 404 | ) 405 | # TensorProduct 406 | irreps_mid, instructions = tp_out_irreps_with_instructions( 407 | self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps, 408 | ) 409 | self.conv_tp = o3.TensorProduct( 410 | self.node_feats_irreps, 411 | self.edge_attrs_irreps, 412 | irreps_mid, 413 | instructions=instructions, 414 | shared_weights=False, 415 | internal_weights=False, 416 | ) 417 | 418 | # Convolution weights 419 | input_dim = self.edge_feats_irreps.num_irreps 420 | self.conv_tp_weights = nn.FullyConnectedNet( 421 | [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], torch.nn.SiLU(), 422 | ) 423 | 424 | # Linear 425 | irreps_mid = irreps_mid.simplify() 426 | self.irreps_out = self.target_irreps 427 | self.linear = o3.Linear( 428 | irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True 429 | ) 430 | 431 | # Selector TensorProduct 432 | self.skip_tp = o3.FullyConnectedTensorProduct( 433 | self.irreps_out, self.node_attrs_irreps, self.irreps_out 434 | ) 435 | self.reshape = reshape_irreps(self.irreps_out) 436 | 437 | def forward( 438 | self, 439 | node_attrs: torch.Tensor, 440 | node_feats: torch.Tensor, 441 | edge_attrs: torch.Tensor, 442 | edge_feats: torch.Tensor, 443 | edge_index: torch.Tensor, 444 | ) -> Tuple[torch.Tensor, torch.Tensor]: 445 | sender, receiver = edge_index 446 | num_nodes = node_feats.shape[0] 447 | 448 | node_feats = self.linear_up(node_feats) 449 | tp_weights = self.conv_tp_weights(edge_feats) 450 | mji = self.conv_tp( 451 | node_feats[sender], edge_attrs, tp_weights 452 | ) # [n_edges, irreps] 453 | message = scatter_sum( 454 | src=mji, index=receiver, dim=0, dim_size=num_nodes 455 | ) # [n_nodes, irreps] 456 | message = self.linear(message) / self.avg_num_neighbors 457 | message = self.skip_tp(message, node_attrs) 458 | return ( 459 | self.reshape(message), 460 | None, 461 | ) # [n_nodes, channels, (lmax + 1)**2] 462 | 463 | 464 | class RealAgnosticResidualInteractionBlock(InteractionBlock): 465 | def _setup(self) -> None: 466 | 467 | # First linear 468 | self.linear_up = o3.Linear( 469 | self.node_feats_irreps, 470 | self.node_feats_irreps, 471 | internal_weights=True, 472 | shared_weights=True, 473 | ) 474 | # TensorProduct 475 | irreps_mid, instructions = tp_out_irreps_with_instructions( 476 | self.node_feats_irreps, self.edge_attrs_irreps, self.target_irreps, 477 | ) 478 | self.conv_tp = o3.TensorProduct( 479 | self.node_feats_irreps, 480 | self.edge_attrs_irreps, 481 | irreps_mid, 482 | instructions=instructions, 483 | shared_weights=False, 484 | internal_weights=False, 485 | ) 486 | 487 | # Convolution weights 488 | input_dim = self.edge_feats_irreps.num_irreps 489 | self.conv_tp_weights = nn.FullyConnectedNet( 490 | [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], torch.nn.SiLU(), 491 | ) 492 | 493 | # Linear 494 | irreps_mid = irreps_mid.simplify() 495 | self.irreps_out = self.target_irreps 496 | self.linear = o3.Linear( 497 | irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True 498 | ) 499 | 500 | # Selector TensorProduct 501 | self.skip_tp = o3.FullyConnectedTensorProduct( 502 | self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps 503 | ) 504 | self.reshape = reshape_irreps(self.irreps_out) 505 | 506 | def forward( 507 | self, 508 | node_attrs: torch.Tensor, 509 | node_feats: torch.Tensor, 510 | edge_attrs: torch.Tensor, 511 | edge_feats: torch.Tensor, 512 | edge_index: torch.Tensor, 513 | ) -> Tuple[torch.Tensor, torch.Tensor]: 514 | sender, receiver = edge_index 515 | num_nodes = node_feats.shape[0] 516 | 517 | sc = self.skip_tp(node_feats, node_attrs) 518 | node_feats = self.linear_up(node_feats) 519 | tp_weights = self.conv_tp_weights(edge_feats) 520 | mji = self.conv_tp( 521 | node_feats[sender], edge_attrs, tp_weights 522 | ) # [n_edges, irreps] 523 | message = scatter_sum( 524 | src=mji, index=receiver, dim=0, dim_size=num_nodes 525 | ) # [n_nodes, irreps] 526 | message = self.linear(message) / self.avg_num_neighbors 527 | return ( 528 | self.reshape(message), 529 | sc, 530 | ) # [n_nodes, channels, (lmax + 1)**2] 531 | 532 | 533 | class ScaleShiftBlock(torch.nn.Module): 534 | def __init__(self, scale: float, shift: float): 535 | super().__init__() 536 | self.register_buffer( 537 | "scale", torch.tensor(scale, dtype=torch.get_default_dtype()) 538 | ) 539 | self.register_buffer( 540 | "shift", torch.tensor(shift, dtype=torch.get_default_dtype()) 541 | ) 542 | 543 | def forward(self, x: torch.Tensor) -> torch.Tensor: 544 | return self.scale * x + self.shift 545 | 546 | def __repr__(self): 547 | return ( 548 | f"{self.__class__.__name__}(scale={self.scale:.6f}, shift={self.shift:.6f})" 549 | ) 550 | -------------------------------------------------------------------------------- /models/layers/spherenet_layer.py: -------------------------------------------------------------------------------- 1 | ################################################################ 2 | # Implementation of SphereNet layers 3 | # 4 | # Paper: Spherical Message Passing for 3D Graph Networks 5 | # by Y Liu, L Wang, M Liu, X Zhang, B Oztekin, and S Ji 6 | # 7 | # Orginal repository: https://github.com/divelab/DIG 8 | ################################################################ 9 | 10 | import torch 11 | from torch import nn 12 | from torch.nn import Linear, Embedding 13 | from torch_geometric.nn.inits import glorot_orthogonal 14 | from torch_scatter import scatter 15 | from math import sqrt 16 | 17 | import numpy as np 18 | from scipy.optimize import brentq 19 | from scipy import special as sp 20 | import torch 21 | from math import pi as PI 22 | 23 | import sympy as sym 24 | 25 | import torch 26 | from torch_scatter import scatter 27 | from torch_sparse import SparseTensor 28 | from math import pi as PI 29 | 30 | def swish(x): 31 | return x * torch.sigmoid(x) 32 | 33 | class emb(torch.nn.Module): 34 | def __init__(self, num_spherical, num_radial, cutoff, envelope_exponent): 35 | super(emb, self).__init__() 36 | self.dist_emb = dist_emb(num_radial, cutoff, envelope_exponent) 37 | self.angle_emb = angle_emb(num_spherical, num_radial, cutoff, envelope_exponent) 38 | self.torsion_emb = torsion_emb(num_spherical, num_radial, cutoff, envelope_exponent) 39 | self.reset_parameters() 40 | 41 | def reset_parameters(self): 42 | self.dist_emb.reset_parameters() 43 | 44 | def forward(self, dist, angle, torsion, idx_kj): 45 | dist_emb = self.dist_emb(dist) 46 | angle_emb = self.angle_emb(dist, angle, idx_kj) 47 | torsion_emb = self.torsion_emb(dist, angle, torsion, idx_kj) 48 | return dist_emb, angle_emb, torsion_emb 49 | 50 | class ResidualLayer(torch.nn.Module): 51 | def __init__(self, hidden_channels, act=swish): 52 | super(ResidualLayer, self).__init__() 53 | self.act = act 54 | self.lin1 = Linear(hidden_channels, hidden_channels) 55 | self.lin2 = Linear(hidden_channels, hidden_channels) 56 | 57 | self.reset_parameters() 58 | 59 | def reset_parameters(self): 60 | glorot_orthogonal(self.lin1.weight, scale=2.0) 61 | self.lin1.bias.data.fill_(0) 62 | glorot_orthogonal(self.lin2.weight, scale=2.0) 63 | self.lin2.bias.data.fill_(0) 64 | 65 | def forward(self, x): 66 | return x + self.act(self.lin2(self.act(self.lin1(x)))) 67 | 68 | 69 | class init(torch.nn.Module): 70 | def __init__(self, num_radial, hidden_channels, act=swish, use_node_features=True): 71 | super(init, self).__init__() 72 | self.act = act 73 | self.use_node_features = use_node_features 74 | if self.use_node_features: 75 | self.emb = Embedding(95, hidden_channels) 76 | else: # option to use no node features and a learned embedding vector for each node instead 77 | self.node_embedding = nn.Parameter(torch.empty((hidden_channels,))) 78 | nn.init.normal_(self.node_embedding) 79 | self.lin_rbf_0 = Linear(num_radial, hidden_channels) 80 | self.lin = Linear(3 * hidden_channels, hidden_channels) 81 | self.lin_rbf_1 = nn.Linear(num_radial, hidden_channels, bias=False) 82 | self.reset_parameters() 83 | 84 | def reset_parameters(self): 85 | if self.use_node_features: 86 | self.emb.weight.data.uniform_(-sqrt(3), sqrt(3)) 87 | self.lin_rbf_0.reset_parameters() 88 | self.lin.reset_parameters() 89 | glorot_orthogonal(self.lin_rbf_1.weight, scale=2.0) 90 | 91 | def forward(self, x, emb, i, j): 92 | rbf,_,_ = emb 93 | if self.use_node_features: 94 | x = self.emb(x) 95 | else: 96 | x = self.node_embedding[None, :].expand(x.shape[0], -1) 97 | rbf0 = self.act(self.lin_rbf_0(rbf)) 98 | e1 = self.act(self.lin(torch.cat([x[i], x[j], rbf0], dim=-1))) 99 | e2 = self.lin_rbf_1(rbf) * e1 100 | 101 | return e1, e2 102 | 103 | 104 | class update_e(torch.nn.Module): 105 | def __init__(self, hidden_channels, int_emb_size, basis_emb_size_dist, basis_emb_size_angle, basis_emb_size_torsion, num_spherical, num_radial, 106 | num_before_skip, num_after_skip, act=swish): 107 | super(update_e, self).__init__() 108 | self.act = act 109 | self.lin_rbf1 = nn.Linear(num_radial, basis_emb_size_dist, bias=False) 110 | self.lin_rbf2 = nn.Linear(basis_emb_size_dist, hidden_channels, bias=False) 111 | self.lin_sbf1 = nn.Linear(num_spherical * num_radial, basis_emb_size_angle, bias=False) 112 | self.lin_sbf2 = nn.Linear(basis_emb_size_angle, int_emb_size, bias=False) 113 | self.lin_t1 = nn.Linear(num_spherical * num_spherical * num_radial, basis_emb_size_torsion, bias=False) 114 | self.lin_t2 = nn.Linear(basis_emb_size_torsion, int_emb_size, bias=False) 115 | self.lin_rbf = nn.Linear(num_radial, hidden_channels, bias=False) 116 | 117 | self.lin_kj = nn.Linear(hidden_channels, hidden_channels) 118 | self.lin_ji = nn.Linear(hidden_channels, hidden_channels) 119 | 120 | self.lin_down = nn.Linear(hidden_channels, int_emb_size, bias=False) 121 | self.lin_up = nn.Linear(int_emb_size, hidden_channels, bias=False) 122 | 123 | self.layers_before_skip = torch.nn.ModuleList([ 124 | ResidualLayer(hidden_channels, act) 125 | for _ in range(num_before_skip) 126 | ]) 127 | self.lin = nn.Linear(hidden_channels, hidden_channels) 128 | self.layers_after_skip = torch.nn.ModuleList([ 129 | ResidualLayer(hidden_channels, act) 130 | for _ in range(num_after_skip) 131 | ]) 132 | 133 | self.reset_parameters() 134 | 135 | def reset_parameters(self): 136 | glorot_orthogonal(self.lin_rbf1.weight, scale=2.0) 137 | glorot_orthogonal(self.lin_rbf2.weight, scale=2.0) 138 | glorot_orthogonal(self.lin_sbf1.weight, scale=2.0) 139 | glorot_orthogonal(self.lin_sbf2.weight, scale=2.0) 140 | glorot_orthogonal(self.lin_t1.weight, scale=2.0) 141 | glorot_orthogonal(self.lin_t2.weight, scale=2.0) 142 | 143 | glorot_orthogonal(self.lin_kj.weight, scale=2.0) 144 | self.lin_kj.bias.data.fill_(0) 145 | glorot_orthogonal(self.lin_ji.weight, scale=2.0) 146 | self.lin_ji.bias.data.fill_(0) 147 | 148 | glorot_orthogonal(self.lin_down.weight, scale=2.0) 149 | glorot_orthogonal(self.lin_up.weight, scale=2.0) 150 | 151 | for res_layer in self.layers_before_skip: 152 | res_layer.reset_parameters() 153 | glorot_orthogonal(self.lin.weight, scale=2.0) 154 | self.lin.bias.data.fill_(0) 155 | for res_layer in self.layers_after_skip: 156 | res_layer.reset_parameters() 157 | 158 | glorot_orthogonal(self.lin_rbf.weight, scale=2.0) 159 | 160 | def forward(self, x, emb, idx_kj, idx_ji): 161 | rbf0, sbf, t = emb 162 | x1,_ = x 163 | 164 | x_ji = self.act(self.lin_ji(x1)) 165 | x_kj = self.act(self.lin_kj(x1)) 166 | 167 | rbf = self.lin_rbf1(rbf0) 168 | rbf = self.lin_rbf2(rbf) 169 | x_kj = x_kj * rbf 170 | 171 | x_kj = self.act(self.lin_down(x_kj)) 172 | 173 | sbf = self.lin_sbf1(sbf) 174 | sbf = self.lin_sbf2(sbf) 175 | x_kj = x_kj[idx_kj] * sbf 176 | 177 | t = self.lin_t1(t) 178 | t = self.lin_t2(t) 179 | x_kj = x_kj * t 180 | 181 | x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x1.size(0)) 182 | x_kj = self.act(self.lin_up(x_kj)) 183 | 184 | e1 = x_ji + x_kj 185 | for layer in self.layers_before_skip: 186 | e1 = layer(e1) 187 | e1 = self.act(self.lin(e1)) + x1 188 | for layer in self.layers_after_skip: 189 | e1 = layer(e1) 190 | e2 = self.lin_rbf(rbf0) * e1 191 | 192 | return e1, e2 193 | 194 | 195 | class update_v(torch.nn.Module): 196 | def __init__(self, hidden_channels, out_emb_channels, out_channels, num_output_layers, act, output_init): 197 | super(update_v, self).__init__() 198 | self.act = act 199 | self.output_init = output_init 200 | 201 | self.lin_up = nn.Linear(hidden_channels, out_emb_channels, bias=True) 202 | self.lins = torch.nn.ModuleList() 203 | for _ in range(num_output_layers): 204 | self.lins.append(nn.Linear(out_emb_channels, out_emb_channels)) 205 | self.lin = nn.Linear(out_emb_channels, out_channels, bias=False) 206 | 207 | self.reset_parameters() 208 | 209 | def reset_parameters(self): 210 | glorot_orthogonal(self.lin_up.weight, scale=2.0) 211 | for lin in self.lins: 212 | glorot_orthogonal(lin.weight, scale=2.0) 213 | lin.bias.data.fill_(0) 214 | if self.output_init == 'zeros': 215 | self.lin.weight.data.fill_(0) 216 | if self.output_init == 'GlorotOrthogonal': 217 | glorot_orthogonal(self.lin.weight, scale=2.0) 218 | 219 | def forward(self, e, i): 220 | _, e2 = e 221 | v = scatter(e2, i, dim=0) 222 | v = self.lin_up(v) 223 | for lin in self.lins: 224 | v = self.act(lin(v)) 225 | v = self.lin(v) 226 | return v 227 | 228 | 229 | class update_u(torch.nn.Module): 230 | def __init__(self): 231 | super(update_u, self).__init__() 232 | 233 | def forward(self, u, v, batch): 234 | u += scatter(v, batch, dim=0) 235 | return u 236 | 237 | # Based on the code from: https://github.com/klicperajo/dimenet, 238 | # https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet_utils.py 239 | 240 | 241 | def Jn(r, n): 242 | return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r) 243 | 244 | 245 | def Jn_zeros(n, k): 246 | zerosj = np.zeros((n, k), dtype='float32') 247 | zerosj[0] = np.arange(1, k + 1) * np.pi 248 | points = np.arange(1, k + n) * np.pi 249 | racines = np.zeros(k + n - 1, dtype='float32') 250 | for i in range(1, n): 251 | for j in range(k + n - 1 - i): 252 | foo = brentq(Jn, points[j], points[j + 1], (i, )) 253 | racines[j] = foo 254 | points = racines 255 | zerosj[i][:k] = racines[:k] 256 | 257 | return zerosj 258 | 259 | 260 | def spherical_bessel_formulas(n): 261 | x = sym.symbols('x') 262 | 263 | f = [sym.sin(x) / x] 264 | a = sym.sin(x) / x 265 | for i in range(1, n): 266 | b = sym.diff(a, x) / x 267 | f += [sym.simplify(b * (-x)**i)] 268 | a = sym.simplify(b) 269 | return f 270 | 271 | 272 | def bessel_basis(n, k): 273 | zeros = Jn_zeros(n, k) 274 | normalizer = [] 275 | for order in range(n): 276 | normalizer_tmp = [] 277 | for i in range(k): 278 | normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1)**2] 279 | normalizer_tmp = 1 / np.array(normalizer_tmp)**0.5 280 | normalizer += [normalizer_tmp] 281 | 282 | f = spherical_bessel_formulas(n) 283 | x = sym.symbols('x') 284 | bess_basis = [] 285 | for order in range(n): 286 | bess_basis_tmp = [] 287 | for i in range(k): 288 | bess_basis_tmp += [ 289 | sym.simplify(normalizer[order][i] * 290 | f[order].subs(x, zeros[order, i] * x)) 291 | ] 292 | bess_basis += [bess_basis_tmp] 293 | return bess_basis 294 | 295 | 296 | def sph_harm_prefactor(k, m): 297 | return ((2 * k + 1) * np.math.factorial(k - abs(m)) / 298 | (4 * np.pi * np.math.factorial(k + abs(m))))**0.5 299 | 300 | 301 | def associated_legendre_polynomials(k, zero_m_only=True): 302 | z = sym.symbols('z') 303 | P_l_m = [[0] * (j + 1) for j in range(k)] 304 | 305 | P_l_m[0][0] = 1 306 | if k > 0: 307 | P_l_m[1][0] = z 308 | 309 | for j in range(2, k): 310 | P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] - 311 | (j - 1) * P_l_m[j - 2][0]) / j) 312 | if not zero_m_only: 313 | for i in range(1, k): 314 | P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1]) 315 | if i + 1 < k: 316 | P_l_m[i + 1][i] = sym.simplify( 317 | (2 * i + 1) * z * P_l_m[i][i]) 318 | for j in range(i + 2, k): 319 | P_l_m[j][i] = sym.simplify( 320 | ((2 * j - 1) * z * P_l_m[j - 1][i] - 321 | (i + j - 1) * P_l_m[j - 2][i]) / (j - i)) 322 | 323 | return P_l_m 324 | 325 | 326 | def real_sph_harm(l, zero_m_only=False, spherical_coordinates=True): 327 | """ 328 | Computes formula strings of the the real part of the spherical harmonics up to order l (excluded). 329 | Variables are either cartesian coordinates x,y,z on the unit sphere or spherical coordinates phi and theta. 330 | """ 331 | if not zero_m_only: 332 | x = sym.symbols('x') 333 | y = sym.symbols('y') 334 | S_m = [x*0] 335 | C_m = [1+0*x] 336 | # S_m = [0] 337 | # C_m = [1] 338 | for i in range(1, l): 339 | x = sym.symbols('x') 340 | y = sym.symbols('y') 341 | S_m += [x*S_m[i-1] + y*C_m[i-1]] 342 | C_m += [x*C_m[i-1] - y*S_m[i-1]] 343 | 344 | P_l_m = associated_legendre_polynomials(l, zero_m_only) 345 | if spherical_coordinates: 346 | theta = sym.symbols('theta') 347 | z = sym.symbols('z') 348 | for i in range(len(P_l_m)): 349 | for j in range(len(P_l_m[i])): 350 | if type(P_l_m[i][j]) != int: 351 | P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) 352 | if not zero_m_only: 353 | phi = sym.symbols('phi') 354 | for i in range(len(S_m)): 355 | S_m[i] = S_m[i].subs(x, sym.sin( 356 | theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) 357 | for i in range(len(C_m)): 358 | C_m[i] = C_m[i].subs(x, sym.sin( 359 | theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) 360 | 361 | Y_func_l_m = [['0']*(2*j + 1) for j in range(l)] 362 | for i in range(l): 363 | Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) 364 | 365 | if not zero_m_only: 366 | for i in range(1, l): 367 | for j in range(1, i + 1): 368 | Y_func_l_m[i][j] = sym.simplify( 369 | 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) 370 | for i in range(1, l): 371 | for j in range(1, i + 1): 372 | Y_func_l_m[i][-j] = sym.simplify( 373 | 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) 374 | 375 | return Y_func_l_m 376 | 377 | 378 | class Envelope(torch.nn.Module): 379 | def __init__(self, exponent): 380 | super(Envelope, self).__init__() 381 | self.p = exponent + 1 382 | self.a = -(self.p + 1) * (self.p + 2) / 2 383 | self.b = self.p * (self.p + 2) 384 | self.c = -self.p * (self.p + 1) / 2 385 | 386 | def forward(self, x): 387 | p, a, b, c = self.p, self.a, self.b, self.c 388 | x_pow_p0 = x.pow(p - 1) 389 | x_pow_p1 = x_pow_p0 * x 390 | x_pow_p2 = x_pow_p1 * x 391 | return 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2 392 | 393 | 394 | class dist_emb(torch.nn.Module): 395 | def __init__(self, num_radial, cutoff=5.0, envelope_exponent=5): 396 | super(dist_emb, self).__init__() 397 | self.cutoff = cutoff 398 | self.envelope = Envelope(envelope_exponent) 399 | 400 | self.freq = torch.nn.Parameter(torch.Tensor(num_radial)) 401 | 402 | self.reset_parameters() 403 | 404 | def reset_parameters(self): 405 | self.freq.data = torch.arange(1, self.freq.numel() + 1).float().mul_(PI) 406 | 407 | def forward(self, dist): 408 | dist = dist.unsqueeze(-1) / self.cutoff 409 | return self.envelope(dist) * (self.freq * dist).sin() 410 | 411 | 412 | class angle_emb(torch.nn.Module): 413 | def __init__(self, num_spherical, num_radial, cutoff=5.0, 414 | envelope_exponent=5): 415 | super(angle_emb, self).__init__() 416 | assert num_radial <= 64 417 | self.num_spherical = num_spherical 418 | self.num_radial = num_radial 419 | self.cutoff = cutoff 420 | # self.envelope = Envelope(envelope_exponent) 421 | 422 | bessel_forms = bessel_basis(num_spherical, num_radial) 423 | sph_harm_forms = real_sph_harm(num_spherical) 424 | self.sph_funcs = [] 425 | self.bessel_funcs = [] 426 | 427 | x, theta = sym.symbols('x theta') 428 | modules = {'sin': torch.sin, 'cos': torch.cos} 429 | for i in range(num_spherical): 430 | if i == 0: 431 | sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0) 432 | self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1) 433 | else: 434 | sph = sym.lambdify([theta], sph_harm_forms[i][0], modules) 435 | self.sph_funcs.append(sph) 436 | for j in range(num_radial): 437 | bessel = sym.lambdify([x], bessel_forms[i][j], modules) 438 | self.bessel_funcs.append(bessel) 439 | 440 | def forward(self, dist, angle, idx_kj): 441 | dist = dist / self.cutoff 442 | rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1) 443 | # rbf = self.envelope(dist).unsqueeze(-1) * rbf 444 | 445 | cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1) 446 | 447 | n, k = self.num_spherical, self.num_radial 448 | out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k) 449 | return out 450 | 451 | 452 | class torsion_emb(torch.nn.Module): 453 | def __init__(self, num_spherical, num_radial, cutoff=5.0, 454 | envelope_exponent=5): 455 | super(torsion_emb, self).__init__() 456 | assert num_radial <= 64 457 | self.num_spherical = num_spherical # 458 | self.num_radial = num_radial 459 | self.cutoff = cutoff 460 | # self.envelope = Envelope(envelope_exponent) 461 | 462 | bessel_forms = bessel_basis(num_spherical, num_radial) 463 | sph_harm_forms = real_sph_harm(num_spherical, zero_m_only=False) 464 | self.sph_funcs = [] 465 | self.bessel_funcs = [] 466 | 467 | x = sym.symbols('x') 468 | theta = sym.symbols('theta') 469 | phi = sym.symbols('phi') 470 | modules = {'sin': torch.sin, 'cos': torch.cos} 471 | for i in range(self.num_spherical): 472 | if i == 0: 473 | sph1 = sym.lambdify([theta, phi], sph_harm_forms[i][0], modules) 474 | self.sph_funcs.append(lambda x, y: torch.zeros_like(x) + torch.zeros_like(y) + sph1(0,0)) #torch.zeros_like(x) + torch.zeros_like(y) 475 | else: 476 | for k in range(-i, i + 1): 477 | sph = sym.lambdify([theta, phi], sph_harm_forms[i][k+i], modules) 478 | self.sph_funcs.append(sph) 479 | for j in range(self.num_radial): 480 | bessel = sym.lambdify([x], bessel_forms[i][j], modules) 481 | self.bessel_funcs.append(bessel) 482 | 483 | def forward(self, dist, angle, phi, idx_kj): 484 | dist = dist / self.cutoff 485 | rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1) 486 | cbf = torch.stack([f(angle, phi) for f in self.sph_funcs], dim=1) 487 | 488 | n, k = self.num_spherical, self.num_radial 489 | out = (rbf[idx_kj].view(-1, 1, n, k) * cbf.view(-1, n, n, 1)).view(-1, n * n * k) 490 | return out 491 | 492 | 493 | # Based on the code from: https://github.com/klicperajo/dimenet, 494 | # https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/models/dimenet.py 495 | 496 | def xyz_to_dat(pos, edge_index, num_nodes, use_torsion = False): 497 | """ 498 | Compute the diatance, angle, and torsion from geometric information. 499 | 500 | Args: 501 | pos: Geometric information for every node in the graph. 502 | edge_index: Edge index of the graph. 503 | number_nodes: Number of nodes in the graph. 504 | use_torsion: If set to :obj:`True`, will return distance, angle and torsion, otherwise only return distance and angle (also retrun some useful index). (default: :obj:`False`) 505 | """ 506 | j, i = edge_index # j->i 507 | 508 | # Calculate distances. # number of edges 509 | dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt() 510 | 511 | value = torch.arange(j.size(0), device=j.device) 512 | adj_t = SparseTensor(row=i, col=j, value=value, sparse_sizes=(num_nodes, num_nodes)) 513 | adj_t_row = adj_t[j] 514 | num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) 515 | 516 | # Node indices (k->j->i) for triplets. 517 | idx_i = i.repeat_interleave(num_triplets) 518 | idx_j = j.repeat_interleave(num_triplets) 519 | idx_k = adj_t_row.storage.col() 520 | mask = idx_i != idx_k 521 | idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] 522 | 523 | # Edge indices (k-j, j->i) for triplets. 524 | idx_kj = adj_t_row.storage.value()[mask] 525 | idx_ji = adj_t_row.storage.row()[mask] 526 | 527 | # Calculate angles. 0 to pi 528 | pos_ji = pos[idx_i] - pos[idx_j] 529 | pos_jk = pos[idx_k] - pos[idx_j] 530 | a = (pos_ji * pos_jk).sum(dim=-1) # cos_angle * |pos_ji| * |pos_jk| 531 | b = torch.cross(pos_ji, pos_jk).norm(dim=-1) # sin_angle * |pos_ji| * |pos_jk| 532 | angle = torch.atan2(b, a) 533 | 534 | 535 | if use_torsion: 536 | # Prepare torsion idxes. 537 | idx_batch = torch.arange(len(idx_i),device=j.device) 538 | idx_k_n = adj_t[idx_j].storage.col() 539 | repeat = num_triplets 540 | num_triplets_t = num_triplets.repeat_interleave(repeat)[mask] 541 | idx_i_t = idx_i.repeat_interleave(num_triplets_t) 542 | idx_j_t = idx_j.repeat_interleave(num_triplets_t) 543 | idx_k_t = idx_k.repeat_interleave(num_triplets_t) 544 | idx_batch_t = idx_batch.repeat_interleave(num_triplets_t) 545 | mask = idx_i_t != idx_k_n 546 | idx_i_t, idx_j_t, idx_k_t, idx_k_n, idx_batch_t = idx_i_t[mask], idx_j_t[mask], idx_k_t[mask], idx_k_n[mask], idx_batch_t[mask] 547 | 548 | # Calculate torsions. 549 | pos_j0 = pos[idx_k_t] - pos[idx_j_t] 550 | pos_ji = pos[idx_i_t] - pos[idx_j_t] 551 | pos_jk = pos[idx_k_n] - pos[idx_j_t] 552 | dist_ji = pos_ji.pow(2).sum(dim=-1).sqrt() 553 | plane1 = torch.cross(pos_ji, pos_j0) 554 | plane2 = torch.cross(pos_ji, pos_jk) 555 | a = (plane1 * plane2).sum(dim=-1) # cos_angle * |plane1| * |plane2| 556 | b = (torch.cross(plane1, plane2) * pos_ji).sum(dim=-1) / dist_ji 557 | torsion1 = torch.atan2(b, a) # -pi to pi 558 | torsion1[torsion1<=0]+=2*PI # 0 to 2pi 559 | torsion = scatter(torsion1,idx_batch_t,reduce='min') 560 | 561 | return dist, angle, torsion, i, j, idx_kj, idx_ji 562 | 563 | else: 564 | return dist, angle, i, j, idx_kj, idx_ji 565 | --------------------------------------------------------------------------------