├── .github ├── env-dev.yml ├── env-docs.yml └── workflows │ ├── CI.yaml │ ├── linting.yaml │ └── rtd.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── LICENSE.txt ├── README.md ├── docs ├── Makefile ├── _static │ ├── TM23Cu_sample_complexity.png │ ├── css │ │ └── mystyle.css │ └── diagram_part1.png ├── _templates │ ├── class.rst │ ├── func.rst │ └── layout.html ├── conf.py ├── index.rst ├── make.bat ├── notebooks │ ├── autotune.ipynb │ ├── getting_started.ipynb │ └── molecular_dynamics.ipynb ├── reference │ ├── cli.rst │ ├── franken-api │ │ ├── franken.calculators.rst │ │ ├── franken.config.rst │ │ ├── franken.rf.heads.rst │ │ ├── franken.rf.model.rst │ │ ├── franken.rf.scaler.rst │ │ └── franken.trainers.rst │ ├── franken-cli │ │ ├── franken.autotune.rst │ │ ├── franken.backbones.rst │ │ └── franken.create_lammps_model.rst │ └── index.rst ├── requirements.txt └── topics │ ├── installation.md │ ├── lammps.md │ └── model_registry.md ├── franken ├── __init__.py ├── autotune │ ├── __init__.py │ ├── __main__.py │ ├── cli.py │ └── script.py ├── backbones │ ├── __init__.py │ ├── cli.py │ ├── registry.json │ ├── utils.py │ └── wrappers │ │ ├── __init__.py │ │ ├── common_patches.py │ │ ├── fairchem_schnet.py │ │ ├── mace_wrap.py │ │ └── sevenn.py ├── calculators │ ├── __init__.py │ ├── ase_calc.py │ └── lammps_calc.py ├── config.py ├── data │ ├── __init__.py │ ├── base.py │ ├── distributed_sampler.py │ ├── fairchem.py │ ├── mace.py │ └── sevenn.py ├── datasets │ ├── PtH2O │ │ └── pth2o_dataset.py │ ├── TM23 │ │ └── tm23_dataset.py │ ├── __init__.py │ ├── registry.py │ ├── split_data.py │ ├── test │ │ ├── long.xyz │ │ ├── md.xyz │ │ ├── test.xyz │ │ ├── test_dataset.py │ │ ├── train.xyz │ │ └── validation.xyz │ └── water │ │ ├── HH_digitizer.csv │ │ ├── OH_digitizer.csv │ │ ├── OO_digitizer.csv │ │ ├── exp_rdf.csv │ │ └── water_dataset.py ├── metrics │ ├── __init__.py │ ├── base.py │ ├── functions.py │ └── registry.py ├── rf │ ├── __init__.py │ ├── atomic_energies.py │ ├── heads.py │ ├── model.py │ └── scaler.py ├── trainers │ ├── __init__.py │ ├── base.py │ ├── log_utils.py │ └── rf_cuda_lowmem.py └── utils │ ├── __init__.py │ ├── distributed.py │ ├── file_utils.py │ ├── hostlist.py │ ├── jac.py │ ├── linalg │ ├── __init__.py │ ├── cov.py │ ├── psdsolve.py │ └── tri.py │ └── misc.py ├── notebooks ├── autotune.ipynb ├── colab.ipynb ├── getting_started.ipynb └── molecular_dynamics.ipynb ├── pyproject.toml └── tests ├── __init__.py ├── conftest.py ├── test_FrankenPotential.py ├── test_backbones.py ├── test_backbones_utils.py ├── test_data.py ├── test_lammps.py ├── test_linalg.py ├── test_metrics.py ├── test_rf_heads.py ├── test_trainer.py ├── test_trainers_log_utils.py └── utils.py /.github/env-dev.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - nvidia 4 | - conda-forge 5 | dependencies: 6 | - pytorch>=2.4 7 | - ase 8 | - numpy 9 | - omegaconf 10 | - cupy 11 | - e3nn 12 | - pip 13 | - requests 14 | - tqdm 15 | - pytest 16 | - pre-commit 17 | - black 18 | - ruff 19 | - psutil 20 | - docstring_parser 21 | - packaging 22 | name: franken -------------------------------------------------------------------------------- /.github/env-docs.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - nvidia 4 | - conda-forge 5 | dependencies: 6 | - pytorch>=2.4 7 | - ase 8 | - numpy 9 | - omegaconf 10 | - pip 11 | - e3nn 12 | - requests 13 | - tqdm 14 | - pytest 15 | - psutil 16 | - docstring_parser 17 | - ipython 18 | - sphinx 19 | - sphinxawesome-theme 20 | - sphinxcontrib-applehelp 21 | - sphinxcontrib-devhelp 22 | - sphinxcontrib-htmlhelp 23 | - sphinxcontrib-jsmath 24 | - sphinxcontrib-qthelp 25 | - sphinxcontrib-serializinghtml 26 | - sphinx-argparse 27 | - myst-parser 28 | - nbsphinx 29 | - packaging 30 | name: franken -------------------------------------------------------------------------------- /.github/workflows/CI.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - "main" 7 | pull_request: 8 | branches: 9 | - "main" 10 | types: [opened, reopened, synchronize] 11 | schedule: 12 | # Weekly tests run on main by default: 13 | # Scheduled workflows run on the latest commit on the default or base branch. 14 | # (from https://help.github.com/en/actions/reference/events-that-trigger-workflows#scheduled-events-schedule) 15 | - cron: "0 2 * * 1" 16 | workflow_dispatch: 17 | 18 | jobs: 19 | test: 20 | name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }} 21 | runs-on: ${{ matrix.os }} 22 | strategy: 23 | matrix: 24 | os: [ubuntu-latest] 25 | python-version: ["3.10", "3.11", "3.12"] 26 | pytorch-version: ["2.5", "2.6"] 27 | 28 | steps: 29 | - uses: actions/checkout@v4 30 | 31 | - name: Additional info about the build 32 | shell: bash 33 | run: | 34 | uname -a 35 | df -h 36 | ulimit -a 37 | 38 | # More info on options: https://github.com/marketplace/actions/setup-micromamba 39 | - name: Create and setup mamba 40 | uses: mamba-org/setup-micromamba@v2 41 | with: 42 | # here we specify the environment like this instead of just installing with pip to make caching easier 43 | environment-file: .github/env-dev.yml 44 | environment-name: test 45 | cache-environment: true 46 | cache-environment-key: environment-${{ matrix.python-version }}-${{ matrix.pytorch-version }} 47 | condarc: | 48 | channels: 49 | - conda-forge 50 | create-args: >- 51 | python=${{ matrix.python-version }} 52 | pytorch=${{ matrix.pytorch-version }} 53 | 54 | - name: Install GNN backbones packages 55 | # conda setup requires this special shell 56 | shell: bash -l {0} 57 | env: 58 | TORCH_VERSION: ${{ matrix.pytorch-version }} 59 | run: | 60 | python -m pip install torch_geometric 61 | # Running with the -f argument gives us prebuilt wheels which speeds things up. 62 | # On the other hand it depends on them publishing the wheels 63 | python -m pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH_VERSION}.0+cpu.html 64 | python -m pip install --no-deps fairchem-core==1.10 # fairchem dependencies are a nightmare, better to ignore them 65 | python -m pip install mace-torch 66 | 67 | - name: Install package 68 | # conda setup requires this special shell 69 | shell: bash -l {0} 70 | run: | 71 | python -m pip install . --no-deps 72 | micromamba list 73 | 74 | - name: Run tests 75 | # conda setup requires this special shell 76 | shell: bash -l {0} 77 | run: | 78 | pytest -v --color=yes tests/ 79 | 80 | # - name: CodeCov 81 | # if: contains( matrix.os, 'ubuntu' ) 82 | # uses: codecov/codecov-action@v3 83 | # with: 84 | # token: ${{ secrets.CODECOV_TOKEN }} 85 | # file: ./coverage.xml 86 | # flags: codecov 87 | # name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} 88 | 89 | build: 90 | name: Bump version and build package with hatch 91 | needs: test # This ensures 'publish' only runs if 'test' passes 92 | runs-on: ubuntu-latest 93 | if: | 94 | github.ref == 'refs/heads/main' && 95 | github.event_name == 'push' && 96 | contains(github.event.head_commit.message, '[release]') 97 | steps: 98 | - uses: actions/checkout@v4 99 | - name: setup python 100 | uses: actions/setup-python@v5 101 | with: 102 | python-version: '3.11' 103 | - name: install hatch 104 | run: pip install hatch 105 | 106 | - name: Determine bump type 107 | id: bump 108 | run: | 109 | COMMIT_MSG=`git log -1 --pretty=%B | head -n 1` 110 | if [[ "$COMMIT_MSG" == *"[Major]"* ]]; then 111 | echo "bump=major" >> $GITHUB_OUTPUT 112 | elif [[ "$COMMIT_MSG" == *"[Minor]"* ]]; then 113 | echo "bump=minor" >> $GITHUB_OUTPUT 114 | else 115 | echo "bump=patch" >> $GITHUB_OUTPUT 116 | fi 117 | - name: bump version and tag repo 118 | run: | 119 | git config --global user.name 'autobump' 120 | git config --global user.email 'autobump@github.com' 121 | OLD_VERSION=`hatch version` 122 | hatch version ${{ steps.bump.outputs.bump }} 123 | NEW_VERSION=`hatch version` 124 | git add franken/__init__.py 125 | git commit -m "Updated version: ${OLD_VERSION} → ${NEW_VERSION} [skip ci]" 126 | git tag $NEW_VERSION 127 | git push 128 | git push --tags 129 | - name: build franken package 130 | run: hatch build 131 | - name: Upload build artifacts 132 | uses: actions/upload-artifact@v4 133 | with: 134 | name: dist-files 135 | path: dist/* 136 | 137 | publish: 138 | name: Publish to PyPi 139 | needs: build 140 | runs-on: ubuntu-latest 141 | if: | 142 | github.ref == 'refs/heads/main' && 143 | github.event_name == 'push' && 144 | needs.build.result == 'success' 145 | permissions: 146 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 147 | environment: 148 | name: pypi 149 | url: https://pypi.org/project/franken/ 150 | steps: 151 | - name: Download build artifacts 152 | uses: actions/download-artifact@v4 153 | with: 154 | name: dist-files 155 | path: dist/ 156 | - name: Publish package distributions to PyPI 157 | uses: pypa/gh-action-pypi-publish@release/v1 -------------------------------------------------------------------------------- /.github/workflows/linting.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [push] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - name: Set up Python 11 | uses: actions/setup-python@v5 12 | with: 13 | python-version: '3.11' 14 | - name: Install dependencies 15 | run: | 16 | python -m pip install --upgrade pip 17 | pip install ruff 18 | - name: Linting with black 19 | uses: psf/black@stable 20 | with: 21 | options: "--check --verbose" 22 | use_pyproject: true 23 | - name: Run Ruff 24 | run: ruff check --output-format=github . 25 | -------------------------------------------------------------------------------- /.github/workflows/rtd.yaml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | 3 | # Runs on pushes targeting the default branch 4 | on: 5 | push: 6 | branches: 7 | - main 8 | pull_request: 9 | branches: 10 | - main 11 | 12 | # Allows you to run this workflow manually from the Actions tab 13 | workflow_dispatch: 14 | 15 | # Cancel in-progress runs when pushing a new commit on the PR 16 | concurrency: 17 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 18 | cancel-in-progress: true 19 | 20 | jobs: 21 | docs: 22 | environment: 23 | name: ghpg 24 | url: ${{ steps.deployment.outputs.page_url }} 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v4 28 | - name: Create and setup mamba 29 | uses: mamba-org/setup-micromamba@v2 30 | with: 31 | # here we specify the environment like this instead of just installing with pip to make caching easier 32 | environment-file: .github/env-docs.yml 33 | environment-name: test 34 | cache-environment: true 35 | cache-environment-key: environment-docs 36 | condarc: | 37 | channels: 38 | - conda-forge 39 | create-args: >- 40 | python=3.12 41 | pytorch=2.4 42 | pandoc=3.6.4 43 | - name: Install franken 44 | # conda setup requires this special shell 45 | shell: bash -l {0} 46 | # dependencies are handled in conda env 47 | run: | 48 | python -m pip install . --no-deps 49 | micromamba list 50 | 51 | - name: Sphinx build 52 | # conda setup requires this special shell 53 | shell: bash -l {0} 54 | run: | 55 | # Check import works. sphinx-build will try to import but not provide reliable error-traces. 56 | python -c "import franken; import franken.calculators;" 57 | sphinx-build docs _build 58 | 59 | # This step zips and pushes the built docs to the rtd branch 60 | - name: Push docs to rtd branch 61 | if: github.ref == 'refs/heads/main' # Only deploy when pushing to main 62 | run: | 63 | # Setup git identity 64 | git config --global user.name "GitHub Actions" 65 | git config --global user.email "actions@github.com" 66 | 67 | # Create docs.zip from the _build directory first 68 | cd _build 69 | zip -r ../docs.zip . 70 | cd .. 71 | 72 | # Save a copy of important files 73 | cp .readthedocs.yaml /tmp/readthedocs.yaml 74 | cp docs.zip /tmp/docs.zip 75 | 76 | # Create a fresh rtd branch 77 | git checkout --orphan rtd-temp 78 | git rm -rf . 79 | 80 | # Restore the saved files 81 | cp /tmp/readthedocs.yaml .readthedocs.yaml 82 | cp /tmp/docs.zip docs.zip 83 | 84 | # Add and commit both files 85 | git add docs.zip .readthedocs.yaml 86 | git commit -m "Update documentation build [skip ci]" 87 | 88 | # Force push to rtd branch 89 | git push origin rtd-temp:rtd -f 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Datasets 2 | datasets/TM23/test_lengths.py 3 | franken/datasets/TM23/*.xyz 4 | franken/datasets/TM23/*.zip 5 | franken/datasets/PtH2O/*.extxyz 6 | franken/datasets/PtH2O/*.traj 7 | franken/datasets/water/*.xyz 8 | franken/datasets/water/*.zip 9 | 10 | # Random 11 | rsync.sh 12 | 13 | # Docs 14 | docs/reference/franken-api/stubs 15 | docs/reference/stubs 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | share/python-wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .nox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | *.py,cover 66 | .hypothesis/ 67 | .pytest_cache/ 68 | cover/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | local_settings.py 77 | db.sqlite3 78 | db.sqlite3-journal 79 | 80 | # Flask stuff: 81 | instance/ 82 | .webassets-cache 83 | 84 | # Scrapy stuff: 85 | .scrapy 86 | 87 | # Sphinx documentation 88 | docs/_build/ 89 | 90 | # PyBuilder 91 | .pybuilder/ 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # IPython 98 | profile_default/ 99 | ipython_config.py 100 | 101 | # pyenv 102 | # For a library or package, you might want to ignore these files since the code is 103 | # intended to run in multiple environments; otherwise, check them in: 104 | # .python-version 105 | 106 | # pipenv 107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 109 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 110 | # install all needed dependencies. 111 | #Pipfile.lock 112 | 113 | # poetry 114 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 115 | # This is especially recommended for binary packages to ensure reproducibility, and is more 116 | # commonly ignored for libraries. 117 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 118 | #poetry.lock 119 | 120 | # pdm 121 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 122 | #pdm.lock 123 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 124 | # in version control. 125 | # https://pdm.fming.dev/#use-with-ide 126 | .pdm.toml 127 | 128 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 129 | __pypackages__/ 130 | 131 | # Celery stuff 132 | celerybeat-schedule 133 | celerybeat.pid 134 | 135 | # SageMath parsed files 136 | *.sage.py 137 | 138 | # Environments 139 | .env 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | .vscode 172 | 173 | **/gnn_checkpoints/ 174 | **/experiments/ 175 | **/notebooks_legacy/ 176 | **/baselines_legacy/ 177 | wheel/ 178 | **/wandb/ 179 | **/precomputed/ 180 | *.pt 181 | *.report 182 | *.sbatch 183 | slurm* 184 | 185 | *baseline_/ 186 | *.out 187 | .history/ 188 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster 3 | - repo: https://github.com/psf/black-pre-commit-mirror 4 | rev: 24.8.0 5 | hooks: 6 | - id: black 7 | # It is recommended to specify the latest version of Python 8 | # supported by your project here 9 | language_version: python3.11 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | rev: v0.7.3 12 | hooks: 13 | - id: ruff 14 | pass_filenames: false -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | build: 8 | os: ubuntu-22.04 9 | tools: 10 | python: "3.10" 11 | jobs: 12 | build: 13 | html: 14 | - echo "Extracting pre-built docs from docs.zip" 15 | - mkdir -p $READTHEDOCS_OUTPUT/html/ 16 | - unzip -o docs.zip -d $READTHEDOCS_OUTPUT/html/ -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2025 Franken authors 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Franken 2 | 3 | [![Test status](https://github.com/CSML-IIT-UCL/franken/actions/workflows/CI.yaml/badge.svg)](https://github.com/CSML-IIT-UCL/franken/actions/workflows/CI.yaml) 4 | [![Docs status](https://github.com/CSML-IIT-UCL/franken/actions/workflows/rtd.yaml/badge.svg)](https://franken.readthedocs.io/) 5 | 6 | 7 | ## Introduction 8 | 9 | Franken is an open-source library that can be used to enhance the accuracy of atomistic foundation models. It can be used for molecular dynamics simulations, and has a focus on computational efficiency. 10 | 11 | `franken` features include: 12 | - Supports fine-tuning for a variety of foundation models ([MACE](https://github.com/ACEsuit/mace), [SevenNet](https://github.com/MDIL-SNU/SevenNet), [SchNet](https://github.com/facebookresearch/fairchem)) 13 | - Automatic [hyperparameter tuning](https://franken.readthedocs.io/notebooks/autotune.html) simplifies the adaptation procedure, for an out-of-the-box user experience. 14 | - Several random-feature approximations to common kernels (e.g. Gaussian, polynomial) are available to flexibly fine-tune any foundation model. 15 | - Support for running within [LAMMPS](https://www.lammps.org/) molecular dynamics, as well as with [ASE](https://wiki.fysik.dtu.dk/ase/). 16 | 17 | Franken diagram 18 | 19 | For detailed information and benchmarks please check our paper [*Fast and Fourier Features for Transfer Learning of Interatomic Potentials*](https://arxiv.org/abs/2505.05652). 20 | 21 | ## Documentation 22 | 23 | A full documentation including several examples is available: [https://franken.readthedocs.io/index.html](https://franken.readthedocs.io/index.html). [The paper](https://arxiv.org/abs/2505.05652) also contains a comprehensive description of the methods behind franken. 24 | 25 | ## Install 26 | 27 | To install the latest release of `franken`, you can simply do: 28 | 29 | ```bash 30 | pip install franken 31 | ``` 32 | 33 | Several optional dependencies can be specified, to install packages required for certain operations: 34 | - `cuda` includes packages which speed up training on GPUs (note that `franken` will work on GPUs even without these dependencies thanks to pytorch). 35 | - `fairchem`, `mace`, `sevenn` install the necessary dependencies to use a specific backbone. 36 | - `docs` and `develop` are only needed if you wish to build the documentation, or work on extending the library. 37 | 38 | They can be installed for example by running 39 | 40 | ```bash 41 | pip install franken[mace,cuda] 42 | ``` 43 | 44 | For more details read the [relevant documentation page](https://franken.readthedocs.io/topics/installation.html) 45 | 46 | ## Quickstart 47 | 48 | You can directly run `franken.autotune` to get started with the `franken` library. A quick example is to fine-tune MACE-MP0 on a high-level-of-theory water dataset: 49 | 50 | ```bash 51 | franken.autotune \ 52 | --dataset-name="water" --max-train-samples=8 \ 53 | --l2-penalty="(-10, -5, 5, log)" \ 54 | --force-weight="(0.01, 0.99, 5, linear)" \ 55 | --seed=42 \ 56 | --jac-chunk-size=64 \ 57 | --run-dir="./results" \ 58 | --backbone=mace --mace.path-or-id="MACE-L0" --mace.interaction-block=2 \ 59 | --rf=gaussian --gaussian.num-rf=512 --gaussian.length-scale="[10.0, 15.0]" 60 | ``` 61 | 62 | For more details you can check out the [autotune tutorial](https://franken.readthedocs.io/notebooks/autotune.html) or the [getting started notebook](https://franken.readthedocs.io/notebooks/getting_started.html). 63 | 64 | 65 | ## Citing 66 | 67 | If you find this library useful, please cite our work using the folowing bibtex entry: 68 | ``` 69 | @misc{novelli25franken, 70 | title={Fast and Fourier Features for Transfer Learning of Interatomic Potentials}, 71 | author={Pietro Novelli and Giacomo Meanti and Pedro J. Buigues and Lorenzo Rosasco and Michele Parrinello and Massimiliano Pontil and Luigi Bonati}, 72 | year={2025}, 73 | eprint={2505.05652}, 74 | archivePrefix={arXiv}, 75 | url={https://arxiv.org/abs/2505.05652}, 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | clean: 23 | rm -rf $(BUILDDIR)/* 24 | rm -rf reference/franken-api/stubs 25 | -------------------------------------------------------------------------------- /docs/_static/TM23Cu_sample_complexity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/docs/_static/TM23Cu_sample_complexity.png -------------------------------------------------------------------------------- /docs/_static/css/mystyle.css: -------------------------------------------------------------------------------- 1 | .literal-no-code { 2 | background-color: transparent; 3 | font-size: 1rem; 4 | } 5 | -------------------------------------------------------------------------------- /docs/_static/diagram_part1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/docs/_static/diagram_part1.png -------------------------------------------------------------------------------- /docs/_templates/class.rst: -------------------------------------------------------------------------------- 1 | {{ name | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :show-inheritance: 8 | :no-undoc-members: 9 | :special-members: __mul__, __add__, __div__, __neg__, __sub__, __truediv__ 10 | -------------------------------------------------------------------------------- /docs/_templates/func.rst: -------------------------------------------------------------------------------- 1 | {{ name | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autofunction:: {{ objname }} -------------------------------------------------------------------------------- /docs/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | {% block extrahead %} 4 | {{ super() }} 5 | 11 | {% endblock extrahead %} -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | 8 | import os 9 | import sys 10 | import re 11 | from docutils import nodes 12 | from sphinxawesome_theme.postprocess import Icons 13 | 14 | # -- Path setup -------------------------------------------------------------- 15 | basedir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 16 | sys.path.insert(0, basedir) 17 | 18 | 19 | html_permalinks_icon = Icons.permalinks_icon # SVG as a string 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = "franken" 24 | copyright = "2025, franken team" 25 | author = "franken team" 26 | 27 | # -- General configuration --------------------------------------------------- 28 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 29 | 30 | # Add any paths that contain templates here, relative to this directory. 31 | templates_path = ["_templates"] 32 | 33 | # List of patterns, relative to source directory, that match files and 34 | # directories to ignore when looking for source files. 35 | # This pattern also affects html_static_path and html_extra_path. 36 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "requirements.txt"] 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | "sphinx.ext.autodoc", 43 | "sphinx.ext.autosummary", 44 | "sphinx.ext.napoleon", 45 | "sphinx.ext.intersphinx", 46 | "sphinxawesome_theme", 47 | "myst_parser", 48 | "sphinxarg.ext", 49 | "nbsphinx", 50 | ] 51 | 52 | myst_enable_extensions = ["amsmath", "dollarmath", "html_image"] 53 | 54 | intersphinx_mapping = { 55 | "numpy": ("https://numpy.org/doc/stable/", None), 56 | "torch": ("https://pytorch.org/docs/stable/", None), 57 | "torchvision": ("https://pytorch.org/vision/stable/", None), 58 | "python": ("https://docs.python.org/3.9/", None), 59 | "ase": ("https://wiki.fysik.dtu.dk/ase/", None), 60 | } 61 | 62 | 63 | autodoc_typehints = "description" 64 | autodoc_typehints_description_target = "documented" 65 | # to handle functions as default input arguments 66 | autodoc_preserve_defaults = True 67 | # Warn about broken links 68 | nitpicky = True 69 | autodoc_inherit_docstrings = False 70 | # autodoc_class_signature = "separated" 71 | # autoclass_content = "class" 72 | # autosummary_generate = False 73 | 74 | # autodoc_member_order = "groupwise" 75 | # napoleon_preprocess_types = True 76 | # napoleon_use_rtype = False 77 | 78 | # master_doc = "index" 79 | 80 | source_suffix = { 81 | ".rst": "restructuredtext", 82 | ".txt": "restructuredtext", 83 | ".md": "markdown", 84 | } 85 | 86 | # -- Options for HTML output ------------------------------------------------- 87 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 88 | html_theme = "sphinxawesome_theme" 89 | html_static_path = ["_static"] 90 | html_css_files = ["css/mystyle.css"] 91 | templates_path = ["_templates"] 92 | # Favicon configuration 93 | # html_favicon = '_static/favicon.ico' 94 | # Configure syntax highlighting for Awesome Sphinx Theme 95 | pygments_style = "default" 96 | pygments_style_dark = "material" 97 | html_title = "franken" 98 | # Additional theme configuration 99 | html_theme_options = { 100 | "show_prev_next": True, 101 | "show_scrolltop": True, 102 | "main_nav_links": { 103 | "Docs": "index", 104 | "API Reference": "reference/index", 105 | }, 106 | "extra_header_link_icons": { 107 | "GitHub": { 108 | "link": "https://github.com/CSML-IIT-UCL/franken", 109 | "icon": """""", 110 | }, 111 | }, 112 | # "logo_light": "_static/[logo_light].png", 113 | # "logo_dark": "_static/[logo_dark].png", 114 | } 115 | 116 | ## Teletype role 117 | tt_re = re.compile('^:tt:`(.*)`$') 118 | def tt_role(name, rawtext, text, lineno, inliner, options={}, content=[]): 119 | """ 120 | Can be used as :tt:`SOME_TEXT_HERE`, 121 | """ 122 | result = [] 123 | m = tt_re.search(rawtext) 124 | if m: 125 | arg = m.group(1) 126 | result = [nodes.literal('', arg)] 127 | result[0]['classes'].append('literal-no-code') 128 | return result,[] 129 | 130 | 131 | def setup(app): 132 | app.add_role('tt', tt_role) 133 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. role:: frnkword 2 | 3 | Franken: A Method for Efficient and Accurate Molecular Dynamics 4 | ================================================================ 5 | 6 | :tt:`franken` is a novel method designed to enhance the accuracy of atomistic foundation models used for molecular dynamics simulations, all while maintaining computational efficiency. This method builds upon the capabilities of the `MEKRR method `_, extending its application from fitting energies to also fitting forces, thereby enabling feasible molecular dynamics (MD) simulations. 7 | 8 | Franken's Three-Step Process 9 | ---------------------------- 10 | 11 | :tt:`franken` operates through a three-step pipeline: 12 | 13 | #. **Feature Extraction:** The initial step involves representing the chemical environment of each atom within a 14 | molecular configuration using features extracted from a pre-trained GNN foundation model. 15 | This leverages the inherent knowledge captured by these pre-trained models. 16 | Specifically, :tt:`franken` utilizes features derived from models such as the `MACE-MP0 `_ model. 17 | 18 | #. **Random Features Enhancement:** In this stage, :tt:`franken` introduces non-linearity into the model by transforming the 19 | extracted GNN features using Random Features (RF) maps. These RF maps offer a computationally efficient alternative 20 | to traditional kernel methods by approximating kernel functions, including the widely used Gaussian kernel, 21 | utilizing randomly sampled parameters. 22 | 23 | #. **Energy and Force Prediction:** The final step involves predicting atomic energies and forces by employing a readout mechanism. 24 | This mechanism leverages a learnable vector of coefficients in conjunction with the transformed features obtained from the preceding step. 25 | This design takes advantage of the efficient optimization characteristics of RF models. 26 | 27 | .. figure:: _static/diagram_part1.png 28 | :class: rounded-image 29 | :width: 75% 30 | :align: center 31 | 32 | The three-step pipeline at the heart of :tt:`franken`. 33 | 34 | Advantages of Franken 35 | --------------------- 36 | 37 | :tt:`franken` presents several distinct advantages that position it as a valuable asset in the realm of molecular dynamics simulations: 38 | 39 | - **Closed-Form Optimization:** :tt:`franken` offers the significant advantage of determining the globally optimal model 40 | parameters through a closed-form solution. This eliminates the reliance on iterative gradient descent, leading to 41 | substantial reductions in training time and ensuring efficient optimization. 42 | 43 | - **High Sample Efficiency:** One of :tt:`franken`'s hallmarks is its exceptional data efficiency. 44 | The method achieves accurate results even with a limited number of training samples, 45 | as evidenced by experiments on the TM23 dataset. Notably, :tt:`franken` attained a validation error 46 | of 9 meV/ using only 128 samples with 1024 random features, underscoring its ability to extract 47 | valuable information from relatively small datasets. 48 | 49 | .. figure:: _static/TM23Cu_sample_complexity.png 50 | :class: rounded-image 51 | :width: 75% 52 | :align: center 53 | 54 | Sample complexity of :tt:`franken` on the :tt:`Cu` data from the `TM23 Dataset `_. (MACE-MP0 Backbone) 55 | 56 | - **Parallelization Capabilities:** :tt:`franken`'s training algorithm inherently lends itself to parallelization, allowing it to be scaled across multiple GPUs, thereby significantly accelerating training. This scalability becomes crucial when addressing the computational burden posed by simulations of increasingly intricate molecular systems. 57 | 58 | 59 | .. toctree:: 60 | :maxdepth: 2 61 | :caption: HOW TOs: 62 | :hidden: 63 | 64 | Introduction 65 | 66 | topics/installation.md 67 | topics/model_registry.md 68 | topics/lammps.md 69 | 70 | .. toctree:: 71 | :maxdepth: 2 72 | :caption: Tutorials: 73 | :hidden: 74 | 75 | notebooks/getting_started 76 | notebooks/autotune 77 | notebooks/molecular_dynamics 78 | 79 | 80 | .. toctree:: 81 | :maxdepth: 3 82 | :caption: API Reference: 83 | :hidden: 84 | 85 | reference/index 86 | reference/cli -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/notebooks/autotune.ipynb: -------------------------------------------------------------------------------- 1 | ../../notebooks/autotune.ipynb -------------------------------------------------------------------------------- /docs/notebooks/getting_started.ipynb: -------------------------------------------------------------------------------- 1 | ../../notebooks/getting_started.ipynb -------------------------------------------------------------------------------- /docs/notebooks/molecular_dynamics.ipynb: -------------------------------------------------------------------------------- 1 | ../../notebooks/molecular_dynamics.ipynb -------------------------------------------------------------------------------- /docs/reference/cli.rst: -------------------------------------------------------------------------------- 1 | .. _cli_reference: 2 | 3 | Franken CLI Reference 4 | ===================== 5 | 6 | 7 | .. list-table:: 8 | :header-rows: 1 9 | 10 | * - Program 11 | - Description 12 | * - :doc:`franken.trainers ` 13 | - Automatic hyperparameter tuning for franken models. 14 | * - :doc:`franken.calculators ` 15 | - List and download GNN backbones for franken. 16 | * - :doc:`franken.rf.model ` 17 | - Convert a franken model to be able to use it with LAMMPS. 18 | 19 | .. toctree:: 20 | :maxdepth: 1 21 | :hidden: 22 | 23 | franken-cli/franken.autotune 24 | franken-cli/franken.backbones 25 | franken-cli/franken.create_lammps_model 26 | -------------------------------------------------------------------------------- /docs/reference/franken-api/franken.calculators.rst: -------------------------------------------------------------------------------- 1 | franken.calculators 2 | =================== 3 | 4 | 5 | .. autosummary:: 6 | :toctree: stubs 7 | :template: class.rst 8 | :nosignatures: 9 | 10 | franken.calculators.FrankenCalculator 11 | franken.calculators.LammpsFrankenCalculator 12 | -------------------------------------------------------------------------------- /docs/reference/franken-api/franken.config.rst: -------------------------------------------------------------------------------- 1 | franken.config 2 | ============== 3 | Object-oriented configuration for the `franken` library. 4 | 5 | 6 | Backbone configuration 7 | ---------------------- 8 | 9 | .. autosummary:: 10 | :toctree: stubs 11 | :template: class.rst 12 | :nosignatures: 13 | 14 | franken.config.MaceBackboneConfig 15 | franken.config.FairchemBackboneConfig 16 | franken.config.SevennBackboneConfig 17 | 18 | 19 | Random feature configuration 20 | ---------------------------- 21 | 22 | .. autosummary:: 23 | :toctree: stubs 24 | :template: class.rst 25 | :nosignatures: 26 | 27 | franken.config.GaussianRFConfig 28 | franken.config.MultiscaleGaussianRFConfig 29 | 30 | 31 | Other configurations 32 | -------------------- 33 | 34 | .. autosummary:: 35 | :toctree: stubs 36 | :template: class.rst 37 | :nosignatures: 38 | 39 | franken.config.DatasetConfig 40 | franken.config.SolverConfig 41 | franken.config.AutotuneConfig 42 | -------------------------------------------------------------------------------- /docs/reference/franken-api/franken.rf.heads.rst: -------------------------------------------------------------------------------- 1 | franken.rf.heads 2 | ================ 3 | This module contains random feature implementations for different kernels 4 | 5 | Base Class 6 | ---------- 7 | .. autosummary:: 8 | :toctree: stubs 9 | :template: class.rst 10 | :nosignatures: 11 | 12 | franken.rf.heads.RandomFeaturesHead 13 | 14 | 15 | Gaussian kernel 16 | --------------- 17 | Approximations to the classical Gaussian (or RBF) kernel 18 | 19 | .. autosummary:: 20 | :toctree: stubs 21 | :template: class.rst 22 | :nosignatures: 23 | 24 | franken.rf.heads.OrthogonalRFF 25 | franken.rf.heads.MultiScaleOrthogonalRFF 26 | franken.rf.heads.BiasedOrthogonalRFF 27 | 28 | Other kernels 29 | ------------- 30 | 31 | .. autosummary:: 32 | :toctree: stubs 33 | :template: class.rst 34 | :nosignatures: 35 | 36 | franken.rf.heads.Linear 37 | franken.rf.heads.RandomFeaturesHead 38 | franken.rf.heads.TensorSketch 39 | 40 | Helper Functions 41 | ---------------- 42 | 43 | .. autosummary:: 44 | :toctree: stubs 45 | :template: func.rst 46 | :nosignatures: 47 | 48 | franken.rf.heads.initialize_rf 49 | -------------------------------------------------------------------------------- /docs/reference/franken-api/franken.rf.model.rst: -------------------------------------------------------------------------------- 1 | franken.rf.model 2 | ================ 3 | The main franken model implementation 4 | 5 | .. autosummary:: 6 | :toctree: stubs 7 | :template: class.rst 8 | :nosignatures: 9 | 10 | franken.rf.model.FrankenPotential 11 | -------------------------------------------------------------------------------- /docs/reference/franken-api/franken.rf.scaler.rst: -------------------------------------------------------------------------------- 1 | franken.rf.scaler 2 | ================= 3 | 4 | .. autosummary:: 5 | :toctree: stubs 6 | :template: class.rst 7 | :nosignatures: 8 | 9 | franken.rf.scaler.FeatureScaler 10 | franken.rf.scaler.Statistics 11 | 12 | 13 | .. autosummary:: 14 | :toctree: stubs 15 | :template: func.rst 16 | :nosignatures: 17 | 18 | franken.rf.scaler.compute_dataset_statistics 19 | -------------------------------------------------------------------------------- /docs/reference/franken-api/franken.trainers.rst: -------------------------------------------------------------------------------- 1 | franken.trainers 2 | ================ 3 | 4 | Base Class 5 | ---------- 6 | .. autosummary:: 7 | :toctree: stubs 8 | :template: class.rst 9 | :nosignatures: 10 | 11 | franken.trainers.BaseTrainer 12 | 13 | Random features trainer 14 | ----------------------- 15 | .. autosummary:: 16 | :toctree: stubs 17 | :template: class.rst 18 | :nosignatures: 19 | 20 | franken.trainers.RandomFeaturesTrainer 21 | -------------------------------------------------------------------------------- /docs/reference/franken-cli/franken.autotune.rst: -------------------------------------------------------------------------------- 1 | Autotune 2 | ======== 3 | 4 | .. argparse:: 5 | :module: franken.autotune.script 6 | :func: get_parser_fn 7 | :prog: franken.autotune 8 | :nodefault: 9 | -------------------------------------------------------------------------------- /docs/reference/franken-cli/franken.backbones.rst: -------------------------------------------------------------------------------- 1 | Backbones 2 | ========= 3 | 4 | 5 | .. argparse:: 6 | :module: franken.backbones.cli 7 | :func: get_parser_fn 8 | :prog: franken.backbones 9 | -------------------------------------------------------------------------------- /docs/reference/franken-cli/franken.create_lammps_model.rst: -------------------------------------------------------------------------------- 1 | Create LAMMPS model 2 | =================== 3 | 4 | .. argparse:: 5 | :module: franken.calculators.lammps_calc 6 | :func: get_parser_fn 7 | :prog: franken.create_lammps_model 8 | -------------------------------------------------------------------------------- /docs/reference/index.rst: -------------------------------------------------------------------------------- 1 | 2 | Franken API Reference 3 | ===================== 4 | 5 | .. list-table:: 6 | :header-rows: 1 7 | 8 | * - Module 9 | - Description 10 | * - :doc:`franken.trainers ` 11 | - Train franken from atomistic simulation data 12 | * - :doc:`franken.calculators ` 13 | - Run molecular dynamics with learned potentials. 14 | * - :doc:`franken.rf.model ` 15 | - Main model class for franken 16 | * - :doc:`franken.rf.heads ` 17 | - Random feature implementations for different kernels 18 | * - :doc:`franken.rf.scaler ` 19 | - Utilities for scaling random features 20 | * - :doc:`franken.config ` 21 | - Configuration data-classes for the whole franken library 22 | 23 | 24 | .. toctree:: 25 | :maxdepth: 1 26 | :hidden: 27 | 28 | franken-api/franken.trainers 29 | franken-api/franken.calculators 30 | franken-api/franken.rf.model 31 | franken-api/franken.rf.heads 32 | franken-api/franken.rf.scaler 33 | franken-api/franken.config 34 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Sphinx==7.3.7 2 | sphinxawesome-theme==5.2.0 3 | sphinxcontrib-applehelp==2.0.0 4 | sphinxcontrib-devhelp==2.0.0 5 | sphinxcontrib-htmlhelp==2.1.0 6 | sphinxcontrib-jsmath==1.0.1 7 | sphinxcontrib-qthelp==2.0.0 8 | sphinxcontrib-serializinghtml==2.0.0 9 | sphinx-argparse 10 | myst-parser 11 | nbsphinx -------------------------------------------------------------------------------- /docs/topics/installation.md: -------------------------------------------------------------------------------- 1 | (installation)= 2 | # Installation 3 | 4 | To install `franken`, start by setting up your environment with the correct **version of [PyTorch](https://pytorch.org/)**. This is especially necessary if you wish to use GPUs. Then install `franken` by running 5 | ```bash 6 | pip install franken 7 | ``` 8 | The basic installation comes bare-bones without any GNN backbone installed. You can install franken with a specific backbone directly, by running one of the following commands 9 | ```bash 10 | pip install franken[cuda,mace] 11 | pip install franken[cuda,fairchem] 12 | pip install franken[cuda,sevenn] 13 | ``` 14 | In more detail: 15 | - the `cuda` qualifier installs dependencies which are only relevant on GPU-enabled environments and can be omitted. 16 | - the three supported backbones are [MACE](https://github.com/ACEsuit/mace), [SchNet from fairchem](https://github.com/FAIR-Chem/fairchem), and [SevenNet](https://github.com/MDIL-SNU/SevenNet). They are explained in more detail below. 17 | 18 | 19 | ```{warning} 20 | Each backbone seems to have mutually incompatible requirements, particularly with regards to `e3nn` - but also pytorch versions might be a problem. 21 | To minimize incompatibilities, we suggest that the users who wishes to use multiple backbones create independent python environments for each. 22 | In particular, the `mace-torch` package requires an old version of `e3nn` (0.4.4) which conflicts with `fairchem-core`, see [this relevant issue](https://github.com/ACEsuit/mace/issues/555) and with `SevenNet`. If you encounter errors with model loading, simply upgrade `e3nn` by running `pip install -U e3nn`. 23 | ``` 24 | 25 | ## Supported pre-trained models 26 | ### MACE 27 | We support several models which use the [MACE architecture](https://github.com/ACEsuit/mace): 28 | - The [`MACE-MP0`](https://arxiv.org/abs/2401.00096) models trained on the materials project data by Batatia et al. Additional informations on the pre-training of `MACE-MP0` are available on its [HuggingFace model card](https://huggingface.co/cyrusyc/mace-universal). 29 | - The MACE-OFF ([paper](https://github.com/ACEsuit/mace-off) and [github](https://github.com/ACEsuit/mace-off)) models which are pretrained on organic molecules. 30 | - The Egret ([github](https://github.com/rowansci/egret-public)) family of models (`Egret-1`, `Egret-1e`, `Egret-1t`), also tuned for organic molecules. 31 | 32 | To use any MACE model as a backbone for `franken` just `pip`-install `mace-torch` in `franken`'s environment 33 | ```bash 34 | pip install mace-torch 35 | ``` 36 | or directly install franken with mace support (`pip install franken[cuda,mace]`). 37 | 38 | In addition to MACE-MP0 trained on the materials project dataset, Franken also supports the [`MACE-OFF` models](https://arxiv.org/abs/2312.15211) for organic chemistry. 39 | 40 | 41 | ### SevenNet 42 | 43 | Franken also supports the [SevenNet model](https://arxiv.org/abs/2402.03789) by Park et al. as implemented in the [`sevennet`](https://github.com/MDIL-SNU/SevenNet) library. 44 | We have only tested the SevenNet-0 model trained on the materials project dataset, but support for other models should be possible (open an issue if you encounter any problem). 45 | 46 | ### SchNet OC20 (fairchem, formerly OCP) 47 | We support the [SchNet model](https://arxiv.org/abs/1706.08566) by Schütt et al. as implemented in the [`fairchem`](https://fair-chem.github.io/) library by Meta's FAIR. The pre-training was done on the [Open Catalyst dataset](https://fair-chem.github.io/core/datasets/oc20.html). To use it as a backbone for `franken`, install the `fairchem` library 48 | ```bash 49 | pip install fairchem-core 50 | ``` 51 | and the `torch_geometric` dependencies as explained in the [FairChem docs](https://fair-chem.github.io/core/install.html). 52 | ```{note} 53 | Not all of fairchem's dependencies can be installed by `pip` alone, check the [FairChem docs](https://fair-chem.github.io/core/install.html). 54 | ``` 55 | Note that `SchNet` is not competitive with more recent GNN models and is only meant as a baseline, and to showcase support for diverse backends. 56 | For now we do not support fairchem v2 models, if you wish to see this implemented please file an issue! -------------------------------------------------------------------------------- /docs/topics/lammps.md: -------------------------------------------------------------------------------- 1 | # Franken + LAMMPS 2 | 3 | The basic steps required to run a Franken model with [LAMMPS](https://www.lammps.org/) are: 4 | 1. Compile the model using `franken/calculators/lammps.py`: 5 | ```bash 6 | franken.create_lammps_model --model_path= 7 | ``` 8 | Note that only models which use the MACE backbone can be compiled and run with LAMMPS. For the other backbones please use the ase MD interface. The compiled model will be saved in the same directory as the original model, with `-lammps` appended to the filename. 9 | 2. Configure LAMMPS. The following lines are necessary, the second line should point to the compiled model from step 1. 10 | ``` 11 | pair_style mace no_domain_decomposition 12 | pair_coeff * * C H N O 13 | ``` 14 | 3. Run LAMMPS-Mace. On leonardo you can find it pre-compiled here: 15 | `/leonardo/pub/userexternal/lbonati1/software/lammps-mace/lammps/build-ampere-plumed/lmp` 16 | 17 | ## Compiling LAMMPS-Mace 18 | 19 | This follows the [MACE guide](https://mace-docs.readthedocs.io/en/latest/guide/lammps.html) adapting it to the leonardo cluster. 20 | This can be useful in case one wants to modify the Mace patch to LAMMPS. In particular, the following two files are important: 21 | - [https://github.com/ACEsuit/lammps/blob/mace/src/ML-MACE/pair_mace.cpp](https://github.com/ACEsuit/lammps/blob/mace/src/ML-MACE/pair_mace.cpp) 22 | - [https://github.com/ACEsuit/lammps/blob/mace/src/KOKKOS/pair_mace_kokkos.cpp](https://github.com/ACEsuit/lammps/blob/mace/src/KOKKOS/pair_mace_kokkos.cpp) 23 | 24 | We will assume to start from directory `$BASE_DIR` 25 | 1. ```git clone --branch=mace --depth=1 https://github.com/ACEsuit/lammps``` 26 | 2. download librtorch. For now keeping the default version as specified by MACE, but note that new versions exist! 27 | ```bash 28 | wget https://download.pytorch.org/libtorch/cu121/libtorch-shared-with-deps-2.2.0%2Bcu121.zip 29 | unzip libtorch-shared-with-deps-2.2.0+cu121.zip 30 | rm libtorch-shared-with-deps-2.2.0+cu121.zip 31 | mv libtorch libtorch-gpu 32 | ``` 33 | 3. Get a GPU node for compilation 34 | `srun -N 1 --ntasks-per-node=1 --cpus-per-task=8 --gres=gpu:1 -A -p boost_usr_prod -t 00:30:00 --pty /bin/bash` 35 | 4. Compile: 36 | 1. Load modules 37 | ```bash 38 | module purge 39 | module load gcc/12.2.0 40 | module load gsl/2.7.1--gcc--12.2.0 41 | module load openmpi/4.1.6--gcc--12.2.0 42 | module load fftw/3.3.10--openmpi--4.1.6--gcc--12.2.0 43 | module load openblas/0.3.24--gcc--12.2.0 44 | module load cuda/12.1 45 | module load intel-oneapi-mkl/2023.2.0 46 | ``` 47 | 2. Compile 48 | ```bash 49 | cd $BASE_DIR/lammps 50 | mkdir -p build-ampere 51 | cd build-ampere 52 | cmake \ 53 | -D CMAKE_BUILD_TYPE=Release \ 54 | -D CMAKE_INSTALL_PREFIX=$(pwd) \ 55 | -D CMAKE_CXX_STANDARD=17 \ 56 | -D CMAKE_CXX_STANDARD_REQUIRED=ON \ 57 | -D BUILD_MPI=ON \ 58 | -D BUILD_SHARED_LIBS=ON \ 59 | -D PKG_KOKKOS=ON \ 60 | -D Kokkos_ENABLE_CUDA=ON \ 61 | -D CMAKE_CXX_COMPILER=$(pwd)/../lib/kokkos/bin/nvcc_wrapper \ 62 | -D Kokkos_ARCH_AMDAVX=ON \ 63 | -D Kokkos_ARCH_AMPERE100=ON \ 64 | -D CMAKE_PREFIX_PATH=$(pwd)/../../libtorch-gpu \ 65 | -D PKG_ML-MACE=ON \ 66 | ../cmake 67 | make -j 8 68 | make install 69 | ``` 70 | The compiled binary is then at `$BASE_DIR/lammps/build-ampere/bin/lmp`. 71 | 72 | 73 | ## Running LAMMPS-Mace 74 | 75 | This is just an example sbatch file which can be used to run LAMMPS-Mace. Edit it according to your needs. It uses the paths to LAMMPS-Mace as available on the leonardo cluster, and we will assume that LAMMPS has been configured in a file named `in.lammps`. 76 | 77 | ```bash 78 | #!/bin/bash 79 | #SBATCH --account= 80 | #SBATCH --partition=boost_usr_prod # partition to be used 81 | #SBATCH --time 00:30:00 # format: HH:MM:SS 82 | #SBATCH --qos=boost_qos_dbg 83 | #SBATCH --nodes=1 # node 84 | #SBATCH --ntasks-per-node=1 # tasks out of 32 85 | #SBATCH --gres=gpu:1 # gpus per node out of 4 86 | #SBATCH --cpus-per-task=1 # Important: if > 1 kokkos complains. 87 | ############################ 88 | 89 | module purge 90 | module load profile/base 91 | module load gcc/12.2.0 92 | module load gsl/2.7.1--gcc--12.2.0 93 | module load openmpi/4.1.6--gcc--12.2.0 94 | module load fftw/3.3.10--openmpi--4.1.6--gcc--12.2.0 95 | module load openblas/0.3.24--gcc--12.2.0 96 | module load cuda/12.1 97 | module load intel-oneapi-mkl/2023.2.0 98 | 99 | . /leonardo/pub/userexternal/lbonati1/software/lammps-mace/libtorch-gpu/sourceme.sh 100 | . /leonardo/pub/userexternal/lbonati1/software/plumed/plumed2-2.9-gcc12/sourceme.sh 101 | 102 | echo "setting env variable" 103 | export OMP_NUM_THREADS=${SLURM_CPUS_PER_TASK} 104 | export OMP_PROC_BIND=spread 105 | export OMP_PLACES=threads 106 | 107 | echo "running job" 108 | in_file='in.lammps' 109 | log_file='log.lammps' 110 | lmp='/leonardo/pub/userexternal/lbonati1/software/lammps-mace/lammps/build-ampere-plumed/lmp' 111 | 112 | srun $lmp -k on g 1 t ${SLURM_CPUS_PER_TASK} -sf kk -i $in_file -l $log_file 113 | 114 | wait 115 | ``` 116 | -------------------------------------------------------------------------------- /docs/topics/model_registry.md: -------------------------------------------------------------------------------- 1 | (model-registry)= 2 | # Backbones Registry 3 | 4 | The available pre-trained GNNs can be listed by running `franken.backbones list`. 5 | As of today, the available models are: 6 | 7 | ``` 8 | DOWNLOADED MODELS 9 | --------------------(/path/to/.franken/gnn_checkpoints)-------------------- 10 | MACE-L0 (MACE) 11 | --------------------------------AVAILABLE MODELS-------------------------------- 12 | SevenNet0 (sevenn) 13 | MACE-L1 (MACE) 14 | MACE-L2 (MACE) 15 | MACE-OFF-small (MACE) 16 | MACE-OFF-medium (MACE) 17 | MACE-OFF-large (MACE) 18 | SchNet-S2EF-OC20-200k (fairchem) 19 | SchNet-S2EF-OC20-2M (fairchem) 20 | SchNet-S2EF-OC20-20M (fairchem) 21 | SchNet-S2EF-OC20-All (fairchem) 22 | -------------------------------------------------------------------------------- 23 | ``` 24 | 25 | Models can also be directly downloaded by copying the backbone-ID from the command above into the `download` command 26 | 27 | ```bash 28 | franken.backbones download 29 | ``` 30 | 31 | Check the command-line help (e.g. `franken.backbones download --help`) for more information. -------------------------------------------------------------------------------- /franken/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | __version__ = "0.4.1" 4 | 5 | # Get the absolute path of the directory where this file is located 6 | FRANKEN_DIR = Path(__file__).resolve().parent 7 | -------------------------------------------------------------------------------- /franken/autotune/__init__.py: -------------------------------------------------------------------------------- 1 | from franken.autotune.script import autotune 2 | 3 | __all__ = ["autotune"] 4 | -------------------------------------------------------------------------------- /franken/autotune/__main__.py: -------------------------------------------------------------------------------- 1 | from franken.autotune.script import cli_entry_point 2 | 3 | 4 | if __name__ == "__main__": 5 | cli_entry_point() 6 | -------------------------------------------------------------------------------- /franken/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from franken.backbones.utils import load_model_registry 2 | 3 | 4 | REGISTRY = load_model_registry() 5 | -------------------------------------------------------------------------------- /franken/backbones/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from franken.backbones import REGISTRY 4 | from franken.backbones.utils import download_checkpoint, make_summary 5 | from franken.utils.misc import setup_logger 6 | 7 | 8 | ### Command 'list': list available models 9 | 10 | 11 | def build_list_arg_parser(subparsers) -> argparse.ArgumentParser: 12 | parser = subparsers.add_parser("list", description="List available models") 13 | parser.add_argument( 14 | "--cache_dir", 15 | help=( 16 | "Directory to save the downloaded checkpoints. " 17 | "Defaults to '~/.franken/' in the user home or to the " 18 | "'FRANKEN_CACHE_DIR' environment variable if set." 19 | ), 20 | type=str, 21 | default=None, 22 | ) 23 | parser.add_argument( 24 | "--log_level", 25 | help="log-level for the command-line logger", 26 | type=str, 27 | default="INFO", 28 | choices=["DEBUG", "INFO", "WARNING", "ERROR"], 29 | ) 30 | return parser 31 | 32 | 33 | def run_list_cmd(args): 34 | setup_logger(level=args.log_level, directory=None) 35 | print(make_summary(cache_dir=args.cache_dir)) 36 | 37 | 38 | ### Command 'download': download a model 39 | 40 | 41 | def build_download_arg_parser(subparsers) -> argparse.ArgumentParser: 42 | parser = subparsers.add_parser("download", description="Download a model") 43 | parser.add_argument( 44 | "--model_name", 45 | help="The name of the model to download.", 46 | type=str, 47 | required=True, 48 | choices=[ 49 | name for name, info in REGISTRY.items() if info["implemented"] is True 50 | ], 51 | ) 52 | parser.add_argument( 53 | "--cache_dir", 54 | help=( 55 | "Directory to save the downloaded checkpoints. " 56 | "Defaults to '~/.franken/' in the user home or to the " 57 | "'FRANKEN_CACHE_DIR' environment variable if set." 58 | ), 59 | type=str, 60 | default=None, 61 | ) 62 | parser.add_argument( 63 | "--log_level", 64 | help="log-level for the command-line logger", 65 | type=str, 66 | default="INFO", 67 | choices=["DEBUG", "INFO", "WARNING", "ERROR"], 68 | ) 69 | return parser 70 | 71 | 72 | def run_download_cmd(args): 73 | setup_logger(level=args.log_level, directory=None) 74 | download_checkpoint(args.model_name, args.cache_dir) 75 | 76 | 77 | def build_arg_parser(): 78 | parser = argparse.ArgumentParser( 79 | description="List and download GNN backbones for franken." 80 | ) 81 | 82 | subparsers = parser.add_subparsers( 83 | required=True, 84 | title="Franken backbone CLI", 85 | description="Provides helpers to interact with the various backbone models supported by Franken", 86 | help="Run `%(prog)s -h` for help with the individual subcommands", 87 | ) 88 | 89 | list_parser = build_list_arg_parser(subparsers) 90 | list_parser.set_defaults(func=run_list_cmd) 91 | download_parser = build_download_arg_parser(subparsers) 92 | download_parser.set_defaults(func=run_download_cmd) 93 | 94 | return parser 95 | 96 | 97 | def main(): 98 | """This entry-point has 2 commands, 'list' and 'download'. 99 | Usage: 100 | franken.backbones list 101 | franken.backbones download 102 | """ 103 | parser = build_arg_parser() 104 | args = parser.parse_args() 105 | args.func(args) 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | 111 | 112 | # For sphinx docs 113 | get_parser_fn = lambda: build_arg_parser() # noqa: E731 114 | -------------------------------------------------------------------------------- /franken/backbones/utils.py: -------------------------------------------------------------------------------- 1 | import importlib.resources 2 | import json 3 | import logging 4 | import os 5 | from pathlib import Path 6 | 7 | import requests 8 | import torch 9 | 10 | from franken.config import BackboneConfig, asdict_with_classvar 11 | from franken.utils import distributed 12 | from franken.utils.file_utils import download_file 13 | 14 | 15 | logger = logging.getLogger("franken") 16 | 17 | 18 | def load_model_registry(): 19 | model_registry_text = ( 20 | importlib.resources.files("franken.backbones") 21 | .joinpath("registry.json") 22 | .read_text() 23 | ) 24 | model_registry = json.loads(model_registry_text) 25 | return model_registry 26 | 27 | 28 | class CacheDir: 29 | directory: Path | None = None 30 | 31 | @staticmethod 32 | def initialize(cache_dir: Path | str | None = None): 33 | if CacheDir.is_initialized(): 34 | logger.warning( 35 | f"Cache directory already initialized at {CacheDir.directory}. Reinitializing." 36 | ) 37 | # Default cache location: ~/.franken 38 | default_cache = Path.home() / ".franken" 39 | if cache_dir is None: 40 | env_cache_dir = os.environ.get("FRANKEN_CACHE_DIR", None) 41 | if env_cache_dir is None: 42 | logger.info(f"Initializing default cache directory at {default_cache}") 43 | cache_dir = default_cache 44 | else: 45 | logger.info( 46 | f"Initializing cache directory from $FRANKEN_CACHE_DIR {env_cache_dir}" 47 | ) 48 | cache_dir = env_cache_dir 49 | else: 50 | logger.info(f"Initializing custom cache directory {cache_dir}") 51 | CacheDir.directory = Path(cache_dir) 52 | 53 | # Ensure the directory exists 54 | if not CacheDir.directory.exists(): 55 | CacheDir.directory.mkdir(parents=True, exist_ok=True) 56 | logger.info(f"Created cache directory at: {CacheDir.directory}") 57 | 58 | @staticmethod 59 | def get() -> Path: 60 | if not CacheDir.is_initialized(): 61 | CacheDir.initialize() 62 | assert CacheDir.directory is not None 63 | return CacheDir.directory 64 | 65 | @staticmethod 66 | def is_initialized() -> bool: 67 | return CacheDir.directory is not None 68 | 69 | 70 | def make_summary(cache_dir: str | None = None): 71 | """Function to print available models, first those present locally.""" 72 | if cache_dir is not None: 73 | CacheDir.initialize(cache_dir=cache_dir) 74 | registry = load_model_registry() 75 | ckpt_dir = CacheDir.get() / "gnn_checkpoints" 76 | 77 | local_models = [] 78 | remote_models = [] 79 | _summary = "" 80 | for model, info in registry.items(): 81 | local_path = ckpt_dir / info["local"] 82 | kind = info["kind"] 83 | implemented = info.get("implemented", False) 84 | if implemented: 85 | if local_path.is_file(): 86 | local_models.append((model, kind)) 87 | else: 88 | remote_models.append((model, kind)) 89 | if len(local_models) > 0: 90 | _summary += f"{'DOWNLOADED MODELS':^80}\n" 91 | _summary += f"{'(' + str(ckpt_dir) + ')':-^80}\n" 92 | for model, kind in local_models: 93 | _str = f"{model} ({kind})" 94 | _summary += f"{_str:<40}\n" 95 | 96 | _summary += f"{'AVAILABLE MODELS':-^80}\n" 97 | for model, kind in remote_models: 98 | _str = f"{model} ({kind})" 99 | _summary += f"{_str:<80}\n" 100 | _summary += "-" * 80 101 | return _summary 102 | 103 | 104 | def get_checkpoint_path(backbone_path_or_id: str) -> Path: 105 | """Fetches the path of a given backbone. If the backbone is not present, it will be downloaded. 106 | 107 | The backbone can be either specified directly via its file-system path, 108 | then this function is a thin wrapper -- or it can be specified via its 109 | ID in the model registry. Then this function takes care of finding the 110 | correct model path and potentially downloading the backbone from the internet. 111 | 112 | Args: 113 | backbone_path_or_id (str): file-system path to the backbone 114 | or the backbone's ID as per the model registry. 115 | 116 | Returns: 117 | Path: Path to the model on disk 118 | 119 | See Also: 120 | You can use the command :code:`franken.backbones list` from the command-line 121 | to find out which backbone IDs are supported out-of-the-box. 122 | """ 123 | registry = load_model_registry() 124 | gnn_checkpoints_dir = CacheDir.get() / "gnn_checkpoints" 125 | 126 | if backbone_path_or_id not in registry.keys(): 127 | if not os.path.isfile(backbone_path_or_id): 128 | raise FileNotFoundError( 129 | f"GNN Backbone path '{backbone_path_or_id}' does not exist. " 130 | f"You should either provide an existing backbone path or a backbone ID " 131 | f"from the registry of available backbones: \n{make_summary()}" 132 | ) 133 | return Path(backbone_path_or_id) 134 | else: 135 | backbone_info = registry[backbone_path_or_id] 136 | ckpt_path = gnn_checkpoints_dir / backbone_info["local"] 137 | # Download checkpoint being aware of multiprocessing 138 | if distributed.get_rank() != 0: 139 | distributed.barrier() 140 | else: 141 | if not ckpt_path.exists(): 142 | download_checkpoint(backbone_path_or_id) 143 | distributed.barrier() 144 | return ckpt_path 145 | 146 | 147 | def download_checkpoint(gnn_backbone_id: str, cache_dir: str | None = None) -> None: 148 | """Download the model if it's not already present locally.""" 149 | registry = load_model_registry() 150 | if cache_dir is not None: 151 | CacheDir.initialize(cache_dir=cache_dir) 152 | ckpt_dir = CacheDir.get() / "gnn_checkpoints" 153 | 154 | if gnn_backbone_id not in registry.keys(): 155 | raise NameError( 156 | f"Unknown {gnn_backbone_id} GNN backbone, the current available backbones are\n{make_summary()}" 157 | ) 158 | 159 | if not registry[gnn_backbone_id]["implemented"]: 160 | raise NotImplementedError( 161 | f"The model {gnn_backbone_id} is not implemented in franken yet." 162 | ) 163 | 164 | local_path = ckpt_dir / registry[gnn_backbone_id]["local"] 165 | remote_path = registry[gnn_backbone_id]["remote"] 166 | 167 | if local_path.is_file(): 168 | logger.info( 169 | f"Model already exists locally at {local_path}. No download needed." 170 | ) 171 | return 172 | 173 | local_path.parent.mkdir(parents=True, exist_ok=True) 174 | logger.info(f"Downloading model from {remote_path} to {local_path}") 175 | try: 176 | download_file(url=remote_path, filename=local_path, desc="Downloading model") 177 | except requests.RequestException as e: 178 | logger.error(f"Download failed. {e}") 179 | raise e 180 | 181 | 182 | def load_checkpoint(gnn_config: BackboneConfig) -> torch.nn.Module: 183 | gnn_config_dict = asdict_with_classvar(gnn_config) 184 | gnn_backbone_id = gnn_config_dict.pop("path_or_id") 185 | backbone_family = gnn_config_dict.pop("family") 186 | ckpt_path = get_checkpoint_path(gnn_backbone_id) 187 | err_msg = f"franken wasn't able to load {gnn_backbone_id}. Is {backbone_family} installed?" 188 | if backbone_family == "fairchem": 189 | try: 190 | from franken.backbones.wrappers.fairchem_schnet import FrankenSchNetWrap 191 | except ImportError as import_err: 192 | logger.error(err_msg, exc_info=import_err) 193 | raise 194 | return FrankenSchNetWrap.load_from_checkpoint( 195 | str(ckpt_path), gnn_backbone_id=gnn_backbone_id, **gnn_config_dict 196 | ) 197 | elif backbone_family == "mace": 198 | try: 199 | from franken.backbones.wrappers.mace_wrap import FrankenMACE 200 | except ImportError as import_err: 201 | logger.error(err_msg, exc_info=import_err) 202 | raise 203 | return FrankenMACE.load_from_checkpoint( 204 | str(ckpt_path), 205 | gnn_backbone_id=gnn_backbone_id, 206 | map_location="cpu", 207 | **gnn_config_dict, 208 | ) 209 | elif backbone_family == "sevenn": 210 | try: 211 | from franken.backbones.wrappers.sevenn import FrankenSevenn 212 | except ImportError as import_err: 213 | logger.error(err_msg, exc_info=import_err) 214 | raise 215 | return FrankenSevenn.load_from_checkpoint( 216 | ckpt_path, gnn_backbone_id=gnn_backbone_id, **gnn_config_dict 217 | ) 218 | else: 219 | raise ValueError(f"Unknown backbone family {backbone_family}") 220 | -------------------------------------------------------------------------------- /franken/backbones/wrappers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/franken/backbones/wrappers/__init__.py -------------------------------------------------------------------------------- /franken/backbones/wrappers/common_patches.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | 6 | logger = logging.getLogger("franken") 7 | 8 | 9 | def patch_e3nn(): 10 | # NOTE: 11 | # Patching should occur during training: it is necessary for `jvp` on the MACE model, 12 | # but not during inference, when we only use `torch.autograd`. For inference, we may want 13 | # to compile the model using `torch.jit` - and the patch interferes with the JIT, so we 14 | # must disable it. 15 | 16 | import e3nn.o3._spherical_harmonics 17 | 18 | if hasattr(e3nn.o3._spherical_harmonics._spherical_harmonics, "code"): 19 | # Then _spherical_harmonics is a scripted function, we need to undo this! 20 | new_locals = {"Tensor": torch.Tensor} 21 | exec(e3nn.o3._spherical_harmonics._spherical_harmonics.code, None, new_locals) 22 | 23 | def _spherical_harmonics( 24 | lmax: int, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor 25 | ) -> torch.Tensor: 26 | return new_locals["_spherical_harmonics"](torch.tensor(lmax), x, y, z) 27 | 28 | # Save to allow undoing later 29 | setattr( 30 | e3nn.o3._spherical_harmonics, 31 | "_old_spherical_harmonics", 32 | e3nn.o3._spherical_harmonics._spherical_harmonics, 33 | ) 34 | e3nn.o3._spherical_harmonics._spherical_harmonics = _spherical_harmonics 35 | 36 | # 2nd patch for newer e3nn versions (somewhere between 0.5.0 and 0.5.5 37 | # e3nn jits _spherical_harmonics which the SphericalHarmonics class, 38 | # making the above patch ineffective) 39 | try: 40 | from e3nn import set_optimization_defaults 41 | 42 | set_optimization_defaults(jit_script_fx=False) 43 | except ImportError: 44 | pass # only valid for newer e3nn 45 | 46 | 47 | def unpatch_e3nn(): 48 | # This is only useful for CI and testing environments. 49 | # When jit-compiling a franken module (e.g. for LAMMPS), we don't want the patch applied! 50 | import e3nn.o3._spherical_harmonics 51 | 52 | if hasattr(e3nn.o3._spherical_harmonics, "_old_spherical_harmonics"): 53 | e3nn.o3._spherical_harmonics._spherical_harmonics = ( 54 | e3nn.o3._spherical_harmonics._old_spherical_harmonics 55 | ) 56 | -------------------------------------------------------------------------------- /franken/backbones/wrappers/fairchem_schnet.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import fairchem.core.common.utils 4 | import torch 5 | from fairchem.core.models.schnet import SchNetWrap 6 | 7 | from franken.data import Configuration 8 | 9 | 10 | def segment_coo_patch(src, index, dim_size=None): 11 | if dim_size is None: 12 | dim_size = index.max().item() + 1 13 | out = torch.zeros(dim_size, dtype=src.dtype, device=src.device) 14 | out.scatter_add_(dim=0, index=index, src=src) 15 | return out 16 | 17 | 18 | def segment_csr_patch(src, indptr): 19 | out = torch.zeros(indptr.size(0) - 1, dtype=src.dtype, device=src.device) 20 | for i in range(len(indptr) - 1): 21 | out[i] = src[indptr[i] : indptr[i + 1]].sum() 22 | return out 23 | 24 | 25 | def patch_fairchem(): 26 | """ 27 | The `segment_coo` and `segment_csr` patches are necessary to allow 28 | forward-mode autodiff through the network, which is not implemented 29 | in the original torch-scatter functions. 30 | """ 31 | fairchem.core.common.utils.segment_coo = segment_coo_patch 32 | fairchem.core.common.utils.segment_csr = segment_csr_patch 33 | 34 | 35 | FairchemCompatData = namedtuple( 36 | "FairchemCompatData", ["pos", "cell", "batch", "natoms", "atomic_numbers"] 37 | ) 38 | 39 | 40 | class FrankenSchNetWrap(SchNetWrap): 41 | def __init__(self, *args, interaction_block, gnn_backbone_id, **kwargs): 42 | patch_fairchem() 43 | super().__init__(*args, **kwargs) 44 | 45 | self.interaction_block = interaction_block 46 | self.gnn_backbone_id = gnn_backbone_id 47 | 48 | def descriptors( 49 | self, 50 | data: Configuration, 51 | ): 52 | """ 53 | Forward pass for the SchNet model to get the embedded representations of the input data 54 | """ 55 | fairchem_compat_data = FairchemCompatData( 56 | data.atom_pos, data.cell, data.batch_ids, data.natoms, data.atomic_numbers 57 | ) 58 | # fairchem checks if the attribute exists, not whether it's None. 59 | if data.pbc is not None: 60 | fairchem_compat_data.pbc = data.pbc # type: ignore 61 | # Get the atomic numbers of the input data 62 | z = data.atomic_numbers.long() 63 | assert z.dim() == 1 64 | # Get the edge index, edge weight and other attributes of the input data 65 | graph = self.generate_graph(fairchem_compat_data) 66 | 67 | edge_attr = self.distance_expansion(graph.edge_distance) 68 | 69 | # Get the embedded representations of the input data 70 | h = self.embedding(z) 71 | for interaction in self.interactions[: self.interaction_block]: 72 | h = h + interaction(h, graph.edge_index, graph.edge_distance, edge_attr) 73 | 74 | return h 75 | 76 | def feature_dim(self): 77 | return self.hidden_channels 78 | 79 | def num_params(self) -> int: 80 | return sum(p.numel() for p in self.parameters()) 81 | 82 | def init_args(self): 83 | return { 84 | "gnn_backbone_id": self.gnn_backbone_id, 85 | "interaction_block": self.interaction_block, 86 | } 87 | 88 | @staticmethod 89 | def load_from_checkpoint( 90 | trainer_ckpt, gnn_backbone_id, interaction_block 91 | ) -> "FrankenSchNetWrap": 92 | ckpt_data = torch.load( 93 | trainer_ckpt, map_location=torch.device("cpu"), weights_only=False 94 | ) 95 | 96 | model_config = ckpt_data["config"]["model_attributes"] 97 | model_config["otf_graph"] = True 98 | 99 | model = FrankenSchNetWrap( 100 | **model_config, 101 | interaction_block=interaction_block, 102 | gnn_backbone_id=gnn_backbone_id, 103 | ) 104 | # Before we can load state, need to fix state-dict keys: 105 | # Match the "module." count in the keys of model and checkpoint state_dict 106 | # DataParallel model has 1 "module.", DistributedDataParallel has 2 "module." 107 | # Not using either of the above two would have no "module." 108 | ckpt_key_count = next(iter(ckpt_data["state_dict"])).count("module") 109 | mod_key_count = next(iter(model.state_dict())).count("module") 110 | key_count_diff = mod_key_count - ckpt_key_count 111 | if key_count_diff > 0: 112 | new_dict = { 113 | key_count_diff * "module." + k: v 114 | for k, v in ckpt_data["state_dict"].items() 115 | } 116 | elif key_count_diff < 0: 117 | new_dict = { 118 | k[len("module.") * abs(key_count_diff) :]: v 119 | for k, v in ckpt_data["state_dict"].items() 120 | } 121 | else: 122 | new_dict = ckpt_data["state_dict"] 123 | model.load_state_dict(new_dict) 124 | return model 125 | -------------------------------------------------------------------------------- /franken/backbones/wrappers/sevenn.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from types import MethodType 3 | from typing import Union 4 | from functools import partial 5 | 6 | import sevenn._keys as KEY 7 | import torch 8 | import torch.nn as nn 9 | from sevenn.util import model_from_checkpoint 10 | 11 | from franken.data.base import Configuration 12 | 13 | 14 | def extract_scalar_irrep(data, irreps): 15 | node_features = data[KEY.NODE_FEATURE] 16 | scalar_slice = irreps.slices()[0] 17 | scalar_features = node_features[..., scalar_slice] 18 | return scalar_features 19 | 20 | 21 | def franken_sevenn_descriptors( 22 | self, 23 | data: Configuration, 24 | interaction_layer: int, 25 | extract_after_act: bool = True, 26 | append_layers: bool = True, 27 | ): 28 | # Convert data to sevenn 29 | assert data.cell is not None 30 | sevenn_data = { 31 | KEY.NODE_FEATURE: data.atomic_numbers, 32 | KEY.ATOMIC_NUMBERS: data.atomic_numbers, 33 | KEY.POS: data.atom_pos, 34 | KEY.EDGE_IDX: data.edge_index, 35 | KEY.CELL: data.cell, 36 | KEY.CELL_SHIFT: data.shifts, # TODO: Check this correct? 37 | KEY.CELL_VOLUME: torch.einsum( 38 | "i,i", data.cell[0, :], torch.linalg.cross(data.cell[1, :], data.cell[2, :]) 39 | ), 40 | KEY.NUM_ATOMS: len(data.atomic_numbers), 41 | KEY.BATCH: data.batch_ids, 42 | } 43 | 44 | # From v0.9.3 to v10 sevenn introduced some changes in how models are build 45 | # (`build_E3_equivariant_model`), removing the EdgePreprocess class before the 46 | # network itself. The main purpose of EdgePreprocess was to initialize the 47 | # KEY.EDGE_VEC (r_ij: the vector between atom positions) and KEY.EDGE_LENGTH. 48 | # We replace that functionality here. 49 | # NOTE: the original preprocess had some special handling of the PBC cell 50 | # when self.is_stress was set to True. We're ignoring all that. 51 | # NOTE: as comparison to the original EdgePreprocess we assume `is_batch_data` 52 | # to be False. 53 | idx_src = sevenn_data[KEY.EDGE_IDX][0] 54 | idx_dst = sevenn_data[KEY.EDGE_IDX][1] 55 | pos = sevenn_data[KEY.POS] 56 | edge_vec = pos[idx_dst] - pos[idx_src] 57 | edge_vec = edge_vec + torch.einsum( 58 | "ni,ij->nj", sevenn_data[KEY.CELL_SHIFT], sevenn_data[KEY.CELL].view(3, 3) 59 | ) 60 | sevenn_data[KEY.EDGE_VEC] = edge_vec 61 | sevenn_data[KEY.EDGE_LENGTH] = torch.linalg.norm(edge_vec, dim=-1) 62 | 63 | # Iterate through the model's layers 64 | # the sanest way to figure out which layer we're at is through the 65 | # `_modules` attribute of `nn.Sequential` (which the Sevenn network 66 | # inherits from), which exposes key-value pairs. 67 | layer_idx = 0 68 | scalar_features_list = [] 69 | for i, (name, module) in enumerate(self._modules.items()): 70 | if "self_connection_intro" in name: 71 | layer_idx += 1 72 | 73 | new_sevenn_data = module(sevenn_data) 74 | if "equivariant_gate" in name: 75 | if extract_after_act: 76 | scalar_features = extract_scalar_irrep( 77 | new_sevenn_data, module.gate.irreps_out 78 | ) 79 | else: 80 | scalar_features = extract_scalar_irrep( 81 | sevenn_data, module.gate.irreps_in 82 | ) 83 | if append_layers: 84 | scalar_features_list.append(scalar_features) 85 | else: 86 | scalar_features_list[0] = scalar_features 87 | if layer_idx == interaction_layer: 88 | break 89 | sevenn_data = new_sevenn_data 90 | 91 | return torch.cat(scalar_features_list, dim=-1) 92 | 93 | 94 | def franken_sevenn_num_params(self) -> int: 95 | return sum(p.numel() for p in self.parameters()) 96 | 97 | 98 | def franken_sevenn_feature_dim( 99 | self, 100 | interaction_layer: int, 101 | extract_after_act: bool = True, 102 | append_layers: bool = True, 103 | ): 104 | layer_idx = 0 105 | tot_feat_dim = 0 106 | for i, (name, module) in enumerate(self._modules.items()): 107 | if "self_connection_intro" in name: 108 | layer_idx += 1 109 | if "equivariant_gate" in name: 110 | if extract_after_act: 111 | new_feat_dim = module.gate.irreps_out.count("0e") 112 | else: 113 | new_feat_dim = module.gate.irreps_in.count("0e") 114 | if append_layers: 115 | tot_feat_dim += new_feat_dim 116 | else: 117 | tot_feat_dim = new_feat_dim 118 | if layer_idx == interaction_layer: 119 | break 120 | return tot_feat_dim 121 | 122 | 123 | class FrankenSevenn: 124 | @staticmethod 125 | def load_from_checkpoint( 126 | trainer_ckpt: Union[str, Path], 127 | gnn_backbone_id: str, 128 | interaction_block: int, 129 | extract_after_act: bool = True, 130 | append_layers: bool = True, 131 | ): 132 | sevenn, config = model_from_checkpoint(str(trainer_ckpt)) 133 | assert isinstance(sevenn, nn.Module) 134 | sevenn.descriptors = MethodType( # type: ignore 135 | partial( 136 | franken_sevenn_descriptors, 137 | interaction_layer=interaction_block, 138 | extract_after_act=extract_after_act, 139 | append_layers=append_layers, 140 | ), 141 | sevenn, 142 | ) 143 | sevenn.num_params = MethodType(franken_sevenn_num_params, sevenn) # type: ignore 144 | sevenn.feature_dim = MethodType( # type: ignore 145 | partial( 146 | franken_sevenn_feature_dim, 147 | interaction_layer=interaction_block, 148 | extract_after_act=extract_after_act, 149 | append_layers=append_layers, 150 | ), 151 | sevenn, 152 | ) 153 | 154 | def init_args(self): 155 | return { 156 | "gnn_backbone_id": gnn_backbone_id, 157 | "interaction_block": interaction_block, 158 | "extract_after_act": extract_after_act, 159 | "append_layers": append_layers, 160 | } 161 | 162 | sevenn.init_args = MethodType(init_args, sevenn) # type: ignore 163 | return sevenn 164 | -------------------------------------------------------------------------------- /franken/calculators/__init__.py: -------------------------------------------------------------------------------- 1 | """Run molecular dynamics with learned potentials. 2 | 3 | Calculators are available for ASE and LAMMPS, but can be 4 | extended to support your favorite MD software. 5 | """ 6 | 7 | from .ase_calc import FrankenCalculator 8 | from .lammps_calc import LammpsFrankenCalculator 9 | 10 | __all__ = ( 11 | "FrankenCalculator", 12 | "LammpsFrankenCalculator", 13 | ) 14 | -------------------------------------------------------------------------------- /franken/calculators/ase_calc.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Literal, Union 3 | 4 | import numpy as np 5 | import torch 6 | from ase.calculators.calculator import Calculator, all_changes 7 | 8 | from franken.data import BaseAtomsDataset, Configuration 9 | from franken.rf.model import FrankenPotential 10 | from franken.utils.misc import get_device_name 11 | 12 | 13 | class FrankenCalculator(Calculator): 14 | """Calculator for ASE with franken models 15 | 16 | Attributes: 17 | implemented_properties: 18 | Lists properties which can be asked from this calculator, notably "energy" and "forces". 19 | """ 20 | 21 | implemented_properties = ["energy", "forces"] 22 | default_parameters = {} 23 | nolabel = True # ?? 24 | 25 | def __init__( 26 | self, 27 | franken_ckpt: Union[FrankenPotential, str, Path], 28 | device=None, 29 | rf_weight_id: int | None = None, 30 | forces_mode: Literal["torch.func", "torch.autograd"] = "torch.autograd", 31 | **calc_kwargs, 32 | ): 33 | """Initialize FrankenCalculator class from a franken model. 34 | 35 | Args: 36 | franken_ckpt : Path to the franken model. 37 | This class accepts pre-loaded models, as well as jitted models (with `torch.jit`). 38 | device : PyTorch device specification for where the model should reside 39 | (e.g. "cuda:0" for GPU placement or "cpu" for CPU placement). 40 | rf_weight_id : ID of the random feature weights. 41 | Can generally be left to ``None`` unless the checkpoint contains multiple trained models. 42 | """ 43 | # TODO: Remove forces_mode, torch.autograd is always the right way. 44 | super().__init__(**calc_kwargs) 45 | self.franken: FrankenPotential 46 | if isinstance(franken_ckpt, torch.nn.Module): 47 | self.franken = franken_ckpt 48 | if device is not None: 49 | self.franken.to(device) 50 | else: 51 | # Handle jitted torchscript archives and normal files 52 | try: 53 | self.franken = torch.jit.load(franken_ckpt, map_location=device) 54 | except RuntimeError as e: 55 | if "PytorchStreamReader" not in str(e): 56 | raise 57 | self.franken = FrankenPotential.load( # type: ignore 58 | franken_ckpt, 59 | map_location=device, 60 | rf_weight_id=rf_weight_id, 61 | ) 62 | 63 | self.dataset = BaseAtomsDataset.from_path( 64 | data_path=None, 65 | split="md", 66 | gnn_config=self.franken.gnn_config, 67 | ) 68 | self.device = ( 69 | device if device is not None else next(self.franken.parameters()).device 70 | ) 71 | self.forces_mode = forces_mode 72 | 73 | def calculate( 74 | self, 75 | atoms=None, 76 | properties=None, 77 | system_changes=all_changes, 78 | ): 79 | if properties is None: 80 | properties = self.implemented_properties 81 | if "forces" not in properties: 82 | forces_mode = "no_forces" 83 | else: 84 | forces_mode = self.forces_mode 85 | 86 | super().calculate(atoms, properties, system_changes) 87 | 88 | # self.atoms is set in the super() call. Unclear why it should be preferred over `atoms` 89 | config_idx = self.dataset.add_configuration(self.atoms) # type: ignore 90 | cpu_data = self.dataset.__getitem__(config_idx, no_targets=True) 91 | assert isinstance(cpu_data, Configuration) 92 | data = cpu_data.to(self.device) 93 | 94 | energy, forces = self.franken.energy_and_forces(data, forces_mode=forces_mode) 95 | 96 | if energy.ndim == 0: 97 | self.results["energy"] = energy.item() 98 | else: 99 | self.results["energy"] = np.squeeze(energy.numpy(force=True)) 100 | if "forces" in properties: 101 | assert forces is not None 102 | self.results["forces"] = np.squeeze(forces.numpy(force=True)) 103 | 104 | 105 | def calculator_throughput( 106 | calculator, atoms_list, num_repetitions=1, warmup_configs=5, verbose=True 107 | ): 108 | from time import perf_counter 109 | 110 | hardware = get_device_name(calculator.device) 111 | 112 | _atom_numbers = set(len(atoms) for atoms in atoms_list) 113 | assert ( 114 | len(_atom_numbers) == 1 115 | ), f"This function only accepts configurations with the same number of atoms, while found configurations with {_atom_numbers} number of atoms" 116 | natoms = _atom_numbers.pop() 117 | 118 | assert len(atoms_list) > warmup_configs 119 | for idx in range(warmup_configs): 120 | calculator.calculate(atoms_list[idx]) 121 | time_init = perf_counter() 122 | for _ in range(num_repetitions): 123 | for atoms in atoms_list: 124 | calculator.calculate(atoms) 125 | time = perf_counter() - time_init 126 | configs_per_sec = (len(atoms_list) * num_repetitions) / time 127 | results = { 128 | "throughput": configs_per_sec, 129 | "atoms": natoms, 130 | "hardware": hardware, 131 | } 132 | if verbose: 133 | print( 134 | f"{results['throughput']:.1f} cfgs/sec ({results['atoms']} atoms) | {results['hardware']}" 135 | ) 136 | return results 137 | -------------------------------------------------------------------------------- /franken/calculators/lammps_calc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from typing import Optional 4 | import torch 5 | from e3nn.util import jit 6 | 7 | from franken.data.base import Configuration 8 | from franken.rf.model import FrankenPotential 9 | 10 | 11 | @jit.compile_mode("script") 12 | class LammpsFrankenCalculator(torch.nn.Module): 13 | def __init__( 14 | self, 15 | franken_model: FrankenPotential, 16 | ): 17 | """Initialize LAMMPS Calculator 18 | 19 | Args: 20 | franken_model (FrankenPotential): The base franken model used in this MD calculator 21 | 22 | Note: 23 | The backbone underlying the franken model must be a MACE model. This is because we 24 | are re-using the LAMMPS interface developed by the MACE authors. 25 | """ 26 | super().__init__() 27 | 28 | self.model = franken_model 29 | self.register_buffer("atomic_numbers", self.model.gnn.atomic_numbers) 30 | self.register_buffer("r_max", self.model.gnn.r_max) 31 | self.register_buffer("num_interactions", self.model.gnn.num_interactions) 32 | # this attribute is used for dtype detection in LAMMPS-MACE. 33 | # See: https://github.com/ACEsuit/lammps/blob/mace/src/ML-MACE/pair_mace.cpp#314 34 | self.model.node_embedding = self.model.gnn.node_embedding 35 | 36 | for param in self.model.parameters(): 37 | param.requires_grad = False 38 | 39 | def forward( 40 | self, 41 | data: dict[str, torch.Tensor], 42 | local_or_ghost: torch.Tensor, 43 | compute_virials: bool = False, 44 | ) -> dict[str, torch.Tensor | None]: 45 | """Compute energies and forces of a given configuration. 46 | 47 | This module is meant to be used in conjunction with LAMMPS, 48 | and this function should not be called directly. The format of 49 | the input data is designed to work with the MACE-LAMMPS fork. 50 | 51 | Warning: 52 | Stresses and virials are not supported by franken. Since they 53 | are required to be set by LAMMPS, this function sets them to tensors 54 | of the appropriate shape filled with zeros. Make sure that 55 | the chosen MD method does not depend on these quantities. 56 | """ 57 | # node_attrs is a one-hot representation of the atom types. atom_nums should be the actual atomic numbers! 58 | # we rely on correct sorting. This is the same as in MACE. 59 | atom_nums = self.atomic_numbers[torch.argmax(data["node_attrs"], dim=1)] 60 | 61 | franken_data = Configuration( 62 | atom_pos=data["positions"].double(), 63 | atomic_numbers=atom_nums, 64 | natoms=torch.tensor( 65 | len(atom_nums), dtype=torch.int32, device=atom_nums.device 66 | ).view(1), 67 | node_attrs=data["node_attrs"].double(), 68 | edge_index=data["edge_index"], 69 | shifts=data["shifts"], 70 | unit_shifts=data["unit_shifts"], 71 | ) 72 | energy, forces = self.model(franken_data) # type: ignore 73 | # Kokkos doesn't like total_energy_local and only looks at node_energy. 74 | # We hack around this: 75 | node_energy = energy.repeat(len(atom_nums)).div(len(atom_nums)) 76 | virials: Optional[torch.Tensor] = None 77 | if compute_virials: 78 | virials = torch.zeros((1, 3, 3), dtype=forces.dtype, device=forces.device) 79 | return { 80 | "total_energy_local": energy, 81 | "node_energy": node_energy, 82 | "forces": forces, 83 | "virials": virials, 84 | } 85 | 86 | @staticmethod 87 | def create_lammps_model(model_path: str, rf_weight_id: int | None) -> str: 88 | """Compile a franken model into a LAMMPS calculator 89 | 90 | Args: 91 | model_path (str): 92 | path to the franken model checkpoint. 93 | rf_weight_id (int | None): 94 | ID of the random feature weights. Can generally be left to ``None`` unless 95 | the checkpoint contains multiple trained models. 96 | 97 | Returns: 98 | str: the path where the LAMMPS-compatible model was saved to. 99 | """ 100 | franken_model = FrankenPotential.load( 101 | model_path, 102 | map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 103 | rf_weight_id=rf_weight_id, 104 | ) 105 | # NOTE: 106 | # Kokkos is hardcoded to double and will silently corrupt data if the model 107 | # does not use dtype double. 108 | franken_model = franken_model.double().to("cpu") 109 | lammps_model = LammpsFrankenCalculator(franken_model) 110 | lammps_model_compiled = jit.compile(lammps_model) 111 | 112 | save_path = f"{os.path.splitext(model_path)[0]}-lammps.pt" 113 | print(f"Saving compiled model to '{save_path}'") 114 | lammps_model_compiled.save(save_path) 115 | return save_path 116 | 117 | 118 | def build_arg_parser(): 119 | parser = argparse.ArgumentParser( 120 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 121 | description="Convert a franken model to be able to use it with LAMMPS", 122 | ) 123 | parser.add_argument( 124 | "--model_path", 125 | type=str, 126 | help="Path to the model to be converted to LAMMPS", 127 | ) 128 | parser.add_argument( 129 | "--rf_weight_id", 130 | type=int, 131 | help="Head of the model to be converted to LAMMPS", 132 | default=None, 133 | ) 134 | return parser 135 | 136 | 137 | def create_lammps_model_cli(): 138 | parser = build_arg_parser() 139 | args = parser.parse_args() 140 | LammpsFrankenCalculator.create_lammps_model(args.model_path, args.rf_weight_id) # type: ignore 141 | 142 | 143 | if __name__ == "__main__": 144 | create_lammps_model_cli() 145 | 146 | 147 | # For sphinx docs 148 | get_parser_fn = lambda: build_arg_parser() # noqa: E731 149 | -------------------------------------------------------------------------------- /franken/data/__init__.py: -------------------------------------------------------------------------------- 1 | from franken.data.base import BaseAtomsDataset, Configuration, Target 2 | 3 | 4 | __all__ = [ 5 | "BaseAtomsDataset", 6 | "Configuration", 7 | "Target", 8 | ] 9 | -------------------------------------------------------------------------------- /franken/data/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Iterator 2 | 3 | from torch.utils.data import Sampler, Dataset 4 | import torch.distributed as dist 5 | 6 | 7 | class SimpleUnevenDistributedSampler(Sampler): 8 | def __init__( 9 | self, 10 | dataset: Dataset, 11 | num_replicas: Optional[int] = None, 12 | rank: Optional[int] = None, 13 | ) -> None: 14 | if num_replicas is None: 15 | if not dist.is_available(): 16 | raise RuntimeError("Requires distributed package to be available") 17 | num_replicas = dist.get_world_size() 18 | if rank is None: 19 | if not dist.is_available(): 20 | raise RuntimeError("Requires distributed package to be available") 21 | rank = dist.get_rank() 22 | if rank >= num_replicas or rank < 0: 23 | raise ValueError( 24 | f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" 25 | ) 26 | self.dataset = dataset 27 | self.num_replicas = num_replicas 28 | self.rank = rank 29 | self.epoch = 0 30 | self.total_size = len(self.dataset) # type: ignore[arg-type] 31 | # num_samples indicates the number of samples for the current process 32 | self.num_samples = len(range(self.rank, self.total_size, self.num_replicas)) 33 | 34 | def __iter__(self) -> Iterator: 35 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] 36 | 37 | # subsample 38 | indices = indices[self.rank : self.total_size : self.num_replicas] 39 | assert len(indices) == self.num_samples 40 | 41 | return iter(indices) 42 | 43 | def __len__(self) -> int: 44 | return self.num_samples 45 | -------------------------------------------------------------------------------- /franken/data/fairchem.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | from fairchem.core.preprocessing import AtomsToGraphs as FairchemAtomsToGraphs 5 | 6 | import franken.utils.distributed as dist_utils 7 | from franken.backbones.utils import get_checkpoint_path 8 | from franken.data import BaseAtomsDataset, Configuration, Target 9 | 10 | 11 | class FairchemAtomsDataset(BaseAtomsDataset): 12 | def __init__( 13 | self, 14 | data_path: str | Path | None, 15 | split: str, 16 | num_random_subsamples: int | None = None, 17 | subsample_rng: int | None = None, 18 | gnn_backbone_id: str | torch.nn.Module | None = None, 19 | cutoff=6.0, 20 | max_num_neighbors=200, 21 | precompute=True, 22 | ): 23 | super().__init__(data_path, split, num_random_subsamples, subsample_rng) 24 | 25 | if gnn_backbone_id is not None: 26 | cutoff, max_num_neighbors = self.load_info_from_gnn_config(gnn_backbone_id) 27 | 28 | if split == "md": 29 | # Cannot get energy and forces in MD mode (the calculator fails) 30 | self.a2g = FairchemAtomsToGraphs( 31 | max_neigh=max_num_neighbors, 32 | radius=cutoff, # type: ignore 33 | ) 34 | else: 35 | self.a2g = FairchemAtomsToGraphs( 36 | max_neigh=max_num_neighbors, 37 | radius=cutoff, 38 | r_energy=True, 39 | r_forces=True, # type: ignore 40 | ) 41 | self.graphs = None 42 | if precompute and len(self.ase_atoms) > 0: 43 | self.graphs = self.a2g.convert_all( 44 | self.ase_atoms, disable_tqdm=dist_utils.get_rank() != 0 45 | ) 46 | 47 | def load_info_from_gnn_config(self, gnn_backbone_id: str | torch.nn.Module): 48 | if not isinstance(gnn_backbone_id, str): 49 | raise ValueError( 50 | "Backbone path must be provided instead of the preloaded model." 51 | ) 52 | ckpt_path = get_checkpoint_path(gnn_backbone_id) 53 | model = torch.load(ckpt_path, map_location="cpu", weights_only=False) 54 | model_cfg = model["config"]["model"] 55 | cutoff = getattr(model_cfg, "cutoff", 6.0) 56 | max_num_neighbors = getattr(model_cfg, "max_num_neighbors", 200) 57 | del model, model_cfg 58 | return cutoff, max_num_neighbors 59 | 60 | def graph_to_inputs(self, graph): 61 | return Configuration( 62 | atom_pos=graph.pos, # type: ignore 63 | atomic_numbers=graph.atomic_numbers.int(), 64 | natoms=torch.tensor(graph.natoms).view(1), 65 | batch_ids=( 66 | graph.batch 67 | if graph.batch is not None 68 | else torch.zeros(graph.natoms, dtype=torch.int64) 69 | ), 70 | cell=graph.cell, 71 | pbc=getattr(graph, "pbc", None), 72 | ) 73 | 74 | def graph_to_targets(self, graph): 75 | energy = torch.tensor(graph.energy) 76 | return Target(energy=energy, forces=graph.forces) 77 | 78 | def __getitem__(self, idx, no_targets: bool = False): 79 | """Returns an array of (inputs, outputs) with inputs being a configuration 80 | and outputs being the target (energy and forces). 81 | Note: ONLY for the 'train' split, the energy_shift is removed from the target. 82 | """ 83 | if self.graphs is None: 84 | graph = self.a2g.convert(self.ase_atoms[idx]) 85 | else: 86 | graph = self.graphs[idx] 87 | 88 | data = self.graph_to_inputs(graph) 89 | if no_targets: 90 | return data 91 | target = self.graph_to_targets(graph) 92 | 93 | if self.split == "train": 94 | target.energy -= self.energy_shifts[idx] 95 | 96 | return data, target 97 | -------------------------------------------------------------------------------- /franken/data/mace.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | 4 | import torch 5 | from mace.data import AtomicData 6 | from mace.data.utils import config_from_atoms 7 | from mace.tools.utils import AtomicNumberTable 8 | from tqdm.auto import tqdm 9 | 10 | import franken.utils.distributed as dist_utils 11 | from franken.backbones.utils import get_checkpoint_path 12 | from franken.data import BaseAtomsDataset, Configuration, Target 13 | from franken.utils.misc import torch_load_maybejit 14 | 15 | 16 | class MACEAtomsToGraphs: 17 | def __init__(self, z_table: AtomicNumberTable, cutoff: float): 18 | self.cutoff = cutoff 19 | self.z_table = z_table 20 | 21 | def convert(self, atoms): 22 | mace_config = config_from_atoms(atoms) 23 | graph = AtomicData.from_config( 24 | mace_config, z_table=self.z_table, cutoff=self.cutoff 25 | ) 26 | graph.atomic_numbers = torch.tensor(atoms.get_atomic_numbers()).int() 27 | return graph 28 | 29 | def convert_all( 30 | self, 31 | atoms_list, 32 | process_rank: Optional[int] = None, 33 | split_name: Optional[str] = None, 34 | ): 35 | graphs = [] 36 | atoms_iter = atoms_list 37 | if process_rank is None: 38 | process_rank = dist_utils.get_rank() 39 | if process_rank == 0: 40 | desc = "ASE -> MACE" 41 | if split_name is not None: 42 | desc += f" ({split_name})" 43 | atoms_iter = tqdm(atoms_list, desc=desc) 44 | for atoms in atoms_iter: 45 | graphs.append(self.convert(atoms)) 46 | return graphs 47 | 48 | 49 | class MACEAtomsDataset(BaseAtomsDataset): 50 | def __init__( 51 | self, 52 | data_path: str | Path | None, 53 | split: str, 54 | num_random_subsamples: int | None = None, 55 | subsample_rng: int | None = None, 56 | gnn_backbone_id: str | torch.nn.Module | None = None, 57 | z_table: AtomicNumberTable | None = None, 58 | cutoff=6.0, 59 | precompute=True, 60 | ): 61 | super().__init__(data_path, split, num_random_subsamples, subsample_rng) 62 | if gnn_backbone_id is not None: 63 | z_table, cutoff = self.load_info_from_gnn_config(gnn_backbone_id) 64 | else: 65 | assert z_table is not None 66 | 67 | self.a2g = MACEAtomsToGraphs(z_table, cutoff) 68 | self.graphs = None 69 | if precompute and len(self.ase_atoms) > 0: 70 | self.graphs = self.a2g.convert_all( 71 | self.ase_atoms, 72 | split_name=self.split, 73 | ) 74 | 75 | def load_info_from_gnn_config(self, gnn_backbone_id: str | torch.nn.Module): 76 | if isinstance(gnn_backbone_id, str): 77 | ckpt_path = get_checkpoint_path(gnn_backbone_id) 78 | mace_gnn = torch_load_maybejit( 79 | ckpt_path, map_location="cpu", weights_only=False 80 | ) 81 | else: 82 | mace_gnn = gnn_backbone_id 83 | z_table = AtomicNumberTable([z.item() for z in mace_gnn.atomic_numbers]) 84 | cutoff = mace_gnn.r_max.item() 85 | del mace_gnn 86 | return z_table, cutoff 87 | 88 | def __getitem__(self, idx, no_targets: bool = False): 89 | """Returns an array of (inputs, outputs) with inputs being a configuration 90 | and outputs being the target (energy and forces). 91 | Note: ONLY for the 'train' split, the energy_shift is removed from the target. 92 | """ 93 | if self.graphs is None: 94 | graph = self.a2g.convert(self.ase_atoms[idx]) 95 | else: 96 | graph = self.graphs[idx] 97 | 98 | data = Configuration( 99 | atom_pos=graph.positions, 100 | atomic_numbers=graph.atomic_numbers, 101 | natoms=torch.tensor(len(graph.atomic_numbers)).view(1), 102 | node_attrs=graph.node_attrs, 103 | edge_index=graph.edge_index, 104 | shifts=graph.shifts, 105 | unit_shifts=graph.unit_shifts, 106 | ) 107 | if no_targets: 108 | return data 109 | 110 | energy = torch.tensor( 111 | self.ase_atoms[idx].get_potential_energy(apply_constraint=False) 112 | ) 113 | if self.split == "train": 114 | energy = energy - self.energy_shifts[idx] 115 | 116 | target = Target( 117 | energy=energy, 118 | forces=torch.Tensor(self.ase_atoms[idx].get_forces(apply_constraint=False)), 119 | ) 120 | return data, target 121 | -------------------------------------------------------------------------------- /franken/data/sevenn.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | 4 | import ase 5 | import numpy as np 6 | import torch 7 | from tqdm.auto import tqdm 8 | import sevenn._keys as KEY 9 | from sevenn.atom_graph_data import AtomGraphData 10 | from sevenn.train.dataload import atoms_to_graph 11 | 12 | import franken.utils.distributed as dist_utils 13 | from franken.backbones.utils import get_checkpoint_path 14 | from franken.data import BaseAtomsDataset, Configuration, Target 15 | 16 | 17 | class SevennAtomsToGraphs: 18 | """ 19 | Args: 20 | cutoff (float): Cutoff in Angstrom to build atoms graph from positions 21 | transfer_info (bool): if True, copy info from atoms to graph 22 | """ 23 | 24 | def __init__(self, cutoff: float, transfer_info: bool, y_from_calc: bool): 25 | self.cutoff = cutoff 26 | self.transfer_info = transfer_info 27 | self.y_from_calc = y_from_calc 28 | 29 | def convert(self, atoms: ase.Atoms): 30 | if not self.y_from_calc: 31 | # It means we're not interested in forces and energies. 32 | # workaround is to set the attributes to invalid and then 33 | # remove the attributes 34 | atoms.info["y_energy"] = np.nan 35 | atoms.arrays["y_force"] = np.full(atoms.arrays["positions"].shape, np.nan) 36 | graph = atoms_to_graph( 37 | atoms, 38 | cutoff=self.cutoff, 39 | transfer_info=self.transfer_info, 40 | y_from_calc=self.y_from_calc, 41 | with_shift=True, 42 | ) 43 | if not self.y_from_calc: 44 | del graph[KEY.ENERGY] 45 | del graph[KEY.FORCE] 46 | del atoms.info["y_energy"] 47 | del atoms.arrays["y_force"] 48 | atom_graph_data = AtomGraphData.from_numpy_dict(graph) 49 | return atom_graph_data 50 | 51 | def convert_all( 52 | self, 53 | atoms_list, 54 | process_rank: Optional[int] = None, 55 | split_name: Optional[str] = None, 56 | ): 57 | graphs = [] 58 | atoms_iter = atoms_list 59 | if process_rank is None: 60 | process_rank = dist_utils.get_rank() 61 | if process_rank == 0: 62 | desc = "ASE -> SEVENN" 63 | if split_name is not None: 64 | desc += f" ({split_name})" 65 | atoms_iter = tqdm(atoms_list, desc=desc) 66 | for atoms in atoms_iter: 67 | graphs.append(self.convert(atoms)) 68 | return graphs 69 | 70 | 71 | class SevennAtomsDataset(BaseAtomsDataset): 72 | def __init__( 73 | self, 74 | data_path: str | Path | None, 75 | split: str, 76 | num_random_subsamples: int | None = None, 77 | subsample_rng: int | None = None, 78 | gnn_backbone_id: str | torch.nn.Module | None = None, 79 | cutoff: float = 6.0, 80 | precompute=True, 81 | ): 82 | super().__init__(data_path, split, num_random_subsamples, subsample_rng) 83 | if gnn_backbone_id is not None: 84 | cutoff = self.load_info_from_gnn_config(gnn_backbone_id) 85 | else: 86 | assert cutoff is not None 87 | 88 | if split == "md": 89 | self.a2g = SevennAtomsToGraphs( 90 | cutoff, transfer_info=False, y_from_calc=False 91 | ) 92 | else: 93 | self.a2g = SevennAtomsToGraphs( 94 | cutoff, transfer_info=False, y_from_calc=True 95 | ) 96 | self.graphs = None 97 | if precompute and len(self.ase_atoms) > 0: 98 | self.graphs = self.a2g.convert_all( 99 | self.ase_atoms, 100 | split_name=self.split, 101 | ) 102 | 103 | def load_info_from_gnn_config(self, gnn_backbone_id: str | torch.nn.Module): 104 | if not isinstance(gnn_backbone_id, str): 105 | raise ValueError( 106 | "Backbone path must be provided instead of the preloaded model." 107 | ) 108 | ckpt_path = get_checkpoint_path(gnn_backbone_id) 109 | checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) 110 | config = checkpoint["config"] 111 | cutoff = config["cutoff"] 112 | del checkpoint, config 113 | return cutoff 114 | 115 | def __getitem__(self, idx, no_targets: bool = False): 116 | """Returns an array of (inputs, outputs) with inputs being a configuration 117 | and outputs being the target (energy and forces). 118 | Note: ONLY for the 'train' split, the energy_shift is removed from the target. 119 | """ 120 | if self.graphs is None: 121 | graph = self.a2g.convert(self.ase_atoms[idx]) 122 | else: 123 | graph = self.graphs[idx] 124 | 125 | data = Configuration( 126 | atom_pos=graph.pos, 127 | atomic_numbers=graph[KEY.ATOMIC_NUMBERS], 128 | natoms=torch.tensor(len(graph[KEY.ATOMIC_NUMBERS])).view(1), 129 | edge_index=graph.edge_index, 130 | shifts=graph[KEY.CELL_SHIFT], 131 | cell=graph[KEY.CELL], 132 | batch_ids=( 133 | graph.batch 134 | if graph.batch is not None 135 | else torch.zeros(graph[KEY.NUM_ATOMS], dtype=torch.int64) 136 | ), 137 | ) 138 | if no_targets: 139 | return data 140 | 141 | energy = graph[KEY.ENERGY] 142 | if self.split == "train": 143 | energy = energy - self.energy_shifts[idx] 144 | target = Target( 145 | energy=energy, 146 | forces=graph[KEY.FORCE], 147 | ) 148 | return data, target 149 | -------------------------------------------------------------------------------- /franken/datasets/PtH2O/pth2o_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import shutil 3 | import zipfile 4 | 5 | import ase 6 | import ase.io 7 | import numpy as np 8 | 9 | from franken.datasets.registry import DATASET_REGISTRY, BaseRegisteredDataset 10 | from franken.utils.file_utils import download_file 11 | 12 | 13 | @DATASET_REGISTRY.register("PtH2O") 14 | class PtH2ORegisteredDataset(BaseRegisteredDataset): 15 | relative_paths = { 16 | "PtH2O": { 17 | "train": "PtH2O/train.extxyz", 18 | "val": "PtH2O/valid.extxyz", 19 | "test": "PtH2O/test.extxyz", 20 | }, 21 | } 22 | 23 | @classmethod 24 | def get_path( 25 | cls, name: str, split: str, base_path: Path | None, download: bool = True 26 | ): 27 | if base_path is None: 28 | raise KeyError(None) 29 | relative_path = cls.relative_paths[name][split] 30 | path = base_path / relative_path 31 | if not path.is_file() and download: 32 | cls.download(base_path) 33 | if path.is_file(): 34 | return path 35 | else: 36 | raise ValueError(f"Dataset not found at '{path.resolve()}'") 37 | 38 | @classmethod 39 | def download(cls, base_path: Path): 40 | pth2o_base_path = base_path / "PtH2O" 41 | pth2o_base_path.mkdir(exist_ok=True, parents=True) 42 | # Download 43 | download_file( 44 | url="https://data.dtu.dk/ndownloader/files/29141586", 45 | filename=pth2o_base_path / "data.zip", 46 | desc="Downloading PtH2O dataset", 47 | expected_md5="acd748f7f32c66961c90cb15457f7bae", 48 | ) 49 | # Extract 50 | with zipfile.ZipFile(pth2o_base_path / "data.zip", "r") as zf: 51 | zf.extractall(pth2o_base_path) 52 | # Read full dataset 53 | full_traj = ase.io.read( 54 | pth2o_base_path / "Dataset_and_training_files" / "dataset.traj", index=":" 55 | ) 56 | assert isinstance(full_traj, list) 57 | # Split into train/val/test 58 | np.random.seed(42) 59 | np.random.shuffle(full_traj) 60 | train_traj = full_traj[:30_000] 61 | valid_traj = full_traj[30_000:31_000] 62 | test_traj = full_traj[31_000:] 63 | # Saved shuffled to disk 64 | ase.io.write(pth2o_base_path / "train.extxyz", train_traj) 65 | ase.io.write(pth2o_base_path / "valid.extxyz", valid_traj) 66 | ase.io.write(pth2o_base_path / "test.extxyz", test_traj) 67 | # Cleanup 68 | (pth2o_base_path / "data.zip").unlink() 69 | shutil.rmtree(pth2o_base_path / "Dataset_and_training_files") 70 | 71 | 72 | if __name__ == "__main__": 73 | PtH2ORegisteredDataset.download(Path(__file__).parent.parent) 74 | -------------------------------------------------------------------------------- /franken/datasets/TM23/tm23_dataset.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import operator 3 | from pathlib import Path 4 | import zipfile 5 | 6 | from franken.datasets.registry import DATASET_REGISTRY, BaseRegisteredDataset 7 | from franken.utils.file_utils import download_file 8 | 9 | TM23_ELEMENTS = [ 10 | "Ag", 11 | "Au", 12 | "Cd", 13 | "Co", 14 | "Cr", 15 | "Cu", 16 | "Fe", 17 | "Hf", 18 | "Hg", 19 | "Ir", 20 | "Mn", 21 | "Mo", 22 | "Nb", 23 | "Ni", 24 | "Os", 25 | "Pd", 26 | "Pt", 27 | "Re", 28 | "Rh", 29 | "Ru", 30 | "Ta", 31 | "Tc", 32 | "Ti", 33 | "V", 34 | "W", 35 | "Zn", 36 | "Zr", 37 | ] 38 | 39 | TM23_DATASETS = list( 40 | reduce( 41 | operator.concat, 42 | [ 43 | [f"TM23/{el}", f"TM23/{el}-cold", f"TM23/{el}-warm", f"TM23/{el}-melt"] 44 | for el in TM23_ELEMENTS 45 | ], 46 | ) 47 | ) 48 | 49 | 50 | @DATASET_REGISTRY.register(TM23_DATASETS) 51 | class TM23RegisteredDataset(BaseRegisteredDataset): 52 | relative_paths = reduce( 53 | operator.ior, 54 | [ 55 | { 56 | f"TM23/{el}": { 57 | "train": f"TM23/{el}_2700cwm_train.xyz", 58 | "val": f"TM23/{el}_2700cwm_test.xyz", 59 | }, 60 | f"TM23/{el}-cold": { 61 | "train": f"TM23/{el}_cold_nequip_train.xyz", 62 | "val": f"TM23/{el}_cold_nequip_test.xyz", 63 | }, 64 | f"TM23/{el}-warm": { 65 | "train": f"TM23/{el}_warm_nequip_train.xyz", 66 | "val": f"TM23/{el}_warm_nequip_test.xyz", 67 | }, 68 | f"TM23/{el}-melt": { 69 | "train": f"TM23/{el}_melt_nequip_train.xyz", 70 | "val": f"TM23/{el}_melt_nequip_test.xyz", 71 | }, 72 | } 73 | for el in TM23_ELEMENTS 74 | ], 75 | {}, 76 | ) # merge list of dicts 77 | 78 | @classmethod 79 | def get_path( 80 | cls, name: str, split: str, base_path: Path | None, download: bool = True 81 | ): 82 | if base_path is None: 83 | raise KeyError(None) 84 | relative_path = cls.relative_paths[name][split] 85 | path = base_path / relative_path 86 | if not path.is_file() and download: 87 | cls.download(base_path) 88 | if path.is_file(): 89 | return path 90 | else: 91 | raise ValueError(f"Dataset not found at '{path.resolve()}'") 92 | 93 | @classmethod 94 | def download(cls, base_path: Path): 95 | tm23_base_path = base_path / "TM23" 96 | tm23_base_path.mkdir(exist_ok=True, parents=True) 97 | # Download 98 | download_file( 99 | url="https://archive.materialscloud.org/record/file?record_id=2113&filename=benchmarking_master_collection-20240316T202423Z-001.zip", 100 | filename=tm23_base_path / "data.zip", 101 | desc="Downloading TM23 dataset", 102 | ) 103 | # Extract 104 | with zipfile.ZipFile(tm23_base_path / "data.zip", "r") as zf: 105 | zf.extractall(tm23_base_path) 106 | # Move files up one level 107 | for origin in (tm23_base_path / "benchmarking_master_collection").glob("*"): 108 | origin.rename(tm23_base_path / origin.name) 109 | # Cleanup 110 | (tm23_base_path / "data.zip").unlink() 111 | (tm23_base_path / "benchmarking_master_collection").rmdir() 112 | -------------------------------------------------------------------------------- /franken/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from franken.datasets.registry import DATASET_REGISTRY 2 | 3 | # Ensure all sub-datasets are imported so that they are registered. 4 | from .water import water_dataset # noqa: F401 5 | from .TM23 import tm23_dataset # noqa: F401 6 | from .PtH2O import pth2o_dataset # noqa: F401 7 | from .test import test_dataset # noqa: F401 8 | 9 | __all__ = ("DATASET_REGISTRY",) 10 | -------------------------------------------------------------------------------- /franken/datasets/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, ClassVar 2 | from pathlib import Path 3 | 4 | 5 | class BaseRegisteredDataset: 6 | relative_paths: ClassVar[dict[str, dict[str, str]]] 7 | 8 | @classmethod 9 | def get_path( 10 | cls, name: str, split: str, base_path: Path | None, download: bool = True 11 | ) -> Path: 12 | raise NotImplementedError() 13 | 14 | @classmethod 15 | def is_valid_split(cls, name: str, split: str) -> bool: 16 | return split in cls.relative_paths[name] 17 | 18 | 19 | _KT = str 20 | _VT = type[BaseRegisteredDataset] 21 | 22 | 23 | class DatasetRegistry(dict[_KT, _VT]): 24 | def register(self, name: _KT | list[_KT] | tuple[_KT]) -> Callable[[_VT], _VT]: 25 | def decorator(func: _VT) -> _VT: 26 | if isinstance(name, (list, tuple)): 27 | for name_single in name: 28 | self[name_single] = func 29 | else: 30 | self[name] = func 31 | return func 32 | 33 | return decorator 34 | 35 | def get_path( 36 | self, name: str, split: str, base_path: Path | None, download: bool = True 37 | ): 38 | """Fetch the path for a dataset-split pair. If the dataset does not exist under 39 | the `base_path` directory, a download will be attempted. 40 | 41 | Args: 42 | name (str): dataset name (e.g. "water", "TM23/Ag-cold", "PtH2O") 43 | split (str): data-split, for example "train", "val" or "test" 44 | base_path (Path): the base path at which the dataset is stored. 45 | download (bool, optional): Whether to download the dataset if it does not exist. 46 | Defaults to True. 47 | 48 | Returns: 49 | dset_path (Path): a path to the ase-readable dataset. 50 | """ 51 | return self[name].get_path(name, split, base_path, download) 52 | 53 | def is_valid_split(self, name: str, split: str) -> bool: 54 | return self[name].is_valid_split(name, split) 55 | 56 | 57 | DATASET_REGISTRY = DatasetRegistry() 58 | -------------------------------------------------------------------------------- /franken/datasets/split_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from ase.io import read, write 4 | import os 5 | 6 | 7 | def split_trajectory( 8 | file_path, split_ratio=0.8, seed=None, train_output=None, val_output=None 9 | ): 10 | # Set the random seed if provided, for reproducibility 11 | if seed is not None: 12 | random.seed(seed) 13 | 14 | # Load the frames from the trajectory file 15 | frames = read(file_path, index=":") 16 | num_frames = len(frames) 17 | print(f"Loaded {num_frames} frames from '{file_path}'.") 18 | 19 | # Shuffle and split frames based on the split ratio 20 | indices = list(range(num_frames)) 21 | random.shuffle(indices) 22 | 23 | # Calculate the split index 24 | split_index = int(num_frames * split_ratio) 25 | train_indices = indices[:split_index] 26 | val_indices = indices[split_index:] 27 | 28 | # Create train and validation splits 29 | train_frames = [frames[i] for i in train_indices] 30 | val_frames = [frames[i] for i in val_indices] 31 | 32 | # Set default output filenames if not provided 33 | if train_output is None: 34 | train_output = f"{os.path.splitext(file_path)[0]}_train.xyz" 35 | if val_output is None: 36 | val_output = f"{os.path.splitext(file_path)[0]}_val.xyz" 37 | 38 | # Write the split trajectories to separate files 39 | write(train_output, train_frames) 40 | write(val_output, val_frames) 41 | 42 | print( 43 | f"Saved {len(train_frames)} frames to '{train_output}' and {len(val_frames)} frames to '{val_output}'." 44 | ) 45 | 46 | 47 | def main(): 48 | parser = argparse.ArgumentParser( 49 | description="Split an ASE trajectory file into train and validation sets in a reproducible way." 50 | ) 51 | 52 | # Mandatory argument 53 | parser.add_argument( 54 | "file_path", type=str, help="Path to the input .xyz trajectory file." 55 | ) 56 | 57 | # Optional arguments 58 | parser.add_argument( 59 | "--seed", 60 | type=int, 61 | default=None, 62 | help="Random seed for reproducibility (default: None).", 63 | ) 64 | parser.add_argument( 65 | "--split_ratio", 66 | type=float, 67 | default=0.8, 68 | help="Ratio of train to validation split (default: 0.8 for 80%% train).", 69 | ) 70 | parser.add_argument( 71 | "--train_output", 72 | type=str, 73 | default=None, 74 | help="Output filename for the train set (default: 'input_train.xyz').", 75 | ) 76 | parser.add_argument( 77 | "--val_output", 78 | type=str, 79 | default=None, 80 | help="Output filename for the validation set (default: 'input_val.xyz').", 81 | ) 82 | 83 | args = parser.parse_args() 84 | 85 | # Validate the split ratio 86 | if not 0 < args.split_ratio < 1: 87 | parser.error("split_ratio must be between 0 and 1 (exclusive).") 88 | 89 | # Call the split function with the provided arguments 90 | split_trajectory( 91 | args.file_path, args.split_ratio, args.seed, args.train_output, args.val_output 92 | ) 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /franken/datasets/test/md.xyz: -------------------------------------------------------------------------------- 1 | 2 2 | Lattice="10.863786 0.0 0.0 0.0 10.863786 0.0 0.0 0.0 7.242524" Properties=species:S:1:pos:R:3:forces:R:3 energy=-286.50486882 stress="-0.002598970560062866 -0.0009004822384060989 0.00015443080167127217 -0.0009004822384060989 -0.005992973973123529 0.0015299963548775313 0.00015443080167127217 0.0015299963548775313 -8.988228146165167e-05" pbc="T T T" 3 | Cu 10.84368800 0.01696923 3.65871309 -0.61688780 0.27004654 -0.02483569 4 | Cu 10.83387800 3.69713106 7.12674501 -0.54312980 -0.46143988 0.19868231 -------------------------------------------------------------------------------- /franken/datasets/test/test.xyz: -------------------------------------------------------------------------------- 1 | 2 2 | Lattice="12.025693879326065 0.0 0.0 6.012846939663032 10.414551541843355 0.0 0.0 0.0 16.24012235072364" Properties=species:S:1:pos:R:3:initial_charges:R:1:forces:R:3:magmoms:R:1 energy=-34662.48894316101 free_energy=-34662.48894316101 pbc="T T T" 3 | Fe 4.00860872 4.62866793 -2.52042670 8.17001000 0.84132152 0.36018147 0.37064355 0.00000000 4 | N 1.32922520 -0.85820042 0.58980978 5.85370200 1.85383814 -0.41121504 -0.84305033 0.00000000 -------------------------------------------------------------------------------- /franken/datasets/test/test_dataset.py: -------------------------------------------------------------------------------- 1 | from importlib import resources 2 | from pathlib import Path 3 | 4 | import franken.datasets 5 | from franken.datasets.registry import DATASET_REGISTRY, BaseRegisteredDataset 6 | 7 | 8 | @DATASET_REGISTRY.register("test") 9 | class TestRegisteredDataset(BaseRegisteredDataset): 10 | relative_paths = { 11 | "test": { 12 | "train": "test/train.xyz", 13 | "val": "test/validation.xyz", 14 | "test": "test/test.xyz", 15 | "md": "test/md.xyz", 16 | "long": "test/long.xyz", 17 | }, 18 | } 19 | 20 | @classmethod 21 | def get_path( 22 | cls, name: str, split: str, base_path: Path | None, download: bool = True 23 | ): 24 | relative_path = cls.relative_paths[name][split] 25 | path = resources.files(franken.datasets) / relative_path 26 | if path.is_file(): 27 | return path 28 | else: 29 | raise ValueError(f"Dataset not found at '{path}'") 30 | -------------------------------------------------------------------------------- /franken/datasets/test/train.xyz: -------------------------------------------------------------------------------- 1 | 2 2 | Lattice="12.025693879326065 0.0 0.0 6.012846939663032 10.414551541843355 0.0 0.0 0.0 16.24012235072364" Properties=species:S:1:pos:R:3:initial_charges:R:1:forces:R:3:magmoms:R:1 energy=-34663.781510561785 free_energy=-34663.781510561785 pbc="T T T" 3 | H 4.00860872 4.62866793 -2.52042670 8.14689100 0.69448942 0.05905515 0.53533993 0.00000000 4 | N 5.18398320 4.85600646 2.93821253 4.99690200 0.11953213 -0.02882078 0.69749169 0.00000000 5 | 2 6 | Lattice="12.025693879326065 0.0 0.0 6.012846939663032 10.414551541843355 0.0 0.0 0.0 16.24012235072364" Properties=species:S:1:pos:R:3:initial_charges:R:1:forces:R:3:magmoms:R:1 energy=-34662.89388569647 free_energy=-34662.89388569647 pbc="T T T" 7 | H 4.00860872 4.62866793 -2.52042670 8.21123300 -0.14498065 0.29821454 0.47052756 0.00000000 8 | N 0.67900074 -1.16300245 4.59493225 4.97292000 -0.40408203 -0.50703723 -3.20458904 0.00000000 -------------------------------------------------------------------------------- /franken/datasets/test/validation.xyz: -------------------------------------------------------------------------------- 1 | 2 2 | Lattice="12.025693879326065 0.0 0.0 6.012846939663032 10.414551541843355 0.0 0.0 0.0 16.24012235072364" Properties=species:S:1:pos:R:3:initial_charges:R:1:forces:R:3:magmoms:R:1 energy=-34663.925543954065 free_energy=-34663.925543954065 pbc="T T T" 3 | Fe 4.00860872 4.62866793 -2.52042670 8.20589500 0.60784890 -0.92761725 -0.03625873 0.00000000 4 | N 6.20080853 -2.74071456 0.44857041 6.15661500 -0.24211410 -0.27163268 -2.25028837 0.00000000 5 | 2 6 | Lattice="12.025693879326065 0.0 0.0 6.012846939663032 10.414551541843355 0.0 0.0 0.0 16.24012235072364" Properties=species:S:1:pos:R:3:initial_charges:R:1:forces:R:3:magmoms:R:1 energy=-34663.41248242709 free_energy=-34663.41248242709 pbc="T T T" 7 | Fe 4.00860872 4.62866793 -2.52042670 8.18704900 1.03936665 -0.29763682 -0.34057038 0.00000000 8 | N 2.70448956 -0.60938640 0.28451228 5.91890700 -1.07381044 -0.26788530 0.93481454 0.00000000 -------------------------------------------------------------------------------- /franken/datasets/water/HH_digitizer.csv: -------------------------------------------------------------------------------- 1 | radius,num_atoms 2 | 1.3109243697478992, 0 3 | 1.4243697478991597, 0 4 | 1.53781512605042, 0 5 | 1.7394957983193278, 0.03409090909090906 6 | 1.8529411764705883, 0.04166666666666663 7 | 1.972689075630252, 0.19318181818181818 8 | 2.092436974789916, 0.5189393939393939 9 | 2.130252100840336, 0.7272727272727273 10 | 2.2058823529411766, 0.9507575757575757 11 | 2.32563025210084, 1.2803030303030303 12 | 2.439075630252101, 1.375 13 | 2.5588235294117645, 1.2424242424242424 14 | 2.672268907563025, 1.0492424242424243 15 | 2.792016806722689, 0.875 16 | 2.911764705882353, 0.7651515151515151 17 | 3.0252100840336134, 0.7272727272727273 18 | 3.1449579831932772, 0.7386363636363636 19 | 3.258403361344538, 0.7840909090909092 20 | 3.384453781512605, 0.8598484848484849 21 | 3.4978991596638656, 0.9583333333333334 22 | 3.6176470588235294, 1.0606060606060606 23 | 3.73109243697479, 1.1401515151515151 24 | 3.850840336134454, 1.1818181818181819 25 | 3.9705882352941178, 1.1666666666666667 26 | 4.084033613445378, 1.1325757575757576 27 | 4.203781512605042, 1.0909090909090908 28 | 4.317226890756302, 1.0568181818181819 29 | 4.436974789915967, 1.0416666666666665 30 | 4.55672268907563, 1.0340909090909092 31 | 4.670168067226891, 1.0265151515151514 32 | 4.7899159663865545, 1.018939393939394 33 | 4.9033613445378155, 1.003787878787879 34 | 5.023109243697479, 0.9886363636363636 35 | 5.142857142857142, 0.9772727272727273 36 | 5.256302521008403, 0.9734848484848485 37 | 5.376050420168067, 0.9621212121212122 38 | 5.489495798319328, 0.9621212121212122 39 | -------------------------------------------------------------------------------- /franken/datasets/water/OH_digitizer.csv: -------------------------------------------------------------------------------- 1 | radius,num_atoms 2 | 1.3109243697478992, 0 3 | 1.4243697478991597, 0 4 | 1.5441176470588236, 0.049242424242424254 5 | 1.657563025210084, 0.3560606060606061 6 | 1.7773109243697478, 0.8598484848484849 7 | 1.8970588235294117, 1.0946969696969697 8 | 2.0105042016806722, 0.9659090909090909 9 | 2.130252100840336, 0.7272727272727273 10 | 2.2436974789915967, 0.49242424242424243 11 | 2.3634453781512605, 0.34090909090909094 12 | 2.476890756302521, 0.26136363636363635 13 | 2.596638655462185, 0.2537878787878788 14 | 2.716386554621849, 0.3106060606060606 15 | 2.8361344537815127, 0.45454545454545453 16 | 2.9495798319327733, 0.7083333333333334 17 | 3.069327731092437, 1.0416666666666665 18 | 3.189075630252101, 1.3409090909090908 19 | 3.302521008403361, 1.496212121212121 20 | 3.4159663865546217, 1.4848484848484849 21 | 3.5357142857142856, 1.371212121212121 22 | 3.6554621848739495, 1.2348484848484849 23 | 3.76890756302521, 1.128787878787879 24 | 3.888655462184874, 1.0568181818181819 25 | 4.008403361344538, 1.0151515151515151 26 | 4.241596638655462, 0.9962121212121212 27 | 4.3613445378151265, 0.9848484848484849 28 | 4.474789915966387, 0.9772727272727273 29 | 4.588235294117647, 0.9696969696969697 30 | 4.707983193277311, 0.9734848484848485 31 | 4.8277310924369745, 0.9772727272727273 32 | 4.9411764705882355, 0.9886363636363636 33 | 5.0609243697479, 0.9962121212121212 34 | 5.180672268907563, 0.9962121212121212 35 | 5.300420168067227, 1.003787878787879 36 | 5.413865546218487, 1.003787878787879 37 | -------------------------------------------------------------------------------- /franken/datasets/water/OO_digitizer.csv: -------------------------------------------------------------------------------- 1 | radius,num_atoms 2 | 1.30373831775701,0.0 3 | 1.42056074766355,0.0 4 | 1.54672897196262,0.0 5 | 1.66355140186916,0 6 | 1.77570093457944,-0.0 7 | 1.88785046728972,0.00 8 | 2.00934579439252,-0.0 9 | 2.1214953271028,0.0 10 | 2.24766355140187,0 11 | 2.3411214953271,0 12 | 2.47663551401869,0.0196078431372549 13 | 2.60280373831776,0.417366946778711 14 | 2.70560747663551,1.53221288515406 15 | 2.83644859813084,2.33893557422969 16 | 2.95327102803738,2.22408963585434 17 | 3.06542056074766,1.72829131652661 18 | 3.18691588785047,1.26890756302521 19 | 3.29906542056075,0.983193277310924 20 | 3.42056074766355,0.865546218487395 21 | 3.53738317757009,0.826330532212885 22 | 3.64953271028037,0.840336134453782 23 | 3.77102803738318,0.857142857142857 24 | 3.88785046728972,0.907563025210084 25 | 4.01401869158878,0.949579831932773 26 | 4.1214953271028,1 27 | 4.22897196261682,1.04201680672269 28 | 4.34579439252336,1.0812324929972 29 | 4.48130841121495,1.11484593837535 30 | 4.58411214953271,1.12885154061625 31 | 4.70560747663551,1.13445378151261 32 | 4.80841121495327,1.10924369747899 33 | 4.94859813084112,1.08403361344538 34 | 5.06542056074766,1.04761904761905 35 | 5.20093457943925,1.0140056022409 36 | 5.30373831775701,0.974789915966386 37 | 5.40654205607477,0.943977591036414 38 | 5.53271028037383,0.913165266106443 39 | 5.63551401869159,0.913165266106443 40 | 5.74299065420561,0.907563025210084 41 | 5.87383177570093,0.896358543417367 42 | 5.97196261682243,0.910364145658263 43 | -------------------------------------------------------------------------------- /franken/datasets/water/water_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import re 3 | import zipfile 4 | 5 | import ase 6 | import ase.io 7 | 8 | from franken.datasets.registry import DATASET_REGISTRY, BaseRegisteredDataset 9 | from franken.utils.file_utils import download_file 10 | 11 | 12 | @DATASET_REGISTRY.register("water") 13 | class WaterRegisteredDataset(BaseRegisteredDataset): 14 | relative_paths = { 15 | "water": { 16 | "train": "water/ML_AB_dataset_1.xyz", 17 | "val": "water/ML_AB_dataset_2-val.xyz", 18 | }, 19 | } 20 | zip_file_names = ["ML_AB_dataset_1", "ML_AB_dataset_2", "ML_AB_128h2o_validation"] 21 | 22 | @classmethod 23 | def get_path( 24 | cls, name: str, split: str, base_path: Path | None, download: bool = True 25 | ): 26 | if base_path is None: 27 | raise KeyError(None) 28 | relative_path = cls.relative_paths[name][split] 29 | path = base_path / relative_path 30 | if not path.is_file() and download: 31 | cls.download(base_path) 32 | if path.is_file(): 33 | return path 34 | else: 35 | raise ValueError(f"Dataset not found at '{path.resolve()}'") 36 | 37 | @classmethod 38 | def download(cls, base_path: Path): 39 | water_base_path = base_path / "water" 40 | water_base_path.mkdir(exist_ok=True, parents=True) 41 | 42 | # NOTE: cannot check MD5 here since it changes at every download. As a dumb fallback we check the file-size. 43 | download_file( 44 | url="https://zenodo.org/api/records/10723405/files-archive", 45 | filename=water_base_path / "data.zip", 46 | desc="Downloading water dataset", 47 | expected_size=35866571, 48 | ) 49 | # Extract from zip and convert VASP -> XYZ format 50 | with zipfile.ZipFile(water_base_path / "data.zip", mode="r") as zf: 51 | for file_name in cls.zip_file_names: 52 | with zf.open(file_name, "r") as fh: 53 | vasp_data = fh.read().decode("utf-8") 54 | xyz_data = vasp_mlff_to_xyz(vasp_data) 55 | with open(water_base_path / f"{file_name}.xyz", "w") as fh: 56 | fh.write(xyz_data) 57 | # Sanity check 58 | traj = ase.io.read(water_base_path / f"{file_name}.xyz", index=":") 59 | assert isinstance(traj, list) 60 | for i, atoms in enumerate(traj): 61 | atoms.get_potential_energy() 62 | atoms.get_forces() 63 | # Split a validation set from dataset-2 64 | dataset = ase.io.read( 65 | water_base_path / "ML_AB_dataset_2.xyz", index=":", format="extxyz" 66 | ) 67 | assert isinstance(dataset, list) 68 | dataset_no_overlap = dataset[473:] 69 | ase.io.write(water_base_path / "ML_AB_dataset_2-val.xyz", dataset_no_overlap) 70 | # Cleanup 71 | (water_base_path / "data.zip").unlink() 72 | 73 | 74 | def vasp_mlff_to_xyz_oneconfig(data): 75 | # Parse sections using regular expressions 76 | num_atoms = int(re.search(r"The number of atoms\s*[-=]+\s*(\d+)", data).group(1)) 77 | energy = float( 78 | re.search(r"Total energy \(eV\)\s*[-=]+\s*([-+]?\d*\.\d+|\d+)", data).group(1) 79 | ) 80 | 81 | # Extract lattice vectors 82 | lattice_match = re.search( 83 | r"Primitive lattice vectors \(ang.\)\s*[-=]+\s*([\d\s.-]+)", data 84 | ) 85 | lattice_lines = lattice_match.group(1).strip().split("\n") 86 | lattice = [line.split() for line in lattice_lines] 87 | 88 | # Flatten and format lattice as a string for XYZ format 89 | lattice_flat = " ".join([" ".join(line) for line in lattice]) 90 | 91 | # Extract atomic positions 92 | positions_match = re.search( 93 | r"Atomic positions \(ang.\)\s*[-=]+\s*([\d\s.-]+)", data 94 | ) 95 | positions_lines = positions_match.group(1).strip().split("\n") 96 | positions = [line.split() for line in positions_lines] 97 | 98 | # Extract forces 99 | forces_match = re.search(r"Forces \(eV ang.\^-1\)\s*[-=]+\s*([\d\s.-]+)", data) 100 | forces_lines = forces_match.group(1).strip().split("\n") 101 | forces = [line.split() for line in forces_lines] 102 | 103 | # Extract stress tensor (two lines) without separators 104 | stress_match_1 = re.search( 105 | r"Stress \(kbar\)\s*[-=]+\s*XX YY ZZ\s*[-=]+\s*([\d\s.-]+)", data 106 | ) 107 | stress_match_2 = re.search(r"XY YZ ZX\s*[-=]+\s*([\d\s.-]+)", data) 108 | 109 | # Ensure we only capture numerical values and not separator lines 110 | stress_values_1 = ( 111 | stress_match_1.group(1).strip().split()[:3] 112 | ) # Take first three values for XX YY ZZ 113 | stress_values_2 = ( 114 | stress_match_2.group(1).strip().split()[:3] 115 | ) # Take first three values for XY YZ ZX 116 | xx, yy, zz = stress_values_1 117 | xy, yz, zx = stress_values_2 118 | 119 | # Combine the two stress components into a single list 120 | # stress_tensor = stress_values_1 + stress_values_2 121 | # stress_tensor = ' '.join(stress_tensor) # Convert to a single string 122 | stress_tensor = f"{xx} {xy} {zx} {xy} {yy} {yz} {zx} {yz} {zz}" 123 | 124 | # Create the extended XYZ content for this configuration 125 | xyz_content = [] 126 | xyz_content.append(f"{num_atoms}") 127 | xyz_content.append( 128 | f'Lattice="{lattice_flat}" Properties=species:S:1:pos:R:3:forces:R:3 energy={energy} stress="{stress_tensor}"' 129 | ) 130 | 131 | # Atom types (order them according to the positions provided) 132 | atom_type_lines = ( 133 | re.search(r"Atom types and atom numbers\s*[-=]+\s*([\w\s\d]+)", data) 134 | .group(1) 135 | .strip() 136 | .split("\n") 137 | ) 138 | atom_types = [] 139 | for line in atom_type_lines: 140 | element, count = line.split() 141 | atom_types.extend([element] * int(count)) 142 | 143 | # Add each atom's data line by line 144 | for idx, (position, force) in enumerate(zip(positions, forces)): 145 | element = atom_types[idx] 146 | px, py, pz = position 147 | fx, fy, fz = force 148 | xyz_content.append(f"{element} {px} {py} {pz} {fx} {fy} {fz}") 149 | 150 | return "\n".join(xyz_content) 151 | 152 | 153 | def vasp_mlff_to_xyz(data): 154 | # Split the data by configurations using "Configuration num." as the delimiter 155 | configurations = re.split(r"Configuration num\.\s*\d+", data) 156 | xyz_all = [] 157 | 158 | # Process each configuration if it is not empty 159 | for config in configurations: 160 | config = config.strip() 161 | if config: # Only parse if the configuration is not empty 162 | try: 163 | xyz_all.append(vasp_mlff_to_xyz_oneconfig(config)) 164 | except AttributeError: 165 | pass # some errors are expected. 166 | 167 | # Join all configurations with a newline 168 | return "\n".join(xyz_all) 169 | 170 | 171 | if __name__ == "__main__": 172 | WaterRegisteredDataset.download(Path(__file__).parent.parent) 173 | -------------------------------------------------------------------------------- /franken/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from franken.metrics.base import BaseMetric 4 | from franken.metrics.functions import * # noqa: F403 5 | from franken.metrics.registry import registry 6 | 7 | 8 | __all__ = ["registry"] 9 | 10 | 11 | def available_metrics() -> list[str]: 12 | return registry.available_metrics 13 | 14 | 15 | def register(name: str, metric_class: type) -> None: 16 | registry.register(name, metric_class) 17 | 18 | 19 | def init_metric( 20 | name: str, device: torch.device, dtype: torch.dtype = torch.float32 21 | ) -> BaseMetric: 22 | return registry.init_metric(name, device, dtype) 23 | -------------------------------------------------------------------------------- /franken/metrics/base.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | import torch 3 | 4 | import franken.utils.distributed as dist_utils 5 | from franken.data.base import Target 6 | 7 | 8 | class BaseMetric: 9 | def __init__( 10 | self, 11 | name: str, 12 | device: torch.device, 13 | dtype: torch.dtype = torch.float64, 14 | units: Mapping[str, str | None] = {}, 15 | ): 16 | self.name = name 17 | self.device = device 18 | self.dtype = dtype 19 | self.buffer = None 20 | self.samples_counter = torch.zeros((1,), device=device, dtype=dtype) 21 | self.units = units 22 | 23 | def reset(self) -> None: 24 | """Reset the buffer to zeros""" 25 | self.buffer = None 26 | self.samples_counter = torch.zeros((1,), device=self.device, dtype=torch.int64) 27 | 28 | def buffer_add(self, value: torch.Tensor, num_samples: int = 1) -> None: 29 | if self.buffer is None: 30 | self.buffer = torch.zeros(value.shape, device=self.device, dtype=self.dtype) 31 | else: 32 | assert self.buffer.shape == value.shape 33 | self.buffer += value 34 | self.samples_counter += num_samples 35 | 36 | def update( 37 | self, 38 | predictions: Target, 39 | targets: Target, 40 | ) -> None: 41 | """Update the metric buffer with new batch results""" 42 | raise NotImplementedError() 43 | 44 | def compute(self, reset: bool = True) -> torch.Tensor: 45 | if self.buffer is None: 46 | raise ValueError( 47 | f"Cannot compute value for metric '{self.name}' " 48 | "because it was never updated." 49 | ) 50 | dist_utils.all_sum(self.buffer) 51 | dist_utils.all_sum(self.samples_counter) 52 | error = self.buffer / self.samples_counter 53 | if reset: 54 | self.reset() 55 | return error 56 | -------------------------------------------------------------------------------- /franken/metrics/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from franken.data.base import Target 5 | from franken.metrics.base import BaseMetric 6 | from franken.metrics.registry import registry 7 | from franken.utils import distributed 8 | 9 | 10 | __all__ = [ 11 | "EnergyMAE", 12 | "EnergyRMSE", 13 | "ForcesMAE", 14 | "ForcesRMSE", 15 | "ForcesCosineSimilarity", 16 | "is_pareto_efficient", 17 | ] 18 | 19 | 20 | class EnergyMAE(BaseMetric): 21 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32): 22 | units = { 23 | "inputs": "eV", 24 | "outputs": "meV/atom", 25 | } 26 | super().__init__("energy_MAE", device, dtype, units) 27 | 28 | def update(self, predictions: Target, targets: Target) -> None: 29 | if targets.forces is None: 30 | raise NotImplementedError( 31 | "At the moment, target's forces are required to get the number of atoms in the configuration." 32 | ) 33 | num_atoms = targets.forces.shape[-2] 34 | num_samples = 1 35 | if targets.energy.ndim > 0: 36 | num_samples = targets.energy.shape[0] 37 | 38 | error = 1000 * torch.abs(targets.energy - predictions.energy) / num_atoms 39 | 40 | self.buffer_add(error, num_samples=num_samples) 41 | 42 | 43 | class EnergyRMSE(BaseMetric): 44 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32): 45 | units = { 46 | "inputs": "eV", 47 | "outputs": "meV/atom", 48 | } 49 | super().__init__("energy_RMSE", device, dtype, units) 50 | 51 | def update(self, predictions: Target, targets: Target) -> None: 52 | if targets.forces is None: 53 | raise NotImplementedError( 54 | "At the moment, target's forces are required to get the number of atoms in the configuration." 55 | ) 56 | num_atoms = targets.forces.shape[-2] 57 | num_samples = 1 58 | if targets.energy.ndim > 0: 59 | num_samples = targets.energy.shape[0] 60 | 61 | error = torch.square((targets.energy - predictions.energy) / num_atoms) 62 | 63 | self.buffer_add(error, num_samples=num_samples) 64 | 65 | def compute(self, reset: bool = True) -> torch.Tensor: 66 | if self.buffer is None: 67 | raise ValueError( 68 | f"Cannot compute value for metric '{self.name}' " 69 | "because it was never updated." 70 | ) 71 | distributed.all_sum(self.buffer) 72 | distributed.all_sum(self.samples_counter) 73 | error = self.buffer / self.samples_counter 74 | # square-root and fix units 75 | error = torch.sqrt(error) * 1000 76 | if reset: 77 | self.reset() 78 | return error 79 | 80 | 81 | class ForcesMAE(BaseMetric): 82 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32): 83 | units = { 84 | "inputs": "eV/ang", 85 | "outputs": "meV/ang", 86 | } 87 | super().__init__("forces_MAE", device, dtype, units) 88 | 89 | def update(self, predictions: Target, targets: Target) -> None: 90 | if targets.forces is None or predictions.forces is None: 91 | raise AttributeError("Forces must be specified to compute the MAE.") 92 | num_samples = 1 93 | if targets.forces.ndim > 2: 94 | num_samples = targets.forces.shape[0] 95 | elif targets.forces.ndim < 2: 96 | raise ValueError("Forces must be a 2D tensor or a batch of 2D tensors.") 97 | 98 | error = 1000 * torch.abs(targets.forces - predictions.forces) 99 | error = error.mean(dim=(-1, -2)) # Average over atoms and components 100 | 101 | self.buffer_add(error, num_samples=num_samples) 102 | 103 | 104 | class ForcesRMSE(BaseMetric): 105 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32): 106 | units = { 107 | "inputs": "eV/ang", 108 | "outputs": "meV/ang", 109 | } 110 | super().__init__("forces_RMSE", device, dtype, units) 111 | 112 | def update(self, predictions: Target, targets: Target) -> None: 113 | if targets.forces is None or predictions.forces is None: 114 | raise AttributeError("Forces must be specified to compute the MAE.") 115 | num_samples = 1 116 | if targets.forces.ndim > 2: 117 | num_samples = targets.forces.shape[0] 118 | elif targets.forces.ndim < 2: 119 | raise ValueError("Forces must be a 2D tensor or a batch of 2D tensors.") 120 | 121 | error = torch.square(targets.forces - predictions.forces) 122 | error = error.mean(dim=(-1, -2)) # Average over atoms and components 123 | 124 | self.buffer_add(error, num_samples=num_samples) 125 | 126 | def compute(self, reset: bool = True) -> torch.Tensor: 127 | if self.buffer is None: 128 | raise ValueError( 129 | f"Cannot compute value for metric '{self.name}' " 130 | "because it was never updated." 131 | ) 132 | distributed.all_sum(self.buffer) 133 | distributed.all_sum(self.samples_counter) 134 | error = self.buffer / self.samples_counter 135 | # square-root and fix units 136 | error = torch.sqrt(error) * 1000 137 | if reset: 138 | self.reset() 139 | return error 140 | 141 | 142 | class ForcesRMSE2(BaseMetric): 143 | """Average of RMSE along individual structures""" 144 | 145 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32): 146 | units = { 147 | "inputs": "eV/ang", 148 | "outputs": "meV/ang", 149 | } 150 | super().__init__("forces_RMSE", device, dtype, units) 151 | 152 | def update(self, predictions: Target, targets: Target) -> None: 153 | if targets.forces is None or predictions.forces is None: 154 | raise AttributeError("Forces must be specified to compute the MAE.") 155 | num_samples = 1 156 | if targets.forces.ndim > 2: 157 | num_samples = targets.forces.shape[0] 158 | elif targets.forces.ndim < 2: 159 | raise ValueError("Forces must be a 2D tensor or a batch of 2D tensors.") 160 | 161 | error = torch.square(targets.forces - predictions.forces) 162 | error = error.mean(dim=(-1, -2)) # Average over atoms and components 163 | error = torch.sqrt(error) * 1000 164 | self.buffer_add(error, num_samples=num_samples) 165 | 166 | 167 | class ForcesCosineSimilarity(BaseMetric): 168 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32): 169 | units = { 170 | "inputs": "eV/ang", 171 | "outputs": None, 172 | } 173 | super().__init__("forces_cosim", device, dtype, units) 174 | 175 | def update( 176 | self, 177 | predictions: Target, 178 | targets: Target, 179 | ) -> None: 180 | num_samples = 1 181 | assert targets.forces is not None 182 | assert predictions.forces is not None 183 | if targets.forces.ndim > 2: 184 | num_samples = targets.forces.shape[0] 185 | elif targets.forces.ndim < 2: 186 | raise ValueError("Forces must be a 2D tensor or a batch of 2D tensors.") 187 | 188 | cos_similarity = torch.nn.functional.cosine_similarity( 189 | predictions.forces, targets.forces, dim=-1 190 | ) 191 | cos_similarity = cos_similarity.mean(dim=-1) 192 | self.buffer_add(cos_similarity, num_samples=num_samples) 193 | 194 | 195 | def is_pareto_efficient(costs): 196 | """ 197 | Find the pareto-efficient points 198 | :param costs: An (n_points, n_costs) array 199 | :return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient 200 | """ 201 | is_efficient = np.ones(costs.shape[0], dtype=bool) 202 | for i, c in enumerate(costs): 203 | if is_efficient[i]: 204 | is_efficient[is_efficient] = np.any( 205 | costs[is_efficient] < c, axis=1 206 | ) # Keep any point with a lower cost 207 | is_efficient[i] = True # And keep self 208 | return is_efficient 209 | 210 | 211 | registry.register("energy_MAE", EnergyMAE) 212 | registry.register("energy_RMSE", EnergyRMSE) 213 | registry.register("forces_MAE", ForcesMAE) 214 | registry.register("forces_RMSE", ForcesRMSE) 215 | registry.register("forces_RMSE2", ForcesRMSE2) 216 | registry.register("forces_cosim", ForcesCosineSimilarity) 217 | -------------------------------------------------------------------------------- /franken/metrics/registry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from franken.metrics.base import BaseMetric 4 | 5 | 6 | class MetricRegistry: 7 | _instance = None 8 | 9 | def __new__(cls): 10 | if cls._instance is None: 11 | cls._instance = super().__new__(cls) 12 | cls._instance._metrics = {} 13 | return cls._instance 14 | 15 | def register(self, name: str, metric_class: type) -> None: 16 | """Register a metric class""" 17 | self._metrics[name] = metric_class 18 | 19 | def init_metric( 20 | self, name: str, device: torch.device, dtype: torch.dtype = torch.float32 21 | ) -> BaseMetric: 22 | """Create a new instance of a metric""" 23 | if name not in self._metrics: 24 | raise KeyError( 25 | f"Metric '{name}' not found. Available metrics: {list(self._metrics.keys())}" 26 | ) 27 | return self._metrics[name](device=device, dtype=dtype) 28 | 29 | @property 30 | def available_metrics(self) -> list[str]: 31 | return list(self._metrics.keys()) 32 | 33 | 34 | registry = MetricRegistry() 35 | -------------------------------------------------------------------------------- /franken/rf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/franken/rf/__init__.py -------------------------------------------------------------------------------- /franken/rf/atomic_energies.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | 3 | import torch 4 | 5 | 6 | class AtomicEnergiesShift(torch.nn.Module): 7 | atomic_energies: torch.Tensor 8 | Z_keys: list[int] 9 | 10 | def __init__( 11 | self, 12 | num_species: int, 13 | atomic_energies: Mapping[int, torch.Tensor | float] | None = None, 14 | ): 15 | """ 16 | Initialize the AtomicEnergiesShift module. 17 | 18 | Args: 19 | num_species: 20 | atomic_energies: A dictionary mapping atomic numbers to atomic energies. 21 | """ 22 | super().__init__() 23 | 24 | self.num_species = num_species 25 | self.register_buffer("atomic_energies", torch.zeros(num_species)) 26 | self.register_buffer( 27 | "z_keys", torch.zeros((self.num_species,), dtype=torch.long) 28 | ) # placeholder 29 | self.is_initialized = False 30 | 31 | if atomic_energies is not None: 32 | self.set_from_atomic_energies(atomic_energies) 33 | 34 | def set_from_atomic_energies( 35 | self, atomic_energies: Mapping[int, torch.Tensor | float] 36 | ): 37 | assert ( 38 | len(atomic_energies) == self.num_species 39 | ), f"{len(atomic_energies)=} != {self.num_species=}" 40 | device = self.atomic_energies.device 41 | self.atomic_energies = torch.stack( 42 | [ 43 | v.clone().detach() if isinstance(v, torch.Tensor) else torch.tensor(v) 44 | for v in atomic_energies.values() 45 | ] 46 | ).to(device) 47 | self.z_keys = torch.tensor( 48 | list(atomic_energies.keys()), 49 | dtype=torch.long, 50 | device=self.atomic_energies.device, 51 | ) 52 | self.is_initialized = True 53 | 54 | def forward(self, atomic_numbers: torch.Tensor) -> torch.Tensor: 55 | """ 56 | Calculate the energy shift for a given set of atomic numbers. 57 | 58 | Args: 59 | atomic_numbers: A tensor containing atomic numbers for which to calculate the energy shift. 60 | 61 | Returns: 62 | A tensor representing the total energy shift for the provided atomic numbers. 63 | """ 64 | 65 | shift = torch.tensor( 66 | 0.0, dtype=self.atomic_energies.dtype, device=self.atomic_energies.device 67 | ) 68 | 69 | for z, atom_ene in zip(self.z_keys, self.atomic_energies): 70 | mask = atomic_numbers == int(z.item()) 71 | shift += torch.sum(atom_ene * mask) 72 | 73 | return shift 74 | 75 | def __repr__(self): 76 | formatted_energies = " , ".join( 77 | [ 78 | f"{z.item()}: {atom_ene}" 79 | for z, atom_ene in zip(self.z_keys, self.atomic_energies) 80 | ] 81 | ) 82 | return f"{self.__class__.__name__}({formatted_energies})" 83 | -------------------------------------------------------------------------------- /franken/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | """Train franken from data.""" 2 | 3 | from franken.trainers.base import BaseTrainer 4 | from franken.trainers.rf_cuda_lowmem import RandomFeaturesTrainer 5 | 6 | __all__ = ( 7 | "BaseTrainer", 8 | "RandomFeaturesTrainer", 9 | ) 10 | -------------------------------------------------------------------------------- /franken/trainers/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import json 3 | import logging 4 | from pathlib import Path 5 | from typing import Tuple, Union 6 | 7 | import torch 8 | import torch.utils.data 9 | 10 | from franken.config import asdict_with_classvar 11 | from franken.rf.model import FrankenPotential 12 | from franken.rf.scaler import Statistics, compute_dataset_statistics 13 | from franken.trainers.log_utils import ( 14 | DataSplit, 15 | LogCollection, 16 | LogEntry, 17 | dtypeJSONEncoder, 18 | ) 19 | from franken.utils.misc import are_dicts_equal 20 | 21 | 22 | logger = logging.getLogger("franken") 23 | 24 | 25 | class BaseTrainer(abc.ABC): 26 | """Base trainer class. Requires :meth:`~BaseTrainer.fit` and :meth:`~BaseTrainer.evaluate` methods.""" 27 | 28 | def __init__( 29 | self, 30 | train_dataloader: torch.utils.data.DataLoader, 31 | log_dir: Path | None = None, # If None, logging is disabled 32 | save_every_model: bool = True, 33 | device: Union[torch.device, str, int] = "cpu", 34 | dtype: Union[str, torch.dtype] = torch.float32, 35 | ): 36 | self.log_dir = log_dir 37 | self.save_every_model = save_every_model 38 | self.train_dataloader = train_dataloader 39 | self.statistics_ = None 40 | if isinstance(dtype, str): 41 | if dtype.lower() == "float64" or dtype.lower == "double": 42 | dtype = torch.float64 43 | elif dtype.lower() == "float32" or dtype.lower == "float": 44 | dtype = torch.float32 45 | else: 46 | raise ValueError( 47 | f"Invalid dtype {dtype}. Allowed values are 'float64', 'double', 'float32', 'single'." 48 | ) 49 | if dtype not in {torch.float32, torch.float64}: 50 | raise ValueError( 51 | f"Invalid dtype {dtype}. torch.float32 or torch.float64 are allowed." 52 | ) 53 | self.buffer_dt = dtype 54 | self.device = torch.device(device) 55 | 56 | @torch.no_grad() 57 | def get_statistics(self, model: FrankenPotential) -> Tuple[Statistics, dict]: 58 | """Compute statistics on the training dataset with the provided model 59 | 60 | Args: 61 | model (FrankenPotential): Franken model from which the attached GNN 62 | is used to compute the features on atomic configurations. 63 | 64 | Returns: 65 | A tuple containing an object of type :class:`franken.rf.scaler.Statistics` containing 66 | the dataset statistics, and a dictionary containing the GNN-backbone hyperparameters 67 | used when computing dataset features. 68 | """ 69 | if self.statistics_ is None or not are_dicts_equal( 70 | self.statistics_[1], asdict_with_classvar(model.gnn_config) 71 | ): 72 | stat = compute_dataset_statistics( 73 | dataset=self.train_dataloader.dataset, # type: ignore 74 | gnn=model.gnn, 75 | device=self.device, 76 | ) 77 | stat_dict = asdict_with_classvar(model.gnn_config) 78 | self.statistics_ = (stat, stat_dict) 79 | 80 | return self.statistics_ 81 | 82 | @abc.abstractmethod 83 | def fit( 84 | self, 85 | model: FrankenPotential, 86 | solver_params: dict, 87 | ) -> tuple[LogCollection, torch.Tensor]: 88 | """Fit a given franken model on the training set. 89 | 90 | Args: 91 | model (FrankenPotential): The model which defines GNN and random features. 92 | solver_params (dict): Parameters for the solver which actually 93 | performs the fit. 94 | 95 | Returns: 96 | tuple[LogCollection, torch.Tensor]: 97 | - Logs which contain all parameters related to the fitting, as well as timings. 98 | - Weights which were learned during the fit. 99 | """ 100 | pass 101 | 102 | @abc.abstractmethod 103 | def evaluate( 104 | self, 105 | model: FrankenPotential, 106 | dataloader: torch.utils.data.DataLoader, 107 | log_collection: LogCollection, 108 | all_weights: torch.Tensor, 109 | metrics: list[str], 110 | ) -> LogCollection: 111 | """Evaluate a fitted model by computing metrics on a validation dataset. 112 | 113 | Args: 114 | model: The model which defines GNN and random features. 115 | dataloader (torch.utils.data.DataLoader): Evaluation will run the model 116 | on each configuration in the dataloader, computing averaged metrics. 117 | log_collection: Log object as output by the :meth:`fit` 118 | method. Metric values will be added to the logs and the same object will 119 | be returned by this method. 120 | all_weights (torch.Tensor): The weights as output by the :meth:`fit` method. 121 | metrics (list[str]): List of metrics which should be computed. 122 | 123 | Returns: 124 | logs (LogCollection): Logs which contain all parameters related 125 | to the fitting, as well as timings and metrics. 126 | """ 127 | pass 128 | 129 | def serialize_logs( 130 | self, 131 | model: FrankenPotential, 132 | log_collection: LogCollection, 133 | all_weights: torch.Tensor, 134 | best_model_split: DataSplit = DataSplit.TRAIN, 135 | ): 136 | assert self.log_dir is not None, "Log directory is not set" 137 | model_hash_set = set(log.checkpoint_hash for log in log_collection) 138 | assert len(model_hash_set) == 1 139 | model_hash = model_hash_set.pop() 140 | log_collection.save_json(self.log_dir / "log.json") 141 | 142 | # Save the model checkpoint 143 | if self.save_every_model: 144 | ckpt_dir = self.log_dir / "checkpoints" 145 | ckpt_dir.mkdir(parents=True, exist_ok=True) 146 | model_save_path = ckpt_dir / f"{model_hash}.pt" 147 | model.save(model_save_path, multi_weights=all_weights) 148 | logger.debug( 149 | f"Saved multiple models (hash={model_hash}) " f"to {model_save_path}" 150 | ) 151 | # Log the best model 152 | self.serialize_best_model(model, all_weights, split=best_model_split) 153 | 154 | def serialize_best_model( 155 | self, 156 | model: FrankenPotential, 157 | all_weights: torch.Tensor, 158 | split: DataSplit = DataSplit.TRAIN, 159 | ) -> None: 160 | assert self.log_dir is not None, "Log directory is not set" 161 | log_collection = LogCollection.from_json(self.log_dir / "log.json") 162 | best_model = log_collection.get_best_model(split=split) 163 | 164 | best_model_file = self.log_dir / "best.json" 165 | should_save = True 166 | if best_model_file.exists(): 167 | with open(best_model_file, "r") as f: 168 | current_best = LogEntry.from_dict(json.load(f)) 169 | if best_model == current_best: 170 | should_save = False 171 | 172 | if should_save: 173 | logger.debug(f"Identified new best model: {best_model}") 174 | with open(best_model_file, "w") as f: 175 | json.dump(best_model.to_dict(), f, indent=4, cls=dtypeJSONEncoder) 176 | weights = all_weights[best_model.checkpoint_rf_weight_id] 177 | model.rf.weights = weights.reshape_as(model.rf.weights) 178 | model.save(self.log_dir / "best_ckpt.pt") 179 | logger.debug( 180 | f"Saved best model (within-experiment ID={best_model.checkpoint_rf_weight_id}) " 181 | f"to {self.log_dir / 'best_ckpt.pt'}" 182 | ) 183 | -------------------------------------------------------------------------------- /franken/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/franken/utils/__init__.py -------------------------------------------------------------------------------- /franken/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Tuple, Union 3 | import warnings 4 | from socket import gethostname 5 | 6 | import torch 7 | import torch.distributed 8 | 9 | from . import hostlist 10 | 11 | 12 | def slurm_to_env(): 13 | hostname = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])[0] 14 | os.environ["MASTER_ADDR"] = hostname 15 | os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "33633") 16 | try: 17 | world_size = int(os.environ["SLURM_NTASKS"]) 18 | except KeyError: 19 | world_size = int(os.environ["SLURM_NTASKS_PER_NODE"]) * int( 20 | os.environ["SLURM_NNODES"] 21 | ) 22 | os.environ["WORLD_SIZE"] = str(world_size) 23 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] 24 | os.environ["RANK"] = os.environ["SLURM_PROCID"] 25 | 26 | 27 | def is_torchrun(): 28 | # torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE 29 | return "RANK" in os.environ and "WORLD_SIZE" in os.environ 30 | 31 | 32 | def is_slurm(): 33 | return "SLURM_PROCID" in os.environ 34 | 35 | 36 | def init(distributed: bool) -> int: 37 | if distributed: 38 | if not torch.cuda.is_available(): 39 | raise RuntimeError("Distributed training is only supported on CUDA") 40 | if is_torchrun(): 41 | pass 42 | elif is_slurm(): 43 | slurm_to_env() 44 | else: 45 | warnings.warn( 46 | "Cannot initialize distributed training. " 47 | "Neither torchrun nor SLURM environment variable were found." 48 | ) 49 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 50 | if world_size > 1: 51 | print( 52 | f"Distributed initialization at rank {os.environ['RANK']} of {world_size} " 53 | f"(rank {os.environ['LOCAL_RANK']} on {gethostname()} with " 54 | f"{torch.cuda.device_count()} GPUs allocated)." 55 | ) 56 | torch.distributed.init_process_group( 57 | backend="nccl", 58 | device_id=torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}"), 59 | ) 60 | 61 | device = f"cuda:{get_local_rank()}" 62 | torch.cuda.set_device(device) 63 | return get_rank() 64 | 65 | 66 | def get_local_rank() -> int: 67 | if torch.distributed.is_initialized(): 68 | return int(os.environ["LOCAL_RANK"]) 69 | return 0 70 | 71 | 72 | def get_rank() -> int: 73 | if torch.distributed.is_initialized(): 74 | return torch.distributed.get_rank() 75 | return 0 76 | 77 | 78 | def barrier() -> None: 79 | if torch.distributed.is_initialized(): 80 | torch.distributed.barrier() 81 | 82 | 83 | def get_world_size() -> int: 84 | if torch.distributed.is_initialized(): 85 | return torch.distributed.get_world_size() 86 | return 1 87 | 88 | 89 | def all_reduce(tensor: torch.Tensor, op) -> None: 90 | if torch.distributed.is_initialized(): 91 | torch.distributed.all_reduce(tensor, op) 92 | return None 93 | 94 | 95 | def all_sum(tensor: torch.Tensor) -> None: 96 | if torch.distributed.is_initialized(): 97 | torch.distributed.all_reduce(tensor, torch.distributed.ReduceOp.SUM) 98 | return None 99 | 100 | 101 | def broadcast_obj(obj, src=0): 102 | if torch.distributed.is_initialized(): 103 | to_broadcast = [obj] 104 | torch.distributed.broadcast_object_list(to_broadcast, src=src) 105 | return to_broadcast[0] 106 | return obj 107 | 108 | 109 | def all_gather_into_tensor( 110 | out_size: Union[Tuple, torch.Size], in_tensor: torch.Tensor 111 | ) -> torch.Tensor: 112 | if torch.distributed.is_initialized(): 113 | out_tensor = torch.zeros( 114 | out_size, dtype=in_tensor.dtype, device=in_tensor.device 115 | ) 116 | torch.distributed.all_gather_into_tensor(out_tensor, in_tensor) 117 | return out_tensor 118 | return in_tensor 119 | 120 | 121 | def all_gather(tensor: torch.Tensor) -> List[torch.Tensor]: 122 | if torch.distributed.is_initialized(): 123 | shapes = [ 124 | tensor.shape if r == get_rank() else None for r in range(get_world_size()) 125 | ] 126 | for r in range(get_world_size()): 127 | shapes[r] = broadcast_obj(shapes[r], src=r) 128 | tensor_list = [ 129 | ( 130 | tensor 131 | if r == get_rank() 132 | else torch.empty(shapes[r], device=tensor.device, dtype=tensor.dtype) # type: ignore 133 | ) # type: ignore 134 | for r in range(get_world_size()) 135 | ] 136 | torch.distributed.all_gather(tensor_list, tensor) 137 | return tensor_list 138 | return [tensor] 139 | 140 | 141 | def all_gather_object(obj) -> List: 142 | if torch.distributed.is_initialized(): 143 | output = [None for _ in range(get_world_size())] 144 | 145 | torch.distributed.all_gather_object(output, obj) 146 | return output 147 | return [obj] 148 | 149 | 150 | def print0(*args, **kwargs): 151 | if get_rank() == 0: 152 | print(*args, **kwargs) 153 | -------------------------------------------------------------------------------- /franken/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from pathlib import Path 3 | import requests 4 | from tqdm.auto import tqdm 5 | 6 | 7 | def compute_md5(file_path): 8 | md5_hash = hashlib.md5() 9 | with open(file_path, "rb") as f: 10 | for chunk in iter( 11 | lambda: f.read(8192), b"" 12 | ): # Read in chunks to handle large files 13 | md5_hash.update(chunk) 14 | return md5_hash.hexdigest() 15 | 16 | 17 | def validate_file( 18 | filename: Path, expected_size: int | None = None, expected_md5: str | None = None 19 | ) -> bool: 20 | if expected_md5 is not None: 21 | actual_md5 = compute_md5(filename) 22 | if actual_md5 != expected_md5: 23 | return False 24 | if expected_size is not None: 25 | actual_size = filename.stat().st_size 26 | if expected_size != actual_size: 27 | return False 28 | return True 29 | 30 | 31 | def download_file( 32 | url: str, 33 | filename: Path, 34 | expected_size: int | None = None, 35 | expected_md5: str | None = None, 36 | desc: str | None = None, 37 | ): 38 | if (expected_md5 is not None or expected_size is not None) and filename.is_file(): 39 | # Check that the file is correct to avoid re-download 40 | if validate_file(filename, expected_size, expected_md5): 41 | return filename 42 | 43 | response = requests.get(url, stream=True) 44 | response.raise_for_status() 45 | total_size = int(response.headers.get("content-length", 0)) 46 | block_size = 8192 47 | data_size = 0 48 | data_md5 = hashlib.md5() 49 | with ( 50 | open(filename.with_suffix(".temp"), "wb") as file, 51 | tqdm( 52 | desc=desc or str(filename), 53 | total=total_size, 54 | unit="B", 55 | unit_scale=True, 56 | unit_divisor=1024, 57 | ) as bar, 58 | ): 59 | for chunk in response.iter_content(chunk_size=block_size): 60 | file.write(chunk) 61 | data_md5.update(chunk) 62 | data_size += len(chunk) 63 | bar.update(len(chunk)) 64 | # validate 65 | if expected_size is not None and expected_size != data_size: 66 | raise IOError("Incorrect file size", filename) 67 | if expected_md5 is not None and data_md5.hexdigest() != expected_md5: 68 | raise IOError("Incorrect file MD5", filename) 69 | 70 | filename.with_suffix(".temp").replace(filename) 71 | return filename 72 | -------------------------------------------------------------------------------- /franken/utils/jac.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import Any, Callable, Optional, Sequence, Tuple, Union 3 | 4 | import torch 5 | from torch._functorch.eager_transforms import ( 6 | _construct_standard_basis_for, 7 | _jvp_with_argnums, 8 | _slice_argnums, 9 | error_if_complex, 10 | safe_unflatten, 11 | ) 12 | from torch._functorch.utils import argnums_t 13 | from torch.func import vmap 14 | from torch.utils._pytree import tree_flatten, tree_unflatten 15 | 16 | from franken.utils.misc import garbage_collection_cuda, is_cuda_out_of_memory 17 | 18 | 19 | def jacfwd( 20 | # drop-in replacement of torch.func.jacfwd accepting the chunk_size argument (as with jacrev) 21 | func: Callable, 22 | argnums: argnums_t = 0, 23 | has_aux: bool = False, 24 | *, 25 | randomness: str = "error", 26 | chunk_size: Optional[int] = None, 27 | ): 28 | def wrapper_fn(*args): 29 | error_if_complex("jacfwd", args, is_input=True) 30 | primals = args if argnums is None else _slice_argnums(args, argnums) 31 | flat_primals, primals_spec = tree_flatten(primals) 32 | flat_primals_numels = tuple(p.numel() for p in flat_primals) 33 | flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels) 34 | basis = tree_unflatten(flat_basis, primals_spec) 35 | 36 | def push_jvp(basis): 37 | output = _jvp_with_argnums( 38 | func, args, basis, argnums=argnums, has_aux=has_aux 39 | ) 40 | # output[0] is the output of `func(*args)` 41 | error_if_complex("jacfwd", output[0], is_input=False) 42 | if has_aux: 43 | _, jvp_out, aux = output 44 | return jvp_out, aux 45 | _, jvp_out = output 46 | return jvp_out 47 | 48 | results = vmap(push_jvp, randomness=randomness, chunk_size=chunk_size)(basis) 49 | if has_aux: 50 | results, aux = results 51 | # aux is in the standard basis format, e.g. NxN matrix 52 | # We need to fetch the first element as original `func` output 53 | flat_aux, aux_spec = tree_flatten(aux) 54 | flat_aux = [value[0] for value in flat_aux] 55 | aux = tree_unflatten(flat_aux, aux_spec) 56 | 57 | jac_outs, spec = tree_flatten(results) 58 | # Most probably below output check can never raise an error 59 | # as jvp should test the output before 60 | # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)') 61 | 62 | jac_outs_ins = tuple( 63 | tuple( 64 | safe_unflatten(jac_out_in, -1, primal.shape) 65 | for primal, jac_out_in in zip( 66 | flat_primals, 67 | jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1), 68 | ) 69 | ) 70 | for jac_out in jac_outs 71 | ) 72 | jac_outs_ins = tuple( 73 | tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins 74 | ) 75 | 76 | if isinstance(argnums, int): 77 | jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins) 78 | if has_aux: 79 | return tree_unflatten(jac_outs_ins, spec), aux 80 | return tree_unflatten(jac_outs_ins, spec) 81 | 82 | # Dynamo does not support HOP composition if their inner function is 83 | # annotated with @functools.wraps(...). We circumvent this issue by applying 84 | # wraps only if we're not tracing with dynamo. 85 | if not torch._dynamo.is_compiling(): 86 | wrapper_fn = wraps(func)(wrapper_fn) 87 | 88 | return wrapper_fn 89 | 90 | 91 | def tune_jacfwd_chunksize( 92 | test_sample: Sequence[Union[torch.Tensor, Any]], 93 | mode: str = "power", 94 | init_val: int = 32, 95 | max_trials: int = 25, 96 | **jac_kwargs, 97 | ): 98 | try: 99 | # We want to tune this and set it ourselves. 100 | jac_kwargs.pop("chunk_size") 101 | except KeyError: 102 | pass 103 | 104 | # Initially we just double in size until an OOM is encountered 105 | new_size, _ = _adjust_batch_size( 106 | test_sample, init_val, value=init_val, **jac_kwargs 107 | ) # initially set to init_val 108 | if mode == "power": 109 | new_size = _run_power_scaling(new_size, max_trials, test_sample, **jac_kwargs) 110 | else: 111 | raise ValueError("mode in method `scale_batch_size` can only be `power`") 112 | 113 | garbage_collection_cuda() 114 | return new_size 115 | 116 | 117 | def _run_power_scaling(new_size, max_trials, test_sample, **jac_kwargs) -> int: 118 | """Batch scaling mode where the size is doubled at each iteration until an 119 | OOM error is encountered.""" 120 | for _ in range(max_trials): 121 | garbage_collection_cuda() 122 | try: 123 | # Try jacfwd 124 | for _ in range(1): 125 | jacfwd(**jac_kwargs, chunk_size=new_size)(*test_sample) 126 | # Double in size 127 | new_size, changed = _adjust_batch_size( 128 | test_sample, new_size, factor=2.0, **jac_kwargs 129 | ) 130 | except RuntimeError as exception: 131 | # Only these errors should trigger an adjustment 132 | if is_cuda_out_of_memory(exception): 133 | # If we fail in power mode, half the size and return 134 | garbage_collection_cuda() 135 | new_size, _ = _adjust_batch_size( 136 | test_sample, new_size, factor=0.5, **jac_kwargs 137 | ) 138 | break 139 | else: 140 | raise # some other error not memory related 141 | if not changed: 142 | # No change in batch size, so we can exit. 143 | break 144 | return new_size 145 | 146 | 147 | def _adjust_batch_size( 148 | test_sample: Sequence[Union[torch.Tensor, Any]], 149 | batch_size: int, 150 | factor: float = 1.0, 151 | value: Optional[int] = None, 152 | **jac_kwargs, 153 | ) -> Tuple[int, bool]: 154 | max_batch_size = _get_max_batch_size(test_sample, **jac_kwargs) 155 | new_size = value if value is not None else int(batch_size * factor) 156 | new_size = min(new_size, max_batch_size) 157 | changed = new_size != batch_size 158 | return new_size, changed 159 | 160 | 161 | def _get_max_batch_size( 162 | test_sample: Sequence[Union[torch.Tensor, Any]], **jac_kwargs 163 | ) -> int: 164 | argnums = jac_kwargs.get("argnums", 0) 165 | if isinstance(argnums, int): 166 | argnums = [argnums] 167 | batch_size = 0 168 | for argnum in argnums: 169 | arg = test_sample[argnum] 170 | assert isinstance(arg, torch.Tensor) 171 | batch_size += arg.numel() 172 | return batch_size 173 | -------------------------------------------------------------------------------- /franken/utils/linalg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/franken/utils/linalg/__init__.py -------------------------------------------------------------------------------- /franken/utils/linalg/psdsolve.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | 5 | 6 | try: 7 | import cupy.cuda 8 | from cupy_backends.cuda.libs import cublas, cusolver 9 | except ImportError: 10 | cupy = None 11 | cusolver = None 12 | cublas = None 13 | 14 | 15 | def psd_ridge(cov: torch.Tensor, rhs: torch.Tensor, penalty: float) -> torch.Tensor: 16 | """Solve ridge regression via Cholesky factorization, overwriting :attr:`cov` and :attr:`rhs`. 17 | 18 | Multiple right-hand sides are supported. Instead of providing the data 19 | matrix (commonly :math:`X` in ridge-regression notation), and labels (commonly :math:`y`), 20 | we are given directly :math:`\text{cov} = X^T X` and :math:`\text{rhs} = X^T y`. 21 | Since :attr:`cov` is symmetric only its **upper triangle** will be accessed. 22 | 23 | To limit memory usage, the :attr:`cov` matrix **may be overwritten**, and :math:`rhs` 24 | may also be overwritten (depending on its memory layout). 25 | 26 | Args: 27 | cov (Tensor): covariance of the linear system 28 | rhs (Tensor): right hand side (one or more) of the linear system 29 | penalty (float): Tikhonov l2 penalty 30 | 31 | Returns: 32 | solution (Tensor): the ridge regression coefficients 33 | """ 34 | if cupy is not None and cov.device.type == "cuda": 35 | return _lowmem_psd_ridge(cov, rhs, penalty) 36 | else: 37 | # NOTE: this should be a warnings.warn NOT logger.warning - otherwise 38 | # it gets printed a lot of times and is just annoying. We could add 39 | # https://docs.python.org/library/logging.html#logging.captureWarnings 40 | # to the logger to capture warnings automatically. 41 | if cov.device.type == "cuda": 42 | warnings.warn( 43 | "low-memory solver cannot be used because `cupy` is not available. " 44 | "Install `cupy` if you encounter memory problems." 45 | ) 46 | return _naive_psd_ridge(cov, rhs, penalty) 47 | 48 | 49 | def _naive_psd_ridge( 50 | cov: torch.Tensor, rhs: torch.Tensor, penalty: float 51 | ) -> torch.Tensor: 52 | # Add diagonal without copies 53 | cov.diagonal().add_(penalty) 54 | # Solve with cholesky on GPU 55 | L = torch.linalg.cholesky(cov, upper=True) 56 | rhs_shape = rhs.shape 57 | return torch.cholesky_solve(rhs.view(cov.shape[0], -1), L, upper=True).view( 58 | rhs_shape 59 | ) 60 | 61 | 62 | def _lowmem_psd_ridge( 63 | cov: torch.Tensor, rhs: torch.Tensor, penalty: float 64 | ) -> torch.Tensor: 65 | assert cusolver is not None and cublas is not None and cupy is not None 66 | assert cov.device.type == "cuda" 67 | dtype = cov.dtype 68 | n = cov.shape[0] 69 | 70 | # Add diagonal without copies 71 | cov.diagonal().add_(penalty) 72 | 73 | if dtype == torch.float32: 74 | potrf = cusolver.spotrf 75 | potrf_bufferSize = cusolver.spotrf_bufferSize 76 | potrs = cusolver.spotrs 77 | elif dtype == torch.float64: 78 | potrf = cusolver.dpotrf 79 | potrf_bufferSize = cusolver.dpotrf_bufferSize 80 | potrs = cusolver.dpotrs 81 | else: 82 | raise ValueError(dtype) 83 | 84 | # cov must be f-contiguous (column-contiguous, stride is (1, n)) 85 | assert cov.dim() == 2 86 | assert cov.shape[0] == cov.shape[1] 87 | transpose = False 88 | if n != 1: 89 | if cov.stride(0) != 1: 90 | cov = cov.T 91 | transpose = True 92 | assert cov.stride(0) == 1 93 | cov_cp = cupy.asarray(cov) 94 | 95 | # save rhs shape to restore it later on. 96 | rhs_shape = rhs.shape 97 | rhs = rhs.reshape(n, -1) 98 | n_rhs = rhs.shape[1] 99 | if rhs.stride(0) != 1: # force rhs to be f-contiguous 100 | # `contiguous` causes a copy 101 | rhs = rhs.T.contiguous().T 102 | assert rhs.stride(0) == 1 103 | rhs_cp = cupy.asarray(rhs) 104 | 105 | handle = cupy.cuda.device.get_cusolver_handle() 106 | uplo = cublas.CUBLAS_FILL_MODE_LOWER if transpose else cublas.CUBLAS_FILL_MODE_UPPER 107 | dev_info = torch.empty( 108 | 1, dtype=torch.int32 109 | ) # don't allocate with cupy as it uses a separate mem pool 110 | dev_info_cp = cupy.asarray(dev_info) 111 | 112 | worksize = potrf_bufferSize(handle, uplo, n, cov_cp.data.ptr, n) 113 | workspace = torch.empty(worksize, dtype=dtype) 114 | workspace_cp = cupy.asarray(workspace) 115 | 116 | # Cholesky factorization 117 | potrf( 118 | handle, 119 | uplo, 120 | n, 121 | cov_cp.data.ptr, 122 | n, 123 | workspace_cp.data.ptr, 124 | worksize, 125 | dev_info_cp.data.ptr, 126 | ) 127 | if (dev_info_cp != 0).any(): 128 | raise torch.linalg.LinAlgError( 129 | f"Error reported by {potrf.__name__} in cuSOLVER. devInfo = {dev_info_cp}." 130 | ) 131 | 132 | # Solve: A * X = B 133 | potrs( 134 | handle, 135 | uplo, 136 | n, 137 | n_rhs, 138 | cov_cp.data.ptr, 139 | n, 140 | rhs_cp.data.ptr, 141 | n, 142 | dev_info_cp.data.ptr, 143 | ) 144 | if (dev_info_cp != 0).any(): 145 | raise torch.linalg.LinAlgError( 146 | f"Error reported by {potrf.__name__} in cuSOLVER. devInfo = {dev_info_cp}." 147 | ) 148 | 149 | return torch.as_tensor(rhs).reshape(rhs_shape) 150 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "franken" 3 | dynamic = ["version"] 4 | authors = [ 5 | { name="Pietro Novelli", email="pietronvll@gmail.com" }, 6 | { name="Giacomo Meanti" }, 7 | { name="Luigi Bonati" }, 8 | { name="Pedro Juan Buigues Jorro" } 9 | ] 10 | description = "Franken fine-tuning scheme for ML potentials" 11 | readme = "README.md" 12 | license = "MIT" 13 | license-files = ["LICENSE.*"] 14 | requires-python = ">=3.10" 15 | dependencies = [ 16 | "torch >= 2.4.0", 17 | "ase", 18 | "numpy", 19 | "tqdm", 20 | "psutil", 21 | "scipy", 22 | "e3nn", 23 | "omegaconf", 24 | "requests", 25 | "docstring_parser", 26 | ] 27 | classifiers = [ 28 | "Development Status :: 4 - Beta", 29 | "Intended Audience :: Science/Research", 30 | "License :: OSI Approved :: MIT License", 31 | "Programming Language :: Python :: 3.10", 32 | "Programming Language :: Python :: 3.11", 33 | "Programming Language :: Python :: 3.12", 34 | ] 35 | keywords = [ 36 | "franken", "potentials", "molecular dynamics", 37 | ] 38 | 39 | [project.urls] 40 | Homepage = "https://franken.readthedocs.io/" 41 | Documentation = "https://franken.readthedocs.io/" 42 | Repository = "https://github.com/CSML-IIT-UCL/franken" 43 | 44 | [project.scripts] 45 | "franken.backbones" = "franken.backbones.cli:main" 46 | "franken.autotune" = "franken.autotune.script:cli_entry_point" 47 | "franken.create_lammps_model" = "franken.calculators.lammps_calc:create_lammps_model_cli" 48 | 49 | [project.optional-dependencies] 50 | develop = [ 51 | "black ~= 24.0", 52 | "ruff", 53 | "pytest", 54 | "pre-commit", 55 | "pytest", 56 | "packaging", 57 | ] 58 | mace = ["mace-torch >= 0.3.10"] 59 | fairchem = ["fairchem-core == 1.10"] 60 | sevenn = ["sevenn ~= 0.11"] 61 | cuda = ["cupy"] 62 | docs = [ 63 | "Sphinx", 64 | "sphinxawesome-theme", 65 | "sphinxcontrib-applehelp", 66 | "sphinxcontrib-devhelp", 67 | "sphinxcontrib-htmlhelp", 68 | "sphinxcontrib-jsmath", 69 | "sphinxcontrib-qthelp", 70 | "sphinxcontrib-serializinghtml", 71 | "sphinx-argparse", 72 | "myst-parser", 73 | "nbsphinx", 74 | ] 75 | 76 | [build-system] 77 | requires = ["hatchling"] 78 | build-backend = "hatchling.build" 79 | 80 | [tool.hatch.version] 81 | path = "franken/__init__.py" 82 | 83 | [tool.hatch.build.targets.sdist] 84 | only-include = ["franken", "tests"] 85 | 86 | [tool.hatch.build.targets.wheel] 87 | include = [ 88 | "franken/**/*.py", 89 | "franken/autotune/configs/**/*.yaml", 90 | "franken/mdgen/configs/**/*.yaml", 91 | "franken/backbones/registry.json", 92 | "franken/datasets/water/*.csv", 93 | "franken/datasets/test/*", 94 | ] 95 | exclude = [ 96 | "franken/datasets/ala3", 97 | "franken/datasets/chignolin", 98 | "franken/datasets/Cu-EMT", 99 | "franken/datasets/CuFormate", 100 | "franken/datasets/Fe_N2", 101 | "franken/datasets/Fe4N", 102 | "franken/datasets/FeBulk", 103 | "franken/datasets/LiPS", 104 | "franken/datasets/MD22", 105 | "franken/datasets/split_data.py", 106 | "franken/datasets/download_and_process_all.sh", 107 | "franken/datasets/readme", 108 | ] 109 | 110 | [tool.black] 111 | line-length = 88 112 | target-version = ['py310', 'py312'] 113 | force-exclude = '^/((?!franken/))' 114 | 115 | [tool.ruff] 116 | target-version = "py310" 117 | include = [ 118 | "pyproject.toml", 119 | "franken/**/*.py", 120 | ] 121 | extend-exclude = [ 122 | "franken/utils/hostlist.py", 123 | ] 124 | force-exclude = true 125 | 126 | [tool.ruff.lint] 127 | select = ["E4", "E7", "E9", "F", "W"] 128 | ignore = [ 129 | "E501", # Avoid enforcing line-length violations (`E501`) 130 | ] 131 | 132 | [tool.pytest.ini_options] 133 | testpaths = ["tests"] 134 | markers = [ 135 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 136 | ] 137 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSML-IIT-UCL/franken/497d6ad4ee63b46d2528bcc38ab9b7581c4c6255/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import random 4 | 5 | import numpy 6 | import pytest 7 | import torch 8 | 9 | from franken import FRANKEN_DIR 10 | from franken.backbones.utils import CacheDir, download_checkpoint 11 | from franken.config import MaceBackboneConfig 12 | 13 | 14 | __all__ = [ 15 | "ROOT_PATH", 16 | "DEFAULT_GNN_CONFIGS", 17 | "SKIP_NO_CUDA", 18 | "DEVICES", 19 | "DEV_CPU_FAIL" 20 | ] 21 | 22 | ROOT_PATH = FRANKEN_DIR 23 | 24 | DEFAULT_GNN_CONFIGS = [ 25 | MaceBackboneConfig("MACE-L0") 26 | ] # , "SchNet-S2EF-OC20-All"] # List of gnn_ids to download 27 | 28 | SKIP_NO_CUDA = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") 29 | 30 | DEVICES = [ 31 | "cpu", 32 | pytest.param("cuda:0", marks=SKIP_NO_CUDA) # type: ignore 33 | ] 34 | DEV_CPU_FAIL = [ 35 | pytest.param(dev, marks=pytest.mark.xfail(run=False, reason="Not implemented on CPU")) 36 | if dev == "cpu" 37 | else dev 38 | for dev in DEVICES 39 | ] 40 | 41 | def prepare_gnn_checkpoints(): 42 | # Ensure each gnn_id backbone is downloaded 43 | for gnn_cfg in DEFAULT_GNN_CONFIGS: 44 | download_checkpoint(gnn_cfg.path_or_id) 45 | return CacheDir.get() / "gnn_checkpoints" 46 | 47 | def pytest_sessionstart(session): 48 | """ 49 | Called after the Session object has been created and 50 | before performing collection and entering the run test loop. 51 | """ 52 | 53 | # Cache-dir is either FRANKEN_CACHE_DIR if specified or in the repository folder. 54 | CacheDir.initialize( 55 | os.environ.get("FRANKEN_CACHE_DIR", Path(__file__).parent / ".franken") 56 | ) 57 | 58 | prepare_gnn_checkpoints() 59 | 60 | 61 | @pytest.fixture(autouse=True) 62 | def random_seed(): 63 | """This fixture is called before each test and sets random seeds""" 64 | random.seed(14) 65 | numpy.random.seed(14) 66 | torch.manual_seed(14) 67 | -------------------------------------------------------------------------------- /tests/test_backbones.py: -------------------------------------------------------------------------------- 1 | import e3nn 2 | import pytest 3 | from packaging.version import Version 4 | import torch 5 | 6 | from franken.config import BackboneConfig, GaussianRFConfig 7 | from franken.data import BaseAtomsDataset 8 | from franken.datasets.registry import DATASET_REGISTRY 9 | from franken.backbones import REGISTRY 10 | from franken.backbones.utils import load_checkpoint 11 | from franken.rf.model import FrankenPotential 12 | 13 | 14 | models = [ 15 | "Egret-1t", 16 | pytest.param("MACE-L1", marks=pytest.mark.xfail(Version(e3nn.__version__) >= Version("0.5.5"), reason="Known incompatibility", strict=True)), 17 | pytest.param("MACE-OFF-small", marks=pytest.mark.xfail(Version(e3nn.__version__) >= Version("0.5.5"), reason="Known incompatibility", strict=True)), 18 | pytest.param("SevenNet0", marks=pytest.mark.xfail(Version(e3nn.__version__) < Version("0.5.0"), reason="Known incompatibility", strict=True)), 19 | pytest.param("SchNet-S2EF-OC20-200k", marks=pytest.mark.xfail(reason="Fails in CI due to unknown reasons", strict=False)) 20 | ] 21 | 22 | 23 | @pytest.mark.parametrize("model_name", models) 24 | def test_backbone_loading(model_name): 25 | registry_entry = REGISTRY[model_name] 26 | gnn_config = BackboneConfig.from_ckpt({ 27 | "family": registry_entry["kind"], 28 | "path_or_id": model_name, 29 | "interaction_block": 2, 30 | }) 31 | load_checkpoint(gnn_config) 32 | 33 | 34 | @pytest.mark.parametrize("model_name", models) 35 | def test_descriptors(model_name): 36 | registry_entry = REGISTRY[model_name] 37 | gnn_config = BackboneConfig.from_ckpt({ 38 | "family": registry_entry["kind"], 39 | "path_or_id": model_name, 40 | "interaction_block": 2, 41 | }) 42 | bbone = load_checkpoint(gnn_config) 43 | # Get a random data sample 44 | data_path = DATASET_REGISTRY.get_path("test", "train", None, False) 45 | dataset = BaseAtomsDataset.from_path( 46 | data_path=data_path, 47 | split="train", 48 | gnn_config=gnn_config, 49 | ) 50 | data, _ = dataset[0] # type: ignore 51 | expected_fdim = bbone.feature_dim() 52 | features = bbone.descriptors(data) 53 | assert features.shape[1] == expected_fdim 54 | 55 | 56 | @pytest.mark.parametrize("model_name", models) 57 | def test_force_maps(model_name): 58 | from franken.backbones.wrappers.common_patches import patch_e3nn 59 | patch_e3nn() 60 | registry_entry = REGISTRY[model_name] 61 | gnn_config = BackboneConfig.from_ckpt({ 62 | "family": registry_entry["kind"], 63 | "path_or_id": model_name, 64 | "interaction_block": 2, 65 | }) 66 | # Get a random data sample 67 | data_path = DATASET_REGISTRY.get_path("test", "train", None, False) 68 | dataset = BaseAtomsDataset.from_path( 69 | data_path=data_path, 70 | split="train", 71 | gnn_config=gnn_config, 72 | ) 73 | device="cuda:0" if torch.cuda.is_available() else "cpu" 74 | # initialize model 75 | model = FrankenPotential( 76 | gnn_config=gnn_config, 77 | rf_config=GaussianRFConfig(num_random_features=128, length_scale=1.0), 78 | ) 79 | model = model.to(device) 80 | data, _ = dataset[0] # type: ignore 81 | data = data.to(device) 82 | emap, fmap = model.grad_feature_map(data) 83 | 84 | -------------------------------------------------------------------------------- /tests/test_backbones_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | 7 | import franken.backbones.utils 8 | 9 | 10 | @pytest.fixture 11 | def mock_registry(): 12 | return { 13 | "UNIMPLEMENTED_MODEL": { 14 | "kind": "mock", 15 | "implemented": False, 16 | "local": "unimplemented.ckpt", 17 | "remote": "https://example.com", 18 | } 19 | } 20 | 21 | 22 | @pytest.fixture 23 | def mock_cache_folder(): 24 | return Path("/tmp/cache") 25 | 26 | 27 | def test_model_registry(): 28 | registry = franken.backbones.utils.load_model_registry() 29 | for model in registry.values(): 30 | for key in ["remote", "local", "kind", "implemented"]: 31 | assert key in model.keys() 32 | 33 | 34 | def test_cache_dir_default(mock_cache_folder): 35 | """Test that the function returns the default path when FRANKEN_CACHE_DIR is not set.""" 36 | # Ensure no environment variable is set 37 | with patch.dict(os.environ, {}, clear=True): 38 | # Mock the home path and the Path.exists method 39 | with patch("pathlib.Path.home", return_value=mock_cache_folder): 40 | with patch("pathlib.Path.exists", return_value=True) as mock_exists: 41 | # Call the function 42 | franken.backbones.utils.CacheDir.initialize() 43 | result = franken.backbones.utils.CacheDir.get() 44 | 45 | # Check the default path is returned 46 | assert result == mock_cache_folder / ".franken" 47 | # Ensure that the path exists 48 | mock_exists.assert_called_once() 49 | 50 | 51 | def test_cache_dir_with_env_var(mock_cache_folder): 52 | """Test that the function returns the correct path when FRANKEN_CACHE_DIR is set.""" 53 | # Mock the environment variable 54 | with patch.dict(os.environ, {"FRANKEN_CACHE_DIR": str(mock_cache_folder)}): 55 | # Mock the Path.exists method 56 | with patch("pathlib.Path.exists", return_value=True) as mock_exists: 57 | # Call the function 58 | franken.backbones.utils.CacheDir.initialize() 59 | result = franken.backbones.utils.CacheDir.get() 60 | 61 | # Check the environment variable path is returned 62 | assert str(result) == str(mock_cache_folder) 63 | # Ensure that the path exists 64 | mock_exists.assert_called_once() 65 | 66 | 67 | def test_download_checkpoint_name_error(): 68 | """Test that a NameError is raised for unknown gnn_backbone_id.""" 69 | # Mock the model registry to return an empty registry 70 | with patch("franken.backbones.utils.load_model_registry", return_value={}): 71 | # Expect a NameError when the gnn_backbone_id is not in the registry 72 | with pytest.raises(NameError) as exc_info: 73 | franken.backbones.utils.download_checkpoint("UNKNOWN_MODEL") 74 | assert "Unknown UNKNOWN_MODEL GNN backbone" in str(exc_info.value) 75 | 76 | 77 | def test_download_checkpoint_not_implemented(mock_registry): 78 | """Test that a NotImplementedError is raised when the model is not implemented.""" 79 | # Mock the model registry to return a registry with a model that is not implemented 80 | with patch( 81 | "franken.backbones.utils.load_model_registry", return_value=mock_registry 82 | ): 83 | # Expect a NotImplementedError when the gnn_backbone_id is not implemented 84 | with pytest.raises(NotImplementedError) as exc_info: 85 | franken.backbones.utils.download_checkpoint("UNIMPLEMENTED_MODEL") 86 | assert "The model UNIMPLEMENTED_MODEL is not implemented" in str(exc_info.value) 87 | 88 | 89 | @pytest.mark.skip(reason="Actually downloads the model") 90 | def test_download_checkpoint_successful_download(tmp_path): 91 | gnn_id = "MACE-L0" 92 | """Test that the model is downloaded correctly when it is implemented.""" 93 | registry = franken.backbones.utils.load_model_registry() 94 | with patch.dict(os.environ, {"FRANKEN_CACHE_DIR": str(tmp_path)}): 95 | franken.backbones.utils.download_checkpoint(gnn_id) 96 | ckpt = tmp_path / "gnn_checkpoints" / registry[gnn_id]["local"] 97 | assert ckpt.exists() 98 | assert ckpt.is_file() 99 | 100 | 101 | def test_get_checkpoint_path_valid_backbone(mock_registry, mock_cache_folder): 102 | with patch( 103 | "franken.backbones.utils.load_model_registry", return_value=mock_registry 104 | ), patch( 105 | "franken.backbones.utils.CacheDir.get", return_value=mock_cache_folder 106 | ), patch("pathlib.Path.exists", return_value=True): 107 | result = franken.backbones.utils.get_checkpoint_path("UNIMPLEMENTED_MODEL") 108 | expected_path = mock_cache_folder / "gnn_checkpoints" / "unimplemented.ckpt" 109 | assert result == expected_path 110 | 111 | 112 | def test_get_checkpoint_path_invalid_backbone(mock_registry): 113 | with patch( 114 | "franken.backbones.utils.load_model_registry", return_value=mock_registry 115 | ), patch( 116 | "franken.backbones.utils.make_summary", return_value="available backbones" 117 | ): 118 | with pytest.raises(FileNotFoundError) as exc_info: 119 | franken.backbones.utils.get_checkpoint_path("invalid_backbone") 120 | 121 | assert "GNN Backbone path 'invalid_backbone' does not exist." in str(exc_info.value) 122 | assert "available backbones" in str(exc_info.value) 123 | 124 | 125 | def test_get_checkpoint_path_download_required(mock_registry, mock_cache_folder): 126 | with patch( 127 | "franken.backbones.utils.load_model_registry", return_value=mock_registry 128 | ), patch( 129 | "franken.backbones.utils.CacheDir.get", return_value=mock_cache_folder 130 | ), patch("pathlib.Path.exists", return_value=False), patch( 131 | "franken.backbones.utils.download_checkpoint" 132 | ) as mock_download: 133 | result = franken.backbones.utils.get_checkpoint_path("UNIMPLEMENTED_MODEL") 134 | expected_path = mock_cache_folder / "gnn_checkpoints" / "unimplemented.ckpt" 135 | assert result == expected_path 136 | mock_download.assert_called_once_with("UNIMPLEMENTED_MODEL") 137 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import torch.distributed 4 | from torch.multiprocessing import Process, Pipe, SimpleQueue 5 | 6 | from franken.data.base import Configuration, SimpleAtomsDataset 7 | from franken.datasets.registry import DATASET_REGISTRY 8 | 9 | class ThrowingProcess(Process): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self._pconn, self._cconn = Pipe() 13 | self._exception = None 14 | 15 | def run(self): 16 | try: 17 | super().run() 18 | self._cconn.send(None) 19 | except Exception as e: 20 | self._cconn.send(e) 21 | raise e 22 | 23 | @property 24 | def exception(self): 25 | if self._pconn.poll(): 26 | self._exception = self._pconn.recv() 27 | return self._exception 28 | 29 | 30 | def init_processes(rank, size, fn, backend='gloo'): 31 | """ Initialize the distributed environment. """ 32 | os.environ['MASTER_ADDR'] = '127.0.0.64' 33 | os.environ['MASTER_PORT'] = '26512' 34 | os.environ['GLOO_SOCKET_IFNAME'] = "lo" 35 | torch.distributed.init_process_group(backend, rank=rank, world_size=size) 36 | fn() 37 | 38 | 39 | def init_distributed_cpu(num_proc, run_fn): 40 | processes = [] 41 | for rank in range(num_proc): 42 | p = ThrowingProcess(target=init_processes, args=(rank, num_proc, run_fn)) 43 | p.start() 44 | processes.append(p) 45 | 46 | for p in processes: 47 | p.join() 48 | if p.exception: 49 | error = p.exception 50 | raise error 51 | 52 | 53 | def mocked_dataset(num_atoms, dtype, device, num_configs: int = 1): 54 | data = [] 55 | for _ in range(num_configs): 56 | data.append(Configuration( 57 | torch.randn(num_atoms, 3, dtype=dtype), 58 | torch.randint(1, 100, (num_atoms,)), 59 | torch.tensor(num_atoms), 60 | ).to(device)) 61 | return data 62 | 63 | 64 | @pytest.mark.parametrize("num_samples", [1, 7, 19]) 65 | @pytest.mark.parametrize("num_procs", [1, 4]) 66 | def test_distributed_dataloader_length(num_samples, num_procs): 67 | def inner_fn(): 68 | data_path = DATASET_REGISTRY.get_path("test", "long", None, False) 69 | dataset = SimpleAtomsDataset( 70 | data_path, 71 | split="train", 72 | num_random_subsamples=num_samples, 73 | subsample_rng=None, 74 | ) 75 | assert len(dataset) == num_samples 76 | dataloader = dataset.get_dataloader(True) 77 | rank = torch.distributed.get_rank() 78 | ws = torch.distributed.get_world_size() 79 | assert len(dataloader) == (len(dataset) // ws) + int(len(dataset) % ws > rank) 80 | 81 | init_distributed_cpu(num_procs, inner_fn) 82 | 83 | 84 | def test_distributed_dataloader_order(): 85 | num_samples = 7 86 | num_procs = 3 87 | ids_queue = SimpleQueue() 88 | def inner_fn(): 89 | data_path = DATASET_REGISTRY.get_path("test", "long", None, False) 90 | dataset = SimpleAtomsDataset( 91 | data_path, 92 | split="train", 93 | num_random_subsamples=num_samples, 94 | subsample_rng=None, 95 | ) 96 | assert len(dataset) == num_samples 97 | dataloader = dataset.get_dataloader(True) 98 | rank = torch.distributed.get_rank() 99 | dl_elements = [el for el in dataloader] 100 | dl_id = 0 101 | for i in range(rank, num_samples, num_procs): 102 | torch.testing.assert_close( 103 | dl_elements[dl_id][0].atom_pos, dataset[i][0].atom_pos 104 | ) 105 | torch.testing.assert_close( 106 | dl_elements[dl_id][1].forces, dataset[i][1].forces 107 | ) 108 | dl_id += 1 109 | ids_queue.put(i) 110 | assert dl_id == len(dl_elements) 111 | init_distributed_cpu(num_procs, inner_fn) 112 | # Assert all IDs were processed - only once 113 | all_ids = [] 114 | while not ids_queue.empty(): 115 | all_ids.append(ids_queue.get()) 116 | assert sorted(all_ids) == list(range(num_samples)) 117 | -------------------------------------------------------------------------------- /tests/test_lammps.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the model conversion to LAMMPS (essentially testing torch-scriptability, not LAMMPS directly) 3 | """ 4 | 5 | import os 6 | import pytest 7 | import torch 8 | 9 | from franken.backbones.wrappers.common_patches import unpatch_e3nn 10 | from franken.config import GaussianRFConfig, MaceBackboneConfig, MultiscaleGaussianRFConfig 11 | from franken.data import BaseAtomsDataset 12 | from franken.rf.model import FrankenPotential 13 | from franken.rf.scaler import Statistics 14 | from franken.utils.misc import garbage_collection_cuda 15 | from franken.datasets.registry import DATASET_REGISTRY 16 | from franken.calculators.lammps_calc import LammpsFrankenCalculator 17 | 18 | from .conftest import DEVICES 19 | from .utils import are_dicts_close, cleanup_dir, create_temp_dir 20 | 21 | 22 | RF_PARAMETRIZE = [ 23 | GaussianRFConfig(num_random_features=128, length_scale=1.0), 24 | MultiscaleGaussianRFConfig(num_random_features=128), 25 | ] 26 | 27 | 28 | @pytest.mark.parametrize("rf_cfg", RF_PARAMETRIZE) 29 | @pytest.mark.parametrize("device", DEVICES) 30 | def test_lammps_compile(rf_cfg, device): 31 | """Test for checking save and load methods of FrankenPotential""" 32 | unpatch_e3nn() # needed in case some previous test ran the patching code 33 | gnn_cfg = MaceBackboneConfig("MACE-L0") 34 | temp_dir = None 35 | try: 36 | # Step 1: Create a temporary directory for saving the model 37 | temp_dir = create_temp_dir() 38 | 39 | data_path = DATASET_REGISTRY.get_path("test", "test", None, False) 40 | dataset = BaseAtomsDataset.from_path( 41 | data_path=data_path, 42 | split="train", 43 | gnn_config=gnn_cfg, 44 | ) 45 | model = FrankenPotential( 46 | gnn_config=gnn_cfg, 47 | rf_config=rf_cfg, 48 | scale_by_Z=True, 49 | num_species=dataset.num_species, 50 | ).to(device) 51 | 52 | with torch.no_grad(): 53 | gnn_features_stats = Statistics() 54 | for data, _ in dataset: # type: ignore 55 | data = data.to(device=device) 56 | gnn_features = model.gnn.descriptors(data) 57 | gnn_features_stats.update( 58 | gnn_features, atomic_numbers=data.atomic_numbers 59 | ) 60 | 61 | model.input_scaler.set_from_statistics(gnn_features_stats) 62 | garbage_collection_cuda() 63 | 64 | # Step 2: Save the model to the temporary directory 65 | model_save_path = os.path.join(temp_dir, "model_checkpoint.pth") 66 | model.save(model_save_path) 67 | 68 | # Step 3: Run create_lammps_model 69 | comp_model_path = LammpsFrankenCalculator.create_lammps_model(model_path=model_save_path, rf_weight_id=None) 70 | 71 | # Step 4: Load saved model 72 | comp_model = torch.jit.load(comp_model_path, map_location=device) 73 | 74 | # Step 4: Compare rf.state_dict between the original and loaded models 75 | with pytest.raises(RuntimeError) as exc: 76 | assert are_dicts_close( 77 | model.rf.state_dict(), comp_model.model.rf.state_dict(), verbose=True 78 | ) 79 | assert "Float did not match Double" in str(exc.value) 80 | assert are_dicts_close( 81 | model.rf.double().state_dict(), comp_model.model.rf.state_dict(), verbose=True 82 | ), "The rf.state_dict() of the loaded model does not match the original model." 83 | 84 | with pytest.raises(RuntimeError) as exc: 85 | assert are_dicts_close( 86 | model.input_scaler.state_dict(), 87 | comp_model.model.input_scaler.state_dict(), 88 | verbose=True, 89 | ) 90 | assert "Float did not match Double" in str(exc.value) 91 | assert are_dicts_close( 92 | model.input_scaler.double().state_dict(), 93 | comp_model.model.input_scaler.state_dict(), 94 | verbose=True, 95 | ), "The input_scaler.state_dict() of the loaded model does not match the original model." 96 | 97 | with pytest.raises(RuntimeError) as exc: 98 | assert are_dicts_close( 99 | model.energy_shift.state_dict(), 100 | comp_model.model.energy_shift.state_dict(), 101 | verbose=True, 102 | ) 103 | assert "Float did not match Double" in str(exc.value) 104 | assert are_dicts_close( 105 | model.energy_shift.double().state_dict(), 106 | comp_model.model.energy_shift.state_dict(), 107 | verbose=True, 108 | ), "The energy_shift.state_dict() of the loaded model does not match the original model." 109 | finally: 110 | if temp_dir is not None: 111 | cleanup_dir(temp_dir) 112 | 113 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | def test_registry_init(): 2 | from franken.metrics.registry import registry 3 | 4 | assert hasattr(registry._instance, "_metrics") 5 | 6 | 7 | def test_available_metrics(): 8 | import franken.metrics as fm 9 | 10 | for name in ["energy_MAE", "forces_MAE", "forces_cosim"]: 11 | assert name in fm.available_metrics() 12 | 13 | 14 | def test_register(): 15 | import franken.metrics as fm 16 | from franken.metrics.base import BaseMetric 17 | 18 | class MockMetric(BaseMetric): 19 | pass 20 | 21 | assert "mock_metric" not in fm.available_metrics() 22 | fm.register("mock_metric", MockMetric) 23 | assert "mock_metric" in fm.available_metrics() 24 | 25 | 26 | def test_init_metric(): 27 | import torch 28 | 29 | import franken.metrics as fm 30 | from franken.metrics.base import BaseMetric 31 | 32 | class MockMetric(BaseMetric): 33 | def __init__(self, device: torch.device, dtype: torch.dtype = torch.float32): 34 | super().__init__("mock_metric", device, dtype) 35 | 36 | fm.register("mock_metric", MockMetric) 37 | metric = fm.init_metric("mock_metric", torch.device("cpu")) 38 | assert isinstance(metric, MockMetric) 39 | -------------------------------------------------------------------------------- /tests/test_rf_heads.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from franken.rf.heads import ( 5 | BiasedOrthogonalRFF, 6 | Linear, 7 | MultiScaleOrthogonalRFF, 8 | OrthogonalRFF, 9 | TensorSketch, 10 | ) 11 | 12 | RF_PARAMETRIZE = [ 13 | "poly", 14 | "gaussian", 15 | "linear", 16 | "biased-gaussian", 17 | "multiscale-gaussian", 18 | ] 19 | 20 | 21 | def init_rf(rf_type: str, *args, **kwargs): 22 | if rf_type == "poly": 23 | return TensorSketch(*args, **kwargs) 24 | elif rf_type == "gaussian": 25 | return OrthogonalRFF(*args, **kwargs) 26 | elif rf_type == "linear": 27 | return Linear(*args, **kwargs) 28 | elif rf_type == "biased-gaussian": 29 | return BiasedOrthogonalRFF(*args, **kwargs) 30 | elif rf_type == "multiscale-gaussian": 31 | return MultiScaleOrthogonalRFF(*args, **kwargs) 32 | else: 33 | raise ValueError(rf_type) 34 | 35 | 36 | @pytest.mark.parametrize("rf_type", RF_PARAMETRIZE) 37 | class TestDtype: 38 | @pytest.mark.parametrize("dt", [torch.float32, torch.float64]) 39 | def test_dtype_match(self, dt, rf_type): 40 | rf = init_rf( 41 | rf_type, 42 | input_dim=64, 43 | ) 44 | data = torch.randn(10, 64, dtype=dt) 45 | atomic_nums = torch.randint(1, 100, (10,)) 46 | fmap = rf.feature_map(data, atomic_numbers=atomic_nums) 47 | assert fmap.dtype == dt 48 | for buf_name, buf in rf.named_buffers(): 49 | if buf_name == "weights": 50 | # weights not touched by this test so they'll be f32 51 | assert ( 52 | buf.dtype == torch.get_default_dtype() 53 | ), f"weights has unexpected type {buf.dtype}" 54 | elif buf.numel() > 1 and buf.dtype.is_floating_point: 55 | assert buf.dtype == dt, f"Buffer {buf_name} has incorrect dtype." 56 | 57 | 58 | class TestFeatureSizes: 59 | def test_orff_offset(self): 60 | rf_offset = init_rf( 61 | "gaussian", input_dim=32, use_offset=True, num_random_features=128 62 | ) 63 | rf_no_offset = init_rf( 64 | "gaussian", input_dim=32, use_offset=False, num_random_features=128 65 | ) 66 | assert rf_offset.num_random_features == 128 67 | assert rf_no_offset.num_random_features == 128 68 | assert rf_offset.total_random_features == 128 69 | assert rf_no_offset.total_random_features == 256 70 | assert rf_offset.rff_matrix.shape == (128, 32) 71 | assert rf_no_offset.rff_matrix.shape == (128, 32) 72 | assert rf_offset.random_offset.shape == (128,) 73 | 74 | @pytest.mark.parametrize("rf_type", ["poly", "gaussian"]) 75 | def test_per_species_kernel_nonlin1(self, rf_type): 76 | rf = init_rf( 77 | rf_type, 78 | input_dim=32, 79 | num_random_features=128, 80 | num_species=4, 81 | chemically_informed_ratio=None, 82 | ) 83 | assert rf.num_random_features == 128 84 | assert rf.total_random_features == 128 * 4 85 | 86 | @pytest.mark.parametrize("rf_type", ["poly", "gaussian"]) 87 | def test_per_species_kernel_nonlin2(self, rf_type): 88 | rf = init_rf( 89 | rf_type, 90 | input_dim=32, 91 | num_random_features=128, 92 | num_species=4, 93 | chemically_informed_ratio=0.4, 94 | ) 95 | assert rf.num_random_features == 128 96 | assert rf.total_random_features == 128 * (4 + 1) 97 | 98 | def test_per_species_kernel_lin1(self): 99 | rf = init_rf( 100 | "linear", input_dim=32, num_species=4, chemically_informed_ratio=None 101 | ) 102 | assert rf.num_random_features == 33 103 | assert rf.total_random_features == 33 * 4 104 | 105 | def test_per_species_kernel_lin2(self): 106 | rf = init_rf( 107 | "linear", input_dim=32, num_species=4, chemically_informed_ratio=0.4 108 | ) 109 | assert rf.num_random_features == 33 110 | assert rf.total_random_features == (33) * (4 + 1) 111 | 112 | 113 | class TestEdgeCaseInputs: 114 | @pytest.mark.parametrize("rf_type", RF_PARAMETRIZE) 115 | def test_zero_lengthscale(self, rf_type): 116 | with pytest.raises(ValueError): 117 | if rf_type != "multiscale-gaussian": 118 | init_rf(rf_type, input_dim=32, length_scale=0) 119 | else: 120 | init_rf(rf_type, input_dim=32, length_scale_low=0) 121 | 122 | @pytest.mark.parametrize("rf_type", RF_PARAMETRIZE) 123 | def test_negative_lengthscale(self, rf_type): 124 | with pytest.raises(ValueError): 125 | if rf_type != "multiscale-gaussian": 126 | init_rf(rf_type, input_dim=32, length_scale=-1.1) 127 | else: 128 | init_rf(rf_type, input_dim=32, length_scale_low=-1.1) 129 | 130 | 131 | if __name__ == "__main__": 132 | pytest.main() 133 | -------------------------------------------------------------------------------- /tests/test_trainers_log_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from franken.trainers.log_utils import HyperParameterGroup, LogEntry 4 | 5 | 6 | @pytest.fixture 7 | def dummy_log_dict(): 8 | return { 9 | "checkpoint": {"hash": "rand_uuid", "rf_weight_id": 0}, 10 | "timings": {"cov_coeffs": 1.0, "solve": 1.0}, 11 | "metrics": { 12 | "train": {"energy_MAE": 1.0, "forces_MAE": 1.0, "forces_cosim": 1.0}, 13 | "validation": {"energy_MAE": 1.0, "forces_MAE": 1.0, "forces_cosim": 1.0}, 14 | "test": {"energy_MAE": 1.0, "forces_MAE": 1.0, "forces_cosim": 1.0}, 15 | }, 16 | "hyperparameters": { 17 | "franken": { 18 | "gnn_backbone_id": "SchNet-S2EF-OC20-All", 19 | "interaction_block": 3, 20 | "kernel_type": "gaussian", 21 | }, 22 | "random_features": { 23 | "num_random_features": 1024, 24 | }, 25 | "input_scaler": {"scale_by_Z": True, "num_species": 2}, 26 | "solver": { 27 | "l2_penalty": 1e-6, 28 | "force_weight": 0.1, 29 | "dtype": "torch.float64", 30 | }, 31 | }, 32 | } 33 | 34 | 35 | def test_hpgroup_from_dict(): 36 | dummy_group_dict = { 37 | "str_param": "str_value", 38 | "int_param": 1, 39 | "float_param": 1.0, 40 | "bool_param": True, 41 | } 42 | 43 | hpg = HyperParameterGroup.from_dict("dummy_group", dummy_group_dict) 44 | assert hpg.group_name == "dummy_group" 45 | for hp in hpg.hyperparameters: 46 | assert hp.name in dummy_group_dict.keys() 47 | assert hp.value == dummy_group_dict[hp.name] 48 | 49 | 50 | def test_log_entry_serialize_deserialize(dummy_log_dict): 51 | log_entry = LogEntry.from_dict(dummy_log_dict) 52 | assert log_entry.to_dict() == dummy_log_dict 53 | 54 | 55 | def test_log_entry_get_metric(dummy_log_dict): 56 | log_entry = LogEntry.from_dict(dummy_log_dict) 57 | assert log_entry.get_metric("energy_MAE", "train") == 1.0 58 | 59 | 60 | def test_log_entry_get_invalid_metric_name(dummy_log_dict): 61 | log_entry = LogEntry.from_dict(dummy_log_dict) 62 | with pytest.raises(KeyError): 63 | log_entry.get_metric("invalid_metric", "train") 64 | 65 | 66 | def test_log_entry_get_invalid_metric_split(dummy_log_dict): 67 | log_entry = LogEntry.from_dict(dummy_log_dict) 68 | with pytest.raises(KeyError): 69 | log_entry.get_metric("energy_MAE", "invalid_split") 70 | 71 | 72 | # class TestBestModel: 73 | # def test_all_nans(self): 74 | # log_entries = [ 75 | # {"metrics": {"val": {"energy": torch.nan}}}, 76 | # {"metrics": {"val": {"energy": torch.nan}}}, 77 | # ] 78 | # expected_best_log = log_entries[0] 79 | # best_log = get_best_model(log_entries, ["energy"], split="val") 80 | # assert best_log == expected_best_log 81 | 82 | # def test_nans(self): 83 | # log_entries = [ 84 | # {"metrics": {"val": {"energy": torch.nan}}}, 85 | # {"metrics": {"val": {"energy": 0.1}}}, 86 | # {"metrics": {"val": {"energy": 12.0}}}, 87 | # ] 88 | # expected_best_log = log_entries[1] 89 | # best_log = get_best_model(log_entries, ["energy"], split="val") 90 | # assert best_log == expected_best_log 91 | # log_entries = [ 92 | # {"metrics": {"val": {"energy": 0.1}}}, 93 | # {"metrics": {"val": {"energy": torch.nan}}}, 94 | # {"metrics": {"val": {"energy": 12.0}}}, 95 | # ] 96 | # expected_best_log = log_entries[0] 97 | # best_log = get_best_model(log_entries, ["energy"], split="val") 98 | # assert best_log == expected_best_log 99 | 100 | # def test_stability(self): 101 | # log_entries = [ 102 | # {"metrics": {"val": {"energy": 1.0, "forces": 12}}}, 103 | # {"metrics": {"val": {"energy": 1.1, "forces": 11.9}}}, 104 | # {"metrics": {"val": {"energy": 1.2, "forces": 11.8}}}, 105 | # ] 106 | # expected_best_log = log_entries[0] 107 | # best_log = get_best_model(log_entries, ["energy", "forces"], split="val") 108 | # assert best_log == expected_best_log 109 | 110 | # def test_normal(self): 111 | # log_entries = [ 112 | # {"metrics": {"val": {"energy": 1.0, "forces": 12}}}, 113 | # {"metrics": {"val": {"energy": 0.9, "forces": 11.9}}}, 114 | # {"metrics": {"val": {"energy": 1.2, "forces": 11.8}}}, 115 | # ] 116 | # expected_best_log = log_entries[1] 117 | # best_log = get_best_model(log_entries, ["energy", "forces"], split="val") 118 | # assert best_log == expected_best_log 119 | 120 | # def test_missing_split(self): 121 | # log_entries = [ 122 | # {"metrics": {"val": {"energy": 1.0, "forces": 12}}}, 123 | # ] 124 | # with pytest.raises(KeyError): 125 | # get_best_model(log_entries, ["energy", "forces"], split="train") 126 | 127 | # def test_missing_metric(self): 128 | # log_entries = [ 129 | # {"metrics": {"val": {"energy": 1.0, "forces": 12}}}, 130 | # ] 131 | # with pytest.raises(KeyError): 132 | # get_best_model(log_entries, ["missing", "forces"], split="val") 133 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import tempfile 3 | from unittest.mock import MagicMock, patch 4 | 5 | import torch 6 | 7 | 8 | # Utility function to create a temporary directory 9 | def create_temp_dir() -> str: 10 | return tempfile.mkdtemp() 11 | 12 | 13 | # Utility function to clean up a directory 14 | def cleanup_dir(temp_dir: str): 15 | shutil.rmtree(temp_dir) 16 | 17 | 18 | def are_dicts_close(dict1, dict2, rtol=1e-4, atol=1e-6, verbose=False): 19 | if not isinstance(dict1, dict) or not isinstance(dict2, dict): 20 | return False 21 | 22 | if set(dict1.keys()) != set(dict2.keys()): 23 | if verbose: 24 | print(f"Dictionaries have different keys: {set(dict1.keys())}, {set(dict2.keys())}") 25 | return False 26 | 27 | for key in dict1.keys(): 28 | if isinstance(dict1[key], dict) and isinstance(dict2[key], dict): 29 | if not are_dicts_close(dict1[key], dict2[key], rtol, atol): 30 | return False 31 | elif isinstance(dict1[key], torch.Tensor) and isinstance( 32 | dict2[key], torch.Tensor 33 | ): 34 | if not torch.allclose(dict1[key], dict2[key], rtol=rtol, atol=atol): 35 | if verbose: 36 | print(f"{key} not equal:\n(1) {dict1[key]}\n(2) {dict2[key]}") 37 | return False 38 | else: 39 | if verbose: 40 | print("The dictionaries have differnt topology") 41 | return False 42 | return True 43 | 44 | 45 | def mocked_gnn(device, dtype, feature_dim: int = 32, backbone_id: str = "test"): 46 | # A bunch of code to initialize a mock for the GNN 47 | gnn = MagicMock() 48 | gnn.feature_dim = MagicMock(return_value=feature_dim) 49 | fake_gnn_weight = torch.randn(3, feature_dim, device=device, dtype=dtype) 50 | 51 | def mock_descriptors(data): 52 | return torch.sin(data.atom_pos) @ fake_gnn_weight 53 | 54 | gnn.descriptors = mock_descriptors 55 | 56 | def load_checkpoint_patch(*args, **kwargs): 57 | gnn.init_args = MagicMock(return_value=dict(kwargs)) 58 | return gnn 59 | 60 | return patch.multiple("franken.rf.model", load_checkpoint=load_checkpoint_patch) 61 | --------------------------------------------------------------------------------