├── .codecov.yml ├── .github ├── PULL_REQUEST_TEMPLATE │ ├── pull_request_new-architecture.md │ └── pull_request_template.md ├── dependabot.yml ├── pull_request_template.md └── workflows │ ├── architecture-tests.yml │ ├── build.yml │ ├── docs.yml │ ├── lint.yml │ ├── pr-docs-preview.yml │ ├── release.yml │ └── tests.yml ├── .gitignore ├── .readthedocs.yml ├── CODEOWNERS ├── CONTRIBUTING.rst ├── LICENSE ├── MANIFEST.in ├── README.rst ├── developer └── clean-tempfiles.sh ├── docs ├── generate_examples │ ├── conf.py │ └── generate-examples.py ├── requirements.txt ├── src │ ├── _static │ │ └── .gitkeep │ ├── _templates │ │ └── .gitkeep │ ├── advanced-concepts │ │ ├── auto-restarting.rst │ │ ├── auxiliary-outputs.rst │ │ ├── fine-tuning.rst │ │ ├── fitting-generic-targets.rst │ │ ├── index.rst │ │ ├── multi-gpu.rst │ │ ├── output-naming.rst │ │ └── transfer-learning.rst │ ├── architectures │ │ ├── deprecated-pet.rst │ │ ├── gap.rst │ │ ├── index.rst │ │ ├── nanopet.rst │ │ ├── pet.rst │ │ └── soap-bpnn.rst │ ├── conf.py │ ├── dev-docs │ │ ├── architecture-life-cycle.rst │ │ ├── changelog.rst │ │ ├── cli │ │ │ ├── eval.rst │ │ │ ├── export.rst │ │ │ ├── formatter.rst │ │ │ ├── index.rst │ │ │ └── train.rst │ │ ├── dataset-information.rst │ │ ├── getting-started.rst │ │ ├── index.rst │ │ ├── new-architecture.rst │ │ └── utils │ │ │ ├── additive │ │ │ ├── composition.rst │ │ │ ├── index.rst │ │ │ ├── remove_additive.rst │ │ │ └── zbl.rst │ │ │ ├── architectures.rst │ │ │ ├── augmentation.rst │ │ │ ├── data │ │ │ ├── combine_dataloaders.rst │ │ │ ├── dataset.rst │ │ │ ├── get_dataset.rst │ │ │ ├── index.rst │ │ │ ├── readers.rst │ │ │ ├── systems_to_ase.rst │ │ │ └── writers.rst │ │ │ ├── devices.rst │ │ │ ├── dtype.rst │ │ │ ├── errors.rst │ │ │ ├── evaluate_model.rst │ │ │ ├── external_naming.rst │ │ │ ├── index.rst │ │ │ ├── io.rst │ │ │ ├── jsonschema.rst │ │ │ ├── logging.rst │ │ │ ├── long_range.rst │ │ │ ├── loss.rst │ │ │ ├── metrics.rst │ │ │ ├── neighbor_lists.rst │ │ │ ├── omegaconf.rst │ │ │ ├── output_gradient.rst │ │ │ ├── per_atom.rst │ │ │ ├── scaler.rst │ │ │ ├── sum_over_atoms.rst │ │ │ ├── transfer.rst │ │ │ └── units.rst │ ├── getting-started │ │ ├── advanced_base_config.rst │ │ ├── checkpoints.rst │ │ ├── custom_dataset_conf.rst │ │ ├── index.rst │ │ ├── installation.rst │ │ ├── override.rst │ │ └── units.rst │ ├── index.rst │ ├── logo │ │ ├── metatrain-512.png │ │ ├── metatrain-64.png │ │ ├── metatrain-text-dark.svg │ │ ├── metatrain-text.svg │ │ └── metatrain.svg │ └── tutorials │ │ └── index.rst └── static │ ├── images │ ├── metatrain-dark.png │ └── metatrain.png │ ├── qm9 │ ├── eval.yaml │ ├── options.yaml │ └── qm9_reduced_100.xyz │ └── refs.bib ├── examples ├── README.rst ├── ase │ ├── README.rst │ ├── ethanol_reduced_100.xyz │ ├── options.yaml │ ├── run_ase.py │ └── train.sh ├── basic_usage │ ├── README.rst │ ├── eval.yaml │ ├── options.yaml │ ├── qm9_reduced_100.xyz │ └── usage.sh ├── multi-gpu │ └── soap-bpnn │ │ ├── options-distributed.yaml │ │ └── submit-distributed.sh ├── programmatic │ ├── disk_dataset │ │ ├── README.rst │ │ ├── disk_dataset.py │ │ └── qm9_reduced_100.xyz │ ├── llpr │ │ ├── README.rst │ │ ├── ethanol_reduced_100.xyz │ │ ├── llpr.py │ │ ├── options.yaml │ │ ├── qm9_reduced_100.xyz │ │ └── train.sh │ ├── llpr_forces │ │ ├── ethanol_reduced_100.xyz │ │ ├── force_llpr.py │ │ ├── options.yaml │ │ ├── readme.txt │ │ └── split.py │ └── use_architectures_outside │ │ ├── README.rst │ │ ├── qm9_reduced_100.xyz │ │ └── use_outside.py └── zbl │ ├── README.rst │ ├── dimers.py │ ├── ethanol_reduced_100.xyz │ ├── options_no_zbl.yaml │ ├── options_zbl.yaml │ └── train.sh ├── pyproject.toml ├── src └── metatrain │ ├── __init__.py │ ├── __main__.py │ ├── cli │ ├── __init__.py │ ├── eval.py │ ├── export.py │ ├── formatter.py │ └── train.py │ ├── deprecated │ ├── __init__.py │ └── pet │ │ ├── __init__.py │ │ ├── default-hypers.yaml │ │ ├── model.py │ │ ├── modules │ │ ├── __init__.py │ │ ├── analysis.py │ │ ├── data_preparation.py │ │ ├── hypers.py │ │ ├── long_range.py │ │ ├── molecule.py │ │ ├── pet.py │ │ ├── transformer.py │ │ └── utilities.py │ │ ├── schema-hypers.json │ │ ├── tests │ │ ├── __init__.py │ │ ├── test_exported.py │ │ ├── test_functionality.py │ │ ├── test_pet_compatibility.py │ │ └── test_torchscript.py │ │ ├── trainer.py │ │ └── utils │ │ ├── __init__.py │ │ ├── dataset_to_ase.py │ │ ├── fine_tuning.py │ │ ├── load_raw_pet_model.py │ │ ├── systems_to_batch_dict.py │ │ ├── update_hypers.py │ │ └── update_state_dict.py │ ├── experimental │ ├── __init__.py │ └── nanopet │ │ ├── __init__.py │ │ ├── default-hypers.yaml │ │ ├── model.py │ │ ├── modules │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── encoder.py │ │ ├── feedforward.py │ │ ├── nef.py │ │ ├── radial_mask.py │ │ ├── structures.py │ │ └── transformer.py │ │ ├── schema-hypers.json │ │ ├── tests │ │ ├── __init__.py │ │ ├── test_continue.py │ │ ├── test_exported.py │ │ ├── test_functionality.py │ │ ├── test_regression.py │ │ └── test_torchscript.py │ │ └── trainer.py │ ├── gap │ ├── __init__.py │ ├── default-hypers.yaml │ ├── model.py │ ├── schema-hypers.json │ ├── tests │ │ ├── __init__.py │ │ ├── ethanol_reduced_100.xyz │ │ ├── options-gap.yaml │ │ ├── test_errors.py │ │ ├── test_exported.py │ │ ├── test_regression.py │ │ └── test_torchscript.py │ └── trainer.py │ ├── pet │ ├── __init__.py │ ├── default-hypers.yaml │ ├── model.py │ ├── modules │ │ ├── compatibility.py │ │ ├── finetuning.py │ │ ├── nef.py │ │ ├── structures.py │ │ ├── transformer.py │ │ └── utilities.py │ ├── schema-hypers.json │ ├── tests │ │ ├── __init__.py │ │ ├── test_autograd.py │ │ ├── test_continue.py │ │ ├── test_exported.py │ │ ├── test_finetuning.py │ │ ├── test_functionality.py │ │ ├── test_long_range.py │ │ ├── test_pet_compatibility.py │ │ ├── test_regression.py │ │ └── test_torchscript.py │ └── trainer.py │ ├── share │ ├── metatrain-completion.bash │ ├── schema-base.json │ └── schema-dataset.json │ ├── soap_bpnn │ ├── __init__.py │ ├── default-hypers.yaml │ ├── model.py │ ├── schema-hypers.json │ ├── spherical.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_continue.py │ │ ├── test_equivariance.py │ │ ├── test_exported.py │ │ ├── test_functionality.py │ │ ├── test_regression.py │ │ └── test_torchscript.py │ └── trainer.py │ └── utils │ ├── __init__.py │ ├── abc.py │ ├── additive │ ├── __init__.py │ ├── composition.py │ ├── remove.py │ └── zbl.py │ ├── architectures.py │ ├── augmentation.py │ ├── data │ ├── __init__.py │ ├── combine_dataloaders.py │ ├── dataset.py │ ├── get_dataset.py │ ├── readers │ │ ├── __init__.py │ │ ├── ase.py │ │ ├── metatensor.py │ │ └── readers.py │ ├── system_to_ase.py │ ├── target_info.py │ └── writers │ │ ├── __init__.py │ │ ├── metatensor.py │ │ └── xyz.py │ ├── devices.py │ ├── distributed │ ├── distributed_data_parallel.py │ ├── logging.py │ └── slurm.py │ ├── dtype.py │ ├── errors.py │ ├── evaluate_model.py │ ├── external_naming.py │ ├── io.py │ ├── jsonschema.py │ ├── llpr.py │ ├── logging.py │ ├── long_range.py │ ├── loss.py │ ├── metadata.py │ ├── metrics.py │ ├── neighbor_lists.py │ ├── omegaconf.py │ ├── output_gradient.py │ ├── per_atom.py │ ├── scaler.py │ ├── sum_over_atoms.py │ ├── testing │ ├── __init__.py │ └── equivariance.py │ ├── transfer.py │ └── units.py ├── tests ├── cli │ ├── __init__.py │ ├── dump_spherical_targets.py │ ├── test_cli.py │ ├── test_eval_model.py │ ├── test_export_model.py │ ├── test_formatter.py │ └── test_train_model.py ├── distributed │ ├── ethanol_reduced_100.xyz │ ├── options-distributed.yaml │ ├── options.yaml │ ├── readme.txt │ ├── submit-distributed.sh │ └── submit.sh ├── resources │ ├── carbon_reduced_100.xyz │ ├── ethanol_reduced_100.xyz │ ├── eval.yaml │ ├── generate-outputs.sh │ ├── options-nanopet.yaml │ ├── options.yaml │ ├── qm7x_reduced_100.xyz │ ├── qm9_reduced_100.xyz │ └── test.yaml ├── test_init.py └── utils │ ├── __init__.py │ ├── data │ ├── test_combine_dataloaders.py │ ├── test_dataset.py │ ├── test_get_dataset.py │ ├── test_readers.py │ ├── test_readers_ase.py │ ├── test_readers_metatensor.py │ ├── test_system_to_ase.py │ ├── test_target_info.py │ ├── test_targets_ase.py │ └── test_writers.py │ ├── test_additive.py │ ├── test_architectures.py │ ├── test_device.py │ ├── test_dtype.py │ ├── test_errors.py │ ├── test_evaluate_model.py │ ├── test_external_naming.py │ ├── test_io.py │ ├── test_jsonschema.py │ ├── test_llpr.py │ ├── test_logging.py │ ├── test_long_range.py │ ├── test_loss.py │ ├── test_metadata.py │ ├── test_metrics.py │ ├── test_neighbor_list.py │ ├── test_omegaconf.py │ ├── test_output_gradient.py │ ├── test_per_atom.py │ ├── test_scaler.py │ ├── test_sum_over_atoms.py │ ├── test_transfer.py │ └── test_units.py └── tox.ini /.codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: 90% 6 | patch: 7 | default: 8 | informational: true 9 | ignore: 10 | - "tests/.*" 11 | - "examples/.*" 12 | - "src/metatrain/deprecated/.*" 13 | - "src/metatrain/experimental/.*" 14 | - "src/metatrain/gap/.*" 15 | - "src/metatrain/pet/.*" 16 | - "src/metatrain/soap_bpnn/.*" 17 | - "src/metatrain/utils/distributed/.*" 18 | - "src/metatrain/utils/sum_over_atoms.py" 19 | - "src/metatrain/utils/augmentation.py" 20 | 21 | comment: false 22 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/pull_request_new-architecture.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | # Contributor (creator of pull-request) checklist 6 | 7 | - [ ] Add your architecture to the experimental/stable folder. See the 8 | [docs/src/dev-docs/architecture-life-cycle.rst](Architecture life cycle) document for 9 | requirements. 10 | `src/metatrain/experimental/` 11 | - [ ] Add default hyperparameter file to 12 | `src/metatrain/experimental//default-hypers.yml` 13 | - [ ] Add a `.yml` file into github workflows `.github/workflow/.yml` 14 | - [ ] Architecture dependencies entry in the `optional-dependencies` section in the 15 | `pyproject.toml` 16 | - [ ] Tests: torch-scriptability, basic functionality (invariance, fitting, prediction) 17 | - [ ] Add maintainers as codeowners in [CODEOWNERS](CODEOWNERS) 18 | 19 | # Reviewer checklist 20 | 21 | ## New experimental architectures 22 | 23 | - [ ] Capability to fit at least a single quantity and predict it, verified through CI 24 | tests. 25 | - [ ] Compatibility with JIT compilation using `TorchScript 26 | `_. 27 | - [ ] Provision of reasonable default hyperparameters. 28 | - [ ] A contact person designated as the maintainer, mentioned in `__maintainers__` and the `CODEOWNERS` file 29 | - [ ] All external dependencies must be pip-installable. While not required to be on 30 | PyPI, a public git repository or another public URL with a repository is acceptable. 31 | 32 | 33 | ## New stable architectures 34 | - [ ] Provision of regression prediction tests with a small (not exported) checkpoint 35 | file. 36 | - [ ] Comprehensive architecture documentation 37 | - [ ] If an architecture has external dependencies, all must be publicly available on 38 | PyPI. 39 | - [ ] Adherence to the standard output infrastructure of `metatrain`, including 40 | logging and model save locations. 41 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | # Contributor (creator of pull-request) checklist 6 | 7 | - [ ] Tests updated (for new features and bugfixes)? 8 | - [ ] Documentation updated (for new features)? 9 | - [ ] Issue referenced (for PRs that solve an issue)? 10 | 11 | # Reviewer checklist 12 | 13 | - [ ] CHANGELOG updated with public API or any other important changes? 14 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: monthly 7 | open-pull-requests-limit: 1 8 | groups: 9 | action-dependencies: 10 | patterns: 11 | - "*" # A wildcard to create one PR for all dependencies in the ecosystem 12 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | Please go the the `Preview` tab and select the appropriate PR template: 2 | 3 | - [Default template](?expand=1&template=pull_request_template.md) 4 | - [Adding a new architecture](?expand=1&template=pull_request_new-architecture.md) 5 | -------------------------------------------------------------------------------- /.github/workflows/architecture-tests.yml: -------------------------------------------------------------------------------- 1 | name: Architecture tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | # Check all PR 8 | 9 | jobs: 10 | tests: 11 | name: ${{ matrix.architecture-name }} 12 | strategy: 13 | matrix: 14 | include: 15 | - architecture-name: gap 16 | - architecture-name: soap-bpnn 17 | - architecture-name: pet 18 | - architecture-name: nanopet 19 | - architecture-name: deprecated-pet 20 | 21 | runs-on: ubuntu-22.04 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | 26 | - name: Set up Python 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: "3.13" 30 | - run: pip install tox 31 | 32 | - name: run architecture tests 33 | run: tox -e ${{ matrix.architecture-name }}-tests 34 | env: 35 | # Use the CPU only version of torch when building/running the code 36 | PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu 37 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | # This workflow builds and checks the package for release 2 | name: Build 3 | 4 | on: 5 | pull_request: 6 | branches: [main] 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-22.04 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Set up Python 15 | uses: actions/setup-python@v5 16 | with: 17 | python-version: "3.13" 18 | - run: pip install tox 19 | 20 | - name: Test build integrity 21 | run: tox -e build 22 | env: 23 | # Use the CPU only version of torch when building/running the code 24 | PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu 25 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | tags: ["*"] 7 | pull_request: 8 | # Check all PR 9 | 10 | jobs: 11 | build: 12 | permissions: 13 | contents: write 14 | runs-on: ubuntu-22.04 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: setup Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: "3.13" 21 | - name: install dependencies 22 | run: python -m pip install tox 23 | - name: build documentation 24 | run: tox -e docs 25 | env: 26 | # Use the CPU-only version of torch 27 | PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu 28 | 29 | - name: put documentation in the website 30 | run: | 31 | git clone https://github.com/$GITHUB_REPOSITORY --branch gh-pages gh-pages 32 | rm -rf gh-pages/.git 33 | cd gh-pages 34 | 35 | REF_KIND=$(echo $GITHUB_REF | cut -d / -f2) 36 | if [[ "$REF_KIND" == "tags" ]]; then 37 | TAG=${GITHUB_REF#refs/tags/} 38 | mv ../docs/build/html $TAG 39 | else 40 | rm -rf latest 41 | mv ../docs/build/html latest 42 | fi 43 | 44 | - name: deploy to gh-pages 45 | if: github.event_name == 'push' 46 | uses: peaceiris/actions-gh-pages@v4 47 | with: 48 | github_token: ${{ secrets.GITHUB_TOKEN }} 49 | publish_dir: ./gh-pages/ 50 | force_orphan: true 51 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | pull_request: 5 | branches: [main] 6 | 7 | jobs: 8 | lint: 9 | runs-on: ubuntu-22.04 10 | 11 | steps: 12 | - uses: actions/checkout@v4 13 | 14 | - name: Set up Python 15 | uses: actions/setup-python@v5 16 | with: 17 | python-version: "3.13" 18 | - run: pip install tox 19 | 20 | - name: Lint the code 21 | run: tox -e lint 22 | -------------------------------------------------------------------------------- /.github/workflows/pr-docs-preview.yml: -------------------------------------------------------------------------------- 1 | name: readthedocs/actions 2 | 3 | on: 4 | pull_request_target: 5 | types: 6 | - opened 7 | 8 | permissions: 9 | pull-requests: write 10 | 11 | jobs: 12 | documentation-links: 13 | runs-on: ubuntu-22.04 14 | steps: 15 | - uses: readthedocs/actions/preview@v1 16 | with: 17 | project-slug: metatrain 18 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: ["*"] 6 | 7 | jobs: 8 | build: 9 | name: Build distribution 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: pypi 13 | url: https://pypi.org/project/metatrain 14 | permissions: 15 | id-token: write 16 | contents: write 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | with: 21 | fetch-depth: 0 22 | - name: setup Python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: "3.13" 26 | - run: python -m pip install tox 27 | - name: Build package 28 | run: tox -e build 29 | - name: Publish distribution to PyPI 30 | if: startsWith(github.ref, 'refs/tags/v') 31 | uses: pypa/gh-action-pypi-publish@release/v1 32 | - name: Publish to GitHub release 33 | if: startsWith(github.ref, 'refs/tags/v') 34 | uses: softprops/action-gh-release@v2 35 | with: 36 | files: | 37 | dist/*.tar.gz 38 | dist/*.whl 39 | prerelease: ${{ contains(github.ref, '-rc') }} 40 | env: 41 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 42 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | # Check all PR 8 | 9 | jobs: 10 | tests: 11 | strategy: 12 | matrix: 13 | include: 14 | - os: ubuntu-22.04 15 | python-version: "3.9" 16 | - os: ubuntu-22.04 17 | python-version: "3.13" 18 | - os: macos-14 19 | python-version: "3.13" 20 | # To be restored once we figure out the issue with the windows build 21 | # - os: windows-2019 22 | # python-version: "3.13" 23 | 24 | runs-on: ${{ matrix.os }} 25 | 26 | steps: 27 | - uses: actions/checkout@v4 28 | 29 | - name: Install Zsh 30 | if: startsWith(matrix.os, 'ubuntu') 31 | run: | 32 | sudo apt-get update 33 | sudo apt-get install -y zsh libfftw3-dev 34 | touch ~/.zshrc 35 | 36 | - name: Set up Python ${{ matrix.python-version }} 37 | uses: actions/setup-python@v5 38 | with: 39 | python-version: ${{ matrix.python-version }} 40 | - run: pip install tox coverage[toml] 41 | 42 | - name: run Python tests and collect coverage 43 | run: | 44 | tox -e tests 45 | coverage xml --data-file tests/.coverage 46 | env: 47 | # Use the CPU only version of torch when building/running the code 48 | PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu 49 | HUGGINGFACE_TOKEN_METATRAIN: ${{ secrets.HUGGINGFACE_TOKEN }} 50 | 51 | - name: upload to codecov.io 52 | uses: codecov/codecov-action@v5 53 | with: 54 | fail_ci_if_error: true 55 | files: tests/coverage.xml 56 | token: ${{ secrets.CODECOV_TOKEN }} 57 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the OS, Python version and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.13" 13 | rust: "1.75" 14 | jobs: 15 | pre_build: 16 | - set -e && cd examples/ase && bash train.sh 17 | - set -e && cd examples/programmatic/llpr && bash train.sh 18 | - set -e && cd examples/zbl && bash train.sh 19 | 20 | # Build documentation in the docs/ directory with Sphinx 21 | sphinx: 22 | configuration: docs/src/conf.py 23 | fail_on_warning: true 24 | 25 | # Declare the Python requirements required to build the docs. 26 | # Additionally, a custom environment variable 27 | # PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu 28 | # is declared in the project’s dashboard 29 | python: 30 | install: 31 | - method: pip 32 | path: . 33 | extra_requirements: 34 | - gap 35 | - soap-bpnn 36 | - requirements: docs/requirements.txt 37 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # This file defines code owners for the different architectures, ensuring 2 | # that they get asked for a review whenever the code for this architecture 3 | # is modified. 4 | 5 | **/soap_bpnn @frostedoyster 6 | **/pet @abmazitov 7 | **/gap @DavideTisi 8 | **/nanopet @frostedoyster 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Laboratory of Computational Science and Modeling 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft src 2 | 3 | include LICENSE 4 | include README.rst 5 | 6 | prune developer 7 | prune docs 8 | prune examples 9 | prune tests 10 | prune .github 11 | prune .tox 12 | 13 | exclude .codecov.yml 14 | exclude .gitignore 15 | exclude .readthedocs.yml 16 | exclude CODEOWNERS 17 | exclude CONTRIBUTING.rst 18 | exclude tox.ini 19 | 20 | global-exclude *.py[cod] __pycache__/* *.so *.dylib 21 | -------------------------------------------------------------------------------- /developer/clean-tempfiles.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script removes all temporary files created by Python during 4 | # installation and tests running. 5 | 6 | set -eux 7 | 8 | ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")"/.. && pwd) 9 | cd "$ROOT_DIR" 10 | 11 | rm -rf dist 12 | rm -rf build 13 | rm -rf docs/build 14 | rm -rf docs/src/examples 15 | rm -rf docs/src/sg_execution_times.rst 16 | 17 | rm -rf src/metatrain/dist 18 | rm -rf src/metatrain/build 19 | 20 | find . -name "*.egg-info" -exec rm -rf "{}" + 21 | find . -name "__pycache__" -exec rm -rf "{}" + 22 | find . -name ".coverage" -exec rm -rf "{}" + 23 | -------------------------------------------------------------------------------- /docs/generate_examples/conf.py: -------------------------------------------------------------------------------- 1 | # Pseudo-sphinx configuration to run sphinx-gallery as a command line tool 2 | 3 | import os 4 | 5 | 6 | extensions = [ 7 | "sphinx_gallery.gen_gallery", 8 | ] 9 | 10 | HERE = os.path.dirname(__file__) 11 | ROOT = os.path.realpath(os.path.join(HERE, "..", "..")) 12 | 13 | sphinx_gallery_conf = { 14 | "filename_pattern": r"/*\.py", 15 | "copyfile_regex": r".*\.(pt|sh|xyz|yaml)", 16 | "ignore_pattern": r"train\.sh", 17 | "example_extensions": {".py", ".sh"}, 18 | "default_thumb_file": os.path.join(ROOT, "docs/src/logo/metatrain-512.png"), 19 | "examples_dirs": [ 20 | os.path.join(ROOT, "examples", "ase"), 21 | os.path.join(ROOT, "examples", "programmatic", "llpr"), 22 | os.path.join(ROOT, "examples", "zbl"), 23 | os.path.join(ROOT, "examples", "programmatic", "use_architectures_outside"), 24 | os.path.join(ROOT, "examples", "programmatic", "disk_dataset"), 25 | os.path.join(ROOT, "examples", "basic_usage"), 26 | ], 27 | "gallery_dirs": [ 28 | os.path.join(ROOT, "docs", "src", "examples", "ase"), 29 | os.path.join(ROOT, "docs", "src", "examples", "programmatic", "llpr"), 30 | os.path.join(ROOT, "docs", "src", "examples", "zbl"), 31 | os.path.join(ROOT, "docs", "src", "examples", "programmatic", "use_architectures_outside"), 32 | os.path.join(ROOT, "docs", "src", "examples", "programmatic", "disk_dataset"), 33 | os.path.join(ROOT, "docs", "src", "examples", "basic_usage"), 34 | ], 35 | "min_reported_time": 5, 36 | "matplotlib_animations": True, 37 | } 38 | -------------------------------------------------------------------------------- /docs/generate_examples/generate-examples.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from sphinx.application import Sphinx 4 | 5 | 6 | HERE = os.path.realpath(os.path.dirname(__file__)) 7 | 8 | if __name__ == "__main__": 9 | # the examples gallery is automatically generated upon the Sphinx object creation 10 | Sphinx( 11 | srcdir=os.path.join(HERE, "..", "src"), 12 | confdir=HERE, 13 | outdir=os.path.join(HERE, "..", "build"), 14 | doctreedir=os.path.join(HERE, "..", "build"), 15 | buildername="html", 16 | ) 17 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | ase 2 | furo 3 | sphinx >= 7 4 | sphinxcontrib-bibtex 5 | sphinx-gallery 6 | sphinx-toggleprompt 7 | setuptools # required for sphinxcontrib-bibtex together with Python 3.13 8 | tomli 9 | -------------------------------------------------------------------------------- /docs/src/_static/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/docs/src/_static/.gitkeep -------------------------------------------------------------------------------- /docs/src/_templates/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/docs/src/_templates/.gitkeep -------------------------------------------------------------------------------- /docs/src/advanced-concepts/auto-restarting.rst: -------------------------------------------------------------------------------- 1 | Automatic restarting 2 | ==================== 3 | 4 | When restarting multiple times (for example, when training an expensive model 5 | or running on an HPC cluster with short time limits), it is useful to be able 6 | to train and restart multiple times with the same command. 7 | 8 | In ``metatrain``, this functionality is provided via the ``--restart auto`` 9 | (or ``-c auto``) flag of ``mtt train``. This flag will automatically restart 10 | the training from the last checkpoint, if one is found in the ``outputs/`` 11 | of the current directory. If no checkpoint is found, the training will start 12 | from scratch. 13 | -------------------------------------------------------------------------------- /docs/src/advanced-concepts/index.rst: -------------------------------------------------------------------------------- 1 | Advanced concepts 2 | ================= 3 | 4 | This section covers advanced concepts of the ``metatrain`` library, 5 | such as output naming, auxiliary outputs, and wrapper models. 6 | 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | output-naming 12 | auxiliary-outputs 13 | multi-gpu 14 | auto-restarting 15 | fine-tuning 16 | fitting-generic-targets 17 | transfer-learning 18 | -------------------------------------------------------------------------------- /docs/src/advanced-concepts/multi-gpu.rst: -------------------------------------------------------------------------------- 1 | Multi-GPU training 2 | ================== 3 | 4 | Some of the architectures in metatensor-models support multi-GPU training. 5 | In multi-GPU training, every batch of samples is split into smaller 6 | mini-batches and the computation is run for each of the smaller mini-batches 7 | in parallel on different GPUs. The different gradients obtained on each 8 | device are then summed. This approach allows the user to reduce the time 9 | it takes to train models. 10 | 11 | Here is a list of architectures supporting multi-GPU training: 12 | 13 | 14 | SOAP-BPNN 15 | --------- 16 | 17 | SOAP-BPNN supports distributed multi-GPU training on SLURM environments. 18 | The options file to run distributed training with the SOAP-BPNN model looks 19 | like this: 20 | 21 | .. literalinclude:: ../../../examples/multi-gpu/soap-bpnn/options-distributed.yaml 22 | :language: yaml 23 | 24 | and the slurm submission script would look like this: 25 | 26 | .. literalinclude:: ../../../examples/multi-gpu/soap-bpnn/submit-distributed.sh 27 | :language: shell 28 | -------------------------------------------------------------------------------- /docs/src/advanced-concepts/output-naming.rst: -------------------------------------------------------------------------------- 1 | Output naming 2 | ============= 3 | 4 | The name and format of the outputs in ``metatrain`` are based on those of the 5 | ``_ 6 | package. An immediate example is given by the ``energy`` output. 7 | 8 | Any additional outputs present within the library are denoted by the 9 | ``mtt::`` prefix. For example, some models can output their last-layer 10 | features, which are named as ``mtt::aux::{target}_last_layer_features``, 11 | where ``aux`` denotes an auxiliary output. 12 | 13 | Outputs that are specific to a particular model should be named as 14 | ``mtt::::``. 15 | -------------------------------------------------------------------------------- /docs/src/advanced-concepts/transfer-learning.rst: -------------------------------------------------------------------------------- 1 | .. _transfer-learning: 2 | 3 | Transfer Learning (experimental) 4 | ==================================== 5 | 6 | .. warning:: 7 | 8 | This section of the documentation is only relevant for PET model so far. 9 | 10 | .. warning:: 11 | 12 | Features described in this section are experimental and not yet 13 | extensively tested. Please use them at your own risk and report any 14 | issues you encounter to the developers. 15 | 16 | This section describes the process of transfer learning, which is a 17 | common technique used in machine learning, where a model is pre-trained on 18 | the dataset with one level of theory and/or one set of properties and then 19 | fine-tuned on a different dataset with a different level of theory and/or 20 | different set of properties. This approach to use the learned representations 21 | from the pre-trained model and adapt them to the targets, which can be 22 | expensive to compute and/or not available in the pre-trained dataset. 23 | 24 | In the following sections we assume that the pre-trained model is trained on the 25 | conventional DFT dataset with energies, forces and stresses, which are provided 26 | as ``energy`` targets (and its derivatives) in the ``options.yaml`` file. 27 | 28 | 29 | Fitting to a new level of theory 30 | -------------------------------- 31 | 32 | Training on a new level of theory is a common use case for transfer 33 | learning. It requires using a pre-trained model checkpoint with 34 | ``mtt train -c pre-trained-model.ckpt`` command and setting the 35 | new targets corresponding to the new level of theory in the 36 | ``options.yaml`` file. Let's assume that the training is done on the 37 | dataset computed with the hybrid DFT functional (e.g. PBE0) stored in the 38 | ``new_train_dataset.xyz`` file, where the corresponsing energies are written in the 39 | ``energy`` key of the ``info`` dictionary of the ``ase.Atoms`` object. Then, 40 | the ``options.yaml`` file should look like this: 41 | 42 | .. code-block:: yaml 43 | 44 | training_set: 45 | systems: "new_train_dataset.xyz" 46 | targets: 47 | mtt::energy_pbe0: # name of the new target 48 | key: "energy" # key of the target in the atoms.info dictionary 49 | unit: "eV" # unit of the target value 50 | 51 | 52 | The validation and test sets can be set in the same way. The training 53 | process will then create a new composition model and new heads for the 54 | target ``energy_pbe0``. The rest of the model weights will be 55 | initialized from the pre-trained model checkpoint. 56 | 57 | 58 | Fitting to a new set of properties 59 | ---------------------------------- 60 | 61 | Training on a new set of properties is another common use case for 62 | transfer learning. It can be done in a similar way as training on a new 63 | level of theory. The only difference is that the new targets need to be 64 | properly set in the ``options.yaml`` file. More information about fitting the 65 | generic targets can be found in the :ref:`Fitting generic targets ` 66 | section of the documentation. 67 | 68 | 69 | -------------------------------------------------------------------------------- /docs/src/architectures/index.rst: -------------------------------------------------------------------------------- 1 | .. _available-architectures: 2 | 3 | Available Architectures 4 | ======================= 5 | 6 | This is a list of all architectures available in ``metatrain``. 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | :glob: 11 | 12 | ./* 13 | -------------------------------------------------------------------------------- /docs/src/dev-docs/architecture-life-cycle.rst: -------------------------------------------------------------------------------- 1 | .. _architecture-life-cycle: 2 | 3 | Life Cycle of an Architecture 4 | ============================= 5 | 6 | .. TODO: Maybe add a flowchart later 7 | 8 | Architectures in ``metatrain`` undergo different stages based on their 9 | development/functionality level and maintenance status. We distinguish three distinct 10 | stages: **experimental**, **stable**, and **deprecated**. Typically, an architecture 11 | starts as experimental, advances to stable, and eventually becomes deprecated before 12 | removal if maintenance is no longer feasible. 13 | 14 | .. note:: 15 | The development and maintenance of an architecture must be fully undertaken by the 16 | architecture's authors or maintainers. The core developers of ``metatrain`` 17 | provide infrastructure and implementation support but are not responsible for the 18 | architecture's internal functionality or any issues that may arise therein. 19 | 20 | Experimental Architectures 21 | -------------------------- 22 | 23 | New architectures added to the library will initially be classified as experimental. 24 | These architectures are stored in the ``experimental`` subdirectory within the 25 | repository. To qualify as an experimental architecture, certain criteria must be met: 26 | 27 | 1. Capability to fit at least a single quantity and predict it, verified through CI 28 | tests. 29 | 2. Compatibility with JIT compilation using `TorchScript 30 | `_. 31 | 3. Provision of reasonable default hyperparameters. 32 | 4. Minimal code quality, ensured by passing linting tests invoked with ``tox -e lint``. 33 | 5. A contact person designated as the maintainer. 34 | 6. All external dependencies must be pip-installable. While not required to be on PyPI, 35 | a public git repository or another public URL with a repository is acceptable. 36 | 37 | For detailed instructions on adding a new architecture, refer to 38 | :ref:`adding-new-architecture`. 39 | 40 | Stable Architectures 41 | -------------------- 42 | 43 | Transitioning from an experimental to a stable model requires additional criteria to be 44 | satisfied: 45 | 46 | 1. Provision of regression prediction tests with a small (not exported) checkpoint file. 47 | 2. Comprehensive architecture documentation including a schema for verifying the 48 | architecture's hyperparameters. 49 | 3. If an architecture has external dependencies, all must be publicly available on PyPI. 50 | 4. Adherence to the standard output infrastructure of ``metatrain``, including 51 | logging and model save locations. 52 | 53 | Deprecated Architectures 54 | ------------------------ 55 | 56 | An architecture will be deemed deprecated if its maintainer becomes irresponsive 57 | any of its CI jobs fail. Such an architecture will be **removed after 6 months** unless 58 | a new maintainer is found who can address the issues. If rectified within this 6-month 59 | period, the model may revert to its previous stable or experimental status. 60 | -------------------------------------------------------------------------------- /docs/src/dev-docs/cli/eval.rst: -------------------------------------------------------------------------------- 1 | Eval 2 | #### 3 | 4 | .. automodule:: metatrain.cli.eval 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/cli/export.rst: -------------------------------------------------------------------------------- 1 | Export 2 | ###### 3 | 4 | .. automodule:: metatrain.cli.export 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/cli/formatter.rst: -------------------------------------------------------------------------------- 1 | Formatter 2 | ######### 3 | 4 | .. automodule:: metatrain.cli.formatter 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/cli/index.rst: -------------------------------------------------------------------------------- 1 | CLI API 2 | ======= 3 | 4 | This is the API for the command line interface ``cli`` functions for the ``train``, 5 | the ``eval`` and the ``export`` functions of ``metatrain``. 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | train 11 | eval 12 | export 13 | 14 | We provide a custom formatter class for the formatting the help message of the 15 | ``argparse`` package. 16 | 17 | .. toctree:: 18 | :maxdepth: 1 19 | 20 | formatter 21 | -------------------------------------------------------------------------------- /docs/src/dev-docs/cli/train.rst: -------------------------------------------------------------------------------- 1 | Train 2 | ##### 3 | 4 | .. automodule:: metatrain.cli.train 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/dataset-information.rst: -------------------------------------------------------------------------------- 1 | Dataset Information 2 | =================== 3 | 4 | When working with ``metatrain``, you will most likely need to interact with some core 5 | classes which are responsible for storing some information about datasets. All these 6 | classes belong to the ``metatrain.utils.data`` module which can be found in the 7 | :ref:`data` section of the developer documentation. 8 | 9 | These classes are: 10 | 11 | - :py:class:`metatrain.utils.data.DatasetInfo`: This class is responsible for storing 12 | information about a dataset. It contains the length unit used in the dataset, the 13 | atomic types present, as well as information about the dataset's targets as a 14 | ``Dict[str, TargetInfo]`` object. The keys of this dictionary are the names of the 15 | targets in the datasets (e.g., ``energy``, ``mtt::dipole``, etc.). 16 | 17 | - :py:class:`metatrain.utils.data.TargetInfo`: This class is responsible for storing 18 | information about a target in a dataset. It contains the target's physical quantity, 19 | the unit in which the target is expressed, and the ``layout`` of the target. The 20 | ``layout`` is ``TensorMap`` object with zero samples which is used to exemplify 21 | the metadata of each target. 22 | 23 | At the moment, only three types of layouts are supported: 24 | 25 | - scalar: This type of layout is used when the target is a scalar quantity. The 26 | ``layout`` ``TensorMap`` object corresponding to a scalar must have one 27 | ``TensorBlock`` and no ``components``. 28 | - Cartesian tensor: This type of layout is used when the target is a Cartesian tensor. 29 | The ``layout`` ``TensorMap`` object corresponding to a Cartesian tensor must have 30 | one ``TensorBlock`` and as many ``components`` as the tensor's rank. These 31 | components are named ``xyz`` for a tensor of rank 1 and ``xyz_1``, ``xyz_2``, and 32 | so on for higher ranks. 33 | - Spherical tensor: This type of layout is used when the target is a spherical tensor. 34 | The ``layout`` ``TensorMap`` object corresponding to a spherical tensor can have 35 | multiple blocks corresponding to different irreps (irreducible representations) of 36 | the target. The ``keys`` of the ``TensorMap`` object must have the ``o3_lambda`` 37 | and ``o3_sigma`` names, and each ``TensorBlock`` must have a single component named 38 | ``o3_mu``. 39 | -------------------------------------------------------------------------------- /docs/src/dev-docs/getting-started.rst: -------------------------------------------------------------------------------- 1 | .. _label_dev-getting-started: 2 | 3 | .. include:: ../../../CONTRIBUTING.rst 4 | -------------------------------------------------------------------------------- /docs/src/dev-docs/index.rst: -------------------------------------------------------------------------------- 1 | Developer documentation 2 | ======================= 3 | 4 | This is a collection of documentation for developers of the ``metatrain`` package. 5 | It includes documentation on how to add a new model, as well as the API of the utils 6 | module. 7 | 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | 12 | getting-started 13 | architecture-life-cycle 14 | new-architecture 15 | dataset-information 16 | cli/index 17 | utils/index 18 | changelog 19 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/additive/composition.rst: -------------------------------------------------------------------------------- 1 | Composition model 2 | ################# 3 | 4 | .. automodule:: metatrain.utils.additive.composition 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/additive/index.rst: -------------------------------------------------------------------------------- 1 | Additive models 2 | =============== 3 | 4 | API for handling additive models in ``metatrain``. These are models that 5 | can be added to one or more architectures. 6 | 7 | .. toctree:: 8 | :maxdepth: 1 9 | 10 | remove_additive 11 | composition 12 | zbl 13 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/additive/remove_additive.rst: -------------------------------------------------------------------------------- 1 | Removing additive contributions 2 | ############################### 3 | 4 | .. automodule:: metatrain.utils.additive.remove 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/additive/zbl.rst: -------------------------------------------------------------------------------- 1 | ZBL short-range potential 2 | ######################### 3 | 4 | .. automodule:: metatrain.utils.additive.zbl 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/architectures.rst: -------------------------------------------------------------------------------- 1 | Architectures 2 | ############# 3 | 4 | Utility functions to detect architecture and verify their names in metatrain. 5 | 6 | .. automodule:: metatrain.utils.architectures 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/augmentation.rst: -------------------------------------------------------------------------------- 1 | Augmentation 2 | ############ 3 | 4 | .. automodule:: metatrain.utils.augmentation 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/data/combine_dataloaders.rst: -------------------------------------------------------------------------------- 1 | Combining dataloaders 2 | ##################### 3 | 4 | .. automodule:: metatrain.utils.data.combine_dataloaders 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | :special-members: 9 | :exclude-members: __init__, reset, __iter__, __next__ 10 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/data/dataset.rst: -------------------------------------------------------------------------------- 1 | Dataset 2 | ####### 3 | 4 | .. automodule:: metatrain.utils.data.dataset 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/data/get_dataset.rst: -------------------------------------------------------------------------------- 1 | Reading a dataset 2 | ################# 3 | 4 | .. automodule:: metatrain.utils.data.get_dataset 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/data/index.rst: -------------------------------------------------------------------------------- 1 | .. _data: 2 | 3 | Data 4 | ==== 5 | 6 | API for handling data in ``metatrain``. 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | combine_dataloaders 12 | dataset 13 | get_dataset 14 | readers 15 | writers 16 | systems_to_ase 17 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/data/readers.rst: -------------------------------------------------------------------------------- 1 | Readers 2 | ####### 3 | 4 | Parsers for obtaining *system* and *target* information from files. Currently, 5 | ``metatrain`` support the following libraries for reading data 6 | 7 | .. list-table:: 8 | :header-rows: 1 9 | 10 | * - Library 11 | - Supported targets 12 | - Linked file formats 13 | * - ``ase`` 14 | - system, energy, forces, stress, virials 15 | - ``.xyz``, ``.extxyz`` 16 | * - ``metatensor`` 17 | - system, energy, forces, stress, virials 18 | - ``.mts`` 19 | 20 | 21 | If the ``reader`` parameter is not set, the library is determined from the file 22 | extension. Overriding this behavior is in particular useful if a file format is not 23 | listed here but might be supported by a library. 24 | 25 | Below the synopsis of the reader functions in details. 26 | 27 | System and target data readers 28 | ============================== 29 | 30 | The main entry point for reading system and target information are the reader functions. 31 | 32 | .. autofunction:: metatrain.utils.data.read_systems 33 | .. autofunction:: metatrain.utils.data.read_targets 34 | 35 | These functions dispatch the reading of the system and target information to the 36 | appropriate readers, based on the file extension or the user-provided library. 37 | 38 | In addition, the read_targets function uses the user-provided information about the 39 | targets to call the appropriate target reader function (for energy targets or generic 40 | targets). 41 | 42 | ASE 43 | --- 44 | 45 | This section describes the parsers for the ASE library. 46 | 47 | .. autofunction:: metatrain.utils.data.readers.ase.read 48 | .. autofunction:: metatrain.utils.data.readers.ase.read_systems 49 | .. autofunction:: metatrain.utils.data.readers.ase.read_energy 50 | .. autofunction:: metatrain.utils.data.readers.ase.read_generic 51 | 52 | It should be noted that :func:`metatrain.utils.data.readers.ase.read_energy` currently 53 | uses sub-functions to parse the energy and its gradients like ``forces``, ``virial`` 54 | and ``stress``. 55 | 56 | Metatensor 57 | ---------- 58 | 59 | This section describes the parsers for the ``metatensor`` library. As the systems and/or 60 | targets are already stored in the ``metatensor`` format, these reader functions mainly 61 | perform checks and return the data. 62 | 63 | .. autofunction:: metatrain.utils.data.readers.metatensor.read_systems 64 | .. autofunction:: metatrain.utils.data.readers.metatensor.read_energy 65 | .. autofunction:: metatrain.utils.data.readers.metatensor.read_generic 66 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/data/systems_to_ase.rst: -------------------------------------------------------------------------------- 1 | Converting Systems to ASE 2 | ######################### 3 | 4 | Some machine learning models might train on ``ase.Atoms`` objects. This module 5 | provides a function to convert a ``metatomic.torch.System`` object to an 6 | ``ase.Atoms`` object. 7 | 8 | .. automodule:: metatrain.utils.data.system_to_ase 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/data/writers.rst: -------------------------------------------------------------------------------- 1 | Target data Writers 2 | =================== 3 | 4 | The main entry point for writing target information is 5 | 6 | .. autofunction:: metatrain.utils.data.writers.write_predictions 7 | 8 | 9 | Based on the provided filename the writer choses which child writer to use. The mapping 10 | which reader is used for which file type is stored in 11 | 12 | .. autodata:: metatrain.utils.data.writers.PREDICTIONS_WRITERS 13 | 14 | Implemented Writers 15 | ------------------- 16 | 17 | .. autofunction:: metatrain.utils.data.writers.write_xyz 18 | .. autofunction:: metatrain.utils.data.writers.write_mts 19 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/devices.rst: -------------------------------------------------------------------------------- 1 | Device 2 | ###### 3 | 4 | .. automodule:: metatrain.utils.devices 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/dtype.rst: -------------------------------------------------------------------------------- 1 | Dtype 2 | ##### 3 | 4 | .. automodule:: metatrain.utils.dtype 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/errors.rst: -------------------------------------------------------------------------------- 1 | Errors 2 | ###### 3 | 4 | .. automodule:: metatrain.utils.errors 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/evaluate_model.rst: -------------------------------------------------------------------------------- 1 | Evaluating a model 2 | ################## 3 | 4 | .. automodule:: metatrain.utils.evaluate_model 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/external_naming.rst: -------------------------------------------------------------------------------- 1 | External Naming 2 | ############### 3 | 4 | 5 | Functions to handle the conversion between external and internal naming conventions. 6 | 7 | .. automodule:: metatrain.utils.external_naming 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/index.rst: -------------------------------------------------------------------------------- 1 | Utility API 2 | =========== 3 | 4 | This is the API for the ``utils`` module of ``metatrain``. 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | :glob: 9 | 10 | additive/index 11 | data/index 12 | ./* 13 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/io.rst: -------------------------------------------------------------------------------- 1 | IO 2 | ## 3 | 4 | 5 | Functions to be used for handling the serialization of models 6 | 7 | .. automodule:: metatrain.utils.io 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | :member-order: bysource 12 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/jsonschema.rst: -------------------------------------------------------------------------------- 1 | Jsonschema 2 | ########## 3 | 4 | Functions and classes to wrap around and extend the `jsonschema 5 | `_ library. 6 | 7 | .. automodule:: metatrain.utils.jsonschema 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/logging.rst: -------------------------------------------------------------------------------- 1 | Logging 2 | ####### 3 | 4 | .. automodule:: metatrain.utils.logging 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/long_range.rst: -------------------------------------------------------------------------------- 1 | Long-range 2 | ########## 3 | 4 | .. automodule:: metatrain.utils.long_range 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/loss.rst: -------------------------------------------------------------------------------- 1 | Loss 2 | #### 3 | 4 | .. automodule:: metatrain.utils.loss 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics 2 | ####### 3 | 4 | .. automodule:: metatrain.utils.metrics 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/neighbor_lists.rst: -------------------------------------------------------------------------------- 1 | Neighbor lists 2 | ============== 3 | 4 | Utilities to attach neighbor lists to a ``metatomic.torch.System`` object. 5 | 6 | .. automodule:: metatrain.utils.neighbor_lists 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/omegaconf.rst: -------------------------------------------------------------------------------- 1 | Custom omegaconf functions 2 | ========================== 3 | 4 | Resolvers to handle special fields in our configs as well as the expansion/completion of 5 | the dataset section. 6 | 7 | .. automodule:: metatrain.utils.omegaconf 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/output_gradient.rst: -------------------------------------------------------------------------------- 1 | Output gradient 2 | ############### 3 | 4 | .. automodule:: metatrain.utils.output_gradient 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/per_atom.rst: -------------------------------------------------------------------------------- 1 | Averaging predictions per atom 2 | ############################## 3 | 4 | .. automodule:: metatrain.utils.per_atom 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/scaler.rst: -------------------------------------------------------------------------------- 1 | Scaler 2 | ###### 3 | 4 | .. automodule:: metatrain.utils.scaler 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/sum_over_atoms.rst: -------------------------------------------------------------------------------- 1 | Summing over atoms 2 | ################## 3 | 4 | .. automodule:: metatrain.utils.sum_over_atoms 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/transfer.rst: -------------------------------------------------------------------------------- 1 | Data type and device transfers 2 | ############################## 3 | 4 | .. automodule:: metatrain.utils.transfer 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/dev-docs/utils/units.rst: -------------------------------------------------------------------------------- 1 | Unit handling 2 | ############# 3 | 4 | .. automodule:: metatrain.utils.units 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/src/getting-started/advanced_base_config.rst: -------------------------------------------------------------------------------- 1 | .. _advanced_base_conf: 2 | 3 | Advanced Base Configuration 4 | =========================== 5 | 6 | Here, we show how some advanced base properties in the ``options.yaml`` can 7 | be adjusted. They should be written without indentation in the ``options.yaml`` file. 8 | 9 | :param device: The device in which the training should be run. Takes two possible 10 | values: ``cpu`` and ``gpu``. Default: ``cpu`` 11 | :param base_precision: Override the base precision of all floats during training. By 12 | default an optimal precision is obtained from the architecture. Changing this will 13 | have an effect on the memory consumption during training and maybe also on the 14 | accuracy of the model. Possible values: ``64``, ``32`` or ``16``. 15 | :param seed: Seed used to start the training. Set all the seeds of ``numpy.random``, 16 | ``random``, ``torch`` and ``torch.cuda`` (if available) to the same value ``seed``. 17 | If ``seed`` is not the initial seed will be set to a random number. This initial 18 | seed will be reported in the output folder 19 | :param wandb: If you want to use Weights and Biases (wandb) for logging, create a new 20 | section with this name. The parameters of section are the same as of the `wandb.init 21 | `_ method and a minimal example of the 22 | section is: 23 | 24 | .. code-block:: yaml 25 | 26 | wandb: 27 | project: my_project 28 | name: my_run_name 29 | tags: 30 | - tag1 31 | - tag2 32 | notes: This is a test run 33 | 34 | All parameters of your options file will be automatically added to the wandb run so 35 | you don't have to set the ``config`` parameter. 36 | 37 | .. important:: 38 | 39 | You need to install wandb with ``pip install wandb`` if you want to use this 40 | logger. **Before** running also set up your credentials with ``wandb login`` 41 | from the command line. See `wandb login 42 | documentation `_ for details on the 43 | setup. 44 | 45 | In the next tutorials we show how to override the default parameters of an architecture. 46 | -------------------------------------------------------------------------------- /docs/src/getting-started/index.rst: -------------------------------------------------------------------------------- 1 | Getting started 2 | =============== 3 | 4 | This sections describes how to install the package, and its most basic commands. 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | installation 10 | ../examples/basic_usage/usage 11 | custom_dataset_conf 12 | advanced_base_config 13 | override 14 | checkpoints 15 | units 16 | -------------------------------------------------------------------------------- /docs/src/getting-started/installation.rst: -------------------------------------------------------------------------------- 1 | .. _label_installation: 2 | 3 | .. include:: ../../../README.rst 4 | :start-after: marker-installation 5 | :end-before: marker-issues 6 | -------------------------------------------------------------------------------- /docs/src/getting-started/override.rst: -------------------------------------------------------------------------------- 1 | Override Architecture's Default Parameters 2 | ========================================== 3 | 4 | In our initial tutorial, we used default parameters to train a model employing the 5 | SOAP-BPNN architecture, as shown in the following config: 6 | 7 | .. literalinclude:: ../../static/qm9/options.yaml 8 | :language: yaml 9 | 10 | While default parameters often serve as a good starting point, depending on your 11 | training target and dataset, it might be necessary to adjust the architecture's 12 | parameters. 13 | 14 | First, familiarize yourself with the specific parameters of the architecture you intend 15 | to use. We provide a list of all architectures and their parameters in the 16 | :ref:`available-architectures` section. For example, the parameters of the SOAP-BPNN 17 | models are detailed at :ref:`architecture-soap-bpnn`. 18 | 19 | Modifying Parameters (yaml) 20 | --------------------------- 21 | 22 | As an example, let's increase the number of epochs (``num_epochs``) and the ``cutoff`` 23 | radius of the SOAP descriptor. To do this, create a new section in the ``options.yaml`` 24 | named ``architecture``. Within this section, you can override the architecture's 25 | hyperparameters. The adjustments for ``num_epochs`` and ``cutoff`` look like this: 26 | 27 | .. code-block:: yaml 28 | 29 | architecture: 30 | name: "soap_bpnn" 31 | model: 32 | soap: 33 | cutoff: 7.0 34 | training: 35 | num_epochs: 200 36 | 37 | training_set: 38 | systems: "qm9_reduced_100.xyz" 39 | targets: 40 | energy: 41 | key: "U0" 42 | 43 | test_set: 0.1 44 | validation_set: 0.1 45 | 46 | Modifying Parameters (Command Line Overrides) 47 | --------------------------------------------- 48 | 49 | For quick adjustments or additions to an options file, command-line overrides are also 50 | possibility. The changes above can be achieved by typing: 51 | 52 | .. code-block:: bash 53 | 54 | mtt train options.yaml \ 55 | -r architecture.model.soap.cutoff=7.0 -r architecture.training.num_epochs=200 56 | 57 | Here, the ``-r`` or equivalent ``--override`` flag is used to parse the override flags. 58 | The syntax follows a dotlist-style string format where each level of the options is 59 | seperated by a ``.``. As a further example, to use single precision for your training 60 | you can add ``-r base_precision=32``. 61 | 62 | .. note:: 63 | Command line overrides allow adding new values to your training parameters and 64 | override the architectures as well as the parameters of your provided options file. 65 | -------------------------------------------------------------------------------- /docs/src/getting-started/units.rst: -------------------------------------------------------------------------------- 1 | Units 2 | ===== 3 | 4 | ``metatrain`` will always work with the units as provided by the user, and all logs will 5 | be in the same units. In other terms, ``metatrain`` does not perform any unit 6 | conversion. The only exception is the logging of energies in ``meV`` if the energies are 7 | declared to be in ``eV``, for consistency with common practice and other codes. 8 | 9 | Although not mandatory, the user is encouraged to specify the units of their datasets 10 | in the input files, so that the logs can be more informative and, more importantly, in 11 | order to make the resulting exported models usable in simulation engines (which instead 12 | require the units to be specified) without unpleasant surprises. 13 | -------------------------------------------------------------------------------- /docs/src/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to metatrain! 2 | ===================== 3 | 4 | .. include:: ../../README.rst 5 | :start-after: marker-introduction 6 | :end-before: marker-documentation 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | :hidden: 11 | 12 | getting-started/index 13 | architectures/index 14 | tutorials/index 15 | advanced-concepts/index 16 | dev-docs/index 17 | -------------------------------------------------------------------------------- /docs/src/logo/metatrain-512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/docs/src/logo/metatrain-512.png -------------------------------------------------------------------------------- /docs/src/logo/metatrain-64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/docs/src/logo/metatrain-64.png -------------------------------------------------------------------------------- /docs/src/logo/metatrain.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 18 | 19 | -------------------------------------------------------------------------------- /docs/src/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | .. _tutorials: 2 | 3 | Tutorials 4 | ========= 5 | 6 | This sections includes some more advanced tutorials on the usage of the 7 | ``metatrain`` package. 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | 12 | ../examples/ase/run_ase 13 | ../examples/zbl/dimers 14 | ../examples/programmatic/llpr/llpr 15 | ../examples/programmatic/use_architectures_outside/use_outside 16 | ../examples/programmatic/disk_dataset/disk_dataset 17 | -------------------------------------------------------------------------------- /docs/static/images/metatrain-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/docs/static/images/metatrain-dark.png -------------------------------------------------------------------------------- /docs/static/images/metatrain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/docs/static/images/metatrain.png -------------------------------------------------------------------------------- /docs/static/qm9/eval.yaml: -------------------------------------------------------------------------------- 1 | systems: "qm9_reduced_100.xyz" 2 | targets: 3 | energy: 4 | key: "U0" 5 | unit: "eV" 6 | -------------------------------------------------------------------------------- /docs/static/qm9/options.yaml: -------------------------------------------------------------------------------- 1 | # architecture used to train the model 2 | architecture: 3 | name: soap_bpnn 4 | training: 5 | num_epochs: 5 # a very short training run 6 | 7 | # Mandatory section defining the parameters for system and target data of the 8 | # training set 9 | training_set: 10 | systems: "qm9_reduced_100.xyz" # file where the positions are stored 11 | targets: 12 | energy: 13 | key: "U0" # name of the target value 14 | unit: "eV" # unit of the target value 15 | 16 | test_set: 0.1 # 10 % of the training_set are randomly split and taken for test set 17 | validation_set: 0.1 # 10 % of the training_set are randomly split and for validation 18 | -------------------------------------------------------------------------------- /docs/static/refs.bib: -------------------------------------------------------------------------------- 1 | @article{behler_generalized_2007, 2 | title = {Generalized {{Neural-Network Representation}} of {{High-Dimensional Potential-Energy Surfaces}}}, 3 | author = {Behler, J{\"o}rg and Parrinello, Michele}, 4 | year = {2007}, 5 | month = apr, 6 | journal = {Phys. Rev. Lett.}, 7 | volume = {98}, 8 | number = {14}, 9 | pages = {146401}, 10 | publisher = {{American Physical Society}}, 11 | doi = {10.1103/PhysRevLett.98.146401}, 12 | urldate = {2023-08-03} 13 | } 14 | 15 | @article{bartok_representing_2013, 16 | title = {On Representing Chemical Environments}, 17 | author = {Bart{\'o}k, Albert P. and Kondor, Risi and Cs{\'a}nyi, G{\'a}bor}, 18 | year = {2013}, 19 | month = may, 20 | journal = {Phys. Rev. B}, 21 | volume = {87}, 22 | number = {18}, 23 | pages = {184115}, 24 | publisher = {{American Physical Society}}, 25 | doi = {10.1103/PhysRevB.87.184115}, 26 | urldate = {2022-02-24} 27 | } 28 | 29 | @article{willatt_feature_2018, 30 | title = {Feature Optimization for Atomistic Machine Learning Yields a Data-Driven Construction of the Periodic Table of the Elements}, 31 | author = {Willatt, Michael J. and Musil, F{\'e}lix and Ceriotti, Michele}, 32 | year = {2018}, 33 | month = dec, 34 | journal = {Phys. Chem. Chem. Phys.}, 35 | volume = {20}, 36 | number = {47}, 37 | pages = {29661--29668}, 38 | publisher = {{The Royal Society of Chemistry}}, 39 | issn = {1463-9084}, 40 | doi = {10.1039/C8CP05921G}, 41 | urldate = {2024-02-19}, 42 | langid = {english} 43 | } 44 | 45 | @article{bigi_smooth_2022, 46 | author = {Bigi, Filippo and Huguenin-Dumittan, Kevin K. and Ceriotti, Michele and Manolopoulos, David E.}, 47 | title = "{A smooth basis for atomistic machine learning}", 48 | journal = {The Journal of Chemical Physics}, 49 | volume = {157}, 50 | number = {23}, 51 | pages = {234101}, 52 | year = {2022}, 53 | month = {12}, 54 | issn = {0021-9606}, 55 | doi = {10.1063/5.0124363}, 56 | url = {https://doi.org/10.1063/5.0124363}, 57 | } 58 | 59 | @article{pozdnyakov_smooth_2023, 60 | title = {Smooth, Exact Rotational Symmetrization for Deep Learning on Point Clouds}, 61 | author = {Pozdnyakov, Sergey N. and Ceriotti, Michele}, 62 | year = {2023}, 63 | month = may, 64 | journal = {arXiv.org}, 65 | urldate = {2025-01-24}, 66 | howpublished = {https://arxiv.org/abs/2305.19302v3}, 67 | langid = {english} 68 | } 69 | 70 | @article{bartok_gaussian_2010, 71 | title = {Gaussian {{Approximation Potentials}}: {{The Accuracy}} of {{Quantum Mechanics}}, without the {{Electrons}}}, 72 | shorttitle = {Gaussian {{Approximation Potentials}}}, 73 | author = {Bart{\'o}k, Albert P. and Payne, Mike C. and Kondor, Risi and Cs{\'a}nyi, G{\'a}bor}, 74 | year = {2010}, 75 | month = apr, 76 | journal = {Phys. Rev. Lett.}, 77 | volume = {104}, 78 | number = {13}, 79 | pages = {136403}, 80 | publisher = {American Physical Society}, 81 | doi = {10.1103/PhysRevLett.104.136403}, 82 | urldate = {2023-08-03} 83 | } 84 | -------------------------------------------------------------------------------- /examples/README.rst: -------------------------------------------------------------------------------- 1 | Metatrain Examples 2 | ================== 3 | 4 | This folder consists of introductory and advanced examples. 5 | -------------------------------------------------------------------------------- /examples/ase/README.rst: -------------------------------------------------------------------------------- 1 | Running molecular dynamics with ASE 2 | =================================== 3 | -------------------------------------------------------------------------------- /examples/ase/options.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | 3 | architecture: 4 | name: gap 5 | 6 | # training set section 7 | training_set: 8 | systems: "ethanol_reduced_100.xyz" 9 | targets: 10 | energy: 11 | key: "energy" 12 | unit: "kcal/mol" # very important to run simulations 13 | 14 | validation_set: 0.1 15 | test_set: 0.0 16 | -------------------------------------------------------------------------------- /examples/ase/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mtt train options.yaml 4 | -------------------------------------------------------------------------------- /examples/basic_usage/README.rst: -------------------------------------------------------------------------------- 1 | Basic usage of the metatrain CLI 2 | ================================ 3 | -------------------------------------------------------------------------------- /examples/basic_usage/eval.yaml: -------------------------------------------------------------------------------- 1 | ../../docs/static/qm9/eval.yaml -------------------------------------------------------------------------------- /examples/basic_usage/options.yaml: -------------------------------------------------------------------------------- 1 | ../../docs/static/qm9/options.yaml -------------------------------------------------------------------------------- /examples/basic_usage/qm9_reduced_100.xyz: -------------------------------------------------------------------------------- 1 | ../../docs/static/qm9/qm9_reduced_100.xyz -------------------------------------------------------------------------------- /examples/multi-gpu/soap-bpnn/options-distributed.yaml: -------------------------------------------------------------------------------- 1 | ../../../tests/distributed/options-distributed.yaml -------------------------------------------------------------------------------- /examples/multi-gpu/soap-bpnn/submit-distributed.sh: -------------------------------------------------------------------------------- 1 | ../../../tests/distributed/submit-distributed.sh -------------------------------------------------------------------------------- /examples/programmatic/disk_dataset/README.rst: -------------------------------------------------------------------------------- 1 | Using metatrain architectures outside of metatrain 2 | ================================================== 3 | -------------------------------------------------------------------------------- /examples/programmatic/disk_dataset/disk_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Saving a disk dataset 3 | ===================== 4 | 5 | Large datasets may not fit into memory. In such cases, it is useful to save the 6 | dataset to disk and load it on the fly during training. This example demonstrates 7 | how to save a ``DiskDataset`` for this purpose. Metatrain will then be able to load 8 | ``DiskDataset`` objects saved in this way to execute on-the-fly data loading. 9 | """ 10 | 11 | # %% 12 | # 13 | 14 | import ase.io 15 | import torch 16 | from metatensor.torch import Labels, TensorBlock, TensorMap 17 | from metatomic.torch import NeighborListOptions, systems_to_torch 18 | 19 | from metatrain.utils.data import DiskDatasetWriter 20 | from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists 21 | 22 | 23 | # %% 24 | # 25 | # As an example, we will use 100 structures from the QM9 dataset. In addition to the 26 | # systems and targets (here the energy), we also need to save the neighbor lists that 27 | # the model will use during training. 28 | 29 | disk_dataset_writer = DiskDatasetWriter("qm9_reduced_100.zip") 30 | for i in range(100): 31 | frame = ase.io.read("qm9_reduced_100.xyz", index=i) 32 | system = systems_to_torch(frame, dtype=torch.float64) 33 | system = get_system_with_neighbor_lists( 34 | system, 35 | [NeighborListOptions(cutoff=5.0, full_list=True, strict=True)], 36 | ) 37 | energy = TensorMap( 38 | keys=Labels.single(), 39 | blocks=[ 40 | TensorBlock( 41 | values=torch.tensor([[frame.info["U0"]]], dtype=torch.float64), 42 | samples=Labels( 43 | names=["system"], 44 | values=torch.tensor([[i]]), 45 | ), 46 | components=[], 47 | properties=Labels("energy", torch.tensor([[0]])), 48 | ) 49 | ], 50 | ) 51 | disk_dataset_writer.write_sample(system, {"energy": energy}) 52 | del disk_dataset_writer # not necessary if the file ends here, but good in general 53 | 54 | # %% 55 | # 56 | # The dataset is saved to disk. You can now provide it to ``metatrain`` as a 57 | # dataset to train from, simply by replacing your ``.xyz`` file with the newly created 58 | # zip file (e.g. ``read_from: qm9_reduced_100.zip``). 59 | -------------------------------------------------------------------------------- /examples/programmatic/disk_dataset/qm9_reduced_100.xyz: -------------------------------------------------------------------------------- 1 | ../../../tests/resources/qm9_reduced_100.xyz -------------------------------------------------------------------------------- /examples/programmatic/llpr/README.rst: -------------------------------------------------------------------------------- 1 | Computing LLPR uncertainties 2 | ============================ 3 | -------------------------------------------------------------------------------- /examples/programmatic/llpr/ethanol_reduced_100.xyz: -------------------------------------------------------------------------------- 1 | ../../ase/ethanol_reduced_100.xyz -------------------------------------------------------------------------------- /examples/programmatic/llpr/options.yaml: -------------------------------------------------------------------------------- 1 | device: cpu 2 | base_precision: 64 3 | seed: 42 4 | 5 | architecture: 6 | name: soap_bpnn 7 | training: 8 | batch_size: 16 9 | num_epochs: 10 10 | learning_rate: 0.01 11 | 12 | # Section defining the parameters for system and target data 13 | training_set: 14 | systems: "qm9_reduced_100.xyz" 15 | targets: 16 | energy: 17 | key: "U0" 18 | unit: "hartree" # very important to run simulations 19 | 20 | validation_set: 0.1 21 | test_set: 0.0 22 | -------------------------------------------------------------------------------- /examples/programmatic/llpr/qm9_reduced_100.xyz: -------------------------------------------------------------------------------- 1 | ../../../tests/resources/qm9_reduced_100.xyz -------------------------------------------------------------------------------- /examples/programmatic/llpr/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mtt train options.yaml 4 | -------------------------------------------------------------------------------- /examples/programmatic/llpr_forces/ethanol_reduced_100.xyz: -------------------------------------------------------------------------------- 1 | ../../ase/ethanol_reduced_100.xyz -------------------------------------------------------------------------------- /examples/programmatic/llpr_forces/options.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | 3 | architecture: 4 | name: soap_bpnn 5 | training: 6 | batch_size: 8 7 | num_epochs: 100 8 | log_interval: 1 9 | 10 | training_set: 11 | systems: 12 | read_from: train.xyz 13 | length_unit: angstrom 14 | targets: 15 | energy: 16 | key: energy 17 | unit: eV 18 | 19 | validation_set: 20 | systems: 21 | read_from: valid.xyz 22 | length_unit: angstrom 23 | targets: 24 | energy: 25 | key: energy 26 | unit: eV 27 | 28 | test_set: 29 | systems: 30 | read_from: test.xyz 31 | length_unit: angstrom 32 | targets: 33 | energy: 34 | key: energy 35 | unit: eV 36 | -------------------------------------------------------------------------------- /examples/programmatic/llpr_forces/readme.txt: -------------------------------------------------------------------------------- 1 | This is a small example of how to calculate force uncertainties with the LLPR. 2 | In order to run it, it is sufficient to split the ethanol dataset with `python split.py`. 3 | Then train a model with `mtt train options.yaml`, and finally run the example 4 | with `python force_llpr.py`. 5 | -------------------------------------------------------------------------------- /examples/programmatic/llpr_forces/split.py: -------------------------------------------------------------------------------- 1 | import ase.io 2 | import numpy as np 3 | 4 | 5 | structures = ase.io.read("ethanol_reduced_100.xyz", ":") 6 | np.random.shuffle(structures) 7 | train = structures[:50] 8 | valid = structures[50:60] 9 | test = structures[60:] 10 | 11 | ase.io.write("train.xyz", train) 12 | ase.io.write("valid.xyz", valid) 13 | ase.io.write("test.xyz", test) 14 | -------------------------------------------------------------------------------- /examples/programmatic/use_architectures_outside/README.rst: -------------------------------------------------------------------------------- 1 | Using metatrain architectures outside of metatrain 2 | ================================================== 3 | -------------------------------------------------------------------------------- /examples/programmatic/use_architectures_outside/qm9_reduced_100.xyz: -------------------------------------------------------------------------------- 1 | ../../../tests/resources/qm9_reduced_100.xyz -------------------------------------------------------------------------------- /examples/programmatic/use_architectures_outside/use_outside.py: -------------------------------------------------------------------------------- 1 | """ 2 | Using metatrain architectures outside of metatrain 3 | ================================================== 4 | 5 | This tutorial demonstrates how to use one of metatrain's implemented architectures 6 | outside of metatrain. This will be done by taking internal representations of a 7 | NanoPET model (as an example) and using them inside a user-defined torch ``Module``. 8 | 9 | Only architectures which can output internal representations ("features" output) can 10 | be used in this way. 11 | """ 12 | 13 | # %% 14 | # 15 | 16 | import torch 17 | from metatomic.torch import ModelOutput 18 | 19 | from metatrain.experimental.nanopet import NanoPET 20 | from metatrain.utils.architectures import get_default_hypers 21 | from metatrain.utils.data import DatasetInfo, read_systems 22 | from metatrain.utils.neighbor_lists import ( 23 | get_requested_neighbor_lists, 24 | get_system_with_neighbor_lists, 25 | ) 26 | 27 | 28 | # %% 29 | # 30 | # Read some sample systems. Metatrain always reads systems in float64, while torch 31 | # uses float32 by default. We will convert the systems to float32. 32 | 33 | systems = read_systems("qm9_reduced_100.xyz") 34 | systems = [s.to(torch.float32) for s in systems] 35 | 36 | 37 | # %% 38 | # 39 | # Define the custom model using the NanoPET architecture as a building block. 40 | # The dummy architecture here adds a linear layer and a tanh activation function 41 | # on top of the NanoPET model. 42 | 43 | 44 | class NanoPETWithTanh(torch.nn.Module): 45 | def __init__(self): 46 | super().__init__() 47 | self.nanopet = NanoPET( 48 | get_default_hypers("experimental.nanopet")["model"], 49 | DatasetInfo( 50 | length_unit="angstrom", 51 | atomic_types=[1, 6, 7, 8, 9], 52 | targets={}, 53 | ), 54 | ) 55 | self.linear = torch.nn.Linear(128, 1) 56 | self.tanh = torch.nn.Tanh() 57 | 58 | def forward(self, systems): 59 | model_outputs = self.nanopet( 60 | systems, 61 | {"features": ModelOutput()}, 62 | # ModelOutput(per_atom=True) would give per-atom features 63 | ) 64 | features = model_outputs["features"].block().values 65 | return self.tanh(self.linear(features)) 66 | 67 | 68 | # %% 69 | # 70 | # Now we can train the custom model. Here is one training step executed with 71 | # some random targets. 72 | my_targets = torch.randn(100, 1) 73 | 74 | # instantiate the model 75 | model = NanoPETWithTanh() 76 | 77 | # all metatrain models require neighbor lists to be present in the input systems 78 | systems = [ 79 | get_system_with_neighbor_lists(sys, get_requested_neighbor_lists(model)) 80 | for sys in systems 81 | ] 82 | 83 | # define an optimizer 84 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 85 | 86 | # this is one training step 87 | predictions = model(systems) 88 | loss = torch.nn.functional.mse_loss(predictions, my_targets) 89 | loss.backward() 90 | optimizer.step() 91 | -------------------------------------------------------------------------------- /examples/zbl/README.rst: -------------------------------------------------------------------------------- 1 | Running molecular dynamics with ASE 2 | =================================== 3 | -------------------------------------------------------------------------------- /examples/zbl/ethanol_reduced_100.xyz: -------------------------------------------------------------------------------- 1 | ../ase/ethanol_reduced_100.xyz -------------------------------------------------------------------------------- /examples/zbl/options_no_zbl.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | 3 | architecture: 4 | name: soap_bpnn 5 | model: 6 | zbl: false 7 | training: 8 | num_epochs: 10 9 | 10 | # training set section 11 | training_set: 12 | systems: 13 | read_from: ethanol_reduced_100.xyz 14 | length_unit: angstrom 15 | targets: 16 | energy: 17 | key: "energy" 18 | unit: "eV" # very important to run simulations 19 | 20 | validation_set: 0.1 21 | test_set: 0.0 22 | -------------------------------------------------------------------------------- /examples/zbl/options_zbl.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | 3 | architecture: 4 | name: soap_bpnn 5 | model: 6 | zbl: true 7 | training: 8 | num_epochs: 10 9 | 10 | # training set section 11 | training_set: 12 | systems: 13 | read_from: ethanol_reduced_100.xyz 14 | length_unit: angstrom 15 | targets: 16 | energy: 17 | key: "energy" 18 | unit: "eV" # very important to run simulations 19 | 20 | validation_set: 0.1 21 | test_set: 0.0 22 | -------------------------------------------------------------------------------- /examples/zbl/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mtt train options_no_zbl.yaml -o model_no_zbl.pt 4 | mtt train options_zbl.yaml -o model_zbl.pt 5 | -------------------------------------------------------------------------------- /src/metatrain/__init__.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | from pathlib import Path 3 | 4 | from ._version import __version__ # noqa: F401 5 | 6 | 7 | PACKAGE_ROOT = Path(__file__).parent.resolve() 8 | 9 | 10 | # A constant as "session" variable to set the random seed to a fixed value that do not 11 | # change within the execution of the program. 12 | RANDOM_SEED = secrets.randbelow(2**32) 13 | -------------------------------------------------------------------------------- /src/metatrain/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/src/metatrain/cli/__init__.py -------------------------------------------------------------------------------- /src/metatrain/cli/formatter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | class CustomHelpFormatter(argparse.RawDescriptionHelpFormatter): 5 | """Descriptions formatter showing positional arguments before optionals.""" 6 | 7 | def _format_usage(self, usage, actions, groups, prefix): 8 | if usage is None: 9 | # split optionals from positionals 10 | optionals = [] 11 | positionals = [] 12 | for action in actions: 13 | if action.option_strings: 14 | optionals.append(action) 15 | else: 16 | positionals.append(action) 17 | 18 | prog = "%(prog)s" % dict(prog=self._prog) 19 | 20 | # build full usage string 21 | format = self._format_actions_usage 22 | action_usage = format(positionals + optionals, groups) 23 | usage = " ".join([s for s in [prog, action_usage] if s]) 24 | 25 | # Call the superclass method to format the usage 26 | return super()._format_usage(usage, actions, groups, prefix) 27 | -------------------------------------------------------------------------------- /src/metatrain/deprecated/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/src/metatrain/deprecated/__init__.py -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import PET 2 | from .trainer import Trainer 3 | 4 | 5 | __model__ = PET 6 | __trainer__ = Trainer 7 | __capabilities__ = { 8 | "supported_devices": __model__.__supported_devices__, 9 | "supported_dtypes": __model__.__supported_dtypes__, 10 | } 11 | 12 | __authors__ = [ 13 | ("Sergey Pozdnyakov ", "@spozdn"), 14 | ("Arslan Mazitov ", "@abmazitov"), 15 | ("Filippo Bigi ", "@frostedoyster"), 16 | ] 17 | 18 | __maintainers__ = [ 19 | ("Arslan Mazitov ", "@abmazitov"), 20 | ] 21 | -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/default-hypers.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | 3 | name: deprecated.pet 4 | 5 | model: 6 | CUTOFF_DELTA: 0.2 7 | AVERAGE_POOLING: False 8 | TRANSFORMERS_CENTRAL_SPECIFIC: False 9 | HEADS_CENTRAL_SPECIFIC: False 10 | ADD_TOKEN_FIRST: True 11 | ADD_TOKEN_SECOND: True 12 | N_GNN_LAYERS: 3 13 | TRANSFORMER_D_MODEL: 128 14 | TRANSFORMER_N_HEAD: 4 15 | TRANSFORMER_DIM_FEEDFORWARD: 512 16 | HEAD_N_NEURONS: 128 17 | N_TRANS_LAYERS: 3 18 | ACTIVATION: silu 19 | USE_LENGTH: True 20 | USE_ONLY_LENGTH: False 21 | R_CUT: 5.0 22 | R_EMBEDDING_ACTIVATION: False 23 | COMPRESS_MODE: mlp 24 | BLEND_NEIGHBOR_SPECIES: False 25 | AVERAGE_BOND_ENERGIES: False 26 | USE_BOND_ENERGIES: True 27 | USE_ADDITIONAL_SCALAR_ATTRIBUTES: False 28 | SCALAR_ATTRIBUTES_SIZE: null 29 | TRANSFORMER_TYPE: PostLN # PostLN or PreLN 30 | USE_LONG_RANGE: False 31 | K_CUT: null # should be float; only used when USE_LONG_RANGE is True 32 | K_CUT_DELTA: null 33 | DTYPE: float32 # float32 or float16 or bfloat16 34 | N_TARGETS: 1 35 | TARGET_INDEX_KEY: target_index 36 | RESIDUAL_FACTOR: 0.5 37 | USE_ZBL: False 38 | 39 | training: 40 | USE_LORA_PEFT: False 41 | LORA_RANK: 4 42 | LORA_ALPHA: 0.5 43 | INITIAL_LR: 1e-4 44 | EPOCH_NUM_ATOMIC: 1000000000 45 | EPOCHS_WARMUP_ATOMIC: 100000000 46 | SCHEDULER_STEP_SIZE_ATOMIC: 500000000 # structural version is called "SCHEDULER_STEP_SIZE" 47 | GLOBAL_AUG: True 48 | SLIDING_FACTOR: 0.7 49 | ATOMIC_BATCH_SIZE: 850 # structural version is called "STRUCTURAL_BATCH_SIZE" 50 | BALANCED_DATA_LOADER: False # if True, use DynamicBatchSampler from torch_geometric 51 | MAX_TIME: 234000 52 | ENERGY_WEIGHT: 0.1 # only used when fitting MLIP 53 | MULTI_GPU: False 54 | RANDOM_SEED: 0 55 | CUDA_DETERMINISTIC: False 56 | MODEL_TO_START_WITH: null 57 | ALL_SPECIES_PATH: null 58 | SELF_CONTRIBUTIONS_PATH: null 59 | SUPPORT_MISSING_VALUES: False 60 | USE_WEIGHT_DECAY: False 61 | WEIGHT_DECAY: 0.0 62 | DO_GRADIENT_CLIPPING: False 63 | GRADIENT_CLIPPING_MAX_NORM: null # must be overwritten if DO_GRADIENT_CLIPPING is True 64 | USE_SHIFT_AGNOSTIC_LOSS: False # only used when fitting general target. Primary use case: EDOS 65 | ENERGIES_LOSS: per_structure # per_structure or per_atom 66 | CHECKPOINT_INTERVAL: 100 67 | -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/src/metatrain/deprecated/pet/modules/__init__.py -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/modules/analysis.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | 6 | def get_structural_batch_size(structures, atomic_batch_size): 7 | sizes = [len(structure.get_positions()) for structure in structures] 8 | average_size = np.mean(sizes) 9 | return math.ceil(atomic_batch_size / average_size) 10 | 11 | 12 | def convert_atomic_throughput(train_structures, atomic_throughput): 13 | sizes = [len(structure.get_positions()) for structure in train_structures] 14 | total_size = np.sum(sizes) 15 | return math.ceil(atomic_throughput / total_size) 16 | 17 | 18 | def adapt_hypers(hypers, train_structures): 19 | if "STRUCTURAL_BATCH_SIZE" not in hypers.__dict__.keys(): 20 | hypers.STRUCTURAL_BATCH_SIZE = get_structural_batch_size( 21 | train_structures, hypers.ATOMIC_BATCH_SIZE 22 | ) 23 | 24 | if "EPOCH_NUM" not in hypers.__dict__.keys(): 25 | hypers.EPOCH_NUM = convert_atomic_throughput( 26 | train_structures, hypers.EPOCH_NUM_ATOMIC 27 | ) 28 | 29 | if "SCHEDULER_STEP_SIZE" not in hypers.__dict__.keys(): 30 | hypers.SCHEDULER_STEP_SIZE = convert_atomic_throughput( 31 | train_structures, hypers.SCHEDULER_STEP_SIZE_ATOMIC 32 | ) 33 | 34 | if "EPOCHS_WARMUP" not in hypers.__dict__.keys(): 35 | hypers.EPOCHS_WARMUP = convert_atomic_throughput( 36 | train_structures, hypers.EPOCHS_WARMUP_ATOMIC 37 | ) 38 | -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/carbon_reduced_100.xyz") 5 | -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/tests/test_exported.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from metatomic.torch import ( 4 | ModelCapabilities, 5 | ModelEvaluationOptions, 6 | ModelMetadata, 7 | ModelOutput, 8 | System, 9 | ) 10 | 11 | from metatrain.deprecated.pet import PET as WrappedPET 12 | from metatrain.deprecated.pet.modules.hypers import Hypers 13 | from metatrain.deprecated.pet.modules.pet import PET 14 | from metatrain.utils.architectures import get_default_hypers 15 | from metatrain.utils.data import DatasetInfo 16 | from metatrain.utils.data.target_info import get_energy_target_info 17 | from metatrain.utils.neighbor_lists import ( 18 | get_requested_neighbor_lists, 19 | get_system_with_neighbor_lists, 20 | ) 21 | 22 | 23 | DEFAULT_HYPERS = get_default_hypers("deprecated.pet") 24 | 25 | 26 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 27 | def test_to(device): 28 | """Tests that the `.to()` method of the exported model works.""" 29 | if device == "cuda" and not torch.cuda.is_available(): 30 | pytest.skip("CUDA is not available") 31 | 32 | dtype = torch.float32 # for now 33 | dataset_info = DatasetInfo( 34 | length_unit="Angstrom", 35 | atomic_types=[1, 6, 7, 8], 36 | targets={"energy": get_energy_target_info({"unit": "eV"})}, 37 | ) 38 | model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) 39 | ARCHITECTURAL_HYPERS = Hypers(model.hypers) 40 | raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types)) 41 | model.set_trained_model(raw_pet) 42 | 43 | capabilities = ModelCapabilities( 44 | length_unit="Angstrom", 45 | atomic_types=model.atomic_types, 46 | outputs={ 47 | "energy": ModelOutput( 48 | quantity="energy", 49 | unit="eV", 50 | ) 51 | }, 52 | interaction_range=DEFAULT_HYPERS["model"]["N_GNN_LAYERS"] 53 | * DEFAULT_HYPERS["model"]["R_CUT"], 54 | dtype="float32", 55 | supported_devices=["cpu", "cuda"], 56 | ) 57 | 58 | exported = model.export(metadata=ModelMetadata(name="test")) 59 | 60 | # test correct metadata 61 | assert "This is the test model" in str(exported.metadata()) 62 | 63 | exported.to(device=device, dtype=dtype) 64 | 65 | system = System( 66 | types=torch.tensor([6, 6]), 67 | positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), 68 | cell=torch.zeros(3, 3), 69 | pbc=torch.tensor([False, False, False]), 70 | ) 71 | requested_neighbor_lists = get_requested_neighbor_lists(exported) 72 | system = get_system_with_neighbor_lists(system, requested_neighbor_lists) 73 | system = system.to(device=device, dtype=dtype) 74 | 75 | evaluation_options = ModelEvaluationOptions( 76 | length_unit=dataset_info.length_unit, 77 | outputs=capabilities.outputs, 78 | ) 79 | 80 | exported([system], evaluation_options, check_consistency=True) 81 | -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/tests/test_torchscript.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | from metatrain.deprecated.pet import PET as WrappedPET 6 | from metatrain.deprecated.pet.modules.hypers import Hypers 7 | from metatrain.deprecated.pet.modules.pet import PET 8 | from metatrain.utils.architectures import get_default_hypers 9 | from metatrain.utils.data import DatasetInfo 10 | from metatrain.utils.data.target_info import get_energy_target_info 11 | 12 | 13 | DEFAULT_HYPERS = get_default_hypers("deprecated.pet") 14 | 15 | 16 | def test_torchscript(): 17 | """Tests that the model can be jitted.""" 18 | 19 | dataset_info = DatasetInfo( 20 | length_unit="Angstrom", 21 | atomic_types=[1, 6, 7, 8], 22 | targets={"energy": get_energy_target_info({"unit": "eV"})}, 23 | ) 24 | model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) 25 | ARCHITECTURAL_HYPERS = Hypers(model.hypers) 26 | raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types)) 27 | model.set_trained_model(raw_pet) 28 | torch.jit.script(model) 29 | 30 | 31 | def test_torchscript_save_load(tmpdir): 32 | """Tests that the model can be jitted and saved.""" 33 | 34 | dataset_info = DatasetInfo( 35 | length_unit="Angstrom", 36 | atomic_types=[1, 6, 7, 8], 37 | targets={"energy": get_energy_target_info({"unit": "eV"})}, 38 | ) 39 | model = WrappedPET(DEFAULT_HYPERS["model"], dataset_info) 40 | ARCHITECTURAL_HYPERS = Hypers(model.hypers) 41 | raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types)) 42 | model.set_trained_model(raw_pet) 43 | torch.jit.script(model) 44 | with tmpdir.as_cwd(): 45 | torch.jit.save(torch.jit.script(model), "pet.pt") 46 | torch.jit.load("pet.pt") 47 | 48 | 49 | def test_torchscript_integers(): 50 | """Tests that the model can be jitted when some float 51 | parameters are instead supplied as integers.""" 52 | 53 | new_hypers = copy.deepcopy(DEFAULT_HYPERS["model"]) 54 | new_hypers["R_CUT"] = 5 55 | new_hypers["CUTOFF_DELTA"] = 1 56 | new_hypers["RESIDUAL_FACTOR"] = 1 57 | 58 | dataset_info = DatasetInfo( 59 | length_unit="Angstrom", 60 | atomic_types=[1, 6, 7, 8], 61 | targets={"energy": get_energy_target_info({"unit": "eV"})}, 62 | ) 63 | model = WrappedPET(new_hypers, dataset_info) 64 | ARCHITECTURAL_HYPERS = Hypers(model.hypers) 65 | raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types)) 66 | model.set_trained_model(raw_pet) 67 | torch.jit.script(model) 68 | -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_to_ase import dataset_to_ase 2 | from .load_raw_pet_model import load_raw_pet_model 3 | from .systems_to_batch_dict import systems_to_batch_dict 4 | from .update_hypers import update_hypers 5 | from .update_state_dict import update_state_dict 6 | 7 | 8 | __all__ = [ 9 | "systems_to_batch_dict", 10 | "dataset_to_ase", 11 | "update_hypers", 12 | "update_state_dict", 13 | "load_raw_pet_model", 14 | ] 15 | -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/utils/dataset_to_ase.py: -------------------------------------------------------------------------------- 1 | from metatensor.learn.data import DataLoader 2 | 3 | from ....utils.additive import remove_additive 4 | from ....utils.data import collate_fn 5 | from ....utils.data.system_to_ase import system_to_ase 6 | from ....utils.neighbor_lists import ( 7 | get_requested_neighbor_lists, 8 | get_system_with_neighbor_lists, 9 | ) 10 | 11 | 12 | # dummy dataloaders due to https://github.com/metatensor/metatensor/issues/521 13 | def dataset_to_ase(dataset, model, do_forces=True, target_name="energy"): 14 | dataloader = DataLoader( 15 | dataset, 16 | batch_size=1, 17 | shuffle=False, 18 | collate_fn=collate_fn, 19 | ) 20 | ase_dataset = [] 21 | for (system,), targets in dataloader: 22 | # remove additive model (e.g. ZBL) contributions 23 | requested_neighbor_lists = get_requested_neighbor_lists(model) 24 | system = get_system_with_neighbor_lists(system, requested_neighbor_lists) 25 | for additive_model in model.additive_models: 26 | targets = remove_additive( 27 | [system], targets, additive_model, model.dataset_info.targets 28 | ) 29 | # transform to ase atoms 30 | ase_atoms = system_to_ase(system) 31 | ase_atoms.info["energy"] = float( 32 | targets[target_name].block().values.squeeze(-1).detach().cpu().numpy() 33 | ) 34 | if do_forces: 35 | ase_atoms.arrays["forces"] = ( 36 | -targets[target_name] 37 | .block() 38 | .gradient("positions") 39 | .values.squeeze(-1) 40 | .detach() 41 | .cpu() 42 | .numpy() 43 | ) 44 | ase_dataset.append(ase_atoms) 45 | return ase_dataset 46 | -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/utils/fine_tuning.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import torch 4 | 5 | 6 | class LoRALayer(torch.nn.Module): 7 | def __init__(self, hidden_dim: int, rank: int): 8 | super(LoRALayer, self).__init__() 9 | self.hidden_dim = hidden_dim 10 | self.rank = rank 11 | self.A = torch.nn.Parameter(torch.randn(hidden_dim, rank)) 12 | self.B = torch.nn.Parameter(torch.randn(rank, hidden_dim)) 13 | self.reset_parameters() 14 | 15 | def reset_parameters(self): 16 | torch.nn.init.zeros_(self.A) 17 | torch.nn.init.xavier_normal_(self.B) 18 | 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | return x @ self.A @ self.B 21 | 22 | 23 | class AttentionBlockWithLoRA(torch.nn.Module): 24 | def __init__(self, original_module: torch.nn.Module, rank: int, alpha: float): 25 | super(AttentionBlockWithLoRA, self).__init__() 26 | self.original_module = original_module 27 | self.rank = rank 28 | self.alpha = alpha 29 | self.hidden_dim = original_module.output_linear.out_features 30 | self.lora = LoRALayer(self.hidden_dim, self.rank) 31 | 32 | def forward( 33 | self, x: torch.Tensor, multipliers: Optional[torch.Tensor] = None 34 | ) -> torch.Tensor: 35 | return self.original_module(x, multipliers) + self.alpha * self.lora(x) 36 | 37 | 38 | class LoRAWrapper(torch.nn.Module): 39 | def __init__(self, model: torch.nn.Module, rank: int, alpha: float): 40 | super(LoRAWrapper, self).__init__() 41 | self.model = model 42 | self.hypers = model.hypers 43 | self.rank = rank 44 | self.alpha = alpha 45 | self.hidden_dim = model.hypers.TRANSFORMER_D_MODEL 46 | self.num_hidden_layers = model.hypers.N_GNN_LAYERS * model.hypers.N_TRANS_LAYERS 47 | for param in model.parameters(): 48 | param.requires_grad = False 49 | for gnn_layer in model.gnn_layers: 50 | for trans_layer in gnn_layer.trans.layers: 51 | trans_layer.attention = AttentionBlockWithLoRA( 52 | trans_layer.attention, self.rank, self.alpha 53 | ) 54 | 55 | def forward( 56 | self, 57 | batch_dict: Dict[str, torch.Tensor], 58 | rotations: Optional[torch.Tensor] = None, 59 | ): 60 | return self.model(batch_dict, rotations) 61 | -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/utils/load_raw_pet_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from ..modules.hypers import Hypers 7 | from ..modules.pet import PET, SelfContributionsWrapper 8 | from .fine_tuning import LoRAWrapper 9 | from .update_state_dict import update_state_dict 10 | 11 | 12 | def load_raw_pet_model( 13 | state_dict: Dict, 14 | hypers: Dict, 15 | atomic_types: List, 16 | self_contributions: np.ndarray, 17 | **kwargs, 18 | ) -> "SelfContributionsWrapper": 19 | """Creates a raw PET model instance.""" 20 | 21 | ARCHITECTURAL_HYPERS = Hypers(hypers) 22 | 23 | ARCHITECTURAL_HYPERS.D_OUTPUT = 1 # type: ignore 24 | ARCHITECTURAL_HYPERS.TARGET_AGGREGATION = "sum" # type: ignore 25 | ARCHITECTURAL_HYPERS.TARGET_TYPE = "atomic" # type: ignore 26 | 27 | raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(atomic_types)) 28 | if "use_lora_peft" in kwargs and kwargs["use_lora_peft"] is True: 29 | lora_rank = kwargs["lora_rank"] 30 | lora_alpha = kwargs["lora_alpha"] 31 | raw_pet = LoRAWrapper(raw_pet, lora_rank, lora_alpha) 32 | 33 | new_state_dict = update_state_dict(state_dict) 34 | dtype = next(iter(new_state_dict.values())).dtype 35 | raw_pet.to(dtype).load_state_dict(new_state_dict) 36 | if isinstance(self_contributions, torch.Tensor): 37 | self_contributions = self_contributions.cpu().numpy() 38 | wrapper = SelfContributionsWrapper(raw_pet, self_contributions) 39 | 40 | return wrapper 41 | -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/utils/update_hypers.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | 4 | def update_hypers( 5 | hypers: Dict[str, Any], model_hypers: Dict[str, Any], do_forces: bool = True 6 | ): 7 | """ 8 | Updates the hypers dictionary with the model hypers, the 9 | MLIP_SETTINGS and UTILITY_FLAGS keys of the PET model. 10 | """ 11 | 12 | # set model hypers 13 | hypers = hypers.copy() 14 | hypers["ARCHITECTURAL_HYPERS"] = model_hypers 15 | hypers["ARCHITECTURAL_HYPERS"]["DTYPE"] = "float32" 16 | hypers["ARCHITECTURAL_HYPERS"]["D_OUTPUT"] = 1 17 | hypers["ARCHITECTURAL_HYPERS"]["TARGET_TYPE"] = "structural" 18 | hypers["ARCHITECTURAL_HYPERS"]["TARGET_AGGREGATION"] = "sum" 19 | 20 | # set MLIP_SETTINGS 21 | hypers["MLIP_SETTINGS"] = { 22 | "ENERGY_KEY": "energy", 23 | "FORCES_KEY": "forces", 24 | "USE_ENERGIES": True, 25 | "USE_FORCES": do_forces, 26 | } 27 | 28 | # set PET utility flags 29 | hypers["UTILITY_FLAGS"] = { 30 | "CALCULATION_TYPE": None, 31 | } 32 | return hypers 33 | -------------------------------------------------------------------------------- /src/metatrain/deprecated/pet/utils/update_state_dict.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | def update_state_dict(state_dict: Dict) -> Dict: 5 | """ 6 | Updates the state_dict keys so they match the model's keys. 7 | """ 8 | new_state_dict = {} 9 | for name, value in state_dict.items(): 10 | if "pet_model." in name: 11 | name = name.split("pet_model.")[1] 12 | new_state_dict[name] = value 13 | return new_state_dict 14 | -------------------------------------------------------------------------------- /src/metatrain/experimental/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/src/metatrain/experimental/__init__.py -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import NanoPET 2 | from .trainer import Trainer 3 | 4 | 5 | __model__ = NanoPET 6 | __trainer__ = Trainer 7 | 8 | __authors__ = [ 9 | ("Filippo Bigi ", "@frostedoyster"), 10 | ] 11 | 12 | __maintainers__ = [ 13 | ("Filippo Bigi ", "@frostedoyster"), 14 | ] 15 | -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/default-hypers.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | 3 | name: experimental.nanopet 4 | 5 | model: 6 | cutoff: 5.0 7 | cutoff_width: 0.5 8 | d_pet: 128 9 | num_heads: 4 10 | num_attention_layers: 2 11 | num_gnn_layers: 2 12 | heads: {} 13 | zbl: False 14 | long_range: 15 | enable: false 16 | use_ewald: false 17 | smearing: 1.4 18 | kspace_resolution: 1.33 19 | interpolation_nodes: 5 20 | 21 | training: 22 | distributed: False 23 | distributed_port: 39591 24 | batch_size: 16 25 | num_epochs: 10000 26 | learning_rate: 3e-4 27 | scheduler_patience: 100 28 | scheduler_factor: 0.8 29 | log_interval: 10 30 | checkpoint_interval: 100 31 | scale_targets: true 32 | fixed_composition_weights: {} 33 | per_structure_targets: [] 34 | log_mae: False 35 | log_separate_blocks: false 36 | best_model_metric: rmse_prod 37 | loss: 38 | type: mse 39 | weights: {} 40 | reduction: mean 41 | -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/src/metatrain/experimental/nanopet/modules/__init__.py -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/modules/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AttentionBlock(torch.nn.Module): 5 | """ 6 | A single transformer attention block. We are not using the 7 | MultiHeadAttention module from torch.nn because we need to apply a 8 | radial mask to the attention weights. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | hidden_size: int, 14 | num_heads: int, 15 | dropout_rate: float, 16 | attention_dropout_rate: float, 17 | ): 18 | super().__init__() 19 | 20 | self.num_heads = num_heads 21 | self.in_proj = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=False) 22 | self.out_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False) 23 | self.layernorm = torch.nn.LayerNorm(normalized_shape=hidden_size) 24 | self.attention_dropout_rate = attention_dropout_rate 25 | 26 | def forward( 27 | self, 28 | inputs: torch.Tensor, # seq_len hidden_size 29 | radial_mask: torch.Tensor, # seq_len 30 | ) -> torch.Tensor: # seq_len hidden_size 31 | # Pre-layer normalization 32 | normed_inputs = self.layernorm(inputs) 33 | 34 | # Input projection 35 | qkv = self.in_proj(normed_inputs) 36 | q, k, v = torch.chunk(qkv, 3, dim=-1) 37 | # Split heads 38 | q = q.reshape(q.size(0), q.size(1), self.num_heads, q.size(2) // self.num_heads) 39 | k = k.reshape(k.size(0), k.size(1), self.num_heads, k.size(2) // self.num_heads) 40 | v = v.reshape(v.size(0), v.size(1), self.num_heads, v.size(2) // self.num_heads) 41 | q = q.transpose(1, 2) 42 | k = k.transpose(1, 2) 43 | v = v.transpose(1, 2) 44 | # Attention 45 | attention_weights = torch.matmul(q, k.transpose(-2, -1)) / (k.size(-1) ** 0.5) 46 | attention_weights = attention_weights.softmax(dim=-1) 47 | attention_weights = torch.nn.functional.dropout( 48 | attention_weights, p=self.attention_dropout_rate, training=self.training 49 | ) 50 | 51 | # Radial mask 52 | attention_weights = attention_weights * radial_mask[:, None, None, :] 53 | attention_weights = attention_weights / ( 54 | attention_weights.sum(dim=-1, keepdim=True) + 1e-6 55 | ) 56 | 57 | # Attention output 58 | attention_output = torch.matmul(attention_weights, v) 59 | attention_output = attention_output.transpose(1, 2) 60 | attention_output = attention_output.reshape( 61 | attention_output.size(0), 62 | attention_output.size(1), 63 | attention_output.size(2) * attention_output.size(3), 64 | ) 65 | 66 | # Output projection 67 | outputs = self.out_proj(attention_output) 68 | 69 | # Residual connection 70 | outputs = (outputs + inputs) * 0.5**0.5 71 | 72 | return outputs 73 | -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/modules/encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | 5 | 6 | class Encoder(torch.nn.Module): 7 | """ 8 | An encoder of edges. It generates a fixed-size representation of the 9 | interatomic vector, the chemical element of the center and the chemical 10 | element of the neighbor. The representations are then concatenated and 11 | compressed to the initial fixed size. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | n_species: int, 17 | hidden_size: int, 18 | ): 19 | super().__init__() 20 | 21 | self.cartesian_encoder = torch.nn.Sequential( 22 | torch.nn.Linear(in_features=3, out_features=4 * hidden_size, bias=False), 23 | torch.nn.SiLU(), 24 | torch.nn.Linear( 25 | in_features=4 * hidden_size, out_features=4 * hidden_size, bias=False 26 | ), 27 | torch.nn.SiLU(), 28 | torch.nn.Linear( 29 | in_features=4 * hidden_size, out_features=hidden_size, bias=False 30 | ), 31 | ) 32 | self.center_encoder = torch.nn.Embedding( 33 | num_embeddings=n_species, embedding_dim=hidden_size 34 | ) 35 | self.neighbor_encoder = torch.nn.Embedding( 36 | num_embeddings=n_species, embedding_dim=hidden_size 37 | ) 38 | self.compressor = torch.nn.Linear( 39 | in_features=3 * hidden_size, out_features=hidden_size, bias=False 40 | ) 41 | 42 | def forward( 43 | self, 44 | features: Dict[str, torch.Tensor], 45 | ): 46 | # Encode cartesian coordinates 47 | cartesian_features = self.cartesian_encoder(features["cartesian"]) 48 | 49 | # Encode centers 50 | center_features = self.center_encoder(features["center"]) 51 | 52 | # Encode neighbors 53 | neighbor_features = self.neighbor_encoder(features["neighbor"]) 54 | 55 | # Concatenate 56 | encoded_features = torch.concatenate( 57 | [cartesian_features, center_features, neighbor_features], dim=-1 58 | ) 59 | 60 | # Compress 61 | compressed_features = self.compressor(encoded_features) 62 | 63 | return compressed_features 64 | -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/modules/feedforward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FeedForwardBlock(torch.nn.Module): 5 | """A single transformer feed forward block.""" 6 | 7 | def __init__( 8 | self, 9 | hidden_size: int, 10 | intermediate_size: int, 11 | dropout_rate: float, 12 | ): 13 | super().__init__() 14 | 15 | self.mlp = torch.nn.Linear( 16 | in_features=hidden_size, out_features=intermediate_size, bias=False 17 | ) 18 | self.output = torch.nn.Linear( 19 | in_features=intermediate_size, out_features=hidden_size, bias=False 20 | ) 21 | 22 | self.layernorm = torch.nn.LayerNorm(normalized_shape=hidden_size) 23 | self.dropout = torch.nn.Dropout(dropout_rate) 24 | 25 | def forward( 26 | self, 27 | inputs: torch.Tensor, # hidden_size 28 | ) -> torch.Tensor: # hidden_size 29 | # Pre-layer normalization 30 | normed_inputs = self.layernorm(inputs) 31 | 32 | # Feed-forward 33 | hidden = self.mlp(normed_inputs) 34 | hidden = torch.nn.functional.gelu(hidden) 35 | 36 | # Project back to input size 37 | outputs = self.output(hidden) 38 | 39 | # Apply dropout 40 | outputs = self.dropout(outputs) 41 | 42 | # Residual connection 43 | outputs = (outputs + inputs) * 0.5**0.5 44 | 45 | return outputs 46 | -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/modules/radial_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_radial_mask(r, r_cut: float, r_transition: float): 5 | # All radii are already guaranteed to be smaller than r_cut 6 | return torch.where( 7 | r < r_transition, 8 | torch.ones_like(r), 9 | 0.5 * (torch.cos(torch.pi * (r - r_transition) / (r_cut - r_transition)) + 1.0), 10 | ) 11 | -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/modules/structures.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from metatomic.torch import NeighborListOptions, System 5 | 6 | 7 | def concatenate_structures( 8 | systems: List[System], neighbor_list_options: NeighborListOptions 9 | ): 10 | positions = [] 11 | centers = [] 12 | neighbors = [] 13 | species = [] 14 | cell_shifts = [] 15 | cells = [] 16 | node_counter = 0 17 | 18 | for system in systems: 19 | positions.append(system.positions) 20 | species.append(system.types) 21 | 22 | assert len(system.known_neighbor_lists()) >= 1, "no neighbor list found" 23 | neighbor_list = system.get_neighbor_list(neighbor_list_options) 24 | nl_values = neighbor_list.samples.values 25 | 26 | centers.append(nl_values[:, 0] + node_counter) 27 | neighbors.append(nl_values[:, 1] + node_counter) 28 | cell_shifts.append(nl_values[:, 2:]) 29 | 30 | cells.append(system.cell) 31 | 32 | node_counter += len(system.positions) 33 | 34 | positions = torch.cat(positions) 35 | centers = torch.cat(centers) 36 | neighbors = torch.cat(neighbors) 37 | species = torch.cat(species) 38 | cells = torch.stack(cells) 39 | cell_shifts = torch.cat(cell_shifts) 40 | 41 | return ( 42 | positions, 43 | centers, 44 | neighbors, 45 | species, 46 | cells, 47 | cell_shifts, 48 | ) 49 | -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/modules/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .attention import AttentionBlock 4 | from .feedforward import FeedForwardBlock 5 | 6 | 7 | class TransformerLayer(torch.nn.Module): 8 | """A single transformer layer.""" 9 | 10 | def __init__( 11 | self, 12 | hidden_size: int, 13 | intermediate_size: int, 14 | num_heads: int, 15 | dropout_rate: float, 16 | attention_dropout_rate: float, 17 | ): 18 | super().__init__() 19 | 20 | self.attention_block = AttentionBlock( 21 | hidden_size=hidden_size, 22 | num_heads=num_heads, 23 | dropout_rate=dropout_rate, 24 | attention_dropout_rate=attention_dropout_rate, 25 | ) 26 | self.ff_block = FeedForwardBlock( 27 | hidden_size=hidden_size, 28 | intermediate_size=intermediate_size, 29 | dropout_rate=dropout_rate, 30 | ) 31 | 32 | def forward( 33 | self, 34 | inputs: torch.Tensor, 35 | radial_mask: torch.Tensor, 36 | ) -> torch.Tensor: 37 | attention_output = self.attention_block(inputs, radial_mask) 38 | output = self.ff_block(attention_output) 39 | 40 | return output 41 | 42 | 43 | class Transformer(torch.nn.Module): 44 | """A transformer model.""" 45 | 46 | def __init__( 47 | self, 48 | hidden_size: int, 49 | intermediate_size: int, 50 | num_heads: int, 51 | num_layers: int, 52 | dropout_rate: float, 53 | attention_dropout_rate: float, 54 | ): 55 | super().__init__() 56 | 57 | self.layers = torch.nn.ModuleList( 58 | [ 59 | TransformerLayer( 60 | hidden_size=hidden_size, 61 | intermediate_size=intermediate_size, 62 | num_heads=num_heads, 63 | dropout_rate=dropout_rate, 64 | attention_dropout_rate=attention_dropout_rate, 65 | ) 66 | for _ in range(num_layers) 67 | ] 68 | ) 69 | 70 | def forward( 71 | self, 72 | inputs, 73 | radial_mask, 74 | ): 75 | x = inputs 76 | for layer in self.layers: 77 | x = layer(x, radial_mask) 78 | return x 79 | -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from metatrain.utils.architectures import get_default_hypers 4 | 5 | 6 | DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz") 7 | 8 | DEFAULT_HYPERS = get_default_hypers("experimental.nanopet") 9 | MODEL_HYPERS = DEFAULT_HYPERS["model"] 10 | -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/tests/test_continue.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import metatensor 4 | import torch 5 | from omegaconf import OmegaConf 6 | 7 | from metatrain.experimental.nanopet import NanoPET, Trainer 8 | from metatrain.utils.data import Dataset, DatasetInfo 9 | from metatrain.utils.data.readers import read_systems, read_targets 10 | from metatrain.utils.data.target_info import get_energy_target_info 11 | from metatrain.utils.io import model_from_checkpoint 12 | from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists 13 | 14 | from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS 15 | 16 | 17 | def test_continue(monkeypatch, tmp_path): 18 | """Tests that a model can be checkpointed and loaded 19 | for a continuation of the training process""" 20 | 21 | monkeypatch.chdir(tmp_path) 22 | shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz") 23 | 24 | systems = read_systems(DATASET_PATH) 25 | systems = [system.to(torch.float32) for system in systems] 26 | 27 | target_info_dict = {} 28 | target_info_dict["mtt::U0"] = get_energy_target_info( 29 | {"quantity": "energy", "unit": "eV"} 30 | ) 31 | 32 | dataset_info = DatasetInfo( 33 | length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict 34 | ) 35 | model = NanoPET(MODEL_HYPERS, dataset_info) 36 | 37 | conf = { 38 | "mtt::U0": { 39 | "quantity": "energy", 40 | "read_from": DATASET_PATH, 41 | "reader": "ase", 42 | "key": "U0", 43 | "unit": "eV", 44 | "type": "scalar", 45 | "per_atom": False, 46 | "num_subtargets": 1, 47 | "forces": False, 48 | "stress": False, 49 | "virial": False, 50 | } 51 | } 52 | targets, _ = read_targets(OmegaConf.create(conf)) 53 | 54 | # systems in float64 are required for training 55 | systems = [system.to(torch.float64) for system in systems] 56 | 57 | dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) 58 | 59 | hypers = DEFAULT_HYPERS.copy() 60 | hypers["training"]["num_epochs"] = 0 61 | trainer = Trainer(hypers["training"]) 62 | trainer.train( 63 | model=model, 64 | dtype=torch.float32, 65 | devices=[torch.device("cpu")], 66 | train_datasets=[dataset], 67 | val_datasets=[dataset], 68 | checkpoint_dir=".", 69 | ) 70 | 71 | trainer.save_checkpoint(model, "tmp.ckpt") 72 | 73 | model_after = model_from_checkpoint("tmp.ckpt", context="restart") 74 | assert isinstance(model_after, NanoPET) 75 | model_after.restart(dataset_info) 76 | 77 | hypers["training"]["num_epochs"] = 0 78 | trainer = Trainer(hypers["training"]) 79 | trainer.train( 80 | model=model_after, 81 | dtype=torch.float32, 82 | devices=[torch.device("cpu")], 83 | train_datasets=[dataset], 84 | val_datasets=[dataset], 85 | checkpoint_dir=".", 86 | ) 87 | 88 | # evaluation 89 | systems = [system.to(torch.float32) for system in systems] 90 | for system in systems: 91 | get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) 92 | 93 | model.eval() 94 | model_after.eval() 95 | 96 | # Predict on the first five systems 97 | output_before = model(systems[:5], {"mtt::U0": model.outputs["mtt::U0"]}) 98 | output_after = model_after(systems[:5], {"mtt::U0": model_after.outputs["mtt::U0"]}) 99 | 100 | assert metatensor.torch.allclose(output_before["mtt::U0"], output_after["mtt::U0"]) 101 | -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/tests/test_exported.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from metatomic.torch import ModelEvaluationOptions, ModelMetadata, System 4 | 5 | from metatrain.experimental.nanopet import NanoPET 6 | from metatrain.utils.data import DatasetInfo 7 | from metatrain.utils.data.target_info import get_energy_target_info 8 | from metatrain.utils.neighbor_lists import ( 9 | get_requested_neighbor_lists, 10 | get_system_with_neighbor_lists, 11 | ) 12 | 13 | from . import MODEL_HYPERS 14 | 15 | 16 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 17 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) 18 | def test_to(device, dtype): 19 | """Tests that the `.to()` method of the exported model works.""" 20 | if device == "cuda" and not torch.cuda.is_available(): 21 | pytest.skip("CUDA is not available") 22 | 23 | dataset_info = DatasetInfo( 24 | length_unit="Angstrom", 25 | atomic_types=[1, 6, 7, 8], 26 | targets={ 27 | "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) 28 | }, 29 | ) 30 | model = NanoPET(MODEL_HYPERS, dataset_info).to(dtype=dtype) 31 | 32 | exported = model.export(metadata=ModelMetadata(name="test")) 33 | 34 | # test correct metadata 35 | assert "This is the test model" in str(exported.metadata()) 36 | 37 | exported.to(device=device) 38 | 39 | system = System( 40 | types=torch.tensor([6, 6]), 41 | positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), 42 | cell=torch.zeros(3, 3), 43 | pbc=torch.tensor([False, False, False]), 44 | ) 45 | requested_neighbor_lists = get_requested_neighbor_lists(exported) 46 | system = get_system_with_neighbor_lists(system, requested_neighbor_lists) 47 | system = system.to(device=device, dtype=dtype) 48 | 49 | evaluation_options = ModelEvaluationOptions( 50 | length_unit=dataset_info.length_unit, 51 | outputs=model.outputs, 52 | ) 53 | 54 | exported([system], evaluation_options, check_consistency=True) 55 | -------------------------------------------------------------------------------- /src/metatrain/experimental/nanopet/tests/test_torchscript.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | from metatomic.torch import System 5 | 6 | from metatrain.experimental.nanopet import NanoPET 7 | from metatrain.utils.data import DatasetInfo 8 | from metatrain.utils.data.target_info import get_energy_target_info 9 | from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists 10 | 11 | from . import MODEL_HYPERS 12 | 13 | 14 | def test_torchscript(): 15 | """Tests that the model can be jitted.""" 16 | 17 | dataset_info = DatasetInfo( 18 | length_unit="Angstrom", 19 | atomic_types=[1, 6, 7, 8], 20 | targets={ 21 | "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) 22 | }, 23 | ) 24 | model = NanoPET(MODEL_HYPERS, dataset_info) 25 | 26 | system = System( 27 | types=torch.tensor([6, 1, 8, 7]), 28 | positions=torch.tensor( 29 | [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] 30 | ), 31 | cell=torch.zeros(3, 3), 32 | pbc=torch.tensor([False, False, False]), 33 | ) 34 | system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) 35 | 36 | model = torch.jit.script(model) 37 | model( 38 | [system], 39 | {"energy": model.outputs["energy"]}, 40 | ) 41 | 42 | 43 | def test_torchscript_save_load(tmpdir): 44 | """Tests that the model can be jitted and saved.""" 45 | 46 | dataset_info = DatasetInfo( 47 | length_unit="Angstrom", 48 | atomic_types=[1, 6, 7, 8], 49 | targets={ 50 | "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) 51 | }, 52 | ) 53 | model = NanoPET(MODEL_HYPERS, dataset_info) 54 | 55 | with tmpdir.as_cwd(): 56 | torch.jit.save(torch.jit.script(model), "model.pt") 57 | torch.jit.load("model.pt") 58 | 59 | 60 | def test_torchscript_integers(): 61 | """Tests that the model can be jitted when some float 62 | parameters are instead supplied as integers.""" 63 | 64 | new_hypers = copy.deepcopy(MODEL_HYPERS) 65 | new_hypers["cutoff"] = 5 66 | new_hypers["cutoff_width"] = 1 67 | 68 | dataset_info = DatasetInfo( 69 | length_unit="Angstrom", 70 | atomic_types=[1, 6, 7, 8], 71 | targets={ 72 | "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) 73 | }, 74 | ) 75 | model = NanoPET(new_hypers, dataset_info) 76 | 77 | system = System( 78 | types=torch.tensor([6, 1, 8, 7]), 79 | positions=torch.tensor( 80 | [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] 81 | ), 82 | cell=torch.zeros(3, 3), 83 | pbc=torch.tensor([False, False, False]), 84 | ) 85 | system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) 86 | 87 | model = torch.jit.script(model) 88 | model( 89 | [system], 90 | {"energy": model.outputs["energy"]}, 91 | ) 92 | -------------------------------------------------------------------------------- /src/metatrain/gap/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import GAP 2 | from .trainer import Trainer 3 | 4 | 5 | __model__ = GAP 6 | __trainer__ = Trainer 7 | 8 | __authors__ = [ 9 | ("Alexander Goscinski ", "@agosckinski"), 10 | ("Davide Tisi ", "@DavideTisi"), 11 | ] 12 | 13 | __maintainers__ = [ 14 | ("Davide Tisi ", "@DavideTisi"), 15 | ] 16 | -------------------------------------------------------------------------------- /src/metatrain/gap/default-hypers.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | name: gap 3 | 4 | model: 5 | soap: 6 | cutoff: 7 | radius: 5.0 8 | smoothing: 9 | type: ShiftedCosine 10 | width: 1.0 11 | density: 12 | type: Gaussian 13 | center_atom_weight: 1.0 14 | width: 0.3 15 | scaling: 16 | type: Willatt2018 17 | rate: 1.0 18 | scale: 2.0 19 | exponent: 7.0 20 | basis: 21 | type: TensorProduct 22 | max_angular: 6 23 | radial: 24 | type: Gto 25 | max_radial: 7 26 | krr: 27 | degree: 2 28 | num_sparse_points: 500 29 | zbl: false 30 | 31 | training: 32 | regularizer: 0.001 33 | regularizer_forces: null 34 | -------------------------------------------------------------------------------- /src/metatrain/gap/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from metatrain.utils.architectures import get_default_hypers 4 | 5 | 6 | DEFAULT_HYPERS = get_default_hypers("gap") 7 | DATASET_PATH = str(Path(__file__).parents[4] / "tests/resources/qm9_reduced_100.xyz") 8 | 9 | DATASET_ETHANOL_PATH = str( 10 | Path(__file__).parents[4] / "tests/resources/ethanol_reduced_100.xyz" 11 | ) 12 | -------------------------------------------------------------------------------- /src/metatrain/gap/tests/ethanol_reduced_100.xyz: -------------------------------------------------------------------------------- 1 | ../../../../examples/ase/ethanol_reduced_100.xyz -------------------------------------------------------------------------------- /src/metatrain/gap/tests/options-gap.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | name: gap 3 | 4 | model: 5 | soap: 6 | cutoff: 7 | radius: 5.5 8 | smoothing: 9 | type: ShiftedCosine 10 | width: 1.0 11 | density: 12 | type: Gaussian 13 | center_atom_weight: 1.0 14 | width: 0.2 15 | scaling: 16 | type: Willatt2018 17 | rate: 1.0 18 | scale: 2.0 19 | exponent: 7.0 20 | basis: 21 | type: TensorProduct 22 | max_angular: 6 23 | radial: 24 | type: Gto 25 | max_radial: 7 # now exclusive 26 | krr: 27 | degree: 2 28 | num_sparse_points: 900 29 | zbl: false 30 | 31 | training: 32 | regularizer: 0.00005 33 | regularizer_forces: 0.001 34 | 35 | training_set: 36 | systems: "ethanol_reduced_100.xyz" # file where the positions are stored 37 | targets: 38 | energy: 39 | key: "energy" # name of the target value 40 | unit: "eV" # unit of the target value 41 | 42 | test_set: 43 | systems: "ethanol_reduced_100.xyz" # file where the positions are stored 44 | targets: 45 | energy: 46 | key: "energy" # name of the target value 47 | unit: "eV" # unit of the target value 48 | 49 | validation_set: 50 | systems: "ethanol_reduced_100.xyz" # file where the positions are stored 51 | targets: 52 | energy: 53 | key: "energy" # name of the target value 54 | unit: "eV" # unit of the target value 55 | -------------------------------------------------------------------------------- /src/metatrain/gap/tests/test_exported.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from metatomic.torch import ModelMetadata 3 | from omegaconf import OmegaConf 4 | 5 | from metatrain.gap import GAP, Trainer 6 | from metatrain.utils.data import Dataset, DatasetInfo 7 | from metatrain.utils.data.readers import read_systems, read_targets 8 | from metatrain.utils.data.target_info import get_energy_target_info 9 | 10 | from . import DATASET_PATH, DEFAULT_HYPERS 11 | 12 | 13 | def test_export(): 14 | """Tests that export works with injected metadata""" 15 | 16 | systems = read_systems(DATASET_PATH) 17 | 18 | conf = { 19 | "energy": { 20 | "quantity": "energy", 21 | "read_from": DATASET_PATH, 22 | "reader": "ase", 23 | "key": "U0", 24 | "unit": "kcal/mol", 25 | "type": "scalar", 26 | "per_atom": False, 27 | "num_subtargets": 1, 28 | "forces": False, 29 | "stress": False, 30 | "virial": False, 31 | } 32 | } 33 | targets, _ = read_targets(OmegaConf.create(conf)) 34 | dataset = Dataset.from_dict({"system": systems, "energy": targets["energy"]}) 35 | 36 | target_info_dict = {} 37 | target_info_dict["energy"] = get_energy_target_info({"unit": "eV"}) 38 | 39 | dataset_info = DatasetInfo( 40 | length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict 41 | ) 42 | 43 | dataset_info = DatasetInfo( 44 | length_unit="Angstrom", 45 | atomic_types=[1, 6, 7, 8], 46 | targets={ 47 | "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) 48 | }, 49 | ) 50 | model = GAP(DEFAULT_HYPERS["model"], dataset_info) 51 | 52 | # we have to train gap before we can export... 53 | trainer = Trainer(DEFAULT_HYPERS["training"]) 54 | trainer.train( 55 | model=model, 56 | dtype=torch.float64, 57 | devices=[torch.device("cpu")], 58 | train_datasets=[dataset], 59 | val_datasets=[dataset], 60 | checkpoint_dir=".", 61 | ) 62 | 63 | exported = model.export(metadata=ModelMetadata(name="test")) 64 | 65 | # test correct metadata 66 | assert "This is the test model" in str(exported.metadata()) 67 | -------------------------------------------------------------------------------- /src/metatrain/pet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import PET 2 | from .trainer import Trainer 3 | 4 | 5 | __model__ = PET 6 | __trainer__ = Trainer 7 | __capabilities__ = { 8 | "supported_devices": __model__.__supported_devices__, 9 | "supported_dtypes": __model__.__supported_dtypes__, 10 | } 11 | 12 | __authors__ = [ 13 | ("Sergey Pozdnyakov ", "@spozdn"), 14 | ("Arslan Mazitov ", "@abmazitov"), 15 | ("Filippo Bigi ", "@frostedoyster"), 16 | ] 17 | 18 | __maintainers__ = [ 19 | ("Arslan Mazitov ", "@abmazitov"), 20 | ("Filippo Bigi ", "@frostedoyster"), 21 | ] 22 | -------------------------------------------------------------------------------- /src/metatrain/pet/default-hypers.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | 3 | name: pet 4 | 5 | model: 6 | cutoff: 4.5 7 | cutoff_width: 0.2 8 | d_pet: 128 9 | d_head: 128 10 | d_feedforward: 512 11 | num_heads: 8 12 | num_attention_layers: 2 13 | num_gnn_layers: 2 14 | zbl: false 15 | long_range: 16 | enable: false 17 | use_ewald: false 18 | smearing: 1.4 19 | kspace_resolution: 1.33 20 | interpolation_nodes: 5 21 | 22 | training: 23 | distributed: false 24 | distributed_port: 39591 25 | batch_size: 16 26 | num_epochs: 10000 27 | num_epochs_warmup: 100 28 | learning_rate: 1e-4 29 | weight_decay: null 30 | scheduler_patience: 250 31 | log_interval: 1 32 | checkpoint_interval: 100 33 | scale_targets: false 34 | fixed_composition_weights: {} 35 | per_structure_targets: [] 36 | log_mae: true 37 | log_separate_blocks: false 38 | best_model_metric: rmse_prod 39 | finetune: {} 40 | grad_clip_norm: .inf 41 | loss: 42 | type: mse 43 | weights: {} 44 | reduction: mean 45 | sliding_factor: null 46 | -------------------------------------------------------------------------------- /src/metatrain/pet/modules/utilities.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def cutoff_func(grid: torch.Tensor, r_cut: float, delta: float): 5 | mask_bigger = grid >= r_cut 6 | mask_smaller = grid <= r_cut - delta 7 | grid = (grid - r_cut + delta) / delta 8 | f = 0.5 + 0.5 * torch.cos(torch.pi * grid) 9 | 10 | f[mask_bigger] = 0.0 11 | f[mask_smaller] = 1.0 12 | return f 13 | 14 | 15 | class DummyModule(torch.nn.Module): 16 | """Dummy torch module to make torchscript happy. 17 | This model should never be run""" 18 | 19 | def __init__(self): 20 | super(DummyModule, self).__init__() 21 | 22 | def forward(self, x) -> torch.Tensor: 23 | raise RuntimeError("This model should never be run") 24 | -------------------------------------------------------------------------------- /src/metatrain/pet/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from metatrain.utils.architectures import get_default_hypers 4 | 5 | 6 | DATASET_PATH = str(Path(__file__).parents[4] / "tests/resources/qm9_reduced_100.xyz") 7 | DATASET_WITH_FORCES_PATH = str( 8 | Path(__file__).parents[4] / "tests/resources/carbon_reduced_100.xyz" 9 | ) 10 | 11 | DEFAULT_HYPERS = get_default_hypers("pet") 12 | MODEL_HYPERS = DEFAULT_HYPERS["model"] 13 | -------------------------------------------------------------------------------- /src/metatrain/pet/tests/test_continue.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import metatensor 4 | import torch 5 | from omegaconf import OmegaConf 6 | 7 | from metatrain.pet import PET, Trainer 8 | from metatrain.utils.data import Dataset, DatasetInfo 9 | from metatrain.utils.data.readers import read_systems, read_targets 10 | from metatrain.utils.data.target_info import get_energy_target_info 11 | from metatrain.utils.io import model_from_checkpoint 12 | from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists 13 | 14 | from . import DATASET_PATH, DEFAULT_HYPERS, MODEL_HYPERS 15 | 16 | 17 | def test_continue(monkeypatch, tmp_path): 18 | """Tests that a model can be checkpointed and loaded 19 | for a continuation of the training process""" 20 | 21 | monkeypatch.chdir(tmp_path) 22 | shutil.copy(DATASET_PATH, "qm9_reduced_100.xyz") 23 | 24 | systems = read_systems(DATASET_PATH) 25 | systems = [system.to(torch.float32) for system in systems] 26 | 27 | target_info_dict = {} 28 | target_info_dict["mtt::U0"] = get_energy_target_info( 29 | {"quantity": "energy", "unit": "eV"} 30 | ) 31 | 32 | dataset_info = DatasetInfo( 33 | length_unit="Angstrom", atomic_types=[1, 6, 7, 8], targets=target_info_dict 34 | ) 35 | model = PET(MODEL_HYPERS, dataset_info) 36 | 37 | conf = { 38 | "mtt::U0": { 39 | "quantity": "energy", 40 | "read_from": DATASET_PATH, 41 | "reader": "ase", 42 | "key": "U0", 43 | "unit": "eV", 44 | "type": "scalar", 45 | "per_atom": False, 46 | "num_subtargets": 1, 47 | "forces": False, 48 | "stress": False, 49 | "virial": False, 50 | } 51 | } 52 | targets, _ = read_targets(OmegaConf.create(conf)) 53 | 54 | # systems in float64 are required for training 55 | systems = [system.to(torch.float64) for system in systems] 56 | 57 | dataset = Dataset.from_dict({"system": systems, "mtt::U0": targets["mtt::U0"]}) 58 | 59 | hypers = DEFAULT_HYPERS.copy() 60 | hypers["training"]["num_epochs"] = 0 61 | trainer = Trainer(hypers["training"]) 62 | trainer.train( 63 | model=model, 64 | dtype=torch.float32, 65 | devices=[torch.device("cpu")], 66 | train_datasets=[dataset], 67 | val_datasets=[dataset], 68 | checkpoint_dir=".", 69 | ) 70 | 71 | trainer.save_checkpoint(model, "tmp.ckpt") 72 | 73 | model_after = model_from_checkpoint("tmp.ckpt", context="restart") 74 | assert isinstance(model_after, PET) 75 | model_after.restart(dataset_info) 76 | 77 | hypers["training"]["num_epochs"] = 0 78 | trainer = Trainer(hypers["training"]) 79 | trainer.train( 80 | model=model_after, 81 | dtype=torch.float32, 82 | devices=[torch.device("cpu")], 83 | train_datasets=[dataset], 84 | val_datasets=[dataset], 85 | checkpoint_dir=".", 86 | ) 87 | 88 | # evaluation 89 | systems = [system.to(torch.float32) for system in systems] 90 | for system in systems: 91 | get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) 92 | 93 | model.eval() 94 | model_after.eval() 95 | 96 | # Predict on the first five systems 97 | output_before = model(systems[:5], {"mtt::U0": model.outputs["mtt::U0"]}) 98 | output_after = model_after(systems[:5], {"mtt::U0": model_after.outputs["mtt::U0"]}) 99 | 100 | assert metatensor.torch.allclose(output_before["mtt::U0"], output_after["mtt::U0"]) 101 | -------------------------------------------------------------------------------- /src/metatrain/pet/tests/test_exported.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from metatomic.torch import ModelEvaluationOptions, ModelMetadata, System 4 | 5 | from metatrain.pet import PET 6 | from metatrain.utils.data import DatasetInfo 7 | from metatrain.utils.data.target_info import get_energy_target_info 8 | from metatrain.utils.neighbor_lists import ( 9 | get_requested_neighbor_lists, 10 | get_system_with_neighbor_lists, 11 | ) 12 | 13 | from . import MODEL_HYPERS 14 | 15 | 16 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 17 | @pytest.mark.parametrize("dtype", [torch.float32]) 18 | def test_to(device, dtype): 19 | """Tests that the `.to()` method of the exported model works.""" 20 | if device == "cuda" and not torch.cuda.is_available(): 21 | pytest.skip("CUDA is not available") 22 | 23 | dataset_info = DatasetInfo( 24 | length_unit="Angstrom", 25 | atomic_types=[1, 6, 7, 8], 26 | targets={ 27 | "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) 28 | }, 29 | ) 30 | model = PET(MODEL_HYPERS, dataset_info).to(dtype=dtype) 31 | 32 | exported = model.export(metadata=ModelMetadata(name="test")) 33 | 34 | # test correct metadata 35 | assert "This is the test model" in str(exported.metadata()) 36 | 37 | exported.to(device=device) 38 | 39 | system = System( 40 | types=torch.tensor([6, 6]), 41 | positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), 42 | cell=torch.zeros(3, 3), 43 | pbc=torch.tensor([False, False, False]), 44 | ) 45 | requested_neighbor_lists = get_requested_neighbor_lists(exported) 46 | system = get_system_with_neighbor_lists(system, requested_neighbor_lists) 47 | system = system.to(device=device, dtype=dtype) 48 | 49 | evaluation_options = ModelEvaluationOptions( 50 | length_unit=dataset_info.length_unit, 51 | outputs=model.outputs, 52 | ) 53 | 54 | exported([system], evaluation_options, check_consistency=True) 55 | -------------------------------------------------------------------------------- /src/metatrain/pet/tests/test_torchscript.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | from metatomic.torch import System 5 | 6 | from metatrain.pet import PET 7 | from metatrain.utils.data import DatasetInfo 8 | from metatrain.utils.data.target_info import get_energy_target_info 9 | from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists 10 | 11 | from . import MODEL_HYPERS 12 | 13 | 14 | def test_torchscript(): 15 | """Tests that the model can be jitted.""" 16 | 17 | dataset_info = DatasetInfo( 18 | length_unit="Angstrom", 19 | atomic_types=[1, 6, 7, 8], 20 | targets={ 21 | "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) 22 | }, 23 | ) 24 | model = PET(MODEL_HYPERS, dataset_info) 25 | system = System( 26 | types=torch.tensor([6, 1, 8, 7]), 27 | positions=torch.tensor( 28 | [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] 29 | ), 30 | cell=torch.zeros(3, 3), 31 | pbc=torch.tensor([False, False, False]), 32 | ) 33 | system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) 34 | 35 | model = torch.jit.script(model) 36 | model( 37 | [system], 38 | {"energy": model.outputs["energy"]}, 39 | ) 40 | 41 | 42 | def test_torchscript_save_load(tmpdir): 43 | """Tests that the model can be jitted and saved.""" 44 | 45 | dataset_info = DatasetInfo( 46 | length_unit="Angstrom", 47 | atomic_types=[1, 6, 7, 8], 48 | targets={ 49 | "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) 50 | }, 51 | ) 52 | model = PET(MODEL_HYPERS, dataset_info) 53 | 54 | with tmpdir.as_cwd(): 55 | torch.jit.save(torch.jit.script(model), "model.pt") 56 | torch.jit.load("model.pt") 57 | 58 | 59 | def test_torchscript_integers(): 60 | """Tests that the model can be jitted when some float 61 | parameters are instead supplied as integers.""" 62 | 63 | new_hypers = copy.deepcopy(MODEL_HYPERS) 64 | new_hypers["cutoff"] = 5 65 | new_hypers["cutoff_width"] = 1 66 | 67 | dataset_info = DatasetInfo( 68 | length_unit="Angstrom", 69 | atomic_types=[1, 6, 7, 8], 70 | targets={ 71 | "energy": get_energy_target_info({"quantity": "energy", "unit": "eV"}) 72 | }, 73 | ) 74 | model = PET(new_hypers, dataset_info) 75 | 76 | system = System( 77 | types=torch.tensor([6, 1, 8, 7]), 78 | positions=torch.tensor( 79 | [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 2.0], [0.0, 0.0, 3.0]] 80 | ), 81 | cell=torch.zeros(3, 3), 82 | pbc=torch.tensor([False, False, False]), 83 | ) 84 | system = get_system_with_neighbor_lists(system, model.requested_neighbor_lists()) 85 | 86 | model = torch.jit.script(model) 87 | model( 88 | [system], 89 | {"energy": model.outputs["energy"]}, 90 | ) 91 | -------------------------------------------------------------------------------- /src/metatrain/share/metatrain-completion.bash: -------------------------------------------------------------------------------- 1 | _metatrain() 2 | { 3 | local cur_word="${COMP_WORDS[$COMP_CWORD]}" 4 | local prev_word="${COMP_WORDS[$COMP_CWORD-1]}" 5 | local module="${COMP_WORDS[1]}" 6 | 7 | # Define supported file endings. 8 | local yaml='!*@(.yml|.yaml)' 9 | local ckpt='!*.ckpt' 10 | local pt='!*.pt' 11 | local architecture_names=$(python -c " 12 | from metatrain.utils.architectures import find_all_architectures 13 | print(' '.join(find_all_architectures())) 14 | ") 15 | 16 | # Complete the arguments to the module commands. 17 | case "$module" in 18 | train) 19 | case "${prev_word}" in 20 | -h|--help|-o|--output|-r|--override) 21 | COMPREPLY=( ) 22 | return 0 23 | ;; 24 | --restart) 25 | COMPREPLY=( $( compgen -W "auto" -f -X "$ckpt" -- "${cur_word}") ) 26 | return 0 27 | ;; 28 | *) 29 | if [[ $COMP_CWORD -eq 2 ]]; then 30 | COMPREPLY=( $(compgen -f -X "$yaml" -- "${cur_word}") ) 31 | return 0 32 | fi 33 | ;; 34 | esac 35 | local opts="-h --help -o --output --restart -r --override" 36 | COMPREPLY=( $(compgen -W "${opts}" -- "${cur_word}") ) 37 | return 0 38 | ;; 39 | export) 40 | case "${prev_word}" in 41 | -h|--help|-o|--output) 42 | COMPREPLY=( ) 43 | return 0 44 | ;; 45 | *) 46 | if [[ $COMP_CWORD -eq 2 ]]; then 47 | # We don't have a generated list of known the architecture names 48 | COMPREPLY=( $(compgen -W "$architecture_names" -- "${cur_word}") ) 49 | return 0 50 | elif [[ $COMP_CWORD -eq 3 ]]; then 51 | COMPREPLY=( $(compgen -f -X "$ckpt" -- "${cur_word}") ) 52 | return 0 53 | fi 54 | ;; 55 | esac 56 | local opts="-h --help -o --output -m --metadata --token" 57 | COMPREPLY=( $(compgen -W "${opts}" -- "${cur_word}") ) 58 | return 0 59 | ;; 60 | eval) 61 | case "${prev_word}" in 62 | -h|--help|-o|--output|-b|--batch-size|--check-consistency) 63 | COMPREPLY=( ) 64 | return 0 65 | ;; 66 | -e|--extensions-dir) 67 | # Only complete directories 68 | COMPREPLY=( $(compgen -d -- "${cur_word}") ) 69 | return 0 70 | ;; 71 | *) 72 | if [[ $COMP_CWORD -eq 2 ]]; then 73 | COMPREPLY=( $(compgen -f -X "$pt" -- "${cur_word}") ) 74 | return 0 75 | elif [[ $COMP_CWORD -eq 3 ]]; then 76 | COMPREPLY=( $(compgen -f -X "$yaml" -- "${cur_word}") ) 77 | return 0 78 | fi 79 | ;; 80 | esac 81 | local opts="-h --help -o --output -b --batch-size -e --extensions-dir --check-consistency" 82 | COMPREPLY=( $(compgen -W "${opts}" -- "${cur_word}") ) 83 | return 0 84 | ;; 85 | esac 86 | 87 | # Complete the basic metatrain commands. 88 | local opts="eval export train -h --help --debug --version" 89 | COMPREPLY=( $(compgen -W "${opts}" -- "${cur_word}") ) 90 | return 0 91 | } 92 | 93 | if test -n "$ZSH_VERSION"; then 94 | autoload -U +X compinit && compinit 95 | autoload -U +X bashcompinit && bashcompinit 96 | fi 97 | 98 | complete -o bashdefault -F _metatrain mtt 99 | -------------------------------------------------------------------------------- /src/metatrain/share/schema-base.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json-schema.org/draft-07/schema#", 3 | "$defs": { 4 | "dataset_object": { 5 | "type": "object", 6 | "additionalProperties": true 7 | }, 8 | "dataset_array": { 9 | "type": "array", 10 | "items": { 11 | "type": "object", 12 | "additionalProperties": true 13 | } 14 | }, 15 | "dataset_string": { 16 | "type": "string", 17 | "format": "uri" 18 | } 19 | }, 20 | "type": "object", 21 | "properties": { 22 | "device": { 23 | "type": "string" 24 | }, 25 | "base_precision": { 26 | "type": "integer", 27 | "enum": [16, 32, 64] 28 | }, 29 | "seed": { 30 | "type": "integer", 31 | "minimum": 0 32 | }, 33 | "wandb": { 34 | "type": "object", 35 | "additionalProperties": true, 36 | "propertyNames": { 37 | "not": { 38 | "const": "config" 39 | } 40 | }, 41 | "allOf": [ 42 | { 43 | "if": { "required": ["entity"] }, 44 | "then": { 45 | "properties": { 46 | "entity": { "type": "string" } 47 | } 48 | } 49 | }, 50 | { 51 | "if": { "required": ["project"] }, 52 | "then": { 53 | "properties": { 54 | "project": { "type": "string" } 55 | } 56 | } 57 | }, 58 | { 59 | "if": { "required": ["name"] }, 60 | "then": { 61 | "properties": { 62 | "name": { "type": "string" } 63 | } 64 | } 65 | } 66 | ] 67 | }, 68 | "architecture": { 69 | "type": "object", 70 | "properties": { 71 | "name": { 72 | "type": "string" 73 | } 74 | }, 75 | "required": ["name"], 76 | "additionalProperties": true 77 | }, 78 | "training_set": { 79 | "oneOf": [ 80 | { 81 | "$ref": "#/$defs/dataset_object" 82 | }, 83 | { 84 | "$ref": "#/$defs/dataset_array" 85 | }, 86 | { 87 | "$ref": "#/$defs/dataset_string" 88 | } 89 | ] 90 | }, 91 | "test_set": { 92 | "oneOf": [ 93 | { 94 | "$ref": "#/$defs/dataset_object" 95 | }, 96 | { 97 | "$ref": "#/$defs/dataset_array" 98 | }, 99 | { 100 | "$ref": "#/$defs/dataset_string" 101 | }, 102 | { 103 | "type": "number", 104 | "minimum": 0, 105 | "exclusiveMaximum": 1 106 | } 107 | ] 108 | }, 109 | "validation_set": { 110 | "oneOf": [ 111 | { 112 | "$ref": "#/$defs/dataset_object" 113 | }, 114 | { 115 | "$ref": "#/$defs/dataset_array" 116 | }, 117 | { 118 | "$ref": "#/$defs/dataset_string" 119 | }, 120 | { 121 | "type": "number", 122 | "exclusiveMinimum": 0, 123 | "exclusiveMaximum": 1 124 | } 125 | ] 126 | } 127 | }, 128 | "required": ["architecture", "training_set", "test_set", "validation_set"], 129 | "additionalProperties": false 130 | } 131 | -------------------------------------------------------------------------------- /src/metatrain/soap_bpnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SoapBpnn 2 | from .trainer import Trainer 3 | 4 | 5 | __model__ = SoapBpnn 6 | __trainer__ = Trainer 7 | 8 | __authors__ = [ 9 | ("Filippo Bigi ", "@frostedoyster"), 10 | ] 11 | 12 | __maintainers__ = [ 13 | ("Filippo Bigi ", "@frostedoyster"), 14 | ] 15 | -------------------------------------------------------------------------------- /src/metatrain/soap_bpnn/default-hypers.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | name: soap_bpnn 3 | 4 | model: 5 | soap: 6 | max_angular: 6 7 | max_radial: 7 8 | cutoff: 9 | radius: 5.0 10 | width: 0.5 11 | bpnn: 12 | layernorm: true 13 | num_hidden_layers: 2 14 | num_neurons_per_layer: 32 15 | add_lambda_basis: true 16 | heads: {} 17 | zbl: false 18 | long_range: 19 | enable: false 20 | use_ewald: false 21 | smearing: 1.4 22 | kspace_resolution: 1.33 23 | interpolation_nodes: 5 24 | 25 | training: 26 | distributed: False 27 | distributed_port: 39591 28 | batch_size: 8 29 | num_epochs: 100 30 | learning_rate: 0.001 31 | early_stopping_patience: 200 32 | scheduler_patience: 100 33 | scheduler_factor: 0.8 34 | log_interval: 5 35 | checkpoint_interval: 25 36 | scale_targets: true 37 | fixed_composition_weights: {} 38 | per_structure_targets: [] 39 | log_mae: False 40 | log_separate_blocks: false 41 | best_model_metric: rmse_prod 42 | loss: 43 | type: mse 44 | weights: {} 45 | reduction: mean 46 | -------------------------------------------------------------------------------- /src/metatrain/soap_bpnn/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from metatrain.utils.architectures import get_default_hypers 4 | 5 | 6 | DATASET_PATH = str(Path(__file__).parents[4] / "tests/resources/qm9_reduced_100.xyz") 7 | 8 | DEFAULT_HYPERS = get_default_hypers("soap_bpnn") 9 | MODEL_HYPERS = DEFAULT_HYPERS["model"] 10 | -------------------------------------------------------------------------------- /src/metatrain/soap_bpnn/tests/test_exported.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from metatomic.torch import ModelEvaluationOptions, ModelMetadata, System 4 | 5 | from metatrain.soap_bpnn import SoapBpnn 6 | from metatrain.utils.data import DatasetInfo 7 | from metatrain.utils.data.target_info import get_energy_target_info 8 | from metatrain.utils.neighbor_lists import ( 9 | get_requested_neighbor_lists, 10 | get_system_with_neighbor_lists, 11 | ) 12 | 13 | from . import MODEL_HYPERS 14 | 15 | 16 | @pytest.mark.parametrize("device", ["cpu", "cuda"]) 17 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) 18 | def test_to(device, dtype): 19 | """Tests that the `.to()` method of the exported model works.""" 20 | if device == "cuda" and not torch.cuda.is_available(): 21 | pytest.skip("CUDA is not available") 22 | 23 | dataset_info = DatasetInfo( 24 | length_unit="Angstrom", 25 | atomic_types=[1, 6, 7, 8], 26 | targets={"energy": get_energy_target_info({"unit": "eV"})}, 27 | ) 28 | model = SoapBpnn(MODEL_HYPERS, dataset_info).to(dtype=dtype) 29 | exported = model.export(metadata=ModelMetadata(name="test")) 30 | 31 | # test correct metadata 32 | assert "This is the test model" in str(exported.metadata()) 33 | 34 | exported.to(device=device) 35 | 36 | system = System( 37 | types=torch.tensor([6, 6]), 38 | positions=torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), 39 | cell=torch.zeros(3, 3), 40 | pbc=torch.tensor([False, False, False]), 41 | ) 42 | requested_neighbor_lists = get_requested_neighbor_lists(exported) 43 | system = get_system_with_neighbor_lists(system, requested_neighbor_lists) 44 | system = system.to(device=device, dtype=dtype) 45 | 46 | evaluation_options = ModelEvaluationOptions( 47 | length_unit=dataset_info.length_unit, 48 | outputs=model.outputs, 49 | ) 50 | 51 | exported([system], evaluation_options, check_consistency=True) 52 | -------------------------------------------------------------------------------- /src/metatrain/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/src/metatrain/utils/__init__.py -------------------------------------------------------------------------------- /src/metatrain/utils/additive/__init__.py: -------------------------------------------------------------------------------- 1 | from .composition import CompositionModel # noqa: F401 2 | from .remove import remove_additive # noqa: F401 3 | from .zbl import ZBL # noqa: F401 4 | -------------------------------------------------------------------------------- /src/metatrain/utils/additive/remove.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, List 3 | 4 | import metatensor.torch 5 | import torch 6 | from metatensor.torch import TensorMap 7 | from metatomic.torch import System 8 | 9 | from ..data import TargetInfo 10 | from ..evaluate_model import evaluate_model 11 | 12 | 13 | def remove_additive( 14 | systems: List[System], 15 | targets: Dict[str, TensorMap], 16 | additive_model: torch.nn.Module, 17 | target_info_dict: Dict[str, TargetInfo], 18 | ): 19 | """Remove an additive contribution from the training targets. 20 | 21 | :param systems: List of systems. 22 | :param targets: Dictionary containing the targets corresponding to the systems. 23 | :param additive_model: The model used to calculate the additive 24 | contribution to be removed. 25 | :param targets_dict: Dictionary containing information about the targets. 26 | """ 27 | warnings.filterwarnings( 28 | "ignore", 29 | category=RuntimeWarning, 30 | message=( 31 | "GRADIENT WARNING: element 0 of tensors does not " 32 | "require grad and does not have a grad_fn" 33 | ), 34 | ) 35 | additive_contribution = evaluate_model( 36 | additive_model, 37 | systems, 38 | { 39 | key: target_info_dict[key] 40 | for key in targets.keys() 41 | if key in additive_model.outputs 42 | }, 43 | is_training=False, # we don't need any gradients w.r.t. any parameters 44 | ) 45 | 46 | for target_key in additive_contribution.keys(): 47 | # note that we loop over the keys of additive_contribution, not targets, 48 | # because the targets might contain additional keys (this is for example 49 | # the case of the composition model, which will only provide outputs 50 | # for scalar targets 51 | 52 | # make the samples the same so we can use metatensor.torch.subtract 53 | # we also need to detach the values to avoid backpropagating through the 54 | # subtraction 55 | blocks = [] 56 | for block_key, old_block in additive_contribution[target_key].items(): 57 | block = metatensor.torch.TensorBlock( 58 | values=old_block.values.detach(), 59 | samples=targets[target_key].block(block_key).samples, 60 | components=old_block.components, 61 | properties=old_block.properties, 62 | ) 63 | for gradient_name in targets[target_key].block(block_key).gradients_list(): 64 | gradient = ( 65 | additive_contribution[target_key] 66 | .block(block_key) 67 | .gradient(gradient_name) 68 | ) 69 | block.add_gradient( 70 | gradient_name, 71 | metatensor.torch.TensorBlock( 72 | values=gradient.values.detach(), 73 | samples=targets[target_key] 74 | .block(block_key) 75 | .gradient(gradient_name) 76 | .samples, 77 | components=gradient.components, 78 | properties=gradient.properties, 79 | ), 80 | ) 81 | blocks.append(block) 82 | additive_contribution[target_key] = TensorMap( 83 | keys=targets[target_key].keys, 84 | blocks=blocks, 85 | ) 86 | # subtract the additive contribution from the target 87 | targets[target_key] = metatensor.torch.subtract( 88 | targets[target_key], additive_contribution[target_key] 89 | ) 90 | 91 | return targets 92 | -------------------------------------------------------------------------------- /src/metatrain/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .combine_dataloaders import CombinedDataLoader # noqa: F401 2 | from .dataset import ( # noqa: F401 3 | Dataset, 4 | DatasetInfo, 5 | DiskDataset, 6 | DiskDatasetWriter, 7 | _is_disk_dataset, 8 | check_datasets, 9 | collate_fn, 10 | get_all_targets, 11 | get_atomic_types, 12 | get_stats, 13 | ) 14 | from .get_dataset import get_dataset # noqa: F401 15 | from .readers import read_systems, read_targets # noqa: F401 16 | from .system_to_ase import system_to_ase # noqa: F401 17 | from .target_info import TargetInfo # noqa: F401 18 | from .writers import write_predictions # noqa: F401 19 | -------------------------------------------------------------------------------- /src/metatrain/utils/data/combine_dataloaders.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class CombinedDataLoader: 8 | """ 9 | Combines multiple dataloaders into a single dataloader. 10 | 11 | This is useful for learning from multiple datasets at the same time, 12 | each of which may have different batch sizes, properties, etc. 13 | 14 | :param dataloaders: list of dataloaders to combine 15 | :param shuffle: whether to shuffle the combined dataloader (this does not 16 | act on the individual batches, but it shuffles the order in which 17 | they are returned) 18 | 19 | :return: the combined dataloader 20 | """ 21 | 22 | def __init__(self, dataloaders: List[torch.utils.data.DataLoader], shuffle: bool): 23 | self.dataloaders = dataloaders 24 | self.shuffle = shuffle 25 | 26 | # Create the indices. These contain which dataloader each batch comes from. 27 | # These will be shuffled later. 28 | self.indices = [] 29 | for i, dl in enumerate(dataloaders): 30 | self.indices.extend([i] * len(dl)) 31 | 32 | self.reset() 33 | 34 | def reset(self): 35 | self.dataloader_iterators = [iter(dl) for dl in self.dataloaders] 36 | self.current_index = 0 37 | # Shuffle the indices, if requested, for every new epoch 38 | if self.shuffle: 39 | np.random.shuffle(self.indices) 40 | 41 | def __iter__(self): 42 | return self 43 | 44 | def __next__(self): 45 | if self.current_index >= len(self.indices): 46 | self.reset() # Reset the index for the next iteration 47 | raise StopIteration 48 | 49 | idx = self.indices[self.current_index] 50 | self.current_index += 1 51 | return next(self.dataloader_iterators[idx]) 52 | 53 | def __len__(self): 54 | """Total number of batches in all dataloaders. 55 | 56 | This returns the total number of batches in all dataloaders 57 | (as opposed to the total number of samples or the number of 58 | individual dataloaders). 59 | 60 | :return: the total number of batches in all dataloaders 61 | """ 62 | return sum(len(dl) for dl in self.dataloaders) 63 | -------------------------------------------------------------------------------- /src/metatrain/utils/data/get_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | from omegaconf import DictConfig 4 | 5 | from .dataset import Dataset, DiskDataset 6 | from .readers import read_systems, read_targets 7 | from .target_info import TargetInfo 8 | 9 | 10 | def get_dataset(options: DictConfig) -> Tuple[Dataset, Dict[str, TargetInfo]]: 11 | """ 12 | Gets a dataset given a configuration dictionary. 13 | 14 | The system and targets in the dataset are read from one or more 15 | files, as specified in ``options``. 16 | 17 | :param options: the configuration options for the dataset. 18 | This configuration dictionary must contain keys for both the 19 | systems and targets in the dataset. 20 | 21 | :returns: A tuple containing a ``Dataset`` object and a 22 | ``Dict[str, TargetInfo]`` containing additional information (units, 23 | physical quantities, ...) on the targets in the dataset 24 | """ 25 | 26 | if options["systems"]["read_from"].endswith(".zip"): # disk dataset 27 | dataset = DiskDataset(options["systems"]["read_from"]) 28 | target_info_dictionary = dataset.get_target_info(options["targets"]) 29 | else: 30 | systems = read_systems( 31 | filename=options["systems"]["read_from"], 32 | reader=options["systems"]["reader"], 33 | ) 34 | targets, target_info_dictionary = read_targets(conf=options["targets"]) 35 | dataset = Dataset.from_dict({"system": systems, **targets}) 36 | 37 | return dataset, target_info_dictionary 38 | -------------------------------------------------------------------------------- /src/metatrain/utils/data/readers/__init__.py: -------------------------------------------------------------------------------- 1 | from .readers import ( # noqa: F401 2 | read_systems, 3 | read_targets, 4 | ) 5 | -------------------------------------------------------------------------------- /src/metatrain/utils/data/system_to_ase.py: -------------------------------------------------------------------------------- 1 | import ase 2 | from metatomic.torch import System 3 | 4 | 5 | def system_to_ase(system: System) -> ase.Atoms: 6 | """Converts a ``metatomic.torch.System`` to an ``ase.Atoms`` object. 7 | This will discard any neighbor lists attached to the ``System``. 8 | 9 | :param system: The system to convert. 10 | 11 | :return: The system as an ``ase.Atoms`` object. 12 | """ 13 | 14 | # Convert the system to an ASE atoms object 15 | positions = system.positions.detach().cpu().numpy() 16 | numbers = system.types.detach().cpu().numpy() 17 | cell = system.cell.detach().cpu().numpy() 18 | pbc = list(cell.any(axis=1)) 19 | atoms = ase.Atoms( 20 | numbers=numbers, 21 | positions=positions, 22 | cell=cell, 23 | pbc=pbc, 24 | ) 25 | 26 | return atoms 27 | -------------------------------------------------------------------------------- /src/metatrain/utils/data/writers/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Optional 3 | 4 | from metatensor.torch import TensorMap 5 | from metatomic.torch import ModelCapabilities, System 6 | 7 | from .metatensor import write_mts 8 | from .xyz import write_xyz 9 | 10 | 11 | PREDICTIONS_WRITERS = {".xyz": write_xyz, ".mts": write_mts} 12 | """:py:class:`dict`: dictionary mapping file suffixes to a prediction writers""" 13 | 14 | 15 | def write_predictions( 16 | filename: str, 17 | systems: List[System], 18 | capabilities: ModelCapabilities, 19 | predictions: TensorMap, 20 | fileformat: Optional[str] = None, 21 | ) -> None: 22 | """Writes predictions to a file. 23 | 24 | For certain file suffixes, the systems will also be written (i.e ``xyz``). 25 | 26 | The capabilities of the model are used to infer the type (physical quantity) of 27 | the predictions. In this way, for example, position gradients of energies can be 28 | saved as forces. 29 | 30 | For the moment, strain gradients of the energy are saved as stresses 31 | (and not as virials). 32 | 33 | :param filename: name of the file to write 34 | :param systems: list of systems that for some writers will also be written 35 | :param capabilities: capabilities of the model 36 | :param predictions: :py:class:`metatensor.torch.TensorMap` containing the 37 | predictions that should be written 38 | :param fileformat: format of the target value file. If :py:obj:`None` the format is 39 | determined from the file extension. 40 | """ 41 | if fileformat is None: 42 | fileformat = Path(filename).suffix 43 | 44 | try: 45 | writer = PREDICTIONS_WRITERS[fileformat] 46 | except KeyError: 47 | raise ValueError(f"fileformat '{fileformat}' is not supported") 48 | 49 | return writer(filename, systems, capabilities, predictions) 50 | -------------------------------------------------------------------------------- /src/metatrain/utils/data/writers/metatensor.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, List 3 | 4 | import torch 5 | from metatensor.torch import TensorMap, save 6 | from metatomic.torch import ModelCapabilities, System 7 | 8 | 9 | # note that, although we don't use `systems` and `capabilities`, we need them to 10 | # match the writer interface 11 | def write_mts( 12 | filename: str, 13 | systems: List[System], 14 | capabilities: ModelCapabilities, 15 | predictions: Dict[str, TensorMap], 16 | ) -> None: 17 | """A metatensor-format prediction writer. Writes the predictions to `.mts` files. 18 | 19 | :param filename: name of the file to save to. 20 | :param systems: structures to be written to the file (not written by this writer). 21 | :param: capabilities: capabilities of the model (not used by this writer) 22 | :param predictions: prediction values to be written to the file. 23 | """ 24 | 25 | filename_base = Path(filename).stem 26 | for prediction_name, prediction_tmap in predictions.items(): 27 | save( 28 | filename_base + "_" + prediction_name + ".mts", 29 | prediction_tmap.to("cpu").to(torch.float64), 30 | ) 31 | -------------------------------------------------------------------------------- /src/metatrain/utils/distributed/distributed_data_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DistributedDataParallel(torch.nn.parallel.DistributedDataParallel): 5 | """ 6 | DistributedDataParallel wrapper that inherits from 7 | :py:class`torch.nn.parallel.DistributedDataParallel` 8 | and adds a function to retrieve the supported outputs of the module. 9 | """ 10 | 11 | def supported_outputs(self): 12 | return self.module.supported_outputs() 13 | -------------------------------------------------------------------------------- /src/metatrain/utils/distributed/logging.py: -------------------------------------------------------------------------------- 1 | from .slurm import is_slurm, is_slurm_main_process 2 | 3 | 4 | def is_main_process(): 5 | if is_slurm(): 6 | return is_slurm_main_process() 7 | else: 8 | return True 9 | -------------------------------------------------------------------------------- /src/metatrain/utils/distributed/slurm.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hostlist 4 | 5 | 6 | def is_slurm(): 7 | return ("SLURM_JOB_ID" in os.environ) and ("SLURM_PROCID" in os.environ) 8 | 9 | 10 | def is_slurm_main_process(): 11 | return os.environ["SLURM_PROCID"] == "0" 12 | 13 | 14 | class DistributedEnvironment: 15 | """ 16 | Distributed environment for Slurm. 17 | 18 | This class sets up the distributed environment on Slurm. It reads 19 | the necessary environment variables and sets them for use in the 20 | PyTorch distributed utilities. Modified from 21 | https://github.com/Lumi-supercomputer/lumi-reframe-tests/blob/main/checks/apps/deeplearning/pytorch/src/pt_distr_env.py. 22 | 23 | :param port: The port to use for communication in the distributed 24 | environment. 25 | """ # noqa: E501, E262 26 | 27 | def __init__(self, port: int): 28 | self._setup_distr_env(port) 29 | self.master_addr = os.environ["MASTER_ADDR"] 30 | self.master_port = os.environ["MASTER_PORT"] 31 | self.world_size = int(os.environ["WORLD_SIZE"]) 32 | self.rank = int(os.environ["RANK"]) 33 | self.local_rank = int(os.environ["LOCAL_RANK"]) 34 | 35 | def _setup_distr_env(self, port: int): 36 | hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"]) 37 | os.environ["MASTER_ADDR"] = hostnames[0] # set first node as master 38 | os.environ["MASTER_PORT"] = str(port) # set port for communication 39 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"] 40 | os.environ["RANK"] = os.environ["SLURM_PROCID"] 41 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] 42 | -------------------------------------------------------------------------------- /src/metatrain/utils/dtype.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def dtype_to_str(dtype: torch.dtype) -> str: 5 | """ 6 | Convert a torch dtype to its string representation. 7 | 8 | :param dtype: torch dtype to convert 9 | :returns: string representation of the torch dtype 10 | 11 | Example 12 | ------- 13 | >>> import torch 14 | >>> dtype_to_str(torch.float64) 15 | "float64" 16 | >>> dtype_to_str(torch.int32) 17 | "int32" 18 | """ 19 | return str(dtype).split(".")[-1] 20 | -------------------------------------------------------------------------------- /src/metatrain/utils/errors.py: -------------------------------------------------------------------------------- 1 | class ArchitectureError(Exception): 2 | """ 3 | Exception raised for errors originating from architectures 4 | 5 | This exception should be raised when an error occurs within an architecture's 6 | operation, indicating that the problem is not directly related to the 7 | metatrain infrastructure but rather to the specific architecture being used. 8 | 9 | :param exception: The original exception that was caught, which led to raising this 10 | custom exception. 11 | """ 12 | 13 | def __init__(self, exception): 14 | super().__init__( 15 | f"{exception.__class__.__name__}: {exception}\n\n" 16 | "The error above most likely originates from an architecture.\n\n" 17 | "If you think this is a bug, please contact its maintainer (see the " 18 | "architecture's documentation) and include the full traceback error.log." 19 | ) 20 | -------------------------------------------------------------------------------- /src/metatrain/utils/external_naming.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | from metatomic.torch import ModelOutput 4 | 5 | 6 | def to_external_name( 7 | internal_name: str, quantities: Union[Dict[str, ModelOutput]] 8 | ) -> str: 9 | """Converts internal names to external names. 10 | 11 | Very often, the "common" names for quantities are different from the 12 | internal names used in the code. Two important examples are forces and 13 | virials, which are referred to as energy_positions_gradients and 14 | energy_strain_gradients, respectively, in the code. This function 15 | converts an internal name to an external name. 16 | 17 | :param internal_name: An internal name to convert. 18 | :param quantities: A dictionary of physical quantities, either as 19 | :py:class:`TargetInfo` objects or as :py:class:`ModelOutput` objects. 20 | 21 | :return: The name for external use. 22 | """ 23 | 24 | if internal_name.endswith("_positions_gradients"): 25 | base_name = internal_name.replace("_positions_gradients", "") 26 | if quantities[base_name].quantity == "energy": 27 | if base_name == "energy": # we treat "energy" as a special case 28 | external_name = "forces" 29 | else: 30 | external_name = f"forces[{base_name}]" 31 | else: 32 | external_name = internal_name 33 | elif internal_name.endswith("_strain_gradients"): 34 | base_name = internal_name.replace("_strain_gradients", "") 35 | if quantities[base_name].quantity == "energy": 36 | if base_name == "energy": 37 | external_name = "virial" 38 | else: 39 | external_name = f"virial[{base_name}]" 40 | else: 41 | external_name = internal_name 42 | else: 43 | external_name = internal_name 44 | 45 | return external_name 46 | 47 | 48 | def to_internal_name(external_name: str) -> str: 49 | """Converts an external names to internal names. 50 | 51 | This function is the inverse of :func:`to_external_names`. 52 | 53 | :param external_names: A list of names to convert. 54 | 55 | :return: The list of names for internal use. 56 | """ 57 | 58 | if external_name == "forces": 59 | internal_name = "energy_positions_gradients" 60 | elif external_name.startswith("forces[") and external_name.endswith("]"): 61 | base_name = external_name[7:-1] 62 | internal_name = f"{base_name}_positions_gradients" 63 | elif external_name == "virial": 64 | internal_name = "energy_strain_gradients" 65 | elif external_name.startswith("virial[") and external_name.endswith("]"): 66 | base_name = external_name[7:-1] 67 | internal_name = f"{base_name}_strain_gradients" 68 | else: 69 | internal_name = external_name 70 | 71 | return internal_name 72 | -------------------------------------------------------------------------------- /src/metatrain/utils/jsonschema.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | 3 | import jsonschema 4 | from jsonschema.exceptions import ValidationError 5 | 6 | 7 | def validate(instance, schema, cls=None, *args, **kwargs) -> None: 8 | """Validate an instance under the given schema. 9 | 10 | Function similar to :py:class:`jsonschema.validate` but displaying only the human 11 | readable error message without showing the reference schema and path if the instance 12 | is invalid. In addition, if the error is caused by unallowed 13 | ``additionalProperties`` the closest matching properties will be suggested. 14 | 15 | :param instance: Instance to validate 16 | :param schema: Schema to validate with 17 | :raises jsonschema.exceptions.ValidationError: If the instance is invalid 18 | :raises jsonschema.exceptions.SchemaError: If the schema itself is invalid 19 | """ 20 | try: 21 | jsonschema.validate(instance, schema, cls=cls, *args, **kwargs) # noqa: B026 22 | except ValidationError as error: 23 | if error.validator == "additionalProperties": 24 | # Change error message to be clearer for users 25 | error.message = error.message.replace( 26 | "Additional properties are not allowed", "Unrecognized options" 27 | ) 28 | 29 | known_properties = error.schema["properties"].keys() 30 | unknown_properties = error.instance.keys() - known_properties 31 | 32 | closest_matches = [] 33 | for name in unknown_properties: 34 | closest_match = difflib.get_close_matches( 35 | word=name, possibilities=known_properties 36 | ) 37 | 38 | if closest_match: 39 | closest_matches.append(f"'{closest_match[0]}'") 40 | 41 | if closest_matches: 42 | error.message += f". Do you mean {', '.join(closest_matches)}?" 43 | 44 | raise ValidationError(message=error.message) 45 | -------------------------------------------------------------------------------- /src/metatrain/utils/metadata.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | import json 3 | 4 | from metatomic.torch import ModelMetadata 5 | 6 | 7 | def update(d, u): 8 | for k, v in u.items(): 9 | if isinstance(v, collections.abc.Mapping): 10 | d[k] = update(d.get(k, {}), v) 11 | elif isinstance(v, list): 12 | if k in d: 13 | for item in v: 14 | if item not in d[k]: 15 | d[k].append(item) 16 | else: 17 | d[k] = v 18 | else: 19 | d[k] = v 20 | return d 21 | 22 | 23 | def merge_metadata(self: ModelMetadata, other: ModelMetadata) -> ModelMetadata: 24 | """Append ``references`` to an existing ModelMetadata object. 25 | 26 | :param self: The metadata object to be updated. 27 | :param other: The metadata object to merged to self. 28 | """ 29 | 30 | self_dict = json.loads(self._get_method("__getstate__")()) 31 | other_dict = json.loads(other._get_method("__getstate__")()) 32 | 33 | self_dict = update(self_dict, other_dict) 34 | self_dict.pop("class") 35 | 36 | new_metadata = ModelMetadata(**self_dict) 37 | 38 | return new_metadata 39 | -------------------------------------------------------------------------------- /src/metatrain/utils/output_gradient.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import List, Optional 3 | 4 | import torch 5 | 6 | 7 | def compute_gradient( 8 | target: torch.Tensor, inputs: List[torch.Tensor], is_training: bool 9 | ) -> List[torch.Tensor]: 10 | """ 11 | Calculates the gradient of a target tensor with respect to a list of input tensors. 12 | 13 | ``target`` must be a single torch.Tensor object. If target contains multiple values, 14 | the gradient will be calculated with respect to the sum of all values. 15 | """ 16 | 17 | grad_outputs: Optional[List[Optional[torch.Tensor]]] = [torch.ones_like(target)] 18 | try: 19 | gradient = torch.autograd.grad( 20 | outputs=[target], 21 | inputs=inputs, 22 | grad_outputs=grad_outputs, 23 | retain_graph=is_training, 24 | create_graph=is_training, 25 | ) 26 | except RuntimeError as e: 27 | # Torch raises an error if the target tensor does not require grad, 28 | # but this could just mean that the target is a constant tensor, like in 29 | # the case of composition models. In this case, we can safely ignore the error 30 | # and we raise a warning instead. The warning can be caught and silenced in the 31 | # appropriate places. 32 | if ( 33 | "element 0 of tensors does not require grad and does not have a grad_fn" 34 | in str(e) 35 | ): 36 | warnings.warn(f"GRADIENT WARNING: {e}", RuntimeWarning, stacklevel=2) 37 | gradient = [torch.zeros_like(i) for i in inputs] 38 | else: 39 | # Re-raise the error if it's not the one above 40 | raise 41 | if gradient is None: 42 | raise ValueError( 43 | "Unexpected None value for computed gradient. " 44 | "One or more operations inside the model might " 45 | "not have a gradient implementation." 46 | ) 47 | else: 48 | return gradient 49 | -------------------------------------------------------------------------------- /src/metatrain/utils/sum_over_atoms.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from metatensor.torch import Labels, TensorBlock, TensorMap 5 | 6 | 7 | @torch.jit.script 8 | def sum_over_atoms(tensor_map: TensorMap): # pragma: no cover 9 | """ 10 | A faster version of ``metatensor.torch.sum_over_samples``, specialized for 11 | summing over atoms in graph-like TensorMaps. 12 | 13 | :param tensor_map: The TensorMap to sum over. 14 | :return: A new TensorMap with the same keys, but with the samples summed 15 | over the atoms. 16 | """ 17 | new_blocks: List[TensorBlock] = [] 18 | for block in tensor_map.blocks(): 19 | n_systems = int(block.samples.column("system").max() + 1) 20 | new_tensor = torch.zeros( 21 | [n_systems] + block.values.shape[1:], 22 | device=tensor_map.device, 23 | dtype=tensor_map.dtype, 24 | ) 25 | new_tensor.index_add_(0, block.samples.column("system"), block.values) 26 | new_block = TensorBlock( 27 | values=new_tensor, 28 | samples=Labels( 29 | names=["system"], 30 | values=torch.arange( 31 | n_systems, device=tensor_map.device, dtype=torch.int 32 | ).reshape(-1, 1), 33 | ), 34 | components=block.components, 35 | properties=block.properties, 36 | ) 37 | new_blocks.append(new_block) 38 | return TensorMap( 39 | keys=tensor_map.keys, 40 | blocks=new_blocks, 41 | ) 42 | -------------------------------------------------------------------------------- /src/metatrain/utils/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/metatensor/metatrain/9a0c351ae227781e8481bfc50fe6bd463048e417/src/metatrain/utils/testing/__init__.py -------------------------------------------------------------------------------- /src/metatrain/utils/transfer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | from metatensor.torch import TensorMap 5 | from metatomic.torch import System 6 | 7 | 8 | @torch.jit.script 9 | def systems_and_targets_to_device( # pragma: no cover 10 | systems: List[System], 11 | targets: Dict[str, TensorMap], 12 | device: torch.device, 13 | ): 14 | """ 15 | Transfers the systems and targets to the specified device. 16 | 17 | :param systems: List of systems. 18 | :param targets: Dictionary of targets. 19 | :param device: Device to transfer to. 20 | """ 21 | 22 | systems = [system.to(device=device) for system in systems] 23 | targets = {key: value.to(device=device) for key, value in targets.items()} 24 | return systems, targets 25 | 26 | 27 | @torch.jit.script 28 | def systems_and_targets_to_dtype( # pragma: no cover 29 | systems: List[System], 30 | targets: Dict[str, TensorMap], 31 | dtype: torch.dtype, 32 | ): 33 | """ 34 | Changes the systems and targets to the specified floating point data type. 35 | 36 | :param systems: List of systems. 37 | :param targets: Dictionary of targets. 38 | :param dtype: Desired floating point data type. 39 | """ 40 | 41 | systems = [system.to(dtype=dtype) for system in systems] 42 | targets = {key: value.to(dtype=dtype) for key, value in targets.items()} 43 | return systems, targets 44 | -------------------------------------------------------------------------------- /src/metatrain/utils/units.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | 4 | def get_gradient_units(base_unit: str, gradient_name: str, length_unit: str) -> str: 5 | """ 6 | Get the gradient units based on the unit of the base quantity. 7 | 8 | For example, if the base unit is "" and the gradient name is 9 | "positions", the gradient unit will be "/". 10 | 11 | :param base_unit: The unit of the base quantity. 12 | :param gradient_name: The name of the gradient. 13 | :param length_unit: The unit of lengths. 14 | 15 | :return: The unit of the gradient. 16 | """ 17 | if base_unit == "": 18 | return "" # unknown unit for base quantity -> unknown unit for gradient 19 | if length_unit.lower() in ["angstrom", "å", "ångstrom"]: 20 | length_unit = "A" # prettier 21 | if gradient_name == "positions": 22 | return base_unit + "/" + length_unit 23 | elif gradient_name == "strain": 24 | return base_unit # strain is dimensionless 25 | else: 26 | raise ValueError(f"Unknown gradient name: {gradient_name}") 27 | 28 | 29 | def ev_to_mev(value: float, unit: str) -> Tuple[float, str]: 30 | """ 31 | If the `unit` starts with eV, converts the `value` and its 32 | corresponding `unit` to meV. Otherwise, returns the input. 33 | 34 | :param value: The value (potentially in eV or a derived quantity of eV). 35 | :param unit: The unit of the value. 36 | 37 | :return: If the `value` is in meV (or a derived quantity), the value and 38 | the corresponding unit where eV is converted to meV. Otherwise, the input. 39 | """ 40 | if unit.startswith("eV") or unit.startswith("ev"): 41 | return value * 1000.0, ( 42 | unit.replace("eV", "meV") 43 | if unit.startswith("eV") 44 | else unit.replace("ev", "mev") 45 | ) 46 | else: 47 | return value, unit 48 | -------------------------------------------------------------------------------- /tests/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from metatrain.utils.architectures import get_default_hypers 4 | 5 | 6 | MODEL_HYPERS = get_default_hypers("soap_bpnn")["model"] 7 | 8 | RESOURCES_PATH = Path(__file__).parents[1] / "resources" 9 | 10 | DATASET_PATH_QM9 = RESOURCES_PATH / "qm9_reduced_100.xyz" 11 | DATASET_PATH_ETHANOL = RESOURCES_PATH / "ethanol_reduced_100.xyz" 12 | DATASET_PATH_CARBON = RESOURCES_PATH / "carbon_reduced_100.xyz" 13 | DATASET_PATH_QM7X = RESOURCES_PATH / "qm7x_reduced_100.xyz" 14 | EVAL_OPTIONS_PATH = RESOURCES_PATH / "eval.yaml" 15 | MODEL_PATH = RESOURCES_PATH / "model-32-bit.pt" 16 | MODEL_PATH_64_BIT = RESOURCES_PATH / "model-64-bit.ckpt" 17 | OPTIONS_PATH = RESOURCES_PATH / "options.yaml" 18 | OPTIONS_NANOPET_PATH = RESOURCES_PATH / "options-nanopet.yaml" 19 | -------------------------------------------------------------------------------- /tests/cli/dump_spherical_targets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from metatensor.torch import Labels, TensorBlock, TensorMap 4 | 5 | from metatrain.utils.data.readers.ase import read 6 | 7 | 8 | def l0_components_from_matrix(A): 9 | # note: might be wrong, but correct up to a normalization factor 10 | # which is good enough for the tests 11 | A = A.reshape(3, 3) 12 | l0_A = np.sum(np.diagonal(A)) 13 | return l0_A 14 | 15 | 16 | def l2_components_from_matrix(A): 17 | A = A.reshape(3, 3) 18 | 19 | l2_A = np.empty((5,)) 20 | l2_A[0] = (A[0, 1] + A[1, 0]) / 2.0 21 | l2_A[1] = (A[1, 2] + A[2, 1]) / 2.0 22 | l2_A[2] = (2.0 * A[2, 2] - A[0, 0] - A[1, 1]) / ((2.0) * np.sqrt(3.0)) 23 | l2_A[3] = (A[0, 2] + A[2, 0]) / 2.0 24 | l2_A[4] = (A[0, 0] - A[1, 1]) / 2.0 25 | 26 | return l2_A 27 | 28 | 29 | def dump_spherical_targets(path_in, path_out, with_scalar_part=False): 30 | # Takes polarizabilities from a dataset in Cartesian format, converts them to 31 | # spherical coordinates, and saves them in metatensor format (suitable for 32 | # training a model with spherical targets). 33 | 34 | structures = read(path_in, ":") 35 | 36 | polarizabilities_l2 = np.array( 37 | [ 38 | l2_components_from_matrix(structure.info["polarizability"]) 39 | for structure in structures 40 | ] 41 | ) 42 | 43 | if with_scalar_part: 44 | polarizabilities_l0 = np.array( 45 | [ 46 | l0_components_from_matrix(structure.info["polarizability"]) 47 | for structure in structures 48 | ] 49 | ) 50 | 51 | samples = Labels( 52 | names=["system"], 53 | values=torch.arange(len(structures)).reshape(-1, 1), 54 | ) 55 | 56 | properties = Labels.single() 57 | 58 | components_l2 = Labels( 59 | names=["o3_mu"], 60 | values=torch.tensor([[-2], [-1], [0], [1], [2]]), 61 | ) 62 | 63 | keys = Labels( 64 | names=["o3_lambda", "o3_sigma"], 65 | values=torch.tensor(([[0, 1]] if with_scalar_part else []) + [[2, 1]]), 66 | ) 67 | 68 | blocks = ( 69 | [ 70 | TensorBlock( 71 | values=torch.tensor(polarizabilities_l0, dtype=torch.float64).reshape( 72 | 100, 1, 1 73 | ), 74 | samples=samples, 75 | components=[Labels.range("o3_mu", 1)], 76 | properties=properties, 77 | ), 78 | ] 79 | if with_scalar_part 80 | else [] 81 | ) 82 | blocks.append( 83 | TensorBlock( 84 | values=torch.tensor(polarizabilities_l2, dtype=torch.float64).reshape( 85 | 100, 5, 1 86 | ), 87 | samples=samples, 88 | components=[components_l2], 89 | properties=properties, 90 | ) 91 | ) 92 | 93 | tensor_map = TensorMap( 94 | keys=keys, 95 | blocks=blocks, 96 | ) 97 | 98 | tensor_map.save(path_out) 99 | -------------------------------------------------------------------------------- /tests/cli/test_formatter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from metatrain.cli.formatter import CustomHelpFormatter 4 | 5 | 6 | def test_formatter(capsys): 7 | """Test that positonal arguments are displayed before optional in usage.""" 8 | parser = argparse.ArgumentParser(prog="myprog", formatter_class=CustomHelpFormatter) 9 | parser.add_argument("required_input") 10 | parser.add_argument("required_input2") 11 | parser.add_argument("-f", "--foo", help="optional argument") 12 | parser.add_argument("-b", "--bar", help="optional argument 2") 13 | 14 | parser.print_help() 15 | 16 | captured = capsys.readouterr() 17 | assert ( 18 | "usage: myprog required_input required_input2 [-h] [-f FOO] [-b BAR]" 19 | in captured.out 20 | ) 21 | -------------------------------------------------------------------------------- /tests/distributed/ethanol_reduced_100.xyz: -------------------------------------------------------------------------------- 1 | ../../examples/ase/ethanol_reduced_100.xyz -------------------------------------------------------------------------------- /tests/distributed/options-distributed.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | device: cuda 3 | 4 | architecture: 5 | name: soap_bpnn 6 | training: 7 | distributed: True 8 | batch_size: 25 9 | num_epochs: 100 10 | 11 | training_set: 12 | systems: 13 | read_from: ethanol_reduced_100.xyz 14 | length_unit: angstrom 15 | targets: 16 | energy: 17 | key: energy 18 | unit: eV 19 | forces: on 20 | 21 | test_set: 0.0 22 | validation_set: 0.5 23 | -------------------------------------------------------------------------------- /tests/distributed/options.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | device: cuda 3 | 4 | architecture: 5 | name: soap_bpnn 6 | training: 7 | batch_size: 50 8 | num_epochs: 100 9 | 10 | training_set: 11 | systems: 12 | read_from: ethanol_reduced_100.xyz 13 | length_unit: angstrom 14 | targets: 15 | energy: 16 | key: energy 17 | unit: eV 18 | forces: on 19 | 20 | test_set: 0.0 21 | validation_set: 0.5 22 | -------------------------------------------------------------------------------- /tests/distributed/readme.txt: -------------------------------------------------------------------------------- 1 | This sub-folder contains a simple test to check whether distributed training 2 | works as expected. The test consists of a training exercise using SOAP-BPNN 3 | on a small ethanol dataset. The logs obtained by using "options.yaml" and 4 | "options-distributed.yaml" should be the same. 5 | -------------------------------------------------------------------------------- /tests/distributed/submit-distributed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes 1 3 | #SBATCH --ntasks 2 4 | #SBATCH --ntasks-per-node 2 5 | #SBATCH --gpus-per-node 2 6 | #SBATCH --cpus-per-task 8 7 | #SBATCH --exclusive 8 | #SBATCH --time=1:00:00 9 | 10 | # load modules and/or virtual environments and/or containers here 11 | 12 | srun mtt train options-distributed.yaml 13 | -------------------------------------------------------------------------------- /tests/distributed/submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --nodes 1 3 | #SBATCH --ntasks 1 4 | #SBATCH --ntasks-per-node 1 5 | #SBATCH --gpus-per-node 1 6 | #SBATCH --cpus-per-task 8 7 | #SBATCH --time=1:00:00 8 | 9 | # load modules and/or virtual environments and/or containers here 10 | 11 | mtt train options.yaml 12 | -------------------------------------------------------------------------------- /tests/resources/ethanol_reduced_100.xyz: -------------------------------------------------------------------------------- 1 | ../../examples/ase/ethanol_reduced_100.xyz -------------------------------------------------------------------------------- /tests/resources/eval.yaml: -------------------------------------------------------------------------------- 1 | ../../docs/static/qm9/eval.yaml -------------------------------------------------------------------------------- /tests/resources/generate-outputs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eux 3 | 4 | echo "Generating data for testing..." 5 | 6 | ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) 7 | 8 | cd $ROOT_DIR 9 | 10 | mtt train options.yaml -o model-32-bit.pt -r base_precision=32 # > /dev/null 11 | mtt train options.yaml -o model-64-bit.pt -r base_precision=64 # > /dev/null 12 | mtt train options-nanopet.yaml -o model-no-extensions.pt # > /dev/null 13 | 14 | # upload results to private HF repo if token is set 15 | if [ -n "${HUGGINGFACE_TOKEN_METATRAIN:-}" ]; then 16 | huggingface-cli upload \ 17 | "metatensor/metatrain-test" \ 18 | "model-32-bit.ckpt" \ 19 | "model.ckpt" \ 20 | --commit-message="Overwrite test model with new version" \ 21 | --token=$HUGGINGFACE_TOKEN_METATRAIN 22 | fi 23 | -------------------------------------------------------------------------------- /tests/resources/options-nanopet.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | 3 | architecture: 4 | name: experimental.nanopet 5 | training: 6 | batch_size: 16 7 | num_epochs: 1 8 | 9 | training_set: 10 | systems: 11 | read_from: qm9_reduced_100.xyz 12 | length_unit: angstrom 13 | targets: 14 | energy: 15 | key: U0 16 | unit: eV 17 | 18 | test_set: 0.5 19 | validation_set: 0.1 20 | -------------------------------------------------------------------------------- /tests/resources/options.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | 3 | architecture: 4 | name: soap_bpnn 5 | training: 6 | batch_size: 16 7 | num_epochs: 1 8 | model: 9 | soap: 10 | max_radial: 4 11 | max_angular: 2 12 | 13 | training_set: 14 | systems: 15 | read_from: qm9_reduced_100.xyz 16 | length_unit: angstrom 17 | targets: 18 | energy: 19 | key: U0 20 | unit: eV 21 | 22 | test_set: 0.5 23 | validation_set: 0.1 24 | -------------------------------------------------------------------------------- /tests/resources/qm9_reduced_100.xyz: -------------------------------------------------------------------------------- 1 | ../../docs/static/qm9/qm9_reduced_100.xyz -------------------------------------------------------------------------------- /tests/resources/test.yaml: -------------------------------------------------------------------------------- 1 | seed: 42 2 | 3 | architecture: 4 | name: soap_bpnn 5 | training: 6 | batch_size: 2 7 | num_epochs: 1 8 | 9 | training_set: 10 | systems: 11 | read_from: ethanol_reduced_100.xyz 12 | length_unit: angstrom 13 | targets: 14 | forces: 15 | quantity: force 16 | key: forces 17 | per_atom: true 18 | num_subtargets: 3 19 | 20 | test_set: 0.5 21 | validation_set: 0.1 22 | -------------------------------------------------------------------------------- /tests/test_init.py: -------------------------------------------------------------------------------- 1 | import metatrain 2 | 3 | 4 | def test_version_exists(): 5 | metatrain.__version__ 6 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from metatrain.utils.architectures import get_default_hypers 4 | 5 | 6 | MODEL_HYPERS = get_default_hypers("soap_bpnn")["model"] 7 | 8 | RESOURCES_PATH = Path(__file__).parents[1] / "resources" 9 | -------------------------------------------------------------------------------- /tests/utils/data/test_get_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from omegaconf import OmegaConf 4 | 5 | from metatrain.utils.data import get_dataset 6 | 7 | 8 | RESOURCES_PATH = Path(__file__).parents[2] / "resources" 9 | 10 | 11 | def test_get_dataset(): 12 | options = { 13 | "systems": { 14 | "read_from": str(RESOURCES_PATH / "qm9_reduced_100.xyz"), 15 | "reader": "ase", 16 | }, 17 | "targets": { 18 | "energy": { 19 | "quantity": "energy", 20 | "read_from": str(RESOURCES_PATH / "qm9_reduced_100.xyz"), 21 | "reader": "ase", 22 | "key": "U0", 23 | "unit": "eV", 24 | "type": "scalar", 25 | "per_atom": False, 26 | "num_subtargets": 1, 27 | "forces": False, 28 | "stress": False, 29 | "virial": False, 30 | } 31 | }, 32 | } 33 | 34 | dataset, target_info = get_dataset(OmegaConf.create(options)) 35 | 36 | dataset[0].system 37 | dataset[0].energy 38 | assert "energy" in target_info 39 | assert target_info["energy"].quantity == "energy" 40 | assert target_info["energy"].unit == "eV" 41 | -------------------------------------------------------------------------------- /tests/utils/data/test_readers_ase.py: -------------------------------------------------------------------------------- 1 | """Tests for the ASE readers. The functionality of the top-level functions 2 | `read_systems`, `read_energy`, `read_generic` is already tested through 3 | the reader tests in `test_readers.py`. Here we test the specific ASE readers 4 | for energies, forces, stresses, and virials.""" 5 | 6 | import ase 7 | import ase.io 8 | import pytest 9 | import torch 10 | from metatensor.torch import Labels 11 | from test_targets_ase import ase_systems 12 | 13 | from metatrain.utils.data.readers.ase import ( 14 | _read_energy_ase, 15 | _read_forces_ase, 16 | _read_stress_ase, 17 | _read_virial_ase, 18 | ) 19 | 20 | 21 | @pytest.mark.parametrize("key", ["true_energy", "energy"]) 22 | def test_read_energies(monkeypatch, tmp_path, key): 23 | monkeypatch.chdir(tmp_path) 24 | 25 | filename = "systems.xyz" 26 | systems = ase_systems() 27 | ase.io.write(filename, systems) 28 | 29 | results = _read_energy_ase(filename, key=key) 30 | 31 | assert type(results) is list 32 | assert len(results) == len(systems) 33 | for i_system, result in enumerate(results): 34 | assert result.values.dtype is torch.float64 35 | assert result.samples.names == ["system"] 36 | assert result.samples.values == torch.tensor([[i_system]]) 37 | assert result.properties == Labels("energy", torch.tensor([[0]])) 38 | 39 | 40 | @pytest.mark.parametrize("key", ["true_forces", "forces"]) 41 | def test_read_forces(monkeypatch, tmp_path, key): 42 | monkeypatch.chdir(tmp_path) 43 | 44 | filename = "systems.xyz" 45 | systems = ase_systems() 46 | ase.io.write(filename, systems) 47 | 48 | results = _read_forces_ase(filename, key=key) 49 | 50 | assert type(results) is list 51 | assert len(results) == len(systems) 52 | for i_system, result in enumerate(results): 53 | assert result.values.dtype is torch.float64 54 | assert result.samples.names == ["sample", "system", "atom"] 55 | assert torch.all(result.samples["sample"] == torch.tensor(0)) 56 | assert torch.all(result.samples["system"] == torch.tensor(i_system)) 57 | assert result.components == [Labels(["xyz"], torch.arange(3).reshape(-1, 1))] 58 | assert result.properties == Labels("energy", torch.tensor([[0]])) 59 | 60 | 61 | @pytest.mark.parametrize("key", ["stress", "stress-3x3"]) 62 | @pytest.mark.parametrize("reader_func", [_read_stress_ase, _read_virial_ase]) 63 | def test_read_stress_virial(reader_func, monkeypatch, tmp_path, key): 64 | monkeypatch.chdir(tmp_path) 65 | 66 | filename = "systems.xyz" 67 | systems = ase_systems() 68 | ase.io.write(filename, systems) 69 | 70 | results = reader_func(filename, key=key) 71 | 72 | assert type(results) is list 73 | assert len(results) == len(systems) 74 | components = [ 75 | Labels(["xyz_1"], torch.arange(3).reshape(-1, 1)), 76 | Labels(["xyz_2"], torch.arange(3).reshape(-1, 1)), 77 | ] 78 | for result in results: 79 | assert result.values.dtype is torch.float64 80 | assert result.samples.names == ["sample"] 81 | assert result.samples.values == torch.tensor([[0]]) 82 | assert result.components == components 83 | assert result.properties == Labels("energy", torch.tensor([[0]])) 84 | -------------------------------------------------------------------------------- /tests/utils/data/test_system_to_ase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from metatomic.torch import System 3 | 4 | from metatrain.utils.data import system_to_ase 5 | 6 | 7 | def test_system_to_ase(): 8 | """Tests the conversion of a System to an ASE atoms object.""" 9 | # Create a system 10 | system = System( 11 | positions=torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), 12 | types=torch.tensor([1, 8]), 13 | cell=torch.tensor([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]), 14 | pbc=torch.tensor([True, True, True]), 15 | ) 16 | 17 | # Convert the system to an ASE atoms object 18 | atoms = system_to_ase(system) 19 | 20 | # Check the positions 21 | assert atoms.positions.tolist() == system.positions.tolist() 22 | 23 | # Check the species 24 | assert atoms.numbers.tolist() == system.types.tolist() 25 | 26 | # Check the cell 27 | assert atoms.cell.tolist() == system.cell.tolist() 28 | assert atoms.pbc.tolist() == [True, True, True] 29 | -------------------------------------------------------------------------------- /tests/utils/test_dtype.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from metatrain.utils.dtype import dtype_to_str 4 | 5 | 6 | def test_dtype_to_string(): 7 | assert dtype_to_str(torch.float64) == "float64" 8 | assert dtype_to_str(torch.float32) == "float32" 9 | assert dtype_to_str(torch.int64) == "int64" 10 | assert dtype_to_str(torch.int32) == "int32" 11 | assert dtype_to_str(torch.bool) == "bool" 12 | -------------------------------------------------------------------------------- /tests/utils/test_errors.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from metatrain.utils.errors import ArchitectureError 4 | 5 | 6 | def test_architecture_error(): 7 | match = "The error above most likely originates from an architecture" 8 | with pytest.raises(ArchitectureError, match=match): 9 | try: 10 | raise ValueError("An example error from the architecture") 11 | except Exception as e: 12 | raise ArchitectureError(e) 13 | -------------------------------------------------------------------------------- /tests/utils/test_evaluate_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from metatrain.soap_bpnn import __model__ 5 | from metatrain.utils.data import DatasetInfo, read_systems 6 | from metatrain.utils.data.target_info import get_energy_target_info 7 | from metatrain.utils.evaluate_model import evaluate_model 8 | from metatrain.utils.neighbor_lists import ( 9 | get_requested_neighbor_lists, 10 | get_system_with_neighbor_lists, 11 | ) 12 | 13 | from . import MODEL_HYPERS, RESOURCES_PATH 14 | 15 | 16 | @pytest.mark.parametrize("training", [True, False]) 17 | @pytest.mark.parametrize("exported", [True, False]) 18 | def test_evaluate_model(training, exported): 19 | """Test that the evaluate_model function works as intended.""" 20 | 21 | systems = read_systems(RESOURCES_PATH / "carbon_reduced_100.xyz")[:2] 22 | 23 | atomic_types = set( 24 | torch.unique(torch.concatenate([system.types for system in systems])) 25 | ) 26 | 27 | targets = { 28 | "energy": get_energy_target_info( 29 | {"unit": "eV"}, 30 | add_position_gradients=True, 31 | add_strain_gradients=True, 32 | ) 33 | } 34 | 35 | dataset_info = DatasetInfo( 36 | length_unit="angstrom", atomic_types=atomic_types, targets=targets 37 | ) 38 | model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info) 39 | 40 | if exported: 41 | model = model.export() 42 | 43 | requested_neighbor_lists = get_requested_neighbor_lists(model) 44 | systems = [ 45 | get_system_with_neighbor_lists(system, requested_neighbor_lists) 46 | for system in systems 47 | ] 48 | 49 | systems = [system.to(torch.float32) for system in systems] 50 | outputs = evaluate_model( 51 | model, systems, targets, is_training=training, check_consistency=True 52 | ) 53 | 54 | assert isinstance(outputs, dict) 55 | assert "energy" in outputs 56 | assert "positions" in outputs["energy"].block().gradients_list() 57 | assert "strain" in outputs["energy"].block().gradients_list() 58 | 59 | if training: 60 | assert outputs["energy"].block().gradient("positions").values.requires_grad 61 | assert outputs["energy"].block().gradient("strain").values.requires_grad 62 | else: 63 | assert not outputs["energy"].block().gradient("positions").values.requires_grad 64 | assert not outputs["energy"].block().gradient("strain").values.requires_grad 65 | -------------------------------------------------------------------------------- /tests/utils/test_external_naming.py: -------------------------------------------------------------------------------- 1 | from metatrain.utils.data.target_info import get_energy_target_info 2 | from metatrain.utils.external_naming import to_external_name, to_internal_name 3 | 4 | 5 | def test_to_external_name(): 6 | """Tests the to_external_name function.""" 7 | 8 | quantities = { 9 | "energy": get_energy_target_info({"unit": "eV"}), 10 | "mtt::free_energy": get_energy_target_info({"unit": "eV"}), 11 | "mtt::foo": get_energy_target_info({"unit": "eV"}), 12 | } 13 | 14 | # hack to test the fact that non-energies should be treated differently 15 | # (i.e., their gradients should not have special names) 16 | quantities["mtt::foo"].quantity = "bar" 17 | 18 | assert to_external_name("energy_positions_gradients", quantities) == "forces" 19 | assert ( 20 | to_external_name("mtt::free_energy_positions_gradients", quantities) 21 | == "forces[mtt::free_energy]" 22 | ) 23 | assert ( 24 | to_external_name("mtt::foo_positions_gradients", quantities) 25 | == "mtt::foo_positions_gradients" 26 | ) 27 | assert to_external_name("energy_strain_gradients", quantities) == "virial" 28 | assert ( 29 | to_external_name("mtt::free_energy_strain_gradients", quantities) 30 | == "virial[mtt::free_energy]" 31 | ) 32 | assert ( 33 | to_external_name("mtt::foo_strain_gradients", quantities) 34 | == "mtt::foo_strain_gradients" 35 | ) 36 | assert to_external_name("energy", quantities) == "energy" 37 | assert to_external_name("mtt::free_energy", quantities) == "mtt::free_energy" 38 | assert to_external_name("mtt::foo", quantities) == "mtt::foo" 39 | 40 | 41 | def test_to_internal_name(): 42 | """Tests the to_internal_name function.""" 43 | 44 | assert to_internal_name("forces") == "energy_positions_gradients" 45 | assert ( 46 | to_internal_name("forces[mtt::free_energy]") 47 | == "mtt::free_energy_positions_gradients" 48 | ) 49 | assert ( 50 | to_internal_name("mtt::foo_positions_gradients") 51 | == "mtt::foo_positions_gradients" 52 | ) 53 | assert to_internal_name("virial") == "energy_strain_gradients" 54 | assert ( 55 | to_internal_name("virial[mtt::free_energy]") 56 | == "mtt::free_energy_strain_gradients" 57 | ) 58 | assert to_internal_name("mtt::foo_strain_gradients") == "mtt::foo_strain_gradients" 59 | assert to_internal_name("energy") == "energy" 60 | assert to_internal_name("mtt::free_energy") == "mtt::free_energy" 61 | assert to_internal_name("mtt::foo") == "mtt::foo" 62 | -------------------------------------------------------------------------------- /tests/utils/test_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import pytest 5 | from metatomic.torch import AtomisticModel 6 | 7 | from metatrain.soap_bpnn.model import SoapBpnn 8 | from metatrain.utils.io import check_file_extension, is_exported_file, load_model 9 | 10 | from . import RESOURCES_PATH 11 | 12 | 13 | def is_None(*args, **kwargs) -> None: 14 | return None 15 | 16 | 17 | @pytest.mark.parametrize("filename", ["example.txt", Path("example.txt")]) 18 | def test_check_suffix(filename): 19 | result = check_file_extension(filename, ".txt") 20 | 21 | assert str(result) == "example.txt" 22 | assert isinstance(result, type(filename)) 23 | 24 | 25 | @pytest.mark.parametrize("filename", ["example", Path("example")]) 26 | def test_warning_on_missing_suffix(filename): 27 | match = r"The file name should have a '\.txt' file extension." 28 | with pytest.warns(UserWarning, match=match): 29 | result = check_file_extension(filename, ".txt") 30 | 31 | assert str(result) == "example.txt" 32 | assert isinstance(result, type(filename)) 33 | 34 | 35 | def test_is_exported_file(): 36 | assert is_exported_file(RESOURCES_PATH / "model-32-bit.pt") 37 | assert not is_exported_file(RESOURCES_PATH / "model-32-bit.ckpt") 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "path", 42 | [ 43 | RESOURCES_PATH / "model-32-bit.ckpt", 44 | str(RESOURCES_PATH / "model-32-bit.ckpt"), 45 | f"file:{str(RESOURCES_PATH / 'model-32-bit.ckpt')}", 46 | ], 47 | ) 48 | def test_load_model_checkpoint(path): 49 | model = load_model(path) 50 | assert type(model) is SoapBpnn 51 | 52 | # TODO: test that weights are the expected if loading with `context == 'export'`. 53 | # One can use `list(model.bpnn[0].parameters())[0][0]` to get some weights. But, 54 | # currently weights of the `"export"` and the `"restart"` context are the same... 55 | 56 | 57 | @pytest.mark.parametrize( 58 | "path", 59 | [ 60 | RESOURCES_PATH / "model-32-bit.pt", 61 | str(RESOURCES_PATH / "model-32-bit.pt"), 62 | f"file:{str(RESOURCES_PATH / 'model-32-bit.pt')}", 63 | ], 64 | ) 65 | def test_load_model_exported(path): 66 | model = load_model(path) 67 | assert type(model) is AtomisticModel 68 | 69 | 70 | @pytest.mark.parametrize("suffix", [".yml", ".yaml"]) 71 | def test_load_model_yaml(suffix): 72 | match = f"path 'foo{suffix}' seems to be a YAML option file and not a model" 73 | with pytest.raises(ValueError, match=match): 74 | load_model(f"foo{suffix}") 75 | 76 | 77 | def test_load_model_token(): 78 | """Test that the export cli succeeds when exporting a private 79 | model from HuggingFace.""" 80 | 81 | hf_token = os.getenv("HUGGINGFACE_TOKEN_METATRAIN") 82 | if hf_token is None: 83 | pytest.skip("HuggingFace token not found in environment.") 84 | assert len(hf_token) > 0 85 | 86 | path = "https://huggingface.co/metatensor/metatrain-test/resolve/main/model.ckpt" 87 | load_model(path, hf_token=hf_token) 88 | 89 | 90 | def test_load_model_token_invalid_url_style(): 91 | hf_token = os.getenv("HUGGINGFACE_TOKEN_METATRAIN") 92 | if hf_token is None: 93 | pytest.skip("HuggingFace token not found in environment.") 94 | assert len(hf_token) > 0 95 | 96 | # change `resolve` to ``foo`` to make the URL scheme invalid 97 | path = "https://huggingface.co/metatensor/metatrain-test/foo/main/model.ckpt" 98 | 99 | with pytest.raises( 100 | ValueError, 101 | match=f"URL '{path}' has an invalid format for the Hugging Face Hub.", 102 | ): 103 | load_model(path, hf_token=hf_token) 104 | -------------------------------------------------------------------------------- /tests/utils/test_jsonschema.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | from jsonschema.exceptions import ValidationError 5 | 6 | from metatrain.utils.architectures import get_architecture_path 7 | from metatrain.utils.jsonschema import validate 8 | 9 | 10 | def schema(): 11 | with open(get_architecture_path("soap_bpnn") / "schema-hypers.json", "r") as f: 12 | return json.load(f) 13 | 14 | 15 | def test_validate_valid(): 16 | instance = { 17 | "name": "soap_bpnn", 18 | "training": {"num_epochs": 1, "batch_size": 2}, 19 | } 20 | validate(instance=instance, schema=schema()) 21 | 22 | 23 | def test_validate_single_suggestion(): 24 | """Two invalid names; one to random that a useful suggestion can be given.""" 25 | instance = { 26 | "name": "soap_bpnn", 27 | "training": {"nasdasd": 1, "batch_sizes": 2}, 28 | } 29 | match = ( 30 | r"Unrecognized options \('batch_sizes', 'nasdasd' were unexpected\). " 31 | r"Do you mean 'batch_size'?" 32 | ) 33 | with pytest.raises(ValidationError, match=match): 34 | validate(instance=instance, schema=schema()) 35 | 36 | 37 | def test_validate_multi_suggestion(): 38 | instance = { 39 | "name": "soap_bpnn", 40 | "training": {"num_epoch": 1, "batch_sizes": 2}, 41 | } 42 | match = ( 43 | r"Unrecognized options \('batch_sizes', 'num_epoch' were unexpected\). " 44 | r"Do you mean" 45 | ) 46 | with pytest.raises(ValidationError, match=match): 47 | validate(instance=instance, schema=schema()) 48 | -------------------------------------------------------------------------------- /tests/utils/test_long_range.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from metatomic.torch import systems_to_torch 4 | 5 | from metatrain.experimental.nanopet import NanoPET 6 | from metatrain.soap_bpnn import SoapBpnn 7 | from metatrain.utils.architectures import get_default_hypers 8 | from metatrain.utils.data import DatasetInfo 9 | from metatrain.utils.data.readers.ase import read 10 | from metatrain.utils.data.target_info import ( 11 | get_energy_target_info, 12 | ) 13 | from metatrain.utils.neighbor_lists import ( 14 | get_requested_neighbor_lists, 15 | get_system_with_neighbor_lists, 16 | ) 17 | 18 | from . import RESOURCES_PATH 19 | 20 | 21 | @pytest.mark.parametrize("periodicity", [True, False]) 22 | # we only have torchPME integration for nanoPET and SOAP-BPNN for now 23 | @pytest.mark.parametrize("model_name", ["experimental.nanopet", "soap_bpnn"]) 24 | def test_long_range(periodicity, model_name, tmpdir): 25 | """Tests that the long-range module can predict successfully.""" 26 | 27 | if periodicity: 28 | structures = read(RESOURCES_PATH / "carbon_reduced_100.xyz", ":10") 29 | else: 30 | structures = read(RESOURCES_PATH / "ethanol_reduced_100.xyz", ":10") 31 | systems = systems_to_torch(structures) 32 | 33 | dataset_info = DatasetInfo( 34 | length_unit="Angstrom", 35 | atomic_types=[1, 6, 8], 36 | targets={"energy": get_energy_target_info({"unit": "eV"})}, 37 | ) 38 | 39 | hypers = get_default_hypers(model_name) 40 | hypers["model"]["long_range"]["enable"] = True 41 | if model_name == "soap_bpnn": 42 | model = SoapBpnn(hypers["model"], dataset_info) 43 | else: 44 | model = NanoPET(hypers["model"], dataset_info) 45 | requested_nls = get_requested_neighbor_lists(model) 46 | 47 | systems = [ 48 | get_system_with_neighbor_lists(system, requested_nls) for system in systems 49 | ] 50 | 51 | model( 52 | systems, 53 | {"energy": model.outputs["energy"]}, 54 | ) 55 | 56 | # now torchscripted 57 | model = torch.jit.script(model) 58 | model( 59 | systems, 60 | {"energy": model.outputs["energy"]}, 61 | ) 62 | 63 | # torch.jit.save and torch.jit.load 64 | with tmpdir.as_cwd(): 65 | torch.jit.save(model, "model.pt") 66 | model = torch.jit.load("model.pt") 67 | model( 68 | systems, 69 | {"energy": model.outputs["energy"]}, 70 | ) 71 | -------------------------------------------------------------------------------- /tests/utils/test_metadata.py: -------------------------------------------------------------------------------- 1 | from metatomic.torch import ModelMetadata 2 | 3 | from metatrain.utils.metadata import merge_metadata 4 | 5 | 6 | def test_append_metadata_new_keys(): 7 | self_meta = ModelMetadata(references={"implementation": ["ref1"]}) 8 | other_meta = ModelMetadata(references={"architecture": ["ref2"]}) 9 | 10 | result = merge_metadata(self_meta, other_meta) 11 | 12 | assert result.references["implementation"] == ["ref1"] 13 | assert result.references["architecture"] == ["ref2"] 14 | 15 | 16 | def test_append_metadata_existing_keys(): 17 | self_meta = ModelMetadata(references={"implementation": ["ref1"]}) 18 | other_meta = ModelMetadata(references={"implementation": ["ref2"]}) 19 | 20 | result = merge_metadata(self_meta, other_meta) 21 | 22 | assert result.references["implementation"] == ["ref1", "ref2"] 23 | 24 | 25 | def test_append_metadata_mixed_keys(): 26 | self_meta = ModelMetadata(references={"implementation": ["ref1"]}) 27 | other_meta = ModelMetadata( 28 | references={"implementation": ["ref2"], "architecture": ["ref3"]} 29 | ) 30 | 31 | result = merge_metadata(self_meta, other_meta) 32 | 33 | assert result.references["implementation"] == ["ref1", "ref2"] 34 | assert result.references["architecture"] == ["ref3"] 35 | 36 | 37 | def test_merge_metadata(): 38 | self_meta = ModelMetadata( 39 | name="self_meta", 40 | description="self_meta", 41 | authors=[ 42 | "John Doe", 43 | "Jane Smith", 44 | ], 45 | references={ 46 | "architecture": ["ref1"], 47 | "model": ["ref2"], 48 | }, 49 | ) 50 | other_meta = ModelMetadata( 51 | name="other_meta", 52 | description="other_meta", 53 | authors=[ 54 | "John Doe", 55 | "Alice Johnson", 56 | ], 57 | references={ 58 | "model": ["ref3"], 59 | "implementation": ["ref4"], 60 | }, 61 | ) 62 | 63 | result = merge_metadata(self_meta, other_meta) 64 | 65 | assert result.name == "other_meta" 66 | assert result.description == "other_meta" 67 | assert result.authors == [ 68 | "John Doe", 69 | "Jane Smith", 70 | "Alice Johnson", 71 | ] 72 | assert result.references["architecture"] == ["ref1"] 73 | assert result.references["model"] == ["ref2", "ref3"] 74 | assert result.references["implementation"] == ["ref4"] 75 | -------------------------------------------------------------------------------- /tests/utils/test_neighbor_list.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from metatomic.torch import NeighborListOptions 4 | 5 | from metatrain.utils.data.readers.ase import read_systems 6 | from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists 7 | 8 | 9 | RESOURCES_PATH = Path(__file__).parents[1] / "resources" 10 | 11 | 12 | def test_attach_neighbor_lists(): 13 | filename = RESOURCES_PATH / "qm9_reduced_100.xyz" 14 | systems = read_systems(filename) 15 | 16 | requested_neighbor_lists = [ 17 | NeighborListOptions(cutoff=4.0, full_list=True, strict=True), 18 | NeighborListOptions(cutoff=5.0, full_list=False, strict=True), 19 | NeighborListOptions(cutoff=6.0, full_list=True, strict=True), 20 | ] 21 | 22 | new_system = get_system_with_neighbor_lists(systems[0], requested_neighbor_lists) 23 | 24 | assert requested_neighbor_lists[0] in new_system.known_neighbor_lists() 25 | assert requested_neighbor_lists[1] in new_system.known_neighbor_lists() 26 | assert requested_neighbor_lists[2] in new_system.known_neighbor_lists() 27 | 28 | extraneous_nl = NeighborListOptions(cutoff=5.0, full_list=True, strict=True) 29 | assert extraneous_nl not in new_system.known_neighbor_lists() 30 | 31 | for nl_options in new_system.known_neighbor_lists(): 32 | nl = new_system.get_neighbor_list(nl_options) 33 | assert nl.samples.names == [ 34 | "first_atom", 35 | "second_atom", 36 | "cell_shift_a", 37 | "cell_shift_b", 38 | "cell_shift_c", 39 | ] 40 | assert len(nl.values.shape) == 3 41 | -------------------------------------------------------------------------------- /tests/utils/test_sum_over_atoms.py: -------------------------------------------------------------------------------- 1 | import metatensor.torch 2 | import torch 3 | from metatensor.torch import Labels, TensorBlock, TensorMap 4 | 5 | from metatrain.utils.sum_over_atoms import sum_over_atoms 6 | 7 | 8 | def test_sum_over_atoms(): 9 | """Test the sum_over_atoms function.""" 10 | block1 = TensorBlock( 11 | values=torch.tensor([[[1.0]], [[2.0]], [[3.0]]]), 12 | samples=Labels( 13 | names=["system", "atom"], 14 | values=torch.tensor([[0, 0], [0, 1], [1, 0]]), 15 | ), 16 | components=[Labels.range("comp", 1)], 17 | properties=Labels.single(), 18 | ) 19 | 20 | block2 = TensorBlock( 21 | values=torch.tensor([[[4.0], [5.0]], [[6.0], [7.0]], [[8.0], [9.0]]]), 22 | samples=Labels( 23 | names=["system", "atom"], 24 | values=torch.tensor([[0, 0], [0, 1], [1, 0]]), 25 | ), 26 | components=[Labels.range("comp", 2)], 27 | properties=Labels.single(), 28 | ) 29 | 30 | tensor_map = TensorMap( 31 | keys=Labels.range("key", 2), 32 | blocks=[block1, block2], 33 | ) 34 | 35 | # Call the sum_over_atoms function 36 | summed_tensor_map = sum_over_atoms(tensor_map) 37 | 38 | summed_tensor_map_ref = metatensor.torch.sum_over_samples( 39 | tensor_map, 40 | sample_names=["atom"], 41 | ) 42 | 43 | assert metatensor.torch.allclose(summed_tensor_map, summed_tensor_map_ref) 44 | -------------------------------------------------------------------------------- /tests/utils/test_transfer.py: -------------------------------------------------------------------------------- 1 | import metatensor.torch 2 | import torch 3 | from metatensor.torch import Labels, TensorMap 4 | from metatomic.torch import System 5 | 6 | from metatrain.utils.transfer import ( 7 | systems_and_targets_to_device, 8 | systems_and_targets_to_dtype, 9 | ) 10 | 11 | 12 | def test_systems_and_targets_to_dtype(): 13 | system = System( 14 | positions=torch.tensor([[1.0, 1.0, 1.0]]), 15 | cell=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), 16 | types=torch.tensor([1]), 17 | pbc=torch.tensor([True, True, True]), 18 | ) 19 | targets = TensorMap( 20 | keys=Labels.single(), 21 | blocks=[metatensor.torch.block_from_array(torch.tensor([[1.0]]))], 22 | ) 23 | 24 | systems = [system] 25 | targets = {"energy": targets} 26 | 27 | assert systems[0].positions.dtype == torch.float32 28 | assert systems[0].cell.dtype == torch.float32 29 | assert targets["energy"].block().values.dtype == torch.float32 30 | 31 | systems, targets = systems_and_targets_to_dtype(systems, targets, torch.float64) 32 | 33 | assert systems[0].positions.dtype == torch.float64 34 | assert systems[0].cell.dtype == torch.float64 35 | assert targets["energy"].block().values.dtype == torch.float64 36 | 37 | 38 | def test_systems_and_targets_to_dtype_and_device(): 39 | system = System( 40 | positions=torch.tensor([[1.0, 1.0, 1.0]]), 41 | cell=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), 42 | types=torch.tensor([1]), 43 | pbc=torch.tensor([True, True, True]), 44 | ) 45 | targets = TensorMap( 46 | keys=Labels.single(), 47 | blocks=[metatensor.torch.block_from_array(torch.tensor([[1.0]]))], 48 | ) 49 | 50 | systems = [system] 51 | targets = {"energy": targets} 52 | 53 | assert systems[0].positions.device == torch.device("cpu") 54 | assert systems[0].types.device == torch.device("cpu") 55 | assert targets["energy"].block().values.device == torch.device("cpu") 56 | 57 | systems, targets = systems_and_targets_to_device( 58 | systems, targets, torch.device("meta") 59 | ) 60 | 61 | assert systems[0].positions.device == torch.device("meta") 62 | assert systems[0].types.device == torch.device("meta") 63 | assert targets["energy"].block().values.device == torch.device("meta") 64 | -------------------------------------------------------------------------------- /tests/utils/test_units.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from metatrain.utils.units import ev_to_mev, get_gradient_units 4 | 5 | 6 | def test_get_gradient_units(): 7 | """Tests the get_gradient_units function.""" 8 | # Test the case where the base unit is empty 9 | assert get_gradient_units("", "positions", "angstrom") == "" 10 | # Test the case where the length unit is angstrom 11 | for length_unit in ["angstrom", "å", "ångstrom", "Ångstrom", "Å"]: 12 | assert get_gradient_units("unit", "positions", length_unit) == "unit/A" 13 | # Test the case where the gradient name is strain 14 | assert get_gradient_units("unit", "strain", "angstrom") == "unit" 15 | # Test the case where the gradient name is unknown 16 | with pytest.raises(ValueError): 17 | get_gradient_units("unit", "unknown", "angstrom") 18 | 19 | 20 | def test_ev_to_mev(): 21 | """Tests the ev_to_mev function.""" 22 | # Test the case where the unit is not eV 23 | assert ev_to_mev(1.0, "unit") == (1.0, "unit") 24 | # Test the case where the unit is eV 25 | assert ev_to_mev(1.0, "eV") == (1000.0, "meV") 26 | # Test the case where the unit is eV with a different case 27 | assert ev_to_mev(0.2, "ev") == (200.0, "mev") 28 | # Test the case where the unit is a derived unit of eV 29 | assert ev_to_mev(1.0, "eV/unit") == (1000.0, "meV/unit") 30 | --------------------------------------------------------------------------------