├── LICENSE ├── README.md ├── cace ├── __init__.py ├── calculators │ ├── __init__.py │ └── cace_calculator.py ├── data │ ├── __init__.py │ ├── atomic_data.py │ └── neighborhood.py ├── models │ ├── __init__.py │ ├── atomistic.py │ └── combined.py ├── modules │ ├── __init__.py │ ├── angular.py │ ├── angular_tools.py │ ├── atomwise.py │ ├── blocks.py │ ├── cutoff.py │ ├── ewald.py │ ├── feature_mix.py │ ├── forces.py │ ├── grad.py │ ├── interaction.py │ ├── les_wrapper.py │ ├── polarization.py │ ├── preprocess.py │ ├── product_basis.py │ ├── radial.py │ ├── radial_transform.py │ ├── symmetrize_basis.py │ ├── transform.py │ ├── type.py │ └── utils.py ├── representations │ ├── __init__.py │ └── cace_representation.py ├── tasks │ ├── __init__.py │ ├── evaluate.py │ ├── lightning.py │ ├── load_data.py │ ├── loss.py │ └── train.py └── tools │ ├── __init__.py │ ├── io_utils.py │ ├── metric.py │ ├── output.py │ ├── parser_train.py │ ├── scatter.py │ ├── torch_geometric │ ├── README.md │ ├── __init__.py │ ├── batch.py │ ├── data.py │ ├── dataloader.py │ ├── dataset.py │ └── utils.py │ ├── torch_tools.py │ └── utils.py ├── examples ├── water_train.py ├── water_train_pl.py └── water_train_w_ft_pl.py ├── scripts ├── cace_alchemical_rep_v0.pth ├── compute_avg_e0.py ├── compute_cace_desc.py ├── split.py └── train.py ├── setup.py └── tests ├── ewald_nopbc.py ├── test-cace-representation-rotation.ipynb ├── test_angular.py ├── test_edge_encoder.py ├── test_elementwise_multiply_tensors.py ├── test_ewald.py ├── test_ewald_triclinic.py ├── test_neighborhood.py └── test_pol_triclinic.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Bingqing Cheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cartesian Atomic Cluster Expansion for Machine Learning Interatomic Potentials (CACE) 2 | 3 | ## Summary 4 | 5 | The Cartesian Atomic Cluster Expansion (CACE) is a new approach for developing machine learning interatomic potentials. This method utilizes Cartesian coordinates to provide a complete description of atomic environments, maintaining interaction body orders. It integrates low-dimensional embeddings of chemical elements with inter-atomic message passing. 6 | 7 | ## Requirements 8 | 9 | - Python 3.6 or higher 10 | - NumPy 11 | - ASE (Atomic Simulation Environment) 12 | - PyTorch 13 | - matscipy 14 | 15 | ## Installation 16 | 17 | Please refer to the `setup.py` file for installation instructions. 18 | 19 | ## Usage 20 | 21 | Please refer to the `scripts/train.py`. 22 | 23 | More example scripts can be found in [https://github.com/BingqingCheng/cacefit]. 24 | 25 | ## License 26 | 27 | This project is licensed under the MIT License - see the LICENSE file for details. 28 | 29 | ## Citation 30 | 31 | ```text 32 | @article{cheng2024cartesian, 33 | title={Cartesian atomic cluster expansion for machine learning interatomic potentials}, 34 | author={Cheng, Bingqing}, 35 | journal={npj Computational Materials}, 36 | volume={10}, 37 | number={1}, 38 | pages={157}, 39 | year={2024}, 40 | publisher={Nature Publishing Group UK London} 41 | } 42 | ``` 43 | 44 | ## Contact 45 | 46 | For any queries regarding CACE, please contact Bingqing Cheng at tonicbq@gmail.com. 47 | 48 | -------------------------------------------------------------------------------- /cace/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['data', 'modules', 'tools', 'representations', 'models', 'tasks', 'calculators'] 2 | 3 | from . import data 4 | from . import modules 5 | from . import tools 6 | from . import representations 7 | from . import models 8 | from . import tasks 9 | from . import calculators 10 | -------------------------------------------------------------------------------- /cace/calculators/__init__.py: -------------------------------------------------------------------------------- 1 | from .cace_calculator import CACECalculator 2 | -------------------------------------------------------------------------------- /cace/calculators/cace_calculator.py: -------------------------------------------------------------------------------- 1 | # the CACE calculator for ASE 2 | 3 | from typing import Union, List 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from ase.calculators.calculator import Calculator, all_changes 9 | from ase.stress import full_3x3_to_voigt_6_stress 10 | 11 | from ..tools import torch_geometric, torch_tools, to_numpy 12 | from ..data import AtomicData 13 | 14 | __all__ = ["CACECalculator"] 15 | 16 | class CACECalculator(Calculator): 17 | """CACE ASE Calculator 18 | args: 19 | model_path: str or nn.module, path to model 20 | device: str, device to run on (cuda or cpu) 21 | compute_stress: bool, whether to compute stress 22 | energy_key: str, key for energy in model output 23 | forces_key: str, key for forces in model output 24 | energy_units_to_eV: float, conversion factor from model energy units to eV 25 | length_units_to_A: float, conversion factor from model length units to Angstroms 26 | atomic_energies: dict, dictionary of atomic energies to add to model output 27 | """ 28 | 29 | def __init__( 30 | self, 31 | model_path: Union[str, torch.nn.Module], 32 | device: str, 33 | energy_units_to_eV: float = 1.0, 34 | length_units_to_A: float = 1.0, 35 | electric_field_unit: float = 1.0, 36 | compute_stress = False, 37 | energy_key: str = 'energy', 38 | forces_key: str = 'forces', 39 | stress_key: str = 'stress', 40 | bec_key: str = 'bec', 41 | external_field: Union[float,List[float]] = None, 42 | keep_neutral: bool = True, # to keep BEC sum to be neutral 43 | atomic_energies: dict = None, 44 | output_index: int = None, # only used for multi-output models 45 | **kwargs, 46 | ): 47 | 48 | Calculator.__init__(self, **kwargs) 49 | self.implemented_properties = [ 50 | "energy", 51 | "forces", 52 | "stress", 53 | ] 54 | 55 | self.results = {} 56 | 57 | if isinstance(model_path, str): 58 | self.model = torch.load(f=model_path, map_location=device) 59 | elif isinstance(model_path, torch.nn.Module): 60 | self.model = model_path 61 | else: 62 | raise ValueError("model_path must be a string or nn.Module") 63 | self.model.to(device) 64 | 65 | self.device = torch_tools.init_device(device) 66 | self.energy_units_to_eV = energy_units_to_eV 67 | self.length_units_to_A = length_units_to_A 68 | self.electric_field_unit = electric_field_unit 69 | 70 | try: 71 | self.cutoff = self.model.representation.cutoff 72 | except AttributeError: 73 | self.cutoff = self.model.models[0].representation.cutoff 74 | 75 | self.atomic_energies = atomic_energies 76 | 77 | self.compute_stress = compute_stress 78 | self.energy_key = energy_key 79 | self.forces_key = forces_key 80 | self.stress_key = stress_key 81 | self.bec_key = bec_key 82 | self.keep_neutral = keep_neutral 83 | 84 | if external_field is not None: 85 | if isinstance(external_field, float): 86 | self.external_field = external_field 87 | else: 88 | self.external_field = np.array(external_field) 89 | else: 90 | self.external_field = None 91 | self.output_index = output_index 92 | 93 | for param in self.model.parameters(): 94 | param.requires_grad = False 95 | 96 | def calculate(self, atoms=None, properties=None, system_changes=all_changes): 97 | """ 98 | Calculate properties. 99 | :param atoms: ase.Atoms object 100 | :param properties: [str], properties to be computed, used by ASE internally 101 | :param system_changes: [str], system changes since last calculation, used by ASE internally 102 | :return: 103 | """ 104 | # call to base-class to set atoms attribute 105 | Calculator.calculate(self, atoms) 106 | 107 | if not hasattr(self, "output_index"): 108 | self.output_index = None 109 | 110 | # prepare data 111 | data_loader = torch_geometric.dataloader.DataLoader( 112 | dataset=[ 113 | AtomicData.from_atoms( 114 | atoms, cutoff=self.cutoff 115 | ) 116 | ], 117 | batch_size=1, 118 | shuffle=False, 119 | drop_last=False, 120 | ) 121 | 122 | batch_base = next(iter(data_loader)).to(self.device) 123 | batch = batch_base.clone() 124 | output = self.model(batch.to_dict(), training=True, compute_stress=self.compute_stress, output_index=self.output_index) 125 | energy_output = to_numpy(output[self.energy_key]) 126 | forces_output = to_numpy(output[self.forces_key]) 127 | if self.external_field is not None and self.bec_key is not None: 128 | bec_output = to_numpy(output[self.bec_key]) 129 | # subtract atomic energies if available 130 | if self.atomic_energies: 131 | e0 = sum(self.atomic_energies.get(Z, 0) for Z in atoms.get_atomic_numbers()) 132 | else: 133 | e0 = 0.0 134 | self.results["energy"] = (energy_output + e0) * self.energy_units_to_eV 135 | self.results["forces"] = forces_output * self.energy_units_to_eV / self.length_units_to_A 136 | if self.external_field is not None: 137 | if isinstance(self.external_field, float): 138 | if self.keep_neutral: 139 | correction = -np.average(bec_output, axis=0) 140 | bec_output += correction 141 | forces_bec = bec_output * self.external_field * self.electric_field_unit 142 | self.results["forces"] += forces_bec # bec_output * self.external_field * self.electric_field_unit 143 | else: 144 | if self.keep_neutral: 145 | correction = -np.average(bec_output, axis=0) 146 | bec_output += correction 147 | forces_bec = bec_output @ self.external_field * self.electric_field_unit 148 | self.results["forces"] += forces_bec # bec_output * self.external_field * self.electric_field_unit 149 | 150 | if self.compute_stress and output[self.stress_key] is not None: 151 | stress = to_numpy(output[self.stress_key]) 152 | # stress has units eng / len^3: 153 | self.results["stress"] = ( 154 | stress * (self.energy_units_to_eV / self.length_units_to_A**3) 155 | )[0] 156 | self.results["stress"] = full_3x3_to_voigt_6_stress(self.results["stress"]) 157 | 158 | return self.results 159 | -------------------------------------------------------------------------------- /cace/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .atomic_data import AtomicData, default_data_key, get_data_loader 2 | from .neighborhood import get_neighborhood 3 | 4 | __all__ = ["AtomicData", "default_data_key", "get_neighborhood", "get_data_loader"] 5 | -------------------------------------------------------------------------------- /cace/data/atomic_data.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # Atomic Data Class for handling molecules as graphs 3 | # modified from MACE 4 | # This program is distributed under the MIT License (see MIT.md) 5 | ########################################################################################### 6 | 7 | from typing import Optional, Sequence, Dict 8 | from ase import Atoms 9 | import numpy as np 10 | #import torch_geometric 11 | from ..tools import torch_geometric 12 | import torch.nn as nn 13 | import torch.utils.data 14 | from ..tools import voigt_to_matrix 15 | 16 | from .neighborhood import get_neighborhood 17 | 18 | default_data_key = { 19 | "energy": "energy", 20 | "forces": "forces", 21 | "molecular_index": "molecular_index", 22 | "stress": "stress", 23 | "virials": "virials", 24 | "dipole": None, 25 | "charges": None, 26 | "weights": None, 27 | "energy_weight": None, 28 | "force_weight": None, 29 | "stress_weight": None, 30 | "virial_weight": None, 31 | } 32 | 33 | class AtomicData(torch_geometric.data.Data): 34 | atomic_numbers: torch.Tensor 35 | num_graphs: torch.Tensor 36 | num_nodes: torch.Tensor 37 | batch: torch.Tensor 38 | edge_index: torch.Tensor 39 | node_attrs: torch.Tensor 40 | n_atom_basis: torch.Tensor 41 | positions: torch.Tensor 42 | shifts: torch.Tensor 43 | unit_shifts: torch.Tensor 44 | cell: torch.Tensor 45 | forces: torch.Tensor 46 | molecular_index: torch.Tensor 47 | energy: torch.Tensor 48 | stress: torch.Tensor 49 | virials: torch.Tensor 50 | dipole: torch.Tensor 51 | charges: torch.Tensor 52 | weight: torch.Tensor 53 | energy_weight: torch.Tensor 54 | forces_weight: torch.Tensor 55 | stress_weight: torch.Tensor 56 | virials_weight: torch.Tensor 57 | 58 | def __init__( 59 | self, 60 | edge_index: torch.Tensor, # [2, n_edges], always sender -> receiver 61 | atomic_numbers: torch.Tensor, # [n_nodes] 62 | positions: torch.Tensor, # [n_nodes, 3] 63 | shifts: torch.Tensor, # [n_edges, 3], 64 | unit_shifts: torch.Tensor, # [n_edges, 3] 65 | num_nodes: Optional[torch.Tensor] = None, #[,] 66 | cell: Optional[torch.Tensor] = None, # [3,3] 67 | forces: Optional[torch.Tensor] = None, # [n_nodes, 3] 68 | molecular_index: Optional[torch.Tensor] = None, # [n_nodes] 69 | energy: Optional[torch.Tensor] = None, # [, ] 70 | stress: Optional[torch.Tensor] = None, # [1,3,3] 71 | virials: Optional[torch.Tensor] = None, # [1,3,3] 72 | additional_info: Optional[Dict] = None, 73 | #dipole: Optional[torch.Tensor], # [, 3] 74 | #charges: Optional[torch.Tensor], # [n_nodes, ] 75 | ): 76 | # Check shapes 77 | if num_nodes is None: 78 | num_nodes = atomic_numbers.shape[0] 79 | else: 80 | assert num_nodes == atomic_numbers.shape[0] 81 | assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2 82 | assert positions.shape == (num_nodes, 3) 83 | assert shifts.shape[1] == 3 84 | assert unit_shifts.shape[1] == 3 85 | assert cell is None or cell.shape == (3, 3) 86 | assert forces is None or forces.shape == (num_nodes, 3) 87 | assert molecular_index is None or molecular_index.shape == (num_nodes,) 88 | assert energy is None or len(energy.shape) == 0 89 | assert stress is None or stress.shape == (1, 3, 3) 90 | assert virials is None or virials.shape == (1, 3, 3) 91 | #assert dipole is None or dipole.shape[-1] == 3 92 | #assert charges is None or charges.shape == (num_nodes,) 93 | # Aggregate data 94 | data = { 95 | "edge_index": edge_index, 96 | "positions": positions, 97 | "shifts": shifts, 98 | "unit_shifts": unit_shifts, 99 | "cell": cell, 100 | "atomic_numbers": atomic_numbers, 101 | "num_nodes": num_nodes, 102 | "forces": forces, 103 | "molecular_index": molecular_index, 104 | "energy": energy, 105 | "stress": stress, 106 | "virials": virials, 107 | #"dipole": dipole, 108 | #"charges": charges, 109 | } 110 | if additional_info is not None: 111 | data.update(additional_info) 112 | super().__init__(**data) 113 | 114 | @classmethod 115 | def from_atoms( 116 | cls, 117 | atoms: Atoms, 118 | cutoff: float, 119 | data_key: Dict[str, str] = None, 120 | atomic_energies: Optional[Dict[int, float]] = None, 121 | ) -> "AtomicData": 122 | if data_key is not None: 123 | data_key = default_data_key.update(data_key) 124 | data_key = default_data_key 125 | positions = atoms.get_positions() 126 | pbc = tuple(atoms.get_pbc()) 127 | cell = np.array(atoms.get_cell()) 128 | atomic_numbers = atoms.get_atomic_numbers() 129 | 130 | edge_index, shifts, unit_shifts = get_neighborhood( 131 | positions=positions, 132 | cutoff=cutoff, 133 | pbc=pbc, 134 | cell=cell 135 | ) 136 | 137 | try: 138 | energy = atoms.info.get(data_key["energy"], None) # eV 139 | except: 140 | # this ugly bit is for compatibility with newest ASE versions 141 | if data_key['energy'] == 'energy': 142 | energy = atoms.get_potential_energy() 143 | else: 144 | energy = None 145 | 146 | # subtract atomic energies if available 147 | if atomic_energies and energy is not None: 148 | energy -= sum(atomic_energies.get(Z, 0) for Z in atomic_numbers) 149 | try: 150 | forces = atoms.arrays.get(data_key["forces"], None) # eV / Ang 151 | except: 152 | if data_key['forces'] == 'forces': 153 | forces = atoms.get_forces() 154 | else: 155 | forces = None 156 | molecular_index = atoms.arrays.get(data_key["molecular_index"], None) # index of molecules 157 | stress = atoms.info.get(data_key["stress"], None) # eV / Ang 158 | virials = atoms.info.get(data_key["virials"], None) 159 | 160 | # process these to make tensors 161 | cell = ( 162 | torch.tensor(cell, dtype=torch.get_default_dtype()) 163 | if cell is not None 164 | else torch.tensor( 165 | 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype() 166 | ).view(3, 3) 167 | ) 168 | 169 | forces = ( 170 | torch.tensor(forces, dtype=torch.get_default_dtype()) 171 | if forces is not None 172 | else None 173 | ) 174 | 175 | molecular_index = ( 176 | torch.tensor(molecular_index, dtype=torch.long) 177 | if molecular_index is not None 178 | else None 179 | ) 180 | 181 | energy = ( 182 | torch.tensor(energy, dtype=torch.get_default_dtype()) 183 | if energy is not None 184 | else None 185 | ) 186 | stress = ( 187 | voigt_to_matrix( 188 | torch.tensor(stress, dtype=torch.get_default_dtype()) 189 | ).unsqueeze(0) 190 | if stress is not None 191 | else None 192 | ) 193 | virials = ( 194 | torch.tensor(virials, dtype=torch.get_default_dtype()).unsqueeze(0) 195 | if virials is not None 196 | else None 197 | ) 198 | 199 | # obtain additional info 200 | # enumerate the data_key and extract data 201 | additional_info = {} 202 | for key, kk in data_key.items(): 203 | if kk is None or key in ['energy', 'forces', 'stress', 'virials', 'molecular_index']: 204 | continue 205 | else: 206 | more_info = atoms.info.get(kk, None) 207 | if more_info is None: 208 | more_info = atoms.arrays.get(kk, None) 209 | more_info = ( 210 | torch.tensor(more_info, dtype=torch.get_default_dtype()) 211 | if more_info is not None 212 | else None 213 | ) 214 | additional_info[key] = more_info 215 | 216 | return cls( 217 | edge_index=torch.tensor(edge_index, dtype=torch.long), 218 | positions=torch.tensor(positions, dtype=torch.get_default_dtype()), 219 | shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()), 220 | unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()), 221 | cell=cell, 222 | atomic_numbers=torch.tensor(atomic_numbers, dtype=torch.long), 223 | num_nodes=atomic_numbers.shape[0], 224 | forces=forces, 225 | molecular_index=molecular_index, 226 | energy=energy, 227 | stress=stress, 228 | virials=virials, 229 | additional_info=additional_info, 230 | ) 231 | 232 | 233 | def get_data_loader( 234 | dataset: Sequence[AtomicData], 235 | batch_size: int, 236 | shuffle=True, 237 | drop_last=False, 238 | ) -> torch.utils.data.DataLoader: 239 | return torch_geometric.dataloader.DataLoader( 240 | dataset=dataset, 241 | batch_size=batch_size, 242 | shuffle=shuffle, 243 | drop_last=drop_last, 244 | ) 245 | -------------------------------------------------------------------------------- /cace/data/neighborhood.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # Neighborhood construction 3 | # modified from MACE 4 | # This program is distributed under the MIT License (see MIT.md) 5 | ########################################################################################### 6 | 7 | from typing import Optional, Tuple 8 | 9 | import ase.neighborlist 10 | import numpy as np 11 | 12 | 13 | from matscipy.neighbours import neighbour_list 14 | 15 | 16 | def get_neighborhood( 17 | positions: np.ndarray, # [num_positions, 3] 18 | cutoff: float, 19 | pbc: Optional[Tuple[bool, bool, bool]] = None, 20 | cell: Optional[np.ndarray] = None, # [3, 3] 21 | true_self_interaction=False, 22 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 23 | if pbc is None: 24 | pbc = (False, False, False) 25 | 26 | if cell is None or cell.any() == np.zeros((3, 3)).any(): 27 | cell = np.identity(3, dtype=float) 28 | 29 | assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) 30 | assert cell.shape == (3, 3) 31 | 32 | pbc_x = pbc[0] 33 | pbc_y = pbc[1] 34 | pbc_z = pbc[2] 35 | identity = np.identity(3, dtype=float) 36 | max_positions = np.max(np.absolute(positions)) + 1 37 | # Extend cell in non-periodic directions 38 | # For models with more than 5 layers, the multiplicative constant needs to be increased. 39 | if not pbc_x: 40 | cell[:, 0] = max_positions * 5 * cutoff * identity[:, 0] 41 | if not pbc_y: 42 | cell[:, 1] = max_positions * 5 * cutoff * identity[:, 1] 43 | if not pbc_z: 44 | cell[:, 2] = max_positions * 5 * cutoff * identity[:, 2] 45 | 46 | sender, receiver, unit_shifts = neighbour_list( 47 | quantities="ijS", 48 | pbc=pbc, 49 | cell=cell, 50 | positions=positions, 51 | cutoff=float(cutoff), 52 | ) 53 | 54 | if not true_self_interaction: 55 | # Eliminate self-edges that don't cross periodic boundaries 56 | true_self_edge = sender == receiver 57 | true_self_edge &= np.all(unit_shifts == 0, axis=1) 58 | keep_edge = ~true_self_edge 59 | 60 | # Note: after eliminating self-edges, it can be that no edges remain in this system 61 | sender = sender[keep_edge] 62 | receiver = receiver[keep_edge] 63 | unit_shifts = unit_shifts[keep_edge] 64 | 65 | # Build output 66 | edge_index = np.stack((sender, receiver)) # [2, n_edges] 67 | 68 | # From the docs: With the shift vector S, the distances D between atoms can be computed from 69 | # D = positions[j]-positions[i]+S.dot(cell) 70 | shifts = np.dot(unit_shifts, cell) # [n_edges, 3] 71 | 72 | return edge_index, shifts, unit_shifts 73 | 74 | def get_neighborhood_ASE( 75 | positions: np.ndarray, # [num_positions, 3] 76 | cutoff: float, 77 | pbc: Optional[Tuple[bool, bool, bool]] = None, 78 | cell: Optional[np.ndarray] = None, # [3, 3] 79 | true_self_interaction=False, 80 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 81 | if pbc is None: 82 | pbc = (False, False, False) 83 | 84 | if cell is None or cell.any() == np.zeros((3, 3)).any(): 85 | cell = 1000. * np.identity(3, dtype=float) 86 | pbc = (False, False, False) 87 | 88 | assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) 89 | assert cell.shape == (3, 3) 90 | 91 | """ 92 | ‘i’ : first atom index 93 | ‘j’ : second atom index 94 | ‘d’ : absolute distance 95 | ‘D’ : distance vector 96 | ‘S’ : shift vector (number of cell boundaries crossed by the bond between atom i and j). With the shift vector S, the distances D between atoms can be computed from: D = positions[j]-positions[i]+S.dot(cell) 97 | """ 98 | sender, receiver, unit_shifts = ase.neighborlist.primitive_neighbor_list( 99 | quantities="ijS", 100 | pbc=pbc, 101 | cell=cell, 102 | positions=positions, 103 | cutoff=cutoff, 104 | self_interaction=True, # we want edges from atom to itself in different periodic images 105 | use_scaled_positions=False, # positions are not scaled positions 106 | ) 107 | 108 | if not true_self_interaction: 109 | # Eliminate self-edges that don't cross periodic boundaries 110 | true_self_edge = sender == receiver 111 | true_self_edge &= np.all(unit_shifts == 0, axis=1) 112 | keep_edge = ~true_self_edge 113 | 114 | # Note: after eliminating self-edges, it can be that no edges remain in this system 115 | sender = sender[keep_edge] 116 | receiver = receiver[keep_edge] 117 | unit_shifts = unit_shifts[keep_edge] 118 | 119 | # Build output 120 | edge_index = np.stack((sender, receiver)) # [2, n_edges] 121 | 122 | # From the docs: With the shift vector S, the distances D between atoms can be computed from 123 | # D = positions[j]-positions[i]+S.dot(cell) 124 | shifts = np.dot(unit_shifts, cell) # [n_edges, 3] 125 | 126 | # sender = edge_index[0] receiver = edge_index[1] 127 | return edge_index, shifts, unit_shifts 128 | -------------------------------------------------------------------------------- /cace/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .atomistic import * 2 | from .combined import * 3 | -------------------------------------------------------------------------------- /cace/models/atomistic.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..modules import Transform 7 | from ..modules import Preprocess 8 | from ..tools import torch_geometric 9 | 10 | __all__ = ["AtomisticModel", "NeuralNetworkPotential"] 11 | 12 | 13 | class AtomisticModel(nn.Module): 14 | """ 15 | Base class for atomistic neural network models. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | postprocessors: Optional[List[Transform]] = None, 21 | do_postprocessing: bool = False, 22 | ): 23 | """ 24 | Args: 25 | postprocessors: Post-processing transforms that may be 26 | initialized using the `datamodule`, but are not 27 | applied during training. 28 | do_postprocessing: If true, post-processing is activated. 29 | """ 30 | super().__init__() 31 | self.do_postprocessing = do_postprocessing 32 | self.postprocessors = nn.ModuleList(postprocessors) 33 | self.required_derivatives: Optional[List[str]] = None 34 | self.model_outputs: Optional[List[str]] = None 35 | 36 | def collect_derivatives(self) -> List[str]: 37 | self.required_derivatives = None 38 | required_derivatives = set() 39 | for m in self.modules(): 40 | if ( 41 | hasattr(m, "required_derivatives") 42 | and m.required_derivatives is not None 43 | ): 44 | required_derivatives.update(m.required_derivatives) 45 | required_derivatives: List[str] = list(required_derivatives) 46 | self.required_derivatives = required_derivatives 47 | 48 | def collect_outputs(self) -> List[str]: 49 | self.model_outputs = None 50 | model_outputs = set() 51 | for m in self.modules(): 52 | if hasattr(m, "model_outputs") and m.model_outputs is not None: 53 | model_outputs.update(m.model_outputs) 54 | model_outputs: List[str] = list(model_outputs) 55 | self.model_outputs = model_outputs 56 | 57 | def initialize_derivatives( 58 | self, data: Dict[str, torch.Tensor] 59 | ) -> Dict[str, torch.Tensor]: 60 | for p in self.required_derivatives: 61 | if isinstance(data, torch_geometric.Batch): 62 | if p in data.to_dict().keys(): 63 | data[p].requires_grad_(True) 64 | else: 65 | if p in data.keys(): 66 | data[p].requires_grad_(True) 67 | return data 68 | 69 | def initialize_transforms(self, datamodule): 70 | for module in self.modules(): 71 | if isinstance(module, Transform): 72 | module.datamodule(datamodule) 73 | 74 | def postprocess(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 75 | if self.do_postprocessing: 76 | # apply postprocessing 77 | for pp in self.postprocessors: 78 | data = pp(data) 79 | return data 80 | 81 | def extract_outputs( 82 | self, data: Dict[str, torch.Tensor] 83 | ) -> Dict[str, torch.Tensor]: 84 | results = {k: data[k] for k in self.model_outputs} 85 | return results 86 | 87 | 88 | class NeuralNetworkPotential(AtomisticModel): 89 | """ 90 | A generic neural network potential class that sequentially applies a list of input 91 | modules, a representation module and a list of output modules. 92 | 93 | This can be flexibly configured for various, e.g. property prediction or potential 94 | energy sufaces with response properties. 95 | """ 96 | 97 | def __init__( 98 | self, 99 | representation: nn.Module = None, 100 | input_modules: List[nn.Module] = None, 101 | output_modules: List[nn.Module] = None, 102 | postprocessors: Optional[List[Transform]] = None, 103 | do_postprocessing: bool = False, 104 | keep_graph: bool = False, 105 | ): 106 | """ 107 | Args: 108 | representation: The module that builds representation from data. 109 | input_modules: Modules that are applied before representation, e.g. to 110 | modify input or add additional tensors for response properties. 111 | output_modules: Modules that predict output properties from the 112 | representation. 113 | postprocessors: Post-processing transforms that may be initialized using the 114 | `datamodule`, but are not applied during training. 115 | input_dtype_str: The dtype of real data. 116 | do_postprocessing: If true, post-processing is activated. 117 | """ 118 | super().__init__( 119 | postprocessors=postprocessors, 120 | do_postprocessing=do_postprocessing, 121 | ) 122 | self.representation = representation 123 | if input_modules is None: 124 | preprocessor = Preprocess() 125 | input_modules = [preprocessor] 126 | self.input_modules = nn.ModuleList(input_modules) 127 | self.output_modules = nn.ModuleList(output_modules) 128 | 129 | self.collect_derivatives() 130 | self.collect_outputs() 131 | 132 | self.keep_graph = keep_graph 133 | 134 | def add_module(self, module: nn.Module, module_type: str = "output"): 135 | if module_type == "input": 136 | self.input_modules.append(module) 137 | elif module_type == "output": 138 | self.output_modules.append(module) 139 | self.collect_derivatives() 140 | self.collect_outputs() 141 | else: 142 | raise ValueError(f"Unknown module type {module_type}") 143 | 144 | def remove_module(self, module_index: int, module_type: str = "output"): 145 | if module_type == "output": 146 | del self.output_modules[module_index] 147 | self.collect_derivatives() 148 | self.collect_outputs() 149 | else: 150 | raise ValueError(f"Unknown module type {module_type}") 151 | 152 | def forward(self, 153 | data: Dict[str, torch.Tensor], 154 | training: bool = False, 155 | compute_stress: bool = True, 156 | compute_virials: bool = False, 157 | output_index: int = None, # only used for multiple-head output 158 | ) -> Dict[str, torch.Tensor]: 159 | # initialize derivatives for response properties 160 | data = self.initialize_derivatives(data) 161 | 162 | if 'stress' in self.model_outputs or 'CACE_stress' in self.model_outputs: 163 | compute_stress = True 164 | for m in self.input_modules: 165 | data = m(data, compute_stress=compute_stress, compute_virials=compute_virials) 166 | 167 | if self.representation is not None: 168 | data = self.representation(data) 169 | 170 | for m in self.output_modules: 171 | if hasattr(self, "keep_graph"): 172 | training = training or self.keep_graph 173 | data = m(data, training=training, output_index=output_index) 174 | 175 | # apply postprocessing (if enabled) 176 | data = self.postprocess(data) 177 | 178 | results = self.extract_outputs(data) 179 | 180 | return results 181 | 182 | -------------------------------------------------------------------------------- /cace/models/combined.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Dict, List 4 | 5 | __all__ = ['CombinePotential'] 6 | 7 | class CombinePotential(nn.Module): 8 | def __init__( 9 | self, 10 | potentials: List[nn.Module], 11 | potential_keys: List[Dict], 12 | operation = None, 13 | ): 14 | """ 15 | Combine multiple potentials into a single potential. 16 | Args: 17 | potentials: List of potentials to combine. 18 | potential_keys: List of dictionaries with keys for each potential. 19 | e.g. [pot1, pot2] where 20 | pot1 = {'CACE_energy': 'CACE_energy_intra', 21 | 'CACE_forces': 'CACE_forces_intra', 22 | 'weight': 1. 23 | } 24 | 25 | pot2 = {'CACE_energy': 'CACE_energy_inter', 26 | 'CACE_forces': 'CACE_forces_inter', 27 | 'weight': 0.01, 28 | } 29 | out_keys: List of keys to output. Should be present in all potential_keys. 30 | operation: Operation to combine the outputs of the potentials. 31 | """ 32 | super().__init__() 33 | self.models = nn.ModuleList([potential for potential in potentials]) 34 | self.potential_keys = potential_keys 35 | self.required_derivatives = [] 36 | for potential in potentials: 37 | for d in potential.required_derivatives: 38 | if d not in self.required_derivatives: 39 | self.required_derivatives.append(d) 40 | 41 | self.out_keys = [] 42 | for key in potential_keys[0]: 43 | if all(key in potential_key for potential_key in self.potential_keys) and key != 'weight': 44 | self.out_keys.append(key) 45 | 46 | if operation is None: 47 | # Default operation (sum) 48 | self.operation = self.default_operation 49 | else: 50 | self.operation = operation 51 | 52 | def default_operation(self, my_list): 53 | return torch.stack(my_list).sum(0) 54 | 55 | 56 | def forward(self, 57 | data: Dict[str, torch.Tensor], 58 | training: bool = False, 59 | compute_stress: bool = False, 60 | compute_virials: bool = False, 61 | output_index: int = None, # only used for multiple-head output 62 | ) -> Dict[str, torch.Tensor]: 63 | results = {} 64 | output = {} 65 | for i, potential in enumerate(self.models): 66 | result = potential(data, training, compute_stress, compute_virials, output_index) 67 | results[i] = result 68 | output.update(result) 69 | 70 | for key in self.out_keys: 71 | values = [] 72 | for i, potential_key in enumerate(self.potential_keys): 73 | v_now = results[i][potential_key[key]] 74 | if 'weight' in potential_key: 75 | v_now *= potential_key['weight'] 76 | values.append(v_now) 77 | if values: 78 | output[key] = self.operation(values) 79 | return output 80 | -------------------------------------------------------------------------------- /cace/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocess import * 2 | 3 | from .angular import * 4 | 5 | from .angular_tools import * 6 | 7 | from .symmetrize_basis import * 8 | 9 | from .product_basis import * 10 | 11 | from .cutoff import * 12 | 13 | from .radial import * 14 | 15 | from .radial_transform import * 16 | 17 | from .type import * 18 | 19 | from .utils import * 20 | 21 | from .blocks import * 22 | 23 | from .atomwise import * 24 | 25 | from .ewald import * 26 | 27 | from .forces import * 28 | 29 | from .interaction import * 30 | 31 | from .transform import * 32 | 33 | from .feature_mix import * 34 | 35 | # from .node_edge_former import * 36 | 37 | from .polarization import * 38 | 39 | from .grad import * 40 | 41 | from .les_wrapper import * 42 | -------------------------------------------------------------------------------- /cace/modules/angular.py: -------------------------------------------------------------------------------- 1 | ############################################### 2 | # This module contains functions to compute the angular part of the 3 | # edge basis functions 4 | ############################################### 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from math import factorial 10 | from collections import OrderedDict 11 | 12 | __all__=['AngularComponent', 'AngularComponent_GPU', 'make_lxlylz_list', 'make_lxlylz', 'make_l_dict', 'l_dict_to_lxlylz_list', 'compute_length_lxlylz', 'compute_length_lmax', 'compute_length_lmax_numerical', 'lxlylz_factorial_coef', 'lxlylz_factorial_coef_torch', 'l1l2_factorial_coef'] 13 | 14 | import torch 15 | import torch.nn as nn 16 | from collections import OrderedDict 17 | 18 | class AngularComponent(nn.Module): 19 | """ Angular component of the edge basis functions 20 | Optimized for CPU usage (use recursive formula) 21 | """ 22 | def __init__(self, l_max): 23 | super().__init__() 24 | self.l_max = l_max 25 | self.precompute_lxlylz() 26 | 27 | def precompute_lxlylz(self): 28 | self.lxlylz_dict = OrderedDict({l: [] for l in range(self.l_max + 1)}) 29 | self.lxlylz_dict[0] = [(0, 0, 0)] 30 | for l in range(1, self.l_max + 1): 31 | for prev_lxlylz_combination in self.lxlylz_dict[l - 1]: 32 | for i in range(3): 33 | lxlylz_combination = list(prev_lxlylz_combination) 34 | lxlylz_combination[i] += 1 35 | lxlylz_combination_tuple = tuple(lxlylz_combination) 36 | if lxlylz_combination_tuple not in self.lxlylz_dict[l]: 37 | self.lxlylz_dict[l].append(lxlylz_combination_tuple) 38 | self.lxlylz_list = self._convert_lxlylz_to_list() 39 | # get the start and the end index of the lxlylz_list for each l 40 | self.lxlylz_index = torch.zeros((self.l_max+1, 2), dtype=torch.long) 41 | for l in range(self.l_max+1): 42 | self.lxlylz_index[l, 0] = 0 if l == 0 else self.lxlylz_index[l-1, 1] 43 | self.lxlylz_index[l, 1] = compute_length_lmax(l) 44 | 45 | def forward(self, vectors: torch.Tensor) -> torch.Tensor: 46 | 47 | computed_values = {(0, 0, 0): torch.ones(vectors.size(0), device=vectors.device, dtype=vectors.dtype)} 48 | for l in range(1, self.l_max + 1): 49 | for lxlylz_combination in self.lxlylz_dict[l]: 50 | prev_lxlylz_combination = tuple(l - 1 if i == lxlylz_combination.index(max(lxlylz_combination)) else l for i, l in enumerate(lxlylz_combination)) 51 | i = lxlylz_combination.index(max(lxlylz_combination)) 52 | computed_values[lxlylz_combination] = computed_values[prev_lxlylz_combination] * vectors[:, i] 53 | 54 | computed_values_list = self._convert_computed_values_to_list(computed_values) 55 | return torch.stack(computed_values_list, dim=1) 56 | 57 | def _convert_lxlylz_to_list(self): 58 | lxlylz_list = [] 59 | for l, combinations in self.lxlylz_dict.items(): 60 | lxlylz_list.extend(combinations) 61 | return lxlylz_list 62 | 63 | def _convert_computed_values_to_list(self, computed_values): 64 | return [computed_values[comb] for comb in self.lxlylz_list] 65 | 66 | def get_lxlylz_list(self): 67 | if self.lxlylz_list is None: 68 | raise ValueError("You must call forward before getting lxlylz_list") 69 | return self.lxlylz_list 70 | 71 | def get_lxlylz_dict(self): 72 | return self.lxlylz_dict 73 | 74 | def get_lxlylz_index(self): 75 | return self.lxlylz_index 76 | 77 | def __repr__(self): 78 | return f"AngularComponent(l_max={self.l_max})" 79 | 80 | class AngularComponent_GPU(nn.Module): 81 | """ Angular component of the edge basis functions 82 | This version runs faster on gpus but slower on cpus 83 | The ordering of lxlylz_list is different from the CPU version 84 | """ 85 | def __init__(self, l_max): 86 | super().__init__() 87 | self.l_max = l_max 88 | self.lxlylz_dict = make_l_dict(l_max) 89 | self.lxlylz_list = l_dict_to_lxlylz_list(self.lxlylz_dict) 90 | 91 | def forward(self, vectors: torch.Tensor) -> torch.Tensor: 92 | 93 | lxlylz_tensor = torch.tensor(self.lxlylz_list, device=vectors.device, dtype=vectors.dtype) 94 | 95 | # Expand vectors and lxlylz_tensor for broadcasting 96 | vectors_expanded = vectors[:, None, :] # Shape: [N, 1, 3] 97 | lxlylz_expanded = lxlylz_tensor[None, :, :] # Shape: [1, M, 3] 98 | 99 | # Compute terms using broadcasting 100 | # Each vector component is raised to the power of corresponding lx, ly, lz 101 | # Somehow this is causing trouble on gpus when doing second order derivatives!!! 102 | terms = vectors_expanded ** lxlylz_expanded # Shape: [N, M, 3] 103 | 104 | # Multiply across the last dimension (x^lx * y^ly * z^lz) for each term 105 | computed_terms = torch.prod(terms, dim=-1) # Shape: [N, M] 106 | # to avoid the mps problem with cumprod 107 | #computed_terms = terms[:, :, 0] * terms[:, :, 1] * terms[:, :, 2] 108 | 109 | return computed_terms 110 | 111 | def get_lxlylz_list(self): 112 | return self.lxlylz_list 113 | 114 | def get_lxlylz_dict(self): 115 | return self.lxlylz_dict 116 | 117 | def __repr__(self): 118 | return f"AngularComponent_GPU(l_max={self.l_max})" 119 | 120 | def make_lxlylz_list(l_max: int): 121 | """ 122 | make a list of lxlylz up to l_max 123 | """ 124 | l_dict = make_l_dict(l_max) 125 | return l_dict_to_lxlylz_list(l_dict) 126 | 127 | def l_index_select(l): 128 | """ select the index of the lxlylz_list based on l """ 129 | return np.arange(compute_length_lmax(l-1), compute_length_lmax(l)) 130 | 131 | def make_lxlylz(l): 132 | """ 133 | make a list of lxlylz such that lx + ly + lz = l 134 | """ 135 | lxlylz = [] 136 | for lx in range(l+1): 137 | for ly in range(l+1): 138 | lz = l - lx - ly 139 | if lz >= 0: 140 | lxlylz.append([lx, ly, lz]) 141 | #return torch.tensor(lxlylz, dtype=torch.int64) 142 | return lxlylz 143 | 144 | def make_l_dict(l_max): 145 | """ 146 | make a ordered dictionary of lxlylz list 147 | up to l_max 148 | """ 149 | l_dict = OrderedDict() 150 | for l in range(l_max+1): 151 | l_dict[l] = make_lxlylz(l) 152 | return l_dict 153 | 154 | def l_dict_to_lxlylz_list(l_dict): 155 | """ 156 | convert the ordered dictionary to a list of lxlylz 157 | """ 158 | lxlylz_list = [] 159 | for l in l_dict: 160 | lxlylz_list += l_dict[l] 161 | return lxlylz_list 162 | 163 | def compute_length_lxlylz(l): 164 | """ compute the length of the lxlylz list based on l """ 165 | return int((l+1)*(l+2)/2) 166 | 167 | def compute_length_lmax(l_max): 168 | """ compute the length of the lxlylz list based on l_max """ 169 | return int((l_max+1)*(l_max+2)*(l_max+3)/6) 170 | 171 | def compute_length_lmax_numerical(l_max): 172 | """ compute the length of the lxlylz list based on l_max numerically""" 173 | length = 0 174 | for l in range(l_max+1): 175 | length += compute_length_lxlylz(l) 176 | return length 177 | 178 | 179 | def l1l2_factorial_coef(l1, l2): 180 | # Ensure inputs are integers 181 | if not all(isinstance(n, int) for n in l1): 182 | raise ValueError("All elements of l1 must be integers.") 183 | if not all(isinstance(n, int) for n in l2): 184 | raise ValueError("All elements of l2 must be integers.") 185 | 186 | # Compute the multinomial coefficient 187 | result = 1 188 | for l1i, l2i in zip(l1, l2): 189 | result *= factorial(l1i + l2i) 190 | result /= factorial(l1i) 191 | result /= factorial(l2i) 192 | return result 193 | 194 | def lxlylz_factorial_coef(lxlylz): 195 | # Ensure inputs are integers 196 | if not all(isinstance(n, int) for n in lxlylz): 197 | raise ValueError("All elements of lxlylz must be integers.") 198 | 199 | # Sort the elements in descending order 200 | sorted_lxlylz = sorted(lxlylz, reverse=True) 201 | 202 | # Compute the sum l = lx + ly + lz 203 | l = sum(sorted_lxlylz) 204 | 205 | # Compute the multinomial coefficient 206 | result = factorial(l) 207 | for lxly in sorted_lxlylz: 208 | result //= factorial(lxly) 209 | 210 | return result 211 | 212 | def lxlylz_factorial_coef_torch(lxlylz) -> torch.Tensor: 213 | 214 | # Check if lxlylz is a tensor, if not, convert to tensor 215 | if not isinstance(lxlylz, torch.Tensor): 216 | lxlylz = torch.tensor(lxlylz, dtype=torch.int64) 217 | 218 | if not torch.all(lxlylz == lxlylz.int()): 219 | raise ValueError("All elements of lxlylz must be integers.") 220 | 221 | sorted_lxlylz, _ = torch.sort(lxlylz, descending=True) 222 | l = torch.sum(sorted_lxlylz) 223 | 224 | result = torch.tensor(1, dtype=torch.int) 225 | 226 | for i in torch.arange(int(sorted_lxlylz[0])): 227 | result = result * (l - i) 228 | result = (result / (i + 1)).floor() 229 | 230 | for i in torch.arange(1, len(sorted_lxlylz)): 231 | for j in torch.arange(int(sorted_lxlylz[i])): 232 | result = result * (l - sorted_lxlylz[:i].sum() - j) 233 | result = (result / (j + 1)).floor() 234 | 235 | return result 236 | -------------------------------------------------------------------------------- /cace/modules/atomwise.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, Sequence, Callable, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .blocks import Dense, ResidualBlock, build_mlp 7 | from ..tools import scatter_sum 8 | 9 | __all__ = ["Atomwise"] 10 | 11 | class Atomwise(nn.Module): 12 | """ 13 | Predicts atom-wise contributions and accumulates global prediction, e.g. for the energy. 14 | 15 | If `aggregation_mode` is None, only the per-atom predictions will be returned. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | n_in: Optional[int] = None, 21 | n_out: int = 1, 22 | n_hidden: Optional[Union[int, Sequence[int]]] = None, 23 | n_layers: int = 2, 24 | bias: bool = True, 25 | activation: Callable = F.silu, 26 | aggregation_mode: str = "sum", 27 | feature_key: Union[str, Sequence[int]] = 'node_feats', 28 | output_key: str = "energy", 29 | per_atom_output_key: Optional[str] = None, 30 | descriptor_output_key: Optional[str] = None, 31 | residual: bool = False, 32 | use_batchnorm: bool = False, 33 | add_linear_nn: bool = False, 34 | post_process: Optional[Callable] = None 35 | ): 36 | """ 37 | Args: 38 | n_in: input dimension of representation 39 | n_out: output dimension of target property (default: 1) 40 | n_hidden: size of hidden layers. 41 | If an integer, same number of node is used for all hidden layers resulting 42 | in a rectangular network. 43 | If None, the number of neurons is divided by two after each layer starting 44 | n_in resulting in a pyramidal network. 45 | n_layers: number of layers. 46 | aggregation_mode: one of {sum, avg} (default: sum) 47 | output_key: the key under which the result will be stored 48 | per_atom_output_key: If not None, the key under which the per-atom result will be stored 49 | residual: whether to use residual connections between layers 50 | use_batchnorm: whether to use batch normalization between layers 51 | add_linear_nn: whether to add a linear NN to the output of the MLP 52 | """ 53 | super().__init__() 54 | self.output_key = output_key 55 | self.model_outputs = [output_key] 56 | self.per_atom_output_key = per_atom_output_key 57 | self.descriptor_output_key = descriptor_output_key 58 | if self.per_atom_output_key is not None: 59 | self.model_outputs.append(self.per_atom_output_key) 60 | if self.descriptor_output_key is not None: 61 | self.model_outputs.append(self.descriptor_output_key) 62 | 63 | self.n_out = n_out 64 | 65 | if aggregation_mode is None and self.per_atom_output_key is None: 66 | raise ValueError( 67 | "If `aggregation_mode` is None, `per_atom_output_key` needs to be set," 68 | + " since no accumulated output will be returned!" 69 | ) 70 | 71 | self.n_in = n_in 72 | self.n_out = n_out 73 | self.n_hidden = n_hidden 74 | self.n_layers = n_layers 75 | self.activation = activation 76 | self.aggregation_mode = aggregation_mode 77 | self.residual = residual 78 | self.use_batchnorm = use_batchnorm 79 | self.add_linear_nn = add_linear_nn 80 | self.post_process = post_process 81 | self.bias = bias 82 | self.feature_key = feature_key 83 | 84 | if n_in is not None: 85 | self.outnet = build_mlp( 86 | n_in=self.n_in, 87 | n_out=self.n_out, 88 | n_hidden=self.n_hidden, 89 | n_layers=self.n_layers, 90 | activation=self.activation, 91 | residual=self.residual, 92 | use_batchnorm=self.use_batchnorm, 93 | bias=self.bias, 94 | ) 95 | if self.add_linear_nn: 96 | self.linear_nn = Dense( 97 | self.n_in, 98 | self.n_out, 99 | bias=self.bias, 100 | activation=None, 101 | use_batchnorm=self.use_batchnorm, 102 | ) 103 | 104 | else: 105 | self.outnet = None 106 | 107 | def forward(self, 108 | data: Dict[str, torch.Tensor], 109 | training: bool = None, 110 | output_index: int = None, # only used for multi-head output 111 | ) -> Dict[str, torch.Tensor]: 112 | 113 | # check if self.feature_key exists, otherwise set default 114 | if not hasattr(self, "feature_key") or self.feature_key is None: 115 | self.feature_key = "node_feats" 116 | 117 | # reshape the feature vectors 118 | if isinstance(self.feature_key, str): 119 | if self.feature_key not in data: 120 | raise ValueError(f"Feature key {self.feature_key} not found in data dictionary.") 121 | features = data[self.feature_key] 122 | features = features.reshape(features.shape[0], -1) 123 | elif isinstance(self.feature_key, list): 124 | features = torch.cat([data[key].reshape(data[key].shape[0], -1) for key in self.feature_key], dim=-1) 125 | 126 | if self.n_in is None: 127 | self.n_in = features.shape[1] 128 | else: 129 | assert self.n_in == features.shape[1] 130 | 131 | if self.outnet == None: 132 | self.outnet = build_mlp( 133 | n_in=self.n_in, 134 | n_out=self.n_out, 135 | n_hidden=self.n_hidden, 136 | n_layers=self.n_layers, 137 | activation=self.activation, 138 | residual=self.residual, 139 | use_batchnorm=self.use_batchnorm, 140 | bias=self.bias, 141 | ) 142 | self.outnet = self.outnet.to(features.device) 143 | if self.add_linear_nn: 144 | self.linear_nn = Dense( 145 | self.n_in, 146 | self.n_out, 147 | bias=self.bias, 148 | activation=None, 149 | use_batchnorm=self.use_batchnorm, 150 | ) 151 | self.linear_nn = self.linear_nn.to(features.device) 152 | else: 153 | self.linear_nn = None 154 | 155 | # predict atomwise contributions 156 | y = self.outnet(features) 157 | if self.add_linear_nn: 158 | y += self.linear_nn(features) 159 | 160 | # accumulate the per-atom output if necessary 161 | if self.per_atom_output_key is not None: 162 | data[self.per_atom_output_key] = y 163 | 164 | if hasattr(self, "descriptor_output_key") and self.descriptor_output_key is not None: 165 | data[self.descriptor_output_key] = features 166 | 167 | # aggregate 168 | if self.aggregation_mode is not None: 169 | y = scatter_sum( 170 | src=y, 171 | index=data["batch"], 172 | dim=0) 173 | y = torch.squeeze(y, -1) 174 | 175 | if self.aggregation_mode == "avg": 176 | y = y / torch.bincount(data['batch']) 177 | 178 | if hasattr(self, "post_process") and self.post_process is not None: 179 | y = self.post_process(y) 180 | data[self.output_key] = y[:, output_index] if output_index is not None else y 181 | return data 182 | -------------------------------------------------------------------------------- /cace/modules/blocks.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union, Optional, Sequence 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | __all__ = ["build_mlp", "Dense", "ResidualBlock", "AtomicEnergiesBlock"] 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from typing import Union, Callable 14 | 15 | def build_mlp( 16 | n_in: int, 17 | n_out: int, 18 | n_hidden: Optional[Union[int, Sequence[int]]] = None, 19 | n_layers: int = 2, 20 | activation: Callable = F.silu, 21 | residual: bool = False, 22 | use_batchnorm: bool = False, 23 | bias: bool = True, 24 | last_zero_init: bool = False, 25 | ) -> nn.Module: 26 | """ 27 | Build multiple layer fully connected perceptron neural network. 28 | 29 | Args: 30 | n_in: number of input nodes. 31 | n_out: number of output nodes. 32 | n_hidden: number hidden layer nodes. 33 | If an integer, same number of node is used for all hidden layers resulting 34 | in a rectangular network. 35 | If None, the number of neurons is divided by two after each layer starting 36 | n_in resulting in a pyramidal network. 37 | n_layers: number of layers. 38 | activation: activation function. All hidden layers would 39 | the same activation function except the output layer that does not apply 40 | any activation function. 41 | residual: whether to use residual connections between layers 42 | use_batchnorm: whether to use batch normalization between layers 43 | """ 44 | # get list of number of nodes in input, hidden & output layers 45 | if n_hidden is None: 46 | c_neurons = n_in 47 | n_neurons = [] 48 | for i in range(n_layers): 49 | n_neurons.append(c_neurons) 50 | c_neurons = max(n_out, c_neurons // 2) 51 | n_neurons.append(n_out) 52 | else: 53 | # get list of number of nodes hidden layers 54 | if type(n_hidden) is int: 55 | n_hidden = [n_hidden] * (n_layers - 1) 56 | else: 57 | n_hidden = list(n_hidden) 58 | n_neurons = [n_in] + n_hidden + [n_out] 59 | 60 | if residual: 61 | if n_layers < 3 or n_layers % 2 == 0: 62 | raise ValueError("Residual networks require at least 3 layers and an odd number of layers") 63 | layers = [] 64 | # Create residual blocks every 2 layers 65 | for i in range(0, n_layers - 1, 2): 66 | in_features = n_neurons[i] 67 | out_features = n_neurons[min(i + 2, len(n_neurons) - 1)] 68 | layers.append( 69 | ResidualBlock( 70 | in_features, 71 | out_features, 72 | activation, 73 | skip_interval=2, 74 | use_batchnorm=use_batchnorm, 75 | ) 76 | ) 77 | else: 78 | # assign a Dense layer (with activation function) to each hidden layer 79 | layers = [ 80 | Dense(n_neurons[i], n_neurons[i + 1], activation=activation, use_batchnorm=use_batchnorm, bias=bias) 81 | for i in range(n_layers - 1) 82 | ] 83 | 84 | # assign a Dense layer (without activation function) to the output layer 85 | 86 | if last_zero_init: 87 | layers.append( 88 | Dense( 89 | n_neurons[-2], 90 | n_neurons[-1], 91 | activation=None, 92 | weight_init=torch.nn.init.zeros_, 93 | bias=bias, 94 | ) 95 | ) 96 | else: 97 | layers.append( 98 | Dense(n_neurons[-2], n_neurons[-1], activation=None, bias=bias) 99 | ) 100 | # put all layers together to make the network 101 | out_net = nn.Sequential(*layers) 102 | return out_net 103 | 104 | class Dense(nn.Module): 105 | def __init__( 106 | self, 107 | in_features: int, 108 | out_features: int, 109 | bias: bool = True, 110 | activation: Union[Callable, nn.Module] = nn.Identity(), 111 | use_batchnorm: bool = False, 112 | ): 113 | """ 114 | Fully connected linear layer with an optional activation function and batch normalization. 115 | 116 | Args: 117 | in_features (int): Number of input features. 118 | out_features (int): Number of output features. 119 | bias (bool): If False, the layer will not have a bias term. 120 | activation (Callable or nn.Module): Activation function. Defaults to Identity. 121 | use_batchnorm (bool): If True, include a batch normalization layer. 122 | """ 123 | super().__init__() 124 | self.use_batchnorm = use_batchnorm 125 | 126 | # Dense layer 127 | self.linear = nn.Linear(in_features, out_features, bias) 128 | 129 | # Activation function 130 | self.activation = activation 131 | if self.activation is None: 132 | self.activation = nn.Identity() 133 | 134 | # Batch normalization layer 135 | if self.use_batchnorm: 136 | self.batchnorm = nn.BatchNorm1d(out_features) 137 | 138 | def forward(self, input: torch.Tensor): 139 | y = self.linear(input) 140 | if self.use_batchnorm: 141 | y = self.batchnorm(y) 142 | y = self.activation(y) 143 | return y 144 | 145 | 146 | class ResidualBlock(nn.Module): 147 | """ 148 | A residual block with flexible number of dense layers, optional batch normalization, 149 | and a skip connection. 150 | 151 | Args: 152 | in_features: Number of input features. 153 | out_features: Number of output features. 154 | activation: Activation function to be used in the dense layers. 155 | skip_interval: Number of layers between each skip connection. 156 | use_batchnorm: Boolean indicating whether to use batch normalization. 157 | """ 158 | def __init__(self, in_features, out_features, activation, skip_interval=2, use_batchnorm=True): 159 | super().__init__() 160 | self.skip_interval = skip_interval 161 | self.use_batchnorm = use_batchnorm 162 | self.layers = nn.ModuleList() 163 | 164 | # Skip connection with optional dimension matching and batch normalization 165 | if in_features != out_features: 166 | skip_layers = [Dense(in_features, out_features, activation=None)] 167 | if self.use_batchnorm: 168 | skip_layers.append(nn.BatchNorm1d(out_features)) 169 | self.skip = nn.Sequential(*skip_layers) 170 | else: 171 | self.skip = nn.Identity() 172 | 173 | # Create dense layers with optional batch normalization 174 | for _ in range(skip_interval): 175 | self.layers.append(Dense(in_features, out_features, activation=activation)) 176 | if self.use_batchnorm: 177 | self.layers.append(nn.BatchNorm1d(out_features)) 178 | in_features = out_features # Update in_features for the next layer 179 | 180 | def forward(self, x): 181 | identity = self.skip(x) 182 | out = x 183 | 184 | # Forward through dense layers with skip connections and optional batch normalization 185 | for i, layer in enumerate(self.layers): 186 | out = layer(out) 187 | if (i + 1) % self.skip_interval == 0: 188 | out += identity 189 | 190 | return out 191 | 192 | class AtomicEnergiesBlock(nn.Module): 193 | def __init__(self, nz:int, trainable=True, atomic_energies: Optional[Union[np.ndarray, torch.Tensor]]=None): 194 | super().__init__() 195 | if atomic_energies is None: 196 | atomic_energies = torch.zeros(nz) 197 | else: 198 | assert len(atomic_energies.shape) == 1 199 | 200 | if trainable: 201 | self.atomic_energies = nn.Parameter(atomic_energies) 202 | else: 203 | self.register_buffer("atomic_energies", atomic_energies, torch.get_default_dtype()) 204 | 205 | def forward( 206 | self, x: torch.Tensor # one-hot of elements [..., n_elements] 207 | ) -> torch.Tensor: # [..., ] 208 | return torch.matmul(x, self.atomic_energies) 209 | 210 | def __repr__(self): 211 | formatted_energies = ", ".join([f"{x:.4f}" for x in self.atomic_energies]) 212 | return f"{self.__class__.__name__}(energies=[{formatted_energies}])" 213 | -------------------------------------------------------------------------------- /cace/modules/cutoff.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # Radial basis cutoff 3 | # modified from mace/mace/modules/radials.py and schnetpack/src/schnetpack/nn/cutoff.py 4 | # This program is distributed under the MIT License (see MIT.md) 5 | ########################################################################################### 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | __all__ = ["CosineCutoff", "MollifierCutoff", "PolynomialCutoff", "SwitchFunction"] 12 | 13 | def cosine_cutoff(input: torch.Tensor, cutoff: torch.Tensor): 14 | """ Behler-style cosine cutoff. 15 | 16 | .. math:: 17 | f(r) = \begin{cases} 18 | 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right] 19 | & r < r_\text{cutoff} \\ 20 | 0 & r \geqslant r_\text{cutoff} \\ 21 | \end{cases} 22 | 23 | Args: 24 | cutoff (float, optional): cutoff radius. 25 | 26 | """ 27 | 28 | # Compute values of cutoff function 29 | input_cut = 0.5 * (torch.cos(input * np.pi / cutoff) + 1.0) 30 | # Remove contributions beyond the cutoff radius 31 | input_cut *= (input < cutoff).float() 32 | return input_cut 33 | 34 | 35 | class CosineCutoff(nn.Module): 36 | r""" Behler-style cosine cutoff module. 37 | 38 | .. math:: 39 | f(r) = \begin{cases} 40 | 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right] 41 | & r < r_\text{cutoff} \\ 42 | 0 & r \geqslant r_\text{cutoff} \\ 43 | \end{cases} 44 | 45 | """ 46 | 47 | def __init__(self, cutoff: float): 48 | """ 49 | Args: 50 | cutoff (float, optional): cutoff radius. 51 | """ 52 | super().__init__() 53 | self.register_buffer("cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype())) 54 | 55 | def forward(self, input: torch.Tensor): 56 | return cosine_cutoff(input, self.cutoff) 57 | 58 | def __repr__(self): 59 | return f"{self.__class__.__name__}(cutoff={self.cutoff})" 60 | 61 | def mollifier_cutoff(input: torch.Tensor, cutoff: torch.Tensor, eps: torch.Tensor): 62 | r""" Mollifier cutoff scaled to have a value of 1 at :math:`r=0`. 63 | 64 | .. math:: 65 | f(r) = \begin{cases} 66 | \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right) 67 | & r < r_\text{cutoff} \\ 68 | 0 & r \geqslant r_\text{cutoff} \\ 69 | \end{cases} 70 | 71 | Args: 72 | cutoff: Cutoff radius. 73 | eps: Offset added to distances for numerical stability. 74 | 75 | """ 76 | mask = (input + eps < cutoff).float() 77 | exponent = 1.0 - 1.0 / (1.0 - torch.pow(input * mask / cutoff, 2)) 78 | cutoffs = torch.exp(exponent) 79 | cutoffs = cutoffs * mask 80 | return cutoffs 81 | 82 | 83 | class MollifierCutoff(nn.Module): 84 | r""" Mollifier cutoff module scaled to have a value of 1 at :math:`r=0`. 85 | 86 | .. math:: 87 | f(r) = \begin{cases} 88 | \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right) 89 | & r < r_\text{cutoff} \\ 90 | 0 & r \geqslant r_\text{cutoff} \\ 91 | \end{cases} 92 | """ 93 | 94 | def __init__(self, cutoff: float, eps: float = 1.0e-7): 95 | """ 96 | Args: 97 | cutoff: Cutoff radius. 98 | eps: Offset added to distances for numerical stability. 99 | """ 100 | super().__init__() 101 | self.register_buffer("cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype())) 102 | self.register_buffer("eps", torch.tensor(eps, dtype=torch.get_default_dtype())) 103 | 104 | def forward(self, input: torch.Tensor): 105 | return mollifier_cutoff(input, self.cutoff, self.eps) 106 | 107 | def __repr__(self): 108 | return f"{self.__class__.__name__}(eps={self.eps}, cutoff={self.cutoff})" 109 | 110 | def _switch_component( 111 | x: torch.Tensor, ones: torch.Tensor, zeros: torch.Tensor 112 | ) -> torch.Tensor: 113 | """ 114 | Basic component of switching functions. 115 | 116 | Args: 117 | x (torch.Tensor): Switch functions. 118 | ones (torch.Tensor): Tensor with ones. 119 | zeros (torch.Tensor): Zero tensor 120 | 121 | Returns: 122 | torch.Tensor: Output tensor. 123 | """ 124 | x_ = torch.where(x <= 0, ones, x) 125 | return torch.where(x <= 0, zeros, torch.exp(-ones / x_)) 126 | 127 | 128 | class SwitchFunction(nn.Module): 129 | """ 130 | Decays from 1 to 0 between `switch_on` and `switch_off`. 131 | """ 132 | 133 | def __init__(self, switch_on: float, switch_off: float): 134 | """ 135 | 136 | Args: 137 | switch_on (float): Onset of switch. 138 | switch_off (float): Value from which on switch is 0. 139 | """ 140 | super(SwitchFunction, self).__init__() 141 | self.register_buffer("switch_on", torch.tensor(switch_on, dtype=torch.get_default_dtype())) 142 | self.register_buffer("switch_off", torch.tensor(switch_off, dtype=torch.get_default_dtype())) 143 | 144 | def forward(self, x: torch.Tensor) -> torch.Tensor: 145 | """ 146 | 147 | Args: 148 | x (torch.Tensor): tensor to which switching function should be applied to. 149 | 150 | Returns: 151 | torch.Tensor: switch output 152 | """ 153 | x = (x - self.switch_on) / (self.switch_off - self.switch_on) 154 | 155 | ones = torch.ones_like(x) 156 | zeros = torch.zeros_like(x) 157 | fp = _switch_component(x, ones, zeros) 158 | fm = _switch_component(1 - x, ones, zeros) 159 | 160 | f_switch = torch.where(x <= 0, ones, torch.where(x >= 1, zeros, fm / (fp + fm))) 161 | return f_switch 162 | 163 | def __repr__(self): 164 | return f"{self.__class__.__name__}(switch_on={self.switch_on}, switch_off={self.switch_off})" 165 | 166 | class PolynomialCutoff(nn.Module): 167 | """ 168 | Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020. 169 | Equation (8) 170 | """ 171 | 172 | p: torch.Tensor 173 | cutoff: torch.Tensor 174 | 175 | def __init__(self, cutoff: float, p=6): 176 | super().__init__() 177 | self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype())) 178 | self.register_buffer( 179 | "cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype()) 180 | ) 181 | 182 | def forward(self, x: torch.Tensor) -> torch.Tensor: 183 | envelope = ( 184 | 1.0 185 | - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.cutoff, self.p) 186 | + self.p * (self.p + 2.0) * torch.pow(x / self.cutoff, self.p + 1) 187 | - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.cutoff, self.p + 2) 188 | ) 189 | return envelope * (x < self.cutoff) 190 | 191 | def __repr__(self): 192 | return f"{self.__class__.__name__}(p={self.p}, cutoff={self.cutoff})" 193 | -------------------------------------------------------------------------------- /cace/modules/feature_mix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Dict 4 | 5 | __all__ = ['FeatureAdd', 'FeatureInteract'] 6 | 7 | class FeatureAdd(nn.Module): 8 | """ 9 | A class for adding together different features of the data. 10 | """ 11 | def __init__(self, 12 | feature_keys: list, 13 | output_key: str): 14 | super().__init__() 15 | self.feature_keys = feature_keys 16 | self.output_key = output_key 17 | self.model_outputs = [output_key] 18 | 19 | def forward(self, data: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]: 20 | feature_shape = data[self.feature_keys[0]].shape 21 | result = torch.zeros_like(data[self.feature_keys[0]]) 22 | for feature_key in self.feature_keys: 23 | if data[feature_key].shape != feature_shape: 24 | raise ValueError(f"Feature {feature_key} has shape {data[feature_key].shape} but expected {feature_shape}") 25 | result += data[feature_key] 26 | data[self.output_key] = result 27 | return data 28 | 29 | 30 | class FeatureInteract(nn.Module): 31 | """ 32 | A class for interacting between two multidimensional features by reshaping, performing interaction, and reshaping back. 33 | """ 34 | def __init__(self, 35 | feature1_key: str, 36 | feature2_key: str, 37 | output_key: str): 38 | super().__init__() 39 | self.feature1_key = feature1_key 40 | self.feature2_key = feature2_key 41 | self.output_key = output_key 42 | self.model_outputs = [output_key] 43 | 44 | # Weights will be initialized during the forward pass 45 | self.weights = None 46 | 47 | def forward(self, data: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]: 48 | feature1 = data[self.feature1_key] # Shape: [n, A, B, C] 49 | feature2 = data[self.feature2_key] # Shape: [n, D, E] 50 | 51 | # Ensure the first dimensions match 52 | if feature1.shape[0] != feature2.shape[0]: 53 | raise ValueError(f"Feature1 has shape {feature1.shape} but feature2 has shape {feature2.shape}. Shapes must match.") 54 | 55 | # Save the original shape for reshaping back 56 | original_shape = feature1.shape # [n, A, B, C] 57 | 58 | # Reshape both features to [n, -1] (flatten all except the first dimension) 59 | n = feature1.shape[0] # The first dimension size (n) 60 | flattened_size1 = feature1.shape[1:].numel() # Product of A, B, C 61 | flattened_size2 = feature2.shape[1:].numel() # Product of D, E 62 | feature1_reshaped = feature1.view(n, -1) # [n, A * B * C] 63 | feature2_reshaped = feature2.view(n, -1) # [n, D * E] 64 | 65 | # Dynamically initialize weights based on the reshaped feature sizes during the forward pass 66 | if self.weights is None: 67 | self.weights = nn.Parameter(torch.randn(flattened_size2, flattened_size1, dtype=torch.get_default_dtype(), device=feature1.device)) 68 | 69 | # Perform the interaction using einsum on the reshaped tensors 70 | interaction_result = feature1_reshaped * torch.einsum('ij,jk->ik', feature2_reshaped, self.weights) 71 | 72 | # Reshape the result back to the original shape [n, A, B, C] 73 | interaction_result = interaction_result.view(*original_shape) 74 | 75 | # Save the result in the data dictionary 76 | data[self.output_key] = interaction_result 77 | return data 78 | 79 | -------------------------------------------------------------------------------- /cace/modules/forces.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | from torch import nn 4 | 5 | from .utils import get_outputs 6 | 7 | __all__ = ['Forces'] 8 | 9 | class Forces(nn.Module): 10 | """ 11 | Predicts forces and stress as response of the energy prediction 12 | 13 | """ 14 | 15 | def __init__( 16 | self, 17 | calc_forces: bool = True, 18 | calc_stress: bool = True, 19 | #calc_virials: bool = False, 20 | energy_key: str = 'energy', 21 | forces_key: str = 'forces', 22 | stress_key: str = 'stress', 23 | virials_key: str = 'virials', 24 | ): 25 | """ 26 | Args: 27 | calc_forces: If True, calculate atomic forces. 28 | calc_stress: If True, calculate the stress tensor. 29 | energy_key: Key of the energy in results. 30 | forces_key: Key of the forces in results. 31 | stress_key: Key of the stress in results. 32 | """ 33 | super().__init__() 34 | self.calc_forces = calc_forces 35 | self.calc_stress = calc_stress 36 | #self.calc_virials = calc_virials 37 | self.energy_key = energy_key 38 | self.forces_key = forces_key 39 | self.stress_key = stress_key 40 | self.virial_key = virials_key 41 | self.model_outputs = [] 42 | if calc_forces: 43 | self.model_outputs.append(forces_key) 44 | if calc_stress: 45 | self.model_outputs.append(stress_key) 46 | 47 | self.required_derivatives = [] 48 | if self.calc_forces or self.calc_stress: 49 | self.required_derivatives.append('positions') 50 | 51 | def forward(self, data: Dict[str, torch.Tensor], training: bool = False, output_index: int = None) -> Dict[str, torch.Tensor]: 52 | forces, virials, stress = get_outputs( 53 | energy=data[self.energy_key][:, output_index] if output_index is not None and len(data[self.energy_key].shape) == 2 else data[self.energy_key], 54 | positions=data['positions'], 55 | displacement=data.get('displacement', None), 56 | cell=data.get('cell', None), 57 | training=training, 58 | compute_force=self.calc_forces, 59 | #compute_virials=self.calc_virials, 60 | compute_stress=self.calc_stress 61 | ) 62 | 63 | data[self.forces_key] = forces 64 | if self.virial_key is not None: 65 | data[self.virial_key] = virials 66 | if self.stress_key is not None: 67 | data[self.stress_key] = stress 68 | return data 69 | 70 | def __repr__(self): 71 | return ( 72 | f"{self.__class__.__name__} (calc_forces={self.calc_forces}, calc_stress={self.calc_stress},) " 73 | ) 74 | -------------------------------------------------------------------------------- /cace/modules/grad.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | from torch import nn 4 | 5 | __all__ = ['Grad'] 6 | 7 | class Grad(nn.Module): 8 | """ 9 | a wrapper for the gradient calculation 10 | 11 | """ 12 | 13 | def __init__( 14 | self, 15 | y_key: str, 16 | x_key: str, 17 | output_key: str = 'gradient', 18 | ): 19 | super().__init__() 20 | self.y_key = y_key 21 | self.x_key = x_key 22 | self.output_key = output_key 23 | self.required_derivatives = [self.x_key] 24 | self.model_outputs = [self.output_key] 25 | 26 | def forward(self, data: Dict[str, torch.Tensor], training: bool = True, output_index: int = None) -> Dict[str, torch.Tensor]: 27 | y = data[self.y_key] 28 | x = data[self.x_key] 29 | 30 | if y.is_complex(): 31 | get_imag = True 32 | else: 33 | get_imag = False 34 | 35 | if len(y.shape) == 1: 36 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)] 37 | gradient_real = torch.autograd.grad( 38 | outputs=[y], # [n_graphs, ] 39 | inputs=[x], # [n_nodes, 3] 40 | grad_outputs=grad_outputs, 41 | retain_graph=(training or get_imag), # Make sure the graph is not destroyed during training 42 | create_graph=training, # Create graph for second derivative 43 | allow_unused=True, # For complete dissociation turn to true 44 | )[0] # [n_nodes, 3] 45 | 46 | if get_imag: 47 | gradient_imag = torch.autograd.grad( 48 | outputs=[y/1j], # [n_graphs, ] 49 | inputs=[x], # [n_nodes, 3] 50 | grad_outputs=grad_outputs, 51 | retain_graph=training, # Make sure the graph is not destroyed during training 52 | create_graph=training, # Create graph for second derivative 53 | allow_unused=True, # For complete dissociation turn to true 54 | )[0] # [n_nodes, 3] 55 | else: 56 | gradient_imag = 0.0 57 | else: 58 | dim_y = y.shape[1] 59 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y[:,0])] 60 | gradient_real = torch.stack([ 61 | torch.autograd.grad( 62 | outputs=[y[:,i]], # [n_graphs, ] 63 | inputs=[x], # [n_nodes, 3] 64 | grad_outputs=grad_outputs, 65 | retain_graph=(training or (i < dim_y - 1) or get_imag), # Make sure the graph is not destroyed during training 66 | create_graph=training, # Create graph for second derivative 67 | allow_unused=True, # For complete dissociation turn to true 68 | )[0] for i in range(dim_y) 69 | ], axis=2) # [n_nodes, 3, num_energy] 70 | # if y is complex, we need to calculate the imaginary part 71 | if get_imag: 72 | gradient_imag = torch.stack([ 73 | torch.autograd.grad( 74 | outputs=[y[:,i]/1j], # [n_graphs, ] 75 | inputs=[x], # [n_nodes, 3] 76 | grad_outputs=grad_outputs, 77 | retain_graph=(training or (i < dim_y - 1)), # Make sure the graph is not destroyed during training 78 | create_graph=training, # Create graph for second derivative 79 | allow_unused=True, # For complete dissociation turn to true 80 | )[0] for i in range(dim_y) 81 | ], axis=2) # [n_nodes, 3, num_energy] 82 | if get_imag: 83 | data[self.output_key] = gradient_real + 1j * gradient_imag 84 | else: 85 | data[self.output_key] = gradient_real 86 | 87 | return data 88 | 89 | def __repr__(self): 90 | return ( 91 | f"{self.__class__.__name__} (function={self.y_key}, variable={self.x_key},) " 92 | ) 93 | -------------------------------------------------------------------------------- /cace/modules/les_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Dict, Sequence, Union 4 | 5 | __all__ = ['LesWrapper'] 6 | 7 | class LesWrapper(nn.Module): 8 | """ 9 | A wrapper for the LES library that does long-range interactions and BECs 10 | Note that CACE has its own internal implementation of the LES algorithm 11 | so it is not necessary to use this wrapper in CACE. 12 | """ 13 | def __init__(self, 14 | feature_key: Union[str, Sequence[int]] = 'node_feats', 15 | energy_key: str = 'LES_energy', 16 | charge_key: str = 'LES_charge', 17 | bec_key: str = 'LES_BEC', 18 | compute_energy: bool = True, 19 | compute_bec: bool = False, 20 | bec_output_index: int = None, # option to compute BEC along one axis 21 | ): 22 | super().__init__() 23 | from les import Les 24 | self.les = Les(les_arguments={}) 25 | 26 | self.feature_key = feature_key 27 | self.energy_key = energy_key 28 | self.charge_key = charge_key 29 | self.bec_key = bec_key 30 | self.bec_output_index = bec_output_index 31 | 32 | self.compute_energy = compute_energy 33 | self.compute_bec = compute_bec 34 | self.model_outputs = [charge_key] 35 | if compute_energy: 36 | self.model_outputs.append(energy_key) 37 | if compute_bec: 38 | self.model_outputs.append(bec_key) 39 | self.required_derivatives = [] 40 | self.required_derivatives.append('cell') 41 | 42 | def set_compute_energy(self, compute_energy: bool): 43 | self.compute_energy = compute_energy 44 | 45 | def set_compute_bec(self, compute_bec: bool): 46 | self.compute_bec = compute_bec 47 | 48 | def set_bec_output_index(self, bec_output_index: int): 49 | self.bec_output_index = bec_output_index 50 | 51 | def forward(self, data: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]: 52 | 53 | # reshape the feature vectors 54 | if isinstance(self.feature_key, str): 55 | if self.feature_key not in data: 56 | raise ValueError(f"Feature key {self.feature_key} not found in data dictionary.") 57 | features = data[self.feature_key] 58 | features = features.reshape(features.shape[0], -1) 59 | elif isinstance(self.feature_key, list): 60 | features = torch.cat([data[key].reshape(data[key].shape[0], -1) for key in self.feature_key], dim=-1) 61 | 62 | result = self.les(desc=features, 63 | positions=data['positions'], 64 | cell=data['cell'].view(-1, 3, 3), 65 | batch=data["batch"], 66 | compute_energy=self.compute_energy, 67 | compute_bec=self.compute_bec, 68 | bec_output_index=self.bec_output_index, 69 | ) 70 | 71 | data[self.charge_key] = result['latent_charges'] 72 | if self.compute_energy: 73 | data[self.energy_key] = result['E_lr'] 74 | if self.compute_bec: 75 | data[self.bec_key] = result['BEC'] 76 | return data 77 | -------------------------------------------------------------------------------- /cace/modules/polarization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Dict 4 | 5 | __all__ = ['Polarization', 'Dephase', 'FixedCharge'] 6 | 7 | class Polarization(nn.Module): 8 | def __init__(self, 9 | charge_key: str = 'q', 10 | output_key: str = 'polarization', 11 | output_index: int = None, # 0, 1, 2 to select only one component 12 | phase_key: str = 'phase', 13 | remove_mean: bool = True, 14 | pbc: bool = False, 15 | normalization_factor: float = 1./9.48933, 16 | ): 17 | super().__init__() 18 | self.charge_key = charge_key 19 | self.output_key = output_key 20 | self.output_index = output_index 21 | self.phase_key = phase_key 22 | self.model_outputs = [output_key, phase_key] 23 | self.remove_mean = remove_mean 24 | self.pbc = pbc 25 | self.normalization_factor = normalization_factor 26 | 27 | def forward(self, data: Dict[str, torch.Tensor], training=True, output_index=None) -> torch.Tensor: 28 | 29 | if data["batch"] is None: 30 | n_nodes = data['positions'].shape[0] 31 | batch_now = torch.zeros(n_nodes, dtype=torch.int64, device=data['positions'].device) 32 | else: 33 | batch_now = data["batch"] 34 | 35 | box = data['cell'].view(-1, 3, 3) 36 | 37 | r = data['positions'] 38 | q = data[self.charge_key] 39 | 40 | if q.dim() == 1: 41 | q = q.unsqueeze(1) 42 | if self.remove_mean: 43 | q = q - torch.mean(q, dim=0, keepdim=True) 44 | 45 | # Check the input dimension 46 | n, d = r.shape 47 | assert d == 3, 'r dimension error' 48 | assert n == q.size(0), 'q dimension error' 49 | 50 | unique_batches = torch.unique(batch_now) # Get unique batch indices 51 | 52 | results = [] 53 | phases = [] 54 | for i in unique_batches: 55 | mask = batch_now == i # Create a mask for the i-th configuration 56 | r_now, q_now, box_now = r[mask], q[mask], box[i] 57 | box_diag = box[i].diagonal(dim1=-2, dim2=-1) 58 | if box_diag[0] < 1e-6 and box_diag[1] < 1e-6 and box_diag[2] < 1e-6 or self.pbc == False: 59 | # the box is not periodic, we use the direct sum 60 | polarization = torch.sum(q_now * r_now, dim=0) 61 | elif box_diag[0] > 0 and box_diag[1] > 0 and box_diag[2] > 0: 62 | polarization, phase = self.compute_pol_pbc(r_now, q_now, box_now) 63 | if self.output_index is not None: 64 | phase = phase[:,self.output_index] 65 | phases.append(phase) 66 | if self.output_index is not None: 67 | polarization = polarization[self.output_index] 68 | results.append(polarization * self.normalization_factor) 69 | data[self.output_key] = torch.stack(results, dim=0) 70 | if len(phases) > 0: 71 | data[self.phase_key] = torch.cat(phases, dim=0) 72 | else: 73 | data[self.phase_key] = 0.0 74 | return data 75 | 76 | def compute_pol_pbc(self, r_now, q_now, box_now): 77 | r_frac = torch.matmul(r_now, torch.linalg.inv(box_now)) 78 | phase = torch.exp(1j * 2.* torch.pi * r_frac) 79 | S = torch.sum(q_now * phase, dim=0) 80 | polarization = torch.matmul(box_now.to(S.dtype), 81 | S.unsqueeze(1)) / (1j * 2.* torch.pi) 82 | return polarization.reshape(-1), phase 83 | 84 | class Dephase(nn.Module): 85 | def __init__(self, 86 | input_key: str = None, 87 | phase_key: str = 'phase', 88 | output_key: str = 'dephased', 89 | input_index: int = None, 90 | ): 91 | super().__init__() 92 | self.input_key = input_key 93 | self.input_index = input_index 94 | self.output_key = output_key 95 | self.phase_key = phase_key 96 | self.model_outputs = [output_key] 97 | 98 | def forward(self, data: Dict[str, torch.Tensor], training=None, output_index=None) -> torch.Tensor: 99 | result = data[self.input_key] * data[self.phase_key].unsqueeze(1).conj() 100 | data[self.output_key] = result.real 101 | return data 102 | 103 | class FixedCharge(nn.Module): 104 | def __init__(self, 105 | atomic_numbers_key: str = 'atomic_numbers', 106 | output_key: str = 'q', 107 | charge_dict: Dict[int, float] = None, 108 | normalize: bool = True, 109 | ): 110 | super().__init__() 111 | self.charge_dict = charge_dict 112 | self.atomic_numbers_key = atomic_numbers_key 113 | self.output_key = output_key 114 | self.normalize = normalize 115 | self.normalization_factor = 9.48933 116 | 117 | def forward(self, data: Dict[str, torch.Tensor], training=None, output_index=None) -> torch.Tensor: 118 | atomic_numbers = data[self.atomic_numbers_key] 119 | charge = torch.tensor([self.charge_dict[atomic_number.item()] for atomic_number in atomic_numbers], device=atomic_numbers.device) 120 | if self.normalize: 121 | charge = charge * self.normalization_factor # to be consistent with the internal units and the Ewald sum 122 | data[self.output_key] = charge[:,None] 123 | return data 124 | -------------------------------------------------------------------------------- /cace/modules/preprocess.py: -------------------------------------------------------------------------------- 1 | # Description: Preprocess the data for the model 2 | 3 | from typing import Dict 4 | import torch 5 | from torch import nn 6 | 7 | from .utils import get_symmetric_displacement 8 | 9 | __all__ = ["Preprocess"] 10 | 11 | class Preprocess(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, data: Dict[str, torch.Tensor], compute_stress: bool = False, compute_virials: bool = False): 16 | 17 | try: 18 | num_graphs = data["ptr"].numel() - 1 19 | except: 20 | num_graphs = 1 21 | 22 | if compute_virials or compute_stress: 23 | ( 24 | data["positions"], 25 | data["shifts"], 26 | data["displacement"], 27 | data["cell"] 28 | ) = get_symmetric_displacement( 29 | positions=data["positions"], 30 | unit_shifts=data["unit_shifts"], 31 | cell=data["cell"], 32 | edge_index=data["edge_index"], 33 | num_graphs=num_graphs, 34 | batch=data["batch"], 35 | ) 36 | else: 37 | data["displacement"] = None 38 | 39 | return data 40 | 41 | 42 | -------------------------------------------------------------------------------- /cace/modules/product_basis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from .angular_tools import find_combo_vectors_l1l2 5 | 6 | class AngularTensorProduct(nn.Module): 7 | def __init__(self, max_l: int, l_list: list): 8 | super().__init__() 9 | self.max_l = max_l 10 | 11 | # Convert elements of l_list to tuples for dictionary keys 12 | l_list_tuples = [tuple(l) for l in l_list] 13 | # Create a dictionary to map tuple to index 14 | l_list_indices = {l_tuple: i for i, l_tuple in enumerate(l_list_tuples)} 15 | vec_dict = find_combo_vectors_l1l2(self.max_l) 16 | 17 | self._get_indices(vec_dict, l_list_indices) 18 | 19 | def _get_indices(self, vec_dict, l_list_indices): 20 | self.indice_list = [] 21 | for i, (l3, l1l2_list) in enumerate(vec_dict.items()): 22 | l3_index = l_list_indices[tuple(l3)] 23 | for item in l1l2_list: 24 | prefactor = int(item[-1]) 25 | l1l2indices = [l_list_indices[tuple(lxlylz)] for lxlylz in item[:-1]] 26 | self.indice_list.append([l3_index, l1l2indices[0], l1l2indices[1], prefactor]) 27 | self.indice_tensor = torch.tensor(self.indice_list) 28 | 29 | def forward(self, edge_attr1: torch.Tensor, edge_attr2: torch.Tensor): 30 | 31 | num_edges, n_radial, n_angular, n_chanel = edge_attr1.size() 32 | edge_attr_new = torch.zeros((num_edges, n_radial, n_angular, n_chanel), 33 | dtype=edge_attr1.dtype, device=edge_attr1.device) 34 | 35 | for item in self.indice_tensor: 36 | l3_index, l1_index, l2_index, prefactor = item[0], item[1], item[2], item[3] 37 | edge_attr_new[:, :, l3_index, :] += prefactor * edge_attr1[:, :, l1_index, :] * edge_attr2[:, :, l2_index, :] 38 | 39 | return edge_attr_new 40 | -------------------------------------------------------------------------------- /cace/modules/radial.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # Radial basis 3 | # modified from mace/mace/modules/radials.py and schnetpack/src/schnetpack/nn/radials.py 4 | # This program is distributed under the MIT License (see MIT.md) 5 | ########################################################################################### 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | __all__ = ["BesselRBF", "GaussianRBF", "GaussianRBFCentered", "ExponentialDecayRBF"] 12 | 13 | class BesselRBF(nn.Module): 14 | """ 15 | Sine for radial basis functions with coulomb decay (0th order bessel). 16 | 17 | References: 18 | 19 | .. [#dimenet] Klicpera, Groß, Günnemann: 20 | Directional message passing for molecular graphs. 21 | ICLR 2020 22 | Equation (7) 23 | """ 24 | 25 | def __init__(self, cutoff: float, n_rbf=8, trainable=False): 26 | super().__init__() 27 | 28 | self.n_rbf = n_rbf 29 | 30 | bessel_weights = ( 31 | np.pi 32 | / cutoff 33 | * torch.linspace( 34 | start=1.0, 35 | end=n_rbf, 36 | steps=n_rbf, 37 | dtype=torch.get_default_dtype(), 38 | ) 39 | ) 40 | if trainable: 41 | self.bessel_weights = nn.Parameter(bessel_weights) 42 | else: 43 | self.register_buffer("bessel_weights", bessel_weights) 44 | 45 | self.register_buffer( 46 | "cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype()) 47 | ) 48 | self.register_buffer( 49 | "prefactor", 50 | torch.tensor(np.sqrt(2.0 / cutoff), dtype=torch.get_default_dtype()), 51 | ) 52 | 53 | def forward(self, x: torch.Tensor) -> torch.Tensor: # [...,1] 54 | numerator = torch.sin(self.bessel_weights * x) # [..., n_rbf] 55 | return self.prefactor * (numerator / x) # [..., n_rbf] 56 | 57 | def __repr__(self): 58 | return ( 59 | f"{self.__class__.__name__}(cutoff={self.cutoff}, n_rbf={len(self.bessel_weights)}, " 60 | f"trainable={self.bessel_weights.requires_grad})" 61 | ) 62 | 63 | 64 | class ExponentialDecayRBF(nn.Module): 65 | """Exponential decay radial basis functions. 66 | y = prefactor * exp(-x / r0) 67 | """ 68 | def __init__( 69 | self, n_rbf: int, cutoff: float, prefactor: torch.tensor=torch.tensor(1.0), trainable: bool = False 70 | ): 71 | super().__init__() 72 | self.n_rbf = n_rbf 73 | 74 | # Convert prefactor to a tensor if it's not already one 75 | if not isinstance(prefactor, torch.Tensor): 76 | prefactor = torch.tensor(prefactor, dtype=torch.get_default_dtype()) 77 | 78 | if n_rbf == 1: 79 | r0 = torch.tensor(cutoff / 2.0, dtype=torch.get_default_dtype()) 80 | else: 81 | # compute offset and width of Gaussian functions 82 | r0 = torch.linspace(0, cutoff, n_rbf + 2, dtype=torch.get_default_dtype()) [1:-1] 83 | 84 | self.register_buffer("cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype())) 85 | 86 | if trainable: 87 | self.r0 = nn.Parameter(r0) 88 | self.prefactor = nn.Parameter(prefactor) 89 | else: 90 | self.register_buffer("r0", r0) 91 | self.register_buffer("prefactor", torch.tensor(prefactor, dtype=torch.get_default_dtype())) 92 | 93 | def forward(self, inputs: torch.Tensor): 94 | return self.prefactor * torch.exp(-inputs / self.r0) 95 | 96 | def __repr__(self): 97 | return ( 98 | f"{self.__class__.__name__}(prefactor={self.prefactor}, r0={self.r0}," 99 | f"trainable={self.r0.requires_grad})" 100 | ) 101 | 102 | def gaussian_rbf(inputs: torch.Tensor, offsets: torch.Tensor, widths: torch.Tensor): 103 | coeff = -0.5 / torch.pow(widths, 2) 104 | diff = inputs - offsets 105 | y = torch.exp(coeff * torch.pow(diff, 2)) 106 | return y 107 | 108 | 109 | class GaussianRBF(nn.Module): 110 | r"""Gaussian radial basis functions.""" 111 | 112 | def __init__( 113 | self, n_rbf: int, cutoff: float, start: float = 0.8, trainable: bool = False 114 | ): 115 | """ 116 | Args: 117 | n_rbf: total number of Gaussian functions, :math:`N_g`. 118 | cutoff: center of last Gaussian function, :math:`\mu_{N_g}` 119 | start: center of first Gaussian function, :math:`\mu_0`. 120 | trainable: If True, widths and offset of Gaussian functions 121 | are adjusted during training process. 122 | """ 123 | super().__init__() 124 | self.n_rbf = n_rbf 125 | 126 | # compute offset and width of Gaussian functions 127 | offset = torch.linspace(start, cutoff, n_rbf, dtype=torch.get_default_dtype()) 128 | widths = torch.FloatTensor( 129 | torch.abs(offset[1] - offset[0]) * torch.ones_like(offset) 130 | ) 131 | 132 | self.register_buffer( 133 | "cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype()) 134 | ) 135 | 136 | if trainable: 137 | self.widths = nn.Parameter(widths) 138 | self.offsets = nn.Parameter(offset) 139 | else: 140 | self.register_buffer("widths", widths) 141 | self.register_buffer("offsets", offset) 142 | 143 | def forward(self, inputs: torch.Tensor): 144 | return gaussian_rbf(inputs, self.offsets, self.widths) 145 | 146 | def __repr__(self): 147 | return ( 148 | f"{self.__class__.__name__}(cutoff={self.cutoff}, n_rbf={self.n_rbf}, " 149 | f"trainable={self.widths.requires_grad})" 150 | ) 151 | 152 | class GaussianRBFCentered(nn.Module): 153 | r"""Gaussian radial basis functions centered at the origin.""" 154 | 155 | def __init__( 156 | self, n_rbf: int, cutoff: float, start: float = 1.0, trainable: bool = False 157 | ): 158 | """ 159 | Args: 160 | n_rbf: total number of Gaussian functions, :math:`N_g`. 161 | cutoff: width of last Gaussian function, :math:`\mu_{N_g}` 162 | start: width of first Gaussian function, :math:`\mu_0`. 163 | trainable: If True, widths of Gaussian functions 164 | are adjusted during training process. 165 | """ 166 | super().__init__() 167 | self.n_rbf = n_rbf 168 | 169 | # compute offset and width of Gaussian functions 170 | widths = torch.linspace(start, cutoff, n_rbf, dtype=torch.get_default_dtype()) 171 | offset = torch.zeros_like(widths) 172 | 173 | self.register_buffer( 174 | "cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype()) 175 | ) 176 | 177 | if trainable: 178 | self.widths = nn.Parameter(widths) 179 | self.offsets = nn.Parameter(offset) 180 | else: 181 | self.register_buffer("widths", widths) 182 | self.register_buffer("offsets", offset) 183 | 184 | def forward(self, inputs: torch.Tensor): 185 | return gaussian_rbf(inputs, self.offsets, self.widths) 186 | 187 | def __repr__(self): 188 | return ( 189 | f"{self.__class__.__name__}(cutoff={self.cutoff}, n_rbf={self.n_rbf}, " 190 | f"trainable={self.widths.requires_grad})" 191 | ) 192 | -------------------------------------------------------------------------------- /cace/modules/radial_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from typing import Optional, List 5 | 6 | class SharedRadialLinearTransform(nn.Module): 7 | # TODO: this can be jitted, however, this causes trouble in saving the model 8 | def __init__(self, max_l: int, radial_dim: int, radial_embedding_dim: Optional[int] = None, channel_dim: Optional[int] = None): 9 | super().__init__() 10 | self.max_l = max_l 11 | self.radial_dim = radial_dim 12 | self.radial_embedding_dim = radial_embedding_dim or radial_dim 13 | self.channel_dim = channel_dim 14 | self.register_buffer('angular_dim_groups', torch.tensor(self._init_angular_dim_groups(max_l), dtype=torch.int64)) 15 | self.weights = self._initialize_weights(radial_dim, self.radial_embedding_dim, channel_dim) 16 | 17 | def __getstate__(self): 18 | # Return a dictionary of state items to be serialized. 19 | state = self.__dict__.copy() 20 | # Modify the state dictionary as needed, or return as is. 21 | return state 22 | 23 | def __setstate__(self, state): 24 | # Restore the state. 25 | self.__dict__.update(state) 26 | 27 | def _initialize_weights(self, radial_dim: int, radial_embedding_dim: int, channel_dim: int) -> nn.ParameterList: 28 | torch.manual_seed(0) 29 | # TODO: try other initialization 30 | if channel_dim is not None: 31 | return nn.ParameterList([ 32 | nn.Parameter(torch.rand([radial_dim, radial_embedding_dim, channel_dim])) for _ in self.angular_dim_groups 33 | ]) 34 | else: 35 | return nn.ParameterList([ 36 | nn.Parameter(torch.rand([radial_dim, radial_embedding_dim])) for _ in self.angular_dim_groups 37 | ]) 38 | 39 | def forward(self, x: torch.Tensor) -> torch.Tensor: 40 | 41 | n_nodes, radial_dim, angular_dim, embedding_dim = x.shape 42 | 43 | output = torch.zeros(n_nodes, self.radial_embedding_dim, angular_dim, embedding_dim, 44 | device=x.device, dtype=x.dtype) 45 | 46 | for index, weight in enumerate(self.weights): 47 | i_start = self.angular_dim_groups[index, 0] 48 | i_end = self.angular_dim_groups[index, 1] 49 | group = torch.arange(i_start, i_end) 50 | # Gather all angular dimensions for the current group 51 | group_x = x[:, :, group, :] # Shape: [n_nodes, radial_dim, len(group), embedding_dim] 52 | # Apply the transformation for the entire group at once 53 | if self.channel_dim: 54 | transformed_group = torch.einsum('ijkh,jmh->imkh', group_x, weight) 55 | else: 56 | transformed_group = torch.einsum('ijkh,jm->imkh', group_x, weight) 57 | # Assign to the output tensor for each angular dimension 58 | output[:, :, group, :] = transformed_group 59 | return output 60 | 61 | def _compute_length_lxlylz(self, l): 62 | return int((l+1)*(l+2)/2) 63 | 64 | def _init_angular_dim_groups(self, max_l): 65 | angular_dim_groups: List[int] = [] 66 | l_now = 0 67 | for l in range(max_l+1): 68 | l_list_atl = [l_now, l_now + self._compute_length_lxlylz(l)] 69 | angular_dim_groups.append(l_list_atl) 70 | l_now += self._compute_length_lxlylz(l) 71 | return angular_dim_groups 72 | -------------------------------------------------------------------------------- /cace/modules/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | __all__ = [ 7 | "Transform", 8 | "TransformException", 9 | ] 10 | 11 | 12 | class TransformException(Exception): 13 | pass 14 | 15 | 16 | class Transform(nn.Module): 17 | """ 18 | Base class for all transforms. 19 | The base class ensures that the reference to the data and datamodule attributes are 20 | initialized. 21 | Transforms can be used as pre- or post-processing layers. 22 | They can also be used for other parts of a model, that need to be 23 | initialized based on data. 24 | 25 | To implement a new transform, override the forward method. Preprocessors are applied 26 | to single examples, while postprocessors operate on batches. All transforms should 27 | return a modified `inputs` dictionary. 28 | 29 | """ 30 | 31 | def datamodule(self, value): 32 | """ 33 | Extract all required information from data module automatically when using 34 | PyTorch Lightning integration. The transform should also implement a way to 35 | set these things manually, to make it usable independent of PL. 36 | 37 | Do not store the datamodule, as this does not work with torchscript conversion! 38 | """ 39 | pass 40 | 41 | def forward( 42 | self, 43 | inputs: Dict[str, torch.Tensor], 44 | ) -> Dict[str, torch.Tensor]: 45 | raise NotImplementedError 46 | 47 | def teardown(self): 48 | pass 49 | -------------------------------------------------------------------------------- /cace/modules/utils.py: -------------------------------------------------------------------------------- 1 | ########################################################################################### 2 | # Utilities 3 | # modified from MACE 4 | # This program is distributed under the MIT License (see MIT.md) 5 | ########################################################################################### 6 | 7 | from typing import List, Optional, Tuple 8 | 9 | import torch 10 | 11 | __all__ = ["get_outputs", "get_edge_vectors_and_lengths", "get_edge_node_type", "get_symmetric_displacement"] 12 | 13 | def compute_forces( 14 | energy: torch.Tensor, positions: torch.Tensor, training: bool = False 15 | ) -> torch.Tensor: 16 | # check the dimension of the energy tensor 17 | if len(energy.shape) == 1: 18 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] 19 | gradient = torch.autograd.grad( 20 | outputs=[energy], # [n_graphs, ] 21 | inputs=[positions], # [n_nodes, 3] 22 | grad_outputs=grad_outputs, 23 | retain_graph=training, # Make sure the graph is not destroyed during training 24 | create_graph=training, # Create graph for second derivative 25 | allow_unused=True, # For complete dissociation turn to true 26 | )[0] # [n_nodes, 3] 27 | else: 28 | num_energy = energy.shape[1] 29 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy[:,0])] 30 | gradient = torch.stack([ 31 | torch.autograd.grad( 32 | outputs=[energy[:,i]], # [n_graphs, ] 33 | inputs=[positions], # [n_nodes, 3] 34 | grad_outputs=grad_outputs, 35 | retain_graph=(training or (i < num_energy - 1)), # Make sure the graph is not destroyed during training 36 | create_graph=(training or (i < num_energy - 1)), # Create graph for second derivative 37 | allow_unused=True, # For complete dissociation turn to true 38 | )[0] for i in range(num_energy) 39 | ], axis=2) # [n_nodes, 3, num_energy] 40 | 41 | if gradient is None: 42 | return torch.zeros_like(positions) 43 | return -1 * gradient 44 | 45 | 46 | def compute_forces_virials( 47 | energy: torch.Tensor, 48 | positions: torch.Tensor, 49 | displacement: torch.Tensor, 50 | cell: torch.Tensor, 51 | training: bool = False, 52 | compute_stress: bool = False, 53 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: 54 | # check the dimension of the energy tensor 55 | if len(energy.shape) == 1: 56 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] 57 | gradient, virials = torch.autograd.grad( 58 | outputs=[energy], # [n_graphs, ] 59 | inputs=[positions, displacement], # [n_nodes, 3] 60 | grad_outputs=grad_outputs, 61 | retain_graph=training, # Make sure the graph is not destroyed during training 62 | create_graph=training, # Create graph for second derivative 63 | allow_unused=True, 64 | ) 65 | stress = torch.zeros_like(displacement) 66 | if compute_stress and virials is not None: 67 | cell = cell.view(-1, 3, 3) 68 | volume = torch.einsum( 69 | "zi,zi->z", 70 | cell[:, 0, :], 71 | torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), 72 | ).unsqueeze(-1) 73 | stress = virials / volume.view(-1, 1, 1) 74 | else: 75 | num_energy = energy.shape[1] 76 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy[:,0])] 77 | gradient_list, virials_list, stress_list = [], [], [] 78 | for i in range(num_energy): 79 | gradient, virials = torch.autograd.grad( 80 | outputs=[energy[:,i]], # [n_graphs, ] 81 | inputs=[positions, displacement], # [n_nodes, 3] 82 | grad_outputs=grad_outputs, 83 | retain_graph=(training or (i < num_energy - 1)), # Make sure the graph is not destroyed during training 84 | create_graph=(training or (i < num_energy - 1)), # Create graph for second derivative 85 | allow_unused=True, 86 | ) 87 | stress = torch.zeros_like(displacement) 88 | if compute_stress and virials is not None: 89 | cell = cell.view(-1, 3, 3) 90 | volume = torch.einsum( 91 | "zi,zi->z", 92 | cell[:, 0, :], 93 | torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), 94 | ).unsqueeze(-1) 95 | stress = virials / volume.view(-1, 1, 1) 96 | gradient_list.append(gradient) 97 | virials_list.append(virials) 98 | stress_list.append(stress) 99 | gradient = torch.stack(gradient_list, axis=2) 100 | virials = torch.stack(virials_list, axis=-1) 101 | stress = torch.stack(stress_list, axis=-1) 102 | 103 | if gradient is None: 104 | gradient = torch.zeros_like(positions) 105 | if virials is None: 106 | virials = torch.zeros((1, 3, 3)) 107 | 108 | return -1 * gradient, -1 * virials, stress 109 | 110 | def get_symmetric_displacement( 111 | positions: torch.Tensor, 112 | unit_shifts: torch.Tensor, 113 | cell: Optional[torch.Tensor], 114 | edge_index: torch.Tensor, 115 | num_graphs: int, 116 | batch: torch.Tensor, 117 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 118 | if cell is None: 119 | cell = torch.zeros( 120 | num_graphs * 3, 121 | 3, 122 | dtype=positions.dtype, 123 | device=positions.device, 124 | ) 125 | sender = edge_index[0] 126 | displacement = torch.zeros( 127 | (num_graphs, 3, 3), 128 | dtype=positions.dtype, 129 | device=positions.device, 130 | ) 131 | displacement.requires_grad_(True) 132 | symmetric_displacement = 0.5 * ( 133 | displacement + displacement.transpose(-1, -2) 134 | ) # From https://github.com/mir-group/nequip 135 | positions = positions + torch.einsum( 136 | "be,bec->bc", positions, symmetric_displacement[batch] 137 | ) 138 | cell = cell.view(-1, 3, 3) 139 | cell = cell + torch.matmul(cell, symmetric_displacement) 140 | shifts = torch.einsum( 141 | "be,bec->bc", 142 | unit_shifts, 143 | cell[batch[sender]], 144 | ) 145 | return positions, shifts, displacement, cell 146 | 147 | def get_outputs( 148 | energy: torch.Tensor, 149 | positions: torch.Tensor, 150 | displacement: Optional[torch.Tensor] = None, 151 | cell: Optional[torch.Tensor] = None, 152 | training: bool = False, 153 | compute_force: bool = True, 154 | compute_virials: bool = True, 155 | compute_stress: bool = True, 156 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 157 | if (compute_virials or compute_stress) and displacement is not None: 158 | # forces come for free 159 | forces, virials, stress = compute_forces_virials( 160 | energy=energy, 161 | positions=positions, 162 | displacement=displacement, 163 | cell=cell, 164 | compute_stress=compute_stress, 165 | training=training, 166 | ) 167 | elif compute_force: 168 | forces, virials, stress = ( 169 | compute_forces(energy=energy, positions=positions, training=training), 170 | None, 171 | None, 172 | ) 173 | else: 174 | forces, virials, stress = (None, None, None) 175 | return forces, virials, stress 176 | 177 | def get_edge_vectors_and_lengths( 178 | positions: torch.Tensor, # [n_nodes, 3] 179 | edge_index: torch.Tensor, # [2, n_edges] 180 | shifts: torch.Tensor, # [n_edges, 3] 181 | normalize: bool = False, 182 | eps: float = 1e-9, 183 | ) -> Tuple[torch.Tensor, torch.Tensor]: 184 | sender = edge_index[0] 185 | receiver = edge_index[1] 186 | vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3] 187 | lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] 188 | if normalize: 189 | vectors_normed = vectors / (lengths + eps) 190 | return vectors_normed, lengths 191 | 192 | return vectors, lengths 193 | 194 | def get_edge_node_type( 195 | edge_index: torch.Tensor, # [2, n_edges] 196 | node_type: torch.Tensor, # [n_nodes, n_dims] 197 | node_type_2: torch.Tensor=None, # [n_nodes, n_dims] 198 | ) -> Tuple[torch.Tensor, torch.Tensor]: 199 | if node_type_2 is None: 200 | node_type_2 = node_type 201 | 202 | edge_type = torch.zeros([edge_index.shape[1], 2, node_type.shape[1]], 203 | dtype=node_type.dtype, device=node_type.device) 204 | sender_type = node_type[edge_index[0]] 205 | receiver_type = node_type_2[edge_index[1]] 206 | return sender_type, receiver_type # [n_edges, n_dims] 207 | 208 | -------------------------------------------------------------------------------- /cace/representations/__init__.py: -------------------------------------------------------------------------------- 1 | from .cace_representation import * 2 | -------------------------------------------------------------------------------- /cace/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_data import * 2 | 3 | from .loss import * 4 | 5 | from .train import * 6 | 7 | from .evaluate import * 8 | 9 | from .lightning import * 10 | 11 | 12 | -------------------------------------------------------------------------------- /cace/tasks/load_data.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | from typing import Dict, List, Optional, Tuple, Sequence 4 | import numpy as np 5 | from ase import Atoms 6 | from ase.io import read 7 | from ..tools import torch_geometric 8 | from ..data import AtomicData 9 | 10 | __all__ = ["load_data_loader", "get_dataset_from_xyz", "random_train_valid_split"] 11 | 12 | @dataclasses.dataclass 13 | class SubsetAtoms: 14 | train: Atoms 15 | valid: Atoms 16 | test: Atoms 17 | cutoff: float 18 | data_key: Dict 19 | atomic_energies: Dict 20 | 21 | def load_data_loader( 22 | collection: SubsetAtoms, 23 | data_type: str, # ['train', 'valid', 'test'] 24 | batch_size: int, 25 | ): 26 | 27 | allowed_types = ['train', 'valid', 'test'] 28 | if data_type not in allowed_types: 29 | raise ValueError(f"Input value must be one of {allowed_types}, got {data_type}") 30 | 31 | cutoff = collection.cutoff 32 | data_key = collection.data_key 33 | atomic_energies = collection.atomic_energies 34 | 35 | if data_type == 'train': 36 | loader = torch_geometric.DataLoader( 37 | dataset=[ 38 | AtomicData.from_atoms(atoms, cutoff=cutoff, data_key=data_key, atomic_energies=atomic_energies) 39 | for atoms in collection.train 40 | ], 41 | batch_size=batch_size, 42 | shuffle=True, 43 | drop_last=True, 44 | ) 45 | elif data_type == 'valid': 46 | loader = torch_geometric.DataLoader( 47 | dataset=[ 48 | AtomicData.from_atoms(atoms, cutoff=cutoff, data_key=data_key, atomic_energies=atomic_energies) 49 | for atoms in collection.valid 50 | ], 51 | batch_size=batch_size, 52 | shuffle=False, 53 | drop_last=False, 54 | ) 55 | elif data_type == 'test': 56 | loader = torch_geometric.DataLoader( 57 | dataset=[ 58 | AtomicData.from_atoms(atoms, cutoff=cutoff, data_key=data_key, atomic_energies=atomic_energies) 59 | for atoms in collection.test 60 | ], 61 | batch_size=batch_size, 62 | shuffle=False, 63 | drop_last=False, 64 | ) 65 | return loader 66 | 67 | def get_dataset_from_xyz( 68 | train_path: str, 69 | cutoff: float, 70 | valid_path: str = None, 71 | valid_fraction: float = 0.1, 72 | test_path: str = None, 73 | seed: int = 1234, 74 | data_key: Dict[str, str] = None, 75 | atomic_energies: Dict[int, float] = None 76 | ) -> SubsetAtoms: 77 | """Load training and test dataset from xyz file""" 78 | all_train_configs = read(train_path, ":") 79 | if not isinstance(all_train_configs, list): 80 | all_train_configs = [all_train_configs] 81 | logging.info( 82 | f"Loaded {len(all_train_configs)} training configurations from '{train_path}'" 83 | ) 84 | if valid_path is not None: 85 | valid_configs = read(valid_path, ":") 86 | if not isinstance(valid_configs, list): 87 | valid_configs = [valid_configs] 88 | logging.info( 89 | f"Loaded {len(valid_configs)} validation configurations from '{valid_path}'" 90 | ) 91 | train_configs = all_train_configs 92 | else: 93 | logging.info( 94 | "Using random %s%% of training set for validation", 100 * valid_fraction 95 | ) 96 | train_configs, valid_configs = random_train_valid_split( 97 | all_train_configs, valid_fraction, seed 98 | ) 99 | 100 | test_configs = [] 101 | if test_path is not None: 102 | test_configs = read(test_path, ":") 103 | if not isinstance(test_configs, list): 104 | test_configs = [test_configs] 105 | logging.info( 106 | f"Loaded {len(test_configs)} test configurations from '{test_path}'" 107 | ) 108 | return ( 109 | SubsetAtoms(train=train_configs, valid=valid_configs, test=test_configs, cutoff=cutoff, data_key=data_key, atomic_energies=atomic_energies) 110 | ) 111 | 112 | def random_train_valid_split( 113 | items: Sequence, valid_fraction: float, seed: int 114 | ) -> Tuple[List, List]: 115 | assert 0.0 < valid_fraction < 1.0 116 | 117 | size = len(items) 118 | train_size = size - int(valid_fraction * size) 119 | 120 | indices = list(range(size)) 121 | rng = np.random.default_rng(seed) 122 | rng.shuffle(indices) 123 | 124 | return ( 125 | [items[i] for i in indices[:train_size]], 126 | [items[i] for i in indices[train_size:]], 127 | ) 128 | -------------------------------------------------------------------------------- /cace/tasks/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict, Union, Callable 2 | import torch 3 | import torch.nn as nn 4 | 5 | __all__ = ["GetLoss", "GetRegularizationLoss", "GetVarianceLoss"] 6 | 7 | class GetLoss(nn.Module): 8 | """ 9 | Defines mappings to a loss function and weight for training 10 | """ 11 | 12 | def __init__( 13 | self, 14 | target_name: str, 15 | predict_name: Optional[str] = None, 16 | output_index: Optional[int] = None, # only used for multi-output models 17 | name: Optional[str] = None, 18 | loss_fn: Optional[nn.Module] = None, 19 | loss_weight: Union[float, Callable] = 1.0, # Union[float, Callable] means that the type can be either float or callable 20 | ): 21 | """ 22 | Args: 23 | target_name: Name of target in training batch. 24 | name: name of the loss object 25 | loss_fn: function to compute the loss 26 | loss_weight: loss weight in the composite loss: $l = w_1 l_1 + \dots + w_n l_n$ 27 | This can be a float or a callable that takes in the loss_weight_args 28 | For example, if we want the loss weight to be dependent on the epoch number 29 | if training == True and a default value of 1.0 otherwise, 30 | loss_weight can be, e.g., lambda training, epoch: 1.0 if not training else epoch / 100 31 | """ 32 | super().__init__() 33 | self.target_name = target_name 34 | self.predict_name = predict_name or target_name 35 | self.output_index = output_index 36 | self.name = name or target_name 37 | self.loss_fn = loss_fn 38 | # the loss_weight can either be a float or a callable 39 | self.loss_weight = loss_weight 40 | 41 | def forward(self, 42 | pred: Dict[str, torch.Tensor], 43 | target: Optional[Dict[str, torch.Tensor]] = None, 44 | loss_args: Optional[Dict[str, torch.Tensor]] = None 45 | ): 46 | if self.loss_weight == 0 or self.loss_fn is None: 47 | return 0.0 48 | 49 | if isinstance(self.loss_weight, Callable): 50 | if loss_args is None: 51 | loss_weight = self.loss_weight() 52 | else: 53 | loss_weight = self.loss_weight(**loss_args) 54 | else: 55 | loss_weight = self.loss_weight 56 | 57 | if self.output_index is None: 58 | pred_tensor = pred[self.predict_name] 59 | else: 60 | pred_tensor = pred[self.predict_name][..., self.output_index] 61 | 62 | if pred_tensor.shape != target[self.target_name].shape: 63 | pred_tensor = pred_tensor.reshape(target[self.target_name].shape) 64 | 65 | if target is not None: 66 | target_tensor = target[self.target_name] 67 | elif self.predict_name != self.target_name: 68 | target_tensor = pred[self.target_name] 69 | else: 70 | raise ValueError("Target is None and predict_name is not equal to target_name") 71 | 72 | loss = loss_weight * self.loss_fn(pred_tensor, target_tensor) 73 | return loss 74 | 75 | def __repr__(self): 76 | return ( 77 | f"{self.__class__.__name__}(name={self.name}, loss_fn={self.loss_fn}, loss_weight={self.loss_weight})" 78 | ) 79 | 80 | class GetRegularizationLoss(nn.Module): 81 | def __init__(self, loss_weight, model): 82 | super().__init__() 83 | self.loss_weight = loss_weight 84 | self.model = model 85 | 86 | def forward(self, 87 | *args, 88 | ): 89 | regularization_loss = 0.0 90 | for param in self.model.parameters(): 91 | if param.requires_grad: 92 | regularization_loss += torch.norm(param, p=2) 93 | regularization_loss *= self.loss_weight 94 | return regularization_loss 95 | 96 | class GetVarianceLoss(nn.Module): 97 | def __init__( 98 | self, 99 | target_name: str, 100 | loss_weight: float = 1.0, 101 | name: Optional[str] = None, 102 | ): 103 | super().__init__() 104 | self.target_name = target_name 105 | self.loss_weight = loss_weight 106 | self.name = name 107 | 108 | def forward(self, pred: Dict[str, torch.Tensor], *args): 109 | 110 | # Compute the variance along the first dimension (across different entries) 111 | variances = torch.var(pred[self.target_name], dim=0) 112 | 113 | # Calculate the mean of the variances 114 | mean_variance = torch.mean(variances) 115 | mean_variance = mean_variance * self.loss_weight 116 | 117 | return mean_variance 118 | -------------------------------------------------------------------------------- /cace/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch_tools import ( 2 | elementwise_multiply_2tensors, 3 | elementwise_multiply_3tensors, 4 | to_numpy, 5 | voigt_to_matrix, 6 | init_device, 7 | tensor_dict_to_device, 8 | ) 9 | 10 | #from .slurm_distributed import * 11 | 12 | from .scatter import scatter_sum 13 | 14 | from .metric import * 15 | 16 | from .utils import ( 17 | compute_avg_num_neighbors, 18 | setup_logger, 19 | get_unique_atomic_number, 20 | compute_average_E0s 21 | ) 22 | 23 | from .output import batch_to_atoms 24 | 25 | from .parser_train import * 26 | 27 | from .io_utils import * 28 | -------------------------------------------------------------------------------- /cace/tools/io_utils.py: -------------------------------------------------------------------------------- 1 | # Description: Utility functions for saving and loading datasets 2 | 3 | #import h5py 4 | import numpy as np 5 | import torch 6 | 7 | __all__ = ['tensor_to_numpy', 'numpy_to_tensor', 'save_dataset', 'load_dataset'] 8 | 9 | # Function to convert tensors to numpy arrays 10 | def tensor_to_numpy(data): 11 | if isinstance(data, torch.Tensor): 12 | return data.numpy() 13 | elif isinstance(data, dict): 14 | return {k: tensor_to_numpy(v) for k, v in data.items()} 15 | elif isinstance(data, list): 16 | return [tensor_to_numpy(item) for item in data] 17 | else: 18 | return data 19 | 20 | # Function to convert numpy arrays back to tensors 21 | def numpy_to_tensor(data): 22 | if isinstance(data, np.ndarray): 23 | return torch.tensor(data) 24 | elif isinstance(data, dict): 25 | return {k: numpy_to_tensor(v) for k, v in data.items()} 26 | elif isinstance(data, list): 27 | return [numpy_to_tensor(item) for item in data] 28 | else: 29 | return data 30 | 31 | # Function to save the dataset to an HDF5 file 32 | def save_dataset(data, filename, shuffle=False): 33 | import h5py 34 | if shuffle: 35 | index = np.random.permutation(len(data)) # Shuffle the data 36 | else: 37 | index = np.arange(len(data)) 38 | with h5py.File(filename, 'w') as f: 39 | for i, index_now in enumerate(index): 40 | item = data[index_now] 41 | grp = f.create_group(str(i)) 42 | serializable_item = tensor_to_numpy(item) 43 | for k, v in serializable_item.items(): 44 | grp.create_dataset(k, data=v) 45 | print(f"Saved dataset with {len(data)} records to {filename}") 46 | 47 | def load_dataset(filename): 48 | import h5py 49 | all_data = [] 50 | with h5py.File(filename, 'r') as f: 51 | for key in f.keys(): 52 | grp = f[key] 53 | item = {} 54 | for k in grp.keys(): 55 | data = grp[k] 56 | if data.shape == (): # Check if the data is a scalar 57 | item[k] = torch.tensor(data[()]) # Access scalar value 58 | else: 59 | item[k] = torch.tensor(data[:]) # Access array value 60 | all_data.append(numpy_to_tensor(item)) 61 | print(f"Loaded dataset with {len(all_data)} records from {filename}") 62 | return all_data 63 | -------------------------------------------------------------------------------- /cace/tools/metric.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import torch.nn as nn 4 | from typing import Optional, Dict, List 5 | 6 | from .torch_tools import to_numpy 7 | 8 | __all__ = ['Metrics', 'compute_loss_metrics'] 9 | 10 | def compute_loss_metrics(metric: str, y_true: torch.Tensor, y_pred: torch.Tensor): 11 | """ 12 | Compute the loss metrics 13 | current options: mse, mae, rmse, r2 14 | """ 15 | if metric == 'mse': 16 | return torch.mean((y_true - y_pred) ** 2) 17 | elif metric == 'mae': 18 | return torch.mean(torch.abs(y_true - y_pred)) 19 | elif metric == 'rmse': 20 | return torch.sqrt(torch.mean((y_true - y_pred) ** 2)) 21 | elif metric == 'r2': 22 | return 1 - torch.sum((y_true - y_pred) ** 2) / torch.sum((y_true - torch.mean(y_true)) ** 2) 23 | else: 24 | raise ValueError('Metric not implemented') 25 | 26 | class Metrics(nn.Module): 27 | """ 28 | Defines and calculate metrics to be logged. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | target_name: str, 34 | predict_name: Optional[str] = None, 35 | output_index: Optional[int] = None, # used for multi-task learning 36 | name: Optional[str] = None, 37 | metric_keys: List[str] = ["mae", "rmse"], 38 | per_atom: bool = False, 39 | ): 40 | """ 41 | Args: 42 | target_name: name of the target in the dataset 43 | predict_name: name of the prediction in the model output 44 | name: name of the metrics 45 | metric_keys: list of metrics to be calculated 46 | per_atom: whether to calculate the metrics per atom 47 | """ 48 | 49 | super().__init__() 50 | self.target_name = target_name 51 | self.predict_name = predict_name or target_name 52 | self.output_index = output_index 53 | self.name = name or target_name 54 | 55 | self.per_atom = per_atom 56 | 57 | self.metric_keys = metric_keys 58 | self.logs = { 59 | "train": {'pred': [], 'target': []}, 60 | "val": {'pred': [], 'target': []}, 61 | "test": {'pred': [], 'target': []}, 62 | } 63 | 64 | def _collect_tensor(self, 65 | pred: Dict[str, torch.Tensor], 66 | target: Optional[Dict[str, torch.Tensor]] = None, 67 | ): 68 | pred_tensor = pred[self.predict_name].clone().detach() 69 | if len(pred_tensor.shape) > 2: 70 | pred_tensor = pred_tensor.reshape(pred_tensor.shape[0], -1) 71 | #print("pred_tensor", pred_tensor.shape) 72 | if self.output_index is not None: 73 | pred_tensor = pred_tensor[..., self.output_index] 74 | if target is not None: 75 | target_tensor = target[self.target_name].clone().detach() 76 | elif self.predict_name != self.target_name: 77 | target_tensor = pred[self.target_name].clone().detach() 78 | else: 79 | raise ValueError("Target is None and predict_name is not equal to target_name") 80 | #print("target_tensor:", target_tensor.shape) 81 | if self.per_atom: 82 | n_atoms = torch.bincount(target['batch']).clone().detach() 83 | pred_tensor = pred_tensor / n_atoms 84 | target_tensor = target_tensor / n_atoms 85 | return pred_tensor, target_tensor 86 | 87 | def forward(self, 88 | pred: Dict[str, torch.Tensor], 89 | target: Optional[Dict[str, torch.Tensor]] = None, 90 | ): 91 | pred_tensor, target_tensor = self._collect_tensor(pred, target) 92 | metrics_now = {} 93 | for metric in self.metric_keys: 94 | metrics_now[metric] = compute_loss_metrics(metric, target_tensor, pred_tensor) 95 | 96 | return metrics_now 97 | 98 | def update_metrics(self, subset: str, 99 | pred: Dict[str, torch.Tensor], 100 | target: Optional[Dict[str, torch.Tensor]] = None, 101 | ): 102 | pred_tensor, target_tensor = self._collect_tensor(pred, target) 103 | self.logs[subset]['pred'].append(pred_tensor) 104 | self.logs[subset]['target'].append(target_tensor) 105 | 106 | def retrieve_metrics(self, subset: str, clear: bool = True, print_log: bool = True): 107 | pred_tensor = torch.cat(self.logs[subset]['pred'], dim=0) 108 | target_tensor = torch.cat(self.logs[subset]['target'], dim=0) 109 | 110 | assert pred_tensor.shape == target_tensor.shape, f"pred_tensor.shape: {pred_tensor.shape}, target_tensor.shape: {target_tensor.shape}" 111 | 112 | if pred_tensor.shape[0] == 0: 113 | raise ValueError("No data in the logs") 114 | 115 | metrics_now = {} 116 | for metric in self.metric_keys: 117 | metric_mean = compute_loss_metrics(metric, target_tensor, pred_tensor) 118 | metrics_now[metric] = metric_mean 119 | if print_log: 120 | print( 121 | f'{subset}_{self.name}_{metric}: {metric_mean:.6f}', 122 | ) 123 | logging.info( 124 | f'{subset}_{self.name}_{metric}: {metric_mean:.6f}', 125 | ) 126 | if clear: 127 | self.clear_metrics(subset) 128 | 129 | return metrics_now 130 | 131 | def clear_metrics(self, subset: str): 132 | self.logs[subset]['pred'] = [] 133 | self.logs[subset]['target'] = [] 134 | 135 | def __repr__(self): 136 | return f'{self.__class__.__name__} name: {self.name}, metric_keys: {self.metric_keys}' 137 | -------------------------------------------------------------------------------- /cace/tools/output.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .torch_tools import to_numpy 3 | from typing import Dict, Optional 4 | import ase 5 | 6 | def batch_to_atoms(batched_data: Dict, 7 | pred_data: Optional[Dict] = None, 8 | output_file: str = None, 9 | energy_key: str = 'energy', 10 | forces_key: str = 'forces', 11 | cace_energy_key: str = 'energy', 12 | cace_forces_key: str = 'forces'): 13 | """ 14 | Create ASE Atoms objects from batched graph data and write to an XYZ file. 15 | 16 | Parameters: 17 | - batched_data (Dict): Batched data containing graph information. 18 | - pred_data (Dict): Predicted data. If not given, the pred_data name is assumed to also be the batched_data. 19 | - energy_key (str): Key for accessing energy information in batched_data. 20 | - forces_key (str): Key for accessing force information in batched_data. 21 | - cace_energy_key (str): Key for accessing CACE energy information. 22 | - cace_forces_key (str): Key for accessing CACE force information. 23 | - output_file (str): Name of the output file to write the Atoms objects. 24 | """ 25 | 26 | if pred_data == None and energy_key != cace_energy_key: 27 | pred_data = batched_data 28 | atoms_list = [] 29 | batch = batched_data.batch 30 | num_graphs = batch.max().item() + 1 31 | 32 | for i in range(num_graphs): 33 | # Mask to extract nodes for each graph 34 | mask = batch == i 35 | 36 | # Extract node features, edge indices, etc., for each graph 37 | positions = to_numpy(batched_data['positions'][mask]) 38 | atomic_numbers = to_numpy(batched_data['atomic_numbers'][mask]) 39 | cell = to_numpy(batched_data['cell'][3*i:3*i+3]) 40 | 41 | energy = to_numpy(batched_data[energy_key][i]) 42 | forces = to_numpy(batched_data[forces_key][mask]) 43 | cace_energy = to_numpy(pred_data[cace_energy_key][i]) 44 | cace_forces = to_numpy(pred_data[cace_forces_key][mask]) 45 | 46 | # Set periodic boundary conditions if the cell is defined 47 | pbc = np.all(np.mean(cell, axis=0) > 0) 48 | 49 | # Create the Atoms object 50 | atoms = ase.Atoms(numbers=atomic_numbers, positions=positions, cell=cell, pbc=pbc) 51 | atoms.info[energy_key] = energy.item() if np.ndim(energy) == 0 else energy 52 | atoms.arrays[forces_key] = forces 53 | atoms.info[cace_energy_key] = cace_energy.item() if np.ndim(cace_energy) == 0 else cace_energy 54 | atoms.arrays[cace_forces_key] = cace_forces 55 | atoms_list.append(atoms) 56 | 57 | # Write all atoms to the output file 58 | if output_file: 59 | ase.io.write(output_file, atoms_list, append=True) 60 | return atoms_list 61 | -------------------------------------------------------------------------------- /cace/tools/parser_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | __all__ = ['parse_arguments'] 4 | 5 | def parse_arguments(): 6 | parser = argparse.ArgumentParser(description='ML potential training configuration') 7 | 8 | parser.add_argument('--prefix', type=str, default='CACE_NNP', help='Prefix for the model name') 9 | 10 | # Dataset and Training Configuration 11 | parser.add_argument('--train_path', type=str, help='Path to the training dataset', required=True) 12 | parser.add_argument('--valid_path', type=str, default=None, help='Path to the training dataset') 13 | parser.add_argument('--valid_fraction', type=float, default=0.1, help='Fraction of data to use for validation') 14 | parser.add_argument('--energy_key', type=str, default='energy', help='Key for energy in the dataset') 15 | parser.add_argument('--forces_key', type=str, default='forces', help='Key for forces in the dataset') 16 | parser.add_argument('--cutoff', type=float, default=4.0, help='Cutoff radius for interactions') 17 | parser.add_argument('--batch_size', type=int, default=10, help='Batch size for training') 18 | parser.add_argument('--valid_batch_size', type=int, default=20, help='Batch size for validation') 19 | parser.add_argument('--use_device', type=str, default='cpu', help='Device to use for training') 20 | 21 | # Radial Basis Function (RBF) Configuration 22 | parser.add_argument('--n_rbf', type=int, default=6, help='Number of RBFs') 23 | parser.add_argument('--trainable_rbf', action='store_true', help='Whether the RBF parameters are trainable') 24 | 25 | # Cutoff Function Configuration 26 | parser.add_argument('--cutoff_fn', type=str, default='PolynomialCutoff', help='Type of cutoff function') 27 | parser.add_argument('--cutoff_fn_p', type=int, default=5, help='Polynomial degree for the PolynomialCutoff function') 28 | 29 | # Representation Configuration 30 | parser.add_argument('--zs', type=int, nargs='+', default=None, help='Atomic numbers considered in the model') 31 | parser.add_argument('--n_atom_basis', type=int, default=3, help='Number of atom basis functions') 32 | parser.add_argument('--n_radial_basis', type=int, default=8, help='Number of radial basis functions') 33 | parser.add_argument('--max_l', type=int, default=3, help='Maximum angular momentum quantum number') 34 | parser.add_argument('--max_nu', type=int, default=3, help='Maximum radial quantum number') 35 | parser.add_argument('--num_message_passing', type=int, default=1, help='Number of message passing steps') 36 | parser.add_argument('--embed_receiver_nodes', action='store_true', help='Whether to embed receiver nodes') 37 | 38 | # Atomwise Module Configuration 39 | parser.add_argument('--atomwise_layers', type=int, default=3, help='Number of layers in the atomwise module') 40 | parser.add_argument('--atomwise_hidden', type=int, nargs='+', default=[32, 16], help='Hidden units in each layer of the atomwise module') 41 | parser.add_argument('--atomwise_residual', action='store_false', help='Use residual connections in the atomwise module') 42 | parser.add_argument('--atomwise_batchnorm', action='store_false', help='Use batch normalization in the atomwise module') 43 | parser.add_argument('--atomwise_linear_nn', action='store_true', help='Add a linear neural network layer in the atomwise module') 44 | 45 | # Training Procedure Configuration 46 | parser.add_argument('--lr', type=float, default=1e-2, help='Learning rate') 47 | parser.add_argument('--scheduler_factor', type=float, default=0.8, help='Factor by which the learning rate is reduced') 48 | parser.add_argument('--scheduler_patience', type=int, default=10, help='Patience for the learning rate scheduler') 49 | parser.add_argument('--max_grad_norm', type=float, default=10, help='Max gradient norm for gradient clipping') 50 | parser.add_argument('--ema', action='store_true', help='Use exponential moving average of model parameters') 51 | parser.add_argument('--ema_start', type=int, default=10, help='Start using EMA after this many steps') 52 | parser.add_argument('--warmup_steps', type=int, default=10, help='Number of warmup steps for the optimizer') 53 | parser.add_argument('--epochs', type=int, default=200, help='Number of epochs for the first phase of training') 54 | parser.add_argument('--second_phase_epochs', type=int, default=100, help='Number of epochs for the second phase of training') 55 | parser.add_argument('--energy_loss_weight', type=float, default=1.0, help='Weight for the energy loss in phase 1') 56 | parser.add_argument('--force_loss_weight', type=float, default=1000.0, help='Weight for the force loss in both phases') 57 | parser.add_argument('--num_restart', type=int, default=5, help='Number of restarts for the training during phase 1') 58 | parser.add_argument('--second_phase_energy_loss_weight', type=float, default=1000.0, 59 | help='Weight for the energy loss in phase 2') 60 | 61 | return parser.parse_args() 62 | -------------------------------------------------------------------------------- /cace/tools/scatter.py: -------------------------------------------------------------------------------- 1 | """basic scatter_sum operations from torch_scatter from 2 | https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py 3 | Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. 4 | PyTorch plans to move these features into the main repo, but until then, 5 | to make installation simpler, we need this pure python set of wrappers 6 | that don't require installing PyTorch C++ extensions. 7 | See https://github.com/pytorch/pytorch/issues/63780. 8 | """ 9 | 10 | from typing import Optional 11 | 12 | import torch 13 | 14 | 15 | def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 16 | if dim < 0: 17 | dim = other.dim() + dim 18 | if src.dim() == 1: 19 | for _ in range(0, dim): 20 | src = src.unsqueeze(0) 21 | for _ in range(src.dim(), other.dim()): 22 | src = src.unsqueeze(-1) 23 | src = src.expand_as(other) 24 | return src 25 | 26 | 27 | @torch.jit.script 28 | def scatter_sum( 29 | src: torch.Tensor, 30 | index: torch.Tensor, 31 | dim: int = -1, 32 | out: Optional[torch.Tensor] = None, 33 | dim_size: Optional[int] = None, 34 | reduce: str = "sum", 35 | ) -> torch.Tensor: 36 | assert reduce == "sum" # for now, TODO 37 | index = _broadcast(index, src, dim) 38 | if out is None: 39 | size = list(src.size()) 40 | if dim_size is not None: 41 | size[dim] = dim_size 42 | elif index.numel() == 0: 43 | size[dim] = 0 44 | else: 45 | size[dim] = int(index.max()) + 1 46 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 47 | return out.scatter_add_(dim, index, src) 48 | else: 49 | return out.scatter_add_(dim, index, src) 50 | 51 | 52 | @torch.jit.script 53 | def scatter_std( 54 | src: torch.Tensor, 55 | index: torch.Tensor, 56 | dim: int = -1, 57 | out: Optional[torch.Tensor] = None, 58 | dim_size: Optional[int] = None, 59 | unbiased: bool = True, 60 | ) -> torch.Tensor: 61 | if out is not None: 62 | dim_size = out.size(dim) 63 | 64 | if dim < 0: 65 | dim = src.dim() + dim 66 | 67 | count_dim = dim 68 | if index.dim() <= dim: 69 | count_dim = index.dim() - 1 70 | 71 | ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) 72 | count = scatter_sum(ones, index, count_dim, dim_size=dim_size) 73 | 74 | index = _broadcast(index, src, dim) 75 | tmp = scatter_sum(src, index, dim, dim_size=dim_size) 76 | count = _broadcast(count, tmp, dim).clamp(1) 77 | mean = tmp.div(count) 78 | 79 | var = src - mean.gather(dim, index) 80 | var = var * var 81 | out = scatter_sum(var, index, dim, out, dim_size) 82 | 83 | if unbiased: 84 | count = count.sub(1).clamp_(1) 85 | out = out.div(count + 1e-6).sqrt() 86 | 87 | return out 88 | 89 | 90 | @torch.jit.script 91 | def scatter_mean( 92 | src: torch.Tensor, 93 | index: torch.Tensor, 94 | dim: int = -1, 95 | out: Optional[torch.Tensor] = None, 96 | dim_size: Optional[int] = None, 97 | ) -> torch.Tensor: 98 | out = scatter_sum(src, index, dim, out, dim_size) 99 | dim_size = out.size(dim) 100 | 101 | index_dim = dim 102 | if index_dim < 0: 103 | index_dim = index_dim + src.dim() 104 | if index.dim() <= index_dim: 105 | index_dim = index.dim() - 1 106 | 107 | ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) 108 | count = scatter_sum(ones, index, index_dim, None, dim_size) 109 | count[count < 1] = 1 110 | count = _broadcast(count, out, dim) 111 | if out.is_floating_point(): 112 | out.true_divide_(count) 113 | else: 114 | out.div_(count, rounding_mode="floor") 115 | return out 116 | -------------------------------------------------------------------------------- /cace/tools/torch_geometric/README.md: -------------------------------------------------------------------------------- 1 | # Trimmed-down `pytorch_geometric` 2 | 3 | MACE uses [`pytorch_geometric`](https://pytorch-geometric.readthedocs.io/en/latest/) [1, 2] framework. However as only use a very limited subset of that library: the most basic graph data structures. 4 | 5 | We follow the same approach to NequIP (https://github.com/mir-group/nequip/tree/main/nequip) and copy their code here. 6 | 7 | To avoid adding a large number of unnecessary second-degree dependencies, and to simplify installation, we include and modify here the small subset of `torch_geometric` that is neccessary for our code. 8 | 9 | We are grateful to the developers of PyTorch Geometric for their ongoing and very useful work on graph learning with PyTorch. 10 | 11 | [1] Fey, M., & Lenssen, J. E. (2019). Fast Graph Representation Learning with PyTorch Geometric (Version 2.0.1) [Computer software]. https://github.com/pyg-team/pytorch_geometric
12 | [2] https://arxiv.org/abs/1903.02428 -------------------------------------------------------------------------------- /cace/tools/torch_geometric/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch import Batch 2 | from .data import Data 3 | from .dataloader import DataLoader 4 | from .dataset import Dataset 5 | 6 | __all__ = ["Batch", "Data", "Dataset", "DataLoader"] 7 | -------------------------------------------------------------------------------- /cace/tools/torch_geometric/batch.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | from torch import Tensor 7 | 8 | from .data import Data 9 | from .dataset import IndexType 10 | 11 | 12 | class Batch(Data): 13 | r"""A plain old python object modeling a batch of graphs as one big 14 | (disconnected) graph. With :class:`torch_geometric.data.Data` being the 15 | base class, all its methods can also be used here. 16 | In addition, single graphs can be reconstructed via the assignment vector 17 | :obj:`batch`, which maps each node to its respective graph identifier. 18 | """ 19 | 20 | def __init__(self, batch=None, ptr=None, **kwargs): 21 | super().__init__(**kwargs) 22 | 23 | for key, item in kwargs.items(): 24 | if key == "num_nodes": 25 | self.__num_nodes__ = item 26 | else: 27 | self[key] = item 28 | 29 | self.batch = batch 30 | self.ptr = ptr 31 | self.__data_class__ = Data 32 | self.__slices__ = None 33 | self.__cumsum__ = None 34 | self.__cat_dims__ = None 35 | self.__num_nodes_list__ = None 36 | self.__num_graphs__ = None 37 | 38 | @classmethod 39 | def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]): 40 | r"""Constructs a batch object from a python list holding 41 | :class:`torch_geometric.data.Data` objects. 42 | The assignment vector :obj:`batch` is created on the fly. 43 | Additionally, creates assignment batch vectors for each key in 44 | :obj:`follow_batch`. 45 | Will exclude any keys given in :obj:`exclude_keys`.""" 46 | 47 | # this is for the competibility of different pytorch versions 48 | try: 49 | keys = list(set(data_list[0].keys) - set(exclude_keys)) 50 | except: 51 | keys = list(set(data_list[0].keys()) - set(exclude_keys)) 52 | assert "batch" not in keys and "ptr" not in keys 53 | 54 | batch = cls() 55 | for key in data_list[0].__dict__.keys(): 56 | if key[:2] != "__" and key[-2:] != "__": 57 | batch[key] = None 58 | 59 | batch.__num_graphs__ = len(data_list) 60 | batch.__data_class__ = data_list[0].__class__ 61 | for key in keys + ["batch"]: 62 | batch[key] = [] 63 | batch["ptr"] = [0] 64 | 65 | device = None 66 | slices = {key: [0] for key in keys} 67 | cumsum = {key: [0] for key in keys} 68 | cat_dims = {} 69 | num_nodes_list = [] 70 | for i, data in enumerate(data_list): 71 | for key in keys: 72 | item = data[key] 73 | 74 | # Increase values by `cumsum` value. 75 | cum = cumsum[key][-1] 76 | if isinstance(item, Tensor) and item.dtype != torch.bool: 77 | if not isinstance(cum, int) or cum != 0: 78 | item = item + cum 79 | elif isinstance(item, (int, float)): 80 | item = item + cum 81 | 82 | # Gather the size of the `cat` dimension. 83 | size = 1 84 | cat_dim = data.__cat_dim__(key, data[key]) 85 | # 0-dimensional tensors have no dimension along which to 86 | # concatenate, so we set `cat_dim` to `None`. 87 | if isinstance(item, Tensor) and item.dim() == 0: 88 | cat_dim = None 89 | cat_dims[key] = cat_dim 90 | 91 | # Add a batch dimension to items whose `cat_dim` is `None`: 92 | if isinstance(item, Tensor) and cat_dim is None: 93 | cat_dim = 0 # Concatenate along this new batch dimension. 94 | item = item.unsqueeze(0) 95 | device = item.device 96 | elif isinstance(item, Tensor): 97 | size = item.size(cat_dim) 98 | device = item.device 99 | 100 | batch[key].append(item) # Append item to the attribute list. 101 | 102 | slices[key].append(size + slices[key][-1]) 103 | inc = data.__inc__(key, item) 104 | if isinstance(inc, (tuple, list)): 105 | inc = torch.tensor(inc) 106 | cumsum[key].append(inc + cumsum[key][-1]) 107 | 108 | if key in follow_batch: 109 | if isinstance(size, Tensor): 110 | for j, size in enumerate(size.tolist()): 111 | tmp = f"{key}_{j}_batch" 112 | batch[tmp] = [] if i == 0 else batch[tmp] 113 | batch[tmp].append( 114 | torch.full((size,), i, dtype=torch.long, device=device) 115 | ) 116 | else: 117 | tmp = f"{key}_batch" 118 | batch[tmp] = [] if i == 0 else batch[tmp] 119 | batch[tmp].append( 120 | torch.full((size,), i, dtype=torch.long, device=device) 121 | ) 122 | 123 | if hasattr(data, "__num_nodes__"): 124 | num_nodes_list.append(data.__num_nodes__) 125 | else: 126 | num_nodes_list.append(None) 127 | 128 | num_nodes = data.num_nodes 129 | if num_nodes is not None: 130 | item = torch.full((num_nodes,), i, dtype=torch.long, device=device) 131 | batch.batch.append(item) 132 | batch.ptr.append(batch.ptr[-1] + num_nodes) 133 | 134 | batch.batch = None if len(batch.batch) == 0 else batch.batch 135 | batch.ptr = None if len(batch.ptr) == 1 else batch.ptr 136 | batch.__slices__ = slices 137 | batch.__cumsum__ = cumsum 138 | batch.__cat_dims__ = cat_dims 139 | batch.__num_nodes_list__ = num_nodes_list 140 | 141 | ref_data = data_list[0] 142 | for key in batch.keys: 143 | items = batch[key] 144 | item = items[0] 145 | cat_dim = ref_data.__cat_dim__(key, item) 146 | cat_dim = 0 if cat_dim is None else cat_dim 147 | if isinstance(item, Tensor): 148 | batch[key] = torch.cat(items, cat_dim) 149 | elif isinstance(item, (int, float)): 150 | batch[key] = torch.tensor(items) 151 | 152 | # if torch_geometric.is_debug_enabled(): 153 | # batch.debug() 154 | 155 | return batch.contiguous() 156 | 157 | def get_example(self, idx: int) -> Data: 158 | r"""Reconstructs the :class:`torch_geometric.data.Data` object at index 159 | :obj:`idx` from the batch object. 160 | The batch object must have been created via :meth:`from_data_list` in 161 | order to be able to reconstruct the initial objects.""" 162 | 163 | if self.__slices__ is None: 164 | raise RuntimeError( 165 | ( 166 | "Cannot reconstruct data list from batch because the batch " 167 | "object was not created using `Batch.from_data_list()`." 168 | ) 169 | ) 170 | 171 | data = self.__data_class__() 172 | idx = self.num_graphs + idx if idx < 0 else idx 173 | 174 | for key in self.__slices__.keys(): 175 | item = self[key] 176 | if self.__cat_dims__[key] is None: 177 | # The item was concatenated along a new batch dimension, 178 | # so just index in that dimension: 179 | item = item[idx] 180 | else: 181 | # Narrow the item based on the values in `__slices__`. 182 | if isinstance(item, Tensor): 183 | dim = self.__cat_dims__[key] 184 | start = self.__slices__[key][idx] 185 | end = self.__slices__[key][idx + 1] 186 | item = item.narrow(dim, start, end - start) 187 | else: 188 | start = self.__slices__[key][idx] 189 | end = self.__slices__[key][idx + 1] 190 | item = item[start:end] 191 | item = item[0] if len(item) == 1 else item 192 | 193 | # Decrease its value by `cumsum` value: 194 | cum = self.__cumsum__[key][idx] 195 | if isinstance(item, Tensor): 196 | if not isinstance(cum, int) or cum != 0: 197 | item = item - cum 198 | elif isinstance(item, (int, float)): 199 | item = item - cum 200 | 201 | data[key] = item 202 | 203 | if self.__num_nodes_list__[idx] is not None: 204 | data.num_nodes = self.__num_nodes_list__[idx] 205 | 206 | return data 207 | 208 | def index_select(self, idx: IndexType) -> List[Data]: 209 | if isinstance(idx, slice): 210 | idx = list(range(self.num_graphs)[idx]) 211 | 212 | elif isinstance(idx, Tensor) and idx.dtype == torch.long: 213 | idx = idx.flatten().tolist() 214 | 215 | elif isinstance(idx, Tensor) and idx.dtype == torch.bool: 216 | idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist() 217 | 218 | elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: 219 | idx = idx.flatten().tolist() 220 | 221 | elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: 222 | idx = idx.flatten().nonzero()[0].flatten().tolist() 223 | 224 | elif isinstance(idx, Sequence) and not isinstance(idx, str): 225 | pass 226 | 227 | else: 228 | raise IndexError( 229 | f"Only integers, slices (':'), list, tuples, torch.tensor and " 230 | f"np.ndarray of dtype long or bool are valid indices (got " 231 | f"'{type(idx).__name__}')" 232 | ) 233 | 234 | return [self.get_example(i) for i in idx] 235 | 236 | def __getitem__(self, idx): 237 | if isinstance(idx, str): 238 | return super().__getitem__(idx) 239 | elif isinstance(idx, (int, np.integer)): 240 | return self.get_example(idx) 241 | else: 242 | return self.index_select(idx) 243 | 244 | def to_data_list(self) -> List[Data]: 245 | r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects 246 | from the batch object. 247 | The batch object must have been created via :meth:`from_data_list` in 248 | order to be able to reconstruct the initial objects.""" 249 | return [self.get_example(i) for i in range(self.num_graphs)] 250 | 251 | @property 252 | def num_graphs(self) -> int: 253 | """Returns the number of graphs in the batch.""" 254 | if self.__num_graphs__ is not None: 255 | return self.__num_graphs__ 256 | elif self.ptr is not None: 257 | return self.ptr.numel() - 1 258 | elif self.batch is not None: 259 | return int(self.batch.max()) + 1 260 | else: 261 | raise ValueError 262 | -------------------------------------------------------------------------------- /cace/tools/torch_geometric/dataloader.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping, Sequence 2 | from typing import List, Optional, Union 3 | 4 | import torch.utils.data 5 | from torch.utils.data.dataloader import default_collate 6 | 7 | from .batch import Batch 8 | from .data import Data 9 | from .dataset import Dataset 10 | 11 | class Collater: 12 | def __init__(self, follow_batch, exclude_keys): 13 | self.follow_batch = follow_batch 14 | self.exclude_keys = exclude_keys 15 | 16 | def __call__(self, batch): 17 | elem = batch[0] 18 | if isinstance(elem, Data): 19 | return Batch.from_data_list( 20 | batch, 21 | follow_batch=self.follow_batch, 22 | exclude_keys=self.exclude_keys, 23 | ) 24 | elif isinstance(elem, torch.Tensor): 25 | return default_collate(batch) 26 | elif isinstance(elem, float): 27 | return torch.tensor(batch, dtype=torch.float) 28 | elif isinstance(elem, int): 29 | return torch.tensor(batch) 30 | elif isinstance(elem, str): 31 | return batch 32 | elif isinstance(elem, Mapping): 33 | return {key: self([data[key] for data in batch]) for key in elem} 34 | elif isinstance(elem, tuple) and hasattr(elem, "_fields"): 35 | return type(elem)(*(self(s) for s in zip(*batch))) 36 | elif isinstance(elem, Sequence) and not isinstance(elem, str): 37 | return [self(s) for s in zip(*batch)] 38 | 39 | raise TypeError(f"DataLoader found invalid type: {type(elem)}") 40 | 41 | def collate(self, batch): # Deprecated... 42 | return self(batch) 43 | 44 | 45 | class DataLoader(torch.utils.data.DataLoader): 46 | r"""A data loader which merges data objects from a 47 | :class:`torch_geometric.data.Dataset` to a mini-batch. 48 | Data objects can be either of type :class:`~torch_geometric.data.Data` or 49 | :class:`~torch_geometric.data.HeteroData`. 50 | Args: 51 | dataset (Dataset): The dataset from which to load the data. 52 | batch_size (int, optional): How many samples per batch to load. 53 | (default: :obj:`1`) 54 | shuffle (bool, optional): If set to :obj:`True`, the data will be 55 | reshuffled at every epoch. (default: :obj:`False`) 56 | follow_batch (List[str], optional): Creates assignment batch 57 | vectors for each key in the list. (default: :obj:`None`) 58 | exclude_keys (List[str], optional): Will exclude each key in the 59 | list. (default: :obj:`None`) 60 | **kwargs (optional): Additional arguments of 61 | :class:`torch.utils.data.DataLoader`. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | dataset: Dataset, 67 | batch_size: int = 1, 68 | shuffle: bool = False, 69 | follow_batch: Optional[List[str]] = [None], 70 | exclude_keys: Optional[List[str]] = [None], 71 | **kwargs, 72 | ): 73 | if "collate_fn" in kwargs: 74 | del kwargs["collate_fn"] 75 | 76 | # Save for PyTorch Lightning < 1.6: 77 | self.follow_batch = follow_batch 78 | self.exclude_keys = exclude_keys 79 | 80 | super().__init__( 81 | dataset, 82 | batch_size, 83 | shuffle, 84 | collate_fn=Collater(follow_batch, exclude_keys), 85 | **kwargs, 86 | ) 87 | -------------------------------------------------------------------------------- /cace/tools/torch_geometric/dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os.path as osp 3 | import re 4 | import warnings 5 | from collections.abc import Sequence 6 | from typing import Any, Callable, List, Optional, Tuple, Union 7 | 8 | import numpy as np 9 | import torch.utils.data 10 | from torch import Tensor 11 | 12 | from .data import Data 13 | from .utils import makedirs 14 | 15 | IndexType = Union[slice, Tensor, np.ndarray, Sequence] 16 | 17 | 18 | class Dataset(torch.utils.data.Dataset): 19 | r"""Dataset base class for creating graph datasets. 20 | See `here `__ for the accompanying tutorial. 22 | 23 | Args: 24 | root (string, optional): Root directory where the dataset should be 25 | saved. (optional: :obj:`None`) 26 | transform (callable, optional): A function/transform that takes in an 27 | :obj:`torch_geometric.data.Data` object and returns a transformed 28 | version. The data object will be transformed before every access. 29 | (default: :obj:`None`) 30 | pre_transform (callable, optional): A function/transform that takes in 31 | an :obj:`torch_geometric.data.Data` object and returns a 32 | transformed version. The data object will be transformed before 33 | being saved to disk. (default: :obj:`None`) 34 | pre_filter (callable, optional): A function that takes in an 35 | :obj:`torch_geometric.data.Data` object and returns a boolean 36 | value, indicating whether the data object should be included in the 37 | final dataset. (default: :obj:`None`) 38 | """ 39 | 40 | @property 41 | def raw_file_names(self) -> Union[str, List[str], Tuple]: 42 | r"""The name of the files to find in the :obj:`self.raw_dir` folder in 43 | order to skip the download.""" 44 | raise NotImplementedError 45 | 46 | @property 47 | def processed_file_names(self) -> Union[str, List[str], Tuple]: 48 | r"""The name of the files to find in the :obj:`self.processed_dir` 49 | folder in order to skip the processing.""" 50 | raise NotImplementedError 51 | 52 | def download(self): 53 | r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" 54 | raise NotImplementedError 55 | 56 | def process(self): 57 | r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" 58 | raise NotImplementedError 59 | 60 | def len(self) -> int: 61 | raise NotImplementedError 62 | 63 | def get(self, idx: int) -> Data: 64 | r"""Gets the data object at index :obj:`idx`.""" 65 | raise NotImplementedError 66 | 67 | def __init__( 68 | self, 69 | root: Optional[str] = None, 70 | transform: Optional[Callable] = None, 71 | pre_transform: Optional[Callable] = None, 72 | pre_filter: Optional[Callable] = None, 73 | ): 74 | super().__init__() 75 | 76 | if isinstance(root, str): 77 | root = osp.expanduser(osp.normpath(root)) 78 | 79 | self.root = root 80 | self.transform = transform 81 | self.pre_transform = pre_transform 82 | self.pre_filter = pre_filter 83 | self._indices: Optional[Sequence] = None 84 | 85 | if "download" in self.__class__.__dict__.keys(): 86 | self._download() 87 | 88 | if "process" in self.__class__.__dict__.keys(): 89 | self._process() 90 | 91 | def indices(self) -> Sequence: 92 | return range(self.len()) if self._indices is None else self._indices 93 | 94 | @property 95 | def raw_dir(self) -> str: 96 | return osp.join(self.root, "raw") 97 | 98 | @property 99 | def processed_dir(self) -> str: 100 | return osp.join(self.root, "processed") 101 | 102 | @property 103 | def num_node_features(self) -> int: 104 | r"""Returns the number of features per node in the dataset.""" 105 | data = self[0] 106 | if hasattr(data, "num_node_features"): 107 | return data.num_node_features 108 | raise AttributeError( 109 | f"'{data.__class__.__name__}' object has no " 110 | f"attribute 'num_node_features'" 111 | ) 112 | 113 | @property 114 | def num_features(self) -> int: 115 | r"""Alias for :py:attr:`~num_node_features`.""" 116 | return self.num_node_features 117 | 118 | @property 119 | def num_edge_features(self) -> int: 120 | r"""Returns the number of features per edge in the dataset.""" 121 | data = self[0] 122 | if hasattr(data, "num_edge_features"): 123 | return data.num_edge_features 124 | raise AttributeError( 125 | f"'{data.__class__.__name__}' object has no " 126 | f"attribute 'num_edge_features'" 127 | ) 128 | 129 | @property 130 | def raw_paths(self) -> List[str]: 131 | r"""The filepaths to find in order to skip the download.""" 132 | files = to_list(self.raw_file_names) 133 | return [osp.join(self.raw_dir, f) for f in files] 134 | 135 | @property 136 | def processed_paths(self) -> List[str]: 137 | r"""The filepaths to find in the :obj:`self.processed_dir` 138 | folder in order to skip the processing.""" 139 | files = to_list(self.processed_file_names) 140 | return [osp.join(self.processed_dir, f) for f in files] 141 | 142 | def _download(self): 143 | if files_exist(self.raw_paths): # pragma: no cover 144 | return 145 | 146 | makedirs(self.raw_dir) 147 | self.download() 148 | 149 | def _process(self): 150 | f = osp.join(self.processed_dir, "pre_transform.pt") 151 | if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): 152 | warnings.warn( 153 | f"The `pre_transform` argument differs from the one used in " 154 | f"the pre-processed version of this dataset. If you want to " 155 | f"make use of another pre-processing technique, make sure to " 156 | f"sure to delete '{self.processed_dir}' first" 157 | ) 158 | 159 | f = osp.join(self.processed_dir, "pre_filter.pt") 160 | if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): 161 | warnings.warn( 162 | "The `pre_filter` argument differs from the one used in the " 163 | "pre-processed version of this dataset. If you want to make " 164 | "use of another pre-fitering technique, make sure to delete " 165 | "'{self.processed_dir}' first" 166 | ) 167 | 168 | if files_exist(self.processed_paths): # pragma: no cover 169 | return 170 | 171 | print("Processing...") 172 | 173 | makedirs(self.processed_dir) 174 | self.process() 175 | 176 | path = osp.join(self.processed_dir, "pre_transform.pt") 177 | torch.save(_repr(self.pre_transform), path) 178 | path = osp.join(self.processed_dir, "pre_filter.pt") 179 | torch.save(_repr(self.pre_filter), path) 180 | 181 | print("Done!") 182 | 183 | def __len__(self) -> int: 184 | r"""The number of examples in the dataset.""" 185 | return len(self.indices()) 186 | 187 | def __getitem__( 188 | self, 189 | idx: Union[int, np.integer, IndexType], 190 | ) -> Union["Dataset", Data]: 191 | r"""In case :obj:`idx` is of type integer, will return the data object 192 | at index :obj:`idx` (and transforms it in case :obj:`transform` is 193 | present). 194 | In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a 195 | tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy 196 | :obj:`np.array`, will return a subset of the dataset at the specified 197 | indices.""" 198 | if ( 199 | isinstance(idx, (int, np.integer)) 200 | or (isinstance(idx, Tensor) and idx.dim() == 0) 201 | or (isinstance(idx, np.ndarray) and np.isscalar(idx)) 202 | ): 203 | data = self.get(self.indices()[idx]) 204 | data = data if self.transform is None else self.transform(data) 205 | return data 206 | 207 | else: 208 | return self.index_select(idx) 209 | 210 | def index_select(self, idx: IndexType) -> "Dataset": 211 | indices = self.indices() 212 | 213 | if isinstance(idx, slice): 214 | indices = indices[idx] 215 | 216 | elif isinstance(idx, Tensor) and idx.dtype == torch.long: 217 | return self.index_select(idx.flatten().tolist()) 218 | 219 | elif isinstance(idx, Tensor) and idx.dtype == torch.bool: 220 | idx = idx.flatten().nonzero(as_tuple=False) 221 | return self.index_select(idx.flatten().tolist()) 222 | 223 | elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: 224 | return self.index_select(idx.flatten().tolist()) 225 | 226 | elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: 227 | idx = idx.flatten().nonzero()[0] 228 | return self.index_select(idx.flatten().tolist()) 229 | 230 | elif isinstance(idx, Sequence) and not isinstance(idx, str): 231 | indices = [indices[i] for i in idx] 232 | 233 | else: 234 | raise IndexError( 235 | f"Only integers, slices (':'), list, tuples, torch.tensor and " 236 | f"np.ndarray of dtype long or bool are valid indices (got " 237 | f"'{type(idx).__name__}')" 238 | ) 239 | 240 | dataset = copy.copy(self) 241 | dataset._indices = indices 242 | return dataset 243 | 244 | def shuffle( 245 | self, 246 | return_perm: bool = False, 247 | ) -> Union["Dataset", Tuple["Dataset", Tensor]]: 248 | r"""Randomly shuffles the examples in the dataset. 249 | 250 | Args: 251 | return_perm (bool, optional): If set to :obj:`True`, will return 252 | the random permutation used to shuffle the dataset in addition. 253 | (default: :obj:`False`) 254 | """ 255 | perm = torch.randperm(len(self)) 256 | dataset = self.index_select(perm) 257 | return (dataset, perm) if return_perm is True else dataset 258 | 259 | def __repr__(self) -> str: 260 | arg_repr = str(len(self)) if len(self) > 1 else "" 261 | return f"{self.__class__.__name__}({arg_repr})" 262 | 263 | 264 | def to_list(value: Any) -> Sequence: 265 | if isinstance(value, Sequence) and not isinstance(value, str): 266 | return value 267 | else: 268 | return [value] 269 | 270 | 271 | def files_exist(files: List[str]) -> bool: 272 | # NOTE: We return `False` in case `files` is empty, leading to a 273 | # re-processing of files on every instantiation. 274 | return len(files) != 0 and all([osp.exists(f) for f in files]) 275 | 276 | 277 | def _repr(obj: Any) -> str: 278 | if obj is None: 279 | return "None" 280 | return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__()) 281 | -------------------------------------------------------------------------------- /cace/tools/torch_geometric/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import ssl 4 | import urllib 5 | import zipfile 6 | 7 | 8 | def makedirs(dir): 9 | os.makedirs(dir, exist_ok=True) 10 | 11 | 12 | def download_url(url, folder, log=True): 13 | r"""Downloads the content of an URL to a specific folder. 14 | 15 | Args: 16 | url (string): The url. 17 | folder (string): The folder. 18 | log (bool, optional): If :obj:`False`, will not print anything to the 19 | console. (default: :obj:`True`) 20 | """ 21 | 22 | filename = url.rpartition("/")[2].split("?")[0] 23 | path = osp.join(folder, filename) 24 | 25 | if osp.exists(path): # pragma: no cover 26 | if log: 27 | print("Using exist file", filename) 28 | return path 29 | 30 | if log: 31 | print("Downloading", url) 32 | 33 | makedirs(folder) 34 | 35 | context = ssl._create_unverified_context() 36 | data = urllib.request.urlopen(url, context=context) 37 | 38 | with open(path, "wb") as f: 39 | f.write(data.read()) 40 | 41 | return path 42 | 43 | 44 | def extract_zip(path, folder, log=True): 45 | r"""Extracts a zip archive to a specific folder. 46 | 47 | Args: 48 | path (string): The path to the tar archive. 49 | folder (string): The folder. 50 | log (bool, optional): If :obj:`False`, will not print anything to the 51 | console. (default: :obj:`True`) 52 | """ 53 | with zipfile.ZipFile(path, "r") as f: 54 | f.extractall(folder) 55 | -------------------------------------------------------------------------------- /cace/tools/torch_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import logging 4 | from typing import Dict 5 | 6 | TensorDict = Dict[str, torch.Tensor] 7 | 8 | def elementwise_multiply_2tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 9 | """ 10 | Elementwise multiplication of two 2D tensors 11 | :param a: (N, A) tensor 12 | :param b: (N, B) tensor 13 | :return: (N, A, B) tensor 14 | """ 15 | # expand the dimenstions for broadcasting 16 | a_expanded = a.unsqueeze(2) 17 | b_expanded = b.unsqueeze(1) 18 | # multiply 19 | return a_expanded * b_expanded 20 | 21 | @torch.jit.script 22 | def elementwise_multiply_3tensors(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: 23 | """ 24 | Elementwise multiplication of three 2D tensors 25 | :param a: (N, A) tensor 26 | :param b: (N, B) tensor 27 | :param c: (N, C) tensor 28 | :return: (N, A, B, C) tensor 29 | """ 30 | # expand the dimenstions for broadcasting 31 | a_expanded = a.unsqueeze(2).unsqueeze(3) 32 | b_expanded = b.unsqueeze(1).unsqueeze(3) 33 | c_expanded = c.unsqueeze(1).unsqueeze(2) 34 | # multiply 35 | # this is the same as torch.einsum('ni,nj,nk->nijk', a, b,c) 36 | # but a bit faster 37 | return a_expanded * b_expanded * c_expanded 38 | 39 | def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict: 40 | return {k: v.to(device) if v is not None else None for k, v in td.items()} 41 | 42 | def to_numpy(t: torch.Tensor) -> np.ndarray: 43 | return t.cpu().detach().numpy() 44 | 45 | def init_device(device_str: str) -> torch.device: 46 | if device_str == "cuda": 47 | assert torch.cuda.is_available(), "No CUDA device available!" 48 | logging.info( 49 | f"CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}" 50 | ) 51 | torch.cuda.init() 52 | return torch.device("cuda") 53 | if device_str == "mps": 54 | assert torch.backends.mps.is_available(), "No MPS backend is available!" 55 | logging.info("Using MPS GPU acceleration") 56 | return torch.device("mps") 57 | 58 | logging.info("Using CPU") 59 | return torch.device("cpu") 60 | 61 | def voigt_to_matrix(t: torch.Tensor): 62 | """ 63 | Convert voigt notation to matrix notation 64 | :param t: (6,) tensor or (3, 3) tensor 65 | :return: (3, 3) tensor 66 | """ 67 | if t.shape == (3, 3): 68 | return t 69 | 70 | return torch.tensor( 71 | [[t[0], t[5], t[4]], [t[5], t[1], t[3]], [t[4], t[3], t[2]]], dtype=t.dtype 72 | ) 73 | 74 | def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor: 75 | """ 76 | Generates one-hot encoding with classes from 77 | :param indices: (N x 1) tensor 78 | :param num_classes: number of classes 79 | :param device: torch device 80 | :return: (N x num_classes) tensor 81 | """ 82 | shape = indices.shape[:-1] + (num_classes,) 83 | oh = torch.zeros(shape, device=indices.device).view(shape) 84 | 85 | # scatter_ is the in-place version of scatter 86 | oh.scatter_(dim=-1, index=indices, value=1) 87 | 88 | return oh.view(*shape) 89 | -------------------------------------------------------------------------------- /cace/tools/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import sys 5 | from typing import Any, Dict, Iterable, Optional, Sequence, Union, List 6 | from ase import Atoms 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from . import torch_geometric 12 | from .torch_tools import to_numpy 13 | 14 | class AtomicNumberTable: 15 | def __init__(self, zs: Sequence[int]): 16 | self.zs = zs 17 | 18 | def __len__(self) -> int: 19 | return len(self.zs) 20 | 21 | def __str__(self): 22 | return f"AtomicNumberTable: {tuple(s for s in self.zs)}" 23 | 24 | def index_to_z(self, index: int) -> int: 25 | return self.zs[index] 26 | 27 | def z_to_index(self, atomic_number: str) -> int: 28 | return self.zs.index(atomic_number) 29 | 30 | 31 | def get_atomic_number_table_from_zs(zs: Iterable[int]) -> AtomicNumberTable: 32 | z_set = set() 33 | for z in zs: 34 | z_set.add(z) 35 | return AtomicNumberTable(sorted(list(z_set))) 36 | 37 | 38 | def atomic_numbers_to_indices( 39 | atomic_numbers: np.ndarray, z_table: AtomicNumberTable 40 | ) -> np.ndarray: 41 | to_index_fn = np.vectorize(z_table.z_to_index) 42 | return to_index_fn(atomic_numbers) 43 | 44 | def compute_avg_num_neighbors(batches: Union[torch.utils.data.DataLoader, torch_geometric.data.Data, torch_geometric.batch.Batch]) -> float: 45 | num_neighbors = [] 46 | 47 | if isinstance(batches, torch_geometric.data.Data) or isinstance(batches, torch_geometric.batch.Batch): 48 | _, receivers = batches.edge_index 49 | _, counts = torch.unique(receivers, return_counts=True) 50 | num_neighbors.append(counts) 51 | elif isinstance(batches, torch.utils.data.DataLoader): 52 | for batch in batches: 53 | _, receivers = batch.edge_index 54 | _, counts = torch.unique(receivers, return_counts=True) 55 | num_neighbors.append(counts) 56 | 57 | avg_num_neighbors = torch.mean( 58 | torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype()) 59 | ) 60 | return to_numpy(avg_num_neighbors).item() 61 | 62 | def get_unique_atomic_number(atoms_list: List[Atoms]) -> List[int]: 63 | """ 64 | Read a multi-frame XYZ file and return a list of unique atomic numbers 65 | present across all frames. 66 | 67 | Returns: 68 | list: List of unique atomic numbers. 69 | """ 70 | unique_atomic_numbers = set() 71 | 72 | for atoms in atoms_list: 73 | unique_atomic_numbers.update(atom.number for atom in atoms) 74 | 75 | return list(unique_atomic_numbers) 76 | 77 | def compute_average_E0s( 78 | atom_list: Atoms, zs: List[int] = None, energy_key: str = "energy" 79 | ) -> Dict[int, float]: 80 | """ 81 | Function to compute the average interaction energy of each chemical element 82 | returns dictionary of E0s 83 | """ 84 | len_xyz = len(atom_list) 85 | if zs is None: 86 | zs = get_unique_atomic_number(atom_list) 87 | # sort by atomic number 88 | zs.sort() 89 | len_zs = len(zs) 90 | 91 | A = np.zeros((len_xyz, len_zs)) 92 | B = np.zeros(len_xyz) 93 | for i in range(len_xyz): 94 | B[i] = atom_list[i].info[energy_key] 95 | for j, z in enumerate(zs): 96 | A[i, j] = np.count_nonzero(atom_list[i].get_atomic_numbers() == z) 97 | try: 98 | E0s = np.linalg.lstsq(A, B, rcond=None)[0] 99 | atomic_energies_dict = {} 100 | for i, z in enumerate(zs): 101 | atomic_energies_dict[z] = E0s[i] 102 | except np.linalg.LinAlgError: 103 | logging.warning( 104 | "Failed to compute E0s using least squares regression, using the same for all atoms" 105 | ) 106 | atomic_energies_dict = {} 107 | for i, z in enumerate(zs): 108 | atomic_energies_dict[z] = 0.0 109 | return atomic_energies_dict 110 | 111 | def setup_logger( 112 | level: Union[int, str] = logging.INFO, 113 | tag: Optional[str] = None, 114 | directory: Optional[str] = None, 115 | ): 116 | logger = logging.getLogger() 117 | logger.setLevel(level) 118 | 119 | formatter = logging.Formatter( 120 | "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s", 121 | datefmt="%Y-%m-%d %H:%M:%S", 122 | ) 123 | 124 | ch = logging.StreamHandler(stream=sys.stdout) 125 | ch.setFormatter(formatter) 126 | logger.addHandler(ch) 127 | 128 | if (directory is not None) and (tag is not None): 129 | os.makedirs(name=directory, exist_ok=True) 130 | path = os.path.join(directory, tag + ".log") 131 | fh = logging.FileHandler(path) 132 | fh.setFormatter(formatter) 133 | 134 | logger.addHandler(fh) 135 | -------------------------------------------------------------------------------- /examples/water_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import sys 5 | sys.path.append('../cace/') 6 | 7 | import numpy as np 8 | #import matplotlib.pyplot as plt 9 | import torch 10 | import torch.nn as nn 11 | import logging 12 | 13 | import cace 14 | from cace.representations import Cace 15 | from cace.modules import CosineCutoff, MollifierCutoff, PolynomialCutoff 16 | from cace.modules import BesselRBF, GaussianRBF, GaussianRBFCentered 17 | 18 | from cace.models.atomistic import NeuralNetworkPotential 19 | from cace.tasks.train import TrainingTask 20 | 21 | torch.set_default_dtype(torch.float32) 22 | 23 | cace.tools.setup_logger(level='INFO') 24 | 25 | logging.info("reading data") 26 | collection = cace.tasks.get_dataset_from_xyz(train_path='../../datasets/water/water.xyz', 27 | valid_fraction=0.1, 28 | seed=1, 29 | energy_key='energy', 30 | forces_key='force', 31 | atomic_energies={1: -187.6043857100553, 8: -93.80219285502734} # avg 32 | ) 33 | cutoff = 5.5 34 | batch_size = 2 35 | 36 | train_loader = cace.tasks.load_data_loader(collection=collection, 37 | data_type='train', 38 | batch_size=batch_size, 39 | cutoff=cutoff) 40 | 41 | valid_loader = cace.tasks.load_data_loader(collection=collection, 42 | data_type='valid', 43 | batch_size=4, 44 | cutoff=cutoff) 45 | 46 | use_device = 'cuda' 47 | device = cace.tools.init_device(use_device) 48 | logging.info(f"device: {use_device}") 49 | 50 | 51 | logging.info("building CACE representation") 52 | radial_basis = BesselRBF(cutoff=cutoff, n_rbf=6, trainable=True) 53 | #cutoff_fn = CosineCutoff(cutoff=cutoff) 54 | cutoff_fn = PolynomialCutoff(cutoff=cutoff) 55 | 56 | cace_representation = Cace( 57 | zs=[1,8], 58 | n_atom_basis=3, 59 | embed_receiver_nodes=True, 60 | cutoff=cutoff, 61 | cutoff_fn=cutoff_fn, 62 | radial_basis=radial_basis, 63 | n_radial_basis=12, 64 | max_l=3, 65 | max_nu=3, 66 | num_message_passing=1, 67 | type_message_passing=['Bchi'], 68 | device=device, 69 | timeit=False 70 | ) 71 | 72 | cace_representation.to(device) 73 | logging.info(f"Representation: {cace_representation}") 74 | 75 | atomwise = cace.modules.atomwise.Atomwise(n_layers=3, 76 | output_key='CACE_energy', 77 | n_hidden=[32,16], 78 | use_batchnorm=False, 79 | add_linear_nn=True) 80 | 81 | 82 | forces = cace.modules.forces.Forces(energy_key='CACE_energy', 83 | forces_key='CACE_forces') 84 | 85 | logging.info("building CACE NNP") 86 | cace_nnp = NeuralNetworkPotential( 87 | input_modules=None, 88 | representation=cace_representation, 89 | output_modules=[atomwise, forces] 90 | ) 91 | 92 | #trainable_params = sum(p.numel() for p in cace_nnp.parameters() if p.requires_grad) 93 | #logging.info(f"Number of trainable parameters: {trainable_params}") 94 | 95 | cace_nnp.to(device) 96 | 97 | 98 | logging.info(f"First train loop:") 99 | energy_loss = cace.tasks.GetLoss( 100 | target_name='energy', 101 | predict_name='CACE_energy', 102 | loss_fn=torch.nn.MSELoss(), 103 | loss_weight=0.1 104 | ) 105 | 106 | force_loss = cace.tasks.GetLoss( 107 | target_name='forces', 108 | predict_name='CACE_forces', 109 | loss_fn=torch.nn.MSELoss(), 110 | loss_weight=1000 111 | ) 112 | 113 | from cace.tools import Metrics 114 | 115 | e_metric = Metrics( 116 | target_name='energy', 117 | predict_name='CACE_energy', 118 | name='e/atom', 119 | per_atom=True 120 | ) 121 | 122 | f_metric = Metrics( 123 | target_name='forces', 124 | predict_name='CACE_forces', 125 | name='f' 126 | ) 127 | 128 | # Example usage 129 | logging.info("creating training task") 130 | 131 | optimizer_args = {'lr': 1e-2} 132 | scheduler_args = {'mode': 'min', 'factor': 0.8, 'patience': 10} 133 | 134 | for i in range(5): 135 | task = TrainingTask( 136 | model=cace_nnp, 137 | losses=[energy_loss, force_loss], 138 | metrics=[e_metric, f_metric], 139 | device=device, 140 | optimizer_args=optimizer_args, 141 | #scheduler_cls=torch.optim.lr_scheduler.StepLR, 142 | scheduler_cls=torch.optim.lr_scheduler.ReduceLROnPlateau, 143 | scheduler_args=scheduler_args, 144 | max_grad_norm=10, 145 | ema=True, 146 | ema_start=10, 147 | warmup_steps=5, 148 | ) 149 | 150 | logging.info("training") 151 | task.fit(train_loader, valid_loader, epochs=40, screen_nan=False) 152 | 153 | task.save_model('water-model.pth') 154 | cace_nnp.to(device) 155 | 156 | logging.info(f"Second train loop:") 157 | energy_loss = cace.tasks.GetLoss( 158 | target_name='energy', 159 | predict_name='CACE_energy', 160 | loss_fn=torch.nn.MSELoss(), 161 | loss_weight=100 162 | ) 163 | 164 | task.update_loss([energy_loss, force_loss]) 165 | logging.info("training") 166 | task.fit(train_loader, valid_loader, epochs=100, screen_nan=False) 167 | 168 | 169 | task.save_model('water-model-2.pth') 170 | cace_nnp.to(device) 171 | 172 | logging.info(f"Finished") 173 | 174 | 175 | 176 | 177 | 178 | -------------------------------------------------------------------------------- /examples/water_train_pl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | from cace.tasks import LightningData, LightningTrainingTask 5 | 6 | on_cluster = False 7 | if 'SLURM_JOB_CPUS_PER_NODE' in os.environ.keys(): 8 | on_cluster = True 9 | if on_cluster: 10 | root = "/global/scratch/users/king1305/data/water.xyz" 11 | else: 12 | root = "/home/king1305/Apps/cacefit/fit-water/water.xyz" 13 | 14 | logs_directory = "lightning_logs" 15 | logs_name = "water_test" 16 | cutoff = 5.5 17 | avge0 = {1: -187.6043857100553, 8: -93.80219285502734} 18 | batch_size = 4 19 | training_epochs = 500 20 | data = LightningData(root,batch_size=batch_size,cutoff=cutoff,atomic_energies=avge0) 21 | 22 | from cace.representations import Cace 23 | from cace.modules import BesselRBF, GaussianRBF, GaussianRBFCentered 24 | from cace.modules import PolynomialCutoff 25 | 26 | #Model 27 | radial_basis = BesselRBF(cutoff=cutoff, n_rbf=6, trainable=True) 28 | cutoff_fn = PolynomialCutoff(cutoff=cutoff) 29 | 30 | representation = Cace( 31 | zs=[1,8], 32 | n_atom_basis=3, 33 | embed_receiver_nodes=True, 34 | cutoff=cutoff, 35 | cutoff_fn=cutoff_fn, 36 | radial_basis=radial_basis, 37 | n_radial_basis=12, 38 | max_l=3, 39 | max_nu=3, 40 | num_message_passing=1, 41 | type_message_passing=['Bchi'], 42 | args_message_passing={'Bchi': {'shared_channels': False, 'shared_l': False}}, 43 | avg_num_neighbors=1, 44 | timeit=False 45 | ) 46 | 47 | for batch in data.train_dataloader(): 48 | exdatabatch = batch 49 | break 50 | 51 | from cace.models import NeuralNetworkPotential 52 | from cace.modules.atomwise import Atomwise 53 | from cace.modules.forces import Forces 54 | 55 | atomwise = Atomwise(n_layers=3, 56 | output_key="pred_energy", 57 | n_hidden=[32,16], 58 | n_out=1, 59 | use_batchnorm=False, 60 | add_linear_nn=True) 61 | 62 | forces = Forces(energy_key="pred_energy", 63 | forces_key="pred_force") 64 | 65 | model = NeuralNetworkPotential( 66 | input_modules=None, 67 | representation=representation, 68 | output_modules=[atomwise,forces] 69 | ) 70 | 71 | #Losses 72 | from cace.tasks import GetLoss 73 | e_loss = GetLoss( 74 | target_name="energy", 75 | predict_name='pred_energy', 76 | loss_fn=torch.nn.MSELoss(), 77 | loss_weight=1, 78 | ) 79 | f_loss = GetLoss( 80 | target_name="force", 81 | predict_name='pred_force', 82 | loss_fn=torch.nn.MSELoss(), 83 | loss_weight=1000, 84 | ) 85 | losses = [e_loss,f_loss] 86 | 87 | #Metrics 88 | from cace.tools import Metrics 89 | e_metric = Metrics( 90 | target_name="energy", 91 | predict_name='pred_energy', 92 | name='e', 93 | metric_keys=["rmse"], 94 | per_atom=True, 95 | ) 96 | f_metric = Metrics( 97 | target_name="force", 98 | predict_name='pred_force', 99 | metric_keys=["rmse"], 100 | name='f', 101 | ) 102 | metrics = [e_metric,f_metric] 103 | 104 | #Init lazy layers 105 | for batch in data.train_dataloader(): 106 | exdatabatch = batch 107 | break 108 | model(exdatabatch) 109 | 110 | #Check for checkpoint and restart if found: 111 | chkpt = None 112 | dev_run = False 113 | if os.path.isdir(f"lightning_logs/{logs_name}"): 114 | latest_version = None 115 | num = 0 116 | while os.path.isdir(f"lightning_logs/{logs_name}/version_{num}"): 117 | latest_version = f"lightning_logs/{logs_name}/version_{num}" 118 | num += 1 119 | if latest_version: 120 | chkpt = glob.glob(f"{latest_version}/checkpoints/*.ckpt")[0] 121 | if chkpt: 122 | print("Checkpoint found!",chkpt) 123 | print("Restarting...") 124 | dev_run = False 125 | 126 | progress_bar = True 127 | if on_cluster: 128 | torch.set_float32_matmul_precision('medium') 129 | progress_bar = False 130 | task = LightningTrainingTask(model,losses=losses,metrics=metrics, 131 | logs_directory="lightning_logs",name=logs_name, 132 | scheduler_args={'mode': 'min', 'factor': 0.8, 'patience': 10}, 133 | optimizer_args={'lr': 0.01}, 134 | ) 135 | task.fit(data,dev_run=dev_run,max_epochs=training_epochs,chkpt=chkpt, progress_bar=progress_bar) 136 | 137 | -------------------------------------------------------------------------------- /examples/water_train_w_ft_pl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | from cace.tasks import LightningData, LightningTrainingTask 5 | 6 | on_cluster = False 7 | if 'SLURM_JOB_CPUS_PER_NODE' in os.environ.keys(): 8 | on_cluster = True 9 | if on_cluster: 10 | root = "/global/scratch/users/king1305/data/water.xyz" 11 | else: 12 | root = "/home/king1305/Apps/cacefit/fit-water/water.xyz" 13 | 14 | logs_directory = "lightning_logs" 15 | logs_name = "water_test" 16 | cutoff = 5.5 17 | avge0 = {1: -187.6043857100553, 8: -93.80219285502734} 18 | batch_size = 4 19 | data = LightningData(root,batch_size=batch_size,cutoff=cutoff,atomic_energies=avge0) 20 | 21 | training_epochs = 500 22 | tuning_epochs = 100 23 | 24 | from cace.representations import Cace 25 | from cace.modules import BesselRBF, GaussianRBF, GaussianRBFCentered 26 | from cace.modules import PolynomialCutoff 27 | 28 | #Model 29 | radial_basis = BesselRBF(cutoff=cutoff, n_rbf=6, trainable=True) 30 | cutoff_fn = PolynomialCutoff(cutoff=cutoff) 31 | 32 | representation = Cace( 33 | zs=[1,8], 34 | n_atom_basis=3, 35 | embed_receiver_nodes=True, 36 | cutoff=cutoff, 37 | cutoff_fn=cutoff_fn, 38 | radial_basis=radial_basis, 39 | n_radial_basis=12, 40 | max_l=3, 41 | max_nu=3, 42 | num_message_passing=1, 43 | type_message_passing=['Bchi'], 44 | args_message_passing={'Bchi': {'shared_channels': False, 'shared_l': False}}, 45 | avg_num_neighbors=1, 46 | timeit=False 47 | ) 48 | 49 | for batch in data.train_dataloader(): 50 | exdatabatch = batch 51 | break 52 | 53 | from cace.models import NeuralNetworkPotential 54 | from cace.modules.atomwise import Atomwise 55 | from cace.modules.forces import Forces 56 | 57 | atomwise = Atomwise(n_layers=3, 58 | output_key="pred_energy", 59 | n_hidden=[32,16], 60 | n_out=1, 61 | use_batchnorm=False, 62 | add_linear_nn=True) 63 | 64 | forces = Forces(energy_key="pred_energy", 65 | forces_key="pred_force") 66 | 67 | model = NeuralNetworkPotential( 68 | input_modules=None, 69 | representation=representation, 70 | output_modules=[atomwise,forces] 71 | ) 72 | 73 | #Losses 74 | from cace.tasks import GetLoss 75 | e_loss = GetLoss( 76 | target_name="energy", 77 | predict_name='pred_energy', 78 | loss_fn=torch.nn.MSELoss(), 79 | loss_weight=1, 80 | ) 81 | f_loss = GetLoss( 82 | target_name="force", 83 | predict_name='pred_force', 84 | loss_fn=torch.nn.MSELoss(), 85 | loss_weight=1000, 86 | ) 87 | losses = [e_loss,f_loss] 88 | 89 | #Metrics 90 | from cace.tools import Metrics 91 | e_metric = Metrics( 92 | target_name="energy", 93 | predict_name='pred_energy', 94 | name='e', 95 | metric_keys=["rmse"], 96 | per_atom=True, 97 | ) 98 | f_metric = Metrics( 99 | target_name="force", 100 | predict_name='pred_force', 101 | metric_keys=["rmse"], 102 | name='f', 103 | ) 104 | metrics = [e_metric,f_metric] 105 | 106 | #Init lazy layers 107 | for batch in data.train_dataloader(): 108 | exdatabatch = batch 109 | break 110 | model(exdatabatch) 111 | 112 | #Check for checkpoint and restart if found: 113 | chkpt = None 114 | dev_run = False 115 | if os.path.isdir(f"lightning_logs/{logs_name}"): 116 | latest_version = None 117 | num = 0 118 | while os.path.isdir(f"lightning_logs/{logs_name}/version_{num}"): 119 | latest_version = f"lightning_logs/{logs_name}/version_{num}" 120 | num += 1 121 | if latest_version: 122 | chkpt = glob.glob(f"{latest_version}/checkpoints/*.ckpt")[0] 123 | if chkpt: 124 | print("Checkpoint found!",chkpt) 125 | print("Restarting...") 126 | dev_run = False 127 | 128 | progress_bar = True 129 | if on_cluster: 130 | torch.set_float32_matmul_precision('medium') 131 | progress_bar = False 132 | task = LightningTrainingTask(model,losses=losses,metrics=metrics, 133 | logs_directory="lightning_logs",name=logs_name, 134 | scheduler_args={'mode': 'min', 'factor': 0.8, 'patience': 10}, 135 | optimizer_args={'lr': 0.01}, 136 | ) 137 | task.fit(data,dev_run=dev_run,max_epochs=training_epochs,chkpt=chkpt, progress_bar=progress_bar) 138 | 139 | #If you want to do the "fine-tuning" w/higher energy loss: 140 | if tuning_epochs > 0: 141 | os.system(f"mv lightning_logs/{logs_name}/best_model.pth lightning_logs/{logs_name}/best_model_noft.pth") 142 | e_loss = GetLoss( 143 | target_name="energy", 144 | predict_name='pred_energy', 145 | loss_fn=torch.nn.MSELoss(), 146 | loss_weight=1000, 147 | ) 148 | losses = [e_loss,f_loss] 149 | 150 | #Search for checkpoint 151 | latest_version = None 152 | num = 0 153 | while os.path.isdir(f"lightning_logs/{logs_name}/version_{num}"): 154 | latest_version = f"lightning_logs/{logs_name}/version_{num}" 155 | num += 1 156 | if latest_version: 157 | chkpt = glob.glob(f"{latest_version}/checkpoints/*.ckpt")[0] 158 | task = LightningTrainingTask(model,losses=losses,metrics=metrics, 159 | logs_directory="lightning_logs",name=logs_name, 160 | scheduler_args={'mode': 'min', 'factor': 0.8, 'patience': 10}, 161 | optimizer_args={'lr': 0.01}, 162 | ) 163 | #Uses previous checkpoint params + lr: 164 | task.fit(data,dev_run=dev_run,max_epochs=training_epochs+tuning_epochs,chkpt=chkpt, 165 | progress_bar=progress_bar) 166 | # restrat the trainer: 167 | #task.fit(data,dev_run=dev_run,max_epochs=tuning_epochs, 168 | # progress_bar=progress_bar) 169 | 170 | 171 | -------------------------------------------------------------------------------- /scripts/cace_alchemical_rep_v0.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BingqingCheng/cace/82af6fc87733d2f4bf729754313cf77f1293ce72/scripts/cace_alchemical_rep_v0.pth -------------------------------------------------------------------------------- /scripts/compute_avg_e0.py: -------------------------------------------------------------------------------- 1 | import cace 2 | import pickle 3 | from ase.io import read 4 | import sys 5 | 6 | if len(sys.argv) != 3: 7 | print('Usage: python compute_avg_e0.py xyzfile stride') 8 | sys.exit() 9 | 10 | stride = int(sys.argv[2]) 11 | 12 | # read the xyz file and compute the average E0s 13 | xyz = read(sys.argv[1], index=slice(0, None, stride)) 14 | avge0 = cace.tools.compute_average_E0s(xyz) 15 | 16 | print('Average E0s:', avge0) 17 | # save the avge0 dict to a file 18 | with open('avge0.pkl', 'wb') as f: 19 | pickle.dump(avge0, f) 20 | -------------------------------------------------------------------------------- /scripts/compute_cace_desc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import ase 7 | from ase import Atoms 8 | from ase.io import read,write 9 | 10 | import cace 11 | from cace.data import AtomicData 12 | from cace.representations.cace_representation import Cace 13 | from cace.tools import to_numpy 14 | from cace.tools import scatter_sum 15 | 16 | cutoff = 4.0 17 | batch_size = 10 18 | 19 | 20 | # Get the directory of the current script 21 | script_dir = os.path.dirname(os.path.abspath(__file__)) 22 | data_file_path = os.path.join(script_dir, 'cace_alchemical_rep_v0.pth') 23 | cace_repr = torch.load(data_file_path, map_location='cpu') 24 | 25 | data = read(sys.argv[1], ":") 26 | dataset=[ 27 | AtomicData.from_atoms( 28 | atom, cutoff=cutoff 29 | ) 30 | for atom in data 31 | ] 32 | 33 | data_loader = cace.tools.torch_geometric.dataloader.DataLoader( 34 | dataset, 35 | batch_size=batch_size, 36 | shuffle=False, 37 | drop_last=False, 38 | ) 39 | 40 | n_frame = 0 41 | for sampled_data in tqdm(data_loader): 42 | cace_result_more = cace_repr(sampled_data) 43 | avg_B = scatter_sum( 44 | src=cace_result_more['node_feats'], 45 | index=sampled_data["batch"], 46 | dim=0 47 | ) 48 | #print(avg_B.shape) 49 | #print(torch.bincount(sampled_data['batch'])) 50 | n_configs = avg_B.shape[0] 51 | avg_B_flat = to_numpy(avg_B.reshape(n_configs, -1) / torch.bincount(sampled_data['batch']).reshape(-1, 1)) 52 | #print(avg_B_flat.shape) 53 | for i in range(n_configs): 54 | data[i+n_frame].info['CACE_desc'] = avg_B_flat[i] 55 | n_frame += n_configs 56 | 57 | # check if sys.argv[2] exists 58 | if len(sys.argv) > 2: 59 | prefix = sys.argv[2] 60 | else: 61 | prefix = 'CACE_desc' 62 | write(prefix+'.xyz', data) 63 | 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /scripts/split.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import pickle 3 | import torch 4 | import numpy as np 5 | import sys 6 | import argparse 7 | 8 | from ase.io import read 9 | from cace.data import AtomicData 10 | from cace.tools import save_dataset 11 | 12 | def main(): 13 | # Parse the command line arguments 14 | parser = argparse.ArgumentParser(description='Split a dataset into multiple parts') 15 | parser.add_argument('--input_file', type=str, help='Path to the input dataset') 16 | parser.add_argument('--num_splits', type=int, help='Number of splits', default=4) 17 | parser.add_argument('--output_prefix', type=str, help='Prefix for the output files', default='split') 18 | parser.add_argument('--shuffle', action='store_true', help='Shuffle the dataset before splitting') 19 | parser.add_argument('--cutoff', type=float, help='Cutoff radius for the atomic environment') 20 | parser.add_argument('--energy', type=str, help='Key for the energy data', default='energy') 21 | parser.add_argument('--forces', type=str, help='Key for the forces data', default='forces') 22 | parser.add_argument('--atomic_energies', type=str, help='file for the atomic energies', default=None) 23 | args = parser.parse_args() 24 | 25 | all_xyz_path = args.input_file 26 | num_splits = args.num_splits 27 | cutoff = args.cutoff 28 | data_key = {'energy': args.energy, 'forces': args.forces} 29 | atomic_energies = pickle.load(open(args.atomic_energies, 'rb')) if args.atomic_energies is not None else None 30 | 31 | for i in range(num_splits): 32 | all_xyz = read(all_xyz_path, index=slice(i, None, num_splits)) 33 | 34 | dataset=[ 35 | AtomicData.from_atoms(atoms, cutoff=cutoff, data_key=data_key, atomic_energies=atomic_energies).to_dict() # Convert to dictionary 36 | for atoms in all_xyz 37 | ] 38 | 39 | shuffle = args.shuffle # Set to True if you want to shuffle the data 40 | save_dataset(dataset, args.output_prefix+'_'+str(i)+'.h5') 41 | 42 | if __name__ == '__main__': 43 | main() 44 | 45 | """ 46 | load the dataset 47 | # Load the dataset 48 | loaded_dataset = load_dataset('test.h5') 49 | print(loaded_dataset[0].keys()) 50 | 51 | dataset = [ 52 | AtomicData(**data) # Convert to AtomicData object 53 | for data in loaded_dataset 54 | ] 55 | """ 56 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import ase.io 4 | import cace 5 | import pickle 6 | import os 7 | from cace.representations import Cace 8 | from cace.modules import PolynomialCutoff, BesselRBF, Atomwise, Forces 9 | from cace.models.atomistic import NeuralNetworkPotential 10 | from cace.tasks.train import TrainingTask, GetLoss 11 | from cace.tools import Metrics, init_device, compute_average_E0s, setup_logger, get_unique_atomic_number 12 | from cace.tools import parse_arguments 13 | 14 | def main(): 15 | args = parse_arguments() 16 | 17 | setup_logger(level='INFO', tag=args.prefix, directory='./') 18 | device = init_device(args.use_device) 19 | 20 | if args.zs is None: 21 | xyz = ase.io.read(args.train_path, ':') 22 | args.zs = get_unique_atomic_number(xyz) 23 | 24 | # load the avge0 dict from a file if possible 25 | if os.path.exists('avge0.pkl'): 26 | with open('avge0.pkl', 'rb') as f: 27 | avge0 = pickle.load(f) 28 | else: 29 | # Load Dataset 30 | avge0 = compute_average_E0s(xyz) 31 | with open('avge0.pkl', 'wb') as f: 32 | pickle.dump(avge0, f) 33 | 34 | # Prepare Data Loaders 35 | collection = cace.tasks.get_dataset_from_xyz( 36 | train_path=args.train_path, 37 | valid_fraction=args.valid_fraction, 38 | data_key={'energy': args.energy_key, 'forces': args.forces_key}, 39 | atomic_energies=avge0, 40 | cutoff=args.cutoff) 41 | 42 | train_loader = cace.tasks.load_data_loader( 43 | collection=collection, 44 | data_type='train', 45 | batch_size=args.batch_size) 46 | 47 | valid_loader = cace.tasks.load_data_loader( 48 | collection=collection, 49 | data_type='valid', 50 | batch_size=args.valid_batch_size) 51 | 52 | # Configure CACE Representation 53 | cutoff_fn = PolynomialCutoff(cutoff=args.cutoff, p=args.cutoff_fn_p) 54 | radial_basis = BesselRBF(cutoff=args.cutoff, n_rbf=args.n_rbf, trainable=args.trainable_rbf) 55 | cace_representation = Cace( 56 | zs=args.zs, n_atom_basis=args.n_atom_basis, embed_receiver_nodes=args.embed_receiver_nodes, 57 | cutoff=args.cutoff, cutoff_fn=cutoff_fn, radial_basis=radial_basis, 58 | n_radial_basis=args.n_radial_basis, max_l=args.max_l, max_nu=args.max_nu, 59 | device=device, num_message_passing=args.num_message_passing) 60 | 61 | # Configure Atomwise Module 62 | atomwise = Atomwise( 63 | n_layers=args.atomwise_layers, n_hidden=args.atomwise_hidden, residual=args.atomwise_residual, 64 | use_batchnorm=args.atomwise_batchnorm, add_linear_nn=args.atomwise_linear_nn, 65 | output_key='CACE_energy') 66 | 67 | # Configure Forces Module 68 | forces = Forces(energy_key='CACE_energy', forces_key='CACE_forces') 69 | 70 | # Assemble Neural Network Potential 71 | cace_nnp = NeuralNetworkPotential(representation=cace_representation, output_modules=[atomwise, forces]).to(device) 72 | 73 | # Phase 1 Training Configuration 74 | optimizer_args = {'lr': args.lr} 75 | scheduler_args = {'mode': 'min', 'factor': args.scheduler_factor, 'patience': args.scheduler_patience} 76 | energy_loss = GetLoss( 77 | target_name='energy', 78 | predict_name='CACE_energy', 79 | loss_fn=torch.nn.MSELoss(), 80 | loss_weight=args.energy_loss_weight) 81 | force_loss = GetLoss( 82 | target_name='forces', 83 | predict_name='CACE_forces', 84 | loss_fn=torch.nn.MSELoss(), 85 | loss_weight=args.force_loss_weight) 86 | 87 | 88 | e_metric = Metrics( 89 | target_name='energy', 90 | predict_name='CACE_energy', 91 | name='e/atom', 92 | per_atom=True 93 | ) 94 | 95 | f_metric = Metrics( 96 | target_name='forces', 97 | predict_name='CACE_forces', 98 | name='f' 99 | ) 100 | 101 | for _ in range(args.num_restart): 102 | # Initialize and Fit Training Task for Phase 1 103 | task = TrainingTask( 104 | model=cace_nnp, losses=[energy_loss, force_loss], metrics=[e_metric, f_metric], 105 | device=device, optimizer_args=optimizer_args, scheduler_cls=torch.optim.lr_scheduler.ReduceLROnPlateau, 106 | scheduler_args=scheduler_args, max_grad_norm=args.max_grad_norm, ema=args.ema, 107 | ema_start=args.ema_start, warmup_steps=args.warmup_steps) 108 | 109 | task.fit(train_loader, valid_loader, epochs=int(args.epochs/args.num_restart), print_stride=0) 110 | task.save_model(args.prefix+'_phase_1.pth') 111 | 112 | # Phase 2 Training Adjustment 113 | energy_loss_2 = GetLoss('energy', 'CACE_energy', torch.nn.MSELoss(), args.second_phase_energy_loss_weight) 114 | task.update_loss([energy_loss_2, force_loss]) 115 | 116 | # Fit Training Task for Phase 2 117 | task.fit(train_loader, valid_loader, epochs=args.second_phase_epochs) 118 | task.save_model(args.prefix+'_phase_2.pth') 119 | 120 | if __name__ == '__main__': 121 | main() 122 | 123 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='CACE', 5 | version='0.1.0', 6 | author='Bingqing Cheng', 7 | author_email='tonicbq@gmail.com', 8 | description='Cartesian Atomic Cluster Expansion Machine Learning Potential', 9 | packages=find_packages(), 10 | install_requires=[ 11 | 'numpy', 12 | 'ase<=3.22.1', 13 | 'torch', 14 | 'matscipy', 15 | ], 16 | classifiers=[ 17 | 'Programming Language :: Python :: 3', 18 | 'License :: OSI Approved :: MIT License', 19 | 'Operating System :: OS Independent', 20 | ], 21 | python_requires='>=3.6', 22 | ) 23 | 24 | -------------------------------------------------------------------------------- /tests/ewald_nopbc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cace 3 | from cace.modules import EwaldPotential 4 | import sys 5 | torch.set_default_dtype(torch.float32) 6 | 7 | ep = EwaldPotential(dl=1., 8 | sigma=2, 9 | exponent=1, 10 | feature_key='q', 11 | aggregation_mode='sum', 12 | remove_self_interaction=True, 13 | compute_field=True) 14 | 15 | # set the same random seed for reproducibility 16 | torch.manual_seed(sys.argv[1]) 17 | r = torch.rand(10, 3) * 8 # Random positions in a 10x10x10 box 18 | q = torch.rand(10) * 2 # Random charges 19 | 20 | q -= torch.mean(q) 21 | box = torch.tensor([30.0, 30.0, 30.0], dtype=torch.float32) # Box dimensions 22 | 23 | #print(q) 24 | #exit() 25 | 26 | # Replicate the box 2x2x2 times 27 | 28 | ew_1, field_1 = ep.compute_potential_optimized(torch.tensor(r), torch.tensor(q).unsqueeze(1), torch.tensor(box), compute_field=True) 29 | ew_1_s, field_1_s = ep.compute_potential_realspace(torch.tensor(r), torch.tensor(q), compute_field=True) 30 | print(ew_1, ew_1_s) 31 | print(ew_1.shape, ew_1_s.shape) 32 | print(field_1, field_1_s) 33 | print(field_1.shape, field_1_s.shape) 34 | print(torch.sum(q.unsqueeze(1) * field_1 / 2), torch.sum(q.unsqueeze(1) * field_1_s / 2)) 35 | 36 | -------------------------------------------------------------------------------- /tests/test-cace-representation-rotation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "8e21a83f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "6f361918", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import os\n", 22 | "os.environ['KMP_DUPLICATE_LIB_OK']='True'" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "id": "7ab78b11", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "import numpy as np\n", 33 | "import torch\n", 34 | "import torch.nn as nn" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "6c87d406", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "import sys\n", 45 | "sys.path.append('../')" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "11273bc5", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "import cace\n", 56 | "from cace import data" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "080d6fbd", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "from scipy.spatial.transform import Rotation as R" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "810cd94b", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "torch.set_default_dtype(torch.float64)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "dc966daf", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "from ase import Atoms" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "id": "61e728d9", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "atom=Atoms(numbers=np.array([8, 1, 1, 1]),\n", 97 | " positions=np.array(\n", 98 | " [\n", 99 | " [0.0452, -2.02, 0.0452],\n", 100 | " [1.0145, 0.034, 0.0232],\n", 101 | " [0.0111, 1.041, -0.010],\n", 102 | " [-0.0111, -0.041, 0.510],\n", 103 | " ]\n", 104 | " ),\n", 105 | " pbc=False)\n", 106 | "\n", 107 | "# Created the rotated environment\n", 108 | "rot = R.from_euler(\"z\", 70, degrees=True).as_matrix()\n", 109 | "positions_rotated = np.array(rot @ atom.positions.T).T\n", 110 | "\n", 111 | "rot = R.from_euler(\"x\", 10.6, degrees=True).as_matrix()\n", 112 | "positions_rotated = np.array(rot @ positions_rotated.T).T\n", 113 | "\n", 114 | "rot = R.from_euler(\"y\", 190, degrees=True).as_matrix()\n", 115 | "positions_rotated = np.array(rot @ positions_rotated.T).T\n", 116 | "\n", 117 | "atom_rotated=Atoms(numbers=np.array([8, 1, 1, 1]),\n", 118 | " positions=positions_rotated,\n", 119 | " pbc=False)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "3aa4259a", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "cutoff = 5.0" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "4d068011", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "atomic_data = data.AtomicData.from_atoms(atom, cutoff=cutoff)\n", 140 | "atomic_data2 = data.AtomicData.from_atoms(atom_rotated, cutoff=cutoff)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "id": "d2f316e0", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "atomic_data.positions" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "02c3623d", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "atomic_data2.positions" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "id": "cc49098a", 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "from cace.representations.cace_representation import Cace" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "id": "5de73a93", 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "atomic_data" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "id": "6fdb68fe", 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "from cace.modules import CosineCutoff, MollifierCutoff, PolynomialCutoff\n", 191 | "from cace.modules import BesselRBF, GaussianRBF, GaussianRBFCentered, ExponentialDecayRBF" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "id": "d9632c86", 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "radial_basis = BesselRBF(cutoff=cutoff, n_rbf=5, trainable=False)\n", 202 | "#radial_basis = ExponentialDecayRBF(n_rbf=4, cutoff=cutoff, prefactor=1, trainable=True)\n", 203 | "cutoff_fn = CosineCutoff(cutoff=cutoff)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "a6f21fb8", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "cace_representation = Cace(\n", 214 | " zs=[1,8],\n", 215 | " n_atom_basis=2,\n", 216 | " cutoff=cutoff,\n", 217 | " cutoff_fn=cutoff_fn,\n", 218 | " radial_basis=radial_basis,\n", 219 | " max_l=6,\n", 220 | " max_nu=4,\n", 221 | " num_message_passing=2,\n", 222 | " timeit=True\n", 223 | " )" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "id": "e92da60c", 230 | "metadata": {}, 231 | "outputs": [], 232 | "source": [ 233 | "cace_result = cace_representation(atomic_data)" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "id": "32c0bdcb", 240 | "metadata": {}, 241 | "outputs": [], 242 | "source": [ 243 | "cace_result2 = cace_representation(atomic_data2)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "id": "6a65daec", 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "features = cace_result['node_feats']" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "39e4436f", 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "features2 = cace_result2['node_feats']" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "id": "55cfd8e5", 270 | "metadata": {}, 271 | "outputs": [], 272 | "source": [ 273 | "torch.allclose(features, features2, rtol=1e-05, atol=1e-05)" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "id": "582511b2", 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "features.shape" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "id": "425f0a9e", 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "features[0,2,:,0,0]" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "id": "88a1895e", 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "features[0,1,:,0,1]" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "id": "5f7d3b23", 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "features[0,1,:,0,2]" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "id": "21467404", 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "features2[0,1,:,0,1]" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "id": "f8619177", 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [] 333 | } 334 | ], 335 | "metadata": { 336 | "kernelspec": { 337 | "display_name": "Python 3 (ipykernel)", 338 | "language": "python", 339 | "name": "python3" 340 | }, 341 | "language_info": { 342 | "codemirror_mode": { 343 | "name": "ipython", 344 | "version": 3 345 | }, 346 | "file_extension": ".py", 347 | "mimetype": "text/x-python", 348 | "name": "python", 349 | "nbconvert_exporter": "python", 350 | "pygments_lexer": "ipython3", 351 | "version": "3.8.16" 352 | } 353 | }, 354 | "nbformat": 4, 355 | "nbformat_minor": 5 356 | } 357 | -------------------------------------------------------------------------------- /tests/test_angular.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import time 4 | import torch 5 | from cace.modules import AngularComponent, AngularComponent_GPU 6 | 7 | vectors = torch.rand(10000, 3) 8 | 9 | start_time = time.time() 10 | angular_func = AngularComponent(3) 11 | angular_component = angular_func(vectors) 12 | end_time = time.time() 13 | print(f"Execution time AngularComponent function: {end_time - start_time} seconds") 14 | 15 | 16 | start_time = time.time() 17 | angular_func_GPU = AngularComponent_GPU(3) 18 | angular_component_GPU = angular_func_GPU(vectors) 19 | end_time = time.time() 20 | print(f"Execution time AngularComponent_GPU function: {end_time - start_time} seconds") 21 | 22 | #not supposed to be the same as l_list is different 23 | #print(torch.allclose(angular_component, angular_component_GPU)) 24 | -------------------------------------------------------------------------------- /tests/test_edge_encoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import torch 4 | import cace 5 | 6 | from cace.modules import EdgeEncoder 7 | 8 | encoder = EdgeEncoder(directed=True) 9 | edges = torch.tensor([ 10 | [[0, 1], [0, 1]], 11 | [[0, 1], [1, 0]], 12 | [[1, 0], [0, 1]], 13 | [[1, 0], [1, 0]], 14 | ]) 15 | encoded_edges = edge_coding(edges) 16 | print("edges:", edges) 17 | print(encoded_edges) 18 | 19 | encoder = EdgeEncoder(directed=False) 20 | edges = torch.tensor([[[0, 0.2], [0.7, 0]], [[1, 0], [0, 1]],[[1, 0], [1, 0]], [[0.7, 0], [0.0, 0.2]]]) 21 | encoded_edges = encoder(edges) 22 | print("edges:", edges) 23 | print(encoded_edges) 24 | 25 | 26 | -------------------------------------------------------------------------------- /tests/test_elementwise_multiply_tensors.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import torch 4 | 5 | from cace.tools import elementwise_multiply_2tensors, elementwise_multiply_3tensors 6 | 7 | tensor1 = torch.rand([996, 20]) 8 | tensor2 = torch.rand([996, 8]) 9 | result = elementwise_multiply_2tensors(tensor1,tensor2) 10 | print(result.shape) 11 | assert torch.equal(result[0,0], tensor2[0] * tensor1[0,0]), "Tensors are not equal." 12 | 13 | tensor1 = torch.rand([996, 20]) 14 | tensor2 = torch.rand([996, 8]) 15 | tensor3 = torch.rand([996, 6]) 16 | result = elementwise_multiply_3tensors(tensor1,tensor2,tensor3) 17 | print(result.shape) 18 | # TODO: Fix this test 19 | #assert torch.equal(result[0,0], tensor2[0] * tensor1[0,0] * tensor3[0,0]), "Tensors are not equal." 20 | 21 | 22 | -------------------------------------------------------------------------------- /tests/test_ewald.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cace 3 | from cace.modules import EwaldPotential 4 | 5 | ep = EwaldPotential(dl=2, 6 | sigma=1, 7 | exponent=1, 8 | feature_key='q', 9 | aggregation_mode='sum') 10 | 11 | def replicate_box(r, q, box, nx=2, ny=2, nz=2): 12 | """Replicate the simulation box nx, ny, nz times in each direction.""" 13 | n_atoms = r.shape[0] 14 | replicated_r = [] 15 | replicated_q = [] 16 | 17 | for ix in range(nx): 18 | for iy in range(ny): 19 | for iz in range(nz): 20 | shift = torch.tensor([ix, iy, iz], dtype=r.dtype, device=r.device) * box 21 | replicated_r.append(r + shift) 22 | replicated_q.append(q) 23 | 24 | replicated_r = torch.cat(replicated_r) 25 | replicated_q = torch.cat(replicated_q) 26 | 27 | new_box = torch.tensor([nx, ny, nz], dtype=r.dtype, device=r.device) * box 28 | return replicated_r, replicated_q, new_box 29 | 30 | # set the same random seed for reproducibility 31 | torch.manual_seed(0) 32 | r = torch.rand(100, 3) * 10 # Random positions in a 10x10x10 box 33 | q = torch.rand(100) * 2 - 1 # Random charges 34 | 35 | #q -= torch.mean(q) 36 | box = torch.tensor([10.0, 10.0, 10.0], dtype=torch.float64) # Box dimensions 37 | 38 | #print(q) 39 | #exit() 40 | 41 | # Replicate the box 2x2x2 times 42 | replicated_r, replicated_q, new_box = replicate_box(r, q, box, nx=2, ny=2, nz=2) 43 | 44 | ew_1 = ep.compute_potential_optimized(torch.tensor(r), torch.tensor(q).unsqueeze(1), torch.tensor(box)) 45 | ew_1_s = ep.compute_potential(torch.tensor(r), torch.tensor(q).unsqueeze(1), torch.tensor(box)) 46 | ew_2 = ep.compute_potential_optimized(replicated_r, replicated_q.unsqueeze(1), new_box) / 8 47 | ew_2_s = ep.compute_potential(replicated_r, replicated_q.unsqueeze(1), new_box) / 8 48 | print(ew_1, ew_2, ew_1_s, ew_2_s) 49 | 50 | -------------------------------------------------------------------------------- /tests/test_ewald_triclinic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | sys.path.append('../') 4 | import cace 5 | from cace.modules import EwaldPotential 6 | 7 | ep = EwaldPotential(dl=1.5, 8 | sigma=1, 9 | exponent=1, 10 | feature_key='q', 11 | aggregation_mode='sum') 12 | 13 | def replicate_box(r, q, box, nx=2, ny=2, nz=2): 14 | """Replicate the simulation box nx, ny, nz times in each direction.""" 15 | n_atoms = r.shape[0] 16 | replicated_r = [] 17 | replicated_q = [] 18 | 19 | for ix in range(nx): 20 | for iy in range(ny): 21 | for iz in range(nz): 22 | shift = torch.tensor([ix, iy, iz], dtype=r.dtype, device=r.device) * box 23 | replicated_r.append(r + shift) 24 | replicated_q.append(q) 25 | 26 | replicated_r = torch.cat(replicated_r) 27 | replicated_q = torch.cat(replicated_q) 28 | 29 | new_box = torch.tensor([nx, ny, nz], dtype=r.dtype, device=r.device) * box 30 | return replicated_r, replicated_q, new_box 31 | 32 | # set the same random seed for reproducibility 33 | torch.manual_seed(0) 34 | r = torch.rand(100, 3, ) * 10 # Random positions in a 10x10x10 box 35 | q = torch.rand(100) * 2 - 1 # Random charges 36 | 37 | box = torch.tensor([10.0, 10.0, 10.0], dtype=torch.float64) # Box dimensions 38 | box_3d = torch.tensor([[10.0, 0.0, 0.0], 39 | [0.0, 10.0, 0.0], 40 | [0.0, 0.0, 10.0]]) 41 | box_3d_2 = torch.tensor([[20.0, 0.0, 0.0], 42 | [0.0, 20.0, 0.0], 43 | [0.0, 0.0, 20.0]]) 44 | 45 | # Replicate the box 2x2x2 times 46 | replicated_r, replicated_q, new_box = replicate_box(r, q, box, nx=2, ny=2, nz=2) 47 | 48 | ew_1 = ep.compute_potential_optimized(torch.tensor(r), torch.tensor(q).unsqueeze(1), torch.tensor(box)) 49 | ew_1_s = ep.compute_potential(torch.tensor(r), torch.tensor(q).unsqueeze(1), torch.tensor(box)) 50 | ew_tri = ep.compute_potential_triclinic(torch.tensor(r), torch.tensor(q).unsqueeze(1), torch.tensor(box_3d)) 51 | ew_2 = ep.compute_potential_optimized(replicated_r, replicated_q.unsqueeze(1), new_box) 52 | ew_2_s = ep.compute_potential(replicated_r, replicated_q.unsqueeze(1), new_box) 53 | ew_2_tri = ep.compute_potential_triclinic(replicated_r.to(dtype=torch.float32), replicated_q.to(dtype=torch.float32).unsqueeze(1), box_3d_2) 54 | print('###cubic cell test###') 55 | print(ew_1[0], ew_1_s[0], ew_tri[0]) 56 | print(ew_2[0], ew_2_s[0], ew_2_tri[0]) 57 | print(ew_2[0]/ew_1[0], ew_2_s[0]/ew_1_s[0], ew_2_tri[0]/ew_tri[0]) 58 | 59 | 60 | #triclinic cell test 61 | def replicate_box_tri(r, q, box, nx=2, ny=2, nz=2): 62 | replicated_r = [] 63 | replicated_q = [] 64 | for ix in range(nx): 65 | for iy in range(ny): 66 | for iz in range(nz): 67 | shift = ix * box[0, :] + iy * box[1, :] + iz * box[2, :] 68 | replicated_r.append(r + shift) 69 | replicated_q.append(q) 70 | replicated_r = torch.cat(replicated_r, dim=0) 71 | replicated_q = torch.cat(replicated_q, dim=0) 72 | new_box = torch.stack([nx * box[0, :], ny * box[1, :], nz * box[2, :]], dim=0) 73 | return replicated_r, replicated_q, new_box 74 | 75 | 76 | box_tric = torch.tensor([[10.0, 2.0, 1.0], 77 | [0.0, 9.0, 1.5], 78 | [0.0, 0.0, 10.0]]) 79 | 80 | s_rand = torch.rand(100, 3) 81 | r_tric = torch.matmul(s_rand, box_tric) # Random positions in a triclinic box 82 | q_tric = torch.rand(100) * 2 - 1 83 | 84 | ep = cace.modules.EwaldPotential(dl=2, sigma=1, exponent=1, feature_key='q', aggregation_mode='sum') 85 | 86 | 87 | ew_tric = ep.compute_potential_triclinic(r_tric, q_tric.unsqueeze(1), box_tric) 88 | 89 | # Replicate the cell 2x2x2 times. 90 | rep_r, rep_q, new_box = replicate_box_tri(r_tric, q_tric, box_tric, nx=2, ny=2, nz=2) 91 | ew_tric_rep = ep.compute_potential_triclinic(rep_r, rep_q.unsqueeze(1), new_box) 92 | 93 | print('###triclinic cell test###') 94 | print("Triclinic energy (original):", ew_tric[0]) 95 | print("Triclinic energy (replicated):", ew_tric_rep[0]) 96 | print("Ratio:", ew_tric_rep[0] / ew_tric[0]) -------------------------------------------------------------------------------- /tests/test_neighborhood.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | import numpy as np 4 | from ase.io import read 5 | import cace 6 | #from cace import data 7 | 8 | atoms = read('../datasets/water.xyz','0') 9 | config = cace.data.config_from_atoms(atoms, energy_key ='TotEnergy', forces_key='force') 10 | 11 | edge_index, shifts, unit_shifts = cace.data.get_neighborhood( 12 | positions=config.positions, 13 | cutoff=5, 14 | cell=config.cell, 15 | pbc=config.pbc 16 | ) 17 | 18 | print(np.shape(edge_index)) 19 | 20 | # try if the same number of neighbors as the usual method 21 | from ase.neighborlist import NeighborList 22 | # Generate neighbor list with element-specific cutoffs 23 | cutoffs = [2.5 for atom in atoms] 24 | nl = NeighborList(cutoffs, skin=0.0, self_interaction=False, bothways=True) 25 | nl.update(atoms) 26 | 27 | # Store displacement vectors 28 | displacement_vectors = [] 29 | 30 | for i, atom in enumerate(atoms): 31 | indices, offsets = nl.get_neighbors(i) 32 | for j, offset in zip(indices, offsets): 33 | disp_vector = atoms[j].position - atom.position 34 | disp_vector += np.dot(offset, atoms.get_cell()) 35 | displacement_vectors.append(disp_vector) 36 | 37 | print(np.shape(displacement_vectors)) 38 | 39 | assert np.shape(edge_index)[1] == np.shape(displacement_vectors)[0] 40 | 41 | 42 | -------------------------------------------------------------------------------- /tests/test_pol_triclinic.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import sys 3 | sys.path.append('../') 4 | import cace 5 | from cace.modules import Polarization 6 | 7 | #generate random data 8 | torch.random.manual_seed(0) 9 | r_now = torch.randn(192,3, dtype=torch.float64) 10 | q_now = torch.randn(192,1, dtype=torch.float64) 11 | 12 | #box for orignal function 13 | box_now = torch.tensor([10.,10.,10.]) 14 | 15 | #calculate polarization using original function 16 | factor_o = box_now / (1j * 2.* torch.pi) 17 | phase_o = torch.exp(1j * 2.* torch.pi * r_now / box_now) 18 | polarization_o = torch.sum(q_now * phase_o, dim=(0)) * factor_o 19 | print("Original polarization:", polarization_o) 20 | 21 | #calculate polarization using new function 22 | pol_class = Polarization(pbc=True) 23 | box_ortho = torch.tensor([[10.0, 0.0, 0.0], 24 | [0.0, 10.0, 0.0], 25 | [0.0, 0.0, 10.0]], dtype=torch.float64) 26 | pol_ortho, phase_ortho = pol_class.compute_pol_pbc(r_now, q_now, box_ortho) 27 | print("Modified polarization:", pol_ortho) 28 | print(torch.allclose(polarization_o, pol_ortho)) 29 | 30 | 31 | # Check that the polarization is equivalent to the rotated polarization in the triclinic cell 32 | 33 | #Rotation matrix around z-axis, 45 degrees 34 | theta = math.pi/4 35 | R = torch.tensor([[math.cos(theta), -math.sin(theta), 0], 36 | [math.sin(theta), math.cos(theta), 0], 37 | [0, 0, 1]], dtype=torch.float64) 38 | 39 | # Triclinic cell 40 | box_tric = torch.matmul(R, box_ortho) 41 | r_tric = torch.matmul(r_now, R) 42 | 43 | pol_tric, phase_tric = pol_class.compute_pol_pbc(r_tric, q_now, box_tric) 44 | pol_expected = torch.matmul(R.to(torch.complex128), pol_ortho.unsqueeze(1)).squeeze() 45 | 46 | print("Triclinic polarization:", pol_tric) 47 | print("Expected triclinic polarization from rotation:", pol_expected) 48 | print(torch.allclose(pol_tric, pol_expected)) --------------------------------------------------------------------------------