├── PaiNN ├── __init__.py ├── active_learning.py ├── calculator.py ├── data.py ├── kernel.py ├── model.py └── select.py ├── README.md ├── scripts ├── MD.traj ├── arguments.toml ├── gpu_info ├── gpu_run.sh ├── md_run.py ├── runner_output.log ├── train.py └── water_O2.cif ├── setup.py └── workflow ├── al_select.py ├── config.toml ├── flow.py ├── md_run.py ├── train.py └── vasp.py /PaiNN/__init__.py: -------------------------------------------------------------------------------- 1 | from PaiNN import * 2 | -------------------------------------------------------------------------------- /PaiNN/active_learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from collections import defaultdict 4 | from torch_scatter import scatter_mean 5 | from typing import List, Dict, Tuple, Optional 6 | from PaiNN.data import collate_atomsdata 7 | from PaiNN.select import * 8 | from PaiNN.kernel import * 9 | 10 | class FeatureExtractor(nn.Module): 11 | def __init__(self, model: nn.Module): 12 | super().__init__() 13 | self.model = model 14 | self._features = [] 15 | self._grads = [] 16 | self.hooks = [] 17 | for name, layer in self.model.named_modules(): 18 | if 'readout_mlp' in name and isinstance(layer, nn.Linear): 19 | self.hooks.append(layer.register_forward_pre_hook(self.save_feats_hook)) 20 | self.hooks.append(layer.register_backward_hook(self.save_grads_hook)) 21 | 22 | def save_feats_hook(self, _, in_feat): 23 | new_feat = torch.cat((in_feat[0].detach().clone(), torch.ones_like(in_feat[0][:, 0:1])), dim=-1) 24 | self._features.append(new_feat) 25 | 26 | def save_grads_hook(self, _, __, grad_output): 27 | self._grads.append(grad_output[0].detach().clone()) 28 | 29 | def unhook(self): 30 | for hook in self.hooks: 31 | hook.remove() 32 | 33 | def forward(self, model_inputs: Dict[str, torch.Tensor]): 34 | self._features = [] 35 | self._grads = [] 36 | _ = self.model(model_inputs) 37 | return self._features, self._grads[::-1] 38 | 39 | class RandomProjections: 40 | """Store parameters of random projections""" 41 | def __init__(self, model: nn.Module, num_features: int): 42 | self.num_features = num_features 43 | if self.num_features > 0: 44 | self.in_feat_proj = [ 45 | torch.randn(l.in_features +1, num_features, device=next(model.parameters()).device) 46 | for l in model.readout_mlp.children() if isinstance(l, nn.Linear) 47 | ] 48 | self.out_grad_proj = [ 49 | torch.randn(l.out_features, num_features, device=next(model.parameters()).device) 50 | for l in model.readout_mlp.children() if isinstance(l, nn.Linear) 51 | ] 52 | 53 | class FeatureStatistics: 54 | """ 55 | Generate features by giving models, pool, and training dataset 56 | """ 57 | def __init__( 58 | self, 59 | models: List[nn.Module], 60 | dataset: torch.utils.data.Dataset, 61 | random_projections: List[RandomProjections], 62 | batch_size: int=8, 63 | ): 64 | self.models = models 65 | self.batch_size = batch_size 66 | self.dataset = dataset 67 | self.random_projections = random_projections 68 | self.device = next(models[0].parameters()).device 69 | self.g = None 70 | self.ens_stats = None 71 | self.Fisher = None 72 | self.F_reg_inv = None 73 | 74 | def _compute_ens_stats(self, model_inputs: Dict[str, torch.Tensor], labeled_data: bool=False) -> Dict[str, torch.Tensor]: 75 | """Compute energy variance, forces variance, energy absolute error, and forces absolute error""" 76 | ens_stats = defaultdict(list) 77 | predictions = defaultdict(list) 78 | for model in self.models: 79 | model_results = model(model_inputs) 80 | predictions['energy'].append(model_results["energy"].detach()) 81 | predictions['forces'].append(model_results["forces"].detach()) 82 | 83 | predictions = {k: torch.stack(v) for k, v in predictions.items()} 84 | 85 | image_idx = torch.arange( 86 | model_inputs['num_atoms'].shape[0], 87 | device=model_inputs['num_atoms'].device, 88 | ) 89 | image_idx = torch.repeat_interleave(image_idx, model_inputs['num_atoms']) 90 | 91 | if len(self.models) > 1: 92 | E_var = torch.var(predictions['energy'], dim=0) 93 | F_var = torch.var(predictions['forces'], dim=0) 94 | F_var = scatter_mean(torch.mean(F_var, dim=-1), image_idx, dim=0) 95 | ens_stats['Energy-Var'] = E_var 96 | ens_stats['Forces-Var'] = F_var 97 | 98 | if labeled_data: 99 | E_AE = torch.abs(model_inputs['energy'] - torch.mean(predictions['energy'], dim=0)) 100 | F_AE = torch.abs(model_inputs['forces'] - torch.mean(predictions['forces'], dim=0)) 101 | F_AE = scatter_mean(torch.mean(F_AE, dim=-1), image_idx, dim=0) 102 | 103 | ens_stats['Energy-AE'] = E_AE 104 | ens_stats['Forces-AE'] = F_AE 105 | 106 | return ens_stats 107 | 108 | def _compute_features( 109 | self, 110 | feature_extractor: FeatureExtractor, 111 | model_inputs: torch.tensor, 112 | random_projection: RandomProjections, 113 | kernel: str='ll-gradient', 114 | ) -> torch.Tensor: 115 | """ 116 | Implementing features calculation and kernel transformation. 117 | Available features are: 118 | ll-gradient: last layer gradient feature, obtained from neural networks. 119 | full-gradient: All gradient information from NN, must use random projections kernel transformation. 120 | gnn: Features learned by message passing layers 121 | symmetry-function: Behler Parrinello symmetry function, can only be used for CUR. To be implemented. 122 | """ 123 | image_idx = torch.arange( 124 | model_inputs['num_atoms'].shape[0], 125 | device=model_inputs['num_atoms'].device, 126 | ) 127 | image_idx = torch.repeat_interleave(image_idx, model_inputs['num_atoms']) 128 | 129 | if kernel == 'full-gradient': 130 | assert random_projection.num_features != 0, "Error! Random projections must be provided!" 131 | feats, grads = feature_extractor(model_inputs) 132 | atomic_g = torch.zeros((image_idx.shape[0], random_projection.num_features)) 133 | for feat, grad, in_proj, out_proj in zip( 134 | feats, 135 | grads, 136 | random_projection.in_feat_proj, 137 | random_projection.out_grad_proj, 138 | ): 139 | atomic_g = (feat @ in_proj) * (grad @ out_proj) 140 | 141 | g = torch.zeros( 142 | (model_inputs['num_atoms'].shape[0], atomic_g.shape[1]), 143 | dtype = atomic_g.dtype, 144 | device = atomic_g.device, 145 | ).index_add(0, image_idx, atomic_g) 146 | elif kernel == 'local_full-g': 147 | assert random_projection.num_features != 0, "Error! Random projections must be provided!" 148 | feats, grads = feature_extractor(model_inputs) 149 | atomic_g = torch.zeros((image_idx.shape[0], random_projection.num_features)) 150 | for feat, grad, in_proj, out_proj in zip( 151 | feats, 152 | grads, 153 | random_projection.in_feat_proj, 154 | random_projection.out_grad_proj, 155 | ): 156 | atomic_g = (feat @ in_proj) * (grad @ out_proj) 157 | g = atomic_g 158 | 159 | elif kernel == 'll-gradient': 160 | feats, grads = feature_extractor(model_inputs) 161 | if random_projection.num_features != 0: 162 | atomic_g = (feats[-1] @ random_projection.in_feat_proj[-1]) *\ 163 | (grads[-1] @ random_projection.out_grad_proj[-1]) 164 | else: 165 | atomic_g = feats[-1][:, :-1] 166 | 167 | g = torch.zeros( 168 | (model_inputs['num_atoms'].shape[0], atomic_g.shape[1]), 169 | dtype = atomic_g.dtype, 170 | device = atomic_g.device, 171 | ).index_add(0, image_idx, atomic_g) 172 | 173 | elif kernel == 'local_ll-g': 174 | feats, grads = feature_extractor(model_inputs) 175 | if random_projection.num_features != 0: 176 | atomic_g = (feats[-1] @ random_projection.in_feat_proj[-1]) *\ 177 | (grads[-1] @ random_projection.out_grad_proj[-1]) 178 | else: 179 | atomic_g = feats[-1][:, :-1] 180 | g = atomic_g 181 | 182 | elif kernel == 'gnn': 183 | feats, grads = feature_extractor(model_inputs) 184 | if random_projection.num_features != 0: 185 | atomic_g = (feats[0] @ random_projection.in_feat_proj[0]) *\ 186 | (grads[0] @ random_projection.out_grad_proj[0]) 187 | else: 188 | atomic_g = feats[0][:, :-1] 189 | 190 | g = torch.zeros( 191 | (model_inputs['num_atoms'].shape[0], atomic_g.shape[1]), 192 | dtype = atomic_g.dtype, 193 | device = atomic_g.device, 194 | ).index_add(0, image_idx, atomic_g) 195 | 196 | elif kernel == 'local_gnn': 197 | feats, grads = feature_extractor(model_inputs) 198 | if random_projection.num_features != 0: 199 | atomic_g = (feats[0] @ random_projection.in_feat_proj[0]) *\ 200 | (grads[0] @ random_projection.out_grad_proj[0]) 201 | else: 202 | atomic_g = feats[0][:, :-1] 203 | g = atomic_g 204 | 205 | return g 206 | 207 | def _compute_fisher(self, g: torch.Tensor) -> torch.Tensor: 208 | return torch.einsum('mci, mcj -> mij', g, g) 209 | 210 | def get_features( 211 | self, 212 | dataset: Optional[torch.utils.data.Dataset]=None, 213 | kernel: str='full-gradient', 214 | ) -> torch.Tensor: 215 | """ 216 | :return: Feature vector of ``shape=(n_models, n_structures, n_features)``. 217 | """ 218 | if dataset == None: 219 | dataset = self.dataset 220 | else: 221 | self.dataset = dataset 222 | self.g = None 223 | 224 | if self.g == None: 225 | dataloader = torch.utils.data.DataLoader( 226 | dataset=dataset, 227 | batch_size=self.batch_size, 228 | collate_fn=collate_atomsdata, 229 | ) 230 | global_g = [] 231 | for i, model in enumerate(self.models): 232 | feat_extract = FeatureExtractor(model) 233 | model_g = [] 234 | for batch in dataloader: 235 | batch = {k: v.to(self.device) for k, v in batch.items()} 236 | model_g.append(self._compute_features( 237 | feat_extract, 238 | batch, 239 | kernel=kernel, 240 | random_projection=self.random_projections[i], 241 | )) 242 | feat_extract.unhook() 243 | model_g = torch.cat(model_g) 244 | # Normalization 245 | model_g = (model_g - torch.mean(model_g, dim=0)) / torch.var(model_g, dim=0) 246 | global_g.append(model_g) 247 | # global_g.append(torch.cat(model_g)) 248 | 249 | self.g = torch.stack(global_g) 250 | 251 | return self.g 252 | 253 | def get_num_atoms( 254 | self, 255 | dataset: Optional[torch.utils.data.Dataset]=None, 256 | ): 257 | if dataset == None: 258 | dataset = self.dataset 259 | else: 260 | self.dataset = dataset 261 | num_atoms = [] 262 | dataloader = torch.utils.data.DataLoader( 263 | dataset=dataset, 264 | batch_size=self.batch_size, 265 | collate_fn=collate_atomsdata, 266 | ) 267 | for batch in dataloader: 268 | batch = {k: v.to(self.device) for k, v in batch.items()} 269 | num_atoms.append(batch['num_atoms']) 270 | 271 | return torch.cat(num_atoms) 272 | 273 | def get_ens_stats(self, dataset: Optional[torch.utils.data.Dataset]=None) -> Dict[str, torch.Tensor]: 274 | """ 275 | :return: Dict of energy statistics 276 | """ 277 | if dataset == None: 278 | dataset = self.dataset 279 | else: 280 | self.dataset = dataset 281 | self.ens_stats = None 282 | 283 | if self.ens_stats == None: 284 | dataloader = torch.utils.data.DataLoader( 285 | dataset=dataset, 286 | batch_size=self.batch_size, 287 | collate_fn=collate_atomsdata, 288 | ) 289 | ens_stats = [] 290 | for batch in dataloader: 291 | batch = {k: v.to(self.device) for k, v in batch.items()} 292 | labeled_data = True if 'energy' in batch.keys() else False 293 | ens_stats.append(self._compute_ens_stats(batch, labeled_data)) 294 | 295 | self.ens_stats = {k: torch.cat([ens[k] for ens in ens_stats]) for k in ens_stats[0].keys()} 296 | 297 | return self.ens_stats 298 | 299 | def get_fisher(self) -> torch.Tensor: 300 | if self.Fisher is None: 301 | self.Fisher = self._compute_fisher(self.get_features()) 302 | return self.Fisher 303 | 304 | def get_F_inv(self) -> torch.Tensor: 305 | """ 306 | :return: Regularized inverse of Fisher matrix of "shape=(n_models, n_features, n_features)". 307 | """ 308 | if self.F_reg_inv is None: 309 | F = self.get_features() 310 | g = self.get_g() 311 | # empirical regularisation 312 | lam = torch.linalg.trace(F) / (g.shape[1] * g.shape[2]) 313 | self.F_train_reg_inv = torch.linalg.inv(F + lam * torch.eye(F.shape[1])) 314 | return self.F_train_reg_inv 315 | 316 | class GeneralActiveLearning: 317 | """Provides methods for selecting batches during active learning. 318 | 319 | :param kernel: Name of the kernel, e.g. "full-g", "ll-g", "full-F_inv", "ll-F_inv", "qbc-energy", "qbc-force". 320 | "random" produces random selection and "ae-energy" and "ae-force" select by absolute errors 321 | on the pool data, which is only possible if the pool data is already labeled. 322 | :param selection: Selection method, one of "max_dist_greedy", "deterministic_CUR", "lcmd_greedy", "max_det_greedy" or "max_diag". 323 | :param n_random_features: If "n_random_features = 0", do not use random projections. 324 | Otherwise, use random projections of all linear-layer gradients. 325 | """ 326 | def __init__( 327 | self, 328 | kernel = 'full-g', 329 | selection = 'max_diag', 330 | n_random_features = 0, 331 | ): 332 | self.kernel = kernel 333 | self.selection = selection 334 | self.n_random_features = n_random_features 335 | 336 | def select( 337 | self, 338 | models: List[nn.Module], 339 | datasets: Dict[str, torch.utils.data.Dataset], 340 | batch_size: int = 8, 341 | al_batch_size: int = 100, 342 | ): 343 | """ 344 | models: pytorch models, 345 | dataset: a dictionary containing pool, train, and validation dataset, 346 | batch_size: batch size for extracting features, 347 | al_batch_size: active learning selection batch size 348 | """ 349 | if (self.kernel == 'qbc-energy' or self.kernel == 'qbc-force' or self.kernel == 'ae-energy' or 350 | self.kernel == 'ae-force' or self.kernel == 'random') and self.selection != 'max_diag': 351 | raise RuntimeError(f'{self.kernel} kernel can only be used with max_diag selection method,' 352 | f' not with {self.selection}!') 353 | random_projections = [RandomProjections(model, self.n_random_features) for model in models] 354 | 355 | stats = { 356 | key: FeatureStatistics(models, ds, random_projections, batch_size) 357 | for key, ds in datasets.items() 358 | } 359 | 360 | if self.selection == 'max_dist_greedy': 361 | matrix = self._get_kernel_matrix(stats['pool'], stats['train']) 362 | idxs = max_dist_greedy(matrix=matrix, batch_size=al_batch_size, n_train=len(datasets['train'])) 363 | elif self.selection == 'max_diag': 364 | matrix = self._get_kernel_matrix(stats['pool']) 365 | idxs = max_diag(matrix=matrix, batch_size=al_batch_size) 366 | elif self.selection == 'max_det_greedy': 367 | matrix = self._get_kernel_matrix(stats['pool']) 368 | idxs = max_det_greedy(matrix=matrix, batch_size=al_batch_size) 369 | elif self.selection == 'lcmd_greedy': 370 | matrix = self._get_kernel_matrix(stats['pool'], stats['train']) 371 | idxs = lcmd_greedy(matrix=matrix, batch_size=al_batch_size, n_train=len(datasets['train'])) 372 | elif self.selection == 'max_det_greedy_local': 373 | matrix, num_atoms = self._get_kernel_matrix(stats['pool']) 374 | idxs = max_det_greedy_local(matrix=matrix, batch_size=al_batch_size, num_atoms=num_atoms) 375 | else: 376 | raise NotImplementedError(f"Unknown selection method '{self.selection}' for active learning!") 377 | 378 | return idxs.cpu().tolist() 379 | 380 | 381 | def _get_kernel_matrix(self, pool_stats: FeatureStatistics, train_stats: Optional[FeatureStatistics]=None) -> KernelMatrix: 382 | stats_list = [pool_stats] if train_stats == None else [pool_stats, train_stats] 383 | 384 | if self.kernel == 'full-g': 385 | return FeatureKernelMatrix(torch.cat([s.get_features(kernel='full-gradient') for s in stats_list], dim=1)) 386 | elif self.kernel == 'll-g': 387 | return FeatureKernelMatrix(torch.cat([s.get_features(kernel='ll-gradient') for s in stats_list], dim=1)) 388 | elif self.kernel == 'gnn': 389 | return FeatureKernelMatrix(torch.cat([s.get_features(kernel='gnn') for s in stats_list], dim=1)) 390 | elif self.kernel == 'local_full-g': 391 | matrix = FeatureKernelMatrix(torch.cat([s.get_features(kernel='local_full-g') for s in stats_list], dim=1)) 392 | num_atoms = torch.cat([s.get_num_atoms() for s in stats_list]) 393 | return matrix, num_atoms 394 | elif self.kernel == 'local_ll-g': 395 | matrix = FeatureKernelMatrix(torch.cat([s.get_features(kernel='local_ll-g') for s in stats_list], dim=1)) 396 | num_atoms = torch.cat([s.get_num_atoms() for s in stats_list]) 397 | return matrix, num_atoms 398 | elif self.kernel == 'local_gnn': 399 | matrix = FeatureKernelMatrix(torch.cat([s.get_features(kernel='local_gnn') for s in stats_list], dim=1)) 400 | num_atoms = torch.cat([s.get_num_atoms() for s in stats_list]) 401 | return matrix, num_atoms 402 | elif self.kernel == 'full-F_inv': 403 | return FeatureCovKernelMatrix(torch.cat([s.get_features(kernel='full-gradient') for s in stats_list], dim=1), 404 | train_stats.get_F_reg_inv()) 405 | elif self.kernel == 'll-F_inv': 406 | return FeatureCovKernelMatrix(torch.cat([s.get_features(kernel='ll-gradient') for s in stats_list], dim=1), 407 | train_stats.get_F_reg_inv()) 408 | elif self.kernel == 'qbc-energy': 409 | return DiagonalKernelMatrix(pool_stats.get_ens_stats()['Energy-Var']) 410 | elif self.kernel == 'qbc-force': 411 | return DiagonalKernelMatrix(pool_stats.get_ens_stats()['Forces-Var']) 412 | elif self.kernel == 'ae-energy': 413 | return DiagonalKernelMatrix(pool_stats.get_ens_stats()['Energy-AE']) 414 | elif self.kernel == 'ae-force': 415 | return DiagonalKernelMatrix(pool_stats.get_ens_stats()['Forces-AE']) 416 | elif self.kernel == 'random': 417 | return DiagonalKernelMatrix(torch.rand([sum([len(s.dataset) for s in stats_list])])) 418 | else: 419 | raise RuntimeError(f"Unknown active learning kernel {self.kernel}!") -------------------------------------------------------------------------------- /PaiNN/calculator.py: -------------------------------------------------------------------------------- 1 | from ase.calculators.calculator import Calculator, all_changes 2 | from PaiNN.data import AseDataReader 3 | import numpy as np 4 | 5 | class MLCalculator(Calculator): 6 | implemented_properties = ["energy", "forces"] 7 | 8 | def __init__( 9 | self, 10 | model, 11 | energy_scale=1.0, 12 | forces_scale=1.0, 13 | # stress_scale=1.0, 14 | **kwargs 15 | ): 16 | super().__init__(**kwargs) 17 | 18 | self.model = model 19 | self.model_device = next(model.parameters()).device 20 | self.cutoff = model.cutoff 21 | self.ase_data_reader = AseDataReader(self.cutoff) 22 | self.energy_scale = energy_scale 23 | self.forces_scale = forces_scale 24 | # self.stress_scale = stress_scale 25 | 26 | def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes): 27 | """ 28 | Args: 29 | atoms (ase.Atoms): ASE atoms object. 30 | properties (list of str): do not use this, no functionality 31 | system_changes (list of str): List of changes for ASE. 32 | """ 33 | # First call original calculator to set atoms attribute 34 | # (see https://wiki.fysik.dtu.dk/ase/_modules/ase/calculators/calculator.html#Calculator) 35 | if atoms is not None: 36 | self.atoms = atoms.copy() 37 | 38 | model_inputs = self.ase_data_reader(self.atoms) 39 | model_inputs = { 40 | k: v.to(self.model_device) for (k, v) in model_inputs.items() 41 | } 42 | 43 | model_results = self.model(model_inputs) 44 | 45 | results = {} 46 | 47 | # Convert outputs to calculator format 48 | results["forces"] = ( 49 | model_results["forces"].detach().cpu().numpy() * self.forces_scale 50 | ) 51 | results["energy"] = ( 52 | model_results["energy"][0].detach().cpu().numpy().item() 53 | * self.energy_scale 54 | ) 55 | # results["stress"] = ( 56 | # model_results["stress"][0].detach().cpu().numpy() * self.stress_scale 57 | # ) 58 | # atoms.info["ll_out"] = { 59 | # k: v.detach().cpu().numpy() for k, v in model_results["ll_out"].items() 60 | # } 61 | if model_results.get("fps"): 62 | atoms.info["fps"] = model_results["fps"].detach().cpu().numpy() 63 | 64 | self.results = results 65 | 66 | class EnsembleCalculator(Calculator): 67 | implemented_properties = ["energy", "forces"] 68 | 69 | def __init__( 70 | self, 71 | models, 72 | energy_scale=1.0, 73 | forces_scale=1.0, 74 | # stress_scale=1.0, 75 | **kwargs 76 | ): 77 | super().__init__(**kwargs) 78 | 79 | self.models = models 80 | self.model_device = next(models[0].parameters()).device 81 | self.cutoff = models[0].cutoff 82 | self.ase_data_reader = AseDataReader(self.cutoff) 83 | self.energy_scale = energy_scale 84 | self.forces_scale = forces_scale 85 | # self.stress_scale = stress_scale 86 | 87 | def calculate(self, atoms=None, properties=["energy"], system_changes=all_changes): 88 | """ 89 | Args: 90 | atoms (ase.Atoms): ASE atoms object. 91 | properties (list of str): do not use this, no functionality 92 | system_changes (list of str): List of changes for ASE. 93 | """ 94 | # First call original calculator to set atoms attribute 95 | # (see https://wiki.fysik.dtu.dk/ase/_modules/ase/calculators/calculator.html#Calculator) 96 | if atoms is not None: 97 | self.atoms = atoms.copy() 98 | 99 | model_inputs = self.ase_data_reader(self.atoms) 100 | model_inputs = { 101 | k: v.to(self.model_device) for (k, v) in model_inputs.items() 102 | } 103 | 104 | predictions = {'energy': [], 'forces': []} 105 | for model in self.models: 106 | model_results = model(model_inputs) 107 | predictions['energy'].append(model_results["energy"][0].detach().cpu().numpy().item() * self.energy_scale) 108 | predictions['forces'].append(model_results["forces"].detach().cpu().numpy() * self.forces_scale) 109 | 110 | results = {"energy": np.mean(predictions['energy'])} 111 | results["forces"] = np.mean(np.stack(predictions['forces']), axis=0) 112 | 113 | ensemble = { 114 | 'energy_var': np.var(predictions['energy']), 115 | 'forces_var': np.var(np.stack(predictions['forces']), axis=0), 116 | 'forces_l2_var': np.var(np.linalg.norm(predictions['forces'], axis=2), axis=0), 117 | } 118 | 119 | results['ensemble'] = ensemble 120 | 121 | self.results = results 122 | -------------------------------------------------------------------------------- /PaiNN/data.py: -------------------------------------------------------------------------------- 1 | from ase.io import read, write, Trajectory 2 | import torch 3 | from typing import List 4 | import asap3 5 | import numpy as np 6 | from scipy.spatial import distance_matrix 7 | 8 | # def ase_properties(atoms): 9 | # """Guess dataset format from an ASE atoms""" 10 | # atoms_prop = [] 11 | # 12 | # if atoms.pbc.any(): 13 | # atoms_prop.append('cell') 14 | # 15 | # try: 16 | # atoms.get_potential_energy() 17 | # atoms_prop.append('energy') 18 | # except: 19 | # pass 20 | # 21 | # try: 22 | # atoms.get_forces() 23 | # atoms_prop.append('forces') 24 | # except: 25 | # pass 26 | # 27 | # return atoms_prop 28 | 29 | class AseDataReader: 30 | def __init__(self, cutoff=5.0): 31 | self.cutoff = cutoff 32 | 33 | def __call__(self, atoms): 34 | atoms_data = { 35 | 'num_atoms': torch.tensor([atoms.get_global_number_of_atoms()]), 36 | 'elems': torch.tensor(atoms.numbers), 37 | 'coord': torch.tensor(atoms.positions, dtype=torch.float), 38 | } 39 | 40 | if atoms.pbc.any(): 41 | pairs, n_diff = self.get_neighborlist(atoms) 42 | atoms_data['cell'] = torch.tensor(atoms.cell[:], dtype=torch.float) 43 | else: 44 | pairs, n_diff = self.get_neighborlist_simple(atoms) 45 | 46 | atoms_data['pairs'] = torch.from_numpy(pairs) 47 | atoms_data['n_diff'] = torch.from_numpy(n_diff).float() 48 | atoms_data['num_pairs'] = torch.tensor([pairs.shape[0]]) 49 | 50 | try: 51 | energy = torch.tensor([atoms.get_potential_energy()], dtype=torch.float) 52 | atoms_data['energy'] = energy 53 | except (AttributeError, RuntimeError): 54 | pass 55 | 56 | try: 57 | forces = torch.tensor(atoms.get_forces(apply_constraint=False), dtype=torch.float) 58 | atoms_data['forces'] = forces 59 | except (AttributeError, RuntimeError): 60 | pass 61 | 62 | return atoms_data 63 | 64 | 65 | def get_neighborlist(self, atoms): 66 | nl = asap3.FullNeighborList(self.cutoff, atoms) 67 | pair_i_idx = [] 68 | pair_j_idx = [] 69 | n_diff = [] 70 | for i in range(len(atoms)): 71 | indices, diff, _ = nl.get_neighbors(i) 72 | pair_i_idx += [i] * len(indices) # local index of pair i 73 | pair_j_idx.append(indices) # local index of pair j 74 | n_diff.append(diff) 75 | 76 | pair_j_idx = np.concatenate(pair_j_idx) 77 | pairs = np.stack((pair_i_idx, pair_j_idx), axis=1) 78 | n_diff = np.concatenate(n_diff) 79 | 80 | return pairs, n_diff 81 | 82 | def get_neighborlist_simple(self, atoms): 83 | pos = atoms.get_positions() 84 | dist_mat = distance_matrix(pos, pos) 85 | mask = dist_mat < self.cutoff 86 | np.fill_diagonal(mask, False) 87 | pairs = np.argwhere(mask) 88 | n_diff = pos[pairs[:, 1]] - pos[pairs[:, 0]] 89 | 90 | return pairs, n_diff 91 | 92 | class AseDataset(torch.utils.data.Dataset): 93 | def __init__(self, ase_db, cutoff=5.0, **kwargs): 94 | super().__init__(**kwargs) 95 | 96 | if isinstance(ase_db, str): 97 | self.db = Trajectory(ase_db) 98 | else: 99 | self.db = ase_db 100 | 101 | self.cutoff = cutoff 102 | self.atoms_reader = AseDataReader(cutoff) 103 | 104 | def __len__(self): 105 | return len(self.db) 106 | 107 | def __getitem__(self, idx): 108 | atoms = self.db[idx] 109 | atoms_data = self.atoms_reader(atoms) 110 | return atoms_data 111 | 112 | def cat_tensors(tensors: List[torch.Tensor]): 113 | if tensors[0].shape: 114 | return torch.cat(tensors) 115 | return torch.stack(tensors) 116 | 117 | def collate_atomsdata(atoms_data: List[dict], pin_memory=True): 118 | # convert from list of dicts to dict of lists 119 | dict_of_lists = {k: [dic[k] for dic in atoms_data] for k in atoms_data[0]} 120 | if pin_memory: 121 | pin = lambda x: x.pin_memory() 122 | else: 123 | pin = lambda x: x 124 | 125 | collated = {k: cat_tensors(v) for k, v in dict_of_lists.items()} 126 | return collated 127 | -------------------------------------------------------------------------------- /PaiNN/kernel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class KernelMatrix: 4 | """Abstract kernel class used to calculate kernel matrix by giving a feature matrix""" 5 | def __init__(self, num_col: int): 6 | self.num_columns = num_col 7 | 8 | def get_number_of_columns(self) -> int: 9 | return self.num_columns 10 | 11 | def get_column(self, i: int) -> torch.Tensor: 12 | raise RuntimeError("Not implemented") 13 | 14 | def get_diag(self) -> torch.Tensor: 15 | raise RuntimeError("Not implemented") 16 | 17 | def get_sq_dists(self, i: int) -> torch.Tensor: 18 | diag = self.get_diag() 19 | return diag[i] + diag - 2 * self.get_column(i) 20 | 21 | class DiagonalKernelMatrix(KernelMatrix): 22 | """ 23 | Represents a diagonal kernel matrix, where get_column() and get_sq_dists() is not implemented. 24 | 25 | :param g: Diagonal of the kernel matrix. 26 | """ 27 | def __init__(self, g: torch.Tensor): 28 | super().__init__(g.shape[0]) 29 | self.diag = g 30 | 31 | def get_diag(self) -> torch.Tensor: 32 | return self.diag 33 | 34 | class FeatureKernelMatrix(KernelMatrix): 35 | """ 36 | input: m x n x p matrix 37 | m: number of models 38 | n: number of entries 39 | p: dimensionality of features 40 | """ 41 | def __init__(self, mat: torch.Tensor): 42 | super().__init__(mat.shape[1]) 43 | self.mat = mat 44 | self.diag = torch.einsum('mbi, mbi -> mb', mat, mat) 45 | 46 | def get_column(self, i: int) -> torch.Tensor: 47 | return torch.mean(torch.einsum("mnp, mp -> mn", self.mat, self.mat[:, i, :]), dim=0) 48 | 49 | def get_diag(self) -> torch.Tensor: 50 | return torch.mean(self.diag, dim=0) 51 | 52 | class FeatureCovKernelMatrix(KernelMatrix): 53 | """ 54 | input: m x n x p matrix mat, m x p x p covariance matrix 55 | m: number of models 56 | n: number of entries 57 | p: dimensionality of features 58 | """ 59 | def __init__(self, g: torch.Tensor, cov_mat: torch.Tensor): 60 | super().__init__(mat.shape[1]) 61 | self.g = g 62 | self.cov_mat = cov_mat 63 | self.cov_g = torch.einsum('mij, mbi -> mbj', self.cov_mat, g) 64 | self.diag = torch.einsum('mbi, mbi -> mb', self.cov_g, g) 65 | 66 | def get_diag(self) -> torch.Tensor: 67 | return torch.mean(self.diag, dim=0) 68 | 69 | def get_column(self, i: int) -> torch.Tensor: 70 | return torch.mean(torch.einsum('mbi, mi -> mb', self.g, self.cov_g[:, i, :]), dim=0) 71 | -------------------------------------------------------------------------------- /PaiNN/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | def sinc_expansion(edge_dist: torch.Tensor, edge_size: int, cutoff: float): 5 | """ 6 | calculate sinc radial basis function: 7 | 8 | sin(n *pi*d/d_cut)/d 9 | """ 10 | n = torch.arange(edge_size, device=edge_dist.device) + 1 11 | return torch.sin(edge_dist.unsqueeze(-1) * n * torch.pi / cutoff) / edge_dist.unsqueeze(-1) 12 | 13 | def cosine_cutoff(edge_dist: torch.Tensor, cutoff: float): 14 | """ 15 | Calculate cutoff value based on distance. 16 | This uses the cosine Behler-Parinello cutoff function: 17 | 18 | f(d) = 0.5*(cos(pi*d/d_cut)+1) for d < d_cut and 0 otherwise 19 | """ 20 | 21 | return torch.where( 22 | edge_dist < cutoff, 23 | 0.5 * (torch.cos(torch.pi * edge_dist / cutoff) + 1), 24 | torch.tensor(0.0, device=edge_dist.device, dtype=edge_dist.dtype), 25 | ) 26 | 27 | class PainnMessage(nn.Module): 28 | """Message function""" 29 | def __init__(self, node_size: int, edge_size: int, cutoff: float): 30 | super().__init__() 31 | 32 | self.edge_size = edge_size 33 | self.node_size = node_size 34 | self.cutoff = cutoff 35 | 36 | self.scalar_message_mlp = nn.Sequential( 37 | nn.Linear(node_size, node_size), 38 | nn.SiLU(), 39 | nn.Linear(node_size, node_size * 3), 40 | ) 41 | 42 | self.filter_layer = nn.Linear(edge_size, node_size * 3) 43 | 44 | def forward(self, node_scalar, node_vector, edge, edge_diff, edge_dist): 45 | # remember to use v_j, s_j but not v_i, s_i 46 | filter_weight = self.filter_layer(sinc_expansion(edge_dist, self.edge_size, self.cutoff)) 47 | filter_weight = filter_weight * cosine_cutoff(edge_dist, self.cutoff).unsqueeze(-1) 48 | scalar_out = self.scalar_message_mlp(node_scalar) 49 | filter_out = filter_weight * scalar_out[edge[:, 1]] 50 | 51 | gate_state_vector, gate_edge_vector, message_scalar = torch.split( 52 | filter_out, 53 | self.node_size, 54 | dim = 1, 55 | ) 56 | 57 | # num_pairs * 3 * node_size, num_pairs * node_size 58 | message_vector = node_vector[edge[:, 1]] * gate_state_vector.unsqueeze(1) 59 | edge_vector = gate_edge_vector.unsqueeze(1) * (edge_diff / edge_dist.unsqueeze(-1)).unsqueeze(-1) 60 | message_vector = message_vector + edge_vector 61 | 62 | # sum message 63 | residual_scalar = torch.zeros_like(node_scalar) 64 | residual_vector = torch.zeros_like(node_vector) 65 | residual_scalar.index_add_(0, edge[:, 0], message_scalar) 66 | residual_vector.index_add_(0, edge[:, 0], message_vector) 67 | 68 | # new node state 69 | new_node_scalar = node_scalar + residual_scalar 70 | new_node_vector = node_vector + residual_vector 71 | 72 | return new_node_scalar, new_node_vector 73 | 74 | class PainnUpdate(nn.Module): 75 | """Update function""" 76 | def __init__(self, node_size: int): 77 | super().__init__() 78 | 79 | self.update_U = nn.Linear(node_size, node_size) 80 | self.update_V = nn.Linear(node_size, node_size) 81 | 82 | self.update_mlp = nn.Sequential( 83 | nn.Linear(node_size * 2, node_size), 84 | nn.SiLU(), 85 | nn.Linear(node_size, node_size * 3), 86 | ) 87 | 88 | def forward(self, node_scalar, node_vector): 89 | Uv = self.update_U(node_vector) 90 | Vv = self.update_V(node_vector) 91 | 92 | Vv_norm = torch.linalg.norm(Vv, dim=1) 93 | mlp_input = torch.cat((Vv_norm, node_scalar), dim=1) 94 | mlp_output = self.update_mlp(mlp_input) 95 | 96 | a_vv, a_sv, a_ss = torch.split( 97 | mlp_output, 98 | node_vector.shape[-1], 99 | dim = 1, 100 | ) 101 | 102 | delta_v = a_vv.unsqueeze(1) * Uv 103 | inner_prod = torch.sum(Uv * Vv, dim=1) 104 | delta_s = a_sv * inner_prod + a_ss 105 | 106 | return node_scalar + delta_s, node_vector + delta_v 107 | 108 | class PainnModel(nn.Module): 109 | """PainnModel without edge updating""" 110 | def __init__( 111 | self, 112 | num_interactions, 113 | hidden_state_size, 114 | cutoff, 115 | normalization=True, 116 | target_mean=[0.0], 117 | target_stddev=[1.0], 118 | atomwise_normalization=True, 119 | **kwargs, 120 | ): 121 | super().__init__() 122 | 123 | num_embedding = 119 # number of all elements 124 | self.cutoff = cutoff 125 | self.num_interactions = num_interactions 126 | self.hidden_state_size = hidden_state_size 127 | self.edge_embedding_size = 20 128 | 129 | # Setup atom embeddings 130 | self.atom_embedding = nn.Embedding(num_embedding, hidden_state_size) 131 | 132 | # Setup message-passing layers 133 | self.message_layers = nn.ModuleList( 134 | [ 135 | PainnMessage(self.hidden_state_size, self.edge_embedding_size, self.cutoff) 136 | for _ in range(self.num_interactions) 137 | ] 138 | ) 139 | self.update_layers = nn.ModuleList( 140 | [ 141 | PainnUpdate(self.hidden_state_size) 142 | for _ in range(self.num_interactions) 143 | ] 144 | ) 145 | 146 | # Setup readout function 147 | self.readout_mlp = nn.Sequential( 148 | nn.Linear(self.hidden_state_size, self.hidden_state_size), 149 | nn.SiLU(), 150 | nn.Linear(self.hidden_state_size, 1), 151 | ) 152 | 153 | # Normalisation constants 154 | self.normalization = torch.nn.Parameter( 155 | torch.tensor(normalization), requires_grad=False 156 | ) 157 | self.atomwise_normalization = torch.nn.Parameter( 158 | torch.tensor(atomwise_normalization), requires_grad=False 159 | ) 160 | self.normalize_stddev = torch.nn.Parameter( 161 | torch.tensor(target_stddev[0]), requires_grad=False 162 | ) 163 | self.normalize_mean = torch.nn.Parameter( 164 | torch.tensor(target_mean[0]), requires_grad=False 165 | ) 166 | 167 | def forward(self, input_dict, compute_forces=True): 168 | num_atoms = input_dict['num_atoms'] 169 | num_pairs = input_dict['num_pairs'] 170 | 171 | # edge offset. Add offset to edges to get indices of pairs in a batch but not a structure 172 | edge = input_dict['pairs'] 173 | edge_offset = torch.cumsum( 174 | torch.cat((torch.tensor([0], 175 | device=num_atoms.device, 176 | dtype=num_atoms.dtype, 177 | ), num_atoms[:-1])), 178 | dim=0 179 | ) 180 | edge_offset = torch.repeat_interleave(edge_offset, num_pairs) 181 | edge = edge + edge_offset.unsqueeze(-1) 182 | edge_diff = input_dict['n_diff'] 183 | if compute_forces: 184 | edge_diff.requires_grad_() 185 | edge_dist = torch.linalg.norm(edge_diff, dim=1) 186 | 187 | node_scalar = self.atom_embedding(input_dict['elems']) 188 | node_vector = torch.zeros((input_dict['coord'].shape[0], 3, self.hidden_state_size), 189 | device=edge_diff.device, 190 | dtype=edge_diff.dtype, 191 | ) 192 | 193 | for message_layer, update_layer in zip(self.message_layers, self.update_layers): 194 | node_scalar, node_vector = message_layer(node_scalar, node_vector, edge, edge_diff, edge_dist) 195 | node_scalar, node_vector = update_layer(node_scalar, node_vector) 196 | 197 | node_scalar = self.readout_mlp(node_scalar) 198 | node_scalar.squeeze_() 199 | 200 | image_idx = torch.arange(input_dict['num_atoms'].shape[0], 201 | device=edge.device, 202 | ) 203 | image_idx = torch.repeat_interleave(image_idx, num_atoms) 204 | 205 | energy = torch.zeros_like(input_dict['num_atoms']).float() 206 | energy.index_add_(0, image_idx, node_scalar) 207 | 208 | # Apply (de-)normalization 209 | if self.normalization: 210 | normalizer = self.normalize_stddev 211 | energy = normalizer * energy 212 | mean_shift = self.normalize_mean 213 | if self.atomwise_normalization: 214 | mean_shift = input_dict["num_atoms"] * mean_shift 215 | energy = energy + mean_shift 216 | 217 | result_dict = {'energy': energy} 218 | 219 | if compute_forces: 220 | dE_ddiff = torch.autograd.grad( 221 | energy, 222 | edge_diff, 223 | grad_outputs=torch.ones_like(energy), 224 | retain_graph=True, 225 | create_graph=True, 226 | )[0] 227 | 228 | # diff = R_j - R_i, so -dE/dR_j = -dE/ddiff, -dE/R_i = dE/ddiff 229 | i_forces = torch.zeros_like(input_dict['coord']).index_add(0, edge[:, 0], dE_ddiff) 230 | j_forces = torch.zeros_like(input_dict['coord']).index_add(0, edge[:, 1], -dE_ddiff) 231 | forces = i_forces + j_forces 232 | 233 | result_dict['forces'] = forces 234 | 235 | return result_dict 236 | 237 | class PainnModel_predict(nn.Module): 238 | """PainnModel without edge updating""" 239 | def __init__(self, num_interactions, hidden_state_size, cutoff, **kwargs): 240 | super().__init__() 241 | 242 | num_embedding = 119 # number of all elements 243 | self.atom_embedding = nn.Embedding(num_embedding, hidden_state_size) 244 | self.cutoff = cutoff 245 | self.num_interactions = num_interactions 246 | self.hidden_state_size = hidden_state_size 247 | self.edge_embedding_size = 20 248 | 249 | self.message_layers = nn.ModuleList( 250 | [ 251 | PainnMessage(self.hidden_state_size, self.edge_embedding_size, self.cutoff) 252 | for _ in range(self.num_interactions) 253 | ] 254 | ) 255 | 256 | self.update_layers = nn.ModuleList( 257 | [ 258 | PainnUpdate(self.hidden_state_size) 259 | for _ in range(self.num_interactions) 260 | ] 261 | ) 262 | 263 | self.linear_1 = nn.Linear(self.hidden_state_size, self.hidden_state_size) 264 | self.silu = nn.SiLU() 265 | self.linear_2 = nn.Linear(self.hidden_state_size, 1) 266 | U_in_0 = torch.randn(self.hidden_state_size, 500) / 500 ** 0.5 267 | U_out_1 = torch.randn(self.hidden_state_size, 500) / 500 ** 0.5 268 | U_in_1 = torch.randn(self.hidden_state_size, 500) / 500 ** 0.5 269 | self.register_buffer('U_in_0', U_in_0) 270 | self.register_buffer('U_out_1', U_out_1) 271 | self.register_buffer('U_in_1', U_in_1) 272 | 273 | def forward(self, input_dict, compute_forces=True): 274 | # edge offset 275 | num_atoms = input_dict['num_atoms'] 276 | num_pairs = input_dict['num_pairs'] 277 | 278 | edge = input_dict['pairs'] 279 | edge_offset = torch.cumsum( 280 | torch.cat((torch.tensor([0], 281 | device=num_atoms.device, 282 | dtype=num_atoms.dtype, 283 | ), num_atoms[:-1])), 284 | dim=0 285 | ) 286 | edge_offset = torch.repeat_interleave(edge_offset, num_pairs) 287 | edge = edge + edge_offset.unsqueeze(-1) 288 | edge_diff = input_dict['n_diff'] 289 | if compute_forces: 290 | edge_diff.requires_grad_() 291 | edge_dist = torch.linalg.norm(edge_diff, dim=1) 292 | 293 | node_scalar = self.atom_embedding(input_dict['elems']) 294 | node_vector = torch.zeros((input_dict['coord'].shape[0], 3, self.hidden_state_size), 295 | device=edge_diff.device, 296 | dtype=edge_diff.dtype, 297 | ) 298 | 299 | for message_layer, update_layer in zip(self.message_layers, self.update_layers): 300 | node_scalar, node_vector = message_layer(node_scalar, node_vector, edge, edge_diff, edge_dist) 301 | node_scalar, node_vector = update_layer(node_scalar, node_vector) 302 | 303 | x0 = node_scalar 304 | z1 = self.linear_1(x0) 305 | z1.retain_grad() 306 | x1 = self.silu(z1) 307 | node_scalar = self.linear_2(x1) 308 | 309 | node_scalar.squeeze_() 310 | 311 | image_idx = torch.arange(input_dict['num_atoms'].shape[0], 312 | device=edge.device, 313 | ) 314 | image_idx = torch.repeat_interleave(image_idx, num_atoms) 315 | 316 | energy = torch.zeros_like(input_dict['num_atoms']).float() 317 | 318 | energy.index_add_(0, image_idx, node_scalar) 319 | result_dict = {'energy': energy} 320 | 321 | if compute_forces: 322 | dE_ddiff = torch.autograd.grad( 323 | energy, 324 | edge_diff, 325 | grad_outputs=torch.ones_like(energy), 326 | retain_graph=True, 327 | create_graph=True, 328 | )[0] 329 | 330 | # diff = R_j - R_i, so -dE/dR_j = -dE/ddiff, -dE/R_i = dE/ddiff 331 | i_forces = torch.zeros_like(input_dict['coord']).index_add(0, edge[:, 0], dE_ddiff) 332 | j_forces = torch.zeros_like(input_dict['coord']).index_add(0, edge[:, 1], -dE_ddiff) 333 | forces = i_forces + j_forces 334 | 335 | result_dict['forces'] = forces 336 | 337 | fps = torch.sum((x0.detach() @ self.U_in_0) * (z1.grad.detach() @ self.U_out_1) * 500 ** 0.5 + x1.detach() @ self.U_in_1, dim=0) 338 | # result_dict['ll_out'] = { 339 | # 'll_out_x0': x0.detach(), 340 | # 'll_out_z1': z1.grad.detach(), 341 | # 'll_out_x1': x1.detach(), 342 | # } 343 | result_dict['fps'] = fps 344 | del z1.grad 345 | return result_dict 346 | -------------------------------------------------------------------------------- /PaiNN/select.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PaiNN.kernel import KernelMatrix 3 | 4 | def max_diag(matrix: KernelMatrix, batch_size: int) -> torch.Tensor: 5 | """ 6 | maximize uncertainty selection method 7 | """ 8 | return torch.argsort(matrix.get_diag())[-batch_size:] 9 | 10 | def max_det_greedy(matrix: KernelMatrix, batch_size: int) -> torch.Tensor: 11 | vec_c = matrix.get_diag() 12 | batch_idxs = [torch.argmax(vec_c)] 13 | 14 | l_n = None 15 | 16 | for n in range(1, batch_size): 17 | opt_idx = batch_idxs[-1] 18 | l_n_T_l_n = 0.0 if l_n is None else torch.einsum('w,wc->c', l_n[:, opt_idx], l_n) 19 | mat_col = matrix.get_column(opt_idx) 20 | update = (1 / torch.sqrt(vec_c[opt_idx])) * (mat_col - l_n_T_l_n) 21 | vec_c = vec_c - update ** 2 22 | l_n = update.unsqueeze(0) if l_n is None else torch.concat((l_n, update.unsqueeze(0))) 23 | new_idx = torch.argmax(vec_c) 24 | if vec_c[new_idx] <= 1e-12 or new_idx in batch_idxs: 25 | break 26 | else: 27 | batch_idxs.append(new_idx) 28 | 29 | batch_idxs = torch.hstack(batch_idxs) 30 | return batch_idxs 31 | 32 | def max_det_greedy_local(matrix: KernelMatrix, batch_size: int, num_atoms: torch.Tensor) -> torch.Tensor: 33 | vec_c = matrix.get_diag() 34 | batch_idxs = [torch.argmax(vec_c)] 35 | 36 | l_n = None 37 | image_idx = torch.arange( 38 | num_atoms.shape[0], 39 | device=num_atoms.device, 40 | ) 41 | image_idx = torch.repeat_interleave(image_idx, num_atoms) 42 | 43 | selected_idx = [] 44 | n = 0 45 | while len(selected_idx) < batch_size: 46 | opt_idx = batch_idxs[-1] 47 | l_n_T_l_n = 0.0 if l_n is None else torch.einsum('w,wc->c', l_n[:, opt_idx], l_n) 48 | mat_col = matrix.get_column(opt_idx) 49 | update = (1 / torch.sqrt(vec_c[opt_idx])) * (mat_col - l_n_T_l_n) 50 | vec_c = vec_c - update ** 2 51 | l_n = update.unsqueeze(0) if l_n is None else torch.concat((l_n, update.unsqueeze(0))) 52 | new_idx = torch.argmax(vec_c) 53 | if vec_c[new_idx] <= 1e-12 or new_idx in batch_idxs: 54 | break 55 | else: 56 | batch_idxs.append(new_idx) 57 | if image_idx[new_idx] not in selected_idx: 58 | selected_idx.append(image_idx[new_idx]) 59 | 60 | return torch.stack(selected_idx) 61 | 62 | def lcmd_greedy(matrix: KernelMatrix, batch_size: int, n_train: int) -> torch.Tensor: 63 | """ 64 | Only accept matrix with double dtype!!! 65 | Selects batch elements by greedily picking those with the maximum distance in the largest cluster, 66 | including training points. Assumes that the last ``n_train`` columns of ``matrix`` correspond to training points. 67 | 68 | :param matrix: Kernel matrix. 69 | :param batch_size: Size of the AL batch. 70 | :param n_train: Number of training structures. 71 | :return: Indices of the selected structures. 72 | """ 73 | # assumes that the matrix contains pool samples, optionally followed by train samples 74 | n_pool = matrix.get_number_of_columns() - n_train 75 | sq_dists = matrix.get_diag() 76 | batch_idxs = [n_pool if n_train > 0 else torch.argmax(sq_dists)] 77 | closest_idxs = torch.zeros(n_pool, dtype=int, device=sq_dists.device) 78 | min_sq_dists = matrix.get_sq_dists(batch_idxs[-1])[:n_pool] 79 | 80 | for i in range(1, batch_size + n_train): 81 | if i < n_train: 82 | batch_idxs.append(n_pool+i) 83 | else: 84 | bincount = torch.bincount(closest_idxs, weights=min_sq_dists, minlength=i) 85 | max_bincount = torch.max(bincount) 86 | new_idx = torch.argmax(torch.where( 87 | torch.gather(bincount, 0, closest_idxs) == max_bincount, 88 | min_sq_dists, 89 | torch.zeros_like(min_sq_dists)-float("Inf"))) 90 | batch_idxs.append(new_idx) 91 | sq_dists = matrix.get_sq_dists(batch_idxs[-1])[:n_pool] 92 | new_min = sq_dists < min_sq_dists 93 | closest_idxs = torch.where(new_min, i, closest_idxs) 94 | min_sq_dists = torch.where(new_min, sq_dists, min_sq_dists) 95 | 96 | return torch.hstack(batch_idxs[n_train:]) 97 | 98 | def deterministic_CUR(matrix: KernelMatrix, batch_size: int, lambd: float=0.1, eposilon: float=1E-3) -> torch.Tensor: 99 | """ 100 | CUR matrix decomposition, the matrix must be normalized. 101 | """ 102 | n = matrix.num_columns 103 | W = torch.zeros(n, n) 104 | I = torch.eye(n, n) 105 | while True: 106 | W_t = W 107 | for i in range(matrix.num_columns): 108 | z = matrix.get_column(i) @ (I - W) + matrix.get_diag()[i] * W[i] 109 | coeff = 1 - lambd / torch.linalg.norm(z) 110 | W[i] = coeff * z if coeff > 0 else 0 * z 111 | if torch.linalg.norm(W - W_t) < eposilon: 112 | break 113 | 114 | return torch.argsort(torch.linalg.norm(W, dim=1))[-batch_size:] 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #
PaiNN-model introduction
2 | This is a simple implementation of [PaiNN](https://arxiv.org/abs/2102.03150) model and active learning workflow for fitting interatomic potentials. 3 | The learned features or gradients in the model are used for active learning. Several selection methods are implemented. 4 | All the active learning codes are to be tested. 5 | ##
Documentation
6 | No documentation yet. 7 | 8 | ##
Quick Start
9 |
10 | How to install 11 | 12 | This code is only tested on [**Python>=3.8.0**](https://www.python.org/) and [**PyTorch>=1.10**](https://pytorch.org/get-started/locally/). 13 | Requirements: [PyTorch Scatter](https://github.com/rusty1s/pytorch_scatter)(if you want to use active learning), 14 | [toml](https://toml.io/en/), [myqueue](https://myqueue.readthedocs.io/en/latest/installation.html)(if you want to submit jobs automatically). 15 | 16 | ```bash 17 | $ conda install pytorch-scatter -c pyg 18 | $ conda install -c conda-forge toml 19 | $ python3 -m pip install myqueue 20 | $ conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia 21 | $ git clone https://github.com/Yangxinsix/PaiNN-model.git 22 | $ cd PaiNN-model 23 | $ python -m pip install -U . 24 | ``` 25 | 26 |
27 | 28 |
29 | How to use 30 | 31 | * See `train.py` in `scripts` for training, and `md_run.py` for running MD simulations by using ASE. 32 | * See `al_select.py` for active learning. 33 | * See `flow.py` for distributing and submitting active learning jobs. 34 | 35 |
36 | -------------------------------------------------------------------------------- /scripts/MD.traj: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nityasagarjena/PaiNN-model/6ee9a59c3cd544b5e31d4936cb1e75e9bded6a6e/scripts/MD.traj -------------------------------------------------------------------------------- /scripts/arguments.toml: -------------------------------------------------------------------------------- 1 | node_size = 40 2 | num_interactions = 5 3 | cutoff = 5.0 4 | split_file = "/home/energy/xinyang/work/active_learning_test/datasplits.json" 5 | output_dir = "model_output" 6 | dataset = "/home/energy/xinyang/work/active_learning_test/md17aspirin.traj" 7 | max_steps = 1000000 8 | device = "cuda" 9 | batch_size = 32 10 | initial_lr = 0.0001 11 | forces_weight = 0.99 12 | log_interval = 1000 13 | normalization = true 14 | atomwise_normalization = true 15 | stop_tolerance = 10 16 | -------------------------------------------------------------------------------- /scripts/gpu_info: -------------------------------------------------------------------------------- 1 | Mon Aug 1 11:30:57 2022 2 | +-----------------------------------------------------------------------------+ 3 | | NVIDIA-SMI 470.74 Driver Version: 470.74 CUDA Version: 11.4 | 4 | |-------------------------------+----------------------+----------------------+ 5 | | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | 6 | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | 7 | | | | MIG M. | 8 | |===============================+======================+======================| 9 | | 0 NVIDIA GeForce ... On | 00000000:3F:00.0 Off | N/A | 10 | | 30% 39C P8 24W / 350W | 1MiB / 24268MiB | 0% Default | 11 | | | | N/A | 12 | +-------------------------------+----------------------+----------------------+ 13 | 14 | +-----------------------------------------------------------------------------+ 15 | | Processes: | 16 | | GPU GI CI PID Type Process name GPU Memory | 17 | | ID ID Usage | 18 | |=============================================================================| 19 | | No running processes found | 20 | +-----------------------------------------------------------------------------+ 21 | -------------------------------------------------------------------------------- /scripts/gpu_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | 3 | #SBATCH --mail-user=xinyang@dtu.dk 4 | #SBATCH --mail-type=END,FAIL 5 | #SBATCH --partition=sm3090 6 | #SBATCH -N 1 # Minimum of 1 node 7 | #SBATCH -n 8 # 10 MPI processes per node 8 | #SBATCH --time=7-00:00:00 9 | #SBATCH --job=PaiNN-training 10 | #SBATCH --output=runner_output.log 11 | #SBATCH --gres=gpu:RTX3090:1 12 | 13 | #module load ASE/3.22.0-intel-2020b 14 | #module load Python/3.8.6-GCCcore-10.2.0 15 | 16 | export MKL_NUM_THREADS=1 17 | export NUMEXPR_NUM_THREADS=1 18 | export OMP_NUM_THREADS=1 19 | export OPENBLAS_NUM_THREADS=1 20 | 21 | nvidia-smi > gpu_info 22 | ulimit -s unlimited 23 | python3 md_run.py 24 | -------------------------------------------------------------------------------- /scripts/md_run.py: -------------------------------------------------------------------------------- 1 | from ase.md.langevin import Langevin 2 | from ase.calculators.plumed import Plumed 3 | from ase import units 4 | from ase.md.velocitydistribution import MaxwellBoltzmannDistribution 5 | from ase.io import read, write, Trajectory 6 | 7 | import numpy as np 8 | import torch 9 | import sys 10 | import glob 11 | 12 | from PaiNN.data import AseDataset, collate_atomsdata 13 | from PaiNN.model import PainnModel_predict 14 | from PaiNN.calculator import MLCalculator 15 | from ase.constraints import FixAtoms 16 | 17 | # load model 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | # model_pth = glob.glob('/home/energy/xinyang/work/Au_MD/graphnn/ads_images/ensembles/*_layer/runs/model_outputs/best_model.pth') 20 | # # models = [] 21 | # for each in model_pth: 22 | # node_size = int(each.split('/')[-4].split('_')[0]) 23 | # num_inter = int(each.split('/')[-4].split('_')[2]) 24 | # model = PainnModel(num_interactions=num_inter, hidden_state_size=node_size, cutoff=5.0) 25 | # model.to(device) 26 | # state_dict = torch.load(each) 27 | # model.load_state_dict(state_dict["model"]) 28 | # models.append(model) 29 | # 30 | # encalc = EnsembleCalculator(models) 31 | 32 | # set md parameters 33 | #dataset="/home/energy/xinyang/work/Au_MD/training_loop/Au_larger/dataset_selector/dataset_repository/corrected_ads_images.traj" 34 | #images = read(dataset, ':') 35 | #indices = [i for i in range(len(images)) if images[i].info['system'] == '1OH'] 36 | #atoms = images[np.random.choice(indices)] 37 | #atoms = read('MD.traj', -1) 38 | #cons = FixAtoms(mask=atoms.positions[:, 2] < 5.9) 39 | #atoms.set_constraint(cons) 40 | 41 | model = PainnModel_predict(num_interactions=3, hidden_state_size=128, cutoff=5.0) 42 | model.to(device) 43 | state_dict = torch.load('/home/energy/xinyang/work/Au_MD/graphnn/pure_water/runs/model_outputs/best_model.pth') 44 | new_names = ["linear_1.weight", "linear_1.bias", "linear_2.weight", "linear_2.bias"] 45 | old_names = ["readout_mlp.0.weight", "readout_mlp.0.bias", "readout_mlp.2.weight", "readout_mlp.2.bias"] 46 | for old, new in zip(old_names, new_names): 47 | state_dict['model'][new] = state_dict['model'].pop(old) 48 | 49 | state_dict["model"]["U_in_0"] = torch.randn(128, 500) / 500 ** 0.5 50 | state_dict["model"]["U_out_1"] = torch.randn(128, 500) / 500 ** 0.5 51 | state_dict["model"]["U_in_1"] = torch.randn(128, 500) / 500 ** 0.5 52 | model.load_state_dict(state_dict["model"]) 53 | mlcalc = MLCalculator(model) 54 | 55 | atoms = read('water_O2.cif') 56 | atoms.calc = mlcalc 57 | atoms.get_potential_energy() 58 | 59 | #collect_traj = Trajectory('bad_struct.traj', 'a') 60 | steps = 0 61 | def printenergy(a=atoms): # store a reference to atoms in the definition. 62 | """Function to print the potential, kinetic and total energy.""" 63 | epot = a.get_potential_energy() 64 | ekin = a.get_kinetic_energy() 65 | temp = ekin / (1.5 * units.kB) / a.get_global_number_of_atoms() 66 | global steps 67 | steps += 1 68 | with open('ensemble.log', 'a') as f: 69 | f.write( 70 | f"Steps={steps:12.3f} Epot={epot:12.3f} Ekin={ekin:12.3f} temperature={temp:8.2f}\n") 71 | 72 | #atoms.calc = encalc 73 | MaxwellBoltzmannDistribution(atoms, temperature_K=350) 74 | dyn = Langevin(atoms, 0.25 * units.fs, temperature_K=350, friction=0.1) 75 | dyn.attach(printenergy, interval=1) 76 | 77 | traj = Trajectory('MD.traj', 'w', atoms) 78 | dyn.attach(traj.write, interval=400) 79 | dyn.run(10000000) 80 | -------------------------------------------------------------------------------- /scripts/runner_output.log: -------------------------------------------------------------------------------- 1 | + '[' -z '' ']' 2 | + case "$-" in 3 | + __lmod_vx=x 4 | + '[' -n x ']' 5 | + set +x 6 | Shell debugging temporarily silenced: export LMOD_SH_DBG_ON=1 for this output (/usr/share/lmod/lmod/init/bash) 7 | Shell debugging restarted 8 | + unset __lmod_vx 9 | + export MKL_NUM_THREADS=1 10 | + MKL_NUM_THREADS=1 11 | + export NUMEXPR_NUM_THREADS=1 12 | + NUMEXPR_NUM_THREADS=1 13 | + export OMP_NUM_THREADS=1 14 | + OMP_NUM_THREADS=1 15 | + export OPENBLAS_NUM_THREADS=1 16 | + OPENBLAS_NUM_THREADS=1 17 | + nvidia-smi 18 | + ulimit -s unlimited 19 | + python3 md_run.py 20 | slurmstepd: error: *** JOB 5233695 ON s002 CANCELLED AT 2022-08-01T11:34:52 *** 21 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import json, os, sys, toml 4 | from pathlib import Path 5 | import argparse 6 | import logging 7 | import itertools 8 | import torch 9 | import time 10 | 11 | from PaiNN.data import AseDataset, collate_atomsdata 12 | from PaiNN.model import PainnModel 13 | 14 | def get_arguments(arg_list=None): 15 | parser = argparse.ArgumentParser( 16 | description="Train graph convolution network", fromfile_prefix_chars="+" 17 | ) 18 | parser.add_argument( 19 | "--load_model", 20 | type=str, 21 | help="Load model parameters from previous run", 22 | ) 23 | parser.add_argument( 24 | "--cutoff", 25 | type=float, 26 | help="Atomic interaction cutoff distance [�~E]", 27 | ) 28 | parser.add_argument( 29 | "--split_file", 30 | type=str, 31 | help="Train/test/validation split file json", 32 | ) 33 | parser.add_argument( 34 | "--num_interactions", 35 | type=int, 36 | help="Number of interaction layers used", 37 | ) 38 | parser.add_argument( 39 | "--node_size", type=int, help="Size of hidden node states" 40 | ) 41 | parser.add_argument( 42 | "--output_dir", 43 | type=str, 44 | help="Path to output directory", 45 | ) 46 | parser.add_argument( 47 | "--dataset", type=str, help="Path to ASE trajectory", 48 | ) 49 | parser.add_argument( 50 | "--max_steps", 51 | type=int, 52 | help="Maximum number of optimisation steps", 53 | ) 54 | parser.add_argument( 55 | "--device", 56 | type=str, 57 | help="Set which device to use for training e.g. 'cuda' or 'cpu'", 58 | ) 59 | parser.add_argument( 60 | "--batch_size", type=int, help="Number of molecules per minibatch", 61 | ) 62 | parser.add_argument( 63 | "--initial_lr", type=float, help="Initial learning rate", 64 | ) 65 | parser.add_argument( 66 | "--forces_weight", 67 | type=float, 68 | help="Tradeoff between training on forces (weight=1) and energy (weight=0)", 69 | ) 70 | parser.add_argument( 71 | "--log_inverval", 72 | type=int, 73 | help="The interval of model evaluation", 74 | ) 75 | parser.add_argument( 76 | "--normalization", 77 | action="store_true", 78 | help="Enable normalization of the model", 79 | ) 80 | parser.add_argument( 81 | "--atomwise_normalization", 82 | action="store_true", 83 | help="Enable atomwise normalization", 84 | ) 85 | parser.add_argument( 86 | "--stop_tolerance", 87 | type=int, 88 | help="Stop training when validation loss is larger than best loss for 'stop_tolerance' steps", 89 | ) 90 | parser.add_argument( 91 | "--cfg", 92 | type=str, 93 | help="Path to config file. e.g. 'arguments.toml'" 94 | ) 95 | 96 | return parser.parse_args(arg_list) 97 | 98 | def split_data(dataset, args): 99 | # Load or generate splits 100 | if args.split_file: 101 | with open(args.split_file, "r") as fp: 102 | splits = json.load(fp) 103 | else: 104 | datalen = len(dataset) 105 | num_validation = int(math.ceil(datalen * 0.10)) 106 | indices = np.random.permutation(len(dataset)) 107 | splits = { 108 | "train": indices[num_validation:].tolist(), 109 | "validation": indices[:num_validation].tolist(), 110 | } 111 | 112 | # Save split file 113 | with open(os.path.join(args.output_dir, "datasplits.json"), "w") as f: 114 | json.dump(splits, f) 115 | 116 | # Split the dataset 117 | datasplits = {} 118 | for key, indices in splits.items(): 119 | datasplits[key] = torch.utils.data.Subset(dataset, indices) 120 | return datasplits 121 | 122 | def forces_criterion(predicted, target, reduction="mean"): 123 | # predicted, target are (bs, max_nodes, 3) tensors 124 | # node_count is (bs) tensor 125 | diff = predicted - target 126 | total_squared_norm = torch.linalg.norm(diff, dim=1) # bs 127 | if reduction == "mean": 128 | scalar = torch.mean(total_squared_norm) 129 | elif reduction == "sum": 130 | scalar = torch.sum(total_squared_norm) 131 | else: 132 | raise ValueError("Reduction must be 'mean' or 'sum'") 133 | return scalar 134 | 135 | def get_normalization(dataset, per_atom=True): 136 | # Use double precision to avoid overflows 137 | x_sum = torch.zeros(1, dtype=torch.double) 138 | x_2 = torch.zeros(1, dtype=torch.double) 139 | num_objects = 0 140 | for i, sample in enumerate(dataset): 141 | if i == 0: 142 | # Estimate "bias" from 1 sample 143 | # to avoid overflows for large valued datasets 144 | if per_atom: 145 | bias = sample["energy"] / sample["num_atoms"] 146 | else: 147 | bias = sample["energy"] 148 | x = sample["energy"] 149 | if per_atom: 150 | x = x / sample["num_atoms"] 151 | x -= bias 152 | x_sum += x 153 | x_2 += x ** 2.0 154 | num_objects += 1 155 | # Var(X) = E[X^2] - E[X]^2 156 | x_mean = x_sum / num_objects 157 | x_var = x_2 / num_objects - x_mean ** 2.0 158 | x_mean = x_mean + bias 159 | 160 | default_type = torch.get_default_dtype() 161 | 162 | return x_mean.type(default_type), torch.sqrt(x_var).type(default_type) 163 | 164 | def eval_model(model, dataloader, device, forces_weight): 165 | energy_running_ae = 0 166 | energy_running_se = 0 167 | 168 | forces_running_l2_ae = 0 169 | forces_running_l2_se = 0 170 | forces_running_c_ae = 0 171 | forces_running_c_se = 0 172 | forces_running_loss = 0 173 | 174 | running_loss = 0 175 | count = 0 176 | forces_count = 0 177 | criterion = torch.nn.MSELoss() 178 | 179 | for batch in dataloader: 180 | device_batch = { 181 | k: v.to(device=device, non_blocking=True) for k, v in batch.items() 182 | } 183 | out = model(device_batch) 184 | 185 | # counts 186 | count += batch["energy"].shape[0] 187 | forces_count += batch['forces'].shape[0] 188 | 189 | # use mean square loss here 190 | forces_loss = forces_criterion(out["forces"], device_batch["forces"]).item() 191 | energy_loss = criterion(out["energy"], device_batch["energy"]).item() #problem here 192 | total_loss = forces_weight * forces_loss + (1 - forces_weight) * energy_loss 193 | running_loss += total_loss * batch["energy"].shape[0] 194 | 195 | # energy errors 196 | outputs = {key: val.detach().cpu().numpy() for key, val in out.items()} 197 | energy_targets = batch["energy"].detach().cpu().numpy() 198 | energy_running_ae += np.sum(np.abs(energy_targets - outputs["energy"]), axis=0) 199 | energy_running_se += np.sum( 200 | np.square(energy_targets - outputs["energy"]), axis=0 201 | ) 202 | 203 | # force errors 204 | forces_targets = batch["forces"].detach().cpu().numpy() 205 | forces_diff = forces_targets - outputs["forces"] 206 | forces_l2_norm = np.sqrt(np.sum(np.square(forces_diff), axis=1)) 207 | 208 | forces_running_c_ae += np.sum(np.abs(forces_diff)) 209 | forces_running_c_se += np.sum(np.square(forces_diff)) 210 | 211 | forces_running_l2_ae += np.sum(np.abs(forces_l2_norm)) 212 | forces_running_l2_se += np.sum(np.square(forces_l2_norm)) 213 | 214 | energy_mae = energy_running_ae / count 215 | energy_rmse = np.sqrt(energy_running_se / count) 216 | 217 | forces_l2_mae = forces_running_l2_ae / forces_count 218 | forces_l2_rmse = np.sqrt(forces_running_l2_se / forces_count) 219 | 220 | forces_c_mae = forces_running_c_ae / (forces_count * 3) 221 | forces_c_rmse = np.sqrt(forces_running_c_se / (forces_count * 3)) 222 | 223 | total_loss = running_loss / count 224 | 225 | evaluation = { 226 | "energy_mae": energy_mae, 227 | "energy_rmse": energy_rmse, 228 | "forces_l2_mae": forces_l2_mae, 229 | "forces_l2_rmse": forces_l2_rmse, 230 | "forces_mae": forces_c_mae, 231 | "forces_rmse": forces_c_rmse, 232 | "sqrt(total_loss)": np.sqrt(total_loss), 233 | } 234 | 235 | return evaluation 236 | 237 | def update_namespace(ns, d): 238 | for k, v in d.items(): 239 | if not ns.__dict__.get(k): 240 | ns.__dict__[k] = v 241 | 242 | class EarlyStopping(): 243 | def __init__(self, tolerance=5, min_delta=0): 244 | 245 | self.tolerance = tolerance 246 | self.min_delta = min_delta 247 | self.counter = 0 248 | self.early_stop = False 249 | 250 | def __call__(self, val_loss, best_loss): 251 | if best_loss < 1.0 and (val_loss - best_loss) > self.min_delta: 252 | self.counter +=1 253 | if self.counter >= self.tolerance: 254 | self.early_stop = True 255 | 256 | return self.early_stop 257 | 258 | def main(): 259 | args = get_arguments() 260 | if args.cfg: 261 | with open(args.cfg, 'r') as f: 262 | params = toml.load(f) 263 | update_namespace(args, params) 264 | 265 | # Setup logging 266 | os.makedirs(args.output_dir, exist_ok=True) 267 | logging.basicConfig( 268 | level=logging.DEBUG, 269 | format="%(asctime)s [%(levelname)-5.5s] %(message)s", 270 | handlers=[ 271 | logging.FileHandler( 272 | os.path.join(args.output_dir, "printlog.txt"), mode="w" 273 | ), 274 | logging.StreamHandler(), 275 | ], 276 | ) 277 | 278 | # Save command line args 279 | with open(os.path.join(args.output_dir, "commandline_args.txt"), "w") as f: 280 | f.write("\n".join(sys.argv[1:])) 281 | # Save parsed command line arguments 282 | with open(os.path.join(args.output_dir, "arguments.json"), "w") as f: 283 | json.dump(vars(args), f) 284 | 285 | # Create device 286 | device = torch.device(args.device) 287 | # Put a tensor on the device before loading data 288 | # This way the GPU appears to be in use when other users run gpustat 289 | torch.tensor([0], device=device) 290 | 291 | # Setup dataset and loader 292 | logging.info("loading data %s", args.dataset) 293 | dataset = AseDataset( 294 | args.dataset, 295 | cutoff = args.cutoff, 296 | ) 297 | 298 | with open(args.split_file, 'r') as f: 299 | splits = json.load(f) 300 | 301 | datasplits = { 302 | 'train': torch.utils.data.Subset(dataset, splits['train']), 303 | 'validation': torch.utils.data.Subset(dataset, splits['validation']), 304 | } 305 | 306 | train_loader = torch.utils.data.DataLoader( 307 | datasplits["train"], 308 | args.batch_size, 309 | sampler=torch.utils.data.RandomSampler(datasplits["train"]), 310 | collate_fn=collate_atomsdata, 311 | ) 312 | val_loader = torch.utils.data.DataLoader( 313 | datasplits["validation"], 314 | args.batch_size, 315 | collate_fn=collate_atomsdata, 316 | ) 317 | 318 | logging.info("Computing mean and variance") 319 | target_mean, target_stddev = get_normalization( 320 | datasplits["train"], 321 | per_atom=args.atomwise_normalization, 322 | ) 323 | logging.debug("target_mean=%f, target_stddev=%f" % (target_mean, target_stddev)) 324 | 325 | net = PainnModel( 326 | num_interactions=args.num_interactions, 327 | hidden_state_size=args.node_size, 328 | cutoff=args.cutoff, 329 | normalization=args.normalization, 330 | target_mean=target_mean.tolist(), 331 | target_stddev=target_stddev.tolist(), 332 | atomwise_normalization=args.atomwise_normalization, 333 | ) 334 | net.to(device) 335 | 336 | optimizer = torch.optim.Adam(net.parameters(), lr=args.initial_lr) 337 | criterion = torch.nn.MSELoss() 338 | scheduler_fn = lambda step: 0.96 ** (step / 100000) 339 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_fn) 340 | early_stop = EarlyStopping(tolerance=args.stop_tolerance) 341 | 342 | running_loss = 0 343 | running_loss_count = 0 344 | best_val_loss = np.inf 345 | step = 0 346 | training_time = 0 347 | 348 | if args.load_model: 349 | state_dict = torch.load(args.load_model) 350 | net.load_state_dict(state_dict["model"]) 351 | # step = state_dict["step"] 352 | best_val_loss = state_dict["best_val_loss"] 353 | # optimizer.load_state_dict(state_dict["optimizer"]) 354 | # scheduler.load_state_dict(state_dict["scheduler"]) 355 | 356 | for epoch in itertools.count(): 357 | for batch_host in train_loader: 358 | start = time.time() 359 | # Transfer to 'device' 360 | batch = { 361 | k: v.to(device=device, non_blocking=True) 362 | for (k, v) in batch_host.items() 363 | } 364 | # Reset gradient 365 | optimizer.zero_grad() 366 | 367 | # Forward, backward and optimize 368 | outputs = net( 369 | batch, compute_forces=bool(args.forces_weight) 370 | ) 371 | energy_loss = criterion(outputs["energy"], batch["energy"]) 372 | if args.forces_weight: 373 | forces_loss = forces_criterion(outputs['forces'], batch['forces']) 374 | else: 375 | forces_loss = 0.0 376 | total_loss = ( 377 | args.forces_weight * forces_loss 378 | + (1 - args.forces_weight) * energy_loss 379 | ) 380 | total_loss.backward() 381 | optimizer.step() 382 | running_loss += total_loss.item() * batch["energy"].shape[0] 383 | running_loss_count += batch["energy"].shape[0] 384 | training_time += time.time() - start 385 | 386 | # print(step, loss_value) 387 | # Validate and save model 388 | if (step % args.log_interval == 0) or ((step + 1) == args.max_steps): 389 | eval_start = time.time() 390 | train_loss = running_loss / running_loss_count 391 | running_loss = running_loss_count = 0 392 | 393 | eval_dict = eval_model(net, val_loader, device, args.forces_weight) 394 | eval_formatted = ", ".join( 395 | ["%s=%g" % (k, v) for (k, v) in eval_dict.items()] 396 | ) 397 | 398 | logging.info( 399 | "step=%d, %s, sqrt(train_loss)=%g, max memory used=%g, training time=%g min, eval time=%g min", 400 | step, 401 | eval_formatted, 402 | math.sqrt(train_loss), 403 | torch.cuda.max_memory_allocated() / 2**20, 404 | training_time / 60, 405 | (time.time() - eval_start) / 60 406 | ) 407 | training_time = 0 408 | # Save checkpoint 409 | if not early_stop(eval_dict["sqrt(total_loss)"], best_val_loss): 410 | best_val_loss = eval_dict["sqrt(total_loss)"] 411 | torch.save( 412 | { 413 | "model": net.state_dict(), 414 | "optimizer": optimizer.state_dict(), 415 | "scheduler": scheduler.state_dict(), 416 | "step": step, 417 | "best_val_loss": best_val_loss, 418 | "node_size": args.node_size, 419 | "num_layer": args.num_interactions, 420 | "cutoff": args.cutoff, 421 | }, 422 | os.path.join(args.output_dir, "best_model.pth"), 423 | ) 424 | else: 425 | sys.exit(0) 426 | 427 | step += 1 428 | 429 | scheduler.step() 430 | 431 | if step >= args.max_steps: 432 | logging.info("Max steps reached, exiting") 433 | torch.save( 434 | { 435 | "model": net.state_dict(), 436 | "optimizer": optimizer.state_dict(), 437 | "scheduler": scheduler.state_dict(), 438 | "step": step, 439 | "best_val_loss": best_val_loss, 440 | "node_size": args.node_size, 441 | "num_layer": args.num_interactions, 442 | "cutoff": args.cutoff, 443 | }, 444 | os.path.join(args.output_dir, "exit_model.pth"), 445 | ) 446 | sys.exit(0) 447 | 448 | if __name__ == "__main__": 449 | main() 450 | -------------------------------------------------------------------------------- /scripts/water_O2.cif: -------------------------------------------------------------------------------- 1 | 2 | #====================================================================== 3 | # CRYSTAL DATA 4 | #---------------------------------------------------------------------- 5 | data_VESTA_phase_1 6 | 7 | _chemical_name_common '' 8 | _cell_length_a 12.600000 9 | _cell_length_b 12.600000 10 | _cell_length_c 12.600000 11 | _cell_angle_alpha 90.000000 12 | _cell_angle_beta 90.000000 13 | _cell_angle_gamma 90.000000 14 | _cell_volume 2000.376182 15 | _space_group_name_H-M_alt 'P 1' 16 | _space_group_IT_number 1 17 | 18 | loop_ 19 | _space_group_symop_operation_xyz 20 | 'x, y, z' 21 | 22 | loop_ 23 | _atom_site_label 24 | _atom_site_occupancy 25 | _atom_site_fract_x 26 | _atom_site_fract_y 27 | _atom_site_fract_z 28 | _atom_site_adp_type 29 | _atom_site_B_iso_or_equiv 30 | _atom_site_type_symbol 31 | O1 1.0 0.291790 0.511480 0.379270 Biso 1.000000 O 32 | O2 1.0 0.303320 0.251860 0.188060 Biso 1.000000 O 33 | O3 1.0 0.091950 0.077880 0.623230 Biso 1.000000 O 34 | O4 1.0 0.741690 0.283460 0.638230 Biso 1.000000 O 35 | O5 1.0 0.608390 0.865710 0.971920 Biso 1.000000 O 36 | O6 1.0 0.147250 0.738240 0.186290 Biso 1.000000 O 37 | O7 1.0 0.820100 0.652800 0.148430 Biso 1.000000 O 38 | O8 1.0 0.967270 0.822160 0.868330 Biso 1.000000 O 39 | O9 1.0 0.854540 0.474700 0.657050 Biso 1.000000 O 40 | O10 1.0 0.598820 0.804170 0.766680 Biso 1.000000 O 41 | O11 1.0 0.056380 0.392800 0.220790 Biso 1.000000 O 42 | O12 1.0 0.398640 0.670010 0.680850 Biso 1.000000 O 43 | O13 1.0 0.434050 0.759620 0.028000 Biso 1.000000 O 44 | O14 1.0 0.084340 0.659930 0.006530 Biso 1.000000 O 45 | O15 1.0 0.266970 0.136120 0.368690 Biso 1.000000 O 46 | O16 1.0 0.800830 0.861400 0.111570 Biso 1.000000 O 47 | O17 1.0 0.650660 0.990640 0.439370 Biso 1.000000 O 48 | O18 1.0 0.277250 0.881160 0.447140 Biso 1.000000 O 49 | O19 1.0 0.545150 0.745470 0.516140 Biso 1.000000 O 50 | O20 1.0 0.667440 0.200890 0.252520 Biso 1.000000 O 51 | O21 1.0 0.745940 0.616710 0.802360 Biso 1.000000 O 52 | O22 1.0 0.344900 0.643820 0.219180 Biso 1.000000 O 53 | O23 1.0 0.544780 0.357790 0.630150 Biso 1.000000 O 54 | O24 1.0 0.622730 0.994000 0.672940 Biso 1.000000 O 55 | O25 1.0 0.198100 0.690050 0.505930 Biso 1.000000 O 56 | O26 1.0 0.350940 0.438600 0.087820 Biso 1.000000 O 57 | O27 1.0 0.423310 0.027640 0.497210 Biso 1.000000 O 58 | O28 1.0 0.732950 0.325120 0.968620 Biso 1.000000 O 59 | O29 1.0 0.852830 0.610750 0.364470 Biso 1.000000 O 60 | O30 1.0 0.900960 0.738090 0.641740 Biso 1.000000 O 61 | O31 1.0 0.827740 0.203920 0.121470 Biso 1.000000 O 62 | O32 1.0 0.602460 0.419800 0.253630 Biso 1.000000 O 63 | O33 1.0 0.608760 0.508200 0.911240 Biso 1.000000 O 64 | O34 1.0 0.576780 0.125490 0.971200 Biso 1.000000 O 65 | O35 1.0 0.434780 0.464980 0.763000 Biso 1.000000 O 66 | O36 1.0 0.840880 0.199590 0.804530 Biso 1.000000 O 67 | O37 1.0 0.026680 0.375750 0.512530 Biso 1.000000 O 68 | O38 1.0 0.288210 0.006550 0.703740 Biso 1.000000 O 69 | O39 1.0 0.778590 0.157910 0.476440 Biso 1.000000 O 70 | O40 1.0 0.193300 0.330340 0.393520 Biso 1.000000 O 71 | O41 1.0 0.415060 0.180510 0.843790 Biso 1.000000 O 72 | O42 1.0 0.958050 0.505370 0.070110 Biso 1.000000 O 73 | O43 1.0 0.984680 0.106340 0.444620 Biso 1.000000 O 74 | O44 1.0 0.120370 0.428310 0.787520 Biso 1.000000 O 75 | O45 1.0 0.146340 0.923030 0.072790 Biso 1.000000 O 76 | O46 1.0 0.495170 0.514270 0.446370 Biso 1.000000 O 77 | O47 1.0 0.650940 0.915300 0.237120 Biso 1.000000 O 78 | O48 1.0 0.127520 0.990080 0.302090 Biso 1.000000 O 79 | O49 1.0 0.052490 0.224410 0.750720 Biso 1.000000 O 80 | O50 1.0 0.180310 0.870000 0.848560 Biso 1.000000 O 81 | O51 1.0 0.618490 0.589540 0.127830 Biso 1.000000 O 82 | O52 1.0 0.943120 0.398110 0.908100 Biso 1.000000 O 83 | O53 1.0 0.017050 0.735580 0.380950 Biso 1.000000 O 84 | O54 1.0 0.731110 0.646570 0.539590 Biso 1.000000 O 85 | O55 1.0 0.091030 0.632000 0.680020 Biso 1.000000 O 86 | O56 1.0 0.319380 0.378200 0.902550 Biso 1.000000 O 87 | O57 1.0 0.455120 0.793090 0.303840 Biso 1.000000 O 88 | O58 1.0 0.059320 0.192850 0.164590 Biso 1.000000 O 89 | O59 1.0 0.022450 0.889690 0.537430 Biso 1.000000 O 90 | O60 1.0 0.155680 0.215120 0.949720 Biso 1.000000 O 91 | O61 1.0 0.512920 0.070650 0.172800 Biso 1.000000 O 92 | O62 1.0 0.921120 0.019420 0.019760 Biso 1.000000 O 93 | O63 1.0 0.298480 0.712610 0.879390 Biso 1.000000 O 94 | O64 1.0 0.322320 0.050500 0.051740 Biso 1.000000 O 95 | O65 1.0 0.855430 0.395050 0.315430 Biso 1.000000 O 96 | O66 1.0 0.841040 0.980250 0.802780 Biso 1.000000 O 97 | O67 1.0 0.484000 0.157760 0.649970 Biso 1.000000 O 98 | H1 1.0 0.305630 0.555150 0.313630 Biso 1.000000 H 99 | H2 1.0 0.265060 0.559620 0.435020 Biso 1.000000 H 100 | H5 1.0 0.159910 0.033030 0.646240 Biso 1.000000 H 101 | H6 1.0 0.090930 0.135550 0.675230 Biso 1.000000 H 102 | H7 1.0 0.664560 0.304930 0.634650 Biso 1.000000 H 103 | H8 1.0 0.769850 0.245880 0.571220 Biso 1.000000 H 104 | H9 1.0 0.665830 0.839030 0.015530 Biso 1.000000 H 105 | H10 1.0 0.541340 0.840930 0.007500 Biso 1.000000 H 106 | H11 1.0 0.115370 0.700170 0.120160 Biso 1.000000 H 107 | H12 1.0 0.090230 0.745890 0.236890 Biso 1.000000 H 108 | H13 1.0 0.871000 0.602900 0.119030 Biso 1.000000 H 109 | H14 1.0 0.825620 0.646880 0.223650 Biso 1.000000 H 110 | H15 1.0 0.927680 0.780220 0.817450 Biso 1.000000 H 111 | H16 1.0 0.923300 0.888520 0.859390 Biso 1.000000 H 112 | H17 1.0 0.851710 0.524070 0.598260 Biso 1.000000 H 113 | H18 1.0 0.810520 0.412560 0.630660 Biso 1.000000 H 114 | H19 1.0 0.597890 0.811480 0.843280 Biso 1.000000 H 115 | H20 1.0 0.541050 0.755430 0.749370 Biso 1.000000 H 116 | H21 1.0 0.103310 0.410790 0.284950 Biso 1.000000 H 117 | H22 1.0 0.091620 0.325780 0.194580 Biso 1.000000 H 118 | H23 1.0 0.415760 0.589040 0.699600 Biso 1.000000 H 119 | H24 1.0 0.368830 0.696890 0.746910 Biso 1.000000 H 120 | H25 1.0 0.380910 0.740830 0.971520 Biso 1.000000 H 121 | H26 1.0 0.400620 0.760870 0.095660 Biso 1.000000 H 122 | H27 1.0 0.035060 0.598670 0.014540 Biso 1.000000 H 123 | H28 1.0 0.044070 0.717920 0.975760 Biso 1.000000 H 124 | H29 1.0 0.326880 0.117070 0.417390 Biso 1.000000 H 125 | H30 1.0 0.234630 0.196250 0.402590 Biso 1.000000 H 126 | H31 1.0 0.818480 0.785990 0.119840 Biso 1.000000 H 127 | H32 1.0 0.871750 0.889360 0.089770 Biso 1.000000 H 128 | H33 1.0 0.645350 0.957710 0.512430 Biso 1.000000 H 129 | H34 1.0 0.715280 0.034770 0.459660 Biso 1.000000 H 130 | H35 1.0 0.239840 0.919190 0.386110 Biso 1.000000 H 131 | H36 1.0 0.318320 0.941910 0.481050 Biso 1.000000 H 132 | H37 1.0 0.493650 0.748570 0.577710 Biso 1.000000 H 133 | H38 1.0 0.501030 0.766000 0.453700 Biso 1.000000 H 134 | H39 1.0 0.633180 0.269970 0.250600 Biso 1.000000 H 135 | H40 1.0 0.611980 0.143660 0.254440 Biso 1.000000 H 136 | H41 1.0 0.789430 0.575970 0.750640 Biso 1.000000 H 137 | H42 1.0 0.706960 0.663370 0.754550 Biso 1.000000 H 138 | H43 1.0 0.273770 0.681120 0.217190 Biso 1.000000 H 139 | H44 1.0 0.388890 0.698170 0.253430 Biso 1.000000 H 140 | H45 1.0 0.534910 0.401820 0.563230 Biso 1.000000 H 141 | H46 1.0 0.515780 0.409570 0.687880 Biso 1.000000 H 142 | H47 1.0 0.612070 0.932720 0.715520 Biso 1.000000 H 143 | H48 1.0 0.579800 0.055860 0.692010 Biso 1.000000 H 144 | H49 1.0 0.139660 0.712820 0.455740 Biso 1.000000 H 145 | H50 1.0 0.253990 0.738850 0.489590 Biso 1.000000 H 146 | H51 1.0 0.350320 0.507560 0.127920 Biso 1.000000 H 147 | H52 1.0 0.344750 0.379870 0.137830 Biso 1.000000 H 148 | H53 1.0 0.422960 0.074500 0.564730 Biso 1.000000 H 149 | H54 1.0 0.499190 0.023320 0.478950 Biso 1.000000 H 150 | H55 1.0 0.676310 0.273130 0.971090 Biso 1.000000 H 151 | H56 1.0 0.685370 0.393710 0.947910 Biso 1.000000 H 152 | H57 1.0 0.862750 0.534080 0.369810 Biso 1.000000 H 153 | H58 1.0 0.791300 0.622680 0.412600 Biso 1.000000 H 154 | H59 1.0 0.962010 0.687810 0.657170 Biso 1.000000 H 155 | H60 1.0 0.940900 0.792160 0.595790 Biso 1.000000 H 156 | H61 1.0 0.812020 0.264670 0.077760 Biso 1.000000 H 157 | H62 1.0 0.769340 0.204280 0.176850 Biso 1.000000 H 158 | H63 1.0 0.590790 0.475730 0.205570 Biso 1.000000 H 159 | H64 1.0 0.583260 0.449710 0.320980 Biso 1.000000 H 160 | H65 1.0 0.596040 0.539420 0.981970 Biso 1.000000 H 161 | H66 1.0 0.669780 0.544210 0.878230 Biso 1.000000 H 162 | H67 1.0 0.605900 0.064060 0.940180 Biso 1.000000 H 163 | H68 1.0 0.540240 0.103530 0.037930 Biso 1.000000 H 164 | H71 1.0 0.799400 0.236820 0.747820 Biso 1.000000 H 165 | H72 1.0 0.825640 0.235120 0.872570 Biso 1.000000 H 166 | H73 1.0 0.977920 0.363450 0.455480 Biso 1.000000 H 167 | H74 1.0 0.992270 0.366230 0.579600 Biso 1.000000 H 168 | H75 1.0 0.344470 0.031140 0.753450 Biso 1.000000 H 169 | H76 1.0 0.247830 0.948740 0.737190 Biso 1.000000 H 170 | H77 1.0 0.749910 0.171880 0.408440 Biso 1.000000 H 171 | H78 1.0 0.858000 0.145670 0.458050 Biso 1.000000 H 172 | H79 1.0 0.137400 0.329690 0.453420 Biso 1.000000 H 173 | H80 1.0 0.236050 0.394380 0.403930 Biso 1.000000 H 174 | H81 1.0 0.375730 0.243730 0.863210 Biso 1.000000 H 175 | H82 1.0 0.472460 0.169440 0.899690 Biso 1.000000 H 176 | H83 1.0 0.995180 0.470860 0.132060 Biso 1.000000 H 177 | H84 1.0 0.942550 0.449830 0.007970 Biso 1.000000 H 178 | H85 1.0 0.038170 0.106620 0.506000 Biso 1.000000 H 179 | H86 1.0 0.980220 0.032830 0.423360 Biso 1.000000 H 180 | H87 1.0 0.089730 0.361760 0.750200 Biso 1.000000 H 181 | H88 1.0 0.104490 0.477370 0.731510 Biso 1.000000 H 182 | H89 1.0 0.074070 0.932140 0.045200 Biso 1.000000 H 183 | H90 1.0 0.155060 0.851060 0.108180 Biso 1.000000 H 184 | H91 1.0 0.516470 0.586040 0.461250 Biso 1.000000 H 185 | H92 1.0 0.420520 0.515050 0.423790 Biso 1.000000 H 186 | H93 1.0 0.710200 0.899950 0.186010 Biso 1.000000 H 187 | H94 1.0 0.685760 0.926470 0.306020 Biso 1.000000 H 188 | H95 1.0 0.166840 0.059400 0.320700 Biso 1.000000 H 189 | H96 1.0 0.121500 0.983360 0.221300 Biso 1.000000 H 190 | H97 1.0 0.085470 0.231440 0.824430 Biso 1.000000 H 191 | H98 1.0 0.976310 0.199760 0.770130 Biso 1.000000 H 192 | H99 1.0 0.175900 0.906220 0.921080 Biso 1.000000 H 193 | H100 1.0 0.099450 0.856470 0.830380 Biso 1.000000 H 194 | H101 1.0 0.687180 0.627590 0.140360 Biso 1.000000 H 195 | H102 1.0 0.575160 0.649540 0.104550 Biso 1.000000 H 196 | H103 1.0 0.881050 0.422700 0.868250 Biso 1.000000 H 197 | H104 1.0 0.006450 0.418360 0.861550 Biso 1.000000 H 198 | H105 1.0 0.949340 0.693880 0.376620 Biso 1.000000 H 199 | H106 1.0 0.010580 0.808450 0.405310 Biso 1.000000 H 200 | H107 1.0 0.663720 0.685850 0.529150 Biso 1.000000 H 201 | H108 1.0 0.778590 0.701170 0.581370 Biso 1.000000 H 202 | H109 1.0 0.141280 0.646600 0.620090 Biso 1.000000 H 203 | H110 1.0 0.120100 0.679760 0.734890 Biso 1.000000 H 204 | H111 1.0 0.322660 0.404990 0.973970 Biso 1.000000 H 205 | H112 1.0 0.262670 0.420990 0.869420 Biso 1.000000 H 206 | H113 1.0 0.518750 0.833560 0.273900 Biso 1.000000 H 207 | H114 1.0 0.412880 0.843550 0.349950 Biso 1.000000 H 208 | H115 1.0 0.983010 0.204750 0.169990 Biso 1.000000 H 209 | H116 1.0 0.060770 0.136380 0.212390 Biso 1.000000 H 210 | H117 1.0 0.098320 0.897960 0.517210 Biso 1.000000 H 211 | H118 1.0 0.011430 0.961150 0.561150 Biso 1.000000 H 212 | H119 1.0 0.124320 0.216990 0.018450 Biso 1.000000 H 213 | H120 1.0 0.217590 0.260980 0.954920 Biso 1.000000 H 214 | H121 1.0 0.562630 0.006990 0.185000 Biso 1.000000 H 215 | H122 1.0 0.440790 0.048660 0.158590 Biso 1.000000 H 216 | H123 1.0 0.893550 0.077180 0.065530 Biso 1.000000 H 217 | H124 1.0 0.872330 0.005080 0.954480 Biso 1.000000 H 218 | H125 1.0 0.242900 0.772220 0.871600 Biso 1.000000 H 219 | H126 1.0 0.265120 0.653330 0.913550 Biso 1.000000 H 220 | H127 1.0 0.246130 0.035070 0.060400 Biso 1.000000 H 221 | H128 1.0 0.331580 0.078970 0.977530 Biso 1.000000 H 222 | H129 1.0 0.789380 0.388080 0.275210 Biso 1.000000 H 223 | H130 1.0 0.909150 0.380550 0.264760 Biso 1.000000 H 224 | H131 1.0 0.826460 0.060160 0.807490 Biso 1.000000 H 225 | H132 1.0 0.775010 0.949820 0.768010 Biso 1.000000 H 226 | H133 1.0 0.472880 0.171420 0.726340 Biso 1.000000 H 227 | H134 1.0 0.498630 0.231200 0.622540 Biso 1.000000 H 228 | O35 1.0 0.544780 0.504980 0.763000 Biso 1.000000 O 229 | O2 1.0 0.223320 0.171860 0.168060 Biso 1.000000 O 230 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="PaiNN", 5 | version="1.0.0", 6 | description="Library for implementation of message passing neural networks in Pytorch", 7 | author="xinyang", 8 | author_email="xinyang@dtu.dk", 9 | url = "https://github.com/Yangxinsix/PaiNN-model", 10 | packages=["PaiNN"], 11 | ) 12 | -------------------------------------------------------------------------------- /workflow/al_select.py: -------------------------------------------------------------------------------- 1 | from PaiNN.data import AseDataset, collate_atomsdata 2 | from PaiNN.model import PainnModel 3 | import torch 4 | import numpy as np 5 | from PaiNN.active_learning import GeneralActiveLearning 6 | import math 7 | import glob 8 | import json 9 | import argparse, toml 10 | from pathlib import Path 11 | from ase.io import read, write, Trajectory 12 | 13 | def setup_seed(seed): 14 | torch.manual_seed(seed) 15 | if torch.cuda.is_available(): 16 | torch.cuda.manual_seed_all(seed) 17 | np.random.seed(seed) 18 | torch.backends.cudnn.deterministic = True 19 | 20 | def get_arguments(arg_list=None): 21 | parser = argparse.ArgumentParser( 22 | description="General Active Learning", fromfile_prefix_chars="+" 23 | ) 24 | parser.add_argument( 25 | "--kernel", 26 | type=str, 27 | help="How to get features", 28 | ) 29 | parser.add_argument( 30 | "--selection", 31 | type=str, 32 | help="Selection method, one of `max_dist_greedy`, `deterministic_CUR`, `lcmd_greedy`, `max_det_greedy` or `max_diag`", 33 | ) 34 | parser.add_argument( 35 | "--n_random_features", 36 | type=int, 37 | help="If `n_random_features = 0`, do not use random projections.", 38 | ) 39 | parser.add_argument( 40 | "--batch_size", 41 | type=int, 42 | help="How many data points should be selected", 43 | ) 44 | parser.add_argument( 45 | "--load_model", 46 | type=str, 47 | help="Where to find the models", 48 | ) 49 | parser.add_argument( 50 | "--dataset", type=str, help="Path to ASE trajectory", 51 | ) 52 | parser.add_argument( 53 | "--split_file", 54 | type=str, 55 | help="Train/test/validation split file json", 56 | ) 57 | parser.add_argument( 58 | "--pool_set", type=str, help="Path to MD trajectory obtained from machine learning potential", 59 | ) 60 | parser.add_argument( 61 | "--training_set", type=str, help="Path to training set. Useful for pool/train based selection method", 62 | ) 63 | parser.add_argument( 64 | "--device", 65 | type=str, 66 | help="Set which device to use for training e.g. 'cuda' or 'cpu'", 67 | ) 68 | parser.add_argument( 69 | "--random_seed", 70 | type=int, 71 | help="Random seed for this run", 72 | ) 73 | parser.add_argument( 74 | "--cfg", 75 | type=str, 76 | default="arguments.toml", 77 | help="Path to config file. e.g. 'arguments.toml'" 78 | ) 79 | 80 | return parser.parse_args(arg_list) 81 | 82 | def update_namespace(ns, d): 83 | for k, v in d.items(): 84 | if not ns.__dict__.get(k): 85 | ns.__dict__[k] = v 86 | 87 | def main(): 88 | args = get_arguments() 89 | if args.cfg: 90 | with open(args.cfg, 'r') as f: 91 | params = toml.load(f) 92 | update_namespace(args, params) 93 | 94 | setup_seed(args.random_seed) 95 | 96 | # Load models 97 | model_pth = Path(args.load_model).rglob('*best_model.pth') 98 | models = [] 99 | for each in model_pth: 100 | state_dict = torch.load(each) 101 | model = PainnModel( 102 | num_interactions=state_dict["num_layer"], 103 | hidden_state_size=state_dict["node_size"], 104 | cutoff=state_dict["cutoff"], 105 | ) 106 | model.to(args.device) 107 | model.load_state_dict(state_dict["model"]) 108 | models.append(model) 109 | 110 | # Load dataset 111 | if args.dataset: 112 | with open(args.split_file, 'r') as f: 113 | datasplits = json.load(f) 114 | 115 | dataset = AseDataset(args.dataset, cutoff=models[0].cutoff) 116 | data_dict = { 117 | 'pool': torch.utils.data.Subset(dataset, datasplits['pool']), 118 | 'train': torch.utils.data.Subset(dataset, datasplits['train']), 119 | } 120 | elif args.pool_set and args.train_set: 121 | if isinstance(args.pool_set, list): 122 | dataset = [] 123 | for traj in args.pool_set: 124 | if Path(traj).stat().st_size > 0: 125 | dataset += read(traj, ':') 126 | else: 127 | dataset = args.pool_set 128 | data_dict = { 129 | 'pool': AseDataset(dataset, cutoff=models[0].cutoff), 130 | 'train': AseDataset(args.train_set, cutoff=models[0].cutoff), 131 | } 132 | else: 133 | raise RuntimeError("Please give valid pool data set for selection!") 134 | 135 | # raise error if the pool dataset is not large enough 136 | if len(data_dict['pool']) < args.batch_size * 5: 137 | raise RuntimeError(f"""The pool data set is not large enough for selection! 138 | It should be larger than 10 times batch size ({args.batch_size*10}). 139 | Check you MD simulation!""") 140 | 141 | # Select structures 142 | al = GeneralActiveLearning( 143 | kernel=args.kernel, 144 | selection=args.selection, 145 | n_random_features=args.n_random_features, 146 | ) 147 | indices = al.select(models, data_dict, al_batch_size=args.batch_size) 148 | al_idx = [datasplits['pool'][i] for i in indices] if args.dataset else indices 149 | al_info = { 150 | 'kernel': args.kernel, 151 | 'selection': args.selection, 152 | 'dataset': args.dataset if args.dataset else args.pool_set, 153 | 'selected': al_idx, 154 | } 155 | 156 | with open('selected.json', 'w') as f: 157 | json.dump(al_info, f) 158 | 159 | # Update new data splits 160 | if args.dataset: 161 | pool_idx = np.delete(datasplits['pool'], indices) 162 | datasplits['pool'] = pool_idx.tolist() 163 | datasplits['train'] += al_idx 164 | with open(args.split_file, 'w') as f: 165 | json.dump(datasplits, f) 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /workflow/config.toml: -------------------------------------------------------------------------------- 1 | # An example configuration file for active learning workflow 2 | 3 | [global] 4 | root = '.' 5 | random_seed = 3407 6 | 7 | [train] 8 | # Hyperparameters for PaiNN 9 | # load_model = '.' 10 | cutoff = 5.0 11 | # split_file = 'datasplits.json' 12 | val_ratio = 0.1 13 | num_interactions = 3 14 | node_size = 64 15 | output_dir = 'model_output' 16 | dataset = '/home/scratch3/xinyang/Au-facets/111_110.traj' 17 | max_steps = 1000000 18 | device = 'cuda' 19 | batch_size = 12 20 | initial_lr = 0.0001 21 | forces_weight = 0.98 22 | log_interval = 2000 23 | normalization = false 24 | atomwise_normalization = false 25 | stop_patience = 20 26 | plateau_scheduler = true # use ReduceLROnPlateatu scheduler to decrease lr when learning plateaus 27 | random_seed = 3407 28 | 29 | [train.ensemble] 30 | # For training multiple models in parallel, the hyperparameters will be set as default (in above) if not assigned 31 | #80_node_4_layer = {node_size = 80, num_interactions = 4, load_model = '/home/scratch3/xinyang/Au-facets/old_training/80_node_4_layer/model_output/best_model.pth'} 32 | #96_node_4_layer = {node_size = 96, num_interactions = 4, load_model = '/home/scratch3/xinyang/Au-facets/old_training/96_node_4_layer/model_output/best_model.pth'} 33 | 112_node_3_layer = {node_size = 112, num_interactions = 3, load_model = '/home/scratch3/xinyang/Au-facets/old_training/112_node_3_layer/model_output/best_model.pth'} 34 | 120_node_3_layer = {node_size = 120, num_interactions = 3, start_iteration = 7, stop_patience = 200} 35 | 128_node_3_layer = {node_size = 128, num_interactions = 3, load_model = '/home/scratch3/xinyang/Au-facets/old_training/128_node_3_layer/model_output/best_model.pth'} 36 | 136_node_3_layer = {node_size = 136, num_interactions = 3, start_iteration = 7, stop_patience = 200} 37 | 144_node_3_layer = {node_size = 144, num_interactions = 3, load_model = '/home/scratch3/xinyang/Au-facets/old_training/144_node_3_layer/model_output/best_model.pth'} 38 | 160_node_3_layer = {node_size = 160, num_interactions = 3, load_model = '/home/scratch3/xinyang/Au-facets/old_training/160_node_3_layer/model_output/best_model.pth'} 39 | 40 | [train.resource] 41 | nodename = 'sm3090' 42 | tmax = '7d' # Time limit for each job. For example: 1d (1 day), 2m (2 min), 5h (5 hours) 43 | cores = 8 44 | 45 | [MD] 46 | # Parameters for MD. It is better to customize your parameters in MD script. 47 | init_traj = '/home/scratch3/xinyang/md_mlp/Au_111_110/110_water/MD.traj' 48 | start_indice = -5 49 | # load_model = '/home/scratch3/xinyang/Au-facets/old_training/train' # will be assigned in the workflow 50 | time_step = 0.5 51 | temperature = 350 52 | max_steps = 2000000 53 | min_steps = 100000 54 | device = 'cuda' 55 | fix_under = 7.0 56 | dump_step = 100 57 | print_step = 1 58 | num_uncertain = 1000 59 | random_seed = 3407 60 | 61 | [MD.runs] 62 | # run multiple MD jobs in parallel 63 | [MD.runs.Au_111_water] 64 | init_traj = '/home/scratch3/xinyang/md_mlp/Au_111_110/111_water/MD.traj' 65 | fix_under = 7.0 66 | start_indice = -5 67 | min_steps = 50000 68 | dump_step = 50 69 | 70 | [MD.runs.Au_110_water] 71 | init_traj = '/home/scratch3/xinyang/md_mlp/Au_111_110/110_water/MD.traj' 72 | max_steps = 2000000 # this one is already good enough 73 | fix_under = 7.0 74 | start_indice = -5 75 | min_steps = 50000 76 | dump_step = 100 77 | 78 | [MD.runs.Au_111_1OH] 79 | init_traj = '/home/scratch3/xinyang/Au-facets/1OH/md/iter_0/111_1OH/MD.traj' 80 | fix_under = 7.0 81 | start_indice = -5 82 | min_steps = 50000 83 | dump_step = 30 84 | start_iteration = 3 85 | 86 | [MD.runs.Au_110_1OH] 87 | init_traj = '/home/scratch3/xinyang/Au-facets/1OH/md/iter_0/110_1OH/MD.traj' 88 | fix_under = 7.0 89 | start_indice = -5 90 | min_steps = 20000 91 | dump_step = 30 92 | start_iteration = 3 93 | 94 | [MD.runs.Au_111_1O2] 95 | init_traj = '/home/energy/xinyang/work/Au_MD/DFT_MD/111_MD/O2/111_O2_incomplete.traj' 96 | fix_under = 7.0 97 | start_indice = -5 98 | min_steps = 30000 99 | dump_step = 50 100 | start_iteration = 3 101 | 102 | [MD.runs.Au_110_1O2] 103 | init_traj = '/home/energy/xinyang/work/Au_MD/DFT_MD/110_MD/O2/110_O2_incomplete.traj' 104 | fix_under = 7.0 105 | start_indice = -5 106 | min_steps = 50000 107 | dump_step = 50 108 | start_iteration = 3 109 | 110 | [MD.resource] 111 | nodename = 'sm3090' 112 | tmax = '7d' 113 | cores = 8 114 | 115 | [select] 116 | kernel = "full-g" # Name of the kernel, e.g. "full-g", "ll-g", "full-F_inv", "ll-F_inv", "qbc-energy", "qbc-force", "ae-energy", "ae-force", "random" 117 | selection = "lcmd_greedy" # Selection method, one of "max_dist_greedy", "deterministic_CUR", "lcmd_greedy", "max_det_greedy" or "max_diag". 118 | n_random_features = 500 # If "n_random_features = 0", do not use random projections. 119 | batch_size = 100 120 | # load_model = '/home/scratch3/xinyang/Au-facets/old_training/train' # will be assigned in the workflow 121 | # dataset = 'md17aspirin.traj' # should not be assigned if using pool data set from MD 122 | # split_file = 'datasplits.json' 123 | # pool_set = # Useful when dataset and split_file are not assigned, can be a list or str 124 | train_set = '/home/scratch3/xinyang/Au-facets/111_110.traj' 125 | device = 'cuda' 126 | random_seed = 3407 127 | 128 | [select.runs] 129 | Au_110_water = {batch_size = 100} # this one is much faster so use larger batch size can save some time 130 | Au_111_water = {batch_size = 200} 131 | Au_110_1OH = {batch_size = 200, start_iteration = 3} 132 | Au_111_1OH = {batch_size = 200, start_iteration = 3} 133 | Au_110_1O2 = {batch_size = 200, start_iteration = 3} 134 | Au_111_1O2 = {batch_size = 200, start_iteration = 3} 135 | 136 | 137 | [select.resource] 138 | nodename = 'sm3090' 139 | tmax = '2d' 140 | cores = 8 141 | 142 | [labeling] 143 | # label_set = 'xxx.traj' 144 | train_set = '/home/scratch3/xinyang/Au-facets/111_110.traj' 145 | # pool_set # will be assigned in the workflow, can be a list 146 | # al_info # will be assigned in the workflow 147 | num_jobs = 2 148 | 149 | [labeling.VASP] 150 | # VASP parameters 151 | xc = 'PBE' 152 | gga = 'pe' 153 | system = 'ni' 154 | prec = 'normal' 155 | istart = 1 156 | icharg = 2 157 | npar = 4 158 | encut = 350 159 | algo = 'Fast' 160 | lreal = 'Auto' 161 | nelm = 1000 162 | nelmin = 5 163 | nelmdl = -5 164 | ediff = 1e-4 165 | ediffg = -0.01 166 | nsw = 0 167 | ibrion = 0 168 | potim = 1 169 | isif = 2 170 | ispin = 2 171 | ismear = 0 172 | sigma = 0.1 173 | lwave = true 174 | lcharg = false 175 | ivdw = 11 176 | lasph = true 177 | kpts = [2, 2, 1] 178 | gamma = false 179 | # kspacing = 0.5 180 | 181 | [labeling.runs] 182 | 183 | [labeling.runs.Au_111_water] # The key name should be the same to MD 184 | gamma = true 185 | num_jobs = 6 # accelerate DFT labeling by spliting the job to several different parts 186 | 187 | [labeling.runs.Au_110_water] 188 | gamma = false 189 | num_jobs = 2 190 | 191 | [labeling.runs.Au_111_1OH] 192 | gamma = true 193 | num_jobs = 6 194 | start_iteration = 3 195 | 196 | [labeling.runs.Au_110_1OH] 197 | gamma = false 198 | num_jobs = 2 199 | start_iteration = 3 200 | 201 | [labeling.runs.Au_111_1O2] 202 | gamma = true 203 | num_jobs = 6 204 | start_iteration = 3 205 | 206 | [labeling.runs.Au_110_1O2] 207 | gamma = false 208 | num_jobs = 2 209 | start_iteration = 3 210 | 211 | [labeling.resource] 212 | cores = 40 213 | nodename = 'xeon40' 214 | tmax = '2d' 215 | -------------------------------------------------------------------------------- /workflow/flow.py: -------------------------------------------------------------------------------- 1 | import json, toml, sys 2 | from pathlib import Path 3 | from myqueue.workflow import run 4 | from typing import List, Dict 5 | from ase.io import Trajectory, read, write 6 | import numpy as np 7 | import copy 8 | 9 | # args parsing 10 | 11 | with open('config.toml') as f: 12 | args = toml.load(f) 13 | 14 | # get absolute path 15 | name_list = [ 16 | 'dataset', 17 | 'split_file', 18 | 'load_model', 19 | 'init_traj', 20 | 'pool_set', 21 | 'train_set', 22 | 'label_set', 23 | 'al_info', 24 | 'root' 25 | ] 26 | def get_absolute_path(d: dict): 27 | for k, v in d.items(): 28 | if k in name_list and not Path(v).is_absolute(): 29 | d[k] = str(Path(v).resolve()) 30 | elif isinstance(v, dict): 31 | d[k] = get_absolute_path(v) 32 | return d 33 | args = get_absolute_path(args) 34 | 35 | # parsing training parameters 36 | train_params = {} 37 | if args['train'].get('ensemble'): 38 | for name, params in args['train']['ensemble'].items(): 39 | for k, v in args['train'].items(): 40 | if not isinstance(v, dict) and k not in params: 41 | params[k] = v 42 | train_params[name] = params 43 | else: 44 | params = {} 45 | for k, v in args['train'].items(): 46 | if not isinstance(v, dict) and k not in params: 47 | params[k] = v 48 | train_params['model'] = params 49 | # train_resource = args['train']['resource'] 50 | 51 | # parsing active learning parameters 52 | al_params = {} 53 | if args['select'].get('runs'): 54 | for name, params in args['select']['runs'].items(): 55 | for k, v in args['select'].items(): 56 | if not isinstance(v, dict) and k not in params: 57 | params[k] = v 58 | al_params[name] = params 59 | else: 60 | params = {} 61 | for k, v in args['select'].items(): 62 | if not isinstance(v, dict) and k not in params: 63 | params[k] = v 64 | al_params['select'] = params 65 | 66 | # al_resource = args['select']['resource'] 67 | 68 | # parsing MD parameters 69 | md_params = {} 70 | if args['MD'].get('runs'): 71 | for name, params in args['MD']['runs'].items(): 72 | for k, v in args['MD'].items(): 73 | if not isinstance(v, dict) and k not in params: 74 | params[k] = v 75 | md_params[name] = params 76 | else: 77 | params = {} 78 | for k, v in args['MD'].items(): 79 | if not isinstance(v, dict) and k not in params: 80 | params[k] = v 81 | md_params['md_run'] = params 82 | 83 | # DFT labelling 84 | dft_params = {} 85 | tmp_params = {k: v for k, v in args['labeling'].items() if not isinstance(v, dict)} 86 | tmp_params['VASP'] = args['labeling']['VASP'] 87 | if args['labeling'].get('runs'): 88 | for name, params in args['labeling']['runs'].items(): 89 | new_params = copy.deepcopy(tmp_params) 90 | for k, v in params.items(): 91 | if k in new_params['VASP']: 92 | new_params['VASP'][k] = v 93 | else: 94 | new_params[k] = v 95 | dft_params[name] = new_params 96 | else: 97 | dft_params['dft_run'] = tmp_params 98 | 99 | root = args['global']['root'] 100 | 101 | def train_models(folder, deps, extra_args: List[str] = [], iteration: int=0): 102 | tasks = [] 103 | node_info = args['train']['resource'] 104 | # parse parameters 105 | for name, params in train_params.items(): 106 | path = Path(f'{folder}/iter_{iteration}/{name}') 107 | 108 | if not params.get('start_iteration'): 109 | params['start_iteration'] = 0 110 | if iteration >= params['start_iteration']: 111 | if not path.is_dir(): 112 | path.mkdir(parents=True) 113 | 114 | # load model 115 | if iteration > 0: 116 | load_model = f'{root}/{folder}/iter_{iteration-1}/{name}/{params["output_dir"]}/best_model.pth' 117 | if Path(load_model).is_file(): 118 | params['load_model'] = load_model 119 | # elif iteration == 0: 120 | # params['load_model'] = f'/home/scratch3/xinyang/Au-facets/old_training/train/{name}/model_output/best_model.pth' 121 | 122 | with open(path / 'arguments.toml', 'w') as f: 123 | toml.dump(params, f) 124 | 125 | arguments = ['--cfg', 'arguments.toml'] 126 | arguments += extra_args 127 | 128 | tasks.append(run( 129 | script=f'{root}/train.py', 130 | nodename='sm3090' if not node_info.get('nodename') else node_info['nodename'], 131 | cores=8 if not node_info.get('cores') else node_info['cores'], 132 | tmax='7d' if not node_info.get('tmax') else node_info['tmax'], 133 | args=arguments, 134 | folder=path, 135 | name=name, 136 | deps=deps, 137 | )) 138 | 139 | return tasks 140 | 141 | def active_learning(folder, deps, extra_args: List[str] = [], iteration: int=0): 142 | tasks = {} 143 | node_info = args['select']['resource'] 144 | # parse parameters 145 | for name, params in al_params.items(): 146 | path = Path(f'{folder}/iter_{iteration}/{name}') 147 | if not params.get('start_iteration'): 148 | params['start_iteration'] = 0 149 | if iteration >= params['start_iteration']: 150 | if not path.is_dir(): 151 | path.mkdir(parents=True) 152 | 153 | params['load_model'] = f'{root}/train/iter_{iteration}' 154 | if not params.get('dataset'): 155 | params['pool_set'] = [f'{root}/md/iter_{iteration}/{name}/MD.traj', f'{root}/md/iter_{iteration}/{name}/warning_struct.traj'] 156 | 157 | with open(path / 'arguments.toml', 'w') as f: 158 | toml.dump(params, f) 159 | 160 | arguments = ['--cfg', 'arguments.toml'] 161 | 162 | tasks[name] = run( 163 | script=f'{root}/al_select.py', 164 | nodename='sm3090' if not node_info.get('nodename') else node_info['nodename'], 165 | cores=8 if not node_info.get('cores') else node_info['cores'], 166 | tmax='7d' if not node_info.get('tmax') else node_info['tmax'], 167 | args=arguments, 168 | folder=path, 169 | name=name, 170 | deps=[deps[name]], 171 | ) 172 | 173 | return tasks 174 | 175 | def run_md(folder, deps=[], extra_args: List[str] = [], iteration: int=0): 176 | tasks = {} 177 | node_info = args['MD']['resource'] 178 | for name, params in md_params.items(): 179 | path = Path(f'{folder}/iter_{iteration}/{name}') 180 | 181 | if not params.get('start_iteration'): 182 | params['start_iteration'] = 0 183 | if iteration >= params['start_iteration']: 184 | if not path.is_dir(): 185 | path.mkdir(parents=True) 186 | params['load_model'] = f'{root}/train/iter_{iteration}' 187 | 188 | 189 | if iteration > params['start_iteration']: 190 | params['init_traj'] = f'{root}/md/iter_{iteration-1}/{name}/MD.traj' 191 | 192 | with open(path / 'arguments.toml', 'w') as f: 193 | toml.dump(params, f) 194 | 195 | arguments = ['--cfg', 'arguments.toml'] 196 | 197 | tasks[name] = run( 198 | script=f'{root}/md_run.py', 199 | nodename='sm3090' if not node_info.get('nodename') else node_info['nodename'], 200 | cores=8 if not node_info.get('cores') else node_info['cores'], 201 | tmax='7d' if not node_info.get('tmax') else node_info['tmax'], 202 | args=arguments, 203 | folder=path, 204 | name=name, 205 | deps=deps, 206 | ) 207 | 208 | return tasks 209 | 210 | def run_dft(folder, deps={}, extra_args: List[str] = [], iteration: int=0): 211 | tasks = [] 212 | node_info = args['labeling']['resource'] 213 | for name, params in dft_params.items(): 214 | path = Path(f'{folder}/iter_{iteration}/{name}') 215 | if not params.get('start_iteration'): 216 | params['start_iteration'] = 0 217 | if iteration >= params['start_iteration']: 218 | if not path.is_dir(): 219 | path.mkdir(parents=True) 220 | 221 | # get images that need to be labeled 222 | params['system'] = name 223 | params['pool_set'] = [f'{root}/md/iter_{iteration}/{name}/MD.traj', f'{root}/md/iter_{iteration}/{name}/warning_struct.traj'] 224 | params['al_info'] = f'{root}/select/iter_{iteration}/{name}/selected.json' 225 | with open(path / 'arguments.toml', 'w') as f: 226 | toml.dump(params, f) 227 | 228 | arguments = ['--cfg', 'arguments.toml'] 229 | 230 | if params.get('num_jobs'): 231 | for i in range(params['num_jobs']): 232 | dft_arguments = ['--cfg', '../arguments.toml', '--job_order', f'{i}'] 233 | dft_path = path / f'{i}' 234 | if not dft_path.is_dir(): 235 | dft_path.mkdir(parents=True) 236 | tasks.append(run( 237 | script=f'{root}/vasp.py', 238 | nodename='xeon40' if not node_info.get('nodename') else node_info['nodename'], 239 | cores=40 if not node_info.get('cores') else node_info['cores'], 240 | tmax='50h' if not node_info.get('tmax') else node_info['tmax'], 241 | args=dft_arguments, 242 | folder=dft_path, 243 | name=name, 244 | deps=[deps[name]], 245 | )) 246 | else: 247 | tasks.append(run( 248 | script=f'{root}/vasp.py', 249 | nodename='xeon40' if not node_info.get('nodename') else node_info['nodename'], 250 | cores=40 if not node_info.get('cores') else node_info['cores'], 251 | tmax='50h' if not node_info.get('tmax') else node_info['tmax'], 252 | args=arguments, 253 | folder=path, 254 | name=name, 255 | deps=[deps[name]], 256 | )) 257 | 258 | return tasks 259 | 260 | def all_done(runs): 261 | return all([task.done for task in runs]) 262 | 263 | def workflow(): 264 | dft = [] 265 | for iteration in range(9): 266 | # training part 267 | training = train_models('train', deps=dft, iteration=iteration) 268 | 269 | # data generating 270 | md = run_md('md', deps=training, iteration=iteration) 271 | 272 | # active learning selection 273 | select = active_learning('select', deps=md, iteration=iteration) 274 | 275 | # DFT labeling 276 | dft = run_dft('labeling', deps=select, iteration=iteration) 277 | -------------------------------------------------------------------------------- /workflow/md_run.py: -------------------------------------------------------------------------------- 1 | from ase.md.langevin import Langevin 2 | from ase.calculators.plumed import Plumed 3 | from ase import units 4 | from ase.md.velocitydistribution import MaxwellBoltzmannDistribution 5 | from ase.io import read, write, Trajectory 6 | 7 | import numpy as np 8 | import torch 9 | import sys 10 | import glob 11 | import toml 12 | import argparse 13 | from pathlib import Path 14 | import logging 15 | 16 | from PaiNN.data import AseDataset, collate_atomsdata 17 | from PaiNN.model import PainnModel 18 | from PaiNN.calculator import MLCalculator, EnsembleCalculator 19 | from ase.constraints import FixAtoms 20 | 21 | def setup_seed(seed): 22 | torch.manual_seed(seed) 23 | if torch.cuda.is_available(): 24 | torch.cuda.manual_seed_all(seed) 25 | np.random.seed(seed) 26 | torch.backends.cudnn.deterministic = True 27 | 28 | def get_arguments(arg_list=None): 29 | parser = argparse.ArgumentParser( 30 | description="MD simulations drive by graph neural networks", fromfile_prefix_chars="+" 31 | ) 32 | parser.add_argument( 33 | "--init_traj", 34 | type=str, 35 | help="Path to start configurations", 36 | ) 37 | parser.add_argument( 38 | "--start_indice", 39 | type=int, 40 | help="Indice of the start configuration", 41 | ) 42 | parser.add_argument( 43 | "--load_model", 44 | type=str, 45 | help="Where to find the models", 46 | ) 47 | parser.add_argument( 48 | "--time_step", 49 | type=float, 50 | default=0.5, 51 | help="Time step of MD simulation", 52 | ) 53 | parser.add_argument( 54 | "--max_steps", 55 | type=int, 56 | default=5000000, 57 | help="Maximum steps of MD", 58 | ) 59 | parser.add_argument( 60 | "--min_steps", 61 | type=int, 62 | default=100000, 63 | help="Minimum steps of MD, raise error if not reached", 64 | ) 65 | parser.add_argument( 66 | "--temperature", 67 | type=float, 68 | default=350.0, 69 | help="Maximum time steps of MD", 70 | ) 71 | parser.add_argument( 72 | "--fix_under", 73 | type=float, 74 | default=5.9, 75 | help="Fix atoms under the specified value", 76 | ) 77 | parser.add_argument( 78 | "--dump_step", 79 | type=int, 80 | default=100, 81 | help="Fix atoms under the specified value", 82 | ) 83 | parser.add_argument( 84 | "--print_step", 85 | type=int, 86 | default=1, 87 | help="Fix atoms under the specified value", 88 | ) 89 | parser.add_argument( 90 | "--num_uncertain", 91 | type=int, 92 | default=1000, 93 | help="Stop MD when too many structures with large uncertainty are collected", 94 | ) 95 | parser.add_argument( 96 | "--random_seed", 97 | type=int, 98 | help="Random seed for this run", 99 | ) 100 | parser.add_argument( 101 | "--device", 102 | type=str, 103 | default='cuda', 104 | help="Set which device to use for running MD e.g. 'cuda' or 'cpu'", 105 | ) 106 | parser.add_argument( 107 | "--cfg", 108 | type=str, 109 | default="arguments.toml", 110 | help="Path to config file. e.g. 'arguments.toml'" 111 | ) 112 | 113 | return parser.parse_args(arg_list) 114 | 115 | def update_namespace(ns, d): 116 | for k, v in d.items(): 117 | ns.__dict__[k] = v 118 | 119 | class CallsCounter: 120 | def __init__(self, func): 121 | self.calls = 0 122 | self.func = func 123 | def __call__(self, *args, **kwargs): 124 | self.calls += 1 125 | self.func(*args, **kwargs) 126 | 127 | def main(): 128 | args = get_arguments() 129 | if args.cfg: 130 | with open(args.cfg, 'r') as f: 131 | params = toml.load(f) 132 | update_namespace(args, params) 133 | 134 | setup_seed(args.random_seed) 135 | 136 | # set logger 137 | logger = logging.getLogger(__file__) 138 | logger.setLevel(logging.DEBUG) 139 | 140 | runHandler = logging.FileHandler('md.log', mode='w') 141 | runHandler.setLevel(logging.DEBUG) 142 | runHandler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)7s - %(message)s")) 143 | errorHandler = logging.FileHandler('error.log', mode='w') 144 | errorHandler.setLevel(logging.WARNING) 145 | errorHandler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)7s - %(message)s")) 146 | 147 | logger.addHandler(runHandler) 148 | logger.addHandler(errorHandler) 149 | logger.addHandler(logging.StreamHandler()) 150 | logger.warning = CallsCounter(logger.warning) 151 | logger.info = CallsCounter(logger.info) 152 | 153 | # load model 154 | model_pth = Path(args.load_model).rglob('*best_model.pth') 155 | models = [] 156 | for each in model_pth: 157 | state_dict = torch.load(each) 158 | model = PainnModel( 159 | num_interactions=state_dict["num_layer"], 160 | hidden_state_size=state_dict["node_size"], 161 | cutoff=state_dict["cutoff"], 162 | ) 163 | model.to(args.device) 164 | model.load_state_dict(state_dict["model"]) 165 | models.append(model) 166 | 167 | encalc = EnsembleCalculator(models) 168 | 169 | # set up md start configuration 170 | images = read(args.init_traj, ':') 171 | start_indice = np.random.choice(len(images)) if args.start_indice == None else args.start_indice 172 | logger.debug(f'MD starts from No.{start_indice} configuration in {args.init_traj}') 173 | atoms = images[start_indice] 174 | atoms.wrap() 175 | cons = FixAtoms(mask=atoms.positions[:, 2] < args.fix_under) if args.fix_under else [] 176 | atoms.set_constraint(cons) 177 | atoms.calc = encalc 178 | atoms.get_potential_energy() 179 | 180 | collect_traj = Trajectory('warning_struct.traj', 'w') 181 | @CallsCounter 182 | def printenergy(a=atoms): # store a reference to atoms in the definition. 183 | """Function to print the potential, kinetic and total energy.""" 184 | epot = a.get_potential_energy() 185 | ekin = a.get_kinetic_energy() 186 | temp = ekin / (1.5 * units.kB) / a.get_global_number_of_atoms() 187 | ensemble = a.calc.results['ensemble'] 188 | energy_var = ensemble['energy_var'] 189 | forces_var = np.mean(ensemble['forces_var']) 190 | forces_sd = np.mean(np.sqrt(ensemble['forces_var'])) 191 | forces_l2_var = np.mean(ensemble['forces_l2_var']) 192 | 193 | if forces_sd > 0.2: 194 | logger.error("Too large uncertainty!") 195 | if logger.info.calls + logger.warning.calls > args.min_steps: 196 | sys.exit(0) 197 | else: 198 | sys.exit("Too large uncertainty!") 199 | elif forces_sd > 0.05: 200 | collect_traj.write(a) 201 | logger.warning("Steps={:10d} Epot={:12.3f} Ekin={:12.3f} temperature={:8.2f} energy_var={:10.6f} forces_var={:10.6f} forces_sd={:10.6f} forces_l2_var={:10.6f}".format( 202 | printenergy.calls * args.print_step, 203 | epot, 204 | ekin, 205 | temp, 206 | energy_var, 207 | forces_var, 208 | forces_sd, 209 | forces_l2_var, 210 | )) 211 | if logger.warning.calls > args.num_uncertain: 212 | logger.error(f"More than {args.num_uncertain} uncertain structures are collected!") 213 | if logger.info.calls + logger.warning.calls > args.min_steps: 214 | sys.exit(0) 215 | else: 216 | sys.exit(f"More than {args.num_uncertain} uncertain structures are collected!") 217 | else: 218 | logger.info("Steps={:10d} Epot={:12.3f} Ekin={:12.3f} temperature={:8.2f} energy_var={:10.6f} forces_var={:10.6f} forces_sd={:10.6f} forces_l2_var={:10.6f}".format( 219 | printenergy.calls * args.print_step, 220 | epot, 221 | ekin, 222 | temp, 223 | energy_var, 224 | forces_var, 225 | forces_sd, 226 | forces_l2_var, 227 | )) 228 | 229 | #atoms.calc = encalc 230 | if not np.any(atoms.get_momenta()): 231 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) 232 | dyn = Langevin(atoms, args.time_step * units.fs, temperature_K=args.temperature, friction=0.1) 233 | dyn.attach(printenergy, interval=args.print_step) 234 | 235 | traj = Trajectory('MD.traj', 'w', atoms) 236 | dyn.attach(traj.write, interval=args.dump_step) 237 | dyn.run(args.max_steps) 238 | 239 | if __name__ == "__main__": 240 | main() -------------------------------------------------------------------------------- /workflow/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import json, os, sys, toml 4 | from pathlib import Path 5 | import argparse 6 | import logging 7 | import itertools 8 | import torch 9 | import time 10 | 11 | from PaiNN.data import AseDataset, collate_atomsdata 12 | from PaiNN.model import PainnModel 13 | 14 | def setup_seed(seed): 15 | torch.manual_seed(seed) 16 | if torch.cuda.is_available(): 17 | torch.cuda.manual_seed_all(seed) 18 | np.random.seed(seed) 19 | torch.backends.cudnn.deterministic = True 20 | 21 | def get_arguments(arg_list=None): 22 | parser = argparse.ArgumentParser( 23 | description="Train graph convolution network", fromfile_prefix_chars="+" 24 | ) 25 | parser.add_argument( 26 | "--load_model", 27 | type=str, 28 | help="Load model parameters from previous run", 29 | ) 30 | parser.add_argument( 31 | "--cutoff", 32 | type=float, 33 | help="Atomic interaction cutoff distance [�~E]", 34 | ) 35 | parser.add_argument( 36 | "--split_file", 37 | type=str, 38 | help="Train/test/validation split file json", 39 | ) 40 | parser.add_argument( 41 | "--val_ratio", 42 | type=float, 43 | help="Ratio of validation set. Only useful when 'split_file' is not assigned", 44 | ) 45 | parser.add_argument( 46 | "--num_interactions", 47 | type=int, 48 | help="Number of interaction layers used", 49 | ) 50 | parser.add_argument( 51 | "--node_size", type=int, help="Size of hidden node states" 52 | ) 53 | parser.add_argument( 54 | "--output_dir", 55 | type=str, 56 | help="Path to output directory", 57 | ) 58 | parser.add_argument( 59 | "--dataset", type=str, help="Path to ASE trajectory", 60 | ) 61 | parser.add_argument( 62 | "--max_steps", 63 | type=int, 64 | help="Maximum number of optimisation steps", 65 | ) 66 | parser.add_argument( 67 | "--device", 68 | type=str, 69 | help="Set which device to use for training e.g. 'cuda' or 'cpu'", 70 | ) 71 | parser.add_argument( 72 | "--batch_size", type=int, help="Number of molecules per minibatch", 73 | ) 74 | parser.add_argument( 75 | "--initial_lr", type=float, help="Initial learning rate", 76 | ) 77 | parser.add_argument( 78 | "--forces_weight", 79 | type=float, 80 | help="Tradeoff between training on forces (weight=1) and energy (weight=0)", 81 | ) 82 | parser.add_argument( 83 | "--log_inverval", 84 | type=int, 85 | help="The interval of model evaluation", 86 | ) 87 | parser.add_argument( 88 | "--plateau_scheduler", 89 | action="store_true", 90 | help="Using ReduceLROnPlateau scheduler for decreasing learning rate when learning plateaus", 91 | ) 92 | parser.add_argument( 93 | "--normalization", 94 | action="store_true", 95 | help="Enable normalization of the model", 96 | ) 97 | parser.add_argument( 98 | "--atomwise_normalization", 99 | action="store_true", 100 | help="Enable atomwise normalization", 101 | ) 102 | parser.add_argument( 103 | "--stop_patience", 104 | type=int, 105 | help="Stop training when validation loss is larger than best loss for 'stop_patience' steps", 106 | ) 107 | parser.add_argument( 108 | "--random_seed", 109 | type=int, 110 | help="Random seed for this run", 111 | ) 112 | parser.add_argument( 113 | "--cfg", 114 | type=str, 115 | help="Path to config file. e.g. 'arguments.toml'" 116 | ) 117 | 118 | return parser.parse_args(arg_list) 119 | 120 | def split_data(dataset, args): 121 | # Load or generate splits 122 | if args.split_file: 123 | with open(args.split_file, "r") as fp: 124 | splits = json.load(fp) 125 | else: 126 | datalen = len(dataset) 127 | num_validation = int(math.ceil(datalen * args.val_ratio)) 128 | indices = np.random.permutation(len(dataset)) 129 | splits = { 130 | "train": indices[num_validation:].tolist(), 131 | "validation": indices[:num_validation].tolist(), 132 | } 133 | 134 | # Save split file 135 | with open(os.path.join(args.output_dir, "datasplits.json"), "w") as f: 136 | json.dump(splits, f) 137 | 138 | # Split the dataset 139 | datasplits = {} 140 | for key, indices in splits.items(): 141 | datasplits[key] = torch.utils.data.Subset(dataset, indices) 142 | return datasplits 143 | 144 | def forces_criterion(predicted, target, reduction="mean"): 145 | # predicted, target are (bs, max_nodes, 3) tensors 146 | # node_count is (bs) tensor 147 | diff = predicted - target 148 | total_squared_norm = torch.linalg.norm(diff, dim=1) # bs 149 | if reduction == "mean": 150 | scalar = torch.mean(total_squared_norm) 151 | elif reduction == "sum": 152 | scalar = torch.sum(total_squared_norm) 153 | else: 154 | raise ValueError("Reduction must be 'mean' or 'sum'") 155 | return scalar 156 | 157 | def get_normalization(dataset, per_atom=True): 158 | # Use double precision to avoid overflows 159 | x_sum = torch.zeros(1, dtype=torch.double) 160 | x_2 = torch.zeros(1, dtype=torch.double) 161 | num_objects = 0 162 | for i, sample in enumerate(dataset): 163 | if i == 0: 164 | # Estimate "bias" from 1 sample 165 | # to avoid overflows for large valued datasets 166 | if per_atom: 167 | bias = sample["energy"] / sample["num_atoms"] 168 | else: 169 | bias = sample["energy"] 170 | x = sample["energy"] 171 | if per_atom: 172 | x = x / sample["num_atoms"] 173 | x -= bias 174 | x_sum += x 175 | x_2 += x ** 2.0 176 | num_objects += 1 177 | # Var(X) = E[X^2] - E[X]^2 178 | x_mean = x_sum / num_objects 179 | x_var = x_2 / num_objects - x_mean ** 2.0 180 | x_mean = x_mean + bias 181 | 182 | default_type = torch.get_default_dtype() 183 | 184 | return x_mean.type(default_type), torch.sqrt(x_var).type(default_type) 185 | 186 | def eval_model(model, dataloader, device, forces_weight): 187 | energy_running_ae = 0 188 | energy_running_se = 0 189 | 190 | forces_running_l2_ae = 0 191 | forces_running_l2_se = 0 192 | forces_running_c_ae = 0 193 | forces_running_c_se = 0 194 | forces_running_loss = 0 195 | 196 | running_loss = 0 197 | count = 0 198 | forces_count = 0 199 | criterion = torch.nn.MSELoss() 200 | 201 | for batch in dataloader: 202 | device_batch = { 203 | k: v.to(device=device, non_blocking=True) for k, v in batch.items() 204 | } 205 | out = model(device_batch) 206 | 207 | # counts 208 | count += batch["energy"].shape[0] 209 | forces_count += batch['forces'].shape[0] 210 | 211 | # use mean square loss here 212 | forces_loss = forces_criterion(out["forces"], device_batch["forces"]).item() 213 | energy_loss = criterion(out["energy"], device_batch["energy"]).item() #problem here 214 | total_loss = forces_weight * forces_loss + (1 - forces_weight) * energy_loss 215 | running_loss += total_loss * batch["energy"].shape[0] 216 | 217 | # energy errors 218 | outputs = {key: val.detach().cpu().numpy() for key, val in out.items()} 219 | energy_targets = batch["energy"].detach().cpu().numpy() 220 | energy_running_ae += np.sum(np.abs(energy_targets - outputs["energy"]), axis=0) 221 | energy_running_se += np.sum( 222 | np.square(energy_targets - outputs["energy"]), axis=0 223 | ) 224 | 225 | # force errors 226 | forces_targets = batch["forces"].detach().cpu().numpy() 227 | forces_diff = forces_targets - outputs["forces"] 228 | forces_l2_norm = np.sqrt(np.sum(np.square(forces_diff), axis=1)) 229 | 230 | forces_running_c_ae += np.sum(np.abs(forces_diff)) 231 | forces_running_c_se += np.sum(np.square(forces_diff)) 232 | 233 | forces_running_l2_ae += np.sum(np.abs(forces_l2_norm)) 234 | forces_running_l2_se += np.sum(np.square(forces_l2_norm)) 235 | 236 | energy_mae = energy_running_ae / count 237 | energy_rmse = np.sqrt(energy_running_se / count) 238 | 239 | forces_l2_mae = forces_running_l2_ae / forces_count 240 | forces_l2_rmse = np.sqrt(forces_running_l2_se / forces_count) 241 | 242 | forces_c_mae = forces_running_c_ae / (forces_count * 3) 243 | forces_c_rmse = np.sqrt(forces_running_c_se / (forces_count * 3)) 244 | 245 | total_loss = running_loss / count 246 | 247 | evaluation = { 248 | "energy_mae": energy_mae, 249 | "energy_rmse": energy_rmse, 250 | "forces_l2_mae": forces_l2_mae, 251 | "forces_l2_rmse": forces_l2_rmse, 252 | "forces_mae": forces_c_mae, 253 | "forces_rmse": forces_c_rmse, 254 | "sqrt(total_loss)": np.sqrt(total_loss), 255 | } 256 | 257 | return evaluation 258 | 259 | def update_namespace(ns, d): 260 | for k, v in d.items(): 261 | if not ns.__dict__.get(k): 262 | ns.__dict__[k] = v 263 | 264 | class EarlyStopping(): 265 | def __init__(self, patience=5, min_delta=0): 266 | 267 | self.patience = patience 268 | self.min_delta = min_delta 269 | self.counter = 0 270 | self.early_stop = False 271 | 272 | def __call__(self, val_loss, best_loss): 273 | if val_loss - best_loss > self.min_delta: 274 | self.counter +=1 275 | if self.counter >= self.patience: 276 | self.early_stop = True 277 | 278 | return self.early_stop 279 | 280 | def main(): 281 | args = get_arguments() 282 | if args.cfg: 283 | with open(args.cfg, 'r') as f: 284 | params = toml.load(f) 285 | update_namespace(args, params) 286 | 287 | # Setup random seed 288 | setup_seed(args.random_seed) 289 | 290 | # Setup logging 291 | os.makedirs(args.output_dir, exist_ok=True) 292 | logging.basicConfig( 293 | level=logging.DEBUG, 294 | format="%(asctime)s [%(levelname)-5.5s] %(message)s", 295 | handlers=[ 296 | logging.FileHandler( 297 | os.path.join(args.output_dir, "printlog.txt"), mode="w" 298 | ), 299 | logging.StreamHandler(), 300 | ], 301 | ) 302 | 303 | # Save command line args 304 | with open(os.path.join(args.output_dir, "commandline_args.txt"), "w") as f: 305 | f.write("\n".join(sys.argv[1:])) 306 | # Save parsed command line arguments 307 | with open(os.path.join(args.output_dir, "arguments.json"), "w") as f: 308 | json.dump(vars(args), f) 309 | 310 | # Create device 311 | device = torch.device(args.device) 312 | # Put a tensor on the device before loading data 313 | # This way the GPU appears to be in use when other users run gpustat 314 | torch.tensor([0], device=device) 315 | 316 | # Setup dataset and loader 317 | logging.info("loading data %s", args.dataset) 318 | dataset = AseDataset( 319 | args.dataset, 320 | cutoff = args.cutoff, 321 | ) 322 | 323 | datasplits = split_data(dataset, args) 324 | 325 | train_loader = torch.utils.data.DataLoader( 326 | datasplits["train"], 327 | args.batch_size, 328 | sampler=torch.utils.data.RandomSampler(datasplits["train"]), 329 | collate_fn=collate_atomsdata, 330 | ) 331 | val_loader = torch.utils.data.DataLoader( 332 | datasplits["validation"], 333 | args.batch_size, 334 | collate_fn=collate_atomsdata, 335 | ) 336 | 337 | logging.info('Dataset size: {}, training set size: {}, validation set size: {}'.format( 338 | len(dataset), 339 | len(datasplits["train"]), 340 | len(datasplits["validation"]), 341 | )) 342 | 343 | if args.normalization: 344 | logging.info("Computing mean and variance") 345 | target_mean, target_stddev = get_normalization( 346 | datasplits["train"], 347 | per_atom=args.atomwise_normalization, 348 | ) 349 | logging.debug("target_mean=%f, target_stddev=%f" % (target_mean, target_stddev)) 350 | 351 | net = PainnModel( 352 | num_interactions=args.num_interactions, 353 | hidden_state_size=args.node_size, 354 | cutoff=args.cutoff, 355 | normalization=args.normalization, 356 | target_mean=target_mean.tolist() if args.normalization else [0.0], 357 | target_stddev=target_stddev.tolist() if args.normalization else [1.0], 358 | atomwise_normalization=args.atomwise_normalization, 359 | ) 360 | net.to(device) 361 | 362 | optimizer = torch.optim.Adam(net.parameters(), lr=args.initial_lr) 363 | criterion = torch.nn.MSELoss() 364 | if args.plateau_scheduler: 365 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10) 366 | else: 367 | scheduler_fn = lambda step: 0.96 ** (step / 100000) 368 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_fn) 369 | early_stop = EarlyStopping(patience=args.stop_patience) 370 | 371 | running_loss = 0 372 | running_loss_count = 0 373 | # used for smoothing loss 374 | prev_loss = None 375 | best_val_loss = np.inf 376 | step = 0 377 | training_time = 0 378 | 379 | if args.load_model: 380 | logging.info(f"Load model from {args.load_model}") 381 | state_dict = torch.load(args.load_model) 382 | net.load_state_dict(state_dict["model"]) 383 | # step = state_dict["step"] 384 | # best_val_loss = state_dict["best_val_loss"] 385 | # optimizer.load_state_dict(state_dict["optimizer"]) 386 | scheduler.load_state_dict(state_dict["scheduler"]) 387 | 388 | for epoch in itertools.count(): 389 | for batch_host in train_loader: 390 | start = time.time() 391 | # Transfer to 'device' 392 | batch = { 393 | k: v.to(device=device, non_blocking=True) 394 | for (k, v) in batch_host.items() 395 | } 396 | # Reset gradient 397 | optimizer.zero_grad() 398 | 399 | # Forward, backward and optimize 400 | outputs = net( 401 | batch, compute_forces=bool(args.forces_weight) 402 | ) 403 | energy_loss = criterion(outputs["energy"], batch["energy"]) 404 | if args.forces_weight: 405 | forces_loss = forces_criterion(outputs['forces'], batch['forces']) 406 | else: 407 | forces_loss = 0.0 408 | total_loss = ( 409 | args.forces_weight * forces_loss 410 | + (1 - args.forces_weight) * energy_loss 411 | ) 412 | total_loss.backward() 413 | optimizer.step() 414 | running_loss += total_loss.item() * batch["energy"].shape[0] 415 | running_loss_count += batch["energy"].shape[0] 416 | training_time += time.time() - start 417 | 418 | # print(step, loss_value) 419 | # Validate and save model 420 | if (step % args.log_interval == 0) or ((step + 1) == args.max_steps): 421 | eval_start = time.time() 422 | train_loss = running_loss / running_loss_count 423 | running_loss = running_loss_count = 0 424 | 425 | eval_dict = eval_model(net, val_loader, device, args.forces_weight) 426 | eval_formatted = ", ".join( 427 | ["{}={:.3f}".format(k, v) for (k, v) in eval_dict.items()] 428 | ) 429 | # loss smoothing 430 | eval_loss = np.square(eval_dict["sqrt(total_loss)"]) 431 | smooth_loss = eval_loss if prev_loss == None else 0.9 * eval_loss + 0.1 * prev_loss 432 | prev_loss = smooth_loss 433 | 434 | logging.info( 435 | "step={}, {}, sqrt(train_loss)={:.3f}, sqrt(smooth_loss)={:.3f}, patience={:3d}, training time={:.3f} min, eval time={:.3f} min".format( 436 | step, 437 | eval_formatted, 438 | math.sqrt(train_loss), 439 | math.sqrt(smooth_loss), 440 | early_stop.counter, 441 | training_time / 60, 442 | (time.time() - eval_start) / 60, 443 | ) 444 | ) 445 | training_time = 0 446 | # reduce learning rate 447 | if args.plateau_scheduler: 448 | scheduler.step(smooth_loss) 449 | # Save checkpoint 450 | if not early_stop(math.sqrt(smooth_loss), best_val_loss): 451 | best_val_loss = math.sqrt(smooth_loss) 452 | torch.save( 453 | { 454 | "model": net.state_dict(), 455 | "optimizer": optimizer.state_dict(), 456 | "scheduler": scheduler.state_dict(), 457 | "step": step, 458 | "best_val_loss": best_val_loss, 459 | "node_size": args.node_size, 460 | "num_layer": args.num_interactions, 461 | "cutoff": args.cutoff, 462 | }, 463 | os.path.join(args.output_dir, "best_model.pth"), 464 | ) 465 | else: 466 | sys.exit(0) 467 | 468 | step += 1 469 | 470 | if not args.plateau_scheduler: 471 | scheduler.step() 472 | 473 | if step >= args.max_steps: 474 | logging.info("Max steps reached, exiting") 475 | torch.save( 476 | { 477 | "model": net.state_dict(), 478 | "optimizer": optimizer.state_dict(), 479 | "scheduler": scheduler.state_dict(), 480 | "step": step, 481 | "best_val_loss": best_val_loss, 482 | "node_size": args.node_size, 483 | "num_layer": args.num_interactions, 484 | "cutoff": args.cutoff, 485 | }, 486 | os.path.join(args.output_dir, "exit_model.pth"), 487 | ) 488 | sys.exit(0) 489 | 490 | if __name__ == "__main__": 491 | main() 492 | -------------------------------------------------------------------------------- /workflow/vasp.py: -------------------------------------------------------------------------------- 1 | from ase.calculators.vasp import Vasp 2 | from ase.io import read, write, Trajectory 3 | from shutil import copy 4 | import os, subprocess 5 | import numpy as np 6 | import argparse 7 | import json 8 | import toml 9 | from pathlib import Path 10 | 11 | def get_arguments(arg_list=None): 12 | parser = argparse.ArgumentParser( 13 | description="General Active Learning", fromfile_prefix_chars="+" 14 | ) 15 | parser.add_argument( 16 | "--label_set", 17 | type=str, 18 | help="Path to trajectory to be labeled by DFT", 19 | ) 20 | parser.add_argument( 21 | "--train_set", 22 | type=str, 23 | help="Path to existing training data set", 24 | ) 25 | parser.add_argument( 26 | "--pool_set", 27 | type=str, 28 | help="Path to MD trajectory obtained from machine learning potential", 29 | ) 30 | parser.add_argument( 31 | "--al_info", 32 | type=str, 33 | help="Path to json file that stores indices selected in pool set", 34 | ) 35 | parser.add_argument( 36 | "--num_jobs", 37 | type=int, 38 | help="Number of DFT jobs", 39 | ) 40 | parser.add_argument( 41 | "--job_order", 42 | type=int, 43 | help="Split DFT jobs to several different parts", 44 | ) 45 | parser.add_argument( 46 | "--cfg", 47 | type=str, 48 | default="arguments.toml", 49 | help="Path to config file. e.g. 'arguments.toml'" 50 | ) 51 | 52 | return parser.parse_args(arg_list) 53 | 54 | def update_namespace(ns, d): 55 | for k, v in d.items(): 56 | if not isinstance(v, dict): 57 | ns.__dict__[k] = v 58 | 59 | def main(): 60 | # set environment variables 61 | os.putenv('ASE_VASP_VDW', '/home/energy/modules/software/VASP/vasp-potpaw-5.4') 62 | os.putenv('VASP_PP_PATH', '/home/energy/modules/software/VASP/vasp-potpaw-5.4') 63 | os.putenv('ASE_VASP_COMMAND', 'mpirun vasp_std') 64 | 65 | args = get_arguments() 66 | if args.cfg: 67 | with open(args.cfg, 'r') as f: 68 | params = toml.load(f) 69 | update_namespace(args, params) 70 | 71 | # get images and set parameters 72 | if args.label_set: 73 | images = read(args.label_set, index = ':') 74 | elif args.pool_set: 75 | if isinstance(args.pool_set, list): 76 | pool_traj = [] 77 | for pool_path in args.pool_set: 78 | if Path(pool_path).stat().st_size > 0: 79 | pool_traj += read(pool_path, ':') 80 | else: 81 | pool_traj = Trajectory(args.pool_set) 82 | with open(args.al_info) as f: 83 | indices = json.load(f)["selected"] 84 | if args.num_jobs: 85 | split_idx = np.array_split(indices, args.num_jobs) 86 | indices = split_idx[args.job_order] 87 | images = [pool_traj[i] for i in indices] 88 | else: 89 | raise RuntimeError('Valid configarations for DFT calculation should be provided!') 90 | 91 | vasp_params = params['VASP'] 92 | calc = Vasp(**vasp_params) 93 | traj = Trajectory('dft_structures.traj', mode = 'a') 94 | check_result = False 95 | unconverged = Trajectory('unconverged.traj', mode = 'a') 96 | unconverged_idx = [] 97 | for i, atoms in enumerate(images): 98 | atoms.set_pbc([True,True,True]) 99 | atoms.set_calculator(calc) 100 | atoms.get_potential_energy() 101 | steps = int(subprocess.getoutput('grep LOOP OUTCAR | wc -l')) 102 | if steps <= vasp_params['nelm']: 103 | traj.write(atoms) 104 | else: 105 | check_result = True 106 | unconverged.write(atoms) 107 | unconverged_idx.append(i) 108 | copy('OSZICAR', 'OSZICAR_{}'.format(i)) 109 | 110 | traj.close() 111 | # write to training set 112 | if check_result: 113 | raise RuntimeError(f"DFT calculations of {unconverged_idx} are not converged!") 114 | 115 | if args.train_set: 116 | train_traj = Trajectory(args.train_set, mode = 'a') 117 | images = read('dft_structures.traj', ':') 118 | for atoms in images: 119 | atoms.info['system'] = args.system 120 | atoms.info['path'] = str(Path('dft_structures.traj').resolve()) 121 | train_traj.write(atoms) 122 | 123 | os.remove('WAVECAR') 124 | 125 | if __name__ == "__main__": 126 | main() 127 | --------------------------------------------------------------------------------