├── .github └── workflows │ ├── pre-comit.yaml │ └── publish.yaml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── .python-version ├── .readthedocs.yaml ├── .vscode └── settings.json ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile └── source │ ├── _static │ ├── .gitignore │ ├── mlflow_compare.png │ ├── mlipx-dark.svg │ ├── mlipx-favicon.svg │ ├── mlipx-light.svg │ ├── quickstart_zndraw.png │ ├── zndraw_compare.png │ └── zndraw_render.png │ ├── abc.rst │ ├── authors.rst │ ├── build_graph.rst │ ├── concept.rst │ ├── concept │ ├── data.rst │ ├── distributed.rst │ ├── metrics.rst │ ├── models.rst │ ├── recipes.rst │ ├── zndraw.rst │ └── zntrack.rst │ ├── conf.py │ ├── glossary.rst │ ├── index.rst │ ├── installation.rst │ ├── nodes.rst │ ├── notebooks │ ├── combine.ipynb │ └── structure_relaxation.ipynb │ ├── quickstart.rst │ ├── quickstart │ ├── cli.rst │ └── python.rst │ ├── recipes.rst │ ├── recipes │ ├── adsorption.rst │ ├── energy_and_forces.rst │ ├── energy_volume.rst │ ├── homonuclear_diatomics.rst │ ├── invariances.rst │ ├── md.rst │ ├── neb.rst │ ├── phase_diagram.rst │ ├── pourbaix_diagram.rst │ ├── relax.rst │ └── vibrational_analysis.rst │ └── references.bib ├── mlipx ├── __init__.py ├── __init__.pyi ├── abc.py ├── benchmark │ ├── __init__.py │ ├── elements.py │ ├── file.py │ └── main.py ├── cli │ ├── __init__.py │ └── main.py ├── doc_utils.py ├── models.py ├── nodes │ ├── __init__.py │ ├── adsorption.py │ ├── apply_calculator.py │ ├── autowte.py │ ├── compare_calculator.py │ ├── diatomics.py │ ├── energy_volume.py │ ├── evaluate_calculator.py │ ├── filter_dataset.py │ ├── formation_energy.py │ ├── generic_ase.py │ ├── invariances.py │ ├── io.py │ ├── modifier.py │ ├── molecular_dynamics.py │ ├── mp_api.py │ ├── nebs.py │ ├── observer.py │ ├── orca.py │ ├── phase_diagram.py │ ├── pourbaix_diagram.py │ ├── rattle.py │ ├── smiles.py │ ├── structure_optimization.py │ ├── updated_frames.py │ └── vibrational_analysis.py ├── project.py ├── recipes │ ├── README.md │ ├── __init__.py │ ├── adsorption.py.jinja2 │ ├── energy_volume.py.jinja2 │ ├── homonuclear_diatomics.py.jinja2 │ ├── invariances.py.jinja2 │ ├── main.py │ ├── md.py.jinja2 │ ├── metrics.py │ ├── models.py.jinja2 │ ├── neb.py │ ├── phase_diagram.py.jinja2 │ ├── pourbaix_diagram.py.jinja2 │ ├── relax.py.jinja2 │ └── vibrational_analysis.py.jinja2 ├── utils.py └── version.py ├── pyproject.toml └── uv.lock /.github/workflows/pre-comit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | - uses: pre-commit/action@v3.0.1 15 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | release: 4 | types: 5 | - created 6 | 7 | jobs: 8 | publish-pypi: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Install uv 13 | uses: astral-sh/setup-uv@v5 14 | - name: Publish 15 | env: 16 | PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} 17 | run: | 18 | uv build 19 | uv publish --token $PYPI_TOKEN 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | tmp/ 164 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mlipx-hub"] 2 | path = mlipx-hub 3 | url = https://github.com/basf/mlipx-hub 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: check-added-large-files 8 | exclude: ^.*uv\.lock$ 9 | - id: check-case-conflict 10 | - id: check-docstring-first 11 | - id: check-executables-have-shebangs 12 | - id: check-json 13 | - id: check-merge-conflict 14 | args: ['--assume-in-merge'] 15 | exclude: ^(docs/) 16 | - id: check-toml 17 | - id: check-yaml 18 | - id: debug-statements 19 | - id: end-of-file-fixer 20 | exclude: .*\.json$ 21 | - id: mixed-line-ending 22 | args: ['--fix=lf'] 23 | - id: sort-simple-yaml 24 | - id: trailing-whitespace 25 | - repo: https://github.com/codespell-project/codespell 26 | rev: v2.4.1 27 | hooks: 28 | - id: codespell 29 | additional_dependencies: ["tomli"] 30 | - repo: https://github.com/astral-sh/ruff-pre-commit 31 | # Ruff version. 32 | rev: v0.11.5 33 | hooks: 34 | # Run the linter. 35 | - id: ruff 36 | args: [ --fix ] 37 | # Run the formatter. 38 | - id: ruff-format 39 | - repo: https://github.com/executablebooks/mdformat 40 | rev: 0.7.22 41 | hooks: 42 | - id: mdformat 43 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | version: 2 6 | 7 | submodules: 8 | include: all 9 | 10 | # Set the version of Python and other tools you might need 11 | build: 12 | os: ubuntu-lts-latest 13 | tools: 14 | python: "3.11" 15 | jobs: 16 | post_install: 17 | # see https://github.com/astral-sh/uv/issues/10074 18 | - pip install uv 19 | - UV_PROJECT_ENVIRONMENT=$READTHEDOCS_VIRTUALENV_PATH uv sync --link-mode=copy --group=docs 20 | # Navigate to the examples directory and perform DVC operations 21 | - cd mlipx-hub && dvc pull --allow-missing 22 | 23 | # Build documentation in the docs/ directory with Sphinx 24 | sphinx: 25 | configuration: docs/source/conf.py 26 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.typeCheckingMode": "basic", 3 | "[restructuredtext]": { 4 | "editor.wordWrap": "on" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # MLIPX Contribution guidelines 2 | 3 | ## Adding a new MLIP 4 | 5 | 1. Create a new entry in `[project.optional-dependencies]` in the `pyproject.toml`. Configure `[tool.uv]:conflicts` if necessary. 6 | 1. Add your model to `mlipx/recipes/models.py.jinja2` to the `ALL_MODELS` dictionary. 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 BASF 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ![Logo](https://raw.githubusercontent.com/basf/mlipx/refs/heads/main/docs/source/_static/mlipx-light.svg#gh-light-mode-only) 4 | ![Logo](https://raw.githubusercontent.com/basf/mlipx/refs/heads/main/docs/source/_static/mlipx-dark.svg#gh-dark-mode-only) 5 | 6 | [![PyPI version](https://badge.fury.io/py/mlipx.svg)](https://badge.fury.io/py/mlipx) 7 | [![ZnTrack](https://img.shields.io/badge/Powered%20by-ZnTrack-%23007CB0)](https://zntrack.readthedocs.io/en/latest/) 8 | [![ZnDraw](https://img.shields.io/badge/works_with-ZnDraw-orange)](https://github.com/zincware/zndraw) 9 | [![open issues](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/basf/mlipx/issues) 10 | [![Documentation Status](https://readthedocs.org/projects/mlipx/badge/?version=latest)](https://mlipx.readthedocs.io/en/latest/?badge=latest) 11 | 12 | [📘Documentation](https://mlipx.readthedocs.io) | 13 | [🛠️Installation](https://mlipx.readthedocs.io/en/latest/installation.html) | 14 | [📜Recipes](https://mlipx.readthedocs.io/en/latest/recipes.html) | 15 | [🚀Quickstart](https://mlipx.readthedocs.io/en/latest/quickstart.html) 16 | 17 |
18 | 19 |
20 |

Machine-Learned Interatomic Potential eXploration

21 |
22 | 23 | `mlipx` is a Python library designed for evaluating machine-learned interatomic 24 | potentials (MLIPs). It offers a growing set of evaluation methods alongside 25 | powerful visualization and comparison tools. 26 | 27 | The goal of `mlipx` is to provide a common platform for MLIP evaluation and to 28 | facilitate sharing results among researchers. This allows you to determine the 29 | applicability of a specific MLIP to your research and compare it against others. 30 | 31 | ## Installation 32 | 33 | Install `mlipx` via pip: 34 | 35 | ```bash 36 | pip install mlipx 37 | ``` 38 | 39 | > [!NOTE] 40 | > The `mlipx` package does not include the installation of any MLIP code, as we aim to keep the package as lightweight as possible. 41 | > If you encounter any `ImportError`, you may need to install the additional dependencies manually. 42 | 43 | ## Quickstart 44 | 45 | This section provides a brief overview of the core features of `mlipx`. For more detailed instructions, visit the [documentation](https://mlipx.readthedocs.io). 46 | 47 | Most recipes support different input formats, such as data file paths, `SMILES` strings, or Materials Project structure IDs. 48 | 49 | > [!NOTE] 50 | > Because `mlipx` uses Git and [DVC](https://dvc.org/doc), you need to create a new project directory to run your experiments in. Here's how to set up your project: 51 | > 52 | > ```bash 53 | > mkdir exp 54 | > cd exp 55 | > git init && dvc init 56 | > ``` 57 | > 58 | > If you want to use datafiles, it is recommend to track them with `dvc add ` instead of `git add `. 59 | > 60 | > ```bash 61 | > cp /your/data/file.xyz . 62 | > dvc add file.xyz 63 | > ``` 64 | 65 | ### Energy-Volume Curve 66 | 67 | Compute an energy-volume curve using the `mp-1143` structure from the Materials Project and MLIPs such as `mace-mpa-0`, `sevennet`, and `orb-v2`: 68 | 69 | ```bash 70 | mlipx recipes ev --models mace-mpa-0,sevennet,orb-v2 --material-ids=mp-1143 --repro 71 | mlipx compare --glob "*EnergyVolumeCurve" 72 | ``` 73 | 74 | > [!NOTE] 75 | > `mlipx` utilizes [ASE](https://wiki.fysik.dtu.dk/ase/index.html), 76 | > meaning any ASE-compatible calculator for your MLIP can be used. 77 | > If we do not provide a preset for your model, you can either adapt the `models.py` file, raise an [issue](https://github.com/basf/mlipx/issues/new) to request support, or submit a pull request to add your model directly. 78 | 79 | Below is an example of the resulting comparison: 80 | 81 | ![ZnDraw UI](https://github.com/user-attachments/assets/2036e6d9-3342-4542-9ddb-bbc777d2b093#gh-dark-mode-only "ZnDraw UI") 82 | ![ZnDraw UI](https://github.com/user-attachments/assets/c2479d17-c443-4550-a641-c513ede3be02#gh-light-mode-only "ZnDraw UI") 83 | 84 | > [!NOTE] 85 | > Set your default visualizer path using: `export ZNDRAW_URL=http://localhost:1234`. 86 | 87 | ### Structure Optimization 88 | 89 | Compare the performance of different models in optimizing multiple molecular structures from `SMILES` representations: 90 | 91 | ```bash 92 | mlipx recipes relax --models mace-mpa-0,sevennet,orb-v2 --smiles "CCO,C1=CC2=C(C=C1O)C(=CN2)CCN" --repro 93 | mlipx compare --glob "*0_StructureOptimization" 94 | mlipx compare --glob "*1_StructureOptimization" 95 | ``` 96 | 97 | ![ZnDraw UI](https://github.com/user-attachments/assets/7e26a502-3c59-4498-9b98-af8e17a227ce#gh-dark-mode-only "ZnDraw UI") 98 | ![ZnDraw UI](https://github.com/user-attachments/assets/a68ac9f5-e3fe-438d-ad4e-88b60499b79e#gh-light-mode-only "ZnDraw UI") 99 | 100 | ### Nudged Elastic Band (NEB) 101 | 102 | Run and compare nudged elastic band (NEB) calculations for a given start and end structure: 103 | 104 | ```bash 105 | mlipx recipes neb --models mace-mpa-0,sevennet,orb-v2 --datapath ../data/neb_end_p.xyz --repro 106 | mlipx compare --glob "*NEBs" 107 | ``` 108 | 109 | ![ZnDraw UI](https://github.com/user-attachments/assets/a2e80caf-dd86-4f14-9101-6d52610b9c34#gh-dark-mode-only "ZnDraw UI") 110 | ![ZnDraw UI](https://github.com/user-attachments/assets/0c1eb681-a32c-41c2-a15e-2348104239dc#gh-light-mode-only "ZnDraw UI") 111 | 112 | ## Python API 113 | 114 | You can also use all the recipes from the `mlipx` command-line interface 115 | programmatically in Python. 116 | 117 | > [!NOTE] 118 | > Whether you use the CLI or the Python API, you must work within a GIT 119 | > and DVC repository. This setup ensures reproducibility and enables automatic 120 | > caching and other features from DVC and ZnTrack. 121 | 122 | ```python 123 | import mlipx 124 | 125 | # Initialize the project 126 | project = mlipx.Project() 127 | 128 | # Define an MLIP 129 | mace_mp = mlipx.GenericASECalculator( 130 | module="mace.calculators", 131 | class_name="mace_mp", 132 | device="auto", 133 | kwargs={ 134 | "model": "medium", 135 | }, 136 | ) 137 | 138 | # Use the MLIP in a structure optimization 139 | with project: 140 | data = mlipx.LoadDataFile(path="/your/data/file.xyz") 141 | relax = mlipx.StructureOptimization( 142 | data=data.frames, 143 | data_id=-1, 144 | model=mace_mp, 145 | fmax=0.1 146 | ) 147 | 148 | # Reproduce the project state 149 | project.repro() 150 | 151 | # Access the results 152 | print(relax.frames) 153 | # >>> [ase.Atoms(...), ...] 154 | ``` 155 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = ../build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @cd "$(SOURCEDIR)" && $(SPHINXBUILD) -M help "." "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @cd "$(SOURCEDIR)" && $(SPHINXBUILD) -M $@ "." "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/source/_static/.gitignore: -------------------------------------------------------------------------------- 1 | *.mp4 2 | -------------------------------------------------------------------------------- /docs/source/_static/mlflow_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basf/mlipx/c109aa6b2d3ff8236d6a9ec0772a8c4efb9c7dea/docs/source/_static/mlflow_compare.png -------------------------------------------------------------------------------- /docs/source/_static/mlipx-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | MLIP 86 | -------------------------------------------------------------------------------- /docs/source/_static/mlipx-favicon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 78 | -------------------------------------------------------------------------------- /docs/source/_static/mlipx-light.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | MLIP 86 | -------------------------------------------------------------------------------- /docs/source/_static/quickstart_zndraw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basf/mlipx/c109aa6b2d3ff8236d6a9ec0772a8c4efb9c7dea/docs/source/_static/quickstart_zndraw.png -------------------------------------------------------------------------------- /docs/source/_static/zndraw_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basf/mlipx/c109aa6b2d3ff8236d6a9ec0772a8c4efb9c7dea/docs/source/_static/zndraw_compare.png -------------------------------------------------------------------------------- /docs/source/_static/zndraw_render.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basf/mlipx/c109aa6b2d3ff8236d6a9ec0772a8c4efb9c7dea/docs/source/_static/zndraw_render.png -------------------------------------------------------------------------------- /docs/source/abc.rst: -------------------------------------------------------------------------------- 1 | Abstract Base Classes 2 | ====================== 3 | We make use of abstract base classes, protocols and type hints to improve the workflow design experience. 4 | Further, these can be used for cross-package interoperability with other :term:`ZnTrack` based packages like :term:`IPSuite`. 5 | 6 | For most :term:`Node` classes operating on lists of :term:`ASE` objects, there are two scenarios: 7 | - The node operates on a single :term:`ASE` object. 8 | - The node operates on a list of :term:`ASE` objects. 9 | For both scenarios, the node is given a list of :term:`ASE` objects via the `data` attribute. 10 | For the first scenario, the `id` of the :term:`ASE` object is given via the `data_id` attribute which is omitted for the second scenario. 11 | 12 | .. automodule:: mlipx.abc 13 | :members: 14 | :undoc-members: 15 | -------------------------------------------------------------------------------- /docs/source/authors.rst: -------------------------------------------------------------------------------- 1 | Authors and Contributing 2 | ======================== 3 | 4 | .. note:: 5 | 6 | Every contribution is welcome and you will be included in this ever growing list. 7 | 8 | Authors 9 | ------- 10 | The creation of ``mlipx`` began during Fabian Zills internship at BASF SE, where he worked under the guidance of Sandip De. The foundational concepts and designs were initially developed by Sandip, while the current version of the code is a product of contributions from various members of the BASF team. Fabian Zills integrated several of his previous projects and lead the technical development of the initial release of this code. The code has been released with the intention of fostering community involvement in future developments. We acknowledge support from: 11 | 12 | - Fabian Zills 13 | - Sheena Agarwal 14 | - Sandip De 15 | - Shuang Han 16 | - Srishti Gupta 17 | - Tiago Joao Ferreira Goncalves 18 | - Edvin Fako 19 | 20 | 21 | Contribution Guidelines 22 | ----------------------- 23 | 24 | We welcome contributions to :code:`mlipx`! 25 | With the inclusion of your contributions, we can make :code:`mlipx` better for everyone. 26 | 27 | To ensure code quality and consistency, we use :code:`pre-commit` hooks. 28 | To install the pre-commit hooks, run the following command: 29 | 30 | .. code:: console 31 | 32 | (.venv) $ pre-commit install 33 | 34 | All pre-commit hooks have to pass before a pull request can be merged. 35 | 36 | For new recipes, we recommend adding an example to the ``\examples`` directory of this repository and updating the documentation accordingly. 37 | 38 | **Plugins** 39 | 40 | It is further possible, to add new recipes to ``mlipx`` by writing plugins. 41 | We use the entry point ``mlipx.recipes`` to load new recipes. 42 | You can find more information on entry points `here `_. 43 | 44 | Given the following file ``yourpackage/recipes.py`` in your package: 45 | 46 | .. code:: python 47 | 48 | from mlipx.recipes import app 49 | 50 | @app.command() 51 | def my_recipe(): 52 | # Your recipe code here 53 | 54 | you can add the following to your ``pyproject.toml``: 55 | 56 | .. code:: toml 57 | 58 | [project.entry-points."mlipx.recipes"] 59 | yourpackage = "yourpackage.recipes" 60 | 61 | 62 | and when your package is installed together with ``mlipx``, the recipe will be available in the CLI via ``mlipx recipes my_recipe``. 63 | -------------------------------------------------------------------------------- /docs/source/build_graph.rst: -------------------------------------------------------------------------------- 1 | .. _custom_nodes: 2 | 3 | Build your own Graph 4 | ===================== 5 | 6 | This section goes into more detail for adding your own :term:`ZnTrack` Node and designing a custom workflow. 7 | You will learn how to include :term:`MLIP` that can not be interfaced with the :code:`GenericASECalculator` Node. 8 | With your own custom Nodes you can build more comprehensive test cases or go even beyond :term:`MLIP` testing and build workflows for other scenarios, such as :term:`MLIP` training with :code:`mlipx` and :term:`IPSuite`. 9 | 10 | .. toctree:: 11 | :glob: 12 | 13 | notebooks/* 14 | -------------------------------------------------------------------------------- /docs/source/concept.rst: -------------------------------------------------------------------------------- 1 | Concept 2 | ======= 3 | 4 | ``mlipx`` is a tool designed to evaluate the performance of various **Machine-Learned Interatomic Potentials (MLIPs)**. 5 | It offers both static and dynamic test recipes, helping you identify the most suitable MLIP for your specific problem. 6 | 7 | The ``mlipx`` package is modular and highly extensible, achieved by leveraging the capabilities of :term:`ZnTrack` and community support to provide a wide range of different test cases and :term:`MLIP` interfaces. 8 | 9 | Static Tests 10 | ------------ 11 | 12 | Static tests focus on predefined datasets that serve as benchmarks for evaluating the performance of different :term:`MLIP` models. 13 | You provide a dataset file, and ``mlipx`` evaluates a specified list of :term:`MLIP` models to generate performance metrics. 14 | These tests are ideal for comparing general performance across multiple MLIPs on tasks with well-defined input data. 15 | 16 | Dynamic Tests 17 | ------------- 18 | 19 | Dynamic tests are designed to address specific user-defined problems where the dataset is not predetermined. These tests provide flexibility and adaptability to evaluate :term:`MLIP` models based on your unique requirements. For example, if you provide only the composition of a system, ``mlipx`` can assess the suitability of various :term:`MLIP` models for the problem. 20 | 21 | - ``mlipx`` offers several methods to generate new data using recipes such as :ref:`relax`, :ref:`md`, :ref:`homonuclear_diatomics`, or :ref:`ev`. 22 | - If no starting structures are available, ``mlipx`` can search public datasets like ``mptraj`` or the Materials Project for similar data. Alternatively, new structures can be generated directly from ``smiles`` strings, as detailed in the :ref:`data` section. 23 | 24 | This dynamic approach enables a more focused evaluation of :term:`MLIP` models, tailoring the process to the specific challenges and requirements of the user's system. 25 | 26 | Comparison 27 | ---------- 28 | 29 | A comprehensive comparison of different :term:`MLIP` models is crucial to identifying the best model for a specific problem. 30 | To facilitate this, ``mlipx`` integrates with :ref:`ZnDraw ` for visualizing trajectories and creating interactive plots of the generated data. 31 | 32 | Additionally, ``mlipx`` interfaces with :term:`DVC` for data versioning and can log metrics to :term:`mlflow`, 33 | providing a quick overview of all past evaluations. 34 | 35 | 36 | .. toctree:: 37 | :hidden: 38 | 39 | concept/data 40 | concept/models 41 | concept/recipes 42 | concept/zntrack 43 | concept/zndraw 44 | concept/metrics 45 | concept/distributed 46 | -------------------------------------------------------------------------------- /docs/source/concept/data.rst: -------------------------------------------------------------------------------- 1 | .. _data: 2 | 3 | Datasets 4 | ======== 5 | 6 | Data within ``mlipx`` is always represented as a list of :term:`ASE` atoms objects. 7 | There are various ways to provide this data to the workflow, depending on your requirements. 8 | 9 | Local Data Files 10 | ---------------- 11 | 12 | The simplest way to use data in the workflow is by providing a local data file, such as a trajectory file. 13 | 14 | .. code:: console 15 | 16 | (.venv) $ cp /path/to/data.xyz . 17 | (.venv) $ dvc add data.xyz 18 | 19 | .. dropdown:: Local data file (:code:`main.py`) 20 | :open: 21 | 22 | .. code:: python 23 | 24 | import zntrack 25 | import mlipx 26 | 27 | DATAPATH = "data.xyz" 28 | 29 | project = mlipx.Project() 30 | 31 | with project.group("initialize"): 32 | data = mlipx.LoadDataFile(path=DATAPATH) 33 | 34 | Remote Data Files 35 | ----------------- 36 | 37 | Since ``mlipx`` integrates with :term:`DVC`, it can easily handle data from remote locations. 38 | You can manually import a remote file: 39 | 40 | .. code:: console 41 | 42 | (.venv) $ dvc import-url https://url/to/your/data.xyz data.xyz 43 | 44 | Alternatively, you can use the ``zntrack`` interface for automated management. 45 | This allows evaluation of datasets such as :code:`mptraj` and supports filtering to select relevant configurations. 46 | For example, the following code selects all structures containing :code:`F` and :code:`B` atoms. 47 | 48 | .. dropdown:: Importing online resources (:code:`main.py`) 49 | :open: 50 | 51 | .. code:: python 52 | 53 | import zntrack 54 | import mlipx 55 | 56 | mptraj = zntrack.add( 57 | url="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz", 58 | path="mptraj.xyz", 59 | ) 60 | 61 | project = mlipx.Project() 62 | 63 | with project: 64 | raw_data = mlipx.LoadDataFile(path=mptraj) 65 | data = mlipx.FilterAtoms(data=raw_data.frames, elements=["B", "F"]) 66 | 67 | Materials Project 68 | ----------------- 69 | 70 | You can also search and retrieve structures from the `Materials Project`. 71 | 72 | .. dropdown:: Querying Materials Project (:code:`main.py`) 73 | :open: 74 | 75 | .. code:: python 76 | 77 | import mlipx 78 | 79 | project = mlipx.Project() 80 | 81 | with project.group("initialize"): 82 | data = mlipx.MPRester(search_kwargs={"material_ids": ["mp-1143"]}) 83 | 84 | .. note:: 85 | To use the Materials Project, you need an API key. Set the environment variable 86 | :code:`MP_API_KEY` to your API key. 87 | 88 | Generating Data 89 | --------------- 90 | 91 | Another approach is generating data dynamically. In ``mlipx``, you can build molecules or simulation boxes from SMILES strings. 92 | For instance, the following code generates a simulation box containing 10 ethanol molecules: 93 | 94 | .. dropdown:: Using SMILES (:code:`main.py`) 95 | :open: 96 | 97 | .. code:: python 98 | 99 | import mlipx 100 | 101 | project = mlipx.Project() 102 | 103 | with project.group("initialize"): 104 | confs = mlipx.Smiles2Conformers(smiles="CCO", num_confs=10) 105 | data = mlipx.BuildBox(data=[confs.frames], counts=[10], density=789) 106 | 107 | .. note:: 108 | The :code:`BuildBox` node requires :term:`Packmol` and :term:`rdkit2ase`. 109 | If you do not need a simulation box, you can use :code:`confs.frames` directly. 110 | -------------------------------------------------------------------------------- /docs/source/concept/distributed.rst: -------------------------------------------------------------------------------- 1 | .. _Distributed evaluation: 2 | 3 | Distributed evaluation 4 | ====================== 5 | 6 | For the evaluation of different :term:`MLIP` models, it is often necessary to 7 | run the evaluation in different environments due to package incompatibility. 8 | Another reason can be the computational cost of the evaluation. 9 | 10 | Writing the evaluation in a workflow-like manner allows for the separation of tasks 11 | onto different hardware or software environments. 12 | For this purpose, the :term:`paraffin` package was developed. 13 | 14 | You can use :code:`paraffin submit` to queue the evaluation of the selected stages. 15 | With :code:`paraffin worker --concurrency 5` you can start 5 workers to evaluate the stages. 16 | 17 | Further, you can select which stage should be picked up by which worker by defining a :code:`paraffin.yaml` file which supports wildcards. 18 | 19 | .. code-block:: yaml 20 | 21 | queue: 22 | "B_X*": BQueue 23 | "A_X_AddNodeNumbers": AQueue 24 | 25 | The above configuration will queue all stages starting with :code:`B_X` to the :code:`BQueue` and the stage :code:`A_X_AddNodeNumbers` to the :code:`AQueue`. 26 | You can then use :code:`paraffin worker --queue BQueue` to only pick up the stages from the :code:`BQueue` and vice versa. 27 | 28 | The paraffin package is available on PyPI and can be installed via: 29 | 30 | .. code-block:: bash 31 | 32 | pip install paraffin 33 | -------------------------------------------------------------------------------- /docs/source/concept/metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics Overview 2 | ================ 3 | 4 | ``mlipx`` provides several tools and integrations for comparing and visualizing metrics across experiments and nodes. 5 | This section outlines how to use these features to evaluate model performance and gain insights into various tasks. 6 | 7 | Comparing Metrics Using ``mlipx compare`` 8 | ----------------------------------------- 9 | 10 | With the ``mlipx compare`` command, you can directly compare results from the same Node or experiment using the :ref:`ZnDraw ` visualization tool. For example: 11 | 12 | .. code-block:: bash 13 | 14 | mlipx compare mace-mpa-0_StructureOptimization orb-v2_0_StructureOptimization 15 | 16 | This allows you to study the performance of different models for a single task in great detail. 17 | Every Node in ``mlipx`` defines its own comparison method for this. 18 | 19 | Integrations with DVC and MLFlow 20 | -------------------------------- 21 | 22 | To enable a broader overview of metrics and enhance experiment tracking, ``mlipx`` integrates with both :term:`DVC` and :term:`mlflow`. These tools allow for efficient tracking, visualization, and comparison of metrics across multiple experiments. 23 | 24 | MLFlow Integration 25 | ------------------- 26 | 27 | ``mlipx`` supports logging metrics to :term:`mlflow`. To use this feature, ensure ``mlflow`` is installed: 28 | 29 | .. code-block:: bash 30 | 31 | pip install mlflow 32 | 33 | 34 | .. note:: 35 | 36 | More information on how to setup MLFlow and run the server can be found in the `MLFlow documentation `_. 37 | 38 | Set the tracking URI to connect to your MLFlow server: 39 | 40 | .. code-block:: bash 41 | 42 | export MLFLOW_TRACKING_URI=http://localhost:5000 43 | 44 | Use the ``zntrack mlflow-sync`` command to upload metrics to MLFlow. 45 | For this command, you need to specify the Nodes you want to sync. 46 | 47 | .. note:: 48 | You can get an overview of all available Nodes using the ``zntrack list`` command. 49 | The use of glob patterns makes it easy to sync the same node for different models. 50 | To structure the experiments in MLFlow, you can specify a parent experiment. 51 | 52 | A typical structure for syncing multiple Nodes would look like this: 53 | 54 | .. code-block:: bash 55 | 56 | zntrack mlflow-sync "*StructureOptimization" --experiment "mlipx" --parent "StructureOptimization" 57 | zntrack mlflow-sync "*EnergyVolumeCurve" --experiment "mlipx" --parent "EnergyVolumeCurve" 58 | zntrack mlflow-sync "*MolecularDynamics" --experiment "mlipx" --parent "MolecularDynamics" 59 | 60 | With the MLFlow UI, you can visualize and compare metrics across experiments: 61 | 62 | .. image:: https://github.com/user-attachments/assets/2536d5d5-f8ef-4403-ac4b-670d40ae64de 63 | :align: center 64 | :alt: MLFlow UI Metrics 65 | :width: 100% 66 | :class: only-dark 67 | 68 | .. image:: https://github.com/user-attachments/assets/0d3d3187-b8ee-4b27-855e-7b245bd88346 69 | :align: center 70 | :alt: MLFlow UI Metrics 71 | :width: 100% 72 | :class: only-light 73 | 74 | Additionally, ``mlipx`` logs plots to MLFlow, enabling comparisons of relaxation energies across models or direct visualizations of energy-volume curves: 75 | 76 | .. image:: https://github.com/user-attachments/assets/19305012-6d92-40a3-bac6-68522bd55490 77 | :align: center 78 | :alt: MLFlow UI Plots 79 | :width: 100% 80 | :class: only-dark 81 | 82 | .. image:: https://github.com/user-attachments/assets/3cffba32-7abf-4a36-ac44-b584126c2e57 83 | :align: center 84 | :alt: MLFlow UI Plots 85 | :width: 100% 86 | :class: only-light 87 | 88 | 89 | Data Version Control (DVC) 90 | --------------------------- 91 | 92 | Each Node in ``mlipx`` includes predefined metrics that can be accessed via the :term:`DVC` command-line interface. Use the following commands to view metrics and plots: 93 | 94 | .. code-block:: bash 95 | 96 | dvc metrics show 97 | dvc plots show 98 | 99 | For more details on working with DVC, refer to the `DVC documentation `_. 100 | 101 | DVC also integrates seamlessly with Visual Studio Code through the `DVC extension `_, providing a user-friendly interface to browse and compare metrics and plots: 102 | 103 | .. image:: https://github.com/user-attachments/assets/79ede9d2-e11f-47da-b69c-523aa0361aaa 104 | :alt: DVC extension in Visual Studio Code 105 | :width: 100% 106 | :class: only-dark 107 | 108 | .. image:: https://github.com/user-attachments/assets/562ab225-15a8-409a-8e4e-f585e33103fa 109 | :alt: DVC extension in Visual Studio Code 110 | :width: 100% 111 | :class: only-light 112 | -------------------------------------------------------------------------------- /docs/source/concept/models.rst: -------------------------------------------------------------------------------- 1 | Models 2 | ====== 3 | 4 | For each recipe, the models to evaluate are defined in the :term:`models.py` file. 5 | In most cases, you can use :code:`mlipx.GenericASECalculator` to access models. 6 | However, in certain scenarios, you may need to provide a custom calculator. 7 | 8 | Below, we demonstrate how to write a custom calculator node for :code:`SevenCalc`. 9 | While this is an example, note that :code:`SevenCalc` could also be used with :code:`mlipx.GenericASECalculator`. 10 | 11 | Defining Models 12 | --------------- 13 | 14 | Here is the content of a typical :code:`models.py` file: 15 | 16 | .. dropdown:: Content of :code:`models.py` 17 | :open: 18 | 19 | .. code-block:: python 20 | 21 | import mlipx 22 | from src import SevenCalc 23 | 24 | mace_medium = mlipx.GenericASECalculator( 25 | module="mace.calculators", 26 | class_name="MACECalculator", 27 | device='auto', 28 | kwargs={ 29 | "model_paths": "mace_models/y7uhwpje-medium.model", 30 | }, 31 | ) 32 | 33 | mace_agnesi = mlipx.GenericASECalculator( 34 | module="mace.calculators", 35 | class_name="MACECalculator", 36 | device='auto', 37 | kwargs={ 38 | "model_paths": "mace_models/mace_mp_agnesi_medium.model", 39 | }, 40 | ) 41 | 42 | sevennet = SevenCalc(model='7net-0') 43 | 44 | MODELS = { 45 | "mace_medm": mace_medium, 46 | "mace_agne": mace_agnesi, 47 | "7net": sevennet, 48 | } 49 | 50 | Custom Calculator Example 51 | ------------------------- 52 | 53 | The :code:`SevenCalc` class, used in the example above, is defined in :code:`src/__init__.py` as follows: 54 | 55 | .. dropdown:: Content of :code:`src/__init__.py` 56 | :open: 57 | 58 | .. code-block:: python 59 | 60 | import dataclasses 61 | from ase.calculators.calculator import Calculator 62 | 63 | @dataclasses.dataclass 64 | class SevenCalc: 65 | model: str 66 | 67 | def get_calculator(self, **kwargs) -> Calculator: 68 | from sevenn.sevennet_calculator import SevenNetCalculator 69 | sevennet = SevenNetCalculator(self.model, device='cpu') 70 | 71 | return sevennet 72 | 73 | For more details, refer to the :ref:`custom_nodes` section. 74 | 75 | .. _update-frames-calc: 76 | 77 | Updating Dataset Keys 78 | --------------------- 79 | 80 | In some cases, models may need to be defined to convert existing dataset keys into the format :code:`mlipx` expects. 81 | For example, you may need to provide isolated atom energies or convert data where energies are stored as :code:`atoms.info['DFT_ENERGY']` 82 | and forces as :code:`atoms.arrays['DFT_FORCES']`. 83 | 84 | Here’s how to define a model for such a scenario: 85 | 86 | .. code-block:: python 87 | 88 | import mlipx 89 | 90 | REFERENCE = mlipx.UpdateFramesCalc( 91 | results_mapping={"energy": "DFT_ENERGY", "forces": "DFT_FORCES"}, 92 | info_mapping={mlipx.abc.ASEKeys.isolated_energies.value: "isol_ene"}, 93 | ) 94 | -------------------------------------------------------------------------------- /docs/source/concept/recipes.rst: -------------------------------------------------------------------------------- 1 | .. _recipes: 2 | 3 | Recipes 4 | ======= 5 | 6 | One of :code:`mlipx` core functionality is providing you with pre-designed recipes. 7 | These define workflows for evaluating :term:`MLIP` on specific tasks. 8 | You can get an overview of all available recipes using 9 | 10 | .. code-block:: console 11 | 12 | (.venv) $ mlipx recipes --help 13 | 14 | All recipes follow the same structure. 15 | It is recommended, to create a new directory for each recipe. 16 | 17 | .. code-block:: console 18 | 19 | (.venv) $ mkdir molecular_dynamics 20 | (.venv) $ cd molecular_dynamics 21 | (.venv) $ mlipx recipes md --initialize 22 | 23 | This will create the following structure: 24 | 25 | .. code-block:: console 26 | 27 | molecular_dynamics/ 28 | ├── .git/ 29 | ├── .dvc/ 30 | ├── models.py 31 | └── main.py 32 | 33 | After initialization, adapt the :code:`main.py` file to point towards the requested data files. 34 | Define all models for testing in the :term:`models.py` file. 35 | 36 | Finally, build the recipe using 37 | 38 | .. code-block:: console 39 | 40 | (.venv) $ python main.py 41 | (.venv) $ dvc repro 42 | 43 | 44 | Upload Results 45 | -------------- 46 | Once the recipe is finished, you can persist the results and upload them to a remote storage. 47 | Therefore, you want to make a GIT commit and push it to your repository. 48 | 49 | .. code-block:: console 50 | 51 | (.venv) $ git add . 52 | (.venv) $ git commit -m "Finished molecular dynamics test" 53 | (.venv) $ git push 54 | (.venv) $ dvc push 55 | 56 | .. note:: 57 | You need to define a :term:`GIT` and :term:`DVC` remote to push the results. 58 | More information on how to setup a :term:`DVC` remote can be found at https://dvc.org/doc/user-guide/data-management/remote-storage. 59 | 60 | 61 | In combination or as an alternative, you can upload the results to a parameter and metric tracking service, such as :term:`mlflow`. 62 | Given a running :term:`mlflow` server, you can use the following command to upload the results: 63 | 64 | .. code-block:: console 65 | 66 | (.venv) $ zntrack mlflow-sync --help 67 | 68 | .. note:: 69 | Depending on the installed packages, the :term:`mlflow` command might not be available. 70 | This functionality is provided by the :term:`zntrack` package, and other tracking services can be used as well. 71 | They will show up once the respective package is installed. 72 | See https://zntrack.readthedocs.io/ for more information. 73 | -------------------------------------------------------------------------------- /docs/source/concept/zndraw.rst: -------------------------------------------------------------------------------- 1 | .. _zndraw: 2 | 3 | Visualisation 4 | ============= 5 | :code:`mlipx` uses ZnDraw as primary tool for visualisation and comparison. 6 | The following will give you an overview. 7 | 8 | The ZnDraw package provides versatile visualisation package for atomisitc structures. 9 | It is based on :term:`ASE` and runs as a web application. 10 | You can install it via: 11 | 12 | .. code:: bash 13 | 14 | pip install zndraw 15 | 16 | 17 | It can be used to visualize data through a CLI: 18 | 19 | .. code:: bash 20 | 21 | zndraw file.xyz # any ASE supported file format + H5MD 22 | zndraw --remote . Node.frames # any ZnTrack node that has an attribute `list[ase.Atoms]` 23 | 24 | Once you have a running ZnDraw instance, you can connect to it from within Python. 25 | You can find more information in the GUI by clicking on :code:`Python Access`. 26 | The :code:`vis` object behaves like a list of :term:`ASE` atom objects. 27 | Modifying them in place, will be reflected in real-time on the GUI. 28 | 29 | .. tip:: 30 | 31 | You can keep a ZnDraw instance running in the background and set the environment variable :code:`ZNDRAW_URL` to the URL of the running instance. 32 | This way, you do not have to define a ZnDraw url when running ZnDraw or ``mlipx`` CLI commands. 33 | You can also setup a `ZnDraw Docker container `_ to always have a running instance. 34 | 35 | .. code:: python 36 | 37 | from zndraw import ZnDraw 38 | 39 | vis = ZnDraw(url="http://localhost:1234", token="") 40 | 41 | print(vis[0]) 42 | >>> ase.Atoms(...) 43 | 44 | vis.append(ase.Atoms(...)) 45 | 46 | 47 | .. image:: ../_static/zndraw_render.png 48 | :width: 100% 49 | 50 | **Figure 1** Graphical user interface of the :ref:`ZnDraw ` package with GPU path tracing enabled. 51 | 52 | 53 | For further information have a look at the ZnDraw repository https://github.com/zincware/zndraw - a full documentation will be provided soon. 54 | -------------------------------------------------------------------------------- /docs/source/concept/zntrack.rst: -------------------------------------------------------------------------------- 1 | Workflows 2 | ========= 3 | 4 | The :code:`mlipx` package is based ZnTrack. 5 | Although, :code:`mlipx` usage does not require you to understand how ZnTrack works, the following will give a short overview of the concept. 6 | We will take an example of building a simulation Box from :code:`smiles` as illustrated in the following Python script. 7 | 8 | .. code:: python 9 | 10 | import rdkit2ase 11 | 12 | water = rdkit2ase.smiles2atoms('O') 13 | ethanol = rdkit2ase.smiles2atoms('CCO') 14 | 15 | box = rdkit2ase.pack([[water], [ethanol]], counts=[50, 50], density=800) 16 | print(box) 17 | >>> ase.Atoms(...) 18 | 19 | This script can also be represented as the following workflow which we will now convert. 20 | 21 | .. mermaid:: 22 | :align: center 23 | 24 | graph TD 25 | BuildWater --> PackBox 26 | BuildEtOH --> PackBox 27 | 28 | 29 | With ZnTrack you can build complex workflows based on :term:`DVC` and :term:`GIT`. 30 | The first part of a workflow is defining the steps, which in the context of ZnTrack are called :code:`Node`. 31 | A :code:`Node` is based on the Python :code:`dataclass` module defining it's arguments as class attributes. 32 | 33 | .. note:: 34 | 35 | It is highly recommend to follow the single-responsibility principle when writing a :code:`Node`. For example if you have a relaxation followed by a molecular dynamics simulation, separate the these into two Nodes. But also keep it mind, that there is some communication overhead between Nodes, so e.g. defining each MD step as a separate Node would not be recommended. 36 | 37 | .. code:: python 38 | 39 | import zntrack 40 | import ase 41 | import rdkit2ase 42 | 43 | class BuildMolecule(zntrack.Node): 44 | smiles: str = zntrack.params() 45 | 46 | frames: list[ase.Atoms] = zntrack.outs() 47 | 48 | def run(self): 49 | self.frames = [rdkit2ase.smiles2atoms(self.smiles)] 50 | 51 | With this :code:`BuildMolecule` class we can bring the :code:`rdkit2ase.smiles2atoms` onto the graph by defining the inputs and outputs. 52 | Further, we need to define a :code:`Node` for the :code:`rdkit2ase.pack` function. 53 | For this, we define the :code:`PackBox` node as follows: 54 | 55 | .. code:: python 56 | 57 | import ase.io 58 | import pathlib 59 | 60 | class PackBox(zntrack.Node): 61 | data: list[list[ase.Atoms]] = zntrack.deps() 62 | counts: list[int] = zntrack.params() 63 | density: float = zntrack.params() 64 | 65 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / 'frames.xyz') 66 | 67 | def run(self): 68 | box = rdkit2ase.pack(self.data, counts=self.counts, density=self.density) 69 | ase.io.write(self.frames_path, box) 70 | 71 | .. note:: 72 | 73 | The :code:`zntrack.outs_path(zntrack.nwd / 'frames.xyz')` provides a unique output path per node in the node working directory (nwd). It is crucial to define every input and output as ZnTrack attributes. Otherwise, the results will be lost. 74 | 75 | With this Node, we can build our graph: 76 | 77 | .. code:: python 78 | 79 | project = zntrack.Project() 80 | 81 | with project: 82 | water = BuildMolecule(smiles="O") 83 | ethanol = BuildMolecule(smiles="CCO") 84 | 85 | box = PackBox(data=[water.frames, ethanol.frames], counts=[50, 50], density=800) 86 | 87 | project.build() 88 | 89 | .. note:: 90 | 91 | The `project.build()` command will not run the graph but only define how the graph is to be executed in the future. 92 | Consider it a pure graph definition file. 93 | If you write this into a single :code:`main.py` file, it should look like 94 | 95 | .. dropdown:: Content of :code:`main.py` 96 | 97 | .. code-block:: python 98 | 99 | import zntrack 100 | import ase.io 101 | import rdkit2ase 102 | import pathlib 103 | 104 | class BuildMolecule(zntrack.Node): 105 | smiles: str = zntrack.params() 106 | 107 | frames: list[ase.Atoms] = zntrack.outs() 108 | 109 | def run(self): 110 | self.frames = [rdkit2ase.smiles2atoms(self.smiles)] 111 | 112 | class PackBox(zntrack.Node): 113 | data: list[list[ase.Atoms]] = zntrack.deps() 114 | counts: list[int] = zntrack.params() 115 | density: float = zntrack.params() 116 | 117 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / 'frames.xyz') 118 | 119 | def run(self): 120 | box = rdkit2ase.pack(self.data, counts=self.counts, density=self.density) 121 | ase.io.write(self.frames_path, box) 122 | 123 | if __name__ == "__main__": 124 | project = zntrack.Project() 125 | 126 | with project: 127 | water = BuildMolecule(smiles="O") 128 | ethanol = BuildMolecule(smiles="CCO") 129 | 130 | box = PackBox(data=[water.frames, ethanol.frames], counts=[50, 50], density=800) 131 | 132 | project.build() 133 | 134 | To run the graph you can use the :term:`DVC` CLI :code:`dvc repro` (or the :term:`paraffin` package, see :ref:`Distributed evaluation`. ) 135 | 136 | Once finished, you can look at the results by loading the nodes: 137 | 138 | .. code:: python 139 | 140 | import zntrack 141 | import ase.io 142 | 143 | box = zntrack.from_rev("PackBox") 144 | print(ase.io.read(box.frames_path)) 145 | >>> ase.Atoms(...) 146 | 147 | 148 | For further information have a look at the ZnTrack documentation https://zntrack.readthedocs.io and repository https://github.com/zincware/zntrack . 149 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | import typing as t 7 | 8 | import mlipx 9 | 10 | # -- Project information ----------------------------------------------------- 11 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 12 | 13 | project = "mlipx" 14 | copyright = "2025, Fabian Zills, Sheena Agarwal, Sandip De" 15 | author = "Fabian Zills, Sheena Agarwal, Sandip De" 16 | release = mlipx.__version__ 17 | 18 | # -- General configuration --------------------------------------------------- 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 20 | 21 | extensions = [ 22 | "sphinx.ext.doctest", 23 | "sphinx.ext.autodoc", 24 | "sphinxcontrib.mermaid", 25 | "sphinx.ext.viewcode", 26 | "sphinx.ext.napoleon", 27 | "hoverxref.extension", 28 | "sphinxcontrib.bibtex", 29 | "sphinx_copybutton", 30 | "jupyter_sphinx", 31 | "sphinx_design", 32 | "nbsphinx", 33 | "sphinx_mdinclude", 34 | ] 35 | 36 | templates_path = ["_templates"] 37 | exclude_patterns = [] 38 | 39 | 40 | # -- Options for HTML output ------------------------------------------------- 41 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 42 | 43 | html_theme = "furo" 44 | html_static_path = ["_static"] 45 | html_title = "Machine Learned Interatomic Potential eXploration" 46 | html_short_title = "mlipx" 47 | html_favicon = "_static/mlipx-favicon.svg" 48 | 49 | html_theme_options: t.Dict[str, t.Any] = { 50 | "light_logo": "mlipx-light.svg", 51 | "dark_logo": "mlipx-dark.svg", 52 | "footer_icons": [ 53 | { 54 | "name": "GitHub", 55 | "url": "https://github.com/basf/mlipx", 56 | "html": "", 57 | "class": "fa-brands fa-github fa-2x", 58 | }, 59 | ], 60 | "source_repository": "https://github.com/basf/mlipx/", 61 | "source_branch": "main", 62 | "source_directory": "docs/source/", 63 | "navigation_with_keys": True, 64 | } 65 | 66 | # font-awesome logos 67 | html_css_files = [ 68 | "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/brands.min.css", 69 | ] 70 | 71 | # -- Options for hoverxref extension ----------------------------------------- 72 | # https://sphinx-hoverxref.readthedocs.io/en/latest/ 73 | 74 | hoverxref_roles = ["term"] 75 | hoverxref_role_types = { 76 | "class": "tooltip", 77 | } 78 | 79 | 80 | # -- Options for sphinxcontrib-bibtex ---------------------------------------- 81 | # https://sphinxcontrib-bibtex.readthedocs.io/en/latest/ 82 | 83 | bibtex_bibfiles = ["references.bib"] 84 | 85 | # -- Options for sphinx_copybutton ------------------------------------------- 86 | # https://sphinx-copybutton.readthedocs.io/en/latest/ 87 | 88 | copybutton_prompt_text = r">>> |\.\.\. |\(.*\) \$ " 89 | copybutton_prompt_is_regexp = True 90 | -------------------------------------------------------------------------------- /docs/source/glossary.rst: -------------------------------------------------------------------------------- 1 | Glossary 2 | ======== 3 | 4 | .. glossary:: 5 | 6 | MLIP 7 | Machine learned interatomic potential. 8 | 9 | GIT 10 | GIT is a distributed version control system. It allows multiple people to work on a project at the same time without overwriting each other's changes. It also keeps a history of all changes made to the project, so you can easily revert to an earlier version if necessary. 11 | 12 | DVC 13 | Data Version Control (DVC) is a tool used in machine learning projects to track and version the datasets used in the project, as well as the code and the results. This makes it easier to reproduce experiments and share results with others. 14 | More information can be found at https://dvc.org/ . 15 | 16 | mlflow 17 | Mlflow is an open-source platform that helps manage the machine learning lifecycle, including experimentation, reproducibility, and deployment. It keeps track of the parameters used in the model as well as the metrics obtained from the model. 18 | More information can be found at https://mlflow.org/ . 19 | 20 | ZnTrack 21 | ZnTrack :footcite:t:`zillsZnTrackDataCode2024` is a Python package that provides a framework for defining and executing workflows. It allows users to define a sequence of tasks, each represented by a Node, and manage their execution and dependencies. 22 | The package can be installed via :code:`pip install zntracck` or from source at https://github.com/zincware/zntrack . 23 | 24 | IPSuite 25 | IPSuite by :footcite:t:`zillsCollaborationMachineLearnedPotentials2024` is a software package for the development and application of machine-learned interatomic potentials (MLIPs). It provides functionalities for training and testing MLIPs, as well as for running simulations using these potentials. 26 | The package can be installed via :code:`pip install ipsuite` or from source at https://github.com/zincware/ipsuite . 27 | 28 | ZnDraw 29 | The :ref:`ZnDraw ` package for visualisation and editing of atomistic structures :footcite:`elijosiusZeroShotMolecular2024`. 30 | The package can be installed via :code:`pip install zndraw` or from source at https://github.com/zincware/zndraw . 31 | 32 | main.py 33 | The :term:`ZnTrack` graph definition for the recipe is stored in this file. 34 | 35 | models.py 36 | The file that contains the models for testing. 37 | Each recipe will import the models from this file. 38 | It should follow the following format: 39 | 40 | .. code-block:: python 41 | 42 | from mlipx.abc import NodeWithCalculator 43 | 44 | MODELS: dict[str, NodeWithCalculator] = { 45 | ... 46 | } 47 | 48 | packmol 49 | Packmol is a software package used for building initial configurations for molecular dynamics or Monte Carlo simulations. It can generate a random collection of molecules using the specified density and composition. More information can be found at https://m3g.github.io/packmol/ . 50 | 51 | rdkit2ase 52 | A package for converting RDKit molecules to ASE atoms. 53 | The package can be installed via :code:`pip install rdkit2ase` or from source at https://github.com/zincware/rdkit2ase . 54 | 55 | Node 56 | A node is a class that represents a single step in the workflow. 57 | It should inherit from the :class:`zntrack.Node` class. 58 | The node should implement the :meth:`zntrack.Node.run` method. 59 | 60 | ASE 61 | The Atomic Simulation Environment (ASE). More information can be found at https://wiki.fysik.dtu.dk/ase/ 62 | 63 | paraffin 64 | The paraffin package for the distributed evaluation of :term:`DVC` stages. 65 | The package can be installed via :code:`pip install paraffin` or from source at https://github.com/zincware/paraffin . 66 | 67 | 68 | .. footbibliography:: 69 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | MLIPX Documentation 2 | =================== 3 | 4 | :code:`mlipx` is a Python library for the evaluation and benchmarking of machine-learned interatomic potentials (:term:`MLIP`). MLIPs are advanced computational models that use machine learning techniques to describe complex atomic interactions accurately and efficiently. They significantly accelerate traditional quantum mechanical modeling workflows, making them invaluable for simulating a wide range of physical phenomena, from chemical reactions to materials properties and phase transitions. MLIP testing requires more than static cross-validation protocols. While these protocols are essential, they are just the beginning. Evaluating energy and force prediction accuracies on a test set is only the first step. To determine the real-world usability of an ML model, more comprehensive testing is needed. 5 | 6 | :code:`mlipx` addresses this need by providing systematically designed testing recipes to assess the strengths and weaknesses of rapidly developing growing flavours of MLIP models. These recipes help ensure that models are robust and applicable to a wide range of scenarios. :code:`mlipx` provides you with an ever-growing set of evaluation methods accompanied by comprehensive visualization and comparison tools. 7 | 8 | The goal of this project is to provide a common platform for the evaluation of MLIPs and to facilitate the exchange of evaluation results between researchers. 9 | Ultimately, you should be able to determine the applicability of a given MLIP for your specific research question and to compare it to other MLIPs. 10 | 11 | By offering these capabilities, MLIPX helps researchers determine the applicability of MLIPs for specific research questions and compare them effectively while developing from scratch or finetuning universal models. This collaborative tool promotes transparency and reproducibility in MLIP evaluations. 12 | 13 | Join us in using and improving MLIPX to advance the field of machine-learned interatomic potentials. Your contributions and feedback are invaluable. 14 | 15 | .. note:: 16 | 17 | This project is under active development. 18 | 19 | 20 | Example 21 | ------- 22 | 23 | Create a ``mlipx`` :ref:`recipe ` to compute :ref:`ev` for the `mp-1143 `_ structure using different :term:`MLIP` models 24 | 25 | .. code-block:: console 26 | 27 | (.venv) $ mlipx recipes ev --models mace-mpa-0,sevennet,orb-v2 --material-ids=mp-1143 --repro 28 | (.venv) $ mlipx compare --glob "*EnergyVolumeCurve" 29 | 30 | and use the integration with :ref:`ZnDraw ` to visualize the resulting trajectories and compare the energies interactively. 31 | 32 | .. image:: https://github.com/user-attachments/assets/c2479d17-c443-4550-a641-c513ede3be02 33 | :width: 100% 34 | :alt: ZnDraw 35 | :class: only-light 36 | 37 | .. image:: https://github.com/user-attachments/assets/2036e6d9-3342-4542-9ddb-bbc777d2b093 38 | :width: 100% 39 | :alt: ZnDraw 40 | :class: only-dark 41 | 42 | .. toctree:: 43 | :hidden: 44 | :maxdepth: 2 45 | 46 | installation 47 | quickstart 48 | concept 49 | recipes 50 | build_graph 51 | nodes 52 | glossary 53 | abc 54 | authors 55 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | From PyPI 5 | --------- 6 | 7 | To use :code:`mlipx`, first install it using pip: 8 | 9 | .. code-block:: console 10 | 11 | (.venv) $ pip install mlipx 12 | 13 | .. note:: 14 | 15 | The :code:`mlipx` package installation does not contain any :term:`MLIP` packages. 16 | Due to different dependencies, it is highly recommended to install your preferred :term:`MLIP` package individually into the same environment. 17 | We provide extras for the :term:`MLIP` packages included in our documentation. 18 | You can install them using extras (not exhaustive): 19 | 20 | .. code-block:: console 21 | 22 | (.venv) $ pip install mlipx[mace] 23 | (.venv) $ pip install mlipx[orb] 24 | 25 | To get an overview of the currently available models :code:`mlipx` is familiar with, you can use the following command: 26 | 27 | .. code-block:: console 28 | 29 | (.venv) $ mlipx info 30 | 31 | .. note:: 32 | 33 | If you encounter en error like :code:`Permission denied '/var/cache/dvc'` you might want to reinstall :code:`pip install platformdirs==3.11.0` or :code:`pip install platformdirs==3.10.0` as discussed at https://github.com/iterative/dvc/issues/9184 34 | 35 | From Source 36 | ----------- 37 | 38 | To install and develop :code:`mlipx` from source we recommend using :code:`https://docs.astral.sh/uv`. 39 | More information and installation instructions can be found at https://docs.astral.sh/uv/getting-started/installation/ . 40 | 41 | .. code:: console 42 | 43 | (.venv) $ git clone https://github.com/basf/mlipx 44 | (.venv) $ cd mlipx 45 | (.venv) $ uv sync 46 | (.venv) $ source .venv/bin/activate 47 | 48 | You can quickly switch between different :term:`MLIP` packages extras using :code:`uv sync` command. 49 | 50 | 51 | .. code:: console 52 | 53 | (.venv) $ uv sync --extra mattersim 54 | (.venv) $ uv sync --extra sevenn 55 | -------------------------------------------------------------------------------- /docs/source/nodes.rst: -------------------------------------------------------------------------------- 1 | Nodes 2 | ===== 3 | 4 | The core functionality of :code:`mlipx` is based on :term:`ZnTrack` nodes. 5 | Each Node is documented in the following section. 6 | 7 | .. glossary:: 8 | 9 | LoadDataFile 10 | .. autofunction:: mlipx.LoadDataFile 11 | 12 | LangevinConfig 13 | .. autofunction:: mlipx.LangevinConfig 14 | 15 | ApplyCalculator 16 | .. autofunction:: mlipx.ApplyCalculator 17 | 18 | EvaluateCalculatorResults 19 | .. autofunction:: mlipx.EvaluateCalculatorResults 20 | 21 | CompareCalculatorResults 22 | .. autofunction:: mlipx.CompareCalculatorResults 23 | 24 | CompareFormationEnergy 25 | .. autofunction:: mlipx.CompareFormationEnergy 26 | 27 | CalculateFormationEnergy 28 | .. autofunction:: mlipx.CalculateFormationEnergy 29 | 30 | MaximumForceObserver 31 | .. autofunction:: mlipx.MaximumForceObserver 32 | 33 | TemperatureRampModifier 34 | .. autofunction:: mlipx.TemperatureRampModifier 35 | 36 | MolecularDynamics 37 | .. autofunction:: mlipx.MolecularDynamics 38 | 39 | HomonuclearDiatomics 40 | .. autofunction:: mlipx.HomonuclearDiatomics 41 | 42 | NEBinterpolate 43 | .. autofunction:: mlipx.NEBinterpolate 44 | 45 | NEBs 46 | .. autofunction:: mlipx.NEBs 47 | 48 | VibrationalAnalysis 49 | .. autofunction:: mlipx.VibrationalAnalysis 50 | 51 | PhaseDiagram 52 | .. autofunction:: mlipx.PhaseDiagram 53 | 54 | PourbaixDiagram 55 | .. autofunction:: mlipx.PourbaixDiagram 56 | 57 | StructureOptimization 58 | .. autofunction:: mlipx.StructureOptimization 59 | 60 | Smiles2Conformers 61 | .. autofunction:: mlipx.Smiles2Conformers 62 | 63 | BuildBox 64 | .. autofunction:: mlipx.BuildBox 65 | 66 | EnergyVolumeCurve 67 | .. autofunction:: mlipx.EnergyVolumeCurve 68 | 69 | FilterAtoms 70 | .. autofunction:: mlipx.FilterAtoms 71 | 72 | RotationalInvariance 73 | .. autofunction:: mlipx.RotationalInvariance 74 | 75 | TranslationalInvariance 76 | .. autofunction:: mlipx.TranslationalInvariance 77 | 78 | PermutationInvariance 79 | .. autofunction:: mlipx.PermutationInvariance 80 | 81 | RelaxAdsorptionConfigs 82 | .. autofunction:: mlipx.RelaxAdsorptionConfigs 83 | 84 | BuildASEslab 85 | .. autofunction:: mlipx.BuildASEslab 86 | -------------------------------------------------------------------------------- /docs/source/notebooks/structure_relaxation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/basf/mlipx/blob/main/docs/source/notebooks/structure_relaxation.ipynb)\n", 8 | "\n", 9 | "# Structure Relaxtion with Custom Nodes\n", 10 | "\n", 11 | "You can combine `mlipx` with custom code by writing ZnTrack nodes.\n", 12 | "We will write a Node to perform a geometry relaxation similar to `mlipx.StructureOptimization`." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "# Only install the packages if they are not already installed\n", 22 | "!pip show mlipx > /dev/null 2>&1 || pip install mlipx\n", 23 | "!pip show rdkit2ase > /dev/null 2>&1 || pip install rdkit2ase" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "# We will create a GIT and DVC repository in a temporary directory\n", 33 | "import os\n", 34 | "import tempfile\n", 35 | "\n", 36 | "temp_dir = tempfile.TemporaryDirectory()\n", 37 | "os.chdir(temp_dir.name)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "Like all `mlipx` Nodes we will use a GIT and DVC repository to run experiments.\n", 45 | "To make our custom code available, we structure our project like\n", 46 | "\n", 47 | "```\n", 48 | "relaxation/\n", 49 | " ├── .git/\n", 50 | " ├── .dvc/\n", 51 | " ├── src/__init__.py\n", 52 | " ├── src/relaxation.py\n", 53 | " ├── models.py\n", 54 | " └── main.py\n", 55 | "```\n", 56 | "\n", 57 | "to allow us to import our code `from src.relaxation import Relax`.\n", 58 | "Alternatively, you can package your code and import it like any other Python package." 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "!git init\n", 68 | "!dvc init --quiet\n", 69 | "!mkdir src\n", 70 | "!touch src/__init__.py" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "The code we want to put into our `Relax` `Node` is the following:\n", 78 | "\n", 79 | "\n", 80 | "```python\n", 81 | "from ase.optimize import BFGS\n", 82 | "import ase.io\n", 83 | "\n", 84 | "data: list[ase.Atoms]\n", 85 | "calc: ase.calculator.Calculator\n", 86 | "\n", 87 | "end_structures = []\n", 88 | "for atoms in data:\n", 89 | " atoms.set_calculator(calc)\n", 90 | " opt = BFGS(atoms)\n", 91 | " opt.run(fmax=0.05)\n", 92 | " end_structures.append(atoms)\n", 93 | "\n", 94 | "ase.io.write('end_structures.xyz', end_structures)\n", 95 | "```\n", 96 | "\n", 97 | "To do so, we need to identify and define the inputs and outputs of our code.\n", 98 | "We provide the `data: list[ase.Atoms]` from a data loading Node.\n", 99 | "Therefore, we use `data: list = zntrack.deps()`.\n", 100 | "If you want to read directly from file you could use `data_path: str = zntrack.deps_path()`.\n", 101 | "We access the calculator in a similar way using `model: NodeWithCalculator = zntrack.deps()`.\n", 102 | "`mlipx` provides the `NodeWithCalculator` abstract base class for a common communication on how to share `ASE` calculators.\n", 103 | "Another convention is providing inputs as `data: list[ase.Atoms]` and outputs as `frames: list[ase.Atoms]`.\n", 104 | "As we save our data to a file, we define `frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / 'frames.xyz')` to store the output trajetory in the node working directory (nwd) as `frames.xyz`.\n", 105 | "The `zntrack.nwd` provides a unique directory per `Node` to store the data at.\n", 106 | "As the communication between `mlipx` nodes is based on `ase.Atoms` we define a `@frames` property.\n", 107 | "Within this, we could also alter the `ase.Atoms` object, thus making the node communication independent of the file format facilitating data communication via code or Data as Code (DaC).\n", 108 | "To summarize, each Node provides all the information on how to `save` and `load` the produced data, simplifying communication and reducing issues with different file format conventions.\n", 109 | "\n", 110 | "Besides the implemented fields, there are also `x: dict = zntrack.params`, `x: dict = zntrack.metrics` and `x: pd.DataFrame = zntrack.plots` and their corresponding file path versions `x: str|pathlib.Path = zntrack.params_path`, `zntrack.metrics_path` and `zntrack.plots_path`.\n", 111 | "For general outputs there is `x: any = zntrack.outs`. More information can be found at https://dvc.org/doc/start/data-pipelines/metrics-parameters-plots ." 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "%%writefile src/relaxation.py\n", 121 | "import zntrack\n", 122 | "from mlipx.abc import NodeWithCalculator\n", 123 | "from ase.optimize import BFGS\n", 124 | "import ase.io\n", 125 | "import pathlib\n", 126 | "\n", 127 | "\n", 128 | "\n", 129 | "class Relax(zntrack.Node):\n", 130 | " data: list = zntrack.deps()\n", 131 | " model: NodeWithCalculator = zntrack.deps()\n", 132 | " frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / 'frames.xyz')\n", 133 | "\n", 134 | " def run(self):\n", 135 | " end_structures = []\n", 136 | " for atoms in self.data:\n", 137 | " atoms.set_calculator(self.model.get_calculator())\n", 138 | " opt = BFGS(atoms)\n", 139 | " opt.run(fmax=0.05)\n", 140 | " end_structures.append(atoms)\n", 141 | " with open(self.frames_path, 'w') as f:\n", 142 | " ase.io.write(f, end_structures, format='extxyz')\n", 143 | " \n", 144 | " @property\n", 145 | " def frames(self) -> list[ase.Atoms]:\n", 146 | " with self.state.fs.open(self.frames_path, \"r\") as f:\n", 147 | " return ase.io.read(f, format='extxyz', index=':')\n" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "With this Node definition, we can import the Node and connect it with `mlipx` to form a graph." 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "from src.relaxation import Relax\n", 164 | "\n", 165 | "import mlipx" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "project = mlipx.Project()\n", 175 | "\n", 176 | "emt = mlipx.GenericASECalculator(\n", 177 | " module=\"ase.calculators.emt\",\n", 178 | " class_name=\"EMT\",\n", 179 | ")\n", 180 | "\n", 181 | "with project:\n", 182 | " confs = mlipx.Smiles2Conformers(smiles=\"CCCC\", num_confs=5)\n", 183 | " relax = Relax(data=confs.frames, model=emt)\n", 184 | "\n", 185 | "project.build()" 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": {}, 191 | "source": [ 192 | "To execute the graph, we make use of `dvc repro` via `project.repro`." 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "project.repro(build=False)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "Once the graph has been executed, we can look at the resulting structures." 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "relax.frames" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "temp_dir.cleanup()" 227 | ] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Python 3", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.10.0" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 2 251 | } 252 | -------------------------------------------------------------------------------- /docs/source/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | ========== 3 | 4 | There are two ways to use the library: as a :ref:`command-line tool ` or as a :ref:`Python library `. 5 | The CLI provides the most convenient way to get started, while the Python library offers more flexibility for advanced workflows. 6 | 7 | .. image:: https://github.com/user-attachments/assets/ab38546b-6f5f-4c7c-9274-f7d2e9e1ae73 8 | :width: 100% 9 | :class: only-light 10 | 11 | .. image:: https://github.com/user-attachments/assets/c34f64f7-958a-47cc-88ab-d2689e82deaf 12 | :width: 100% 13 | :class: only-dark 14 | 15 | Use the :ref:`command-line tool ` to evaluate different :term:`MLIP` models on the ``DODH_adsorption_dft.xyz`` file and 16 | visualize the trajectory together with the maximum force error in :ref:`ZnDraw `. 17 | 18 | .. code:: console 19 | 20 | (.venv) $ mlipx recipes metrics --models mace-mpa-0,sevennet,orb-v2 --datapath ../data DODH_adsorption_dft.xyz --repro 21 | (.venv) $ mlipx compare --glob "*CompareCalculatorResults" 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :hidden: 26 | 27 | quickstart/cli 28 | quickstart/python 29 | -------------------------------------------------------------------------------- /docs/source/quickstart/cli.rst: -------------------------------------------------------------------------------- 1 | .. _cli-quickstart: 2 | 3 | Command Line Interface 4 | ====================== 5 | 6 | This guide will help you get started with ``mlipx`` by creating a new project in an empty directory and computing metrics for a machine-learned interatomic potential (:term:`MLIP`) against reference DFT data. 7 | 8 | First, create a new project directory and initialize it with Git and DVC: 9 | 10 | .. code-block:: console 11 | 12 | (.venv) $ mkdir my_project 13 | (.venv) $ cd my_project 14 | (.venv) $ git init 15 | (.venv) $ dvc init 16 | 17 | Adding Reference Data 18 | ---------------------- 19 | Next, add a reference DFT dataset to the project. For this example, we use a slice from the mptraj dataset :footcite:`dengCHGNetPretrainedUniversal2023`. 20 | 21 | .. note:: 22 | 23 | If you have your own data, replace this file with any dataset that can be read by ``ase.io.read`` and includes reference energies and forces. Run the following command instead: 24 | 25 | .. code-block:: bash 26 | 27 | (.venv) $ cp /path/to/your/data.xyz data.xyz 28 | (.venv) $ dvc add data.xyz 29 | 30 | .. code-block:: console 31 | 32 | (.venv) $ dvc import-url https://github.com/zincware/ips-mace/releases/download/v0.1.0/mptraj_slice.xyz data.xyz 33 | 34 | Adding the Recipe 35 | ----------------- 36 | With the reference data in place, add a ``mlipx`` recipe to compute metrics: 37 | 38 | .. code-block:: console 39 | 40 | (.venv) $ mlipx recipes metrics --datapath data.xyz 41 | 42 | This command generates a ``main.py`` file in the current directory, which defines the workflow for the recipe. 43 | 44 | Defining Models 45 | --------------- 46 | Define the models to evaluate. This example uses the MACE-MP-0 model :footcite:`batatiaFoundationModelAtomistic2023` which is provided by the ``mace-torch`` package.. 47 | 48 | Create a file named ``models.py`` in the current directory with the following content: 49 | 50 | 51 | .. note:: 52 | 53 | If you already have computed energies and forces you can use two different data files or one file and update the keys. 54 | For more information, see the section on :ref:`update-frames-calc`. 55 | 56 | .. code-block:: python 57 | 58 | import mlipx 59 | 60 | mace_mp = mlipx.GenericASECalculator( 61 | module="mace.calculators", 62 | class_name="mace_mp", 63 | device="auto", 64 | kwargs={ 65 | "model": "medium", 66 | }, 67 | ) 68 | 69 | MODELS = {"mace-mp": mace_mp} 70 | 71 | .. note:: 72 | 73 | The ``GenericASECalculator`` class passes any provided ``kwargs`` to the specified ``class_name``. 74 | A special case is the ``device`` argument. 75 | When set to ``auto``, the class uses ``torch.cuda.is_available()`` to check if a GPU is available and automatically selects it if possible. 76 | If you are not using ``torch`` or wish to specify a device explicitly, you can omit the ``device`` argument or define it directly in the ``kwargs``. 77 | 78 | 79 | Running the Workflow 80 | --------------------- 81 | Now, run the workflow using the following commands: 82 | 83 | .. code-block:: console 84 | 85 | (.venv) $ python main.py 86 | (.venv) $ dvc repro 87 | 88 | Listing Steps and Visualizing Results 89 | ------------------------------------- 90 | To explore the available steps and visualize results, use the commands below: 91 | 92 | .. code-block:: console 93 | 94 | (.venv) $ zntrack list 95 | (.venv) $ mlipx compare mace-mp_CompareCalculatorResults 96 | 97 | .. note:: 98 | 99 | To use ``mlipx compare``, you must have an active ZnDraw server running. Provide the server URL via the ``--zndraw-url`` argument or the ``ZNDRAW_URL`` environment variable. 100 | 101 | You can start a server locally with the command ``zndraw`` in a separate terminal or use the public server at https://zndraw.icp.uni-stuttgart.de. 102 | 103 | 104 | More CLI Options 105 | ---------------- 106 | 107 | The ``mlipx`` CLI can create the :term:`models.py` for some models. 108 | To evaluate ``data.xyz`` with multiple models, you can also run 109 | 110 | .. code-block:: console 111 | 112 | (.venv) $ mlipx recipes metrics --datapath data.xyz --models mace-mpa-0,sevennet,orb-v2,chgnet --repro 113 | 114 | .. note:: 115 | 116 | Want to see your model here? Open an issue or submit a pull request to the `mlipx repository `_. 117 | 118 | 119 | .. footbibliography:: 120 | -------------------------------------------------------------------------------- /docs/source/quickstart/python.rst: -------------------------------------------------------------------------------- 1 | .. _python-quickstart: 2 | 3 | Python Interface 4 | ================ 5 | 6 | In the :ref:`cli-quickstart` guide, we demonstrated how to compute metrics for an MLIP against reference DFT data using the CLI. 7 | This guide shows how to achieve the same result using the Python interface. 8 | 9 | Getting Started 10 | --------------- 11 | 12 | First, create a new project directory and initialize it with Git and DVC, as shown below: 13 | 14 | .. code-block:: console 15 | 16 | (.venv) $ mkdir my_project 17 | (.venv) $ cd my_project 18 | (.venv) $ git init 19 | (.venv) $ dvc init 20 | 21 | Adding Reference Data 22 | ---------------------- 23 | 24 | Create a new Python file named ``main.py`` in the project directory, and add the following code to download the reference dataset: 25 | 26 | .. code-block:: python 27 | 28 | import mlipx 29 | import zntrack 30 | 31 | mptraj = zntrack.add( 32 | url="https://github.com/zincware/ips-mace/releases/download/v0.1.0/mptraj_slice.xyz", 33 | path="data.xyz", 34 | ) 35 | 36 | This will download the reference data file ``mptraj_slice.xyz`` into your project directory. 37 | 38 | Defining Models 39 | --------------- 40 | 41 | Define the MLIP models to evaluate by adding the following code to the ``main.py`` file: 42 | 43 | .. code-block:: python 44 | 45 | mace_mp = mlipx.GenericASECalculator( 46 | module="mace.calculators", 47 | class_name="mace_mp", 48 | device="auto", 49 | kwargs={ 50 | "model": "medium", 51 | }, 52 | ) 53 | 54 | Adding the Recipe 55 | ----------------- 56 | 57 | Next, set up the recipe to compute metrics for the MLIP. Add the following code to the ``main.py`` file: 58 | 59 | .. code-block:: python 60 | 61 | project = mlipx.Project() 62 | 63 | with project.group("reference"): 64 | data = mlipx.LoadDataFile(path=mptraj) 65 | ref_evaluation = mlipx.EvaluateCalculatorResults(data=data.frames) 66 | 67 | with project.group("mace-mp"): 68 | updated_data = mlipx.ApplyCalculator(data=data.frames, model=mace_mp) 69 | evaluation = mlipx.EvaluateCalculatorResults(data=updated_data.frames) 70 | mlipx.CompareCalculatorResults(data=evaluation, reference=ref_evaluation) 71 | 72 | project.repro() 73 | 74 | Running the Workflow 75 | --------------------- 76 | 77 | Finally, run the workflow by executing the ``main.py`` file: 78 | 79 | .. code-block:: console 80 | 81 | (.venv) $ python main.py 82 | 83 | .. note:: 84 | 85 | If you want to execute the workflow using ``dvc repro``, replace ``project.repro()`` with ``project.build()`` in the ``main.py`` file. 86 | 87 | This will compute the metrics for the MLIP against the reference DFT data. 88 | 89 | Listing Steps and Visualizing Results 90 | ------------------------------------- 91 | 92 | As with the CLI approach, you can list the available steps and visualize results using the following commands: 93 | 94 | .. code-block:: console 95 | 96 | (.venv) $ zntrack list 97 | (.venv) $ mlipx compare mace-mp_CompareCalculatorResults 98 | 99 | Alternatively, you can load the results for this and any other Node directly into a Python kernel using the following code: 100 | 101 | .. code-block:: python 102 | 103 | import zntrack 104 | 105 | node = zntrack.from_rev("mace-mp_CompareCalculatorResults") 106 | print(node.figures) 107 | >>> {"fmax_error": plotly.graph_objects.Figure(), ...} 108 | -------------------------------------------------------------------------------- /docs/source/recipes.rst: -------------------------------------------------------------------------------- 1 | Recipes 2 | ======= 3 | 4 | The following recipes are currently available within :code:`mlipx`. 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | recipes/energy_and_forces 10 | recipes/homonuclear_diatomics 11 | recipes/energy_volume 12 | recipes/invariances 13 | recipes/relax 14 | recipes/md 15 | recipes/neb 16 | recipes/adsorption 17 | recipes/phase_diagram 18 | recipes/pourbaix_diagram 19 | recipes/vibrational_analysis 20 | -------------------------------------------------------------------------------- /docs/source/recipes/adsorption.rst: -------------------------------------------------------------------------------- 1 | .. _neb: 2 | 3 | Adsorption Energies 4 | =================== 5 | 6 | This recipe calculates the adsorption energies of a molecule on a surface. 7 | The following example creates a slab of ``Cu(111)`` and calculates the adsorption energy of ethanol ``(CCO)`` on the surface. 8 | 9 | .. mdinclude:: ../../../mlipx-hub/adsorption/cu_fcc111/README.md 10 | 11 | 12 | .. jupyter-execute:: 13 | :hide-code: 14 | 15 | from mlipx.doc_utils import get_plots 16 | 17 | plots = get_plots("*RelaxAdsorptionConfigs", "../../mlipx-hub/adsorption/cu_fcc111/") 18 | plots["adsorption_energies"].show() 19 | 20 | This test uses the following Nodes together with your provided model in the :term:`models.py` file: 21 | 22 | * :term:`RelaxAdsorptionConfigs` 23 | * :term:`BuildASEslab` 24 | * :term:`Smiles2Conformers` 25 | 26 | 27 | .. dropdown:: Content of :code:`main.py` 28 | 29 | .. literalinclude:: ../../../mlipx-hub/adsorption/cu_fcc111/main.py 30 | :language: Python 31 | 32 | 33 | .. dropdown:: Content of :code:`models.py` 34 | 35 | .. literalinclude:: ../../../mlipx-hub/adsorption/cu_fcc111/models.py 36 | :language: Python 37 | -------------------------------------------------------------------------------- /docs/source/recipes/energy_and_forces.rst: -------------------------------------------------------------------------------- 1 | Energy and Force Evaluation 2 | =========================== 3 | 4 | This recipe is used to test the performance of different models in predicting the energy and forces for a given dataset. 5 | 6 | .. mdinclude:: ../../../mlipx-hub/metrics/DODH_adsorption/README.md 7 | 8 | .. mermaid:: 9 | :align: center 10 | 11 | graph TD 12 | 13 | data['Reference Data incl. DFT E/F'] 14 | data --> CalculateFormationEnergy1 15 | data --> CalculateFormationEnergy2 16 | data --> CalculateFormationEnergy3 17 | data --> CalculateFormationEnergy4 18 | 19 | subgraph Reference 20 | CalculateFormationEnergy1 --> EvaluateCalculatorResults1 21 | end 22 | 23 | subgraph mg1["Model 1"] 24 | CalculateFormationEnergy2 --> EvaluateCalculatorResults2 25 | EvaluateCalculatorResults2 --> CompareCalculatorResults2 26 | EvaluateCalculatorResults1 --> CompareCalculatorResults2 27 | end 28 | subgraph mg2["Model 2"] 29 | CalculateFormationEnergy3 --> EvaluateCalculatorResults3 30 | EvaluateCalculatorResults3 --> CompareCalculatorResults3 31 | EvaluateCalculatorResults1 --> CompareCalculatorResults3 32 | end 33 | subgraph mgn["Model N"] 34 | CalculateFormationEnergy4 --> EvaluateCalculatorResults4 35 | EvaluateCalculatorResults4 --> CompareCalculatorResults4 36 | EvaluateCalculatorResults1 --> CompareCalculatorResults4 37 | end 38 | 39 | 40 | .. code-block:: console 41 | 42 | (.venv) $ mlipx compare --glob "*CompareCalculatorResults" 43 | 44 | .. jupyter-execute:: 45 | :hide-code: 46 | 47 | from mlipx.doc_utils import get_plots 48 | 49 | plots = get_plots("*CompareCalculatorResults", "../../mlipx-hub/metrics/DODH_adsorption/") 50 | # raise ValueError(plots.keys()) 51 | plots["fmax_error"].show() 52 | plots["adjusted_energy_error_per_atom"].show() 53 | 54 | 55 | This recipe uses the following Nodes together with your provided model in the :term:`models.py` file: 56 | 57 | * :term:`ApplyCalculator` 58 | * :term:`EvaluateCalculatorResults` 59 | * :term:`CalculateFormationEnergy` 60 | * :term:`CompareCalculatorResults` 61 | * :term:`CompareFormationEnergy` 62 | 63 | 64 | .. dropdown:: Content of :code:`main.py` 65 | 66 | .. literalinclude:: ../../../mlipx-hub/metrics/DODH_adsorption/main.py 67 | :language: Python 68 | 69 | 70 | .. dropdown:: Content of :code:`models.py` 71 | 72 | .. literalinclude:: ../../../mlipx-hub/metrics/DODH_adsorption/models.py 73 | :language: Python 74 | -------------------------------------------------------------------------------- /docs/source/recipes/energy_volume.rst: -------------------------------------------------------------------------------- 1 | .. _ev: 2 | 3 | Energy Volume Curves 4 | ==================== 5 | Compute the energy-volume curve for a given material using multiple models. 6 | 7 | .. mdinclude:: ../../../mlipx-hub/energy-volume/mp-1143/README.md 8 | 9 | 10 | .. jupyter-execute:: 11 | :hide-code: 12 | 13 | from mlipx.doc_utils import get_plots 14 | 15 | plots = get_plots("*EnergyVolumeCurve", "../../mlipx-hub/energy-volume/mp-1143/") 16 | plots["adjusted_energy-volume-curve"].show() 17 | 18 | 19 | This recipe uses the following Nodes together with your provided model in the :term:`models.py` file: 20 | 21 | * :term:`EnergyVolumeCurve` 22 | 23 | .. dropdown:: Content of :code:`main.py` 24 | 25 | .. literalinclude:: ../../../mlipx-hub/energy-volume/mp-1143/main.py 26 | :language: Python 27 | 28 | 29 | .. dropdown:: Content of :code:`models.py` 30 | 31 | .. literalinclude:: ../../../mlipx-hub/energy-volume/mp-1143/models.py 32 | :language: Python 33 | -------------------------------------------------------------------------------- /docs/source/recipes/homonuclear_diatomics.rst: -------------------------------------------------------------------------------- 1 | .. _homonuclear_diatomics: 2 | 3 | Homonuclear Diatomics 4 | =========================== 5 | Homonuclear diatomics give a per-element information on the performance of the :term:`MLIP`. 6 | 7 | 8 | .. mdinclude:: ../../../mlipx-hub/diatomics/LiCl/README.md 9 | 10 | You can edit the elements in the :term:`main.py` file to include the elements you want to test. 11 | In the following we show the results for the :code:`Li-Li` bond for the three selected models. 12 | 13 | .. code-block:: console 14 | 15 | (.venv) $ mlipx compare --glob "*HomonuclearDiatomics" 16 | 17 | 18 | .. jupyter-execute:: 19 | :hide-code: 20 | 21 | from mlipx.doc_utils import get_plots 22 | 23 | plots = get_plots("*HomonuclearDiatomics", "../../mlipx-hub/diatomics/LiCl/") 24 | plots["Li-Li bond (adjusted)"].show() 25 | 26 | 27 | This test uses the following Nodes together with your provided model in the :term:`models.py` file: 28 | 29 | * :term:`HomonuclearDiatomics` 30 | 31 | .. dropdown:: Content of :code:`main.py` 32 | 33 | .. literalinclude:: ../../../mlipx-hub/diatomics/LiCl/main.py 34 | :language: Python 35 | 36 | 37 | .. dropdown:: Content of :code:`models.py` 38 | 39 | .. literalinclude:: ../../../mlipx-hub/diatomics/LiCl/models.py 40 | :language: Python 41 | -------------------------------------------------------------------------------- /docs/source/recipes/invariances.rst: -------------------------------------------------------------------------------- 1 | Invariances 2 | =========== 3 | Check the rotational, translational and permutational invariance of an :term:`mlip`. 4 | 5 | 6 | .. mdinclude:: ../../../mlipx-hub/invariances/mp-1143/README.md 7 | 8 | 9 | .. jupyter-execute:: 10 | :hide-code: 11 | 12 | from mlipx.doc_utils import get_plots 13 | 14 | plots = get_plots("*TranslationalInvariance", "../../mlipx-hub/invariances/mp-1143/") 15 | plots["energy_vs_steps_adjusted"].show() 16 | 17 | plots = get_plots("*RotationalInvariance", ".") 18 | plots["energy_vs_steps_adjusted"].show() 19 | 20 | plots = get_plots("*PermutationInvariance", ".") 21 | plots["energy_vs_steps_adjusted"].show() 22 | 23 | 24 | This recipe uses: 25 | 26 | * :term:`RotationalInvariance` 27 | * :term:`StructureOptimization` 28 | * :term:`PermutationInvariance` 29 | 30 | .. dropdown:: Content of :code:`main.py` 31 | 32 | .. literalinclude:: ../../../mlipx-hub/invariances/mp-1143/main.py 33 | :language: Python 34 | 35 | 36 | .. dropdown:: Content of :code:`models.py` 37 | 38 | .. literalinclude:: ../../../mlipx-hub/invariances/mp-1143/models.py 39 | :language: Python 40 | -------------------------------------------------------------------------------- /docs/source/recipes/md.rst: -------------------------------------------------------------------------------- 1 | .. _md: 2 | 3 | Molecular Dynamics 4 | ================== 5 | This recipe is used to test the performance of different models in molecular dynamics simulations. 6 | 7 | .. mdinclude:: ../../../mlipx-hub/md/mp-1143/README.md 8 | 9 | 10 | 11 | .. jupyter-execute:: 12 | :hide-code: 13 | 14 | from mlipx.doc_utils import get_plots 15 | 16 | plots = get_plots("*MolecularDynamics", "../../mlipx-hub/md/mp-1143/") 17 | plots["energy_vs_steps_adjusted"].show() 18 | 19 | This test uses the following Nodes together with your provided model in the :term:`models.py` file: 20 | 21 | * :term:`LangevinConfig` 22 | * :term:`MaximumForceObserver` 23 | * :term:`TemperatureRampModifier` 24 | * :term:`MolecularDynamics` 25 | 26 | .. dropdown:: Content of :code:`main.py` 27 | 28 | .. literalinclude:: ../../../mlipx-hub/md/mp-1143/main.py 29 | :language: Python 30 | 31 | 32 | .. dropdown:: Content of :code:`models.py` 33 | 34 | .. literalinclude:: ../../../mlipx-hub/md/mp-1143/models.py 35 | :language: Python 36 | -------------------------------------------------------------------------------- /docs/source/recipes/neb.rst: -------------------------------------------------------------------------------- 1 | .. _neb: 2 | 3 | Nudged Elastic Band 4 | =================== 5 | 6 | :code:`mlipx` provides a command line interface to interpolate and create a NEB path from inital-final or initial-ts-final images and run NEB on the interpolated images. 7 | You can run the following command to instantiate a test directory: 8 | 9 | .. mdinclude:: ../../../mlipx-hub/neb/ex01/README.md 10 | 11 | 12 | .. jupyter-execute:: 13 | :hide-code: 14 | 15 | from mlipx.doc_utils import get_plots 16 | 17 | plots = get_plots("*NEBs", "../../mlipx-hub/neb/ex01/") 18 | plots["adjusted_energy_vs_neb_image"].show() 19 | 20 | This test uses the following Nodes together with your provided model in the :term:`models.py` file: 21 | 22 | * :term:`NEBinterpolate` 23 | * :term:`NEBs` 24 | 25 | 26 | .. dropdown:: Content of :code:`main.py` 27 | 28 | .. literalinclude:: ../../../mlipx-hub/neb/ex01/main.py 29 | :language: Python 30 | 31 | 32 | .. dropdown:: Content of :code:`models.py` 33 | 34 | .. literalinclude:: ../../../mlipx-hub/neb/ex01/models.py 35 | :language: Python 36 | -------------------------------------------------------------------------------- /docs/source/recipes/phase_diagram.rst: -------------------------------------------------------------------------------- 1 | Phase Diagram 2 | ============= 3 | 4 | :code:`mlipx` provides a command line interface to generate Phase Diagrams. 5 | You can run the following command to instantiate a test directory: 6 | 7 | .. mdinclude:: ../../../mlipx-hub/phase_diagram/mp-30084/README.md 8 | 9 | 10 | .. jupyter-execute:: 11 | :hide-code: 12 | 13 | from mlipx.doc_utils import get_plots 14 | 15 | plots = get_plots("*PhaseDiagram", "../../mlipx-hub/phase_diagram/mp-30084/") 16 | for name, plot in plots.items(): 17 | if name.endswith("phase-diagram"): 18 | plot.show() 19 | 20 | This test uses the following Nodes together with your provided model in the :term:`models.py` file: 21 | 22 | * :term:`PhaseDiagram` 23 | 24 | .. dropdown:: Content of :code:`main.py` 25 | 26 | .. literalinclude:: ../../../mlipx-hub/phase_diagram/mp-30084/main.py 27 | :language: Python 28 | 29 | 30 | .. dropdown:: Content of :code:`models.py` 31 | 32 | .. literalinclude:: ../../../mlipx-hub/phase_diagram/mp-30084/models.py 33 | :language: Python 34 | -------------------------------------------------------------------------------- /docs/source/recipes/pourbaix_diagram.rst: -------------------------------------------------------------------------------- 1 | Pourbaix Diagram 2 | ================ 3 | 4 | :code:`mlipx` provides a command line interface to generate Pourbaix diagrams. 5 | 6 | 7 | .. mdinclude:: ../../../mlipx-hub/pourbaix_diagram/mp-1143/README.md 8 | 9 | 10 | .. jupyter-execute:: 11 | :hide-code: 12 | 13 | from mlipx.doc_utils import get_plots 14 | 15 | plots = get_plots("*PourbaixDiagram", "../../mlipx-hub/pourbaix_diagram/mp-1143/") 16 | for name, plot in plots.items(): 17 | if name.endswith("pourbaix-diagram"): 18 | plot.show() 19 | 20 | This test uses the following Nodes together with your provided model in the :term:`models.py` file: 21 | 22 | * :term:`PourbaixDiagram` 23 | 24 | .. dropdown:: Content of :code:`main.py` 25 | 26 | .. literalinclude:: ../../../mlipx-hub/pourbaix_diagram/mp-1143/main.py 27 | :language: Python 28 | 29 | 30 | .. dropdown:: Content of :code:`models.py` 31 | 32 | .. literalinclude:: ../../../mlipx-hub/pourbaix_diagram/mp-1143/models.py 33 | :language: Python 34 | -------------------------------------------------------------------------------- /docs/source/recipes/relax.rst: -------------------------------------------------------------------------------- 1 | .. _relax: 2 | 3 | Structure Relaxation 4 | ==================== 5 | 6 | This recipe is used to test the performance of different models in performing structure relaxation. 7 | 8 | 9 | .. mdinclude:: ../../../mlipx-hub/relax/mp-1143/README.md 10 | 11 | .. note:: 12 | 13 | If you relax a non-periodic system and your model yields a stress tensor of :code:`[inf, inf, inf, inf, inf, inf]` you have to add the :code:`--convert-nan` flag to the :code:`mlipx compare` or :code:`zndraw` command to convert them to :code:`None`. 14 | 15 | .. jupyter-execute:: 16 | :hide-code: 17 | 18 | from mlipx.doc_utils import get_plots 19 | 20 | plots = get_plots("*StructureOptimization", "../../mlipx-hub/relax/mp-1143/") 21 | plots["adjusted_energy_vs_steps"].show() 22 | 23 | This recipe uses the following Nodes together with your provided model in the :term:`models.py` file: 24 | 25 | * :term:`StructureOptimization` 26 | 27 | .. dropdown:: Content of :code:`main.py` 28 | 29 | .. literalinclude:: ../../../mlipx-hub/relax/mp-1143/main.py 30 | :language: Python 31 | 32 | 33 | .. dropdown:: Content of :code:`models.py` 34 | 35 | .. literalinclude:: ../../../mlipx-hub/relax/mp-1143/models.py 36 | :language: Python 37 | -------------------------------------------------------------------------------- /docs/source/recipes/vibrational_analysis.rst: -------------------------------------------------------------------------------- 1 | Vibrational Analysis 2 | ==================== 3 | 4 | :code:`mlipx` provides a command line interface to vibrational analysis. 5 | You can run the following command to instantiate a test directory: 6 | 7 | .. mdinclude:: ../../../mlipx-hub/vibrational_analysis/CxO/README.md 8 | 9 | 10 | The vibrational analysis method needs additional information to run. 11 | Please edit the ``main.py`` file and set the ``system`` parameter on the ``VibrationalAnalysis`` node. 12 | For the given list of SMILES, you should set it to ``"molecule"``. 13 | Then run the following commands to reproduce and inspect the results: 14 | 15 | .. code-block:: console 16 | 17 | (.venv) $ python main.py 18 | (.venv) $ dvc repro 19 | (.venv) $ mlipx compare --glob "*VibrationalAnalysis" 20 | 21 | 22 | .. jupyter-execute:: 23 | :hide-code: 24 | 25 | from mlipx.doc_utils import get_plots 26 | 27 | plots = get_plots("*VibrationalAnalysis", "../../mlipx-hub/vibrational_analysis/CxO/") 28 | plots["Gibbs-Comparison"].show() 29 | 30 | This test uses the following Nodes together with your provided model in the :term:`models.py` file: 31 | 32 | * :term:`VibrationalAnalysis` 33 | 34 | .. dropdown:: Content of :code:`main.py` 35 | 36 | .. literalinclude:: ../../../mlipx-hub/vibrational_analysis/CxO/main.py 37 | :language: Python 38 | 39 | 40 | .. dropdown:: Content of :code:`models.py` 41 | 42 | .. literalinclude:: ../../../mlipx-hub/vibrational_analysis/CxO/models.py 43 | :language: Python 44 | -------------------------------------------------------------------------------- /docs/source/references.bib: -------------------------------------------------------------------------------- 1 | @misc{zillsZnTrackDataCode2024, 2 | title = {{{ZnTrack}} -- {{Data}} as {{Code}}}, 3 | author = {Zills, Fabian and Sch{\"a}fer, Moritz and Tovey, Samuel and K{\"a}stner, Johannes and Holm, Christian}, 4 | year = {2024}, 5 | eprint={2401.10603}, 6 | archivePrefix={arXiv}, 7 | } 8 | @article{zillsCollaborationMachineLearnedPotentials2024, 9 | title = {Collaboration on {{Machine-Learned Potentials}} with {{IPSuite}}: {{A Modular Framework}} for {{Learning-on-the-Fly}}}, 10 | shorttitle = {Collaboration on {{Machine-Learned Potentials}} with {{IPSuite}}}, 11 | author = {Zills, Fabian and Schäfer, Moritz René and Segreto, Nico and Kästner, Johannes and Holm, Christian and Tovey, Samuel}, 12 | year = 2024, 13 | journal = {J. Phys. Chem. B}, 14 | publisher = {American Chemical Society}, 15 | issn = {1520-6106}, 16 | doi = {10.1021/acs.jpcb.3c07187}, 17 | } 18 | @misc{elijosiusZeroShotMolecular2024, 19 | title = {Zero {{Shot Molecular Generation}} via {{Similarity Kernels}}}, 20 | author = {Elijo{\v s}ius, Rokas and Zills, Fabian and Batatia, Ilyes and Norwood, Sam Walton and Kov{\'a}cs, D{\'a}vid P{\'e}ter and Holm, Christian and Cs{\'a}nyi, G{\'a}bor}, 21 | year = {2024}, 22 | eprint = {2402.08708}, 23 | archiveprefix = {arxiv}, 24 | } 25 | @article{dengCHGNetPretrainedUniversal2023, 26 | title = {{{CHGNet}} as a Pretrained Universal Neural Network Potential for Charge-Informed Atomistic Modelling}, 27 | author = {Deng, Bowen and Zhong, Peichen and Jun, KyuJung and Riebesell, Janosh and Han, Kevin and Bartel, Christopher J. and Ceder, Gerbrand}, 28 | journal = {Nat Mach Intell}, 29 | year = {2023}, 30 | } 31 | 32 | @online{batatiaFoundationModelAtomistic2023, 33 | title = {A Foundation Model for Atomistic Materials Chemistry}, 34 | author = {Batatia et al., Ilyes}, 35 | year = {2023}, 36 | eprint = {2401.00096}, 37 | archiveprefix = {arxiv}, 38 | } 39 | -------------------------------------------------------------------------------- /mlipx/__init__.py: -------------------------------------------------------------------------------- 1 | import lazy_loader as lazy 2 | 3 | __getattr__, __dir__, __all__ = lazy.attach_stub(__name__, __file__) 4 | -------------------------------------------------------------------------------- /mlipx/__init__.pyi: -------------------------------------------------------------------------------- 1 | from . import abc 2 | from .nodes.adsorption import BuildASEslab, RelaxAdsorptionConfigs 3 | from .nodes.apply_calculator import ApplyCalculator 4 | from .nodes.compare_calculator import CompareCalculatorResults 5 | from .nodes.diatomics import HomonuclearDiatomics 6 | from .nodes.energy_volume import EnergyVolumeCurve 7 | from .nodes.evaluate_calculator import EvaluateCalculatorResults 8 | from .nodes.filter_dataset import FilterAtoms 9 | from .nodes.formation_energy import CalculateFormationEnergy, CompareFormationEnergy 10 | from .nodes.generic_ase import GenericASECalculator 11 | from .nodes.invariances import ( 12 | PermutationInvariance, 13 | RotationalInvariance, 14 | TranslationalInvariance, 15 | ) 16 | from .nodes.io import LoadDataFile 17 | from .nodes.modifier import TemperatureRampModifier 18 | from .nodes.molecular_dynamics import LangevinConfig, MolecularDynamics 19 | from .nodes.mp_api import MPRester 20 | from .nodes.nebs import NEBinterpolate, NEBs 21 | from .nodes.observer import MaximumForceObserver 22 | from .nodes.orca import OrcaSinglePoint 23 | from .nodes.phase_diagram import PhaseDiagram 24 | from .nodes.pourbaix_diagram import PourbaixDiagram 25 | from .nodes.rattle import Rattle 26 | from .nodes.smiles import BuildBox, Smiles2Conformers 27 | from .nodes.structure_optimization import StructureOptimization 28 | from .nodes.updated_frames import UpdateFramesCalc 29 | from .nodes.vibrational_analysis import VibrationalAnalysis 30 | from .project import Project 31 | from .version import __version__ 32 | 33 | __all__ = [ 34 | "abc", 35 | "StructureOptimization", 36 | "LoadDataFile", 37 | "MaximumForceObserver", 38 | "TemperatureRampModifier", 39 | "MolecularDynamics", 40 | "LangevinConfig", 41 | "ApplyCalculator", 42 | "CalculateFormationEnergy", 43 | "EvaluateCalculatorResults", 44 | "CompareCalculatorResults", 45 | "NEBs", 46 | "NEBinterpolate", 47 | "Smiles2Conformers", 48 | "PhaseDiagram", 49 | "PourbaixDiagram", 50 | "VibrationalAnalysis", 51 | "HomonuclearDiatomics", 52 | "MPRester", 53 | "GenericASECalculator", 54 | "FilterAtoms", 55 | "EnergyVolumeCurve", 56 | "BuildBox", 57 | "CompareFormationEnergy", 58 | "UpdateFramesCalc", 59 | "RotationalInvariance", 60 | "TranslationalInvariance", 61 | "PermutationInvariance", 62 | "Rattle", 63 | "Project", 64 | "BuildASEslab", 65 | "RelaxAdsorptionConfigs", 66 | "OrcaSinglePoint", 67 | "__version__", 68 | ] 69 | -------------------------------------------------------------------------------- /mlipx/abc.py: -------------------------------------------------------------------------------- 1 | """Abstract base classes and type hints.""" 2 | 3 | import abc 4 | import dataclasses 5 | import pathlib 6 | import typing as t 7 | from enum import Enum 8 | 9 | import ase 10 | import h5py 11 | import plotly.graph_objects as go 12 | import znh5md 13 | import zntrack 14 | from ase.calculators.calculator import Calculator 15 | from ase.md.md import MolecularDynamics 16 | 17 | T = t.TypeVar("T", bound=zntrack.Node) 18 | 19 | 20 | class Optimizer(str, Enum): 21 | FIRE = "FIRE" 22 | BFGS = "BFGS" 23 | LBFGS = "LBFGS" 24 | 25 | 26 | class ASEKeys(str, Enum): 27 | formation_energy = "formation_energy" 28 | isolated_energies = "isolated_energies" 29 | 30 | 31 | class NodeWithCalculator(t.Protocol[T]): 32 | def get_calculator(self, **kwargs) -> Calculator: ... 33 | 34 | 35 | class NodeWithMolecularDynamics(t.Protocol[T]): 36 | def get_molecular_dynamics(self, atoms: ase.Atoms) -> MolecularDynamics: ... 37 | 38 | 39 | FIGURES = t.Dict[str, go.Figure] 40 | FRAMES = t.List[ase.Atoms] 41 | 42 | 43 | class ComparisonResults(t.TypedDict): 44 | frames: FRAMES 45 | figures: FIGURES 46 | 47 | 48 | @dataclasses.dataclass 49 | class DynamicsObserver: 50 | @property 51 | def name(self) -> str: 52 | return self.__class__.__name__ 53 | 54 | def initialize(self, atoms: ase.Atoms) -> None: 55 | pass 56 | 57 | @abc.abstractmethod 58 | def check(self, atoms: ase.Atoms) -> bool: ... 59 | 60 | 61 | @dataclasses.dataclass 62 | class DynamicsModifier: 63 | @property 64 | def name(self) -> str: 65 | return self.__class__.__name__ 66 | 67 | @abc.abstractmethod 68 | def modify(self, thermostat, step, total_steps) -> None: ... 69 | 70 | 71 | class ProcessAtoms(zntrack.Node): 72 | data: list[ase.Atoms] = zntrack.deps() 73 | data_id: int = zntrack.params(-1) 74 | 75 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.h5") 76 | 77 | @property 78 | def frames(self) -> FRAMES: 79 | with self.state.fs.open(self.frames_path, "r") as f: 80 | with h5py.File(f, "r") as h5f: 81 | return znh5md.IO(file_handle=h5f)[:] 82 | 83 | @property 84 | def figures(self) -> FIGURES: ... 85 | 86 | @staticmethod 87 | def compare(*nodes: "ProcessAtoms") -> ComparisonResults: ... 88 | 89 | 90 | class ProcessFrames(zntrack.Node): 91 | data: list[ase.Atoms] = zntrack.deps() 92 | -------------------------------------------------------------------------------- /mlipx/benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | from mlipx.benchmark.main import app 2 | 3 | __all__ = ["app"] 4 | -------------------------------------------------------------------------------- /mlipx/benchmark/elements.py: -------------------------------------------------------------------------------- 1 | import zntrack 2 | from models import MODELS 3 | 4 | import mlipx 5 | 6 | ELEMENTS = {{elements}} # noqa F821 7 | FILTERING_TYPE = "{{ filtering_type }}" 8 | 9 | mptraj = zntrack.add( 10 | url="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz", 11 | path="mptraj.xyz", 12 | ) 13 | 14 | project = zntrack.Project() 15 | 16 | with project.group("mptraj"): 17 | raw_mptraj_data = mlipx.LoadDataFile(path=mptraj) 18 | mptraj_data = mlipx.FilterAtoms( 19 | data=raw_mptraj_data.frames, elements=ELEMENTS, filtering_type=FILTERING_TYPE 20 | ) 21 | 22 | for model_name, model in MODELS.items(): 23 | with project.group(model_name, "diatomics"): 24 | neb = mlipx.HomonuclearDiatomics( 25 | elements=ELEMENTS, 26 | model=model, 27 | n_points=100, 28 | min_distance=0.5, 29 | max_distance=2.0, 30 | ) 31 | 32 | relaxed = [] 33 | 34 | for model_name, model in MODELS.items(): 35 | with project.group(model_name, "struct_optim"): 36 | relaxed.append( 37 | mlipx.StructureOptimization( 38 | data=mptraj_data.frames, data_id=-1, model=model, fmax=0.1 39 | ) 40 | ) 41 | 42 | 43 | mds = [] 44 | 45 | thermostat = mlipx.LangevinConfig(timestep=0.5, temperature=300, friction=0.05) 46 | force_check = mlipx.MaximumForceObserver(f_max=100) 47 | t_ramp = mlipx.TemperatureRampModifier(end_temperature=400, total_steps=100) 48 | 49 | for (model_name, model), relaxed_structure in zip(MODELS.items(), relaxed): 50 | with project.group(model_name, "md"): 51 | mds.append( 52 | mlipx.MolecularDynamics( 53 | model=model, 54 | thermostat=thermostat, 55 | data=relaxed_structure.frames, 56 | data_id=-1, 57 | observers=[force_check], 58 | modifiers=[t_ramp], 59 | steps=100, 60 | ) 61 | ) 62 | 63 | 64 | for (model_name, model), md in zip(MODELS.items(), mds): 65 | with project.group(model_name): 66 | ev = mlipx.EnergyVolumeCurve( 67 | model=model, 68 | data=md.frames, 69 | data_id=-1, 70 | n_points=50, 71 | start=0.75, 72 | stop=2.0, 73 | ) 74 | 75 | for (model_name, model), md in zip(MODELS.items(), mds): 76 | with project.group(model_name, "struct_optim_2"): 77 | relaxed.append( 78 | mlipx.StructureOptimization( 79 | data=md.frames, data_id=-1, model=model, fmax=0.1 80 | ) 81 | ) 82 | 83 | 84 | project.build() 85 | -------------------------------------------------------------------------------- /mlipx/benchmark/file.py: -------------------------------------------------------------------------------- 1 | import ase.io 2 | import zntrack 3 | from models import MODELS 4 | 5 | import mlipx 6 | 7 | DATAPATH = "{{ datapath }}" 8 | 9 | count = 0 10 | ELEMENTS = set() 11 | for atoms in ase.io.iread(DATAPATH): 12 | count += 1 13 | for symbol in atoms.symbols: 14 | ELEMENTS.add(symbol) 15 | ELEMENTS = list(ELEMENTS) 16 | 17 | project = zntrack.Project() 18 | 19 | with project.group("mptraj"): 20 | data = mlipx.LoadDataFile(path=DATAPATH) 21 | 22 | 23 | for model_name, model in MODELS.items(): 24 | with project.group(model_name, "diatomics"): 25 | _ = mlipx.HomonuclearDiatomics( 26 | elements=ELEMENTS, 27 | model=model, 28 | n_points=100, 29 | min_distance=0.5, 30 | max_distance=2.0, 31 | ) 32 | 33 | # Energy-Volume Curve 34 | for model_name, model in MODELS.items(): 35 | for idx in range(count): 36 | with project.group(model_name, "ev", str(idx)): 37 | _ = mlipx.EnergyVolumeCurve( 38 | model=model, 39 | data=data.frames, 40 | data_id=idx, 41 | n_points=50, 42 | start=0.75, 43 | stop=2.0, 44 | ) 45 | 46 | 47 | # Molecular Dynamics 48 | thermostat = mlipx.LangevinConfig(timestep=0.5, temperature=300, friction=0.05) 49 | force_check = mlipx.MaximumForceObserver(f_max=100) 50 | t_ramp = mlipx.TemperatureRampModifier(end_temperature=400, total_steps=100) 51 | 52 | for model_name, model in MODELS.items(): 53 | for idx in range(count): 54 | with project.group(model_name, "md", str(idx)): 55 | _ = mlipx.MolecularDynamics( 56 | model=model, 57 | thermostat=thermostat, 58 | data=data.frames, 59 | data_id=idx, 60 | observers=[force_check], 61 | modifiers=[t_ramp], 62 | steps=100, 63 | ) 64 | 65 | # Structure Optimization 66 | with project.group("rattle"): 67 | rattle = mlipx.Rattle(data=data.frames, stdev=0.01) 68 | 69 | for model_name, model in MODELS.items(): 70 | for idx in range(count): 71 | with project.group(model_name, "struct_optim", str(idx)): 72 | _ = mlipx.StructureOptimization( 73 | data=rattle.frames, data_id=idx, model=model, fmax=0.1 74 | ) 75 | 76 | project.build() 77 | -------------------------------------------------------------------------------- /mlipx/benchmark/main.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import shutil 3 | import subprocess 4 | import typing as t 5 | 6 | import jinja2 7 | import typer 8 | from ase.data import chemical_symbols 9 | 10 | from mlipx.nodes.filter_dataset import FilteringType 11 | 12 | CWD = pathlib.Path(__file__).parent 13 | 14 | 15 | app = typer.Typer() 16 | 17 | 18 | def initialize_directory(): 19 | subprocess.run(["git", "init"], check=True) 20 | subprocess.run(["dvc", "init"], check=True) 21 | shutil.copy(CWD.parent / "recipes" / "models.py", "models.py") 22 | 23 | 24 | @app.command() 25 | def elements( 26 | elements: t.Annotated[list[str], typer.Argument()], 27 | filtering_type: FilteringType = FilteringType.INCLUSIVE, 28 | ): 29 | for element in elements: 30 | if element not in chemical_symbols: 31 | raise ValueError(f"{element} is not a chemical element") 32 | template = jinja2.Template((CWD / "elements.py").read_text()) 33 | with open("main.py", "w") as f: 34 | f.write(template.render(elements=elements, filtering_type=filtering_type.value)) 35 | 36 | 37 | @app.command() 38 | def file( 39 | datapath: str, 40 | ): 41 | template = jinja2.Template((CWD / "file.py").read_text()) 42 | with open("main.py", "w") as f: 43 | f.write(template.render(datapath=datapath)) 44 | -------------------------------------------------------------------------------- /mlipx/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basf/mlipx/c109aa6b2d3ff8236d6a9ec0772a8c4efb9c7dea/mlipx/cli/__init__.py -------------------------------------------------------------------------------- /mlipx/cli/main.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import importlib.metadata 3 | import json 4 | import pathlib 5 | import sys 6 | import uuid 7 | import webbrowser 8 | 9 | import dvc.api 10 | import plotly.io as pio 11 | import typer 12 | import zntrack 13 | from rich import box 14 | from rich.console import Console 15 | from rich.table import Table 16 | from tqdm import tqdm 17 | from typing_extensions import Annotated 18 | from zndraw import ZnDraw 19 | 20 | from mlipx import benchmark, recipes 21 | 22 | app = typer.Typer() 23 | app.add_typer(recipes.app, name="recipes") 24 | app.add_typer(benchmark.app, name="benchmark") 25 | 26 | # Load plugins 27 | 28 | entry_points = importlib.metadata.entry_points(group="mlipx.recipes") 29 | for entry_point in entry_points: 30 | entry_point.load() 31 | 32 | 33 | @app.command() 34 | def main(): 35 | typer.echo("Hello World") 36 | 37 | 38 | @app.command() 39 | def info(): 40 | """Print the version of mlipx and the available models.""" 41 | from mlipx.models import AVAILABLE_MODELS # slow import 42 | 43 | console = Console() 44 | # Get Python environment info 45 | python_version = sys.version.split()[0] 46 | python_executable = sys.executable 47 | python_platform = sys.platform 48 | 49 | py_table = Table(title="🐍 Python Environment", box=box.ROUNDED) 50 | py_table.add_column("Version", style="cyan", no_wrap=True) 51 | py_table.add_column("Executable", style="magenta") 52 | py_table.add_column("Platform", style="green") 53 | py_table.add_row(python_version, python_executable, python_platform) 54 | 55 | # Get model availability 56 | mlip_table = Table(title="🧠 MLIP Codes", box=box.ROUNDED) 57 | mlip_table.add_column("Model", style="bold") 58 | mlip_table.add_column("Available", style="bold") 59 | 60 | for model in sorted(AVAILABLE_MODELS): 61 | status = AVAILABLE_MODELS[model] 62 | if status is True: 63 | mlip_table.add_row(model, "[green]:heavy_check_mark: Yes[/green]") 64 | elif status is False: 65 | mlip_table.add_row(model, "[red]:x: No[/red]") 66 | elif status is None: 67 | mlip_table.add_row(model, "[yellow]:warning: Unknown[/yellow]") 68 | else: 69 | mlip_table.add_row(model, "[red]:boom: Error[/red]") 70 | 71 | # Get versions of key packages 72 | mlipx_table = Table(title="📦 mlipx Ecosystem", box=box.ROUNDED) 73 | mlipx_table.add_column("Package", style="bold") 74 | mlipx_table.add_column("Version", style="cyan") 75 | 76 | for package in ["mlipx", "zntrack", "zndraw"]: 77 | try: 78 | version = importlib.metadata.version(package) 79 | except importlib.metadata.PackageNotFoundError: 80 | version = "[red]Not installed[/red]" 81 | mlipx_table.add_row(package, version) 82 | 83 | # Display all 84 | console.print(mlipx_table) 85 | console.print(py_table) 86 | console.print(mlip_table) 87 | 88 | 89 | @app.command() 90 | def compare( # noqa C901 91 | nodes: Annotated[list[str], typer.Argument(help="Path to the node to compare")], 92 | zndraw_url: Annotated[ 93 | str, 94 | typer.Option( 95 | envvar="ZNDRAW_URL", 96 | help="URL of the ZnDraw server to visualize the results", 97 | ), 98 | ], 99 | kwarg: Annotated[list[str], typer.Option("--kwarg", "-k")] = None, 100 | token: Annotated[str, typer.Option("--token")] = None, 101 | glob: Annotated[ 102 | bool, typer.Option("--glob", help="Allow glob patterns to select nodes.") 103 | ] = False, 104 | convert_nan: Annotated[bool, typer.Option()] = False, 105 | browser: Annotated[ 106 | bool, 107 | typer.Option( 108 | help="""Whether to open the ZnDraw GUI in the default web browser.""" 109 | ), 110 | ] = True, 111 | figures_path: Annotated[ 112 | str | None, 113 | typer.Option( 114 | help="Provide a path to save the figures to." 115 | "No figures will be saved by default." 116 | ), 117 | ] = None, 118 | ): 119 | """Compare mlipx nodes and visualize the results using ZnDraw.""" 120 | # TODO: allow for glob patterns 121 | if kwarg is None: 122 | kwarg = [] 123 | node_names, revs, remotes = [], [], [] 124 | if glob: 125 | fs = dvc.api.DVCFileSystem() 126 | with fs.open("zntrack.json", mode="r") as f: 127 | all_nodes = list(json.load(f).keys()) 128 | 129 | for node in nodes: 130 | # can be name or name@rev or name@remote@rev 131 | parts = node.split("@") 132 | if glob: 133 | filtered_nodes = [x for x in all_nodes if fnmatch.fnmatch(x, parts[0])] 134 | else: 135 | filtered_nodes = [parts[0]] 136 | for x in filtered_nodes: 137 | node_names.append(x) 138 | if len(parts) == 1: 139 | revs.append(None) 140 | remotes.append(None) 141 | elif len(parts) == 2: 142 | revs.append(parts[1]) 143 | remotes.append(None) 144 | elif len(parts) == 3: 145 | remotes.append(parts[1]) 146 | revs.append(parts[2]) 147 | else: 148 | raise ValueError(f"Invalid node format: {node}") 149 | 150 | node_instances = {} 151 | for node_name, rev, remote in tqdm( 152 | zip(node_names, revs, remotes), desc="Loading nodes" 153 | ): 154 | node_instances[node_name] = zntrack.from_rev(node_name, remote=remote, rev=rev) 155 | 156 | if len(node_instances) == 0: 157 | typer.echo("No nodes to compare") 158 | return 159 | 160 | typer.echo(f"Comparing {len(node_instances)} nodes") 161 | 162 | kwargs = {} 163 | for arg in kwarg: 164 | key, value = arg.split("=", 1) 165 | kwargs[key] = value 166 | result = node_instances[node_names[0]].compare(*node_instances.values(), **kwargs) 167 | 168 | token = token or str(uuid.uuid4()) 169 | typer.echo(f"View the results at {zndraw_url}/token/{token}") 170 | vis = ZnDraw(zndraw_url, token=token, convert_nan=convert_nan) 171 | length = len(vis) 172 | vis.extend(result["frames"]) 173 | del vis[:length] # temporary fix 174 | vis.figures = result["figures"] 175 | if browser: 176 | webbrowser.open(f"{zndraw_url}/token/{token}") 177 | if figures_path: 178 | for desc, fig in result["figures"].items(): 179 | pio.write_json(fig, pathlib.Path(figures_path) / f"{desc}.json") 180 | 181 | vis.socket.sleep(5) 182 | -------------------------------------------------------------------------------- /mlipx/doc_utils.py: -------------------------------------------------------------------------------- 1 | import fnmatch 2 | import json 3 | import os 4 | import pathlib 5 | 6 | import plotly.graph_objects as go 7 | import plotly.io as pio 8 | import zntrack 9 | 10 | 11 | def show(file: str) -> None: 12 | pio.renderers.default = "sphinx_gallery" 13 | 14 | figure = pio.read_json(f"source/figures/{file}") 15 | figure.update_layout( 16 | { 17 | "plot_bgcolor": "rgba(0, 0, 0, 0)", 18 | "paper_bgcolor": "rgba(0, 0, 0, 0)", 19 | } 20 | ) 21 | figure.update_xaxes( 22 | showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False 23 | ) 24 | figure.update_yaxes( 25 | showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False 26 | ) 27 | figure.show() 28 | 29 | 30 | def get_plots(name: str, url: str) -> dict[str, go.Figure]: 31 | os.chdir(url) 32 | pio.renderers.default = "sphinx_gallery" 33 | with pathlib.Path("zntrack.json").open(mode="r") as f: 34 | all_nodes = list(json.load(f).keys()) 35 | filtered_nodes = [x for x in all_nodes if fnmatch.fnmatch(x, name)] 36 | 37 | node_instances = {} 38 | for node_name in filtered_nodes: 39 | node_instances[node_name] = zntrack.from_rev(node_name) 40 | 41 | result = node_instances[filtered_nodes[0]].compare(*node_instances.values()) 42 | return result["figures"] 43 | -------------------------------------------------------------------------------- /mlipx/models.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import jinja2 4 | 5 | import mlipx 6 | from mlipx import recipes 7 | 8 | AVAILABLE_MODELS = {} 9 | 10 | RECIPES_PATH = Path(recipes.__file__).parent 11 | template = jinja2.Template((RECIPES_PATH / "models.py.jinja2").read_text()) 12 | 13 | rendered_code = template.render(models=[]) 14 | 15 | # Prepare a namespace and execute the rendered code into it 16 | namespace = {"mlipx": mlipx} # replace with your actual mlipx 17 | exec(rendered_code, namespace) 18 | 19 | # Access ALL_MODELS and MODELS 20 | all_models = namespace["ALL_MODELS"] 21 | 22 | AVAILABLE_MODELS = { 23 | model_name: model.available for model_name, model in all_models.items() 24 | } 25 | -------------------------------------------------------------------------------- /mlipx/nodes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basf/mlipx/c109aa6b2d3ff8236d6a9ec0772a8c4efb9c7dea/mlipx/nodes/__init__.py -------------------------------------------------------------------------------- /mlipx/nodes/apply_calculator.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import ase 4 | import h5py 5 | import tqdm 6 | import znh5md 7 | import zntrack 8 | from ase.calculators.calculator import all_properties 9 | 10 | from mlipx.abc import NodeWithCalculator 11 | from mlipx.utils import freeze_copy_atoms 12 | 13 | 14 | class ApplyCalculator(zntrack.Node): 15 | """ 16 | Apply a calculator to a list of atoms objects and store the results in a H5MD file. 17 | 18 | Parameters 19 | ---------- 20 | data : list[ase.Atoms] 21 | List of atoms objects to calculate. 22 | model : NodeWithCalculator, optional 23 | Node providing the calculator object to apply to the data. 24 | """ 25 | 26 | data: list[ase.Atoms] = zntrack.deps() 27 | model: NodeWithCalculator | None = zntrack.deps() 28 | 29 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.h5") 30 | 31 | def run(self): 32 | frames = [] 33 | if self.model is not None: 34 | calc = self.model.get_calculator() 35 | 36 | # Some calculators, e.g. MACE do not follow the ASE API correctly. 37 | # and we need to fix some keys in `all_properties` 38 | all_properties.append("node_energy") 39 | 40 | for atoms in tqdm.tqdm(self.data): 41 | atoms.calc = calc 42 | atoms.get_potential_energy() 43 | frames.append(freeze_copy_atoms(atoms)) 44 | else: 45 | frames = self.data 46 | 47 | io = znh5md.IO(self.frames_path) 48 | io.extend(frames) 49 | 50 | @property 51 | def frames(self) -> list[ase.Atoms]: 52 | with self.state.fs.open(self.frames_path, "rb") as f: 53 | with h5py.File(f) as file: 54 | return list(znh5md.IO(file_handle=file)) 55 | -------------------------------------------------------------------------------- /mlipx/nodes/autowte.py: -------------------------------------------------------------------------------- 1 | """Automatic heat-conductivity predictions from the Wigner Transport Equation. 2 | 3 | Based on https://github.com/MPA2suite/autoWTE 4 | 5 | Use pip install git+https://github.com/MPA2suite/autoWTE 6 | and conda install -c conda-forge phono3py 7 | """ 8 | 9 | import zntrack 10 | 11 | from mlipx.abc import NodeWithCalculator 12 | 13 | 14 | class AutoWTE(zntrack.Node): 15 | model: NodeWithCalculator = zntrack.deps() 16 | -------------------------------------------------------------------------------- /mlipx/nodes/compare_calculator.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import pandas as pd 4 | import tqdm 5 | import zntrack 6 | from ase.calculators.calculator import PropertyNotImplementedError 7 | 8 | from mlipx.abc import FIGURES, FRAMES, ComparisonResults 9 | from mlipx.nodes.evaluate_calculator import EvaluateCalculatorResults, get_figure 10 | from mlipx.utils import rmse, shallow_copy_atoms 11 | 12 | 13 | class CompareCalculatorResults(zntrack.Node): 14 | """ 15 | CompareCalculatorResults is a node that compares the results of two calculators. 16 | It calculates the RMSE between the two calculators and adjusts plots accordingly. 17 | It calculates the error between the two calculators and saves the min/max values. 18 | 19 | Parameters 20 | ---------- 21 | data : EvaluateCalculatorResults 22 | The results of the first calculator. 23 | reference : EvaluateCalculatorResults 24 | The results of the second calculator. 25 | The results of the first calculator will be compared to these results. 26 | """ 27 | 28 | data: EvaluateCalculatorResults = zntrack.deps() 29 | reference: EvaluateCalculatorResults = zntrack.deps() 30 | 31 | plots: pd.DataFrame = zntrack.plots(autosave=True) 32 | rmse: dict = zntrack.metrics() 33 | error: dict = zntrack.metrics() 34 | 35 | def run(self): 36 | e_rmse = rmse(self.data.plots["energy"], self.reference.plots["energy"]) 37 | self.rmse = { 38 | "energy": e_rmse, 39 | "energy_per_atom": e_rmse / len(self.data.plots), 40 | "fmax": rmse(self.data.plots["fmax"], self.reference.plots["fmax"]), 41 | "fnorm": rmse(self.data.plots["fnorm"], self.reference.plots["fnorm"]), 42 | } 43 | 44 | all_plots = [] 45 | 46 | for row_idx in tqdm.trange(len(self.data.plots)): 47 | plots = {} 48 | plots["adjusted_energy_error"] = ( 49 | self.data.plots["energy"].iloc[row_idx] - e_rmse 50 | ) - self.reference.plots["energy"].iloc[row_idx] 51 | plots["adjusted_energy"] = self.data.plots["energy"].iloc[row_idx] - e_rmse 52 | plots["adjusted_energy_error_per_atom"] = ( 53 | plots["adjusted_energy_error"] 54 | / self.data.plots["n_atoms"].iloc[row_idx] 55 | ) 56 | 57 | plots["fmax_error"] = ( 58 | self.data.plots["fmax"].iloc[row_idx] 59 | - self.reference.plots["fmax"].iloc[row_idx] 60 | ) 61 | plots["fnorm_error"] = ( 62 | self.data.plots["fnorm"].iloc[row_idx] 63 | - self.reference.plots["fnorm"].iloc[row_idx] 64 | ) 65 | all_plots.append(plots) 66 | self.plots = pd.DataFrame(all_plots) 67 | 68 | # iterate over plots and save min/max 69 | self.error = {} 70 | for key in self.plots.columns: 71 | if "_error" in key: 72 | stripped_key = key.replace("_error", "") 73 | self.error[f"{stripped_key}_max"] = self.plots[key].max() 74 | self.error[f"{stripped_key}_min"] = self.plots[key].min() 75 | 76 | @property 77 | def frames(self) -> FRAMES: 78 | return self.data.frames 79 | 80 | @property 81 | def figures(self) -> FIGURES: 82 | figures = {} 83 | for key in self.plots.columns: 84 | figures[key] = get_figure(key, [self]) 85 | return figures 86 | 87 | def compare(self, *nodes: "CompareCalculatorResults") -> ComparisonResults: # noqa C901 88 | if len(nodes) == 0: 89 | raise ValueError("No nodes to compare provided") 90 | figures = {} 91 | frames_info = {} 92 | for key in nodes[0].plots.columns: 93 | if not all(key in node.plots.columns for node in nodes): 94 | raise ValueError(f"Key {key} not found in all nodes") 95 | # check frames are the same 96 | figures[key] = get_figure(key, nodes) 97 | 98 | for node in nodes: 99 | for key in node.plots.columns: 100 | frames_info[f"{node.name}_{key}"] = node.plots[key].values 101 | 102 | # TODO: calculate the rmse difference between a fixed 103 | # one and all the others and shift them accordingly 104 | # and plot as energy_shifted 105 | 106 | # plot error between curves 107 | # mlipx pass additional flags to compare function 108 | # have different compare methods also for correlation plots 109 | 110 | frames = [shallow_copy_atoms(x) for x in nodes[0].frames] 111 | for key, values in frames_info.items(): 112 | for atoms, value in zip(frames, values): 113 | atoms.info[key] = value 114 | 115 | for node in nodes: 116 | for node_atoms, atoms in zip(node.frames, frames): 117 | if len(node_atoms) != len(atoms): 118 | raise ValueError("Atoms objects have different lengths") 119 | with contextlib.suppress(RuntimeError, PropertyNotImplementedError): 120 | atoms.info[f"{node.name}_energy"] = ( 121 | node_atoms.get_potential_energy() 122 | ) 123 | atoms.arrays[f"{node.name}_forces"] = node_atoms.get_forces() 124 | 125 | for ref_atoms, atoms in zip(nodes[0].reference.frames, frames): 126 | with contextlib.suppress(RuntimeError, PropertyNotImplementedError): 127 | atoms.info["ref_energy"] = ref_atoms.get_potential_energy() 128 | atoms.arrays["ref_forces"] = ref_atoms.get_forces() 129 | 130 | return { 131 | "frames": frames, 132 | "figures": figures, 133 | } 134 | -------------------------------------------------------------------------------- /mlipx/nodes/energy_volume.py: -------------------------------------------------------------------------------- 1 | import ase.io 2 | import numpy as np 3 | import pandas as pd 4 | import plotly.express as px 5 | import plotly.graph_objects as go 6 | import tqdm 7 | import zntrack 8 | 9 | from mlipx.abc import ComparisonResults, NodeWithCalculator 10 | 11 | 12 | class EnergyVolumeCurve(zntrack.Node): 13 | """Compute the energy-volume curve for a given structure. 14 | 15 | Parameters 16 | ---------- 17 | data : list[ase.Atoms] 18 | List of structures to evaluate. 19 | model : NodeWithCalculator 20 | Node providing the calculator object for the energy calculations. 21 | data_id : int, default=-1 22 | Index of the structure to evaluate. 23 | n_points : int, default=50 24 | Number of points to sample for the volume scaling. 25 | start : float, default=0.75 26 | Initial scaling factor from the original cell. 27 | stop : float, default=2.0 28 | Final scaling factor from the original cell. 29 | 30 | Attributes 31 | ---------- 32 | results : pd.DataFrame 33 | DataFrame with the volume, energy, and scaling factor. 34 | 35 | """ 36 | 37 | model: NodeWithCalculator = zntrack.deps() 38 | data: list[ase.Atoms] = zntrack.deps() 39 | data_id: int = zntrack.params(-1) 40 | 41 | n_points: int = zntrack.params(50) 42 | start: float = zntrack.params(0.75) 43 | stop: float = zntrack.params(2.0) 44 | 45 | frames_path: str = zntrack.outs_path(zntrack.nwd / "frames.xyz") 46 | results: pd.DataFrame = zntrack.plots(y="energy", x="scale") 47 | 48 | def run(self): 49 | atoms = self.data[self.data_id] 50 | calc = self.model.get_calculator() 51 | 52 | results = [] 53 | 54 | scale_factor = np.linspace(self.start, self.stop, self.n_points) 55 | for scale in tqdm.tqdm(scale_factor): 56 | atoms_copy = atoms.copy() 57 | atoms_copy.set_cell(atoms.get_cell() * scale, scale_atoms=True) 58 | atoms_copy.calc = calc 59 | 60 | results.append( 61 | { 62 | "volume": atoms_copy.get_volume(), 63 | "energy": atoms_copy.get_potential_energy(), 64 | "fmax": np.linalg.norm(atoms_copy.get_forces(), axis=-1).max(), 65 | "scale": scale, 66 | } 67 | ) 68 | 69 | ase.io.write(self.frames_path, atoms_copy, append=True) 70 | 71 | self.results = pd.DataFrame(results) 72 | 73 | @property 74 | def frames(self) -> list[ase.Atoms]: 75 | """List of structures evaluated during the energy-volume curve.""" 76 | with self.state.fs.open(self.frames_path, "r") as f: 77 | return list(ase.io.iread(f, format="extxyz")) 78 | 79 | @property 80 | def figures(self) -> dict[str, go.Figure]: 81 | """Plot the energy-volume curve.""" 82 | fig = px.scatter(self.results, x="scale", y="energy", color="scale") 83 | fig.update_layout(title="Energy-Volume Curve") 84 | fig.update_traces(customdata=np.stack([np.arange(self.n_points)], axis=1)) 85 | fig.update_xaxes(title_text="cell vector scale") 86 | fig.update_yaxes(title_text="Energy / eV") 87 | 88 | ffig = px.scatter(self.results, x="scale", y="fmax", color="scale") 89 | ffig.update_layout(title="Energy-Volume Curve (fmax)") 90 | ffig.update_traces(customdata=np.stack([np.arange(self.n_points)], axis=1)) 91 | ffig.update_xaxes(title_text="cell vector scale") 92 | ffig.update_yaxes(title_text="Maximum Force / eV/Å") 93 | 94 | return {"energy-volume-curve": fig, "fmax-volume-curve": ffig} 95 | 96 | @staticmethod 97 | def compare(*nodes: "EnergyVolumeCurve") -> ComparisonResults: 98 | """Compare the energy-volume curves of multiple nodes.""" 99 | fig = go.Figure() 100 | for node in nodes: 101 | fig.add_trace( 102 | go.Scatter( 103 | x=node.results["scale"], 104 | y=node.results["energy"], 105 | mode="lines+markers", 106 | name=node.name.replace("_EnergyVolumeCurve", ""), 107 | ) 108 | ) 109 | fig.update_traces(customdata=np.stack([np.arange(node.n_points)], axis=1)) 110 | 111 | # TODO: remove all info from the frames? 112 | # What about forces / energies? Update the key? 113 | fig.update_layout(title="Energy-Volume Curve Comparison") 114 | # set x-axis title 115 | # fig.update_xaxes(title_text="Volume / ų") 116 | fig.update_xaxes(title_text="cell vector scale") 117 | fig.update_yaxes(title_text="Energy / eV") 118 | 119 | # Now adjusted 120 | 121 | fig_adjust = go.Figure() 122 | for node in nodes: 123 | scale_factor = np.linspace(node.start, node.stop, node.n_points) 124 | one_idx = np.abs(scale_factor - 1).argmin() 125 | fig_adjust.add_trace( 126 | go.Scatter( 127 | x=node.results["scale"], 128 | y=node.results["energy"] - node.results["energy"].iloc[one_idx], 129 | mode="lines+markers", 130 | name=node.name.replace("_EnergyVolumeCurve", ""), 131 | ) 132 | ) 133 | fig_adjust.update_traces( 134 | customdata=np.stack([np.arange(node.n_points)], axis=1) 135 | ) 136 | 137 | fig_adjust.update_layout(title="Adjusted Energy-Volume Curve Comparison") 138 | fig_adjust.update_xaxes(title_text="cell vector scale") 139 | fig_adjust.update_yaxes(title_text="Adjusted Energy / eV") 140 | 141 | fig_adjust.update_layout( 142 | plot_bgcolor="rgba(0, 0, 0, 0)", 143 | paper_bgcolor="rgba(0, 0, 0, 0)", 144 | ) 145 | fig_adjust.update_xaxes( 146 | showgrid=True, 147 | gridwidth=1, 148 | gridcolor="rgba(120, 120, 120, 0.3)", 149 | zeroline=False, 150 | ) 151 | fig_adjust.update_yaxes( 152 | showgrid=True, 153 | gridwidth=1, 154 | gridcolor="rgba(120, 120, 120, 0.3)", 155 | zeroline=False, 156 | ) 157 | 158 | fig.update_layout( 159 | plot_bgcolor="rgba(0, 0, 0, 0)", 160 | paper_bgcolor="rgba(0, 0, 0, 0)", 161 | ) 162 | fig.update_xaxes( 163 | showgrid=True, 164 | gridwidth=1, 165 | gridcolor="rgba(120, 120, 120, 0.3)", 166 | zeroline=False, 167 | ) 168 | fig.update_yaxes( 169 | showgrid=True, 170 | gridwidth=1, 171 | gridcolor="rgba(120, 120, 120, 0.3)", 172 | zeroline=False, 173 | ) 174 | 175 | return { 176 | "frames": nodes[0].frames, 177 | "figures": { 178 | "energy-volume-curve": fig, 179 | "adjusted_energy-volume-curve": fig_adjust, 180 | }, 181 | } 182 | -------------------------------------------------------------------------------- /mlipx/nodes/evaluate_calculator.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import ase 4 | import numpy as np 5 | import pandas as pd 6 | import plotly.express as px 7 | import plotly.graph_objects as go 8 | import tqdm 9 | import zntrack 10 | from ase.calculators.calculator import PropertyNotImplementedError 11 | 12 | from mlipx.abc import ComparisonResults 13 | from mlipx.utils import shallow_copy_atoms 14 | 15 | 16 | def get_figure(key: str, nodes: list["EvaluateCalculatorResults"]) -> go.Figure: 17 | fig = go.Figure() 18 | for node in nodes: 19 | fig.add_trace( 20 | go.Scatter( 21 | x=node.plots.index, 22 | y=node.plots[key], 23 | mode="lines+markers", 24 | name=node.name.replace(f"_{node.__class__.__name__}", ""), 25 | ) 26 | ) 27 | fig.update_traces(customdata=np.stack([np.arange(len(node.plots.index))], axis=1)) 28 | fig.update_layout( 29 | title=key, 30 | plot_bgcolor="rgba(0, 0, 0, 0)", 31 | paper_bgcolor="rgba(0, 0, 0, 0)", 32 | ) 33 | fig.update_xaxes( 34 | showgrid=True, 35 | gridwidth=1, 36 | gridcolor="rgba(120, 120, 120, 0.3)", 37 | zeroline=False, 38 | ) 39 | fig.update_yaxes( 40 | showgrid=True, 41 | gridwidth=1, 42 | gridcolor="rgba(120, 120, 120, 0.3)", 43 | zeroline=False, 44 | ) 45 | return fig 46 | 47 | 48 | class EvaluateCalculatorResults(zntrack.Node): 49 | """ 50 | Evaluate the results of a calculator. 51 | 52 | Parameters 53 | ---------- 54 | data : list[ase.Atoms] 55 | List of atoms objects. 56 | 57 | """ 58 | 59 | data: list[ase.Atoms] = zntrack.deps() 60 | plots: pd.DataFrame = zntrack.plots( 61 | y=["fmax", "fnorm", "energy"], independent=True, autosave=True 62 | ) 63 | 64 | def run(self): 65 | self.plots = pd.DataFrame() 66 | frame_data = [] 67 | for idx in tqdm.tqdm(range(len(self.data))): 68 | atoms = self.data[idx] 69 | 70 | forces = atoms.get_forces() 71 | fmax = np.max(np.linalg.norm(forces, axis=1)) 72 | fnorm = np.linalg.norm(forces) 73 | energy = atoms.get_potential_energy() 74 | # eform = atoms.info.get(ASEKeys.formation_energy.value, -1) 75 | n_atoms = len(atoms) 76 | 77 | # have energy and formation energy in the plot 78 | 79 | plots = { 80 | "fmax": fmax, 81 | "fnorm": fnorm, 82 | "energy": energy, 83 | # "eform": eform, 84 | "n_atoms": n_atoms, 85 | "energy_per_atom": energy / n_atoms, 86 | # "eform_per_atom": eform / n_atoms, 87 | } 88 | frame_data.append(plots) 89 | self.plots = pd.DataFrame(frame_data) 90 | 91 | @property 92 | def frames(self): 93 | return self.data 94 | 95 | def __run_note__(self) -> str: 96 | return f"""# {self.name} 97 | Results from {self.state.remote} at {self.state.rev}. 98 | 99 | View the trajectory via zndraw: 100 | ```bash 101 | zndraw {self.name}.frames --rev {self.state.rev} --remote {self.state.remote} --url https://app-dev.roqs.basf.net/zndraw_app 102 | ``` 103 | """ 104 | 105 | @property 106 | def figures(self) -> dict: 107 | # TODO: remove index column 108 | 109 | plots = {} 110 | for key in self.plots.columns: 111 | fig = px.line( 112 | self.plots, 113 | x=self.plots.index, 114 | y=key, 115 | title=key, 116 | ) 117 | fig.update_traces( 118 | customdata=np.stack([np.arange(len(self.plots))], axis=1), 119 | ) 120 | plots[key] = fig 121 | return plots 122 | 123 | @staticmethod 124 | def compare( 125 | *nodes: "EvaluateCalculatorResults", reference: str | None = None 126 | ) -> ComparisonResults: 127 | # TODO: if reference, shift energies by 128 | # rmse(val, reference) and plot as energy_adjusted 129 | figures = {} 130 | frames_info = {} 131 | for key in nodes[0].plots.columns: 132 | if not all(key in node.plots.columns for node in nodes): 133 | raise ValueError(f"Key {key} not found in all nodes") 134 | # check frames are the same 135 | figures[key] = get_figure(key, nodes) 136 | 137 | for node in nodes: 138 | for key in node.plots.columns: 139 | frames_info[f"{node.name}_{key}"] = node.plots[key].values 140 | 141 | # TODO: calculate the rmse difference between a fixed one 142 | # and all the others and shift them accordingly 143 | # and plot as energy_shifted 144 | 145 | # plot error between curves 146 | # mlipx pass additional flags to compare function 147 | # have different compare methods also for correlation plots 148 | 149 | frames = [shallow_copy_atoms(x) for x in nodes[0].frames] 150 | for key, values in frames_info.items(): 151 | for atoms, value in zip(frames, values): 152 | atoms.info[key] = value 153 | 154 | for node in nodes: 155 | for node_atoms, atoms in zip(node.frames, frames): 156 | if len(node_atoms) != len(atoms): 157 | raise ValueError("Atoms objects have different lengths") 158 | with contextlib.suppress(RuntimeError, PropertyNotImplementedError): 159 | atoms.info[f"{node.name}_energy"] = ( 160 | node_atoms.get_potential_energy() 161 | ) 162 | atoms.arrays[f"{node.name}_forces"] = node_atoms.get_forces() 163 | 164 | return { 165 | "frames": frames, 166 | "figures": figures, 167 | } 168 | -------------------------------------------------------------------------------- /mlipx/nodes/filter_dataset.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import typing as t 3 | from enum import Enum 4 | 5 | import ase.io 6 | import numpy as np 7 | import tqdm 8 | import zntrack 9 | 10 | 11 | class FilteringType(str, Enum): 12 | """ 13 | Enum defining types of filtering to apply on atomic configurations. 14 | 15 | Attributes 16 | ---------- 17 | COMBINATIONS : str 18 | Filters to include atoms with elements that are any 19 | subset of the specified elements. 20 | EXCLUSIVE : str 21 | Filters to include atoms that contain *only* the specified elements. 22 | INCLUSIVE : str 23 | Filters to include atoms that contain *at least* the specified elements. 24 | """ 25 | 26 | COMBINATIONS = "combinations" 27 | EXCLUSIVE = "exclusive" 28 | INCLUSIVE = "inclusive" 29 | 30 | 31 | def filter_atoms( 32 | atoms: ase.Atoms, 33 | element_subset: list[str], 34 | filtering_type: t.Optional[FilteringType] = None, 35 | ) -> bool: 36 | """ 37 | Filters an atomic configuration based on the 38 | specified filtering type and element subset. 39 | 40 | Parameters 41 | ---------- 42 | atoms : ase.Atoms 43 | Atomic configuration to filter. 44 | element_subset : list[str] 45 | List of elements to be considered during filtering. 46 | filtering_type : FilteringType, optional 47 | Type of filtering to apply (COMBINATIONS, EXCLUSIVE, or INCLUSIVE). 48 | If None, all atoms pass the filter. 49 | 50 | Returns 51 | ------- 52 | bool 53 | True if the atomic configuration passes the filter, False otherwise. 54 | 55 | References 56 | ---------- 57 | Adapted from github.com/ACEsuit/mace 58 | 59 | Raises 60 | ------ 61 | ValueError 62 | If the provided filtering_type is not recognized. 63 | """ 64 | if filtering_type is None: 65 | return True 66 | elif filtering_type == FilteringType.COMBINATIONS: 67 | atom_symbols = np.unique(atoms.symbols) 68 | return all(x in element_subset for x in atom_symbols) 69 | elif filtering_type == FilteringType.EXCLUSIVE: 70 | atom_symbols = set(atoms.symbols) 71 | return atom_symbols == set(element_subset) 72 | elif filtering_type == FilteringType.INCLUSIVE: 73 | atom_symbols = np.unique(atoms.symbols) 74 | return all(x in atom_symbols for x in element_subset) 75 | else: 76 | raise ValueError( 77 | f"Filtering type {filtering_type} not recognized." 78 | " Must be one of 'none', 'exclusive', or 'inclusive'." 79 | ) 80 | 81 | 82 | class FilterAtoms(zntrack.Node): 83 | """ 84 | ZnTrack node that filters a list of atomic configurations 85 | based on specified elements and filtering type. 86 | 87 | Attributes 88 | ---------- 89 | data : list[ase.Atoms] 90 | List of atomic configurations to filter. 91 | elements : list[str] 92 | List of elements to use as the filtering subset. 93 | filtering_type : FilteringType 94 | Type of filtering to apply (INCLUSIVE, EXCLUSIVE, or COMBINATIONS). 95 | frames_path : pathlib.Path 96 | Path to store filtered atomic configuration frames. 97 | 98 | Methods 99 | ------- 100 | run() 101 | Filters atomic configurations and writes the results to `frames_path`. 102 | frames() -> list[ase.Atoms] 103 | Loads filtered atomic configurations from `frames_path`. 104 | """ 105 | 106 | data: list[ase.Atoms] = zntrack.deps() 107 | elements: list[str] = zntrack.params() 108 | filtering_type: FilteringType = zntrack.params(FilteringType.INCLUSIVE.value) 109 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz") 110 | 111 | def run(self): 112 | """ 113 | Applies filtering to atomic configurations in 114 | `data` and saves the results to `frames_path`. 115 | """ 116 | for atoms in tqdm.tqdm(self.data): 117 | if filter_atoms(atoms, self.elements, self.filtering_type): 118 | ase.io.write(self.frames_path, atoms, append=True) 119 | 120 | @property 121 | def frames(self) -> list[ase.Atoms]: 122 | """Loads the filtered atomic configurations from the `frames_path` file. 123 | 124 | Returns 125 | ------- 126 | list[ase.Atoms] 127 | List of filtered atomic configuration frames. 128 | """ 129 | with self.state.fs.open(self.frames_path, "r") as f: 130 | return list(ase.io.iread(f, format="extxyz")) 131 | -------------------------------------------------------------------------------- /mlipx/nodes/formation_energy.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | import ase 4 | import pandas as pd 5 | import zntrack 6 | from tqdm import tqdm, trange 7 | 8 | from mlipx.abc import ASEKeys, NodeWithCalculator 9 | from mlipx.utils import rmse 10 | 11 | 12 | class CalculateFormationEnergy(zntrack.Node): 13 | """ 14 | Calculate formation energy. 15 | 16 | Parameters 17 | ---------- 18 | data : list[ase.Atoms] 19 | ASE atoms object with appropriate tags in info 20 | """ 21 | 22 | data: list[ase.Atoms] = zntrack.deps() 23 | model: t.Optional[NodeWithCalculator] = zntrack.deps(None) 24 | 25 | formation_energy: list = zntrack.outs(independent=True) 26 | isolated_energies: dict = zntrack.outs(independent=True) 27 | 28 | plots: pd.DataFrame = zntrack.plots( 29 | y=["eform", "n_atoms"], independent=True, autosave=True 30 | ) 31 | 32 | def get_isolated_energies(self) -> dict[str, float]: 33 | # get all unique elements 34 | isolated_energies = {} 35 | for atoms in tqdm(self.data, desc="Getting isolated energies"): 36 | for element in set(atoms.get_chemical_symbols()): 37 | if self.model is None: 38 | if element not in isolated_energies: 39 | isolated_energies[element] = atoms.info[ 40 | ASEKeys.isolated_energies.value 41 | ][element] 42 | else: 43 | assert ( 44 | isolated_energies[element] 45 | == atoms.info[ASEKeys.isolated_energies.value][element] 46 | ) 47 | else: 48 | if element not in isolated_energies: 49 | box = ase.Atoms( 50 | element, 51 | positions=[[50, 50, 50]], 52 | cell=[100, 100, 100], 53 | pbc=True, 54 | ) 55 | box.calc = self.model.get_calculator() 56 | isolated_energies[element] = box.get_potential_energy() 57 | 58 | return isolated_energies 59 | 60 | def run(self): 61 | self.formation_energy = [] 62 | self.isolated_energies = self.get_isolated_energies() 63 | 64 | plots = [] 65 | 66 | for atoms in self.data: 67 | chem = atoms.get_chemical_symbols() 68 | reference_energy = 0 69 | for element in chem: 70 | reference_energy += self.isolated_energies[element] 71 | E_form = atoms.get_potential_energy() - reference_energy 72 | self.formation_energy.append(E_form) 73 | plots.append({"eform": E_form, "n_atoms": len(atoms)}) 74 | 75 | self.plots = pd.DataFrame(plots) 76 | 77 | @property 78 | def frames(self): 79 | for atom, energy in zip(self.data, self.formation_energy): 80 | atom.info[ASEKeys.formation_energy.value] = energy 81 | return self.data 82 | 83 | 84 | class CompareFormationEnergy(zntrack.Node): 85 | data: CalculateFormationEnergy = zntrack.deps() 86 | reference: CalculateFormationEnergy = zntrack.deps() 87 | 88 | plots: pd.DataFrame = zntrack.plots(autosave=True) 89 | rmse: dict = zntrack.metrics() 90 | error: dict = zntrack.metrics() 91 | 92 | def run(self): 93 | eform_rmse = rmse(self.data.plots["eform"], self.reference.plots["eform"]) 94 | # e_rmse = rmse(self.data.plots["energy"], self.reference.plots["energy"]) 95 | self.rmse = { 96 | "eform": eform_rmse, 97 | "eform_per_atom": eform_rmse / len(self.data.plots), 98 | } 99 | 100 | all_plots = [] 101 | 102 | for row_idx in trange(len(self.data.plots)): 103 | plots = {} 104 | plots["adjusted_eform_error"] = ( 105 | self.data.plots["eform"].iloc[row_idx] - eform_rmse 106 | ) - self.reference.plots["eform"].iloc[row_idx] 107 | plots["adjusted_eform"] = ( 108 | self.data.plots["eform"].iloc[row_idx] - eform_rmse 109 | ) 110 | plots["adjusted_eform_error_per_atom"] = ( 111 | plots["adjusted_eform_error"] / self.data.plots["n_atoms"].iloc[row_idx] 112 | ) 113 | all_plots.append(plots) 114 | self.plots = pd.DataFrame(all_plots) 115 | 116 | # iterate over plots and save min/max 117 | self.error = {} 118 | for key in self.plots.columns: 119 | if "_error" in key: 120 | stripped_key = key.replace("_error", "") 121 | self.error[f"{stripped_key}_max"] = self.plots[key].max() 122 | self.error[f"{stripped_key}_min"] = self.plots[key].min() 123 | -------------------------------------------------------------------------------- /mlipx/nodes/generic_ase.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import importlib 3 | import typing as t 4 | 5 | from ase.calculators.calculator import Calculator 6 | 7 | 8 | class Device: 9 | AUTO = "auto" 10 | CPU = "cpu" 11 | CUDA = "cuda" 12 | 13 | @staticmethod 14 | def resolve_auto() -> t.Literal["cpu", "cuda"]: 15 | import torch 16 | 17 | return "cuda" if torch.cuda.is_available() else "cpu" 18 | 19 | 20 | # TODO: add files as dependencies somehow! 21 | 22 | 23 | @dataclasses.dataclass 24 | class GenericASECalculator: 25 | """Generic ASE calculator. 26 | 27 | Load any ASE calculator from a module and class name. 28 | 29 | Parameters 30 | ---------- 31 | module : str 32 | Module name containing the calculator class. 33 | For LJ this would be 'ase.calculators.lj'. 34 | class_name : str 35 | Class name of the calculator. 36 | For LJ this would be 'LennardJones'. 37 | kwargs : dict, default=None 38 | Additional keyword arguments to pass to the calculator. 39 | For LJ this could be {'epsilon': 1.0, 'sigma': 1.0}. 40 | """ 41 | 42 | module: str 43 | class_name: str 44 | kwargs: dict[str, t.Any] | None = None 45 | device: t.Literal["auto", "cpu", "cuda"] | None = None 46 | 47 | def get_calculator(self, **kwargs) -> Calculator: 48 | if self.kwargs is not None: 49 | kwargs.update(self.kwargs) 50 | module = importlib.import_module(self.module) 51 | cls = getattr(module, self.class_name) 52 | if self.device is None: 53 | return cls(**kwargs) 54 | elif self.device == "auto": 55 | return cls(**kwargs, device=Device.resolve_auto()) 56 | else: 57 | return cls(**kwargs, device=self.device) 58 | 59 | @property 60 | def available(self) -> bool: 61 | try: 62 | importlib.import_module(self.module) 63 | return True 64 | except ImportError: 65 | return False 66 | -------------------------------------------------------------------------------- /mlipx/nodes/invariances.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import ase 4 | import ase.io 5 | import numpy as np 6 | import pandas as pd 7 | import plotly.graph_objects as go 8 | import tqdm 9 | import zntrack 10 | 11 | from mlipx.abc import ComparisonResults, NodeWithCalculator 12 | 13 | 14 | class InvarianceNode(zntrack.Node): 15 | """Base class for testing invariances.""" 16 | 17 | model: NodeWithCalculator = zntrack.deps() 18 | data: list[ase.Atoms] = zntrack.deps() 19 | data_id: int = zntrack.params(-1) 20 | n_points: int = zntrack.params(50) 21 | 22 | metrics: dict = zntrack.metrics() 23 | plots: pd.DataFrame = zntrack.plots() 24 | 25 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz") 26 | 27 | def run(self): 28 | """Common logic for invariance testing.""" 29 | atoms = self.data[self.data_id] 30 | calc = self.model.get_calculator() 31 | atoms.calc = calc 32 | 33 | rng = np.random.default_rng() 34 | energies = [] 35 | for _ in tqdm.tqdm(range(self.n_points)): 36 | self.apply_transformation(atoms, rng) 37 | energies.append(atoms.get_potential_energy()) 38 | ase.io.write(self.frames_path, atoms, append=True) 39 | 40 | self.plots = pd.DataFrame(energies, columns=["energy"]) 41 | 42 | self.metrics = { 43 | "mean": float(np.mean(energies)), 44 | "std": float(np.std(energies)), 45 | } 46 | 47 | @property 48 | def frames(self): 49 | with self.state.fs.open(self.frames_path, "r") as f: 50 | return list(ase.io.iread(f, ":")) 51 | 52 | def apply_transformation(self, atoms_copy: ase.Atoms, rng: np.random.Generator): 53 | """To be implemented by child classes.""" 54 | raise NotImplementedError("Subclasses must implement apply_transformation()") 55 | 56 | @staticmethod 57 | def compare(*nodes: "InvarianceNode") -> ComparisonResults: 58 | if len(nodes) == 0: 59 | raise ValueError("No nodes to compare") 60 | 61 | fig = go.Figure() 62 | for node in nodes: 63 | fig.add_trace( 64 | go.Scatter( 65 | x=np.arange(node.n_points), 66 | y=node.plots["energy"] - node.metrics["mean"], 67 | mode="markers", 68 | name=node.name.replace(f"_{node.__class__.__name__}", ""), 69 | ) 70 | ) 71 | 72 | fig.update_layout( 73 | title=f"Energy vs step ({nodes[0].__class__.__name__})", 74 | xaxis_title="Steps", 75 | yaxis_title="Adjusted energy", 76 | plot_bgcolor="rgba(0, 0, 0, 0)", 77 | paper_bgcolor="rgba(0, 0, 0, 0)", 78 | ) 79 | fig.update_xaxes( 80 | showgrid=True, 81 | gridwidth=1, 82 | gridcolor="rgba(120, 120, 120, 0.3)", 83 | zeroline=False, 84 | ) 85 | fig.update_yaxes( 86 | showgrid=True, 87 | gridwidth=1, 88 | gridcolor="rgba(120, 120, 120, 0.3)", 89 | zeroline=False, 90 | ) 91 | 92 | return ComparisonResults( 93 | frames=nodes[0].frames, figures={"energy_vs_steps_adjusted": fig} 94 | ) 95 | 96 | 97 | class TranslationalInvariance(InvarianceNode): 98 | """Test translational invariance by random box translocation.""" 99 | 100 | def apply_transformation(self, atoms_copy: ase.Atoms, rng: np.random.Generator): 101 | translation = rng.uniform(-1, 1, 3) 102 | atoms_copy.positions += translation 103 | 104 | 105 | class RotationalInvariance(InvarianceNode): 106 | """Test rotational invariance by random rotation of the box.""" 107 | 108 | def apply_transformation(self, atoms_copy: ase.Atoms, rng: np.random.Generator): 109 | angle = rng.uniform(-90, 90) 110 | direction = rng.choice(["x", "y", "z"]) 111 | atoms_copy.rotate(angle, direction, rotate_cell=any(atoms_copy.pbc)) 112 | 113 | 114 | class PermutationInvariance(InvarianceNode): 115 | """Test permutation invariance by permutation of atoms of the same species.""" 116 | 117 | def apply_transformation(self, atoms_copy: ase.Atoms, rng: np.random.Generator): 118 | species = np.unique(atoms_copy.get_chemical_symbols()) 119 | for s in species: 120 | indices = np.where(atoms_copy.get_chemical_symbols() == s)[0] 121 | rng.shuffle(indices) 122 | atoms_copy.positions[indices] = rng.permutation( 123 | atoms_copy.positions[indices], axis=0 124 | ) 125 | -------------------------------------------------------------------------------- /mlipx/nodes/io.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import typing as t 3 | 4 | import ase.io 5 | import h5py 6 | import znh5md 7 | import zntrack 8 | 9 | 10 | class LoadDataFile(zntrack.Node): 11 | """Load a trajectory file. 12 | 13 | Entry point of trajectory data for the use in other nodes. 14 | 15 | Parameters 16 | ---------- 17 | path : str | pathlib.Path 18 | Path to the trajectory file. 19 | start : int, default=0 20 | Index of the first frame to load. 21 | stop : int, default=None 22 | Index of the last frame to load. 23 | step : int, default=1 24 | Step size between frames. 25 | 26 | Attributes 27 | ---------- 28 | frames : list[ase.Atoms] 29 | Loaded frames. 30 | """ 31 | 32 | path: str | pathlib.Path = zntrack.deps_path() 33 | # TODO these are not used 34 | start: int = zntrack.params(0) 35 | stop: t.Optional[int] = zntrack.params(None) 36 | step: int = zntrack.params(1) 37 | 38 | def run(self): 39 | pass 40 | 41 | @property 42 | def frames(self) -> list[ase.Atoms]: 43 | if pathlib.Path(self.path).suffix in [".h5", ".h5md"]: 44 | with self.state.fs.open(self.path, "rb") as f: 45 | with h5py.File(f) as file: 46 | return znh5md.IO(file_handle=file)[ 47 | self.start : self.stop : self.step 48 | ] 49 | else: 50 | format = pathlib.Path(self.path).suffix.lstrip(".") 51 | if format == "xyz": 52 | format = "extxyz" # force ase to use the extxyz reader 53 | with self.state.fs.open(self.path, "r") as f: 54 | return list( 55 | ase.io.iread( 56 | f, format=format, index=slice(self.start, self.stop, self.step) 57 | ) 58 | ) 59 | -------------------------------------------------------------------------------- /mlipx/nodes/modifier.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import typing as t 3 | 4 | from ase import units 5 | 6 | from mlipx.abc import DynamicsModifier 7 | 8 | 9 | @dataclasses.dataclass 10 | class TemperatureRampModifier(DynamicsModifier): 11 | """Ramp the temperature from start_temperature to temperature. 12 | 13 | Attributes 14 | ---------- 15 | start_temperature: float, optional 16 | temperature to start from, if None, the temperature of the thermostat is used. 17 | end_temperature: float 18 | temperature to ramp to. 19 | interval: int, default 1 20 | interval in which the temperature is changed. 21 | total_steps: int 22 | total number of steps in the simulation. 23 | 24 | References 25 | ---------- 26 | Code taken from ipsuite/calculators/ase_md.py 27 | """ 28 | 29 | end_temperature: float 30 | total_steps: int 31 | start_temperature: t.Optional[float] = None 32 | interval: int = 1 33 | 34 | def modify(self, thermostat, step): 35 | # we use the thermostat, so we can also modify e.g. temperature 36 | if self.start_temperature is None: 37 | # different thermostats call the temperature attribute differently 38 | if temp := getattr(thermostat, "temp", None): 39 | self.start_temperature = temp / units.kB 40 | elif temp := getattr(thermostat, "temperature", None): 41 | self.start_temperature = temp / units.kB 42 | else: 43 | raise AttributeError("No temperature attribute found in thermostat.") 44 | 45 | percentage = step / (self.total_steps - 1) 46 | new_temperature = ( 47 | 1 - percentage 48 | ) * self.start_temperature + percentage * self.end_temperature 49 | if step % self.interval == 0: 50 | thermostat.set_temperature(temperature_K=new_temperature) 51 | -------------------------------------------------------------------------------- /mlipx/nodes/molecular_dynamics.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import pathlib 3 | 4 | import ase.io 5 | import ase.units 6 | import numpy as np 7 | import pandas as pd 8 | import plotly.express as px 9 | import plotly.graph_objects as go 10 | import tqdm 11 | import zntrack 12 | from ase.md import Langevin 13 | 14 | from mlipx.abc import ( 15 | ComparisonResults, 16 | DynamicsModifier, 17 | DynamicsObserver, 18 | NodeWithCalculator, 19 | NodeWithMolecularDynamics, 20 | ) 21 | 22 | 23 | @dataclasses.dataclass 24 | class LangevinConfig: 25 | """Configure a Langevin thermostat for molecular dynamics. 26 | 27 | Parameters 28 | ---------- 29 | timestep : float 30 | Time step for the molecular dynamics simulation in fs. 31 | temperature : float 32 | Temperature of the thermostat. 33 | friction : float 34 | Friction coefficient of the thermostat. 35 | """ 36 | 37 | timestep: float 38 | temperature: float 39 | friction: float 40 | 41 | def get_molecular_dynamics(self, atoms) -> Langevin: 42 | return Langevin( 43 | atoms, 44 | timestep=self.timestep * ase.units.fs, 45 | temperature_K=self.temperature, 46 | friction=self.friction, 47 | ) 48 | 49 | 50 | class MolecularDynamics(zntrack.Node): 51 | """Run molecular dynamics simulation. 52 | 53 | Parameters 54 | ---------- 55 | model : NodeWithCalculator 56 | Node providing the calculator object for the simulation. 57 | thermostat : LangevinConfig 58 | Node providing the thermostat object for the simulation. 59 | data : list[ase.Atoms] 60 | Initial configurations for the simulation. 61 | data_id : int, default=-1 62 | Index of the initial configuration to use. 63 | steps : int, default=100 64 | Number of steps to run the simulation. 65 | """ 66 | 67 | model: NodeWithCalculator = zntrack.deps() 68 | thermostat: NodeWithMolecularDynamics = zntrack.deps() 69 | data: list[ase.Atoms] = zntrack.deps() 70 | data_id: int = zntrack.params(-1) 71 | steps: int = zntrack.params(100) 72 | observers: list[DynamicsObserver] = zntrack.deps(None) 73 | modifiers: list[DynamicsModifier] = zntrack.deps(None) 74 | 75 | observer_metrics: dict = zntrack.metrics() 76 | plots: pd.DataFrame = zntrack.plots(y=["energy", "fmax"], autosave=True) 77 | 78 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz") 79 | 80 | def run(self): 81 | if self.observers is None: 82 | self.observers = [] 83 | if self.modifiers is None: 84 | self.modifiers = [] 85 | atoms = self.data[self.data_id] 86 | atoms.calc = self.model.get_calculator() 87 | dyn = self.thermostat.get_molecular_dynamics(atoms) 88 | for obs in self.observers: 89 | obs.initialize(atoms) 90 | 91 | self.observer_metrics = {} 92 | self.plots = pd.DataFrame(columns=["energy", "fmax", "fnorm"]) 93 | 94 | for idx, _ in enumerate( 95 | tqdm.tqdm(dyn.irun(steps=self.steps), total=self.steps) 96 | ): 97 | ase.io.write(self.frames_path, atoms, append=True) 98 | plots = { 99 | "energy": atoms.get_potential_energy(), 100 | "fmax": np.max(np.linalg.norm(atoms.get_forces(), axis=1)), 101 | "fnorm": np.linalg.norm(atoms.get_forces()), 102 | } 103 | self.plots.loc[len(self.plots)] = plots 104 | 105 | for obs in self.observers: 106 | if obs.check(atoms): 107 | self.observer_metrics[obs.name] = idx 108 | 109 | if len(self.observer_metrics) > 0: 110 | break 111 | 112 | for mod in self.modifiers: 113 | mod.modify(dyn, idx) 114 | 115 | for obs in self.observers: 116 | # document all attached observers 117 | self.observer_metrics[obs.name] = self.observer_metrics.get(obs.name, -1) 118 | 119 | @property 120 | def frames(self) -> list[ase.Atoms]: 121 | with self.state.fs.open(self.frames_path, "r") as f: 122 | return list(ase.io.iread(f, format="extxyz")) 123 | 124 | @property 125 | def figures(self) -> dict[str, go.Figure]: 126 | plots = {} 127 | for key in self.plots.columns: 128 | fig = px.line( 129 | self.plots, 130 | x=self.plots.index, 131 | y=key, 132 | title=key, 133 | ) 134 | fig.update_traces( 135 | customdata=np.stack([np.arange(len(self.plots))], axis=1), 136 | ) 137 | plots[key] = fig 138 | return plots 139 | 140 | @staticmethod 141 | def compare(*nodes: "MolecularDynamics") -> ComparisonResults: 142 | frames = sum([node.frames for node in nodes], []) 143 | offset = 0 144 | fig = go.Figure() 145 | for _, node in enumerate(nodes): 146 | energies = [atoms.get_potential_energy() for atoms in node.frames] 147 | fig.add_trace( 148 | go.Scatter( 149 | x=list(range(len(energies))), 150 | y=energies, 151 | mode="lines+markers", 152 | name=node.name.replace(f"_{node.__class__.__name__}", ""), 153 | customdata=np.stack([np.arange(len(energies)) + offset], axis=1), 154 | ) 155 | ) 156 | offset += len(energies) 157 | 158 | fig.update_layout( 159 | title="Energy vs. step", 160 | xaxis_title="Step", 161 | yaxis_title="Energy", 162 | ) 163 | 164 | fig.update_layout( 165 | plot_bgcolor="rgba(0, 0, 0, 0)", 166 | paper_bgcolor="rgba(0, 0, 0, 0)", 167 | ) 168 | fig.update_xaxes( 169 | showgrid=True, 170 | gridwidth=1, 171 | gridcolor="rgba(120, 120, 120, 0.3)", 172 | zeroline=False, 173 | ) 174 | fig.update_yaxes( 175 | showgrid=True, 176 | gridwidth=1, 177 | gridcolor="rgba(120, 120, 120, 0.3)", 178 | zeroline=False, 179 | ) 180 | 181 | # Now we set the first energy to zero for better compareability. 182 | 183 | offset = 0 184 | fig_adjusted = go.Figure() 185 | for _, node in enumerate(nodes): 186 | energies = np.array([atoms.get_potential_energy() for atoms in node.frames]) 187 | energies -= energies[0] 188 | fig_adjusted.add_trace( 189 | go.Scatter( 190 | x=list(range(len(energies))), 191 | y=energies, 192 | mode="lines+markers", 193 | name=node.name.replace(f"_{node.__class__.__name__}", ""), 194 | customdata=np.stack([np.arange(len(energies)) + offset], axis=1), 195 | ) 196 | ) 197 | offset += len(energies) 198 | 199 | fig_adjusted.update_layout( 200 | title="Adjusted energy vs. step", 201 | xaxis_title="Step", 202 | yaxis_title="Adjusted energy", 203 | ) 204 | 205 | fig_adjusted.update_layout( 206 | plot_bgcolor="rgba(0, 0, 0, 0)", 207 | paper_bgcolor="rgba(0, 0, 0, 0)", 208 | ) 209 | fig_adjusted.update_xaxes( 210 | showgrid=True, 211 | gridwidth=1, 212 | gridcolor="rgba(120, 120, 120, 0.3)", 213 | zeroline=False, 214 | ) 215 | fig_adjusted.update_yaxes( 216 | showgrid=True, 217 | gridwidth=1, 218 | gridcolor="rgba(120, 120, 120, 0.3)", 219 | zeroline=False, 220 | ) 221 | 222 | return ComparisonResults( 223 | frames=frames, 224 | figures={"energy_vs_steps": fig, "energy_vs_steps_adjusted": fig_adjusted}, 225 | ) 226 | -------------------------------------------------------------------------------- /mlipx/nodes/mp_api.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import ase.io 4 | import zntrack 5 | from mp_api import client 6 | from pymatgen.io.ase import AseAtomsAdaptor 7 | 8 | 9 | class MPRester(zntrack.Node): 10 | """Search the materials project database. 11 | 12 | Parameters 13 | ---------- 14 | search_kwargs: dict 15 | The search parameters for the materials project database. 16 | 17 | Example 18 | ------- 19 | >>> os.environ["MP_API_KEY"] = "your_api_key" 20 | >>> MPRester(search_kwargs={"material_ids": ["mp-1234"]}) 21 | 22 | """ 23 | 24 | search_kwargs: dict = zntrack.params() 25 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz") 26 | 27 | def run(self): 28 | with client.MPRester() as mpr: 29 | docs = mpr.materials.search(**self.search_kwargs) 30 | 31 | frames = [] 32 | adaptor = AseAtomsAdaptor() 33 | 34 | for entry in docs: 35 | structure = entry.structure 36 | atoms = adaptor.get_atoms(structure) 37 | frames.append(atoms) 38 | 39 | ase.io.write(self.frames_path, frames) 40 | 41 | @property 42 | def frames(self) -> list[ase.Atoms]: 43 | with self.state.fs.open(self.frames_path, mode="r") as f: 44 | return list(ase.io.iread(f, format="extxyz")) 45 | -------------------------------------------------------------------------------- /mlipx/nodes/observer.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import warnings 3 | 4 | import ase 5 | import numpy as np 6 | 7 | from mlipx.abc import DynamicsObserver 8 | 9 | 10 | @dataclasses.dataclass 11 | class MaximumForceObserver(DynamicsObserver): 12 | """Evaluate if the maximum force on a single atom exceeds a threshold. 13 | 14 | Parameters 15 | ---------- 16 | f_max : float 17 | Maximum allowed force norm on a single atom 18 | 19 | 20 | Example 21 | ------- 22 | 23 | >>> import zntrack, mlipx 24 | >>> project = zntrack.Project() 25 | >>> observer = mlipx.MaximumForceObserver(f_max=0.1) 26 | >>> with project: 27 | ... md = mlipx.MolecularDynamics( 28 | ... observers=[observer], 29 | ... **kwargs 30 | ... ) 31 | >>> project.build() 32 | """ 33 | 34 | f_max: float 35 | 36 | def check(self, atoms: ase.Atoms) -> bool: 37 | """Check if the maximum force on a single atom exceeds the threshold. 38 | 39 | Parameters 40 | ---------- 41 | atoms : ase.Atoms 42 | Atoms object to evaluate 43 | """ 44 | 45 | max_force = np.linalg.norm(atoms.get_forces(), axis=1).max() 46 | if max_force > self.f_max: 47 | warnings.warn(f"Maximum force {max_force} exceeds {self.f_max}") 48 | return True 49 | return False 50 | -------------------------------------------------------------------------------- /mlipx/nodes/orca.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import os 3 | from pathlib import Path 4 | 5 | from ase.calculators.orca import ORCA, OrcaProfile 6 | 7 | 8 | @dataclasses.dataclass 9 | class OrcaSinglePoint: 10 | """Use ORCA to perform a single point calculation. 11 | 12 | Parameters 13 | ---------- 14 | orcasimpleinput : str 15 | ORCA input string. 16 | You can use something like "PBE def2-TZVP TightSCF EnGrad". 17 | orcablocks : str 18 | ORCA input blocks. 19 | You can use something like "%pal nprocs 8 end". 20 | orca_shell : str, optional 21 | Path to the ORCA executable. 22 | The environment variable MLIPX_ORCA will be used if not provided. 23 | """ 24 | 25 | orcasimpleinput: str 26 | orcablocks: str 27 | orca_shell: str | None = None 28 | 29 | def get_calculator(self, directory: str | Path) -> ORCA: 30 | profile = OrcaProfile(command=self.orca_shell or os.environ["MLIPX_ORCA"]) 31 | 32 | calc = ORCA( 33 | profile=profile, 34 | orcasimpleinput=self.orcasimpleinput, 35 | orcablocks=self.orcablocks, 36 | directory=directory, 37 | ) 38 | return calc 39 | 40 | @property 41 | def available(self) -> None: 42 | return None 43 | -------------------------------------------------------------------------------- /mlipx/nodes/phase_diagram.py: -------------------------------------------------------------------------------- 1 | # skip linting for this file 2 | 3 | import itertools 4 | import os 5 | import typing as t 6 | 7 | import ase.io 8 | import pandas as pd 9 | import plotly.express as px 10 | import plotly.graph_objects as go 11 | import zntrack 12 | from ase.optimize import BFGS 13 | from mp_api.client import MPRester 14 | from pymatgen.analysis.phase_diagram import PDPlotter 15 | from pymatgen.analysis.phase_diagram import PhaseDiagram as pmg_PhaseDiagram 16 | from pymatgen.entries.compatibility import ( 17 | MaterialsProject2020Compatibility, 18 | ) 19 | from pymatgen.entries.computed_entries import ( 20 | ComputedEntry, 21 | ) 22 | 23 | from mlipx.abc import ComparisonResults, NodeWithCalculator 24 | 25 | 26 | class PhaseDiagram(zntrack.Node): 27 | """Compute the phase diagram for a given set of structures. 28 | 29 | Parameters 30 | ---------- 31 | data : list[ase.Atoms] 32 | List of structures to evaluate. 33 | model : NodeWithCalculator 34 | Node providing the calculator object for the energy calculations. 35 | chemsys: list[str], defaeult=None 36 | The set of chemical symbols to construct phase diagram. 37 | data_ids : list[int], default=None 38 | Index of the structure to evaluate. 39 | geo_opt: bool, default=False 40 | Whether to perform geometry optimization before calculating the 41 | formation energy of each structure. 42 | fmax: float, default=0.05 43 | The maximum force stopping rule for geometry optimizations. 44 | 45 | Attributes 46 | ---------- 47 | results : pd.DataFrame 48 | DataFrame with the data_id, potential energy and formation energy. 49 | plots : dict[str, go.Figure] 50 | Dictionary with the phase diagram (and formation energy plot). 51 | 52 | """ 53 | 54 | model: NodeWithCalculator = zntrack.deps() 55 | data: list[ase.Atoms] = zntrack.deps() 56 | chemsys: list[str] = zntrack.params(None) 57 | data_ids: list[int] = zntrack.params(None) 58 | geo_opt: bool = zntrack.params(False) 59 | fmax: float = zntrack.params(0.05) 60 | frames_path: str = zntrack.outs_path(zntrack.nwd / "frames.xyz") 61 | results: pd.DataFrame = zntrack.plots(x="data_id", y="formation_energy") 62 | phase_diagram: t.Any = zntrack.outs() 63 | 64 | def run(self): # noqa C901 65 | if self.data_ids is None: 66 | atoms_list = self.data 67 | else: 68 | atoms_list = [self.data[i] for i in self.data_id] 69 | if self.model is not None: 70 | calc = self.model.get_calculator() 71 | 72 | U_metal_set = {"Co", "Cr", "Fe", "Mn", "Mo", "Ni", "V", "W"} 73 | U_settings = { 74 | "Co": 3.32, 75 | "Cr": 3.7, 76 | "Fe": 5.3, 77 | "Mn": 3.9, 78 | "Mo": 4.38, 79 | "Ni": 6.2, 80 | "V": 3.25, 81 | "W": 6.2, 82 | } 83 | try: 84 | api_key = os.environ["MP_API_KEY"] 85 | except KeyError: 86 | api_key = None 87 | 88 | entries, epots = [], [] 89 | for atoms in atoms_list: 90 | metals = [s for s in set(atoms.symbols) if s not in ["O", "H"]] 91 | hubbards = {} 92 | if set(metals) & U_metal_set: 93 | run_type = "GGA+U" 94 | is_hubbard = True 95 | for m in metals: 96 | hubbards[m] = U_settings.get(m, 0) 97 | else: 98 | run_type = "GGA" 99 | is_hubbard = False 100 | 101 | if self.model is not None: 102 | atoms.calc = calc 103 | if self.geo_opt: 104 | dyn = BFGS(atoms) 105 | dyn.run(fmax=self.fmax) 106 | epot = atoms.get_potential_energy() 107 | ase.io.write(self.frames_path, atoms, append=True) 108 | epots.append(epot) 109 | amt_dict = { 110 | m: len([a for a in atoms if a.symbol == m]) for m in set(atoms.symbols) 111 | } 112 | entry = ComputedEntry( 113 | composition=amt_dict, 114 | energy=epot, 115 | parameters={ 116 | "run_type": run_type, 117 | "software": "N/A", 118 | "oxide_type": "oxide", 119 | "is_hubbard": is_hubbard, 120 | "hubbards": hubbards, 121 | }, 122 | ) 123 | entries.append(entry) 124 | compat = MaterialsProject2020Compatibility() 125 | computed_entries = compat.process_entries(entries) 126 | if api_key is None: 127 | mp_entries = [] 128 | else: 129 | mpr = MPRester(api_key) 130 | if self.chemsys is None: 131 | chemsys = set( 132 | itertools.chain.from_iterable(atoms.symbols for atoms in atoms_list) 133 | ) 134 | else: 135 | chemsys = self.chemsys 136 | mp_entries = mpr.get_entries_in_chemsys(chemsys) 137 | all_entries = computed_entries + mp_entries 138 | self.phase_diagram = pmg_PhaseDiagram(all_entries) 139 | 140 | row_dicts = [] 141 | for i, entry in enumerate(computed_entries): 142 | if self.data_ids is None: 143 | data_id = i 144 | else: 145 | data_id = self.data_id[i] 146 | eform = self.phase_diagram.get_form_energy_per_atom(entry) 147 | row_dicts.append( 148 | { 149 | "data_id": data_id, 150 | "potential_energy": epots[i], 151 | "formation_energy": eform, 152 | }, 153 | ) 154 | self.results = pd.DataFrame(row_dicts) 155 | 156 | @property 157 | def figures(self) -> dict[str, go.Figure]: 158 | plotter = PDPlotter(self.phase_diagram) 159 | fig1 = plotter.get_plot() 160 | fig2 = px.line(self.results, x="data_id", y="formation_energy") 161 | fig2.update_layout(title="Formation Energy Plot") 162 | pd_df = pd.DataFrame( 163 | [len(self.phase_diagram.stable_entries)], columns=["Stable_phases"] 164 | ) 165 | fig3 = px.bar(pd_df, y="Stable_phases") 166 | 167 | return { 168 | "phase-diagram": fig1, 169 | "formation-energy-plot": fig2, 170 | "stable_phases": fig3, 171 | } 172 | 173 | @staticmethod 174 | def compare(*nodes: "PhaseDiagram") -> ComparisonResults: 175 | figures = {} 176 | 177 | for node in nodes: 178 | # Extract a unique identifier for the node 179 | node_identifier = node.name.replace(f"_{node.__class__.__name__}", "") 180 | 181 | # Update and store the figures directly 182 | for key, fig in node.figures.items(): 183 | fig.update_layout( 184 | title=node_identifier, 185 | plot_bgcolor="rgba(0, 0, 0, 0)", 186 | paper_bgcolor="rgba(0, 0, 0, 0)", 187 | ) 188 | fig.update_xaxes( 189 | showgrid=True, 190 | gridwidth=1, 191 | gridcolor="rgba(120, 120, 120, 0.3)", 192 | zeroline=False, 193 | ) 194 | fig.update_yaxes( 195 | showgrid=True, 196 | gridwidth=1, 197 | gridcolor="rgba(120, 120, 120, 0.3)", 198 | zeroline=False, 199 | ) 200 | figures[f"{node_identifier}-{key}"] = fig 201 | 202 | return { 203 | "frames": nodes[0].frames, 204 | "figures": figures, 205 | } 206 | 207 | @property 208 | def frames(self) -> list[ase.Atoms]: 209 | with self.state.fs.open(self.frames_path, "r") as f: 210 | return list(ase.io.iread(f, format="extxyz")) 211 | -------------------------------------------------------------------------------- /mlipx/nodes/rattle.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import ase.io 4 | import zntrack 5 | 6 | 7 | class Rattle(zntrack.Node): 8 | data: list[ase.Atoms] = zntrack.deps() 9 | stdev: float = zntrack.params(0.001) 10 | seed: int = zntrack.params(42) 11 | 12 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz") 13 | 14 | def run(self): 15 | for atoms in self.data: 16 | atoms.rattle(stdev=self.stdev, seed=self.seed) 17 | ase.io.write(self.frames_path, atoms, append=True) 18 | 19 | @property 20 | def frames(self) -> list[ase.Atoms]: 21 | with self.state.fs.open(self.frames_path, "r") as f: 22 | return list(ase.io.iread(f, format="extxyz")) 23 | -------------------------------------------------------------------------------- /mlipx/nodes/smiles.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import ase 4 | import ase.io as aio 5 | import zntrack 6 | 7 | 8 | class Smiles2Conformers(zntrack.Node): 9 | """Create conformers from a SMILES string. 10 | 11 | Parameters 12 | ---------- 13 | smiles : str 14 | The SMILES string. 15 | num_confs : int 16 | The number of conformers to generate. 17 | random_seed : int 18 | The random seed. 19 | max_attempts : int 20 | The maximum number of attempts. 21 | """ 22 | 23 | smiles: str = zntrack.params() 24 | num_confs: int = zntrack.params() 25 | random_seed: int = zntrack.params(42) 26 | max_attempts: int = zntrack.params(1000) 27 | 28 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz") 29 | 30 | def run(self): 31 | from rdkit2ase import smiles2conformers 32 | 33 | conformers = smiles2conformers( 34 | self.smiles, 35 | numConfs=self.num_confs, 36 | randomSeed=self.random_seed, 37 | maxAttempts=self.max_attempts, 38 | ) 39 | aio.write(self.frames_path, conformers) 40 | 41 | @property 42 | def frames(self) -> list[ase.Atoms]: 43 | with self.state.fs.open(self.frames_path, "r") as f: 44 | return list(aio.iread(f, format="extxyz")) 45 | 46 | 47 | class BuildBox(zntrack.Node): 48 | """Build a box from a list of atoms. 49 | 50 | Parameters 51 | ---------- 52 | data : list[list[ase.Atoms]] 53 | A list of lists of ASE Atoms objects representing 54 | the molecules to be packed. 55 | counts : list[int] 56 | A list of integers representing the number of each 57 | type of molecule. 58 | density : float 59 | The target density of the packed system in kg/m^3 60 | 61 | """ 62 | 63 | data: list[list[ase.Atoms]] = zntrack.deps() 64 | counts: list[int] = zntrack.params() 65 | density: float = zntrack.params(1000) 66 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz") 67 | 68 | def run(self): 69 | from rdkit2ase import pack 70 | 71 | atoms = pack(data=self.data, counts=self.counts, density=self.density) 72 | aio.write(self.frames_path, atoms) 73 | 74 | @property 75 | def frames(self) -> list[ase.Atoms]: 76 | with self.state.fs.open(self.frames_path, "r") as f: 77 | return list(aio.iread(f, format="extxyz")) 78 | -------------------------------------------------------------------------------- /mlipx/nodes/structure_optimization.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import ase.io 4 | import ase.optimize as opt 5 | import numpy as np 6 | import pandas as pd 7 | import plotly.graph_objects as go 8 | import zntrack 9 | 10 | from mlipx.abc import ComparisonResults, NodeWithCalculator, Optimizer 11 | 12 | 13 | class StructureOptimization(zntrack.Node): 14 | """Structure optimization Node. 15 | 16 | Relax the geometry for the selected `ase.Atoms`. 17 | 18 | Parameters 19 | ---------- 20 | data : list[ase.Atoms] 21 | Atoms to relax. 22 | data_id: int, default=-1 23 | The index of the ase.Atoms in `data` to optimize. 24 | optimizer : Optimizer 25 | Optimizer to use. 26 | model : NodeWithCalculator 27 | Model to use. 28 | fmax : float 29 | Maximum force to reach before stopping. 30 | steps : int 31 | Maximum number of steps for each optimization. 32 | plots : pd.DataFrame 33 | Resulting energy and fmax for each step. 34 | trajectory_path : str 35 | Output directory for the optimization trajectories. 36 | 37 | """ 38 | 39 | data: list[ase.Atoms] = zntrack.deps() 40 | data_id: int = zntrack.params(-1) 41 | optimizer: Optimizer = zntrack.params(Optimizer.LBFGS.value) 42 | model: NodeWithCalculator = zntrack.deps() 43 | fmax: float = zntrack.params(0.05) 44 | steps: int = zntrack.params(100_000_000) 45 | plots: pd.DataFrame = zntrack.plots(y=["energy", "fmax"], x="step") 46 | 47 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.traj") 48 | 49 | def run(self): 50 | optimizer = getattr(opt, self.optimizer) 51 | calc = self.model.get_calculator() 52 | 53 | atoms = self.data[self.data_id] 54 | self.frames_path.parent.mkdir(exist_ok=True) 55 | 56 | energies = [] 57 | fmax = [] 58 | 59 | def metrics_callback(): 60 | energies.append(atoms.get_potential_energy()) 61 | fmax.append(np.linalg.norm(atoms.get_forces(), axis=-1).max()) 62 | 63 | atoms.calc = calc 64 | dyn = optimizer( 65 | atoms, 66 | trajectory=self.frames_path.as_posix(), 67 | ) 68 | dyn.attach(metrics_callback) 69 | dyn.run(fmax=self.fmax, steps=self.steps) 70 | 71 | self.plots = pd.DataFrame({"energy": energies, "fmax": fmax}) 72 | self.plots.index.name = "step" 73 | 74 | @property 75 | def frames(self) -> list[ase.Atoms]: 76 | with self.state.fs.open(self.frames_path, "rb") as f: 77 | return list(ase.io.iread(f, format="traj")) 78 | 79 | @property 80 | def figures(self) -> dict[str, go.Figure]: 81 | figure = go.Figure() 82 | 83 | energies = [atoms.get_potential_energy() for atoms in self.frames] 84 | figure.add_trace( 85 | go.Scatter( 86 | x=list(range(len(energies))), 87 | y=energies, 88 | mode="lines+markers", 89 | customdata=np.stack([np.arange(len(energies))], axis=1), 90 | ) 91 | ) 92 | 93 | figure.update_layout( 94 | title="Energy vs. Steps", 95 | xaxis_title="Step", 96 | yaxis_title="Energy", 97 | ) 98 | 99 | ffigure = go.Figure() 100 | ffigure.add_trace( 101 | go.Scatter( 102 | x=self.plots.index, 103 | y=self.plots["fmax"], 104 | mode="lines+markers", 105 | customdata=np.stack([np.arange(len(energies))], axis=1), 106 | ) 107 | ) 108 | 109 | ffigure.update_layout( 110 | title="Fmax vs. Steps", 111 | xaxis_title="Step", 112 | yaxis_title="Maximum force", 113 | ) 114 | return {"energy_vs_steps": figure, "fmax_vs_steps": ffigure} 115 | 116 | @staticmethod 117 | def compare(*nodes: "StructureOptimization") -> ComparisonResults: 118 | frames = sum([node.frames for node in nodes], []) 119 | offset = 0 120 | fig = go.Figure() 121 | for idx, node in enumerate(nodes): 122 | energies = [atoms.get_potential_energy() for atoms in node.frames] 123 | fig.add_trace( 124 | go.Scatter( 125 | x=list(range(len(energies))), 126 | y=energies, 127 | mode="lines+markers", 128 | name=node.name.replace(f"_{node.__class__.__name__}", ""), 129 | customdata=np.stack([np.arange(len(energies)) + offset], axis=1), 130 | ) 131 | ) 132 | offset += len(energies) 133 | 134 | fig.update_layout( 135 | title="Energy vs. Steps", 136 | xaxis_title="Step", 137 | yaxis_title="Energy", 138 | plot_bgcolor="rgba(0, 0, 0, 0)", 139 | paper_bgcolor="rgba(0, 0, 0, 0)", 140 | ) 141 | fig.update_xaxes( 142 | showgrid=True, 143 | gridwidth=1, 144 | gridcolor="rgba(120, 120, 120, 0.3)", 145 | zeroline=False, 146 | ) 147 | fig.update_yaxes( 148 | showgrid=True, 149 | gridwidth=1, 150 | gridcolor="rgba(120, 120, 120, 0.3)", 151 | zeroline=False, 152 | ) 153 | 154 | # now adjusted 155 | 156 | offset = 0 157 | fig_adjusted = go.Figure() 158 | for idx, node in enumerate(nodes): 159 | energies = np.array([atoms.get_potential_energy() for atoms in node.frames]) 160 | energies -= energies[0] 161 | fig_adjusted.add_trace( 162 | go.Scatter( 163 | x=list(range(len(energies))), 164 | y=energies, 165 | mode="lines+markers", 166 | name=node.name.replace(f"_{node.__class__.__name__}", ""), 167 | customdata=np.stack([np.arange(len(energies)) + offset], axis=1), 168 | ) 169 | ) 170 | offset += len(energies) 171 | 172 | fig_adjusted.update_layout( 173 | title="Adjusted energy vs. Steps", 174 | xaxis_title="Step", 175 | yaxis_title="Adjusted energy", 176 | plot_bgcolor="rgba(0, 0, 0, 0)", 177 | paper_bgcolor="rgba(0, 0, 0, 0)", 178 | ) 179 | fig_adjusted.update_xaxes( 180 | showgrid=True, 181 | gridwidth=1, 182 | gridcolor="rgba(120, 120, 120, 0.3)", 183 | zeroline=False, 184 | ) 185 | fig_adjusted.update_yaxes( 186 | showgrid=True, 187 | gridwidth=1, 188 | gridcolor="rgba(120, 120, 120, 0.3)", 189 | zeroline=False, 190 | ) 191 | 192 | return ComparisonResults( 193 | frames=frames, 194 | figures={"energy_vs_steps": fig, "adjusted_energy_vs_steps": fig_adjusted}, 195 | ) 196 | -------------------------------------------------------------------------------- /mlipx/nodes/updated_frames.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import ase 4 | from ase.calculators.calculator import Calculator, all_changes 5 | 6 | 7 | class _UpdateFramesCalc(Calculator): 8 | implemented_properties = ["energy", "forces"] 9 | 10 | def __init__( 11 | self, results_mapping: dict, info_mapping: dict, arrays_mapping: dict, **kwargs 12 | ): 13 | Calculator.__init__(self, **kwargs) 14 | self.results_mapping = results_mapping 15 | self.info_mapping = info_mapping 16 | self.arrays_mapping = arrays_mapping 17 | 18 | def calculate( 19 | self, 20 | atoms=ase.Atoms, 21 | properties=None, 22 | system_changes=all_changes, 23 | ): 24 | if properties is None: 25 | properties = self.implemented_properties 26 | Calculator.calculate(self, atoms, properties, system_changes) 27 | for target, key in self.results_mapping.items(): 28 | if key is None: 29 | continue 30 | try: 31 | value = atoms.info[key] 32 | except KeyError: 33 | value = atoms.arrays[key] 34 | self.results[target] = value 35 | 36 | for target, key in self.info_mapping.items(): 37 | # rename the key to target 38 | atoms.info[target] = atoms.info[key] 39 | del atoms.info[key] 40 | 41 | for target, key in self.arrays_mapping.items(): 42 | # rename the key to target 43 | atoms.arrays[target] = atoms.arrays[key] 44 | del atoms.arrays[key] 45 | 46 | 47 | # TODO: what if the energy is in the single point calculator but the forces are not? 48 | @dataclasses.dataclass 49 | class UpdateFramesCalc: 50 | results_mapping: dict[str, str] = dataclasses.field(default_factory=dict) 51 | info_mapping: dict[str, str] = dataclasses.field(default_factory=dict) 52 | arrays_mapping: dict[str, str] = dataclasses.field(default_factory=dict) 53 | 54 | def get_calculator(self, **kwargs) -> _UpdateFramesCalc: 55 | return _UpdateFramesCalc( 56 | results_mapping=self.results_mapping, 57 | info_mapping=self.info_mapping, 58 | arrays_mapping=self.arrays_mapping, 59 | ) 60 | -------------------------------------------------------------------------------- /mlipx/project.py: -------------------------------------------------------------------------------- 1 | from zntrack import Project 2 | 3 | __all__ = ["Project"] 4 | -------------------------------------------------------------------------------- /mlipx/recipes/README.md: -------------------------------------------------------------------------------- 1 | # Jinja2 Templating 2 | 3 | We use `recipe.py.jinja2` templates for generating the `main.py` and `models.py` file from the CLI. 4 | For new Nodes, once you added them to `mlipx/nodes/.py` and updated the `mlipx/__init__.pyi` you might want to create a new template and update the CLI in `main.py`. 5 | 6 | If you want to introduce a new model, you might want to adapt `models.py.jinja2` and the `main.py` as well. 7 | -------------------------------------------------------------------------------- /mlipx/recipes/__init__.py: -------------------------------------------------------------------------------- 1 | from mlipx.recipes.main import app 2 | 3 | __all__ = ["app"] 4 | -------------------------------------------------------------------------------- /mlipx/recipes/adsorption.py.jinja2: -------------------------------------------------------------------------------- 1 | import mlipx 2 | import zntrack 3 | 4 | from models import MODELS 5 | 6 | project = zntrack.Project() 7 | 8 | slabs = [] 9 | {% if slab_config %} 10 | with project.group("initialize"): 11 | slabs.append(mlipx.BuildASEslab(**{{ slab_config }}).frames) 12 | {% endif %} 13 | 14 | adsorbates = [] 15 | {% if smiles %} 16 | with project.group("initialize"): 17 | for smiles in {{ smiles }}: 18 | adsorbates.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1).frames) 19 | {% endif %} 20 | 21 | for model_name, model in MODELS.items(): 22 | for idx, slab in enumerate(slabs): 23 | for jdx, adsorbate in enumerate(adsorbates): 24 | with project.group(model_name, str(idx)): 25 | _ = mlipx.RelaxAdsorptionConfigs( 26 | slabs=slab, 27 | adsorbates=adsorbate, 28 | model=model, 29 | ) 30 | 31 | project.build() 32 | -------------------------------------------------------------------------------- /mlipx/recipes/energy_volume.py.jinja2: -------------------------------------------------------------------------------- 1 | import mlipx 2 | import zntrack 3 | 4 | from models import MODELS 5 | 6 | project = zntrack.Project() 7 | 8 | frames = [] 9 | {% if datapath %} 10 | with project.group("initialize"): 11 | for path in {{ datapath }}: 12 | frames.append(mlipx.LoadDataFile(path=path)) 13 | {% endif %}{% if material_ids %} 14 | with project.group("initialize"): 15 | for material_id in {{ material_ids }}: 16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]})) 17 | {% endif %}{% if smiles %} 18 | with project.group("initialize"): 19 | for smiles in {{ smiles }}: 20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1)) 21 | {% endif %} 22 | 23 | 24 | for model_name, model in MODELS.items(): 25 | for idx, data in enumerate(frames): 26 | with project.group(model_name, str(idx)): 27 | neb = mlipx.EnergyVolumeCurve( 28 | model=model, 29 | data=data.frames, 30 | n_points=50, 31 | start=0.8, 32 | stop=2.0, 33 | ) 34 | 35 | project.build() 36 | -------------------------------------------------------------------------------- /mlipx/recipes/homonuclear_diatomics.py.jinja2: -------------------------------------------------------------------------------- 1 | import mlipx 2 | import zntrack 3 | 4 | from models import MODELS 5 | 6 | project = zntrack.Project() 7 | 8 | frames = [] 9 | {% if datapath %} 10 | with project.group("initialize"): 11 | for path in {{ datapath }}: 12 | frames.append(mlipx.LoadDataFile(path=path).frames) 13 | {% endif %}{% if material_ids %} 14 | with project.group("initialize"): 15 | for material_id in {{ material_ids }}: 16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]}).frames) 17 | {% endif %}{% if smiles %} 18 | with project.group("initialize"): 19 | for smiles in {{ smiles }}: 20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1).frames) 21 | {% endif %} 22 | 23 | for model_name, model in MODELS.items(): 24 | with project.group(model_name): 25 | neb = mlipx.HomonuclearDiatomics( 26 | elements=[], 27 | data=sum(frames, []), # Use all elements from all frames 28 | model=model, 29 | n_points=100, 30 | min_distance=0.5, 31 | max_distance=2.0, 32 | ) 33 | 34 | project.build() 35 | -------------------------------------------------------------------------------- /mlipx/recipes/invariances.py.jinja2: -------------------------------------------------------------------------------- 1 | import mlipx 2 | import zntrack 3 | 4 | from models import MODELS 5 | 6 | project = zntrack.Project() 7 | 8 | frames = [] 9 | {% if datapath %} 10 | with project.group("initialize"): 11 | for path in {{ datapath }}: 12 | frames.append(mlipx.LoadDataFile(path=path)) 13 | {% endif %}{% if material_ids %} 14 | with project.group("initialize"): 15 | for material_id in {{ material_ids }}: 16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]})) 17 | {% endif %}{% if smiles %} 18 | with project.group("initialize"): 19 | for smiles in {{ smiles }}: 20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1)) 21 | {% endif %} 22 | 23 | for model_name, model in MODELS.items(): 24 | for idx, data in enumerate(frames): 25 | with project.group(model_name, str(idx)): 26 | rot = mlipx.RotationalInvariance( 27 | model=model, 28 | n_points=100, 29 | data=data.frames, 30 | ) 31 | trans = mlipx.TranslationalInvariance( 32 | model=model, 33 | n_points=100, 34 | data=data.frames, 35 | ) 36 | perm = mlipx.PermutationInvariance( 37 | model=model, 38 | n_points=100, 39 | data=data.frames, 40 | ) 41 | 42 | project.build() 43 | -------------------------------------------------------------------------------- /mlipx/recipes/md.py.jinja2: -------------------------------------------------------------------------------- 1 | import mlipx 2 | import zntrack 3 | 4 | from models import MODELS 5 | 6 | project = zntrack.Project() 7 | 8 | frames = [] 9 | {% if datapath %} 10 | with project.group("initialize"): 11 | for path in {{ datapath }}: 12 | frames.append(mlipx.LoadDataFile(path=path)) 13 | {% endif %}{% if material_ids %} 14 | with project.group("initialize"): 15 | for material_id in {{ material_ids }}: 16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]})) 17 | {% endif %}{% if smiles %} 18 | with project.group("initialize"): 19 | for smiles in {{ smiles }}: 20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1)) 21 | {% endif %} 22 | 23 | thermostat = mlipx.LangevinConfig(timestep=0.5, temperature=300, friction=0.05) 24 | force_check = mlipx.MaximumForceObserver(f_max=100) 25 | t_ramp = mlipx.TemperatureRampModifier(end_temperature=400, total_steps=100) 26 | 27 | 28 | for model_name, model in MODELS.items(): 29 | for idx, data in enumerate(frames): 30 | with project.group(model_name, str(idx)): 31 | neb = mlipx.MolecularDynamics( 32 | model=model, 33 | thermostat=thermostat, 34 | data=data.frames, 35 | observers=[force_check], 36 | modifiers=[t_ramp], 37 | steps=1000, 38 | ) 39 | 40 | project.build() 41 | -------------------------------------------------------------------------------- /mlipx/recipes/metrics.py: -------------------------------------------------------------------------------- 1 | import zntrack 2 | from models import MODELS 3 | 4 | try: 5 | from models import REFERENCE 6 | except ImportError: 7 | REFERENCE = None 8 | 9 | import mlipx 10 | 11 | DATAPATH = "{{ datapath }}" 12 | ISOLATED_ATOM_ENERGIES = {{isolated_atom_energies}} # noqa F821 13 | 14 | 15 | project = zntrack.Project() 16 | 17 | with project.group("initialize"): 18 | data = mlipx.LoadDataFile(path=DATAPATH) 19 | 20 | 21 | with project.group("reference"): 22 | if REFERENCE is not None: 23 | data = mlipx.ApplyCalculator(data=data.frames, model=REFERENCE) 24 | ref_evaluation = mlipx.EvaluateCalculatorResults(data=data.frames) 25 | if ISOLATED_ATOM_ENERGIES: 26 | ref_isolated = mlipx.CalculateFormationEnergy(data=data.frames) 27 | 28 | for model_name, model in MODELS.items(): 29 | with project.group(model_name): 30 | updated_data = mlipx.ApplyCalculator(data=data.frames, model=model) 31 | evaluation = mlipx.EvaluateCalculatorResults(data=updated_data.frames) 32 | mlipx.CompareCalculatorResults(data=evaluation, reference=ref_evaluation) 33 | 34 | if ISOLATED_ATOM_ENERGIES: 35 | isolated = mlipx.CalculateFormationEnergy( 36 | data=updated_data.frames, model=model 37 | ) 38 | mlipx.CompareFormationEnergy(data=isolated, reference=ref_isolated) 39 | 40 | project.build() 41 | -------------------------------------------------------------------------------- /mlipx/recipes/models.py.jinja2: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import mlipx 4 | from mlipx.nodes.generic_ase import Device 5 | 6 | ALL_MODELS = {} 7 | 8 | # https://github.com/ACEsuit/mace 9 | ALL_MODELS["mace-mpa-0"] = mlipx.GenericASECalculator( 10 | module="mace.calculators", 11 | class_name="mace_mp", 12 | device="auto", 13 | kwargs={"model": "../../models/mace-mpa-0-medium.model"} 14 | # MLIPX-hub model path, adjust as needed 15 | ) 16 | # https://github.com/MDIL-SNU/SevenNet 17 | ALL_MODELS["7net-0"] = mlipx.GenericASECalculator( 18 | module="sevenn.sevennet_calculator", 19 | class_name="SevenNetCalculator", 20 | device="auto", 21 | kwargs={"model": "7net-0"} 22 | ) 23 | ALL_MODELS["7net-mf-ompa-mpa"] = mlipx.GenericASECalculator( 24 | module="sevenn.sevennet_calculator", 25 | class_name="SevenNetCalculator", 26 | device="auto", 27 | kwargs={"model": "7net-mf-ompa", "modal": "mpa"} 28 | ) 29 | 30 | # https://github.com/orbital-materials/orb-models 31 | @dataclasses.dataclass 32 | class OrbCalc: 33 | name: str 34 | device: Device | None = None 35 | kwargs: dict = dataclasses.field(default_factory=dict) 36 | 37 | def get_calculator(self, **kwargs): 38 | from orb_models.forcefield import pretrained 39 | from orb_models.forcefield.calculator import ORBCalculator 40 | 41 | method = getattr(pretrained, self.name) 42 | if self.device is None: 43 | orbff = method(**self.kwargs) 44 | calc = ORBCalculator(orbff, **self.kwargs) 45 | elif self.device == Device.AUTO: 46 | orbff = method(device=Device.resolve_auto(), **self.kwargs) 47 | calc = ORBCalculator(orbff, device=Device.resolve_auto(), **self.kwargs) 48 | else: 49 | orbff = method(device=self.device, **self.kwargs) 50 | calc = ORBCalculator(orbff, device=self.device, **self.kwargs) 51 | return calc 52 | 53 | @property 54 | def available(self) -> bool: 55 | try: 56 | from orb_models.forcefield import pretrained 57 | from orb_models.forcefield.calculator import ORBCalculator 58 | return True 59 | except ImportError: 60 | return False 61 | 62 | ALL_MODELS["orb-v2"] = OrbCalc( 63 | name="orb_v2", 64 | device="auto" 65 | ) 66 | ALL_MODELS["orb-v3"] = OrbCalc( 67 | name="orb_v3_conservative_inf_omat", 68 | device="auto" 69 | ) 70 | 71 | # https://github.com/CederGroupHub/chgnet 72 | ALL_MODELS["chgnet"] = mlipx.GenericASECalculator( 73 | module="chgnet.model", 74 | class_name="CHGNetCalculator", 75 | ) 76 | # https://github.com/microsoft/mattersim 77 | ALL_MODELS["mattersim"] = mlipx.GenericASECalculator( 78 | module="mattersim.forcefield", 79 | class_name="MatterSimCalculator", 80 | device="auto", 81 | ) 82 | # https://www.faccts.de/orca/ 83 | ALL_MODELS["orca"] = mlipx.OrcaSinglePoint( 84 | orcasimpleinput= "PBE def2-TZVP TightSCF EnGrad", 85 | orcablocks ="%pal nprocs 8 end", 86 | orca_shell="{{ orcashell }}", 87 | ) 88 | 89 | # https://gracemaker.readthedocs.io/en/latest/gracemaker/foundation/ 90 | ALL_MODELS["grace-2l-omat"] = mlipx.GenericASECalculator( 91 | module="tensorpotential.calculator", 92 | class_name="TPCalculator", 93 | device=None, 94 | kwargs={ 95 | "model": "../../models/GRACE-2L-OMAT", 96 | }, 97 | # MLIPX-hub model path, adjust as needed 98 | ) 99 | 100 | # OPTIONAL 101 | # ======== 102 | # If you have custom property names you can use the UpdatedFramesCalc 103 | # to set the energy, force and isolated_energies keys mlipx expects. 104 | 105 | # REFERENCE = mlipx.UpdateFramesCalc( 106 | # results_mapping={"energy": "DFT_ENERGY", "forces": "DFT_FORCES"}, 107 | # info_mapping={mlipx.abc.ASEKeys.isolated_energies.value: "isol_ene"}, 108 | # ) 109 | 110 | # ============================================================ 111 | # THE SELECTED MODELS! 112 | # ONLY THESE MODELS WILL BE USED IN THE RECIPE 113 | # ============================================================ 114 | MODELS = { 115 | {%- for model in models %} 116 | "{{ model }}": ALL_MODELS["{{ model }}"], 117 | {%- endfor %} 118 | } 119 | -------------------------------------------------------------------------------- /mlipx/recipes/neb.py: -------------------------------------------------------------------------------- 1 | import zntrack 2 | from models import MODELS 3 | 4 | import mlipx 5 | 6 | DATAPATH = "{{ datapath }}" 7 | 8 | project = zntrack.Project() 9 | 10 | with project.group("initialize"): 11 | data = mlipx.LoadDataFile(path=DATAPATH) 12 | trajectory = mlipx.NEBinterpolate(data=data.frames, n_images=5, mic=True) 13 | 14 | for model_name, model in MODELS.items(): 15 | with project.group(model_name): 16 | neb = mlipx.NEBs( 17 | data=trajectory.frames, 18 | model=model, 19 | relax=True, 20 | fmax=0.05, 21 | ) 22 | 23 | project.build() 24 | -------------------------------------------------------------------------------- /mlipx/recipes/phase_diagram.py.jinja2: -------------------------------------------------------------------------------- 1 | import mlipx 2 | import zntrack 3 | 4 | from models import MODELS 5 | 6 | project = zntrack.Project() 7 | 8 | frames = [] 9 | {% if datapath %} 10 | with project.group("initialize"): 11 | for path in {{ datapath }}: 12 | frames.append(mlipx.LoadDataFile(path=path)) 13 | {% endif %}{% if material_ids %} 14 | with project.group("initialize"): 15 | for material_id in {{ material_ids }}: 16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]})) 17 | {% endif %}{% if smiles %} 18 | with project.group("initialize"): 19 | for smiles in {{ smiles }}: 20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1)) 21 | {% endif %} 22 | 23 | for model_name, model in MODELS.items(): 24 | for idx, data in enumerate(frames): 25 | with project.group(model_name, str(idx)): 26 | pd = mlipx.PhaseDiagram(data=data.frames, model=model) 27 | 28 | 29 | project.build() 30 | -------------------------------------------------------------------------------- /mlipx/recipes/pourbaix_diagram.py.jinja2: -------------------------------------------------------------------------------- 1 | import mlipx 2 | import zntrack 3 | 4 | from models import MODELS 5 | 6 | project = zntrack.Project() 7 | 8 | frames = [] 9 | {% if datapath %} 10 | with project.group("initialize"): 11 | for path in {{ datapath }}: 12 | frames.append(mlipx.LoadDataFile(path=path)) 13 | {% endif %}{% if material_ids %} 14 | with project.group("initialize"): 15 | for material_id in {{ material_ids }}: 16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]})) 17 | {% endif %}{% if smiles %} 18 | with project.group("initialize"): 19 | for smiles in {{ smiles }}: 20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1)) 21 | {% endif %} 22 | 23 | for model_name, model in MODELS.items(): 24 | for idx, data in enumerate(frames): 25 | with project.group(model_name, str(idx)): 26 | pd = mlipx.PourbaixDiagram(data=data.frames, model=model, pH=1.0, V=1.8) 27 | 28 | project.build() 29 | -------------------------------------------------------------------------------- /mlipx/recipes/relax.py.jinja2: -------------------------------------------------------------------------------- 1 | import mlipx 2 | import zntrack 3 | 4 | from models import MODELS 5 | 6 | project = zntrack.Project() 7 | 8 | frames = [] 9 | {% if datapath %} 10 | with project.group("initialize"): 11 | for path in {{ datapath }}: 12 | frames.append(mlipx.LoadDataFile(path=path)) 13 | {% endif %}{% if material_ids %} 14 | with project.group("initialize"): 15 | for material_id in {{ material_ids }}: 16 | data = mlipx.MPRester(search_kwargs={"material_ids": [material_id]}) 17 | frames.append(mlipx.Rattle(data=data.frames, stdev=0.1)) 18 | {% endif %}{% if smiles %} 19 | with project.group("initialize"): 20 | for smiles in {{ smiles }}: 21 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1)) 22 | {% endif %} 23 | 24 | 25 | for model_name, model in MODELS.items(): 26 | for idx, data in enumerate(frames): 27 | with project.group(model_name, str(idx)): 28 | geom_opt = mlipx.StructureOptimization(data=data.frames, model=model, fmax=0.1) 29 | 30 | project.build() 31 | -------------------------------------------------------------------------------- /mlipx/recipes/vibrational_analysis.py.jinja2: -------------------------------------------------------------------------------- 1 | import mlipx 2 | import zntrack 3 | 4 | from models import MODELS 5 | 6 | project = zntrack.Project() 7 | 8 | frames = [] 9 | {% if datapath %} 10 | with project.group("initialize"): 11 | for path in {{ datapath }}: 12 | frames.append(mlipx.LoadDataFile(path=path)) 13 | {% endif %}{% if material_ids %} 14 | with project.group("initialize"): 15 | for material_id in {{ material_ids }}: 16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]})) 17 | {% endif %}{% if smiles %} 18 | with project.group("initialize"): 19 | for smiles in {{ smiles }}: 20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1)) 21 | {% endif %} 22 | 23 | for model_name, model in MODELS.items(): 24 | with project.group(model_name): 25 | phon = mlipx.VibrationalAnalysis( 26 | data=sum([x.frames for x in frames], []), 27 | model=model, 28 | temperature=298.15, 29 | displacement=0.015, 30 | nfree=4, 31 | lower_freq_threshold=12, 32 | ) 33 | 34 | 35 | project.build() 36 | -------------------------------------------------------------------------------- /mlipx/utils.py: -------------------------------------------------------------------------------- 1 | import ase 2 | import numpy as np 3 | from ase.calculators.singlepoint import SinglePointCalculator 4 | 5 | 6 | def freeze_copy_atoms(atoms: ase.Atoms) -> ase.Atoms: 7 | atoms_copy = atoms.copy() 8 | if atoms.calc is not None: 9 | atoms_copy.calc = SinglePointCalculator(atoms_copy) 10 | atoms_copy.calc.results = atoms.calc.results 11 | return atoms_copy 12 | 13 | 14 | def shallow_copy_atoms(atoms: ase.Atoms) -> ase.Atoms: 15 | """Create a shallow copy of an ASE atoms object.""" 16 | atoms_copy = ase.Atoms( 17 | positions=atoms.positions, 18 | numbers=atoms.numbers, 19 | cell=atoms.cell, 20 | pbc=atoms.pbc, 21 | ) 22 | return atoms_copy 23 | 24 | 25 | def rmse(y_true, y_pred): 26 | return np.sqrt(np.mean((y_true - y_pred) ** 2)) 27 | -------------------------------------------------------------------------------- /mlipx/version.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | 3 | __version__ = importlib.metadata.version("mlipx") 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mlipx" 3 | version = "0.1.4" 4 | description = "Machine-Learned Interatomic Potential eXploration" 5 | authors = [ 6 | { name = "Sandip De", email = "sandip.de@basf.com" }, 7 | { name = "Fabian Zills", email = "fzills@icp.uni-stuttgart.de" }, 8 | { name = "Sheena Agarwal", email = "sheena.a.agarwal@basf.com" } 9 | ] 10 | 11 | readme = "README.md" 12 | license = "MIT" 13 | requires-python = ">=3.10" 14 | keywords=["data-version-control", "machine-learning", "reproducibility", "collaboration", "machine-learned interatomic potential", "mlip", "mlff"] 15 | 16 | dependencies = [ 17 | "ase>=3.24.0", 18 | "lazy-loader>=0.4", 19 | "mp-api>=0.45.3", 20 | "plotly>=6.0.0", 21 | "rdkit2ase>=0.1.4", 22 | "typer>=0.15.1", 23 | "zndraw>=0.5.10", 24 | "znh5md>=0.4.4", 25 | "zntrack>=0.8.5", 26 | "dvc-s3>=3.2.0", 27 | "mpcontribs-client>=5.10.2", 28 | ] 29 | 30 | [dependency-groups] 31 | docs = [ 32 | "furo>=2024.8.6", 33 | "jupyter-sphinx>=0.5.3", 34 | "nbsphinx>=0.9.6", # https://github.com/sphinx-doc/sphinx/issues/13352 35 | "sphinx>=8.1.3,!=8.2.0", 36 | "sphinx-copybutton>=0.5.2", 37 | "sphinx-design>=0.6.1", 38 | "sphinx-hoverxref>=1.4.2", 39 | "sphinx-mdinclude>=0.6.2", 40 | "sphinxcontrib-bibtex>=2.6.3", 41 | "sphinxcontrib-mermaid>=1.0.0", 42 | ] 43 | dev = [ 44 | "ipykernel>=6.29.5", 45 | "pre-commit>=4.2.0", 46 | "pytest>=8.3.4", 47 | "pytest-cov>=6.0.0", 48 | "ruff>=0.9.6", 49 | ] 50 | 51 | 52 | [project.scripts] 53 | mlipx = "mlipx.cli.main:app" 54 | 55 | 56 | [tool.uv] 57 | conflicts = [ 58 | [ 59 | { extra = "mace" }, 60 | { extra = "mattersim" }, 61 | ], 62 | [ 63 | { extra = "mace" }, 64 | { extra = "sevenn" }, 65 | ] 66 | ] 67 | 68 | [project.optional-dependencies] 69 | chgnet = [ 70 | "chgnet>=0.4.0", 71 | ] 72 | mace = [ 73 | "mace-torch>=0.3.12", 74 | ] 75 | sevenn = [ 76 | "sevenn>=0.11.0", 77 | ] 78 | orb = [ 79 | "orb-models>=0.5.0" 80 | ] 81 | mattersim = [ 82 | "mattersim>=1.1.2", 83 | ] 84 | grace = [ 85 | "tensorpotential>=0.5.1", 86 | ] 87 | 88 | 89 | [build-system] 90 | requires = ["hatchling"] 91 | build-backend = "hatchling.build" 92 | 93 | [tool.codespell] 94 | ignore-words-list = "basf" 95 | skip = "*.svg,*.lock,*.json" 96 | 97 | [tool.ruff.lint] 98 | select = ["I", "F", "E", "C", "W"] 99 | --------------------------------------------------------------------------------