├── PaiNN
├── __init__.py
├── active_learning.py
├── calculator.py
├── data.py
├── kernel.py
├── model.py
└── select.py
├── README.md
├── scripts
├── MD.traj
├── arguments.toml
├── gpu_info
├── gpu_run.sh
├── md_run.py
├── runner_output.log
├── train.py
└── water_O2.cif
├── setup.py
└── workflow
├── al_select.py
├── config.toml
├── flow.py
├── md_run.py
├── train.py
└── vasp.py
/PaiNN/__init__.py:
--------------------------------------------------------------------------------
1 | from PaiNN import *
2 |
--------------------------------------------------------------------------------
/PaiNN/active_learning.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from collections import defaultdict
4 | from torch_scatter import scatter_mean
5 | from typing import List, Dict, Tuple, Optional
6 | from PaiNN.data import collate_atomsdata
7 | from PaiNN.select import *
8 | from PaiNN.kernel import *
9 |
10 | class FeatureExtractor(nn.Module):
11 | def __init__(self, model: nn.Module):
12 | super().__init__()
13 | self.model = model
14 | self._features = []
15 | self._grads = []
16 | self.hooks = []
17 | for name, layer in self.model.named_modules():
18 | if 'readout_mlp' in name and isinstance(layer, nn.Linear):
19 | self.hooks.append(layer.register_forward_pre_hook(self.save_feats_hook))
20 | self.hooks.append(layer.register_backward_hook(self.save_grads_hook))
21 |
22 | def save_feats_hook(self, _, in_feat):
23 | new_feat = torch.cat((in_feat[0].detach().clone(), torch.ones_like(in_feat[0][:, 0:1])), dim=-1)
24 | self._features.append(new_feat)
25 |
26 | def save_grads_hook(self, _, __, grad_output):
27 | self._grads.append(grad_output[0].detach().clone())
28 |
29 | def unhook(self):
30 | for hook in self.hooks:
31 | hook.remove()
32 |
33 | def forward(self, model_inputs: Dict[str, torch.Tensor]):
34 | self._features = []
35 | self._grads = []
36 | _ = self.model(model_inputs)
37 | return self._features, self._grads[::-1]
38 |
39 | class RandomProjections:
40 | """Store parameters of random projections"""
41 | def __init__(self, model: nn.Module, num_features: int):
42 | self.num_features = num_features
43 | if self.num_features > 0:
44 | self.in_feat_proj = [
45 | torch.randn(l.in_features +1, num_features, device=next(model.parameters()).device)
46 | for l in model.readout_mlp.children() if isinstance(l, nn.Linear)
47 | ]
48 | self.out_grad_proj = [
49 | torch.randn(l.out_features, num_features, device=next(model.parameters()).device)
50 | for l in model.readout_mlp.children() if isinstance(l, nn.Linear)
51 | ]
52 |
53 | class FeatureStatistics:
54 | """
55 | Generate features by giving models, pool, and training dataset
56 | """
57 | def __init__(
58 | self,
59 | models: List[nn.Module],
60 | dataset: torch.utils.data.Dataset,
61 | random_projections: List[RandomProjections],
62 | batch_size: int=8,
63 | ):
64 | self.models = models
65 | self.batch_size = batch_size
66 | self.dataset = dataset
67 | self.random_projections = random_projections
68 | self.device = next(models[0].parameters()).device
69 | self.g = None
70 | self.ens_stats = None
71 | self.Fisher = None
72 | self.F_reg_inv = None
73 |
74 | def _compute_ens_stats(self, model_inputs: Dict[str, torch.Tensor], labeled_data: bool=False) -> Dict[str, torch.Tensor]:
75 | """Compute energy variance, forces variance, energy absolute error, and forces absolute error"""
76 | ens_stats = defaultdict(list)
77 | predictions = defaultdict(list)
78 | for model in self.models:
79 | model_results = model(model_inputs)
80 | predictions['energy'].append(model_results["energy"].detach())
81 | predictions['forces'].append(model_results["forces"].detach())
82 |
83 | predictions = {k: torch.stack(v) for k, v in predictions.items()}
84 |
85 | image_idx = torch.arange(
86 | model_inputs['num_atoms'].shape[0],
87 | device=model_inputs['num_atoms'].device,
88 | )
89 | image_idx = torch.repeat_interleave(image_idx, model_inputs['num_atoms'])
90 |
91 | if len(self.models) > 1:
92 | E_var = torch.var(predictions['energy'], dim=0)
93 | F_var = torch.var(predictions['forces'], dim=0)
94 | F_var = scatter_mean(torch.mean(F_var, dim=-1), image_idx, dim=0)
95 | ens_stats['Energy-Var'] = E_var
96 | ens_stats['Forces-Var'] = F_var
97 |
98 | if labeled_data:
99 | E_AE = torch.abs(model_inputs['energy'] - torch.mean(predictions['energy'], dim=0))
100 | F_AE = torch.abs(model_inputs['forces'] - torch.mean(predictions['forces'], dim=0))
101 | F_AE = scatter_mean(torch.mean(F_AE, dim=-1), image_idx, dim=0)
102 |
103 | ens_stats['Energy-AE'] = E_AE
104 | ens_stats['Forces-AE'] = F_AE
105 |
106 | return ens_stats
107 |
108 | def _compute_features(
109 | self,
110 | feature_extractor: FeatureExtractor,
111 | model_inputs: torch.tensor,
112 | random_projection: RandomProjections,
113 | kernel: str='ll-gradient',
114 | ) -> torch.Tensor:
115 | """
116 | Implementing features calculation and kernel transformation.
117 | Available features are:
118 | ll-gradient: last layer gradient feature, obtained from neural networks.
119 | full-gradient: All gradient information from NN, must use random projections kernel transformation.
120 | gnn: Features learned by message passing layers
121 | symmetry-function: Behler Parrinello symmetry function, can only be used for CUR. To be implemented.
122 | """
123 | image_idx = torch.arange(
124 | model_inputs['num_atoms'].shape[0],
125 | device=model_inputs['num_atoms'].device,
126 | )
127 | image_idx = torch.repeat_interleave(image_idx, model_inputs['num_atoms'])
128 |
129 | if kernel == 'full-gradient':
130 | assert random_projection.num_features != 0, "Error! Random projections must be provided!"
131 | feats, grads = feature_extractor(model_inputs)
132 | atomic_g = torch.zeros((image_idx.shape[0], random_projection.num_features))
133 | for feat, grad, in_proj, out_proj in zip(
134 | feats,
135 | grads,
136 | random_projection.in_feat_proj,
137 | random_projection.out_grad_proj,
138 | ):
139 | atomic_g = (feat @ in_proj) * (grad @ out_proj)
140 |
141 | g = torch.zeros(
142 | (model_inputs['num_atoms'].shape[0], atomic_g.shape[1]),
143 | dtype = atomic_g.dtype,
144 | device = atomic_g.device,
145 | ).index_add(0, image_idx, atomic_g)
146 | elif kernel == 'local_full-g':
147 | assert random_projection.num_features != 0, "Error! Random projections must be provided!"
148 | feats, grads = feature_extractor(model_inputs)
149 | atomic_g = torch.zeros((image_idx.shape[0], random_projection.num_features))
150 | for feat, grad, in_proj, out_proj in zip(
151 | feats,
152 | grads,
153 | random_projection.in_feat_proj,
154 | random_projection.out_grad_proj,
155 | ):
156 | atomic_g = (feat @ in_proj) * (grad @ out_proj)
157 | g = atomic_g
158 |
159 | elif kernel == 'll-gradient':
160 | feats, grads = feature_extractor(model_inputs)
161 | if random_projection.num_features != 0:
162 | atomic_g = (feats[-1] @ random_projection.in_feat_proj[-1]) *\
163 | (grads[-1] @ random_projection.out_grad_proj[-1])
164 | else:
165 | atomic_g = feats[-1][:, :-1]
166 |
167 | g = torch.zeros(
168 | (model_inputs['num_atoms'].shape[0], atomic_g.shape[1]),
169 | dtype = atomic_g.dtype,
170 | device = atomic_g.device,
171 | ).index_add(0, image_idx, atomic_g)
172 |
173 | elif kernel == 'local_ll-g':
174 | feats, grads = feature_extractor(model_inputs)
175 | if random_projection.num_features != 0:
176 | atomic_g = (feats[-1] @ random_projection.in_feat_proj[-1]) *\
177 | (grads[-1] @ random_projection.out_grad_proj[-1])
178 | else:
179 | atomic_g = feats[-1][:, :-1]
180 | g = atomic_g
181 |
182 | elif kernel == 'gnn':
183 | feats, grads = feature_extractor(model_inputs)
184 | if random_projection.num_features != 0:
185 | atomic_g = (feats[0] @ random_projection.in_feat_proj[0]) *\
186 | (grads[0] @ random_projection.out_grad_proj[0])
187 | else:
188 | atomic_g = feats[0][:, :-1]
189 |
190 | g = torch.zeros(
191 | (model_inputs['num_atoms'].shape[0], atomic_g.shape[1]),
192 | dtype = atomic_g.dtype,
193 | device = atomic_g.device,
194 | ).index_add(0, image_idx, atomic_g)
195 |
196 | elif kernel == 'local_gnn':
197 | feats, grads = feature_extractor(model_inputs)
198 | if random_projection.num_features != 0:
199 | atomic_g = (feats[0] @ random_projection.in_feat_proj[0]) *\
200 | (grads[0] @ random_projection.out_grad_proj[0])
201 | else:
202 | atomic_g = feats[0][:, :-1]
203 | g = atomic_g
204 |
205 | return g
206 |
207 | def _compute_fisher(self, g: torch.Tensor) -> torch.Tensor:
208 | return torch.einsum('mci, mcj -> mij', g, g)
209 |
210 | def get_features(
211 | self,
212 | dataset: Optional[torch.utils.data.Dataset]=None,
213 | kernel: str='full-gradient',
214 | ) -> torch.Tensor:
215 | """
216 | :return: Feature vector of ``shape=(n_models, n_structures, n_features)``.
217 | """
218 | if dataset == None:
219 | dataset = self.dataset
220 | else:
221 | self.dataset = dataset
222 | self.g = None
223 |
224 | if self.g == None:
225 | dataloader = torch.utils.data.DataLoader(
226 | dataset=dataset,
227 | batch_size=self.batch_size,
228 | collate_fn=collate_atomsdata,
229 | )
230 | global_g = []
231 | for i, model in enumerate(self.models):
232 | feat_extract = FeatureExtractor(model)
233 | model_g = []
234 | for batch in dataloader:
235 | batch = {k: v.to(self.device) for k, v in batch.items()}
236 | model_g.append(self._compute_features(
237 | feat_extract,
238 | batch,
239 | kernel=kernel,
240 | random_projection=self.random_projections[i],
241 | ))
242 | feat_extract.unhook()
243 | model_g = torch.cat(model_g)
244 | # Normalization
245 | model_g = (model_g - torch.mean(model_g, dim=0)) / torch.var(model_g, dim=0)
246 | global_g.append(model_g)
247 | # global_g.append(torch.cat(model_g))
248 |
249 | self.g = torch.stack(global_g)
250 |
251 | return self.g
252 |
253 | def get_num_atoms(
254 | self,
255 | dataset: Optional[torch.utils.data.Dataset]=None,
256 | ):
257 | if dataset == None:
258 | dataset = self.dataset
259 | else:
260 | self.dataset = dataset
261 | num_atoms = []
262 | dataloader = torch.utils.data.DataLoader(
263 | dataset=dataset,
264 | batch_size=self.batch_size,
265 | collate_fn=collate_atomsdata,
266 | )
267 | for batch in dataloader:
268 | batch = {k: v.to(self.device) for k, v in batch.items()}
269 | num_atoms.append(batch['num_atoms'])
270 |
271 | return torch.cat(num_atoms)
272 |
273 | def get_ens_stats(self, dataset: Optional[torch.utils.data.Dataset]=None) -> Dict[str, torch.Tensor]:
274 | """
275 | :return: Dict of energy statistics
276 | """
277 | if dataset == None:
278 | dataset = self.dataset
279 | else:
280 | self.dataset = dataset
281 | self.ens_stats = None
282 |
283 | if self.ens_stats == None:
284 | dataloader = torch.utils.data.DataLoader(
285 | dataset=dataset,
286 | batch_size=self.batch_size,
287 | collate_fn=collate_atomsdata,
288 | )
289 | ens_stats = []
290 | for batch in dataloader:
291 | batch = {k: v.to(self.device) for k, v in batch.items()}
292 | labeled_data = True if 'energy' in batch.keys() else False
293 | ens_stats.append(self._compute_ens_stats(batch, labeled_data))
294 |
295 | self.ens_stats = {k: torch.cat([ens[k] for ens in ens_stats]) for k in ens_stats[0].keys()}
296 |
297 | return self.ens_stats
298 |
299 | def get_fisher(self) -> torch.Tensor:
300 | if self.Fisher is None:
301 | self.Fisher = self._compute_fisher(self.get_features())
302 | return self.Fisher
303 |
304 | def get_F_inv(self) -> torch.Tensor:
305 | """
306 | :return: Regularized inverse of Fisher matrix of "shape=(n_models, n_features, n_features)".
307 | """
308 | if self.F_reg_inv is None:
309 | F = self.get_features()
310 | g = self.get_g()
311 | # empirical regularisation
312 | lam = torch.linalg.trace(F) / (g.shape[1] * g.shape[2])
313 | self.F_train_reg_inv = torch.linalg.inv(F + lam * torch.eye(F.shape[1]))
314 | return self.F_train_reg_inv
315 |
316 | class GeneralActiveLearning:
317 | """Provides methods for selecting batches during active learning.
318 |
319 | :param kernel: Name of the kernel, e.g. "full-g", "ll-g", "full-F_inv", "ll-F_inv", "qbc-energy", "qbc-force".
320 | "random" produces random selection and "ae-energy" and "ae-force" select by absolute errors
321 | on the pool data, which is only possible if the pool data is already labeled.
322 | :param selection: Selection method, one of "max_dist_greedy", "deterministic_CUR", "lcmd_greedy", "max_det_greedy" or "max_diag".
323 | :param n_random_features: If "n_random_features = 0", do not use random projections.
324 | Otherwise, use random projections of all linear-layer gradients.
325 | """
326 | def __init__(
327 | self,
328 | kernel = 'full-g',
329 | selection = 'max_diag',
330 | n_random_features = 0,
331 | ):
332 | self.kernel = kernel
333 | self.selection = selection
334 | self.n_random_features = n_random_features
335 |
336 | def select(
337 | self,
338 | models: List[nn.Module],
339 | datasets: Dict[str, torch.utils.data.Dataset],
340 | batch_size: int = 8,
341 | al_batch_size: int = 100,
342 | ):
343 | """
344 | models: pytorch models,
345 | dataset: a dictionary containing pool, train, and validation dataset,
346 | batch_size: batch size for extracting features,
347 | al_batch_size: active learning selection batch size
348 | """
349 | if (self.kernel == 'qbc-energy' or self.kernel == 'qbc-force' or self.kernel == 'ae-energy' or
350 | self.kernel == 'ae-force' or self.kernel == 'random') and self.selection != 'max_diag':
351 | raise RuntimeError(f'{self.kernel} kernel can only be used with max_diag selection method,'
352 | f' not with {self.selection}!')
353 | random_projections = [RandomProjections(model, self.n_random_features) for model in models]
354 |
355 | stats = {
356 | key: FeatureStatistics(models, ds, random_projections, batch_size)
357 | for key, ds in datasets.items()
358 | }
359 |
360 | if self.selection == 'max_dist_greedy':
361 | matrix = self._get_kernel_matrix(stats['pool'], stats['train'])
362 | idxs = max_dist_greedy(matrix=matrix, batch_size=al_batch_size, n_train=len(datasets['train']))
363 | elif self.selection == 'max_diag':
364 | matrix = self._get_kernel_matrix(stats['pool'])
365 | idxs = max_diag(matrix=matrix, batch_size=al_batch_size)
366 | elif self.selection == 'max_det_greedy':
367 | matrix = self._get_kernel_matrix(stats['pool'])
368 | idxs = max_det_greedy(matrix=matrix, batch_size=al_batch_size)
369 | elif self.selection == 'lcmd_greedy':
370 | matrix = self._get_kernel_matrix(stats['pool'], stats['train'])
371 | idxs = lcmd_greedy(matrix=matrix, batch_size=al_batch_size, n_train=len(datasets['train']))
372 | elif self.selection == 'max_det_greedy_local':
373 | matrix, num_atoms = self._get_kernel_matrix(stats['pool'])
374 | idxs = max_det_greedy_local(matrix=matrix, batch_size=al_batch_size, num_atoms=num_atoms)
375 | else:
376 | raise NotImplementedError(f"Unknown selection method '{self.selection}' for active learning!")
377 |
378 | return idxs.cpu().tolist()
379 |
380 |
381 | def _get_kernel_matrix(self, pool_stats: FeatureStatistics, train_stats: Optional[FeatureStatistics]=None) -> KernelMatrix:
382 | stats_list = [pool_stats] if train_stats == None else [pool_stats, train_stats]
383 |
384 | if self.kernel == 'full-g':
385 | return FeatureKernelMatrix(torch.cat([s.get_features(kernel='full-gradient') for s in stats_list], dim=1))
386 | elif self.kernel == 'll-g':
387 | return FeatureKernelMatrix(torch.cat([s.get_features(kernel='ll-gradient') for s in stats_list], dim=1))
388 | elif self.kernel == 'gnn':
389 | return FeatureKernelMatrix(torch.cat([s.get_features(kernel='gnn') for s in stats_list], dim=1))
390 | elif self.kernel == 'local_full-g':
391 | matrix = FeatureKernelMatrix(torch.cat([s.get_features(kernel='local_full-g') for s in stats_list], dim=1))
392 | num_atoms = torch.cat([s.get_num_atoms() for s in stats_list])
393 | return matrix, num_atoms
394 | elif self.kernel == 'local_ll-g':
395 | matrix = FeatureKernelMatrix(torch.cat([s.get_features(kernel='local_ll-g') for s in stats_list], dim=1))
396 | num_atoms = torch.cat([s.get_num_atoms() for s in stats_list])
397 | return matrix, num_atoms
398 | elif self.kernel == 'local_gnn':
399 | matrix = FeatureKernelMatrix(torch.cat([s.get_features(kernel='local_gnn') for s in stats_list], dim=1))
400 | num_atoms = torch.cat([s.get_num_atoms() for s in stats_list])
401 | return matrix, num_atoms
402 | elif self.kernel == 'full-F_inv':
403 | return FeatureCovKernelMatrix(torch.cat([s.get_features(kernel='full-gradient') for s in stats_list], dim=1),
404 | train_stats.get_F_reg_inv())
405 | elif self.kernel == 'll-F_inv':
406 | return FeatureCovKernelMatrix(torch.cat([s.get_features(kernel='ll-gradient') for s in stats_list], dim=1),
407 | train_stats.get_F_reg_inv())
408 | elif self.kernel == 'qbc-energy':
409 | return DiagonalKernelMatrix(pool_stats.get_ens_stats()['Energy-Var'])
410 | elif self.kernel == 'qbc-force':
411 | return DiagonalKernelMatrix(pool_stats.get_ens_stats()['Forces-Var'])
412 | elif self.kernel == 'ae-energy':
413 | return DiagonalKernelMatrix(pool_stats.get_ens_stats()['Energy-AE'])
414 | elif self.kernel == 'ae-force':
415 | return DiagonalKernelMatrix(pool_stats.get_ens_stats()['Forces-AE'])
416 | elif self.kernel == 'random':
417 | return DiagonalKernelMatrix(torch.rand([sum([len(s.dataset) for s in stats_list])]))
418 | else:
419 | raise RuntimeError(f"Unknown active learning kernel {self.kernel}!")
--------------------------------------------------------------------------------
/PaiNN/calculator.py:
--------------------------------------------------------------------------------
1 | from ase.calculators.calculator import Calculator, all_changes
2 | from PaiNN.data import AseDataReader
3 | import numpy as np
4 |
5 | class MLCalculator(Calculator):
6 | implemented_properties = ["energy", "forces"]
7 |
8 | def __init__(
9 | self,
10 | model,
11 | energy_scale=1.0,
12 | forces_scale=1.0,
13 | # stress_scale=1.0,
14 | **kwargs
15 | ):
16 | super().__init__(**kwargs)
17 |
18 | self.model = model
19 | self.model_device = next(model.parameters()).device
20 | self.cutoff = model.cutoff
21 | self.ase_data_reader = AseDataReader(self.cutoff)
22 | self.energy_scale = energy_scale
23 | self.forces_scale = forces_scale
24 | # self.stress_scale = stress_scale
25 |
26 | def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes):
27 | """
28 | Args:
29 | atoms (ase.Atoms): ASE atoms object.
30 | properties (list of str): do not use this, no functionality
31 | system_changes (list of str): List of changes for ASE.
32 | """
33 | # First call original calculator to set atoms attribute
34 | # (see https://wiki.fysik.dtu.dk/ase/_modules/ase/calculators/calculator.html#Calculator)
35 | if atoms is not None:
36 | self.atoms = atoms.copy()
37 |
38 | model_inputs = self.ase_data_reader(self.atoms)
39 | model_inputs = {
40 | k: v.to(self.model_device) for (k, v) in model_inputs.items()
41 | }
42 |
43 | model_results = self.model(model_inputs)
44 |
45 | results = {}
46 |
47 | # Convert outputs to calculator format
48 | results["forces"] = (
49 | model_results["forces"].detach().cpu().numpy() * self.forces_scale
50 | )
51 | results["energy"] = (
52 | model_results["energy"][0].detach().cpu().numpy().item()
53 | * self.energy_scale
54 | )
55 | # results["stress"] = (
56 | # model_results["stress"][0].detach().cpu().numpy() * self.stress_scale
57 | # )
58 | # atoms.info["ll_out"] = {
59 | # k: v.detach().cpu().numpy() for k, v in model_results["ll_out"].items()
60 | # }
61 | if model_results.get("fps"):
62 | atoms.info["fps"] = model_results["fps"].detach().cpu().numpy()
63 |
64 | self.results = results
65 |
66 | class EnsembleCalculator(Calculator):
67 | implemented_properties = ["energy", "forces"]
68 |
69 | def __init__(
70 | self,
71 | models,
72 | energy_scale=1.0,
73 | forces_scale=1.0,
74 | # stress_scale=1.0,
75 | **kwargs
76 | ):
77 | super().__init__(**kwargs)
78 |
79 | self.models = models
80 | self.model_device = next(models[0].parameters()).device
81 | self.cutoff = models[0].cutoff
82 | self.ase_data_reader = AseDataReader(self.cutoff)
83 | self.energy_scale = energy_scale
84 | self.forces_scale = forces_scale
85 | # self.stress_scale = stress_scale
86 |
87 | def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes):
88 | """
89 | Args:
90 | atoms (ase.Atoms): ASE atoms object.
91 | properties (list of str): do not use this, no functionality
92 | system_changes (list of str): List of changes for ASE.
93 | """
94 | # First call original calculator to set atoms attribute
95 | # (see https://wiki.fysik.dtu.dk/ase/_modules/ase/calculators/calculator.html#Calculator)
96 | if atoms is not None:
97 | self.atoms = atoms.copy()
98 |
99 | model_inputs = self.ase_data_reader(self.atoms)
100 | model_inputs = {
101 | k: v.to(self.model_device) for (k, v) in model_inputs.items()
102 | }
103 |
104 | predictions = {'energy': [], 'forces': []}
105 | for model in self.models:
106 | model_results = model(model_inputs)
107 | predictions['energy'].append(model_results["energy"][0].detach().cpu().numpy().item() * self.energy_scale)
108 | predictions['forces'].append(model_results["forces"].detach().cpu().numpy() * self.forces_scale)
109 |
110 | results = {"energy": np.mean(predictions['energy'])}
111 | results["forces"] = np.mean(np.stack(predictions['forces']), axis=0)
112 |
113 | ensemble = {
114 | 'energy_var': np.var(predictions['energy']),
115 | 'forces_var': np.var(np.stack(predictions['forces']), axis=0),
116 | 'forces_l2_var': np.var(np.linalg.norm(predictions['forces'], axis=2), axis=0),
117 | }
118 |
119 | results['ensemble'] = ensemble
120 |
121 | self.results = results
122 |
--------------------------------------------------------------------------------
/PaiNN/data.py:
--------------------------------------------------------------------------------
1 | from ase.io import read, write, Trajectory
2 | import torch
3 | from typing import List
4 | import asap3
5 | import numpy as np
6 | from scipy.spatial import distance_matrix
7 |
8 | # def ase_properties(atoms):
9 | # """Guess dataset format from an ASE atoms"""
10 | # atoms_prop = []
11 | #
12 | # if atoms.pbc.any():
13 | # atoms_prop.append('cell')
14 | #
15 | # try:
16 | # atoms.get_potential_energy()
17 | # atoms_prop.append('energy')
18 | # except:
19 | # pass
20 | #
21 | # try:
22 | # atoms.get_forces()
23 | # atoms_prop.append('forces')
24 | # except:
25 | # pass
26 | #
27 | # return atoms_prop
28 |
29 | class AseDataReader:
30 | def __init__(self, cutoff=5.0):
31 | self.cutoff = cutoff
32 |
33 | def __call__(self, atoms):
34 | atoms_data = {
35 | 'num_atoms': torch.tensor([atoms.get_global_number_of_atoms()]),
36 | 'elems': torch.tensor(atoms.numbers),
37 | 'coord': torch.tensor(atoms.positions, dtype=torch.float),
38 | }
39 |
40 | if atoms.pbc.any():
41 | pairs, n_diff = self.get_neighborlist(atoms)
42 | atoms_data['cell'] = torch.tensor(atoms.cell[:], dtype=torch.float)
43 | else:
44 | pairs, n_diff = self.get_neighborlist_simple(atoms)
45 |
46 | atoms_data['pairs'] = torch.from_numpy(pairs)
47 | atoms_data['n_diff'] = torch.from_numpy(n_diff).float()
48 | atoms_data['num_pairs'] = torch.tensor([pairs.shape[0]])
49 |
50 | try:
51 | energy = torch.tensor([atoms.get_potential_energy()], dtype=torch.float)
52 | atoms_data['energy'] = energy
53 | except (AttributeError, RuntimeError):
54 | pass
55 |
56 | try:
57 | forces = torch.tensor(atoms.get_forces(apply_constraint=False), dtype=torch.float)
58 | atoms_data['forces'] = forces
59 | except (AttributeError, RuntimeError):
60 | pass
61 |
62 | return atoms_data
63 |
64 |
65 | def get_neighborlist(self, atoms):
66 | nl = asap3.FullNeighborList(self.cutoff, atoms)
67 | pair_i_idx = []
68 | pair_j_idx = []
69 | n_diff = []
70 | for i in range(len(atoms)):
71 | indices, diff, _ = nl.get_neighbors(i)
72 | pair_i_idx += [i] * len(indices) # local index of pair i
73 | pair_j_idx.append(indices) # local index of pair j
74 | n_diff.append(diff)
75 |
76 | pair_j_idx = np.concatenate(pair_j_idx)
77 | pairs = np.stack((pair_i_idx, pair_j_idx), axis=1)
78 | n_diff = np.concatenate(n_diff)
79 |
80 | return pairs, n_diff
81 |
82 | def get_neighborlist_simple(self, atoms):
83 | pos = atoms.get_positions()
84 | dist_mat = distance_matrix(pos, pos)
85 | mask = dist_mat < self.cutoff
86 | np.fill_diagonal(mask, False)
87 | pairs = np.argwhere(mask)
88 | n_diff = pos[pairs[:, 1]] - pos[pairs[:, 0]]
89 |
90 | return pairs, n_diff
91 |
92 | class AseDataset(torch.utils.data.Dataset):
93 | def __init__(self, ase_db, cutoff=5.0, **kwargs):
94 | super().__init__(**kwargs)
95 |
96 | if isinstance(ase_db, str):
97 | self.db = Trajectory(ase_db)
98 | else:
99 | self.db = ase_db
100 |
101 | self.cutoff = cutoff
102 | self.atoms_reader = AseDataReader(cutoff)
103 |
104 | def __len__(self):
105 | return len(self.db)
106 |
107 | def __getitem__(self, idx):
108 | atoms = self.db[idx]
109 | atoms_data = self.atoms_reader(atoms)
110 | return atoms_data
111 |
112 | def cat_tensors(tensors: List[torch.Tensor]):
113 | if tensors[0].shape:
114 | return torch.cat(tensors)
115 | return torch.stack(tensors)
116 |
117 | def collate_atomsdata(atoms_data: List[dict], pin_memory=True):
118 | # convert from list of dicts to dict of lists
119 | dict_of_lists = {k: [dic[k] for dic in atoms_data] for k in atoms_data[0]}
120 | if pin_memory:
121 | pin = lambda x: x.pin_memory()
122 | else:
123 | pin = lambda x: x
124 |
125 | collated = {k: cat_tensors(v) for k, v in dict_of_lists.items()}
126 | return collated
127 |
--------------------------------------------------------------------------------
/PaiNN/kernel.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class KernelMatrix:
4 | """Abstract kernel class used to calculate kernel matrix by giving a feature matrix"""
5 | def __init__(self, num_col: int):
6 | self.num_columns = num_col
7 |
8 | def get_number_of_columns(self) -> int:
9 | return self.num_columns
10 |
11 | def get_column(self, i: int) -> torch.Tensor:
12 | raise RuntimeError("Not implemented")
13 |
14 | def get_diag(self) -> torch.Tensor:
15 | raise RuntimeError("Not implemented")
16 |
17 | def get_sq_dists(self, i: int) -> torch.Tensor:
18 | diag = self.get_diag()
19 | return diag[i] + diag - 2 * self.get_column(i)
20 |
21 | class DiagonalKernelMatrix(KernelMatrix):
22 | """
23 | Represents a diagonal kernel matrix, where get_column() and get_sq_dists() is not implemented.
24 |
25 | :param g: Diagonal of the kernel matrix.
26 | """
27 | def __init__(self, g: torch.Tensor):
28 | super().__init__(g.shape[0])
29 | self.diag = g
30 |
31 | def get_diag(self) -> torch.Tensor:
32 | return self.diag
33 |
34 | class FeatureKernelMatrix(KernelMatrix):
35 | """
36 | input: m x n x p matrix
37 | m: number of models
38 | n: number of entries
39 | p: dimensionality of features
40 | """
41 | def __init__(self, mat: torch.Tensor):
42 | super().__init__(mat.shape[1])
43 | self.mat = mat
44 | self.diag = torch.einsum('mbi, mbi -> mb', mat, mat)
45 |
46 | def get_column(self, i: int) -> torch.Tensor:
47 | return torch.mean(torch.einsum("mnp, mp -> mn", self.mat, self.mat[:, i, :]), dim=0)
48 |
49 | def get_diag(self) -> torch.Tensor:
50 | return torch.mean(self.diag, dim=0)
51 |
52 | class FeatureCovKernelMatrix(KernelMatrix):
53 | """
54 | input: m x n x p matrix mat, m x p x p covariance matrix
55 | m: number of models
56 | n: number of entries
57 | p: dimensionality of features
58 | """
59 | def __init__(self, g: torch.Tensor, cov_mat: torch.Tensor):
60 | super().__init__(mat.shape[1])
61 | self.g = g
62 | self.cov_mat = cov_mat
63 | self.cov_g = torch.einsum('mij, mbi -> mbj', self.cov_mat, g)
64 | self.diag = torch.einsum('mbi, mbi -> mb', self.cov_g, g)
65 |
66 | def get_diag(self) -> torch.Tensor:
67 | return torch.mean(self.diag, dim=0)
68 |
69 | def get_column(self, i: int) -> torch.Tensor:
70 | return torch.mean(torch.einsum('mbi, mi -> mb', self.g, self.cov_g[:, i, :]), dim=0)
71 |
--------------------------------------------------------------------------------
/PaiNN/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | def sinc_expansion(edge_dist: torch.Tensor, edge_size: int, cutoff: float):
5 | """
6 | calculate sinc radial basis function:
7 |
8 | sin(n *pi*d/d_cut)/d
9 | """
10 | n = torch.arange(edge_size, device=edge_dist.device) + 1
11 | return torch.sin(edge_dist.unsqueeze(-1) * n * torch.pi / cutoff) / edge_dist.unsqueeze(-1)
12 |
13 | def cosine_cutoff(edge_dist: torch.Tensor, cutoff: float):
14 | """
15 | Calculate cutoff value based on distance.
16 | This uses the cosine Behler-Parinello cutoff function:
17 |
18 | f(d) = 0.5*(cos(pi*d/d_cut)+1) for d < d_cut and 0 otherwise
19 | """
20 |
21 | return torch.where(
22 | edge_dist < cutoff,
23 | 0.5 * (torch.cos(torch.pi * edge_dist / cutoff) + 1),
24 | torch.tensor(0.0, device=edge_dist.device, dtype=edge_dist.dtype),
25 | )
26 |
27 | class PainnMessage(nn.Module):
28 | """Message function"""
29 | def __init__(self, node_size: int, edge_size: int, cutoff: float):
30 | super().__init__()
31 |
32 | self.edge_size = edge_size
33 | self.node_size = node_size
34 | self.cutoff = cutoff
35 |
36 | self.scalar_message_mlp = nn.Sequential(
37 | nn.Linear(node_size, node_size),
38 | nn.SiLU(),
39 | nn.Linear(node_size, node_size * 3),
40 | )
41 |
42 | self.filter_layer = nn.Linear(edge_size, node_size * 3)
43 |
44 | def forward(self, node_scalar, node_vector, edge, edge_diff, edge_dist):
45 | # remember to use v_j, s_j but not v_i, s_i
46 | filter_weight = self.filter_layer(sinc_expansion(edge_dist, self.edge_size, self.cutoff))
47 | filter_weight = filter_weight * cosine_cutoff(edge_dist, self.cutoff).unsqueeze(-1)
48 | scalar_out = self.scalar_message_mlp(node_scalar)
49 | filter_out = filter_weight * scalar_out[edge[:, 1]]
50 |
51 | gate_state_vector, gate_edge_vector, message_scalar = torch.split(
52 | filter_out,
53 | self.node_size,
54 | dim = 1,
55 | )
56 |
57 | # num_pairs * 3 * node_size, num_pairs * node_size
58 | message_vector = node_vector[edge[:, 1]] * gate_state_vector.unsqueeze(1)
59 | edge_vector = gate_edge_vector.unsqueeze(1) * (edge_diff / edge_dist.unsqueeze(-1)).unsqueeze(-1)
60 | message_vector = message_vector + edge_vector
61 |
62 | # sum message
63 | residual_scalar = torch.zeros_like(node_scalar)
64 | residual_vector = torch.zeros_like(node_vector)
65 | residual_scalar.index_add_(0, edge[:, 0], message_scalar)
66 | residual_vector.index_add_(0, edge[:, 0], message_vector)
67 |
68 | # new node state
69 | new_node_scalar = node_scalar + residual_scalar
70 | new_node_vector = node_vector + residual_vector
71 |
72 | return new_node_scalar, new_node_vector
73 |
74 | class PainnUpdate(nn.Module):
75 | """Update function"""
76 | def __init__(self, node_size: int):
77 | super().__init__()
78 |
79 | self.update_U = nn.Linear(node_size, node_size)
80 | self.update_V = nn.Linear(node_size, node_size)
81 |
82 | self.update_mlp = nn.Sequential(
83 | nn.Linear(node_size * 2, node_size),
84 | nn.SiLU(),
85 | nn.Linear(node_size, node_size * 3),
86 | )
87 |
88 | def forward(self, node_scalar, node_vector):
89 | Uv = self.update_U(node_vector)
90 | Vv = self.update_V(node_vector)
91 |
92 | Vv_norm = torch.linalg.norm(Vv, dim=1)
93 | mlp_input = torch.cat((Vv_norm, node_scalar), dim=1)
94 | mlp_output = self.update_mlp(mlp_input)
95 |
96 | a_vv, a_sv, a_ss = torch.split(
97 | mlp_output,
98 | node_vector.shape[-1],
99 | dim = 1,
100 | )
101 |
102 | delta_v = a_vv.unsqueeze(1) * Uv
103 | inner_prod = torch.sum(Uv * Vv, dim=1)
104 | delta_s = a_sv * inner_prod + a_ss
105 |
106 | return node_scalar + delta_s, node_vector + delta_v
107 |
108 | class PainnModel(nn.Module):
109 | """PainnModel without edge updating"""
110 | def __init__(
111 | self,
112 | num_interactions,
113 | hidden_state_size,
114 | cutoff,
115 | normalization=True,
116 | target_mean=[0.0],
117 | target_stddev=[1.0],
118 | atomwise_normalization=True,
119 | **kwargs,
120 | ):
121 | super().__init__()
122 |
123 | num_embedding = 119 # number of all elements
124 | self.cutoff = cutoff
125 | self.num_interactions = num_interactions
126 | self.hidden_state_size = hidden_state_size
127 | self.edge_embedding_size = 20
128 |
129 | # Setup atom embeddings
130 | self.atom_embedding = nn.Embedding(num_embedding, hidden_state_size)
131 |
132 | # Setup message-passing layers
133 | self.message_layers = nn.ModuleList(
134 | [
135 | PainnMessage(self.hidden_state_size, self.edge_embedding_size, self.cutoff)
136 | for _ in range(self.num_interactions)
137 | ]
138 | )
139 | self.update_layers = nn.ModuleList(
140 | [
141 | PainnUpdate(self.hidden_state_size)
142 | for _ in range(self.num_interactions)
143 | ]
144 | )
145 |
146 | # Setup readout function
147 | self.readout_mlp = nn.Sequential(
148 | nn.Linear(self.hidden_state_size, self.hidden_state_size),
149 | nn.SiLU(),
150 | nn.Linear(self.hidden_state_size, 1),
151 | )
152 |
153 | # Normalisation constants
154 | self.normalization = torch.nn.Parameter(
155 | torch.tensor(normalization), requires_grad=False
156 | )
157 | self.atomwise_normalization = torch.nn.Parameter(
158 | torch.tensor(atomwise_normalization), requires_grad=False
159 | )
160 | self.normalize_stddev = torch.nn.Parameter(
161 | torch.tensor(target_stddev[0]), requires_grad=False
162 | )
163 | self.normalize_mean = torch.nn.Parameter(
164 | torch.tensor(target_mean[0]), requires_grad=False
165 | )
166 |
167 | def forward(self, input_dict, compute_forces=True):
168 | num_atoms = input_dict['num_atoms']
169 | num_pairs = input_dict['num_pairs']
170 |
171 | # edge offset. Add offset to edges to get indices of pairs in a batch but not a structure
172 | edge = input_dict['pairs']
173 | edge_offset = torch.cumsum(
174 | torch.cat((torch.tensor([0],
175 | device=num_atoms.device,
176 | dtype=num_atoms.dtype,
177 | ), num_atoms[:-1])),
178 | dim=0
179 | )
180 | edge_offset = torch.repeat_interleave(edge_offset, num_pairs)
181 | edge = edge + edge_offset.unsqueeze(-1)
182 | edge_diff = input_dict['n_diff']
183 | if compute_forces:
184 | edge_diff.requires_grad_()
185 | edge_dist = torch.linalg.norm(edge_diff, dim=1)
186 |
187 | node_scalar = self.atom_embedding(input_dict['elems'])
188 | node_vector = torch.zeros((input_dict['coord'].shape[0], 3, self.hidden_state_size),
189 | device=edge_diff.device,
190 | dtype=edge_diff.dtype,
191 | )
192 |
193 | for message_layer, update_layer in zip(self.message_layers, self.update_layers):
194 | node_scalar, node_vector = message_layer(node_scalar, node_vector, edge, edge_diff, edge_dist)
195 | node_scalar, node_vector = update_layer(node_scalar, node_vector)
196 |
197 | node_scalar = self.readout_mlp(node_scalar)
198 | node_scalar.squeeze_()
199 |
200 | image_idx = torch.arange(input_dict['num_atoms'].shape[0],
201 | device=edge.device,
202 | )
203 | image_idx = torch.repeat_interleave(image_idx, num_atoms)
204 |
205 | energy = torch.zeros_like(input_dict['num_atoms']).float()
206 | energy.index_add_(0, image_idx, node_scalar)
207 |
208 | # Apply (de-)normalization
209 | if self.normalization:
210 | normalizer = self.normalize_stddev
211 | energy = normalizer * energy
212 | mean_shift = self.normalize_mean
213 | if self.atomwise_normalization:
214 | mean_shift = input_dict["num_atoms"] * mean_shift
215 | energy = energy + mean_shift
216 |
217 | result_dict = {'energy': energy}
218 |
219 | if compute_forces:
220 | dE_ddiff = torch.autograd.grad(
221 | energy,
222 | edge_diff,
223 | grad_outputs=torch.ones_like(energy),
224 | retain_graph=True,
225 | create_graph=True,
226 | )[0]
227 |
228 | # diff = R_j - R_i, so -dE/dR_j = -dE/ddiff, -dE/R_i = dE/ddiff
229 | i_forces = torch.zeros_like(input_dict['coord']).index_add(0, edge[:, 0], dE_ddiff)
230 | j_forces = torch.zeros_like(input_dict['coord']).index_add(0, edge[:, 1], -dE_ddiff)
231 | forces = i_forces + j_forces
232 |
233 | result_dict['forces'] = forces
234 |
235 | return result_dict
236 |
237 | class PainnModel_predict(nn.Module):
238 | """PainnModel without edge updating"""
239 | def __init__(self, num_interactions, hidden_state_size, cutoff, **kwargs):
240 | super().__init__()
241 |
242 | num_embedding = 119 # number of all elements
243 | self.atom_embedding = nn.Embedding(num_embedding, hidden_state_size)
244 | self.cutoff = cutoff
245 | self.num_interactions = num_interactions
246 | self.hidden_state_size = hidden_state_size
247 | self.edge_embedding_size = 20
248 |
249 | self.message_layers = nn.ModuleList(
250 | [
251 | PainnMessage(self.hidden_state_size, self.edge_embedding_size, self.cutoff)
252 | for _ in range(self.num_interactions)
253 | ]
254 | )
255 |
256 | self.update_layers = nn.ModuleList(
257 | [
258 | PainnUpdate(self.hidden_state_size)
259 | for _ in range(self.num_interactions)
260 | ]
261 | )
262 |
263 | self.linear_1 = nn.Linear(self.hidden_state_size, self.hidden_state_size)
264 | self.silu = nn.SiLU()
265 | self.linear_2 = nn.Linear(self.hidden_state_size, 1)
266 | U_in_0 = torch.randn(self.hidden_state_size, 500) / 500 ** 0.5
267 | U_out_1 = torch.randn(self.hidden_state_size, 500) / 500 ** 0.5
268 | U_in_1 = torch.randn(self.hidden_state_size, 500) / 500 ** 0.5
269 | self.register_buffer('U_in_0', U_in_0)
270 | self.register_buffer('U_out_1', U_out_1)
271 | self.register_buffer('U_in_1', U_in_1)
272 |
273 | def forward(self, input_dict, compute_forces=True):
274 | # edge offset
275 | num_atoms = input_dict['num_atoms']
276 | num_pairs = input_dict['num_pairs']
277 |
278 | edge = input_dict['pairs']
279 | edge_offset = torch.cumsum(
280 | torch.cat((torch.tensor([0],
281 | device=num_atoms.device,
282 | dtype=num_atoms.dtype,
283 | ), num_atoms[:-1])),
284 | dim=0
285 | )
286 | edge_offset = torch.repeat_interleave(edge_offset, num_pairs)
287 | edge = edge + edge_offset.unsqueeze(-1)
288 | edge_diff = input_dict['n_diff']
289 | if compute_forces:
290 | edge_diff.requires_grad_()
291 | edge_dist = torch.linalg.norm(edge_diff, dim=1)
292 |
293 | node_scalar = self.atom_embedding(input_dict['elems'])
294 | node_vector = torch.zeros((input_dict['coord'].shape[0], 3, self.hidden_state_size),
295 | device=edge_diff.device,
296 | dtype=edge_diff.dtype,
297 | )
298 |
299 | for message_layer, update_layer in zip(self.message_layers, self.update_layers):
300 | node_scalar, node_vector = message_layer(node_scalar, node_vector, edge, edge_diff, edge_dist)
301 | node_scalar, node_vector = update_layer(node_scalar, node_vector)
302 |
303 | x0 = node_scalar
304 | z1 = self.linear_1(x0)
305 | z1.retain_grad()
306 | x1 = self.silu(z1)
307 | node_scalar = self.linear_2(x1)
308 |
309 | node_scalar.squeeze_()
310 |
311 | image_idx = torch.arange(input_dict['num_atoms'].shape[0],
312 | device=edge.device,
313 | )
314 | image_idx = torch.repeat_interleave(image_idx, num_atoms)
315 |
316 | energy = torch.zeros_like(input_dict['num_atoms']).float()
317 |
318 | energy.index_add_(0, image_idx, node_scalar)
319 | result_dict = {'energy': energy}
320 |
321 | if compute_forces:
322 | dE_ddiff = torch.autograd.grad(
323 | energy,
324 | edge_diff,
325 | grad_outputs=torch.ones_like(energy),
326 | retain_graph=True,
327 | create_graph=True,
328 | )[0]
329 |
330 | # diff = R_j - R_i, so -dE/dR_j = -dE/ddiff, -dE/R_i = dE/ddiff
331 | i_forces = torch.zeros_like(input_dict['coord']).index_add(0, edge[:, 0], dE_ddiff)
332 | j_forces = torch.zeros_like(input_dict['coord']).index_add(0, edge[:, 1], -dE_ddiff)
333 | forces = i_forces + j_forces
334 |
335 | result_dict['forces'] = forces
336 |
337 | fps = torch.sum((x0.detach() @ self.U_in_0) * (z1.grad.detach() @ self.U_out_1) * 500 ** 0.5 + x1.detach() @ self.U_in_1, dim=0)
338 | # result_dict['ll_out'] = {
339 | # 'll_out_x0': x0.detach(),
340 | # 'll_out_z1': z1.grad.detach(),
341 | # 'll_out_x1': x1.detach(),
342 | # }
343 | result_dict['fps'] = fps
344 | del z1.grad
345 | return result_dict
346 |
--------------------------------------------------------------------------------
/PaiNN/select.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PaiNN.kernel import KernelMatrix
3 |
4 | def max_diag(matrix: KernelMatrix, batch_size: int) -> torch.Tensor:
5 | """
6 | maximize uncertainty selection method
7 | """
8 | return torch.argsort(matrix.get_diag())[-batch_size:]
9 |
10 | def max_det_greedy(matrix: KernelMatrix, batch_size: int) -> torch.Tensor:
11 | vec_c = matrix.get_diag()
12 | batch_idxs = [torch.argmax(vec_c)]
13 |
14 | l_n = None
15 |
16 | for n in range(1, batch_size):
17 | opt_idx = batch_idxs[-1]
18 | l_n_T_l_n = 0.0 if l_n is None else torch.einsum('w,wc->c', l_n[:, opt_idx], l_n)
19 | mat_col = matrix.get_column(opt_idx)
20 | update = (1 / torch.sqrt(vec_c[opt_idx])) * (mat_col - l_n_T_l_n)
21 | vec_c = vec_c - update ** 2
22 | l_n = update.unsqueeze(0) if l_n is None else torch.concat((l_n, update.unsqueeze(0)))
23 | new_idx = torch.argmax(vec_c)
24 | if vec_c[new_idx] <= 1e-12 or new_idx in batch_idxs:
25 | break
26 | else:
27 | batch_idxs.append(new_idx)
28 |
29 | batch_idxs = torch.hstack(batch_idxs)
30 | return batch_idxs
31 |
32 | def max_det_greedy_local(matrix: KernelMatrix, batch_size: int, num_atoms: torch.Tensor) -> torch.Tensor:
33 | vec_c = matrix.get_diag()
34 | batch_idxs = [torch.argmax(vec_c)]
35 |
36 | l_n = None
37 | image_idx = torch.arange(
38 | num_atoms.shape[0],
39 | device=num_atoms.device,
40 | )
41 | image_idx = torch.repeat_interleave(image_idx, num_atoms)
42 |
43 | selected_idx = []
44 | n = 0
45 | while len(selected_idx) < batch_size:
46 | opt_idx = batch_idxs[-1]
47 | l_n_T_l_n = 0.0 if l_n is None else torch.einsum('w,wc->c', l_n[:, opt_idx], l_n)
48 | mat_col = matrix.get_column(opt_idx)
49 | update = (1 / torch.sqrt(vec_c[opt_idx])) * (mat_col - l_n_T_l_n)
50 | vec_c = vec_c - update ** 2
51 | l_n = update.unsqueeze(0) if l_n is None else torch.concat((l_n, update.unsqueeze(0)))
52 | new_idx = torch.argmax(vec_c)
53 | if vec_c[new_idx] <= 1e-12 or new_idx in batch_idxs:
54 | break
55 | else:
56 | batch_idxs.append(new_idx)
57 | if image_idx[new_idx] not in selected_idx:
58 | selected_idx.append(image_idx[new_idx])
59 |
60 | return torch.stack(selected_idx)
61 |
62 | def lcmd_greedy(matrix: KernelMatrix, batch_size: int, n_train: int) -> torch.Tensor:
63 | """
64 | Only accept matrix with double dtype!!!
65 | Selects batch elements by greedily picking those with the maximum distance in the largest cluster,
66 | including training points. Assumes that the last ``n_train`` columns of ``matrix`` correspond to training points.
67 |
68 | :param matrix: Kernel matrix.
69 | :param batch_size: Size of the AL batch.
70 | :param n_train: Number of training structures.
71 | :return: Indices of the selected structures.
72 | """
73 | # assumes that the matrix contains pool samples, optionally followed by train samples
74 | n_pool = matrix.get_number_of_columns() - n_train
75 | sq_dists = matrix.get_diag()
76 | batch_idxs = [n_pool if n_train > 0 else torch.argmax(sq_dists)]
77 | closest_idxs = torch.zeros(n_pool, dtype=int, device=sq_dists.device)
78 | min_sq_dists = matrix.get_sq_dists(batch_idxs[-1])[:n_pool]
79 |
80 | for i in range(1, batch_size + n_train):
81 | if i < n_train:
82 | batch_idxs.append(n_pool+i)
83 | else:
84 | bincount = torch.bincount(closest_idxs, weights=min_sq_dists, minlength=i)
85 | max_bincount = torch.max(bincount)
86 | new_idx = torch.argmax(torch.where(
87 | torch.gather(bincount, 0, closest_idxs) == max_bincount,
88 | min_sq_dists,
89 | torch.zeros_like(min_sq_dists)-float("Inf")))
90 | batch_idxs.append(new_idx)
91 | sq_dists = matrix.get_sq_dists(batch_idxs[-1])[:n_pool]
92 | new_min = sq_dists < min_sq_dists
93 | closest_idxs = torch.where(new_min, i, closest_idxs)
94 | min_sq_dists = torch.where(new_min, sq_dists, min_sq_dists)
95 |
96 | return torch.hstack(batch_idxs[n_train:])
97 |
98 | def deterministic_CUR(matrix: KernelMatrix, batch_size: int, lambd: float=0.1, eposilon: float=1E-3) -> torch.Tensor:
99 | """
100 | CUR matrix decomposition, the matrix must be normalized.
101 | """
102 | n = matrix.num_columns
103 | W = torch.zeros(n, n)
104 | I = torch.eye(n, n)
105 | while True:
106 | W_t = W
107 | for i in range(matrix.num_columns):
108 | z = matrix.get_column(i) @ (I - W) + matrix.get_diag()[i] * W[i]
109 | coeff = 1 - lambd / torch.linalg.norm(z)
110 | W[i] = coeff * z if coeff > 0 else 0 * z
111 | if torch.linalg.norm(W - W_t) < eposilon:
112 | break
113 |
114 | return torch.argsort(torch.linalg.norm(W, dim=1))[-batch_size:]
115 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
PaiNN-model introduction
2 | This is a simple implementation of [PaiNN](https://arxiv.org/abs/2102.03150) model and active learning workflow for fitting interatomic potentials.
3 | The learned features or gradients in the model are used for active learning. Several selection methods are implemented.
4 | All the active learning codes are to be tested.
5 | ## Documentation
6 | No documentation yet.
7 |
8 | ## Quick Start
9 |
10 | How to install
11 |
12 | This code is only tested on [**Python>=3.8.0**](https://www.python.org/) and [**PyTorch>=1.10**](https://pytorch.org/get-started/locally/).
13 | Requirements: [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter)(if you want to use active learning),
14 | [toml](https://toml.io/en/), [myqueue](https://myqueue.readthedocs.io/en/latest/installation.html)(if you want to submit jobs automatically).
15 |
16 | ```bash
17 | $ conda install pytorch-scatter -c pyg
18 | $ conda install -c conda-forge toml
19 | $ python3 -m pip install myqueue
20 | $ conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia
21 | $ git clone https://github.com/Yangxinsix/PaiNN-model.git
22 | $ cd PaiNN-model
23 | $ python -m pip install -U .
24 | ```
25 |
26 |
27 |
28 |
29 | How to use
30 |
31 | * See `train.py` in `scripts` for training, and `md_run.py` for running MD simulations by using ASE.
32 | * See `al_select.py` for active learning.
33 | * See `flow.py` for distributing and submitting active learning jobs.
34 |
35 |
36 |
--------------------------------------------------------------------------------
/scripts/MD.traj:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nityasagarjena/PaiNN-model/6ee9a59c3cd544b5e31d4936cb1e75e9bded6a6e/scripts/MD.traj
--------------------------------------------------------------------------------
/scripts/arguments.toml:
--------------------------------------------------------------------------------
1 | node_size = 40
2 | num_interactions = 5
3 | cutoff = 5.0
4 | split_file = "/home/energy/xinyang/work/active_learning_test/datasplits.json"
5 | output_dir = "model_output"
6 | dataset = "/home/energy/xinyang/work/active_learning_test/md17aspirin.traj"
7 | max_steps = 1000000
8 | device = "cuda"
9 | batch_size = 32
10 | initial_lr = 0.0001
11 | forces_weight = 0.99
12 | log_interval = 1000
13 | normalization = true
14 | atomwise_normalization = true
15 | stop_tolerance = 10
16 |
--------------------------------------------------------------------------------
/scripts/gpu_info:
--------------------------------------------------------------------------------
1 | Mon Aug 1 11:30:57 2022
2 | +-----------------------------------------------------------------------------+
3 | | NVIDIA-SMI 470.74 Driver Version: 470.74 CUDA Version: 11.4 |
4 | |-------------------------------+----------------------+----------------------+
5 | | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
6 | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
7 | | | | MIG M. |
8 | |===============================+======================+======================|
9 | | 0 NVIDIA GeForce ... On | 00000000:3F:00.0 Off | N/A |
10 | | 30% 39C P8 24W / 350W | 1MiB / 24268MiB | 0% Default |
11 | | | | N/A |
12 | +-------------------------------+----------------------+----------------------+
13 |
14 | +-----------------------------------------------------------------------------+
15 | | Processes: |
16 | | GPU GI CI PID Type Process name GPU Memory |
17 | | ID ID Usage |
18 | |=============================================================================|
19 | | No running processes found |
20 | +-----------------------------------------------------------------------------+
21 |
--------------------------------------------------------------------------------
/scripts/gpu_run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -ex
2 |
3 | #SBATCH --mail-user=xinyang@dtu.dk
4 | #SBATCH --mail-type=END,FAIL
5 | #SBATCH --partition=sm3090
6 | #SBATCH -N 1 # Minimum of 1 node
7 | #SBATCH -n 8 # 10 MPI processes per node
8 | #SBATCH --time=7-00:00:00
9 | #SBATCH --job=PaiNN-training
10 | #SBATCH --output=runner_output.log
11 | #SBATCH --gres=gpu:RTX3090:1
12 |
13 | #module load ASE/3.22.0-intel-2020b
14 | #module load Python/3.8.6-GCCcore-10.2.0
15 |
16 | export MKL_NUM_THREADS=1
17 | export NUMEXPR_NUM_THREADS=1
18 | export OMP_NUM_THREADS=1
19 | export OPENBLAS_NUM_THREADS=1
20 |
21 | nvidia-smi > gpu_info
22 | ulimit -s unlimited
23 | python3 md_run.py
24 |
--------------------------------------------------------------------------------
/scripts/md_run.py:
--------------------------------------------------------------------------------
1 | from ase.md.langevin import Langevin
2 | from ase.calculators.plumed import Plumed
3 | from ase import units
4 | from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
5 | from ase.io import read, write, Trajectory
6 |
7 | import numpy as np
8 | import torch
9 | import sys
10 | import glob
11 |
12 | from PaiNN.data import AseDataset, collate_atomsdata
13 | from PaiNN.model import PainnModel_predict
14 | from PaiNN.calculator import MLCalculator
15 | from ase.constraints import FixAtoms
16 |
17 | # load model
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 | # model_pth = glob.glob('/home/energy/xinyang/work/Au_MD/graphnn/ads_images/ensembles/*_layer/runs/model_outputs/best_model.pth')
20 | # # models = []
21 | # for each in model_pth:
22 | # node_size = int(each.split('/')[-4].split('_')[0])
23 | # num_inter = int(each.split('/')[-4].split('_')[2])
24 | # model = PainnModel(num_interactions=num_inter, hidden_state_size=node_size, cutoff=5.0)
25 | # model.to(device)
26 | # state_dict = torch.load(each)
27 | # model.load_state_dict(state_dict["model"])
28 | # models.append(model)
29 | #
30 | # encalc = EnsembleCalculator(models)
31 |
32 | # set md parameters
33 | #dataset="/home/energy/xinyang/work/Au_MD/training_loop/Au_larger/dataset_selector/dataset_repository/corrected_ads_images.traj"
34 | #images = read(dataset, ':')
35 | #indices = [i for i in range(len(images)) if images[i].info['system'] == '1OH']
36 | #atoms = images[np.random.choice(indices)]
37 | #atoms = read('MD.traj', -1)
38 | #cons = FixAtoms(mask=atoms.positions[:, 2] < 5.9)
39 | #atoms.set_constraint(cons)
40 |
41 | model = PainnModel_predict(num_interactions=3, hidden_state_size=128, cutoff=5.0)
42 | model.to(device)
43 | state_dict = torch.load('/home/energy/xinyang/work/Au_MD/graphnn/pure_water/runs/model_outputs/best_model.pth')
44 | new_names = ["linear_1.weight", "linear_1.bias", "linear_2.weight", "linear_2.bias"]
45 | old_names = ["readout_mlp.0.weight", "readout_mlp.0.bias", "readout_mlp.2.weight", "readout_mlp.2.bias"]
46 | for old, new in zip(old_names, new_names):
47 | state_dict['model'][new] = state_dict['model'].pop(old)
48 |
49 | state_dict["model"]["U_in_0"] = torch.randn(128, 500) / 500 ** 0.5
50 | state_dict["model"]["U_out_1"] = torch.randn(128, 500) / 500 ** 0.5
51 | state_dict["model"]["U_in_1"] = torch.randn(128, 500) / 500 ** 0.5
52 | model.load_state_dict(state_dict["model"])
53 | mlcalc = MLCalculator(model)
54 |
55 | atoms = read('water_O2.cif')
56 | atoms.calc = mlcalc
57 | atoms.get_potential_energy()
58 |
59 | #collect_traj = Trajectory('bad_struct.traj', 'a')
60 | steps = 0
61 | def printenergy(a=atoms): # store a reference to atoms in the definition.
62 | """Function to print the potential, kinetic and total energy."""
63 | epot = a.get_potential_energy()
64 | ekin = a.get_kinetic_energy()
65 | temp = ekin / (1.5 * units.kB) / a.get_global_number_of_atoms()
66 | global steps
67 | steps += 1
68 | with open('ensemble.log', 'a') as f:
69 | f.write(
70 | f"Steps={steps:12.3f} Epot={epot:12.3f} Ekin={ekin:12.3f} temperature={temp:8.2f}\n")
71 |
72 | #atoms.calc = encalc
73 | MaxwellBoltzmannDistribution(atoms, temperature_K=350)
74 | dyn = Langevin(atoms, 0.25 * units.fs, temperature_K=350, friction=0.1)
75 | dyn.attach(printenergy, interval=1)
76 |
77 | traj = Trajectory('MD.traj', 'w', atoms)
78 | dyn.attach(traj.write, interval=400)
79 | dyn.run(10000000)
80 |
--------------------------------------------------------------------------------
/scripts/runner_output.log:
--------------------------------------------------------------------------------
1 | + '[' -z '' ']'
2 | + case "$-" in
3 | + __lmod_vx=x
4 | + '[' -n x ']'
5 | + set +x
6 | Shell debugging temporarily silenced: export LMOD_SH_DBG_ON=1 for this output (/usr/share/lmod/lmod/init/bash)
7 | Shell debugging restarted
8 | + unset __lmod_vx
9 | + export MKL_NUM_THREADS=1
10 | + MKL_NUM_THREADS=1
11 | + export NUMEXPR_NUM_THREADS=1
12 | + NUMEXPR_NUM_THREADS=1
13 | + export OMP_NUM_THREADS=1
14 | + OMP_NUM_THREADS=1
15 | + export OPENBLAS_NUM_THREADS=1
16 | + OPENBLAS_NUM_THREADS=1
17 | + nvidia-smi
18 | + ulimit -s unlimited
19 | + python3 md_run.py
20 | slurmstepd: error: *** JOB 5233695 ON s002 CANCELLED AT 2022-08-01T11:34:52 ***
21 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import json, os, sys, toml
4 | from pathlib import Path
5 | import argparse
6 | import logging
7 | import itertools
8 | import torch
9 | import time
10 |
11 | from PaiNN.data import AseDataset, collate_atomsdata
12 | from PaiNN.model import PainnModel
13 |
14 | def get_arguments(arg_list=None):
15 | parser = argparse.ArgumentParser(
16 | description="Train graph convolution network", fromfile_prefix_chars="+"
17 | )
18 | parser.add_argument(
19 | "--load_model",
20 | type=str,
21 | help="Load model parameters from previous run",
22 | )
23 | parser.add_argument(
24 | "--cutoff",
25 | type=float,
26 | help="Atomic interaction cutoff distance [�~E]",
27 | )
28 | parser.add_argument(
29 | "--split_file",
30 | type=str,
31 | help="Train/test/validation split file json",
32 | )
33 | parser.add_argument(
34 | "--num_interactions",
35 | type=int,
36 | help="Number of interaction layers used",
37 | )
38 | parser.add_argument(
39 | "--node_size", type=int, help="Size of hidden node states"
40 | )
41 | parser.add_argument(
42 | "--output_dir",
43 | type=str,
44 | help="Path to output directory",
45 | )
46 | parser.add_argument(
47 | "--dataset", type=str, help="Path to ASE trajectory",
48 | )
49 | parser.add_argument(
50 | "--max_steps",
51 | type=int,
52 | help="Maximum number of optimisation steps",
53 | )
54 | parser.add_argument(
55 | "--device",
56 | type=str,
57 | help="Set which device to use for training e.g. 'cuda' or 'cpu'",
58 | )
59 | parser.add_argument(
60 | "--batch_size", type=int, help="Number of molecules per minibatch",
61 | )
62 | parser.add_argument(
63 | "--initial_lr", type=float, help="Initial learning rate",
64 | )
65 | parser.add_argument(
66 | "--forces_weight",
67 | type=float,
68 | help="Tradeoff between training on forces (weight=1) and energy (weight=0)",
69 | )
70 | parser.add_argument(
71 | "--log_inverval",
72 | type=int,
73 | help="The interval of model evaluation",
74 | )
75 | parser.add_argument(
76 | "--normalization",
77 | action="store_true",
78 | help="Enable normalization of the model",
79 | )
80 | parser.add_argument(
81 | "--atomwise_normalization",
82 | action="store_true",
83 | help="Enable atomwise normalization",
84 | )
85 | parser.add_argument(
86 | "--stop_tolerance",
87 | type=int,
88 | help="Stop training when validation loss is larger than best loss for 'stop_tolerance' steps",
89 | )
90 | parser.add_argument(
91 | "--cfg",
92 | type=str,
93 | help="Path to config file. e.g. 'arguments.toml'"
94 | )
95 |
96 | return parser.parse_args(arg_list)
97 |
98 | def split_data(dataset, args):
99 | # Load or generate splits
100 | if args.split_file:
101 | with open(args.split_file, "r") as fp:
102 | splits = json.load(fp)
103 | else:
104 | datalen = len(dataset)
105 | num_validation = int(math.ceil(datalen * 0.10))
106 | indices = np.random.permutation(len(dataset))
107 | splits = {
108 | "train": indices[num_validation:].tolist(),
109 | "validation": indices[:num_validation].tolist(),
110 | }
111 |
112 | # Save split file
113 | with open(os.path.join(args.output_dir, "datasplits.json"), "w") as f:
114 | json.dump(splits, f)
115 |
116 | # Split the dataset
117 | datasplits = {}
118 | for key, indices in splits.items():
119 | datasplits[key] = torch.utils.data.Subset(dataset, indices)
120 | return datasplits
121 |
122 | def forces_criterion(predicted, target, reduction="mean"):
123 | # predicted, target are (bs, max_nodes, 3) tensors
124 | # node_count is (bs) tensor
125 | diff = predicted - target
126 | total_squared_norm = torch.linalg.norm(diff, dim=1) # bs
127 | if reduction == "mean":
128 | scalar = torch.mean(total_squared_norm)
129 | elif reduction == "sum":
130 | scalar = torch.sum(total_squared_norm)
131 | else:
132 | raise ValueError("Reduction must be 'mean' or 'sum'")
133 | return scalar
134 |
135 | def get_normalization(dataset, per_atom=True):
136 | # Use double precision to avoid overflows
137 | x_sum = torch.zeros(1, dtype=torch.double)
138 | x_2 = torch.zeros(1, dtype=torch.double)
139 | num_objects = 0
140 | for i, sample in enumerate(dataset):
141 | if i == 0:
142 | # Estimate "bias" from 1 sample
143 | # to avoid overflows for large valued datasets
144 | if per_atom:
145 | bias = sample["energy"] / sample["num_atoms"]
146 | else:
147 | bias = sample["energy"]
148 | x = sample["energy"]
149 | if per_atom:
150 | x = x / sample["num_atoms"]
151 | x -= bias
152 | x_sum += x
153 | x_2 += x ** 2.0
154 | num_objects += 1
155 | # Var(X) = E[X^2] - E[X]^2
156 | x_mean = x_sum / num_objects
157 | x_var = x_2 / num_objects - x_mean ** 2.0
158 | x_mean = x_mean + bias
159 |
160 | default_type = torch.get_default_dtype()
161 |
162 | return x_mean.type(default_type), torch.sqrt(x_var).type(default_type)
163 |
164 | def eval_model(model, dataloader, device, forces_weight):
165 | energy_running_ae = 0
166 | energy_running_se = 0
167 |
168 | forces_running_l2_ae = 0
169 | forces_running_l2_se = 0
170 | forces_running_c_ae = 0
171 | forces_running_c_se = 0
172 | forces_running_loss = 0
173 |
174 | running_loss = 0
175 | count = 0
176 | forces_count = 0
177 | criterion = torch.nn.MSELoss()
178 |
179 | for batch in dataloader:
180 | device_batch = {
181 | k: v.to(device=device, non_blocking=True) for k, v in batch.items()
182 | }
183 | out = model(device_batch)
184 |
185 | # counts
186 | count += batch["energy"].shape[0]
187 | forces_count += batch['forces'].shape[0]
188 |
189 | # use mean square loss here
190 | forces_loss = forces_criterion(out["forces"], device_batch["forces"]).item()
191 | energy_loss = criterion(out["energy"], device_batch["energy"]).item() #problem here
192 | total_loss = forces_weight * forces_loss + (1 - forces_weight) * energy_loss
193 | running_loss += total_loss * batch["energy"].shape[0]
194 |
195 | # energy errors
196 | outputs = {key: val.detach().cpu().numpy() for key, val in out.items()}
197 | energy_targets = batch["energy"].detach().cpu().numpy()
198 | energy_running_ae += np.sum(np.abs(energy_targets - outputs["energy"]), axis=0)
199 | energy_running_se += np.sum(
200 | np.square(energy_targets - outputs["energy"]), axis=0
201 | )
202 |
203 | # force errors
204 | forces_targets = batch["forces"].detach().cpu().numpy()
205 | forces_diff = forces_targets - outputs["forces"]
206 | forces_l2_norm = np.sqrt(np.sum(np.square(forces_diff), axis=1))
207 |
208 | forces_running_c_ae += np.sum(np.abs(forces_diff))
209 | forces_running_c_se += np.sum(np.square(forces_diff))
210 |
211 | forces_running_l2_ae += np.sum(np.abs(forces_l2_norm))
212 | forces_running_l2_se += np.sum(np.square(forces_l2_norm))
213 |
214 | energy_mae = energy_running_ae / count
215 | energy_rmse = np.sqrt(energy_running_se / count)
216 |
217 | forces_l2_mae = forces_running_l2_ae / forces_count
218 | forces_l2_rmse = np.sqrt(forces_running_l2_se / forces_count)
219 |
220 | forces_c_mae = forces_running_c_ae / (forces_count * 3)
221 | forces_c_rmse = np.sqrt(forces_running_c_se / (forces_count * 3))
222 |
223 | total_loss = running_loss / count
224 |
225 | evaluation = {
226 | "energy_mae": energy_mae,
227 | "energy_rmse": energy_rmse,
228 | "forces_l2_mae": forces_l2_mae,
229 | "forces_l2_rmse": forces_l2_rmse,
230 | "forces_mae": forces_c_mae,
231 | "forces_rmse": forces_c_rmse,
232 | "sqrt(total_loss)": np.sqrt(total_loss),
233 | }
234 |
235 | return evaluation
236 |
237 | def update_namespace(ns, d):
238 | for k, v in d.items():
239 | if not ns.__dict__.get(k):
240 | ns.__dict__[k] = v
241 |
242 | class EarlyStopping():
243 | def __init__(self, tolerance=5, min_delta=0):
244 |
245 | self.tolerance = tolerance
246 | self.min_delta = min_delta
247 | self.counter = 0
248 | self.early_stop = False
249 |
250 | def __call__(self, val_loss, best_loss):
251 | if best_loss < 1.0 and (val_loss - best_loss) > self.min_delta:
252 | self.counter +=1
253 | if self.counter >= self.tolerance:
254 | self.early_stop = True
255 |
256 | return self.early_stop
257 |
258 | def main():
259 | args = get_arguments()
260 | if args.cfg:
261 | with open(args.cfg, 'r') as f:
262 | params = toml.load(f)
263 | update_namespace(args, params)
264 |
265 | # Setup logging
266 | os.makedirs(args.output_dir, exist_ok=True)
267 | logging.basicConfig(
268 | level=logging.DEBUG,
269 | format="%(asctime)s [%(levelname)-5.5s] %(message)s",
270 | handlers=[
271 | logging.FileHandler(
272 | os.path.join(args.output_dir, "printlog.txt"), mode="w"
273 | ),
274 | logging.StreamHandler(),
275 | ],
276 | )
277 |
278 | # Save command line args
279 | with open(os.path.join(args.output_dir, "commandline_args.txt"), "w") as f:
280 | f.write("\n".join(sys.argv[1:]))
281 | # Save parsed command line arguments
282 | with open(os.path.join(args.output_dir, "arguments.json"), "w") as f:
283 | json.dump(vars(args), f)
284 |
285 | # Create device
286 | device = torch.device(args.device)
287 | # Put a tensor on the device before loading data
288 | # This way the GPU appears to be in use when other users run gpustat
289 | torch.tensor([0], device=device)
290 |
291 | # Setup dataset and loader
292 | logging.info("loading data %s", args.dataset)
293 | dataset = AseDataset(
294 | args.dataset,
295 | cutoff = args.cutoff,
296 | )
297 |
298 | with open(args.split_file, 'r') as f:
299 | splits = json.load(f)
300 |
301 | datasplits = {
302 | 'train': torch.utils.data.Subset(dataset, splits['train']),
303 | 'validation': torch.utils.data.Subset(dataset, splits['validation']),
304 | }
305 |
306 | train_loader = torch.utils.data.DataLoader(
307 | datasplits["train"],
308 | args.batch_size,
309 | sampler=torch.utils.data.RandomSampler(datasplits["train"]),
310 | collate_fn=collate_atomsdata,
311 | )
312 | val_loader = torch.utils.data.DataLoader(
313 | datasplits["validation"],
314 | args.batch_size,
315 | collate_fn=collate_atomsdata,
316 | )
317 |
318 | logging.info("Computing mean and variance")
319 | target_mean, target_stddev = get_normalization(
320 | datasplits["train"],
321 | per_atom=args.atomwise_normalization,
322 | )
323 | logging.debug("target_mean=%f, target_stddev=%f" % (target_mean, target_stddev))
324 |
325 | net = PainnModel(
326 | num_interactions=args.num_interactions,
327 | hidden_state_size=args.node_size,
328 | cutoff=args.cutoff,
329 | normalization=args.normalization,
330 | target_mean=target_mean.tolist(),
331 | target_stddev=target_stddev.tolist(),
332 | atomwise_normalization=args.atomwise_normalization,
333 | )
334 | net.to(device)
335 |
336 | optimizer = torch.optim.Adam(net.parameters(), lr=args.initial_lr)
337 | criterion = torch.nn.MSELoss()
338 | scheduler_fn = lambda step: 0.96 ** (step / 100000)
339 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_fn)
340 | early_stop = EarlyStopping(tolerance=args.stop_tolerance)
341 |
342 | running_loss = 0
343 | running_loss_count = 0
344 | best_val_loss = np.inf
345 | step = 0
346 | training_time = 0
347 |
348 | if args.load_model:
349 | state_dict = torch.load(args.load_model)
350 | net.load_state_dict(state_dict["model"])
351 | # step = state_dict["step"]
352 | best_val_loss = state_dict["best_val_loss"]
353 | # optimizer.load_state_dict(state_dict["optimizer"])
354 | # scheduler.load_state_dict(state_dict["scheduler"])
355 |
356 | for epoch in itertools.count():
357 | for batch_host in train_loader:
358 | start = time.time()
359 | # Transfer to 'device'
360 | batch = {
361 | k: v.to(device=device, non_blocking=True)
362 | for (k, v) in batch_host.items()
363 | }
364 | # Reset gradient
365 | optimizer.zero_grad()
366 |
367 | # Forward, backward and optimize
368 | outputs = net(
369 | batch, compute_forces=bool(args.forces_weight)
370 | )
371 | energy_loss = criterion(outputs["energy"], batch["energy"])
372 | if args.forces_weight:
373 | forces_loss = forces_criterion(outputs['forces'], batch['forces'])
374 | else:
375 | forces_loss = 0.0
376 | total_loss = (
377 | args.forces_weight * forces_loss
378 | + (1 - args.forces_weight) * energy_loss
379 | )
380 | total_loss.backward()
381 | optimizer.step()
382 | running_loss += total_loss.item() * batch["energy"].shape[0]
383 | running_loss_count += batch["energy"].shape[0]
384 | training_time += time.time() - start
385 |
386 | # print(step, loss_value)
387 | # Validate and save model
388 | if (step % args.log_interval == 0) or ((step + 1) == args.max_steps):
389 | eval_start = time.time()
390 | train_loss = running_loss / running_loss_count
391 | running_loss = running_loss_count = 0
392 |
393 | eval_dict = eval_model(net, val_loader, device, args.forces_weight)
394 | eval_formatted = ", ".join(
395 | ["%s=%g" % (k, v) for (k, v) in eval_dict.items()]
396 | )
397 |
398 | logging.info(
399 | "step=%d, %s, sqrt(train_loss)=%g, max memory used=%g, training time=%g min, eval time=%g min",
400 | step,
401 | eval_formatted,
402 | math.sqrt(train_loss),
403 | torch.cuda.max_memory_allocated() / 2**20,
404 | training_time / 60,
405 | (time.time() - eval_start) / 60
406 | )
407 | training_time = 0
408 | # Save checkpoint
409 | if not early_stop(eval_dict["sqrt(total_loss)"], best_val_loss):
410 | best_val_loss = eval_dict["sqrt(total_loss)"]
411 | torch.save(
412 | {
413 | "model": net.state_dict(),
414 | "optimizer": optimizer.state_dict(),
415 | "scheduler": scheduler.state_dict(),
416 | "step": step,
417 | "best_val_loss": best_val_loss,
418 | "node_size": args.node_size,
419 | "num_layer": args.num_interactions,
420 | "cutoff": args.cutoff,
421 | },
422 | os.path.join(args.output_dir, "best_model.pth"),
423 | )
424 | else:
425 | sys.exit(0)
426 |
427 | step += 1
428 |
429 | scheduler.step()
430 |
431 | if step >= args.max_steps:
432 | logging.info("Max steps reached, exiting")
433 | torch.save(
434 | {
435 | "model": net.state_dict(),
436 | "optimizer": optimizer.state_dict(),
437 | "scheduler": scheduler.state_dict(),
438 | "step": step,
439 | "best_val_loss": best_val_loss,
440 | "node_size": args.node_size,
441 | "num_layer": args.num_interactions,
442 | "cutoff": args.cutoff,
443 | },
444 | os.path.join(args.output_dir, "exit_model.pth"),
445 | )
446 | sys.exit(0)
447 |
448 | if __name__ == "__main__":
449 | main()
450 |
--------------------------------------------------------------------------------
/scripts/water_O2.cif:
--------------------------------------------------------------------------------
1 |
2 | #======================================================================
3 | # CRYSTAL DATA
4 | #----------------------------------------------------------------------
5 | data_VESTA_phase_1
6 |
7 | _chemical_name_common ''
8 | _cell_length_a 12.600000
9 | _cell_length_b 12.600000
10 | _cell_length_c 12.600000
11 | _cell_angle_alpha 90.000000
12 | _cell_angle_beta 90.000000
13 | _cell_angle_gamma 90.000000
14 | _cell_volume 2000.376182
15 | _space_group_name_H-M_alt 'P 1'
16 | _space_group_IT_number 1
17 |
18 | loop_
19 | _space_group_symop_operation_xyz
20 | 'x, y, z'
21 |
22 | loop_
23 | _atom_site_label
24 | _atom_site_occupancy
25 | _atom_site_fract_x
26 | _atom_site_fract_y
27 | _atom_site_fract_z
28 | _atom_site_adp_type
29 | _atom_site_B_iso_or_equiv
30 | _atom_site_type_symbol
31 | O1 1.0 0.291790 0.511480 0.379270 Biso 1.000000 O
32 | O2 1.0 0.303320 0.251860 0.188060 Biso 1.000000 O
33 | O3 1.0 0.091950 0.077880 0.623230 Biso 1.000000 O
34 | O4 1.0 0.741690 0.283460 0.638230 Biso 1.000000 O
35 | O5 1.0 0.608390 0.865710 0.971920 Biso 1.000000 O
36 | O6 1.0 0.147250 0.738240 0.186290 Biso 1.000000 O
37 | O7 1.0 0.820100 0.652800 0.148430 Biso 1.000000 O
38 | O8 1.0 0.967270 0.822160 0.868330 Biso 1.000000 O
39 | O9 1.0 0.854540 0.474700 0.657050 Biso 1.000000 O
40 | O10 1.0 0.598820 0.804170 0.766680 Biso 1.000000 O
41 | O11 1.0 0.056380 0.392800 0.220790 Biso 1.000000 O
42 | O12 1.0 0.398640 0.670010 0.680850 Biso 1.000000 O
43 | O13 1.0 0.434050 0.759620 0.028000 Biso 1.000000 O
44 | O14 1.0 0.084340 0.659930 0.006530 Biso 1.000000 O
45 | O15 1.0 0.266970 0.136120 0.368690 Biso 1.000000 O
46 | O16 1.0 0.800830 0.861400 0.111570 Biso 1.000000 O
47 | O17 1.0 0.650660 0.990640 0.439370 Biso 1.000000 O
48 | O18 1.0 0.277250 0.881160 0.447140 Biso 1.000000 O
49 | O19 1.0 0.545150 0.745470 0.516140 Biso 1.000000 O
50 | O20 1.0 0.667440 0.200890 0.252520 Biso 1.000000 O
51 | O21 1.0 0.745940 0.616710 0.802360 Biso 1.000000 O
52 | O22 1.0 0.344900 0.643820 0.219180 Biso 1.000000 O
53 | O23 1.0 0.544780 0.357790 0.630150 Biso 1.000000 O
54 | O24 1.0 0.622730 0.994000 0.672940 Biso 1.000000 O
55 | O25 1.0 0.198100 0.690050 0.505930 Biso 1.000000 O
56 | O26 1.0 0.350940 0.438600 0.087820 Biso 1.000000 O
57 | O27 1.0 0.423310 0.027640 0.497210 Biso 1.000000 O
58 | O28 1.0 0.732950 0.325120 0.968620 Biso 1.000000 O
59 | O29 1.0 0.852830 0.610750 0.364470 Biso 1.000000 O
60 | O30 1.0 0.900960 0.738090 0.641740 Biso 1.000000 O
61 | O31 1.0 0.827740 0.203920 0.121470 Biso 1.000000 O
62 | O32 1.0 0.602460 0.419800 0.253630 Biso 1.000000 O
63 | O33 1.0 0.608760 0.508200 0.911240 Biso 1.000000 O
64 | O34 1.0 0.576780 0.125490 0.971200 Biso 1.000000 O
65 | O35 1.0 0.434780 0.464980 0.763000 Biso 1.000000 O
66 | O36 1.0 0.840880 0.199590 0.804530 Biso 1.000000 O
67 | O37 1.0 0.026680 0.375750 0.512530 Biso 1.000000 O
68 | O38 1.0 0.288210 0.006550 0.703740 Biso 1.000000 O
69 | O39 1.0 0.778590 0.157910 0.476440 Biso 1.000000 O
70 | O40 1.0 0.193300 0.330340 0.393520 Biso 1.000000 O
71 | O41 1.0 0.415060 0.180510 0.843790 Biso 1.000000 O
72 | O42 1.0 0.958050 0.505370 0.070110 Biso 1.000000 O
73 | O43 1.0 0.984680 0.106340 0.444620 Biso 1.000000 O
74 | O44 1.0 0.120370 0.428310 0.787520 Biso 1.000000 O
75 | O45 1.0 0.146340 0.923030 0.072790 Biso 1.000000 O
76 | O46 1.0 0.495170 0.514270 0.446370 Biso 1.000000 O
77 | O47 1.0 0.650940 0.915300 0.237120 Biso 1.000000 O
78 | O48 1.0 0.127520 0.990080 0.302090 Biso 1.000000 O
79 | O49 1.0 0.052490 0.224410 0.750720 Biso 1.000000 O
80 | O50 1.0 0.180310 0.870000 0.848560 Biso 1.000000 O
81 | O51 1.0 0.618490 0.589540 0.127830 Biso 1.000000 O
82 | O52 1.0 0.943120 0.398110 0.908100 Biso 1.000000 O
83 | O53 1.0 0.017050 0.735580 0.380950 Biso 1.000000 O
84 | O54 1.0 0.731110 0.646570 0.539590 Biso 1.000000 O
85 | O55 1.0 0.091030 0.632000 0.680020 Biso 1.000000 O
86 | O56 1.0 0.319380 0.378200 0.902550 Biso 1.000000 O
87 | O57 1.0 0.455120 0.793090 0.303840 Biso 1.000000 O
88 | O58 1.0 0.059320 0.192850 0.164590 Biso 1.000000 O
89 | O59 1.0 0.022450 0.889690 0.537430 Biso 1.000000 O
90 | O60 1.0 0.155680 0.215120 0.949720 Biso 1.000000 O
91 | O61 1.0 0.512920 0.070650 0.172800 Biso 1.000000 O
92 | O62 1.0 0.921120 0.019420 0.019760 Biso 1.000000 O
93 | O63 1.0 0.298480 0.712610 0.879390 Biso 1.000000 O
94 | O64 1.0 0.322320 0.050500 0.051740 Biso 1.000000 O
95 | O65 1.0 0.855430 0.395050 0.315430 Biso 1.000000 O
96 | O66 1.0 0.841040 0.980250 0.802780 Biso 1.000000 O
97 | O67 1.0 0.484000 0.157760 0.649970 Biso 1.000000 O
98 | H1 1.0 0.305630 0.555150 0.313630 Biso 1.000000 H
99 | H2 1.0 0.265060 0.559620 0.435020 Biso 1.000000 H
100 | H5 1.0 0.159910 0.033030 0.646240 Biso 1.000000 H
101 | H6 1.0 0.090930 0.135550 0.675230 Biso 1.000000 H
102 | H7 1.0 0.664560 0.304930 0.634650 Biso 1.000000 H
103 | H8 1.0 0.769850 0.245880 0.571220 Biso 1.000000 H
104 | H9 1.0 0.665830 0.839030 0.015530 Biso 1.000000 H
105 | H10 1.0 0.541340 0.840930 0.007500 Biso 1.000000 H
106 | H11 1.0 0.115370 0.700170 0.120160 Biso 1.000000 H
107 | H12 1.0 0.090230 0.745890 0.236890 Biso 1.000000 H
108 | H13 1.0 0.871000 0.602900 0.119030 Biso 1.000000 H
109 | H14 1.0 0.825620 0.646880 0.223650 Biso 1.000000 H
110 | H15 1.0 0.927680 0.780220 0.817450 Biso 1.000000 H
111 | H16 1.0 0.923300 0.888520 0.859390 Biso 1.000000 H
112 | H17 1.0 0.851710 0.524070 0.598260 Biso 1.000000 H
113 | H18 1.0 0.810520 0.412560 0.630660 Biso 1.000000 H
114 | H19 1.0 0.597890 0.811480 0.843280 Biso 1.000000 H
115 | H20 1.0 0.541050 0.755430 0.749370 Biso 1.000000 H
116 | H21 1.0 0.103310 0.410790 0.284950 Biso 1.000000 H
117 | H22 1.0 0.091620 0.325780 0.194580 Biso 1.000000 H
118 | H23 1.0 0.415760 0.589040 0.699600 Biso 1.000000 H
119 | H24 1.0 0.368830 0.696890 0.746910 Biso 1.000000 H
120 | H25 1.0 0.380910 0.740830 0.971520 Biso 1.000000 H
121 | H26 1.0 0.400620 0.760870 0.095660 Biso 1.000000 H
122 | H27 1.0 0.035060 0.598670 0.014540 Biso 1.000000 H
123 | H28 1.0 0.044070 0.717920 0.975760 Biso 1.000000 H
124 | H29 1.0 0.326880 0.117070 0.417390 Biso 1.000000 H
125 | H30 1.0 0.234630 0.196250 0.402590 Biso 1.000000 H
126 | H31 1.0 0.818480 0.785990 0.119840 Biso 1.000000 H
127 | H32 1.0 0.871750 0.889360 0.089770 Biso 1.000000 H
128 | H33 1.0 0.645350 0.957710 0.512430 Biso 1.000000 H
129 | H34 1.0 0.715280 0.034770 0.459660 Biso 1.000000 H
130 | H35 1.0 0.239840 0.919190 0.386110 Biso 1.000000 H
131 | H36 1.0 0.318320 0.941910 0.481050 Biso 1.000000 H
132 | H37 1.0 0.493650 0.748570 0.577710 Biso 1.000000 H
133 | H38 1.0 0.501030 0.766000 0.453700 Biso 1.000000 H
134 | H39 1.0 0.633180 0.269970 0.250600 Biso 1.000000 H
135 | H40 1.0 0.611980 0.143660 0.254440 Biso 1.000000 H
136 | H41 1.0 0.789430 0.575970 0.750640 Biso 1.000000 H
137 | H42 1.0 0.706960 0.663370 0.754550 Biso 1.000000 H
138 | H43 1.0 0.273770 0.681120 0.217190 Biso 1.000000 H
139 | H44 1.0 0.388890 0.698170 0.253430 Biso 1.000000 H
140 | H45 1.0 0.534910 0.401820 0.563230 Biso 1.000000 H
141 | H46 1.0 0.515780 0.409570 0.687880 Biso 1.000000 H
142 | H47 1.0 0.612070 0.932720 0.715520 Biso 1.000000 H
143 | H48 1.0 0.579800 0.055860 0.692010 Biso 1.000000 H
144 | H49 1.0 0.139660 0.712820 0.455740 Biso 1.000000 H
145 | H50 1.0 0.253990 0.738850 0.489590 Biso 1.000000 H
146 | H51 1.0 0.350320 0.507560 0.127920 Biso 1.000000 H
147 | H52 1.0 0.344750 0.379870 0.137830 Biso 1.000000 H
148 | H53 1.0 0.422960 0.074500 0.564730 Biso 1.000000 H
149 | H54 1.0 0.499190 0.023320 0.478950 Biso 1.000000 H
150 | H55 1.0 0.676310 0.273130 0.971090 Biso 1.000000 H
151 | H56 1.0 0.685370 0.393710 0.947910 Biso 1.000000 H
152 | H57 1.0 0.862750 0.534080 0.369810 Biso 1.000000 H
153 | H58 1.0 0.791300 0.622680 0.412600 Biso 1.000000 H
154 | H59 1.0 0.962010 0.687810 0.657170 Biso 1.000000 H
155 | H60 1.0 0.940900 0.792160 0.595790 Biso 1.000000 H
156 | H61 1.0 0.812020 0.264670 0.077760 Biso 1.000000 H
157 | H62 1.0 0.769340 0.204280 0.176850 Biso 1.000000 H
158 | H63 1.0 0.590790 0.475730 0.205570 Biso 1.000000 H
159 | H64 1.0 0.583260 0.449710 0.320980 Biso 1.000000 H
160 | H65 1.0 0.596040 0.539420 0.981970 Biso 1.000000 H
161 | H66 1.0 0.669780 0.544210 0.878230 Biso 1.000000 H
162 | H67 1.0 0.605900 0.064060 0.940180 Biso 1.000000 H
163 | H68 1.0 0.540240 0.103530 0.037930 Biso 1.000000 H
164 | H71 1.0 0.799400 0.236820 0.747820 Biso 1.000000 H
165 | H72 1.0 0.825640 0.235120 0.872570 Biso 1.000000 H
166 | H73 1.0 0.977920 0.363450 0.455480 Biso 1.000000 H
167 | H74 1.0 0.992270 0.366230 0.579600 Biso 1.000000 H
168 | H75 1.0 0.344470 0.031140 0.753450 Biso 1.000000 H
169 | H76 1.0 0.247830 0.948740 0.737190 Biso 1.000000 H
170 | H77 1.0 0.749910 0.171880 0.408440 Biso 1.000000 H
171 | H78 1.0 0.858000 0.145670 0.458050 Biso 1.000000 H
172 | H79 1.0 0.137400 0.329690 0.453420 Biso 1.000000 H
173 | H80 1.0 0.236050 0.394380 0.403930 Biso 1.000000 H
174 | H81 1.0 0.375730 0.243730 0.863210 Biso 1.000000 H
175 | H82 1.0 0.472460 0.169440 0.899690 Biso 1.000000 H
176 | H83 1.0 0.995180 0.470860 0.132060 Biso 1.000000 H
177 | H84 1.0 0.942550 0.449830 0.007970 Biso 1.000000 H
178 | H85 1.0 0.038170 0.106620 0.506000 Biso 1.000000 H
179 | H86 1.0 0.980220 0.032830 0.423360 Biso 1.000000 H
180 | H87 1.0 0.089730 0.361760 0.750200 Biso 1.000000 H
181 | H88 1.0 0.104490 0.477370 0.731510 Biso 1.000000 H
182 | H89 1.0 0.074070 0.932140 0.045200 Biso 1.000000 H
183 | H90 1.0 0.155060 0.851060 0.108180 Biso 1.000000 H
184 | H91 1.0 0.516470 0.586040 0.461250 Biso 1.000000 H
185 | H92 1.0 0.420520 0.515050 0.423790 Biso 1.000000 H
186 | H93 1.0 0.710200 0.899950 0.186010 Biso 1.000000 H
187 | H94 1.0 0.685760 0.926470 0.306020 Biso 1.000000 H
188 | H95 1.0 0.166840 0.059400 0.320700 Biso 1.000000 H
189 | H96 1.0 0.121500 0.983360 0.221300 Biso 1.000000 H
190 | H97 1.0 0.085470 0.231440 0.824430 Biso 1.000000 H
191 | H98 1.0 0.976310 0.199760 0.770130 Biso 1.000000 H
192 | H99 1.0 0.175900 0.906220 0.921080 Biso 1.000000 H
193 | H100 1.0 0.099450 0.856470 0.830380 Biso 1.000000 H
194 | H101 1.0 0.687180 0.627590 0.140360 Biso 1.000000 H
195 | H102 1.0 0.575160 0.649540 0.104550 Biso 1.000000 H
196 | H103 1.0 0.881050 0.422700 0.868250 Biso 1.000000 H
197 | H104 1.0 0.006450 0.418360 0.861550 Biso 1.000000 H
198 | H105 1.0 0.949340 0.693880 0.376620 Biso 1.000000 H
199 | H106 1.0 0.010580 0.808450 0.405310 Biso 1.000000 H
200 | H107 1.0 0.663720 0.685850 0.529150 Biso 1.000000 H
201 | H108 1.0 0.778590 0.701170 0.581370 Biso 1.000000 H
202 | H109 1.0 0.141280 0.646600 0.620090 Biso 1.000000 H
203 | H110 1.0 0.120100 0.679760 0.734890 Biso 1.000000 H
204 | H111 1.0 0.322660 0.404990 0.973970 Biso 1.000000 H
205 | H112 1.0 0.262670 0.420990 0.869420 Biso 1.000000 H
206 | H113 1.0 0.518750 0.833560 0.273900 Biso 1.000000 H
207 | H114 1.0 0.412880 0.843550 0.349950 Biso 1.000000 H
208 | H115 1.0 0.983010 0.204750 0.169990 Biso 1.000000 H
209 | H116 1.0 0.060770 0.136380 0.212390 Biso 1.000000 H
210 | H117 1.0 0.098320 0.897960 0.517210 Biso 1.000000 H
211 | H118 1.0 0.011430 0.961150 0.561150 Biso 1.000000 H
212 | H119 1.0 0.124320 0.216990 0.018450 Biso 1.000000 H
213 | H120 1.0 0.217590 0.260980 0.954920 Biso 1.000000 H
214 | H121 1.0 0.562630 0.006990 0.185000 Biso 1.000000 H
215 | H122 1.0 0.440790 0.048660 0.158590 Biso 1.000000 H
216 | H123 1.0 0.893550 0.077180 0.065530 Biso 1.000000 H
217 | H124 1.0 0.872330 0.005080 0.954480 Biso 1.000000 H
218 | H125 1.0 0.242900 0.772220 0.871600 Biso 1.000000 H
219 | H126 1.0 0.265120 0.653330 0.913550 Biso 1.000000 H
220 | H127 1.0 0.246130 0.035070 0.060400 Biso 1.000000 H
221 | H128 1.0 0.331580 0.078970 0.977530 Biso 1.000000 H
222 | H129 1.0 0.789380 0.388080 0.275210 Biso 1.000000 H
223 | H130 1.0 0.909150 0.380550 0.264760 Biso 1.000000 H
224 | H131 1.0 0.826460 0.060160 0.807490 Biso 1.000000 H
225 | H132 1.0 0.775010 0.949820 0.768010 Biso 1.000000 H
226 | H133 1.0 0.472880 0.171420 0.726340 Biso 1.000000 H
227 | H134 1.0 0.498630 0.231200 0.622540 Biso 1.000000 H
228 | O35 1.0 0.544780 0.504980 0.763000 Biso 1.000000 O
229 | O2 1.0 0.223320 0.171860 0.168060 Biso 1.000000 O
230 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="PaiNN",
5 | version="1.0.0",
6 | description="Library for implementation of message passing neural networks in Pytorch",
7 | author="xinyang",
8 | author_email="xinyang@dtu.dk",
9 | url = "https://github.com/Yangxinsix/PaiNN-model",
10 | packages=["PaiNN"],
11 | )
12 |
--------------------------------------------------------------------------------
/workflow/al_select.py:
--------------------------------------------------------------------------------
1 | from PaiNN.data import AseDataset, collate_atomsdata
2 | from PaiNN.model import PainnModel
3 | import torch
4 | import numpy as np
5 | from PaiNN.active_learning import GeneralActiveLearning
6 | import math
7 | import glob
8 | import json
9 | import argparse, toml
10 | from pathlib import Path
11 | from ase.io import read, write, Trajectory
12 |
13 | def setup_seed(seed):
14 | torch.manual_seed(seed)
15 | if torch.cuda.is_available():
16 | torch.cuda.manual_seed_all(seed)
17 | np.random.seed(seed)
18 | torch.backends.cudnn.deterministic = True
19 |
20 | def get_arguments(arg_list=None):
21 | parser = argparse.ArgumentParser(
22 | description="General Active Learning", fromfile_prefix_chars="+"
23 | )
24 | parser.add_argument(
25 | "--kernel",
26 | type=str,
27 | help="How to get features",
28 | )
29 | parser.add_argument(
30 | "--selection",
31 | type=str,
32 | help="Selection method, one of `max_dist_greedy`, `deterministic_CUR`, `lcmd_greedy`, `max_det_greedy` or `max_diag`",
33 | )
34 | parser.add_argument(
35 | "--n_random_features",
36 | type=int,
37 | help="If `n_random_features = 0`, do not use random projections.",
38 | )
39 | parser.add_argument(
40 | "--batch_size",
41 | type=int,
42 | help="How many data points should be selected",
43 | )
44 | parser.add_argument(
45 | "--load_model",
46 | type=str,
47 | help="Where to find the models",
48 | )
49 | parser.add_argument(
50 | "--dataset", type=str, help="Path to ASE trajectory",
51 | )
52 | parser.add_argument(
53 | "--split_file",
54 | type=str,
55 | help="Train/test/validation split file json",
56 | )
57 | parser.add_argument(
58 | "--pool_set", type=str, help="Path to MD trajectory obtained from machine learning potential",
59 | )
60 | parser.add_argument(
61 | "--training_set", type=str, help="Path to training set. Useful for pool/train based selection method",
62 | )
63 | parser.add_argument(
64 | "--device",
65 | type=str,
66 | help="Set which device to use for training e.g. 'cuda' or 'cpu'",
67 | )
68 | parser.add_argument(
69 | "--random_seed",
70 | type=int,
71 | help="Random seed for this run",
72 | )
73 | parser.add_argument(
74 | "--cfg",
75 | type=str,
76 | default="arguments.toml",
77 | help="Path to config file. e.g. 'arguments.toml'"
78 | )
79 |
80 | return parser.parse_args(arg_list)
81 |
82 | def update_namespace(ns, d):
83 | for k, v in d.items():
84 | if not ns.__dict__.get(k):
85 | ns.__dict__[k] = v
86 |
87 | def main():
88 | args = get_arguments()
89 | if args.cfg:
90 | with open(args.cfg, 'r') as f:
91 | params = toml.load(f)
92 | update_namespace(args, params)
93 |
94 | setup_seed(args.random_seed)
95 |
96 | # Load models
97 | model_pth = Path(args.load_model).rglob('*best_model.pth')
98 | models = []
99 | for each in model_pth:
100 | state_dict = torch.load(each)
101 | model = PainnModel(
102 | num_interactions=state_dict["num_layer"],
103 | hidden_state_size=state_dict["node_size"],
104 | cutoff=state_dict["cutoff"],
105 | )
106 | model.to(args.device)
107 | model.load_state_dict(state_dict["model"])
108 | models.append(model)
109 |
110 | # Load dataset
111 | if args.dataset:
112 | with open(args.split_file, 'r') as f:
113 | datasplits = json.load(f)
114 |
115 | dataset = AseDataset(args.dataset, cutoff=models[0].cutoff)
116 | data_dict = {
117 | 'pool': torch.utils.data.Subset(dataset, datasplits['pool']),
118 | 'train': torch.utils.data.Subset(dataset, datasplits['train']),
119 | }
120 | elif args.pool_set and args.train_set:
121 | if isinstance(args.pool_set, list):
122 | dataset = []
123 | for traj in args.pool_set:
124 | if Path(traj).stat().st_size > 0:
125 | dataset += read(traj, ':')
126 | else:
127 | dataset = args.pool_set
128 | data_dict = {
129 | 'pool': AseDataset(dataset, cutoff=models[0].cutoff),
130 | 'train': AseDataset(args.train_set, cutoff=models[0].cutoff),
131 | }
132 | else:
133 | raise RuntimeError("Please give valid pool data set for selection!")
134 |
135 | # raise error if the pool dataset is not large enough
136 | if len(data_dict['pool']) < args.batch_size * 5:
137 | raise RuntimeError(f"""The pool data set is not large enough for selection!
138 | It should be larger than 10 times batch size ({args.batch_size*10}).
139 | Check you MD simulation!""")
140 |
141 | # Select structures
142 | al = GeneralActiveLearning(
143 | kernel=args.kernel,
144 | selection=args.selection,
145 | n_random_features=args.n_random_features,
146 | )
147 | indices = al.select(models, data_dict, al_batch_size=args.batch_size)
148 | al_idx = [datasplits['pool'][i] for i in indices] if args.dataset else indices
149 | al_info = {
150 | 'kernel': args.kernel,
151 | 'selection': args.selection,
152 | 'dataset': args.dataset if args.dataset else args.pool_set,
153 | 'selected': al_idx,
154 | }
155 |
156 | with open('selected.json', 'w') as f:
157 | json.dump(al_info, f)
158 |
159 | # Update new data splits
160 | if args.dataset:
161 | pool_idx = np.delete(datasplits['pool'], indices)
162 | datasplits['pool'] = pool_idx.tolist()
163 | datasplits['train'] += al_idx
164 | with open(args.split_file, 'w') as f:
165 | json.dump(datasplits, f)
166 |
167 | if __name__ == "__main__":
168 | main()
169 |
--------------------------------------------------------------------------------
/workflow/config.toml:
--------------------------------------------------------------------------------
1 | # An example configuration file for active learning workflow
2 |
3 | [global]
4 | root = '.'
5 | random_seed = 3407
6 |
7 | [train]
8 | # Hyperparameters for PaiNN
9 | # load_model = '.'
10 | cutoff = 5.0
11 | # split_file = 'datasplits.json'
12 | val_ratio = 0.1
13 | num_interactions = 3
14 | node_size = 64
15 | output_dir = 'model_output'
16 | dataset = '/home/scratch3/xinyang/Au-facets/111_110.traj'
17 | max_steps = 1000000
18 | device = 'cuda'
19 | batch_size = 12
20 | initial_lr = 0.0001
21 | forces_weight = 0.98
22 | log_interval = 2000
23 | normalization = false
24 | atomwise_normalization = false
25 | stop_patience = 20
26 | plateau_scheduler = true # use ReduceLROnPlateatu scheduler to decrease lr when learning plateaus
27 | random_seed = 3407
28 |
29 | [train.ensemble]
30 | # For training multiple models in parallel, the hyperparameters will be set as default (in above) if not assigned
31 | #80_node_4_layer = {node_size = 80, num_interactions = 4, load_model = '/home/scratch3/xinyang/Au-facets/old_training/80_node_4_layer/model_output/best_model.pth'}
32 | #96_node_4_layer = {node_size = 96, num_interactions = 4, load_model = '/home/scratch3/xinyang/Au-facets/old_training/96_node_4_layer/model_output/best_model.pth'}
33 | 112_node_3_layer = {node_size = 112, num_interactions = 3, load_model = '/home/scratch3/xinyang/Au-facets/old_training/112_node_3_layer/model_output/best_model.pth'}
34 | 120_node_3_layer = {node_size = 120, num_interactions = 3, start_iteration = 7, stop_patience = 200}
35 | 128_node_3_layer = {node_size = 128, num_interactions = 3, load_model = '/home/scratch3/xinyang/Au-facets/old_training/128_node_3_layer/model_output/best_model.pth'}
36 | 136_node_3_layer = {node_size = 136, num_interactions = 3, start_iteration = 7, stop_patience = 200}
37 | 144_node_3_layer = {node_size = 144, num_interactions = 3, load_model = '/home/scratch3/xinyang/Au-facets/old_training/144_node_3_layer/model_output/best_model.pth'}
38 | 160_node_3_layer = {node_size = 160, num_interactions = 3, load_model = '/home/scratch3/xinyang/Au-facets/old_training/160_node_3_layer/model_output/best_model.pth'}
39 |
40 | [train.resource]
41 | nodename = 'sm3090'
42 | tmax = '7d' # Time limit for each job. For example: 1d (1 day), 2m (2 min), 5h (5 hours)
43 | cores = 8
44 |
45 | [MD]
46 | # Parameters for MD. It is better to customize your parameters in MD script.
47 | init_traj = '/home/scratch3/xinyang/md_mlp/Au_111_110/110_water/MD.traj'
48 | start_indice = -5
49 | # load_model = '/home/scratch3/xinyang/Au-facets/old_training/train' # will be assigned in the workflow
50 | time_step = 0.5
51 | temperature = 350
52 | max_steps = 2000000
53 | min_steps = 100000
54 | device = 'cuda'
55 | fix_under = 7.0
56 | dump_step = 100
57 | print_step = 1
58 | num_uncertain = 1000
59 | random_seed = 3407
60 |
61 | [MD.runs]
62 | # run multiple MD jobs in parallel
63 | [MD.runs.Au_111_water]
64 | init_traj = '/home/scratch3/xinyang/md_mlp/Au_111_110/111_water/MD.traj'
65 | fix_under = 7.0
66 | start_indice = -5
67 | min_steps = 50000
68 | dump_step = 50
69 |
70 | [MD.runs.Au_110_water]
71 | init_traj = '/home/scratch3/xinyang/md_mlp/Au_111_110/110_water/MD.traj'
72 | max_steps = 2000000 # this one is already good enough
73 | fix_under = 7.0
74 | start_indice = -5
75 | min_steps = 50000
76 | dump_step = 100
77 |
78 | [MD.runs.Au_111_1OH]
79 | init_traj = '/home/scratch3/xinyang/Au-facets/1OH/md/iter_0/111_1OH/MD.traj'
80 | fix_under = 7.0
81 | start_indice = -5
82 | min_steps = 50000
83 | dump_step = 30
84 | start_iteration = 3
85 |
86 | [MD.runs.Au_110_1OH]
87 | init_traj = '/home/scratch3/xinyang/Au-facets/1OH/md/iter_0/110_1OH/MD.traj'
88 | fix_under = 7.0
89 | start_indice = -5
90 | min_steps = 20000
91 | dump_step = 30
92 | start_iteration = 3
93 |
94 | [MD.runs.Au_111_1O2]
95 | init_traj = '/home/energy/xinyang/work/Au_MD/DFT_MD/111_MD/O2/111_O2_incomplete.traj'
96 | fix_under = 7.0
97 | start_indice = -5
98 | min_steps = 30000
99 | dump_step = 50
100 | start_iteration = 3
101 |
102 | [MD.runs.Au_110_1O2]
103 | init_traj = '/home/energy/xinyang/work/Au_MD/DFT_MD/110_MD/O2/110_O2_incomplete.traj'
104 | fix_under = 7.0
105 | start_indice = -5
106 | min_steps = 50000
107 | dump_step = 50
108 | start_iteration = 3
109 |
110 | [MD.resource]
111 | nodename = 'sm3090'
112 | tmax = '7d'
113 | cores = 8
114 |
115 | [select]
116 | kernel = "full-g" # Name of the kernel, e.g. "full-g", "ll-g", "full-F_inv", "ll-F_inv", "qbc-energy", "qbc-force", "ae-energy", "ae-force", "random"
117 | selection = "lcmd_greedy" # Selection method, one of "max_dist_greedy", "deterministic_CUR", "lcmd_greedy", "max_det_greedy" or "max_diag".
118 | n_random_features = 500 # If "n_random_features = 0", do not use random projections.
119 | batch_size = 100
120 | # load_model = '/home/scratch3/xinyang/Au-facets/old_training/train' # will be assigned in the workflow
121 | # dataset = 'md17aspirin.traj' # should not be assigned if using pool data set from MD
122 | # split_file = 'datasplits.json'
123 | # pool_set = # Useful when dataset and split_file are not assigned, can be a list or str
124 | train_set = '/home/scratch3/xinyang/Au-facets/111_110.traj'
125 | device = 'cuda'
126 | random_seed = 3407
127 |
128 | [select.runs]
129 | Au_110_water = {batch_size = 100} # this one is much faster so use larger batch size can save some time
130 | Au_111_water = {batch_size = 200}
131 | Au_110_1OH = {batch_size = 200, start_iteration = 3}
132 | Au_111_1OH = {batch_size = 200, start_iteration = 3}
133 | Au_110_1O2 = {batch_size = 200, start_iteration = 3}
134 | Au_111_1O2 = {batch_size = 200, start_iteration = 3}
135 |
136 |
137 | [select.resource]
138 | nodename = 'sm3090'
139 | tmax = '2d'
140 | cores = 8
141 |
142 | [labeling]
143 | # label_set = 'xxx.traj'
144 | train_set = '/home/scratch3/xinyang/Au-facets/111_110.traj'
145 | # pool_set # will be assigned in the workflow, can be a list
146 | # al_info # will be assigned in the workflow
147 | num_jobs = 2
148 |
149 | [labeling.VASP]
150 | # VASP parameters
151 | xc = 'PBE'
152 | gga = 'pe'
153 | system = 'ni'
154 | prec = 'normal'
155 | istart = 1
156 | icharg = 2
157 | npar = 4
158 | encut = 350
159 | algo = 'Fast'
160 | lreal = 'Auto'
161 | nelm = 1000
162 | nelmin = 5
163 | nelmdl = -5
164 | ediff = 1e-4
165 | ediffg = -0.01
166 | nsw = 0
167 | ibrion = 0
168 | potim = 1
169 | isif = 2
170 | ispin = 2
171 | ismear = 0
172 | sigma = 0.1
173 | lwave = true
174 | lcharg = false
175 | ivdw = 11
176 | lasph = true
177 | kpts = [2, 2, 1]
178 | gamma = false
179 | # kspacing = 0.5
180 |
181 | [labeling.runs]
182 |
183 | [labeling.runs.Au_111_water] # The key name should be the same to MD
184 | gamma = true
185 | num_jobs = 6 # accelerate DFT labeling by spliting the job to several different parts
186 |
187 | [labeling.runs.Au_110_water]
188 | gamma = false
189 | num_jobs = 2
190 |
191 | [labeling.runs.Au_111_1OH]
192 | gamma = true
193 | num_jobs = 6
194 | start_iteration = 3
195 |
196 | [labeling.runs.Au_110_1OH]
197 | gamma = false
198 | num_jobs = 2
199 | start_iteration = 3
200 |
201 | [labeling.runs.Au_111_1O2]
202 | gamma = true
203 | num_jobs = 6
204 | start_iteration = 3
205 |
206 | [labeling.runs.Au_110_1O2]
207 | gamma = false
208 | num_jobs = 2
209 | start_iteration = 3
210 |
211 | [labeling.resource]
212 | cores = 40
213 | nodename = 'xeon40'
214 | tmax = '2d'
215 |
--------------------------------------------------------------------------------
/workflow/flow.py:
--------------------------------------------------------------------------------
1 | import json, toml, sys
2 | from pathlib import Path
3 | from myqueue.workflow import run
4 | from typing import List, Dict
5 | from ase.io import Trajectory, read, write
6 | import numpy as np
7 | import copy
8 |
9 | # args parsing
10 |
11 | with open('config.toml') as f:
12 | args = toml.load(f)
13 |
14 | # get absolute path
15 | name_list = [
16 | 'dataset',
17 | 'split_file',
18 | 'load_model',
19 | 'init_traj',
20 | 'pool_set',
21 | 'train_set',
22 | 'label_set',
23 | 'al_info',
24 | 'root'
25 | ]
26 | def get_absolute_path(d: dict):
27 | for k, v in d.items():
28 | if k in name_list and not Path(v).is_absolute():
29 | d[k] = str(Path(v).resolve())
30 | elif isinstance(v, dict):
31 | d[k] = get_absolute_path(v)
32 | return d
33 | args = get_absolute_path(args)
34 |
35 | # parsing training parameters
36 | train_params = {}
37 | if args['train'].get('ensemble'):
38 | for name, params in args['train']['ensemble'].items():
39 | for k, v in args['train'].items():
40 | if not isinstance(v, dict) and k not in params:
41 | params[k] = v
42 | train_params[name] = params
43 | else:
44 | params = {}
45 | for k, v in args['train'].items():
46 | if not isinstance(v, dict) and k not in params:
47 | params[k] = v
48 | train_params['model'] = params
49 | # train_resource = args['train']['resource']
50 |
51 | # parsing active learning parameters
52 | al_params = {}
53 | if args['select'].get('runs'):
54 | for name, params in args['select']['runs'].items():
55 | for k, v in args['select'].items():
56 | if not isinstance(v, dict) and k not in params:
57 | params[k] = v
58 | al_params[name] = params
59 | else:
60 | params = {}
61 | for k, v in args['select'].items():
62 | if not isinstance(v, dict) and k not in params:
63 | params[k] = v
64 | al_params['select'] = params
65 |
66 | # al_resource = args['select']['resource']
67 |
68 | # parsing MD parameters
69 | md_params = {}
70 | if args['MD'].get('runs'):
71 | for name, params in args['MD']['runs'].items():
72 | for k, v in args['MD'].items():
73 | if not isinstance(v, dict) and k not in params:
74 | params[k] = v
75 | md_params[name] = params
76 | else:
77 | params = {}
78 | for k, v in args['MD'].items():
79 | if not isinstance(v, dict) and k not in params:
80 | params[k] = v
81 | md_params['md_run'] = params
82 |
83 | # DFT labelling
84 | dft_params = {}
85 | tmp_params = {k: v for k, v in args['labeling'].items() if not isinstance(v, dict)}
86 | tmp_params['VASP'] = args['labeling']['VASP']
87 | if args['labeling'].get('runs'):
88 | for name, params in args['labeling']['runs'].items():
89 | new_params = copy.deepcopy(tmp_params)
90 | for k, v in params.items():
91 | if k in new_params['VASP']:
92 | new_params['VASP'][k] = v
93 | else:
94 | new_params[k] = v
95 | dft_params[name] = new_params
96 | else:
97 | dft_params['dft_run'] = tmp_params
98 |
99 | root = args['global']['root']
100 |
101 | def train_models(folder, deps, extra_args: List[str] = [], iteration: int=0):
102 | tasks = []
103 | node_info = args['train']['resource']
104 | # parse parameters
105 | for name, params in train_params.items():
106 | path = Path(f'{folder}/iter_{iteration}/{name}')
107 |
108 | if not params.get('start_iteration'):
109 | params['start_iteration'] = 0
110 | if iteration >= params['start_iteration']:
111 | if not path.is_dir():
112 | path.mkdir(parents=True)
113 |
114 | # load model
115 | if iteration > 0:
116 | load_model = f'{root}/{folder}/iter_{iteration-1}/{name}/{params["output_dir"]}/best_model.pth'
117 | if Path(load_model).is_file():
118 | params['load_model'] = load_model
119 | # elif iteration == 0:
120 | # params['load_model'] = f'/home/scratch3/xinyang/Au-facets/old_training/train/{name}/model_output/best_model.pth'
121 |
122 | with open(path / 'arguments.toml', 'w') as f:
123 | toml.dump(params, f)
124 |
125 | arguments = ['--cfg', 'arguments.toml']
126 | arguments += extra_args
127 |
128 | tasks.append(run(
129 | script=f'{root}/train.py',
130 | nodename='sm3090' if not node_info.get('nodename') else node_info['nodename'],
131 | cores=8 if not node_info.get('cores') else node_info['cores'],
132 | tmax='7d' if not node_info.get('tmax') else node_info['tmax'],
133 | args=arguments,
134 | folder=path,
135 | name=name,
136 | deps=deps,
137 | ))
138 |
139 | return tasks
140 |
141 | def active_learning(folder, deps, extra_args: List[str] = [], iteration: int=0):
142 | tasks = {}
143 | node_info = args['select']['resource']
144 | # parse parameters
145 | for name, params in al_params.items():
146 | path = Path(f'{folder}/iter_{iteration}/{name}')
147 | if not params.get('start_iteration'):
148 | params['start_iteration'] = 0
149 | if iteration >= params['start_iteration']:
150 | if not path.is_dir():
151 | path.mkdir(parents=True)
152 |
153 | params['load_model'] = f'{root}/train/iter_{iteration}'
154 | if not params.get('dataset'):
155 | params['pool_set'] = [f'{root}/md/iter_{iteration}/{name}/MD.traj', f'{root}/md/iter_{iteration}/{name}/warning_struct.traj']
156 |
157 | with open(path / 'arguments.toml', 'w') as f:
158 | toml.dump(params, f)
159 |
160 | arguments = ['--cfg', 'arguments.toml']
161 |
162 | tasks[name] = run(
163 | script=f'{root}/al_select.py',
164 | nodename='sm3090' if not node_info.get('nodename') else node_info['nodename'],
165 | cores=8 if not node_info.get('cores') else node_info['cores'],
166 | tmax='7d' if not node_info.get('tmax') else node_info['tmax'],
167 | args=arguments,
168 | folder=path,
169 | name=name,
170 | deps=[deps[name]],
171 | )
172 |
173 | return tasks
174 |
175 | def run_md(folder, deps=[], extra_args: List[str] = [], iteration: int=0):
176 | tasks = {}
177 | node_info = args['MD']['resource']
178 | for name, params in md_params.items():
179 | path = Path(f'{folder}/iter_{iteration}/{name}')
180 |
181 | if not params.get('start_iteration'):
182 | params['start_iteration'] = 0
183 | if iteration >= params['start_iteration']:
184 | if not path.is_dir():
185 | path.mkdir(parents=True)
186 | params['load_model'] = f'{root}/train/iter_{iteration}'
187 |
188 |
189 | if iteration > params['start_iteration']:
190 | params['init_traj'] = f'{root}/md/iter_{iteration-1}/{name}/MD.traj'
191 |
192 | with open(path / 'arguments.toml', 'w') as f:
193 | toml.dump(params, f)
194 |
195 | arguments = ['--cfg', 'arguments.toml']
196 |
197 | tasks[name] = run(
198 | script=f'{root}/md_run.py',
199 | nodename='sm3090' if not node_info.get('nodename') else node_info['nodename'],
200 | cores=8 if not node_info.get('cores') else node_info['cores'],
201 | tmax='7d' if not node_info.get('tmax') else node_info['tmax'],
202 | args=arguments,
203 | folder=path,
204 | name=name,
205 | deps=deps,
206 | )
207 |
208 | return tasks
209 |
210 | def run_dft(folder, deps={}, extra_args: List[str] = [], iteration: int=0):
211 | tasks = []
212 | node_info = args['labeling']['resource']
213 | for name, params in dft_params.items():
214 | path = Path(f'{folder}/iter_{iteration}/{name}')
215 | if not params.get('start_iteration'):
216 | params['start_iteration'] = 0
217 | if iteration >= params['start_iteration']:
218 | if not path.is_dir():
219 | path.mkdir(parents=True)
220 |
221 | # get images that need to be labeled
222 | params['system'] = name
223 | params['pool_set'] = [f'{root}/md/iter_{iteration}/{name}/MD.traj', f'{root}/md/iter_{iteration}/{name}/warning_struct.traj']
224 | params['al_info'] = f'{root}/select/iter_{iteration}/{name}/selected.json'
225 | with open(path / 'arguments.toml', 'w') as f:
226 | toml.dump(params, f)
227 |
228 | arguments = ['--cfg', 'arguments.toml']
229 |
230 | if params.get('num_jobs'):
231 | for i in range(params['num_jobs']):
232 | dft_arguments = ['--cfg', '../arguments.toml', '--job_order', f'{i}']
233 | dft_path = path / f'{i}'
234 | if not dft_path.is_dir():
235 | dft_path.mkdir(parents=True)
236 | tasks.append(run(
237 | script=f'{root}/vasp.py',
238 | nodename='xeon40' if not node_info.get('nodename') else node_info['nodename'],
239 | cores=40 if not node_info.get('cores') else node_info['cores'],
240 | tmax='50h' if not node_info.get('tmax') else node_info['tmax'],
241 | args=dft_arguments,
242 | folder=dft_path,
243 | name=name,
244 | deps=[deps[name]],
245 | ))
246 | else:
247 | tasks.append(run(
248 | script=f'{root}/vasp.py',
249 | nodename='xeon40' if not node_info.get('nodename') else node_info['nodename'],
250 | cores=40 if not node_info.get('cores') else node_info['cores'],
251 | tmax='50h' if not node_info.get('tmax') else node_info['tmax'],
252 | args=arguments,
253 | folder=path,
254 | name=name,
255 | deps=[deps[name]],
256 | ))
257 |
258 | return tasks
259 |
260 | def all_done(runs):
261 | return all([task.done for task in runs])
262 |
263 | def workflow():
264 | dft = []
265 | for iteration in range(9):
266 | # training part
267 | training = train_models('train', deps=dft, iteration=iteration)
268 |
269 | # data generating
270 | md = run_md('md', deps=training, iteration=iteration)
271 |
272 | # active learning selection
273 | select = active_learning('select', deps=md, iteration=iteration)
274 |
275 | # DFT labeling
276 | dft = run_dft('labeling', deps=select, iteration=iteration)
277 |
--------------------------------------------------------------------------------
/workflow/md_run.py:
--------------------------------------------------------------------------------
1 | from ase.md.langevin import Langevin
2 | from ase.calculators.plumed import Plumed
3 | from ase import units
4 | from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
5 | from ase.io import read, write, Trajectory
6 |
7 | import numpy as np
8 | import torch
9 | import sys
10 | import glob
11 | import toml
12 | import argparse
13 | from pathlib import Path
14 | import logging
15 |
16 | from PaiNN.data import AseDataset, collate_atomsdata
17 | from PaiNN.model import PainnModel
18 | from PaiNN.calculator import MLCalculator, EnsembleCalculator
19 | from ase.constraints import FixAtoms
20 |
21 | def setup_seed(seed):
22 | torch.manual_seed(seed)
23 | if torch.cuda.is_available():
24 | torch.cuda.manual_seed_all(seed)
25 | np.random.seed(seed)
26 | torch.backends.cudnn.deterministic = True
27 |
28 | def get_arguments(arg_list=None):
29 | parser = argparse.ArgumentParser(
30 | description="MD simulations drive by graph neural networks", fromfile_prefix_chars="+"
31 | )
32 | parser.add_argument(
33 | "--init_traj",
34 | type=str,
35 | help="Path to start configurations",
36 | )
37 | parser.add_argument(
38 | "--start_indice",
39 | type=int,
40 | help="Indice of the start configuration",
41 | )
42 | parser.add_argument(
43 | "--load_model",
44 | type=str,
45 | help="Where to find the models",
46 | )
47 | parser.add_argument(
48 | "--time_step",
49 | type=float,
50 | default=0.5,
51 | help="Time step of MD simulation",
52 | )
53 | parser.add_argument(
54 | "--max_steps",
55 | type=int,
56 | default=5000000,
57 | help="Maximum steps of MD",
58 | )
59 | parser.add_argument(
60 | "--min_steps",
61 | type=int,
62 | default=100000,
63 | help="Minimum steps of MD, raise error if not reached",
64 | )
65 | parser.add_argument(
66 | "--temperature",
67 | type=float,
68 | default=350.0,
69 | help="Maximum time steps of MD",
70 | )
71 | parser.add_argument(
72 | "--fix_under",
73 | type=float,
74 | default=5.9,
75 | help="Fix atoms under the specified value",
76 | )
77 | parser.add_argument(
78 | "--dump_step",
79 | type=int,
80 | default=100,
81 | help="Fix atoms under the specified value",
82 | )
83 | parser.add_argument(
84 | "--print_step",
85 | type=int,
86 | default=1,
87 | help="Fix atoms under the specified value",
88 | )
89 | parser.add_argument(
90 | "--num_uncertain",
91 | type=int,
92 | default=1000,
93 | help="Stop MD when too many structures with large uncertainty are collected",
94 | )
95 | parser.add_argument(
96 | "--random_seed",
97 | type=int,
98 | help="Random seed for this run",
99 | )
100 | parser.add_argument(
101 | "--device",
102 | type=str,
103 | default='cuda',
104 | help="Set which device to use for running MD e.g. 'cuda' or 'cpu'",
105 | )
106 | parser.add_argument(
107 | "--cfg",
108 | type=str,
109 | default="arguments.toml",
110 | help="Path to config file. e.g. 'arguments.toml'"
111 | )
112 |
113 | return parser.parse_args(arg_list)
114 |
115 | def update_namespace(ns, d):
116 | for k, v in d.items():
117 | ns.__dict__[k] = v
118 |
119 | class CallsCounter:
120 | def __init__(self, func):
121 | self.calls = 0
122 | self.func = func
123 | def __call__(self, *args, **kwargs):
124 | self.calls += 1
125 | self.func(*args, **kwargs)
126 |
127 | def main():
128 | args = get_arguments()
129 | if args.cfg:
130 | with open(args.cfg, 'r') as f:
131 | params = toml.load(f)
132 | update_namespace(args, params)
133 |
134 | setup_seed(args.random_seed)
135 |
136 | # set logger
137 | logger = logging.getLogger(__file__)
138 | logger.setLevel(logging.DEBUG)
139 |
140 | runHandler = logging.FileHandler('md.log', mode='w')
141 | runHandler.setLevel(logging.DEBUG)
142 | runHandler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)7s - %(message)s"))
143 | errorHandler = logging.FileHandler('error.log', mode='w')
144 | errorHandler.setLevel(logging.WARNING)
145 | errorHandler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)7s - %(message)s"))
146 |
147 | logger.addHandler(runHandler)
148 | logger.addHandler(errorHandler)
149 | logger.addHandler(logging.StreamHandler())
150 | logger.warning = CallsCounter(logger.warning)
151 | logger.info = CallsCounter(logger.info)
152 |
153 | # load model
154 | model_pth = Path(args.load_model).rglob('*best_model.pth')
155 | models = []
156 | for each in model_pth:
157 | state_dict = torch.load(each)
158 | model = PainnModel(
159 | num_interactions=state_dict["num_layer"],
160 | hidden_state_size=state_dict["node_size"],
161 | cutoff=state_dict["cutoff"],
162 | )
163 | model.to(args.device)
164 | model.load_state_dict(state_dict["model"])
165 | models.append(model)
166 |
167 | encalc = EnsembleCalculator(models)
168 |
169 | # set up md start configuration
170 | images = read(args.init_traj, ':')
171 | start_indice = np.random.choice(len(images)) if args.start_indice == None else args.start_indice
172 | logger.debug(f'MD starts from No.{start_indice} configuration in {args.init_traj}')
173 | atoms = images[start_indice]
174 | atoms.wrap()
175 | cons = FixAtoms(mask=atoms.positions[:, 2] < args.fix_under) if args.fix_under else []
176 | atoms.set_constraint(cons)
177 | atoms.calc = encalc
178 | atoms.get_potential_energy()
179 |
180 | collect_traj = Trajectory('warning_struct.traj', 'w')
181 | @CallsCounter
182 | def printenergy(a=atoms): # store a reference to atoms in the definition.
183 | """Function to print the potential, kinetic and total energy."""
184 | epot = a.get_potential_energy()
185 | ekin = a.get_kinetic_energy()
186 | temp = ekin / (1.5 * units.kB) / a.get_global_number_of_atoms()
187 | ensemble = a.calc.results['ensemble']
188 | energy_var = ensemble['energy_var']
189 | forces_var = np.mean(ensemble['forces_var'])
190 | forces_sd = np.mean(np.sqrt(ensemble['forces_var']))
191 | forces_l2_var = np.mean(ensemble['forces_l2_var'])
192 |
193 | if forces_sd > 0.2:
194 | logger.error("Too large uncertainty!")
195 | if logger.info.calls + logger.warning.calls > args.min_steps:
196 | sys.exit(0)
197 | else:
198 | sys.exit("Too large uncertainty!")
199 | elif forces_sd > 0.05:
200 | collect_traj.write(a)
201 | logger.warning("Steps={:10d} Epot={:12.3f} Ekin={:12.3f} temperature={:8.2f} energy_var={:10.6f} forces_var={:10.6f} forces_sd={:10.6f} forces_l2_var={:10.6f}".format(
202 | printenergy.calls * args.print_step,
203 | epot,
204 | ekin,
205 | temp,
206 | energy_var,
207 | forces_var,
208 | forces_sd,
209 | forces_l2_var,
210 | ))
211 | if logger.warning.calls > args.num_uncertain:
212 | logger.error(f"More than {args.num_uncertain} uncertain structures are collected!")
213 | if logger.info.calls + logger.warning.calls > args.min_steps:
214 | sys.exit(0)
215 | else:
216 | sys.exit(f"More than {args.num_uncertain} uncertain structures are collected!")
217 | else:
218 | logger.info("Steps={:10d} Epot={:12.3f} Ekin={:12.3f} temperature={:8.2f} energy_var={:10.6f} forces_var={:10.6f} forces_sd={:10.6f} forces_l2_var={:10.6f}".format(
219 | printenergy.calls * args.print_step,
220 | epot,
221 | ekin,
222 | temp,
223 | energy_var,
224 | forces_var,
225 | forces_sd,
226 | forces_l2_var,
227 | ))
228 |
229 | #atoms.calc = encalc
230 | if not np.any(atoms.get_momenta()):
231 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature)
232 | dyn = Langevin(atoms, args.time_step * units.fs, temperature_K=args.temperature, friction=0.1)
233 | dyn.attach(printenergy, interval=args.print_step)
234 |
235 | traj = Trajectory('MD.traj', 'w', atoms)
236 | dyn.attach(traj.write, interval=args.dump_step)
237 | dyn.run(args.max_steps)
238 |
239 | if __name__ == "__main__":
240 | main()
--------------------------------------------------------------------------------
/workflow/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import json, os, sys, toml
4 | from pathlib import Path
5 | import argparse
6 | import logging
7 | import itertools
8 | import torch
9 | import time
10 |
11 | from PaiNN.data import AseDataset, collate_atomsdata
12 | from PaiNN.model import PainnModel
13 |
14 | def setup_seed(seed):
15 | torch.manual_seed(seed)
16 | if torch.cuda.is_available():
17 | torch.cuda.manual_seed_all(seed)
18 | np.random.seed(seed)
19 | torch.backends.cudnn.deterministic = True
20 |
21 | def get_arguments(arg_list=None):
22 | parser = argparse.ArgumentParser(
23 | description="Train graph convolution network", fromfile_prefix_chars="+"
24 | )
25 | parser.add_argument(
26 | "--load_model",
27 | type=str,
28 | help="Load model parameters from previous run",
29 | )
30 | parser.add_argument(
31 | "--cutoff",
32 | type=float,
33 | help="Atomic interaction cutoff distance [�~E]",
34 | )
35 | parser.add_argument(
36 | "--split_file",
37 | type=str,
38 | help="Train/test/validation split file json",
39 | )
40 | parser.add_argument(
41 | "--val_ratio",
42 | type=float,
43 | help="Ratio of validation set. Only useful when 'split_file' is not assigned",
44 | )
45 | parser.add_argument(
46 | "--num_interactions",
47 | type=int,
48 | help="Number of interaction layers used",
49 | )
50 | parser.add_argument(
51 | "--node_size", type=int, help="Size of hidden node states"
52 | )
53 | parser.add_argument(
54 | "--output_dir",
55 | type=str,
56 | help="Path to output directory",
57 | )
58 | parser.add_argument(
59 | "--dataset", type=str, help="Path to ASE trajectory",
60 | )
61 | parser.add_argument(
62 | "--max_steps",
63 | type=int,
64 | help="Maximum number of optimisation steps",
65 | )
66 | parser.add_argument(
67 | "--device",
68 | type=str,
69 | help="Set which device to use for training e.g. 'cuda' or 'cpu'",
70 | )
71 | parser.add_argument(
72 | "--batch_size", type=int, help="Number of molecules per minibatch",
73 | )
74 | parser.add_argument(
75 | "--initial_lr", type=float, help="Initial learning rate",
76 | )
77 | parser.add_argument(
78 | "--forces_weight",
79 | type=float,
80 | help="Tradeoff between training on forces (weight=1) and energy (weight=0)",
81 | )
82 | parser.add_argument(
83 | "--log_inverval",
84 | type=int,
85 | help="The interval of model evaluation",
86 | )
87 | parser.add_argument(
88 | "--plateau_scheduler",
89 | action="store_true",
90 | help="Using ReduceLROnPlateau scheduler for decreasing learning rate when learning plateaus",
91 | )
92 | parser.add_argument(
93 | "--normalization",
94 | action="store_true",
95 | help="Enable normalization of the model",
96 | )
97 | parser.add_argument(
98 | "--atomwise_normalization",
99 | action="store_true",
100 | help="Enable atomwise normalization",
101 | )
102 | parser.add_argument(
103 | "--stop_patience",
104 | type=int,
105 | help="Stop training when validation loss is larger than best loss for 'stop_patience' steps",
106 | )
107 | parser.add_argument(
108 | "--random_seed",
109 | type=int,
110 | help="Random seed for this run",
111 | )
112 | parser.add_argument(
113 | "--cfg",
114 | type=str,
115 | help="Path to config file. e.g. 'arguments.toml'"
116 | )
117 |
118 | return parser.parse_args(arg_list)
119 |
120 | def split_data(dataset, args):
121 | # Load or generate splits
122 | if args.split_file:
123 | with open(args.split_file, "r") as fp:
124 | splits = json.load(fp)
125 | else:
126 | datalen = len(dataset)
127 | num_validation = int(math.ceil(datalen * args.val_ratio))
128 | indices = np.random.permutation(len(dataset))
129 | splits = {
130 | "train": indices[num_validation:].tolist(),
131 | "validation": indices[:num_validation].tolist(),
132 | }
133 |
134 | # Save split file
135 | with open(os.path.join(args.output_dir, "datasplits.json"), "w") as f:
136 | json.dump(splits, f)
137 |
138 | # Split the dataset
139 | datasplits = {}
140 | for key, indices in splits.items():
141 | datasplits[key] = torch.utils.data.Subset(dataset, indices)
142 | return datasplits
143 |
144 | def forces_criterion(predicted, target, reduction="mean"):
145 | # predicted, target are (bs, max_nodes, 3) tensors
146 | # node_count is (bs) tensor
147 | diff = predicted - target
148 | total_squared_norm = torch.linalg.norm(diff, dim=1) # bs
149 | if reduction == "mean":
150 | scalar = torch.mean(total_squared_norm)
151 | elif reduction == "sum":
152 | scalar = torch.sum(total_squared_norm)
153 | else:
154 | raise ValueError("Reduction must be 'mean' or 'sum'")
155 | return scalar
156 |
157 | def get_normalization(dataset, per_atom=True):
158 | # Use double precision to avoid overflows
159 | x_sum = torch.zeros(1, dtype=torch.double)
160 | x_2 = torch.zeros(1, dtype=torch.double)
161 | num_objects = 0
162 | for i, sample in enumerate(dataset):
163 | if i == 0:
164 | # Estimate "bias" from 1 sample
165 | # to avoid overflows for large valued datasets
166 | if per_atom:
167 | bias = sample["energy"] / sample["num_atoms"]
168 | else:
169 | bias = sample["energy"]
170 | x = sample["energy"]
171 | if per_atom:
172 | x = x / sample["num_atoms"]
173 | x -= bias
174 | x_sum += x
175 | x_2 += x ** 2.0
176 | num_objects += 1
177 | # Var(X) = E[X^2] - E[X]^2
178 | x_mean = x_sum / num_objects
179 | x_var = x_2 / num_objects - x_mean ** 2.0
180 | x_mean = x_mean + bias
181 |
182 | default_type = torch.get_default_dtype()
183 |
184 | return x_mean.type(default_type), torch.sqrt(x_var).type(default_type)
185 |
186 | def eval_model(model, dataloader, device, forces_weight):
187 | energy_running_ae = 0
188 | energy_running_se = 0
189 |
190 | forces_running_l2_ae = 0
191 | forces_running_l2_se = 0
192 | forces_running_c_ae = 0
193 | forces_running_c_se = 0
194 | forces_running_loss = 0
195 |
196 | running_loss = 0
197 | count = 0
198 | forces_count = 0
199 | criterion = torch.nn.MSELoss()
200 |
201 | for batch in dataloader:
202 | device_batch = {
203 | k: v.to(device=device, non_blocking=True) for k, v in batch.items()
204 | }
205 | out = model(device_batch)
206 |
207 | # counts
208 | count += batch["energy"].shape[0]
209 | forces_count += batch['forces'].shape[0]
210 |
211 | # use mean square loss here
212 | forces_loss = forces_criterion(out["forces"], device_batch["forces"]).item()
213 | energy_loss = criterion(out["energy"], device_batch["energy"]).item() #problem here
214 | total_loss = forces_weight * forces_loss + (1 - forces_weight) * energy_loss
215 | running_loss += total_loss * batch["energy"].shape[0]
216 |
217 | # energy errors
218 | outputs = {key: val.detach().cpu().numpy() for key, val in out.items()}
219 | energy_targets = batch["energy"].detach().cpu().numpy()
220 | energy_running_ae += np.sum(np.abs(energy_targets - outputs["energy"]), axis=0)
221 | energy_running_se += np.sum(
222 | np.square(energy_targets - outputs["energy"]), axis=0
223 | )
224 |
225 | # force errors
226 | forces_targets = batch["forces"].detach().cpu().numpy()
227 | forces_diff = forces_targets - outputs["forces"]
228 | forces_l2_norm = np.sqrt(np.sum(np.square(forces_diff), axis=1))
229 |
230 | forces_running_c_ae += np.sum(np.abs(forces_diff))
231 | forces_running_c_se += np.sum(np.square(forces_diff))
232 |
233 | forces_running_l2_ae += np.sum(np.abs(forces_l2_norm))
234 | forces_running_l2_se += np.sum(np.square(forces_l2_norm))
235 |
236 | energy_mae = energy_running_ae / count
237 | energy_rmse = np.sqrt(energy_running_se / count)
238 |
239 | forces_l2_mae = forces_running_l2_ae / forces_count
240 | forces_l2_rmse = np.sqrt(forces_running_l2_se / forces_count)
241 |
242 | forces_c_mae = forces_running_c_ae / (forces_count * 3)
243 | forces_c_rmse = np.sqrt(forces_running_c_se / (forces_count * 3))
244 |
245 | total_loss = running_loss / count
246 |
247 | evaluation = {
248 | "energy_mae": energy_mae,
249 | "energy_rmse": energy_rmse,
250 | "forces_l2_mae": forces_l2_mae,
251 | "forces_l2_rmse": forces_l2_rmse,
252 | "forces_mae": forces_c_mae,
253 | "forces_rmse": forces_c_rmse,
254 | "sqrt(total_loss)": np.sqrt(total_loss),
255 | }
256 |
257 | return evaluation
258 |
259 | def update_namespace(ns, d):
260 | for k, v in d.items():
261 | if not ns.__dict__.get(k):
262 | ns.__dict__[k] = v
263 |
264 | class EarlyStopping():
265 | def __init__(self, patience=5, min_delta=0):
266 |
267 | self.patience = patience
268 | self.min_delta = min_delta
269 | self.counter = 0
270 | self.early_stop = False
271 |
272 | def __call__(self, val_loss, best_loss):
273 | if val_loss - best_loss > self.min_delta:
274 | self.counter +=1
275 | if self.counter >= self.patience:
276 | self.early_stop = True
277 |
278 | return self.early_stop
279 |
280 | def main():
281 | args = get_arguments()
282 | if args.cfg:
283 | with open(args.cfg, 'r') as f:
284 | params = toml.load(f)
285 | update_namespace(args, params)
286 |
287 | # Setup random seed
288 | setup_seed(args.random_seed)
289 |
290 | # Setup logging
291 | os.makedirs(args.output_dir, exist_ok=True)
292 | logging.basicConfig(
293 | level=logging.DEBUG,
294 | format="%(asctime)s [%(levelname)-5.5s] %(message)s",
295 | handlers=[
296 | logging.FileHandler(
297 | os.path.join(args.output_dir, "printlog.txt"), mode="w"
298 | ),
299 | logging.StreamHandler(),
300 | ],
301 | )
302 |
303 | # Save command line args
304 | with open(os.path.join(args.output_dir, "commandline_args.txt"), "w") as f:
305 | f.write("\n".join(sys.argv[1:]))
306 | # Save parsed command line arguments
307 | with open(os.path.join(args.output_dir, "arguments.json"), "w") as f:
308 | json.dump(vars(args), f)
309 |
310 | # Create device
311 | device = torch.device(args.device)
312 | # Put a tensor on the device before loading data
313 | # This way the GPU appears to be in use when other users run gpustat
314 | torch.tensor([0], device=device)
315 |
316 | # Setup dataset and loader
317 | logging.info("loading data %s", args.dataset)
318 | dataset = AseDataset(
319 | args.dataset,
320 | cutoff = args.cutoff,
321 | )
322 |
323 | datasplits = split_data(dataset, args)
324 |
325 | train_loader = torch.utils.data.DataLoader(
326 | datasplits["train"],
327 | args.batch_size,
328 | sampler=torch.utils.data.RandomSampler(datasplits["train"]),
329 | collate_fn=collate_atomsdata,
330 | )
331 | val_loader = torch.utils.data.DataLoader(
332 | datasplits["validation"],
333 | args.batch_size,
334 | collate_fn=collate_atomsdata,
335 | )
336 |
337 | logging.info('Dataset size: {}, training set size: {}, validation set size: {}'.format(
338 | len(dataset),
339 | len(datasplits["train"]),
340 | len(datasplits["validation"]),
341 | ))
342 |
343 | if args.normalization:
344 | logging.info("Computing mean and variance")
345 | target_mean, target_stddev = get_normalization(
346 | datasplits["train"],
347 | per_atom=args.atomwise_normalization,
348 | )
349 | logging.debug("target_mean=%f, target_stddev=%f" % (target_mean, target_stddev))
350 |
351 | net = PainnModel(
352 | num_interactions=args.num_interactions,
353 | hidden_state_size=args.node_size,
354 | cutoff=args.cutoff,
355 | normalization=args.normalization,
356 | target_mean=target_mean.tolist() if args.normalization else [0.0],
357 | target_stddev=target_stddev.tolist() if args.normalization else [1.0],
358 | atomwise_normalization=args.atomwise_normalization,
359 | )
360 | net.to(device)
361 |
362 | optimizer = torch.optim.Adam(net.parameters(), lr=args.initial_lr)
363 | criterion = torch.nn.MSELoss()
364 | if args.plateau_scheduler:
365 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)
366 | else:
367 | scheduler_fn = lambda step: 0.96 ** (step / 100000)
368 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_fn)
369 | early_stop = EarlyStopping(patience=args.stop_patience)
370 |
371 | running_loss = 0
372 | running_loss_count = 0
373 | # used for smoothing loss
374 | prev_loss = None
375 | best_val_loss = np.inf
376 | step = 0
377 | training_time = 0
378 |
379 | if args.load_model:
380 | logging.info(f"Load model from {args.load_model}")
381 | state_dict = torch.load(args.load_model)
382 | net.load_state_dict(state_dict["model"])
383 | # step = state_dict["step"]
384 | # best_val_loss = state_dict["best_val_loss"]
385 | # optimizer.load_state_dict(state_dict["optimizer"])
386 | scheduler.load_state_dict(state_dict["scheduler"])
387 |
388 | for epoch in itertools.count():
389 | for batch_host in train_loader:
390 | start = time.time()
391 | # Transfer to 'device'
392 | batch = {
393 | k: v.to(device=device, non_blocking=True)
394 | for (k, v) in batch_host.items()
395 | }
396 | # Reset gradient
397 | optimizer.zero_grad()
398 |
399 | # Forward, backward and optimize
400 | outputs = net(
401 | batch, compute_forces=bool(args.forces_weight)
402 | )
403 | energy_loss = criterion(outputs["energy"], batch["energy"])
404 | if args.forces_weight:
405 | forces_loss = forces_criterion(outputs['forces'], batch['forces'])
406 | else:
407 | forces_loss = 0.0
408 | total_loss = (
409 | args.forces_weight * forces_loss
410 | + (1 - args.forces_weight) * energy_loss
411 | )
412 | total_loss.backward()
413 | optimizer.step()
414 | running_loss += total_loss.item() * batch["energy"].shape[0]
415 | running_loss_count += batch["energy"].shape[0]
416 | training_time += time.time() - start
417 |
418 | # print(step, loss_value)
419 | # Validate and save model
420 | if (step % args.log_interval == 0) or ((step + 1) == args.max_steps):
421 | eval_start = time.time()
422 | train_loss = running_loss / running_loss_count
423 | running_loss = running_loss_count = 0
424 |
425 | eval_dict = eval_model(net, val_loader, device, args.forces_weight)
426 | eval_formatted = ", ".join(
427 | ["{}={:.3f}".format(k, v) for (k, v) in eval_dict.items()]
428 | )
429 | # loss smoothing
430 | eval_loss = np.square(eval_dict["sqrt(total_loss)"])
431 | smooth_loss = eval_loss if prev_loss == None else 0.9 * eval_loss + 0.1 * prev_loss
432 | prev_loss = smooth_loss
433 |
434 | logging.info(
435 | "step={}, {}, sqrt(train_loss)={:.3f}, sqrt(smooth_loss)={:.3f}, patience={:3d}, training time={:.3f} min, eval time={:.3f} min".format(
436 | step,
437 | eval_formatted,
438 | math.sqrt(train_loss),
439 | math.sqrt(smooth_loss),
440 | early_stop.counter,
441 | training_time / 60,
442 | (time.time() - eval_start) / 60,
443 | )
444 | )
445 | training_time = 0
446 | # reduce learning rate
447 | if args.plateau_scheduler:
448 | scheduler.step(smooth_loss)
449 | # Save checkpoint
450 | if not early_stop(math.sqrt(smooth_loss), best_val_loss):
451 | best_val_loss = math.sqrt(smooth_loss)
452 | torch.save(
453 | {
454 | "model": net.state_dict(),
455 | "optimizer": optimizer.state_dict(),
456 | "scheduler": scheduler.state_dict(),
457 | "step": step,
458 | "best_val_loss": best_val_loss,
459 | "node_size": args.node_size,
460 | "num_layer": args.num_interactions,
461 | "cutoff": args.cutoff,
462 | },
463 | os.path.join(args.output_dir, "best_model.pth"),
464 | )
465 | else:
466 | sys.exit(0)
467 |
468 | step += 1
469 |
470 | if not args.plateau_scheduler:
471 | scheduler.step()
472 |
473 | if step >= args.max_steps:
474 | logging.info("Max steps reached, exiting")
475 | torch.save(
476 | {
477 | "model": net.state_dict(),
478 | "optimizer": optimizer.state_dict(),
479 | "scheduler": scheduler.state_dict(),
480 | "step": step,
481 | "best_val_loss": best_val_loss,
482 | "node_size": args.node_size,
483 | "num_layer": args.num_interactions,
484 | "cutoff": args.cutoff,
485 | },
486 | os.path.join(args.output_dir, "exit_model.pth"),
487 | )
488 | sys.exit(0)
489 |
490 | if __name__ == "__main__":
491 | main()
492 |
--------------------------------------------------------------------------------
/workflow/vasp.py:
--------------------------------------------------------------------------------
1 | from ase.calculators.vasp import Vasp
2 | from ase.io import read, write, Trajectory
3 | from shutil import copy
4 | import os, subprocess
5 | import numpy as np
6 | import argparse
7 | import json
8 | import toml
9 | from pathlib import Path
10 |
11 | def get_arguments(arg_list=None):
12 | parser = argparse.ArgumentParser(
13 | description="General Active Learning", fromfile_prefix_chars="+"
14 | )
15 | parser.add_argument(
16 | "--label_set",
17 | type=str,
18 | help="Path to trajectory to be labeled by DFT",
19 | )
20 | parser.add_argument(
21 | "--train_set",
22 | type=str,
23 | help="Path to existing training data set",
24 | )
25 | parser.add_argument(
26 | "--pool_set",
27 | type=str,
28 | help="Path to MD trajectory obtained from machine learning potential",
29 | )
30 | parser.add_argument(
31 | "--al_info",
32 | type=str,
33 | help="Path to json file that stores indices selected in pool set",
34 | )
35 | parser.add_argument(
36 | "--num_jobs",
37 | type=int,
38 | help="Number of DFT jobs",
39 | )
40 | parser.add_argument(
41 | "--job_order",
42 | type=int,
43 | help="Split DFT jobs to several different parts",
44 | )
45 | parser.add_argument(
46 | "--cfg",
47 | type=str,
48 | default="arguments.toml",
49 | help="Path to config file. e.g. 'arguments.toml'"
50 | )
51 |
52 | return parser.parse_args(arg_list)
53 |
54 | def update_namespace(ns, d):
55 | for k, v in d.items():
56 | if not isinstance(v, dict):
57 | ns.__dict__[k] = v
58 |
59 | def main():
60 | # set environment variables
61 | os.putenv('ASE_VASP_VDW', '/home/energy/modules/software/VASP/vasp-potpaw-5.4')
62 | os.putenv('VASP_PP_PATH', '/home/energy/modules/software/VASP/vasp-potpaw-5.4')
63 | os.putenv('ASE_VASP_COMMAND', 'mpirun vasp_std')
64 |
65 | args = get_arguments()
66 | if args.cfg:
67 | with open(args.cfg, 'r') as f:
68 | params = toml.load(f)
69 | update_namespace(args, params)
70 |
71 | # get images and set parameters
72 | if args.label_set:
73 | images = read(args.label_set, index = ':')
74 | elif args.pool_set:
75 | if isinstance(args.pool_set, list):
76 | pool_traj = []
77 | for pool_path in args.pool_set:
78 | if Path(pool_path).stat().st_size > 0:
79 | pool_traj += read(pool_path, ':')
80 | else:
81 | pool_traj = Trajectory(args.pool_set)
82 | with open(args.al_info) as f:
83 | indices = json.load(f)["selected"]
84 | if args.num_jobs:
85 | split_idx = np.array_split(indices, args.num_jobs)
86 | indices = split_idx[args.job_order]
87 | images = [pool_traj[i] for i in indices]
88 | else:
89 | raise RuntimeError('Valid configarations for DFT calculation should be provided!')
90 |
91 | vasp_params = params['VASP']
92 | calc = Vasp(**vasp_params)
93 | traj = Trajectory('dft_structures.traj', mode = 'a')
94 | check_result = False
95 | unconverged = Trajectory('unconverged.traj', mode = 'a')
96 | unconverged_idx = []
97 | for i, atoms in enumerate(images):
98 | atoms.set_pbc([True,True,True])
99 | atoms.set_calculator(calc)
100 | atoms.get_potential_energy()
101 | steps = int(subprocess.getoutput('grep LOOP OUTCAR | wc -l'))
102 | if steps <= vasp_params['nelm']:
103 | traj.write(atoms)
104 | else:
105 | check_result = True
106 | unconverged.write(atoms)
107 | unconverged_idx.append(i)
108 | copy('OSZICAR', 'OSZICAR_{}'.format(i))
109 |
110 | traj.close()
111 | # write to training set
112 | if check_result:
113 | raise RuntimeError(f"DFT calculations of {unconverged_idx} are not converged!")
114 |
115 | if args.train_set:
116 | train_traj = Trajectory(args.train_set, mode = 'a')
117 | images = read('dft_structures.traj', ':')
118 | for atoms in images:
119 | atoms.info['system'] = args.system
120 | atoms.info['path'] = str(Path('dft_structures.traj').resolve())
121 | train_traj.write(atoms)
122 |
123 | os.remove('WAVECAR')
124 |
125 | if __name__ == "__main__":
126 | main()
127 |
--------------------------------------------------------------------------------