├── LICENSE
├── README.md
├── cace
├── __init__.py
├── calculators
│ ├── __init__.py
│ └── cace_calculator.py
├── data
│ ├── __init__.py
│ ├── atomic_data.py
│ └── neighborhood.py
├── models
│ ├── __init__.py
│ ├── atomistic.py
│ └── combined.py
├── modules
│ ├── __init__.py
│ ├── angular.py
│ ├── angular_tools.py
│ ├── atomwise.py
│ ├── blocks.py
│ ├── cutoff.py
│ ├── ewald.py
│ ├── feature_mix.py
│ ├── forces.py
│ ├── grad.py
│ ├── interaction.py
│ ├── les_wrapper.py
│ ├── polarization.py
│ ├── preprocess.py
│ ├── product_basis.py
│ ├── radial.py
│ ├── radial_transform.py
│ ├── symmetrize_basis.py
│ ├── transform.py
│ ├── type.py
│ └── utils.py
├── representations
│ ├── __init__.py
│ └── cace_representation.py
├── tasks
│ ├── __init__.py
│ ├── evaluate.py
│ ├── lightning.py
│ ├── load_data.py
│ ├── loss.py
│ └── train.py
└── tools
│ ├── __init__.py
│ ├── io_utils.py
│ ├── metric.py
│ ├── output.py
│ ├── parser_train.py
│ ├── scatter.py
│ ├── torch_geometric
│ ├── README.md
│ ├── __init__.py
│ ├── batch.py
│ ├── data.py
│ ├── dataloader.py
│ ├── dataset.py
│ └── utils.py
│ ├── torch_tools.py
│ └── utils.py
├── examples
├── water_train.py
├── water_train_pl.py
└── water_train_w_ft_pl.py
├── scripts
├── cace_alchemical_rep_v0.pth
├── compute_avg_e0.py
├── compute_cace_desc.py
├── split.py
└── train.py
├── setup.py
└── tests
├── ewald_nopbc.py
├── test-cace-representation-rotation.ipynb
├── test_angular.py
├── test_edge_encoder.py
├── test_elementwise_multiply_tensors.py
├── test_ewald.py
├── test_ewald_triclinic.py
├── test_neighborhood.py
└── test_pol_triclinic.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Bingqing Cheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cartesian Atomic Cluster Expansion for Machine Learning Interatomic Potentials (CACE)
2 |
3 | ## Summary
4 |
5 | The Cartesian Atomic Cluster Expansion (CACE) is a new approach for developing machine learning interatomic potentials. This method utilizes Cartesian coordinates to provide a complete description of atomic environments, maintaining interaction body orders. It integrates low-dimensional embeddings of chemical elements with inter-atomic message passing.
6 |
7 | ## Requirements
8 |
9 | - Python 3.6 or higher
10 | - NumPy
11 | - ASE (Atomic Simulation Environment)
12 | - PyTorch
13 | - matscipy
14 |
15 | ## Installation
16 |
17 | Please refer to the `setup.py` file for installation instructions.
18 |
19 | ## Usage
20 |
21 | Please refer to the `scripts/train.py`.
22 |
23 | More example scripts can be found in [https://github.com/BingqingCheng/cacefit].
24 |
25 | ## License
26 |
27 | This project is licensed under the MIT License - see the LICENSE file for details.
28 |
29 | ## Citation
30 |
31 | ```text
32 | @article{cheng2024cartesian,
33 | title={Cartesian atomic cluster expansion for machine learning interatomic potentials},
34 | author={Cheng, Bingqing},
35 | journal={npj Computational Materials},
36 | volume={10},
37 | number={1},
38 | pages={157},
39 | year={2024},
40 | publisher={Nature Publishing Group UK London}
41 | }
42 | ```
43 |
44 | ## Contact
45 |
46 | For any queries regarding CACE, please contact Bingqing Cheng at tonicbq@gmail.com.
47 |
48 |
--------------------------------------------------------------------------------
/cace/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['data', 'modules', 'tools', 'representations', 'models', 'tasks', 'calculators']
2 |
3 | from . import data
4 | from . import modules
5 | from . import tools
6 | from . import representations
7 | from . import models
8 | from . import tasks
9 | from . import calculators
10 |
--------------------------------------------------------------------------------
/cace/calculators/__init__.py:
--------------------------------------------------------------------------------
1 | from .cace_calculator import CACECalculator
2 |
--------------------------------------------------------------------------------
/cace/calculators/cace_calculator.py:
--------------------------------------------------------------------------------
1 | # the CACE calculator for ASE
2 |
3 | from typing import Union, List
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from ase.calculators.calculator import Calculator, all_changes
9 | from ase.stress import full_3x3_to_voigt_6_stress
10 |
11 | from ..tools import torch_geometric, torch_tools, to_numpy
12 | from ..data import AtomicData
13 |
14 | __all__ = ["CACECalculator"]
15 |
16 | class CACECalculator(Calculator):
17 | """CACE ASE Calculator
18 | args:
19 | model_path: str or nn.module, path to model
20 | device: str, device to run on (cuda or cpu)
21 | compute_stress: bool, whether to compute stress
22 | energy_key: str, key for energy in model output
23 | forces_key: str, key for forces in model output
24 | energy_units_to_eV: float, conversion factor from model energy units to eV
25 | length_units_to_A: float, conversion factor from model length units to Angstroms
26 | atomic_energies: dict, dictionary of atomic energies to add to model output
27 | """
28 |
29 | def __init__(
30 | self,
31 | model_path: Union[str, torch.nn.Module],
32 | device: str,
33 | energy_units_to_eV: float = 1.0,
34 | length_units_to_A: float = 1.0,
35 | electric_field_unit: float = 1.0,
36 | compute_stress = False,
37 | energy_key: str = 'energy',
38 | forces_key: str = 'forces',
39 | stress_key: str = 'stress',
40 | bec_key: str = 'bec',
41 | external_field: Union[float,List[float]] = None,
42 | keep_neutral: bool = True, # to keep BEC sum to be neutral
43 | atomic_energies: dict = None,
44 | output_index: int = None, # only used for multi-output models
45 | **kwargs,
46 | ):
47 |
48 | Calculator.__init__(self, **kwargs)
49 | self.implemented_properties = [
50 | "energy",
51 | "forces",
52 | "stress",
53 | ]
54 |
55 | self.results = {}
56 |
57 | if isinstance(model_path, str):
58 | self.model = torch.load(f=model_path, map_location=device)
59 | elif isinstance(model_path, torch.nn.Module):
60 | self.model = model_path
61 | else:
62 | raise ValueError("model_path must be a string or nn.Module")
63 | self.model.to(device)
64 |
65 | self.device = torch_tools.init_device(device)
66 | self.energy_units_to_eV = energy_units_to_eV
67 | self.length_units_to_A = length_units_to_A
68 | self.electric_field_unit = electric_field_unit
69 |
70 | try:
71 | self.cutoff = self.model.representation.cutoff
72 | except AttributeError:
73 | self.cutoff = self.model.models[0].representation.cutoff
74 |
75 | self.atomic_energies = atomic_energies
76 |
77 | self.compute_stress = compute_stress
78 | self.energy_key = energy_key
79 | self.forces_key = forces_key
80 | self.stress_key = stress_key
81 | self.bec_key = bec_key
82 | self.keep_neutral = keep_neutral
83 |
84 | if external_field is not None:
85 | if isinstance(external_field, float):
86 | self.external_field = external_field
87 | else:
88 | self.external_field = np.array(external_field)
89 | else:
90 | self.external_field = None
91 | self.output_index = output_index
92 |
93 | for param in self.model.parameters():
94 | param.requires_grad = False
95 |
96 | def calculate(self, atoms=None, properties=None, system_changes=all_changes):
97 | """
98 | Calculate properties.
99 | :param atoms: ase.Atoms object
100 | :param properties: [str], properties to be computed, used by ASE internally
101 | :param system_changes: [str], system changes since last calculation, used by ASE internally
102 | :return:
103 | """
104 | # call to base-class to set atoms attribute
105 | Calculator.calculate(self, atoms)
106 |
107 | if not hasattr(self, "output_index"):
108 | self.output_index = None
109 |
110 | # prepare data
111 | data_loader = torch_geometric.dataloader.DataLoader(
112 | dataset=[
113 | AtomicData.from_atoms(
114 | atoms, cutoff=self.cutoff
115 | )
116 | ],
117 | batch_size=1,
118 | shuffle=False,
119 | drop_last=False,
120 | )
121 |
122 | batch_base = next(iter(data_loader)).to(self.device)
123 | batch = batch_base.clone()
124 | output = self.model(batch.to_dict(), training=True, compute_stress=self.compute_stress, output_index=self.output_index)
125 | energy_output = to_numpy(output[self.energy_key])
126 | forces_output = to_numpy(output[self.forces_key])
127 | if self.external_field is not None and self.bec_key is not None:
128 | bec_output = to_numpy(output[self.bec_key])
129 | # subtract atomic energies if available
130 | if self.atomic_energies:
131 | e0 = sum(self.atomic_energies.get(Z, 0) for Z in atoms.get_atomic_numbers())
132 | else:
133 | e0 = 0.0
134 | self.results["energy"] = (energy_output + e0) * self.energy_units_to_eV
135 | self.results["forces"] = forces_output * self.energy_units_to_eV / self.length_units_to_A
136 | if self.external_field is not None:
137 | if isinstance(self.external_field, float):
138 | if self.keep_neutral:
139 | correction = -np.average(bec_output, axis=0)
140 | bec_output += correction
141 | forces_bec = bec_output * self.external_field * self.electric_field_unit
142 | self.results["forces"] += forces_bec # bec_output * self.external_field * self.electric_field_unit
143 | else:
144 | if self.keep_neutral:
145 | correction = -np.average(bec_output, axis=0)
146 | bec_output += correction
147 | forces_bec = bec_output @ self.external_field * self.electric_field_unit
148 | self.results["forces"] += forces_bec # bec_output * self.external_field * self.electric_field_unit
149 |
150 | if self.compute_stress and output[self.stress_key] is not None:
151 | stress = to_numpy(output[self.stress_key])
152 | # stress has units eng / len^3:
153 | self.results["stress"] = (
154 | stress * (self.energy_units_to_eV / self.length_units_to_A**3)
155 | )[0]
156 | self.results["stress"] = full_3x3_to_voigt_6_stress(self.results["stress"])
157 |
158 | return self.results
159 |
--------------------------------------------------------------------------------
/cace/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .atomic_data import AtomicData, default_data_key, get_data_loader
2 | from .neighborhood import get_neighborhood
3 |
4 | __all__ = ["AtomicData", "default_data_key", "get_neighborhood", "get_data_loader"]
5 |
--------------------------------------------------------------------------------
/cace/data/atomic_data.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # Atomic Data Class for handling molecules as graphs
3 | # modified from MACE
4 | # This program is distributed under the MIT License (see MIT.md)
5 | ###########################################################################################
6 |
7 | from typing import Optional, Sequence, Dict
8 | from ase import Atoms
9 | import numpy as np
10 | #import torch_geometric
11 | from ..tools import torch_geometric
12 | import torch.nn as nn
13 | import torch.utils.data
14 | from ..tools import voigt_to_matrix
15 |
16 | from .neighborhood import get_neighborhood
17 |
18 | default_data_key = {
19 | "energy": "energy",
20 | "forces": "forces",
21 | "molecular_index": "molecular_index",
22 | "stress": "stress",
23 | "virials": "virials",
24 | "dipole": None,
25 | "charges": None,
26 | "weights": None,
27 | "energy_weight": None,
28 | "force_weight": None,
29 | "stress_weight": None,
30 | "virial_weight": None,
31 | }
32 |
33 | class AtomicData(torch_geometric.data.Data):
34 | atomic_numbers: torch.Tensor
35 | num_graphs: torch.Tensor
36 | num_nodes: torch.Tensor
37 | batch: torch.Tensor
38 | edge_index: torch.Tensor
39 | node_attrs: torch.Tensor
40 | n_atom_basis: torch.Tensor
41 | positions: torch.Tensor
42 | shifts: torch.Tensor
43 | unit_shifts: torch.Tensor
44 | cell: torch.Tensor
45 | forces: torch.Tensor
46 | molecular_index: torch.Tensor
47 | energy: torch.Tensor
48 | stress: torch.Tensor
49 | virials: torch.Tensor
50 | dipole: torch.Tensor
51 | charges: torch.Tensor
52 | weight: torch.Tensor
53 | energy_weight: torch.Tensor
54 | forces_weight: torch.Tensor
55 | stress_weight: torch.Tensor
56 | virials_weight: torch.Tensor
57 |
58 | def __init__(
59 | self,
60 | edge_index: torch.Tensor, # [2, n_edges], always sender -> receiver
61 | atomic_numbers: torch.Tensor, # [n_nodes]
62 | positions: torch.Tensor, # [n_nodes, 3]
63 | shifts: torch.Tensor, # [n_edges, 3],
64 | unit_shifts: torch.Tensor, # [n_edges, 3]
65 | num_nodes: Optional[torch.Tensor] = None, #[,]
66 | cell: Optional[torch.Tensor] = None, # [3,3]
67 | forces: Optional[torch.Tensor] = None, # [n_nodes, 3]
68 | molecular_index: Optional[torch.Tensor] = None, # [n_nodes]
69 | energy: Optional[torch.Tensor] = None, # [, ]
70 | stress: Optional[torch.Tensor] = None, # [1,3,3]
71 | virials: Optional[torch.Tensor] = None, # [1,3,3]
72 | additional_info: Optional[Dict] = None,
73 | #dipole: Optional[torch.Tensor], # [, 3]
74 | #charges: Optional[torch.Tensor], # [n_nodes, ]
75 | ):
76 | # Check shapes
77 | if num_nodes is None:
78 | num_nodes = atomic_numbers.shape[0]
79 | else:
80 | assert num_nodes == atomic_numbers.shape[0]
81 | assert edge_index.shape[0] == 2 and len(edge_index.shape) == 2
82 | assert positions.shape == (num_nodes, 3)
83 | assert shifts.shape[1] == 3
84 | assert unit_shifts.shape[1] == 3
85 | assert cell is None or cell.shape == (3, 3)
86 | assert forces is None or forces.shape == (num_nodes, 3)
87 | assert molecular_index is None or molecular_index.shape == (num_nodes,)
88 | assert energy is None or len(energy.shape) == 0
89 | assert stress is None or stress.shape == (1, 3, 3)
90 | assert virials is None or virials.shape == (1, 3, 3)
91 | #assert dipole is None or dipole.shape[-1] == 3
92 | #assert charges is None or charges.shape == (num_nodes,)
93 | # Aggregate data
94 | data = {
95 | "edge_index": edge_index,
96 | "positions": positions,
97 | "shifts": shifts,
98 | "unit_shifts": unit_shifts,
99 | "cell": cell,
100 | "atomic_numbers": atomic_numbers,
101 | "num_nodes": num_nodes,
102 | "forces": forces,
103 | "molecular_index": molecular_index,
104 | "energy": energy,
105 | "stress": stress,
106 | "virials": virials,
107 | #"dipole": dipole,
108 | #"charges": charges,
109 | }
110 | if additional_info is not None:
111 | data.update(additional_info)
112 | super().__init__(**data)
113 |
114 | @classmethod
115 | def from_atoms(
116 | cls,
117 | atoms: Atoms,
118 | cutoff: float,
119 | data_key: Dict[str, str] = None,
120 | atomic_energies: Optional[Dict[int, float]] = None,
121 | ) -> "AtomicData":
122 | if data_key is not None:
123 | data_key = default_data_key.update(data_key)
124 | data_key = default_data_key
125 | positions = atoms.get_positions()
126 | pbc = tuple(atoms.get_pbc())
127 | cell = np.array(atoms.get_cell())
128 | atomic_numbers = atoms.get_atomic_numbers()
129 |
130 | edge_index, shifts, unit_shifts = get_neighborhood(
131 | positions=positions,
132 | cutoff=cutoff,
133 | pbc=pbc,
134 | cell=cell
135 | )
136 |
137 | try:
138 | energy = atoms.info.get(data_key["energy"], None) # eV
139 | except:
140 | # this ugly bit is for compatibility with newest ASE versions
141 | if data_key['energy'] == 'energy':
142 | energy = atoms.get_potential_energy()
143 | else:
144 | energy = None
145 |
146 | # subtract atomic energies if available
147 | if atomic_energies and energy is not None:
148 | energy -= sum(atomic_energies.get(Z, 0) for Z in atomic_numbers)
149 | try:
150 | forces = atoms.arrays.get(data_key["forces"], None) # eV / Ang
151 | except:
152 | if data_key['forces'] == 'forces':
153 | forces = atoms.get_forces()
154 | else:
155 | forces = None
156 | molecular_index = atoms.arrays.get(data_key["molecular_index"], None) # index of molecules
157 | stress = atoms.info.get(data_key["stress"], None) # eV / Ang
158 | virials = atoms.info.get(data_key["virials"], None)
159 |
160 | # process these to make tensors
161 | cell = (
162 | torch.tensor(cell, dtype=torch.get_default_dtype())
163 | if cell is not None
164 | else torch.tensor(
165 | 3 * [0.0, 0.0, 0.0], dtype=torch.get_default_dtype()
166 | ).view(3, 3)
167 | )
168 |
169 | forces = (
170 | torch.tensor(forces, dtype=torch.get_default_dtype())
171 | if forces is not None
172 | else None
173 | )
174 |
175 | molecular_index = (
176 | torch.tensor(molecular_index, dtype=torch.long)
177 | if molecular_index is not None
178 | else None
179 | )
180 |
181 | energy = (
182 | torch.tensor(energy, dtype=torch.get_default_dtype())
183 | if energy is not None
184 | else None
185 | )
186 | stress = (
187 | voigt_to_matrix(
188 | torch.tensor(stress, dtype=torch.get_default_dtype())
189 | ).unsqueeze(0)
190 | if stress is not None
191 | else None
192 | )
193 | virials = (
194 | torch.tensor(virials, dtype=torch.get_default_dtype()).unsqueeze(0)
195 | if virials is not None
196 | else None
197 | )
198 |
199 | # obtain additional info
200 | # enumerate the data_key and extract data
201 | additional_info = {}
202 | for key, kk in data_key.items():
203 | if kk is None or key in ['energy', 'forces', 'stress', 'virials', 'molecular_index']:
204 | continue
205 | else:
206 | more_info = atoms.info.get(kk, None)
207 | if more_info is None:
208 | more_info = atoms.arrays.get(kk, None)
209 | more_info = (
210 | torch.tensor(more_info, dtype=torch.get_default_dtype())
211 | if more_info is not None
212 | else None
213 | )
214 | additional_info[key] = more_info
215 |
216 | return cls(
217 | edge_index=torch.tensor(edge_index, dtype=torch.long),
218 | positions=torch.tensor(positions, dtype=torch.get_default_dtype()),
219 | shifts=torch.tensor(shifts, dtype=torch.get_default_dtype()),
220 | unit_shifts=torch.tensor(unit_shifts, dtype=torch.get_default_dtype()),
221 | cell=cell,
222 | atomic_numbers=torch.tensor(atomic_numbers, dtype=torch.long),
223 | num_nodes=atomic_numbers.shape[0],
224 | forces=forces,
225 | molecular_index=molecular_index,
226 | energy=energy,
227 | stress=stress,
228 | virials=virials,
229 | additional_info=additional_info,
230 | )
231 |
232 |
233 | def get_data_loader(
234 | dataset: Sequence[AtomicData],
235 | batch_size: int,
236 | shuffle=True,
237 | drop_last=False,
238 | ) -> torch.utils.data.DataLoader:
239 | return torch_geometric.dataloader.DataLoader(
240 | dataset=dataset,
241 | batch_size=batch_size,
242 | shuffle=shuffle,
243 | drop_last=drop_last,
244 | )
245 |
--------------------------------------------------------------------------------
/cace/data/neighborhood.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # Neighborhood construction
3 | # modified from MACE
4 | # This program is distributed under the MIT License (see MIT.md)
5 | ###########################################################################################
6 |
7 | from typing import Optional, Tuple
8 |
9 | import ase.neighborlist
10 | import numpy as np
11 |
12 |
13 | from matscipy.neighbours import neighbour_list
14 |
15 |
16 | def get_neighborhood(
17 | positions: np.ndarray, # [num_positions, 3]
18 | cutoff: float,
19 | pbc: Optional[Tuple[bool, bool, bool]] = None,
20 | cell: Optional[np.ndarray] = None, # [3, 3]
21 | true_self_interaction=False,
22 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
23 | if pbc is None:
24 | pbc = (False, False, False)
25 |
26 | if cell is None or cell.any() == np.zeros((3, 3)).any():
27 | cell = np.identity(3, dtype=float)
28 |
29 | assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc)
30 | assert cell.shape == (3, 3)
31 |
32 | pbc_x = pbc[0]
33 | pbc_y = pbc[1]
34 | pbc_z = pbc[2]
35 | identity = np.identity(3, dtype=float)
36 | max_positions = np.max(np.absolute(positions)) + 1
37 | # Extend cell in non-periodic directions
38 | # For models with more than 5 layers, the multiplicative constant needs to be increased.
39 | if not pbc_x:
40 | cell[:, 0] = max_positions * 5 * cutoff * identity[:, 0]
41 | if not pbc_y:
42 | cell[:, 1] = max_positions * 5 * cutoff * identity[:, 1]
43 | if not pbc_z:
44 | cell[:, 2] = max_positions * 5 * cutoff * identity[:, 2]
45 |
46 | sender, receiver, unit_shifts = neighbour_list(
47 | quantities="ijS",
48 | pbc=pbc,
49 | cell=cell,
50 | positions=positions,
51 | cutoff=float(cutoff),
52 | )
53 |
54 | if not true_self_interaction:
55 | # Eliminate self-edges that don't cross periodic boundaries
56 | true_self_edge = sender == receiver
57 | true_self_edge &= np.all(unit_shifts == 0, axis=1)
58 | keep_edge = ~true_self_edge
59 |
60 | # Note: after eliminating self-edges, it can be that no edges remain in this system
61 | sender = sender[keep_edge]
62 | receiver = receiver[keep_edge]
63 | unit_shifts = unit_shifts[keep_edge]
64 |
65 | # Build output
66 | edge_index = np.stack((sender, receiver)) # [2, n_edges]
67 |
68 | # From the docs: With the shift vector S, the distances D between atoms can be computed from
69 | # D = positions[j]-positions[i]+S.dot(cell)
70 | shifts = np.dot(unit_shifts, cell) # [n_edges, 3]
71 |
72 | return edge_index, shifts, unit_shifts
73 |
74 | def get_neighborhood_ASE(
75 | positions: np.ndarray, # [num_positions, 3]
76 | cutoff: float,
77 | pbc: Optional[Tuple[bool, bool, bool]] = None,
78 | cell: Optional[np.ndarray] = None, # [3, 3]
79 | true_self_interaction=False,
80 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
81 | if pbc is None:
82 | pbc = (False, False, False)
83 |
84 | if cell is None or cell.any() == np.zeros((3, 3)).any():
85 | cell = 1000. * np.identity(3, dtype=float)
86 | pbc = (False, False, False)
87 |
88 | assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc)
89 | assert cell.shape == (3, 3)
90 |
91 | """
92 | ‘i’ : first atom index
93 | ‘j’ : second atom index
94 | ‘d’ : absolute distance
95 | ‘D’ : distance vector
96 | ‘S’ : shift vector (number of cell boundaries crossed by the bond between atom i and j). With the shift vector S, the distances D between atoms can be computed from: D = positions[j]-positions[i]+S.dot(cell)
97 | """
98 | sender, receiver, unit_shifts = ase.neighborlist.primitive_neighbor_list(
99 | quantities="ijS",
100 | pbc=pbc,
101 | cell=cell,
102 | positions=positions,
103 | cutoff=cutoff,
104 | self_interaction=True, # we want edges from atom to itself in different periodic images
105 | use_scaled_positions=False, # positions are not scaled positions
106 | )
107 |
108 | if not true_self_interaction:
109 | # Eliminate self-edges that don't cross periodic boundaries
110 | true_self_edge = sender == receiver
111 | true_self_edge &= np.all(unit_shifts == 0, axis=1)
112 | keep_edge = ~true_self_edge
113 |
114 | # Note: after eliminating self-edges, it can be that no edges remain in this system
115 | sender = sender[keep_edge]
116 | receiver = receiver[keep_edge]
117 | unit_shifts = unit_shifts[keep_edge]
118 |
119 | # Build output
120 | edge_index = np.stack((sender, receiver)) # [2, n_edges]
121 |
122 | # From the docs: With the shift vector S, the distances D between atoms can be computed from
123 | # D = positions[j]-positions[i]+S.dot(cell)
124 | shifts = np.dot(unit_shifts, cell) # [n_edges, 3]
125 |
126 | # sender = edge_index[0] receiver = edge_index[1]
127 | return edge_index, shifts, unit_shifts
128 |
--------------------------------------------------------------------------------
/cace/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .atomistic import *
2 | from .combined import *
3 |
--------------------------------------------------------------------------------
/cace/models/atomistic.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, List
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from ..modules import Transform
7 | from ..modules import Preprocess
8 | from ..tools import torch_geometric
9 |
10 | __all__ = ["AtomisticModel", "NeuralNetworkPotential"]
11 |
12 |
13 | class AtomisticModel(nn.Module):
14 | """
15 | Base class for atomistic neural network models.
16 | """
17 |
18 | def __init__(
19 | self,
20 | postprocessors: Optional[List[Transform]] = None,
21 | do_postprocessing: bool = False,
22 | ):
23 | """
24 | Args:
25 | postprocessors: Post-processing transforms that may be
26 | initialized using the `datamodule`, but are not
27 | applied during training.
28 | do_postprocessing: If true, post-processing is activated.
29 | """
30 | super().__init__()
31 | self.do_postprocessing = do_postprocessing
32 | self.postprocessors = nn.ModuleList(postprocessors)
33 | self.required_derivatives: Optional[List[str]] = None
34 | self.model_outputs: Optional[List[str]] = None
35 |
36 | def collect_derivatives(self) -> List[str]:
37 | self.required_derivatives = None
38 | required_derivatives = set()
39 | for m in self.modules():
40 | if (
41 | hasattr(m, "required_derivatives")
42 | and m.required_derivatives is not None
43 | ):
44 | required_derivatives.update(m.required_derivatives)
45 | required_derivatives: List[str] = list(required_derivatives)
46 | self.required_derivatives = required_derivatives
47 |
48 | def collect_outputs(self) -> List[str]:
49 | self.model_outputs = None
50 | model_outputs = set()
51 | for m in self.modules():
52 | if hasattr(m, "model_outputs") and m.model_outputs is not None:
53 | model_outputs.update(m.model_outputs)
54 | model_outputs: List[str] = list(model_outputs)
55 | self.model_outputs = model_outputs
56 |
57 | def initialize_derivatives(
58 | self, data: Dict[str, torch.Tensor]
59 | ) -> Dict[str, torch.Tensor]:
60 | for p in self.required_derivatives:
61 | if isinstance(data, torch_geometric.Batch):
62 | if p in data.to_dict().keys():
63 | data[p].requires_grad_(True)
64 | else:
65 | if p in data.keys():
66 | data[p].requires_grad_(True)
67 | return data
68 |
69 | def initialize_transforms(self, datamodule):
70 | for module in self.modules():
71 | if isinstance(module, Transform):
72 | module.datamodule(datamodule)
73 |
74 | def postprocess(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
75 | if self.do_postprocessing:
76 | # apply postprocessing
77 | for pp in self.postprocessors:
78 | data = pp(data)
79 | return data
80 |
81 | def extract_outputs(
82 | self, data: Dict[str, torch.Tensor]
83 | ) -> Dict[str, torch.Tensor]:
84 | results = {k: data[k] for k in self.model_outputs}
85 | return results
86 |
87 |
88 | class NeuralNetworkPotential(AtomisticModel):
89 | """
90 | A generic neural network potential class that sequentially applies a list of input
91 | modules, a representation module and a list of output modules.
92 |
93 | This can be flexibly configured for various, e.g. property prediction or potential
94 | energy sufaces with response properties.
95 | """
96 |
97 | def __init__(
98 | self,
99 | representation: nn.Module = None,
100 | input_modules: List[nn.Module] = None,
101 | output_modules: List[nn.Module] = None,
102 | postprocessors: Optional[List[Transform]] = None,
103 | do_postprocessing: bool = False,
104 | keep_graph: bool = False,
105 | ):
106 | """
107 | Args:
108 | representation: The module that builds representation from data.
109 | input_modules: Modules that are applied before representation, e.g. to
110 | modify input or add additional tensors for response properties.
111 | output_modules: Modules that predict output properties from the
112 | representation.
113 | postprocessors: Post-processing transforms that may be initialized using the
114 | `datamodule`, but are not applied during training.
115 | input_dtype_str: The dtype of real data.
116 | do_postprocessing: If true, post-processing is activated.
117 | """
118 | super().__init__(
119 | postprocessors=postprocessors,
120 | do_postprocessing=do_postprocessing,
121 | )
122 | self.representation = representation
123 | if input_modules is None:
124 | preprocessor = Preprocess()
125 | input_modules = [preprocessor]
126 | self.input_modules = nn.ModuleList(input_modules)
127 | self.output_modules = nn.ModuleList(output_modules)
128 |
129 | self.collect_derivatives()
130 | self.collect_outputs()
131 |
132 | self.keep_graph = keep_graph
133 |
134 | def add_module(self, module: nn.Module, module_type: str = "output"):
135 | if module_type == "input":
136 | self.input_modules.append(module)
137 | elif module_type == "output":
138 | self.output_modules.append(module)
139 | self.collect_derivatives()
140 | self.collect_outputs()
141 | else:
142 | raise ValueError(f"Unknown module type {module_type}")
143 |
144 | def remove_module(self, module_index: int, module_type: str = "output"):
145 | if module_type == "output":
146 | del self.output_modules[module_index]
147 | self.collect_derivatives()
148 | self.collect_outputs()
149 | else:
150 | raise ValueError(f"Unknown module type {module_type}")
151 |
152 | def forward(self,
153 | data: Dict[str, torch.Tensor],
154 | training: bool = False,
155 | compute_stress: bool = True,
156 | compute_virials: bool = False,
157 | output_index: int = None, # only used for multiple-head output
158 | ) -> Dict[str, torch.Tensor]:
159 | # initialize derivatives for response properties
160 | data = self.initialize_derivatives(data)
161 |
162 | if 'stress' in self.model_outputs or 'CACE_stress' in self.model_outputs:
163 | compute_stress = True
164 | for m in self.input_modules:
165 | data = m(data, compute_stress=compute_stress, compute_virials=compute_virials)
166 |
167 | if self.representation is not None:
168 | data = self.representation(data)
169 |
170 | for m in self.output_modules:
171 | if hasattr(self, "keep_graph"):
172 | training = training or self.keep_graph
173 | data = m(data, training=training, output_index=output_index)
174 |
175 | # apply postprocessing (if enabled)
176 | data = self.postprocess(data)
177 |
178 | results = self.extract_outputs(data)
179 |
180 | return results
181 |
182 |
--------------------------------------------------------------------------------
/cace/models/combined.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Dict, List
4 |
5 | __all__ = ['CombinePotential']
6 |
7 | class CombinePotential(nn.Module):
8 | def __init__(
9 | self,
10 | potentials: List[nn.Module],
11 | potential_keys: List[Dict],
12 | operation = None,
13 | ):
14 | """
15 | Combine multiple potentials into a single potential.
16 | Args:
17 | potentials: List of potentials to combine.
18 | potential_keys: List of dictionaries with keys for each potential.
19 | e.g. [pot1, pot2] where
20 | pot1 = {'CACE_energy': 'CACE_energy_intra',
21 | 'CACE_forces': 'CACE_forces_intra',
22 | 'weight': 1.
23 | }
24 |
25 | pot2 = {'CACE_energy': 'CACE_energy_inter',
26 | 'CACE_forces': 'CACE_forces_inter',
27 | 'weight': 0.01,
28 | }
29 | out_keys: List of keys to output. Should be present in all potential_keys.
30 | operation: Operation to combine the outputs of the potentials.
31 | """
32 | super().__init__()
33 | self.models = nn.ModuleList([potential for potential in potentials])
34 | self.potential_keys = potential_keys
35 | self.required_derivatives = []
36 | for potential in potentials:
37 | for d in potential.required_derivatives:
38 | if d not in self.required_derivatives:
39 | self.required_derivatives.append(d)
40 |
41 | self.out_keys = []
42 | for key in potential_keys[0]:
43 | if all(key in potential_key for potential_key in self.potential_keys) and key != 'weight':
44 | self.out_keys.append(key)
45 |
46 | if operation is None:
47 | # Default operation (sum)
48 | self.operation = self.default_operation
49 | else:
50 | self.operation = operation
51 |
52 | def default_operation(self, my_list):
53 | return torch.stack(my_list).sum(0)
54 |
55 |
56 | def forward(self,
57 | data: Dict[str, torch.Tensor],
58 | training: bool = False,
59 | compute_stress: bool = False,
60 | compute_virials: bool = False,
61 | output_index: int = None, # only used for multiple-head output
62 | ) -> Dict[str, torch.Tensor]:
63 | results = {}
64 | output = {}
65 | for i, potential in enumerate(self.models):
66 | result = potential(data, training, compute_stress, compute_virials, output_index)
67 | results[i] = result
68 | output.update(result)
69 |
70 | for key in self.out_keys:
71 | values = []
72 | for i, potential_key in enumerate(self.potential_keys):
73 | v_now = results[i][potential_key[key]]
74 | if 'weight' in potential_key:
75 | v_now *= potential_key['weight']
76 | values.append(v_now)
77 | if values:
78 | output[key] = self.operation(values)
79 | return output
80 |
--------------------------------------------------------------------------------
/cace/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .preprocess import *
2 |
3 | from .angular import *
4 |
5 | from .angular_tools import *
6 |
7 | from .symmetrize_basis import *
8 |
9 | from .product_basis import *
10 |
11 | from .cutoff import *
12 |
13 | from .radial import *
14 |
15 | from .radial_transform import *
16 |
17 | from .type import *
18 |
19 | from .utils import *
20 |
21 | from .blocks import *
22 |
23 | from .atomwise import *
24 |
25 | from .ewald import *
26 |
27 | from .forces import *
28 |
29 | from .interaction import *
30 |
31 | from .transform import *
32 |
33 | from .feature_mix import *
34 |
35 | # from .node_edge_former import *
36 |
37 | from .polarization import *
38 |
39 | from .grad import *
40 |
41 | from .les_wrapper import *
42 |
--------------------------------------------------------------------------------
/cace/modules/angular.py:
--------------------------------------------------------------------------------
1 | ###############################################
2 | # This module contains functions to compute the angular part of the
3 | # edge basis functions
4 | ###############################################
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from math import factorial
10 | from collections import OrderedDict
11 |
12 | __all__=['AngularComponent', 'AngularComponent_GPU', 'make_lxlylz_list', 'make_lxlylz', 'make_l_dict', 'l_dict_to_lxlylz_list', 'compute_length_lxlylz', 'compute_length_lmax', 'compute_length_lmax_numerical', 'lxlylz_factorial_coef', 'lxlylz_factorial_coef_torch', 'l1l2_factorial_coef']
13 |
14 | import torch
15 | import torch.nn as nn
16 | from collections import OrderedDict
17 |
18 | class AngularComponent(nn.Module):
19 | """ Angular component of the edge basis functions
20 | Optimized for CPU usage (use recursive formula)
21 | """
22 | def __init__(self, l_max):
23 | super().__init__()
24 | self.l_max = l_max
25 | self.precompute_lxlylz()
26 |
27 | def precompute_lxlylz(self):
28 | self.lxlylz_dict = OrderedDict({l: [] for l in range(self.l_max + 1)})
29 | self.lxlylz_dict[0] = [(0, 0, 0)]
30 | for l in range(1, self.l_max + 1):
31 | for prev_lxlylz_combination in self.lxlylz_dict[l - 1]:
32 | for i in range(3):
33 | lxlylz_combination = list(prev_lxlylz_combination)
34 | lxlylz_combination[i] += 1
35 | lxlylz_combination_tuple = tuple(lxlylz_combination)
36 | if lxlylz_combination_tuple not in self.lxlylz_dict[l]:
37 | self.lxlylz_dict[l].append(lxlylz_combination_tuple)
38 | self.lxlylz_list = self._convert_lxlylz_to_list()
39 | # get the start and the end index of the lxlylz_list for each l
40 | self.lxlylz_index = torch.zeros((self.l_max+1, 2), dtype=torch.long)
41 | for l in range(self.l_max+1):
42 | self.lxlylz_index[l, 0] = 0 if l == 0 else self.lxlylz_index[l-1, 1]
43 | self.lxlylz_index[l, 1] = compute_length_lmax(l)
44 |
45 | def forward(self, vectors: torch.Tensor) -> torch.Tensor:
46 |
47 | computed_values = {(0, 0, 0): torch.ones(vectors.size(0), device=vectors.device, dtype=vectors.dtype)}
48 | for l in range(1, self.l_max + 1):
49 | for lxlylz_combination in self.lxlylz_dict[l]:
50 | prev_lxlylz_combination = tuple(l - 1 if i == lxlylz_combination.index(max(lxlylz_combination)) else l for i, l in enumerate(lxlylz_combination))
51 | i = lxlylz_combination.index(max(lxlylz_combination))
52 | computed_values[lxlylz_combination] = computed_values[prev_lxlylz_combination] * vectors[:, i]
53 |
54 | computed_values_list = self._convert_computed_values_to_list(computed_values)
55 | return torch.stack(computed_values_list, dim=1)
56 |
57 | def _convert_lxlylz_to_list(self):
58 | lxlylz_list = []
59 | for l, combinations in self.lxlylz_dict.items():
60 | lxlylz_list.extend(combinations)
61 | return lxlylz_list
62 |
63 | def _convert_computed_values_to_list(self, computed_values):
64 | return [computed_values[comb] for comb in self.lxlylz_list]
65 |
66 | def get_lxlylz_list(self):
67 | if self.lxlylz_list is None:
68 | raise ValueError("You must call forward before getting lxlylz_list")
69 | return self.lxlylz_list
70 |
71 | def get_lxlylz_dict(self):
72 | return self.lxlylz_dict
73 |
74 | def get_lxlylz_index(self):
75 | return self.lxlylz_index
76 |
77 | def __repr__(self):
78 | return f"AngularComponent(l_max={self.l_max})"
79 |
80 | class AngularComponent_GPU(nn.Module):
81 | """ Angular component of the edge basis functions
82 | This version runs faster on gpus but slower on cpus
83 | The ordering of lxlylz_list is different from the CPU version
84 | """
85 | def __init__(self, l_max):
86 | super().__init__()
87 | self.l_max = l_max
88 | self.lxlylz_dict = make_l_dict(l_max)
89 | self.lxlylz_list = l_dict_to_lxlylz_list(self.lxlylz_dict)
90 |
91 | def forward(self, vectors: torch.Tensor) -> torch.Tensor:
92 |
93 | lxlylz_tensor = torch.tensor(self.lxlylz_list, device=vectors.device, dtype=vectors.dtype)
94 |
95 | # Expand vectors and lxlylz_tensor for broadcasting
96 | vectors_expanded = vectors[:, None, :] # Shape: [N, 1, 3]
97 | lxlylz_expanded = lxlylz_tensor[None, :, :] # Shape: [1, M, 3]
98 |
99 | # Compute terms using broadcasting
100 | # Each vector component is raised to the power of corresponding lx, ly, lz
101 | # Somehow this is causing trouble on gpus when doing second order derivatives!!!
102 | terms = vectors_expanded ** lxlylz_expanded # Shape: [N, M, 3]
103 |
104 | # Multiply across the last dimension (x^lx * y^ly * z^lz) for each term
105 | computed_terms = torch.prod(terms, dim=-1) # Shape: [N, M]
106 | # to avoid the mps problem with cumprod
107 | #computed_terms = terms[:, :, 0] * terms[:, :, 1] * terms[:, :, 2]
108 |
109 | return computed_terms
110 |
111 | def get_lxlylz_list(self):
112 | return self.lxlylz_list
113 |
114 | def get_lxlylz_dict(self):
115 | return self.lxlylz_dict
116 |
117 | def __repr__(self):
118 | return f"AngularComponent_GPU(l_max={self.l_max})"
119 |
120 | def make_lxlylz_list(l_max: int):
121 | """
122 | make a list of lxlylz up to l_max
123 | """
124 | l_dict = make_l_dict(l_max)
125 | return l_dict_to_lxlylz_list(l_dict)
126 |
127 | def l_index_select(l):
128 | """ select the index of the lxlylz_list based on l """
129 | return np.arange(compute_length_lmax(l-1), compute_length_lmax(l))
130 |
131 | def make_lxlylz(l):
132 | """
133 | make a list of lxlylz such that lx + ly + lz = l
134 | """
135 | lxlylz = []
136 | for lx in range(l+1):
137 | for ly in range(l+1):
138 | lz = l - lx - ly
139 | if lz >= 0:
140 | lxlylz.append([lx, ly, lz])
141 | #return torch.tensor(lxlylz, dtype=torch.int64)
142 | return lxlylz
143 |
144 | def make_l_dict(l_max):
145 | """
146 | make a ordered dictionary of lxlylz list
147 | up to l_max
148 | """
149 | l_dict = OrderedDict()
150 | for l in range(l_max+1):
151 | l_dict[l] = make_lxlylz(l)
152 | return l_dict
153 |
154 | def l_dict_to_lxlylz_list(l_dict):
155 | """
156 | convert the ordered dictionary to a list of lxlylz
157 | """
158 | lxlylz_list = []
159 | for l in l_dict:
160 | lxlylz_list += l_dict[l]
161 | return lxlylz_list
162 |
163 | def compute_length_lxlylz(l):
164 | """ compute the length of the lxlylz list based on l """
165 | return int((l+1)*(l+2)/2)
166 |
167 | def compute_length_lmax(l_max):
168 | """ compute the length of the lxlylz list based on l_max """
169 | return int((l_max+1)*(l_max+2)*(l_max+3)/6)
170 |
171 | def compute_length_lmax_numerical(l_max):
172 | """ compute the length of the lxlylz list based on l_max numerically"""
173 | length = 0
174 | for l in range(l_max+1):
175 | length += compute_length_lxlylz(l)
176 | return length
177 |
178 |
179 | def l1l2_factorial_coef(l1, l2):
180 | # Ensure inputs are integers
181 | if not all(isinstance(n, int) for n in l1):
182 | raise ValueError("All elements of l1 must be integers.")
183 | if not all(isinstance(n, int) for n in l2):
184 | raise ValueError("All elements of l2 must be integers.")
185 |
186 | # Compute the multinomial coefficient
187 | result = 1
188 | for l1i, l2i in zip(l1, l2):
189 | result *= factorial(l1i + l2i)
190 | result /= factorial(l1i)
191 | result /= factorial(l2i)
192 | return result
193 |
194 | def lxlylz_factorial_coef(lxlylz):
195 | # Ensure inputs are integers
196 | if not all(isinstance(n, int) for n in lxlylz):
197 | raise ValueError("All elements of lxlylz must be integers.")
198 |
199 | # Sort the elements in descending order
200 | sorted_lxlylz = sorted(lxlylz, reverse=True)
201 |
202 | # Compute the sum l = lx + ly + lz
203 | l = sum(sorted_lxlylz)
204 |
205 | # Compute the multinomial coefficient
206 | result = factorial(l)
207 | for lxly in sorted_lxlylz:
208 | result //= factorial(lxly)
209 |
210 | return result
211 |
212 | def lxlylz_factorial_coef_torch(lxlylz) -> torch.Tensor:
213 |
214 | # Check if lxlylz is a tensor, if not, convert to tensor
215 | if not isinstance(lxlylz, torch.Tensor):
216 | lxlylz = torch.tensor(lxlylz, dtype=torch.int64)
217 |
218 | if not torch.all(lxlylz == lxlylz.int()):
219 | raise ValueError("All elements of lxlylz must be integers.")
220 |
221 | sorted_lxlylz, _ = torch.sort(lxlylz, descending=True)
222 | l = torch.sum(sorted_lxlylz)
223 |
224 | result = torch.tensor(1, dtype=torch.int)
225 |
226 | for i in torch.arange(int(sorted_lxlylz[0])):
227 | result = result * (l - i)
228 | result = (result / (i + 1)).floor()
229 |
230 | for i in torch.arange(1, len(sorted_lxlylz)):
231 | for j in torch.arange(int(sorted_lxlylz[i])):
232 | result = result * (l - sorted_lxlylz[:i].sum() - j)
233 | result = (result / (j + 1)).floor()
234 |
235 | return result
236 |
--------------------------------------------------------------------------------
/cace/modules/atomwise.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Union, Sequence, Callable, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .blocks import Dense, ResidualBlock, build_mlp
7 | from ..tools import scatter_sum
8 |
9 | __all__ = ["Atomwise"]
10 |
11 | class Atomwise(nn.Module):
12 | """
13 | Predicts atom-wise contributions and accumulates global prediction, e.g. for the energy.
14 |
15 | If `aggregation_mode` is None, only the per-atom predictions will be returned.
16 | """
17 |
18 | def __init__(
19 | self,
20 | n_in: Optional[int] = None,
21 | n_out: int = 1,
22 | n_hidden: Optional[Union[int, Sequence[int]]] = None,
23 | n_layers: int = 2,
24 | bias: bool = True,
25 | activation: Callable = F.silu,
26 | aggregation_mode: str = "sum",
27 | feature_key: Union[str, Sequence[int]] = 'node_feats',
28 | output_key: str = "energy",
29 | per_atom_output_key: Optional[str] = None,
30 | descriptor_output_key: Optional[str] = None,
31 | residual: bool = False,
32 | use_batchnorm: bool = False,
33 | add_linear_nn: bool = False,
34 | post_process: Optional[Callable] = None
35 | ):
36 | """
37 | Args:
38 | n_in: input dimension of representation
39 | n_out: output dimension of target property (default: 1)
40 | n_hidden: size of hidden layers.
41 | If an integer, same number of node is used for all hidden layers resulting
42 | in a rectangular network.
43 | If None, the number of neurons is divided by two after each layer starting
44 | n_in resulting in a pyramidal network.
45 | n_layers: number of layers.
46 | aggregation_mode: one of {sum, avg} (default: sum)
47 | output_key: the key under which the result will be stored
48 | per_atom_output_key: If not None, the key under which the per-atom result will be stored
49 | residual: whether to use residual connections between layers
50 | use_batchnorm: whether to use batch normalization between layers
51 | add_linear_nn: whether to add a linear NN to the output of the MLP
52 | """
53 | super().__init__()
54 | self.output_key = output_key
55 | self.model_outputs = [output_key]
56 | self.per_atom_output_key = per_atom_output_key
57 | self.descriptor_output_key = descriptor_output_key
58 | if self.per_atom_output_key is not None:
59 | self.model_outputs.append(self.per_atom_output_key)
60 | if self.descriptor_output_key is not None:
61 | self.model_outputs.append(self.descriptor_output_key)
62 |
63 | self.n_out = n_out
64 |
65 | if aggregation_mode is None and self.per_atom_output_key is None:
66 | raise ValueError(
67 | "If `aggregation_mode` is None, `per_atom_output_key` needs to be set,"
68 | + " since no accumulated output will be returned!"
69 | )
70 |
71 | self.n_in = n_in
72 | self.n_out = n_out
73 | self.n_hidden = n_hidden
74 | self.n_layers = n_layers
75 | self.activation = activation
76 | self.aggregation_mode = aggregation_mode
77 | self.residual = residual
78 | self.use_batchnorm = use_batchnorm
79 | self.add_linear_nn = add_linear_nn
80 | self.post_process = post_process
81 | self.bias = bias
82 | self.feature_key = feature_key
83 |
84 | if n_in is not None:
85 | self.outnet = build_mlp(
86 | n_in=self.n_in,
87 | n_out=self.n_out,
88 | n_hidden=self.n_hidden,
89 | n_layers=self.n_layers,
90 | activation=self.activation,
91 | residual=self.residual,
92 | use_batchnorm=self.use_batchnorm,
93 | bias=self.bias,
94 | )
95 | if self.add_linear_nn:
96 | self.linear_nn = Dense(
97 | self.n_in,
98 | self.n_out,
99 | bias=self.bias,
100 | activation=None,
101 | use_batchnorm=self.use_batchnorm,
102 | )
103 |
104 | else:
105 | self.outnet = None
106 |
107 | def forward(self,
108 | data: Dict[str, torch.Tensor],
109 | training: bool = None,
110 | output_index: int = None, # only used for multi-head output
111 | ) -> Dict[str, torch.Tensor]:
112 |
113 | # check if self.feature_key exists, otherwise set default
114 | if not hasattr(self, "feature_key") or self.feature_key is None:
115 | self.feature_key = "node_feats"
116 |
117 | # reshape the feature vectors
118 | if isinstance(self.feature_key, str):
119 | if self.feature_key not in data:
120 | raise ValueError(f"Feature key {self.feature_key} not found in data dictionary.")
121 | features = data[self.feature_key]
122 | features = features.reshape(features.shape[0], -1)
123 | elif isinstance(self.feature_key, list):
124 | features = torch.cat([data[key].reshape(data[key].shape[0], -1) for key in self.feature_key], dim=-1)
125 |
126 | if self.n_in is None:
127 | self.n_in = features.shape[1]
128 | else:
129 | assert self.n_in == features.shape[1]
130 |
131 | if self.outnet == None:
132 | self.outnet = build_mlp(
133 | n_in=self.n_in,
134 | n_out=self.n_out,
135 | n_hidden=self.n_hidden,
136 | n_layers=self.n_layers,
137 | activation=self.activation,
138 | residual=self.residual,
139 | use_batchnorm=self.use_batchnorm,
140 | bias=self.bias,
141 | )
142 | self.outnet = self.outnet.to(features.device)
143 | if self.add_linear_nn:
144 | self.linear_nn = Dense(
145 | self.n_in,
146 | self.n_out,
147 | bias=self.bias,
148 | activation=None,
149 | use_batchnorm=self.use_batchnorm,
150 | )
151 | self.linear_nn = self.linear_nn.to(features.device)
152 | else:
153 | self.linear_nn = None
154 |
155 | # predict atomwise contributions
156 | y = self.outnet(features)
157 | if self.add_linear_nn:
158 | y += self.linear_nn(features)
159 |
160 | # accumulate the per-atom output if necessary
161 | if self.per_atom_output_key is not None:
162 | data[self.per_atom_output_key] = y
163 |
164 | if hasattr(self, "descriptor_output_key") and self.descriptor_output_key is not None:
165 | data[self.descriptor_output_key] = features
166 |
167 | # aggregate
168 | if self.aggregation_mode is not None:
169 | y = scatter_sum(
170 | src=y,
171 | index=data["batch"],
172 | dim=0)
173 | y = torch.squeeze(y, -1)
174 |
175 | if self.aggregation_mode == "avg":
176 | y = y / torch.bincount(data['batch'])
177 |
178 | if hasattr(self, "post_process") and self.post_process is not None:
179 | y = self.post_process(y)
180 | data[self.output_key] = y[:, output_index] if output_index is not None else y
181 | return data
182 |
--------------------------------------------------------------------------------
/cace/modules/blocks.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Union, Optional, Sequence
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 | __all__ = ["build_mlp", "Dense", "ResidualBlock", "AtomicEnergiesBlock"]
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from typing import Union, Callable
14 |
15 | def build_mlp(
16 | n_in: int,
17 | n_out: int,
18 | n_hidden: Optional[Union[int, Sequence[int]]] = None,
19 | n_layers: int = 2,
20 | activation: Callable = F.silu,
21 | residual: bool = False,
22 | use_batchnorm: bool = False,
23 | bias: bool = True,
24 | last_zero_init: bool = False,
25 | ) -> nn.Module:
26 | """
27 | Build multiple layer fully connected perceptron neural network.
28 |
29 | Args:
30 | n_in: number of input nodes.
31 | n_out: number of output nodes.
32 | n_hidden: number hidden layer nodes.
33 | If an integer, same number of node is used for all hidden layers resulting
34 | in a rectangular network.
35 | If None, the number of neurons is divided by two after each layer starting
36 | n_in resulting in a pyramidal network.
37 | n_layers: number of layers.
38 | activation: activation function. All hidden layers would
39 | the same activation function except the output layer that does not apply
40 | any activation function.
41 | residual: whether to use residual connections between layers
42 | use_batchnorm: whether to use batch normalization between layers
43 | """
44 | # get list of number of nodes in input, hidden & output layers
45 | if n_hidden is None:
46 | c_neurons = n_in
47 | n_neurons = []
48 | for i in range(n_layers):
49 | n_neurons.append(c_neurons)
50 | c_neurons = max(n_out, c_neurons // 2)
51 | n_neurons.append(n_out)
52 | else:
53 | # get list of number of nodes hidden layers
54 | if type(n_hidden) is int:
55 | n_hidden = [n_hidden] * (n_layers - 1)
56 | else:
57 | n_hidden = list(n_hidden)
58 | n_neurons = [n_in] + n_hidden + [n_out]
59 |
60 | if residual:
61 | if n_layers < 3 or n_layers % 2 == 0:
62 | raise ValueError("Residual networks require at least 3 layers and an odd number of layers")
63 | layers = []
64 | # Create residual blocks every 2 layers
65 | for i in range(0, n_layers - 1, 2):
66 | in_features = n_neurons[i]
67 | out_features = n_neurons[min(i + 2, len(n_neurons) - 1)]
68 | layers.append(
69 | ResidualBlock(
70 | in_features,
71 | out_features,
72 | activation,
73 | skip_interval=2,
74 | use_batchnorm=use_batchnorm,
75 | )
76 | )
77 | else:
78 | # assign a Dense layer (with activation function) to each hidden layer
79 | layers = [
80 | Dense(n_neurons[i], n_neurons[i + 1], activation=activation, use_batchnorm=use_batchnorm, bias=bias)
81 | for i in range(n_layers - 1)
82 | ]
83 |
84 | # assign a Dense layer (without activation function) to the output layer
85 |
86 | if last_zero_init:
87 | layers.append(
88 | Dense(
89 | n_neurons[-2],
90 | n_neurons[-1],
91 | activation=None,
92 | weight_init=torch.nn.init.zeros_,
93 | bias=bias,
94 | )
95 | )
96 | else:
97 | layers.append(
98 | Dense(n_neurons[-2], n_neurons[-1], activation=None, bias=bias)
99 | )
100 | # put all layers together to make the network
101 | out_net = nn.Sequential(*layers)
102 | return out_net
103 |
104 | class Dense(nn.Module):
105 | def __init__(
106 | self,
107 | in_features: int,
108 | out_features: int,
109 | bias: bool = True,
110 | activation: Union[Callable, nn.Module] = nn.Identity(),
111 | use_batchnorm: bool = False,
112 | ):
113 | """
114 | Fully connected linear layer with an optional activation function and batch normalization.
115 |
116 | Args:
117 | in_features (int): Number of input features.
118 | out_features (int): Number of output features.
119 | bias (bool): If False, the layer will not have a bias term.
120 | activation (Callable or nn.Module): Activation function. Defaults to Identity.
121 | use_batchnorm (bool): If True, include a batch normalization layer.
122 | """
123 | super().__init__()
124 | self.use_batchnorm = use_batchnorm
125 |
126 | # Dense layer
127 | self.linear = nn.Linear(in_features, out_features, bias)
128 |
129 | # Activation function
130 | self.activation = activation
131 | if self.activation is None:
132 | self.activation = nn.Identity()
133 |
134 | # Batch normalization layer
135 | if self.use_batchnorm:
136 | self.batchnorm = nn.BatchNorm1d(out_features)
137 |
138 | def forward(self, input: torch.Tensor):
139 | y = self.linear(input)
140 | if self.use_batchnorm:
141 | y = self.batchnorm(y)
142 | y = self.activation(y)
143 | return y
144 |
145 |
146 | class ResidualBlock(nn.Module):
147 | """
148 | A residual block with flexible number of dense layers, optional batch normalization,
149 | and a skip connection.
150 |
151 | Args:
152 | in_features: Number of input features.
153 | out_features: Number of output features.
154 | activation: Activation function to be used in the dense layers.
155 | skip_interval: Number of layers between each skip connection.
156 | use_batchnorm: Boolean indicating whether to use batch normalization.
157 | """
158 | def __init__(self, in_features, out_features, activation, skip_interval=2, use_batchnorm=True):
159 | super().__init__()
160 | self.skip_interval = skip_interval
161 | self.use_batchnorm = use_batchnorm
162 | self.layers = nn.ModuleList()
163 |
164 | # Skip connection with optional dimension matching and batch normalization
165 | if in_features != out_features:
166 | skip_layers = [Dense(in_features, out_features, activation=None)]
167 | if self.use_batchnorm:
168 | skip_layers.append(nn.BatchNorm1d(out_features))
169 | self.skip = nn.Sequential(*skip_layers)
170 | else:
171 | self.skip = nn.Identity()
172 |
173 | # Create dense layers with optional batch normalization
174 | for _ in range(skip_interval):
175 | self.layers.append(Dense(in_features, out_features, activation=activation))
176 | if self.use_batchnorm:
177 | self.layers.append(nn.BatchNorm1d(out_features))
178 | in_features = out_features # Update in_features for the next layer
179 |
180 | def forward(self, x):
181 | identity = self.skip(x)
182 | out = x
183 |
184 | # Forward through dense layers with skip connections and optional batch normalization
185 | for i, layer in enumerate(self.layers):
186 | out = layer(out)
187 | if (i + 1) % self.skip_interval == 0:
188 | out += identity
189 |
190 | return out
191 |
192 | class AtomicEnergiesBlock(nn.Module):
193 | def __init__(self, nz:int, trainable=True, atomic_energies: Optional[Union[np.ndarray, torch.Tensor]]=None):
194 | super().__init__()
195 | if atomic_energies is None:
196 | atomic_energies = torch.zeros(nz)
197 | else:
198 | assert len(atomic_energies.shape) == 1
199 |
200 | if trainable:
201 | self.atomic_energies = nn.Parameter(atomic_energies)
202 | else:
203 | self.register_buffer("atomic_energies", atomic_energies, torch.get_default_dtype())
204 |
205 | def forward(
206 | self, x: torch.Tensor # one-hot of elements [..., n_elements]
207 | ) -> torch.Tensor: # [..., ]
208 | return torch.matmul(x, self.atomic_energies)
209 |
210 | def __repr__(self):
211 | formatted_energies = ", ".join([f"{x:.4f}" for x in self.atomic_energies])
212 | return f"{self.__class__.__name__}(energies=[{formatted_energies}])"
213 |
--------------------------------------------------------------------------------
/cace/modules/cutoff.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # Radial basis cutoff
3 | # modified from mace/mace/modules/radials.py and schnetpack/src/schnetpack/nn/cutoff.py
4 | # This program is distributed under the MIT License (see MIT.md)
5 | ###########################################################################################
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 |
11 | __all__ = ["CosineCutoff", "MollifierCutoff", "PolynomialCutoff", "SwitchFunction"]
12 |
13 | def cosine_cutoff(input: torch.Tensor, cutoff: torch.Tensor):
14 | """ Behler-style cosine cutoff.
15 |
16 | .. math::
17 | f(r) = \begin{cases}
18 | 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right]
19 | & r < r_\text{cutoff} \\
20 | 0 & r \geqslant r_\text{cutoff} \\
21 | \end{cases}
22 |
23 | Args:
24 | cutoff (float, optional): cutoff radius.
25 |
26 | """
27 |
28 | # Compute values of cutoff function
29 | input_cut = 0.5 * (torch.cos(input * np.pi / cutoff) + 1.0)
30 | # Remove contributions beyond the cutoff radius
31 | input_cut *= (input < cutoff).float()
32 | return input_cut
33 |
34 |
35 | class CosineCutoff(nn.Module):
36 | r""" Behler-style cosine cutoff module.
37 |
38 | .. math::
39 | f(r) = \begin{cases}
40 | 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right]
41 | & r < r_\text{cutoff} \\
42 | 0 & r \geqslant r_\text{cutoff} \\
43 | \end{cases}
44 |
45 | """
46 |
47 | def __init__(self, cutoff: float):
48 | """
49 | Args:
50 | cutoff (float, optional): cutoff radius.
51 | """
52 | super().__init__()
53 | self.register_buffer("cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype()))
54 |
55 | def forward(self, input: torch.Tensor):
56 | return cosine_cutoff(input, self.cutoff)
57 |
58 | def __repr__(self):
59 | return f"{self.__class__.__name__}(cutoff={self.cutoff})"
60 |
61 | def mollifier_cutoff(input: torch.Tensor, cutoff: torch.Tensor, eps: torch.Tensor):
62 | r""" Mollifier cutoff scaled to have a value of 1 at :math:`r=0`.
63 |
64 | .. math::
65 | f(r) = \begin{cases}
66 | \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right)
67 | & r < r_\text{cutoff} \\
68 | 0 & r \geqslant r_\text{cutoff} \\
69 | \end{cases}
70 |
71 | Args:
72 | cutoff: Cutoff radius.
73 | eps: Offset added to distances for numerical stability.
74 |
75 | """
76 | mask = (input + eps < cutoff).float()
77 | exponent = 1.0 - 1.0 / (1.0 - torch.pow(input * mask / cutoff, 2))
78 | cutoffs = torch.exp(exponent)
79 | cutoffs = cutoffs * mask
80 | return cutoffs
81 |
82 |
83 | class MollifierCutoff(nn.Module):
84 | r""" Mollifier cutoff module scaled to have a value of 1 at :math:`r=0`.
85 |
86 | .. math::
87 | f(r) = \begin{cases}
88 | \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right)
89 | & r < r_\text{cutoff} \\
90 | 0 & r \geqslant r_\text{cutoff} \\
91 | \end{cases}
92 | """
93 |
94 | def __init__(self, cutoff: float, eps: float = 1.0e-7):
95 | """
96 | Args:
97 | cutoff: Cutoff radius.
98 | eps: Offset added to distances for numerical stability.
99 | """
100 | super().__init__()
101 | self.register_buffer("cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype()))
102 | self.register_buffer("eps", torch.tensor(eps, dtype=torch.get_default_dtype()))
103 |
104 | def forward(self, input: torch.Tensor):
105 | return mollifier_cutoff(input, self.cutoff, self.eps)
106 |
107 | def __repr__(self):
108 | return f"{self.__class__.__name__}(eps={self.eps}, cutoff={self.cutoff})"
109 |
110 | def _switch_component(
111 | x: torch.Tensor, ones: torch.Tensor, zeros: torch.Tensor
112 | ) -> torch.Tensor:
113 | """
114 | Basic component of switching functions.
115 |
116 | Args:
117 | x (torch.Tensor): Switch functions.
118 | ones (torch.Tensor): Tensor with ones.
119 | zeros (torch.Tensor): Zero tensor
120 |
121 | Returns:
122 | torch.Tensor: Output tensor.
123 | """
124 | x_ = torch.where(x <= 0, ones, x)
125 | return torch.where(x <= 0, zeros, torch.exp(-ones / x_))
126 |
127 |
128 | class SwitchFunction(nn.Module):
129 | """
130 | Decays from 1 to 0 between `switch_on` and `switch_off`.
131 | """
132 |
133 | def __init__(self, switch_on: float, switch_off: float):
134 | """
135 |
136 | Args:
137 | switch_on (float): Onset of switch.
138 | switch_off (float): Value from which on switch is 0.
139 | """
140 | super(SwitchFunction, self).__init__()
141 | self.register_buffer("switch_on", torch.tensor(switch_on, dtype=torch.get_default_dtype()))
142 | self.register_buffer("switch_off", torch.tensor(switch_off, dtype=torch.get_default_dtype()))
143 |
144 | def forward(self, x: torch.Tensor) -> torch.Tensor:
145 | """
146 |
147 | Args:
148 | x (torch.Tensor): tensor to which switching function should be applied to.
149 |
150 | Returns:
151 | torch.Tensor: switch output
152 | """
153 | x = (x - self.switch_on) / (self.switch_off - self.switch_on)
154 |
155 | ones = torch.ones_like(x)
156 | zeros = torch.zeros_like(x)
157 | fp = _switch_component(x, ones, zeros)
158 | fm = _switch_component(1 - x, ones, zeros)
159 |
160 | f_switch = torch.where(x <= 0, ones, torch.where(x >= 1, zeros, fm / (fp + fm)))
161 | return f_switch
162 |
163 | def __repr__(self):
164 | return f"{self.__class__.__name__}(switch_on={self.switch_on}, switch_off={self.switch_off})"
165 |
166 | class PolynomialCutoff(nn.Module):
167 | """
168 | Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020.
169 | Equation (8)
170 | """
171 |
172 | p: torch.Tensor
173 | cutoff: torch.Tensor
174 |
175 | def __init__(self, cutoff: float, p=6):
176 | super().__init__()
177 | self.register_buffer("p", torch.tensor(p, dtype=torch.get_default_dtype()))
178 | self.register_buffer(
179 | "cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype())
180 | )
181 |
182 | def forward(self, x: torch.Tensor) -> torch.Tensor:
183 | envelope = (
184 | 1.0
185 | - ((self.p + 1.0) * (self.p + 2.0) / 2.0) * torch.pow(x / self.cutoff, self.p)
186 | + self.p * (self.p + 2.0) * torch.pow(x / self.cutoff, self.p + 1)
187 | - (self.p * (self.p + 1.0) / 2) * torch.pow(x / self.cutoff, self.p + 2)
188 | )
189 | return envelope * (x < self.cutoff)
190 |
191 | def __repr__(self):
192 | return f"{self.__class__.__name__}(p={self.p}, cutoff={self.cutoff})"
193 |
--------------------------------------------------------------------------------
/cace/modules/feature_mix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Dict
4 |
5 | __all__ = ['FeatureAdd', 'FeatureInteract']
6 |
7 | class FeatureAdd(nn.Module):
8 | """
9 | A class for adding together different features of the data.
10 | """
11 | def __init__(self,
12 | feature_keys: list,
13 | output_key: str):
14 | super().__init__()
15 | self.feature_keys = feature_keys
16 | self.output_key = output_key
17 | self.model_outputs = [output_key]
18 |
19 | def forward(self, data: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]:
20 | feature_shape = data[self.feature_keys[0]].shape
21 | result = torch.zeros_like(data[self.feature_keys[0]])
22 | for feature_key in self.feature_keys:
23 | if data[feature_key].shape != feature_shape:
24 | raise ValueError(f"Feature {feature_key} has shape {data[feature_key].shape} but expected {feature_shape}")
25 | result += data[feature_key]
26 | data[self.output_key] = result
27 | return data
28 |
29 |
30 | class FeatureInteract(nn.Module):
31 | """
32 | A class for interacting between two multidimensional features by reshaping, performing interaction, and reshaping back.
33 | """
34 | def __init__(self,
35 | feature1_key: str,
36 | feature2_key: str,
37 | output_key: str):
38 | super().__init__()
39 | self.feature1_key = feature1_key
40 | self.feature2_key = feature2_key
41 | self.output_key = output_key
42 | self.model_outputs = [output_key]
43 |
44 | # Weights will be initialized during the forward pass
45 | self.weights = None
46 |
47 | def forward(self, data: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]:
48 | feature1 = data[self.feature1_key] # Shape: [n, A, B, C]
49 | feature2 = data[self.feature2_key] # Shape: [n, D, E]
50 |
51 | # Ensure the first dimensions match
52 | if feature1.shape[0] != feature2.shape[0]:
53 | raise ValueError(f"Feature1 has shape {feature1.shape} but feature2 has shape {feature2.shape}. Shapes must match.")
54 |
55 | # Save the original shape for reshaping back
56 | original_shape = feature1.shape # [n, A, B, C]
57 |
58 | # Reshape both features to [n, -1] (flatten all except the first dimension)
59 | n = feature1.shape[0] # The first dimension size (n)
60 | flattened_size1 = feature1.shape[1:].numel() # Product of A, B, C
61 | flattened_size2 = feature2.shape[1:].numel() # Product of D, E
62 | feature1_reshaped = feature1.view(n, -1) # [n, A * B * C]
63 | feature2_reshaped = feature2.view(n, -1) # [n, D * E]
64 |
65 | # Dynamically initialize weights based on the reshaped feature sizes during the forward pass
66 | if self.weights is None:
67 | self.weights = nn.Parameter(torch.randn(flattened_size2, flattened_size1, dtype=torch.get_default_dtype(), device=feature1.device))
68 |
69 | # Perform the interaction using einsum on the reshaped tensors
70 | interaction_result = feature1_reshaped * torch.einsum('ij,jk->ik', feature2_reshaped, self.weights)
71 |
72 | # Reshape the result back to the original shape [n, A, B, C]
73 | interaction_result = interaction_result.view(*original_shape)
74 |
75 | # Save the result in the data dictionary
76 | data[self.output_key] = interaction_result
77 | return data
78 |
79 |
--------------------------------------------------------------------------------
/cace/modules/forces.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch
3 | from torch import nn
4 |
5 | from .utils import get_outputs
6 |
7 | __all__ = ['Forces']
8 |
9 | class Forces(nn.Module):
10 | """
11 | Predicts forces and stress as response of the energy prediction
12 |
13 | """
14 |
15 | def __init__(
16 | self,
17 | calc_forces: bool = True,
18 | calc_stress: bool = True,
19 | #calc_virials: bool = False,
20 | energy_key: str = 'energy',
21 | forces_key: str = 'forces',
22 | stress_key: str = 'stress',
23 | virials_key: str = 'virials',
24 | ):
25 | """
26 | Args:
27 | calc_forces: If True, calculate atomic forces.
28 | calc_stress: If True, calculate the stress tensor.
29 | energy_key: Key of the energy in results.
30 | forces_key: Key of the forces in results.
31 | stress_key: Key of the stress in results.
32 | """
33 | super().__init__()
34 | self.calc_forces = calc_forces
35 | self.calc_stress = calc_stress
36 | #self.calc_virials = calc_virials
37 | self.energy_key = energy_key
38 | self.forces_key = forces_key
39 | self.stress_key = stress_key
40 | self.virial_key = virials_key
41 | self.model_outputs = []
42 | if calc_forces:
43 | self.model_outputs.append(forces_key)
44 | if calc_stress:
45 | self.model_outputs.append(stress_key)
46 |
47 | self.required_derivatives = []
48 | if self.calc_forces or self.calc_stress:
49 | self.required_derivatives.append('positions')
50 |
51 | def forward(self, data: Dict[str, torch.Tensor], training: bool = False, output_index: int = None) -> Dict[str, torch.Tensor]:
52 | forces, virials, stress = get_outputs(
53 | energy=data[self.energy_key][:, output_index] if output_index is not None and len(data[self.energy_key].shape) == 2 else data[self.energy_key],
54 | positions=data['positions'],
55 | displacement=data.get('displacement', None),
56 | cell=data.get('cell', None),
57 | training=training,
58 | compute_force=self.calc_forces,
59 | #compute_virials=self.calc_virials,
60 | compute_stress=self.calc_stress
61 | )
62 |
63 | data[self.forces_key] = forces
64 | if self.virial_key is not None:
65 | data[self.virial_key] = virials
66 | if self.stress_key is not None:
67 | data[self.stress_key] = stress
68 | return data
69 |
70 | def __repr__(self):
71 | return (
72 | f"{self.__class__.__name__} (calc_forces={self.calc_forces}, calc_stress={self.calc_stress},) "
73 | )
74 |
--------------------------------------------------------------------------------
/cace/modules/grad.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch
3 | from torch import nn
4 |
5 | __all__ = ['Grad']
6 |
7 | class Grad(nn.Module):
8 | """
9 | a wrapper for the gradient calculation
10 |
11 | """
12 |
13 | def __init__(
14 | self,
15 | y_key: str,
16 | x_key: str,
17 | output_key: str = 'gradient',
18 | ):
19 | super().__init__()
20 | self.y_key = y_key
21 | self.x_key = x_key
22 | self.output_key = output_key
23 | self.required_derivatives = [self.x_key]
24 | self.model_outputs = [self.output_key]
25 |
26 | def forward(self, data: Dict[str, torch.Tensor], training: bool = True, output_index: int = None) -> Dict[str, torch.Tensor]:
27 | y = data[self.y_key]
28 | x = data[self.x_key]
29 |
30 | if y.is_complex():
31 | get_imag = True
32 | else:
33 | get_imag = False
34 |
35 | if len(y.shape) == 1:
36 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)]
37 | gradient_real = torch.autograd.grad(
38 | outputs=[y], # [n_graphs, ]
39 | inputs=[x], # [n_nodes, 3]
40 | grad_outputs=grad_outputs,
41 | retain_graph=(training or get_imag), # Make sure the graph is not destroyed during training
42 | create_graph=training, # Create graph for second derivative
43 | allow_unused=True, # For complete dissociation turn to true
44 | )[0] # [n_nodes, 3]
45 |
46 | if get_imag:
47 | gradient_imag = torch.autograd.grad(
48 | outputs=[y/1j], # [n_graphs, ]
49 | inputs=[x], # [n_nodes, 3]
50 | grad_outputs=grad_outputs,
51 | retain_graph=training, # Make sure the graph is not destroyed during training
52 | create_graph=training, # Create graph for second derivative
53 | allow_unused=True, # For complete dissociation turn to true
54 | )[0] # [n_nodes, 3]
55 | else:
56 | gradient_imag = 0.0
57 | else:
58 | dim_y = y.shape[1]
59 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y[:,0])]
60 | gradient_real = torch.stack([
61 | torch.autograd.grad(
62 | outputs=[y[:,i]], # [n_graphs, ]
63 | inputs=[x], # [n_nodes, 3]
64 | grad_outputs=grad_outputs,
65 | retain_graph=(training or (i < dim_y - 1) or get_imag), # Make sure the graph is not destroyed during training
66 | create_graph=training, # Create graph for second derivative
67 | allow_unused=True, # For complete dissociation turn to true
68 | )[0] for i in range(dim_y)
69 | ], axis=2) # [n_nodes, 3, num_energy]
70 | # if y is complex, we need to calculate the imaginary part
71 | if get_imag:
72 | gradient_imag = torch.stack([
73 | torch.autograd.grad(
74 | outputs=[y[:,i]/1j], # [n_graphs, ]
75 | inputs=[x], # [n_nodes, 3]
76 | grad_outputs=grad_outputs,
77 | retain_graph=(training or (i < dim_y - 1)), # Make sure the graph is not destroyed during training
78 | create_graph=training, # Create graph for second derivative
79 | allow_unused=True, # For complete dissociation turn to true
80 | )[0] for i in range(dim_y)
81 | ], axis=2) # [n_nodes, 3, num_energy]
82 | if get_imag:
83 | data[self.output_key] = gradient_real + 1j * gradient_imag
84 | else:
85 | data[self.output_key] = gradient_real
86 |
87 | return data
88 |
89 | def __repr__(self):
90 | return (
91 | f"{self.__class__.__name__} (function={self.y_key}, variable={self.x_key},) "
92 | )
93 |
--------------------------------------------------------------------------------
/cace/modules/les_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Dict, Sequence, Union
4 |
5 | __all__ = ['LesWrapper']
6 |
7 | class LesWrapper(nn.Module):
8 | """
9 | A wrapper for the LES library that does long-range interactions and BECs
10 | Note that CACE has its own internal implementation of the LES algorithm
11 | so it is not necessary to use this wrapper in CACE.
12 | """
13 | def __init__(self,
14 | feature_key: Union[str, Sequence[int]] = 'node_feats',
15 | energy_key: str = 'LES_energy',
16 | charge_key: str = 'LES_charge',
17 | bec_key: str = 'LES_BEC',
18 | compute_energy: bool = True,
19 | compute_bec: bool = False,
20 | bec_output_index: int = None, # option to compute BEC along one axis
21 | ):
22 | super().__init__()
23 | from les import Les
24 | self.les = Les(les_arguments={})
25 |
26 | self.feature_key = feature_key
27 | self.energy_key = energy_key
28 | self.charge_key = charge_key
29 | self.bec_key = bec_key
30 | self.bec_output_index = bec_output_index
31 |
32 | self.compute_energy = compute_energy
33 | self.compute_bec = compute_bec
34 | self.model_outputs = [charge_key]
35 | if compute_energy:
36 | self.model_outputs.append(energy_key)
37 | if compute_bec:
38 | self.model_outputs.append(bec_key)
39 | self.required_derivatives = []
40 | self.required_derivatives.append('cell')
41 |
42 | def set_compute_energy(self, compute_energy: bool):
43 | self.compute_energy = compute_energy
44 |
45 | def set_compute_bec(self, compute_bec: bool):
46 | self.compute_bec = compute_bec
47 |
48 | def set_bec_output_index(self, bec_output_index: int):
49 | self.bec_output_index = bec_output_index
50 |
51 | def forward(self, data: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]:
52 |
53 | # reshape the feature vectors
54 | if isinstance(self.feature_key, str):
55 | if self.feature_key not in data:
56 | raise ValueError(f"Feature key {self.feature_key} not found in data dictionary.")
57 | features = data[self.feature_key]
58 | features = features.reshape(features.shape[0], -1)
59 | elif isinstance(self.feature_key, list):
60 | features = torch.cat([data[key].reshape(data[key].shape[0], -1) for key in self.feature_key], dim=-1)
61 |
62 | result = self.les(desc=features,
63 | positions=data['positions'],
64 | cell=data['cell'].view(-1, 3, 3),
65 | batch=data["batch"],
66 | compute_energy=self.compute_energy,
67 | compute_bec=self.compute_bec,
68 | bec_output_index=self.bec_output_index,
69 | )
70 |
71 | data[self.charge_key] = result['latent_charges']
72 | if self.compute_energy:
73 | data[self.energy_key] = result['E_lr']
74 | if self.compute_bec:
75 | data[self.bec_key] = result['BEC']
76 | return data
77 |
--------------------------------------------------------------------------------
/cace/modules/polarization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Dict
4 |
5 | __all__ = ['Polarization', 'Dephase', 'FixedCharge']
6 |
7 | class Polarization(nn.Module):
8 | def __init__(self,
9 | charge_key: str = 'q',
10 | output_key: str = 'polarization',
11 | output_index: int = None, # 0, 1, 2 to select only one component
12 | phase_key: str = 'phase',
13 | remove_mean: bool = True,
14 | pbc: bool = False,
15 | normalization_factor: float = 1./9.48933,
16 | ):
17 | super().__init__()
18 | self.charge_key = charge_key
19 | self.output_key = output_key
20 | self.output_index = output_index
21 | self.phase_key = phase_key
22 | self.model_outputs = [output_key, phase_key]
23 | self.remove_mean = remove_mean
24 | self.pbc = pbc
25 | self.normalization_factor = normalization_factor
26 |
27 | def forward(self, data: Dict[str, torch.Tensor], training=True, output_index=None) -> torch.Tensor:
28 |
29 | if data["batch"] is None:
30 | n_nodes = data['positions'].shape[0]
31 | batch_now = torch.zeros(n_nodes, dtype=torch.int64, device=data['positions'].device)
32 | else:
33 | batch_now = data["batch"]
34 |
35 | box = data['cell'].view(-1, 3, 3)
36 |
37 | r = data['positions']
38 | q = data[self.charge_key]
39 |
40 | if q.dim() == 1:
41 | q = q.unsqueeze(1)
42 | if self.remove_mean:
43 | q = q - torch.mean(q, dim=0, keepdim=True)
44 |
45 | # Check the input dimension
46 | n, d = r.shape
47 | assert d == 3, 'r dimension error'
48 | assert n == q.size(0), 'q dimension error'
49 |
50 | unique_batches = torch.unique(batch_now) # Get unique batch indices
51 |
52 | results = []
53 | phases = []
54 | for i in unique_batches:
55 | mask = batch_now == i # Create a mask for the i-th configuration
56 | r_now, q_now, box_now = r[mask], q[mask], box[i]
57 | box_diag = box[i].diagonal(dim1=-2, dim2=-1)
58 | if box_diag[0] < 1e-6 and box_diag[1] < 1e-6 and box_diag[2] < 1e-6 or self.pbc == False:
59 | # the box is not periodic, we use the direct sum
60 | polarization = torch.sum(q_now * r_now, dim=0)
61 | elif box_diag[0] > 0 and box_diag[1] > 0 and box_diag[2] > 0:
62 | polarization, phase = self.compute_pol_pbc(r_now, q_now, box_now)
63 | if self.output_index is not None:
64 | phase = phase[:,self.output_index]
65 | phases.append(phase)
66 | if self.output_index is not None:
67 | polarization = polarization[self.output_index]
68 | results.append(polarization * self.normalization_factor)
69 | data[self.output_key] = torch.stack(results, dim=0)
70 | if len(phases) > 0:
71 | data[self.phase_key] = torch.cat(phases, dim=0)
72 | else:
73 | data[self.phase_key] = 0.0
74 | return data
75 |
76 | def compute_pol_pbc(self, r_now, q_now, box_now):
77 | r_frac = torch.matmul(r_now, torch.linalg.inv(box_now))
78 | phase = torch.exp(1j * 2.* torch.pi * r_frac)
79 | S = torch.sum(q_now * phase, dim=0)
80 | polarization = torch.matmul(box_now.to(S.dtype),
81 | S.unsqueeze(1)) / (1j * 2.* torch.pi)
82 | return polarization.reshape(-1), phase
83 |
84 | class Dephase(nn.Module):
85 | def __init__(self,
86 | input_key: str = None,
87 | phase_key: str = 'phase',
88 | output_key: str = 'dephased',
89 | input_index: int = None,
90 | ):
91 | super().__init__()
92 | self.input_key = input_key
93 | self.input_index = input_index
94 | self.output_key = output_key
95 | self.phase_key = phase_key
96 | self.model_outputs = [output_key]
97 |
98 | def forward(self, data: Dict[str, torch.Tensor], training=None, output_index=None) -> torch.Tensor:
99 | result = data[self.input_key] * data[self.phase_key].unsqueeze(1).conj()
100 | data[self.output_key] = result.real
101 | return data
102 |
103 | class FixedCharge(nn.Module):
104 | def __init__(self,
105 | atomic_numbers_key: str = 'atomic_numbers',
106 | output_key: str = 'q',
107 | charge_dict: Dict[int, float] = None,
108 | normalize: bool = True,
109 | ):
110 | super().__init__()
111 | self.charge_dict = charge_dict
112 | self.atomic_numbers_key = atomic_numbers_key
113 | self.output_key = output_key
114 | self.normalize = normalize
115 | self.normalization_factor = 9.48933
116 |
117 | def forward(self, data: Dict[str, torch.Tensor], training=None, output_index=None) -> torch.Tensor:
118 | atomic_numbers = data[self.atomic_numbers_key]
119 | charge = torch.tensor([self.charge_dict[atomic_number.item()] for atomic_number in atomic_numbers], device=atomic_numbers.device)
120 | if self.normalize:
121 | charge = charge * self.normalization_factor # to be consistent with the internal units and the Ewald sum
122 | data[self.output_key] = charge[:,None]
123 | return data
124 |
--------------------------------------------------------------------------------
/cace/modules/preprocess.py:
--------------------------------------------------------------------------------
1 | # Description: Preprocess the data for the model
2 |
3 | from typing import Dict
4 | import torch
5 | from torch import nn
6 |
7 | from .utils import get_symmetric_displacement
8 |
9 | __all__ = ["Preprocess"]
10 |
11 | class Preprocess(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def forward(self, data: Dict[str, torch.Tensor], compute_stress: bool = False, compute_virials: bool = False):
16 |
17 | try:
18 | num_graphs = data["ptr"].numel() - 1
19 | except:
20 | num_graphs = 1
21 |
22 | if compute_virials or compute_stress:
23 | (
24 | data["positions"],
25 | data["shifts"],
26 | data["displacement"],
27 | data["cell"]
28 | ) = get_symmetric_displacement(
29 | positions=data["positions"],
30 | unit_shifts=data["unit_shifts"],
31 | cell=data["cell"],
32 | edge_index=data["edge_index"],
33 | num_graphs=num_graphs,
34 | batch=data["batch"],
35 | )
36 | else:
37 | data["displacement"] = None
38 |
39 | return data
40 |
41 |
42 |
--------------------------------------------------------------------------------
/cace/modules/product_basis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from .angular_tools import find_combo_vectors_l1l2
5 |
6 | class AngularTensorProduct(nn.Module):
7 | def __init__(self, max_l: int, l_list: list):
8 | super().__init__()
9 | self.max_l = max_l
10 |
11 | # Convert elements of l_list to tuples for dictionary keys
12 | l_list_tuples = [tuple(l) for l in l_list]
13 | # Create a dictionary to map tuple to index
14 | l_list_indices = {l_tuple: i for i, l_tuple in enumerate(l_list_tuples)}
15 | vec_dict = find_combo_vectors_l1l2(self.max_l)
16 |
17 | self._get_indices(vec_dict, l_list_indices)
18 |
19 | def _get_indices(self, vec_dict, l_list_indices):
20 | self.indice_list = []
21 | for i, (l3, l1l2_list) in enumerate(vec_dict.items()):
22 | l3_index = l_list_indices[tuple(l3)]
23 | for item in l1l2_list:
24 | prefactor = int(item[-1])
25 | l1l2indices = [l_list_indices[tuple(lxlylz)] for lxlylz in item[:-1]]
26 | self.indice_list.append([l3_index, l1l2indices[0], l1l2indices[1], prefactor])
27 | self.indice_tensor = torch.tensor(self.indice_list)
28 |
29 | def forward(self, edge_attr1: torch.Tensor, edge_attr2: torch.Tensor):
30 |
31 | num_edges, n_radial, n_angular, n_chanel = edge_attr1.size()
32 | edge_attr_new = torch.zeros((num_edges, n_radial, n_angular, n_chanel),
33 | dtype=edge_attr1.dtype, device=edge_attr1.device)
34 |
35 | for item in self.indice_tensor:
36 | l3_index, l1_index, l2_index, prefactor = item[0], item[1], item[2], item[3]
37 | edge_attr_new[:, :, l3_index, :] += prefactor * edge_attr1[:, :, l1_index, :] * edge_attr2[:, :, l2_index, :]
38 |
39 | return edge_attr_new
40 |
--------------------------------------------------------------------------------
/cace/modules/radial.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # Radial basis
3 | # modified from mace/mace/modules/radials.py and schnetpack/src/schnetpack/nn/radials.py
4 | # This program is distributed under the MIT License (see MIT.md)
5 | ###########################################################################################
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 |
11 | __all__ = ["BesselRBF", "GaussianRBF", "GaussianRBFCentered", "ExponentialDecayRBF"]
12 |
13 | class BesselRBF(nn.Module):
14 | """
15 | Sine for radial basis functions with coulomb decay (0th order bessel).
16 |
17 | References:
18 |
19 | .. [#dimenet] Klicpera, Groß, Günnemann:
20 | Directional message passing for molecular graphs.
21 | ICLR 2020
22 | Equation (7)
23 | """
24 |
25 | def __init__(self, cutoff: float, n_rbf=8, trainable=False):
26 | super().__init__()
27 |
28 | self.n_rbf = n_rbf
29 |
30 | bessel_weights = (
31 | np.pi
32 | / cutoff
33 | * torch.linspace(
34 | start=1.0,
35 | end=n_rbf,
36 | steps=n_rbf,
37 | dtype=torch.get_default_dtype(),
38 | )
39 | )
40 | if trainable:
41 | self.bessel_weights = nn.Parameter(bessel_weights)
42 | else:
43 | self.register_buffer("bessel_weights", bessel_weights)
44 |
45 | self.register_buffer(
46 | "cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype())
47 | )
48 | self.register_buffer(
49 | "prefactor",
50 | torch.tensor(np.sqrt(2.0 / cutoff), dtype=torch.get_default_dtype()),
51 | )
52 |
53 | def forward(self, x: torch.Tensor) -> torch.Tensor: # [...,1]
54 | numerator = torch.sin(self.bessel_weights * x) # [..., n_rbf]
55 | return self.prefactor * (numerator / x) # [..., n_rbf]
56 |
57 | def __repr__(self):
58 | return (
59 | f"{self.__class__.__name__}(cutoff={self.cutoff}, n_rbf={len(self.bessel_weights)}, "
60 | f"trainable={self.bessel_weights.requires_grad})"
61 | )
62 |
63 |
64 | class ExponentialDecayRBF(nn.Module):
65 | """Exponential decay radial basis functions.
66 | y = prefactor * exp(-x / r0)
67 | """
68 | def __init__(
69 | self, n_rbf: int, cutoff: float, prefactor: torch.tensor=torch.tensor(1.0), trainable: bool = False
70 | ):
71 | super().__init__()
72 | self.n_rbf = n_rbf
73 |
74 | # Convert prefactor to a tensor if it's not already one
75 | if not isinstance(prefactor, torch.Tensor):
76 | prefactor = torch.tensor(prefactor, dtype=torch.get_default_dtype())
77 |
78 | if n_rbf == 1:
79 | r0 = torch.tensor(cutoff / 2.0, dtype=torch.get_default_dtype())
80 | else:
81 | # compute offset and width of Gaussian functions
82 | r0 = torch.linspace(0, cutoff, n_rbf + 2, dtype=torch.get_default_dtype()) [1:-1]
83 |
84 | self.register_buffer("cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype()))
85 |
86 | if trainable:
87 | self.r0 = nn.Parameter(r0)
88 | self.prefactor = nn.Parameter(prefactor)
89 | else:
90 | self.register_buffer("r0", r0)
91 | self.register_buffer("prefactor", torch.tensor(prefactor, dtype=torch.get_default_dtype()))
92 |
93 | def forward(self, inputs: torch.Tensor):
94 | return self.prefactor * torch.exp(-inputs / self.r0)
95 |
96 | def __repr__(self):
97 | return (
98 | f"{self.__class__.__name__}(prefactor={self.prefactor}, r0={self.r0},"
99 | f"trainable={self.r0.requires_grad})"
100 | )
101 |
102 | def gaussian_rbf(inputs: torch.Tensor, offsets: torch.Tensor, widths: torch.Tensor):
103 | coeff = -0.5 / torch.pow(widths, 2)
104 | diff = inputs - offsets
105 | y = torch.exp(coeff * torch.pow(diff, 2))
106 | return y
107 |
108 |
109 | class GaussianRBF(nn.Module):
110 | r"""Gaussian radial basis functions."""
111 |
112 | def __init__(
113 | self, n_rbf: int, cutoff: float, start: float = 0.8, trainable: bool = False
114 | ):
115 | """
116 | Args:
117 | n_rbf: total number of Gaussian functions, :math:`N_g`.
118 | cutoff: center of last Gaussian function, :math:`\mu_{N_g}`
119 | start: center of first Gaussian function, :math:`\mu_0`.
120 | trainable: If True, widths and offset of Gaussian functions
121 | are adjusted during training process.
122 | """
123 | super().__init__()
124 | self.n_rbf = n_rbf
125 |
126 | # compute offset and width of Gaussian functions
127 | offset = torch.linspace(start, cutoff, n_rbf, dtype=torch.get_default_dtype())
128 | widths = torch.FloatTensor(
129 | torch.abs(offset[1] - offset[0]) * torch.ones_like(offset)
130 | )
131 |
132 | self.register_buffer(
133 | "cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype())
134 | )
135 |
136 | if trainable:
137 | self.widths = nn.Parameter(widths)
138 | self.offsets = nn.Parameter(offset)
139 | else:
140 | self.register_buffer("widths", widths)
141 | self.register_buffer("offsets", offset)
142 |
143 | def forward(self, inputs: torch.Tensor):
144 | return gaussian_rbf(inputs, self.offsets, self.widths)
145 |
146 | def __repr__(self):
147 | return (
148 | f"{self.__class__.__name__}(cutoff={self.cutoff}, n_rbf={self.n_rbf}, "
149 | f"trainable={self.widths.requires_grad})"
150 | )
151 |
152 | class GaussianRBFCentered(nn.Module):
153 | r"""Gaussian radial basis functions centered at the origin."""
154 |
155 | def __init__(
156 | self, n_rbf: int, cutoff: float, start: float = 1.0, trainable: bool = False
157 | ):
158 | """
159 | Args:
160 | n_rbf: total number of Gaussian functions, :math:`N_g`.
161 | cutoff: width of last Gaussian function, :math:`\mu_{N_g}`
162 | start: width of first Gaussian function, :math:`\mu_0`.
163 | trainable: If True, widths of Gaussian functions
164 | are adjusted during training process.
165 | """
166 | super().__init__()
167 | self.n_rbf = n_rbf
168 |
169 | # compute offset and width of Gaussian functions
170 | widths = torch.linspace(start, cutoff, n_rbf, dtype=torch.get_default_dtype())
171 | offset = torch.zeros_like(widths)
172 |
173 | self.register_buffer(
174 | "cutoff", torch.tensor(cutoff, dtype=torch.get_default_dtype())
175 | )
176 |
177 | if trainable:
178 | self.widths = nn.Parameter(widths)
179 | self.offsets = nn.Parameter(offset)
180 | else:
181 | self.register_buffer("widths", widths)
182 | self.register_buffer("offsets", offset)
183 |
184 | def forward(self, inputs: torch.Tensor):
185 | return gaussian_rbf(inputs, self.offsets, self.widths)
186 |
187 | def __repr__(self):
188 | return (
189 | f"{self.__class__.__name__}(cutoff={self.cutoff}, n_rbf={self.n_rbf}, "
190 | f"trainable={self.widths.requires_grad})"
191 | )
192 |
--------------------------------------------------------------------------------
/cace/modules/radial_transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from typing import Optional, List
5 |
6 | class SharedRadialLinearTransform(nn.Module):
7 | # TODO: this can be jitted, however, this causes trouble in saving the model
8 | def __init__(self, max_l: int, radial_dim: int, radial_embedding_dim: Optional[int] = None, channel_dim: Optional[int] = None):
9 | super().__init__()
10 | self.max_l = max_l
11 | self.radial_dim = radial_dim
12 | self.radial_embedding_dim = radial_embedding_dim or radial_dim
13 | self.channel_dim = channel_dim
14 | self.register_buffer('angular_dim_groups', torch.tensor(self._init_angular_dim_groups(max_l), dtype=torch.int64))
15 | self.weights = self._initialize_weights(radial_dim, self.radial_embedding_dim, channel_dim)
16 |
17 | def __getstate__(self):
18 | # Return a dictionary of state items to be serialized.
19 | state = self.__dict__.copy()
20 | # Modify the state dictionary as needed, or return as is.
21 | return state
22 |
23 | def __setstate__(self, state):
24 | # Restore the state.
25 | self.__dict__.update(state)
26 |
27 | def _initialize_weights(self, radial_dim: int, radial_embedding_dim: int, channel_dim: int) -> nn.ParameterList:
28 | torch.manual_seed(0)
29 | # TODO: try other initialization
30 | if channel_dim is not None:
31 | return nn.ParameterList([
32 | nn.Parameter(torch.rand([radial_dim, radial_embedding_dim, channel_dim])) for _ in self.angular_dim_groups
33 | ])
34 | else:
35 | return nn.ParameterList([
36 | nn.Parameter(torch.rand([radial_dim, radial_embedding_dim])) for _ in self.angular_dim_groups
37 | ])
38 |
39 | def forward(self, x: torch.Tensor) -> torch.Tensor:
40 |
41 | n_nodes, radial_dim, angular_dim, embedding_dim = x.shape
42 |
43 | output = torch.zeros(n_nodes, self.radial_embedding_dim, angular_dim, embedding_dim,
44 | device=x.device, dtype=x.dtype)
45 |
46 | for index, weight in enumerate(self.weights):
47 | i_start = self.angular_dim_groups[index, 0]
48 | i_end = self.angular_dim_groups[index, 1]
49 | group = torch.arange(i_start, i_end)
50 | # Gather all angular dimensions for the current group
51 | group_x = x[:, :, group, :] # Shape: [n_nodes, radial_dim, len(group), embedding_dim]
52 | # Apply the transformation for the entire group at once
53 | if self.channel_dim:
54 | transformed_group = torch.einsum('ijkh,jmh->imkh', group_x, weight)
55 | else:
56 | transformed_group = torch.einsum('ijkh,jm->imkh', group_x, weight)
57 | # Assign to the output tensor for each angular dimension
58 | output[:, :, group, :] = transformed_group
59 | return output
60 |
61 | def _compute_length_lxlylz(self, l):
62 | return int((l+1)*(l+2)/2)
63 |
64 | def _init_angular_dim_groups(self, max_l):
65 | angular_dim_groups: List[int] = []
66 | l_now = 0
67 | for l in range(max_l+1):
68 | l_list_atl = [l_now, l_now + self._compute_length_lxlylz(l)]
69 | angular_dim_groups.append(l_list_atl)
70 | l_now += self._compute_length_lxlylz(l)
71 | return angular_dim_groups
72 |
--------------------------------------------------------------------------------
/cace/modules/transform.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | __all__ = [
7 | "Transform",
8 | "TransformException",
9 | ]
10 |
11 |
12 | class TransformException(Exception):
13 | pass
14 |
15 |
16 | class Transform(nn.Module):
17 | """
18 | Base class for all transforms.
19 | The base class ensures that the reference to the data and datamodule attributes are
20 | initialized.
21 | Transforms can be used as pre- or post-processing layers.
22 | They can also be used for other parts of a model, that need to be
23 | initialized based on data.
24 |
25 | To implement a new transform, override the forward method. Preprocessors are applied
26 | to single examples, while postprocessors operate on batches. All transforms should
27 | return a modified `inputs` dictionary.
28 |
29 | """
30 |
31 | def datamodule(self, value):
32 | """
33 | Extract all required information from data module automatically when using
34 | PyTorch Lightning integration. The transform should also implement a way to
35 | set these things manually, to make it usable independent of PL.
36 |
37 | Do not store the datamodule, as this does not work with torchscript conversion!
38 | """
39 | pass
40 |
41 | def forward(
42 | self,
43 | inputs: Dict[str, torch.Tensor],
44 | ) -> Dict[str, torch.Tensor]:
45 | raise NotImplementedError
46 |
47 | def teardown(self):
48 | pass
49 |
--------------------------------------------------------------------------------
/cace/modules/utils.py:
--------------------------------------------------------------------------------
1 | ###########################################################################################
2 | # Utilities
3 | # modified from MACE
4 | # This program is distributed under the MIT License (see MIT.md)
5 | ###########################################################################################
6 |
7 | from typing import List, Optional, Tuple
8 |
9 | import torch
10 |
11 | __all__ = ["get_outputs", "get_edge_vectors_and_lengths", "get_edge_node_type", "get_symmetric_displacement"]
12 |
13 | def compute_forces(
14 | energy: torch.Tensor, positions: torch.Tensor, training: bool = False
15 | ) -> torch.Tensor:
16 | # check the dimension of the energy tensor
17 | if len(energy.shape) == 1:
18 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)]
19 | gradient = torch.autograd.grad(
20 | outputs=[energy], # [n_graphs, ]
21 | inputs=[positions], # [n_nodes, 3]
22 | grad_outputs=grad_outputs,
23 | retain_graph=training, # Make sure the graph is not destroyed during training
24 | create_graph=training, # Create graph for second derivative
25 | allow_unused=True, # For complete dissociation turn to true
26 | )[0] # [n_nodes, 3]
27 | else:
28 | num_energy = energy.shape[1]
29 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy[:,0])]
30 | gradient = torch.stack([
31 | torch.autograd.grad(
32 | outputs=[energy[:,i]], # [n_graphs, ]
33 | inputs=[positions], # [n_nodes, 3]
34 | grad_outputs=grad_outputs,
35 | retain_graph=(training or (i < num_energy - 1)), # Make sure the graph is not destroyed during training
36 | create_graph=(training or (i < num_energy - 1)), # Create graph for second derivative
37 | allow_unused=True, # For complete dissociation turn to true
38 | )[0] for i in range(num_energy)
39 | ], axis=2) # [n_nodes, 3, num_energy]
40 |
41 | if gradient is None:
42 | return torch.zeros_like(positions)
43 | return -1 * gradient
44 |
45 |
46 | def compute_forces_virials(
47 | energy: torch.Tensor,
48 | positions: torch.Tensor,
49 | displacement: torch.Tensor,
50 | cell: torch.Tensor,
51 | training: bool = False,
52 | compute_stress: bool = False,
53 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
54 | # check the dimension of the energy tensor
55 | if len(energy.shape) == 1:
56 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)]
57 | gradient, virials = torch.autograd.grad(
58 | outputs=[energy], # [n_graphs, ]
59 | inputs=[positions, displacement], # [n_nodes, 3]
60 | grad_outputs=grad_outputs,
61 | retain_graph=training, # Make sure the graph is not destroyed during training
62 | create_graph=training, # Create graph for second derivative
63 | allow_unused=True,
64 | )
65 | stress = torch.zeros_like(displacement)
66 | if compute_stress and virials is not None:
67 | cell = cell.view(-1, 3, 3)
68 | volume = torch.einsum(
69 | "zi,zi->z",
70 | cell[:, 0, :],
71 | torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1),
72 | ).unsqueeze(-1)
73 | stress = virials / volume.view(-1, 1, 1)
74 | else:
75 | num_energy = energy.shape[1]
76 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy[:,0])]
77 | gradient_list, virials_list, stress_list = [], [], []
78 | for i in range(num_energy):
79 | gradient, virials = torch.autograd.grad(
80 | outputs=[energy[:,i]], # [n_graphs, ]
81 | inputs=[positions, displacement], # [n_nodes, 3]
82 | grad_outputs=grad_outputs,
83 | retain_graph=(training or (i < num_energy - 1)), # Make sure the graph is not destroyed during training
84 | create_graph=(training or (i < num_energy - 1)), # Create graph for second derivative
85 | allow_unused=True,
86 | )
87 | stress = torch.zeros_like(displacement)
88 | if compute_stress and virials is not None:
89 | cell = cell.view(-1, 3, 3)
90 | volume = torch.einsum(
91 | "zi,zi->z",
92 | cell[:, 0, :],
93 | torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1),
94 | ).unsqueeze(-1)
95 | stress = virials / volume.view(-1, 1, 1)
96 | gradient_list.append(gradient)
97 | virials_list.append(virials)
98 | stress_list.append(stress)
99 | gradient = torch.stack(gradient_list, axis=2)
100 | virials = torch.stack(virials_list, axis=-1)
101 | stress = torch.stack(stress_list, axis=-1)
102 |
103 | if gradient is None:
104 | gradient = torch.zeros_like(positions)
105 | if virials is None:
106 | virials = torch.zeros((1, 3, 3))
107 |
108 | return -1 * gradient, -1 * virials, stress
109 |
110 | def get_symmetric_displacement(
111 | positions: torch.Tensor,
112 | unit_shifts: torch.Tensor,
113 | cell: Optional[torch.Tensor],
114 | edge_index: torch.Tensor,
115 | num_graphs: int,
116 | batch: torch.Tensor,
117 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
118 | if cell is None:
119 | cell = torch.zeros(
120 | num_graphs * 3,
121 | 3,
122 | dtype=positions.dtype,
123 | device=positions.device,
124 | )
125 | sender = edge_index[0]
126 | displacement = torch.zeros(
127 | (num_graphs, 3, 3),
128 | dtype=positions.dtype,
129 | device=positions.device,
130 | )
131 | displacement.requires_grad_(True)
132 | symmetric_displacement = 0.5 * (
133 | displacement + displacement.transpose(-1, -2)
134 | ) # From https://github.com/mir-group/nequip
135 | positions = positions + torch.einsum(
136 | "be,bec->bc", positions, symmetric_displacement[batch]
137 | )
138 | cell = cell.view(-1, 3, 3)
139 | cell = cell + torch.matmul(cell, symmetric_displacement)
140 | shifts = torch.einsum(
141 | "be,bec->bc",
142 | unit_shifts,
143 | cell[batch[sender]],
144 | )
145 | return positions, shifts, displacement, cell
146 |
147 | def get_outputs(
148 | energy: torch.Tensor,
149 | positions: torch.Tensor,
150 | displacement: Optional[torch.Tensor] = None,
151 | cell: Optional[torch.Tensor] = None,
152 | training: bool = False,
153 | compute_force: bool = True,
154 | compute_virials: bool = True,
155 | compute_stress: bool = True,
156 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
157 | if (compute_virials or compute_stress) and displacement is not None:
158 | # forces come for free
159 | forces, virials, stress = compute_forces_virials(
160 | energy=energy,
161 | positions=positions,
162 | displacement=displacement,
163 | cell=cell,
164 | compute_stress=compute_stress,
165 | training=training,
166 | )
167 | elif compute_force:
168 | forces, virials, stress = (
169 | compute_forces(energy=energy, positions=positions, training=training),
170 | None,
171 | None,
172 | )
173 | else:
174 | forces, virials, stress = (None, None, None)
175 | return forces, virials, stress
176 |
177 | def get_edge_vectors_and_lengths(
178 | positions: torch.Tensor, # [n_nodes, 3]
179 | edge_index: torch.Tensor, # [2, n_edges]
180 | shifts: torch.Tensor, # [n_edges, 3]
181 | normalize: bool = False,
182 | eps: float = 1e-9,
183 | ) -> Tuple[torch.Tensor, torch.Tensor]:
184 | sender = edge_index[0]
185 | receiver = edge_index[1]
186 | vectors = positions[receiver] - positions[sender] + shifts # [n_edges, 3]
187 | lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1]
188 | if normalize:
189 | vectors_normed = vectors / (lengths + eps)
190 | return vectors_normed, lengths
191 |
192 | return vectors, lengths
193 |
194 | def get_edge_node_type(
195 | edge_index: torch.Tensor, # [2, n_edges]
196 | node_type: torch.Tensor, # [n_nodes, n_dims]
197 | node_type_2: torch.Tensor=None, # [n_nodes, n_dims]
198 | ) -> Tuple[torch.Tensor, torch.Tensor]:
199 | if node_type_2 is None:
200 | node_type_2 = node_type
201 |
202 | edge_type = torch.zeros([edge_index.shape[1], 2, node_type.shape[1]],
203 | dtype=node_type.dtype, device=node_type.device)
204 | sender_type = node_type[edge_index[0]]
205 | receiver_type = node_type_2[edge_index[1]]
206 | return sender_type, receiver_type # [n_edges, n_dims]
207 |
208 |
--------------------------------------------------------------------------------
/cace/representations/__init__.py:
--------------------------------------------------------------------------------
1 | from .cace_representation import *
2 |
--------------------------------------------------------------------------------
/cace/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .load_data import *
2 |
3 | from .loss import *
4 |
5 | from .train import *
6 |
7 | from .evaluate import *
8 |
9 | from .lightning import *
10 |
11 |
12 |
--------------------------------------------------------------------------------
/cace/tasks/load_data.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import logging
3 | from typing import Dict, List, Optional, Tuple, Sequence
4 | import numpy as np
5 | from ase import Atoms
6 | from ase.io import read
7 | from ..tools import torch_geometric
8 | from ..data import AtomicData
9 |
10 | __all__ = ["load_data_loader", "get_dataset_from_xyz", "random_train_valid_split"]
11 |
12 | @dataclasses.dataclass
13 | class SubsetAtoms:
14 | train: Atoms
15 | valid: Atoms
16 | test: Atoms
17 | cutoff: float
18 | data_key: Dict
19 | atomic_energies: Dict
20 |
21 | def load_data_loader(
22 | collection: SubsetAtoms,
23 | data_type: str, # ['train', 'valid', 'test']
24 | batch_size: int,
25 | ):
26 |
27 | allowed_types = ['train', 'valid', 'test']
28 | if data_type not in allowed_types:
29 | raise ValueError(f"Input value must be one of {allowed_types}, got {data_type}")
30 |
31 | cutoff = collection.cutoff
32 | data_key = collection.data_key
33 | atomic_energies = collection.atomic_energies
34 |
35 | if data_type == 'train':
36 | loader = torch_geometric.DataLoader(
37 | dataset=[
38 | AtomicData.from_atoms(atoms, cutoff=cutoff, data_key=data_key, atomic_energies=atomic_energies)
39 | for atoms in collection.train
40 | ],
41 | batch_size=batch_size,
42 | shuffle=True,
43 | drop_last=True,
44 | )
45 | elif data_type == 'valid':
46 | loader = torch_geometric.DataLoader(
47 | dataset=[
48 | AtomicData.from_atoms(atoms, cutoff=cutoff, data_key=data_key, atomic_energies=atomic_energies)
49 | for atoms in collection.valid
50 | ],
51 | batch_size=batch_size,
52 | shuffle=False,
53 | drop_last=False,
54 | )
55 | elif data_type == 'test':
56 | loader = torch_geometric.DataLoader(
57 | dataset=[
58 | AtomicData.from_atoms(atoms, cutoff=cutoff, data_key=data_key, atomic_energies=atomic_energies)
59 | for atoms in collection.test
60 | ],
61 | batch_size=batch_size,
62 | shuffle=False,
63 | drop_last=False,
64 | )
65 | return loader
66 |
67 | def get_dataset_from_xyz(
68 | train_path: str,
69 | cutoff: float,
70 | valid_path: str = None,
71 | valid_fraction: float = 0.1,
72 | test_path: str = None,
73 | seed: int = 1234,
74 | data_key: Dict[str, str] = None,
75 | atomic_energies: Dict[int, float] = None
76 | ) -> SubsetAtoms:
77 | """Load training and test dataset from xyz file"""
78 | all_train_configs = read(train_path, ":")
79 | if not isinstance(all_train_configs, list):
80 | all_train_configs = [all_train_configs]
81 | logging.info(
82 | f"Loaded {len(all_train_configs)} training configurations from '{train_path}'"
83 | )
84 | if valid_path is not None:
85 | valid_configs = read(valid_path, ":")
86 | if not isinstance(valid_configs, list):
87 | valid_configs = [valid_configs]
88 | logging.info(
89 | f"Loaded {len(valid_configs)} validation configurations from '{valid_path}'"
90 | )
91 | train_configs = all_train_configs
92 | else:
93 | logging.info(
94 | "Using random %s%% of training set for validation", 100 * valid_fraction
95 | )
96 | train_configs, valid_configs = random_train_valid_split(
97 | all_train_configs, valid_fraction, seed
98 | )
99 |
100 | test_configs = []
101 | if test_path is not None:
102 | test_configs = read(test_path, ":")
103 | if not isinstance(test_configs, list):
104 | test_configs = [test_configs]
105 | logging.info(
106 | f"Loaded {len(test_configs)} test configurations from '{test_path}'"
107 | )
108 | return (
109 | SubsetAtoms(train=train_configs, valid=valid_configs, test=test_configs, cutoff=cutoff, data_key=data_key, atomic_energies=atomic_energies)
110 | )
111 |
112 | def random_train_valid_split(
113 | items: Sequence, valid_fraction: float, seed: int
114 | ) -> Tuple[List, List]:
115 | assert 0.0 < valid_fraction < 1.0
116 |
117 | size = len(items)
118 | train_size = size - int(valid_fraction * size)
119 |
120 | indices = list(range(size))
121 | rng = np.random.default_rng(seed)
122 | rng.shuffle(indices)
123 |
124 | return (
125 | [items[i] for i in indices[:train_size]],
126 | [items[i] for i in indices[train_size:]],
127 | )
128 |
--------------------------------------------------------------------------------
/cace/tasks/loss.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict, Union, Callable
2 | import torch
3 | import torch.nn as nn
4 |
5 | __all__ = ["GetLoss", "GetRegularizationLoss", "GetVarianceLoss"]
6 |
7 | class GetLoss(nn.Module):
8 | """
9 | Defines mappings to a loss function and weight for training
10 | """
11 |
12 | def __init__(
13 | self,
14 | target_name: str,
15 | predict_name: Optional[str] = None,
16 | output_index: Optional[int] = None, # only used for multi-output models
17 | name: Optional[str] = None,
18 | loss_fn: Optional[nn.Module] = None,
19 | loss_weight: Union[float, Callable] = 1.0, # Union[float, Callable] means that the type can be either float or callable
20 | ):
21 | """
22 | Args:
23 | target_name: Name of target in training batch.
24 | name: name of the loss object
25 | loss_fn: function to compute the loss
26 | loss_weight: loss weight in the composite loss: $l = w_1 l_1 + \dots + w_n l_n$
27 | This can be a float or a callable that takes in the loss_weight_args
28 | For example, if we want the loss weight to be dependent on the epoch number
29 | if training == True and a default value of 1.0 otherwise,
30 | loss_weight can be, e.g., lambda training, epoch: 1.0 if not training else epoch / 100
31 | """
32 | super().__init__()
33 | self.target_name = target_name
34 | self.predict_name = predict_name or target_name
35 | self.output_index = output_index
36 | self.name = name or target_name
37 | self.loss_fn = loss_fn
38 | # the loss_weight can either be a float or a callable
39 | self.loss_weight = loss_weight
40 |
41 | def forward(self,
42 | pred: Dict[str, torch.Tensor],
43 | target: Optional[Dict[str, torch.Tensor]] = None,
44 | loss_args: Optional[Dict[str, torch.Tensor]] = None
45 | ):
46 | if self.loss_weight == 0 or self.loss_fn is None:
47 | return 0.0
48 |
49 | if isinstance(self.loss_weight, Callable):
50 | if loss_args is None:
51 | loss_weight = self.loss_weight()
52 | else:
53 | loss_weight = self.loss_weight(**loss_args)
54 | else:
55 | loss_weight = self.loss_weight
56 |
57 | if self.output_index is None:
58 | pred_tensor = pred[self.predict_name]
59 | else:
60 | pred_tensor = pred[self.predict_name][..., self.output_index]
61 |
62 | if pred_tensor.shape != target[self.target_name].shape:
63 | pred_tensor = pred_tensor.reshape(target[self.target_name].shape)
64 |
65 | if target is not None:
66 | target_tensor = target[self.target_name]
67 | elif self.predict_name != self.target_name:
68 | target_tensor = pred[self.target_name]
69 | else:
70 | raise ValueError("Target is None and predict_name is not equal to target_name")
71 |
72 | loss = loss_weight * self.loss_fn(pred_tensor, target_tensor)
73 | return loss
74 |
75 | def __repr__(self):
76 | return (
77 | f"{self.__class__.__name__}(name={self.name}, loss_fn={self.loss_fn}, loss_weight={self.loss_weight})"
78 | )
79 |
80 | class GetRegularizationLoss(nn.Module):
81 | def __init__(self, loss_weight, model):
82 | super().__init__()
83 | self.loss_weight = loss_weight
84 | self.model = model
85 |
86 | def forward(self,
87 | *args,
88 | ):
89 | regularization_loss = 0.0
90 | for param in self.model.parameters():
91 | if param.requires_grad:
92 | regularization_loss += torch.norm(param, p=2)
93 | regularization_loss *= self.loss_weight
94 | return regularization_loss
95 |
96 | class GetVarianceLoss(nn.Module):
97 | def __init__(
98 | self,
99 | target_name: str,
100 | loss_weight: float = 1.0,
101 | name: Optional[str] = None,
102 | ):
103 | super().__init__()
104 | self.target_name = target_name
105 | self.loss_weight = loss_weight
106 | self.name = name
107 |
108 | def forward(self, pred: Dict[str, torch.Tensor], *args):
109 |
110 | # Compute the variance along the first dimension (across different entries)
111 | variances = torch.var(pred[self.target_name], dim=0)
112 |
113 | # Calculate the mean of the variances
114 | mean_variance = torch.mean(variances)
115 | mean_variance = mean_variance * self.loss_weight
116 |
117 | return mean_variance
118 |
--------------------------------------------------------------------------------
/cace/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .torch_tools import (
2 | elementwise_multiply_2tensors,
3 | elementwise_multiply_3tensors,
4 | to_numpy,
5 | voigt_to_matrix,
6 | init_device,
7 | tensor_dict_to_device,
8 | )
9 |
10 | #from .slurm_distributed import *
11 |
12 | from .scatter import scatter_sum
13 |
14 | from .metric import *
15 |
16 | from .utils import (
17 | compute_avg_num_neighbors,
18 | setup_logger,
19 | get_unique_atomic_number,
20 | compute_average_E0s
21 | )
22 |
23 | from .output import batch_to_atoms
24 |
25 | from .parser_train import *
26 |
27 | from .io_utils import *
28 |
--------------------------------------------------------------------------------
/cace/tools/io_utils.py:
--------------------------------------------------------------------------------
1 | # Description: Utility functions for saving and loading datasets
2 |
3 | #import h5py
4 | import numpy as np
5 | import torch
6 |
7 | __all__ = ['tensor_to_numpy', 'numpy_to_tensor', 'save_dataset', 'load_dataset']
8 |
9 | # Function to convert tensors to numpy arrays
10 | def tensor_to_numpy(data):
11 | if isinstance(data, torch.Tensor):
12 | return data.numpy()
13 | elif isinstance(data, dict):
14 | return {k: tensor_to_numpy(v) for k, v in data.items()}
15 | elif isinstance(data, list):
16 | return [tensor_to_numpy(item) for item in data]
17 | else:
18 | return data
19 |
20 | # Function to convert numpy arrays back to tensors
21 | def numpy_to_tensor(data):
22 | if isinstance(data, np.ndarray):
23 | return torch.tensor(data)
24 | elif isinstance(data, dict):
25 | return {k: numpy_to_tensor(v) for k, v in data.items()}
26 | elif isinstance(data, list):
27 | return [numpy_to_tensor(item) for item in data]
28 | else:
29 | return data
30 |
31 | # Function to save the dataset to an HDF5 file
32 | def save_dataset(data, filename, shuffle=False):
33 | import h5py
34 | if shuffle:
35 | index = np.random.permutation(len(data)) # Shuffle the data
36 | else:
37 | index = np.arange(len(data))
38 | with h5py.File(filename, 'w') as f:
39 | for i, index_now in enumerate(index):
40 | item = data[index_now]
41 | grp = f.create_group(str(i))
42 | serializable_item = tensor_to_numpy(item)
43 | for k, v in serializable_item.items():
44 | grp.create_dataset(k, data=v)
45 | print(f"Saved dataset with {len(data)} records to {filename}")
46 |
47 | def load_dataset(filename):
48 | import h5py
49 | all_data = []
50 | with h5py.File(filename, 'r') as f:
51 | for key in f.keys():
52 | grp = f[key]
53 | item = {}
54 | for k in grp.keys():
55 | data = grp[k]
56 | if data.shape == (): # Check if the data is a scalar
57 | item[k] = torch.tensor(data[()]) # Access scalar value
58 | else:
59 | item[k] = torch.tensor(data[:]) # Access array value
60 | all_data.append(numpy_to_tensor(item))
61 | print(f"Loaded dataset with {len(all_data)} records from {filename}")
62 | return all_data
63 |
--------------------------------------------------------------------------------
/cace/tools/metric.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | import torch.nn as nn
4 | from typing import Optional, Dict, List
5 |
6 | from .torch_tools import to_numpy
7 |
8 | __all__ = ['Metrics', 'compute_loss_metrics']
9 |
10 | def compute_loss_metrics(metric: str, y_true: torch.Tensor, y_pred: torch.Tensor):
11 | """
12 | Compute the loss metrics
13 | current options: mse, mae, rmse, r2
14 | """
15 | if metric == 'mse':
16 | return torch.mean((y_true - y_pred) ** 2)
17 | elif metric == 'mae':
18 | return torch.mean(torch.abs(y_true - y_pred))
19 | elif metric == 'rmse':
20 | return torch.sqrt(torch.mean((y_true - y_pred) ** 2))
21 | elif metric == 'r2':
22 | return 1 - torch.sum((y_true - y_pred) ** 2) / torch.sum((y_true - torch.mean(y_true)) ** 2)
23 | else:
24 | raise ValueError('Metric not implemented')
25 |
26 | class Metrics(nn.Module):
27 | """
28 | Defines and calculate metrics to be logged.
29 | """
30 |
31 | def __init__(
32 | self,
33 | target_name: str,
34 | predict_name: Optional[str] = None,
35 | output_index: Optional[int] = None, # used for multi-task learning
36 | name: Optional[str] = None,
37 | metric_keys: List[str] = ["mae", "rmse"],
38 | per_atom: bool = False,
39 | ):
40 | """
41 | Args:
42 | target_name: name of the target in the dataset
43 | predict_name: name of the prediction in the model output
44 | name: name of the metrics
45 | metric_keys: list of metrics to be calculated
46 | per_atom: whether to calculate the metrics per atom
47 | """
48 |
49 | super().__init__()
50 | self.target_name = target_name
51 | self.predict_name = predict_name or target_name
52 | self.output_index = output_index
53 | self.name = name or target_name
54 |
55 | self.per_atom = per_atom
56 |
57 | self.metric_keys = metric_keys
58 | self.logs = {
59 | "train": {'pred': [], 'target': []},
60 | "val": {'pred': [], 'target': []},
61 | "test": {'pred': [], 'target': []},
62 | }
63 |
64 | def _collect_tensor(self,
65 | pred: Dict[str, torch.Tensor],
66 | target: Optional[Dict[str, torch.Tensor]] = None,
67 | ):
68 | pred_tensor = pred[self.predict_name].clone().detach()
69 | if len(pred_tensor.shape) > 2:
70 | pred_tensor = pred_tensor.reshape(pred_tensor.shape[0], -1)
71 | #print("pred_tensor", pred_tensor.shape)
72 | if self.output_index is not None:
73 | pred_tensor = pred_tensor[..., self.output_index]
74 | if target is not None:
75 | target_tensor = target[self.target_name].clone().detach()
76 | elif self.predict_name != self.target_name:
77 | target_tensor = pred[self.target_name].clone().detach()
78 | else:
79 | raise ValueError("Target is None and predict_name is not equal to target_name")
80 | #print("target_tensor:", target_tensor.shape)
81 | if self.per_atom:
82 | n_atoms = torch.bincount(target['batch']).clone().detach()
83 | pred_tensor = pred_tensor / n_atoms
84 | target_tensor = target_tensor / n_atoms
85 | return pred_tensor, target_tensor
86 |
87 | def forward(self,
88 | pred: Dict[str, torch.Tensor],
89 | target: Optional[Dict[str, torch.Tensor]] = None,
90 | ):
91 | pred_tensor, target_tensor = self._collect_tensor(pred, target)
92 | metrics_now = {}
93 | for metric in self.metric_keys:
94 | metrics_now[metric] = compute_loss_metrics(metric, target_tensor, pred_tensor)
95 |
96 | return metrics_now
97 |
98 | def update_metrics(self, subset: str,
99 | pred: Dict[str, torch.Tensor],
100 | target: Optional[Dict[str, torch.Tensor]] = None,
101 | ):
102 | pred_tensor, target_tensor = self._collect_tensor(pred, target)
103 | self.logs[subset]['pred'].append(pred_tensor)
104 | self.logs[subset]['target'].append(target_tensor)
105 |
106 | def retrieve_metrics(self, subset: str, clear: bool = True, print_log: bool = True):
107 | pred_tensor = torch.cat(self.logs[subset]['pred'], dim=0)
108 | target_tensor = torch.cat(self.logs[subset]['target'], dim=0)
109 |
110 | assert pred_tensor.shape == target_tensor.shape, f"pred_tensor.shape: {pred_tensor.shape}, target_tensor.shape: {target_tensor.shape}"
111 |
112 | if pred_tensor.shape[0] == 0:
113 | raise ValueError("No data in the logs")
114 |
115 | metrics_now = {}
116 | for metric in self.metric_keys:
117 | metric_mean = compute_loss_metrics(metric, target_tensor, pred_tensor)
118 | metrics_now[metric] = metric_mean
119 | if print_log:
120 | print(
121 | f'{subset}_{self.name}_{metric}: {metric_mean:.6f}',
122 | )
123 | logging.info(
124 | f'{subset}_{self.name}_{metric}: {metric_mean:.6f}',
125 | )
126 | if clear:
127 | self.clear_metrics(subset)
128 |
129 | return metrics_now
130 |
131 | def clear_metrics(self, subset: str):
132 | self.logs[subset]['pred'] = []
133 | self.logs[subset]['target'] = []
134 |
135 | def __repr__(self):
136 | return f'{self.__class__.__name__} name: {self.name}, metric_keys: {self.metric_keys}'
137 |
--------------------------------------------------------------------------------
/cace/tools/output.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .torch_tools import to_numpy
3 | from typing import Dict, Optional
4 | import ase
5 |
6 | def batch_to_atoms(batched_data: Dict,
7 | pred_data: Optional[Dict] = None,
8 | output_file: str = None,
9 | energy_key: str = 'energy',
10 | forces_key: str = 'forces',
11 | cace_energy_key: str = 'energy',
12 | cace_forces_key: str = 'forces'):
13 | """
14 | Create ASE Atoms objects from batched graph data and write to an XYZ file.
15 |
16 | Parameters:
17 | - batched_data (Dict): Batched data containing graph information.
18 | - pred_data (Dict): Predicted data. If not given, the pred_data name is assumed to also be the batched_data.
19 | - energy_key (str): Key for accessing energy information in batched_data.
20 | - forces_key (str): Key for accessing force information in batched_data.
21 | - cace_energy_key (str): Key for accessing CACE energy information.
22 | - cace_forces_key (str): Key for accessing CACE force information.
23 | - output_file (str): Name of the output file to write the Atoms objects.
24 | """
25 |
26 | if pred_data == None and energy_key != cace_energy_key:
27 | pred_data = batched_data
28 | atoms_list = []
29 | batch = batched_data.batch
30 | num_graphs = batch.max().item() + 1
31 |
32 | for i in range(num_graphs):
33 | # Mask to extract nodes for each graph
34 | mask = batch == i
35 |
36 | # Extract node features, edge indices, etc., for each graph
37 | positions = to_numpy(batched_data['positions'][mask])
38 | atomic_numbers = to_numpy(batched_data['atomic_numbers'][mask])
39 | cell = to_numpy(batched_data['cell'][3*i:3*i+3])
40 |
41 | energy = to_numpy(batched_data[energy_key][i])
42 | forces = to_numpy(batched_data[forces_key][mask])
43 | cace_energy = to_numpy(pred_data[cace_energy_key][i])
44 | cace_forces = to_numpy(pred_data[cace_forces_key][mask])
45 |
46 | # Set periodic boundary conditions if the cell is defined
47 | pbc = np.all(np.mean(cell, axis=0) > 0)
48 |
49 | # Create the Atoms object
50 | atoms = ase.Atoms(numbers=atomic_numbers, positions=positions, cell=cell, pbc=pbc)
51 | atoms.info[energy_key] = energy.item() if np.ndim(energy) == 0 else energy
52 | atoms.arrays[forces_key] = forces
53 | atoms.info[cace_energy_key] = cace_energy.item() if np.ndim(cace_energy) == 0 else cace_energy
54 | atoms.arrays[cace_forces_key] = cace_forces
55 | atoms_list.append(atoms)
56 |
57 | # Write all atoms to the output file
58 | if output_file:
59 | ase.io.write(output_file, atoms_list, append=True)
60 | return atoms_list
61 |
--------------------------------------------------------------------------------
/cace/tools/parser_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | __all__ = ['parse_arguments']
4 |
5 | def parse_arguments():
6 | parser = argparse.ArgumentParser(description='ML potential training configuration')
7 |
8 | parser.add_argument('--prefix', type=str, default='CACE_NNP', help='Prefix for the model name')
9 |
10 | # Dataset and Training Configuration
11 | parser.add_argument('--train_path', type=str, help='Path to the training dataset', required=True)
12 | parser.add_argument('--valid_path', type=str, default=None, help='Path to the training dataset')
13 | parser.add_argument('--valid_fraction', type=float, default=0.1, help='Fraction of data to use for validation')
14 | parser.add_argument('--energy_key', type=str, default='energy', help='Key for energy in the dataset')
15 | parser.add_argument('--forces_key', type=str, default='forces', help='Key for forces in the dataset')
16 | parser.add_argument('--cutoff', type=float, default=4.0, help='Cutoff radius for interactions')
17 | parser.add_argument('--batch_size', type=int, default=10, help='Batch size for training')
18 | parser.add_argument('--valid_batch_size', type=int, default=20, help='Batch size for validation')
19 | parser.add_argument('--use_device', type=str, default='cpu', help='Device to use for training')
20 |
21 | # Radial Basis Function (RBF) Configuration
22 | parser.add_argument('--n_rbf', type=int, default=6, help='Number of RBFs')
23 | parser.add_argument('--trainable_rbf', action='store_true', help='Whether the RBF parameters are trainable')
24 |
25 | # Cutoff Function Configuration
26 | parser.add_argument('--cutoff_fn', type=str, default='PolynomialCutoff', help='Type of cutoff function')
27 | parser.add_argument('--cutoff_fn_p', type=int, default=5, help='Polynomial degree for the PolynomialCutoff function')
28 |
29 | # Representation Configuration
30 | parser.add_argument('--zs', type=int, nargs='+', default=None, help='Atomic numbers considered in the model')
31 | parser.add_argument('--n_atom_basis', type=int, default=3, help='Number of atom basis functions')
32 | parser.add_argument('--n_radial_basis', type=int, default=8, help='Number of radial basis functions')
33 | parser.add_argument('--max_l', type=int, default=3, help='Maximum angular momentum quantum number')
34 | parser.add_argument('--max_nu', type=int, default=3, help='Maximum radial quantum number')
35 | parser.add_argument('--num_message_passing', type=int, default=1, help='Number of message passing steps')
36 | parser.add_argument('--embed_receiver_nodes', action='store_true', help='Whether to embed receiver nodes')
37 |
38 | # Atomwise Module Configuration
39 | parser.add_argument('--atomwise_layers', type=int, default=3, help='Number of layers in the atomwise module')
40 | parser.add_argument('--atomwise_hidden', type=int, nargs='+', default=[32, 16], help='Hidden units in each layer of the atomwise module')
41 | parser.add_argument('--atomwise_residual', action='store_false', help='Use residual connections in the atomwise module')
42 | parser.add_argument('--atomwise_batchnorm', action='store_false', help='Use batch normalization in the atomwise module')
43 | parser.add_argument('--atomwise_linear_nn', action='store_true', help='Add a linear neural network layer in the atomwise module')
44 |
45 | # Training Procedure Configuration
46 | parser.add_argument('--lr', type=float, default=1e-2, help='Learning rate')
47 | parser.add_argument('--scheduler_factor', type=float, default=0.8, help='Factor by which the learning rate is reduced')
48 | parser.add_argument('--scheduler_patience', type=int, default=10, help='Patience for the learning rate scheduler')
49 | parser.add_argument('--max_grad_norm', type=float, default=10, help='Max gradient norm for gradient clipping')
50 | parser.add_argument('--ema', action='store_true', help='Use exponential moving average of model parameters')
51 | parser.add_argument('--ema_start', type=int, default=10, help='Start using EMA after this many steps')
52 | parser.add_argument('--warmup_steps', type=int, default=10, help='Number of warmup steps for the optimizer')
53 | parser.add_argument('--epochs', type=int, default=200, help='Number of epochs for the first phase of training')
54 | parser.add_argument('--second_phase_epochs', type=int, default=100, help='Number of epochs for the second phase of training')
55 | parser.add_argument('--energy_loss_weight', type=float, default=1.0, help='Weight for the energy loss in phase 1')
56 | parser.add_argument('--force_loss_weight', type=float, default=1000.0, help='Weight for the force loss in both phases')
57 | parser.add_argument('--num_restart', type=int, default=5, help='Number of restarts for the training during phase 1')
58 | parser.add_argument('--second_phase_energy_loss_weight', type=float, default=1000.0,
59 | help='Weight for the energy loss in phase 2')
60 |
61 | return parser.parse_args()
62 |
--------------------------------------------------------------------------------
/cace/tools/scatter.py:
--------------------------------------------------------------------------------
1 | """basic scatter_sum operations from torch_scatter from
2 | https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py
3 | Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency.
4 | PyTorch plans to move these features into the main repo, but until then,
5 | to make installation simpler, we need this pure python set of wrappers
6 | that don't require installing PyTorch C++ extensions.
7 | See https://github.com/pytorch/pytorch/issues/63780.
8 | """
9 |
10 | from typing import Optional
11 |
12 | import torch
13 |
14 |
15 | def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
16 | if dim < 0:
17 | dim = other.dim() + dim
18 | if src.dim() == 1:
19 | for _ in range(0, dim):
20 | src = src.unsqueeze(0)
21 | for _ in range(src.dim(), other.dim()):
22 | src = src.unsqueeze(-1)
23 | src = src.expand_as(other)
24 | return src
25 |
26 |
27 | @torch.jit.script
28 | def scatter_sum(
29 | src: torch.Tensor,
30 | index: torch.Tensor,
31 | dim: int = -1,
32 | out: Optional[torch.Tensor] = None,
33 | dim_size: Optional[int] = None,
34 | reduce: str = "sum",
35 | ) -> torch.Tensor:
36 | assert reduce == "sum" # for now, TODO
37 | index = _broadcast(index, src, dim)
38 | if out is None:
39 | size = list(src.size())
40 | if dim_size is not None:
41 | size[dim] = dim_size
42 | elif index.numel() == 0:
43 | size[dim] = 0
44 | else:
45 | size[dim] = int(index.max()) + 1
46 | out = torch.zeros(size, dtype=src.dtype, device=src.device)
47 | return out.scatter_add_(dim, index, src)
48 | else:
49 | return out.scatter_add_(dim, index, src)
50 |
51 |
52 | @torch.jit.script
53 | def scatter_std(
54 | src: torch.Tensor,
55 | index: torch.Tensor,
56 | dim: int = -1,
57 | out: Optional[torch.Tensor] = None,
58 | dim_size: Optional[int] = None,
59 | unbiased: bool = True,
60 | ) -> torch.Tensor:
61 | if out is not None:
62 | dim_size = out.size(dim)
63 |
64 | if dim < 0:
65 | dim = src.dim() + dim
66 |
67 | count_dim = dim
68 | if index.dim() <= dim:
69 | count_dim = index.dim() - 1
70 |
71 | ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
72 | count = scatter_sum(ones, index, count_dim, dim_size=dim_size)
73 |
74 | index = _broadcast(index, src, dim)
75 | tmp = scatter_sum(src, index, dim, dim_size=dim_size)
76 | count = _broadcast(count, tmp, dim).clamp(1)
77 | mean = tmp.div(count)
78 |
79 | var = src - mean.gather(dim, index)
80 | var = var * var
81 | out = scatter_sum(var, index, dim, out, dim_size)
82 |
83 | if unbiased:
84 | count = count.sub(1).clamp_(1)
85 | out = out.div(count + 1e-6).sqrt()
86 |
87 | return out
88 |
89 |
90 | @torch.jit.script
91 | def scatter_mean(
92 | src: torch.Tensor,
93 | index: torch.Tensor,
94 | dim: int = -1,
95 | out: Optional[torch.Tensor] = None,
96 | dim_size: Optional[int] = None,
97 | ) -> torch.Tensor:
98 | out = scatter_sum(src, index, dim, out, dim_size)
99 | dim_size = out.size(dim)
100 |
101 | index_dim = dim
102 | if index_dim < 0:
103 | index_dim = index_dim + src.dim()
104 | if index.dim() <= index_dim:
105 | index_dim = index.dim() - 1
106 |
107 | ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
108 | count = scatter_sum(ones, index, index_dim, None, dim_size)
109 | count[count < 1] = 1
110 | count = _broadcast(count, out, dim)
111 | if out.is_floating_point():
112 | out.true_divide_(count)
113 | else:
114 | out.div_(count, rounding_mode="floor")
115 | return out
116 |
--------------------------------------------------------------------------------
/cace/tools/torch_geometric/README.md:
--------------------------------------------------------------------------------
1 | # Trimmed-down `pytorch_geometric`
2 |
3 | MACE uses [`pytorch_geometric`](https://pytorch-geometric.readthedocs.io/en/latest/) [1, 2] framework. However as only use a very limited subset of that library: the most basic graph data structures.
4 |
5 | We follow the same approach to NequIP (https://github.com/mir-group/nequip/tree/main/nequip) and copy their code here.
6 |
7 | To avoid adding a large number of unnecessary second-degree dependencies, and to simplify installation, we include and modify here the small subset of `torch_geometric` that is neccessary for our code.
8 |
9 | We are grateful to the developers of PyTorch Geometric for their ongoing and very useful work on graph learning with PyTorch.
10 |
11 | [1] Fey, M., & Lenssen, J. E. (2019). Fast Graph Representation Learning with PyTorch Geometric (Version 2.0.1) [Computer software]. https://github.com/pyg-team/pytorch_geometric
12 | [2] https://arxiv.org/abs/1903.02428
--------------------------------------------------------------------------------
/cace/tools/torch_geometric/__init__.py:
--------------------------------------------------------------------------------
1 | from .batch import Batch
2 | from .data import Data
3 | from .dataloader import DataLoader
4 | from .dataset import Dataset
5 |
6 | __all__ = ["Batch", "Data", "Dataset", "DataLoader"]
7 |
--------------------------------------------------------------------------------
/cace/tools/torch_geometric/batch.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 | from typing import List
3 |
4 | import numpy as np
5 | import torch
6 | from torch import Tensor
7 |
8 | from .data import Data
9 | from .dataset import IndexType
10 |
11 |
12 | class Batch(Data):
13 | r"""A plain old python object modeling a batch of graphs as one big
14 | (disconnected) graph. With :class:`torch_geometric.data.Data` being the
15 | base class, all its methods can also be used here.
16 | In addition, single graphs can be reconstructed via the assignment vector
17 | :obj:`batch`, which maps each node to its respective graph identifier.
18 | """
19 |
20 | def __init__(self, batch=None, ptr=None, **kwargs):
21 | super().__init__(**kwargs)
22 |
23 | for key, item in kwargs.items():
24 | if key == "num_nodes":
25 | self.__num_nodes__ = item
26 | else:
27 | self[key] = item
28 |
29 | self.batch = batch
30 | self.ptr = ptr
31 | self.__data_class__ = Data
32 | self.__slices__ = None
33 | self.__cumsum__ = None
34 | self.__cat_dims__ = None
35 | self.__num_nodes_list__ = None
36 | self.__num_graphs__ = None
37 |
38 | @classmethod
39 | def from_data_list(cls, data_list, follow_batch=[], exclude_keys=[]):
40 | r"""Constructs a batch object from a python list holding
41 | :class:`torch_geometric.data.Data` objects.
42 | The assignment vector :obj:`batch` is created on the fly.
43 | Additionally, creates assignment batch vectors for each key in
44 | :obj:`follow_batch`.
45 | Will exclude any keys given in :obj:`exclude_keys`."""
46 |
47 | # this is for the competibility of different pytorch versions
48 | try:
49 | keys = list(set(data_list[0].keys) - set(exclude_keys))
50 | except:
51 | keys = list(set(data_list[0].keys()) - set(exclude_keys))
52 | assert "batch" not in keys and "ptr" not in keys
53 |
54 | batch = cls()
55 | for key in data_list[0].__dict__.keys():
56 | if key[:2] != "__" and key[-2:] != "__":
57 | batch[key] = None
58 |
59 | batch.__num_graphs__ = len(data_list)
60 | batch.__data_class__ = data_list[0].__class__
61 | for key in keys + ["batch"]:
62 | batch[key] = []
63 | batch["ptr"] = [0]
64 |
65 | device = None
66 | slices = {key: [0] for key in keys}
67 | cumsum = {key: [0] for key in keys}
68 | cat_dims = {}
69 | num_nodes_list = []
70 | for i, data in enumerate(data_list):
71 | for key in keys:
72 | item = data[key]
73 |
74 | # Increase values by `cumsum` value.
75 | cum = cumsum[key][-1]
76 | if isinstance(item, Tensor) and item.dtype != torch.bool:
77 | if not isinstance(cum, int) or cum != 0:
78 | item = item + cum
79 | elif isinstance(item, (int, float)):
80 | item = item + cum
81 |
82 | # Gather the size of the `cat` dimension.
83 | size = 1
84 | cat_dim = data.__cat_dim__(key, data[key])
85 | # 0-dimensional tensors have no dimension along which to
86 | # concatenate, so we set `cat_dim` to `None`.
87 | if isinstance(item, Tensor) and item.dim() == 0:
88 | cat_dim = None
89 | cat_dims[key] = cat_dim
90 |
91 | # Add a batch dimension to items whose `cat_dim` is `None`:
92 | if isinstance(item, Tensor) and cat_dim is None:
93 | cat_dim = 0 # Concatenate along this new batch dimension.
94 | item = item.unsqueeze(0)
95 | device = item.device
96 | elif isinstance(item, Tensor):
97 | size = item.size(cat_dim)
98 | device = item.device
99 |
100 | batch[key].append(item) # Append item to the attribute list.
101 |
102 | slices[key].append(size + slices[key][-1])
103 | inc = data.__inc__(key, item)
104 | if isinstance(inc, (tuple, list)):
105 | inc = torch.tensor(inc)
106 | cumsum[key].append(inc + cumsum[key][-1])
107 |
108 | if key in follow_batch:
109 | if isinstance(size, Tensor):
110 | for j, size in enumerate(size.tolist()):
111 | tmp = f"{key}_{j}_batch"
112 | batch[tmp] = [] if i == 0 else batch[tmp]
113 | batch[tmp].append(
114 | torch.full((size,), i, dtype=torch.long, device=device)
115 | )
116 | else:
117 | tmp = f"{key}_batch"
118 | batch[tmp] = [] if i == 0 else batch[tmp]
119 | batch[tmp].append(
120 | torch.full((size,), i, dtype=torch.long, device=device)
121 | )
122 |
123 | if hasattr(data, "__num_nodes__"):
124 | num_nodes_list.append(data.__num_nodes__)
125 | else:
126 | num_nodes_list.append(None)
127 |
128 | num_nodes = data.num_nodes
129 | if num_nodes is not None:
130 | item = torch.full((num_nodes,), i, dtype=torch.long, device=device)
131 | batch.batch.append(item)
132 | batch.ptr.append(batch.ptr[-1] + num_nodes)
133 |
134 | batch.batch = None if len(batch.batch) == 0 else batch.batch
135 | batch.ptr = None if len(batch.ptr) == 1 else batch.ptr
136 | batch.__slices__ = slices
137 | batch.__cumsum__ = cumsum
138 | batch.__cat_dims__ = cat_dims
139 | batch.__num_nodes_list__ = num_nodes_list
140 |
141 | ref_data = data_list[0]
142 | for key in batch.keys:
143 | items = batch[key]
144 | item = items[0]
145 | cat_dim = ref_data.__cat_dim__(key, item)
146 | cat_dim = 0 if cat_dim is None else cat_dim
147 | if isinstance(item, Tensor):
148 | batch[key] = torch.cat(items, cat_dim)
149 | elif isinstance(item, (int, float)):
150 | batch[key] = torch.tensor(items)
151 |
152 | # if torch_geometric.is_debug_enabled():
153 | # batch.debug()
154 |
155 | return batch.contiguous()
156 |
157 | def get_example(self, idx: int) -> Data:
158 | r"""Reconstructs the :class:`torch_geometric.data.Data` object at index
159 | :obj:`idx` from the batch object.
160 | The batch object must have been created via :meth:`from_data_list` in
161 | order to be able to reconstruct the initial objects."""
162 |
163 | if self.__slices__ is None:
164 | raise RuntimeError(
165 | (
166 | "Cannot reconstruct data list from batch because the batch "
167 | "object was not created using `Batch.from_data_list()`."
168 | )
169 | )
170 |
171 | data = self.__data_class__()
172 | idx = self.num_graphs + idx if idx < 0 else idx
173 |
174 | for key in self.__slices__.keys():
175 | item = self[key]
176 | if self.__cat_dims__[key] is None:
177 | # The item was concatenated along a new batch dimension,
178 | # so just index in that dimension:
179 | item = item[idx]
180 | else:
181 | # Narrow the item based on the values in `__slices__`.
182 | if isinstance(item, Tensor):
183 | dim = self.__cat_dims__[key]
184 | start = self.__slices__[key][idx]
185 | end = self.__slices__[key][idx + 1]
186 | item = item.narrow(dim, start, end - start)
187 | else:
188 | start = self.__slices__[key][idx]
189 | end = self.__slices__[key][idx + 1]
190 | item = item[start:end]
191 | item = item[0] if len(item) == 1 else item
192 |
193 | # Decrease its value by `cumsum` value:
194 | cum = self.__cumsum__[key][idx]
195 | if isinstance(item, Tensor):
196 | if not isinstance(cum, int) or cum != 0:
197 | item = item - cum
198 | elif isinstance(item, (int, float)):
199 | item = item - cum
200 |
201 | data[key] = item
202 |
203 | if self.__num_nodes_list__[idx] is not None:
204 | data.num_nodes = self.__num_nodes_list__[idx]
205 |
206 | return data
207 |
208 | def index_select(self, idx: IndexType) -> List[Data]:
209 | if isinstance(idx, slice):
210 | idx = list(range(self.num_graphs)[idx])
211 |
212 | elif isinstance(idx, Tensor) and idx.dtype == torch.long:
213 | idx = idx.flatten().tolist()
214 |
215 | elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
216 | idx = idx.flatten().nonzero(as_tuple=False).flatten().tolist()
217 |
218 | elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
219 | idx = idx.flatten().tolist()
220 |
221 | elif isinstance(idx, np.ndarray) and idx.dtype == np.bool:
222 | idx = idx.flatten().nonzero()[0].flatten().tolist()
223 |
224 | elif isinstance(idx, Sequence) and not isinstance(idx, str):
225 | pass
226 |
227 | else:
228 | raise IndexError(
229 | f"Only integers, slices (':'), list, tuples, torch.tensor and "
230 | f"np.ndarray of dtype long or bool are valid indices (got "
231 | f"'{type(idx).__name__}')"
232 | )
233 |
234 | return [self.get_example(i) for i in idx]
235 |
236 | def __getitem__(self, idx):
237 | if isinstance(idx, str):
238 | return super().__getitem__(idx)
239 | elif isinstance(idx, (int, np.integer)):
240 | return self.get_example(idx)
241 | else:
242 | return self.index_select(idx)
243 |
244 | def to_data_list(self) -> List[Data]:
245 | r"""Reconstructs the list of :class:`torch_geometric.data.Data` objects
246 | from the batch object.
247 | The batch object must have been created via :meth:`from_data_list` in
248 | order to be able to reconstruct the initial objects."""
249 | return [self.get_example(i) for i in range(self.num_graphs)]
250 |
251 | @property
252 | def num_graphs(self) -> int:
253 | """Returns the number of graphs in the batch."""
254 | if self.__num_graphs__ is not None:
255 | return self.__num_graphs__
256 | elif self.ptr is not None:
257 | return self.ptr.numel() - 1
258 | elif self.batch is not None:
259 | return int(self.batch.max()) + 1
260 | else:
261 | raise ValueError
262 |
--------------------------------------------------------------------------------
/cace/tools/torch_geometric/dataloader.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Mapping, Sequence
2 | from typing import List, Optional, Union
3 |
4 | import torch.utils.data
5 | from torch.utils.data.dataloader import default_collate
6 |
7 | from .batch import Batch
8 | from .data import Data
9 | from .dataset import Dataset
10 |
11 | class Collater:
12 | def __init__(self, follow_batch, exclude_keys):
13 | self.follow_batch = follow_batch
14 | self.exclude_keys = exclude_keys
15 |
16 | def __call__(self, batch):
17 | elem = batch[0]
18 | if isinstance(elem, Data):
19 | return Batch.from_data_list(
20 | batch,
21 | follow_batch=self.follow_batch,
22 | exclude_keys=self.exclude_keys,
23 | )
24 | elif isinstance(elem, torch.Tensor):
25 | return default_collate(batch)
26 | elif isinstance(elem, float):
27 | return torch.tensor(batch, dtype=torch.float)
28 | elif isinstance(elem, int):
29 | return torch.tensor(batch)
30 | elif isinstance(elem, str):
31 | return batch
32 | elif isinstance(elem, Mapping):
33 | return {key: self([data[key] for data in batch]) for key in elem}
34 | elif isinstance(elem, tuple) and hasattr(elem, "_fields"):
35 | return type(elem)(*(self(s) for s in zip(*batch)))
36 | elif isinstance(elem, Sequence) and not isinstance(elem, str):
37 | return [self(s) for s in zip(*batch)]
38 |
39 | raise TypeError(f"DataLoader found invalid type: {type(elem)}")
40 |
41 | def collate(self, batch): # Deprecated...
42 | return self(batch)
43 |
44 |
45 | class DataLoader(torch.utils.data.DataLoader):
46 | r"""A data loader which merges data objects from a
47 | :class:`torch_geometric.data.Dataset` to a mini-batch.
48 | Data objects can be either of type :class:`~torch_geometric.data.Data` or
49 | :class:`~torch_geometric.data.HeteroData`.
50 | Args:
51 | dataset (Dataset): The dataset from which to load the data.
52 | batch_size (int, optional): How many samples per batch to load.
53 | (default: :obj:`1`)
54 | shuffle (bool, optional): If set to :obj:`True`, the data will be
55 | reshuffled at every epoch. (default: :obj:`False`)
56 | follow_batch (List[str], optional): Creates assignment batch
57 | vectors for each key in the list. (default: :obj:`None`)
58 | exclude_keys (List[str], optional): Will exclude each key in the
59 | list. (default: :obj:`None`)
60 | **kwargs (optional): Additional arguments of
61 | :class:`torch.utils.data.DataLoader`.
62 | """
63 |
64 | def __init__(
65 | self,
66 | dataset: Dataset,
67 | batch_size: int = 1,
68 | shuffle: bool = False,
69 | follow_batch: Optional[List[str]] = [None],
70 | exclude_keys: Optional[List[str]] = [None],
71 | **kwargs,
72 | ):
73 | if "collate_fn" in kwargs:
74 | del kwargs["collate_fn"]
75 |
76 | # Save for PyTorch Lightning < 1.6:
77 | self.follow_batch = follow_batch
78 | self.exclude_keys = exclude_keys
79 |
80 | super().__init__(
81 | dataset,
82 | batch_size,
83 | shuffle,
84 | collate_fn=Collater(follow_batch, exclude_keys),
85 | **kwargs,
86 | )
87 |
--------------------------------------------------------------------------------
/cace/tools/torch_geometric/dataset.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os.path as osp
3 | import re
4 | import warnings
5 | from collections.abc import Sequence
6 | from typing import Any, Callable, List, Optional, Tuple, Union
7 |
8 | import numpy as np
9 | import torch.utils.data
10 | from torch import Tensor
11 |
12 | from .data import Data
13 | from .utils import makedirs
14 |
15 | IndexType = Union[slice, Tensor, np.ndarray, Sequence]
16 |
17 |
18 | class Dataset(torch.utils.data.Dataset):
19 | r"""Dataset base class for creating graph datasets.
20 | See `here `__ for the accompanying tutorial.
22 |
23 | Args:
24 | root (string, optional): Root directory where the dataset should be
25 | saved. (optional: :obj:`None`)
26 | transform (callable, optional): A function/transform that takes in an
27 | :obj:`torch_geometric.data.Data` object and returns a transformed
28 | version. The data object will be transformed before every access.
29 | (default: :obj:`None`)
30 | pre_transform (callable, optional): A function/transform that takes in
31 | an :obj:`torch_geometric.data.Data` object and returns a
32 | transformed version. The data object will be transformed before
33 | being saved to disk. (default: :obj:`None`)
34 | pre_filter (callable, optional): A function that takes in an
35 | :obj:`torch_geometric.data.Data` object and returns a boolean
36 | value, indicating whether the data object should be included in the
37 | final dataset. (default: :obj:`None`)
38 | """
39 |
40 | @property
41 | def raw_file_names(self) -> Union[str, List[str], Tuple]:
42 | r"""The name of the files to find in the :obj:`self.raw_dir` folder in
43 | order to skip the download."""
44 | raise NotImplementedError
45 |
46 | @property
47 | def processed_file_names(self) -> Union[str, List[str], Tuple]:
48 | r"""The name of the files to find in the :obj:`self.processed_dir`
49 | folder in order to skip the processing."""
50 | raise NotImplementedError
51 |
52 | def download(self):
53 | r"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
54 | raise NotImplementedError
55 |
56 | def process(self):
57 | r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
58 | raise NotImplementedError
59 |
60 | def len(self) -> int:
61 | raise NotImplementedError
62 |
63 | def get(self, idx: int) -> Data:
64 | r"""Gets the data object at index :obj:`idx`."""
65 | raise NotImplementedError
66 |
67 | def __init__(
68 | self,
69 | root: Optional[str] = None,
70 | transform: Optional[Callable] = None,
71 | pre_transform: Optional[Callable] = None,
72 | pre_filter: Optional[Callable] = None,
73 | ):
74 | super().__init__()
75 |
76 | if isinstance(root, str):
77 | root = osp.expanduser(osp.normpath(root))
78 |
79 | self.root = root
80 | self.transform = transform
81 | self.pre_transform = pre_transform
82 | self.pre_filter = pre_filter
83 | self._indices: Optional[Sequence] = None
84 |
85 | if "download" in self.__class__.__dict__.keys():
86 | self._download()
87 |
88 | if "process" in self.__class__.__dict__.keys():
89 | self._process()
90 |
91 | def indices(self) -> Sequence:
92 | return range(self.len()) if self._indices is None else self._indices
93 |
94 | @property
95 | def raw_dir(self) -> str:
96 | return osp.join(self.root, "raw")
97 |
98 | @property
99 | def processed_dir(self) -> str:
100 | return osp.join(self.root, "processed")
101 |
102 | @property
103 | def num_node_features(self) -> int:
104 | r"""Returns the number of features per node in the dataset."""
105 | data = self[0]
106 | if hasattr(data, "num_node_features"):
107 | return data.num_node_features
108 | raise AttributeError(
109 | f"'{data.__class__.__name__}' object has no "
110 | f"attribute 'num_node_features'"
111 | )
112 |
113 | @property
114 | def num_features(self) -> int:
115 | r"""Alias for :py:attr:`~num_node_features`."""
116 | return self.num_node_features
117 |
118 | @property
119 | def num_edge_features(self) -> int:
120 | r"""Returns the number of features per edge in the dataset."""
121 | data = self[0]
122 | if hasattr(data, "num_edge_features"):
123 | return data.num_edge_features
124 | raise AttributeError(
125 | f"'{data.__class__.__name__}' object has no "
126 | f"attribute 'num_edge_features'"
127 | )
128 |
129 | @property
130 | def raw_paths(self) -> List[str]:
131 | r"""The filepaths to find in order to skip the download."""
132 | files = to_list(self.raw_file_names)
133 | return [osp.join(self.raw_dir, f) for f in files]
134 |
135 | @property
136 | def processed_paths(self) -> List[str]:
137 | r"""The filepaths to find in the :obj:`self.processed_dir`
138 | folder in order to skip the processing."""
139 | files = to_list(self.processed_file_names)
140 | return [osp.join(self.processed_dir, f) for f in files]
141 |
142 | def _download(self):
143 | if files_exist(self.raw_paths): # pragma: no cover
144 | return
145 |
146 | makedirs(self.raw_dir)
147 | self.download()
148 |
149 | def _process(self):
150 | f = osp.join(self.processed_dir, "pre_transform.pt")
151 | if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
152 | warnings.warn(
153 | f"The `pre_transform` argument differs from the one used in "
154 | f"the pre-processed version of this dataset. If you want to "
155 | f"make use of another pre-processing technique, make sure to "
156 | f"sure to delete '{self.processed_dir}' first"
157 | )
158 |
159 | f = osp.join(self.processed_dir, "pre_filter.pt")
160 | if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
161 | warnings.warn(
162 | "The `pre_filter` argument differs from the one used in the "
163 | "pre-processed version of this dataset. If you want to make "
164 | "use of another pre-fitering technique, make sure to delete "
165 | "'{self.processed_dir}' first"
166 | )
167 |
168 | if files_exist(self.processed_paths): # pragma: no cover
169 | return
170 |
171 | print("Processing...")
172 |
173 | makedirs(self.processed_dir)
174 | self.process()
175 |
176 | path = osp.join(self.processed_dir, "pre_transform.pt")
177 | torch.save(_repr(self.pre_transform), path)
178 | path = osp.join(self.processed_dir, "pre_filter.pt")
179 | torch.save(_repr(self.pre_filter), path)
180 |
181 | print("Done!")
182 |
183 | def __len__(self) -> int:
184 | r"""The number of examples in the dataset."""
185 | return len(self.indices())
186 |
187 | def __getitem__(
188 | self,
189 | idx: Union[int, np.integer, IndexType],
190 | ) -> Union["Dataset", Data]:
191 | r"""In case :obj:`idx` is of type integer, will return the data object
192 | at index :obj:`idx` (and transforms it in case :obj:`transform` is
193 | present).
194 | In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
195 | tuple, a PyTorch :obj:`LongTensor` or a :obj:`BoolTensor`, or a numpy
196 | :obj:`np.array`, will return a subset of the dataset at the specified
197 | indices."""
198 | if (
199 | isinstance(idx, (int, np.integer))
200 | or (isinstance(idx, Tensor) and idx.dim() == 0)
201 | or (isinstance(idx, np.ndarray) and np.isscalar(idx))
202 | ):
203 | data = self.get(self.indices()[idx])
204 | data = data if self.transform is None else self.transform(data)
205 | return data
206 |
207 | else:
208 | return self.index_select(idx)
209 |
210 | def index_select(self, idx: IndexType) -> "Dataset":
211 | indices = self.indices()
212 |
213 | if isinstance(idx, slice):
214 | indices = indices[idx]
215 |
216 | elif isinstance(idx, Tensor) and idx.dtype == torch.long:
217 | return self.index_select(idx.flatten().tolist())
218 |
219 | elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
220 | idx = idx.flatten().nonzero(as_tuple=False)
221 | return self.index_select(idx.flatten().tolist())
222 |
223 | elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
224 | return self.index_select(idx.flatten().tolist())
225 |
226 | elif isinstance(idx, np.ndarray) and idx.dtype == np.bool:
227 | idx = idx.flatten().nonzero()[0]
228 | return self.index_select(idx.flatten().tolist())
229 |
230 | elif isinstance(idx, Sequence) and not isinstance(idx, str):
231 | indices = [indices[i] for i in idx]
232 |
233 | else:
234 | raise IndexError(
235 | f"Only integers, slices (':'), list, tuples, torch.tensor and "
236 | f"np.ndarray of dtype long or bool are valid indices (got "
237 | f"'{type(idx).__name__}')"
238 | )
239 |
240 | dataset = copy.copy(self)
241 | dataset._indices = indices
242 | return dataset
243 |
244 | def shuffle(
245 | self,
246 | return_perm: bool = False,
247 | ) -> Union["Dataset", Tuple["Dataset", Tensor]]:
248 | r"""Randomly shuffles the examples in the dataset.
249 |
250 | Args:
251 | return_perm (bool, optional): If set to :obj:`True`, will return
252 | the random permutation used to shuffle the dataset in addition.
253 | (default: :obj:`False`)
254 | """
255 | perm = torch.randperm(len(self))
256 | dataset = self.index_select(perm)
257 | return (dataset, perm) if return_perm is True else dataset
258 |
259 | def __repr__(self) -> str:
260 | arg_repr = str(len(self)) if len(self) > 1 else ""
261 | return f"{self.__class__.__name__}({arg_repr})"
262 |
263 |
264 | def to_list(value: Any) -> Sequence:
265 | if isinstance(value, Sequence) and not isinstance(value, str):
266 | return value
267 | else:
268 | return [value]
269 |
270 |
271 | def files_exist(files: List[str]) -> bool:
272 | # NOTE: We return `False` in case `files` is empty, leading to a
273 | # re-processing of files on every instantiation.
274 | return len(files) != 0 and all([osp.exists(f) for f in files])
275 |
276 |
277 | def _repr(obj: Any) -> str:
278 | if obj is None:
279 | return "None"
280 | return re.sub("(<.*?)\\s.*(>)", r"\1\2", obj.__repr__())
281 |
--------------------------------------------------------------------------------
/cace/tools/torch_geometric/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import ssl
4 | import urllib
5 | import zipfile
6 |
7 |
8 | def makedirs(dir):
9 | os.makedirs(dir, exist_ok=True)
10 |
11 |
12 | def download_url(url, folder, log=True):
13 | r"""Downloads the content of an URL to a specific folder.
14 |
15 | Args:
16 | url (string): The url.
17 | folder (string): The folder.
18 | log (bool, optional): If :obj:`False`, will not print anything to the
19 | console. (default: :obj:`True`)
20 | """
21 |
22 | filename = url.rpartition("/")[2].split("?")[0]
23 | path = osp.join(folder, filename)
24 |
25 | if osp.exists(path): # pragma: no cover
26 | if log:
27 | print("Using exist file", filename)
28 | return path
29 |
30 | if log:
31 | print("Downloading", url)
32 |
33 | makedirs(folder)
34 |
35 | context = ssl._create_unverified_context()
36 | data = urllib.request.urlopen(url, context=context)
37 |
38 | with open(path, "wb") as f:
39 | f.write(data.read())
40 |
41 | return path
42 |
43 |
44 | def extract_zip(path, folder, log=True):
45 | r"""Extracts a zip archive to a specific folder.
46 |
47 | Args:
48 | path (string): The path to the tar archive.
49 | folder (string): The folder.
50 | log (bool, optional): If :obj:`False`, will not print anything to the
51 | console. (default: :obj:`True`)
52 | """
53 | with zipfile.ZipFile(path, "r") as f:
54 | f.extractall(folder)
55 |
--------------------------------------------------------------------------------
/cace/tools/torch_tools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import logging
4 | from typing import Dict
5 |
6 | TensorDict = Dict[str, torch.Tensor]
7 |
8 | def elementwise_multiply_2tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
9 | """
10 | Elementwise multiplication of two 2D tensors
11 | :param a: (N, A) tensor
12 | :param b: (N, B) tensor
13 | :return: (N, A, B) tensor
14 | """
15 | # expand the dimenstions for broadcasting
16 | a_expanded = a.unsqueeze(2)
17 | b_expanded = b.unsqueeze(1)
18 | # multiply
19 | return a_expanded * b_expanded
20 |
21 | @torch.jit.script
22 | def elementwise_multiply_3tensors(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
23 | """
24 | Elementwise multiplication of three 2D tensors
25 | :param a: (N, A) tensor
26 | :param b: (N, B) tensor
27 | :param c: (N, C) tensor
28 | :return: (N, A, B, C) tensor
29 | """
30 | # expand the dimenstions for broadcasting
31 | a_expanded = a.unsqueeze(2).unsqueeze(3)
32 | b_expanded = b.unsqueeze(1).unsqueeze(3)
33 | c_expanded = c.unsqueeze(1).unsqueeze(2)
34 | # multiply
35 | # this is the same as torch.einsum('ni,nj,nk->nijk', a, b,c)
36 | # but a bit faster
37 | return a_expanded * b_expanded * c_expanded
38 |
39 | def tensor_dict_to_device(td: TensorDict, device: torch.device) -> TensorDict:
40 | return {k: v.to(device) if v is not None else None for k, v in td.items()}
41 |
42 | def to_numpy(t: torch.Tensor) -> np.ndarray:
43 | return t.cpu().detach().numpy()
44 |
45 | def init_device(device_str: str) -> torch.device:
46 | if device_str == "cuda":
47 | assert torch.cuda.is_available(), "No CUDA device available!"
48 | logging.info(
49 | f"CUDA version: {torch.version.cuda}, CUDA device: {torch.cuda.current_device()}"
50 | )
51 | torch.cuda.init()
52 | return torch.device("cuda")
53 | if device_str == "mps":
54 | assert torch.backends.mps.is_available(), "No MPS backend is available!"
55 | logging.info("Using MPS GPU acceleration")
56 | return torch.device("mps")
57 |
58 | logging.info("Using CPU")
59 | return torch.device("cpu")
60 |
61 | def voigt_to_matrix(t: torch.Tensor):
62 | """
63 | Convert voigt notation to matrix notation
64 | :param t: (6,) tensor or (3, 3) tensor
65 | :return: (3, 3) tensor
66 | """
67 | if t.shape == (3, 3):
68 | return t
69 |
70 | return torch.tensor(
71 | [[t[0], t[5], t[4]], [t[5], t[1], t[3]], [t[4], t[3], t[2]]], dtype=t.dtype
72 | )
73 |
74 | def to_one_hot(indices: torch.Tensor, num_classes: int) -> torch.Tensor:
75 | """
76 | Generates one-hot encoding with classes from
77 | :param indices: (N x 1) tensor
78 | :param num_classes: number of classes
79 | :param device: torch device
80 | :return: (N x num_classes) tensor
81 | """
82 | shape = indices.shape[:-1] + (num_classes,)
83 | oh = torch.zeros(shape, device=indices.device).view(shape)
84 |
85 | # scatter_ is the in-place version of scatter
86 | oh.scatter_(dim=-1, index=indices, value=1)
87 |
88 | return oh.view(*shape)
89 |
--------------------------------------------------------------------------------
/cace/tools/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import sys
5 | from typing import Any, Dict, Iterable, Optional, Sequence, Union, List
6 | from ase import Atoms
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from . import torch_geometric
12 | from .torch_tools import to_numpy
13 |
14 | class AtomicNumberTable:
15 | def __init__(self, zs: Sequence[int]):
16 | self.zs = zs
17 |
18 | def __len__(self) -> int:
19 | return len(self.zs)
20 |
21 | def __str__(self):
22 | return f"AtomicNumberTable: {tuple(s for s in self.zs)}"
23 |
24 | def index_to_z(self, index: int) -> int:
25 | return self.zs[index]
26 |
27 | def z_to_index(self, atomic_number: str) -> int:
28 | return self.zs.index(atomic_number)
29 |
30 |
31 | def get_atomic_number_table_from_zs(zs: Iterable[int]) -> AtomicNumberTable:
32 | z_set = set()
33 | for z in zs:
34 | z_set.add(z)
35 | return AtomicNumberTable(sorted(list(z_set)))
36 |
37 |
38 | def atomic_numbers_to_indices(
39 | atomic_numbers: np.ndarray, z_table: AtomicNumberTable
40 | ) -> np.ndarray:
41 | to_index_fn = np.vectorize(z_table.z_to_index)
42 | return to_index_fn(atomic_numbers)
43 |
44 | def compute_avg_num_neighbors(batches: Union[torch.utils.data.DataLoader, torch_geometric.data.Data, torch_geometric.batch.Batch]) -> float:
45 | num_neighbors = []
46 |
47 | if isinstance(batches, torch_geometric.data.Data) or isinstance(batches, torch_geometric.batch.Batch):
48 | _, receivers = batches.edge_index
49 | _, counts = torch.unique(receivers, return_counts=True)
50 | num_neighbors.append(counts)
51 | elif isinstance(batches, torch.utils.data.DataLoader):
52 | for batch in batches:
53 | _, receivers = batch.edge_index
54 | _, counts = torch.unique(receivers, return_counts=True)
55 | num_neighbors.append(counts)
56 |
57 | avg_num_neighbors = torch.mean(
58 | torch.cat(num_neighbors, dim=0).type(torch.get_default_dtype())
59 | )
60 | return to_numpy(avg_num_neighbors).item()
61 |
62 | def get_unique_atomic_number(atoms_list: List[Atoms]) -> List[int]:
63 | """
64 | Read a multi-frame XYZ file and return a list of unique atomic numbers
65 | present across all frames.
66 |
67 | Returns:
68 | list: List of unique atomic numbers.
69 | """
70 | unique_atomic_numbers = set()
71 |
72 | for atoms in atoms_list:
73 | unique_atomic_numbers.update(atom.number for atom in atoms)
74 |
75 | return list(unique_atomic_numbers)
76 |
77 | def compute_average_E0s(
78 | atom_list: Atoms, zs: List[int] = None, energy_key: str = "energy"
79 | ) -> Dict[int, float]:
80 | """
81 | Function to compute the average interaction energy of each chemical element
82 | returns dictionary of E0s
83 | """
84 | len_xyz = len(atom_list)
85 | if zs is None:
86 | zs = get_unique_atomic_number(atom_list)
87 | # sort by atomic number
88 | zs.sort()
89 | len_zs = len(zs)
90 |
91 | A = np.zeros((len_xyz, len_zs))
92 | B = np.zeros(len_xyz)
93 | for i in range(len_xyz):
94 | B[i] = atom_list[i].info[energy_key]
95 | for j, z in enumerate(zs):
96 | A[i, j] = np.count_nonzero(atom_list[i].get_atomic_numbers() == z)
97 | try:
98 | E0s = np.linalg.lstsq(A, B, rcond=None)[0]
99 | atomic_energies_dict = {}
100 | for i, z in enumerate(zs):
101 | atomic_energies_dict[z] = E0s[i]
102 | except np.linalg.LinAlgError:
103 | logging.warning(
104 | "Failed to compute E0s using least squares regression, using the same for all atoms"
105 | )
106 | atomic_energies_dict = {}
107 | for i, z in enumerate(zs):
108 | atomic_energies_dict[z] = 0.0
109 | return atomic_energies_dict
110 |
111 | def setup_logger(
112 | level: Union[int, str] = logging.INFO,
113 | tag: Optional[str] = None,
114 | directory: Optional[str] = None,
115 | ):
116 | logger = logging.getLogger()
117 | logger.setLevel(level)
118 |
119 | formatter = logging.Formatter(
120 | "%(asctime)s.%(msecs)03d %(levelname)s: %(message)s",
121 | datefmt="%Y-%m-%d %H:%M:%S",
122 | )
123 |
124 | ch = logging.StreamHandler(stream=sys.stdout)
125 | ch.setFormatter(formatter)
126 | logger.addHandler(ch)
127 |
128 | if (directory is not None) and (tag is not None):
129 | os.makedirs(name=directory, exist_ok=True)
130 | path = os.path.join(directory, tag + ".log")
131 | fh = logging.FileHandler(path)
132 | fh.setFormatter(formatter)
133 |
134 | logger.addHandler(fh)
135 |
--------------------------------------------------------------------------------
/examples/water_train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | import sys
5 | sys.path.append('../cace/')
6 |
7 | import numpy as np
8 | #import matplotlib.pyplot as plt
9 | import torch
10 | import torch.nn as nn
11 | import logging
12 |
13 | import cace
14 | from cace.representations import Cace
15 | from cace.modules import CosineCutoff, MollifierCutoff, PolynomialCutoff
16 | from cace.modules import BesselRBF, GaussianRBF, GaussianRBFCentered
17 |
18 | from cace.models.atomistic import NeuralNetworkPotential
19 | from cace.tasks.train import TrainingTask
20 |
21 | torch.set_default_dtype(torch.float32)
22 |
23 | cace.tools.setup_logger(level='INFO')
24 |
25 | logging.info("reading data")
26 | collection = cace.tasks.get_dataset_from_xyz(train_path='../../datasets/water/water.xyz',
27 | valid_fraction=0.1,
28 | seed=1,
29 | energy_key='energy',
30 | forces_key='force',
31 | atomic_energies={1: -187.6043857100553, 8: -93.80219285502734} # avg
32 | )
33 | cutoff = 5.5
34 | batch_size = 2
35 |
36 | train_loader = cace.tasks.load_data_loader(collection=collection,
37 | data_type='train',
38 | batch_size=batch_size,
39 | cutoff=cutoff)
40 |
41 | valid_loader = cace.tasks.load_data_loader(collection=collection,
42 | data_type='valid',
43 | batch_size=4,
44 | cutoff=cutoff)
45 |
46 | use_device = 'cuda'
47 | device = cace.tools.init_device(use_device)
48 | logging.info(f"device: {use_device}")
49 |
50 |
51 | logging.info("building CACE representation")
52 | radial_basis = BesselRBF(cutoff=cutoff, n_rbf=6, trainable=True)
53 | #cutoff_fn = CosineCutoff(cutoff=cutoff)
54 | cutoff_fn = PolynomialCutoff(cutoff=cutoff)
55 |
56 | cace_representation = Cace(
57 | zs=[1,8],
58 | n_atom_basis=3,
59 | embed_receiver_nodes=True,
60 | cutoff=cutoff,
61 | cutoff_fn=cutoff_fn,
62 | radial_basis=radial_basis,
63 | n_radial_basis=12,
64 | max_l=3,
65 | max_nu=3,
66 | num_message_passing=1,
67 | type_message_passing=['Bchi'],
68 | device=device,
69 | timeit=False
70 | )
71 |
72 | cace_representation.to(device)
73 | logging.info(f"Representation: {cace_representation}")
74 |
75 | atomwise = cace.modules.atomwise.Atomwise(n_layers=3,
76 | output_key='CACE_energy',
77 | n_hidden=[32,16],
78 | use_batchnorm=False,
79 | add_linear_nn=True)
80 |
81 |
82 | forces = cace.modules.forces.Forces(energy_key='CACE_energy',
83 | forces_key='CACE_forces')
84 |
85 | logging.info("building CACE NNP")
86 | cace_nnp = NeuralNetworkPotential(
87 | input_modules=None,
88 | representation=cace_representation,
89 | output_modules=[atomwise, forces]
90 | )
91 |
92 | #trainable_params = sum(p.numel() for p in cace_nnp.parameters() if p.requires_grad)
93 | #logging.info(f"Number of trainable parameters: {trainable_params}")
94 |
95 | cace_nnp.to(device)
96 |
97 |
98 | logging.info(f"First train loop:")
99 | energy_loss = cace.tasks.GetLoss(
100 | target_name='energy',
101 | predict_name='CACE_energy',
102 | loss_fn=torch.nn.MSELoss(),
103 | loss_weight=0.1
104 | )
105 |
106 | force_loss = cace.tasks.GetLoss(
107 | target_name='forces',
108 | predict_name='CACE_forces',
109 | loss_fn=torch.nn.MSELoss(),
110 | loss_weight=1000
111 | )
112 |
113 | from cace.tools import Metrics
114 |
115 | e_metric = Metrics(
116 | target_name='energy',
117 | predict_name='CACE_energy',
118 | name='e/atom',
119 | per_atom=True
120 | )
121 |
122 | f_metric = Metrics(
123 | target_name='forces',
124 | predict_name='CACE_forces',
125 | name='f'
126 | )
127 |
128 | # Example usage
129 | logging.info("creating training task")
130 |
131 | optimizer_args = {'lr': 1e-2}
132 | scheduler_args = {'mode': 'min', 'factor': 0.8, 'patience': 10}
133 |
134 | for i in range(5):
135 | task = TrainingTask(
136 | model=cace_nnp,
137 | losses=[energy_loss, force_loss],
138 | metrics=[e_metric, f_metric],
139 | device=device,
140 | optimizer_args=optimizer_args,
141 | #scheduler_cls=torch.optim.lr_scheduler.StepLR,
142 | scheduler_cls=torch.optim.lr_scheduler.ReduceLROnPlateau,
143 | scheduler_args=scheduler_args,
144 | max_grad_norm=10,
145 | ema=True,
146 | ema_start=10,
147 | warmup_steps=5,
148 | )
149 |
150 | logging.info("training")
151 | task.fit(train_loader, valid_loader, epochs=40, screen_nan=False)
152 |
153 | task.save_model('water-model.pth')
154 | cace_nnp.to(device)
155 |
156 | logging.info(f"Second train loop:")
157 | energy_loss = cace.tasks.GetLoss(
158 | target_name='energy',
159 | predict_name='CACE_energy',
160 | loss_fn=torch.nn.MSELoss(),
161 | loss_weight=100
162 | )
163 |
164 | task.update_loss([energy_loss, force_loss])
165 | logging.info("training")
166 | task.fit(train_loader, valid_loader, epochs=100, screen_nan=False)
167 |
168 |
169 | task.save_model('water-model-2.pth')
170 | cace_nnp.to(device)
171 |
172 | logging.info(f"Finished")
173 |
174 |
175 |
176 |
177 |
178 |
--------------------------------------------------------------------------------
/examples/water_train_pl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | from cace.tasks import LightningData, LightningTrainingTask
5 |
6 | on_cluster = False
7 | if 'SLURM_JOB_CPUS_PER_NODE' in os.environ.keys():
8 | on_cluster = True
9 | if on_cluster:
10 | root = "/global/scratch/users/king1305/data/water.xyz"
11 | else:
12 | root = "/home/king1305/Apps/cacefit/fit-water/water.xyz"
13 |
14 | logs_directory = "lightning_logs"
15 | logs_name = "water_test"
16 | cutoff = 5.5
17 | avge0 = {1: -187.6043857100553, 8: -93.80219285502734}
18 | batch_size = 4
19 | training_epochs = 500
20 | data = LightningData(root,batch_size=batch_size,cutoff=cutoff,atomic_energies=avge0)
21 |
22 | from cace.representations import Cace
23 | from cace.modules import BesselRBF, GaussianRBF, GaussianRBFCentered
24 | from cace.modules import PolynomialCutoff
25 |
26 | #Model
27 | radial_basis = BesselRBF(cutoff=cutoff, n_rbf=6, trainable=True)
28 | cutoff_fn = PolynomialCutoff(cutoff=cutoff)
29 |
30 | representation = Cace(
31 | zs=[1,8],
32 | n_atom_basis=3,
33 | embed_receiver_nodes=True,
34 | cutoff=cutoff,
35 | cutoff_fn=cutoff_fn,
36 | radial_basis=radial_basis,
37 | n_radial_basis=12,
38 | max_l=3,
39 | max_nu=3,
40 | num_message_passing=1,
41 | type_message_passing=['Bchi'],
42 | args_message_passing={'Bchi': {'shared_channels': False, 'shared_l': False}},
43 | avg_num_neighbors=1,
44 | timeit=False
45 | )
46 |
47 | for batch in data.train_dataloader():
48 | exdatabatch = batch
49 | break
50 |
51 | from cace.models import NeuralNetworkPotential
52 | from cace.modules.atomwise import Atomwise
53 | from cace.modules.forces import Forces
54 |
55 | atomwise = Atomwise(n_layers=3,
56 | output_key="pred_energy",
57 | n_hidden=[32,16],
58 | n_out=1,
59 | use_batchnorm=False,
60 | add_linear_nn=True)
61 |
62 | forces = Forces(energy_key="pred_energy",
63 | forces_key="pred_force")
64 |
65 | model = NeuralNetworkPotential(
66 | input_modules=None,
67 | representation=representation,
68 | output_modules=[atomwise,forces]
69 | )
70 |
71 | #Losses
72 | from cace.tasks import GetLoss
73 | e_loss = GetLoss(
74 | target_name="energy",
75 | predict_name='pred_energy',
76 | loss_fn=torch.nn.MSELoss(),
77 | loss_weight=1,
78 | )
79 | f_loss = GetLoss(
80 | target_name="force",
81 | predict_name='pred_force',
82 | loss_fn=torch.nn.MSELoss(),
83 | loss_weight=1000,
84 | )
85 | losses = [e_loss,f_loss]
86 |
87 | #Metrics
88 | from cace.tools import Metrics
89 | e_metric = Metrics(
90 | target_name="energy",
91 | predict_name='pred_energy',
92 | name='e',
93 | metric_keys=["rmse"],
94 | per_atom=True,
95 | )
96 | f_metric = Metrics(
97 | target_name="force",
98 | predict_name='pred_force',
99 | metric_keys=["rmse"],
100 | name='f',
101 | )
102 | metrics = [e_metric,f_metric]
103 |
104 | #Init lazy layers
105 | for batch in data.train_dataloader():
106 | exdatabatch = batch
107 | break
108 | model(exdatabatch)
109 |
110 | #Check for checkpoint and restart if found:
111 | chkpt = None
112 | dev_run = False
113 | if os.path.isdir(f"lightning_logs/{logs_name}"):
114 | latest_version = None
115 | num = 0
116 | while os.path.isdir(f"lightning_logs/{logs_name}/version_{num}"):
117 | latest_version = f"lightning_logs/{logs_name}/version_{num}"
118 | num += 1
119 | if latest_version:
120 | chkpt = glob.glob(f"{latest_version}/checkpoints/*.ckpt")[0]
121 | if chkpt:
122 | print("Checkpoint found!",chkpt)
123 | print("Restarting...")
124 | dev_run = False
125 |
126 | progress_bar = True
127 | if on_cluster:
128 | torch.set_float32_matmul_precision('medium')
129 | progress_bar = False
130 | task = LightningTrainingTask(model,losses=losses,metrics=metrics,
131 | logs_directory="lightning_logs",name=logs_name,
132 | scheduler_args={'mode': 'min', 'factor': 0.8, 'patience': 10},
133 | optimizer_args={'lr': 0.01},
134 | )
135 | task.fit(data,dev_run=dev_run,max_epochs=training_epochs,chkpt=chkpt, progress_bar=progress_bar)
136 |
137 |
--------------------------------------------------------------------------------
/examples/water_train_w_ft_pl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | from cace.tasks import LightningData, LightningTrainingTask
5 |
6 | on_cluster = False
7 | if 'SLURM_JOB_CPUS_PER_NODE' in os.environ.keys():
8 | on_cluster = True
9 | if on_cluster:
10 | root = "/global/scratch/users/king1305/data/water.xyz"
11 | else:
12 | root = "/home/king1305/Apps/cacefit/fit-water/water.xyz"
13 |
14 | logs_directory = "lightning_logs"
15 | logs_name = "water_test"
16 | cutoff = 5.5
17 | avge0 = {1: -187.6043857100553, 8: -93.80219285502734}
18 | batch_size = 4
19 | data = LightningData(root,batch_size=batch_size,cutoff=cutoff,atomic_energies=avge0)
20 |
21 | training_epochs = 500
22 | tuning_epochs = 100
23 |
24 | from cace.representations import Cace
25 | from cace.modules import BesselRBF, GaussianRBF, GaussianRBFCentered
26 | from cace.modules import PolynomialCutoff
27 |
28 | #Model
29 | radial_basis = BesselRBF(cutoff=cutoff, n_rbf=6, trainable=True)
30 | cutoff_fn = PolynomialCutoff(cutoff=cutoff)
31 |
32 | representation = Cace(
33 | zs=[1,8],
34 | n_atom_basis=3,
35 | embed_receiver_nodes=True,
36 | cutoff=cutoff,
37 | cutoff_fn=cutoff_fn,
38 | radial_basis=radial_basis,
39 | n_radial_basis=12,
40 | max_l=3,
41 | max_nu=3,
42 | num_message_passing=1,
43 | type_message_passing=['Bchi'],
44 | args_message_passing={'Bchi': {'shared_channels': False, 'shared_l': False}},
45 | avg_num_neighbors=1,
46 | timeit=False
47 | )
48 |
49 | for batch in data.train_dataloader():
50 | exdatabatch = batch
51 | break
52 |
53 | from cace.models import NeuralNetworkPotential
54 | from cace.modules.atomwise import Atomwise
55 | from cace.modules.forces import Forces
56 |
57 | atomwise = Atomwise(n_layers=3,
58 | output_key="pred_energy",
59 | n_hidden=[32,16],
60 | n_out=1,
61 | use_batchnorm=False,
62 | add_linear_nn=True)
63 |
64 | forces = Forces(energy_key="pred_energy",
65 | forces_key="pred_force")
66 |
67 | model = NeuralNetworkPotential(
68 | input_modules=None,
69 | representation=representation,
70 | output_modules=[atomwise,forces]
71 | )
72 |
73 | #Losses
74 | from cace.tasks import GetLoss
75 | e_loss = GetLoss(
76 | target_name="energy",
77 | predict_name='pred_energy',
78 | loss_fn=torch.nn.MSELoss(),
79 | loss_weight=1,
80 | )
81 | f_loss = GetLoss(
82 | target_name="force",
83 | predict_name='pred_force',
84 | loss_fn=torch.nn.MSELoss(),
85 | loss_weight=1000,
86 | )
87 | losses = [e_loss,f_loss]
88 |
89 | #Metrics
90 | from cace.tools import Metrics
91 | e_metric = Metrics(
92 | target_name="energy",
93 | predict_name='pred_energy',
94 | name='e',
95 | metric_keys=["rmse"],
96 | per_atom=True,
97 | )
98 | f_metric = Metrics(
99 | target_name="force",
100 | predict_name='pred_force',
101 | metric_keys=["rmse"],
102 | name='f',
103 | )
104 | metrics = [e_metric,f_metric]
105 |
106 | #Init lazy layers
107 | for batch in data.train_dataloader():
108 | exdatabatch = batch
109 | break
110 | model(exdatabatch)
111 |
112 | #Check for checkpoint and restart if found:
113 | chkpt = None
114 | dev_run = False
115 | if os.path.isdir(f"lightning_logs/{logs_name}"):
116 | latest_version = None
117 | num = 0
118 | while os.path.isdir(f"lightning_logs/{logs_name}/version_{num}"):
119 | latest_version = f"lightning_logs/{logs_name}/version_{num}"
120 | num += 1
121 | if latest_version:
122 | chkpt = glob.glob(f"{latest_version}/checkpoints/*.ckpt")[0]
123 | if chkpt:
124 | print("Checkpoint found!",chkpt)
125 | print("Restarting...")
126 | dev_run = False
127 |
128 | progress_bar = True
129 | if on_cluster:
130 | torch.set_float32_matmul_precision('medium')
131 | progress_bar = False
132 | task = LightningTrainingTask(model,losses=losses,metrics=metrics,
133 | logs_directory="lightning_logs",name=logs_name,
134 | scheduler_args={'mode': 'min', 'factor': 0.8, 'patience': 10},
135 | optimizer_args={'lr': 0.01},
136 | )
137 | task.fit(data,dev_run=dev_run,max_epochs=training_epochs,chkpt=chkpt, progress_bar=progress_bar)
138 |
139 | #If you want to do the "fine-tuning" w/higher energy loss:
140 | if tuning_epochs > 0:
141 | os.system(f"mv lightning_logs/{logs_name}/best_model.pth lightning_logs/{logs_name}/best_model_noft.pth")
142 | e_loss = GetLoss(
143 | target_name="energy",
144 | predict_name='pred_energy',
145 | loss_fn=torch.nn.MSELoss(),
146 | loss_weight=1000,
147 | )
148 | losses = [e_loss,f_loss]
149 |
150 | #Search for checkpoint
151 | latest_version = None
152 | num = 0
153 | while os.path.isdir(f"lightning_logs/{logs_name}/version_{num}"):
154 | latest_version = f"lightning_logs/{logs_name}/version_{num}"
155 | num += 1
156 | if latest_version:
157 | chkpt = glob.glob(f"{latest_version}/checkpoints/*.ckpt")[0]
158 | task = LightningTrainingTask(model,losses=losses,metrics=metrics,
159 | logs_directory="lightning_logs",name=logs_name,
160 | scheduler_args={'mode': 'min', 'factor': 0.8, 'patience': 10},
161 | optimizer_args={'lr': 0.01},
162 | )
163 | #Uses previous checkpoint params + lr:
164 | task.fit(data,dev_run=dev_run,max_epochs=training_epochs+tuning_epochs,chkpt=chkpt,
165 | progress_bar=progress_bar)
166 | # restrat the trainer:
167 | #task.fit(data,dev_run=dev_run,max_epochs=tuning_epochs,
168 | # progress_bar=progress_bar)
169 |
170 |
171 |
--------------------------------------------------------------------------------
/scripts/cace_alchemical_rep_v0.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BingqingCheng/cace/82af6fc87733d2f4bf729754313cf77f1293ce72/scripts/cace_alchemical_rep_v0.pth
--------------------------------------------------------------------------------
/scripts/compute_avg_e0.py:
--------------------------------------------------------------------------------
1 | import cace
2 | import pickle
3 | from ase.io import read
4 | import sys
5 |
6 | if len(sys.argv) != 3:
7 | print('Usage: python compute_avg_e0.py xyzfile stride')
8 | sys.exit()
9 |
10 | stride = int(sys.argv[2])
11 |
12 | # read the xyz file and compute the average E0s
13 | xyz = read(sys.argv[1], index=slice(0, None, stride))
14 | avge0 = cace.tools.compute_average_E0s(xyz)
15 |
16 | print('Average E0s:', avge0)
17 | # save the avge0 dict to a file
18 | with open('avge0.pkl', 'wb') as f:
19 | pickle.dump(avge0, f)
20 |
--------------------------------------------------------------------------------
/scripts/compute_cace_desc.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | from tqdm import tqdm
4 | import numpy as np
5 | import torch
6 | import ase
7 | from ase import Atoms
8 | from ase.io import read,write
9 |
10 | import cace
11 | from cace.data import AtomicData
12 | from cace.representations.cace_representation import Cace
13 | from cace.tools import to_numpy
14 | from cace.tools import scatter_sum
15 |
16 | cutoff = 4.0
17 | batch_size = 10
18 |
19 |
20 | # Get the directory of the current script
21 | script_dir = os.path.dirname(os.path.abspath(__file__))
22 | data_file_path = os.path.join(script_dir, 'cace_alchemical_rep_v0.pth')
23 | cace_repr = torch.load(data_file_path, map_location='cpu')
24 |
25 | data = read(sys.argv[1], ":")
26 | dataset=[
27 | AtomicData.from_atoms(
28 | atom, cutoff=cutoff
29 | )
30 | for atom in data
31 | ]
32 |
33 | data_loader = cace.tools.torch_geometric.dataloader.DataLoader(
34 | dataset,
35 | batch_size=batch_size,
36 | shuffle=False,
37 | drop_last=False,
38 | )
39 |
40 | n_frame = 0
41 | for sampled_data in tqdm(data_loader):
42 | cace_result_more = cace_repr(sampled_data)
43 | avg_B = scatter_sum(
44 | src=cace_result_more['node_feats'],
45 | index=sampled_data["batch"],
46 | dim=0
47 | )
48 | #print(avg_B.shape)
49 | #print(torch.bincount(sampled_data['batch']))
50 | n_configs = avg_B.shape[0]
51 | avg_B_flat = to_numpy(avg_B.reshape(n_configs, -1) / torch.bincount(sampled_data['batch']).reshape(-1, 1))
52 | #print(avg_B_flat.shape)
53 | for i in range(n_configs):
54 | data[i+n_frame].info['CACE_desc'] = avg_B_flat[i]
55 | n_frame += n_configs
56 |
57 | # check if sys.argv[2] exists
58 | if len(sys.argv) > 2:
59 | prefix = sys.argv[2]
60 | else:
61 | prefix = 'CACE_desc'
62 | write(prefix+'.xyz', data)
63 |
64 |
65 |
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/scripts/split.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import pickle
3 | import torch
4 | import numpy as np
5 | import sys
6 | import argparse
7 |
8 | from ase.io import read
9 | from cace.data import AtomicData
10 | from cace.tools import save_dataset
11 |
12 | def main():
13 | # Parse the command line arguments
14 | parser = argparse.ArgumentParser(description='Split a dataset into multiple parts')
15 | parser.add_argument('--input_file', type=str, help='Path to the input dataset')
16 | parser.add_argument('--num_splits', type=int, help='Number of splits', default=4)
17 | parser.add_argument('--output_prefix', type=str, help='Prefix for the output files', default='split')
18 | parser.add_argument('--shuffle', action='store_true', help='Shuffle the dataset before splitting')
19 | parser.add_argument('--cutoff', type=float, help='Cutoff radius for the atomic environment')
20 | parser.add_argument('--energy', type=str, help='Key for the energy data', default='energy')
21 | parser.add_argument('--forces', type=str, help='Key for the forces data', default='forces')
22 | parser.add_argument('--atomic_energies', type=str, help='file for the atomic energies', default=None)
23 | args = parser.parse_args()
24 |
25 | all_xyz_path = args.input_file
26 | num_splits = args.num_splits
27 | cutoff = args.cutoff
28 | data_key = {'energy': args.energy, 'forces': args.forces}
29 | atomic_energies = pickle.load(open(args.atomic_energies, 'rb')) if args.atomic_energies is not None else None
30 |
31 | for i in range(num_splits):
32 | all_xyz = read(all_xyz_path, index=slice(i, None, num_splits))
33 |
34 | dataset=[
35 | AtomicData.from_atoms(atoms, cutoff=cutoff, data_key=data_key, atomic_energies=atomic_energies).to_dict() # Convert to dictionary
36 | for atoms in all_xyz
37 | ]
38 |
39 | shuffle = args.shuffle # Set to True if you want to shuffle the data
40 | save_dataset(dataset, args.output_prefix+'_'+str(i)+'.h5')
41 |
42 | if __name__ == '__main__':
43 | main()
44 |
45 | """
46 | load the dataset
47 | # Load the dataset
48 | loaded_dataset = load_dataset('test.h5')
49 | print(loaded_dataset[0].keys())
50 |
51 | dataset = [
52 | AtomicData(**data) # Convert to AtomicData object
53 | for data in loaded_dataset
54 | ]
55 | """
56 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import ase.io
4 | import cace
5 | import pickle
6 | import os
7 | from cace.representations import Cace
8 | from cace.modules import PolynomialCutoff, BesselRBF, Atomwise, Forces
9 | from cace.models.atomistic import NeuralNetworkPotential
10 | from cace.tasks.train import TrainingTask, GetLoss
11 | from cace.tools import Metrics, init_device, compute_average_E0s, setup_logger, get_unique_atomic_number
12 | from cace.tools import parse_arguments
13 |
14 | def main():
15 | args = parse_arguments()
16 |
17 | setup_logger(level='INFO', tag=args.prefix, directory='./')
18 | device = init_device(args.use_device)
19 |
20 | if args.zs is None:
21 | xyz = ase.io.read(args.train_path, ':')
22 | args.zs = get_unique_atomic_number(xyz)
23 |
24 | # load the avge0 dict from a file if possible
25 | if os.path.exists('avge0.pkl'):
26 | with open('avge0.pkl', 'rb') as f:
27 | avge0 = pickle.load(f)
28 | else:
29 | # Load Dataset
30 | avge0 = compute_average_E0s(xyz)
31 | with open('avge0.pkl', 'wb') as f:
32 | pickle.dump(avge0, f)
33 |
34 | # Prepare Data Loaders
35 | collection = cace.tasks.get_dataset_from_xyz(
36 | train_path=args.train_path,
37 | valid_fraction=args.valid_fraction,
38 | data_key={'energy': args.energy_key, 'forces': args.forces_key},
39 | atomic_energies=avge0,
40 | cutoff=args.cutoff)
41 |
42 | train_loader = cace.tasks.load_data_loader(
43 | collection=collection,
44 | data_type='train',
45 | batch_size=args.batch_size)
46 |
47 | valid_loader = cace.tasks.load_data_loader(
48 | collection=collection,
49 | data_type='valid',
50 | batch_size=args.valid_batch_size)
51 |
52 | # Configure CACE Representation
53 | cutoff_fn = PolynomialCutoff(cutoff=args.cutoff, p=args.cutoff_fn_p)
54 | radial_basis = BesselRBF(cutoff=args.cutoff, n_rbf=args.n_rbf, trainable=args.trainable_rbf)
55 | cace_representation = Cace(
56 | zs=args.zs, n_atom_basis=args.n_atom_basis, embed_receiver_nodes=args.embed_receiver_nodes,
57 | cutoff=args.cutoff, cutoff_fn=cutoff_fn, radial_basis=radial_basis,
58 | n_radial_basis=args.n_radial_basis, max_l=args.max_l, max_nu=args.max_nu,
59 | device=device, num_message_passing=args.num_message_passing)
60 |
61 | # Configure Atomwise Module
62 | atomwise = Atomwise(
63 | n_layers=args.atomwise_layers, n_hidden=args.atomwise_hidden, residual=args.atomwise_residual,
64 | use_batchnorm=args.atomwise_batchnorm, add_linear_nn=args.atomwise_linear_nn,
65 | output_key='CACE_energy')
66 |
67 | # Configure Forces Module
68 | forces = Forces(energy_key='CACE_energy', forces_key='CACE_forces')
69 |
70 | # Assemble Neural Network Potential
71 | cace_nnp = NeuralNetworkPotential(representation=cace_representation, output_modules=[atomwise, forces]).to(device)
72 |
73 | # Phase 1 Training Configuration
74 | optimizer_args = {'lr': args.lr}
75 | scheduler_args = {'mode': 'min', 'factor': args.scheduler_factor, 'patience': args.scheduler_patience}
76 | energy_loss = GetLoss(
77 | target_name='energy',
78 | predict_name='CACE_energy',
79 | loss_fn=torch.nn.MSELoss(),
80 | loss_weight=args.energy_loss_weight)
81 | force_loss = GetLoss(
82 | target_name='forces',
83 | predict_name='CACE_forces',
84 | loss_fn=torch.nn.MSELoss(),
85 | loss_weight=args.force_loss_weight)
86 |
87 |
88 | e_metric = Metrics(
89 | target_name='energy',
90 | predict_name='CACE_energy',
91 | name='e/atom',
92 | per_atom=True
93 | )
94 |
95 | f_metric = Metrics(
96 | target_name='forces',
97 | predict_name='CACE_forces',
98 | name='f'
99 | )
100 |
101 | for _ in range(args.num_restart):
102 | # Initialize and Fit Training Task for Phase 1
103 | task = TrainingTask(
104 | model=cace_nnp, losses=[energy_loss, force_loss], metrics=[e_metric, f_metric],
105 | device=device, optimizer_args=optimizer_args, scheduler_cls=torch.optim.lr_scheduler.ReduceLROnPlateau,
106 | scheduler_args=scheduler_args, max_grad_norm=args.max_grad_norm, ema=args.ema,
107 | ema_start=args.ema_start, warmup_steps=args.warmup_steps)
108 |
109 | task.fit(train_loader, valid_loader, epochs=int(args.epochs/args.num_restart), print_stride=0)
110 | task.save_model(args.prefix+'_phase_1.pth')
111 |
112 | # Phase 2 Training Adjustment
113 | energy_loss_2 = GetLoss('energy', 'CACE_energy', torch.nn.MSELoss(), args.second_phase_energy_loss_weight)
114 | task.update_loss([energy_loss_2, force_loss])
115 |
116 | # Fit Training Task for Phase 2
117 | task.fit(train_loader, valid_loader, epochs=args.second_phase_epochs)
118 | task.save_model(args.prefix+'_phase_2.pth')
119 |
120 | if __name__ == '__main__':
121 | main()
122 |
123 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='CACE',
5 | version='0.1.0',
6 | author='Bingqing Cheng',
7 | author_email='tonicbq@gmail.com',
8 | description='Cartesian Atomic Cluster Expansion Machine Learning Potential',
9 | packages=find_packages(),
10 | install_requires=[
11 | 'numpy',
12 | 'ase<=3.22.1',
13 | 'torch',
14 | 'matscipy',
15 | ],
16 | classifiers=[
17 | 'Programming Language :: Python :: 3',
18 | 'License :: OSI Approved :: MIT License',
19 | 'Operating System :: OS Independent',
20 | ],
21 | python_requires='>=3.6',
22 | )
23 |
24 |
--------------------------------------------------------------------------------
/tests/ewald_nopbc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cace
3 | from cace.modules import EwaldPotential
4 | import sys
5 | torch.set_default_dtype(torch.float32)
6 |
7 | ep = EwaldPotential(dl=1.,
8 | sigma=2,
9 | exponent=1,
10 | feature_key='q',
11 | aggregation_mode='sum',
12 | remove_self_interaction=True,
13 | compute_field=True)
14 |
15 | # set the same random seed for reproducibility
16 | torch.manual_seed(sys.argv[1])
17 | r = torch.rand(10, 3) * 8 # Random positions in a 10x10x10 box
18 | q = torch.rand(10) * 2 # Random charges
19 |
20 | q -= torch.mean(q)
21 | box = torch.tensor([30.0, 30.0, 30.0], dtype=torch.float32) # Box dimensions
22 |
23 | #print(q)
24 | #exit()
25 |
26 | # Replicate the box 2x2x2 times
27 |
28 | ew_1, field_1 = ep.compute_potential_optimized(torch.tensor(r), torch.tensor(q).unsqueeze(1), torch.tensor(box), compute_field=True)
29 | ew_1_s, field_1_s = ep.compute_potential_realspace(torch.tensor(r), torch.tensor(q), compute_field=True)
30 | print(ew_1, ew_1_s)
31 | print(ew_1.shape, ew_1_s.shape)
32 | print(field_1, field_1_s)
33 | print(field_1.shape, field_1_s.shape)
34 | print(torch.sum(q.unsqueeze(1) * field_1 / 2), torch.sum(q.unsqueeze(1) * field_1_s / 2))
35 |
36 |
--------------------------------------------------------------------------------
/tests/test-cace-representation-rotation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "8e21a83f",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%load_ext autoreload\n",
11 | "%autoreload 2"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "id": "6f361918",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "import os\n",
22 | "os.environ['KMP_DUPLICATE_LIB_OK']='True'"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "id": "7ab78b11",
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "import numpy as np\n",
33 | "import torch\n",
34 | "import torch.nn as nn"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "id": "6c87d406",
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "import sys\n",
45 | "sys.path.append('../')"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "id": "11273bc5",
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "import cace\n",
56 | "from cace import data"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "id": "080d6fbd",
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "from scipy.spatial.transform import Rotation as R"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "id": "810cd94b",
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "torch.set_default_dtype(torch.float64)"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "id": "dc966daf",
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "from ase import Atoms"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "id": "61e728d9",
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "atom=Atoms(numbers=np.array([8, 1, 1, 1]),\n",
97 | " positions=np.array(\n",
98 | " [\n",
99 | " [0.0452, -2.02, 0.0452],\n",
100 | " [1.0145, 0.034, 0.0232],\n",
101 | " [0.0111, 1.041, -0.010],\n",
102 | " [-0.0111, -0.041, 0.510],\n",
103 | " ]\n",
104 | " ),\n",
105 | " pbc=False)\n",
106 | "\n",
107 | "# Created the rotated environment\n",
108 | "rot = R.from_euler(\"z\", 70, degrees=True).as_matrix()\n",
109 | "positions_rotated = np.array(rot @ atom.positions.T).T\n",
110 | "\n",
111 | "rot = R.from_euler(\"x\", 10.6, degrees=True).as_matrix()\n",
112 | "positions_rotated = np.array(rot @ positions_rotated.T).T\n",
113 | "\n",
114 | "rot = R.from_euler(\"y\", 190, degrees=True).as_matrix()\n",
115 | "positions_rotated = np.array(rot @ positions_rotated.T).T\n",
116 | "\n",
117 | "atom_rotated=Atoms(numbers=np.array([8, 1, 1, 1]),\n",
118 | " positions=positions_rotated,\n",
119 | " pbc=False)"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "id": "3aa4259a",
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "cutoff = 5.0"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "id": "4d068011",
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "atomic_data = data.AtomicData.from_atoms(atom, cutoff=cutoff)\n",
140 | "atomic_data2 = data.AtomicData.from_atoms(atom_rotated, cutoff=cutoff)"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "id": "d2f316e0",
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "atomic_data.positions"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "id": "02c3623d",
157 | "metadata": {},
158 | "outputs": [],
159 | "source": [
160 | "atomic_data2.positions"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "id": "cc49098a",
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "from cace.representations.cace_representation import Cace"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": null,
176 | "id": "5de73a93",
177 | "metadata": {},
178 | "outputs": [],
179 | "source": [
180 | "atomic_data"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "id": "6fdb68fe",
187 | "metadata": {},
188 | "outputs": [],
189 | "source": [
190 | "from cace.modules import CosineCutoff, MollifierCutoff, PolynomialCutoff\n",
191 | "from cace.modules import BesselRBF, GaussianRBF, GaussianRBFCentered, ExponentialDecayRBF"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "id": "d9632c86",
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "radial_basis = BesselRBF(cutoff=cutoff, n_rbf=5, trainable=False)\n",
202 | "#radial_basis = ExponentialDecayRBF(n_rbf=4, cutoff=cutoff, prefactor=1, trainable=True)\n",
203 | "cutoff_fn = CosineCutoff(cutoff=cutoff)"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "id": "a6f21fb8",
210 | "metadata": {},
211 | "outputs": [],
212 | "source": [
213 | "cace_representation = Cace(\n",
214 | " zs=[1,8],\n",
215 | " n_atom_basis=2,\n",
216 | " cutoff=cutoff,\n",
217 | " cutoff_fn=cutoff_fn,\n",
218 | " radial_basis=radial_basis,\n",
219 | " max_l=6,\n",
220 | " max_nu=4,\n",
221 | " num_message_passing=2,\n",
222 | " timeit=True\n",
223 | " )"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": null,
229 | "id": "e92da60c",
230 | "metadata": {},
231 | "outputs": [],
232 | "source": [
233 | "cace_result = cace_representation(atomic_data)"
234 | ]
235 | },
236 | {
237 | "cell_type": "code",
238 | "execution_count": null,
239 | "id": "32c0bdcb",
240 | "metadata": {},
241 | "outputs": [],
242 | "source": [
243 | "cace_result2 = cace_representation(atomic_data2)"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": null,
249 | "id": "6a65daec",
250 | "metadata": {},
251 | "outputs": [],
252 | "source": [
253 | "features = cace_result['node_feats']"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": null,
259 | "id": "39e4436f",
260 | "metadata": {},
261 | "outputs": [],
262 | "source": [
263 | "features2 = cace_result2['node_feats']"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": null,
269 | "id": "55cfd8e5",
270 | "metadata": {},
271 | "outputs": [],
272 | "source": [
273 | "torch.allclose(features, features2, rtol=1e-05, atol=1e-05)"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": null,
279 | "id": "582511b2",
280 | "metadata": {},
281 | "outputs": [],
282 | "source": [
283 | "features.shape"
284 | ]
285 | },
286 | {
287 | "cell_type": "code",
288 | "execution_count": null,
289 | "id": "425f0a9e",
290 | "metadata": {},
291 | "outputs": [],
292 | "source": [
293 | "features[0,2,:,0,0]"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": null,
299 | "id": "88a1895e",
300 | "metadata": {},
301 | "outputs": [],
302 | "source": [
303 | "features[0,1,:,0,1]"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": null,
309 | "id": "5f7d3b23",
310 | "metadata": {},
311 | "outputs": [],
312 | "source": [
313 | "features[0,1,:,0,2]"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": null,
319 | "id": "21467404",
320 | "metadata": {},
321 | "outputs": [],
322 | "source": [
323 | "features2[0,1,:,0,1]"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": null,
329 | "id": "f8619177",
330 | "metadata": {},
331 | "outputs": [],
332 | "source": []
333 | }
334 | ],
335 | "metadata": {
336 | "kernelspec": {
337 | "display_name": "Python 3 (ipykernel)",
338 | "language": "python",
339 | "name": "python3"
340 | },
341 | "language_info": {
342 | "codemirror_mode": {
343 | "name": "ipython",
344 | "version": 3
345 | },
346 | "file_extension": ".py",
347 | "mimetype": "text/x-python",
348 | "name": "python",
349 | "nbconvert_exporter": "python",
350 | "pygments_lexer": "ipython3",
351 | "version": "3.8.16"
352 | }
353 | },
354 | "nbformat": 4,
355 | "nbformat_minor": 5
356 | }
357 |
--------------------------------------------------------------------------------
/tests/test_angular.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | import time
4 | import torch
5 | from cace.modules import AngularComponent, AngularComponent_GPU
6 |
7 | vectors = torch.rand(10000, 3)
8 |
9 | start_time = time.time()
10 | angular_func = AngularComponent(3)
11 | angular_component = angular_func(vectors)
12 | end_time = time.time()
13 | print(f"Execution time AngularComponent function: {end_time - start_time} seconds")
14 |
15 |
16 | start_time = time.time()
17 | angular_func_GPU = AngularComponent_GPU(3)
18 | angular_component_GPU = angular_func_GPU(vectors)
19 | end_time = time.time()
20 | print(f"Execution time AngularComponent_GPU function: {end_time - start_time} seconds")
21 |
22 | #not supposed to be the same as l_list is different
23 | #print(torch.allclose(angular_component, angular_component_GPU))
24 |
--------------------------------------------------------------------------------
/tests/test_edge_encoder.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | import torch
4 | import cace
5 |
6 | from cace.modules import EdgeEncoder
7 |
8 | encoder = EdgeEncoder(directed=True)
9 | edges = torch.tensor([
10 | [[0, 1], [0, 1]],
11 | [[0, 1], [1, 0]],
12 | [[1, 0], [0, 1]],
13 | [[1, 0], [1, 0]],
14 | ])
15 | encoded_edges = edge_coding(edges)
16 | print("edges:", edges)
17 | print(encoded_edges)
18 |
19 | encoder = EdgeEncoder(directed=False)
20 | edges = torch.tensor([[[0, 0.2], [0.7, 0]], [[1, 0], [0, 1]],[[1, 0], [1, 0]], [[0.7, 0], [0.0, 0.2]]])
21 | encoded_edges = encoder(edges)
22 | print("edges:", edges)
23 | print(encoded_edges)
24 |
25 |
26 |
--------------------------------------------------------------------------------
/tests/test_elementwise_multiply_tensors.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | import torch
4 |
5 | from cace.tools import elementwise_multiply_2tensors, elementwise_multiply_3tensors
6 |
7 | tensor1 = torch.rand([996, 20])
8 | tensor2 = torch.rand([996, 8])
9 | result = elementwise_multiply_2tensors(tensor1,tensor2)
10 | print(result.shape)
11 | assert torch.equal(result[0,0], tensor2[0] * tensor1[0,0]), "Tensors are not equal."
12 |
13 | tensor1 = torch.rand([996, 20])
14 | tensor2 = torch.rand([996, 8])
15 | tensor3 = torch.rand([996, 6])
16 | result = elementwise_multiply_3tensors(tensor1,tensor2,tensor3)
17 | print(result.shape)
18 | # TODO: Fix this test
19 | #assert torch.equal(result[0,0], tensor2[0] * tensor1[0,0] * tensor3[0,0]), "Tensors are not equal."
20 |
21 |
22 |
--------------------------------------------------------------------------------
/tests/test_ewald.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cace
3 | from cace.modules import EwaldPotential
4 |
5 | ep = EwaldPotential(dl=2,
6 | sigma=1,
7 | exponent=1,
8 | feature_key='q',
9 | aggregation_mode='sum')
10 |
11 | def replicate_box(r, q, box, nx=2, ny=2, nz=2):
12 | """Replicate the simulation box nx, ny, nz times in each direction."""
13 | n_atoms = r.shape[0]
14 | replicated_r = []
15 | replicated_q = []
16 |
17 | for ix in range(nx):
18 | for iy in range(ny):
19 | for iz in range(nz):
20 | shift = torch.tensor([ix, iy, iz], dtype=r.dtype, device=r.device) * box
21 | replicated_r.append(r + shift)
22 | replicated_q.append(q)
23 |
24 | replicated_r = torch.cat(replicated_r)
25 | replicated_q = torch.cat(replicated_q)
26 |
27 | new_box = torch.tensor([nx, ny, nz], dtype=r.dtype, device=r.device) * box
28 | return replicated_r, replicated_q, new_box
29 |
30 | # set the same random seed for reproducibility
31 | torch.manual_seed(0)
32 | r = torch.rand(100, 3) * 10 # Random positions in a 10x10x10 box
33 | q = torch.rand(100) * 2 - 1 # Random charges
34 |
35 | #q -= torch.mean(q)
36 | box = torch.tensor([10.0, 10.0, 10.0], dtype=torch.float64) # Box dimensions
37 |
38 | #print(q)
39 | #exit()
40 |
41 | # Replicate the box 2x2x2 times
42 | replicated_r, replicated_q, new_box = replicate_box(r, q, box, nx=2, ny=2, nz=2)
43 |
44 | ew_1 = ep.compute_potential_optimized(torch.tensor(r), torch.tensor(q).unsqueeze(1), torch.tensor(box))
45 | ew_1_s = ep.compute_potential(torch.tensor(r), torch.tensor(q).unsqueeze(1), torch.tensor(box))
46 | ew_2 = ep.compute_potential_optimized(replicated_r, replicated_q.unsqueeze(1), new_box) / 8
47 | ew_2_s = ep.compute_potential(replicated_r, replicated_q.unsqueeze(1), new_box) / 8
48 | print(ew_1, ew_2, ew_1_s, ew_2_s)
49 |
50 |
--------------------------------------------------------------------------------
/tests/test_ewald_triclinic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 | sys.path.append('../')
4 | import cace
5 | from cace.modules import EwaldPotential
6 |
7 | ep = EwaldPotential(dl=1.5,
8 | sigma=1,
9 | exponent=1,
10 | feature_key='q',
11 | aggregation_mode='sum')
12 |
13 | def replicate_box(r, q, box, nx=2, ny=2, nz=2):
14 | """Replicate the simulation box nx, ny, nz times in each direction."""
15 | n_atoms = r.shape[0]
16 | replicated_r = []
17 | replicated_q = []
18 |
19 | for ix in range(nx):
20 | for iy in range(ny):
21 | for iz in range(nz):
22 | shift = torch.tensor([ix, iy, iz], dtype=r.dtype, device=r.device) * box
23 | replicated_r.append(r + shift)
24 | replicated_q.append(q)
25 |
26 | replicated_r = torch.cat(replicated_r)
27 | replicated_q = torch.cat(replicated_q)
28 |
29 | new_box = torch.tensor([nx, ny, nz], dtype=r.dtype, device=r.device) * box
30 | return replicated_r, replicated_q, new_box
31 |
32 | # set the same random seed for reproducibility
33 | torch.manual_seed(0)
34 | r = torch.rand(100, 3, ) * 10 # Random positions in a 10x10x10 box
35 | q = torch.rand(100) * 2 - 1 # Random charges
36 |
37 | box = torch.tensor([10.0, 10.0, 10.0], dtype=torch.float64) # Box dimensions
38 | box_3d = torch.tensor([[10.0, 0.0, 0.0],
39 | [0.0, 10.0, 0.0],
40 | [0.0, 0.0, 10.0]])
41 | box_3d_2 = torch.tensor([[20.0, 0.0, 0.0],
42 | [0.0, 20.0, 0.0],
43 | [0.0, 0.0, 20.0]])
44 |
45 | # Replicate the box 2x2x2 times
46 | replicated_r, replicated_q, new_box = replicate_box(r, q, box, nx=2, ny=2, nz=2)
47 |
48 | ew_1 = ep.compute_potential_optimized(torch.tensor(r), torch.tensor(q).unsqueeze(1), torch.tensor(box))
49 | ew_1_s = ep.compute_potential(torch.tensor(r), torch.tensor(q).unsqueeze(1), torch.tensor(box))
50 | ew_tri = ep.compute_potential_triclinic(torch.tensor(r), torch.tensor(q).unsqueeze(1), torch.tensor(box_3d))
51 | ew_2 = ep.compute_potential_optimized(replicated_r, replicated_q.unsqueeze(1), new_box)
52 | ew_2_s = ep.compute_potential(replicated_r, replicated_q.unsqueeze(1), new_box)
53 | ew_2_tri = ep.compute_potential_triclinic(replicated_r.to(dtype=torch.float32), replicated_q.to(dtype=torch.float32).unsqueeze(1), box_3d_2)
54 | print('###cubic cell test###')
55 | print(ew_1[0], ew_1_s[0], ew_tri[0])
56 | print(ew_2[0], ew_2_s[0], ew_2_tri[0])
57 | print(ew_2[0]/ew_1[0], ew_2_s[0]/ew_1_s[0], ew_2_tri[0]/ew_tri[0])
58 |
59 |
60 | #triclinic cell test
61 | def replicate_box_tri(r, q, box, nx=2, ny=2, nz=2):
62 | replicated_r = []
63 | replicated_q = []
64 | for ix in range(nx):
65 | for iy in range(ny):
66 | for iz in range(nz):
67 | shift = ix * box[0, :] + iy * box[1, :] + iz * box[2, :]
68 | replicated_r.append(r + shift)
69 | replicated_q.append(q)
70 | replicated_r = torch.cat(replicated_r, dim=0)
71 | replicated_q = torch.cat(replicated_q, dim=0)
72 | new_box = torch.stack([nx * box[0, :], ny * box[1, :], nz * box[2, :]], dim=0)
73 | return replicated_r, replicated_q, new_box
74 |
75 |
76 | box_tric = torch.tensor([[10.0, 2.0, 1.0],
77 | [0.0, 9.0, 1.5],
78 | [0.0, 0.0, 10.0]])
79 |
80 | s_rand = torch.rand(100, 3)
81 | r_tric = torch.matmul(s_rand, box_tric) # Random positions in a triclinic box
82 | q_tric = torch.rand(100) * 2 - 1
83 |
84 | ep = cace.modules.EwaldPotential(dl=2, sigma=1, exponent=1, feature_key='q', aggregation_mode='sum')
85 |
86 |
87 | ew_tric = ep.compute_potential_triclinic(r_tric, q_tric.unsqueeze(1), box_tric)
88 |
89 | # Replicate the cell 2x2x2 times.
90 | rep_r, rep_q, new_box = replicate_box_tri(r_tric, q_tric, box_tric, nx=2, ny=2, nz=2)
91 | ew_tric_rep = ep.compute_potential_triclinic(rep_r, rep_q.unsqueeze(1), new_box)
92 |
93 | print('###triclinic cell test###')
94 | print("Triclinic energy (original):", ew_tric[0])
95 | print("Triclinic energy (replicated):", ew_tric_rep[0])
96 | print("Ratio:", ew_tric_rep[0] / ew_tric[0])
--------------------------------------------------------------------------------
/tests/test_neighborhood.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../')
3 | import numpy as np
4 | from ase.io import read
5 | import cace
6 | #from cace import data
7 |
8 | atoms = read('../datasets/water.xyz','0')
9 | config = cace.data.config_from_atoms(atoms, energy_key ='TotEnergy', forces_key='force')
10 |
11 | edge_index, shifts, unit_shifts = cace.data.get_neighborhood(
12 | positions=config.positions,
13 | cutoff=5,
14 | cell=config.cell,
15 | pbc=config.pbc
16 | )
17 |
18 | print(np.shape(edge_index))
19 |
20 | # try if the same number of neighbors as the usual method
21 | from ase.neighborlist import NeighborList
22 | # Generate neighbor list with element-specific cutoffs
23 | cutoffs = [2.5 for atom in atoms]
24 | nl = NeighborList(cutoffs, skin=0.0, self_interaction=False, bothways=True)
25 | nl.update(atoms)
26 |
27 | # Store displacement vectors
28 | displacement_vectors = []
29 |
30 | for i, atom in enumerate(atoms):
31 | indices, offsets = nl.get_neighbors(i)
32 | for j, offset in zip(indices, offsets):
33 | disp_vector = atoms[j].position - atom.position
34 | disp_vector += np.dot(offset, atoms.get_cell())
35 | displacement_vectors.append(disp_vector)
36 |
37 | print(np.shape(displacement_vectors))
38 |
39 | assert np.shape(edge_index)[1] == np.shape(displacement_vectors)[0]
40 |
41 |
42 |
--------------------------------------------------------------------------------
/tests/test_pol_triclinic.py:
--------------------------------------------------------------------------------
1 | import torch, math
2 | import sys
3 | sys.path.append('../')
4 | import cace
5 | from cace.modules import Polarization
6 |
7 | #generate random data
8 | torch.random.manual_seed(0)
9 | r_now = torch.randn(192,3, dtype=torch.float64)
10 | q_now = torch.randn(192,1, dtype=torch.float64)
11 |
12 | #box for orignal function
13 | box_now = torch.tensor([10.,10.,10.])
14 |
15 | #calculate polarization using original function
16 | factor_o = box_now / (1j * 2.* torch.pi)
17 | phase_o = torch.exp(1j * 2.* torch.pi * r_now / box_now)
18 | polarization_o = torch.sum(q_now * phase_o, dim=(0)) * factor_o
19 | print("Original polarization:", polarization_o)
20 |
21 | #calculate polarization using new function
22 | pol_class = Polarization(pbc=True)
23 | box_ortho = torch.tensor([[10.0, 0.0, 0.0],
24 | [0.0, 10.0, 0.0],
25 | [0.0, 0.0, 10.0]], dtype=torch.float64)
26 | pol_ortho, phase_ortho = pol_class.compute_pol_pbc(r_now, q_now, box_ortho)
27 | print("Modified polarization:", pol_ortho)
28 | print(torch.allclose(polarization_o, pol_ortho))
29 |
30 |
31 | # Check that the polarization is equivalent to the rotated polarization in the triclinic cell
32 |
33 | #Rotation matrix around z-axis, 45 degrees
34 | theta = math.pi/4
35 | R = torch.tensor([[math.cos(theta), -math.sin(theta), 0],
36 | [math.sin(theta), math.cos(theta), 0],
37 | [0, 0, 1]], dtype=torch.float64)
38 |
39 | # Triclinic cell
40 | box_tric = torch.matmul(R, box_ortho)
41 | r_tric = torch.matmul(r_now, R)
42 |
43 | pol_tric, phase_tric = pol_class.compute_pol_pbc(r_tric, q_now, box_tric)
44 | pol_expected = torch.matmul(R.to(torch.complex128), pol_ortho.unsqueeze(1)).squeeze()
45 |
46 | print("Triclinic polarization:", pol_tric)
47 | print("Expected triclinic polarization from rotation:", pol_expected)
48 | print(torch.allclose(pol_tric, pol_expected))
--------------------------------------------------------------------------------