├── .gitignore ├── LICENSE ├── README.org ├── cdm ├── __init__.py ├── charge_trainer.py ├── chg_utils.py ├── models │ ├── __init__.py │ ├── chargemodel.py │ ├── gemnet_charge.py │ ├── gemnet_oc_charge.py │ ├── painn_charge.py │ ├── schnet_charge.py │ └── scn_charge.py ├── tests │ ├── test_probe_graph_adder.py │ └── test_structure └── utils │ ├── chg_db.py │ ├── dir_of_chgcars.py │ ├── inference.py │ ├── preprocessing.py │ ├── probe_graph.py │ └── vasp.py ├── configs ├── common │ └── common.yml ├── datasets │ ├── 10k.yml │ ├── 1k.yml │ ├── 33k.yml │ └── 3k.yml ├── models │ ├── painn │ │ └── painn-small.yml │ └── schnet │ │ ├── schnet-large.yml │ │ └── schnet-small.yml ├── optimizers │ └── adam-standard.yml └── template.yml ├── data ├── 10k_split.txt ├── 1k_split.txt ├── 3k_split.txt ├── test.db ├── train0.db ├── train1.db ├── train2.db ├── train3.db ├── train4.db └── val.db ├── notebooks ├── Inference.ipynb ├── charge-dataset-creation.ipynb ├── training.ipynb └── wandb-sweep.ipynb └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | runs/ 2 | checkpoints/ 3 | logs/ 4 | __pycache__/ 5 | */__pycache__/ 6 | wandb/ 7 | .ipynb_checkpoints/ 8 | *.egg-info/ 9 | dev/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ethan Sunshine 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.org: -------------------------------------------------------------------------------- 1 | Tools to build charge density models using [[https://github.com/Open-Catalyst-Project/ocp][ocpmodels]] 2 | 3 | Many important ideas are taken from [[https://github.com/peterbjorgensen/DeepDFT][DeepDFT]]. Historically, much of the code was also borrowed from this repository, but that is no longer the case. 4 | -------------------------------------------------------------------------------- /cdm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .charge_trainer import ChargeTrainer 3 | from .utils.dir_of_chgcars import ChgcarDataset -------------------------------------------------------------------------------- /cdm/charge_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import torch_geometric 8 | from tqdm import tqdm 9 | 10 | from torch.utils.data import DataLoader 11 | from torch.nn.parallel.distributed import DistributedDataParallel 12 | 13 | from torch_geometric.data import Batch 14 | 15 | from ocpmodels.common import distutils 16 | from ocpmodels.common.registry import registry 17 | from ocpmodels.modules.normalizer import Normalizer 18 | from ocpmodels.trainers.base_trainer import BaseTrainer 19 | from ocpmodels.common.utils import pyg2_data_transform 20 | 21 | from ocpmodels.common.data_parallel import ( 22 | BalancedBatchSampler, 23 | OCPDataParallel, 24 | ParallelCollater, 25 | ) 26 | 27 | from ocpmodels.modules.loss import AtomwiseL2Loss, DDPLoss, L2MAELoss 28 | from ocpmodels.modules.evaluator import * 29 | 30 | from cdm.utils.probe_graph import ProbeGraphAdder 31 | from cdm import models 32 | 33 | 34 | @registry.register_trainer("charge") 35 | class ChargeTrainer(BaseTrainer): 36 | """ 37 | Trainer class for charge density prediction task. 38 | 39 | .. note:: 40 | 41 | Examples of configurations for task, model, dataset and optimizer 42 | can be found in configs/ 43 | 44 | 45 | Args: 46 | task (dict): Task configuration. 47 | model (dict): Model configuration. 48 | dataset (dict): Dataset configuration. The dataset needs to be a SinglePointLMDB dataset. 49 | optimizer (dict): Optimizer configuration. 50 | identifier (str): Experiment identifier that is appended to log directory. 51 | run_dir (str, optional): Path to the run directory where logs are to be saved. 52 | (default: :obj:`None`) 53 | is_debug (bool, optional): Run in debug mode. 54 | (default: :obj:`False`) 55 | is_hpo (bool, optional): Run hyperparameter optimization with Ray Tune. 56 | (default: :obj:`False`) 57 | print_every (int, optional): Frequency of printing logs. 58 | (default: :obj:`100`) 59 | seed (int, optional): Random number seed. 60 | (default: :obj:`None`) 61 | logger (str, optional): Type of logger to be used. 62 | (default: :obj:`tensorboard`) 63 | local_rank (int, optional): Local rank of the process, only applicable for distributed training. 64 | (default: :obj:`0`) 65 | amp (bool, optional): Run using automatic mixed precision. 66 | (default: :obj:`False`) 67 | slurm (dict): Slurm configuration. Currently just for keeping track. 68 | (default: :obj:`{}`) 69 | """ 70 | 71 | def __init__( 72 | self, 73 | task, 74 | model, 75 | dataset, 76 | optimizer, 77 | identifier, 78 | normalizer=None, 79 | timestamp_id=None, 80 | run_dir=None, 81 | is_debug=False, 82 | is_hpo=False, 83 | print_every=100, 84 | log_every=100, 85 | seed=None, 86 | logger="wandb", 87 | local_rank=0, 88 | amp=False, 89 | cpu=False, 90 | slurm={}, 91 | noddp=False, 92 | name=None, 93 | trainer = 'charge', 94 | ): 95 | 96 | super().__init__( 97 | task=task, 98 | model=model, 99 | dataset=dataset, 100 | optimizer=optimizer, 101 | identifier=identifier, 102 | normalizer=normalizer, 103 | timestamp_id=timestamp_id, 104 | run_dir=run_dir, 105 | is_debug=is_debug, 106 | is_hpo=is_hpo, 107 | print_every=print_every, 108 | seed=seed, 109 | logger=logger, 110 | local_rank=local_rank, 111 | amp=amp, 112 | cpu=cpu, 113 | name='s2ef', 114 | slurm=slurm, 115 | noddp=noddp, 116 | ) 117 | 118 | self.evaluator = ChargeEvaluator() 119 | self.name = 'charge' 120 | self.log_every = log_every 121 | self.num_devices = self.config['gpus'] if (self.config['gpus'] > 0) else 1 122 | 123 | def load_loss(self): 124 | 125 | self.loss_fn = {} 126 | self.loss_fn['charge'] = self.config['optim'].get('loss_charge', 'mae') 127 | 128 | for loss, loss_name in self.loss_fn.items(): 129 | if loss_name in ['l1', 'mae']: 130 | self.loss_fn[loss] = torch.nn.L1Loss() 131 | elif loss_name in ['l2', 'mse']: 132 | self.loss_fn[loss] = torch.nn.MSELoss() 133 | elif loss_name == 'l2mae': 134 | self.loss_fn[loss] = L2MAELoss() 135 | elif loss_name == 'atomwisel2': 136 | self.loss_fn[loss] = AtomwiseL2Loss() 137 | elif loss_name == 'normed_mae': 138 | self.loss_fn[loss] = NormedMAELoss() 139 | else: 140 | raise NotImplementedError( 141 | f'Unknown loss function name: {loss_name}' 142 | ) 143 | 144 | self.loss_fn[loss] = DDPLoss(self.loss_fn[loss], reduction='sum') 145 | 146 | 147 | def load_task(self): 148 | logging.info(f"Loading dataset: {self.config['task']['dataset']}") 149 | self.num_targets = 1 150 | 151 | @torch.no_grad() 152 | def predict( 153 | self, loader, per_image=True, results_file=None, disable_tqdm=False 154 | ): 155 | if distutils.is_master() and not disable_tqdm: 156 | logging.info('Predicting on test.') 157 | assert isinstance( 158 | loader, 159 | ( 160 | torch.utils.data.dataloader.DataLoader, 161 | torch_geometric.data.Batch, 162 | ), 163 | ) 164 | rank = distutils.get_rank() 165 | 166 | if isinstance(loader, torch_geometric.data.Batch): 167 | loader = [[loader]] 168 | 169 | self.model.eval() 170 | if self.ema: 171 | self.ema.store() 172 | self.ema.copy_to() 173 | 174 | if self.normalizers is not None and 'target' in self.normalizers: 175 | self.normalizers['target'].to(self.device) 176 | predictions = {'id': [], 'charge': []} 177 | 178 | for i, batch in tqdm( 179 | enumerate(loader), 180 | total=len(loader), 181 | position=rank, 182 | desc='device {}'.format(rank), 183 | disable=disable_tqdm, 184 | ): 185 | 186 | if hasattr(batch[0], 'probe_data'): 187 | for subbatch in batch: 188 | subbatch.probe_data = [pyg2_data_transform(x) for x in subbatch.probe_data] 189 | subbatch.probe_data = Batch.from_data_list(subbatch.probe_data) 190 | 191 | with torch.cuda.amp.autocast(enabled=self.scaler is not None): 192 | out = self._forward(batch) 193 | 194 | if self.normalizers is not None and 'target' in self.normalizers: 195 | out['charge'] = self.normalizers['target'].denorm( 196 | out['charge'] 197 | ) 198 | 199 | if per_image: 200 | predictions['id'].extend( 201 | [str(i) for i in batch[0].sid.tolist()] 202 | ) 203 | predictions['charge'].extend(out['charge'].tolist()) 204 | else: 205 | predictions['charge'] = out['charge'].detach() 206 | return predictions 207 | 208 | self.save_results(predictions, results_file, keys=['charge']) 209 | 210 | if self.ema: 211 | self.ema.restore() 212 | 213 | return predictions 214 | 215 | def train(self, disable_eval_tqdm=False): 216 | eval_every = self.config['optim'].get( 217 | 'eval_every', len(self.train_loader) 218 | ) 219 | primary_metric = self.config['task'].get( 220 | 'primary_metric', self.evaluator.task_primary_metric[self.name] 221 | ) 222 | self.best_val_metric = 1e9 223 | 224 | # Calculate start_epoch from step instead of loading the epoch number 225 | # to prevent inconsistencies due to different batch size in checkpoint. 226 | start_epoch = self.step // len(self.train_loader) 227 | 228 | for epoch_int in range( 229 | start_epoch, self.config['optim']['max_epochs'] 230 | ): 231 | self.train_sampler.set_epoch(epoch_int) 232 | skip_steps = self.step % len(self.train_loader) 233 | train_loader_iter = iter(self.train_loader) 234 | 235 | for i in range(skip_steps, len(self.train_loader)): 236 | 237 | self.epoch = epoch_int + (i + 1) / len(self.train_loader) 238 | self.step = (epoch_int * len(self.train_loader) + i + 1) * self.num_devices 239 | self.model.train() 240 | 241 | # Get a batch. 242 | 243 | batch = next(train_loader_iter) 244 | 245 | if hasattr(batch[0], 'probe_data'): 246 | for subbatch in batch: 247 | subbatch.probe_data = [pyg2_data_transform(x) for x in subbatch.probe_data] 248 | subbatch.probe_data = Batch.from_data_list(subbatch.probe_data) 249 | 250 | # Forward, loss, backward. 251 | with torch.cuda.amp.autocast(enabled=self.scaler is not None): 252 | out = self._forward(batch) 253 | loss = self._compute_loss(out, batch) 254 | loss = self.scaler.scale(loss) if self.scaler else loss 255 | 256 | if torch.sum(out['charge']) != 0: 257 | if not torch.any(torch.isinf(out['charge'])): 258 | self._backward(loss) 259 | 260 | scale = self.scaler.get_scale() if self.scaler else 1.0 261 | 262 | # Compute metrics. 263 | self.metrics = self._compute_metrics( 264 | out, 265 | batch, 266 | self.evaluator, 267 | metrics={}, 268 | ) 269 | 270 | # Log metrics. 271 | if (self.step % self.log_every == 0) or (self.step % self.config['cmd']['print_every'] == 0): 272 | log_dict = {k: self.metrics[k]['metric'] for k in self.metrics} 273 | log_dict.update( 274 | { 275 | 'lr': self.scheduler.get_lr(), 276 | 'epoch': self.epoch, 277 | 'step': self.step, 278 | } 279 | ) 280 | if ( 281 | self.step % self.config['cmd']['print_every'] == 0 282 | and distutils.is_master() 283 | and not self.is_hpo 284 | ): 285 | log_str = [ 286 | '{}: {:.2e}'.format(k, v) for k, v in log_dict.items() 287 | ] 288 | print(', '.join(log_str)) 289 | self.metrics = {} 290 | 291 | if self.logger is not None: 292 | self.logger.log( 293 | log_dict, 294 | step=self.step, 295 | split='train', 296 | ) 297 | 298 | # Evaluate on val set after every `eval_every` iterations. 299 | if self.step % eval_every == 0: 300 | self.save( 301 | checkpoint_file='checkpoint.pt', training_state=True 302 | ) 303 | 304 | if self.val_loader is not None: 305 | val_metrics = self.validate( 306 | split='val', 307 | disable_tqdm=disable_eval_tqdm, 308 | ) 309 | if ( 310 | val_metrics[ 311 | self.evaluator.task_primary_metric[self.name] 312 | ]['metric'] 313 | < self.best_val_metric 314 | ): 315 | self.best_val_metric = val_metrics[ 316 | self.evaluator.task_primary_metric[self.name] 317 | ]['metric'] 318 | self.save( 319 | metrics=val_metrics, 320 | checkpoint_file='best_checkpoint.pt', 321 | training_state=False, 322 | ) 323 | if self.test_loader is not None: 324 | self.predict( 325 | self.test_loader, 326 | results_file='predictions', 327 | disable_tqdm=False, 328 | ) 329 | 330 | if self.is_hpo: 331 | self.hpo_update( 332 | self.epoch, 333 | self.step, 334 | self.metrics, 335 | val_metrics, 336 | ) 337 | 338 | if self.scheduler.scheduler_type == 'ReduceLROnPlateau': 339 | if self.step % eval_every == 0: 340 | self.scheduler.step( 341 | metrics=val_metrics[primary_metric]['metric'], 342 | ) 343 | else: 344 | self.scheduler.step() 345 | 346 | torch.cuda.empty_cache() 347 | 348 | self.train_dataset.close_db() 349 | if self.config.get('val_dataset', False): 350 | self.val_dataset.close_db() 351 | if self.config.get('test_dataset', False): 352 | self.test_dataset.close_db() 353 | 354 | def _forward(self, batch_list): 355 | output = self.model(batch_list) 356 | 357 | if output.shape[-1] == 1: 358 | output = output.view(-1) 359 | 360 | return { 361 | 'charge': output, 362 | } 363 | 364 | def _compute_loss(self, out, batch_list): 365 | 366 | charge_target = torch.cat( 367 | [batch.probe_data.target for batch in batch_list], dim=0 368 | ) 369 | 370 | if self.normalizer.get('normalize_labels', False): 371 | target_normed = self.normalizers['target'].norm(charge_target) 372 | else: 373 | target_normed = charge_target 374 | 375 | loss = self.loss_fn['charge'](out['charge'], target_normed) 376 | return loss 377 | 378 | def _compute_metrics(self, out, batch_list, evaluator, metrics={}): 379 | charge_target = torch.cat( 380 | [batch.probe_data.target for batch in batch_list], dim=0 381 | ) 382 | 383 | if self.normalizer.get('normalize_labels', False): 384 | out['charge'] = self.normalizers['target'].denorm(out['charge']) 385 | 386 | metrics = evaluator.eval( 387 | out, 388 | {'charge': charge_target}, 389 | prev_metrics=metrics, 390 | ) 391 | 392 | return metrics 393 | 394 | def load_model(self): 395 | # Build model 396 | if distutils.is_master(): 397 | logging.info(f"Loading model: {self.config['model']}") 398 | 399 | # TODO: depreicated, remove. 400 | bond_feat_dim = None 401 | bond_feat_dim = self.config['model_attributes'].get( 402 | 'num_gaussians', 50 403 | ) 404 | 405 | loader = self.train_loader or self.val_loader or self.test_loader 406 | 407 | self.model = registry.get_model_class(self.config['model'])( 408 | **self.config['model_attributes'], 409 | ).to(self.device) 410 | 411 | if distutils.is_master(): 412 | logging.info( 413 | f'Loaded {self.model.__class__.__name__} with ' 414 | f'{self.model.num_params} parameters.' 415 | ) 416 | 417 | if self.logger is not None: 418 | self.logger.watch(self.model) 419 | 420 | self.model = OCPDataParallel( 421 | self.model, 422 | output_device=self.device, 423 | num_gpus=1 if not self.cpu else 0, 424 | ) 425 | if distutils.initialized() and not self.config['noddp']: 426 | self.model = DistributedDataParallel( 427 | self.model, device_ids=[self.device] 428 | ) 429 | @torch.no_grad() 430 | def validate(self, split='val', disable_tqdm=False): 431 | if distutils.is_master(): 432 | logging.info(f'Evaluating on {split}.') 433 | if self.is_hpo: 434 | disable_tqdm = True 435 | 436 | self.model.eval() 437 | if self.ema: 438 | self.ema.store() 439 | self.ema.copy_to() 440 | 441 | evaluator, metrics = ChargeEvaluator(task='charge'), {} 442 | rank = distutils.get_rank() 443 | 444 | loader = self.val_loader if split =='val' else self.test_loader 445 | 446 | for i, batch in tqdm( 447 | enumerate(loader), 448 | total=len(loader), 449 | position=rank, 450 | desc='device {}'.format(rank), 451 | disable=disable_tqdm, 452 | ): 453 | if hasattr(batch[0], 'probe_data'): 454 | for subbatch in batch: 455 | subbatch.probe_data = [pyg2_data_transform(x) for x in subbatch.probe_data] 456 | subbatch.probe_data = Batch.from_data_list(subbatch.probe_data) 457 | 458 | # Forward. 459 | with torch.cuda.amp.autocast(enabled=self.scaler is not None): 460 | out = self._forward(batch) 461 | loss = self._compute_loss(out, batch) 462 | 463 | # Compute metrics. 464 | metrics = self._compute_metrics(out, batch, evaluator, metrics) 465 | metrics = evaluator.update('loss', loss.item(), metrics) 466 | 467 | aggregated_metrics = {} 468 | for k in metrics: 469 | aggregated_metrics[k] = { 470 | 'total': distutils.all_reduce( 471 | metrics[k]['total'], average=False, device=self.device 472 | ), 473 | 'numel': distutils.all_reduce( 474 | metrics[k]['numel'], average=False, device=self.device 475 | ), 476 | } 477 | aggregated_metrics[k]['metric'] = ( 478 | aggregated_metrics[k]['total'] / aggregated_metrics[k]['numel'] 479 | ) 480 | metrics = aggregated_metrics 481 | 482 | log_dict = {k: metrics[k]['metric'] for k in metrics} 483 | log_dict.update({'epoch': self.epoch}) 484 | if distutils.is_master(): 485 | log_str = ['{}: {:.4f}'.format(k, v) for k, v in log_dict.items()] 486 | logging.info(', '.join(log_str)) 487 | 488 | # Make plots. 489 | if self.logger is not None: 490 | self.logger.log( 491 | log_dict, 492 | step=self.step, 493 | split=split, 494 | ) 495 | 496 | if self.ema: 497 | self.ema.restore() 498 | 499 | return metrics 500 | 501 | def load_datasets(self): 502 | 503 | self.config['optim']['batch_size'] = 1 504 | self.config['optim']['eval_batch_size'] = 1 505 | 506 | self.parallel_collater = ParallelCollater( 507 | 0 if self.cpu else 1, 508 | self.config['model_attributes'].get('otf_graph', False), 509 | ) 510 | 511 | self.train_loader = self.val_loader = self.test_loader = None 512 | 513 | if self.config.get('dataset', None): 514 | self.train_dataset = registry.get_dataset_class( 515 | self.config['task']['dataset'] 516 | )(self.config['dataset']) 517 | self.train_sampler = self.get_sampler( 518 | self.train_dataset, 519 | self.config['optim']['batch_size'], 520 | shuffle=True, 521 | ) 522 | self.train_loader = self.get_dataloader( 523 | self.train_dataset, 524 | self.train_sampler, 525 | ) 526 | 527 | if self.config.get('val_dataset', None): 528 | self.val_dataset = registry.get_dataset_class( 529 | self.config['task']['dataset'] 530 | )(self.config['val_dataset']) 531 | self.val_sampler = self.get_sampler( 532 | self.val_dataset, 533 | self.config['optim'].get( 534 | 'eval_batch_size', self.config['optim']['batch_size'] 535 | ), 536 | shuffle=False, 537 | ) 538 | self.val_loader = self.get_dataloader( 539 | self.val_dataset, 540 | self.val_sampler, 541 | ) 542 | 543 | if self.config.get('test_dataset', None): 544 | self.test_dataset = registry.get_dataset_class( 545 | self.config['task']['dataset'] 546 | )(self.config['test_dataset']) 547 | self.test_sampler = self.get_sampler( 548 | self.test_dataset, 549 | self.config['optim'].get( 550 | 'eval_batch_size', self.config['optim']['batch_size'] 551 | ), 552 | shuffle=False, 553 | ) 554 | self.test_loader = self.get_dataloader( 555 | self.test_dataset, 556 | self.test_sampler, 557 | ) 558 | 559 | # Normalizer for the dataset. 560 | # Compute mean, std of training set labels. 561 | self.normalizers = {} 562 | if self.normalizer.get('normalize_labels', False): 563 | if 'target_mean' in self.normalizer: 564 | self.normalizers['target'] = Normalizer( 565 | mean=self.normalizer['target_mean'], 566 | std=self.normalizer['target_std'], 567 | device=self.device, 568 | ) 569 | else: 570 | self.normalizers["target"] = Normalizer( 571 | tensor=self.train_loader.dataset.data.y[ 572 | self.train_loader.dataset.__indices__ 573 | ], 574 | device=self.device, 575 | ) 576 | 577 | def get_dataloader(self, dataset, sampler): 578 | loader = DataLoader( 579 | dataset, 580 | collate_fn=self.parallel_collater, 581 | num_workers=self.config["optim"]["num_workers"], 582 | pin_memory=True, 583 | batch_sampler=sampler, 584 | prefetch_factor = 6, 585 | ) 586 | return loader 587 | 588 | 589 | class ChargeEvaluator(Evaluator): 590 | def __init__(self, task = 'charge'): 591 | self.task = 'charge' 592 | 593 | self.task_metrics['charge'] = [ 594 | 'norm_charge_mae', 595 | 'norm_charge_rmse', 596 | 'charge_mae', 597 | 'charge_mse', 598 | 'true_density', 599 | 'total_charge_ratio', 600 | ] 601 | 602 | self.task_attributes['charge'] = ['charge'] 603 | self.task_primary_metric['charge'] = 'norm_charge_mae' 604 | 605 | self.metric_fn = self.task_metrics[task] 606 | 607 | def eval(self, prediction, target, prev_metrics={}): 608 | for attr in self.task_attributes[self.task]: 609 | assert attr in prediction 610 | assert attr in target 611 | assert prediction[attr].shape == target[attr].shape 612 | 613 | metrics = prev_metrics 614 | 615 | for fn in self.task_metrics[self.task]: 616 | res = eval(fn)(prediction, target) 617 | metrics = self.update(fn, res, metrics) 618 | 619 | return metrics 620 | 621 | class NormedMAELoss(torch.nn.Module): 622 | def __init__(self): 623 | super().__init__() 624 | def forward( 625 | self, 626 | prediction, 627 | target, 628 | ): 629 | return torch.sum(torch.abs(prediction - target)) \ 630 | / torch.sum(torch.abs(target)) 631 | 632 | def absolute_error(prediction, target): 633 | error = torch.abs(prediction - target) 634 | return { 635 | 'metric': (torch.mean(error)).item(), 636 | 'total': torch.sum(error).item(), 637 | 'numel': prediction.numel(), 638 | } 639 | 640 | def squared_error(prediction, target): 641 | error = torch.abs(prediction - target) **2 642 | return { 643 | 'metric': (torch.mean(error)).item(), 644 | 'total': torch.sum(error).item(), 645 | 'numel': prediction.numel(), 646 | } 647 | 648 | def charge_mae(prediction, target): 649 | return absolute_error(prediction['charge'], target['charge']) 650 | 651 | def charge_mse(prediction, target): 652 | return squared_error(prediction['charge'].float(), target['charge'].float()) 653 | 654 | 655 | def total_charge_ratio(prediction, target): 656 | error = torch.sum(prediction['charge']) / torch.sum(target['charge']) 657 | return { 658 | 'metric': torch.mean(error).item(), 659 | 'total': torch.sum(error).item(), 660 | 'numel': error.numel() 661 | } 662 | 663 | def true_density(prediction, target): 664 | return { 665 | 'metric': torch.mean(target['charge']).item(), 666 | 'total': torch.sum(target['charge']).item(), 667 | 'numel': target['charge'].numel(), 668 | } 669 | 670 | def norm_charge_mae(prediction, target): 671 | error = torch.sum(torch.abs(prediction['charge'] - target['charge'])) \ 672 | / torch.sum(torch.abs(target['charge'])) 673 | return { 674 | 'metric': error.item(), 675 | 'total': error.item(), 676 | 'numel': 1, 677 | } 678 | 679 | def norm_charge_rmse(prediction, target): 680 | error = torch.sqrt(torch.sum(torch.square(prediction['charge'] - target['charge'])) \ 681 | / torch.sum(torch.abs(target['charge']))) 682 | return { 683 | 'metric': error.item(), 684 | 'total': error.item(), 685 | 'numel': 1, 686 | } 687 | -------------------------------------------------------------------------------- /cdm/chg_utils.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import pickle 3 | import numpy as np 4 | import torch 5 | import os 6 | 7 | from tqdm import tqdm 8 | 9 | from torch_geometric.data import Data 10 | 11 | from ocpmodels.preprocessing import AtomsToGraphs 12 | from ocpmodels.datasets import data_list_collater 13 | 14 | from ase import Atoms 15 | from ase.calculators.vasp import VaspChargeDensity 16 | from ase import neighborlist as nbl 17 | 18 | def build_charge_lmdb(inpath, outpath, use_tqdm = False, loud=False, probe_graph_adder = None, stride = 1, cutoff = 6): 19 | ''' 20 | A function used to build LMDB datasets from a directory of VASP calculations 21 | Supports pre-computation of probe graphs by passing in a ProbeGraphAdder object 22 | ''' 23 | a2g = AtomsToGraphs( 24 | max_neigh = 100, 25 | radius = cutoff, 26 | r_energy = False, 27 | r_forces = False, 28 | r_distances = False, 29 | r_fixed = False, 30 | ) 31 | 32 | db = lmdb.open( 33 | os.path.join(outpath, 'charge.lmdb'), 34 | map_size=1099511627776 * 2, 35 | subdir=False, 36 | meminit=False, 37 | map_async=True, 38 | ) 39 | 40 | 41 | paths = os.listdir(inpath) 42 | if use_tqdm: 43 | paths = tqdm(paths) 44 | 45 | for fid, directory in enumerate(paths): 46 | if loud: 47 | print(directory) 48 | 49 | try: 50 | vcd = VaspChargeDensity(os.path.join(inpath, directory, 'CHGCAR')) 51 | atoms = vcd.atoms[-1] 52 | dens = vcd.chg[-1] 53 | 54 | if stride != 1: 55 | dens = dens[::stride, ::stride, ::stride] 56 | 57 | data_object = a2g.convert(atoms) 58 | data_object.charge_density = dens 59 | 60 | if probe_graph_adder is not None: 61 | data_object = probe_graph_adder(object) 62 | 63 | txn = db.begin(write = True) 64 | txn.put(f"{fid}".encode("ascii"), pickle.dumps(data_object,protocol=-1)) 65 | txn.commit() 66 | except: 67 | print('Exception occured for:', directory) 68 | 69 | 70 | txn = db.begin(write = True) 71 | txn.put(f'length'.encode('ascii'), pickle.dumps(fid + 1, protocol=-1)) 72 | txn.commit() 73 | 74 | 75 | db.sync() 76 | db.close() 77 | 78 | 79 | class charge_density: 80 | ''' 81 | Class was formerly used to convert between CHGCAR and .cube formats 82 | Likely to be deprecated in the future 83 | ''' 84 | def __init__(self, inpath=None, spin_polarized = False): 85 | self.spin_polarized = spin_polarized 86 | 87 | if self.spin_polarized == True: 88 | raise NotImplementedError 89 | 90 | self.atoms = [] 91 | 92 | if inpath == None: 93 | self.cell = [[1, 0, 0], 94 | [0, 1, 0], 95 | [0, 0, 1]] 96 | self.charge = [[[]]] 97 | if spin_polarized: 98 | self.polarization = [[[]]] 99 | 100 | elif inpath[-6:] == 'CHGCAR': 101 | self.read_CHGCAR(inpath) 102 | 103 | elif inpath[-5:] == '.cube': 104 | self.read_cube(inpath) 105 | 106 | else: 107 | print('Error: Unknown file type. Currently support filetypes are:') 108 | print('CHGCAR, .cube') 109 | raise NotImplementedError 110 | 111 | 112 | def read_CHGCAR(self, inpath): 113 | with open(inpath) as CHGCAR: 114 | lines = CHGCAR.readlines() 115 | 116 | self.name = lines[0][:-1] 117 | 118 | v1, v2, v3 = [lines[i].split() for i in (2, 3, 4)] 119 | v1 = [float(i) for i in v1] 120 | v2 = [float(i) for i in v2] 121 | v3 = [float(i) for i in v3] 122 | self.cell = [v1, v2, v3] 123 | self.vol = np.dot(np.cross(v1,v2),v3) 124 | 125 | self.atom_types = lines[5].split() 126 | atom_counts = lines[6].split() 127 | self.atom_counts = [int(i) for i in atom_counts] 128 | self.n_atoms = sum(self.atom_counts) 129 | 130 | k = 0 131 | 132 | for j, element in enumerate(self.atom_types): 133 | for i in range(self.atom_counts[j]): 134 | rel_coords = lines[8+k].split() 135 | rel_coords = [float(i) for i in rel_coords] 136 | 137 | coords = np.array(self.cell).T.dot(rel_coords).tolist() 138 | 139 | self.atoms.append({'Num': k, 140 | 'Name': element, 141 | 'pos': coords, 142 | 'rel_pos': rel_coords}) 143 | k += 1 144 | 145 | dims = lines[9+self.n_atoms].split() 146 | self.grid = [int(i) for i in dims] 147 | 148 | i = 10+self.n_atoms 149 | chgs = [] 150 | 151 | while lines[i].split()[0] != 'augmentation': 152 | chgs.extend(lines[i].split()) 153 | i += 1 154 | 155 | chgs = [float(x) for x in chgs] 156 | 157 | self.charge = np.reshape(chgs, self.grid) 158 | self.charge /= self.vol 159 | 160 | for line in lines[i:]: 161 | tokens = line.split() 162 | if tokens[0] == 'augmentation': 163 | k = int(tokens[-2]) - 1 164 | self.atoms[k]['aug'] = [] 165 | else: 166 | self.atoms[k]['aug'].extend([float(j) for j in line.split()]) 167 | 168 | 169 | def write_CHGCAR(self, outpath): 170 | out = '' 171 | out += self.name + '\n' 172 | out += ' 1.000000000000000 \n' 173 | out += f'{self.cell[0][0]:>13.6f}{self.cell[0][1]:>12.6f}{self.cell[0][2]:>12.6f}\n' 174 | out += f'{self.cell[1][0]:>13.6f}{self.cell[1][1]:>12.6f}{self.cell[1][2]:>12.6f}\n' 175 | out += f'{self.cell[2][0]:>13.6f}{self.cell[2][1]:>12.6f}{self.cell[2][2]:>12.6f}\n' 176 | for x in self.atom_types: 177 | out += f' {x:<2}' 178 | out += '\n' 179 | for x in self.atom_counts: 180 | out += f'{x:>6}' 181 | out += '\nDirect\n' 182 | 183 | for atom in self.atoms: 184 | out += f' {atom["rel_pos"][0]:.6f} {atom["rel_pos"][1]:.6f} {atom["rel_pos"][2]:.6f}\n' 185 | 186 | out += f' \n{self.grid[0]:>5}{self.grid[1]:>5}{self.grid[2]:>5}\n' 187 | 188 | chgs = np.reshape(self.charge, np.prod(self.grid)) * self.vol 189 | 190 | line = '' 191 | for i, chg in enumerate(chgs): 192 | line = line + ' ' 193 | if chg >= 1e-12: 194 | exp = int(np.log10(chg)) 195 | if chg >= 1: 196 | exp += 1 197 | line = line + f'{(chg/10**exp):.11f}' + 'E' + f'{exp:+03}' 198 | elif chg <= -1e-12: 199 | exp = int(np.log10(-chg)) 200 | if chg <= -1: 201 | exp += 1 202 | line = line + '-' + f'{(-chg/10**exp):.11f}'[1:] + 'E' + f'{exp:+03}' 203 | else: 204 | line = line + '0.00000000000E+00' 205 | if (i+1) % 5 == 0: 206 | line = line + '\n' 207 | out += line 208 | line = '' 209 | 210 | if line != '': 211 | out += line + ' \n' 212 | 213 | for k, atom in enumerate(self.atoms): 214 | line = '' 215 | out += f'augmentation occupancies{k+1:>4} '+str(len(atom['aug']))+'\n' 216 | for i, aug in enumerate(atom['aug']): 217 | line = line + ' ' 218 | if aug >= 1e-32: 219 | exp = int(np.log10(aug)) 220 | if aug >= 1: 221 | exp += 1 222 | line = line + ' ' + f'{(aug/10**exp):.7f}' + 'E' + f'{exp:+03}' 223 | elif aug <= -1e-32: 224 | exp = int(np.log10(-aug)) 225 | if aug <= -1: 226 | exp += 1 227 | line = line + '-0' + f'{(-aug/10**exp):.7f}'[1:] + 'E' + f'{exp:+03}' 228 | else: 229 | line = line + ' 0.0000000E+00' 230 | if (i+1) % 5 == 0: 231 | line = line + '\n' 232 | out += line 233 | line = '' 234 | if line != '': 235 | out += line + '\n' 236 | 237 | with open(outpath, 'w') as file: 238 | file.write(out) 239 | 240 | 241 | def read_cube(self, inpath): 242 | with open(inpath) as cube: 243 | lines = cube.readlines() 244 | 245 | self.name = lines[0][:-1] 246 | self.n_atoms = int(lines[2].split()[0]) 247 | 248 | dim1 = int(lines[3].split()[0]) 249 | dim2 = int(lines[4].split()[0]) 250 | dim3 = int(lines[5].split()[0]) 251 | 252 | self.grid = [dim1, dim2, dim3] 253 | 254 | v1 = dim1 * np.array([float(x) for x in lines[3].split()[1:]]) * 0.529177 # Converting Bohr to Angstrom 255 | v2 = dim2 * np.array([float(x) for x in lines[4].split()[1:]]) * 0.529177 # Converting Bohr to Angstrom 256 | v3 = dim3 * np.array([float(x) for x in lines[5].split()[1:]]) * 0.529177 # Converting Bohr to Angstrom 257 | 258 | self.cell = [v1.tolist(), v2.tolist(), v3.tolist()] 259 | self.vol = np.dot(np.cross(v1, v2), v3) 260 | 261 | element_counts_dict = {} 262 | self.atoms = [] 263 | 264 | for i in range(self.n_atoms): 265 | line = lines[i + 6].split() 266 | element = int(line[0]) 267 | element = elements_lookup[element - 1] 268 | p1 = float(line[2]) * 0.529177 # Converting Bohr to Angstrom 269 | p2 = float(line[3]) * 0.529177 # Converting Bohr to Angstrom 270 | p3 = float(line[4]) * 0.529177 # Converting Bohr to Angstrom 271 | 272 | coords = [p1, p2, p3] 273 | rel_coords = np.linalg.inv(np.array(self.cell).T).dot(coords) 274 | for i, x in enumerate(rel_coords): 275 | if x < 0: 276 | rel_coords[i] += 1 277 | 278 | 279 | if element in element_counts_dict: 280 | element_counts_dict[element] += 1 281 | else: 282 | element_counts_dict[element] = 1 283 | 284 | self.atoms.append({'Num': i, 285 | 'Name': element, 286 | 'pos': coords, 287 | 'rel_pos': rel_coords.tolist(), 288 | 'aug':[]}) 289 | 290 | atom_types = [] 291 | atom_counts = [] 292 | 293 | for key, value in element_counts_dict.items(): 294 | atom_types.append(key) 295 | atom_counts.append(value) 296 | self.atom_types, self.atom_counts = atom_types, atom_counts 297 | 298 | chgs = [float(lines[6 + self.n_atoms + i]) for i in range(np.prod(self.grid))] 299 | self.charge = np.reshape(chgs, self.grid) 300 | 301 | def write_cube(self, outpath): 302 | raise NotImplementedError 303 | 304 | 305 | def plotly_vis(self): 306 | raise NotImplementedError 307 | 308 | 309 | def __repr__(self): 310 | out = 'Charge Density Object:\n' 311 | out += f'Name: {self.name}\n' 312 | out += f'# of Atoms: {self.n_atoms}\n' 313 | out += f'Charge Points Grid: {self.grid[0]} {self.grid[1]} {self.grid[2]}\n' 314 | return out 315 | -------------------------------------------------------------------------------- /cdm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .schnet_charge import schnet_charge 2 | from .gemnet_charge import GemNetT_charge 3 | from .gemnet_oc_charge import GemNet_OC_charge 4 | from .painn_charge import PaiNN_Charge 5 | from .scn_charge import SCN_Charge 6 | from .chargemodel import ChargeModel -------------------------------------------------------------------------------- /cdm/models/chargemodel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from ase import Atoms 4 | from pymatgen.core.sites import PeriodicSite 5 | from pymatgen.io.ase import AseAtomsAdaptor 6 | import numpy as np 7 | from tqdm import tqdm 8 | import warnings 9 | 10 | from torch_geometric.data import Batch 11 | from torch_geometric.utils import remove_isolated_nodes 12 | 13 | from ocpmodels.datasets import data_list_collater 14 | from ocpmodels.common.registry import registry 15 | from ocpmodels.common.utils import conditional_grad 16 | from ocpmodels.common.utils import pyg2_data_transform 17 | from ocpmodels.common.utils import load_state_dict 18 | 19 | from cdm.utils.probe_graph import ProbeGraphAdder 20 | 21 | 22 | @registry.register_model('charge_model') 23 | class ChargeModel(torch.nn.Module): 24 | def __init__( 25 | self, 26 | atom_model_config, 27 | probe_model_config, 28 | otf_pga_config = { 29 | 'implementation': 'RGPBC', 30 | }, 31 | include_atomic_edges = False, 32 | enforce_zero_for_disconnected_probes = False, 33 | enforce_charge_conservation = False, 34 | freeze_atomic = False, 35 | name = 'charge_model', 36 | ): 37 | super().__init__() 38 | 39 | self.regress_forces = False 40 | self.enforce_zero_for_disconnected_probes = enforce_zero_for_disconnected_probes 41 | self.enforce_charge_conservation = enforce_charge_conservation 42 | self.freeze_atomic = freeze_atomic 43 | 44 | probe_final_mlp = True 45 | 46 | # Initialize atom message-passing model 47 | if 'checkpoint' in atom_model_config: 48 | cfg = torch.load( 49 | atom_model_config['checkpoint'], 50 | map_location=torch.device('cpu') 51 | )['config']['model_attributes'] 52 | else: 53 | cfg = atom_model_config 54 | 55 | 56 | self.atom_message_model = registry.get_model_class(atom_model_config['name'])( 57 | **cfg, 58 | atomic=True, 59 | probe=False, 60 | ) 61 | 62 | if 'checkpoint' in atom_model_config: 63 | self.load_checkpoint( 64 | checkpoint_path = atom_model_config['checkpoint'], 65 | atomic = True, 66 | ) 67 | 68 | 69 | # Initialize probe message-passing model 70 | if 'checkpoint' in probe_model_config: 71 | cfg = torch.load( 72 | probe_model_config['checkpoint'], 73 | map_location=torch.device('cpu') 74 | )['config']['model_attributes'] 75 | else: 76 | cfg = probe_model_config 77 | 78 | 79 | self.probe_message_model = registry.get_model_class(probe_model_config['name'])( 80 | **cfg, 81 | atomic=False, 82 | probe=True, 83 | ) 84 | 85 | if 'checkpoint' in probe_model_config: 86 | self.load_checkpoint( 87 | checkpoint_path = probe_model_config['checkpoint'], 88 | probe = True, 89 | ) 90 | 91 | # Ensure match between atom and probe messaging models 92 | if self.atom_message_model.hidden_channels != self.probe_message_model.hidden_channels: 93 | self.reduce_atom_representations = True 94 | self.atom_reduction = torch.nn.Sequential( 95 | torch.nn.Linear(self.atom_message_model.hidden_channels,self.atom_message_model.hidden_channels), 96 | torch.nn.Sigmoid(), 97 | torch.nn.Linear(self.atom_message_model.hidden_channels, self.probe_message_model.hidden_channels)) 98 | else: 99 | self.reduce_atom_representations = False 100 | 101 | assert self.atom_message_model.num_interactions >= self.probe_message_model.num_interactions 102 | 103 | # Compatibility for specific models 104 | if probe_model_config['name'] == 'scn_charge': 105 | probe_final_mlp = False 106 | 107 | if probe_final_mlp: 108 | self.probe_output_function = torch.nn.Sequential( 109 | torch.nn.Linear(self.probe_message_model.hidden_channels, self.probe_message_model.hidden_channels), 110 | torch.nn.ELU(), 111 | torch.nn.Linear(self.probe_message_model.hidden_channels, 1) 112 | ) 113 | 114 | self.otf_pga = ProbeGraphAdder(**otf_pga_config) 115 | 116 | 117 | @conditional_grad(torch.enable_grad()) 118 | def forward(self, data): 119 | # Ensure data has probe points 120 | data = self.otf_pga(data) 121 | data.probe_data = [pyg2_data_transform(data.probe_data)] 122 | data.probe_data = Batch.from_data_list(data.probe_data) 123 | 124 | atom_representations = self.forward_atomic(data) 125 | 126 | probes = self.forward_probe(data.probe_data, atom_representations) 127 | 128 | return probes 129 | 130 | @conditional_grad(torch.enable_grad()) 131 | def forward_atomic(self, data): 132 | if self.freeze_atomic: 133 | with torch.no_grad(): 134 | atom_representations = self.atom_message_model(data) 135 | else: 136 | atom_representations = self.atom_message_model(data) 137 | 138 | if self.reduce_atom_representations: 139 | atom_representations = [self.atom_reduction(rep).float() for rep in atom_representations] 140 | 141 | return(atom_representations) 142 | 143 | @conditional_grad(torch.enable_grad()) 144 | def forward_probe(self, data, atom_representations): 145 | data.atom_representations = atom_representations[-self.probe_message_model.num_interactions:] 146 | 147 | probe_results = self.probe_message_model(data) 148 | 149 | if hasattr(self, 'probe_output_function'): 150 | probe_results = self.probe_output_function(probe_results).flatten() 151 | 152 | probe_results = torch.nan_to_num(probe_results) 153 | 154 | if self.enforce_zero_for_disconnected_probes: 155 | is_probe = data.atomic_numbers == 0 156 | _, _, is_not_isolated = remove_isolated_nodes(data.edge_index, num_nodes = len(data.atomic_numbers)) 157 | is_isolated = ~is_not_isolated 158 | 159 | if torch.all(is_isolated): 160 | warnings.warn('All probes are isolated - not enforcing zero charge constraint') 161 | else: 162 | probe_results[is_isolated[is_probe]] = torch.zeros_like(probe_results[is_isolated[is_probe]]) 163 | 164 | if self.enforce_charge_conservation: 165 | if torch.sum(probe_results) == 0: 166 | warnings.warn('Charge prediction is 0 - cannot enforce charge conservation!') 167 | else: 168 | data.total_target = data.total_target.to(probe_results.device) 169 | probe_results *= data.total_target / torch.sum(probe_results) 170 | 171 | return probe_results 172 | 173 | @property 174 | def num_params(self): 175 | return sum(p.numel() for p in self.parameters()) 176 | 177 | 178 | def load_checkpoint( 179 | self, 180 | checkpoint_path, 181 | atomic=False, 182 | probe=False 183 | ): 184 | if not os.path.isfile(checkpoint_path): 185 | raise FileNotFoundError( 186 | errno.ENOENT, "Checkpoint file not found", checkpoint_path 187 | ) 188 | 189 | if atomic: 190 | model = self.atom_message_model 191 | if probe: 192 | model = self.probe_message_model 193 | 194 | map_location = torch.device("cpu") if self.cpu else self.device 195 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 196 | 197 | # Match the "module." count in the keys of model and checkpoint state_dict 198 | # DataParallel model has 1 "module.", DistributedDataParallel has 2 "module." 199 | # Not using either of the above two would have no "module." 200 | 201 | ckpt_key_count = next(iter(checkpoint["state_dict"])).count("module") 202 | mod_key_count = next(iter(model.state_dict())).count("module") 203 | key_count_diff = mod_key_count - ckpt_key_count 204 | 205 | if key_count_diff > 0: 206 | new_dict = { 207 | key_count_diff * "module." + k: v 208 | for k, v in checkpoint["state_dict"].items() 209 | } 210 | elif key_count_diff < 0: 211 | new_dict = { 212 | k[len("module.") * abs(key_count_diff) :]: v 213 | for k, v in checkpoint["state_dict"].items() 214 | } 215 | else: 216 | new_dict = checkpoint["state_dict"] 217 | 218 | # HACK TO LOAD OLD SCHNET CHECKPOINTS 219 | if 'atomic_mass' in new_dict: 220 | del new_dict['atomic_mass'] 221 | 222 | load_state_dict(model, new_dict, strict=False) -------------------------------------------------------------------------------- /cdm/models/gemnet_charge.py: -------------------------------------------------------------------------------- 1 | """ 2 | This source code is licensed under the MIT license found in the 3 | LICENSE file in the root directory of this source tree. 4 | """ 5 | 6 | from typing import Optional 7 | 8 | import numpy as np 9 | import torch 10 | from torch_geometric.nn import radius_graph 11 | from torch_scatter import scatter 12 | from torch_sparse import SparseTensor 13 | 14 | from ocpmodels.common.registry import registry 15 | from ocpmodels.common.utils import ( 16 | compute_neighbors, 17 | conditional_grad, 18 | get_pbc_distances, 19 | radius_graph_pbc, 20 | ) 21 | from ocpmodels.models.base import BaseModel 22 | from ocpmodels.modules.scaling.compat import load_scales_compat 23 | 24 | from ocpmodels.models.gemnet.layers.atom_update_block import OutputBlock 25 | from ocpmodels.models.gemnet.layers.base_layers import Dense 26 | from ocpmodels.models.gemnet.layers.efficient import EfficientInteractionDownProjection 27 | from ocpmodels.models.gemnet.layers.embedding_block import AtomEmbedding, EdgeEmbedding 28 | from ocpmodels.models.gemnet.layers.interaction_block import InteractionBlockTripletsOnly 29 | from ocpmodels.models.gemnet.layers.radial_basis import RadialBasis 30 | from ocpmodels.models.gemnet.layers.spherical_basis import CircularBasisLayer 31 | from ocpmodels.models.gemnet.utils import ( 32 | inner_product_normalized, 33 | mask_neighbors, 34 | ragged_range, 35 | repeat_blocks, 36 | ) 37 | 38 | from ocpmodels.models.gemnet.gemnet import GemNetT 39 | 40 | 41 | @registry.register_model("gemnet_t_charge") 42 | class GemNetT_charge(GemNetT): 43 | 44 | def __init__( 45 | self, 46 | name='gemnet_t_charge', 47 | num_spherical = 16, 48 | num_radial = 16, 49 | num_blocks = 6, 50 | emb_size_atom = 64, 51 | emb_size_edge = 64, 52 | emb_size_trip = 64, 53 | emb_size_rbf = 64, 54 | emb_size_cbf = 64, 55 | emb_size_bil_trip = 64, 56 | num_before_skip = 3, 57 | num_after_skip = 3, 58 | num_concat = 3, 59 | num_atom = 3, 60 | **kwargs, 61 | ): 62 | 63 | self.atomic = kwargs['atomic'] 64 | self.probe = kwargs['probe'] 65 | kwargs.pop('atomic') 66 | kwargs.pop('probe') 67 | 68 | if self.probe: 69 | kwargs['num_elements'] = 84 70 | 71 | super().__init__( 72 | num_atoms = 1, 73 | bond_feat_dim = 1, 74 | num_targets = 1, 75 | num_spherical = 16, 76 | num_radial = num_radial, 77 | num_blocks = num_blocks, 78 | emb_size_atom = emb_size_atom, 79 | emb_size_edge = emb_size_edge, 80 | emb_size_trip = emb_size_trip, 81 | emb_size_rbf = emb_size_rbf, 82 | emb_size_cbf = emb_size_cbf, 83 | emb_size_bil_trip = emb_size_bil_trip, 84 | num_before_skip = num_before_skip, 85 | num_after_skip = num_after_skip, 86 | num_concat = num_concat, 87 | num_atom = num_atom, 88 | otf_graph = False, 89 | **kwargs, 90 | ) 91 | 92 | self.num_interactions = self.num_blocks 93 | 94 | self.hidden_channels = self.atom_emb.emb_size 95 | 96 | 97 | @conditional_grad(torch.enable_grad()) 98 | def forward(self, data): 99 | 100 | pos = data.pos 101 | batch = data.batch 102 | atomic_numbers = data.atomic_numbers.long() 103 | 104 | if self.regress_forces and not self.direct_forces: 105 | pos.requires_grad_(True) 106 | 107 | data.natoms = data.natoms.to(data.pos.device) 108 | data.neighbors = data.neighbors.to(data.pos.device) 109 | 110 | ( 111 | edge_index, 112 | neighbors, 113 | D_st, 114 | V_st, 115 | id_swap, 116 | id3_ba, 117 | id3_ca, 118 | id3_ragged_idx, 119 | ) = self.generate_interaction_graph(data) 120 | idx_s, idx_t = edge_index 121 | 122 | # Calculate triplet angles 123 | cosφ_cab = inner_product_normalized(V_st[id3_ca], V_st[id3_ba]) 124 | rad_cbf3, cbf3 = self.cbf_basis3(D_st, cosφ_cab, id3_ca) 125 | 126 | rbf = self.radial_basis(D_st) 127 | 128 | # Embedding block 129 | if self.atomic: 130 | h = self.atom_emb(atomic_numbers) 131 | if self.probe: 132 | h = self.atom_emb(atomic_numbers + 1) 133 | # (nAtoms, emb_size_atom) 134 | m = self.edge_emb(h, rbf, idx_s, idx_t) # (nEdges, emb_size_edge) 135 | 136 | rbf3 = self.mlp_rbf3(rbf) 137 | cbf3 = self.mlp_cbf3(rad_cbf3, cbf3, id3_ca, id3_ragged_idx) 138 | 139 | rbf_h = self.mlp_rbf_h(rbf) 140 | rbf_out = self.mlp_rbf_out(rbf) 141 | 142 | if self.atomic: 143 | h_list = [] 144 | 145 | for i in range(self.num_blocks): 146 | # Interaction block 147 | h, m = self.int_blocks[i]( 148 | h=h, 149 | m=m, 150 | rbf3=rbf3, 151 | cbf3=cbf3, 152 | id3_ragged_idx=id3_ragged_idx, 153 | id_swap=id_swap, 154 | id3_ba=id3_ba, 155 | id3_ca=id3_ca, 156 | rbf_h=rbf_h, 157 | idx_s=idx_s, 158 | idx_t=idx_t, 159 | ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) 160 | 161 | h_list.append(h) 162 | return h_list 163 | 164 | if self.probe: 165 | 166 | atom_indices = torch.nonzero(data.atomic_numbers).flatten() 167 | probe_indices = (data.atomic_numbers == 0).nonzero().flatten() 168 | 169 | for i in range(self.num_blocks): 170 | h[atom_indices] = data.atom_representations[i] 171 | # Interaction block 172 | h, m = self.int_blocks[i]( 173 | h=h, 174 | m=m, 175 | rbf3=rbf3, 176 | cbf3=cbf3, 177 | id3_ragged_idx=id3_ragged_idx, 178 | id_swap=id_swap, 179 | id3_ba=id3_ba, 180 | id3_ca=id3_ca, 181 | rbf_h=rbf_h, 182 | idx_s=idx_s, 183 | idx_t=idx_t, 184 | ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) 185 | return h[probe_indices] -------------------------------------------------------------------------------- /cdm/models/gemnet_oc_charge.py: -------------------------------------------------------------------------------- 1 | """ 2 | This source code is licensed under the MIT license found in the 3 | LICENSE file in the root directory of this source tree. 4 | """ 5 | 6 | import logging 7 | import os 8 | from typing import Optional 9 | 10 | import numpy as np 11 | import torch 12 | from torch_geometric.nn import radius_graph 13 | from torch_scatter import scatter, segment_coo 14 | 15 | from ocpmodels.common.registry import registry 16 | from ocpmodels.common.utils import ( 17 | compute_neighbors, 18 | conditional_grad, 19 | get_max_neighbors_mask, 20 | get_pbc_distances, 21 | radius_graph_pbc, 22 | ) 23 | from ocpmodels.models.base import BaseModel 24 | from ocpmodels.modules.scaling.compat import load_scales_compat 25 | 26 | from ocpmodels.models.gemnet_oc.initializers import get_initializer 27 | from ocpmodels.models.gemnet_oc.interaction_indices import ( 28 | get_mixed_triplets, 29 | get_quadruplets, 30 | get_triplets, 31 | ) 32 | from ocpmodels.models.gemnet_oc.layers.atom_update_block import OutputBlock 33 | from ocpmodels.models.gemnet_oc.layers.base_layers import Dense, ResidualLayer 34 | from ocpmodels.models.gemnet_oc.layers.efficient import BasisEmbedding 35 | from ocpmodels.models.gemnet_oc.layers.embedding_block import AtomEmbedding, EdgeEmbedding 36 | from ocpmodels.models.gemnet_oc.layers.force_scaler import ForceScaler 37 | from ocpmodels.models.gemnet_oc.layers.interaction_block import InteractionBlock 38 | from ocpmodels.models.gemnet_oc.layers.radial_basis import RadialBasis 39 | from ocpmodels.models.gemnet_oc.layers.spherical_basis import CircularBasisLayer, SphericalBasisLayer 40 | from ocpmodels.models.gemnet_oc.utils import ( 41 | get_angle, 42 | get_edge_id, 43 | get_inner_idx, 44 | inner_product_clamped, 45 | mask_neighbors, 46 | repeat_blocks, 47 | ) 48 | 49 | from ocpmodels.models.gemnet_oc.gemnet_oc import GemNetOC 50 | 51 | 52 | @registry.register_model("gemnet_oc_charge") 53 | class GemNet_OC_charge(GemNetOC): 54 | 55 | def __init__( 56 | self, 57 | name='gemnet_oc_charge', 58 | num_spherical = 16, 59 | num_radial = 16, 60 | num_blocks = 6, 61 | emb_size_atom = 32, 62 | emb_size_edge = 32, 63 | emb_size_trip = 32, 64 | emb_size_rbf = 32, 65 | emb_size_cbf = 32, 66 | emb_size_sbf = 32, 67 | emb_size_trip_in = 32, 68 | emb_size_trip_out = 32, 69 | emb_size_quad_in = 32, 70 | emb_size_quad_out = 32, 71 | emb_size_aint_in = 32, 72 | emb_size_aint_out = 32, 73 | num_before_skip = 3, 74 | num_after_skip = 3, 75 | num_concat = 3, 76 | num_atom = 3, 77 | num_output_afteratom = 3, 78 | **kwargs, 79 | ): 80 | 81 | self.atomic = kwargs['atomic'] 82 | self.probe = kwargs['probe'] 83 | kwargs.pop('atomic') 84 | kwargs.pop('probe') 85 | 86 | if self.probe: 87 | kwargs['num_elements'] = 84 88 | 89 | kwargs['otf_graph'] = False 90 | 91 | super().__init__( 92 | num_atoms = 1, 93 | bond_feat_dim = 1, 94 | num_targets = 1, 95 | num_spherical = num_spherical, 96 | num_radial = num_radial, 97 | num_blocks = num_blocks, 98 | emb_size_atom = emb_size_atom, 99 | emb_size_edge = emb_size_edge, 100 | emb_size_trip_in = emb_size_trip_in, 101 | emb_size_trip_out = emb_size_trip_out, 102 | emb_size_quad_in = emb_size_quad_in, 103 | emb_size_quad_out = emb_size_quad_out, 104 | emb_size_aint_in = emb_size_aint_in, 105 | emb_size_aint_out = emb_size_aint_out, 106 | emb_size_rbf = emb_size_rbf, 107 | emb_size_cbf = emb_size_cbf, 108 | emb_size_sbf = emb_size_sbf, 109 | num_before_skip = num_before_skip, 110 | num_after_skip = num_after_skip, 111 | num_concat = num_concat, 112 | num_atom = num_atom, 113 | num_output_afteratom = num_output_afteratom, 114 | **kwargs, 115 | ) 116 | 117 | self.num_interactions = self.num_blocks 118 | 119 | self.hidden_channels = self.atom_emb.emb_size 120 | 121 | 122 | @conditional_grad(torch.enable_grad()) 123 | def forward(self, data): 124 | 125 | pos = data.pos 126 | batch = data.batch 127 | atomic_numbers = data.atomic_numbers.long() 128 | num_atoms = atomic_numbers.shape[0] 129 | 130 | if self.regress_forces and not self.direct_forces: 131 | pos.requires_grad_(True) 132 | 133 | data.natoms = data.natoms.to(data.pos.device) 134 | data.tags = torch.ones_like(data.atomic_numbers) 135 | data.neighbors = data.neighbors.to(data.pos.device) 136 | 137 | ( 138 | main_graph, 139 | a2a_graph, 140 | a2ee2a_graph, 141 | qint_graph, 142 | id_swap, 143 | trip_idx_e2e, 144 | trip_idx_a2e, 145 | trip_idx_e2a, 146 | quad_idx, 147 | ) = self.get_graphs_and_indices(data) 148 | _, idx_t = main_graph["edge_index"] 149 | 150 | ( 151 | basis_rad_raw, 152 | basis_atom_update, 153 | basis_output, 154 | bases_qint, 155 | bases_e2e, 156 | bases_a2e, 157 | bases_e2a, 158 | basis_a2a_rad, 159 | ) = self.get_bases( 160 | main_graph=main_graph, 161 | a2a_graph=a2a_graph, 162 | a2ee2a_graph=a2ee2a_graph, 163 | qint_graph=qint_graph, 164 | trip_idx_e2e=trip_idx_e2e, 165 | trip_idx_a2e=trip_idx_a2e, 166 | trip_idx_e2a=trip_idx_e2a, 167 | quad_idx=quad_idx, 168 | num_atoms=num_atoms, 169 | ) 170 | 171 | 172 | # Embedding block 173 | if self.atomic: 174 | h = self.atom_emb(atomic_numbers) 175 | if self.probe: 176 | h = self.atom_emb(atomic_numbers + 1) 177 | # (nAtoms, emb_size_atom) 178 | 179 | m = self.edge_emb(h, basis_rad_raw, main_graph["edge_index"]) 180 | # (nEdges, emb_size_edge) 181 | 182 | if self.atomic: 183 | h_list = [] 184 | 185 | for i in range(self.num_blocks): 186 | # Interaction block 187 | h, m = self.int_blocks[i]( 188 | h=h, 189 | m=m, 190 | bases_qint=bases_qint, 191 | bases_e2e=bases_e2e, 192 | bases_a2e=bases_a2e, 193 | bases_e2a=bases_e2a, 194 | basis_a2a_rad=basis_a2a_rad, 195 | basis_atom_update=basis_atom_update, 196 | edge_index_main=main_graph["edge_index"], 197 | a2ee2a_graph=a2ee2a_graph, 198 | a2a_graph=a2a_graph, 199 | id_swap=id_swap, 200 | trip_idx_e2e=trip_idx_e2e, 201 | trip_idx_a2e=trip_idx_a2e, 202 | trip_idx_e2a=trip_idx_e2a, 203 | quad_idx=quad_idx, 204 | ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) 205 | 206 | h_list.append(h) 207 | return h_list 208 | 209 | if self.probe: 210 | atom_indices = torch.nonzero(data.atomic_numbers).flatten() 211 | probe_indices = (data.atomic_numbers == 0).nonzero().flatten() 212 | 213 | for i in range(self.num_blocks): 214 | h[atom_indices] = data.atom_representations[i] 215 | 216 | # Interaction block 217 | h, m = self.int_blocks[i]( 218 | h=h, 219 | m=m, 220 | bases_qint=bases_qint, 221 | bases_e2e=bases_e2e, 222 | bases_a2e=bases_a2e, 223 | bases_e2a=bases_e2a, 224 | basis_a2a_rad=basis_a2a_rad, 225 | basis_atom_update=basis_atom_update, 226 | edge_index_main=main_graph["edge_index"], 227 | a2ee2a_graph=a2ee2a_graph, 228 | a2a_graph=a2a_graph, 229 | id_swap=id_swap, 230 | trip_idx_e2e=trip_idx_e2e, 231 | trip_idx_a2e=trip_idx_a2e, 232 | trip_idx_e2a=trip_idx_e2a, 233 | quad_idx=quad_idx, 234 | ) # (nAtoms, emb_size_atom), (nEdges, emb_size_edge) 235 | return h[probe_indices] -------------------------------------------------------------------------------- /cdm/models/painn_charge.py: -------------------------------------------------------------------------------- 1 | """ 2 | This source code is licensed under the MIT license found in the 3 | LICENSE file in the root directory of this source tree. 4 | """ 5 | 6 | from ocpmodels.common.registry import registry 7 | import torch 8 | from ocpmodels.models.painn.painn import PaiNN 9 | from typing import Optional, Tuple 10 | 11 | from ocpmodels.common.utils import ( 12 | conditional_grad, 13 | get_pbc_distances, 14 | radius_graph_pbc, 15 | ) 16 | 17 | @registry.register_model("painn_charge") 18 | class PaiNN_Charge(PaiNN): 19 | def __init__( 20 | self, 21 | name = 'painn_charge', 22 | **kwargs, 23 | ): 24 | 25 | self.atomic = kwargs['atomic'] 26 | self.probe = kwargs['probe'] 27 | kwargs.pop('atomic') 28 | kwargs.pop('probe') 29 | 30 | if self.probe: 31 | kwargs['num_elements'] = 84 32 | 33 | super().__init__( 34 | num_atoms = 1, 35 | bond_feat_dim = 1, 36 | num_targets = 1, 37 | otf_graph = False, 38 | **kwargs, 39 | ) 40 | 41 | self.num_interactions = self.num_layers 42 | 43 | del self.out_energy 44 | del self.out_forces 45 | 46 | @conditional_grad(torch.enable_grad()) 47 | def forward(self, data): 48 | data.pos = data.pos.float() 49 | pos = data.pos 50 | 51 | batch = data.batch 52 | z = data.atomic_numbers.long() 53 | 54 | data.natoms = data.natoms.to(data.cell.device) 55 | data.neighbors = data.neighbors.to(data.cell.device) 56 | 57 | ( 58 | edge_index, 59 | neighbors, 60 | edge_dist, 61 | edge_vector, 62 | id_swap, 63 | ) = self.generate_graph_values(data) 64 | 65 | assert z.dim() == 1 and z.dtype == torch.long 66 | 67 | edge_rbf = self.radial_basis(edge_dist) # rbf * envelope 68 | 69 | 70 | #### Interaction blocks ############################################### 71 | 72 | if self.atomic: 73 | x = self.atom_emb(z) 74 | vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) 75 | 76 | x_list = [] 77 | for i in range(self.num_layers): 78 | dx, dvec = self.message_layers[i]( 79 | x, vec, edge_index, edge_rbf, edge_vector 80 | ) 81 | 82 | x = x + dx 83 | vec = vec + dvec 84 | x = x * self.inv_sqrt_2 85 | 86 | dx, dvec = self.update_layers[i](x, vec) 87 | 88 | x = x + dx 89 | vec = vec + dvec 90 | x = getattr(self, "upd_out_scalar_scale_%d" % i)(x) 91 | 92 | x_list.append(x) 93 | 94 | atom_representations = x_list 95 | return atom_representations 96 | 97 | if self.probe: 98 | x = self.atom_emb(z + 1) # +1 allows for probe embeddings 99 | vec = torch.zeros(x.size(0), 3, x.size(1), device=x.device) 100 | 101 | atom_indices = torch.nonzero(z).flatten() 102 | probe_indices = (z == 0).nonzero().flatten() 103 | 104 | for i in range(self.num_layers): 105 | x[atom_indices] = data.atom_representations[i] 106 | 107 | dx, dvec = self.message_layers[i]( 108 | x, vec, edge_index, edge_rbf, edge_vector 109 | ) 110 | 111 | x = x + dx 112 | vec = vec + dvec 113 | x = x * self.inv_sqrt_2 114 | 115 | dx, dvec = self.update_layers[i](x, vec) 116 | 117 | x = x + dx 118 | vec = vec + dvec 119 | x = getattr(self, "upd_out_scalar_scale_%d" % i)(x) 120 | 121 | return x[probe_indices] 122 | -------------------------------------------------------------------------------- /cdm/models/schnet_charge.py: -------------------------------------------------------------------------------- 1 | """ 2 | This source code is licensed under the MIT license found in the 3 | LICENSE file in the root directory of this source tree. 4 | """ 5 | 6 | from ocpmodels.common.registry import registry 7 | import torch 8 | from ocpmodels.models.schnet import SchNetWrap as SchNet 9 | 10 | from ocpmodels.common.utils import ( 11 | conditional_grad, 12 | get_pbc_distances, 13 | radius_graph_pbc, 14 | ) 15 | 16 | @registry.register_model("schnet_charge") 17 | class schnet_charge(SchNet): 18 | def __init__( 19 | self, 20 | name='schnet_charge', 21 | **kwargs, 22 | ): 23 | 24 | self.atomic = kwargs['atomic'] 25 | self.probe = kwargs['probe'] 26 | kwargs.pop('atomic') 27 | kwargs.pop('probe') 28 | 29 | super().__init__( 30 | num_atoms = 1, 31 | bond_feat_dim = 1, 32 | num_targets = 1, 33 | otf_graph = False, 34 | **kwargs, 35 | ) 36 | 37 | if hasattr(self, 'lin1'): 38 | del self.lin1 39 | if hasattr(self, 'lin2'): 40 | del self.lin2 41 | 42 | @conditional_grad(torch.enable_grad()) 43 | def forward(self, data): 44 | z = data.atomic_numbers.long() 45 | pos = data.pos 46 | batch = data.batch 47 | 48 | ( 49 | edge_index, 50 | edge_weight, 51 | distance_vec, 52 | cell_offsets, 53 | _, # cell offset distances 54 | neighbors, 55 | ) = self.generate_graph(data) 56 | 57 | if self.use_pbc: 58 | assert z.dim() == 1 and z.dtype == torch.long 59 | 60 | edge_attr = self.distance_expansion(edge_weight) 61 | 62 | h = self.embedding(z) 63 | 64 | if self.atomic: 65 | h_list = [] 66 | for i, interaction, in enumerate(self.interactions): 67 | h = h + interaction(h, edge_index, edge_weight, edge_attr) 68 | h_list.append(h) 69 | 70 | atom_representations = h_list 71 | return atom_representations 72 | 73 | if self.probe: 74 | atom_indices = torch.nonzero(data.atomic_numbers).flatten() 75 | probe_indices = (data.atomic_numbers == 0).nonzero().flatten() 76 | 77 | edge_weight = edge_weight.float() 78 | edge_attr = edge_attr.float() 79 | 80 | for interaction_number, interaction in enumerate(self.interactions): 81 | h[atom_indices] = data.atom_representations[interaction_number] 82 | h = h + interaction(h, edge_index, edge_weight, edge_attr) 83 | 84 | return h[probe_indices] 85 | -------------------------------------------------------------------------------- /cdm/models/scn_charge.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import logging 9 | import time 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | from torch_geometric.nn import radius_graph 15 | 16 | from torch_geometric.utils import sort_edge_index 17 | 18 | from ocpmodels.models.scn.scn import SphericalChannelNetwork as SCN 19 | 20 | from ocpmodels.common.registry import registry 21 | from ocpmodels.common.utils import ( 22 | conditional_grad, 23 | get_pbc_distances, 24 | radius_graph_pbc, 25 | ) 26 | from ocpmodels.models.base import BaseModel 27 | from ocpmodels.models.scn.sampling import CalcSpherePoints 28 | from ocpmodels.models.scn.smearing import ( 29 | GaussianSmearing, 30 | LinearSigmoidSmearing, 31 | SigmoidSmearing, 32 | SiLUSmearing, 33 | ) 34 | from ocpmodels.models.scn.spherical_harmonics import SphericalHarmonicsHelper 35 | 36 | try: 37 | import e3nn 38 | from e3nn import o3 39 | except ImportError: 40 | pass 41 | 42 | 43 | @registry.register_model("scn_charge") 44 | class SCN_Charge(SCN): 45 | """Spherical Channel Network 46 | Paper: Spherical Channels for Modeling Atomic Interactions 47 | 48 | Args: 49 | use_pbc (bool): Use periodic boundary conditions 50 | regress_forces (bool): Compute forces 51 | otf_graph (bool): Compute graph On The Fly (OTF) 52 | max_num_neighbors (int): Maximum number of neighbors per atom 53 | cutoff (float): Maximum distance between nieghboring atoms in Angstroms 54 | max_num_elements (int): Maximum atomic number 55 | 56 | num_interactions (int): Number of layers in the GNN 57 | lmax (int): Maximum degree of the spherical harmonics (1 to 10) 58 | mmax (int): Maximum order of the spherical harmonics (0 or 1) 59 | num_resolutions (int): Number of resolutions used to compute messages, further away atoms has lower resolution (1 or 2) 60 | sphere_channels (int): Number of spherical channels 61 | sphere_channels_reduce (int): Number of spherical channels used during message passing (downsample or upsample) 62 | hidden_channels (int): Number of hidden units in message passing 63 | num_taps (int): Number of taps or rotations used during message passing (1 or otherwise set automatically based on mmax) 64 | 65 | use_grid (bool): Use non-linear pointwise convolution during aggregation 66 | num_bands (int): Number of bands used during message aggregation for the 1x1 pointwise convolution (1 or 2) 67 | 68 | num_sphere_samples (int): Number of samples used to approximate the integration of the sphere in the output blocks 69 | num_basis_functions (int): Number of basis functions used for distance and atomic number blocks 70 | distance_function ("gaussian", "sigmoid", "linearsigmoid", "silu"): Basis function used for distances 71 | basis_width_scalar (float): Width of distance basis function 72 | distance_resolution (float): Distance between distance basis functions in Angstroms 73 | 74 | show_timing_info (bool): Show timing and memory info 75 | """ 76 | 77 | def __init__( 78 | self, 79 | name = 'scn_charge', 80 | **kwargs, 81 | ): 82 | 83 | self.atomic = kwargs['atomic'] 84 | self.probe = kwargs['probe'] 85 | kwargs.pop('atomic') 86 | kwargs.pop('probe') 87 | 88 | if 'max_num_neighbors' not in kwargs: 89 | kwargs['max_num_neighbors'] = 10000 90 | print('ping') 91 | if 'show_timing_info' not in kwargs: 92 | kwargs['show_timing_info'] = False 93 | 94 | super().__init__( 95 | num_atoms = 1, 96 | bond_feat_dim = 1, 97 | num_targets = 1, 98 | otf_graph = False, 99 | **kwargs, 100 | ) 101 | 102 | 103 | @conditional_grad(torch.enable_grad()) 104 | def _forward_helper(self, data): 105 | atomic_numbers = data.atomic_numbers.long() 106 | 107 | num_atoms = len(atomic_numbers) 108 | pos = data.pos 109 | 110 | # Necessary for _rank_edge_distances 111 | data.edge_index = sort_edge_index(data.edge_index.flipud()).flipud() 112 | 113 | ( 114 | edge_index, 115 | edge_distance, 116 | edge_distance_vec, 117 | cell_offsets, 118 | _, # cell offset distances 119 | neighbors, 120 | ) = self.generate_graph(data) 121 | 122 | ############################################################### 123 | # Initialize data structures 124 | ############################################################### 125 | 126 | # Calculate which message block each edge should use. Based on edge distance rank. 127 | edge_rank = self._rank_edge_distances( 128 | edge_distance, edge_index, self.max_num_neighbors, 129 | ) 130 | 131 | # Reorder edges so that they are grouped by distance rank (lowest to highest) 132 | last_cutoff = -0.1 133 | message_block_idx = torch.zeros(len(edge_distance), device=pos.device) 134 | edge_distance_reorder = torch.tensor([], device=self.device) 135 | edge_index_reorder = torch.tensor([], device=self.device) 136 | edge_distance_vec_reorder = torch.tensor([], device=self.device) 137 | cutoff_index = torch.tensor([0], device=self.device) 138 | for i in range(self.num_resolutions): 139 | mask = torch.logical_and( 140 | edge_rank.gt(last_cutoff), edge_rank.le(self.cutoff_list[i]) 141 | ) 142 | last_cutoff = self.cutoff_list[i] 143 | message_block_idx.masked_fill_(mask, i) 144 | edge_distance_reorder = torch.cat( 145 | [ 146 | edge_distance_reorder, 147 | torch.masked_select(edge_distance, mask), 148 | ], 149 | dim=0, 150 | ) 151 | edge_index_reorder = torch.cat( 152 | [ 153 | edge_index_reorder, 154 | torch.masked_select( 155 | edge_index, mask.view(1, -1).repeat(2, 1) 156 | ).view(2, -1), 157 | ], 158 | dim=1, 159 | ) 160 | edge_distance_vec_mask = torch.masked_select( 161 | edge_distance_vec, mask.view(-1, 1).repeat(1, 3) 162 | ).view(-1, 3) 163 | edge_distance_vec_reorder = torch.cat( 164 | [edge_distance_vec_reorder, edge_distance_vec_mask], dim=0 165 | ) 166 | cutoff_index = torch.cat( 167 | [ 168 | cutoff_index, 169 | torch.tensor( 170 | [len(edge_distance_reorder)], device=self.device 171 | ), 172 | ], 173 | dim=0, 174 | ) 175 | 176 | edge_index = edge_index_reorder.long() 177 | edge_distance = edge_distance_reorder 178 | edge_distance_vec = edge_distance_vec_reorder 179 | 180 | # Compute 3x3 rotation matrix per edge 181 | edge_rot_mat = self._init_edge_rot_mat( 182 | data, edge_index, edge_distance_vec 183 | ) 184 | 185 | # Initialize the WignerD matrices and other values for spherical harmonic calculations 186 | for i in range(self.num_resolutions): 187 | self.sphharm_list[i].InitWignerDMatrix( 188 | edge_rot_mat[cutoff_index[i] : cutoff_index[i + 1]], 189 | ) 190 | 191 | ############################################################### 192 | # Initialize node embeddings 193 | ############################################################### 194 | 195 | # Init per node representations using an atomic number based embedding 196 | x = torch.zeros( 197 | num_atoms, 198 | self.sphere_basis, 199 | self.sphere_channels, 200 | device=pos.device, 201 | ) 202 | x[:, 0, :] = self.sphere_embedding(atomic_numbers) 203 | 204 | ############################################################### 205 | # Update spherical node embeddings 206 | ############################################################### 207 | 208 | if self.atomic: 209 | atom_representations = [] 210 | for i, interaction in enumerate(self.edge_blocks): 211 | if i > 0: 212 | x = x + interaction( 213 | x, atomic_numbers, edge_distance, edge_index, cutoff_index 214 | ) 215 | atom_representations.append(x) 216 | else: 217 | x = interaction( 218 | x, atomic_numbers, edge_distance, edge_index, cutoff_index 219 | ) 220 | atom_representations.append(x) 221 | return atom_representations 222 | 223 | 224 | if self.probe: 225 | atom_indices = torch.nonzero(data.atomic_numbers).flatten() 226 | probe_indices = (data.atomic_numbers == 0).nonzero().flatten() 227 | 228 | for i, interaction in enumerate(self.edge_blocks): 229 | if i > 0: 230 | x = x + interaction( 231 | x, atomic_numbers, edge_distance, edge_index, cutoff_index 232 | ) 233 | x[atom_indices] = data.atom_representations[i] 234 | else: 235 | x = interaction( 236 | x, atomic_numbers, edge_distance, edge_index, cutoff_index 237 | ) 238 | x[atom_indices] = data.atom_representations[i] 239 | 240 | ############################################################### 241 | # Predict electron density 242 | ############################################################### 243 | 244 | # Create a roughly evenly distributed point sampling of the sphere 245 | sphere_points = CalcSpherePoints( 246 | self.num_sphere_samples, x.device 247 | ).detach() 248 | sphharm_weights = o3.spherical_harmonics( 249 | torch.arange(0, self.lmax + 1).tolist(), sphere_points, False 250 | ).detach() 251 | 252 | # Density estimation 253 | node_energy = torch.einsum( 254 | "abc, pb->apc", x, sphharm_weights 255 | ).contiguous() 256 | node_energy = node_energy.view(-1, self.sphere_channels) 257 | node_energy = self.act(self.energy_fc1(node_energy)) 258 | node_energy = self.act(self.energy_fc2(node_energy)) 259 | node_energy = self.energy_fc3(node_energy) 260 | node_energy = node_energy.view(-1, self.num_sphere_samples, 1) 261 | node_density = torch.sum(node_energy, dim=1) / self.num_sphere_samples 262 | 263 | return node_density[probe_indices] -------------------------------------------------------------------------------- /cdm/tests/test_probe_graph_adder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import warnings 5 | 6 | from ase.build import molecule 7 | from ase.calculators.vasp import VaspChargeDensity 8 | 9 | from ocpmodels.common.utils import radius_graph_pbc 10 | from ocpmodels.models.base import BaseModel 11 | from ocpmodels.preprocessing.atoms_to_graphs import AtomsToGraphs 12 | from ocpmodels.datasets import data_list_collater 13 | 14 | from cdm.utils.probe_graph import calculate_grid_pos, ProbeGraphAdder, get_edges_from_choice 15 | 16 | def old_calculate_grid_pos(shape, cell): 17 | ngridpts = np.array(shape) 18 | grid_pos = np.meshgrid( 19 | np.arange(ngridpts[0]) / shape[0], 20 | np.arange(ngridpts[1]) / shape[1], 21 | np.arange(ngridpts[2]) / shape[2], 22 | indexing="ij", 23 | ) 24 | grid_pos = np.stack(grid_pos, 3) 25 | grid_pos = np.dot(grid_pos, cell) 26 | 27 | return torch.tensor(grid_pos, dtype=torch.float) 28 | 29 | def end_to_end_graph_gen(device): 30 | structure = molecule('H') 31 | structure.cell = [[10, 0, 0], 32 | [0, 10, 0], 33 | [0, 0, 10]] 34 | 35 | cell = torch.tensor(np.array(structure.cell), dtype=torch.float) 36 | 37 | structure.positions = [[0, 9, 0]] 38 | 39 | a2g = AtomsToGraphs() 40 | data = a2g.convert(structure) 41 | data.charge_density = [[[0]]] 42 | 43 | data.to(device) 44 | 45 | pga = ProbeGraphAdder(num_probes = 1) 46 | data = pga(data) 47 | 48 | model = BaseModel() 49 | model.otf_graph = False 50 | 51 | ( 52 | edge_index, 53 | edge_weight, 54 | distance_vec, 55 | cell_offsets, 56 | cell_offset_distances, 57 | neighbors, 58 | ) = model.generate_graph( 59 | data_list_collater([data.probe_data]), 60 | cutoff = 1000, 61 | max_neighbors=100, 62 | use_pbc = True, 63 | otf_graph = False, 64 | ) 65 | 66 | if edge_weight.item() == 19: 67 | print('Offsets are likely in the wrong direction!') 68 | 69 | return edge_weight.item() 70 | 71 | def test_calculate_grid_pos(): 72 | # Base case 73 | shape = [2, 2, 2] 74 | cell = [[[1, 0, 0], 75 | [0, 1, 0], 76 | [0, 0, 1]]] 77 | cell = torch.tensor(cell, dtype=torch.float) 78 | 79 | assert torch.allclose(calculate_grid_pos(shape, cell), old_calculate_grid_pos(shape, cell)) 80 | 81 | # Non-uniform spacing 82 | shape = [20, 5, 2] 83 | cell = [[[1, 0, 0], 84 | [0, 1, 0], 85 | [0, 0, 1]]] 86 | cell = torch.tensor(cell, dtype=torch.float) 87 | 88 | assert torch.allclose(calculate_grid_pos(shape, cell), old_calculate_grid_pos(shape, cell)) 89 | 90 | # Rectangular cell 91 | shape = [7, 7, 7] 92 | cell = [[[3, 0, 0], 93 | [0, 4, 0], 94 | [0, 0, 5]]] 95 | cell = torch.tensor(cell, dtype=torch.float) 96 | 97 | assert torch.allclose(calculate_grid_pos(shape, cell), old_calculate_grid_pos(shape, cell)) 98 | 99 | # Skew cell 100 | shape = [31, 16, 44] 101 | cell = [[[1, 5, 0], 102 | [0, 4, 0], 103 | [0, 1, 6]]] 104 | cell = torch.tensor(cell, dtype=torch.float) 105 | 106 | assert torch.allclose(calculate_grid_pos(shape, cell), old_calculate_grid_pos(shape, cell)) 107 | 108 | # Edge case 109 | shape = [1, 1, 1] 110 | cell = [[[1, 0, 0], 111 | [0, 1, 0], 112 | [0, 0, 1]]] 113 | cell = torch.tensor(cell, dtype=torch.float) 114 | 115 | assert torch.allclose(calculate_grid_pos(shape, cell), old_calculate_grid_pos(shape, cell)) 116 | 117 | 118 | def test_get_edges_from_choice(): 119 | vcd = VaspChargeDensity( 120 | os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_structure") 121 | ) 122 | 123 | atoms = vcd.atoms[0] 124 | dens = vcd.chg[0] 125 | cell = torch.tensor(np.array([atoms.cell.array]), dtype=torch.float) 126 | grid_pos = calculate_grid_pos(dens.shape, cell) 127 | 128 | probe_choice = np.random.randint(np.prod(grid_pos.shape[0:3]), size = 100) 129 | probe_choice = np.unravel_index(probe_choice, grid_pos.shape[0:3]) 130 | 131 | cutoff = 6 132 | 133 | out1 = get_edges_from_choice( 134 | probe_choice, 135 | grid_pos, 136 | atom_pos = torch.tensor(atoms.get_positions()), 137 | cell = cell, 138 | cutoff= cutoff, 139 | include_atomic_edges = False, 140 | implementation = 'ASE', 141 | ) 142 | 143 | out2 = get_edges_from_choice( 144 | probe_choice, 145 | grid_pos, 146 | atom_pos = torch.tensor(atoms.get_positions()), 147 | cell = cell, 148 | cutoff = cutoff, 149 | include_atomic_edges = False, 150 | implementation = 'RGPBC', 151 | ) 152 | 153 | features1 = torch.cat( 154 | (out1[0], out1[1].T), dim=0 155 | ).T 156 | 157 | features2 = torch.cat( 158 | (out2[0], out2[1].T), dim=0 159 | ).T 160 | 161 | # Convert rows of tensors to sets. The order of edges is not guaranteed 162 | features1 = {tuple(x.tolist()) for x in features1} 163 | features2 = {tuple(x.tolist()) for x in features2} 164 | 165 | # Ensure sets are not empty 166 | assert len(features1) > 0 167 | assert len(features2) > 0 168 | 169 | # Ensure sets are the same 170 | assert features1 == features2 171 | 172 | assert (out1[2] == out2[2]).all() 173 | 174 | def test_end_to_end_graph_gen_cpu(): 175 | assert end_to_end_graph_gen('cpu') == 1 176 | 177 | def test_end_to_end_graph_gen_cuda(): 178 | if torch.cuda.is_available(): 179 | assert end_to_end_graph_gen('cuda') == 1 180 | else: 181 | warnings.warn('cannot test graph generation on cuda') 182 | pass 183 | 184 | if __name__ == "__main__": 185 | test_calculate_grid_pos() 186 | print('Pass: test_calculate_grid_pos') 187 | 188 | test_get_edges_from_choice() 189 | print('Pass: test_get_edges_from_choice') 190 | 191 | test_end_to_end_graph_gen_cpu() 192 | print('Pass: test_end_to_end_graph_gen_cpu') 193 | 194 | test_end_to_end_graph_gen_cuda() 195 | print('Pass: test_end_to_end_graph_gen_cuda') 196 | 197 | -------------------------------------------------------------------------------- /cdm/utils/chg_db.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import ase 4 | import bisect 5 | import warnings 6 | 7 | from tqdm.notebook import tqdm 8 | from pathlib import Path 9 | from multiprocessing import Pool 10 | 11 | from ocpmodels.common.registry import registry 12 | from ocpmodels.datasets.ase_datasets import AseDBDataset 13 | 14 | from cdm.utils.preprocessing import VaspChargeDensity 15 | 16 | @registry.register_dataset('charge_db') 17 | class ChgDBDataset(AseDBDataset): 18 | ''' 19 | An alternative database format based on ASE databases 20 | One way to create such a database is with the "write_charge_db" 21 | script in this file. This script requires VASP CHGCARs. 22 | ''' 23 | def __getitem__(self, idx): 24 | # Handle slicing 25 | if isinstance(idx, slice): 26 | return [self[i] for i in range(*idx.indices(len(self.ids)))] 27 | 28 | # Get atoms object via derived class method 29 | atoms = self.get_atoms_object(self.ids[idx]) 30 | 31 | # Transform atoms object 32 | if self.atoms_transform is not None: 33 | atoms = self.atoms_transform( 34 | atoms, **self.config.get("atoms_transform_args", {}) 35 | ) 36 | 37 | if "sid" in atoms.info: 38 | sid = atoms.info["sid"] 39 | else: 40 | sid = torch.tensor([idx]) 41 | 42 | # Convert to data object 43 | data_object = self.a2g.convert(atoms, sid) 44 | 45 | data_object.pbc = torch.tensor(atoms.pbc) 46 | data_object.charge_density = atoms.info["charge_density"] 47 | 48 | if isinstance(data_object.charge_density, list): 49 | data_object.charge_density = torch.tensor(data_object.charge_density) 50 | 51 | # Transform data object 52 | if self.transform is not None: 53 | data_object = self.transform( 54 | data_object, **self.config.get("transform_args", {}) 55 | ) 56 | 57 | return data_object 58 | 59 | def get_atoms_object(self, idx): 60 | # Figure out which db this should be indexed from. 61 | db_idx = bisect.bisect(self._idlen_cumulative, idx) 62 | 63 | # Extract index of element within that db 64 | el_idx = idx 65 | if db_idx != 0: 66 | el_idx = idx - self._idlen_cumulative[db_idx - 1] 67 | assert el_idx >= 0 68 | 69 | atoms_row = self.dbs[db_idx]._get_row(self.db_ids[db_idx][el_idx]) 70 | atoms = atoms_row.toatoms(add_additional_information = True) 71 | 72 | if isinstance(atoms_row.data, dict): 73 | atoms.info.update(atoms_row.data) 74 | 75 | return atoms 76 | 77 | def write_charge_db( 78 | CHGCARs, 79 | ase_db, 80 | num_workers = 1, 81 | ): 82 | if not isinstance(CHGCARs, list): 83 | path = Path(CHGCARs) 84 | ids = sorted(path.glob("*/CHGCAR")) 85 | else: 86 | ids = CHGCARs 87 | 88 | for idx in tqdm(ids): 89 | try: 90 | vcd = VaspChargeDensity(idx) 91 | atoms = vcd.atoms[-1] 92 | dens = vcd.chg[-1] 93 | except: 94 | print("Exception occured for: ", idx) 95 | 96 | try: 97 | ase_db.write(atoms, data = {'charge_density': dens}) 98 | except TypeError: 99 | warnings.warn("Failed to write tensor to database. Trying again as a list!") 100 | ase_db.write(atoms, data = {'charge_density': dens.tolist()}) -------------------------------------------------------------------------------- /cdm/utils/dir_of_chgcars.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from pathlib import Path 5 | 6 | from torch.utils.data import Dataset 7 | 8 | from ocpmodels.common.registry import registry 9 | from ocpmodels.preprocessing import AtomsToGraphs 10 | 11 | from cdm.utils.preprocessing import VaspChargeDensity 12 | 13 | @registry.register_dataset('dir_of_chgcars') 14 | class ChgcarDataset(Dataset): 15 | ''' 16 | This Dataset is used to process the outputs of VASP calculations directly. 17 | The following directory structure is expected: 18 | 19 | src/ 20 | ├─ calculation_1/ 21 | ├─ CHGCAR 22 | ├─ ... 23 | ├─ calculation_2/ 24 | ├─ CHGCAR 25 | ├─ ... 26 | ├─ ... 27 | 28 | ''' 29 | 30 | def __init__(self, config): 31 | super(ChgcarDataset, self).__init__() 32 | self.config = config 33 | 34 | self.path = Path(self.config['src']) 35 | if self.path.is_file(): 36 | raise Exception('The specified src is not a directory') 37 | split = config.get('split') 38 | if split is not None: 39 | f = open(split, "r") 40 | split = f.readlines() 41 | self.id = sorted([Path( 42 | str(self.path) + '/' + 43 | str(i.rstrip('\n')) + '/CHGCAR' ) for i in split]) 44 | else: 45 | self.id = sorted(self.path.glob('*/CHGCAR')) 46 | 47 | self.a2g = AtomsToGraphs( 48 | max_neigh = 1000, 49 | radius = 8, 50 | r_energy = False, 51 | r_forces = False, 52 | r_distances = False, 53 | r_fixed = False, 54 | r_pbc = False, 55 | ) 56 | 57 | self.transform = config.get('transform') 58 | 59 | def __len__(self): 60 | return len(self.id) 61 | 62 | def __getitem__(self, idx): 63 | try: 64 | vcd = VaspChargeDensity(self.id[idx]) 65 | atoms = vcd.atoms[-1] 66 | dens = vcd.chg[-1] 67 | except: 68 | print('Exception occured for: ', self.id[idx]) 69 | 70 | data_object = self.a2g.convert(atoms) 71 | data_object.charge_density = dens 72 | 73 | if self.transform is not None: 74 | data_object = self.transform(data_object) 75 | 76 | return data_object -------------------------------------------------------------------------------- /cdm/utils/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | import numpy as np 4 | 5 | from tqdm import tqdm 6 | 7 | from torch_geometric.data import Batch 8 | 9 | from ocpmodels.preprocessing import AtomsToGraphs 10 | from ocpmodels.datasets import data_list_collater 11 | 12 | from cdm.utils.probe_graph import ProbeGraphAdder 13 | 14 | def inference( 15 | atoms, 16 | model, 17 | grid = (100, 100, 100), 18 | atom_cutoff = 6, 19 | probe_cutoff = 6, 20 | batch_size = 10000, 21 | use_tqdm = True, 22 | device='cuda', 23 | total_density = None,): 24 | 25 | if device is 'cuda' and (torch.cuda.is_available() == False): 26 | warnings.warn('Cuda not available: running on CPU. Set device="cpu" to avoid this warning') 27 | device = 'cpu' 28 | 29 | a2g = AtomsToGraphs( 30 | max_neigh = len(atoms.get_atomic_numbers())**2, 31 | radius = atom_cutoff, 32 | r_energy = False, 33 | r_forces = False, 34 | r_distances = False, 35 | r_fixed = False, 36 | ) 37 | 38 | data_object = a2g.convert(atoms) 39 | 40 | data_object.charge_density = torch.zeros(grid) 41 | if total_density is not None: 42 | data_object.charge_density[0,0,0] = total_density 43 | data_object.to(device) 44 | model.to(device) 45 | 46 | pga = ProbeGraphAdder( 47 | num_probes = batch_size, 48 | cutoff = probe_cutoff, 49 | include_atomic_edges = False, 50 | mode = 'specify', 51 | stride = 1, 52 | implementation = 'RGPBC', 53 | ) 54 | 55 | total_probes = np.prod(grid) 56 | num_blocks = int(np.ceil(total_probes / batch_size)) 57 | slice_start = 0 58 | preds = torch.tensor([], device = device) 59 | sequence = np.arange(total_probes) 60 | np.random.shuffle(sequence) 61 | 62 | loop = range(num_blocks) 63 | if use_tqdm: 64 | loop = tqdm(loop) 65 | 66 | for i in loop: 67 | data_object.probe_data = 0 68 | if i == (num_blocks - 1): 69 | with torch.no_grad(): 70 | data_object = pga( 71 | data_object, 72 | specify_probes = sequence[slice_start:], 73 | ) 74 | else: 75 | with torch.no_grad(): 76 | data_object = pga( 77 | data_object, 78 | specify_probes = sequence[(i*batch_size):((i+1)*batch_size)] 79 | ) 80 | 81 | batch = data_list_collater([data_object.clone().detach()]) 82 | batch.probe_data = Batch.from_data_list([data_object.probe_data]) 83 | 84 | with torch.no_grad(): 85 | preds = torch.cat((preds, model(batch))) 86 | 87 | slice_start += batch_size 88 | torch.cuda.empty_cache() 89 | 90 | out = torch.zeros_like(preds) 91 | 92 | out[sequence] = preds 93 | 94 | return torch.reshape(out, grid) 95 | -------------------------------------------------------------------------------- /cdm/utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | from ase.calculators.vasp import VaspChargeDensity 2 | import numpy as np 3 | import re 4 | 5 | class VaspChargeDensity(VaspChargeDensity): 6 | def read(self, filename, read_augs = True): 7 | """Re-implemenation of ASE functionality 8 | The adjustments support the new, faster _read_chg 9 | 10 | """ 11 | import ase.io.vasp as aiv 12 | fd = open(filename) 13 | self.atoms = [] 14 | self.chg = [] 15 | self.chgdiff = [] 16 | self.aug = '' 17 | self.augdiff = '' 18 | while True: 19 | try: 20 | atoms = aiv.read_vasp(fd) 21 | except (IOError, ValueError, IndexError): 22 | # Probably an empty line, or we tried to read the 23 | # augmentation occupancies in CHGCAR 24 | break 25 | fd.readline() 26 | ngr = fd.readline().split() 27 | shape = (int(ngr[0]), int(ngr[1]), int(ngr[2])) 28 | chg = self._read_chg(fd, shape, atoms.get_volume()) 29 | self.chg.append(chg) 30 | self.atoms.append(atoms) 31 | # Check if the file has a spin-polarized charge density part, and 32 | # if so, read it in. 33 | fl = fd.tell() 34 | 35 | if not read_augs: 36 | break 37 | 38 | # First check if the file has an augmentation charge part (CHGCAR 39 | # file.) 40 | line1 = fd.readline() 41 | if line1 == '': 42 | break 43 | elif line1.find('augmentation') != -1: 44 | augs = [line1] 45 | while True: 46 | line2 = fd.readline() 47 | if line2.split() == ngr: 48 | self.aug = ''.join(augs) 49 | augs = [] 50 | chgdiff = np.empty(ng) 51 | self._read_chg(fd, chgdiff, atoms.get_volume()) 52 | self.chgdiff.append(chgdiff) 53 | elif line2 == '': 54 | break 55 | else: 56 | augs.append(line2) 57 | if len(self.aug) == 0: 58 | self.aug = ''.join(augs) 59 | augs = [] 60 | else: 61 | self.augdiff = ''.join(augs) 62 | augs = [] 63 | elif line1.split() == ngr: 64 | chgdiff = np.empty(ng) 65 | self._read_chg(fd, chgdiff, atoms.get_volume()) 66 | self.chgdiff.append(chgdiff) 67 | else: 68 | fd.seek(fl) 69 | fd.close() 70 | self.aug_string_to_dict() 71 | 72 | def _read_chg(self, fobj, shape, volume): 73 | """Replaces ASE's method 74 | This implementation is approximately 2x faster. 75 | This is important because reading data from disk can 76 | be a limiting factor for training speed. 77 | 78 | """ 79 | chg = np.fromfile(fobj, count = np.prod(shape), sep=' ') 80 | 81 | chg = chg.reshape(shape, order = 'F') 82 | chg /= volume 83 | 84 | return chg 85 | 86 | def write(self, filename): 87 | self.aug_dict_to_string() 88 | super().write(filename) 89 | 90 | def aug_dict_to_string(self): 91 | texts = re.split('.{10}E.{3}', self.aug) 92 | augs = [] 93 | for i in range(len(self.aug_dict)): 94 | augs = [*augs, *(self.aug_dict[str(i+1)])] 95 | 96 | out = texts[0] 97 | 98 | for text, number in zip(texts[1:], augs): 99 | if number > 0: 100 | number = f'{10*number:.6E}' 101 | number = number.split('.') 102 | number = number[0] + number[1] 103 | number = ' 0.' + number 104 | elif number == 0: 105 | number = ' 0.0000000E+00' 106 | else: 107 | number = f'{10*number:.6E}' 108 | number = number.split('.') 109 | number = number[0][1] + number[1] 110 | number = '-0.' + number 111 | 112 | out += number 113 | out += text 114 | 115 | self.aug = out 116 | 117 | def aug_string_to_dict(self): 118 | self.aug_dict = {} 119 | split = [x.split() for x in self.aug.split('augmentation occupancies')[1:]] 120 | for row in split: 121 | label = row[0] 122 | augmentations = [float(x) for x in row[2:]] 123 | self.aug_dict[label] = augmentations 124 | 125 | def zero_augmentations(self): 126 | ''' 127 | Set all augmentation occupancies to zero 128 | ''' 129 | for key, value in self.aug_dict.items(): 130 | value = [0 for i in value] 131 | self.aug_dict[key] = value 132 | 133 | def __eq__(self, other): 134 | if self.aug_dict != other.aug_dict: 135 | return False 136 | if self.atoms != other.atoms: 137 | return False 138 | 139 | for self_chg, other_chg in zip(self.chg, other.chg): 140 | if not (self_chg == other_chg).all(): 141 | return False 142 | 143 | return True -------------------------------------------------------------------------------- /cdm/utils/probe_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from torch_geometric.data import Data 5 | 6 | from ase import Atoms 7 | from ase.calculators.vasp import VaspChargeDensity 8 | from ase import neighborlist as nbl 9 | 10 | from ocpmodels.preprocessing import AtomsToGraphs 11 | from ocpmodels.datasets import data_list_collater 12 | 13 | class ProbeGraphAdder(): 14 | ''' 15 | A class that is used to add probe graphs to data objects. 16 | The data object must have an attribute "charge_density" which is 17 | a 3-dimensional tensor of charge density values 18 | 19 | Alternatively, this object can be called twice: 20 | The first time pre-selects random probe points (mode: "preselect") 21 | The second time computes the graph (mode: "preselected") 22 | This can be useful in certain situations such as 23 | selecting the probes on CPU and computing the graph on GPU. 24 | This reduces the amount of data that must be sent to the GPU. 25 | ''' 26 | def __init__( 27 | self, 28 | num_probes=10000, 29 | cutoff=5, 30 | include_atomic_edges=False, 31 | mode = 'random', 32 | slice_start = 0, 33 | stride = 1, 34 | implementation = 'RGPBC', 35 | assert_min_edges = 1, 36 | ): 37 | self.num_probes = num_probes 38 | self.cutoff = cutoff 39 | self.include_atomic_edges = include_atomic_edges 40 | self.mode = mode 41 | self.slice_start = slice_start 42 | self.stride = stride 43 | self.implementation = implementation 44 | self.assert_min_edges = assert_min_edges 45 | 46 | def handle_density_array( 47 | self, 48 | density, 49 | cell 50 | ): 51 | 52 | if not torch.is_tensor(density): 53 | density = torch.tensor(np.array(density), dtype = torch.float) 54 | 55 | if self.stride != 1: 56 | assert (self.stride == 2) or (self.stride == 4) 57 | density = density[::self.stride, ::self.stride, ::self.stride] 58 | 59 | grid_pos = calculate_grid_pos(density.shape, cell) 60 | 61 | return density, grid_pos 62 | 63 | 64 | 65 | def __call__( 66 | self, 67 | data_object, 68 | slice_start = None, 69 | specify_probes = None, 70 | num_probes = None, 71 | mode = None, 72 | use_tqdm = False, 73 | ): 74 | 75 | if self.implementation == 'SKIP': 76 | return data_object 77 | 78 | # Check if probe graph has been precomputed or preselected 79 | if hasattr(data_object, 'probe_data'): 80 | if hasattr(data_object.probe_data, 'edge_index') \ 81 | and hasattr(data_object.probe_data, 'cell_offsets'): 82 | return data_object 83 | 84 | elif hasattr(data_object.probe_data, 'atomic_numbers') \ 85 | and hasattr(data_object.probe_data, 'pos') \ 86 | and hasattr(data_object.probe_data, 'target'): 87 | mode = 'preselected' 88 | 89 | 90 | # Handle batching 91 | if type(data_object.natoms) is not int: 92 | if len(data_object.natoms) > 1: 93 | raise Exception( 94 | 'Batch size >1 is not supported. \ 95 | It is recommended to instead increase the number of probes per structure' 96 | ) 97 | 98 | ''' 99 | data_list = data_object.to_data_list() 100 | batches = [data_list_collater([data]) for data in data_list] 101 | probe_data = [self(batch).probe_data for batch in batches] 102 | probe_data = data_list_collater(probe_data) 103 | data_object.probe_data = probe_data 104 | ''' 105 | return data_object 106 | 107 | # Use default options if none have been passed in 108 | if slice_start is None: 109 | slice_start = self.slice_start 110 | if num_probes is None: 111 | num_probes = self.num_probes 112 | if mode is None: 113 | mode = self.mode 114 | 115 | probe_data = Data() 116 | 117 | if mode == 'random': 118 | density, grid_pos = self.handle_density_array( 119 | data_object.charge_density, 120 | data_object.cell, 121 | ) 122 | 123 | probe_edges = torch.tensor([[]]) 124 | 125 | while probe_edges.shape[1] < self.assert_min_edges: 126 | probe_choice = np.random.randint( 127 | np.prod(grid_pos.shape[-5:-2]), 128 | size = num_probes, 129 | ) 130 | 131 | probe_choice = np.unravel_index( 132 | probe_choice, 133 | grid_pos.shape[-5:-2], 134 | ) 135 | 136 | out = get_edges_from_choice( 137 | probe_choice, 138 | grid_pos, 139 | atom_pos = data_object.pos, 140 | cell = data_object.cell, 141 | cutoff = self.cutoff, 142 | include_atomic_edges = self.include_atomic_edges, 143 | implementation = self.implementation, 144 | ) 145 | 146 | probe_edges, probe_offsets, probe_pos = out 147 | 148 | atomic_numbers = torch.clone(data_object.atomic_numbers.detach()) 149 | atomic_numbers = torch.cat((atomic_numbers, torch.zeros(num_probes, device = atomic_numbers.device))) 150 | 151 | probe_data.target = (density.reshape(density.shape[-3:])[probe_choice[0:3]]).to(atomic_numbers.device) 152 | probe_data.total_target = torch.sum(density) * torch.numel(probe_data.target) / torch.numel(density) 153 | 154 | elif mode == 'specify': 155 | density, grid_pos = self.handle_density_array( 156 | data_object.charge_density, 157 | data_object.cell, 158 | ) 159 | 160 | num_probes = len(specify_probes) 161 | probe_choice = np.unravel_index(specify_probes, grid_pos.shape[-5:-2]) 162 | out = get_edges_from_choice( 163 | probe_choice, 164 | grid_pos, 165 | atom_pos = data_object.pos, 166 | cell = data_object.cell, 167 | cutoff = self.cutoff, 168 | include_atomic_edges = self.include_atomic_edges, 169 | implementation = self.implementation, 170 | ) 171 | probe_edges, probe_offsets, probe_pos = out 172 | atomic_numbers = torch.cat((data_object.atomic_numbers, torch.zeros(num_probes, device = data_object.atomic_numbers.device))) 173 | 174 | probe_data.target = (density.reshape(density.shape[-3:])[probe_choice[0:3]]).to(atomic_numbers.device) 175 | probe_data.total_target = torch.sum(density) * torch.numel(probe_data.target) / torch.numel(density) 176 | 177 | elif mode == 'slice': 178 | density, grid_pos = self.handle_density_array( 179 | data_object.charge_density, 180 | data_object.cell, 181 | ) 182 | probe_choice = np.arange(slice_start, slice_start + num_probes, step=1) 183 | probe_choice = np.unravel_index(probe_choice, grid_pos.shape[-5:-2]) 184 | out = get_edges_from_choice( 185 | probe_choice, 186 | grid_pos, 187 | atom_pos = data_object.pos, 188 | cell = data_object.cell, 189 | cutoff = self.cutoff, 190 | include_atomic_edges = self.include_atomic_edges, 191 | implementation = self.implementation, 192 | ) 193 | probe_edges, probe_offsets, probe_pos = out 194 | atomic_numbers = torch.cat((data_object.atomic_numbers, torch.zeros(num_probes, device = data_object.atomic_numbers.device))) 195 | 196 | probe_data.target = (density.reshape(density.shape[-3:])[probe_choice[0:3]]).to(atomic_numbers.device) 197 | probe_data.total_target = torch.sum(density) * torch.numel(probe_data.target) / torch.numel(density) 198 | 199 | elif mode == 'all': 200 | density, grid_pos = self.handle_density_array( 201 | data_object.charge_density, 202 | data_object.cell, 203 | ) 204 | 205 | total_probes = np.prod(density.shape) 206 | num_blocks = int(np.ceil(total_probes / num_probes)) 207 | 208 | probe_edges = torch.tensor([], device = data_object.edge_index.device) 209 | probe_offsets = torch.tensor([], device = data_object.cell_offsets.device) 210 | atomic_numbers = torch.clone(data_object.atomic_numbers.detach()) 211 | probe_pos = torch.tensor([], device = data_object.pos.device) 212 | 213 | loop = range(num_blocks) 214 | if use_tqdm: 215 | loop = tqdm(loop) 216 | 217 | for i in loop: 218 | if i == num_blocks - 1: 219 | probe_choice = np.arange(i * num_probes, total_probes, step = 1) 220 | else: 221 | probe_choice = np.arange(i * num_probes, (i+1)*num_probes, step = 1) 222 | 223 | probe_choice = np.unravel_index(probe_choice, grid_pos.shape[-5:-2]) 224 | out = get_edges_from_choice( 225 | probe_choice, 226 | grid_pos, 227 | atom_pos = data_object.pos, 228 | cell = data_object.cell, 229 | cutoff = self.cutoff, 230 | include_atomic_edges = self.include_atomic_edges, 231 | implementation = self.implementation, 232 | ) 233 | new_edges, new_offsets, new_pos = out 234 | 235 | new_edges[1] += i*num_probes 236 | probe_edges = torch.cat((probe_edges, new_edges), dim=1) 237 | probe_offsets = torch.cat((probe_offsets, new_offsets)) 238 | atomic_numbers = torch.cat((atomic_numbers, torch.zeros(new_pos.shape[0], device = atomic_numbers.device))) 239 | probe_pos = torch.cat((probe_pos, new_pos)) 240 | 241 | probe_choice = np.arange(0, np.prod(grid_pos.shape[-5:-2]), step=1) 242 | probe_choice = np.unravel_index(probe_choice, grid_pos.shape[-5:-2]) 243 | 244 | probe_data.target = (density.reshape(density.shape[-3:])[probe_choice[0:3]]).to(atomic_numbers.device) 245 | probe_data.total_target = torch.sum(density) * torch.numel(probe_data.target) / torch.numel(density) 246 | 247 | elif mode == 'preselected': 248 | probe_data = data_object.probe_data 249 | atom_pos = data_object.pos 250 | probe_pos = probe_data.pos[len(data_object.pos):] 251 | 252 | if self.implementation == 'ASE': 253 | neighborlist = AseNeighborListWrapper( 254 | self.cutoff, 255 | atom_pos, 256 | probe_pos, 257 | data_object.cell, 258 | ) 259 | 260 | elif self.implementation == 'RGPBC': 261 | neighborlist = RadiusGraphPBCWrapper( 262 | self.cutoff, 263 | atom_pos, 264 | probe_pos, 265 | data_object.cell) 266 | 267 | probe_edges, probe_offsets = neighborlist.get_all_neighbors( 268 | self.cutoff, 269 | self.include_atomic_edges, 270 | ) 271 | 272 | atomic_numbers = probe_data.atomic_numbers 273 | 274 | elif mode == 'preselect': 275 | density, grid_pos = self.handle_density_array( 276 | data_object.charge_density, 277 | data_object.cell, 278 | ) 279 | 280 | probe_edges = torch.tensor([[]]) 281 | probe_choice = np.random.randint( 282 | np.prod(grid_pos.shape[-5:-2]), 283 | size = num_probes, 284 | ) 285 | 286 | probe_choice = np.unravel_index( 287 | probe_choice, 288 | grid_pos.shape[-5:-2], 289 | ) 290 | 291 | grid_pos = grid_pos.reshape((*grid_pos.shape[-5:-2], 3)) 292 | probe_pos = grid_pos[probe_choice[0:3]] 293 | 294 | atomic_numbers = torch.clone(data_object.atomic_numbers.detach()) 295 | probe_data.atomic_numbers = torch.cat((atomic_numbers, torch.zeros(num_probes, device = atomic_numbers.device))) 296 | 297 | probe_data.target = (density.reshape(density.shape[-3:])[probe_choice[0:3]]).to(atomic_numbers.device) 298 | probe_data.natoms = torch.LongTensor([int(len(probe_data.atomic_numbers))]) 299 | probe_data.cell = data_object.cell 300 | 301 | 302 | probe_data.pos = torch.cat((data_object.pos, probe_pos)) 303 | 304 | probe_data.total_target = torch.sum(density) * torch.numel(probe_data.target) / torch.numel(density) 305 | 306 | data_object.probe_data = probe_data 307 | 308 | return data_object 309 | 310 | else: 311 | raise RuntimeError('Mode '+mode+' is not recognized.') 312 | 313 | # Add attributes to probe_data object 314 | probe_data.cell = data_object.cell 315 | probe_data.atomic_numbers = atomic_numbers 316 | probe_data.natoms = torch.LongTensor([int(len(atomic_numbers))]) 317 | probe_data.pos = torch.cat((data_object.pos, probe_pos)) 318 | 319 | probe_data.edge_index = probe_edges.long() 320 | 321 | probe_data.cell_offsets = -probe_offsets 322 | 323 | probe_data.neighbors = torch.LongTensor([probe_data.edge_index.shape[1]]) 324 | 325 | # Add probe_data object to overall data object 326 | data_object.probe_data = probe_data 327 | 328 | return data_object 329 | 330 | class AseNeighborListWrapper: 331 | """ 332 | Wrapper around ASE neighborlist 333 | Modified from DeepDFT 334 | """ 335 | 336 | def __init__( 337 | self, 338 | cutoff, 339 | atom_pos, 340 | probe_pos, 341 | cell 342 | ): 343 | atoms = Atoms(numbers = [1] * len(atom_pos), 344 | positions = atom_pos.cpu().detach().numpy(), 345 | cell = cell.cpu().detach().numpy()[0], 346 | pbc = [True, True, True]) 347 | 348 | probe_atoms = Atoms(numbers = [0] * len(probe_pos), positions = probe_pos) 349 | atoms_with_probes = atoms.copy() 350 | atoms_with_probes.extend(probe_atoms) 351 | 352 | atoms = atoms_with_probes 353 | 354 | self.neighborlist = nbl.NewPrimitiveNeighborList( 355 | cutoff, skin=0.0, self_interaction=False, bothways=True 356 | ) 357 | 358 | self.neighborlist.build( 359 | atoms.get_pbc(), atoms.get_cell(), atoms.get_positions() 360 | ) 361 | 362 | self.cutoff = cutoff 363 | self.atoms_positions = atoms.get_positions() 364 | self.atoms_cell = atoms.get_cell() 365 | 366 | is_probe = atoms.get_atomic_numbers() == 0 367 | self.num_atoms = len(atoms.get_positions()[~is_probe]) 368 | self.atomic_numbers = atoms.get_atomic_numbers() 369 | 370 | def get_neighbors(self, i, cutoff): 371 | assert ( 372 | cutoff == self.cutoff 373 | ), "Cutoff must be the same as used to initialise the neighborlist" 374 | 375 | indices, offsets = self.neighborlist.get_neighbors(i) 376 | 377 | offsets = offsets 378 | 379 | return indices, offsets 380 | 381 | def get_all_neighbors(self, cutoff, include_atomic_edges): 382 | probe_edges = [] 383 | probe_offsets = [] 384 | results = [self.neighborlist.get_neighbors(i) for i in range(self.num_atoms)] 385 | 386 | for i, (neigh_idx, neigh_offset) in enumerate(results): 387 | if not include_atomic_edges: 388 | neigh_atomic_species = self.atomic_numbers[neigh_idx] 389 | neigh_is_probe = neigh_atomic_species == 0 390 | neigh_idx = neigh_idx[neigh_is_probe] 391 | neigh_offset = neigh_offset[neigh_is_probe] 392 | 393 | atom_index = np.ones_like(neigh_idx) * i 394 | edges = np.stack((atom_index, neigh_idx), axis = 1) 395 | probe_edges.append(edges) 396 | probe_offsets.append(neigh_offset) 397 | 398 | edge_index = torch.tensor(np.concatenate(probe_edges, axis=0)).T 399 | 400 | cell_offsets = torch.tensor(np.concatenate(probe_offsets, axis=0)) 401 | 402 | return edge_index, cell_offsets 403 | 404 | class RadiusGraphPBCWrapper: 405 | """ 406 | Wraps a modified version of the neighbor-finding algorithm from ocp 407 | (ocp.ocpmodels.common.utils.radius_graph_pbc) 408 | The modifications restrict the neighbor-finding to atom-probe edges, 409 | which is more efficient for our purposes. 410 | """ 411 | def __init__(self, radius, atom_pos, probe_pos, cell, pbc = [True, True, True]): 412 | self.cutoff = radius 413 | atom_indices = torch.arange(0, len(atom_pos), device = atom_pos.device) 414 | probe_indices = torch.arange(len(atom_pos), len(atom_pos)+len(probe_pos), device = probe_pos.device) 415 | batch_size = 1 416 | 417 | num_atoms = len(atom_pos) 418 | num_probes = len(probe_pos) 419 | num_total = num_atoms + num_probes 420 | num_combos = num_atoms * num_probes 421 | 422 | indices = np.arange(0, num_total, 1) 423 | 424 | index1 = torch.repeat_interleave(atom_indices, repeats=num_probes) 425 | index2 = probe_indices.repeat(num_atoms) 426 | 427 | pos1 = atom_pos[index1] 428 | pos2 = probe_pos[index2 - num_atoms] 429 | 430 | cross_a2a3 = torch.cross(cell[:, 1], cell[:, 2], dim=-1) 431 | cell_vol = torch.sum(cell[:, 0] * cross_a2a3, dim=-1, keepdim=True) 432 | 433 | if pbc[0]: 434 | inv_min_dist_a1 = torch.norm(cross_a2a3 / cell_vol, p=2, dim=-1) 435 | rep_a1 = torch.ceil(radius * inv_min_dist_a1) 436 | else: 437 | rep_a1 = cell.new_zeros(1) 438 | 439 | if pbc[1]: 440 | cross_a3a1 = torch.cross(cell[:, 2], cell[:, 0], dim=-1) 441 | inv_min_dist_a2 = torch.norm(cross_a3a1 / cell_vol, p=2, dim=-1) 442 | rep_a2 = torch.ceil(radius * inv_min_dist_a2) 443 | else: 444 | rep_a2 = cell.new_zeros(1) 445 | 446 | if pbc[2]: 447 | cross_a1a2 = torch.cross(cell[:, 0], cell[:, 1], dim=-1) 448 | inv_min_dist_a3 = torch.norm(cross_a1a2 / cell_vol, p=2, dim=-1) 449 | rep_a3 = torch.ceil(radius * inv_min_dist_a3) 450 | else: 451 | rep_a3 = cell.new_zeros(1) 452 | 453 | # Take the max over all images for uniformity. This is essentially padding. 454 | # Note that this can significantly increase the number of computed distances 455 | # if the required repetitions are very different between images 456 | # (which they usually are). Changing this to sparse (scatter) operations 457 | # might be worth the effort if this function becomes a bottleneck. 458 | max_rep = [rep_a1.max(), rep_a2.max(), rep_a3.max()] 459 | 460 | # Tensor of unit cells 461 | cells_per_dim = [ 462 | torch.arange(-rep, rep + 1, dtype=torch.float, device = cell.device) 463 | for rep in max_rep 464 | ] 465 | unit_cell = torch.cartesian_prod(*cells_per_dim) 466 | 467 | num_cells = len(unit_cell) 468 | unit_cell_per_atom = unit_cell.view(1, num_cells, 3).repeat( 469 | len(index2), 1, 1 470 | ) 471 | unit_cell = torch.transpose(unit_cell, 0, 1) 472 | unit_cell_batch = unit_cell.view(1, 3, num_cells).expand( 473 | batch_size, -1, -1 474 | ) 475 | 476 | # Compute the x, y, z positional offsets for each cell in each image 477 | data_cell = torch.transpose(cell, 1, 2) 478 | pbc_offsets = torch.bmm(data_cell, unit_cell_batch) 479 | pbc_offsets_per_atom = torch.repeat_interleave( 480 | pbc_offsets, num_combos, dim=0 481 | ) 482 | 483 | # Expand the positions and indices for the 9 cells 484 | pos1 = pos1.view(-1, 3, 1).expand(-1, -1, num_cells) 485 | pos2 = pos2.view(-1, 3, 1).expand(-1, -1, num_cells) 486 | index1 = index1.view(-1, 1).repeat(1, num_cells).view(-1) 487 | index2 = index2.view(-1, 1).repeat(1, num_cells).view(-1) 488 | # Add the PBC offsets for the second atom 489 | pos2 = pos2 + pbc_offsets_per_atom 490 | 491 | # Compute the squared distance between atoms 492 | atom_distance_sqr = torch.sum((pos1 - pos2) ** 2, dim=1) 493 | atom_distance_sqr = atom_distance_sqr.view(-1) 494 | 495 | # Remove pairs that are too far apart 496 | mask_within_radius = torch.le(atom_distance_sqr, radius * radius) 497 | 498 | # Remove pairs with the same atoms (distance = 0.0) 499 | mask_not_same = torch.gt(atom_distance_sqr, 0.0001) 500 | mask = torch.logical_and(mask_within_radius, mask_not_same) 501 | index1 = torch.masked_select(index1, mask) 502 | index2 = torch.masked_select(index2, mask) 503 | 504 | unit_cell = torch.masked_select( 505 | unit_cell_per_atom.view(-1, 3), mask.view(-1, 1).expand(-1, 3) 506 | ) 507 | unit_cell = unit_cell.view(-1, 3) 508 | atom_distance_sqr = torch.masked_select(atom_distance_sqr, mask) 509 | 510 | self.edge_index = torch.stack((index1, index2)) 511 | 512 | self.offsets = unit_cell 513 | 514 | def get_all_neighbors(self, cutoff, include_atomic_edges = False): 515 | assert ( 516 | cutoff == self.cutoff 517 | ), "Cutoff must be the same as used to initialise the neighborlist" 518 | 519 | if include_atomic_edges: 520 | raise NotImplementedError 521 | 522 | return self.edge_index.to(torch.int64), self.offsets 523 | 524 | def calculate_grid_pos(shape, cell): 525 | # Ensure proper dimensions of cell 526 | if len(cell.shape) > 2: 527 | if cell.shape[0] == 1: 528 | cell = cell[0, :, :] 529 | else: 530 | raise NotImplementedError('calculate_grid_pos does not yet support batch sizes > 1') 531 | else: 532 | raise RuntimeError('Invalid unit cell definition for calculate_grid_pos') 533 | 534 | # Compute grid positions 535 | grid_pos = torch.cartesian_prod( 536 | torch.linspace(0, 1, shape[-3]+1, device = cell.device)[:-1], 537 | torch.linspace(0, 1, shape[-2]+1, device = cell.device)[:-1], 538 | torch.linspace(0, 1, shape[-1]+1, device = cell.device)[:-1], 539 | ) 540 | 541 | grid_pos = torch.mm(grid_pos, cell) 542 | grid_pos = grid_pos.reshape((*shape, 1, 3)) 543 | return grid_pos 544 | 545 | def get_edges_from_choice(probe_choice, grid_pos, atom_pos, cell, cutoff, include_atomic_edges, implementation): 546 | """ 547 | Given a list of chosen probes, compute all edges between the probes and atoms. 548 | Portions from DeepDFT 549 | """ 550 | grid_pos = grid_pos.reshape((*grid_pos.shape[-5:-2], 3)) 551 | probe_pos = grid_pos[probe_choice[0:3]] 552 | 553 | if implementation == 'ASE': 554 | neighborlist = AseNeighborListWrapper(cutoff, atom_pos, probe_pos, cell) 555 | 556 | elif implementation == 'RGPBC': 557 | neighborlist = RadiusGraphPBCWrapper(cutoff, atom_pos, probe_pos, cell) 558 | 559 | else: 560 | raise NotImplementedError('Unsupported implementation. Please choose from: ASE, RGPBC') 561 | 562 | edge_index, cell_offsets = neighborlist.get_all_neighbors(cutoff, include_atomic_edges) 563 | 564 | return edge_index, cell_offsets, probe_pos 565 | -------------------------------------------------------------------------------- /cdm/utils/vasp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import re 4 | import torch 5 | 6 | import ase 7 | from ase.calculators.vasp import Vasp 8 | from ase.db import connect 9 | 10 | from tqdm.notebook import tqdm 11 | 12 | from cdm.utils.inference import inference 13 | from cdm.utils.preprocessing import VaspChargeDensity 14 | 15 | def make_VASP_inputs( 16 | database_path, 17 | path): 18 | """ 19 | Make VASP inputs for all entries in a database. 20 | This script relies on a working ASE setup (i.e. the appropriate psuedopotentials are available). 21 | Once the directories are created you can run the VASP calculations according to your HPC system. 22 | 23 | Args: 24 | database_path (str): Path to database. 25 | path (str): Path to directory where vasp inputs will be written. 26 | 27 | Returns: 28 | None 29 | """ 30 | database = connect(database_path) 31 | 32 | for entry in database.select(): 33 | atoms = entry.toatoms() 34 | 35 | # Setup vasp calculator 36 | atoms.calc = Vasp( 37 | encut=350, 38 | xc='rpbe', 39 | gga='RP', 40 | lcharg=True, 41 | lwave=False, 42 | nelm=120, 43 | algo='Normal', 44 | ) 45 | 46 | atoms.calc.directory = os.path.join(path, str(entry.id)) 47 | 48 | # Write vasp inputs 49 | atoms.calc.write_input(atoms) 50 | 51 | def write_CHGCAR_like( 52 | model, 53 | input_CHGCAR_path, 54 | output_CHGCAR_path, 55 | batch_size = 100000, 56 | use_tqdm = False, 57 | device = 'cuda', 58 | ): 59 | vcd = VaspChargeDensity(input_CHGCAR_path) 60 | vcd.chg[0] = inference( 61 | atoms = vcd.atoms[0], 62 | model = model, 63 | grid = vcd.chg[0].shape, 64 | atom_cutoff = model.atom_message_model.cutoff, 65 | probe_cutoff = model.probe_message_model.cutoff, 66 | batch_size = batch_size, 67 | use_tqdm = use_tqdm, 68 | device = device, 69 | total_density = torch.sum(torch.tensor(vcd.chg[0])), 70 | ) 71 | vcd.write(output_CHGCAR_path) 72 | 73 | def setup_VASP_experiment( 74 | input_augs_path, 75 | output_path, 76 | input_density_path = None, 77 | model = None, 78 | batch_size = 100000, 79 | device = 'cuda', 80 | ): 81 | paths = os.listdir(input_augs_path) 82 | assert set(paths) == set(os.listdir(output_path)) 83 | 84 | if model is None: 85 | assert input_density_path is not None 86 | assert set(paths) == set(os.listdir(input_density_path)) 87 | 88 | for i in tqdm(paths): 89 | vcd = VaspChargeDensity(os.path.join(input_density_path, i, 'CHGCAR')) 90 | vcd.aug_dict = VaspChargeDensity(os.path.join(input_augs_path, i, 'CHGCAR')).aug_dict 91 | vcd.write(os.path.join(output_path, i, 'CHGCAR')) 92 | 93 | elif (input_density_path is None) or (input_density_path == input_augs_path): 94 | assert model is not None 95 | 96 | for i in tqdm(paths): 97 | write_CHGCAR_like( 98 | model = model, 99 | input_CHGCAR_path = os.path.join(input_augs_path, i, 'CHGCAR'), 100 | output_CHGCAR_path = os.path.join(output_path, i, 'CHGCAR'), 101 | batch_size = batch_size, 102 | device = device, 103 | use_tqdm = True, 104 | ) 105 | else: 106 | assert set(paths) == set(os.listdir(input_density_path)) 107 | 108 | for i in tqdm(paths): 109 | 110 | vcd = VaspChargeDensity(os.path.join(input_density_path, i, 'CHGCAR')) 111 | vcd.chg[0] = inference( 112 | atoms = vcd.atoms[0], 113 | model = model, 114 | grid = vcd.chg[0].shape, 115 | atom_cutoff = model.atom_message_model.cutoff, 116 | probe_cutoff = modle.probe_message_model.cutoff, 117 | batch_size = batch_size, 118 | use_tqdm = False, 119 | device = device, 120 | total_density = torch.sum(torch.tensor(vcd.chg[0])), 121 | ) 122 | vcd.aug_dict = VaspChargeDensity(os.path.join(input_augs_path, i, 'CHGCAR')).aug_dict 123 | vcd.write(os.path.join(output_path, i, 'CHGCAR')) 124 | 125 | def kubernetes_VASP_batch( 126 | path, 127 | base_params, 128 | exp_tag, 129 | namespace, 130 | cpu_req = 4, 131 | cpu_lim = 16, 132 | mem_req = '8Gi', 133 | mem_lim = '16Gi', 134 | VASP_mp_threads = 16, 135 | select_only = None, 136 | ): 137 | ''' 138 | args: 139 | path (str): the directory that contains your VASP directories 140 | base_params: a yaml of parameters obtained from your template file 141 | or a path to the template file 142 | exp_tag: a string to identify this experiment/run/batch 143 | cpu_req: how many CPU cores to request for each individual VASP job 144 | cpu_lim: how many CPU cores to limit the job to (if extra cores are available on the node) 145 | mem_req: how much memory to reserve for each individual VASP job 146 | mem_lim: how much memory to limit the job to (if extra memory is avaialable on the node) 147 | VASP_mp_threads: how many threads to launch for your VASP calculation. 148 | Should be between (or equal to) cpu_req and cpu_lim 149 | 150 | This script is designed to help submit calculations to Laikapack, a 151 | Kubernetes cluster at CMU. It is likely that your HPC setup is 152 | different. In that case, you should use whatever submission system is 153 | recommended for VASP calculations on your resources. 154 | 155 | If None is passed in for any of the resource values, this script will not attempt to replace that 156 | value in the template file. In this case, your template file should specify the resource amount. 157 | 158 | ''' 159 | 160 | if select_only is not None: 161 | paths = select_only 162 | else: 163 | paths = os.listdir(path) 164 | 165 | if isinstance(base_params, str): 166 | with open(base_params, 'r') as stream: 167 | base_params = yaml.safe_load(stream) 168 | 169 | params = base_params.copy() 170 | params['metadata']['namespace'] = namespace 171 | params['spec']['template']['spec']['containers'] = [params['spec']['template']['spec']['containers']] 172 | container = params['spec']['template']['spec']['containers'][0] 173 | 174 | # Set the resource usage, which is assumed to be the same for all jobs 175 | if cpu_lim is not None: 176 | container['resources']['limits']['cpu'] = cpu_lim 177 | if cpu_req is not None: 178 | container['resources']['requests']['cpu'] = cpu_req 179 | if VASP_mp_threads is not None: 180 | container['args'][0] = re.sub('-np \d+', '-np '+str(VASP_mp_threads), container['args'][0]) 181 | 182 | if mem_lim is not None: 183 | container['resources']['limits']['memory'] = mem_lim 184 | if mem_req is not None: 185 | container['resources']['requests']['memory'] = mem_req 186 | 187 | for fid, directory in enumerate(tqdm(paths)): 188 | # Set the metadata for each job 189 | params['metadata']['name'] = exp_tag + '-' + re.sub('_', '-', directory) 190 | container['workingDir'] = path + '/' + directory 191 | container['name'] = exp_tag + '-' + re.sub('_', '-', directory) 192 | 193 | # Write the job specification file 194 | with open('job.yaml', 'w') as config_file: 195 | yaml.dump(params, config_file, default_flow_style = False) 196 | 197 | # Submit the job 198 | os.system('kubectl apply -f job.yaml > /dev/null') 199 | 200 | def read_experiment( 201 | path, 202 | loud = False, 203 | skip = None, 204 | ): 205 | paths = os.listdir(path) 206 | 207 | if skip is not None: 208 | for i in skip: 209 | paths.remove(i) 210 | 211 | results = {} 212 | 213 | for fid, directory in enumerate(tqdm(paths)): 214 | if loud: 215 | print(directory) 216 | 217 | if directory not in []: 218 | 219 | try: 220 | with open(path + '/' + directory + '/' + 'vasp.out') as vaspout: 221 | 222 | ncg = 0 223 | DAVS = [line for line in vaspout if ('DAV' in line)] 224 | for idx, line in enumerate(DAVS): 225 | line = line.split(' ') 226 | line = [string for string in line if (string != '')] 227 | DAVS[idx] = line 228 | ncg += int(line[5]) 229 | 230 | n = len(DAVS) 231 | E = float(DAVS[-1][2]) 232 | 233 | results[directory] = {'num_scf_steps': n, 'Energy': E, 'ncg': ncg} 234 | except Exception as exception: 235 | print(f'Error: could not read vasp.out from: {path}/{directory}') 236 | print(str(exception)) 237 | 238 | return results 239 | 240 | def compare_VASP_experiments( 241 | baseline, 242 | trial, 243 | skip = None, 244 | ): 245 | baseline = read_experiment(baseline, skip = skip) 246 | trial = read_experiment(trial, skip = skip) 247 | 248 | assert set(baseline.keys()) == set(trial.keys()) 249 | 250 | E_diffs = [] 251 | for i in baseline.keys(): 252 | E_diffs.append(trial[i]['Energy'] - baseline[i]['Energy']) 253 | ncg_fraction = sum(x['ncg'] for x in trial.values()) / sum(x['ncg'] for x in baseline.values()) 254 | indiv_ncg_fractions = [x['ncg'] / y['ncg'] for x, y in zip(trial.values(), baseline.values())] 255 | 256 | faster_fraction = len([x for x in trial.keys() if trial[x]['ncg'] < baseline[x]['ncg']]) / len(trial.keys()) 257 | slower_fraction = len([x for x in trial.keys() if trial[x]['ncg'] > baseline[x]['ncg']]) / len(trial.keys()) 258 | equal_fraction = len([x for x in trial.keys() if trial[x]['ncg'] == baseline[x]['ncg']]) / len(trial.keys()) 259 | 260 | return { 261 | 'faster_fraction': faster_fraction, 262 | 'slower_fraction': slower_fraction, 263 | 'equal_fraction': equal_fraction, 264 | 'ncg_fraction': ncg_fraction, 265 | 'E_diffs': E_diffs, 266 | 'indiv_ncg_fractions': indiv_ncg_fractions, 267 | } 268 | -------------------------------------------------------------------------------- /configs/common/common.yml: -------------------------------------------------------------------------------- 1 | trainer: cdm.charge_trainer.ChargeTrainer 2 | local_rank: 0 3 | amp: True 4 | seed: 2 5 | 6 | task: 7 | description: Predicting electron density from atomic positions 8 | strict_load: False -------------------------------------------------------------------------------- /configs/datasets/10k.yml: -------------------------------------------------------------------------------- 1 | task: 2 | dataset: dir_of_chgcars 3 | 4 | dataset: 5 | - src: /home/jovyan/shared-scratch/ethan/density/33k_sample/ 6 | split: /home/jovyan/shared-scratch/ethan/density/10k_split.txt 7 | - src: /home/jovyan/shared-scratch/ethan/val_dataset_gen/val_ood_both/val -------------------------------------------------------------------------------- /configs/datasets/1k.yml: -------------------------------------------------------------------------------- 1 | task: 2 | dataset: dir_of_chgcars 3 | 4 | dataset: 5 | - src: /home/jovyan/shared-scratch/ethan/density/33k_sample/ 6 | split: /home/jovyan/shared-scratch/ethan/density/1k_split.txt 7 | - src: /home/jovyan/shared-scratch/ethan/val_dataset_gen/val_ood_both/val -------------------------------------------------------------------------------- /configs/datasets/33k.yml: -------------------------------------------------------------------------------- 1 | task: 2 | dataset: dir_of_chgcars 3 | 4 | dataset: 5 | - src: /home/jovyan/shared-scratch/ethan/density/33k_sample/ 6 | - src: /home/jovyan/shared-scratch/ethan/val_dataset_gen/val_ood_both/val -------------------------------------------------------------------------------- /configs/datasets/3k.yml: -------------------------------------------------------------------------------- 1 | task: 2 | dataset: dir_of_chgcars 3 | 4 | dataset: 5 | - src: /home/jovyan/shared-scratch/ethan/density/33k_sample/ 6 | split: /home/jovyan/shared-scratch/ethan/density/3k_split.txt 7 | - src: /home/jovyan/shared-scratch/ethan/val_dataset_gen/val_ood_both/val -------------------------------------------------------------------------------- /configs/models/painn/painn-small.yml: -------------------------------------------------------------------------------- 1 | model: 2 | name: charge_model 3 | enforce_zero_for_disconnected_probes: True 4 | enforce_charge_conservation: True 5 | freeze_atomic: False 6 | 7 | atom_model_config: 8 | name: painn_charge 9 | num_layers: 3 10 | hidden_channels: 128 11 | num_rbf: 32 12 | cutoff: 6 13 | 14 | probe_model_config: 15 | name: painn_charge 16 | num_layers: 2 17 | hidden_channels: 128 18 | max_neighbors: 20000 19 | num_rbf: 32 20 | cutoff: 6 21 | 22 | otf_pga_config: 23 | num_probes: 30000 24 | cutoff: 6 25 | assert_min_edges: 200 -------------------------------------------------------------------------------- /configs/models/schnet/schnet-large.yml: -------------------------------------------------------------------------------- 1 | model: 2 | name: charge_model 3 | enforce_zero_for_disconnected_probes: True 4 | enforce_charge_conservation: True 5 | freeze_atomic: False 6 | 7 | atom_model_config: 8 | name: schnet_charge 9 | num_interactions: 6 10 | hidden_channels: 256 11 | cutoff: 6 12 | 13 | probe_model_config: 14 | name: schnet_charge 15 | num_interactions: 6 16 | hidden_channels: 256 17 | cutoff: 6 18 | 19 | otf_pga_config: 20 | num_probes: 60000 21 | cutoff: 6 22 | assert_min_edges: 10 -------------------------------------------------------------------------------- /configs/models/schnet/schnet-small.yml: -------------------------------------------------------------------------------- 1 | model: 2 | name: charge_model 3 | enforce_zero_for_disconnected_probes: False 4 | enforce_charge_conservation: True 5 | freeze_atomic: False 6 | 7 | atom_model_config: 8 | name: schnet_charge 9 | num_interactions: 5 10 | hidden_channels: 128 11 | num_filters: 32 12 | cutoff: 5 13 | 14 | probe_model_config: 15 | name: schnet_charge 16 | num_interactions: 3 17 | hidden_channels: 128 18 | num_filters: 32 19 | cutoff: 5 20 | 21 | otf_pga_config: 22 | num_probes: 20000 23 | cutoff: 5 -------------------------------------------------------------------------------- /configs/optimizers/adam-standard.yml: -------------------------------------------------------------------------------- 1 | optim: 2 | optimizer: Adam 3 | num_workers: 6 4 | lr_initial: 0.00001 5 | scheduler: ReduceLROnPlateau 6 | mode: min 7 | factor: 0.97 8 | patience: 1 9 | max_epochs: 1000 10 | loss_charge: normed_mae 11 | load_balancing: False 12 | eval_every: 40000 13 | -------------------------------------------------------------------------------- /configs/template.yml: -------------------------------------------------------------------------------- 1 | includes: 2 | - /home/jovyan/charge-density-models/configs/common/common.yml 3 | - /home/jovyan/charge-density-models/configs/datasets/3k.yml 4 | - /home/jovyan/charge-density-models/configs/models/schnet/schnet_large.yml 5 | - /home/jovyan/charge-density-models/configs/optimizers/adam-standard/schnet_large.yml -------------------------------------------------------------------------------- /data/1k_split.txt: -------------------------------------------------------------------------------- 1 | random1333062_168 2 | random1597616_105 3 | random1397381_274 4 | random1616922_68 5 | random1348777_130 6 | random1651049_86 7 | random1276563_25 8 | random1430227_287 9 | random1212243_213 10 | random1264690_172 11 | random1291727_52 12 | random1011949_33 13 | random1333869_159 14 | random1490761_84 15 | random1655634_57 16 | random1684061_79 17 | random1612072_126 18 | random1053359_51 19 | random1479616_16 20 | random1344801_342 21 | random1670338_96 22 | random1682395_205 23 | random1603291_143 24 | random1659376_53 25 | random1373870_125 26 | random1525044_9 27 | random1655025_68 28 | random1039153_160 29 | random1478658_54 30 | random1304301_192 31 | random1383713_243 32 | random1287653_121 33 | random1641729_289 34 | random1101472_75 35 | random1159326_8 36 | random1192158_260 37 | random1400624_313 38 | random1233254_195 39 | random1545729_22 40 | random1535709_116 41 | random1592415_86 42 | random1050808_328 43 | random1265469_107 44 | random1438365_106 45 | random1041525_211 46 | random1376277_88 47 | random1000808_129 48 | random1537308_39 49 | random1531655_95 50 | random1419711_159 51 | random1597143_15 52 | random1254568_108 53 | random1109203_98 54 | random1474573_233 55 | random1013854_55 56 | random1599540_114 57 | random1274958_14 58 | random1172919_148 59 | random1466540_113 60 | random1404961_357 61 | random1285112_84 62 | random1581334_34 63 | random1525119_121 64 | random1531517_45 65 | random1179544_218 66 | random1211394_67 67 | random1190076_38 68 | random1459161_82 69 | random1498603_289 70 | random1501303_282 71 | random1676738_124 72 | random1186585_335 73 | random1462988_160 74 | random1376635_196 75 | random1364311_177 76 | random1407076_557 77 | random1502865_653 78 | random1287901_68 79 | random1679856_30 80 | random1098751_32 81 | random1642172_112 82 | random1056268_243 83 | random1360892_110 84 | random1574143_53 85 | random1436093_115 86 | random1517321_293 87 | random1537334_150 88 | random1445174_45 89 | random1441245_203 90 | random1238209_378 91 | random1227726_223 92 | random1252891_60 93 | random1563941_55 94 | random1469440_45 95 | random1439975_144 96 | random1574173_36 97 | random1315450_244 98 | random1670891_103 99 | random1634526_163 100 | random1394725_228 101 | random1288643_278 102 | random1192074_109 103 | random1639675_102 104 | random1627431_14 105 | random1200587_126 106 | random1574080_208 107 | random1569241_225 108 | random1461641_146 109 | random1451510_114 110 | random1433886_60 111 | random1419595_52 112 | random1343004_221 113 | random1634849_60 114 | random1407393_55 115 | random1148956_165 116 | random1580721_102 117 | random1151041_94 118 | random1496082_85 119 | random1272580_59 120 | random1347543_198 121 | random1128731_54 122 | random1526669_56 123 | random1022000_237 124 | random1501916_166 125 | random1528759_52 126 | random1211049_339 127 | random1309401_117 128 | random1305147_230 129 | random1467680_210 130 | random1545236_208 131 | random153372_10 132 | random1249775_124 133 | random1339043_27 134 | random1600801_198 135 | random1556093_153 136 | random1387459_119 137 | random1393293_100 138 | random1182895_54 139 | random1173798_156 140 | random1355964_28 141 | random1539730_313 142 | random1647840_209 143 | random1673865_273 144 | random1556384_34 145 | random1547516_193 146 | random1477672_444 147 | random1196482_227 148 | random1236996_41 149 | random1630548_166 150 | random1652866_273 151 | random1527986_240 152 | random1094717_73 153 | random1111045_170 154 | random1628671_235 155 | random1251964_176 156 | random1412133_4 157 | random1294787_247 158 | random1500370_164 159 | random1668596_39 160 | random1276388_145 161 | random1337371_19 162 | random1248376_49 163 | random1496931_84 164 | random1397435_305 165 | random1671739_483 166 | random1220125_18 167 | random1256873_19 168 | random1484158_58 169 | random1350130_215 170 | random1205446_15 171 | random1649669_14 172 | random1521785_521 173 | random1508186_256 174 | random1192469_414 175 | random1251790_4 176 | random1205847_162 177 | random1363094_465 178 | random1331799_64 179 | random1505722_194 180 | random1499220_96 181 | random1203480_110 182 | random1288780_145 183 | random1197142_120 184 | random1494147_40 185 | random1616928_178 186 | random1597272_89 187 | random1451553_222 188 | random1503516_90 189 | random1170857_16 190 | random1436494_265 191 | random1635935_4 192 | random1489508_579 193 | random1330029_362 194 | random1065239_13 195 | random1569858_57 196 | random1635380_6 197 | random1034171_42 198 | random1319150_150 199 | random1171205_110 200 | random1462111_126 201 | random1281565_89 202 | random1302093_203 203 | random1488549_198 204 | random1346344_471 205 | random1577011_151 206 | random1379003_33 207 | random1519696_196 208 | random1634731_2 209 | random1329127_23 210 | random1248069_26 211 | random1584020_240 212 | random109729_31 213 | random1301784_29 214 | random1414684_60 215 | random1361173_75 216 | random1554514_61 217 | random1578500_574 218 | random1507635_183 219 | random1356042_21 220 | random1036401_190 221 | random1525043_263 222 | random1510290_247 223 | random1355869_92 224 | random1664931_128 225 | random1567256_107 226 | random1362481_176 227 | random1623368_190 228 | random1202176_4 229 | random1253624_131 230 | random1539036_109 231 | random1411934_304 232 | random1670123_82 233 | random1654803_324 234 | random1321660_203 235 | random1626349_32 236 | random1628281_32 237 | random1681210_30 238 | random1668262_48 239 | random1161662_95 240 | random1042029_129 241 | random1594343_168 242 | random1556483_61 243 | random1603048_360 244 | random1346377_146 245 | random1193048_126 246 | random1403176_28 247 | random1051741_108 248 | random1225725_57 249 | random1674236_119 250 | random1670590_137 251 | random1559931_193 252 | random1349800_281 253 | random1008007_87 254 | random1554177_65 255 | random1628629_302 256 | random1382364_102 257 | random1294912_93 258 | random1314542_11 259 | random1503802_50 260 | random1368900_4 261 | random1306400_173 262 | random1301941_243 263 | random1021108_39 264 | random1484388_295 265 | random1186075_46 266 | random1498251_328 267 | random1304052_186 268 | random1636314_73 269 | random1335937_183 270 | random1630812_37 271 | random1015147_99 272 | random1389842_145 273 | random1350498_325 274 | random1346663_69 275 | random1205731_63 276 | random1652682_147 277 | random1465409_42 278 | random1280009_71 279 | random1570473_129 280 | random1002110_45 281 | random1295646_115 282 | random1113452_163 283 | random1173207_70 284 | random1295012_137 285 | random1329433_41 286 | random1665406_109 287 | random1500087_206 288 | random1402833_273 289 | random1460797_148 290 | random1480988_293 291 | random1392911_61 292 | random1600601_129 293 | random1665925_62 294 | random1389157_154 295 | random1468984_401 296 | random1382020_137 297 | random1330889_52 298 | random1654278_62 299 | random1075550_79 300 | random1380638_144 301 | random1253866_97 302 | random1617465_26 303 | random1411714_189 304 | random1525856_48 305 | random1330242_268 306 | random1383761_25 307 | random1626233_59 308 | random1232603_2 309 | random1217128_7 310 | random1208359_2 311 | random1297675_109 312 | random1606838_93 313 | random1013220_205 314 | random1511911_4 315 | random1392728_195 316 | random1661686_201 317 | random1415496_179 318 | random1386517_46 319 | random1155911_107 320 | random1615400_102 321 | random1469208_250 322 | random1476191_199 323 | random1302765_143 324 | random1673253_80 325 | random1488681_16 326 | random1619805_699 327 | random1502929_127 328 | random1529083_199 329 | random1330119_153 330 | random1011673_225 331 | random1420253_174 332 | random1590090_7 333 | random1392968_85 334 | random1144568_191 335 | random1220421_49 336 | random1523809_32 337 | random1646767_142 338 | random1041725_74 339 | random1536958_148 340 | random1657319_44 341 | random1460906_69 342 | random1254205_15 343 | random1245098_183 344 | random1400907_120 345 | random1555909_108 346 | random1431348_92 347 | random1382367_180 348 | random1606719_58 349 | random1139017_40 350 | random1418273_238 351 | random1441712_347 352 | random1610585_137 353 | random1276563_195 354 | random1263564_58 355 | random1086970_58 356 | random1623296_106 357 | random1355735_161 358 | random1241160_399 359 | random1027047_110 360 | random1658985_23 361 | random1325845_135 362 | random1446210_108 363 | random1629395_29 364 | random1622763_686 365 | random1626418_186 366 | random1335321_81 367 | random1243444_191 368 | random1303198_84 369 | random1218018_157 370 | random1241613_12 371 | random1088720_218 372 | random1030742_81 373 | random1315060_181 374 | random1556722_287 375 | random1558064_47 376 | random1207009_219 377 | random1670807_37 378 | random1409895_137 379 | random1021016_183 380 | random1201579_251 381 | random1283271_203 382 | random1357766_49 383 | random1547640_26 384 | random1195777_184 385 | random1498093_19 386 | random1510891_276 387 | random1430658_12 388 | random1501831_297 389 | random1163532_103 390 | random1459702_540 391 | random1309660_53 392 | random1502628_110 393 | random1063014_279 394 | random1350798_46 395 | random1598976_407 396 | random1430820_94 397 | random1411782_59 398 | random1434476_192 399 | random1396816_66 400 | random1498635_33 401 | random1322022_103 402 | random1299881_32 403 | random1654006_66 404 | random1120856_39 405 | random1081661_160 406 | random1548205_121 407 | random1415413_347 408 | random1627848_70 409 | random1174291_77 410 | random1320947_146 411 | random1642010_287 412 | random1481822_341 413 | random1231747_215 414 | random1557878_63 415 | random1122676_1 416 | random1590978_150 417 | random1190729_181 418 | random1460834_96 419 | random1355776_376 420 | random1528213_150 421 | random1192307_22 422 | random1557527_191 423 | random1184687_250 424 | random1222752_127 425 | random1387639_230 426 | random1303686_100 427 | random1289928_166 428 | random1605974_89 429 | random1517803_11 430 | random1449597_74 431 | random1282046_472 432 | random1225142_133 433 | random1032044_106 434 | random1432951_113 435 | random1555345_102 436 | random1333173_73 437 | random1672033_330 438 | random1435350_239 439 | random1659354_105 440 | random1454955_41 441 | random1462995_182 442 | random1201606_214 443 | random1205006_10 444 | random1635386_175 445 | random1383238_106 446 | random1406296_438 447 | random1553858_229 448 | random1568453_19 449 | random1650429_163 450 | random1153339_34 451 | random1159663_209 452 | random1283904_75 453 | random1196702_87 454 | random1214859_83 455 | random1644117_113 456 | random1455432_121 457 | random1317620_30 458 | random1545120_147 459 | random1463162_173 460 | random1226560_51 461 | random1577092_126 462 | random1608375_186 463 | random1281108_3 464 | random1320044_70 465 | random1533040_32 466 | random1197154_29 467 | random1658393_210 468 | random1596088_192 469 | random1083051_146 470 | random1568656_106 471 | random1622137_54 472 | random1517264_76 473 | random1682374_146 474 | random1484997_73 475 | random1571550_17 476 | random1059420_94 477 | random1179829_287 478 | random1423935_58 479 | random1256404_153 480 | random1393696_140 481 | random1257348_62 482 | random1513678_102 483 | random1524280_38 484 | random1616351_199 485 | random1674356_330 486 | random1346966_13 487 | random1492463_140 488 | random1368637_174 489 | random1350686_258 490 | random1424922_21 491 | random1653223_138 492 | random1258914_182 493 | random1184670_51 494 | random1448006_85 495 | random1624745_19 496 | random1530386_246 497 | random1174969_41 498 | random1299901_17 499 | random1319888_105 500 | random1466841_69 501 | random1086442_133 502 | random1093042_184 503 | random1450445_137 504 | random1189674_216 505 | random1506417_156 506 | random1451199_169 507 | random1637902_48 508 | random1570185_0 509 | random1122602_53 510 | random1256019_974 511 | random1467695_41 512 | random1302318_143 513 | random1248435_81 514 | random1569990_123 515 | random1061875_147 516 | random1523174_129 517 | random1239053_110 518 | random1198289_267 519 | random1246103_474 520 | random1135316_64 521 | random1517198_93 522 | random1548205_237 523 | random1039282_64 524 | random1251193_21 525 | random1407748_553 526 | random1632393_322 527 | random1585064_96 528 | random1646079_2 529 | random1205825_148 530 | random1677232_153 531 | random1328728_68 532 | random1406965_19 533 | random1030630_170 534 | random1640615_105 535 | random1589490_124 536 | random1173132_61 537 | random1508378_69 538 | random1473837_260 539 | random1211978_124 540 | random1605785_136 541 | random1475459_74 542 | random1225913_168 543 | random1106009_135 544 | random1480393_61 545 | random1228100_105 546 | random1433872_169 547 | random1361767_23 548 | random1335967_4 549 | random1536436_14 550 | random1645430_21 551 | random1424448_110 552 | random1602207_27 553 | random1438722_51 554 | random1441697_89 555 | random1186017_139 556 | random1182402_72 557 | random1473064_186 558 | random1398334_419 559 | random1511765_10 560 | random1273046_156 561 | random1674251_36 562 | random1406960_244 563 | random1463574_75 564 | random1013259_67 565 | random1011512_163 566 | random1494712_75 567 | random1357578_330 568 | random1405238_276 569 | random1216849_85 570 | random1305299_13 571 | random1074870_157 572 | random1495704_44 573 | random1414571_66 574 | random1467077_44 575 | random1210070_216 576 | random1313069_194 577 | random1306066_145 578 | random1233247_277 579 | random1603801_48 580 | random1574406_49 581 | random1227177_60 582 | random1472679_255 583 | random1276158_72 584 | random1454001_76 585 | random1558149_31 586 | random1620120_338 587 | random1545818_315 588 | random1069571_226 589 | random1528603_29 590 | random1505313_103 591 | random1226734_2 592 | random1067295_49 593 | random1152633_80 594 | random1284515_69 595 | random1017796_57 596 | random1650580_101 597 | random1180952_272 598 | random1679718_312 599 | random1498403_71 600 | random1560021_194 601 | random1036713_58 602 | random1661400_465 603 | random1228876_93 604 | random1353291_78 605 | random1294577_97 606 | random1219541_56 607 | random1318797_108 608 | random1079699_131 609 | random1190847_64 610 | random1203359_44 611 | random1540033_150 612 | random1081528_99 613 | random1354856_5 614 | random1027004_151 615 | random1210045_279 616 | random1676023_83 617 | random1406748_95 618 | random1491047_7 619 | random1218984_30 620 | random1396091_118 621 | random1278094_99 622 | random1443765_111 623 | random1413291_62 624 | random1375449_54 625 | random1563803_142 626 | random1393480_22 627 | random1594710_56 628 | random1648145_191 629 | random1375125_102 630 | random1518417_244 631 | random1527317_28 632 | random1428304_238 633 | random1597884_107 634 | random1218112_147 635 | random1188775_75 636 | random1669866_239 637 | random1604193_421 638 | random1333232_4 639 | random1296622_168 640 | random1317734_23 641 | random1650250_19 642 | random1558542_2 643 | random1460793_106 644 | random1479211_79 645 | random1577706_54 646 | random1478679_195 647 | random1577448_77 648 | random1390981_41 649 | random1527103_423 650 | random1588269_122 651 | random1478161_496 652 | random1562049_28 653 | random1488790_8 654 | random1569074_212 655 | random1581816_119 656 | random1342205_35 657 | random1276011_207 658 | random1478229_402 659 | random1537040_167 660 | random1366015_558 661 | random1559558_558 662 | random1629993_13 663 | random1262446_58 664 | random1480438_55 665 | random1100084_73 666 | random1633332_21 667 | random1309826_127 668 | random1427128_35 669 | random1339463_62 670 | random1632555_110 671 | random1420264_141 672 | random1138817_178 673 | random1410846_31 674 | random1044859_66 675 | random1581356_194 676 | random1339110_182 677 | random1226134_66 678 | random1447093_205 679 | random1280726_148 680 | random1591243_381 681 | random1358519_38 682 | random1176875_182 683 | random1453136_153 684 | random1471134_149 685 | random1508171_156 686 | random1369480_36 687 | random1150354_107 688 | random1225106_245 689 | random1353200_114 690 | random1067264_89 691 | random1608486_103 692 | random1423707_10 693 | random1406019_13 694 | random1375439_117 695 | random1438750_323 696 | random1629440_2 697 | random1289391_211 698 | random1346043_516 699 | random1028998_91 700 | random1492009_29 701 | random1514467_63 702 | random1162538_49 703 | random1172725_20 704 | random1606910_63 705 | random1390775_95 706 | random1161428_194 707 | random1595636_170 708 | random1473968_166 709 | random1682565_70 710 | random1475166_85 711 | random1336832_123 712 | random1456128_128 713 | random1153564_142 714 | random1294159_243 715 | random1165723_173 716 | random1615717_190 717 | random1587774_72 718 | random1047977_141 719 | random1066886_59 720 | random1445232_36 721 | random1653907_665 722 | random1280884_64 723 | random1572603_149 724 | random1146320_99 725 | random1439581_30 726 | random1196549_189 727 | random1609673_100 728 | random1102704_131 729 | random1189198_178 730 | random1473525_126 731 | random1070020_155 732 | random1019209_106 733 | random1632443_322 734 | random1509359_395 735 | random1180952_447 736 | random1508200_98 737 | random1070212_76 738 | random1511587_60 739 | random1462991_99 740 | random1295243_140 741 | random1514578_73 742 | random1324735_40 743 | random1199464_13 744 | random1337946_177 745 | random1017070_120 746 | random1273042_200 747 | random1373171_31 748 | random1394266_139 749 | random1387247_201 750 | random1652216_124 751 | random155825_36 752 | random1592626_38 753 | random1212211_461 754 | random1028046_71 755 | random1244592_148 756 | random1370851_488 757 | random1555121_23 758 | random1210746_53 759 | random1309827_48 760 | random1431067_140 761 | random1413384_5 762 | random1656184_415 763 | random1506424_123 764 | random1519254_407 765 | random1436128_135 766 | random1187565_196 767 | random1417640_69 768 | random1335415_69 769 | random1042427_162 770 | random1456892_179 771 | random1659655_167 772 | random1096339_112 773 | random1332944_58 774 | random1312708_102 775 | random1365090_179 776 | random1218499_152 777 | random1286661_94 778 | random1269109_15 779 | random1087796_106 780 | random1504010_68 781 | random151803_109 782 | random1341801_98 783 | random1678804_140 784 | random1507519_196 785 | random1541857_277 786 | random1152134_45 787 | random1045948_122 788 | random1453116_159 789 | random1301929_303 790 | random1497937_16 791 | random1136246_57 792 | random1639242_126 793 | random1377801_495 794 | random1024805_57 795 | random1658876_170 796 | random1412051_263 797 | random1280764_18 798 | random108395_180 799 | random1573281_133 800 | random1636400_5 801 | random1349654_133 802 | random1319183_176 803 | random1184268_0 804 | random1644953_84 805 | random1391768_302 806 | random1204533_89 807 | random1611541_205 808 | random1459387_22 809 | random1618450_73 810 | random1421846_145 811 | random1576815_195 812 | random1238959_12 813 | random1637775_311 814 | random1204985_287 815 | random1606059_71 816 | random1451498_268 817 | random1607935_135 818 | random1463883_280 819 | random1234335_210 820 | random1053479_122 821 | random1456467_284 822 | random1431953_19 823 | random1066728_122 824 | random1217463_110 825 | random1427187_178 826 | random1281529_41 827 | random1601586_104 828 | random1070269_100 829 | random1680033_189 830 | random1640381_79 831 | random1372861_368 832 | random1315657_195 833 | random1539323_90 834 | random1602529_165 835 | random1523340_223 836 | random16327_137 837 | random1203861_204 838 | random1642522_181 839 | random1236245_65 840 | random1517121_116 841 | random1075498_164 842 | random1421802_146 843 | random1495183_143 844 | random1419287_246 845 | random1252863_215 846 | random1665242_138 847 | random1244318_365 848 | random1405476_27 849 | random1531888_93 850 | random1545448_144 851 | random1628185_36 852 | random1622215_351 853 | random1221097_65 854 | random1433685_44 855 | random1276630_324 856 | random1009815_98 857 | random1605528_102 858 | random1606829_407 859 | random1442846_80 860 | random1666916_81 861 | random1641678_168 862 | random1338001_356 863 | random1463092_26 864 | random1245477_206 865 | random1449401_1 866 | random1260792_185 867 | random1083058_41 868 | random1566197_77 869 | random1431025_124 870 | random1407333_330 871 | random1369716_70 872 | random1327093_260 873 | random1417108_288 874 | random1680367_116 875 | random1335660_139 876 | random1423631_255 877 | random1033391_22 878 | random1292392_337 879 | random1373575_165 880 | random1263383_195 881 | random1506170_234 882 | random1441661_75 883 | random1509375_17 884 | random1647329_455 885 | random1028699_35 886 | random1175288_36 887 | random1028748_55 888 | random1682172_59 889 | random1673253_40 890 | random1386923_159 891 | random1491649_31 892 | random1391623_156 893 | random1224244_102 894 | random1486993_400 895 | random1486722_98 896 | random1093465_60 897 | random1633417_115 898 | random1301386_8 899 | random1230754_75 900 | random1665377_217 901 | random1310231_52 902 | random1526694_244 903 | random1469353_130 904 | random1432348_97 905 | random1532416_50 906 | random1619630_35 907 | random1519064_45 908 | random1645845_96 909 | random1288643_11 910 | random1543861_15 911 | random1660035_40 912 | random1373415_120 913 | random1222111_170 914 | random1038152_178 915 | random1489836_20 916 | random1386734_27 917 | random1530595_282 918 | random1295612_19 919 | random1196793_123 920 | random1417535_456 921 | random1614149_210 922 | random1655666_114 923 | random1458570_25 924 | random1326300_285 925 | random1365109_225 926 | random1515612_143 927 | random1251368_130 928 | random1332646_303 929 | random1492644_296 930 | random1238995_99 931 | random1292595_107 932 | random1434711_73 933 | random1025232_147 934 | random1000738_94 935 | random1556574_217 936 | random1532782_292 937 | random1627667_186 938 | random1014307_67 939 | random1468021_51 940 | random1342921_172 941 | random1164285_189 942 | random1670140_322 943 | random1222713_164 944 | random1528402_192 945 | random1261191_303 946 | random126977_52 947 | random1375588_20 948 | random1522262_57 949 | random1187517_142 950 | random1635351_170 951 | random1202227_37 952 | random1236090_132 953 | random1149456_125 954 | random1309080_184 955 | random1514499_139 956 | random1204622_163 957 | random1498925_211 958 | random1295190_132 959 | random1606312_365 960 | random1264010_276 961 | random1399316_69 962 | random1222678_2 963 | random1227180_120 964 | random1178726_135 965 | random1009651_192 966 | random1205774_14 967 | random1464793_385 968 | random1556330_53 969 | random1435509_60 970 | random1547706_84 971 | random1196923_343 972 | random1291145_124 973 | random1667353_61 974 | random1683920_123 975 | random1485121_335 976 | random1452258_151 977 | random1488000_3 978 | random1239766_47 979 | random1313135_54 980 | random1379874_23 981 | random1424991_170 982 | random1606312_271 983 | random1040436_116 984 | random1442350_39 985 | random1622038_210 986 | random1303021_171 987 | random1438775_119 988 | random1385708_301 989 | random1651726_165 990 | random1087032_251 991 | random1559040_330 992 | random1328345_83 993 | random1664504_254 994 | random1348210_50 995 | random1682421_110 996 | random1622763_121 997 | random1493940_65 998 | random1646417_46 999 | random1330628_147 1000 | random1193236_105 1001 | -------------------------------------------------------------------------------- /data/test.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulissigroup/charge-density-models/bac64ff98dcfb015d19be1cbe9d3ae435600f999/data/test.db -------------------------------------------------------------------------------- /data/train0.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulissigroup/charge-density-models/bac64ff98dcfb015d19be1cbe9d3ae435600f999/data/train0.db -------------------------------------------------------------------------------- /data/train1.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulissigroup/charge-density-models/bac64ff98dcfb015d19be1cbe9d3ae435600f999/data/train1.db -------------------------------------------------------------------------------- /data/train2.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulissigroup/charge-density-models/bac64ff98dcfb015d19be1cbe9d3ae435600f999/data/train2.db -------------------------------------------------------------------------------- /data/train3.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulissigroup/charge-density-models/bac64ff98dcfb015d19be1cbe9d3ae435600f999/data/train3.db -------------------------------------------------------------------------------- /data/train4.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulissigroup/charge-density-models/bac64ff98dcfb015d19be1cbe9d3ae435600f999/data/train4.db -------------------------------------------------------------------------------- /data/val.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ulissigroup/charge-density-models/bac64ff98dcfb015d19be1cbe9d3ae435600f999/data/val.db -------------------------------------------------------------------------------- /notebooks/Inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "2d900c26", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import os\n", 12 | "import yaml\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import numpy as np\n", 15 | "import time\n", 16 | "\n", 17 | "from tqdm import tqdm\n", 18 | "from torch_geometric.data import Batch, Data\n", 19 | "from pymatgen.core.sites import PeriodicSite\n", 20 | "from pymatgen.io.ase import AseAtomsAdaptor\n", 21 | "from ase import neighborlist as nbl\n", 22 | "from ase import Atoms\n", 23 | "from ase.calculators.vasp import VaspChargeDensity\n", 24 | "\n", 25 | "from ocpmodels.common import logger\n", 26 | "from ocpmodels.common.registry import registry\n", 27 | "from ocpmodels.common.utils import setup_logging\n", 28 | "from ocpmodels.preprocessing import AtomsToGraphs\n", 29 | "from ocpmodels.datasets import data_list_collater\n", 30 | "\n", 31 | "import cdm.models\n", 32 | "from cdm.charge_trainer import ChargeTrainer\n", 33 | "from cdm.utils.probe_graph import ProbeGraphAdder\n", 34 | "from cdm.utils.inference import inference\n", 35 | "\n", 36 | "setup_logging()" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "id": "b482da38", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def make_parity_plot(x, y, LOG):\n", 47 | " plt.scatter(x, y, \n", 48 | " color='blue', \n", 49 | " alpha = 0.1,\n", 50 | " s=1.5,\n", 51 | " #label='Predictions',\n", 52 | " )\n", 53 | "\n", 54 | " plt.gcf().set_dpi(200)\n", 55 | " plt.axis('square')\n", 56 | "\n", 57 | " if LOG:\n", 58 | " plt.gca().set_xscale('log')\n", 59 | " plt.gca().set_yscale('log')\n", 60 | "\n", 61 | " plt.plot([0, torch.max(x)+1], [0, torch.max(x)+1], label='Parity line', color='red')\n", 62 | " plt.xlabel('Ground truth electron density\\nelectrons per cubic Angstrom')\n", 63 | " plt.ylabel('Predicted electron density\\nelectrons per cubic Angstrom')\n", 64 | " plt.xlim([1e-10, torch.max(x)+1])\n", 65 | " plt.ylim([1e-10, torch.max(x)+1])\n", 66 | " plt.show()" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "101fac48", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "if torch.cuda.is_available():\n", 77 | " print(\"True\")\n", 78 | "else:\n", 79 | " print(\"False\")\n", 80 | " torch.set_num_threads(8)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "371724ea", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "model_config = {\n", 91 | " 'name': 'charge_model',\n", 92 | " 'num_interactions': 4,\n", 93 | " 'atom_channels': 64,\n", 94 | " 'probe_channels': 64,\n", 95 | " 'enforce_zero_for_disconnected_probes': True,\n", 96 | " 'enforce_charge_conservation': True,\n", 97 | " \n", 98 | " 'atom_model_config': {\n", 99 | " 'name': 'schnet_charge',\n", 100 | " 'num_filters':64,\n", 101 | " 'num_gaussians':64,\n", 102 | " 'cutoff':5,\n", 103 | " },\n", 104 | " \n", 105 | " 'probe_model_config': {\n", 106 | " 'name': 'schnet_charge',\n", 107 | " 'num_filters':32,\n", 108 | " 'num_gaussians':32,\n", 109 | " 'cutoff':4,\n", 110 | " },\n", 111 | "}\n", 112 | "\n", 113 | "model = cdm.models.ChargeModel(**model_config)\n", 114 | "\n", 115 | "path = '../runs/checkpoints/2022-11-01-18-54-56-Approximate Charge Conservation, 100k/checkpoint.pt'\n", 116 | "state_dict = torch.load(path)['state_dict']\n", 117 | "\n", 118 | "sd = {}\n", 119 | "\n", 120 | "for x in state_dict.items():\n", 121 | " sd[x[0][7:]] = x[1]\n", 122 | "\n", 123 | "model.load_state_dict(sd)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "id": "2c40cbac", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "path = '../cdm/tests/test_structure'\n", 134 | "\n", 135 | "vcd = VaspChargeDensity(path) \n", 136 | "atoms = vcd.atoms[-1]\n", 137 | "dens = vcd.chg[-1]\n", 138 | "grid = dens.shape\n", 139 | "\n", 140 | "target = torch.tensor(dens)\n", 141 | "\n", 142 | "print(atoms)\n", 143 | "print(grid)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "id": "aaf19215", 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "pred = inference(\n", 154 | " atoms, \n", 155 | " model, \n", 156 | " grid, \n", 157 | " atom_cutoff = 5,\n", 158 | " probe_cutoff = 4,\n", 159 | " batch_size = 1000,\n", 160 | " use_tqdm = True,\n", 161 | " device = 'cuda',\n", 162 | " total_density = torch.sum(target)\n", 163 | ")\n", 164 | "\n", 165 | "pred = pred.to('cpu')" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "9d55683e", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "make_parity_plot(target.flatten(), pred.flatten(), LOG=False)\n", 176 | "make_parity_plot(target.flatten(), pred.flatten(), LOG=True)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "id": "72922033", 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "print(torch.mean(torch.abs(pred - target)).item())" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "id": "9ad62fac", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "print(torch.mean(pred).item())\n", 197 | "print(torch.mean(target).item())\n", 198 | "\n", 199 | "print((torch.mean(pred).item() - torch.mean(target).item()) / torch.mean(target).item())" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "id": "3ba06998", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "print(torch.std(pred).item())\n", 210 | "print(torch.std(target).item())" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "id": "1dcd00f6", 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [] 220 | } 221 | ], 222 | "metadata": { 223 | "kernelspec": { 224 | "display_name": "Python 3 (ipykernel)", 225 | "language": "python", 226 | "name": "python3" 227 | }, 228 | "language_info": { 229 | "codemirror_mode": { 230 | "name": "ipython", 231 | "version": 3 232 | }, 233 | "file_extension": ".py", 234 | "mimetype": "text/x-python", 235 | "name": "python", 236 | "nbconvert_exporter": "python", 237 | "pygments_lexer": "ipython3", 238 | "version": "3.9.13" 239 | } 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 5 243 | } 244 | -------------------------------------------------------------------------------- /notebooks/charge-dataset-creation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "136ab9a1", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from cdm.chg_utils import build_charge_lmdb\n", 11 | "from ocpmodels.datasets import LmdbDataset\n", 12 | "from ocpmodels.datasets import data_list_collater\n", 13 | "from ase.atoms import Atoms\n", 14 | "import numpy as np" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "9a4adf97", 21 | "metadata": { 22 | "scrolled": false 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "build_charge_lmdb(\n", 27 | " inpath = '../shared-scratch/ethan/density/1k_sample/train', \n", 28 | " outpath = '../charge-data/d/train',\n", 29 | " use_tqdm=True,\n", 30 | " stride = 1,\n", 31 | " probe_graph_adder = None,\n", 32 | " cutoff = 6\n", 33 | ")" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "106a3907", 40 | "metadata": { 41 | "scrolled": true 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "build_charge_lmdb(\n", 46 | " inpath = '../shared-scratch/ethan/density/1k_sample/val', \n", 47 | " outpath = '../charge-data/d/val',\n", 48 | " use_tqdm=True,\n", 49 | " stride = 1,\n", 50 | " probe_graph_adder = None,\n", 51 | " cutoff = 6\n", 52 | ")" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "fa2b49f0", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "dataset = LmdbDataset({'src':'../charge-data/1k-no-probe-graphs/val'})" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "5a811fc0", 69 | "metadata": { 70 | "scrolled": true 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "print(len(dataset))\n", 75 | "\n", 76 | "print(dataset[0])" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "e80a1123", 83 | "metadata": { 84 | "scrolled": true 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "batch = data_list_collater([dataset[0], dataset[1], dataset[2]])" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "e9961826", 95 | "metadata": { 96 | "scrolled": true 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "print(batch.natoms)\n", 101 | "print(batch.charge_density)" 102 | ] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "Python 3 (ipykernel)", 108 | "language": "python", 109 | "name": "python3" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.9.13" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 5 126 | } 127 | -------------------------------------------------------------------------------- /notebooks/training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "7b90f92f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%env OMP_NUM_THREADS = 1" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "53cded00", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import cdm\n", 21 | "\n", 22 | "from ocpmodels.common import logger\n", 23 | "from ocpmodels.common.registry import registry\n", 24 | "from ocpmodels.common.utils import setup_logging\n", 25 | "\n", 26 | "setup_logging()" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "55fad1b2", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "task = {\n", 37 | " 'description': 'Predicting electron density from atomic positions',\n", 38 | " 'dataset': 'lmdb',\n", 39 | "}" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "01a2bd0d", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "model = {\n", 50 | " 'name': 'charge_model',\n", 51 | " 'enforce_zero_for_disconnected_probes': True,\n", 52 | " 'enforce_charge_conservation': True,\n", 53 | " 'freeze_atomic': False,\n", 54 | " \n", 55 | " 'atom_model_config': {\n", 56 | " 'name': 'schnet_charge',\n", 57 | " },\n", 58 | " \n", 59 | " 'probe_model_config': {\n", 60 | " 'name': 'schnet_charge',\n", 61 | " 'num_interactions': 3,\n", 62 | " 'cutoff': 5,\n", 63 | " },\n", 64 | " \n", 65 | " 'otf_pga_config': {\n", 66 | " 'num_probes': 100000,\n", 67 | " 'cutoff': 6,\n", 68 | " }\n", 69 | "}" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "c9e9cb4b", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "optimizer = {\n", 80 | " 'optimizer': 'Adam',\n", 81 | " 'num_workers': 7,\n", 82 | " 'lr_initial': 5e-5,\n", 83 | " 'scheduler': \"ReduceLROnPlateau\",\n", 84 | " 'mode': \"min\",\n", 85 | " 'factor': 0.96,\n", 86 | " 'patience': 1,\n", 87 | " 'max_epochs': 1000,\n", 88 | " 'loss_charge': 'normed_mae'\n", 89 | "}" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "34904aeb", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "dataset = [\n", 100 | " {'src': 'path/to/train'}, \n", 101 | " {'src': 'path/to/val'},\n", 102 | "]" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "10d82621", 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "trainer_config = {\n", 113 | " 'trainer': 'charge',\n", 114 | " 'identifier': 'Electron Density Prediction with SchNet',\n", 115 | " 'is_debug': True,\n", 116 | " 'run_dir': '../runs/',\n", 117 | " 'print_every': 1,\n", 118 | " 'seed': 2,\n", 119 | " 'logger': 'wandb',\n", 120 | " 'local_rank': 0,\n", 121 | " 'amp': True,\n", 122 | "}" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "id": "b2b0585e", 129 | "metadata": { 130 | "scrolled": false 131 | }, 132 | "outputs": [], 133 | "source": [ 134 | "trainer = registry.get_trainer_class(trainer_config['trainer'])(\n", 135 | " task = task,\n", 136 | " model = model,\n", 137 | " dataset = dataset,\n", 138 | " optimizer = optimizer, \n", 139 | " **trainer_config\n", 140 | ")" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "id": "7f659ca1", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "trainer.model.module" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "cc08a533", 157 | "metadata": { 158 | "scrolled": true 159 | }, 160 | "outputs": [], 161 | "source": [ 162 | "trainer.train()" 163 | ] 164 | } 165 | ], 166 | "metadata": { 167 | "kernelspec": { 168 | "display_name": "Python 3 (ipykernel)", 169 | "language": "python", 170 | "name": "python3" 171 | }, 172 | "language_info": { 173 | "codemirror_mode": { 174 | "name": "ipython", 175 | "version": 3 176 | }, 177 | "file_extension": ".py", 178 | "mimetype": "text/x-python", 179 | "name": "python", 180 | "nbconvert_exporter": "python", 181 | "pygments_lexer": "ipython3", 182 | "version": "3.9.13" 183 | } 184 | }, 185 | "nbformat": 4, 186 | "nbformat_minor": 5 187 | } 188 | -------------------------------------------------------------------------------- /notebooks/wandb-sweep.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d0d02ead", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import os\n", 12 | "import cdm.models\n", 13 | "from cdm.charge_trainer import ChargeTrainer\n", 14 | "from ocpmodels.common import logger\n", 15 | "from ocpmodels.common.registry import registry\n", 16 | "from ocpmodels.common.utils import setup_logging\n", 17 | "from cdm.chg_utils import ProbeGraphAdder\n", 18 | "setup_logging()\n", 19 | "\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "from torch_geometric.data import Batch" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "bace2cd5", 28 | "metadata": { 29 | "scrolled": true 30 | }, 31 | "outputs": [], 32 | "source": [ 33 | "# !pip install wandb --upgrade\n", 34 | "import wandb\n", 35 | "wandb.login()\n", 36 | "import pprint" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "id": "525a06a2", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "sweep_config = {\n", 47 | " 'method':'bayes',\n", 48 | " 'metric':{'name':'val/charge_mae', 'goal':'minimize'},\n", 49 | " 'parameters':{\n", 50 | " \n", 51 | " 'num_interactions':{\n", 52 | " 'distribution':'int_uniform',\n", 53 | " 'max': 6,\n", 54 | " 'min': 1,\n", 55 | " },\n", 56 | " \n", 57 | " 'atom_channels':{\n", 58 | " 'distribution':'q_log_uniform_values',\n", 59 | " 'min': 16,\n", 60 | " 'max': 128,\n", 61 | " 'q': 8,\n", 62 | " },\n", 63 | " \n", 64 | " 'probe_channels':{\n", 65 | " 'distribution':'q_log_uniform_values',\n", 66 | " 'min': 16,\n", 67 | " 'max': 128,\n", 68 | " 'q': 8,\n", 69 | " },\n", 70 | " \n", 71 | " 'batch_size':{\n", 72 | " 'distribution':'q_log_uniform_values',\n", 73 | " 'min': 1,\n", 74 | " 'max': 16,\n", 75 | " 'q': 2,\n", 76 | " },\n", 77 | " \n", 78 | " 'atom_filters':{\n", 79 | " 'distribution':'q_log_uniform_values',\n", 80 | " 'min': 8,\n", 81 | " 'max': 128,\n", 82 | " 'q': 8,\n", 83 | " },\n", 84 | " \n", 85 | " 'probe_filters':{\n", 86 | " 'distribution':'q_log_uniform_values',\n", 87 | " 'min': 8,\n", 88 | " 'max': 128,\n", 89 | " 'q': 8,\n", 90 | " },\n", 91 | " \n", 92 | " 'atom_gaussians':{\n", 93 | " 'distribution':'q_log_uniform_values',\n", 94 | " 'min': 8,\n", 95 | " 'max': 32,\n", 96 | " 'q': 8,\n", 97 | " },\n", 98 | " \n", 99 | " 'probe_gaussians':{\n", 100 | " 'distribution':'q_log_uniform_values',\n", 101 | " 'min': 8,\n", 102 | " 'max': 128,\n", 103 | " 'q': 8,\n", 104 | " },\n", 105 | " \n", 106 | " 'cutoff': {\n", 107 | " 'distribution':'int_uniform',\n", 108 | " 'max': 4,\n", 109 | " 'min': 2,\n", 110 | " },\n", 111 | " }\n", 112 | "}\n", 113 | "\n", 114 | "pprint.pprint(sweep_config)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "id": "9c97b17d", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "sweep_id = wandb.sweep(sweep_config, project=\"charge-density-models-sweeps\")\n", 125 | "print(sweep_id)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "54250beb", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "def train(config=None):\n", 136 | " with wandb.init(config=config):\n", 137 | " config = wandb.config\n", 138 | " \n", 139 | " task = {\n", 140 | " 'dataset': 'lmdb',\n", 141 | " 'description': 'Initial test of training on charges',\n", 142 | " 'type': 'regression',\n", 143 | " 'metric': ['charge_mse', 'charge_mae'],\n", 144 | " 'primary_metric': 'charge_mae',\n", 145 | " 'labels': ['charge_vals'],\n", 146 | " }\n", 147 | " \n", 148 | " model = {\n", 149 | " 'name': 'charge_model',\n", 150 | " 'num_interactions': config.num_interactions,\n", 151 | " 'atom_channels': config.atom_channels,\n", 152 | " 'probe_channels': config.probe_channels,\n", 153 | "\n", 154 | " 'atom_model_config': {\n", 155 | " 'name': 'schnet_charge',\n", 156 | " 'num_filters':config.atom_filters,\n", 157 | " 'num_gaussians':config.atom_gaussians,\n", 158 | " 'cutoff':6,\n", 159 | " },\n", 160 | "\n", 161 | " 'probe_model_config': {\n", 162 | " 'name': 'schnet_charge',\n", 163 | " 'num_filters':config.probe_filters,\n", 164 | " 'num_gaussians':config.probe_gaussians,\n", 165 | " 'cutoff':config.cutoff,\n", 166 | " },\n", 167 | " }\n", 168 | " \n", 169 | " optimizer = {\n", 170 | " 'optimizer': 'Adam',\n", 171 | " 'batch_size': config.batch_size,\n", 172 | " 'eval_batch_size': 10,\n", 173 | " 'num_workers': 1,\n", 174 | " 'lr_initial': 5e-4,\n", 175 | " 'scheduler': \"ReduceLROnPlateau\",\n", 176 | " 'mode': \"min\",\n", 177 | " 'factor': 0.96,\n", 178 | " 'patience': 1,\n", 179 | " 'max_epochs': 300,\n", 180 | " }\n", 181 | " \n", 182 | " dataset = [\n", 183 | " {'src': '../chg/100/train', 'normalize_labels': False}, # train set \n", 184 | " {'src': '../chg/100/val'}, # val set (optional)\n", 185 | " # {'src': train_src} # test set (optional - writes predictions to disk)\n", 186 | " ]\n", 187 | " \n", 188 | " trainer_config = {\n", 189 | " 'trainer': 'charge',\n", 190 | " 'identifier': 'sweep_run',\n", 191 | " 'is_debug': False,\n", 192 | " 'run_dir': './runs/',\n", 193 | " 'print_every': 1,\n", 194 | " 'seed': 2,\n", 195 | " 'logger': 'wandb',\n", 196 | " 'local_rank': 0,\n", 197 | " 'amp': True,\n", 198 | "\n", 199 | " 'cutoff': config.cutoff,\n", 200 | " }\n", 201 | " \n", 202 | " trainer = registry.get_trainer_class(\n", 203 | " trainer_config['trainer'])(task = task,\n", 204 | " model = model,\n", 205 | " dataset = dataset,\n", 206 | " optimizer = optimizer,\n", 207 | " **trainer_config)\n", 208 | " \n", 209 | " trainer.train()" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "id": "285553ff", 216 | "metadata": { 217 | "scrolled": true 218 | }, 219 | "outputs": [], 220 | "source": [ 221 | "wandb.agent('charge-density-models-sweeps/########', train, count=100)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "id": "76099a62", 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [] 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": "Python 3 (ipykernel)", 236 | "language": "python", 237 | "name": "python3" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 3 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython3", 249 | "version": "3.9.13" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 5 254 | } 255 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="charge-density-models", 5 | version="0.0.0", 6 | description="Tools to build charge density models using ocpmodels", 7 | url="https://github.com/ulissigroup/charge-density-models", 8 | packages=find_packages(), 9 | include_package_data=True, 10 | ) --------------------------------------------------------------------------------