├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── THIRD-PARTY-LICENSES
├── alchemical_mace
├── __init__.py
├── calculator.py
├── model.py
├── optimize.py
└── utils.py
├── data
├── perovskite_dataset.json
└── structures
│ ├── AlN_hex.cif
│ ├── BiSBr.cif
│ ├── CeO2.cif
│ ├── CsPbI3_alpha.cif
│ ├── CsPbI3_delta.cif
│ ├── CsSnI3_alpha.cif
│ ├── CsSnI3_delta.cif
│ ├── Fe.cif
│ ├── GaN_hex.cif
│ └── NaCl.cif
├── notebooks
├── 1_solid_solution.ipynb
├── 2_compositional_optimization.ipynb
├── 3_disorder_energy.ipynb
├── 4_vacancy_analysis.ipynb
└── 5_perovskite_analysis.ipynb
├── pyproject.toml
├── requirements.txt
└── scripts
├── perovskite_alchemy.py
├── perovskite_frenkel_ladd.py
└── vacancy_frenkel_ladd.py
/.gitignore:
--------------------------------------------------------------------------------
1 | archived/
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/#use-with-ide
112 | .pdm.toml
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
164 | # Results directory
165 | results/
166 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.6.0
4 | hooks:
5 | - id: check-added-large-files
6 | - id: check-yaml
7 | - id: debug-statements
8 | - id: end-of-file-fixer
9 | - id: trailing-whitespace
10 | - repo: https://github.com/astral-sh/ruff-pre-commit
11 | rev: v0.4.2
12 | hooks:
13 | - id: ruff
14 | types_or: [ python, pyi, jupyter ]
15 | args: [ --fix ]
16 | - id: ruff-format
17 | types_or: [ python, pyi, jupyter ]
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Juno Nam and Rafael Gomez-Bombarelli
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Alchemical MLIP
2 | [](https://arxiv.org/abs/2404.10746)
3 | [](https://zenodo.org/doi/10.5281/zenodo.11081492)
4 | [](https://opensource.org/license/mit)
5 |
6 | This repository contains the code to modify machine learning interatomic potentials (MLIPs) to enable continuous and differentiable alchemical transformations.
7 | Currently, we provide the alchemical modification for the [MACE](https://github.com/ACEsuit/mace) model.
8 | The details of the method are described in the paper: [Interpolation and differentiation of alchemical degrees of freedom in machine learning interatomic potentials](https://arxiv.org/abs/2404.10746).
9 |
10 | ## Installation
11 | We tested the code with Python 3.10 and the packages in `requirements.txt`.
12 | For example, you can create a conda environment and install the required packages as follows (assuming CUDA 11.8):
13 | ```bash
14 | conda create -n alchemical-mlip python=3.10
15 | conda activate alchemical-mlip
16 | pip install torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118
17 | pip install -r requirements.txt
18 | pip install -e .
19 | ```
20 |
21 | ## Static calculations
22 | We provide the jupyter notebooks for the lattice parameter calculations (Fig. 2 in the paper) and the compositional optimization (Fig. 3) in the `notebook` directory.
23 | ```
24 | notebook/
25 | ├── 1_solid_solution.ipynb
26 | └── 2_compositional_optimization.ipynb
27 | ```
28 |
29 | ## Free energy calculations
30 | We provide the scripts for the free energy calculations for the vacancy (Fig. 4) and perovskites (Fig. 5) in the `scripts` directory.
31 | ```
32 | scripts/
33 | ├── vacancy_frenkel_ladd.py
34 | ├── perovskite_frenkel_ladd.py
35 | └── perovskite_alchemy.py
36 | ```
37 |
38 | The arguments for the scripts are as follows:
39 | ```bash
40 | # Vacancy Frenkel-Ladd calculation
41 | python vacancy_frenkel_ladd.py \
42 | --structure-file data/structures/Fe.cif \
43 | --supercell 5 5 5 \
44 | --temperature 100 \
45 | --output-dir data/results/vacancy/Fe_5x5x5_100K/0
46 |
47 | # Perovskite Frenkel-Ladd calculation (alpha phase)
48 | python perovskite_frenkel_ladd.py \
49 | --structure-file data/structures/CsPbI3_alpha.cif \
50 | --supercell 6 6 6 \
51 | --temperature 400 \
52 | --output-dir data/results/perovskite/frenkel_ladd/CsPbI3_alpha_6x6x6_400K/0
53 |
54 | # Perovskite Frenkel-Ladd calculation (delta phase)
55 | python perovskite_frenkel_ladd.py \
56 | --structure-file data/structures/CsPbI3_delta.cif \
57 | --supercell 6 3 3 \
58 | --temperature 400 \
59 | --output-dir data/results/perovskite/frenkel_ladd/CsPbI3_delta_6x3x3_400K/0
60 |
61 | # Perovskite alchemy calculation (alpha phase)
62 | python -u perovskite_alchemy.py \
63 | --structure-file data/structures/CsPbI3_alpha.cif \
64 | --supercell 6 6 6 \
65 | --switch-pair Pb Sn \
66 | --temperature 400 \
67 | --output-dir data/results/perovskite/alchemy/CsPbI3_CsSnI3_alpha_400K/0
68 |
69 | # Perovskite alchemy calculation (delta phase)
70 | python -u perovskite_alchemy.py \
71 | --structure-file data/structures/CsPbI3_delta.cif \
72 | --supercell 6 3 3 \
73 | --switch-pair Pb Sn \
74 | --temperature 400 \
75 | --output-dir data/results/perovskite/alchemy/CsPbI3_CsSnI3_delta_400K/0
76 | ```
77 |
78 | The result files are large and not included in the repository.
79 | If you want to reproduce the results without running the calculations, the result files are uploaded in the [Zenodo repository](https://zenodo.org/doi/10.5281/zenodo.11081395).
80 | Please download the files and place them in the `data/results` directory.
81 |
82 | The post-processing scripts for the free energy calculations are provided in the `notebook` directory.
83 | ```
84 | notebook/
85 | ├── 3_vacancy_analysis.ipynb
86 | └── 4_perovskite_analysis.ipynb
87 | ```
88 |
89 | ## Citation
90 | ```
91 | @misc{nam2024interpolation,
92 | title={Interpolation and differentiation of alchemical degrees of freedom in machine learning interatomic potentials},
93 | author={Juno Nam and Rafael G{\'o}mez-Bombarelli},
94 | year={2024},
95 | eprint={2404.10746},
96 | archivePrefix={arXiv},
97 | primaryClass={cond-mat.mtrl-sci}
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/THIRD-PARTY-LICENSES:
--------------------------------------------------------------------------------
1 | Code in alchemical_mace/{calculator,model}.py is adapted from
2 | https://github.com/ACEsuit/mace
3 |
4 | MIT License
5 |
6 | Copyright (c) 2022 ACEsuit/mace
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
11 |
12 | -------------------------------------------------------------------------------
13 |
14 | Code in alchemical_mace/utils.py is adapted from
15 | https://github.com/CederGroupHub/chgnet
16 |
17 | Crystal Hamiltonian Graph neural Network (CHGNet) Copyright (c) 2023, The Regents
18 | of the University of California, through Lawrence Berkeley National
19 | Laboratory (subject to receipt of any required approvals from the U.S.
20 | Dept. of Energy) and the University of California, Berkeley. All rights reserved.
21 |
22 | Redistribution and use in source and binary forms, with or without
23 | modification, are permitted provided that the following conditions are met:
24 |
25 | (1) Redistributions of source code must retain the above copyright notice,
26 | this list of conditions and the following disclaimer.
27 |
28 | (2) Redistributions in binary form must reproduce the above copyright
29 | notice, this list of conditions and the following disclaimer in the
30 | documentation and/or other materials provided with the distribution.
31 |
32 | (3) Neither the name of the University of California, Lawrence Berkeley
33 | National Laboratory, U.S. Dept. of Energy, University of California,
34 | Berkeley nor the names of its contributors may be used to endorse or
35 | promote products derived from this software without specific prior written
36 | permission.
37 |
38 |
39 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
40 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
41 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
42 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
43 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
44 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
45 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
46 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
47 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
48 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
49 | POSSIBILITY OF SUCH DAMAGE.
50 |
51 | You are under no obligation whatsoever to provide any bug fixes, patches,
52 | or upgrades to the features, functionality or performance of the source
53 | code ("Enhancements") to anyone; however, if you choose to make your
54 | Enhancements available either publicly, or directly to Lawrence Berkeley
55 | National Laboratory, without imposing a separate written license agreement
56 | for such Enhancements, then you hereby grant the following license: a
57 | non-exclusive, royalty-free perpetual license to install, use, modify,
58 | prepare derivative works, incorporate into other computer software,
59 | distribute, and sublicense such enhancements or derivative works thereof,
60 | in binary and source code form.
61 |
--------------------------------------------------------------------------------
/alchemical_mace/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/learningmatter-mit/alchemical-mlip/b556b850e9e89a4c2e1f5e9c27f03a3b1b393864/alchemical_mace/__init__.py
--------------------------------------------------------------------------------
/alchemical_mace/calculator.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence, Tuple
2 |
3 | import ase
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from ase.calculators.calculator import Calculator, all_changes
8 | from ase.constraints import ExpCellFilter
9 | from ase.optimize import FIRE
10 | from ase.stress import full_3x3_to_voigt_6_stress
11 | from mace import data
12 | from mace.calculators import mace_mp
13 | from mace.tools import torch_geometric
14 |
15 | from alchemical_mace.model import (
16 | AlchemicalPair,
17 | AlchemyManager,
18 | alchemical_mace_mp,
19 | get_z_table_and_r_max,
20 | )
21 |
22 | ################################################################################
23 | # Alchemical MACE calculator
24 | ################################################################################
25 |
26 |
27 | class AlchemicalMACECalculator(Calculator):
28 | """
29 | Alchemical MACE calculator for ASE.
30 | """
31 |
32 | def __init__(
33 | self,
34 | atoms: ase.Atoms,
35 | alchemical_pairs: Sequence[Sequence[Tuple[int, int]]],
36 | alchemical_weights: Sequence[float],
37 | device: str = "cpu",
38 | model: str = "medium",
39 | ):
40 | """
41 | Initialize the Alchemical MACE calculator.
42 |
43 | Args:
44 | atoms (ase.Atoms): Atoms object.
45 | alchemical_pairs (Sequence[Sequence[Tuple[int, int]]]): List of
46 | alchemical pairs. Each pair is a tuple of the atom index and
47 | atomic number of an alchemical atom.
48 | alchemical_weights (Sequence[float]): List of alchemical weights.
49 | device (str): Device to run the calculations on.
50 | model (str): Model to use for the MACE calculator.
51 | """
52 | Calculator.__init__(self)
53 | self.results = {}
54 | self.implemented_properties = ["energy", "free_energy", "forces", "stress"]
55 |
56 | # Build the alchemical MACE model
57 | self.device = device
58 | self.model = alchemical_mace_mp(
59 | model=model, device=device, default_dtype="float32"
60 | )
61 | for param in self.model.parameters():
62 | param.requires_grad = False
63 |
64 | # Set AlchemyManager
65 | z_table, r_max = get_z_table_and_r_max(self.model)
66 | alchemical_weights = torch.tensor(alchemical_weights, dtype=torch.float32)
67 | self.alchemy_manager = AlchemyManager(
68 | atoms=atoms,
69 | alchemical_pairs=alchemical_pairs,
70 | alchemical_weights=alchemical_weights,
71 | z_table=z_table,
72 | r_max=r_max,
73 | ).to(self.device)
74 |
75 | # Disable alchemical weights gradients by default
76 | self.alchemy_manager.alchemical_weights.requires_grad = False
77 | self.calculate_alchemical_grad = False
78 |
79 | self.num_atoms = len(atoms)
80 |
81 | def set_alchemical_weights(self, alchemical_weights: Sequence[float]):
82 | alchemical_weights = torch.tensor(
83 | alchemical_weights,
84 | dtype=torch.float32,
85 | device=self.device,
86 | )
87 | self.alchemy_manager.alchemical_weights.data = alchemical_weights
88 |
89 | def get_alchemical_atomic_masses(self) -> np.ndarray:
90 | # Get atomic masses for alchemical atoms
91 | node_masses = ase.data.atomic_masses[self.alchemy_manager.atomic_numbers]
92 | weights = self.alchemy_manager.alchemical_weights.data
93 | weights = F.pad(weights, (1, 0), "constant", 1.0).cpu().numpy()
94 | node_weights = weights[self.alchemy_manager.weight_indices]
95 |
96 | # Scatter sum to get the atomic masses
97 | atom_masses = np.zeros(self.num_atoms, dtype=np.float32)
98 | np.add.at(
99 | atom_masses, self.alchemy_manager.atom_indices, node_masses * node_weights
100 | )
101 | return atom_masses
102 |
103 | # pylint: disable=dangerous-default-value
104 | def calculate(self, atoms=None, properties=None, system_changes=all_changes):
105 | # call to base-class to set atoms attribute
106 | Calculator.calculate(self, atoms)
107 |
108 | # prepare data
109 | tensor_kwargs = {"dtype": torch.float32, "device": self.device}
110 | positions = torch.tensor(atoms.get_positions(), **tensor_kwargs)
111 | cell = torch.tensor(atoms.get_cell().array, **tensor_kwargs)
112 | if self.calculate_alchemical_grad:
113 | self.alchemy_manager.alchemical_weights.requires_grad = True
114 | batch = self.alchemy_manager(positions, cell).to(self.device)
115 |
116 | # get outputs
117 | if self.calculate_alchemical_grad:
118 | out = self.model(batch, compute_stress=True, compute_alchemical_grad=True)
119 | (grad,) = torch.autograd.grad(
120 | outputs=[batch["node_weights"], batch["edge_weights"]],
121 | inputs=[self.alchemy_manager.alchemical_weights],
122 | grad_outputs=[out["node_grad"], out["edge_grad"]],
123 | retain_graph=False,
124 | create_graph=False,
125 | )
126 | grad = grad.cpu().numpy()
127 | self.alchemy_manager.alchemical_weights.requires_grad = False
128 | else:
129 | out = self.model(batch, retain_graph=False, compute_stress=True)
130 | grad = np.zeros(
131 | self.alchemy_manager.alchemical_weights.shape[0], dtype=np.float32
132 | )
133 |
134 | # store results
135 | self.results = {}
136 | self.results["energy"] = out["energy"][0].item()
137 | self.results["free_energy"] = self.results["energy"]
138 | self.results["forces"] = out["forces"].detach().cpu().numpy()
139 | self.results["stress"] = full_3x3_to_voigt_6_stress(
140 | out["stress"][0].detach().cpu().numpy()
141 | )
142 | self.results["alchemical_grad"] = grad
143 |
144 |
145 | class NVTMACECalculator(Calculator):
146 | def __init__(self, model: str = "medium", device: str = "cuda"):
147 | Calculator.__init__(self)
148 | self.results = {}
149 | self.implemented_properties = ["energy", "free_energy", "forces", "stress"]
150 | self.device = device
151 | self.model = mace_mp(
152 | model=model, device=device, default_dtype="float32"
153 | ).models[0]
154 | self.z_table, self.r_max = get_z_table_and_r_max(self.model)
155 | for param in self.model.parameters():
156 | param.requires_grad = False
157 |
158 | # pylint: disable=dangerous-default-value
159 | def calculate(self, atoms=None, properties=None, system_changes=all_changes):
160 | # call to base-class to set atoms attribute
161 | Calculator.calculate(self, atoms)
162 |
163 | # prepare data
164 | config = data.config_from_atoms(atoms)
165 | atomic_data = data.AtomicData.from_config(
166 | config, z_table=self.z_table, cutoff=self.r_max
167 | )
168 | data_loader = torch_geometric.dataloader.DataLoader(
169 | dataset=[atomic_data],
170 | batch_size=1,
171 | shuffle=False,
172 | drop_last=False,
173 | )
174 | batch = next(iter(data_loader)).to(self.device)
175 |
176 | out = self.model(batch, compute_stress=False)
177 | self.results = {}
178 | self.results["energy"] = out["energy"][0].item()
179 | self.results["free_energy"] = self.results["energy"]
180 | self.results["forces"] = out["forces"].detach().cpu().numpy()
181 |
182 |
183 | class FrenkelLaddCalculator(Calculator):
184 | """
185 | Frenkel-Ladd calculator for ASE.
186 | """
187 |
188 | def __init__(
189 | self,
190 | spring_constants: np.ndarray,
191 | initial_positions: np.ndarray,
192 | device: str,
193 | model: str = "medium",
194 | ):
195 | """
196 | Initialize the Frenkel-Ladd calculator.
197 |
198 | Args:
199 | spring_constants (np.ndarray): Spring constants for each atom.
200 | initial_positions (np.ndarray): Initial positions of the atoms.
201 | device (str): Device to run the calculations on.
202 | model (str): Model to use for the MACE calculator.
203 | """
204 | Calculator.__init__(self)
205 | self.results = {}
206 | self.implemented_properties = ["energy", "free_energy", "forces"]
207 | self.device = device
208 | self.model = mace_mp(
209 | model=model, device=device, default_dtype="float32"
210 | ).models[0]
211 | self.z_table, self.r_max = get_z_table_and_r_max(self.model)
212 | for param in self.model.parameters():
213 | param.requires_grad = False
214 |
215 | # Spring constants
216 | self.spring_constants = spring_constants
217 | self.initial_positions = initial_positions
218 |
219 | # Reversible scaling factor
220 | self.weights = [1.0, 0.0]
221 | self.compute_mace = True
222 |
223 | def set_weights(self, lambda_value: float):
224 | self.weights = [1.0 - lambda_value, lambda_value]
225 |
226 | # pylint: disable=dangerous-default-value
227 | def calculate(self, atoms=None, properties=None, system_changes=all_changes):
228 | # call to base-class to set atoms attribute
229 | Calculator.calculate(self, atoms)
230 |
231 | # Get MACE results if needed
232 | if self.compute_mace:
233 | config = data.config_from_atoms(atoms)
234 | atomic_data = data.AtomicData.from_config(
235 | config, z_table=self.z_table, cutoff=self.r_max
236 | )
237 | data_loader = torch_geometric.dataloader.DataLoader(
238 | dataset=[atomic_data],
239 | batch_size=1,
240 | shuffle=False,
241 | drop_last=False,
242 | )
243 | batch = next(iter(data_loader)).to(self.device)
244 | out = self.model(batch, compute_stress=False) # Frenkel-Ladd is NVT
245 | mace_energy = out["energy"][0].item()
246 | mace_forces = out["forces"].detach().cpu().numpy()
247 | else:
248 | mace_energy = 0.0
249 | mace_forces = np.zeros((len(atoms), 3), dtype=np.float32)
250 |
251 | # Get spring energy and forces
252 | displacement = atoms.get_positions() - self.initial_positions
253 | spring_energy = 0.5 * np.sum(
254 | self.spring_constants * np.sum(displacement**2, axis=1)
255 | )
256 | spring_forces = -self.spring_constants[:, None] * displacement
257 |
258 | # Combine energies and forces
259 | total_energy = self.weights[0] * spring_energy + self.weights[1] * mace_energy
260 | total_forces = self.weights[0] * spring_forces + self.weights[1] * mace_forces
261 | if self.compute_mace:
262 | energy_diff = mace_energy - spring_energy
263 | else:
264 | energy_diff = 0.0
265 |
266 | self.results = {}
267 | self.results["energy"] = total_energy
268 | self.results["free_energy"] = total_energy
269 | self.results["forces"] = total_forces
270 | self.results["energy_diff"] = energy_diff
271 |
272 |
273 | class DefectFrenkelLaddCalculator(Calculator):
274 | """
275 | Frenkel-Ladd calculator for ASE, for a crystal with a defect.
276 | """
277 |
278 | def __init__(
279 | self,
280 | atoms: ase.Atoms,
281 | spring_constant: float,
282 | defect_index: int,
283 | device: str = "cpu",
284 | model: str = "medium",
285 | ):
286 | """
287 | Initialize the Frenkel-Ladd calculator.
288 |
289 | Args:
290 | atoms (ase.Atoms): Atoms object.
291 | spring_constant (float): Spring constant for the defect atom.
292 | defect_index (int): Index of the defect atom.
293 | device (str): Device to run the calculations on.
294 | model (str): Model to use for the MACE calculator.
295 | """
296 | Calculator.__init__(self)
297 | self.results = {}
298 | self.implemented_properties = ["energy", "free_energy", "forces", "stress"]
299 |
300 | # Build the alchemical MACE model
301 | self.device = device
302 | self.model = alchemical_mace_mp(
303 | model=model, device=device, default_dtype="float32"
304 | )
305 | for param in self.model.parameters():
306 | param.requires_grad = False
307 |
308 | # Set AlchemyManager
309 | z_table, r_max = get_z_table_and_r_max(self.model)
310 | alchemical_weights = torch.tensor([1.0], dtype=torch.float32)
311 | atomic_number = atoms.get_atomic_numbers()[defect_index]
312 | alchemical_pairs = [[AlchemicalPair(defect_index, atomic_number)]]
313 | self.alchemy_manager = AlchemyManager(
314 | atoms=atoms,
315 | alchemical_pairs=alchemical_pairs,
316 | alchemical_weights=alchemical_weights,
317 | z_table=z_table,
318 | r_max=r_max,
319 | ).to(self.device)
320 |
321 | # Disable alchemical weights gradients by default
322 | self.alchemy_manager.alchemical_weights.requires_grad = False
323 | self.calculate_alchemical_grad = False
324 |
325 | self.num_atoms = len(atoms)
326 |
327 | # Switching
328 | self.defect_index = defect_index
329 | self.spring_constant = spring_constant
330 |
331 | def set_alchemical_weight(self, alchemical_weight: float):
332 | # Set alchemical weights
333 | alchemical_weights = torch.tensor(
334 | [1.0 - alchemical_weight], # initial = original atoms = 1 - 0
335 | dtype=torch.float32,
336 | device=self.device,
337 | )
338 | self.alchemy_manager.alchemical_weights.data = alchemical_weights
339 |
340 | # pylint: disable=dangerous-default-value
341 | def calculate(self, atoms=None, properties=None, system_changes=all_changes):
342 | # call to base-class to set atoms attribute
343 | Calculator.calculate(self, atoms)
344 |
345 | # prepare data
346 | tensor_kwargs = {"dtype": torch.float32, "device": self.device}
347 | positions = torch.tensor(atoms.get_positions(), **tensor_kwargs)
348 | cell = torch.tensor(atoms.get_cell().array, **tensor_kwargs)
349 | if self.calculate_alchemical_grad:
350 | self.alchemy_manager.alchemical_weights.requires_grad = True
351 | batch = self.alchemy_manager(positions, cell).to(self.device)
352 |
353 | # get outputs
354 | if self.calculate_alchemical_grad:
355 | out = self.model(batch, retain_graph=True, compute_stress=True)
356 | out["energy"].backward()
357 | grad = self.alchemy_manager.alchemical_weights.grad.item()
358 | self.alchemy_manager.alchemical_weights.grad.zero_()
359 | self.alchemy_manager.alchemical_weights.requires_grad = False
360 | else:
361 | out = self.model(batch, retain_graph=False, compute_stress=True)
362 | grad = 0.0
363 | mace_energy = out["energy"][0].item()
364 | mace_forces = out["forces"].detach().cpu().numpy()
365 | mace_stress = out["stress"][0].detach().cpu().numpy()
366 |
367 | # Get spring energy and forces
368 | cell_center = np.array([0.5, 0.5, 0.5]) @ atoms.get_cell().array
369 | displacement = atoms.get_positions()[self.defect_index] - cell_center
370 | spring_energy = 0.5 * self.spring_constant * np.sum(displacement**2)
371 | spring_forces = -self.spring_constant * displacement
372 |
373 | # Combine energies and forces
374 | # Note: weight here is 1 - lambda, and we're not weighting the mace
375 | # energy because it's already weighted by the alchemical weight
376 | weight = self.alchemy_manager.alchemical_weights.item()
377 | total_energy = mace_energy + (1 - weight) * spring_energy
378 | total_forces = mace_forces
379 | total_forces[self.defect_index] += (1 - weight) * spring_forces
380 | if self.calculate_alchemical_grad:
381 | # H(lambda) = E(1 - lambda) + lambda * spring_energy
382 | # dH/d(lambda) = -dE/d(1 - lambda) + spring_energy
383 | grad = -grad + spring_energy
384 |
385 | # store results
386 | self.results = {}
387 | self.results["energy"] = total_energy
388 | self.results["free_energy"] = total_energy
389 | self.results["forces"] = total_forces
390 | self.results["stress"] = full_3x3_to_voigt_6_stress(mace_stress)
391 | self.results["alchemical_grad"] = grad
392 |
393 |
394 | def get_alchemical_optimized_cellpar(
395 | atoms: ase.Atoms,
396 | alchemical_pairs: Sequence[Sequence[Tuple[int, int]]],
397 | alchemical_weights: Sequence[float],
398 | model: str = "medium",
399 | device: str = "cpu",
400 | **kwargs,
401 | ):
402 | """
403 | Optimize the cell parameters of a crystal with alchemical atoms using the
404 | Alchemical MACE calculator.
405 |
406 | Args:
407 | atoms (ase.Atoms): Atoms object.
408 | alchemical_pairs (Sequence[Sequence[Tuple[int, int]]]): List of
409 | alchemical pairs. Each pair is a tuple of the atom index and
410 | atomic number of an alchemical atom.
411 | alchemical_weights (Sequence[float]): List of alchemical weights.
412 | model (str): Model to use for the MACE calculator.
413 | device (str): Device to run the calculations on.
414 |
415 | Returns:
416 | np.ndarray: Optimized cell parameters.
417 | """
418 | # Make a copy of the atoms object
419 | atoms = atoms.copy()
420 |
421 | # Load Alchemical MACE calculator and relax the structure
422 | calc = AlchemicalMACECalculator(
423 | atoms, alchemical_pairs, alchemical_weights, device=device, model=model
424 | )
425 | atoms.set_calculator(calc)
426 | atoms.set_masses(calc.get_alchemical_atomic_masses())
427 | atoms = ExpCellFilter(atoms)
428 | optimizer = FIRE(atoms)
429 | optimizer.run(fmax=kwargs.get("fmax", 0.01), steps=kwargs.get("steps", 500))
430 |
431 | # Return the optimized cell parameters
432 | return atoms.atoms.get_cell().cellpar()
433 |
--------------------------------------------------------------------------------
/alchemical_mace/model.py:
--------------------------------------------------------------------------------
1 | import ast
2 | from collections import namedtuple
3 | from typing import Dict, List, Optional, Sequence, Tuple
4 |
5 | import ase
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from e3nn import o3
10 | from e3nn.util.jit import compile_mode
11 | from mace import modules, tools
12 | from mace.calculators import mace_mp
13 | from mace.data.neighborhood import get_neighborhood
14 | from mace.modules import RealAgnosticResidualInteractionBlock, ScaleShiftMACE
15 | from mace.modules.utils import get_edge_vectors_and_lengths, get_symmetric_displacement
16 | from mace.tools import (
17 | AtomicNumberTable,
18 | atomic_numbers_to_indices,
19 | to_one_hot,
20 | torch_geometric,
21 | utils,
22 | )
23 | from mace.tools.scatter import scatter_sum
24 |
25 | ################################################################################
26 | # Alchemy manager class for handling alchemical weights
27 | ################################################################################
28 |
29 | AlchemicalPair = namedtuple("AlchemicalPair", ["atom_index", "atomic_number"])
30 |
31 |
32 | class AlchemyManager(torch.nn.Module):
33 | """
34 | Class for managing alchemical weights and building alchemical graphs for MACE.
35 | """
36 |
37 | def __init__(
38 | self,
39 | atoms: ase.Atoms,
40 | alchemical_pairs: Sequence[Sequence[Tuple[int, int]]],
41 | alchemical_weights: torch.Tensor,
42 | z_table: AtomicNumberTable,
43 | r_max: float,
44 | ):
45 | """
46 | Initialize the alchemy manager.
47 |
48 | Args:
49 | atoms: ASE atoms object
50 | alchemical_pairs: List of lists of tuples, where each tuple contains
51 | the atom index and atomic number of an alchemical atom
52 | alchemical_weights: Tensor of alchemical weights
53 | z_table: Atomic number table
54 | r_max: Maximum cutoff radius for the alchemical graph
55 | """
56 | super().__init__()
57 | self.alchemical_weights = torch.nn.Parameter(alchemical_weights)
58 | self.r_max = r_max
59 |
60 | # Process alchemical pairs into atom indices and atomic numbers
61 | # Alchemical weights are 1-indexed, 0 is reserved for non-alchemical atoms
62 | alchemical_atom_indices = []
63 | alchemical_atomic_numbers = []
64 | alchemical_weight_indices = []
65 |
66 | for weight_idx, pairs in enumerate(alchemical_pairs):
67 | for pair in pairs:
68 | alchemical_atom_indices.append(pair.atom_index)
69 | alchemical_atomic_numbers.append(pair.atomic_number)
70 | alchemical_weight_indices.append(weight_idx + 1)
71 |
72 | non_alchemical_atom_indices = [
73 | i for i in range(len(atoms)) if i not in alchemical_atom_indices
74 | ]
75 | non_alchemical_atomic_numbers = atoms.get_atomic_numbers()[
76 | non_alchemical_atom_indices
77 | ].tolist()
78 | non_alchemical_weight_indices = [0] * len(non_alchemical_atom_indices)
79 |
80 | self.atom_indices = alchemical_atom_indices + non_alchemical_atom_indices
81 | self.atomic_numbers = alchemical_atomic_numbers + non_alchemical_atomic_numbers
82 | self.weight_indices = alchemical_weight_indices + non_alchemical_weight_indices
83 |
84 | self.atom_indices = np.array(self.atom_indices)
85 | self.atomic_numbers = np.array(self.atomic_numbers)
86 | self.weight_indices = np.array(self.weight_indices)
87 |
88 | sort_idx = np.argsort(self.atom_indices)
89 | self.atom_indices = self.atom_indices[sort_idx]
90 | self.atomic_numbers = self.atomic_numbers[sort_idx]
91 | self.weight_indices = self.weight_indices[sort_idx]
92 |
93 | # Array to map original atom indices to alchemical indices
94 | # -1 means the atom does not have a corresponding alchemical atom
95 | # [n_atoms, n_weights + 1]
96 | self.original_to_alchemical_index = np.full(
97 | (len(atoms), len(alchemical_pairs) + 1), -1
98 | )
99 | for i, (atom_idx, weight_idx) in enumerate(
100 | zip(self.atom_indices, self.weight_indices)
101 | ):
102 | self.original_to_alchemical_index[atom_idx, weight_idx] = i
103 |
104 | self.is_original_atom_alchemical = np.any(
105 | self.original_to_alchemical_index[:, 1:] != -1, axis=1
106 | )
107 |
108 | # Extract common node features
109 | z_indices = atomic_numbers_to_indices(self.atomic_numbers, z_table=z_table)
110 | node_attrs = to_one_hot(
111 | torch.tensor(z_indices, dtype=torch.long).unsqueeze(-1),
112 | num_classes=len(z_table),
113 | )
114 | self.register_buffer("node_attrs", node_attrs)
115 | self.pbc = atoms.get_pbc()
116 |
117 | def forward(
118 | self,
119 | positions: torch.Tensor,
120 | cell: torch.Tensor,
121 | ) -> Dict[str, torch.Tensor]:
122 | """
123 | Build an alchemical graph for the given positions and cell.
124 |
125 | Args:
126 | positions: Tensor of atomic positions
127 | cell: Tensor of cell vectors
128 |
129 | Returns:
130 | Dictionary containing the alchemical graph data
131 | """
132 |
133 | # Build original atom graph
134 | orig_edge_index, shifts, unit_shifts = get_neighborhood(
135 | positions=positions.detach().cpu().numpy(),
136 | cutoff=self.r_max,
137 | pbc=self.pbc,
138 | cell=cell.detach().cpu().numpy(),
139 | )
140 |
141 | # Extend edges to alchemical pairs
142 | edge_index = []
143 | orig_edge_loc = []
144 | edge_weight_indices = []
145 |
146 | is_alchemical = self.is_original_atom_alchemical[orig_edge_index]
147 | src_non_dst_non = ~is_alchemical[0] & ~is_alchemical[1]
148 | src_non_dst_alch = ~is_alchemical[0] & is_alchemical[1]
149 | src_alch_dst_non = is_alchemical[0] & ~is_alchemical[1]
150 | src_alch_dst_alch = is_alchemical[0] & is_alchemical[1]
151 |
152 | # Both non-alchemical: keep as is
153 | _orig_edge_index = orig_edge_index[:, src_non_dst_non]
154 | edge_index.append(self.original_to_alchemical_index[_orig_edge_index, 0])
155 | orig_edge_loc.append(np.where(src_non_dst_non)[0])
156 | edge_weight_indices.append(np.zeros_like(_orig_edge_index[0]))
157 |
158 | # Source non-alchemical, destination alchemical: pair all, weights are 1
159 | _src, _dst = orig_edge_index[:, src_non_dst_alch]
160 | _orig_edge_loc = np.where(src_non_dst_alch)[0]
161 | _src = self.original_to_alchemical_index[_src, 0]
162 | _dst = self.original_to_alchemical_index[_dst, :]
163 | _dst_mask = _dst != -1
164 | _dst = _dst[_dst_mask]
165 | _repeat = _dst_mask.sum(axis=1)
166 | _src = np.repeat(_src, _repeat)
167 | edge_index.append(np.stack((_src, _dst), axis=0))
168 | orig_edge_loc.append(np.repeat(_orig_edge_loc, _repeat))
169 | edge_weight_indices.append(np.zeros_like(_src))
170 |
171 | # Source alchemical, destination non-alchemical: pair all, follow src weights
172 | _src, _dst = orig_edge_index[:, src_alch_dst_non]
173 | _orig_edge_loc = np.where(src_alch_dst_non)[0]
174 | _src = self.original_to_alchemical_index[_src, :]
175 | _dst = self.original_to_alchemical_index[_dst, 0]
176 | _src_mask = _src != -1
177 | _src = _src[_src_mask]
178 | _repeat = _src_mask.sum(axis=1)
179 | _dst = np.repeat(_dst, _repeat)
180 | edge_index.append(np.stack((_src, _dst), axis=0))
181 | orig_edge_loc.append(np.repeat(_orig_edge_loc, _repeat))
182 | edge_weight_indices.append(np.where(_src_mask)[1])
183 |
184 | # Both alchemical: pair according to alchemical indices, weights are 1
185 | _orig_edge_index = orig_edge_index[:, src_alch_dst_alch]
186 | _orig_edge_loc = np.where(src_alch_dst_alch)[0]
187 | _alch_edge_index = self.original_to_alchemical_index[_orig_edge_index, :]
188 | _idx = np.where((_alch_edge_index != -1).all(axis=0))
189 | edge_index.append(_alch_edge_index[:, _idx[0], _idx[1]])
190 | orig_edge_loc.append(_orig_edge_loc[_idx[0]])
191 | edge_weight_indices.append(np.zeros_like(_idx[0]))
192 |
193 | # Collect all edges
194 | edge_index = np.concatenate(edge_index, axis=1)
195 | orig_edge_loc = np.concatenate(orig_edge_loc)
196 | edge_weight_indices = np.concatenate(edge_weight_indices)
197 |
198 | # Convert to torch tensors
199 | edge_index = torch.tensor(edge_index, dtype=torch.long)
200 | shifts = torch.tensor(shifts[orig_edge_loc], dtype=torch.float32)
201 | unit_shifts = torch.tensor(unit_shifts[orig_edge_loc], dtype=torch.float32)
202 |
203 | # Alchemical weights for nodes and edges
204 | weights = F.pad(self.alchemical_weights, (1, 0), "constant", 1.0)
205 | node_weights = weights[self.weight_indices]
206 | edge_weights = weights[edge_weight_indices]
207 |
208 | # Build data batch
209 | atomic_data = torch_geometric.data.Data(
210 | num_nodes=len(self.atom_indices),
211 | edge_index=edge_index,
212 | node_attrs=self.node_attrs,
213 | positions=positions[self.atom_indices],
214 | shifts=shifts,
215 | unit_shifts=unit_shifts,
216 | cell=cell,
217 | node_weights=node_weights,
218 | edge_weights=edge_weights,
219 | node_atom_indices=torch.tensor(self.atom_indices, dtype=torch.long),
220 | )
221 | data_loader = torch_geometric.dataloader.DataLoader(
222 | dataset=[atomic_data],
223 | batch_size=1,
224 | shuffle=False,
225 | drop_last=False,
226 | )
227 | batch = next(iter(data_loader))
228 |
229 | return batch
230 |
231 |
232 | ################################################################################
233 | # Alchemical MACE model
234 | ################################################################################
235 |
236 | # get_outputs function from mace.modules.utils is modified to calculate also
237 | # the alchemical gradients
238 |
239 |
240 | def get_outputs(
241 | energy: torch.Tensor,
242 | positions: torch.Tensor,
243 | displacement: torch.Tensor,
244 | cell: torch.Tensor,
245 | node_weights: torch.Tensor,
246 | edge_weights: torch.Tensor,
247 | retain_graph: bool = False,
248 | create_graph: bool = False,
249 | compute_force: bool = True,
250 | compute_stress: bool = False,
251 | compute_alchemical_grad: bool = False,
252 | ) -> Tuple[
253 | Optional[torch.Tensor],
254 | Optional[torch.Tensor],
255 | Optional[torch.Tensor],
256 | Optional[torch.Tensor],
257 | Optional[torch.Tensor],
258 | ]:
259 | grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(energy)]
260 | if not compute_force:
261 | return None, None, None, None, None
262 | inputs = [positions]
263 | if compute_stress:
264 | inputs.append(displacement)
265 | if compute_alchemical_grad:
266 | inputs.extend([node_weights, edge_weights])
267 | gradients = torch.autograd.grad(
268 | outputs=[energy],
269 | inputs=inputs,
270 | grad_outputs=grad_outputs,
271 | retain_graph=retain_graph,
272 | create_graph=create_graph,
273 | allow_unused=True,
274 | )
275 |
276 | forces = gradients[0]
277 | stress = torch.zeros_like(displacement)
278 | virials = gradients[1] if compute_stress else None
279 | if compute_alchemical_grad:
280 | node_grad, edge_grad = gradients[-2], gradients[-1]
281 | else:
282 | node_grad, edge_grad = None, None
283 | if compute_stress and virials is not None:
284 | cell = cell.view(-1, 3, 3)
285 | volume = torch.einsum(
286 | "zi,zi->z",
287 | cell[:, 0, :],
288 | torch.cross(cell[:, 1, :], cell[:, 2, :], dim=1),
289 | ).unsqueeze(-1)
290 | stress = virials / volume.view(-1, 1, 1)
291 |
292 | if forces is not None:
293 | forces = -1 * forces
294 | if virials is not None:
295 | virials = -1 * virials
296 | return forces, virials, stress, node_grad, edge_grad
297 |
298 |
299 | @compile_mode("script")
300 | class AlchemicalResidualInteractionBlock(RealAgnosticResidualInteractionBlock):
301 | def forward(
302 | self,
303 | node_attrs: torch.Tensor,
304 | node_feats: torch.Tensor,
305 | edge_attrs: torch.Tensor,
306 | edge_feats: torch.Tensor,
307 | edge_index: torch.Tensor,
308 | edge_weights: torch.Tensor, # alchemy
309 | ) -> Tuple[torch.Tensor, torch.Tensor]:
310 | sender = edge_index[0]
311 | receiver = edge_index[1]
312 | num_nodes = node_feats.shape[0]
313 | sc = self.skip_tp(node_feats, node_attrs)
314 | node_feats = self.linear_up(node_feats)
315 | tp_weights = self.conv_tp_weights(edge_feats)
316 | tp_weights = tp_weights * edge_weights[:, None] # alchemy
317 | mji = self.conv_tp(
318 | node_feats[sender], edge_attrs, tp_weights
319 | ) # [n_edges, irreps]
320 | message = scatter_sum(
321 | src=mji, index=receiver, dim=0, dim_size=num_nodes
322 | ) # [n_nodes, irreps]
323 | message = self.linear(message) / self.avg_num_neighbors
324 | return (
325 | self.reshape(message),
326 | sc,
327 | ) # [n_nodes, channels, (lmax + 1)**2]
328 |
329 |
330 | @compile_mode("script")
331 | class AlchemicalMACE(ScaleShiftMACE):
332 | def forward(
333 | self,
334 | data: Dict[str, torch.Tensor],
335 | retain_graph: bool = False, # alchemy
336 | create_graph: bool = False, # alchemy
337 | compute_force: bool = True,
338 | compute_stress: bool = False,
339 | compute_displacement: bool = False,
340 | compute_alchemical_grad: bool = False, # alchemy
341 | map_to_original_atoms: bool = True, # alchemy
342 | ) -> Dict[str, Optional[torch.Tensor]]:
343 | # Setup
344 | data["positions"].requires_grad_(True)
345 | data["node_attrs"].requires_grad_(True)
346 | num_graphs = data["ptr"].numel() - 1
347 | displacement = torch.zeros(
348 | (num_graphs, 3, 3),
349 | dtype=data["positions"].dtype,
350 | device=data["positions"].device,
351 | )
352 | if compute_stress or compute_displacement:
353 | (
354 | data["positions"],
355 | data["shifts"],
356 | displacement,
357 | ) = get_symmetric_displacement(
358 | positions=data["positions"],
359 | unit_shifts=data["unit_shifts"],
360 | cell=data["cell"],
361 | edge_index=data["edge_index"],
362 | num_graphs=num_graphs,
363 | batch=data["batch"],
364 | )
365 |
366 | # Atomic energies
367 | node_e0 = self.atomic_energies_fn(data["node_attrs"])
368 | node_e0 = node_e0 * data["node_weights"] # alchemy
369 | e0 = scatter_sum(
370 | src=node_e0, index=data["batch"], dim=-1, dim_size=num_graphs
371 | ) # [n_graphs,]
372 |
373 | # Embeddings
374 | node_feats = self.node_embedding(data["node_attrs"])
375 | vectors, lengths = get_edge_vectors_and_lengths(
376 | positions=data["positions"],
377 | edge_index=data["edge_index"],
378 | shifts=data["shifts"],
379 | )
380 | edge_attrs = self.spherical_harmonics(vectors)
381 | edge_feats = self.radial_embedding(lengths)
382 |
383 | # Interactions
384 | node_es_list = []
385 | node_feats_list = []
386 | for interaction, product, readout in zip(
387 | self.interactions, self.products, self.readouts
388 | ):
389 | node_feats, sc = interaction(
390 | node_attrs=data["node_attrs"],
391 | node_feats=node_feats,
392 | edge_attrs=edge_attrs,
393 | edge_feats=edge_feats,
394 | edge_index=data["edge_index"],
395 | edge_weights=data["edge_weights"], # alchemy
396 | )
397 | node_feats = product(
398 | node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"]
399 | )
400 | node_feats_list.append(node_feats)
401 | node_es_list.append(readout(node_feats).squeeze(-1)) # {[n_nodes, ], }
402 |
403 | # Concatenate node features
404 | # node_feats_out = torch.cat(node_feats_list, dim=-1)
405 |
406 | # Sum over interactions
407 | node_inter_es = torch.sum(
408 | torch.stack(node_es_list, dim=0), dim=0
409 | ) # [n_nodes, ]
410 | node_inter_es = self.scale_shift(node_inter_es)
411 | node_inter_es = node_inter_es * data["node_weights"] # alchemy
412 |
413 | # Sum over nodes in graph
414 | inter_e = scatter_sum(
415 | src=node_inter_es, index=data["batch"], dim=-1, dim_size=num_graphs
416 | ) # [n_graphs,]
417 |
418 | # Add E_0 and (scaled) interaction energy
419 | total_energy = e0 + inter_e
420 | node_energy = node_e0 + node_inter_es
421 |
422 | forces, virials, stress, node_grad, edge_grad = get_outputs(
423 | energy=total_energy, # alchemy
424 | positions=data["positions"],
425 | displacement=displacement,
426 | cell=data["cell"],
427 | node_weights=data["node_weights"], # alchemy
428 | edge_weights=data["edge_weights"], # alchemy
429 | retain_graph=retain_graph, # alchemy
430 | create_graph=create_graph, # alchemy
431 | compute_force=compute_force,
432 | # compute_virials=compute_virials, # alchemy
433 | compute_stress=compute_stress,
434 | compute_alchemical_grad=compute_alchemical_grad, # alchemy
435 | )
436 |
437 | # Map to original atoms (node energies and forces): alchemy
438 | if map_to_original_atoms:
439 | # Note: we're not giving the dim_size, as we assume that all
440 | # original atoms are present in the batch
441 | node_index = data["node_atom_indices"]
442 | node_energy = scatter_sum(src=node_energy, dim=0, index=node_index)
443 | if compute_force:
444 | forces = scatter_sum(src=forces, dim=0, index=node_index)
445 |
446 | output = {
447 | "energy": total_energy,
448 | "node_energy": node_energy,
449 | "interaction_energy": inter_e,
450 | "forces": forces,
451 | "virials": virials,
452 | "stress": stress,
453 | "displacement": displacement,
454 | "node_grad": node_grad,
455 | "edge_grad": edge_grad,
456 | }
457 |
458 | return output
459 |
460 |
461 | ################################################################################
462 | # Alchemical MACE universal model
463 | ################################################################################
464 |
465 |
466 | def alchemical_mace_mp(
467 | model: str,
468 | device: str,
469 | default_dtype: str = "float32",
470 | ):
471 | """
472 | Load a pre-trained alchemical MACE model.
473 |
474 | Args:
475 | model: Model size (small, medium)
476 | device: Device to load the model onto
477 | default_dtype: Default data type for the model
478 |
479 | Returns:
480 | Alchemical MACE model
481 | """
482 |
483 | # Load foundation MACE model and extract initial parameters
484 | assert model in ("small", "medium") # TODO: support large model
485 | mace = mace_mp(model=model, device=device, default_dtype=default_dtype).models[0]
486 | atomic_energies = mace.atomic_energies_fn.atomic_energies.detach().clone()
487 | z_table = utils.AtomicNumberTable([int(z) for z in mace.atomic_numbers])
488 | atomic_inter_scale = mace.scale_shift.scale.detach().clone()
489 | atomic_inter_shift = mace.scale_shift.shift.detach().clone()
490 |
491 | # Prepare arguments for building the model
492 | placeholder_args = ["--name", "None", "--train_file", "None"]
493 | args = tools.build_default_arg_parser().parse_args(placeholder_args)
494 | args.max_L = {"small": 0, "medium": 1, "large": 2}[model]
495 | args.num_channels = 128
496 | args.hidden_irreps = o3.Irreps(
497 | (args.num_channels * o3.Irreps.spherical_harmonics(args.max_L))
498 | .sort()
499 | .irreps.simplify()
500 | )
501 |
502 | # Build the alchemical MACE model
503 | model = AlchemicalMACE(
504 | r_max=6.0,
505 | num_bessel=10,
506 | num_polynomial_cutoff=5,
507 | max_ell=3,
508 | interaction_cls=AlchemicalResidualInteractionBlock,
509 | interaction_cls_first=AlchemicalResidualInteractionBlock,
510 | num_interactions=2,
511 | num_elements=len(z_table),
512 | hidden_irreps=o3.Irreps(args.hidden_irreps),
513 | MLP_irreps=o3.Irreps(args.MLP_irreps),
514 | atomic_energies=atomic_energies,
515 | avg_num_neighbors=args.avg_num_neighbors,
516 | atomic_numbers=z_table.zs,
517 | correlation=args.correlation,
518 | gate=modules.gate_dict[args.gate],
519 | radial_MLP=ast.literal_eval(args.radial_MLP),
520 | radial_type=args.radial_type,
521 | atomic_inter_scale=atomic_inter_scale,
522 | atomic_inter_shift=atomic_inter_shift,
523 | )
524 |
525 | # Load foundation model parameters
526 | model.load_state_dict(mace.state_dict(), strict=True)
527 | for i in range(int(model.num_interactions)):
528 | model.interactions[i].avg_num_neighbors = mace.interactions[i].avg_num_neighbors
529 | model = model.to(device)
530 | return model
531 |
532 |
533 | def get_z_table_and_r_max(model: ScaleShiftMACE) -> Tuple[AtomicNumberTable, float]:
534 | """Extract the atomic number table and maximum cutoff radius from a MACE model."""
535 | z_table = AtomicNumberTable([int(z) for z in model.atomic_numbers])
536 | r_max = model.r_max.item()
537 | return z_table, r_max
538 |
--------------------------------------------------------------------------------
/alchemical_mace/optimize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, Iterable, Dict, Any
3 |
4 |
5 | class ExponentiatedGradientDescent(torch.optim.Optimizer):
6 | """
7 | Implements Exponentiated Gradient Descent.
8 |
9 | Args:
10 | params (iterable of torch.Tensor or dict): iterable of parameters to optimize or
11 | dicts defining parameter groups.
12 | lr (float, optional): learning rate. Defaults to 1e-3.
13 | eps (float, optional): small constant for numerical stability. Defaults to 1e-8.
14 | """
15 |
16 | def __init__(
17 | self,
18 | params: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]],
19 | lr: float = 1e-3,
20 | eps: float = 1e-8,
21 | ):
22 | super().__init__(params, defaults=dict(lr=lr, eps=eps))
23 |
24 | def step(self):
25 | for group in self.param_groups:
26 | for p in group["params"]:
27 | if p.grad is None:
28 | continue
29 | p.data.mul_(torch.exp(-group["lr"] * p.grad))
30 | p.data.div_(p.data.sum() + group["eps"])
31 |
--------------------------------------------------------------------------------
/alchemical_mace/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from contextlib import ExitStack, contextmanager, redirect_stderr, redirect_stdout
3 |
4 | from ase import Atoms
5 |
6 |
7 | @contextmanager
8 | def suppress_print(out: bool = True, err: bool = False):
9 | """Suppress stdout and/or stderr."""
10 |
11 | with ExitStack() as stack:
12 | devnull = stack.enter_context(open(os.devnull, "w"))
13 | if out:
14 | stack.enter_context(redirect_stdout(devnull))
15 | if err:
16 | stack.enter_context(redirect_stderr(devnull))
17 | yield
18 |
19 |
20 | # From CHGNet
21 | def upper_triangular_cell(atoms: Atoms):
22 | """Transform to upper-triangular cell."""
23 | import numpy as np
24 | from ase.md.npt import NPT
25 |
26 | if NPT._isuppertriangular(atoms.get_cell()):
27 | return
28 |
29 | a, b, c, alpha, beta, gamma = atoms.cell.cellpar()
30 | angles = np.radians((alpha, beta, gamma))
31 | sin_a, sin_b, _sin_g = np.sin(angles)
32 | cos_a, cos_b, cos_g = np.cos(angles)
33 | cos_p = (cos_g - cos_a * cos_b) / (sin_a * sin_b)
34 | cos_p = np.clip(cos_p, -1, 1)
35 | sin_p = (1 - cos_p**2) ** 0.5
36 |
37 | new_basis = [
38 | (a * sin_b * sin_p, a * sin_b * cos_p, a * cos_b),
39 | (0, b * sin_a, b * cos_a),
40 | (0, 0, c),
41 | ]
42 |
43 | atoms.set_cell(new_basis, scale_atoms=True)
44 |
--------------------------------------------------------------------------------
/data/structures/AlN_hex.cif:
--------------------------------------------------------------------------------
1 | # generated using pymatgen
2 | data_AlN
3 | _symmetry_space_group_name_H-M 'P 1'
4 | _cell_length_a 3.12664153
5 | _cell_length_b 3.12664143
6 | _cell_length_c 5.00715332
7 | _cell_angle_alpha 90.00000043
8 | _cell_angle_beta 89.99999986
9 | _cell_angle_gamma 119.99999965
10 | _symmetry_Int_Tables_number 1
11 | _chemical_formula_structural AlN
12 | _chemical_formula_sum 'Al2 N2'
13 | _cell_volume 42.39139351
14 | _cell_formula_units_Z 2
15 | loop_
16 | _symmetry_equiv_pos_site_id
17 | _symmetry_equiv_pos_as_xyz
18 | 1 'x, y, z'
19 | loop_
20 | _atom_site_type_symbol
21 | _atom_site_label
22 | _atom_site_symmetry_multiplicity
23 | _atom_site_fract_x
24 | _atom_site_fract_y
25 | _atom_site_fract_z
26 | _atom_site_occupancy
27 | Al Al0 1 0.66666667 0.33333333 0.49932157 1
28 | Al Al1 1 0.33333333 0.66666667 0.99932156 1
29 | N N2 1 0.66666667 0.33333333 0.88067843 1
30 | N N3 1 0.33333333 0.66666667 0.38067844 1
31 |
--------------------------------------------------------------------------------
/data/structures/BiSBr.cif:
--------------------------------------------------------------------------------
1 | # generated using pymatgen
2 | data_BiSBr
3 | _symmetry_space_group_name_H-M 'P 1'
4 | _cell_length_a 4.10679210
5 | _cell_length_b 8.22484633
6 | _cell_length_c 11.05701800
7 | _cell_angle_alpha 89.99999573
8 | _cell_angle_beta 90.00000347
9 | _cell_angle_gamma 90.00001971
10 | _symmetry_Int_Tables_number 1
11 | _chemical_formula_structural BiSBr
12 | _chemical_formula_sum 'Bi4 S4 Br4'
13 | _cell_volume 373.48101173
14 | _cell_formula_units_Z 4
15 | loop_
16 | _symmetry_equiv_pos_site_id
17 | _symmetry_equiv_pos_as_xyz
18 | 1 'x, y, z'
19 | loop_
20 | _atom_site_type_symbol
21 | _atom_site_label
22 | _atom_site_symmetry_multiplicity
23 | _atom_site_fract_x
24 | _atom_site_fract_y
25 | _atom_site_fract_z
26 | _atom_site_occupancy
27 | Bi Bi0 1 0.24999992 0.88556832 0.87177467 1
28 | Bi Bi1 1 0.75000010 0.11443170 0.12822534 1
29 | Bi Bi2 1 0.24999988 0.38556830 0.62822532 1
30 | Bi Bi3 1 0.75000009 0.61443169 0.37177466 1
31 | S S4 1 0.74999990 0.82391374 0.03167750 1
32 | S S5 1 0.25000008 0.67608625 0.53167751 1
33 | S S6 1 0.74999991 0.32391375 0.46832249 1
34 | S S7 1 0.25000010 0.17608625 0.96832251 1
35 | Br Br8 1 0.24999990 0.96479764 0.30006197 1
36 | Br Br9 1 0.75000009 0.53520235 0.80006197 1
37 | Br Br10 1 0.24999991 0.46479765 0.19993803 1
38 | Br Br11 1 0.75000013 0.03520235 0.69993804 1
39 |
--------------------------------------------------------------------------------
/data/structures/CeO2.cif:
--------------------------------------------------------------------------------
1 | # generated using pymatgen
2 | data_CeO2
3 | _symmetry_space_group_name_H-M 'P 1'
4 | _cell_length_a 5.46789061
5 | _cell_length_b 5.46789009
6 | _cell_length_c 5.46788984
7 | _cell_angle_alpha 90.00000180
8 | _cell_angle_beta 90.00000282
9 | _cell_angle_gamma 90.00000072
10 | _symmetry_Int_Tables_number 1
11 | _chemical_formula_structural CeO2
12 | _chemical_formula_sum 'Ce4 O8'
13 | _cell_volume 163.47801292
14 | _cell_formula_units_Z 4
15 | loop_
16 | _symmetry_equiv_pos_site_id
17 | _symmetry_equiv_pos_as_xyz
18 | 1 'x, y, z'
19 | loop_
20 | _atom_site_type_symbol
21 | _atom_site_label
22 | _atom_site_symmetry_multiplicity
23 | _atom_site_fract_x
24 | _atom_site_fract_y
25 | _atom_site_fract_z
26 | _atom_site_occupancy
27 | Ce Ce0 1 0.00000000 -0.00000000 -0.00000000 1
28 | Ce Ce1 1 0.00000000 0.50000000 0.50000000 1
29 | Ce Ce2 1 0.50000000 0.00000000 0.50000000 1
30 | Ce Ce3 1 0.50000000 0.50000000 0.00000000 1
31 | O O4 1 0.25000000 0.75000000 0.74999999 1
32 | O O5 1 0.74999999 0.25000000 0.25000000 1
33 | O O6 1 0.25000000 0.25000000 0.25000000 1
34 | O O7 1 0.75000000 0.75000000 0.75000000 1
35 | O O8 1 0.74999999 0.75000000 0.25000000 1
36 | O O9 1 0.25000000 0.25000000 0.74999999 1
37 | O O10 1 0.75000000 0.25000000 0.75000000 1
38 | O O11 1 0.25000000 0.75000000 0.25000000 1
39 |
--------------------------------------------------------------------------------
/data/structures/CsPbI3_alpha.cif:
--------------------------------------------------------------------------------
1 | # generated using pymatgen
2 | data_CsPbI3
3 | _symmetry_space_group_name_H-M 'P 1'
4 | _cell_length_a 6.37904246
5 | _cell_length_b 6.37904251
6 | _cell_length_c 6.37904247
7 | _cell_angle_alpha 90.00000013
8 | _cell_angle_beta 89.99999979
9 | _cell_angle_gamma 89.99999953
10 | _symmetry_Int_Tables_number 1
11 | _chemical_formula_structural CsPbI3
12 | _chemical_formula_sum 'Cs1 Pb1 I3'
13 | _cell_volume 259.57716340
14 | _cell_formula_units_Z 1
15 | loop_
16 | _symmetry_equiv_pos_site_id
17 | _symmetry_equiv_pos_as_xyz
18 | 1 'x, y, z'
19 | loop_
20 | _atom_site_type_symbol
21 | _atom_site_label
22 | _atom_site_symmetry_multiplicity
23 | _atom_site_fract_x
24 | _atom_site_fract_y
25 | _atom_site_fract_z
26 | _atom_site_occupancy
27 | Cs Cs0 1 0.00000000 0.00000000 -0.00000000 1
28 | I I1 1 0.50000000 0.50000000 -0.00000000 1
29 | I I2 1 0.50000000 -0.00000000 0.50000000 1
30 | I I3 1 0.00000000 0.50000000 0.50000000 1
31 | Pb Pb4 1 0.50000000 0.50000000 0.50000000 1
32 |
--------------------------------------------------------------------------------
/data/structures/CsPbI3_delta.cif:
--------------------------------------------------------------------------------
1 | # generated using pymatgen
2 | data_CsPbI3
3 | _symmetry_space_group_name_H-M 'P 1'
4 | _cell_length_a 4.91187765
5 | _cell_length_b 10.78602159
6 | _cell_length_c 18.12941751
7 | _cell_angle_alpha 90.00000734
8 | _cell_angle_beta 89.99999876
9 | _cell_angle_gamma 89.99999900
10 | _symmetry_Int_Tables_number 1
11 | _chemical_formula_structural CsPbI3
12 | _chemical_formula_sum 'Cs4 Pb4 I12'
13 | _cell_volume 960.48962177
14 | _cell_formula_units_Z 4
15 | loop_
16 | _symmetry_equiv_pos_site_id
17 | _symmetry_equiv_pos_as_xyz
18 | 1 'x, y, z'
19 | loop_
20 | _atom_site_type_symbol
21 | _atom_site_label
22 | _atom_site_symmetry_multiplicity
23 | _atom_site_fract_x
24 | _atom_site_fract_y
25 | _atom_site_fract_z
26 | _atom_site_occupancy
27 | Cs Cs0 1 0.75000002 0.57436836 0.17329066 1
28 | Cs Cs1 1 0.25000000 0.42563165 0.82670935 1
29 | Cs Cs2 1 0.75000002 0.07436835 0.32670935 1
30 | Cs Cs3 1 0.24999999 0.92563166 0.67329065 1
31 | I I4 1 0.74999999 0.20706041 0.71324789 1
32 | I I5 1 0.25000000 0.79293960 0.28675211 1
33 | I I6 1 0.75000001 0.97331566 0.10968932 1
34 | I I7 1 0.24999998 0.02668433 0.89031069 1
35 | I I8 1 0.74999999 0.47331567 0.39031069 1
36 | I I9 1 0.25000001 0.52668433 0.60968931 1
37 | I I10 1 0.25000000 0.66460269 0.00429655 1
38 | I I11 1 0.25000000 0.16460269 0.49570344 1
39 | I I12 1 0.25000000 0.29293960 0.21324788 1
40 | I I13 1 0.75000001 0.83539729 0.50429655 1
41 | I I14 1 0.75000000 0.33539731 0.99570345 1
42 | I I15 1 0.75000001 0.70706040 0.78675212 1
43 | Pb Pb16 1 0.74999998 0.83743203 0.94239546 1
44 | Pb Pb17 1 0.24999999 0.16256797 0.05760453 1
45 | Pb Pb18 1 0.75000002 0.33743205 0.55760454 1
46 | Pb Pb19 1 0.24999998 0.66256796 0.44239545 1
47 |
--------------------------------------------------------------------------------
/data/structures/CsSnI3_alpha.cif:
--------------------------------------------------------------------------------
1 | # generated using pymatgen
2 | data_CsSnI3
3 | _symmetry_space_group_name_H-M 'P 1'
4 | _cell_length_a 6.27081715
5 | _cell_length_b 6.27081724
6 | _cell_length_c 6.27081711
7 | _cell_angle_alpha 89.99999964
8 | _cell_angle_beta 89.99999969
9 | _cell_angle_gamma 89.99999972
10 | _symmetry_Int_Tables_number 1
11 | _chemical_formula_structural CsSnI3
12 | _chemical_formula_sum 'Cs1 Sn1 I3'
13 | _cell_volume 246.58827085
14 | _cell_formula_units_Z 1
15 | loop_
16 | _symmetry_equiv_pos_site_id
17 | _symmetry_equiv_pos_as_xyz
18 | 1 'x, y, z'
19 | loop_
20 | _atom_site_type_symbol
21 | _atom_site_label
22 | _atom_site_symmetry_multiplicity
23 | _atom_site_fract_x
24 | _atom_site_fract_y
25 | _atom_site_fract_z
26 | _atom_site_occupancy
27 | Cs Cs0 1 0.00000000 0.00000000 -0.00000000 1
28 | I I1 1 0.50000000 0.50000000 0.00000000 1
29 | I I2 1 0.50000000 -0.00000000 0.50000000 1
30 | I I3 1 -0.00000000 0.50000000 0.50000000 1
31 | Sn Sn4 1 0.50000000 0.50000000 0.50000000 1
32 |
--------------------------------------------------------------------------------
/data/structures/CsSnI3_delta.cif:
--------------------------------------------------------------------------------
1 | # generated using pymatgen
2 | data_CsSnI3
3 | _symmetry_space_group_name_H-M 'P 1'
4 | _cell_length_a 4.84918880
5 | _cell_length_b 10.69167480
6 | _cell_length_c 18.19616575
7 | _cell_angle_alpha 89.99999030
8 | _cell_angle_beta 90.00000262
9 | _cell_angle_gamma 89.99999762
10 | _symmetry_Int_Tables_number 1
11 | _chemical_formula_structural CsSnI3
12 | _chemical_formula_sum 'Cs4 Sn4 I12'
13 | _cell_volume 943.39749344
14 | _cell_formula_units_Z 4
15 | loop_
16 | _symmetry_equiv_pos_site_id
17 | _symmetry_equiv_pos_as_xyz
18 | 1 'x, y, z'
19 | loop_
20 | _atom_site_type_symbol
21 | _atom_site_label
22 | _atom_site_symmetry_multiplicity
23 | _atom_site_fract_x
24 | _atom_site_fract_y
25 | _atom_site_fract_z
26 | _atom_site_occupancy
27 | Cs Cs0 1 0.75000000 0.57552416 0.17186806 1
28 | Cs Cs1 1 0.25000001 0.42447584 0.82813194 1
29 | Cs Cs2 1 0.75000000 0.07552414 0.32813195 1
30 | Cs Cs3 1 0.25000000 0.92447584 0.67186806 1
31 | I I4 1 0.75000000 0.21397041 0.70606567 1
32 | I I5 1 0.25000000 0.78602961 0.29393431 1
33 | I I6 1 0.75000000 0.97123818 0.11227438 1
34 | I I7 1 0.25000000 0.02876181 0.88772566 1
35 | I I8 1 0.74999999 0.47123818 0.38772564 1
36 | I I9 1 0.25000001 0.52876182 0.61227437 1
37 | I I10 1 0.25000002 0.66813711 0.00020614 1
38 | I I11 1 0.24999999 0.16813710 0.49979388 1
39 | I I12 1 0.25000001 0.28602959 0.20606567 1
40 | I I13 1 0.74999999 0.83186292 0.50020614 1
41 | I I14 1 0.75000000 0.33186292 0.99979386 1
42 | I I15 1 0.74999999 0.71397042 0.79393433 1
43 | Sn Sn16 1 0.75000000 0.84420702 0.94393743 1
44 | Sn Sn17 1 0.25000000 0.15579296 0.05606252 1
45 | Sn Sn18 1 0.74999998 0.34420702 0.55606253 1
46 | Sn Sn19 1 0.24999999 0.65579297 0.44393746 1
47 |
--------------------------------------------------------------------------------
/data/structures/Fe.cif:
--------------------------------------------------------------------------------
1 | # generated using pymatgen
2 | data_Fe
3 | _symmetry_space_group_name_H-M 'P 1'
4 | _cell_length_a 2.86106543
5 | _cell_length_b 2.86106544
6 | _cell_length_c 2.86106538
7 | _cell_angle_alpha 90.00000018
8 | _cell_angle_beta 89.99999992
9 | _cell_angle_gamma 90.00000009
10 | _symmetry_Int_Tables_number 1
11 | _chemical_formula_structural Fe
12 | _chemical_formula_sum Fe2
13 | _cell_volume 23.41980977
14 | _cell_formula_units_Z 2
15 | loop_
16 | _symmetry_equiv_pos_site_id
17 | _symmetry_equiv_pos_as_xyz
18 | 1 'x, y, z'
19 | loop_
20 | _atom_site_type_symbol
21 | _atom_site_label
22 | _atom_site_symmetry_multiplicity
23 | _atom_site_fract_x
24 | _atom_site_fract_y
25 | _atom_site_fract_z
26 | _atom_site_occupancy
27 | Fe Fe0 1 0.00000000 -0.00000000 0.00000000 1
28 | Fe Fe1 1 0.50000000 0.50000000 0.50000000 1
29 |
--------------------------------------------------------------------------------
/data/structures/GaN_hex.cif:
--------------------------------------------------------------------------------
1 | # generated using pymatgen
2 | data_GaN
3 | _symmetry_space_group_name_H-M 'P 1'
4 | _cell_length_a 3.21192371
5 | _cell_length_b 3.21192377
6 | _cell_length_c 5.21628467
7 | _cell_angle_alpha 90.00000055
8 | _cell_angle_beta 90.00000024
9 | _cell_angle_gamma 119.99999991
10 | _symmetry_Int_Tables_number 1
11 | _chemical_formula_structural GaN
12 | _chemical_formula_sum 'Ga2 N2'
13 | _cell_volume 46.60391121
14 | _cell_formula_units_Z 2
15 | loop_
16 | _symmetry_equiv_pos_site_id
17 | _symmetry_equiv_pos_as_xyz
18 | 1 'x, y, z'
19 | loop_
20 | _atom_site_type_symbol
21 | _atom_site_label
22 | _atom_site_symmetry_multiplicity
23 | _atom_site_fract_x
24 | _atom_site_fract_y
25 | _atom_site_fract_z
26 | _atom_site_occupancy
27 | Ga Ga0 1 0.66666667 0.33333333 0.49900050 1
28 | Ga Ga1 1 0.33333333 0.66666667 0.99900050 1
29 | N N2 1 0.66666667 0.33333333 0.87599950 1
30 | N N3 1 0.33333333 0.66666667 0.37599950 1
31 |
--------------------------------------------------------------------------------
/data/structures/NaCl.cif:
--------------------------------------------------------------------------------
1 | # generated using pymatgen
2 | data_NaCl
3 | _symmetry_space_group_name_H-M 'P 1'
4 | _cell_length_a 5.68304678
5 | _cell_length_b 5.68304679
6 | _cell_length_c 5.68304672
7 | _cell_angle_alpha 90.00000009
8 | _cell_angle_beta 90.00000015
9 | _cell_angle_gamma 90.00000005
10 | _symmetry_Int_Tables_number 1
11 | _chemical_formula_structural NaCl
12 | _chemical_formula_sum 'Na4 Cl4'
13 | _cell_volume 183.54547783
14 | _cell_formula_units_Z 4
15 | loop_
16 | _symmetry_equiv_pos_site_id
17 | _symmetry_equiv_pos_as_xyz
18 | 1 'x, y, z'
19 | loop_
20 | _atom_site_type_symbol
21 | _atom_site_label
22 | _atom_site_symmetry_multiplicity
23 | _atom_site_fract_x
24 | _atom_site_fract_y
25 | _atom_site_fract_z
26 | _atom_site_occupancy
27 | Na Na0 1 -0.00000000 -0.00000000 -0.00000000 1
28 | Na Na1 1 -0.00000000 0.50000000 0.50000000 1
29 | Na Na2 1 0.50000000 -0.00000000 0.50000000 1
30 | Na Na3 1 0.50000000 0.50000000 -0.00000000 1
31 | Cl Cl4 1 0.00000000 0.00000000 0.50000000 1
32 | Cl Cl5 1 -0.00000000 0.50000000 0.00000000 1
33 | Cl Cl6 1 0.50000000 0.00000000 -0.00000000 1
34 | Cl Cl7 1 0.50000000 0.50000000 0.50000000 1
35 |
--------------------------------------------------------------------------------
/notebooks/1_solid_solution.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import ase\n",
10 | "import matplotlib.pyplot as plt\n",
11 | "import numpy as np\n",
12 | "from tqdm import tqdm\n",
13 | "\n",
14 | "from alchemical_mace.calculator import get_alchemical_optimized_cellpar\n",
15 | "from alchemical_mace.model import AlchemicalPair\n",
16 | "from alchemical_mace.utils import suppress_print\n",
17 | "\n",
18 | "plt.style.use(\"default\")"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "### CeO2"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": 2,
31 | "metadata": {},
32 | "outputs": [
33 | {
34 | "name": "stderr",
35 | "output_type": "stream",
36 | "text": [
37 | "100%|██████████| 21/21 [01:15<00:00, 3.60s/it]\n"
38 | ]
39 | }
40 | ],
41 | "source": [
42 | "# Default settings\n",
43 | "model = \"medium\"\n",
44 | "device = \"cpu\"\n",
45 | "\n",
46 | "# Load structure\n",
47 | "atoms = ase.io.read(\"../data/structures/CeO2.cif\")\n",
48 | "alch_elements = [\"Ce\", \"Sn\"]\n",
49 | "alch_idx = [i for i, atom in enumerate(atoms) if atom.symbol in alch_elements]\n",
50 | "alch_atomic_numbers = [ase.Atoms(el).numbers[0] for el in alch_elements]\n",
51 | "alchemical_pairs = [\n",
52 | " [AlchemicalPair(atom_index=idx, atomic_number=z) for idx in alch_idx]\n",
53 | " for z in alch_atomic_numbers\n",
54 | "]\n",
55 | "\n",
56 | "comp_grid = [[1 - x, x] for x in np.linspace(0, 0.5, 21)]\n",
57 | "lat_params_CeSn = []\n",
58 | "for comp in tqdm(comp_grid):\n",
59 | " with suppress_print(out=True, err=True):\n",
60 | " cellpar = get_alchemical_optimized_cellpar(\n",
61 | " atoms, alchemical_pairs, comp, model=model, device=device\n",
62 | " )\n",
63 | " lat_params_CeSn.append(cellpar[:3])"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 3,
69 | "metadata": {},
70 | "outputs": [
71 | {
72 | "name": "stderr",
73 | "output_type": "stream",
74 | "text": [
75 | "100%|██████████| 21/21 [01:35<00:00, 4.52s/it]\n"
76 | ]
77 | }
78 | ],
79 | "source": [
80 | "# Load structure\n",
81 | "alch_elements = [\"Ce\", \"Zr\"]\n",
82 | "alch_idx = [i for i, atom in enumerate(atoms) if atom.symbol in alch_elements]\n",
83 | "alch_atomic_numbers = [ase.Atoms(el).numbers[0] for el in alch_elements]\n",
84 | "alchemical_pairs = [\n",
85 | " [AlchemicalPair(atom_index=idx, atomic_number=z) for idx in alch_idx]\n",
86 | " for z in alch_atomic_numbers\n",
87 | "]\n",
88 | "\n",
89 | "comp_grid = [[1 - x, x] for x in np.linspace(0, 0.5, 21)]\n",
90 | "lat_params_CeZr = []\n",
91 | "for comp in tqdm(comp_grid):\n",
92 | " with suppress_print(out=True, err=True):\n",
93 | " cellpar = get_alchemical_optimized_cellpar(\n",
94 | " atoms, alchemical_pairs, comp, model=model, device=device\n",
95 | " )\n",
96 | " lat_params_CeZr.append(cellpar[:3])"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 4,
102 | "metadata": {},
103 | "outputs": [
104 | {
105 | "data": {
106 | "image/png": "",
107 | "text/plain": [
108 | ""
109 | ]
110 | },
111 | "metadata": {},
112 | "output_type": "display_data"
113 | }
114 | ],
115 | "source": [
116 | "fig, ax = plt.subplots(figsize=(3, 2.5))\n",
117 | "ax.plot(\n",
118 | " [x[1] for x in comp_grid],\n",
119 | " [x[0] for x in lat_params_CeSn],\n",
120 | " label=\"Ce$_{1-x}$Sn$_x$O$_2$\",\n",
121 | ")\n",
122 | "ax.plot(\n",
123 | " [x[1] for x in comp_grid],\n",
124 | " [x[0] for x in lat_params_CeZr],\n",
125 | " label=\"Ce$_{1-x}$Zr$_x$O$_2$\",\n",
126 | ")\n",
127 | "ax.set_xlabel(\"$x$\")\n",
128 | "ax.set_ylabel(\"a [Å]\")\n",
129 | "ax.legend()\n",
130 | "fig.show()"
131 | ]
132 | },
133 | {
134 | "cell_type": "markdown",
135 | "metadata": {},
136 | "source": [
137 | "### BiSBr"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 5,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "# Load structure\n",
147 | "atoms = ase.io.read(\"../data/structures/BiSBr.cif\")\n",
148 | "halide_elements = [\"Cl\", \"Br\", \"I\"]\n",
149 | "halide_idx = [i for i, atom in enumerate(atoms) if atom.symbol in halide_elements]\n",
150 | "halide_atomic_numbers = [ase.Atoms(el).numbers[0] for el in halide_elements]\n",
151 | "alchemical_pairs = [\n",
152 | " [AlchemicalPair(atom_index=idx, atomic_number=z) for idx in halide_idx]\n",
153 | " for z in halide_atomic_numbers\n",
154 | "]"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": 6,
160 | "metadata": {},
161 | "outputs": [
162 | {
163 | "name": "stderr",
164 | "output_type": "stream",
165 | "text": [
166 | "100%|██████████| 21/21 [08:12<00:00, 23.46s/it]\n",
167 | "100%|██████████| 21/21 [04:21<00:00, 12.45s/it]\n",
168 | "100%|██████████| 21/21 [05:35<00:00, 15.95s/it]\n"
169 | ]
170 | }
171 | ],
172 | "source": [
173 | "comp_grid = [[1 - x, x, 0] for x in np.linspace(0, 1, 21)]\n",
174 | "lat_params_ClBr = []\n",
175 | "for comp in tqdm(comp_grid):\n",
176 | " with suppress_print(out=True, err=True):\n",
177 | " cellpar = get_alchemical_optimized_cellpar(\n",
178 | " atoms, alchemical_pairs, comp, model=model, device=device\n",
179 | " )\n",
180 | " lat_params_ClBr.append(cellpar[:3])\n",
181 | "\n",
182 | "comp_grid = [[0, 1 - x, x] for x in np.linspace(0, 1, 21)]\n",
183 | "lat_params_BrI = []\n",
184 | "for comp in tqdm(comp_grid):\n",
185 | " with suppress_print(out=True, err=True):\n",
186 | " cellpar = get_alchemical_optimized_cellpar(\n",
187 | " atoms, alchemical_pairs, comp, model=model, device=device\n",
188 | " )\n",
189 | " lat_params_BrI.append(cellpar[:3])\n",
190 | "\n",
191 | "\n",
192 | "comp_grid = [[1 - x, 0, x] for x in np.linspace(0, 1, 21)]\n",
193 | "lat_params_ClI = []\n",
194 | "for comp in tqdm(comp_grid):\n",
195 | " with suppress_print(out=True, err=True):\n",
196 | " cellpar = get_alchemical_optimized_cellpar(\n",
197 | " atoms, alchemical_pairs, comp, model=model, device=device\n",
198 | " )\n",
199 | " lat_params_ClI.append(cellpar[:3])"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": 7,
205 | "metadata": {},
206 | "outputs": [
207 | {
208 | "data": {
209 | "image/png": "",
210 | "text/plain": [
211 | ""
212 | ]
213 | },
214 | "metadata": {},
215 | "output_type": "display_data"
216 | }
217 | ],
218 | "source": [
219 | "fig, ax = plt.subplots(figsize=(3.5, 2.5), layout=\"constrained\")\n",
220 | "comp = np.linspace(0, 1, 21)\n",
221 | "idx, param = 2, \"c\"\n",
222 | "ax.plot(comp, [lat[idx] for lat in lat_params_ClBr], label=\"BiSCl$_{1-x}$Br$_x$\")\n",
223 | "ax.plot(comp, [lat[idx] for lat in lat_params_BrI], label=\"BiSBr$_{1-x}$I$_x$\")\n",
224 | "ax.plot(comp, [lat[idx] for lat in lat_params_ClI], label=\"BiSCl$_{1-x}$I$_x$\")\n",
225 | "ax.legend()\n",
226 | "ax.set_xlabel(\"$x$\")\n",
227 | "ax.set_ylabel(f\"{param} [Å]\")\n",
228 | "fig.show()"
229 | ]
230 | }
231 | ],
232 | "metadata": {
233 | "kernelspec": {
234 | "display_name": "chgnet",
235 | "language": "python",
236 | "name": "python3"
237 | },
238 | "language_info": {
239 | "codemirror_mode": {
240 | "name": "ipython",
241 | "version": 3
242 | },
243 | "file_extension": ".py",
244 | "mimetype": "text/x-python",
245 | "name": "python",
246 | "nbconvert_exporter": "python",
247 | "pygments_lexer": "ipython3",
248 | "version": "3.10.14"
249 | }
250 | },
251 | "nbformat": 4,
252 | "nbformat_minor": 2
253 | }
254 |
--------------------------------------------------------------------------------
/notebooks/4_vacancy_analysis.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from pathlib import Path\n",
10 | "\n",
11 | "import numpy as np\n",
12 | "import pandas as pd\n",
13 | "from ase import units\n",
14 | "from scipy.integrate import trapezoid"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "def integrate_switching(\n",
24 | " df_log: pd.DataFrame,\n",
25 | " equil_time: int = 20000,\n",
26 | " switch_time: int = 30000,\n",
27 | " return_E_diss: bool = False,\n",
28 | "):\n",
29 | " fwd_start, fwd_end = equil_time, equil_time + switch_time\n",
30 | " rev_start, rev_end = 2 * equil_time + switch_time, 2 * equil_time + 2 * switch_time\n",
31 | " grad, lamda = df_log[\"lambda_grad\"], df_log[\"lambda\"]\n",
32 | " W_fwd = trapezoid(grad[fwd_start:fwd_end], lamda[fwd_start:fwd_end])\n",
33 | " W_rev = trapezoid(grad[rev_start:rev_end], lamda[rev_start:rev_end])\n",
34 | " if return_E_diss:\n",
35 | " return (W_fwd - W_rev) / 2, (W_fwd + W_rev) / 2\n",
36 | " return (W_fwd - W_rev) / 2 # free energy difference\n",
37 | "\n",
38 | "\n",
39 | "def analyze_frenkel_ladd(\n",
40 | " base_path: Path,\n",
41 | " temp: float,\n",
42 | " equil_time: int = 20000,\n",
43 | " switch_time: int = 30000,\n",
44 | " verbose: bool = False,\n",
45 | "):\n",
46 | " T = temp\n",
47 | " k = np.load(base_path / \"spring_constants.npy\")\n",
48 | "\n",
49 | " mass = np.load(base_path / \"masses.npy\")\n",
50 | " omega = np.sqrt(k / mass)\n",
51 | " n_atoms = len(mass)\n",
52 | "\n",
53 | " # 1. Perfect crystal\n",
54 | " df_log = pd.read_csv(base_path / \"observables.csv\")\n",
55 | " volume = df_log[\"volume\"].values[0]\n",
56 | " if verbose:\n",
57 | " _, E_diss_perfect = integrate_switching(\n",
58 | " df_log, equil_time, switch_time, return_E_diss=True\n",
59 | " )\n",
60 | " delta_F = integrate_switching(df_log, equil_time, switch_time)\n",
61 | " F_E = 3 * units.kB * T * np.mean(np.log(units._hbar * omega / (units.kB * T)))\n",
62 | " PV = volume * 1.01325 * units.bar\n",
63 | " G_perfect = delta_F + F_E + PV\n",
64 | "\n",
65 | " # 2. Defective crystal\n",
66 | " df_log = pd.read_csv(base_path / \"observables_defect.csv\")\n",
67 | " volume = df_log[\"volume\"].values[0]\n",
68 | " if verbose:\n",
69 | " _, E_diss_defect = integrate_switching(\n",
70 | " df_log, equil_time, switch_time, return_E_diss=True\n",
71 | " )\n",
72 | " delta_F = integrate_switching(df_log, equil_time, switch_time)\n",
73 | " F_E = 3 * units.kB * T * np.mean(np.log(units._hbar * omega / (units.kB * T)))\n",
74 | " PV = volume * 1.01325 * units.bar\n",
75 | " G_defect = delta_F + F_E + PV\n",
76 | " G_v = G_defect * (n_atoms - 1) - G_perfect * (n_atoms - 1)\n",
77 | "\n",
78 | " # 3. Partial FL\n",
79 | " df_log = pd.read_csv(base_path / \"observables_FL.csv\")\n",
80 | " if verbose:\n",
81 | " _, E_diss_FL = integrate_switching(\n",
82 | " df_log, equil_time, switch_time, return_E_diss=True\n",
83 | " )\n",
84 | " delta_G = integrate_switching(df_log, equil_time, switch_time)\n",
85 | " # delta_G * N = (G_defect * N-1 + F_E) - G_perfect * N\n",
86 | " G_defect_alt = (delta_G * n_atoms - F_E + G_perfect * n_atoms) / (n_atoms - 1)\n",
87 | " G_v_alt = G_defect_alt * (n_atoms - 1) - G_perfect * (n_atoms - 1)\n",
88 | "\n",
89 | " if verbose:\n",
90 | " return G_perfect, G_defect, delta_G, E_diss_perfect, E_diss_defect, E_diss_FL\n",
91 | " return G_v, G_v_alt"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 3,
97 | "metadata": {},
98 | "outputs": [
99 | {
100 | "name": "stdout",
101 | "output_type": "stream",
102 | "text": [
103 | "G_v (50 K) = 1.5513 ± 0.0063 eV\n",
104 | "G_v FL (50 K) = 1.5582 ± 0.0014 eV\n",
105 | "G_v (100 K) = 1.5475 ± 0.0060 eV\n",
106 | "G_v FL (100 K) = 1.5405 ± 0.0018 eV\n",
107 | "G_v (150 K) = 1.5232 ± 0.0276 eV\n",
108 | "G_v FL (150 K) = 1.5188 ± 0.0029 eV\n",
109 | "G_v (200 K) = 1.5110 ± 0.0283 eV\n",
110 | "G_v FL (200 K) = 1.5003 ± 0.0072 eV\n"
111 | ]
112 | }
113 | ],
114 | "source": [
115 | "result_path = Path(\"../data/results/vacancy\")\n",
116 | "temp_range = [50, 100, 150, 200]\n",
117 | "\n",
118 | "G_v_all, G_v_std_all = [], []\n",
119 | "G_v_alt_all, G_v_alt_std_all = [], []\n",
120 | "for temp in temp_range:\n",
121 | " G_v_list = []\n",
122 | " G_v_alt_list = []\n",
123 | " E_diss_perfect_list = []\n",
124 | " E_diss_defect_list = []\n",
125 | " E_diss_FL_list = []\n",
126 | " for i in range(4):\n",
127 | " base_path = result_path / f\"Fe_5x5x5_{temp}K/{i}\"\n",
128 | " G_v, G_v_alt = analyze_frenkel_ladd(base_path, temp=temp, verbose=False)\n",
129 | " G_v_list.append(G_v)\n",
130 | " G_v_alt_list.append(G_v_alt)\n",
131 | " G_v = np.mean(G_v_list)\n",
132 | " G_v_std = np.std(G_v_list)\n",
133 | " G_v_alt = np.mean(G_v_alt_list)\n",
134 | " G_v_alt_std = np.std(G_v_alt_list)\n",
135 | " print(f\"G_v ({temp} K) = {G_v:.4f} ± {G_v_std:.4f} eV\")\n",
136 | " print(f\"G_v FL ({temp} K) = {G_v_alt:.4f} ± {G_v_alt_std:.4f} eV\")\n",
137 | " G_v_all.append(G_v)\n",
138 | " G_v_std_all.append(G_v_std)\n",
139 | " G_v_alt_all.append(G_v_alt)\n",
140 | " G_v_alt_std_all.append(G_v_alt_std)"
141 | ]
142 | }
143 | ],
144 | "metadata": {
145 | "kernelspec": {
146 | "display_name": "chgnet",
147 | "language": "python",
148 | "name": "python3"
149 | },
150 | "language_info": {
151 | "codemirror_mode": {
152 | "name": "ipython",
153 | "version": 3
154 | },
155 | "file_extension": ".py",
156 | "mimetype": "text/x-python",
157 | "name": "python",
158 | "nbconvert_exporter": "python",
159 | "pygments_lexer": "ipython3",
160 | "version": "3.10.14"
161 | }
162 | },
163 | "nbformat": 4,
164 | "nbformat_minor": 2
165 | }
166 |
--------------------------------------------------------------------------------
/notebooks/5_perovskite_analysis.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from pathlib import Path\n",
10 | "\n",
11 | "import numpy as np\n",
12 | "import pandas as pd\n",
13 | "from ase import units\n",
14 | "from scipy.integrate import trapezoid"
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "### Frenkel–Ladd path"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "def integrate_switching(\n",
31 | " df_log: pd.DataFrame,\n",
32 | " equil_time: int = 20000,\n",
33 | " switch_time: int = 30000,\n",
34 | " return_E_diss: bool = False,\n",
35 | "):\n",
36 | " fwd_start, fwd_end = equil_time, equil_time + switch_time\n",
37 | " rev_start, rev_end = 2 * equil_time + switch_time, 2 * equil_time + 2 * switch_time\n",
38 | " grad, lamda = df_log[\"lambda_grad\"], df_log[\"lambda\"]\n",
39 | " W_fwd = trapezoid(grad[fwd_start:fwd_end], lamda[fwd_start:fwd_end])\n",
40 | " W_rev = trapezoid(grad[rev_start:rev_end], lamda[rev_start:rev_end])\n",
41 | " if return_E_diss:\n",
42 | " return (W_fwd - W_rev) / 2, (W_fwd + W_rev) / 2\n",
43 | " return (W_fwd - W_rev) / 2 # free energy difference\n",
44 | "\n",
45 | "\n",
46 | "def analyze_frenkel_ladd(\n",
47 | " base_path: Path,\n",
48 | " temp: float,\n",
49 | " equil_time: int = 20000,\n",
50 | " switch_time: int = 30000,\n",
51 | "):\n",
52 | " T = temp\n",
53 | " df_log = pd.read_csv(base_path / \"observables.csv\")\n",
54 | " k = np.load(base_path / \"spring_constants.npy\")\n",
55 | " mass = np.load(base_path / \"masses.npy\")\n",
56 | " omega = np.sqrt(k / mass)\n",
57 | " volume = df_log[\"volume\"].values[0]\n",
58 | "\n",
59 | " delta_F = integrate_switching(df_log, equil_time, switch_time)\n",
60 | " F_E = 3 * units.kB * T * np.mean(np.log(units._hbar * omega / (units.kB * T)))\n",
61 | " PV = volume * 1.01325 * units.bar\n",
62 | " delta_G = delta_F + F_E + PV\n",
63 | "\n",
64 | " return delta_G\n",
65 | "\n",
66 | "\n",
67 | "def analyze_alchemical_switching(\n",
68 | " base_path: Path,\n",
69 | " temp: float,\n",
70 | " equil_time: int = 20000,\n",
71 | " switch_time: int = 30000,\n",
72 | "):\n",
73 | " T = temp\n",
74 | " df_log = pd.read_csv(base_path / \"observables.csv\")\n",
75 | " mass_init = np.load(base_path / \"masses_init.npy\")\n",
76 | " mass_final = np.load(base_path / \"masses_final.npy\")\n",
77 | "\n",
78 | " work = integrate_switching(df_log, equil_time, switch_time)\n",
79 | " G_mass = 1.5 * units.kB * T * np.mean(np.log(mass_init / mass_final))\n",
80 | " delta_G = work + G_mass\n",
81 | "\n",
82 | " return delta_G"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 3,
88 | "metadata": {},
89 | "outputs": [
90 | {
91 | "name": "stdout",
92 | "output_type": "stream",
93 | "text": [
94 | "CsPbI3 Alpha G (300 K) = -8.8514 ± 0.0003 eV/atom\n",
95 | "CsPbI3 Alpha G (350 K) = -9.8642 ± 0.0004 eV/atom\n",
96 | "CsPbI3 Alpha G (400 K) = -10.8783 ± 0.0006 eV/atom\n",
97 | "CsPbI3 Alpha G (450 K) = -11.8943 ± 0.0003 eV/atom\n",
98 | "CsPbI3 Alpha G (500 K) = -12.9119 ± 0.0004 eV/atom\n",
99 | "CsPbI3 Delta G (300 K) = -8.8560 ± 0.0001 eV/atom\n",
100 | "CsPbI3 Delta G (350 K) = -9.8660 ± 0.0001 eV/atom\n",
101 | "CsPbI3 Delta G (400 K) = -10.8782 ± 0.0002 eV/atom\n",
102 | "CsPbI3 Delta G (450 K) = -11.8921 ± 0.0002 eV/atom\n",
103 | "CsPbI3 Delta G (500 K) = -12.9072 ± 0.0002 eV/atom\n"
104 | ]
105 | }
106 | ],
107 | "source": [
108 | "result_path = Path(\"../data/results/perovskite/frenkel_ladd\")\n",
109 | "temp_range = [300, 350, 400, 450, 500]\n",
110 | "\n",
111 | "G_alpha = []\n",
112 | "G_alpha_std = []\n",
113 | "for temp in temp_range:\n",
114 | " G_list = []\n",
115 | " for i in range(4):\n",
116 | " base_path = result_path / f\"CsPbI3_alpha_6x6x6_{temp}K/{i}\"\n",
117 | " G_list.append(analyze_frenkel_ladd(base_path, temp=temp))\n",
118 | " G = np.mean(G_list)\n",
119 | " G_std = np.std(G_list)\n",
120 | " print(f\"CsPbI3 Alpha G ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n",
121 | " G_alpha.append(G)\n",
122 | " G_alpha_std.append(G_std)\n",
123 | "\n",
124 | "G_delta = []\n",
125 | "G_delta_std = []\n",
126 | "for temp in temp_range:\n",
127 | " G_list = []\n",
128 | " for i in range(4):\n",
129 | " base_path = result_path / f\"CsPbI3_delta_6x3x3_{temp}K/{i}\"\n",
130 | " G_list.append(analyze_frenkel_ladd(base_path, temp=temp))\n",
131 | " G = np.mean(G_list)\n",
132 | " G_std = np.std(G_list)\n",
133 | " print(f\"CsPbI3 Delta G ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n",
134 | " G_delta.append(G)\n",
135 | " G_delta_std.append(G_std)"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 4,
141 | "metadata": {},
142 | "outputs": [
143 | {
144 | "name": "stdout",
145 | "output_type": "stream",
146 | "text": [
147 | "CsSnI3 Alpha G (300 K) = -8.8297 ± 0.0002 eV/atom\n",
148 | "CsSnI3 Alpha G (350 K) = -9.8413 ± 0.0002 eV/atom\n",
149 | "CsSnI3 Alpha G (400 K) = -10.8544 ± 0.0002 eV/atom\n",
150 | "CsSnI3 Alpha G (450 K) = -11.8695 ± 0.0003 eV/atom\n",
151 | "CsSnI3 Alpha G (500 K) = -12.8863 ± 0.0003 eV/atom\n",
152 | "CsSnI3 Delta G (300 K) = -8.8289 ± 0.0001 eV/atom\n",
153 | "CsSnI3 Delta G (350 K) = -9.8381 ± 0.0000 eV/atom\n",
154 | "CsSnI3 Delta G (400 K) = -10.8494 ± 0.0003 eV/atom\n",
155 | "CsSnI3 Delta G (450 K) = -11.8627 ± 0.0003 eV/atom\n",
156 | "CsSnI3 Delta G (500 K) = -12.8771 ± 0.0003 eV/atom\n"
157 | ]
158 | }
159 | ],
160 | "source": [
161 | "G_CsSnI3_alpha = []\n",
162 | "G_CsSnI3_alpha_std = []\n",
163 | "for temp in temp_range:\n",
164 | " G_list = []\n",
165 | " for i in range(4):\n",
166 | " base_path = result_path / f\"CsSnI3_alpha_6x6x6_{temp}K/{i}\"\n",
167 | " G_list.append(analyze_frenkel_ladd(base_path, temp=temp))\n",
168 | " G = np.mean(G_list)\n",
169 | " G_std = np.std(G_list)\n",
170 | " print(f\"CsSnI3 Alpha G ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n",
171 | " G_CsSnI3_alpha.append(G)\n",
172 | " G_CsSnI3_alpha_std.append(G_std)\n",
173 | "\n",
174 | "G_CsSnI3_delta = []\n",
175 | "G_CsSnI3_delta_std = []\n",
176 | "for temp in temp_range:\n",
177 | " G_list = []\n",
178 | " for i in range(4):\n",
179 | " base_path = result_path / f\"CsSnI3_delta_6x3x3_{temp}K/{i}\"\n",
180 | " G_list.append(analyze_frenkel_ladd(base_path, temp=temp))\n",
181 | " G = np.mean(G_list)\n",
182 | " G_std = np.std(G_list)\n",
183 | " print(f\"CsSnI3 Delta G ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n",
184 | " G_CsSnI3_delta.append(G)\n",
185 | " G_CsSnI3_delta_std.append(G_std)"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {},
191 | "source": [
192 | "### Alchemical path"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 5,
198 | "metadata": {},
199 | "outputs": [
200 | {
201 | "name": "stdout",
202 | "output_type": "stream",
203 | "text": [
204 | "Alpha ΔG (300 K) = 0.0233 ± 0.0001 eV/atom\n",
205 | "Alpha ΔG (350 K) = 0.0236 ± 0.0001 eV/atom\n",
206 | "Alpha ΔG (400 K) = 0.0241 ± 0.0001 eV/atom\n",
207 | "Alpha ΔG (450 K) = 0.0249 ± 0.0000 eV/atom\n",
208 | "Alpha ΔG (500 K) = 0.0258 ± 0.0000 eV/atom\n",
209 | "Delta ΔG (300 K) = 0.0271 ± 0.0000 eV/atom\n",
210 | "Delta ΔG (350 K) = 0.0279 ± 0.0000 eV/atom\n",
211 | "Delta ΔG (400 K) = 0.0286 ± 0.0000 eV/atom\n",
212 | "Delta ΔG (450 K) = 0.0294 ± 0.0000 eV/atom\n",
213 | "Delta ΔG (500 K) = 0.0301 ± 0.0000 eV/atom\n"
214 | ]
215 | }
216 | ],
217 | "source": [
218 | "result_path = Path(\"../data/results/perovskite/alchemy\")\n",
219 | "\n",
220 | "G_alpha = []\n",
221 | "G_alpha_std = []\n",
222 | "for temp in temp_range:\n",
223 | " G_list = []\n",
224 | " for i in range(4):\n",
225 | " base_path = result_path / f\"CsPbI3_CsSnI3_alpha_{temp}K/{i}\"\n",
226 | " G_list.append(analyze_alchemical_switching(base_path, temp=temp))\n",
227 | " G = np.mean(G_list)\n",
228 | " G_std = np.std(G_list)\n",
229 | " print(f\"Alpha ΔG ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n",
230 | " G_alpha.append(G)\n",
231 | " G_alpha_std.append(G_std)\n",
232 | "\n",
233 | "G_delta = []\n",
234 | "G_delta_std = []\n",
235 | "for temp in temp_range:\n",
236 | " G_list = []\n",
237 | " for i in range(4):\n",
238 | " base_path = result_path / f\"CsPbI3_CsSnI3_delta_{temp}K/{i}\"\n",
239 | " G_list.append(analyze_alchemical_switching(base_path, temp=temp))\n",
240 | " G = np.mean(G_list)\n",
241 | " G_std = np.std(G_list)\n",
242 | " print(f\"Delta ΔG ({temp} K) = {G:.4f} ± {G_std:.4f} eV/atom\")\n",
243 | " G_delta.append(G)\n",
244 | " G_delta_std.append(G_std)"
245 | ]
246 | }
247 | ],
248 | "metadata": {
249 | "kernelspec": {
250 | "display_name": "chgnet",
251 | "language": "python",
252 | "name": "python3"
253 | },
254 | "language_info": {
255 | "codemirror_mode": {
256 | "name": "ipython",
257 | "version": 3
258 | },
259 | "file_extension": ".py",
260 | "mimetype": "text/x-python",
261 | "name": "python",
262 | "nbconvert_exporter": "python",
263 | "pygments_lexer": "ipython3",
264 | "version": "3.10.14"
265 | }
266 | },
267 | "nbformat": 4,
268 | "nbformat_minor": 2
269 | }
270 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "alchemical_mace"
7 | authors = [
8 | { name = "Juno Nam", email = "junonam@mit.edu" },
9 | ]
10 | description = "Alchemical MACE model"
11 | readme = "README.md"
12 | requires-python = ">=3.9"
13 | version = "0.1.0"
14 |
15 | [tool.setuptools]
16 | packages = ["alchemical_mace"]
17 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | e3nn==0.4.4
3 | mace-torch==0.3.4
4 | ase==3.22.1
5 | pymatgen==2024.3.1
6 | numpy==1.25.2
7 | scipy==1.11.2
8 | pandas==2.2.2
9 | matplotlib==3.8.0
10 | mpltern==1.0.2
11 | tqdm==4.66.3
12 | ipykernel==6.25.2
13 |
--------------------------------------------------------------------------------
/scripts/perovskite_alchemy.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import ase
5 | import numpy as np
6 | import pandas as pd
7 | from ase import units
8 | from ase.build import make_supercell
9 | from ase.constraints import ExpCellFilter
10 | from ase.md.npt import NPT
11 | from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen
12 | from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
13 | from ase.optimize import FIRE
14 | from mace.calculators import mace_mp
15 | from tqdm import tqdm
16 |
17 | from alchemical_mace.calculator import AlchemicalMACECalculator
18 | from alchemical_mace.model import AlchemicalPair
19 | from alchemical_mace.utils import upper_triangular_cell
20 |
21 |
22 | # Arguments
23 | parser = argparse.ArgumentParser()
24 |
25 | # Structure
26 | parser.add_argument("--structure-file", type=str)
27 | parser.add_argument("--supercell", type=int, nargs=3, default=[6, 6, 6])
28 |
29 | # Alchemy
30 | parser.add_argument("--switch-pair", type=str, nargs=2, default=["Pb", "Sn"])
31 |
32 | # Molecular dynamics: general
33 | parser.add_argument("--temperature", type=float, default=300.0)
34 | parser.add_argument("--pressure", type=float, default=1.0)
35 | parser.add_argument("--timestep", type=float, default=2.0)
36 | parser.add_argument("--ttime", type=float, default=25.0)
37 | parser.add_argument("--ptime", type=int, default=75.0)
38 |
39 | # Molecular dynamics: timesteps
40 | parser.add_argument("--npt-equil-stpes", type=int, default=10000)
41 | parser.add_argument("--alchemy-equil-steps", type=int, default=20000)
42 | parser.add_argument("--alchemy-switch-steps", type=int, default=30000)
43 |
44 | # Molecular dynamics: output control
45 | parser.add_argument("--output-dir", type=Path, default=Path("results"))
46 | parser.add_argument("--log-interval", type=int, default=1)
47 |
48 | # MACE model
49 | parser.add_argument("--device", type=str, default="cuda")
50 | parser.add_argument("--model", type=str, default="small")
51 |
52 | args = parser.parse_args()
53 | args.output_dir.mkdir(exist_ok=True, parents=True)
54 |
55 | # Load structure
56 | atoms = ase.io.read(args.structure_file)
57 | atoms = make_supercell(atoms, np.diag(args.supercell))
58 |
59 | # Load universal MACE calculator and relax the structure
60 | mace_calc = mace_mp(model=args.model, device=args.device, default_dtype="float32")
61 | atoms.calc = mace_calc
62 | atoms = ExpCellFilter(atoms)
63 | optimizer = FIRE(atoms)
64 | optimizer.run(fmax=0.01, steps=500)
65 | atoms = atoms.atoms # get the relaxed structure
66 | initial_atoms = atoms.copy() # save the initial structure
67 |
68 |
69 | ################################################################################
70 | # Cell volume equilibration
71 | ################################################################################
72 |
73 | atoms = initial_atoms.copy()
74 | atoms.set_calculator(mace_calc)
75 | bulk_modulus = 100.0 * units.GPa
76 |
77 | # NPT equilibration
78 | dyn = Inhomogeneous_NPTBerendsen(
79 | atoms,
80 | timestep=args.timestep * units.fs,
81 | temperature_K=args.temperature,
82 | pressure_au=args.pressure * 1.01325 * units.bar,
83 | taut=args.ttime * units.fs,
84 | taup=args.ptime * units.fs,
85 | compressibility_au=1.0 / bulk_modulus,
86 | )
87 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature)
88 | Stationary(atoms)
89 |
90 | # NPT equilibration and volume relaxation
91 | for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"):
92 | dyn.run(steps=1)
93 |
94 |
95 | ################################################################################
96 | # Alchemical switching
97 | ################################################################################
98 |
99 | # Define alchemical transformation
100 | src_elem, dst_elem = args.switch_pair
101 | src_Z, dst_Z = ase.data.atomic_numbers[src_elem], ase.data.atomic_numbers[dst_elem]
102 | src_idx = np.where(atoms.get_atomic_numbers() == src_Z)[0]
103 | alchemical_pairs = [
104 | [AlchemicalPair(atom_index=idx, atomic_number=Z) for idx in src_idx]
105 | for Z in [src_Z, dst_Z]
106 | ]
107 |
108 | # Set up the alchemical MACE calculator
109 | calc = AlchemicalMACECalculator(
110 | atoms=atoms,
111 | alchemical_pairs=alchemical_pairs,
112 | alchemical_weights=[1.0, 0.0],
113 | device=args.device,
114 | model=args.model,
115 | )
116 | atoms.set_calculator(calc)
117 | upper_triangular_cell(atoms) # for ASE NPT
118 |
119 | # NPT alchemical switching
120 | ptime = args.ptime * units.fs
121 | pfactor = bulk_modulus * ptime * ptime
122 |
123 | dyn = NPT(
124 | atoms,
125 | timestep=args.timestep * units.fs,
126 | temperature_K=args.temperature,
127 | externalstress=args.pressure * 1.01325 * units.bar,
128 | ttime=args.ttime * units.fs,
129 | pfactor=pfactor,
130 | )
131 |
132 | # Define alchemical path
133 | t = np.linspace(0.0, 1.0, args.alchemy_switch_steps)
134 | lambda_steps = t**5 * (70 * t**4 - 315 * t**3 + 540 * t**2 - 420 * t + 126)
135 | lambda_values = [
136 | np.zeros(args.alchemy_equil_steps),
137 | lambda_steps,
138 | np.ones(args.alchemy_equil_steps),
139 | lambda_steps[::-1],
140 | ]
141 | lambda_values = np.concatenate(lambda_values)
142 |
143 | calculate_gradients = [
144 | np.zeros(args.alchemy_equil_steps, dtype=bool),
145 | np.ones(args.alchemy_switch_steps, dtype=bool),
146 | np.zeros(args.alchemy_equil_steps, dtype=bool),
147 | np.ones(args.alchemy_switch_steps, dtype=bool),
148 | ]
149 | calculate_gradients = np.concatenate(calculate_gradients)
150 |
151 |
152 | def get_observables(dynamics, time, lambda_value):
153 | num_atoms = len(dynamics.atoms)
154 | alchemical_grad = dynamics.atoms._calc.results["alchemical_grad"]
155 | lambda_grad = (alchemical_grad[1] - alchemical_grad[0]) / num_atoms
156 | return {
157 | "time": time,
158 | "potential": dynamics.atoms.get_potential_energy() / num_atoms,
159 | "temperature": dynamics.atoms.get_temperature(),
160 | "volume": dynamics.atoms.get_volume() / num_atoms,
161 | "lambda": lambda_value,
162 | "lambda_grad": lambda_grad,
163 | }
164 |
165 |
166 | # Simulation loop
167 | total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps
168 |
169 | observables = []
170 | for step in tqdm(range(total_steps), desc="Alchemical switching"):
171 | lambda_value = lambda_values[step]
172 | grad_enabled = calculate_gradients[step]
173 |
174 | # Set alchemical weights and atomic masses
175 | calc.set_alchemical_weights([1 - lambda_value, lambda_value])
176 | atoms.set_masses(calc.get_alchemical_atomic_masses())
177 | calc.calculate_alchemical_grad = grad_enabled
178 |
179 | dyn.run(steps=1)
180 | if step % args.log_interval == 0:
181 | time = (step + 1) * args.timestep
182 | observables.append(get_observables(dyn, time, lambda_value))
183 |
184 | # Save observables
185 | df = pd.DataFrame(observables)
186 | df.to_csv(args.output_dir / "observables.csv", index=False)
187 |
188 | # Save masses for post-processing
189 | calc.set_alchemical_weights([1.0, 0.0])
190 | np.save(args.output_dir / "masses_init.npy", calc.get_alchemical_atomic_masses())
191 | calc.set_alchemical_weights([0.0, 1.0])
192 | np.save(args.output_dir / "masses_final.npy", calc.get_alchemical_atomic_masses())
193 |
--------------------------------------------------------------------------------
/scripts/perovskite_frenkel_ladd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import ase
5 | import numpy as np
6 | import pandas as pd
7 | from ase import units
8 | from ase.build import make_supercell
9 | from ase.constraints import ExpCellFilter
10 | from ase.md.langevin import Langevin
11 | from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen
12 | from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
13 | from ase.optimize import FIRE
14 | from mace.calculators import mace_mp
15 | from pymatgen.io.ase import AseAtomsAdaptor
16 | from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
17 | from tqdm import tqdm
18 |
19 | from alchemical_mace.calculator import FrenkelLaddCalculator, NVTMACECalculator
20 |
21 |
22 | # Arguments
23 | parser = argparse.ArgumentParser()
24 |
25 | # Structure
26 | parser.add_argument("--structure-file", type=str)
27 | parser.add_argument("--supercell", type=int, nargs=3, default=[6, 6, 6])
28 |
29 | # Molecular dynamics: general
30 | parser.add_argument("--temperature", type=float, default=300.0)
31 | parser.add_argument("--pressure", type=float, default=1.0)
32 | parser.add_argument("--timestep", type=float, default=2.0)
33 | parser.add_argument("--ttime", type=float, default=25.0)
34 | parser.add_argument("--ptime", type=int, default=75.0)
35 |
36 | # Molecular dynamics: timesteps
37 | parser.add_argument("--npt-equil-stpes", type=int, default=10000)
38 | parser.add_argument("--npt-prod-steps", type=int, default=20000)
39 | parser.add_argument("--nvt-equil-steps", type=int, default=20000)
40 | parser.add_argument("--nvt-prod-steps", type=int, default=30000)
41 | parser.add_argument("--alchemy-equil-steps", type=int, default=20000)
42 | parser.add_argument("--alchemy-switch-steps", type=int, default=30000)
43 |
44 | # Molecular dynamics: output control
45 | parser.add_argument("--output-dir", type=Path, default=Path("results"))
46 | parser.add_argument("--log-interval", type=int, default=1)
47 |
48 | # MACE model
49 | parser.add_argument("--device", type=str, default="cuda")
50 | parser.add_argument("--model", type=str, default="small")
51 |
52 | args = parser.parse_args()
53 | args.output_dir.mkdir(exist_ok=True, parents=True)
54 |
55 | # Load structure
56 | atoms = ase.io.read(args.structure_file)
57 | atoms = make_supercell(atoms, np.diag(args.supercell))
58 |
59 | # Load universal MACE calculator and relax the structure
60 | mace_calc = mace_mp(model=args.model, device=args.device, default_dtype="float32")
61 | atoms.calc = mace_calc
62 | atoms = ExpCellFilter(atoms)
63 | optimizer = FIRE(atoms)
64 | optimizer.run(fmax=0.01, steps=500)
65 | atoms = atoms.atoms # get the relaxed structure
66 | initial_atoms = atoms.copy() # save the initial structure
67 |
68 |
69 | ################################################################################
70 | # Cell volume equilibration
71 | ################################################################################
72 |
73 | atoms = initial_atoms.copy()
74 | atoms.set_calculator(mace_calc)
75 | bulk_modulus = 100.0 * units.GPa
76 |
77 | # Equilibration and volume calculation
78 | dyn = Inhomogeneous_NPTBerendsen(
79 | atoms,
80 | timestep=args.timestep * units.fs,
81 | temperature_K=args.temperature,
82 | pressure_au=args.pressure * 1.01325 * units.bar,
83 | taut=args.ttime * units.fs,
84 | taup=args.ptime * units.fs,
85 | compressibility_au=1.0 / bulk_modulus,
86 | )
87 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature)
88 | Stationary(atoms)
89 |
90 | # NPT equilibration and volume relaxation
91 | cellpar_traj = []
92 | for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"):
93 | dyn.run(steps=1)
94 | for step in tqdm(range(args.npt_prod_steps), desc="NPT prod"):
95 | dyn.run(steps=1)
96 | if step % args.log_interval == 0:
97 | cellpar_traj.append(atoms.get_cell().cellpar())
98 | abc_new = np.mean(cellpar_traj, axis=0)[:3]
99 |
100 | # Scale the initial cell to match the average volume
101 | atoms = initial_atoms
102 | atoms.set_cell(np.diag(abc_new), scale_atoms=True)
103 | atoms.set_calculator(mace_calc)
104 |
105 | # Relax the atomic positions
106 | optimizer = FIRE(atoms)
107 | optimizer.run(fmax=0.01, steps=500)
108 | initial_atoms = atoms.copy() # save the initial structure
109 |
110 |
111 | ################################################################################
112 | # MSD calculation
113 | ################################################################################
114 |
115 | initial_positions = atoms.get_positions()
116 | # Using the reversible scaling MACE calculator with fixed scale of 1.0
117 | # since we can turn off the stress calculation
118 | calc = NVTMACECalculator(device=args.device, model=args.model)
119 | atoms.set_calculator(calc)
120 |
121 | # NVT MSD calculation
122 | dyn = Langevin(
123 | atoms,
124 | timestep=args.timestep * units.fs,
125 | temperature_K=args.temperature,
126 | friction=1 / (args.ttime * units.fs),
127 | )
128 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature)
129 | Stationary(atoms)
130 |
131 | temperatures = []
132 | for step in tqdm(range(args.nvt_equil_steps), desc="NVT equil"):
133 | dyn.run(steps=1)
134 | squared_disp = np.zeros(len(atoms))
135 | for step in tqdm(range(args.nvt_prod_steps), desc="NVT prod"):
136 | dyn.run(steps=1)
137 | squared_disp += np.sum((atoms.get_positions() - initial_positions) ** 2, axis=1)
138 | mean_squared_disp = squared_disp / args.nvt_prod_steps
139 |
140 | # Calculate spring constants and average over symmetrically equivalent atoms
141 | spring_constants = 3.0 * units.kB * args.temperature / mean_squared_disp
142 | structure = AseAtomsAdaptor.get_structure(initial_atoms)
143 | sga = SpacegroupAnalyzer(structure)
144 | equivalent_indices = sga.get_symmetrized_structure().equivalent_indices
145 | for indices in equivalent_indices:
146 | spring_constants[indices] = np.mean(spring_constants[indices])
147 |
148 | np.save(args.output_dir / "spring_constants.npy", spring_constants)
149 | np.save(args.output_dir / "masses.npy", atoms.get_masses())
150 |
151 |
152 | ################################################################################
153 | # Frenkel-Ladd calculation
154 | ################################################################################
155 |
156 | atoms = initial_atoms.copy()
157 | calc = FrenkelLaddCalculator(
158 | spring_constants=spring_constants,
159 | initial_positions=initial_positions,
160 | device=args.device,
161 | model=args.model,
162 | )
163 | atoms.set_calculator(calc)
164 |
165 | # NVT Frenkel-Ladd calculation
166 | dyn = Langevin(
167 | atoms,
168 | timestep=args.timestep * units.fs,
169 | temperature_K=args.temperature,
170 | friction=1 / (args.ttime * units.fs),
171 | )
172 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature)
173 | Stationary(atoms)
174 |
175 | # Define Frenkel-Ladd path
176 | t = np.linspace(0.0, 1.0, args.alchemy_switch_steps)
177 | lambda_steps = t**5 * (70 * t**4 - 315 * t**3 + 540 * t**2 - 420 * t + 126)
178 | lambda_values = [
179 | np.zeros(args.alchemy_equil_steps),
180 | lambda_steps,
181 | np.ones(args.alchemy_equil_steps),
182 | lambda_steps[::-1],
183 | ]
184 | lambda_values = np.concatenate(lambda_values)
185 |
186 |
187 | def get_observables(dynamics, time, lambda_value):
188 | num_atoms = len(dynamics.atoms)
189 | return {
190 | "time": time,
191 | "potential": dynamics.atoms.get_potential_energy() / num_atoms,
192 | "temperature": dynamics.atoms.get_temperature(),
193 | "volume": dynamics.atoms.get_volume() / num_atoms,
194 | "lambda": lambda_value,
195 | "lambda_grad": dynamics.atoms._calc.results["energy_diff"] / num_atoms,
196 | }
197 |
198 |
199 | # Simulation loop
200 | calc.compute_mace = False
201 | total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps
202 |
203 | observables = []
204 | for step in tqdm(range(total_steps), desc="Frenkel-Ladd"):
205 | if step == args.alchemy_equil_steps: # turn on MACE after spring equilibration
206 | calc.compute_mace = True
207 | lambda_value = lambda_values[step]
208 | calc.set_weights(lambda_value)
209 |
210 | dyn.run(steps=1)
211 | if step % args.log_interval == 0:
212 | time = (step + 1) * args.timestep
213 | observables.append(get_observables(dyn, time, lambda_value))
214 |
215 | # Save observables
216 | df = pd.DataFrame(observables)
217 | df.to_csv(args.output_dir / "observables.csv", index=False)
218 |
--------------------------------------------------------------------------------
/scripts/vacancy_frenkel_ladd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import ase
5 | import numpy as np
6 | import pandas as pd
7 | from ase import units
8 | from ase.build import make_supercell
9 | from ase.constraints import ExpCellFilter
10 | from ase.md.langevin import Langevin
11 | from ase.md.npt import NPT
12 | from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen
13 | from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary
14 | from ase.optimize import FIRE
15 | from mace.calculators import mace_mp
16 | from pymatgen.io.ase import AseAtomsAdaptor
17 | from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
18 | from tqdm import tqdm
19 |
20 | from alchemical_mace.calculator import (
21 | DefectFrenkelLaddCalculator,
22 | FrenkelLaddCalculator,
23 | NVTMACECalculator,
24 | )
25 | from alchemical_mace.utils import upper_triangular_cell
26 |
27 |
28 | # Arguments
29 | parser = argparse.ArgumentParser()
30 |
31 | # Structure
32 | parser.add_argument("--structure-file", type=str)
33 | parser.add_argument("--supercell", type=int, nargs=3, default=[5, 5, 5])
34 |
35 | # Molecular dynamics: general
36 | parser.add_argument("--temperature", type=float, default=300.0)
37 | parser.add_argument("--pressure", type=float, default=1.0)
38 | parser.add_argument("--timestep", type=float, default=2.0)
39 | parser.add_argument("--ttime", type=float, default=25.0)
40 | parser.add_argument("--ptime", type=int, default=75.0)
41 |
42 | # Molecular dynamics: timesteps
43 | parser.add_argument("--npt-equil-stpes", type=int, default=10000)
44 | parser.add_argument("--npt-prod-steps", type=int, default=20000)
45 | parser.add_argument("--nvt-equil-steps", type=int, default=20000)
46 | parser.add_argument("--nvt-prod-steps", type=int, default=30000)
47 | parser.add_argument("--alchemy-equil-steps", type=int, default=20000)
48 | parser.add_argument("--alchemy-switch-steps", type=int, default=30000)
49 |
50 | # Molecular dynamics: output control
51 | parser.add_argument("--output-dir", type=Path, default=Path("results"))
52 | parser.add_argument("--log-interval", type=int, default=1)
53 |
54 | # MACE model
55 | parser.add_argument("--device", type=str, default="cuda")
56 | parser.add_argument("--model", type=str, default="small")
57 |
58 | args = parser.parse_args()
59 | args.output_dir.mkdir(exist_ok=True, parents=True)
60 |
61 |
62 | ################################################################################
63 | # Energy minimization: defect-free structure
64 | ################################################################################
65 |
66 | # Load structure
67 | atoms = ase.io.read(args.structure_file)
68 | atoms = make_supercell(atoms, np.diag(args.supercell))
69 |
70 | # Load universal MACE calculator and relax the structure
71 | mace_calc = mace_mp(model=args.model, device=args.device, default_dtype="float32")
72 | atoms.calc = mace_calc
73 | atoms = ExpCellFilter(atoms)
74 | optimizer = FIRE(atoms)
75 | optimizer.run(fmax=0.01, steps=500)
76 | atoms = atoms.atoms # get the relaxed structure
77 | initial_atoms = atoms.copy() # save the initial structure
78 |
79 |
80 | ################################################################################
81 | # Cell volume equilibration: defect-free structure
82 | ################################################################################
83 |
84 | atoms = initial_atoms.copy()
85 | atoms.set_calculator(mace_calc)
86 | bulk_modulus = 100.0 * units.GPa
87 |
88 | # Equilibration and volume calculation
89 | dyn = Inhomogeneous_NPTBerendsen(
90 | atoms,
91 | timestep=args.timestep * units.fs,
92 | temperature_K=args.temperature,
93 | pressure_au=args.pressure * 1.01325 * units.bar,
94 | taut=args.ttime * units.fs,
95 | taup=args.ptime * units.fs,
96 | compressibility_au=1.0 / bulk_modulus,
97 | )
98 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature)
99 | Stationary(atoms)
100 |
101 | # NPT equilibration and volume relaxation
102 | cellpar_traj = []
103 | for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"):
104 | dyn.run(steps=1)
105 | for step in tqdm(range(args.npt_prod_steps), desc="NPT prod"):
106 | dyn.run(steps=1)
107 | if step % args.log_interval == 0:
108 | cellpar_traj.append(atoms.get_cell().cellpar())
109 | abc_new = np.mean(cellpar_traj, axis=0)[:3]
110 |
111 | # Scale the initial cell to match the average volume
112 | atoms = initial_atoms
113 | atoms.set_cell(np.diag(abc_new), scale_atoms=True)
114 | atoms.set_calculator(mace_calc)
115 |
116 | # Relax the atomic positions
117 | optimizer = FIRE(atoms)
118 | optimizer.run(fmax=0.01, steps=500)
119 | initial_atoms = atoms.copy() # save the initial structure
120 |
121 |
122 | ################################################################################
123 | # MSD calculation: defect-free structure
124 | ################################################################################
125 |
126 | initial_positions = atoms.get_positions()
127 | # Using the reversible scaling MACE calculator with fixed scale of 1.0
128 | # since we can turn off the stress calculation
129 | calc = NVTMACECalculator(device=args.device, model=args.model)
130 | atoms.set_calculator(calc)
131 |
132 | # NVT MSD calculation
133 | dyn = Langevin(
134 | atoms,
135 | timestep=args.timestep * units.fs,
136 | temperature_K=args.temperature,
137 | friction=1 / (args.ttime * units.fs),
138 | )
139 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature)
140 | Stationary(atoms)
141 |
142 | temperatures = []
143 | for step in tqdm(range(args.nvt_equil_steps), desc="NVT equil"):
144 | dyn.run(steps=1)
145 | squared_disp = np.zeros(len(atoms))
146 | for step in tqdm(range(args.nvt_prod_steps), desc="NVT prod"):
147 | dyn.run(steps=1)
148 | squared_disp += np.sum((atoms.get_positions() - initial_positions) ** 2, axis=1)
149 | mean_squared_disp = squared_disp / args.nvt_prod_steps
150 |
151 | # Calculate spring constants and average over symmetrically equivalent atoms
152 | spring_constants = 3.0 * units.kB * args.temperature / mean_squared_disp
153 | structure = AseAtomsAdaptor.get_structure(initial_atoms)
154 | sga = SpacegroupAnalyzer(structure)
155 | equivalent_indices = sga.get_symmetrized_structure().equivalent_indices
156 | for indices in equivalent_indices:
157 | spring_constants[indices] = np.mean(spring_constants[indices])
158 |
159 | np.save(args.output_dir / "spring_constants.npy", spring_constants)
160 | np.save(args.output_dir / "masses.npy", atoms.get_masses())
161 |
162 |
163 | ################################################################################
164 | # Frenkel-Ladd calculation: defect-free structure
165 | ################################################################################
166 |
167 | atoms = initial_atoms.copy()
168 | calc = FrenkelLaddCalculator(
169 | spring_constants=spring_constants,
170 | initial_positions=initial_positions,
171 | device=args.device,
172 | model=args.model,
173 | )
174 | atoms.set_calculator(calc)
175 |
176 | # NVT Frenkel-Ladd calculation
177 | dyn = Langevin(
178 | atoms,
179 | timestep=args.timestep * units.fs,
180 | temperature_K=args.temperature,
181 | friction=1 / (args.ttime * units.fs),
182 | )
183 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature)
184 | Stationary(atoms)
185 |
186 | # Define Frenkel-Ladd path
187 | t = np.linspace(0.0, 1.0, args.alchemy_switch_steps)
188 | lambda_steps = t**5 * (70 * t**4 - 315 * t**3 + 540 * t**2 - 420 * t + 126)
189 | lambda_values = [
190 | np.zeros(args.alchemy_equil_steps),
191 | lambda_steps,
192 | np.ones(args.alchemy_equil_steps),
193 | lambda_steps[::-1],
194 | ]
195 | lambda_values = np.concatenate(lambda_values)
196 |
197 |
198 | def get_observables(dynamics, time, lambda_value):
199 | num_atoms = len(dynamics.atoms)
200 | return {
201 | "time": time,
202 | "potential": dynamics.atoms.get_potential_energy() / num_atoms,
203 | "temperature": dynamics.atoms.get_temperature(),
204 | "volume": dynamics.atoms.get_volume() / num_atoms,
205 | "lambda": lambda_value,
206 | "lambda_grad": dynamics.atoms._calc.results["energy_diff"] / num_atoms,
207 | }
208 |
209 |
210 | # Simulation loop
211 | calc.compute_mace = False
212 | total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps
213 |
214 | observables = []
215 | for step in tqdm(range(total_steps), desc="Frenkel-Ladd"):
216 | if step == args.alchemy_equil_steps: # turn on MACE after spring equilibration
217 | calc.compute_mace = True
218 | lambda_value = lambda_values[step]
219 | calc.set_weights(lambda_value)
220 |
221 | dyn.run(steps=1)
222 | if step % args.log_interval == 0:
223 | time = (step + 1) * args.timestep
224 | observables.append(get_observables(dyn, time, lambda_value))
225 |
226 | # Save observables
227 | df = pd.DataFrame(observables)
228 | df.to_csv(args.output_dir / "observables.csv", index=False)
229 |
230 |
231 | ################################################################################
232 | # Cell volume equilibration: structure with a defect
233 | ################################################################################
234 |
235 | atoms = initial_atoms.copy()
236 |
237 | # Create a vacancy at the center of the supercell
238 | vacancy_index = len(atoms) // 2
239 | atom_mask = np.ones(len(atoms), dtype=bool)
240 | atom_mask[vacancy_index] = False
241 | del atoms[vacancy_index]
242 |
243 | atoms.set_calculator(mace_calc)
244 |
245 | # Equilibration and volume calculation
246 | dyn = Inhomogeneous_NPTBerendsen(
247 | atoms,
248 | timestep=args.timestep * units.fs,
249 | temperature_K=args.temperature,
250 | pressure_au=args.pressure * 1.01325 * units.bar,
251 | taut=args.ttime * units.fs,
252 | taup=args.ptime * units.fs,
253 | compressibility_au=1.0 / bulk_modulus,
254 | )
255 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature)
256 | Stationary(atoms)
257 |
258 | # NPT equilibration and volume relaxation
259 | cellpar_traj = []
260 | for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"):
261 | dyn.run(steps=1)
262 | for step in tqdm(range(args.npt_prod_steps), desc="NPT prod"):
263 | dyn.run(steps=1)
264 | if step % args.log_interval == 0:
265 | cellpar_traj.append(atoms.get_cell().cellpar())
266 | abc_new = np.mean(cellpar_traj, axis=0)[:3]
267 |
268 | # Scale the initial cell to match the average volume
269 | atoms = initial_atoms.copy()
270 | atoms.set_cell(np.diag(abc_new), scale_atoms=True)
271 | del atoms[vacancy_index]
272 | atoms.set_calculator(mace_calc)
273 |
274 | # Relax the atomic positions
275 | optimizer = FIRE(atoms)
276 | optimizer.run(fmax=0.01, steps=500)
277 |
278 |
279 | ################################################################################
280 | # Frenkel-Ladd calculation: structure with a defect
281 | ################################################################################
282 |
283 | calc = FrenkelLaddCalculator(
284 | spring_constants=spring_constants[atom_mask],
285 | initial_positions=initial_positions[atom_mask],
286 | device=args.device,
287 | model=args.model,
288 | )
289 | atoms.set_calculator(calc)
290 |
291 | # NVT Frenkel-Ladd calculation
292 | dyn = Langevin(
293 | atoms,
294 | timestep=args.timestep * units.fs,
295 | temperature_K=args.temperature,
296 | friction=1 / (args.ttime * units.fs),
297 | )
298 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature)
299 | Stationary(atoms)
300 |
301 | # Simulation loop
302 | calc.compute_mace = False
303 | total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps
304 |
305 | observables = []
306 | for step in tqdm(range(total_steps), desc="Frenkel-Ladd"):
307 | if step == args.alchemy_equil_steps: # turn on MACE after spring equilibration
308 | calc.compute_mace = True
309 | lambda_value = lambda_values[step]
310 | calc.set_weights(lambda_value)
311 |
312 | dyn.run(steps=1)
313 | if step % args.log_interval == 0:
314 | time = (step + 1) * args.timestep
315 | observables.append(get_observables(dyn, time, lambda_value))
316 |
317 | # Save observables
318 | df = pd.DataFrame(observables)
319 | df.to_csv(args.output_dir / "observables_defect.csv", index=False)
320 |
321 |
322 | ################################################################################
323 | # Cell volume equilibration: partial Frenkel-Ladd calculation
324 | ################################################################################
325 |
326 | atoms = initial_atoms.copy()
327 | atoms.set_calculator(mace_calc)
328 |
329 | # Equilibration and volume calculation
330 | dyn = Inhomogeneous_NPTBerendsen(
331 | atoms,
332 | timestep=args.timestep * units.fs,
333 | temperature_K=args.temperature,
334 | pressure_au=args.pressure * 1.01325 * units.bar,
335 | taut=args.ttime * units.fs,
336 | taup=args.ptime * units.fs,
337 | compressibility_au=1.0 / bulk_modulus,
338 | )
339 | MaxwellBoltzmannDistribution(atoms, temperature_K=args.temperature)
340 | Stationary(atoms)
341 |
342 | # NPT equilibration and volume relaxation
343 | for step in tqdm(range(args.npt_equil_stpes), desc="NPT equil"):
344 | dyn.run(steps=1)
345 |
346 |
347 | ################################################################################
348 | # Alchemical switching
349 | ################################################################################
350 |
351 | # Set up the partial Frenkel-Ladd calculation
352 | calc = DefectFrenkelLaddCalculator(
353 | atoms=atoms,
354 | spring_constant=spring_constants[vacancy_index],
355 | defect_index=vacancy_index,
356 | device=args.device,
357 | model=args.model,
358 | )
359 | atoms.set_calculator(calc)
360 | upper_triangular_cell(atoms) # for ASE NPT
361 |
362 | # NPT alchemical switching
363 | ptime = args.ptime * units.fs
364 | pfactor = bulk_modulus * ptime * ptime
365 |
366 | dyn = NPT(
367 | atoms,
368 | timestep=args.timestep * units.fs,
369 | temperature_K=args.temperature,
370 | externalstress=args.pressure * 1.01325 * units.bar,
371 | ttime=args.ttime * units.fs,
372 | pfactor=pfactor,
373 | )
374 |
375 | # Define alchemical path
376 | t = np.linspace(0.0, 1.0, args.alchemy_switch_steps)
377 | lambda_steps = t**5 * (70 * t**4 - 315 * t**3 + 540 * t**2 - 420 * t + 126)
378 | lambda_values = [
379 | np.zeros(args.alchemy_equil_steps),
380 | lambda_steps,
381 | np.ones(args.alchemy_equil_steps),
382 | lambda_steps[::-1],
383 | ]
384 | lambda_values = np.concatenate(lambda_values)
385 |
386 | calculate_gradients = [
387 | np.zeros(args.alchemy_equil_steps, dtype=bool),
388 | np.ones(args.alchemy_switch_steps, dtype=bool),
389 | np.zeros(args.alchemy_equil_steps, dtype=bool),
390 | np.ones(args.alchemy_switch_steps, dtype=bool),
391 | ]
392 | calculate_gradients = np.concatenate(calculate_gradients)
393 |
394 |
395 | def get_observables(dynamics, time, lambda_value):
396 | num_atoms = len(dynamics.atoms)
397 | alchemical_grad = dynamics.atoms._calc.results["alchemical_grad"]
398 | return {
399 | "time": time,
400 | "potential": dynamics.atoms.get_potential_energy() / num_atoms,
401 | "temperature": dynamics.atoms.get_temperature(),
402 | "volume": dynamics.atoms.get_volume() / num_atoms,
403 | "lambda": lambda_value,
404 | "lambda_grad": alchemical_grad / num_atoms,
405 | }
406 |
407 |
408 | # Simulation loop
409 | total_steps = 2 * args.alchemy_equil_steps + 2 * args.alchemy_switch_steps
410 |
411 | observables = []
412 | for step in tqdm(range(total_steps), desc="Alchemical switching"):
413 | lambda_value = lambda_values[step]
414 | grad_enabled = calculate_gradients[step]
415 |
416 | # Set alchemical weights and atomic masses
417 | calc.set_alchemical_weight(lambda_value)
418 | calc.calculate_alchemical_grad = grad_enabled
419 |
420 | dyn.run(steps=1)
421 | if step % args.log_interval == 0:
422 | time = (step + 1) * args.timestep
423 | observables.append(get_observables(dyn, time, lambda_value))
424 |
425 | # Save observables
426 | df = pd.DataFrame(observables)
427 | df.to_csv(args.output_dir / "observables_FL.csv", index=False)
428 |
--------------------------------------------------------------------------------