├── n_body_system ├── __init__.py ├── dataset │ ├── __init__.py │ ├── script.sh │ └── generate_dataset.py ├── se3_dynamics │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ └── utils_profiling.py │ ├── equivariant_attention │ │ ├── __init__.py │ │ ├── from_se3cnn │ │ │ ├── __init__.py │ │ │ ├── license.txt │ │ │ ├── cache_file.py │ │ │ ├── representations.py │ │ │ ├── SO3.py │ │ │ └── utils_steerable.py │ │ ├── ops.py │ │ └── fibers.py │ ├── dynamics.py │ └── models.py ├── dataloader.py ├── post_process.py ├── dataset_nbody.py └── model.py ├── ponita ├── __init__.py ├── nn │ ├── __init__.py │ ├── convnext.py │ ├── embedding.py │ └── conv.py ├── utils │ ├── __init__.py │ ├── to_from_sphere.py │ └── windowing.py ├── geometry │ ├── __init__.py │ ├── rotation_2d.py │ ├── repulsion.py │ ├── spherical_grid.py │ └── invariants.py ├── models │ ├── __init__.py │ └── ponita.py └── transforms │ ├── __init__.py │ ├── random_rotate.py │ ├── invariants.py │ └── position_orientation_graph.py ├── setup.py ├── lightning_wrappers ├── scheduler.py ├── nbody.py ├── mnist.py ├── qm9.py └── md17.py ├── LICENSE ├── .gitignore ├── main_qm9.py ├── main_nbody.py ├── main_mnist.py ├── main_md17.py └── README.md /n_body_system/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /n_body_system/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/equivariant_attention/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ponita/__init__.py: -------------------------------------------------------------------------------- 1 | # ponita/__init__.py 2 | # Empty content 3 | -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/equivariant_attention/from_se3cnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ponita/nn/__init__.py: -------------------------------------------------------------------------------- 1 | # ponita/nn/__init__.py 2 | # Empty content 3 | -------------------------------------------------------------------------------- /ponita/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # ponita/utils/__init__.py 2 | # Empty content 3 | -------------------------------------------------------------------------------- /ponita/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | # ponita/geometry/__init__.py 2 | # Empty content 3 | -------------------------------------------------------------------------------- /ponita/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ponita import Ponita, PonitaFiberBundle, PonitaPointCloud 2 | -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/utils/utils_profiling.py: -------------------------------------------------------------------------------- 1 | try: 2 | profile 3 | except NameError: 4 | def profile(func): 5 | return func 6 | -------------------------------------------------------------------------------- /n_body_system/dataset/script.sh: -------------------------------------------------------------------------------- 1 | python generate_dataset.py --initial_vel 1 2 | python generate_dataset.py --initial_vel 0 --length 2000 --length_test 2000 3 | python generate_dataset.py --n_balls 15 --initial_vel 1 4 | 5 | -------------------------------------------------------------------------------- /ponita/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .position_orientation_graph import PositionOrientationGraph 2 | from .invariants import SEnInvariantAttributes 3 | from .random_rotate import RandomRotate 4 | 5 | __all__ = ('PositionOrientationGraph', 'SEnInvariantAttributes', 'RandomRotate') -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='ponita', 5 | version='0.1', 6 | packages=find_packages(), 7 | url='https://github.com/ebekkers/ponita.git', 8 | license='MIT', 9 | author='ebekkers', 10 | author_email='e.j.bekkers@uva.nl', 11 | description='Ponita: Fast, Expressive SE(n) Equivariant Networks through Weight-Sharing in Position-Orientation Space', 12 | python_requires=">=3.10.5", 13 | ) 14 | -------------------------------------------------------------------------------- /ponita/utils/to_from_sphere.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def vec_to_sphere(vec, ori_grid): 5 | return torch.einsum('bcd,nd->bnc', vec, ori_grid) # [num_nodes, num_ori, num_vec] 6 | 7 | def scalar_to_sphere(scalar, ori_grid): 8 | return scalar.unsqueeze(-2).repeat_interleave(ori_grid.shape[-2], dim=-2) # [num_nodes, num_ori, num_scalars] 9 | 10 | def sphere_to_vec(spherical_signal, ori_grid): 11 | return torch.einsum('bnc,nd->bcd',spherical_signal, ori_grid) / ori_grid.shape[-2] # [num_nodes, num_vec, n] 12 | 13 | def sphere_to_scalar(spherical_signal): 14 | return spherical_signal.mean(dim=-2) # [num_nodes, num_scalars] -------------------------------------------------------------------------------- /lightning_wrappers/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): 6 | 7 | def __init__(self, optimizer, warmup, max_iters): 8 | self.warmup = warmup 9 | self.max_num_iters = max_iters 10 | super().__init__(optimizer) 11 | 12 | def get_lr(self): 13 | lr_factor = self.get_lr_factor(epoch=self.last_epoch) 14 | return [base_lr * lr_factor for base_lr in self.base_lrs] 15 | 16 | def get_lr_factor(self, epoch): 17 | lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters)) 18 | if epoch <= self.warmup: 19 | lr_factor *= (epoch + 1e-6) * 1.0 / (self.warmup + 1e-6) 20 | return lr_factor -------------------------------------------------------------------------------- /ponita/geometry/rotation_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | 5 | def random_so2_matrix(batch_size, device: Optional[str] = None) -> torch.Tensor: 6 | # Generate random angles 7 | angles = 2 * torch.pi * torch.rand(batch_size, device=device) 8 | 9 | # Calculate sin and cos for each angle 10 | cos_vals = torch.cos(angles) 11 | sin_vals = torch.sin(angles) 12 | 13 | # Construct the rotation matrices 14 | rotation_matrices = torch.stack([cos_vals, -sin_vals, sin_vals, cos_vals], dim=1) 15 | rotation_matrices = rotation_matrices.view(batch_size, 2, 2) 16 | 17 | return rotation_matrices 18 | 19 | 20 | def uniform_grid_s1(num_points: int, device: Optional[str] = None) -> torch.Tensor: 21 | # Generate angles uniformly 22 | # NOTE: the last element should be one portion before 2*pi. 23 | angles = torch.linspace( 24 | start=0, 25 | end=2 * torch.pi - (2 * torch.pi / num_points), 26 | steps=num_points 27 | ) 28 | 29 | # Calculate x and y coordinates 30 | x = torch.cos(angles) 31 | y = torch.sin(angles) 32 | 33 | return torch.stack((x, y), dim=1) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Universiteit van Amsterdam, Erik J Bekkers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ponita/nn/convnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ConvNext(torch.nn.Module): 5 | """ 6 | """ 7 | def __init__(self, channels, conv, act=torch.nn.GELU(), layer_scale=1e-6, widening_factor=4): 8 | super().__init__() 9 | 10 | self.conv = conv 11 | self.act_fn = act 12 | self.linear_1 = torch.nn.Linear(channels, widening_factor * channels) 13 | self.linear_2 = torch.nn.Linear(widening_factor * channels, channels) 14 | if layer_scale is not None: 15 | self.layer_scale = torch.nn.Parameter(torch.ones(channels) * layer_scale) 16 | else: 17 | self.register_buffer('layer_scale', None) 18 | self.norm = torch.nn.LayerNorm(channels) 19 | 20 | def forward(self, x, edge_index, edge_attr, **kwargs): 21 | """ 22 | """ 23 | input = x 24 | x = self.conv(x, edge_index, edge_attr, **kwargs) 25 | x = self.norm(x) 26 | x = self.linear_1(x) 27 | x = self.act_fn(x) 28 | x = self.linear_2(x) 29 | if self.layer_scale is not None: 30 | x = self.layer_scale * x 31 | if input.shape == x.shape: 32 | x = x + input 33 | return x -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/equivariant_attention/from_se3cnn/license.txt: -------------------------------------------------------------------------------- 1 | the code in this folder was mostly obtained from https://github.com/mariogeiger/se3cnn/ 2 | 3 | which has the following license: 4 | 5 | MIT License 6 | 7 | Copyright (c) 2019 Mario Geiger 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | -------------------------------------------------------------------------------- /ponita/utils/windowing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class PolynomialCutoff(torch.nn.Module): 5 | """ 6 | Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020. 7 | Equation (8) 8 | """ 9 | 10 | p: torch.Tensor 11 | r_max: torch.Tensor 12 | 13 | def __init__(self, r_max, p=6): 14 | super().__init__() 15 | if r_max is not None: 16 | self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) 17 | self.register_buffer("r_max", torch.tensor(r_max, dtype=torch.get_default_dtype())) 18 | else: 19 | self.r_max = None 20 | 21 | def forward(self, x: torch.Tensor) -> torch.Tensor: 22 | if self.r_max is not None: 23 | envelope = ( 24 | 1.0 25 | - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.r_max, self.p) 26 | + self.p * (self.p + 2.0) * torch.pow(x / self.r_max, self.p + 1) 27 | - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.r_max, self.p + 2) 28 | ) 29 | return envelope * (x < self.r_max) 30 | else: 31 | return torch.ones_like(x) 32 | 33 | def __repr__(self): 34 | return f"{self.__class__.__name__}(p={self.p}, r_max={self.r_max})" 35 | 36 | -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/equivariant_attention/ops.py: -------------------------------------------------------------------------------- 1 | from ..utils.utils_profiling import * # load before other local modules 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from typing import Dict, List, Tuple 10 | 11 | 12 | def shape_is(a, b, ignore_batch=1): 13 | """ 14 | check whether multi-dimensional array a has dimensions b; use in combination with assert 15 | 16 | :param a: multi dimensional array 17 | :param b: list of ints which indicate expected dimensions of a 18 | :param ignore_batch: if set to True, ignore first dimension of a 19 | :return: True or False 20 | """ 21 | if ignore_batch: 22 | shape_a = np.array(a.shape[1:]) 23 | else: 24 | shape_a = np.array(a.shape) 25 | shape_b = np.array(b) 26 | return np.array_equal(shape_a, shape_b) 27 | 28 | 29 | def norm_with_epsilon(input_tensor, axis=None, keep_dims=False, epsilon=0.0): 30 | """ 31 | Regularized norm 32 | 33 | Args: 34 | input_tensor: torch.Tensor 35 | 36 | Returns: 37 | torch.Tensor normed over axis 38 | """ 39 | # return torch.sqrt(torch.max(torch.reduce_sum(torch.square(input_tensor), axis=axis, keep_dims=keep_dims), epsilon)) 40 | keep_dims = bool(keep_dims) 41 | squares = torch.sum(input_tensor**2, axis=axis, keepdim=keep_dims) 42 | squares = torch.max(squares, torch.tensor([epsilon]).to(squares.device)) 43 | return torch.sqrt(squares) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb 2 | *.ckpt 3 | backup 4 | datasets 5 | logs 6 | n_body_system/dataset/*.npy 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /ponita/transforms/random_rotate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import BaseTransform 3 | from ponita.geometry.rotation import random_matrix as random_so3_matrix 4 | from ponita.geometry.rotation_2d import random_so2_matrix 5 | 6 | class RandomRotate(BaseTransform): 7 | """ 8 | A PyTorch Geometric transform that randomly rotates each point cloud in a batch of graphs 9 | by sampling rotations from a uniform distribution over the Special Orthogonal Group SO(3). 10 | 11 | Args: 12 | None 13 | """ 14 | 15 | def __init__(self, attr_list, n=3): 16 | super().__init__() 17 | self.attr_list = attr_list 18 | self.random_rotation_matrix_fn = random_so2_matrix if (n == 2) else random_so3_matrix 19 | 20 | def __call__(self, graph): 21 | """ 22 | Apply the random rotation transform to the input graph. 23 | 24 | Args: 25 | graph (torch_geometric.data.Data): Input graph containing position (graph.pos), 26 | and optionally, batch information (graph.batch). 27 | 28 | Returns: 29 | torch_geometric.data.Data: Updated graph with randomly rotated positions. 30 | The positions in the graph are rotated using a random rotation matrix 31 | sampled from a uniform distribution over SO(3). 32 | """ 33 | rand_rot = self.random_rotation(graph) 34 | return self.rotate_graph(graph, rand_rot) 35 | 36 | def rotate_graph(self, graph, rand_rot): 37 | for attr in self.attr_list: 38 | if hasattr(graph, attr): 39 | setattr(graph, attr, self.rotate_attr(getattr(graph, attr), rand_rot)) 40 | return graph 41 | 42 | def rotate_attr(self, attr, rand_rot): 43 | rand_rot = rand_rot.type_as(attr) 44 | if len(rand_rot.shape)==3: 45 | if len(attr.shape)==2: 46 | return torch.einsum('bij,bj->bi', rand_rot, attr) 47 | else: 48 | return torch.einsum('bij,bcj->bci', rand_rot, attr) 49 | else: 50 | if len(attr.shape)==2: 51 | return torch.einsum('ij,bj->bi', rand_rot, attr) 52 | else: 53 | return torch.einsum('ij,bcj->bci', rand_rot, attr) 54 | 55 | def random_rotation(self, graph): 56 | if graph.batch is not None: 57 | batch_size = graph.batch.max() + 1 58 | random_rotation_matrix = self.random_rotation_matrix_fn(batch_size).to(graph.batch.device) 59 | random_rotation_matrix = random_rotation_matrix[graph.batch] 60 | else: 61 | random_rotation_matrix = self.random_rotation_matrix(1)[0] 62 | return random_rotation_matrix 63 | -------------------------------------------------------------------------------- /n_body_system/dataloader.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class Dataloader(): 6 | def __init__(self, dataset, batch_size=1, slice=[0, 1e8], shuffle=True): 7 | self.dataset = dataset 8 | self.batch_size = batch_size 9 | self.n_nodes = self.dataset.get_n_nodes() 10 | self.edges = self.expand_edges(dataset.edges, batch_size, self.n_nodes) 11 | self.idxs_permuted = list(range(len(self.dataset))) 12 | self.shuffle = shuffle 13 | self.slice = slice 14 | if self.shuffle: 15 | random.shuffle(self.idxs_permuted) 16 | self.idx = 0 17 | 18 | def __iter__(self): 19 | return self 20 | 21 | def expand_edges(self, edges, batch_size, n_nodes): 22 | edges = [torch.LongTensor(edges[0]), torch.LongTensor(edges[1])] 23 | if batch_size == 1: 24 | return edges 25 | elif batch_size > 1: 26 | rows, cols = [], [] 27 | for i in range(batch_size): 28 | rows.append(edges[0] + n_nodes*i) 29 | cols.append(edges[1] + n_nodes*i) 30 | edges = [torch.cat(rows), torch.cat(cols)] 31 | return edges 32 | 33 | def __next__(self): 34 | if self.idx > len(self.dataset) - self.batch_size: 35 | self.idx = 0 36 | #random.shuffle(self.dataset.graphs) 37 | raise StopIteration # Done iterating. 38 | else: 39 | loc, vel, edge_attr, charges = self.dataset.data 40 | idx_permuted = self.idxs_permuted[self.idx:self.idx + self.batch_size] 41 | batched_data = loc[idx_permuted], vel[idx_permuted], edge_attr[idx_permuted], charges[idx_permuted] 42 | [loc_batch, vel_batch, edge_attr_batch, loc_end_batch, charges_batch] = self.cast_batch(list(batched_data)) 43 | 44 | self.idx += self.batch_size 45 | return loc_batch, vel_batch, edge_attr_batch, loc_end_batch, charges_batch 46 | 47 | def cast_batch(self, batched_data): 48 | #loc_batch, vel_batch, edges_batch, loc_end_batch = batched_data 49 | #if self.batch_size > 1: 50 | # raise Exception("To implement") 51 | batched_data = [d.contiguous().view(-1, d.size(2)) for d in batched_data] 52 | 53 | return batched_data 54 | #else: 55 | # return loc_batch[0], vel_batch[0], edges_batch[0], loc_end_batch[0] 56 | 57 | def __len__(self): 58 | return len(self.dataset) 59 | 60 | def partition(self): 61 | return self.dataset.partition 62 | 63 | 64 | if __name__ == "__main__": 65 | ''' 66 | from dataset_nbody import NBodyDataset 67 | 68 | dataset_train = NBodyDataset(partition='train') 69 | dataloader_train = Dataloader(dataset_train) 70 | for i, (loc, vel, edges) in enumerate(dataset_train): 71 | print(i) 72 | ''' 73 | 74 | 75 | -------------------------------------------------------------------------------- /ponita/nn/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class PolynomialFeatures(torch.nn.Module): 5 | def __init__(self, degree): 6 | super(PolynomialFeatures, self).__init__() 7 | 8 | self.degree = degree 9 | 10 | def forward(self, x): 11 | 12 | polynomial_list = [x] 13 | for it in range(1, self.degree): 14 | polynomial_list.append(torch.einsum('...i,...j->...ij', polynomial_list[-1], x).flatten(-2,-1)) 15 | return torch.cat(polynomial_list, -1) 16 | 17 | 18 | class RandomFourierFeatures(torch.nn.Module): 19 | def __init__(self, out_dim, sigma, symmetric=None): 20 | super(RandomFourierFeatures, self).__init__() 21 | 22 | self.out_dim = out_dim 23 | if out_dim % 2 != 0: 24 | self.compensation = 1 25 | else: 26 | self.compensation = 0 27 | self.num_frequencies = int(out_dim / 2) + self.compensation 28 | 29 | if symmetric is None: 30 | symmetric = [False] * len(sigma) 31 | self.unconstraint_idx = [i for i, x in enumerate(symmetric) if not(x)] 32 | self.constraint_idx = [i for i, x in enumerate(symmetric) if x] 33 | 34 | self.sigma = sigma 35 | if len(self.unconstraint_idx) > 0: 36 | self.frequencies_unconstraint = torch.stack([self.random_frequencies(self.sigma[i], self.num_frequencies) for i in self.unconstraint_idx],dim=0) 37 | else: 38 | self.frequencies_unconstraint = None 39 | if len(self.constraint_idx) > 0: 40 | self.frequencies_constraint = torch.stack([self.random_frequencies(self.sigma[i], 2 * self.num_frequencies) for i in self.constraint_idx],dim=0) 41 | else: 42 | self.frequencies_constraint = None 43 | 44 | def random_frequencies(self, sigma, num_frequencies): 45 | if type(sigma)==float: 46 | # Continuous frequencies, sigma is interpreted as the std of the gaussian distribution from which we sample 47 | return torch.randn(num_frequencies) * math.sqrt(1/2) * sigma 48 | elif type(sigma)==int: 49 | # Integer frequencies, now sigma is interpreted as the integer band-limit (max frequency) 50 | return torch.randint(-sigma, sigma,(num_frequencies,)) 51 | 52 | def forward(self, x): 53 | 54 | # Mix unconstraint terms 55 | if self.frequencies_unconstraint is not None: 56 | unconstraint_proj = x[..., self.unconstraint_idx] @ self.frequencies_unconstraint.type_as(x) 57 | out = torch.cat([unconstraint_proj.cos(),unconstraint_proj.sin()], dim=-1) 58 | else: 59 | out = torch.ones((1, self.num_frequencies * 2)) 60 | 61 | # Tensor product for contraint terms 62 | if self.frequencies_constraint is not None: 63 | out = out * torch.einsum('...d,di->...di', x[...,self.constraint_idx], self.frequencies_constraint.type_as(x)).cos().prod(dim=-2) 64 | 65 | # Crop to dimension (if necessary), and return output 66 | if self.compensation: 67 | out = out[..., :-1] 68 | return out -------------------------------------------------------------------------------- /ponita/geometry/repulsion.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/ThijsKuipers1995/gconv 2 | """ 3 | repulsion.py 4 | 5 | Contains repulsion model. 6 | """ 7 | 8 | 9 | import torch 10 | from torch import Tensor 11 | 12 | from tqdm import trange 13 | 14 | from typing import Callable 15 | 16 | 17 | def columb_energy(d: Tensor, k: int = 2) -> Tensor: 18 | """ 19 | Returns columb energy over given input. 20 | 21 | Arguments: 22 | - d: Tensor to calculate columb energy over. 23 | - k: Exponent term of columb energy. 24 | 25 | Returns: 26 | - Tensor containing columb energy. 27 | """ 28 | return d ** (-k) 29 | 30 | 31 | def repulse( 32 | grid: Tensor, 33 | steps: int = 200, 34 | step_size: float = 10, 35 | metric_fn: Callable = lambda x, y: x - y, 36 | transform_fn: Callable = lambda x: x, 37 | energy_fn: Callable = columb_energy, 38 | dist_normalization_constant: float = 1, 39 | alpha: float = 0.001, 40 | show_pbar: bool = True, 41 | in_place: bool = False, 42 | ) -> Tensor: 43 | """ 44 | Performs repulsion grids defined on Sn. Will perform 45 | repulsion on device that grid is defined on. 46 | 47 | Arguments: 48 | - grid: Tensor of shape (N, D) of grid elements 49 | parameterized with a periodic parameterization, i.e. 50 | spherical angles on S2. 51 | - steps: Number times the optimimzation step will be peformed. 52 | - step_size: Strength of each optimization step. 53 | - dist_fn: Metric for which energy will be minimized. 54 | - transform_fn: Function that transforms grid to the parameterization 55 | used by the metric. 56 | - energy_fn: Function used to calculate energy. Defaults to columb energy. 57 | - dist_normalization_constant: Optional constant for normalizing distances. 58 | Set to max distance from metric or else 1. 59 | - show_pbar: If True, will show progress of optimization procedure. 60 | - in_place: If True, will update grid in place. 61 | 62 | Returns: 63 | - Tensor of shape (N, D) of minimized grid. 64 | """ 65 | 66 | pbar = trange(steps, disable=not show_pbar, desc="Optimizing") 67 | 68 | grid = grid if in_place else grid.clone() 69 | grid.requires_grad = True 70 | 71 | optimizer = torch.optim.SGD([grid], lr=step_size) 72 | 73 | for epoch in pbar: 74 | optimizer.zero_grad(set_to_none=True) 75 | 76 | grid_transform = transform_fn(grid) 77 | 78 | dists = metric_fn(grid_transform[:, None], grid_transform).sort(dim=-1)[0][:, 1:] 79 | energy_matrix = energy_fn(dists / dist_normalization_constant) 80 | 81 | mean_total_energy = energy_matrix.mean() 82 | mean_total_energy.backward() 83 | grid.grad += (steps - epoch) / steps * alpha * torch.randn(grid.grad.shape, device=grid.device) 84 | 85 | optimizer.step() 86 | 87 | pbar.set_postfix_str(f"mean total energy: {mean_total_energy.item():.3f}") 88 | 89 | grid.requires_grad = False 90 | 91 | return grid.detach() -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/equivariant_attention/from_se3cnn/cache_file.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Cache in files 3 | ''' 4 | from functools import wraps, lru_cache 5 | import pickle 6 | import gzip 7 | import os 8 | import sys 9 | import fcntl 10 | 11 | 12 | class FileSystemMutex: 13 | ''' 14 | Mutual exclusion of different **processes** using the file system 15 | ''' 16 | 17 | def __init__(self, filename): 18 | self.handle = None 19 | self.filename = filename 20 | 21 | def acquire(self): 22 | ''' 23 | Locks the mutex 24 | if it is already locked, it waits (blocking function) 25 | ''' 26 | self.handle = open(self.filename, 'w') 27 | fcntl.lockf(self.handle, fcntl.LOCK_EX) 28 | self.handle.write("{}\n".format(os.getpid())) 29 | self.handle.flush() 30 | 31 | def release(self): 32 | ''' 33 | Unlock the mutex 34 | ''' 35 | if self.handle is None: 36 | raise RuntimeError() 37 | fcntl.lockf(self.handle, fcntl.LOCK_UN) 38 | self.handle.close() 39 | self.handle = None 40 | 41 | def __enter__(self): 42 | self.acquire() 43 | 44 | def __exit__(self, exc_type, exc_value, traceback): 45 | self.release() 46 | 47 | 48 | def cached_dirpklgz(dirname, maxsize=128): 49 | ''' 50 | Cache a function with a directory 51 | 52 | :param dirname: the directory path 53 | :param maxsize: maximum size of the RAM cache (there is no limit for the directory cache) 54 | ''' 55 | 56 | def decorator(func): 57 | ''' 58 | The actual decorator 59 | ''' 60 | 61 | @lru_cache(maxsize=maxsize) 62 | @wraps(func) 63 | def wrapper(*args, **kwargs): 64 | ''' 65 | The wrapper of the function 66 | ''' 67 | try: 68 | os.makedirs(dirname) 69 | except FileExistsError: 70 | pass 71 | 72 | indexfile = os.path.join(dirname, "index.pkl") 73 | mutexfile = os.path.join(dirname, "mutex") 74 | 75 | with FileSystemMutex(mutexfile): 76 | try: 77 | with open(indexfile, "rb") as file: 78 | index = pickle.load(file) 79 | except FileNotFoundError: 80 | index = {} 81 | 82 | key = (args, frozenset(kwargs), func.__defaults__) 83 | 84 | try: 85 | filename = index[key] 86 | except KeyError: 87 | index[key] = filename = "{}.pkl.gz".format(len(index)) 88 | with open(indexfile, "wb") as file: 89 | pickle.dump(index, file) 90 | 91 | filepath = os.path.join(dirname, filename) 92 | 93 | try: 94 | with FileSystemMutex(mutexfile): 95 | with gzip.open(filepath, "rb") as file: 96 | result = pickle.load(file) 97 | except FileNotFoundError: 98 | print("compute {}... ".format(filename), end="") 99 | sys.stdout.flush() 100 | result = func(*args, **kwargs) 101 | print("save {}... ".format(filename), end="") 102 | sys.stdout.flush() 103 | with FileSystemMutex(mutexfile): 104 | with gzip.open(filepath, "wb") as file: 105 | pickle.dump(result, file) 106 | print("done") 107 | return result 108 | 109 | return wrapper 110 | 111 | return decorator 112 | -------------------------------------------------------------------------------- /ponita/transforms/invariants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import BaseTransform 3 | from ponita.geometry.invariants import invariant_attr_r2s1_fiber_bundle, invariant_attr_r2s1_point_cloud 4 | from ponita.geometry.invariants import invariant_attr_rn, invariant_attr_r3s2_fiber_bundle, invariant_attr_r3s2_point_cloud 5 | 6 | 7 | class SEnInvariantAttributes(BaseTransform): 8 | """ 9 | A PyTorch Geometric transform that adds invariant edge attributes to the input graph. 10 | The transformation includes pair-wise distances in position space (graph.dists) and 11 | invariant edge attributes between local orientations. 12 | 13 | Args: 14 | separable (bool): If True, computes spatial invariants for each orientation separately 15 | (no orientation interactions). If False, computes all pair-wise 16 | invariants between orientations in the receiving fiber related to 17 | those in the sending fiber. 18 | """ 19 | 20 | def __init__(self, separable=True, point_cloud=False): 21 | super().__init__() 22 | # Discretization of the orientation grid 23 | self.separable = separable 24 | self.point_cloud = point_cloud 25 | 26 | def __call__(self, graph): 27 | """ 28 | Apply the transform to the input graph. 29 | 30 | Args: 31 | graph (torch_geometric.data.Data): Input graph containing position (graph.pos), 32 | orientation (graph.ori), and edge index (graph.edge_index). 33 | 34 | Returns: 35 | torch_geometric.data.Data: Updated graph with added invariant edge attributes. 36 | Pair-wise distances in position space are stored in graph.dists. 37 | If separable is True, graph.attr contains spatial invariants for 38 | each orientation, and graph.attr_ori contains invariants between 39 | local orientations. If separable is False, graph.attr contains 40 | all pair-wise invariants between orientations. 41 | """ 42 | # TODO: make more elegant 43 | if graph.n == 2: 44 | graph.dists = invariant_attr_rn(graph.pos[:,:graph.n], graph.edge_index) 45 | if self.point_cloud: 46 | if graph.pos.size(-1) == graph.n: 47 | graph.attr = graph.dists 48 | else: 49 | graph.attr = invariant_attr_r2s1_point_cloud(graph.pos, graph.edge_index) 50 | return graph 51 | else: 52 | if self.separable: 53 | graph.attr, graph.fiber_attr = invariant_attr_r2s1_fiber_bundle(graph.pos, graph.ori_grid, graph.edge_index, separable=True) 54 | else: 55 | graph.attr = invariant_attr_r2s1_fiber_bundle(graph.pos, graph.ori_grid, graph.edge_index, separable=False) 56 | return graph 57 | else: 58 | graph.dists = invariant_attr_rn(graph.pos[:,:graph.n], graph.edge_index) 59 | graph.attr = graph.dists 60 | if self.point_cloud: 61 | if graph.pos.size(-1) == graph.n: 62 | graph.attr = graph.dists 63 | else: 64 | graph.attr = invariant_attr_r3s2_point_cloud(graph.pos, graph.edge_index) 65 | return graph 66 | else: 67 | if self.separable: 68 | graph.attr, graph.fiber_attr = invariant_attr_r3s2_fiber_bundle(graph.pos, graph.ori_grid, graph.edge_index, separable=True) 69 | else: 70 | graph.attr = invariant_attr_r3s2_fiber_bundle(graph.pos, graph.ori_grid, graph.edge_index, separable=False) 71 | return graph -------------------------------------------------------------------------------- /n_body_system/post_process.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | def draw_result(lst_iter, loss_gnn, loss_egnn, loss_baseline, title): 4 | plt.plot(lst_iter, loss_gnn, '-b', label='gnn') 5 | plt.plot(lst_iter, loss_egnn, '-r', label='egnn') 6 | plt.plot(lst_iter, loss_baseline, '--g', label='baseline') 7 | 8 | plt.xlabel("n iteration") 9 | plt.legend(loc='upper left') 10 | plt.title(title) 11 | 12 | # save image 13 | # plt.savefig(title+".png") # should before show method 14 | 15 | # show 16 | plt.show() 17 | 18 | 19 | 20 | if __name__ == "__main__": 21 | 22 | 23 | gnn = [0.7161,0.4581,0.4035,0.0858,0.0301,0.0267,0.0257,0.0243,0.0235,0.0226,0.0218,0.0212,0.0209,0.0201,0.0199,0.0199,0.0198,0.0199,0.0196,0.0193,0.0192,0.0188,0.0184,0.0180,0.0183,0.0180,0.0179,0.0183,0.0175,0.0173,0.0182,0.0196,0.0171,0.0185,0.0187,0.0173,0.0155,0.0155,0.0183,0.0164,0.0146,0.0173,0.0172,0.0154,0.0181,0.0189,0.0186,0.0162,0.0183,0.0158,0.0156,0.0138,0.0206,0.0148,0.0141,0.0155,0.0155,0.0163,0.0127,0.0167,0.0133,0.0146,0.0194,0.0169,0.0145,0.0144,0.0129,0.0160,0.0165,0.0136,0.0123,0.0123,0.0124,0.0129,0.0122,0.0186,0.0133,0.0138,0.0123,0.0114,0.0131,0.0128,0.0123,0.0114,0.0116,0.0121,0.0121,0.0115,0.0119,0.0109,0.0112,0.0117,0.0110,0.0108,0.0110,0.0119,0.0114,0.0108,0.0121,0.0111,0.0111,0.0109,0.0100,0.0117,0.0106,0.0113,0.0118,0.0128,0.0101,0.0114,0.0106,0.0112,0.0104,0.0114,0.0097,0.0106,0.0105,0.0117,0.0103,0.0100,0.0097,0.0107,0.0097,0.0129,0.0102,0.0118,0.0119,0.0099,0.0104,0.0111,0.0095,0.0104,0.0122,0.0090,0.0114,0.0089,0.0095,0.0097,0.0088,0.0116,0.0103,0.0106,0.0100,0.0170,0.0085,0.0094,0.0107,0.0100,0.0092,0.0106,0.0084,0.0096,0.0093,0.0092,0.0102,0.0091,0.0105,0.0088,0.0110,0.0092,0.0124,0.0092,0.0099,0.0091,0.0096,0.0093,0.0085,0.0091,0.0086,0.0085,0.0078,0.0086,0.0085,0.0093,0.0084,0.0084,0.0093,0.0082,0.0073,0.0091,0.0090,0.0089,0.0095,0.0076,0.0078,0.0078,0.0077,0.0085,0.0084,0.0084,0.0085,0.0078,0.0079,0.0095,0.0072,0.0075,0.0075,0.0094,0.0081,0.0088,0.0077,0.0083,0.0080,0.0075,0.0069,0.0088,0.0080,0.0079,0.0076,0.0122,0.0084,0.0083,0.0082,0.0082,0.0078,0.0074,0.0092,0.0069,0.0089,0.0071,0.0079,0.0099,0.0079,0.0070,0.0086,0.0070,0.0093,0.0075,0.0078,0.0079,0.0085,0.0088,0.0104,0.0075,0.0079,0.0078,0.0080,0.0091,0.0089,0.0080,0.0079,0.0079,0.0076,0.0077,0.0071,0.0080,0.0073,0.0065,0.0082,0.0116,0.0075,0.0074,0.0073,0.0086,0.0068,0.0082,0.0079,0.0095,0.0071,0.0088,0.0104,0.0088,0.0072,0.0084,0.0068,0.0076,0.0080,0.0074,0.0077,0.0076,0.0070,0.0079,0.0079,0.0068,0.0084,0.0077,0.0100,0.0078,0.0084,0.0071,0.0072,0.0078,0.0067,0.0077,0.0075,0.0065,0.0076,0.0091,0.0074,0.0082,0.0071,0.0074,0.0072,0.0085,0.0073,0.0069,0.0067,0.0091,0.0085,0.0086,0.0077,0.0067,0.0077,0.0086,0.0074,0.0066,0.0075,0.0071,0.0078,0.0070,0.0074,0.0093,0.0074,0.0070,0.0092,0.0070,0.0071,0.0114,0.0074,0.0074,0.0071,0.0076,0.0090,0.0075,0.0073,0.0079,0.0103,0.0075,0.0081,0.0078,0.0089,0.0103,0.0071,0.0069,0.0076,0.0069,0.0067,0.0074,0.0082,0.0079,0.0081,0.0069,0.0070,0.0062,0.0072,0.0121,0.0066,0.0065,0.0075,0.0074,0.0076,0.0071,0.0078,0.0066,0.0080,0.0069,0.0066,0.0070,0.0072,0.0080,0.0078,0.0074,0.0084,0.0070,0.0078,0.0081,0.0071,0.0065,0.0071,0.0081,0.0094,0.0085,0.0070,0.0068,0.0080,0.0085,0.0070,0.0069,0.0075,0.0065,0.0071,0.0076,0.0077,0.0071,0.0066,0.0075,0.0086,0.0068,0.0065,0.0072,0.0073,0.0097,0.0066,0.0073,0.0075,0.0073,0.0067,0.0064,0.0065,0.0068] 24 | egnn = [0.1020,0.1055,0.1058,0.1029,0.0986,0.0948,0.0913,0.0883,0.0857,0.0836,0.0818,0.0804,0.0794,0.0786,0.0779,0.0773,0.0769,0.0767,0.0766,0.0765,0.0766,0.0770,0.0775,0.0784,0.0791,0.0798,0.0808,0.0823,0.0829,0.0846,0.0855,0.0866,0.0879,0.0893,0.0907,0.0923,0.0936,0.0947,0.0957,0.0963,0.0967,0.0974,0.0982,0.0989,0.0998,0.1008,0.1022,0.1051,0.1075,0.1093,0.1115,0.1142,0.1177,0.1219,0.1259,0.1295,0.1346,0.1396,0.1433,0.1462,0.1497,0.1530,0.1564,0.1591,0.1617,0.1643,0.1667,0.1688,0.1710,0.1730,0.1754,0.1780,0.1802,0.1824,0.1851,0.1879,0.1907,0.1933,0.1959,0.1984,0.2014,0.2045,0.2073,0.2101,0.2124,0.2148,0.2173,0.2196,0.2212,0.2227,0.2239,0.2252,0.2267,0.2283,0.2299,0.2314,0.2332,0.2352,0.2372,0.2391,0.2408,0.2424,0.2441,0.2455,0.2471,0.2489,0.2507,0.2525,0.2543,0.2561,0.2578,0.2598,0.2617,0.2635,0.2651,0.2667,0.2683,0.2694,0.2706] 25 | gnn = gnn[1:100] 26 | egnn = egnn[1:100] 27 | num_epochs = len(gnn) 28 | baseline = [0.1075] * num_epochs 29 | epochs = list(range(num_epochs)) 30 | draw_result(epochs, gnn, egnn, baseline, title="Comparison") 31 | -------------------------------------------------------------------------------- /n_body_system/dataset/generate_dataset.py: -------------------------------------------------------------------------------- 1 | from synthetic_sim import ChargedParticlesSim, SpringSim 2 | import time 3 | import numpy as np 4 | import argparse 5 | 6 | """ 7 | nbody: python -u generate_dataset.py --num-train 50000 --sample-freq 500 2>&1 | tee log_generating_100000.log & 8 | 9 | nbody_small: python -u generate_dataset.py --num-train 10000 --seed 43 --sufix small 2>&1 | tee log_generating_10000_small.log & 10 | 11 | """ 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--simulation', type=str, default='charged', 15 | help='What simulation to generate.') 16 | parser.add_argument('--num-train', type=int, default=10000, 17 | help='Number of training simulations to generate.') 18 | parser.add_argument('--num-valid', type=int, default=2000, 19 | help='Number of validation simulations to generate.') 20 | parser.add_argument('--num-test', type=int, default=2000, 21 | help='Number of test simulations to generate.') 22 | parser.add_argument('--length', type=int, default=5000, 23 | help='Length of trajectory.') 24 | parser.add_argument('--length_test', type=int, default=5000, 25 | help='Length of test set trajectory.') 26 | parser.add_argument('--sample-freq', type=int, default=100, 27 | help='How often to sample the trajectory.') 28 | parser.add_argument('--n_balls', type=int, default=5, 29 | help='Number of balls in the simulation.') 30 | parser.add_argument('--seed', type=int, default=42, 31 | help='Random seed.') 32 | parser.add_argument('--initial_vel', type=int, default=1, 33 | help='consider initial velocity') 34 | parser.add_argument('--sufix', type=str, default="", 35 | help='add a sufix to the name') 36 | 37 | args = parser.parse_args() 38 | 39 | initial_vel_norm = 0.5 40 | if not args.initial_vel: 41 | initial_vel_norm = 1e-16 42 | 43 | if args.simulation == 'springs': 44 | sim = SpringSim(noise_var=0.0, n_balls=args.n_balls) 45 | suffix = '_springs' 46 | elif args.simulation == 'charged': 47 | sim = ChargedParticlesSim(noise_var=0.0, n_balls=args.n_balls, vel_norm=initial_vel_norm) 48 | suffix = '_charged' 49 | else: 50 | raise ValueError('Simulation {} not implemented'.format(args.simulation)) 51 | 52 | suffix += str(args.n_balls) + "_initvel%d" % args.initial_vel + args.sufix 53 | np.random.seed(args.seed) 54 | 55 | print(suffix) 56 | 57 | 58 | def generate_dataset(num_sims, length, sample_freq): 59 | loc_all = list() 60 | vel_all = list() 61 | edges_all = list() 62 | charges_all = list() 63 | for i in range(num_sims): 64 | t = time.time() 65 | loc, vel, edges, charges = sim.sample_trajectory(T=length, 66 | sample_freq=sample_freq) 67 | if i % 100 == 0: 68 | print("Iter: {}, Simulation time: {}".format(i, time.time() - t)) 69 | loc_all.append(loc) 70 | vel_all.append(vel) 71 | edges_all.append(edges) 72 | charges_all.append(charges) 73 | 74 | charges_all = np.stack(charges_all) 75 | loc_all = np.stack(loc_all) 76 | vel_all = np.stack(vel_all) 77 | edges_all = np.stack(edges_all) 78 | 79 | return loc_all, vel_all, edges_all, charges_all 80 | 81 | if __name__ == "__main__": 82 | 83 | print("Generating {} training simulations".format(args.num_train)) 84 | loc_train, vel_train, edges_train, charges_train = generate_dataset(args.num_train, 85 | args.length, 86 | args.sample_freq) 87 | 88 | print("Generating {} validation simulations".format(args.num_valid)) 89 | loc_valid, vel_valid, edges_valid, charges_valid = generate_dataset(args.num_valid, 90 | args.length, 91 | args.sample_freq) 92 | 93 | print("Generating {} test simulations".format(args.num_test)) 94 | loc_test, vel_test, edges_test, charges_test = generate_dataset(args.num_test, 95 | args.length_test, 96 | args.sample_freq) 97 | 98 | np.save('loc_train' + suffix + '.npy', loc_train) 99 | np.save('vel_train' + suffix + '.npy', vel_train) 100 | np.save('edges_train' + suffix + '.npy', edges_train) 101 | np.save('charges_train' + suffix + '.npy', charges_train) 102 | 103 | np.save('loc_valid' + suffix + '.npy', loc_valid) 104 | np.save('vel_valid' + suffix + '.npy', vel_valid) 105 | np.save('edges_valid' + suffix + '.npy', edges_valid) 106 | np.save('charges_valid' + suffix + '.npy', charges_valid) 107 | 108 | np.save('loc_test' + suffix + '.npy', loc_test) 109 | np.save('vel_test' + suffix + '.npy', vel_test) 110 | np.save('edges_test' + suffix + '.npy', edges_test) 111 | np.save('charges_test' + suffix + '.npy', charges_test) 112 | 113 | -------------------------------------------------------------------------------- /n_body_system/dataset_nbody.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | 6 | class NBodyDataset(): 7 | """ 8 | NBodyDataset 9 | 10 | """ 11 | def __init__(self, partition='train', max_samples=1e8, dataset_name="se3_transformer"): 12 | self.partition = partition 13 | if self.partition == 'val': 14 | self.sufix = 'valid' 15 | else: 16 | self.sufix = self.partition 17 | self.dataset_name = dataset_name 18 | if dataset_name == "nbody": 19 | self.sufix += "_charged5_initvel1" 20 | elif dataset_name == "nbody_small" or dataset_name == "nbody_small_out_dist": 21 | self.sufix += "_charged5_initvel1small" 22 | else: 23 | raise Exception("Wrong dataset name %s" % self.dataset_name) 24 | 25 | self.max_samples = int(max_samples) 26 | self.dataset_name = dataset_name 27 | self.data, self.edges = self.load() 28 | 29 | def load(self): 30 | loc = np.load('n_body_system/dataset/loc_' + self.sufix + '.npy') 31 | vel = np.load('n_body_system/dataset/vel_' + self.sufix + '.npy') 32 | edges = np.load('n_body_system/dataset/edges_' + self.sufix + '.npy') 33 | charges = np.load('n_body_system/dataset/charges_' + self.sufix + '.npy') 34 | 35 | loc, vel, edge_attr, edges, charges = self.preprocess(loc, vel, edges, charges) 36 | return (loc, vel, edge_attr, charges), edges 37 | 38 | 39 | def preprocess(self, loc, vel, edges, charges): 40 | # cast to torch and swap n_nodes <--> n_features dimensions 41 | loc, vel = torch.Tensor(loc).transpose(2, 3), torch.Tensor(vel).transpose(2, 3) 42 | n_nodes = loc.size(2) 43 | loc = loc[0:self.max_samples, :, :, :] # limit number of samples 44 | vel = vel[0:self.max_samples, :, :, :] # speed when starting the trajectory 45 | charges = charges[0:self.max_samples] 46 | edge_attr = [] 47 | 48 | #Initialize edges and edge_attributes 49 | rows, cols = [], [] 50 | for i in range(n_nodes): 51 | for j in range(n_nodes): 52 | if i != j: 53 | edge_attr.append(edges[:, i, j]) 54 | rows.append(i) 55 | cols.append(j) 56 | edges = [rows, cols] 57 | edge_attr = torch.Tensor(edge_attr).transpose(0, 1).unsqueeze(2) # swap n_nodes <--> batch_size and add nf dimension 58 | 59 | return torch.Tensor(loc), torch.Tensor(vel), torch.Tensor(edge_attr), edges, torch.Tensor(charges) 60 | 61 | def set_max_samples(self, max_samples): 62 | self.max_samples = int(max_samples) 63 | self.data, self.edges = self.load() 64 | ''' 65 | def preprocess_old(self, loc, vel, edges, charges): 66 | # cast to torch and swap n_nodes <--> n_features dimensions 67 | loc, vel = torch.Tensor(loc).transpose(2, 3), torch.Tensor(vel).transpose(2, 3) 68 | n_nodes = loc.size(2) 69 | loc0 = loc[0:self.max_samples, 0, :, :] # first location from the trajectory 70 | loc_last = loc[0:self.max_samples, -1, :, :] # last location from the trajectory 71 | vel = vel[0:self.max_samples, 0, :, :] # speed when starting the trajectory 72 | charges = charges[0:self.max_samples] 73 | edge_attr = [] 74 | 75 | #Initialize edges and edge_attributes 76 | rows, cols = [], [] 77 | for i in range(n_nodes): 78 | for j in range(n_nodes): 79 | if i != j: 80 | edge_attr.append(edges[:, i, j]) 81 | rows.append(i) 82 | cols.append(j) 83 | edges = [rows, cols] 84 | edge_attr = torch.Tensor(edge_attr).transpose(0, 1).unsqueeze(2) # swap n_nodes <--> batch_size and add nf dimension 85 | 86 | return torch.Tensor(loc0), torch.Tensor(vel), torch.Tensor(edge_attr), loc_last, edges, torch.Tensor(charges) 87 | ''' 88 | def get_n_nodes(self): 89 | return self.data[0].size(1) 90 | 91 | def __getitem__(self, i): 92 | loc, vel, edge_attr, charges = self.data 93 | loc, vel, edge_attr, charges = loc[i], vel[i], edge_attr[i], charges[i] 94 | 95 | if self.dataset_name == "nbody": 96 | frame_0, frame_T = 6, 8 97 | elif self.dataset_name == "nbody_small": 98 | frame_0, frame_T = 30, 40 99 | elif self.dataset_name == "nbody_small_out_dist": 100 | frame_0, frame_T = 20, 30 101 | else: 102 | raise Exception("Wrong dataset partition %s" % self.dataset_name) 103 | 104 | 105 | return loc[frame_0], vel[frame_0], edge_attr, charges, loc[frame_T] 106 | 107 | def __len__(self): 108 | return len(self.data[0]) 109 | 110 | def get_edges(self, batch_size, n_nodes): 111 | edges = [torch.LongTensor(self.edges[0]), torch.LongTensor(self.edges[1])] 112 | if batch_size == 1: 113 | return edges 114 | elif batch_size > 1: 115 | rows, cols = [], [] 116 | for i in range(batch_size): 117 | rows.append(edges[0] + n_nodes * i) 118 | cols.append(edges[1] + n_nodes * i) 119 | edges = [torch.cat(rows), torch.cat(cols)] 120 | return edges 121 | 122 | 123 | if __name__ == "__main__": 124 | NBodyDataset() 125 | -------------------------------------------------------------------------------- /ponita/geometry/spherical_grid.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/ThijsKuipers1995/gconv 2 | 3 | import torch 4 | from torch import Tensor 5 | from tqdm import trange 6 | from typing import Callable, Optional 7 | 8 | 9 | def columb_energy(d: Tensor, k: int = 2) -> Tensor: 10 | """ 11 | Returns columb energy over given input. 12 | 13 | Arguments: 14 | - d: Tensor to calculate columb energy over. 15 | - k: Exponent term of columb energy. 16 | 17 | Returns: 18 | - Tensor containing columb energy. 19 | """ 20 | return d ** (-k) 21 | 22 | 23 | def repulse( 24 | grid: Tensor, 25 | steps: int = 200, 26 | step_size: float = 10, 27 | metric_fn: Callable = lambda x, y: x - y, 28 | transform_fn: Callable = lambda x: x, 29 | energy_fn: Callable = columb_energy, 30 | dist_normalization_constant: float = 1, 31 | alpha: float = 0.001, 32 | show_pbar: bool = True, 33 | in_place: bool = False, 34 | ) -> Tensor: 35 | """ 36 | Performs repulsion grids defined on Sn. Will perform 37 | repulsion on device that grid is defined on. 38 | 39 | Arguments: 40 | - grid: Tensor of shape (N, D) of grid elements 41 | parameterized with a periodic parameterization, i.e. 42 | spherical angles on S2. 43 | - steps: Number times the optimimzation step will be peformed. 44 | - step_size: Strength of each optimization step. 45 | - dist_fn: Metric for which energy will be minimized. 46 | - transform_fn: Function that transforms grid to the parameterization 47 | used by the metric. 48 | - energy_fn: Function used to calculate energy. Defaults to columb energy. 49 | - dist_normalization_constant: Optional constant for normalizing distances. 50 | Set to max distance from metric or else 1. 51 | - show_pbar: If True, will show progress of optimization procedure. 52 | - in_place: If True, will update grid in place. 53 | 54 | Returns: 55 | - Tensor of shape (N, D) of minimized grid. 56 | """ 57 | 58 | pbar = trange(steps, disable=not show_pbar, desc="Optimizing") 59 | 60 | grid = grid if in_place else grid.clone() 61 | grid.requires_grad = True 62 | 63 | optimizer = torch.optim.SGD([grid], lr=step_size) 64 | 65 | for epoch in pbar: 66 | optimizer.zero_grad(set_to_none=True) 67 | 68 | grid_transform = transform_fn(grid) 69 | 70 | dists = metric_fn(grid_transform[:, None], grid_transform).sort(dim=-1)[0][:, 1:] 71 | energy_matrix = energy_fn(dists / dist_normalization_constant) 72 | 73 | mean_total_energy = energy_matrix.mean() 74 | mean_total_energy.backward() 75 | grid.grad += (steps - epoch) / steps * alpha * torch.randn(grid.grad.shape, device=grid.device) 76 | 77 | optimizer.step() 78 | 79 | pbar.set_postfix_str(f"mean total energy: {mean_total_energy.item():.3f}") 80 | 81 | grid.requires_grad = False 82 | 83 | return grid.detach() 84 | 85 | def uniform_grid_s2( 86 | n: int, 87 | parameterization: str = "euclidean", 88 | set_alpha_as_neg_gamma: bool = False, 89 | steps: int = 100, 90 | step_size: float = 0.1, 91 | show_pbar: bool = False, 92 | device: Optional[str] = None, 93 | ) -> Tensor: 94 | """ 95 | Creates a uniform grid of `n` rotations on S2. Rotations will be uniform 96 | with respect to the geodesic distance. 97 | 98 | Arguments: 99 | - n: Number of rotations in grid. 100 | - parameterization: Parameterization of the returned grid elements. Must 101 | be either 'spherical', 'euclidean', 'quat', 'matrix', or 'euler'. Defaults to 102 | 'euclidean'. 103 | - steps: Number of minimization steps. 104 | - step_size: Strength of minimization step. Default of 0.1 works well. 105 | - show_pbar: If True, will show progress of optimization procedure. 106 | - device: Device on which energy minimization will be performed and on 107 | which the output grid will be defined. 108 | 109 | Returns: 110 | - Tensor containing uniform grid on SO3. 111 | """ 112 | add_alpha = False 113 | to_so3_fn = ( 114 | spherical_to_euler_neg_gamma if set_alpha_as_neg_gamma else spherical_to_euler 115 | ) 116 | 117 | match parameterization.lower(): 118 | case "spherical": 119 | param_fn = lambda x: x 120 | case "euclidean": 121 | param_fn = spherical_to_euclid 122 | case "euler": 123 | add_alpha = True 124 | param_fn = lambda x: x 125 | case "matrix": 126 | add_alpha = True 127 | param_fn = euler_to_matrix 128 | case "quat": 129 | add_alpha = True 130 | param_fn = euler_to_quat 131 | 132 | grid = random_s2((n,), device=device) 133 | 134 | repulsion.repulse( 135 | grid, 136 | steps=steps, 137 | step_size=step_size, 138 | alpha=0.001, 139 | metric_fn=geodesic_distance_s2, 140 | transform_fn=spherical_to_euclid, 141 | dist_normalization_constant=pi, 142 | show_pbar=show_pbar, 143 | in_place=True, 144 | ) 145 | 146 | grid = to_so3_fn(grid) if add_alpha else grid 147 | 148 | return param_fn(grid) 149 | -------------------------------------------------------------------------------- /ponita/nn/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch_geometric 3 | from typing import Optional 4 | 5 | 6 | class Conv(torch_geometric.nn.MessagePassing): 7 | """ 8 | """ 9 | def __init__(self, in_channels, out_channels, attr_dim, bias=True, aggr="add", groups=1): 10 | super().__init__(node_dim=0, aggr=aggr) 11 | 12 | # Check arguments 13 | if groups==1: 14 | self.depthwise = False 15 | elif groups==in_channels and groups==out_channels: 16 | self.depthwise = True 17 | self.in_channels = in_channels 18 | self.out_channels = out_channels 19 | else: 20 | assert ValueError('Invalid option for groups, should be groups=1 or groups=in_channels=out_channels (depth-wise separable)') 21 | 22 | # Construct kernel and bias 23 | self.kernel = torch.nn.Linear(attr_dim, int(in_channels * out_channels / groups), bias=False) 24 | if bias: 25 | self.bias = torch.nn.Parameter(torch.empty(out_channels)) 26 | self.bias.data.zero_() 27 | else: 28 | self.register_parameter('bias', None) 29 | 30 | # Automatic re-initialization 31 | self.register_buffer("callibrated", torch.tensor(False)) 32 | 33 | def forward(self, x, edge_index, edge_attr, **kwargs): 34 | """ 35 | """ 36 | # Sample the convolution kernels 37 | kernel = self.kernel(edge_attr) 38 | 39 | # Do the convolution 40 | out = self.propagate(edge_index, x=x, kernel=kernel) 41 | 42 | # Re-callibrate the initializaiton 43 | if self.training and not(self.callibrated): 44 | self.callibrate(x.std(), out.std()) 45 | 46 | # Add bias 47 | if self.bias is not None: 48 | return out + self.bias 49 | else: 50 | return out 51 | 52 | def message(self, x_j, kernel): 53 | if self.depthwise: 54 | return kernel * x_j 55 | else: 56 | return torch.einsum('boi,bi->bo', kernel.unflatten(-1, (self.out_channels, self.in_channels)), x_j) 57 | 58 | def callibrate(self, std_in, std_out): 59 | print('Callibrating...') 60 | with torch.no_grad(): 61 | self.kernel.weight.data = self.kernel.weight.data * std_in/std_out 62 | self.callibrated = ~self.callibrated 63 | 64 | 65 | class FiberBundleConv(torch_geometric.nn.MessagePassing): 66 | """ 67 | """ 68 | def __init__(self, in_channels, out_channels, attr_dim, bias=True, aggr="add", separable=True, groups=1): 69 | super().__init__(node_dim=0, aggr=aggr) 70 | 71 | # Check arguments 72 | if groups==1: 73 | self.depthwise = False 74 | elif groups==in_channels and groups==out_channels: 75 | self.depthwise = True 76 | self.in_channels = in_channels 77 | self.out_channels = out_channels 78 | else: 79 | assert ValueError('Invalid option for groups, should be groups=1 or groups=in_channels=out_channels (depth-wise separable)') 80 | 81 | # Construct kernels 82 | self.separable = separable 83 | if self.separable: 84 | self.kernel = torch.nn.Linear(attr_dim, in_channels, bias=False) 85 | self.fiber_kernel = torch.nn.Linear(attr_dim, int(in_channels * out_channels / groups), bias=False) 86 | else: 87 | self.kernel = torch.nn.Linear(attr_dim, int(in_channels * out_channels / groups), bias=False) 88 | 89 | # Construct bias 90 | if bias: 91 | self.bias = torch.nn.Parameter(torch.empty(out_channels)) 92 | self.bias.data.zero_() 93 | else: 94 | self.register_parameter('bias', None) 95 | 96 | # Automatic re-initialization 97 | self.register_buffer("callibrated", torch.tensor(False)) 98 | 99 | def forward(self, x, edge_index, edge_attr, fiber_attr=None, **kwargs): 100 | """ 101 | """ 102 | 103 | # Do the convolutions: 1. Spatial conv, 2. Spherical conv 104 | kernel = self.kernel(edge_attr) 105 | x_1 = self.propagate(edge_index, x=x, kernel=kernel) 106 | if self.separable: 107 | fiber_kernel = self.fiber_kernel(fiber_attr) 108 | if self.depthwise: 109 | x_2 = torch.einsum('boc,opc->bpc', x_1, fiber_kernel) / fiber_kernel.shape[-2] 110 | else: 111 | x_2 = torch.einsum('boc,opdc->bpd', x_1, fiber_kernel.unflatten(-1, (self.out_channels, self.in_channels))) / fiber_kernel.shape[-2] 112 | else: 113 | x_2 = x_1 114 | 115 | # Re-callibrate the initializaiton 116 | if self.training and not(self.callibrated): 117 | self.callibrate(x.std(), x_1.std(), x_2.std()) 118 | 119 | # Add bias 120 | if self.bias is not None: 121 | return x_2 + self.bias 122 | else: 123 | return x_2 124 | 125 | def message(self, x_j, kernel): 126 | if self.separable: 127 | return kernel * x_j 128 | else: 129 | if self.depthwise: 130 | return torch.einsum('bopc,boc->bpc', kernel, x_j) 131 | else: 132 | return torch.einsum('bopdc,boc->bpd', kernel.unflatten(-1, (self.out_channels, self.in_channels)), x_j) 133 | 134 | def callibrate(self, std_in, std_1, std_2): 135 | print('Callibrating...') 136 | with torch.no_grad(): 137 | self.kernel.weight.data = self.kernel.weight.data * std_in/std_1 138 | if self.separable: 139 | self.fiber_kernel.weight.data = self.fiber_kernel.weight.data * std_1/std_2 140 | self.callibrated = ~self.callibrated 141 | -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/equivariant_attention/fibers.py: -------------------------------------------------------------------------------- 1 | from ..utils.utils_profiling import * # load before other local modules 2 | 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import copy 9 | 10 | from typing import Dict, List, Tuple 11 | 12 | 13 | class Fiber(object): 14 | """A Handy Data Structure for Fibers""" 15 | def __init__(self, num_degrees: int=None, num_channels: int=None, 16 | structure: List[Tuple[int,int]]=None, dictionary=None): 17 | if structure: 18 | self.structure = structure 19 | elif dictionary: 20 | self.structure = [(dictionary[o], o) for o in sorted(dictionary.keys())] 21 | else: 22 | self.structure = [(num_channels, i) for i in range(num_degrees)] 23 | 24 | self.multiplicities, self.degrees = zip(*self.structure) 25 | self.max_degree = max(self.degrees) 26 | self.min_degree = min(self.degrees) 27 | self.structure_dict = {k: v for v, k in self.structure} 28 | self.n_features = np.sum([i[0] * (2*i[1]+1) for i in self.structure]) 29 | 30 | self.feature_indices = {} 31 | idx = 0 32 | for (num_channels, d) in self.structure: 33 | length = num_channels * (2*d + 1) 34 | self.feature_indices[d] = (idx, idx + length) 35 | idx += length 36 | 37 | def copy_me(self, multiplicity: int=None): 38 | s = copy.deepcopy(self.structure) 39 | if multiplicity is not None: 40 | # overwrite multiplicities 41 | s = [(multiplicity, o) for m, o in s] 42 | return Fiber(structure=s) 43 | 44 | @staticmethod 45 | def combine(f1, f2): 46 | new_dict = copy.deepcopy(f1.structure_dict) 47 | for k, m in f2.structure_dict.items(): 48 | if k in new_dict.keys(): 49 | new_dict[k] += m 50 | else: 51 | new_dict[k] = m 52 | structure = [(new_dict[k], k) for k in sorted(new_dict.keys())] 53 | return Fiber(structure=structure) 54 | 55 | @staticmethod 56 | def combine_max(f1, f2): 57 | new_dict = copy.deepcopy(f1.structure_dict) 58 | for k, m in f2.structure_dict.items(): 59 | if k in new_dict.keys(): 60 | new_dict[k] = max(m, new_dict[k]) 61 | else: 62 | new_dict[k] = m 63 | structure = [(new_dict[k], k) for k in sorted(new_dict.keys())] 64 | return Fiber(structure=structure) 65 | 66 | @staticmethod 67 | def combine_selectively(f1, f2): 68 | # only use orders which occur in fiber f1 69 | 70 | new_dict = copy.deepcopy(f1.structure_dict) 71 | for k in f1.degrees: 72 | if k in f2.degrees: 73 | new_dict[k] += f2.structure_dict[k] 74 | structure = [(new_dict[k], k) for k in sorted(new_dict.keys())] 75 | return Fiber(structure=structure) 76 | 77 | @staticmethod 78 | def combine_fibers(val1, struc1, val2, struc2): 79 | """ 80 | combine two fibers 81 | 82 | :param val1/2: fiber tensors in dictionary form 83 | :param struc1/2: structure of fiber 84 | :return: fiber tensor in dictionary form 85 | """ 86 | struc_out = Fiber.combine(struc1, struc2) 87 | val_out = {} 88 | for k in struc_out.degrees: 89 | if k in struc1.degrees: 90 | if k in struc2.degrees: 91 | val_out[k] = torch.cat([val1[k], val2[k]], -2) 92 | else: 93 | val_out[k] = val1[k] 94 | else: 95 | val_out[k] = val2[k] 96 | assert val_out[k].shape[-2] == struc_out.structure_dict[k] 97 | return val_out 98 | 99 | def __repr__(self): 100 | return f"{self.structure}" 101 | 102 | 103 | 104 | def get_fiber_dict(F, struc, mask=None, return_struc=False): 105 | if mask is None: mask = struc 106 | index = 0 107 | fiber_dict = {} 108 | first_dims = F.shape[:-1] 109 | masked_dict = {} 110 | for o, m in struc.structure_dict.items(): 111 | length = m * (2*o + 1) 112 | if o in mask.degrees: 113 | masked_dict[o] = m 114 | fiber_dict[o] = F[...,index:index + length].view(list(first_dims) + [m, 2*o + 1]) 115 | index += length 116 | assert F.shape[-1] == index 117 | if return_struc: 118 | return fiber_dict, Fiber(dictionary=masked_dict) 119 | return fiber_dict 120 | 121 | 122 | def get_fiber_tensor(F, struc): 123 | some_entry = tuple(F.values())[0] 124 | first_dims = some_entry.shape[:-2] 125 | res = some_entry.new_empty([*first_dims, struc.n_features]) 126 | index = 0 127 | for o, m in struc.structure_dict.items(): 128 | length = m * (2*o + 1) 129 | res[..., index: index + length] = F[o].view(*first_dims, length) 130 | index += length 131 | assert index == res.shape[-1] 132 | return res 133 | 134 | 135 | def fiber2tensor(F, structure, squeeze=False): 136 | if squeeze: 137 | fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], -1) for i in structure.degrees] 138 | fibers = torch.cat(fibers, -1) 139 | else: 140 | fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], -1, 1) for i in structure.degrees] 141 | fibers = torch.cat(fibers, -2) 142 | return fibers 143 | 144 | 145 | def fiber2head(F, h, structure, squeeze=False): 146 | if squeeze: 147 | fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], h, -1) for i in structure.degrees] 148 | fibers = torch.cat(fibers, -1) 149 | else: 150 | fibers = [F[f'{i}'].view(*F[f'{i}'].shape[:-2], h, -1, 1) for i in structure.degrees] 151 | fibers = torch.cat(fibers, -2) 152 | return fibers 153 | 154 | -------------------------------------------------------------------------------- /ponita/geometry/invariants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def invariant_attr_rn(pos, edge_index): 5 | pos_send, pos_receive = pos[edge_index[0]], pos[edge_index[1]] # [num_edges, n] 6 | dists = (pos_send - pos_receive).norm(dim=-1, keepdim=True) # [num_edges, 1] 7 | return dists 8 | 9 | 10 | def invariant_attr_r3s2_fiber_bundle(pos, ori_grid, edge_index, separable=False): 11 | pos_send, pos_receive = pos[edge_index[0]], pos[edge_index[1]] # [num_edges, 3] 12 | rel_pos = (pos_send - pos_receive) # [num_edges, 3] 13 | 14 | # Convenient shape 15 | rel_pos = rel_pos[:, None, :] # [num_edges, 1, 3] 16 | ori_grid_a = ori_grid[None,:,:] # [1, num_ori, 3] 17 | ori_grid_b = ori_grid[:, None,:] # [num_ori, 1, 3] 18 | 19 | invariant1 = (rel_pos * ori_grid_a).sum(dim=-1, keepdim=True) # [num_edges, num_ori, 1] 20 | invariant2 = (rel_pos - invariant1 * ori_grid_a).norm(dim=-1, keepdim=True) # [num_edges, num_ori, 1] 21 | invariant3 = (ori_grid_a * ori_grid_b).sum(dim=-1, keepdim=True) # [num_ori, num_ori, 1] 22 | 23 | # Note: We could apply the acos = pi/2 - asin, which is differentiable at -1 and 1 24 | # But found that this mapping is unnecessary as it is monotonic and mostly linear 25 | # anyway, except close to -1 and 1. Not applying the arccos worked just as well. 26 | # invariant3 = torch.pi / 2 - torch.asin(invariant3.clamp(-1.,1.)) 27 | 28 | if separable: 29 | return torch.cat([invariant1, invariant2],dim=-1), invariant3 # [num_edges, num_ori, 2], [num_ori, num_ori, 1] 30 | else: 31 | invariant1 = invariant1[:,:,None,:].expand(-1,-1,ori_grid.shape[0],-1) # [num_edges, num_ori, num_ori, 1] 32 | invariant2 = invariant2[:,:,None,:].expand(-1,-1,ori_grid.shape[0],-1) # [num_edges, num_ori, num_ori, 1] 33 | invariant3 = invariant3[None,:,:,:].expand(invariant1.shape[0],-1,-1,-1) # [num_edges, num_ori, num_ori, 1] 34 | return torch.cat([invariant1, invariant2, invariant3],dim=-1) # [num_edges, num_ori, num_ori, 3] 35 | 36 | def invariant_attr_r3s2_point_cloud(pos, edge_index): 37 | pos_send, pos_receive = pos[edge_index[0],:3], pos[edge_index[1],:3] # [num_edges, 3] 38 | ori_send, ori_receive = pos[edge_index[0],3:], pos[edge_index[1],3:] # [num_edges, 3] 39 | rel_pos = pos_send - pos_receive # [num_edges, 3] 40 | 41 | invariant1 = torch.sum(rel_pos * ori_receive, dim=-1, keepdim=True) 42 | invariant2 = (rel_pos - ori_receive * invariant1).norm(dim=-1, keepdim=True) 43 | invariant3 = torch.sum(ori_send * ori_receive, dim=-1, keepdim=True) 44 | 45 | return torch.cat([invariant1, invariant2, invariant3],dim=-1) # [num_edges, num_ori, num_ori, 3] 46 | 47 | def invariant_attr_r2s1_fiber_bundle(pos, ori_grid, edge_index, separable=False): 48 | pos_send, pos_receive = pos[edge_index[0]], pos[edge_index[1]] # [num_edges, 3] 49 | rel_pos = (pos_send - pos_receive) # [num_edges, 3] 50 | 51 | # Convenient shape 52 | rel_pos = rel_pos[:, None, :] # [num_edges, 1, 3] 53 | ori_grid_a = ori_grid[None,:,:] # [1, num_ori, 3] 54 | ori_grid_b = ori_grid[:, None,:] # [num_ori, 1, 3] 55 | 56 | # Note ori_grid consists of tuples (ori[0], ori[1]) = (cos t, sin t) 57 | # A transposed rotation (cos t, sin t \\ - sin t, cos t) is then 58 | # acchieved as (ori[0], ori[1] \\ -ori[1], ori[0]): 59 | invariant1 = (rel_pos[...,0] * ori_grid_a[...,0] + rel_pos[...,1] * ori_grid_a[...,1]).unsqueeze(-1) 60 | invariant2 = (- rel_pos[...,0] * ori_grid_a[...,1] + rel_pos[...,1] * ori_grid_a[...,0]).unsqueeze(-1) 61 | invariant3 = (ori_grid_a * ori_grid_b).sum(dim=-1, keepdim=True) # [num_ori, num_ori, 1] 62 | 63 | # Note: We could apply the acos = pi/2 - asin, which is differentiable at -1 and 1 64 | # But found that this mapping is unnecessary as it is monotonic and mostly linear 65 | # anyway, except close to -1 and 1. Not applying the arccos worked just as well. 66 | # invariant3 = torch.pi / 2 - torch.asin(invariant3.clamp(-1.,1.)) 67 | 68 | if separable: 69 | return torch.cat([invariant1, invariant2],dim=-1), invariant3 # [num_edges, num_ori, 2], [num_ori, num_ori, 1] 70 | else: 71 | invariant1 = invariant1[:,:,None,:].expand(-1,-1,ori_grid.shape[0],-1) # [num_edges, num_ori, num_ori, 1] 72 | invariant2 = invariant2[:,:,None,:].expand(-1,-1,ori_grid.shape[0],-1) # [num_edges, num_ori, num_ori, 1] 73 | invariant3 = invariant3[None,:,:,:].expand(invariant1.shape[0],-1,-1,-1) # [num_edges, num_ori, num_ori, 1] 74 | return torch.cat([invariant1, invariant2, invariant3],dim=-1) # [num_edges, num_ori, num_ori, 3] 75 | 76 | def invariant_attr_r2s1_point_cloud(pos, edge_index): 77 | pos_send, pos_receive = pos[edge_index[0],:2], pos[edge_index[1],:2] # [num_edges, 2] 78 | ori_send, ori_receive = pos[edge_index[0],2:], pos[edge_index[1],2:] # [num_edges, 2] 79 | 80 | rel_pos = pos_send - pos_receive # [num_edges, 2] 81 | invariant1 = rel_pos[...,0] * ori_receive[...,0] + rel_pos[...,1] * ori_receive[...,1] 82 | invariant2 = - rel_pos[...,0] * ori_receive[...,1] + rel_pos[...,1] * ori_receive[...,0] 83 | invariant3 = torch.sum(ori_send * ori_receive, dim=-1, keepdim=False) 84 | 85 | return torch.stack([invariant1, invariant2, invariant3],dim=-1) # [num_edges, num_ori, num_ori, 3] 86 | -------------------------------------------------------------------------------- /lightning_wrappers/nbody.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics 3 | import pytorch_lightning as pl 4 | 5 | from .scheduler import CosineWarmupScheduler 6 | from ponita.models.ponita import PonitaFiberBundle 7 | from ponita.transforms.random_rotate import RandomRotate 8 | 9 | 10 | class PONITA_NBODY(pl.LightningModule): 11 | """Graph Neural Network module""" 12 | 13 | def __init__(self, args): 14 | super().__init__() 15 | 16 | # Store some of the relevant args 17 | self.lr = args.lr 18 | self.weight_decay = args.weight_decay 19 | self.epochs = args.epochs 20 | self.warmup = args.warmup 21 | if args.layer_scale == 0.: 22 | args.layer_scale = None 23 | 24 | # For rotation augmentations during training and testing 25 | self.train_augm = args.train_augm 26 | self.rotation_transform = RandomRotate(['pos','vec','y'], n=3) 27 | 28 | # The metrics to log 29 | self.train_metric = torchmetrics.MeanSquaredError() 30 | self.valid_metric = torchmetrics.MeanSquaredError() 31 | self.test_metric = torchmetrics.MeanSquaredError() 32 | 33 | # Input/output specifications: 34 | in_channels_scalar = 1 # Charge 35 | in_channels_vec = 1 # Velocity 36 | out_channels_scalar = 0 # None 37 | out_channels_vec = 1 # Output velocity 38 | 39 | # Make the model 40 | self.model = PonitaFiberBundle(in_channels_scalar + in_channels_vec, 41 | args.hidden_dim, 42 | out_channels_scalar, 43 | args.layers, 44 | output_dim_vec = out_channels_vec, 45 | radius=args.radius, 46 | num_ori=args.num_ori, 47 | basis_dim=args.basis_dim, 48 | degree=args.degree, 49 | widening_factor=args.widening_factor, 50 | layer_scale=args.layer_scale, 51 | task_level='node', 52 | multiple_readouts=args.multiple_readouts) 53 | 54 | def forward(self, graph): 55 | _, pred = self.model(graph) 56 | return graph.pos + pred[..., 0, :] 57 | 58 | def training_step(self, graph): 59 | if self.train_augm: 60 | graph = self.rotation_transform(graph) 61 | pos_pred = self(graph) 62 | loss = torch.mean((pos_pred - graph.y)**2) 63 | self.train_metric(pos_pred, graph.y) 64 | return loss 65 | 66 | def on_train_epoch_end(self): 67 | self.log("train MSE", self.train_metric, prog_bar=True) 68 | 69 | def validation_step(self, graph, batch_idx): 70 | pos_pred = self(graph) 71 | self.valid_metric(pos_pred, graph.y) 72 | 73 | def on_validation_epoch_end(self): 74 | self.log("valid MSE", self.valid_metric, prog_bar=True) 75 | 76 | def test_step(self, graph, batch_idx): 77 | pos_pred = self(graph) 78 | self.test_metric(pos_pred, graph.y) 79 | 80 | def on_test_epoch_end(self): 81 | self.log("test MSE", self.test_metric) 82 | 83 | def configure_optimizers(self): 84 | """ 85 | Adapted from: https://github.com/karpathy/minGPT/blob/master/mingpt/model.py 86 | This long function is unfortunately doing something very simple and is being very defensive: 87 | We are separating out all parameters of the model into two buckets: those that will experience 88 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 89 | We are then returning the PyTorch optimizer object. 90 | """ 91 | 92 | # separate out all parameters to those that will and won't experience regularizing weight decay 93 | decay = set() 94 | no_decay = set() 95 | whitelist_weight_modules = (torch.nn.Linear, ) 96 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 97 | for mn, m in self.named_modules(): 98 | for pn, p in m.named_parameters(): 99 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 100 | # random note: because named_modules and named_parameters are recursive 101 | # we will see the same tensors p many many times. but doing it this way 102 | # allows us to know which parent module any tensor p belongs to... 103 | if pn.endswith('bias'): 104 | # all biases will not be decayed 105 | no_decay.add(fpn) 106 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 107 | # weights of whitelist modules will be weight decayed 108 | decay.add(fpn) 109 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 110 | # weights of blacklist modules will NOT be weight decayed 111 | no_decay.add(fpn) 112 | elif pn.endswith('layer_scale'): 113 | no_decay.add(fpn) 114 | 115 | # validate that we considered every parameter 116 | param_dict = {pn: p for pn, p in self.named_parameters()} 117 | inter_params = decay & no_decay 118 | union_params = decay | no_decay 119 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 120 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 121 | % (str(param_dict.keys() - union_params), ) 122 | 123 | # create the pytorch optimizer object 124 | optim_groups = [ 125 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.weight_decay}, 126 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 127 | ] 128 | optimizer = torch.optim.Adam(optim_groups, lr=self.lr) 129 | scheduler = CosineWarmupScheduler(optimizer, self.warmup, self.trainer.max_epochs) 130 | return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} -------------------------------------------------------------------------------- /n_body_system/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models.gcl import GCL, E_GCL, E_GCL_vel, GCL_rf_vel 4 | 5 | 6 | 7 | class GNN(nn.Module): 8 | def __init__(self, input_dim, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, attention=0, recurrent=False): 9 | super(GNN, self).__init__() 10 | self.hidden_nf = hidden_nf 11 | self.device = device 12 | self.n_layers = n_layers 13 | ### Encoder 14 | #self.add_module("gcl_0", GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_nf=1, act_fn=act_fn, attention=attention, recurrent=recurrent)) 15 | for i in range(0, n_layers): 16 | self.add_module("gcl_%d" % i, GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_nf=1, act_fn=act_fn, attention=attention, recurrent=recurrent)) 17 | 18 | self.decoder = nn.Sequential(nn.Linear(hidden_nf, hidden_nf), 19 | act_fn, 20 | nn.Linear(hidden_nf, 3)) 21 | self.embedding = nn.Sequential(nn.Linear(input_dim, hidden_nf)) 22 | self.to(self.device) 23 | 24 | 25 | def forward(self, nodes, edges, edge_attr=None): 26 | h = self.embedding(nodes) 27 | #h, _ = self._modules["gcl_0"](h, edges, edge_attr=edge_attr) 28 | for i in range(0, self.n_layers): 29 | h, _ = self._modules["gcl_%d" % i](h, edges, edge_attr=edge_attr) 30 | #return h 31 | return self.decoder(h) 32 | 33 | 34 | def get_velocity_attr(loc, vel, rows, cols): 35 | #return torch.cat([vel[rows], vel[cols]], dim=1) 36 | 37 | diff = loc[cols] - loc[rows] 38 | norm = torch.norm(diff, p=2, dim=1).unsqueeze(1) 39 | u = diff/norm 40 | va, vb = vel[rows] * u, vel[cols] * u 41 | va, vb = torch.sum(va, dim=1).unsqueeze(1), torch.sum(vb, dim=1).unsqueeze(1) 42 | return va 43 | 44 | class EGNN(nn.Module): 45 | def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.LeakyReLU(0.2), n_layers=4, coords_weight=1.0): 46 | super(EGNN, self).__init__() 47 | self.hidden_nf = hidden_nf 48 | self.device = device 49 | self.n_layers = n_layers 50 | #self.reg = reg 51 | ### Encoder 52 | #self.add_module("gcl_0", E_GCL(in_node_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=False, coords_weight=coords_weight)) 53 | self.embedding = nn.Linear(in_node_nf, self.hidden_nf) 54 | for i in range(0, n_layers): 55 | self.add_module("gcl_%d" % i, E_GCL(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=True, coords_weight=coords_weight)) 56 | self.to(self.device) 57 | 58 | 59 | def forward(self, h, x, edges, edge_attr, vel=None): 60 | h = self.embedding(h) 61 | for i in range(0, self.n_layers): 62 | #if vel is not None: 63 | #vel_attr = get_velocity_attr(x, vel, edges[0], edges[1]) 64 | #edge_attr = torch.cat([edge_attr0, vel_attr], dim=1).detach() 65 | h, x, _ = self._modules["gcl_%d" % i](h, edges, x, edge_attr=edge_attr) 66 | return x 67 | 68 | 69 | class EGNN_vel(nn.Module): 70 | def __init__(self, in_node_nf, in_edge_nf, hidden_nf, device='cpu', act_fn=nn.SiLU(), n_layers=4, coords_weight=1.0, recurrent=False, norm_diff=False, tanh=False): 71 | super(EGNN_vel, self).__init__() 72 | self.hidden_nf = hidden_nf 73 | self.device = device 74 | self.n_layers = n_layers 75 | #self.reg = reg 76 | ### Encoder 77 | #self.add_module("gcl_0", E_GCL(in_node_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=False, coords_weight=coords_weight)) 78 | self.embedding = nn.Linear(in_node_nf, self.hidden_nf) 79 | for i in range(0, n_layers): 80 | self.add_module("gcl_%d" % i, E_GCL_vel(self.hidden_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, coords_weight=coords_weight, recurrent=recurrent, norm_diff=norm_diff, tanh=tanh)) 81 | self.to(self.device) 82 | 83 | 84 | def forward(self, h, x, edges, vel, edge_attr): 85 | h = self.embedding(h) 86 | for i in range(0, self.n_layers): 87 | h, x, _ = self._modules["gcl_%d" % i](h, edges, x, vel, edge_attr=edge_attr) 88 | return x 89 | 90 | class RF_vel(nn.Module): 91 | def __init__(self, hidden_nf, edge_attr_nf=0, device='cpu', act_fn=nn.SiLU(), n_layers=4): 92 | super(RF_vel, self).__init__() 93 | self.hidden_nf = hidden_nf 94 | self.device = device 95 | self.n_layers = n_layers 96 | #self.reg = reg 97 | ### Encoder 98 | #self.add_module("gcl_0", E_GCL(in_node_nf, self.hidden_nf, self.hidden_nf, edges_in_d=in_edge_nf, act_fn=act_fn, recurrent=False, coords_weight=coords_weight)) 99 | for i in range(0, n_layers): 100 | self.add_module("gcl_%d" % i, GCL_rf_vel(nf=hidden_nf, edge_attr_nf=edge_attr_nf, act_fn=act_fn)) 101 | self.to(self.device) 102 | 103 | 104 | def forward(self, vel_norm, x, edges, vel, edge_attr): 105 | for i in range(0, self.n_layers): 106 | x, _ = self._modules["gcl_%d" % i](x, vel_norm, vel, edges, edge_attr) 107 | return x 108 | 109 | class Baseline(nn.Module): 110 | def __init__(self, device='cpu'): 111 | super(Baseline, self).__init__() 112 | self.dummy = nn.Linear(1, 1) 113 | self.device = device 114 | self.to(self.device) 115 | 116 | def forward(self, loc): 117 | return loc 118 | 119 | class Linear(nn.Module): 120 | def __init__(self, input_nf, output_nf, device='cpu'): 121 | super(Linear, self).__init__() 122 | self.linear = nn.Linear(input_nf, output_nf) 123 | self.device = device 124 | self.to(self.device) 125 | 126 | def forward(self, input): 127 | return self.linear(input) 128 | 129 | class Linear_dynamics(nn.Module): 130 | def __init__(self, device='cpu'): 131 | super(Linear_dynamics, self).__init__() 132 | self.time = nn.Parameter(torch.ones(1)*0.7) 133 | self.device = device 134 | self.to(self.device) 135 | 136 | def forward(self, x, v): 137 | return x + v*self.time -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/dynamics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import dgl 4 | from torch import nn 5 | 6 | from ..se3_dynamics.models import OurSE3Transformer, OursTFN 7 | from .utils.utils_profiling import * # load before other local modules 8 | 9 | 10 | class OurDynamics(torch.nn.Module): 11 | def __init__(self, n_particles, n_dimesnion, nf=16, n_layers=3, act_fn=nn.ReLU(), model="se3_transformer", num_degrees=4, div=1): 12 | super().__init__() 13 | #self._transformation = transformation 14 | self._n_particles = n_particles 15 | self._n_dimension = n_dimesnion 16 | self._dim = self._n_particles * self._n_dimension 17 | 18 | if model == 'se3_transformer': 19 | self.se3 = OurSE3Transformer(num_layers=n_layers, 20 | num_channels=nf, edge_dim=0, act_fn=act_fn, num_degrees=num_degrees, div=div) 21 | elif model == 'tfn': 22 | self.se3 = OursTFN(num_layers=n_layers, 23 | num_channels=nf, edge_dim=0, div=1, act_fn=act_fn, num_degrees=num_degrees) 24 | else: 25 | raise Exception("Wrong model") 26 | 27 | 28 | self.graph = None 29 | 30 | def forward(self, xs, vs, charges): 31 | #n_batch = xs.shape[0] 32 | xs = xs.view(-1, self._n_particles, self._n_dimension) 33 | vs = vs.view(-1, self._n_particles, self._n_dimension) 34 | 35 | 36 | # r = distance_vectors(xs) 37 | # d = distances_from_vectors(r).unsqueeze(-1) 38 | # rbfs = self._rbf_encoder(d) 39 | # features = self._transformation(rbfs.view(-1, self._n_rbfs)) 40 | # potential = features.view(n_batch, -1).sum(-1) 41 | # 42 | # dxs = -1. * \ 43 | # torch.autograd.grad(potential, xs, torch.ones_like(potential), create_graph=True, only_inputs=False)[0] 44 | # return self._remove_mean(dxs) 45 | 46 | output = self.f(xs, vs, charges).view(-1, self._n_dimension) 47 | 48 | return output 49 | 50 | @profile 51 | def f(self, xs, vs, charges): 52 | """ 53 | :param xs: 54 | :return: xs_outputs.size() = xs.size() 55 | """ 56 | # xs.size() --> (batch_size: 64, num_nodes: 4, dim: 3) 57 | 58 | # features = xs.new_ones((xs.size(0), xs.size(1), 1)) 59 | # self.se3t() 60 | 61 | # 1. Transform xs to G 62 | if self.graph is None: 63 | self.graph = array_to_graph(xs) 64 | self.graph.ndata['x'] = torch.zeros_like(self.graph.ndata['x']) 65 | self.graph.edata['d'] = torch.zeros_like(self.graph.edata['d']) 66 | 67 | indices_src, indices_dst, _w = connect_fully(xs.size(1)) # [N, K] 68 | self.indices_src = indices_src 69 | self.indices_dst = indices_dst 70 | 71 | distance = xs[:, self.indices_dst] - xs[:, self.indices_src] 72 | 73 | # distance_gt = [] 74 | # for example in xs: 75 | # distance_gt.append(example[indices_dst] - example[indices_src]) 76 | # 77 | # distance_gt = torch.stack(distance_gt) 78 | # 79 | # print('gt', distance_gt.size()) 80 | # 81 | # print('distance_error', (distance_gt - distance).pow(2).mean()) 82 | # 83 | # assert False 84 | 85 | #self.graph.ndata['x'] = xs.view(xs.size(0) * xs.size(1), 3) 86 | self.graph.ndata['vel'] = vs.view(xs.size(0) * vs.size(1), 3).unsqueeze(1) 87 | 88 | self.graph.ndata['f1'] = self.graph.ndata['vel']#torch.cat([self.graph.ndata['x'].unsqueeze(1), self.graph.ndata['vel']], dim=1) 89 | 90 | 91 | self.graph.ndata['f'] = charges.unsqueeze(2) 92 | self.graph.edata['d'] = distance.view(-1, 3) 93 | 94 | # G_gt = array_to_graph(xs) 95 | # print((self.graph.edata['d'] - G_gt.edata['d']).pow(2).sum()) 96 | # assert False 97 | 98 | G = self.graph 99 | 100 | # 2. Transform G with se3t to G_out 101 | G_out = self.se3(G) 102 | 103 | # 3. Transform G_out to out 104 | 105 | out = G_out['1'].view(xs.size()) 106 | 107 | # out = xs # TODO transform. 108 | 109 | 110 | return out + xs 111 | 112 | 113 | @profile 114 | def array_to_graph(xs): 115 | B, N, D = xs.size() 116 | ### create graph 117 | # get u, v, and w 118 | 119 | # get neighbour indices here and use throughout entire network; this is a numpy function 120 | indices_src, indices_dst, _w = connect_fully(xs.size(1)) # [N, K] 121 | # indices_dst = indices_dst.flatten() # [N*K] 122 | # indices_src = np.repeat(np.array(range(N)), N-1) 123 | 124 | individual_graphs = [] 125 | for b in range(B): 126 | example = xs[b] 127 | 128 | # example has shape [N, D=3] 129 | 130 | # Create graph (connections only, no bond or feature information yet) 131 | G = dgl.DGLGraph((indices_src, indices_dst)) 132 | 133 | ### add bond & feature information to graph 134 | G.ndata['x'] = example # node positions [N, ...] 135 | G.ndata['f'] = example.new_ones(size=[N, 1, 1]) 136 | # G.ndata['f'] = torch.zeros_like(example) 137 | # G.ndata['f'] = f[...,None] # feature values [N, ...] 138 | # the following two lines will use the same ordering as specified in dgl.DGLGraph((u, v)) 139 | # G.edata['w'] = w.astype(DTYPE) # bond information 140 | G.edata['d'] = example[indices_dst] - example[indices_src] # relative postion 141 | 142 | individual_graphs.append(G) 143 | 144 | batched_graph = dgl.batch(individual_graphs) 145 | 146 | # print(G.ndata['x'].size()) 147 | # print(G.edata['d'].size()) 148 | # print(batched_graph.ndata['x'].size()) 149 | # print(batched_graph.edata['d'].size()) 150 | # 151 | # print(batched_graph) 152 | # assert False 153 | 154 | return batched_graph 155 | 156 | 157 | def connect_fully(num_atoms): 158 | """Convert to a fully connected graph""" 159 | ## TO DO: For later, change to a single for loop 160 | # Initialize all edges: no self-edges 161 | adjacency = {} 162 | for i in range(num_atoms): 163 | for j in range(num_atoms): 164 | if i != j: 165 | adjacency[(i, j)] = 1 166 | 167 | # Convert to numpy arrays 168 | src = [] 169 | dst = [] 170 | w = [] 171 | for edge, weight in adjacency.items(): 172 | src.append(edge[0]) 173 | dst.append(edge[1]) 174 | w.append(weight) 175 | 176 | return np.array(src), np.array(dst), np.array(w) 177 | -------------------------------------------------------------------------------- /lightning_wrappers/mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchmetrics 3 | import pytorch_lightning as pl 4 | 5 | from .scheduler import CosineWarmupScheduler 6 | from ponita.models.ponita import Ponita 7 | from ponita.transforms.random_rotate import RandomRotate 8 | 9 | 10 | class PONITA_MNIST(pl.LightningModule): 11 | """ 12 | Lightning wrapper for the PONITA model on the MNIST dataset. 13 | 14 | Args: 15 | args (argparse.Namespace): The arguments from the command line. 16 | """ 17 | 18 | def __init__(self, args): 19 | super().__init__() 20 | 21 | # Store some of the relevant args 22 | self.lr = args.lr 23 | self.weight_decay = args.weight_decay 24 | self.epochs = args.epochs 25 | self.warmup = args.warmup 26 | if args.layer_scale == 0.: 27 | args.layer_scale = None 28 | 29 | # For rotation augmentations during training and testing 30 | self.train_augm = args.train_augm 31 | self.rotation_transform = RandomRotate(['pos'], n=2) 32 | 33 | # The metrics to log 34 | self.train_metric = torchmetrics.Accuracy(task='multiclass', num_classes=10) 35 | self.valid_metric = torchmetrics.Accuracy(task='multiclass', num_classes=10) 36 | self.test_metric = torchmetrics.Accuracy(task='multiclass', num_classes=10) 37 | 38 | # Input/output specifications: 39 | in_channels_scalar = 1 # gray value 40 | in_channels_vec = 0 # 41 | out_channels_scalar = 10 # The target 42 | out_channels_vec = 0 # 43 | 44 | # Make the model 45 | self.model = Ponita(in_channels_scalar + in_channels_vec, 46 | args.hidden_dim, 47 | out_channels_scalar, 48 | args.layers, 49 | output_dim_vec=out_channels_vec, 50 | radius=args.radius, 51 | num_ori=args.num_ori, 52 | basis_dim=args.basis_dim, 53 | degree=args.degree, 54 | widening_factor=args.widening_factor, 55 | layer_scale=args.layer_scale, 56 | task_level='graph', 57 | multiple_readouts=args.multiple_readouts, 58 | lift_graph=True) 59 | 60 | def forward(self, graph): 61 | # Only utilize the scalar (energy) prediction 62 | pred, _ = self.model(graph) 63 | return pred 64 | 65 | def training_step(self, graph): 66 | if self.train_augm: 67 | graph = self.rotation_transform(graph) 68 | pred = self(graph) 69 | pred = torch.nn.functional.log_softmax(pred, dim=-1) 70 | loss = torch.nn.functional.nll_loss(pred, graph.y) 71 | self.train_metric(pred, graph.y) 72 | return loss 73 | 74 | def on_train_epoch_end(self): 75 | self.log("train ACC", self.train_metric, prog_bar=True) 76 | 77 | def validation_step(self, graph, batch_idx): 78 | pred = self(graph) 79 | self.valid_metric(pred, graph.y) 80 | 81 | def on_validation_epoch_end(self): 82 | self.log("valid ACC", self.valid_metric, prog_bar=True) 83 | 84 | def test_step(self, graph, batch_idx): 85 | pred = self(graph) 86 | self.test_metric(pred, graph.y) 87 | 88 | def on_test_epoch_end(self): 89 | self.log("test ACC", self.test_metric, prog_bar=True) 90 | 91 | def configure_optimizers(self): 92 | """ 93 | Adapted from: https://github.com/karpathy/minGPT/blob/master/mingpt/model.py 94 | This long function is unfortunately doing something very simple and is being very defensive: 95 | We are separating out all parameters of the model into two buckets: those that will experience 96 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 97 | We are then returning the PyTorch optimizer object. 98 | """ 99 | 100 | # separate out all parameters to those that will and won't experience regularizing weight decay 101 | decay = set() 102 | no_decay = set() 103 | whitelist_weight_modules = (torch.nn.Linear, ) 104 | blacklist_weight_modules = (torch.nn.LazyBatchNorm1d, torch.nn.LayerNorm, torch.nn.Embedding) 105 | for mn, m in self.named_modules(): 106 | for pn, p in m.named_parameters(): 107 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 108 | # random note: because named_modules and named_parameters are recursive 109 | # we will see the same tensors p many many times. but doing it this way 110 | # allows us to know which parent module any tensor p belongs to... 111 | if pn.endswith('bias'): 112 | # all biases will not be decayed 113 | no_decay.add(fpn) 114 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 115 | # weights of whitelist modules will be weight decayed 116 | decay.add(fpn) 117 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 118 | # weights of blacklist modules will NOT be weight decayed 119 | no_decay.add(fpn) 120 | elif pn.endswith('layer_scale'): 121 | no_decay.add(fpn) 122 | 123 | # validate that we considered every parameter 124 | param_dict = {pn: p for pn, p in self.named_parameters()} 125 | inter_params = decay & no_decay 126 | union_params = decay | no_decay 127 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 128 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 129 | % (str(param_dict.keys() - union_params), ) 130 | 131 | # create the pytorch optimizer object 132 | optim_groups = [ 133 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.weight_decay}, 134 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 135 | ] 136 | optimizer = torch.optim.Adam(optim_groups, lr=self.lr) 137 | scheduler = CosineWarmupScheduler(optimizer, self.warmup, self.trainer.max_epochs) 138 | return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} 139 | -------------------------------------------------------------------------------- /main_qm9.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | from torch_geometric.datasets import QM9 6 | from torch_geometric.loader import DataLoader 7 | import pytorch_lightning as pl 8 | from lightning_wrappers.callbacks import EMA, EpochTimer 9 | from lightning_wrappers.qm9 import PONITA_QM9 10 | 11 | 12 | # TODO: do we need this? 13 | import torch.multiprocessing 14 | torch.multiprocessing.set_sharing_strategy('file_system') 15 | 16 | 17 | # ------------------------ Start of the main experiment script 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser() 20 | 21 | # ------------------------ Input arguments 22 | 23 | # Run parameters 24 | parser.add_argument('--epochs', type=int, default=1000, 25 | help='number of epochs') 26 | parser.add_argument('--warmup', type=int, default=10, 27 | help='number of epochs') 28 | parser.add_argument('--batch_size', type=int, default=96, 29 | help='Batch size. Does not scale with number of gpus.') 30 | parser.add_argument('--lr', type=float, default=5e-4, 31 | help='learning rate') 32 | parser.add_argument('--weight_decay', type=float, default=1e-10, 33 | help='weight decay') 34 | parser.add_argument('--log', type=eval, default=True, 35 | help='logging flag') 36 | parser.add_argument('--enable_progress_bar', type=eval, default=True, 37 | help='enable progress bar') 38 | parser.add_argument('--num_workers', type=int, default=0, 39 | help='Num workers in dataloader') 40 | parser.add_argument('--seed', type=int, default=0, 41 | help='Random seed') 42 | 43 | # Train settings 44 | parser.add_argument('--train_augm', type=eval, default=True, 45 | help='whether or not to use random rotations during training') 46 | 47 | # Test settings 48 | parser.add_argument('--repeats', type=int, default=5, 49 | help='number of repeated forward passes at test-time') 50 | 51 | # QM9 Dataset 52 | parser.add_argument('--root', type=str, default="datasets/qm9", 53 | help='Data set location') 54 | parser.add_argument('--target', type=str, default="alpha", 55 | help='MD17 target') 56 | 57 | # Graph connectivity settings 58 | parser.add_argument('--radius', type=eval, default=1000., 59 | help='radius for the radius graph construction in front of the force loss') 60 | parser.add_argument('--loop', type=eval, default=True, 61 | help='enable self interactions') 62 | 63 | # PONTA model settings 64 | parser.add_argument('--num_ori', type=int, default=-1, 65 | help='num elements of spherical grid') 66 | parser.add_argument('--hidden_dim', type=int, default=128, 67 | help='internal feature dimension') 68 | parser.add_argument('--basis_dim', type=int, default=256, 69 | help='number of basis functions') 70 | parser.add_argument('--degree', type=int, default=3, 71 | help='degree of the polynomial embedding') 72 | parser.add_argument('--layers', type=int, default=5, 73 | help='Number of message passing layers') 74 | parser.add_argument('--widening_factor', type=int, default=4, 75 | help='Number of message passing layers') 76 | parser.add_argument('--layer_scale', type=float, default=0, 77 | help='Initial layer scale factor in ConvNextBlock, 0 means do not use layer scale') 78 | parser.add_argument('--multiple_readouts', type=eval, default=False, 79 | help='Whether or not to readout after every layer') 80 | 81 | # Parallel computing stuff 82 | parser.add_argument('-g', '--gpus', default=1, type=int, 83 | help='number of gpus to use (assumes all are on one node)') 84 | 85 | # Arg parser 86 | args = parser.parse_args() 87 | 88 | # ------------------------ Device settings 89 | 90 | if args.gpus > 0: 91 | accelerator = "gpu" 92 | devices = args.gpus 93 | else: 94 | accelerator = "cpu" 95 | devices = "auto" 96 | if args.num_workers == -1: 97 | args.num_workers = os.cpu_count() 98 | 99 | # ------------------------ Dataset 100 | 101 | # Load the dataset and set the dataset specific settings 102 | dataset = QM9(root=args.root) 103 | 104 | # Create train, val, test split (same random seed and splits as DimeNet) 105 | random_state = np.random.RandomState(seed=42) 106 | perm = torch.from_numpy(random_state.permutation(np.arange(130831))) 107 | train_idx, val_idx, test_idx = perm[:110000], perm[110000:120000], perm[120000:] 108 | datasets = {'train': dataset[train_idx], 'val': dataset[val_idx], 'test': dataset[test_idx]} 109 | 110 | # Select the right target 111 | targets = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 112 | 'U', 'H', 'G', 'Cv', 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'] 113 | idx = torch.tensor([0, 1, 2, 3, 4, 5, 6, 12, 13, 14, 15, 11, 12, 13, 14, 15]) # We will automatically replace U0 -> U0_atom etc. 114 | dataset.data.y = dataset.data.y[:, idx] 115 | dataset.data.y = dataset.data.y[:, targets.index(args.target)] 116 | 117 | # Make the dataloaders 118 | dataloaders = { 119 | split: DataLoader(dataset, batch_size=args.batch_size, shuffle=(split == 'train'), num_workers=args.num_workers) 120 | for split, dataset in datasets.items()} 121 | 122 | # ------------------------ Load and initialize the model 123 | model = PONITA_QM9(args) 124 | model.set_dataset_statistics(datasets['train']) 125 | 126 | # ------------------------ Weights and Biases logger 127 | if args.log: 128 | logger = pl.loggers.WandbLogger(project="PONITA-QM9", name=args.target.replace(" ", "_"), config=args, save_dir='logs') 129 | else: 130 | logger = None 131 | 132 | # ------------------------ Set up the trainer 133 | 134 | # Seed 135 | pl.seed_everything(args.seed, workers=True) 136 | 137 | # Pytorch lightning call backs 138 | callbacks = [EMA(0.99), 139 | pl.callbacks.ModelCheckpoint(monitor='valid MAE', mode = 'min'), 140 | EpochTimer()] 141 | if args.log: callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='epoch')) 142 | 143 | # Initialize the trainer 144 | trainer = pl.Trainer(logger=logger, max_epochs=args.epochs, callbacks=callbacks, inference_mode=False, # Important for force computation via backprop 145 | gradient_clip_val=0.5, accelerator=accelerator, devices=devices, enable_progress_bar=args.enable_progress_bar) 146 | 147 | # Do the training 148 | trainer.fit(model, dataloaders['train'], dataloaders['val']) 149 | 150 | # And test 151 | trainer.test(model, dataloaders['test'], ckpt_path = "best") 152 | -------------------------------------------------------------------------------- /lightning_wrappers/qm9.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchmetrics 6 | import pytorch_lightning as pl 7 | 8 | from .scheduler import CosineWarmupScheduler 9 | from ponita.models.ponita import Ponita 10 | from ponita.transforms.random_rotate import RandomRotate 11 | 12 | 13 | class PONITA_QM9(pl.LightningModule): 14 | """ 15 | """ 16 | 17 | def __init__(self, args): 18 | super().__init__() 19 | 20 | # Store some of the relevant args 21 | self.repeats = args.repeats 22 | self.lr = args.lr 23 | self.weight_decay = args.weight_decay 24 | self.epochs = args.epochs 25 | self.warmup = args.warmup 26 | if args.layer_scale == 0.: 27 | args.layer_scale = None 28 | 29 | # For rotation augmentations during training and testing 30 | self.train_augm = args.train_augm 31 | self.rotation_transform = RandomRotate(['pos'], n=3) 32 | 33 | # Shift and scale before callibration 34 | self.shift = 0. 35 | self.scale = 1. 36 | 37 | # The metrics to log 38 | self.train_metric = torchmetrics.MeanAbsoluteError() 39 | self.valid_metric = torchmetrics.MeanAbsoluteError() 40 | self.test_metric = torchmetrics.MeanAbsoluteError() 41 | self.test_metrics = nn.ModuleList([torchmetrics.MeanAbsoluteError() for r in range(self.repeats)]) 42 | 43 | # Input/output specifications: 44 | in_channels_scalar = 11 # One-hot encoding 45 | in_channels_vec = 0 # 46 | out_channels_scalar = 1 # The target 47 | out_channels_vec = 0 # 48 | 49 | # Make the model 50 | self.model = Ponita(in_channels_scalar + in_channels_vec, 51 | args.hidden_dim, 52 | out_channels_scalar, 53 | args.layers, 54 | output_dim_vec=out_channels_vec, 55 | radius=args.radius, 56 | num_ori=args.num_ori, 57 | basis_dim=args.basis_dim, 58 | degree=args.degree, 59 | widening_factor=args.widening_factor, 60 | layer_scale=args.layer_scale, 61 | task_level='graph', 62 | multiple_readouts=args.multiple_readouts, 63 | lift_graph=True) 64 | 65 | def set_dataset_statistics(self, dataloader): 66 | print('Computing dataset statistics...') 67 | ys = [] 68 | for data in dataloader: 69 | ys.append(data.y) 70 | ys = np.concatenate(ys) 71 | self.shift = np.mean(ys) 72 | self.scale = np.std(ys) 73 | print('Mean and std of target are:', self.shift, '-', self.scale) 74 | 75 | def forward(self, graph): 76 | # Only utilize the scalar (energy) prediction 77 | pred, _ = self.model(graph) 78 | return pred.squeeze(-1) 79 | 80 | def training_step(self, graph): 81 | if self.train_augm: 82 | graph = self.rotation_transform(graph) 83 | pred = self(graph) 84 | # loss = torch.mean((pred - (graph.y - self.shift) / self.scale)**2) 85 | loss = torch.mean((pred - (graph.y - self.shift) / self.scale).abs()) 86 | self.train_metric(pred * self.scale + self.shift, graph.y) 87 | return loss 88 | 89 | def on_train_epoch_end(self): 90 | self.log("train MAE", self.train_metric, prog_bar=True) 91 | 92 | def validation_step(self, graph, batch_idx): 93 | pred = self(graph) 94 | self.valid_metric(pred * self.scale + self.shift, graph.y) 95 | 96 | def on_validation_epoch_end(self): 97 | self.log("valid MAE", self.valid_metric, prog_bar=True) 98 | 99 | def test_step(self, graph, batch_idx): 100 | pred = self(graph) 101 | self.test_metric(pred * self.scale + self.shift, graph.y) 102 | 103 | def on_test_epoch_end(self): 104 | self.log("test MAE", self.test_metric, prog_bar=True) 105 | 106 | def configure_optimizers(self): 107 | """ 108 | Adapted from: https://github.com/karpathy/minGPT/blob/master/mingpt/model.py 109 | This long function is unfortunately doing something very simple and is being very defensive: 110 | We are separating out all parameters of the model into two buckets: those that will experience 111 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 112 | We are then returning the PyTorch optimizer object. 113 | """ 114 | 115 | # separate out all parameters to those that will and won't experience regularizing weight decay 116 | decay = set() 117 | no_decay = set() 118 | whitelist_weight_modules = (torch.nn.Linear, ) 119 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 120 | for mn, m in self.named_modules(): 121 | for pn, p in m.named_parameters(): 122 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 123 | # random note: because named_modules and named_parameters are recursive 124 | # we will see the same tensors p many many times. but doing it this way 125 | # allows us to know which parent module any tensor p belongs to... 126 | if pn.endswith('bias'): 127 | # all biases will not be decayed 128 | no_decay.add(fpn) 129 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 130 | # weights of whitelist modules will be weight decayed 131 | decay.add(fpn) 132 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 133 | # weights of blacklist modules will NOT be weight decayed 134 | no_decay.add(fpn) 135 | elif pn.endswith('layer_scale'): 136 | no_decay.add(fpn) 137 | 138 | # validate that we considered every parameter 139 | param_dict = {pn: p for pn, p in self.named_parameters()} 140 | inter_params = decay & no_decay 141 | union_params = decay | no_decay 142 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 143 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 144 | % (str(param_dict.keys() - union_params), ) 145 | 146 | # create the pytorch optimizer object 147 | optim_groups = [ 148 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.weight_decay}, 149 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 150 | ] 151 | optimizer = torch.optim.Adam(optim_groups, lr=self.lr) 152 | scheduler = CosineWarmupScheduler(optimizer, self.warmup, self.trainer.max_epochs) 153 | return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} 154 | -------------------------------------------------------------------------------- /main_nbody.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from n_body_system.dataset_nbody import NBodyDataset 4 | from torch_geometric.loader import DataLoader 5 | from torch_geometric.data import Data 6 | from torch_geometric.transforms import RadiusGraph 7 | import pytorch_lightning as pl 8 | from lightning_wrappers.callbacks import EpochTimer, EMA 9 | from lightning_wrappers.nbody import PONITA_NBODY 10 | 11 | 12 | # ------------------------ Function to convert the nbody dataset to a dataloader for pytorch geometric graphs 13 | 14 | def make_pyg_loader(dataset, batch_size, shuffle, num_workers, radius, loop): 15 | data_list = [] 16 | radius = radius or 1000. 17 | radius_graph = RadiusGraph(radius, loop=loop, max_num_neighbors=1000) 18 | for data in dataset: 19 | loc, vel, edge_attr, charges, loc_end = data 20 | x = charges 21 | vec = vel[:,None,:] # [num_pts, num_channels=1, 3] 22 | # Build the graph 23 | graph = Data(pos=loc, x=x, vec=vec, y=loc_end) 24 | graph = radius_graph(graph) 25 | # Append to the database list 26 | data_list.append(graph) 27 | return DataLoader(data_list, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 28 | 29 | 30 | # ------------------------ Start of the main experiment script 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | 35 | # ------------------------ Input arguments 36 | 37 | # Run parameters 38 | parser.add_argument('--epochs', type=int, default=10000, 39 | help='number of epochs') 40 | parser.add_argument('--warmup', type=int, default=10, 41 | help='number of epochs') 42 | parser.add_argument('--batch_size', type=int, default=100, 43 | help='Batch size. Does not scale with number of gpus.') 44 | parser.add_argument('--lr', type=float, default=1e-3, 45 | help='learning rate') 46 | parser.add_argument('--weight_decay', type=float, default=1e-10, 47 | help='weight decay') 48 | parser.add_argument('--log', type=eval, default=True, 49 | help='logging flag') 50 | parser.add_argument('--enable_progress_bar', type=eval, default=False, 51 | help='enable progress bar') 52 | parser.add_argument('--num_workers', type=int, default=0, 53 | help='Num workers in dataloader') 54 | parser.add_argument('--seed', type=int, default=0, 55 | help='Random seed') 56 | parser.add_argument('--val_interval', type=int, default=5, metavar='N', 57 | help='how many epochs to wait before logging validation') 58 | 59 | # Train settings 60 | parser.add_argument('--train_augm', type=eval, default=True, 61 | help='whether or not to use random rotations during training') 62 | 63 | # nbody Dataset 64 | parser.add_argument('--max_training_samples', type=int, default=3000, metavar='N', 65 | help='maximum amount of training samples') 66 | parser.add_argument('--dataset', type=str, default="nbody_small", metavar='N', 67 | help='nbody_small, nbody') 68 | 69 | # Graph connectivity settings 70 | parser.add_argument('--radius', type=eval, default=None, 71 | help='radius for the radius graph construction in front of the force loss') 72 | parser.add_argument('--loop', type=eval, default=True, 73 | help='enable self interactions') 74 | 75 | # PONTA model settings 76 | parser.add_argument('--num_ori', type=int, default=16, 77 | help='num elements of spherical grid') 78 | parser.add_argument('--hidden_dim', type=int, default=128, 79 | help='internal feature dimension') 80 | parser.add_argument('--basis_dim', type=int, default=256, 81 | help='number of basis functions') 82 | parser.add_argument('--degree', type=int, default=3, 83 | help='degree of the polynomial embedding') 84 | parser.add_argument('--layers', type=int, default=5, 85 | help='Number of message passing layers') 86 | parser.add_argument('--widening_factor', type=int, default=4, 87 | help='Number of message passing layers') 88 | parser.add_argument('--layer_scale', type=float, default=1e-6, 89 | help='Initial layer scale factor in ConvNextBlock, 0 means do not use layer scale') 90 | parser.add_argument('--multiple_readouts', type=eval, default=True, 91 | help='Whether or not to readout after every layer') 92 | 93 | # Parallel computing stuff 94 | parser.add_argument('-g', '--gpus', default=1, type=int, 95 | help='number of gpus to use (assumes all are on one node)') 96 | 97 | # Arg parser 98 | args = parser.parse_args() 99 | 100 | # ------------------------ Device settings 101 | 102 | if args.gpus > 0: 103 | accelerator = "gpu" 104 | devices = args.gpus 105 | else: 106 | accelerator = "cpu" 107 | devices = "auto" 108 | if args.num_workers == -1: 109 | args.num_workers = os.cpu_count() 110 | 111 | # ------------------------ Dataset 112 | 113 | # Load the dataset and set the dataset specific settings 114 | dataset_train = NBodyDataset(partition='train', dataset_name=args.dataset, 115 | max_samples=args.max_training_samples) 116 | dataset_val = NBodyDataset(partition='val', dataset_name="nbody_small") 117 | dataset_test = NBodyDataset(partition='test', dataset_name="nbody_small") 118 | 119 | # Make the dataloaders 120 | datasets = {'train': dataset_train, 'valid': dataset_val, 'test': dataset_test} 121 | dataloaders = { 122 | split: make_pyg_loader(dataset, 123 | batch_size=args.batch_size, 124 | shuffle=(split == 'train'), 125 | num_workers=args.num_workers, 126 | radius=args.radius, 127 | loop=args.loop) 128 | for split, dataset in datasets.items()} 129 | 130 | # ------------------------ Load and initialize the model 131 | 132 | model = PONITA_NBODY(args) 133 | 134 | # ------------------------ Weights and Biases logger 135 | 136 | if args.log: 137 | logger = pl.loggers.WandbLogger(project="PONITA-" + args.dataset, name='siva', config=args, save_dir='logs') 138 | else: 139 | logger = None 140 | 141 | # ------------------------ Set up the trainer 142 | 143 | # Seed 144 | pl.seed_everything(args.seed, workers=True) 145 | 146 | # Pytorch lightning call backs 147 | callbacks = [pl.callbacks.ModelCheckpoint(monitor='valid MSE', mode = 'min'), 148 | EpochTimer()] 149 | if args.log: callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='epoch')) 150 | 151 | # Initialize the trainer 152 | trainer = pl.Trainer(logger=logger, max_epochs=args.epochs, callbacks=callbacks, gradient_clip_val=0.5, 153 | accelerator=accelerator, devices=devices, check_val_every_n_epoch=args.val_interval, 154 | enable_progress_bar=args.enable_progress_bar) 155 | 156 | # Do the training 157 | trainer.fit(model, dataloaders['train'], dataloaders['valid']) 158 | 159 | # And test 160 | trainer.test(model, dataloaders['test'], ckpt_path = "best") 161 | -------------------------------------------------------------------------------- /main_mnist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | from torch_geometric.datasets import MNISTSuperpixels 6 | from torch_geometric.loader import DataLoader 7 | import pytorch_lightning as pl 8 | from lightning_wrappers.callbacks import EMA, EpochTimer 9 | from lightning_wrappers.mnist import PONITA_MNIST 10 | from torch_geometric.transforms import BaseTransform, KNNGraph, Compose 11 | 12 | # TODO: do we need this? 13 | import torch.multiprocessing 14 | torch.multiprocessing.set_sharing_strategy('file_system') 15 | 16 | 17 | class Sparsify(BaseTransform): 18 | def __init__(self, threshold=0.5): 19 | super().__init__() 20 | self.threshold = threshold 21 | 22 | def __call__(self, graph): 23 | select = graph.x[:,0] > self.threshold 24 | graph.x = graph.x[select] 25 | graph.pos = graph.pos[select] 26 | if graph.batch is not None: 27 | graph.batch = graph.batch[select] 28 | graph.edge_index = None 29 | return graph 30 | 31 | from torch_geometric.transforms import BaseTransform 32 | class RemoveDuplicatePoints(BaseTransform): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def __call__(self, graph): 37 | dists = (graph.pos[:,None,:] - graph.pos[None,:,:]).norm(dim=-1) 38 | dists = dists + 100. * torch.tril(torch.ones_like(dists), diagonal=0) 39 | min_dists = dists.min(dim=1)[0] 40 | select = min_dists > 0. 41 | graph.x = graph.x[select] 42 | graph.pos = graph.pos[select] 43 | graph.edge_index = None 44 | return graph 45 | 46 | 47 | # ------------------------ Start of the main experiment script 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser() 50 | 51 | # ------------------------ Input arguments 52 | 53 | # Run parameters 54 | parser.add_argument('--epochs', type=int, default=50, 55 | help='number of epochs') 56 | parser.add_argument('--warmup', type=int, default=0, 57 | help='number of epochs') 58 | parser.add_argument('--batch_size', type=int, default=96, 59 | help='Batch size. Does not scale with number of gpus.') 60 | parser.add_argument('--lr', type=float, default=5e-4, 61 | help='learning rate') 62 | parser.add_argument('--weight_decay', type=float, default=1e-10, 63 | help='weight decay') 64 | parser.add_argument('--log', type=eval, default=True, 65 | help='logging flag') 66 | parser.add_argument('--enable_progress_bar', type=eval, default=True, 67 | help='enable progress bar') 68 | parser.add_argument('--num_workers', type=int, default=0, 69 | help='Num workers in dataloader') 70 | parser.add_argument('--seed', type=int, default=0, 71 | help='Random seed') 72 | 73 | # Train settings 74 | parser.add_argument('--train_augm', type=eval, default=True, 75 | help='whether or not to use random rotations during training') 76 | 77 | # QM9 Dataset 78 | parser.add_argument('--root', type=str, default="datasets/mnist", 79 | help='Data set location') 80 | 81 | # Graph connectivity settings 82 | parser.add_argument('--radius', type=eval, default=None, 83 | help='radius for the radius graph construction in front of the force loss') 84 | parser.add_argument('--loop', type=eval, default=True, 85 | help='enable self interactions') 86 | 87 | # PONTA model settings 88 | parser.add_argument('--num_ori', type=int, default=10, 89 | help='num elements of spherical grid') 90 | parser.add_argument('--hidden_dim', type=int, default=128, 91 | help='internal feature dimension') 92 | parser.add_argument('--basis_dim', type=int, default=256, 93 | help='number of basis functions') 94 | parser.add_argument('--degree', type=int, default=3, 95 | help='degree of the polynomial embedding') 96 | parser.add_argument('--layers', type=int, default=5, 97 | help='Number of message passing layers') 98 | parser.add_argument('--widening_factor', type=int, default=4, 99 | help='Number of message passing layers') 100 | parser.add_argument('--layer_scale', type=float, default=0, 101 | help='Initial layer scale factor in ConvNextBlock, 0 means do not use layer scale') 102 | parser.add_argument('--multiple_readouts', type=eval, default=False, 103 | help='Whether or not to readout after every layer') 104 | 105 | # Parallel computing stuff 106 | parser.add_argument('-g', '--gpus', default=1, type=int, 107 | help='number of gpus to use (assumes all are on one node)') 108 | 109 | # Arg parser 110 | args = parser.parse_args() 111 | 112 | # ------------------------ Device settings 113 | 114 | if args.gpus > 0: 115 | accelerator = "gpu" 116 | devices = args.gpus 117 | else: 118 | accelerator = "cpu" 119 | devices = "auto" 120 | if args.num_workers == -1: 121 | args.num_workers = os.cpu_count() 122 | 123 | # ------------------------ Dataset 124 | 125 | # Load the dataset and set the dataset specific settings 126 | # transform = Compose([RemoveDuplicatePoints(), KNNGraph(k=4, loop=False)]) 127 | transform = None 128 | dataset_train = MNISTSuperpixels(root=args.root, train=True, transform=transform) 129 | dataset_test = MNISTSuperpixels(root=args.root, train=False, transform=transform) 130 | 131 | # Create train, val, test splits 132 | train_size = int(0.9 * len(dataset_train)) 133 | val_size = len(dataset_train) - train_size 134 | dataset_train, dataset_val = torch.utils.data.random_split(dataset_train, [train_size, val_size]) 135 | datasets = {'train': dataset_train, 'val': dataset_val, 'test': dataset_test} 136 | 137 | # Select the right target 138 | 139 | # Make the dataloaders 140 | dataloaders = { 141 | split: DataLoader(dataset, batch_size=args.batch_size, shuffle=(split == 'train'), num_workers=args.num_workers) 142 | for split, dataset in datasets.items()} 143 | 144 | # ------------------------ Load and initialize the model 145 | model = PONITA_MNIST(args) 146 | 147 | # ------------------------ Weights and Biases logger 148 | if args.log: 149 | logger = pl.loggers.WandbLogger(project="PONITA-MNIST", name=None, config=args, save_dir='logs') 150 | else: 151 | logger = None 152 | 153 | # ------------------------ Set up the trainer 154 | 155 | # Seed 156 | pl.seed_everything(args.seed, workers=True) 157 | 158 | # Pytorch lightning call backs 159 | callbacks = [EMA(0.99), 160 | pl.callbacks.ModelCheckpoint(monitor='valid ACC', mode = 'max'), 161 | EpochTimer()] 162 | if args.log: callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='epoch')) 163 | 164 | # Initialize the trainer 165 | trainer = pl.Trainer(logger=logger, max_epochs=args.epochs, callbacks=callbacks, inference_mode=False, # Important for force computation via backprop 166 | gradient_clip_val=0.5, accelerator=accelerator, devices=devices, enable_progress_bar=args.enable_progress_bar) 167 | 168 | # Do the training 169 | trainer.fit(model, dataloaders['train'], dataloaders['val']) 170 | 171 | # And test 172 | trainer.test(model, dataloaders['test'], ckpt_path = "best") 173 | -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/equivariant_attention/from_se3cnn/representations.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import numpy as np 5 | from scipy.special import lpmv as lpmv_scipy 6 | 7 | 8 | def semifactorial(x): 9 | """Compute the semifactorial function x!!. 10 | 11 | x!! = x * (x-2) * (x-4) *... 12 | 13 | Args: 14 | x: positive int 15 | Returns: 16 | float for x!! 17 | """ 18 | y = 1. 19 | for n in range(x, 1, -2): 20 | y *= n 21 | return y 22 | 23 | 24 | def pochhammer(x, k): 25 | """Compute the pochhammer symbol (x)_k. 26 | 27 | (x)_k = x * (x+1) * (x+2) *...* (x+k-1) 28 | 29 | Args: 30 | x: positive int 31 | Returns: 32 | float for (x)_k 33 | """ 34 | xf = float(x) 35 | for n in range(x+1, x+k): 36 | xf *= n 37 | return xf 38 | 39 | def lpmv(l, m, x): 40 | """Associated Legendre function including Condon-Shortley phase. 41 | 42 | Args: 43 | m: int order 44 | l: int degree 45 | x: float argument tensor 46 | Returns: 47 | tensor of x-shape 48 | """ 49 | m_abs = abs(m) 50 | if m_abs > l: 51 | return torch.zeros_like(x) 52 | 53 | # Compute P_m^m 54 | yold = ((-1)**m_abs * semifactorial(2*m_abs-1)) * torch.pow(1-x*x, m_abs/2) 55 | 56 | # Compute P_{m+1}^m 57 | if m_abs != l: 58 | y = x * (2*m_abs+1) * yold 59 | else: 60 | y = yold 61 | 62 | # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m 63 | for i in range(m_abs+2, l+1): 64 | tmp = y 65 | # Inplace speedup 66 | y = ((2*i-1) / (i-m_abs)) * x * y 67 | y -= ((i+m_abs-1)/(i-m_abs)) * yold 68 | yold = tmp 69 | 70 | if m < 0: 71 | y *= ((-1)**m / pochhammer(l+m+1, -2*m)) 72 | 73 | return y 74 | 75 | def tesseral_harmonics(l, m, theta=0., phi=0.): 76 | """Tesseral spherical harmonic with Condon-Shortley phase. 77 | 78 | The Tesseral spherical harmonics are also known as the real spherical 79 | harmonics. 80 | 81 | Args: 82 | l: int for degree 83 | m: int for order, where -l <= m < l 84 | theta: collatitude or polar angle 85 | phi: longitude or azimuth 86 | Returns: 87 | tensor of shape theta 88 | """ 89 | assert abs(m) <= l, "absolute value of order m must be <= degree l" 90 | 91 | N = np.sqrt((2*l+1) / (4*np.pi)) 92 | leg = lpmv(l, abs(m), torch.cos(theta)) 93 | if m == 0: 94 | return N*leg 95 | elif m > 0: 96 | Y = torch.cos(m*phi) * leg 97 | else: 98 | Y = torch.sin(abs(m)*phi) * leg 99 | N *= np.sqrt(2. / pochhammer(l-abs(m)+1, 2*abs(m))) 100 | Y *= N 101 | return Y 102 | 103 | class SphericalHarmonics(object): 104 | def __init__(self): 105 | self.leg = {} 106 | 107 | def clear(self): 108 | self.leg = {} 109 | 110 | def negative_lpmv(self, l, m, y): 111 | """Compute negative order coefficients""" 112 | if m < 0: 113 | y *= ((-1)**m / pochhammer(l+m+1, -2*m)) 114 | return y 115 | 116 | def lpmv(self, l, m, x): 117 | """Associated Legendre function including Condon-Shortley phase. 118 | 119 | Args: 120 | m: int order 121 | l: int degree 122 | x: float argument tensor 123 | Returns: 124 | tensor of x-shape 125 | """ 126 | # Check memoized versions 127 | m_abs = abs(m) 128 | if (l,m) in self.leg: 129 | return self.leg[(l,m)] 130 | elif m_abs > l: 131 | return None 132 | elif l == 0: 133 | self.leg[(l,m)] = torch.ones_like(x) 134 | return self.leg[(l,m)] 135 | 136 | # Check if on boundary else recurse solution down to boundary 137 | if m_abs == l: 138 | # Compute P_m^m 139 | y = (-1)**m_abs * semifactorial(2*m_abs-1) 140 | y *= torch.pow(1-x*x, m_abs/2) 141 | self.leg[(l,m)] = self.negative_lpmv(l, m, y) 142 | return self.leg[(l,m)] 143 | else: 144 | # Recursively precompute lower degree harmonics 145 | self.lpmv(l-1, m, x) 146 | 147 | # Compute P_{l}^m from recursion in P_{l-1}^m and P_{l-2}^m 148 | # Inplace speedup 149 | y = ((2*l-1) / (l-m_abs)) * x * self.lpmv(l-1, m_abs, x) 150 | if l - m_abs > 1: 151 | y -= ((l+m_abs-1)/(l-m_abs)) * self.leg[(l-2, m_abs)] 152 | #self.leg[(l, m_abs)] = y 153 | 154 | if m < 0: 155 | y = self.negative_lpmv(l, m, y) 156 | self.leg[(l,m)] = y 157 | 158 | return self.leg[(l,m)] 159 | 160 | def get_element(self, l, m, theta, phi): 161 | """Tesseral spherical harmonic with Condon-Shortley phase. 162 | 163 | The Tesseral spherical harmonics are also known as the real spherical 164 | harmonics. 165 | 166 | Args: 167 | l: int for degree 168 | m: int for order, where -l <= m < l 169 | theta: collatitude or polar angle 170 | phi: longitude or azimuth 171 | Returns: 172 | tensor of shape theta 173 | """ 174 | assert abs(m) <= l, "absolute value of order m must be <= degree l" 175 | 176 | N = np.sqrt((2*l+1) / (4*np.pi)) 177 | leg = self.lpmv(l, abs(m), torch.cos(theta)) 178 | if m == 0: 179 | return N*leg 180 | elif m > 0: 181 | Y = torch.cos(m*phi) * leg 182 | else: 183 | Y = torch.sin(abs(m)*phi) * leg 184 | N *= np.sqrt(2. / pochhammer(l-abs(m)+1, 2*abs(m))) 185 | Y *= N 186 | return Y 187 | 188 | def get(self, l, theta, phi, refresh=True): 189 | """Tesseral harmonic with Condon-Shortley phase. 190 | 191 | The Tesseral spherical harmonics are also known as the real spherical 192 | harmonics. 193 | 194 | Args: 195 | l: int for degree 196 | theta: collatitude or polar angle 197 | phi: longitude or azimuth 198 | Returns: 199 | tensor of shape [*theta.shape, 2*l+1] 200 | """ 201 | results = [] 202 | if refresh: 203 | self.clear() 204 | for m in range(-l, l+1): 205 | results.append(self.get_element(l, m, theta, phi)) 206 | return torch.stack(results, -1) 207 | 208 | 209 | 210 | 211 | if __name__ == "__main__": 212 | from lie_learn.representations.SO3.spherical_harmonics import sh 213 | device = 'cuda' 214 | dtype = torch.float64 215 | bs = 32 216 | theta = 0.1*torch.randn(bs,1024,10, dtype=dtype) 217 | phi = 0.1*torch.randn(bs,1024,10, dtype=dtype) 218 | cu_theta = theta.to(device) 219 | cu_phi = phi.to(device) 220 | s0 = s1 = s2 = 0 221 | max_error = -1. 222 | 223 | sph_har = SphericalHarmonics() 224 | for l in range(10): 225 | for m in range(l, -l-1, -1): 226 | start = time.time() 227 | #y = tesseral_harmonics(l, m, theta, phi) 228 | y = sph_har.get_element(l, m, cu_theta, cu_phi).type(torch.float32) 229 | #y = sph_har.lpmv(l, m, phi) 230 | s0 += time.time() - start 231 | start = time.time() 232 | z = sh(l, m, theta, phi) 233 | #z = lpmv_scipy(m, l, phi).numpy() 234 | s1 += time.time() - start 235 | 236 | error = np.mean(np.abs((y.cpu().numpy() - z) / z)) 237 | max_error = max(max_error, error) 238 | print(f"l: {l}, m: {m} ", error) 239 | 240 | #start = time.time() 241 | #sph_har.get(l, theta, phi) 242 | #s2 += time.time() - start 243 | 244 | print('#################') 245 | 246 | print(f"Max error: {max_error}") 247 | print(f"Time diff: {s0/s1}") 248 | print(f"Total time: {s0}") 249 | #print(f"Time diff: {s2/s1}") 250 | -------------------------------------------------------------------------------- /main_md17.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from torch_geometric.datasets import MD17 5 | from torch_geometric.loader import DataLoader 6 | from torch_geometric.transforms import BaseTransform, Compose, RadiusGraph 7 | import pytorch_lightning as pl 8 | from lightning_wrappers.callbacks import EMA, EpochTimer 9 | from lightning_wrappers.md17 import PONITA_MD17 10 | 11 | 12 | # ------------------------ Some transforms specific to the rMD17 tasks 13 | # One-hot encoding of atom type 14 | class OneHotTransform(BaseTransform): 15 | def __init__(self, k=None): 16 | super().__init__() 17 | self.k = k 18 | 19 | def __call__(self, graph): 20 | if self.k is None: 21 | graph.x = torch.nn.functional.one_hot(graph.z).float() 22 | else: 23 | graph.x = torch.nn.functional.one_hot(graph.z, self.k).squeeze().float() 24 | 25 | return graph 26 | # Unit conversion 27 | class Kcal2meV(BaseTransform): 28 | def __init__(self): 29 | # Kcal/mol to meV 30 | self.conversion = 43.3634 31 | 32 | def __call__(self, graph): 33 | graph.energy = graph.energy * self.conversion 34 | graph.force = graph.force * self.conversion 35 | return graph 36 | 37 | 38 | # ------------------------ Start of the main experiment script 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser() 41 | 42 | # ------------------------ Input arguments 43 | 44 | # Run parameters 45 | parser.add_argument('--epochs', type=int, default=5000, 46 | help='number of epochs') 47 | parser.add_argument('--warmup', type=int, default=100, 48 | help='number of epochs') 49 | parser.add_argument('--batch_size', type=int, default=5, 50 | help='Batch size. Does not scale with number of gpus.') 51 | parser.add_argument('--lr', type=float, default=5e-4, 52 | help='learning rate') 53 | parser.add_argument('--weight_decay', type=float, default=1e-16, 54 | help='weight decay') 55 | parser.add_argument('--log', type=eval, default=True, 56 | help='logging flag') 57 | parser.add_argument('--enable_progress_bar', type=eval, default=False, 58 | help='enable progress bar') 59 | parser.add_argument('--num_workers', type=int, default=0, 60 | help='Num workers in dataloader') 61 | parser.add_argument('--seed', type=int, default=0, 62 | help='Random seed') 63 | 64 | # Train settings 65 | parser.add_argument('--train_augm', type=eval, default=True, 66 | help='whether or not to use random rotations during training') 67 | parser.add_argument('--lambda_F', type=float, default=500.0, 68 | help='coefficient in front of the force loss') 69 | 70 | # Test settings 71 | parser.add_argument('--repeats', type=int, default=5, 72 | help='number of repeated forward passes at test-time') 73 | 74 | # MD17 Dataset 75 | parser.add_argument('--root', type=str, default="datasets", 76 | help='Data set location') 77 | parser.add_argument('--target', type=str, default="revised aspirin", 78 | help='MD17 target') 79 | 80 | # Graph connectivity settings 81 | parser.add_argument('--radius', type=eval, default=None, 82 | help='radius for the radius graph construction in front of the force loss') 83 | parser.add_argument('--loop', type=eval, default=True, 84 | help='enable self interactions') 85 | 86 | # PONTA model settings 87 | parser.add_argument('--num_ori', type=int, default=20, 88 | help='num elements of spherical grid') 89 | parser.add_argument('--hidden_dim', type=int, default=128, 90 | help='internal feature dimension') 91 | parser.add_argument('--basis_dim', type=int, default=256, 92 | help='number of basis functions') 93 | parser.add_argument('--degree', type=int, default=3, 94 | help='degree of the polynomial embedding') 95 | parser.add_argument('--layers', type=int, default=5, 96 | help='Number of message passing layers') 97 | parser.add_argument('--widening_factor', type=int, default=4, 98 | help='Number of message passing layers') 99 | parser.add_argument('--layer_scale', type=float, default=0, 100 | help='Initial layer scale factor in ConvNextBlock, 0 means do not use layer scale') 101 | parser.add_argument('--multiple_readouts', type=eval, default=True, 102 | help='Whether or not to readout after every layer') 103 | 104 | # Parallel computing stuff 105 | parser.add_argument('-g', '--gpus', default=1, type=int, 106 | help='number of gpus to use (assumes all are on one node)') 107 | 108 | # Arg parser 109 | args = parser.parse_args() 110 | 111 | # ------------------------ Device settings 112 | 113 | if args.gpus > 0: 114 | accelerator = "gpu" 115 | devices = args.gpus 116 | else: 117 | accelerator = "cpu" 118 | devices = "auto" 119 | if args.num_workers == -1: 120 | args.num_workers = os.cpu_count() 121 | 122 | # ------------------------ Dataset 123 | 124 | # Load the dataset and set the dataset specific settings 125 | transform = [Kcal2meV(), OneHotTransform(9), RadiusGraph((args.radius or 1000.), loop=args.loop, max_num_neighbors=1000)] 126 | dataset = MD17(root=args.root, name=args.target, transform=Compose(transform)) 127 | 128 | # Create train, val, test split 129 | test_idx = list(range(min(len(dataset),100000))) # The whole dataset consist sof 100,000 samples 130 | train_idx = test_idx[::100] # Select every other 100th sample for training 131 | del test_idx[::100] # and remove these from the test set 132 | val_idx = train_idx[::20] # Select every 20th sample from the train set for validation 133 | del train_idx[::20] # and remove these from the train set 134 | 135 | # Dataset and loaders 136 | datasets = {'train': dataset[train_idx], 'val': dataset[val_idx], 'test': dataset[test_idx]} 137 | dataloaders = { 138 | split: DataLoader(dataset, batch_size=args.batch_size, shuffle=(split == 'train'), num_workers=args.num_workers) 139 | for split, dataset in datasets.items()} 140 | 141 | # ------------------------ Load and initialize the model 142 | model = PONITA_MD17(args) 143 | model.set_dataset_statistics(datasets['train']) 144 | 145 | # ------------------------ Weights and Biases logger 146 | if args.log: 147 | logger = pl.loggers.WandbLogger(project="PONITA-MD17", name=args.target.replace(" ", "_"), config=args, save_dir='logs') 148 | else: 149 | logger = None 150 | 151 | # ------------------------ Set up the trainer 152 | 153 | # Seed 154 | pl.seed_everything(args.seed, workers=True) 155 | 156 | # Pytorch lightning call backs 157 | callbacks = [EMA(0.99), 158 | pl.callbacks.ModelCheckpoint(monitor='valid MAE (energy)', mode = 'min'), 159 | EpochTimer()] 160 | if args.log: callbacks.append(pl.callbacks.LearningRateMonitor(logging_interval='epoch')) 161 | 162 | # Initialize the trainer 163 | trainer = pl.Trainer(logger=logger, max_epochs=args.epochs, callbacks=callbacks, inference_mode=False, # Important for force computation via backprop 164 | gradient_clip_val=0.5, accelerator=accelerator, devices=devices, enable_progress_bar=args.enable_progress_bar) 165 | 166 | # Do the training 167 | trainer.fit(model, dataloaders['train'], dataloaders['val']) 168 | 169 | # And test 170 | trainer.test(model, dataloaders['test'], ckpt_path = "best") 171 | -------------------------------------------------------------------------------- /ponita/transforms/position_orientation_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.transforms import BaseTransform, RadiusGraph 3 | from torch_geometric.typing import SparseTensor 4 | from torch_geometric.utils import coalesce, remove_self_loops, add_self_loops 5 | from ponita.geometry.rotation import uniform_grid_s2, random_matrix 6 | from ponita.geometry.rotation_2d import uniform_grid_s1, random_so2_matrix 7 | from ponita.utils.to_from_sphere import scalar_to_sphere, vec_to_sphere 8 | import torch_geometric 9 | 10 | 11 | class PositionOrientationGraph(BaseTransform): 12 | """ 13 | A PyTorch Geometric transform that lifts a point cloud in position space, as stored in a graph data object, 14 | to a position-orientation space fiber bundle. The grid of orientations sampled in each fiber is shared over 15 | all nodes and node features (scalars and/or vectors) are locally lifted to this grid. 16 | 17 | Args: 18 | num_ori (int): Number of orientations used to discretize the sphere. 19 | """ 20 | 21 | def __init__(self, num_ori, radius=None): 22 | super().__init__() 23 | 24 | # Discretization of the orientation grid 25 | self.num_ori = num_ori 26 | self.radius = radius 27 | 28 | # Grid type 29 | if num_ori > 0: 30 | # Somewhat redundant but this is done to be compatible with both 2d and 3d 31 | self.ori_grid_s1 = uniform_grid_s1(num_ori) 32 | self.ori_grid_s2 = uniform_grid_s2(num_ori) 33 | 34 | if radius is not None: 35 | self.transform = RadiusGraph(radius, loop=True, max_num_neighbors=1000) 36 | 37 | def __call__(self, graph): 38 | """ 39 | Apply the transform to the input graph. 40 | 41 | Args: 42 | graph (torch_geometric.data.Data): Input graph containing position (graph.pos), 43 | scalar features (optional, graph.x) with shape [num_nodes, num_features], 44 | and vector features (optional, graph.vec) with shape [num_nodes, num_vec_features, 3] 45 | 46 | Returns: 47 | torch_geometric.data.Data: Updated graph with added orientation information (graph.ori_grid with shape [num_ori, n]) 48 | and lifted feature (graph.f with shape [num_nodes, num_ori, num_vec + num_x]). 49 | """ 50 | if self.num_ori == -1: 51 | graph = self.to_po_point_cloud(graph) 52 | elif self.num_ori == 0: 53 | graph = self.to_p_point_cloud(graph) 54 | else: 55 | graph = self.to_po_fiber_bundle(graph) 56 | loop = True # Hard-code that self-interactions are always present 57 | if self.radius is not None: 58 | graph.edge_index = torch_geometric.nn.radius_graph(graph.pos[:,:graph.n], self.radius, graph.batch, loop, max_num_neighbors=1000) 59 | else: 60 | if loop: 61 | graph.edge_index = coalesce(add_self_loops(graph.edge_index)[0]) 62 | return graph 63 | 64 | def to_po_fiber_bundle(self, graph): 65 | """ 66 | Internal method to add orientation information to the input graph. 67 | 68 | Args: 69 | graph (torch_geometric.data.Data): Input graph containing position (graph.pos), 70 | scalar features (optional, graph.x with shape [num_nodes, num_scalars]), and 71 | vector features (optional, graph.vec with shape [num_nodes, num_vec, n]). 72 | 73 | Returns: 74 | torch_geometric.data.Data: Updated graph with added orientation information (graph.ori_grid) and features (graph.f). 75 | """ 76 | graph.n = graph.pos.size(1) 77 | graph.ori_grid = (self.ori_grid_s1 if (graph.n == 2) else self.ori_grid_s2).type_as(graph.pos) 78 | graph.num_ori = self.num_ori 79 | 80 | 81 | # Lift input features to spheres 82 | inputs = [] 83 | if hasattr(graph, "x"): inputs.append(scalar_to_sphere(graph.x, graph.ori_grid)) 84 | if hasattr(graph, "vec"): inputs.append(vec_to_sphere(graph.vec, graph.ori_grid)) 85 | graph.x = torch.cat(inputs, dim=-1) # [num_nodes, num_ori, input_dim + input_dim_vec] 86 | 87 | # Return updated graph 88 | return graph 89 | 90 | def to_po_point_cloud(self, graph): 91 | graph.n = graph.pos.size(1) 92 | 93 | # ----------- The relevant items in the original graph 94 | 95 | # We should remove self-loops because those cannot define directions 96 | input_edge_index = remove_self_loops(coalesce(graph.edge_index, num_nodes=graph.num_nodes))[0] 97 | # The other relevant items 98 | pos = graph.pos 99 | batch = graph.batch 100 | source, target = input_edge_index 101 | 102 | # ----------- Lifted positions (each original edge now becomes a node) 103 | 104 | # Compute direction vectors from the edge_index 105 | pos_s, pos_t = pos[source], pos[target] 106 | dist = (pos_s - pos_t).norm(dim=-1, keepdim=True) 107 | ori_t = (pos_s - pos_t) / dist 108 | 109 | # Target position as base node 110 | graph.pos = torch.cat([pos_t, ori_t], dim=-1) # [4D, or 6D position-orientation element] 111 | 112 | # Each edge in the original graph will become a new node with the following index 113 | lifted_index = torch.arange(source.size(0), device=source.device) # lifted idx 114 | 115 | # ----------- Lift the edge_index 116 | 117 | # For the new_edge_index we do allow for self-interactions 118 | base_edge_index = coalesce(add_self_loops(input_edge_index)[0]) 119 | base_source = base_edge_index[0] 120 | base_target = base_edge_index[1] 121 | 122 | # The following is used as lookup table for connecting lifted idx to base idx 123 | # We use SparseTensor for this, which codes the triplets (row_idx, col_idx, value) 124 | # Here we define the triplet as (base_idx, the_original_node_sending_to_this, lifted_idx) 125 | # In particular the combination base_idx -> lifted_idx is going to be useful to lookup which 126 | # lifted nodes are associated with a base node 127 | num_base = pos.size(0) 128 | baseidx_source_liftidx = SparseTensor(row=target, col=source, value=lifted_index, sparse_sizes=(num_base, num_base)) 129 | 130 | # Determine the number of lifted_idx at each base node 131 | num_ori_at_base = baseidx_source_liftidx.set_value(None).sum(dim=1).to(torch.long) 132 | 133 | # We now take the base_edge_index as starting point 134 | # We take the base indices of this edge_index and look up which lifted points are 135 | # associated with these base indices. This will form the set of sending lifted indices 136 | lifted_source = baseidx_source_liftidx[base_source].storage.value() # [648000 = 72000 * 9] 137 | 138 | # Then check the base nodes at the receiving end 139 | base_target = base_target.repeat_interleave(num_ori_at_base[base_source]) # [64800] 140 | 141 | # Lookup all the lifted indices at the receiving (target) node 142 | lifted_target = baseidx_source_liftidx[base_target].storage.value() # [5832000 = 648000 * 9] 143 | 144 | # Repeat the lifted source the number of times that it has to send to the receiving target 145 | lifted_source = lifted_source.repeat_interleave(num_ori_at_base[base_target]) 146 | 147 | # Now we're done 148 | lifted_edge_index = torch.stack([lifted_source, lifted_target]) 149 | graph.edge_index = lifted_edge_index 150 | 151 | # ----------- Lift the batch 152 | 153 | if hasattr(graph, "batch"): 154 | if graph.batch is not None: 155 | graph.batch = batch[input_edge_index[1]].clone().contiguous() 156 | 157 | # ----------- Lift the scalar and vector features, overwrite x 158 | inputs = [] 159 | if hasattr(graph, "x"): 160 | inputs.append(graph.x[input_edge_index[1]]) 161 | if hasattr(graph, "vec"): 162 | inputs.append(torch.einsum('bcd,bd->bc', graph.vec[input_edge_index[1]], graph.pos[:,graph.n:])) 163 | graph.x = torch.cat(inputs, dim=-1) # [num_lifted_nodes, num_channels] 164 | 165 | # ----------- Utility to be able to project back to the base node (e.g. via scatter collect) 166 | 167 | graph.scatter_projection_index = input_edge_index[1] 168 | 169 | 170 | return graph 171 | 172 | def to_p_point_cloud(self, graph): 173 | graph.n = graph.pos.size(1) 174 | 175 | # Otherwise do nothing because the graph is already assumed to be a position point cloud 176 | # I.e., it already has graph.pos, graph.x, graph.batch and possibly graph.edge_index 177 | # However, we need to add scatter_projection_index for compatibility with the ponita method 178 | graph.scatter_projection_index = torch.arange(0,graph.pos.size(0)).type_as(graph.batch) 179 | 180 | return graph -------------------------------------------------------------------------------- /lightning_wrappers/md17.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchmetrics 6 | from torch_geometric.data import Batch 7 | import pytorch_lightning as pl 8 | 9 | from .scheduler import CosineWarmupScheduler 10 | from ponita.models.ponita import Ponita 11 | from ponita.transforms.random_rotate import RandomRotate 12 | 13 | 14 | class PONITA_MD17(pl.LightningModule): 15 | """ 16 | """ 17 | 18 | def __init__(self, args): 19 | super().__init__() 20 | 21 | # Store some of the relevant args 22 | self.repeats = args.repeats 23 | self.lr = args.lr 24 | self.weight_decay = args.weight_decay 25 | self.epochs = args.epochs 26 | self.warmup = args.warmup 27 | self.lambda_F = args.lambda_F 28 | if args.layer_scale == 0.: 29 | args.layer_scale = None 30 | 31 | # For rotation augmentations during training and testing 32 | self.train_augm = args.train_augm 33 | self.rotation_transform = RandomRotate(['pos','force'], n=3) 34 | 35 | # Shift and scale before callibration 36 | self.shift = 0. 37 | self.scale = 1. 38 | 39 | # The metrics to log 40 | self.train_metric = torchmetrics.MeanAbsoluteError() 41 | self.train_metric_force = torchmetrics.MeanAbsoluteError() 42 | self.valid_metric = torchmetrics.MeanAbsoluteError() 43 | self.valid_metric_force = torchmetrics.MeanAbsoluteError() 44 | self.test_metrics_energy = nn.ModuleList([torchmetrics.MeanAbsoluteError() for r in range(self.repeats)]) 45 | self.test_metrics_force = nn.ModuleList([torchmetrics.MeanAbsoluteError() for r in range(self.repeats)]) 46 | 47 | # Input/output specifications: 48 | in_channels_scalar = 9 # Charge, Velocity norm 49 | in_channels_vec = 0 # Velocity, rel_pos 50 | out_channels_scalar = 1 # None 51 | out_channels_vec = 0 # Output velocity 52 | 53 | # Make the model 54 | self.model = Ponita(in_channels_scalar + in_channels_vec, 55 | args.hidden_dim, 56 | out_channels_scalar, 57 | args.layers, 58 | output_dim_vec=out_channels_vec, 59 | radius=args.radius, 60 | num_ori=args.num_ori, 61 | basis_dim=args.basis_dim, 62 | degree=args.degree, 63 | widening_factor=args.widening_factor, 64 | layer_scale=args.layer_scale, 65 | task_level='graph', 66 | multiple_readouts=args.multiple_readouts, 67 | lift_graph=True) 68 | 69 | def set_dataset_statistics(self, dataset): 70 | ys = np.array([data.energy.item() for data in dataset]) 71 | forces = np.concatenate([data.force.numpy() for data in dataset]) 72 | self.shift = np.mean(ys) 73 | self.scale = np.sqrt(np.mean(forces**2)) 74 | self.min_dist = 1e10 75 | self.max_dist = 0 76 | for data in dataset: 77 | pos = data.pos 78 | edm = np.linalg.norm(pos[:,None,:] - pos[None,:,:],axis=-1) 79 | min_dist = np.min(edm + np.eye(edm.shape[0]) * 1e10) 80 | max_dist = np.max(edm) 81 | if min_dist < self.min_dist: 82 | self.min_dist = min_dist 83 | if max_dist > self.max_dist: 84 | self.max_dist = max_dist 85 | print('Min-max range of distances between atoms in the dataset:', self.min_dist, '-', self.max_dist) 86 | 87 | def forward(self, graph): 88 | # Only utilize the scalar (energy) prediction 89 | pred, _ = self.model(graph) 90 | return pred.squeeze(-1) 91 | 92 | @torch.enable_grad() 93 | def pred_energy_and_force(self, graph): 94 | graph.pos = torch.autograd.Variable(graph.pos, requires_grad=True) 95 | pos = graph.pos 96 | pred_energy = self(graph) 97 | sign = -1.0 98 | pred_force = sign * torch.autograd.grad( 99 | pred_energy, 100 | pos, 101 | grad_outputs=torch.ones_like(pred_energy), 102 | create_graph=True, 103 | retain_graph=True 104 | )[0] 105 | # Return result 106 | return pred_energy, pred_force 107 | 108 | def training_step(self, graph): 109 | if self.train_augm: 110 | graph = self.rotation_transform(graph) 111 | pred_energy, pred_force = self.pred_energy_and_force(graph) 112 | 113 | energy_loss = torch.mean((pred_energy - (graph.energy - self.shift) / self.scale)**2) 114 | force_loss = torch.mean(torch.sum((pred_force - graph.force / self.scale)**2,-1)) / 3. 115 | loss = energy_loss / self.lambda_F + force_loss 116 | 117 | self.train_metric(pred_energy * self.scale + self.shift, graph.energy) 118 | self.train_metric_force(pred_force * self.scale, graph.force) 119 | 120 | return loss 121 | 122 | def on_train_epoch_end(self): 123 | self.log("train MAE (energy)", self.train_metric, prog_bar=True) 124 | self.log("train MAE (force)", self.train_metric_force, prog_bar=True) 125 | 126 | def validation_step(self, graph, batch_idx): 127 | pred_energy, pred_force = self.pred_energy_and_force(graph) 128 | self.valid_metric(pred_energy * self.scale + self.shift, graph.energy) 129 | self.valid_metric_force(pred_force * self.scale, graph.force) 130 | 131 | def on_validation_epoch_end(self): 132 | self.log("valid MAE (energy)", self.valid_metric, prog_bar=True) 133 | self.log("valid MAE (force)", self.valid_metric_force, prog_bar=True) 134 | 135 | def test_step(self, graph, batch_idx): 136 | # Repeat the prediction self.repeat number of times and average (makes sense due to random grids) 137 | batch_size = graph.batch.max() + 1 138 | batch_length = graph.batch.shape[0] 139 | graph_repeated = Batch.from_data_list([graph] * self.repeats) 140 | # Random rotate graph 141 | rot = self.rotation_transform.random_rotation(graph_repeated) 142 | graph_repeated = self.rotation_transform.rotate_graph(graph_repeated, rot) 143 | # Compute results 144 | pred_energy_repeated, pred_force_repeated = self.pred_energy_and_force(graph_repeated) 145 | # Unrotate results 146 | rot_T = rot.transpose(-2,-1) 147 | pred_force_repeated = self.rotation_transform.rotate_attr(pred_force_repeated, rot_T) 148 | # Unwrap predictions 149 | pred_energy_repeated = pred_energy_repeated.unflatten(0, (self.repeats, batch_size)) 150 | pred_force_repeated = pred_force_repeated.unflatten(0, (self.repeats, batch_length)) 151 | # Compute the averages 152 | for r in range(self.repeats): 153 | pred_energy, pred_force = pred_energy_repeated[:r+1].mean(0), pred_force_repeated[:r+1].mean(0) 154 | self.test_metrics_energy[r](pred_energy * self.scale + self.shift, graph.energy) 155 | self.test_metrics_force[r](pred_force * self.scale, graph.force) 156 | 157 | def on_test_epoch_end(self): 158 | for r in range(self.repeats): 159 | self.log("test MAE (energy) x"+str(r+1), self.test_metrics_energy[r]) 160 | self.log("test MAE (force) x"+str(r+1), self.test_metrics_force[r]) 161 | 162 | def configure_optimizers(self): 163 | """ 164 | Adapted from: https://github.com/karpathy/minGPT/blob/master/mingpt/model.py 165 | This long function is unfortunately doing something very simple and is being very defensive: 166 | We are separating out all parameters of the model into two buckets: those that will experience 167 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 168 | We are then returning the PyTorch optimizer object. 169 | """ 170 | 171 | # separate out all parameters to those that will and won't experience regularizing weight decay 172 | decay = set() 173 | no_decay = set() 174 | whitelist_weight_modules = (torch.nn.Linear, ) 175 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 176 | for mn, m in self.named_modules(): 177 | for pn, p in m.named_parameters(): 178 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 179 | # random note: because named_modules and named_parameters are recursive 180 | # we will see the same tensors p many many times. but doing it this way 181 | # allows us to know which parent module any tensor p belongs to... 182 | if pn.endswith('bias'): 183 | # all biases will not be decayed 184 | no_decay.add(fpn) 185 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 186 | # weights of whitelist modules will be weight decayed 187 | decay.add(fpn) 188 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 189 | # weights of blacklist modules will NOT be weight decayed 190 | no_decay.add(fpn) 191 | elif pn.endswith('layer_scale'): 192 | no_decay.add(fpn) 193 | 194 | # validate that we considered every parameter 195 | param_dict = {pn: p for pn, p in self.named_parameters()} 196 | inter_params = decay & no_decay 197 | union_params = decay | no_decay 198 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 199 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 200 | % (str(param_dict.keys() - union_params), ) 201 | 202 | # create the pytorch optimizer object 203 | optim_groups = [ 204 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": self.weight_decay}, 205 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 206 | ] 207 | optimizer = torch.optim.Adam(optim_groups, lr=self.lr) 208 | scheduler = CosineWarmupScheduler(optimizer, self.warmup, self.trainer.max_epochs) 209 | return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/equivariant_attention/from_se3cnn/SO3.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C,E1101,E1102 2 | ''' 3 | Some functions related to SO3 and his usual representations 4 | 5 | Using ZYZ Euler angles parametrisation 6 | ''' 7 | import torch 8 | import math 9 | import numpy as np 10 | 11 | 12 | class torch_default_dtype: 13 | 14 | def __init__(self, dtype): 15 | self.saved_dtype = None 16 | self.dtype = dtype 17 | 18 | def __enter__(self): 19 | self.saved_dtype = torch.get_default_dtype() 20 | torch.set_default_dtype(self.dtype) 21 | 22 | def __exit__(self, exc_type, exc_value, traceback): 23 | torch.set_default_dtype(self.saved_dtype) 24 | 25 | 26 | def rot_z(gamma): 27 | ''' 28 | Rotation around Z axis 29 | ''' 30 | if not torch.is_tensor(gamma): 31 | gamma = torch.tensor(gamma, dtype=torch.get_default_dtype()) 32 | return torch.tensor([ 33 | [torch.cos(gamma), -torch.sin(gamma), 0], 34 | [torch.sin(gamma), torch.cos(gamma), 0], 35 | [0, 0, 1] 36 | ], dtype=gamma.dtype) 37 | 38 | 39 | def rot_y(beta): 40 | ''' 41 | Rotation around Y axis 42 | ''' 43 | if not torch.is_tensor(beta): 44 | beta = torch.tensor(beta, dtype=torch.get_default_dtype()) 45 | return torch.tensor([ 46 | [torch.cos(beta), 0, torch.sin(beta)], 47 | [0, 1, 0], 48 | [-torch.sin(beta), 0, torch.cos(beta)] 49 | ], dtype=beta.dtype) 50 | 51 | 52 | def rot(alpha, beta, gamma): 53 | ''' 54 | ZYZ Eurler angles rotation 55 | ''' 56 | return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma) 57 | 58 | 59 | def x_to_alpha_beta(x): 60 | ''' 61 | Convert point (x, y, z) on the sphere into (alpha, beta) 62 | ''' 63 | if not torch.is_tensor(x): 64 | x = torch.tensor(x, dtype=torch.get_default_dtype()) 65 | x = x / torch.norm(x) 66 | beta = torch.acos(x[2]) 67 | alpha = torch.atan2(x[1], x[0]) 68 | return (alpha, beta) 69 | 70 | 71 | # These functions (x_to_alpha_beta and rot) satisfies that 72 | # rot(*x_to_alpha_beta([x, y, z]), 0) @ np.array([[0], [0], [1]]) 73 | # is proportional to 74 | # [x, y, z] 75 | 76 | 77 | def irr_repr(order, alpha, beta, gamma, dtype=None): 78 | """ 79 | irreducible representation of SO3 80 | - compatible with compose and spherical_harmonics 81 | """ 82 | # from from_lielearn_SO3.wigner_d import wigner_D_matrix 83 | from lie_learn.representations.SO3.wigner_d import wigner_D_matrix 84 | # if order == 1: 85 | # # change of basis to have vector_field[x, y, z] = [vx, vy, vz] 86 | # A = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 87 | # return A @ wigner_D_matrix(1, alpha, beta, gamma) @ A.T 88 | 89 | # TODO (non-essential): try to do everything in torch 90 | # return torch.tensor(wigner_D_matrix(torch.tensor(order), alpha, beta, gamma), dtype=torch.get_default_dtype() if dtype is None else dtype) 91 | return torch.tensor(wigner_D_matrix(order, np.array(alpha), np.array(beta), np.array(gamma)), dtype=torch.get_default_dtype() if dtype is None else dtype) 92 | 93 | 94 | # def spherical_harmonics(order, alpha, beta, dtype=None): 95 | # """ 96 | # spherical harmonics 97 | # - compatible with irr_repr and compose 98 | # """ 99 | # # from from_lielearn_SO3.spherical_harmonics import sh 100 | # from lie_learn.representations.SO3.spherical_harmonics import sh # real valued by default 101 | # 102 | # ################################################################################################################### 103 | # # ON ANGLE CONVENTION 104 | # # 105 | # # sh has following convention for angles: 106 | # # :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)). 107 | # # :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi. 108 | # # 109 | # # this function therefore (probably) has the following convention for alpha and beta: 110 | # # beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)). 111 | # # alpha = phi 112 | # # 113 | # ################################################################################################################### 114 | # 115 | # Y = torch.tensor([sh(order, m, theta=math.pi - beta, phi=alpha) for m in range(-order, order + 1)], dtype=torch.get_default_dtype() if dtype is None else dtype) 116 | # # if order == 1: 117 | # # # change of basis to have vector_field[x, y, z] = [vx, vy, vz] 118 | # # A = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 119 | # # return A @ Y 120 | # return Y 121 | 122 | 123 | def compose(a1, b1, c1, a2, b2, c2): 124 | """ 125 | (a, b, c) = (a1, b1, c1) composed with (a2, b2, c2) 126 | """ 127 | comp = rot(a1, b1, c1) @ rot(a2, b2, c2) 128 | xyz = comp @ torch.tensor([0, 0, 1.]) 129 | a, b = x_to_alpha_beta(xyz) 130 | rotz = rot(0, -b, -a) @ comp 131 | c = torch.atan2(rotz[1, 0], rotz[0, 0]) 132 | return a, b, c 133 | 134 | 135 | def kron(x, y): 136 | assert x.ndimension() == 2 137 | assert y.ndimension() == 2 138 | return torch.einsum("ij,kl->ikjl", (x, y)).view(x.size(0) * y.size(0), x.size(1) * y.size(1)) 139 | 140 | 141 | ################################################################################ 142 | # Change of basis 143 | ################################################################################ 144 | 145 | 146 | def xyz_vector_basis_to_spherical_basis(): 147 | """ 148 | to convert a vector [x, y, z] transforming with rot(a, b, c) 149 | into a vector transforming with irr_repr(1, a, b, c) 150 | see assert for usage 151 | """ 152 | with torch_default_dtype(torch.float64): 153 | A = torch.tensor([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype=torch.float64) 154 | assert all(torch.allclose(irr_repr(1, a, b, c) @ A, A @ rot(a, b, c)) for a, b, c in torch.rand(10, 3)) 155 | return A.type(torch.get_default_dtype()) 156 | 157 | 158 | def tensor3x3_repr(a, b, c): 159 | """ 160 | representation of 3x3 tensors 161 | T --> R T R^t 162 | """ 163 | r = rot(a, b, c) 164 | return kron(r, r) 165 | 166 | 167 | def tensor3x3_repr_basis_to_spherical_basis(): 168 | """ 169 | to convert a 3x3 tensor transforming with tensor3x3_repr(a, b, c) 170 | into its 1 + 3 + 5 component transforming with irr_repr(0, a, b, c), irr_repr(1, a, b, c), irr_repr(3, a, b, c) 171 | see assert for usage 172 | """ 173 | with torch_default_dtype(torch.float64): 174 | to1 = torch.tensor([ 175 | [1, 0, 0, 0, 1, 0, 0, 0, 1], 176 | ], dtype=torch.get_default_dtype()) 177 | assert all(torch.allclose(irr_repr(0, a, b, c) @ to1, to1 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3)) 178 | 179 | to3 = torch.tensor([ 180 | [0, 0, -1, 0, 0, 0, 1, 0, 0], 181 | [0, 1, 0, -1, 0, 0, 0, 0, 0], 182 | [0, 0, 0, 0, 0, 1, 0, -1, 0], 183 | ], dtype=torch.get_default_dtype()) 184 | assert all(torch.allclose(irr_repr(1, a, b, c) @ to3, to3 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3)) 185 | 186 | to5 = torch.tensor([ 187 | [0, 1, 0, 1, 0, 0, 0, 0, 0], 188 | [0, 0, 0, 0, 0, 1, 0, 1, 0], 189 | [-3**.5/3, 0, 0, 0, -3**.5/3, 0, 0, 0, 12**.5/3], 190 | [0, 0, 1, 0, 0, 0, 1, 0, 0], 191 | [1, 0, 0, 0, -1, 0, 0, 0, 0] 192 | ], dtype=torch.get_default_dtype()) 193 | assert all(torch.allclose(irr_repr(2, a, b, c) @ to5, to5 @ tensor3x3_repr(a, b, c)) for a, b, c in torch.rand(10, 3)) 194 | 195 | return to1.type(torch.get_default_dtype()), to3.type(torch.get_default_dtype()), to5.type(torch.get_default_dtype()) 196 | 197 | 198 | ################################################################################ 199 | # Tests 200 | ################################################################################ 201 | 202 | 203 | def test_is_representation(rep): 204 | """ 205 | rep(Z(a1) Y(b1) Z(c1) Z(a2) Y(b2) Z(c2)) = rep(Z(a1) Y(b1) Z(c1)) rep(Z(a2) Y(b2) Z(c2)) 206 | """ 207 | with torch_default_dtype(torch.float64): 208 | a1, b1, c1, a2, b2, c2 = torch.rand(6) 209 | 210 | r1 = rep(a1, b1, c1) 211 | r2 = rep(a2, b2, c2) 212 | 213 | a, b, c = compose(a1, b1, c1, a2, b2, c2) 214 | r = rep(a, b, c) 215 | 216 | r_ = r1 @ r2 217 | 218 | d, r = (r - r_).abs().max(), r.abs().max() 219 | print(d.item(), r.item()) 220 | assert d < 1e-10 * r, d / r 221 | 222 | 223 | def _test_spherical_harmonics(order): 224 | """ 225 | This test tests that 226 | - irr_repr 227 | - compose 228 | - spherical_harmonics 229 | are compatible 230 | 231 | Y(Z(alpha) Y(beta) Z(gamma) x) = D(alpha, beta, gamma) Y(x) 232 | with x = Z(a) Y(b) eta 233 | """ 234 | with torch_default_dtype(torch.float64): 235 | a, b = torch.rand(2) 236 | alpha, beta, gamma = torch.rand(3) 237 | 238 | ra, rb, _ = compose(alpha, beta, gamma, a, b, 0) 239 | Yrx = spherical_harmonics(order, ra, rb) 240 | 241 | Y = spherical_harmonics(order, a, b) 242 | DrY = irr_repr(order, alpha, beta, gamma) @ Y 243 | 244 | d, r = (Yrx - DrY).abs().max(), Y.abs().max() 245 | print(d.item(), r.item()) 246 | assert d < 1e-10 * r, d / r 247 | 248 | 249 | def _test_change_basis_wigner_to_rot(): 250 | # from from_lielearn_SO3.wigner_d import wigner_D_matrix 251 | from lie_learn.representations.SO3.wigner_d import wigner_D_matrix 252 | 253 | with torch_default_dtype(torch.float64): 254 | A = torch.tensor([ 255 | [0, 1, 0], 256 | [0, 0, 1], 257 | [1, 0, 0] 258 | ], dtype=torch.float64) 259 | 260 | a, b, c = torch.rand(3) 261 | 262 | r1 = A.t() @ torch.tensor(wigner_D_matrix(1, a, b, c), dtype=torch.float64) @ A 263 | r2 = rot(a, b, c) 264 | 265 | d = (r1 - r2).abs().max() 266 | print(d.item()) 267 | assert d < 1e-10 268 | 269 | 270 | if __name__ == "__main__": 271 | from functools import partial 272 | 273 | print("Change of basis") 274 | xyz_vector_basis_to_spherical_basis() 275 | test_is_representation(tensor3x3_repr) 276 | tensor3x3_repr_basis_to_spherical_basis() 277 | 278 | print("Change of basis Wigner <-> rot") 279 | _test_change_basis_wigner_to_rot() 280 | _test_change_basis_wigner_to_rot() 281 | _test_change_basis_wigner_to_rot() 282 | 283 | print("Spherical harmonics are solution of Y(rx) = D(r) Y(x)") 284 | for l in range(7): 285 | _test_spherical_harmonics(l) 286 | 287 | print("Irreducible repr are indeed representations") 288 | for l in range(7): 289 | test_is_representation(partial(irr_repr, l)) 290 | -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/models.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import torch 5 | 6 | # from dgl.nn.pytorch import GraphConv, NNConv 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from typing import Dict, Tuple, List 10 | 11 | from .equivariant_attention.modules import GConvSE3, GNormSE3, get_basis_and_r, GSE3Res, GMaxPooling, GAvgPooling 12 | from .equivariant_attention.fibers import Fiber 13 | import time 14 | 15 | class TFN(nn.Module): 16 | """SE(3) equivariant GCN""" 17 | def __init__(self, num_layers: int, atom_feature_size: int, 18 | num_channels: int, num_nlayers: int=1, num_degrees: int=4, 19 | edge_dim: int=4, **kwargs): 20 | super().__init__() 21 | # Build the network 22 | self.num_layers = num_layers 23 | self.num_nlayers = num_nlayers 24 | self.num_channels = num_channels 25 | self.num_degrees = num_degrees 26 | self.num_channels_out = num_channels*num_degrees 27 | self.edge_dim = edge_dim 28 | 29 | self.fibers = {'in': Fiber(1, atom_feature_size), 30 | 'mid': Fiber(num_degrees, self.num_channels), 31 | 'out': Fiber(1, self.num_channels_out)} 32 | 33 | blocks = self._build_gcn(self.fibers, 1) 34 | self.block0, self.block1, self.block2 = blocks 35 | print(self.block0) 36 | print(self.block1) 37 | print(self.block2) 38 | 39 | def _build_gcn(self, fibers, out_dim): 40 | 41 | block0 = [] 42 | fin = fibers['in'] 43 | for i in range(self.num_layers-1): 44 | block0.append(GConvSE3(fin, fibers['mid'], self_interaction=True, edge_dim=self.edge_dim)) 45 | block0.append(GNormSE3(fibers['mid'], num_layers=self.num_nlayers)) 46 | fin = fibers['mid'] 47 | block0.append(GConvSE3(fibers['mid'], fibers['out'], self_interaction=True, edge_dim=self.edge_dim)) 48 | 49 | 50 | block1 = [GMaxPooling()] 51 | 52 | block2 = [] 53 | block2.append(nn.Linear(self.num_channels_out, self.num_channels_out)) 54 | block2.append(nn.ReLU(inplace=True)) 55 | block2.append(nn.Linear(self.num_channels_out, out_dim)) 56 | 57 | return nn.ModuleList(block0), nn.ModuleList(block1), nn.ModuleList(block2) 58 | 59 | def forward(self, G): 60 | # Compute equivariant weight basis from relative positions 61 | basis, r = get_basis_and_r(G, self.num_degrees-1) 62 | 63 | # encoder (equivariant layers) 64 | h = {'0': G.ndata['f']} 65 | for layer in self.block0: 66 | h = layer(h, G=G, r=r, basis=basis) 67 | 68 | h = h['0'][...,-1] 69 | for layer in self.block1: 70 | h = layer(G, h) 71 | 72 | for layer in self.block2: 73 | h = layer(h) 74 | 75 | return h 76 | 77 | 78 | class OursTFN(nn.Module): 79 | """SE(3) equivariant GCN""" 80 | def __init__(self, num_layers: int, 81 | num_channels: int, num_nlayers: int=1, num_degrees: int=4, act_fn=nn.ReLU(), 82 | edge_dim: int=4, out_types={1: 1}, in_types={0: 1, 1: 1}, **kwargs): 83 | super().__init__() 84 | # Build the network 85 | self.num_layers = num_layers 86 | self.num_nlayers = num_nlayers 87 | self.num_channels = num_channels 88 | self.num_degrees = num_degrees 89 | self.num_channels_out = num_channels*num_degrees 90 | self.edge_dim = edge_dim 91 | self.act_fn = act_fn 92 | 93 | self.fibers = {'in': Fiber(dictionary=in_types), 94 | 'mid': Fiber(num_degrees, self.num_channels), 95 | 'out': Fiber(dictionary=out_types)} 96 | 97 | blocks = self._build_gcn(self.fibers, 1) 98 | self.block0 = blocks 99 | print(self.block0) 100 | 101 | def _build_gcn(self, fibers, out_dim): 102 | 103 | block0 = [] 104 | fin = fibers['in'] 105 | for i in range(self.num_layers-1): 106 | block0.append(GConvSE3(fin, fibers['mid'], self_interaction=True, edge_dim=self.edge_dim, act_fn=self.act_fn)) 107 | block0.append(GNormSE3(fibers['mid'], num_layers=self.num_nlayers, act_fn=self.act_fn)) 108 | fin = fibers['mid'] 109 | block0.append(GConvSE3(fibers['mid'], fibers['out'], self_interaction=True, edge_dim=self.edge_dim, act_fn=self.act_fn)) 110 | 111 | ''' 112 | block1 = [GMaxPooling()] 113 | 114 | block2 = [] 115 | block2.append(nn.Linear(self.num_channels_out, self.num_channels_out)) 116 | block2.append(nn.ReLU(inplace=True)) 117 | block2.append(nn.Linear(self.num_channels_out, out_dim)) 118 | ''' 119 | return nn.ModuleList(block0)#, nn.ModuleList(block1), nn.ModuleList(block2) 120 | 121 | def forward(self, G): 122 | # Compute equivariant weight basis from relative positions 123 | basis, r = get_basis_and_r(G, self.num_degrees-1) 124 | 125 | # encoder (equivariant layers) 126 | h = {'0': G.ndata['f'], '1': G.ndata['f1']} 127 | for layer in self.block0: 128 | h = layer(h, G=G, r=r, basis=basis) 129 | 130 | ''' 131 | h = h['0'][...,-1] 132 | for layer in self.block1: 133 | h = layer(G, h) 134 | 135 | for layer in self.block2: 136 | h = layer(h) 137 | ''' 138 | 139 | return h 140 | 141 | 142 | class SE3Transformer(nn.Module): 143 | """SE(3) equivariant GCN with attention""" 144 | def __init__(self, num_layers: int, atom_feature_size: int, 145 | num_channels: int, num_nlayers: int=1, num_degrees: int=4, 146 | edge_dim: int=4, div: float=4, pooling: str='avg', n_heads: int=1, **kwargs): 147 | super().__init__() 148 | # Build the network 149 | self.num_layers = num_layers 150 | self.num_nlayers = num_nlayers 151 | self.num_channels = num_channels 152 | self.num_degrees = num_degrees 153 | self.edge_dim = edge_dim 154 | self.div = div 155 | self.pooling = pooling 156 | self.n_heads = n_heads 157 | 158 | self.fibers = {'in': Fiber(1, atom_feature_size), 159 | 'mid': Fiber(num_degrees, self.num_channels), 160 | 'out': Fiber(1, num_degrees*self.num_channels)} 161 | 162 | blocks = self._build_gcn(self.fibers, 1) 163 | self.Gblock, self.FCblock = blocks 164 | print(self.Gblock) 165 | print(self.FCblock) 166 | 167 | def _build_gcn(self, fibers, out_dim): 168 | # Equivariant layers 169 | Gblock = [] 170 | fin = fibers['in'] 171 | for i in range(self.num_layers): 172 | Gblock.append(GSE3Res(fin, fibers['mid'], edge_dim=self.edge_dim, 173 | div=self.div, n_heads=self.n_heads)) 174 | Gblock.append(GNormSE3(fibers['mid'])) 175 | fin = fibers['mid'] 176 | Gblock.append(GConvSE3(fibers['mid'], fibers['out'], self_interaction=True, edge_dim=self.edge_dim)) 177 | 178 | # Pooling 179 | if self.pooling == 'avg': 180 | Gblock.append(GAvgPooling()) 181 | elif self.pooling == 'max': 182 | Gblock.append(GMaxPooling()) 183 | 184 | # FC layers 185 | FCblock = [] 186 | FCblock.append(nn.Linear(self.fibers['out'].n_features, self.fibers['out'].n_features)) 187 | FCblock.append(nn.ReLU(inplace=True)) 188 | FCblock.append(nn.Linear(self.fibers['out'].n_features, out_dim)) 189 | 190 | return nn.ModuleList(Gblock), nn.ModuleList(FCblock) 191 | 192 | def forward(self, G): 193 | # Compute equivariant weight basis from relative positions 194 | basis, r = get_basis_and_r(G, self.num_degrees-1) 195 | 196 | # encoder (equivariant layers) 197 | h = {'0': G.ndata['f']} 198 | for layer in self.Gblock: 199 | h = layer(h, G=G, r=r, basis=basis) 200 | 201 | for layer in self.FCblock: 202 | h = layer(h) 203 | 204 | return h 205 | 206 | 207 | class OurSE3Transformer(nn.Module): 208 | """SE(3) equivariant GCN with attention""" 209 | def __init__(self, num_layers: int, 210 | num_channels: int, num_nlayers: int=1, num_degrees: int=4, 211 | edge_dim: int=4, div: float=1, pooling: str='avg', 212 | n_heads: int=1, act_fn=nn.ReLU(), 213 | out_types={1: 1}, in_types={0: 1, 1:1}, **kwargs): 214 | super().__init__() 215 | # Build the network 216 | self.num_layers = num_layers 217 | self.num_nlayers = num_nlayers 218 | self.num_channels = num_channels 219 | self.num_degrees = num_degrees 220 | self.edge_dim = edge_dim 221 | self.div = div 222 | self.pooling = pooling 223 | self.n_heads = n_heads 224 | self.act_fn = act_fn 225 | self.fibers = {'in': Fiber(dictionary=in_types), 226 | 'mid': Fiber(num_degrees, self.num_channels), 227 | 'out': Fiber(dictionary=out_types)} 228 | 229 | self.Gblock = self._build_gcn(self.fibers, 1) 230 | # self.Gblock, self.FCblock = blocks 231 | print(self.Gblock) 232 | # print(self.FCblock) 233 | self.counter = 0 234 | self.scalar_trick = nn.Parameter(torch.ones(1)*0.01) 235 | 236 | def _build_gcn(self, fibers, out_dim): 237 | # Equivariant layers 238 | Gblock = [] 239 | fin = fibers['in'] 240 | for i in range(self.num_layers): 241 | Gblock.append(GSE3Res(fin, fibers['mid'], edge_dim=self.edge_dim, 242 | div=self.div, n_heads=self.n_heads, act_fn=self.act_fn, learnable_skip=False)) 243 | Gblock.append(GNormSE3(fibers['mid'], act_fn=self.act_fn)) 244 | fin = fibers['mid'] 245 | Gblock.append(GConvSE3(fibers['mid'], fibers['out'], self_interaction=True, edge_dim=self.edge_dim, act_fn=self.act_fn)) 246 | 247 | # Pooling 248 | # if self.pooling == 'avg': 249 | # Gblock.append(GAvgPooling()) 250 | # elif self.pooling == 'max': 251 | # Gblock.append(GMaxPooling()) 252 | 253 | # FC layers 254 | # FCblock = [] 255 | # FCblock.append(nn.Linear(self.fibers['out'].n_features, self.fibers['out'].n_features)) 256 | # FCblock.append(nn.ReLU(inplace=True)) 257 | # FCblock.append(nn.Linear(self.fibers['out'].n_features, out_dim)) 258 | 259 | return nn.ModuleList(Gblock) 260 | # return nn.ModuleList(Gblock), nn.ModuleList(FCblock) 261 | 262 | def forward(self, G): 263 | # Compute equivariant weight basis from relative positions 264 | # if torch.cuda.is_available(): 265 | # torch.cuda.synchronize() 266 | #t1 = time.time() 267 | basis, r = get_basis_and_r(G, self.num_degrees-1) 268 | # if torch.cuda.is_available(): 269 | # torch.cuda.synchronize() 270 | # t2 = time.time() 271 | 272 | 273 | self.counter += 1 274 | # encoder (equivariant layers) 275 | h = {'0': G.ndata['f'], '1': G.ndata['f1']} 276 | for layer in self.Gblock: 277 | h = layer(h, G=G, r=r, basis=basis) 278 | 279 | # if torch.cuda.is_available(): 280 | # torch.cuda.synchronize() 281 | # t3 = time.time() 282 | 283 | # for layer in self.FCblock: 284 | # h = layer(h) 285 | 286 | # print("Counter %i \t Time get_basis_and_r: %.3f \t Time forward %.3f" % (self.counter, t2 - t1, t3 - t1)) 287 | 288 | # 289 | for key in h: 290 | #std = torch.mean(h[key] * h[key]) 291 | # std = torch.std(h[key]) 292 | # print("STD %.5f" % (std.item())) 293 | h[key] = h[key] * self.scalar_trick 294 | 295 | return h 296 | -------------------------------------------------------------------------------- /ponita/models/ponita.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.nn import global_add_pool 4 | from ponita.utils.to_from_sphere import sphere_to_scalar, sphere_to_vec 5 | from ponita.nn.embedding import PolynomialFeatures 6 | from ponita.utils.windowing import PolynomialCutoff 7 | from ponita.transforms import PositionOrientationGraph, SEnInvariantAttributes 8 | from torch_geometric.transforms import Compose 9 | from torch_scatter import scatter_mean 10 | from ponita.nn.conv import Conv, FiberBundleConv 11 | from ponita.nn.convnext import ConvNext 12 | from torch_geometric.transforms import BaseTransform, Compose, RadiusGraph 13 | 14 | 15 | # Wrapper to automatically switch between point cloud mode (num_ori = -1 or 0) and 16 | # bundle mode (num_ori > 0). 17 | def Ponita(input_dim, hidden_dim, output_dim, num_layers, output_dim_vec = 0, radius = None, 18 | num_ori=20, basis_dim=None, degree=3, widening_factor=4, layer_scale=None, 19 | task_level='graph', multiple_readouts=True, lift_graph=False, **kwargs): 20 | # Select either FiberBundle mode or PointCloud mode 21 | PonitaClass = PonitaFiberBundle if (num_ori > 0) else PonitaPointCloud 22 | # Return the ponita object 23 | return PonitaClass(input_dim, hidden_dim, output_dim, num_layers, output_dim_vec = output_dim_vec, 24 | radius = radius, num_ori=num_ori, basis_dim=basis_dim, degree=degree, 25 | widening_factor=widening_factor, layer_scale=layer_scale, task_level=task_level, 26 | multiple_readouts=multiple_readouts, lift_graph=lift_graph, **kwargs) 27 | 28 | 29 | class PonitaFiberBundle(nn.Module): 30 | """ Steerable E(3) equivariant (non-linear) convolutional network """ 31 | def __init__(self, 32 | input_dim, 33 | hidden_dim, 34 | output_dim, 35 | num_layers, 36 | output_dim_vec = 0, 37 | radius = None, 38 | num_ori=20, 39 | basis_dim=None, 40 | degree=3, 41 | widening_factor=4, 42 | layer_scale=None, 43 | task_level='graph', 44 | multiple_readouts=True, 45 | **kwargs): 46 | super().__init__() 47 | 48 | # Input output settings 49 | self.output_dim, self.output_dim_vec = output_dim, output_dim_vec 50 | self.global_pooling = task_level=='graph' 51 | 52 | # For constructing the position-orientation graph and its invariants 53 | self.transform = Compose([PositionOrientationGraph(num_ori), SEnInvariantAttributes(separable=True)]) 54 | 55 | # Activation function to use internally 56 | act_fn = torch.nn.GELU() 57 | 58 | # Kernel basis functions and spatial window 59 | basis_dim = hidden_dim if (basis_dim is None) else basis_dim 60 | self.basis_fn = nn.Sequential(PolynomialFeatures(degree), nn.LazyLinear(hidden_dim), act_fn, nn.Linear(hidden_dim, basis_dim), act_fn) 61 | self.fiber_basis_fn = nn.Sequential(PolynomialFeatures(degree), nn.LazyLinear(hidden_dim), act_fn, nn.Linear(hidden_dim, basis_dim), act_fn) 62 | self.windowing_fn = PolynomialCutoff(radius) 63 | 64 | # Initial node embedding 65 | self.x_embedder = nn.Linear(input_dim, hidden_dim, False) 66 | 67 | # Make feedforward network 68 | self.interaction_layers = nn.ModuleList() 69 | self.read_out_layers = nn.ModuleList() 70 | for i in range(num_layers): 71 | conv = FiberBundleConv(hidden_dim, hidden_dim, basis_dim, groups=hidden_dim, separable=True) 72 | layer = ConvNext(hidden_dim, conv, act=act_fn, layer_scale=layer_scale, widening_factor=widening_factor) 73 | self.interaction_layers.append(layer) 74 | # self.interaction_layers.append(ConvNextR3S2(hidden_dim, basis_dim, act=act_fn, widening_factor=widening_factor, layer_scale=layer_scale)) 75 | if multiple_readouts or i == (num_layers - 1): 76 | self.read_out_layers.append(nn.Linear(hidden_dim, output_dim + output_dim_vec)) 77 | else: 78 | self.read_out_layers.append(None) 79 | 80 | def forward(self, graph): 81 | 82 | # Lift and compute invariants 83 | graph = self.transform(graph) 84 | 85 | # Sample the kernel basis and window the spatial kernel with a smooth cut-off 86 | kernel_basis = self.basis_fn(graph.attr) * self.windowing_fn(graph.dists).unsqueeze(-2) 87 | fiber_kernel_basis = self.fiber_basis_fn(graph.fiber_attr) 88 | 89 | # Initial feature embeding 90 | x = self.x_embedder(graph.x) 91 | 92 | # Interaction + readout layers 93 | readouts = [] 94 | for interaction_layer, readout_layer in zip(self.interaction_layers, self.read_out_layers): 95 | x = interaction_layer(x, graph.edge_index, edge_attr=kernel_basis, fiber_attr=fiber_kernel_basis, batch=graph.batch) 96 | if readout_layer is not None: readouts.append(readout_layer(x)) 97 | readout = sum(readouts) / len(readouts) 98 | 99 | # Read out the scalar and vector part of the output 100 | readout_scalar, readout_vec = torch.split(readout, [self.output_dim, self.output_dim_vec], dim=-1) 101 | 102 | # Read out scalar and vector predictions 103 | output_scalar = self.scalar_readout_fn(readout_scalar, graph.batch) 104 | output_vector = self.vec_readout_fn(readout_vec, graph.ori_grid, graph.batch) 105 | 106 | # Return predictions 107 | return output_scalar, output_vector 108 | 109 | def scalar_readout_fn(self, readout_scalar, batch): 110 | if self.output_dim > 0: 111 | output_scalar = sphere_to_scalar(readout_scalar) 112 | if self.global_pooling: 113 | output_scalar=global_add_pool(output_scalar, batch) 114 | else: 115 | output_scalar = None 116 | return output_scalar 117 | 118 | def vec_readout_fn(self, readout_vec, ori_grid, batch): 119 | if self.output_dim_vec > 0: 120 | output_vector = sphere_to_vec(readout_vec, ori_grid) 121 | if self.global_pooling: 122 | output_vector = global_add_pool(output_vector, batch) 123 | else: 124 | output_vector = None 125 | return output_vector 126 | 127 | 128 | class PonitaPointCloud(nn.Module): 129 | """ Steerable E(3) equivariant (non-linear) convolutional network """ 130 | def __init__(self, 131 | input_dim, 132 | hidden_dim, 133 | output_dim, 134 | num_layers, 135 | output_dim_vec = 0, 136 | radius = None, 137 | num_ori = -1, 138 | basis_dim=None, 139 | degree=3, 140 | widening_factor=4, 141 | layer_scale=None, 142 | task_level='graph', 143 | multiple_readouts=True, 144 | lift_graph=False, 145 | **kwargs): 146 | super().__init__() 147 | 148 | # Input output settings 149 | self.output_dim, self.output_dim_vec = output_dim, output_dim_vec 150 | self.global_pooling = (task_level=='graph') 151 | 152 | # For constructing the position-orientation graph and its invariants 153 | self.lift_graph = lift_graph 154 | if lift_graph: 155 | self.transform = Compose([PositionOrientationGraph(num_ori, radius), SEnInvariantAttributes(separable=False, point_cloud=True)]) 156 | 157 | # Activation function to use internally 158 | act_fn = torch.nn.GELU() 159 | 160 | # Kernel basis functions and spatial window 161 | basis_dim = hidden_dim if (basis_dim is None) else basis_dim 162 | self.basis_fn = nn.Sequential(PolynomialFeatures(degree), nn.LazyLinear(hidden_dim), act_fn, nn.Linear(hidden_dim, basis_dim), act_fn) 163 | self.windowing_fn = PolynomialCutoff(radius) 164 | 165 | # Initial node embedding 166 | self.x_embedder = nn.Linear(input_dim, hidden_dim, False) 167 | 168 | # Make feedforward network 169 | self.interaction_layers = nn.ModuleList() 170 | self.read_out_layers = nn.ModuleList() 171 | for i in range(num_layers): 172 | conv = Conv(hidden_dim, hidden_dim, basis_dim, groups=hidden_dim) 173 | layer = ConvNext(hidden_dim, conv, act=act_fn, layer_scale=layer_scale, widening_factor=widening_factor) 174 | self.interaction_layers.append(layer) 175 | if multiple_readouts or i == (num_layers - 1): 176 | self.read_out_layers.append(nn.Linear(hidden_dim, output_dim + output_dim_vec)) 177 | else: 178 | self.read_out_layers.append(None) 179 | 180 | def forward(self, graph): 181 | 182 | # Lift and compute invariants 183 | if self.lift_graph: 184 | graph = self.transform(graph) 185 | 186 | # Sample the kernel basis and window the spatial kernel with a smooth cut-off 187 | kernel_basis = self.basis_fn(graph.attr) * self.windowing_fn(graph.dists) 188 | 189 | # Initial feature embeding 190 | x = self.x_embedder(graph.x) 191 | 192 | # Interaction + readout layers 193 | readouts = [] 194 | for interaction_layer, readout_layer in zip(self.interaction_layers, self.read_out_layers): 195 | x = interaction_layer(x, graph.edge_index, edge_attr=kernel_basis, batch=graph.batch) 196 | if readout_layer is not None: readouts.append(readout_layer(x)) 197 | readout = sum(readouts) / len(readouts) 198 | 199 | # Read out the scalar and vector part of the output 200 | readout_scalar, readout_vec = torch.split(readout, [self.output_dim, self.output_dim_vec], dim=-1) 201 | 202 | # Read out scalar and vector predictions (if pos-ori cloud collect all predictions that have the same base point in R^n) 203 | if hasattr(graph, 'scatter_projection_index'): 204 | output_scalar = self.scalar_readout_fn(readout_scalar, graph.batch, graph.scatter_projection_index) 205 | output_vector = self.vec_readout_fn(readout_vec, graph.pos, graph.batch, graph.scatter_projection_index) 206 | else: 207 | output_scalar = readout_scalar 208 | if self.global_pooling: 209 | output_scalar=global_add_pool(output_scalar, graph.batch) 210 | output_vector = None 211 | 212 | # Return predictions 213 | return output_scalar, output_vector 214 | 215 | def scalar_readout_fn(self, readout_scalar, batch, scatter_projection_index): 216 | if self.output_dim > 0: 217 | # Aggregate predictions toward the base position in R^n 218 | output_scalar = scatter_mean(readout_scalar, scatter_projection_index, dim=0) 219 | if self.global_pooling: 220 | batch_Rn = scatter_mean(batch, scatter_projection_index, dim=0).type_as(batch) 221 | output_scalar=global_add_pool(output_scalar, batch_Rn) 222 | else: 223 | output_scalar = None 224 | return output_scalar 225 | 226 | def vec_readout_fn(self, readout_vec, pos, batch, scatter_projection_index): 227 | if self.output_dim_vec > 0: 228 | # Scale each orientation with the predicted scalar and aggregate via scatter_mean 229 | _, ori = pos.split(int(pos.shape[-1]/2),dim=-1) 230 | output_vector = scatter_mean(readout_vec[:,:,None] * ori[:,None,:], scatter_projection_index, dim=0) 231 | if self.global_pooling: 232 | batch_Rn = scatter_mean(batch, scatter_projection_index, dim=0).type_as(batch) 233 | output_vector = global_add_pool(output_vector, batch_Rn) 234 | else: 235 | output_vector = None 236 | return output_vector 237 | -------------------------------------------------------------------------------- /n_body_system/se3_dynamics/equivariant_attention/from_se3cnn/utils_steerable.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | from .SO3 import irr_repr, torch_default_dtype 5 | from .cache_file import cached_dirpklgz 6 | from .representations import SphericalHarmonics 7 | 8 | ################################################################################ 9 | # Solving the constraint coming from the stabilizer of 0 and e 10 | ################################################################################ 11 | 12 | def get_matrix_kernel(A, eps=1e-10): 13 | ''' 14 | Compute an orthonormal basis of the kernel (x_1, x_2, ...) 15 | A x_i = 0 16 | scalar_product(x_i, x_j) = delta_ij 17 | 18 | :param A: matrix 19 | :return: matrix where each row is a basis vector of the kernel of A 20 | ''' 21 | _u, s, v = torch.svd(A) 22 | 23 | # A = u @ torch.diag(s) @ v.t() 24 | kernel = v.t()[s < eps] 25 | return kernel 26 | 27 | 28 | def get_matrices_kernel(As, eps=1e-10): 29 | ''' 30 | Computes the commun kernel of all the As matrices 31 | ''' 32 | return get_matrix_kernel(torch.cat(As, dim=0), eps) 33 | 34 | 35 | @cached_dirpklgz("cache/trans_Q") 36 | def _basis_transformation_Q_J(J, order_in, order_out, version=3): # pylint: disable=W0613 37 | """ 38 | :param J: order of the spherical harmonics 39 | :param order_in: order of the input representation 40 | :param order_out: order of the output representation 41 | :return: one part of the Q^-1 matrix of the article 42 | """ 43 | with torch_default_dtype(torch.float64): 44 | def _R_tensor(a, b, c): return kron(irr_repr(order_out, a, b, c), irr_repr(order_in, a, b, c)) 45 | 46 | def _sylvester_submatrix(J, a, b, c): 47 | ''' generate Kronecker product matrix for solving the Sylvester equation in subspace J ''' 48 | R_tensor = _R_tensor(a, b, c) # [m_out * m_in, m_out * m_in] 49 | R_irrep_J = irr_repr(J, a, b, c) # [m, m] 50 | return kron(R_tensor, torch.eye(R_irrep_J.size(0))) - \ 51 | kron(torch.eye(R_tensor.size(0)), R_irrep_J.t()) # [(m_out * m_in) * m, (m_out * m_in) * m] 52 | 53 | random_angles = [ 54 | [4.41301023, 5.56684102, 4.59384642], 55 | [4.93325116, 6.12697327, 4.14574096], 56 | [0.53878964, 4.09050444, 5.36539036], 57 | [2.16017393, 3.48835314, 5.55174441], 58 | [2.52385107, 0.2908958, 3.90040975] 59 | ] 60 | null_space = get_matrices_kernel([_sylvester_submatrix(J, a, b, c) for a, b, c in random_angles]) 61 | assert null_space.size(0) == 1, null_space.size() # unique subspace solution 62 | Q_J = null_space[0] # [(m_out * m_in) * m] 63 | Q_J = Q_J.view((2 * order_out + 1) * (2 * order_in + 1), 2 * J + 1) # [m_out * m_in, m] 64 | assert all(torch.allclose(_R_tensor(a, b, c) @ Q_J, Q_J @ irr_repr(J, a, b, c)) for a, b, c in torch.rand(4, 3)) 65 | 66 | assert Q_J.dtype == torch.float64 67 | return Q_J # [m_out * m_in, m] 68 | 69 | 70 | # @profile 71 | def get_spherical_from_cartesian_torch(cartesian, divide_radius_by=1.0): 72 | 73 | ################################################################################################################### 74 | # ON ANGLE CONVENTION 75 | # 76 | # sh has following convention for angles: 77 | # :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)). 78 | # :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi. 79 | # 80 | # the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta: 81 | # beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)). 82 | # alpha = phi 83 | # 84 | ################################################################################################################### 85 | 86 | # initialise return array 87 | # ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) 88 | spherical = torch.zeros_like(cartesian) 89 | 90 | # indices for return array 91 | ind_radius = 0 92 | ind_alpha = 1 93 | ind_beta = 2 94 | 95 | cartesian_x = 2 96 | cartesian_y = 0 97 | cartesian_z = 1 98 | 99 | # get projected radius in xy plane 100 | # xy = xyz[:,0]**2 + xyz[:,1]**2 101 | r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2 102 | 103 | # get second angle 104 | # version 'elevation angle defined from Z-axis down' 105 | spherical[..., ind_beta] = torch.atan2(torch.sqrt(r_xy), cartesian[..., cartesian_z]) 106 | # ptsnew[:,4] = np.arctan2(np.sqrt(xy), xyz[:,2]) 107 | # version 'elevation angle defined from XY-plane up' 108 | #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) 109 | # spherical[:, ind_beta] = np.arctan2(cartesian[:, 2], np.sqrt(r_xy)) 110 | 111 | # get angle in x-y plane 112 | spherical[...,ind_alpha] = torch.atan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x]) 113 | 114 | # get overall radius 115 | # ptsnew[:,3] = np.sqrt(xy + xyz[:,2]**2) 116 | if divide_radius_by == 1.0: 117 | spherical[..., ind_radius] = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2) 118 | else: 119 | spherical[..., ind_radius] = torch.sqrt(r_xy + cartesian[...,cartesian_z]**2)/divide_radius_by 120 | 121 | return spherical 122 | 123 | 124 | # @profile 125 | def get_spherical_from_cartesian(cartesian): 126 | 127 | ################################################################################################################### 128 | # ON ANGLE CONVENTION 129 | # 130 | # sh has following convention for angles: 131 | # :param theta: the colatitude / polar angle, ranging from 0(North Pole, (X, Y, Z) = (0, 0, 1)) to pi(South Pole, (X, Y, Z) = (0, 0, -1)). 132 | # :param phi: the longitude / azimuthal angle, ranging from 0 to 2 pi. 133 | # 134 | # the 3D steerable CNN code therefore (probably) has the following convention for alpha and beta: 135 | # beta = pi - theta; ranging from 0(South Pole, (X, Y, Z) = (0, 0, -1)) to pi(North Pole, (X, Y, Z) = (0, 0, 1)). 136 | # alpha = phi 137 | # 138 | ################################################################################################################### 139 | 140 | if torch.is_tensor(cartesian): 141 | cartesian = np.array(cartesian.cpu()) 142 | 143 | # initialise return array 144 | # ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) 145 | spherical = np.zeros(cartesian.shape) 146 | 147 | # indices for return array 148 | ind_radius = 0 149 | ind_alpha = 1 150 | ind_beta = 2 151 | 152 | cartesian_x = 2 153 | cartesian_y = 0 154 | cartesian_z = 1 155 | 156 | # get projected radius in xy plane 157 | # xy = xyz[:,0]**2 + xyz[:,1]**2 158 | r_xy = cartesian[..., cartesian_x] ** 2 + cartesian[..., cartesian_y] ** 2 159 | 160 | # get overall radius 161 | # ptsnew[:,3] = np.sqrt(xy + xyz[:,2]**2) 162 | spherical[..., ind_radius] = np.sqrt(r_xy + cartesian[...,cartesian_z]**2) 163 | 164 | # get second angle 165 | # version 'elevation angle defined from Z-axis down' 166 | spherical[..., ind_beta] = np.arctan2(np.sqrt(r_xy), cartesian[..., cartesian_z]) 167 | # ptsnew[:,4] = np.arctan2(np.sqrt(xy), xyz[:,2]) 168 | # version 'elevation angle defined from XY-plane up' 169 | #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) 170 | # spherical[:, ind_beta] = np.arctan2(cartesian[:, 2], np.sqrt(r_xy)) 171 | 172 | # get angle in x-y plane 173 | spherical[...,ind_alpha] = np.arctan2(cartesian[...,cartesian_y], cartesian[...,cartesian_x]) 174 | 175 | return spherical 176 | 177 | def test_coordinate_conversion(): 178 | p = np.array([0, 0, -1]) 179 | expected = np.array([1, 0, 0]) 180 | assert get_spherical_from_cartesian(p) == expected 181 | return True 182 | 183 | 184 | def spherical_harmonics(order, alpha, beta, dtype=None): 185 | """ 186 | spherical harmonics 187 | - compatible with irr_repr and compose 188 | 189 | computation time: excecuting 1000 times with array length 1 took 0.29 seconds; 190 | executing it once with array of length 1000 took 0.0022 seconds 191 | """ 192 | #Y = [tesseral_harmonics(order, m, theta=math.pi - beta, phi=alpha) for m in range(-order, order + 1)] 193 | #Y = torch.stack(Y, -1) 194 | # Y should have dimension 2*order + 1 195 | return SphericalHarmonics.get(order, theta=math.pi-beta, phi=alpha) 196 | 197 | # @profile 198 | def kron(a, b): 199 | """ 200 | A part of the pylabyk library: numpytorch.py at https://github.com/yulkang/pylabyk 201 | 202 | Kronecker product of matrices a and b with leading batch dimensions. 203 | Batch dimensions are broadcast. The number of them mush 204 | :type a: torch.Tensor 205 | :type b: torch.Tensor 206 | :rtype: torch.Tensor 207 | """ 208 | siz1 = torch.Size(torch.tensor(a.shape[-2:]) * torch.tensor(b.shape[-2:])) 209 | res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4) 210 | siz0 = res.shape[:-4] 211 | return res.reshape(siz0 + siz1) 212 | 213 | 214 | def get_maximum_order_unary_only(per_layer_orders_and_multiplicities): 215 | """ 216 | determine what spherical harmonics we need to pre-compute. if we have the 217 | unary term only, we need to compare all adjacent layers 218 | 219 | the spherical harmonics function depends on J (irrep order) purely, which is dedfined by 220 | order_irreps = list(range(abs(order_in - order_out), order_in + order_out + 1)) 221 | simplification: we only care about the maximum (in some circumstances that means we calculate a few lower 222 | order spherical harmonics which we won't actually need) 223 | 224 | :param per_layer_orders_and_multiplicities: nested list of lists of 2-tuples 225 | :return: integer indicating maximum order J 226 | """ 227 | 228 | n_layers = len(per_layer_orders_and_multiplicities) 229 | 230 | # extract orders only 231 | per_layer_orders = [] 232 | for i in range(n_layers): 233 | cur = per_layer_orders_and_multiplicities[i] 234 | cur = [o for (m, o) in cur] 235 | per_layer_orders.append(cur) 236 | 237 | track_max = 0 238 | # compare two (adjacent) layers at a time 239 | for i in range(n_layers - 1): 240 | cur = per_layer_orders[i] 241 | nex = per_layer_orders[i + 1] 242 | track_max = max(max(cur) + max(nex), track_max) 243 | 244 | return track_max 245 | 246 | 247 | def get_maximum_order_with_pairwise(per_layer_orders_and_multiplicities): 248 | """ 249 | determine what spherical harmonics we need to pre-compute. for pairwise 250 | interactions, this will just be twice the maximum order 251 | 252 | the spherical harmonics function depends on J (irrep order) purely, which is defined by 253 | order_irreps = list(range(abs(order_in - order_out), order_in + order_out + 1)) 254 | simplification: we only care about the maximum (in some circumstances that means we calculate a few lower 255 | order spherical harmonics which we won't actually need) 256 | 257 | :param per_layer_orders_and_multiplicities: nested list of lists of 2-tuples 258 | :return: integer indicating maximum order J 259 | """ 260 | 261 | n_layers = len(per_layer_orders_and_multiplicities) 262 | 263 | track_max = 0 264 | for i in range(n_layers): 265 | cur = per_layer_orders_and_multiplicities[i] 266 | # extract orders only 267 | orders = [o for (m, o) in cur] 268 | track_max = max(track_max, max(orders)) 269 | 270 | return 2*track_max 271 | 272 | 273 | def precompute_sh(r_ij, max_J): 274 | """ 275 | pre-comput spherical harmonics up to order max_J 276 | 277 | :param r_ij: relative positions 278 | :param max_J: maximum order used in entire network 279 | :return: dict where each entry has shape [B,N,K,2J+1] 280 | """ 281 | 282 | i_distance = 0 283 | i_alpha = 1 284 | i_beta = 2 285 | 286 | Y_Js = {} 287 | sh = SphericalHarmonics() 288 | 289 | for J in range(max_J+1): 290 | # dimension [B,N,K,2J+1] 291 | #Y_Js[J] = spherical_harmonics(order=J, alpha=r_ij[...,i_alpha], beta=r_ij[...,i_beta]) 292 | Y_Js[J] = sh.get(J, theta=math.pi-r_ij[...,i_beta], phi=r_ij[...,i_alpha], refresh=False) 293 | 294 | sh.clear() 295 | return Y_Js 296 | 297 | 298 | class ScalarActivation3rdDim(torch.nn.Module): 299 | def __init__(self, n_dim, activation, bias=True): 300 | ''' 301 | Can be used only with scalar fields [B, N, s] on last dimension 302 | 303 | :param n_dim: number of scalar fields to apply activation to 304 | :param bool bias: add a bias before the applying the activation 305 | ''' 306 | super().__init__() 307 | 308 | self.activation = activation 309 | 310 | if bias and n_dim > 0: 311 | self.bias = torch.nn.Parameter(torch.zeros(n_dim)) 312 | else: 313 | self.bias = None 314 | 315 | def forward(self, input): 316 | ''' 317 | :param input: [B, N, s] 318 | ''' 319 | 320 | assert len(np.array(input.shape)) == 3 321 | 322 | if self.bias is not None: 323 | x = input + self.bias.view(1, 1, -1) 324 | else: 325 | x = input 326 | x = self.activation(x) 327 | 328 | return x 329 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ✨ 🐴 🔥 PONITA 2 | 3 | ACCEPTED AT [ICLR 2024](https://openreview.net/forum?id=dPHLbUqGbr)! 4 | 5 | MINIMAL DEPENDENCY PYTORCH IMPLEMENTATION CAN BE FOUND [HERE](https://github.com/ebekkers/ponita-torch) 6 | 7 | JAX IMPLEMENTATION CAN BE FOUND [HERE](https://github.com/ebekkers/ponita-jax) (work in progress) 8 | 9 | ## What is this repository about? 10 | This repository contains the code for the paper [Fast, Expressive SE(n) Equivariant Networks through Weight-Sharing in Position-Orientation Space](https://arxiv.org/abs/2310.02970). We propose **PONITA**: a simple fully convolutional SE(n) equivariant architecture. We developed it primarily for 3D point-cloud data, but the method is also applicable to 2D point clouds and 2D/3D images/volumes (though not yet with this repo). PONITA is an equivariant model that does not require working with steerable/Clebsch-Gordan methods, but has the same capabilities in that __it can handle scalars and vectors__ equally well. Moreover, since it does not depend on Clebsch-Gordan tensor products __PONITA is much faster__ than the typical steerable/tensor field network! 11 | 12 | See below results and code for benchmarks for 2D (**super-pixel MNIST**) and 3D point clouds with vector attributes (**n-body**) and without (**MD17**), as well as an example of position-orientation space point clouds (**QM9**)! Results for equivariant generative modeling are in the paper (which will soon be updated with the MNIST and QM9 regression results as presented below). 13 | 14 | ## About the name 15 | PONITA is an acronym for Position-Orientation space Networks based on InvarianT Attributes. We believe this acronym is apt for the method for two reasons. Firstly, PONITA sounds like "bonita" ✨ which means pretty in Spanish, we personally think the architecture is pretty and elegant. Secondly, [Ponyta](https://bulbapedia.bulbagarden.net/wiki/Ponyta_(Pok%C3%A9mon)) 🐴 🔥 is a fire Pokémon which is known to be very fast, our method is fast as well. 16 | 17 | ## About the implementation 18 | 19 | ### The ponita model 20 | The PONITA model is provided in ```ponita/models/ponita.py``` and is designed to support point clouds in 2D and 3D position space $\mathbb{R}^2$ and $\mathbb{R}^3$ out of the box. It further more supports input point clouds in 2D and 3D position-orientation spaces $\mathbb{R}^2 \times S^1$ and $\mathbb{R}^3 \times S^2$. 21 | 22 | The forward assumes a graph object with attributes 23 | * ```graph.x```: which are the scalar features at each node, 24 | * ```graph.pos``` their corresponding positions, and 25 | * ```graph.y``` which are the targets at node level (```task_level='node'```) or at graph level (```task_level='graph'```). 26 | * ```graph.edge_index``` which defines how each node is connected to the other (source to target). However, the model also contains a ```radius``` option, which when specified is used to construct a radius graph which will overwrite (or define) the edge_index. It is also used to smoothly mask the message functions. 27 | 28 | Optionally, the model can also take as input 29 | * ```graph.vec``` which are vector features attached to each node. 30 | 31 | The model will always return **two outputs**: scalar ```output_scalar``` and ```output_vector```, in case no scalar or vector outputs are specified they will take on the value ```None```. 32 | 33 | ### The ponita layers 34 | 35 | ```PositionOrientationGraph(num_ori, radius)```, found in ```ponita\transforms\position_orientation_graph.py``` is a torch geometric transform that transforms the input graph to a graph in position orientation space. It will have two modes: 36 | * **Fiber bundel mode**: When ```num_ori > 0``` a regular fiber bundle (assigning a grid of ```num_ori``` orientations per position) is generated from the provided input point cloud in $\mathbb{R}^2$ or $\mathbb{R}^3$. Subsequently the separable conv layers as described in the main-body in the text can be used. The fiber bundle interpretation is described in Appendix A. 37 | * **Point cloud mode**: 38 | * When ```num_ori = 0``` the graph will be treated as a point cloud on $\mathbb{R}^n$. Nothing really happens in this transform except for defining some variables for consistency with the position-orientation graphs. 39 | * When ```num_ori < -1``` the graph will be lifted to a position orientation graph by using the provided ```graph.edge_index```. In this transformation each edge will become a node which has a certain position (starting point of the edge) and orientation (the normalized direction vector from source to target position). Also an object ```graph.scatter_projection_index``` will be generated to be able to reduce the lifted graph back to the original position point cloud representation. Note for this setting ```graph.edge_index``` **is required** to be existing in the graph. 40 | 41 | ---- 42 | ```SEnInvariantAttributes(separable, point_cloud)```, found in ```ponita\transforms\invariants.py``` is a torch geometric transform that adds invariant attributes to the graph. Only in fiber bundle mode can ```separable``` take on the value ```True```, as then one separate interactions spatial from those within the fiber. The ```separable``` option could also be set to ```False``` in the bundle mode to do full convolution over position-orientation space, but this requires quite some compute and memory. The option ```point_cloud``` switches between bundle mode (```point_cloud=False```) or point cloud mode (```point_cloud=True```). So, we have the following settings: 43 | * ```separable=True, point_cloud=False```: generates ```graph.attr``` (shape = ```[num_edges, num_ori, 2]```) and ```graph.fiber_attr``` (shape = ```[num_ori, num_ori, 1]```). 44 | * ```separable=False, point_cloud=False```: generates ```graph.attr``` (shape = ```[num_edges, num_ori, num_ori, 3]```) 45 | * ```separable=False, point_cloud=True```: generates ```graph.attr``` (shape = ```[num_edges, d]```) with ```d=1``` for position graphs and ```d=3``` for position-orientation graphs in 2D or 3D. 46 | 47 | ---- 48 | ```Conv(in_channels, out_channels, attr_dim, bias=True, aggr="add", groups=1)```, found in ```ponita\nn\conv.py``` implements convolution as a ```torch_geometric.nn.MessagePassing``` module. Considering the above graph constructions, this module operates on graphs in poin cloud mode. The forward requires a graph with ```graph.x``` (the features) and ```graph.edge_index``` and ```graph.attr``` (the pair-wise invariants, or embeddings thereof). 49 | 50 | ---- 51 | ```FiberBundleConv(in_channels, out_channels, attr_dim, bias=True, aggr="add", separable=True, groups=1)```, also found in ```ponita\nn\conv.py``` implements convolution as a ```torch_geometric.nn.MessagePassing``` module over the fiber bundle. Considering the above graph constructions, this module operates on graphs in fiber bundle mode (a grid of orientations at each node). The forward requires a graph with ```graph.x``` (the features) and ```graph.edge_index``` (connecting the spatial base-points) and ```graph.attr``` and ```graph.fiber_attr``` (these are the pair-wise spatial or orientation-spatial attributes, or embeddings thereof), see SEnInvariantAttributes. 52 | 53 | ---- 54 | ```ConvNext(channels, conv, act=torch.nn.GELU(), layer_scale=1e-6, widening_factor=4)```, found in ```ponita\nn\convnext```, is a wrapper that turns the convolution layer ```conv``` into a ConvNext block. 55 | 56 | ## Reproducing PONITA results 57 | 58 | ### MD17 (3D point clouds) 59 | 60 | The paper results can be reproduced by running the following command for rMD17: 61 | 62 | ```python3 main_md17.py``` 63 | 64 | We did a sweep over all "revised *" targets and the seeds 0, 1, 2. Otherwise the defaut settings in ```main.py``` are used. By setting ```num_ori=0``` the PNITA results are generated, with ```num_ori = 20``` the PONITA results are generated. The table shows our results compared to one of the seminal works on *steerable* equivariant convolutions: [NEQUIP](https://github.com/mir-group/nequip). Our methods, PNITA for position space convolutions and PONITA for position-orientation space convolutions, are based on invariants only. For comparison to other state-of-the-art-methods see our paper. 65 | 66 | | Target | | NEQUIP | PNITA | PONITA 67 | |-|-|-|-|- 68 | | Aspirin | E | 2.3 | 4.7 | **1.7** 69 | | | F | 8.2 | 16.3 | **5.8** 70 | | Azobenzene | E | 0.7 | 3.2 | **0.7** 71 | | | F | 2.9 | 12.2 | **2.3** 72 | | Benzene | E | **0.04** | 0.2 | 0.17 73 | | | F | **0.3** | 0.4 | **0.3** 74 | | Ethanol | E | **0.4** | 0.7 | **0.4** 75 | | | F | 2.8 | 4.1 | **2.5** 76 | | Malonaldehyde | E | 0.8 | 0.9 | **0.6** 77 | | | F | 5.1 | 5.1 | **4.0** 78 | | Napthalene | E | **0.2** | 1.1 | 0.3 79 | | | F | **1.3** | 5.6 | **1.3** 80 | | Paracetamol | E | 1.4 | 2.8 | **1.1** 81 | | | F | 5.9 | 11.4 | **4.3** 82 | | Salicylic acid | E | 0.7 | 1.7 | **0.7** 83 | | | F | 4.0 | 8.6 | **3.3** 84 | | Toluene | E | 0.3 | 0.6 | **0.3** 85 | | | F | 1.6 | 3.4 | **1.3** 86 | | Uracil | E | 0.4 | 0.9 | **0.4** 87 | | | F | 3.1 | 5.6 | **2.4** 88 | ____ 89 | ### N-body (3D point clouds with input vectors) 90 | 91 | For the n-body experiments run 92 | 93 | ```python3 main_nbody.py``` 94 | 95 | We did a sweep over seeds 0, 1, 2 with the default parameters in ```main_nbody.py``` fixed. In this case we cannot set ```num_ori=0``` as need to be able to handle input vectors. The mean squared error on predicted future positions of the particles is given below. Here we compared to the seminal works on this problem which use invariant feature representations ([EGNN](https://github.com/vgsatorras/egnn)) and steerable representations ([SEGNN](https://github.com/RobDHess/Steerable-E3-GNN)). Again, in contrast to the steerable method SEGNN ours works with invariants only and thus does not require specialized tensor product operations. 96 | 97 | | Method | MSE 98 | |-|- 99 | | EGNN | 0.0070 100 | | SEGNN | **0.0043** 101 | | PONITA | **0.0043** 102 | 103 | ____ 104 | ### QM9 (3D point clouds or position-orientation point clouds) 105 | 106 | _The Equivariant Denoising Diffusion Experiments will be added soon!_ 107 | 108 | We also tested the PONITA architecture for molecular property prediction. For the QM9 regression experiments run 109 | 110 | ```python3 main_QM9.py``` 111 | 112 | Since QM9 provides an ```edge_index``` derived from the covalent bonds, we have an option to treat the **molecules as point clouds in position-orientation space**. This is what we did in the PONITA entry below, where we specified in the code ```num_ori=-1```. As a baseline we compare against the seminal [DimeNet++](https://github.com/gasteigerjo/dimenet), which, like PONITA, processes the molecules via message passing over the edges. For PNITA, the position space method, we found that going to deep hurts performance and obtained best performance with ```layers=5``` and ```hidden_dim=128```. For PONITA we obtained best results with ```layers=9``` and a ```hidden_dim=256```. Further, we compared to the $\mathbb{R}^3 \times S^2$ point cloud approach (```num_ori=-1```) with the fiber-bundle (spherical grid-based) approach with ```num_ori=16```. Otherwise the settings between the PNITA and PONITA results were the same and all used a fully connected graph (```radius=1000```). The results are as follows. 113 | 114 | | Target | Unit | DimeNet++ | PNITA | PONITA (```num_ori=-1```| PONITA (```num_ori=16```) 115 | |-|-|-|-|-|- 116 | | $\mu$ | D | 0.0286 | 0.0207 | **0.0115** | 0.0121 117 | | $\alpha$ | $a_0^3$ | 0.0469 | 0.0602 | 0.0446 | **0.0375** 118 | | $\epsilon_{HOMO}$ | meV | 27.8 | 26.1 | 18.6 | **16.0** 119 | | $\epsilon_{LUMO}$ | meV | 19.7 | 21.9 | 15.3 | **14.5** 120 | | $\Delta \epsilon$ | meV | 34.8 | 43.4 | 33.5 | **30.4** 121 | | $\langle R^2 \rangle$ | $a_0^2$ | 0.331 | **0.149** | 0.227 | 0.235 122 | | ZPVE | meV | **1.29** | 1.53 | **1.29** | **1.29** 123 | | $U_0$ | meV | **8.02** | 10.71 | 9.20 | 8.31 124 | | $U$ | meV | **7.89** | 10.63 | 9.00 | 8.67 125 | | $H$ | meV | 8.11 | 11.00 | 8.54 | **8.04** 126 | | $G$ | meV | 0.0249 | 0.0112 | 0.0095 | **0.00863** 127 | | $c_v$ | $\frac{\mathrm{cal}}{\mathrm{mol} \, \mathrm{K}}$ | **0.0249** | 0.0307 | 0.0250 | **0.0242** 128 | 129 | ____ 130 | ### Super-pixel MNIST (2D point clouds) 131 | 132 | To showcase the capability of the code to also handle 2D data we tested on super-pixel MNIST. 133 | 134 | ```python3 main_mnist.py``` 135 | 136 | As a baseline for super-pixel MNIST we take Probabilistic Numeric Convolutional Neural Networks ([PNCNN](https://github.com/Qualcomm-AI-research/ProbabilisticNumericCNNs)), which at the time of writing is the state-of-the-art on this task. For PONITA we use the fiber bundle mode with ```num_ori=10```. For both PNITA and PONITA we utilize the provided edge_index. The results are as follows 137 | 138 | | Method | Error Rate 139 | |-|- 140 | | PNCNN | 1.24 $\pm$ 0.12 141 | | PNITA | 3.04 $\pm$ 0.09 142 | | PONITA | **1.17** $\pm$ 0.11 143 | 144 | ## Conda environment 145 | In order to run the code in this repository install the following conda environment 146 | ``` 147 | conda create --yes --name ponita python=3.10 numpy scipy matplotlib 148 | conda activate ponita 149 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia -y 150 | conda install pyg==2.3.1 -c pyg -y 151 | pip3 install wandb 152 | pip3 install pytorch_lightning==1.8.6 153 | pip3 install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-1.13.1+cu117.html 154 | ``` 155 | 156 | 157 | ## Acknowledgements 158 | The experimental setup builds upon the code bases of [EGNN repository](https://github.com/vgsatorras/egnn) and [EDM repository](https://github.com/ehoogeboom/e3_diffusion_for_molecules). The grid construction code is adapted from [Regular SE(3) Group Convolution](https://github.com/ThijsKuipers1995/gconv) library. We deeply thank the authors for open sourcing their codes. We are also very grateful to the developers of the amazing libraries [torch geometric](https://pytorch-geometric.readthedocs.io/en/latest/index.html), [pytorch lightning](https://lightning.ai/), and [weights and biases](https://https://wandb.ai/) ! 159 | --------------------------------------------------------------------------------