├── .github
└── workflows
│ ├── pre-comit.yaml
│ └── publish.yaml
├── .gitignore
├── .gitmodules
├── .pre-commit-config.yaml
├── .python-version
├── .readthedocs.yaml
├── .vscode
└── settings.json
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── Makefile
└── source
│ ├── _static
│ ├── .gitignore
│ ├── mlflow_compare.png
│ ├── mlipx-dark.svg
│ ├── mlipx-favicon.svg
│ ├── mlipx-light.svg
│ ├── quickstart_zndraw.png
│ ├── zndraw_compare.png
│ └── zndraw_render.png
│ ├── abc.rst
│ ├── authors.rst
│ ├── build_graph.rst
│ ├── concept.rst
│ ├── concept
│ ├── data.rst
│ ├── distributed.rst
│ ├── metrics.rst
│ ├── models.rst
│ ├── recipes.rst
│ ├── zndraw.rst
│ └── zntrack.rst
│ ├── conf.py
│ ├── glossary.rst
│ ├── index.rst
│ ├── installation.rst
│ ├── nodes.rst
│ ├── notebooks
│ ├── combine.ipynb
│ └── structure_relaxation.ipynb
│ ├── quickstart.rst
│ ├── quickstart
│ ├── cli.rst
│ └── python.rst
│ ├── recipes.rst
│ ├── recipes
│ ├── adsorption.rst
│ ├── energy_and_forces.rst
│ ├── energy_volume.rst
│ ├── homonuclear_diatomics.rst
│ ├── invariances.rst
│ ├── md.rst
│ ├── neb.rst
│ ├── phase_diagram.rst
│ ├── pourbaix_diagram.rst
│ ├── relax.rst
│ └── vibrational_analysis.rst
│ └── references.bib
├── mlipx
├── __init__.py
├── __init__.pyi
├── abc.py
├── benchmark
│ ├── __init__.py
│ ├── elements.py
│ ├── file.py
│ └── main.py
├── cli
│ ├── __init__.py
│ └── main.py
├── doc_utils.py
├── models.py
├── nodes
│ ├── __init__.py
│ ├── adsorption.py
│ ├── apply_calculator.py
│ ├── autowte.py
│ ├── compare_calculator.py
│ ├── diatomics.py
│ ├── energy_volume.py
│ ├── evaluate_calculator.py
│ ├── filter_dataset.py
│ ├── formation_energy.py
│ ├── generic_ase.py
│ ├── invariances.py
│ ├── io.py
│ ├── modifier.py
│ ├── molecular_dynamics.py
│ ├── mp_api.py
│ ├── nebs.py
│ ├── observer.py
│ ├── orca.py
│ ├── phase_diagram.py
│ ├── pourbaix_diagram.py
│ ├── rattle.py
│ ├── smiles.py
│ ├── structure_optimization.py
│ ├── updated_frames.py
│ └── vibrational_analysis.py
├── project.py
├── recipes
│ ├── README.md
│ ├── __init__.py
│ ├── adsorption.py.jinja2
│ ├── energy_volume.py.jinja2
│ ├── homonuclear_diatomics.py.jinja2
│ ├── invariances.py.jinja2
│ ├── main.py
│ ├── md.py.jinja2
│ ├── metrics.py
│ ├── models.py.jinja2
│ ├── neb.py
│ ├── phase_diagram.py.jinja2
│ ├── pourbaix_diagram.py.jinja2
│ ├── relax.py.jinja2
│ └── vibrational_analysis.py.jinja2
├── utils.py
└── version.py
├── pyproject.toml
└── uv.lock
/.github/workflows/pre-comit.yaml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches: [main]
7 |
8 | jobs:
9 | pre-commit:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v3
13 | - uses: actions/setup-python@v3
14 | - uses: pre-commit/action@v3.0.1
15 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yaml:
--------------------------------------------------------------------------------
1 | name: Release
2 | on:
3 | release:
4 | types:
5 | - created
6 |
7 | jobs:
8 | publish-pypi:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - name: Install uv
13 | uses: astral-sh/setup-uv@v5
14 | - name: Publish
15 | env:
16 | PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
17 | run: |
18 | uv build
19 | uv publish --token $PYPI_TOKEN
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 | tmp/
164 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "mlipx-hub"]
2 | path = mlipx-hub
3 | url = https://github.com/basf/mlipx-hub
4 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v5.0.0
6 | hooks:
7 | - id: check-added-large-files
8 | exclude: ^.*uv\.lock$
9 | - id: check-case-conflict
10 | - id: check-docstring-first
11 | - id: check-executables-have-shebangs
12 | - id: check-json
13 | - id: check-merge-conflict
14 | args: ['--assume-in-merge']
15 | exclude: ^(docs/)
16 | - id: check-toml
17 | - id: check-yaml
18 | - id: debug-statements
19 | - id: end-of-file-fixer
20 | exclude: .*\.json$
21 | - id: mixed-line-ending
22 | args: ['--fix=lf']
23 | - id: sort-simple-yaml
24 | - id: trailing-whitespace
25 | - repo: https://github.com/codespell-project/codespell
26 | rev: v2.4.1
27 | hooks:
28 | - id: codespell
29 | additional_dependencies: ["tomli"]
30 | - repo: https://github.com/astral-sh/ruff-pre-commit
31 | # Ruff version.
32 | rev: v0.11.5
33 | hooks:
34 | # Run the linter.
35 | - id: ruff
36 | args: [ --fix ]
37 | # Run the formatter.
38 | - id: ruff-format
39 | - repo: https://github.com/executablebooks/mdformat
40 | rev: 0.7.22
41 | hooks:
42 | - id: mdformat
43 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.11
2 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | version: 2
6 |
7 | submodules:
8 | include: all
9 |
10 | # Set the version of Python and other tools you might need
11 | build:
12 | os: ubuntu-lts-latest
13 | tools:
14 | python: "3.11"
15 | jobs:
16 | post_install:
17 | # see https://github.com/astral-sh/uv/issues/10074
18 | - pip install uv
19 | - UV_PROJECT_ENVIRONMENT=$READTHEDOCS_VIRTUALENV_PATH uv sync --link-mode=copy --group=docs
20 | # Navigate to the examples directory and perform DVC operations
21 | - cd mlipx-hub && dvc pull --allow-missing
22 |
23 | # Build documentation in the docs/ directory with Sphinx
24 | sphinx:
25 | configuration: docs/source/conf.py
26 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.analysis.typeCheckingMode": "basic",
3 | "[restructuredtext]": {
4 | "editor.wordWrap": "on"
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # MLIPX Contribution guidelines
2 |
3 | ## Adding a new MLIP
4 |
5 | 1. Create a new entry in `[project.optional-dependencies]` in the `pyproject.toml`. Configure `[tool.uv]:conflicts` if necessary.
6 | 1. Add your model to `mlipx/recipes/models.py.jinja2` to the `ALL_MODELS` dictionary.
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 BASF
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | 
4 | 
5 |
6 | [](https://badge.fury.io/py/mlipx)
7 | [](https://zntrack.readthedocs.io/en/latest/)
8 | [](https://github.com/zincware/zndraw)
9 | [](https://github.com/basf/mlipx/issues)
10 | [](https://mlipx.readthedocs.io/en/latest/?badge=latest)
11 |
12 | [📘Documentation](https://mlipx.readthedocs.io) |
13 | [🛠️Installation](https://mlipx.readthedocs.io/en/latest/installation.html) |
14 | [📜Recipes](https://mlipx.readthedocs.io/en/latest/recipes.html) |
15 | [🚀Quickstart](https://mlipx.readthedocs.io/en/latest/quickstart.html)
16 |
17 |
18 |
19 |
20 |
Machine-Learned Interatomic Potential eXploration
21 |
22 |
23 | `mlipx` is a Python library designed for evaluating machine-learned interatomic
24 | potentials (MLIPs). It offers a growing set of evaluation methods alongside
25 | powerful visualization and comparison tools.
26 |
27 | The goal of `mlipx` is to provide a common platform for MLIP evaluation and to
28 | facilitate sharing results among researchers. This allows you to determine the
29 | applicability of a specific MLIP to your research and compare it against others.
30 |
31 | ## Installation
32 |
33 | Install `mlipx` via pip:
34 |
35 | ```bash
36 | pip install mlipx
37 | ```
38 |
39 | > [!NOTE]
40 | > The `mlipx` package does not include the installation of any MLIP code, as we aim to keep the package as lightweight as possible.
41 | > If you encounter any `ImportError`, you may need to install the additional dependencies manually.
42 |
43 | ## Quickstart
44 |
45 | This section provides a brief overview of the core features of `mlipx`. For more detailed instructions, visit the [documentation](https://mlipx.readthedocs.io).
46 |
47 | Most recipes support different input formats, such as data file paths, `SMILES` strings, or Materials Project structure IDs.
48 |
49 | > [!NOTE]
50 | > Because `mlipx` uses Git and [DVC](https://dvc.org/doc), you need to create a new project directory to run your experiments in. Here's how to set up your project:
51 | >
52 | > ```bash
53 | > mkdir exp
54 | > cd exp
55 | > git init && dvc init
56 | > ```
57 | >
58 | > If you want to use datafiles, it is recommend to track them with `dvc add ` instead of `git add `.
59 | >
60 | > ```bash
61 | > cp /your/data/file.xyz .
62 | > dvc add file.xyz
63 | > ```
64 |
65 | ### Energy-Volume Curve
66 |
67 | Compute an energy-volume curve using the `mp-1143` structure from the Materials Project and MLIPs such as `mace-mpa-0`, `sevennet`, and `orb-v2`:
68 |
69 | ```bash
70 | mlipx recipes ev --models mace-mpa-0,sevennet,orb-v2 --material-ids=mp-1143 --repro
71 | mlipx compare --glob "*EnergyVolumeCurve"
72 | ```
73 |
74 | > [!NOTE]
75 | > `mlipx` utilizes [ASE](https://wiki.fysik.dtu.dk/ase/index.html),
76 | > meaning any ASE-compatible calculator for your MLIP can be used.
77 | > If we do not provide a preset for your model, you can either adapt the `models.py` file, raise an [issue](https://github.com/basf/mlipx/issues/new) to request support, or submit a pull request to add your model directly.
78 |
79 | Below is an example of the resulting comparison:
80 |
81 | 
82 | 
83 |
84 | > [!NOTE]
85 | > Set your default visualizer path using: `export ZNDRAW_URL=http://localhost:1234`.
86 |
87 | ### Structure Optimization
88 |
89 | Compare the performance of different models in optimizing multiple molecular structures from `SMILES` representations:
90 |
91 | ```bash
92 | mlipx recipes relax --models mace-mpa-0,sevennet,orb-v2 --smiles "CCO,C1=CC2=C(C=C1O)C(=CN2)CCN" --repro
93 | mlipx compare --glob "*0_StructureOptimization"
94 | mlipx compare --glob "*1_StructureOptimization"
95 | ```
96 |
97 | 
98 | 
99 |
100 | ### Nudged Elastic Band (NEB)
101 |
102 | Run and compare nudged elastic band (NEB) calculations for a given start and end structure:
103 |
104 | ```bash
105 | mlipx recipes neb --models mace-mpa-0,sevennet,orb-v2 --datapath ../data/neb_end_p.xyz --repro
106 | mlipx compare --glob "*NEBs"
107 | ```
108 |
109 | 
110 | 
111 |
112 | ## Python API
113 |
114 | You can also use all the recipes from the `mlipx` command-line interface
115 | programmatically in Python.
116 |
117 | > [!NOTE]
118 | > Whether you use the CLI or the Python API, you must work within a GIT
119 | > and DVC repository. This setup ensures reproducibility and enables automatic
120 | > caching and other features from DVC and ZnTrack.
121 |
122 | ```python
123 | import mlipx
124 |
125 | # Initialize the project
126 | project = mlipx.Project()
127 |
128 | # Define an MLIP
129 | mace_mp = mlipx.GenericASECalculator(
130 | module="mace.calculators",
131 | class_name="mace_mp",
132 | device="auto",
133 | kwargs={
134 | "model": "medium",
135 | },
136 | )
137 |
138 | # Use the MLIP in a structure optimization
139 | with project:
140 | data = mlipx.LoadDataFile(path="/your/data/file.xyz")
141 | relax = mlipx.StructureOptimization(
142 | data=data.frames,
143 | data_id=-1,
144 | model=mace_mp,
145 | fmax=0.1
146 | )
147 |
148 | # Reproduce the project state
149 | project.repro()
150 |
151 | # Access the results
152 | print(relax.frames)
153 | # >>> [ase.Atoms(...), ...]
154 | ```
155 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = ../build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @cd "$(SOURCEDIR)" && $(SPHINXBUILD) -M help "." "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @cd "$(SOURCEDIR)" && $(SPHINXBUILD) -M $@ "." "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/source/_static/.gitignore:
--------------------------------------------------------------------------------
1 | *.mp4
2 |
--------------------------------------------------------------------------------
/docs/source/_static/mlflow_compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basf/mlipx/c109aa6b2d3ff8236d6a9ec0772a8c4efb9c7dea/docs/source/_static/mlflow_compare.png
--------------------------------------------------------------------------------
/docs/source/_static/mlipx-dark.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
86 |
--------------------------------------------------------------------------------
/docs/source/_static/mlipx-favicon.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
78 |
--------------------------------------------------------------------------------
/docs/source/_static/mlipx-light.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
86 |
--------------------------------------------------------------------------------
/docs/source/_static/quickstart_zndraw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basf/mlipx/c109aa6b2d3ff8236d6a9ec0772a8c4efb9c7dea/docs/source/_static/quickstart_zndraw.png
--------------------------------------------------------------------------------
/docs/source/_static/zndraw_compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basf/mlipx/c109aa6b2d3ff8236d6a9ec0772a8c4efb9c7dea/docs/source/_static/zndraw_compare.png
--------------------------------------------------------------------------------
/docs/source/_static/zndraw_render.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basf/mlipx/c109aa6b2d3ff8236d6a9ec0772a8c4efb9c7dea/docs/source/_static/zndraw_render.png
--------------------------------------------------------------------------------
/docs/source/abc.rst:
--------------------------------------------------------------------------------
1 | Abstract Base Classes
2 | ======================
3 | We make use of abstract base classes, protocols and type hints to improve the workflow design experience.
4 | Further, these can be used for cross-package interoperability with other :term:`ZnTrack` based packages like :term:`IPSuite`.
5 |
6 | For most :term:`Node` classes operating on lists of :term:`ASE` objects, there are two scenarios:
7 | - The node operates on a single :term:`ASE` object.
8 | - The node operates on a list of :term:`ASE` objects.
9 | For both scenarios, the node is given a list of :term:`ASE` objects via the `data` attribute.
10 | For the first scenario, the `id` of the :term:`ASE` object is given via the `data_id` attribute which is omitted for the second scenario.
11 |
12 | .. automodule:: mlipx.abc
13 | :members:
14 | :undoc-members:
15 |
--------------------------------------------------------------------------------
/docs/source/authors.rst:
--------------------------------------------------------------------------------
1 | Authors and Contributing
2 | ========================
3 |
4 | .. note::
5 |
6 | Every contribution is welcome and you will be included in this ever growing list.
7 |
8 | Authors
9 | -------
10 | The creation of ``mlipx`` began during Fabian Zills internship at BASF SE, where he worked under the guidance of Sandip De. The foundational concepts and designs were initially developed by Sandip, while the current version of the code is a product of contributions from various members of the BASF team. Fabian Zills integrated several of his previous projects and lead the technical development of the initial release of this code. The code has been released with the intention of fostering community involvement in future developments. We acknowledge support from:
11 |
12 | - Fabian Zills
13 | - Sheena Agarwal
14 | - Sandip De
15 | - Shuang Han
16 | - Srishti Gupta
17 | - Tiago Joao Ferreira Goncalves
18 | - Edvin Fako
19 |
20 |
21 | Contribution Guidelines
22 | -----------------------
23 |
24 | We welcome contributions to :code:`mlipx`!
25 | With the inclusion of your contributions, we can make :code:`mlipx` better for everyone.
26 |
27 | To ensure code quality and consistency, we use :code:`pre-commit` hooks.
28 | To install the pre-commit hooks, run the following command:
29 |
30 | .. code:: console
31 |
32 | (.venv) $ pre-commit install
33 |
34 | All pre-commit hooks have to pass before a pull request can be merged.
35 |
36 | For new recipes, we recommend adding an example to the ``\examples`` directory of this repository and updating the documentation accordingly.
37 |
38 | **Plugins**
39 |
40 | It is further possible, to add new recipes to ``mlipx`` by writing plugins.
41 | We use the entry point ``mlipx.recipes`` to load new recipes.
42 | You can find more information on entry points `here `_.
43 |
44 | Given the following file ``yourpackage/recipes.py`` in your package:
45 |
46 | .. code:: python
47 |
48 | from mlipx.recipes import app
49 |
50 | @app.command()
51 | def my_recipe():
52 | # Your recipe code here
53 |
54 | you can add the following to your ``pyproject.toml``:
55 |
56 | .. code:: toml
57 |
58 | [project.entry-points."mlipx.recipes"]
59 | yourpackage = "yourpackage.recipes"
60 |
61 |
62 | and when your package is installed together with ``mlipx``, the recipe will be available in the CLI via ``mlipx recipes my_recipe``.
63 |
--------------------------------------------------------------------------------
/docs/source/build_graph.rst:
--------------------------------------------------------------------------------
1 | .. _custom_nodes:
2 |
3 | Build your own Graph
4 | =====================
5 |
6 | This section goes into more detail for adding your own :term:`ZnTrack` Node and designing a custom workflow.
7 | You will learn how to include :term:`MLIP` that can not be interfaced with the :code:`GenericASECalculator` Node.
8 | With your own custom Nodes you can build more comprehensive test cases or go even beyond :term:`MLIP` testing and build workflows for other scenarios, such as :term:`MLIP` training with :code:`mlipx` and :term:`IPSuite`.
9 |
10 | .. toctree::
11 | :glob:
12 |
13 | notebooks/*
14 |
--------------------------------------------------------------------------------
/docs/source/concept.rst:
--------------------------------------------------------------------------------
1 | Concept
2 | =======
3 |
4 | ``mlipx`` is a tool designed to evaluate the performance of various **Machine-Learned Interatomic Potentials (MLIPs)**.
5 | It offers both static and dynamic test recipes, helping you identify the most suitable MLIP for your specific problem.
6 |
7 | The ``mlipx`` package is modular and highly extensible, achieved by leveraging the capabilities of :term:`ZnTrack` and community support to provide a wide range of different test cases and :term:`MLIP` interfaces.
8 |
9 | Static Tests
10 | ------------
11 |
12 | Static tests focus on predefined datasets that serve as benchmarks for evaluating the performance of different :term:`MLIP` models.
13 | You provide a dataset file, and ``mlipx`` evaluates a specified list of :term:`MLIP` models to generate performance metrics.
14 | These tests are ideal for comparing general performance across multiple MLIPs on tasks with well-defined input data.
15 |
16 | Dynamic Tests
17 | -------------
18 |
19 | Dynamic tests are designed to address specific user-defined problems where the dataset is not predetermined. These tests provide flexibility and adaptability to evaluate :term:`MLIP` models based on your unique requirements. For example, if you provide only the composition of a system, ``mlipx`` can assess the suitability of various :term:`MLIP` models for the problem.
20 |
21 | - ``mlipx`` offers several methods to generate new data using recipes such as :ref:`relax`, :ref:`md`, :ref:`homonuclear_diatomics`, or :ref:`ev`.
22 | - If no starting structures are available, ``mlipx`` can search public datasets like ``mptraj`` or the Materials Project for similar data. Alternatively, new structures can be generated directly from ``smiles`` strings, as detailed in the :ref:`data` section.
23 |
24 | This dynamic approach enables a more focused evaluation of :term:`MLIP` models, tailoring the process to the specific challenges and requirements of the user's system.
25 |
26 | Comparison
27 | ----------
28 |
29 | A comprehensive comparison of different :term:`MLIP` models is crucial to identifying the best model for a specific problem.
30 | To facilitate this, ``mlipx`` integrates with :ref:`ZnDraw ` for visualizing trajectories and creating interactive plots of the generated data.
31 |
32 | Additionally, ``mlipx`` interfaces with :term:`DVC` for data versioning and can log metrics to :term:`mlflow`,
33 | providing a quick overview of all past evaluations.
34 |
35 |
36 | .. toctree::
37 | :hidden:
38 |
39 | concept/data
40 | concept/models
41 | concept/recipes
42 | concept/zntrack
43 | concept/zndraw
44 | concept/metrics
45 | concept/distributed
46 |
--------------------------------------------------------------------------------
/docs/source/concept/data.rst:
--------------------------------------------------------------------------------
1 | .. _data:
2 |
3 | Datasets
4 | ========
5 |
6 | Data within ``mlipx`` is always represented as a list of :term:`ASE` atoms objects.
7 | There are various ways to provide this data to the workflow, depending on your requirements.
8 |
9 | Local Data Files
10 | ----------------
11 |
12 | The simplest way to use data in the workflow is by providing a local data file, such as a trajectory file.
13 |
14 | .. code:: console
15 |
16 | (.venv) $ cp /path/to/data.xyz .
17 | (.venv) $ dvc add data.xyz
18 |
19 | .. dropdown:: Local data file (:code:`main.py`)
20 | :open:
21 |
22 | .. code:: python
23 |
24 | import zntrack
25 | import mlipx
26 |
27 | DATAPATH = "data.xyz"
28 |
29 | project = mlipx.Project()
30 |
31 | with project.group("initialize"):
32 | data = mlipx.LoadDataFile(path=DATAPATH)
33 |
34 | Remote Data Files
35 | -----------------
36 |
37 | Since ``mlipx`` integrates with :term:`DVC`, it can easily handle data from remote locations.
38 | You can manually import a remote file:
39 |
40 | .. code:: console
41 |
42 | (.venv) $ dvc import-url https://url/to/your/data.xyz data.xyz
43 |
44 | Alternatively, you can use the ``zntrack`` interface for automated management.
45 | This allows evaluation of datasets such as :code:`mptraj` and supports filtering to select relevant configurations.
46 | For example, the following code selects all structures containing :code:`F` and :code:`B` atoms.
47 |
48 | .. dropdown:: Importing online resources (:code:`main.py`)
49 | :open:
50 |
51 | .. code:: python
52 |
53 | import zntrack
54 | import mlipx
55 |
56 | mptraj = zntrack.add(
57 | url="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz",
58 | path="mptraj.xyz",
59 | )
60 |
61 | project = mlipx.Project()
62 |
63 | with project:
64 | raw_data = mlipx.LoadDataFile(path=mptraj)
65 | data = mlipx.FilterAtoms(data=raw_data.frames, elements=["B", "F"])
66 |
67 | Materials Project
68 | -----------------
69 |
70 | You can also search and retrieve structures from the `Materials Project`.
71 |
72 | .. dropdown:: Querying Materials Project (:code:`main.py`)
73 | :open:
74 |
75 | .. code:: python
76 |
77 | import mlipx
78 |
79 | project = mlipx.Project()
80 |
81 | with project.group("initialize"):
82 | data = mlipx.MPRester(search_kwargs={"material_ids": ["mp-1143"]})
83 |
84 | .. note::
85 | To use the Materials Project, you need an API key. Set the environment variable
86 | :code:`MP_API_KEY` to your API key.
87 |
88 | Generating Data
89 | ---------------
90 |
91 | Another approach is generating data dynamically. In ``mlipx``, you can build molecules or simulation boxes from SMILES strings.
92 | For instance, the following code generates a simulation box containing 10 ethanol molecules:
93 |
94 | .. dropdown:: Using SMILES (:code:`main.py`)
95 | :open:
96 |
97 | .. code:: python
98 |
99 | import mlipx
100 |
101 | project = mlipx.Project()
102 |
103 | with project.group("initialize"):
104 | confs = mlipx.Smiles2Conformers(smiles="CCO", num_confs=10)
105 | data = mlipx.BuildBox(data=[confs.frames], counts=[10], density=789)
106 |
107 | .. note::
108 | The :code:`BuildBox` node requires :term:`Packmol` and :term:`rdkit2ase`.
109 | If you do not need a simulation box, you can use :code:`confs.frames` directly.
110 |
--------------------------------------------------------------------------------
/docs/source/concept/distributed.rst:
--------------------------------------------------------------------------------
1 | .. _Distributed evaluation:
2 |
3 | Distributed evaluation
4 | ======================
5 |
6 | For the evaluation of different :term:`MLIP` models, it is often necessary to
7 | run the evaluation in different environments due to package incompatibility.
8 | Another reason can be the computational cost of the evaluation.
9 |
10 | Writing the evaluation in a workflow-like manner allows for the separation of tasks
11 | onto different hardware or software environments.
12 | For this purpose, the :term:`paraffin` package was developed.
13 |
14 | You can use :code:`paraffin submit` to queue the evaluation of the selected stages.
15 | With :code:`paraffin worker --concurrency 5` you can start 5 workers to evaluate the stages.
16 |
17 | Further, you can select which stage should be picked up by which worker by defining a :code:`paraffin.yaml` file which supports wildcards.
18 |
19 | .. code-block:: yaml
20 |
21 | queue:
22 | "B_X*": BQueue
23 | "A_X_AddNodeNumbers": AQueue
24 |
25 | The above configuration will queue all stages starting with :code:`B_X` to the :code:`BQueue` and the stage :code:`A_X_AddNodeNumbers` to the :code:`AQueue`.
26 | You can then use :code:`paraffin worker --queue BQueue` to only pick up the stages from the :code:`BQueue` and vice versa.
27 |
28 | The paraffin package is available on PyPI and can be installed via:
29 |
30 | .. code-block:: bash
31 |
32 | pip install paraffin
33 |
--------------------------------------------------------------------------------
/docs/source/concept/metrics.rst:
--------------------------------------------------------------------------------
1 | Metrics Overview
2 | ================
3 |
4 | ``mlipx`` provides several tools and integrations for comparing and visualizing metrics across experiments and nodes.
5 | This section outlines how to use these features to evaluate model performance and gain insights into various tasks.
6 |
7 | Comparing Metrics Using ``mlipx compare``
8 | -----------------------------------------
9 |
10 | With the ``mlipx compare`` command, you can directly compare results from the same Node or experiment using the :ref:`ZnDraw ` visualization tool. For example:
11 |
12 | .. code-block:: bash
13 |
14 | mlipx compare mace-mpa-0_StructureOptimization orb-v2_0_StructureOptimization
15 |
16 | This allows you to study the performance of different models for a single task in great detail.
17 | Every Node in ``mlipx`` defines its own comparison method for this.
18 |
19 | Integrations with DVC and MLFlow
20 | --------------------------------
21 |
22 | To enable a broader overview of metrics and enhance experiment tracking, ``mlipx`` integrates with both :term:`DVC` and :term:`mlflow`. These tools allow for efficient tracking, visualization, and comparison of metrics across multiple experiments.
23 |
24 | MLFlow Integration
25 | -------------------
26 |
27 | ``mlipx`` supports logging metrics to :term:`mlflow`. To use this feature, ensure ``mlflow`` is installed:
28 |
29 | .. code-block:: bash
30 |
31 | pip install mlflow
32 |
33 |
34 | .. note::
35 |
36 | More information on how to setup MLFlow and run the server can be found in the `MLFlow documentation `_.
37 |
38 | Set the tracking URI to connect to your MLFlow server:
39 |
40 | .. code-block:: bash
41 |
42 | export MLFLOW_TRACKING_URI=http://localhost:5000
43 |
44 | Use the ``zntrack mlflow-sync`` command to upload metrics to MLFlow.
45 | For this command, you need to specify the Nodes you want to sync.
46 |
47 | .. note::
48 | You can get an overview of all available Nodes using the ``zntrack list`` command.
49 | The use of glob patterns makes it easy to sync the same node for different models.
50 | To structure the experiments in MLFlow, you can specify a parent experiment.
51 |
52 | A typical structure for syncing multiple Nodes would look like this:
53 |
54 | .. code-block:: bash
55 |
56 | zntrack mlflow-sync "*StructureOptimization" --experiment "mlipx" --parent "StructureOptimization"
57 | zntrack mlflow-sync "*EnergyVolumeCurve" --experiment "mlipx" --parent "EnergyVolumeCurve"
58 | zntrack mlflow-sync "*MolecularDynamics" --experiment "mlipx" --parent "MolecularDynamics"
59 |
60 | With the MLFlow UI, you can visualize and compare metrics across experiments:
61 |
62 | .. image:: https://github.com/user-attachments/assets/2536d5d5-f8ef-4403-ac4b-670d40ae64de
63 | :align: center
64 | :alt: MLFlow UI Metrics
65 | :width: 100%
66 | :class: only-dark
67 |
68 | .. image:: https://github.com/user-attachments/assets/0d3d3187-b8ee-4b27-855e-7b245bd88346
69 | :align: center
70 | :alt: MLFlow UI Metrics
71 | :width: 100%
72 | :class: only-light
73 |
74 | Additionally, ``mlipx`` logs plots to MLFlow, enabling comparisons of relaxation energies across models or direct visualizations of energy-volume curves:
75 |
76 | .. image:: https://github.com/user-attachments/assets/19305012-6d92-40a3-bac6-68522bd55490
77 | :align: center
78 | :alt: MLFlow UI Plots
79 | :width: 100%
80 | :class: only-dark
81 |
82 | .. image:: https://github.com/user-attachments/assets/3cffba32-7abf-4a36-ac44-b584126c2e57
83 | :align: center
84 | :alt: MLFlow UI Plots
85 | :width: 100%
86 | :class: only-light
87 |
88 |
89 | Data Version Control (DVC)
90 | ---------------------------
91 |
92 | Each Node in ``mlipx`` includes predefined metrics that can be accessed via the :term:`DVC` command-line interface. Use the following commands to view metrics and plots:
93 |
94 | .. code-block:: bash
95 |
96 | dvc metrics show
97 | dvc plots show
98 |
99 | For more details on working with DVC, refer to the `DVC documentation `_.
100 |
101 | DVC also integrates seamlessly with Visual Studio Code through the `DVC extension `_, providing a user-friendly interface to browse and compare metrics and plots:
102 |
103 | .. image:: https://github.com/user-attachments/assets/79ede9d2-e11f-47da-b69c-523aa0361aaa
104 | :alt: DVC extension in Visual Studio Code
105 | :width: 100%
106 | :class: only-dark
107 |
108 | .. image:: https://github.com/user-attachments/assets/562ab225-15a8-409a-8e4e-f585e33103fa
109 | :alt: DVC extension in Visual Studio Code
110 | :width: 100%
111 | :class: only-light
112 |
--------------------------------------------------------------------------------
/docs/source/concept/models.rst:
--------------------------------------------------------------------------------
1 | Models
2 | ======
3 |
4 | For each recipe, the models to evaluate are defined in the :term:`models.py` file.
5 | In most cases, you can use :code:`mlipx.GenericASECalculator` to access models.
6 | However, in certain scenarios, you may need to provide a custom calculator.
7 |
8 | Below, we demonstrate how to write a custom calculator node for :code:`SevenCalc`.
9 | While this is an example, note that :code:`SevenCalc` could also be used with :code:`mlipx.GenericASECalculator`.
10 |
11 | Defining Models
12 | ---------------
13 |
14 | Here is the content of a typical :code:`models.py` file:
15 |
16 | .. dropdown:: Content of :code:`models.py`
17 | :open:
18 |
19 | .. code-block:: python
20 |
21 | import mlipx
22 | from src import SevenCalc
23 |
24 | mace_medium = mlipx.GenericASECalculator(
25 | module="mace.calculators",
26 | class_name="MACECalculator",
27 | device='auto',
28 | kwargs={
29 | "model_paths": "mace_models/y7uhwpje-medium.model",
30 | },
31 | )
32 |
33 | mace_agnesi = mlipx.GenericASECalculator(
34 | module="mace.calculators",
35 | class_name="MACECalculator",
36 | device='auto',
37 | kwargs={
38 | "model_paths": "mace_models/mace_mp_agnesi_medium.model",
39 | },
40 | )
41 |
42 | sevennet = SevenCalc(model='7net-0')
43 |
44 | MODELS = {
45 | "mace_medm": mace_medium,
46 | "mace_agne": mace_agnesi,
47 | "7net": sevennet,
48 | }
49 |
50 | Custom Calculator Example
51 | -------------------------
52 |
53 | The :code:`SevenCalc` class, used in the example above, is defined in :code:`src/__init__.py` as follows:
54 |
55 | .. dropdown:: Content of :code:`src/__init__.py`
56 | :open:
57 |
58 | .. code-block:: python
59 |
60 | import dataclasses
61 | from ase.calculators.calculator import Calculator
62 |
63 | @dataclasses.dataclass
64 | class SevenCalc:
65 | model: str
66 |
67 | def get_calculator(self, **kwargs) -> Calculator:
68 | from sevenn.sevennet_calculator import SevenNetCalculator
69 | sevennet = SevenNetCalculator(self.model, device='cpu')
70 |
71 | return sevennet
72 |
73 | For more details, refer to the :ref:`custom_nodes` section.
74 |
75 | .. _update-frames-calc:
76 |
77 | Updating Dataset Keys
78 | ---------------------
79 |
80 | In some cases, models may need to be defined to convert existing dataset keys into the format :code:`mlipx` expects.
81 | For example, you may need to provide isolated atom energies or convert data where energies are stored as :code:`atoms.info['DFT_ENERGY']`
82 | and forces as :code:`atoms.arrays['DFT_FORCES']`.
83 |
84 | Here’s how to define a model for such a scenario:
85 |
86 | .. code-block:: python
87 |
88 | import mlipx
89 |
90 | REFERENCE = mlipx.UpdateFramesCalc(
91 | results_mapping={"energy": "DFT_ENERGY", "forces": "DFT_FORCES"},
92 | info_mapping={mlipx.abc.ASEKeys.isolated_energies.value: "isol_ene"},
93 | )
94 |
--------------------------------------------------------------------------------
/docs/source/concept/recipes.rst:
--------------------------------------------------------------------------------
1 | .. _recipes:
2 |
3 | Recipes
4 | =======
5 |
6 | One of :code:`mlipx` core functionality is providing you with pre-designed recipes.
7 | These define workflows for evaluating :term:`MLIP` on specific tasks.
8 | You can get an overview of all available recipes using
9 |
10 | .. code-block:: console
11 |
12 | (.venv) $ mlipx recipes --help
13 |
14 | All recipes follow the same structure.
15 | It is recommended, to create a new directory for each recipe.
16 |
17 | .. code-block:: console
18 |
19 | (.venv) $ mkdir molecular_dynamics
20 | (.venv) $ cd molecular_dynamics
21 | (.venv) $ mlipx recipes md --initialize
22 |
23 | This will create the following structure:
24 |
25 | .. code-block:: console
26 |
27 | molecular_dynamics/
28 | ├── .git/
29 | ├── .dvc/
30 | ├── models.py
31 | └── main.py
32 |
33 | After initialization, adapt the :code:`main.py` file to point towards the requested data files.
34 | Define all models for testing in the :term:`models.py` file.
35 |
36 | Finally, build the recipe using
37 |
38 | .. code-block:: console
39 |
40 | (.venv) $ python main.py
41 | (.venv) $ dvc repro
42 |
43 |
44 | Upload Results
45 | --------------
46 | Once the recipe is finished, you can persist the results and upload them to a remote storage.
47 | Therefore, you want to make a GIT commit and push it to your repository.
48 |
49 | .. code-block:: console
50 |
51 | (.venv) $ git add .
52 | (.venv) $ git commit -m "Finished molecular dynamics test"
53 | (.venv) $ git push
54 | (.venv) $ dvc push
55 |
56 | .. note::
57 | You need to define a :term:`GIT` and :term:`DVC` remote to push the results.
58 | More information on how to setup a :term:`DVC` remote can be found at https://dvc.org/doc/user-guide/data-management/remote-storage.
59 |
60 |
61 | In combination or as an alternative, you can upload the results to a parameter and metric tracking service, such as :term:`mlflow`.
62 | Given a running :term:`mlflow` server, you can use the following command to upload the results:
63 |
64 | .. code-block:: console
65 |
66 | (.venv) $ zntrack mlflow-sync --help
67 |
68 | .. note::
69 | Depending on the installed packages, the :term:`mlflow` command might not be available.
70 | This functionality is provided by the :term:`zntrack` package, and other tracking services can be used as well.
71 | They will show up once the respective package is installed.
72 | See https://zntrack.readthedocs.io/ for more information.
73 |
--------------------------------------------------------------------------------
/docs/source/concept/zndraw.rst:
--------------------------------------------------------------------------------
1 | .. _zndraw:
2 |
3 | Visualisation
4 | =============
5 | :code:`mlipx` uses ZnDraw as primary tool for visualisation and comparison.
6 | The following will give you an overview.
7 |
8 | The ZnDraw package provides versatile visualisation package for atomisitc structures.
9 | It is based on :term:`ASE` and runs as a web application.
10 | You can install it via:
11 |
12 | .. code:: bash
13 |
14 | pip install zndraw
15 |
16 |
17 | It can be used to visualize data through a CLI:
18 |
19 | .. code:: bash
20 |
21 | zndraw file.xyz # any ASE supported file format + H5MD
22 | zndraw --remote . Node.frames # any ZnTrack node that has an attribute `list[ase.Atoms]`
23 |
24 | Once you have a running ZnDraw instance, you can connect to it from within Python.
25 | You can find more information in the GUI by clicking on :code:`Python Access`.
26 | The :code:`vis` object behaves like a list of :term:`ASE` atom objects.
27 | Modifying them in place, will be reflected in real-time on the GUI.
28 |
29 | .. tip::
30 |
31 | You can keep a ZnDraw instance running in the background and set the environment variable :code:`ZNDRAW_URL` to the URL of the running instance.
32 | This way, you do not have to define a ZnDraw url when running ZnDraw or ``mlipx`` CLI commands.
33 | You can also setup a `ZnDraw Docker container `_ to always have a running instance.
34 |
35 | .. code:: python
36 |
37 | from zndraw import ZnDraw
38 |
39 | vis = ZnDraw(url="http://localhost:1234", token="")
40 |
41 | print(vis[0])
42 | >>> ase.Atoms(...)
43 |
44 | vis.append(ase.Atoms(...))
45 |
46 |
47 | .. image:: ../_static/zndraw_render.png
48 | :width: 100%
49 |
50 | **Figure 1** Graphical user interface of the :ref:`ZnDraw ` package with GPU path tracing enabled.
51 |
52 |
53 | For further information have a look at the ZnDraw repository https://github.com/zincware/zndraw - a full documentation will be provided soon.
54 |
--------------------------------------------------------------------------------
/docs/source/concept/zntrack.rst:
--------------------------------------------------------------------------------
1 | Workflows
2 | =========
3 |
4 | The :code:`mlipx` package is based ZnTrack.
5 | Although, :code:`mlipx` usage does not require you to understand how ZnTrack works, the following will give a short overview of the concept.
6 | We will take an example of building a simulation Box from :code:`smiles` as illustrated in the following Python script.
7 |
8 | .. code:: python
9 |
10 | import rdkit2ase
11 |
12 | water = rdkit2ase.smiles2atoms('O')
13 | ethanol = rdkit2ase.smiles2atoms('CCO')
14 |
15 | box = rdkit2ase.pack([[water], [ethanol]], counts=[50, 50], density=800)
16 | print(box)
17 | >>> ase.Atoms(...)
18 |
19 | This script can also be represented as the following workflow which we will now convert.
20 |
21 | .. mermaid::
22 | :align: center
23 |
24 | graph TD
25 | BuildWater --> PackBox
26 | BuildEtOH --> PackBox
27 |
28 |
29 | With ZnTrack you can build complex workflows based on :term:`DVC` and :term:`GIT`.
30 | The first part of a workflow is defining the steps, which in the context of ZnTrack are called :code:`Node`.
31 | A :code:`Node` is based on the Python :code:`dataclass` module defining it's arguments as class attributes.
32 |
33 | .. note::
34 |
35 | It is highly recommend to follow the single-responsibility principle when writing a :code:`Node`. For example if you have a relaxation followed by a molecular dynamics simulation, separate the these into two Nodes. But also keep it mind, that there is some communication overhead between Nodes, so e.g. defining each MD step as a separate Node would not be recommended.
36 |
37 | .. code:: python
38 |
39 | import zntrack
40 | import ase
41 | import rdkit2ase
42 |
43 | class BuildMolecule(zntrack.Node):
44 | smiles: str = zntrack.params()
45 |
46 | frames: list[ase.Atoms] = zntrack.outs()
47 |
48 | def run(self):
49 | self.frames = [rdkit2ase.smiles2atoms(self.smiles)]
50 |
51 | With this :code:`BuildMolecule` class we can bring the :code:`rdkit2ase.smiles2atoms` onto the graph by defining the inputs and outputs.
52 | Further, we need to define a :code:`Node` for the :code:`rdkit2ase.pack` function.
53 | For this, we define the :code:`PackBox` node as follows:
54 |
55 | .. code:: python
56 |
57 | import ase.io
58 | import pathlib
59 |
60 | class PackBox(zntrack.Node):
61 | data: list[list[ase.Atoms]] = zntrack.deps()
62 | counts: list[int] = zntrack.params()
63 | density: float = zntrack.params()
64 |
65 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / 'frames.xyz')
66 |
67 | def run(self):
68 | box = rdkit2ase.pack(self.data, counts=self.counts, density=self.density)
69 | ase.io.write(self.frames_path, box)
70 |
71 | .. note::
72 |
73 | The :code:`zntrack.outs_path(zntrack.nwd / 'frames.xyz')` provides a unique output path per node in the node working directory (nwd). It is crucial to define every input and output as ZnTrack attributes. Otherwise, the results will be lost.
74 |
75 | With this Node, we can build our graph:
76 |
77 | .. code:: python
78 |
79 | project = zntrack.Project()
80 |
81 | with project:
82 | water = BuildMolecule(smiles="O")
83 | ethanol = BuildMolecule(smiles="CCO")
84 |
85 | box = PackBox(data=[water.frames, ethanol.frames], counts=[50, 50], density=800)
86 |
87 | project.build()
88 |
89 | .. note::
90 |
91 | The `project.build()` command will not run the graph but only define how the graph is to be executed in the future.
92 | Consider it a pure graph definition file.
93 | If you write this into a single :code:`main.py` file, it should look like
94 |
95 | .. dropdown:: Content of :code:`main.py`
96 |
97 | .. code-block:: python
98 |
99 | import zntrack
100 | import ase.io
101 | import rdkit2ase
102 | import pathlib
103 |
104 | class BuildMolecule(zntrack.Node):
105 | smiles: str = zntrack.params()
106 |
107 | frames: list[ase.Atoms] = zntrack.outs()
108 |
109 | def run(self):
110 | self.frames = [rdkit2ase.smiles2atoms(self.smiles)]
111 |
112 | class PackBox(zntrack.Node):
113 | data: list[list[ase.Atoms]] = zntrack.deps()
114 | counts: list[int] = zntrack.params()
115 | density: float = zntrack.params()
116 |
117 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / 'frames.xyz')
118 |
119 | def run(self):
120 | box = rdkit2ase.pack(self.data, counts=self.counts, density=self.density)
121 | ase.io.write(self.frames_path, box)
122 |
123 | if __name__ == "__main__":
124 | project = zntrack.Project()
125 |
126 | with project:
127 | water = BuildMolecule(smiles="O")
128 | ethanol = BuildMolecule(smiles="CCO")
129 |
130 | box = PackBox(data=[water.frames, ethanol.frames], counts=[50, 50], density=800)
131 |
132 | project.build()
133 |
134 | To run the graph you can use the :term:`DVC` CLI :code:`dvc repro` (or the :term:`paraffin` package, see :ref:`Distributed evaluation`. )
135 |
136 | Once finished, you can look at the results by loading the nodes:
137 |
138 | .. code:: python
139 |
140 | import zntrack
141 | import ase.io
142 |
143 | box = zntrack.from_rev("PackBox")
144 | print(ase.io.read(box.frames_path))
145 | >>> ase.Atoms(...)
146 |
147 |
148 | For further information have a look at the ZnTrack documentation https://zntrack.readthedocs.io and repository https://github.com/zincware/zntrack .
149 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | import typing as t
7 |
8 | import mlipx
9 |
10 | # -- Project information -----------------------------------------------------
11 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
12 |
13 | project = "mlipx"
14 | copyright = "2025, Fabian Zills, Sheena Agarwal, Sandip De"
15 | author = "Fabian Zills, Sheena Agarwal, Sandip De"
16 | release = mlipx.__version__
17 |
18 | # -- General configuration ---------------------------------------------------
19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
20 |
21 | extensions = [
22 | "sphinx.ext.doctest",
23 | "sphinx.ext.autodoc",
24 | "sphinxcontrib.mermaid",
25 | "sphinx.ext.viewcode",
26 | "sphinx.ext.napoleon",
27 | "hoverxref.extension",
28 | "sphinxcontrib.bibtex",
29 | "sphinx_copybutton",
30 | "jupyter_sphinx",
31 | "sphinx_design",
32 | "nbsphinx",
33 | "sphinx_mdinclude",
34 | ]
35 |
36 | templates_path = ["_templates"]
37 | exclude_patterns = []
38 |
39 |
40 | # -- Options for HTML output -------------------------------------------------
41 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
42 |
43 | html_theme = "furo"
44 | html_static_path = ["_static"]
45 | html_title = "Machine Learned Interatomic Potential eXploration"
46 | html_short_title = "mlipx"
47 | html_favicon = "_static/mlipx-favicon.svg"
48 |
49 | html_theme_options: t.Dict[str, t.Any] = {
50 | "light_logo": "mlipx-light.svg",
51 | "dark_logo": "mlipx-dark.svg",
52 | "footer_icons": [
53 | {
54 | "name": "GitHub",
55 | "url": "https://github.com/basf/mlipx",
56 | "html": "",
57 | "class": "fa-brands fa-github fa-2x",
58 | },
59 | ],
60 | "source_repository": "https://github.com/basf/mlipx/",
61 | "source_branch": "main",
62 | "source_directory": "docs/source/",
63 | "navigation_with_keys": True,
64 | }
65 |
66 | # font-awesome logos
67 | html_css_files = [
68 | "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/brands.min.css",
69 | ]
70 |
71 | # -- Options for hoverxref extension -----------------------------------------
72 | # https://sphinx-hoverxref.readthedocs.io/en/latest/
73 |
74 | hoverxref_roles = ["term"]
75 | hoverxref_role_types = {
76 | "class": "tooltip",
77 | }
78 |
79 |
80 | # -- Options for sphinxcontrib-bibtex ----------------------------------------
81 | # https://sphinxcontrib-bibtex.readthedocs.io/en/latest/
82 |
83 | bibtex_bibfiles = ["references.bib"]
84 |
85 | # -- Options for sphinx_copybutton -------------------------------------------
86 | # https://sphinx-copybutton.readthedocs.io/en/latest/
87 |
88 | copybutton_prompt_text = r">>> |\.\.\. |\(.*\) \$ "
89 | copybutton_prompt_is_regexp = True
90 |
--------------------------------------------------------------------------------
/docs/source/glossary.rst:
--------------------------------------------------------------------------------
1 | Glossary
2 | ========
3 |
4 | .. glossary::
5 |
6 | MLIP
7 | Machine learned interatomic potential.
8 |
9 | GIT
10 | GIT is a distributed version control system. It allows multiple people to work on a project at the same time without overwriting each other's changes. It also keeps a history of all changes made to the project, so you can easily revert to an earlier version if necessary.
11 |
12 | DVC
13 | Data Version Control (DVC) is a tool used in machine learning projects to track and version the datasets used in the project, as well as the code and the results. This makes it easier to reproduce experiments and share results with others.
14 | More information can be found at https://dvc.org/ .
15 |
16 | mlflow
17 | Mlflow is an open-source platform that helps manage the machine learning lifecycle, including experimentation, reproducibility, and deployment. It keeps track of the parameters used in the model as well as the metrics obtained from the model.
18 | More information can be found at https://mlflow.org/ .
19 |
20 | ZnTrack
21 | ZnTrack :footcite:t:`zillsZnTrackDataCode2024` is a Python package that provides a framework for defining and executing workflows. It allows users to define a sequence of tasks, each represented by a Node, and manage their execution and dependencies.
22 | The package can be installed via :code:`pip install zntracck` or from source at https://github.com/zincware/zntrack .
23 |
24 | IPSuite
25 | IPSuite by :footcite:t:`zillsCollaborationMachineLearnedPotentials2024` is a software package for the development and application of machine-learned interatomic potentials (MLIPs). It provides functionalities for training and testing MLIPs, as well as for running simulations using these potentials.
26 | The package can be installed via :code:`pip install ipsuite` or from source at https://github.com/zincware/ipsuite .
27 |
28 | ZnDraw
29 | The :ref:`ZnDraw ` package for visualisation and editing of atomistic structures :footcite:`elijosiusZeroShotMolecular2024`.
30 | The package can be installed via :code:`pip install zndraw` or from source at https://github.com/zincware/zndraw .
31 |
32 | main.py
33 | The :term:`ZnTrack` graph definition for the recipe is stored in this file.
34 |
35 | models.py
36 | The file that contains the models for testing.
37 | Each recipe will import the models from this file.
38 | It should follow the following format:
39 |
40 | .. code-block:: python
41 |
42 | from mlipx.abc import NodeWithCalculator
43 |
44 | MODELS: dict[str, NodeWithCalculator] = {
45 | ...
46 | }
47 |
48 | packmol
49 | Packmol is a software package used for building initial configurations for molecular dynamics or Monte Carlo simulations. It can generate a random collection of molecules using the specified density and composition. More information can be found at https://m3g.github.io/packmol/ .
50 |
51 | rdkit2ase
52 | A package for converting RDKit molecules to ASE atoms.
53 | The package can be installed via :code:`pip install rdkit2ase` or from source at https://github.com/zincware/rdkit2ase .
54 |
55 | Node
56 | A node is a class that represents a single step in the workflow.
57 | It should inherit from the :class:`zntrack.Node` class.
58 | The node should implement the :meth:`zntrack.Node.run` method.
59 |
60 | ASE
61 | The Atomic Simulation Environment (ASE). More information can be found at https://wiki.fysik.dtu.dk/ase/
62 |
63 | paraffin
64 | The paraffin package for the distributed evaluation of :term:`DVC` stages.
65 | The package can be installed via :code:`pip install paraffin` or from source at https://github.com/zincware/paraffin .
66 |
67 |
68 | .. footbibliography::
69 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | MLIPX Documentation
2 | ===================
3 |
4 | :code:`mlipx` is a Python library for the evaluation and benchmarking of machine-learned interatomic potentials (:term:`MLIP`). MLIPs are advanced computational models that use machine learning techniques to describe complex atomic interactions accurately and efficiently. They significantly accelerate traditional quantum mechanical modeling workflows, making them invaluable for simulating a wide range of physical phenomena, from chemical reactions to materials properties and phase transitions. MLIP testing requires more than static cross-validation protocols. While these protocols are essential, they are just the beginning. Evaluating energy and force prediction accuracies on a test set is only the first step. To determine the real-world usability of an ML model, more comprehensive testing is needed.
5 |
6 | :code:`mlipx` addresses this need by providing systematically designed testing recipes to assess the strengths and weaknesses of rapidly developing growing flavours of MLIP models. These recipes help ensure that models are robust and applicable to a wide range of scenarios. :code:`mlipx` provides you with an ever-growing set of evaluation methods accompanied by comprehensive visualization and comparison tools.
7 |
8 | The goal of this project is to provide a common platform for the evaluation of MLIPs and to facilitate the exchange of evaluation results between researchers.
9 | Ultimately, you should be able to determine the applicability of a given MLIP for your specific research question and to compare it to other MLIPs.
10 |
11 | By offering these capabilities, MLIPX helps researchers determine the applicability of MLIPs for specific research questions and compare them effectively while developing from scratch or finetuning universal models. This collaborative tool promotes transparency and reproducibility in MLIP evaluations.
12 |
13 | Join us in using and improving MLIPX to advance the field of machine-learned interatomic potentials. Your contributions and feedback are invaluable.
14 |
15 | .. note::
16 |
17 | This project is under active development.
18 |
19 |
20 | Example
21 | -------
22 |
23 | Create a ``mlipx`` :ref:`recipe ` to compute :ref:`ev` for the `mp-1143 `_ structure using different :term:`MLIP` models
24 |
25 | .. code-block:: console
26 |
27 | (.venv) $ mlipx recipes ev --models mace-mpa-0,sevennet,orb-v2 --material-ids=mp-1143 --repro
28 | (.venv) $ mlipx compare --glob "*EnergyVolumeCurve"
29 |
30 | and use the integration with :ref:`ZnDraw ` to visualize the resulting trajectories and compare the energies interactively.
31 |
32 | .. image:: https://github.com/user-attachments/assets/c2479d17-c443-4550-a641-c513ede3be02
33 | :width: 100%
34 | :alt: ZnDraw
35 | :class: only-light
36 |
37 | .. image:: https://github.com/user-attachments/assets/2036e6d9-3342-4542-9ddb-bbc777d2b093
38 | :width: 100%
39 | :alt: ZnDraw
40 | :class: only-dark
41 |
42 | .. toctree::
43 | :hidden:
44 | :maxdepth: 2
45 |
46 | installation
47 | quickstart
48 | concept
49 | recipes
50 | build_graph
51 | nodes
52 | glossary
53 | abc
54 | authors
55 |
--------------------------------------------------------------------------------
/docs/source/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | From PyPI
5 | ---------
6 |
7 | To use :code:`mlipx`, first install it using pip:
8 |
9 | .. code-block:: console
10 |
11 | (.venv) $ pip install mlipx
12 |
13 | .. note::
14 |
15 | The :code:`mlipx` package installation does not contain any :term:`MLIP` packages.
16 | Due to different dependencies, it is highly recommended to install your preferred :term:`MLIP` package individually into the same environment.
17 | We provide extras for the :term:`MLIP` packages included in our documentation.
18 | You can install them using extras (not exhaustive):
19 |
20 | .. code-block:: console
21 |
22 | (.venv) $ pip install mlipx[mace]
23 | (.venv) $ pip install mlipx[orb]
24 |
25 | To get an overview of the currently available models :code:`mlipx` is familiar with, you can use the following command:
26 |
27 | .. code-block:: console
28 |
29 | (.venv) $ mlipx info
30 |
31 | .. note::
32 |
33 | If you encounter en error like :code:`Permission denied '/var/cache/dvc'` you might want to reinstall :code:`pip install platformdirs==3.11.0` or :code:`pip install platformdirs==3.10.0` as discussed at https://github.com/iterative/dvc/issues/9184
34 |
35 | From Source
36 | -----------
37 |
38 | To install and develop :code:`mlipx` from source we recommend using :code:`https://docs.astral.sh/uv`.
39 | More information and installation instructions can be found at https://docs.astral.sh/uv/getting-started/installation/ .
40 |
41 | .. code:: console
42 |
43 | (.venv) $ git clone https://github.com/basf/mlipx
44 | (.venv) $ cd mlipx
45 | (.venv) $ uv sync
46 | (.venv) $ source .venv/bin/activate
47 |
48 | You can quickly switch between different :term:`MLIP` packages extras using :code:`uv sync` command.
49 |
50 |
51 | .. code:: console
52 |
53 | (.venv) $ uv sync --extra mattersim
54 | (.venv) $ uv sync --extra sevenn
55 |
--------------------------------------------------------------------------------
/docs/source/nodes.rst:
--------------------------------------------------------------------------------
1 | Nodes
2 | =====
3 |
4 | The core functionality of :code:`mlipx` is based on :term:`ZnTrack` nodes.
5 | Each Node is documented in the following section.
6 |
7 | .. glossary::
8 |
9 | LoadDataFile
10 | .. autofunction:: mlipx.LoadDataFile
11 |
12 | LangevinConfig
13 | .. autofunction:: mlipx.LangevinConfig
14 |
15 | ApplyCalculator
16 | .. autofunction:: mlipx.ApplyCalculator
17 |
18 | EvaluateCalculatorResults
19 | .. autofunction:: mlipx.EvaluateCalculatorResults
20 |
21 | CompareCalculatorResults
22 | .. autofunction:: mlipx.CompareCalculatorResults
23 |
24 | CompareFormationEnergy
25 | .. autofunction:: mlipx.CompareFormationEnergy
26 |
27 | CalculateFormationEnergy
28 | .. autofunction:: mlipx.CalculateFormationEnergy
29 |
30 | MaximumForceObserver
31 | .. autofunction:: mlipx.MaximumForceObserver
32 |
33 | TemperatureRampModifier
34 | .. autofunction:: mlipx.TemperatureRampModifier
35 |
36 | MolecularDynamics
37 | .. autofunction:: mlipx.MolecularDynamics
38 |
39 | HomonuclearDiatomics
40 | .. autofunction:: mlipx.HomonuclearDiatomics
41 |
42 | NEBinterpolate
43 | .. autofunction:: mlipx.NEBinterpolate
44 |
45 | NEBs
46 | .. autofunction:: mlipx.NEBs
47 |
48 | VibrationalAnalysis
49 | .. autofunction:: mlipx.VibrationalAnalysis
50 |
51 | PhaseDiagram
52 | .. autofunction:: mlipx.PhaseDiagram
53 |
54 | PourbaixDiagram
55 | .. autofunction:: mlipx.PourbaixDiagram
56 |
57 | StructureOptimization
58 | .. autofunction:: mlipx.StructureOptimization
59 |
60 | Smiles2Conformers
61 | .. autofunction:: mlipx.Smiles2Conformers
62 |
63 | BuildBox
64 | .. autofunction:: mlipx.BuildBox
65 |
66 | EnergyVolumeCurve
67 | .. autofunction:: mlipx.EnergyVolumeCurve
68 |
69 | FilterAtoms
70 | .. autofunction:: mlipx.FilterAtoms
71 |
72 | RotationalInvariance
73 | .. autofunction:: mlipx.RotationalInvariance
74 |
75 | TranslationalInvariance
76 | .. autofunction:: mlipx.TranslationalInvariance
77 |
78 | PermutationInvariance
79 | .. autofunction:: mlipx.PermutationInvariance
80 |
81 | RelaxAdsorptionConfigs
82 | .. autofunction:: mlipx.RelaxAdsorptionConfigs
83 |
84 | BuildASEslab
85 | .. autofunction:: mlipx.BuildASEslab
86 |
--------------------------------------------------------------------------------
/docs/source/notebooks/structure_relaxation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "[](https://colab.research.google.com/github/basf/mlipx/blob/main/docs/source/notebooks/structure_relaxation.ipynb)\n",
8 | "\n",
9 | "# Structure Relaxtion with Custom Nodes\n",
10 | "\n",
11 | "You can combine `mlipx` with custom code by writing ZnTrack nodes.\n",
12 | "We will write a Node to perform a geometry relaxation similar to `mlipx.StructureOptimization`."
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "# Only install the packages if they are not already installed\n",
22 | "!pip show mlipx > /dev/null 2>&1 || pip install mlipx\n",
23 | "!pip show rdkit2ase > /dev/null 2>&1 || pip install rdkit2ase"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# We will create a GIT and DVC repository in a temporary directory\n",
33 | "import os\n",
34 | "import tempfile\n",
35 | "\n",
36 | "temp_dir = tempfile.TemporaryDirectory()\n",
37 | "os.chdir(temp_dir.name)"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "metadata": {},
43 | "source": [
44 | "Like all `mlipx` Nodes we will use a GIT and DVC repository to run experiments.\n",
45 | "To make our custom code available, we structure our project like\n",
46 | "\n",
47 | "```\n",
48 | "relaxation/\n",
49 | " ├── .git/\n",
50 | " ├── .dvc/\n",
51 | " ├── src/__init__.py\n",
52 | " ├── src/relaxation.py\n",
53 | " ├── models.py\n",
54 | " └── main.py\n",
55 | "```\n",
56 | "\n",
57 | "to allow us to import our code `from src.relaxation import Relax`.\n",
58 | "Alternatively, you can package your code and import it like any other Python package."
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "!git init\n",
68 | "!dvc init --quiet\n",
69 | "!mkdir src\n",
70 | "!touch src/__init__.py"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {},
76 | "source": [
77 | "The code we want to put into our `Relax` `Node` is the following:\n",
78 | "\n",
79 | "\n",
80 | "```python\n",
81 | "from ase.optimize import BFGS\n",
82 | "import ase.io\n",
83 | "\n",
84 | "data: list[ase.Atoms]\n",
85 | "calc: ase.calculator.Calculator\n",
86 | "\n",
87 | "end_structures = []\n",
88 | "for atoms in data:\n",
89 | " atoms.set_calculator(calc)\n",
90 | " opt = BFGS(atoms)\n",
91 | " opt.run(fmax=0.05)\n",
92 | " end_structures.append(atoms)\n",
93 | "\n",
94 | "ase.io.write('end_structures.xyz', end_structures)\n",
95 | "```\n",
96 | "\n",
97 | "To do so, we need to identify and define the inputs and outputs of our code.\n",
98 | "We provide the `data: list[ase.Atoms]` from a data loading Node.\n",
99 | "Therefore, we use `data: list = zntrack.deps()`.\n",
100 | "If you want to read directly from file you could use `data_path: str = zntrack.deps_path()`.\n",
101 | "We access the calculator in a similar way using `model: NodeWithCalculator = zntrack.deps()`.\n",
102 | "`mlipx` provides the `NodeWithCalculator` abstract base class for a common communication on how to share `ASE` calculators.\n",
103 | "Another convention is providing inputs as `data: list[ase.Atoms]` and outputs as `frames: list[ase.Atoms]`.\n",
104 | "As we save our data to a file, we define `frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / 'frames.xyz')` to store the output trajetory in the node working directory (nwd) as `frames.xyz`.\n",
105 | "The `zntrack.nwd` provides a unique directory per `Node` to store the data at.\n",
106 | "As the communication between `mlipx` nodes is based on `ase.Atoms` we define a `@frames` property.\n",
107 | "Within this, we could also alter the `ase.Atoms` object, thus making the node communication independent of the file format facilitating data communication via code or Data as Code (DaC).\n",
108 | "To summarize, each Node provides all the information on how to `save` and `load` the produced data, simplifying communication and reducing issues with different file format conventions.\n",
109 | "\n",
110 | "Besides the implemented fields, there are also `x: dict = zntrack.params`, `x: dict = zntrack.metrics` and `x: pd.DataFrame = zntrack.plots` and their corresponding file path versions `x: str|pathlib.Path = zntrack.params_path`, `zntrack.metrics_path` and `zntrack.plots_path`.\n",
111 | "For general outputs there is `x: any = zntrack.outs`. More information can be found at https://dvc.org/doc/start/data-pipelines/metrics-parameters-plots ."
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "%%writefile src/relaxation.py\n",
121 | "import zntrack\n",
122 | "from mlipx.abc import NodeWithCalculator\n",
123 | "from ase.optimize import BFGS\n",
124 | "import ase.io\n",
125 | "import pathlib\n",
126 | "\n",
127 | "\n",
128 | "\n",
129 | "class Relax(zntrack.Node):\n",
130 | " data: list = zntrack.deps()\n",
131 | " model: NodeWithCalculator = zntrack.deps()\n",
132 | " frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / 'frames.xyz')\n",
133 | "\n",
134 | " def run(self):\n",
135 | " end_structures = []\n",
136 | " for atoms in self.data:\n",
137 | " atoms.set_calculator(self.model.get_calculator())\n",
138 | " opt = BFGS(atoms)\n",
139 | " opt.run(fmax=0.05)\n",
140 | " end_structures.append(atoms)\n",
141 | " with open(self.frames_path, 'w') as f:\n",
142 | " ase.io.write(f, end_structures, format='extxyz')\n",
143 | " \n",
144 | " @property\n",
145 | " def frames(self) -> list[ase.Atoms]:\n",
146 | " with self.state.fs.open(self.frames_path, \"r\") as f:\n",
147 | " return ase.io.read(f, format='extxyz', index=':')\n"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "metadata": {},
153 | "source": [
154 | "With this Node definition, we can import the Node and connect it with `mlipx` to form a graph."
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "from src.relaxation import Relax\n",
164 | "\n",
165 | "import mlipx"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "metadata": {},
172 | "outputs": [],
173 | "source": [
174 | "project = mlipx.Project()\n",
175 | "\n",
176 | "emt = mlipx.GenericASECalculator(\n",
177 | " module=\"ase.calculators.emt\",\n",
178 | " class_name=\"EMT\",\n",
179 | ")\n",
180 | "\n",
181 | "with project:\n",
182 | " confs = mlipx.Smiles2Conformers(smiles=\"CCCC\", num_confs=5)\n",
183 | " relax = Relax(data=confs.frames, model=emt)\n",
184 | "\n",
185 | "project.build()"
186 | ]
187 | },
188 | {
189 | "cell_type": "markdown",
190 | "metadata": {},
191 | "source": [
192 | "To execute the graph, we make use of `dvc repro` via `project.repro`."
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "project.repro(build=False)"
202 | ]
203 | },
204 | {
205 | "cell_type": "markdown",
206 | "metadata": {},
207 | "source": [
208 | "Once the graph has been executed, we can look at the resulting structures."
209 | ]
210 | },
211 | {
212 | "cell_type": "code",
213 | "execution_count": null,
214 | "metadata": {},
215 | "outputs": [],
216 | "source": [
217 | "relax.frames"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": null,
223 | "metadata": {},
224 | "outputs": [],
225 | "source": [
226 | "temp_dir.cleanup()"
227 | ]
228 | }
229 | ],
230 | "metadata": {
231 | "kernelspec": {
232 | "display_name": "Python 3",
233 | "language": "python",
234 | "name": "python3"
235 | },
236 | "language_info": {
237 | "codemirror_mode": {
238 | "name": "ipython",
239 | "version": 3
240 | },
241 | "file_extension": ".py",
242 | "mimetype": "text/x-python",
243 | "name": "python",
244 | "nbconvert_exporter": "python",
245 | "pygments_lexer": "ipython3",
246 | "version": "3.10.0"
247 | }
248 | },
249 | "nbformat": 4,
250 | "nbformat_minor": 2
251 | }
252 |
--------------------------------------------------------------------------------
/docs/source/quickstart.rst:
--------------------------------------------------------------------------------
1 | Quickstart
2 | ==========
3 |
4 | There are two ways to use the library: as a :ref:`command-line tool ` or as a :ref:`Python library `.
5 | The CLI provides the most convenient way to get started, while the Python library offers more flexibility for advanced workflows.
6 |
7 | .. image:: https://github.com/user-attachments/assets/ab38546b-6f5f-4c7c-9274-f7d2e9e1ae73
8 | :width: 100%
9 | :class: only-light
10 |
11 | .. image:: https://github.com/user-attachments/assets/c34f64f7-958a-47cc-88ab-d2689e82deaf
12 | :width: 100%
13 | :class: only-dark
14 |
15 | Use the :ref:`command-line tool ` to evaluate different :term:`MLIP` models on the ``DODH_adsorption_dft.xyz`` file and
16 | visualize the trajectory together with the maximum force error in :ref:`ZnDraw `.
17 |
18 | .. code:: console
19 |
20 | (.venv) $ mlipx recipes metrics --models mace-mpa-0,sevennet,orb-v2 --datapath ../data DODH_adsorption_dft.xyz --repro
21 | (.venv) $ mlipx compare --glob "*CompareCalculatorResults"
22 |
23 | .. toctree::
24 | :maxdepth: 2
25 | :hidden:
26 |
27 | quickstart/cli
28 | quickstart/python
29 |
--------------------------------------------------------------------------------
/docs/source/quickstart/cli.rst:
--------------------------------------------------------------------------------
1 | .. _cli-quickstart:
2 |
3 | Command Line Interface
4 | ======================
5 |
6 | This guide will help you get started with ``mlipx`` by creating a new project in an empty directory and computing metrics for a machine-learned interatomic potential (:term:`MLIP`) against reference DFT data.
7 |
8 | First, create a new project directory and initialize it with Git and DVC:
9 |
10 | .. code-block:: console
11 |
12 | (.venv) $ mkdir my_project
13 | (.venv) $ cd my_project
14 | (.venv) $ git init
15 | (.venv) $ dvc init
16 |
17 | Adding Reference Data
18 | ----------------------
19 | Next, add a reference DFT dataset to the project. For this example, we use a slice from the mptraj dataset :footcite:`dengCHGNetPretrainedUniversal2023`.
20 |
21 | .. note::
22 |
23 | If you have your own data, replace this file with any dataset that can be read by ``ase.io.read`` and includes reference energies and forces. Run the following command instead:
24 |
25 | .. code-block:: bash
26 |
27 | (.venv) $ cp /path/to/your/data.xyz data.xyz
28 | (.venv) $ dvc add data.xyz
29 |
30 | .. code-block:: console
31 |
32 | (.venv) $ dvc import-url https://github.com/zincware/ips-mace/releases/download/v0.1.0/mptraj_slice.xyz data.xyz
33 |
34 | Adding the Recipe
35 | -----------------
36 | With the reference data in place, add a ``mlipx`` recipe to compute metrics:
37 |
38 | .. code-block:: console
39 |
40 | (.venv) $ mlipx recipes metrics --datapath data.xyz
41 |
42 | This command generates a ``main.py`` file in the current directory, which defines the workflow for the recipe.
43 |
44 | Defining Models
45 | ---------------
46 | Define the models to evaluate. This example uses the MACE-MP-0 model :footcite:`batatiaFoundationModelAtomistic2023` which is provided by the ``mace-torch`` package..
47 |
48 | Create a file named ``models.py`` in the current directory with the following content:
49 |
50 |
51 | .. note::
52 |
53 | If you already have computed energies and forces you can use two different data files or one file and update the keys.
54 | For more information, see the section on :ref:`update-frames-calc`.
55 |
56 | .. code-block:: python
57 |
58 | import mlipx
59 |
60 | mace_mp = mlipx.GenericASECalculator(
61 | module="mace.calculators",
62 | class_name="mace_mp",
63 | device="auto",
64 | kwargs={
65 | "model": "medium",
66 | },
67 | )
68 |
69 | MODELS = {"mace-mp": mace_mp}
70 |
71 | .. note::
72 |
73 | The ``GenericASECalculator`` class passes any provided ``kwargs`` to the specified ``class_name``.
74 | A special case is the ``device`` argument.
75 | When set to ``auto``, the class uses ``torch.cuda.is_available()`` to check if a GPU is available and automatically selects it if possible.
76 | If you are not using ``torch`` or wish to specify a device explicitly, you can omit the ``device`` argument or define it directly in the ``kwargs``.
77 |
78 |
79 | Running the Workflow
80 | ---------------------
81 | Now, run the workflow using the following commands:
82 |
83 | .. code-block:: console
84 |
85 | (.venv) $ python main.py
86 | (.venv) $ dvc repro
87 |
88 | Listing Steps and Visualizing Results
89 | -------------------------------------
90 | To explore the available steps and visualize results, use the commands below:
91 |
92 | .. code-block:: console
93 |
94 | (.venv) $ zntrack list
95 | (.venv) $ mlipx compare mace-mp_CompareCalculatorResults
96 |
97 | .. note::
98 |
99 | To use ``mlipx compare``, you must have an active ZnDraw server running. Provide the server URL via the ``--zndraw-url`` argument or the ``ZNDRAW_URL`` environment variable.
100 |
101 | You can start a server locally with the command ``zndraw`` in a separate terminal or use the public server at https://zndraw.icp.uni-stuttgart.de.
102 |
103 |
104 | More CLI Options
105 | ----------------
106 |
107 | The ``mlipx`` CLI can create the :term:`models.py` for some models.
108 | To evaluate ``data.xyz`` with multiple models, you can also run
109 |
110 | .. code-block:: console
111 |
112 | (.venv) $ mlipx recipes metrics --datapath data.xyz --models mace-mpa-0,sevennet,orb-v2,chgnet --repro
113 |
114 | .. note::
115 |
116 | Want to see your model here? Open an issue or submit a pull request to the `mlipx repository `_.
117 |
118 |
119 | .. footbibliography::
120 |
--------------------------------------------------------------------------------
/docs/source/quickstart/python.rst:
--------------------------------------------------------------------------------
1 | .. _python-quickstart:
2 |
3 | Python Interface
4 | ================
5 |
6 | In the :ref:`cli-quickstart` guide, we demonstrated how to compute metrics for an MLIP against reference DFT data using the CLI.
7 | This guide shows how to achieve the same result using the Python interface.
8 |
9 | Getting Started
10 | ---------------
11 |
12 | First, create a new project directory and initialize it with Git and DVC, as shown below:
13 |
14 | .. code-block:: console
15 |
16 | (.venv) $ mkdir my_project
17 | (.venv) $ cd my_project
18 | (.venv) $ git init
19 | (.venv) $ dvc init
20 |
21 | Adding Reference Data
22 | ----------------------
23 |
24 | Create a new Python file named ``main.py`` in the project directory, and add the following code to download the reference dataset:
25 |
26 | .. code-block:: python
27 |
28 | import mlipx
29 | import zntrack
30 |
31 | mptraj = zntrack.add(
32 | url="https://github.com/zincware/ips-mace/releases/download/v0.1.0/mptraj_slice.xyz",
33 | path="data.xyz",
34 | )
35 |
36 | This will download the reference data file ``mptraj_slice.xyz`` into your project directory.
37 |
38 | Defining Models
39 | ---------------
40 |
41 | Define the MLIP models to evaluate by adding the following code to the ``main.py`` file:
42 |
43 | .. code-block:: python
44 |
45 | mace_mp = mlipx.GenericASECalculator(
46 | module="mace.calculators",
47 | class_name="mace_mp",
48 | device="auto",
49 | kwargs={
50 | "model": "medium",
51 | },
52 | )
53 |
54 | Adding the Recipe
55 | -----------------
56 |
57 | Next, set up the recipe to compute metrics for the MLIP. Add the following code to the ``main.py`` file:
58 |
59 | .. code-block:: python
60 |
61 | project = mlipx.Project()
62 |
63 | with project.group("reference"):
64 | data = mlipx.LoadDataFile(path=mptraj)
65 | ref_evaluation = mlipx.EvaluateCalculatorResults(data=data.frames)
66 |
67 | with project.group("mace-mp"):
68 | updated_data = mlipx.ApplyCalculator(data=data.frames, model=mace_mp)
69 | evaluation = mlipx.EvaluateCalculatorResults(data=updated_data.frames)
70 | mlipx.CompareCalculatorResults(data=evaluation, reference=ref_evaluation)
71 |
72 | project.repro()
73 |
74 | Running the Workflow
75 | ---------------------
76 |
77 | Finally, run the workflow by executing the ``main.py`` file:
78 |
79 | .. code-block:: console
80 |
81 | (.venv) $ python main.py
82 |
83 | .. note::
84 |
85 | If you want to execute the workflow using ``dvc repro``, replace ``project.repro()`` with ``project.build()`` in the ``main.py`` file.
86 |
87 | This will compute the metrics for the MLIP against the reference DFT data.
88 |
89 | Listing Steps and Visualizing Results
90 | -------------------------------------
91 |
92 | As with the CLI approach, you can list the available steps and visualize results using the following commands:
93 |
94 | .. code-block:: console
95 |
96 | (.venv) $ zntrack list
97 | (.venv) $ mlipx compare mace-mp_CompareCalculatorResults
98 |
99 | Alternatively, you can load the results for this and any other Node directly into a Python kernel using the following code:
100 |
101 | .. code-block:: python
102 |
103 | import zntrack
104 |
105 | node = zntrack.from_rev("mace-mp_CompareCalculatorResults")
106 | print(node.figures)
107 | >>> {"fmax_error": plotly.graph_objects.Figure(), ...}
108 |
--------------------------------------------------------------------------------
/docs/source/recipes.rst:
--------------------------------------------------------------------------------
1 | Recipes
2 | =======
3 |
4 | The following recipes are currently available within :code:`mlipx`.
5 |
6 | .. toctree::
7 | :maxdepth: 1
8 |
9 | recipes/energy_and_forces
10 | recipes/homonuclear_diatomics
11 | recipes/energy_volume
12 | recipes/invariances
13 | recipes/relax
14 | recipes/md
15 | recipes/neb
16 | recipes/adsorption
17 | recipes/phase_diagram
18 | recipes/pourbaix_diagram
19 | recipes/vibrational_analysis
20 |
--------------------------------------------------------------------------------
/docs/source/recipes/adsorption.rst:
--------------------------------------------------------------------------------
1 | .. _neb:
2 |
3 | Adsorption Energies
4 | ===================
5 |
6 | This recipe calculates the adsorption energies of a molecule on a surface.
7 | The following example creates a slab of ``Cu(111)`` and calculates the adsorption energy of ethanol ``(CCO)`` on the surface.
8 |
9 | .. mdinclude:: ../../../mlipx-hub/adsorption/cu_fcc111/README.md
10 |
11 |
12 | .. jupyter-execute::
13 | :hide-code:
14 |
15 | from mlipx.doc_utils import get_plots
16 |
17 | plots = get_plots("*RelaxAdsorptionConfigs", "../../mlipx-hub/adsorption/cu_fcc111/")
18 | plots["adsorption_energies"].show()
19 |
20 | This test uses the following Nodes together with your provided model in the :term:`models.py` file:
21 |
22 | * :term:`RelaxAdsorptionConfigs`
23 | * :term:`BuildASEslab`
24 | * :term:`Smiles2Conformers`
25 |
26 |
27 | .. dropdown:: Content of :code:`main.py`
28 |
29 | .. literalinclude:: ../../../mlipx-hub/adsorption/cu_fcc111/main.py
30 | :language: Python
31 |
32 |
33 | .. dropdown:: Content of :code:`models.py`
34 |
35 | .. literalinclude:: ../../../mlipx-hub/adsorption/cu_fcc111/models.py
36 | :language: Python
37 |
--------------------------------------------------------------------------------
/docs/source/recipes/energy_and_forces.rst:
--------------------------------------------------------------------------------
1 | Energy and Force Evaluation
2 | ===========================
3 |
4 | This recipe is used to test the performance of different models in predicting the energy and forces for a given dataset.
5 |
6 | .. mdinclude:: ../../../mlipx-hub/metrics/DODH_adsorption/README.md
7 |
8 | .. mermaid::
9 | :align: center
10 |
11 | graph TD
12 |
13 | data['Reference Data incl. DFT E/F']
14 | data --> CalculateFormationEnergy1
15 | data --> CalculateFormationEnergy2
16 | data --> CalculateFormationEnergy3
17 | data --> CalculateFormationEnergy4
18 |
19 | subgraph Reference
20 | CalculateFormationEnergy1 --> EvaluateCalculatorResults1
21 | end
22 |
23 | subgraph mg1["Model 1"]
24 | CalculateFormationEnergy2 --> EvaluateCalculatorResults2
25 | EvaluateCalculatorResults2 --> CompareCalculatorResults2
26 | EvaluateCalculatorResults1 --> CompareCalculatorResults2
27 | end
28 | subgraph mg2["Model 2"]
29 | CalculateFormationEnergy3 --> EvaluateCalculatorResults3
30 | EvaluateCalculatorResults3 --> CompareCalculatorResults3
31 | EvaluateCalculatorResults1 --> CompareCalculatorResults3
32 | end
33 | subgraph mgn["Model N"]
34 | CalculateFormationEnergy4 --> EvaluateCalculatorResults4
35 | EvaluateCalculatorResults4 --> CompareCalculatorResults4
36 | EvaluateCalculatorResults1 --> CompareCalculatorResults4
37 | end
38 |
39 |
40 | .. code-block:: console
41 |
42 | (.venv) $ mlipx compare --glob "*CompareCalculatorResults"
43 |
44 | .. jupyter-execute::
45 | :hide-code:
46 |
47 | from mlipx.doc_utils import get_plots
48 |
49 | plots = get_plots("*CompareCalculatorResults", "../../mlipx-hub/metrics/DODH_adsorption/")
50 | # raise ValueError(plots.keys())
51 | plots["fmax_error"].show()
52 | plots["adjusted_energy_error_per_atom"].show()
53 |
54 |
55 | This recipe uses the following Nodes together with your provided model in the :term:`models.py` file:
56 |
57 | * :term:`ApplyCalculator`
58 | * :term:`EvaluateCalculatorResults`
59 | * :term:`CalculateFormationEnergy`
60 | * :term:`CompareCalculatorResults`
61 | * :term:`CompareFormationEnergy`
62 |
63 |
64 | .. dropdown:: Content of :code:`main.py`
65 |
66 | .. literalinclude:: ../../../mlipx-hub/metrics/DODH_adsorption/main.py
67 | :language: Python
68 |
69 |
70 | .. dropdown:: Content of :code:`models.py`
71 |
72 | .. literalinclude:: ../../../mlipx-hub/metrics/DODH_adsorption/models.py
73 | :language: Python
74 |
--------------------------------------------------------------------------------
/docs/source/recipes/energy_volume.rst:
--------------------------------------------------------------------------------
1 | .. _ev:
2 |
3 | Energy Volume Curves
4 | ====================
5 | Compute the energy-volume curve for a given material using multiple models.
6 |
7 | .. mdinclude:: ../../../mlipx-hub/energy-volume/mp-1143/README.md
8 |
9 |
10 | .. jupyter-execute::
11 | :hide-code:
12 |
13 | from mlipx.doc_utils import get_plots
14 |
15 | plots = get_plots("*EnergyVolumeCurve", "../../mlipx-hub/energy-volume/mp-1143/")
16 | plots["adjusted_energy-volume-curve"].show()
17 |
18 |
19 | This recipe uses the following Nodes together with your provided model in the :term:`models.py` file:
20 |
21 | * :term:`EnergyVolumeCurve`
22 |
23 | .. dropdown:: Content of :code:`main.py`
24 |
25 | .. literalinclude:: ../../../mlipx-hub/energy-volume/mp-1143/main.py
26 | :language: Python
27 |
28 |
29 | .. dropdown:: Content of :code:`models.py`
30 |
31 | .. literalinclude:: ../../../mlipx-hub/energy-volume/mp-1143/models.py
32 | :language: Python
33 |
--------------------------------------------------------------------------------
/docs/source/recipes/homonuclear_diatomics.rst:
--------------------------------------------------------------------------------
1 | .. _homonuclear_diatomics:
2 |
3 | Homonuclear Diatomics
4 | ===========================
5 | Homonuclear diatomics give a per-element information on the performance of the :term:`MLIP`.
6 |
7 |
8 | .. mdinclude:: ../../../mlipx-hub/diatomics/LiCl/README.md
9 |
10 | You can edit the elements in the :term:`main.py` file to include the elements you want to test.
11 | In the following we show the results for the :code:`Li-Li` bond for the three selected models.
12 |
13 | .. code-block:: console
14 |
15 | (.venv) $ mlipx compare --glob "*HomonuclearDiatomics"
16 |
17 |
18 | .. jupyter-execute::
19 | :hide-code:
20 |
21 | from mlipx.doc_utils import get_plots
22 |
23 | plots = get_plots("*HomonuclearDiatomics", "../../mlipx-hub/diatomics/LiCl/")
24 | plots["Li-Li bond (adjusted)"].show()
25 |
26 |
27 | This test uses the following Nodes together with your provided model in the :term:`models.py` file:
28 |
29 | * :term:`HomonuclearDiatomics`
30 |
31 | .. dropdown:: Content of :code:`main.py`
32 |
33 | .. literalinclude:: ../../../mlipx-hub/diatomics/LiCl/main.py
34 | :language: Python
35 |
36 |
37 | .. dropdown:: Content of :code:`models.py`
38 |
39 | .. literalinclude:: ../../../mlipx-hub/diatomics/LiCl/models.py
40 | :language: Python
41 |
--------------------------------------------------------------------------------
/docs/source/recipes/invariances.rst:
--------------------------------------------------------------------------------
1 | Invariances
2 | ===========
3 | Check the rotational, translational and permutational invariance of an :term:`mlip`.
4 |
5 |
6 | .. mdinclude:: ../../../mlipx-hub/invariances/mp-1143/README.md
7 |
8 |
9 | .. jupyter-execute::
10 | :hide-code:
11 |
12 | from mlipx.doc_utils import get_plots
13 |
14 | plots = get_plots("*TranslationalInvariance", "../../mlipx-hub/invariances/mp-1143/")
15 | plots["energy_vs_steps_adjusted"].show()
16 |
17 | plots = get_plots("*RotationalInvariance", ".")
18 | plots["energy_vs_steps_adjusted"].show()
19 |
20 | plots = get_plots("*PermutationInvariance", ".")
21 | plots["energy_vs_steps_adjusted"].show()
22 |
23 |
24 | This recipe uses:
25 |
26 | * :term:`RotationalInvariance`
27 | * :term:`StructureOptimization`
28 | * :term:`PermutationInvariance`
29 |
30 | .. dropdown:: Content of :code:`main.py`
31 |
32 | .. literalinclude:: ../../../mlipx-hub/invariances/mp-1143/main.py
33 | :language: Python
34 |
35 |
36 | .. dropdown:: Content of :code:`models.py`
37 |
38 | .. literalinclude:: ../../../mlipx-hub/invariances/mp-1143/models.py
39 | :language: Python
40 |
--------------------------------------------------------------------------------
/docs/source/recipes/md.rst:
--------------------------------------------------------------------------------
1 | .. _md:
2 |
3 | Molecular Dynamics
4 | ==================
5 | This recipe is used to test the performance of different models in molecular dynamics simulations.
6 |
7 | .. mdinclude:: ../../../mlipx-hub/md/mp-1143/README.md
8 |
9 |
10 |
11 | .. jupyter-execute::
12 | :hide-code:
13 |
14 | from mlipx.doc_utils import get_plots
15 |
16 | plots = get_plots("*MolecularDynamics", "../../mlipx-hub/md/mp-1143/")
17 | plots["energy_vs_steps_adjusted"].show()
18 |
19 | This test uses the following Nodes together with your provided model in the :term:`models.py` file:
20 |
21 | * :term:`LangevinConfig`
22 | * :term:`MaximumForceObserver`
23 | * :term:`TemperatureRampModifier`
24 | * :term:`MolecularDynamics`
25 |
26 | .. dropdown:: Content of :code:`main.py`
27 |
28 | .. literalinclude:: ../../../mlipx-hub/md/mp-1143/main.py
29 | :language: Python
30 |
31 |
32 | .. dropdown:: Content of :code:`models.py`
33 |
34 | .. literalinclude:: ../../../mlipx-hub/md/mp-1143/models.py
35 | :language: Python
36 |
--------------------------------------------------------------------------------
/docs/source/recipes/neb.rst:
--------------------------------------------------------------------------------
1 | .. _neb:
2 |
3 | Nudged Elastic Band
4 | ===================
5 |
6 | :code:`mlipx` provides a command line interface to interpolate and create a NEB path from inital-final or initial-ts-final images and run NEB on the interpolated images.
7 | You can run the following command to instantiate a test directory:
8 |
9 | .. mdinclude:: ../../../mlipx-hub/neb/ex01/README.md
10 |
11 |
12 | .. jupyter-execute::
13 | :hide-code:
14 |
15 | from mlipx.doc_utils import get_plots
16 |
17 | plots = get_plots("*NEBs", "../../mlipx-hub/neb/ex01/")
18 | plots["adjusted_energy_vs_neb_image"].show()
19 |
20 | This test uses the following Nodes together with your provided model in the :term:`models.py` file:
21 |
22 | * :term:`NEBinterpolate`
23 | * :term:`NEBs`
24 |
25 |
26 | .. dropdown:: Content of :code:`main.py`
27 |
28 | .. literalinclude:: ../../../mlipx-hub/neb/ex01/main.py
29 | :language: Python
30 |
31 |
32 | .. dropdown:: Content of :code:`models.py`
33 |
34 | .. literalinclude:: ../../../mlipx-hub/neb/ex01/models.py
35 | :language: Python
36 |
--------------------------------------------------------------------------------
/docs/source/recipes/phase_diagram.rst:
--------------------------------------------------------------------------------
1 | Phase Diagram
2 | =============
3 |
4 | :code:`mlipx` provides a command line interface to generate Phase Diagrams.
5 | You can run the following command to instantiate a test directory:
6 |
7 | .. mdinclude:: ../../../mlipx-hub/phase_diagram/mp-30084/README.md
8 |
9 |
10 | .. jupyter-execute::
11 | :hide-code:
12 |
13 | from mlipx.doc_utils import get_plots
14 |
15 | plots = get_plots("*PhaseDiagram", "../../mlipx-hub/phase_diagram/mp-30084/")
16 | for name, plot in plots.items():
17 | if name.endswith("phase-diagram"):
18 | plot.show()
19 |
20 | This test uses the following Nodes together with your provided model in the :term:`models.py` file:
21 |
22 | * :term:`PhaseDiagram`
23 |
24 | .. dropdown:: Content of :code:`main.py`
25 |
26 | .. literalinclude:: ../../../mlipx-hub/phase_diagram/mp-30084/main.py
27 | :language: Python
28 |
29 |
30 | .. dropdown:: Content of :code:`models.py`
31 |
32 | .. literalinclude:: ../../../mlipx-hub/phase_diagram/mp-30084/models.py
33 | :language: Python
34 |
--------------------------------------------------------------------------------
/docs/source/recipes/pourbaix_diagram.rst:
--------------------------------------------------------------------------------
1 | Pourbaix Diagram
2 | ================
3 |
4 | :code:`mlipx` provides a command line interface to generate Pourbaix diagrams.
5 |
6 |
7 | .. mdinclude:: ../../../mlipx-hub/pourbaix_diagram/mp-1143/README.md
8 |
9 |
10 | .. jupyter-execute::
11 | :hide-code:
12 |
13 | from mlipx.doc_utils import get_plots
14 |
15 | plots = get_plots("*PourbaixDiagram", "../../mlipx-hub/pourbaix_diagram/mp-1143/")
16 | for name, plot in plots.items():
17 | if name.endswith("pourbaix-diagram"):
18 | plot.show()
19 |
20 | This test uses the following Nodes together with your provided model in the :term:`models.py` file:
21 |
22 | * :term:`PourbaixDiagram`
23 |
24 | .. dropdown:: Content of :code:`main.py`
25 |
26 | .. literalinclude:: ../../../mlipx-hub/pourbaix_diagram/mp-1143/main.py
27 | :language: Python
28 |
29 |
30 | .. dropdown:: Content of :code:`models.py`
31 |
32 | .. literalinclude:: ../../../mlipx-hub/pourbaix_diagram/mp-1143/models.py
33 | :language: Python
34 |
--------------------------------------------------------------------------------
/docs/source/recipes/relax.rst:
--------------------------------------------------------------------------------
1 | .. _relax:
2 |
3 | Structure Relaxation
4 | ====================
5 |
6 | This recipe is used to test the performance of different models in performing structure relaxation.
7 |
8 |
9 | .. mdinclude:: ../../../mlipx-hub/relax/mp-1143/README.md
10 |
11 | .. note::
12 |
13 | If you relax a non-periodic system and your model yields a stress tensor of :code:`[inf, inf, inf, inf, inf, inf]` you have to add the :code:`--convert-nan` flag to the :code:`mlipx compare` or :code:`zndraw` command to convert them to :code:`None`.
14 |
15 | .. jupyter-execute::
16 | :hide-code:
17 |
18 | from mlipx.doc_utils import get_plots
19 |
20 | plots = get_plots("*StructureOptimization", "../../mlipx-hub/relax/mp-1143/")
21 | plots["adjusted_energy_vs_steps"].show()
22 |
23 | This recipe uses the following Nodes together with your provided model in the :term:`models.py` file:
24 |
25 | * :term:`StructureOptimization`
26 |
27 | .. dropdown:: Content of :code:`main.py`
28 |
29 | .. literalinclude:: ../../../mlipx-hub/relax/mp-1143/main.py
30 | :language: Python
31 |
32 |
33 | .. dropdown:: Content of :code:`models.py`
34 |
35 | .. literalinclude:: ../../../mlipx-hub/relax/mp-1143/models.py
36 | :language: Python
37 |
--------------------------------------------------------------------------------
/docs/source/recipes/vibrational_analysis.rst:
--------------------------------------------------------------------------------
1 | Vibrational Analysis
2 | ====================
3 |
4 | :code:`mlipx` provides a command line interface to vibrational analysis.
5 | You can run the following command to instantiate a test directory:
6 |
7 | .. mdinclude:: ../../../mlipx-hub/vibrational_analysis/CxO/README.md
8 |
9 |
10 | The vibrational analysis method needs additional information to run.
11 | Please edit the ``main.py`` file and set the ``system`` parameter on the ``VibrationalAnalysis`` node.
12 | For the given list of SMILES, you should set it to ``"molecule"``.
13 | Then run the following commands to reproduce and inspect the results:
14 |
15 | .. code-block:: console
16 |
17 | (.venv) $ python main.py
18 | (.venv) $ dvc repro
19 | (.venv) $ mlipx compare --glob "*VibrationalAnalysis"
20 |
21 |
22 | .. jupyter-execute::
23 | :hide-code:
24 |
25 | from mlipx.doc_utils import get_plots
26 |
27 | plots = get_plots("*VibrationalAnalysis", "../../mlipx-hub/vibrational_analysis/CxO/")
28 | plots["Gibbs-Comparison"].show()
29 |
30 | This test uses the following Nodes together with your provided model in the :term:`models.py` file:
31 |
32 | * :term:`VibrationalAnalysis`
33 |
34 | .. dropdown:: Content of :code:`main.py`
35 |
36 | .. literalinclude:: ../../../mlipx-hub/vibrational_analysis/CxO/main.py
37 | :language: Python
38 |
39 |
40 | .. dropdown:: Content of :code:`models.py`
41 |
42 | .. literalinclude:: ../../../mlipx-hub/vibrational_analysis/CxO/models.py
43 | :language: Python
44 |
--------------------------------------------------------------------------------
/docs/source/references.bib:
--------------------------------------------------------------------------------
1 | @misc{zillsZnTrackDataCode2024,
2 | title = {{{ZnTrack}} -- {{Data}} as {{Code}}},
3 | author = {Zills, Fabian and Sch{\"a}fer, Moritz and Tovey, Samuel and K{\"a}stner, Johannes and Holm, Christian},
4 | year = {2024},
5 | eprint={2401.10603},
6 | archivePrefix={arXiv},
7 | }
8 | @article{zillsCollaborationMachineLearnedPotentials2024,
9 | title = {Collaboration on {{Machine-Learned Potentials}} with {{IPSuite}}: {{A Modular Framework}} for {{Learning-on-the-Fly}}},
10 | shorttitle = {Collaboration on {{Machine-Learned Potentials}} with {{IPSuite}}},
11 | author = {Zills, Fabian and Schäfer, Moritz René and Segreto, Nico and Kästner, Johannes and Holm, Christian and Tovey, Samuel},
12 | year = 2024,
13 | journal = {J. Phys. Chem. B},
14 | publisher = {American Chemical Society},
15 | issn = {1520-6106},
16 | doi = {10.1021/acs.jpcb.3c07187},
17 | }
18 | @misc{elijosiusZeroShotMolecular2024,
19 | title = {Zero {{Shot Molecular Generation}} via {{Similarity Kernels}}},
20 | author = {Elijo{\v s}ius, Rokas and Zills, Fabian and Batatia, Ilyes and Norwood, Sam Walton and Kov{\'a}cs, D{\'a}vid P{\'e}ter and Holm, Christian and Cs{\'a}nyi, G{\'a}bor},
21 | year = {2024},
22 | eprint = {2402.08708},
23 | archiveprefix = {arxiv},
24 | }
25 | @article{dengCHGNetPretrainedUniversal2023,
26 | title = {{{CHGNet}} as a Pretrained Universal Neural Network Potential for Charge-Informed Atomistic Modelling},
27 | author = {Deng, Bowen and Zhong, Peichen and Jun, KyuJung and Riebesell, Janosh and Han, Kevin and Bartel, Christopher J. and Ceder, Gerbrand},
28 | journal = {Nat Mach Intell},
29 | year = {2023},
30 | }
31 |
32 | @online{batatiaFoundationModelAtomistic2023,
33 | title = {A Foundation Model for Atomistic Materials Chemistry},
34 | author = {Batatia et al., Ilyes},
35 | year = {2023},
36 | eprint = {2401.00096},
37 | archiveprefix = {arxiv},
38 | }
39 |
--------------------------------------------------------------------------------
/mlipx/__init__.py:
--------------------------------------------------------------------------------
1 | import lazy_loader as lazy
2 |
3 | __getattr__, __dir__, __all__ = lazy.attach_stub(__name__, __file__)
4 |
--------------------------------------------------------------------------------
/mlipx/__init__.pyi:
--------------------------------------------------------------------------------
1 | from . import abc
2 | from .nodes.adsorption import BuildASEslab, RelaxAdsorptionConfigs
3 | from .nodes.apply_calculator import ApplyCalculator
4 | from .nodes.compare_calculator import CompareCalculatorResults
5 | from .nodes.diatomics import HomonuclearDiatomics
6 | from .nodes.energy_volume import EnergyVolumeCurve
7 | from .nodes.evaluate_calculator import EvaluateCalculatorResults
8 | from .nodes.filter_dataset import FilterAtoms
9 | from .nodes.formation_energy import CalculateFormationEnergy, CompareFormationEnergy
10 | from .nodes.generic_ase import GenericASECalculator
11 | from .nodes.invariances import (
12 | PermutationInvariance,
13 | RotationalInvariance,
14 | TranslationalInvariance,
15 | )
16 | from .nodes.io import LoadDataFile
17 | from .nodes.modifier import TemperatureRampModifier
18 | from .nodes.molecular_dynamics import LangevinConfig, MolecularDynamics
19 | from .nodes.mp_api import MPRester
20 | from .nodes.nebs import NEBinterpolate, NEBs
21 | from .nodes.observer import MaximumForceObserver
22 | from .nodes.orca import OrcaSinglePoint
23 | from .nodes.phase_diagram import PhaseDiagram
24 | from .nodes.pourbaix_diagram import PourbaixDiagram
25 | from .nodes.rattle import Rattle
26 | from .nodes.smiles import BuildBox, Smiles2Conformers
27 | from .nodes.structure_optimization import StructureOptimization
28 | from .nodes.updated_frames import UpdateFramesCalc
29 | from .nodes.vibrational_analysis import VibrationalAnalysis
30 | from .project import Project
31 | from .version import __version__
32 |
33 | __all__ = [
34 | "abc",
35 | "StructureOptimization",
36 | "LoadDataFile",
37 | "MaximumForceObserver",
38 | "TemperatureRampModifier",
39 | "MolecularDynamics",
40 | "LangevinConfig",
41 | "ApplyCalculator",
42 | "CalculateFormationEnergy",
43 | "EvaluateCalculatorResults",
44 | "CompareCalculatorResults",
45 | "NEBs",
46 | "NEBinterpolate",
47 | "Smiles2Conformers",
48 | "PhaseDiagram",
49 | "PourbaixDiagram",
50 | "VibrationalAnalysis",
51 | "HomonuclearDiatomics",
52 | "MPRester",
53 | "GenericASECalculator",
54 | "FilterAtoms",
55 | "EnergyVolumeCurve",
56 | "BuildBox",
57 | "CompareFormationEnergy",
58 | "UpdateFramesCalc",
59 | "RotationalInvariance",
60 | "TranslationalInvariance",
61 | "PermutationInvariance",
62 | "Rattle",
63 | "Project",
64 | "BuildASEslab",
65 | "RelaxAdsorptionConfigs",
66 | "OrcaSinglePoint",
67 | "__version__",
68 | ]
69 |
--------------------------------------------------------------------------------
/mlipx/abc.py:
--------------------------------------------------------------------------------
1 | """Abstract base classes and type hints."""
2 |
3 | import abc
4 | import dataclasses
5 | import pathlib
6 | import typing as t
7 | from enum import Enum
8 |
9 | import ase
10 | import h5py
11 | import plotly.graph_objects as go
12 | import znh5md
13 | import zntrack
14 | from ase.calculators.calculator import Calculator
15 | from ase.md.md import MolecularDynamics
16 |
17 | T = t.TypeVar("T", bound=zntrack.Node)
18 |
19 |
20 | class Optimizer(str, Enum):
21 | FIRE = "FIRE"
22 | BFGS = "BFGS"
23 | LBFGS = "LBFGS"
24 |
25 |
26 | class ASEKeys(str, Enum):
27 | formation_energy = "formation_energy"
28 | isolated_energies = "isolated_energies"
29 |
30 |
31 | class NodeWithCalculator(t.Protocol[T]):
32 | def get_calculator(self, **kwargs) -> Calculator: ...
33 |
34 |
35 | class NodeWithMolecularDynamics(t.Protocol[T]):
36 | def get_molecular_dynamics(self, atoms: ase.Atoms) -> MolecularDynamics: ...
37 |
38 |
39 | FIGURES = t.Dict[str, go.Figure]
40 | FRAMES = t.List[ase.Atoms]
41 |
42 |
43 | class ComparisonResults(t.TypedDict):
44 | frames: FRAMES
45 | figures: FIGURES
46 |
47 |
48 | @dataclasses.dataclass
49 | class DynamicsObserver:
50 | @property
51 | def name(self) -> str:
52 | return self.__class__.__name__
53 |
54 | def initialize(self, atoms: ase.Atoms) -> None:
55 | pass
56 |
57 | @abc.abstractmethod
58 | def check(self, atoms: ase.Atoms) -> bool: ...
59 |
60 |
61 | @dataclasses.dataclass
62 | class DynamicsModifier:
63 | @property
64 | def name(self) -> str:
65 | return self.__class__.__name__
66 |
67 | @abc.abstractmethod
68 | def modify(self, thermostat, step, total_steps) -> None: ...
69 |
70 |
71 | class ProcessAtoms(zntrack.Node):
72 | data: list[ase.Atoms] = zntrack.deps()
73 | data_id: int = zntrack.params(-1)
74 |
75 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.h5")
76 |
77 | @property
78 | def frames(self) -> FRAMES:
79 | with self.state.fs.open(self.frames_path, "r") as f:
80 | with h5py.File(f, "r") as h5f:
81 | return znh5md.IO(file_handle=h5f)[:]
82 |
83 | @property
84 | def figures(self) -> FIGURES: ...
85 |
86 | @staticmethod
87 | def compare(*nodes: "ProcessAtoms") -> ComparisonResults: ...
88 |
89 |
90 | class ProcessFrames(zntrack.Node):
91 | data: list[ase.Atoms] = zntrack.deps()
92 |
--------------------------------------------------------------------------------
/mlipx/benchmark/__init__.py:
--------------------------------------------------------------------------------
1 | from mlipx.benchmark.main import app
2 |
3 | __all__ = ["app"]
4 |
--------------------------------------------------------------------------------
/mlipx/benchmark/elements.py:
--------------------------------------------------------------------------------
1 | import zntrack
2 | from models import MODELS
3 |
4 | import mlipx
5 |
6 | ELEMENTS = {{elements}} # noqa F821
7 | FILTERING_TYPE = "{{ filtering_type }}"
8 |
9 | mptraj = zntrack.add(
10 | url="https://github.com/ACEsuit/mace-mp/releases/download/mace_mp_0b/mp_traj_combined.xyz",
11 | path="mptraj.xyz",
12 | )
13 |
14 | project = zntrack.Project()
15 |
16 | with project.group("mptraj"):
17 | raw_mptraj_data = mlipx.LoadDataFile(path=mptraj)
18 | mptraj_data = mlipx.FilterAtoms(
19 | data=raw_mptraj_data.frames, elements=ELEMENTS, filtering_type=FILTERING_TYPE
20 | )
21 |
22 | for model_name, model in MODELS.items():
23 | with project.group(model_name, "diatomics"):
24 | neb = mlipx.HomonuclearDiatomics(
25 | elements=ELEMENTS,
26 | model=model,
27 | n_points=100,
28 | min_distance=0.5,
29 | max_distance=2.0,
30 | )
31 |
32 | relaxed = []
33 |
34 | for model_name, model in MODELS.items():
35 | with project.group(model_name, "struct_optim"):
36 | relaxed.append(
37 | mlipx.StructureOptimization(
38 | data=mptraj_data.frames, data_id=-1, model=model, fmax=0.1
39 | )
40 | )
41 |
42 |
43 | mds = []
44 |
45 | thermostat = mlipx.LangevinConfig(timestep=0.5, temperature=300, friction=0.05)
46 | force_check = mlipx.MaximumForceObserver(f_max=100)
47 | t_ramp = mlipx.TemperatureRampModifier(end_temperature=400, total_steps=100)
48 |
49 | for (model_name, model), relaxed_structure in zip(MODELS.items(), relaxed):
50 | with project.group(model_name, "md"):
51 | mds.append(
52 | mlipx.MolecularDynamics(
53 | model=model,
54 | thermostat=thermostat,
55 | data=relaxed_structure.frames,
56 | data_id=-1,
57 | observers=[force_check],
58 | modifiers=[t_ramp],
59 | steps=100,
60 | )
61 | )
62 |
63 |
64 | for (model_name, model), md in zip(MODELS.items(), mds):
65 | with project.group(model_name):
66 | ev = mlipx.EnergyVolumeCurve(
67 | model=model,
68 | data=md.frames,
69 | data_id=-1,
70 | n_points=50,
71 | start=0.75,
72 | stop=2.0,
73 | )
74 |
75 | for (model_name, model), md in zip(MODELS.items(), mds):
76 | with project.group(model_name, "struct_optim_2"):
77 | relaxed.append(
78 | mlipx.StructureOptimization(
79 | data=md.frames, data_id=-1, model=model, fmax=0.1
80 | )
81 | )
82 |
83 |
84 | project.build()
85 |
--------------------------------------------------------------------------------
/mlipx/benchmark/file.py:
--------------------------------------------------------------------------------
1 | import ase.io
2 | import zntrack
3 | from models import MODELS
4 |
5 | import mlipx
6 |
7 | DATAPATH = "{{ datapath }}"
8 |
9 | count = 0
10 | ELEMENTS = set()
11 | for atoms in ase.io.iread(DATAPATH):
12 | count += 1
13 | for symbol in atoms.symbols:
14 | ELEMENTS.add(symbol)
15 | ELEMENTS = list(ELEMENTS)
16 |
17 | project = zntrack.Project()
18 |
19 | with project.group("mptraj"):
20 | data = mlipx.LoadDataFile(path=DATAPATH)
21 |
22 |
23 | for model_name, model in MODELS.items():
24 | with project.group(model_name, "diatomics"):
25 | _ = mlipx.HomonuclearDiatomics(
26 | elements=ELEMENTS,
27 | model=model,
28 | n_points=100,
29 | min_distance=0.5,
30 | max_distance=2.0,
31 | )
32 |
33 | # Energy-Volume Curve
34 | for model_name, model in MODELS.items():
35 | for idx in range(count):
36 | with project.group(model_name, "ev", str(idx)):
37 | _ = mlipx.EnergyVolumeCurve(
38 | model=model,
39 | data=data.frames,
40 | data_id=idx,
41 | n_points=50,
42 | start=0.75,
43 | stop=2.0,
44 | )
45 |
46 |
47 | # Molecular Dynamics
48 | thermostat = mlipx.LangevinConfig(timestep=0.5, temperature=300, friction=0.05)
49 | force_check = mlipx.MaximumForceObserver(f_max=100)
50 | t_ramp = mlipx.TemperatureRampModifier(end_temperature=400, total_steps=100)
51 |
52 | for model_name, model in MODELS.items():
53 | for idx in range(count):
54 | with project.group(model_name, "md", str(idx)):
55 | _ = mlipx.MolecularDynamics(
56 | model=model,
57 | thermostat=thermostat,
58 | data=data.frames,
59 | data_id=idx,
60 | observers=[force_check],
61 | modifiers=[t_ramp],
62 | steps=100,
63 | )
64 |
65 | # Structure Optimization
66 | with project.group("rattle"):
67 | rattle = mlipx.Rattle(data=data.frames, stdev=0.01)
68 |
69 | for model_name, model in MODELS.items():
70 | for idx in range(count):
71 | with project.group(model_name, "struct_optim", str(idx)):
72 | _ = mlipx.StructureOptimization(
73 | data=rattle.frames, data_id=idx, model=model, fmax=0.1
74 | )
75 |
76 | project.build()
77 |
--------------------------------------------------------------------------------
/mlipx/benchmark/main.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import shutil
3 | import subprocess
4 | import typing as t
5 |
6 | import jinja2
7 | import typer
8 | from ase.data import chemical_symbols
9 |
10 | from mlipx.nodes.filter_dataset import FilteringType
11 |
12 | CWD = pathlib.Path(__file__).parent
13 |
14 |
15 | app = typer.Typer()
16 |
17 |
18 | def initialize_directory():
19 | subprocess.run(["git", "init"], check=True)
20 | subprocess.run(["dvc", "init"], check=True)
21 | shutil.copy(CWD.parent / "recipes" / "models.py", "models.py")
22 |
23 |
24 | @app.command()
25 | def elements(
26 | elements: t.Annotated[list[str], typer.Argument()],
27 | filtering_type: FilteringType = FilteringType.INCLUSIVE,
28 | ):
29 | for element in elements:
30 | if element not in chemical_symbols:
31 | raise ValueError(f"{element} is not a chemical element")
32 | template = jinja2.Template((CWD / "elements.py").read_text())
33 | with open("main.py", "w") as f:
34 | f.write(template.render(elements=elements, filtering_type=filtering_type.value))
35 |
36 |
37 | @app.command()
38 | def file(
39 | datapath: str,
40 | ):
41 | template = jinja2.Template((CWD / "file.py").read_text())
42 | with open("main.py", "w") as f:
43 | f.write(template.render(datapath=datapath))
44 |
--------------------------------------------------------------------------------
/mlipx/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basf/mlipx/c109aa6b2d3ff8236d6a9ec0772a8c4efb9c7dea/mlipx/cli/__init__.py
--------------------------------------------------------------------------------
/mlipx/cli/main.py:
--------------------------------------------------------------------------------
1 | import fnmatch
2 | import importlib.metadata
3 | import json
4 | import pathlib
5 | import sys
6 | import uuid
7 | import webbrowser
8 |
9 | import dvc.api
10 | import plotly.io as pio
11 | import typer
12 | import zntrack
13 | from rich import box
14 | from rich.console import Console
15 | from rich.table import Table
16 | from tqdm import tqdm
17 | from typing_extensions import Annotated
18 | from zndraw import ZnDraw
19 |
20 | from mlipx import benchmark, recipes
21 |
22 | app = typer.Typer()
23 | app.add_typer(recipes.app, name="recipes")
24 | app.add_typer(benchmark.app, name="benchmark")
25 |
26 | # Load plugins
27 |
28 | entry_points = importlib.metadata.entry_points(group="mlipx.recipes")
29 | for entry_point in entry_points:
30 | entry_point.load()
31 |
32 |
33 | @app.command()
34 | def main():
35 | typer.echo("Hello World")
36 |
37 |
38 | @app.command()
39 | def info():
40 | """Print the version of mlipx and the available models."""
41 | from mlipx.models import AVAILABLE_MODELS # slow import
42 |
43 | console = Console()
44 | # Get Python environment info
45 | python_version = sys.version.split()[0]
46 | python_executable = sys.executable
47 | python_platform = sys.platform
48 |
49 | py_table = Table(title="🐍 Python Environment", box=box.ROUNDED)
50 | py_table.add_column("Version", style="cyan", no_wrap=True)
51 | py_table.add_column("Executable", style="magenta")
52 | py_table.add_column("Platform", style="green")
53 | py_table.add_row(python_version, python_executable, python_platform)
54 |
55 | # Get model availability
56 | mlip_table = Table(title="🧠 MLIP Codes", box=box.ROUNDED)
57 | mlip_table.add_column("Model", style="bold")
58 | mlip_table.add_column("Available", style="bold")
59 |
60 | for model in sorted(AVAILABLE_MODELS):
61 | status = AVAILABLE_MODELS[model]
62 | if status is True:
63 | mlip_table.add_row(model, "[green]:heavy_check_mark: Yes[/green]")
64 | elif status is False:
65 | mlip_table.add_row(model, "[red]:x: No[/red]")
66 | elif status is None:
67 | mlip_table.add_row(model, "[yellow]:warning: Unknown[/yellow]")
68 | else:
69 | mlip_table.add_row(model, "[red]:boom: Error[/red]")
70 |
71 | # Get versions of key packages
72 | mlipx_table = Table(title="📦 mlipx Ecosystem", box=box.ROUNDED)
73 | mlipx_table.add_column("Package", style="bold")
74 | mlipx_table.add_column("Version", style="cyan")
75 |
76 | for package in ["mlipx", "zntrack", "zndraw"]:
77 | try:
78 | version = importlib.metadata.version(package)
79 | except importlib.metadata.PackageNotFoundError:
80 | version = "[red]Not installed[/red]"
81 | mlipx_table.add_row(package, version)
82 |
83 | # Display all
84 | console.print(mlipx_table)
85 | console.print(py_table)
86 | console.print(mlip_table)
87 |
88 |
89 | @app.command()
90 | def compare( # noqa C901
91 | nodes: Annotated[list[str], typer.Argument(help="Path to the node to compare")],
92 | zndraw_url: Annotated[
93 | str,
94 | typer.Option(
95 | envvar="ZNDRAW_URL",
96 | help="URL of the ZnDraw server to visualize the results",
97 | ),
98 | ],
99 | kwarg: Annotated[list[str], typer.Option("--kwarg", "-k")] = None,
100 | token: Annotated[str, typer.Option("--token")] = None,
101 | glob: Annotated[
102 | bool, typer.Option("--glob", help="Allow glob patterns to select nodes.")
103 | ] = False,
104 | convert_nan: Annotated[bool, typer.Option()] = False,
105 | browser: Annotated[
106 | bool,
107 | typer.Option(
108 | help="""Whether to open the ZnDraw GUI in the default web browser."""
109 | ),
110 | ] = True,
111 | figures_path: Annotated[
112 | str | None,
113 | typer.Option(
114 | help="Provide a path to save the figures to."
115 | "No figures will be saved by default."
116 | ),
117 | ] = None,
118 | ):
119 | """Compare mlipx nodes and visualize the results using ZnDraw."""
120 | # TODO: allow for glob patterns
121 | if kwarg is None:
122 | kwarg = []
123 | node_names, revs, remotes = [], [], []
124 | if glob:
125 | fs = dvc.api.DVCFileSystem()
126 | with fs.open("zntrack.json", mode="r") as f:
127 | all_nodes = list(json.load(f).keys())
128 |
129 | for node in nodes:
130 | # can be name or name@rev or name@remote@rev
131 | parts = node.split("@")
132 | if glob:
133 | filtered_nodes = [x for x in all_nodes if fnmatch.fnmatch(x, parts[0])]
134 | else:
135 | filtered_nodes = [parts[0]]
136 | for x in filtered_nodes:
137 | node_names.append(x)
138 | if len(parts) == 1:
139 | revs.append(None)
140 | remotes.append(None)
141 | elif len(parts) == 2:
142 | revs.append(parts[1])
143 | remotes.append(None)
144 | elif len(parts) == 3:
145 | remotes.append(parts[1])
146 | revs.append(parts[2])
147 | else:
148 | raise ValueError(f"Invalid node format: {node}")
149 |
150 | node_instances = {}
151 | for node_name, rev, remote in tqdm(
152 | zip(node_names, revs, remotes), desc="Loading nodes"
153 | ):
154 | node_instances[node_name] = zntrack.from_rev(node_name, remote=remote, rev=rev)
155 |
156 | if len(node_instances) == 0:
157 | typer.echo("No nodes to compare")
158 | return
159 |
160 | typer.echo(f"Comparing {len(node_instances)} nodes")
161 |
162 | kwargs = {}
163 | for arg in kwarg:
164 | key, value = arg.split("=", 1)
165 | kwargs[key] = value
166 | result = node_instances[node_names[0]].compare(*node_instances.values(), **kwargs)
167 |
168 | token = token or str(uuid.uuid4())
169 | typer.echo(f"View the results at {zndraw_url}/token/{token}")
170 | vis = ZnDraw(zndraw_url, token=token, convert_nan=convert_nan)
171 | length = len(vis)
172 | vis.extend(result["frames"])
173 | del vis[:length] # temporary fix
174 | vis.figures = result["figures"]
175 | if browser:
176 | webbrowser.open(f"{zndraw_url}/token/{token}")
177 | if figures_path:
178 | for desc, fig in result["figures"].items():
179 | pio.write_json(fig, pathlib.Path(figures_path) / f"{desc}.json")
180 |
181 | vis.socket.sleep(5)
182 |
--------------------------------------------------------------------------------
/mlipx/doc_utils.py:
--------------------------------------------------------------------------------
1 | import fnmatch
2 | import json
3 | import os
4 | import pathlib
5 |
6 | import plotly.graph_objects as go
7 | import plotly.io as pio
8 | import zntrack
9 |
10 |
11 | def show(file: str) -> None:
12 | pio.renderers.default = "sphinx_gallery"
13 |
14 | figure = pio.read_json(f"source/figures/{file}")
15 | figure.update_layout(
16 | {
17 | "plot_bgcolor": "rgba(0, 0, 0, 0)",
18 | "paper_bgcolor": "rgba(0, 0, 0, 0)",
19 | }
20 | )
21 | figure.update_xaxes(
22 | showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False
23 | )
24 | figure.update_yaxes(
25 | showgrid=True, gridwidth=1, gridcolor="rgba(120, 120, 120, 0.3)", zeroline=False
26 | )
27 | figure.show()
28 |
29 |
30 | def get_plots(name: str, url: str) -> dict[str, go.Figure]:
31 | os.chdir(url)
32 | pio.renderers.default = "sphinx_gallery"
33 | with pathlib.Path("zntrack.json").open(mode="r") as f:
34 | all_nodes = list(json.load(f).keys())
35 | filtered_nodes = [x for x in all_nodes if fnmatch.fnmatch(x, name)]
36 |
37 | node_instances = {}
38 | for node_name in filtered_nodes:
39 | node_instances[node_name] = zntrack.from_rev(node_name)
40 |
41 | result = node_instances[filtered_nodes[0]].compare(*node_instances.values())
42 | return result["figures"]
43 |
--------------------------------------------------------------------------------
/mlipx/models.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import jinja2
4 |
5 | import mlipx
6 | from mlipx import recipes
7 |
8 | AVAILABLE_MODELS = {}
9 |
10 | RECIPES_PATH = Path(recipes.__file__).parent
11 | template = jinja2.Template((RECIPES_PATH / "models.py.jinja2").read_text())
12 |
13 | rendered_code = template.render(models=[])
14 |
15 | # Prepare a namespace and execute the rendered code into it
16 | namespace = {"mlipx": mlipx} # replace with your actual mlipx
17 | exec(rendered_code, namespace)
18 |
19 | # Access ALL_MODELS and MODELS
20 | all_models = namespace["ALL_MODELS"]
21 |
22 | AVAILABLE_MODELS = {
23 | model_name: model.available for model_name, model in all_models.items()
24 | }
25 |
--------------------------------------------------------------------------------
/mlipx/nodes/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basf/mlipx/c109aa6b2d3ff8236d6a9ec0772a8c4efb9c7dea/mlipx/nodes/__init__.py
--------------------------------------------------------------------------------
/mlipx/nodes/apply_calculator.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | import ase
4 | import h5py
5 | import tqdm
6 | import znh5md
7 | import zntrack
8 | from ase.calculators.calculator import all_properties
9 |
10 | from mlipx.abc import NodeWithCalculator
11 | from mlipx.utils import freeze_copy_atoms
12 |
13 |
14 | class ApplyCalculator(zntrack.Node):
15 | """
16 | Apply a calculator to a list of atoms objects and store the results in a H5MD file.
17 |
18 | Parameters
19 | ----------
20 | data : list[ase.Atoms]
21 | List of atoms objects to calculate.
22 | model : NodeWithCalculator, optional
23 | Node providing the calculator object to apply to the data.
24 | """
25 |
26 | data: list[ase.Atoms] = zntrack.deps()
27 | model: NodeWithCalculator | None = zntrack.deps()
28 |
29 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.h5")
30 |
31 | def run(self):
32 | frames = []
33 | if self.model is not None:
34 | calc = self.model.get_calculator()
35 |
36 | # Some calculators, e.g. MACE do not follow the ASE API correctly.
37 | # and we need to fix some keys in `all_properties`
38 | all_properties.append("node_energy")
39 |
40 | for atoms in tqdm.tqdm(self.data):
41 | atoms.calc = calc
42 | atoms.get_potential_energy()
43 | frames.append(freeze_copy_atoms(atoms))
44 | else:
45 | frames = self.data
46 |
47 | io = znh5md.IO(self.frames_path)
48 | io.extend(frames)
49 |
50 | @property
51 | def frames(self) -> list[ase.Atoms]:
52 | with self.state.fs.open(self.frames_path, "rb") as f:
53 | with h5py.File(f) as file:
54 | return list(znh5md.IO(file_handle=file))
55 |
--------------------------------------------------------------------------------
/mlipx/nodes/autowte.py:
--------------------------------------------------------------------------------
1 | """Automatic heat-conductivity predictions from the Wigner Transport Equation.
2 |
3 | Based on https://github.com/MPA2suite/autoWTE
4 |
5 | Use pip install git+https://github.com/MPA2suite/autoWTE
6 | and conda install -c conda-forge phono3py
7 | """
8 |
9 | import zntrack
10 |
11 | from mlipx.abc import NodeWithCalculator
12 |
13 |
14 | class AutoWTE(zntrack.Node):
15 | model: NodeWithCalculator = zntrack.deps()
16 |
--------------------------------------------------------------------------------
/mlipx/nodes/compare_calculator.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 |
3 | import pandas as pd
4 | import tqdm
5 | import zntrack
6 | from ase.calculators.calculator import PropertyNotImplementedError
7 |
8 | from mlipx.abc import FIGURES, FRAMES, ComparisonResults
9 | from mlipx.nodes.evaluate_calculator import EvaluateCalculatorResults, get_figure
10 | from mlipx.utils import rmse, shallow_copy_atoms
11 |
12 |
13 | class CompareCalculatorResults(zntrack.Node):
14 | """
15 | CompareCalculatorResults is a node that compares the results of two calculators.
16 | It calculates the RMSE between the two calculators and adjusts plots accordingly.
17 | It calculates the error between the two calculators and saves the min/max values.
18 |
19 | Parameters
20 | ----------
21 | data : EvaluateCalculatorResults
22 | The results of the first calculator.
23 | reference : EvaluateCalculatorResults
24 | The results of the second calculator.
25 | The results of the first calculator will be compared to these results.
26 | """
27 |
28 | data: EvaluateCalculatorResults = zntrack.deps()
29 | reference: EvaluateCalculatorResults = zntrack.deps()
30 |
31 | plots: pd.DataFrame = zntrack.plots(autosave=True)
32 | rmse: dict = zntrack.metrics()
33 | error: dict = zntrack.metrics()
34 |
35 | def run(self):
36 | e_rmse = rmse(self.data.plots["energy"], self.reference.plots["energy"])
37 | self.rmse = {
38 | "energy": e_rmse,
39 | "energy_per_atom": e_rmse / len(self.data.plots),
40 | "fmax": rmse(self.data.plots["fmax"], self.reference.plots["fmax"]),
41 | "fnorm": rmse(self.data.plots["fnorm"], self.reference.plots["fnorm"]),
42 | }
43 |
44 | all_plots = []
45 |
46 | for row_idx in tqdm.trange(len(self.data.plots)):
47 | plots = {}
48 | plots["adjusted_energy_error"] = (
49 | self.data.plots["energy"].iloc[row_idx] - e_rmse
50 | ) - self.reference.plots["energy"].iloc[row_idx]
51 | plots["adjusted_energy"] = self.data.plots["energy"].iloc[row_idx] - e_rmse
52 | plots["adjusted_energy_error_per_atom"] = (
53 | plots["adjusted_energy_error"]
54 | / self.data.plots["n_atoms"].iloc[row_idx]
55 | )
56 |
57 | plots["fmax_error"] = (
58 | self.data.plots["fmax"].iloc[row_idx]
59 | - self.reference.plots["fmax"].iloc[row_idx]
60 | )
61 | plots["fnorm_error"] = (
62 | self.data.plots["fnorm"].iloc[row_idx]
63 | - self.reference.plots["fnorm"].iloc[row_idx]
64 | )
65 | all_plots.append(plots)
66 | self.plots = pd.DataFrame(all_plots)
67 |
68 | # iterate over plots and save min/max
69 | self.error = {}
70 | for key in self.plots.columns:
71 | if "_error" in key:
72 | stripped_key = key.replace("_error", "")
73 | self.error[f"{stripped_key}_max"] = self.plots[key].max()
74 | self.error[f"{stripped_key}_min"] = self.plots[key].min()
75 |
76 | @property
77 | def frames(self) -> FRAMES:
78 | return self.data.frames
79 |
80 | @property
81 | def figures(self) -> FIGURES:
82 | figures = {}
83 | for key in self.plots.columns:
84 | figures[key] = get_figure(key, [self])
85 | return figures
86 |
87 | def compare(self, *nodes: "CompareCalculatorResults") -> ComparisonResults: # noqa C901
88 | if len(nodes) == 0:
89 | raise ValueError("No nodes to compare provided")
90 | figures = {}
91 | frames_info = {}
92 | for key in nodes[0].plots.columns:
93 | if not all(key in node.plots.columns for node in nodes):
94 | raise ValueError(f"Key {key} not found in all nodes")
95 | # check frames are the same
96 | figures[key] = get_figure(key, nodes)
97 |
98 | for node in nodes:
99 | for key in node.plots.columns:
100 | frames_info[f"{node.name}_{key}"] = node.plots[key].values
101 |
102 | # TODO: calculate the rmse difference between a fixed
103 | # one and all the others and shift them accordingly
104 | # and plot as energy_shifted
105 |
106 | # plot error between curves
107 | # mlipx pass additional flags to compare function
108 | # have different compare methods also for correlation plots
109 |
110 | frames = [shallow_copy_atoms(x) for x in nodes[0].frames]
111 | for key, values in frames_info.items():
112 | for atoms, value in zip(frames, values):
113 | atoms.info[key] = value
114 |
115 | for node in nodes:
116 | for node_atoms, atoms in zip(node.frames, frames):
117 | if len(node_atoms) != len(atoms):
118 | raise ValueError("Atoms objects have different lengths")
119 | with contextlib.suppress(RuntimeError, PropertyNotImplementedError):
120 | atoms.info[f"{node.name}_energy"] = (
121 | node_atoms.get_potential_energy()
122 | )
123 | atoms.arrays[f"{node.name}_forces"] = node_atoms.get_forces()
124 |
125 | for ref_atoms, atoms in zip(nodes[0].reference.frames, frames):
126 | with contextlib.suppress(RuntimeError, PropertyNotImplementedError):
127 | atoms.info["ref_energy"] = ref_atoms.get_potential_energy()
128 | atoms.arrays["ref_forces"] = ref_atoms.get_forces()
129 |
130 | return {
131 | "frames": frames,
132 | "figures": figures,
133 | }
134 |
--------------------------------------------------------------------------------
/mlipx/nodes/energy_volume.py:
--------------------------------------------------------------------------------
1 | import ase.io
2 | import numpy as np
3 | import pandas as pd
4 | import plotly.express as px
5 | import plotly.graph_objects as go
6 | import tqdm
7 | import zntrack
8 |
9 | from mlipx.abc import ComparisonResults, NodeWithCalculator
10 |
11 |
12 | class EnergyVolumeCurve(zntrack.Node):
13 | """Compute the energy-volume curve for a given structure.
14 |
15 | Parameters
16 | ----------
17 | data : list[ase.Atoms]
18 | List of structures to evaluate.
19 | model : NodeWithCalculator
20 | Node providing the calculator object for the energy calculations.
21 | data_id : int, default=-1
22 | Index of the structure to evaluate.
23 | n_points : int, default=50
24 | Number of points to sample for the volume scaling.
25 | start : float, default=0.75
26 | Initial scaling factor from the original cell.
27 | stop : float, default=2.0
28 | Final scaling factor from the original cell.
29 |
30 | Attributes
31 | ----------
32 | results : pd.DataFrame
33 | DataFrame with the volume, energy, and scaling factor.
34 |
35 | """
36 |
37 | model: NodeWithCalculator = zntrack.deps()
38 | data: list[ase.Atoms] = zntrack.deps()
39 | data_id: int = zntrack.params(-1)
40 |
41 | n_points: int = zntrack.params(50)
42 | start: float = zntrack.params(0.75)
43 | stop: float = zntrack.params(2.0)
44 |
45 | frames_path: str = zntrack.outs_path(zntrack.nwd / "frames.xyz")
46 | results: pd.DataFrame = zntrack.plots(y="energy", x="scale")
47 |
48 | def run(self):
49 | atoms = self.data[self.data_id]
50 | calc = self.model.get_calculator()
51 |
52 | results = []
53 |
54 | scale_factor = np.linspace(self.start, self.stop, self.n_points)
55 | for scale in tqdm.tqdm(scale_factor):
56 | atoms_copy = atoms.copy()
57 | atoms_copy.set_cell(atoms.get_cell() * scale, scale_atoms=True)
58 | atoms_copy.calc = calc
59 |
60 | results.append(
61 | {
62 | "volume": atoms_copy.get_volume(),
63 | "energy": atoms_copy.get_potential_energy(),
64 | "fmax": np.linalg.norm(atoms_copy.get_forces(), axis=-1).max(),
65 | "scale": scale,
66 | }
67 | )
68 |
69 | ase.io.write(self.frames_path, atoms_copy, append=True)
70 |
71 | self.results = pd.DataFrame(results)
72 |
73 | @property
74 | def frames(self) -> list[ase.Atoms]:
75 | """List of structures evaluated during the energy-volume curve."""
76 | with self.state.fs.open(self.frames_path, "r") as f:
77 | return list(ase.io.iread(f, format="extxyz"))
78 |
79 | @property
80 | def figures(self) -> dict[str, go.Figure]:
81 | """Plot the energy-volume curve."""
82 | fig = px.scatter(self.results, x="scale", y="energy", color="scale")
83 | fig.update_layout(title="Energy-Volume Curve")
84 | fig.update_traces(customdata=np.stack([np.arange(self.n_points)], axis=1))
85 | fig.update_xaxes(title_text="cell vector scale")
86 | fig.update_yaxes(title_text="Energy / eV")
87 |
88 | ffig = px.scatter(self.results, x="scale", y="fmax", color="scale")
89 | ffig.update_layout(title="Energy-Volume Curve (fmax)")
90 | ffig.update_traces(customdata=np.stack([np.arange(self.n_points)], axis=1))
91 | ffig.update_xaxes(title_text="cell vector scale")
92 | ffig.update_yaxes(title_text="Maximum Force / eV/Å")
93 |
94 | return {"energy-volume-curve": fig, "fmax-volume-curve": ffig}
95 |
96 | @staticmethod
97 | def compare(*nodes: "EnergyVolumeCurve") -> ComparisonResults:
98 | """Compare the energy-volume curves of multiple nodes."""
99 | fig = go.Figure()
100 | for node in nodes:
101 | fig.add_trace(
102 | go.Scatter(
103 | x=node.results["scale"],
104 | y=node.results["energy"],
105 | mode="lines+markers",
106 | name=node.name.replace("_EnergyVolumeCurve", ""),
107 | )
108 | )
109 | fig.update_traces(customdata=np.stack([np.arange(node.n_points)], axis=1))
110 |
111 | # TODO: remove all info from the frames?
112 | # What about forces / energies? Update the key?
113 | fig.update_layout(title="Energy-Volume Curve Comparison")
114 | # set x-axis title
115 | # fig.update_xaxes(title_text="Volume / ų")
116 | fig.update_xaxes(title_text="cell vector scale")
117 | fig.update_yaxes(title_text="Energy / eV")
118 |
119 | # Now adjusted
120 |
121 | fig_adjust = go.Figure()
122 | for node in nodes:
123 | scale_factor = np.linspace(node.start, node.stop, node.n_points)
124 | one_idx = np.abs(scale_factor - 1).argmin()
125 | fig_adjust.add_trace(
126 | go.Scatter(
127 | x=node.results["scale"],
128 | y=node.results["energy"] - node.results["energy"].iloc[one_idx],
129 | mode="lines+markers",
130 | name=node.name.replace("_EnergyVolumeCurve", ""),
131 | )
132 | )
133 | fig_adjust.update_traces(
134 | customdata=np.stack([np.arange(node.n_points)], axis=1)
135 | )
136 |
137 | fig_adjust.update_layout(title="Adjusted Energy-Volume Curve Comparison")
138 | fig_adjust.update_xaxes(title_text="cell vector scale")
139 | fig_adjust.update_yaxes(title_text="Adjusted Energy / eV")
140 |
141 | fig_adjust.update_layout(
142 | plot_bgcolor="rgba(0, 0, 0, 0)",
143 | paper_bgcolor="rgba(0, 0, 0, 0)",
144 | )
145 | fig_adjust.update_xaxes(
146 | showgrid=True,
147 | gridwidth=1,
148 | gridcolor="rgba(120, 120, 120, 0.3)",
149 | zeroline=False,
150 | )
151 | fig_adjust.update_yaxes(
152 | showgrid=True,
153 | gridwidth=1,
154 | gridcolor="rgba(120, 120, 120, 0.3)",
155 | zeroline=False,
156 | )
157 |
158 | fig.update_layout(
159 | plot_bgcolor="rgba(0, 0, 0, 0)",
160 | paper_bgcolor="rgba(0, 0, 0, 0)",
161 | )
162 | fig.update_xaxes(
163 | showgrid=True,
164 | gridwidth=1,
165 | gridcolor="rgba(120, 120, 120, 0.3)",
166 | zeroline=False,
167 | )
168 | fig.update_yaxes(
169 | showgrid=True,
170 | gridwidth=1,
171 | gridcolor="rgba(120, 120, 120, 0.3)",
172 | zeroline=False,
173 | )
174 |
175 | return {
176 | "frames": nodes[0].frames,
177 | "figures": {
178 | "energy-volume-curve": fig,
179 | "adjusted_energy-volume-curve": fig_adjust,
180 | },
181 | }
182 |
--------------------------------------------------------------------------------
/mlipx/nodes/evaluate_calculator.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 |
3 | import ase
4 | import numpy as np
5 | import pandas as pd
6 | import plotly.express as px
7 | import plotly.graph_objects as go
8 | import tqdm
9 | import zntrack
10 | from ase.calculators.calculator import PropertyNotImplementedError
11 |
12 | from mlipx.abc import ComparisonResults
13 | from mlipx.utils import shallow_copy_atoms
14 |
15 |
16 | def get_figure(key: str, nodes: list["EvaluateCalculatorResults"]) -> go.Figure:
17 | fig = go.Figure()
18 | for node in nodes:
19 | fig.add_trace(
20 | go.Scatter(
21 | x=node.plots.index,
22 | y=node.plots[key],
23 | mode="lines+markers",
24 | name=node.name.replace(f"_{node.__class__.__name__}", ""),
25 | )
26 | )
27 | fig.update_traces(customdata=np.stack([np.arange(len(node.plots.index))], axis=1))
28 | fig.update_layout(
29 | title=key,
30 | plot_bgcolor="rgba(0, 0, 0, 0)",
31 | paper_bgcolor="rgba(0, 0, 0, 0)",
32 | )
33 | fig.update_xaxes(
34 | showgrid=True,
35 | gridwidth=1,
36 | gridcolor="rgba(120, 120, 120, 0.3)",
37 | zeroline=False,
38 | )
39 | fig.update_yaxes(
40 | showgrid=True,
41 | gridwidth=1,
42 | gridcolor="rgba(120, 120, 120, 0.3)",
43 | zeroline=False,
44 | )
45 | return fig
46 |
47 |
48 | class EvaluateCalculatorResults(zntrack.Node):
49 | """
50 | Evaluate the results of a calculator.
51 |
52 | Parameters
53 | ----------
54 | data : list[ase.Atoms]
55 | List of atoms objects.
56 |
57 | """
58 |
59 | data: list[ase.Atoms] = zntrack.deps()
60 | plots: pd.DataFrame = zntrack.plots(
61 | y=["fmax", "fnorm", "energy"], independent=True, autosave=True
62 | )
63 |
64 | def run(self):
65 | self.plots = pd.DataFrame()
66 | frame_data = []
67 | for idx in tqdm.tqdm(range(len(self.data))):
68 | atoms = self.data[idx]
69 |
70 | forces = atoms.get_forces()
71 | fmax = np.max(np.linalg.norm(forces, axis=1))
72 | fnorm = np.linalg.norm(forces)
73 | energy = atoms.get_potential_energy()
74 | # eform = atoms.info.get(ASEKeys.formation_energy.value, -1)
75 | n_atoms = len(atoms)
76 |
77 | # have energy and formation energy in the plot
78 |
79 | plots = {
80 | "fmax": fmax,
81 | "fnorm": fnorm,
82 | "energy": energy,
83 | # "eform": eform,
84 | "n_atoms": n_atoms,
85 | "energy_per_atom": energy / n_atoms,
86 | # "eform_per_atom": eform / n_atoms,
87 | }
88 | frame_data.append(plots)
89 | self.plots = pd.DataFrame(frame_data)
90 |
91 | @property
92 | def frames(self):
93 | return self.data
94 |
95 | def __run_note__(self) -> str:
96 | return f"""# {self.name}
97 | Results from {self.state.remote} at {self.state.rev}.
98 |
99 | View the trajectory via zndraw:
100 | ```bash
101 | zndraw {self.name}.frames --rev {self.state.rev} --remote {self.state.remote} --url https://app-dev.roqs.basf.net/zndraw_app
102 | ```
103 | """
104 |
105 | @property
106 | def figures(self) -> dict:
107 | # TODO: remove index column
108 |
109 | plots = {}
110 | for key in self.plots.columns:
111 | fig = px.line(
112 | self.plots,
113 | x=self.plots.index,
114 | y=key,
115 | title=key,
116 | )
117 | fig.update_traces(
118 | customdata=np.stack([np.arange(len(self.plots))], axis=1),
119 | )
120 | plots[key] = fig
121 | return plots
122 |
123 | @staticmethod
124 | def compare(
125 | *nodes: "EvaluateCalculatorResults", reference: str | None = None
126 | ) -> ComparisonResults:
127 | # TODO: if reference, shift energies by
128 | # rmse(val, reference) and plot as energy_adjusted
129 | figures = {}
130 | frames_info = {}
131 | for key in nodes[0].plots.columns:
132 | if not all(key in node.plots.columns for node in nodes):
133 | raise ValueError(f"Key {key} not found in all nodes")
134 | # check frames are the same
135 | figures[key] = get_figure(key, nodes)
136 |
137 | for node in nodes:
138 | for key in node.plots.columns:
139 | frames_info[f"{node.name}_{key}"] = node.plots[key].values
140 |
141 | # TODO: calculate the rmse difference between a fixed one
142 | # and all the others and shift them accordingly
143 | # and plot as energy_shifted
144 |
145 | # plot error between curves
146 | # mlipx pass additional flags to compare function
147 | # have different compare methods also for correlation plots
148 |
149 | frames = [shallow_copy_atoms(x) for x in nodes[0].frames]
150 | for key, values in frames_info.items():
151 | for atoms, value in zip(frames, values):
152 | atoms.info[key] = value
153 |
154 | for node in nodes:
155 | for node_atoms, atoms in zip(node.frames, frames):
156 | if len(node_atoms) != len(atoms):
157 | raise ValueError("Atoms objects have different lengths")
158 | with contextlib.suppress(RuntimeError, PropertyNotImplementedError):
159 | atoms.info[f"{node.name}_energy"] = (
160 | node_atoms.get_potential_energy()
161 | )
162 | atoms.arrays[f"{node.name}_forces"] = node_atoms.get_forces()
163 |
164 | return {
165 | "frames": frames,
166 | "figures": figures,
167 | }
168 |
--------------------------------------------------------------------------------
/mlipx/nodes/filter_dataset.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import typing as t
3 | from enum import Enum
4 |
5 | import ase.io
6 | import numpy as np
7 | import tqdm
8 | import zntrack
9 |
10 |
11 | class FilteringType(str, Enum):
12 | """
13 | Enum defining types of filtering to apply on atomic configurations.
14 |
15 | Attributes
16 | ----------
17 | COMBINATIONS : str
18 | Filters to include atoms with elements that are any
19 | subset of the specified elements.
20 | EXCLUSIVE : str
21 | Filters to include atoms that contain *only* the specified elements.
22 | INCLUSIVE : str
23 | Filters to include atoms that contain *at least* the specified elements.
24 | """
25 |
26 | COMBINATIONS = "combinations"
27 | EXCLUSIVE = "exclusive"
28 | INCLUSIVE = "inclusive"
29 |
30 |
31 | def filter_atoms(
32 | atoms: ase.Atoms,
33 | element_subset: list[str],
34 | filtering_type: t.Optional[FilteringType] = None,
35 | ) -> bool:
36 | """
37 | Filters an atomic configuration based on the
38 | specified filtering type and element subset.
39 |
40 | Parameters
41 | ----------
42 | atoms : ase.Atoms
43 | Atomic configuration to filter.
44 | element_subset : list[str]
45 | List of elements to be considered during filtering.
46 | filtering_type : FilteringType, optional
47 | Type of filtering to apply (COMBINATIONS, EXCLUSIVE, or INCLUSIVE).
48 | If None, all atoms pass the filter.
49 |
50 | Returns
51 | -------
52 | bool
53 | True if the atomic configuration passes the filter, False otherwise.
54 |
55 | References
56 | ----------
57 | Adapted from github.com/ACEsuit/mace
58 |
59 | Raises
60 | ------
61 | ValueError
62 | If the provided filtering_type is not recognized.
63 | """
64 | if filtering_type is None:
65 | return True
66 | elif filtering_type == FilteringType.COMBINATIONS:
67 | atom_symbols = np.unique(atoms.symbols)
68 | return all(x in element_subset for x in atom_symbols)
69 | elif filtering_type == FilteringType.EXCLUSIVE:
70 | atom_symbols = set(atoms.symbols)
71 | return atom_symbols == set(element_subset)
72 | elif filtering_type == FilteringType.INCLUSIVE:
73 | atom_symbols = np.unique(atoms.symbols)
74 | return all(x in atom_symbols for x in element_subset)
75 | else:
76 | raise ValueError(
77 | f"Filtering type {filtering_type} not recognized."
78 | " Must be one of 'none', 'exclusive', or 'inclusive'."
79 | )
80 |
81 |
82 | class FilterAtoms(zntrack.Node):
83 | """
84 | ZnTrack node that filters a list of atomic configurations
85 | based on specified elements and filtering type.
86 |
87 | Attributes
88 | ----------
89 | data : list[ase.Atoms]
90 | List of atomic configurations to filter.
91 | elements : list[str]
92 | List of elements to use as the filtering subset.
93 | filtering_type : FilteringType
94 | Type of filtering to apply (INCLUSIVE, EXCLUSIVE, or COMBINATIONS).
95 | frames_path : pathlib.Path
96 | Path to store filtered atomic configuration frames.
97 |
98 | Methods
99 | -------
100 | run()
101 | Filters atomic configurations and writes the results to `frames_path`.
102 | frames() -> list[ase.Atoms]
103 | Loads filtered atomic configurations from `frames_path`.
104 | """
105 |
106 | data: list[ase.Atoms] = zntrack.deps()
107 | elements: list[str] = zntrack.params()
108 | filtering_type: FilteringType = zntrack.params(FilteringType.INCLUSIVE.value)
109 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz")
110 |
111 | def run(self):
112 | """
113 | Applies filtering to atomic configurations in
114 | `data` and saves the results to `frames_path`.
115 | """
116 | for atoms in tqdm.tqdm(self.data):
117 | if filter_atoms(atoms, self.elements, self.filtering_type):
118 | ase.io.write(self.frames_path, atoms, append=True)
119 |
120 | @property
121 | def frames(self) -> list[ase.Atoms]:
122 | """Loads the filtered atomic configurations from the `frames_path` file.
123 |
124 | Returns
125 | -------
126 | list[ase.Atoms]
127 | List of filtered atomic configuration frames.
128 | """
129 | with self.state.fs.open(self.frames_path, "r") as f:
130 | return list(ase.io.iread(f, format="extxyz"))
131 |
--------------------------------------------------------------------------------
/mlipx/nodes/formation_energy.py:
--------------------------------------------------------------------------------
1 | import typing as t
2 |
3 | import ase
4 | import pandas as pd
5 | import zntrack
6 | from tqdm import tqdm, trange
7 |
8 | from mlipx.abc import ASEKeys, NodeWithCalculator
9 | from mlipx.utils import rmse
10 |
11 |
12 | class CalculateFormationEnergy(zntrack.Node):
13 | """
14 | Calculate formation energy.
15 |
16 | Parameters
17 | ----------
18 | data : list[ase.Atoms]
19 | ASE atoms object with appropriate tags in info
20 | """
21 |
22 | data: list[ase.Atoms] = zntrack.deps()
23 | model: t.Optional[NodeWithCalculator] = zntrack.deps(None)
24 |
25 | formation_energy: list = zntrack.outs(independent=True)
26 | isolated_energies: dict = zntrack.outs(independent=True)
27 |
28 | plots: pd.DataFrame = zntrack.plots(
29 | y=["eform", "n_atoms"], independent=True, autosave=True
30 | )
31 |
32 | def get_isolated_energies(self) -> dict[str, float]:
33 | # get all unique elements
34 | isolated_energies = {}
35 | for atoms in tqdm(self.data, desc="Getting isolated energies"):
36 | for element in set(atoms.get_chemical_symbols()):
37 | if self.model is None:
38 | if element not in isolated_energies:
39 | isolated_energies[element] = atoms.info[
40 | ASEKeys.isolated_energies.value
41 | ][element]
42 | else:
43 | assert (
44 | isolated_energies[element]
45 | == atoms.info[ASEKeys.isolated_energies.value][element]
46 | )
47 | else:
48 | if element not in isolated_energies:
49 | box = ase.Atoms(
50 | element,
51 | positions=[[50, 50, 50]],
52 | cell=[100, 100, 100],
53 | pbc=True,
54 | )
55 | box.calc = self.model.get_calculator()
56 | isolated_energies[element] = box.get_potential_energy()
57 |
58 | return isolated_energies
59 |
60 | def run(self):
61 | self.formation_energy = []
62 | self.isolated_energies = self.get_isolated_energies()
63 |
64 | plots = []
65 |
66 | for atoms in self.data:
67 | chem = atoms.get_chemical_symbols()
68 | reference_energy = 0
69 | for element in chem:
70 | reference_energy += self.isolated_energies[element]
71 | E_form = atoms.get_potential_energy() - reference_energy
72 | self.formation_energy.append(E_form)
73 | plots.append({"eform": E_form, "n_atoms": len(atoms)})
74 |
75 | self.plots = pd.DataFrame(plots)
76 |
77 | @property
78 | def frames(self):
79 | for atom, energy in zip(self.data, self.formation_energy):
80 | atom.info[ASEKeys.formation_energy.value] = energy
81 | return self.data
82 |
83 |
84 | class CompareFormationEnergy(zntrack.Node):
85 | data: CalculateFormationEnergy = zntrack.deps()
86 | reference: CalculateFormationEnergy = zntrack.deps()
87 |
88 | plots: pd.DataFrame = zntrack.plots(autosave=True)
89 | rmse: dict = zntrack.metrics()
90 | error: dict = zntrack.metrics()
91 |
92 | def run(self):
93 | eform_rmse = rmse(self.data.plots["eform"], self.reference.plots["eform"])
94 | # e_rmse = rmse(self.data.plots["energy"], self.reference.plots["energy"])
95 | self.rmse = {
96 | "eform": eform_rmse,
97 | "eform_per_atom": eform_rmse / len(self.data.plots),
98 | }
99 |
100 | all_plots = []
101 |
102 | for row_idx in trange(len(self.data.plots)):
103 | plots = {}
104 | plots["adjusted_eform_error"] = (
105 | self.data.plots["eform"].iloc[row_idx] - eform_rmse
106 | ) - self.reference.plots["eform"].iloc[row_idx]
107 | plots["adjusted_eform"] = (
108 | self.data.plots["eform"].iloc[row_idx] - eform_rmse
109 | )
110 | plots["adjusted_eform_error_per_atom"] = (
111 | plots["adjusted_eform_error"] / self.data.plots["n_atoms"].iloc[row_idx]
112 | )
113 | all_plots.append(plots)
114 | self.plots = pd.DataFrame(all_plots)
115 |
116 | # iterate over plots and save min/max
117 | self.error = {}
118 | for key in self.plots.columns:
119 | if "_error" in key:
120 | stripped_key = key.replace("_error", "")
121 | self.error[f"{stripped_key}_max"] = self.plots[key].max()
122 | self.error[f"{stripped_key}_min"] = self.plots[key].min()
123 |
--------------------------------------------------------------------------------
/mlipx/nodes/generic_ase.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import importlib
3 | import typing as t
4 |
5 | from ase.calculators.calculator import Calculator
6 |
7 |
8 | class Device:
9 | AUTO = "auto"
10 | CPU = "cpu"
11 | CUDA = "cuda"
12 |
13 | @staticmethod
14 | def resolve_auto() -> t.Literal["cpu", "cuda"]:
15 | import torch
16 |
17 | return "cuda" if torch.cuda.is_available() else "cpu"
18 |
19 |
20 | # TODO: add files as dependencies somehow!
21 |
22 |
23 | @dataclasses.dataclass
24 | class GenericASECalculator:
25 | """Generic ASE calculator.
26 |
27 | Load any ASE calculator from a module and class name.
28 |
29 | Parameters
30 | ----------
31 | module : str
32 | Module name containing the calculator class.
33 | For LJ this would be 'ase.calculators.lj'.
34 | class_name : str
35 | Class name of the calculator.
36 | For LJ this would be 'LennardJones'.
37 | kwargs : dict, default=None
38 | Additional keyword arguments to pass to the calculator.
39 | For LJ this could be {'epsilon': 1.0, 'sigma': 1.0}.
40 | """
41 |
42 | module: str
43 | class_name: str
44 | kwargs: dict[str, t.Any] | None = None
45 | device: t.Literal["auto", "cpu", "cuda"] | None = None
46 |
47 | def get_calculator(self, **kwargs) -> Calculator:
48 | if self.kwargs is not None:
49 | kwargs.update(self.kwargs)
50 | module = importlib.import_module(self.module)
51 | cls = getattr(module, self.class_name)
52 | if self.device is None:
53 | return cls(**kwargs)
54 | elif self.device == "auto":
55 | return cls(**kwargs, device=Device.resolve_auto())
56 | else:
57 | return cls(**kwargs, device=self.device)
58 |
59 | @property
60 | def available(self) -> bool:
61 | try:
62 | importlib.import_module(self.module)
63 | return True
64 | except ImportError:
65 | return False
66 |
--------------------------------------------------------------------------------
/mlipx/nodes/invariances.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | import ase
4 | import ase.io
5 | import numpy as np
6 | import pandas as pd
7 | import plotly.graph_objects as go
8 | import tqdm
9 | import zntrack
10 |
11 | from mlipx.abc import ComparisonResults, NodeWithCalculator
12 |
13 |
14 | class InvarianceNode(zntrack.Node):
15 | """Base class for testing invariances."""
16 |
17 | model: NodeWithCalculator = zntrack.deps()
18 | data: list[ase.Atoms] = zntrack.deps()
19 | data_id: int = zntrack.params(-1)
20 | n_points: int = zntrack.params(50)
21 |
22 | metrics: dict = zntrack.metrics()
23 | plots: pd.DataFrame = zntrack.plots()
24 |
25 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz")
26 |
27 | def run(self):
28 | """Common logic for invariance testing."""
29 | atoms = self.data[self.data_id]
30 | calc = self.model.get_calculator()
31 | atoms.calc = calc
32 |
33 | rng = np.random.default_rng()
34 | energies = []
35 | for _ in tqdm.tqdm(range(self.n_points)):
36 | self.apply_transformation(atoms, rng)
37 | energies.append(atoms.get_potential_energy())
38 | ase.io.write(self.frames_path, atoms, append=True)
39 |
40 | self.plots = pd.DataFrame(energies, columns=["energy"])
41 |
42 | self.metrics = {
43 | "mean": float(np.mean(energies)),
44 | "std": float(np.std(energies)),
45 | }
46 |
47 | @property
48 | def frames(self):
49 | with self.state.fs.open(self.frames_path, "r") as f:
50 | return list(ase.io.iread(f, ":"))
51 |
52 | def apply_transformation(self, atoms_copy: ase.Atoms, rng: np.random.Generator):
53 | """To be implemented by child classes."""
54 | raise NotImplementedError("Subclasses must implement apply_transformation()")
55 |
56 | @staticmethod
57 | def compare(*nodes: "InvarianceNode") -> ComparisonResults:
58 | if len(nodes) == 0:
59 | raise ValueError("No nodes to compare")
60 |
61 | fig = go.Figure()
62 | for node in nodes:
63 | fig.add_trace(
64 | go.Scatter(
65 | x=np.arange(node.n_points),
66 | y=node.plots["energy"] - node.metrics["mean"],
67 | mode="markers",
68 | name=node.name.replace(f"_{node.__class__.__name__}", ""),
69 | )
70 | )
71 |
72 | fig.update_layout(
73 | title=f"Energy vs step ({nodes[0].__class__.__name__})",
74 | xaxis_title="Steps",
75 | yaxis_title="Adjusted energy",
76 | plot_bgcolor="rgba(0, 0, 0, 0)",
77 | paper_bgcolor="rgba(0, 0, 0, 0)",
78 | )
79 | fig.update_xaxes(
80 | showgrid=True,
81 | gridwidth=1,
82 | gridcolor="rgba(120, 120, 120, 0.3)",
83 | zeroline=False,
84 | )
85 | fig.update_yaxes(
86 | showgrid=True,
87 | gridwidth=1,
88 | gridcolor="rgba(120, 120, 120, 0.3)",
89 | zeroline=False,
90 | )
91 |
92 | return ComparisonResults(
93 | frames=nodes[0].frames, figures={"energy_vs_steps_adjusted": fig}
94 | )
95 |
96 |
97 | class TranslationalInvariance(InvarianceNode):
98 | """Test translational invariance by random box translocation."""
99 |
100 | def apply_transformation(self, atoms_copy: ase.Atoms, rng: np.random.Generator):
101 | translation = rng.uniform(-1, 1, 3)
102 | atoms_copy.positions += translation
103 |
104 |
105 | class RotationalInvariance(InvarianceNode):
106 | """Test rotational invariance by random rotation of the box."""
107 |
108 | def apply_transformation(self, atoms_copy: ase.Atoms, rng: np.random.Generator):
109 | angle = rng.uniform(-90, 90)
110 | direction = rng.choice(["x", "y", "z"])
111 | atoms_copy.rotate(angle, direction, rotate_cell=any(atoms_copy.pbc))
112 |
113 |
114 | class PermutationInvariance(InvarianceNode):
115 | """Test permutation invariance by permutation of atoms of the same species."""
116 |
117 | def apply_transformation(self, atoms_copy: ase.Atoms, rng: np.random.Generator):
118 | species = np.unique(atoms_copy.get_chemical_symbols())
119 | for s in species:
120 | indices = np.where(atoms_copy.get_chemical_symbols() == s)[0]
121 | rng.shuffle(indices)
122 | atoms_copy.positions[indices] = rng.permutation(
123 | atoms_copy.positions[indices], axis=0
124 | )
125 |
--------------------------------------------------------------------------------
/mlipx/nodes/io.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import typing as t
3 |
4 | import ase.io
5 | import h5py
6 | import znh5md
7 | import zntrack
8 |
9 |
10 | class LoadDataFile(zntrack.Node):
11 | """Load a trajectory file.
12 |
13 | Entry point of trajectory data for the use in other nodes.
14 |
15 | Parameters
16 | ----------
17 | path : str | pathlib.Path
18 | Path to the trajectory file.
19 | start : int, default=0
20 | Index of the first frame to load.
21 | stop : int, default=None
22 | Index of the last frame to load.
23 | step : int, default=1
24 | Step size between frames.
25 |
26 | Attributes
27 | ----------
28 | frames : list[ase.Atoms]
29 | Loaded frames.
30 | """
31 |
32 | path: str | pathlib.Path = zntrack.deps_path()
33 | # TODO these are not used
34 | start: int = zntrack.params(0)
35 | stop: t.Optional[int] = zntrack.params(None)
36 | step: int = zntrack.params(1)
37 |
38 | def run(self):
39 | pass
40 |
41 | @property
42 | def frames(self) -> list[ase.Atoms]:
43 | if pathlib.Path(self.path).suffix in [".h5", ".h5md"]:
44 | with self.state.fs.open(self.path, "rb") as f:
45 | with h5py.File(f) as file:
46 | return znh5md.IO(file_handle=file)[
47 | self.start : self.stop : self.step
48 | ]
49 | else:
50 | format = pathlib.Path(self.path).suffix.lstrip(".")
51 | if format == "xyz":
52 | format = "extxyz" # force ase to use the extxyz reader
53 | with self.state.fs.open(self.path, "r") as f:
54 | return list(
55 | ase.io.iread(
56 | f, format=format, index=slice(self.start, self.stop, self.step)
57 | )
58 | )
59 |
--------------------------------------------------------------------------------
/mlipx/nodes/modifier.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import typing as t
3 |
4 | from ase import units
5 |
6 | from mlipx.abc import DynamicsModifier
7 |
8 |
9 | @dataclasses.dataclass
10 | class TemperatureRampModifier(DynamicsModifier):
11 | """Ramp the temperature from start_temperature to temperature.
12 |
13 | Attributes
14 | ----------
15 | start_temperature: float, optional
16 | temperature to start from, if None, the temperature of the thermostat is used.
17 | end_temperature: float
18 | temperature to ramp to.
19 | interval: int, default 1
20 | interval in which the temperature is changed.
21 | total_steps: int
22 | total number of steps in the simulation.
23 |
24 | References
25 | ----------
26 | Code taken from ipsuite/calculators/ase_md.py
27 | """
28 |
29 | end_temperature: float
30 | total_steps: int
31 | start_temperature: t.Optional[float] = None
32 | interval: int = 1
33 |
34 | def modify(self, thermostat, step):
35 | # we use the thermostat, so we can also modify e.g. temperature
36 | if self.start_temperature is None:
37 | # different thermostats call the temperature attribute differently
38 | if temp := getattr(thermostat, "temp", None):
39 | self.start_temperature = temp / units.kB
40 | elif temp := getattr(thermostat, "temperature", None):
41 | self.start_temperature = temp / units.kB
42 | else:
43 | raise AttributeError("No temperature attribute found in thermostat.")
44 |
45 | percentage = step / (self.total_steps - 1)
46 | new_temperature = (
47 | 1 - percentage
48 | ) * self.start_temperature + percentage * self.end_temperature
49 | if step % self.interval == 0:
50 | thermostat.set_temperature(temperature_K=new_temperature)
51 |
--------------------------------------------------------------------------------
/mlipx/nodes/molecular_dynamics.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import pathlib
3 |
4 | import ase.io
5 | import ase.units
6 | import numpy as np
7 | import pandas as pd
8 | import plotly.express as px
9 | import plotly.graph_objects as go
10 | import tqdm
11 | import zntrack
12 | from ase.md import Langevin
13 |
14 | from mlipx.abc import (
15 | ComparisonResults,
16 | DynamicsModifier,
17 | DynamicsObserver,
18 | NodeWithCalculator,
19 | NodeWithMolecularDynamics,
20 | )
21 |
22 |
23 | @dataclasses.dataclass
24 | class LangevinConfig:
25 | """Configure a Langevin thermostat for molecular dynamics.
26 |
27 | Parameters
28 | ----------
29 | timestep : float
30 | Time step for the molecular dynamics simulation in fs.
31 | temperature : float
32 | Temperature of the thermostat.
33 | friction : float
34 | Friction coefficient of the thermostat.
35 | """
36 |
37 | timestep: float
38 | temperature: float
39 | friction: float
40 |
41 | def get_molecular_dynamics(self, atoms) -> Langevin:
42 | return Langevin(
43 | atoms,
44 | timestep=self.timestep * ase.units.fs,
45 | temperature_K=self.temperature,
46 | friction=self.friction,
47 | )
48 |
49 |
50 | class MolecularDynamics(zntrack.Node):
51 | """Run molecular dynamics simulation.
52 |
53 | Parameters
54 | ----------
55 | model : NodeWithCalculator
56 | Node providing the calculator object for the simulation.
57 | thermostat : LangevinConfig
58 | Node providing the thermostat object for the simulation.
59 | data : list[ase.Atoms]
60 | Initial configurations for the simulation.
61 | data_id : int, default=-1
62 | Index of the initial configuration to use.
63 | steps : int, default=100
64 | Number of steps to run the simulation.
65 | """
66 |
67 | model: NodeWithCalculator = zntrack.deps()
68 | thermostat: NodeWithMolecularDynamics = zntrack.deps()
69 | data: list[ase.Atoms] = zntrack.deps()
70 | data_id: int = zntrack.params(-1)
71 | steps: int = zntrack.params(100)
72 | observers: list[DynamicsObserver] = zntrack.deps(None)
73 | modifiers: list[DynamicsModifier] = zntrack.deps(None)
74 |
75 | observer_metrics: dict = zntrack.metrics()
76 | plots: pd.DataFrame = zntrack.plots(y=["energy", "fmax"], autosave=True)
77 |
78 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz")
79 |
80 | def run(self):
81 | if self.observers is None:
82 | self.observers = []
83 | if self.modifiers is None:
84 | self.modifiers = []
85 | atoms = self.data[self.data_id]
86 | atoms.calc = self.model.get_calculator()
87 | dyn = self.thermostat.get_molecular_dynamics(atoms)
88 | for obs in self.observers:
89 | obs.initialize(atoms)
90 |
91 | self.observer_metrics = {}
92 | self.plots = pd.DataFrame(columns=["energy", "fmax", "fnorm"])
93 |
94 | for idx, _ in enumerate(
95 | tqdm.tqdm(dyn.irun(steps=self.steps), total=self.steps)
96 | ):
97 | ase.io.write(self.frames_path, atoms, append=True)
98 | plots = {
99 | "energy": atoms.get_potential_energy(),
100 | "fmax": np.max(np.linalg.norm(atoms.get_forces(), axis=1)),
101 | "fnorm": np.linalg.norm(atoms.get_forces()),
102 | }
103 | self.plots.loc[len(self.plots)] = plots
104 |
105 | for obs in self.observers:
106 | if obs.check(atoms):
107 | self.observer_metrics[obs.name] = idx
108 |
109 | if len(self.observer_metrics) > 0:
110 | break
111 |
112 | for mod in self.modifiers:
113 | mod.modify(dyn, idx)
114 |
115 | for obs in self.observers:
116 | # document all attached observers
117 | self.observer_metrics[obs.name] = self.observer_metrics.get(obs.name, -1)
118 |
119 | @property
120 | def frames(self) -> list[ase.Atoms]:
121 | with self.state.fs.open(self.frames_path, "r") as f:
122 | return list(ase.io.iread(f, format="extxyz"))
123 |
124 | @property
125 | def figures(self) -> dict[str, go.Figure]:
126 | plots = {}
127 | for key in self.plots.columns:
128 | fig = px.line(
129 | self.plots,
130 | x=self.plots.index,
131 | y=key,
132 | title=key,
133 | )
134 | fig.update_traces(
135 | customdata=np.stack([np.arange(len(self.plots))], axis=1),
136 | )
137 | plots[key] = fig
138 | return plots
139 |
140 | @staticmethod
141 | def compare(*nodes: "MolecularDynamics") -> ComparisonResults:
142 | frames = sum([node.frames for node in nodes], [])
143 | offset = 0
144 | fig = go.Figure()
145 | for _, node in enumerate(nodes):
146 | energies = [atoms.get_potential_energy() for atoms in node.frames]
147 | fig.add_trace(
148 | go.Scatter(
149 | x=list(range(len(energies))),
150 | y=energies,
151 | mode="lines+markers",
152 | name=node.name.replace(f"_{node.__class__.__name__}", ""),
153 | customdata=np.stack([np.arange(len(energies)) + offset], axis=1),
154 | )
155 | )
156 | offset += len(energies)
157 |
158 | fig.update_layout(
159 | title="Energy vs. step",
160 | xaxis_title="Step",
161 | yaxis_title="Energy",
162 | )
163 |
164 | fig.update_layout(
165 | plot_bgcolor="rgba(0, 0, 0, 0)",
166 | paper_bgcolor="rgba(0, 0, 0, 0)",
167 | )
168 | fig.update_xaxes(
169 | showgrid=True,
170 | gridwidth=1,
171 | gridcolor="rgba(120, 120, 120, 0.3)",
172 | zeroline=False,
173 | )
174 | fig.update_yaxes(
175 | showgrid=True,
176 | gridwidth=1,
177 | gridcolor="rgba(120, 120, 120, 0.3)",
178 | zeroline=False,
179 | )
180 |
181 | # Now we set the first energy to zero for better compareability.
182 |
183 | offset = 0
184 | fig_adjusted = go.Figure()
185 | for _, node in enumerate(nodes):
186 | energies = np.array([atoms.get_potential_energy() for atoms in node.frames])
187 | energies -= energies[0]
188 | fig_adjusted.add_trace(
189 | go.Scatter(
190 | x=list(range(len(energies))),
191 | y=energies,
192 | mode="lines+markers",
193 | name=node.name.replace(f"_{node.__class__.__name__}", ""),
194 | customdata=np.stack([np.arange(len(energies)) + offset], axis=1),
195 | )
196 | )
197 | offset += len(energies)
198 |
199 | fig_adjusted.update_layout(
200 | title="Adjusted energy vs. step",
201 | xaxis_title="Step",
202 | yaxis_title="Adjusted energy",
203 | )
204 |
205 | fig_adjusted.update_layout(
206 | plot_bgcolor="rgba(0, 0, 0, 0)",
207 | paper_bgcolor="rgba(0, 0, 0, 0)",
208 | )
209 | fig_adjusted.update_xaxes(
210 | showgrid=True,
211 | gridwidth=1,
212 | gridcolor="rgba(120, 120, 120, 0.3)",
213 | zeroline=False,
214 | )
215 | fig_adjusted.update_yaxes(
216 | showgrid=True,
217 | gridwidth=1,
218 | gridcolor="rgba(120, 120, 120, 0.3)",
219 | zeroline=False,
220 | )
221 |
222 | return ComparisonResults(
223 | frames=frames,
224 | figures={"energy_vs_steps": fig, "energy_vs_steps_adjusted": fig_adjusted},
225 | )
226 |
--------------------------------------------------------------------------------
/mlipx/nodes/mp_api.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | import ase.io
4 | import zntrack
5 | from mp_api import client
6 | from pymatgen.io.ase import AseAtomsAdaptor
7 |
8 |
9 | class MPRester(zntrack.Node):
10 | """Search the materials project database.
11 |
12 | Parameters
13 | ----------
14 | search_kwargs: dict
15 | The search parameters for the materials project database.
16 |
17 | Example
18 | -------
19 | >>> os.environ["MP_API_KEY"] = "your_api_key"
20 | >>> MPRester(search_kwargs={"material_ids": ["mp-1234"]})
21 |
22 | """
23 |
24 | search_kwargs: dict = zntrack.params()
25 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz")
26 |
27 | def run(self):
28 | with client.MPRester() as mpr:
29 | docs = mpr.materials.search(**self.search_kwargs)
30 |
31 | frames = []
32 | adaptor = AseAtomsAdaptor()
33 |
34 | for entry in docs:
35 | structure = entry.structure
36 | atoms = adaptor.get_atoms(structure)
37 | frames.append(atoms)
38 |
39 | ase.io.write(self.frames_path, frames)
40 |
41 | @property
42 | def frames(self) -> list[ase.Atoms]:
43 | with self.state.fs.open(self.frames_path, mode="r") as f:
44 | return list(ase.io.iread(f, format="extxyz"))
45 |
--------------------------------------------------------------------------------
/mlipx/nodes/observer.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import warnings
3 |
4 | import ase
5 | import numpy as np
6 |
7 | from mlipx.abc import DynamicsObserver
8 |
9 |
10 | @dataclasses.dataclass
11 | class MaximumForceObserver(DynamicsObserver):
12 | """Evaluate if the maximum force on a single atom exceeds a threshold.
13 |
14 | Parameters
15 | ----------
16 | f_max : float
17 | Maximum allowed force norm on a single atom
18 |
19 |
20 | Example
21 | -------
22 |
23 | >>> import zntrack, mlipx
24 | >>> project = zntrack.Project()
25 | >>> observer = mlipx.MaximumForceObserver(f_max=0.1)
26 | >>> with project:
27 | ... md = mlipx.MolecularDynamics(
28 | ... observers=[observer],
29 | ... **kwargs
30 | ... )
31 | >>> project.build()
32 | """
33 |
34 | f_max: float
35 |
36 | def check(self, atoms: ase.Atoms) -> bool:
37 | """Check if the maximum force on a single atom exceeds the threshold.
38 |
39 | Parameters
40 | ----------
41 | atoms : ase.Atoms
42 | Atoms object to evaluate
43 | """
44 |
45 | max_force = np.linalg.norm(atoms.get_forces(), axis=1).max()
46 | if max_force > self.f_max:
47 | warnings.warn(f"Maximum force {max_force} exceeds {self.f_max}")
48 | return True
49 | return False
50 |
--------------------------------------------------------------------------------
/mlipx/nodes/orca.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import os
3 | from pathlib import Path
4 |
5 | from ase.calculators.orca import ORCA, OrcaProfile
6 |
7 |
8 | @dataclasses.dataclass
9 | class OrcaSinglePoint:
10 | """Use ORCA to perform a single point calculation.
11 |
12 | Parameters
13 | ----------
14 | orcasimpleinput : str
15 | ORCA input string.
16 | You can use something like "PBE def2-TZVP TightSCF EnGrad".
17 | orcablocks : str
18 | ORCA input blocks.
19 | You can use something like "%pal nprocs 8 end".
20 | orca_shell : str, optional
21 | Path to the ORCA executable.
22 | The environment variable MLIPX_ORCA will be used if not provided.
23 | """
24 |
25 | orcasimpleinput: str
26 | orcablocks: str
27 | orca_shell: str | None = None
28 |
29 | def get_calculator(self, directory: str | Path) -> ORCA:
30 | profile = OrcaProfile(command=self.orca_shell or os.environ["MLIPX_ORCA"])
31 |
32 | calc = ORCA(
33 | profile=profile,
34 | orcasimpleinput=self.orcasimpleinput,
35 | orcablocks=self.orcablocks,
36 | directory=directory,
37 | )
38 | return calc
39 |
40 | @property
41 | def available(self) -> None:
42 | return None
43 |
--------------------------------------------------------------------------------
/mlipx/nodes/phase_diagram.py:
--------------------------------------------------------------------------------
1 | # skip linting for this file
2 |
3 | import itertools
4 | import os
5 | import typing as t
6 |
7 | import ase.io
8 | import pandas as pd
9 | import plotly.express as px
10 | import plotly.graph_objects as go
11 | import zntrack
12 | from ase.optimize import BFGS
13 | from mp_api.client import MPRester
14 | from pymatgen.analysis.phase_diagram import PDPlotter
15 | from pymatgen.analysis.phase_diagram import PhaseDiagram as pmg_PhaseDiagram
16 | from pymatgen.entries.compatibility import (
17 | MaterialsProject2020Compatibility,
18 | )
19 | from pymatgen.entries.computed_entries import (
20 | ComputedEntry,
21 | )
22 |
23 | from mlipx.abc import ComparisonResults, NodeWithCalculator
24 |
25 |
26 | class PhaseDiagram(zntrack.Node):
27 | """Compute the phase diagram for a given set of structures.
28 |
29 | Parameters
30 | ----------
31 | data : list[ase.Atoms]
32 | List of structures to evaluate.
33 | model : NodeWithCalculator
34 | Node providing the calculator object for the energy calculations.
35 | chemsys: list[str], defaeult=None
36 | The set of chemical symbols to construct phase diagram.
37 | data_ids : list[int], default=None
38 | Index of the structure to evaluate.
39 | geo_opt: bool, default=False
40 | Whether to perform geometry optimization before calculating the
41 | formation energy of each structure.
42 | fmax: float, default=0.05
43 | The maximum force stopping rule for geometry optimizations.
44 |
45 | Attributes
46 | ----------
47 | results : pd.DataFrame
48 | DataFrame with the data_id, potential energy and formation energy.
49 | plots : dict[str, go.Figure]
50 | Dictionary with the phase diagram (and formation energy plot).
51 |
52 | """
53 |
54 | model: NodeWithCalculator = zntrack.deps()
55 | data: list[ase.Atoms] = zntrack.deps()
56 | chemsys: list[str] = zntrack.params(None)
57 | data_ids: list[int] = zntrack.params(None)
58 | geo_opt: bool = zntrack.params(False)
59 | fmax: float = zntrack.params(0.05)
60 | frames_path: str = zntrack.outs_path(zntrack.nwd / "frames.xyz")
61 | results: pd.DataFrame = zntrack.plots(x="data_id", y="formation_energy")
62 | phase_diagram: t.Any = zntrack.outs()
63 |
64 | def run(self): # noqa C901
65 | if self.data_ids is None:
66 | atoms_list = self.data
67 | else:
68 | atoms_list = [self.data[i] for i in self.data_id]
69 | if self.model is not None:
70 | calc = self.model.get_calculator()
71 |
72 | U_metal_set = {"Co", "Cr", "Fe", "Mn", "Mo", "Ni", "V", "W"}
73 | U_settings = {
74 | "Co": 3.32,
75 | "Cr": 3.7,
76 | "Fe": 5.3,
77 | "Mn": 3.9,
78 | "Mo": 4.38,
79 | "Ni": 6.2,
80 | "V": 3.25,
81 | "W": 6.2,
82 | }
83 | try:
84 | api_key = os.environ["MP_API_KEY"]
85 | except KeyError:
86 | api_key = None
87 |
88 | entries, epots = [], []
89 | for atoms in atoms_list:
90 | metals = [s for s in set(atoms.symbols) if s not in ["O", "H"]]
91 | hubbards = {}
92 | if set(metals) & U_metal_set:
93 | run_type = "GGA+U"
94 | is_hubbard = True
95 | for m in metals:
96 | hubbards[m] = U_settings.get(m, 0)
97 | else:
98 | run_type = "GGA"
99 | is_hubbard = False
100 |
101 | if self.model is not None:
102 | atoms.calc = calc
103 | if self.geo_opt:
104 | dyn = BFGS(atoms)
105 | dyn.run(fmax=self.fmax)
106 | epot = atoms.get_potential_energy()
107 | ase.io.write(self.frames_path, atoms, append=True)
108 | epots.append(epot)
109 | amt_dict = {
110 | m: len([a for a in atoms if a.symbol == m]) for m in set(atoms.symbols)
111 | }
112 | entry = ComputedEntry(
113 | composition=amt_dict,
114 | energy=epot,
115 | parameters={
116 | "run_type": run_type,
117 | "software": "N/A",
118 | "oxide_type": "oxide",
119 | "is_hubbard": is_hubbard,
120 | "hubbards": hubbards,
121 | },
122 | )
123 | entries.append(entry)
124 | compat = MaterialsProject2020Compatibility()
125 | computed_entries = compat.process_entries(entries)
126 | if api_key is None:
127 | mp_entries = []
128 | else:
129 | mpr = MPRester(api_key)
130 | if self.chemsys is None:
131 | chemsys = set(
132 | itertools.chain.from_iterable(atoms.symbols for atoms in atoms_list)
133 | )
134 | else:
135 | chemsys = self.chemsys
136 | mp_entries = mpr.get_entries_in_chemsys(chemsys)
137 | all_entries = computed_entries + mp_entries
138 | self.phase_diagram = pmg_PhaseDiagram(all_entries)
139 |
140 | row_dicts = []
141 | for i, entry in enumerate(computed_entries):
142 | if self.data_ids is None:
143 | data_id = i
144 | else:
145 | data_id = self.data_id[i]
146 | eform = self.phase_diagram.get_form_energy_per_atom(entry)
147 | row_dicts.append(
148 | {
149 | "data_id": data_id,
150 | "potential_energy": epots[i],
151 | "formation_energy": eform,
152 | },
153 | )
154 | self.results = pd.DataFrame(row_dicts)
155 |
156 | @property
157 | def figures(self) -> dict[str, go.Figure]:
158 | plotter = PDPlotter(self.phase_diagram)
159 | fig1 = plotter.get_plot()
160 | fig2 = px.line(self.results, x="data_id", y="formation_energy")
161 | fig2.update_layout(title="Formation Energy Plot")
162 | pd_df = pd.DataFrame(
163 | [len(self.phase_diagram.stable_entries)], columns=["Stable_phases"]
164 | )
165 | fig3 = px.bar(pd_df, y="Stable_phases")
166 |
167 | return {
168 | "phase-diagram": fig1,
169 | "formation-energy-plot": fig2,
170 | "stable_phases": fig3,
171 | }
172 |
173 | @staticmethod
174 | def compare(*nodes: "PhaseDiagram") -> ComparisonResults:
175 | figures = {}
176 |
177 | for node in nodes:
178 | # Extract a unique identifier for the node
179 | node_identifier = node.name.replace(f"_{node.__class__.__name__}", "")
180 |
181 | # Update and store the figures directly
182 | for key, fig in node.figures.items():
183 | fig.update_layout(
184 | title=node_identifier,
185 | plot_bgcolor="rgba(0, 0, 0, 0)",
186 | paper_bgcolor="rgba(0, 0, 0, 0)",
187 | )
188 | fig.update_xaxes(
189 | showgrid=True,
190 | gridwidth=1,
191 | gridcolor="rgba(120, 120, 120, 0.3)",
192 | zeroline=False,
193 | )
194 | fig.update_yaxes(
195 | showgrid=True,
196 | gridwidth=1,
197 | gridcolor="rgba(120, 120, 120, 0.3)",
198 | zeroline=False,
199 | )
200 | figures[f"{node_identifier}-{key}"] = fig
201 |
202 | return {
203 | "frames": nodes[0].frames,
204 | "figures": figures,
205 | }
206 |
207 | @property
208 | def frames(self) -> list[ase.Atoms]:
209 | with self.state.fs.open(self.frames_path, "r") as f:
210 | return list(ase.io.iread(f, format="extxyz"))
211 |
--------------------------------------------------------------------------------
/mlipx/nodes/rattle.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | import ase.io
4 | import zntrack
5 |
6 |
7 | class Rattle(zntrack.Node):
8 | data: list[ase.Atoms] = zntrack.deps()
9 | stdev: float = zntrack.params(0.001)
10 | seed: int = zntrack.params(42)
11 |
12 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz")
13 |
14 | def run(self):
15 | for atoms in self.data:
16 | atoms.rattle(stdev=self.stdev, seed=self.seed)
17 | ase.io.write(self.frames_path, atoms, append=True)
18 |
19 | @property
20 | def frames(self) -> list[ase.Atoms]:
21 | with self.state.fs.open(self.frames_path, "r") as f:
22 | return list(ase.io.iread(f, format="extxyz"))
23 |
--------------------------------------------------------------------------------
/mlipx/nodes/smiles.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | import ase
4 | import ase.io as aio
5 | import zntrack
6 |
7 |
8 | class Smiles2Conformers(zntrack.Node):
9 | """Create conformers from a SMILES string.
10 |
11 | Parameters
12 | ----------
13 | smiles : str
14 | The SMILES string.
15 | num_confs : int
16 | The number of conformers to generate.
17 | random_seed : int
18 | The random seed.
19 | max_attempts : int
20 | The maximum number of attempts.
21 | """
22 |
23 | smiles: str = zntrack.params()
24 | num_confs: int = zntrack.params()
25 | random_seed: int = zntrack.params(42)
26 | max_attempts: int = zntrack.params(1000)
27 |
28 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz")
29 |
30 | def run(self):
31 | from rdkit2ase import smiles2conformers
32 |
33 | conformers = smiles2conformers(
34 | self.smiles,
35 | numConfs=self.num_confs,
36 | randomSeed=self.random_seed,
37 | maxAttempts=self.max_attempts,
38 | )
39 | aio.write(self.frames_path, conformers)
40 |
41 | @property
42 | def frames(self) -> list[ase.Atoms]:
43 | with self.state.fs.open(self.frames_path, "r") as f:
44 | return list(aio.iread(f, format="extxyz"))
45 |
46 |
47 | class BuildBox(zntrack.Node):
48 | """Build a box from a list of atoms.
49 |
50 | Parameters
51 | ----------
52 | data : list[list[ase.Atoms]]
53 | A list of lists of ASE Atoms objects representing
54 | the molecules to be packed.
55 | counts : list[int]
56 | A list of integers representing the number of each
57 | type of molecule.
58 | density : float
59 | The target density of the packed system in kg/m^3
60 |
61 | """
62 |
63 | data: list[list[ase.Atoms]] = zntrack.deps()
64 | counts: list[int] = zntrack.params()
65 | density: float = zntrack.params(1000)
66 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.xyz")
67 |
68 | def run(self):
69 | from rdkit2ase import pack
70 |
71 | atoms = pack(data=self.data, counts=self.counts, density=self.density)
72 | aio.write(self.frames_path, atoms)
73 |
74 | @property
75 | def frames(self) -> list[ase.Atoms]:
76 | with self.state.fs.open(self.frames_path, "r") as f:
77 | return list(aio.iread(f, format="extxyz"))
78 |
--------------------------------------------------------------------------------
/mlipx/nodes/structure_optimization.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | import ase.io
4 | import ase.optimize as opt
5 | import numpy as np
6 | import pandas as pd
7 | import plotly.graph_objects as go
8 | import zntrack
9 |
10 | from mlipx.abc import ComparisonResults, NodeWithCalculator, Optimizer
11 |
12 |
13 | class StructureOptimization(zntrack.Node):
14 | """Structure optimization Node.
15 |
16 | Relax the geometry for the selected `ase.Atoms`.
17 |
18 | Parameters
19 | ----------
20 | data : list[ase.Atoms]
21 | Atoms to relax.
22 | data_id: int, default=-1
23 | The index of the ase.Atoms in `data` to optimize.
24 | optimizer : Optimizer
25 | Optimizer to use.
26 | model : NodeWithCalculator
27 | Model to use.
28 | fmax : float
29 | Maximum force to reach before stopping.
30 | steps : int
31 | Maximum number of steps for each optimization.
32 | plots : pd.DataFrame
33 | Resulting energy and fmax for each step.
34 | trajectory_path : str
35 | Output directory for the optimization trajectories.
36 |
37 | """
38 |
39 | data: list[ase.Atoms] = zntrack.deps()
40 | data_id: int = zntrack.params(-1)
41 | optimizer: Optimizer = zntrack.params(Optimizer.LBFGS.value)
42 | model: NodeWithCalculator = zntrack.deps()
43 | fmax: float = zntrack.params(0.05)
44 | steps: int = zntrack.params(100_000_000)
45 | plots: pd.DataFrame = zntrack.plots(y=["energy", "fmax"], x="step")
46 |
47 | frames_path: pathlib.Path = zntrack.outs_path(zntrack.nwd / "frames.traj")
48 |
49 | def run(self):
50 | optimizer = getattr(opt, self.optimizer)
51 | calc = self.model.get_calculator()
52 |
53 | atoms = self.data[self.data_id]
54 | self.frames_path.parent.mkdir(exist_ok=True)
55 |
56 | energies = []
57 | fmax = []
58 |
59 | def metrics_callback():
60 | energies.append(atoms.get_potential_energy())
61 | fmax.append(np.linalg.norm(atoms.get_forces(), axis=-1).max())
62 |
63 | atoms.calc = calc
64 | dyn = optimizer(
65 | atoms,
66 | trajectory=self.frames_path.as_posix(),
67 | )
68 | dyn.attach(metrics_callback)
69 | dyn.run(fmax=self.fmax, steps=self.steps)
70 |
71 | self.plots = pd.DataFrame({"energy": energies, "fmax": fmax})
72 | self.plots.index.name = "step"
73 |
74 | @property
75 | def frames(self) -> list[ase.Atoms]:
76 | with self.state.fs.open(self.frames_path, "rb") as f:
77 | return list(ase.io.iread(f, format="traj"))
78 |
79 | @property
80 | def figures(self) -> dict[str, go.Figure]:
81 | figure = go.Figure()
82 |
83 | energies = [atoms.get_potential_energy() for atoms in self.frames]
84 | figure.add_trace(
85 | go.Scatter(
86 | x=list(range(len(energies))),
87 | y=energies,
88 | mode="lines+markers",
89 | customdata=np.stack([np.arange(len(energies))], axis=1),
90 | )
91 | )
92 |
93 | figure.update_layout(
94 | title="Energy vs. Steps",
95 | xaxis_title="Step",
96 | yaxis_title="Energy",
97 | )
98 |
99 | ffigure = go.Figure()
100 | ffigure.add_trace(
101 | go.Scatter(
102 | x=self.plots.index,
103 | y=self.plots["fmax"],
104 | mode="lines+markers",
105 | customdata=np.stack([np.arange(len(energies))], axis=1),
106 | )
107 | )
108 |
109 | ffigure.update_layout(
110 | title="Fmax vs. Steps",
111 | xaxis_title="Step",
112 | yaxis_title="Maximum force",
113 | )
114 | return {"energy_vs_steps": figure, "fmax_vs_steps": ffigure}
115 |
116 | @staticmethod
117 | def compare(*nodes: "StructureOptimization") -> ComparisonResults:
118 | frames = sum([node.frames for node in nodes], [])
119 | offset = 0
120 | fig = go.Figure()
121 | for idx, node in enumerate(nodes):
122 | energies = [atoms.get_potential_energy() for atoms in node.frames]
123 | fig.add_trace(
124 | go.Scatter(
125 | x=list(range(len(energies))),
126 | y=energies,
127 | mode="lines+markers",
128 | name=node.name.replace(f"_{node.__class__.__name__}", ""),
129 | customdata=np.stack([np.arange(len(energies)) + offset], axis=1),
130 | )
131 | )
132 | offset += len(energies)
133 |
134 | fig.update_layout(
135 | title="Energy vs. Steps",
136 | xaxis_title="Step",
137 | yaxis_title="Energy",
138 | plot_bgcolor="rgba(0, 0, 0, 0)",
139 | paper_bgcolor="rgba(0, 0, 0, 0)",
140 | )
141 | fig.update_xaxes(
142 | showgrid=True,
143 | gridwidth=1,
144 | gridcolor="rgba(120, 120, 120, 0.3)",
145 | zeroline=False,
146 | )
147 | fig.update_yaxes(
148 | showgrid=True,
149 | gridwidth=1,
150 | gridcolor="rgba(120, 120, 120, 0.3)",
151 | zeroline=False,
152 | )
153 |
154 | # now adjusted
155 |
156 | offset = 0
157 | fig_adjusted = go.Figure()
158 | for idx, node in enumerate(nodes):
159 | energies = np.array([atoms.get_potential_energy() for atoms in node.frames])
160 | energies -= energies[0]
161 | fig_adjusted.add_trace(
162 | go.Scatter(
163 | x=list(range(len(energies))),
164 | y=energies,
165 | mode="lines+markers",
166 | name=node.name.replace(f"_{node.__class__.__name__}", ""),
167 | customdata=np.stack([np.arange(len(energies)) + offset], axis=1),
168 | )
169 | )
170 | offset += len(energies)
171 |
172 | fig_adjusted.update_layout(
173 | title="Adjusted energy vs. Steps",
174 | xaxis_title="Step",
175 | yaxis_title="Adjusted energy",
176 | plot_bgcolor="rgba(0, 0, 0, 0)",
177 | paper_bgcolor="rgba(0, 0, 0, 0)",
178 | )
179 | fig_adjusted.update_xaxes(
180 | showgrid=True,
181 | gridwidth=1,
182 | gridcolor="rgba(120, 120, 120, 0.3)",
183 | zeroline=False,
184 | )
185 | fig_adjusted.update_yaxes(
186 | showgrid=True,
187 | gridwidth=1,
188 | gridcolor="rgba(120, 120, 120, 0.3)",
189 | zeroline=False,
190 | )
191 |
192 | return ComparisonResults(
193 | frames=frames,
194 | figures={"energy_vs_steps": fig, "adjusted_energy_vs_steps": fig_adjusted},
195 | )
196 |
--------------------------------------------------------------------------------
/mlipx/nodes/updated_frames.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | import ase
4 | from ase.calculators.calculator import Calculator, all_changes
5 |
6 |
7 | class _UpdateFramesCalc(Calculator):
8 | implemented_properties = ["energy", "forces"]
9 |
10 | def __init__(
11 | self, results_mapping: dict, info_mapping: dict, arrays_mapping: dict, **kwargs
12 | ):
13 | Calculator.__init__(self, **kwargs)
14 | self.results_mapping = results_mapping
15 | self.info_mapping = info_mapping
16 | self.arrays_mapping = arrays_mapping
17 |
18 | def calculate(
19 | self,
20 | atoms=ase.Atoms,
21 | properties=None,
22 | system_changes=all_changes,
23 | ):
24 | if properties is None:
25 | properties = self.implemented_properties
26 | Calculator.calculate(self, atoms, properties, system_changes)
27 | for target, key in self.results_mapping.items():
28 | if key is None:
29 | continue
30 | try:
31 | value = atoms.info[key]
32 | except KeyError:
33 | value = atoms.arrays[key]
34 | self.results[target] = value
35 |
36 | for target, key in self.info_mapping.items():
37 | # rename the key to target
38 | atoms.info[target] = atoms.info[key]
39 | del atoms.info[key]
40 |
41 | for target, key in self.arrays_mapping.items():
42 | # rename the key to target
43 | atoms.arrays[target] = atoms.arrays[key]
44 | del atoms.arrays[key]
45 |
46 |
47 | # TODO: what if the energy is in the single point calculator but the forces are not?
48 | @dataclasses.dataclass
49 | class UpdateFramesCalc:
50 | results_mapping: dict[str, str] = dataclasses.field(default_factory=dict)
51 | info_mapping: dict[str, str] = dataclasses.field(default_factory=dict)
52 | arrays_mapping: dict[str, str] = dataclasses.field(default_factory=dict)
53 |
54 | def get_calculator(self, **kwargs) -> _UpdateFramesCalc:
55 | return _UpdateFramesCalc(
56 | results_mapping=self.results_mapping,
57 | info_mapping=self.info_mapping,
58 | arrays_mapping=self.arrays_mapping,
59 | )
60 |
--------------------------------------------------------------------------------
/mlipx/project.py:
--------------------------------------------------------------------------------
1 | from zntrack import Project
2 |
3 | __all__ = ["Project"]
4 |
--------------------------------------------------------------------------------
/mlipx/recipes/README.md:
--------------------------------------------------------------------------------
1 | # Jinja2 Templating
2 |
3 | We use `recipe.py.jinja2` templates for generating the `main.py` and `models.py` file from the CLI.
4 | For new Nodes, once you added them to `mlipx/nodes/.py` and updated the `mlipx/__init__.pyi` you might want to create a new template and update the CLI in `main.py`.
5 |
6 | If you want to introduce a new model, you might want to adapt `models.py.jinja2` and the `main.py` as well.
7 |
--------------------------------------------------------------------------------
/mlipx/recipes/__init__.py:
--------------------------------------------------------------------------------
1 | from mlipx.recipes.main import app
2 |
3 | __all__ = ["app"]
4 |
--------------------------------------------------------------------------------
/mlipx/recipes/adsorption.py.jinja2:
--------------------------------------------------------------------------------
1 | import mlipx
2 | import zntrack
3 |
4 | from models import MODELS
5 |
6 | project = zntrack.Project()
7 |
8 | slabs = []
9 | {% if slab_config %}
10 | with project.group("initialize"):
11 | slabs.append(mlipx.BuildASEslab(**{{ slab_config }}).frames)
12 | {% endif %}
13 |
14 | adsorbates = []
15 | {% if smiles %}
16 | with project.group("initialize"):
17 | for smiles in {{ smiles }}:
18 | adsorbates.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1).frames)
19 | {% endif %}
20 |
21 | for model_name, model in MODELS.items():
22 | for idx, slab in enumerate(slabs):
23 | for jdx, adsorbate in enumerate(adsorbates):
24 | with project.group(model_name, str(idx)):
25 | _ = mlipx.RelaxAdsorptionConfigs(
26 | slabs=slab,
27 | adsorbates=adsorbate,
28 | model=model,
29 | )
30 |
31 | project.build()
32 |
--------------------------------------------------------------------------------
/mlipx/recipes/energy_volume.py.jinja2:
--------------------------------------------------------------------------------
1 | import mlipx
2 | import zntrack
3 |
4 | from models import MODELS
5 |
6 | project = zntrack.Project()
7 |
8 | frames = []
9 | {% if datapath %}
10 | with project.group("initialize"):
11 | for path in {{ datapath }}:
12 | frames.append(mlipx.LoadDataFile(path=path))
13 | {% endif %}{% if material_ids %}
14 | with project.group("initialize"):
15 | for material_id in {{ material_ids }}:
16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]}))
17 | {% endif %}{% if smiles %}
18 | with project.group("initialize"):
19 | for smiles in {{ smiles }}:
20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1))
21 | {% endif %}
22 |
23 |
24 | for model_name, model in MODELS.items():
25 | for idx, data in enumerate(frames):
26 | with project.group(model_name, str(idx)):
27 | neb = mlipx.EnergyVolumeCurve(
28 | model=model,
29 | data=data.frames,
30 | n_points=50,
31 | start=0.8,
32 | stop=2.0,
33 | )
34 |
35 | project.build()
36 |
--------------------------------------------------------------------------------
/mlipx/recipes/homonuclear_diatomics.py.jinja2:
--------------------------------------------------------------------------------
1 | import mlipx
2 | import zntrack
3 |
4 | from models import MODELS
5 |
6 | project = zntrack.Project()
7 |
8 | frames = []
9 | {% if datapath %}
10 | with project.group("initialize"):
11 | for path in {{ datapath }}:
12 | frames.append(mlipx.LoadDataFile(path=path).frames)
13 | {% endif %}{% if material_ids %}
14 | with project.group("initialize"):
15 | for material_id in {{ material_ids }}:
16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]}).frames)
17 | {% endif %}{% if smiles %}
18 | with project.group("initialize"):
19 | for smiles in {{ smiles }}:
20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1).frames)
21 | {% endif %}
22 |
23 | for model_name, model in MODELS.items():
24 | with project.group(model_name):
25 | neb = mlipx.HomonuclearDiatomics(
26 | elements=[],
27 | data=sum(frames, []), # Use all elements from all frames
28 | model=model,
29 | n_points=100,
30 | min_distance=0.5,
31 | max_distance=2.0,
32 | )
33 |
34 | project.build()
35 |
--------------------------------------------------------------------------------
/mlipx/recipes/invariances.py.jinja2:
--------------------------------------------------------------------------------
1 | import mlipx
2 | import zntrack
3 |
4 | from models import MODELS
5 |
6 | project = zntrack.Project()
7 |
8 | frames = []
9 | {% if datapath %}
10 | with project.group("initialize"):
11 | for path in {{ datapath }}:
12 | frames.append(mlipx.LoadDataFile(path=path))
13 | {% endif %}{% if material_ids %}
14 | with project.group("initialize"):
15 | for material_id in {{ material_ids }}:
16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]}))
17 | {% endif %}{% if smiles %}
18 | with project.group("initialize"):
19 | for smiles in {{ smiles }}:
20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1))
21 | {% endif %}
22 |
23 | for model_name, model in MODELS.items():
24 | for idx, data in enumerate(frames):
25 | with project.group(model_name, str(idx)):
26 | rot = mlipx.RotationalInvariance(
27 | model=model,
28 | n_points=100,
29 | data=data.frames,
30 | )
31 | trans = mlipx.TranslationalInvariance(
32 | model=model,
33 | n_points=100,
34 | data=data.frames,
35 | )
36 | perm = mlipx.PermutationInvariance(
37 | model=model,
38 | n_points=100,
39 | data=data.frames,
40 | )
41 |
42 | project.build()
43 |
--------------------------------------------------------------------------------
/mlipx/recipes/md.py.jinja2:
--------------------------------------------------------------------------------
1 | import mlipx
2 | import zntrack
3 |
4 | from models import MODELS
5 |
6 | project = zntrack.Project()
7 |
8 | frames = []
9 | {% if datapath %}
10 | with project.group("initialize"):
11 | for path in {{ datapath }}:
12 | frames.append(mlipx.LoadDataFile(path=path))
13 | {% endif %}{% if material_ids %}
14 | with project.group("initialize"):
15 | for material_id in {{ material_ids }}:
16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]}))
17 | {% endif %}{% if smiles %}
18 | with project.group("initialize"):
19 | for smiles in {{ smiles }}:
20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1))
21 | {% endif %}
22 |
23 | thermostat = mlipx.LangevinConfig(timestep=0.5, temperature=300, friction=0.05)
24 | force_check = mlipx.MaximumForceObserver(f_max=100)
25 | t_ramp = mlipx.TemperatureRampModifier(end_temperature=400, total_steps=100)
26 |
27 |
28 | for model_name, model in MODELS.items():
29 | for idx, data in enumerate(frames):
30 | with project.group(model_name, str(idx)):
31 | neb = mlipx.MolecularDynamics(
32 | model=model,
33 | thermostat=thermostat,
34 | data=data.frames,
35 | observers=[force_check],
36 | modifiers=[t_ramp],
37 | steps=1000,
38 | )
39 |
40 | project.build()
41 |
--------------------------------------------------------------------------------
/mlipx/recipes/metrics.py:
--------------------------------------------------------------------------------
1 | import zntrack
2 | from models import MODELS
3 |
4 | try:
5 | from models import REFERENCE
6 | except ImportError:
7 | REFERENCE = None
8 |
9 | import mlipx
10 |
11 | DATAPATH = "{{ datapath }}"
12 | ISOLATED_ATOM_ENERGIES = {{isolated_atom_energies}} # noqa F821
13 |
14 |
15 | project = zntrack.Project()
16 |
17 | with project.group("initialize"):
18 | data = mlipx.LoadDataFile(path=DATAPATH)
19 |
20 |
21 | with project.group("reference"):
22 | if REFERENCE is not None:
23 | data = mlipx.ApplyCalculator(data=data.frames, model=REFERENCE)
24 | ref_evaluation = mlipx.EvaluateCalculatorResults(data=data.frames)
25 | if ISOLATED_ATOM_ENERGIES:
26 | ref_isolated = mlipx.CalculateFormationEnergy(data=data.frames)
27 |
28 | for model_name, model in MODELS.items():
29 | with project.group(model_name):
30 | updated_data = mlipx.ApplyCalculator(data=data.frames, model=model)
31 | evaluation = mlipx.EvaluateCalculatorResults(data=updated_data.frames)
32 | mlipx.CompareCalculatorResults(data=evaluation, reference=ref_evaluation)
33 |
34 | if ISOLATED_ATOM_ENERGIES:
35 | isolated = mlipx.CalculateFormationEnergy(
36 | data=updated_data.frames, model=model
37 | )
38 | mlipx.CompareFormationEnergy(data=isolated, reference=ref_isolated)
39 |
40 | project.build()
41 |
--------------------------------------------------------------------------------
/mlipx/recipes/models.py.jinja2:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | import mlipx
4 | from mlipx.nodes.generic_ase import Device
5 |
6 | ALL_MODELS = {}
7 |
8 | # https://github.com/ACEsuit/mace
9 | ALL_MODELS["mace-mpa-0"] = mlipx.GenericASECalculator(
10 | module="mace.calculators",
11 | class_name="mace_mp",
12 | device="auto",
13 | kwargs={"model": "../../models/mace-mpa-0-medium.model"}
14 | # MLIPX-hub model path, adjust as needed
15 | )
16 | # https://github.com/MDIL-SNU/SevenNet
17 | ALL_MODELS["7net-0"] = mlipx.GenericASECalculator(
18 | module="sevenn.sevennet_calculator",
19 | class_name="SevenNetCalculator",
20 | device="auto",
21 | kwargs={"model": "7net-0"}
22 | )
23 | ALL_MODELS["7net-mf-ompa-mpa"] = mlipx.GenericASECalculator(
24 | module="sevenn.sevennet_calculator",
25 | class_name="SevenNetCalculator",
26 | device="auto",
27 | kwargs={"model": "7net-mf-ompa", "modal": "mpa"}
28 | )
29 |
30 | # https://github.com/orbital-materials/orb-models
31 | @dataclasses.dataclass
32 | class OrbCalc:
33 | name: str
34 | device: Device | None = None
35 | kwargs: dict = dataclasses.field(default_factory=dict)
36 |
37 | def get_calculator(self, **kwargs):
38 | from orb_models.forcefield import pretrained
39 | from orb_models.forcefield.calculator import ORBCalculator
40 |
41 | method = getattr(pretrained, self.name)
42 | if self.device is None:
43 | orbff = method(**self.kwargs)
44 | calc = ORBCalculator(orbff, **self.kwargs)
45 | elif self.device == Device.AUTO:
46 | orbff = method(device=Device.resolve_auto(), **self.kwargs)
47 | calc = ORBCalculator(orbff, device=Device.resolve_auto(), **self.kwargs)
48 | else:
49 | orbff = method(device=self.device, **self.kwargs)
50 | calc = ORBCalculator(orbff, device=self.device, **self.kwargs)
51 | return calc
52 |
53 | @property
54 | def available(self) -> bool:
55 | try:
56 | from orb_models.forcefield import pretrained
57 | from orb_models.forcefield.calculator import ORBCalculator
58 | return True
59 | except ImportError:
60 | return False
61 |
62 | ALL_MODELS["orb-v2"] = OrbCalc(
63 | name="orb_v2",
64 | device="auto"
65 | )
66 | ALL_MODELS["orb-v3"] = OrbCalc(
67 | name="orb_v3_conservative_inf_omat",
68 | device="auto"
69 | )
70 |
71 | # https://github.com/CederGroupHub/chgnet
72 | ALL_MODELS["chgnet"] = mlipx.GenericASECalculator(
73 | module="chgnet.model",
74 | class_name="CHGNetCalculator",
75 | )
76 | # https://github.com/microsoft/mattersim
77 | ALL_MODELS["mattersim"] = mlipx.GenericASECalculator(
78 | module="mattersim.forcefield",
79 | class_name="MatterSimCalculator",
80 | device="auto",
81 | )
82 | # https://www.faccts.de/orca/
83 | ALL_MODELS["orca"] = mlipx.OrcaSinglePoint(
84 | orcasimpleinput= "PBE def2-TZVP TightSCF EnGrad",
85 | orcablocks ="%pal nprocs 8 end",
86 | orca_shell="{{ orcashell }}",
87 | )
88 |
89 | # https://gracemaker.readthedocs.io/en/latest/gracemaker/foundation/
90 | ALL_MODELS["grace-2l-omat"] = mlipx.GenericASECalculator(
91 | module="tensorpotential.calculator",
92 | class_name="TPCalculator",
93 | device=None,
94 | kwargs={
95 | "model": "../../models/GRACE-2L-OMAT",
96 | },
97 | # MLIPX-hub model path, adjust as needed
98 | )
99 |
100 | # OPTIONAL
101 | # ========
102 | # If you have custom property names you can use the UpdatedFramesCalc
103 | # to set the energy, force and isolated_energies keys mlipx expects.
104 |
105 | # REFERENCE = mlipx.UpdateFramesCalc(
106 | # results_mapping={"energy": "DFT_ENERGY", "forces": "DFT_FORCES"},
107 | # info_mapping={mlipx.abc.ASEKeys.isolated_energies.value: "isol_ene"},
108 | # )
109 |
110 | # ============================================================
111 | # THE SELECTED MODELS!
112 | # ONLY THESE MODELS WILL BE USED IN THE RECIPE
113 | # ============================================================
114 | MODELS = {
115 | {%- for model in models %}
116 | "{{ model }}": ALL_MODELS["{{ model }}"],
117 | {%- endfor %}
118 | }
119 |
--------------------------------------------------------------------------------
/mlipx/recipes/neb.py:
--------------------------------------------------------------------------------
1 | import zntrack
2 | from models import MODELS
3 |
4 | import mlipx
5 |
6 | DATAPATH = "{{ datapath }}"
7 |
8 | project = zntrack.Project()
9 |
10 | with project.group("initialize"):
11 | data = mlipx.LoadDataFile(path=DATAPATH)
12 | trajectory = mlipx.NEBinterpolate(data=data.frames, n_images=5, mic=True)
13 |
14 | for model_name, model in MODELS.items():
15 | with project.group(model_name):
16 | neb = mlipx.NEBs(
17 | data=trajectory.frames,
18 | model=model,
19 | relax=True,
20 | fmax=0.05,
21 | )
22 |
23 | project.build()
24 |
--------------------------------------------------------------------------------
/mlipx/recipes/phase_diagram.py.jinja2:
--------------------------------------------------------------------------------
1 | import mlipx
2 | import zntrack
3 |
4 | from models import MODELS
5 |
6 | project = zntrack.Project()
7 |
8 | frames = []
9 | {% if datapath %}
10 | with project.group("initialize"):
11 | for path in {{ datapath }}:
12 | frames.append(mlipx.LoadDataFile(path=path))
13 | {% endif %}{% if material_ids %}
14 | with project.group("initialize"):
15 | for material_id in {{ material_ids }}:
16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]}))
17 | {% endif %}{% if smiles %}
18 | with project.group("initialize"):
19 | for smiles in {{ smiles }}:
20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1))
21 | {% endif %}
22 |
23 | for model_name, model in MODELS.items():
24 | for idx, data in enumerate(frames):
25 | with project.group(model_name, str(idx)):
26 | pd = mlipx.PhaseDiagram(data=data.frames, model=model)
27 |
28 |
29 | project.build()
30 |
--------------------------------------------------------------------------------
/mlipx/recipes/pourbaix_diagram.py.jinja2:
--------------------------------------------------------------------------------
1 | import mlipx
2 | import zntrack
3 |
4 | from models import MODELS
5 |
6 | project = zntrack.Project()
7 |
8 | frames = []
9 | {% if datapath %}
10 | with project.group("initialize"):
11 | for path in {{ datapath }}:
12 | frames.append(mlipx.LoadDataFile(path=path))
13 | {% endif %}{% if material_ids %}
14 | with project.group("initialize"):
15 | for material_id in {{ material_ids }}:
16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]}))
17 | {% endif %}{% if smiles %}
18 | with project.group("initialize"):
19 | for smiles in {{ smiles }}:
20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1))
21 | {% endif %}
22 |
23 | for model_name, model in MODELS.items():
24 | for idx, data in enumerate(frames):
25 | with project.group(model_name, str(idx)):
26 | pd = mlipx.PourbaixDiagram(data=data.frames, model=model, pH=1.0, V=1.8)
27 |
28 | project.build()
29 |
--------------------------------------------------------------------------------
/mlipx/recipes/relax.py.jinja2:
--------------------------------------------------------------------------------
1 | import mlipx
2 | import zntrack
3 |
4 | from models import MODELS
5 |
6 | project = zntrack.Project()
7 |
8 | frames = []
9 | {% if datapath %}
10 | with project.group("initialize"):
11 | for path in {{ datapath }}:
12 | frames.append(mlipx.LoadDataFile(path=path))
13 | {% endif %}{% if material_ids %}
14 | with project.group("initialize"):
15 | for material_id in {{ material_ids }}:
16 | data = mlipx.MPRester(search_kwargs={"material_ids": [material_id]})
17 | frames.append(mlipx.Rattle(data=data.frames, stdev=0.1))
18 | {% endif %}{% if smiles %}
19 | with project.group("initialize"):
20 | for smiles in {{ smiles }}:
21 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1))
22 | {% endif %}
23 |
24 |
25 | for model_name, model in MODELS.items():
26 | for idx, data in enumerate(frames):
27 | with project.group(model_name, str(idx)):
28 | geom_opt = mlipx.StructureOptimization(data=data.frames, model=model, fmax=0.1)
29 |
30 | project.build()
31 |
--------------------------------------------------------------------------------
/mlipx/recipes/vibrational_analysis.py.jinja2:
--------------------------------------------------------------------------------
1 | import mlipx
2 | import zntrack
3 |
4 | from models import MODELS
5 |
6 | project = zntrack.Project()
7 |
8 | frames = []
9 | {% if datapath %}
10 | with project.group("initialize"):
11 | for path in {{ datapath }}:
12 | frames.append(mlipx.LoadDataFile(path=path))
13 | {% endif %}{% if material_ids %}
14 | with project.group("initialize"):
15 | for material_id in {{ material_ids }}:
16 | frames.append(mlipx.MPRester(search_kwargs={"material_ids": [material_id]}))
17 | {% endif %}{% if smiles %}
18 | with project.group("initialize"):
19 | for smiles in {{ smiles }}:
20 | frames.append(mlipx.Smiles2Conformers(smiles=smiles, num_confs=1))
21 | {% endif %}
22 |
23 | for model_name, model in MODELS.items():
24 | with project.group(model_name):
25 | phon = mlipx.VibrationalAnalysis(
26 | data=sum([x.frames for x in frames], []),
27 | model=model,
28 | temperature=298.15,
29 | displacement=0.015,
30 | nfree=4,
31 | lower_freq_threshold=12,
32 | )
33 |
34 |
35 | project.build()
36 |
--------------------------------------------------------------------------------
/mlipx/utils.py:
--------------------------------------------------------------------------------
1 | import ase
2 | import numpy as np
3 | from ase.calculators.singlepoint import SinglePointCalculator
4 |
5 |
6 | def freeze_copy_atoms(atoms: ase.Atoms) -> ase.Atoms:
7 | atoms_copy = atoms.copy()
8 | if atoms.calc is not None:
9 | atoms_copy.calc = SinglePointCalculator(atoms_copy)
10 | atoms_copy.calc.results = atoms.calc.results
11 | return atoms_copy
12 |
13 |
14 | def shallow_copy_atoms(atoms: ase.Atoms) -> ase.Atoms:
15 | """Create a shallow copy of an ASE atoms object."""
16 | atoms_copy = ase.Atoms(
17 | positions=atoms.positions,
18 | numbers=atoms.numbers,
19 | cell=atoms.cell,
20 | pbc=atoms.pbc,
21 | )
22 | return atoms_copy
23 |
24 |
25 | def rmse(y_true, y_pred):
26 | return np.sqrt(np.mean((y_true - y_pred) ** 2))
27 |
--------------------------------------------------------------------------------
/mlipx/version.py:
--------------------------------------------------------------------------------
1 | import importlib.metadata
2 |
3 | __version__ = importlib.metadata.version("mlipx")
4 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "mlipx"
3 | version = "0.1.4"
4 | description = "Machine-Learned Interatomic Potential eXploration"
5 | authors = [
6 | { name = "Sandip De", email = "sandip.de@basf.com" },
7 | { name = "Fabian Zills", email = "fzills@icp.uni-stuttgart.de" },
8 | { name = "Sheena Agarwal", email = "sheena.a.agarwal@basf.com" }
9 | ]
10 |
11 | readme = "README.md"
12 | license = "MIT"
13 | requires-python = ">=3.10"
14 | keywords=["data-version-control", "machine-learning", "reproducibility", "collaboration", "machine-learned interatomic potential", "mlip", "mlff"]
15 |
16 | dependencies = [
17 | "ase>=3.24.0",
18 | "lazy-loader>=0.4",
19 | "mp-api>=0.45.3",
20 | "plotly>=6.0.0",
21 | "rdkit2ase>=0.1.4",
22 | "typer>=0.15.1",
23 | "zndraw>=0.5.10",
24 | "znh5md>=0.4.4",
25 | "zntrack>=0.8.5",
26 | "dvc-s3>=3.2.0",
27 | "mpcontribs-client>=5.10.2",
28 | ]
29 |
30 | [dependency-groups]
31 | docs = [
32 | "furo>=2024.8.6",
33 | "jupyter-sphinx>=0.5.3",
34 | "nbsphinx>=0.9.6", # https://github.com/sphinx-doc/sphinx/issues/13352
35 | "sphinx>=8.1.3,!=8.2.0",
36 | "sphinx-copybutton>=0.5.2",
37 | "sphinx-design>=0.6.1",
38 | "sphinx-hoverxref>=1.4.2",
39 | "sphinx-mdinclude>=0.6.2",
40 | "sphinxcontrib-bibtex>=2.6.3",
41 | "sphinxcontrib-mermaid>=1.0.0",
42 | ]
43 | dev = [
44 | "ipykernel>=6.29.5",
45 | "pre-commit>=4.2.0",
46 | "pytest>=8.3.4",
47 | "pytest-cov>=6.0.0",
48 | "ruff>=0.9.6",
49 | ]
50 |
51 |
52 | [project.scripts]
53 | mlipx = "mlipx.cli.main:app"
54 |
55 |
56 | [tool.uv]
57 | conflicts = [
58 | [
59 | { extra = "mace" },
60 | { extra = "mattersim" },
61 | ],
62 | [
63 | { extra = "mace" },
64 | { extra = "sevenn" },
65 | ]
66 | ]
67 |
68 | [project.optional-dependencies]
69 | chgnet = [
70 | "chgnet>=0.4.0",
71 | ]
72 | mace = [
73 | "mace-torch>=0.3.12",
74 | ]
75 | sevenn = [
76 | "sevenn>=0.11.0",
77 | ]
78 | orb = [
79 | "orb-models>=0.5.0"
80 | ]
81 | mattersim = [
82 | "mattersim>=1.1.2",
83 | ]
84 | grace = [
85 | "tensorpotential>=0.5.1",
86 | ]
87 |
88 |
89 | [build-system]
90 | requires = ["hatchling"]
91 | build-backend = "hatchling.build"
92 |
93 | [tool.codespell]
94 | ignore-words-list = "basf"
95 | skip = "*.svg,*.lock,*.json"
96 |
97 | [tool.ruff.lint]
98 | select = ["I", "F", "E", "C", "W"]
99 |
--------------------------------------------------------------------------------