├── gpnn ├── __init__.py ├── data │ ├── __init__.py │ ├── datamodule.py │ ├── dataset.py │ ├── transforms.py │ └── builders.py ├── models │ ├── __init__.py │ └── model.py ├── train │ ├── __init__.py │ └── trainer.py ├── predict │ ├── __init__.py │ └── predict.py ├── system │ ├── __init__.py │ └── system.py └── fingerprint │ ├── __init__.py │ ├── symmetry_functions.py │ ├── neighbors.py │ └── fingerprint.py ├── .gitignore ├── models └── gpnn_nmc │ └── last.ckpt ├── configs ├── model │ └── default.yaml ├── config.yaml ├── inference │ └── default.yaml ├── train │ └── default.yaml └── data │ ├── NMC.yaml │ └── template.yaml ├── requirements.txt ├── scripts ├── train.py ├── build.py └── predict.py ├── setup.py ├── LICENSE ├── CODE_OF_CONDUCT.md └── README.md /gpnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gpnn/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gpnn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import MLP -------------------------------------------------------------------------------- /gpnn/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer -------------------------------------------------------------------------------- /gpnn/predict/__init__.py: -------------------------------------------------------------------------------- 1 | from .predict import CIFPredictor 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | gpnn.egg-info/** 3 | **/NMC_test.yaml -------------------------------------------------------------------------------- /gpnn/system/__init__.py: -------------------------------------------------------------------------------- 1 | from .system import ChargeDensitySystem, fractional_to_cartesian -------------------------------------------------------------------------------- /gpnn/fingerprint/__init__.py: -------------------------------------------------------------------------------- 1 | from .fingerprint import Fingerprint, GaussianSymmetryFunctions, Neighbors 2 | -------------------------------------------------------------------------------- /models/gpnn_nmc/last.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GenerativeMaterials/OpenGPNN/HEAD/models/gpnn_nmc/last.ckpt -------------------------------------------------------------------------------- /configs/model/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: gpnn.models.MLP 2 | h5_path: ${data.h5_path} 3 | hidden_dim: 300 4 | n_layers: 3 5 | optim: 6 | lr: ${train.lr} 7 | epochs: ${train.epochs} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cupy-cuda11x==13.0.0 2 | dask==2023.11.0 3 | h5py==3.8.0 4 | hydra-core==1.3.2 5 | matplotlib==3.8.3 6 | mp-pyrho==0.3.0 7 | omegaconf==2.3.0 8 | pandas==1.5.3 9 | plotly==5.13.0 10 | pybader==0.3.12 11 | pymatgen>=2024.3.1 12 | pytorch-lightning>=2.2.0 13 | pyyaml==6.0 14 | scikit-learn==1.2.1 15 | scipy==1.10.0 16 | seaborn==0.13.2 17 | sparse==0.15.1 18 | torch==2.2.0 19 | tqdm==4.64.1 20 | wandb==0.16.3 21 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | version: 0.0.1 2 | 3 | # Choose a descriptive name for the experiment 4 | expname: gpnn_nmc 5 | 6 | # Path to the root of the project 7 | scratch_foler: user_defined_scratch_folder 8 | 9 | defaults: 10 | - data: NMC 11 | - train: default 12 | - model: default 13 | - inference: default 14 | - input: null 15 | - _self_ 16 | 17 | hydra: 18 | run: 19 | dir: ${scratch_foler}/../logging/gpnn_v${version}/hydra/${now:%Y-%m-%d}/${expname} 20 | -------------------------------------------------------------------------------- /configs/inference/default.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Change to the directory path containing the CIF files you will run inference on 3 | cif_dir: /home/jake/newvol/data/NMC/inference/scale_test 4 | 5 | # Change to the location you want to save the resulting CHGCARs 6 | save_dir: /home/jake/newvol/data/NMC/inference/chgcars 7 | 8 | # Do not modify 9 | model_path: ../models/${expname} 10 | 11 | # Update with the shape of the charge density grid you want to use for inference 12 | shape: [56, 56, 140] 13 | 14 | # "gpu" or "cpu" 15 | device: gpu -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | from omegaconf import DictConfig 5 | import pytorch_lightning as L 6 | from gpnn.train.trainer import Trainer 7 | 8 | @hydra.main(config_path="../configs", config_name="config", version_base=None) 9 | def main(cfg: DictConfig): 10 | L.seed_everything(cfg.train.seed) 11 | 12 | # Run training 13 | trainer: Trainer = hydra.utils.instantiate( 14 | cfg.train.trainer, cfg=cfg, _recursive_=False 15 | ) 16 | trainer.fit() 17 | trainer.test() 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup script for GPNN package.""" 2 | 3 | from setuptools import setup, find_packages 4 | 5 | def read_requirements(): 6 | """Parse requirements from requirements.txt.""" 7 | with open('requirements.txt', encoding='UTF-8') as req: 8 | return req.read().splitlines() 9 | 10 | setup( 11 | name='gpnn', 12 | version='0.0.1', 13 | author='Jake Vikoren', 14 | author_email='jake.vikoren@genmat.xyz', 15 | description='An ML Model for the prediction of charge density.', 16 | long_description=open('README.md', encoding='UTF-8').read(), 17 | long_description_content_type='text/markdown', 18 | url='https://github.com/GenerativeMaterials/GPNN', 19 | packages=find_packages(), 20 | install_requires=read_requirements(), 21 | python_requires='>=3.7', 22 | ) 23 | -------------------------------------------------------------------------------- /scripts/build.py: -------------------------------------------------------------------------------- 1 | """Script for creating a dataset file from a directory of CIF files.""" 2 | 3 | import hydra 4 | from omegaconf import DictConfig 5 | 6 | from gpnn.data.builders import HDF5DatasetBuilder 7 | 8 | __author__ = "Jake Vikoren" 9 | __maintainer__ = "Jake Vikoren" 10 | __email__ = "jake.vikoren@genmat.xyz" 11 | __date__ = "06/20/2024" 12 | 13 | """This script ingests a directory of CHGCAR files and composes them into a single HDF5 14 | dataset file that will be used for training or fine-tuning a GPNN model.""" 15 | 16 | @hydra.main(config_path="../configs", config_name="config", version_base=None) 17 | def main(cfg: DictConfig): 18 | """Main function for creating a dataset file from a directory of CIF files.""" 19 | # Initialize dataset builder from hydra configs 20 | builder: HDF5DatasetBuilder = hydra.utils.instantiate( 21 | cfg.data.builder, _convert_="partial" 22 | ) 23 | 24 | # Build the dataset 25 | builder.build(cfg.data.override) 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /configs/train/default.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Update if desired 3 | epochs: 200 4 | lr: 1e-4 5 | deterministic: False # Set to True for reproducibility 6 | seed: 42 7 | 8 | # Below defaults should be sufficient 9 | batch_size: 20_000 10 | shuffle: True 11 | model_save_dir: ../models/${expname} 12 | verbose: 13 | 14 | trainer: 15 | _target_: gpnn.train.Trainer 16 | cfg: null 17 | 18 | data_module: 19 | _target_: gpnn.data.datamodule.HDF5DataModule 20 | h5_path: ${data.h5_path} 21 | batch_size: ${train.batch_size} 22 | shuffle: ${train.shuffle} 23 | 24 | checkpoint_callback: 25 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 26 | dirpath: ${train.model_save_dir} 27 | filename: "{epoch}-{val_loss:.8f}" 28 | monitor: val_loss 29 | verbose: ${train.verbose} 30 | save_last: True 31 | mode: min 32 | 33 | lr_monitor: 34 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 35 | logging_interval: step 36 | 37 | lightning_trainer: 38 | _target_: pytorch_lightning.Trainer 39 | devices: auto 40 | accelerator: ${data.builder.device} 41 | deterministic: ${train.deterministic} 42 | max_epochs: ${train.epochs} 43 | enable_checkpointing: True 44 | default_root_dir: ${scratch_folder} 45 | log_every_n_steps: 50 46 | -------------------------------------------------------------------------------- /configs/data/NMC.yaml: -------------------------------------------------------------------------------- 1 | # Download the NMC CHGCAR files to the root_dir in a folder called "chgcars" 2 | 3 | # Update this path to the root directory where you have stored the directory of chgcars 4 | root_dir: /home/jake/newvol/data 5 | 6 | # Set to true if you want to overwrite existing dataset files 7 | override: False 8 | 9 | # Do not modify the below (unless you know what you're doing) 10 | name: NMC 11 | data_dir: ${data.root_dir}/${data.name} 12 | h5_path: ${data.data_dir}/datasets/${data.name}.h5 13 | 14 | # Dataset splits (int, float, and dict types are supported). 15 | # int: number of structures per split (sum must be <= total structures in dataset) 16 | # float: fraction of structures per split (must sum to 1, e.g. [.8, .1, .1] for 80/10/10 split) 17 | # dict: key is "train", "val", or "test", value is the file names of structures to include (List[str]) 18 | splits: [1450, 50, 500] # train, val, test 19 | 20 | # Dataset builder 21 | builder: 22 | _target_: gpnn.data.builders.HDF5DatasetBuilder 23 | dataset_elements: ["Ni", "Mn", "Co", "Li", "O"] 24 | in_dir: ${data.data_dir}/chgcars 25 | out_dir: ${data.data_dir}/datasets 26 | filename: ${data.name}.h5 27 | cutoff: 6 28 | shape: null # define grid shape OR downsample_factor 29 | downsample_factor: 2 30 | splits: ${data.splits} 31 | device: gpu 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /scripts/predict.py: -------------------------------------------------------------------------------- 1 | """Script for predicting charge density from a directory of CIF files.""" 2 | 3 | import hydra 4 | from omegaconf import DictConfig 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | 8 | from gpnn.predict import CIFPredictor 9 | 10 | __author__ = "Jake Vikoren" 11 | __maintainer__ = "Jake Vikoren" 12 | __email__ = "jake.vikoren@genmat.xyz" 13 | __date__ = "06/20/2024" 14 | 15 | """This script ingests a directory of CIF files and runs model predictions that 16 | are output to a directory as CHGCAR files.""" 17 | 18 | 19 | @hydra.main(config_path="../configs", config_name="config", version_base=None) 20 | def main(cfg: DictConfig): 21 | """Main function for creating a dataset file from a directory of CIF files.""" 22 | files = list(Path(cfg.inference.cif_dir).rglob("*.cif")) 23 | for file in tqdm( 24 | files, 25 | desc="Processing Files", 26 | total=len(files), 27 | ): 28 | # Initialize dataset builder from hydra configs 29 | predictor = CIFPredictor( 30 | cif_path=file, 31 | save_dir=cfg.inference.save_dir, 32 | model_path=cfg.inference.model_path, 33 | shape=cfg.inference.shape, 34 | device_str=cfg.inference.device, 35 | ) 36 | 37 | predictor.cif_to_chgcar() 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /gpnn/fingerprint/symmetry_functions.py: -------------------------------------------------------------------------------- 1 | """Module for computing Gaussian symmetry functions.""" 2 | 3 | from dataclasses import dataclass, field 4 | from functools import cached_property 5 | import numpy as np 6 | 7 | __author__ = "Jake Vikoren" 8 | __maintainer__ = "Jake Vikoren" 9 | __email__ = "jake.vikoren@genmat.xyz" 10 | __date__ = "07/11/2024" 11 | 12 | @dataclass(frozen=True) 13 | class GaussianSymmetryFunctions: 14 | """Class for computing Gaussian symmetry functions. 15 | 16 | Args: 17 | MAX_GAUSSIAN_WIDTH (int): Maximum width of the Gaussian functions. 18 | N_GAUSSIANS (int): Number of Gaussian functions to use. 19 | dtype (np.dtype): Data type to use for the Gaussian functions. 20 | Returns: 21 | None 22 | """ 23 | MAX_GAUSSIAN_WIDTH: int = field(default=10) 24 | N_GAUSSIANS: int = field(default=16) 25 | dtype: np.dtype = field(default=np.float32) 26 | 27 | @cached_property 28 | def GAUSSIAN_STANDARD_DEVIATIONS(self): 29 | return np.geomspace( 30 | start=0.25, 31 | stop=self.MAX_GAUSSIAN_WIDTH, 32 | num=self.N_GAUSSIANS, 33 | dtype=self.dtype, 34 | ) 35 | 36 | @cached_property 37 | def NORMALIZING_CONSTANTS(self): 38 | return 1 / ( 39 | ((2 * np.pi) ** (3 / 2)) * (self.GAUSSIAN_STANDARD_DEVIATIONS**3) 40 | ) 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2024, Quantum Generative Materials (GenMat) 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | -------------------------------------------------------------------------------- /configs/data/template.yaml: -------------------------------------------------------------------------------- 1 | # Description: 2 | # - Template for creating a custom dataset. 3 | 4 | # Instructions: 5 | # 1. Load all your data into a directory with the following structure: 6 | # - root_dir 7 | # - name (e.g. "NMC") 8 | # - chgcars (this is where the CHGCAR files should be stored) 9 | # - .CHGCAR 10 | # - .CHGCAR 11 | # - ... 12 | # - datasets (this will be created and populated by the builder) 13 | # - .h5 (this file is the dataset object resulting from "build.py") 14 | # 2. Copy this template to GPNN/configs/data/.yaml and modify the parameters as needed. 15 | # 3. Afterwards update "defaults.data" in GPNN/configs/config.yaml to the name of your dataset configuration file. 16 | # 4. Run "python build.py" to build the dataset object. 17 | # NOTE: if you skip step 3, you can still run "python build.py data=" for the same effect. 18 | 19 | # Naming and Paths 20 | name: # Name of your dataset 21 | root_dir: # Path to the root directory of your dataset 22 | data_dir: ${data.root_dir}/${data.name} # DO NOT MODIFY 23 | h5_path: ${data.data_dir}/datasets/${data.name}.h5 # DO NOT MODIFY 24 | override: False # If True, existing datasets will be overwritten 25 | 26 | # Dataset splits (int, float, and dict types are supported). 27 | # int: number of structures per split (sum must be <= total structures in dataset) 28 | # float: fraction of structures per split (must sum to 1, e.g. [.8, .1, .1] for 80/10/10 split) 29 | # dict: key is "train", "val", or "test", value is the file names of structures to include (List[str]) 30 | splits: [.7, .1, .2] # train, val, test 31 | 32 | # Dataset builder 33 | builder: 34 | _target_: gpnn.data.builders.HDF5DatasetBuilder # DO NOT MODIFY 35 | dataset_elements: ["Ni", "Mn", "Co", "Li", "O"] # Update with the elements in your dataset 36 | in_dir: ${data.data_dir}/chgcars # DO NOT MODIFY 37 | out_dir: ${data.data_dir}/datasets # DO NOT MODIFY 38 | filename: ${data.name}.h5 # DO NOT MODIFY 39 | cutoff: 6 # DO NOT MODIFY (unless you know what you're doing - this is the cutoff radius in Angstroms for fingerprinting) 40 | shape: null # define grid shape OR downsample_factor 41 | downsample_factor: 2 42 | splits: ${data.splits} # DO NOT MODIFY 43 | device: gpu # gpu or cpu -------------------------------------------------------------------------------- /gpnn/fingerprint/neighbors.py: -------------------------------------------------------------------------------- 1 | """This module contains the class for computing neighbors of grid points within a cutoff radius.""" 2 | 3 | from dataclasses import dataclass 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | from numpy.typing import ArrayLike 8 | from pymatgen.optimization import neighbors as nbs 9 | 10 | __author__ = "Jake Vikoren" 11 | __maintainer__ = "Jake Vikoren" 12 | __email__ = "jake.vikoren@genmat.xyz" 13 | __date__ = "07/11/2024" 14 | 15 | 16 | @dataclass(kw_only=True) 17 | class Neighbors: 18 | """Class for computing neighbors of grid points within a cutoff radius. 19 | 20 | Args: 21 | grid_coords (ArrayLike): Coordinates of grid points. 22 | atom_coords (ArrayLike): Coordinates of atoms. 23 | lattice (ArrayLike): Lattice vectors. 24 | cutoff (float): Cutoff radius for neighbor search. 25 | pbc (ArrayLike): Periodic boundary conditions. Defaults to [1, 1, 1]. 26 | Returns: 27 | None 28 | """ 29 | grid_coords: ArrayLike 30 | atom_coords: ArrayLike 31 | lattice: ArrayLike 32 | cutoff: float 33 | pbc: ArrayLike = np.array([1, 1, 1], dtype=np.int64) 34 | 35 | def __post_init__(self): 36 | self.grid_idxs, self.atom_idxs, self.unit_offsets, self.distances = ( 37 | self._get_neigbors() 38 | ) 39 | 40 | self.offset_vectors = self._get_offset_vectors() 41 | 42 | def get(self): 43 | return self.grid_idxs, self.atom_idxs, self.distances, self.offset_vectors 44 | 45 | def _get_neigbors(self) -> Tuple[ArrayLike, ArrayLike, ArrayLike, ArrayLike]: 46 | """Find the neighbors of each grid point within the cutoff radius.""" 47 | grid_idxs, atom_idxs, unit_offsets, distances = nbs.find_points_in_spheres( 48 | all_coords=self.atom_coords, 49 | center_coords=self.grid_coords, 50 | r=self.cutoff, 51 | pbc=self.pbc, 52 | lattice=self.lattice, 53 | ) 54 | return grid_idxs, atom_idxs, unit_offsets, distances 55 | 56 | def _get_offset_vectors(self) -> ArrayLike: 57 | """Compute the displacement vectors between atoms and grid points.""" 58 | # Get the coordinates of atoms and grid points using indices 59 | atom_coords = self.atom_coords[self.atom_idxs] 60 | grid_coords = self.grid_coords[self.grid_idxs] 61 | 62 | # Compute the translation vectors for periodic boundary conditions 63 | translation_vectors = self.unit_offsets @ self.lattice 64 | 65 | # Adjust grid coordinates for periodic boundary conditions 66 | atom_coords += translation_vectors 67 | 68 | # Compute displacement along x, y, and z axes 69 | offset_vectors = atom_coords - grid_coords 70 | 71 | return offset_vectors 72 | -------------------------------------------------------------------------------- /gpnn/train/trainer.py: -------------------------------------------------------------------------------- 1 | """Module for training the GPNN model.""" 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | from pathlib import Path 7 | import hydra 8 | 9 | import pytorch_lightning as L 10 | import torch 11 | from omegaconf import DictConfig 12 | 13 | from gpnn.models import MLP 14 | 15 | __author__ = "Jake Vikoren" 16 | __maintainer__ = "Jake Vikoren" 17 | __email__ = "jake.vikoren@genmat.xyz" 18 | __date__ = "06/20/2024" 19 | 20 | """This module contains the class responsible for training the GPNN model.""" 21 | 22 | 23 | class Trainer: 24 | """Handler for training GPNN.""" 25 | 26 | def __init__(self, cfg: DictConfig) -> None: 27 | """ 28 | Args: 29 | cfg (DictConfig): The configuration object composed using hydra. 30 | Returns: 31 | None 32 | """ 33 | self.cfg = cfg 34 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | self.model = None 36 | 37 | self._setup() 38 | 39 | def _setup(self) -> None: 40 | # Set seed and device 41 | if self.cfg.train.deterministic: 42 | L.seed_everything(self.cfg.train.seed) 43 | torch.set_float32_matmul_precision("high") 44 | 45 | # Instantiate data module 46 | self.data_module = hydra.utils.instantiate(self.cfg.train.data_module) 47 | 48 | # Instantiate callbacks 49 | self.checkpoint_callback = hydra.utils.instantiate( 50 | self.cfg.train.checkpoint_callback 51 | ) 52 | self.lr_monitor = hydra.utils.instantiate(self.cfg.train.lr_monitor) 53 | 54 | callbacks = [ 55 | self.checkpoint_callback, 56 | self.lr_monitor, 57 | ] 58 | 59 | # Instantiate model 60 | self._load_model() 61 | 62 | # Instantiate trainer 63 | self.lightning_trainer: L.Trainer = hydra.utils.instantiate( 64 | self.cfg.train.lightning_trainer, 65 | callbacks=callbacks, 66 | logger=None, 67 | ) 68 | 69 | def _load_model(self) -> None: 70 | """Load model from checkpoint if it exists, otherwise initialize a new model.""" 71 | if Path(self.cfg.train.model_save_dir).exists(): 72 | print("Loading model from checkpoint") 73 | self.model = MLP.load_from_checkpoint( 74 | os.path.join(self.cfg.train.model_save_dir, "last.ckpt") 75 | ) 76 | else: 77 | self.model = hydra.utils.instantiate( 78 | self.cfg.model, optim=self.cfg.model.optim, _recursive_=False 79 | ) 80 | self.model = self.model.to(self.device) 81 | 82 | def fit(self) -> None: 83 | """Fit model""" 84 | self.lightning_trainer.fit(self.model, self.data_module) 85 | 86 | def test(self) -> None: 87 | """Test model.""" 88 | self.lightning_trainer.test(self.model, self.data_module) 89 | -------------------------------------------------------------------------------- /gpnn/data/datamodule.py: -------------------------------------------------------------------------------- 1 | """PyTorch Dataloader class for model training.""" 2 | 3 | import os 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | from pytorch_lightning import LightningDataModule 8 | from torch.utils.data import DataLoader 9 | 10 | from gpnn.data.dataset import HDF5Dataset 11 | 12 | __author__ = "Jake Vikoren" 13 | __maintainer__ = "Jake Vikoren" 14 | __email__ = "jake.vikoren@genmat.xyz" 15 | __date__ = "07/11/2024" 16 | 17 | """This module contains the HDF5DataModule class, which is used to load data from an HDF5 file. 18 | Since data is already batched in the HDF5Dataset for computational efficiency, we use a batch 19 | size of 1 and a custom collate function.""" 20 | 21 | 22 | class HDF5DataModule(LightningDataModule): 23 | """DataModule for loading data from an HDF5 file. 24 | 25 | Args: 26 | h5_path (str): Path to the HDF5 file. 27 | batch_size (int): The batch size to use when loading data. 28 | shuffle (bool): The order of batches is always shuffled. Set to True to 29 | shuffle samples within each batch. 30 | """ 31 | def __init__(self, h5_path: str, batch_size: int = 5000, shuffle: bool = True): 32 | super().__init__() 33 | self.h5_path = h5_path 34 | self.batch_size = batch_size 35 | self.shuffle = shuffle 36 | 37 | self.setup() 38 | 39 | def setup(self, stage: Optional[str] = None) -> None: 40 | """Setup the datasets for training, validation, and testing.""" 41 | if stage == "fit" or stage is None: 42 | self.train_dataset = HDF5Dataset( 43 | h5_path=self.h5_path, 44 | split="train", 45 | batch_size=self.batch_size, 46 | shuffle=self.shuffle, 47 | ) 48 | self.val_dataset = HDF5Dataset( 49 | h5_path=self.h5_path, 50 | split="val", 51 | batch_size=self.batch_size, 52 | shuffle=False, 53 | ) 54 | if stage == "test" or stage is None: 55 | self.test_dataset = HDF5Dataset( 56 | h5_path=self.h5_path, 57 | split="test", 58 | batch_size=self.batch_size, 59 | shuffle=False, 60 | ) 61 | 62 | def train_dataloader(self): 63 | """Training dataloader.""" 64 | return DataLoader( 65 | self.train_dataset, 66 | batch_size=1, 67 | shuffle=self.shuffle, 68 | num_workers=os.cpu_count(), 69 | collate_fn=self.collate_fn, 70 | ) 71 | 72 | def val_dataloader(self) -> DataLoader: 73 | """Validation dataloader.""" 74 | return DataLoader( 75 | self.val_dataset, 76 | batch_size=1, 77 | shuffle=False, 78 | num_workers=os.cpu_count(), 79 | collate_fn=self.collate_fn, 80 | ) 81 | 82 | def test_dataloader(self) -> DataLoader: 83 | """Test dataloader.""" 84 | return DataLoader( 85 | self.test_dataset, 86 | batch_size=1, 87 | shuffle=False, 88 | num_workers=os.cpu_count(), 89 | collate_fn=self.collate_fn, 90 | ) 91 | 92 | def collate_fn( 93 | self, batch: Tuple[torch.Tensor, torch.Tensor] 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | """Avoids adding the batch dimension since our datasets are already batched.""" 96 | return batch[0] 97 | -------------------------------------------------------------------------------- /gpnn/predict/predict.py: -------------------------------------------------------------------------------- 1 | """A module for charge density predictions.""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass 6 | from functools import cached_property 7 | import os 8 | from pathlib import Path 9 | from typing import Optional, Tuple 10 | from pymatgen.io.vasp.outputs import VolumetricData 11 | 12 | import numpy as np 13 | from numpy.typing import ArrayLike 14 | import torch 15 | 16 | from gpnn.system import ChargeDensitySystem 17 | from gpnn.fingerprint import Fingerprint 18 | from gpnn.models import MLP 19 | 20 | __author__ = "Jake Vikoren" 21 | __maintainer__ = "Jake Vikoren" 22 | __email__ = "jake.vikoren@genmat.xyz" 23 | __date__ = "07/11/2024" 24 | 25 | 26 | @dataclass(kw_only=True) 27 | class CIFPredictor: 28 | """Predict charge density from a CIF file. 29 | 30 | Args: 31 | cif_path (str): Path to the CIF file. 32 | save_dir (str): Directory to save the CHGCAR file. 33 | model_path (str): Path to the pretrained model. 34 | shape (Optional[Tuple[int, int, int]]): Shape of the charge density grid. 35 | device_str (str): Device to use for prediction. 36 | Returns: 37 | None 38 | """ 39 | 40 | cif_path: str 41 | save_dir: str 42 | model_path: str 43 | shape: Optional[Tuple[int, int, int]] = (50, 50, 50) 44 | device_str: str = "gpu" 45 | 46 | def __post_init__(self) -> None: 47 | self.features_mean = self.model.hparams.features_mean.to(self.device) 48 | self.features_std = self.model.hparams.features_std.to(self.device) 49 | self.targets_mean = self.model.hparams.targets_mean.to(self.device) 50 | self.targets_std = self.model.hparams.targets_std.to(self.device) 51 | 52 | @property 53 | def device(self) -> torch.device: 54 | """Torch device to use for prediction.""" 55 | if self.device_str.lower() == "cpu": 56 | return torch.device("cpu") 57 | if self.device_str.lower() == "gpu": 58 | if torch.cuda.is_available(): 59 | return torch.device("cuda") 60 | else: 61 | print("GPU not available, using CPU.") 62 | return torch.device("cpu") 63 | else: 64 | raise ValueError( 65 | f"Invalid device: {self.device_str}. Please choose 'cpu' or 'gpu'." 66 | ) 67 | 68 | @device.setter 69 | def device(self, value: str) -> None: 70 | """Set the device with validation.""" 71 | if value.lower() not in ["cpu", "gpu"]: 72 | raise ValueError(f"Invalid device: {value}. Please choose 'cpu' or 'gpu'.") 73 | self.device_str = value 74 | 75 | @cached_property 76 | def system(self) -> ChargeDensitySystem: 77 | """Load the system from the CIF file.""" 78 | return ChargeDensitySystem.from_cif(self.cif_path, shape=self.shape) 79 | 80 | @cached_property 81 | def model(self) -> None: 82 | """Load the pretrained model.""" 83 | model = MLP.load_from_checkpoint(os.path.join(self.model_path, "last.ckpt")) 84 | model.to(self.device) 85 | model.eval() 86 | return model 87 | 88 | def predict(self) -> ArrayLike: 89 | """Predict the charge density.""" 90 | 91 | # Compute the fingerprint 92 | fingerprint = Fingerprint( 93 | system=self.system, 94 | dataset_elements=list(self.model.hparams.dataset_symbols), 95 | device=self.device_str, 96 | ) 97 | 98 | features = fingerprint.get() 99 | 100 | # Standardize features 101 | features = torch.tensor(features, dtype=torch.float32).to(self.device) 102 | features = (features - self.features_mean) / self.features_std 103 | 104 | # Predict charge density 105 | preds = self.model(features) 106 | 107 | # Inverse GPNN standardization 108 | preds = (preds * self.targets_std) + self.targets_mean 109 | 110 | # Inverse VASP normalization from PyRho 111 | # (https://github.com/materialsproject/pyrho/blob/2c35912d667e65d7f9d54d63a3693ad6e014a401/src/pyrho/charge_density.py#L489) 112 | preds = preds.detach().cpu().numpy() * self.system.primitive_structure.volume 113 | 114 | return np.reshape(preds, self.shape) 115 | 116 | def cif_to_chgcar(self) -> ChargeDensitySystem: 117 | """Create a ChargeDensitySystem from a CIF file.""" 118 | data = self.predict() 119 | volumetric_data = VolumetricData( 120 | structure=self.system.primitive_structure, data={"total": data} 121 | ) 122 | # Prepare save dir 123 | if not os.path.exists(self.save_dir): 124 | os.makedirs(self.save_dir) 125 | save_path = os.path.join(self.save_dir, f"{Path(self.cif_path).stem}_gpnn.CHGCAR") 126 | volumetric_data.write_file(save_path) 127 | -------------------------------------------------------------------------------- /gpnn/data/dataset.py: -------------------------------------------------------------------------------- 1 | """Module defining the dataset object.""" 2 | 3 | import random 4 | from pathlib import Path 5 | from typing import List, Literal, Tuple, Union 6 | 7 | import h5py 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | __author__ = "Jake Vikoren" 12 | __maintainer__ = "Jake Vikoren" 13 | __email__ = "jake.vikoren@genmat.xyz" 14 | __date__ = "07/11/2024" 15 | 16 | """This module defines the HDF5Dataset class, which is used to load data from an HDF5 file.""" 17 | 18 | 19 | class HDF5Dataset(Dataset): 20 | """Dataset class for loading data from an HDF5 file.""" 21 | def __init__( 22 | self, 23 | h5_path: Union[str, Path], 24 | split: Literal["train", "val", "test"], 25 | batch_size: int = 5000, 26 | shuffle: bool = True, 27 | ): 28 | """ 29 | Args: 30 | h5_path (Union[str, Path]): Path to the HDF5 file. 31 | split (Literal["train", "val", "test"]): The dataset split to load. 32 | batch_size (int): The batch size to use when loading data. 33 | shuffle (bool): The order of batches is always shuffled. Set to True to 34 | shuffle samples within each batch. 35 | """ 36 | self.h5_path = h5_path 37 | self.split = split 38 | self.batch_size = batch_size 39 | self.shuffle = shuffle 40 | 41 | # Load structure names and batch info. 42 | self.structure_names = self._load_structure_names() 43 | self.batch_info = self._get_batch_info() 44 | 45 | # Load standardization parameters. 46 | self.features_mean, self.features_std, self.targets_mean, self.targets_std = ( 47 | self._load_mean_std() 48 | ) 49 | 50 | def _get_batch_info(self): 51 | """Get the start and end indices for the batch.""" 52 | batch_info = [] 53 | with h5py.File(self.h5_path, "r") as file: 54 | for structure in self.structure_names: 55 | # Determine the number of data points in the current structure. 56 | n_points = file[f"{structure}/features"].shape[0] 57 | 58 | # Calculate the number of full batches that can be formed from these grid points. 59 | n_batches = (n_points + self.batch_size - 1) // self.batch_size 60 | 61 | for batch_idx in range(n_batches): 62 | start_idx = batch_idx * self.batch_size 63 | end_idx = min( 64 | start_idx + self.batch_size, n_points 65 | ) # Ensure we don't exceed the number of grid points. 66 | batch_info.append((structure, start_idx, end_idx)) 67 | random.shuffle(batch_info) # Always shuffle batches 68 | return batch_info 69 | 70 | def _load_structure_names(self) -> List[str]: 71 | """Load structure names for the specified split from the HDF5 file.""" 72 | structure_names = [] 73 | with h5py.File(self.h5_path, "r") as h5_file: 74 | for structure_name, group in h5_file.items(): 75 | if group.attrs.get("split") == self.split: 76 | structure_names.append(structure_name) 77 | return structure_names 78 | 79 | def _load_mean_std( 80 | self, 81 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 82 | """Load standardization parameters from the HDF5 file.""" 83 | with h5py.File(self.h5_path, "r") as h5_file: 84 | features_mean = torch.tensor(h5_file.attrs["features_mean"]) 85 | features_std = torch.tensor(h5_file.attrs["features_std"]) 86 | targets_mean = torch.tensor(h5_file.attrs["targets_mean"]) 87 | targets_std = torch.tensor(h5_file.attrs["targets_std"]) 88 | return features_mean, features_std, targets_mean, targets_std 89 | 90 | def __len__(self) -> int: 91 | """Return the number of items in the dataset.""" 92 | return len(self.structure_names) 93 | 94 | def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: 95 | """Get a full batch from the given index, standardizing it on the fly.""" 96 | # Get the key, start_idx, and end_idx for the batch. 97 | structure, start_idx, end_idx = self.batch_info[idx] 98 | 99 | # Load the features and targets from the HDF5 dataset. 100 | with h5py.File(self.h5_path, "r") as file: 101 | batch_length = end_idx - start_idx 102 | shuffled_idxs = ( 103 | torch.randperm(batch_length) 104 | if self.shuffle 105 | else torch.arange(batch_length) 106 | ) 107 | 108 | features = torch.tensor(file[f"{structure}/features"][start_idx:end_idx])[ 109 | shuffled_idxs 110 | ] 111 | targets = torch.tensor(file[f"{structure}/targets"][start_idx:end_idx])[ 112 | shuffled_idxs 113 | ] 114 | 115 | # Standardize the features and targets in preparation for training. 116 | features = (features - self.features_mean) / (self.features_std) 117 | targets = (targets - self.targets_mean) / (self.targets_std) 118 | 119 | return features, targets 120 | -------------------------------------------------------------------------------- /gpnn/models/model.py: -------------------------------------------------------------------------------- 1 | """Module defining the core ML model.""" 2 | 3 | from typing import Any, Dict 4 | 5 | import h5py 6 | import pytorch_lightning as L 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | __author__ = "Jake Vikoren" 12 | __maintainer__ = "Jake Vikoren" 13 | __email__ = "jake.vikoren@genmat.xyz" 14 | __date__ = "07/11/2024" 15 | 16 | 17 | class BaseModel(L.LightningModule): 18 | """Base model class for all models in the project.""" 19 | def __init__(self, *args, **kwargs) -> None: 20 | super().__init__() 21 | # Save args and kwargs to self.hparams 22 | self.save_hyperparameters() 23 | 24 | # When training from scratch, metadata is loaded from the HDF5 dataset 25 | if not hasattr(self.hparams, 'features_mean'): 26 | try: 27 | with h5py.File(self.hparams.h5_path, "r") as file: 28 | self.metadata = dict(file.attrs.items()) 29 | except FileNotFoundError as error: 30 | raise FileNotFoundError( 31 | f"Metadata file not found at {self.hparams.h5_path}" 32 | ) from error 33 | 34 | # Add key info from HDF5 file to hparams 35 | self.hparams.dataset_symbols = self.metadata["dataset_symbols"] 36 | self.hparams.input_dim = self.metadata["n_features"] 37 | self.hparams.output_dim = self.metadata["n_targets"] 38 | self.hparams.features_mean = torch.tensor(self.metadata["features_mean"]) 39 | self.hparams.features_std = torch.tensor(self.metadata["features_std"]) 40 | self.hparams.targets_mean = torch.tensor(self.metadata["targets_mean"]) 41 | self.hparams.targets_std = torch.tensor(self.metadata["targets_std"]) 42 | 43 | # Track mean and std tensors as part of the model state 44 | self.register_buffer("features_mean", self.hparams.features_mean.clone().detach()) 45 | self.register_buffer("features_std", self.hparams.features_std.clone().detach()) 46 | self.register_buffer("targets_mean", self.hparams.targets_mean.clone().detach()) 47 | self.register_buffer("targets_std", self.hparams.targets_std.clone().detach()) 48 | 49 | def inverse_standardize(self, *args: torch.Tensor) -> torch.Tensor: 50 | """Apply inverse standardization to model predictions.""" 51 | # Convert predictions back to original scale 52 | return ((arg * self.targets_std) + self.targets_mean for arg in args) 53 | 54 | def configure_optimizers(self) -> Dict[str, Any]: 55 | """Configure the optimizer and learning rate scheduler.""" 56 | optimizer = torch.optim.AdamW(params=self.parameters(), lr=self.hparams.optim.lr) 57 | scheduler = torch.optim.lr_scheduler.OneCycleLR( 58 | optimizer=optimizer, 59 | total_steps=self.hparams.optim.epochs, 60 | max_lr=self.hparams.optim.lr 61 | ) 62 | return { 63 | "optimizer": optimizer, 64 | "lr_scheduler": {"scheduler": scheduler, "interval": "epoch"}, 65 | "monitor": "val_loss", 66 | } 67 | 68 | 69 | class MLP(BaseModel): 70 | """An MLP capable of predicting charge density values based on local atomic environment. 71 | 72 | See GPNN/configs/model/default.yaml for input args. 73 | """ 74 | def __init__(self, *args, **kwargs) -> None: 75 | super().__init__(*args, **kwargs) 76 | self.sequential = build_mlp( 77 | input_dim=self.hparams.input_dim, 78 | hidden_dim=self.hparams.hidden_dim, 79 | n_layers=self.hparams.n_layers, 80 | output_dim=self.hparams.output_dim, 81 | ) 82 | 83 | def forward(self, x: torch.Tensor) -> torch.Tensor: 84 | """Forward pass through the model.""" 85 | return self.sequential(x) 86 | 87 | def _step( 88 | self, 89 | batch: torch.Tensor, 90 | stage: str, 91 | ) -> torch.Tensor: 92 | """Common step for training, validation, and test.""" 93 | features, targets = batch 94 | preds = self(features) 95 | 96 | # L1 outperforms MSE 97 | loss = F.l1_loss(preds, targets) 98 | 99 | # Inverse standardize predictions and targets, and calculate RMSE in the input space 100 | inv_preds, inv_targets = self.inverse_standardize(preds, targets) 101 | rmse = torch.sqrt(F.mse_loss(inv_preds, inv_targets)) 102 | 103 | self.log(f"{stage}_loss", loss, on_step=False, on_epoch=True, sync_dist=True) 104 | self.log(f"{stage}_rmse", rmse, on_step=False, on_epoch=True, sync_dist=True) 105 | 106 | return loss 107 | 108 | def training_step(self, batch) -> torch.Tensor: 109 | """Training step.""" 110 | return self._step(batch, "train") 111 | 112 | def validation_step(self, batch) -> torch.Tensor: 113 | """Validation step.""" 114 | return self._step(batch, "val") 115 | 116 | def test_step(self, batch) -> torch.Tensor: 117 | """Test step.""" 118 | return self._step(batch, "test") 119 | 120 | def build_mlp( 121 | input_dim: int, 122 | hidden_dim: int = 300, 123 | n_layers: int = 3, 124 | output_dim: int = 1, 125 | ): 126 | """A simple MLP model.""" 127 | # Input layer 128 | layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU()] 129 | 130 | # Hidden layers 131 | for _ in range(n_layers - 1): 132 | layers += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU()] 133 | 134 | # Output layer 135 | layers += [nn.Linear(hidden_dim, output_dim)] 136 | return nn.Sequential(*layers) 137 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | amrit.prasad@genmat.xyz . 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. -------------------------------------------------------------------------------- /gpnn/data/transforms.py: -------------------------------------------------------------------------------- 1 | """Transforms for processing structure files into datasets.""" 2 | 3 | from datetime import datetime 4 | import os 5 | from pathlib import Path 6 | import time 7 | from typing import Optional, Tuple, Literal, List 8 | 9 | import h5py 10 | import numpy as np 11 | from numpy.typing import ArrayLike 12 | 13 | from gpnn.system import ChargeDensitySystem 14 | from gpnn.fingerprint import Fingerprint 15 | 16 | __author__ = "Jake Vikoren" 17 | __maintainer__ = "Jake Vikoren" 18 | __email__ = "jake.vikoren@genmat.xyz" 19 | __date__ = "07/11/2024" 20 | 21 | """This module includes functions for processing structure files into a format suitable for 22 | ingestion into the GPNN model.""" 23 | 24 | 25 | def chgcar_to_h5( 26 | chgcar_path: str, 27 | out_dir: str, 28 | cutoff: float = 6.0, 29 | shape: Optional[Tuple[int, int, int]] = None, 30 | downsample_factor: Optional[int] = None, 31 | device: Literal["cpu", "gpu"] = "gpu", 32 | override: bool = False, 33 | ): 34 | """Convert a single CHGCAR file to a HDF5 file containing features, targets, and metadata. 35 | 36 | Args: 37 | chgcar_path (str): Path to CHGCAR file 38 | out_dir (str): Output directory for processed files 39 | cutoff (float, optional): Cutoff value for processing. Defaults to 6.0. 40 | shape (Optional[Tuple[int, int, int]], optional): Shape of the processed data. 41 | Defaults to None. 42 | downsample_factor (Optional[int], optional): Factor by which to downsample data. 43 | Defaults to None. 44 | device (Literal["cpu", "gpu"], optional): Device to use for processing (cpu/gpu). 45 | Defaults to "gpu". 46 | override (bool, optional): Whether to override existing output files. 47 | Defaults to False. 48 | 49 | Returns: 50 | None 51 | """ 52 | # Track the time taken to build the dataset 53 | start_time = time.perf_counter() 54 | 55 | # HDF5 file path combines output directory with <_id>.h5 56 | out_path = str((Path(out_dir) / Path(chgcar_path).stem).with_suffix(".h5")) 57 | 58 | # Terminate or override if file already exists 59 | if os.path.exists(out_path): 60 | if not override: 61 | raise FileExistsError(f"{out_path} already exists") 62 | os.remove(out_path) 63 | else: 64 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 65 | 66 | # Construct system 67 | system = ChargeDensitySystem.from_file(chgcar_path, cutoff=cutoff) 68 | 69 | # Resample or downsample if necessary 70 | system.resample(factor=downsample_factor, new_shape=shape) 71 | 72 | # Compute fingerprint 73 | fingerprint = Fingerprint( 74 | system=system, 75 | cutoff=cutoff, 76 | device=device, 77 | ) 78 | 79 | # Write targets to HDF5 file 80 | system.to_hdf5(file_path=out_path) 81 | 82 | # Write features to HDF5 file 83 | fingerprint.to_hdf5(file_path=out_path) 84 | 85 | end_time = time.perf_counter() 86 | 87 | # Write metadata 88 | with h5py.File(out_path, "r+") as hdf5_file: 89 | hdf5_file.attrs["creation_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 90 | hdf5_file.attrs["build_time"] = f"{end_time - start_time: .2f} seconds" 91 | hdf5_file.attrs["cutoff"] = cutoff 92 | hdf5_file.attrs["shape"] = str(system.shape) 93 | hdf5_file.attrs["downsample_factor"] = downsample_factor 94 | hdf5_file.attrs["n_points"] = hdf5_file["features"].attrs["n_points"] 95 | hdf5_file.attrs["system_symbols"] = fingerprint.system_symbols 96 | hdf5_file.attrs["n_features_per_element"] = fingerprint.n_features_per_element 97 | hdf5_file.attrs["n_features"] = fingerprint.n_features 98 | hdf5_file.attrs["device"] = device 99 | 100 | 101 | def match_dataset_symbols( 102 | features: ArrayLike, 103 | system_symbols: List[str], 104 | dataset_symbols: List[str], 105 | n_features_per_element: int = 32, 106 | ) -> ArrayLike: 107 | """Populate the feature array with zero columns for missing symbols. 108 | 109 | This function adds zero columns to the feature array for symbols that are 110 | present in the dataset but not in the system. This ensures that the feature 111 | shape aligns across the full dataset. 112 | 113 | Args: 114 | features (ArrayLike): Feature array to populate. 115 | system_symbols (List[str]): Symbols present in the system. 116 | dataset_symbols (List[str]): Symbols present in the dataset. 117 | n_features_per_element (int, optional): Number of features per element. Defaults to 32. 118 | 119 | Returns: 120 | ArrayLike: Feature array with zero columns for missing symbols. 121 | """ 122 | n_features = n_features_per_element * len(dataset_symbols) 123 | 124 | # No reshape is required if the dataset and system symbols are the same 125 | if set(dataset_symbols) == set(system_symbols): 126 | return features 127 | 128 | out_arr = np.zeros((features.shape[0], n_features), dtype=np.float32) 129 | 130 | system_start_idx_by_element = { 131 | s: i * n_features_per_element for i, s in enumerate(system_symbols) 132 | } 133 | 134 | dataset_start_idx_by_element = { 135 | s: i * n_features_per_element for i, s in enumerate(dataset_symbols) 136 | } 137 | 138 | for s in system_symbols: 139 | system_start_idx = system_start_idx_by_element[s] 140 | system_end_idx = system_start_idx + n_features_per_element 141 | 142 | dataset_start_idx = dataset_start_idx_by_element[s] 143 | dataset_end_idx = dataset_start_idx + n_features_per_element 144 | 145 | out_arr[:, dataset_start_idx:dataset_end_idx] = features[ 146 | :, system_start_idx:system_end_idx 147 | ] 148 | 149 | return out_arr 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Grid-Point Neural Network (GPNN) 2 | 3 | ## Overview 4 | 5 | - GPNN is a package for ML prediction of charge density based on the methods described in [Solving the electronic structure problem with machine learning](https://www.nature.com/articles/s41524-019-0162-7) by Ramprasad et. al. 6 | - The [NMC](https://data.dtu.dk/articles/dataset/NMC_Li-ion_Battery_Cathode_Energies_and_Charge_Densities/16837721) cathode dataset is used for demonstration. A pretrained model for this dataset is included and can be used for inference as is. 7 | - The workflow is constructed as a (build -> train -> predict) pipeline. 8 | - **Build:** Construct an HDF5 dataset file given a directory of CHGCAR (VASP) files. 9 | - **Train:** Train an ML model on the preprocessed data (HDF5 file genereated from the BUILD step). 10 | - **Predict:** Use the trained ML model to run inference on new structures (CIF -> Predicted CHGCAR) 11 | - Charge density can be predicted orders of magnitude faster (1 to 3 minutes in total) than DFT and for larger systems. 12 | - GPNN's core dependencies are: 13 | - [Hydra](https://hydra.cc/docs/intro/): Used for config management of hyperparameters during the build, train and predict stages (hydra-core==1.3.2). 14 | - [PyRho](https://materialsproject.github.io/pyrho/index.html): Used for charge density management and interpolation (mp-pyrho==0.3.0). 15 | - [HDF5](https://docs.h5py.org/en/stable/quick.html): Used for efficient data storage and organization (h5py==3.8.0). 16 | - [Dask](https://docs.dask.org/en/stable/10-minutes-to-dask.html): Used for distributed massive array manipulation (dask==2023.11.0). 17 | - [CuPy](https://cupy.dev/): Used in conjunction with Dask for GPU-accelerated array operations (cupy-cuda11x==13.0.0) 18 | 19 | 20 | ## Installation 21 | 22 | 1. Clone repo 23 | 3. Create new conda env 24 | 5. Run `pip install -e .` 25 | 26 | ## Prepare Configs 27 | 28 | In order to simplify interdependent configurations, configs are organized into basic categories (data, model, train, inference) and composed by [Hydra](https://hydra.cc/docs/intro/) when scripts are run. Configs should be modified in place (not moved to other locations). 29 | 30 | NOTE: Hydra includes a CLI that can be used to modify parameters on-the-fly instead of manually modifying the yaml configs. See "Run Experiments" for examples. 31 | 32 | 1. Main Config: 33 | - Navigate to `GPNN/configs/config.yaml` 34 | - This is the high level config that orchestrates the lower level configs (data, inference and training). 35 | - If you will be training your own model, change "expname". Otherwise, leave it as is. This variable is used to select from pretrained models or name new models. 36 | - Change "project_root" to the global path where you cloned the repo (path/to/GPNN). 37 | 2. Inference: 38 | - Navigate to `GPNN/configs/inference/default.yaml` 39 | - Update the indicated variables 40 | 3. Training 41 | - Navigate to `GPNN/configs/train/default.yaml` 42 | - Update the indicated variables 43 | 4. Data. 44 | - If training on NMC, navigate to `GPNN/configs/data/NMC.yaml` and follow the instructions. 45 | - If defining a custom dataset, navigate to `GPNN/configs/data/template.yaml` and follow the instructions. 46 | 47 | ## Run Experiments 48 | **Navigate to GPNN/scripts before running all scripts** 49 | 50 | 1. Inference with Pretrained Model (included model is trained on NMC) 51 | - If you completed step 2 under "Prepare Configs", simply run `python predict.py` 52 | - Otherwise, run `python predict.py inference.cif_path= inference.save_dir= inference.shape= inference.device=` 53 | 54 | 2. Train NMC from scratch 55 | - Navigate to `GPNN/configs/config.yaml` 56 | - Change "expname" to a name of your choice. This will become the name of your model. 57 | - Run `python train.py` 58 | - Your trained model will be stored in `GPNN/models/` 59 | 60 | 3. Train on custom dataset 61 | - Navigate to `GPNN/configs/data/template.yaml` and follow the instructions. 62 | - You will be loading your CHGCAR files as described, and defining the relevant config parameters. 63 | - Navigate to `GPNN/configs/config.yaml` 64 | - Change "expname" to a name of your choice, that best describes the dataset you want to train on. This will become the name of your model. Note, new experiments run under this name, will use the last model checkpoint under ../models/expname, 65 | - Update line 10 of this file GPNN/configs/config.yaml to the name of your dataset. 66 | - Run `python build.py` to build the dataset file. 67 | - Run `python train.py` to train a model on this data. 68 | - Run `python predict.py` to use your model for inference on CIF files. 69 | 70 | ## Technical Summary 71 | *Charge density* is represented as a set of scalar values defined on a regularly distributed discrete grid spanning the volume of a unit cell. 72 | 73 | A *neural network* is trained to predict the *charge density* values based on a representation of the local atomic environment around each grid point. This representation (called the *fingerprint*) contains information on the location and types of atoms relative to each grid point. 74 | 75 | The fingerprint is constructed as a combination of rotationally invariant *scalar* and *vector* components. The *scalar* component captures radial distance information while the *vector* component captures angular information. 76 | 77 | The authors propose a predefined set of Gaussian functions (k) of varying widths (σk) centered about every grid-point (g) to determine these fingerprints. The *scalar* fingerprint (Sk) for a particular grid-point, g, and Gaussian, k, in an N-atom, single-elemental system is defined as 78 | $$S_k = C_k \sum_{i=1}^{N} \exp \left( \frac{-r_{gi}^2}{2\sigma_k^2} \right) f_c(r_{gi})$$ 79 | 80 | where *rgi* is the distance between the reference grid-point, *g*, and the atom, *i*, and *fc(rgi)* is a cutoff function, which decays to zero for atoms beyond the cutoff radius (default 6Å) from the grid-point. Ck is a normalizing constant (see paper for definition). 81 | 82 | The *vector* component is defined as 83 | $$V_k^\alpha = C_k \sum_{i=1}^{N} \frac{r_{gi}^\alpha}{2\sigma_k^2} \exp \left( \frac{-r_{gi}^2}{2\sigma_k^2} \right) f_c(r_{gi})$$ 84 | where where, α and β represent the x, y, or z directions. In order to maintain rotational invariance, this result is then composed like 85 | $$V_k = \sqrt{(V_k^x)^2 + (V_k^y)^2 + (V_k^z)^2}$$ 86 | 87 | NOTE: The tensor component described in the paper is ommitted due to the marginal performance improvement at significantly increased computational cost. 88 | -------------------------------------------------------------------------------- /gpnn/system/system.py: -------------------------------------------------------------------------------- 1 | """A module defining the ChargeDensitySystem object.""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass 6 | from functools import cached_property, lru_cache 7 | from itertools import product 8 | from typing import Dict, Tuple, Optional 9 | 10 | import copy 11 | import h5py 12 | import numpy as np 13 | from numpy.typing import ArrayLike 14 | from pymatgen.core.structure import Structure 15 | from pymatgen.core.periodic_table import Species 16 | from pymatgen.io.vasp.outputs import VolumetricData 17 | from pyrho.charge_density import ChargeDensity 18 | 19 | __author__ = "Jake Vikoren" 20 | __maintainer__ = "Jake Vikoren" 21 | __email__ = "jake.vikoren@genmat.xyz" 22 | __date__ = "07/11/2024" 23 | 24 | """The ChargeDensitySystem object is the main representation of atomic strucutres and 25 | their associated charge density data. This object is used to generate the fingerprints 26 | used for training.""" 27 | 28 | 29 | @dataclass 30 | class ChargeDensitySystem: 31 | """A class to represent a charge density system. 32 | 33 | Args: 34 | cd (ChargeDensity): The charge density object. 35 | cutoff (float): The cutoff radius for the supercell. 36 | Returns: 37 | None 38 | """ 39 | cd: ChargeDensity 40 | cutoff: float = 6.0 41 | 42 | def __post_init__(self): 43 | self._validate_inputs() 44 | 45 | def _validate_inputs(self): 46 | if not isinstance(self.cd, ChargeDensity): 47 | raise ValueError("cd must be a ChargeDensity object.") 48 | if not isinstance(self.cutoff, (int, float)): 49 | raise ValueError("cutoff must be an integer or float.") 50 | if self.cutoff < 0: 51 | raise ValueError("cutoff must be a positive number.") 52 | 53 | @cached_property 54 | def supercell(self) -> Structure: 55 | """Get the minimal symmetric supercell that contains the cutoff radius.""" 56 | # Orthogonal distance from the center to the walls of the unit cell 57 | dist_to_pbc = np.diagonal(self.cd.lattice) / 2 58 | 59 | # Compute the number of unit cells needed in each direction 60 | with np.errstate(divide='ignore'): # Ignore divide by zero warning 61 | pbcs_crossed = np.abs((self.cutoff - dist_to_pbc) // (2 * dist_to_pbc)) 62 | pbcs_crossed += 1 # Include the center cell 63 | return pbcs_crossed.astype(int) * 2 + 1 # convert to symmetrical supercell 64 | 65 | @cached_property 66 | def primitive_structure(self) -> Structure: 67 | """Get the primitive structure of the charge density system.""" 68 | return self.cd.structure.get_sorted_structure() 69 | 70 | @cached_property 71 | def structure(self) -> Structure: 72 | """Get the structure of the charge density system.""" 73 | structure = self.primitive_structure.copy() 74 | structure *= self.supercell 75 | return structure 76 | 77 | @cached_property 78 | def grid_coords(self) -> ArrayLike: 79 | """Get the cartesian coordinates of the charge density grid.""" 80 | # Unmodified coords 81 | primitive_coords = build_coords( 82 | self.shape, self.cd.lattice.astype(np.float64).tobytes() 83 | ) 84 | 85 | # If the supercell is [1, 1, 1], coords need no translation 86 | if all(self.supercell == 1): 87 | return primitive_coords 88 | 89 | # Fractional center 90 | center_coord = np.array([0.5, 0.5, 0.5]) 91 | 92 | # Center in the supercell with a translation vector 93 | primitive_center = fractional_to_cartesian( 94 | self.primitive_structure.lattice.matrix, center_coord 95 | ) 96 | supercell_center = fractional_to_cartesian(self.lattice, center_coord) 97 | 98 | # Translation vector to move coords from the primitive cell to the supercell 99 | translation_vector = supercell_center - primitive_center 100 | return primitive_coords + translation_vector 101 | 102 | @property 103 | def shape(self) -> Tuple[int, int, int]: 104 | """Get the shape of the charge density grid.""" 105 | return self.cd.grid_shape 106 | 107 | @shape.setter 108 | def shape(self, new_shape: Tuple[int, int, int]) -> None: 109 | """Set the shape of the charge density grid.""" 110 | self.resample(new_shape=new_shape) 111 | 112 | @property 113 | def lattice(self) -> ArrayLike: 114 | """Get the lattice matrix of the charge density system.""" 115 | return self.structure.lattice.matrix 116 | 117 | @property 118 | def data(self) -> ArrayLike: 119 | """Get the charge density data.""" 120 | return self.cd.normalized_data["total"] 121 | 122 | @property 123 | def atom_coords(self) -> ArrayLike: 124 | """Get the cartesian atomic coordinates of the charge density system.""" 125 | return self.structure.cart_coords 126 | 127 | @property 128 | def all_atomic_numbers(self) -> ArrayLike: 129 | """Get the atomic numbers of all atoms in the charge density system.""" 130 | return self.structure.atomic_numbers 131 | 132 | @property 133 | def elements(self) -> ArrayLike: 134 | """Get the elements in the charge density system.""" 135 | elements = self.structure.composition.elements 136 | if isinstance(elements[0], Species): 137 | return sorted([s.element for s in elements]) 138 | return elements 139 | 140 | @property 141 | def symbols(self) -> ArrayLike: 142 | """Get the chemical symbols of the elements in the charge density system.""" 143 | return [element.symbol for element in self.elements] 144 | 145 | @property 146 | def atomic_numbers(self) -> ArrayLike: 147 | """Get the unique atomic numbers of the elements in the charge density system.""" 148 | return [element.Z for element in self.elements] 149 | 150 | @property 151 | def n_points(self) -> int: 152 | """Get the number of grid points in the charge density grid.""" 153 | return self.cd.pgrids["total"].ngridpts 154 | 155 | @property 156 | def n_atoms(self) -> int: 157 | """Get the number of atoms in the charge density system.""" 158 | return len(self.structure) 159 | 160 | @property 161 | def n_elements(self) -> int: 162 | """Get the number of elements in the charge density system.""" 163 | return len(self.elements) 164 | 165 | @property 166 | def element_amount_dict(self) -> Dict[int, int]: 167 | """Get the element amount dictionary of the charge density system.""" 168 | return self.structure.composition.get_el_amt_dict() 169 | 170 | @property 171 | def dtype(self) -> np.dtype: 172 | """Get the data type of the charge density data.""" 173 | return self.data.dtype 174 | 175 | @classmethod 176 | def from_cube(cls, cube_path: str, cutoff: float = 6.0) -> ChargeDensitySystem: 177 | """Create a ChargeDensitySystem from a cube file.""" 178 | cube = VolumetricData.from_cube(cube_path) 179 | cd = ChargeDensity.from_pmg(cube, normalization=None) 180 | return cls(cd=cd, cutoff=cutoff) 181 | 182 | @classmethod 183 | def from_chgcar(cls, chgcar_path: str, cutoff: float = 6.0) -> ChargeDensitySystem: 184 | """Create a ChargeDensitySystem from a CHGCAR file.""" 185 | cd = ChargeDensity.from_file(chgcar_path) 186 | return cls(cd=cd, cutoff=cutoff) 187 | 188 | @classmethod 189 | def from_cif( 190 | cls, 191 | cif_path: str, 192 | shape: Tuple[int, int, int] = (50, 50, 50), 193 | cutoff: float = 6.0, 194 | ) -> ChargeDensitySystem: 195 | """Create a ChargeDensitySystem from a CIF file.""" 196 | structure = Structure.from_file(cif_path) 197 | volumetric_data = VolumetricData( 198 | structure=structure, data={"total": np.zeros(shape)} 199 | ) 200 | cd = ChargeDensity.from_pmg(volumetric_data, normalization=None) 201 | return cls(cd=cd, cutoff=cutoff) 202 | 203 | @classmethod 204 | def from_file(cls, file_path: str, cutoff: float = 6.0) -> ChargeDensitySystem: 205 | """Create a ChargeDensitySystem from a file.""" 206 | if file_path.endswith(".cube.gz"): 207 | return cls.from_cube(cube_path=file_path, cutoff=cutoff) 208 | elif file_path.endswith(".CHGCAR"): 209 | return cls.from_chgcar(chgcar_path=file_path, cutoff=cutoff) 210 | elif file_path.endswith(".cif"): 211 | return cls.from_cif(cif_path=file_path, cutoff=cutoff) 212 | else: 213 | raise ValueError( 214 | "File type not supported. Supported types: ['.cube.gz', '.CHGCAR']." 215 | ) 216 | 217 | @classmethod 218 | def from_structure( 219 | cls, structure: Structure, data: ArrayLike, cutoff: float = 6.0 220 | ) -> ChargeDensitySystem: 221 | """Create a ChargeDensitySystem from a structure and data.""" 222 | cd = ChargeDensity.from_pmg( 223 | VolumetricData(structure=structure, data={"total": data}), 224 | normalization=None, 225 | ) 226 | return cls(cd=cd, cutoff=cutoff) 227 | 228 | @classmethod 229 | def from_pmg( 230 | cls, volumetric_data: VolumetricData, cutoff: float = 6.0 231 | ) -> ChargeDensitySystem: 232 | """Create a ChargeDensitySystem from a pymatgen VolumetricData object.""" 233 | cd = ChargeDensity.from_pmg(volumetric_data, normalization="vasp") 234 | return cls(cd=cd, cutoff=cutoff) 235 | 236 | def to_hdf5(self, file_path: str, group: str = None) -> None: 237 | """Save the charge density data to an HDF5 file.""" 238 | data = np.reshape(self.data, (-1, 1)) 239 | chunk_size = 5000 if data.shape[0] > 5000 else data.shape[0] 240 | 241 | # Track for mean and std computation 242 | sums = np.sum(data, axis=0) 243 | sum_of_squares = np.sum(data**2, axis=0) 244 | n_points = data.shape[0] 245 | 246 | if group is not None: 247 | dataset_path = f"{group}/targets" 248 | else: 249 | dataset_path = "targets" 250 | 251 | with h5py.File(file_path, "a") as f: 252 | dataset = f.create_dataset( 253 | dataset_path, 254 | data=data, 255 | chunks=(chunk_size, 1), 256 | dtype=np.float32, 257 | compression="gzip", 258 | ) 259 | 260 | # This will be used for on-the-fly mean and std computation 261 | dataset.attrs["sum"] = sums 262 | dataset.attrs["sum_of_squares"] = sum_of_squares 263 | dataset.attrs["n_points"] = n_points 264 | 265 | def to_dict(self) -> Dict[str, ArrayLike]: 266 | """Convert the charge density data to a dictionary.""" 267 | data = np.reshape(self.data, (-1, 1)) 268 | 269 | # Track for mean and std computation 270 | sums = np.sum(data, axis=0) 271 | sum_of_squares = np.sum(data**2, axis=0) 272 | n_points = data.shape[0] 273 | 274 | metadata = { 275 | "sum": float(sums[0]), 276 | "sum_of_squares": float(sum_of_squares[0]), 277 | "n_points": n_points, 278 | } 279 | 280 | return {"data": data.tolist(), "metadata": metadata} 281 | 282 | def resample( 283 | self, 284 | factor: Optional[int] = 1, 285 | new_shape: Optional[Tuple[int, int, int]] = None, 286 | ) -> None: 287 | """Resample the charge density grid by a factor or to a new shape.""" 288 | if new_shape is None: 289 | new_shape = [int(s / factor) for s in self.shape] 290 | self.cd = self.cd.get_transformed(sc_mat=np.eye(3), grid_out=new_shape) 291 | 292 | def copy(self) -> ChargeDensitySystem: 293 | return copy.deepcopy(self) 294 | 295 | 296 | @lru_cache 297 | def build_coords(shape: Tuple[int, int, int], lattice: bytes) -> ArrayLike: 298 | """Converts a grid shape and lattice into a set of coordinates. 299 | 300 | The lattice is expected in bytes to enable caching. This avoids expensive 301 | calls to this function when the grid coordinates are unchanged. 302 | 303 | Args: 304 | shape (tuple): The shape of the grid. 305 | lattice (bytes): The lattice vectors of the grid. 306 | 307 | Returns: 308 | ArrayLike: The cartesian coordinates of the grid. 309 | """ 310 | 311 | # Convert bytes back to numpy array 312 | lattice = np.frombuffer(lattice, dtype=np.float64).reshape((3, 3)) 313 | 314 | # Fractional values along each cartesian axis 315 | axes = [np.linspace(0, 1, n, endpoint=False) for n in shape] 316 | 317 | # Generate the fractional coordinates 318 | fractional_coords = np.array(list(product(*axes))) 319 | 320 | # Convert fractional to cartesian coordinates 321 | return fractional_coords.dot(lattice) 322 | 323 | 324 | def fractional_to_cartesian(lattice_matrix, fractional_coords): 325 | """Convert fractional coordinates to cartesian coordinates.""" 326 | return np.matmul(lattice_matrix.T, fractional_coords.T).T 327 | -------------------------------------------------------------------------------- /gpnn/data/builders.py: -------------------------------------------------------------------------------- 1 | """Dataset builders for charge density systems.""" 2 | 3 | import itertools 4 | import math 5 | import os 6 | import time 7 | from dataclasses import dataclass, field 8 | from datetime import datetime 9 | from pathlib import Path 10 | from typing import Dict, List, Literal, Optional, Sequence, Tuple, Union 11 | 12 | import h5py 13 | import numpy as np 14 | from tqdm.auto import tqdm 15 | from pymatgen.core import Element 16 | 17 | from gpnn.fingerprint import Fingerprint 18 | from gpnn.system import ChargeDensitySystem 19 | 20 | 21 | __author__ = "Jake Vikoren" 22 | __maintainer__ = "Jake Vikoren" 23 | __email__ = "jake.vikoren@genmat.xyz" 24 | __date__ = "07/11/2024" 25 | 26 | """This module contains classes for building datasets from charge density systems.""" 27 | 28 | 29 | @dataclass 30 | class DictSplits: 31 | """Orgnanizes dataset structures into appropriate splits. 32 | 33 | Args: 34 | structure_ids (List[str]): List of structure names. 35 | splits (Union[List[int], List[float], Dict[str, List[str]]): The splits to 36 | use for the dataset. This can be a list of integers, a list of floats, or a 37 | dictionary of lists of strings. If a list of integers or floats is provided, 38 | the splits will be assigned based on the proportion of the dataset that each 39 | split should contain. If a dictionary is provided, the keys should be the 40 | split names and the values should be lists of structure names. 41 | """ 42 | 43 | structure_ids: List[str] 44 | splits: Union[List[int], List[float], Dict[str, List[str]]] 45 | _dict_splits: Dict[str, List[str]] = field(init=False) 46 | 47 | def __post_init__(self): 48 | self._dict_splits = self._process_splits(self.splits) 49 | self._validate_dict_splits() 50 | 51 | @property 52 | def train(self) -> List[str]: 53 | """Return the IDs in the training split.""" 54 | return self._dict_splits.get("train", []) 55 | 56 | @property 57 | def val(self) -> List[str]: 58 | """Return the IDs in the validation split.""" 59 | return self._dict_splits.get("val", []) 60 | 61 | @property 62 | def test(self) -> List[str]: 63 | """Return the IDs in the test split.""" 64 | return self._dict_splits.get("test", []) 65 | 66 | def _validate_dict_splits(self): 67 | """Validate the dictionary of splits.""" 68 | assert set(self._dict_splits.keys()) <= {"train", "val", "test"} 69 | assert all(isinstance(x, str) for v in self._dict_splits.values() for x in v) 70 | assert all( 71 | len(set(v)) == len(v) for v in self._dict_splits.values() 72 | ), "All splits must contain unique structure names." 73 | assert len( 74 | set(string for values in self._dict_splits.values() for string in values) 75 | ) == sum( 76 | len(values) for values in self._dict_splits.values() 77 | ), "All splits must be disjoint." 78 | 79 | def _process_splits( 80 | self, splits: Union[List[int], List[float], Dict[str, List[str]]] 81 | ) -> Dict[str, List[str]]: 82 | if isinstance(splits, list): 83 | if all(isinstance(x, float) for x in splits): 84 | return self._assign_splits_from_floats(splits) 85 | if all(isinstance(x, int) for x in splits): 86 | return self._assign_splits_from_ints(splits) 87 | raise ValueError( 88 | "All elements in the splits list must be of the same type " 89 | "(either all int or all float)." 90 | ) 91 | if isinstance(splits, dict): 92 | return splits 93 | else: 94 | raise TypeError( 95 | "splits must be either a list of integers, a list of floats, " 96 | "or a dictionary of lists of strings." 97 | ) 98 | 99 | def _assign_splits_from_floats(self, splits: List[float]) -> Dict[str, List[str]]: 100 | """Assign splits based on proportions.""" 101 | if not math.isclose(sum(splits), 1.0): 102 | raise ValueError("The sum of the splits proportions must be 1.") 103 | 104 | n = len(self.structure_ids) 105 | indices = list(range(n)) 106 | 107 | split_indices = { 108 | "train": indices[: int(n * splits[0])], 109 | "val": indices[int(n * splits[0]) : int(n * (splits[0] + splits[1]))], 110 | "test": indices[int(n * (splits[0] + splits[1])) :], 111 | } 112 | 113 | return { 114 | k: [self.structure_ids[i] for i in v] for k, v in split_indices.items() 115 | } 116 | 117 | def _assign_splits_from_ints(self, splits: List[int]) -> Dict[str, List[str]]: 118 | """Assign splits based on counts.""" 119 | if sum(splits) > len(self.structure_ids): 120 | raise ValueError( 121 | "The sum of the splits must be less than or equal to the number " 122 | "of structure names." 123 | ) 124 | iterator = iter(self.structure_ids) 125 | return { 126 | "train": list(itertools.islice(iterator, splits[0])), 127 | "val": list(itertools.islice(iterator, splits[1])), 128 | "test": list(itertools.islice(iterator, splits[2])), 129 | } 130 | 131 | def __getitem__(self, key): 132 | """Allows access to splits using dictionary syntax.""" 133 | return self._dict_splits.get(key, None) 134 | 135 | 136 | @dataclass(kw_only=True) 137 | class StructurePathHandler: 138 | """Paths to individual structure files and their ids. 139 | 140 | Args: 141 | structure_paths (List[str]): List of paths to structure files. 142 | structure_ids (List[str]): List of structure ids derived from filenames. 143 | """ 144 | 145 | structure_paths: List[str] 146 | structure_ids: List[str] 147 | 148 | @classmethod 149 | def from_chgcars(cls, in_dir: str): 150 | """Initialize from CHGCAR files.""" 151 | paths = [] 152 | ids = [] 153 | for path in Path(in_dir).rglob("*.CHGCAR"): 154 | paths.append(str(path)) 155 | ids.append(str(path.stem)) 156 | 157 | return cls( 158 | structure_paths=paths, 159 | structure_ids=ids, 160 | ) 161 | 162 | 163 | @dataclass(kw_only=True) 164 | class HDF5DatasetBuilder: 165 | """Class responsible for building HDF5 dataset files.""" 166 | 167 | in_dir: str 168 | out_dir: str 169 | filename: str 170 | cutoff: float = 6.0 171 | shape: Optional[Tuple[int, int, int]] = None 172 | downsample_factor: Optional[int] = None 173 | dataset_elements: Sequence[int] = None 174 | splits: Union[List[int], List[float], Dict[str, List[str]]] 175 | device: Literal["cpu", "gpu"] = "cpu" 176 | 177 | def __post_init__(self): 178 | self.out_path = os.path.join(self.out_dir, self.filename) 179 | self.path_handler = StructurePathHandler.from_chgcars(in_dir=self.in_dir) 180 | self.structure_paths = self.path_handler.structure_paths 181 | self.structure_ids = self.path_handler.structure_ids 182 | self.dict_splits = DictSplits( 183 | structure_ids=self.structure_ids, splits=self.splits 184 | ) 185 | self.build_time = None # Populated when the dataset is built. 186 | 187 | # Elements included in the dataset sorted by electronegativity. 188 | self.dataset_elements = sorted( 189 | [ 190 | Element(e) if isinstance(e, str) else Element.from_Z(e) 191 | for e in self.dataset_elements 192 | ] 193 | ) 194 | 195 | def _to_hdf5(self) -> None: 196 | """Compute the features then write the features and targets to an HDF5 file.""" 197 | start_time = time.perf_counter() 198 | 199 | for structure_path, structure_id in tqdm( 200 | zip(self.structure_paths, self.structure_ids), 201 | total=len(self.structure_paths), 202 | desc="Building dataset", 203 | unit="structure", 204 | leave=False, 205 | ): 206 | # Construct system 207 | system = ChargeDensitySystem.from_file(structure_path, cutoff=self.cutoff) 208 | 209 | # Resample or downsample if necessary 210 | system.resample(factor=self.downsample_factor, new_shape=self.shape) 211 | 212 | # Compute fingerprint 213 | fingerprint = Fingerprint( 214 | system=system, 215 | cutoff=self.cutoff, 216 | dataset_elements=self.dataset_elements, 217 | device=self.device, 218 | ) 219 | 220 | # Write targets to HDF5 file 221 | system.to_hdf5(file_path=self.out_path, group=structure_id) 222 | 223 | # Write features to HDF5 file 224 | fingerprint.to_hdf5(file_path=self.out_path, group=structure_id) 225 | 226 | end_time = time.perf_counter() 227 | self.build_time = end_time - start_time 228 | 229 | def build(self, override: bool = False) -> None: 230 | """Compose the dataset file.""" 231 | out_path = Path(self.out_path) 232 | 233 | # Terminate or override if the dataset file already exists 234 | if out_path.exists(): 235 | if override: 236 | out_path.unlink() # remove existing file 237 | else: 238 | raise FileExistsError( 239 | f"This dataset already exists at {out_path}.\n" 240 | "Please remove it or choose a different filename." 241 | ) 242 | else: 243 | out_path.parent.mkdir(parents=True, exist_ok=True) 244 | 245 | self._to_hdf5() # Compute and write features and targets 246 | self._assign_splits() # Using metadata flags ["train", "val", "test"] 247 | self._compute_mean_std() # For feature and target standardization 248 | self._write_metadata() # To track dataset properties 249 | 250 | def _assign_splits(self) -> None: 251 | """Assign splits to HDF5 groups using metadata.""" 252 | with h5py.File(self.out_path, "r+") as hdf5_file: 253 | for structure_id in self.structure_ids: 254 | hdf5_file[structure_id].attrs["split"] = self._get_split( 255 | structure_id 256 | ) 257 | 258 | def _get_split(self, structure_id: str) -> str: 259 | """Determine which split the structure belongs to.""" 260 | if structure_id in self.dict_splits.train: 261 | return "train" 262 | if structure_id in self.dict_splits.val: 263 | return "val" 264 | if structure_id in self.dict_splits.test: 265 | return "test" 266 | raise ValueError( 267 | f"Structure {structure_id} does not belong to any split." 268 | ) 269 | 270 | def _compute_mean_std(self): 271 | """Compute mean and std for features and targets following Welford's method.""" 272 | # Initialize aggregates for features 273 | total_features_sum = 0 274 | total_features_sum_of_squares = 0 275 | total_features_points = 0 276 | 277 | # Initialize aggregates for targets 278 | total_targets_sum = 0 279 | total_targets_sum_of_squares = 0 280 | total_targets_points = 0 281 | 282 | with h5py.File(self.out_path, "r") as hdf5_file: 283 | for structure_id in self.dict_splits.train: 284 | # Aggregate for features 285 | features_dataset = hdf5_file[f"{structure_id}/features"] 286 | total_features_sum += features_dataset.attrs["sum"] 287 | total_features_sum_of_squares += features_dataset.attrs[ 288 | "sum_of_squares" 289 | ] 290 | total_features_points += features_dataset.attrs["n_points"] 291 | 292 | # Aggregate for targets 293 | targets_dataset = hdf5_file[f"{structure_id}/targets"] 294 | total_targets_sum += targets_dataset.attrs["sum"] 295 | total_targets_sum_of_squares += targets_dataset.attrs["sum_of_squares"] 296 | total_targets_points += targets_dataset.attrs["n_points"] 297 | 298 | # Compute the overall mean and std for features 299 | features_mean = total_features_sum / total_features_points 300 | features_std = np.sqrt( 301 | total_features_sum_of_squares / total_features_points 302 | - np.square(features_mean) 303 | ) 304 | 305 | # Compute the overall mean and std for targets 306 | targets_mean = total_targets_sum / total_targets_points 307 | targets_std = np.sqrt( 308 | total_targets_sum_of_squares / total_targets_points 309 | - np.square(targets_mean) 310 | ) 311 | 312 | # Add computed values as attributes to the root group of the HDF5 file 313 | with h5py.File(self.out_path, "r+") as hdf5_file: 314 | hdf5_file.attrs["features_mean"] = features_mean.astype(np.float32) 315 | hdf5_file.attrs["features_std"] = features_std.astype(np.float32) 316 | hdf5_file.attrs["targets_mean"] = targets_mean.astype(np.float32) 317 | hdf5_file.attrs["targets_std"] = targets_std.astype(np.float32) 318 | 319 | def _write_metadata(self) -> None: 320 | """Write metadata to HDF5 group.""" 321 | # Collect dataset wide information 322 | metadata = { 323 | "date_created": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 324 | "build_time": self.build_time if self.build_time else "NA", 325 | "cutoff": self.cutoff, 326 | "shape": str(self.shape), 327 | "downsample_factor": str(self.downsample_factor), 328 | # Elements must be stored as strings for HDF5 compatibility 329 | "dataset_symbols": [element.symbol for element in self.dataset_elements], 330 | "n_structures": len(self.structure_ids), 331 | "device": self.device, 332 | } 333 | with h5py.File(self.out_path, "r+") as hdf5_file: 334 | # Collect shape information from example structure. This will be 335 | # consistent across all structures in the dataset. 336 | example_structure = hdf5_file[self.structure_ids[0]] 337 | metadata["n_points"] = example_structure["features"].shape[0] 338 | metadata["n_features"] = example_structure["features"].shape[1] 339 | metadata["n_targets"] = example_structure["targets"].shape[1] 340 | 341 | # Write metadata to HDF5 file 342 | for key, value in metadata.items(): 343 | hdf5_file.attrs[key] = value 344 | -------------------------------------------------------------------------------- /gpnn/fingerprint/fingerprint.py: -------------------------------------------------------------------------------- 1 | """Core class for building fingerprints from charge density systems.""" 2 | 3 | from dataclasses import dataclass 4 | import math 5 | from typing import List, Literal 6 | 7 | import cupy as cp 8 | import dask 9 | import dask.array as da 10 | import h5py 11 | import numpy as np 12 | from numpy.typing import ArrayLike 13 | from pymatgen.core import Element 14 | from tqdm.auto import tqdm 15 | 16 | from gpnn.fingerprint.neighbors import Neighbors 17 | from gpnn.fingerprint.symmetry_functions import GaussianSymmetryFunctions 18 | from gpnn.system import ChargeDensitySystem 19 | 20 | dask.config.set({"array.slicing.split_large_chunks": False}) 21 | 22 | __author__ = "Jake Vikoren" 23 | __maintainer__ = "Jake Vikoren" 24 | __email__ = "jake.vikoren@genmat.xyz" 25 | __date__ = "07/11/2024" 26 | 27 | 28 | """The fingerprint class takes a structure with charge density data (ChargeDensitySystem) 29 | as input and uses it to compute the features that will be input into the ML model. The 30 | fingerprint is constructed using the list of elements present in the dataset. 31 | The "system_atomic_numbers" represents the elements present in a given structure and 32 | "dataset_atomic_numbers" includes all elements in the dataset. The fingerprint is computed 33 | for all elements in the structure and then zero columns are added for elements that are 34 | present in the dataset but not in the structure. This ensures that the feature shape aligns 35 | across the full dataset. The fingerprint is computed in batches to minimize memory issues. 36 | """ 37 | 38 | 39 | @dataclass 40 | class Fingerprint: 41 | """Class to processing a ChargeDensitySystem into a feature array. 42 | 43 | Args: 44 | system (ChargeDensitySystem): Charge density system. 45 | cutoff (float): Cutoff radius for the fingerprint. 46 | dataset_elements (List[int | str | Element]): Elements to include in the dataset. 47 | batch_size (int): Batch size for processing grid points. 48 | device (Literal["cpu", "gpu"]): Device to use for computation. 49 | 50 | Returns: 51 | None 52 | """ 53 | 54 | system: ChargeDensitySystem 55 | cutoff: float = 6.0 56 | dataset_elements: List[int | str | Element] = None 57 | batch_size: int = 5_000 58 | device: Literal["cpu", "gpu"] = "cpu" 59 | 60 | def __post_init__(self): 61 | self.xp = cp if self.device == "gpu" else np 62 | self._symmetry_functions = GaussianSymmetryFunctions() 63 | self._prepare_dataset_elements() 64 | self._validate_inputs() 65 | 66 | def _validate_inputs(self): 67 | """Validate the input parameters.""" 68 | if self.cutoff <= 0: 69 | raise ValueError("Cutoff must be greater than zero.") 70 | if self.batch_size <= 0: 71 | raise ValueError("Batch size must be greater than zero.") 72 | if self.device not in ["cpu", "gpu"]: 73 | raise ValueError("Device must be either 'cpu' or 'gpu'.") 74 | if any( 75 | element not in self.dataset_elements for element in self.system.elements 76 | ): 77 | raise ValueError( 78 | "System elements must be a subset of the dataset elements.\n" 79 | f"System elements: {self.system_elements}\n" 80 | f"Dataset elements: {self.dataset_elements}" 81 | ) 82 | 83 | if self.n_dataset_elements < self.n_system_elements: 84 | raise ValueError( 85 | "Dataset elements must contain at least as many elements as the material." 86 | ) 87 | 88 | def _prepare_dataset_elements(self): 89 | """Ensure dataset elements are in the correct format.""" 90 | if not self.dataset_elements: 91 | # Use the elements present in the given structure 92 | self.dataset_elements = self.system_elements 93 | else: 94 | # Convert from atomic symbol (string) or atomic number (int) to PyMatGen Elements 95 | self.dataset_elements = [ 96 | ( 97 | Element(element) 98 | if isinstance(element, str) 99 | else ( 100 | Element.from_Z(element) if isinstance(element, int) else element 101 | ) 102 | ) 103 | for element in self.dataset_elements 104 | ] 105 | 106 | @property 107 | def system_elements(self) -> List[Element]: 108 | """Elements contained in the system.""" 109 | return self.system.elements 110 | 111 | @property 112 | def system_symbols(self) -> List[str]: 113 | """Symbols of the elements in the system.""" 114 | return self.system.symbols 115 | 116 | @property 117 | def dataset_symbols(self) -> List[str]: 118 | """Symbols of the elements in the full dataset.""" 119 | return [element.symbol for element in self.dataset_elements] 120 | 121 | @property 122 | def system_atomic_numbers(self) -> List[int]: 123 | """Unique atomic numbers in the system.""" 124 | return self.system.atomic_numbers 125 | 126 | @property 127 | def dataset_atomic_numbers(self) -> List[int]: 128 | """Unique atomic numbers in the full dataset.""" 129 | return [element.Z for element in self.dataset_elements] 130 | 131 | @property 132 | def n_system_elements(self) -> int: 133 | """Number of elements in the system.""" 134 | return len(self.system_elements) 135 | 136 | @property 137 | def n_dataset_elements(self) -> int: 138 | """Number of elements in the full dataset.""" 139 | return len(self.dataset_elements) 140 | 141 | @property 142 | def n_features_per_element(self) -> int: 143 | """Number of features per element.""" 144 | return 2 * self._symmetry_functions.N_GAUSSIANS 145 | 146 | @property 147 | def n_points(self) -> int: 148 | """Number of grid points.""" 149 | return self.system.n_points 150 | 151 | @property 152 | def n_features(self) -> int: 153 | """Number of features.""" 154 | return self.n_features_per_element * self.n_dataset_elements 155 | 156 | @property 157 | def shape(self) -> tuple[int, int]: 158 | """Shape of the fingerprint.""" 159 | return (self.n_points, self.n_features) 160 | 161 | @property 162 | def n_batches(self) -> int: 163 | """Number of batches.""" 164 | return math.ceil(self.system.n_points / self.batch_size) 165 | 166 | @property 167 | def chunk_size(self) -> int: 168 | """Size of chunk.""" 169 | return 1_000 170 | 171 | def get(self) -> ArrayLike: 172 | """Compute the fingerprint.""" 173 | results = [ 174 | self._process_batch(batch) 175 | for batch in tqdm( 176 | self._get_batches(), 177 | desc="Computing Fingerprint", 178 | total=self.n_batches, 179 | unit="batch", 180 | ) 181 | ] 182 | return np.concatenate(results, axis=0) 183 | 184 | def to_hdf5(self, file_path: str, group: str = None) -> None: 185 | """Compute and save the fingerprint to an HDF5 file. 186 | 187 | This method computes the fingerprints in batches and writes to an HDF5 file 188 | on-the-fly. This is useful for large datasets that may not fit into memory. 189 | 190 | Args: 191 | file_path (str): Path to the HDF5 file. 192 | group (str): Group within the HDF5 file to save the fingerprint. 193 | 194 | Returns: 195 | None 196 | """ 197 | with h5py.File(file_path, mode="a") as f: 198 | dataset = f.require_dataset( 199 | f"{group}/features", 200 | shape=(0, self.n_features), 201 | chunks=(5000, self.n_features), 202 | maxshape=(None, self.n_features), 203 | dtype=np.float32, 204 | compression="gzip", 205 | ) 206 | 207 | # Initialize sum, sum of squares, and count 208 | sum_values = np.zeros(self.n_features, dtype=np.float32) 209 | sum_of_squares = np.zeros(self.n_features, dtype=np.float32) 210 | n_points = 0 211 | 212 | for batch in tqdm( 213 | self._get_batches(), 214 | desc="Writing Fingerprint", 215 | total=self.n_batches, 216 | unit="batch", 217 | ): 218 | # Compute the fingerprint of the batch 219 | batch = self._process_batch(batch) 220 | 221 | # Update sum, sum of squares, and count with the new batch 222 | sum_values += np.sum(batch, axis=0) 223 | sum_of_squares += np.sum(np.square(batch), axis=0) 224 | n_points += batch.shape[0] 225 | 226 | # Resize the dataset to accommodate the new batch 227 | new_size = dataset.shape[0] + batch.shape[0] 228 | dataset.resize(new_size, axis=0) 229 | 230 | # Write the batch to the dataset 231 | dataset[-batch.shape[0]:] = batch 232 | 233 | # This will be used for on-the-fly mean and std computation 234 | dataset.attrs["sum"] = sum_values 235 | dataset.attrs["sum_of_squares"] = sum_of_squares 236 | dataset.attrs["n_points"] = n_points 237 | 238 | def _get_batches(self) -> List[ArrayLike]: 239 | """Get the grid coordinates in batches.""" 240 | batches = np.array_split(self.system.grid_coords, self.n_batches) 241 | return batches 242 | 243 | def _process_batch(self, batch_grid_coords: ArrayLike) -> da.Array: 244 | """Compute the fingerprint of a single structure.""" 245 | self.batch_size = len(batch_grid_coords) 246 | 247 | # Calculate neighbors, distances, and offset vectors for gridpoint/atom pairs. 248 | neighbors = Neighbors( 249 | grid_coords=batch_grid_coords, 250 | atom_coords=self.system.atom_coords, 251 | lattice=self.system.lattice, 252 | cutoff=self.cutoff, 253 | ) 254 | 255 | # The scalar distance array stores the distances between grid points and atoms. 256 | # This data is used to compute the scalar fingerprint in a broadcasted manner. 257 | scalar_distance_arr = self._build_scalar_distance_array( 258 | neighbors.atom_idxs, 259 | neighbors.grid_idxs, 260 | neighbors.distances, 261 | self.chunk_size, 262 | ) 263 | 264 | # We first compute a copy of these distances modified by our cutoff function. 265 | # This function smoothly decays to zero at the cutoff radius focusing on the 266 | # region of space near the atom. 267 | cutoff_scalar_distance_arr = self._cutoff_function(scalar_distance_arr) 268 | 269 | # Now we modify the scalar distances by the Gaussian symmetry functions resulting 270 | # in the scalar fingerprint array. 271 | scalar_fingerprint_arr = self._get_scalar_fingerprints( 272 | scalar_distance_arr, 273 | cutoff_scalar_distance_arr, 274 | self._symmetry_functions.NORMALIZING_CONSTANTS, 275 | self._symmetry_functions.GAUSSIAN_STANDARD_DEVIATIONS, 276 | ) 277 | 278 | # During post processing, we sum over contributions of individual atoms to each 279 | # grid point. Each species is summed separately and the results are concatenated. 280 | reduced_scalar_fingerprints = self._postprocess_scalar_fingerprints( 281 | scalar_fingerprint_arr 282 | ) 283 | 284 | # The above process is repeated for the x, y, and z components of the vector offsets. 285 | vector_distance_arr = self._build_vector_distance_array( 286 | neighbors.atom_idxs, 287 | neighbors.grid_idxs, 288 | neighbors.offset_vectors, 289 | self.chunk_size, 290 | ) 291 | 292 | # The scalar distances above are reused to accelerate this portion of the computation. 293 | vector_fingerprint_arr = self._get_vector_fingerprints( 294 | vector_distance_arr, 295 | scalar_fingerprint_arr, 296 | self._symmetry_functions.GAUSSIAN_STANDARD_DEVIATIONS, 297 | ) 298 | 299 | # Postprocessing is performed to combine the x, y, and z components of the vector 300 | # fingerprints into a single rotationally invariant vector fingerprint. 301 | reduced_vector_fingerprints = self._postprocess_vector_fingerprints( 302 | vector_fingerprint_arr 303 | ) 304 | 305 | # Scalar and vector fingerprints are now combined into a single array. 306 | fingerprint_array = self._column_block_interleave( 307 | array1=reduced_scalar_fingerprints, 308 | array2=reduced_vector_fingerprints, 309 | block_size=self._symmetry_functions.N_GAUSSIANS, 310 | ).compute() 311 | 312 | return self._postprocess_full_fingerprints(fingerprint_array) 313 | 314 | def _build_scalar_distance_array( 315 | self, 316 | atom_idxs: ArrayLike, 317 | grid_idxs: ArrayLike, 318 | distances: ArrayLike, 319 | chunk_size: int = -1, 320 | ): 321 | """Build the scalar distance array.""" 322 | # TODO: consider COO sparse arrays 323 | # Set scalar array 324 | scalar_distance_arr = self.xp.zeros((self.batch_size, self.system.n_atoms)) 325 | scalar_distance_arr[grid_idxs, atom_idxs] = distances 326 | scalar_distance_arr = da.from_array( 327 | scalar_distance_arr, chunks=(chunk_size, self.system.n_atoms) 328 | ) 329 | return scalar_distance_arr.astype(self.xp.float32) 330 | 331 | def _build_vector_distance_array( 332 | self, 333 | atom_idxs: ArrayLike, 334 | grid_idxs: ArrayLike, 335 | offset_vectors: ArrayLike, 336 | chunk_size: int = -1, 337 | ): 338 | """Build the vector distance array.""" 339 | # TODO: consider COO sparse arrays 340 | # Set vector array 341 | vector_distance_arr = self.xp.zeros((self.batch_size, self.system.n_atoms, 3)) 342 | vector_distance_arr[grid_idxs, atom_idxs] = offset_vectors 343 | vector_distance_arr = da.from_array( 344 | vector_distance_arr, 345 | chunks=(chunk_size, self.system.n_atoms, 3), 346 | ) 347 | 348 | return vector_distance_arr.astype(self.xp.float32) 349 | 350 | def _get_scalar_fingerprints( 351 | self, 352 | scalar_distances: da.Array, 353 | cutoff_scalar_distances: da.Array, 354 | NORMALIZING_CONSTANTS: ArrayLike, 355 | GAUSSIAN_STANDARD_DEVIATIONS: ArrayLike, 356 | ): 357 | """Compute the scalar fingerprint array.""" 358 | 359 | N_GAUSSIANS = len(GAUSSIAN_STANDARD_DEVIATIONS) 360 | NORMALIZING_CONSTANTS = self.xp.array(NORMALIZING_CONSTANTS) 361 | GAUSSIAN_STANDARD_DEVIATIONS = self.xp.array(GAUSSIAN_STANDARD_DEVIATIONS) 362 | 363 | # Shapes to enable broadcasting 364 | scalar_fingperint_shape = (self.batch_size, self.system.n_atoms, 1) 365 | gaussian_component_shape = ( 1, 1, N_GAUSSIANS) 366 | 367 | scalar_distances = da.reshape(scalar_distances, shape=scalar_fingperint_shape) 368 | cutoff_scalar_distances = da.reshape( 369 | cutoff_scalar_distances, shape=scalar_fingperint_shape 370 | ) 371 | NORMALIZING_CONSTANTS = self.xp.reshape( 372 | NORMALIZING_CONSTANTS, gaussian_component_shape 373 | ) 374 | GAUSSIAN_STANDARD_DEVIATIONS = self.xp.reshape( 375 | GAUSSIAN_STANDARD_DEVIATIONS, gaussian_component_shape 376 | ) 377 | 378 | # Set mask to cancel out zero values impacted by the exponential 379 | mask = scalar_distances != 0 380 | mask = da.reshape(mask, shape=scalar_fingperint_shape) 381 | 382 | exponential = da.exp( 383 | (-(scalar_distances**2)) / (2 * GAUSSIAN_STANDARD_DEVIATIONS**2) 384 | ) 385 | 386 | return NORMALIZING_CONSTANTS * (exponential * cutoff_scalar_distances) * mask 387 | 388 | def _get_vector_fingerprints( 389 | self, 390 | vector_distances: da.Array, 391 | scalar_fingerprints: da.Array, 392 | GAUSSIAN_STANDARD_DEVIATIONS: ArrayLike, 393 | ): 394 | """Compute the vector fingerprint array.""" 395 | 396 | N_GAUSSIANS = len(GAUSSIAN_STANDARD_DEVIATIONS) 397 | GAUSSIAN_STANDARD_DEVIATIONS = self.xp.array(GAUSSIAN_STANDARD_DEVIATIONS) 398 | 399 | # Shapes to enable broadcasting 400 | vector_fingerprints_shape = (self.batch_size, self.system.n_atoms, 1, 3) 401 | scalar_fingerprints_shape = (self.batch_size, self.system.n_atoms, N_GAUSSIANS, 1) 402 | gauss_standard_devs_shape = ( 1, 1, N_GAUSSIANS, 1) 403 | 404 | # Apply reshaping 405 | vector_distances = da.reshape(vector_distances, shape=vector_fingerprints_shape) 406 | scalar_fingerprints = da.reshape( 407 | scalar_fingerprints, shape=scalar_fingerprints_shape 408 | ) 409 | GAUSSIAN_STANDARD_DEVIATIONS = self.xp.reshape( 410 | GAUSSIAN_STANDARD_DEVIATIONS, gauss_standard_devs_shape 411 | ) 412 | 413 | # Compute the vector fingerprints 414 | vector_fingerprints = vector_distances * scalar_fingerprints 415 | return vector_fingerprints / (2 * GAUSSIAN_STANDARD_DEVIATIONS**2) 416 | 417 | def _postprocess_scalar_fingerprints( 418 | self, 419 | scalar_fingerprints: da.Array, 420 | ): 421 | """Postprocess the scalar fingerprints.""" 422 | reduced = self._hsplit_reduce(scalar_fingerprints) 423 | return da.concatenate(reduced, axis=1) 424 | 425 | def _postprocess_vector_fingerprints( 426 | self, 427 | vector_fingerprints: da.Array, 428 | ): 429 | """Postprocess the vector fingerprints.""" 430 | reduced = self._hsplit_reduce(vector_fingerprints) 431 | combined = da.concatenate(reduced, axis=1) 432 | return da.sqrt(da.sum(combined**2, axis=-1)) # Invariant 433 | 434 | def _postprocess_full_fingerprints(self, fingerprint_array: ArrayLike): 435 | """Postprocess the combined fingerprints.""" 436 | if self.device == "gpu": 437 | fingerprint_array = fingerprint_array.get() 438 | 439 | return self._match_dataset_elements(fingerprint_array) 440 | 441 | def _cutoff_function(self, x: ArrayLike): 442 | return 0.5 * (da.cos(np.pi * x / self.cutoff) + 1) 443 | 444 | def _hsplit_reduce(self, arr: ArrayLike): 445 | """Split and sum along the first axis.""" 446 | # Start and end indices for each element in the fingerprint. 447 | slice_indices = [0, *np.cumsum(list(self.system.element_amount_dict.values()))] 448 | return [ 449 | arr[:, start:end].sum(axis=1) 450 | for start, end in zip(slice_indices[:-1], slice_indices[1:]) 451 | ] 452 | 453 | def _match_dataset_elements(self, batch: ArrayLike): 454 | """Populate the feature array with zero columns for missing elements. 455 | 456 | This function adds zero columns to the feature array for elements that are 457 | present in the dataset but not in the system. This ensures that the feature 458 | shape aligns across the full dataset. 459 | """ 460 | if set(self.dataset_elements) == set(self.system_elements): 461 | return batch 462 | 463 | out_arr = np.zeros((batch.shape[0], self.n_features), dtype=np.float32) 464 | 465 | system_start_idx_by_atomic_number = { 466 | z: i * self.n_features_per_element for i, z in enumerate(self.system_atomic_numbers) 467 | } 468 | 469 | dataset_start_idx_by_atomic_number = { 470 | z: i * self.n_features_per_element 471 | for i, z in enumerate(self.dataset_atomic_numbers) 472 | } 473 | 474 | # Populate appropriate columns with features 475 | for z in self.system_atomic_numbers: 476 | system_start_idx = system_start_idx_by_atomic_number[z] 477 | system_end_idx = system_start_idx + self.n_features_per_element 478 | 479 | dataset_start_idx = dataset_start_idx_by_atomic_number[z] 480 | dataset_end_idx = dataset_start_idx + self.n_features_per_element 481 | 482 | out_arr[:, dataset_start_idx:dataset_end_idx] = batch[ 483 | :, system_start_idx:system_end_idx 484 | ] 485 | 486 | return out_arr 487 | 488 | def _column_block_interleave( 489 | self, array1: ArrayLike, array2: ArrayLike, block_size: int 490 | ): 491 | """Interleave the blocks of two arrays along the columns axis""" 492 | # Ensure the arrays have compatible shapes 493 | assert array1.shape == array2.shape, "Arrays must have the same shape" 494 | 495 | N, total_columns = array1.shape 496 | M = total_columns // block_size # Number of blocks 497 | 498 | # Reshape the arrays to expose the blocks 499 | reshaped1 = array1.reshape(N, M, block_size) 500 | reshaped2 = array2.reshape(N, M, block_size) 501 | 502 | # Stack the arrays along a new axis to interleave the blocks 503 | # The new shape will be (N, M, 2, block_size) 504 | interleaved = da.stack([reshaped1, reshaped2], axis=2) 505 | 506 | # Reshape back to (N, M * 2 * block_size) to intermix blocks 507 | result = interleaved.reshape(N, M * 2 * block_size) 508 | 509 | return result 510 | --------------------------------------------------------------------------------