├── .github └── workflows │ ├── changelog.yaml │ ├── publish.yaml │ └── tests.yaml ├── .gitignore ├── .python-version ├── CHANGELOG.md ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── configs ├── kitchen-sink.yaml ├── minimal.yaml └── realistic.yaml ├── docs └── source │ ├── _static │ ├── Na.html │ ├── custom.css │ ├── lj-parity-prefit.svg │ ├── lj-parity-raw.svg │ ├── lj-parity-relative.svg │ ├── logo-square.svg │ ├── logo-text.svg │ └── water.html │ ├── building-blocks │ ├── aggregation.rst │ ├── bessel.svg │ ├── distances.rst │ ├── e3nn.rst │ ├── envelopes.rst │ ├── erbf.svg │ ├── gaussian.svg │ ├── nn.rst │ ├── root.rst │ ├── scaling.rst │ └── sin.svg │ ├── cli │ ├── graph-pes-id.rst │ ├── graph-pes-resume.rst │ ├── graph-pes-test.rst │ └── graph-pes-train │ │ ├── complete-docs.rst │ │ ├── examples.rst │ │ ├── root.rst │ │ └── the-basics.rst │ ├── conf.py │ ├── data │ ├── atomic_graph.rst │ ├── datasets.rst │ ├── loader.rst │ └── root.rst │ ├── development.rst │ ├── fitting │ ├── callbacks.rst │ ├── losses.rst │ ├── optimizers.rst │ └── root.rst │ ├── hide-title.html │ ├── index.rst │ ├── interfaces │ ├── mace.rst │ ├── mattersim.rst │ └── orb.rst │ ├── models │ ├── addition.rst │ ├── lj-dimer.svg │ ├── many-body │ │ ├── eddp.rst │ │ ├── mace.rst │ │ ├── nequip.rst │ │ ├── painn.rst │ │ ├── root.rst │ │ ├── schnet.rst │ │ ├── stillinger-weber.rst │ │ └── tensornet.rst │ ├── morse-dimer.svg │ ├── offsets.rst │ ├── pairwise.rst │ ├── root.rst │ └── zbl-dimer.svg │ ├── quickstart │ ├── custom-training-loop.ipynb │ ├── fine-tune.yaml │ ├── fine-tuning.ipynb │ ├── implement-a-model.ipynb │ ├── mp0.yaml │ ├── orb.yaml │ ├── parity-plot.svg │ ├── quickstart-cgap17.yaml │ ├── quickstart.ipynb │ └── root.rst │ ├── theory.ipynb │ ├── tools │ ├── Cu-LJ-default-parity.svg │ ├── analysis.rst │ ├── ase.ipynb │ ├── dimer-curve.svg │ ├── lammps.ipynb │ └── torch-sim.ipynb │ └── utils.rst ├── pyproject.toml ├── scripts └── build-lammps.sh ├── src └── graph_pes │ ├── __init__.py │ ├── atomic_graph.py │ ├── config │ ├── shared.py │ ├── testing.py │ ├── training-defaults.yaml │ └── training.py │ ├── data │ ├── __init__.py │ ├── ase_db.py │ ├── datasets.py │ └── loader.py │ ├── graph_pes_model.py │ ├── interfaces │ ├── __init__.py │ ├── _mace.py │ ├── _mattersim.py │ ├── _orb.py │ ├── mace_test.py │ ├── mattersim_test.py │ ├── orb_test.py │ └── quick.yaml │ ├── models │ ├── __init__.py │ ├── addition.py │ ├── components │ │ ├── aggregation.py │ │ ├── distances.py │ │ └── scaling.py │ ├── e3nn │ │ ├── _high_order_CG_coeff.pt │ │ ├── mace.py │ │ ├── mace_utils.py │ │ ├── nequip.py │ │ └── utils.py │ ├── eddp.py │ ├── offsets.py │ ├── painn.py │ ├── pairwise.py │ ├── schnet.py │ ├── scripted.py │ ├── stillinger_weber.py │ ├── tensornet.py │ └── unit_converter.py │ ├── pair_style │ ├── pair_graph_pes.cpp │ └── pair_graph_pes.h │ ├── scripts │ ├── id.py │ ├── resume.py │ ├── test.py │ ├── train.py │ └── utils.py │ ├── training │ ├── callbacks.py │ ├── loss.py │ ├── opt.py │ ├── tasks.py │ └── utils.py │ └── utils │ ├── analysis.py │ ├── calculator.py │ ├── distributed.py │ ├── lammps.py │ ├── logger.py │ ├── misc.py │ ├── nn.py │ ├── sampling.py │ ├── shift_and_scale.py │ └── threebody.py ├── tests ├── __init__.py ├── config │ └── test_config.py ├── conftest.py ├── data │ ├── __init__.py │ ├── schnetpack_data.db │ ├── test_ase_datasets.py │ └── test_db.py ├── graphs │ ├── __init__.py │ ├── test_atomic_graph.py │ ├── test_batching.py │ ├── test_conversions.py │ └── test_threebody.py ├── helpers │ ├── __init__.py │ └── test.xyz ├── models │ ├── __init__.py │ ├── components │ │ ├── test_aggregation.py │ │ └── test_distances.py │ ├── test_correctness.py │ ├── test_cutoffs.py │ ├── test_direct_prediction.py │ ├── test_equivariance.py │ ├── test_freezing.py │ ├── test_models.py │ ├── test_offsets.py │ ├── test_parameter_counting.py │ ├── test_predictions.py │ ├── test_scripting.py │ └── test_state_dict.py ├── training │ ├── __init__.py │ ├── test_callbacks.py │ ├── test_config.py │ ├── test_integration.py │ ├── test_loss.py │ ├── test_opt.py │ ├── test_pre_fit.py │ └── test_train_script.py └── utils │ ├── __init__.py │ ├── test_analysis.py │ ├── test_auto_offset.py │ ├── test_calculator.py │ ├── test_deploy.py │ ├── test_dtypes.py │ ├── test_lammps.py │ ├── test_misc.py │ ├── test_multi_sequence.py │ ├── test_nn.py │ └── test_sampling.py └── uv.lock /.github/workflows/changelog.yaml: -------------------------------------------------------------------------------- 1 | name: changelog 2 | 3 | on: 4 | pull_request: 5 | 6 | permissions: 7 | contents: read 8 | 9 | jobs: 10 | changelog: 11 | if: github.event_name == 'pull_request' 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | with: 16 | fetch-depth: 0 17 | - name: Check for CHANGELOG.md changes 18 | run: | 19 | git fetch origin ${{ github.base_ref }} 20 | if ! git diff --name-only origin/${{ github.base_ref }}..HEAD | grep -q "CHANGELOG.md"; then 21 | echo "Error: No changes to CHANGELOG.md found in this pull request" 22 | exit 1 23 | fi 24 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: publish 2 | 3 | on: 4 | push: 5 | tags: 6 | - "*.*.*" 7 | 8 | jobs: 9 | publish: 10 | name: Upload release to PyPI 11 | runs-on: ubuntu-latest 12 | permissions: 13 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - uses: actions/setup-python@v4 18 | with: 19 | python-version: 3.9 20 | - name: Install graph-pes with publish dependencies 21 | run: pip install ".[publish]" 22 | - name: Build 23 | run: python -m build 24 | - name: Check 25 | run: twine check dist/* 26 | - name: Publish to PyPI 27 | uses: pypa/gh-action-pypi-publish@release/v1 28 | -------------------------------------------------------------------------------- /.github/workflows/tests.yaml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | 6 | permissions: 7 | contents: write 8 | 9 | jobs: 10 | formatting: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v4 15 | 16 | - name: Run ruff 17 | uses: astral-sh/ruff-action@v3 18 | with: 19 | src: "./src" 20 | version-file: "uv.lock" 21 | 22 | tests: 23 | runs-on: ubuntu-latest 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | python-version: 28 | - 3.9 29 | - "3.10" 30 | - 3.11 31 | name: test - ${{ matrix.python-version }} 32 | steps: 33 | ### SETUP ### 34 | - uses: actions/checkout@v4 35 | 36 | - uses: actions/setup-python@v5 37 | with: 38 | python-version: ${{ matrix.python-version }} 39 | 40 | - uses: astral-sh/setup-uv@v5 41 | with: 42 | enable-cache: true 43 | cache-dependency-glob: "uv.lock" 44 | 45 | - name: Install graph-pes and required dependencies 46 | run: uv sync --extra test 47 | 48 | - name: Useful info 49 | run: uv pip freeze 50 | 51 | ### CODE TESTS ### 52 | - name: Run tests 53 | run: PYTHONPATH=. uv run pytest --cov --cov-report xml tests/ 54 | 55 | - name: Upload coverage reports to Codecov 56 | uses: codecov/codecov-action@v3 57 | # only upload coverage reports for the first python version 58 | if: matrix.python-version == '3.9' 59 | with: 60 | token: ${{secrets.CODECOV_TOKEN}} 61 | 62 | - name: Run a very small training run 63 | run: | 64 | uv run graph-pes-train configs/minimal.yaml \ 65 | data/+load_atoms_dataset/n_train=10 \ 66 | data/+load_atoms_dataset/n_valid=10 \ 67 | fitting/trainer_kwargs/max_epochs=5 \ 68 | wandb=null \ 69 | general/root_dir=results \ 70 | general/run_id=test-run 71 | 72 | - name: Run a very small testing run 73 | run: | 74 | uv run graph-pes-test model_path=results/test-run/model.pt \ 75 | data=tests/helpers/test.xyz 76 | 77 | interface-tests: 78 | runs-on: ubuntu-latest 79 | strategy: 80 | fail-fast: false 81 | matrix: 82 | interface: 83 | - name: mace-torch 84 | python-version: "3.10" 85 | test-file: "mace_test.py" 86 | extra-deps: "mace-torch" 87 | model_kwargs: model="+mace_mp()" 88 | - name: mattersim 89 | python-version: "3.9" 90 | test-file: "mattersim_test.py" 91 | extra-deps: "mattersim" 92 | model_kwargs: model="+mattersim()" 93 | - name: orb 94 | python-version: "3.10" 95 | test-file: "orb_test.py" 96 | extra-deps: "orb-models" 97 | model_kwargs: model/+orb_model/name=orb-d3-xs-v2 98 | name: ${{ matrix.interface.name }} interface 99 | steps: 100 | - uses: actions/checkout@v4 101 | 102 | - uses: actions/setup-python@v5 103 | with: 104 | python-version: ${{ matrix.interface.python-version }} 105 | 106 | - uses: astral-sh/setup-uv@v5 107 | with: 108 | enable-cache: true 109 | cache-dependency-glob: "uv.lock" 110 | 111 | - name: Install graph-pes and required dependencies 112 | run: uv pip install --system --upgrade ".[test]" 113 | 114 | - name: Install extra dependencies 115 | run: uv pip install --system --upgrade ${{ matrix.interface.extra-deps }} 116 | 117 | - name: Run tests 118 | run: pytest src/graph_pes/interfaces/${{ matrix.interface.test-file }} -vvv 119 | 120 | - name: Run a small training run 121 | run: | 122 | graph-pes-train src/graph_pes/interfaces/quick.yaml \ 123 | ${{ matrix.interface.model_kwargs }} 124 | 125 | docs: 126 | runs-on: ubuntu-latest 127 | steps: 128 | ### SETUP ### 129 | - uses: actions/checkout@v4 130 | 131 | - uses: actions/setup-python@v5 132 | with: 133 | python-version: 3.9 134 | 135 | - uses: astral-sh/setup-uv@v5 136 | with: 137 | enable-cache: true 138 | cache-dependency-glob: "uv.lock" 139 | 140 | - name: Install pandoc 141 | run: sudo apt-get install -y --no-install-recommends pandoc 142 | 143 | - name: Install graph-pes and required dependencies 144 | run: uv sync --extra docs 145 | 146 | ### DOCS ### 147 | - name: Build docs 148 | # -n: nitpick: 149 | # -W: turn warnings into errors 150 | run: uv run sphinx-build -nW docs/source docs/build --keep-going 151 | 152 | - name: Publich docs 153 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') 154 | uses: peaceiris/actions-gh-pages@v3 155 | with: 156 | publish_branch: gh-pages 157 | github_token: ${{ secrets.GITHUB_TOKEN }} 158 | publish_dir: docs/build/ 159 | force_orphan: true 160 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | lightning_logs 2 | __pycache__ 3 | TODO* 4 | .venv 5 | 6 | exclude 7 | docs/build 8 | dist 9 | **/*results* 10 | .communication 11 | logs 12 | **/*.egg-info 13 | 14 | .pytest_cache 15 | coverage.xml 16 | .coverage* 17 | *.extxyz 18 | *.xyz 19 | *.npz 20 | *.ipynb 21 | *.pt 22 | *.pth 23 | **/wandb 24 | .vscode 25 | .isort.cfg 26 | .DS_Store 27 | *.h5md 28 | *.traj 29 | 30 | **.cache -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.9 -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Gardner" 5 | given-names: "John" 6 | orcid: "https://orcid.org/0009-0006-7377-7146" 7 | title: "graph-pes: train and use graph-based ML models of potential energy surfaces" 8 | version: 0.1.1 9 | doi: 10.5281/zenodo.14956210 10 | date-released: 2024-07-01 11 | url: "https://github.com/jla-gardner/graph-pes" 12 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions to `graph-pes` via pull requests are very welcome! Here's how to get started. 4 | 5 | --- 6 | 7 | **Getting started** 8 | 9 | First fork the library on GitHub. 10 | 11 | Then clone and install the library in development mode: 12 | 13 | ```bash 14 | git clone https://github.com//graph-pes.git 15 | cd graph-pes 16 | pip install -e ".[dev]" 17 | ``` 18 | 19 | Alternatively, you can use [`uv`](https://docs.astral.sh/uv/): 20 | 21 | ```bash 22 | git clone https://github.com//graph-pes.git 23 | cd graph-pes 24 | uv sync --all-extras 25 | ``` 26 | 27 | --- 28 | 29 | **If you're making changes to the code:** 30 | 31 | Now make your changes. Make sure to include additional tests if necessary. 32 | 33 | Next verify the tests all pass: 34 | 35 | ```bash 36 | pip install pytest 37 | pytest tests/ # or uv run pytest tests/ 38 | ``` 39 | 40 | Then push your changes back to your fork of the repository: 41 | 42 | ```bash 43 | git push 44 | ``` 45 | 46 | Finally, open a pull request on GitHub! 47 | 48 | --- 49 | 50 | **If you're making changes to the documentation:** 51 | 52 | Make your changes. You can then build the documentation by doing 53 | 54 | ```bash 55 | pip install -e ".[docs]" # or uv sync --extra docs 56 | sphinx-autobuild docs/source docs/build 57 | ``` 58 | 59 | You can then see your local copy of the documentation by navigating to `localhost:8000` in a web browser. 60 | Any time you save changes to the documentation, these will shortly be reflected in the browser! -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-25 John Gardner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # MANIFEST.in 2 | 3 | include src/graph_pes/config/training-defaults.yaml 4 | include src/graph_pes/scripts/automation.yaml 5 | include src/graph_pes/pair_style/pair_graph_pes.h 6 | include src/graph_pes/pair_style/pair_graph_pes.cpp 7 | include src/graph_pes/models/e3nn/_high_order_CG_coeff.pt 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 | 6 | `graph-pes` is a framework built to accelerate the development of machine-learned potential energy surface (PES) models that act on graph representations of atomic structures. 7 | 8 | Links: [Google Colab Quickstart](https://colab.research.google.com/github/jla-gardner/graph-pes/blob/main/docs/source/quickstart/quickstart.ipynb) - [Documentation](https://jla-gardner.github.io/graph-pes/) - [PyPI](https://pypi.org/project/graph-pes/) 9 | 10 | [![PyPI](https://img.shields.io/pypi/v/graph-pes)](https://pypi.org/project/graph-pes/) 11 | [![Conda-forge](https://img.shields.io/conda/vn/conda-forge/graph-pes.svg)](https://github.com/conda-forge/graph-pes-feedstock) 12 | [![Tests](https://github.com/jla-gardner/graph-pes/actions/workflows/tests.yaml/badge.svg?branch=main)](https://github.com/jla-gardner/graph-pes/actions/workflows/tests.yaml) 13 | [![codecov](https://codecov.io/gh/jla-gardner/graph-pes/branch/main/graph/badge.svg)](https://codecov.io/gh/jla-gardner/graph-pes) 14 | [![GitHub last commit](https://img.shields.io/github/last-commit/jla-gardner/load-atoms)]() 15 | 16 |
17 | 18 | 19 | ## Features 20 | 21 | - Experiment with new model architectures by inheriting from our `GraphPESModel` [base class](https://jla-gardner.github.io/graph-pes/models/root.html). 22 | - [Train your own](https://jla-gardner.github.io/graph-pes/quickstart/implement-a-model.html) or existing model architectures (e.g., [SchNet](https://jla-gardner.github.io/graph-pes/models/many-body/schnet.html), [NequIP](https://jla-gardner.github.io/graph-pes/models/many-body/nequip.html), [PaiNN](https://jla-gardner.github.io/graph-pes/models/many-body/pinn.html), [MACE](https://jla-gardner.github.io/graph-pes/models/many-body/mace.html), [TensorNet](https://jla-gardner.github.io/graph-pes/models/many-body/tensornet.html), etc.). 23 | - Use and fine-tune foundation models via a unified interface: [MACE-MP0](https://jla-gardner.github.io/graph-pes/interfaces/mace.html), [MACE-OFF](https://jla-gardner.github.io/graph-pes/interfaces/mace.html), [MatterSim](https://jla-gardner.github.io/graph-pes/interfaces/mattersim.html), [GO-MACE](https://jla-gardner.github.io/graph-pes/interfaces/mace.html) and [Orb v2/3](https://jla-gardner.github.io/graph-pes/interfaces/orb.html). 24 | - Easily configure distributed training, learning rate scheduling, weights and biases logging, and other features using our `graph-pes-train` [command line interface](https://jla-gardner.github.io/graph-pes/cli/graph-pes-train/root.html). 25 | - Use our data-loading pipeline within your [own training loop](https://jla-gardner.github.io/graph-pes/quickstart/custom-training-loop.html). 26 | - Run molecular dynamics simulations with any `GraphPESModel` using [torch-sim](https://jla-gardner.github.io/graph-pes/tools/torch-sim.html), [LAMMPS](https://jla-gardner.github.io/graph-pes/tools/lammps.html) or [ASE](https://jla-gardner.github.io/graph-pes/tools/ase.html) 27 | 28 | ## Quickstart 29 | 30 | ```bash 31 | pip install -q graph-pes 32 | wget https://tinyurl.com/graph-pes-minimal-config -O config.yaml 33 | graph-pes-train config.yaml 34 | ``` 35 | 36 | Alternatively, for a 0-install quickstart experience, please see [this Google Colab](https://colab.research.google.com/github/jla-gardner/graph-pes/blob/main/docs/source/quickstart/quickstart.ipynb), which you can also find in our [documentation](https://jla-gardner.github.io/graph-pes/quickstart/quickstart.html). 37 | 38 | 39 | ## Contributing 40 | 41 | Contributions are welcome! If you find any issues or have suggestions for new features, please open an issue or submit a pull request on the [GitHub repository](https://github.com/jla-gardner/graph-pes). 42 | 43 | ## Citing `graph-pes` 44 | 45 | We kindly ask that you cite `graph-pes` in your work if it has been useful to you. 46 | A manuscript is currently in preparation - in the meantime, please cite the Zenodo DOI found in the [CITATION.cff](CITATION.cff) file. 47 | -------------------------------------------------------------------------------- /configs/kitchen-sink.yaml: -------------------------------------------------------------------------------- 1 | CUTOFF: 3.7 2 | 3 | model: 4 | offset: 5 | +FixedOffset: { H: -123.4, C: -456.7 } 6 | core-repulsion: 7 | +ZBLCoreRepulsion: 8 | trainable: true 9 | cutoff: =/CUTOFF 10 | many-body: 11 | +NequIP: 12 | elements: [C, H, O, N] 13 | cutoff: =/CUTOFF 14 | channels: 128 15 | hidden_irreps: 0e + 1o 16 | self_connection: true 17 | 18 | data: 19 | train: 20 | path: training_data.xyz 21 | n: 1000 22 | shuffle: true 23 | seed: 42 24 | valid: validation_data.xyz 25 | test: 26 | bulk: 27 | +my_module.bulk_test_set: 28 | cutoff: =/CUTOFF 29 | slab: 30 | +my_module.slab_test_set: 31 | n: 100 32 | cutoff: =/CUTOFF 33 | 34 | loss: 35 | energy: +PerAtomEnergyLoss() 36 | forces: 37 | +ForceRMSE: 38 | weight: 3.0 39 | stress: 40 | +PropertyLoss: 41 | property: stress 42 | metric: RMSE 43 | weight: 10.0 44 | 45 | fitting: 46 | pre_fit_model: true 47 | max_n_pre_fit: 1000 48 | early_stopping: 49 | monitor: valid/metrics/forces_rmse 50 | patience: 50 51 | min_delta: 1e-3 52 | 53 | trainer_kwargs: 54 | max_epochs: 1000 55 | accelerator: gpu 56 | accumulate_grad_batches: 4 57 | val_check_interval: 0.25 58 | 59 | optimizer: 60 | name: AdamW 61 | lr: 0.003 62 | weight_decay: 0.01 63 | amsgrad: true 64 | 65 | scheduler: 66 | name: ReduceLROnPlateau 67 | patience: 10 68 | factor: 0.8 69 | 70 | swa: 71 | lr: 0.001 72 | start: 0.8 73 | anneal_epochs: 10 74 | strategy: linear 75 | 76 | loader_kwargs: 77 | batch_size: 32 78 | num_workers: 4 79 | shuffle: true 80 | persistent_workers: true 81 | 82 | general: 83 | seed: 42 84 | root_dir: /path/to/root 85 | run_id: kitchen-sink-run 86 | log_level: INFO 87 | progress: rich 88 | 89 | wandb: 90 | project: my_project 91 | entity: my_entity 92 | tags: [kitchen-sink, test] 93 | name: kitchen-sink-run 94 | -------------------------------------------------------------------------------- /configs/minimal.yaml: -------------------------------------------------------------------------------- 1 | # train a SchNet model... 2 | model: 3 | +SchNet: 4 | layers: 3 5 | channels: 64 6 | cutoff: =/CUTOFF 7 | 8 | # ...using some of the QM7 structures... 9 | data: 10 | +load_atoms_dataset: 11 | id: QM7 12 | cutoff: =/CUTOFF 13 | n_train: 5_000 14 | n_valid: 100 15 | 16 | # ...training on energy labels... 17 | loss: +PerAtomEnergyLoss() 18 | 19 | # ...using a cutoff of 5.0 Å 20 | # (referenced above) 21 | CUTOFF: 5.0 22 | -------------------------------------------------------------------------------- /configs/realistic.yaml: -------------------------------------------------------------------------------- 1 | CUTOFF: 3.7 2 | 3 | general: 4 | seed: 42 5 | run_id: mace-c-gap-20u 6 | 7 | model: 8 | offset: +LearnableOffset() 9 | core-repulsion: 10 | +ZBLCoreRepulsion: 11 | trainable: true 12 | cutoff: =/CUTOFF 13 | many-body: 14 | +MACE: 15 | elements: [C] 16 | cutoff: =/CUTOFF 17 | channels: 128 18 | hidden_irreps: 0e + 1o 19 | self_connection: true 20 | 21 | data: 22 | +load_atoms_dataset: 23 | id: C-GAP-20U 24 | cutoff: =/CUTOFF 25 | n_train: 5000 26 | n_valid: 100 27 | n_test: 500 28 | split: random 29 | 30 | loss: 31 | energy: +PerAtomEnergyLoss() 32 | forces: 33 | +ForceRMSE: 34 | weight: 10.0 35 | 36 | fitting: 37 | trainer_kwargs: 38 | max_epochs: 1000 39 | accelerator: gpu 40 | 41 | callbacks: 42 | - +graph_pes.training.callbacks.DumpModel: 43 | every_n_val_checks: 10 44 | 45 | early_stopping: 46 | patience: 50 47 | 48 | optimizer: 49 | name: AdamW 50 | lr: 0.003 51 | 52 | scheduler: 53 | name: ReduceLROnPlateau 54 | patience: 10 55 | factor: 0.8 56 | 57 | loader_kwargs: 58 | batch_size: 32 59 | num_workers: 4 60 | -------------------------------------------------------------------------------- /docs/source/_static/custom.css: -------------------------------------------------------------------------------- 1 | h1 { 2 | font-size: 3em; 3 | text-align: center; 4 | /* pad on top, but not on bottom */ 5 | padding-top: 0.5em; 6 | padding-bottom: 0; 7 | border-bottom: 0; 8 | } 9 | 10 | .code-block-caption, 11 | .highlight { 12 | margin-left: 7%; 13 | margin-right: 7%; 14 | width: 86%; 15 | border-radius: 10px; 16 | } 17 | 18 | .code-block-caption { 19 | padding: 0; 20 | } 21 | .code-block-caption .caption-text { 22 | padding: 10px 15px; 23 | } 24 | 25 | /* squish code blocks so that everything looks a bit prettier */ 26 | article[role="main"] .highlight pre { 27 | line-height: 1.3; 28 | } 29 | 30 | /* unless child of "input-area" */ 31 | .input_area .highlight { 32 | margin-left: 0; 33 | margin-right: 0; 34 | width: 100%; 35 | border-radius: 0; 36 | } 37 | 38 | /* center all children of viz divs */ 39 | /* and don't allow highlighting when clicked */ 40 | .viz { 41 | display: flex; 42 | justify-content: center; 43 | width: 100%; 44 | } 45 | 46 | img { 47 | margin-top: 30px; 48 | margin-bottom: 20px; 49 | } 50 | 51 | /* align content vertically */ 52 | .info-card { 53 | display: flex; 54 | justify-content: center; /* Align horizontal */ 55 | align-items: center; /* Align vertical */ 56 | } 57 | 58 | /* align text center in each cell */ 59 | .table-wrapper td { 60 | text-align: center; 61 | } 62 | 63 | /* anything that lives in a span wrapped by a .sphinx-codeautolink-a */ 64 | .sphinx-codeautolink-a span { 65 | text-decoration: underline; 66 | text-decoration-color: var(--color-api-name); 67 | } 68 | /* make it obviously clickable on hover by turning bold */ 69 | .sphinx-codeautolink-a span:hover { 70 | font-weight: bold; 71 | } 72 | -------------------------------------------------------------------------------- /docs/source/building-blocks/aggregation.rst: -------------------------------------------------------------------------------- 1 | ########### 2 | Aggregation 3 | ########### 4 | 5 | Aggregating some value over one's neighbours is a common operation in graph-based 6 | ML models. ``graph-pes`` provides a base class for such operations, together with 7 | a few common implementations. A common way to specify the aggregation mode to use 8 | in a model is to use a :class:`~graph_pes.models.components.aggregation.NeighbourAggregationMode` 9 | string, which internally is passed to :meth:`~graph_pes.models.components.aggregation.NeighbourAggregation.parse`. 10 | 11 | 12 | Base Class 13 | ---------- 14 | 15 | .. autoclass:: graph_pes.models.components.aggregation.NeighbourAggregation 16 | :members: 17 | 18 | 19 | .. class:: graph_pes.models.components.aggregation.NeighbourAggregationMode 20 | 21 | Type alias for ``Literal["sum", "mean", "constant_fixed", "constant_learnable", "sqrt"]``. 22 | 23 | 24 | Implementations 25 | --------------- 26 | 27 | .. autoclass:: graph_pes.models.components.aggregation.SumNeighbours 28 | :members: 29 | 30 | .. autoclass:: graph_pes.models.components.aggregation.MeanNeighbours 31 | :members: 32 | 33 | .. autoclass:: graph_pes.models.components.aggregation.ScaledSumNeighbours 34 | :members: 35 | 36 | .. autoclass:: graph_pes.models.components.aggregation.VariancePreservingSumNeighbours 37 | :members: 38 | -------------------------------------------------------------------------------- /docs/source/building-blocks/distances.rst: -------------------------------------------------------------------------------- 1 | Distance Expansions 2 | =================== 3 | 4 | Available Expansions 5 | -------------------- 6 | 7 | ``graph-pes`` exposes the :class:`~graph_pes.models.components.distances.DistanceExpansion` 8 | base class, together with implementations of a few common expansions: 9 | 10 | .. autoclass:: graph_pes.models.components.distances.Bessel 11 | :show-inheritance: 12 | .. autoclass:: graph_pes.models.components.distances.GaussianSmearing 13 | :show-inheritance: 14 | .. autoclass:: graph_pes.models.components.distances.SinExpansion 15 | :show-inheritance: 16 | .. autoclass:: graph_pes.models.components.distances.ExponentialRBF 17 | :show-inheritance: 18 | 19 | 20 | Implementing a new Expansion 21 | ---------------------------- 22 | 23 | .. autoclass:: graph_pes.models.components.distances.DistanceExpansion 24 | :members: -------------------------------------------------------------------------------- /docs/source/building-blocks/e3nn.rst: -------------------------------------------------------------------------------- 1 | ``e3nn`` Helpers 2 | ================ 3 | 4 | .. autoclass:: graph_pes.models.e3nn.utils.LinearReadOut 5 | 6 | .. autoclass:: graph_pes.models.e3nn.utils.NonLinearReadOut 7 | -------------------------------------------------------------------------------- /docs/source/building-blocks/envelopes.rst: -------------------------------------------------------------------------------- 1 | Envelopes 2 | ========== 3 | 4 | Available Envelopes 5 | ------------------- 6 | 7 | ``graph-pes`` exposes the :class:`~graph_pes.models.components.distances.Envelope` 8 | base class, together with implementations of a few common envelope functions: 9 | 10 | .. autoclass:: graph_pes.models.components.distances.PolynomialEnvelope 11 | :show-inheritance: 12 | 13 | .. autoclass:: graph_pes.models.components.distances.SmoothOnsetEnvelope 14 | :show-inheritance: 15 | 16 | .. autoclass:: graph_pes.models.components.distances.CosineEnvelope 17 | :show-inheritance: 18 | 19 | Implementing a new Envelope 20 | --------------------------- 21 | 22 | .. autoclass:: graph_pes.models.components.distances.Envelope 23 | :members: 24 | 25 | -------------------------------------------------------------------------------- /docs/source/building-blocks/nn.rst: -------------------------------------------------------------------------------- 1 | ############### 2 | PyTorch Helpers 3 | ############### 4 | 5 | .. autoclass:: graph_pes.utils.nn.PerElementParameter 6 | :show-inheritance: 7 | :members: 8 | 9 | .. autoclass:: graph_pes.utils.nn.PerElementEmbedding 10 | 11 | .. autoclass:: graph_pes.utils.nn.MLPConfig 12 | :show-inheritance: 13 | :members: 14 | 15 | .. autoclass:: graph_pes.utils.nn.MLP 16 | :members: 17 | 18 | -------------------------------------------------------------------------------- /docs/source/building-blocks/root.rst: -------------------------------------------------------------------------------- 1 | Building blocks 2 | ===================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | distances 8 | envelopes 9 | aggregation 10 | scaling 11 | nn 12 | e3nn 13 | -------------------------------------------------------------------------------- /docs/source/building-blocks/scaling.rst: -------------------------------------------------------------------------------- 1 | Scaling 2 | ======= 3 | 4 | A commonly used strategy in models of the PES is to scale the raw local energy predictions 5 | by some scale parameter (derived in a :meth:`graph_pes.GraphPESModel.pre_fit` step). This has the 6 | effect of allowing models to output ~unit normally distributed predictions (which is often 7 | an implicit assumption of e.g. NN components) before having these scaled to the natural scale of the 8 | labels in question. 9 | 10 | .. autoclass:: graph_pes.models.components.scaling.LocalEnergiesScaler 11 | :members: 12 | -------------------------------------------------------------------------------- /docs/source/cli/graph-pes-id.rst: -------------------------------------------------------------------------------- 1 | ``graph-pes-id`` 2 | ================ 3 | 4 | ``graph-pes-id`` is a command line tool for generating a random ID: 5 | 6 | .. code-block:: console 7 | 8 | $ graph-pes-id 9 | brxcu7_p3s018 10 | 11 | A common use case for this is to create a series of experiments associated with a single id: 12 | 13 | .. code-block:: console 14 | 15 | $ # generate a random id 16 | $ ID=$(graph-pes-id) 17 | 18 | $ # pre-train a model 19 | $ graph-pes-train pre-train.yaml \ 20 | general/root_dir=results \ 21 | general/run_id=$ID-pre-train 22 | ... 23 | 24 | $ # fine-tune the model: 25 | $ # we know where the model weights are and so 26 | $ # fine-tuning is easy: we just load the weights 27 | $ graph-pes-train fine-tune.yaml \ 28 | general/root_dir=results \ 29 | model/+load_model/path=results/$ID-pre-train/model.pt \ 30 | general/run_id=$ID-fine-tune 31 | ... 32 | -------------------------------------------------------------------------------- /docs/source/cli/graph-pes-resume.rst: -------------------------------------------------------------------------------- 1 | ``graph-pes-resume`` 2 | ==================== 3 | 4 | ``graph-pes-resume`` is a command line tool for resuming training runs that have been interrupted: 5 | 6 | 7 | .. code-block:: console 8 | 9 | $ graph-pes-resume -h 10 | usage: graph-pes-resume [-h] train_directory 11 | 12 | Resume a `graph-pes-train` training run. 13 | 14 | positional arguments: 15 | train_directory Path to the training directory. 16 | For instance, `graph-pes-results/abdcefg_hijklmn` 17 | 18 | optional arguments: 19 | -h, --help show this help message and exit 20 | 21 | Copyright 2023-35, John Gardner 22 | 23 | 24 | Usage 25 | ----- 26 | 27 | .. code-block:: bash 28 | 29 | $ graph-pes-resume graph-pes-results/abdcefg_hijklmn 30 | -------------------------------------------------------------------------------- /docs/source/cli/graph-pes-test.rst: -------------------------------------------------------------------------------- 1 | ``graph-pes-test`` 2 | ================== 3 | 4 | Use the ``graph-pes-test`` command to test a trained model. 5 | Testing functionality is already baked into ``graph-pes-train``, but this command 6 | allows you more fine-grained control over the testing process. 7 | 8 | 9 | Usage 10 | ----- 11 | 12 | Simplest possible usage - test the model at ``path/to/model.pth`` on the datasets 13 | found in ``path/to/model.pt/../training-config.yaml``. 14 | 15 | .. code-block:: bash 16 | 17 | graph-pes-test model_path=path/to/model.pth 18 | 19 | 20 | Alternatively, to test on new data, pass a path to a new config file that specifies 21 | a :class:`~graph_pes.config.testing.TestingConfig` object: 22 | 23 | .. code-block:: yaml 24 | 25 | graph-pes-test test-config.yaml model_path=path/to/model.pth 26 | 27 | Where ``test-config.yaml`` contains e.g.: 28 | 29 | .. code-block:: yaml 30 | 31 | data: 32 | dimers: path/to/dimers.xyz 33 | amorphous: path/to/amorphous.xyz 34 | 35 | accelerator: gpu 36 | 37 | loader_kwargs: 38 | batch_size: 64 39 | num_workers: 4 40 | 41 | Complete usage: 42 | 43 | .. code-block:: bash 44 | 45 | graph-pes-test -h 46 | 47 | usage: graph-pes-test [-h] [args ...] 48 | 49 | Test a GraphPES model using PyTorch Lightning. 50 | 51 | positional arguments: 52 | args Config files and command line specifications. 53 | Config files should be YAML (.yaml/.yml) files. 54 | Command line specifications should be in the form 55 | my/nested/key=value. Final config is built up from 56 | these items in a left to right manner, with later 57 | items taking precedence over earlier ones in the 58 | case of conflicts. The data2objects package is used 59 | to resolve references and create objects directly 60 | from the config dictionary. 61 | 62 | optional arguments: 63 | -h, --help show this help message and exit 64 | 65 | 66 | Config 67 | ------ 68 | 69 | .. autoclass:: graph_pes.config.testing.TestingConfig() 70 | :members: 71 | -------------------------------------------------------------------------------- /docs/source/cli/graph-pes-train/examples.rst: -------------------------------------------------------------------------------- 1 | Example configs 2 | =============== 3 | 4 | Realistic config 5 | ---------------- 6 | 7 | A realistic config for training a :class:`~graph_pes.models.MACE` model on the `C-GAP-20U dataset `__: 8 | 9 | .. literalinclude:: ../../../../configs/realistic.yaml 10 | :language: yaml 11 | 12 | 13 | 14 | Kitchen sink config 15 | ------------------- 16 | 17 | A `"kitchen sink"` config that attempts to specify every possible option: 18 | 19 | .. literalinclude:: ../../../../configs/kitchen-sink.yaml 20 | :language: yaml 21 | 22 | 23 | Default config 24 | -------------- 25 | 26 | For reference, here are the default config options used in ``graph-pes-train``: 27 | 28 | .. literalinclude:: ../../../../src/graph_pes/config/training-defaults.yaml 29 | :language: yaml 30 | -------------------------------------------------------------------------------- /docs/source/cli/graph-pes-train/root.rst: -------------------------------------------------------------------------------- 1 | .. _cli-reference: 2 | 3 | ``graph-pes-train`` 4 | =================== 5 | 6 | ``graph-pes-train`` is a command line tool for training graph-based potential energy surface models using `PyTorch Lightning `__: 7 | 8 | .. code-block:: console 9 | 10 | $ graph-pes-train -h 11 | usage: graph-pes-train [-h] [args [args ...]] 12 | 13 | Train a GraphPES model using PyTorch Lightning. 14 | 15 | positional arguments: 16 | args Config files and command line specifications. 17 | Config files should be YAML (.yaml/.yml) files. 18 | Command line specifications should be in the form 19 | nested/key=value. Final config is built up from 20 | these items in a left to right manner, with later 21 | items taking precedence over earlier ones in the 22 | case of conflicts. 23 | 24 | optional arguments: 25 | -h, --help show this help message and exit 26 | 27 | Copyright 2023-25, John Gardner 28 | 29 | 30 | .. toctree:: 31 | :maxdepth: 1 32 | :hidden: 33 | 34 | the-basics 35 | complete-docs 36 | examples 37 | 38 | For a hands-on introduction, try our `quickstart Colab notebook `__. Alternatively, you can learn about how to use ``graph-pes-train`` from :doc:`the basics guide `, :doc:`the complete configuration documentation ` or :doc:`a set of examples `. 39 | 40 | \ 41 | 42 | There are a few important things to note when using ``graph-pes-train`` in special situations: 43 | 44 | 45 | .. _multi-GPU training: 46 | 47 | Multi-GPU training: 48 | ------------------- 49 | 50 | The ``graph-pes-train`` command supports multi-GPU out of the box, relying on PyTorch Lightning's native support for distributed training. 51 | **By default, ``graph-pes-train`` will attempt to use all available GPUs.** You can override this by exporting the ``CUDA_VISIBLE_DEVICES`` environment variable: 52 | 53 | .. code-block:: bash 54 | 55 | $ export CUDA_VISIBLE_DEVICES=0,1 56 | $ graph-pes-train config.yaml 57 | 58 | 59 | Non-interactive jobs 60 | -------------------- 61 | 62 | In cases were you are running ``graph-pes-train`` in a non-interactive session (e.g. from a script or scheduled job) and where you wish to make use of the `Weights and Biases `__ logging functionality, you will need to take one of the following steps: 63 | 64 | 1. run ``wandb login`` in an interactive session beforehand - this will cache your credentials to ``~/.netrc`` 65 | 2. set the ``WANDB_API_KEY`` environment variable to your W&B API key directly before running ``graph-pes-train`` 66 | 67 | Failing to do this will result in ``graph-pes-train`` hanging forever while waiting for you to log in to your W&B account. 68 | 69 | Alternatively, you can set the ``wandb: null`` flag in your config file to disable W&B logging. 70 | 71 | 72 | Compute clusters 73 | ---------------- 74 | 75 | If you are running ``graph-pes-train`` on a compute cluster as a scheduled job, ensure that you: 76 | 77 | * use a ``"logged"`` progress bar so that you can monitor the progress of your training run directly from the jobs outputs 78 | * correctly set the ``CUDA_VISIBLE_DEVICES`` environment variable so that ``graph-pes-train`` makes use of all the GPUs you have requested (and no others) (see above) 79 | * consider copying across your data to the worker nodes, and running ``graph-pes-train`` from there rather than on the head node 80 | - ``graph-pes-train`` writes checkpoints semi-frequently to disk, and this may cause issues/throttle the clusters network. 81 | - if you are using a disk-backed dataset (for instance reading from an ``.db`` file), each data point access will require an I/O operation, and reading from local file storage on the worker nodes will be many times faster than over the network. 82 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | project = "graph-pes" 2 | copyright = "2023-2025, John Gardner" 3 | author = "John Gardner" 4 | release = "0.1.1" 5 | 6 | extensions = [ 7 | "sphinx.ext.duration", 8 | "sphinx.ext.autodoc", 9 | "sphinx.ext.autosummary", 10 | "nbsphinx", 11 | "sphinx.ext.mathjax", 12 | "sphinx.ext.napoleon", 13 | "sphinx.ext.intersphinx", 14 | # "sphinxext.opengraph", 15 | "sphinx_copybutton", 16 | "sphinx.ext.viewcode", 17 | "sphinx_design", 18 | ] 19 | 20 | intersphinx_mapping = { 21 | "python": ("https://docs.python.org/3", None), 22 | "torch": ("https://pytorch.org/docs/stable/", None), 23 | "ase": ("https://wiki.fysik.dtu.dk/ase/", None), 24 | "e3nn": ("https://docs.e3nn.org/en/latest/", None), 25 | "pytorch_lightning": ("https://lightning.ai/docs/pytorch/stable/", None), 26 | "matplotlib": ("https://matplotlib.org/stable/", None), 27 | "load-atoms": ("https://jla-gardner.github.io/load-atoms/", None), 28 | } 29 | 30 | html_logo = "_static/logo-square.svg" 31 | html_title = "graph-pes" 32 | html_theme = "furo" 33 | html_static_path = ["_static"] 34 | html_css_files = ["custom.css"] 35 | autodoc_member_order = "bysource" 36 | maximum_signature_line_length = 70 37 | autodoc_typehints = "description" 38 | autodoc_typehints_description_target = "documented_params" 39 | 40 | copybutton_prompt_text = ( 41 | r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " 42 | ) 43 | copybutton_prompt_is_regexp = True 44 | copybutton_selector = "div.copy-button pre" 45 | 46 | logo_highlight_colour = "#f74565" 47 | code_color = "#f74565" 48 | html_theme_options = { 49 | "sidebar_hide_name": True, 50 | "light_css_variables": { 51 | "color-problematic": code_color, 52 | "color-brand-primary": logo_highlight_colour, 53 | "color-brand-content": logo_highlight_colour, 54 | }, 55 | "dark_css_variables": { 56 | "color-problematic": code_color, 57 | "color-brand-primary": logo_highlight_colour, 58 | "color-brand-content": logo_highlight_colour, 59 | }, 60 | } 61 | 62 | nitpick_ignore = [ 63 | ("py:class", "torch.nn.Parameter"), 64 | ("py:class", "numpy.ndarray"), 65 | ("py:class", "e3nn.*"), 66 | ("py:class", "optional"), 67 | ("py:class", "o3.Irreps"), 68 | ("py:class", "graph_pes.config.training.FittingConfig"), 69 | ("py:class", "graph_pes.config.training.SWAConfig"), 70 | ("py:class", "graph_pes.config.training.GeneralConfig"), 71 | ("py:class", "graph_pes.config.shared.TorchConfig"), 72 | ("py:class", "pytorch_lightning.Callback"), 73 | ("py:class", "pytorch_lightning.Trainer"), 74 | ("py:class", "pytorch_lightning.callbacks.Callback"), 75 | ("py:class", "pytorch_lightning.callbacks.callback.Callback"), 76 | ("py:class", "TorchMetric"), 77 | ("py:class", "DirectForcefieldRegressor"), 78 | ("py:class", "ConservativeForcefieldRegressor"), 79 | ("ipython3", "Lexing literal_block"), 80 | ] 81 | 82 | # override the default css to match the furo theme 83 | nbsphinx_prolog = """ 84 | .. raw:: html 85 | 86 | 114 | """ 115 | 116 | 117 | # Add warning filter for specific lexing warning 118 | import warnings 119 | 120 | warnings.filterwarnings("ignore", r'.*Lexing literal_block.*as "ipython3".*') 121 | 122 | nbsphinx_prompt_width = "0" 123 | 124 | pygments_style = "friendly" 125 | pygments_dark_style = "monokai" 126 | -------------------------------------------------------------------------------- /docs/source/data/atomic_graph.rst: -------------------------------------------------------------------------------- 1 | 2 | Atomic Graphs 3 | ============= 4 | 5 | We describe atomic graphs using the :class:`~graph_pes.AtomicGraph` class. 6 | For convenient ways to create instances of such graphs from :class:`~ase.Atoms` objects, 7 | see :meth:`~graph_pes.AtomicGraph.from_ase`. 8 | 9 | 10 | Definition 11 | ---------- 12 | 13 | .. autoclass:: graph_pes.AtomicGraph() 14 | :show-inheritance: 15 | :members: 16 | 17 | .. autofunction:: graph_pes.atomic_graph.replace 18 | 19 | Batching 20 | -------- 21 | 22 | A batch of :class:`~graph_pes.AtomicGraph` instances is itself represented by a single 23 | :class:`~graph_pes.AtomicGraph` instance, containing multiple disjoint subgraphs. 24 | 25 | :class:`~graph_pes.AtomicGraph` batches are created using :func:`~graph_pes.atomic_graph.to_batch`: 26 | 27 | .. autofunction:: graph_pes.atomic_graph.to_batch 28 | .. autofunction:: graph_pes.atomic_graph.is_batch 29 | 30 | If you need to define custom batching logic for a field in the ``other`` property, 31 | you can use :func:`~graph_pes.atomic_graph.register_custom_batcher`: 32 | 33 | .. autofunction:: graph_pes.atomic_graph.register_custom_batcher 34 | 35 | Derived Properties 36 | ------------------ 37 | 38 | We define a number of derived properties of atomic graphs. These 39 | work for both isolated and batched :class:`~graph_pes.AtomicGraph` instances. 40 | 41 | .. autofunction:: graph_pes.atomic_graph.number_of_atoms 42 | .. autofunction:: graph_pes.atomic_graph.number_of_edges 43 | .. autofunction:: graph_pes.atomic_graph.has_cell 44 | .. autofunction:: graph_pes.atomic_graph.neighbour_vectors 45 | .. autofunction:: graph_pes.atomic_graph.neighbour_distances 46 | .. autofunction:: graph_pes.atomic_graph.number_of_neighbours 47 | .. autofunction:: graph_pes.atomic_graph.available_properties 48 | .. autofunction:: graph_pes.atomic_graph.number_of_structures 49 | .. autofunction:: graph_pes.atomic_graph.structure_sizes 50 | 51 | Graph Operations 52 | ---------------- 53 | 54 | We define a number of operations that act on :class:`torch.Tensor` instances conditioned on the graph structure. 55 | All of these are fully compatible with batched :class:`~graph_pes.AtomicGraph` instances, and with ``TorchScript`` compilation. 56 | 57 | .. autofunction:: graph_pes.atomic_graph.is_local_property 58 | .. autofunction:: graph_pes.atomic_graph.index_over_neighbours 59 | .. autofunction:: graph_pes.atomic_graph.sum_over_central_atom_index 60 | .. autofunction:: graph_pes.atomic_graph.sum_over_neighbours 61 | .. autofunction:: graph_pes.atomic_graph.sum_per_structure 62 | .. autofunction:: graph_pes.atomic_graph.divide_per_atom 63 | .. autofunction:: graph_pes.atomic_graph.trim_edges 64 | 65 | Three-body operations 66 | --------------------- 67 | 68 | .. autofunction:: graph_pes.utils.threebody.triplet_edge_pairs 69 | .. autofunction:: graph_pes.utils.threebody.triplet_bond_descriptors 70 | -------------------------------------------------------------------------------- /docs/source/data/datasets.rst: -------------------------------------------------------------------------------- 1 | Datasets 2 | ======== 3 | 4 | :class:`~graph_pes.data.GraphDataset`\ s are collections of :class:`~graph_pes.AtomicGraph`\ s. 5 | We provide a base class, :class:`~graph_pes.data.GraphDataset`, together with several 6 | implementations. The most common way to get a dataset of graphs is to use 7 | :func:`~graph_pes.data.load_atoms_dataset` or :func:`~graph_pes.data.file_dataset`. 8 | 9 | Useful Datasets 10 | --------------- 11 | 12 | .. autofunction:: graph_pes.data.file_dataset 13 | 14 | .. autofunction:: graph_pes.data.load_atoms_dataset 15 | 16 | .. autoclass:: graph_pes.data.ConcatDataset() 17 | :show-inheritance: 18 | 19 | Base Classes 20 | ------------- 21 | 22 | .. autoclass:: graph_pes.data.GraphDataset() 23 | :show-inheritance: 24 | :members: 25 | :special-members: __len__, __getitem__, __iter__ 26 | 27 | .. autoclass:: graph_pes.data.ASEToGraphDataset() 28 | :show-inheritance: 29 | 30 | .. autoclass:: graph_pes.data.DatasetCollection() 31 | :show-inheritance: 32 | 33 | 34 | Utilities 35 | --------- 36 | 37 | .. autoclass:: graph_pes.data.ase_db.ASEDatabase 38 | :show-inheritance: 39 | :members: 40 | -------------------------------------------------------------------------------- /docs/source/data/loader.rst: -------------------------------------------------------------------------------- 1 | Loader 2 | ====== 3 | 4 | .. autoclass:: graph_pes.data.loader.GraphDataLoader 5 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/data/root.rst: -------------------------------------------------------------------------------- 1 | ##### 2 | Data 3 | ##### 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | atomic_graph 9 | datasets 10 | loader 11 | -------------------------------------------------------------------------------- /docs/source/development.rst: -------------------------------------------------------------------------------- 1 | Development 2 | =========== 3 | 4 | We welcome any suggestions and contributions to this project. 5 | Please visit our `GitHub repository `_ to report issues or submit pull requests. -------------------------------------------------------------------------------- /docs/source/fitting/callbacks.rst: -------------------------------------------------------------------------------- 1 | Callbacks 2 | ========= 3 | 4 | We have implemented a few useful `PyTorch Lightning `_ callbacks that you can use to monitor your training process: 5 | 6 | .. autoclass:: graph_pes.training.callbacks.WandbLogger 7 | 8 | .. autoclass:: graph_pes.training.callbacks.OffsetLogger 9 | 10 | .. autoclass:: graph_pes.training.callbacks.ScalesLogger 11 | 12 | .. autoclass:: graph_pes.training.callbacks.DumpModel 13 | 14 | .. autoclass:: graph_pes.training.callbacks.ModelTimer 15 | 16 | Base class 17 | ---------- 18 | 19 | .. autoclass:: graph_pes.training.callbacks.GraphPESCallback -------------------------------------------------------------------------------- /docs/source/fitting/losses.rst: -------------------------------------------------------------------------------- 1 | ###### 2 | Losses 3 | ###### 4 | 5 | In ``graph-pes``, we distinguish between metrics and losses: 6 | 7 | * A :class:`~graph_pes.training.loss.Loss` is some function that takes a model, a batch of graphs, and some predictions, and returns a scalar value measuring something that training should seek to minimise. 8 | This could be a prediction error, a model weight penalty, or something else. 9 | * A :class:`~graph_pes.training.loss.Metric` is some function that takes two tensors and returns a scalar value measuring the discrepancy between them. 10 | 11 | 12 | Losses 13 | ====== 14 | 15 | .. autoclass:: graph_pes.training.loss.Loss 16 | :show-inheritance: 17 | :members: name, forward, required_properties, pre_fit 18 | 19 | .. autoclass:: graph_pes.training.loss.PropertyLoss 20 | :show-inheritance: 21 | 22 | .. autoclass:: graph_pes.training.loss.PerAtomEnergyLoss 23 | 24 | 25 | Metrics 26 | ======= 27 | 28 | .. class:: graph_pes.training.loss.Metric 29 | 30 | A type alias for any function that takes two input tensors 31 | and returns some scalar measure of the discrepancy between them. 32 | 33 | .. code-block:: python 34 | 35 | Metric = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] 36 | 37 | 38 | 39 | .. autoclass:: graph_pes.training.loss.RMSE() 40 | .. autoclass:: graph_pes.training.loss.MAE() 41 | .. autoclass:: graph_pes.training.loss.MSE() 42 | 43 | 44 | 45 | 46 | Helpers 47 | ======= 48 | 49 | .. autoclass:: graph_pes.training.loss.TotalLoss 50 | :show-inheritance: 51 | 52 | .. class:: graph_pes.training.loss.MetricName 53 | 54 | A type alias for a ``Literal["RMSE", "MAE", "MSE"]``. 55 | 56 | -------------------------------------------------------------------------------- /docs/source/fitting/optimizers.rst: -------------------------------------------------------------------------------- 1 | Optimizers 2 | ========== 3 | 4 | .. autoclass:: graph_pes.training.opt.Optimizer 5 | 6 | 7 | Schedulers 8 | ---------- 9 | 10 | .. autoclass:: graph_pes.training.opt.LRScheduler -------------------------------------------------------------------------------- /docs/source/fitting/root.rst: -------------------------------------------------------------------------------- 1 | Fitting 2 | ======= 3 | 4 | ``graph-pes`` provides a number of utilities for fitting models to data. These are 5 | used internally by :doc:`/cli/graph-pes-train/root`. 6 | 7 | 8 | Contents 9 | -------- 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | 14 | losses 15 | optimizers 16 | callbacks -------------------------------------------------------------------------------- /docs/source/hide-title.html: -------------------------------------------------------------------------------- 1 | 11 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | :hidden: 3 | :maxdepth: 2 4 | 5 | quickstart/root 6 | 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | :hidden: 11 | :caption: CLI Reference 12 | 13 | cli/graph-pes-train/root 14 | cli/graph-pes-resume 15 | cli/graph-pes-test 16 | cli/graph-pes-id 17 | 18 | .. toctree:: 19 | :maxdepth: 4 20 | :hidden: 21 | :caption: API Reference 22 | 23 | data/root 24 | models/root 25 | fitting/root 26 | building-blocks/root 27 | utils 28 | 29 | 30 | .. toctree:: 31 | :maxdepth: 2 32 | :caption: Interfaces 33 | :hidden: 34 | 35 | interfaces/mace 36 | interfaces/mattersim 37 | interfaces/orb 38 | 39 | .. toctree:: 40 | :maxdepth: 2 41 | :caption: Tools 42 | :hidden: 43 | 44 | tools/torch-sim 45 | tools/ase 46 | tools/lammps 47 | tools/analysis 48 | 49 | 50 | .. toctree:: 51 | :maxdepth: 2 52 | :caption: About 53 | :hidden: 54 | 55 | theory 56 | development 57 | 58 | .. image:: _static/logo-text.svg 59 | :align: center 60 | :alt: graph-pes logo 61 | :width: 70% 62 | :target: . 63 | 64 | 65 | ######### 66 | graph-pes 67 | ######### 68 | 69 | .. raw:: html 70 | :file: hide-title.html 71 | 72 | **Date:** |today| - **Author:** `John Gardner `__ - **Version:** |release| 73 | 74 | ``graph-pes`` is a package designed to accelerate the development of machine-learned potential energy surfaces (ML-PESs) that act on graph representations of atomic structures. 75 | 76 | 77 | The core component of ``graph-pes`` is the :class:`~graph_pes.GraphPESModel`. 78 | You can take **any** model that inherits from this class and: 79 | 80 | * train and/or fine-tune it on your own data using the ``graph-pes-train`` command line tool 81 | * use it to drive MD simulations via :doc:`LAMMPS ` or :doc:`ASE ` 82 | 83 | We provide many :class:`~graph_pes.GraphPESModel`\ s, including: 84 | 85 | * re-implementations of popular architectures, including :class:`~graph_pes.models.NequIP`, :class:`~graph_pes.models.PaiNN`, :class:`~graph_pes.models.MACE` and :class:`~graph_pes.models.TensorNet` 86 | * wrappers for other popular ML-PES frameworks, including :doc:`mace-torch `, :doc:`mattersim `, and :doc:`orb-models `, that convert their models into ``graph-pes`` compatible :class:`~graph_pes.GraphPESModel` instances 87 | 88 | Use ``graph-pes`` to train models from scratch, experiment with new architectures, write architecture-agnostic validation pipelines, and try out different foundation models with minimal code changes. 89 | 90 | 91 | **Useful links**: 92 | 93 | .. grid:: 1 2 3 3 94 | :gutter: 3 95 | 96 | .. grid-item-card:: 🔥 Train 97 | :link: quickstart/quickstart 98 | :link-type: doc 99 | :text-align: center 100 | 101 | Train an existing architecture from scratch 102 | 103 | .. grid-item-card:: 🔍 Analyse 104 | :link: https://jla-gardner.github.io/graph-pes/quickstart/quickstart.html#Model-analysis 105 | :text-align: center 106 | 107 | Analyse a trained model 108 | 109 | .. grid-item-card:: 🔧 Fine-tune 110 | :link: quickstart/fine-tuning 111 | :link-type: doc 112 | :text-align: center 113 | 114 | Fine-tune a foundation model on your data 115 | 116 | .. grid-item-card:: 🔨 Build 117 | :link: quickstart/implement-a-model 118 | :link-type: doc 119 | :text-align: center 120 | 121 | Implement your own ML-PES architecture 122 | 123 | .. grid-item-card:: 🧪 Experiment 124 | :link: quickstart/custom-training-loop 125 | :link-type: doc 126 | :text-align: center 127 | 128 | Define a custom training loop 129 | 130 | .. grid-item-card:: 🎓 Learn 131 | :link: theory 132 | :link-type: doc 133 | :text-align: center 134 | 135 | Learn more about the properties of PESs 136 | 137 | 138 | 139 | **Installation:** 140 | 141 | Install ``graph-pes`` using pip. We recommend doing this in a new environment (e.g. using conda): 142 | 143 | .. code-block:: bash 144 | 145 | conda create -n graph-pes python=3.10 -y 146 | conda activate graph-pes 147 | pip install graph-pes 148 | 149 | Please see the `GitHub repository `__ for the source code and to report issues. 150 | -------------------------------------------------------------------------------- /docs/source/interfaces/mace.rst: -------------------------------------------------------------------------------- 1 | ``mace-torch`` 2 | ============== 3 | 4 | 5 | ``graph-pes`` supports the conversion of arbitrary ``mace-torch`` models to :class:`~graph_pes.GraphPESModel` objects via the :class:`~graph_pes.interfaces._mace.MACEWrapper` class. 6 | 7 | We also provide convenience functions to access the recently trained ``MACE-MP`` and ``MACE-OFF`` "foundation" models, as well as the ``GO-MACE-23`` and ``Egret-1`` series of models. 8 | 9 | You can use all of these models in the same way as any other :class:`~graph_pes.GraphPESModel`, either via the Python API: 10 | 11 | .. code-block:: python 12 | 13 | from graph_pes.interfaces import mace_mp 14 | model = mace_mp("medium-0b3") 15 | model.predict_energy(graph) 16 | 17 | 18 | or within a ``graph-pes-train`` configuration file: 19 | 20 | .. code-block:: yaml 21 | 22 | model: 23 | +mace_off: 24 | model: small 25 | 26 | If you use any ``mace-torch`` models in your work, please visit the `mace-torch `__ repository and cite the following: 27 | 28 | .. code-block:: bibtex 29 | 30 | @inproceedings{Batatia2022mace, 31 | title={{MACE}: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields}, 32 | author={Ilyes Batatia and David Peter Kovacs and Gregor N. C. Simm and Christoph Ortner and Gabor Csanyi}, 33 | booktitle={Advances in Neural Information Processing Systems}, 34 | editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho}, 35 | year={2022}, 36 | url={https://openreview.net/forum?id=YPpSngE-ZU} 37 | } 38 | 39 | @misc{Batatia2022Design, 40 | title = {The Design Space of E(3)-Equivariant Atom-Centered Interatomic Potentials}, 41 | author = {Batatia, Ilyes and Batzner, Simon and Kov{\'a}cs, D{\'a}vid P{\'e}ter and Musaelian, Albert and Simm, Gregor N. C. and Drautz, Ralf and Ortner, Christoph and Kozinsky, Boris and Cs{\'a}nyi, G{\'a}bor}, 42 | year = {2022}, 43 | number = {arXiv:2205.06643}, 44 | eprint = {2205.06643}, 45 | eprinttype = {arxiv}, 46 | doi = {10.48550/arXiv.2205.06643}, 47 | archiveprefix = {arXiv} 48 | } 49 | 50 | 51 | 52 | Installation 53 | ------------ 54 | 55 | To install ``graph-pes`` with support for MACE models, you need to install 56 | the `mace-torch `__ package. We recommend doing this in a new environment: 57 | 58 | .. code-block:: bash 59 | 60 | conda create -n graph-pes-mace python=3.10 61 | conda activate graph-pes-mace 62 | pip install mace-torch graph-pes 63 | 64 | 65 | Interface 66 | --------- 67 | 68 | .. autofunction:: graph_pes.interfaces.mace_mp 69 | .. autofunction:: graph_pes.interfaces.mace_off 70 | .. autofunction:: graph_pes.interfaces.go_mace_23 71 | .. autofunction:: graph_pes.interfaces.egret 72 | 73 | .. autoclass:: graph_pes.interfaces._mace.MACEWrapper 74 | -------------------------------------------------------------------------------- /docs/source/interfaces/mattersim.rst: -------------------------------------------------------------------------------- 1 | ``mattersim`` 2 | ============= 3 | 4 | 5 | ``graph-pes`` allows you fine-tune and use the ``mattersim`` series of models in the same way as any other :class:`~graph_pes.GraphPESModel`, either via the Python API: 6 | 7 | 8 | .. code-block:: python 9 | 10 | from graph_pes.interfaces import mattersim 11 | model = mattersim("mattersim-v1.0.0-1m") 12 | model.predict_energy(graph) 13 | 14 | ... or within a ``graph-pes-train`` configuration file: 15 | 16 | .. code-block:: yaml 17 | 18 | model: 19 | +mattersim: 20 | load_path: "mattersim-v1.0.0-5m" 21 | 22 | 23 | 24 | If you use any ``mattersim`` models in your work, please visit the `mattersim `__ repository and cite the following: 25 | 26 | .. code-block:: bibtex 27 | 28 | @article{yang2024mattersim, 29 | title={MatterSim: A Deep Learning Atomistic Model Across Elements, Temperatures and Pressures}, 30 | author={Han Yang and Chenxi Hu and Yichi Zhou and Xixian Liu and Yu Shi and Jielan Li and Guanzhi Li and Zekun Chen and Shuizhou Chen and Claudio Zeni and Matthew Horton and Robert Pinsler and Andrew Fowler and Daniel Zügner and Tian Xie and Jake Smith and Lixin Sun and Qian Wang and Lingyu Kong and Chang Liu and Hongxia Hao and Ziheng Lu}, 31 | year={2024}, 32 | eprint={2405.04967}, 33 | archivePrefix={arXiv}, 34 | primaryClass={cond-mat.mtrl-sci}, 35 | url={https://arxiv.org/abs/2405.04967}, 36 | journal={arXiv preprint arXiv:2405.04967} 37 | } 38 | 39 | 40 | 41 | 42 | Installation 43 | ------------ 44 | 45 | To install ``graph-pes`` with support for ``mattersim`` models, you need to install 46 | the `mattersim `__ package. We recommend doing this in a new environment: 47 | 48 | .. code-block:: bash 49 | 50 | conda create -n graph-pes-mattersim python=3.9 51 | conda activate graph-pes-mattersim 52 | pip install graph-pes 53 | pip install --upgrade mattersim 54 | 55 | 56 | Interface 57 | --------- 58 | 59 | .. autofunction:: graph_pes.interfaces.mattersim 60 | 61 | -------------------------------------------------------------------------------- /docs/source/interfaces/orb.rst: -------------------------------------------------------------------------------- 1 | ``orb-models`` 2 | ============== 3 | 4 | 5 | ``graph-pes`` supports the conversion of arbitrary ``orb-models`` models to :class:`~graph_pes.GraphPESModel` objects via the :class:`~graph_pes.interfaces._orb.OrbWrapper` class. 6 | 7 | Use the :func:`~graph_pes.interfaces.orb_model` function to load a pre-trained ``orb-models`` model and convert it into a :class:`~graph_pes.GraphPESModel`. You can then use this model in the same way as any other :class:`~graph_pes.GraphPESModel`, for instance by :doc:`fine-tuning it <../quickstart/fine-tuning>` or using it to run MD via 8 | :doc:`torch-sim <../tools/torch-sim>`, 9 | :doc:`ASE <../tools/ase>` or :doc:`LAMMPS <../tools/lammps>`: 10 | 11 | .. code-block:: python 12 | 13 | from graph_pes.interfaces import orb_model 14 | from graph_pes import GraphPESModel 15 | 16 | model = orb_model() 17 | assert isinstance(model, GraphPESModel) 18 | 19 | # do stuff ... 20 | 21 | 22 | You can also reference the :func:`~graph_pes.interfaces.orb_model` function in your training configs for :doc:`graph-pes-train <../cli/graph-pes-train/root>`: 23 | 24 | .. code-block:: yaml 25 | 26 | model: 27 | +orb_model: 28 | name: orb-v3-direct-20-omat 29 | 30 | 31 | 32 | If you use any ``orb-models`` models in your work, please visit the `orb-models `_ repository and cite the following: 33 | 34 | .. code-block:: bibtex 35 | 36 | @misc{rhodes2025orbv3atomisticsimulationscale, 37 | title={Orb-v3: atomistic simulation at scale}, 38 | author={ 39 | Benjamin Rhodes and Sander Vandenhaute and Vaidotas Šimkus 40 | and James Gin and Jonathan Godwin and Tim Duignan and Mark Neumann 41 | }, 42 | year={2025}, 43 | eprint={2504.06231}, 44 | archivePrefix={arXiv}, 45 | primaryClass={cond-mat.mtrl-sci}, 46 | url={https://arxiv.org/abs/2504.06231}, 47 | } 48 | 49 | @misc{neumann2024orbfastscalableneural, 50 | title={Orb: A Fast, Scalable Neural Network Potential}, 51 | author={ 52 | Mark Neumann and James Gin and Benjamin Rhodes 53 | and Steven Bennett and Zhiyi Li and Hitarth Choubisa 54 | and Arthur Hussey and Jonathan Godwin 55 | }, 56 | year={2024}, 57 | eprint={2410.22570}, 58 | archivePrefix={arXiv}, 59 | primaryClass={cond-mat.mtrl-sci}, 60 | url={https://arxiv.org/abs/2410.22570}, 61 | } 62 | 63 | 64 | Installation 65 | ------------ 66 | 67 | To install ``graph-pes`` with support for ``orb-models`` models, you need to install 68 | the `orb-models `_ package alongside ``graph-pes``. We recommend doing this in a new environment: 69 | 70 | .. code-block:: bash 71 | 72 | conda create -n graph-pes-orb python=3.10 73 | conda activate graph-pes-orb 74 | pip install graph-pes orb-models 75 | 76 | 77 | Interface 78 | --------- 79 | 80 | .. autofunction:: graph_pes.interfaces.orb_model 81 | 82 | .. autoclass:: graph_pes.interfaces._orb.OrbWrapper 83 | :members: orb_model -------------------------------------------------------------------------------- /docs/source/models/addition.rst: -------------------------------------------------------------------------------- 1 | Addition Model 2 | ============== 3 | 4 | .. autoclass:: graph_pes.models.AdditionModel 5 | :special-members: __getitem__ 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/models/many-body/eddp.rst: -------------------------------------------------------------------------------- 1 | EDDP 2 | #### 3 | 4 | Train this architecture on your own data using the :doc:`graph-pes-train <../../cli/graph-pes-train/root>` CLI, using e.g. the following config: 5 | 6 | 7 | .. code-block:: yaml 8 | 9 | model: 10 | +EDDP: 11 | elements: [H, C, N, O] 12 | cutoff: 5.0 13 | three_body_cutoff: 3.0 14 | 15 | Definition 16 | ---------- 17 | 18 | .. autoclass:: graph_pes.models.EDDP 19 | :show-inheritance: 20 | -------------------------------------------------------------------------------- /docs/source/models/many-body/mace.rst: -------------------------------------------------------------------------------- 1 | MACE 2 | #### 3 | 4 | Train this architecture on your own data using the :doc:`graph-pes-train <../../cli/graph-pes-train/root>` CLI, using e.g. the following config: 5 | 6 | .. code-block:: yaml 7 | 8 | model: 9 | +MACE: 10 | elements: [H, C, N, O] 11 | 12 | 13 | Definition 14 | ---------- 15 | 16 | .. autoclass:: graph_pes.models.MACE 17 | .. autoclass:: graph_pes.models.ZEmbeddingMACE 18 | 19 | ``ScaleShiftMACE``? 20 | ------------------- 21 | 22 | To replicate a ``ScaleShiftMACE`` model as defined in the reference `MACE `_ implementation, you could use the following config: 23 | 24 | .. code-block:: yaml 25 | 26 | model: 27 | offset: 28 | +LearnableOffset: {} 29 | many-body: 30 | +MACE: 31 | elements: [H, C, N, O] -------------------------------------------------------------------------------- /docs/source/models/many-body/nequip.rst: -------------------------------------------------------------------------------- 1 | NequIP 2 | ====== 3 | 4 | Train this architecture on your own data using the :doc:`graph-pes-train <../../cli/graph-pes-train/root>` CLI, using e.g. the following config: 5 | 6 | .. code-block:: yaml 7 | 8 | model: 9 | +NequIP: 10 | elements: [H, C, N, O] 11 | features: 12 | channels: [64, 32, 8] 13 | l_max: 2 14 | use_odd_parity: true 15 | 16 | 17 | Definition 18 | ---------- 19 | 20 | .. autoclass:: graph_pes.models.NequIP 21 | .. autoclass:: graph_pes.models.ZEmbeddingNequIP 22 | 23 | Utilities 24 | --------- 25 | 26 | .. autoclass:: graph_pes.models.e3nn.nequip.SimpleIrrepSpec() 27 | :show-inheritance: 28 | .. autoclass:: graph_pes.models.e3nn.nequip.CompleteIrrepSpec() 29 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/models/many-body/painn.rst: -------------------------------------------------------------------------------- 1 | PaiNN 2 | ##### 3 | 4 | Train this architecture on your own data using the :doc:`graph-pes-train <../../cli/graph-pes-train/root>` CLI, using e.g. the following config: 5 | 6 | .. code-block:: yaml 7 | 8 | model: 9 | +PaiNN: 10 | channels: 32 11 | 12 | Definition 13 | ---------- 14 | 15 | .. autoclass:: graph_pes.models.PaiNN 16 | :show-inheritance: 17 | .. autoclass:: graph_pes.models.painn.Interaction 18 | .. autoclass:: graph_pes.models.painn.Update 19 | -------------------------------------------------------------------------------- /docs/source/models/many-body/root.rst: -------------------------------------------------------------------------------- 1 | Many Body Models 2 | ================ 3 | 4 | ``graph-pes`` has re-implemented the following, popular many-bodied, machine-learned interatomic potentials: 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | stillinger-weber 10 | eddp 11 | schnet 12 | painn 13 | nequip 14 | mace 15 | tensornet 16 | 17 | 18 | -------------------------------------------------------------------------------- /docs/source/models/many-body/schnet.rst: -------------------------------------------------------------------------------- 1 | SchNet 2 | ====== 3 | 4 | Train this architecture on your own data using the :doc:`graph-pes-train <../../cli/graph-pes-train/root>` CLI, using e.g. the following config: 5 | 6 | .. code-block:: yaml 7 | 8 | model: 9 | +SchNet: 10 | channels: 32 11 | 12 | Definition 13 | ---------- 14 | 15 | .. autoclass:: graph_pes.models.SchNet 16 | :show-inheritance: 17 | 18 | Components: 19 | ------------ 20 | 21 | .. autoclass:: graph_pes.models.schnet.SchNetInteraction 22 | .. autoclass:: graph_pes.models.schnet.CFConv -------------------------------------------------------------------------------- /docs/source/models/many-body/stillinger-weber.rst: -------------------------------------------------------------------------------- 1 | ################ 2 | Stillinger-Weber 3 | ################ 4 | 5 | Use this empirical model directly via the Python API: 6 | 7 | .. code-block:: python 8 | 9 | from graph_pes.models import StillingerWeber 10 | model = StillingerWeber() 11 | model.predict_energy(graph) 12 | 13 | # or monatomic water 14 | model = StillingerWeber.monatomic_water() 15 | model.predict_energy(graph) 16 | 17 | or within a ``graph-pes-train`` configuration file to :doc:`train a new model <../../cli/graph-pes-train/root>`. 18 | 19 | .. code-block:: yaml 20 | 21 | model: 22 | +StillingerWeber: 23 | sigma: 3 24 | 25 | 26 | Definition 27 | ---------- 28 | 29 | .. autoclass:: graph_pes.models.StillingerWeber 30 | :members: monatomic_water 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/source/models/many-body/tensornet.rst: -------------------------------------------------------------------------------- 1 | TensorNet 2 | ========= 3 | 4 | Train this architecture on your own data using the :doc:`graph-pes-train <../../cli/graph-pes-train/root>` CLI, using e.g. the following config: 5 | 6 | .. code-block:: yaml 7 | 8 | model: 9 | +TensorNet: 10 | channels: 32 11 | 12 | Definition 13 | ---------- 14 | 15 | .. autoclass:: graph_pes.models.TensorNet 16 | :show-inheritance: 17 | 18 | Components 19 | ---------- 20 | 21 | Below, we use the notation as taken from the `TensorNet paper `__. 22 | 23 | .. autoclass:: graph_pes.models.tensornet.ScalarOutput 24 | .. autoclass:: graph_pes.models.tensornet.VectorOutput 25 | 26 | -------------------------------------------------------------------------------- /docs/source/models/offsets.rst: -------------------------------------------------------------------------------- 1 | Energy Offset 2 | ============= 3 | 4 | .. autoclass:: graph_pes.models.offsets.EnergyOffset 5 | :show-inheritance: 6 | 7 | .. autoclass:: graph_pes.models.FixedOffset 8 | :show-inheritance: 9 | 10 | .. autoclass:: graph_pes.models.LearnableOffset 11 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/models/pairwise.rst: -------------------------------------------------------------------------------- 1 | Pair Potentials 2 | =============== 3 | 4 | Base Class 5 | ---------- 6 | 7 | .. autoclass:: graph_pes.models.PairPotential 8 | :members: interaction 9 | :show-inheritance: 10 | 11 | Available Pair Potentials 12 | ------------------------- 13 | 14 | 15 | .. autoclass:: graph_pes.models.LennardJones 16 | :show-inheritance: 17 | :members: from_ase 18 | 19 | .. autoclass:: graph_pes.models.ZBLCoreRepulsion 20 | :show-inheritance: 21 | 22 | .. autoclass:: graph_pes.models.Morse 23 | :show-inheritance: 24 | 25 | .. autoclass:: graph_pes.models.LennardJonesMixture 26 | :show-inheritance: 27 | 28 | .. autoclass:: graph_pes.models.SmoothedPairPotential 29 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/models/root.rst: -------------------------------------------------------------------------------- 1 | .. _models: 2 | 3 | 4 | Models 5 | ###### 6 | 7 | 8 | .. autoclass:: graph_pes.GraphPESModel 9 | :members: 10 | :show-inheritance: 11 | 12 | 13 | Loading Models 14 | ============== 15 | 16 | .. autofunction:: graph_pes.models.load_model 17 | .. autofunction:: graph_pes.models.load_model_component 18 | 19 | Freezing Models 20 | =============== 21 | 22 | .. class:: graph_pes.models.T 23 | 24 | Type alias for ``TypeVar("T", bound=torch.nn.Module)``. 25 | 26 | .. autofunction:: graph_pes.models.freeze 27 | .. autofunction:: graph_pes.models.freeze_matching 28 | .. autofunction:: graph_pes.models.freeze_all_except 29 | .. autofunction:: graph_pes.models.freeze_any_matching 30 | 31 | 32 | Available Models 33 | ================ 34 | 35 | .. toctree:: 36 | :maxdepth: 2 37 | 38 | addition 39 | offsets 40 | pairwise 41 | many-body/root 42 | 43 | 44 | Unit Conversion 45 | =============== 46 | 47 | .. autoclass:: graph_pes.models.UnitConverter 48 | :show-inheritance: 49 | -------------------------------------------------------------------------------- /docs/source/quickstart/fine-tune.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train: train.xyz 3 | valid: valid.xyz 4 | 5 | loss: 6 | - +PerAtomEnergyLoss() 7 | - +ForceRMSE() 8 | 9 | fitting: 10 | trainer_kwargs: 11 | max_epochs: 20 12 | accelerator: cpu 13 | 14 | optimizer: 15 | name: Adam 16 | lr: 0.0001 17 | 18 | auto_fit_reference_energies: true 19 | 20 | wandb: null 21 | general: 22 | progress: logged 23 | run_id: mp0-fine-tune 24 | -------------------------------------------------------------------------------- /docs/source/quickstart/mp0.yaml: -------------------------------------------------------------------------------- 1 | 2 | model: 3 | +mace_mp: 4 | model: small 5 | 6 | general: 7 | run_id: mp0-fine-tune 8 | -------------------------------------------------------------------------------- /docs/source/quickstart/orb.yaml: -------------------------------------------------------------------------------- 1 | 2 | model: 3 | +freeze_all_except: 4 | model: 5 | +orb_model: 6 | name: orb-d3-xs-v2 7 | pattern: _orb\.heads.* 8 | 9 | general: 10 | run_id: orb-fine-tune 11 | -------------------------------------------------------------------------------- /docs/source/quickstart/quickstart-cgap17.yaml: -------------------------------------------------------------------------------- 1 | # define a radial cutoff to use throughout the config 2 | CUTOFF: 3.7 # in Å 3 | 4 | general: 5 | progress: logged 6 | 7 | # train a lightweight NequIP model ... 8 | model: 9 | offset: 10 | # note the "+" prefix syntax: refer to the 11 | # data2objects package for more details 12 | +FixedOffset: { C: -148.314002 } 13 | many-body: 14 | +NequIP: 15 | elements: [C] 16 | cutoff: =/CUTOFF # reference the radial cutoff defined above 17 | layers: 2 18 | features: 19 | channels: [16, 8, 4] 20 | l_max: 2 21 | use_odd_parity: true 22 | self_interaction: linear 23 | 24 | # ... on structures from local files ... 25 | data: 26 | train: 27 | path: train-cgap17.xyz 28 | n: 1280 29 | shuffle: false 30 | valid: val-cgap17.xyz 31 | test: test-cgap17.xyz 32 | 33 | # ... on both energy and forces (weighted 1:1) ... 34 | loss: 35 | - +PerAtomEnergyLoss() 36 | - +ForceRMSE() 37 | 38 | # ... with the following settings ... 39 | fitting: 40 | trainer_kwargs: 41 | max_epochs: 250 42 | accelerator: auto 43 | check_val_every_n_epoch: 5 44 | 45 | optimizer: 46 | name: AdamW 47 | lr: 0.01 48 | 49 | scheduler: 50 | name: ReduceLROnPlateau 51 | factor: 0.5 52 | patience: 10 53 | 54 | loader_kwargs: 55 | batch_size: 64 56 | 57 | # ... and log to Weights & Biases 58 | wandb: 59 | project: graph-pes-quickstart 60 | -------------------------------------------------------------------------------- /docs/source/quickstart/root.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | 4 | Run each of these notebooks locally, or use the Google Colab links to follow along with no installation required. 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | quickstart 10 | fine-tuning 11 | implement-a-model 12 | custom-training-loop 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/source/tools/analysis.rst: -------------------------------------------------------------------------------- 1 | ######## 2 | Analysis 3 | ######## 4 | 5 | 6 | ``graph-pes`` provides a number of utilities for analysing the models: 7 | 8 | 9 | .. autofunction:: graph_pes.utils.analysis.parity_plot 10 | .. autofunction:: graph_pes.utils.analysis.dimer_curve 11 | 12 | .. class:: graph_pes.utils.analysis.Transform 13 | 14 | Alias for ``Callable[[Tensor, AtomicGraph], Tensor]``. 15 | 16 | Transforms map a property, :math:`x`, to a target property, :math:`y`, 17 | conditioned on an :class:`~graph_pes.AtomicGraph`, :math:`\mathcal{G}`: 18 | 19 | .. math:: 20 | 21 | T: (x; \mathcal{G}) \mapsto y -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | Utils 2 | ===== 3 | 4 | 5 | Shift and Scale 6 | --------------- 7 | 8 | .. autofunction:: graph_pes.utils.shift_and_scale.guess_per_element_mean_and_var 9 | 10 | 11 | Sampling 12 | -------- 13 | 14 | .. autoclass:: graph_pes.utils.sampling.T 15 | 16 | .. autoclass:: graph_pes.utils.sampling.SequenceSampler 17 | :members: 18 | 19 | Useful Types and Aliases 20 | ------------------------ 21 | 22 | .. class:: graph_pes.atomic_graph.PropertyKey 23 | 24 | Type alias for ``Literal["energy", "forces", "stress", "virial", "local_energies"]``. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "graph-pes" 7 | version = "0.1.1" 8 | description = "Potential Energy Surfaces on Graphs" 9 | readme = "README.md" 10 | authors = [{ name = "John Gardner", email = "gardner.john97@gmail.com" }] 11 | license = { file = "LICENSE" } 12 | classifiers = [ 13 | "License :: OSI Approved :: MIT License", 14 | "Programming Language :: Python", 15 | "Programming Language :: Python :: 3", 16 | ] 17 | keywords = [] 18 | dependencies = [ 19 | "torch", 20 | "pytorch-lightning", 21 | "ase", 22 | "numpy", 23 | "rich", 24 | "dacite", 25 | "e3nn==0.4.4", 26 | "scikit-learn", 27 | "locache>=4.0.2", 28 | "load-atoms>=0.3.9", 29 | "wandb", 30 | "data2objects>=0.1.0", 31 | "pyright>=1.1.394", 32 | "vesin>=0.3.2", 33 | ] 34 | requires-python = ">=3.9" 35 | 36 | 37 | [project.optional-dependencies] 38 | test = ["pytest", "pytest-cov"] 39 | docs = [ 40 | "sphinx", 41 | "furo", 42 | "nbsphinx", 43 | "sphinxext-opengraph", 44 | "sphinx-copybutton", 45 | "sphinx-design", 46 | ] 47 | publish = ["build", "twine"] 48 | 49 | [project.scripts] 50 | graph-pes-train = "graph_pes.scripts.train:main" 51 | graph-pes-test = "graph_pes.scripts.test:main" 52 | graph-pes-resume = "graph_pes.scripts.resume:main" 53 | graph-pes-id = "graph_pes.scripts.id:main" 54 | 55 | [project.urls] 56 | Homepage = "https://github.com/jla-gardner/graph-pes" 57 | 58 | [tool.bumpver] 59 | current_version = "0.1.1" 60 | version_pattern = "MAJOR.MINOR.PATCH" 61 | commit_message = "{old_version} -> {new_version}" 62 | commit = true 63 | tag = true 64 | push = false 65 | 66 | [tool.bumpver.file_patterns] 67 | "pyproject.toml" = ['current_version = "{version}"', 'version = "{version}"'] 68 | "src/graph_pes/__init__.py" = ["{version}"] 69 | "docs/source/conf.py" = ['release = "{version}"'] 70 | "src/graph_pes/graph_pes_model.py" = [ 71 | 'self._GRAPH_PES_VERSION: Final\[str\] = "{version}"', 72 | ] 73 | "CITATION.cff" = ['^version: {version}$'] 74 | 75 | [tool.ruff] 76 | line-length = 80 77 | indent-width = 4 78 | target-version = "py38" 79 | extend-include = ["*.ipynb", "*.pyi", "*.toml"] 80 | 81 | [tool.ruff.lint] 82 | select = ["E", "F", "UP", "B", "SIM", "I"] 83 | ignore = ["SIM300", "E402", "E703", "F722", "UP037", "F821", "B018", "E741"] 84 | fixable = ["ALL"] 85 | unfixable = [] 86 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 87 | 88 | [tool.ruff.format] 89 | quote-style = "double" 90 | indent-style = "space" 91 | skip-magic-trailing-comma = false 92 | line-ending = "auto" 93 | 94 | [tool.ruff.lint.pydocstyle] 95 | convention = "numpy" 96 | 97 | [tool.coverage.report] 98 | exclude_also = [ 99 | "def __repr__", 100 | "def _repr", 101 | "class .*\\bProtocol\\):", 102 | "@(abc\\.)?abstractmethod", 103 | "\\.\\.\\.", 104 | "except ImportError", 105 | "if TYPE_CHECKING", 106 | "raise NotImplementedError", 107 | "if __name__ == '__main__':", 108 | ] 109 | 110 | # allow for coverage to find relevant files in both src/ and */site-packages 111 | # so that we can install the package both normally and in editable mode, and 112 | # still get coverage for both cases using `pytest --cov` 113 | [tool.coverage.paths] 114 | source = ["src", "*/site-packages"] 115 | 116 | [tool.coverage.run] 117 | branch = true 118 | source = ["graph_pes"] 119 | omit = ["*/graph_pes/interfaces/*"] 120 | 121 | [tool.pytest.ini_options] 122 | # ignore all warnings coming from the pytorch_lightning package 123 | filterwarnings = [ 124 | "ignore::DeprecationWarning:pytorch_lightning", 125 | "ignore::DeprecationWarning:lightning_fabric", 126 | "ignore::DeprecationWarning:lightning_utilities", 127 | "ignore::DeprecationWarning:pkg_resources", 128 | "ignore::DeprecationWarning:torchmetrics", 129 | "ignore::UserWarning:pytorch_lightning", 130 | "ignore:.*The TorchScript type system doesn't support instance-level annotations on empty non-base types.*", 131 | ] 132 | norecursedirs = "tests/helpers" 133 | 134 | [dependency-groups] 135 | dev = ["bumpver>=2024.1130", "notebook>=7.3.2", "ruff", "sphinx-autobuild"] 136 | -------------------------------------------------------------------------------- /src/graph_pes/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | 5 | from graph_pes.atomic_graph import AtomicGraph 6 | from graph_pes.graph_pes_model import GraphPESModel 7 | 8 | # hide the annoying FutureWarning from e3nn 9 | warnings.filterwarnings("ignore", category=FutureWarning, module="e3nn") 10 | 11 | # fix e3nns torch.load without weights_only 12 | if hasattr(torch.serialization, "add_safe_globals"): 13 | torch.serialization.add_safe_globals([slice]) 14 | 15 | __all__ = ["AtomicGraph", "GraphPESModel"] 16 | __version__ = "0.1.1" 17 | -------------------------------------------------------------------------------- /src/graph_pes/config/training-defaults.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | seed: 42 3 | root_dir: "graph-pes-results" 4 | run_id: null 5 | log_level: "INFO" 6 | progress: rich 7 | torch: 8 | dtype: float32 9 | float32_matmul_precision: high 10 | 11 | fitting: 12 | pre_fit_model: true 13 | max_n_pre_fit: 5_000 14 | 15 | # train for 100 epochs on the best device available 16 | trainer_kwargs: 17 | max_epochs: 100 18 | accelerator: auto 19 | enable_model_summary: false 20 | 21 | loader_kwargs: 22 | num_workers: 1 23 | persistent_workers: true 24 | batch_size: 4 25 | pin_memory: false 26 | 27 | # "fancy"/optional training options disabled 28 | callbacks: [] 29 | scheduler: null 30 | swa: null 31 | early_stopping: null 32 | auto_fit_reference_energies: false 33 | 34 | # this is deprecated, use `early_stopping` instead 35 | early_stopping_patience: null 36 | 37 | wandb: {} 38 | -------------------------------------------------------------------------------- /src/graph_pes/config/training.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: UP006, UP007 2 | # ^^ NB: dacite parsing requires the old type hint syntax in 3 | # order to be compatible with all versions of Python that 4 | # we are targeting (3.9+) 5 | from __future__ import annotations 6 | 7 | from dataclasses import dataclass, field 8 | from pathlib import Path 9 | from typing import Any, Dict, List, Literal, Union 10 | 11 | import yaml 12 | from pytorch_lightning import Callback 13 | 14 | from graph_pes.config.shared import TorchConfig, parse_dataset_collection 15 | from graph_pes.data.datasets import DatasetCollection 16 | from graph_pes.graph_pes_model import GraphPESModel 17 | from graph_pes.training.callbacks import VerboseSWACallback 18 | from graph_pes.training.loss import Loss, TotalLoss 19 | from graph_pes.training.opt import LRScheduler, Optimizer 20 | 21 | 22 | @dataclass 23 | class EarlyStoppingConfig: 24 | patience: int 25 | """ 26 | The number of validation checks with no improvement before stopping. 27 | """ 28 | 29 | min_delta: float = 0.0 30 | """ 31 | The minimum change in the monitored quantity to qualify as an improvement. 32 | """ 33 | 34 | monitor: str = "valid/loss/total" 35 | """The quantity to monitor.""" 36 | 37 | 38 | @dataclass 39 | class FittingOptions: 40 | """Options for the fitting process.""" 41 | 42 | pre_fit_model: bool 43 | max_n_pre_fit: Union[int, None] 44 | early_stopping: Union[EarlyStoppingConfig, None] 45 | loader_kwargs: Dict[str, Any] 46 | early_stopping_patience: Union[int, None] 47 | """ 48 | DEPRECATED: use the `early_stopping` config option instead. 49 | """ 50 | auto_fit_reference_energies: bool 51 | 52 | 53 | @dataclass 54 | class SWAConfig: 55 | """ 56 | Configuration for Stochastic Weight Averaging. 57 | 58 | Internally, this is handled by `this PyTorch Lightning callback 59 | `__. 60 | """ 61 | 62 | lr: float 63 | """ 64 | The learning rate to use during the SWA phase. If not specified, 65 | the learning rate from the end of the training phase will be used. 66 | """ 67 | 68 | start: Union[int, float] = 0.8 69 | """ 70 | The epoch at which to start SWA. If a float, it will be interpreted 71 | as a fraction of the total number of epochs. 72 | """ 73 | 74 | anneal_epochs: int = 10 75 | """ 76 | The number of epochs over which to linearly anneal the learning rate 77 | to zero. 78 | """ 79 | 80 | strategy: Literal["linear", "cos"] = "linear" 81 | """The strategy to use for annealing the learning rate.""" 82 | 83 | def instantiate_lightning_callback(self): 84 | return VerboseSWACallback( 85 | swa_lrs=self.lr, 86 | swa_epoch_start=self.start, 87 | annealing_epochs=self.anneal_epochs, 88 | annealing_strategy=self.strategy, 89 | ) 90 | 91 | 92 | @dataclass 93 | class FittingConfig(FittingOptions): 94 | """Configuration for the fitting process.""" 95 | 96 | trainer_kwargs: Dict[str, Any] 97 | optimizer: Union[Optimizer, Dict[str, Any], None] = None 98 | scheduler: Union[LRScheduler, Dict[str, Any], None] = None 99 | swa: Union[SWAConfig, None] = None 100 | callbacks: List[Callback] = field(default_factory=list) 101 | 102 | def get_optimizer(self) -> Optimizer: 103 | if isinstance(self.optimizer, Optimizer): 104 | return self.optimizer 105 | 106 | kwargs = {"name": "Adam", "lr": 1e-3} 107 | kwargs.update(self.optimizer or {}) 108 | return Optimizer(**kwargs) 109 | 110 | def get_scheduler(self) -> LRScheduler | None: 111 | if isinstance(self.scheduler, LRScheduler): 112 | return self.scheduler 113 | 114 | if self.scheduler is None: 115 | return None 116 | 117 | return LRScheduler(**self.scheduler) 118 | 119 | 120 | @dataclass 121 | class GeneralConfig: 122 | """General configuration for a training run.""" 123 | 124 | seed: int 125 | root_dir: str 126 | run_id: Union[str, None] 127 | torch: TorchConfig 128 | log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" 129 | progress: Literal["rich", "logged"] = "rich" 130 | 131 | 132 | @dataclass 133 | class TrainingConfig: 134 | """ 135 | A schema for a configuration file to train a 136 | :class:`~graph_pes.GraphPESModel`. 137 | """ 138 | 139 | model: Union[GraphPESModel, Dict[str, GraphPESModel]] 140 | data: Union[DatasetCollection, Dict[str, Any]] 141 | loss: Union[Loss, TotalLoss, Dict[str, Loss], List[Loss]] 142 | fitting: FittingConfig 143 | general: GeneralConfig 144 | wandb: Union[Dict[str, Any], None] 145 | 146 | ### Methods ### 147 | 148 | def get_data(self, model: GraphPESModel) -> DatasetCollection: 149 | return parse_dataset_collection(self.data, model) 150 | 151 | @classmethod 152 | def defaults(cls) -> dict: 153 | with open(Path(__file__).parent / "training-defaults.yaml") as f: 154 | return yaml.safe_load(f) 155 | -------------------------------------------------------------------------------- /src/graph_pes/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import ( 2 | ASEToGraphDataset, 3 | ConcatDataset, 4 | DatasetCollection, 5 | GraphDataset, 6 | file_dataset, 7 | load_atoms_dataset, 8 | ) 9 | from .loader import GraphDataLoader 10 | 11 | __all__ = [ 12 | "load_atoms_dataset", 13 | "file_dataset", 14 | "GraphDataset", 15 | "ASEToGraphDataset", 16 | "DatasetCollection", 17 | "GraphDataLoader", 18 | "ConcatDataset", 19 | ] 20 | -------------------------------------------------------------------------------- /src/graph_pes/data/ase_db.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pathlib 4 | from typing import Sequence, overload 5 | 6 | import ase 7 | import ase.db 8 | import numpy as np 9 | 10 | from graph_pes.utils.misc import slice_to_range 11 | 12 | 13 | class ASEDatabase(Sequence[ase.Atoms]): 14 | """ 15 | A class that wraps an ASE database file, allowing for indexing into the 16 | database to obtain :class:`ase.Atoms` objects. 17 | 18 | We assume that each row contains labels in the ``data`` attribute, 19 | as a mapping from property names to values, and that units are "standard" 20 | ASE units, e.g. ``eV``, ``eV/Å``, etc. 21 | 22 | Fully compatible with `SchNetPack Dataset Files `__. 23 | 24 | See the `ASE documentation `__ 25 | for more details about this file format. 26 | 27 | .. warning:: 28 | 29 | This dataset indexes into a database, performing many random access 30 | reads from disk. This can be very slow! If you are using a distributed 31 | compute cluster, ensure you copy your database file to somewhere with 32 | fast local storage (as opposed to network-attached storage). 33 | 34 | Similarly, consider using several workers when loading the dataset, 35 | e.g. ``fitting/loader_kwargs/num_workers=8``. 36 | 37 | Parameters 38 | ---------- 39 | path: str | pathlib.Path 40 | The path to the database. 41 | """ # noqa: E501 42 | 43 | def __init__(self, path: str | pathlib.Path): 44 | path = pathlib.Path(path) 45 | if not path.exists(): 46 | raise FileNotFoundError(f"Database file {path} does not exist") 47 | self.path = path 48 | self.db = ase.db.connect(path, use_lock_file=False) 49 | 50 | @overload 51 | def __getitem__(self, index: int) -> ase.Atoms: ... 52 | @overload 53 | def __getitem__(self, index: slice) -> Sequence[ase.Atoms]: ... 54 | def __getitem__( 55 | self, index: int | slice 56 | ) -> ase.Atoms | Sequence[ase.Atoms]: 57 | if isinstance(index, slice): 58 | indices = slice_to_range(index, len(self)) 59 | return [self[i] for i in indices] 60 | 61 | atoms = self.db.get_atoms(index + 1, add_additional_information=True) 62 | data = atoms.info.pop("data", {}) 63 | arrays = { 64 | k: v 65 | for k, v in data.items() 66 | if isinstance(v, np.ndarray) and v.shape[0] == len(atoms) 67 | } 68 | info = {k: v for k, v in data.items() if k not in arrays} 69 | atoms.arrays.update(arrays) 70 | atoms.info.update(info) 71 | return atoms 72 | 73 | def __len__(self) -> int: 74 | return self.db.count() 75 | -------------------------------------------------------------------------------- /src/graph_pes/data/loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | from functools import partial 5 | from typing import Iterator, Sequence 6 | 7 | import torch.utils.data 8 | 9 | from ..atomic_graph import AtomicGraph, to_batch 10 | from .datasets import GraphDataset 11 | 12 | 13 | class GraphDataLoader(torch.utils.data.DataLoader): 14 | r""" 15 | A helper class for merging :class:`~graph_pes.AtomicGraph` objects 16 | into a single batch, represented as another :class:`~graph_pes.AtomicGraph` 17 | containing disjoint subgraphs per structure (see 18 | :func:`~graph_pes.atomic_graph.to_batch`). 19 | 20 | Parameters 21 | ---------- 22 | dataset: GraphDataset | Sequence[AtomicGraph] 23 | The dataset to load. 24 | batch_size 25 | The batch size. 26 | shuffle 27 | Whether to shuffle the dataset. 28 | **kwargs: 29 | Additional keyword arguments to pass to the underlying 30 | :class:`torch.utils.data.DataLoader`. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | dataset: GraphDataset | Sequence[AtomicGraph], 36 | batch_size: int = 1, 37 | shuffle: bool = False, 38 | three_body_cutoff: float | None = None, 39 | **kwargs, 40 | ): 41 | if not isinstance(dataset, GraphDataset): 42 | dataset = GraphDataset(dataset) 43 | 44 | if "collate_fn" in kwargs: 45 | warnings.warn( 46 | "graph-pes uses a custom collate_fn (`collate_atomic_graphs`), " 47 | "are you sure you want to override this?", 48 | stacklevel=2, 49 | ) 50 | 51 | collate_fn = kwargs.pop( 52 | "collate_fn", 53 | partial(to_batch, three_body_cutoff=three_body_cutoff), 54 | ) 55 | 56 | super().__init__( 57 | dataset, 58 | batch_size, 59 | shuffle, 60 | collate_fn=collate_fn, 61 | **kwargs, 62 | ) 63 | 64 | def __iter__(self) -> Iterator[AtomicGraph]: # type: ignore 65 | return super().__iter__() 66 | -------------------------------------------------------------------------------- /src/graph_pes/interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | from ._mace import egret, go_mace_23, mace_mp, mace_off 2 | from ._mattersim import mattersim 3 | from ._orb import orb_model 4 | 5 | __all__ = [ 6 | "egret", 7 | "go_mace_23", 8 | "mace_mp", 9 | "mace_off", 10 | "mattersim", 11 | "orb_model", 12 | ] 13 | -------------------------------------------------------------------------------- /src/graph_pes/interfaces/_mattersim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from graph_pes import AtomicGraph, GraphPESModel 4 | from graph_pes.atomic_graph import ( 5 | PropertyKey, 6 | neighbour_distances, 7 | neighbour_vectors, 8 | number_of_edges, 9 | sum_per_structure, 10 | ) 11 | from graph_pes.utils.threebody import ( 12 | angle_spanned_by, 13 | triplet_edge_pairs, 14 | ) 15 | 16 | 17 | class MatterSim_M3Gnet_Wrapper(GraphPESModel): 18 | def __init__(self, model: torch.nn.Module): 19 | super().__init__( 20 | cutoff=model.model_args["cutoff"], # type: ignore 21 | implemented_properties=["local_energies"], 22 | three_body_cutoff=model.model_args["threebody_cutoff"], # type: ignore 23 | ) 24 | self.model = model 25 | 26 | def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]: 27 | # pre-compute 28 | edge_lengths = neighbour_distances(graph) # (E) 29 | edge_pairs = triplet_edge_pairs(graph, self.three_body_cutoff.item()) 30 | triplets_per_leading_edge = count_number_of_triplets_per_leading_edge( 31 | edge_pairs, graph 32 | ) 33 | r_ik = edge_lengths[edge_pairs[:, 1]] 34 | v = neighbour_vectors(graph) 35 | v_ij = v[edge_pairs[:, 0]] 36 | v_ik = v[edge_pairs[:, 1]] 37 | angle = angle_spanned_by(v_ij, v_ik) 38 | 39 | num_atoms = sum_per_structure( 40 | torch.ones_like(graph.Z), graph 41 | ).unsqueeze(-1) 42 | 43 | # num_bonds is of shape (n_structures,) such that 44 | # num_bonds[i] = sum(graph.neighbour_list[0] == i) 45 | bonds_per_atom = torch.zeros_like(graph.Z) 46 | bonds_per_atom = bonds_per_atom.scatter_add( 47 | dim=0, 48 | index=graph.neighbour_list[0], 49 | src=torch.ones_like(graph.neighbour_list[0]), 50 | ) 51 | num_bonds = sum_per_structure(bonds_per_atom, graph).unsqueeze(-1) 52 | 53 | three_body_indices = edge_pairs 54 | num_triple_ij = triplets_per_leading_edge.unsqueeze(-1) 55 | 56 | # use the forward pass of M3Gnet 57 | atom_attr = self.model.atom_embedding(self.model.one_hot_atoms(graph.Z)) 58 | edge_attr = self.model.rbf(edge_lengths) 59 | edge_attr_zero = edge_attr 60 | edge_attr = self.model.edge_encoder(edge_attr) 61 | three_basis = self.model.sbf(r_ik, angle) 62 | 63 | for conv in self.model.graph_conv: 64 | atom_attr, edge_attr = conv( 65 | atom_attr, 66 | edge_attr, 67 | edge_attr_zero, 68 | graph.neighbour_list, 69 | three_basis, 70 | three_body_indices, 71 | edge_lengths.unsqueeze(-1), 72 | num_bonds, 73 | num_triple_ij, 74 | num_atoms, 75 | ) 76 | 77 | local_energies = self.model.final(atom_attr).view(-1) 78 | local_energies = self.model.normalizer(local_energies, graph.Z) 79 | 80 | return {"local_energies": local_energies} 81 | 82 | 83 | def mattersim(load_path: str = "mattersim-v1.0.0-1m") -> GraphPESModel: 84 | """ 85 | Load a ``mattersim`` model from a checkpoint file, and convert it to a 86 | :class:`~graph_pes.GraphPESModel` on the CPU. 87 | 88 | Parameters 89 | ---------- 90 | load_path: str 91 | The path to the ``mattersim`` checkpoint file. Expected to be one of 92 | ``mattersim-v1.0.0-1m`` or ``mattersim-v1.0.0-5m`` currently. 93 | """ 94 | from mattersim.forcefield.potential import Potential 95 | 96 | model = Potential.from_checkpoint( # type: ignore 97 | load_path, 98 | load_training_state=False, # only load the model 99 | device="cpu", # manage the device ourself later 100 | ).model 101 | return MatterSim_M3Gnet_Wrapper(model) 102 | 103 | 104 | @torch.no_grad() 105 | def count_number_of_triplets_per_leading_edge( 106 | edge_pairs: torch.Tensor, 107 | graph: AtomicGraph, 108 | ): 109 | """ 110 | Return ``T`` of shape ``(E,)`` where ``T[e]`` is the number of edge pairs 111 | that have edge number ``e`` as the first edge in the pair. 112 | 113 | Parameters 114 | ---------- 115 | edge_pairs: torch.Tensor 116 | A ``(E, 2)`` shaped tensor indicating pairs of edges that form a 117 | triplet ``(i, j, k)`` (see :func:`triplet_edge_pairs`). 118 | graph: AtomicGraph 119 | The graph from which the edge pairs were derived. 120 | 121 | Returns 122 | ------- 123 | triplets_per_edge: torch.Tensor 124 | A ``(E,)`` shaped tensor where ``triplets_per_edge[e]`` is the 125 | number of edge pairs that have edge ``e`` as the first edge in the 126 | pair. 127 | 128 | """ 129 | triplets_per_edge = torch.zeros( 130 | number_of_edges(graph), device=graph.R.device, dtype=torch.long 131 | ) 132 | return triplets_per_edge.scatter_add( 133 | dim=0, 134 | index=edge_pairs[:, 0], 135 | src=torch.ones_like(edge_pairs[:, 0]), 136 | ) 137 | -------------------------------------------------------------------------------- /src/graph_pes/interfaces/orb_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from ase.build import bulk, molecule 5 | from orb_models.forcefield import atomic_system 6 | from orb_models.forcefield.base import batch_graphs 7 | 8 | from graph_pes.atomic_graph import AtomicGraph, to_batch 9 | from graph_pes.interfaces._orb import orb_model 10 | from graph_pes.utils.misc import full_3x3_to_voigt_6 11 | 12 | 13 | @pytest.fixture(params=["orb-v3-direct-20-omat", "orb-v3-conservative-20-omat"]) 14 | def wrapped_orb(request): 15 | return orb_model(request.param) 16 | 17 | 18 | def test_single_isolated_structure(wrapped_orb): 19 | atoms = molecule("H2O") 20 | 21 | g = AtomicGraph.from_ase(atoms, cutoff=wrapped_orb.cutoff.item()) 22 | our_preds = wrapped_orb.forward(g) 23 | 24 | assert our_preds["energy"].shape == tuple() 25 | assert our_preds["forces"].shape == (3, 3) 26 | 27 | orb_g = atomic_system.ase_atoms_to_atom_graphs( 28 | atoms, wrapped_orb.orb_model.system_config 29 | ) 30 | orb_preds = wrapped_orb.orb_model.predict(orb_g) 31 | 32 | torch.testing.assert_close(our_preds["energy"], orb_preds["energy"][0]) 33 | torch.testing.assert_close( 34 | our_preds["forces"], 35 | orb_preds["forces"] 36 | if "forces" in orb_preds 37 | else orb_preds["grad_forces"], 38 | ) 39 | 40 | 41 | def test_single_periodic_structure(wrapped_orb): 42 | atoms = bulk("Cu") 43 | g = AtomicGraph.from_ase(atoms, cutoff=wrapped_orb.cutoff.item()) 44 | our_preds = wrapped_orb.forward(g) 45 | 46 | assert our_preds["energy"].shape == tuple() 47 | assert our_preds["forces"].shape == (1, 3) 48 | assert our_preds["stress"].shape == (3, 3) 49 | 50 | orb_g = atomic_system.ase_atoms_to_atom_graphs( 51 | atoms, wrapped_orb.orb_model.system_config 52 | ) 53 | orb_preds = wrapped_orb.orb_model.predict(orb_g) 54 | 55 | torch.testing.assert_close( 56 | our_preds["energy"], 57 | orb_preds["energy"][0], 58 | atol=1e-4, 59 | rtol=1e-4, 60 | ) 61 | torch.testing.assert_close( 62 | our_preds["forces"], 63 | orb_preds["forces"] 64 | if "forces" in orb_preds 65 | else orb_preds["grad_forces"], 66 | atol=1e-4, 67 | rtol=1e-4, 68 | ) 69 | torch.testing.assert_close( 70 | full_3x3_to_voigt_6(our_preds["stress"]), 71 | orb_preds["stress"][0] 72 | if "stress" in orb_preds 73 | else orb_preds["grad_stress"][0], 74 | atol=1e-3, 75 | rtol=1e-3, 76 | ) 77 | 78 | 79 | def test_batched(wrapped_orb): 80 | atoms = bulk("Cu").repeat(2) 81 | rng = np.random.RandomState(42) 82 | atoms.positions += rng.uniform(-0.1, 0.1, atoms.positions.shape) 83 | 84 | N = len(atoms) 85 | B = 2 86 | 87 | g = AtomicGraph.from_ase(atoms, cutoff=wrapped_orb.cutoff.item()) 88 | batch = to_batch([g] * B) 89 | our_preds = wrapped_orb.forward(batch) 90 | 91 | assert our_preds["energy"].shape == (B,) 92 | assert our_preds["forces"].shape == (B * N, 3) 93 | assert our_preds["stress"].shape == (B, 3, 3) 94 | 95 | orb_g = atomic_system.ase_atoms_to_atom_graphs( 96 | atoms, wrapped_orb.orb_model.system_config 97 | ) 98 | orb_g = batch_graphs([orb_g] * B) 99 | orb_preds = wrapped_orb.orb_model.predict(orb_g) 100 | 101 | torch.testing.assert_close(our_preds["energy"], orb_preds["energy"]) 102 | torch.testing.assert_close( 103 | our_preds["forces"], 104 | orb_preds["forces"] 105 | if "forces" in orb_preds 106 | else orb_preds["grad_forces"], 107 | ) 108 | torch.testing.assert_close( 109 | full_3x3_to_voigt_6(our_preds["stress"]), 110 | orb_preds["stress"] 111 | if "stress" in orb_preds 112 | else orb_preds["grad_stress"], 113 | ) 114 | -------------------------------------------------------------------------------- /src/graph_pes/interfaces/quick.yaml: -------------------------------------------------------------------------------- 1 | # use this config as graph-pes-train quick.yaml model/...=... 2 | model: 3 | 4 | data: 5 | +load_atoms_dataset: 6 | id: QM7 7 | n_train: 10 8 | n_valid: 2 9 | cutoff: 6.0 10 | 11 | fitting: 12 | trainer_kwargs: 13 | max_epochs: 2 14 | accelerator: cpu 15 | 16 | auto_fit_reference_energies: true 17 | 18 | loader_kwargs: 19 | batch_size: 1 20 | 21 | loss: +PerAtomEnergyLoss() 22 | 23 | wandb: null 24 | -------------------------------------------------------------------------------- /src/graph_pes/models/components/scaling.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import warnings 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from graph_pes.atomic_graph import AtomicGraph 9 | from graph_pes.utils.nn import PerElementParameter 10 | from graph_pes.utils.shift_and_scale import guess_per_element_mean_and_var 11 | 12 | 13 | class LocalEnergiesScaler(nn.Module): 14 | """ 15 | Scale the local energies by a per-element scaling factor. 16 | 17 | See :func:`~graph_pes.utils.shift_and_scale.guess_per_element_mean_and_var` 18 | for how the scaling factors are estimated from the training data. 19 | """ 20 | 21 | def __init__(self): 22 | super().__init__() 23 | self.per_element_scaling = PerElementParameter.of_length( 24 | 1, 25 | default_value=1.0, 26 | requires_grad=True, 27 | ) 28 | """ 29 | The per-element scaling factors. 30 | (:class:`~graph_pes.utils.nn.PerElementParameter`) 31 | """ 32 | 33 | def forward( 34 | self, 35 | local_energies: torch.Tensor, 36 | graph: AtomicGraph, 37 | ) -> torch.Tensor: 38 | """ 39 | Scale the local energies by the per-element scaling factor. 40 | """ 41 | scales = self.per_element_scaling[graph.Z].squeeze() 42 | return local_energies.squeeze() * scales 43 | 44 | # add typing for mypy etc 45 | def __call__( 46 | self, local_energies: torch.Tensor, graph: AtomicGraph 47 | ) -> torch.Tensor: 48 | return super().__call__(local_energies, graph) 49 | 50 | @torch.no_grad() 51 | def pre_fit(self, graphs: AtomicGraph): 52 | """ 53 | Pre-fit the per-element scaling factors. 54 | 55 | Parameters 56 | ---------- 57 | graphs 58 | The training data. 59 | """ 60 | if "energy" not in graphs.properties: 61 | warnings.warn( 62 | "No energy data found in training data: can't estimate " 63 | "per-element scaling factors for local energies.", 64 | stacklevel=2, 65 | ) 66 | return 67 | 68 | means, variances = guess_per_element_mean_and_var( 69 | graphs.properties["energy"], graphs 70 | ) 71 | for Z, var in variances.items(): 72 | self.per_element_scaling[Z] = torch.sqrt(torch.tensor(var)) 73 | 74 | def non_decayable_parameters(self) -> list[torch.nn.Parameter]: 75 | """The ``per_element_scaling`` parameter should not be decayed.""" 76 | return [self.per_element_scaling] 77 | 78 | def __repr__(self): 79 | return self.per_element_scaling._repr(alias=self.__class__.__name__) 80 | -------------------------------------------------------------------------------- /src/graph_pes/models/e3nn/_high_order_CG_coeff.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jla-gardner/graph-pes/c574efb9ee873ddefe9cbd0c62b58777b58d96f2/src/graph_pes/models/e3nn/_high_order_CG_coeff.pt -------------------------------------------------------------------------------- /src/graph_pes/models/scripted.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | from graph_pes.atomic_graph import AtomicGraph, PropertyKey 6 | from graph_pes.graph_pes_model import GraphPESModel 7 | from graph_pes.utils.misc import uniform_repr 8 | 9 | 10 | class ScriptedModel(GraphPESModel): 11 | def __init__(self, scripted_model: torch.jit.ScriptModule): 12 | super().__init__( 13 | cutoff=scripted_model.cutoff.item(), 14 | implemented_properties=scripted_model.implemented_properties, 15 | three_body_cutoff=scripted_model.three_body_cutoff.item(), 16 | ) 17 | self.scripted_model = scripted_model 18 | 19 | def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]: 20 | return self.scripted_model(graph) 21 | 22 | def __repr__(self): 23 | return uniform_repr( 24 | self.__class__.__name__, 25 | scripted_model=self.scripted_model, 26 | stringify=True, 27 | max_width=80, 28 | indent_width=2, 29 | ) 30 | -------------------------------------------------------------------------------- /src/graph_pes/models/unit_converter.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | from graph_pes.atomic_graph import AtomicGraph, PropertyKey 6 | from graph_pes.graph_pes_model import GraphPESModel 7 | 8 | 9 | class UnitConverter(GraphPESModel): 10 | r""" 11 | A wrapper that converts the units of the energy, forces and stress 12 | predictions of an underlying model. 13 | 14 | Parameters 15 | ---------- 16 | model 17 | The underlying model. 18 | energy_to_eV 19 | The conversion factor for energy, such that the 20 | ``model.predict_energy(graph) * energy_to_eV`` gives the 21 | energy prediction in eV. 22 | length_to_A 23 | The conversion factor for length, such that the 24 | ``model.predict_forces(graph) * (energy_to_eV / length_to_A)`` 25 | gives the force prediction in eV/Å. 26 | """ 27 | 28 | def __init__( 29 | self, model: GraphPESModel, energy_to_eV: float, length_to_A: float 30 | ): 31 | super().__init__( 32 | cutoff=model.cutoff.item(), 33 | implemented_properties=model.implemented_properties, 34 | ) 35 | self._model = model 36 | self._energy_to_eV = energy_to_eV 37 | self._length_to_A = length_to_A 38 | 39 | def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]: 40 | predictions = self._model(graph) 41 | for key in predictions: 42 | if key in ["energy", "virial"]: 43 | predictions[key] *= self._energy_to_eV 44 | elif key == "forces": 45 | predictions[key] *= self._energy_to_eV / self._length_to_A 46 | elif key == "stress": 47 | predictions[key] *= self._energy_to_eV / self._length_to_A**3 48 | 49 | return predictions 50 | -------------------------------------------------------------------------------- /src/graph_pes/pair_style/pair_graph_pes.h: -------------------------------------------------------------------------------- 1 | /* -*- c++ -*- ---------------------------------------------------------- 2 | LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator 3 | http://lammps.sandia.gov, Sandia National Laboratories 4 | Steve Plimpton, sjplimp@sandia.gov 5 | 6 | Copyright (2003) Sandia Corporation. Under the terms of Contract 7 | DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains 8 | certain rights in this software. This software is distributed under 9 | the GNU General Public License. 10 | 11 | See the README file in the top-level LAMMPS directory. 12 | ------------------------------------------------------------------------- */ 13 | 14 | #ifdef PAIR_CLASS 15 | 16 | PairStyle(graph_pes, PairGraphPES) 17 | 18 | #else 19 | 20 | #ifndef LMP_PAIR_GraphPES_H 21 | #define LMP_PAIR_GraphPES_H 22 | 23 | #include "pair.h" 24 | 25 | #include 26 | 27 | namespace LAMMPS_NS 28 | { 29 | 30 | class PairGraphPES : public Pair 31 | { 32 | public: 33 | PairGraphPES(class LAMMPS *); 34 | virtual ~PairGraphPES(); 35 | virtual void compute(int, int); 36 | void settings(int, char **); 37 | virtual void coeff(int, char **); 38 | virtual double init_one(int, int); 39 | virtual void init_style(); 40 | void allocate(); 41 | 42 | double cutoff; 43 | torch::jit::Module model; 44 | torch::Device device = torch::kCPU; 45 | torch::Tensor extract_cell_tensor(); 46 | 47 | protected: 48 | int *lammps_type_to_Z; 49 | int debug_mode = 0; 50 | }; 51 | 52 | } 53 | 54 | #endif 55 | #endif 56 | -------------------------------------------------------------------------------- /src/graph_pes/scripts/id.py: -------------------------------------------------------------------------------- 1 | from graph_pes.utils.misc import random_id 2 | 3 | 4 | def main(): 5 | print(random_id()) 6 | 7 | 8 | if __name__ == "__main__": 9 | main() 10 | -------------------------------------------------------------------------------- /src/graph_pes/scripts/resume.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from datetime import datetime 4 | from pathlib import Path 5 | 6 | import yaml 7 | 8 | from graph_pes.config.shared import ( 9 | instantiate_config_from_dict, 10 | parse_loss, 11 | parse_model, 12 | ) 13 | from graph_pes.config.training import TrainingConfig 14 | from graph_pes.data.loader import GraphDataLoader 15 | from graph_pes.scripts.train import trainer_from_config 16 | from graph_pes.training.tasks import TrainingTask 17 | from graph_pes.utils.logger import logger 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser( 22 | description="Resume a `graph-pes-train` training run.", 23 | epilog="Copyright 2023-25, John Gardner", 24 | ) 25 | parser.add_argument( 26 | "train_directory", 27 | type=str, 28 | help=( 29 | "Path to the training directory. For instance, " 30 | "`graph-pes-results/abdcefg_hijklmn`" 31 | ), 32 | ) 33 | 34 | return parser.parse_args() 35 | 36 | 37 | def main(): 38 | # set the load-atoms verbosity to 1 by default to avoid 39 | # spamming logs with `rich` output 40 | os.environ["LOAD_ATOMS_VERBOSE"] = os.getenv("LOAD_ATOMS_VERBOSE", "1") 41 | 42 | args = parse_args() 43 | 44 | train_dir = Path(args.train_directory) 45 | if not train_dir.exists(): 46 | raise ValueError(f"Training directory not found: {train_dir}") 47 | 48 | # find the latest checkpoint 49 | checkpoint_path = train_dir / "checkpoints/last.ckpt" 50 | assert checkpoint_path.exists(), f"Checkpoint not found: {checkpoint_path}" 51 | 52 | # and the training config 53 | config_path = train_dir / "train-config.yaml" 54 | assert config_path.exists(), f"Training config not found: {config_path}" 55 | 56 | with open(config_path) as f: 57 | config_data = yaml.safe_load(f) 58 | 59 | # load the checkpoint 60 | config_data, config = instantiate_config_from_dict( 61 | config_data, TrainingConfig 62 | ) 63 | task = TrainingTask.load_from_checkpoint( 64 | checkpoint_path, 65 | model=parse_model(config.model), 66 | loss=parse_loss(config.loss), 67 | optimizer=config.fitting.get_optimizer(), 68 | scheduler=config.fitting.get_scheduler(), 69 | ) 70 | 71 | # create the trainer 72 | trainer = trainer_from_config(config, train_dir) 73 | if trainer.global_rank == 0: 74 | now_ms = datetime.now().strftime("%F %T.%f")[:-3] 75 | logger.info(f"Resuming training at {now_ms}") 76 | 77 | # resume training 78 | data = config.get_data(task.model) 79 | loader_kwargs = {**config.fitting.loader_kwargs} 80 | loader_kwargs["shuffle"] = True 81 | train_loader = GraphDataLoader(data.train, **loader_kwargs) 82 | loader_kwargs["shuffle"] = False 83 | valid_loader = GraphDataLoader(data.valid, **loader_kwargs) 84 | trainer.fit(task, train_loader, valid_loader, ckpt_path=checkpoint_path) 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /src/graph_pes/scripts/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | import pytorch_lightning as pl 7 | 8 | from graph_pes.config.shared import instantiate_config_from_dict 9 | from graph_pes.config.testing import TestingConfig 10 | from graph_pes.models import load_model 11 | from graph_pes.scripts.utils import ( 12 | configure_general_options, 13 | extract_config_dict_from_command_line, 14 | update_summary, 15 | ) 16 | from graph_pes.training.tasks import test_with_lightning 17 | from graph_pes.utils import distributed 18 | from graph_pes.utils.logger import logger 19 | 20 | 21 | def test(config: TestingConfig) -> None: 22 | logger.info(f"Testing model at {config.model_path}...") 23 | 24 | configure_general_options(config.torch, seed=0) 25 | 26 | model = load_model(config.model_path) 27 | logger.info("Loaded model.") 28 | logger.debug(f"Model: {model}") 29 | 30 | datasets = config.get_datasets(model) 31 | 32 | for dataset in datasets.values(): 33 | if distributed.IS_RANK_0: 34 | dataset.prepare_data() 35 | dataset.setup() 36 | 37 | trainer = pl.Trainer( 38 | logger=config.get_logger(), 39 | accelerator=config.accelerator, 40 | inference_mode=False, 41 | ) 42 | 43 | test_with_lightning( 44 | trainer, 45 | model, 46 | datasets, 47 | config.loader_kwargs, 48 | config.prefix, 49 | user_eval_metrics=[], 50 | ) 51 | 52 | summary_file = Path(config.model_path).parent / "summary.yaml" 53 | update_summary(trainer.logger, summary_file) 54 | 55 | 56 | def main(): 57 | # set the load-atoms verbosity to 1 by default to avoid 58 | # spamming logs with `rich` output 59 | os.environ["LOAD_ATOMS_VERBOSE"] = os.getenv("LOAD_ATOMS_VERBOSE", "1") 60 | 61 | config_dict = extract_config_dict_from_command_line( 62 | "Test a GraphPES model using PyTorch Lightning." 63 | ) 64 | _, config = instantiate_config_from_dict(config_dict, TestingConfig) 65 | test(config) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /src/graph_pes/training/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pytorch_lightning.loggers import Logger as PTLLogger 4 | 5 | from graph_pes.atomic_graph import ( 6 | AtomicGraph, 7 | number_of_atoms, 8 | number_of_structures, 9 | ) 10 | from graph_pes.graph_pes_model import GraphPESModel 11 | from graph_pes.models.addition import AdditionModel 12 | from graph_pes.utils.logger import logger 13 | from graph_pes.utils.nn import learnable_parameters 14 | 15 | 16 | def log_model_info( 17 | model: GraphPESModel, 18 | ptl_logger: PTLLogger | None = None, 19 | ) -> None: 20 | """Log the number of parameters in a model.""" 21 | 22 | logger.debug(f"Model:\n{model}") 23 | 24 | if isinstance(model, AdditionModel): 25 | model_names = [ 26 | f"{given_name} ({component.__class__.__name__})" 27 | for given_name, component in model.models.items() 28 | ] 29 | params = [ 30 | learnable_parameters(component) 31 | for component in model.models.values() 32 | ] 33 | width = max(len(name) for name in model_names) 34 | info_str = "Number of learnable params:" 35 | for name, param in zip(model_names, params): 36 | info_str += f"\n {name:<{width}}: {param:,}" 37 | logger.info(info_str) 38 | 39 | else: 40 | logger.info( 41 | f"Number of learnable params : {learnable_parameters(model):,}" 42 | ) 43 | 44 | if ptl_logger is not None: 45 | all_params = sum(p.numel() for p in model.parameters()) 46 | learnable_params = learnable_parameters(model) 47 | ptl_logger.log_metrics( 48 | { 49 | "n_parameters": all_params, 50 | "n_learnable_parameters": learnable_params, 51 | } 52 | ) 53 | 54 | 55 | def sanity_check(model: GraphPESModel, batch: AtomicGraph) -> None: 56 | outputs = model.get_all_PES_predictions(batch) 57 | 58 | N = number_of_atoms(batch) 59 | S = number_of_structures(batch) 60 | expected_shapes = { 61 | "local_energies": (N,), 62 | "forces": (N, 3), 63 | "energy": (S,), 64 | "stress": (S, 3, 3), 65 | "virial": (S, 3, 3), 66 | } 67 | 68 | incorrect = [] 69 | for key, value in outputs.items(): 70 | if value.shape != expected_shapes[key]: 71 | incorrect.append((key, value.shape, expected_shapes[key])) 72 | 73 | if len(incorrect) > 0: 74 | raise ValueError( 75 | "Sanity check failed for the following outputs:\n" 76 | + "\n".join( 77 | f"{key}: {value} != {expected}" 78 | for key, value, expected in incorrect 79 | ) 80 | ) 81 | 82 | if batch.cutoff < model.cutoff: 83 | logger.error( 84 | "Sanity check failed: you appear to be training on data " 85 | f"composed of graphs with a cutoff ({batch.cutoff}) that is " 86 | f"smaller than the cutoff used in the model ({model.cutoff}). " 87 | "This is almost certainly not what you want to do?", 88 | ) 89 | 90 | 91 | VALIDATION_LOSS_KEY = "valid/loss/total" 92 | -------------------------------------------------------------------------------- /src/graph_pes/utils/distributed.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from typing import Final 5 | 6 | from graph_pes.utils.misc import silently_create_trainer 7 | 8 | # dirty hack: just get lightning to work this out, 9 | # and ensure no annoying printing happens 10 | _trainer = silently_create_trainer(logger=False) 11 | 12 | GLOBAL_RANK: Final[int] = _trainer.global_rank 13 | WORLD_SIZE: Final[int] = _trainer.world_size 14 | IS_RANK_0: Final[bool] = GLOBAL_RANK == 0 15 | 16 | 17 | def send_to_other_ranks(key: str, value: str) -> None: 18 | """Must be called by rank 0 and before `Trainer.fit` is called.""" 19 | os.environ[key] = value 20 | 21 | 22 | def receive_from_rank_0(key: str) -> str: 23 | """Must be called from a non-0 rank.""" 24 | return os.environ[key] 25 | -------------------------------------------------------------------------------- /src/graph_pes/utils/lammps.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pathlib 4 | 5 | import e3nn.util.jit 6 | import torch 7 | 8 | from graph_pes.atomic_graph import AtomicGraph, PropertyKey 9 | from graph_pes.graph_pes_model import GraphPESModel 10 | from graph_pes.utils.misc import full_3x3_to_voigt_6 11 | 12 | 13 | def as_lammps_data( 14 | graph: AtomicGraph, 15 | compute_virial: bool = False, 16 | debug: bool = False, 17 | ) -> dict[str, torch.Tensor]: 18 | return { 19 | "atomic_numbers": graph.Z, 20 | "positions": graph.R, 21 | "cell": graph.cell, 22 | "neighbour_list": graph.neighbour_list, 23 | "neighbour_cell_offsets": graph.neighbour_cell_offsets, 24 | "compute_virial": torch.tensor(compute_virial), 25 | "debug": torch.tensor(debug), 26 | } 27 | 28 | 29 | class LAMMPSModel(torch.nn.Module): 30 | def __init__(self, model: GraphPESModel): 31 | super().__init__() 32 | self.model = model 33 | 34 | @torch.jit.export # type: ignore 35 | def get_cutoff(self) -> torch.Tensor: 36 | return self.model.cutoff 37 | 38 | def forward( 39 | self, graph_data: dict[str, torch.Tensor] 40 | ) -> dict[str, torch.Tensor]: 41 | debug = graph_data.get("debug", torch.tensor(False)).item() 42 | 43 | if debug: 44 | print("Received graph:") 45 | for key, value in graph_data.items(): 46 | print(f"{key}: {value}") 47 | 48 | compute_virial = graph_data["compute_virial"].item() 49 | properties: list[PropertyKey] = ["energy", "forces", "local_energies"] 50 | if compute_virial: 51 | properties.append("virial") 52 | 53 | # graph_data is a dict, so we need to convert it to an AtomicGraph 54 | graph = AtomicGraph( 55 | Z=graph_data["atomic_numbers"], 56 | R=graph_data["positions"], 57 | cell=graph_data["cell"], 58 | neighbour_list=graph_data["neighbour_list"], 59 | neighbour_cell_offsets=graph_data["neighbour_cell_offsets"], 60 | properties={}, 61 | other={}, 62 | cutoff=self.model.cutoff.item(), 63 | ) 64 | preds = self.model.predict(graph, properties=properties) 65 | 66 | # cast to float64 67 | for key in preds: 68 | preds[key] = preds[key].double() 69 | 70 | # add virial output if required 71 | if compute_virial: 72 | # LAMMPS expects the **virial** in Voigt notation 73 | # we provide the **stress** in full 3x3 matrix notation 74 | # therefore, convert: 75 | preds["virial"] = full_3x3_to_voigt_6(preds["virial"]) 76 | 77 | return preds # type: ignore 78 | 79 | def __call__( 80 | self, graph_data: dict[str, torch.Tensor] 81 | ) -> dict[str, torch.Tensor]: 82 | return super().__call__(graph_data) 83 | 84 | 85 | def deploy_model(model: GraphPESModel, path: str | pathlib.Path): 86 | """ 87 | Deploy a :class:`~graph_pes.GraphPESModel` for use with LAMMPS. 88 | 89 | Use the resulting model with LAMMPS according to: 90 | 91 | .. code-block:: bash 92 | 93 | pair_style graph_pes 94 | pair_coeff * * path/to/model.pt ... 95 | 96 | Parameters 97 | ---------- 98 | model 99 | The model to deploy. 100 | path 101 | The path to save the deployed model to. 102 | """ # noqa: E501 103 | lammps_model = LAMMPSModel(model) 104 | scripted_model = e3nn.util.jit.script(lammps_model) 105 | torch.jit.save(scripted_model, path) 106 | -------------------------------------------------------------------------------- /src/graph_pes/utils/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import sys 5 | from pathlib import Path 6 | 7 | from . import distributed 8 | 9 | __all__ = ["logger", "log_to_file", "set_log_level"] 10 | 11 | 12 | class MultiLineFormatter(logging.Formatter): 13 | """Detect multi-line logs and adds newlines with colors.""" 14 | 15 | def __init__(self): 16 | super().__init__("[%(name)s %(levelname)s]: %(message)s") 17 | 18 | def format(self, record: logging.LogRecord) -> str: 19 | record.msg = str(record.msg).strip() 20 | # add in new lines 21 | if "\n" in record.msg: 22 | record.msg = "\n" + record.msg + "\n" 23 | return super().format(record) 24 | 25 | 26 | # create the graph-pes logger 27 | logger = logging.getLogger(name="graph-pes") 28 | std_out_handler = logging.StreamHandler(stream=sys.stdout) 29 | std_out_handler.setFormatter(MultiLineFormatter()) 30 | 31 | # log to stdout if rank 0 32 | if distributed.IS_RANK_0: 33 | logger.addHandler(std_out_handler) 34 | 35 | # capture all logs but only show INFO and above in stdout (by default) 36 | logger.setLevel(logging.DEBUG) 37 | std_out_handler.setLevel(logging.INFO) 38 | 39 | 40 | def log_to_file(output_dir: str | Path): 41 | """Append logs to a `rank-.log` file in the given output directory.""" 42 | 43 | file = Path(output_dir) / f"rank-{distributed.GLOBAL_RANK}.log" 44 | file.parent.mkdir(parents=True, exist_ok=True) 45 | 46 | handler = logging.FileHandler(file, mode="a") 47 | handler.setLevel(logging.DEBUG) 48 | handler.setFormatter(MultiLineFormatter()) 49 | 50 | logger.addHandler(handler) 51 | logger.info(f"Logging to {file}") 52 | 53 | 54 | def set_log_level(level: str | int): 55 | """Set the logging level.""" 56 | 57 | std_out_handler.setLevel(level) 58 | logger.debug(f"Set logging level to {level}") 59 | -------------------------------------------------------------------------------- /src/graph_pes/utils/sampling.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterator, Sequence, TypeVar, overload 4 | 5 | import numpy as np 6 | 7 | T = TypeVar("T") 8 | 9 | 10 | class SequenceSampler(Sequence[T]): 11 | """ 12 | A class that wraps a :class:`Sequence` of ``T`` objects and 13 | provides methods for sampling from it without the need to manipulate or 14 | access the underlying data. 15 | 16 | This is useful for e.g. sub-sampling a ``collection`` where individual 17 | indexing operations are expensive, such as a database on disk, or when 18 | indexing involves some form of pre-processing. 19 | 20 | Parameters 21 | ---------- 22 | collection 23 | The collection to wrap. 24 | indices 25 | The indices of the elements to include in the collection. If ``None``, 26 | all elements are included. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | collection: Sequence[T], 32 | indices: Sequence[int] | None = None, 33 | ): 34 | self.collection = collection 35 | self.indices = indices or range(len(collection)) 36 | 37 | @overload 38 | def __getitem__(self, index: int) -> T: ... 39 | @overload 40 | def __getitem__(self, index: slice) -> SequenceSampler[T]: ... 41 | def __getitem__(self, index: int | slice) -> T | SequenceSampler[T]: 42 | """ 43 | Get the element/s at the given ``index``. 44 | 45 | Parameters 46 | ---------- 47 | index 48 | The index/indices of the element/elements to get. 49 | """ 50 | if isinstance(index, int): 51 | return self.collection[self.indices[index]] 52 | 53 | sampled_indices = self.indices[index] 54 | return SequenceSampler(self.collection, sampled_indices) 55 | 56 | def __len__(self) -> int: 57 | """The number of items in the collection.""" 58 | return len(self.indices) 59 | 60 | def __iter__(self) -> Iterator[T]: 61 | """Iterate over the collection.""" 62 | for i in range(len(self)): 63 | yield self[i] 64 | 65 | def shuffled(self, seed: int = 42) -> SequenceSampler[T]: 66 | """ 67 | Return a shuffled version of this collection. 68 | 69 | Parameters 70 | ---------- 71 | seed 72 | The random seed to use for shuffling. 73 | """ 74 | # 1. make a copy of the indices 75 | indices = [*self.indices] 76 | # 2. shuffle them 77 | np.random.default_rng(seed).shuffle(indices) 78 | # 3. return a new OrderedCollection with the shuffled indices 79 | return SequenceSampler(self.collection, indices) 80 | 81 | def sample_at_most(self, n: int, seed: int = 42) -> SequenceSampler[T]: 82 | """ 83 | Return a sampled collection with at most ``n`` elements. 84 | """ 85 | assert n >= 0, "n must be non-negative" 86 | n = min(n, len(self)) 87 | return self.shuffled(seed)[:n] 88 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jla-gardner/graph-pes/c574efb9ee873ddefe9cbd0c62b58777b58d96f2/tests/__init__.py -------------------------------------------------------------------------------- /tests/config/test_config.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import ase.io 4 | import pytest 5 | from ase.build import molecule 6 | 7 | from graph_pes.config.shared import ( 8 | parse_dataset_collection, 9 | parse_single_dataset, 10 | ) 11 | from graph_pes.data.datasets import ( 12 | DatasetCollection, 13 | GraphDataset, 14 | file_dataset, 15 | ) 16 | from graph_pes.models.pairwise import LennardJones 17 | 18 | 19 | def test_parse_single_dataset(tmp_path: pathlib.Path): 20 | CUTOFF = 5.0 21 | model = LennardJones(cutoff=CUTOFF) 22 | ase.io.write(tmp_path / "test.xyz", [molecule("H2O"), molecule("CH4")]) 23 | 24 | # test that a path to a single file works 25 | config = str(tmp_path / "test.xyz") 26 | dataset = parse_single_dataset(config, model) 27 | assert isinstance(dataset, GraphDataset) 28 | assert len(dataset) == 2 29 | assert dataset[0].cutoff == CUTOFF 30 | 31 | # test that a dict works 32 | config = {"path": str(tmp_path / "test.xyz"), "n": 1} 33 | dataset = parse_single_dataset(config, model) 34 | assert isinstance(dataset, GraphDataset) 35 | assert len(dataset) == 1 36 | assert dataset[0].cutoff == CUTOFF 37 | 38 | # test that overriding the cutoff works 39 | config = {"path": str(tmp_path / "test.xyz"), "cutoff": 6.0} 40 | dataset = parse_single_dataset(config, model) 41 | assert isinstance(dataset, GraphDataset) 42 | assert len(dataset) == 2 43 | assert dataset[0].cutoff == 6.0 44 | 45 | # test that an instance of a GraphDataset works 46 | raw_dataset = file_dataset(str(tmp_path / "test.xyz"), cutoff=CUTOFF) 47 | dataset = parse_single_dataset(raw_dataset, model) 48 | assert len(dataset) == 2 49 | assert dataset[0].cutoff == CUTOFF 50 | 51 | # test that a non-GraphDataset raises an error 52 | with pytest.raises(ValueError): 53 | parse_single_dataset(1, model) 54 | 55 | # and that a bogus file path raises file not found 56 | with pytest.raises(FileNotFoundError): 57 | parse_single_dataset("bogus/path.xyz", model) 58 | 59 | 60 | def test_parse_dataset_collection(tmp_path: pathlib.Path): 61 | CUTOFF = 5.0 62 | model = LennardJones(cutoff=CUTOFF) 63 | ase.io.write(tmp_path / "test.xyz", [molecule("H2O"), molecule("CH4")]) 64 | 65 | # test that a dict works 66 | config = { 67 | "train": str(tmp_path / "test.xyz"), 68 | "valid": str(tmp_path / "test.xyz"), 69 | } 70 | collection = parse_dataset_collection(config, model) 71 | assert isinstance(collection, DatasetCollection) 72 | assert len(collection.train) == 2 73 | assert len(collection.valid) == 2 74 | assert collection.test is None 75 | 76 | # tests should be picked up 77 | config = { 78 | "train": str(tmp_path / "test.xyz"), 79 | "valid": str(tmp_path / "test.xyz"), 80 | "test": str(tmp_path / "test.xyz"), 81 | } 82 | collection = parse_dataset_collection(config, model) 83 | assert isinstance(collection, DatasetCollection) 84 | assert len(collection.train) == 2 85 | assert len(collection.valid) == 2 86 | assert collection.test is not None 87 | assert isinstance(collection.test, GraphDataset) 88 | 89 | config = { 90 | "train": str(tmp_path / "test.xyz"), 91 | "valid": str(tmp_path / "test.xyz"), 92 | "test": { 93 | "bulk": str(tmp_path / "test.xyz"), 94 | "surface": { 95 | "path": str(tmp_path / "test.xyz"), 96 | "n": 1, 97 | }, 98 | }, 99 | } 100 | collection = parse_dataset_collection(config, model) 101 | assert isinstance(collection, DatasetCollection) 102 | assert len(collection.train) == 2 103 | assert len(collection.valid) == 2 104 | assert collection.test is not None 105 | assert isinstance(collection.test, dict) 106 | assert len(collection.test) == 2 107 | assert isinstance(collection.test["bulk"], GraphDataset) 108 | assert isinstance(collection.test["surface"], GraphDataset) 109 | assert len(collection.test["surface"]) == 1 110 | 111 | # test that a DatasetCollection instance is returned unchanged 112 | raw_collection = collection 113 | new_collection = parse_dataset_collection(raw_collection, model) 114 | assert new_collection == collection 115 | 116 | # no train or valid keys should raise an error 117 | with pytest.raises(ValueError): 118 | parse_dataset_collection({"test": str(tmp_path / "test.xyz")}, model) 119 | 120 | # file not found for test set should raise an error 121 | with pytest.raises(FileNotFoundError): 122 | parse_dataset_collection( 123 | { 124 | "train": str(tmp_path / "test.xyz"), 125 | "valid": str(tmp_path / "test.xyz"), 126 | "test": "bogus/path.xyz", 127 | }, 128 | model, 129 | ) 130 | 131 | # not a valid dict 132 | with pytest.raises(ValueError): 133 | parse_dataset_collection( 134 | { 135 | "train": str(tmp_path / "test.xyz"), 136 | "valid": str(tmp_path / "test.xyz"), 137 | "test": {"bulk": 1}, 138 | }, 139 | model, 140 | ) 141 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jla-gardner/graph-pes/c574efb9ee873ddefe9cbd0c62b58777b58d96f2/tests/conftest.py -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jla-gardner/graph-pes/c574efb9ee873ddefe9cbd0c62b58777b58d96f2/tests/data/__init__.py -------------------------------------------------------------------------------- /tests/data/schnetpack_data.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jla-gardner/graph-pes/c574efb9ee873ddefe9cbd0c62b58777b58d96f2/tests/data/schnetpack_data.db -------------------------------------------------------------------------------- /tests/data/test_ase_datasets.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | from graph_pes.atomic_graph import AtomicGraph, number_of_atoms 8 | from graph_pes.data import load_atoms_dataset 9 | from graph_pes.data.datasets import ( 10 | ASEToGraphsConverter, 11 | ConcatDataset, 12 | file_dataset, 13 | ) 14 | 15 | from .. import helpers 16 | 17 | 18 | @pytest.mark.parametrize("split", ["random", "sequential"]) 19 | def test_shuffling(split: Literal["random", "sequential"]): 20 | dataset = load_atoms_dataset( 21 | id=helpers.CU_STRUCTURES_FILE, 22 | cutoff=3.7, 23 | n_train=8, 24 | n_valid=2, 25 | split=split, 26 | ) 27 | 28 | if split == "sequential": 29 | np.testing.assert_allclose( 30 | dataset.train[0].R, 31 | helpers.CU_TEST_STRUCTURES[0].positions, 32 | ) 33 | else: 34 | # different structures with different sizes in the first 35 | # position of the training set after shuffling 36 | assert number_of_atoms(dataset.train[0]) != len( 37 | helpers.CU_TEST_STRUCTURES[0] 38 | ) 39 | 40 | 41 | def test_dataset(): 42 | dataset = load_atoms_dataset( 43 | id=helpers.CU_STRUCTURES_FILE, 44 | cutoff=3.7, 45 | n_train=8, 46 | n_valid=2, 47 | ) 48 | 49 | assert len(dataset.train) == 8 50 | assert len(dataset.valid) == 2 51 | 52 | 53 | def test_property_map(): 54 | dataset = load_atoms_dataset( 55 | id=helpers.CU_STRUCTURES_FILE, 56 | cutoff=3.7, 57 | n_train=8, 58 | n_valid=2, 59 | property_map={"positions": "forces"}, 60 | split="sequential", 61 | ) 62 | 63 | assert "forces" in dataset.train[0].properties 64 | np.testing.assert_allclose( 65 | dataset.train[0].properties["forces"], 66 | helpers.CU_TEST_STRUCTURES[0].positions, 67 | ) 68 | 69 | with pytest.raises( 70 | ValueError, match="Unable to find properties: {'UNKNOWN KEY'}" 71 | ): 72 | load_atoms_dataset( 73 | id=helpers.CU_STRUCTURES_FILE, 74 | cutoff=3.7, 75 | n_train=8, 76 | n_valid=2, 77 | property_map={"UNKNOWN KEY": "energy"}, 78 | ) 79 | 80 | 81 | def test_file_dataset(): 82 | dataset = file_dataset( 83 | helpers.CU_STRUCTURES_FILE, 84 | cutoff=2.5, 85 | n=5, 86 | shuffle=False, 87 | seed=42, 88 | ) 89 | 90 | assert len(dataset) == 5 91 | 92 | shuffled = file_dataset( 93 | helpers.CU_STRUCTURES_FILE, 94 | cutoff=2.5, 95 | n=5, 96 | shuffle=True, 97 | seed=42, 98 | ) 99 | 100 | assert len(shuffled) == 5 101 | assert shuffled[0].R.shape != dataset[0].R.shape 102 | 103 | 104 | def test_concat_dataset(): 105 | a = file_dataset( 106 | helpers.CU_STRUCTURES_FILE, 107 | cutoff=2.5, 108 | n=5, 109 | ) 110 | b = file_dataset( 111 | helpers.CU_STRUCTURES_FILE, 112 | cutoff=4.5, 113 | n=5, 114 | ) 115 | 116 | c = ConcatDataset(a=a, b=b) 117 | 118 | # check correct length 119 | assert len(c) == 10 120 | assert torch.allclose(c[0].R, a[0].R) 121 | assert c[0].cutoff == a[0].cutoff 122 | 123 | assert torch.allclose(c[5].R, b[0].R) 124 | assert c[5].cutoff == b[0].cutoff 125 | 126 | # check that the properties are correct 127 | assert set(c.properties) == set(a.properties + b.properties) 128 | 129 | # check that we can't index past the end 130 | with pytest.raises(IndexError): 131 | c[100] 132 | 133 | # check that calling prepare_data() and setup() works 134 | assert isinstance(c.datasets["a"].graphs, ASEToGraphsConverter) 135 | c.prepare_data() 136 | c.setup() 137 | assert isinstance(c.datasets["a"].graphs, list) 138 | 139 | # check that graphs is available 140 | g = c.graphs[0] 141 | assert isinstance(g, AtomicGraph) 142 | assert torch.allclose(g.R, a[0].R) 143 | -------------------------------------------------------------------------------- /tests/data/test_db.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import ase 5 | import numpy as np 6 | import pytest 7 | 8 | from graph_pes.atomic_graph import AtomicGraph 9 | from graph_pes.data.ase_db import ASEDatabase 10 | from graph_pes.data.datasets import file_dataset 11 | 12 | DB_FILE = Path(__file__).parent / "schnetpack_data.db" 13 | # the dataset available at ./schnetpack_data.db 14 | # was created with the following code: 15 | # 16 | # import numpy as np 17 | # from schnetpack.data import ASEAtomsData 18 | # from ase.build import molecule 19 | # 20 | # 21 | # structures = [molecule(s) for s in "H2O CO2 CH4 C2H4 C2H2".split()] 22 | # properties = [ 23 | # { 24 | # "energy": np.random.rand(), 25 | # "forces": np.random.rand(len(s), 3), 26 | # "other_key": np.random.rand(), 27 | # } 28 | # for s in structures 29 | # ] 30 | # 31 | # new_dataset = ASEAtomsData.create( 32 | # "schnetpack_data.db", 33 | # distance_unit="Ang", 34 | # property_unit_dict={"energy": "eV", "forces": "eV/Ang"}, 35 | # ) 36 | # new_dataset.add_systems(properties, structures) 37 | 38 | 39 | def test_ASEDatabase(): 40 | db = ASEDatabase(DB_FILE) 41 | assert len(db) == 5 42 | assert isinstance(db[0].info["energy"], float) 43 | assert isinstance(db[0].arrays["forces"], np.ndarray) and db[0].arrays[ 44 | "forces" 45 | ].shape == (3, 3) 46 | 47 | assert isinstance(db[0:2], Sequence) 48 | assert isinstance(db[0:2][0], ase.Atoms) 49 | 50 | with pytest.raises(FileNotFoundError): 51 | ASEDatabase("non_existent_file.db") 52 | 53 | 54 | def test_file_dataset_with_db(): 55 | from_file = file_dataset(DB_FILE, cutoff=2.5, n=3, shuffle=False, seed=42) 56 | assert len(from_file) == 3 57 | assert isinstance(from_file[0], AtomicGraph) 58 | -------------------------------------------------------------------------------- /tests/graphs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jla-gardner/graph-pes/c574efb9ee873ddefe9cbd0c62b58777b58d96f2/tests/graphs/__init__.py -------------------------------------------------------------------------------- /tests/graphs/test_conversions.py: -------------------------------------------------------------------------------- 1 | import ase 2 | import numpy as np 3 | import pytest 4 | import torch 5 | 6 | from graph_pes import AtomicGraph 7 | 8 | 9 | def test_there_and_back_again(): 10 | for pbc in [True, False]: 11 | atoms = ase.Atoms( 12 | "H2", 13 | positions=[(0, 0, 0), (0, 0, 1)], 14 | pbc=pbc, 15 | cell=(4, 4, 4), 16 | ) 17 | atoms.info["energy"] = 1.0 18 | atoms.arrays["forces"] = np.random.rand(2, 3) 19 | 20 | graph = AtomicGraph.from_ase(atoms) 21 | print(graph) 22 | 23 | atoms_back = graph.to_ase() 24 | print(atoms_back.info) 25 | 26 | assert atoms_back.info["energy"] == atoms.info["energy"] 27 | assert np.allclose(atoms_back.arrays["forces"], atoms.arrays["forces"]) 28 | assert np.all(atoms_back.pbc == pbc) 29 | 30 | 31 | def test_no_cell(): 32 | atoms = ase.Atoms("H2", positions=[(0, 0, 0), (0, 0, 1)], pbc=True) 33 | with pytest.raises(ValueError, match="but cell is all zeros"): 34 | AtomicGraph.from_ase(atoms) 35 | 36 | atoms.cell = (4, 4, 4) 37 | graph = AtomicGraph.from_ase(atoms) 38 | assert (graph.cell == torch.eye(3) * 4).all() 39 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import inspect 4 | import os 5 | from pathlib import Path 6 | from typing import Callable 7 | 8 | import ase.build 9 | import pytest 10 | import pytorch_lightning 11 | import torch 12 | from ase import Atoms 13 | from ase.io import read 14 | from locache import reset 15 | 16 | from graph_pes.atomic_graph import AtomicGraph, PropertyKey 17 | from graph_pes.data.datasets import get_all_graphs_and_cache_to_disk 18 | from graph_pes.graph_pes_model import GraphPESModel 19 | from graph_pes.models import ( 20 | ALL_MODELS, 21 | EDDP, 22 | MACE, 23 | AdditionModel, 24 | FixedOffset, 25 | LennardJones, 26 | NequIP, 27 | PaiNN, 28 | TensorNet, 29 | ZEmbeddingMACE, 30 | ZEmbeddingNequIP, 31 | ) 32 | from graph_pes.models.components.scaling import LocalEnergiesScaler 33 | 34 | # remove cache so that any changes are actually tested 35 | reset(get_all_graphs_and_cache_to_disk) 36 | 37 | # non-verbose load-atoms to avoid poluting the test output 38 | os.environ["LOAD_ATOMS_VERBOSE"] = "0" 39 | 40 | 41 | def all_model_factories( 42 | expected_elements: list[str], 43 | cutoff: float, 44 | ) -> tuple[list[str], list[Callable[[], GraphPESModel]]]: 45 | pytorch_lightning.seed_everything(42) 46 | # make these models as small as possible to speed up tests 47 | _small_nequip = { 48 | "layers": 2, 49 | "features": dict( 50 | channels=16, 51 | l_max=1, 52 | use_odd_parity=False, 53 | ), 54 | } 55 | required_kwargs = { 56 | NequIP: {"elements": expected_elements, **_small_nequip}, 57 | ZEmbeddingNequIP: {**_small_nequip}, 58 | MACE: { 59 | "elements": expected_elements, 60 | "layers": 3, 61 | "l_max": 2, 62 | "correlation": 3, 63 | "channels": 4, 64 | }, 65 | ZEmbeddingMACE: { 66 | "layers": 3, 67 | "l_max": 2, 68 | "correlation": 3, 69 | "channels": 4, 70 | "z_embed_dim": 4, 71 | }, 72 | PaiNN: { 73 | "layers": 2, 74 | "channels": 16, 75 | }, 76 | TensorNet: { 77 | "layers": 2, 78 | "radial_features": 24, 79 | "channels": 8, 80 | }, 81 | EDDP: {"elements": expected_elements}, 82 | } 83 | 84 | def _model_factory( 85 | model_klass: type[GraphPESModel], 86 | ) -> Callable[[], GraphPESModel]: 87 | # inspect for if cutoff is a required argument 88 | requires_cutoff = False 89 | for arg in inspect.signature(model_klass.__init__).parameters.values(): 90 | if arg.name == "cutoff": 91 | requires_cutoff = True 92 | break 93 | kwargs = required_kwargs.get(model_klass, {}) 94 | if requires_cutoff: 95 | kwargs["cutoff"] = cutoff 96 | return lambda: model_klass(**kwargs) 97 | 98 | names = [model.__name__ for model in ALL_MODELS] 99 | factories = [_model_factory(model) for model in ALL_MODELS] 100 | names.append("AdditionModel") 101 | factories.append( 102 | lambda: AdditionModel( 103 | lj=LennardJones(cutoff=cutoff), offset=FixedOffset() 104 | ) 105 | ) 106 | return names, factories 107 | 108 | 109 | def all_models( 110 | expected_elements: list[str], 111 | cutoff: float, 112 | ) -> tuple[list[str], list[GraphPESModel]]: 113 | torch.manual_seed(42) 114 | names, factories = all_model_factories(expected_elements, cutoff) 115 | return names, [factory() for factory in factories] 116 | 117 | 118 | def parameterise_all_models(expected_elements: list[str], cutoff: float = 5.0): 119 | def decorator(func): 120 | names, models = all_models(expected_elements, cutoff) 121 | return pytest.mark.parametrize("model", models, ids=names)(func) 122 | 123 | return decorator 124 | 125 | 126 | def parameterise_model_classes( 127 | expected_elements: list[str], 128 | cutoff: float = 5.0, 129 | ): 130 | def decorator(func): 131 | names, factories = all_model_factories(expected_elements, cutoff) 132 | return pytest.mark.parametrize("model_class", factories, ids=names)( 133 | func 134 | ) 135 | 136 | return decorator 137 | 138 | 139 | def graph_from_molecule(molecule: str, cutoff: float = 3.7) -> AtomicGraph: 140 | return AtomicGraph.from_ase(ase.build.molecule(molecule), cutoff) 141 | 142 | 143 | CU_STRUCTURES_FILE = Path(__file__).parent / "test.xyz" 144 | CU_TEST_STRUCTURES: list[Atoms] = read(CU_STRUCTURES_FILE, ":") # type: ignore 145 | 146 | CONFIGS_DIR = Path(__file__).parent.parent.parent / "configs" 147 | 148 | 149 | class DoesNothingModel(GraphPESModel): 150 | def __init__(self): 151 | super().__init__( 152 | cutoff=3.7, 153 | implemented_properties=["local_energies"], 154 | ) 155 | self.scaler = LocalEnergiesScaler() 156 | 157 | def forward(self, graph: AtomicGraph) -> dict[PropertyKey, torch.Tensor]: 158 | local_energies = torch.zeros(len(graph.Z)) 159 | local_energies = self.scaler(local_energies, graph) 160 | return {"local_energies": local_energies} 161 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jla-gardner/graph-pes/c574efb9ee873ddefe9cbd0c62b58777b58d96f2/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/components/test_aggregation.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | import torch 5 | from ase import Atoms 6 | 7 | from graph_pes.atomic_graph import ( 8 | AtomicGraph, 9 | number_of_atoms, 10 | number_of_edges, 11 | number_of_neighbours, 12 | to_batch, 13 | ) 14 | from graph_pes.models.components.aggregation import ( 15 | MeanNeighbours, 16 | NeighbourAggregation, 17 | ScaledSumNeighbours, 18 | SumNeighbours, 19 | VariancePreservingSumNeighbours, 20 | ) 21 | 22 | # Test structures 23 | ISOLATED_ATOM = Atoms("H", positions=[(0, 0, 0)], pbc=False) 24 | DIMER = Atoms("H2", positions=[(0, 0, 0), (0, 0, 1)], pbc=False) 25 | TRIMER = Atoms("H3", positions=[(0, 0, 0), (0, 0, 1), (0, 1, 0)], pbc=False) 26 | 27 | 28 | @pytest.fixture 29 | def graphs(): 30 | return [ 31 | AtomicGraph.from_ase(ISOLATED_ATOM, cutoff=1.5), 32 | AtomicGraph.from_ase(DIMER, cutoff=1.5), 33 | AtomicGraph.from_ase(TRIMER, cutoff=1.5), 34 | ] 35 | 36 | 37 | @pytest.mark.parametrize( 38 | "aggregation_class", 39 | [ 40 | SumNeighbours, 41 | MeanNeighbours, 42 | ScaledSumNeighbours, 43 | VariancePreservingSumNeighbours, 44 | ], 45 | ) 46 | def test_aggregation_shape(aggregation_class, graphs): 47 | aggregation = aggregation_class() 48 | 49 | for graph in graphs: 50 | N = number_of_atoms(graph) 51 | E = number_of_edges(graph) 52 | 53 | for shape in [(E,), (E, 2), (E, 2, 3)]: 54 | edge_property = torch.rand(shape) 55 | result = aggregation(edge_property, graph) 56 | assert result.shape == (N, *shape[1:]) 57 | 58 | 59 | def test_sum_neighbours(graphs): 60 | sum_agg = SumNeighbours() 61 | 62 | for graph in graphs: 63 | E = number_of_edges(graph) 64 | edge_property = torch.ones(E) 65 | result = sum_agg(edge_property, graph) 66 | assert torch.allclose( 67 | result, 68 | number_of_neighbours(graph, include_central_atom=False).float(), 69 | ) 70 | 71 | 72 | def test_mean_neighbours(graphs): 73 | mean_agg = MeanNeighbours() 74 | 75 | for graph in graphs: 76 | E = number_of_edges(graph) 77 | edge_property = torch.ones(E) 78 | result = mean_agg(edge_property, graph) 79 | expected = number_of_neighbours( 80 | graph, include_central_atom=False 81 | ) / number_of_neighbours(graph, include_central_atom=True) 82 | assert torch.allclose(result, expected) 83 | 84 | 85 | @pytest.mark.parametrize("learnable", [True, False]) 86 | def test_scaled_sum_neighbours(graphs, learnable): 87 | scaled_sum_agg = ScaledSumNeighbours(learnable=learnable) 88 | 89 | # Test pre_fit 90 | batch = to_batch(graphs) 91 | scaled_sum_agg.pre_fit(batch) 92 | avg_neighbours = number_of_edges(batch) / number_of_atoms(batch) 93 | assert torch.isclose(scaled_sum_agg.scale, torch.tensor(avg_neighbours)) 94 | 95 | # Test forward 96 | for graph in graphs: 97 | E = number_of_edges(graph) 98 | edge_property = torch.ones(E) 99 | result = scaled_sum_agg(edge_property, graph) 100 | expected = ( 101 | number_of_neighbours(graph, include_central_atom=False).to( 102 | torch.float 103 | ) 104 | / scaled_sum_agg.scale 105 | ) 106 | assert torch.allclose(result, expected) 107 | 108 | 109 | def test_variance_preserving_sum_neighbours(graphs): 110 | var_pres_sum_agg = VariancePreservingSumNeighbours() 111 | 112 | for graph in graphs: 113 | E = number_of_edges(graph) 114 | edge_property = torch.ones(E) 115 | result = var_pres_sum_agg(edge_property, graph) 116 | expected = number_of_neighbours( 117 | graph, include_central_atom=False 118 | ) / torch.sqrt(number_of_neighbours(graph, include_central_atom=True)) 119 | assert torch.allclose(result, expected) 120 | 121 | 122 | def test_parse_aggregation(): 123 | assert isinstance(NeighbourAggregation.parse("sum"), SumNeighbours) 124 | assert isinstance(NeighbourAggregation.parse("mean"), MeanNeighbours) 125 | assert isinstance( 126 | NeighbourAggregation.parse("constant_fixed"), ScaledSumNeighbours 127 | ) 128 | assert isinstance( 129 | NeighbourAggregation.parse("constant_learnable"), ScaledSumNeighbours 130 | ) 131 | assert isinstance( 132 | NeighbourAggregation.parse("sqrt"), VariancePreservingSumNeighbours 133 | ) 134 | 135 | with pytest.raises(ValueError): 136 | NeighbourAggregation.parse("invalid_mode") # type: ignore 137 | 138 | 139 | @pytest.mark.parametrize( 140 | "aggregation_class", 141 | [ 142 | SumNeighbours, 143 | MeanNeighbours, 144 | ScaledSumNeighbours, 145 | VariancePreservingSumNeighbours, 146 | ], 147 | ) 148 | def test_torchscript_compatibility(aggregation_class): 149 | aggregation = aggregation_class() 150 | scripted = torch.jit.script(aggregation) 151 | assert isinstance(scripted, torch.jit.ScriptModule) 152 | -------------------------------------------------------------------------------- /tests/models/components/test_distances.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import pytest 6 | import torch 7 | 8 | from graph_pes.models.components.distances import ( 9 | Bessel, 10 | CosineEnvelope, 11 | DistanceExpansion, 12 | Envelope, 13 | ExponentialRBF, 14 | GaussianSmearing, 15 | PolynomialEnvelope, 16 | SinExpansion, 17 | get_distance_expansion, 18 | ) 19 | 20 | _expansions = [Bessel, GaussianSmearing, SinExpansion, ExponentialRBF] 21 | _names = [expansion.__name__ for expansion in _expansions] 22 | parameterise_expansions = pytest.mark.parametrize( 23 | "expansion_klass", 24 | _expansions, 25 | ids=_names, 26 | ) 27 | 28 | 29 | @parameterise_expansions 30 | def test_torchscript( 31 | expansion_klass: type[DistanceExpansion], 32 | tmp_path: Path, 33 | ): 34 | n_features = 17 35 | cutoff = 5.0 36 | expansion = expansion_klass(n_features, cutoff, trainable=True) 37 | 38 | scripted = torch.jit.script(expansion) 39 | assert isinstance(scripted, torch.jit.ScriptModule) 40 | r = torch.linspace(0, cutoff, 10) 41 | x = scripted(r) 42 | 43 | torch.jit.save(scripted, tmp_path / "expansion.pt") 44 | loaded: torch.jit.ScriptModule = torch.jit.load(tmp_path / "expansion.pt") 45 | 46 | assert torch.allclose(x, loaded(r)) 47 | 48 | 49 | @parameterise_expansions 50 | def test_expansions(expansion_klass: type[DistanceExpansion]): 51 | n_features = 17 52 | cutoff = 5.0 53 | r = torch.linspace(0, cutoff, 10) 54 | x = expansion_klass(n_features, cutoff)(r) 55 | assert x.shape == (10, n_features) 56 | assert x.grad_fn is not None 57 | 58 | x = expansion_klass(n_features, cutoff, trainable=False)(r) 59 | assert x.grad_fn is None 60 | 61 | 62 | @pytest.mark.parametrize( 63 | "envelope", 64 | [ 65 | CosineEnvelope, 66 | PolynomialEnvelope, 67 | ], 68 | ) 69 | def test_envelopes(envelope: type[Envelope]): 70 | cutoff = 5.0 71 | env = envelope(cutoff=cutoff) 72 | 73 | r = torch.tensor([4.5, 5, 5.5]) 74 | x = env(r) 75 | assert x.shape == (3,) 76 | a, b, c = env(r).tolist() 77 | assert a > 0, "The envelope should be positive" 78 | assert b == 0, "The envelope should be zero at the cutoff" 79 | assert c == 0, "The envelope should be zero beyond the cutoff" 80 | 81 | 82 | def test_get_expansion(): 83 | assert ( 84 | get_distance_expansion("Bessel") 85 | == get_distance_expansion(Bessel) 86 | == Bessel 87 | ) 88 | 89 | with pytest.raises(ValueError, match="Unknown distance expansion"): 90 | get_distance_expansion("Unknown") 91 | 92 | with pytest.raises(ValueError, match="is not a DistanceExpansion"): 93 | get_distance_expansion("PolynomialEnvelope") 94 | -------------------------------------------------------------------------------- /tests/models/test_correctness.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from ase import build 4 | from ase.calculators.lj import LennardJones 5 | 6 | from graph_pes.models import LennardJones as GraphPESLennardJones 7 | from graph_pes.utils.calculator import GraphPESCalculator 8 | 9 | 10 | def test_correctness(): 11 | lj_ase = LennardJones(rc=3.0) 12 | lj_gp = GraphPESCalculator(GraphPESLennardJones.from_ase(rc=3.0)) 13 | 14 | # test correct on molecular 15 | mol = build.molecule("H2O") 16 | mol.center(vacuum=10) 17 | 18 | assert lj_ase.get_potential_energy(mol) == pytest.approx( 19 | lj_gp.get_potential_energy(mol), abs=3e-5 20 | ) 21 | np.testing.assert_allclose( # type: ignore 22 | lj_ase.get_forces(mol), # type: ignore 23 | lj_gp.get_forces(mol), # type: ignore 24 | atol=3e-4, 25 | ) 26 | np.testing.assert_allclose( # type: ignore 27 | lj_ase.get_stress(mol), # type: ignore 28 | lj_gp.get_stress(mol), # type: ignore 29 | atol=3e-4, 30 | ) 31 | 32 | # test correct on bulk 33 | bulk = build.bulk("C", "diamond", a=3.5668) 34 | assert lj_ase.get_potential_energy(bulk) == pytest.approx( 35 | lj_gp.get_potential_energy(bulk), abs=3e-5 36 | ) 37 | np.testing.assert_allclose( # type: ignore 38 | lj_ase.get_forces(bulk), # type: ignore 39 | lj_gp.get_forces(bulk), # type: ignore 40 | atol=3e-4, 41 | ) 42 | np.testing.assert_allclose( # type: ignore 43 | lj_ase.get_stress(bulk), # type: ignore 44 | lj_gp.get_stress(bulk), # type: ignore 45 | atol=3e-4, 46 | ) 47 | -------------------------------------------------------------------------------- /tests/models/test_cutoffs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | 5 | import pytest 6 | from torch import Tensor 7 | 8 | from graph_pes import GraphPESModel 9 | from graph_pes.atomic_graph import ( 10 | AtomicGraph, 11 | PropertyKey, 12 | neighbour_distances, 13 | number_of_edges, 14 | to_batch, 15 | trim_edges, 16 | ) 17 | from graph_pes.models import AdditionModel, FixedOffset, SchNet 18 | 19 | from ..helpers import graph_from_molecule 20 | 21 | 22 | @dataclass 23 | class Stats: 24 | n_neighbours: int 25 | max_edge_length: float 26 | 27 | 28 | class DummyModel(GraphPESModel): 29 | def __init__(self, name: str, cutoff: float, info: dict[str, Stats]): 30 | super().__init__( 31 | cutoff=cutoff, 32 | implemented_properties=["local_energies"], 33 | ) 34 | self.name = name 35 | self.info = info 36 | 37 | def forward(self, graph: AtomicGraph) -> dict[PropertyKey, Tensor]: 38 | # insert statistics here: `GraphPESModel` should automatically 39 | # trim the input graph based on the model's cutoff 40 | self.info[self.name] = Stats( 41 | n_neighbours=number_of_edges(graph), 42 | max_edge_length=neighbour_distances(graph).max().item(), 43 | ) 44 | # dummy return value 45 | return {"local_energies": graph.Z.float()} 46 | 47 | 48 | def test_auto_trimming(): 49 | info = {} 50 | graph = graph_from_molecule("CH3CH2OCH3", cutoff=5.0) 51 | 52 | large_model = DummyModel("large", cutoff=5.0, info=info) 53 | small_model = DummyModel("small", cutoff=3.0, info=info) 54 | 55 | # forward passes to gather info 56 | large_model.get_all_PES_predictions(graph) 57 | small_model.get_all_PES_predictions(graph) 58 | 59 | # check that cutoff filtering is working 60 | assert info["large"].n_neighbours == number_of_edges(graph) 61 | assert info["small"].n_neighbours < number_of_edges(graph) 62 | 63 | assert ( 64 | info["large"].max_edge_length == neighbour_distances(graph).max().item() 65 | ) 66 | assert ( 67 | info["small"].max_edge_length < neighbour_distances(graph).max().item() 68 | ) 69 | 70 | 71 | def test_model_cutoffs(): 72 | model = AdditionModel( 73 | small=SchNet(cutoff=3.0), 74 | large=SchNet(cutoff=5.0), 75 | ) 76 | 77 | assert model.cutoff == 5.0 78 | 79 | model = FixedOffset() 80 | assert model.cutoff == 0 81 | 82 | 83 | def test_warning(): 84 | graph = graph_from_molecule("CH4", cutoff=3.0) 85 | trimmed_graph = trim_edges(graph, cutoff=3.0) 86 | 87 | model = SchNet(cutoff=6.0) 88 | with pytest.warns(UserWarning, match="Graph already has a cutoff of"): 89 | model.get_all_PES_predictions(trimmed_graph) 90 | 91 | 92 | def test_cutoff_trimming(): 93 | graph = graph_from_molecule("CH4", cutoff=5) 94 | 95 | trimmed_graph = trim_edges(graph, cutoff=3.0) 96 | assert graph is not trimmed_graph 97 | assert trimmed_graph.cutoff == 3.0 98 | 99 | # check that trimming a second time with the same cutoff is a no-op 100 | doubly_trimmed_graph = trim_edges(trimmed_graph, cutoff=3.0) 101 | assert doubly_trimmed_graph is trimmed_graph 102 | 103 | # but that if the cutoff is further reduced then the trimming occurs 104 | doubly_trimmed_graph = trim_edges(trimmed_graph, cutoff=2.0) 105 | assert doubly_trimmed_graph is not trimmed_graph 106 | assert doubly_trimmed_graph.cutoff == 2.0 107 | 108 | 109 | def test_cutoff_batching(): 110 | graph = graph_from_molecule("CH4", cutoff=5.0) 111 | graph2 = graph._replace(cutoff=3.0) 112 | 113 | with pytest.warns( 114 | UserWarning, match="Attempting to batch graphs with different cutoffs" 115 | ): 116 | batch = to_batch([graph, graph2]) 117 | assert batch.cutoff == 5.0 118 | -------------------------------------------------------------------------------- /tests/models/test_direct_prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ase.build import molecule 3 | 4 | from graph_pes import AtomicGraph 5 | from graph_pes.models import NequIP 6 | 7 | 8 | def test_direct_forces(): 9 | model = NequIP( 10 | direct_force_predictions=True, 11 | elements=["C", "H"], 12 | features=dict(channels=4, l_max=1, use_odd_parity=True), 13 | ) 14 | graph = AtomicGraph.from_ase(molecule("CH4")) 15 | 16 | # test that the model outputs forces directly... 17 | preds = model(graph) 18 | assert "forces" in preds 19 | assert preds["forces"].shape == (5, 3) 20 | 21 | # model outputs forces indirectly 22 | all_preds = model.get_all_PES_predictions(graph) 23 | assert "forces" in all_preds 24 | assert all_preds["forces"].shape == (5, 3) 25 | assert torch.allclose(preds["forces"], all_preds["forces"]) 26 | 27 | # check equivariance: all force mags for the H atoms should be the same 28 | force_mags = torch.linalg.norm(preds["forces"], dim=1) 29 | idx = graph.Z == 1 30 | assert torch.allclose(force_mags[idx], force_mags[idx][0]) 31 | # and that these forces aren't all 0 32 | assert not torch.allclose( 33 | force_mags[idx], torch.zeros_like(force_mags[idx]) 34 | ) 35 | -------------------------------------------------------------------------------- /tests/models/test_equivariance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from ase.build import molecule 5 | 6 | from graph_pes import AtomicGraph, GraphPESModel 7 | 8 | from .. import helpers 9 | 10 | CUTOFF = 1.2 # bond length of methane is ~1.09 Å 11 | 12 | 13 | @helpers.parameterise_all_models(expected_elements=["H", "C"], cutoff=CUTOFF) 14 | def test_equivariance(model: GraphPESModel): 15 | methane = molecule("CH4") 16 | og_graph = AtomicGraph.from_ase(methane, cutoff=CUTOFF) 17 | og_predictions = model.get_all_PES_predictions(og_graph) 18 | 19 | # get a (repeatably) random rotation matrix 20 | R, _ = np.linalg.qr(np.random.RandomState(42).randn(3, 3)) 21 | 22 | # rotate the molecule 23 | new_methane = methane.copy() 24 | shift = new_methane.positions[0] 25 | new_methane.positions = (new_methane.positions - shift).dot(R) + shift 26 | new_graph = AtomicGraph.from_ase(new_methane, cutoff=CUTOFF) 27 | new_predictions = model.get_all_PES_predictions(new_graph) 28 | 29 | # pre-checks: 30 | np.testing.assert_allclose( 31 | np.linalg.norm(new_methane.positions - shift, axis=1), 32 | np.linalg.norm(methane.positions - shift, axis=1), 33 | ) 34 | 35 | # now check: 36 | # 1. invariance of energy prediction 37 | torch.testing.assert_close( 38 | og_predictions["energy"], 39 | new_predictions["energy"], 40 | atol=2e-5, 41 | rtol=1e-3, 42 | ) 43 | 44 | # 2. invariance of force magnitude 45 | torch.testing.assert_close( 46 | og_predictions["forces"].norm(dim=-1), 47 | new_predictions["forces"].norm(dim=-1), 48 | atol=2e-5, 49 | rtol=1e-3, 50 | ) 51 | 52 | # 3. equivariance of forces 53 | _dtype = og_predictions["forces"].dtype 54 | torch.testing.assert_close( 55 | og_predictions["forces"] @ torch.tensor(R.T, dtype=_dtype), 56 | new_predictions["forces"], 57 | # auto-grad is not perfect, and we lose 58 | # precision, particularly with larger models 59 | # and default float32 dtype 60 | atol=3e-3, 61 | # some of the og predictions are 0: a large 62 | # relative error is not a problem here 63 | rtol=10, 64 | ) 65 | 66 | # 4. molecule is symetric: forces should ~0 on the central C, 67 | # and of equal magnitude on the H atoms, and H atom 68 | # local energies should be the same 69 | force_norms = new_predictions["forces"].norm(dim=-1) 70 | c_force = force_norms[new_graph.Z == 6] 71 | assert c_force.item() == pytest.approx(0.0, abs=3e-4) 72 | 73 | h_forces = force_norms[new_graph.Z == 1] 74 | assert h_forces.min().item() == pytest.approx( 75 | h_forces.max().item(), abs=1e-4 76 | ) 77 | 78 | h_local_energies = new_predictions["local_energies"][new_graph.Z == 1] 79 | _min, _max = h_local_energies.min().item(), h_local_energies.max().item() 80 | assert _min == pytest.approx(_max, abs=1e-6) 81 | -------------------------------------------------------------------------------- /tests/models/test_freezing.py: -------------------------------------------------------------------------------- 1 | from graph_pes.models import ( 2 | LennardJones, 3 | freeze, 4 | freeze_all_except, 5 | freeze_any_matching, 6 | freeze_matching, 7 | ) 8 | from graph_pes.utils.nn import learnable_parameters 9 | 10 | 11 | def test_freeze(): 12 | model = LennardJones() 13 | model = freeze(model) 14 | 15 | assert learnable_parameters(model) == 0 16 | 17 | 18 | def test_freeze_matching(): 19 | model = LennardJones() 20 | model = freeze_matching(model, ".*") 21 | 22 | assert learnable_parameters(model) == 0 23 | 24 | model = LennardJones() 25 | model = freeze_matching(model, "_log_epsilon") 26 | assert learnable_parameters(model) == 1 27 | assert not model._log_epsilon.requires_grad 28 | 29 | 30 | def test_freeze_all_except(): 31 | model = LennardJones() 32 | model = freeze_all_except(model, "_log_epsilon") 33 | 34 | assert learnable_parameters(model) == 1 35 | assert model._log_epsilon.requires_grad 36 | 37 | 38 | def test_freeze_any_matching(): 39 | model = LennardJones() 40 | model = freeze_any_matching(model, ["_log_epsilon", "_log_sigma"]) 41 | 42 | assert learnable_parameters(model) == 0 43 | assert not model._log_epsilon.requires_grad 44 | assert not model._log_sigma.requires_grad 45 | -------------------------------------------------------------------------------- /tests/models/test_offsets.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from ase import Atoms 7 | 8 | from graph_pes import AtomicGraph 9 | from graph_pes.atomic_graph import number_of_atoms, to_batch 10 | from graph_pes.models.offsets import EnergyOffset, FixedOffset, LearnableOffset 11 | 12 | from .. import helpers 13 | 14 | graphs = [ 15 | AtomicGraph.from_ase(atoms, cutoff=3) 16 | for atoms in helpers.CU_TEST_STRUCTURES 17 | ] 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "offset_model,trainable", 22 | [ 23 | (FixedOffset(He=1, Cu=-3), False), 24 | (LearnableOffset(He=1, Cu=-3), True), 25 | ], 26 | ) 27 | def test_offset_behaviour(offset_model: EnergyOffset, trainable: bool): 28 | assert offset_model._offsets.requires_grad == trainable 29 | total_parameters = sum(p.numel() for p in offset_model.parameters()) 30 | assert ( 31 | total_parameters == 2 32 | ), "expected 2 parameters (energy offsets for He and Cu)" 33 | 34 | assert ( 35 | offset_model._offsets[2] == 1 36 | ), "expected offset for He to be as specified" 37 | 38 | graph = graphs[0] 39 | n = number_of_atoms(graph) 40 | 41 | assert offset_model.predict_local_energies(graph).shape == (n,) 42 | predictions = offset_model.predict( 43 | graph, 44 | properties=["energy", "forces"], 45 | ) 46 | 47 | assert "energy" in predictions 48 | # total energy is the sum of offsets of all atoms, which are Cu 49 | assert predictions["energy"].item() == n * -3 50 | if trainable: 51 | assert predictions["energy"].grad_fn is not None, ( 52 | "expected gradients on the energy calculation " 53 | "due to using trainable offsets" 54 | ) 55 | 56 | # no interactions between atoms, so forces are 0 57 | assert torch.all(predictions["forces"] == 0) 58 | 59 | 60 | def test_energy_offset_fitting(): 61 | # create some fake structures to test shift and scale fitting 62 | 63 | structures = [] 64 | 65 | # (num H atoms, num C atoms) 66 | nums = [(3, 0), (4, 8), (5, 2)] 67 | H_energy, C_energy = -4.5, -10.0 68 | 69 | for n_H, n_C in nums: 70 | atoms = Atoms("H" * n_H + "C" * n_C) 71 | atoms.info["energy"] = n_H * H_energy + n_C * C_energy 72 | atoms.arrays["forces"] = np.zeros((n_H + n_C, 3)) 73 | structures.append(atoms) 74 | 75 | graphs = [AtomicGraph.from_ase(atoms, cutoff=1.5) for atoms in structures] 76 | batch = to_batch(graphs) 77 | assert "energy" in batch.properties 78 | 79 | model = LearnableOffset() 80 | model.pre_fit_all_components(graphs) 81 | # check that the model has learned the correct energy offsets 82 | # use pytest.approx to account for numerical errors 83 | assert model._offsets[1].item() == pytest.approx(H_energy) 84 | assert model._offsets[6].item() == pytest.approx(C_energy) 85 | 86 | # check that initial values aren't overwritten if specified 87 | model = LearnableOffset(H=20) 88 | model.pre_fit_all_components(graphs) 89 | assert model._offsets[1].item() == pytest.approx(20) 90 | assert model._offsets[6].item() == pytest.approx(C_energy) 91 | 92 | # check suitable warning logged if no energy data 93 | model = LearnableOffset() 94 | graph = graphs[0] 95 | del graph.properties["energy"] 96 | with pytest.warns( 97 | UserWarning, 98 | match="No energy labels found in the training data", 99 | ): 100 | model.pre_fit_all_components([graph]) 101 | -------------------------------------------------------------------------------- /tests/models/test_parameter_counting.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import ase 4 | import torch 5 | 6 | from graph_pes import AtomicGraph 7 | from graph_pes.models import ( 8 | FixedOffset, 9 | LearnableOffset, 10 | LennardJones, 11 | LennardJonesMixture, 12 | SchNet, 13 | ) 14 | from graph_pes.models.addition import AdditionModel 15 | from graph_pes.utils.nn import PerElementParameter 16 | 17 | from ..helpers import DoesNothingModel 18 | 19 | graphs: list[AtomicGraph] = [] 20 | for n in range(5): 21 | g = AtomicGraph.from_ase(ase.Atoms(f"CH{n}"), cutoff=5.0) 22 | g.properties["energy"] = torch.tensor(n).float() 23 | graphs.append(g) 24 | 25 | 26 | # Before a model has been pre_fit, all PerElementParameters should have 0 27 | # relevant and countable values. After pre_fitting, the PerElementParameter 28 | # values corresponding to elements seen in the pre-fit data should be counted. 29 | 30 | 31 | def test_fixed(): 32 | model = FixedOffset(H=1.0, C=2.0) 33 | 34 | # model should have a single parameter 35 | params = list(model.parameters()) 36 | assert len(params) == 1 37 | 38 | # there should be 2 values, since we passed 2 offset energies 39 | assert params[0].numel() == 2 40 | 41 | 42 | def test_scaling(): 43 | model = DoesNothingModel() 44 | # the model should have a single parameter: the per_element_scaling 45 | params = list(model.parameters()) 46 | assert len(params) == 1 47 | assert params[0] is model.scaler.per_element_scaling 48 | 49 | # there should be no countable values in this parameter 50 | assert params[0].numel() == 0 51 | 52 | model.pre_fit_all_components(graphs) 53 | # now the model has seen info about 2 elements: 54 | # there should be 2 countable elements on the model 55 | assert params[0].numel() == 2 56 | 57 | 58 | def test_counting(): 59 | _schnet_dim = 50 60 | model = AdditionModel( 61 | offset=LearnableOffset(), 62 | schnet=SchNet(channels=_schnet_dim), 63 | ) 64 | 65 | non_pre_fit_params = sum(p.numel() for p in model.parameters()) 66 | 67 | # 3 per-element parameters: 68 | # 1. the offsets in LearnableOffset 69 | # 2. the per-element scaling in SchNet 70 | # 3. the chemical embedding of SchNet 71 | assert ( 72 | sum(1 for p in model.parameters() if isinstance(p, PerElementParameter)) 73 | == 3 74 | ) 75 | 76 | model.pre_fit_all_components(graphs) 77 | 78 | post_fit_params = sum(p.numel() for p in model.parameters()) 79 | 80 | # seen 2 elements (C and H), leading to a total of: 81 | # 1. 2 countable elements in the offsets 82 | # 2. 2 countable elements in the per-element scaling 83 | # 3. 2*schnet emdedding dim countable elements in the chemical embedding 84 | assert post_fit_params == non_pre_fit_params + 2 + 2 + 2 * _schnet_dim 85 | 86 | 87 | def test_lj(): 88 | lj = LennardJones() 89 | # lj has two parameters: epsilon and sigma 90 | assert sum(p.numel() for p in lj.parameters()) == 2 91 | 92 | 93 | def test_lj_mixture(): 94 | lj_mixture = LennardJonesMixture() 95 | lj_mixture.pre_fit_all_components(graphs) 96 | 97 | expected_params = 0 98 | # sigma and epsilon for each element 99 | expected_params += 2 * 2 100 | # nu and zeta term for each ordered pair of elements 101 | expected_params += 2 * 2 * 2 102 | 103 | assert sum(p.numel() for p in lj_mixture.parameters()) == expected_params 104 | -------------------------------------------------------------------------------- /tests/models/test_predictions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import ase.build 4 | import numpy as np 5 | import pytest 6 | import torch 7 | from ase import Atoms 8 | 9 | from graph_pes import AtomicGraph 10 | from graph_pes.atomic_graph import get_cell_volume, number_of_edges, to_batch 11 | from graph_pes.models.pairwise import LennardJones 12 | 13 | no_pbc = AtomicGraph.from_ase( 14 | Atoms("H2", positions=[(0, 0, 0), (0, 0, 1)], pbc=False), 15 | cutoff=1.5, 16 | ) 17 | pbc = AtomicGraph.from_ase( 18 | Atoms("H2", positions=[(0, 0, 0), (0, 0, 1)], pbc=True, cell=(2, 2, 2)), 19 | cutoff=1.5, 20 | ) 21 | 22 | 23 | def test_predictions(): 24 | expected_shapes = { 25 | "energy": (), 26 | "forces": (2, 3), 27 | "stress": (3, 3), 28 | "virial": (3, 3), 29 | "local_energies": (2,), 30 | } 31 | 32 | model = LennardJones(cutoff=1.5) 33 | 34 | # no stress should be predicted for non-periodic systems 35 | predictions = model.get_all_PES_predictions(no_pbc) 36 | assert set(predictions.keys()) == {"energy", "forces", "local_energies"} 37 | 38 | for key in "energy", "forces", "local_energies": 39 | assert predictions[key].shape == expected_shapes[key] 40 | 41 | # if we ask for stress, we get an error: 42 | with pytest.raises(ValueError): 43 | model.predict(no_pbc, properties=["stress"]) 44 | 45 | # with pbc structures, we should get all predictions 46 | predictions = model.get_all_PES_predictions(pbc) 47 | assert set(predictions.keys()) == { 48 | "energy", 49 | "forces", 50 | "stress", 51 | "virial", 52 | "local_energies", 53 | } 54 | 55 | for key in "energy", "forces", "stress", "virial", "local_energies": 56 | assert predictions[key].shape == expected_shapes[key] 57 | 58 | 59 | def test_batched_prediction(): 60 | batch = to_batch([pbc, pbc]) 61 | 62 | expected_shapes = { 63 | "energy": (2,), # two structures 64 | "forces": (4, 3), # four atoms 65 | "stress": (2, 3, 3), # two structures 66 | "virial": (2, 3, 3), # two structures 67 | } 68 | 69 | predictions = LennardJones(cutoff=1.5).get_all_PES_predictions(batch) 70 | 71 | for key in "energy", "forces", "stress", "virial": 72 | assert predictions[key].shape == expected_shapes[key] 73 | 74 | 75 | def test_isolated_atom(): 76 | atom = Atoms("H", positions=[(0, 0, 0)], pbc=False) 77 | graph = AtomicGraph.from_ase(atom, cutoff=1.5) 78 | assert number_of_edges(graph) == 0 79 | 80 | predictions = LennardJones(cutoff=1.5).get_all_PES_predictions(graph) 81 | assert torch.allclose(predictions["forces"], torch.zeros(1, 3)) 82 | 83 | 84 | def test_stress_and_virial(): 85 | model = LennardJones(cutoff=5.0) 86 | 87 | # get stress and virial predictions 88 | structure = ase.build.bulk("Si", "diamond", a=5.43) 89 | graph = AtomicGraph.from_ase(structure, cutoff=5.0) 90 | 91 | s = model.predict_stress(graph) 92 | v = model.predict_virial(graph) 93 | volume = get_cell_volume(graph) 94 | torch.testing.assert_close(s, -v / volume) 95 | 96 | # ensure correct scaling 97 | np.product = np.prod # fix ase for new versions of numpy 98 | structure2 = structure.copy().repeat((2, 2, 2)) 99 | graph2 = AtomicGraph.from_ase(structure2, cutoff=5.0) 100 | s2 = model.predict_stress(graph2) 101 | v2 = model.predict_virial(graph2) 102 | 103 | # use diagonal elements of stress and virial tensor to avoid 104 | # issues with v low values on off diagonals due to numerical error 105 | def get_diagonal(tensor): 106 | return torch.diagonal(tensor, dim1=-2, dim2=-1) 107 | 108 | # stress is an intensive property, so should remain the same 109 | # under a repeated unit cell 110 | torch.testing.assert_close(get_diagonal(s), get_diagonal(s2)) 111 | 112 | # virial is an extensive property, so should scale with volume 113 | # which in this case is a factor of 8 larger 114 | torch.testing.assert_close(8 * get_diagonal(v), get_diagonal(v2)) 115 | -------------------------------------------------------------------------------- /tests/models/test_scripting.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | import torch 5 | from ase.build import molecule 6 | 7 | from graph_pes.models import SchNet, ScriptedModel, load_model 8 | 9 | 10 | def test_scripting(tmp_path: Path): 11 | model = SchNet() 12 | water = molecule("H2O") 13 | pred = model.ase_calculator().get_potential_energy(water) 14 | 15 | _scripted = torch.jit.script(model) 16 | scripted = ScriptedModel(_scripted) 17 | pred_scripted = scripted.ase_calculator().get_potential_energy(water) 18 | 19 | assert pred == pytest.approx(pred_scripted) 20 | 21 | scripted_path = tmp_path / "scripted.pt" 22 | torch.jit.script(scripted).save(scripted_path) 23 | scripted_from_file = load_model(scripted_path) 24 | pred_scripted_from_file = ( 25 | scripted_from_file.ase_calculator().get_potential_energy(water) 26 | ) 27 | assert pred == pytest.approx(pred_scripted_from_file) 28 | -------------------------------------------------------------------------------- /tests/models/test_state_dict.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | import pytest 6 | import torch 7 | 8 | from graph_pes import GraphPESModel 9 | from graph_pes.atomic_graph import AtomicGraph, PropertyKey 10 | from graph_pes.models import LennardJones 11 | 12 | 13 | def test_state_dict(): 14 | model = LennardJones() 15 | state_dict = model.state_dict() 16 | assert "_GRAPH_PES_VERSION" in state_dict["_extra_state"] 17 | assert "extra" in state_dict["_extra_state"] 18 | 19 | class CustomModel(GraphPESModel): 20 | def __init__(self, v: float): 21 | super().__init__( 22 | cutoff=0.0, implemented_properties=["local_energies"] 23 | ) 24 | self.v = v 25 | 26 | def forward( 27 | self, graph: AtomicGraph 28 | ) -> dict[PropertyKey, torch.Tensor]: 29 | return {"local_energies": torch.zeros_like(graph.Z)} 30 | 31 | @property 32 | def extra_state(self) -> dict[str, Any]: 33 | return {"v": self.v} 34 | 35 | @extra_state.setter 36 | def extra_state(self, state: dict[str, Any]) -> None: 37 | self.v = state["v"] 38 | 39 | model1 = CustomModel(v=2.0) 40 | state_dict1 = model1.state_dict() 41 | assert state_dict1["_extra_state"]["extra"]["v"] == 2.0 42 | 43 | model2 = CustomModel(v=4.0) 44 | assert model2.v == 4.0 45 | model2.load_state_dict(state_dict1) 46 | assert model2.v == 2.0 47 | 48 | 49 | def test_state_dict_warning(): 50 | model = LennardJones() 51 | model._GRAPH_PES_VERSION = "0.0.0" # type: ignore 52 | 53 | with pytest.warns(UserWarning): 54 | LennardJones().load_state_dict(model.state_dict()) 55 | -------------------------------------------------------------------------------- /tests/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jla-gardner/graph-pes/c574efb9ee873ddefe9cbd0c62b58777b58d96f2/tests/training/__init__.py -------------------------------------------------------------------------------- /tests/training/test_callbacks.py: -------------------------------------------------------------------------------- 1 | from graph_pes.models.addition import AdditionModel 2 | from graph_pes.models.offsets import LearnableOffset 3 | from graph_pes.models.schnet import SchNet 4 | from graph_pes.training.callbacks import log_offset, log_scales 5 | 6 | 7 | class FakeLogger: 8 | def __init__(self, results: dict): 9 | self.results = results 10 | 11 | def log_metrics(self, metrics: dict): 12 | self.results.update(metrics) 13 | 14 | 15 | def test_offset_logger(): 16 | offset = LearnableOffset(H=1, O=2) 17 | offset._offsets.register_elements([1, 6, 8]) # H, C, O 18 | model = AdditionModel(offset=offset) 19 | logger = FakeLogger({}) 20 | log_offset(model, logger) # type: ignore 21 | assert logger.results == {"offset/H": 1.0, "offset/O": 2.0, "offset/C": 0.0} 22 | 23 | # clear the logger 24 | logger.results = {} 25 | log_offset(SchNet(), logger) # type: ignore 26 | assert logger.results == {} 27 | 28 | 29 | def test_scales_logger(): 30 | model = SchNet() 31 | model.scaler.per_element_scaling.register_elements([1, 6]) # H, C, O 32 | model.scaler.per_element_scaling.data[1, 0] = 2.0 33 | logger = FakeLogger({}) 34 | log_scales(model, logger) # type: ignore 35 | assert logger.results == {"scale/H": 2.0, "scale/C": 1.0} 36 | 37 | addition_model = AdditionModel(schnet=model) 38 | logger = FakeLogger({}) 39 | log_scales(addition_model, logger) # type: ignore 40 | assert logger.results == {"scale/schnet/H": 2.0, "scale/schnet/C": 1.0} 41 | -------------------------------------------------------------------------------- /tests/training/test_integration.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytorch_lightning as pl 4 | 5 | from graph_pes import AtomicGraph, GraphPESModel 6 | from graph_pes.atomic_graph import to_batch 7 | from graph_pes.config.training import FittingOptions 8 | from graph_pes.data.datasets import DatasetCollection, GraphDataset 9 | from graph_pes.training.callbacks import ( 10 | EarlyStoppingWithLogging, 11 | LoggedProgressBar, 12 | ) 13 | from graph_pes.training.loss import PerAtomEnergyLoss, TotalLoss 14 | from graph_pes.training.opt import Optimizer 15 | from graph_pes.training.tasks import train_with_lightning 16 | from graph_pes.training.utils import VALIDATION_LOSS_KEY 17 | 18 | from .. import helpers 19 | 20 | 21 | @helpers.parameterise_all_models(expected_elements=["Cu"], cutoff=3) 22 | def test_integration(model: GraphPESModel): 23 | if len(list(model.parameters())) == 0: 24 | # nothing to train 25 | return 26 | 27 | graphs = [ 28 | AtomicGraph.from_ase(atoms, cutoff=3) 29 | for atoms in helpers.CU_TEST_STRUCTURES 30 | ] 31 | 32 | # Split data into train/val sets 33 | train_graphs = graphs[:8] 34 | val_graphs = graphs[8:] 35 | 36 | # pre-fit before measuring performance to ensure that 37 | # training improves the model 38 | model.pre_fit_all_components(train_graphs) 39 | 40 | train_batch = to_batch(train_graphs) 41 | assert "energy" in train_batch.properties 42 | 43 | loss = TotalLoss([PerAtomEnergyLoss()]) 44 | 45 | def get_train_loss(): 46 | return loss( 47 | model, train_batch, model.predict(train_batch, ["energy"]) 48 | ).loss_value.item() 49 | 50 | before = get_train_loss() 51 | 52 | # Create trainer and train 53 | train_with_lightning( 54 | trainer=pl.Trainer( 55 | max_epochs=10, 56 | accelerator="cpu", 57 | callbacks=[ 58 | LoggedProgressBar(), 59 | EarlyStoppingWithLogging( 60 | monitor=VALIDATION_LOSS_KEY, patience=10 61 | ), 62 | ], 63 | ), 64 | model=model, 65 | data=DatasetCollection( 66 | train=GraphDataset(train_graphs), 67 | valid=GraphDataset(val_graphs), 68 | ), 69 | loss=loss, 70 | fit_config=FittingOptions( 71 | pre_fit_model=False, 72 | loader_kwargs={"batch_size": 8}, 73 | max_n_pre_fit=100, 74 | early_stopping=None, 75 | early_stopping_patience=None, 76 | auto_fit_reference_energies=False, 77 | ), 78 | optimizer=Optimizer("Adam", lr=3e-4), 79 | ) 80 | 81 | after = get_train_loss() 82 | 83 | assert after < before, "training did not improve the loss" 84 | -------------------------------------------------------------------------------- /tests/training/test_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | import torch 5 | 6 | from graph_pes.training.loss import MAE, RMSE, WeightedLoss 7 | 8 | 9 | def test_metrics(): 10 | a = torch.tensor([1.0, 2.0, 3.0]) 11 | b = torch.tensor([1.0, 2.0, 3.0]) 12 | 13 | assert torch.allclose(MAE()(a, b), torch.tensor(0.0)) 14 | assert torch.allclose(RMSE()(a, b), torch.tensor(0.0)) 15 | 16 | c = torch.tensor([0, 0, 0]).float() 17 | assert torch.allclose(MAE()(a, c), torch.tensor(2.0)) 18 | assert torch.allclose(RMSE()(a, c), torch.tensor((1 + 4 + 9) / 3).sqrt()) 19 | 20 | 21 | def test_excpetion(): 22 | with pytest.raises(ImportError): 23 | WeightedLoss() 24 | -------------------------------------------------------------------------------- /tests/training/test_opt.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from graph_pes.models import LearnableOffset, LennardJones, SchNet 5 | from graph_pes.models.addition import AdditionModel 6 | from graph_pes.training.opt import LRScheduler, Optimizer 7 | 8 | 9 | def test_non_decayable_params(): 10 | opt_factory = Optimizer("Adam", weight_decay=1e-4) 11 | 12 | # learnable offsets have a single paramter, _offsets, 13 | # that is not decayable 14 | model = LearnableOffset() 15 | assert ( 16 | set(model.non_decayable_parameters()) 17 | == set([model._offsets]) 18 | == set(model.parameters()) 19 | ) 20 | 21 | opt = opt_factory(model) 22 | assert len(opt.param_groups) == 1 23 | assert opt.param_groups[0]["weight_decay"] == 0.0 24 | 25 | # LJ models have no non-decayable parameters 26 | model = LennardJones() 27 | assert len(model.non_decayable_parameters()) == 0 28 | 29 | opt = opt_factory(model) 30 | assert len(opt.param_groups) == 1 31 | assert opt.param_groups[0]["weight_decay"] == 1e-4 32 | 33 | # schnet has many parameters, but only per_element_scaling is 34 | # not decayable 35 | model = SchNet() 36 | assert set(model.non_decayable_parameters()) == set( 37 | [model.scaler.per_element_scaling] 38 | ) 39 | opt = opt_factory(model) 40 | assert len(opt.param_groups) == 2 41 | pg_by_name = {pg["name"]: pg for pg in opt.param_groups} 42 | assert pg_by_name["non-decayable"]["weight_decay"] == 0.0 43 | assert pg_by_name["normal"]["weight_decay"] == 1e-4 44 | 45 | # addition models should return the decayable parameters of their 46 | # components 47 | model = AdditionModel(energy_offset=LearnableOffset(), schnet=SchNet()) 48 | assert set(model.non_decayable_parameters()) == set( 49 | [ 50 | model["energy_offset"]._offsets, 51 | model["schnet"].scaler.per_element_scaling, 52 | ] 53 | ) 54 | 55 | 56 | def test_opt(): 57 | # test vanilla use 58 | model = SchNet() 59 | opt_factory = Optimizer("Adam", weight_decay=1e-4) 60 | opt = opt_factory(model) 61 | 62 | # check that opt is an Adam optimizer with two parameter groups 63 | assert isinstance(opt, torch.optim.Adam) 64 | 65 | # test error if optimizer class is not found 66 | with pytest.raises(ValueError, match="Could not find optimizer"): 67 | Optimizer("Unknown") 68 | 69 | # test error if optimizer class is not an optimizer 70 | with pytest.raises( 71 | ValueError, 72 | match="Expected the returned optimizer to be an instance of ", 73 | ): 74 | 75 | def fake_opt(*args, **kwargs): 76 | return None 77 | 78 | Optimizer(fake_opt) # type: ignore 79 | 80 | # test custom optimizer class 81 | class CustomOptimizer(torch.optim.Adam): 82 | pass 83 | 84 | opt_factory = Optimizer(CustomOptimizer) 85 | opt = opt_factory(model) 86 | assert isinstance(opt, CustomOptimizer) 87 | 88 | 89 | def test_lr_scheduler(): 90 | params = [torch.nn.Parameter(torch.zeros(1))] 91 | sched_factory = LRScheduler("StepLR", step_size=10, gamma=0.1) 92 | sched = sched_factory(torch.optim.Adam(params)) 93 | assert isinstance(sched, torch.optim.lr_scheduler.StepLR) 94 | 95 | # can't find scheduler 96 | with pytest.raises(ValueError, match="Could not find scheduler"): 97 | LRScheduler("Unknown") 98 | 99 | # custom scheduler 100 | class CustomScheduler(torch.optim.lr_scheduler.StepLR): 101 | pass 102 | 103 | sched_factory = LRScheduler(CustomScheduler, step_size=10, gamma=0.1) 104 | sched = sched_factory(torch.optim.Adam(params)) 105 | assert isinstance(sched, CustomScheduler) 106 | -------------------------------------------------------------------------------- /tests/training/test_pre_fit.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pytest 7 | import torch 8 | from ase.build import molecule 9 | 10 | from graph_pes import AtomicGraph, GraphPESModel 11 | from graph_pes.atomic_graph import ( 12 | number_of_structures, 13 | to_batch, 14 | ) 15 | from graph_pes.models import LennardJonesMixture 16 | from graph_pes.models.addition import AdditionModel 17 | from graph_pes.utils.shift_and_scale import guess_per_element_mean_and_var 18 | 19 | 20 | def _create_batch( 21 | mu: dict[int, float], 22 | sigma: dict[int, float], 23 | weights: list[float] | None = None, 24 | ) -> AtomicGraph: 25 | """ 26 | Create a batch of structures with local energies distributed 27 | according to the given parameters. 28 | 29 | Parameters 30 | ---------- 31 | mu 32 | The per-element mean local energies. 33 | sigma 34 | The per-element standard deviations in local energy. 35 | weights 36 | The relative likelihood of sampling each element. 37 | """ 38 | 39 | N = 1_000 40 | graphs: list[AtomicGraph] = [] 41 | rng = np.random.default_rng(0) 42 | for _ in range(N): 43 | structure_size = rng.integers(4, 10) 44 | Zs = rng.choice(list(mu.keys()), size=structure_size, p=weights) 45 | total_E = 0 46 | for Z in Zs: 47 | total_E += rng.normal(mu[Z], sigma[Z]) 48 | graphs.append( 49 | AtomicGraph.create_with_defaults( 50 | Z=torch.LongTensor(Zs), 51 | R=torch.randn(structure_size, 3), 52 | properties={"energy": torch.tensor(total_E)}, 53 | ) 54 | ) 55 | return to_batch(graphs) 56 | 57 | 58 | def test_guess_per_element_mean_and_var(): 59 | mu = {1: -1.0, 2: -2.0} 60 | sigma = {1: 0.1, 2: 0.2} 61 | batch = _create_batch(mu=mu, sigma=sigma) 62 | 63 | # quickly check that this batch is as expected 64 | assert number_of_structures(batch) == 1_000 65 | assert sorted(torch.unique(batch.Z).tolist()) == [1, 2] 66 | 67 | # calculate the per-element mean and variance 68 | per_structure_quantity = batch.properties["energy"] 69 | means, variances = guess_per_element_mean_and_var( 70 | per_structure_quantity, batch 71 | ) 72 | 73 | # are means roughly right? 74 | for Z, actual_mu in mu.items(): 75 | assert np.isclose(means[Z], actual_mu, atol=0.01) 76 | 77 | # are variances roughly right? 78 | for Z, actual_sigma in sigma.items(): 79 | assert np.isclose(variances[Z], actual_sigma**2, atol=0.01) 80 | 81 | 82 | def test_clamping(): 83 | # variances can not be negative: ensure that they are clamped 84 | mu = {1: -1.0, 2: -2.0} 85 | sigma = {1: 0.0, 2: 1.0} 86 | batch = _create_batch(mu=mu, sigma=sigma) 87 | 88 | # calculate the per-element mean and variance 89 | means, variances = guess_per_element_mean_and_var( 90 | batch.properties["energy"], batch, min_variance=0.01 91 | ) 92 | 93 | # ensure no variance is less than the value we choose to clamp to 94 | for value in variances.values(): 95 | assert value >= 0.01 96 | 97 | 98 | models = [ 99 | LennardJonesMixture(), 100 | AdditionModel(a=LennardJonesMixture(), b=LennardJonesMixture()), 101 | ] 102 | names = ["LennardJonesMixture", "AdditionModel"] 103 | 104 | 105 | @pytest.mark.parametrize("model", models, ids=names) 106 | def test( 107 | tmp_path: Path, 108 | model: GraphPESModel, 109 | caplog: pytest.LogCaptureFixture, 110 | ): 111 | assert model.elements_seen == [] 112 | 113 | # show the model C and H 114 | methane = molecule("CH4") 115 | methane.info["energy"] = 1.0 116 | model.pre_fit_all_components([AtomicGraph.from_ase(methane, cutoff=3.0)]) 117 | assert model.elements_seen == ["H", "C"] 118 | 119 | # check that these are persisted over save and load 120 | torch.save(model, tmp_path / "model.pt") 121 | loaded = torch.load(tmp_path / "model.pt", weights_only=False) 122 | assert loaded.elements_seen == ["H", "C"] 123 | 124 | # show the model C, H, and O 125 | acetaldehyde = molecule("CH3CHO") 126 | acetaldehyde.info["energy"] = 2.0 127 | model.pre_fit_all_components( 128 | [AtomicGraph.from_ase(acetaldehyde, cutoff=3.0)] 129 | ) 130 | assert any( 131 | record.levelname == "WARNING" 132 | and "has already been pre-fitted" in record.message 133 | for record in caplog.records 134 | ) 135 | assert model.elements_seen == ["H", "C", "O"] 136 | 137 | 138 | def test_large_pre_fit(caplog: pytest.LogCaptureFixture): 139 | model = LennardJonesMixture() 140 | methane = molecule("CH4") 141 | graph = AtomicGraph.from_ase(methane, cutoff=0.5) 142 | graphs = [graph] * 10_001 143 | model.pre_fit_all_components(graphs) 144 | assert any( 145 | record.levelname == "WARNING" 146 | and "Pre-fitting on a large dataset" in record.message 147 | for record in caplog.records 148 | ) 149 | -------------------------------------------------------------------------------- /tests/training/test_train_script.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import sys 4 | from pathlib import Path 5 | 6 | import yaml 7 | 8 | from graph_pes.config.training import SWAConfig, TrainingConfig 9 | from graph_pes.scripts.train import train_from_config 10 | from graph_pes.scripts.utils import extract_config_dict_from_command_line 11 | from graph_pes.training.callbacks import WandbLogger 12 | from graph_pes.utils.misc import nested_merge 13 | 14 | from .. import helpers 15 | 16 | 17 | def test_arg_parse(): 18 | config_path = helpers.CONFIGS_DIR / "minimal.yaml" 19 | command = f"""\ 20 | graph-pes-train {config_path} \ 21 | fitting/loader_kwargs/batch_size=32 \ 22 | data/+load_atoms_dataset/n_train=10 23 | """ 24 | sys.argv = command.split() 25 | 26 | config_data = nested_merge( 27 | TrainingConfig.defaults(), 28 | extract_config_dict_from_command_line(""), 29 | ) 30 | assert config_data["fitting"]["loader_kwargs"]["batch_size"] == 32 31 | assert config_data["data"]["+load_atoms_dataset"]["n_train"] == 10 32 | 33 | 34 | def test_train_script(tmp_path: Path): 35 | root = tmp_path / "root" 36 | config = _get_quick_train_config(root) 37 | 38 | train_from_config(config) 39 | 40 | assert root.exists() 41 | sub_dir = next(root.iterdir()) 42 | assert (sub_dir / "model.pt").exists() 43 | 44 | 45 | def test_run_id(tmp_path: Path): 46 | root = tmp_path / "root" 47 | 48 | # first round: train with no explicit run_id ... 49 | config = _get_quick_train_config(root) 50 | assert config["general"]["run_id"] is None 51 | 52 | train_from_config(config) 53 | 54 | # second round: train with an explicit run_id 55 | config = _get_quick_train_config(root) 56 | config["general"]["run_id"] = "explicit-id" 57 | train_from_config(config) 58 | assert (root / "explicit-id").exists() 59 | 60 | # third round: train with the same explicit run_id 61 | # and check that the collision is avoided 62 | config = _get_quick_train_config(root) 63 | config["general"]["run_id"] = "explicit-id" 64 | train_from_config(config) 65 | assert (root / "explicit-id-1").exists() 66 | 67 | 68 | def test_swa(tmp_path: Path, caplog): 69 | root = tmp_path / "root" 70 | config = _get_quick_train_config(root) 71 | config["fitting"]["swa"] = SWAConfig(lr=0.1, start=1) 72 | config["fitting"]["trainer_kwargs"]["max_epochs"] = 3 73 | 74 | train_from_config(config) 75 | 76 | assert "SWA: starting SWA" in caplog.text 77 | 78 | 79 | def _get_quick_train_config(root) -> dict: 80 | config_str = f"""\ 81 | general: 82 | root_dir: {root} 83 | run_id: null 84 | wandb: null 85 | loss: +PerAtomEnergyLoss() 86 | model: 87 | +LennardJones: {{cutoff: 3.0}} 88 | data: 89 | +load_atoms_dataset: 90 | id: {helpers.CU_STRUCTURES_FILE} 91 | cutoff: 3.0 92 | n_train: 6 93 | n_valid: 2 94 | n_test: 2 95 | fitting: 96 | trainer_kwargs: 97 | max_epochs: 1 98 | accelerator: cpu 99 | callbacks: [] 100 | loader_kwargs: 101 | batch_size: 2 102 | num_workers: 0 103 | persistent_workers: false 104 | """ 105 | return nested_merge( 106 | TrainingConfig.defaults(), 107 | yaml.safe_load(config_str), 108 | ) 109 | 110 | 111 | def test_wandb_logging(tmp_path: Path, caplog): 112 | import wandb 113 | 114 | # configure wandb to not actually log anything 115 | wandb.init = lambda *args, **kwargs: None 116 | 117 | logger = WandbLogger( 118 | tmp_path / "logging-name", 119 | project="test-project", 120 | log_epoch=False, 121 | ) 122 | assert logger._id == "logging-name" 123 | assert logger.save_dir == str(tmp_path) 124 | assert logger._project == "test-project" 125 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jla-gardner/graph-pes/c574efb9ee873ddefe9cbd0c62b58777b58d96f2/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/test_analysis.py: -------------------------------------------------------------------------------- 1 | from ase.build import molecule 2 | 3 | from graph_pes.atomic_graph import AtomicGraph 4 | from graph_pes.models import LennardJones 5 | from graph_pes.utils.analysis import dimer_curve, parity_plot 6 | 7 | 8 | def test_parity_plot(): 9 | structures = [molecule("H2O"), molecule("CO2")] 10 | energies = [-14.0, -14.0] 11 | for s, e in zip(structures, energies): 12 | s.info["energy"] = e 13 | 14 | graphs = [AtomicGraph.from_ase(s, cutoff=3.0) for s in structures] 15 | model = LennardJones(cutoff=3.0) 16 | 17 | parity_plot(model, graphs) 18 | 19 | 20 | def test_dimer_curve(): 21 | dimer_curve(LennardJones(cutoff=3.0), system="SiO") 22 | -------------------------------------------------------------------------------- /tests/utils/test_auto_offset.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from typing import Mapping 5 | 6 | import numpy as np 7 | import pytest 8 | import torch 9 | from ase import Atoms 10 | from ase.data import chemical_symbols 11 | 12 | from graph_pes.atomic_graph import AtomicGraph, to_batch 13 | from graph_pes.models import FixedOffset 14 | from graph_pes.models.addition import AdditionModel 15 | from graph_pes.utils.shift_and_scale import ( 16 | add_auto_offset, 17 | guess_per_element_mean_and_var, 18 | ) 19 | 20 | 21 | def _get_random_graphs( 22 | reference: Mapping[str, float | int], 23 | ) -> list[AtomicGraph]: 24 | rng = np.random.default_rng(42) 25 | structures = [] 26 | for _ in range(100): 27 | symbols = rng.choice(list(reference.keys()), size=10) 28 | energy = sum(reference[Z] for Z in symbols) 29 | atoms = Atoms(symbols=symbols) 30 | atoms.info["energy"] = energy 31 | structures.append(atoms) 32 | return [AtomicGraph.from_ase(s, cutoff=0.1) for s in structures] 33 | 34 | 35 | def test_add_auto_offset(): 36 | # step 1: make a collection of structures with energies as the 37 | # sum of known per-element energies 38 | reference = dict(C=-1, H=-2, O=-3) 39 | graphs = _get_random_graphs(reference) 40 | 41 | # step 2: ensure the guessed offsets are close 42 | means, _ = guess_per_element_mean_and_var( 43 | torch.tensor([g.properties["energy"] for g in graphs]), 44 | to_batch(graphs), 45 | ) 46 | means = {chemical_symbols[Z]: float(mu) for Z, mu in means.items()} 47 | for k in reference: 48 | assert means[k] == pytest.approx(reference[k], abs=1e-6) 49 | 50 | # step 3: take an existing model with different offsets, and check 51 | # that the guessed difference is close the actual difference 52 | model_reference = dict(C=2, H=3, O=4) 53 | starting_model = FixedOffset(**model_reference) 54 | 55 | final_model = add_auto_offset(starting_model, graphs) 56 | assert isinstance(final_model, AdditionModel) 57 | assert set(final_model.models.keys()) == {"auto_offset", "base"} 58 | final_model.eval() 59 | 60 | for g in graphs: 61 | assert final_model.predict_energy(g).item() == pytest.approx( 62 | g.properties["energy"].item(), abs=1e-4 63 | ) 64 | 65 | starting_model = AdditionModel(offset=FixedOffset(**model_reference)) 66 | final_model = add_auto_offset(starting_model, graphs) 67 | assert isinstance(final_model, AdditionModel) 68 | assert set(final_model.models.keys()) == {"auto_offset", "offset"} 69 | 70 | 71 | def test_add_auto_offset_no_op(): 72 | reference = dict(C=-1, H=-2, O=-3) 73 | graphs = _get_random_graphs(reference) 74 | 75 | starting_model = FixedOffset(**reference) 76 | final_model = add_auto_offset(starting_model, graphs) 77 | # this should be a no-op 78 | assert final_model is starting_model 79 | 80 | 81 | def test_warning(caplog): 82 | reference = dict(C=-1, H=-2, O=-3) 83 | 84 | # get structures with the same composition 85 | rng = np.random.default_rng(42) 86 | graphs = [] 87 | for _ in range(100): 88 | N = rng.integers(1, 10) 89 | symbols = ["H", "H", "C", "O"] * N 90 | atoms = Atoms(symbols=symbols) 91 | atoms.info["energy"] = sum(reference[Z] for Z in symbols) 92 | graphs.append(AtomicGraph.from_ase(atoms, cutoff=0.1)) 93 | 94 | model_reference = dict(C=2, H=3, O=4) 95 | starting_model = FixedOffset(**model_reference) 96 | 97 | # check that the warning is issued 98 | with caplog.at_level(logging.WARNING): 99 | final_model = add_auto_offset(starting_model, graphs) 100 | assert "no unique solution is possible" in caplog.text 101 | 102 | # check that things still work 103 | final_model.eval() 104 | for g in graphs: 105 | assert final_model.predict_energy(g).item() == pytest.approx( 106 | g.properties["energy"].item(), abs=1e-4 107 | ) 108 | -------------------------------------------------------------------------------- /tests/utils/test_calculator.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import pytest 3 | from ase.build import bulk, molecule 4 | 5 | from graph_pes.atomic_graph import AtomicGraph, PropertyKey 6 | from graph_pes.models import LennardJones 7 | from graph_pes.utils.calculator import merge_predictions 8 | 9 | 10 | def test_calc(): 11 | model = LennardJones() 12 | calc = model.ase_calculator(skin=0.0) 13 | ethanol = molecule("CH3CH2OH") 14 | ethanol.calc = calc 15 | 16 | # ensure right shapes 17 | assert isinstance(ethanol.get_potential_energy(), float) 18 | assert ethanol.get_forces().shape == (9, 3) 19 | 20 | # ensure correctness 21 | g = AtomicGraph.from_ase(ethanol, model.cutoff.item()) 22 | assert ethanol.get_potential_energy() == pytest.approx( 23 | model.predict_energy(g).item() 24 | ) 25 | numpy.testing.assert_allclose( 26 | ethanol.get_forces(), model.predict_forces(g).numpy() 27 | ) 28 | 29 | copper = bulk("Cu") 30 | copper.calc = calc 31 | assert isinstance(copper.get_potential_energy(), float) 32 | assert copper.get_forces().shape == (1, 3) 33 | assert copper.get_stress().shape == (6,) 34 | 35 | g = AtomicGraph.from_ase(copper, model.cutoff.item()) 36 | assert copper.get_potential_energy() == pytest.approx( 37 | model.predict_energy(g).item() 38 | ) 39 | numpy.testing.assert_allclose( 40 | copper.get_forces(), model.predict_forces(g).numpy() 41 | ) 42 | 43 | 44 | def test_calc_all(): 45 | calc = LennardJones().ase_calculator() 46 | molecules = [molecule(s) for s in ["CH4", "H2O", "CH3CH2OH", "C2H6"]] 47 | 48 | # add cell info so we can test stresses 49 | for m in molecules: 50 | m.center(vacuum=10) 51 | 52 | properties: list[PropertyKey] = ["energy", "forces", "stress"] 53 | one_by_one = [] 54 | for m in molecules: 55 | calc.calculate(m, properties) 56 | one_by_one.append(calc.results) 57 | 58 | batched = calc.calculate_all(molecules, properties) 59 | 60 | assert len(one_by_one) == len(batched) == len(molecules) 61 | 62 | for single, parallel in zip(one_by_one, batched): 63 | for key in properties: 64 | numpy.testing.assert_allclose( 65 | single[key], parallel[key], rtol=10, atol=1e-10 66 | ) 67 | 68 | merged = merge_predictions(batched) 69 | for i in range(len(molecules)): 70 | for key in "energy", "stress": 71 | numpy.testing.assert_allclose(merged[key][i], batched[i][key]) 72 | 73 | n_atoms = sum(map(len, molecules)) 74 | assert merged["forces"].shape == (n_atoms, 3) 75 | -------------------------------------------------------------------------------- /tests/utils/test_deploy.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import pytest 6 | import torch 7 | from ase.build import molecule 8 | 9 | from graph_pes import AtomicGraph, GraphPESModel 10 | from graph_pes.atomic_graph import number_of_atoms 11 | from graph_pes.models.pairwise import LennardJones, SmoothedPairPotential 12 | from graph_pes.utils.lammps import ( 13 | as_lammps_data, 14 | deploy_model, 15 | ) 16 | from graph_pes.utils.misc import full_3x3_to_voigt_6 17 | 18 | from .. import helpers 19 | 20 | 21 | # ignore warnings about lack of energy labels for pre-fitting: not important 22 | @pytest.mark.filterwarnings("ignore:.*No energy data found in training data.*") 23 | @helpers.parameterise_all_models(expected_elements=["C", "H", "O"]) 24 | def test_deploy(model: GraphPESModel, tmp_path: Path): 25 | dummy_graph = AtomicGraph.from_ase(molecule("CH3CH2OH"), cutoff=5.0) 26 | # required by some models before making predictions 27 | model.pre_fit_all_components([dummy_graph]) 28 | 29 | model_cutoff = float(model.cutoff) 30 | graph = AtomicGraph.from_ase( 31 | molecule("CH3CH2OH", vacuum=2), 32 | cutoff=model_cutoff, 33 | ) 34 | outputs = { 35 | k: t.double() for k, t in model.get_all_PES_predictions(graph).items() 36 | } 37 | 38 | # 1. saving and unsaving works 39 | torch.save(model, tmp_path / "model.pt") 40 | loaded_model = torch.load(tmp_path / "model.pt", weights_only=False) 41 | assert isinstance(loaded_model, GraphPESModel) 42 | torch.testing.assert_close( 43 | model.predict_forces(graph), 44 | loaded_model.predict_forces(graph), 45 | atol=1e-6, 46 | rtol=1e-6, 47 | ) 48 | 49 | # 2. deploy the model 50 | save_path = tmp_path / "lammps-model.pt" 51 | deploy_model(model, path=save_path) 52 | 53 | # 3. load the model back in 54 | lammps_model = torch.jit.load(save_path) 55 | assert isinstance(lammps_model, torch.jit.ScriptModule) 56 | assert lammps_model.get_cutoff() == model_cutoff 57 | 58 | # 4. test outputs 59 | lammps_data = as_lammps_data(graph, compute_virial=True) 60 | lammps_outputs = lammps_model(lammps_data) 61 | assert isinstance(lammps_outputs, dict) 62 | assert set(lammps_outputs.keys()) == { 63 | "energy", 64 | "local_energies", 65 | "forces", 66 | "virial", 67 | } 68 | assert lammps_outputs["energy"].shape == torch.Size([]) 69 | torch.testing.assert_close( 70 | outputs["energy"], 71 | lammps_outputs["energy"], 72 | atol=1e-6, 73 | rtol=1e-6, 74 | ) 75 | 76 | assert lammps_outputs["local_energies"].shape == (number_of_atoms(graph),) 77 | torch.testing.assert_close( 78 | outputs["local_energies"], 79 | lammps_outputs["local_energies"], 80 | atol=1e-6, 81 | rtol=1e-6, 82 | ) 83 | 84 | assert lammps_outputs["forces"].shape == graph.R.shape 85 | torch.testing.assert_close( 86 | outputs["forces"], 87 | lammps_outputs["forces"], 88 | atol=1e-6, 89 | rtol=1e-6, 90 | ) 91 | 92 | assert lammps_outputs["virial"].shape == (6,) 93 | torch.testing.assert_close( 94 | full_3x3_to_voigt_6(outputs["virial"]), 95 | lammps_outputs["virial"], 96 | atol=1e-6, 97 | rtol=1e-6, 98 | ) 99 | 100 | 101 | def test_deploy_smoothed_pair_potential(tmp_path: Path): 102 | model = SmoothedPairPotential(LennardJones(cutoff=2.5)) 103 | test_deploy(model, tmp_path) 104 | -------------------------------------------------------------------------------- /tests/utils/test_dtypes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ase.build import molecule 3 | 4 | from graph_pes import AtomicGraph 5 | from graph_pes.models import LennardJones 6 | 7 | 8 | def test_dtypes(): 9 | torch.set_default_dtype(torch.float64) 10 | g = AtomicGraph.from_ase(molecule("C6H6")) 11 | assert g.R.dtype == torch.float64 12 | 13 | model = LennardJones() 14 | assert model._log_epsilon.dtype == torch.float64 15 | 16 | model(g) 17 | 18 | torch.set_default_dtype(torch.float32) 19 | g = AtomicGraph.from_ase(molecule("C6H6")) 20 | assert g.R.dtype == torch.float32 21 | 22 | model = LennardJones() 23 | assert model._log_epsilon.dtype == torch.float32 24 | 25 | model(g) 26 | -------------------------------------------------------------------------------- /tests/utils/test_lammps.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from ase.build import molecule 4 | 5 | from graph_pes import AtomicGraph 6 | from graph_pes.atomic_graph import PropertyKey 7 | from graph_pes.models import LennardJones 8 | from graph_pes.utils.lammps import LAMMPSModel, as_lammps_data 9 | 10 | CUTOFF = 1.5 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "compute_virial", 15 | [True, False], 16 | ) 17 | def test_lammps_model(compute_virial: bool): 18 | # generate a structure 19 | structure = molecule("CH3CH2OH") 20 | if compute_virial: 21 | # ensure the structure has a cell 22 | structure.center(vacuum=5.0) 23 | graph = AtomicGraph.from_ase(structure, cutoff=CUTOFF) 24 | 25 | # create a normal model, and get normal predictions 26 | model = LennardJones(cutoff=CUTOFF) 27 | props: list[PropertyKey] = ["energy", "forces"] 28 | if compute_virial: 29 | props.append("virial") 30 | outputs = model.predict(graph, properties=props) 31 | 32 | # create a LAMMPS model, and get LAMMPS predictions 33 | lammps_model = LAMMPSModel(model) 34 | 35 | assert lammps_model.get_cutoff() == torch.tensor(CUTOFF) 36 | 37 | lammps_data = as_lammps_data(graph, compute_virial=compute_virial) 38 | lammps_outputs = lammps_model(lammps_data) 39 | 40 | # check outputs 41 | if compute_virial: 42 | assert "virial" in lammps_outputs 43 | assert lammps_outputs["virial"].shape == (6,) 44 | 45 | assert torch.allclose( 46 | outputs["energy"].float(), 47 | lammps_outputs["energy"].float(), 48 | ) 49 | 50 | 51 | def test_debug_logging(capsys): 52 | # generate a structure 53 | structure = molecule("CH3CH2OH") 54 | structure.center(vacuum=5.0) 55 | graph = AtomicGraph.from_ase(structure, cutoff=CUTOFF) 56 | 57 | # create a LAMMPS model, and get LAMMPS predictions 58 | lammps_model = LAMMPSModel(LennardJones()) 59 | 60 | lammps_data = as_lammps_data(graph, compute_virial=True, debug=True) 61 | lammps_model(lammps_data) 62 | 63 | logs = capsys.readouterr().out 64 | assert "Received graph:" in logs 65 | -------------------------------------------------------------------------------- /tests/utils/test_misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | from graph_pes.utils.misc import ( 8 | as_possible_tensor, 9 | build_single_nested_dict, 10 | differentiate, 11 | full_3x3_to_voigt_6, 12 | nested_merge, 13 | nested_merge_all, 14 | voigt_6_to_full_3x3, 15 | ) 16 | 17 | possible_tensors = [ 18 | (1, True), 19 | (1.0, True), 20 | ([1, 2, 3], True), 21 | (torch.tensor([1, 2, 3]), True), 22 | (np.array([1, 2, 3]), True), 23 | (np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64), True), 24 | ("hello", False), 25 | ] 26 | 27 | 28 | @pytest.mark.parametrize("obj, can_be_converted", possible_tensors) 29 | def test_as_possible_tensor(obj, can_be_converted): 30 | if can_be_converted: 31 | assert isinstance(as_possible_tensor(obj), torch.Tensor) 32 | else: 33 | assert as_possible_tensor(obj) is None 34 | 35 | 36 | def test_differentiate(): 37 | # test that it works 38 | x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) 39 | y = x.sum() 40 | dy_dx = differentiate(y, x) 41 | assert torch.allclose(dy_dx, torch.ones_like(x)) 42 | 43 | # test that it works with a non-scalar y 44 | x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) 45 | y = x**2 46 | dy_dx = differentiate(y, x) 47 | assert torch.allclose(dy_dx, 2 * x) 48 | 49 | # test that it works if x is not part of the computation graph 50 | x = torch.tensor([1.0, 2.0, 3.0]) 51 | z = torch.tensor([4.0, 5.0, 6.0], requires_grad=True) 52 | y = z.sum() 53 | dy_dx = differentiate(y, x) 54 | assert torch.allclose(dy_dx, torch.zeros_like(x)) 55 | 56 | # finally, we want to test that the gradient itself has a gradient 57 | x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) 58 | y = (x**2).sum() 59 | dy_dx = differentiate(y, x, keep_graph=True) 60 | dy_dx2 = differentiate(dy_dx, x) 61 | assert torch.allclose(dy_dx2, 2 * torch.ones_like(x)) 62 | 63 | 64 | def test_nested_merge(): 65 | a = {"a": 1, "b": {"c": 2}, "d": 3} 66 | b = {"a": 3, "b": {"c": 4}} 67 | c = nested_merge(a, b) 68 | assert c == {"a": 3, "b": {"c": 4}, "d": 3}, "nested_merge failed" 69 | assert a == {"a": 1, "b": {"c": 2}, "d": 3}, "nested_merge mutated a" 70 | assert b == {"a": 3, "b": {"c": 4}}, "nested_merge mutated b" 71 | 72 | 73 | def test_build_single_nested_dict(): 74 | assert build_single_nested_dict(["a", "b", "c"], 4) == { 75 | "a": {"b": {"c": 4}} 76 | } 77 | 78 | 79 | def test_nested_merge_all(): 80 | assert nested_merge_all({"a": 1}, {"a": 2, "b": 1}, {"a": 3}) == { 81 | "a": 3, 82 | "b": 1, 83 | } 84 | 85 | assert nested_merge_all( 86 | {"a": {"b": {"c": 1}}}, 87 | {"a": {"b": {"d": 2}}}, 88 | {"a": {"b": {"c": 2}}}, 89 | ) == {"a": {"b": {"c": 2, "d": 2}}} 90 | 91 | 92 | def test_stress_conversions(): 93 | # non-batched 94 | stress = torch.rand(3, 3) 95 | # symmetrize: 96 | stress = (stress + stress.T) / 2 97 | voigt = full_3x3_to_voigt_6(stress) 98 | assert voigt.shape == (6,) 99 | stress_again = voigt_6_to_full_3x3(voigt) 100 | torch.testing.assert_close(stress_again, stress) 101 | 102 | # batched 103 | stress = torch.rand(2, 3, 3) 104 | # symmetrize: 105 | stress = (stress + stress.transpose(1, 2)) / 2 106 | voigt = full_3x3_to_voigt_6(stress) 107 | assert voigt.shape == (2, 6) 108 | stress_again = voigt_6_to_full_3x3(voigt) 109 | torch.testing.assert_close(stress_again, stress) 110 | -------------------------------------------------------------------------------- /tests/utils/test_multi_sequence.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from graph_pes.utils.misc import MultiSequence 4 | 5 | 6 | def test_multi_sequence(): 7 | a = [1, 2, 3, 4, 5] 8 | b = [6, 7, 8, 9, 10] 9 | 10 | ms = MultiSequence([a, b]) 11 | assert len(ms) == 10 12 | assert ms[0] == 1 13 | assert ms[5] == 6 14 | 15 | # test slicing 16 | sliced = ms[2:5] 17 | assert isinstance(sliced, MultiSequence) 18 | assert len(sliced) == 3 19 | assert list(sliced) == [3, 4, 5] 20 | # nested slicing 21 | assert isinstance(sliced[:2], MultiSequence) 22 | assert list(sliced[:2]) == [3, 4] 23 | 24 | # test slicing across the boundary of two sequences 25 | sliced = ms[3:7] 26 | assert len(sliced) == 4 27 | assert list(sliced) == [4, 5, 6, 7] 28 | 29 | # test slicing with a step 30 | sliced = ms[::2] 31 | assert len(sliced) == 5 32 | assert list(sliced) == [1, 3, 5, 7, 9] 33 | 34 | # test negative slicing 35 | sliced = ms[-3:] 36 | assert len(sliced) == 3 37 | assert list(sliced) == [8, 9, 10] 38 | 39 | sliced = ms[::-1] 40 | assert len(sliced) == 10 41 | assert list(sliced) == [10, 9, 8, 7, 6, 5, 4, 3, 2, 1] 42 | 43 | with pytest.raises(IndexError): 44 | ms[100] 45 | -------------------------------------------------------------------------------- /tests/utils/test_nn.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | import torch 5 | 6 | from graph_pes.utils.misc import MAX_Z, left_aligned_div, left_aligned_mul 7 | from graph_pes.utils.nn import ( 8 | MLP, 9 | AtomicOneHot, 10 | PerElementEmbedding, 11 | PerElementParameter, 12 | UniformModuleDict, 13 | UniformModuleList, 14 | parse_activation, 15 | ) 16 | 17 | 18 | def test_per_element_parameter(tmp_path): 19 | pep = PerElementParameter.of_length(5) 20 | assert pep._index_dims == 1 21 | assert pep.data.shape == (MAX_Z + 1, 5) 22 | assert isinstance(pep, PerElementParameter) 23 | assert isinstance(pep, torch.Tensor) 24 | assert isinstance(pep, torch.nn.Parameter) 25 | 26 | # no elements have been registered, so there should (appear to) be no 27 | # trainable parameters 28 | assert pep.numel() == 0 29 | 30 | # register the parameter for use with hydrogen 31 | pep.register_elements([1]) 32 | assert pep.numel() == 5 33 | 34 | # test save and loading 35 | torch.save(pep, tmp_path / "pep.pt") 36 | pep_loaded = torch.load(tmp_path / "pep.pt", weights_only=False) 37 | assert pep_loaded.numel() == 5 38 | assert pep.data.allclose(pep_loaded.data) 39 | assert pep.requires_grad == pep_loaded.requires_grad 40 | assert pep._accessed_Zs == pep_loaded._accessed_Zs 41 | assert pep._index_dims == pep_loaded._index_dims 42 | 43 | # test default value init 44 | assert PerElementParameter.of_length(1, default_value=1.0).data.allclose( 45 | torch.ones(MAX_Z + 1) 46 | ) 47 | 48 | # test shape api 49 | pep = PerElementParameter.of_shape((5, 5), index_dims=2) 50 | assert pep.data.shape == (MAX_Z + 1, MAX_Z + 1, 5, 5) 51 | 52 | # test errors 53 | with pytest.raises(ValueError, match="Unknown element: ZZZ"): 54 | PerElementParameter.from_dict(ZZZ=1) 55 | 56 | 57 | def test_per_element_embedding(): 58 | embedding = PerElementEmbedding(10) 59 | embedding._embeddings.register_elements([1, 2, 3, 4, 5]) 60 | Z = torch.tensor([1, 2, 3, 4, 5]) 61 | assert embedding(Z).shape == (5, 10) 62 | assert embedding.parameters().__next__().numel() == 50 63 | 64 | 65 | def test_mlp(): 66 | mlp = MLP([10, 20, 1]) 67 | 68 | # test behaviour 69 | assert mlp(torch.zeros(10)).shape == (1,) 70 | 71 | # test properties 72 | assert mlp.input_size == 10 73 | assert mlp.output_size == 1 74 | 75 | # test internals 76 | assert len(mlp.linear_layers) == 2 77 | 78 | # test nice repr 79 | assert "MLP(10 → 20 → 1" in str(mlp) 80 | 81 | 82 | def test_activations(): 83 | act = parse_activation("ReLU") 84 | assert act(torch.tensor([-1.0])).item() == 0.0 85 | 86 | with pytest.raises( 87 | ValueError, match="Activation function ZZZ not found in `torch.nn`." 88 | ): 89 | parse_activation("ZZZ") 90 | 91 | 92 | def test_one_hot(): 93 | one_hot = AtomicOneHot(["H", "C", "O"]) 94 | 95 | Z = torch.tensor([1, 6, 8]) 96 | Z_emb = one_hot(Z) 97 | 98 | assert Z_emb.shape == (3, 3) 99 | assert Z_emb.allclose(torch.eye(3)) 100 | 101 | with pytest.raises(ValueError, match="Unknown element"): 102 | one_hot(torch.tensor([2])) 103 | 104 | 105 | def test_module_dict(): 106 | umd = UniformModuleDict( 107 | a=torch.nn.Linear(10, 10), 108 | b=torch.nn.Linear(10, 10), 109 | ) 110 | 111 | assert len(umd) == 2 112 | 113 | for k, v in umd.items(): 114 | assert isinstance(k, str) 115 | assert isinstance(v, torch.nn.Linear) 116 | 117 | assert "a" in umd 118 | 119 | b = umd.pop("b") 120 | assert b is not None 121 | assert len(umd) == 1 122 | 123 | 124 | def test_module_list(): 125 | uml = UniformModuleList( 126 | [ 127 | torch.nn.Linear(10, 10), 128 | torch.nn.Linear(10, 10), 129 | ] 130 | ) 131 | 132 | assert len(uml) == 2 133 | assert isinstance(uml[0], torch.nn.Linear) 134 | 135 | uml[1] = torch.nn.Linear(100, 100) 136 | assert isinstance(uml[1], torch.nn.Linear) 137 | assert uml[1].in_features == 100 138 | 139 | uml.append(torch.nn.Linear(1000, 1000)) 140 | assert len(uml) == 3 141 | 142 | lin = uml.pop(-1) 143 | assert len(uml) == 2 144 | assert lin.in_features == 1000 145 | 146 | uml.insert(1, torch.nn.Linear(10000, 10000)) 147 | assert len(uml) == 3 148 | assert uml[1].in_features == 10000 149 | 150 | 151 | @pytest.mark.parametrize("x_dim", [0, 1, 2, 3]) 152 | def test_left_aligned_ops(x_dim: int): 153 | N = 10 154 | 155 | y = torch.ones((N,)) * 5 156 | 157 | if x_dim == 0: 158 | x = torch.randn(N) 159 | else: 160 | x = torch.randn(N, *[i + 1 for i in range(x_dim)]) 161 | 162 | z = left_aligned_mul(x, y) 163 | assert z.shape == x.shape 164 | assert z.allclose(x * 5) 165 | 166 | z = left_aligned_div(x, y) 167 | assert z.shape == x.shape 168 | assert z.allclose(x / 5) 169 | -------------------------------------------------------------------------------- /tests/utils/test_sampling.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from graph_pes.utils.sampling import SequenceSampler 4 | 5 | 6 | @pytest.fixture 7 | def sample_list(): 8 | return list(range(10)) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 9 | 10 | 11 | def test_sequence_sampler_init(sample_list): 12 | # Test initialization with default indices 13 | sampler = SequenceSampler(sample_list) 14 | assert len(sampler) == len(sample_list) 15 | 16 | # Test initialization with specific indices 17 | indices = [0, 2, 4] 18 | sampler = SequenceSampler(sample_list, indices) 19 | assert len(sampler) == len(indices) 20 | 21 | 22 | def test_sequence_sampler_getitem(sample_list): 23 | sampler = SequenceSampler(sample_list) 24 | 25 | # Test integer indexing 26 | assert sampler[0] == 0 27 | assert sampler[-1] == 9 28 | 29 | # Test slice indexing 30 | sliced = sampler[2:5] 31 | assert isinstance(sliced, SequenceSampler) 32 | assert list(sliced) == [2, 3, 4] 33 | 34 | 35 | def test_sequence_sampler_len(sample_list): 36 | sampler = SequenceSampler(sample_list) 37 | assert len(sampler) == len(sample_list) 38 | 39 | sampler = SequenceSampler(sample_list, [0, 1, 2]) 40 | assert len(sampler) == 3 41 | 42 | 43 | def test_sequence_sampler_iter(sample_list): 44 | sampler = SequenceSampler(sample_list) 45 | assert list(sampler) == sample_list 46 | 47 | 48 | def test_sequence_sampler_shuffled(sample_list): 49 | sampler = SequenceSampler(sample_list) 50 | shuffled = sampler.shuffled(seed=42) 51 | 52 | # Test that shuffled returns a SequenceSampler 53 | assert isinstance(shuffled, SequenceSampler) 54 | 55 | # Test that shuffled has same length 56 | assert len(shuffled) == len(sampler) 57 | 58 | # Test that shuffled contains all original elements 59 | assert sorted(list(shuffled)) == sorted(sample_list) 60 | 61 | # Test that shuffling is deterministic with same seed 62 | shuffled2 = sampler.shuffled(seed=42) 63 | assert list(shuffled) == list(shuffled2) 64 | 65 | # Test that different seeds give different orders 66 | shuffled3 = sampler.shuffled(seed=43) 67 | assert list(shuffled) != list(shuffled3) 68 | 69 | 70 | def test_sequence_sampler_sample_at_most(sample_list): 71 | sampler = SequenceSampler(sample_list) 72 | 73 | # Test sampling less than total length 74 | sample = sampler.sample_at_most(5, seed=42) 75 | assert len(sample) == 5 76 | assert isinstance(sample, SequenceSampler) 77 | 78 | # Test sampling more than total length 79 | sample = sampler.sample_at_most(15, seed=42) 80 | assert len(sample) == len(sample_list) 81 | 82 | 83 | def test_sequence_sampler_with_custom_sequence(): 84 | class CustomSequence: 85 | def __init__(self, data): 86 | self.data = data 87 | 88 | def __len__(self): 89 | return len(self.data) 90 | 91 | def __getitem__(self, idx): 92 | return self.data[idx] 93 | 94 | custom_seq = CustomSequence(["a", "b", "c"]) 95 | sampler = SequenceSampler(custom_seq) # type: ignore 96 | 97 | assert len(sampler) == 3 98 | assert list(sampler) == ["a", "b", "c"] 99 | 100 | 101 | def test_sequence_sampler_error_handling(): 102 | with pytest.raises(TypeError): 103 | # Test with non-sequence 104 | SequenceSampler(42) # type: ignore 105 | 106 | sampler = SequenceSampler([1, 2, 3]) 107 | with pytest.raises(IndexError): 108 | # Test index out of range 109 | sampler[10] 110 | --------------------------------------------------------------------------------