├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml └── src └── flashmd ├── __init__.py ├── ase ├── __init__.py ├── bussi.py ├── langevin.py ├── npt.py └── velocity_verlet.py ├── ipi.py ├── models.py └── stepper.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | FlashMD: universal long-stride molecular dynamics 2 | ================================================= 3 | 4 | This repository contains custom integrators to run MD trajectories with FlashMD models. These models are 5 | designed to learn and predict molecular dynamics trajectories using long strides, therefore allowing 6 | very large time steps. Before using this method, make sure you are aware of its limitations, which are 7 | discussed in [this preprint](http://arxiv.org/abs/2505.19350). 8 | 9 | Quickstart 10 | ---------- 11 | 12 | You can install the package with 13 | 14 | ```bash 15 | pip install flashmd 16 | ``` 17 | 18 | After installation, you can run accelerated molecular dynamics with ASE as follows: 19 | 20 | ```py 21 | import ase.build 22 | import ase.units 23 | from ase.md.velocitydistribution import MaxwellBoltzmannDistribution 24 | import torch 25 | from pet_mad.calculator import PETMADCalculator 26 | 27 | from flashmd import get_universal_model 28 | from flashmd.ase.langevin import Langevin 29 | 30 | 31 | # Create a structure and initialize velocities 32 | atoms = ase.build.bulk("Al", "fcc", cubic=True) 33 | MaxwellBoltzmannDistribution(atoms, temperature_K=300) 34 | 35 | # Load models 36 | device="cuda" if torch.cuda.is_available() else "cpu" 37 | calculator = PETMADCalculator("1.0.1", device=device) 38 | atoms.calc = calculator 39 | model = get_universal_model(16) # 16 fs model; also available: 1, 4, 8, 32, 64 fs 40 | model = model.to(device) 41 | 42 | # Run MD 43 | dyn = Langevin( 44 | atoms=atoms, 45 | timestep=16*ase.units.fs, 46 | temperature_K=300, 47 | time_constant=100*ase.units.fs, 48 | model=model, 49 | device=device 50 | ) 51 | dyn.run(1000) 52 | ``` 53 | 54 | Other available integrators: 55 | 56 | ```py 57 | from flashmd.ase.velocity_verlet import VelocityVerlet 58 | from flashmd.ase.bussi import Bussi 59 | ``` 60 | 61 | Disclaimer 62 | ---------- 63 | 64 | This is experimental software and should only be used if you know what you're doing. 65 | We recommend using the i-PI integrators for any serious work, and to perform constant 66 | pressure, NpT molecular dynamics. You can see 67 | [this cookbook recipe](https://atomistic-cookbook.org/examples/flashmd/flashmd-demo.html) 68 | for a usage example. 69 | Given that the main issue we observe in direct MD trajectories is loss of equipartition 70 | of energy between different degrees of freedom, we recommend using a local Langevin 71 | thermostat, and to monitor the temperature of different atomic types or different 72 | parts of the simulated system. 73 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "flashmd" 3 | version = "0.1.2" 4 | requires-python = ">=3.9" 5 | 6 | readme = "README.md" 7 | license = {text = "Apache-2.0"} 8 | description = "Accelerated molecular dynamics with large-time-step predictions" 9 | authors = [{name = "flashmd developers"}] 10 | 11 | dependencies = [ 12 | "ase", 13 | "pet-mad>=1.2.0", 14 | "huggingface_hub", 15 | ] 16 | 17 | keywords = ["machine learning", "molecular modeling", "molecular dynamics"] 18 | classifiers = [ 19 | "Development Status :: 4 - Beta", 20 | "Intended Audience :: Science/Research", 21 | "License :: OSI Approved :: Apache Software License", 22 | "Operating System :: POSIX", 23 | "Operating System :: MacOS :: MacOS X", 24 | "Operating System :: Microsoft :: Windows", 25 | "Programming Language :: Python", 26 | "Programming Language :: Python :: 3", 27 | "Topic :: Scientific/Engineering", 28 | "Topic :: Scientific/Engineering :: Bio-Informatics", 29 | "Topic :: Scientific/Engineering :: Chemistry", 30 | "Topic :: Scientific/Engineering :: Physics", 31 | "Topic :: Software Development :: Libraries", 32 | "Topic :: Software Development :: Libraries :: Python Modules", 33 | ] 34 | 35 | # TODO: add project URLs 36 | 37 | [build-system] 38 | requires = [ 39 | "setuptools", 40 | "wheel", 41 | ] 42 | build-backend = "setuptools.build_meta" 43 | 44 | [tool.setuptools.packages.find] 45 | where = ["src"] 46 | 47 | [tool.ruff] 48 | line-length = 88 49 | -------------------------------------------------------------------------------- /src/flashmd/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import get_universal_model as get_universal_model 2 | 3 | import warnings 4 | 5 | warnings.filterwarnings("ignore", category=UserWarning, message="custom data") 6 | -------------------------------------------------------------------------------- /src/flashmd/ase/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-cosmo/flashmd/1189ebe1d5ebab650a2b4eb87b42f680298d1481/src/flashmd/ase/__init__.py -------------------------------------------------------------------------------- /src/flashmd/ase/bussi.py: -------------------------------------------------------------------------------- 1 | from .velocity_verlet import VelocityVerlet 2 | import ase.units 3 | from typing import List 4 | from metatomic.torch import AtomisticModel 5 | import torch 6 | import ase 7 | import numpy as np 8 | 9 | 10 | class Bussi(VelocityVerlet): 11 | def __init__( 12 | self, 13 | atoms: ase.Atoms, 14 | timestep: float, 15 | temperature_K: float, 16 | model: AtomisticModel | List[AtomisticModel], 17 | time_constant: float = 10.0 * ase.units.fs, 18 | device: str | torch.device = "auto", 19 | rescale_energy: bool = True, 20 | **kwargs, 21 | ): 22 | super().__init__(atoms, timestep, model, device, rescale_energy, **kwargs) 23 | 24 | self.temperature_K = temperature_K 25 | self.time_constant = time_constant 26 | 27 | def step(self): 28 | self.apply_bussi_half_step() 29 | super().step() 30 | self.apply_bussi_half_step() 31 | 32 | def apply_bussi_half_step(self): 33 | old_kinetic_energy = self.atoms.get_kinetic_energy() 34 | n_degrees_of_freedom = 3 * len(self.atoms) 35 | target_kinetic_energy = ( 36 | 0.5 * ase.units.kB * self.temperature_K * n_degrees_of_freedom 37 | ) 38 | 39 | exp_term = np.exp(-0.5 * self.dt / self.time_constant) 40 | energy_scaling_term = ( 41 | (1.0 - exp_term) 42 | * target_kinetic_energy 43 | / old_kinetic_energy 44 | / n_degrees_of_freedom 45 | ) 46 | r = np.random.randn(n_degrees_of_freedom) 47 | alpha_sq = ( 48 | exp_term 49 | + energy_scaling_term * np.sum(r**2) 50 | + 2.0 * r[0] * np.sqrt(exp_term * energy_scaling_term) 51 | ) 52 | alpha = np.sqrt(alpha_sq) 53 | 54 | momenta = self.atoms.get_momenta() 55 | self.atoms.set_momenta(alpha * momenta) 56 | -------------------------------------------------------------------------------- /src/flashmd/ase/langevin.py: -------------------------------------------------------------------------------- 1 | from .velocity_verlet import VelocityVerlet 2 | import ase.units 3 | from typing import List 4 | 5 | # from ..utils.pretrained import load_pretrained_models 6 | from metatomic.torch import AtomisticModel 7 | import torch 8 | import ase 9 | import numpy as np 10 | 11 | 12 | class Langevin(VelocityVerlet): 13 | def __init__( 14 | self, 15 | atoms: ase.Atoms, 16 | timestep: float, 17 | temperature_K: float, 18 | model: AtomisticModel | List[AtomisticModel], 19 | time_constant: float = 100.0 * ase.units.fs, 20 | device: str | torch.device = "auto", 21 | rescale_energy: bool = True, 22 | **kwargs, 23 | ): 24 | super().__init__(atoms, timestep, model, device, rescale_energy, **kwargs) 25 | 26 | self.temperature_K = temperature_K 27 | self.friction = 1.0 / time_constant 28 | 29 | def step(self): 30 | self.apply_langevin_half_step() 31 | super().step() 32 | self.apply_langevin_half_step() 33 | 34 | def apply_langevin_half_step(self): 35 | old_momenta = self.atoms.get_momenta() 36 | new_momenta = np.exp(-self.friction * 0.5 * self.dt) * old_momenta + np.sqrt( 37 | 1.0 - np.exp(-self.friction * self.dt) 38 | ) * np.sqrt( 39 | ase.units.kB * self.temperature_K * self.atoms.get_masses()[:, None] 40 | ) * np.random.randn(*old_momenta.shape) 41 | self.atoms.set_momenta(new_momenta) 42 | -------------------------------------------------------------------------------- /src/flashmd/ase/npt.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lab-cosmo/flashmd/1189ebe1d5ebab650a2b4eb87b42f680298d1481/src/flashmd/ase/npt.py -------------------------------------------------------------------------------- /src/flashmd/ase/velocity_verlet.py: -------------------------------------------------------------------------------- 1 | from ase.md.md import MolecularDynamics 2 | from typing import List 3 | from metatomic.torch import AtomisticModel 4 | from metatensor.torch import Labels, TensorBlock, TensorMap 5 | import ase.units 6 | import torch 7 | from metatomic.torch.ase_calculator import _ase_to_torch_data 8 | from metatomic.torch import System 9 | import ase 10 | from ..stepper import FlashMDStepper 11 | import numpy as np 12 | 13 | 14 | class VelocityVerlet(MolecularDynamics): 15 | def __init__( 16 | self, 17 | atoms: ase.Atoms, 18 | timestep: float, 19 | model: AtomisticModel | List[AtomisticModel], 20 | device: str | torch.device = "auto", 21 | rescale_energy: bool = True, 22 | **kwargs, 23 | ): 24 | super().__init__(atoms, timestep, **kwargs) 25 | 26 | models = model if isinstance(model, list) else [model] 27 | capabilities = models[0].capabilities() 28 | 29 | base_timestep = float(models[0].module.base_time_step) * ase.units.fs 30 | 31 | n_time_steps = int( 32 | [k for k in capabilities.outputs.keys() if "mtt::delta_" in k][0].split( 33 | "_" 34 | )[1] 35 | ) 36 | if n_time_steps != self.dt / base_timestep: 37 | raise ValueError( 38 | f"Mismatch between timestep ({self.dt}) and model timestep ({base_timestep})." 39 | ) 40 | 41 | if device == "auto": 42 | device = "cuda" if torch.cuda.is_available() else "cpu" 43 | else: 44 | device = device 45 | self.device = torch.device(device) 46 | self.dtype = getattr(torch, capabilities.dtype) 47 | 48 | self.stepper = FlashMDStepper(models, n_time_steps, self.device) 49 | self.rescale_energy = rescale_energy 50 | 51 | def step(self): 52 | if self.rescale_energy: 53 | old_energy = self.atoms.get_total_energy() 54 | 55 | system = _convert_atoms_to_system( 56 | self.atoms, device=self.device, dtype=self.dtype 57 | ) 58 | new_system = self.stepper.step(system) 59 | self.atoms.set_positions(new_system.positions.detach().cpu().numpy()) 60 | self.atoms.set_momenta( 61 | new_system.get_data("momenta") 62 | .block() 63 | .values.squeeze(-1) 64 | .detach() 65 | .cpu() 66 | .numpy() 67 | ) 68 | 69 | if self.rescale_energy: 70 | new_energy = self.atoms.get_total_energy() 71 | old_kinetic_energy = self.atoms.get_kinetic_energy() 72 | alpha = np.sqrt(1.0 - (new_energy - old_energy) / old_kinetic_energy) 73 | self.atoms.set_momenta(alpha * self.atoms.get_momenta()) 74 | 75 | 76 | def _convert_atoms_to_system( 77 | atoms: ase.Atoms, dtype: str, device: str | torch.device 78 | ) -> System: 79 | system_data = _ase_to_torch_data(atoms, dtype=dtype, device=device) 80 | system = System(*system_data) 81 | system.add_data( 82 | "momenta", 83 | TensorMap( 84 | keys=Labels.single().to(device), 85 | blocks=[ 86 | TensorBlock( 87 | values=torch.tensor( 88 | atoms.get_momenta(), dtype=dtype, device=device 89 | ).unsqueeze(-1), 90 | samples=Labels( 91 | names=["system", "atom"], 92 | values=torch.tensor( 93 | [[0, j] for j in range(len(atoms))], device=device 94 | ), 95 | ), 96 | components=[ 97 | Labels( 98 | names="xyz", 99 | values=torch.tensor([[0], [1], [2]], device=device), 100 | ) 101 | ], 102 | properties=Labels.single().to(device), 103 | ) 104 | ], 105 | ), 106 | ) 107 | system.add_data( 108 | "masses", 109 | TensorMap( 110 | keys=Labels.single().to(device), 111 | blocks=[ 112 | TensorBlock( 113 | values=torch.tensor( 114 | atoms.get_masses(), dtype=dtype, device=device 115 | ).unsqueeze(-1), 116 | samples=Labels( 117 | names=["system", "atom"], 118 | values=torch.tensor( 119 | [[0, j] for j in range(len(atoms))], device=device 120 | ), 121 | ), 122 | components=[], 123 | properties=Labels.single().to(device), 124 | ) 125 | ], 126 | ), 127 | ) 128 | return system 129 | -------------------------------------------------------------------------------- /src/flashmd/ipi.py: -------------------------------------------------------------------------------- 1 | from ipi.utils.depend import dstrip 2 | from ipi.utils.units import Constants 3 | from ipi.utils.messages import verbosity, info 4 | from ipi.utils.mathtools import random_rotation as random_rotation_matrix 5 | from ipi.engine.motion.dynamics import NVEIntegrator, NVTIntegrator, NPTIntegrator 6 | 7 | from flashmd.stepper import FlashMDStepper 8 | import ase.units 9 | import torch 10 | import numpy as np 11 | import ase.data 12 | 13 | from metatomic.torch import System 14 | from metatensor.torch import Labels, TensorBlock, TensorMap 15 | 16 | 17 | def get_standard_vv_step( 18 | sim, model=None, device=None, rescale_energy=True, random_rotation=False 19 | ): 20 | """ 21 | Returns a velocity Verlet stepper function for i-PI simulations. 22 | 23 | Parameters: 24 | - sim: The i-PI simulation object. 25 | - rescale_energy: If True, rescales the kinetic energy after the step 26 | to maintain energy conservation. 27 | 28 | Returns: 29 | - A function that performs a velocity Verlet step. 30 | """ 31 | 32 | def vv_step(motion): 33 | if random_rotation: 34 | raise NotImplementedError( 35 | "Random rotation is not implemented in the standard VV stepper." 36 | ) 37 | 38 | if rescale_energy: 39 | info("@flashmd: Old energy", verbosity.debug) 40 | old_energy = sim.properties("potential") + sim.properties("kinetic_md") 41 | 42 | print(motion.integrator.pdt, motion.integrator.qdt) 43 | motion.integrator.pstep(level=0) 44 | motion.integrator.pconstraints() 45 | motion.integrator.qcstep() # does two steps because qdt is halved in the i-PI integrator 46 | motion.integrator.qcstep() 47 | motion.integrator.pstep(level=0) 48 | motion.integrator.pconstraints() 49 | 50 | if rescale_energy: 51 | info("@flashmd: Energy rescale", verbosity.debug) 52 | new_energy = sim.properties("potential") + sim.properties("kinetic_md") 53 | kinetic_energy = sim.properties("kinetic_md") 54 | alpha = np.sqrt(1.0 - (new_energy - old_energy) / kinetic_energy) 55 | motion.beads.p[:] = alpha * dstrip(motion.beads.p) 56 | 57 | return vv_step 58 | 59 | 60 | def get_flashmd_vv_step(sim, model, device, rescale_energy=True, random_rotation=False): 61 | capabilities = model.capabilities() 62 | 63 | base_timestep = float(model.module.base_time_step) * ase.units.fs 64 | 65 | dt = sim.syslist[0].motion.dt * 2.4188843e-17 * ase.units.s 66 | 67 | n_time_steps = int( 68 | [k for k in capabilities.outputs.keys() if "mtt::delta_" in k][0].split("_")[1] 69 | ) 70 | if not np.allclose(dt, n_time_steps * base_timestep): 71 | raise ValueError( 72 | f"Mismatch between timestep ({dt}) and model timestep ({base_timestep})." 73 | ) 74 | 75 | device = torch.device(device) 76 | dtype = getattr(torch, capabilities.dtype) 77 | stepper = FlashMDStepper([model], n_time_steps, device) 78 | 79 | def flashmd_vv(motion): 80 | info("@flashmd: Starting VV", verbosity.debug) 81 | if rescale_energy: 82 | info("@flashmd: Old energy", verbosity.debug) 83 | old_energy = sim.properties("potential") + sim.properties("kinetic_md") 84 | 85 | info("@flashmd: Stepper", verbosity.debug) 86 | system = ipi_to_system(motion, device, dtype) 87 | 88 | if random_rotation: 89 | # generate a random rotation matrix 90 | R = torch.tensor( 91 | random_rotation_matrix(motion.prng, improper=True), 92 | device=system.positions.device, 93 | dtype=system.positions.dtype, 94 | ) 95 | # applies the random rotation 96 | system.cell = system.cell @ R.T 97 | system.positions = system.positions @ R.T 98 | momenta = system.get_data("momenta").block(0).values.squeeze() 99 | momenta[:] = momenta @ R.T # does the change in place 100 | 101 | new_system = stepper.step(system) 102 | 103 | if random_rotation: 104 | # revert q,p to the original reference frame (`system_to_ipi` ignores the cell) 105 | new_system.positions = new_system.positions @ R 106 | momenta = new_system.get_data("momenta").block(0).values.squeeze() 107 | momenta[:] = momenta @ R 108 | 109 | info("@flashmd: System to ipi", verbosity.debug) 110 | system_to_ipi(motion, new_system) 111 | info("@flashmd: VV P constraints", verbosity.debug) 112 | motion.integrator.pconstraints() 113 | 114 | if rescale_energy: 115 | info("@flashmd: Energy rescale", verbosity.debug) 116 | new_energy = sim.properties("potential") + sim.properties("kinetic_md") 117 | kinetic_energy = sim.properties("kinetic_md") 118 | alpha = np.sqrt(1.0 - (new_energy - old_energy) / kinetic_energy) 119 | motion.beads.p[:] = alpha * dstrip(motion.beads.p) 120 | motion.integrator.pconstraints() 121 | info("@flashmd: End of VV step", verbosity.debug) 122 | 123 | return flashmd_vv 124 | 125 | 126 | def get_nve_stepper( 127 | sim, 128 | model, 129 | device, 130 | rescale_energy=True, 131 | random_rotation=False, 132 | use_standard_vv=False, 133 | ): 134 | motion = sim.syslist[0].motion 135 | if type(motion.integrator) is not NVEIntegrator: 136 | raise TypeError( 137 | f"Base i-PI integrator is of type {motion.integrator.__class__.__name__}, use a NVE setup." 138 | ) 139 | 140 | if use_standard_vv: 141 | # use the standard velocity Verlet integrator 142 | vv_step = get_standard_vv_step( 143 | sim, model, device, rescale_energy, random_rotation 144 | ) 145 | else: 146 | # defaults to the FlashMD VV stepper 147 | vv_step = get_flashmd_vv_step( 148 | sim, model, device, rescale_energy, random_rotation 149 | ) 150 | 151 | def nve_stepper(motion, *_, **__): 152 | vv_step(motion) 153 | motion.ensemble.time += motion.dt 154 | 155 | return nve_stepper 156 | 157 | 158 | def get_nvt_stepper( 159 | sim, 160 | model, 161 | device, 162 | rescale_energy=True, 163 | random_rotation=False, 164 | use_standard_vv=False, 165 | ): 166 | motion = sim.syslist[0].motion 167 | if type(motion.integrator) is not NVTIntegrator: 168 | raise TypeError( 169 | f"Base i-PI integrator is of type {motion.integrator.__class__.__name__}, use a NVT setup." 170 | ) 171 | 172 | if use_standard_vv: 173 | # use the standard velocity Verlet integrator 174 | vv_step = get_standard_vv_step( 175 | sim, model, device, rescale_energy, random_rotation 176 | ) 177 | else: 178 | # defaults to the FlashMD VV stepper 179 | vv_step = get_flashmd_vv_step( 180 | sim, model, device, rescale_energy, random_rotation 181 | ) 182 | 183 | def nvt_stepper(motion, *_, **__): 184 | # OBABO splitting of a NVT propagator 185 | motion.thermostat.step() 186 | motion.integrator.pconstraints() 187 | vv_step(motion) 188 | motion.thermostat.step() 189 | motion.integrator.pconstraints() 190 | motion.ensemble.time += motion.dt 191 | 192 | return nvt_stepper 193 | 194 | 195 | def _qbaro(baro): 196 | """Propagation step for the cell volume (adjusting atomic positions and momenta).""" 197 | 198 | v = baro.p[0] / baro.m[0] 199 | halfdt = ( 200 | baro.qdt 201 | ) # this is set to half the inner loop in all integrators that use a barostat 202 | expq, expp = (np.exp(v * halfdt), np.exp(-v * halfdt)) 203 | 204 | baro.nm.qnm[0, :] *= expq 205 | baro.nm.pnm[0, :] *= expp 206 | baro.cell.h *= expq 207 | 208 | 209 | def _pbaro(baro): 210 | """Propagation step for the cell momentum (adjusting atomic positions and momenta).""" 211 | 212 | # we are assuming then that p the coupling between p^2 and dp/dt only involves the fast force 213 | dt = baro.pdt[0] 214 | 215 | # computes the pressure associated with the forces at the outer level MTS level. 216 | press = np.trace(baro.stress_mts(0)) / 3.0 217 | # integerates the kinetic part of the pressure with the force at the inner-most level. 218 | nbeads = baro.beads.nbeads 219 | baro.p += ( 220 | 3.0 221 | * dt 222 | * (baro.cell.V * (press - nbeads * baro.pext) + Constants.kb * baro.temp) 223 | ) 224 | 225 | 226 | def get_npt_stepper( 227 | sim, 228 | model, 229 | device, 230 | rescale_energy=True, 231 | random_rotation=False, 232 | use_standard_vv=False, 233 | ): 234 | motion = sim.syslist[0].motion 235 | if type(motion.integrator) is not NPTIntegrator: 236 | raise TypeError( 237 | f"Base i-PI integrator is of type {motion.integrator.__class__.__name__}, use a NPT setup." 238 | ) 239 | 240 | if use_standard_vv: 241 | # use the standard velocity Verlet integrator 242 | vv_step = get_standard_vv_step( 243 | sim, model, device, rescale_energy, random_rotation 244 | ) 245 | else: 246 | # defaults to the FlashMD VV stepper 247 | vv_step = get_flashmd_vv_step( 248 | sim, model, device, rescale_energy, random_rotation 249 | ) 250 | 251 | # The barostat here needs a simpler splitting than for BZP, something as 252 | # OAbBbBABbAbPO where Bp and Ap are the cell momentum and volume steps 253 | def npt_stepper(motion, *_, **__): 254 | info("@flashmd: Starting NPT step", verbosity.debug) 255 | info("@flashmd: Particle thermo", verbosity.debug) 256 | motion.thermostat.step() 257 | info("@flashmd: P constraints", verbosity.debug) 258 | motion.integrator.pconstraints() 259 | info("@flashmd: Barostat thermo", verbosity.debug) 260 | motion.barostat.thermostat.step() 261 | info("@flashmd: Barostat q", verbosity.debug) 262 | _qbaro(motion.barostat) 263 | info("@flashmd: Barostat p", verbosity.debug) 264 | _pbaro(motion.barostat) 265 | info("@flashmd: FlashVV", verbosity.debug) 266 | vv_step(motion) 267 | info("@flashmd: Barostat p", verbosity.debug) 268 | _pbaro(motion.barostat) 269 | info("@flashmd: Barostat q", verbosity.debug) 270 | _qbaro(motion.barostat) 271 | info("@flashmd: Barostat thermo", verbosity.debug) 272 | motion.barostat.thermostat.step() 273 | info("@flashmd: Particle thermo", verbosity.debug) 274 | motion.thermostat.step() 275 | info("@flashmd: P constraints", verbosity.debug) 276 | motion.integrator.pconstraints() 277 | motion.ensemble.time += motion.dt 278 | info("@flashmd: NPT Step finished", verbosity.debug) 279 | 280 | return npt_stepper 281 | 282 | 283 | def ipi_to_system(motion, device, dtype): 284 | positions = ( 285 | dstrip(motion.beads.q).reshape(-1, 3) * ase.units.Bohr / ase.units.Angstrom 286 | ) 287 | positions_torch = torch.tensor(positions, device=device, dtype=dtype) 288 | cell = dstrip(motion.cell.h).T * ase.units.Bohr / ase.units.Angstrom 289 | cell_torch = torch.tensor(cell, device=device, dtype=dtype) 290 | pbc_torch = torch.tensor([True, True, True], device=device, dtype=torch.bool) 291 | momenta = ( 292 | dstrip(motion.beads.p).reshape(-1, 3) 293 | * (9.1093819e-31 * ase.units.kg) 294 | * (ase.units.Bohr / ase.units.Angstrom) 295 | / (2.4188843e-17 * ase.units.s) 296 | ) 297 | momenta_torch = torch.tensor(momenta, device=device, dtype=dtype) 298 | masses = dstrip(motion.beads.m) * 9.1093819e-31 * ase.units.kg 299 | masses_torch = torch.tensor(masses, device=device, dtype=dtype) 300 | types_torch = torch.tensor( 301 | [ase.data.atomic_numbers[name] for name in motion.beads.names], 302 | device=device, 303 | dtype=torch.int32, 304 | ) 305 | system = System(types_torch, positions_torch, cell_torch, pbc_torch) 306 | system.add_data( 307 | "momenta", 308 | TensorMap( 309 | keys=Labels.single().to(device), 310 | blocks=[ 311 | TensorBlock( 312 | values=momenta_torch.unsqueeze(-1), 313 | samples=Labels( 314 | names=["system", "atom"], 315 | values=torch.tensor( 316 | [[0, j] for j in range(len(momenta_torch))], device=device 317 | ), 318 | ), 319 | components=[ 320 | Labels( 321 | names="xyz", 322 | values=torch.tensor([[0], [1], [2]], device=device), 323 | ) 324 | ], 325 | properties=Labels.single().to(device), 326 | ) 327 | ], 328 | ), 329 | ) 330 | system.add_data( 331 | "masses", 332 | TensorMap( 333 | keys=Labels.single().to(device), 334 | blocks=[ 335 | TensorBlock( 336 | values=masses_torch.unsqueeze(-1), 337 | samples=Labels( 338 | names=["system", "atom"], 339 | values=torch.tensor( 340 | [[0, j] for j in range(len(masses_torch))], device=device 341 | ), 342 | ), 343 | components=[], 344 | properties=Labels.single().to(device), 345 | ) 346 | ], 347 | ), 348 | ) 349 | return system 350 | 351 | 352 | def system_to_ipi(motion, system): 353 | # only needs to convert positions and momenta, it's assumed that the cell won't be changed 354 | motion.beads.q[:] = ( 355 | system.positions.cpu().numpy().flatten() * ase.units.Angstrom / ase.units.Bohr 356 | ) 357 | motion.beads.p[:] = system.get_data("momenta").block().values.squeeze( 358 | -1 359 | ).cpu().numpy().flatten() / ( 360 | (9.1093819e-31 * ase.units.kg) 361 | * (ase.units.Bohr / ase.units.Angstrom) 362 | / (2.4188843e-17 * ase.units.s) 363 | ) 364 | -------------------------------------------------------------------------------- /src/flashmd/models.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | from metatomic.torch import AtomisticModel, load_atomistic_model 3 | 4 | 5 | def get_universal_model(time_step: int = 16) -> AtomisticModel: 6 | if time_step not in [1, 4, 8, 16, 32, 64]: 7 | raise ValueError( 8 | "Universal FlashMD models are only available for" 9 | " time steps of 1, 4, 8, 16, 32, 64 fs." 10 | ) 11 | 12 | model_path = hf_hub_download( 13 | repo_id="lab-cosmo/flashmd", 14 | filename=f"flashmd_{time_step}fs.pt", 15 | cache_dir=None, 16 | revision="main", 17 | ) 18 | return load_atomistic_model(model_path) 19 | -------------------------------------------------------------------------------- /src/flashmd/stepper.py: -------------------------------------------------------------------------------- 1 | # from ..utils.pretrained import load_pretrained_models 2 | from metatomic.torch import ModelEvaluationOptions, ModelOutput 3 | from metatensor.torch import Labels, TensorBlock, TensorMap 4 | import torch 5 | from metatomic.torch import System 6 | from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists 7 | from typing import List 8 | from metatomic.torch import AtomisticModel 9 | 10 | 11 | class FlashMDStepper: 12 | def __init__( 13 | self, 14 | models: List[AtomisticModel], 15 | n_time_steps: int, 16 | device: torch.device, 17 | # q_error_threshold: float = 0.1, 18 | # p_error_threshold: float = 0.1, 19 | # energy_error_threshold: float = 0.1, 20 | ): 21 | self.n_time_steps = n_time_steps 22 | 23 | # internally, turn list of models into a dict and send to device 24 | self.models = {} 25 | for model in models: 26 | n_time_steps_model = int( 27 | [k for k in model.capabilities().outputs.keys() if "mtt::delta_" in k][ 28 | 0 29 | ].split("_")[1] 30 | ) 31 | self.models[n_time_steps_model] = model.to(device) 32 | 33 | # one of these for each model: 34 | self.evaluation_options = ModelEvaluationOptions( 35 | length_unit="Angstrom", 36 | outputs={ 37 | f"mtt::delta_{self.n_time_steps}_q": ModelOutput(per_atom=True), 38 | f"mtt::p_{self.n_time_steps}": ModelOutput(per_atom=True), 39 | }, 40 | ) 41 | 42 | self.dtype = getattr(torch, self.models[n_time_steps].capabilities().dtype) 43 | self.device = device 44 | 45 | def step(self, system: System): 46 | if system.device.type != self.device.type: 47 | raise ValueError("System device does not match stepper device.") 48 | if system.positions.dtype != self.dtype: 49 | raise ValueError("System dtype does not match stepper dtype.") 50 | 51 | system = get_system_with_neighbor_lists( 52 | system, self.models[self.n_time_steps].requested_neighbor_lists() 53 | ) 54 | 55 | masses = system.get_data("masses").block().values 56 | model_outputs = self.models[self.n_time_steps]( 57 | [system], self.evaluation_options, check_consistency=False 58 | ) 59 | delta_q_scaled = ( 60 | model_outputs[f"mtt::delta_{self.n_time_steps}_q"] 61 | .block() 62 | .values.squeeze(-1) 63 | ) 64 | p_scaled = ( 65 | model_outputs[f"mtt::p_{self.n_time_steps}"].block().values.squeeze(-1) 66 | ) 67 | sqrt_masses = torch.sqrt(masses) 68 | delta_q = delta_q_scaled / sqrt_masses 69 | p = p_scaled * sqrt_masses 70 | 71 | new_system = System( 72 | positions=system.positions + delta_q, 73 | types=system.types, 74 | cell=system.cell, 75 | pbc=system.pbc, 76 | ) 77 | new_system.add_data( 78 | "momenta", 79 | TensorMap( 80 | keys=Labels.single().to(self.device), 81 | blocks=[ 82 | TensorBlock( 83 | values=p.unsqueeze(-1), 84 | samples=Labels( 85 | names=["system", "atom"], 86 | values=torch.tensor( 87 | [[0, j] for j in range(len(new_system))], 88 | device=self.device, 89 | ), 90 | ), 91 | components=[ 92 | Labels( 93 | names="xyz", 94 | values=torch.tensor( 95 | [[0], [1], [2]], device=self.device 96 | ), 97 | ) 98 | ], 99 | properties=Labels.single().to(self.device), 100 | ) 101 | ], 102 | ), 103 | ) 104 | new_system.add_data( 105 | "masses", 106 | TensorMap( 107 | keys=Labels.single().to(self.device), 108 | blocks=[ 109 | TensorBlock( 110 | values=masses, 111 | samples=Labels( 112 | names=["system", "atom"], 113 | values=torch.tensor( 114 | [[0, j] for j in range(len(new_system))], 115 | device=self.device, 116 | ), 117 | ), 118 | components=[], 119 | properties=Labels.single().to(self.device), 120 | ) 121 | ], 122 | ), 123 | ) 124 | return new_system 125 | --------------------------------------------------------------------------------