├── .github
├── env-dev.yml
├── env-docs.yml
└── workflows
│ ├── CI.yaml
│ ├── linting.yaml
│ └── rtd.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── LICENSE.txt
├── README.md
├── docs
├── Makefile
├── _static
│ ├── TM23Cu_sample_complexity.png
│ ├── css
│ │ └── mystyle.css
│ └── diagram_part1.png
├── _templates
│ ├── class.rst
│ ├── func.rst
│ └── layout.html
├── conf.py
├── index.rst
├── make.bat
├── notebooks
│ ├── autotune.ipynb
│ ├── getting_started.ipynb
│ └── molecular_dynamics.ipynb
├── reference
│ ├── cli.rst
│ ├── franken-api
│ │ ├── franken.calculators.rst
│ │ ├── franken.config.rst
│ │ ├── franken.rf.heads.rst
│ │ ├── franken.rf.model.rst
│ │ ├── franken.rf.scaler.rst
│ │ └── franken.trainers.rst
│ ├── franken-cli
│ │ ├── franken.autotune.rst
│ │ ├── franken.backbones.rst
│ │ └── franken.create_lammps_model.rst
│ └── index.rst
├── requirements.txt
└── topics
│ ├── installation.md
│ ├── lammps.md
│ └── model_registry.md
├── franken
├── __init__.py
├── autotune
│ ├── __init__.py
│ ├── __main__.py
│ ├── cli.py
│ └── script.py
├── backbones
│ ├── __init__.py
│ ├── cli.py
│ ├── registry.json
│ ├── utils.py
│ └── wrappers
│ │ ├── __init__.py
│ │ ├── common_patches.py
│ │ ├── fairchem_schnet.py
│ │ ├── mace_wrap.py
│ │ └── sevenn.py
├── calculators
│ ├── __init__.py
│ ├── ase_calc.py
│ └── lammps_calc.py
├── config.py
├── data
│ ├── __init__.py
│ ├── base.py
│ ├── distributed_sampler.py
│ ├── fairchem.py
│ ├── mace.py
│ └── sevenn.py
├── datasets
│ ├── PtH2O
│ │ └── pth2o_dataset.py
│ ├── TM23
│ │ └── tm23_dataset.py
│ ├── __init__.py
│ ├── registry.py
│ ├── split_data.py
│ ├── test
│ │ ├── long.xyz
│ │ ├── md.xyz
│ │ ├── test.xyz
│ │ ├── test_dataset.py
│ │ ├── train.xyz
│ │ └── validation.xyz
│ └── water
│ │ ├── HH_digitizer.csv
│ │ ├── OH_digitizer.csv
│ │ ├── OO_digitizer.csv
│ │ ├── exp_rdf.csv
│ │ └── water_dataset.py
├── metrics
│ ├── __init__.py
│ ├── base.py
│ ├── functions.py
│ └── registry.py
├── rf
│ ├── __init__.py
│ ├── atomic_energies.py
│ ├── heads.py
│ ├── model.py
│ └── scaler.py
├── trainers
│ ├── __init__.py
│ ├── base.py
│ ├── log_utils.py
│ └── rf_cuda_lowmem.py
└── utils
│ ├── __init__.py
│ ├── distributed.py
│ ├── file_utils.py
│ ├── hostlist.py
│ ├── jac.py
│ ├── linalg
│ ├── __init__.py
│ ├── cov.py
│ ├── psdsolve.py
│ └── tri.py
│ └── misc.py
├── notebooks
├── autotune.ipynb
├── colab.ipynb
├── getting_started.ipynb
└── molecular_dynamics.ipynb
├── pyproject.toml
└── tests
├── __init__.py
├── conftest.py
├── test_FrankenPotential.py
├── test_backbones.py
├── test_backbones_utils.py
├── test_data.py
├── test_lammps.py
├── test_linalg.py
├── test_metrics.py
├── test_rf_heads.py
├── test_trainer.py
├── test_trainers_log_utils.py
└── utils.py
/.github/env-dev.yml:
--------------------------------------------------------------------------------
1 | channels:
2 | - pytorch
3 | - nvidia
4 | - conda-forge
5 | dependencies:
6 | - pytorch>=2.4
7 | - ase
8 | - numpy
9 | - omegaconf
10 | - cupy
11 | - e3nn
12 | - pip
13 | - requests
14 | - tqdm
15 | - pytest
16 | - pre-commit
17 | - black
18 | - ruff
19 | - psutil
20 | - docstring_parser
21 | - packaging
22 | name: franken
--------------------------------------------------------------------------------
/.github/env-docs.yml:
--------------------------------------------------------------------------------
1 | channels:
2 | - pytorch
3 | - nvidia
4 | - conda-forge
5 | dependencies:
6 | - pytorch>=2.4
7 | - ase
8 | - numpy
9 | - omegaconf
10 | - pip
11 | - e3nn
12 | - requests
13 | - tqdm
14 | - pytest
15 | - psutil
16 | - docstring_parser
17 | - ipython
18 | - sphinx
19 | - sphinxawesome-theme
20 | - sphinxcontrib-applehelp
21 | - sphinxcontrib-devhelp
22 | - sphinxcontrib-htmlhelp
23 | - sphinxcontrib-jsmath
24 | - sphinxcontrib-qthelp
25 | - sphinxcontrib-serializinghtml
26 | - sphinx-argparse
27 | - myst-parser
28 | - nbsphinx
29 | - packaging
30 | name: franken
--------------------------------------------------------------------------------
/.github/workflows/CI.yaml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches:
6 | - "main"
7 | pull_request:
8 | branches:
9 | - "main"
10 | types: [opened, reopened, synchronize]
11 | schedule:
12 | # Weekly tests run on main by default:
13 | # Scheduled workflows run on the latest commit on the default or base branch.
14 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule)
15 | - cron: "0 2 * * 1"
16 | workflow_dispatch:
17 |
18 | jobs:
19 | test:
20 | name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }}
21 | runs-on: ${{ matrix.os }}
22 | strategy:
23 | matrix:
24 | os: [ubuntu-latest]
25 | python-version: ["3.10", "3.11", "3.12"]
26 | pytorch-version: ["2.5", "2.6"]
27 |
28 | steps:
29 | - uses: actions/checkout@v4
30 |
31 | - name: Additional info about the build
32 | shell: bash
33 | run: |
34 | uname -a
35 | df -h
36 | ulimit -a
37 |
38 | # More info on options: https://github.com/marketplace/actions/setup-micromamba
39 | - name: Create and setup mamba
40 | uses: mamba-org/setup-micromamba@v2
41 | with:
42 | # here we specify the environment like this instead of just installing with pip to make caching easier
43 | environment-file: .github/env-dev.yml
44 | environment-name: test
45 | cache-environment: true
46 | cache-environment-key: environment-${{ matrix.python-version }}-${{ matrix.pytorch-version }}
47 | condarc: |
48 | channels:
49 | - conda-forge
50 | create-args: >-
51 | python=${{ matrix.python-version }}
52 | pytorch=${{ matrix.pytorch-version }}
53 |
54 | - name: Install GNN backbones packages
55 | # conda setup requires this special shell
56 | shell: bash -l {0}
57 | env:
58 | TORCH_VERSION: ${{ matrix.pytorch-version }}
59 | run: |
60 | python -m pip install torch_geometric
61 | # Running with the -f argument gives us prebuilt wheels which speeds things up.
62 | # On the other hand it depends on them publishing the wheels
63 | python -m pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH_VERSION}.0+cpu.html
64 | python -m pip install --no-deps fairchem-core==1.10 # fairchem dependencies are a nightmare, better to ignore them
65 | python -m pip install mace-torch
66 |
67 | - name: Install package
68 | # conda setup requires this special shell
69 | shell: bash -l {0}
70 | run: |
71 | python -m pip install . --no-deps
72 | micromamba list
73 |
74 | - name: Run tests
75 | # conda setup requires this special shell
76 | shell: bash -l {0}
77 | run: |
78 | pytest -v --color=yes tests/
79 |
80 | # - name: CodeCov
81 | # if: contains( matrix.os, 'ubuntu' )
82 | # uses: codecov/codecov-action@v3
83 | # with:
84 | # token: ${{ secrets.CODECOV_TOKEN }}
85 | # file: ./coverage.xml
86 | # flags: codecov
87 | # name: codecov-${{ matrix.os }}-py${{ matrix.python-version }}
88 |
89 | build:
90 | name: Bump version and build package with hatch
91 | needs: test # This ensures 'publish' only runs if 'test' passes
92 | runs-on: ubuntu-latest
93 | if: |
94 | github.ref == 'refs/heads/main' &&
95 | github.event_name == 'push' &&
96 | contains(github.event.head_commit.message, '[release]')
97 | steps:
98 | - uses: actions/checkout@v4
99 | - name: setup python
100 | uses: actions/setup-python@v5
101 | with:
102 | python-version: '3.11'
103 | - name: install hatch
104 | run: pip install hatch
105 |
106 | - name: Determine bump type
107 | id: bump
108 | run: |
109 | COMMIT_MSG=`git log -1 --pretty=%B | head -n 1`
110 | if [[ "$COMMIT_MSG" == *"[Major]"* ]]; then
111 | echo "bump=major" >> $GITHUB_OUTPUT
112 | elif [[ "$COMMIT_MSG" == *"[Minor]"* ]]; then
113 | echo "bump=minor" >> $GITHUB_OUTPUT
114 | else
115 | echo "bump=patch" >> $GITHUB_OUTPUT
116 | fi
117 | - name: bump version and tag repo
118 | run: |
119 | git config --global user.name 'autobump'
120 | git config --global user.email 'autobump@github.com'
121 | OLD_VERSION=`hatch version`
122 | hatch version ${{ steps.bump.outputs.bump }}
123 | NEW_VERSION=`hatch version`
124 | git add franken/__init__.py
125 | git commit -m "Updated version: ${OLD_VERSION} → ${NEW_VERSION} [skip ci]"
126 | git tag $NEW_VERSION
127 | git push
128 | git push --tags
129 | - name: build franken package
130 | run: hatch build
131 | - name: Upload build artifacts
132 | uses: actions/upload-artifact@v4
133 | with:
134 | name: dist-files
135 | path: dist/*
136 |
137 | publish:
138 | name: Publish to PyPi
139 | needs: build
140 | runs-on: ubuntu-latest
141 | if: |
142 | github.ref == 'refs/heads/main' &&
143 | github.event_name == 'push' &&
144 | needs.build.result == 'success'
145 | permissions:
146 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
147 | environment:
148 | name: pypi
149 | url: https://pypi.org/project/franken/
150 | steps:
151 | - name: Download build artifacts
152 | uses: actions/download-artifact@v4
153 | with:
154 | name: dist-files
155 | path: dist/
156 | - name: Publish package distributions to PyPI
157 | uses: pypa/gh-action-pypi-publish@release/v1
--------------------------------------------------------------------------------
/.github/workflows/linting.yaml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on: [push]
4 |
5 | jobs:
6 | lint:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v4
10 | - name: Set up Python
11 | uses: actions/setup-python@v5
12 | with:
13 | python-version: '3.11'
14 | - name: Install dependencies
15 | run: |
16 | python -m pip install --upgrade pip
17 | pip install ruff
18 | - name: Linting with black
19 | uses: psf/black@stable
20 | with:
21 | options: "--check --verbose"
22 | use_pyproject: true
23 | - name: Run Ruff
24 | run: ruff check --output-format=github .
25 |
--------------------------------------------------------------------------------
/.github/workflows/rtd.yaml:
--------------------------------------------------------------------------------
1 | name: Docs
2 |
3 | # Runs on pushes targeting the default branch
4 | on:
5 | push:
6 | branches:
7 | - main
8 | pull_request:
9 | branches:
10 | - main
11 |
12 | # Allows you to run this workflow manually from the Actions tab
13 | workflow_dispatch:
14 |
15 | # Cancel in-progress runs when pushing a new commit on the PR
16 | concurrency:
17 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
18 | cancel-in-progress: true
19 |
20 | jobs:
21 | docs:
22 | environment:
23 | name: ghpg
24 | url: ${{ steps.deployment.outputs.page_url }}
25 | runs-on: ubuntu-latest
26 | steps:
27 | - uses: actions/checkout@v4
28 | - name: Create and setup mamba
29 | uses: mamba-org/setup-micromamba@v2
30 | with:
31 | # here we specify the environment like this instead of just installing with pip to make caching easier
32 | environment-file: .github/env-docs.yml
33 | environment-name: test
34 | cache-environment: true
35 | cache-environment-key: environment-docs
36 | condarc: |
37 | channels:
38 | - conda-forge
39 | create-args: >-
40 | python=3.12
41 | pytorch=2.4
42 | pandoc=3.6.4
43 | - name: Install franken
44 | # conda setup requires this special shell
45 | shell: bash -l {0}
46 | # dependencies are handled in conda env
47 | run: |
48 | python -m pip install . --no-deps
49 | micromamba list
50 |
51 | - name: Sphinx build
52 | # conda setup requires this special shell
53 | shell: bash -l {0}
54 | run: |
55 | # Check import works. sphinx-build will try to import but not provide reliable error-traces.
56 | python -c "import franken; import franken.calculators;"
57 | sphinx-build docs _build
58 |
59 | # This step zips and pushes the built docs to the rtd branch
60 | - name: Push docs to rtd branch
61 | if: github.ref == 'refs/heads/main' # Only deploy when pushing to main
62 | run: |
63 | # Setup git identity
64 | git config --global user.name "GitHub Actions"
65 | git config --global user.email "actions@github.com"
66 |
67 | # Create docs.zip from the _build directory first
68 | cd _build
69 | zip -r ../docs.zip .
70 | cd ..
71 |
72 | # Save a copy of important files
73 | cp .readthedocs.yaml /tmp/readthedocs.yaml
74 | cp docs.zip /tmp/docs.zip
75 |
76 | # Create a fresh rtd branch
77 | git checkout --orphan rtd-temp
78 | git rm -rf .
79 |
80 | # Restore the saved files
81 | cp /tmp/readthedocs.yaml .readthedocs.yaml
82 | cp /tmp/docs.zip docs.zip
83 |
84 | # Add and commit both files
85 | git add docs.zip .readthedocs.yaml
86 | git commit -m "Update documentation build [skip ci]"
87 |
88 | # Force push to rtd branch
89 | git push origin rtd-temp:rtd -f
90 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Datasets
2 | datasets/TM23/test_lengths.py
3 | franken/datasets/TM23/*.xyz
4 | franken/datasets/TM23/*.zip
5 | franken/datasets/PtH2O/*.extxyz
6 | franken/datasets/PtH2O/*.traj
7 | franken/datasets/water/*.xyz
8 | franken/datasets/water/*.zip
9 |
10 | # Random
11 | rsync.sh
12 |
13 | # Docs
14 | docs/reference/franken-api/stubs
15 | docs/reference/stubs
16 |
17 | # Byte-compiled / optimized / DLL files
18 | __pycache__/
19 | *.py[cod]
20 | *$py.class
21 |
22 | # C extensions
23 | *.so
24 |
25 | # Distribution / packaging
26 | .Python
27 | build/
28 | develop-eggs/
29 | dist/
30 | downloads/
31 | eggs/
32 | .eggs/
33 | lib/
34 | lib64/
35 | parts/
36 | sdist/
37 | var/
38 | wheels/
39 | share/python-wheels/
40 | *.egg-info/
41 | .installed.cfg
42 | *.egg
43 | MANIFEST
44 |
45 | # PyInstaller
46 | # Usually these files are written by a python script from a template
47 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
48 | *.manifest
49 | *.spec
50 |
51 | # Installer logs
52 | pip-log.txt
53 | pip-delete-this-directory.txt
54 |
55 | # Unit test / coverage reports
56 | htmlcov/
57 | .tox/
58 | .nox/
59 | .coverage
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *.cover
65 | *.py,cover
66 | .hypothesis/
67 | .pytest_cache/
68 | cover/
69 |
70 | # Translations
71 | *.mo
72 | *.pot
73 |
74 | # Django stuff:
75 | *.log
76 | local_settings.py
77 | db.sqlite3
78 | db.sqlite3-journal
79 |
80 | # Flask stuff:
81 | instance/
82 | .webassets-cache
83 |
84 | # Scrapy stuff:
85 | .scrapy
86 |
87 | # Sphinx documentation
88 | docs/_build/
89 |
90 | # PyBuilder
91 | .pybuilder/
92 | target/
93 |
94 | # Jupyter Notebook
95 | .ipynb_checkpoints
96 |
97 | # IPython
98 | profile_default/
99 | ipython_config.py
100 |
101 | # pyenv
102 | # For a library or package, you might want to ignore these files since the code is
103 | # intended to run in multiple environments; otherwise, check them in:
104 | # .python-version
105 |
106 | # pipenv
107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
109 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
110 | # install all needed dependencies.
111 | #Pipfile.lock
112 |
113 | # poetry
114 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
115 | # This is especially recommended for binary packages to ensure reproducibility, and is more
116 | # commonly ignored for libraries.
117 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
118 | #poetry.lock
119 |
120 | # pdm
121 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
122 | #pdm.lock
123 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
124 | # in version control.
125 | # https://pdm.fming.dev/#use-with-ide
126 | .pdm.toml
127 |
128 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
129 | __pypackages__/
130 |
131 | # Celery stuff
132 | celerybeat-schedule
133 | celerybeat.pid
134 |
135 | # SageMath parsed files
136 | *.sage.py
137 |
138 | # Environments
139 | .env
140 | .venv
141 | env/
142 | venv/
143 | ENV/
144 | env.bak/
145 | venv.bak/
146 |
147 | # Spyder project settings
148 | .spyderproject
149 | .spyproject
150 |
151 | # Rope project settings
152 | .ropeproject
153 |
154 | # mkdocs documentation
155 | /site
156 |
157 | # mypy
158 | .mypy_cache/
159 | .dmypy.json
160 | dmypy.json
161 |
162 | # Pyre type checker
163 | .pyre/
164 |
165 | # pytype static type analyzer
166 | .pytype/
167 |
168 | # Cython debug symbols
169 | cython_debug/
170 |
171 | .vscode
172 |
173 | **/gnn_checkpoints/
174 | **/experiments/
175 | **/notebooks_legacy/
176 | **/baselines_legacy/
177 | wheel/
178 | **/wandb/
179 | **/precomputed/
180 | *.pt
181 | *.report
182 | *.sbatch
183 | slurm*
184 |
185 | *baseline_/
186 | *.out
187 | .history/
188 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
3 | - repo: https://github.com/psf/black-pre-commit-mirror
4 | rev: 24.8.0
5 | hooks:
6 | - id: black
7 | # It is recommended to specify the latest version of Python
8 | # supported by your project here
9 | language_version: python3.11
10 | - repo: https://github.com/astral-sh/ruff-pre-commit
11 | rev: v0.7.3
12 | hooks:
13 | - id: ruff
14 | pass_filenames: false
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # Read the Docs configuration file
2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
3 |
4 | # Required
5 | version: 2
6 |
7 | build:
8 | os: ubuntu-22.04
9 | tools:
10 | python: "3.10"
11 | jobs:
12 | build:
13 | html:
14 | - echo "Extracting pre-built docs from docs.zip"
15 | - mkdir -p $READTHEDOCS_OUTPUT/html/
16 | - unzip -o docs.zip -d $READTHEDOCS_OUTPUT/html/
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2025 Franken authors
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Franken
2 |
3 | [](https://github.com/CSML-IIT-UCL/franken/actions/workflows/CI.yaml)
4 | [](https://franken.readthedocs.io/)
5 |
6 |
7 | ## Introduction
8 |
9 | Franken is an open-source library that can be used to enhance the accuracy of atomistic foundation models. It can be used for molecular dynamics simulations, and has a focus on computational efficiency.
10 |
11 | `franken` features include:
12 | - Supports fine-tuning for a variety of foundation models ([MACE](https://github.com/ACEsuit/mace), [SevenNet](https://github.com/MDIL-SNU/SevenNet), [SchNet](https://github.com/facebookresearch/fairchem))
13 | - Automatic [hyperparameter tuning](https://franken.readthedocs.io/notebooks/autotune.html) simplifies the adaptation procedure, for an out-of-the-box user experience.
14 | - Several random-feature approximations to common kernels (e.g. Gaussian, polynomial) are available to flexibly fine-tune any foundation model.
15 | - Support for running within [LAMMPS](https://www.lammps.org/) molecular dynamics, as well as with [ASE](https://wiki.fysik.dtu.dk/ase/).
16 |
17 |
18 |
19 | For detailed information and benchmarks please check our paper [*Fast and Fourier Features for Transfer Learning of Interatomic Potentials*](https://arxiv.org/abs/2505.05652).
20 |
21 | ## Documentation
22 |
23 | A full documentation including several examples is available: [https://franken.readthedocs.io/index.html](https://franken.readthedocs.io/index.html). [The paper](https://arxiv.org/abs/2505.05652) also contains a comprehensive description of the methods behind franken.
24 |
25 | ## Install
26 |
27 | To install the latest release of `franken`, you can simply do:
28 |
29 | ```bash
30 | pip install franken
31 | ```
32 |
33 | Several optional dependencies can be specified, to install packages required for certain operations:
34 | - `cuda` includes packages which speed up training on GPUs (note that `franken` will work on GPUs even without these dependencies thanks to pytorch).
35 | - `fairchem`, `mace`, `sevenn` install the necessary dependencies to use a specific backbone.
36 | - `docs` and `develop` are only needed if you wish to build the documentation, or work on extending the library.
37 |
38 | They can be installed for example by running
39 |
40 | ```bash
41 | pip install franken[mace,cuda]
42 | ```
43 |
44 | For more details read the [relevant documentation page](https://franken.readthedocs.io/topics/installation.html)
45 |
46 | ## Quickstart
47 |
48 | You can directly run `franken.autotune` to get started with the `franken` library. A quick example is to fine-tune MACE-MP0 on a high-level-of-theory water dataset:
49 |
50 | ```bash
51 | franken.autotune \
52 | --dataset-name="water" --max-train-samples=8 \
53 | --l2-penalty="(-10, -5, 5, log)" \
54 | --force-weight="(0.01, 0.99, 5, linear)" \
55 | --seed=42 \
56 | --jac-chunk-size=64 \
57 | --run-dir="./results" \
58 | --backbone=mace --mace.path-or-id="MACE-L0" --mace.interaction-block=2 \
59 | --rf=gaussian --gaussian.num-rf=512 --gaussian.length-scale="[10.0, 15.0]"
60 | ```
61 |
62 | For more details you can check out the [autotune tutorial](https://franken.readthedocs.io/notebooks/autotune.html) or the [getting started notebook](https://franken.readthedocs.io/notebooks/getting_started.html).
63 |
64 |
65 | ## Citing
66 |
67 | If you find this library useful, please cite our work using the folowing bibtex entry:
68 | ```
69 | @misc{novelli25franken,
70 | title={Fast and Fourier Features for Transfer Learning of Interatomic Potentials},
71 | author={Pietro Novelli and Giacomo Meanti and Pedro J. Buigues and Lorenzo Rosasco and Michele Parrinello and Massimiliano Pontil and Luigi Bonati},
72 | year={2025},
73 | eprint={2505.05652},
74 | archivePrefix={arXiv},
75 | url={https://arxiv.org/abs/2505.05652},
76 | }
77 | ```
78 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
22 | clean:
23 | rm -rf $(BUILDDIR)/*
24 | rm -rf reference/franken-api/stubs
25 |
--------------------------------------------------------------------------------
/docs/_static/TM23Cu_sample_complexity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/docs/_static/TM23Cu_sample_complexity.png
--------------------------------------------------------------------------------
/docs/_static/css/mystyle.css:
--------------------------------------------------------------------------------
1 | .literal-no-code {
2 | background-color: transparent;
3 | font-size: 1rem;
4 | }
5 |
--------------------------------------------------------------------------------
/docs/_static/diagram_part1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/docs/_static/diagram_part1.png
--------------------------------------------------------------------------------
/docs/_templates/class.rst:
--------------------------------------------------------------------------------
1 | {{ name | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 | :members:
7 | :show-inheritance:
8 | :no-undoc-members:
9 | :special-members: __mul__, __add__, __div__, __neg__, __sub__, __truediv__
10 |
--------------------------------------------------------------------------------
/docs/_templates/func.rst:
--------------------------------------------------------------------------------
1 | {{ name | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autofunction:: {{ objname }}
--------------------------------------------------------------------------------
/docs/_templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 |
3 | {% block extrahead %}
4 | {{ super() }}
5 |
11 | {% endblock extrahead %}
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 |
8 | import os
9 | import sys
10 | import re
11 | from docutils import nodes
12 | from sphinxawesome_theme.postprocess import Icons
13 |
14 | # -- Path setup --------------------------------------------------------------
15 | basedir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
16 | sys.path.insert(0, basedir)
17 |
18 |
19 | html_permalinks_icon = Icons.permalinks_icon # SVG as a string
20 |
21 | # -- Project information -----------------------------------------------------
22 |
23 | project = "franken"
24 | copyright = "2025, franken team"
25 | author = "franken team"
26 |
27 | # -- General configuration ---------------------------------------------------
28 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
29 |
30 | # Add any paths that contain templates here, relative to this directory.
31 | templates_path = ["_templates"]
32 |
33 | # List of patterns, relative to source directory, that match files and
34 | # directories to ignore when looking for source files.
35 | # This pattern also affects html_static_path and html_extra_path.
36 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "requirements.txt"]
37 |
38 | # Add any Sphinx extension module names here, as strings. They can be
39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
40 | # ones.
41 | extensions = [
42 | "sphinx.ext.autodoc",
43 | "sphinx.ext.autosummary",
44 | "sphinx.ext.napoleon",
45 | "sphinx.ext.intersphinx",
46 | "sphinxawesome_theme",
47 | "myst_parser",
48 | "sphinxarg.ext",
49 | "nbsphinx",
50 | ]
51 |
52 | myst_enable_extensions = ["amsmath", "dollarmath", "html_image"]
53 |
54 | intersphinx_mapping = {
55 | "numpy": ("https://numpy.org/doc/stable/", None),
56 | "torch": ("https://pytorch.org/docs/stable/", None),
57 | "torchvision": ("https://pytorch.org/vision/stable/", None),
58 | "python": ("https://docs.python.org/3.9/", None),
59 | "ase": ("https://wiki.fysik.dtu.dk/ase/", None),
60 | }
61 |
62 |
63 | autodoc_typehints = "description"
64 | autodoc_typehints_description_target = "documented"
65 | # to handle functions as default input arguments
66 | autodoc_preserve_defaults = True
67 | # Warn about broken links
68 | nitpicky = True
69 | autodoc_inherit_docstrings = False
70 | # autodoc_class_signature = "separated"
71 | # autoclass_content = "class"
72 | # autosummary_generate = False
73 |
74 | # autodoc_member_order = "groupwise"
75 | # napoleon_preprocess_types = True
76 | # napoleon_use_rtype = False
77 |
78 | # master_doc = "index"
79 |
80 | source_suffix = {
81 | ".rst": "restructuredtext",
82 | ".txt": "restructuredtext",
83 | ".md": "markdown",
84 | }
85 |
86 | # -- Options for HTML output -------------------------------------------------
87 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
88 | html_theme = "sphinxawesome_theme"
89 | html_static_path = ["_static"]
90 | html_css_files = ["css/mystyle.css"]
91 | templates_path = ["_templates"]
92 | # Favicon configuration
93 | # html_favicon = '_static/favicon.ico'
94 | # Configure syntax highlighting for Awesome Sphinx Theme
95 | pygments_style = "default"
96 | pygments_style_dark = "material"
97 | html_title = "franken"
98 | # Additional theme configuration
99 | html_theme_options = {
100 | "show_prev_next": True,
101 | "show_scrolltop": True,
102 | "main_nav_links": {
103 | "Docs": "index",
104 | "API Reference": "reference/index",
105 | },
106 | "extra_header_link_icons": {
107 | "GitHub": {
108 | "link": "https://github.com/CSML-IIT-UCL/franken",
109 | "icon": """""",
110 | },
111 | },
112 | # "logo_light": "_static/[logo_light].png",
113 | # "logo_dark": "_static/[logo_dark].png",
114 | }
115 |
116 | ## Teletype role
117 | tt_re = re.compile('^:tt:`(.*)`$')
118 | def tt_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
119 | """
120 | Can be used as :tt:`SOME_TEXT_HERE`,
121 | """
122 | result = []
123 | m = tt_re.search(rawtext)
124 | if m:
125 | arg = m.group(1)
126 | result = [nodes.literal('', arg)]
127 | result[0]['classes'].append('literal-no-code')
128 | return result,[]
129 |
130 |
131 | def setup(app):
132 | app.add_role('tt', tt_role)
133 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. role:: frnkword
2 |
3 | Franken: A Method for Efficient and Accurate Molecular Dynamics
4 | ================================================================
5 |
6 | :tt:`franken` is a novel method designed to enhance the accuracy of atomistic foundation models used for molecular dynamics simulations, all while maintaining computational efficiency. This method builds upon the capabilities of the `MEKRR method `_, extending its application from fitting energies to also fitting forces, thereby enabling feasible molecular dynamics (MD) simulations.
7 |
8 | Franken's Three-Step Process
9 | ----------------------------
10 |
11 | :tt:`franken` operates through a three-step pipeline:
12 |
13 | #. **Feature Extraction:** The initial step involves representing the chemical environment of each atom within a
14 | molecular configuration using features extracted from a pre-trained GNN foundation model.
15 | This leverages the inherent knowledge captured by these pre-trained models.
16 | Specifically, :tt:`franken` utilizes features derived from models such as the `MACE-MP0 `_ model.
17 |
18 | #. **Random Features Enhancement:** In this stage, :tt:`franken` introduces non-linearity into the model by transforming the
19 | extracted GNN features using Random Features (RF) maps. These RF maps offer a computationally efficient alternative
20 | to traditional kernel methods by approximating kernel functions, including the widely used Gaussian kernel,
21 | utilizing randomly sampled parameters.
22 |
23 | #. **Energy and Force Prediction:** The final step involves predicting atomic energies and forces by employing a readout mechanism.
24 | This mechanism leverages a learnable vector of coefficients in conjunction with the transformed features obtained from the preceding step.
25 | This design takes advantage of the efficient optimization characteristics of RF models.
26 |
27 | .. figure:: _static/diagram_part1.png
28 | :class: rounded-image
29 | :width: 75%
30 | :align: center
31 |
32 | The three-step pipeline at the heart of :tt:`franken`.
33 |
34 | Advantages of Franken
35 | ---------------------
36 |
37 | :tt:`franken` presents several distinct advantages that position it as a valuable asset in the realm of molecular dynamics simulations:
38 |
39 | - **Closed-Form Optimization:** :tt:`franken` offers the significant advantage of determining the globally optimal model
40 | parameters through a closed-form solution. This eliminates the reliance on iterative gradient descent, leading to
41 | substantial reductions in training time and ensuring efficient optimization.
42 |
43 | - **High Sample Efficiency:** One of :tt:`franken`'s hallmarks is its exceptional data efficiency.
44 | The method achieves accurate results even with a limited number of training samples,
45 | as evidenced by experiments on the TM23 dataset. Notably, :tt:`franken` attained a validation error
46 | of 9 meV/ using only 128 samples with 1024 random features, underscoring its ability to extract
47 | valuable information from relatively small datasets.
48 |
49 | .. figure:: _static/TM23Cu_sample_complexity.png
50 | :class: rounded-image
51 | :width: 75%
52 | :align: center
53 |
54 | Sample complexity of :tt:`franken` on the :tt:`Cu` data from the `TM23 Dataset `_. (MACE-MP0 Backbone)
55 |
56 | - **Parallelization Capabilities:** :tt:`franken`'s training algorithm inherently lends itself to parallelization, allowing it to be scaled across multiple GPUs, thereby significantly accelerating training. This scalability becomes crucial when addressing the computational burden posed by simulations of increasingly intricate molecular systems.
57 |
58 |
59 | .. toctree::
60 | :maxdepth: 2
61 | :caption: HOW TOs:
62 | :hidden:
63 |
64 | Introduction
65 |
66 | topics/installation.md
67 | topics/model_registry.md
68 | topics/lammps.md
69 |
70 | .. toctree::
71 | :maxdepth: 2
72 | :caption: Tutorials:
73 | :hidden:
74 |
75 | notebooks/getting_started
76 | notebooks/autotune
77 | notebooks/molecular_dynamics
78 |
79 |
80 | .. toctree::
81 | :maxdepth: 3
82 | :caption: API Reference:
83 | :hidden:
84 |
85 | reference/index
86 | reference/cli
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/notebooks/autotune.ipynb:
--------------------------------------------------------------------------------
1 | ../../notebooks/autotune.ipynb
--------------------------------------------------------------------------------
/docs/notebooks/getting_started.ipynb:
--------------------------------------------------------------------------------
1 | ../../notebooks/getting_started.ipynb
--------------------------------------------------------------------------------
/docs/notebooks/molecular_dynamics.ipynb:
--------------------------------------------------------------------------------
1 | ../../notebooks/molecular_dynamics.ipynb
--------------------------------------------------------------------------------
/docs/reference/cli.rst:
--------------------------------------------------------------------------------
1 | .. _cli_reference:
2 |
3 | Franken CLI Reference
4 | =====================
5 |
6 |
7 | .. list-table::
8 | :header-rows: 1
9 |
10 | * - Program
11 | - Description
12 | * - :doc:`franken.trainers `
13 | - Automatic hyperparameter tuning for franken models.
14 | * - :doc:`franken.calculators `
15 | - List and download GNN backbones for franken.
16 | * - :doc:`franken.rf.model `
17 | - Convert a franken model to be able to use it with LAMMPS.
18 |
19 | .. toctree::
20 | :maxdepth: 1
21 | :hidden:
22 |
23 | franken-cli/franken.autotune
24 | franken-cli/franken.backbones
25 | franken-cli/franken.create_lammps_model
26 |
--------------------------------------------------------------------------------
/docs/reference/franken-api/franken.calculators.rst:
--------------------------------------------------------------------------------
1 | franken.calculators
2 | ===================
3 |
4 |
5 | .. autosummary::
6 | :toctree: stubs
7 | :template: class.rst
8 | :nosignatures:
9 |
10 | franken.calculators.FrankenCalculator
11 | franken.calculators.LammpsFrankenCalculator
12 |
--------------------------------------------------------------------------------
/docs/reference/franken-api/franken.config.rst:
--------------------------------------------------------------------------------
1 | franken.config
2 | ==============
3 | Object-oriented configuration for the `franken` library.
4 |
5 |
6 | Backbone configuration
7 | ----------------------
8 |
9 | .. autosummary::
10 | :toctree: stubs
11 | :template: class.rst
12 | :nosignatures:
13 |
14 | franken.config.MaceBackboneConfig
15 | franken.config.FairchemBackboneConfig
16 | franken.config.SevennBackboneConfig
17 |
18 |
19 | Random feature configuration
20 | ----------------------------
21 |
22 | .. autosummary::
23 | :toctree: stubs
24 | :template: class.rst
25 | :nosignatures:
26 |
27 | franken.config.GaussianRFConfig
28 | franken.config.MultiscaleGaussianRFConfig
29 |
30 |
31 | Other configurations
32 | --------------------
33 |
34 | .. autosummary::
35 | :toctree: stubs
36 | :template: class.rst
37 | :nosignatures:
38 |
39 | franken.config.DatasetConfig
40 | franken.config.SolverConfig
41 | franken.config.AutotuneConfig
42 |
--------------------------------------------------------------------------------
/docs/reference/franken-api/franken.rf.heads.rst:
--------------------------------------------------------------------------------
1 | franken.rf.heads
2 | ================
3 | This module contains random feature implementations for different kernels
4 |
5 | Base Class
6 | ----------
7 | .. autosummary::
8 | :toctree: stubs
9 | :template: class.rst
10 | :nosignatures:
11 |
12 | franken.rf.heads.RandomFeaturesHead
13 |
14 |
15 | Gaussian kernel
16 | ---------------
17 | Approximations to the classical Gaussian (or RBF) kernel
18 |
19 | .. autosummary::
20 | :toctree: stubs
21 | :template: class.rst
22 | :nosignatures:
23 |
24 | franken.rf.heads.OrthogonalRFF
25 | franken.rf.heads.MultiScaleOrthogonalRFF
26 | franken.rf.heads.BiasedOrthogonalRFF
27 |
28 | Other kernels
29 | -------------
30 |
31 | .. autosummary::
32 | :toctree: stubs
33 | :template: class.rst
34 | :nosignatures:
35 |
36 | franken.rf.heads.Linear
37 | franken.rf.heads.RandomFeaturesHead
38 | franken.rf.heads.TensorSketch
39 |
40 | Helper Functions
41 | ----------------
42 |
43 | .. autosummary::
44 | :toctree: stubs
45 | :template: func.rst
46 | :nosignatures:
47 |
48 | franken.rf.heads.initialize_rf
49 |
--------------------------------------------------------------------------------
/docs/reference/franken-api/franken.rf.model.rst:
--------------------------------------------------------------------------------
1 | franken.rf.model
2 | ================
3 | The main franken model implementation
4 |
5 | .. autosummary::
6 | :toctree: stubs
7 | :template: class.rst
8 | :nosignatures:
9 |
10 | franken.rf.model.FrankenPotential
11 |
--------------------------------------------------------------------------------
/docs/reference/franken-api/franken.rf.scaler.rst:
--------------------------------------------------------------------------------
1 | franken.rf.scaler
2 | =================
3 |
4 | .. autosummary::
5 | :toctree: stubs
6 | :template: class.rst
7 | :nosignatures:
8 |
9 | franken.rf.scaler.FeatureScaler
10 | franken.rf.scaler.Statistics
11 |
12 |
13 | .. autosummary::
14 | :toctree: stubs
15 | :template: func.rst
16 | :nosignatures:
17 |
18 | franken.rf.scaler.compute_dataset_statistics
19 |
--------------------------------------------------------------------------------
/docs/reference/franken-api/franken.trainers.rst:
--------------------------------------------------------------------------------
1 | franken.trainers
2 | ================
3 |
4 | Base Class
5 | ----------
6 | .. autosummary::
7 | :toctree: stubs
8 | :template: class.rst
9 | :nosignatures:
10 |
11 | franken.trainers.BaseTrainer
12 |
13 | Random features trainer
14 | -----------------------
15 | .. autosummary::
16 | :toctree: stubs
17 | :template: class.rst
18 | :nosignatures:
19 |
20 | franken.trainers.RandomFeaturesTrainer
21 |
--------------------------------------------------------------------------------
/docs/reference/franken-cli/franken.autotune.rst:
--------------------------------------------------------------------------------
1 | Autotune
2 | ========
3 |
4 | .. argparse::
5 | :module: franken.autotune.script
6 | :func: get_parser_fn
7 | :prog: franken.autotune
8 | :nodefault:
9 |
--------------------------------------------------------------------------------
/docs/reference/franken-cli/franken.backbones.rst:
--------------------------------------------------------------------------------
1 | Backbones
2 | =========
3 |
4 |
5 | .. argparse::
6 | :module: franken.backbones.cli
7 | :func: get_parser_fn
8 | :prog: franken.backbones
9 |
--------------------------------------------------------------------------------
/docs/reference/franken-cli/franken.create_lammps_model.rst:
--------------------------------------------------------------------------------
1 | Create LAMMPS model
2 | ===================
3 |
4 | .. argparse::
5 | :module: franken.calculators.lammps_calc
6 | :func: get_parser_fn
7 | :prog: franken.create_lammps_model
8 |
--------------------------------------------------------------------------------
/docs/reference/index.rst:
--------------------------------------------------------------------------------
1 |
2 | Franken API Reference
3 | =====================
4 |
5 | .. list-table::
6 | :header-rows: 1
7 |
8 | * - Module
9 | - Description
10 | * - :doc:`franken.trainers `
11 | - Train franken from atomistic simulation data
12 | * - :doc:`franken.calculators `
13 | - Run molecular dynamics with learned potentials.
14 | * - :doc:`franken.rf.model `
15 | - Main model class for franken
16 | * - :doc:`franken.rf.heads `
17 | - Random feature implementations for different kernels
18 | * - :doc:`franken.rf.scaler `
19 | - Utilities for scaling random features
20 | * - :doc:`franken.config `
21 | - Configuration data-classes for the whole franken library
22 |
23 |
24 | .. toctree::
25 | :maxdepth: 1
26 | :hidden:
27 |
28 | franken-api/franken.trainers
29 | franken-api/franken.calculators
30 | franken-api/franken.rf.model
31 | franken-api/franken.rf.heads
32 | franken-api/franken.rf.scaler
33 | franken-api/franken.config
34 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | Sphinx==7.3.7
2 | sphinxawesome-theme==5.2.0
3 | sphinxcontrib-applehelp==2.0.0
4 | sphinxcontrib-devhelp==2.0.0
5 | sphinxcontrib-htmlhelp==2.1.0
6 | sphinxcontrib-jsmath==1.0.1
7 | sphinxcontrib-qthelp==2.0.0
8 | sphinxcontrib-serializinghtml==2.0.0
9 | sphinx-argparse
10 | myst-parser
11 | nbsphinx
--------------------------------------------------------------------------------
/docs/topics/installation.md:
--------------------------------------------------------------------------------
1 | (installation)=
2 | # Installation
3 |
4 | To install `franken`, start by setting up your environment with the correct **version of [PyTorch](https://pytorch.org/)**. This is especially necessary if you wish to use GPUs. Then install `franken` by running
5 | ```bash
6 | pip install franken
7 | ```
8 | The basic installation comes bare-bones without any GNN backbone installed. You can install franken with a specific backbone directly, by running one of the following commands
9 | ```bash
10 | pip install franken[cuda,mace]
11 | pip install franken[cuda,fairchem]
12 | pip install franken[cuda,sevenn]
13 | ```
14 | In more detail:
15 | - the `cuda` qualifier installs dependencies which are only relevant on GPU-enabled environments and can be omitted.
16 | - the three supported backbones are [MACE](https://github.com/ACEsuit/mace), [SchNet from fairchem](https://github.com/FAIR-Chem/fairchem), and [SevenNet](https://github.com/MDIL-SNU/SevenNet). They are explained in more detail below.
17 |
18 |
19 | ```{warning}
20 | Each backbone seems to have mutually incompatible requirements, particularly with regards to `e3nn` - but also pytorch versions might be a problem.
21 | To minimize incompatibilities, we suggest that the users who wishes to use multiple backbones create independent python environments for each.
22 | In particular, the `mace-torch` package requires an old version of `e3nn` (0.4.4) which conflicts with `fairchem-core`, see [this relevant issue](https://github.com/ACEsuit/mace/issues/555) and with `SevenNet`. If you encounter errors with model loading, simply upgrade `e3nn` by running `pip install -U e3nn`.
23 | ```
24 |
25 | ## Supported pre-trained models
26 | ### MACE
27 | We support several models which use the [MACE architecture](https://github.com/ACEsuit/mace):
28 | - The [`MACE-MP0`](https://arxiv.org/abs/2401.00096) models trained on the materials project data by Batatia et al. Additional informations on the pre-training of `MACE-MP0` are available on its [HuggingFace model card](https://huggingface.co/cyrusyc/mace-universal).
29 | - The MACE-OFF ([paper](https://github.com/ACEsuit/mace-off) and [github](https://github.com/ACEsuit/mace-off)) models which are pretrained on organic molecules.
30 | - The Egret ([github](https://github.com/rowansci/egret-public)) family of models (`Egret-1`, `Egret-1e`, `Egret-1t`), also tuned for organic molecules.
31 |
32 | To use any MACE model as a backbone for `franken` just `pip`-install `mace-torch` in `franken`'s environment
33 | ```bash
34 | pip install mace-torch
35 | ```
36 | or directly install franken with mace support (`pip install franken[cuda,mace]`).
37 |
38 | In addition to MACE-MP0 trained on the materials project dataset, Franken also supports the [`MACE-OFF` models](https://arxiv.org/abs/2312.15211) for organic chemistry.
39 |
40 |
41 | ### SevenNet
42 |
43 | Franken also supports the [SevenNet model](https://arxiv.org/abs/2402.03789) by Park et al. as implemented in the [`sevennet`](https://github.com/MDIL-SNU/SevenNet) library.
44 | We have only tested the SevenNet-0 model trained on the materials project dataset, but support for other models should be possible (open an issue if you encounter any problem).
45 |
46 | ### SchNet OC20 (fairchem, formerly OCP)
47 | We support the [SchNet model](https://arxiv.org/abs/1706.08566) by Schütt et al. as implemented in the [`fairchem`](https://fair-chem.github.io/) library by Meta's FAIR. The pre-training was done on the [Open Catalyst dataset](https://fair-chem.github.io/core/datasets/oc20.html). To use it as a backbone for `franken`, install the `fairchem` library
48 | ```bash
49 | pip install fairchem-core
50 | ```
51 | and the `torch_geometric` dependencies as explained in the [FairChem docs](https://fair-chem.github.io/core/install.html).
52 | ```{note}
53 | Not all of fairchem's dependencies can be installed by `pip` alone, check the [FairChem docs](https://fair-chem.github.io/core/install.html).
54 | ```
55 | Note that `SchNet` is not competitive with more recent GNN models and is only meant as a baseline, and to showcase support for diverse backends.
56 | For now we do not support fairchem v2 models, if you wish to see this implemented please file an issue!
--------------------------------------------------------------------------------
/docs/topics/lammps.md:
--------------------------------------------------------------------------------
1 | # Franken + LAMMPS
2 |
3 | The basic steps required to run a Franken model with [LAMMPS](https://www.lammps.org/) are:
4 | 1. Compile the model using `franken/calculators/lammps.py`:
5 | ```bash
6 | franken.create_lammps_model --model_path=
7 | ```
8 | Note that only models which use the MACE backbone can be compiled and run with LAMMPS. For the other backbones please use the ase MD interface. The compiled model will be saved in the same directory as the original model, with `-lammps` appended to the filename.
9 | 2. Configure LAMMPS. The following lines are necessary, the second line should point to the compiled model from step 1.
10 | ```
11 | pair_style mace no_domain_decomposition
12 | pair_coeff * * C H N O
13 | ```
14 | 3. Run LAMMPS-Mace. On leonardo you can find it pre-compiled here:
15 | `/leonardo/pub/userexternal/lbonati1/software/lammps-mace/lammps/build-ampere-plumed/lmp`
16 |
17 | ## Compiling LAMMPS-Mace
18 |
19 | This follows the [MACE guide](https://mace-docs.readthedocs.io/en/latest/guide/lammps.html) adapting it to the leonardo cluster.
20 | This can be useful in case one wants to modify the Mace patch to LAMMPS. In particular, the following two files are important:
21 | - [https://github.com/ACEsuit/lammps/blob/mace/src/ML-MACE/pair_mace.cpp](https://github.com/ACEsuit/lammps/blob/mace/src/ML-MACE/pair_mace.cpp)
22 | - [https://github.com/ACEsuit/lammps/blob/mace/src/KOKKOS/pair_mace_kokkos.cpp](https://github.com/ACEsuit/lammps/blob/mace/src/KOKKOS/pair_mace_kokkos.cpp)
23 |
24 | We will assume to start from directory `$BASE_DIR`
25 | 1. ```git clone --branch=mace --depth=1 https://github.com/ACEsuit/lammps```
26 | 2. download librtorch. For now keeping the default version as specified by MACE, but note that new versions exist!
27 | ```bash
28 | wget https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.2.0%2Bcu121.zip
29 | unzip libtorch-shared-with-deps-2.2.0+cu121.zip
30 | rm libtorch-shared-with-deps-2.2.0+cu121.zip
31 | mv libtorch libtorch-gpu
32 | ```
33 | 3. Get a GPU node for compilation
34 | `srun -N 1 --ntasks-per-node=1 --cpus-per-task=8 --gres=gpu:1 -A -p boost_usr_prod -t 00:30:00 --pty /bin/bash`
35 | 4. Compile:
36 | 1. Load modules
37 | ```bash
38 | module purge
39 | module load gcc/12.2.0
40 | module load gsl/2.7.1--gcc--12.2.0
41 | module load openmpi/4.1.6--gcc--12.2.0
42 | module load fftw/3.3.10--openmpi--4.1.6--gcc--12.2.0
43 | module load openblas/0.3.24--gcc--12.2.0
44 | module load cuda/12.1
45 | module load intel-oneapi-mkl/2023.2.0
46 | ```
47 | 2. Compile
48 | ```bash
49 | cd $BASE_DIR/lammps
50 | mkdir -p build-ampere
51 | cd build-ampere
52 | cmake \
53 | -D CMAKE_BUILD_TYPE=Release \
54 | -D CMAKE_INSTALL_PREFIX=$(pwd) \
55 | -D CMAKE_CXX_STANDARD=17 \
56 | -D CMAKE_CXX_STANDARD_REQUIRED=ON \
57 | -D BUILD_MPI=ON \
58 | -D BUILD_SHARED_LIBS=ON \
59 | -D PKG_KOKKOS=ON \
60 | -D Kokkos_ENABLE_CUDA=ON \
61 | -D CMAKE_CXX_COMPILER=$(pwd)/../lib/kokkos/bin/nvcc_wrapper \
62 | -D Kokkos_ARCH_AMDAVX=ON \
63 | -D Kokkos_ARCH_AMPERE100=ON \
64 | -D CMAKE_PREFIX_PATH=$(pwd)/../../libtorch-gpu \
65 | -D PKG_ML-MACE=ON \
66 | ../cmake
67 | make -j 8
68 | make install
69 | ```
70 | The compiled binary is then at `$BASE_DIR/lammps/build-ampere/bin/lmp`.
71 |
72 |
73 | ## Running LAMMPS-Mace
74 |
75 | This is just an example sbatch file which can be used to run LAMMPS-Mace. Edit it according to your needs. It uses the paths to LAMMPS-Mace as available on the leonardo cluster, and we will assume that LAMMPS has been configured in a file named `in.lammps`.
76 |
77 | ```bash
78 | #!/bin/bash
79 | #SBATCH --account=
80 | #SBATCH --partition=boost_usr_prod # partition to be used
81 | #SBATCH --time 00:30:00 # format: HH:MM:SS
82 | #SBATCH --qos=boost_qos_dbg
83 | #SBATCH --nodes=1 # node
84 | #SBATCH --ntasks-per-node=1 # tasks out of 32
85 | #SBATCH --gres=gpu:1 # gpus per node out of 4
86 | #SBATCH --cpus-per-task=1 # Important: if > 1 kokkos complains.
87 | ############################
88 |
89 | module purge
90 | module load profile/base
91 | module load gcc/12.2.0
92 | module load gsl/2.7.1--gcc--12.2.0
93 | module load openmpi/4.1.6--gcc--12.2.0
94 | module load fftw/3.3.10--openmpi--4.1.6--gcc--12.2.0
95 | module load openblas/0.3.24--gcc--12.2.0
96 | module load cuda/12.1
97 | module load intel-oneapi-mkl/2023.2.0
98 |
99 | . /leonardo/pub/userexternal/lbonati1/software/lammps-mace/libtorch-gpu/sourceme.sh
100 | . /leonardo/pub/userexternal/lbonati1/software/plumed/plumed2-2.9-gcc12/sourceme.sh
101 |
102 | echo "setting env variable"
103 | export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK}
104 | export OMP_PROC_BIND=spread
105 | export OMP_PLACES=threads
106 |
107 | echo "running job"
108 | in_file='in.lammps'
109 | log_file='log.lammps'
110 | lmp='/leonardo/pub/userexternal/lbonati1/software/lammps-mace/lammps/build-ampere-plumed/lmp'
111 |
112 | srun $lmp -k on g 1 t ${SLURM_CPUS_PER_TASK} -sf kk -i $in_file -l $log_file
113 |
114 | wait
115 | ```
116 |
--------------------------------------------------------------------------------
/docs/topics/model_registry.md:
--------------------------------------------------------------------------------
1 | (model-registry)=
2 | # Backbones Registry
3 |
4 | The available pre-trained GNNs can be listed by running `franken.backbones list`.
5 | As of today, the available models are:
6 |
7 | ```
8 | DOWNLOADED MODELS
9 | --------------------(/path/to/.franken/gnn_checkpoints)--------------------
10 | MACE-L0 (MACE)
11 | --------------------------------AVAILABLE MODELS--------------------------------
12 | SevenNet0 (sevenn)
13 | MACE-L1 (MACE)
14 | MACE-L2 (MACE)
15 | MACE-OFF-small (MACE)
16 | MACE-OFF-medium (MACE)
17 | MACE-OFF-large (MACE)
18 | SchNet-S2EF-OC20-200k (fairchem)
19 | SchNet-S2EF-OC20-2M (fairchem)
20 | SchNet-S2EF-OC20-20M (fairchem)
21 | SchNet-S2EF-OC20-All (fairchem)
22 | --------------------------------------------------------------------------------
23 | ```
24 |
25 | Models can also be directly downloaded by copying the backbone-ID from the command above into the `download` command
26 |
27 | ```bash
28 | franken.backbones download
29 | ```
30 |
31 | Check the command-line help (e.g. `franken.backbones download --help`) for more information.
--------------------------------------------------------------------------------
/franken/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | __version__ = "0.4.1"
4 |
5 | # Get the absolute path of the directory where this file is located
6 | FRANKEN_DIR = Path(__file__).resolve().parent
7 |
--------------------------------------------------------------------------------
/franken/autotune/__init__.py:
--------------------------------------------------------------------------------
1 | from franken.autotune.script import autotune
2 |
3 | __all__ = ["autotune"]
4 |
--------------------------------------------------------------------------------
/franken/autotune/__main__.py:
--------------------------------------------------------------------------------
1 | from franken.autotune.script import cli_entry_point
2 |
3 |
4 | if __name__ == "__main__":
5 | cli_entry_point()
6 |
--------------------------------------------------------------------------------
/franken/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from franken.backbones.utils import load_model_registry
2 |
3 |
4 | REGISTRY = load_model_registry()
5 |
--------------------------------------------------------------------------------
/franken/backbones/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from franken.backbones import REGISTRY
4 | from franken.backbones.utils import download_checkpoint, make_summary
5 | from franken.utils.misc import setup_logger
6 |
7 |
8 | ### Command 'list': list available models
9 |
10 |
11 | def build_list_arg_parser(subparsers) -> argparse.ArgumentParser:
12 | parser = subparsers.add_parser("list", description="List available models")
13 | parser.add_argument(
14 | "--cache_dir",
15 | help=(
16 | "Directory to save the downloaded checkpoints. "
17 | "Defaults to '~/.franken/' in the user home or to the "
18 | "'FRANKEN_CACHE_DIR' environment variable if set."
19 | ),
20 | type=str,
21 | default=None,
22 | )
23 | parser.add_argument(
24 | "--log_level",
25 | help="log-level for the command-line logger",
26 | type=str,
27 | default="INFO",
28 | choices=["DEBUG", "INFO", "WARNING", "ERROR"],
29 | )
30 | return parser
31 |
32 |
33 | def run_list_cmd(args):
34 | setup_logger(level=args.log_level, directory=None)
35 | print(make_summary(cache_dir=args.cache_dir))
36 |
37 |
38 | ### Command 'download': download a model
39 |
40 |
41 | def build_download_arg_parser(subparsers) -> argparse.ArgumentParser:
42 | parser = subparsers.add_parser("download", description="Download a model")
43 | parser.add_argument(
44 | "--model_name",
45 | help="The name of the model to download.",
46 | type=str,
47 | required=True,
48 | choices=[
49 | name for name, info in REGISTRY.items() if info["implemented"] is True
50 | ],
51 | )
52 | parser.add_argument(
53 | "--cache_dir",
54 | help=(
55 | "Directory to save the downloaded checkpoints. "
56 | "Defaults to '~/.franken/' in the user home or to the "
57 | "'FRANKEN_CACHE_DIR' environment variable if set."
58 | ),
59 | type=str,
60 | default=None,
61 | )
62 | parser.add_argument(
63 | "--log_level",
64 | help="log-level for the command-line logger",
65 | type=str,
66 | default="INFO",
67 | choices=["DEBUG", "INFO", "WARNING", "ERROR"],
68 | )
69 | return parser
70 |
71 |
72 | def run_download_cmd(args):
73 | setup_logger(level=args.log_level, directory=None)
74 | download_checkpoint(args.model_name, args.cache_dir)
75 |
76 |
77 | def build_arg_parser():
78 | parser = argparse.ArgumentParser(
79 | description="List and download GNN backbones for franken."
80 | )
81 |
82 | subparsers = parser.add_subparsers(
83 | required=True,
84 | title="Franken backbone CLI",
85 | description="Provides helpers to interact with the various backbone models supported by Franken",
86 | help="Run `%(prog)s -h` for help with the individual subcommands",
87 | )
88 |
89 | list_parser = build_list_arg_parser(subparsers)
90 | list_parser.set_defaults(func=run_list_cmd)
91 | download_parser = build_download_arg_parser(subparsers)
92 | download_parser.set_defaults(func=run_download_cmd)
93 |
94 | return parser
95 |
96 |
97 | def main():
98 | """This entry-point has 2 commands, 'list' and 'download'.
99 | Usage:
100 | franken.backbones list
101 | franken.backbones download
102 | """
103 | parser = build_arg_parser()
104 | args = parser.parse_args()
105 | args.func(args)
106 |
107 |
108 | if __name__ == "__main__":
109 | main()
110 |
111 |
112 | # For sphinx docs
113 | get_parser_fn = lambda: build_arg_parser() # noqa: E731
114 |
--------------------------------------------------------------------------------
/franken/backbones/utils.py:
--------------------------------------------------------------------------------
1 | import importlib.resources
2 | import json
3 | import logging
4 | import os
5 | from pathlib import Path
6 |
7 | import requests
8 | import torch
9 |
10 | from franken.config import BackboneConfig, asdict_with_classvar
11 | from franken.utils import distributed
12 | from franken.utils.file_utils import download_file
13 |
14 |
15 | logger = logging.getLogger("franken")
16 |
17 |
18 | def load_model_registry():
19 | model_registry_text = (
20 | importlib.resources.files("franken.backbones")
21 | .joinpath("registry.json")
22 | .read_text()
23 | )
24 | model_registry = json.loads(model_registry_text)
25 | return model_registry
26 |
27 |
28 | class CacheDir:
29 | directory: Path | None = None
30 |
31 | @staticmethod
32 | def initialize(cache_dir: Path | str | None = None):
33 | if CacheDir.is_initialized():
34 | logger.warning(
35 | f"Cache directory already initialized at {CacheDir.directory}. Reinitializing."
36 | )
37 | # Default cache location: ~/.franken
38 | default_cache = Path.home() / ".franken"
39 | if cache_dir is None:
40 | env_cache_dir = os.environ.get("FRANKEN_CACHE_DIR", None)
41 | if env_cache_dir is None:
42 | logger.info(f"Initializing default cache directory at {default_cache}")
43 | cache_dir = default_cache
44 | else:
45 | logger.info(
46 | f"Initializing cache directory from $FRANKEN_CACHE_DIR {env_cache_dir}"
47 | )
48 | cache_dir = env_cache_dir
49 | else:
50 | logger.info(f"Initializing custom cache directory {cache_dir}")
51 | CacheDir.directory = Path(cache_dir)
52 |
53 | # Ensure the directory exists
54 | if not CacheDir.directory.exists():
55 | CacheDir.directory.mkdir(parents=True, exist_ok=True)
56 | logger.info(f"Created cache directory at: {CacheDir.directory}")
57 |
58 | @staticmethod
59 | def get() -> Path:
60 | if not CacheDir.is_initialized():
61 | CacheDir.initialize()
62 | assert CacheDir.directory is not None
63 | return CacheDir.directory
64 |
65 | @staticmethod
66 | def is_initialized() -> bool:
67 | return CacheDir.directory is not None
68 |
69 |
70 | def make_summary(cache_dir: str | None = None):
71 | """Function to print available models, first those present locally."""
72 | if cache_dir is not None:
73 | CacheDir.initialize(cache_dir=cache_dir)
74 | registry = load_model_registry()
75 | ckpt_dir = CacheDir.get() / "gnn_checkpoints"
76 |
77 | local_models = []
78 | remote_models = []
79 | _summary = ""
80 | for model, info in registry.items():
81 | local_path = ckpt_dir / info["local"]
82 | kind = info["kind"]
83 | implemented = info.get("implemented", False)
84 | if implemented:
85 | if local_path.is_file():
86 | local_models.append((model, kind))
87 | else:
88 | remote_models.append((model, kind))
89 | if len(local_models) > 0:
90 | _summary += f"{'DOWNLOADED MODELS':^80}\n"
91 | _summary += f"{'(' + str(ckpt_dir) + ')':-^80}\n"
92 | for model, kind in local_models:
93 | _str = f"{model} ({kind})"
94 | _summary += f"{_str:<40}\n"
95 |
96 | _summary += f"{'AVAILABLE MODELS':-^80}\n"
97 | for model, kind in remote_models:
98 | _str = f"{model} ({kind})"
99 | _summary += f"{_str:<80}\n"
100 | _summary += "-" * 80
101 | return _summary
102 |
103 |
104 | def get_checkpoint_path(backbone_path_or_id: str) -> Path:
105 | """Fetches the path of a given backbone. If the backbone is not present, it will be downloaded.
106 |
107 | The backbone can be either specified directly via its file-system path,
108 | then this function is a thin wrapper -- or it can be specified via its
109 | ID in the model registry. Then this function takes care of finding the
110 | correct model path and potentially downloading the backbone from the internet.
111 |
112 | Args:
113 | backbone_path_or_id (str): file-system path to the backbone
114 | or the backbone's ID as per the model registry.
115 |
116 | Returns:
117 | Path: Path to the model on disk
118 |
119 | See Also:
120 | You can use the command :code:`franken.backbones list` from the command-line
121 | to find out which backbone IDs are supported out-of-the-box.
122 | """
123 | registry = load_model_registry()
124 | gnn_checkpoints_dir = CacheDir.get() / "gnn_checkpoints"
125 |
126 | if backbone_path_or_id not in registry.keys():
127 | if not os.path.isfile(backbone_path_or_id):
128 | raise FileNotFoundError(
129 | f"GNN Backbone path '{backbone_path_or_id}' does not exist. "
130 | f"You should either provide an existing backbone path or a backbone ID "
131 | f"from the registry of available backbones: \n{make_summary()}"
132 | )
133 | return Path(backbone_path_or_id)
134 | else:
135 | backbone_info = registry[backbone_path_or_id]
136 | ckpt_path = gnn_checkpoints_dir / backbone_info["local"]
137 | # Download checkpoint being aware of multiprocessing
138 | if distributed.get_rank() != 0:
139 | distributed.barrier()
140 | else:
141 | if not ckpt_path.exists():
142 | download_checkpoint(backbone_path_or_id)
143 | distributed.barrier()
144 | return ckpt_path
145 |
146 |
147 | def download_checkpoint(gnn_backbone_id: str, cache_dir: str | None = None) -> None:
148 | """Download the model if it's not already present locally."""
149 | registry = load_model_registry()
150 | if cache_dir is not None:
151 | CacheDir.initialize(cache_dir=cache_dir)
152 | ckpt_dir = CacheDir.get() / "gnn_checkpoints"
153 |
154 | if gnn_backbone_id not in registry.keys():
155 | raise NameError(
156 | f"Unknown {gnn_backbone_id} GNN backbone, the current available backbones are\n{make_summary()}"
157 | )
158 |
159 | if not registry[gnn_backbone_id]["implemented"]:
160 | raise NotImplementedError(
161 | f"The model {gnn_backbone_id} is not implemented in franken yet."
162 | )
163 |
164 | local_path = ckpt_dir / registry[gnn_backbone_id]["local"]
165 | remote_path = registry[gnn_backbone_id]["remote"]
166 |
167 | if local_path.is_file():
168 | logger.info(
169 | f"Model already exists locally at {local_path}. No download needed."
170 | )
171 | return
172 |
173 | local_path.parent.mkdir(parents=True, exist_ok=True)
174 | logger.info(f"Downloading model from {remote_path} to {local_path}")
175 | try:
176 | download_file(url=remote_path, filename=local_path, desc="Downloading model")
177 | except requests.RequestException as e:
178 | logger.error(f"Download failed. {e}")
179 | raise e
180 |
181 |
182 | def load_checkpoint(gnn_config: BackboneConfig) -> torch.nn.Module:
183 | gnn_config_dict = asdict_with_classvar(gnn_config)
184 | gnn_backbone_id = gnn_config_dict.pop("path_or_id")
185 | backbone_family = gnn_config_dict.pop("family")
186 | ckpt_path = get_checkpoint_path(gnn_backbone_id)
187 | err_msg = f"franken wasn't able to load {gnn_backbone_id}. Is {backbone_family} installed?"
188 | if backbone_family == "fairchem":
189 | try:
190 | from franken.backbones.wrappers.fairchem_schnet import FrankenSchNetWrap
191 | except ImportError as import_err:
192 | logger.error(err_msg, exc_info=import_err)
193 | raise
194 | return FrankenSchNetWrap.load_from_checkpoint(
195 | str(ckpt_path), gnn_backbone_id=gnn_backbone_id, **gnn_config_dict
196 | )
197 | elif backbone_family == "mace":
198 | try:
199 | from franken.backbones.wrappers.mace_wrap import FrankenMACE
200 | except ImportError as import_err:
201 | logger.error(err_msg, exc_info=import_err)
202 | raise
203 | return FrankenMACE.load_from_checkpoint(
204 | str(ckpt_path),
205 | gnn_backbone_id=gnn_backbone_id,
206 | map_location="cpu",
207 | **gnn_config_dict,
208 | )
209 | elif backbone_family == "sevenn":
210 | try:
211 | from franken.backbones.wrappers.sevenn import FrankenSevenn
212 | except ImportError as import_err:
213 | logger.error(err_msg, exc_info=import_err)
214 | raise
215 | return FrankenSevenn.load_from_checkpoint(
216 | ckpt_path, gnn_backbone_id=gnn_backbone_id, **gnn_config_dict
217 | )
218 | else:
219 | raise ValueError(f"Unknown backbone family {backbone_family}")
220 |
--------------------------------------------------------------------------------
/franken/backbones/wrappers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/franken/backbones/wrappers/__init__.py
--------------------------------------------------------------------------------
/franken/backbones/wrappers/common_patches.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 |
5 |
6 | logger = logging.getLogger("franken")
7 |
8 |
9 | def patch_e3nn():
10 | # NOTE:
11 | # Patching should occur during training: it is necessary for `jvp` on the MACE model,
12 | # but not during inference, when we only use `torch.autograd`. For inference, we may want
13 | # to compile the model using `torch.jit` - and the patch interferes with the JIT, so we
14 | # must disable it.
15 |
16 | import e3nn.o3._spherical_harmonics
17 |
18 | if hasattr(e3nn.o3._spherical_harmonics._spherical_harmonics, "code"):
19 | # Then _spherical_harmonics is a scripted function, we need to undo this!
20 | new_locals = {"Tensor": torch.Tensor}
21 | exec(e3nn.o3._spherical_harmonics._spherical_harmonics.code, None, new_locals)
22 |
23 | def _spherical_harmonics(
24 | lmax: int, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor
25 | ) -> torch.Tensor:
26 | return new_locals["_spherical_harmonics"](torch.tensor(lmax), x, y, z)
27 |
28 | # Save to allow undoing later
29 | setattr(
30 | e3nn.o3._spherical_harmonics,
31 | "_old_spherical_harmonics",
32 | e3nn.o3._spherical_harmonics._spherical_harmonics,
33 | )
34 | e3nn.o3._spherical_harmonics._spherical_harmonics = _spherical_harmonics
35 |
36 | # 2nd patch for newer e3nn versions (somewhere between 0.5.0 and 0.5.5
37 | # e3nn jits _spherical_harmonics which the SphericalHarmonics class,
38 | # making the above patch ineffective)
39 | try:
40 | from e3nn import set_optimization_defaults
41 |
42 | set_optimization_defaults(jit_script_fx=False)
43 | except ImportError:
44 | pass # only valid for newer e3nn
45 |
46 |
47 | def unpatch_e3nn():
48 | # This is only useful for CI and testing environments.
49 | # When jit-compiling a franken module (e.g. for LAMMPS), we don't want the patch applied!
50 | import e3nn.o3._spherical_harmonics
51 |
52 | if hasattr(e3nn.o3._spherical_harmonics, "_old_spherical_harmonics"):
53 | e3nn.o3._spherical_harmonics._spherical_harmonics = (
54 | e3nn.o3._spherical_harmonics._old_spherical_harmonics
55 | )
56 |
--------------------------------------------------------------------------------
/franken/backbones/wrappers/fairchem_schnet.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import fairchem.core.common.utils
4 | import torch
5 | from fairchem.core.models.schnet import SchNetWrap
6 |
7 | from franken.data import Configuration
8 |
9 |
10 | def segment_coo_patch(src, index, dim_size=None):
11 | if dim_size is None:
12 | dim_size = index.max().item() + 1
13 | out = torch.zeros(dim_size, dtype=src.dtype, device=src.device)
14 | out.scatter_add_(dim=0, index=index, src=src)
15 | return out
16 |
17 |
18 | def segment_csr_patch(src, indptr):
19 | out = torch.zeros(indptr.size(0) - 1, dtype=src.dtype, device=src.device)
20 | for i in range(len(indptr) - 1):
21 | out[i] = src[indptr[i] : indptr[i + 1]].sum()
22 | return out
23 |
24 |
25 | def patch_fairchem():
26 | """
27 | The `segment_coo` and `segment_csr` patches are necessary to allow
28 | forward-mode autodiff through the network, which is not implemented
29 | in the original torch-scatter functions.
30 | """
31 | fairchem.core.common.utils.segment_coo = segment_coo_patch
32 | fairchem.core.common.utils.segment_csr = segment_csr_patch
33 |
34 |
35 | FairchemCompatData = namedtuple(
36 | "FairchemCompatData", ["pos", "cell", "batch", "natoms", "atomic_numbers"]
37 | )
38 |
39 |
40 | class FrankenSchNetWrap(SchNetWrap):
41 | def __init__(self, *args, interaction_block, gnn_backbone_id, **kwargs):
42 | patch_fairchem()
43 | super().__init__(*args, **kwargs)
44 |
45 | self.interaction_block = interaction_block
46 | self.gnn_backbone_id = gnn_backbone_id
47 |
48 | def descriptors(
49 | self,
50 | data: Configuration,
51 | ):
52 | """
53 | Forward pass for the SchNet model to get the embedded representations of the input data
54 | """
55 | fairchem_compat_data = FairchemCompatData(
56 | data.atom_pos, data.cell, data.batch_ids, data.natoms, data.atomic_numbers
57 | )
58 | # fairchem checks if the attribute exists, not whether it's None.
59 | if data.pbc is not None:
60 | fairchem_compat_data.pbc = data.pbc # type: ignore
61 | # Get the atomic numbers of the input data
62 | z = data.atomic_numbers.long()
63 | assert z.dim() == 1
64 | # Get the edge index, edge weight and other attributes of the input data
65 | graph = self.generate_graph(fairchem_compat_data)
66 |
67 | edge_attr = self.distance_expansion(graph.edge_distance)
68 |
69 | # Get the embedded representations of the input data
70 | h = self.embedding(z)
71 | for interaction in self.interactions[: self.interaction_block]:
72 | h = h + interaction(h, graph.edge_index, graph.edge_distance, edge_attr)
73 |
74 | return h
75 |
76 | def feature_dim(self):
77 | return self.hidden_channels
78 |
79 | def num_params(self) -> int:
80 | return sum(p.numel() for p in self.parameters())
81 |
82 | def init_args(self):
83 | return {
84 | "gnn_backbone_id": self.gnn_backbone_id,
85 | "interaction_block": self.interaction_block,
86 | }
87 |
88 | @staticmethod
89 | def load_from_checkpoint(
90 | trainer_ckpt, gnn_backbone_id, interaction_block
91 | ) -> "FrankenSchNetWrap":
92 | ckpt_data = torch.load(
93 | trainer_ckpt, map_location=torch.device("cpu"), weights_only=False
94 | )
95 |
96 | model_config = ckpt_data["config"]["model_attributes"]
97 | model_config["otf_graph"] = True
98 |
99 | model = FrankenSchNetWrap(
100 | **model_config,
101 | interaction_block=interaction_block,
102 | gnn_backbone_id=gnn_backbone_id,
103 | )
104 | # Before we can load state, need to fix state-dict keys:
105 | # Match the "module." count in the keys of model and checkpoint state_dict
106 | # DataParallel model has 1 "module.", DistributedDataParallel has 2 "module."
107 | # Not using either of the above two would have no "module."
108 | ckpt_key_count = next(iter(ckpt_data["state_dict"])).count("module")
109 | mod_key_count = next(iter(model.state_dict())).count("module")
110 | key_count_diff = mod_key_count - ckpt_key_count
111 | if key_count_diff > 0:
112 | new_dict = {
113 | key_count_diff * "module." + k: v
114 | for k, v in ckpt_data["state_dict"].items()
115 | }
116 | elif key_count_diff < 0:
117 | new_dict = {
118 | k[len("module.") * abs(key_count_diff) :]: v
119 | for k, v in ckpt_data["state_dict"].items()
120 | }
121 | else:
122 | new_dict = ckpt_data["state_dict"]
123 | model.load_state_dict(new_dict)
124 | return model
125 |
--------------------------------------------------------------------------------
/franken/backbones/wrappers/sevenn.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from types import MethodType
3 | from typing import Union
4 | from functools import partial
5 |
6 | import sevenn._keys as KEY
7 | import torch
8 | import torch.nn as nn
9 | from sevenn.util import model_from_checkpoint
10 |
11 | from franken.data.base import Configuration
12 |
13 |
14 | def extract_scalar_irrep(data, irreps):
15 | node_features = data[KEY.NODE_FEATURE]
16 | scalar_slice = irreps.slices()[0]
17 | scalar_features = node_features[..., scalar_slice]
18 | return scalar_features
19 |
20 |
21 | def franken_sevenn_descriptors(
22 | self,
23 | data: Configuration,
24 | interaction_layer: int,
25 | extract_after_act: bool = True,
26 | append_layers: bool = True,
27 | ):
28 | # Convert data to sevenn
29 | assert data.cell is not None
30 | sevenn_data = {
31 | KEY.NODE_FEATURE: data.atomic_numbers,
32 | KEY.ATOMIC_NUMBERS: data.atomic_numbers,
33 | KEY.POS: data.atom_pos,
34 | KEY.EDGE_IDX: data.edge_index,
35 | KEY.CELL: data.cell,
36 | KEY.CELL_SHIFT: data.shifts, # TODO: Check this correct?
37 | KEY.CELL_VOLUME: torch.einsum(
38 | "i,i", data.cell[0, :], torch.linalg.cross(data.cell[1, :], data.cell[2, :])
39 | ),
40 | KEY.NUM_ATOMS: len(data.atomic_numbers),
41 | KEY.BATCH: data.batch_ids,
42 | }
43 |
44 | # From v0.9.3 to v10 sevenn introduced some changes in how models are build
45 | # (`build_E3_equivariant_model`), removing the EdgePreprocess class before the
46 | # network itself. The main purpose of EdgePreprocess was to initialize the
47 | # KEY.EDGE_VEC (r_ij: the vector between atom positions) and KEY.EDGE_LENGTH.
48 | # We replace that functionality here.
49 | # NOTE: the original preprocess had some special handling of the PBC cell
50 | # when self.is_stress was set to True. We're ignoring all that.
51 | # NOTE: as comparison to the original EdgePreprocess we assume `is_batch_data`
52 | # to be False.
53 | idx_src = sevenn_data[KEY.EDGE_IDX][0]
54 | idx_dst = sevenn_data[KEY.EDGE_IDX][1]
55 | pos = sevenn_data[KEY.POS]
56 | edge_vec = pos[idx_dst] - pos[idx_src]
57 | edge_vec = edge_vec + torch.einsum(
58 | "ni,ij->nj", sevenn_data[KEY.CELL_SHIFT], sevenn_data[KEY.CELL].view(3, 3)
59 | )
60 | sevenn_data[KEY.EDGE_VEC] = edge_vec
61 | sevenn_data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1)
62 |
63 | # Iterate through the model's layers
64 | # the sanest way to figure out which layer we're at is through the
65 | # `_modules` attribute of `nn.Sequential` (which the Sevenn network
66 | # inherits from), which exposes key-value pairs.
67 | layer_idx = 0
68 | scalar_features_list = []
69 | for i, (name, module) in enumerate(self._modules.items()):
70 | if "self_connection_intro" in name:
71 | layer_idx += 1
72 |
73 | new_sevenn_data = module(sevenn_data)
74 | if "equivariant_gate" in name:
75 | if extract_after_act:
76 | scalar_features = extract_scalar_irrep(
77 | new_sevenn_data, module.gate.irreps_out
78 | )
79 | else:
80 | scalar_features = extract_scalar_irrep(
81 | sevenn_data, module.gate.irreps_in
82 | )
83 | if append_layers:
84 | scalar_features_list.append(scalar_features)
85 | else:
86 | scalar_features_list[0] = scalar_features
87 | if layer_idx == interaction_layer:
88 | break
89 | sevenn_data = new_sevenn_data
90 |
91 | return torch.cat(scalar_features_list, dim=-1)
92 |
93 |
94 | def franken_sevenn_num_params(self) -> int:
95 | return sum(p.numel() for p in self.parameters())
96 |
97 |
98 | def franken_sevenn_feature_dim(
99 | self,
100 | interaction_layer: int,
101 | extract_after_act: bool = True,
102 | append_layers: bool = True,
103 | ):
104 | layer_idx = 0
105 | tot_feat_dim = 0
106 | for i, (name, module) in enumerate(self._modules.items()):
107 | if "self_connection_intro" in name:
108 | layer_idx += 1
109 | if "equivariant_gate" in name:
110 | if extract_after_act:
111 | new_feat_dim = module.gate.irreps_out.count("0e")
112 | else:
113 | new_feat_dim = module.gate.irreps_in.count("0e")
114 | if append_layers:
115 | tot_feat_dim += new_feat_dim
116 | else:
117 | tot_feat_dim = new_feat_dim
118 | if layer_idx == interaction_layer:
119 | break
120 | return tot_feat_dim
121 |
122 |
123 | class FrankenSevenn:
124 | @staticmethod
125 | def load_from_checkpoint(
126 | trainer_ckpt: Union[str, Path],
127 | gnn_backbone_id: str,
128 | interaction_block: int,
129 | extract_after_act: bool = True,
130 | append_layers: bool = True,
131 | ):
132 | sevenn, config = model_from_checkpoint(str(trainer_ckpt))
133 | assert isinstance(sevenn, nn.Module)
134 | sevenn.descriptors = MethodType( # type: ignore
135 | partial(
136 | franken_sevenn_descriptors,
137 | interaction_layer=interaction_block,
138 | extract_after_act=extract_after_act,
139 | append_layers=append_layers,
140 | ),
141 | sevenn,
142 | )
143 | sevenn.num_params = MethodType(franken_sevenn_num_params, sevenn) # type: ignore
144 | sevenn.feature_dim = MethodType( # type: ignore
145 | partial(
146 | franken_sevenn_feature_dim,
147 | interaction_layer=interaction_block,
148 | extract_after_act=extract_after_act,
149 | append_layers=append_layers,
150 | ),
151 | sevenn,
152 | )
153 |
154 | def init_args(self):
155 | return {
156 | "gnn_backbone_id": gnn_backbone_id,
157 | "interaction_block": interaction_block,
158 | "extract_after_act": extract_after_act,
159 | "append_layers": append_layers,
160 | }
161 |
162 | sevenn.init_args = MethodType(init_args, sevenn) # type: ignore
163 | return sevenn
164 |
--------------------------------------------------------------------------------
/franken/calculators/__init__.py:
--------------------------------------------------------------------------------
1 | """Run molecular dynamics with learned potentials.
2 |
3 | Calculators are available for ASE and LAMMPS, but can be
4 | extended to support your favorite MD software.
5 | """
6 |
7 | from .ase_calc import FrankenCalculator
8 | from .lammps_calc import LammpsFrankenCalculator
9 |
10 | __all__ = (
11 | "FrankenCalculator",
12 | "LammpsFrankenCalculator",
13 | )
14 |
--------------------------------------------------------------------------------
/franken/calculators/ase_calc.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Literal, Union
3 |
4 | import numpy as np
5 | import torch
6 | from ase.calculators.calculator import Calculator, all_changes
7 |
8 | from franken.data import BaseAtomsDataset, Configuration
9 | from franken.rf.model import FrankenPotential
10 | from franken.utils.misc import get_device_name
11 |
12 |
13 | class FrankenCalculator(Calculator):
14 | """Calculator for ASE with franken models
15 |
16 | Attributes:
17 | implemented_properties:
18 | Lists properties which can be asked from this calculator, notably "energy" and "forces".
19 | """
20 |
21 | implemented_properties = ["energy", "forces"]
22 | default_parameters = {}
23 | nolabel = True # ??
24 |
25 | def __init__(
26 | self,
27 | franken_ckpt: Union[FrankenPotential, str, Path],
28 | device=None,
29 | rf_weight_id: int | None = None,
30 | forces_mode: Literal["torch.func", "torch.autograd"] = "torch.autograd",
31 | **calc_kwargs,
32 | ):
33 | """Initialize FrankenCalculator class from a franken model.
34 |
35 | Args:
36 | franken_ckpt : Path to the franken model.
37 | This class accepts pre-loaded models, as well as jitted models (with `torch.jit`).
38 | device : PyTorch device specification for where the model should reside
39 | (e.g. "cuda:0" for GPU placement or "cpu" for CPU placement).
40 | rf_weight_id : ID of the random feature weights.
41 | Can generally be left to ``None`` unless the checkpoint contains multiple trained models.
42 | """
43 | # TODO: Remove forces_mode, torch.autograd is always the right way.
44 | super().__init__(**calc_kwargs)
45 | self.franken: FrankenPotential
46 | if isinstance(franken_ckpt, torch.nn.Module):
47 | self.franken = franken_ckpt
48 | if device is not None:
49 | self.franken.to(device)
50 | else:
51 | # Handle jitted torchscript archives and normal files
52 | try:
53 | self.franken = torch.jit.load(franken_ckpt, map_location=device)
54 | except RuntimeError as e:
55 | if "PytorchStreamReader" not in str(e):
56 | raise
57 | self.franken = FrankenPotential.load( # type: ignore
58 | franken_ckpt,
59 | map_location=device,
60 | rf_weight_id=rf_weight_id,
61 | )
62 |
63 | self.dataset = BaseAtomsDataset.from_path(
64 | data_path=None,
65 | split="md",
66 | gnn_config=self.franken.gnn_config,
67 | )
68 | self.device = (
69 | device if device is not None else next(self.franken.parameters()).device
70 | )
71 | self.forces_mode = forces_mode
72 |
73 | def calculate(
74 | self,
75 | atoms=None,
76 | properties=None,
77 | system_changes=all_changes,
78 | ):
79 | if properties is None:
80 | properties = self.implemented_properties
81 | if "forces" not in properties:
82 | forces_mode = "no_forces"
83 | else:
84 | forces_mode = self.forces_mode
85 |
86 | super().calculate(atoms, properties, system_changes)
87 |
88 | # self.atoms is set in the super() call. Unclear why it should be preferred over `atoms`
89 | config_idx = self.dataset.add_configuration(self.atoms) # type: ignore
90 | cpu_data = self.dataset.__getitem__(config_idx, no_targets=True)
91 | assert isinstance(cpu_data, Configuration)
92 | data = cpu_data.to(self.device)
93 |
94 | energy, forces = self.franken.energy_and_forces(data, forces_mode=forces_mode)
95 |
96 | if energy.ndim == 0:
97 | self.results["energy"] = energy.item()
98 | else:
99 | self.results["energy"] = np.squeeze(energy.numpy(force=True))
100 | if "forces" in properties:
101 | assert forces is not None
102 | self.results["forces"] = np.squeeze(forces.numpy(force=True))
103 |
104 |
105 | def calculator_throughput(
106 | calculator, atoms_list, num_repetitions=1, warmup_configs=5, verbose=True
107 | ):
108 | from time import perf_counter
109 |
110 | hardware = get_device_name(calculator.device)
111 |
112 | _atom_numbers = set(len(atoms) for atoms in atoms_list)
113 | assert (
114 | len(_atom_numbers) == 1
115 | ), f"This function only accepts configurations with the same number of atoms, while found configurations with {_atom_numbers} number of atoms"
116 | natoms = _atom_numbers.pop()
117 |
118 | assert len(atoms_list) > warmup_configs
119 | for idx in range(warmup_configs):
120 | calculator.calculate(atoms_list[idx])
121 | time_init = perf_counter()
122 | for _ in range(num_repetitions):
123 | for atoms in atoms_list:
124 | calculator.calculate(atoms)
125 | time = perf_counter() - time_init
126 | configs_per_sec = (len(atoms_list) * num_repetitions) / time
127 | results = {
128 | "throughput": configs_per_sec,
129 | "atoms": natoms,
130 | "hardware": hardware,
131 | }
132 | if verbose:
133 | print(
134 | f"{results['throughput']:.1f} cfgs/sec ({results['atoms']} atoms) | {results['hardware']}"
135 | )
136 | return results
137 |
--------------------------------------------------------------------------------
/franken/calculators/lammps_calc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from typing import Optional
4 | import torch
5 | from e3nn.util import jit
6 |
7 | from franken.data.base import Configuration
8 | from franken.rf.model import FrankenPotential
9 |
10 |
11 | @jit.compile_mode("script")
12 | class LammpsFrankenCalculator(torch.nn.Module):
13 | def __init__(
14 | self,
15 | franken_model: FrankenPotential,
16 | ):
17 | """Initialize LAMMPS Calculator
18 |
19 | Args:
20 | franken_model (FrankenPotential): The base franken model used in this MD calculator
21 |
22 | Note:
23 | The backbone underlying the franken model must be a MACE model. This is because we
24 | are re-using the LAMMPS interface developed by the MACE authors.
25 | """
26 | super().__init__()
27 |
28 | self.model = franken_model
29 | self.register_buffer("atomic_numbers", self.model.gnn.atomic_numbers)
30 | self.register_buffer("r_max", self.model.gnn.r_max)
31 | self.register_buffer("num_interactions", self.model.gnn.num_interactions)
32 | # this attribute is used for dtype detection in LAMMPS-MACE.
33 | # See: https://github.com/ACEsuit/lammps/blob/mace/src/ML-MACE/pair_mace.cpp#314
34 | self.model.node_embedding = self.model.gnn.node_embedding
35 |
36 | for param in self.model.parameters():
37 | param.requires_grad = False
38 |
39 | def forward(
40 | self,
41 | data: dict[str, torch.Tensor],
42 | local_or_ghost: torch.Tensor,
43 | compute_virials: bool = False,
44 | ) -> dict[str, torch.Tensor | None]:
45 | """Compute energies and forces of a given configuration.
46 |
47 | This module is meant to be used in conjunction with LAMMPS,
48 | and this function should not be called directly. The format of
49 | the input data is designed to work with the MACE-LAMMPS fork.
50 |
51 | Warning:
52 | Stresses and virials are not supported by franken. Since they
53 | are required to be set by LAMMPS, this function sets them to tensors
54 | of the appropriate shape filled with zeros. Make sure that
55 | the chosen MD method does not depend on these quantities.
56 | """
57 | # node_attrs is a one-hot representation of the atom types. atom_nums should be the actual atomic numbers!
58 | # we rely on correct sorting. This is the same as in MACE.
59 | atom_nums = self.atomic_numbers[torch.argmax(data["node_attrs"], dim=1)]
60 |
61 | franken_data = Configuration(
62 | atom_pos=data["positions"].double(),
63 | atomic_numbers=atom_nums,
64 | natoms=torch.tensor(
65 | len(atom_nums), dtype=torch.int32, device=atom_nums.device
66 | ).view(1),
67 | node_attrs=data["node_attrs"].double(),
68 | edge_index=data["edge_index"],
69 | shifts=data["shifts"],
70 | unit_shifts=data["unit_shifts"],
71 | )
72 | energy, forces = self.model(franken_data) # type: ignore
73 | # Kokkos doesn't like total_energy_local and only looks at node_energy.
74 | # We hack around this:
75 | node_energy = energy.repeat(len(atom_nums)).div(len(atom_nums))
76 | virials: Optional[torch.Tensor] = None
77 | if compute_virials:
78 | virials = torch.zeros((1, 3, 3), dtype=forces.dtype, device=forces.device)
79 | return {
80 | "total_energy_local": energy,
81 | "node_energy": node_energy,
82 | "forces": forces,
83 | "virials": virials,
84 | }
85 |
86 | @staticmethod
87 | def create_lammps_model(model_path: str, rf_weight_id: int | None) -> str:
88 | """Compile a franken model into a LAMMPS calculator
89 |
90 | Args:
91 | model_path (str):
92 | path to the franken model checkpoint.
93 | rf_weight_id (int | None):
94 | ID of the random feature weights. Can generally be left to ``None`` unless
95 | the checkpoint contains multiple trained models.
96 |
97 | Returns:
98 | str: the path where the LAMMPS-compatible model was saved to.
99 | """
100 | franken_model = FrankenPotential.load(
101 | model_path,
102 | map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
103 | rf_weight_id=rf_weight_id,
104 | )
105 | # NOTE:
106 | # Kokkos is hardcoded to double and will silently corrupt data if the model
107 | # does not use dtype double.
108 | franken_model = franken_model.double().to("cpu")
109 | lammps_model = LammpsFrankenCalculator(franken_model)
110 | lammps_model_compiled = jit.compile(lammps_model)
111 |
112 | save_path = f"{os.path.splitext(model_path)[0]}-lammps.pt"
113 | print(f"Saving compiled model to '{save_path}'")
114 | lammps_model_compiled.save(save_path)
115 | return save_path
116 |
117 |
118 | def build_arg_parser():
119 | parser = argparse.ArgumentParser(
120 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
121 | description="Convert a franken model to be able to use it with LAMMPS",
122 | )
123 | parser.add_argument(
124 | "--model_path",
125 | type=str,
126 | help="Path to the model to be converted to LAMMPS",
127 | )
128 | parser.add_argument(
129 | "--rf_weight_id",
130 | type=int,
131 | help="Head of the model to be converted to LAMMPS",
132 | default=None,
133 | )
134 | return parser
135 |
136 |
137 | def create_lammps_model_cli():
138 | parser = build_arg_parser()
139 | args = parser.parse_args()
140 | LammpsFrankenCalculator.create_lammps_model(args.model_path, args.rf_weight_id) # type: ignore
141 |
142 |
143 | if __name__ == "__main__":
144 | create_lammps_model_cli()
145 |
146 |
147 | # For sphinx docs
148 | get_parser_fn = lambda: build_arg_parser() # noqa: E731
149 |
--------------------------------------------------------------------------------
/franken/data/__init__.py:
--------------------------------------------------------------------------------
1 | from franken.data.base import BaseAtomsDataset, Configuration, Target
2 |
3 |
4 | __all__ = [
5 | "BaseAtomsDataset",
6 | "Configuration",
7 | "Target",
8 | ]
9 |
--------------------------------------------------------------------------------
/franken/data/distributed_sampler.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Iterator
2 |
3 | from torch.utils.data import Sampler, Dataset
4 | import torch.distributed as dist
5 |
6 |
7 | class SimpleUnevenDistributedSampler(Sampler):
8 | def __init__(
9 | self,
10 | dataset: Dataset,
11 | num_replicas: Optional[int] = None,
12 | rank: Optional[int] = None,
13 | ) -> None:
14 | if num_replicas is None:
15 | if not dist.is_available():
16 | raise RuntimeError("Requires distributed package to be available")
17 | num_replicas = dist.get_world_size()
18 | if rank is None:
19 | if not dist.is_available():
20 | raise RuntimeError("Requires distributed package to be available")
21 | rank = dist.get_rank()
22 | if rank >= num_replicas or rank < 0:
23 | raise ValueError(
24 | f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
25 | )
26 | self.dataset = dataset
27 | self.num_replicas = num_replicas
28 | self.rank = rank
29 | self.epoch = 0
30 | self.total_size = len(self.dataset) # type: ignore[arg-type]
31 | # num_samples indicates the number of samples for the current process
32 | self.num_samples = len(range(self.rank, self.total_size, self.num_replicas))
33 |
34 | def __iter__(self) -> Iterator:
35 | indices = list(range(len(self.dataset))) # type: ignore[arg-type]
36 |
37 | # subsample
38 | indices = indices[self.rank : self.total_size : self.num_replicas]
39 | assert len(indices) == self.num_samples
40 |
41 | return iter(indices)
42 |
43 | def __len__(self) -> int:
44 | return self.num_samples
45 |
--------------------------------------------------------------------------------
/franken/data/fairchem.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | from fairchem.core.preprocessing import AtomsToGraphs as FairchemAtomsToGraphs
5 |
6 | import franken.utils.distributed as dist_utils
7 | from franken.backbones.utils import get_checkpoint_path
8 | from franken.data import BaseAtomsDataset, Configuration, Target
9 |
10 |
11 | class FairchemAtomsDataset(BaseAtomsDataset):
12 | def __init__(
13 | self,
14 | data_path: str | Path | None,
15 | split: str,
16 | num_random_subsamples: int | None = None,
17 | subsample_rng: int | None = None,
18 | gnn_backbone_id: str | torch.nn.Module | None = None,
19 | cutoff=6.0,
20 | max_num_neighbors=200,
21 | precompute=True,
22 | ):
23 | super().__init__(data_path, split, num_random_subsamples, subsample_rng)
24 |
25 | if gnn_backbone_id is not None:
26 | cutoff, max_num_neighbors = self.load_info_from_gnn_config(gnn_backbone_id)
27 |
28 | if split == "md":
29 | # Cannot get energy and forces in MD mode (the calculator fails)
30 | self.a2g = FairchemAtomsToGraphs(
31 | max_neigh=max_num_neighbors,
32 | radius=cutoff, # type: ignore
33 | )
34 | else:
35 | self.a2g = FairchemAtomsToGraphs(
36 | max_neigh=max_num_neighbors,
37 | radius=cutoff,
38 | r_energy=True,
39 | r_forces=True, # type: ignore
40 | )
41 | self.graphs = None
42 | if precompute and len(self.ase_atoms) > 0:
43 | self.graphs = self.a2g.convert_all(
44 | self.ase_atoms, disable_tqdm=dist_utils.get_rank() != 0
45 | )
46 |
47 | def load_info_from_gnn_config(self, gnn_backbone_id: str | torch.nn.Module):
48 | if not isinstance(gnn_backbone_id, str):
49 | raise ValueError(
50 | "Backbone path must be provided instead of the preloaded model."
51 | )
52 | ckpt_path = get_checkpoint_path(gnn_backbone_id)
53 | model = torch.load(ckpt_path, map_location="cpu", weights_only=False)
54 | model_cfg = model["config"]["model"]
55 | cutoff = getattr(model_cfg, "cutoff", 6.0)
56 | max_num_neighbors = getattr(model_cfg, "max_num_neighbors", 200)
57 | del model, model_cfg
58 | return cutoff, max_num_neighbors
59 |
60 | def graph_to_inputs(self, graph):
61 | return Configuration(
62 | atom_pos=graph.pos, # type: ignore
63 | atomic_numbers=graph.atomic_numbers.int(),
64 | natoms=torch.tensor(graph.natoms).view(1),
65 | batch_ids=(
66 | graph.batch
67 | if graph.batch is not None
68 | else torch.zeros(graph.natoms, dtype=torch.int64)
69 | ),
70 | cell=graph.cell,
71 | pbc=getattr(graph, "pbc", None),
72 | )
73 |
74 | def graph_to_targets(self, graph):
75 | energy = torch.tensor(graph.energy)
76 | return Target(energy=energy, forces=graph.forces)
77 |
78 | def __getitem__(self, idx, no_targets: bool = False):
79 | """Returns an array of (inputs, outputs) with inputs being a configuration
80 | and outputs being the target (energy and forces).
81 | Note: ONLY for the 'train' split, the energy_shift is removed from the target.
82 | """
83 | if self.graphs is None:
84 | graph = self.a2g.convert(self.ase_atoms[idx])
85 | else:
86 | graph = self.graphs[idx]
87 |
88 | data = self.graph_to_inputs(graph)
89 | if no_targets:
90 | return data
91 | target = self.graph_to_targets(graph)
92 |
93 | if self.split == "train":
94 | target.energy -= self.energy_shifts[idx]
95 |
96 | return data, target
97 |
--------------------------------------------------------------------------------
/franken/data/mace.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional
3 |
4 | import torch
5 | from mace.data import AtomicData
6 | from mace.data.utils import config_from_atoms
7 | from mace.tools.utils import AtomicNumberTable
8 | from tqdm.auto import tqdm
9 |
10 | import franken.utils.distributed as dist_utils
11 | from franken.backbones.utils import get_checkpoint_path
12 | from franken.data import BaseAtomsDataset, Configuration, Target
13 | from franken.utils.misc import torch_load_maybejit
14 |
15 |
16 | class MACEAtomsToGraphs:
17 | def __init__(self, z_table: AtomicNumberTable, cutoff: float):
18 | self.cutoff = cutoff
19 | self.z_table = z_table
20 |
21 | def convert(self, atoms):
22 | mace_config = config_from_atoms(atoms)
23 | graph = AtomicData.from_config(
24 | mace_config, z_table=self.z_table, cutoff=self.cutoff
25 | )
26 | graph.atomic_numbers = torch.tensor(atoms.get_atomic_numbers()).int()
27 | return graph
28 |
29 | def convert_all(
30 | self,
31 | atoms_list,
32 | process_rank: Optional[int] = None,
33 | split_name: Optional[str] = None,
34 | ):
35 | graphs = []
36 | atoms_iter = atoms_list
37 | if process_rank is None:
38 | process_rank = dist_utils.get_rank()
39 | if process_rank == 0:
40 | desc = "ASE -> MACE"
41 | if split_name is not None:
42 | desc += f" ({split_name})"
43 | atoms_iter = tqdm(atoms_list, desc=desc)
44 | for atoms in atoms_iter:
45 | graphs.append(self.convert(atoms))
46 | return graphs
47 |
48 |
49 | class MACEAtomsDataset(BaseAtomsDataset):
50 | def __init__(
51 | self,
52 | data_path: str | Path | None,
53 | split: str,
54 | num_random_subsamples: int | None = None,
55 | subsample_rng: int | None = None,
56 | gnn_backbone_id: str | torch.nn.Module | None = None,
57 | z_table: AtomicNumberTable | None = None,
58 | cutoff=6.0,
59 | precompute=True,
60 | ):
61 | super().__init__(data_path, split, num_random_subsamples, subsample_rng)
62 | if gnn_backbone_id is not None:
63 | z_table, cutoff = self.load_info_from_gnn_config(gnn_backbone_id)
64 | else:
65 | assert z_table is not None
66 |
67 | self.a2g = MACEAtomsToGraphs(z_table, cutoff)
68 | self.graphs = None
69 | if precompute and len(self.ase_atoms) > 0:
70 | self.graphs = self.a2g.convert_all(
71 | self.ase_atoms,
72 | split_name=self.split,
73 | )
74 |
75 | def load_info_from_gnn_config(self, gnn_backbone_id: str | torch.nn.Module):
76 | if isinstance(gnn_backbone_id, str):
77 | ckpt_path = get_checkpoint_path(gnn_backbone_id)
78 | mace_gnn = torch_load_maybejit(
79 | ckpt_path, map_location="cpu", weights_only=False
80 | )
81 | else:
82 | mace_gnn = gnn_backbone_id
83 | z_table = AtomicNumberTable([z.item() for z in mace_gnn.atomic_numbers])
84 | cutoff = mace_gnn.r_max.item()
85 | del mace_gnn
86 | return z_table, cutoff
87 |
88 | def __getitem__(self, idx, no_targets: bool = False):
89 | """Returns an array of (inputs, outputs) with inputs being a configuration
90 | and outputs being the target (energy and forces).
91 | Note: ONLY for the 'train' split, the energy_shift is removed from the target.
92 | """
93 | if self.graphs is None:
94 | graph = self.a2g.convert(self.ase_atoms[idx])
95 | else:
96 | graph = self.graphs[idx]
97 |
98 | data = Configuration(
99 | atom_pos=graph.positions,
100 | atomic_numbers=graph.atomic_numbers,
101 | natoms=torch.tensor(len(graph.atomic_numbers)).view(1),
102 | node_attrs=graph.node_attrs,
103 | edge_index=graph.edge_index,
104 | shifts=graph.shifts,
105 | unit_shifts=graph.unit_shifts,
106 | )
107 | if no_targets:
108 | return data
109 |
110 | energy = torch.tensor(
111 | self.ase_atoms[idx].get_potential_energy(apply_constraint=False)
112 | )
113 | if self.split == "train":
114 | energy = energy - self.energy_shifts[idx]
115 |
116 | target = Target(
117 | energy=energy,
118 | forces=torch.Tensor(self.ase_atoms[idx].get_forces(apply_constraint=False)),
119 | )
120 | return data, target
121 |
--------------------------------------------------------------------------------
/franken/data/sevenn.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional
3 |
4 | import ase
5 | import numpy as np
6 | import torch
7 | from tqdm.auto import tqdm
8 | import sevenn._keys as KEY
9 | from sevenn.atom_graph_data import AtomGraphData
10 | from sevenn.train.dataload import atoms_to_graph
11 |
12 | import franken.utils.distributed as dist_utils
13 | from franken.backbones.utils import get_checkpoint_path
14 | from franken.data import BaseAtomsDataset, Configuration, Target
15 |
16 |
17 | class SevennAtomsToGraphs:
18 | """
19 | Args:
20 | cutoff (float): Cutoff in Angstrom to build atoms graph from positions
21 | transfer_info (bool): if True, copy info from atoms to graph
22 | """
23 |
24 | def __init__(self, cutoff: float, transfer_info: bool, y_from_calc: bool):
25 | self.cutoff = cutoff
26 | self.transfer_info = transfer_info
27 | self.y_from_calc = y_from_calc
28 |
29 | def convert(self, atoms: ase.Atoms):
30 | if not self.y_from_calc:
31 | # It means we're not interested in forces and energies.
32 | # workaround is to set the attributes to invalid and then
33 | # remove the attributes
34 | atoms.info["y_energy"] = np.nan
35 | atoms.arrays["y_force"] = np.full(atoms.arrays["positions"].shape, np.nan)
36 | graph = atoms_to_graph(
37 | atoms,
38 | cutoff=self.cutoff,
39 | transfer_info=self.transfer_info,
40 | y_from_calc=self.y_from_calc,
41 | with_shift=True,
42 | )
43 | if not self.y_from_calc:
44 | del graph[KEY.ENERGY]
45 | del graph[KEY.FORCE]
46 | del atoms.info["y_energy"]
47 | del atoms.arrays["y_force"]
48 | atom_graph_data = AtomGraphData.from_numpy_dict(graph)
49 | return atom_graph_data
50 |
51 | def convert_all(
52 | self,
53 | atoms_list,
54 | process_rank: Optional[int] = None,
55 | split_name: Optional[str] = None,
56 | ):
57 | graphs = []
58 | atoms_iter = atoms_list
59 | if process_rank is None:
60 | process_rank = dist_utils.get_rank()
61 | if process_rank == 0:
62 | desc = "ASE -> SEVENN"
63 | if split_name is not None:
64 | desc += f" ({split_name})"
65 | atoms_iter = tqdm(atoms_list, desc=desc)
66 | for atoms in atoms_iter:
67 | graphs.append(self.convert(atoms))
68 | return graphs
69 |
70 |
71 | class SevennAtomsDataset(BaseAtomsDataset):
72 | def __init__(
73 | self,
74 | data_path: str | Path | None,
75 | split: str,
76 | num_random_subsamples: int | None = None,
77 | subsample_rng: int | None = None,
78 | gnn_backbone_id: str | torch.nn.Module | None = None,
79 | cutoff: float = 6.0,
80 | precompute=True,
81 | ):
82 | super().__init__(data_path, split, num_random_subsamples, subsample_rng)
83 | if gnn_backbone_id is not None:
84 | cutoff = self.load_info_from_gnn_config(gnn_backbone_id)
85 | else:
86 | assert cutoff is not None
87 |
88 | if split == "md":
89 | self.a2g = SevennAtomsToGraphs(
90 | cutoff, transfer_info=False, y_from_calc=False
91 | )
92 | else:
93 | self.a2g = SevennAtomsToGraphs(
94 | cutoff, transfer_info=False, y_from_calc=True
95 | )
96 | self.graphs = None
97 | if precompute and len(self.ase_atoms) > 0:
98 | self.graphs = self.a2g.convert_all(
99 | self.ase_atoms,
100 | split_name=self.split,
101 | )
102 |
103 | def load_info_from_gnn_config(self, gnn_backbone_id: str | torch.nn.Module):
104 | if not isinstance(gnn_backbone_id, str):
105 | raise ValueError(
106 | "Backbone path must be provided instead of the preloaded model."
107 | )
108 | ckpt_path = get_checkpoint_path(gnn_backbone_id)
109 | checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
110 | config = checkpoint["config"]
111 | cutoff = config["cutoff"]
112 | del checkpoint, config
113 | return cutoff
114 |
115 | def __getitem__(self, idx, no_targets: bool = False):
116 | """Returns an array of (inputs, outputs) with inputs being a configuration
117 | and outputs being the target (energy and forces).
118 | Note: ONLY for the 'train' split, the energy_shift is removed from the target.
119 | """
120 | if self.graphs is None:
121 | graph = self.a2g.convert(self.ase_atoms[idx])
122 | else:
123 | graph = self.graphs[idx]
124 |
125 | data = Configuration(
126 | atom_pos=graph.pos,
127 | atomic_numbers=graph[KEY.ATOMIC_NUMBERS],
128 | natoms=torch.tensor(len(graph[KEY.ATOMIC_NUMBERS])).view(1),
129 | edge_index=graph.edge_index,
130 | shifts=graph[KEY.CELL_SHIFT],
131 | cell=graph[KEY.CELL],
132 | batch_ids=(
133 | graph.batch
134 | if graph.batch is not None
135 | else torch.zeros(graph[KEY.NUM_ATOMS], dtype=torch.int64)
136 | ),
137 | )
138 | if no_targets:
139 | return data
140 |
141 | energy = graph[KEY.ENERGY]
142 | if self.split == "train":
143 | energy = energy - self.energy_shifts[idx]
144 | target = Target(
145 | energy=energy,
146 | forces=graph[KEY.FORCE],
147 | )
148 | return data, target
149 |
--------------------------------------------------------------------------------
/franken/datasets/PtH2O/pth2o_dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import shutil
3 | import zipfile
4 |
5 | import ase
6 | import ase.io
7 | import numpy as np
8 |
9 | from franken.datasets.registry import DATASET_REGISTRY, BaseRegisteredDataset
10 | from franken.utils.file_utils import download_file
11 |
12 |
13 | @DATASET_REGISTRY.register("PtH2O")
14 | class PtH2ORegisteredDataset(BaseRegisteredDataset):
15 | relative_paths = {
16 | "PtH2O": {
17 | "train": "PtH2O/train.extxyz",
18 | "val": "PtH2O/valid.extxyz",
19 | "test": "PtH2O/test.extxyz",
20 | },
21 | }
22 |
23 | @classmethod
24 | def get_path(
25 | cls, name: str, split: str, base_path: Path | None, download: bool = True
26 | ):
27 | if base_path is None:
28 | raise KeyError(None)
29 | relative_path = cls.relative_paths[name][split]
30 | path = base_path / relative_path
31 | if not path.is_file() and download:
32 | cls.download(base_path)
33 | if path.is_file():
34 | return path
35 | else:
36 | raise ValueError(f"Dataset not found at '{path.resolve()}'")
37 |
38 | @classmethod
39 | def download(cls, base_path: Path):
40 | pth2o_base_path = base_path / "PtH2O"
41 | pth2o_base_path.mkdir(exist_ok=True, parents=True)
42 | # Download
43 | download_file(
44 | url="https://data.dtu.dk/ndownloader/files/29141586",
45 | filename=pth2o_base_path / "data.zip",
46 | desc="Downloading PtH2O dataset",
47 | expected_md5="acd748f7f32c66961c90cb15457f7bae",
48 | )
49 | # Extract
50 | with zipfile.ZipFile(pth2o_base_path / "data.zip", "r") as zf:
51 | zf.extractall(pth2o_base_path)
52 | # Read full dataset
53 | full_traj = ase.io.read(
54 | pth2o_base_path / "Dataset_and_training_files" / "dataset.traj", index=":"
55 | )
56 | assert isinstance(full_traj, list)
57 | # Split into train/val/test
58 | np.random.seed(42)
59 | np.random.shuffle(full_traj)
60 | train_traj = full_traj[:30_000]
61 | valid_traj = full_traj[30_000:31_000]
62 | test_traj = full_traj[31_000:]
63 | # Saved shuffled to disk
64 | ase.io.write(pth2o_base_path / "train.extxyz", train_traj)
65 | ase.io.write(pth2o_base_path / "valid.extxyz", valid_traj)
66 | ase.io.write(pth2o_base_path / "test.extxyz", test_traj)
67 | # Cleanup
68 | (pth2o_base_path / "data.zip").unlink()
69 | shutil.rmtree(pth2o_base_path / "Dataset_and_training_files")
70 |
71 |
72 | if __name__ == "__main__":
73 | PtH2ORegisteredDataset.download(Path(__file__).parent.parent)
74 |
--------------------------------------------------------------------------------
/franken/datasets/TM23/tm23_dataset.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | import operator
3 | from pathlib import Path
4 | import zipfile
5 |
6 | from franken.datasets.registry import DATASET_REGISTRY, BaseRegisteredDataset
7 | from franken.utils.file_utils import download_file
8 |
9 | TM23_ELEMENTS = [
10 | "Ag",
11 | "Au",
12 | "Cd",
13 | "Co",
14 | "Cr",
15 | "Cu",
16 | "Fe",
17 | "Hf",
18 | "Hg",
19 | "Ir",
20 | "Mn",
21 | "Mo",
22 | "Nb",
23 | "Ni",
24 | "Os",
25 | "Pd",
26 | "Pt",
27 | "Re",
28 | "Rh",
29 | "Ru",
30 | "Ta",
31 | "Tc",
32 | "Ti",
33 | "V",
34 | "W",
35 | "Zn",
36 | "Zr",
37 | ]
38 |
39 | TM23_DATASETS = list(
40 | reduce(
41 | operator.concat,
42 | [
43 | [f"TM23/{el}", f"TM23/{el}-cold", f"TM23/{el}-warm", f"TM23/{el}-melt"]
44 | for el in TM23_ELEMENTS
45 | ],
46 | )
47 | )
48 |
49 |
50 | @DATASET_REGISTRY.register(TM23_DATASETS)
51 | class TM23RegisteredDataset(BaseRegisteredDataset):
52 | relative_paths = reduce(
53 | operator.ior,
54 | [
55 | {
56 | f"TM23/{el}": {
57 | "train": f"TM23/{el}_2700cwm_train.xyz",
58 | "val": f"TM23/{el}_2700cwm_test.xyz",
59 | },
60 | f"TM23/{el}-cold": {
61 | "train": f"TM23/{el}_cold_nequip_train.xyz",
62 | "val": f"TM23/{el}_cold_nequip_test.xyz",
63 | },
64 | f"TM23/{el}-warm": {
65 | "train": f"TM23/{el}_warm_nequip_train.xyz",
66 | "val": f"TM23/{el}_warm_nequip_test.xyz",
67 | },
68 | f"TM23/{el}-melt": {
69 | "train": f"TM23/{el}_melt_nequip_train.xyz",
70 | "val": f"TM23/{el}_melt_nequip_test.xyz",
71 | },
72 | }
73 | for el in TM23_ELEMENTS
74 | ],
75 | {},
76 | ) # merge list of dicts
77 |
78 | @classmethod
79 | def get_path(
80 | cls, name: str, split: str, base_path: Path | None, download: bool = True
81 | ):
82 | if base_path is None:
83 | raise KeyError(None)
84 | relative_path = cls.relative_paths[name][split]
85 | path = base_path / relative_path
86 | if not path.is_file() and download:
87 | cls.download(base_path)
88 | if path.is_file():
89 | return path
90 | else:
91 | raise ValueError(f"Dataset not found at '{path.resolve()}'")
92 |
93 | @classmethod
94 | def download(cls, base_path: Path):
95 | tm23_base_path = base_path / "TM23"
96 | tm23_base_path.mkdir(exist_ok=True, parents=True)
97 | # Download
98 | download_file(
99 | url="https://archive.materialscloud.org/record/file?record_id=2113&filename=benchmarking_master_collection-20240316T202423Z-001.zip",
100 | filename=tm23_base_path / "data.zip",
101 | desc="Downloading TM23 dataset",
102 | )
103 | # Extract
104 | with zipfile.ZipFile(tm23_base_path / "data.zip", "r") as zf:
105 | zf.extractall(tm23_base_path)
106 | # Move files up one level
107 | for origin in (tm23_base_path / "benchmarking_master_collection").glob("*"):
108 | origin.rename(tm23_base_path / origin.name)
109 | # Cleanup
110 | (tm23_base_path / "data.zip").unlink()
111 | (tm23_base_path / "benchmarking_master_collection").rmdir()
112 |
--------------------------------------------------------------------------------
/franken/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from franken.datasets.registry import DATASET_REGISTRY
2 |
3 | # Ensure all sub-datasets are imported so that they are registered.
4 | from .water import water_dataset # noqa: F401
5 | from .TM23 import tm23_dataset # noqa: F401
6 | from .PtH2O import pth2o_dataset # noqa: F401
7 | from .test import test_dataset # noqa: F401
8 |
9 | __all__ = ("DATASET_REGISTRY",)
10 |
--------------------------------------------------------------------------------
/franken/datasets/registry.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, ClassVar
2 | from pathlib import Path
3 |
4 |
5 | class BaseRegisteredDataset:
6 | relative_paths: ClassVar[dict[str, dict[str, str]]]
7 |
8 | @classmethod
9 | def get_path(
10 | cls, name: str, split: str, base_path: Path | None, download: bool = True
11 | ) -> Path:
12 | raise NotImplementedError()
13 |
14 | @classmethod
15 | def is_valid_split(cls, name: str, split: str) -> bool:
16 | return split in cls.relative_paths[name]
17 |
18 |
19 | _KT = str
20 | _VT = type[BaseRegisteredDataset]
21 |
22 |
23 | class DatasetRegistry(dict[_KT, _VT]):
24 | def register(self, name: _KT | list[_KT] | tuple[_KT]) -> Callable[[_VT], _VT]:
25 | def decorator(func: _VT) -> _VT:
26 | if isinstance(name, (list, tuple)):
27 | for name_single in name:
28 | self[name_single] = func
29 | else:
30 | self[name] = func
31 | return func
32 |
33 | return decorator
34 |
35 | def get_path(
36 | self, name: str, split: str, base_path: Path | None, download: bool = True
37 | ):
38 | """Fetch the path for a dataset-split pair. If the dataset does not exist under
39 | the `base_path` directory, a download will be attempted.
40 |
41 | Args:
42 | name (str): dataset name (e.g. "water", "TM23/Ag-cold", "PtH2O")
43 | split (str): data-split, for example "train", "val" or "test"
44 | base_path (Path): the base path at which the dataset is stored.
45 | download (bool, optional): Whether to download the dataset if it does not exist.
46 | Defaults to True.
47 |
48 | Returns:
49 | dset_path (Path): a path to the ase-readable dataset.
50 | """
51 | return self[name].get_path(name, split, base_path, download)
52 |
53 | def is_valid_split(self, name: str, split: str) -> bool:
54 | return self[name].is_valid_split(name, split)
55 |
56 |
57 | DATASET_REGISTRY = DatasetRegistry()
58 |
--------------------------------------------------------------------------------
/franken/datasets/split_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | from ase.io import read, write
4 | import os
5 |
6 |
7 | def split_trajectory(
8 | file_path, split_ratio=0.8, seed=None, train_output=None, val_output=None
9 | ):
10 | # Set the random seed if provided, for reproducibility
11 | if seed is not None:
12 | random.seed(seed)
13 |
14 | # Load the frames from the trajectory file
15 | frames = read(file_path, index=":")
16 | num_frames = len(frames)
17 | print(f"Loaded {num_frames} frames from '{file_path}'.")
18 |
19 | # Shuffle and split frames based on the split ratio
20 | indices = list(range(num_frames))
21 | random.shuffle(indices)
22 |
23 | # Calculate the split index
24 | split_index = int(num_frames * split_ratio)
25 | train_indices = indices[:split_index]
26 | val_indices = indices[split_index:]
27 |
28 | # Create train and validation splits
29 | train_frames = [frames[i] for i in train_indices]
30 | val_frames = [frames[i] for i in val_indices]
31 |
32 | # Set default output filenames if not provided
33 | if train_output is None:
34 | train_output = f"{os.path.splitext(file_path)[0]}_train.xyz"
35 | if val_output is None:
36 | val_output = f"{os.path.splitext(file_path)[0]}_val.xyz"
37 |
38 | # Write the split trajectories to separate files
39 | write(train_output, train_frames)
40 | write(val_output, val_frames)
41 |
42 | print(
43 | f"Saved {len(train_frames)} frames to '{train_output}' and {len(val_frames)} frames to '{val_output}'."
44 | )
45 |
46 |
47 | def main():
48 | parser = argparse.ArgumentParser(
49 | description="Split an ASE trajectory file into train and validation sets in a reproducible way."
50 | )
51 |
52 | # Mandatory argument
53 | parser.add_argument(
54 | "file_path", type=str, help="Path to the input .xyz trajectory file."
55 | )
56 |
57 | # Optional arguments
58 | parser.add_argument(
59 | "--seed",
60 | type=int,
61 | default=None,
62 | help="Random seed for reproducibility (default: None).",
63 | )
64 | parser.add_argument(
65 | "--split_ratio",
66 | type=float,
67 | default=0.8,
68 | help="Ratio of train to validation split (default: 0.8 for 80%% train).",
69 | )
70 | parser.add_argument(
71 | "--train_output",
72 | type=str,
73 | default=None,
74 | help="Output filename for the train set (default: 'input_train.xyz').",
75 | )
76 | parser.add_argument(
77 | "--val_output",
78 | type=str,
79 | default=None,
80 | help="Output filename for the validation set (default: 'input_val.xyz').",
81 | )
82 |
83 | args = parser.parse_args()
84 |
85 | # Validate the split ratio
86 | if not 0 < args.split_ratio < 1:
87 | parser.error("split_ratio must be between 0 and 1 (exclusive).")
88 |
89 | # Call the split function with the provided arguments
90 | split_trajectory(
91 | args.file_path, args.split_ratio, args.seed, args.train_output, args.val_output
92 | )
93 |
94 |
95 | if __name__ == "__main__":
96 | main()
97 |
--------------------------------------------------------------------------------
/franken/datasets/test/md.xyz:
--------------------------------------------------------------------------------
1 | 2
2 | Lattice="10.863786 0.0 0.0 0.0 10.863786 0.0 0.0 0.0 7.242524" Properties=species:S:1:pos:R:3:forces:R:3 energy=-286.50486882 stress="-0.002598970560062866 -0.0009004822384060989 0.00015443080167127217 -0.0009004822384060989 -0.005992973973123529 0.0015299963548775313 0.00015443080167127217 0.0015299963548775313 -8.988228146165167e-05" pbc="T T T"
3 | Cu 10.84368800 0.01696923 3.65871309 -0.61688780 0.27004654 -0.02483569
4 | Cu 10.83387800 3.69713106 7.12674501 -0.54312980 -0.46143988 0.19868231
--------------------------------------------------------------------------------
/franken/datasets/test/test.xyz:
--------------------------------------------------------------------------------
1 | 2
2 | Lattice="12.025693879326065 0.0 0.0 6.012846939663032 10.414551541843355 0.0 0.0 0.0 16.24012235072364" Properties=species:S:1:pos:R:3:initial_charges:R:1:forces:R:3:magmoms:R:1 energy=-34662.48894316101 free_energy=-34662.48894316101 pbc="T T T"
3 | Fe 4.00860872 4.62866793 -2.52042670 8.17001000 0.84132152 0.36018147 0.37064355 0.00000000
4 | N 1.32922520 -0.85820042 0.58980978 5.85370200 1.85383814 -0.41121504 -0.84305033 0.00000000
--------------------------------------------------------------------------------
/franken/datasets/test/test_dataset.py:
--------------------------------------------------------------------------------
1 | from importlib import resources
2 | from pathlib import Path
3 |
4 | import franken.datasets
5 | from franken.datasets.registry import DATASET_REGISTRY, BaseRegisteredDataset
6 |
7 |
8 | @DATASET_REGISTRY.register("test")
9 | class TestRegisteredDataset(BaseRegisteredDataset):
10 | relative_paths = {
11 | "test": {
12 | "train": "test/train.xyz",
13 | "val": "test/validation.xyz",
14 | "test": "test/test.xyz",
15 | "md": "test/md.xyz",
16 | "long": "test/long.xyz",
17 | },
18 | }
19 |
20 | @classmethod
21 | def get_path(
22 | cls, name: str, split: str, base_path: Path | None, download: bool = True
23 | ):
24 | relative_path = cls.relative_paths[name][split]
25 | path = resources.files(franken.datasets) / relative_path
26 | if path.is_file():
27 | return path
28 | else:
29 | raise ValueError(f"Dataset not found at '{path}'")
30 |
--------------------------------------------------------------------------------
/franken/datasets/test/train.xyz:
--------------------------------------------------------------------------------
1 | 2
2 | Lattice="12.025693879326065 0.0 0.0 6.012846939663032 10.414551541843355 0.0 0.0 0.0 16.24012235072364" Properties=species:S:1:pos:R:3:initial_charges:R:1:forces:R:3:magmoms:R:1 energy=-34663.781510561785 free_energy=-34663.781510561785 pbc="T T T"
3 | H 4.00860872 4.62866793 -2.52042670 8.14689100 0.69448942 0.05905515 0.53533993 0.00000000
4 | N 5.18398320 4.85600646 2.93821253 4.99690200 0.11953213 -0.02882078 0.69749169 0.00000000
5 | 2
6 | Lattice="12.025693879326065 0.0 0.0 6.012846939663032 10.414551541843355 0.0 0.0 0.0 16.24012235072364" Properties=species:S:1:pos:R:3:initial_charges:R:1:forces:R:3:magmoms:R:1 energy=-34662.89388569647 free_energy=-34662.89388569647 pbc="T T T"
7 | H 4.00860872 4.62866793 -2.52042670 8.21123300 -0.14498065 0.29821454 0.47052756 0.00000000
8 | N 0.67900074 -1.16300245 4.59493225 4.97292000 -0.40408203 -0.50703723 -3.20458904 0.00000000
--------------------------------------------------------------------------------
/franken/datasets/test/validation.xyz:
--------------------------------------------------------------------------------
1 | 2
2 | Lattice="12.025693879326065 0.0 0.0 6.012846939663032 10.414551541843355 0.0 0.0 0.0 16.24012235072364" Properties=species:S:1:pos:R:3:initial_charges:R:1:forces:R:3:magmoms:R:1 energy=-34663.925543954065 free_energy=-34663.925543954065 pbc="T T T"
3 | Fe 4.00860872 4.62866793 -2.52042670 8.20589500 0.60784890 -0.92761725 -0.03625873 0.00000000
4 | N 6.20080853 -2.74071456 0.44857041 6.15661500 -0.24211410 -0.27163268 -2.25028837 0.00000000
5 | 2
6 | Lattice="12.025693879326065 0.0 0.0 6.012846939663032 10.414551541843355 0.0 0.0 0.0 16.24012235072364" Properties=species:S:1:pos:R:3:initial_charges:R:1:forces:R:3:magmoms:R:1 energy=-34663.41248242709 free_energy=-34663.41248242709 pbc="T T T"
7 | Fe 4.00860872 4.62866793 -2.52042670 8.18704900 1.03936665 -0.29763682 -0.34057038 0.00000000
8 | N 2.70448956 -0.60938640 0.28451228 5.91890700 -1.07381044 -0.26788530 0.93481454 0.00000000
--------------------------------------------------------------------------------
/franken/datasets/water/HH_digitizer.csv:
--------------------------------------------------------------------------------
1 | radius,num_atoms
2 | 1.3109243697478992, 0
3 | 1.4243697478991597, 0
4 | 1.53781512605042, 0
5 | 1.7394957983193278, 0.03409090909090906
6 | 1.8529411764705883, 0.04166666666666663
7 | 1.972689075630252, 0.19318181818181818
8 | 2.092436974789916, 0.5189393939393939
9 | 2.130252100840336, 0.7272727272727273
10 | 2.2058823529411766, 0.9507575757575757
11 | 2.32563025210084, 1.2803030303030303
12 | 2.439075630252101, 1.375
13 | 2.5588235294117645, 1.2424242424242424
14 | 2.672268907563025, 1.0492424242424243
15 | 2.792016806722689, 0.875
16 | 2.911764705882353, 0.7651515151515151
17 | 3.0252100840336134, 0.7272727272727273
18 | 3.1449579831932772, 0.7386363636363636
19 | 3.258403361344538, 0.7840909090909092
20 | 3.384453781512605, 0.8598484848484849
21 | 3.4978991596638656, 0.9583333333333334
22 | 3.6176470588235294, 1.0606060606060606
23 | 3.73109243697479, 1.1401515151515151
24 | 3.850840336134454, 1.1818181818181819
25 | 3.9705882352941178, 1.1666666666666667
26 | 4.084033613445378, 1.1325757575757576
27 | 4.203781512605042, 1.0909090909090908
28 | 4.317226890756302, 1.0568181818181819
29 | 4.436974789915967, 1.0416666666666665
30 | 4.55672268907563, 1.0340909090909092
31 | 4.670168067226891, 1.0265151515151514
32 | 4.7899159663865545, 1.018939393939394
33 | 4.9033613445378155, 1.003787878787879
34 | 5.023109243697479, 0.9886363636363636
35 | 5.142857142857142, 0.9772727272727273
36 | 5.256302521008403, 0.9734848484848485
37 | 5.376050420168067, 0.9621212121212122
38 | 5.489495798319328, 0.9621212121212122
39 |
--------------------------------------------------------------------------------
/franken/datasets/water/OH_digitizer.csv:
--------------------------------------------------------------------------------
1 | radius,num_atoms
2 | 1.3109243697478992, 0
3 | 1.4243697478991597, 0
4 | 1.5441176470588236, 0.049242424242424254
5 | 1.657563025210084, 0.3560606060606061
6 | 1.7773109243697478, 0.8598484848484849
7 | 1.8970588235294117, 1.0946969696969697
8 | 2.0105042016806722, 0.9659090909090909
9 | 2.130252100840336, 0.7272727272727273
10 | 2.2436974789915967, 0.49242424242424243
11 | 2.3634453781512605, 0.34090909090909094
12 | 2.476890756302521, 0.26136363636363635
13 | 2.596638655462185, 0.2537878787878788
14 | 2.716386554621849, 0.3106060606060606
15 | 2.8361344537815127, 0.45454545454545453
16 | 2.9495798319327733, 0.7083333333333334
17 | 3.069327731092437, 1.0416666666666665
18 | 3.189075630252101, 1.3409090909090908
19 | 3.302521008403361, 1.496212121212121
20 | 3.4159663865546217, 1.4848484848484849
21 | 3.5357142857142856, 1.371212121212121
22 | 3.6554621848739495, 1.2348484848484849
23 | 3.76890756302521, 1.128787878787879
24 | 3.888655462184874, 1.0568181818181819
25 | 4.008403361344538, 1.0151515151515151
26 | 4.241596638655462, 0.9962121212121212
27 | 4.3613445378151265, 0.9848484848484849
28 | 4.474789915966387, 0.9772727272727273
29 | 4.588235294117647, 0.9696969696969697
30 | 4.707983193277311, 0.9734848484848485
31 | 4.8277310924369745, 0.9772727272727273
32 | 4.9411764705882355, 0.9886363636363636
33 | 5.0609243697479, 0.9962121212121212
34 | 5.180672268907563, 0.9962121212121212
35 | 5.300420168067227, 1.003787878787879
36 | 5.413865546218487, 1.003787878787879
37 |
--------------------------------------------------------------------------------
/franken/datasets/water/OO_digitizer.csv:
--------------------------------------------------------------------------------
1 | radius,num_atoms
2 | 1.30373831775701,0.0
3 | 1.42056074766355,0.0
4 | 1.54672897196262,0.0
5 | 1.66355140186916,0
6 | 1.77570093457944,-0.0
7 | 1.88785046728972,0.00
8 | 2.00934579439252,-0.0
9 | 2.1214953271028,0.0
10 | 2.24766355140187,0
11 | 2.3411214953271,0
12 | 2.47663551401869,0.0196078431372549
13 | 2.60280373831776,0.417366946778711
14 | 2.70560747663551,1.53221288515406
15 | 2.83644859813084,2.33893557422969
16 | 2.95327102803738,2.22408963585434
17 | 3.06542056074766,1.72829131652661
18 | 3.18691588785047,1.26890756302521
19 | 3.29906542056075,0.983193277310924
20 | 3.42056074766355,0.865546218487395
21 | 3.53738317757009,0.826330532212885
22 | 3.64953271028037,0.840336134453782
23 | 3.77102803738318,0.857142857142857
24 | 3.88785046728972,0.907563025210084
25 | 4.01401869158878,0.949579831932773
26 | 4.1214953271028,1
27 | 4.22897196261682,1.04201680672269
28 | 4.34579439252336,1.0812324929972
29 | 4.48130841121495,1.11484593837535
30 | 4.58411214953271,1.12885154061625
31 | 4.70560747663551,1.13445378151261
32 | 4.80841121495327,1.10924369747899
33 | 4.94859813084112,1.08403361344538
34 | 5.06542056074766,1.04761904761905
35 | 5.20093457943925,1.0140056022409
36 | 5.30373831775701,0.974789915966386
37 | 5.40654205607477,0.943977591036414
38 | 5.53271028037383,0.913165266106443
39 | 5.63551401869159,0.913165266106443
40 | 5.74299065420561,0.907563025210084
41 | 5.87383177570093,0.896358543417367
42 | 5.97196261682243,0.910364145658263
43 |
--------------------------------------------------------------------------------
/franken/datasets/water/water_dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import re
3 | import zipfile
4 |
5 | import ase
6 | import ase.io
7 |
8 | from franken.datasets.registry import DATASET_REGISTRY, BaseRegisteredDataset
9 | from franken.utils.file_utils import download_file
10 |
11 |
12 | @DATASET_REGISTRY.register("water")
13 | class WaterRegisteredDataset(BaseRegisteredDataset):
14 | relative_paths = {
15 | "water": {
16 | "train": "water/ML_AB_dataset_1.xyz",
17 | "val": "water/ML_AB_dataset_2-val.xyz",
18 | },
19 | }
20 | zip_file_names = ["ML_AB_dataset_1", "ML_AB_dataset_2", "ML_AB_128h2o_validation"]
21 |
22 | @classmethod
23 | def get_path(
24 | cls, name: str, split: str, base_path: Path | None, download: bool = True
25 | ):
26 | if base_path is None:
27 | raise KeyError(None)
28 | relative_path = cls.relative_paths[name][split]
29 | path = base_path / relative_path
30 | if not path.is_file() and download:
31 | cls.download(base_path)
32 | if path.is_file():
33 | return path
34 | else:
35 | raise ValueError(f"Dataset not found at '{path.resolve()}'")
36 |
37 | @classmethod
38 | def download(cls, base_path: Path):
39 | water_base_path = base_path / "water"
40 | water_base_path.mkdir(exist_ok=True, parents=True)
41 |
42 | # NOTE: cannot check MD5 here since it changes at every download. As a dumb fallback we check the file-size.
43 | download_file(
44 | url="https://zenodo.org/api/records/10723405/files-archive",
45 | filename=water_base_path / "data.zip",
46 | desc="Downloading water dataset",
47 | expected_size=35866571,
48 | )
49 | # Extract from zip and convert VASP -> XYZ format
50 | with zipfile.ZipFile(water_base_path / "data.zip", mode="r") as zf:
51 | for file_name in cls.zip_file_names:
52 | with zf.open(file_name, "r") as fh:
53 | vasp_data = fh.read().decode("utf-8")
54 | xyz_data = vasp_mlff_to_xyz(vasp_data)
55 | with open(water_base_path / f"{file_name}.xyz", "w") as fh:
56 | fh.write(xyz_data)
57 | # Sanity check
58 | traj = ase.io.read(water_base_path / f"{file_name}.xyz", index=":")
59 | assert isinstance(traj, list)
60 | for i, atoms in enumerate(traj):
61 | atoms.get_potential_energy()
62 | atoms.get_forces()
63 | # Split a validation set from dataset-2
64 | dataset = ase.io.read(
65 | water_base_path / "ML_AB_dataset_2.xyz", index=":", format="extxyz"
66 | )
67 | assert isinstance(dataset, list)
68 | dataset_no_overlap = dataset[473:]
69 | ase.io.write(water_base_path / "ML_AB_dataset_2-val.xyz", dataset_no_overlap)
70 | # Cleanup
71 | (water_base_path / "data.zip").unlink()
72 |
73 |
74 | def vasp_mlff_to_xyz_oneconfig(data):
75 | # Parse sections using regular expressions
76 | num_atoms = int(re.search(r"The number of atoms\s*[-=]+\s*(\d+)", data).group(1))
77 | energy = float(
78 | re.search(r"Total energy \(eV\)\s*[-=]+\s*([-+]?\d*\.\d+|\d+)", data).group(1)
79 | )
80 |
81 | # Extract lattice vectors
82 | lattice_match = re.search(
83 | r"Primitive lattice vectors \(ang.\)\s*[-=]+\s*([\d\s.-]+)", data
84 | )
85 | lattice_lines = lattice_match.group(1).strip().split("\n")
86 | lattice = [line.split() for line in lattice_lines]
87 |
88 | # Flatten and format lattice as a string for XYZ format
89 | lattice_flat = " ".join([" ".join(line) for line in lattice])
90 |
91 | # Extract atomic positions
92 | positions_match = re.search(
93 | r"Atomic positions \(ang.\)\s*[-=]+\s*([\d\s.-]+)", data
94 | )
95 | positions_lines = positions_match.group(1).strip().split("\n")
96 | positions = [line.split() for line in positions_lines]
97 |
98 | # Extract forces
99 | forces_match = re.search(r"Forces \(eV ang.\^-1\)\s*[-=]+\s*([\d\s.-]+)", data)
100 | forces_lines = forces_match.group(1).strip().split("\n")
101 | forces = [line.split() for line in forces_lines]
102 |
103 | # Extract stress tensor (two lines) without separators
104 | stress_match_1 = re.search(
105 | r"Stress \(kbar\)\s*[-=]+\s*XX YY ZZ\s*[-=]+\s*([\d\s.-]+)", data
106 | )
107 | stress_match_2 = re.search(r"XY YZ ZX\s*[-=]+\s*([\d\s.-]+)", data)
108 |
109 | # Ensure we only capture numerical values and not separator lines
110 | stress_values_1 = (
111 | stress_match_1.group(1).strip().split()[:3]
112 | ) # Take first three values for XX YY ZZ
113 | stress_values_2 = (
114 | stress_match_2.group(1).strip().split()[:3]
115 | ) # Take first three values for XY YZ ZX
116 | xx, yy, zz = stress_values_1
117 | xy, yz, zx = stress_values_2
118 |
119 | # Combine the two stress components into a single list
120 | # stress_tensor = stress_values_1 + stress_values_2
121 | # stress_tensor = ' '.join(stress_tensor) # Convert to a single string
122 | stress_tensor = f"{xx} {xy} {zx} {xy} {yy} {yz} {zx} {yz} {zz}"
123 |
124 | # Create the extended XYZ content for this configuration
125 | xyz_content = []
126 | xyz_content.append(f"{num_atoms}")
127 | xyz_content.append(
128 | f'Lattice="{lattice_flat}" Properties=species:S:1:pos:R:3:forces:R:3 energy={energy} stress="{stress_tensor}"'
129 | )
130 |
131 | # Atom types (order them according to the positions provided)
132 | atom_type_lines = (
133 | re.search(r"Atom types and atom numbers\s*[-=]+\s*([\w\s\d]+)", data)
134 | .group(1)
135 | .strip()
136 | .split("\n")
137 | )
138 | atom_types = []
139 | for line in atom_type_lines:
140 | element, count = line.split()
141 | atom_types.extend([element] * int(count))
142 |
143 | # Add each atom's data line by line
144 | for idx, (position, force) in enumerate(zip(positions, forces)):
145 | element = atom_types[idx]
146 | px, py, pz = position
147 | fx, fy, fz = force
148 | xyz_content.append(f"{element} {px} {py} {pz} {fx} {fy} {fz}")
149 |
150 | return "\n".join(xyz_content)
151 |
152 |
153 | def vasp_mlff_to_xyz(data):
154 | # Split the data by configurations using "Configuration num." as the delimiter
155 | configurations = re.split(r"Configuration num\.\s*\d+", data)
156 | xyz_all = []
157 |
158 | # Process each configuration if it is not empty
159 | for config in configurations:
160 | config = config.strip()
161 | if config: # Only parse if the configuration is not empty
162 | try:
163 | xyz_all.append(vasp_mlff_to_xyz_oneconfig(config))
164 | except AttributeError:
165 | pass # some errors are expected.
166 |
167 | # Join all configurations with a newline
168 | return "\n".join(xyz_all)
169 |
170 |
171 | if __name__ == "__main__":
172 | WaterRegisteredDataset.download(Path(__file__).parent.parent)
173 |
--------------------------------------------------------------------------------
/franken/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from franken.metrics.base import BaseMetric
4 | from franken.metrics.functions import * # noqa: F403
5 | from franken.metrics.registry import registry
6 |
7 |
8 | __all__ = ["registry"]
9 |
10 |
11 | def available_metrics() -> list[str]:
12 | return registry.available_metrics
13 |
14 |
15 | def register(name: str, metric_class: type) -> None:
16 | registry.register(name, metric_class)
17 |
18 |
19 | def init_metric(
20 | name: str, device: torch.device, dtype: torch.dtype = torch.float32
21 | ) -> BaseMetric:
22 | return registry.init_metric(name, device, dtype)
23 |
--------------------------------------------------------------------------------
/franken/metrics/base.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | import torch
3 |
4 | import franken.utils.distributed as dist_utils
5 | from franken.data.base import Target
6 |
7 |
8 | class BaseMetric:
9 | def __init__(
10 | self,
11 | name: str,
12 | device: torch.device,
13 | dtype: torch.dtype = torch.float64,
14 | units: Mapping[str, str | None] = {},
15 | ):
16 | self.name = name
17 | self.device = device
18 | self.dtype = dtype
19 | self.buffer = None
20 | self.samples_counter = torch.zeros((1,), device=device, dtype=dtype)
21 | self.units = units
22 |
23 | def reset(self) -> None:
24 | """Reset the buffer to zeros"""
25 | self.buffer = None
26 | self.samples_counter = torch.zeros((1,), device=self.device, dtype=torch.int64)
27 |
28 | def buffer_add(self, value: torch.Tensor, num_samples: int = 1) -> None:
29 | if self.buffer is None:
30 | self.buffer = torch.zeros(value.shape, device=self.device, dtype=self.dtype)
31 | else:
32 | assert self.buffer.shape == value.shape
33 | self.buffer += value
34 | self.samples_counter += num_samples
35 |
36 | def update(
37 | self,
38 | predictions: Target,
39 | targets: Target,
40 | ) -> None:
41 | """Update the metric buffer with new batch results"""
42 | raise NotImplementedError()
43 |
44 | def compute(self, reset: bool = True) -> torch.Tensor:
45 | if self.buffer is None:
46 | raise ValueError(
47 | f"Cannot compute value for metric '{self.name}' "
48 | "because it was never updated."
49 | )
50 | dist_utils.all_sum(self.buffer)
51 | dist_utils.all_sum(self.samples_counter)
52 | error = self.buffer / self.samples_counter
53 | if reset:
54 | self.reset()
55 | return error
56 |
--------------------------------------------------------------------------------
/franken/metrics/functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from franken.data.base import Target
5 | from franken.metrics.base import BaseMetric
6 | from franken.metrics.registry import registry
7 | from franken.utils import distributed
8 |
9 |
10 | __all__ = [
11 | "EnergyMAE",
12 | "EnergyRMSE",
13 | "ForcesMAE",
14 | "ForcesRMSE",
15 | "ForcesCosineSimilarity",
16 | "is_pareto_efficient",
17 | ]
18 |
19 |
20 | class EnergyMAE(BaseMetric):
21 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32):
22 | units = {
23 | "inputs": "eV",
24 | "outputs": "meV/atom",
25 | }
26 | super().__init__("energy_MAE", device, dtype, units)
27 |
28 | def update(self, predictions: Target, targets: Target) -> None:
29 | if targets.forces is None:
30 | raise NotImplementedError(
31 | "At the moment, target's forces are required to get the number of atoms in the configuration."
32 | )
33 | num_atoms = targets.forces.shape[-2]
34 | num_samples = 1
35 | if targets.energy.ndim > 0:
36 | num_samples = targets.energy.shape[0]
37 |
38 | error = 1000 * torch.abs(targets.energy - predictions.energy) / num_atoms
39 |
40 | self.buffer_add(error, num_samples=num_samples)
41 |
42 |
43 | class EnergyRMSE(BaseMetric):
44 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32):
45 | units = {
46 | "inputs": "eV",
47 | "outputs": "meV/atom",
48 | }
49 | super().__init__("energy_RMSE", device, dtype, units)
50 |
51 | def update(self, predictions: Target, targets: Target) -> None:
52 | if targets.forces is None:
53 | raise NotImplementedError(
54 | "At the moment, target's forces are required to get the number of atoms in the configuration."
55 | )
56 | num_atoms = targets.forces.shape[-2]
57 | num_samples = 1
58 | if targets.energy.ndim > 0:
59 | num_samples = targets.energy.shape[0]
60 |
61 | error = torch.square((targets.energy - predictions.energy) / num_atoms)
62 |
63 | self.buffer_add(error, num_samples=num_samples)
64 |
65 | def compute(self, reset: bool = True) -> torch.Tensor:
66 | if self.buffer is None:
67 | raise ValueError(
68 | f"Cannot compute value for metric '{self.name}' "
69 | "because it was never updated."
70 | )
71 | distributed.all_sum(self.buffer)
72 | distributed.all_sum(self.samples_counter)
73 | error = self.buffer / self.samples_counter
74 | # square-root and fix units
75 | error = torch.sqrt(error) * 1000
76 | if reset:
77 | self.reset()
78 | return error
79 |
80 |
81 | class ForcesMAE(BaseMetric):
82 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32):
83 | units = {
84 | "inputs": "eV/ang",
85 | "outputs": "meV/ang",
86 | }
87 | super().__init__("forces_MAE", device, dtype, units)
88 |
89 | def update(self, predictions: Target, targets: Target) -> None:
90 | if targets.forces is None or predictions.forces is None:
91 | raise AttributeError("Forces must be specified to compute the MAE.")
92 | num_samples = 1
93 | if targets.forces.ndim > 2:
94 | num_samples = targets.forces.shape[0]
95 | elif targets.forces.ndim < 2:
96 | raise ValueError("Forces must be a 2D tensor or a batch of 2D tensors.")
97 |
98 | error = 1000 * torch.abs(targets.forces - predictions.forces)
99 | error = error.mean(dim=(-1, -2)) # Average over atoms and components
100 |
101 | self.buffer_add(error, num_samples=num_samples)
102 |
103 |
104 | class ForcesRMSE(BaseMetric):
105 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32):
106 | units = {
107 | "inputs": "eV/ang",
108 | "outputs": "meV/ang",
109 | }
110 | super().__init__("forces_RMSE", device, dtype, units)
111 |
112 | def update(self, predictions: Target, targets: Target) -> None:
113 | if targets.forces is None or predictions.forces is None:
114 | raise AttributeError("Forces must be specified to compute the MAE.")
115 | num_samples = 1
116 | if targets.forces.ndim > 2:
117 | num_samples = targets.forces.shape[0]
118 | elif targets.forces.ndim < 2:
119 | raise ValueError("Forces must be a 2D tensor or a batch of 2D tensors.")
120 |
121 | error = torch.square(targets.forces - predictions.forces)
122 | error = error.mean(dim=(-1, -2)) # Average over atoms and components
123 |
124 | self.buffer_add(error, num_samples=num_samples)
125 |
126 | def compute(self, reset: bool = True) -> torch.Tensor:
127 | if self.buffer is None:
128 | raise ValueError(
129 | f"Cannot compute value for metric '{self.name}' "
130 | "because it was never updated."
131 | )
132 | distributed.all_sum(self.buffer)
133 | distributed.all_sum(self.samples_counter)
134 | error = self.buffer / self.samples_counter
135 | # square-root and fix units
136 | error = torch.sqrt(error) * 1000
137 | if reset:
138 | self.reset()
139 | return error
140 |
141 |
142 | class ForcesRMSE2(BaseMetric):
143 | """Average of RMSE along individual structures"""
144 |
145 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32):
146 | units = {
147 | "inputs": "eV/ang",
148 | "outputs": "meV/ang",
149 | }
150 | super().__init__("forces_RMSE", device, dtype, units)
151 |
152 | def update(self, predictions: Target, targets: Target) -> None:
153 | if targets.forces is None or predictions.forces is None:
154 | raise AttributeError("Forces must be specified to compute the MAE.")
155 | num_samples = 1
156 | if targets.forces.ndim > 2:
157 | num_samples = targets.forces.shape[0]
158 | elif targets.forces.ndim < 2:
159 | raise ValueError("Forces must be a 2D tensor or a batch of 2D tensors.")
160 |
161 | error = torch.square(targets.forces - predictions.forces)
162 | error = error.mean(dim=(-1, -2)) # Average over atoms and components
163 | error = torch.sqrt(error) * 1000
164 | self.buffer_add(error, num_samples=num_samples)
165 |
166 |
167 | class ForcesCosineSimilarity(BaseMetric):
168 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32):
169 | units = {
170 | "inputs": "eV/ang",
171 | "outputs": None,
172 | }
173 | super().__init__("forces_cosim", device, dtype, units)
174 |
175 | def update(
176 | self,
177 | predictions: Target,
178 | targets: Target,
179 | ) -> None:
180 | num_samples = 1
181 | assert targets.forces is not None
182 | assert predictions.forces is not None
183 | if targets.forces.ndim > 2:
184 | num_samples = targets.forces.shape[0]
185 | elif targets.forces.ndim < 2:
186 | raise ValueError("Forces must be a 2D tensor or a batch of 2D tensors.")
187 |
188 | cos_similarity = torch.nn.functional.cosine_similarity(
189 | predictions.forces, targets.forces, dim=-1
190 | )
191 | cos_similarity = cos_similarity.mean(dim=-1)
192 | self.buffer_add(cos_similarity, num_samples=num_samples)
193 |
194 |
195 | def is_pareto_efficient(costs):
196 | """
197 | Find the pareto-efficient points
198 | :param costs: An (n_points, n_costs) array
199 | :return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient
200 | """
201 | is_efficient = np.ones(costs.shape[0], dtype=bool)
202 | for i, c in enumerate(costs):
203 | if is_efficient[i]:
204 | is_efficient[is_efficient] = np.any(
205 | costs[is_efficient] < c, axis=1
206 | ) # Keep any point with a lower cost
207 | is_efficient[i] = True # And keep self
208 | return is_efficient
209 |
210 |
211 | registry.register("energy_MAE", EnergyMAE)
212 | registry.register("energy_RMSE", EnergyRMSE)
213 | registry.register("forces_MAE", ForcesMAE)
214 | registry.register("forces_RMSE", ForcesRMSE)
215 | registry.register("forces_RMSE2", ForcesRMSE2)
216 | registry.register("forces_cosim", ForcesCosineSimilarity)
217 |
--------------------------------------------------------------------------------
/franken/metrics/registry.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from franken.metrics.base import BaseMetric
4 |
5 |
6 | class MetricRegistry:
7 | _instance = None
8 |
9 | def __new__(cls):
10 | if cls._instance is None:
11 | cls._instance = super().__new__(cls)
12 | cls._instance._metrics = {}
13 | return cls._instance
14 |
15 | def register(self, name: str, metric_class: type) -> None:
16 | """Register a metric class"""
17 | self._metrics[name] = metric_class
18 |
19 | def init_metric(
20 | self, name: str, device: torch.device, dtype: torch.dtype = torch.float32
21 | ) -> BaseMetric:
22 | """Create a new instance of a metric"""
23 | if name not in self._metrics:
24 | raise KeyError(
25 | f"Metric '{name}' not found. Available metrics: {list(self._metrics.keys())}"
26 | )
27 | return self._metrics[name](device=device, dtype=dtype)
28 |
29 | @property
30 | def available_metrics(self) -> list[str]:
31 | return list(self._metrics.keys())
32 |
33 |
34 | registry = MetricRegistry()
35 |
--------------------------------------------------------------------------------
/franken/rf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/franken/rf/__init__.py
--------------------------------------------------------------------------------
/franken/rf/atomic_energies.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 |
3 | import torch
4 |
5 |
6 | class AtomicEnergiesShift(torch.nn.Module):
7 | atomic_energies: torch.Tensor
8 | Z_keys: list[int]
9 |
10 | def __init__(
11 | self,
12 | num_species: int,
13 | atomic_energies: Mapping[int, torch.Tensor | float] | None = None,
14 | ):
15 | """
16 | Initialize the AtomicEnergiesShift module.
17 |
18 | Args:
19 | num_species:
20 | atomic_energies: A dictionary mapping atomic numbers to atomic energies.
21 | """
22 | super().__init__()
23 |
24 | self.num_species = num_species
25 | self.register_buffer("atomic_energies", torch.zeros(num_species))
26 | self.register_buffer(
27 | "z_keys", torch.zeros((self.num_species,), dtype=torch.long)
28 | ) # placeholder
29 | self.is_initialized = False
30 |
31 | if atomic_energies is not None:
32 | self.set_from_atomic_energies(atomic_energies)
33 |
34 | def set_from_atomic_energies(
35 | self, atomic_energies: Mapping[int, torch.Tensor | float]
36 | ):
37 | assert (
38 | len(atomic_energies) == self.num_species
39 | ), f"{len(atomic_energies)=} != {self.num_species=}"
40 | device = self.atomic_energies.device
41 | self.atomic_energies = torch.stack(
42 | [
43 | v.clone().detach() if isinstance(v, torch.Tensor) else torch.tensor(v)
44 | for v in atomic_energies.values()
45 | ]
46 | ).to(device)
47 | self.z_keys = torch.tensor(
48 | list(atomic_energies.keys()),
49 | dtype=torch.long,
50 | device=self.atomic_energies.device,
51 | )
52 | self.is_initialized = True
53 |
54 | def forward(self, atomic_numbers: torch.Tensor) -> torch.Tensor:
55 | """
56 | Calculate the energy shift for a given set of atomic numbers.
57 |
58 | Args:
59 | atomic_numbers: A tensor containing atomic numbers for which to calculate the energy shift.
60 |
61 | Returns:
62 | A tensor representing the total energy shift for the provided atomic numbers.
63 | """
64 |
65 | shift = torch.tensor(
66 | 0.0, dtype=self.atomic_energies.dtype, device=self.atomic_energies.device
67 | )
68 |
69 | for z, atom_ene in zip(self.z_keys, self.atomic_energies):
70 | mask = atomic_numbers == int(z.item())
71 | shift += torch.sum(atom_ene * mask)
72 |
73 | return shift
74 |
75 | def __repr__(self):
76 | formatted_energies = " , ".join(
77 | [
78 | f"{z.item()}: {atom_ene}"
79 | for z, atom_ene in zip(self.z_keys, self.atomic_energies)
80 | ]
81 | )
82 | return f"{self.__class__.__name__}({formatted_energies})"
83 |
--------------------------------------------------------------------------------
/franken/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | """Train franken from data."""
2 |
3 | from franken.trainers.base import BaseTrainer
4 | from franken.trainers.rf_cuda_lowmem import RandomFeaturesTrainer
5 |
6 | __all__ = (
7 | "BaseTrainer",
8 | "RandomFeaturesTrainer",
9 | )
10 |
--------------------------------------------------------------------------------
/franken/trainers/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import json
3 | import logging
4 | from pathlib import Path
5 | from typing import Tuple, Union
6 |
7 | import torch
8 | import torch.utils.data
9 |
10 | from franken.config import asdict_with_classvar
11 | from franken.rf.model import FrankenPotential
12 | from franken.rf.scaler import Statistics, compute_dataset_statistics
13 | from franken.trainers.log_utils import (
14 | DataSplit,
15 | LogCollection,
16 | LogEntry,
17 | dtypeJSONEncoder,
18 | )
19 | from franken.utils.misc import are_dicts_equal
20 |
21 |
22 | logger = logging.getLogger("franken")
23 |
24 |
25 | class BaseTrainer(abc.ABC):
26 | """Base trainer class. Requires :meth:`~BaseTrainer.fit` and :meth:`~BaseTrainer.evaluate` methods."""
27 |
28 | def __init__(
29 | self,
30 | train_dataloader: torch.utils.data.DataLoader,
31 | log_dir: Path | None = None, # If None, logging is disabled
32 | save_every_model: bool = True,
33 | device: Union[torch.device, str, int] = "cpu",
34 | dtype: Union[str, torch.dtype] = torch.float32,
35 | ):
36 | self.log_dir = log_dir
37 | self.save_every_model = save_every_model
38 | self.train_dataloader = train_dataloader
39 | self.statistics_ = None
40 | if isinstance(dtype, str):
41 | if dtype.lower() == "float64" or dtype.lower == "double":
42 | dtype = torch.float64
43 | elif dtype.lower() == "float32" or dtype.lower == "float":
44 | dtype = torch.float32
45 | else:
46 | raise ValueError(
47 | f"Invalid dtype {dtype}. Allowed values are 'float64', 'double', 'float32', 'single'."
48 | )
49 | if dtype not in {torch.float32, torch.float64}:
50 | raise ValueError(
51 | f"Invalid dtype {dtype}. torch.float32 or torch.float64 are allowed."
52 | )
53 | self.buffer_dt = dtype
54 | self.device = torch.device(device)
55 |
56 | @torch.no_grad()
57 | def get_statistics(self, model: FrankenPotential) -> Tuple[Statistics, dict]:
58 | """Compute statistics on the training dataset with the provided model
59 |
60 | Args:
61 | model (FrankenPotential): Franken model from which the attached GNN
62 | is used to compute the features on atomic configurations.
63 |
64 | Returns:
65 | A tuple containing an object of type :class:`franken.rf.scaler.Statistics` containing
66 | the dataset statistics, and a dictionary containing the GNN-backbone hyperparameters
67 | used when computing dataset features.
68 | """
69 | if self.statistics_ is None or not are_dicts_equal(
70 | self.statistics_[1], asdict_with_classvar(model.gnn_config)
71 | ):
72 | stat = compute_dataset_statistics(
73 | dataset=self.train_dataloader.dataset, # type: ignore
74 | gnn=model.gnn,
75 | device=self.device,
76 | )
77 | stat_dict = asdict_with_classvar(model.gnn_config)
78 | self.statistics_ = (stat, stat_dict)
79 |
80 | return self.statistics_
81 |
82 | @abc.abstractmethod
83 | def fit(
84 | self,
85 | model: FrankenPotential,
86 | solver_params: dict,
87 | ) -> tuple[LogCollection, torch.Tensor]:
88 | """Fit a given franken model on the training set.
89 |
90 | Args:
91 | model (FrankenPotential): The model which defines GNN and random features.
92 | solver_params (dict): Parameters for the solver which actually
93 | performs the fit.
94 |
95 | Returns:
96 | tuple[LogCollection, torch.Tensor]:
97 | - Logs which contain all parameters related to the fitting, as well as timings.
98 | - Weights which were learned during the fit.
99 | """
100 | pass
101 |
102 | @abc.abstractmethod
103 | def evaluate(
104 | self,
105 | model: FrankenPotential,
106 | dataloader: torch.utils.data.DataLoader,
107 | log_collection: LogCollection,
108 | all_weights: torch.Tensor,
109 | metrics: list[str],
110 | ) -> LogCollection:
111 | """Evaluate a fitted model by computing metrics on a validation dataset.
112 |
113 | Args:
114 | model: The model which defines GNN and random features.
115 | dataloader (torch.utils.data.DataLoader): Evaluation will run the model
116 | on each configuration in the dataloader, computing averaged metrics.
117 | log_collection: Log object as output by the :meth:`fit`
118 | method. Metric values will be added to the logs and the same object will
119 | be returned by this method.
120 | all_weights (torch.Tensor): The weights as output by the :meth:`fit` method.
121 | metrics (list[str]): List of metrics which should be computed.
122 |
123 | Returns:
124 | logs (LogCollection): Logs which contain all parameters related
125 | to the fitting, as well as timings and metrics.
126 | """
127 | pass
128 |
129 | def serialize_logs(
130 | self,
131 | model: FrankenPotential,
132 | log_collection: LogCollection,
133 | all_weights: torch.Tensor,
134 | best_model_split: DataSplit = DataSplit.TRAIN,
135 | ):
136 | assert self.log_dir is not None, "Log directory is not set"
137 | model_hash_set = set(log.checkpoint_hash for log in log_collection)
138 | assert len(model_hash_set) == 1
139 | model_hash = model_hash_set.pop()
140 | log_collection.save_json(self.log_dir / "log.json")
141 |
142 | # Save the model checkpoint
143 | if self.save_every_model:
144 | ckpt_dir = self.log_dir / "checkpoints"
145 | ckpt_dir.mkdir(parents=True, exist_ok=True)
146 | model_save_path = ckpt_dir / f"{model_hash}.pt"
147 | model.save(model_save_path, multi_weights=all_weights)
148 | logger.debug(
149 | f"Saved multiple models (hash={model_hash}) " f"to {model_save_path}"
150 | )
151 | # Log the best model
152 | self.serialize_best_model(model, all_weights, split=best_model_split)
153 |
154 | def serialize_best_model(
155 | self,
156 | model: FrankenPotential,
157 | all_weights: torch.Tensor,
158 | split: DataSplit = DataSplit.TRAIN,
159 | ) -> None:
160 | assert self.log_dir is not None, "Log directory is not set"
161 | log_collection = LogCollection.from_json(self.log_dir / "log.json")
162 | best_model = log_collection.get_best_model(split=split)
163 |
164 | best_model_file = self.log_dir / "best.json"
165 | should_save = True
166 | if best_model_file.exists():
167 | with open(best_model_file, "r") as f:
168 | current_best = LogEntry.from_dict(json.load(f))
169 | if best_model == current_best:
170 | should_save = False
171 |
172 | if should_save:
173 | logger.debug(f"Identified new best model: {best_model}")
174 | with open(best_model_file, "w") as f:
175 | json.dump(best_model.to_dict(), f, indent=4, cls=dtypeJSONEncoder)
176 | weights = all_weights[best_model.checkpoint_rf_weight_id]
177 | model.rf.weights = weights.reshape_as(model.rf.weights)
178 | model.save(self.log_dir / "best_ckpt.pt")
179 | logger.debug(
180 | f"Saved best model (within-experiment ID={best_model.checkpoint_rf_weight_id}) "
181 | f"to {self.log_dir / 'best_ckpt.pt'}"
182 | )
183 |
--------------------------------------------------------------------------------
/franken/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/franken/utils/__init__.py
--------------------------------------------------------------------------------
/franken/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Tuple, Union
3 | import warnings
4 | from socket import gethostname
5 |
6 | import torch
7 | import torch.distributed
8 |
9 | from . import hostlist
10 |
11 |
12 | def slurm_to_env():
13 | hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0]
14 | os.environ["MASTER_ADDR"] = hostname
15 | os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33633")
16 | try:
17 | world_size = int(os.environ["SLURM_NTASKS"])
18 | except KeyError:
19 | world_size = int(os.environ["SLURM_NTASKS_PER_NODE"]) * int(
20 | os.environ["SLURM_NNODES"]
21 | )
22 | os.environ["WORLD_SIZE"] = str(world_size)
23 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
24 | os.environ["RANK"] = os.environ["SLURM_PROCID"]
25 |
26 |
27 | def is_torchrun():
28 | # torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
29 | return "RANK" in os.environ and "WORLD_SIZE" in os.environ
30 |
31 |
32 | def is_slurm():
33 | return "SLURM_PROCID" in os.environ
34 |
35 |
36 | def init(distributed: bool) -> int:
37 | if distributed:
38 | if not torch.cuda.is_available():
39 | raise RuntimeError("Distributed training is only supported on CUDA")
40 | if is_torchrun():
41 | pass
42 | elif is_slurm():
43 | slurm_to_env()
44 | else:
45 | warnings.warn(
46 | "Cannot initialize distributed training. "
47 | "Neither torchrun nor SLURM environment variable were found."
48 | )
49 | world_size = int(os.environ.get("WORLD_SIZE", 1))
50 | if world_size > 1:
51 | print(
52 | f"Distributed initialization at rank {os.environ['RANK']} of {world_size} "
53 | f"(rank {os.environ['LOCAL_RANK']} on {gethostname()} with "
54 | f"{torch.cuda.device_count()} GPUs allocated)."
55 | )
56 | torch.distributed.init_process_group(
57 | backend="nccl",
58 | device_id=torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}"),
59 | )
60 |
61 | device = f"cuda:{get_local_rank()}"
62 | torch.cuda.set_device(device)
63 | return get_rank()
64 |
65 |
66 | def get_local_rank() -> int:
67 | if torch.distributed.is_initialized():
68 | return int(os.environ["LOCAL_RANK"])
69 | return 0
70 |
71 |
72 | def get_rank() -> int:
73 | if torch.distributed.is_initialized():
74 | return torch.distributed.get_rank()
75 | return 0
76 |
77 |
78 | def barrier() -> None:
79 | if torch.distributed.is_initialized():
80 | torch.distributed.barrier()
81 |
82 |
83 | def get_world_size() -> int:
84 | if torch.distributed.is_initialized():
85 | return torch.distributed.get_world_size()
86 | return 1
87 |
88 |
89 | def all_reduce(tensor: torch.Tensor, op) -> None:
90 | if torch.distributed.is_initialized():
91 | torch.distributed.all_reduce(tensor, op)
92 | return None
93 |
94 |
95 | def all_sum(tensor: torch.Tensor) -> None:
96 | if torch.distributed.is_initialized():
97 | torch.distributed.all_reduce(tensor, torch.distributed.ReduceOp.SUM)
98 | return None
99 |
100 |
101 | def broadcast_obj(obj, src=0):
102 | if torch.distributed.is_initialized():
103 | to_broadcast = [obj]
104 | torch.distributed.broadcast_object_list(to_broadcast, src=src)
105 | return to_broadcast[0]
106 | return obj
107 |
108 |
109 | def all_gather_into_tensor(
110 | out_size: Union[Tuple, torch.Size], in_tensor: torch.Tensor
111 | ) -> torch.Tensor:
112 | if torch.distributed.is_initialized():
113 | out_tensor = torch.zeros(
114 | out_size, dtype=in_tensor.dtype, device=in_tensor.device
115 | )
116 | torch.distributed.all_gather_into_tensor(out_tensor, in_tensor)
117 | return out_tensor
118 | return in_tensor
119 |
120 |
121 | def all_gather(tensor: torch.Tensor) -> List[torch.Tensor]:
122 | if torch.distributed.is_initialized():
123 | shapes = [
124 | tensor.shape if r == get_rank() else None for r in range(get_world_size())
125 | ]
126 | for r in range(get_world_size()):
127 | shapes[r] = broadcast_obj(shapes[r], src=r)
128 | tensor_list = [
129 | (
130 | tensor
131 | if r == get_rank()
132 | else torch.empty(shapes[r], device=tensor.device, dtype=tensor.dtype) # type: ignore
133 | ) # type: ignore
134 | for r in range(get_world_size())
135 | ]
136 | torch.distributed.all_gather(tensor_list, tensor)
137 | return tensor_list
138 | return [tensor]
139 |
140 |
141 | def all_gather_object(obj) -> List:
142 | if torch.distributed.is_initialized():
143 | output = [None for _ in range(get_world_size())]
144 |
145 | torch.distributed.all_gather_object(output, obj)
146 | return output
147 | return [obj]
148 |
149 |
150 | def print0(*args, **kwargs):
151 | if get_rank() == 0:
152 | print(*args, **kwargs)
153 |
--------------------------------------------------------------------------------
/franken/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | from pathlib import Path
3 | import requests
4 | from tqdm.auto import tqdm
5 |
6 |
7 | def compute_md5(file_path):
8 | md5_hash = hashlib.md5()
9 | with open(file_path, "rb") as f:
10 | for chunk in iter(
11 | lambda: f.read(8192), b""
12 | ): # Read in chunks to handle large files
13 | md5_hash.update(chunk)
14 | return md5_hash.hexdigest()
15 |
16 |
17 | def validate_file(
18 | filename: Path, expected_size: int | None = None, expected_md5: str | None = None
19 | ) -> bool:
20 | if expected_md5 is not None:
21 | actual_md5 = compute_md5(filename)
22 | if actual_md5 != expected_md5:
23 | return False
24 | if expected_size is not None:
25 | actual_size = filename.stat().st_size
26 | if expected_size != actual_size:
27 | return False
28 | return True
29 |
30 |
31 | def download_file(
32 | url: str,
33 | filename: Path,
34 | expected_size: int | None = None,
35 | expected_md5: str | None = None,
36 | desc: str | None = None,
37 | ):
38 | if (expected_md5 is not None or expected_size is not None) and filename.is_file():
39 | # Check that the file is correct to avoid re-download
40 | if validate_file(filename, expected_size, expected_md5):
41 | return filename
42 |
43 | response = requests.get(url, stream=True)
44 | response.raise_for_status()
45 | total_size = int(response.headers.get("content-length", 0))
46 | block_size = 8192
47 | data_size = 0
48 | data_md5 = hashlib.md5()
49 | with (
50 | open(filename.with_suffix(".temp"), "wb") as file,
51 | tqdm(
52 | desc=desc or str(filename),
53 | total=total_size,
54 | unit="B",
55 | unit_scale=True,
56 | unit_divisor=1024,
57 | ) as bar,
58 | ):
59 | for chunk in response.iter_content(chunk_size=block_size):
60 | file.write(chunk)
61 | data_md5.update(chunk)
62 | data_size += len(chunk)
63 | bar.update(len(chunk))
64 | # validate
65 | if expected_size is not None and expected_size != data_size:
66 | raise IOError("Incorrect file size", filename)
67 | if expected_md5 is not None and data_md5.hexdigest() != expected_md5:
68 | raise IOError("Incorrect file MD5", filename)
69 |
70 | filename.with_suffix(".temp").replace(filename)
71 | return filename
72 |
--------------------------------------------------------------------------------
/franken/utils/jac.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from typing import Any, Callable, Optional, Sequence, Tuple, Union
3 |
4 | import torch
5 | from torch._functorch.eager_transforms import (
6 | _construct_standard_basis_for,
7 | _jvp_with_argnums,
8 | _slice_argnums,
9 | error_if_complex,
10 | safe_unflatten,
11 | )
12 | from torch._functorch.utils import argnums_t
13 | from torch.func import vmap
14 | from torch.utils._pytree import tree_flatten, tree_unflatten
15 |
16 | from franken.utils.misc import garbage_collection_cuda, is_cuda_out_of_memory
17 |
18 |
19 | def jacfwd(
20 | # drop-in replacement of torch.func.jacfwd accepting the chunk_size argument (as with jacrev)
21 | func: Callable,
22 | argnums: argnums_t = 0,
23 | has_aux: bool = False,
24 | *,
25 | randomness: str = "error",
26 | chunk_size: Optional[int] = None,
27 | ):
28 | def wrapper_fn(*args):
29 | error_if_complex("jacfwd", args, is_input=True)
30 | primals = args if argnums is None else _slice_argnums(args, argnums)
31 | flat_primals, primals_spec = tree_flatten(primals)
32 | flat_primals_numels = tuple(p.numel() for p in flat_primals)
33 | flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)
34 | basis = tree_unflatten(flat_basis, primals_spec)
35 |
36 | def push_jvp(basis):
37 | output = _jvp_with_argnums(
38 | func, args, basis, argnums=argnums, has_aux=has_aux
39 | )
40 | # output[0] is the output of `func(*args)`
41 | error_if_complex("jacfwd", output[0], is_input=False)
42 | if has_aux:
43 | _, jvp_out, aux = output
44 | return jvp_out, aux
45 | _, jvp_out = output
46 | return jvp_out
47 |
48 | results = vmap(push_jvp, randomness=randomness, chunk_size=chunk_size)(basis)
49 | if has_aux:
50 | results, aux = results
51 | # aux is in the standard basis format, e.g. NxN matrix
52 | # We need to fetch the first element as original `func` output
53 | flat_aux, aux_spec = tree_flatten(aux)
54 | flat_aux = [value[0] for value in flat_aux]
55 | aux = tree_unflatten(flat_aux, aux_spec)
56 |
57 | jac_outs, spec = tree_flatten(results)
58 | # Most probably below output check can never raise an error
59 | # as jvp should test the output before
60 | # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)')
61 |
62 | jac_outs_ins = tuple(
63 | tuple(
64 | safe_unflatten(jac_out_in, -1, primal.shape)
65 | for primal, jac_out_in in zip(
66 | flat_primals,
67 | jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1),
68 | )
69 | )
70 | for jac_out in jac_outs
71 | )
72 | jac_outs_ins = tuple(
73 | tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins
74 | )
75 |
76 | if isinstance(argnums, int):
77 | jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins)
78 | if has_aux:
79 | return tree_unflatten(jac_outs_ins, spec), aux
80 | return tree_unflatten(jac_outs_ins, spec)
81 |
82 | # Dynamo does not support HOP composition if their inner function is
83 | # annotated with @functools.wraps(...). We circumvent this issue by applying
84 | # wraps only if we're not tracing with dynamo.
85 | if not torch._dynamo.is_compiling():
86 | wrapper_fn = wraps(func)(wrapper_fn)
87 |
88 | return wrapper_fn
89 |
90 |
91 | def tune_jacfwd_chunksize(
92 | test_sample: Sequence[Union[torch.Tensor, Any]],
93 | mode: str = "power",
94 | init_val: int = 32,
95 | max_trials: int = 25,
96 | **jac_kwargs,
97 | ):
98 | try:
99 | # We want to tune this and set it ourselves.
100 | jac_kwargs.pop("chunk_size")
101 | except KeyError:
102 | pass
103 |
104 | # Initially we just double in size until an OOM is encountered
105 | new_size, _ = _adjust_batch_size(
106 | test_sample, init_val, value=init_val, **jac_kwargs
107 | ) # initially set to init_val
108 | if mode == "power":
109 | new_size = _run_power_scaling(new_size, max_trials, test_sample, **jac_kwargs)
110 | else:
111 | raise ValueError("mode in method `scale_batch_size` can only be `power`")
112 |
113 | garbage_collection_cuda()
114 | return new_size
115 |
116 |
117 | def _run_power_scaling(new_size, max_trials, test_sample, **jac_kwargs) -> int:
118 | """Batch scaling mode where the size is doubled at each iteration until an
119 | OOM error is encountered."""
120 | for _ in range(max_trials):
121 | garbage_collection_cuda()
122 | try:
123 | # Try jacfwd
124 | for _ in range(1):
125 | jacfwd(**jac_kwargs, chunk_size=new_size)(*test_sample)
126 | # Double in size
127 | new_size, changed = _adjust_batch_size(
128 | test_sample, new_size, factor=2.0, **jac_kwargs
129 | )
130 | except RuntimeError as exception:
131 | # Only these errors should trigger an adjustment
132 | if is_cuda_out_of_memory(exception):
133 | # If we fail in power mode, half the size and return
134 | garbage_collection_cuda()
135 | new_size, _ = _adjust_batch_size(
136 | test_sample, new_size, factor=0.5, **jac_kwargs
137 | )
138 | break
139 | else:
140 | raise # some other error not memory related
141 | if not changed:
142 | # No change in batch size, so we can exit.
143 | break
144 | return new_size
145 |
146 |
147 | def _adjust_batch_size(
148 | test_sample: Sequence[Union[torch.Tensor, Any]],
149 | batch_size: int,
150 | factor: float = 1.0,
151 | value: Optional[int] = None,
152 | **jac_kwargs,
153 | ) -> Tuple[int, bool]:
154 | max_batch_size = _get_max_batch_size(test_sample, **jac_kwargs)
155 | new_size = value if value is not None else int(batch_size * factor)
156 | new_size = min(new_size, max_batch_size)
157 | changed = new_size != batch_size
158 | return new_size, changed
159 |
160 |
161 | def _get_max_batch_size(
162 | test_sample: Sequence[Union[torch.Tensor, Any]], **jac_kwargs
163 | ) -> int:
164 | argnums = jac_kwargs.get("argnums", 0)
165 | if isinstance(argnums, int):
166 | argnums = [argnums]
167 | batch_size = 0
168 | for argnum in argnums:
169 | arg = test_sample[argnum]
170 | assert isinstance(arg, torch.Tensor)
171 | batch_size += arg.numel()
172 | return batch_size
173 |
--------------------------------------------------------------------------------
/franken/utils/linalg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/franken/utils/linalg/__init__.py
--------------------------------------------------------------------------------
/franken/utils/linalg/psdsolve.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 |
5 |
6 | try:
7 | import cupy.cuda
8 | from cupy_backends.cuda.libs import cublas, cusolver
9 | except ImportError:
10 | cupy = None
11 | cusolver = None
12 | cublas = None
13 |
14 |
15 | def psd_ridge(cov: torch.Tensor, rhs: torch.Tensor, penalty: float) -> torch.Tensor:
16 | """Solve ridge regression via Cholesky factorization, overwriting :attr:`cov` and :attr:`rhs`.
17 |
18 | Multiple right-hand sides are supported. Instead of providing the data
19 | matrix (commonly :math:`X` in ridge-regression notation), and labels (commonly :math:`y`),
20 | we are given directly :math:`\text{cov} = X^T X` and :math:`\text{rhs} = X^T y`.
21 | Since :attr:`cov` is symmetric only its **upper triangle** will be accessed.
22 |
23 | To limit memory usage, the :attr:`cov` matrix **may be overwritten**, and :math:`rhs`
24 | may also be overwritten (depending on its memory layout).
25 |
26 | Args:
27 | cov (Tensor): covariance of the linear system
28 | rhs (Tensor): right hand side (one or more) of the linear system
29 | penalty (float): Tikhonov l2 penalty
30 |
31 | Returns:
32 | solution (Tensor): the ridge regression coefficients
33 | """
34 | if cupy is not None and cov.device.type == "cuda":
35 | return _lowmem_psd_ridge(cov, rhs, penalty)
36 | else:
37 | # NOTE: this should be a warnings.warn NOT logger.warning - otherwise
38 | # it gets printed a lot of times and is just annoying. We could add
39 | # https://docs.python.org/library/logging.html#logging.captureWarnings
40 | # to the logger to capture warnings automatically.
41 | if cov.device.type == "cuda":
42 | warnings.warn(
43 | "low-memory solver cannot be used because `cupy` is not available. "
44 | "Install `cupy` if you encounter memory problems."
45 | )
46 | return _naive_psd_ridge(cov, rhs, penalty)
47 |
48 |
49 | def _naive_psd_ridge(
50 | cov: torch.Tensor, rhs: torch.Tensor, penalty: float
51 | ) -> torch.Tensor:
52 | # Add diagonal without copies
53 | cov.diagonal().add_(penalty)
54 | # Solve with cholesky on GPU
55 | L = torch.linalg.cholesky(cov, upper=True)
56 | rhs_shape = rhs.shape
57 | return torch.cholesky_solve(rhs.view(cov.shape[0], -1), L, upper=True).view(
58 | rhs_shape
59 | )
60 |
61 |
62 | def _lowmem_psd_ridge(
63 | cov: torch.Tensor, rhs: torch.Tensor, penalty: float
64 | ) -> torch.Tensor:
65 | assert cusolver is not None and cublas is not None and cupy is not None
66 | assert cov.device.type == "cuda"
67 | dtype = cov.dtype
68 | n = cov.shape[0]
69 |
70 | # Add diagonal without copies
71 | cov.diagonal().add_(penalty)
72 |
73 | if dtype == torch.float32:
74 | potrf = cusolver.spotrf
75 | potrf_bufferSize = cusolver.spotrf_bufferSize
76 | potrs = cusolver.spotrs
77 | elif dtype == torch.float64:
78 | potrf = cusolver.dpotrf
79 | potrf_bufferSize = cusolver.dpotrf_bufferSize
80 | potrs = cusolver.dpotrs
81 | else:
82 | raise ValueError(dtype)
83 |
84 | # cov must be f-contiguous (column-contiguous, stride is (1, n))
85 | assert cov.dim() == 2
86 | assert cov.shape[0] == cov.shape[1]
87 | transpose = False
88 | if n != 1:
89 | if cov.stride(0) != 1:
90 | cov = cov.T
91 | transpose = True
92 | assert cov.stride(0) == 1
93 | cov_cp = cupy.asarray(cov)
94 |
95 | # save rhs shape to restore it later on.
96 | rhs_shape = rhs.shape
97 | rhs = rhs.reshape(n, -1)
98 | n_rhs = rhs.shape[1]
99 | if rhs.stride(0) != 1: # force rhs to be f-contiguous
100 | # `contiguous` causes a copy
101 | rhs = rhs.T.contiguous().T
102 | assert rhs.stride(0) == 1
103 | rhs_cp = cupy.asarray(rhs)
104 |
105 | handle = cupy.cuda.device.get_cusolver_handle()
106 | uplo = cublas.CUBLAS_FILL_MODE_LOWER if transpose else cublas.CUBLAS_FILL_MODE_UPPER
107 | dev_info = torch.empty(
108 | 1, dtype=torch.int32
109 | ) # don't allocate with cupy as it uses a separate mem pool
110 | dev_info_cp = cupy.asarray(dev_info)
111 |
112 | worksize = potrf_bufferSize(handle, uplo, n, cov_cp.data.ptr, n)
113 | workspace = torch.empty(worksize, dtype=dtype)
114 | workspace_cp = cupy.asarray(workspace)
115 |
116 | # Cholesky factorization
117 | potrf(
118 | handle,
119 | uplo,
120 | n,
121 | cov_cp.data.ptr,
122 | n,
123 | workspace_cp.data.ptr,
124 | worksize,
125 | dev_info_cp.data.ptr,
126 | )
127 | if (dev_info_cp != 0).any():
128 | raise torch.linalg.LinAlgError(
129 | f"Error reported by {potrf.__name__} in cuSOLVER. devInfo = {dev_info_cp}."
130 | )
131 |
132 | # Solve: A * X = B
133 | potrs(
134 | handle,
135 | uplo,
136 | n,
137 | n_rhs,
138 | cov_cp.data.ptr,
139 | n,
140 | rhs_cp.data.ptr,
141 | n,
142 | dev_info_cp.data.ptr,
143 | )
144 | if (dev_info_cp != 0).any():
145 | raise torch.linalg.LinAlgError(
146 | f"Error reported by {potrf.__name__} in cuSOLVER. devInfo = {dev_info_cp}."
147 | )
148 |
149 | return torch.as_tensor(rhs).reshape(rhs_shape)
150 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "franken"
3 | dynamic = ["version"]
4 | authors = [
5 | { name="Pietro Novelli", email="pietronvll@gmail.com" },
6 | { name="Giacomo Meanti" },
7 | { name="Luigi Bonati" },
8 | { name="Pedro Juan Buigues Jorro" }
9 | ]
10 | description = "Franken fine-tuning scheme for ML potentials"
11 | readme = "README.md"
12 | license = "MIT"
13 | license-files = ["LICENSE.*"]
14 | requires-python = ">=3.10"
15 | dependencies = [
16 | "torch >= 2.4.0",
17 | "ase",
18 | "numpy",
19 | "tqdm",
20 | "psutil",
21 | "scipy",
22 | "e3nn",
23 | "omegaconf",
24 | "requests",
25 | "docstring_parser",
26 | ]
27 | classifiers = [
28 | "Development Status :: 4 - Beta",
29 | "Intended Audience :: Science/Research",
30 | "License :: OSI Approved :: MIT License",
31 | "Programming Language :: Python :: 3.10",
32 | "Programming Language :: Python :: 3.11",
33 | "Programming Language :: Python :: 3.12",
34 | ]
35 | keywords = [
36 | "franken", "potentials", "molecular dynamics",
37 | ]
38 |
39 | [project.urls]
40 | Homepage = "https://franken.readthedocs.io/"
41 | Documentation = "https://franken.readthedocs.io/"
42 | Repository = "https://github.com/CSML-IIT-UCL/franken"
43 |
44 | [project.scripts]
45 | "franken.backbones" = "franken.backbones.cli:main"
46 | "franken.autotune" = "franken.autotune.script:cli_entry_point"
47 | "franken.create_lammps_model" = "franken.calculators.lammps_calc:create_lammps_model_cli"
48 |
49 | [project.optional-dependencies]
50 | develop = [
51 | "black ~= 24.0",
52 | "ruff",
53 | "pytest",
54 | "pre-commit",
55 | "pytest",
56 | "packaging",
57 | ]
58 | mace = ["mace-torch >= 0.3.10"]
59 | fairchem = ["fairchem-core == 1.10"]
60 | sevenn = ["sevenn ~= 0.11"]
61 | cuda = ["cupy"]
62 | docs = [
63 | "Sphinx",
64 | "sphinxawesome-theme",
65 | "sphinxcontrib-applehelp",
66 | "sphinxcontrib-devhelp",
67 | "sphinxcontrib-htmlhelp",
68 | "sphinxcontrib-jsmath",
69 | "sphinxcontrib-qthelp",
70 | "sphinxcontrib-serializinghtml",
71 | "sphinx-argparse",
72 | "myst-parser",
73 | "nbsphinx",
74 | ]
75 |
76 | [build-system]
77 | requires = ["hatchling"]
78 | build-backend = "hatchling.build"
79 |
80 | [tool.hatch.version]
81 | path = "franken/__init__.py"
82 |
83 | [tool.hatch.build.targets.sdist]
84 | only-include = ["franken", "tests"]
85 |
86 | [tool.hatch.build.targets.wheel]
87 | include = [
88 | "franken/**/*.py",
89 | "franken/autotune/configs/**/*.yaml",
90 | "franken/mdgen/configs/**/*.yaml",
91 | "franken/backbones/registry.json",
92 | "franken/datasets/water/*.csv",
93 | "franken/datasets/test/*",
94 | ]
95 | exclude = [
96 | "franken/datasets/ala3",
97 | "franken/datasets/chignolin",
98 | "franken/datasets/Cu-EMT",
99 | "franken/datasets/CuFormate",
100 | "franken/datasets/Fe_N2",
101 | "franken/datasets/Fe4N",
102 | "franken/datasets/FeBulk",
103 | "franken/datasets/LiPS",
104 | "franken/datasets/MD22",
105 | "franken/datasets/split_data.py",
106 | "franken/datasets/download_and_process_all.sh",
107 | "franken/datasets/readme",
108 | ]
109 |
110 | [tool.black]
111 | line-length = 88
112 | target-version = ['py310', 'py312']
113 | force-exclude = '^/((?!franken/))'
114 |
115 | [tool.ruff]
116 | target-version = "py310"
117 | include = [
118 | "pyproject.toml",
119 | "franken/**/*.py",
120 | ]
121 | extend-exclude = [
122 | "franken/utils/hostlist.py",
123 | ]
124 | force-exclude = true
125 |
126 | [tool.ruff.lint]
127 | select = ["E4", "E7", "E9", "F", "W"]
128 | ignore = [
129 | "E501", # Avoid enforcing line-length violations (`E501`)
130 | ]
131 |
132 | [tool.pytest.ini_options]
133 | testpaths = ["tests"]
134 | markers = [
135 | "slow: marks tests as slow (deselect with '-m \"not slow\"')",
136 | ]
137 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import random
4 |
5 | import numpy
6 | import pytest
7 | import torch
8 |
9 | from franken import FRANKEN_DIR
10 | from franken.backbones.utils import CacheDir, download_checkpoint
11 | from franken.config import MaceBackboneConfig
12 |
13 |
14 | __all__ = [
15 | "ROOT_PATH",
16 | "DEFAULT_GNN_CONFIGS",
17 | "SKIP_NO_CUDA",
18 | "DEVICES",
19 | "DEV_CPU_FAIL"
20 | ]
21 |
22 | ROOT_PATH = FRANKEN_DIR
23 |
24 | DEFAULT_GNN_CONFIGS = [
25 | MaceBackboneConfig("MACE-L0")
26 | ] # , "SchNet-S2EF-OC20-All"] # List of gnn_ids to download
27 |
28 | SKIP_NO_CUDA = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
29 |
30 | DEVICES = [
31 | "cpu",
32 | pytest.param("cuda:0", marks=SKIP_NO_CUDA) # type: ignore
33 | ]
34 | DEV_CPU_FAIL = [
35 | pytest.param(dev, marks=pytest.mark.xfail(run=False, reason="Not implemented on CPU"))
36 | if dev == "cpu"
37 | else dev
38 | for dev in DEVICES
39 | ]
40 |
41 | def prepare_gnn_checkpoints():
42 | # Ensure each gnn_id backbone is downloaded
43 | for gnn_cfg in DEFAULT_GNN_CONFIGS:
44 | download_checkpoint(gnn_cfg.path_or_id)
45 | return CacheDir.get() / "gnn_checkpoints"
46 |
47 | def pytest_sessionstart(session):
48 | """
49 | Called after the Session object has been created and
50 | before performing collection and entering the run test loop.
51 | """
52 |
53 | # Cache-dir is either FRANKEN_CACHE_DIR if specified or in the repository folder.
54 | CacheDir.initialize(
55 | os.environ.get("FRANKEN_CACHE_DIR", Path(__file__).parent / ".franken")
56 | )
57 |
58 | prepare_gnn_checkpoints()
59 |
60 |
61 | @pytest.fixture(autouse=True)
62 | def random_seed():
63 | """This fixture is called before each test and sets random seeds"""
64 | random.seed(14)
65 | numpy.random.seed(14)
66 | torch.manual_seed(14)
67 |
--------------------------------------------------------------------------------
/tests/test_backbones.py:
--------------------------------------------------------------------------------
1 | import e3nn
2 | import pytest
3 | from packaging.version import Version
4 | import torch
5 |
6 | from franken.config import BackboneConfig, GaussianRFConfig
7 | from franken.data import BaseAtomsDataset
8 | from franken.datasets.registry import DATASET_REGISTRY
9 | from franken.backbones import REGISTRY
10 | from franken.backbones.utils import load_checkpoint
11 | from franken.rf.model import FrankenPotential
12 |
13 |
14 | models = [
15 | "Egret-1t",
16 | pytest.param("MACE-L1", marks=pytest.mark.xfail(Version(e3nn.__version__) >= Version("0.5.5"), reason="Known incompatibility", strict=True)),
17 | pytest.param("MACE-OFF-small", marks=pytest.mark.xfail(Version(e3nn.__version__) >= Version("0.5.5"), reason="Known incompatibility", strict=True)),
18 | pytest.param("SevenNet0", marks=pytest.mark.xfail(Version(e3nn.__version__) < Version("0.5.0"), reason="Known incompatibility", strict=True)),
19 | pytest.param("SchNet-S2EF-OC20-200k", marks=pytest.mark.xfail(reason="Fails in CI due to unknown reasons", strict=False))
20 | ]
21 |
22 |
23 | @pytest.mark.parametrize("model_name", models)
24 | def test_backbone_loading(model_name):
25 | registry_entry = REGISTRY[model_name]
26 | gnn_config = BackboneConfig.from_ckpt({
27 | "family": registry_entry["kind"],
28 | "path_or_id": model_name,
29 | "interaction_block": 2,
30 | })
31 | load_checkpoint(gnn_config)
32 |
33 |
34 | @pytest.mark.parametrize("model_name", models)
35 | def test_descriptors(model_name):
36 | registry_entry = REGISTRY[model_name]
37 | gnn_config = BackboneConfig.from_ckpt({
38 | "family": registry_entry["kind"],
39 | "path_or_id": model_name,
40 | "interaction_block": 2,
41 | })
42 | bbone = load_checkpoint(gnn_config)
43 | # Get a random data sample
44 | data_path = DATASET_REGISTRY.get_path("test", "train", None, False)
45 | dataset = BaseAtomsDataset.from_path(
46 | data_path=data_path,
47 | split="train",
48 | gnn_config=gnn_config,
49 | )
50 | data, _ = dataset[0] # type: ignore
51 | expected_fdim = bbone.feature_dim()
52 | features = bbone.descriptors(data)
53 | assert features.shape[1] == expected_fdim
54 |
55 |
56 | @pytest.mark.parametrize("model_name", models)
57 | def test_force_maps(model_name):
58 | from franken.backbones.wrappers.common_patches import patch_e3nn
59 | patch_e3nn()
60 | registry_entry = REGISTRY[model_name]
61 | gnn_config = BackboneConfig.from_ckpt({
62 | "family": registry_entry["kind"],
63 | "path_or_id": model_name,
64 | "interaction_block": 2,
65 | })
66 | # Get a random data sample
67 | data_path = DATASET_REGISTRY.get_path("test", "train", None, False)
68 | dataset = BaseAtomsDataset.from_path(
69 | data_path=data_path,
70 | split="train",
71 | gnn_config=gnn_config,
72 | )
73 | device="cuda:0" if torch.cuda.is_available() else "cpu"
74 | # initialize model
75 | model = FrankenPotential(
76 | gnn_config=gnn_config,
77 | rf_config=GaussianRFConfig(num_random_features=128, length_scale=1.0),
78 | )
79 | model = model.to(device)
80 | data, _ = dataset[0] # type: ignore
81 | data = data.to(device)
82 | emap, fmap = model.grad_feature_map(data)
83 |
84 |
--------------------------------------------------------------------------------
/tests/test_backbones_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from unittest.mock import patch
4 |
5 | import pytest
6 |
7 | import franken.backbones.utils
8 |
9 |
10 | @pytest.fixture
11 | def mock_registry():
12 | return {
13 | "UNIMPLEMENTED_MODEL": {
14 | "kind": "mock",
15 | "implemented": False,
16 | "local": "unimplemented.ckpt",
17 | "remote": "https://example.com",
18 | }
19 | }
20 |
21 |
22 | @pytest.fixture
23 | def mock_cache_folder():
24 | return Path("/tmp/cache")
25 |
26 |
27 | def test_model_registry():
28 | registry = franken.backbones.utils.load_model_registry()
29 | for model in registry.values():
30 | for key in ["remote", "local", "kind", "implemented"]:
31 | assert key in model.keys()
32 |
33 |
34 | def test_cache_dir_default(mock_cache_folder):
35 | """Test that the function returns the default path when FRANKEN_CACHE_DIR is not set."""
36 | # Ensure no environment variable is set
37 | with patch.dict(os.environ, {}, clear=True):
38 | # Mock the home path and the Path.exists method
39 | with patch("pathlib.Path.home", return_value=mock_cache_folder):
40 | with patch("pathlib.Path.exists", return_value=True) as mock_exists:
41 | # Call the function
42 | franken.backbones.utils.CacheDir.initialize()
43 | result = franken.backbones.utils.CacheDir.get()
44 |
45 | # Check the default path is returned
46 | assert result == mock_cache_folder / ".franken"
47 | # Ensure that the path exists
48 | mock_exists.assert_called_once()
49 |
50 |
51 | def test_cache_dir_with_env_var(mock_cache_folder):
52 | """Test that the function returns the correct path when FRANKEN_CACHE_DIR is set."""
53 | # Mock the environment variable
54 | with patch.dict(os.environ, {"FRANKEN_CACHE_DIR": str(mock_cache_folder)}):
55 | # Mock the Path.exists method
56 | with patch("pathlib.Path.exists", return_value=True) as mock_exists:
57 | # Call the function
58 | franken.backbones.utils.CacheDir.initialize()
59 | result = franken.backbones.utils.CacheDir.get()
60 |
61 | # Check the environment variable path is returned
62 | assert str(result) == str(mock_cache_folder)
63 | # Ensure that the path exists
64 | mock_exists.assert_called_once()
65 |
66 |
67 | def test_download_checkpoint_name_error():
68 | """Test that a NameError is raised for unknown gnn_backbone_id."""
69 | # Mock the model registry to return an empty registry
70 | with patch("franken.backbones.utils.load_model_registry", return_value={}):
71 | # Expect a NameError when the gnn_backbone_id is not in the registry
72 | with pytest.raises(NameError) as exc_info:
73 | franken.backbones.utils.download_checkpoint("UNKNOWN_MODEL")
74 | assert "Unknown UNKNOWN_MODEL GNN backbone" in str(exc_info.value)
75 |
76 |
77 | def test_download_checkpoint_not_implemented(mock_registry):
78 | """Test that a NotImplementedError is raised when the model is not implemented."""
79 | # Mock the model registry to return a registry with a model that is not implemented
80 | with patch(
81 | "franken.backbones.utils.load_model_registry", return_value=mock_registry
82 | ):
83 | # Expect a NotImplementedError when the gnn_backbone_id is not implemented
84 | with pytest.raises(NotImplementedError) as exc_info:
85 | franken.backbones.utils.download_checkpoint("UNIMPLEMENTED_MODEL")
86 | assert "The model UNIMPLEMENTED_MODEL is not implemented" in str(exc_info.value)
87 |
88 |
89 | @pytest.mark.skip(reason="Actually downloads the model")
90 | def test_download_checkpoint_successful_download(tmp_path):
91 | gnn_id = "MACE-L0"
92 | """Test that the model is downloaded correctly when it is implemented."""
93 | registry = franken.backbones.utils.load_model_registry()
94 | with patch.dict(os.environ, {"FRANKEN_CACHE_DIR": str(tmp_path)}):
95 | franken.backbones.utils.download_checkpoint(gnn_id)
96 | ckpt = tmp_path / "gnn_checkpoints" / registry[gnn_id]["local"]
97 | assert ckpt.exists()
98 | assert ckpt.is_file()
99 |
100 |
101 | def test_get_checkpoint_path_valid_backbone(mock_registry, mock_cache_folder):
102 | with patch(
103 | "franken.backbones.utils.load_model_registry", return_value=mock_registry
104 | ), patch(
105 | "franken.backbones.utils.CacheDir.get", return_value=mock_cache_folder
106 | ), patch("pathlib.Path.exists", return_value=True):
107 | result = franken.backbones.utils.get_checkpoint_path("UNIMPLEMENTED_MODEL")
108 | expected_path = mock_cache_folder / "gnn_checkpoints" / "unimplemented.ckpt"
109 | assert result == expected_path
110 |
111 |
112 | def test_get_checkpoint_path_invalid_backbone(mock_registry):
113 | with patch(
114 | "franken.backbones.utils.load_model_registry", return_value=mock_registry
115 | ), patch(
116 | "franken.backbones.utils.make_summary", return_value="available backbones"
117 | ):
118 | with pytest.raises(FileNotFoundError) as exc_info:
119 | franken.backbones.utils.get_checkpoint_path("invalid_backbone")
120 |
121 | assert "GNN Backbone path 'invalid_backbone' does not exist." in str(exc_info.value)
122 | assert "available backbones" in str(exc_info.value)
123 |
124 |
125 | def test_get_checkpoint_path_download_required(mock_registry, mock_cache_folder):
126 | with patch(
127 | "franken.backbones.utils.load_model_registry", return_value=mock_registry
128 | ), patch(
129 | "franken.backbones.utils.CacheDir.get", return_value=mock_cache_folder
130 | ), patch("pathlib.Path.exists", return_value=False), patch(
131 | "franken.backbones.utils.download_checkpoint"
132 | ) as mock_download:
133 | result = franken.backbones.utils.get_checkpoint_path("UNIMPLEMENTED_MODEL")
134 | expected_path = mock_cache_folder / "gnn_checkpoints" / "unimplemented.ckpt"
135 | assert result == expected_path
136 | mock_download.assert_called_once_with("UNIMPLEMENTED_MODEL")
137 |
--------------------------------------------------------------------------------
/tests/test_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | import torch.distributed
4 | from torch.multiprocessing import Process, Pipe, SimpleQueue
5 |
6 | from franken.data.base import Configuration, SimpleAtomsDataset
7 | from franken.datasets.registry import DATASET_REGISTRY
8 |
9 | class ThrowingProcess(Process):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self._pconn, self._cconn = Pipe()
13 | self._exception = None
14 |
15 | def run(self):
16 | try:
17 | super().run()
18 | self._cconn.send(None)
19 | except Exception as e:
20 | self._cconn.send(e)
21 | raise e
22 |
23 | @property
24 | def exception(self):
25 | if self._pconn.poll():
26 | self._exception = self._pconn.recv()
27 | return self._exception
28 |
29 |
30 | def init_processes(rank, size, fn, backend='gloo'):
31 | """ Initialize the distributed environment. """
32 | os.environ['MASTER_ADDR'] = '127.0.0.64'
33 | os.environ['MASTER_PORT'] = '26512'
34 | os.environ['GLOO_SOCKET_IFNAME'] = "lo"
35 | torch.distributed.init_process_group(backend, rank=rank, world_size=size)
36 | fn()
37 |
38 |
39 | def init_distributed_cpu(num_proc, run_fn):
40 | processes = []
41 | for rank in range(num_proc):
42 | p = ThrowingProcess(target=init_processes, args=(rank, num_proc, run_fn))
43 | p.start()
44 | processes.append(p)
45 |
46 | for p in processes:
47 | p.join()
48 | if p.exception:
49 | error = p.exception
50 | raise error
51 |
52 |
53 | def mocked_dataset(num_atoms, dtype, device, num_configs: int = 1):
54 | data = []
55 | for _ in range(num_configs):
56 | data.append(Configuration(
57 | torch.randn(num_atoms, 3, dtype=dtype),
58 | torch.randint(1, 100, (num_atoms,)),
59 | torch.tensor(num_atoms),
60 | ).to(device))
61 | return data
62 |
63 |
64 | @pytest.mark.parametrize("num_samples", [1, 7, 19])
65 | @pytest.mark.parametrize("num_procs", [1, 4])
66 | def test_distributed_dataloader_length(num_samples, num_procs):
67 | def inner_fn():
68 | data_path = DATASET_REGISTRY.get_path("test", "long", None, False)
69 | dataset = SimpleAtomsDataset(
70 | data_path,
71 | split="train",
72 | num_random_subsamples=num_samples,
73 | subsample_rng=None,
74 | )
75 | assert len(dataset) == num_samples
76 | dataloader = dataset.get_dataloader(True)
77 | rank = torch.distributed.get_rank()
78 | ws = torch.distributed.get_world_size()
79 | assert len(dataloader) == (len(dataset) // ws) + int(len(dataset) % ws > rank)
80 |
81 | init_distributed_cpu(num_procs, inner_fn)
82 |
83 |
84 | def test_distributed_dataloader_order():
85 | num_samples = 7
86 | num_procs = 3
87 | ids_queue = SimpleQueue()
88 | def inner_fn():
89 | data_path = DATASET_REGISTRY.get_path("test", "long", None, False)
90 | dataset = SimpleAtomsDataset(
91 | data_path,
92 | split="train",
93 | num_random_subsamples=num_samples,
94 | subsample_rng=None,
95 | )
96 | assert len(dataset) == num_samples
97 | dataloader = dataset.get_dataloader(True)
98 | rank = torch.distributed.get_rank()
99 | dl_elements = [el for el in dataloader]
100 | dl_id = 0
101 | for i in range(rank, num_samples, num_procs):
102 | torch.testing.assert_close(
103 | dl_elements[dl_id][0].atom_pos, dataset[i][0].atom_pos
104 | )
105 | torch.testing.assert_close(
106 | dl_elements[dl_id][1].forces, dataset[i][1].forces
107 | )
108 | dl_id += 1
109 | ids_queue.put(i)
110 | assert dl_id == len(dl_elements)
111 | init_distributed_cpu(num_procs, inner_fn)
112 | # Assert all IDs were processed - only once
113 | all_ids = []
114 | while not ids_queue.empty():
115 | all_ids.append(ids_queue.get())
116 | assert sorted(all_ids) == list(range(num_samples))
117 |
--------------------------------------------------------------------------------
/tests/test_lammps.py:
--------------------------------------------------------------------------------
1 | """
2 | Test the model conversion to LAMMPS (essentially testing torch-scriptability, not LAMMPS directly)
3 | """
4 |
5 | import os
6 | import pytest
7 | import torch
8 |
9 | from franken.backbones.wrappers.common_patches import unpatch_e3nn
10 | from franken.config import GaussianRFConfig, MaceBackboneConfig, MultiscaleGaussianRFConfig
11 | from franken.data import BaseAtomsDataset
12 | from franken.rf.model import FrankenPotential
13 | from franken.rf.scaler import Statistics
14 | from franken.utils.misc import garbage_collection_cuda
15 | from franken.datasets.registry import DATASET_REGISTRY
16 | from franken.calculators.lammps_calc import LammpsFrankenCalculator
17 |
18 | from .conftest import DEVICES
19 | from .utils import are_dicts_close, cleanup_dir, create_temp_dir
20 |
21 |
22 | RF_PARAMETRIZE = [
23 | GaussianRFConfig(num_random_features=128, length_scale=1.0),
24 | MultiscaleGaussianRFConfig(num_random_features=128),
25 | ]
26 |
27 |
28 | @pytest.mark.parametrize("rf_cfg", RF_PARAMETRIZE)
29 | @pytest.mark.parametrize("device", DEVICES)
30 | def test_lammps_compile(rf_cfg, device):
31 | """Test for checking save and load methods of FrankenPotential"""
32 | unpatch_e3nn() # needed in case some previous test ran the patching code
33 | gnn_cfg = MaceBackboneConfig("MACE-L0")
34 | temp_dir = None
35 | try:
36 | # Step 1: Create a temporary directory for saving the model
37 | temp_dir = create_temp_dir()
38 |
39 | data_path = DATASET_REGISTRY.get_path("test", "test", None, False)
40 | dataset = BaseAtomsDataset.from_path(
41 | data_path=data_path,
42 | split="train",
43 | gnn_config=gnn_cfg,
44 | )
45 | model = FrankenPotential(
46 | gnn_config=gnn_cfg,
47 | rf_config=rf_cfg,
48 | scale_by_Z=True,
49 | num_species=dataset.num_species,
50 | ).to(device)
51 |
52 | with torch.no_grad():
53 | gnn_features_stats = Statistics()
54 | for data, _ in dataset: # type: ignore
55 | data = data.to(device=device)
56 | gnn_features = model.gnn.descriptors(data)
57 | gnn_features_stats.update(
58 | gnn_features, atomic_numbers=data.atomic_numbers
59 | )
60 |
61 | model.input_scaler.set_from_statistics(gnn_features_stats)
62 | garbage_collection_cuda()
63 |
64 | # Step 2: Save the model to the temporary directory
65 | model_save_path = os.path.join(temp_dir, "model_checkpoint.pth")
66 | model.save(model_save_path)
67 |
68 | # Step 3: Run create_lammps_model
69 | comp_model_path = LammpsFrankenCalculator.create_lammps_model(model_path=model_save_path, rf_weight_id=None)
70 |
71 | # Step 4: Load saved model
72 | comp_model = torch.jit.load(comp_model_path, map_location=device)
73 |
74 | # Step 4: Compare rf.state_dict between the original and loaded models
75 | with pytest.raises(RuntimeError) as exc:
76 | assert are_dicts_close(
77 | model.rf.state_dict(), comp_model.model.rf.state_dict(), verbose=True
78 | )
79 | assert "Float did not match Double" in str(exc.value)
80 | assert are_dicts_close(
81 | model.rf.double().state_dict(), comp_model.model.rf.state_dict(), verbose=True
82 | ), "The rf.state_dict() of the loaded model does not match the original model."
83 |
84 | with pytest.raises(RuntimeError) as exc:
85 | assert are_dicts_close(
86 | model.input_scaler.state_dict(),
87 | comp_model.model.input_scaler.state_dict(),
88 | verbose=True,
89 | )
90 | assert "Float did not match Double" in str(exc.value)
91 | assert are_dicts_close(
92 | model.input_scaler.double().state_dict(),
93 | comp_model.model.input_scaler.state_dict(),
94 | verbose=True,
95 | ), "The input_scaler.state_dict() of the loaded model does not match the original model."
96 |
97 | with pytest.raises(RuntimeError) as exc:
98 | assert are_dicts_close(
99 | model.energy_shift.state_dict(),
100 | comp_model.model.energy_shift.state_dict(),
101 | verbose=True,
102 | )
103 | assert "Float did not match Double" in str(exc.value)
104 | assert are_dicts_close(
105 | model.energy_shift.double().state_dict(),
106 | comp_model.model.energy_shift.state_dict(),
107 | verbose=True,
108 | ), "The energy_shift.state_dict() of the loaded model does not match the original model."
109 | finally:
110 | if temp_dir is not None:
111 | cleanup_dir(temp_dir)
112 |
113 |
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | def test_registry_init():
2 | from franken.metrics.registry import registry
3 |
4 | assert hasattr(registry._instance, "_metrics")
5 |
6 |
7 | def test_available_metrics():
8 | import franken.metrics as fm
9 |
10 | for name in ["energy_MAE", "forces_MAE", "forces_cosim"]:
11 | assert name in fm.available_metrics()
12 |
13 |
14 | def test_register():
15 | import franken.metrics as fm
16 | from franken.metrics.base import BaseMetric
17 |
18 | class MockMetric(BaseMetric):
19 | pass
20 |
21 | assert "mock_metric" not in fm.available_metrics()
22 | fm.register("mock_metric", MockMetric)
23 | assert "mock_metric" in fm.available_metrics()
24 |
25 |
26 | def test_init_metric():
27 | import torch
28 |
29 | import franken.metrics as fm
30 | from franken.metrics.base import BaseMetric
31 |
32 | class MockMetric(BaseMetric):
33 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32):
34 | super().__init__("mock_metric", device, dtype)
35 |
36 | fm.register("mock_metric", MockMetric)
37 | metric = fm.init_metric("mock_metric", torch.device("cpu"))
38 | assert isinstance(metric, MockMetric)
39 |
--------------------------------------------------------------------------------
/tests/test_rf_heads.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from franken.rf.heads import (
5 | BiasedOrthogonalRFF,
6 | Linear,
7 | MultiScaleOrthogonalRFF,
8 | OrthogonalRFF,
9 | TensorSketch,
10 | )
11 |
12 | RF_PARAMETRIZE = [
13 | "poly",
14 | "gaussian",
15 | "linear",
16 | "biased-gaussian",
17 | "multiscale-gaussian",
18 | ]
19 |
20 |
21 | def init_rf(rf_type: str, *args, **kwargs):
22 | if rf_type == "poly":
23 | return TensorSketch(*args, **kwargs)
24 | elif rf_type == "gaussian":
25 | return OrthogonalRFF(*args, **kwargs)
26 | elif rf_type == "linear":
27 | return Linear(*args, **kwargs)
28 | elif rf_type == "biased-gaussian":
29 | return BiasedOrthogonalRFF(*args, **kwargs)
30 | elif rf_type == "multiscale-gaussian":
31 | return MultiScaleOrthogonalRFF(*args, **kwargs)
32 | else:
33 | raise ValueError(rf_type)
34 |
35 |
36 | @pytest.mark.parametrize("rf_type", RF_PARAMETRIZE)
37 | class TestDtype:
38 | @pytest.mark.parametrize("dt", [torch.float32, torch.float64])
39 | def test_dtype_match(self, dt, rf_type):
40 | rf = init_rf(
41 | rf_type,
42 | input_dim=64,
43 | )
44 | data = torch.randn(10, 64, dtype=dt)
45 | atomic_nums = torch.randint(1, 100, (10,))
46 | fmap = rf.feature_map(data, atomic_numbers=atomic_nums)
47 | assert fmap.dtype == dt
48 | for buf_name, buf in rf.named_buffers():
49 | if buf_name == "weights":
50 | # weights not touched by this test so they'll be f32
51 | assert (
52 | buf.dtype == torch.get_default_dtype()
53 | ), f"weights has unexpected type {buf.dtype}"
54 | elif buf.numel() > 1 and buf.dtype.is_floating_point:
55 | assert buf.dtype == dt, f"Buffer {buf_name} has incorrect dtype."
56 |
57 |
58 | class TestFeatureSizes:
59 | def test_orff_offset(self):
60 | rf_offset = init_rf(
61 | "gaussian", input_dim=32, use_offset=True, num_random_features=128
62 | )
63 | rf_no_offset = init_rf(
64 | "gaussian", input_dim=32, use_offset=False, num_random_features=128
65 | )
66 | assert rf_offset.num_random_features == 128
67 | assert rf_no_offset.num_random_features == 128
68 | assert rf_offset.total_random_features == 128
69 | assert rf_no_offset.total_random_features == 256
70 | assert rf_offset.rff_matrix.shape == (128, 32)
71 | assert rf_no_offset.rff_matrix.shape == (128, 32)
72 | assert rf_offset.random_offset.shape == (128,)
73 |
74 | @pytest.mark.parametrize("rf_type", ["poly", "gaussian"])
75 | def test_per_species_kernel_nonlin1(self, rf_type):
76 | rf = init_rf(
77 | rf_type,
78 | input_dim=32,
79 | num_random_features=128,
80 | num_species=4,
81 | chemically_informed_ratio=None,
82 | )
83 | assert rf.num_random_features == 128
84 | assert rf.total_random_features == 128 * 4
85 |
86 | @pytest.mark.parametrize("rf_type", ["poly", "gaussian"])
87 | def test_per_species_kernel_nonlin2(self, rf_type):
88 | rf = init_rf(
89 | rf_type,
90 | input_dim=32,
91 | num_random_features=128,
92 | num_species=4,
93 | chemically_informed_ratio=0.4,
94 | )
95 | assert rf.num_random_features == 128
96 | assert rf.total_random_features == 128 * (4 + 1)
97 |
98 | def test_per_species_kernel_lin1(self):
99 | rf = init_rf(
100 | "linear", input_dim=32, num_species=4, chemically_informed_ratio=None
101 | )
102 | assert rf.num_random_features == 33
103 | assert rf.total_random_features == 33 * 4
104 |
105 | def test_per_species_kernel_lin2(self):
106 | rf = init_rf(
107 | "linear", input_dim=32, num_species=4, chemically_informed_ratio=0.4
108 | )
109 | assert rf.num_random_features == 33
110 | assert rf.total_random_features == (33) * (4 + 1)
111 |
112 |
113 | class TestEdgeCaseInputs:
114 | @pytest.mark.parametrize("rf_type", RF_PARAMETRIZE)
115 | def test_zero_lengthscale(self, rf_type):
116 | with pytest.raises(ValueError):
117 | if rf_type != "multiscale-gaussian":
118 | init_rf(rf_type, input_dim=32, length_scale=0)
119 | else:
120 | init_rf(rf_type, input_dim=32, length_scale_low=0)
121 |
122 | @pytest.mark.parametrize("rf_type", RF_PARAMETRIZE)
123 | def test_negative_lengthscale(self, rf_type):
124 | with pytest.raises(ValueError):
125 | if rf_type != "multiscale-gaussian":
126 | init_rf(rf_type, input_dim=32, length_scale=-1.1)
127 | else:
128 | init_rf(rf_type, input_dim=32, length_scale_low=-1.1)
129 |
130 |
131 | if __name__ == "__main__":
132 | pytest.main()
133 |
--------------------------------------------------------------------------------
/tests/test_trainers_log_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from franken.trainers.log_utils import HyperParameterGroup, LogEntry
4 |
5 |
6 | @pytest.fixture
7 | def dummy_log_dict():
8 | return {
9 | "checkpoint": {"hash": "rand_uuid", "rf_weight_id": 0},
10 | "timings": {"cov_coeffs": 1.0, "solve": 1.0},
11 | "metrics": {
12 | "train": {"energy_MAE": 1.0, "forces_MAE": 1.0, "forces_cosim": 1.0},
13 | "validation": {"energy_MAE": 1.0, "forces_MAE": 1.0, "forces_cosim": 1.0},
14 | "test": {"energy_MAE": 1.0, "forces_MAE": 1.0, "forces_cosim": 1.0},
15 | },
16 | "hyperparameters": {
17 | "franken": {
18 | "gnn_backbone_id": "SchNet-S2EF-OC20-All",
19 | "interaction_block": 3,
20 | "kernel_type": "gaussian",
21 | },
22 | "random_features": {
23 | "num_random_features": 1024,
24 | },
25 | "input_scaler": {"scale_by_Z": True, "num_species": 2},
26 | "solver": {
27 | "l2_penalty": 1e-6,
28 | "force_weight": 0.1,
29 | "dtype": "torch.float64",
30 | },
31 | },
32 | }
33 |
34 |
35 | def test_hpgroup_from_dict():
36 | dummy_group_dict = {
37 | "str_param": "str_value",
38 | "int_param": 1,
39 | "float_param": 1.0,
40 | "bool_param": True,
41 | }
42 |
43 | hpg = HyperParameterGroup.from_dict("dummy_group", dummy_group_dict)
44 | assert hpg.group_name == "dummy_group"
45 | for hp in hpg.hyperparameters:
46 | assert hp.name in dummy_group_dict.keys()
47 | assert hp.value == dummy_group_dict[hp.name]
48 |
49 |
50 | def test_log_entry_serialize_deserialize(dummy_log_dict):
51 | log_entry = LogEntry.from_dict(dummy_log_dict)
52 | assert log_entry.to_dict() == dummy_log_dict
53 |
54 |
55 | def test_log_entry_get_metric(dummy_log_dict):
56 | log_entry = LogEntry.from_dict(dummy_log_dict)
57 | assert log_entry.get_metric("energy_MAE", "train") == 1.0
58 |
59 |
60 | def test_log_entry_get_invalid_metric_name(dummy_log_dict):
61 | log_entry = LogEntry.from_dict(dummy_log_dict)
62 | with pytest.raises(KeyError):
63 | log_entry.get_metric("invalid_metric", "train")
64 |
65 |
66 | def test_log_entry_get_invalid_metric_split(dummy_log_dict):
67 | log_entry = LogEntry.from_dict(dummy_log_dict)
68 | with pytest.raises(KeyError):
69 | log_entry.get_metric("energy_MAE", "invalid_split")
70 |
71 |
72 | # class TestBestModel:
73 | # def test_all_nans(self):
74 | # log_entries = [
75 | # {"metrics": {"val": {"energy": torch.nan}}},
76 | # {"metrics": {"val": {"energy": torch.nan}}},
77 | # ]
78 | # expected_best_log = log_entries[0]
79 | # best_log = get_best_model(log_entries, ["energy"], split="val")
80 | # assert best_log == expected_best_log
81 |
82 | # def test_nans(self):
83 | # log_entries = [
84 | # {"metrics": {"val": {"energy": torch.nan}}},
85 | # {"metrics": {"val": {"energy": 0.1}}},
86 | # {"metrics": {"val": {"energy": 12.0}}},
87 | # ]
88 | # expected_best_log = log_entries[1]
89 | # best_log = get_best_model(log_entries, ["energy"], split="val")
90 | # assert best_log == expected_best_log
91 | # log_entries = [
92 | # {"metrics": {"val": {"energy": 0.1}}},
93 | # {"metrics": {"val": {"energy": torch.nan}}},
94 | # {"metrics": {"val": {"energy": 12.0}}},
95 | # ]
96 | # expected_best_log = log_entries[0]
97 | # best_log = get_best_model(log_entries, ["energy"], split="val")
98 | # assert best_log == expected_best_log
99 |
100 | # def test_stability(self):
101 | # log_entries = [
102 | # {"metrics": {"val": {"energy": 1.0, "forces": 12}}},
103 | # {"metrics": {"val": {"energy": 1.1, "forces": 11.9}}},
104 | # {"metrics": {"val": {"energy": 1.2, "forces": 11.8}}},
105 | # ]
106 | # expected_best_log = log_entries[0]
107 | # best_log = get_best_model(log_entries, ["energy", "forces"], split="val")
108 | # assert best_log == expected_best_log
109 |
110 | # def test_normal(self):
111 | # log_entries = [
112 | # {"metrics": {"val": {"energy": 1.0, "forces": 12}}},
113 | # {"metrics": {"val": {"energy": 0.9, "forces": 11.9}}},
114 | # {"metrics": {"val": {"energy": 1.2, "forces": 11.8}}},
115 | # ]
116 | # expected_best_log = log_entries[1]
117 | # best_log = get_best_model(log_entries, ["energy", "forces"], split="val")
118 | # assert best_log == expected_best_log
119 |
120 | # def test_missing_split(self):
121 | # log_entries = [
122 | # {"metrics": {"val": {"energy": 1.0, "forces": 12}}},
123 | # ]
124 | # with pytest.raises(KeyError):
125 | # get_best_model(log_entries, ["energy", "forces"], split="train")
126 |
127 | # def test_missing_metric(self):
128 | # log_entries = [
129 | # {"metrics": {"val": {"energy": 1.0, "forces": 12}}},
130 | # ]
131 | # with pytest.raises(KeyError):
132 | # get_best_model(log_entries, ["missing", "forces"], split="val")
133 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import tempfile
3 | from unittest.mock import MagicMock, patch
4 |
5 | import torch
6 |
7 |
8 | # Utility function to create a temporary directory
9 | def create_temp_dir() -> str:
10 | return tempfile.mkdtemp()
11 |
12 |
13 | # Utility function to clean up a directory
14 | def cleanup_dir(temp_dir: str):
15 | shutil.rmtree(temp_dir)
16 |
17 |
18 | def are_dicts_close(dict1, dict2, rtol=1e-4, atol=1e-6, verbose=False):
19 | if not isinstance(dict1, dict) or not isinstance(dict2, dict):
20 | return False
21 |
22 | if set(dict1.keys()) != set(dict2.keys()):
23 | if verbose:
24 | print(f"Dictionaries have different keys: {set(dict1.keys())}, {set(dict2.keys())}")
25 | return False
26 |
27 | for key in dict1.keys():
28 | if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
29 | if not are_dicts_close(dict1[key], dict2[key], rtol, atol):
30 | return False
31 | elif isinstance(dict1[key], torch.Tensor) and isinstance(
32 | dict2[key], torch.Tensor
33 | ):
34 | if not torch.allclose(dict1[key], dict2[key], rtol=rtol, atol=atol):
35 | if verbose:
36 | print(f"{key} not equal:\n(1) {dict1[key]}\n(2) {dict2[key]}")
37 | return False
38 | else:
39 | if verbose:
40 | print("The dictionaries have differnt topology")
41 | return False
42 | return True
43 |
44 |
45 | def mocked_gnn(device, dtype, feature_dim: int = 32, backbone_id: str = "test"):
46 | # A bunch of code to initialize a mock for the GNN
47 | gnn = MagicMock()
48 | gnn.feature_dim = MagicMock(return_value=feature_dim)
49 | fake_gnn_weight = torch.randn(3, feature_dim, device=device, dtype=dtype)
50 |
51 | def mock_descriptors(data):
52 | return torch.sin(data.atom_pos) @ fake_gnn_weight
53 |
54 | gnn.descriptors = mock_descriptors
55 |
56 | def load_checkpoint_patch(*args, **kwargs):
57 | gnn.init_args = MagicMock(return_value=dict(kwargs))
58 | return gnn
59 |
60 | return patch.multiple("franken.rf.model", load_checkpoint=load_checkpoint_patch)
61 |
--------------------------------------------------------------------------------