├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── THIRD-PARTY-LICENSES ├── alchemical_mace ├── __init__.py ├── calculator.py ├── model.py ├── optimize.py └── utils.py ├── data ├── perovskite_dataset.json └── structures │ ├── AlN_hex.cif │ ├── BiSBr.cif │ ├── CeO2.cif │ ├── CsPbI3_alpha.cif │ ├── CsPbI3_delta.cif │ ├── CsSnI3_alpha.cif │ ├── CsSnI3_delta.cif │ ├── Fe.cif │ ├── GaN_hex.cif │ └── NaCl.cif ├── notebooks ├── 1_solid_solution.ipynb ├── 2_compositional_optimization.ipynb ├── 3_disorder_energy.ipynb ├── 4_vacancy_analysis.ipynb └── 5_perovskite_analysis.ipynb ├── pyproject.toml ├── requirements.txt └── scripts ├── perovskite_alchemy.py ├── perovskite_frenkel_ladd.py └── vacancy_frenkel_ladd.py /.gitignore: -------------------------------------------------------------------------------- 1 | archived/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # Results directory 165 | results/ 166 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: check-added-large-files 6 | - id: check-yaml 7 | - id: debug-statements 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | rev: v0.4.2 12 | hooks: 13 | - id: ruff 14 | types_or: [ python, pyi, jupyter ] 15 | args: [ --fix ] 16 | - id: ruff-format 17 | types_or: [ python, pyi, jupyter ] 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Juno Nam and Rafael Gomez-Bombarelli 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Alchemical MLIP 2 | [![arXiv](https://img.shields.io/badge/arXiv-2404.10746-84cc16)](https://arxiv.org/abs/2404.10746) 3 | [![Zenodo](https://img.shields.io/badge/DOI-10.5281/zenodo.11081492-14b8a6.svg)](https://zenodo.org/doi/10.5281/zenodo.11081492) 4 | [![MIT](https://img.shields.io/badge/License-MIT-3b82f6.svg)](https://opensource.org/license/mit) 5 | 6 | This repository contains the code to modify machine learning interatomic potentials (MLIPs) to enable continuous and differentiable alchemical transformations. 7 | Currently, we provide the alchemical modification for the [MACE](https://github.com/ACEsuit/mace) model. 8 | The details of the method are described in the paper: [Interpolation and differentiation of alchemical degrees of freedom in machine learning interatomic potentials](https://arxiv.org/abs/2404.10746). 9 | 10 | ## Installation 11 | We tested the code with Python 3.10 and the packages in `requirements.txt`. 12 | For example, you can create a conda environment and install the required packages as follows (assuming CUDA 11.8): 13 | ```bash 14 | conda create -n alchemical-mlip python=3.10 15 | conda activate alchemical-mlip 16 | pip install torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118 17 | pip install -r requirements.txt 18 | pip install -e . 19 | ``` 20 | 21 | ## Static calculations 22 | We provide the jupyter notebooks for the lattice parameter calculations (Fig. 2 in the paper) and the compositional optimization (Fig. 3) in the `notebook` directory. 23 | ``` 24 | notebook/ 25 | ├── 1_solid_solution.ipynb 26 | └── 2_compositional_optimization.ipynb 27 | ``` 28 | 29 | ## Free energy calculations 30 | We provide the scripts for the free energy calculations for the vacancy (Fig. 4) and perovskites (Fig. 5) in the `scripts` directory. 31 | ``` 32 | scripts/ 33 | ├── vacancy_frenkel_ladd.py 34 | ├── perovskite_frenkel_ladd.py 35 | └── perovskite_alchemy.py 36 | ``` 37 | 38 | The arguments for the scripts are as follows: 39 | ```bash 40 | # Vacancy Frenkel-Ladd calculation 41 | python vacancy_frenkel_ladd.py \ 42 | --structure-file data/structures/Fe.cif \ 43 | --supercell 5 5 5 \ 44 | --temperature 100 \ 45 | --output-dir data/results/vacancy/Fe_5x5x5_100K/0 46 | 47 | # Perovskite Frenkel-Ladd calculation (alpha phase) 48 | python perovskite_frenkel_ladd.py \ 49 | --structure-file data/structures/CsPbI3_alpha.cif \ 50 | --supercell 6 6 6 \ 51 | --temperature 400 \ 52 | --output-dir data/results/perovskite/frenkel_ladd/CsPbI3_alpha_6x6x6_400K/0 53 | 54 | # Perovskite Frenkel-Ladd calculation (delta phase) 55 | python perovskite_frenkel_ladd.py \ 56 | --structure-file data/structures/CsPbI3_delta.cif \ 57 | --supercell 6 3 3 \ 58 | --temperature 400 \ 59 | --output-dir data/results/perovskite/frenkel_ladd/CsPbI3_delta_6x3x3_400K/0 60 | 61 | # Perovskite alchemy calculation (alpha phase) 62 | python -u perovskite_alchemy.py \ 63 | --structure-file data/structures/CsPbI3_alpha.cif \ 64 | --supercell 6 6 6 \ 65 | --switch-pair Pb Sn \ 66 | --temperature 400 \ 67 | --output-dir data/results/perovskite/alchemy/CsPbI3_CsSnI3_alpha_400K/0 68 | 69 | # Perovskite alchemy calculation (delta phase) 70 | python -u perovskite_alchemy.py \ 71 | --structure-file data/structures/CsPbI3_delta.cif \ 72 | --supercell 6 3 3 \ 73 | --switch-pair Pb Sn \ 74 | --temperature 400 \ 75 | --output-dir data/results/perovskite/alchemy/CsPbI3_CsSnI3_delta_400K/0 76 | ``` 77 | 78 | The result files are large and not included in the repository. 79 | If you want to reproduce the results without running the calculations, the result files are uploaded in the [Zenodo repository](https://zenodo.org/doi/10.5281/zenodo.11081395). 80 | Please download the files and place them in the `data/results` directory. 81 | 82 | The post-processing scripts for the free energy calculations are provided in the `notebook` directory. 83 | ``` 84 | notebook/ 85 | ├── 3_vacancy_analysis.ipynb 86 | └── 4_perovskite_analysis.ipynb 87 | ``` 88 | 89 | ## Citation 90 | ``` 91 | @misc{nam2024interpolation, 92 | title={Interpolation and differentiation of alchemical degrees of freedom in machine learning interatomic potentials}, 93 | author={Juno Nam and Rafael G{\'o}mez-Bombarelli}, 94 | year={2024}, 95 | eprint={2404.10746}, 96 | archivePrefix={arXiv}, 97 | primaryClass={cond-mat.mtrl-sci} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /THIRD-PARTY-LICENSES: -------------------------------------------------------------------------------- 1 | Code in alchemical_mace/{calculator,model}.py is adapted from 2 | https://github.com/ACEsuit/mace 3 | 4 | MIT License 5 | 6 | Copyright (c) 2022 ACEsuit/mace 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 11 | 12 | ------------------------------------------------------------------------------- 13 | 14 | Code in alchemical_mace/utils.py is adapted from 15 | https://github.com/CederGroupHub/chgnet 16 | 17 | Crystal Hamiltonian Graph neural Network (CHGNet) Copyright (c) 2023, The Regents 18 | of the University of California, through Lawrence Berkeley National 19 | Laboratory (subject to receipt of any required approvals from the U.S. 20 | Dept. of Energy) and the University of California, Berkeley. All rights reserved. 21 | 22 | Redistribution and use in source and binary forms, with or without 23 | modification, are permitted provided that the following conditions are met: 24 | 25 | (1) Redistributions of source code must retain the above copyright notice, 26 | this list of conditions and the following disclaimer. 27 | 28 | (2) Redistributions in binary form must reproduce the above copyright 29 | notice, this list of conditions and the following disclaimer in the 30 | documentation and/or other materials provided with the distribution. 31 | 32 | (3) Neither the name of the University of California, Lawrence Berkeley 33 | National Laboratory, U.S. Dept. of Energy, University of California, 34 | Berkeley nor the names of its contributors may be used to endorse or 35 | promote products derived from this software without specific prior written 36 | permission. 37 | 38 | 39 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 40 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 41 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 42 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 43 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 44 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 45 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 46 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 47 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 48 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 49 | POSSIBILITY OF SUCH DAMAGE. 50 | 51 | You are under no obligation whatsoever to provide any bug fixes, patches, 52 | or upgrades to the features, functionality or performance of the source 53 | code ("Enhancements") to anyone; however, if you choose to make your 54 | Enhancements available either publicly, or directly to Lawrence Berkeley 55 | National Laboratory, without imposing a separate written license agreement 56 | for such Enhancements, then you hereby grant the following license: a 57 | non-exclusive, royalty-free perpetual license to install, use, modify, 58 | prepare derivative works, incorporate into other computer software, 59 | distribute, and sublicense such enhancements or derivative works thereof, 60 | in binary and source code form. 61 | -------------------------------------------------------------------------------- /alchemical_mace/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/learningmatter-mit/alchemical-mlip/b556b850e9e89a4c2e1f5e9c27f03a3b1b393864/alchemical_mace/__init__.py -------------------------------------------------------------------------------- /alchemical_mace/calculator.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Tuple 2 | 3 | import ase 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from ase.calculators.calculator import Calculator, all_changes 8 | from ase.constraints import ExpCellFilter 9 | from ase.optimize import FIRE 10 | from ase.stress import full_3x3_to_voigt_6_stress 11 | from mace import data 12 | from mace.calculators import mace_mp 13 | from mace.tools import torch_geometric 14 | 15 | from alchemical_mace.model import ( 16 | AlchemicalPair, 17 | AlchemyManager, 18 | alchemical_mace_mp, 19 | get_z_table_and_r_max, 20 | ) 21 | 22 | ################################################################################ 23 | # Alchemical MACE calculator 24 | ################################################################################ 25 | 26 | 27 | class AlchemicalMACECalculator(Calculator): 28 | """ 29 | Alchemical MACE calculator for ASE. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | atoms: ase.Atoms, 35 | alchemical_pairs: Sequence[Sequence[Tuple[int, int]]], 36 | alchemical_weights: Sequence[float], 37 | device: str = "cpu", 38 | model: str = "medium", 39 | ): 40 | """ 41 | Initialize the Alchemical MACE calculator. 42 | 43 | Args: 44 | atoms (ase.Atoms): Atoms object. 45 | alchemical_pairs (Sequence[Sequence[Tuple[int, int]]]): List of 46 | alchemical pairs. Each pair is a tuple of the atom index and 47 | atomic number of an alchemical atom. 48 | alchemical_weights (Sequence[float]): List of alchemical weights. 49 | device (str): Device to run the calculations on. 50 | model (str): Model to use for the MACE calculator. 51 | """ 52 | Calculator.__init__(self) 53 | self.results = {} 54 | self.implemented_properties = ["energy", "free_energy", "forces", "stress"] 55 | 56 | # Build the alchemical MACE model 57 | self.device = device 58 | self.model = alchemical_mace_mp( 59 | model=model, device=device, default_dtype="float32" 60 | ) 61 | for param in self.model.parameters(): 62 | param.requires_grad = False 63 | 64 | # Set AlchemyManager 65 | z_table, r_max = get_z_table_and_r_max(self.model) 66 | alchemical_weights = torch.tensor(alchemical_weights, dtype=torch.float32) 67 | self.alchemy_manager = AlchemyManager( 68 | atoms=atoms, 69 | alchemical_pairs=alchemical_pairs, 70 | alchemical_weights=alchemical_weights, 71 | z_table=z_table, 72 | r_max=r_max, 73 | ).to(self.device) 74 | 75 | # Disable alchemical weights gradients by default 76 | self.alchemy_manager.alchemical_weights.requires_grad = False 77 | self.calculate_alchemical_grad = False 78 | 79 | self.num_atoms = len(atoms) 80 | 81 | def set_alchemical_weights(self, alchemical_weights: Sequence[float]): 82 | alchemical_weights = torch.tensor( 83 | alchemical_weights, 84 | dtype=torch.float32, 85 | device=self.device, 86 | ) 87 | self.alchemy_manager.alchemical_weights.data = alchemical_weights 88 | 89 | def get_alchemical_atomic_masses(self) -> np.ndarray: 90 | # Get atomic masses for alchemical atoms 91 | node_masses = ase.data.atomic_masses[self.alchemy_manager.atomic_numbers] 92 | weights = self.alchemy_manager.alchemical_weights.data 93 | weights = F.pad(weights, (1, 0), "constant", 1.0).cpu().numpy() 94 | node_weights = weights[self.alchemy_manager.weight_indices] 95 | 96 | # Scatter sum to get the atomic masses 97 | atom_masses = np.zeros(self.num_atoms, dtype=np.float32) 98 | np.add.at( 99 | atom_masses, self.alchemy_manager.atom_indices, node_masses * node_weights 100 | ) 101 | return atom_masses 102 | 103 | # pylint: disable=dangerous-default-value 104 | def calculate(self, atoms=None, properties=None, system_changes=all_changes): 105 | # call to base-class to set atoms attribute 106 | Calculator.calculate(self, atoms) 107 | 108 | # prepare data 109 | tensor_kwargs = {"dtype": torch.float32, "device": self.device} 110 | positions = torch.tensor(atoms.get_positions(), **tensor_kwargs) 111 | cell = torch.tensor(atoms.get_cell().array, **tensor_kwargs) 112 | if self.calculate_alchemical_grad: 113 | self.alchemy_manager.alchemical_weights.requires_grad = True 114 | batch = self.alchemy_manager(positions, cell).to(self.device) 115 | 116 | # get outputs 117 | if self.calculate_alchemical_grad: 118 | out = self.model(batch, compute_stress=True, compute_alchemical_grad=True) 119 | (grad,) = torch.autograd.grad( 120 | outputs=[batch["node_weights"], batch["edge_weights"]], 121 | inputs=[self.alchemy_manager.alchemical_weights], 122 | grad_outputs=[out["node_grad"], out["edge_grad"]], 123 | retain_graph=False, 124 | create_graph=False, 125 | ) 126 | grad = grad.cpu().numpy() 127 | self.alchemy_manager.alchemical_weights.requires_grad = False 128 | else: 129 | out = self.model(batch, retain_graph=False, compute_stress=True) 130 | grad = np.zeros( 131 | self.alchemy_manager.alchemical_weights.shape[0], dtype=np.float32 132 | ) 133 | 134 | # store results 135 | self.results = {} 136 | self.results["energy"] = out["energy"][0].item() 137 | self.results["free_energy"] = self.results["energy"] 138 | self.results["forces"] = out["forces"].detach().cpu().numpy() 139 | self.results["stress"] = full_3x3_to_voigt_6_stress( 140 | out["stress"][0].detach().cpu().numpy() 141 | ) 142 | self.results["alchemical_grad"] = grad 143 | 144 | 145 | class NVTMACECalculator(Calculator): 146 | def __init__(self, model: str = "medium", device: str = "cuda"): 147 | Calculator.__init__(self) 148 | self.results = {} 149 | self.implemented_properties = ["energy", "free_energy", "forces", "stress"] 150 | self.device = device 151 | self.model = mace_mp( 152 | model=model, device=device, default_dtype="float32" 153 | ).models[0] 154 | self.z_table, self.r_max = get_z_table_and_r_max(self.model) 155 | for param in self.model.parameters(): 156 | param.requires_grad = False 157 | 158 | # pylint: disable=dangerous-default-value 159 | def calculate(self, atoms=None, properties=None, system_changes=all_changes): 160 | # call to base-class to set atoms attribute 161 | Calculator.calculate(self, atoms) 162 | 163 | # prepare data 164 | config = data.config_from_atoms(atoms) 165 | atomic_data = data.AtomicData.from_config( 166 | config, z_table=self.z_table, cutoff=self.r_max 167 | ) 168 | data_loader = torch_geometric.dataloader.DataLoader( 169 | dataset=[atomic_data], 170 | batch_size=1, 171 | shuffle=False, 172 | drop_last=False, 173 | ) 174 | batch = next(iter(data_loader)).to(self.device) 175 | 176 | out = self.model(batch, compute_stress=False) 177 | self.results = {} 178 | self.results["energy"] = out["energy"][0].item() 179 | self.results["free_energy"] = self.results["energy"] 180 | self.results["forces"] = out["forces"].detach().cpu().numpy() 181 | 182 | 183 | class FrenkelLaddCalculator(Calculator): 184 | """ 185 | Frenkel-Ladd calculator for ASE. 186 | """ 187 | 188 | def __init__( 189 | self, 190 | spring_constants: np.ndarray, 191 | initial_positions: np.ndarray, 192 | device: str, 193 | model: str = "medium", 194 | ): 195 | """ 196 | Initialize the Frenkel-Ladd calculator. 197 | 198 | Args: 199 | spring_constants (np.ndarray): Spring constants for each atom. 200 | initial_positions (np.ndarray): Initial positions of the atoms. 201 | device (str): Device to run the calculations on. 202 | model (str): Model to use for the MACE calculator. 203 | """ 204 | Calculator.__init__(self) 205 | self.results = {} 206 | self.implemented_properties = ["energy", "free_energy", "forces"] 207 | self.device = device 208 | self.model = mace_mp( 209 | model=model, device=device, default_dtype="float32" 210 | ).models[0] 211 | self.z_table, self.r_max = get_z_table_and_r_max(self.model) 212 | for param in self.model.parameters(): 213 | param.requires_grad = False 214 | 215 | # Spring constants 216 | self.spring_constants = spring_constants 217 | self.initial_positions = initial_positions 218 | 219 | # Reversible scaling factor 220 | self.weights = [1.0, 0.0] 221 | self.compute_mace = True 222 | 223 | def set_weights(self, lambda_value: float): 224 | self.weights = [1.0 - lambda_value, lambda_value] 225 | 226 | # pylint: disable=dangerous-default-value 227 | def calculate(self, atoms=None, properties=None, system_changes=all_changes): 228 | # call to base-class to set atoms attribute 229 | Calculator.calculate(self, atoms) 230 | 231 | # Get MACE results if needed 232 | if self.compute_mace: 233 | config = data.config_from_atoms(atoms) 234 | atomic_data = data.AtomicData.from_config( 235 | config, z_table=self.z_table, cutoff=self.r_max 236 | ) 237 | data_loader = torch_geometric.dataloader.DataLoader( 238 | dataset=[atomic_data], 239 | batch_size=1, 240 | shuffle=False, 241 | drop_last=False, 242 | ) 243 | batch = next(iter(data_loader)).to(self.device) 244 | out = self.model(batch, compute_stress=False) # Frenkel-Ladd is NVT 245 | mace_energy = out["energy"][0].item() 246 | mace_forces = out["forces"].detach().cpu().numpy() 247 | else: 248 | mace_energy = 0.0 249 | mace_forces = np.zeros((len(atoms), 3), dtype=np.float32) 250 | 251 | # Get spring energy and forces 252 | displacement = atoms.get_positions() - self.initial_positions 253 | spring_energy = 0.5 * np.sum( 254 | self.spring_constants * np.sum(displacement**2, axis=1) 255 | ) 256 | spring_forces = -self.spring_constants[:, None] * displacement 257 | 258 | # Combine energies and forces 259 | total_energy = self.weights[0] * spring_energy + self.weights[1] * mace_energy 260 | total_forces = self.weights[0] * spring_forces + self.weights[1] * mace_forces 261 | if self.compute_mace: 262 | energy_diff = mace_energy - spring_energy 263 | else: 264 | energy_diff = 0.0 265 | 266 | self.results = {} 267 | self.results["energy"] = total_energy 268 | self.results["free_energy"] = total_energy 269 | self.results["forces"] = total_forces 270 | self.results["energy_diff"] = energy_diff 271 | 272 | 273 | class DefectFrenkelLaddCalculator(Calculator): 274 | """ 275 | Frenkel-Ladd calculator for ASE, for a crystal with a defect. 276 | """ 277 | 278 | def __init__( 279 | self, 280 | atoms: ase.Atoms, 281 | spring_constant: float, 282 | defect_index: int, 283 | device: str = "cpu", 284 | model: str = "medium", 285 | ): 286 | """ 287 | Initialize the Frenkel-Ladd calculator. 288 | 289 | Args: 290 | atoms (ase.Atoms): Atoms object. 291 | spring_constant (float): Spring constant for the defect atom. 292 | defect_index (int): Index of the defect atom. 293 | device (str): Device to run the calculations on. 294 | model (str): Model to use for the MACE calculator. 295 | """ 296 | Calculator.__init__(self) 297 | self.results = {} 298 | self.implemented_properties = ["energy", "free_energy", "forces", "stress"] 299 | 300 | # Build the alchemical MACE model 301 | self.device = device 302 | self.model = alchemical_mace_mp( 303 | model=model, device=device, default_dtype="float32" 304 | ) 305 | for param in self.model.parameters(): 306 | param.requires_grad = False 307 | 308 | # Set AlchemyManager 309 | z_table, r_max = get_z_table_and_r_max(self.model) 310 | alchemical_weights = torch.tensor([1.0], dtype=torch.float32) 311 | atomic_number = atoms.get_atomic_numbers()[defect_index] 312 | alchemical_pairs = [[AlchemicalPair(defect_index, atomic_number)]] 313 | self.alchemy_manager = AlchemyManager( 314 | atoms=atoms, 315 | alchemical_pairs=alchemical_pairs, 316 | alchemical_weights=alchemical_weights, 317 | z_table=z_table, 318 | r_max=r_max, 319 | ).to(self.device) 320 | 321 | # Disable alchemical weights gradients by default 322 | self.alchemy_manager.alchemical_weights.requires_grad = False 323 | self.calculate_alchemical_grad = False 324 | 325 | self.num_atoms = len(atoms) 326 | 327 | # Switching 328 | self.defect_index = defect_index 329 | self.spring_constant = spring_constant 330 | 331 | def set_alchemical_weight(self, alchemical_weight: float): 332 | # Set alchemical weights 333 | alchemical_weights = torch.tensor( 334 | [1.0 - alchemical_weight], # initial = original atoms = 1 - 0 335 | dtype=torch.float32, 336 | device=self.device, 337 | ) 338 | self.alchemy_manager.alchemical_weights.data = alchemical_weights 339 | 340 | # pylint: disable=dangerous-default-value 341 | def calculate(self, atoms=None, properties=None, system_changes=all_changes): 342 | # call to base-class to set atoms attribute 343 | Calculator.calculate(self, atoms) 344 | 345 | # prepare data 346 | tensor_kwargs = {"dtype": torch.float32, "device": self.device} 347 | positions = torch.tensor(atoms.get_positions(), **tensor_kwargs) 348 | cell = torch.tensor(atoms.get_cell().array, **tensor_kwargs) 349 | if self.calculate_alchemical_grad: 350 | self.alchemy_manager.alchemical_weights.requires_grad = True 351 | batch = self.alchemy_manager(positions, cell).to(self.device) 352 | 353 | # get outputs 354 | if self.calculate_alchemical_grad: 355 | out = self.model(batch, retain_graph=True, compute_stress=True) 356 | out["energy"].backward() 357 | grad = self.alchemy_manager.alchemical_weights.grad.item() 358 | self.alchemy_manager.alchemical_weights.grad.zero_() 359 | self.alchemy_manager.alchemical_weights.requires_grad = False 360 | else: 361 | out = self.model(batch, retain_graph=False, compute_stress=True) 362 | grad = 0.0 363 | mace_energy = out["energy"][0].item() 364 | mace_forces = out["forces"].detach().cpu().numpy() 365 | mace_stress = out["stress"][0].detach().cpu().numpy() 366 | 367 | # Get spring energy and forces 368 | cell_center = np.array([0.5, 0.5, 0.5]) @ atoms.get_cell().array 369 | displacement = atoms.get_positions()[self.defect_index] - cell_center 370 | spring_energy = 0.5 * self.spring_constant * np.sum(displacement**2) 371 | spring_forces = -self.spring_constant * displacement 372 | 373 | # Combine energies and forces 374 | # Note: weight here is 1 - lambda, and we're not weighting the mace 375 | # energy because it's already weighted by the alchemical weight 376 | weight = self.alchemy_manager.alchemical_weights.item() 377 | total_energy = mace_energy + (1 - weight) * spring_energy 378 | total_forces = mace_forces 379 | total_forces[self.defect_index] += (1 - weight) * spring_forces 380 | if self.calculate_alchemical_grad: 381 | # H(lambda) = E(1 - lambda) + lambda * spring_energy 382 | # dH/d(lambda) = -dE/d(1 - lambda) + spring_energy 383 | grad = -grad + spring_energy 384 | 385 | # store results 386 | self.results = {} 387 | self.results["energy"] = total_energy 388 | self.results["free_energy"] = total_energy 389 | self.results["forces"] = total_forces 390 | self.results["stress"] = full_3x3_to_voigt_6_stress(mace_stress) 391 | self.results["alchemical_grad"] = grad 392 | 393 | 394 | def get_alchemical_optimized_cellpar( 395 | atoms: ase.Atoms, 396 | alchemical_pairs: Sequence[Sequence[Tuple[int, int]]], 397 | alchemical_weights: Sequence[float], 398 | model: str = "medium", 399 | device: str = "cpu", 400 | **kwargs, 401 | ): 402 | """ 403 | Optimize the cell parameters of a crystal with alchemical atoms using the 404 | Alchemical MACE calculator. 405 | 406 | Args: 407 | atoms (ase.Atoms): Atoms object. 408 | alchemical_pairs (Sequence[Sequence[Tuple[int, int]]]): List of 409 | alchemical pairs. Each pair is a tuple of the atom index and 410 | atomic number of an alchemical atom. 411 | alchemical_weights (Sequence[float]): List of alchemical weights. 412 | model (str): Model to use for the MACE calculator. 413 | device (str): Device to run the calculations on. 414 | 415 | Returns: 416 | np.ndarray: Optimized cell parameters. 417 | """ 418 | # Make a copy of the atoms object 419 | atoms = atoms.copy() 420 | 421 | # Load Alchemical MACE calculator and relax the structure 422 | calc = AlchemicalMACECalculator( 423 | atoms, alchemical_pairs, alchemical_weights, device=device, model=model 424 | ) 425 | atoms.set_calculator(calc) 426 | atoms.set_masses(calc.get_alchemical_atomic_masses()) 427 | atoms = ExpCellFilter(atoms) 428 | optimizer = FIRE(atoms) 429 | optimizer.run(fmax=kwargs.get("fmax", 0.01), steps=kwargs.get("steps", 500)) 430 | 431 | # Return the optimized cell parameters 432 | return atoms.atoms.get_cell().cellpar() 433 | -------------------------------------------------------------------------------- /alchemical_mace/model.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from collections import namedtuple 3 | from typing import Dict, List, Optional, Sequence, Tuple 4 | 5 | import ase 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from e3nn import o3 10 | from e3nn.util.jit import compile_mode 11 | from mace import modules, tools 12 | from mace.calculators import mace_mp 13 | from mace.data.neighborhood import get_neighborhood 14 | from mace.modules import RealAgnosticResidualInteractionBlock, ScaleShiftMACE 15 | from mace.modules.utils import get_edge_vectors_and_lengths, get_symmetric_displacement 16 | from mace.tools import ( 17 | AtomicNumberTable, 18 | atomic_numbers_to_indices, 19 | to_one_hot, 20 | torch_geometric, 21 | utils, 22 | ) 23 | from mace.tools.scatter import scatter_sum 24 | 25 | ################################################################################ 26 | # Alchemy manager class for handling alchemical weights 27 | ################################################################################ 28 | 29 | AlchemicalPair = namedtuple("AlchemicalPair", ["atom_index", "atomic_number"]) 30 | 31 | 32 | class AlchemyManager(torch.nn.Module): 33 | """ 34 | Class for managing alchemical weights and building alchemical graphs for MACE. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | atoms: ase.Atoms, 40 | alchemical_pairs: Sequence[Sequence[Tuple[int, int]]], 41 | alchemical_weights: torch.Tensor, 42 | z_table: AtomicNumberTable, 43 | r_max: float, 44 | ): 45 | """ 46 | Initialize the alchemy manager. 47 | 48 | Args: 49 | atoms: ASE atoms object 50 | alchemical_pairs: List of lists of tuples, where each tuple contains 51 | the atom index and atomic number of an alchemical atom 52 | alchemical_weights: Tensor of alchemical weights 53 | z_table: Atomic number table 54 | r_max: Maximum cutoff radius for the alchemical graph 55 | """ 56 | super().__init__() 57 | self.alchemical_weights = torch.nn.Parameter(alchemical_weights) 58 | self.r_max = r_max 59 | 60 | # Process alchemical pairs into atom indices and atomic numbers 61 | # Alchemical weights are 1-indexed, 0 is reserved for non-alchemical atoms 62 | alchemical_atom_indices = [] 63 | alchemical_atomic_numbers = [] 64 | alchemical_weight_indices = [] 65 | 66 | for weight_idx, pairs in enumerate(alchemical_pairs): 67 | for pair in pairs: 68 | alchemical_atom_indices.append(pair.atom_index) 69 | alchemical_atomic_numbers.append(pair.atomic_number) 70 | alchemical_weight_indices.append(weight_idx + 1) 71 | 72 | non_alchemical_atom_indices = [ 73 | i for i in range(len(atoms)) if i not in alchemical_atom_indices 74 | ] 75 | non_alchemical_atomic_numbers = atoms.get_atomic_numbers()[ 76 | non_alchemical_atom_indices 77 | ].tolist() 78 | non_alchemical_weight_indices = [0] * len(non_alchemical_atom_indices) 79 | 80 | self.atom_indices = alchemical_atom_indices + non_alchemical_atom_indices 81 | self.atomic_numbers = alchemical_atomic_numbers + non_alchemical_atomic_numbers 82 | self.weight_indices = alchemical_weight_indices + non_alchemical_weight_indices 83 | 84 | self.atom_indices = np.array(self.atom_indices) 85 | self.atomic_numbers = np.array(self.atomic_numbers) 86 | self.weight_indices = np.array(self.weight_indices) 87 | 88 | sort_idx = np.argsort(self.atom_indices) 89 | self.atom_indices = self.atom_indices[sort_idx] 90 | self.atomic_numbers = self.atomic_numbers[sort_idx] 91 | self.weight_indices = self.weight_indices[sort_idx] 92 | 93 | # Array to map original atom indices to alchemical indices 94 | # -1 means the atom does not have a corresponding alchemical atom 95 | # [n_atoms, n_weights + 1] 96 | self.original_to_alchemical_index = np.full( 97 | (len(atoms), len(alchemical_pairs) + 1), -1 98 | ) 99 | for i, (atom_idx, weight_idx) in enumerate( 100 | zip(self.atom_indices, self.weight_indices) 101 | ): 102 | self.original_to_alchemical_index[atom_idx, weight_idx] = i 103 | 104 | self.is_original_atom_alchemical = np.any( 105 | self.original_to_alchemical_index[:, 1:] != -1, axis=1 106 | ) 107 | 108 | # Extract common node features 109 | z_indices = atomic_numbers_to_indices(self.atomic_numbers, z_table=z_table) 110 | node_attrs = to_one_hot( 111 | torch.tensor(z_indices, dtype=torch.long).unsqueeze(-1), 112 | num_classes=len(z_table), 113 | ) 114 | self.register_buffer("node_attrs", node_attrs) 115 | self.pbc = atoms.get_pbc() 116 | 117 | def forward( 118 | self, 119 | positions: torch.Tensor, 120 | cell: torch.Tensor, 121 | ) -> Dict[str, torch.Tensor]: 122 | """ 123 | Build an alchemical graph for the given positions and cell. 124 | 125 | Args: 126 | positions: Tensor of atomic positions 127 | cell: Tensor of cell vectors 128 | 129 | Returns: 130 | Dictionary containing the alchemical graph data 131 | """ 132 | 133 | # Build original atom graph 134 | orig_edge_index, shifts, unit_shifts = get_neighborhood( 135 | positions=positions.detach().cpu().numpy(), 136 | cutoff=self.r_max, 137 | pbc=self.pbc, 138 | cell=cell.detach().cpu().numpy(), 139 | ) 140 | 141 | # Extend edges to alchemical pairs 142 | edge_index = [] 143 | orig_edge_loc = [] 144 | edge_weight_indices = [] 145 | 146 | is_alchemical = self.is_original_atom_alchemical[orig_edge_index] 147 | src_non_dst_non = ~is_alchemical[0] & ~is_alchemical[1] 148 | src_non_dst_alch = ~is_alchemical[0] & is_alchemical[1] 149 | src_alch_dst_non = is_alchemical[0] & ~is_alchemical[1] 150 | src_alch_dst_alch = is_alchemical[0] & is_alchemical[1] 151 | 152 | # Both non-alchemical: keep as is 153 | _orig_edge_index = orig_edge_index[:, src_non_dst_non] 154 | edge_index.append(self.original_to_alchemical_index[_orig_edge_index, 0]) 155 | orig_edge_loc.append(np.where(src_non_dst_non)[0]) 156 | edge_weight_indices.append(np.zeros_like(_orig_edge_index[0])) 157 | 158 | # Source non-alchemical, destination alchemical: pair all, weights are 1 159 | _src, _dst = orig_edge_index[:, src_non_dst_alch] 160 | _orig_edge_loc = np.where(src_non_dst_alch)[0] 161 | _src = self.original_to_alchemical_index[_src, 0] 162 | _dst = self.original_to_alchemical_index[_dst, :] 163 | _dst_mask = _dst != -1 164 | _dst = _dst[_dst_mask] 165 | _repeat = _dst_mask.sum(axis=1) 166 | _src = np.repeat(_src, _repeat) 167 | edge_index.append(np.stack((_src, _dst), axis=0)) 168 | orig_edge_loc.append(np.repeat(_orig_edge_loc, _repeat)) 169 | edge_weight_indices.append(np.zeros_like(_src)) 170 | 171 | # Source alchemical, destination non-alchemical: pair all, follow src weights 172 | _src, _dst = orig_edge_index[:, src_alch_dst_non] 173 | _orig_edge_loc = np.where(src_alch_dst_non)[0] 174 | _src = self.original_to_alchemical_index[_src, :] 175 | _dst = self.original_to_alchemical_index[_dst, 0] 176 | _src_mask = _src != -1 177 | _src = _src[_src_mask] 178 | _repeat = _src_mask.sum(axis=1) 179 | _dst = np.repeat(_dst, _repeat) 180 | edge_index.append(np.stack((_src, _dst), axis=0)) 181 | orig_edge_loc.append(np.repeat(_orig_edge_loc, _repeat)) 182 | edge_weight_indices.append(np.where(_src_mask)[1]) 183 | 184 | # Both alchemical: pair according to alchemical indices, weights are 1 185 | _orig_edge_index = orig_edge_index[:, src_alch_dst_alch] 186 | _orig_edge_loc = np.where(src_alch_dst_alch)[0] 187 | _alch_edge_index = self.original_to_alchemical_index[_orig_edge_index, :] 188 | _idx = np.where((_alch_edge_index != -1).all(axis=0)) 189 | edge_index.append(_alch_edge_index[:, _idx[0], _idx[1]]) 190 | orig_edge_loc.append(_orig_edge_loc[_idx[0]]) 191 | edge_weight_indices.append(np.zeros_like(_idx[0])) 192 | 193 | # Collect all edges 194 | edge_index = np.concatenate(edge_index, axis=1) 195 | orig_edge_loc = np.concatenate(orig_edge_loc) 196 | edge_weight_indices = np.concatenate(edge_weight_indices) 197 | 198 | # Convert to torch tensors 199 | edge_index = torch.tensor(edge_index, dtype=torch.long) 200 | shifts = torch.tensor(shifts[orig_edge_loc], dtype=torch.float32) 201 | unit_shifts = torch.tensor(unit_shifts[orig_edge_loc], dtype=torch.float32) 202 | 203 | # Alchemical weights for nodes and edges 204 | weights = F.pad(self.alchemical_weights, (1, 0), "constant", 1.0) 205 | node_weights = weights[self.weight_indices] 206 | edge_weights = weights[edge_weight_indices] 207 | 208 | # Build data batch 209 | atomic_data = torch_geometric.data.Data( 210 | num_nodes=len(self.atom_indices), 211 | edge_index=edge_index, 212 | node_attrs=self.node_attrs, 213 | positions=positions[self.atom_indices], 214 | shifts=shifts, 215 | unit_shifts=unit_shifts, 216 | cell=cell, 217 | node_weights=node_weights, 218 | edge_weights=edge_weights, 219 | node_atom_indices=torch.tensor(self.atom_indices, dtype=torch.long), 220 | ) 221 | data_loader = torch_geometric.dataloader.DataLoader( 222 | dataset=[atomic_data], 223 | batch_size=1, 224 | shuffle=False, 225 | drop_last=False, 226 | ) 227 | batch = next(iter(data_loader)) 228 | 229 | return batch 230 | 231 | 232 | ################################################################################ 233 | # Alchemical MACE model 234 | ################################################################################ 235 | 236 | # get_outputs function from mace.modules.utils is modified to calculate also 237 | # the alchemical gradients 238 | 239 | 240 | def get_outputs( 241 | energy: torch.Tensor, 242 | positions: torch.Tensor, 243 | displacement: torch.Tensor, 244 | cell: torch.Tensor, 245 | node_weights: torch.Tensor, 246 | edge_weights: torch.Tensor, 247 | retain_graph: bool = False, 248 | create_graph: bool = False, 249 | compute_force: bool = True, 250 | compute_stress: bool = False, 251 | compute_alchemical_grad: bool = False, 252 | ) -> Tuple[ 253 | Optional[torch.Tensor], 254 | Optional[torch.Tensor], 255 | Optional[torch.Tensor], 256 | Optional[torch.Tensor], 257 | Optional[torch.Tensor], 258 | ]: 259 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)] 260 | if not compute_force: 261 | return None, None, None, None, None 262 | inputs = [positions] 263 | if compute_stress: 264 | inputs.append(displacement) 265 | if compute_alchemical_grad: 266 | inputs.extend([node_weights, edge_weights]) 267 | gradients = torch.autograd.grad( 268 | outputs=[energy], 269 | inputs=inputs, 270 | grad_outputs=grad_outputs, 271 | retain_graph=retain_graph, 272 | create_graph=create_graph, 273 | allow_unused=True, 274 | ) 275 | 276 | forces = gradients[0] 277 | stress = torch.zeros_like(displacement) 278 | virials = gradients[1] if compute_stress else None 279 | if compute_alchemical_grad: 280 | node_grad, edge_grad = gradients[-2], gradients[-1] 281 | else: 282 | node_grad, edge_grad = None, None 283 | if compute_stress and virials is not None: 284 | cell = cell.view(-1, 3, 3) 285 | volume = torch.einsum( 286 | "zi,zi->z", 287 | cell[:, 0, :], 288 | torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1), 289 | ).unsqueeze(-1) 290 | stress = virials / volume.view(-1, 1, 1) 291 | 292 | if forces is not None: 293 | forces = -1 * forces 294 | if virials is not None: 295 | virials = -1 * virials 296 | return forces, virials, stress, node_grad, edge_grad 297 | 298 | 299 | @compile_mode("script") 300 | class AlchemicalResidualInteractionBlock(RealAgnosticResidualInteractionBlock): 301 | def forward( 302 | self, 303 | node_attrs: torch.Tensor, 304 | node_feats: torch.Tensor, 305 | edge_attrs: torch.Tensor, 306 | edge_feats: torch.Tensor, 307 | edge_index: torch.Tensor, 308 | edge_weights: torch.Tensor, # alchemy 309 | ) -> Tuple[torch.Tensor, torch.Tensor]: 310 | sender = edge_index[0] 311 | receiver = edge_index[1] 312 | num_nodes = node_feats.shape[0] 313 | sc = self.skip_tp(node_feats, node_attrs) 314 | node_feats = self.linear_up(node_feats) 315 | tp_weights = self.conv_tp_weights(edge_feats) 316 | tp_weights = tp_weights * edge_weights[:, None] # alchemy 317 | mji = self.conv_tp( 318 | node_feats[sender], edge_attrs, tp_weights 319 | ) # [n_edges, irreps] 320 | message = scatter_sum( 321 | src=mji, index=receiver, dim=0, dim_size=num_nodes 322 | ) # [n_nodes, irreps] 323 | message = self.linear(message) / self.avg_num_neighbors 324 | return ( 325 | self.reshape(message), 326 | sc, 327 | ) # [n_nodes, channels, (lmax + 1)**2] 328 | 329 | 330 | @compile_mode("script") 331 | class AlchemicalMACE(ScaleShiftMACE): 332 | def forward( 333 | self, 334 | data: Dict[str, torch.Tensor], 335 | retain_graph: bool = False, # alchemy 336 | create_graph: bool = False, # alchemy 337 | compute_force: bool = True, 338 | compute_stress: bool = False, 339 | compute_displacement: bool = False, 340 | compute_alchemical_grad: bool = False, # alchemy 341 | map_to_original_atoms: bool = True, # alchemy 342 | ) -> Dict[str, Optional[torch.Tensor]]: 343 | # Setup 344 | data["positions"].requires_grad_(True) 345 | data["node_attrs"].requires_grad_(True) 346 | num_graphs = data["ptr"].numel() - 1 347 | displacement = torch.zeros( 348 | (num_graphs, 3, 3), 349 | dtype=data["positions"].dtype, 350 | device=data["positions"].device, 351 | ) 352 | if compute_stress or compute_displacement: 353 | ( 354 | data["positions"], 355 | data["shifts"], 356 | displacement, 357 | ) = get_symmetric_displacement( 358 | positions=data["positions"], 359 | unit_shifts=data["unit_shifts"], 360 | cell=data["cell"], 361 | edge_index=data["edge_index"], 362 | num_graphs=num_graphs, 363 | batch=data["batch"], 364 | ) 365 | 366 | # Atomic energies 367 | node_e0 = self.atomic_energies_fn(data["node_attrs"]) 368 | node_e0 = node_e0 * data["node_weights"] # alchemy 369 | e0 = scatter_sum( 370 | src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs 371 | ) # [n_graphs,] 372 | 373 | # Embeddings 374 | node_feats = self.node_embedding(data["node_attrs"]) 375 | vectors, lengths = get_edge_vectors_and_lengths( 376 | positions=data["positions"], 377 | edge_index=data["edge_index"], 378 | shifts=data["shifts"], 379 | ) 380 | edge_attrs = self.spherical_harmonics(vectors) 381 | edge_feats = self.radial_embedding(lengths) 382 | 383 | # Interactions 384 | node_es_list = [] 385 | node_feats_list = [] 386 | for interaction, product, readout in zip( 387 | self.interactions, self.products, self.readouts 388 | ): 389 | node_feats, sc = interaction( 390 | node_attrs=data["node_attrs"], 391 | node_feats=node_feats, 392 | edge_attrs=edge_attrs, 393 | edge_feats=edge_feats, 394 | edge_index=data["edge_index"], 395 | edge_weights=data["edge_weights"], # alchemy 396 | ) 397 | node_feats = product( 398 | node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] 399 | ) 400 | node_feats_list.append(node_feats) 401 | node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], } 402 | 403 | # Concatenate node features 404 | # node_feats_out = torch.cat(node_feats_list, dim=-1) 405 | 406 | # Sum over interactions 407 | node_inter_es = torch.sum( 408 | torch.stack(node_es_list, dim=0), dim=0 409 | ) # [n_nodes, ] 410 | node_inter_es = self.scale_shift(node_inter_es) 411 | node_inter_es = node_inter_es * data["node_weights"] # alchemy 412 | 413 | # Sum over nodes in graph 414 | inter_e = scatter_sum( 415 | src=node_inter_es, index=data["batch"], dim=-1, dim_size=num_graphs 416 | ) # [n_graphs,] 417 | 418 | # Add E_0 and (scaled) interaction energy 419 | total_energy = e0 + inter_e 420 | node_energy = node_e0 + node_inter_es 421 | 422 | forces, virials, stress, node_grad, edge_grad = get_outputs( 423 | energy=total_energy, # alchemy 424 | positions=data["positions"], 425 | displacement=displacement, 426 | cell=data["cell"], 427 | node_weights=data["node_weights"], # alchemy 428 | edge_weights=data["edge_weights"], # alchemy 429 | retain_graph=retain_graph, # alchemy 430 | create_graph=create_graph, # alchemy 431 | compute_force=compute_force, 432 | # compute_virials=compute_virials, # alchemy 433 | compute_stress=compute_stress, 434 | compute_alchemical_grad=compute_alchemical_grad, # alchemy 435 | ) 436 | 437 | # Map to original atoms (node energies and forces): alchemy 438 | if map_to_original_atoms: 439 | # Note: we're not giving the dim_size, as we assume that all 440 | # original atoms are present in the batch 441 | node_index = data["node_atom_indices"] 442 | node_energy = scatter_sum(src=node_energy, dim=0, index=node_index) 443 | if compute_force: 444 | forces = scatter_sum(src=forces, dim=0, index=node_index) 445 | 446 | output = { 447 | "energy": total_energy, 448 | "node_energy": node_energy, 449 | "interaction_energy": inter_e, 450 | "forces": forces, 451 | "virials": virials, 452 | "stress": stress, 453 | "displacement": displacement, 454 | "node_grad": node_grad, 455 | "edge_grad": edge_grad, 456 | } 457 | 458 | return output 459 | 460 | 461 | ################################################################################ 462 | # Alchemical MACE universal model 463 | ################################################################################ 464 | 465 | 466 | def alchemical_mace_mp( 467 | model: str, 468 | device: str, 469 | default_dtype: str = "float32", 470 | ): 471 | """ 472 | Load a pre-trained alchemical MACE model. 473 | 474 | Args: 475 | model: Model size (small, medium) 476 | device: Device to load the model onto 477 | default_dtype: Default data type for the model 478 | 479 | Returns: 480 | Alchemical MACE model 481 | """ 482 | 483 | # Load foundation MACE model and extract initial parameters 484 | assert model in ("small", "medium") # TODO: support large model 485 | mace = mace_mp(model=model, device=device, default_dtype=default_dtype).models[0] 486 | atomic_energies = mace.atomic_energies_fn.atomic_energies.detach().clone() 487 | z_table = utils.AtomicNumberTable([int(z) for z in mace.atomic_numbers]) 488 | atomic_inter_scale = mace.scale_shift.scale.detach().clone() 489 | atomic_inter_shift = mace.scale_shift.shift.detach().clone() 490 | 491 | # Prepare arguments for building the model 492 | placeholder_args = ["--name", "None", "--train_file", "None"] 493 | args = tools.build_default_arg_parser().parse_args(placeholder_args) 494 | args.max_L = {"small": 0, "medium": 1, "large": 2}[model] 495 | args.num_channels = 128 496 | args.hidden_irreps = o3.Irreps( 497 | (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L)) 498 | .sort() 499 | .irreps.simplify() 500 | ) 501 | 502 | # Build the alchemical MACE model 503 | model = AlchemicalMACE( 504 | r_max=6.0, 505 | num_bessel=10, 506 | num_polynomial_cutoff=5, 507 | max_ell=3, 508 | interaction_cls=AlchemicalResidualInteractionBlock, 509 | interaction_cls_first=AlchemicalResidualInteractionBlock, 510 | num_interactions=2, 511 | num_elements=len(z_table), 512 | hidden_irreps=o3.Irreps(args.hidden_irreps), 513 | MLP_irreps=o3.Irreps(args.MLP_irreps), 514 | atomic_energies=atomic_energies, 515 | avg_num_neighbors=args.avg_num_neighbors, 516 | atomic_numbers=z_table.zs, 517 | correlation=args.correlation, 518 | gate=modules.gate_dict[args.gate], 519 | radial_MLP=ast.literal_eval(args.radial_MLP), 520 | radial_type=args.radial_type, 521 | atomic_inter_scale=atomic_inter_scale, 522 | atomic_inter_shift=atomic_inter_shift, 523 | ) 524 | 525 | # Load foundation model parameters 526 | model.load_state_dict(mace.state_dict(), strict=True) 527 | for i in range(int(model.num_interactions)): 528 | model.interactions[i].avg_num_neighbors = mace.interactions[i].avg_num_neighbors 529 | model = model.to(device) 530 | return model 531 | 532 | 533 | def get_z_table_and_r_max(model: ScaleShiftMACE) -> Tuple[AtomicNumberTable, float]: 534 | """Extract the atomic number table and maximum cutoff radius from a MACE model.""" 535 | z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers]) 536 | r_max = model.r_max.item() 537 | return z_table, r_max 538 | -------------------------------------------------------------------------------- /alchemical_mace/optimize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Iterable, Dict, Any 3 | 4 | 5 | class ExponentiatedGradientDescent(torch.optim.Optimizer): 6 | """ 7 | Implements Exponentiated Gradient Descent. 8 | 9 | Args: 10 | params (iterable of torch.Tensor or dict): iterable of parameters to optimize or 11 | dicts defining parameter groups. 12 | lr (float, optional): learning rate. Defaults to 1e-3. 13 | eps (float, optional): small constant for numerical stability. Defaults to 1e-8. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]], 19 | lr: float = 1e-3, 20 | eps: float = 1e-8, 21 | ): 22 | super().__init__(params, defaults=dict(lr=lr, eps=eps)) 23 | 24 | def step(self): 25 | for group in self.param_groups: 26 | for p in group["params"]: 27 | if p.grad is None: 28 | continue 29 | p.data.mul_(torch.exp(-group["lr"] * p.grad)) 30 | p.data.div_(p.data.sum() + group["eps"]) 31 | -------------------------------------------------------------------------------- /alchemical_mace/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import ExitStack, contextmanager, redirect_stderr, redirect_stdout 3 | 4 | from ase import Atoms 5 | 6 | 7 | @contextmanager 8 | def suppress_print(out: bool = True, err: bool = False): 9 | """Suppress stdout and/or stderr.""" 10 | 11 | with ExitStack() as stack: 12 | devnull = stack.enter_context(open(os.devnull, "w")) 13 | if out: 14 | stack.enter_context(redirect_stdout(devnull)) 15 | if err: 16 | stack.enter_context(redirect_stderr(devnull)) 17 | yield 18 | 19 | 20 | # From CHGNet 21 | def upper_triangular_cell(atoms: Atoms): 22 | """Transform to upper-triangular cell.""" 23 | import numpy as np 24 | from ase.md.npt import NPT 25 | 26 | if NPT._isuppertriangular(atoms.get_cell()): 27 | return 28 | 29 | a, b, c, alpha, beta, gamma = atoms.cell.cellpar() 30 | angles = np.radians((alpha, beta, gamma)) 31 | sin_a, sin_b, _sin_g = np.sin(angles) 32 | cos_a, cos_b, cos_g = np.cos(angles) 33 | cos_p = (cos_g - cos_a * cos_b) / (sin_a * sin_b) 34 | cos_p = np.clip(cos_p, -1, 1) 35 | sin_p = (1 - cos_p**2) ** 0.5 36 | 37 | new_basis = [ 38 | (a * sin_b * sin_p, a * sin_b * cos_p, a * cos_b), 39 | (0, b * sin_a, b * cos_a), 40 | (0, 0, c), 41 | ] 42 | 43 | atoms.set_cell(new_basis, scale_atoms=True) 44 | -------------------------------------------------------------------------------- /data/structures/AlN_hex.cif: -------------------------------------------------------------------------------- 1 | # generated using pymatgen 2 | data_AlN 3 | _symmetry_space_group_name_H-M 'P 1' 4 | _cell_length_a 3.12664153 5 | _cell_length_b 3.12664143 6 | _cell_length_c 5.00715332 7 | _cell_angle_alpha 90.00000043 8 | _cell_angle_beta 89.99999986 9 | _cell_angle_gamma 119.99999965 10 | _symmetry_Int_Tables_number 1 11 | _chemical_formula_structural AlN 12 | _chemical_formula_sum 'Al2 N2' 13 | _cell_volume 42.39139351 14 | _cell_formula_units_Z 2 15 | loop_ 16 | _symmetry_equiv_pos_site_id 17 | _symmetry_equiv_pos_as_xyz 18 | 1 'x, y, z' 19 | loop_ 20 | _atom_site_type_symbol 21 | _atom_site_label 22 | _atom_site_symmetry_multiplicity 23 | _atom_site_fract_x 24 | _atom_site_fract_y 25 | _atom_site_fract_z 26 | _atom_site_occupancy 27 | Al Al0 1 0.66666667 0.33333333 0.49932157 1 28 | Al Al1 1 0.33333333 0.66666667 0.99932156 1 29 | N N2 1 0.66666667 0.33333333 0.88067843 1 30 | N N3 1 0.33333333 0.66666667 0.38067844 1 31 | -------------------------------------------------------------------------------- /data/structures/BiSBr.cif: -------------------------------------------------------------------------------- 1 | # generated using pymatgen 2 | data_BiSBr 3 | _symmetry_space_group_name_H-M 'P 1' 4 | _cell_length_a 4.10679210 5 | _cell_length_b 8.22484633 6 | _cell_length_c 11.05701800 7 | _cell_angle_alpha 89.99999573 8 | _cell_angle_beta 90.00000347 9 | _cell_angle_gamma 90.00001971 10 | _symmetry_Int_Tables_number 1 11 | _chemical_formula_structural BiSBr 12 | _chemical_formula_sum 'Bi4 S4 Br4' 13 | _cell_volume 373.48101173 14 | _cell_formula_units_Z 4 15 | loop_ 16 | _symmetry_equiv_pos_site_id 17 | _symmetry_equiv_pos_as_xyz 18 | 1 'x, y, z' 19 | loop_ 20 | _atom_site_type_symbol 21 | _atom_site_label 22 | _atom_site_symmetry_multiplicity 23 | _atom_site_fract_x 24 | _atom_site_fract_y 25 | _atom_site_fract_z 26 | _atom_site_occupancy 27 | Bi Bi0 1 0.24999992 0.88556832 0.87177467 1 28 | Bi Bi1 1 0.75000010 0.11443170 0.12822534 1 29 | Bi Bi2 1 0.24999988 0.38556830 0.62822532 1 30 | Bi Bi3 1 0.75000009 0.61443169 0.37177466 1 31 | S S4 1 0.74999990 0.82391374 0.03167750 1 32 | S S5 1 0.25000008 0.67608625 0.53167751 1 33 | S S6 1 0.74999991 0.32391375 0.46832249 1 34 | S S7 1 0.25000010 0.17608625 0.96832251 1 35 | Br Br8 1 0.24999990 0.96479764 0.30006197 1 36 | Br Br9 1 0.75000009 0.53520235 0.80006197 1 37 | Br Br10 1 0.24999991 0.46479765 0.19993803 1 38 | Br Br11 1 0.75000013 0.03520235 0.69993804 1 39 | -------------------------------------------------------------------------------- /data/structures/CeO2.cif: -------------------------------------------------------------------------------- 1 | # generated using pymatgen 2 | data_CeO2 3 | _symmetry_space_group_name_H-M 'P 1' 4 | _cell_length_a 5.46789061 5 | _cell_length_b 5.46789009 6 | _cell_length_c 5.46788984 7 | _cell_angle_alpha 90.00000180 8 | _cell_angle_beta 90.00000282 9 | _cell_angle_gamma 90.00000072 10 | _symmetry_Int_Tables_number 1 11 | _chemical_formula_structural CeO2 12 | _chemical_formula_sum 'Ce4 O8' 13 | _cell_volume 163.47801292 14 | _cell_formula_units_Z 4 15 | loop_ 16 | _symmetry_equiv_pos_site_id 17 | _symmetry_equiv_pos_as_xyz 18 | 1 'x, y, z' 19 | loop_ 20 | _atom_site_type_symbol 21 | _atom_site_label 22 | _atom_site_symmetry_multiplicity 23 | _atom_site_fract_x 24 | _atom_site_fract_y 25 | _atom_site_fract_z 26 | _atom_site_occupancy 27 | Ce Ce0 1 0.00000000 -0.00000000 -0.00000000 1 28 | Ce Ce1 1 0.00000000 0.50000000 0.50000000 1 29 | Ce Ce2 1 0.50000000 0.00000000 0.50000000 1 30 | Ce Ce3 1 0.50000000 0.50000000 0.00000000 1 31 | O O4 1 0.25000000 0.75000000 0.74999999 1 32 | O O5 1 0.74999999 0.25000000 0.25000000 1 33 | O O6 1 0.25000000 0.25000000 0.25000000 1 34 | O O7 1 0.75000000 0.75000000 0.75000000 1 35 | O O8 1 0.74999999 0.75000000 0.25000000 1 36 | O O9 1 0.25000000 0.25000000 0.74999999 1 37 | O O10 1 0.75000000 0.25000000 0.75000000 1 38 | O O11 1 0.25000000 0.75000000 0.25000000 1 39 | -------------------------------------------------------------------------------- /data/structures/CsPbI3_alpha.cif: -------------------------------------------------------------------------------- 1 | # generated using pymatgen 2 | data_CsPbI3 3 | _symmetry_space_group_name_H-M 'P 1' 4 | _cell_length_a 6.37904246 5 | _cell_length_b 6.37904251 6 | _cell_length_c 6.37904247 7 | _cell_angle_alpha 90.00000013 8 | _cell_angle_beta 89.99999979 9 | _cell_angle_gamma 89.99999953 10 | _symmetry_Int_Tables_number 1 11 | _chemical_formula_structural CsPbI3 12 | _chemical_formula_sum 'Cs1 Pb1 I3' 13 | _cell_volume 259.57716340 14 | _cell_formula_units_Z 1 15 | loop_ 16 | _symmetry_equiv_pos_site_id 17 | _symmetry_equiv_pos_as_xyz 18 | 1 'x, y, z' 19 | loop_ 20 | _atom_site_type_symbol 21 | _atom_site_label 22 | _atom_site_symmetry_multiplicity 23 | _atom_site_fract_x 24 | _atom_site_fract_y 25 | _atom_site_fract_z 26 | _atom_site_occupancy 27 | Cs Cs0 1 0.00000000 0.00000000 -0.00000000 1 28 | I I1 1 0.50000000 0.50000000 -0.00000000 1 29 | I I2 1 0.50000000 -0.00000000 0.50000000 1 30 | I I3 1 0.00000000 0.50000000 0.50000000 1 31 | Pb Pb4 1 0.50000000 0.50000000 0.50000000 1 32 | -------------------------------------------------------------------------------- /data/structures/CsPbI3_delta.cif: -------------------------------------------------------------------------------- 1 | # generated using pymatgen 2 | data_CsPbI3 3 | _symmetry_space_group_name_H-M 'P 1' 4 | _cell_length_a 4.91187765 5 | _cell_length_b 10.78602159 6 | _cell_length_c 18.12941751 7 | _cell_angle_alpha 90.00000734 8 | _cell_angle_beta 89.99999876 9 | _cell_angle_gamma 89.99999900 10 | _symmetry_Int_Tables_number 1 11 | _chemical_formula_structural CsPbI3 12 | _chemical_formula_sum 'Cs4 Pb4 I12' 13 | _cell_volume 960.48962177 14 | _cell_formula_units_Z 4 15 | loop_ 16 | _symmetry_equiv_pos_site_id 17 | _symmetry_equiv_pos_as_xyz 18 | 1 'x, y, z' 19 | loop_ 20 | _atom_site_type_symbol 21 | _atom_site_label 22 | _atom_site_symmetry_multiplicity 23 | _atom_site_fract_x 24 | _atom_site_fract_y 25 | _atom_site_fract_z 26 | _atom_site_occupancy 27 | Cs Cs0 1 0.75000002 0.57436836 0.17329066 1 28 | Cs Cs1 1 0.25000000 0.42563165 0.82670935 1 29 | Cs Cs2 1 0.75000002 0.07436835 0.32670935 1 30 | Cs Cs3 1 0.24999999 0.92563166 0.67329065 1 31 | I I4 1 0.74999999 0.20706041 0.71324789 1 32 | I I5 1 0.25000000 0.79293960 0.28675211 1 33 | I I6 1 0.75000001 0.97331566 0.10968932 1 34 | I I7 1 0.24999998 0.02668433 0.89031069 1 35 | I I8 1 0.74999999 0.47331567 0.39031069 1 36 | I I9 1 0.25000001 0.52668433 0.60968931 1 37 | I I10 1 0.25000000 0.66460269 0.00429655 1 38 | I I11 1 0.25000000 0.16460269 0.49570344 1 39 | I I12 1 0.25000000 0.29293960 0.21324788 1 40 | I I13 1 0.75000001 0.83539729 0.50429655 1 41 | I I14 1 0.75000000 0.33539731 0.99570345 1 42 | I I15 1 0.75000001 0.70706040 0.78675212 1 43 | Pb Pb16 1 0.74999998 0.83743203 0.94239546 1 44 | Pb Pb17 1 0.24999999 0.16256797 0.05760453 1 45 | Pb Pb18 1 0.75000002 0.33743205 0.55760454 1 46 | Pb Pb19 1 0.24999998 0.66256796 0.44239545 1 47 | -------------------------------------------------------------------------------- /data/structures/CsSnI3_alpha.cif: -------------------------------------------------------------------------------- 1 | # generated using pymatgen 2 | data_CsSnI3 3 | _symmetry_space_group_name_H-M 'P 1' 4 | _cell_length_a 6.27081715 5 | _cell_length_b 6.27081724 6 | _cell_length_c 6.27081711 7 | _cell_angle_alpha 89.99999964 8 | _cell_angle_beta 89.99999969 9 | _cell_angle_gamma 89.99999972 10 | _symmetry_Int_Tables_number 1 11 | _chemical_formula_structural CsSnI3 12 | _chemical_formula_sum 'Cs1 Sn1 I3' 13 | _cell_volume 246.58827085 14 | _cell_formula_units_Z 1 15 | loop_ 16 | _symmetry_equiv_pos_site_id 17 | _symmetry_equiv_pos_as_xyz 18 | 1 'x, y, z' 19 | loop_ 20 | _atom_site_type_symbol 21 | _atom_site_label 22 | _atom_site_symmetry_multiplicity 23 | _atom_site_fract_x 24 | _atom_site_fract_y 25 | _atom_site_fract_z 26 | _atom_site_occupancy 27 | Cs Cs0 1 0.00000000 0.00000000 -0.00000000 1 28 | I I1 1 0.50000000 0.50000000 0.00000000 1 29 | I I2 1 0.50000000 -0.00000000 0.50000000 1 30 | I I3 1 -0.00000000 0.50000000 0.50000000 1 31 | Sn Sn4 1 0.50000000 0.50000000 0.50000000 1 32 | -------------------------------------------------------------------------------- /data/structures/CsSnI3_delta.cif: -------------------------------------------------------------------------------- 1 | # generated using pymatgen 2 | data_CsSnI3 3 | _symmetry_space_group_name_H-M 'P 1' 4 | _cell_length_a 4.84918880 5 | _cell_length_b 10.69167480 6 | _cell_length_c 18.19616575 7 | _cell_angle_alpha 89.99999030 8 | _cell_angle_beta 90.00000262 9 | _cell_angle_gamma 89.99999762 10 | _symmetry_Int_Tables_number 1 11 | _chemical_formula_structural CsSnI3 12 | _chemical_formula_sum 'Cs4 Sn4 I12' 13 | _cell_volume 943.39749344 14 | _cell_formula_units_Z 4 15 | loop_ 16 | _symmetry_equiv_pos_site_id 17 | _symmetry_equiv_pos_as_xyz 18 | 1 'x, y, z' 19 | loop_ 20 | _atom_site_type_symbol 21 | _atom_site_label 22 | _atom_site_symmetry_multiplicity 23 | _atom_site_fract_x 24 | _atom_site_fract_y 25 | _atom_site_fract_z 26 | _atom_site_occupancy 27 | Cs Cs0 1 0.75000000 0.57552416 0.17186806 1 28 | Cs Cs1 1 0.25000001 0.42447584 0.82813194 1 29 | Cs Cs2 1 0.75000000 0.07552414 0.32813195 1 30 | Cs Cs3 1 0.25000000 0.92447584 0.67186806 1 31 | I I4 1 0.75000000 0.21397041 0.70606567 1 32 | I I5 1 0.25000000 0.78602961 0.29393431 1 33 | I I6 1 0.75000000 0.97123818 0.11227438 1 34 | I I7 1 0.25000000 0.02876181 0.88772566 1 35 | I I8 1 0.74999999 0.47123818 0.38772564 1 36 | I I9 1 0.25000001 0.52876182 0.61227437 1 37 | I I10 1 0.25000002 0.66813711 0.00020614 1 38 | I I11 1 0.24999999 0.16813710 0.49979388 1 39 | I I12 1 0.25000001 0.28602959 0.20606567 1 40 | I I13 1 0.74999999 0.83186292 0.50020614 1 41 | I I14 1 0.75000000 0.33186292 0.99979386 1 42 | I I15 1 0.74999999 0.71397042 0.79393433 1 43 | Sn Sn16 1 0.75000000 0.84420702 0.94393743 1 44 | Sn Sn17 1 0.25000000 0.15579296 0.05606252 1 45 | Sn Sn18 1 0.74999998 0.34420702 0.55606253 1 46 | Sn Sn19 1 0.24999999 0.65579297 0.44393746 1 47 | -------------------------------------------------------------------------------- /data/structures/Fe.cif: -------------------------------------------------------------------------------- 1 | # generated using pymatgen 2 | data_Fe 3 | _symmetry_space_group_name_H-M 'P 1' 4 | _cell_length_a 2.86106543 5 | _cell_length_b 2.86106544 6 | _cell_length_c 2.86106538 7 | _cell_angle_alpha 90.00000018 8 | _cell_angle_beta 89.99999992 9 | _cell_angle_gamma 90.00000009 10 | _symmetry_Int_Tables_number 1 11 | _chemical_formula_structural Fe 12 | _chemical_formula_sum Fe2 13 | _cell_volume 23.41980977 14 | _cell_formula_units_Z 2 15 | loop_ 16 | _symmetry_equiv_pos_site_id 17 | _symmetry_equiv_pos_as_xyz 18 | 1 'x, y, z' 19 | loop_ 20 | _atom_site_type_symbol 21 | _atom_site_label 22 | _atom_site_symmetry_multiplicity 23 | _atom_site_fract_x 24 | _atom_site_fract_y 25 | _atom_site_fract_z 26 | _atom_site_occupancy 27 | Fe Fe0 1 0.00000000 -0.00000000 0.00000000 1 28 | Fe Fe1 1 0.50000000 0.50000000 0.50000000 1 29 | -------------------------------------------------------------------------------- /data/structures/GaN_hex.cif: -------------------------------------------------------------------------------- 1 | # generated using pymatgen 2 | data_GaN 3 | _symmetry_space_group_name_H-M 'P 1' 4 | _cell_length_a 3.21192371 5 | _cell_length_b 3.21192377 6 | _cell_length_c 5.21628467 7 | _cell_angle_alpha 90.00000055 8 | _cell_angle_beta 90.00000024 9 | _cell_angle_gamma 119.99999991 10 | _symmetry_Int_Tables_number 1 11 | _chemical_formula_structural GaN 12 | _chemical_formula_sum 'Ga2 N2' 13 | _cell_volume 46.60391121 14 | _cell_formula_units_Z 2 15 | loop_ 16 | _symmetry_equiv_pos_site_id 17 | _symmetry_equiv_pos_as_xyz 18 | 1 'x, y, z' 19 | loop_ 20 | _atom_site_type_symbol 21 | _atom_site_label 22 | _atom_site_symmetry_multiplicity 23 | _atom_site_fract_x 24 | _atom_site_fract_y 25 | _atom_site_fract_z 26 | _atom_site_occupancy 27 | Ga Ga0 1 0.66666667 0.33333333 0.49900050 1 28 | Ga Ga1 1 0.33333333 0.66666667 0.99900050 1 29 | N N2 1 0.66666667 0.33333333 0.87599950 1 30 | N N3 1 0.33333333 0.66666667 0.37599950 1 31 | -------------------------------------------------------------------------------- /data/structures/NaCl.cif: -------------------------------------------------------------------------------- 1 | # generated using pymatgen 2 | data_NaCl 3 | _symmetry_space_group_name_H-M 'P 1' 4 | _cell_length_a 5.68304678 5 | _cell_length_b 5.68304679 6 | _cell_length_c 5.68304672 7 | _cell_angle_alpha 90.00000009 8 | _cell_angle_beta 90.00000015 9 | _cell_angle_gamma 90.00000005 10 | _symmetry_Int_Tables_number 1 11 | _chemical_formula_structural NaCl 12 | _chemical_formula_sum 'Na4 Cl4' 13 | _cell_volume 183.54547783 14 | _cell_formula_units_Z 4 15 | loop_ 16 | _symmetry_equiv_pos_site_id 17 | _symmetry_equiv_pos_as_xyz 18 | 1 'x, y, z' 19 | loop_ 20 | _atom_site_type_symbol 21 | _atom_site_label 22 | _atom_site_symmetry_multiplicity 23 | _atom_site_fract_x 24 | _atom_site_fract_y 25 | _atom_site_fract_z 26 | _atom_site_occupancy 27 | Na Na0 1 -0.00000000 -0.00000000 -0.00000000 1 28 | Na Na1 1 -0.00000000 0.50000000 0.50000000 1 29 | Na Na2 1 0.50000000 -0.00000000 0.50000000 1 30 | Na Na3 1 0.50000000 0.50000000 -0.00000000 1 31 | Cl Cl4 1 0.00000000 0.00000000 0.50000000 1 32 | Cl Cl5 1 -0.00000000 0.50000000 0.00000000 1 33 | Cl Cl6 1 0.50000000 0.00000000 -0.00000000 1 34 | Cl Cl7 1 0.50000000 0.50000000 0.50000000 1 35 | -------------------------------------------------------------------------------- /notebooks/1_solid_solution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import ase\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import numpy as np\n", 12 | "from tqdm import tqdm\n", 13 | "\n", 14 | "from alchemical_mace.calculator import get_alchemical_optimized_cellpar\n", 15 | "from alchemical_mace.model import AlchemicalPair\n", 16 | "from alchemical_mace.utils import suppress_print\n", 17 | "\n", 18 | "plt.style.use(\"default\")" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "### CeO2" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stderr", 35 | "output_type": "stream", 36 | "text": [ 37 | "100%|██████████| 21/21 [01:15<00:00, 3.60s/it]\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "# Default settings\n", 43 | "model = \"medium\"\n", 44 | "device = \"cpu\"\n", 45 | "\n", 46 | "# Load structure\n", 47 | "atoms = ase.io.read(\"../data/structures/CeO2.cif\")\n", 48 | "alch_elements = [\"Ce\", \"Sn\"]\n", 49 | "alch_idx = [i for i, atom in enumerate(atoms) if atom.symbol in alch_elements]\n", 50 | "alch_atomic_numbers = [ase.Atoms(el).numbers[0] for el in alch_elements]\n", 51 | "alchemical_pairs = [\n", 52 | " [AlchemicalPair(atom_index=idx, atomic_number=z) for idx in alch_idx]\n", 53 | " for z in alch_atomic_numbers\n", 54 | "]\n", 55 | "\n", 56 | "comp_grid = [[1 - x, x] for x in np.linspace(0, 0.5, 21)]\n", 57 | "lat_params_CeSn = []\n", 58 | "for comp in tqdm(comp_grid):\n", 59 | " with suppress_print(out=True, err=True):\n", 60 | " cellpar = get_alchemical_optimized_cellpar(\n", 61 | " atoms, alchemical_pairs, comp, model=model, device=device\n", 62 | " )\n", 63 | " lat_params_CeSn.append(cellpar[:3])" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stderr", 73 | "output_type": "stream", 74 | "text": [ 75 | "100%|██████████| 21/21 [01:35<00:00, 4.52s/it]\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "# Load structure\n", 81 | "alch_elements = [\"Ce\", \"Zr\"]\n", 82 | "alch_idx = [i for i, atom in enumerate(atoms) if atom.symbol in alch_elements]\n", 83 | "alch_atomic_numbers = [ase.Atoms(el).numbers[0] for el in alch_elements]\n", 84 | "alchemical_pairs = [\n", 85 | " [AlchemicalPair(atom_index=idx, atomic_number=z) for idx in alch_idx]\n", 86 | " for z in alch_atomic_numbers\n", 87 | "]\n", 88 | "\n", 89 | "comp_grid = [[1 - x, x] for x in np.linspace(0, 0.5, 21)]\n", 90 | "lat_params_CeZr = []\n", 91 | "for comp in tqdm(comp_grid):\n", 92 | " with suppress_print(out=True, err=True):\n", 93 | " cellpar = get_alchemical_optimized_cellpar(\n", 94 | " atoms, alchemical_pairs, comp, model=model, device=device\n", 95 | " )\n", 96 | " lat_params_CeZr.append(cellpar[:3])" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 4, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "image/png": "", 107 | "text/plain": [ 108 | "
" 109 | ] 110 | }, 111 | "metadata": {}, 112 | "output_type": "display_data" 113 | } 114 | ], 115 | "source": [ 116 | "fig, ax = plt.subplots(figsize=(3, 2.5))\n", 117 | "ax.plot(\n", 118 | " [x[1] for x in comp_grid],\n", 119 | " [x[0] for x in lat_params_CeSn],\n", 120 | " label=\"Ce$_{1-x}$Sn$_x$O$_2$\",\n", 121 | ")\n", 122 | "ax.plot(\n", 123 | " [x[1] for x in comp_grid],\n", 124 | " [x[0] for x in lat_params_CeZr],\n", 125 | " label=\"Ce$_{1-x}$Zr$_x$O$_2$\",\n", 126 | ")\n", 127 | "ax.set_xlabel(\"$x$\")\n", 128 | "ax.set_ylabel(\"a [Å]\")\n", 129 | "ax.legend()\n", 130 | "fig.show()" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "metadata": {}, 136 | "source": [ 137 | "### BiSBr" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 5, 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "# Load structure\n", 147 | "atoms = ase.io.read(\"../data/structures/BiSBr.cif\")\n", 148 | "halide_elements = [\"Cl\", \"Br\", \"I\"]\n", 149 | "halide_idx = [i for i, atom in enumerate(atoms) if atom.symbol in halide_elements]\n", 150 | "halide_atomic_numbers = [ase.Atoms(el).numbers[0] for el in halide_elements]\n", 151 | "alchemical_pairs = [\n", 152 | " [AlchemicalPair(atom_index=idx, atomic_number=z) for idx in halide_idx]\n", 153 | " for z in halide_atomic_numbers\n", 154 | "]" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 6, 160 | "metadata": {}, 161 | "outputs": [ 162 | { 163 | "name": "stderr", 164 | "output_type": "stream", 165 | "text": [ 166 | "100%|██████████| 21/21 [08:12<00:00, 23.46s/it]\n", 167 | "100%|██████████| 21/21 [04:21<00:00, 12.45s/it]\n", 168 | "100%|██████████| 21/21 [05:35<00:00, 15.95s/it]\n" 169 | ] 170 | } 171 | ], 172 | "source": [ 173 | "comp_grid = [[1 - x, x, 0] for x in np.linspace(0, 1, 21)]\n", 174 | "lat_params_ClBr = []\n", 175 | "for comp in tqdm(comp_grid):\n", 176 | " with suppress_print(out=True, err=True):\n", 177 | " cellpar = get_alchemical_optimized_cellpar(\n", 178 | " atoms, alchemical_pairs, comp, model=model, device=device\n", 179 | " )\n", 180 | " lat_params_ClBr.append(cellpar[:3])\n", 181 | "\n", 182 | "comp_grid = [[0, 1 - x, x] for x in np.linspace(0, 1, 21)]\n", 183 | "lat_params_BrI = []\n", 184 | "for comp in tqdm(comp_grid):\n", 185 | " with suppress_print(out=True, err=True):\n", 186 | " cellpar = get_alchemical_optimized_cellpar(\n", 187 | " atoms, alchemical_pairs, comp, model=model, device=device\n", 188 | " )\n", 189 | " lat_params_BrI.append(cellpar[:3])\n", 190 | "\n", 191 | "\n", 192 | "comp_grid = [[1 - x, 0, x] for x in np.linspace(0, 1, 21)]\n", 193 | "lat_params_ClI = []\n", 194 | "for comp in tqdm(comp_grid):\n", 195 | " with suppress_print(out=True, err=True):\n", 196 | " cellpar = get_alchemical_optimized_cellpar(\n", 197 | " atoms, alchemical_pairs, comp, model=model, device=device\n", 198 | " )\n", 199 | " lat_params_ClI.append(cellpar[:3])" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 7, 205 | "metadata": {}, 206 | "outputs": [ 207 | { 208 | "data": { 209 | "image/png": "", 210 | "text/plain": [ 211 | "
" 212 | ] 213 | }, 214 | "metadata": {}, 215 | "output_type": "display_data" 216 | } 217 | ], 218 | "source": [ 219 | "fig, ax = plt.subplots(figsize=(3.5, 2.5), layout=\"constrained\")\n", 220 | "comp = np.linspace(0, 1, 21)\n", 221 | "idx, param = 2, \"c\"\n", 222 | "ax.plot(comp, [lat[idx] for lat in lat_params_ClBr], label=\"BiSCl$_{1-x}$Br$_x$\")\n", 223 | "ax.plot(comp, [lat[idx] for lat in lat_params_BrI], label=\"BiSBr$_{1-x}$I$_x$\")\n", 224 | "ax.plot(comp, [lat[idx] for lat in lat_params_ClI], label=\"BiSCl$_{1-x}$I$_x$\")\n", 225 | "ax.legend()\n", 226 | "ax.set_xlabel(\"$x$\")\n", 227 | "ax.set_ylabel(f\"{param} [Å]\")\n", 228 | "fig.show()" 229 | ] 230 | } 231 | ], 232 | "metadata": { 233 | "kernelspec": { 234 | "display_name": "chgnet", 235 | "language": "python", 236 | "name": "python3" 237 | }, 238 | "language_info": { 239 | "codemirror_mode": { 240 | "name": "ipython", 241 | "version": 3 242 | }, 243 | "file_extension": ".py", 244 | "mimetype": "text/x-python", 245 | "name": "python", 246 | "nbconvert_exporter": "python", 247 | "pygments_lexer": "ipython3", 248 | "version": "3.10.14" 249 | } 250 | }, 251 | "nbformat": 4, 252 | "nbformat_minor": 2 253 | } 254 | -------------------------------------------------------------------------------- /notebooks/4_vacancy_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pathlib import Path\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "from ase import units\n", 14 | "from scipy.integrate import trapezoid" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "def integrate_switching(\n", 24 | " df_log: pd.DataFrame,\n", 25 | " equil_time: int = 20000,\n", 26 | " switch_time: int = 30000,\n", 27 | " return_E_diss: bool = False,\n", 28 | "):\n", 29 | " fwd_start, fwd_end = equil_time, equil_time + switch_time\n", 30 | " rev_start, rev_end = 2 * equil_time + switch_time, 2 * equil_time + 2 * switch_time\n", 31 | " grad, lamda = df_log[\"lambda_grad\"], df_log[\"lambda\"]\n", 32 | " W_fwd = trapezoid(grad[fwd_start:fwd_end], lamda[fwd_start:fwd_end])\n", 33 | " W_rev = trapezoid(grad[rev_start:rev_end], lamda[rev_start:rev_end])\n", 34 | " if return_E_diss:\n", 35 | " return (W_fwd - W_rev) / 2, (W_fwd + W_rev) / 2\n", 36 | " return (W_fwd - W_rev) / 2 # free energy difference\n", 37 | "\n", 38 | "\n", 39 | "def analyze_frenkel_ladd(\n", 40 | " base_path: Path,\n", 41 | " temp: float,\n", 42 | " equil_time: int = 20000,\n", 43 | " switch_time: int = 30000,\n", 44 | " verbose: bool = False,\n", 45 | "):\n", 46 | " T = temp\n", 47 | " k = np.load(base_path / \"spring_constants.npy\")\n", 48 | "\n", 49 | " mass = np.load(base_path / \"masses.npy\")\n", 50 | " omega = np.sqrt(k / mass)\n", 51 | " n_atoms = len(mass)\n", 52 | "\n", 53 | " # 1. Perfect crystal\n", 54 | " df_log = pd.read_csv(base_path / \"observables.csv\")\n", 55 | " volume = df_log[\"volume\"].values[0]\n", 56 | " if verbose:\n", 57 | " _, E_diss_perfect = integrate_switching(\n", 58 | " df_log, equil_time, switch_time, return_E_diss=True\n", 59 | " )\n", 60 | " delta_F = integrate_switching(df_log, equil_time, switch_time)\n", 61 | " F_E = 3 * units.kB * T * np.mean(np.log(units._hbar * omega / (units.kB * T)))\n", 62 | " PV = volume * 1.01325 * units.bar\n", 63 | " G_perfect = delta_F + F_E + PV\n", 64 | "\n", 65 | " # 2. Defective crystal\n", 66 | " df_log = pd.read_csv(base_path / \"observables_defect.csv\")\n", 67 | " volume = df_log[\"volume\"].values[0]\n", 68 | " if verbose:\n", 69 | " _, E_diss_defect = integrate_switching(\n", 70 | " df_log, equil_time, switch_time, return_E_diss=True\n", 71 | " )\n", 72 | " delta_F = integrate_switching(df_log, equil_time, switch_time)\n", 73 | " F_E = 3 * units.kB * T * np.mean(np.log(units._hbar * omega / (units.kB * T)))\n", 74 | " PV = volume * 1.01325 * units.bar\n", 75 | " G_defect = delta_F + F_E + PV\n", 76 | " G_v = G_defect * (n_atoms - 1) - G_perfect * (n_atoms - 1)\n", 77 | "\n", 78 | " # 3. Partial FL\n", 79 | " df_log = pd.read_csv(base_path / \"observables_FL.csv\")\n", 80 | " if verbose:\n", 81 | " _, E_diss_FL = integrate_switching(\n", 82 | " df_log, equil_time, switch_time, return_E_diss=True\n", 83 | " )\n", 84 | " delta_G = integrate_switching(df_log, equil_time, switch_time)\n", 85 | " # delta_G * N = (G_defect * N-1 + F_E) - G_perfect * N\n", 86 | " G_defect_alt = (delta_G * n_atoms - F_E + G_perfect * n_atoms) / (n_atoms - 1)\n", 87 | " G_v_alt = G_defect_alt * (n_atoms - 1) - G_perfect * (n_atoms - 1)\n", 88 | "\n", 89 | " if verbose:\n", 90 | " return G_perfect, G_defect, delta_G, E_diss_perfect, E_diss_defect, E_diss_FL\n", 91 | " return G_v, G_v_alt" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 3, 97 | "metadata": {}, 98 | "outputs": [ 99 | { 100 | "name": "stdout", 101 | "output_type": "stream", 102 | "text": [ 103 | "G_v (50 K) = 1.5513 ± 0.0063 eV\n", 104 | "G_v FL (50 K) = 1.5582 ± 0.0014 eV\n", 105 | "G_v (100 K) = 1.5475 ± 0.0060 eV\n", 106 | "G_v FL (100 K) = 1.5405 ± 0.0018 eV\n", 107 | "G_v (150 K) = 1.5232 ± 0.0276 eV\n", 108 | "G_v FL (150 K) = 1.5188 ± 0.0029 eV\n", 109 | "G_v (200 K) = 1.5110 ± 0.0283 eV\n", 110 | "G_v FL (200 K) = 1.5003 ± 0.0072 eV\n" 111 | ] 112 | } 113 | ], 114 | "source": [ 115 | "result_path = Path(\"../data/results/vacancy\")\n", 116 | "temp_range = [50, 100, 150, 200]\n", 117 | "\n", 118 | "G_v_all, G_v_std_all = [], []\n", 119 | "G_v_alt_all, G_v_alt_std_all = [], []\n", 120 | "for temp in temp_range:\n", 121 | " G_v_list = []\n", 122 | " G_v_alt_list = []\n", 123 | " E_diss_perfect_list = []\n", 124 | " E_diss_defect_list = []\n", 125 | " E_diss_FL_list = []\n", 126 | " for i in range(4):\n", 127 | " base_path = result_path / f\"Fe_5x5x5_{temp}K/{i}\"\n", 128 | " G_v, G_v_alt = analyze_frenkel_ladd(base_path, temp=temp, verbose=False)\n", 129 | " G_v_list.append(G_v)\n", 130 | " G_v_alt_list.append(G_v_alt)\n", 131 | " G_v = np.mean(G_v_list)\n", 132 | " G_v_std = np.std(G_v_list)\n", 133 | " G_v_alt = np.mean(G_v_alt_list)\n", 134 | " G_v_alt_std = np.std(G_v_alt_list)\n", 135 | " print(f\"G_v ({temp} K) = {G_v:.4f} ± {G_v_std:.4f} eV\")\n", 136 | " print(f\"G_v FL ({temp} K) = {G_v_alt:.4f} ± {G_v_alt_std:.4f} eV\")\n", 137 | " G_v_all.append(G_v)\n", 138 | " G_v_std_all.append(G_v_std)\n", 139 | " G_v_alt_all.append(G_v_alt)\n", 140 | " G_v_alt_std_all.append(G_v_alt_std)" 141 | ] 142 | } 143 | ], 144 | "metadata": { 145 | "kernelspec": { 146 | "display_name": "chgnet", 147 | "language": "python", 148 | "name": "python3" 149 | }, 150 | "language_info": { 151 | "codemirror_mode": { 152 | "name": "ipython", 153 | "version": 3 154 | }, 155 | "file_extension": ".py", 156 | "mimetype": "text/x-python", 157 | "name": "python", 158 | "nbconvert_exporter": "python", 159 | "pygments_lexer": "ipython3", 160 | "version": "3.10.14" 161 | } 162 | }, 163 | "nbformat": 4, 164 | "nbformat_minor": 2 165 | } 166 | -------------------------------------------------------------------------------- /notebooks/5_perovskite_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pathlib import Path\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "from ase import units\n", 14 | "from scipy.integrate import trapezoid" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "### Frenkel–Ladd path" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "def integrate_switching(\n", 31 | " df_log: pd.DataFrame,\n", 32 | " equil_time: int = 20000,\n", 33 | " switch_time: int = 30000,\n", 34 | " return_E_diss: bool = False,\n", 35 | "):\n", 36 | " fwd_start, fwd_end = equil_time, equil_time + switch_time\n", 37 | " rev_start, rev_end = 2 * equil_time + switch_time, 2 * equil_time + 2 * switch_time\n", 38 | " grad, lamda = df_log[\"lambda_grad\"], df_log[\"lambda\"]\n", 39 | " W_fwd = trapezoid(grad[fwd_start:fwd_end], lamda[fwd_start:fwd_end])\n", 40 | " W_rev = trapezoid(grad[rev_start:rev_end], lamda[rev_start:rev_end])\n", 41 | " if return_E_diss:\n", 42 | " return (W_fwd - W_rev) / 2, (W_fwd + W_rev) / 2\n", 43 | " return (W_fwd - W_rev) / 2 # free energy difference\n", 44 | "\n", 45 | "\n", 46 | "def analyze_frenkel_ladd(\n", 47 | " base_path: Path,\n", 48 | " temp: float,\n", 49 | " equil_time: int = 20000,\n", 50 | " switch_time: int = 30000,\n", 51 | "):\n", 52 | " T = temp\n", 53 | " df_log = pd.read_csv(base_path / \"observables.csv\")\n", 54 | " k = np.load(base_path / \"spring_constants.npy\")\n", 55 | " mass = np.load(base_path / \"masses.npy\")\n", 56 | " omega = np.sqrt(k / mass)\n", 57 | " volume = df_log[\"volume\"].values[0]\n", 58 | "\n", 59 | " delta_F = integrate_switching(df_log, equil_time, switch_time)\n", 60 | " F_E = 3 * units.kB * T * np.mean(np.log(units._hbar * omega / (units.kB * T)))\n", 61 | " PV = volume * 1.01325 * units.bar\n", 62 | " delta_G = delta_F + F_E + PV\n", 63 | "\n", 64 | " return delta_G\n", 65 | "\n", 66 | "\n", 67 | "def analyze_alchemical_switching(\n", 68 | " base_path: Path,\n", 69 | " temp: float,\n", 70 | " equil_time: int = 20000,\n", 71 | " switch_time: int = 30000,\n", 72 | "):\n", 73 | " T = temp\n", 74 | " df_log = pd.read_csv(base_path / \"observables.csv\")\n", 75 | " mass_init = np.load(base_path / \"masses_init.npy\")\n", 76 | " mass_final = np.load(base_path / \"masses_final.npy\")\n", 77 | "\n", 78 | " work = integrate_switching(df_log, equil_time, switch_time)\n", 79 | " G_mass = 1.5 * units.kB * T * np.mean(np.log(mass_init / mass_final))\n", 80 | " delta_G = work + G_mass\n", 81 | "\n", 82 | " return delta_G" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "CsPbI3 Alpha G (300 K) = -8.8514 ± 0.0003 eV/atom\n", 95 | "CsPbI3 Alpha G (350 K) = -9.8642 ± 0.0004 eV/atom\n", 96 | "CsPbI3 Alpha G (400 K) = -10.8783 ± 0.0006 eV/atom\n", 97 | "CsPbI3 Alpha G (450 K) = -11.8943 ± 0.0003 eV/atom\n", 98 | "CsPbI3 Alpha G (500 K) = -12.9119 ± 0.0004 eV/atom\n", 99 | "CsPbI3 Delta G (300 K) = -8.8560 ± 0.0001 eV/atom\n", 100 | "CsPbI3 Delta G (350 K) = -9.8660 ± 0.0001 eV/atom\n", 101 | "CsPbI3 Delta G (400 K) = -10.8782 ± 0.0002 eV/atom\n", 102 | "CsPbI3 Delta G (450 K) = -11.8921 ± 0.0002 eV/atom\n", 103 | "CsPbI3 Delta G (500 K) = -12.9072 ± 0.0002 eV/atom\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "result_path = Path(\"../data/results/perovskite/frenkel_ladd\")\n", 109 | "temp_range = [300, 350, 400, 450, 500]\n", 110 | "\n", 111 | "G_alpha = []\n", 112 | "G_alpha_std = []\n", 113 | "for temp in temp_range:\n", 114 | " G_list = []\n", 115 | " for i in range(4):\n", 116 | " base_path = result_path / f\"CsPbI3_alpha_6x6x6_{temp}K/{i}\"\n", 117 | " G_list.append(analyze_frenkel_ladd(base_path, temp=temp))\n", 118 | " G = np.mean(G_list)\n", 119 | " G_std = np.std(G_list)\n", 120 | " print(f\"CsPbI3 Alpha G ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n", 121 | " G_alpha.append(G)\n", 122 | " G_alpha_std.append(G_std)\n", 123 | "\n", 124 | "G_delta = []\n", 125 | "G_delta_std = []\n", 126 | "for temp in temp_range:\n", 127 | " G_list = []\n", 128 | " for i in range(4):\n", 129 | " base_path = result_path / f\"CsPbI3_delta_6x3x3_{temp}K/{i}\"\n", 130 | " G_list.append(analyze_frenkel_ladd(base_path, temp=temp))\n", 131 | " G = np.mean(G_list)\n", 132 | " G_std = np.std(G_list)\n", 133 | " print(f\"CsPbI3 Delta G ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n", 134 | " G_delta.append(G)\n", 135 | " G_delta_std.append(G_std)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 4, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "name": "stdout", 145 | "output_type": "stream", 146 | "text": [ 147 | "CsSnI3 Alpha G (300 K) = -8.8297 ± 0.0002 eV/atom\n", 148 | "CsSnI3 Alpha G (350 K) = -9.8413 ± 0.0002 eV/atom\n", 149 | "CsSnI3 Alpha G (400 K) = -10.8544 ± 0.0002 eV/atom\n", 150 | "CsSnI3 Alpha G (450 K) = -11.8695 ± 0.0003 eV/atom\n", 151 | "CsSnI3 Alpha G (500 K) = -12.8863 ± 0.0003 eV/atom\n", 152 | "CsSnI3 Delta G (300 K) = -8.8289 ± 0.0001 eV/atom\n", 153 | "CsSnI3 Delta G (350 K) = -9.8381 ± 0.0000 eV/atom\n", 154 | "CsSnI3 Delta G (400 K) = -10.8494 ± 0.0003 eV/atom\n", 155 | "CsSnI3 Delta G (450 K) = -11.8627 ± 0.0003 eV/atom\n", 156 | "CsSnI3 Delta G (500 K) = -12.8771 ± 0.0003 eV/atom\n" 157 | ] 158 | } 159 | ], 160 | "source": [ 161 | "G_CsSnI3_alpha = []\n", 162 | "G_CsSnI3_alpha_std = []\n", 163 | "for temp in temp_range:\n", 164 | " G_list = []\n", 165 | " for i in range(4):\n", 166 | " base_path = result_path / f\"CsSnI3_alpha_6x6x6_{temp}K/{i}\"\n", 167 | " G_list.append(analyze_frenkel_ladd(base_path, temp=temp))\n", 168 | " G = np.mean(G_list)\n", 169 | " G_std = np.std(G_list)\n", 170 | " print(f\"CsSnI3 Alpha G ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n", 171 | " G_CsSnI3_alpha.append(G)\n", 172 | " G_CsSnI3_alpha_std.append(G_std)\n", 173 | "\n", 174 | "G_CsSnI3_delta = []\n", 175 | "G_CsSnI3_delta_std = []\n", 176 | "for temp in temp_range:\n", 177 | " G_list = []\n", 178 | " for i in range(4):\n", 179 | " base_path = result_path / f\"CsSnI3_delta_6x3x3_{temp}K/{i}\"\n", 180 | " G_list.append(analyze_frenkel_ladd(base_path, temp=temp))\n", 181 | " G = np.mean(G_list)\n", 182 | " G_std = np.std(G_list)\n", 183 | " print(f\"CsSnI3 Delta G ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n", 184 | " G_CsSnI3_delta.append(G)\n", 185 | " G_CsSnI3_delta_std.append(G_std)" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "### Alchemical path" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 5, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "Alpha ΔG (300 K) = 0.0233 ± 0.0001 eV/atom\n", 205 | "Alpha ΔG (350 K) = 0.0236 ± 0.0001 eV/atom\n", 206 | "Alpha ΔG (400 K) = 0.0241 ± 0.0001 eV/atom\n", 207 | "Alpha ΔG (450 K) = 0.0249 ± 0.0000 eV/atom\n", 208 | "Alpha ΔG (500 K) = 0.0258 ± 0.0000 eV/atom\n", 209 | "Delta ΔG (300 K) = 0.0271 ± 0.0000 eV/atom\n", 210 | "Delta ΔG (350 K) = 0.0279 ± 0.0000 eV/atom\n", 211 | "Delta ΔG (400 K) = 0.0286 ± 0.0000 eV/atom\n", 212 | "Delta ΔG (450 K) = 0.0294 ± 0.0000 eV/atom\n", 213 | "Delta ΔG (500 K) = 0.0301 ± 0.0000 eV/atom\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "result_path = Path(\"../data/results/perovskite/alchemy\")\n", 219 | "\n", 220 | "G_alpha = []\n", 221 | "G_alpha_std = []\n", 222 | "for temp in temp_range:\n", 223 | " G_list = []\n", 224 | " for i in range(4):\n", 225 | " base_path = result_path / f\"CsPbI3_CsSnI3_alpha_{temp}K/{i}\"\n", 226 | " G_list.append(analyze_alchemical_switching(base_path, temp=temp))\n", 227 | " G = np.mean(G_list)\n", 228 | " G_std = np.std(G_list)\n", 229 | " print(f\"Alpha ΔG ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n", 230 | " G_alpha.append(G)\n", 231 | " G_alpha_std.append(G_std)\n", 232 | "\n", 233 | "G_delta = []\n", 234 | "G_delta_std = []\n", 235 | "for temp in temp_range:\n", 236 | " G_list = []\n", 237 | " for i in range(4):\n", 238 | " base_path = result_path / f\"CsPbI3_CsSnI3_delta_{temp}K/{i}\"\n", 239 | " G_list.append(analyze_alchemical_switching(base_path, temp=temp))\n", 240 | " G = np.mean(G_list)\n", 241 | " G_std = np.std(G_list)\n", 242 | " print(f\"Delta ΔG ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n", 243 | " G_delta.append(G)\n", 244 | " G_delta_std.append(G_std)" 245 | ] 246 | } 247 | ], 248 | "metadata": { 249 | "kernelspec": { 250 | "display_name": "chgnet", 251 | "language": "python", 252 | "name": "python3" 253 | }, 254 | "language_info": { 255 | "codemirror_mode": { 256 | "name": "ipython", 257 | "version": 3 258 | }, 259 | "file_extension": ".py", 260 | "mimetype": "text/x-python", 261 | "name": "python", 262 | "nbconvert_exporter": "python", 263 | "pygments_lexer": "ipython3", 264 | "version": "3.10.14" 265 | } 266 | }, 267 | "nbformat": 4, 268 | "nbformat_minor": 2 269 | } 270 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "alchemical_mace" 7 | authors = [ 8 | { name = "Juno Nam", email = "junonam@mit.edu" }, 9 | ] 10 | description = "Alchemical MACE model" 11 | readme = "README.md" 12 | requires-python = ">=3.9" 13 | version = "0.1.0" 14 | 15 | [tool.setuptools] 16 | packages = ["alchemical_mace"] 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | e3nn==0.4.4 3 | mace-torch==0.3.4 4 | ase==3.22.1 5 | pymatgen==2024.3.1 6 | numpy==1.25.2 7 | scipy==1.11.2 8 | pandas==2.2.2 9 | matplotlib==3.8.0 10 | mpltern==1.0.2 11 | tqdm==4.66.3 12 | ipykernel==6.25.2 13 | -------------------------------------------------------------------------------- /scripts/perovskite_alchemy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import ase 5 | import numpy as np 6 | import pandas as pd 7 | from ase import units 8 | from ase.build import make_supercell 9 | from ase.constraints import ExpCellFilter 10 | from ase.md.npt import NPT 11 | from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen 12 | from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary 13 | from ase.optimize import FIRE 14 | from mace.calculators import mace_mp 15 | from tqdm import tqdm 16 | 17 | from alchemical_mace.calculator import AlchemicalMACECalculator 18 | from alchemical_mace.model import AlchemicalPair 19 | from alchemical_mace.utils import upper_triangular_cell 20 | 21 | 22 | # Arguments 23 | parser = argparse.ArgumentParser() 24 | 25 | # Structure 26 | parser.add_argument("--structure-file", type=str) 27 | parser.add_argument("--supercell", type=int, nargs=3, default=[6, 6, 6]) 28 | 29 | # Alchemy 30 | parser.add_argument("--switch-pair", type=str, nargs=2, default=["Pb", "Sn"]) 31 | 32 | # Molecular dynamics: general 33 | parser.add_argument("--temperature", type=float, default=300.0) 34 | parser.add_argument("--pressure", type=float, default=1.0) 35 | parser.add_argument("--timestep", type=float, default=2.0) 36 | parser.add_argument("--ttime", type=float, default=25.0) 37 | parser.add_argument("--ptime", type=int, default=75.0) 38 | 39 | # Molecular dynamics: timesteps 40 | parser.add_argument("--npt-equil-stpes", type=int, default=10000) 41 | parser.add_argument("--alchemy-equil-steps", type=int, default=20000) 42 | parser.add_argument("--alchemy-switch-steps", type=int, default=30000) 43 | 44 | # Molecular dynamics: output control 45 | parser.add_argument("--output-dir", type=Path, default=Path("results")) 46 | parser.add_argument("--log-interval", type=int, default=1) 47 | 48 | # MACE model 49 | parser.add_argument("--device", type=str, default="cuda") 50 | parser.add_argument("--model", type=str, default="small") 51 | 52 | args = parser.parse_args() 53 | args.output_dir.mkdir(exist_ok=True, parents=True) 54 | 55 | # Load structure 56 | atoms = ase.io.read(args.structure_file) 57 | atoms = make_supercell(atoms, np.diag(args.supercell)) 58 | 59 | # Load universal MACE calculator and relax the structure 60 | mace_calc = mace_mp(model=args.model, device=args.device, default_dtype="float32") 61 | atoms.calc = mace_calc 62 | atoms = ExpCellFilter(atoms) 63 | optimizer = FIRE(atoms) 64 | optimizer.run(fmax=0.01, steps=500) 65 | atoms = atoms.atoms # get the relaxed structure 66 | initial_atoms = atoms.copy() # save the initial structure 67 | 68 | 69 | ################################################################################ 70 | # Cell volume equilibration 71 | ################################################################################ 72 | 73 | atoms = initial_atoms.copy() 74 | atoms.set_calculator(mace_calc) 75 | bulk_modulus = 100.0 * units.GPa 76 | 77 | # NPT equilibration 78 | dyn = Inhomogeneous_NPTBerendsen( 79 | atoms, 80 | timestep=args.timestep * units.fs, 81 | temperature_K=args.temperature, 82 | pressure_au=args.pressure * 1.01325 * units.bar, 83 | taut=args.ttime * units.fs, 84 | taup=args.ptime * units.fs, 85 | compressibility_au=1.0 / bulk_modulus, 86 | ) 87 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) 88 | Stationary(atoms) 89 | 90 | # NPT equilibration and volume relaxation 91 | for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"): 92 | dyn.run(steps=1) 93 | 94 | 95 | ################################################################################ 96 | # Alchemical switching 97 | ################################################################################ 98 | 99 | # Define alchemical transformation 100 | src_elem, dst_elem = args.switch_pair 101 | src_Z, dst_Z = ase.data.atomic_numbers[src_elem], ase.data.atomic_numbers[dst_elem] 102 | src_idx = np.where(atoms.get_atomic_numbers() == src_Z)[0] 103 | alchemical_pairs = [ 104 | [AlchemicalPair(atom_index=idx, atomic_number=Z) for idx in src_idx] 105 | for Z in [src_Z, dst_Z] 106 | ] 107 | 108 | # Set up the alchemical MACE calculator 109 | calc = AlchemicalMACECalculator( 110 | atoms=atoms, 111 | alchemical_pairs=alchemical_pairs, 112 | alchemical_weights=[1.0, 0.0], 113 | device=args.device, 114 | model=args.model, 115 | ) 116 | atoms.set_calculator(calc) 117 | upper_triangular_cell(atoms) # for ASE NPT 118 | 119 | # NPT alchemical switching 120 | ptime = args.ptime * units.fs 121 | pfactor = bulk_modulus * ptime * ptime 122 | 123 | dyn = NPT( 124 | atoms, 125 | timestep=args.timestep * units.fs, 126 | temperature_K=args.temperature, 127 | externalstress=args.pressure * 1.01325 * units.bar, 128 | ttime=args.ttime * units.fs, 129 | pfactor=pfactor, 130 | ) 131 | 132 | # Define alchemical path 133 | t = np.linspace(0.0, 1.0, args.alchemy_switch_steps) 134 | lambda_steps = t**5 * (70 * t**4 - 315 * t**3 + 540 * t**2 - 420 * t + 126) 135 | lambda_values = [ 136 | np.zeros(args.alchemy_equil_steps), 137 | lambda_steps, 138 | np.ones(args.alchemy_equil_steps), 139 | lambda_steps[::-1], 140 | ] 141 | lambda_values = np.concatenate(lambda_values) 142 | 143 | calculate_gradients = [ 144 | np.zeros(args.alchemy_equil_steps, dtype=bool), 145 | np.ones(args.alchemy_switch_steps, dtype=bool), 146 | np.zeros(args.alchemy_equil_steps, dtype=bool), 147 | np.ones(args.alchemy_switch_steps, dtype=bool), 148 | ] 149 | calculate_gradients = np.concatenate(calculate_gradients) 150 | 151 | 152 | def get_observables(dynamics, time, lambda_value): 153 | num_atoms = len(dynamics.atoms) 154 | alchemical_grad = dynamics.atoms._calc.results["alchemical_grad"] 155 | lambda_grad = (alchemical_grad[1] - alchemical_grad[0]) / num_atoms 156 | return { 157 | "time": time, 158 | "potential": dynamics.atoms.get_potential_energy() / num_atoms, 159 | "temperature": dynamics.atoms.get_temperature(), 160 | "volume": dynamics.atoms.get_volume() / num_atoms, 161 | "lambda": lambda_value, 162 | "lambda_grad": lambda_grad, 163 | } 164 | 165 | 166 | # Simulation loop 167 | total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps 168 | 169 | observables = [] 170 | for step in tqdm(range(total_steps), desc="Alchemical switching"): 171 | lambda_value = lambda_values[step] 172 | grad_enabled = calculate_gradients[step] 173 | 174 | # Set alchemical weights and atomic masses 175 | calc.set_alchemical_weights([1 - lambda_value, lambda_value]) 176 | atoms.set_masses(calc.get_alchemical_atomic_masses()) 177 | calc.calculate_alchemical_grad = grad_enabled 178 | 179 | dyn.run(steps=1) 180 | if step % args.log_interval == 0: 181 | time = (step + 1) * args.timestep 182 | observables.append(get_observables(dyn, time, lambda_value)) 183 | 184 | # Save observables 185 | df = pd.DataFrame(observables) 186 | df.to_csv(args.output_dir / "observables.csv", index=False) 187 | 188 | # Save masses for post-processing 189 | calc.set_alchemical_weights([1.0, 0.0]) 190 | np.save(args.output_dir / "masses_init.npy", calc.get_alchemical_atomic_masses()) 191 | calc.set_alchemical_weights([0.0, 1.0]) 192 | np.save(args.output_dir / "masses_final.npy", calc.get_alchemical_atomic_masses()) 193 | -------------------------------------------------------------------------------- /scripts/perovskite_frenkel_ladd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import ase 5 | import numpy as np 6 | import pandas as pd 7 | from ase import units 8 | from ase.build import make_supercell 9 | from ase.constraints import ExpCellFilter 10 | from ase.md.langevin import Langevin 11 | from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen 12 | from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary 13 | from ase.optimize import FIRE 14 | from mace.calculators import mace_mp 15 | from pymatgen.io.ase import AseAtomsAdaptor 16 | from pymatgen.symmetry.analyzer import SpacegroupAnalyzer 17 | from tqdm import tqdm 18 | 19 | from alchemical_mace.calculator import FrenkelLaddCalculator, NVTMACECalculator 20 | 21 | 22 | # Arguments 23 | parser = argparse.ArgumentParser() 24 | 25 | # Structure 26 | parser.add_argument("--structure-file", type=str) 27 | parser.add_argument("--supercell", type=int, nargs=3, default=[6, 6, 6]) 28 | 29 | # Molecular dynamics: general 30 | parser.add_argument("--temperature", type=float, default=300.0) 31 | parser.add_argument("--pressure", type=float, default=1.0) 32 | parser.add_argument("--timestep", type=float, default=2.0) 33 | parser.add_argument("--ttime", type=float, default=25.0) 34 | parser.add_argument("--ptime", type=int, default=75.0) 35 | 36 | # Molecular dynamics: timesteps 37 | parser.add_argument("--npt-equil-stpes", type=int, default=10000) 38 | parser.add_argument("--npt-prod-steps", type=int, default=20000) 39 | parser.add_argument("--nvt-equil-steps", type=int, default=20000) 40 | parser.add_argument("--nvt-prod-steps", type=int, default=30000) 41 | parser.add_argument("--alchemy-equil-steps", type=int, default=20000) 42 | parser.add_argument("--alchemy-switch-steps", type=int, default=30000) 43 | 44 | # Molecular dynamics: output control 45 | parser.add_argument("--output-dir", type=Path, default=Path("results")) 46 | parser.add_argument("--log-interval", type=int, default=1) 47 | 48 | # MACE model 49 | parser.add_argument("--device", type=str, default="cuda") 50 | parser.add_argument("--model", type=str, default="small") 51 | 52 | args = parser.parse_args() 53 | args.output_dir.mkdir(exist_ok=True, parents=True) 54 | 55 | # Load structure 56 | atoms = ase.io.read(args.structure_file) 57 | atoms = make_supercell(atoms, np.diag(args.supercell)) 58 | 59 | # Load universal MACE calculator and relax the structure 60 | mace_calc = mace_mp(model=args.model, device=args.device, default_dtype="float32") 61 | atoms.calc = mace_calc 62 | atoms = ExpCellFilter(atoms) 63 | optimizer = FIRE(atoms) 64 | optimizer.run(fmax=0.01, steps=500) 65 | atoms = atoms.atoms # get the relaxed structure 66 | initial_atoms = atoms.copy() # save the initial structure 67 | 68 | 69 | ################################################################################ 70 | # Cell volume equilibration 71 | ################################################################################ 72 | 73 | atoms = initial_atoms.copy() 74 | atoms.set_calculator(mace_calc) 75 | bulk_modulus = 100.0 * units.GPa 76 | 77 | # Equilibration and volume calculation 78 | dyn = Inhomogeneous_NPTBerendsen( 79 | atoms, 80 | timestep=args.timestep * units.fs, 81 | temperature_K=args.temperature, 82 | pressure_au=args.pressure * 1.01325 * units.bar, 83 | taut=args.ttime * units.fs, 84 | taup=args.ptime * units.fs, 85 | compressibility_au=1.0 / bulk_modulus, 86 | ) 87 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) 88 | Stationary(atoms) 89 | 90 | # NPT equilibration and volume relaxation 91 | cellpar_traj = [] 92 | for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"): 93 | dyn.run(steps=1) 94 | for step in tqdm(range(args.npt_prod_steps), desc="NPT prod"): 95 | dyn.run(steps=1) 96 | if step % args.log_interval == 0: 97 | cellpar_traj.append(atoms.get_cell().cellpar()) 98 | abc_new = np.mean(cellpar_traj, axis=0)[:3] 99 | 100 | # Scale the initial cell to match the average volume 101 | atoms = initial_atoms 102 | atoms.set_cell(np.diag(abc_new), scale_atoms=True) 103 | atoms.set_calculator(mace_calc) 104 | 105 | # Relax the atomic positions 106 | optimizer = FIRE(atoms) 107 | optimizer.run(fmax=0.01, steps=500) 108 | initial_atoms = atoms.copy() # save the initial structure 109 | 110 | 111 | ################################################################################ 112 | # MSD calculation 113 | ################################################################################ 114 | 115 | initial_positions = atoms.get_positions() 116 | # Using the reversible scaling MACE calculator with fixed scale of 1.0 117 | # since we can turn off the stress calculation 118 | calc = NVTMACECalculator(device=args.device, model=args.model) 119 | atoms.set_calculator(calc) 120 | 121 | # NVT MSD calculation 122 | dyn = Langevin( 123 | atoms, 124 | timestep=args.timestep * units.fs, 125 | temperature_K=args.temperature, 126 | friction=1 / (args.ttime * units.fs), 127 | ) 128 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) 129 | Stationary(atoms) 130 | 131 | temperatures = [] 132 | for step in tqdm(range(args.nvt_equil_steps), desc="NVT equil"): 133 | dyn.run(steps=1) 134 | squared_disp = np.zeros(len(atoms)) 135 | for step in tqdm(range(args.nvt_prod_steps), desc="NVT prod"): 136 | dyn.run(steps=1) 137 | squared_disp += np.sum((atoms.get_positions() - initial_positions) ** 2, axis=1) 138 | mean_squared_disp = squared_disp / args.nvt_prod_steps 139 | 140 | # Calculate spring constants and average over symmetrically equivalent atoms 141 | spring_constants = 3.0 * units.kB * args.temperature / mean_squared_disp 142 | structure = AseAtomsAdaptor.get_structure(initial_atoms) 143 | sga = SpacegroupAnalyzer(structure) 144 | equivalent_indices = sga.get_symmetrized_structure().equivalent_indices 145 | for indices in equivalent_indices: 146 | spring_constants[indices] = np.mean(spring_constants[indices]) 147 | 148 | np.save(args.output_dir / "spring_constants.npy", spring_constants) 149 | np.save(args.output_dir / "masses.npy", atoms.get_masses()) 150 | 151 | 152 | ################################################################################ 153 | # Frenkel-Ladd calculation 154 | ################################################################################ 155 | 156 | atoms = initial_atoms.copy() 157 | calc = FrenkelLaddCalculator( 158 | spring_constants=spring_constants, 159 | initial_positions=initial_positions, 160 | device=args.device, 161 | model=args.model, 162 | ) 163 | atoms.set_calculator(calc) 164 | 165 | # NVT Frenkel-Ladd calculation 166 | dyn = Langevin( 167 | atoms, 168 | timestep=args.timestep * units.fs, 169 | temperature_K=args.temperature, 170 | friction=1 / (args.ttime * units.fs), 171 | ) 172 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) 173 | Stationary(atoms) 174 | 175 | # Define Frenkel-Ladd path 176 | t = np.linspace(0.0, 1.0, args.alchemy_switch_steps) 177 | lambda_steps = t**5 * (70 * t**4 - 315 * t**3 + 540 * t**2 - 420 * t + 126) 178 | lambda_values = [ 179 | np.zeros(args.alchemy_equil_steps), 180 | lambda_steps, 181 | np.ones(args.alchemy_equil_steps), 182 | lambda_steps[::-1], 183 | ] 184 | lambda_values = np.concatenate(lambda_values) 185 | 186 | 187 | def get_observables(dynamics, time, lambda_value): 188 | num_atoms = len(dynamics.atoms) 189 | return { 190 | "time": time, 191 | "potential": dynamics.atoms.get_potential_energy() / num_atoms, 192 | "temperature": dynamics.atoms.get_temperature(), 193 | "volume": dynamics.atoms.get_volume() / num_atoms, 194 | "lambda": lambda_value, 195 | "lambda_grad": dynamics.atoms._calc.results["energy_diff"] / num_atoms, 196 | } 197 | 198 | 199 | # Simulation loop 200 | calc.compute_mace = False 201 | total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps 202 | 203 | observables = [] 204 | for step in tqdm(range(total_steps), desc="Frenkel-Ladd"): 205 | if step == args.alchemy_equil_steps: # turn on MACE after spring equilibration 206 | calc.compute_mace = True 207 | lambda_value = lambda_values[step] 208 | calc.set_weights(lambda_value) 209 | 210 | dyn.run(steps=1) 211 | if step % args.log_interval == 0: 212 | time = (step + 1) * args.timestep 213 | observables.append(get_observables(dyn, time, lambda_value)) 214 | 215 | # Save observables 216 | df = pd.DataFrame(observables) 217 | df.to_csv(args.output_dir / "observables.csv", index=False) 218 | -------------------------------------------------------------------------------- /scripts/vacancy_frenkel_ladd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import ase 5 | import numpy as np 6 | import pandas as pd 7 | from ase import units 8 | from ase.build import make_supercell 9 | from ase.constraints import ExpCellFilter 10 | from ase.md.langevin import Langevin 11 | from ase.md.npt import NPT 12 | from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen 13 | from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary 14 | from ase.optimize import FIRE 15 | from mace.calculators import mace_mp 16 | from pymatgen.io.ase import AseAtomsAdaptor 17 | from pymatgen.symmetry.analyzer import SpacegroupAnalyzer 18 | from tqdm import tqdm 19 | 20 | from alchemical_mace.calculator import ( 21 | DefectFrenkelLaddCalculator, 22 | FrenkelLaddCalculator, 23 | NVTMACECalculator, 24 | ) 25 | from alchemical_mace.utils import upper_triangular_cell 26 | 27 | 28 | # Arguments 29 | parser = argparse.ArgumentParser() 30 | 31 | # Structure 32 | parser.add_argument("--structure-file", type=str) 33 | parser.add_argument("--supercell", type=int, nargs=3, default=[5, 5, 5]) 34 | 35 | # Molecular dynamics: general 36 | parser.add_argument("--temperature", type=float, default=300.0) 37 | parser.add_argument("--pressure", type=float, default=1.0) 38 | parser.add_argument("--timestep", type=float, default=2.0) 39 | parser.add_argument("--ttime", type=float, default=25.0) 40 | parser.add_argument("--ptime", type=int, default=75.0) 41 | 42 | # Molecular dynamics: timesteps 43 | parser.add_argument("--npt-equil-stpes", type=int, default=10000) 44 | parser.add_argument("--npt-prod-steps", type=int, default=20000) 45 | parser.add_argument("--nvt-equil-steps", type=int, default=20000) 46 | parser.add_argument("--nvt-prod-steps", type=int, default=30000) 47 | parser.add_argument("--alchemy-equil-steps", type=int, default=20000) 48 | parser.add_argument("--alchemy-switch-steps", type=int, default=30000) 49 | 50 | # Molecular dynamics: output control 51 | parser.add_argument("--output-dir", type=Path, default=Path("results")) 52 | parser.add_argument("--log-interval", type=int, default=1) 53 | 54 | # MACE model 55 | parser.add_argument("--device", type=str, default="cuda") 56 | parser.add_argument("--model", type=str, default="small") 57 | 58 | args = parser.parse_args() 59 | args.output_dir.mkdir(exist_ok=True, parents=True) 60 | 61 | 62 | ################################################################################ 63 | # Energy minimization: defect-free structure 64 | ################################################################################ 65 | 66 | # Load structure 67 | atoms = ase.io.read(args.structure_file) 68 | atoms = make_supercell(atoms, np.diag(args.supercell)) 69 | 70 | # Load universal MACE calculator and relax the structure 71 | mace_calc = mace_mp(model=args.model, device=args.device, default_dtype="float32") 72 | atoms.calc = mace_calc 73 | atoms = ExpCellFilter(atoms) 74 | optimizer = FIRE(atoms) 75 | optimizer.run(fmax=0.01, steps=500) 76 | atoms = atoms.atoms # get the relaxed structure 77 | initial_atoms = atoms.copy() # save the initial structure 78 | 79 | 80 | ################################################################################ 81 | # Cell volume equilibration: defect-free structure 82 | ################################################################################ 83 | 84 | atoms = initial_atoms.copy() 85 | atoms.set_calculator(mace_calc) 86 | bulk_modulus = 100.0 * units.GPa 87 | 88 | # Equilibration and volume calculation 89 | dyn = Inhomogeneous_NPTBerendsen( 90 | atoms, 91 | timestep=args.timestep * units.fs, 92 | temperature_K=args.temperature, 93 | pressure_au=args.pressure * 1.01325 * units.bar, 94 | taut=args.ttime * units.fs, 95 | taup=args.ptime * units.fs, 96 | compressibility_au=1.0 / bulk_modulus, 97 | ) 98 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) 99 | Stationary(atoms) 100 | 101 | # NPT equilibration and volume relaxation 102 | cellpar_traj = [] 103 | for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"): 104 | dyn.run(steps=1) 105 | for step in tqdm(range(args.npt_prod_steps), desc="NPT prod"): 106 | dyn.run(steps=1) 107 | if step % args.log_interval == 0: 108 | cellpar_traj.append(atoms.get_cell().cellpar()) 109 | abc_new = np.mean(cellpar_traj, axis=0)[:3] 110 | 111 | # Scale the initial cell to match the average volume 112 | atoms = initial_atoms 113 | atoms.set_cell(np.diag(abc_new), scale_atoms=True) 114 | atoms.set_calculator(mace_calc) 115 | 116 | # Relax the atomic positions 117 | optimizer = FIRE(atoms) 118 | optimizer.run(fmax=0.01, steps=500) 119 | initial_atoms = atoms.copy() # save the initial structure 120 | 121 | 122 | ################################################################################ 123 | # MSD calculation: defect-free structure 124 | ################################################################################ 125 | 126 | initial_positions = atoms.get_positions() 127 | # Using the reversible scaling MACE calculator with fixed scale of 1.0 128 | # since we can turn off the stress calculation 129 | calc = NVTMACECalculator(device=args.device, model=args.model) 130 | atoms.set_calculator(calc) 131 | 132 | # NVT MSD calculation 133 | dyn = Langevin( 134 | atoms, 135 | timestep=args.timestep * units.fs, 136 | temperature_K=args.temperature, 137 | friction=1 / (args.ttime * units.fs), 138 | ) 139 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) 140 | Stationary(atoms) 141 | 142 | temperatures = [] 143 | for step in tqdm(range(args.nvt_equil_steps), desc="NVT equil"): 144 | dyn.run(steps=1) 145 | squared_disp = np.zeros(len(atoms)) 146 | for step in tqdm(range(args.nvt_prod_steps), desc="NVT prod"): 147 | dyn.run(steps=1) 148 | squared_disp += np.sum((atoms.get_positions() - initial_positions) ** 2, axis=1) 149 | mean_squared_disp = squared_disp / args.nvt_prod_steps 150 | 151 | # Calculate spring constants and average over symmetrically equivalent atoms 152 | spring_constants = 3.0 * units.kB * args.temperature / mean_squared_disp 153 | structure = AseAtomsAdaptor.get_structure(initial_atoms) 154 | sga = SpacegroupAnalyzer(structure) 155 | equivalent_indices = sga.get_symmetrized_structure().equivalent_indices 156 | for indices in equivalent_indices: 157 | spring_constants[indices] = np.mean(spring_constants[indices]) 158 | 159 | np.save(args.output_dir / "spring_constants.npy", spring_constants) 160 | np.save(args.output_dir / "masses.npy", atoms.get_masses()) 161 | 162 | 163 | ################################################################################ 164 | # Frenkel-Ladd calculation: defect-free structure 165 | ################################################################################ 166 | 167 | atoms = initial_atoms.copy() 168 | calc = FrenkelLaddCalculator( 169 | spring_constants=spring_constants, 170 | initial_positions=initial_positions, 171 | device=args.device, 172 | model=args.model, 173 | ) 174 | atoms.set_calculator(calc) 175 | 176 | # NVT Frenkel-Ladd calculation 177 | dyn = Langevin( 178 | atoms, 179 | timestep=args.timestep * units.fs, 180 | temperature_K=args.temperature, 181 | friction=1 / (args.ttime * units.fs), 182 | ) 183 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) 184 | Stationary(atoms) 185 | 186 | # Define Frenkel-Ladd path 187 | t = np.linspace(0.0, 1.0, args.alchemy_switch_steps) 188 | lambda_steps = t**5 * (70 * t**4 - 315 * t**3 + 540 * t**2 - 420 * t + 126) 189 | lambda_values = [ 190 | np.zeros(args.alchemy_equil_steps), 191 | lambda_steps, 192 | np.ones(args.alchemy_equil_steps), 193 | lambda_steps[::-1], 194 | ] 195 | lambda_values = np.concatenate(lambda_values) 196 | 197 | 198 | def get_observables(dynamics, time, lambda_value): 199 | num_atoms = len(dynamics.atoms) 200 | return { 201 | "time": time, 202 | "potential": dynamics.atoms.get_potential_energy() / num_atoms, 203 | "temperature": dynamics.atoms.get_temperature(), 204 | "volume": dynamics.atoms.get_volume() / num_atoms, 205 | "lambda": lambda_value, 206 | "lambda_grad": dynamics.atoms._calc.results["energy_diff"] / num_atoms, 207 | } 208 | 209 | 210 | # Simulation loop 211 | calc.compute_mace = False 212 | total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps 213 | 214 | observables = [] 215 | for step in tqdm(range(total_steps), desc="Frenkel-Ladd"): 216 | if step == args.alchemy_equil_steps: # turn on MACE after spring equilibration 217 | calc.compute_mace = True 218 | lambda_value = lambda_values[step] 219 | calc.set_weights(lambda_value) 220 | 221 | dyn.run(steps=1) 222 | if step % args.log_interval == 0: 223 | time = (step + 1) * args.timestep 224 | observables.append(get_observables(dyn, time, lambda_value)) 225 | 226 | # Save observables 227 | df = pd.DataFrame(observables) 228 | df.to_csv(args.output_dir / "observables.csv", index=False) 229 | 230 | 231 | ################################################################################ 232 | # Cell volume equilibration: structure with a defect 233 | ################################################################################ 234 | 235 | atoms = initial_atoms.copy() 236 | 237 | # Create a vacancy at the center of the supercell 238 | vacancy_index = len(atoms) // 2 239 | atom_mask = np.ones(len(atoms), dtype=bool) 240 | atom_mask[vacancy_index] = False 241 | del atoms[vacancy_index] 242 | 243 | atoms.set_calculator(mace_calc) 244 | 245 | # Equilibration and volume calculation 246 | dyn = Inhomogeneous_NPTBerendsen( 247 | atoms, 248 | timestep=args.timestep * units.fs, 249 | temperature_K=args.temperature, 250 | pressure_au=args.pressure * 1.01325 * units.bar, 251 | taut=args.ttime * units.fs, 252 | taup=args.ptime * units.fs, 253 | compressibility_au=1.0 / bulk_modulus, 254 | ) 255 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) 256 | Stationary(atoms) 257 | 258 | # NPT equilibration and volume relaxation 259 | cellpar_traj = [] 260 | for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"): 261 | dyn.run(steps=1) 262 | for step in tqdm(range(args.npt_prod_steps), desc="NPT prod"): 263 | dyn.run(steps=1) 264 | if step % args.log_interval == 0: 265 | cellpar_traj.append(atoms.get_cell().cellpar()) 266 | abc_new = np.mean(cellpar_traj, axis=0)[:3] 267 | 268 | # Scale the initial cell to match the average volume 269 | atoms = initial_atoms.copy() 270 | atoms.set_cell(np.diag(abc_new), scale_atoms=True) 271 | del atoms[vacancy_index] 272 | atoms.set_calculator(mace_calc) 273 | 274 | # Relax the atomic positions 275 | optimizer = FIRE(atoms) 276 | optimizer.run(fmax=0.01, steps=500) 277 | 278 | 279 | ################################################################################ 280 | # Frenkel-Ladd calculation: structure with a defect 281 | ################################################################################ 282 | 283 | calc = FrenkelLaddCalculator( 284 | spring_constants=spring_constants[atom_mask], 285 | initial_positions=initial_positions[atom_mask], 286 | device=args.device, 287 | model=args.model, 288 | ) 289 | atoms.set_calculator(calc) 290 | 291 | # NVT Frenkel-Ladd calculation 292 | dyn = Langevin( 293 | atoms, 294 | timestep=args.timestep * units.fs, 295 | temperature_K=args.temperature, 296 | friction=1 / (args.ttime * units.fs), 297 | ) 298 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) 299 | Stationary(atoms) 300 | 301 | # Simulation loop 302 | calc.compute_mace = False 303 | total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps 304 | 305 | observables = [] 306 | for step in tqdm(range(total_steps), desc="Frenkel-Ladd"): 307 | if step == args.alchemy_equil_steps: # turn on MACE after spring equilibration 308 | calc.compute_mace = True 309 | lambda_value = lambda_values[step] 310 | calc.set_weights(lambda_value) 311 | 312 | dyn.run(steps=1) 313 | if step % args.log_interval == 0: 314 | time = (step + 1) * args.timestep 315 | observables.append(get_observables(dyn, time, lambda_value)) 316 | 317 | # Save observables 318 | df = pd.DataFrame(observables) 319 | df.to_csv(args.output_dir / "observables_defect.csv", index=False) 320 | 321 | 322 | ################################################################################ 323 | # Cell volume equilibration: partial Frenkel-Ladd calculation 324 | ################################################################################ 325 | 326 | atoms = initial_atoms.copy() 327 | atoms.set_calculator(mace_calc) 328 | 329 | # Equilibration and volume calculation 330 | dyn = Inhomogeneous_NPTBerendsen( 331 | atoms, 332 | timestep=args.timestep * units.fs, 333 | temperature_K=args.temperature, 334 | pressure_au=args.pressure * 1.01325 * units.bar, 335 | taut=args.ttime * units.fs, 336 | taup=args.ptime * units.fs, 337 | compressibility_au=1.0 / bulk_modulus, 338 | ) 339 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature) 340 | Stationary(atoms) 341 | 342 | # NPT equilibration and volume relaxation 343 | for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"): 344 | dyn.run(steps=1) 345 | 346 | 347 | ################################################################################ 348 | # Alchemical switching 349 | ################################################################################ 350 | 351 | # Set up the partial Frenkel-Ladd calculation 352 | calc = DefectFrenkelLaddCalculator( 353 | atoms=atoms, 354 | spring_constant=spring_constants[vacancy_index], 355 | defect_index=vacancy_index, 356 | device=args.device, 357 | model=args.model, 358 | ) 359 | atoms.set_calculator(calc) 360 | upper_triangular_cell(atoms) # for ASE NPT 361 | 362 | # NPT alchemical switching 363 | ptime = args.ptime * units.fs 364 | pfactor = bulk_modulus * ptime * ptime 365 | 366 | dyn = NPT( 367 | atoms, 368 | timestep=args.timestep * units.fs, 369 | temperature_K=args.temperature, 370 | externalstress=args.pressure * 1.01325 * units.bar, 371 | ttime=args.ttime * units.fs, 372 | pfactor=pfactor, 373 | ) 374 | 375 | # Define alchemical path 376 | t = np.linspace(0.0, 1.0, args.alchemy_switch_steps) 377 | lambda_steps = t**5 * (70 * t**4 - 315 * t**3 + 540 * t**2 - 420 * t + 126) 378 | lambda_values = [ 379 | np.zeros(args.alchemy_equil_steps), 380 | lambda_steps, 381 | np.ones(args.alchemy_equil_steps), 382 | lambda_steps[::-1], 383 | ] 384 | lambda_values = np.concatenate(lambda_values) 385 | 386 | calculate_gradients = [ 387 | np.zeros(args.alchemy_equil_steps, dtype=bool), 388 | np.ones(args.alchemy_switch_steps, dtype=bool), 389 | np.zeros(args.alchemy_equil_steps, dtype=bool), 390 | np.ones(args.alchemy_switch_steps, dtype=bool), 391 | ] 392 | calculate_gradients = np.concatenate(calculate_gradients) 393 | 394 | 395 | def get_observables(dynamics, time, lambda_value): 396 | num_atoms = len(dynamics.atoms) 397 | alchemical_grad = dynamics.atoms._calc.results["alchemical_grad"] 398 | return { 399 | "time": time, 400 | "potential": dynamics.atoms.get_potential_energy() / num_atoms, 401 | "temperature": dynamics.atoms.get_temperature(), 402 | "volume": dynamics.atoms.get_volume() / num_atoms, 403 | "lambda": lambda_value, 404 | "lambda_grad": alchemical_grad / num_atoms, 405 | } 406 | 407 | 408 | # Simulation loop 409 | total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps 410 | 411 | observables = [] 412 | for step in tqdm(range(total_steps), desc="Alchemical switching"): 413 | lambda_value = lambda_values[step] 414 | grad_enabled = calculate_gradients[step] 415 | 416 | # Set alchemical weights and atomic masses 417 | calc.set_alchemical_weight(lambda_value) 418 | calc.calculate_alchemical_grad = grad_enabled 419 | 420 | dyn.run(steps=1) 421 | if step % args.log_interval == 0: 422 | time = (step + 1) * args.timestep 423 | observables.append(get_observables(dyn, time, lambda_value)) 424 | 425 | # Save observables 426 | df = pd.DataFrame(observables) 427 | df.to_csv(args.output_dir / "observables_FL.csv", index=False) 428 | --------------------------------------------------------------------------------