├── .github └── workflows │ ├── black.yml │ └── python-publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _templates │ └── autosummary │ │ └── classtemplate.rst ├── api │ ├── atomistic.rst │ ├── data.rst │ ├── datasets.rst │ ├── md.rst │ ├── model.rst │ ├── nn.rst │ ├── representation.rst │ ├── schnetpack.rst │ ├── task.rst │ ├── train.rst │ └── transform.rst ├── conf.py ├── getstarted.rst ├── howtos ├── index.rst ├── pictures │ └── tensorboard.png ├── sphinx-requirements.txt ├── tutorials └── userguide │ ├── configs.rst │ ├── md.rst │ └── overview.rst ├── examples ├── README.md ├── howtos │ ├── howto_batchwise_relaxations.ipynb │ └── lammps.rst └── tutorials │ ├── tutorial_01_preparing_data.ipynb │ ├── tutorial_02_qm9.ipynb │ ├── tutorial_03_force_models.ipynb │ ├── tutorial_04_molecular_dynamics.ipynb │ ├── tutorial_05_materials.ipynb │ └── tutorials_figures │ ├── integrator.svg │ ├── md_flowchart.svg │ └── rpmd.svg ├── interfaces └── lammps │ ├── examples │ └── aspirin │ │ ├── aspirin.data │ │ ├── aspirin_md.in │ │ └── best_model │ ├── pair_schnetpack.cpp │ ├── pair_schnetpack.h │ └── patch_lammps.sh ├── pyproject.toml ├── readthedocs.yaml ├── src ├── schnetpack │ ├── __init__.py │ ├── atomistic │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── atomwise.py │ │ ├── distances.py │ │ ├── electrostatic.py │ │ ├── external_fields.py │ │ ├── nuclear_repulsion.py │ │ └── response.py │ ├── cli.py │ ├── configs │ │ ├── __init__.py │ │ ├── callbacks │ │ │ ├── checkpoint.yaml │ │ │ ├── earlystopping.yaml │ │ │ ├── ema.yaml │ │ │ └── lrmonitor.yaml │ │ ├── data │ │ │ ├── ani1.yaml │ │ │ ├── custom.yaml │ │ │ ├── iso17.yaml │ │ │ ├── materials_project.yaml │ │ │ ├── md17.yaml │ │ │ ├── md22.yaml │ │ │ ├── omdb.yaml │ │ │ ├── qm7x.yaml │ │ │ ├── qm9.yaml │ │ │ ├── rmd17.yaml │ │ │ └── sampler │ │ │ │ └── stratified_property.yaml │ │ ├── experiment │ │ │ ├── md17.yaml │ │ │ ├── qm9_atomwise.yaml │ │ │ ├── qm9_dipole.yaml │ │ │ ├── response.yaml │ │ │ └── rmd17.yaml │ │ ├── globals │ │ │ └── default_globals.yaml │ │ ├── logger │ │ │ ├── aim.yaml │ │ │ ├── csv.yaml │ │ │ ├── tensorboard.yaml │ │ │ └── wandb.yaml │ │ ├── model │ │ │ ├── nnp.yaml │ │ │ └── representation │ │ │ │ ├── field_schnet.yaml │ │ │ │ ├── painn.yaml │ │ │ │ ├── radial_basis │ │ │ │ ├── bessel.yaml │ │ │ │ └── gaussian.yaml │ │ │ │ ├── schnet.yaml │ │ │ │ └── so3net.yaml │ │ ├── predict.yaml │ │ ├── run │ │ │ └── default_run.yaml │ │ ├── task │ │ │ ├── default_task.yaml │ │ │ ├── optimizer │ │ │ │ ├── adabelief.yaml │ │ │ │ ├── adam.yaml │ │ │ │ └── sgd.yaml │ │ │ └── scheduler │ │ │ │ └── reduce_on_plateau.yaml │ │ ├── train.yaml │ │ └── trainer │ │ │ ├── ddp_debug.yaml │ │ │ ├── ddp_trainer.yaml │ │ │ ├── debug_trainer.yaml │ │ │ └── default_trainer.yaml │ ├── data │ │ ├── __init__.py │ │ ├── atoms.py │ │ ├── datamodule.py │ │ ├── loader.py │ │ ├── sampler.py │ │ ├── splitting.py │ │ └── stats.py │ ├── datasets │ │ ├── __init__.py │ │ ├── ani1.py │ │ ├── iso17.py │ │ ├── materials_project.py │ │ ├── md17.py │ │ ├── md22.py │ │ ├── omdb.py │ │ ├── qm7x.py │ │ ├── qm9.py │ │ ├── rmd17.py │ │ └── tmqm.py │ ├── interfaces │ │ ├── __init__.py │ │ ├── ase_interface.py │ │ └── batchwise_optimization.py │ ├── md │ │ ├── __init__.py │ │ ├── calculators │ │ │ ├── __init__.py │ │ │ ├── base_calculator.py │ │ │ ├── ensemble_calculator.py │ │ │ ├── lj_calculator.py │ │ │ ├── orca_calculator.py │ │ │ └── schnetpack_calculator.py │ │ ├── cli.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── hdf5_data.py │ │ │ └── spectra.py │ │ ├── initial_conditions.py │ │ ├── integrators.py │ │ ├── md_configs │ │ │ ├── __init__.py │ │ │ ├── calculator │ │ │ │ ├── lj.yaml │ │ │ │ ├── neighbor_list │ │ │ │ │ ├── ase.yaml │ │ │ │ │ ├── matscipy.yaml │ │ │ │ │ └── torch.yaml │ │ │ │ ├── orca.yaml │ │ │ │ ├── spk.yaml │ │ │ │ └── spk_ensemble.yaml │ │ │ ├── callbacks │ │ │ │ ├── checkpoint.yaml │ │ │ │ ├── hdf5.yaml │ │ │ │ └── tensorboard.yaml │ │ │ ├── config.yaml │ │ │ ├── dynamics │ │ │ │ ├── barostat │ │ │ │ │ ├── nhc_aniso.yaml │ │ │ │ │ ├── nhc_iso.yaml │ │ │ │ │ └── pile_rpmd.yaml │ │ │ │ ├── base.yaml │ │ │ │ ├── integrator │ │ │ │ │ ├── md.yaml │ │ │ │ │ └── rpmd.yaml │ │ │ │ └── thermostat │ │ │ │ │ ├── berendsen.yaml │ │ │ │ │ ├── gle.yaml │ │ │ │ │ ├── langevin.yaml │ │ │ │ │ ├── nhc.yaml │ │ │ │ │ ├── pi_gle.yaml │ │ │ │ │ ├── pi_nhc_global.yaml │ │ │ │ │ ├── pi_nhc_local.yaml │ │ │ │ │ ├── piglet.yaml │ │ │ │ │ ├── pile_global.yaml │ │ │ │ │ ├── pile_local.yaml │ │ │ │ │ └── trpmd.yaml │ │ │ └── system │ │ │ │ ├── initializer │ │ │ │ ├── maxwell_boltzmann.yaml │ │ │ │ └── uniform.yaml │ │ │ │ └── system.yaml │ │ ├── neighborlist_md.py │ │ ├── parsers │ │ │ ├── __init__.py │ │ │ └── orca_parser.py │ │ ├── simulation_hooks │ │ │ ├── __init__.py │ │ │ ├── barostats.py │ │ │ ├── barostats_rpmd.py │ │ │ ├── basic_hooks.py │ │ │ ├── callback_hooks.py │ │ │ ├── thermostats.py │ │ │ └── thermostats_rpmd.py │ │ ├── simulator.py │ │ ├── system.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── md_config.py │ │ │ ├── normal_model_transformation.py │ │ │ └── thermostat_utils.py │ ├── model │ │ ├── __init__.py │ │ └── base.py │ ├── nn │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── base.py │ │ ├── blocks.py │ │ ├── cutoff.py │ │ ├── embedding.py │ │ ├── equivariant.py │ │ ├── ops │ │ │ ├── __init__.py │ │ │ ├── math.py │ │ │ └── so3.py │ │ ├── radial.py │ │ ├── scatter.py │ │ ├── so3.py │ │ └── utils.py │ ├── properties.py │ ├── representation │ │ ├── __init__.py │ │ ├── field_schnet.py │ │ ├── painn.py │ │ ├── schnet.py │ │ └── so3net.py │ ├── task.py │ ├── train │ │ ├── __init__.py │ │ ├── callbacks.py │ │ ├── lr_scheduler.py │ │ └── metrics.py │ ├── transform │ │ ├── __init__.py │ │ ├── atomistic.py │ │ ├── base.py │ │ ├── casting.py │ │ ├── neighborlist.py │ │ └── response.py │ ├── units.py │ └── utils │ │ ├── __init__.py │ │ ├── compatibility.py │ │ └── script.py └── scripts │ ├── spkconvert │ ├── spkdeploy │ ├── spkmd │ ├── spkpredict │ └── spktrain ├── tests ├── README.md ├── __init__.py ├── atomistic │ ├── __init__.py │ └── test_response.py ├── conftest.py ├── data │ ├── __init__.py │ ├── conftest.py │ ├── test_data.py │ ├── test_datasets.py │ ├── test_loader.py │ └── test_transforms.py ├── nn │ ├── __init__.py │ ├── test_activations.py │ ├── test_cutoff.py │ ├── test_radial.py │ └── test_schnet.py ├── testdata │ ├── md_ethanol.model │ ├── md_ethanol.xyz │ ├── si16.model │ ├── test_qm9.db │ └── tmp │ │ └── .gitkeep └── user_config │ └── user_exp.yaml └── tox.ini /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Black Code Formatter 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | black: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - name: Checkout code 11 | uses: actions/checkout@v2 12 | 13 | - name: Set up Python 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: '3.x' # Specify the Python version you need 17 | 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install black==24.4.2 22 | 23 | - name: Run black 24 | run: black --check . 25 | 26 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea 6 | *.DS_Store 7 | 8 | docs/tutorials/*.db 9 | docs/tutorials/*.xyz 10 | docs/tutorials/qm9tut 11 | 12 | # hydra stuff 13 | outputs 14 | logs 15 | generated 16 | 17 | # C extensions 18 | *.so 19 | 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | .vscode 40 | .github 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | .pytest_cache 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | local_settings.py 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | docs/api/generated/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # dotenv 98 | .env 99 | 100 | # virtualenv 101 | .venv 102 | venv/ 103 | ENV/ 104 | env/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | 119 | # experiments 120 | runs/ 121 | tests/testdata/tmp/test* 122 | 123 | # lammps examples 124 | interfaces/lammps/examples/*/*.lammpstrj 125 | interfaces/lammps/examples/*/*.lammps 126 | interfaces/lammps/examples/*/*.dump 127 | interfaces/lammps/examples/*/*.dat 128 | interfaces/lammps/examples/*/deployed_model 129 | 130 | # batchwise optimizer examples 131 | examples/howtos/howto_batchwise_relaxations_outputs/* -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/python/black 3 | rev: 24.4.2 4 | hooks: 5 | - id: black 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | COPYRIGHT 2 | 3 | Copyright (c) 2018 Kristof Schütt, Michael Gastegger, Pan Kessel, Kim Nicoli 4 | 5 | All other contributions: 6 | Copyright (c) 2018, the respective contributors. 7 | All rights reserved. 8 | 9 | Each contributor holds copyright over their respective contributions. 10 | The project versioning (Git) records all such contribution source information. 11 | 12 | LICENSE 13 | 14 | The MIT License 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | 34 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = SchNetPack 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/_templates/autosummary/classtemplate.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | 5 | {{ name | underline}} 6 | 7 | .. autoclass:: {{ name }} 8 | :no-inherited-members: 9 | :members: 10 | -------------------------------------------------------------------------------- /docs/api/atomistic.rst: -------------------------------------------------------------------------------- 1 | schnetpack.atomistic 2 | ==================== 3 | .. currentmodule:: atomistic 4 | 5 | Output modules 6 | -------------- 7 | .. rubric:: Atom-wise layers 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | :nosignatures: 12 | :template: classtemplate.rst 13 | 14 | Atomwise 15 | DipoleMoment 16 | Polarizability 17 | 18 | .. rubric:: Response layers 19 | 20 | .. autosummary:: 21 | :toctree: generated 22 | :nosignatures: 23 | :template: classtemplate.rst 24 | 25 | Forces 26 | -------------------------------------------------------------------------------- /docs/api/data.rst: -------------------------------------------------------------------------------- 1 | schnetpack.data 2 | =============== 3 | .. currentmodule:: data 4 | 5 | Atoms data 6 | ------------ 7 | .. autosummary:: 8 | :toctree: generated 9 | :nosignatures: 10 | :recursive: 11 | :template: classtemplate.rst 12 | 13 | BaseAtomsData 14 | ASEAtomsData 15 | AtomsLoader 16 | resolve_format 17 | AtomsDataFormat 18 | StratifiedSampler 19 | 20 | 21 | Creation 22 | -------- 23 | .. autosummary:: 24 | :toctree: generated 25 | :nosignatures: 26 | :template: classtemplate.rst 27 | 28 | create_dataset 29 | load_dataset 30 | 31 | Data modules 32 | ------------ 33 | 34 | .. autosummary:: 35 | :toctree: generated 36 | :nosignatures: 37 | :template: classtemplate.rst 38 | 39 | AtomsDataModule 40 | 41 | Statistics 42 | ---------- 43 | 44 | .. autosummary:: 45 | :toctree: generated 46 | :nosignatures: 47 | :template: classtemplate.rst 48 | 49 | calculate_stats 50 | NumberOfAtomsCriterion 51 | PropertyCriterion -------------------------------------------------------------------------------- /docs/api/datasets.rst: -------------------------------------------------------------------------------- 1 | schnetpack.datasets 2 | =================== 3 | .. currentmodule:: datasets 4 | 5 | Molecules 6 | ------------ 7 | .. autosummary:: 8 | :toctree: generated 9 | :nosignatures: 10 | :template: classtemplate.rst 11 | 12 | QM9 13 | MD17 14 | ANI1 15 | ISO17 16 | 17 | Materials 18 | --------- 19 | .. autosummary:: 20 | :toctree: generated 21 | :nosignatures: 22 | :template: classtemplate.rst 23 | 24 | MaterialsProject 25 | OrganicMaterialsDatabase 26 | 27 | 28 | -------------------------------------------------------------------------------- /docs/api/md.rst: -------------------------------------------------------------------------------- 1 | schnetpack.md 2 | ============= 3 | .. currentmodule:: md 4 | 5 | This module contains all functionality for performing various molecular dynamics simulations using SchNetPack. 6 | 7 | System 8 | ------ 9 | 10 | .. autosummary:: 11 | :toctree: generated 12 | :nosignatures: 13 | :template: classtemplate.rst 14 | 15 | System 16 | 17 | 18 | Initial Conditions 19 | ------------------ 20 | 21 | .. currentmodule:: md.initial_conditions 22 | 23 | .. autosummary:: 24 | :toctree: generated 25 | :nosignatures: 26 | :template: classtemplate.rst 27 | 28 | Initializer 29 | MaxwellBoltzmannInit 30 | UniformInit 31 | 32 | 33 | Integrators 34 | ----------- 35 | 36 | .. currentmodule:: md.integrators 37 | 38 | Integrators for NVE and NVT simulations: 39 | 40 | .. autosummary:: 41 | :toctree: generated 42 | :nosignatures: 43 | :template: classtemplate.rst 44 | 45 | Integrator 46 | VelocityVerlet 47 | RingPolymer 48 | 49 | Integrators for NPT simulations: 50 | 51 | .. autosummary:: 52 | :toctree: generated 53 | :nosignatures: 54 | :template: classtemplate.rst 55 | 56 | NPTVelocityVerlet 57 | NPTRingPolymer 58 | 59 | 60 | Calculators 61 | ----------- 62 | 63 | .. currentmodule:: md.calculators 64 | 65 | Basic calculators: 66 | 67 | .. autosummary:: 68 | :toctree: generated 69 | :nosignatures: 70 | :template: classtemplate.rst 71 | 72 | MDCalculator 73 | QMCalculator 74 | EnsembleCalculator 75 | LJCalculator 76 | 77 | Neural network potentials and ORCA calculators: 78 | 79 | .. autosummary:: 80 | :toctree: generated 81 | :nosignatures: 82 | :template: classtemplate.rst 83 | 84 | SchNetPackCalculator 85 | SchNetPackEnsembleCalculator 86 | OrcaCalculator 87 | 88 | 89 | Neighbor List 90 | ------------- 91 | 92 | .. currentmodule:: md.neighborlist_md 93 | 94 | .. autosummary:: 95 | :toctree: generated 96 | :nosignatures: 97 | :template: classtemplate.rst 98 | 99 | NeighborListMD 100 | 101 | 102 | Simulator 103 | --------- 104 | 105 | .. currentmodule:: md 106 | 107 | .. autosummary:: 108 | :toctree: generated 109 | :nosignatures: 110 | :template: classtemplate.rst 111 | 112 | Simulator 113 | 114 | 115 | Simulation hooks 116 | ---------------- 117 | 118 | .. currentmodule:: md.simulation_hooks 119 | 120 | Basic hooks: 121 | 122 | .. autosummary:: 123 | :toctree: generated 124 | :nosignatures: 125 | :template: classtemplate.rst 126 | 127 | SimulationHook 128 | RemoveCOMMotion 129 | 130 | Thermostats: 131 | 132 | .. autosummary:: 133 | :toctree: generated 134 | :nosignatures: 135 | :template: classtemplate.rst 136 | 137 | ThermostatHook 138 | BerendsenThermostat 139 | LangevinThermostat 140 | NHCThermostat 141 | GLEThermostat 142 | 143 | Thermostats for ring-polymer MD: 144 | 145 | .. autosummary:: 146 | :toctree: generated 147 | :nosignatures: 148 | :template: classtemplate.rst 149 | 150 | PILELocalThermostat 151 | PILEGlobalThermostat 152 | TRPMDThermostat 153 | RPMDGLEThermostat 154 | PIGLETThermostat 155 | NHCRingPolymerThermostat 156 | 157 | Barostats: 158 | 159 | .. autosummary:: 160 | :toctree: generated 161 | :nosignatures: 162 | :template: classtemplate.rst 163 | 164 | BarostatHook 165 | NHCBarostatIsotropic 166 | NHCBarostatAnisotropic 167 | 168 | Barostats for ring-polymer MD: 169 | 170 | .. autosummary:: 171 | :toctree: generated 172 | :nosignatures: 173 | :template: classtemplate.rst 174 | 175 | PILEBarostat 176 | 177 | Logging and callback 178 | 179 | .. autosummary:: 180 | :toctree: generated 181 | :nosignatures: 182 | :template: classtemplate.rst 183 | 184 | Checkpoint 185 | DataStream 186 | MoleculeStream 187 | PropertyStream 188 | FileLogger 189 | BasicTensorboardLogger 190 | TensorBoardLogger 191 | 192 | 193 | Simulation data and postprocessing 194 | ---------------------------------- 195 | 196 | .. currentmodule:: md.data 197 | 198 | Data loading: 199 | 200 | .. autosummary:: 201 | :toctree: generated 202 | :nosignatures: 203 | :template: classtemplate.rst 204 | 205 | HDF5Loader 206 | 207 | Vibrational spectra: 208 | 209 | .. autosummary:: 210 | :toctree: generated 211 | :nosignatures: 212 | :template: classtemplate.rst 213 | 214 | VibrationalSpectrum 215 | PowerSpectrum 216 | IRSpectrum 217 | RamanSpectrum 218 | 219 | 220 | ORCA output parsing 221 | ------------------- 222 | 223 | .. currentmodule:: md.parsers 224 | 225 | .. autosummary:: 226 | :toctree: generated 227 | :nosignatures: 228 | :template: classtemplate.rst 229 | 230 | OrcaParser 231 | OrcaOutputParser 232 | OrcaFormatter 233 | OrcaPropertyParser 234 | OrcaMainFileParser 235 | OrcaHessianFileParser 236 | 237 | 238 | MD utilities 239 | ------------ 240 | 241 | .. currentmodule:: md.utils 242 | 243 | .. autosummary:: 244 | :toctree: generated 245 | :nosignatures: 246 | :template: classtemplate.rst 247 | 248 | NormalModeTransformer 249 | 250 | Utilities for thermostats 251 | 252 | .. currentmodule:: md.utils.thermostat_utils 253 | 254 | .. autosummary:: 255 | :toctree: generated 256 | :nosignatures: 257 | :template: classtemplate.rst 258 | 259 | YSWeights 260 | GLEMatrixParser 261 | load_gle_matrices 262 | StableSinhDiv 263 | -------------------------------------------------------------------------------- /docs/api/model.rst: -------------------------------------------------------------------------------- 1 | schnetpack.model 2 | ================ 3 | .. currentmodule:: model 4 | 5 | .. autosummary:: 6 | :toctree: generated 7 | :nosignatures: 8 | :template: classtemplate.rst 9 | 10 | AtomisticModel 11 | NeuralNetworkPotential 12 | 13 | -------------------------------------------------------------------------------- /docs/api/nn.rst: -------------------------------------------------------------------------------- 1 | schnetpack.nn 2 | ============= 3 | .. currentmodule:: nn 4 | 5 | 6 | Basic layers 7 | ------------ 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | :nosignatures: 12 | :template: classtemplate.rst 13 | 14 | Dense 15 | 16 | 17 | Equivariant layers 18 | ------------------ 19 | 20 | Cartesian: 21 | 22 | .. autosummary:: 23 | :toctree: generated 24 | :nosignatures: 25 | :template: classtemplate.rst 26 | 27 | GatedEquivariantBlock 28 | 29 | Irreps: 30 | 31 | .. autosummary:: 32 | :toctree: generated 33 | :nosignatures: 34 | :template: classtemplate.rst 35 | 36 | RealSphericalHarmonics 37 | SO3TensorProduct 38 | SO3Convolution 39 | SO3GatedNonlinearity 40 | SO3ParametricGatedNonlinearity 41 | 42 | 43 | Radial basis 44 | ------------ 45 | 46 | .. autosummary:: 47 | :toctree: generated 48 | :nosignatures: 49 | :template: classtemplate.rst 50 | 51 | GaussianRBF 52 | GaussianRBFCentered 53 | BesselRBF 54 | 55 | 56 | Cutoff 57 | ------ 58 | 59 | .. autosummary:: 60 | :toctree: generated 61 | :nosignatures: 62 | :template: classtemplate.rst 63 | 64 | CosineCutoff 65 | MollifierCutoff 66 | 67 | 68 | Activations 69 | ----------- 70 | 71 | .. autosummary:: 72 | :toctree: generated 73 | :nosignatures: 74 | :template: classtemplate.rst 75 | 76 | shifted_softplus 77 | 78 | 79 | Ops 80 | --- 81 | 82 | .. autosummary:: 83 | :toctree: generated 84 | :nosignatures: 85 | 86 | scatter_add 87 | 88 | Factory functions 89 | ----------------- 90 | 91 | .. autosummary:: 92 | :toctree: generated 93 | :nosignatures: 94 | 95 | build_mlp 96 | build_gated_equivariant_mlp 97 | replicate_module 98 | -------------------------------------------------------------------------------- /docs/api/representation.rst: -------------------------------------------------------------------------------- 1 | schnetpack.representation 2 | ========================= 3 | .. currentmodule:: representation 4 | 5 | 6 | .. rubric:: Message-passing neural networks 7 | 8 | .. autosummary:: 9 | :toctree: generated 10 | :nosignatures: 11 | :template: classtemplate.rst 12 | 13 | SchNet 14 | PaiNN 15 | -------------------------------------------------------------------------------- /docs/api/schnetpack.rst: -------------------------------------------------------------------------------- 1 | schnetpack 2 | ========== 3 | 4 | Structure attributes 5 | -------------------- 6 | .. autosummary:: 7 | :toctree: generated 8 | :nosignatures: 9 | 10 | properties.Z 11 | properties.R 12 | properties.cell 13 | properties.pbc 14 | properties.idx_m 15 | properties.idx_i 16 | properties.idx_j 17 | properties.Rij 18 | properties.n_atoms 19 | 20 | Units 21 | ----- 22 | .. autosummary:: 23 | :toctree: generated 24 | :nosignatures: 25 | 26 | units.convert_units 27 | -------------------------------------------------------------------------------- /docs/api/task.rst: -------------------------------------------------------------------------------- 1 | schnetpack.task 2 | =============== 3 | .. currentmodule:: task 4 | 5 | .. autosummary:: 6 | :toctree: generated 7 | :nosignatures: 8 | :template: classtemplate.rst 9 | 10 | AtomisticTask 11 | ModelOutput 12 | UnsupervisedModelOutput 13 | 14 | -------------------------------------------------------------------------------- /docs/api/train.rst: -------------------------------------------------------------------------------- 1 | schnetpack.train 2 | ================ 3 | .. currentmodule:: train 4 | 5 | 6 | Callbacks 7 | --------- 8 | 9 | .. autosummary:: 10 | :toctree: generated 11 | :nosignatures: 12 | :template: classtemplate.rst 13 | 14 | ModelCheckpoint 15 | PredictionWriter 16 | 17 | Scheduler 18 | --------- 19 | 20 | .. autosummary:: 21 | :toctree: generated 22 | :nosignatures: 23 | :template: classtemplate.rst 24 | 25 | ReduceLROnPlateau 26 | 27 | -------------------------------------------------------------------------------- /docs/api/transform.rst: -------------------------------------------------------------------------------- 1 | schnetpack.transform 2 | ==================== 3 | .. automodule:: transform 4 | 5 | .. currentmodule:: transform 6 | .. autoclass:: Transform 7 | 8 | Atomistic 9 | --------- 10 | 11 | .. autosummary:: 12 | :toctree: generated 13 | :nosignatures: 14 | :template: classtemplate.rst 15 | 16 | AddOffsets 17 | RemoveOffsets 18 | SubtractCenterOfMass 19 | SubtractCenterOfGeometry 20 | 21 | Casting 22 | ------- 23 | 24 | .. autosummary:: 25 | :toctree: generated 26 | :nosignatures: 27 | :template: classtemplate.rst 28 | 29 | CastMap 30 | CastTo32 31 | CastTo64 32 | 33 | Neighbor lists 34 | -------------- 35 | 36 | .. autosummary:: 37 | :toctree: generated 38 | :nosignatures: 39 | :template: classtemplate.rst 40 | 41 | MatScipyNeighborList 42 | ASENeighborList 43 | TorchNeighborList 44 | CachedNeighborList 45 | CountNeighbors 46 | FilterNeighbors 47 | WrapPositions 48 | CollectAtomTriples 49 | -------------------------------------------------------------------------------- /docs/howtos: -------------------------------------------------------------------------------- 1 | ../examples/howtos/ -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. SchNetPack documentation master file, created by 2 | sphinx-quickstart on Mon Jul 30 18:07:50 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | SchNetPack documentation 7 | ======================== 8 | 9 | SchNetPack is a toolbox for the development and application of deep neural networks to the prediction of 10 | potential energy surfaces and other quantum-chemical properties of molecules and materials. It contains 11 | basic building blocks of atomistic neural networks, manages their training and provides simple access 12 | to common benchmark datasets. This allows for an easy implementation and evaluation of new models. 13 | 14 | Contents 15 | ======== 16 | 17 | .. toctree:: 18 | :glob: 19 | :caption: Get Started 20 | :maxdepth: 1 21 | 22 | getstarted 23 | 24 | .. toctree:: 25 | :glob: 26 | :caption: User guide 27 | :maxdepth: 1 28 | 29 | userguide/overview 30 | userguide/configs 31 | userguide/md 32 | 33 | .. toctree:: 34 | :glob: 35 | :caption: Tutorials 36 | :maxdepth: 1 37 | 38 | tutorials/tutorial_01_preparing_data 39 | tutorials/tutorial_02_qm9 40 | tutorials/tutorial_03_force_models 41 | tutorials/tutorial_04_molecular_dynamics 42 | tutorials/tutorial_05_materials 43 | 44 | .. toctree:: 45 | :glob: 46 | :caption: How-To 47 | :maxdepth: 1 48 | 49 | howtos/howto_batchwise_relaxations 50 | howtos/lammps 51 | 52 | .. toctree:: 53 | :glob: 54 | :caption: Reference 55 | :maxdepth: 1 56 | 57 | api/schnetpack 58 | api/atomistic 59 | api/data 60 | api/datasets 61 | api/task 62 | api/model 63 | api/representation 64 | api/nn 65 | api/train 66 | api/transform 67 | api/md 68 | -------------------------------------------------------------------------------- /docs/pictures/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/docs/pictures/tensorboard.png -------------------------------------------------------------------------------- /docs/sphinx-requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | ase>=3.21 4 | h5py 5 | pyyaml 6 | tqdm 7 | ipykernel 8 | 9 | nbsphinx 10 | sphinx==7.1.2 11 | sphinx_rtd_theme 12 | readthedocs-sphinx-search 13 | 14 | pytorch_lightning>=2.0.0 15 | hydra-core>=1.1.0 16 | hydra-colorlog>=1.1.0 17 | torchmetrics==1.0.1 18 | protobuf==3.20.2 19 | -------------------------------------------------------------------------------- /docs/tutorials: -------------------------------------------------------------------------------- 1 | ../examples/tutorials/ -------------------------------------------------------------------------------- /docs/userguide/overview.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Overview 3 | ======== 4 | .. _overview: 5 | 6 | SchNetPack is built so that it can be used from the command line and configured with 7 | config files as well as used as a Python library. 8 | In this section, we will explain the overall structure of SchNetPack, which will be 9 | helpful for both these use cases. 10 | SchNetPack is based on PyTorch and uses `PyTorchLightning `_ as a training framework. 11 | This heavily influences the structure described here. 12 | Additionally, `Hydra `_ is used to configure SchNetPack for command-line usage, 13 | which will be described in the next chapter. 14 | 15 | Data 16 | ==== 17 | .. currentmodule:: data 18 | 19 | SchNetPack currently supports data sets stored in ASE format using 20 | :class:`ASEAtomsData`, but other formats can be added by implementing 21 | :class:`BaseAtomsData`. These classes are compatible with PyTorch dataloaders and 22 | provide an additional interface to store metadata, e.g. property units and 23 | single-atom reference values. 24 | 25 | An important aspect are the transforms that can be passed to the data classes. Those 26 | are PyTorch modules that perform preprocessing task on the data *before* batching. 27 | Typically, this is performed on the CPU as part of the multi-processing of PyTorch 28 | dataloaders. 29 | Important preprocessing :class:`Transform`s include removing of offsets from target properties 30 | and calculation of neighbor lists. 31 | 32 | Furthermore, we support PyTorch Lightning datamodules through :class:`AtomsDataModule`, 33 | which combines :class:`ASEAtomsData` with code for preparation, setup and partitioning 34 | into train/validation/test splits. We provide specific implementations of 35 | :class:`AtomsDataModule` for several benchmark datasets. 36 | 37 | 38 | Model 39 | ===== 40 | .. currentmodule:: model 41 | 42 | A core component of SchNetPack is the :class:`AtomisticModel`, which is the base 43 | class for all models implemented in SchNetPack. It is essentially a PyTorch module with 44 | some additional functionality for specific to atomistic machine learning. 45 | 46 | The particular features and requirements are: 47 | 48 | * **Input dictionary:** 49 | To support a flexible interface, each model is supposed to take an input dictionary 50 | mapping strings to PyTorch tensors and returns a modified dictionary as output. 51 | 52 | * **Automatic collection of required dervatives:** 53 | Each layer that requires derivatives w.r.t to some input, should list them as strings 54 | in `layer.required_derivatives = ["input_key"]`. The `requires_grad` of the input 55 | tensor is then set automatically. 56 | 57 | .. currentmodule::transform 58 | * **Post-processing:** 59 | The atomistic model can take a list of non-trainable :class:`Transform`s that are 60 | used to post-process the output dictionary. These are not applied during training. 61 | A common use case are energy values that a large offsets and require double 62 | precision. To be able to still run a single precision model on GPU, one can substract 63 | the offset from the reference data during a preprocessing stage and then add it 64 | to the model prediction in post-processing after casting to double. 65 | 66 | .. currentmodule:: model 67 | While :class:`AtomisticModel` is a fairly general class, the models provided in 68 | SchNetPack follow a structure defined in the subclass :class:`NeuralNetworkPotential`: 69 | 70 | #. **Input modules:**: the input dictionary is sequentially passed to a list of PyTorch 71 | modules that return a modified dictionary 72 | 73 | #. **Representation:**: the input dictionary is passed to a representation module that 74 | computes atomwise representation, e.g. SchNet or PaiNN. The representation is added 75 | to the dictionary 76 | 77 | #. **Output modules:**: the dictionary is sequentially passed to a list of PyTorch 78 | modules that store the outputs in the dictionary 79 | 80 | Adhering to the structure of :class:`NeuralNetworkPotential` makes it easier to define 81 | config templates with Hydra and it is therefore recommended to subclass it wherever 82 | possible. 83 | 84 | Task 85 | ==== 86 | .. currentmodule:: task 87 | 88 | The :class:`AtomisticTask` ties the model, outputs, loss and optimizers together and defines 89 | how the neural network will be trained. While the model is a vanilla PyTorch module, 90 | the task is a :class:`LightningModule` that can be directly passed to the 91 | PyTorch Lightning :class:`Trainer`. 92 | 93 | To define an :class:`AtomisticTask`, you need to provide: 94 | 95 | * a model as described above 96 | 97 | * a list of :class:`ModelOutput` which map output dictionary keys to target properties and assigns a loss function and other metrics 98 | 99 | * (optionally) optimizer and learning rate schedulers 100 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | In this directory, you find code examples and tutorials to demonstrate the functionality of SchNetPack 4 | 5 | ## Tutorials 6 | Jupyter notebooks demonstrating general concepts and workflows 7 | 8 | [Preparing and loading your data](tutorials/tutorial_01_preparing_data.ipynb) 9 | 10 | [Training a neural network on QM9 11 | ](tutorials/tutorial_02_qm9.ipynb) 12 | 13 | [Training a model on forces and energies](tutorials/tutorial_03_force_models.ipynb) 14 | 15 | [Molecular dynamics in SchNetPack](tutorials/tutorial_04_molecular_dynamics.ipynb) 16 | 17 | [Force fields for materials](tutorials/tutorial_05_materials.ipynb) 18 | 19 | 20 | ## How-To 21 | Short notebooks showing a particular use-case or functionality 22 | 23 | [Batch-wise Structure Relaxation 24 | ](howtos/howto_batchwise_relaxations.ipynb) 25 | -------------------------------------------------------------------------------- /interfaces/lammps/examples/aspirin/aspirin.data: -------------------------------------------------------------------------------- 1 | /home/niklas/phd/code/pair_nequip/test/aspirin.data (written by ASE) 2 | 3 | 21 atoms 4 | 3 atom types 5 | 0.0 10 xlo xhi 6 | 0.0 10 ylo yhi 7 | 0.0 10 zlo zhi 8 | 9 | 10 | Atoms 11 | 12 | 1 1 7.1344888199999996 4.0156389499999996 4.8047821099999997 13 | 2 1 5.7626438100000001 5.9594139500000001 3.3200710999999998 14 | 3 1 7.6603448400000005 4.5920739499999996 3.6926960900000001 15 | 4 1 6.9103168200000002 5.3939659600000001 2.8529801400000001 16 | 5 1 1.9698097699999999 6.4954049600000001 5.7196621299999997 17 | 6 1 5.8494248400000002 4.4491289299999996 5.2843751000000001 18 | 7 1 5.2384468200000001 5.4735059399999999 4.5955781 19 | 8 3 5.8978958099999996 2.72356796 6.7300610499999998 20 | 9 3 2.6165478200000001 5.4177789399999998 3.5371431099999997 21 | 10 3 4.5237988199999997 4.4709129299999999 7.3392591500000002 22 | 11 1 5.3929918099999998 3.8097629500000001 6.5379821099999997 23 | 12 1 2.8770148799999999 5.9517599299999997 4.6022951000000001 24 | 13 3 4.19533384 6.2862459399999997 5.1105090999999998 25 | 14 2 4.5061968300000004 3.8132079800000001 8.0959742099999996 26 | 15 2 7.55473471 3.1975029699999999 5.3921310900000003 27 | 16 2 5.3306898199999999 6.8557109799999996 2.65473604 28 | 17 2 8.8037939099999996 4.5062809599999998 3.5437971399999997 29 | 18 2 7.2311408500000001 5.55718595 1.8758501999999999 30 | 19 2 2.2910697500000001 7.4846579999999996 5.9269280999999996 31 | 20 2 0.86951685000000012 6.4821670099999995 5.4312661000000002 32 | 21 2 2.12585187 6.0032089900000001 6.6994850599999998 33 | -------------------------------------------------------------------------------- /interfaces/lammps/examples/aspirin/aspirin_md.in: -------------------------------------------------------------------------------- 1 | units real 2 | atom_style atomic 3 | newton off 4 | thermo 1 5 | dump mydmp all atom 10 dump.lammpstrj 6 | boundary s s s 7 | read_data aspirin.data 8 | pair_style schnetpack 9 | pair_coeff * * deployed_model 6 1 8 10 | mass 1 12.0 11 | mass 2 1.0 12 | mass 3 16.0 13 | neighbor 1.0 bin 14 | neigh_modify delay 0 every 1 check no 15 | fix 1 all nve 16 | timestep 0.5 17 | compute atomicenergies all pe/atom 18 | compute totalatomicenergy all reduce sum c_atomicenergies 19 | thermo_style custom step time temp pe c_totalatomicenergy etotal press spcpu cpuremain 20 | run 5000 21 | print $(0.000001 * pe) file pe.dat 22 | print $(0.000001 * c_totalatomicenergy) file totalatomicenergy.dat 23 | write_dump all custom output.dump id type x y z fx fy fz c_atomicenergies modify format float %20.15g 24 | -------------------------------------------------------------------------------- /interfaces/lammps/examples/aspirin/best_model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/interfaces/lammps/examples/aspirin/best_model -------------------------------------------------------------------------------- /interfaces/lammps/pair_schnetpack.h: -------------------------------------------------------------------------------- 1 | /* ---------------------------------------------------------------------- 2 | References: 3 | 4 | .. [#pair_nequip] https://github.com/mir-group/pair_nequip 5 | .. [#lammps] https://github.com/lammps/lammps 6 | 7 | ------------------------------------------------------------------------- */ 8 | 9 | #ifdef PAIR_CLASS 10 | 11 | PairStyle(schnetpack,PairSCHNETPACK) 12 | 13 | #else 14 | 15 | #ifndef LMP_PAIR_SCHNETPACK_H 16 | #define LMP_PAIR_SCHNETPACK_H 17 | 18 | #include "pair.h" 19 | 20 | #include 21 | 22 | namespace LAMMPS_NS { 23 | 24 | class PairSCHNETPACK : public Pair { 25 | public: 26 | PairSCHNETPACK(class LAMMPS *); 27 | virtual ~PairSCHNETPACK(); 28 | virtual void compute(int, int); 29 | void settings(int, char **); 30 | virtual void coeff(int, char **); 31 | virtual double init_one(int, int); 32 | virtual void init_style(); 33 | void allocate(); 34 | 35 | double cutoff; 36 | torch::jit::script::Module model; 37 | torch::Device device = torch::kCPU; 38 | 39 | protected: 40 | int * type_mapper; 41 | int debug_mode = 0; 42 | 43 | }; 44 | 45 | } 46 | 47 | #endif 48 | #endif 49 | -------------------------------------------------------------------------------- /interfaces/lammps/patch_lammps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # usage: patch_lammps.sh [-e] /path/to/lammps/ 3 | # 4 | # 5 | # References: 6 | # 7 | # .. [#pair_nequip] https://github.com/mir-group/pair_nequip 8 | 9 | 10 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 11 | echo $SCRIPT_DIR 12 | 13 | 14 | do_e_mode=false 15 | 16 | while getopts "he" option; do 17 | case $option in 18 | e) 19 | do_e_mode=true;; 20 | h) # display Help 21 | echo "patch_lammps.sh [-e] /path/to/lammps/" 22 | exit;; 23 | esac 24 | done 25 | 26 | # https://stackoverflow.com/a/9472919 27 | shift $(($OPTIND - 1)) 28 | lammps_dir=$1 29 | 30 | if [ "$lammps_dir" = "" ]; 31 | then 32 | echo "lammps_dir must be provided" 33 | exit 1 34 | fi 35 | 36 | if [ ! -d "$lammps_dir" ] 37 | then 38 | echo "$lammps_dir doesn't exist" 39 | exit 1 40 | fi 41 | 42 | if [ ! -d "$lammps_dir/cmake" ] 43 | then 44 | echo "$lammps_dir doesn't look like a LAMMPS source directory" 45 | exit 1 46 | fi 47 | 48 | # Check if root directory is correct 49 | if [ ! -f pair_schnetpack.cpp ]; then 50 | echo "Please run `patch_lammps.sh` from the `pair_schnetpack.cpp` root directory." 51 | exit 1 52 | fi 53 | 54 | echo "Updating CMakeLists.txt..." 55 | # Check for double-patch 56 | if grep -q "find_package(Torch REQUIRED)" $lammps_dir/cmake/CMakeLists.txt 57 | then 58 | echo "This LAMMPS installation _seems_ to already have been patched. CMakeLists.txt file not modified." 59 | else 60 | # Update CMakeLists.txt 61 | sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD 14)/" $lammps_dir/cmake/CMakeLists.txt 62 | 63 | # Add libtorch 64 | cat >> $lammps_dir/cmake/CMakeLists.txt << "EOF2" 65 | 66 | find_package(Torch REQUIRED) 67 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 68 | target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}") 69 | EOF2 70 | 71 | fi 72 | 73 | # check if files need to be copied to lammps directory 74 | if [ ! -f $lammps_dir/src/pair_schnetpack.cpp ]; then 75 | if [ "$do_e_mode" = true ] 76 | then 77 | echo "Making source symlinks (-e)..." 78 | for file in *.{cpp,h}; do 79 | ln -s `realpath -s $file` $lammps_dir/src/$file 80 | done 81 | else 82 | echo "Copying files..." 83 | for file in *.{cpp,h}; do 84 | cp $file $lammps_dir/src/$file 85 | done 86 | fi 87 | else 88 | echo "pair_schnetpack.cpp file already exists. No files copied." 89 | fi 90 | 91 | 92 | echo "Done!" 93 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "schnetpack" 7 | # Declaring that the version will be read dynamically. 8 | # This helps us not having to specify the version twice. 9 | # We have to specify the version only in schnetpack.__version__ 10 | # see [tool.setuptools.dynamic] below 11 | dynamic = ["version"] 12 | authors = [ 13 | { name = "Kristof T. Schuett" }, 14 | { name = "Michael Gastegger" }, 15 | { name = "Stefaan Hessmann" }, 16 | { name = "Niklas Gebauer" }, 17 | { name = "Jonas Lederer" } 18 | ] 19 | description = "SchNetPack - Deep Neural Networks for Atomistic Systems" 20 | readme = "README.md" 21 | license = { file="LICENSE" } 22 | requires-python = ">=3.12,<3.13" 23 | dependencies = [ 24 | "numpy>=2.0.0", 25 | "sympy>=1.13", 26 | "ase>=3.21", 27 | "h5py", 28 | "pyyaml", 29 | "hydra-core>=1.1.0", 30 | "torch>=2.5.0", 31 | "pytorch_lightning>=2.0.0", 32 | "torchmetrics", 33 | "hydra-colorlog>=1.1.0", 34 | "rich", 35 | "fasteners", 36 | "dirsync", 37 | "torch-ema", 38 | "matscipy>=1.1.0", 39 | "tensorboard>=2.17.1", 40 | "tensorboardX>=2.6.2.2", 41 | "tqdm", 42 | "pre-commit", 43 | "black", 44 | "protobuf", 45 | "progressbar" 46 | ] 47 | 48 | [project.optional-dependencies] 49 | test = ["pytest", "pytest-datadir", "pytest-benchmark"] 50 | 51 | [tool.setuptools] 52 | package-dir = { "" = "src" } 53 | script-files = [ 54 | "src/scripts/spkconvert", 55 | "src/scripts/spktrain", 56 | "src/scripts/spkpredict", 57 | "src/scripts/spkmd", 58 | "src/scripts/spkdeploy", 59 | ] 60 | 61 | [tool.setuptools.dynamic] 62 | version = {attr = "schnetpack.__version__"} 63 | 64 | [tool.setuptools.packages.find] 65 | where = ["src"] 66 | 67 | [tool.setuptools.package-data] 68 | schnetpack = ["configs/**/*.yaml"] 69 | -------------------------------------------------------------------------------- /readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: "ubuntu-22.04" 11 | tools: 12 | python: "3.12" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/conf.py 17 | 18 | # Explicitly set the version of Python and its requirements 19 | python: 20 | install: 21 | - requirements: docs/sphinx-requirements.txt 22 | - path: . 23 | -------------------------------------------------------------------------------- /src/schnetpack/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore", category=DeprecationWarning, module="tensorboard") 4 | 5 | from schnetpack import transform 6 | from schnetpack import properties 7 | from schnetpack import data 8 | from schnetpack import datasets 9 | from schnetpack import atomistic 10 | from schnetpack import representation 11 | from schnetpack import interfaces 12 | from schnetpack import nn 13 | from schnetpack import train 14 | from schnetpack import model 15 | from schnetpack.units import * 16 | from schnetpack.task import * 17 | from schnetpack import md 18 | 19 | 20 | __version__ = "2.1.1" 21 | -------------------------------------------------------------------------------- /src/schnetpack/atomistic/__init__.py: -------------------------------------------------------------------------------- 1 | from .atomwise import * 2 | from .response import * 3 | from .distances import * 4 | from .nuclear_repulsion import * 5 | from .electrostatic import * 6 | from .aggregation import * 7 | from .external_fields import * 8 | -------------------------------------------------------------------------------- /src/schnetpack/atomistic/aggregation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from typing import Dict, List 5 | 6 | __all__ = ["Aggregation"] 7 | 8 | 9 | class Aggregation(nn.Module): 10 | """ 11 | Aggregate predictions into a single output variable. 12 | 13 | Args: 14 | keys (list(str)): List of properties to be added. 15 | output_key (str): Name of new property in output. 16 | """ 17 | 18 | def __init__(self, keys: List[str], output_key: str = "y"): 19 | super(Aggregation, self).__init__() 20 | 21 | self.keys: List[str] = list(keys) 22 | self.output_key = output_key 23 | self.model_outputs = [output_key] 24 | 25 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 26 | energy = torch.stack([inputs[key] for key in self.keys]).sum(0) 27 | inputs[self.output_key] = energy 28 | return inputs 29 | -------------------------------------------------------------------------------- /src/schnetpack/atomistic/distances.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import schnetpack.properties as properties 7 | 8 | 9 | class PairwiseDistances(nn.Module): 10 | """ 11 | Compute pair-wise distances from indices provided by a neighbor list transform. 12 | """ 13 | 14 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 15 | R = inputs[properties.R] 16 | offsets = inputs[properties.offsets] 17 | idx_i = inputs[properties.idx_i] 18 | idx_j = inputs[properties.idx_j] 19 | 20 | # To avoid error in Windows OS 21 | idx_i = idx_i.long() 22 | idx_j = idx_j.long() 23 | 24 | Rij = R[idx_j] - R[idx_i] + offsets 25 | inputs[properties.Rij] = Rij 26 | return inputs 27 | 28 | 29 | class FilterShortRange(nn.Module): 30 | """ 31 | Separate short-range from all supplied distances. 32 | 33 | The short-range distances will be stored under the original keys (properties.Rij, 34 | properties.idx_i, properties.idx_j), while the original distances can be accessed for long-range terms via 35 | (properties.Rij_lr, properties.idx_i_lr, properties.idx_j_lr). 36 | """ 37 | 38 | def __init__(self, short_range_cutoff: float): 39 | super().__init__() 40 | self.short_range_cutoff = short_range_cutoff 41 | 42 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 43 | idx_i = inputs[properties.idx_i] 44 | idx_j = inputs[properties.idx_j] 45 | Rij = inputs[properties.Rij] 46 | 47 | rij = torch.norm(Rij, dim=-1) 48 | cidx = torch.nonzero(rij <= self.short_range_cutoff).squeeze(-1) 49 | 50 | inputs[properties.Rij_lr] = Rij 51 | inputs[properties.idx_i_lr] = idx_i 52 | inputs[properties.idx_j_lr] = idx_j 53 | 54 | inputs[properties.Rij] = Rij[cidx] 55 | inputs[properties.idx_i] = idx_i[cidx] 56 | inputs[properties.idx_j] = idx_j[cidx] 57 | return inputs 58 | -------------------------------------------------------------------------------- /src/schnetpack/atomistic/external_fields.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import schnetpack.properties as properties 7 | from schnetpack.utils import required_fields_from_properties 8 | 9 | __all__ = ["StaticExternalFields"] 10 | 11 | 12 | class StaticExternalFields(nn.Module): 13 | """ 14 | Input routine for setting up dummy external fields in response models. 15 | Checks if fields are present in input and sets dummy fields otherwise. 16 | 17 | Args: 18 | external_fields (list(str)): List of required external fields. Either this or the requested response 19 | properties needs to be specified. 20 | response_properties (list(str)): List of requested response properties. If this is not None, it is used to 21 | determine the required external fields. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | external_fields: List[str] = [], 27 | response_properties: Optional[List[str]] = None, 28 | ): 29 | super(StaticExternalFields, self).__init__() 30 | 31 | if response_properties is not None: 32 | external_fields = required_fields_from_properties(response_properties) 33 | 34 | self.external_fields: List[str] = list(set(external_fields)) 35 | 36 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 37 | n_atoms = inputs[properties.n_atoms] 38 | n_molecules = n_atoms.shape[0] 39 | 40 | # Fields passed to interaction computation (cast to batch structure) 41 | for field in self.external_fields: 42 | # Store all fields in directory which will be returned for derivatives 43 | if field not in inputs: 44 | inputs[field] = torch.zeros( 45 | n_molecules, 46 | 3, 47 | device=n_atoms.device, 48 | dtype=inputs[properties.R].dtype, 49 | requires_grad=True, 50 | ) 51 | 52 | # Initialize nuclear magnetic moments for magnetic fields 53 | if properties.magnetic_field in self.external_fields: 54 | if properties.nuclear_magnetic_moments not in inputs: 55 | inputs[properties.nuclear_magnetic_moments] = torch.zeros_like( 56 | inputs[properties.R], requires_grad=True 57 | ) 58 | 59 | return inputs 60 | -------------------------------------------------------------------------------- /src/schnetpack/atomistic/nuclear_repulsion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Union, Callable, Dict, Optional 5 | 6 | import schnetpack.properties as properties 7 | import schnetpack.nn as snn 8 | import schnetpack.units as spk_units 9 | 10 | __all__ = ["ZBLRepulsionEnergy"] 11 | 12 | 13 | class ZBLRepulsionEnergy(nn.Module): 14 | """ 15 | Computes a Ziegler-Biersack-Littmark style repulsion energy 16 | 17 | Args: 18 | energy_unit (str/float): Energy unit. 19 | position_unit (str/float): Unit used for distances. 20 | output_key (str): Key to which results will be stored 21 | trainable (bool): If set to true, ZBL parameters will be optimized during training (default=True) 22 | cutoff_fn (Callable): Apply a cutoff function to the interatomic distances. 23 | 24 | References: 25 | .. [#Cutoff] Ebert, D. S.; Musgrave, F. K.; Peachey, D.; Perlin, K.; Worley, S. 26 | Texturing & Modeling: A Procedural Approach; 27 | Morgan Kaufmann, 2003 28 | .. [#ZBL] 29 | https://docs.lammps.org/pair_zbl.html 30 | """ 31 | 32 | def __init__( 33 | self, 34 | energy_unit: Union[str, float], 35 | position_unit: Union[str, float], 36 | output_key: str, 37 | trainable: bool = True, 38 | cutoff_fn: Optional[Callable] = None, 39 | ): 40 | super(ZBLRepulsionEnergy, self).__init__() 41 | 42 | energy_units = spk_units.convert_units("Ha", energy_unit) 43 | position_units = spk_units.convert_units("Bohr", position_unit) 44 | ke = energy_units * position_units 45 | self.register_buffer("ke", torch.tensor(ke)) 46 | 47 | self.cutoff_fn = cutoff_fn 48 | self.output_key = output_key 49 | 50 | # Basic ZBL parameters (in atomic units) 51 | # Since all quantities have a predefined sign, they are initialized to the inverse softplus and a softplus 52 | # function is applied in the forward pass to guarantee the correct sign during training 53 | a_div = snn.softplus_inverse( 54 | torch.tensor([1.0 / (position_units * 0.8854)]) 55 | ) # in this way, distances can be used directly 56 | a_pow = snn.softplus_inverse(torch.tensor([0.23])) 57 | exponents = snn.softplus_inverse( 58 | torch.tensor([3.19980, 0.94229, 0.40290, 0.20162]) 59 | ) 60 | coefficients = snn.softplus_inverse( 61 | torch.tensor([0.18175, 0.50986, 0.28022, 0.02817]) 62 | ) 63 | 64 | # Initialize network parameters 65 | self.a_pow = nn.Parameter(a_pow, requires_grad=trainable) 66 | self.a_div = nn.Parameter(a_div, requires_grad=trainable) 67 | self.coefficients = nn.Parameter(coefficients, requires_grad=trainable) 68 | self.exponents = nn.Parameter(exponents, requires_grad=trainable) 69 | 70 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 71 | z = inputs[properties.Z] 72 | r_ij = inputs[properties.Rij] 73 | d_ij = torch.norm(r_ij, dim=1) 74 | idx_i = inputs[properties.idx_i] 75 | idx_j = inputs[properties.idx_j] 76 | idx_m = inputs[properties.idx_m] 77 | 78 | n_atoms = z.shape[0] 79 | n_molecules = int(idx_m[-1]) + 1 80 | 81 | # Construct screening function 82 | a = z ** F.softplus(self.a_pow) 83 | a_ij = (a[idx_i] + a[idx_j]) * F.softplus(self.a_div) 84 | # Get exponents and coefficients, normalize the latter 85 | exponents = a_ij[..., None] * F.softplus(self.exponents)[None, ...] 86 | coefficients = F.softplus(self.coefficients)[None, ...] 87 | coefficients = F.normalize(coefficients, p=1.0, dim=1) 88 | 89 | screening = torch.sum( 90 | coefficients * torch.exp(-exponents * d_ij[:, None]), dim=1 91 | ) 92 | 93 | # Compute nuclear repulsion 94 | repulsion = (z[idx_i] * z[idx_j]) / d_ij 95 | 96 | # Apply cutoff if requested 97 | if self.cutoff_fn is not None: 98 | f_cut = self.cutoff_fn(d_ij) 99 | repulsion = repulsion * f_cut 100 | 101 | # Compute ZBL energy 102 | y_zbl = snn.scatter_add(repulsion * screening, idx_i, dim_size=n_atoms) 103 | y_zbl = snn.scatter_add(y_zbl, idx_m, dim_size=n_molecules) 104 | y_zbl = 0.5 * self.ke * y_zbl 105 | 106 | inputs[self.output_key] = y_zbl 107 | 108 | return inputs 109 | -------------------------------------------------------------------------------- /src/schnetpack/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/src/schnetpack/configs/__init__.py -------------------------------------------------------------------------------- /src/schnetpack/configs/callbacks/checkpoint.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: schnetpack.train.ModelCheckpoint 3 | monitor: "val_loss" # name of the logged metric which determines when model is improving 4 | save_top_k: 1 # save k best models (determined by above metric) 5 | save_last: True # additionaly always save model from last epoch 6 | mode: "min" # can be "max" or "min" 7 | verbose: False 8 | dirpath: 'checkpoints/' 9 | filename: '{epoch:02d}' 10 | model_path: ${globals.model_path} -------------------------------------------------------------------------------- /src/schnetpack/configs/callbacks/earlystopping.yaml: -------------------------------------------------------------------------------- 1 | early_stopping: 2 | _target_: pytorch_lightning.callbacks.EarlyStopping 3 | monitor: "val_loss" # name of the logged metric which determines when model is improving 4 | patience: 200 # how many epochs of not improving until training stops 5 | mode: "min" # can be "max" or "min" 6 | min_delta: 0.0 # minimum change in the monitored metric needed to qualify as an improvement 7 | check_on_train_epoch_end: False -------------------------------------------------------------------------------- /src/schnetpack/configs/callbacks/ema.yaml: -------------------------------------------------------------------------------- 1 | ema: 2 | _target_: schnetpack.train.ExponentialMovingAverage 3 | decay: 0.995 -------------------------------------------------------------------------------- /src/schnetpack/configs/callbacks/lrmonitor.yaml: -------------------------------------------------------------------------------- 1 | lr_monitor: 2 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 3 | logging_interval: epoch -------------------------------------------------------------------------------- /src/schnetpack/configs/data/ani1.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - custom 3 | 4 | _target_: schnetpack.datasets.ANI1 5 | 6 | datapath: ${run.data_dir}/ani1.db # data_dir is specified in train.yaml 7 | batch_size: 32 8 | num_train: 10000000 9 | num_val: 100000 10 | num_heavy_atoms: 8 11 | high_energies: False 12 | 13 | # convert to typically used units 14 | distance_unit: Ang 15 | property_units: 16 | energy: eV -------------------------------------------------------------------------------- /src/schnetpack/configs/data/custom.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.data.AtomsDataModule 2 | 3 | datapath: ??? 4 | data_workdir: null 5 | batch_size: 10 6 | num_train: ??? 7 | num_val: ??? 8 | num_test: null 9 | num_workers: 8 10 | num_val_workers: null 11 | num_test_workers: null 12 | train_sampler_cls: null -------------------------------------------------------------------------------- /src/schnetpack/configs/data/iso17.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - custom 3 | 4 | _target_: schnetpack.datasets.ISO17 5 | 6 | datapath: ${run.data_dir}/${data.folder}.db # data_dir is specified in train.yaml 7 | folder: reference 8 | batch_size: 32 9 | num_train: 0.9 10 | num_val: 0.1 11 | -------------------------------------------------------------------------------- /src/schnetpack/configs/data/materials_project.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - custom 3 | 4 | _target_: schnetpack.datasets.MaterialsProject 5 | 6 | datapath: ${run.data_dir}/materials_project.db # data_dir is specified in train.yaml 7 | batch_size: 32 8 | num_train: 60000 9 | num_val: 2000 10 | apikey: ??? -------------------------------------------------------------------------------- /src/schnetpack/configs/data/md17.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - custom 3 | 4 | _target_: schnetpack.datasets.MD17 5 | 6 | datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml 7 | molecule: aspirin 8 | batch_size: 10 9 | num_train: 950 10 | num_val: 50 11 | -------------------------------------------------------------------------------- /src/schnetpack/configs/data/md22.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - custom 3 | 4 | _target_: schnetpack.datasets.MD22 5 | 6 | datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml 7 | molecule: Ac-Ala3-NHMe 8 | batch_size: 10 9 | num_train: 5700 10 | num_val: 300 11 | -------------------------------------------------------------------------------- /src/schnetpack/configs/data/omdb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - custom 3 | 4 | _target_: schnetpack.datasets.OrganicMaterialsDatabase 5 | 6 | datapath: ${run.data_dir}/omdb.db # data_dir is specified in train.yaml 7 | batch_size: 32 8 | num_train: 0.8 9 | num_val: 0.1 10 | raw_path: null -------------------------------------------------------------------------------- /src/schnetpack/configs/data/qm7x.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - custom 3 | 4 | _target_: schnetpack.datasets.QM7X 5 | 6 | datapath: ${run.data_dir}/qm7x.db # data_dir is specified in train.yaml 7 | batch_size: 100 8 | num_train: 5550 9 | num_val: 700 -------------------------------------------------------------------------------- /src/schnetpack/configs/data/qm9.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - custom 3 | 4 | _target_: schnetpack.datasets.QM9 5 | 6 | datapath: ${run.data_dir}/qm9.db # data_dir is specified in train.yaml 7 | batch_size: 100 8 | num_train: 110000 9 | num_val: 10000 10 | remove_uncharacterized: True 11 | 12 | # convert to typically used units 13 | distance_unit: Ang 14 | property_units: 15 | energy_U0: eV 16 | energy_U: eV 17 | enthalpy_H: eV 18 | free_energy: eV 19 | homo: eV 20 | lumo: eV 21 | gap: eV 22 | zpve: eV -------------------------------------------------------------------------------- /src/schnetpack/configs/data/rmd17.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - custom 3 | 4 | _target_: schnetpack.datasets.rMD17 5 | 6 | datapath: ${run.data_dir}/rmd17_${data.molecule}.db # data_dir is specified in train.yaml 7 | molecule: aspirin 8 | batch_size: 10 9 | num_train: 950 10 | num_val: 50 11 | split_id: null -------------------------------------------------------------------------------- /src/schnetpack/configs/data/sampler/stratified_property.yaml: -------------------------------------------------------------------------------- 1 | # @package data 2 | train_sampler_cls: schnetpack.data.sampler.StratifiedSampler 3 | train_sampler_args: 4 | partition_criterion: 5 | _target_: schnetpack.data.sampler.PropertyCriterion 6 | property_key: ${globals.property} 7 | num_bins: 10 8 | replacement: True -------------------------------------------------------------------------------- /src/schnetpack/configs/experiment/md17.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: nnp 5 | - override /data: md17 6 | 7 | run: 8 | experiment: md17_${data.molecule} 9 | 10 | globals: 11 | cutoff: 5. 12 | lr: 1e-3 13 | energy_key: energy 14 | forces_key: forces 15 | 16 | data: 17 | distance_unit: Ang 18 | property_units: 19 | energy: kcal/mol 20 | forces: kcal/mol/Ang 21 | transforms: 22 | - _target_: schnetpack.transform.SubtractCenterOfMass 23 | - _target_: schnetpack.transform.RemoveOffsets 24 | property: energy 25 | remove_mean: True 26 | - _target_: schnetpack.transform.MatScipyNeighborList 27 | cutoff: ${globals.cutoff} 28 | - _target_: schnetpack.transform.CastTo32 29 | 30 | model: 31 | output_modules: 32 | - _target_: schnetpack.atomistic.Atomwise 33 | output_key: ${globals.energy_key} 34 | n_in: ${model.representation.n_atom_basis} 35 | aggregation_mode: sum 36 | - _target_: schnetpack.atomistic.Forces 37 | energy_key: ${globals.energy_key} 38 | force_key: ${globals.forces_key} 39 | postprocessors: 40 | - _target_: schnetpack.transform.CastTo64 41 | - _target_: schnetpack.transform.AddOffsets 42 | property: energy 43 | add_mean: True 44 | 45 | task: 46 | outputs: 47 | - _target_: schnetpack.task.ModelOutput 48 | name: ${globals.energy_key} 49 | loss_fn: 50 | _target_: torch.nn.MSELoss 51 | metrics: 52 | mae: 53 | _target_: torchmetrics.regression.MeanAbsoluteError 54 | rmse: 55 | _target_: torchmetrics.regression.MeanSquaredError 56 | squared: False 57 | loss_weight: 0.01 58 | - _target_: schnetpack.task.ModelOutput 59 | name: ${globals.forces_key} 60 | loss_fn: 61 | _target_: torch.nn.MSELoss 62 | metrics: 63 | mae: 64 | _target_: torchmetrics.regression.MeanAbsoluteError 65 | rmse: 66 | _target_: torchmetrics.regression.MeanSquaredError 67 | squared: False 68 | loss_weight: 0.99 -------------------------------------------------------------------------------- /src/schnetpack/configs/experiment/qm9_atomwise.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: nnp 5 | - override /data: qm9 6 | 7 | run: 8 | experiment: qm9_${globals.property} 9 | 10 | globals: 11 | cutoff: 5. 12 | lr: 5e-4 13 | property: energy_U0 14 | aggregation: sum 15 | 16 | data: 17 | transforms: 18 | - _target_: schnetpack.transform.SubtractCenterOfMass 19 | - _target_: schnetpack.transform.RemoveOffsets 20 | property: ${globals.property} 21 | remove_atomrefs: True 22 | remove_mean: True 23 | - _target_: schnetpack.transform.MatScipyNeighborList 24 | cutoff: ${globals.cutoff} 25 | - _target_: schnetpack.transform.CastTo32 26 | 27 | model: 28 | output_modules: 29 | - _target_: schnetpack.atomistic.Atomwise 30 | output_key: ${globals.property} 31 | n_in: ${model.representation.n_atom_basis} 32 | aggregation_mode: ${globals.aggregation} 33 | postprocessors: 34 | - _target_: schnetpack.transform.CastTo64 35 | - _target_: schnetpack.transform.AddOffsets 36 | property: ${globals.property} 37 | add_mean: True 38 | add_atomrefs: True 39 | 40 | task: 41 | outputs: 42 | - _target_: schnetpack.task.ModelOutput 43 | name: ${globals.property} 44 | loss_fn: 45 | _target_: torch.nn.MSELoss 46 | metrics: 47 | mae: 48 | _target_: torchmetrics.regression.MeanAbsoluteError 49 | rmse: 50 | _target_: torchmetrics.regression.MeanSquaredError 51 | squared: False 52 | loss_weight: 1. -------------------------------------------------------------------------------- /src/schnetpack/configs/experiment/qm9_dipole.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: nnp 5 | - override /data: qm9 6 | 7 | run: 8 | experiment: qm9_${globals.property} 9 | 10 | globals: 11 | cutoff: 5. 12 | lr: 5e-4 13 | property: dipole_moment 14 | 15 | data: 16 | transforms: 17 | - _target_: schnetpack.transform.SubtractCenterOfMass 18 | - _target_: schnetpack.transform.MatScipyNeighborList 19 | cutoff: ${globals.cutoff} 20 | - _target_: schnetpack.transform.CastTo32 21 | 22 | model: 23 | output_modules: 24 | - _target_: schnetpack.atomistic.DipoleMoment 25 | dipole_key: ${globals.property} 26 | n_in: ${model.representation.n_atom_basis} 27 | predict_magnitude: True 28 | use_vector_representation: False 29 | postprocessors: 30 | - _target_: schnetpack.transform.CastTo64 31 | 32 | task: 33 | outputs: 34 | - _target_: schnetpack.task.ModelOutput 35 | name: ${globals.property} 36 | loss_fn: 37 | _target_: torch.nn.MSELoss 38 | metrics: 39 | mae: 40 | _target_: torchmetrics.regression.MeanAbsoluteError 41 | rmse: 42 | _target_: torchmetrics.regression.MeanSquaredError 43 | squared: False 44 | loss_weight: 1. -------------------------------------------------------------------------------- /src/schnetpack/configs/experiment/response.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: nnp 5 | - override /model/representation: field_schnet 6 | - override /data: custom 7 | 8 | run: 9 | experiment: response 10 | 11 | globals: 12 | cutoff: 9.448 13 | lr: 1e-4 14 | energy_key: energy 15 | forces_key: forces 16 | shielding_key: shielding 17 | shielding_elements: [ 1,6,8 ] 18 | response_properties: 19 | - forces 20 | - dipole_moment 21 | - polarizability 22 | - ${globals.shielding_key} 23 | 24 | data: 25 | distance_unit: 1.0 26 | batch_size: 10 27 | transforms: 28 | - _target_: schnetpack.transform.SubtractCenterOfMass 29 | - _target_: schnetpack.transform.RemoveOffsets 30 | property: ${globals.energy_key} 31 | remove_mean: true 32 | - _target_: schnetpack.transform.MatScipyNeighborList 33 | cutoff: ${globals.cutoff} 34 | - _target_: schnetpack.transform.CastTo32 35 | - _target_: schnetpack.transform.SplitShielding 36 | shielding_key: ${globals.shielding_key} 37 | atomic_numbers: ${globals.shielding_elements} 38 | 39 | model: 40 | input_modules: 41 | - _target_: schnetpack.atomistic.PairwiseDistances 42 | - _target_: schnetpack.atomistic.StaticExternalFields 43 | response_properties: ${globals.response_properties} 44 | output_modules: 45 | - _target_: schnetpack.atomistic.Atomwise 46 | output_key: ${globals.energy_key} 47 | n_in: ${model.representation.n_atom_basis} 48 | aggregation_mode: sum 49 | - _target_: schnetpack.transform.ScaleProperty 50 | input_key: ${globals.energy_key} 51 | output_key: ${globals.energy_key} 52 | - _target_: schnetpack.atomistic.Response 53 | energy_key: ${globals.energy_key} 54 | response_properties: ${globals.response_properties} 55 | - _target_: schnetpack.transform.SplitShielding 56 | shielding_key: ${globals.shielding_key} 57 | atomic_numbers: ${globals.shielding_elements} 58 | postprocessors: 59 | - _target_: schnetpack.transform.CastTo64 60 | - _target_: schnetpack.transform.AddOffsets 61 | property: energy 62 | add_mean: True 63 | 64 | task: 65 | scheduler_args: 66 | mode: min 67 | factor: 0.5 68 | patience: 50 69 | min_lr: 1e-6 70 | smoothing_factor: 0.0 71 | outputs: 72 | - _target_: schnetpack.task.ModelOutput 73 | name: ${globals.energy_key} 74 | loss_fn: 75 | _target_: torch.nn.MSELoss 76 | metrics: 77 | mae: 78 | _target_: torchmetrics.regression.MeanAbsoluteError 79 | rmse: 80 | _target_: torchmetrics.regression.MeanSquaredError 81 | squared: false 82 | loss_weight: 1.00 83 | - _target_: schnetpack.task.ModelOutput 84 | name: forces 85 | loss_fn: 86 | _target_: torch.nn.MSELoss 87 | metrics: 88 | mae: 89 | _target_: torchmetrics.regression.MeanAbsoluteError 90 | rmse: 91 | _target_: torchmetrics.regression.MeanSquaredError 92 | squared: false 93 | loss_weight: 5.0 94 | - _target_: schnetpack.task.ModelOutput 95 | name: dipole_moment 96 | loss_fn: 97 | _target_: torch.nn.MSELoss 98 | metrics: 99 | mae: 100 | _target_: torchmetrics.regression.MeanAbsoluteError 101 | rmse: 102 | _target_: torchmetrics.regression.MeanSquaredError 103 | squared: false 104 | loss_weight: 0.01 105 | - _target_: schnetpack.task.ModelOutput 106 | name: polarizability 107 | loss_fn: 108 | _target_: torch.nn.MSELoss 109 | metrics: 110 | mae: 111 | _target_: torchmetrics.regression.MeanAbsoluteError 112 | rmse: 113 | _target_: torchmetrics.regression.MeanSquaredError 114 | squared: false 115 | loss_weight: 0.01 116 | # shielding split by element 117 | - _target_: schnetpack.task.ModelOutput 118 | name: ${globals.shielding_key}_1 119 | loss_fn: 120 | _target_: torch.nn.MSELoss 121 | metrics: 122 | mae: 123 | _target_: torchmetrics.regression.MeanAbsoluteError 124 | rmse: 125 | _target_: torchmetrics.regression.MeanSquaredError 126 | squared: false 127 | mae_iso: 128 | _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError 129 | mae_aniso: 130 | _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError 131 | diagonal: false 132 | loss_weight: 0.1 133 | - _target_: schnetpack.task.ModelOutput 134 | name: ${globals.shielding_key}_6 135 | loss_fn: 136 | _target_: torch.nn.MSELoss 137 | metrics: 138 | mae: 139 | _target_: torchmetrics.regression.MeanAbsoluteError 140 | rmse: 141 | _target_: torchmetrics.regression.MeanSquaredError 142 | squared: false 143 | mae_iso: 144 | _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError 145 | mae_aniso: 146 | _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError 147 | diagonal: false 148 | loss_weight: 0.004 149 | - _target_: schnetpack.task.ModelOutput 150 | name: ${globals.shielding_key}_8 151 | loss_fn: 152 | _target_: torch.nn.MSELoss 153 | metrics: 154 | mae: 155 | _target_: torchmetrics.regression.MeanAbsoluteError 156 | rmse: 157 | _target_: torchmetrics.regression.MeanSquaredError 158 | squared: false 159 | mae_iso: 160 | _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError 161 | mae_aniso: 162 | _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError 163 | diagonal: false 164 | loss_weight: 0.001 165 | 166 | 167 | -------------------------------------------------------------------------------- /src/schnetpack/configs/experiment/rmd17.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: nnp 5 | - override /data: rmd17 6 | 7 | run: 8 | experiment: md17_${data.molecule} 9 | 10 | globals: 11 | cutoff: 5. 12 | lr: 1e-3 13 | energy_key: energy 14 | forces_key: forces 15 | 16 | data: 17 | distance_unit: Ang 18 | property_units: 19 | energy: kcal/mol 20 | forces: kcal/mol/Ang 21 | transforms: 22 | - _target_: schnetpack.transform.SubtractCenterOfMass 23 | - _target_: schnetpack.transform.RemoveOffsets 24 | property: energy 25 | remove_mean: True 26 | - _target_: schnetpack.transform.MatScipyNeighborList 27 | cutoff: ${globals.cutoff} 28 | - _target_: schnetpack.transform.CastTo32 29 | 30 | model: 31 | output_modules: 32 | - _target_: schnetpack.atomistic.Atomwise 33 | output_key: ${globals.energy_key} 34 | n_in: ${model.representation.n_atom_basis} 35 | aggregation_mode: sum 36 | - _target_: schnetpack.atomistic.Forces 37 | energy_key: ${globals.energy_key} 38 | force_key: ${globals.forces_key} 39 | postprocessors: 40 | - _target_: schnetpack.transform.CastTo64 41 | - _target_: schnetpack.transform.AddOffsets 42 | property: energy 43 | add_mean: True 44 | 45 | task: 46 | outputs: 47 | - _target_: schnetpack.task.ModelOutput 48 | name: ${globals.energy_key} 49 | loss_fn: 50 | _target_: torch.nn.MSELoss 51 | metrics: 52 | mae: 53 | _target_: torchmetrics.regression.MeanAbsoluteError 54 | rmse: 55 | _target_: torchmetrics.regression.MeanSquaredError 56 | squared: False 57 | loss_weight: 0.01 58 | - _target_: schnetpack.task.ModelOutput 59 | name: ${globals.forces_key} 60 | loss_fn: 61 | _target_: torch.nn.MSELoss 62 | metrics: 63 | mae: 64 | _target_: torchmetrics.regression.MeanAbsoluteError 65 | rmse: 66 | _target_: torchmetrics.regression.MeanSquaredError 67 | squared: False 68 | loss_weight: 0.99 69 | -------------------------------------------------------------------------------- /src/schnetpack/configs/globals/default_globals.yaml: -------------------------------------------------------------------------------- 1 | model_path: "best_model" -------------------------------------------------------------------------------- /src/schnetpack/configs/logger/aim.yaml: -------------------------------------------------------------------------------- 1 | # Aim Logger 2 | 3 | aim: 4 | _target_: aim.pytorch_lightning.AimLogger 5 | repo: ${hydra:runtime.cwd}/${run.path} 6 | experiment: ${run.experiment} 7 | -------------------------------------------------------------------------------- /src/schnetpack/configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # CSVLogger built in PyTorch Lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "." 6 | name: "csv/" 7 | -------------------------------------------------------------------------------- /src/schnetpack/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # TensorBoard 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "tensorboard/" 6 | name: "default" 7 | -------------------------------------------------------------------------------- /src/schnetpack/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | _target_: pytorch_lightning.loggers.WandbLogger 3 | -------------------------------------------------------------------------------- /src/schnetpack/configs/model/nnp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - representation: painn 3 | 4 | _target_: schnetpack.model.NeuralNetworkPotential 5 | 6 | input_modules: 7 | - _target_: schnetpack.atomistic.PairwiseDistances 8 | output_modules: ??? 9 | -------------------------------------------------------------------------------- /src/schnetpack/configs/model/representation/field_schnet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - radial_basis: gaussian 3 | 4 | _target_: schnetpack.representation.FieldSchNet 5 | n_atom_basis: 128 6 | n_interactions: 5 7 | external_fields: [] 8 | response_properties: ${globals.response_properties} 9 | shared_interactions: False 10 | cutoff_fn: 11 | _target_: schnetpack.nn.cutoff.CosineCutoff 12 | cutoff: ${globals.cutoff} -------------------------------------------------------------------------------- /src/schnetpack/configs/model/representation/painn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - radial_basis: gaussian 3 | 4 | _target_: schnetpack.representation.PaiNN 5 | n_atom_basis: 128 6 | n_interactions: 3 7 | shared_interactions: False 8 | shared_filters: False 9 | cutoff_fn: 10 | _target_: schnetpack.nn.cutoff.CosineCutoff 11 | cutoff: ${globals.cutoff} -------------------------------------------------------------------------------- /src/schnetpack/configs/model/representation/radial_basis/bessel.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.nn.radial.BesselRBF 2 | n_rbf: 20 3 | cutoff: ${globals.cutoff} -------------------------------------------------------------------------------- /src/schnetpack/configs/model/representation/radial_basis/gaussian.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.nn.radial.GaussianRBF 2 | n_rbf: 20 3 | cutoff: ${globals.cutoff} -------------------------------------------------------------------------------- /src/schnetpack/configs/model/representation/schnet.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - radial_basis: gaussian 3 | 4 | _target_: schnetpack.representation.SchNet 5 | n_atom_basis: 128 6 | n_interactions: 6 7 | cutoff_fn: 8 | _target_: schnetpack.nn.cutoff.CosineCutoff 9 | cutoff: ${globals.cutoff} -------------------------------------------------------------------------------- /src/schnetpack/configs/model/representation/so3net.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - radial_basis: gaussian 3 | 4 | _target_: schnetpack.representation.SO3net 5 | n_atom_basis: 128 6 | n_interactions: 3 7 | lmax: 2 8 | shared_interactions: False 9 | return_vector_representation: False 10 | cutoff_fn: 11 | _target_: schnetpack.nn.cutoff.CosineCutoff 12 | cutoff: ${globals.cutoff} -------------------------------------------------------------------------------- /src/schnetpack/configs/predict.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - trainer: default_trainer 5 | 6 | datapath: ??? 7 | modeldir: ??? 8 | outputdir: ??? 9 | cutoff: ??? 10 | ckpt_path: null 11 | batch_size: 100 12 | write_interval: epoch 13 | enable_grad: false 14 | write_idx_m: false 15 | 16 | data: 17 | _target_: schnetpack.data.ASEAtomsData 18 | datapath: ${datapath} 19 | transforms: 20 | - _target_: schnetpack.transform.SubtractCenterOfMass 21 | - _target_: schnetpack.transform.MatScipyNeighborList 22 | cutoff: ${cutoff} 23 | - _target_: schnetpack.transform.CastTo32 24 | 25 | 26 | # hydra configuration 27 | hydra: 28 | job: 29 | chdir: True 30 | 31 | # output paths for hydra logs 32 | run: 33 | dir: ${modeldir} 34 | 35 | # disable hydra config storage, since handled manually 36 | output_subdir: null 37 | 38 | help: 39 | app_name: SchNetPack Predict 40 | 41 | template: |- 42 | SchNetPack 43 | 44 | == Configuration groups == 45 | Compose your configuration from those groups (db=mysql) 46 | 47 | $APP_CONFIG_GROUPS 48 | 49 | == Config == 50 | This is the config generated for this run. 51 | You can change the config file to be loaded to a predefined one 52 | > spktrain --config-name=train_qm9 53 | 54 | or your own: 55 | > spktrain --config-dir=./my_configs --config-name=my_config 56 | 57 | You can override everything, for example: 58 | > spktrain --config-name=train_qm9 data_dir=/path/to/datadir data.batch_size=50 --help 59 | 60 | ------- 61 | $CONFIG 62 | ------- 63 | 64 | ${hydra.help.footer} 65 | -------------------------------------------------------------------------------- /src/schnetpack/configs/run/default_run.yaml: -------------------------------------------------------------------------------- 1 | work_dir: ${hydra:runtime.cwd} 2 | data_dir: ${run.work_dir}/data 3 | path: ${run.work_dir}/runs 4 | experiment: default 5 | id: ${uuid:1} 6 | ckpt_path: null 7 | -------------------------------------------------------------------------------- /src/schnetpack/configs/task/default_task.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - optimizer: adam 4 | - scheduler: reduce_on_plateau 5 | 6 | _target_: schnetpack.AtomisticTask 7 | outputs: ??? 8 | warmup_steps: 0 9 | -------------------------------------------------------------------------------- /src/schnetpack/configs/task/optimizer/adabelief.yaml: -------------------------------------------------------------------------------- 1 | # @package task 2 | optimizer_cls: adabelief_pytorch.AdaBelief 3 | optimizer_args: 4 | lr: ${globals.lr} 5 | weight_decay: 0.0 -------------------------------------------------------------------------------- /src/schnetpack/configs/task/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # @package task 2 | optimizer_cls: torch.optim.AdamW 3 | optimizer_args: 4 | lr: ${globals.lr} 5 | weight_decay: 0.0 6 | -------------------------------------------------------------------------------- /src/schnetpack/configs/task/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | # @package task 2 | optimizer_cls: torch.optim.SGD 3 | optimizer_args: 4 | lr: ${globals.lr} 5 | weight_decay: 0.0 6 | momentum: 0.0 7 | nesterov: False 8 | dampening: 0.0 -------------------------------------------------------------------------------- /src/schnetpack/configs/task/scheduler/reduce_on_plateau.yaml: -------------------------------------------------------------------------------- 1 | # @package task 2 | scheduler_cls: schnetpack.train.ReduceLROnPlateau 3 | scheduler_monitor: val_loss 4 | scheduler_args: 5 | mode: min 6 | factor: 0.5 7 | patience: 75 8 | threshold: 0.0 9 | threshold_mode: rel 10 | cooldown: 10 11 | min_lr: 0.0 12 | smoothing_factor: 0.0 -------------------------------------------------------------------------------- /src/schnetpack/configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - run: default_run 6 | - globals: default_globals 7 | - trainer: default_trainer 8 | - callbacks: 9 | - checkpoint 10 | - earlystopping 11 | - lrmonitor 12 | - ema 13 | - task: default_task 14 | - model: null 15 | - data: custom 16 | - logger: tensorboard 17 | - experiment: null 18 | 19 | print_config: True 20 | 21 | # hydra configuration 22 | hydra: 23 | job: 24 | chdir: True 25 | # output paths for hydra logs 26 | run: 27 | dir: ${run.path}/${run.id} 28 | 29 | searchpath: 30 | - file://${oc.env:PWD} 31 | - file://${oc.env:PWD}/configs 32 | 33 | # disable hydra config storage, since handled manually 34 | output_subdir: null 35 | 36 | help: 37 | app_name: SchNetPack Train 38 | 39 | template: |- 40 | SchNetPack 41 | 42 | == Configuration groups == 43 | Compose your configuration from those groups (db=mysql) 44 | 45 | $APP_CONFIG_GROUPS 46 | 47 | == Config == 48 | This is the config generated for this run. 49 | 50 | ------- 51 | $CONFIG 52 | ------- 53 | 54 | You can overide the config file with a pre-defined experiment config 55 | > spktrain experiment=qm9 56 | 57 | or your own experiment config, which needs to be located in a directory called `experiment` in the config search 58 | path, e.g., 59 | > spktrain --config-dir=./my_configs experiment=my_experiment 60 | 61 | with your experiment config located at `./my_configs/experiment/my_experiment.yaml`. 62 | Your current working directory as well as an optional config subdirectory are automatically in the config 63 | search path. Therefore, you can put your experiment config either in `./experiment`, 64 | or `./configs/experiment`. 65 | 66 | You can also override everything with the command line, for example: 67 | > spktrain experiment=qm9 data_dir=/path/to/datadir data.batch_size=50 68 | 69 | ${hydra.help.footer} 70 | -------------------------------------------------------------------------------- /src/schnetpack/configs/trainer/ddp_debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - debug_trainer 3 | 4 | devices: 2 5 | accelerator: cpu 6 | strategy: ddp 7 | num_nodes: 1 8 | -------------------------------------------------------------------------------- /src/schnetpack/configs/trainer/ddp_trainer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_trainer 3 | 4 | devices: 2 5 | accelerator: auto 6 | strategy: ddp 7 | num_nodes: 1 8 | -------------------------------------------------------------------------------- /src/schnetpack/configs/trainer/debug_trainer.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_trainer 3 | 4 | gpus: 1 5 | min_epochs: 1 6 | max_epochs: 3 7 | 8 | # prints 9 | detect_anomaly: True 10 | profiler: simple 11 | 12 | -------------------------------------------------------------------------------- /src/schnetpack/configs/trainer/default_trainer.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | devices: 1 4 | 5 | min_epochs: null 6 | max_epochs: 100000 7 | 8 | # prints 9 | enable_model_summary: True 10 | profiler: null 11 | 12 | gradient_clip_val: null 13 | accumulate_grad_batches: 1 14 | val_check_interval: 1.0 15 | check_val_every_n_epoch: 1 16 | 17 | num_sanity_val_steps: 0 18 | fast_dev_run: False 19 | overfit_batches: 0 20 | limit_train_batches: 1.0 21 | limit_val_batches: 1.0 22 | limit_test_batches: 1.0 23 | detect_anomaly: False 24 | 25 | precision: 32 26 | accelerator: auto 27 | num_nodes: 1 28 | deterministic: False 29 | inference_mode: False 30 | -------------------------------------------------------------------------------- /src/schnetpack/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .atoms import * 2 | from .loader import * 3 | from .stats import * 4 | from .splitting import * 5 | from .datamodule import * 6 | from .sampler import * 7 | -------------------------------------------------------------------------------- /src/schnetpack/data/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | from typing import Optional, Sequence 5 | from torch.utils.data import Dataset, Sampler 6 | from torch.utils.data.dataloader import _collate_fn_t, _T_co 7 | 8 | import schnetpack.properties as structure 9 | 10 | __all__ = ["AtomsLoader"] 11 | 12 | 13 | def _atoms_collate_fn(batch): 14 | """ 15 | Build batch from systems and properties & apply padding 16 | 17 | Args: 18 | examples (list): 19 | 20 | Returns: 21 | dict[str->torch.Tensor]: mini-batch of atomistic systems 22 | """ 23 | elem = batch[0] 24 | idx_keys = {structure.idx_i, structure.idx_j, structure.idx_i_triples} 25 | # Atom triple indices must be treated separately 26 | idx_triple_keys = {structure.idx_j_triples, structure.idx_k_triples} 27 | 28 | coll_batch = {} 29 | for key in elem: 30 | if (key not in idx_keys) and (key not in idx_triple_keys): 31 | coll_batch[key] = torch.cat([d[key] for d in batch], 0) 32 | elif key in idx_keys: 33 | coll_batch[key + "_local"] = torch.cat([d[key] for d in batch], 0) 34 | 35 | seg_m = torch.cumsum(coll_batch[structure.n_atoms], dim=0) 36 | seg_m = torch.cat([torch.zeros((1,), dtype=seg_m.dtype), seg_m], dim=0) 37 | idx_m = torch.repeat_interleave( 38 | torch.arange(len(batch)), repeats=coll_batch[structure.n_atoms], dim=0 39 | ) 40 | coll_batch[structure.idx_m] = idx_m 41 | 42 | for key in idx_keys: 43 | if key in elem.keys(): 44 | coll_batch[key] = torch.cat( 45 | [d[key] + off for d, off in zip(batch, seg_m)], 0 46 | ) 47 | 48 | # Shift the indices for the atom triples 49 | for key in idx_triple_keys: 50 | if key in elem.keys(): 51 | indices = [] 52 | offset = 0 53 | for idx, d in enumerate(batch): 54 | indices.append(d[key] + offset) 55 | offset += d[structure.idx_j].shape[0] 56 | coll_batch[key] = torch.cat(indices, 0) 57 | 58 | return coll_batch 59 | 60 | 61 | class AtomsLoader(DataLoader): 62 | """Data loader for subclasses of BaseAtomsData""" 63 | 64 | def __init__( 65 | self, 66 | dataset: Dataset[_T_co], 67 | batch_size: Optional[int] = 1, 68 | shuffle: bool = False, 69 | sampler: Optional[Sampler[int]] = None, 70 | batch_sampler: Optional[Sampler[Sequence[int]]] = None, 71 | num_workers: int = 0, 72 | collate_fn: _collate_fn_t = _atoms_collate_fn, 73 | pin_memory: bool = False, 74 | **kwargs, 75 | ): 76 | super(AtomsLoader, self).__init__( 77 | dataset=dataset, 78 | batch_size=batch_size, 79 | shuffle=shuffle, 80 | sampler=sampler, 81 | batch_sampler=batch_sampler, 82 | num_workers=num_workers, 83 | collate_fn=collate_fn, 84 | pin_memory=pin_memory, 85 | **kwargs, 86 | ) 87 | -------------------------------------------------------------------------------- /src/schnetpack/data/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List, Callable 2 | 3 | import numpy as np 4 | from torch.utils.data import Sampler, WeightedRandomSampler 5 | 6 | from schnetpack import properties 7 | from schnetpack.data import BaseAtomsData 8 | 9 | 10 | __all__ = [ 11 | "StratifiedSampler", 12 | "NumberOfAtomsCriterion", 13 | "PropertyCriterion", 14 | ] 15 | 16 | 17 | class NumberOfAtomsCriterion: 18 | """ 19 | A callable class that returns the number of atoms for each sample in the dataset. 20 | """ 21 | 22 | def __call__(self, dataset): 23 | n_atoms = [] 24 | for spl_idx in range(len(dataset)): 25 | sample = dataset[spl_idx] 26 | n_atoms.append(sample[properties.n_atoms].item()) 27 | return n_atoms 28 | 29 | 30 | class PropertyCriterion: 31 | """ 32 | A callable class that returns the specified property for each sample in the dataset. 33 | Property must be a scalar value. 34 | """ 35 | 36 | def __init__(self, property_key: str = properties.energy): 37 | self.property_key = property_key 38 | 39 | def __call__(self, dataset): 40 | property_values = [] 41 | for spl_idx in range(len(dataset)): 42 | sample = dataset[spl_idx] 43 | property_values.append(sample[self.property_key].item()) 44 | return property_values 45 | 46 | 47 | class StratifiedSampler(WeightedRandomSampler): 48 | """ 49 | A custom sampler that performs stratified sampling based on a partition criterion. 50 | 51 | Note: Make sure that num_bins is chosen sufficiently small to avoid too many empty bins. 52 | """ 53 | 54 | def __init__( 55 | self, 56 | data_source: BaseAtomsData, 57 | partition_criterion: Callable[[BaseAtomsData], List], 58 | num_samples: int, 59 | num_bins: int = 10, 60 | replacement: bool = True, 61 | verbose: bool = True, 62 | ) -> None: 63 | """ 64 | Args: 65 | data_source: The data source to be sampled from. 66 | partition_criterion: A callable function that takes a data source 67 | and returns a list of values used for partitioning. 68 | num_samples: The total number of samples to be drawn from the data source. 69 | num_bins: The number of bins to divide the partitioned values into. Defaults to 10. 70 | replacement: Whether to sample with replacement or without replacement. Defaults to True. 71 | verbose: Whether to print verbose output during sampling. Defaults to True. 72 | """ 73 | self.data_source = data_source 74 | self.num_bins = num_bins 75 | self.verbose = verbose 76 | 77 | weights = self.calculate_weights(partition_criterion) 78 | super().__init__( 79 | weights=weights, num_samples=num_samples, replacement=replacement 80 | ) 81 | 82 | def calculate_weights(self, partition_criterion): 83 | """ 84 | Calculates the weights for each sample based on the partition criterion. 85 | """ 86 | feature_values = partition_criterion(self.data_source) 87 | 88 | bin_counts, bin_edges = np.histogram(feature_values, bins=self.num_bins) 89 | bin_edges = bin_edges[1:] 90 | bin_edges[-1] += 0.1 91 | bin_indices = np.digitize(feature_values, bin_edges) 92 | 93 | min_counts = min(bin_counts[bin_counts != 0]) 94 | bin_weights = np.where(bin_counts == 0, 0, min_counts / bin_counts) 95 | weights = bin_weights[bin_indices] 96 | 97 | return weights 98 | -------------------------------------------------------------------------------- /src/schnetpack/data/stats.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import torch 4 | from tqdm import tqdm 5 | 6 | import schnetpack.properties as properties 7 | from schnetpack.data import AtomsLoader 8 | 9 | __all__ = ["calculate_stats", "estimate_atomrefs"] 10 | 11 | 12 | def calculate_stats( 13 | dataloader: AtomsLoader, 14 | divide_by_atoms: Dict[str, bool], 15 | atomref: Dict[str, torch.Tensor] = None, 16 | ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: 17 | """ 18 | Use the incremental Welford algorithm described in [h1]_ to accumulate 19 | the mean and standard deviation over a set of samples. 20 | 21 | References: 22 | ----------- 23 | .. [h1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance 24 | 25 | Args: 26 | dataloader: data loader 27 | divide_by_atoms: dict from property name to bool: 28 | If True, divide property by number of atoms before calculating statistics. 29 | atomref: reference values for single atoms to be removed before calculating stats 30 | 31 | Returns: 32 | Mean and standard deviation over all samples 33 | 34 | """ 35 | property_names = list(divide_by_atoms.keys()) 36 | norm_mask = torch.tensor( 37 | [float(divide_by_atoms[p]) for p in property_names], dtype=torch.float64 38 | ) 39 | 40 | count = 0 41 | mean = torch.zeros_like(norm_mask) 42 | M2 = torch.zeros_like(norm_mask) 43 | 44 | for props in tqdm(dataloader, "calculating statistics"): 45 | sample_values = [] 46 | for p in property_names: 47 | val = props[p][None, :] 48 | if atomref and p in atomref.keys(): 49 | ar = atomref[p] 50 | ar = ar[props[properties.Z]] 51 | idx_m = props[properties.idx_m] 52 | tmp = torch.zeros((idx_m[-1] + 1,), dtype=ar.dtype, device=ar.device) 53 | v0 = tmp.index_add(0, idx_m, ar) 54 | val -= v0 55 | 56 | sample_values.append(val) 57 | sample_values = torch.cat(sample_values, dim=0) 58 | 59 | batch_size = sample_values.shape[1] 60 | new_count = count + batch_size 61 | 62 | norm = norm_mask[:, None] * props[properties.n_atoms][None, :] + ( 63 | 1 - norm_mask[:, None] 64 | ) 65 | sample_values /= norm 66 | 67 | sample_mean = torch.mean(sample_values, dim=1) 68 | sample_m2 = torch.sum((sample_values - sample_mean[:, None]) ** 2, dim=1) 69 | 70 | delta = sample_mean - mean 71 | mean += delta * batch_size / new_count 72 | corr = batch_size * count / new_count 73 | M2 += sample_m2 + delta**2 * corr 74 | count = new_count 75 | 76 | stddev = torch.sqrt(M2 / count) 77 | stats = {pn: (mu, std) for pn, mu, std in zip(property_names, mean, stddev)} 78 | return stats 79 | 80 | 81 | def estimate_atomrefs(dataloader, is_extensive, z_max=100): 82 | """ 83 | Uses linear regression to estimate the elementwise biases (atomrefs). 84 | 85 | Args: 86 | dataloader: data loader 87 | is_extensive: If True, divide atom type counts by number of atoms before 88 | calculating statistics. 89 | 90 | Returns: 91 | Elementwise bias estimates over all samples 92 | 93 | """ 94 | property_names = list(is_extensive.keys()) 95 | n_data = len(dataloader.dataset) 96 | all_properties = {pname: torch.zeros(n_data) for pname in property_names} 97 | all_atom_types = torch.zeros((n_data, z_max)) 98 | data_counter = 0 99 | 100 | # loop over all batches 101 | for batch in tqdm(dataloader, "estimating atomrefs"): 102 | # load data 103 | idx_m = batch[properties.idx_m] 104 | atomic_numbers = batch[properties.Z] 105 | 106 | # get counts for atomic numbers 107 | unique_ids = torch.unique(idx_m) 108 | for i in unique_ids: 109 | atomic_numbers_i = atomic_numbers[idx_m == i] 110 | atom_types, atom_counts = torch.unique(atomic_numbers_i, return_counts=True) 111 | # save atom counts and properties 112 | for atom_type, atom_count in zip(atom_types, atom_counts): 113 | all_atom_types[data_counter, atom_type] = atom_count 114 | for pname in property_names: 115 | property_value = batch[pname][i] 116 | if not is_extensive[pname]: 117 | property_value *= batch[properties.n_atoms][i] 118 | all_properties[pname][data_counter] = property_value 119 | data_counter += 1 120 | 121 | # perform linear regression to get the elementwise energy contributions 122 | existing_atom_types = torch.where(all_atom_types.sum(axis=0) != 0)[0] 123 | X = torch.squeeze(all_atom_types[:, existing_atom_types]) 124 | w = dict() 125 | for pname in property_names: 126 | if is_extensive[pname]: 127 | w[pname] = torch.linalg.inv(X.T @ X) @ X.T @ all_properties[pname] 128 | else: 129 | w[pname] = ( 130 | torch.linalg.inv(X.T @ X) 131 | @ X.T 132 | @ (all_properties[pname] / X.sum(axis=1)) 133 | ) 134 | 135 | # compute energy estimates 136 | elementwise_contributions = { 137 | pname: torch.zeros((z_max)) for pname in property_names 138 | } 139 | for pname in property_names: 140 | for atom_type, weight in zip(existing_atom_types, w[pname]): 141 | elementwise_contributions[pname][atom_type] = weight 142 | 143 | return elementwise_contributions 144 | -------------------------------------------------------------------------------- /src/schnetpack/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .qm9 import * 2 | from .md17 import * 3 | from .md22 import * 4 | from .rmd17 import * 5 | from .iso17 import * 6 | from .ani1 import * 7 | from .materials_project import * 8 | from .omdb import * 9 | from .tmqm import * 10 | from .qm7x import * 11 | -------------------------------------------------------------------------------- /src/schnetpack/datasets/md22.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Dict, List 3 | 4 | from schnetpack.data import * 5 | from schnetpack.datasets.md17 import GDMLDataModule 6 | 7 | 8 | all = ["MD22"] 9 | 10 | 11 | class MD22(GDMLDataModule): 12 | """ 13 | MD22 benchmark data set for extended molecules containing molecular forces. 14 | 15 | References: 16 | .. [#md22_1] http://quantum-machine.org/gdml/#datasets 17 | 18 | """ 19 | 20 | def __init__( 21 | self, 22 | datapath: str, 23 | molecule: str, 24 | batch_size: int, 25 | num_train: Optional[int] = None, 26 | num_val: Optional[int] = None, 27 | num_test: Optional[int] = None, 28 | split_file: Optional[str] = "split.npz", 29 | format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE, 30 | load_properties: Optional[List[str]] = None, 31 | val_batch_size: Optional[int] = None, 32 | test_batch_size: Optional[int] = None, 33 | transforms: Optional[List[torch.nn.Module]] = None, 34 | train_transforms: Optional[List[torch.nn.Module]] = None, 35 | val_transforms: Optional[List[torch.nn.Module]] = None, 36 | test_transforms: Optional[List[torch.nn.Module]] = None, 37 | num_workers: int = 2, 38 | num_val_workers: Optional[int] = None, 39 | num_test_workers: Optional[int] = None, 40 | property_units: Optional[Dict[str, str]] = None, 41 | distance_unit: Optional[str] = None, 42 | data_workdir: Optional[str] = None, 43 | **kwargs, 44 | ): 45 | """ 46 | Args: 47 | datapath: path to dataset 48 | batch_size: (train) batch size 49 | num_train: number of training examples 50 | num_val: number of validation examples 51 | num_test: number of test examples 52 | split_file: path to npz file with data partitions 53 | format: dataset format 54 | load_properties: subset of properties to load 55 | val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. 56 | test_batch_size: test batch size. If None, use val_batch_size, then batch_size. 57 | transforms: Transform applied to each system separately before batching. 58 | train_transforms: Overrides transform_fn for training. 59 | val_transforms: Overrides transform_fn for validation. 60 | test_transforms: Overrides transform_fn for testing. 61 | num_workers: Number of data loader workers. 62 | num_val_workers: Number of validation data loader workers (overrides num_workers). 63 | num_test_workers: Number of test data loader workers (overrides num_workers). 64 | distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...). 65 | data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance. 66 | """ 67 | atomrefs = { 68 | self.energy: [ 69 | 0.0, 70 | -313.5150902000774, 71 | 0.0, 72 | 0.0, 73 | 0.0, 74 | 0.0, 75 | -23622.587180094913, 76 | -34219.46811826416, 77 | -47069.30768969713, 78 | ] 79 | } 80 | datasets_dict = { 81 | "Ac-Ala3-NHMe": "md22_Ac-Ala3-NHMe.npz", 82 | "DHA": "md22_DHA.npz", 83 | "stachyose": "md22_stachyose.npz", 84 | "AT-AT": "md22_AT-AT.npz", 85 | "AT-AT-CG-CG": "md22_AT-AT-CG-CG.npz", 86 | "buckyball-catcher": "md22_buckyball-catcher.npz", 87 | "double-walled_nanotube": "md22_double-walled_nanotube.npz", 88 | } 89 | 90 | super(MD22, self).__init__( 91 | datasets_dict=datasets_dict, 92 | download_url="http://www.quantum-machine.org/gdml/repo/datasets/", 93 | tmpdir="md22", 94 | molecule=molecule, 95 | datapath=datapath, 96 | batch_size=batch_size, 97 | num_train=num_train, 98 | num_val=num_val, 99 | num_test=num_test, 100 | split_file=split_file, 101 | format=format, 102 | load_properties=load_properties, 103 | val_batch_size=val_batch_size, 104 | test_batch_size=test_batch_size, 105 | transforms=transforms, 106 | train_transforms=train_transforms, 107 | val_transforms=val_transforms, 108 | test_transforms=test_transforms, 109 | num_workers=num_workers, 110 | num_val_workers=num_val_workers, 111 | num_test_workers=num_test_workers, 112 | property_units=property_units, 113 | distance_unit=distance_unit, 114 | data_workdir=data_workdir, 115 | atomrefs=atomrefs, 116 | **kwargs, 117 | ) 118 | -------------------------------------------------------------------------------- /src/schnetpack/interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | from .ase_interface import * 2 | from .batchwise_optimization import * 3 | -------------------------------------------------------------------------------- /src/schnetpack/md/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains all functionality for performing various molecular dynamics simulations 3 | using SchNetPack. 4 | """ 5 | 6 | from .system import * 7 | from .initial_conditions import * 8 | from .simulator import * 9 | from . import integrators 10 | from . import simulation_hooks 11 | from . import calculators 12 | from . import neighborlist_md 13 | from . import utils 14 | from . import data 15 | -------------------------------------------------------------------------------- /src/schnetpack/md/calculators/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculator modules for molecular dynamics simulations. 3 | """ 4 | 5 | from .base_calculator import * 6 | from .schnetpack_calculator import * 7 | from .lj_calculator import * 8 | from .ensemble_calculator import * 9 | from .orca_calculator import * 10 | -------------------------------------------------------------------------------- /src/schnetpack/md/calculators/ensemble_calculator.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | 4 | from abc import ABC 5 | from typing import TYPE_CHECKING, List, Dict, Optional 6 | from schnetpack.md.calculators import MDCalculator 7 | 8 | if TYPE_CHECKING: 9 | from schnetpack.md import System 10 | 11 | __all__ = ["EnsembleCalculator"] 12 | 13 | 14 | class EnsembleCalculator(ABC, MDCalculator): 15 | """ 16 | Mixin for creating ensemble calculators from the standard `schnetpack.md.calculators` classes. Accumulates 17 | property predictions as the average over all models and uncertainties as the variance of model predictions. 18 | """ 19 | 20 | def calculate(self, system: System): 21 | """ 22 | Perform all calculations and compyte properties and uncertainties. 23 | 24 | Args: 25 | system (schnetpack.md.System): System from the molecular dynamics simulation. 26 | """ 27 | inputs = self._generate_input(system) 28 | 29 | results = [] 30 | for model in self.model: 31 | prediction = model(inputs) 32 | results.append(prediction) 33 | 34 | # Compute statistics 35 | self.results = self._accumulate_results(results) 36 | self._update_system(system) 37 | 38 | @staticmethod 39 | def _accumulate_results( 40 | results: List[Dict[str, torch.tensor]], 41 | ) -> Dict[str, torch.tensor]: 42 | """ 43 | Accumulate results and compute average predictions and uncertainties. 44 | 45 | Args: 46 | results (list(dict(str, torch.tensor)): List of output dictionaries of individual models. 47 | 48 | Returns: 49 | dict(str, torch.tensor): output dictionary with averaged predictions and uncertainties. 50 | """ 51 | # Get the keys 52 | accumulated = {p: [] for p in results[0]} 53 | ensemble_results = {p: [] for p in results[0]} 54 | 55 | for p in accumulated: 56 | tmp = torch.stack([result[p] for result in results]) 57 | ensemble_results[p] = torch.mean(tmp, dim=0) 58 | ensemble_results["{:s}_var".format(p)] = torch.var(tmp, dim=0) 59 | 60 | return ensemble_results 61 | 62 | def _activate_stress(self, stress_key: Optional[str] = None): 63 | """ 64 | Routine for activating stress computations 65 | Args: 66 | stress_key (str, optional): stess label. 67 | """ 68 | raise NotImplementedError 69 | 70 | def _update_required_properties(self): 71 | """ 72 | Update required properties to also contain predictive variances. 73 | """ 74 | new_required = [] 75 | for p in self.required_properties: 76 | prop_string = "{:s}_var".format(p) 77 | new_required += [p, prop_string] 78 | # Update property conversion 79 | self.property_conversion[prop_string] = self.property_conversion[p] 80 | 81 | self.required_properties = new_required 82 | -------------------------------------------------------------------------------- /src/schnetpack/md/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .hdf5_data import * 2 | from .spectra import * 3 | -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/src/schnetpack/md/md_configs/__init__.py -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/calculator/lj.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.calculators.LJCalculator 2 | r_equilibrium: 3.405 # Angs 3 | well_depth: 3.984264 # 4*e (e=119.8K) in kJ/mol 4 | force_key: forces 5 | energy_unit: kJ/mol 6 | position_unit: Angstrom 7 | energy_key: energy 8 | stress_key: stress 9 | healing_length: 4.0 #0.3405 10 | 11 | defaults: 12 | - neighbor_list: matscipy -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/calculator/neighbor_list/ase.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.neighborlist_md.NeighborListMD 2 | cutoff: ??? 3 | cutoff_shell: 2.0 4 | requires_triples: false 5 | base_nbl: schnetpack.transform.ASENeighborList 6 | collate_fn: schnetpack.data.loader._atoms_collate_fn 7 | -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.neighborlist_md.NeighborListMD 2 | cutoff: ??? 3 | cutoff_shell: 2.0 4 | requires_triples: false 5 | base_nbl: schnetpack.transform.MatScipyNeighborList 6 | collate_fn: schnetpack.data.loader._atoms_collate_fn 7 | -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/calculator/neighbor_list/torch.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.neighborlist_md.NeighborListMD 2 | cutoff: ??? 3 | cutoff_shell: 2.0 4 | requires_triples: false 5 | base_nbl: schnetpack.transform.TorchNeighborList 6 | collate_fn: schnetpack.data.loader._atoms_collate_fn -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/calculator/orca.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.calculators.OrcaCalculator 2 | required_properties: 3 | - energy 4 | - forces 5 | force_key: forces 6 | compdir: qm_calculations 7 | qm_executable: ??? 8 | orca_template: ??? 9 | energy_unit: Hartree 10 | position_unit: Bohr 11 | energy_key: energy 12 | stress_key: null 13 | property_conversion: { } 14 | overwrite: true 15 | adaptive: false 16 | basename: qm_calc 17 | -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/calculator/spk.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.calculators.SchNetPackCalculator 2 | required_properties: 3 | - energy 4 | - forces 5 | model_file: ??? 6 | force_key: forces 7 | energy_unit: kcal / mol 8 | position_unit: Angstrom 9 | energy_key: energy 10 | stress_key: null 11 | script_model: false 12 | 13 | defaults: 14 | - neighbor_list: matscipy -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.calculators.SchNetPackEnsembleCalculator 2 | required_properties: 3 | - energy 4 | - forces 5 | model_files: 6 | - ??? 7 | force_key: forces 8 | energy_unit: kcal / mol 9 | position_unit: Angstrom 10 | energy_key: energy 11 | stress_key: null 12 | script_model: false 13 | 14 | defaults: 15 | - neighbor_list: matscipy 16 | -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/callbacks/checkpoint.yaml: -------------------------------------------------------------------------------- 1 | checkpoint: 2 | _target_: schnetpack.md.simulation_hooks.Checkpoint 3 | checkpoint_file: checkpoint.chk 4 | every_n_steps: 10 -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/callbacks/hdf5.yaml: -------------------------------------------------------------------------------- 1 | hdf5: 2 | _target_: schnetpack.md.simulation_hooks.FileLogger 3 | filename: simulation.hdf5 4 | buffer_size: 100 5 | data_streams: 6 | - _target_: schnetpack.md.simulation_hooks.MoleculeStream 7 | store_velocities: true 8 | - _target_: schnetpack.md.simulation_hooks.PropertyStream 9 | target_properties: [ energy ] 10 | every_n_steps: 1 11 | precision: ${precision} -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/callbacks/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | tensorboard: 2 | _target_: schnetpack.md.simulation_hooks.TensorBoardLogger 3 | log_file: logs 4 | properties: 5 | - energy 6 | - temperature 7 | every_n_steps: 10 8 | -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | device: cuda 4 | precision: 32 5 | seed: null 6 | simulation_dir: ??? 7 | overwrite: false 8 | restart: null 9 | load_config: null 10 | 11 | defaults: 12 | - _self_ 13 | - calculator: spk 14 | - system: system 15 | - dynamics: base 16 | - callbacks: 17 | - checkpoint 18 | - hdf5 19 | - tensorboard 20 | 21 | hydra: 22 | run: 23 | dir: ${simulation_dir} 24 | job: 25 | config: 26 | override_dirname: 27 | exclude_keys: 28 | - basename 29 | kv_sep: '=' 30 | item_sep: '_' 31 | chdir: True 32 | -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/barostat/nhc_aniso.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.simulation_hooks.NHCBarostatAnisotropic 2 | target_pressure: 1000.0 # bar 3 | temperature_bath: 300.0 # K 4 | time_constant: 100.0 # fs 5 | time_constant_cell: 1000.0 6 | time_constant_barostat: 500.0 7 | chain_length: 4 8 | multi_step: 4 9 | integration_order: 7 10 | massive: true -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/barostat/nhc_iso.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.simulation_hooks.NHCBarostatIsotropic 2 | target_pressure: 1000.0 # bar 3 | temperature_bath: 300.0 # K 4 | time_constant: 100.0 # fs 5 | time_constant_cell: 1000.0 6 | time_constant_barostat: 500.0 7 | chain_length: 4 8 | multi_step: 4 9 | integration_order: 7 10 | massive: true -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/barostat/pile_rpmd.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.simulation_hooks.PILEBarostat 2 | target_pressure: 1000.0 # bar 3 | temperature_bath: 300.0 # K 4 | time_constant: 500.0 # fs -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - integrator: md 3 | 4 | n_steps: 1000000 5 | thermostat: null 6 | barostat: null 7 | progress: true 8 | simulation_hooks: [ ] -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/integrator/md.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.integrators.VelocityVerlet 2 | time_step: 0.5 # fs -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/integrator/rpmd.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.integrators.RingPolymer 2 | time_step: 0.20 # fs 3 | temperature: 300.0 # K -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/thermostat/berendsen.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.simulation_hooks.BerendsenThermostat 2 | temperature_bath: 300.0 # K 3 | time_constant: 100.0 # fs -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/thermostat/gle.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.simulation_hooks.GLEThermostat 2 | temperature_bath: 300.0 3 | gle_file: ??? 4 | free_particle_limit: true -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/thermostat/langevin.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.simulation_hooks.LangevinThermostat 2 | temperature_bath: 300.0 # K 3 | time_constant: 100.0 # fs -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/thermostat/nhc.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.simulation_hooks.NHCThermostat 2 | temperature_bath: 300.0 # K 3 | time_constant: 100.0 # fs 4 | chain_length: 3 5 | massive: false 6 | multi_step: 2 7 | integration_order: 3 -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/thermostat/pi_gle.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.simulation_hooks.RPMDGLEThermostat 2 | temperature_bath: 300.0 3 | gle_file: ??? 4 | free_particle_limit: true -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/thermostat/pi_nhc_global.yaml: -------------------------------------------------------------------------------- 1 | _target_ : schnetpack.md.simulation_hooks.NHCRingPolymerThermostat 2 | temperature_bath: 300.0 # K 3 | time_constant: 100.0 # fs 4 | local: false 5 | chain_length: 3 6 | multi_step: 2 7 | integration_order: 3 -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/thermostat/pi_nhc_local.yaml: -------------------------------------------------------------------------------- 1 | _target_ : schnetpack.md.simulation_hooks.NHCRingPolymerThermostat 2 | temperature_bath: 300.0 # K 3 | time_constant: 100.0 # fs 4 | local: true 5 | chain_length: 3 6 | multi_step: 2 7 | integration_order: 3 -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/thermostat/piglet.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.simulation_hooks.PIGLETThermostat 2 | temperature_bath: 300.0 3 | gle_file: ??? -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/thermostat/pile_global.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.simulation_hooks.PILEGlobalThermostat 2 | temperature_bath: 300.0 # K 3 | time_constant: 100.0 # fs -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/thermostat/pile_local.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.simulation_hooks.PILELocalThermostat 2 | temperature_bath: 300.0 # K 3 | time_constant: 100.0 # fs -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/dynamics/thermostat/trpmd.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.simulation_hooks.TRPMDThermostat 2 | temperature_bath: 300.0 # K 3 | damping_factor: 0.8 -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/system/initializer/maxwell_boltzmann.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.MaxwellBoltzmannInit 2 | temperature: 300 3 | remove_center_of_mass: true 4 | remove_translation: true 5 | remove_rotation: true 6 | wrap_positions: false -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/system/initializer/uniform.yaml: -------------------------------------------------------------------------------- 1 | _target_: schnetpack.md.UniformInit 2 | temperature: 300 3 | remove_center_of_mass: true 4 | remove_translation: true 5 | remove_rotation: true 6 | wrap_positions: false -------------------------------------------------------------------------------- /src/schnetpack/md/md_configs/system/system.yaml: -------------------------------------------------------------------------------- 1 | molecule_file: ??? 2 | load_system_state: null 3 | n_replicas: 1 4 | position_unit_input: Angstrom 5 | mass_unit_input: 1.0 6 | 7 | defaults: 8 | - initializer: uniform 9 | -------------------------------------------------------------------------------- /src/schnetpack/md/parsers/__init__.py: -------------------------------------------------------------------------------- 1 | from .orca_parser import * 2 | -------------------------------------------------------------------------------- /src/schnetpack/md/simulation_hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic_hooks import * 2 | from .barostats import * 3 | from .barostats_rpmd import * 4 | from .thermostats import * 5 | from .thermostats_rpmd import * 6 | from .callback_hooks import * 7 | -------------------------------------------------------------------------------- /src/schnetpack/md/simulation_hooks/basic_hooks.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch.nn as nn 3 | 4 | from schnetpack.md.utils import UninitializedMixin 5 | 6 | from typing import TYPE_CHECKING 7 | 8 | if TYPE_CHECKING: 9 | from schnetpack.md import Simulator 10 | 11 | __all__ = ["RemoveCOMMotion", "SimulationHook", "WrapPositions"] 12 | 13 | 14 | class SimulationHook(UninitializedMixin, nn.Module): 15 | """ 16 | Basic class for simulator hooks 17 | """ 18 | 19 | def on_step_begin(self, simulator: Simulator): 20 | pass 21 | 22 | def on_step_middle(self, simulator: Simulator): 23 | pass 24 | 25 | def on_step_end(self, simulator: Simulator): 26 | pass 27 | 28 | def on_step_finalize(self, simulator: Simulator): 29 | pass 30 | 31 | def on_step_failed(self, simulator: Simulator): 32 | pass 33 | 34 | def on_simulation_start(self, simulator: Simulator): 35 | pass 36 | 37 | def on_simulation_end(self, simulator: Simulator): 38 | pass 39 | 40 | 41 | class RemoveCOMMotion(SimulationHook): 42 | """ 43 | Periodically remove motions of the center of mass from the system. 44 | 45 | Args: 46 | every_n_steps (int): Frequency with which motions are removed. 47 | remove_rotation (bool): Also remove rotations. 48 | """ 49 | 50 | def __init__(self, every_n_steps: int, remove_rotation: bool): 51 | super(RemoveCOMMotion, self).__init__() 52 | self.every_n_steps = every_n_steps 53 | self.remove_rotation = remove_rotation 54 | 55 | def on_step_finalize(self, simulator: Simulator): 56 | if simulator.step % self.every_n_steps == 0: 57 | simulator.system.remove_center_of_mass() 58 | simulator.system.remove_translation() 59 | 60 | if self.remove_rotation: 61 | simulator.system.remove_com_rotation() 62 | 63 | 64 | class WrapPositions(SimulationHook): 65 | """ 66 | Periodically wrap atoms back into simulation cell. 67 | 68 | Args: 69 | every_n_steps (int): Frequency with which atoms should be wrapped. 70 | """ 71 | 72 | def __init__(self, every_n_steps: int): 73 | super(WrapPositions, self).__init__() 74 | self.every_n_steps = every_n_steps 75 | 76 | def on_step_finalize(self, simulator: Simulator): 77 | if simulator.step % self.every_n_steps == 0: 78 | simulator.system.wrap_positions() 79 | -------------------------------------------------------------------------------- /src/schnetpack/md/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import schnetpack 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .md_config import * 6 | from .normal_model_transformation import * 7 | from .thermostat_utils import * 8 | 9 | from typing import Optional 10 | 11 | 12 | class CalculatorError(Exception): 13 | pass 14 | 15 | 16 | def activate_model_stress( 17 | model: schnetpack.model.AtomisticModel, stress_key: str 18 | ) -> schnetpack.model.AtomisticModel: 19 | """ 20 | Utility function for activating computation of stress in models not explicitly trained on the stress tensor. 21 | Used for e.g. simulations under constant pressure and in cells. 22 | 23 | Args: 24 | model (AtomisticTask): loaded schnetpack model for which stress computation should be activated. 25 | stress_key (str): name of stress tensor in model. 26 | 27 | Returns: 28 | model (AtomisticTask): schnetpack model with activated stress tensor. 29 | """ 30 | stress = False 31 | 32 | # Check if a module suitable for stress computation is present 33 | for module in model.output_modules: 34 | if isinstance(module, schnetpack.atomistic.response.Forces) or isinstance( 35 | module, schnetpack.atomistic.Response 36 | ): 37 | # for `Forces` module 38 | if hasattr(module, "calc_stress"): 39 | # activate internal stress computation flag 40 | module.calc_stress = True 41 | 42 | # append stress label to output list and update required derivatives in the module 43 | module.model_outputs.append(stress_key) 44 | module.required_derivatives.append(schnetpack.properties.strain) 45 | 46 | # if not set in the model, also update output list and required derivatives so that: 47 | # a) required derivatives are computed and 48 | # b) property is added to the model outputs 49 | if stress_key not in model.model_outputs: 50 | model.model_outputs.append(stress_key) 51 | model.required_derivatives.append(schnetpack.properties.strain) 52 | 53 | stress = True 54 | 55 | # for `Response` module 56 | if hasattr(module, "basic_derivatives"): 57 | # activate internal stress computation flag 58 | module.calc_stress = True 59 | module.basic_derivatives["dEds"] = schnetpack.properties.strain 60 | module.derivative_instructions["dEds"] = True 61 | module.basic_derivatives["dEds"] = schnetpack.properties.strain 62 | 63 | module.map_properties[schnetpack.properties.stress] = ( 64 | schnetpack.properties.stress 65 | ) 66 | 67 | # append stress label to output list and update required derivatives in the module 68 | module.model_outputs.append(stress_key) 69 | module.required_derivatives.append(schnetpack.properties.strain) 70 | 71 | # if not set in the model, also update output list and required derivatives so that: 72 | # a) required derivatives are computed and 73 | # b) property is added to the model outputs 74 | if stress_key not in model.model_outputs: 75 | model.model_outputs.append(stress_key) 76 | model.required_derivatives.append(schnetpack.properties.strain) 77 | 78 | stress = True 79 | 80 | # If stress computation has been enables, insert preprocessing for strain computation 81 | if stress: 82 | model.input_modules.insert(0, schnetpack.atomistic.Strain()) 83 | 84 | if not stress: 85 | raise CalculatorError("Failed to activate stress computation") 86 | 87 | return model 88 | 89 | 90 | class UninitializedMixin(nn.modules.lazy.LazyModuleMixin): 91 | """ 92 | Custom mixin for lazy initialization of buffers used in the MD system and simulation hooks. 93 | This can be used to add buffers with a certain dtype in an uninitialized state. 94 | """ 95 | 96 | def register_uninitialized_buffer( 97 | self, name: str, dtype: Optional[torch.dtype] = None 98 | ): 99 | """ 100 | Register an uninitialized buffer with the requested dtype. This can be used to reserve variable which are not 101 | known at the initialization of `schnetpack.md.System` and simulation hooks. 102 | 103 | Args: 104 | name (str): Name of the uninitialized buffer to register. 105 | dtype (torch.dtype): If specified, buffer will be set to requested dtype. If None is given, this will 106 | default to float64 type. 107 | """ 108 | if dtype is None: 109 | dtype = torch.float64 110 | 111 | self.register_buffer(name, nn.parameter.UninitializedBuffer(dtype=dtype)) 112 | -------------------------------------------------------------------------------- /src/schnetpack/md/utils/normal_model_transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | __all__ = ["NormalModeTransformer"] 6 | 7 | 8 | class NormalModeTransformer(nn.Module): 9 | """ 10 | Class for transforming between bead and normal mode representation of the ring polymer, used e.g. in propagating the 11 | ring polymer during simulation. An in depth description of the transformation can be found e.g. in [#rpmd3]_. Here, 12 | a simple matrix multiplication is used instead of a Fourier transformation, which can be more performant in certain 13 | cases. On the GPU however, no significant performance gains where observed when using a FT based transformation over 14 | the matrix version. 15 | 16 | This transformation operates on the first dimension of the property tensors (e.g. positions, momenta) defined in the 17 | system class. Hence, the transformation can be carried out for several molecules at the same time. 18 | 19 | Args: 20 | n_beads (int): Number of beads in the ring polymer. 21 | 22 | References 23 | ---------- 24 | .. [#rpmd3] Ceriotti, Parrinello, Markland, Manolopoulos: 25 | Efficient stochastic thermostatting of path integral molecular dynamics. 26 | The Journal of Chemical Physics, 133, 124105. 2010. 27 | """ 28 | 29 | def __init__(self, n_beads): 30 | super(NormalModeTransformer, self).__init__() 31 | self.n_beads = n_beads 32 | 33 | # Initialize the transformation matrix 34 | c_transform = self._init_transformation_matrix() 35 | self.register_buffer("c_transform", c_transform) 36 | 37 | def _init_transformation_matrix(self): 38 | """ 39 | Build the normal mode transformation matrix. This matrix only has to be built once and can then be used during 40 | the whole simulation. The matrix has the dimension n_beads x n_beads, where n_beads is the number of beads in 41 | the ring polymer 42 | 43 | Returns: 44 | torch.Tensor: Normal mode transformation matrix of the shape n_beads x n_beads 45 | """ 46 | # Set up basic transformation matrix 47 | c_transform = np.zeros((self.n_beads, self.n_beads)) 48 | 49 | # Get auxiliary array with bead indices 50 | n = np.arange(1, self.n_beads + 1) 51 | 52 | # for k = 0 53 | c_transform[0, :] = 1.0 54 | 55 | for k in range(1, self.n_beads // 2 + 1): 56 | c_transform[k, :] = np.sqrt(2) * np.cos(2 * np.pi * k * n / self.n_beads) 57 | 58 | for k in range(self.n_beads // 2 + 1, self.n_beads): 59 | c_transform[k, :] = np.sqrt(2) * np.sin(2 * np.pi * k * n / self.n_beads) 60 | 61 | if self.n_beads % 2 == 0: 62 | c_transform[self.n_beads // 2, :] = (-1) ** n 63 | 64 | # Since matrix is initialized as C(k,n) does not need to be transposed 65 | c_transform /= np.sqrt(self.n_beads) 66 | c_transform = torch.from_numpy(c_transform) 67 | 68 | return c_transform 69 | 70 | def beads2normal(self, x_beads): 71 | """ 72 | Transform a system tensor (e.g. momenta, positions) from the bead representation to normal mode representation. 73 | 74 | Args: 75 | x_beads (torch.Tensor): System tensor in bead representation with the general shape 76 | n_beads x n_molecules x ... 77 | 78 | Returns: 79 | torch.Tensor: System tensor in normal mode representation with the same shape as the input tensor. 80 | """ 81 | return torch.mm(self.c_transform, x_beads.view(self.n_beads, -1)).view( 82 | x_beads.shape 83 | ) 84 | 85 | def normal2beads(self, x_normal): 86 | """ 87 | Transform a system tensor (e.g. momenta, positions) in normal mode representation back to bead representation. 88 | 89 | Args: 90 | x_normal (torch.Tensor): System tensor in normal mode representation with the general shape 91 | n_beads x n_molecules x ... 92 | 93 | Returns: 94 | torch.Tensor: System tensor in bead representation with the same shape as the input tensor. 95 | """ 96 | return torch.mm( 97 | self.c_transform.transpose(0, 1), x_normal.view(self.n_beads, -1) 98 | ).view(x_normal.shape) 99 | -------------------------------------------------------------------------------- /src/schnetpack/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | -------------------------------------------------------------------------------- /src/schnetpack/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic building blocks of SchNetPack models. Contains various basic and specialized network layers, layers for 3 | cutoff functions, as well as several auxiliary layers and functions. 4 | """ 5 | 6 | from schnetpack.nn.activations import * 7 | from schnetpack.nn.base import * 8 | from schnetpack.nn.blocks import * 9 | from schnetpack.nn.cutoff import * 10 | from schnetpack.nn.equivariant import * 11 | from schnetpack.nn.so3 import * 12 | from schnetpack.nn.scatter import * 13 | from schnetpack.nn.radial import * 14 | from schnetpack.nn.utils import * 15 | from schnetpack.nn.embedding import * 16 | -------------------------------------------------------------------------------- /src/schnetpack/nn/activations.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from torch.nn import functional 5 | 6 | __all__ = ["shifted_softplus", "softplus_inverse", "ShiftedSoftplus"] 7 | 8 | 9 | def shifted_softplus(x: torch.Tensor): 10 | r"""Compute shifted soft-plus activation function. 11 | 12 | .. math:: 13 | y = \ln\left(1 + e^{-x}\right) - \ln(2) 14 | 15 | Args: 16 | x (torch.Tensor): input tensor. 17 | 18 | Returns: 19 | torch.Tensor: shifted soft-plus of input. 20 | 21 | """ 22 | return functional.softplus(x) - math.log(2.0) 23 | 24 | 25 | def softplus_inverse(x: torch.Tensor): 26 | """ 27 | Inverse of the softplus function. 28 | 29 | Args: 30 | x (torch.Tensor): Input vector 31 | 32 | Returns: 33 | torch.Tensor: softplus inverse of input. 34 | """ 35 | return x + (torch.log(-torch.expm1(-x))) 36 | 37 | 38 | class ShiftedSoftplus(torch.nn.Module): 39 | """ 40 | Shifted softplus activation function with learnable feature-wise parameters: 41 | f(x) = alpha/beta * (softplus(beta*x) - log(2)) 42 | softplus(x) = log(exp(x) + 1) 43 | For beta -> 0 : f(x) -> 0.5*alpha*x 44 | For beta -> inf: f(x) -> max(0, alpha*x) 45 | 46 | With learnable parameters alpha and beta, the shifted softplus function can 47 | become equivalent to ReLU (if alpha is equal 1 and beta approaches infinity) or to 48 | the identity function (if alpha is equal 2 and beta is equal 0). 49 | """ 50 | 51 | def __init__( 52 | self, 53 | initial_alpha: float = 1.0, 54 | initial_beta: float = 1.0, 55 | trainable: bool = False, 56 | ) -> None: 57 | """ 58 | Args: 59 | initial_alpha: Initial "scale" alpha of the softplus function. 60 | initial_beta: Initial "temperature" beta of the softplus function. 61 | trainable: If True, alpha and beta are trained during optimization. 62 | """ 63 | super(ShiftedSoftplus, self).__init__() 64 | initial_alpha = torch.tensor(initial_alpha) 65 | initial_beta = torch.tensor(initial_beta) 66 | 67 | if trainable: 68 | self.alpha = torch.nn.Parameter(torch.FloatTensor([initial_alpha])) 69 | self.beta = torch.nn.Parameter(torch.FloatTensor([initial_beta])) 70 | else: 71 | self.register_buffer("alpha", initial_alpha) 72 | self.register_buffer("beta", initial_beta) 73 | 74 | def forward(self, x: torch.Tensor) -> torch.Tensor: 75 | """ 76 | Evaluate activation function given the input features x. 77 | num_features: Dimensions of feature space. 78 | 79 | Args: 80 | x (FloatTensor [:, num_features]): Input features. 81 | 82 | Returns: 83 | y (FloatTensor [:, num_features]): Activated features. 84 | """ 85 | return self.alpha * torch.where( 86 | self.beta != 0, 87 | (torch.nn.functional.softplus(self.beta * x) - math.log(2)) / self.beta, 88 | 0.5 * x, 89 | ) 90 | -------------------------------------------------------------------------------- /src/schnetpack/nn/base.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union, Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from torch.nn.init import xavier_uniform_ 7 | 8 | from torch.nn.init import zeros_ 9 | 10 | 11 | __all__ = ["Dense"] 12 | 13 | 14 | class Dense(nn.Linear): 15 | r"""Fully connected linear layer with activation function. 16 | 17 | .. math:: 18 | y = activation(x W^T + b) 19 | """ 20 | 21 | def __init__( 22 | self, 23 | in_features: int, 24 | out_features: int, 25 | bias: bool = True, 26 | activation: Union[Callable, nn.Module] = None, 27 | weight_init: Callable = xavier_uniform_, 28 | bias_init: Callable = zeros_, 29 | ): 30 | """ 31 | Args: 32 | in_features: number of input feature :math:`x`. 33 | out_features: number of output features :math:`y`. 34 | bias: If False, the layer will not adapt bias :math:`b`. 35 | activation: if None, no activation function is used. 36 | weight_init: weight initializer from current weight. 37 | bias_init: bias initializer from current bias. 38 | """ 39 | self.weight_init = weight_init 40 | self.bias_init = bias_init 41 | super(Dense, self).__init__(in_features, out_features, bias) 42 | 43 | self.activation = activation 44 | if self.activation is None: 45 | self.activation = nn.Identity() 46 | 47 | def reset_parameters(self): 48 | self.weight_init(self.weight) 49 | if self.bias is not None: 50 | self.bias_init(self.bias) 51 | 52 | def forward(self, input: torch.Tensor): 53 | y = F.linear(input, self.weight, self.bias) 54 | y = self.activation(y) 55 | return y 56 | -------------------------------------------------------------------------------- /src/schnetpack/nn/cutoff.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | __all__ = [ 6 | "CosineCutoff", 7 | "MollifierCutoff", 8 | "mollifier_cutoff", 9 | "cosine_cutoff", 10 | "SwitchFunction", 11 | ] 12 | 13 | 14 | def cosine_cutoff(input: torch.Tensor, cutoff: torch.Tensor): 15 | r""" Behler-style cosine cutoff. 16 | 17 | .. math:: 18 | f(r) = \begin{cases} 19 | 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right] 20 | & r < r_\text{cutoff} \\ 21 | 0 & r \geqslant r_\text{cutoff} \\ 22 | \end{cases} 23 | 24 | Args: 25 | cutoff (float, optional): cutoff radius. 26 | 27 | """ 28 | 29 | # Compute values of cutoff function 30 | input_cut = 0.5 * (torch.cos(input * math.pi / cutoff) + 1.0) 31 | # Remove contributions beyond the cutoff radius 32 | input_cut *= (input < cutoff).float() 33 | return input_cut 34 | 35 | 36 | class CosineCutoff(nn.Module): 37 | r""" Behler-style cosine cutoff module. 38 | 39 | .. math:: 40 | f(r) = \begin{cases} 41 | 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right] 42 | & r < r_\text{cutoff} \\ 43 | 0 & r \geqslant r_\text{cutoff} \\ 44 | \end{cases} 45 | 46 | """ 47 | 48 | def __init__(self, cutoff: float): 49 | """ 50 | Args: 51 | cutoff (float, optional): cutoff radius. 52 | """ 53 | super(CosineCutoff, self).__init__() 54 | self.register_buffer("cutoff", torch.FloatTensor([cutoff])) 55 | 56 | def forward(self, input: torch.Tensor): 57 | return cosine_cutoff(input, self.cutoff) 58 | 59 | 60 | def mollifier_cutoff(input: torch.Tensor, cutoff: torch.Tensor, eps: torch.Tensor): 61 | r""" Mollifier cutoff scaled to have a value of 1 at :math:`r=0`. 62 | 63 | .. math:: 64 | f(r) = \begin{cases} 65 | \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right) 66 | & r < r_\text{cutoff} \\ 67 | 0 & r \geqslant r_\text{cutoff} \\ 68 | \end{cases} 69 | 70 | Args: 71 | cutoff: Cutoff radius. 72 | eps: Offset added to distances for numerical stability. 73 | 74 | """ 75 | mask = (input + eps < cutoff).float() 76 | exponent = 1.0 - 1.0 / (1.0 - torch.pow(input * mask / cutoff, 2)) 77 | cutoffs = torch.exp(exponent) 78 | cutoffs = cutoffs * mask 79 | return cutoffs 80 | 81 | 82 | class MollifierCutoff(nn.Module): 83 | r""" Mollifier cutoff module scaled to have a value of 1 at :math:`r=0`. 84 | 85 | .. math:: 86 | f(r) = \begin{cases} 87 | \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right) 88 | & r < r_\text{cutoff} \\ 89 | 0 & r \geqslant r_\text{cutoff} \\ 90 | \end{cases} 91 | """ 92 | 93 | def __init__(self, cutoff: float, eps: float = 1.0e-7): 94 | """ 95 | Args: 96 | cutoff: Cutoff radius. 97 | eps: Offset added to distances for numerical stability. 98 | """ 99 | super(MollifierCutoff, self).__init__() 100 | self.register_buffer("cutoff", torch.FloatTensor([cutoff])) 101 | self.register_buffer("eps", torch.FloatTensor([eps])) 102 | 103 | def forward(self, input: torch.Tensor): 104 | return mollifier_cutoff(input, self.cutoff, self.eps) 105 | 106 | 107 | def _switch_component( 108 | x: torch.Tensor, ones: torch.Tensor, zeros: torch.Tensor 109 | ) -> torch.Tensor: 110 | """ 111 | Basic component of switching functions. 112 | 113 | Args: 114 | x (torch.Tensor): Switch functions. 115 | ones (torch.Tensor): Tensor with ones. 116 | zeros (torch.Tensor): Zero tensor 117 | 118 | Returns: 119 | torch.Tensor: Output tensor. 120 | """ 121 | x_ = torch.where(x <= 0, ones, x) 122 | return torch.where(x <= 0, zeros, torch.exp(-ones / x_)) 123 | 124 | 125 | class SwitchFunction(nn.Module): 126 | """ 127 | Decays from 1 to 0 between `switch_on` and `switch_off`. 128 | """ 129 | 130 | def __init__(self, switch_on: float, switch_off: float): 131 | """ 132 | 133 | Args: 134 | switch_on (float): Onset of switch. 135 | switch_off (float): Value from which on switch is 0. 136 | """ 137 | super(SwitchFunction, self).__init__() 138 | self.register_buffer("switch_on", torch.Tensor([switch_on])) 139 | self.register_buffer("switch_off", torch.Tensor([switch_off])) 140 | 141 | def forward(self, x: torch.Tensor) -> torch.Tensor: 142 | """ 143 | 144 | Args: 145 | x (torch.Tensor): tensor to which switching function should be applied to. 146 | 147 | Returns: 148 | torch.Tensor: switch output 149 | """ 150 | x = (x - self.switch_on) / (self.switch_off - self.switch_on) 151 | 152 | ones = torch.ones_like(x) 153 | zeros = torch.zeros_like(x) 154 | fp = _switch_component(x, ones, zeros) 155 | fm = _switch_component(1 - x, ones, zeros) 156 | 157 | f_switch = torch.where(x <= 0, ones, torch.where(x >= 1, zeros, fm / (fp + fm))) 158 | return f_switch 159 | -------------------------------------------------------------------------------- /src/schnetpack/nn/equivariant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import schnetpack.nn as snn 6 | from typing import Tuple 7 | 8 | __all__ = ["GatedEquivariantBlock"] 9 | 10 | 11 | class GatedEquivariantBlock(nn.Module): 12 | """ 13 | Gated equivariant block as used for the prediction of tensorial properties by PaiNN. 14 | Transforms scalar and vector representation using gated nonlinearities. 15 | 16 | References: 17 | 18 | .. [#painn1] Schütt, Unke, Gastegger: 19 | Equivariant message passing for the prediction of tensorial properties and molecular spectra. 20 | ICML 2021 (to appear) 21 | 22 | """ 23 | 24 | def __init__( 25 | self, 26 | n_sin: int, 27 | n_vin: int, 28 | n_sout: int, 29 | n_vout: int, 30 | n_hidden: int, 31 | activation=F.silu, 32 | sactivation=None, 33 | ): 34 | """ 35 | Args: 36 | n_sin: number of input scalar features 37 | n_vin: number of input vector features 38 | n_sout: number of output scalar features 39 | n_vout: number of output vector features 40 | n_hidden: number of hidden units 41 | activation: interal activation function 42 | sactivation: activation function for scalar outputs 43 | """ 44 | super().__init__() 45 | self.n_sin = n_sin 46 | self.n_vin = n_vin 47 | self.n_sout = n_sout 48 | self.n_vout = n_vout 49 | self.n_hidden = n_hidden 50 | self.mix_vectors = snn.Dense(n_vin, 2 * n_vout, activation=None, bias=False) 51 | self.scalar_net = nn.Sequential( 52 | snn.Dense(n_sin + n_vout, n_hidden, activation=activation), 53 | snn.Dense(n_hidden, n_sout + n_vout, activation=None), 54 | ) 55 | self.sactivation = sactivation 56 | 57 | def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor]): 58 | scalars, vectors = inputs 59 | vmix = self.mix_vectors(vectors) 60 | vectors_V, vectors_W = torch.split(vmix, self.n_vout, dim=-1) 61 | vectors_Vn = torch.norm(vectors_V, dim=-2) 62 | 63 | ctx = torch.cat([scalars, vectors_Vn], dim=-1) 64 | x = self.scalar_net(ctx) 65 | s_out, x = torch.split(x, [self.n_sout, self.n_vout], dim=-1) 66 | v_out = x.unsqueeze(-2) * vectors_W 67 | 68 | if self.sactivation: 69 | s_out = self.sactivation(s_out) 70 | 71 | return s_out, v_out 72 | -------------------------------------------------------------------------------- /src/schnetpack/nn/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/src/schnetpack/nn/ops/__init__.py -------------------------------------------------------------------------------- /src/schnetpack/nn/ops/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def binom(n: torch.Tensor, k: torch.Tensor) -> torch.Tensor: 5 | """ 6 | Compute binomial coefficients (n k) 7 | """ 8 | return torch.exp( 9 | torch.lgamma(n + 1) - torch.lgamma((n - k) + 1) - torch.lgamma(k + 1) 10 | ) 11 | -------------------------------------------------------------------------------- /src/schnetpack/nn/ops/so3.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from sympy.physics.wigner import clebsch_gordan 4 | 5 | from functools import lru_cache 6 | from typing import Tuple 7 | 8 | 9 | @lru_cache(maxsize=10) 10 | def sh_indices(lmax: int) -> Tuple[torch.Tensor, torch.Tensor]: 11 | """ 12 | Build index arrays for spherical harmonics 13 | 14 | Args: 15 | lmax: maximum angular momentum 16 | """ 17 | ls = torch.arange(0, lmax + 1) 18 | nls = 2 * ls + 1 19 | lidx = torch.repeat_interleave(ls, nls) 20 | midx = torch.cat([torch.arange(-l, l + 1) for l in ls]) 21 | return lidx, midx 22 | 23 | 24 | @lru_cache(maxsize=10) 25 | def generate_sh_to_rsh(lmax: int) -> torch.Tensor: 26 | """ 27 | Generate transformation matrix to convert (complex) spherical harmonics to real form 28 | 29 | Args: 30 | lmax: maximum angular momentum 31 | """ 32 | lidx, midx = sh_indices(lmax) 33 | l1 = lidx[:, None] 34 | l2 = lidx[None, :] 35 | m1 = midx[:, None] 36 | m2 = midx[None, :] 37 | U = ( 38 | 1.0 * ((m1 == 0) * (m2 == 0)) 39 | + (-1.0) ** abs(m1) / math.sqrt(2) * ((m1 == m2) * (m1 > 0)) 40 | + 1.0 / math.sqrt(2) * ((m1 == -m2) * (m2 < 0)) 41 | + -1.0j * (-1.0) ** abs(m1) / math.sqrt(2) * ((m1 == -m2) * (m1 < 0)) 42 | + 1.0j / math.sqrt(2) * ((m1 == m2) * (m1 < 0)) 43 | ) * (l1 == l2) 44 | return U 45 | 46 | 47 | @lru_cache(maxsize=10) 48 | def generate_clebsch_gordan(lmax: int) -> torch.Tensor: 49 | """ 50 | Generate standard Clebsch-Gordan coefficients for complex spherical harmonics 51 | 52 | Args: 53 | lmax: maximum angular momentum 54 | """ 55 | lidx, midx = sh_indices(lmax) 56 | cg = torch.zeros((lidx.shape[0], lidx.shape[0], lidx.shape[0])) 57 | lidx = lidx.numpy() 58 | midx = midx.numpy() 59 | for c1, (l1, m1) in enumerate(zip(lidx, midx)): 60 | for c2, (l2, m2) in enumerate(zip(lidx, midx)): 61 | for c3, (l3, m3) in enumerate(zip(lidx, midx)): 62 | if abs(l1 - l2) <= l3 <= min(l1 + l2, lmax) and m3 in { 63 | m1 + m2, 64 | m1 - m2, 65 | m2 - m1, 66 | -m1 - m2, 67 | }: 68 | coeff = clebsch_gordan(l1, l2, l3, m1, m2, m3) 69 | cg[c1, c2, c3] = float(coeff) 70 | return cg 71 | 72 | 73 | @lru_cache(maxsize=10) 74 | def generate_clebsch_gordan_rsh( 75 | lmax: int, parity_invariance: bool = True 76 | ) -> torch.Tensor: 77 | """ 78 | Generate Clebsch-Gordan coefficients for real spherical harmonics 79 | 80 | Args: 81 | lmax: maximum angular momentum 82 | parity_invariance: whether to enforce parity invariance, i.e. only allow 83 | non-zero coefficients if :math:`-1^l_1 -1^l_2 = -1^l_3` 84 | 85 | """ 86 | lidx, _ = sh_indices(lmax) 87 | cg = generate_clebsch_gordan(lmax).to(dtype=torch.complex64) 88 | complex_to_real = generate_sh_to_rsh(lmax) # (real, complex) 89 | cg_rsh = torch.einsum( 90 | "ijk,mi,nj,ok->mno", 91 | cg, 92 | complex_to_real, 93 | complex_to_real, 94 | complex_to_real.conj(), 95 | ) 96 | 97 | if parity_invariance: 98 | parity = (-1.0) ** lidx 99 | pmask = parity[:, None, None] * parity[None, :, None] == parity[None, None, :] 100 | cg_rsh *= pmask 101 | else: 102 | lsum = lidx[:, None, None] + lidx[None, :, None] - lidx[None, None, :] 103 | cg_rsh *= 1.0j**lsum 104 | 105 | # cast to real 106 | cg_rsh = cg_rsh.real.to(torch.float64) 107 | return cg_rsh 108 | 109 | 110 | def sparsify_clebsch_gordon( 111 | cg: torch.Tensor, 112 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 113 | """ 114 | Convert Clebsch-Gordon tensor to sparse format. 115 | 116 | Args: 117 | cg: dense tensor Clebsch-Gordon coefficients 118 | [(lmax_1+1)^2, (lmax_2+1)^2, (lmax_out+1)^2] 119 | 120 | Returns: 121 | cg_sparse: vector of non-zeros CG coefficients 122 | idx_in_1: indices for first set of irreps 123 | idx_in_2: indices for second set of irreps 124 | idx_out: indices for output set of irreps 125 | """ 126 | idx = torch.nonzero(cg) 127 | idx_in_1, idx_in_2, idx_out = torch.split(idx, 1, dim=1) 128 | idx_in_1, idx_in_2, idx_out = ( 129 | idx_in_1[:, 0], 130 | idx_in_2[:, 0], 131 | idx_out[:, 0], 132 | ) 133 | cg_sparse = cg[idx_in_1, idx_in_2, idx_out] 134 | return cg_sparse, idx_in_1, idx_in_2, idx_out 135 | 136 | 137 | def round_cmp(x: torch.Tensor, decimals: int = 1): 138 | return torch.round(x.real, decimals=decimals) + 1j * torch.round( 139 | x.imag, decimals=decimals 140 | ) 141 | -------------------------------------------------------------------------------- /src/schnetpack/nn/radial.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | __all__ = ["gaussian_rbf", "GaussianRBF", "GaussianRBFCentered", "BesselRBF"] 7 | 8 | from torch import nn as nn 9 | 10 | 11 | def gaussian_rbf(inputs: torch.Tensor, offsets: torch.Tensor, widths: torch.Tensor): 12 | coeff = -0.5 / torch.pow(widths, 2) 13 | diff = inputs[..., None] - offsets 14 | y = torch.exp(coeff * torch.pow(diff, 2)) 15 | return y 16 | 17 | 18 | class GaussianRBF(nn.Module): 19 | r"""Gaussian radial basis functions.""" 20 | 21 | def __init__( 22 | self, n_rbf: int, cutoff: float, start: float = 0.0, trainable: bool = False 23 | ): 24 | """ 25 | Args: 26 | n_rbf: total number of Gaussian functions, :math:`N_g`. 27 | cutoff: center of last Gaussian function, :math:`\mu_{N_g}` 28 | start: center of first Gaussian function, :math:`\mu_0`. 29 | trainable: If True, widths and offset of Gaussian functions 30 | are adjusted during training process. 31 | """ 32 | super(GaussianRBF, self).__init__() 33 | self.n_rbf = n_rbf 34 | 35 | # compute offset and width of Gaussian functions 36 | offset = torch.linspace(start, cutoff, n_rbf) 37 | widths = torch.FloatTensor( 38 | torch.abs(offset[1] - offset[0]) * torch.ones_like(offset) 39 | ) 40 | if trainable: 41 | self.widths = nn.Parameter(widths) 42 | self.offsets = nn.Parameter(offset) 43 | else: 44 | self.register_buffer("widths", widths) 45 | self.register_buffer("offsets", offset) 46 | 47 | def forward(self, inputs: torch.Tensor): 48 | return gaussian_rbf(inputs, self.offsets, self.widths) 49 | 50 | 51 | class GaussianRBFCentered(nn.Module): 52 | r"""Gaussian radial basis functions centered at the origin.""" 53 | 54 | def __init__( 55 | self, n_rbf: int, cutoff: float, start: float = 1.0, trainable: bool = False 56 | ): 57 | """ 58 | Args: 59 | n_rbf: total number of Gaussian functions, :math:`N_g`. 60 | cutoff: width of last Gaussian function, :math:`\mu_{N_g}` 61 | start: width of first Gaussian function, :math:`\mu_0`. 62 | trainable: If True, widths of Gaussian functions 63 | are adjusted during training process. 64 | """ 65 | super(GaussianRBFCentered, self).__init__() 66 | self.n_rbf = n_rbf 67 | 68 | # compute offset and width of Gaussian functions 69 | widths = torch.linspace(start, cutoff, n_rbf) 70 | offset = torch.zeros_like(widths) 71 | if trainable: 72 | self.widths = nn.Parameter(widths) 73 | self.offsets = nn.Parameter(offset) 74 | else: 75 | self.register_buffer("widths", widths) 76 | self.register_buffer("offsets", offset) 77 | 78 | def forward(self, inputs: torch.Tensor): 79 | return gaussian_rbf(inputs, self.offsets, self.widths) 80 | 81 | 82 | class BesselRBF(nn.Module): 83 | """ 84 | Sine for radial basis functions with coulomb decay (0th order bessel). 85 | 86 | References: 87 | 88 | .. [#dimenet] Klicpera, Groß, Günnemann: 89 | Directional message passing for molecular graphs. 90 | ICLR 2020 91 | """ 92 | 93 | def __init__(self, n_rbf: int, cutoff: float): 94 | """ 95 | Args: 96 | cutoff: radial cutoff 97 | n_rbf: number of basis functions. 98 | """ 99 | super(BesselRBF, self).__init__() 100 | self.n_rbf = n_rbf 101 | 102 | freqs = torch.arange(1, n_rbf + 1) * pi / cutoff 103 | self.register_buffer("freqs", freqs) 104 | 105 | def forward(self, inputs): 106 | ax = inputs[..., None] * self.freqs 107 | sinax = torch.sin(ax) 108 | norm = torch.where(inputs == 0, torch.tensor(1.0, device=inputs.device), inputs) 109 | y = sinax / norm[..., None] 110 | return y 111 | -------------------------------------------------------------------------------- /src/schnetpack/nn/scatter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | __all__ = ["scatter_add"] 5 | 6 | 7 | def scatter_add( 8 | x: torch.Tensor, idx_i: torch.Tensor, dim_size: int, dim: int = 0 9 | ) -> torch.Tensor: 10 | """ 11 | Sum over values with the same indices. 12 | 13 | Args: 14 | x: input values 15 | idx_i: index of center atom i 16 | dim_size: size of the dimension after reduction 17 | dim: the dimension to reduce 18 | 19 | Returns: 20 | reduced input 21 | 22 | """ 23 | return _scatter_add(x, idx_i, dim_size, dim) 24 | 25 | 26 | @torch.jit.script 27 | def _scatter_add( 28 | x: torch.Tensor, idx_i: torch.Tensor, dim_size: int, dim: int = 0 29 | ) -> torch.Tensor: 30 | shape = list(x.shape) 31 | shape[dim] = dim_size 32 | tmp = torch.zeros(shape, dtype=x.dtype, device=x.device) 33 | y = tmp.index_add(dim, idx_i, x) 34 | return y 35 | -------------------------------------------------------------------------------- /src/schnetpack/nn/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | import torch 4 | from torch import nn as nn 5 | 6 | __all__ = ["replicate_module", "derivative_from_atomic", "derivative_from_molecular"] 7 | 8 | from torch.autograd import grad 9 | 10 | 11 | def replicate_module( 12 | module_factory: Callable[[], nn.Module], n: int, share_params: bool 13 | ): 14 | if share_params: 15 | module_list = nn.ModuleList([module_factory()] * n) 16 | else: 17 | module_list = nn.ModuleList([module_factory() for i in range(n)]) 18 | return module_list 19 | 20 | 21 | def derivative_from_molecular( 22 | fx: torch.Tensor, 23 | dx: torch.Tensor, 24 | create_graph: bool = False, 25 | retain_graph: bool = False, 26 | ): 27 | """ 28 | Compute the derivative of `fx` with respect to `dx` if the leading dimension of `fx` is the number of molecules 29 | (e.g. energies, dipole moments, etc). 30 | 31 | Args: 32 | fx (torch.Tensor): Tensor for which the derivative is taken. 33 | dx (torch.Tensor): Derivative. 34 | create_graph (bool): Create computational graph. 35 | retain_graph (bool): Keep the computational graph. 36 | 37 | Returns: 38 | torch.Tensor: derivative of `fx` with respect to `dx`. 39 | """ 40 | fx_shape = fx.shape 41 | dx_shape = dx.shape 42 | # Final shape takes into consideration whether derivative will yield atomic or molecular properties 43 | final_shape = (dx_shape[0], *fx_shape[1:], *dx_shape[1:]) 44 | 45 | fx = fx.view(fx_shape[0], -1) 46 | 47 | dfdx = torch.stack( 48 | [ 49 | grad( 50 | fx[..., i], 51 | dx, 52 | torch.ones_like(fx[..., i]), 53 | create_graph=create_graph, 54 | retain_graph=retain_graph, 55 | )[0] 56 | for i in range(fx.shape[1]) 57 | ], 58 | dim=1, 59 | ) 60 | dfdx = dfdx.view(final_shape) 61 | 62 | return dfdx 63 | 64 | 65 | def derivative_from_atomic( 66 | fx: torch.Tensor, 67 | dx: torch.Tensor, 68 | n_atoms: torch.Tensor, 69 | create_graph: bool = False, 70 | retain_graph: bool = False, 71 | ): 72 | """ 73 | Compute the derivative of a tensor with the leading dimension of (batch x atoms) with respect to another tensor of 74 | either dimension (batch * atoms) (e.g. R) or (batch * atom pairs) (e.g. Rij). This function is primarily used for 75 | computing Hessians and Hessian-like response properties (e.g. nuclear spin-spin couplings). The final tensor will 76 | have the shape ( batch * atoms * atoms x ....). 77 | 78 | This is quite inefficient, use with care. 79 | 80 | Args: 81 | fx (torch.Tensor): Tensor for which the derivative is taken. 82 | dx (torch.Tensor): Derivative. 83 | n_atoms (torch.Tensor): Tensor containing the number of atoms for each molecule. 84 | create_graph (bool): Create computational graph. 85 | retain_graph (bool): Keep the computational graph. 86 | 87 | Returns: 88 | torch.Tensor: derivative of `fx` with respect to `dx`. 89 | """ 90 | # Split input tensor for easier bookkeeping 91 | fxm = fx.split(list(n_atoms)) 92 | 93 | dfdx = [] 94 | 95 | n_mol = 0 96 | # Compute all derivatives 97 | for idx in range(len(fxm)): 98 | fx = fxm[idx].view(-1) 99 | 100 | # Generate the individual derivatives 101 | dfdx_mol = [] 102 | for i in range(fx.shape[0]): 103 | dfdx_i = grad( 104 | fx[i], 105 | dx, 106 | torch.ones_like(fx[i]), 107 | create_graph=create_graph, 108 | retain_graph=retain_graph, 109 | )[0] 110 | 111 | dfdx_mol.append(dfdx_i[n_mol : n_mol + n_atoms[idx], ...]) 112 | 113 | # Build molecular matrix and reshape 114 | dfdx_mol = torch.stack(dfdx_mol, dim=0) 115 | dfdx_mol = dfdx_mol.view(n_atoms[idx], 3, n_atoms[idx], 3) 116 | dfdx_mol = dfdx_mol.permute(0, 2, 1, 3) 117 | dfdx_mol = dfdx_mol.reshape(n_atoms[idx] ** 2, 3, 3) 118 | 119 | dfdx.append(dfdx_mol) 120 | 121 | n_mol += n_atoms[idx] 122 | 123 | # Accumulate everything 124 | dfdx = torch.cat(dfdx, dim=0) 125 | 126 | return dfdx 127 | -------------------------------------------------------------------------------- /src/schnetpack/properties.py: -------------------------------------------------------------------------------- 1 | """ 2 | Keys to access structure properties. 3 | 4 | Note: Had to be moved out of Structure class for TorchScript compatibility 5 | 6 | """ 7 | 8 | from typing import Final 9 | 10 | idx: Final[str] = "_idx" 11 | 12 | ## structure 13 | Z: Final[str] = "_atomic_numbers" #: nuclear charge 14 | position: Final[str] = "_positions" #: atom positions 15 | R: Final[str] = position #: atom positions 16 | 17 | cell: Final[str] = "_cell" #: unit cell 18 | strain: Final[str] = "strain" 19 | pbc: Final[str] = "_pbc" #: periodic boundary conditions 20 | 21 | seg_m: Final[str] = "_seg_m" #: start indices of systems 22 | idx_m: Final[str] = "_idx_m" #: indices of systems 23 | idx_i: Final[str] = "_idx_i" #: indices of center atoms 24 | idx_j: Final[str] = "_idx_j" #: indices of neighboring atoms 25 | idx_i_lr: Final[str] = "_idx_i_lr" #: indices of center atoms for long-range 26 | idx_j_lr: Final[str] = "_idx_j_lr" #: indices of neighboring atoms for long-range 27 | 28 | lidx_i: Final[str] = "_idx_i_local" #: local indices of center atoms (within system) 29 | lidx_j: Final[str] = ( 30 | "_idx_j_local" #: local indices of neighboring atoms (within system) 31 | ) 32 | Rij: Final[str] = "_Rij" #: vectors pointing from center atoms to neighboring atoms 33 | Rij_lr: Final[str] = ( 34 | "_Rij_lr" #: vectors pointing from center atoms to neighboring atoms for long range 35 | ) 36 | n_atoms: Final[str] = "_n_atoms" #: number of atoms 37 | offsets: Final[str] = "_offsets" #: cell offset vectors 38 | offsets_lr: Final[str] = "_offsets_lr" #: cell offset vectors for long range 39 | 40 | R_strained: Final[str] = ( 41 | position + "_strained" 42 | ) #: atom positions with strain-dependence 43 | cell_strained: Final[str] = cell + "_strained" #: atom positions with strain-dependence 44 | 45 | n_nbh: Final[str] = "_n_nbh" #: number of neighbors 46 | 47 | #: indices of center atom triples 48 | idx_i_triples: Final[str] = "_idx_i_triples" 49 | 50 | #: indices of first neighboring atom triples 51 | idx_j_triples: Final[str] = "_idx_j_triples" 52 | 53 | #: indices of second neighboring atom triples 54 | idx_k_triples: Final[str] = "_idx_k_triples" 55 | 56 | ## chemical properties 57 | energy: Final[str] = "energy" 58 | forces: Final[str] = "forces" 59 | stress: Final[str] = "stress" 60 | masses: Final[str] = "masses" 61 | dipole_moment: Final[str] = "dipole_moment" 62 | polarizability: Final[str] = "polarizability" 63 | hessian: Final[str] = "hessian" 64 | dipole_derivatives: Final[str] = "dipole_derivatives" 65 | polarizability_derivatives: Final[str] = "polarizability_derivatives" 66 | total_charge: Final[str] = "total_charge" 67 | partial_charges: Final[str] = "partial_charges" 68 | spin_multiplicity: Final[str] = "spin_multiplicity" 69 | electric_field: Final[str] = "electric_field" 70 | magnetic_field: Final[str] = "magnetic_field" 71 | nuclear_magnetic_moments: Final[str] = "nuclear_magnetic_moments" 72 | shielding: Final[str] = "shielding" 73 | nuclear_spin_coupling: Final[str] = "nuclear_spin_coupling" 74 | 75 | ## external fields needed for different response properties 76 | required_external_fields = { 77 | dipole_moment: [electric_field], 78 | dipole_derivatives: [electric_field], 79 | partial_charges: [electric_field], 80 | polarizability: [electric_field], 81 | polarizability_derivatives: [electric_field], 82 | shielding: [magnetic_field], 83 | nuclear_spin_coupling: [magnetic_field], 84 | } 85 | -------------------------------------------------------------------------------- /src/schnetpack/representation/__init__.py: -------------------------------------------------------------------------------- 1 | from .schnet import * 2 | from .painn import * 3 | from .field_schnet import * 4 | from .so3net import * 5 | -------------------------------------------------------------------------------- /src/schnetpack/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .callbacks import * 2 | from .lr_scheduler import * 3 | -------------------------------------------------------------------------------- /src/schnetpack/train/callbacks.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from typing import Dict 3 | 4 | from pytorch_lightning.callbacks import Callback 5 | from pytorch_lightning.callbacks import ModelCheckpoint as BaseModelCheckpoint 6 | 7 | from torch_ema import ExponentialMovingAverage as EMA 8 | 9 | import torch 10 | import os 11 | from pytorch_lightning.callbacks import BasePredictionWriter 12 | from typing import List, Any 13 | from schnetpack.task import AtomisticTask 14 | from schnetpack import properties 15 | from collections import defaultdict 16 | 17 | 18 | __all__ = ["ModelCheckpoint", "PredictionWriter", "ExponentialMovingAverage"] 19 | 20 | 21 | class PredictionWriter(BasePredictionWriter): 22 | """ 23 | Callback to store prediction results using ``torch.save``. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | output_dir: str, 29 | write_interval: str, 30 | write_idx: bool = False, 31 | ): 32 | """ 33 | Args: 34 | output_dir: output directory for prediction files 35 | write_interval: can be one of ["batch", "epoch", "batch_and_epoch"] 36 | write_idx: Write molecular ids for all atoms. This is needed for 37 | atomic properties like forces. 38 | """ 39 | super().__init__(write_interval) 40 | self.output_dir = output_dir 41 | self.write_idx = write_idx 42 | os.makedirs(output_dir, exist_ok=True) 43 | 44 | def write_on_batch_end( 45 | self, 46 | trainer, 47 | pl_module: AtomisticTask, 48 | prediction: Any, 49 | batch_indices: List[int], 50 | batch: Any, 51 | batch_idx: int, 52 | dataloader_idx: int, 53 | ): 54 | bdir = os.path.join(self.output_dir, str(dataloader_idx)) 55 | os.makedirs(bdir, exist_ok=True) 56 | torch.save(prediction, os.path.join(bdir, f"{batch_idx}.pt")) 57 | 58 | def write_on_epoch_end( 59 | self, 60 | trainer, 61 | pl_module: AtomisticTask, 62 | predictions: List[Any], 63 | batch_indices: List[Any], 64 | ): 65 | # collect batches of predictions and restructure 66 | concatenated_predictions = defaultdict(list) 67 | for batch_prediction in predictions[0]: 68 | for property_name, data in batch_prediction.items(): 69 | if not self.write_idx and property_name == properties.idx_m: 70 | continue 71 | concatenated_predictions[property_name].append(data) 72 | concatenated_predictions = { 73 | property_name: torch.concat(data) 74 | for property_name, data in concatenated_predictions.items() 75 | } 76 | 77 | # save concatenated predictions 78 | torch.save( 79 | concatenated_predictions, 80 | os.path.join(self.output_dir, "predictions.pt"), 81 | ) 82 | 83 | 84 | class ModelCheckpoint(BaseModelCheckpoint): 85 | """ 86 | Like the PyTorch Lightning ModelCheckpoint callback, 87 | but also saves the best inference model with activated post-processing 88 | """ 89 | 90 | def __init__(self, model_path: str, do_postprocessing=True, *args, **kwargs): 91 | super().__init__(*args, **kwargs) 92 | self.model_path = model_path 93 | self.do_postprocessing = do_postprocessing 94 | 95 | def on_validation_end(self, trainer, pl_module: AtomisticTask) -> None: 96 | self.trainer = trainer 97 | self.task = pl_module 98 | super().on_validation_end(trainer, pl_module) 99 | 100 | def _update_best_and_save( 101 | self, current: torch.Tensor, trainer, monitor_candidates: Dict[str, Any] 102 | ): 103 | # save model checkpoint 104 | super()._update_best_and_save(current, trainer, monitor_candidates) 105 | 106 | # save best inference model 107 | if isinstance(current, torch.Tensor) and torch.isnan(current): 108 | current = torch.tensor(float("inf" if self.mode == "min" else "-inf")) 109 | 110 | if current == self.best_model_score: 111 | if self.trainer.strategy.local_rank == 0: 112 | # remove references to trainer and data loaders to avoid pickle error in ddp 113 | self.task.save_model(self.model_path, do_postprocessing=True) 114 | 115 | 116 | class ExponentialMovingAverage(Callback): 117 | def __init__(self, decay, *args, **kwargs): 118 | self.decay = decay 119 | self.ema = None 120 | self._to_load = None 121 | 122 | def on_fit_start(self, trainer, pl_module: AtomisticTask): 123 | if self.ema is None: 124 | self.ema = EMA(pl_module.model.parameters(), decay=self.decay) 125 | if self._to_load is not None: 126 | self.ema.load_state_dict(self._to_load) 127 | self._to_load = None 128 | 129 | # load average parameters, to have same starting point as after validation 130 | self.ema.store() 131 | self.ema.copy_to() 132 | 133 | def on_train_epoch_start( 134 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" 135 | ) -> None: 136 | self.ema.restore() 137 | 138 | def on_train_batch_end(self, trainer, pl_module: AtomisticTask, *args, **kwargs): 139 | self.ema.update() 140 | 141 | def on_validation_epoch_start( 142 | self, trainer: "pl.Trainer", pl_module: AtomisticTask, *args, **kwargs 143 | ): 144 | self.ema.store() 145 | self.ema.copy_to() 146 | 147 | def load_state_dict(self, state_dict): 148 | if "ema" in state_dict: 149 | if self.ema is None: 150 | self._to_load = state_dict["ema"] 151 | else: 152 | self.ema.load_state_dict(state_dict["ema"]) 153 | 154 | def state_dict(self): 155 | return {"ema": self.ema.state_dict()} 156 | -------------------------------------------------------------------------------- /src/schnetpack/train/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | __all__ = ["ReduceLROnPlateau"] 4 | 5 | 6 | class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): 7 | """ 8 | Extends PyTorch ReduceLROnPlateau by exponential smoothing of the monitored metric 9 | 10 | """ 11 | 12 | def __init__( 13 | self, 14 | optimizer, 15 | mode="min", 16 | factor=0.1, 17 | patience=10, 18 | threshold=1e-4, 19 | threshold_mode="rel", 20 | cooldown=0, 21 | min_lr=0, 22 | eps=1e-8, 23 | verbose=False, 24 | smoothing_factor=0.0, 25 | ): 26 | """ 27 | Args: 28 | optimizer (Optimizer): Wrapped optimizer. 29 | mode (str): One of `min`, `max`. In `min` mode, lr will 30 | be reduced when the quantity monitored has stopped 31 | decreasing; in `max` mode it will be reduced when the 32 | quantity monitored has stopped increasing. Default: 'min'. 33 | factor (float): Factor by which the learning rate will be 34 | reduced. new_lr = lr * factor. Default: 0.1. 35 | patience (int): Number of epochs with no improvement after 36 | which learning rate will be reduced. For example, if 37 | `patience = 2`, then we will ignore the first 2 epochs 38 | with no improvement, and will only decrease the LR after the 39 | 3rd epoch if the loss still hasn't improved then. 40 | Default: 10. 41 | threshold (float): Threshold for measuring the new optimum, 42 | to only focus on significant changes. Default: 1e-4. 43 | threshold_mode (str): One of `rel`, `abs`. In `rel` mode, 44 | dynamic_threshold = best * ( 1 + threshold ) in 'max' 45 | mode or best * ( 1 - threshold ) in `min` mode. 46 | In `abs` mode, dynamic_threshold = best + threshold in 47 | `max` mode or best - threshold in `min` mode. Default: 'rel'. 48 | cooldown (int): Number of epochs to wait before resuming 49 | normal operation after lr has been reduced. Default: 0. 50 | min_lr (float or list): A scalar or a list of scalars. A 51 | lower bound on the learning rate of all param groups 52 | or each group respectively. Default: 0. 53 | eps (float): Minimal decay applied to lr. If the difference 54 | between new and old lr is smaller than eps, the update is 55 | ignored. Default: 1e-8. 56 | verbose (bool): If ``True``, prints a message to stdout for 57 | each update. Default: ``False``. 58 | smoothing_factor: smoothing_factor of exponential moving average 59 | """ 60 | super().__init__( 61 | optimizer=optimizer, 62 | mode=mode, 63 | factor=factor, 64 | patience=patience, 65 | threshold=threshold, 66 | threshold_mode=threshold_mode, 67 | cooldown=cooldown, 68 | min_lr=min_lr, 69 | eps=eps, 70 | verbose=verbose, 71 | ) 72 | self.smoothing_factor = smoothing_factor 73 | self.ema_loss = None 74 | 75 | def step(self, metrics, epoch=None): 76 | current = float(metrics) 77 | if self.ema_loss is None: 78 | self.ema_loss = current 79 | else: 80 | self.ema_loss = ( 81 | self.smoothing_factor * self.ema_loss 82 | + (1.0 - self.smoothing_factor) * current 83 | ) 84 | super().step(current, epoch) 85 | -------------------------------------------------------------------------------- /src/schnetpack/train/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | from torchmetrics.functional.regression.mae import ( 4 | _mean_absolute_error_compute, 5 | _mean_absolute_error_update, 6 | ) 7 | 8 | from typing import Optional, Tuple 9 | 10 | __all__ = ["TensorDiagonalMeanAbsoluteError"] 11 | 12 | 13 | class TensorDiagonalMeanAbsoluteError(Metric): 14 | """ 15 | Custom torch metric for monitoring the mean absolute error on the diagonals and offdiagonals of tensors, e.g. 16 | polarizability. 17 | """ 18 | 19 | is_differentiable = True 20 | higher_is_better = False 21 | sum_abs_error: torch.Tensor 22 | total: torch.Tensor 23 | 24 | def __init__( 25 | self, 26 | diagonal: Optional[bool] = True, 27 | diagonal_dims: Optional[Tuple[int, int]] = (-2, -1), 28 | dist_sync_on_step=False, 29 | ) -> None: 30 | """ 31 | 32 | Args: 33 | diagonal (bool): If true, diagonal values are used, if False off-diagonal. 34 | diagonal_dims (tuple(int,int)): axes of the square matrix for which the diagonals should be considered. 35 | dist_sync_on_step (bool): synchronize. 36 | """ 37 | # call `self.add_state`for every internal state that is needed for the metrics computations 38 | # dist_reduce_fx indicates the function that should be used to reduce 39 | # state from multiple processes 40 | super().__init__(dist_sync_on_step=dist_sync_on_step) 41 | 42 | self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") 43 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 44 | 45 | self.diagonal = diagonal 46 | self.diagonal_dims = diagonal_dims 47 | self._diagonal_mask = None 48 | 49 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: 50 | """ 51 | Update the metric. 52 | 53 | Args: 54 | preds (torch.Tensor): network predictions. 55 | target (torch.Tensor): reference values. 56 | """ 57 | # update metric states 58 | preds = self._input_format(preds) 59 | target = self._input_format(target) 60 | 61 | sum_abs_error, n_obs = _mean_absolute_error_update(preds, target) 62 | 63 | self.sum_abs_error += sum_abs_error 64 | self.total += n_obs 65 | 66 | def compute(self) -> torch.Tensor: 67 | """ 68 | Compute the final metric. 69 | 70 | Returns: 71 | torch.Tensor: mean absolute error of diagonal or offdiagonal elements. 72 | """ 73 | # compute final result 74 | return _mean_absolute_error_compute(self.sum_abs_error, self.total) 75 | 76 | def _input_format(self, x) -> torch.Tensor: 77 | """ 78 | Extract diagonal / offdiagonal elements from input tensor. 79 | 80 | Args: 81 | x (torch.Tensor): input tensor. 82 | 83 | Returns: 84 | torch.Tensor: extracted and flattened elements (diagonal / offdiagonal) 85 | """ 86 | if self._diagonal_mask is None: 87 | self._diagonal_mask = self._init_diag_mask(x) 88 | return x.masked_select(self._diagonal_mask) 89 | 90 | def _init_diag_mask(self, x: torch.Tensor) -> torch.Tensor: 91 | """ 92 | Initialize the mask for extracting the diagonal elements based on the given axes and the shape of the 93 | input tensor. 94 | 95 | Args: 96 | x (torch.Tensor): input tensor. 97 | 98 | Returns: 99 | torch.Tensor: Boolean diagonal mask. 100 | """ 101 | tensor_shape = x.shape 102 | dim_0 = tensor_shape[self.diagonal_dims[0]] 103 | dim_1 = tensor_shape[self.diagonal_dims[1]] 104 | 105 | if not dim_0 == dim_1: 106 | raise AssertionError( 107 | "Found different size for diagonal dimensions, expected square sub matrix." 108 | ) 109 | 110 | view = [1 for _ in tensor_shape] 111 | view[self.diagonal_dims[0]] = dim_0 112 | view[self.diagonal_dims[1]] = dim_1 113 | 114 | diag_mask = torch.eye(dim_0, device=x.device, dtype=torch.long).view(view) 115 | 116 | if self.diagonal: 117 | return diag_mask == 1 118 | else: 119 | return diag_mask != 1 120 | -------------------------------------------------------------------------------- /src/schnetpack/transform/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transforms are applied before and/or after the model. They can be used, e.g., for calculating 3 | neighbor lists, casting, unit conversion or data augmentation. Some can applied before batching, 4 | i.e. to single systems, when loading the data. This is necessary for pre-processing and includes 5 | neighbor lists, for example. On the other hand, transforms need to be able to handle batches 6 | for post-processing. The flags `is_postprocessor` and `is_preprocessor` indicate how the tranforms 7 | may be used. The attribute `mode` of a transform is set automatically to either "pre" or "post".q 8 | """ 9 | 10 | from .atomistic import * 11 | from .casting import * 12 | from .neighborlist import * 13 | from .response import * 14 | from .base import * 15 | -------------------------------------------------------------------------------- /src/schnetpack/transform/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import schnetpack as spk 7 | 8 | __all__ = [ 9 | "Transform", 10 | "TransformException", 11 | ] 12 | 13 | 14 | class TransformException(Exception): 15 | pass 16 | 17 | 18 | class Transform(nn.Module): 19 | """ 20 | Base class for all transforms. 21 | The base class ensures that the reference to the data and datamodule attributes are 22 | initialized. 23 | Transforms can be used as pre- or post-processing layers. 24 | They can also be used for other parts of a model, that need to be 25 | initialized based on data. 26 | 27 | To implement a new transform, override the forward method. Preprocessors are applied 28 | to single examples, while postprocessors operate on batches. All transforms should 29 | return a modified `inputs` dictionary. 30 | 31 | """ 32 | 33 | def datamodule(self, value): 34 | """ 35 | Extract all required information from data module automatically when using 36 | PyTorch Lightning integration. The transform should also implement a way to 37 | set these things manually, to make it usable independent of PL. 38 | 39 | Do not store the datamodule, as this does not work with torchscript conversion! 40 | """ 41 | pass 42 | 43 | def forward( 44 | self, 45 | inputs: Dict[str, torch.Tensor], 46 | ) -> Dict[str, torch.Tensor]: 47 | raise NotImplementedError 48 | 49 | def teardown(self): 50 | pass 51 | -------------------------------------------------------------------------------- /src/schnetpack/transform/casting.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from typing import Dict, Optional 3 | from schnetpack.utils import as_dtype 4 | 5 | import torch 6 | 7 | from .base import Transform 8 | 9 | __all__ = ["CastMap", "CastTo32", "CastTo64"] 10 | 11 | 12 | class CastMap(Transform): 13 | """ 14 | Cast all inputs according to type map. 15 | """ 16 | 17 | is_preprocessor: bool = True 18 | is_postprocessor: bool = True 19 | 20 | def __init__(self, type_map: Dict[str, str]): 21 | """ 22 | Args: 23 | type_map: dict with source_type: target_type (as strings) 24 | """ 25 | super().__init__() 26 | self.type_map = type_map 27 | 28 | def forward( 29 | self, 30 | inputs: Dict[str, torch.Tensor], 31 | ) -> Dict[str, torch.Tensor]: 32 | for k, v in inputs.items(): 33 | vdtype = str(v.dtype).split(".")[-1] 34 | if vdtype in self.type_map: 35 | inputs[k] = v.to(dtype=as_dtype(self.type_map[vdtype])) 36 | return inputs 37 | 38 | 39 | class CastTo32(CastMap): 40 | """Cast all float64 tensors to float32""" 41 | 42 | def __init__(self): 43 | super().__init__(type_map={"float64": "float32"}) 44 | 45 | 46 | class CastTo64(CastMap): 47 | """Cast all float32 tensors to float64""" 48 | 49 | def __init__(self): 50 | super().__init__(type_map={"float32": "float64"}) 51 | -------------------------------------------------------------------------------- /src/schnetpack/transform/response.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from schnetpack.transform.base import Transform 4 | from schnetpack import properties 5 | 6 | from typing import Dict, List 7 | 8 | __all__ = ["SplitShielding"] 9 | 10 | 11 | class SplitShielding(Transform): 12 | """ 13 | Transform for splitting shielding tensors by atom types. 14 | """ 15 | 16 | is_preprocessor: bool = True 17 | is_postprocessor: bool = False 18 | 19 | def __init__( 20 | self, 21 | shielding_key: str, 22 | atomic_numbers: List[int], 23 | ): 24 | """ 25 | Args: 26 | shielding_key (str): name of the shielding tensor in the model inputs. 27 | atomic_numbers (list(int)): list of atomic numbers used to split the shielding tensor. 28 | """ 29 | super(SplitShielding, self).__init__() 30 | 31 | self.shielding_key = shielding_key 32 | self.atomic_numbers = atomic_numbers 33 | 34 | self.model_outputs = [ 35 | "{:s}_{:d}".format(self.shielding_key, atomic_number) 36 | for atomic_number in self.atomic_numbers 37 | ] 38 | 39 | def forward( 40 | self, 41 | inputs: Dict[str, torch.Tensor], 42 | ) -> Dict[str, torch.Tensor]: 43 | shielding = inputs[self.shielding_key] 44 | 45 | split_shielding = {} 46 | for atomic_number in self.atomic_numbers: 47 | atomic_key = "{:s}_{:d}".format(self.shielding_key, atomic_number) 48 | split_shielding[atomic_key] = shielding[ 49 | inputs[properties.Z] == atomic_number, :, : 50 | ] 51 | 52 | inputs.update(split_shielding) 53 | 54 | return inputs 55 | -------------------------------------------------------------------------------- /src/schnetpack/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .compatibility import * 2 | import importlib 3 | import torch 4 | from typing import Type, Union, List 5 | 6 | from schnetpack import properties as spk_properties 7 | 8 | TORCH_DTYPES = { 9 | "float32": torch.float32, 10 | "float64": torch.float64, 11 | "float": torch.float, 12 | "float16": torch.float16, 13 | "bfloat16": torch.bfloat16, 14 | "half": torch.half, 15 | "uint8": torch.uint8, 16 | "int8": torch.int8, 17 | "int16": torch.int16, 18 | "short": torch.short, 19 | "int32": torch.int32, 20 | "int": torch.int, 21 | "int64": torch.int64, 22 | "long": torch.long, 23 | "complex64": torch.complex64, 24 | "cfloat": torch.cfloat, 25 | "complex128": torch.complex128, 26 | "cdouble": torch.cdouble, 27 | "quint8": torch.quint8, 28 | "qint8": torch.qint8, 29 | "qint32": torch.qint32, 30 | "bool": torch.bool, 31 | } 32 | 33 | TORCH_DTYPES.update({"torch." + k: v for k, v in TORCH_DTYPES.items()}) 34 | 35 | 36 | def as_dtype(dtype_str: str) -> torch.dtype: 37 | """Convert a string to torch.dtype""" 38 | return TORCH_DTYPES[dtype_str] 39 | 40 | 41 | def int2precision(precision: Union[int, torch.dtype]): 42 | """ 43 | Get torch floating point precision from integer. 44 | If an instance of torch.dtype is passed, it is returned automatically. 45 | 46 | Args: 47 | precision (int, torch.dtype): Target precision. 48 | 49 | Returns: 50 | torch.dtype: Floating point precision. 51 | """ 52 | if isinstance(precision, torch.dtype): 53 | return precision 54 | else: 55 | try: 56 | return getattr(torch, f"float{precision}") 57 | except AttributeError: 58 | raise AttributeError(f"Unknown float precision {precision}") 59 | 60 | 61 | def str2class(class_path: str) -> Type: 62 | """ 63 | Obtain a class type from a string 64 | 65 | Args: 66 | class_path: module path to class, e.g. ``module.submodule.classname`` 67 | 68 | Returns: 69 | class type 70 | """ 71 | class_path = class_path.split(".") 72 | class_name = class_path[-1] 73 | module_name = ".".join(class_path[:-1]) 74 | cls = getattr(importlib.import_module(module_name), class_name) 75 | return cls 76 | 77 | 78 | def required_fields_from_properties(properties: List[str]) -> List[str]: 79 | """ 80 | Determine required external fields based on the response properties to be computed. 81 | 82 | Args: 83 | properties (list(str)): List of response properties for which external fields should be determined. 84 | 85 | Returns: 86 | list(str): List of required external fields. 87 | """ 88 | required_fields = set() 89 | 90 | for p in properties: 91 | if p in spk_properties.required_external_fields: 92 | required_fields.update(spk_properties.required_external_fields[p]) 93 | 94 | required_fields = list(required_fields) 95 | 96 | return required_fields 97 | -------------------------------------------------------------------------------- /src/schnetpack/utils/compatibility.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | from typing import Any, Union 4 | 5 | 6 | __all__ = ["load_model"] 7 | 8 | 9 | def load_model( 10 | model_path: str, device: Union[torch.device, str] = "cpu", **kwargs: Any 11 | ) -> torch.nn.Module: 12 | """ 13 | Load a SchNetPack model from a Torch file, enabling compatibility with models trained using earlier versions of 14 | SchNetPack. This function imports the old model and automatically updates it to the format used in the current 15 | SchNetPack version. To ensure proper functionality, the Torch model object must include a version tag, such as 16 | spk_version="2.0.4". 17 | 18 | Args: 19 | model_path (str): Path to the saved model file. 20 | device (torch.device or str): Device on which to load the model. Defaults to "cpu". 21 | **kwargs (Any): Additional keyword arguments for `torch.load`. 22 | 23 | Returns: 24 | torch.nn.Module: Loaded model. 25 | """ 26 | 27 | def _convert_from_older(model: torch.nn.Module) -> torch.nn.Module: 28 | model.spk_version = "2.0.4" 29 | return model 30 | 31 | def _convert_from_v2_0_4(model: torch.nn.Module) -> torch.nn.Module: 32 | if not hasattr(model.representation, "electronic_embeddings"): 33 | model.representation.electronic_embeddings = [] 34 | model.spk_version = "2.1.0" 35 | return model 36 | 37 | model = torch.load(model_path, map_location=device, weights_only=False, **kwargs) 38 | 39 | if not hasattr(model, "spk_version"): 40 | # make warning that model has no version information 41 | warnings.warn( 42 | "Model was saved without version information. Conversion to current version may fail." 43 | ) 44 | model = _convert_from_older(model) 45 | 46 | if model.spk_version == "2.0.4": 47 | model = _convert_from_v2_0_4(model) 48 | 49 | return model 50 | -------------------------------------------------------------------------------- /src/schnetpack/utils/script.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Dict, Sequence 2 | 3 | import pytorch_lightning as pl 4 | import rich 5 | import yaml 6 | from omegaconf import DictConfig, OmegaConf 7 | from pytorch_lightning.utilities import rank_zero_only 8 | from rich.syntax import Syntax 9 | from rich.tree import Tree 10 | 11 | 12 | __all__ = ["log_hyperparameters", "print_config"] 13 | 14 | 15 | def empty(*args, **kwargs): 16 | pass 17 | 18 | 19 | def todict(config: Union[DictConfig, Dict]): 20 | config_dict = yaml.safe_load(OmegaConf.to_yaml(config, resolve=True)) 21 | return config_dict 22 | 23 | 24 | @rank_zero_only 25 | def log_hyperparameters( 26 | config: DictConfig, 27 | model: pl.LightningModule, 28 | trainer: pl.Trainer, 29 | ) -> None: 30 | """ 31 | This saves Hydra config using Lightning loggers. 32 | """ 33 | 34 | # send hparams to all loggers 35 | trainer.logger.log_hyperparams(config) 36 | 37 | # disable logging any more hyperparameters for all loggers 38 | trainer.logger.log_hyperparams = empty 39 | 40 | 41 | @rank_zero_only 42 | def print_config( 43 | config: DictConfig, 44 | fields: Sequence[str] = ( 45 | "run", 46 | "globals", 47 | "data", 48 | "model", 49 | "task", 50 | "trainer", 51 | "callbacks", 52 | "logger", 53 | "seed", 54 | ), 55 | resolve: bool = True, 56 | ) -> None: 57 | """Prints content of DictConfig using Rich library and its tree structure. 58 | 59 | Args: 60 | config (DictConfig): Config. 61 | fields (Sequence[str], optional): Determines which main fields from config will be printed 62 | and in what order. 63 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 64 | """ 65 | 66 | style = "dim" 67 | tree = Tree( 68 | f":gear: Running with the following config:", style=style, guide_style=style 69 | ) 70 | 71 | for field in fields: 72 | branch = tree.add(field, style=style, guide_style=style) 73 | 74 | config_section = config.get(field) 75 | branch_content = str(config_section) 76 | if isinstance(config_section, DictConfig): 77 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 78 | 79 | branch.add(Syntax(branch_content, "yaml")) 80 | 81 | rich.print(tree) 82 | -------------------------------------------------------------------------------- /src/scripts/spkconvert: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | 4 | from ase.db import connect 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser( 11 | description="Set units of an ASE dataset, e.g. to convert from SchNetPack 1.0 to the new format." 12 | ) 13 | parser.add_argument( 14 | "data_path", 15 | help="Path to ASE DB dataset", 16 | ) 17 | parser.add_argument( 18 | "--distunit", 19 | help="Distance unit as string, corresponding to ASE units (e.g. `Ang`)", 20 | ) 21 | parser.add_argument( 22 | "--propunit", 23 | help="Property units as string, corresponding " 24 | "to ASE units (e.g. `kcal/mol/Ang`), in the form: `property1:unit1,property2:unit2`", 25 | ) 26 | parser.add_argument( 27 | "--expand_property_dims", 28 | default=[], 29 | nargs='+', 30 | help="Expanding the first dimension of the given property " 31 | "(required for example for old FieldSchNet datasets). " 32 | "Add property names here in the form 'property1 property2 property3'", 33 | ) 34 | args = parser.parse_args() 35 | with connect(args.data_path) as db: 36 | meta = db.metadata 37 | print(meta) 38 | 39 | if "atomrefs" not in meta.keys(): 40 | meta["atomrefs"] = {} 41 | elif "atref_labels" in meta.keys(): 42 | old_atref = np.array(meta["atomrefs"]) 43 | new_atomrefs = {} 44 | labels = meta["atref_labels"] 45 | if type(labels) is str: 46 | labels = [labels] 47 | for i, label in enumerate(labels): 48 | print(i, label, old_atref[:, i]) 49 | new_atomrefs[label] = list(old_atref[:, i]) 50 | meta["atomrefs"] = new_atomrefs 51 | del meta["atref_labels"] 52 | 53 | if args.distunit: 54 | if args.distunit == "A": 55 | raise ValueError( 56 | "The provided unit (A for Ampere) is not a valid distance unit according to the ASE unit" 57 | " definitions. You probably mean `Ang`/`Angstrom`. Please also check your property units!" 58 | ) 59 | meta["_distance_unit"] = args.distunit 60 | 61 | if args.propunit: 62 | if "_property_unit_dict" not in meta.keys(): 63 | meta["_property_unit_dict"] = {} 64 | 65 | for p in args.propunit.split(","): 66 | prop, unit = p.split(":") 67 | meta["_property_unit_dict"][prop] = unit 68 | 69 | with connect(args.data_path) as db: 70 | db.metadata = meta 71 | 72 | if args.expand_property_dims is not None and len(args.expand_property_dims) > 0: 73 | 74 | with connect(args.data_path) as db: 75 | for i in tqdm(range(len(db))): 76 | atoms_row = db.get(i + 1) 77 | data = {} 78 | for p, v in atoms_row.data.items(): 79 | if p in args.expand_property_dims: 80 | data[p] = np.expand_dims(v, 0) 81 | else: 82 | data[p] = v 83 | db.update(i + 1, data=data) -------------------------------------------------------------------------------- /src/scripts/spkdeploy: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | import torch.nn as nn 4 | from schnetpack.transform import CastTo64, CastTo32, AddOffsets 5 | import argparse 6 | 7 | 8 | # This script is supposed to take a pytorch model and save a just in time compiled version of it. 9 | # This is needed to run the model with LAMMPS. 10 | # For further info see examples/howtos/lammps.rst 11 | 12 | # Note that this script is designed for models that predict atomic forces via automatic differentiation (utilizing response modules). 13 | # Hence this script will not work for models without response modules. 14 | 15 | 16 | def get_jit_model(model): 17 | # fix invalid operations in postprocessing 18 | jit_postprocessors = nn.ModuleList() 19 | for postprocessor in model.postprocessors: 20 | # ignore type casting 21 | if type(postprocessor) in [CastTo64, CastTo32]: 22 | continue 23 | # ensure offset mean is float 24 | if type(postprocessor) == AddOffsets: 25 | postprocessor.mean = postprocessor.mean.float() 26 | 27 | jit_postprocessors.append(postprocessor) 28 | model.postprocessors = jit_postprocessors 29 | 30 | return torch.jit.script(model) 31 | 32 | 33 | def save_jit_model(model, model_path): 34 | jit_model = get_jit_model(model) 35 | 36 | # add metadata 37 | metadata = dict() 38 | metadata["cutoff"] = str(jit_model.representation.cutoff.item()).encode("ascii") 39 | 40 | torch.jit.save(jit_model, model_path, _extra_files=metadata) 41 | 42 | 43 | if __name__ == "__main__": 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("model_path") 47 | parser.add_argument("deployed_model_path") 48 | parser.add_argument("--device", type=str, default="cpu") 49 | args = parser.parse_args() 50 | 51 | model = torch.load(args.model_path, map_location=args.device) 52 | save_jit_model(model, args.deployed_model_path) 53 | 54 | print(f"stored deployed model at {args.deployed_model_path}.") 55 | -------------------------------------------------------------------------------- /src/scripts/spkmd: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import schnetpack.md.cli as cli 3 | 4 | if __name__ == "__main__": 5 | cli.simulate() 6 | -------------------------------------------------------------------------------- /src/scripts/spkpredict: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import schnetpack.cli as cli 3 | 4 | if __name__ == "__main__": 5 | cli.predict() 6 | -------------------------------------------------------------------------------- /src/scripts/spktrain: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import schnetpack.cli as cli 3 | 4 | if __name__ == "__main__": 5 | cli.train() 6 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Test suite 2 | 3 | SchNetPack has unit tests that you can run to make sure that everything is working properly. 4 | 5 | ## Test dependencies 6 | To install the additional test dependencies, install with: 7 | ``` 8 | pip install schnetpack[test] 9 | ``` 10 | 11 | Or, if you installed from source and assuming you are in your local copy of the SchNetPack repository, 12 | ``` 13 | pip install .[test] 14 | ``` 15 | 16 | ## Run tests 17 | In order to run the tests, run the following command from the root of the repository: 18 | ``` 19 | pytest tests 20 | ``` 21 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/__init__.py -------------------------------------------------------------------------------- /tests/atomistic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/atomistic/__init__.py -------------------------------------------------------------------------------- /tests/atomistic/test_response.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from schnetpack.data.loader import _atoms_collate_fn 5 | import schnetpack as spk 6 | 7 | 8 | def test_strain(environment_periodic): 9 | cutoff, props, neighbors = environment_periodic 10 | props.update(neighbors) 11 | batch = _atoms_collate_fn([props, props]) 12 | strained_batch = spk.atomistic.Strain()(batch) 13 | assert np.allclose( 14 | batch[spk.properties.offsets].detach().numpy(), 15 | strained_batch[spk.properties.offsets].detach().numpy(), 16 | ) 17 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/data/__init__.py -------------------------------------------------------------------------------- /tests/data/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/data/conftest.py -------------------------------------------------------------------------------- /tests/data/test_data.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | from schnetpack.data import * 5 | import os 6 | import schnetpack.properties as structure 7 | from schnetpack.data import calculate_stats, AtomsLoader 8 | 9 | 10 | @pytest.fixture 11 | def asedbpath(tmpdir): 12 | return os.path.join(tmpdir, "test.db") 13 | 14 | 15 | @pytest.fixture(scope="function") 16 | def asedb(asedbpath, example_data, property_units): 17 | 18 | asedb = ASEAtomsData.create( 19 | datapath=asedbpath, distance_unit="A", property_unit_dict=property_units 20 | ) 21 | 22 | atoms_list, prop_list = zip(*example_data) 23 | asedb.add_systems(property_list=prop_list, atoms_list=atoms_list) 24 | yield asedb 25 | 26 | os.remove(asedb.datapath) 27 | del asedb 28 | 29 | 30 | def test_asedb(asedb, example_data): 31 | assert os.path.exists(asedb.datapath) 32 | assert len(example_data) == len(asedb) 33 | assert set(asedb.metadata["_property_unit_dict"].keys()) == set( 34 | asedb.available_properties 35 | ) 36 | assert asedb.metadata["_property_unit_dict"] == asedb.units 37 | 38 | props = asedb[0] 39 | assert set(props.keys()) == set( 40 | [ 41 | structure.Z, 42 | structure.R, 43 | structure.cell, 44 | structure.pbc, 45 | structure.n_atoms, 46 | structure.idx, 47 | ] 48 | + asedb.available_properties 49 | ) 50 | 51 | load_properties = asedb.available_properties[0:2] 52 | asedb.load_properties = load_properties 53 | props = asedb[0] 54 | assert set(props.keys()) == set( 55 | [ 56 | structure.Z, 57 | structure.R, 58 | structure.cell, 59 | structure.pbc, 60 | structure.n_atoms, 61 | structure.idx, 62 | ] 63 | + load_properties 64 | ) 65 | 66 | asedb.load_structure = False 67 | props = asedb[0] 68 | assert set(props.keys()) == set( 69 | [ 70 | structure.n_atoms, 71 | structure.idx, 72 | ] 73 | + load_properties 74 | ) 75 | 76 | asedb.update_metadata(test=1) 77 | assert asedb.metadata["test"] == 1 78 | 79 | 80 | def test_asedb_getprops(asedb): 81 | props = list(asedb.iter_properties(0))[0] 82 | assert set(props.keys()) == set( 83 | [ 84 | structure.Z, 85 | structure.R, 86 | structure.cell, 87 | structure.pbc, 88 | structure.n_atoms, 89 | structure.idx, 90 | ] 91 | + asedb.available_properties 92 | ) 93 | 94 | 95 | def test_stats(): 96 | data = [] 97 | for i in range(6): 98 | Z = torch.tensor([1, 1, 1]) 99 | off = 1.0 if i % 2 == 0 else -1.0 100 | d = { 101 | structure.Z: Z, 102 | structure.n_atoms: torch.tensor([len(Z)]), 103 | "property1": torch.tensor([1.0 + len(Z) * off]), 104 | "property2": torch.tensor([off]), 105 | } 106 | data.append(d) 107 | 108 | atomref = {"property1": torch.ones((100,)) / 3.0} 109 | for bs in range(1, 7): 110 | stats = calculate_stats( 111 | AtomsLoader(data, batch_size=bs), 112 | {"property1": True, "property2": False}, 113 | atomref=atomref, 114 | ) 115 | assert np.allclose(stats["property1"][0].numpy(), np.array([0.0])) 116 | assert np.allclose(stats["property1"][1].numpy(), np.array([1.0])) 117 | assert np.allclose(stats["property2"][0].numpy(), np.array([0.0])) 118 | assert np.allclose(stats["property2"][1].numpy(), np.array([1.0])) 119 | 120 | 121 | def test_asedb_add(asedb, example_data): 122 | l = len(asedb) 123 | 124 | at, props = example_data[0] 125 | asedb.add_system(atoms=at, **props) 126 | 127 | props.update( 128 | { 129 | structure.Z: at.numbers, 130 | structure.R: at.positions, 131 | structure.cell: at.cell, 132 | structure.pbc: at.pbc, 133 | } 134 | ) 135 | asedb.add_system(**props) 136 | 137 | p1 = asedb[l] 138 | p2 = asedb[l + 1] 139 | for k, v in p1.items(): 140 | if k != "_idx": 141 | assert type(v) == torch.Tensor, k 142 | assert (p2[k] == v).all(), v 143 | -------------------------------------------------------------------------------- /tests/data/test_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import numpy as np 4 | 5 | from schnetpack.datasets import QM9, MD17, rMD17 6 | 7 | 8 | @pytest.fixture 9 | def test_qm9_path(): 10 | path = os.path.join(os.path.dirname(__file__), "../testdata/test_qm9.db") 11 | return path 12 | 13 | 14 | @pytest.mark.skip( 15 | "Run only local, not in CI. Otherwise takes too long and requires downloading " 16 | + "the data" 17 | ) 18 | def test_qm9(test_qm9_path): 19 | qm9 = QM9( 20 | test_qm9_path, 21 | num_train=10, 22 | num_val=5, 23 | batch_size=5, 24 | remove_uncharacterized=True, 25 | ) 26 | assert len(qm9.train_dataset) == 10 27 | assert len(qm9.val_dataset) == 5 28 | assert len(qm9.test_dataset) == 5 29 | 30 | ds = [b for b in qm9.train_dataloader()] 31 | assert len(ds) == 2 32 | ds = [b for b in qm9.val_dataloader()] 33 | assert len(ds) == 1 34 | ds = [b for b in qm9.test_dataloader()] 35 | assert len(ds) == 1 36 | 37 | 38 | @pytest.fixture 39 | def test_md17_path(): 40 | path = os.path.join(os.path.dirname(__file__), "../testdata/tmp/test_md17.db") 41 | return path 42 | 43 | 44 | @pytest.mark.skip( 45 | "Run only local, not in CI. Otherwise takes too long and requires downloading " 46 | + "the data" 47 | ) 48 | def test_md17(test_md17_path): 49 | md17 = MD17( 50 | test_md17_path, 51 | num_train=10, 52 | num_val=5, 53 | num_test=5, 54 | batch_size=5, 55 | molecule="uracil", 56 | ) 57 | md17.prepare_data() 58 | md17.setup() 59 | assert len(md17.train_dataset) == 10 60 | assert len(md17.val_dataset) == 5 61 | assert len(md17.test_dataset) == 5 62 | 63 | ds = [b for b in md17.train_dataloader()] 64 | assert len(ds) == 2 65 | ds = [b for b in md17.val_dataloader()] 66 | assert len(ds) == 1 67 | ds = [b for b in md17.test_dataloader()] 68 | assert len(ds) == 1 69 | 70 | 71 | @pytest.fixture 72 | def test_rmd17_path(): 73 | path = os.path.join(os.path.dirname(__file__), "../testdata/tmp/test_rmd17.db") 74 | return path 75 | 76 | 77 | @pytest.mark.skip( 78 | "Run only local, not in CI. Otherwise takes too long and requires downloading " 79 | + "the data" 80 | ) 81 | def test_rmd17(test_rmd17_path): 82 | md17 = rMD17( 83 | test_rmd17_path, 84 | num_train=950, 85 | num_val=50, 86 | num_test=1000, 87 | batch_size=5, 88 | molecule="uracil", 89 | ) 90 | md17.prepare_data() 91 | md17.setup() 92 | assert len(md17.train_dataset) == 950 93 | assert len(md17.val_dataset) == 50 94 | assert len(md17.test_dataset) == 1000 95 | 96 | train_idx = md17.train_dataset.subset_idx 97 | val_idx = md17.val_dataset.subset_idx 98 | test_idx = md17.test_dataset.subset_idx 99 | assert len(np.intersect1d(train_idx, val_idx)) == 0 100 | assert len(np.intersect1d(train_idx, test_idx)) == 0 101 | assert len(np.intersect1d(val_idx, test_idx)) == 0 102 | -------------------------------------------------------------------------------- /tests/data/test_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from schnetpack.transform import ASENeighborList 4 | from schnetpack.data.loader import _atoms_collate_fn 5 | import schnetpack.properties as structure 6 | 7 | 8 | def test_collate_noenv(single_atom, two_atoms): 9 | batch = [single_atom, two_atoms] 10 | collated_batch = _atoms_collate_fn(batch) 11 | assert all([key in collated_batch.keys() for key in single_atom]) 12 | assert structure.idx_m in collated_batch.keys() 13 | assert (collated_batch[structure.idx_m] == torch.tensor((0, 1, 1))).all() 14 | 15 | 16 | def test_collate_env(single_atom, two_atoms): 17 | nll = ASENeighborList(cutoff=5.0) 18 | batch = [nll(single_atom), nll(two_atoms)] 19 | 20 | collated_batch = _atoms_collate_fn(batch) 21 | assert all([key in collated_batch.keys() for key in single_atom]) 22 | assert structure.idx_m in collated_batch.keys() 23 | assert (collated_batch[structure.idx_m] == torch.tensor((0, 1, 1))).all() 24 | assert ( 25 | collated_batch[structure.idx_i] == torch.tensor((1, 2)) 26 | ).all(), collated_batch[structure.idx_i] 27 | assert ( 28 | collated_batch[structure.idx_j] == torch.tensor((2, 1)) 29 | ).all(), collated_batch[structure.idx_j] 30 | -------------------------------------------------------------------------------- /tests/data/test_transforms.py: -------------------------------------------------------------------------------- 1 | import schnetpack.properties as structure 2 | import pytest 3 | import torch 4 | from ase.data import atomic_masses 5 | from schnetpack.transform import * 6 | 7 | 8 | def assert_consistent(orig, transformed): 9 | for k, v in orig.items(): 10 | assert (v == transformed[k]).all(), f"Changed value: {k}" 11 | 12 | 13 | @pytest.fixture(params=[0, 1]) 14 | def neighbor_list(request): 15 | neighbor_lists = [ASENeighborList, TorchNeighborList] 16 | return neighbor_lists[request.param] 17 | 18 | 19 | class TestNeighborLists: 20 | """ 21 | Test for different neighbor lists defined in neighbor_list using the Argon environment fixtures (periodic and 22 | non-periodic). 23 | 24 | """ 25 | 26 | def test_neighbor_list(self, neighbor_list, environment): 27 | cutoff, props, neighbors_ref = environment 28 | neighbor_list = neighbor_list(cutoff) 29 | neighbors = neighbor_list(props) 30 | R = props[structure.R] 31 | neighbors[structure.Rij] = ( 32 | R[neighbors[structure.idx_j]] 33 | - R[neighbors[structure.idx_i]] 34 | + props[structure.offsets] 35 | ) 36 | 37 | neighbors = self._sort_neighbors(neighbors) 38 | neighbors_ref = self._sort_neighbors(neighbors_ref) 39 | 40 | for nbl, nbl_ref in zip(neighbors, neighbors_ref): 41 | torch.testing.assert_close(nbl, nbl_ref) 42 | 43 | def _sort_neighbors(self, neighbors): 44 | """ 45 | Routine for sorting the index, shift and distance vectors to allow comparison between different 46 | neighbor list implementations. 47 | 48 | Args: 49 | neighbors: Input dictionary holding system neighbor information (idx_i, idx_j, cell_offset and Rij) 50 | 51 | Returns: 52 | torch.LongTensor: indices of central atoms in each pair 53 | torch.LongTensor: indices of each neighbor 54 | torch.LongTensor: cell offsets 55 | torch.Tensor: distance vectors associated with each pair 56 | """ 57 | idx_i = neighbors[structure.idx_i] 58 | idx_j = neighbors[structure.idx_j] 59 | Rij = neighbors[structure.Rij] 60 | 61 | sort_idx = self._get_unique_idx(idx_i, idx_j, Rij) 62 | 63 | return idx_i[sort_idx], idx_j[sort_idx], Rij[sort_idx] 64 | 65 | @staticmethod 66 | def _get_unique_idx( 67 | idx_i: torch.Tensor, idx_j: torch.Tensor, offsets: torch.Tensor 68 | ): 69 | """ 70 | Compute unique indices for every neighbor pair based on the central atom, the neighbor and the cell the 71 | neighbor belongs to. This is used for sorting the neighbor lists in order to compare between different 72 | implementations. 73 | 74 | Args: 75 | idx_i: indices of central atoms in each pair 76 | idx_j: indices of each neighbor 77 | offsets: cell offsets 78 | 79 | Returns: 80 | torch.LongTensor: indices used for sorting each tensor in a unique manner 81 | """ 82 | n_max = torch.max(torch.abs(offsets)) 83 | 84 | n_repeats = 2 * n_max + 1 85 | n_atoms = torch.max(idx_i) + 1 86 | 87 | unique_idx = ( 88 | n_repeats**3 * (n_atoms * idx_i + idx_j) 89 | + (offsets[:, 0] + n_max) 90 | + n_repeats * (offsets[:, 1] + n_max) 91 | + n_repeats**2 * (offsets[:, 2] + n_max) 92 | ) 93 | 94 | return torch.argsort(unique_idx) 95 | 96 | 97 | def test_single_atom(single_atom, neighbor_list, cutoff): 98 | neighbor_list = neighbor_list(cutoff) 99 | props_after = neighbor_list(single_atom) 100 | R = props_after[structure.R] 101 | props_after[structure.Rij] = ( 102 | R[props_after[structure.idx_j]] 103 | - R[props_after[structure.idx_i]] 104 | + props_after[structure.offsets] 105 | ) 106 | 107 | assert_consistent(single_atom, props_after) 108 | assert len(props_after[structure.offsets]) == 0 109 | assert len(props_after[structure.idx_i]) == 0 110 | assert len(props_after[structure.idx_j]) == 0 111 | 112 | 113 | def test_cast(single_atom): 114 | allf64 = [k for k, v in single_atom.items() if v.dtype is torch.float64] 115 | other_types = { 116 | k: v.dtype for k, v in single_atom.items() if v.dtype is not torch.float64 117 | } 118 | 119 | assert len(allf64) > 0, single_atom 120 | props_after = CastTo32()(single_atom) 121 | 122 | for k in props_after: 123 | if k in allf64: 124 | assert props_after[k].dtype is torch.float32 125 | else: 126 | assert props_after[k].dtype is other_types[k] 127 | 128 | 129 | def test_remove_com(four_atoms): 130 | positions_trans = SubtractCenterOfMass()(four_atoms) 131 | 132 | com = torch.tensor([0.0, 0.0, 0.0]) 133 | for r_i, m_i in zip( 134 | positions_trans[structure.position], atomic_masses[four_atoms[structure.Z]] 135 | ): 136 | com += r_i * m_i 137 | 138 | torch.testing.assert_close(com, torch.tensor([0.0, 0.0, 0.0])) 139 | 140 | 141 | def test_remove_cog(four_atoms): 142 | positions_trans = SubtractCenterOfGeometry()(four_atoms) 143 | 144 | cog = torch.tensor([0.0, 0.0, 0.0]) 145 | for r_i in positions_trans[structure.position]: 146 | cog += r_i 147 | 148 | torch.testing.assert_close(cog, torch.tensor([0.0, 0.0, 0.0])) 149 | -------------------------------------------------------------------------------- /tests/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/nn/__init__.py -------------------------------------------------------------------------------- /tests/nn/test_activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import schnetpack as spk 5 | 6 | 7 | def test_activation_softplus(): 8 | # simple tensor 9 | x = torch.tensor([0.0, 1.0, 0.5, 2.0]) 10 | expt = torch.log(1.0 + torch.exp(x)) - np.log(2) 11 | assert torch.allclose(expt, spk.nn.shifted_softplus(x), atol=0.0, rtol=1.0e-7) 12 | # random tensor 13 | torch.manual_seed(42) 14 | x = torch.randn((10, 5), dtype=torch.double) 15 | expt = torch.log(1.0 + torch.exp(x)) - np.log(2) 16 | assert torch.allclose(expt, spk.nn.shifted_softplus(x), atol=0.0, rtol=1.0e-7) 17 | x = 10 * torch.randn((10, 5), dtype=torch.double) 18 | expt = torch.log(1.0 + torch.exp(x)) - np.log(2) 19 | assert torch.allclose(expt, spk.nn.shifted_softplus(x), atol=0.0, rtol=1.0e-7) 20 | 21 | 22 | def test_shape_ssp(): 23 | in_data = torch.rand(10) 24 | out_data = spk.nn.shifted_softplus(in_data) 25 | assert in_data.shape == out_data.shape 26 | -------------------------------------------------------------------------------- /tests/nn/test_cutoff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from schnetpack.nn.cutoff import CosineCutoff, MollifierCutoff 5 | 6 | 7 | def test_cutoff_cosine(): 8 | # cosine cutoff with radius 1.8 9 | cutoff = CosineCutoff(cutoff=1.8) 10 | # check cutoff radius 11 | assert abs(1.8 - cutoff.cutoff) < 1.0e-12 12 | # random tensor with elements in [0, 1) 13 | torch.manual_seed(42) 14 | dist = torch.rand((10, 5, 20), dtype=torch.float) 15 | # check cutoff values 16 | expt = 0.5 * (1.0 + torch.cos(dist * np.pi / 1.8)) 17 | assert torch.allclose(expt, cutoff(dist), atol=0.0, rtol=1.0e-7) 18 | # compute expected values for 3.5 times distance 19 | values = 0.5 * (1.0 + torch.cos(3.5 * dist * np.pi / 1.8)) 20 | values[3.5 * dist >= 1.8] = 0.0 21 | assert torch.allclose(values, cutoff(3.5 * dist), atol=0.0, rtol=1.0e-7) 22 | 23 | 24 | def test_cutoff_mollifier(): 25 | # mollifier cutoff with radius 2.3 26 | cutoff = MollifierCutoff(cutoff=2.3) 27 | # check cutoff radius 28 | assert abs(2.3 - cutoff.cutoff) < 1.0e-12 29 | # tensor of zeros 30 | dist = torch.zeros((4, 1, 1)) 31 | assert torch.allclose(torch.ones(4, 1, 1), cutoff(dist), atol=0.0, rtol=1.0e-7) 32 | # random tensor with elements in [0, 1) 33 | torch.manual_seed(42) 34 | dist = torch.rand((1, 3, 9), dtype=torch.float) 35 | # check cutoff values 36 | expt = torch.exp(1.0 - 1.0 / (1.0 - (dist / 2.3) ** 2)) 37 | assert torch.allclose(expt, cutoff(dist), atol=0.0, rtol=1.0e-7) 38 | # compute cutoff values and expected values 39 | comp = cutoff(3.8 * dist) 40 | expt = torch.exp(1.0 - 1.0 / (1.0 - (3.8 * dist / 2.3) ** 2)) 41 | expt[3.8 * dist >= 2.3] = 0.0 42 | assert torch.allclose(expt, comp, atol=0.0, rtol=1.0e-7) 43 | -------------------------------------------------------------------------------- /tests/nn/test_radial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from schnetpack.nn.radial import GaussianRBF 4 | 5 | 6 | def test_smear_gaussian_one_distance(): 7 | # case of one distance 8 | dist = torch.tensor([[[1.0]]]) 9 | 10 | smear = GaussianRBF(n_rbf=6, cutoff=5.0, trainable=False) 11 | expt = torch.exp(-0.5 * torch.tensor([[[1.0, 0.0, 1.0, 4.0, 9.0, 16.0]]])) 12 | assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7) 13 | assert list(smear.parameters()) == [] 14 | 15 | 16 | def test_smear_gaussian_one_distance_trainable(): 17 | dist = torch.tensor([[[1.0]]]) 18 | expt = torch.exp(-0.5 * torch.tensor([[[1.0, 0.0, 1.0, 4.0, 9.0, 16.0]]])) 19 | 20 | smear = GaussianRBF(n_rbf=6, cutoff=5.0, trainable=True) 21 | assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7) 22 | params = list(smear.parameters()) 23 | assert len(params) == 2 24 | assert len(params[0]) == 6 25 | assert len(params[1]) == 6 26 | 27 | 28 | def test_smear_gaussian(): 29 | dist = torch.tensor([[[0.0, 1.0, 1.5], [0.5, 1.5, 3.0]]]) 30 | # smear using 4 Gaussian functions with 1. spacing 31 | smear = GaussianRBF(start=1.0, cutoff=4.0, n_rbf=4) 32 | # absolute value of centered distances 33 | expt = torch.tensor( 34 | [ 35 | [ 36 | [[1, 2, 3, 4], [0, 1, 2, 3], [0.5, 0.5, 1.5, 2.5]], 37 | [[0.5, 1.5, 2.5, 3.5], [0.5, 0.5, 1.5, 2.5], [2, 1, 0, 1]], 38 | ] 39 | ] 40 | ) 41 | expt = torch.exp(-0.5 * expt**2) 42 | assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7) 43 | assert list(smear.parameters()) == [] 44 | 45 | 46 | def test_smear_gaussian_trainable(): 47 | dist = torch.tensor([[[0.0, 1.0, 1.5, 0.25], [0.5, 1.5, 3.0, 1.0]]]) 48 | # smear using 5 Gaussian functions with 0.75 spacing 49 | smear = GaussianRBF(start=1.0, cutoff=4.0, n_rbf=5, trainable=True) 50 | # absolute value of centered distances 51 | expt = torch.tensor( 52 | [ 53 | [ 54 | [ 55 | [1, 1.75, 2.5, 3.25, 4.0], 56 | [0, 0.75, 1.5, 2.25, 3.0], 57 | [0.5, 0.25, 1.0, 1.75, 2.5], 58 | [0.75, 1.5, 2.25, 3.0, 3.75], 59 | ], 60 | [ 61 | [0.5, 1.25, 2.0, 2.75, 3.5], 62 | [0.5, 0.25, 1.0, 1.75, 2.5], 63 | [2.0, 1.25, 0.5, 0.25, 1.0], 64 | [0, 0.75, 1.5, 2.25, 3.0], 65 | ], 66 | ] 67 | ] 68 | ) 69 | expt = torch.exp((-0.5 / 0.75**2) * expt**2) 70 | assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7) 71 | params = list(smear.parameters()) 72 | assert len(params) == 2 73 | assert len(params[0]) == 5 74 | assert len(params[1]) == 5 75 | -------------------------------------------------------------------------------- /tests/nn/test_schnet.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import schnetpack.properties as structure 5 | import schnetpack as spk 6 | import numpy as np 7 | from ase.neighborlist import neighbor_list 8 | 9 | from schnetpack.representation.schnet import SchNet 10 | 11 | # TODO:make proper timing and golden tests 12 | 13 | 14 | @pytest.fixture 15 | def indexed_data(example_data, batch_size): 16 | Z = [] 17 | R = [] 18 | C = [] 19 | seg_m = [] 20 | ind_i = [] 21 | ind_j = [] 22 | ind_S = [] 23 | Rij = [] 24 | 25 | n_atoms = 0 26 | n_pairs = 0 27 | for i in range(len(example_data)): 28 | seg_m.append(n_atoms) 29 | atoms = example_data[i][0] 30 | atoms.set_pbc(False) 31 | Z.append(atoms.numbers) 32 | R.append(atoms.positions) 33 | C.append(atoms.cell) 34 | idx_i, idx_j, idx_S, rij = neighbor_list( 35 | "ijSD", atoms, 5.0, self_interaction=False 36 | ) 37 | _, seg_im = np.unique(idx_i, return_counts=True) 38 | ind_i.append(idx_i + n_atoms) 39 | ind_j.append(idx_j + n_atoms) 40 | ind_S.append(idx_S) 41 | Rij.append(rij.astype(np.float32)) 42 | n_atoms += len(atoms) 43 | n_pairs += len(idx_i) 44 | if i + 1 >= batch_size: 45 | break 46 | seg_m.append(n_atoms) 47 | 48 | Z = np.hstack(Z) 49 | R = np.vstack(R).astype(np.float32) 50 | C = np.array(C).astype(np.float32) 51 | seg_m = np.hstack(seg_m) 52 | ind_i = np.hstack(ind_i) 53 | ind_j = np.hstack(ind_j) 54 | ind_S = np.vstack(ind_S) 55 | Rij = np.vstack(Rij) 56 | 57 | inputs = { 58 | structure.Z: torch.tensor(Z), 59 | structure.position: torch.tensor(R), 60 | structure.cell: torch.tensor(C), 61 | structure.idx_m: torch.tensor(seg_m), 62 | structure.idx_j: torch.tensor(ind_j), 63 | structure.idx_i: torch.tensor(ind_i), 64 | structure.Rij: torch.tensor(Rij), 65 | } 66 | 67 | return inputs 68 | 69 | 70 | def test_schnet_new_coo(indexed_data, benchmark): 71 | radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=5.0) 72 | cutoff_fn = spk.nn.CosineCutoff(5.0) 73 | schnet = SchNet( 74 | n_atom_basis=128, 75 | n_interactions=3, 76 | radial_basis=radial_basis, 77 | cutoff_fn=cutoff_fn, 78 | ) 79 | 80 | benchmark(schnet, indexed_data) 81 | 82 | 83 | def test_schnet_new_script(indexed_data, benchmark): 84 | 85 | radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=5.0) 86 | cutoff_fn = spk.nn.CosineCutoff(5.0) 87 | schnet = SchNet( 88 | n_atom_basis=128, 89 | n_interactions=3, 90 | radial_basis=radial_basis, 91 | cutoff_fn=cutoff_fn, 92 | ) 93 | schnet = torch.jit.script(schnet) 94 | schnet(indexed_data) 95 | 96 | benchmark(schnet, indexed_data) 97 | -------------------------------------------------------------------------------- /tests/testdata/md_ethanol.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/testdata/md_ethanol.model -------------------------------------------------------------------------------- /tests/testdata/md_ethanol.xyz: -------------------------------------------------------------------------------- 1 | 9 2 | 3 | C -4.92196480914482 1.53680877549233 -0.06612792847094 4 | C -3.41079303549336 1.45138155063184 -0.14009009720834 5 | H -5.22648850340463 2.28202241947302 0.66236410391492 6 | H -5.34004680800574 0.57895313793668 0.22257334141131 7 | H -5.33193076526251 1.80898014947387 -1.03229511269262 8 | H -3.00368348713509 1.18933429199764 0.83479697695625 9 | H -2.99557504133053 2.41817570143478 -0.41886385105291 10 | O -3.07553304550781 0.47652256654287 -1.09348059854212 11 | H -2.13350450471551 0.40432140701697 -1.15817683431555 12 | -------------------------------------------------------------------------------- /tests/testdata/si16.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/testdata/si16.model -------------------------------------------------------------------------------- /tests/testdata/test_qm9.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/testdata/test_qm9.db -------------------------------------------------------------------------------- /tests/testdata/tmp/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/testdata/tmp/.gitkeep -------------------------------------------------------------------------------- /tests/user_config/user_exp.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | user_exp: True -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py38,qa 3 | 4 | [testenv] 5 | deps = 6 | pytest 7 | pytest-cov 8 | pytest-datadir 9 | commands = 10 | pip install -r docs/sphinx-requirements.txt 11 | pip install -e .[test] 12 | pip list 13 | pytest --cov-report term --cov=schnetpack 14 | 15 | # prevent exit when error is encountered 16 | ignore_errors = true 17 | 18 | [testenv:qa] 19 | deps = 20 | black 21 | commands = 22 | black -v -l 88 --check --diff src/schnetpack 23 | # prevent exit when error is encountered 24 | ignore_errors = true 25 | --------------------------------------------------------------------------------