├── .github
└── workflows
│ ├── black.yml
│ └── python-publish.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── docs
├── Makefile
├── _templates
│ └── autosummary
│ │ └── classtemplate.rst
├── api
│ ├── atomistic.rst
│ ├── data.rst
│ ├── datasets.rst
│ ├── md.rst
│ ├── model.rst
│ ├── nn.rst
│ ├── representation.rst
│ ├── schnetpack.rst
│ ├── task.rst
│ ├── train.rst
│ └── transform.rst
├── conf.py
├── getstarted.rst
├── howtos
├── index.rst
├── pictures
│ └── tensorboard.png
├── sphinx-requirements.txt
├── tutorials
└── userguide
│ ├── configs.rst
│ ├── md.rst
│ └── overview.rst
├── examples
├── README.md
├── howtos
│ ├── howto_batchwise_relaxations.ipynb
│ └── lammps.rst
└── tutorials
│ ├── tutorial_01_preparing_data.ipynb
│ ├── tutorial_02_qm9.ipynb
│ ├── tutorial_03_force_models.ipynb
│ ├── tutorial_04_molecular_dynamics.ipynb
│ ├── tutorial_05_materials.ipynb
│ └── tutorials_figures
│ ├── integrator.svg
│ ├── md_flowchart.svg
│ └── rpmd.svg
├── interfaces
└── lammps
│ ├── examples
│ └── aspirin
│ │ ├── aspirin.data
│ │ ├── aspirin_md.in
│ │ └── best_model
│ ├── pair_schnetpack.cpp
│ ├── pair_schnetpack.h
│ └── patch_lammps.sh
├── pyproject.toml
├── readthedocs.yaml
├── src
├── schnetpack
│ ├── __init__.py
│ ├── atomistic
│ │ ├── __init__.py
│ │ ├── aggregation.py
│ │ ├── atomwise.py
│ │ ├── distances.py
│ │ ├── electrostatic.py
│ │ ├── external_fields.py
│ │ ├── nuclear_repulsion.py
│ │ └── response.py
│ ├── cli.py
│ ├── configs
│ │ ├── __init__.py
│ │ ├── callbacks
│ │ │ ├── checkpoint.yaml
│ │ │ ├── earlystopping.yaml
│ │ │ ├── ema.yaml
│ │ │ └── lrmonitor.yaml
│ │ ├── data
│ │ │ ├── ani1.yaml
│ │ │ ├── custom.yaml
│ │ │ ├── iso17.yaml
│ │ │ ├── materials_project.yaml
│ │ │ ├── md17.yaml
│ │ │ ├── md22.yaml
│ │ │ ├── omdb.yaml
│ │ │ ├── qm7x.yaml
│ │ │ ├── qm9.yaml
│ │ │ ├── rmd17.yaml
│ │ │ └── sampler
│ │ │ │ └── stratified_property.yaml
│ │ ├── experiment
│ │ │ ├── md17.yaml
│ │ │ ├── qm9_atomwise.yaml
│ │ │ ├── qm9_dipole.yaml
│ │ │ ├── response.yaml
│ │ │ └── rmd17.yaml
│ │ ├── globals
│ │ │ └── default_globals.yaml
│ │ ├── logger
│ │ │ ├── aim.yaml
│ │ │ ├── csv.yaml
│ │ │ ├── tensorboard.yaml
│ │ │ └── wandb.yaml
│ │ ├── model
│ │ │ ├── nnp.yaml
│ │ │ └── representation
│ │ │ │ ├── field_schnet.yaml
│ │ │ │ ├── painn.yaml
│ │ │ │ ├── radial_basis
│ │ │ │ ├── bessel.yaml
│ │ │ │ └── gaussian.yaml
│ │ │ │ ├── schnet.yaml
│ │ │ │ └── so3net.yaml
│ │ ├── predict.yaml
│ │ ├── run
│ │ │ └── default_run.yaml
│ │ ├── task
│ │ │ ├── default_task.yaml
│ │ │ ├── optimizer
│ │ │ │ ├── adabelief.yaml
│ │ │ │ ├── adam.yaml
│ │ │ │ └── sgd.yaml
│ │ │ └── scheduler
│ │ │ │ └── reduce_on_plateau.yaml
│ │ ├── train.yaml
│ │ └── trainer
│ │ │ ├── ddp_debug.yaml
│ │ │ ├── ddp_trainer.yaml
│ │ │ ├── debug_trainer.yaml
│ │ │ └── default_trainer.yaml
│ ├── data
│ │ ├── __init__.py
│ │ ├── atoms.py
│ │ ├── datamodule.py
│ │ ├── loader.py
│ │ ├── sampler.py
│ │ ├── splitting.py
│ │ └── stats.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── ani1.py
│ │ ├── iso17.py
│ │ ├── materials_project.py
│ │ ├── md17.py
│ │ ├── md22.py
│ │ ├── omdb.py
│ │ ├── qm7x.py
│ │ ├── qm9.py
│ │ ├── rmd17.py
│ │ └── tmqm.py
│ ├── interfaces
│ │ ├── __init__.py
│ │ ├── ase_interface.py
│ │ └── batchwise_optimization.py
│ ├── md
│ │ ├── __init__.py
│ │ ├── calculators
│ │ │ ├── __init__.py
│ │ │ ├── base_calculator.py
│ │ │ ├── ensemble_calculator.py
│ │ │ ├── lj_calculator.py
│ │ │ ├── orca_calculator.py
│ │ │ └── schnetpack_calculator.py
│ │ ├── cli.py
│ │ ├── data
│ │ │ ├── __init__.py
│ │ │ ├── hdf5_data.py
│ │ │ └── spectra.py
│ │ ├── initial_conditions.py
│ │ ├── integrators.py
│ │ ├── md_configs
│ │ │ ├── __init__.py
│ │ │ ├── calculator
│ │ │ │ ├── lj.yaml
│ │ │ │ ├── neighbor_list
│ │ │ │ │ ├── ase.yaml
│ │ │ │ │ ├── matscipy.yaml
│ │ │ │ │ └── torch.yaml
│ │ │ │ ├── orca.yaml
│ │ │ │ ├── spk.yaml
│ │ │ │ └── spk_ensemble.yaml
│ │ │ ├── callbacks
│ │ │ │ ├── checkpoint.yaml
│ │ │ │ ├── hdf5.yaml
│ │ │ │ └── tensorboard.yaml
│ │ │ ├── config.yaml
│ │ │ ├── dynamics
│ │ │ │ ├── barostat
│ │ │ │ │ ├── nhc_aniso.yaml
│ │ │ │ │ ├── nhc_iso.yaml
│ │ │ │ │ └── pile_rpmd.yaml
│ │ │ │ ├── base.yaml
│ │ │ │ ├── integrator
│ │ │ │ │ ├── md.yaml
│ │ │ │ │ └── rpmd.yaml
│ │ │ │ └── thermostat
│ │ │ │ │ ├── berendsen.yaml
│ │ │ │ │ ├── gle.yaml
│ │ │ │ │ ├── langevin.yaml
│ │ │ │ │ ├── nhc.yaml
│ │ │ │ │ ├── pi_gle.yaml
│ │ │ │ │ ├── pi_nhc_global.yaml
│ │ │ │ │ ├── pi_nhc_local.yaml
│ │ │ │ │ ├── piglet.yaml
│ │ │ │ │ ├── pile_global.yaml
│ │ │ │ │ ├── pile_local.yaml
│ │ │ │ │ └── trpmd.yaml
│ │ │ └── system
│ │ │ │ ├── initializer
│ │ │ │ ├── maxwell_boltzmann.yaml
│ │ │ │ └── uniform.yaml
│ │ │ │ └── system.yaml
│ │ ├── neighborlist_md.py
│ │ ├── parsers
│ │ │ ├── __init__.py
│ │ │ └── orca_parser.py
│ │ ├── simulation_hooks
│ │ │ ├── __init__.py
│ │ │ ├── barostats.py
│ │ │ ├── barostats_rpmd.py
│ │ │ ├── basic_hooks.py
│ │ │ ├── callback_hooks.py
│ │ │ ├── thermostats.py
│ │ │ └── thermostats_rpmd.py
│ │ ├── simulator.py
│ │ ├── system.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── md_config.py
│ │ │ ├── normal_model_transformation.py
│ │ │ └── thermostat_utils.py
│ ├── model
│ │ ├── __init__.py
│ │ └── base.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── base.py
│ │ ├── blocks.py
│ │ ├── cutoff.py
│ │ ├── embedding.py
│ │ ├── equivariant.py
│ │ ├── ops
│ │ │ ├── __init__.py
│ │ │ ├── math.py
│ │ │ └── so3.py
│ │ ├── radial.py
│ │ ├── scatter.py
│ │ ├── so3.py
│ │ └── utils.py
│ ├── properties.py
│ ├── representation
│ │ ├── __init__.py
│ │ ├── field_schnet.py
│ │ ├── painn.py
│ │ ├── schnet.py
│ │ └── so3net.py
│ ├── task.py
│ ├── train
│ │ ├── __init__.py
│ │ ├── callbacks.py
│ │ ├── lr_scheduler.py
│ │ └── metrics.py
│ ├── transform
│ │ ├── __init__.py
│ │ ├── atomistic.py
│ │ ├── base.py
│ │ ├── casting.py
│ │ ├── neighborlist.py
│ │ └── response.py
│ ├── units.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── compatibility.py
│ │ └── script.py
└── scripts
│ ├── spkconvert
│ ├── spkdeploy
│ ├── spkmd
│ ├── spkpredict
│ └── spktrain
├── tests
├── README.md
├── __init__.py
├── atomistic
│ ├── __init__.py
│ └── test_response.py
├── conftest.py
├── data
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_data.py
│ ├── test_datasets.py
│ ├── test_loader.py
│ └── test_transforms.py
├── nn
│ ├── __init__.py
│ ├── test_activations.py
│ ├── test_cutoff.py
│ ├── test_radial.py
│ └── test_schnet.py
├── testdata
│ ├── md_ethanol.model
│ ├── md_ethanol.xyz
│ ├── si16.model
│ ├── test_qm9.db
│ └── tmp
│ │ └── .gitkeep
└── user_config
│ └── user_exp.yaml
└── tox.ini
/.github/workflows/black.yml:
--------------------------------------------------------------------------------
1 | name: Black Code Formatter
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | black:
7 | runs-on: ubuntu-latest
8 |
9 | steps:
10 | - name: Checkout code
11 | uses: actions/checkout@v2
12 |
13 | - name: Set up Python
14 | uses: actions/setup-python@v2
15 | with:
16 | python-version: '3.x' # Specify the Python version you need
17 |
18 | - name: Install dependencies
19 | run: |
20 | python -m pip install --upgrade pip
21 | pip install black==24.4.2
22 |
23 | - name: Run black
24 | run: black --check .
25 |
26 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v3
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .idea
6 | *.DS_Store
7 |
8 | docs/tutorials/*.db
9 | docs/tutorials/*.xyz
10 | docs/tutorials/qm9tut
11 |
12 | # hydra stuff
13 | outputs
14 | logs
15 | generated
16 |
17 | # C extensions
18 | *.so
19 |
20 |
21 | # Distribution / packaging
22 | .Python
23 | env/
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | .vscode
40 | .github
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | .hypothesis/
62 | .pytest_cache
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | local_settings.py
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 | docs/api/generated/
81 |
82 | # PyBuilder
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # celery beat schedule file
92 | celerybeat-schedule
93 |
94 | # SageMath parsed files
95 | *.sage.py
96 |
97 | # dotenv
98 | .env
99 |
100 | # virtualenv
101 | .venv
102 | venv/
103 | ENV/
104 | env/
105 |
106 | # Spyder project settings
107 | .spyderproject
108 | .spyproject
109 |
110 | # Rope project settings
111 | .ropeproject
112 |
113 | # mkdocs documentation
114 | /site
115 |
116 | # mypy
117 | .mypy_cache/
118 |
119 | # experiments
120 | runs/
121 | tests/testdata/tmp/test*
122 |
123 | # lammps examples
124 | interfaces/lammps/examples/*/*.lammpstrj
125 | interfaces/lammps/examples/*/*.lammps
126 | interfaces/lammps/examples/*/*.dump
127 | interfaces/lammps/examples/*/*.dat
128 | interfaces/lammps/examples/*/deployed_model
129 |
130 | # batchwise optimizer examples
131 | examples/howtos/howto_batchwise_relaxations_outputs/*
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/python/black
3 | rev: 24.4.2
4 | hooks:
5 | - id: black
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | COPYRIGHT
2 |
3 | Copyright (c) 2018 Kristof Schütt, Michael Gastegger, Pan Kessel, Kim Nicoli
4 |
5 | All other contributions:
6 | Copyright (c) 2018, the respective contributors.
7 | All rights reserved.
8 |
9 | Each contributor holds copyright over their respective contributions.
10 | The project versioning (Git) records all such contribution source information.
11 |
12 | LICENSE
13 |
14 | The MIT License
15 |
16 | Permission is hereby granted, free of charge, to any person obtaining a copy
17 | of this software and associated documentation files (the "Software"), to deal
18 | in the Software without restriction, including without limitation the rights
19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20 | copies of the Software, and to permit persons to whom the Software is
21 | furnished to do so, subject to the following conditions:
22 |
23 | The above copyright notice and this permission notice shall be included in all
24 | copies or substantial portions of the Software.
25 |
26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32 | SOFTWARE.
33 |
34 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = SchNetPack
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/_templates/autosummary/classtemplate.rst:
--------------------------------------------------------------------------------
1 | .. role:: hidden
2 | :class: hidden-section
3 |
4 |
5 | {{ name | underline}}
6 |
7 | .. autoclass:: {{ name }}
8 | :no-inherited-members:
9 | :members:
10 |
--------------------------------------------------------------------------------
/docs/api/atomistic.rst:
--------------------------------------------------------------------------------
1 | schnetpack.atomistic
2 | ====================
3 | .. currentmodule:: atomistic
4 |
5 | Output modules
6 | --------------
7 | .. rubric:: Atom-wise layers
8 |
9 | .. autosummary::
10 | :toctree: generated
11 | :nosignatures:
12 | :template: classtemplate.rst
13 |
14 | Atomwise
15 | DipoleMoment
16 | Polarizability
17 |
18 | .. rubric:: Response layers
19 |
20 | .. autosummary::
21 | :toctree: generated
22 | :nosignatures:
23 | :template: classtemplate.rst
24 |
25 | Forces
26 |
--------------------------------------------------------------------------------
/docs/api/data.rst:
--------------------------------------------------------------------------------
1 | schnetpack.data
2 | ===============
3 | .. currentmodule:: data
4 |
5 | Atoms data
6 | ------------
7 | .. autosummary::
8 | :toctree: generated
9 | :nosignatures:
10 | :recursive:
11 | :template: classtemplate.rst
12 |
13 | BaseAtomsData
14 | ASEAtomsData
15 | AtomsLoader
16 | resolve_format
17 | AtomsDataFormat
18 | StratifiedSampler
19 |
20 |
21 | Creation
22 | --------
23 | .. autosummary::
24 | :toctree: generated
25 | :nosignatures:
26 | :template: classtemplate.rst
27 |
28 | create_dataset
29 | load_dataset
30 |
31 | Data modules
32 | ------------
33 |
34 | .. autosummary::
35 | :toctree: generated
36 | :nosignatures:
37 | :template: classtemplate.rst
38 |
39 | AtomsDataModule
40 |
41 | Statistics
42 | ----------
43 |
44 | .. autosummary::
45 | :toctree: generated
46 | :nosignatures:
47 | :template: classtemplate.rst
48 |
49 | calculate_stats
50 | NumberOfAtomsCriterion
51 | PropertyCriterion
--------------------------------------------------------------------------------
/docs/api/datasets.rst:
--------------------------------------------------------------------------------
1 | schnetpack.datasets
2 | ===================
3 | .. currentmodule:: datasets
4 |
5 | Molecules
6 | ------------
7 | .. autosummary::
8 | :toctree: generated
9 | :nosignatures:
10 | :template: classtemplate.rst
11 |
12 | QM9
13 | MD17
14 | ANI1
15 | ISO17
16 |
17 | Materials
18 | ---------
19 | .. autosummary::
20 | :toctree: generated
21 | :nosignatures:
22 | :template: classtemplate.rst
23 |
24 | MaterialsProject
25 | OrganicMaterialsDatabase
26 |
27 |
28 |
--------------------------------------------------------------------------------
/docs/api/md.rst:
--------------------------------------------------------------------------------
1 | schnetpack.md
2 | =============
3 | .. currentmodule:: md
4 |
5 | This module contains all functionality for performing various molecular dynamics simulations using SchNetPack.
6 |
7 | System
8 | ------
9 |
10 | .. autosummary::
11 | :toctree: generated
12 | :nosignatures:
13 | :template: classtemplate.rst
14 |
15 | System
16 |
17 |
18 | Initial Conditions
19 | ------------------
20 |
21 | .. currentmodule:: md.initial_conditions
22 |
23 | .. autosummary::
24 | :toctree: generated
25 | :nosignatures:
26 | :template: classtemplate.rst
27 |
28 | Initializer
29 | MaxwellBoltzmannInit
30 | UniformInit
31 |
32 |
33 | Integrators
34 | -----------
35 |
36 | .. currentmodule:: md.integrators
37 |
38 | Integrators for NVE and NVT simulations:
39 |
40 | .. autosummary::
41 | :toctree: generated
42 | :nosignatures:
43 | :template: classtemplate.rst
44 |
45 | Integrator
46 | VelocityVerlet
47 | RingPolymer
48 |
49 | Integrators for NPT simulations:
50 |
51 | .. autosummary::
52 | :toctree: generated
53 | :nosignatures:
54 | :template: classtemplate.rst
55 |
56 | NPTVelocityVerlet
57 | NPTRingPolymer
58 |
59 |
60 | Calculators
61 | -----------
62 |
63 | .. currentmodule:: md.calculators
64 |
65 | Basic calculators:
66 |
67 | .. autosummary::
68 | :toctree: generated
69 | :nosignatures:
70 | :template: classtemplate.rst
71 |
72 | MDCalculator
73 | QMCalculator
74 | EnsembleCalculator
75 | LJCalculator
76 |
77 | Neural network potentials and ORCA calculators:
78 |
79 | .. autosummary::
80 | :toctree: generated
81 | :nosignatures:
82 | :template: classtemplate.rst
83 |
84 | SchNetPackCalculator
85 | SchNetPackEnsembleCalculator
86 | OrcaCalculator
87 |
88 |
89 | Neighbor List
90 | -------------
91 |
92 | .. currentmodule:: md.neighborlist_md
93 |
94 | .. autosummary::
95 | :toctree: generated
96 | :nosignatures:
97 | :template: classtemplate.rst
98 |
99 | NeighborListMD
100 |
101 |
102 | Simulator
103 | ---------
104 |
105 | .. currentmodule:: md
106 |
107 | .. autosummary::
108 | :toctree: generated
109 | :nosignatures:
110 | :template: classtemplate.rst
111 |
112 | Simulator
113 |
114 |
115 | Simulation hooks
116 | ----------------
117 |
118 | .. currentmodule:: md.simulation_hooks
119 |
120 | Basic hooks:
121 |
122 | .. autosummary::
123 | :toctree: generated
124 | :nosignatures:
125 | :template: classtemplate.rst
126 |
127 | SimulationHook
128 | RemoveCOMMotion
129 |
130 | Thermostats:
131 |
132 | .. autosummary::
133 | :toctree: generated
134 | :nosignatures:
135 | :template: classtemplate.rst
136 |
137 | ThermostatHook
138 | BerendsenThermostat
139 | LangevinThermostat
140 | NHCThermostat
141 | GLEThermostat
142 |
143 | Thermostats for ring-polymer MD:
144 |
145 | .. autosummary::
146 | :toctree: generated
147 | :nosignatures:
148 | :template: classtemplate.rst
149 |
150 | PILELocalThermostat
151 | PILEGlobalThermostat
152 | TRPMDThermostat
153 | RPMDGLEThermostat
154 | PIGLETThermostat
155 | NHCRingPolymerThermostat
156 |
157 | Barostats:
158 |
159 | .. autosummary::
160 | :toctree: generated
161 | :nosignatures:
162 | :template: classtemplate.rst
163 |
164 | BarostatHook
165 | NHCBarostatIsotropic
166 | NHCBarostatAnisotropic
167 |
168 | Barostats for ring-polymer MD:
169 |
170 | .. autosummary::
171 | :toctree: generated
172 | :nosignatures:
173 | :template: classtemplate.rst
174 |
175 | PILEBarostat
176 |
177 | Logging and callback
178 |
179 | .. autosummary::
180 | :toctree: generated
181 | :nosignatures:
182 | :template: classtemplate.rst
183 |
184 | Checkpoint
185 | DataStream
186 | MoleculeStream
187 | PropertyStream
188 | FileLogger
189 | BasicTensorboardLogger
190 | TensorBoardLogger
191 |
192 |
193 | Simulation data and postprocessing
194 | ----------------------------------
195 |
196 | .. currentmodule:: md.data
197 |
198 | Data loading:
199 |
200 | .. autosummary::
201 | :toctree: generated
202 | :nosignatures:
203 | :template: classtemplate.rst
204 |
205 | HDF5Loader
206 |
207 | Vibrational spectra:
208 |
209 | .. autosummary::
210 | :toctree: generated
211 | :nosignatures:
212 | :template: classtemplate.rst
213 |
214 | VibrationalSpectrum
215 | PowerSpectrum
216 | IRSpectrum
217 | RamanSpectrum
218 |
219 |
220 | ORCA output parsing
221 | -------------------
222 |
223 | .. currentmodule:: md.parsers
224 |
225 | .. autosummary::
226 | :toctree: generated
227 | :nosignatures:
228 | :template: classtemplate.rst
229 |
230 | OrcaParser
231 | OrcaOutputParser
232 | OrcaFormatter
233 | OrcaPropertyParser
234 | OrcaMainFileParser
235 | OrcaHessianFileParser
236 |
237 |
238 | MD utilities
239 | ------------
240 |
241 | .. currentmodule:: md.utils
242 |
243 | .. autosummary::
244 | :toctree: generated
245 | :nosignatures:
246 | :template: classtemplate.rst
247 |
248 | NormalModeTransformer
249 |
250 | Utilities for thermostats
251 |
252 | .. currentmodule:: md.utils.thermostat_utils
253 |
254 | .. autosummary::
255 | :toctree: generated
256 | :nosignatures:
257 | :template: classtemplate.rst
258 |
259 | YSWeights
260 | GLEMatrixParser
261 | load_gle_matrices
262 | StableSinhDiv
263 |
--------------------------------------------------------------------------------
/docs/api/model.rst:
--------------------------------------------------------------------------------
1 | schnetpack.model
2 | ================
3 | .. currentmodule:: model
4 |
5 | .. autosummary::
6 | :toctree: generated
7 | :nosignatures:
8 | :template: classtemplate.rst
9 |
10 | AtomisticModel
11 | NeuralNetworkPotential
12 |
13 |
--------------------------------------------------------------------------------
/docs/api/nn.rst:
--------------------------------------------------------------------------------
1 | schnetpack.nn
2 | =============
3 | .. currentmodule:: nn
4 |
5 |
6 | Basic layers
7 | ------------
8 |
9 | .. autosummary::
10 | :toctree: generated
11 | :nosignatures:
12 | :template: classtemplate.rst
13 |
14 | Dense
15 |
16 |
17 | Equivariant layers
18 | ------------------
19 |
20 | Cartesian:
21 |
22 | .. autosummary::
23 | :toctree: generated
24 | :nosignatures:
25 | :template: classtemplate.rst
26 |
27 | GatedEquivariantBlock
28 |
29 | Irreps:
30 |
31 | .. autosummary::
32 | :toctree: generated
33 | :nosignatures:
34 | :template: classtemplate.rst
35 |
36 | RealSphericalHarmonics
37 | SO3TensorProduct
38 | SO3Convolution
39 | SO3GatedNonlinearity
40 | SO3ParametricGatedNonlinearity
41 |
42 |
43 | Radial basis
44 | ------------
45 |
46 | .. autosummary::
47 | :toctree: generated
48 | :nosignatures:
49 | :template: classtemplate.rst
50 |
51 | GaussianRBF
52 | GaussianRBFCentered
53 | BesselRBF
54 |
55 |
56 | Cutoff
57 | ------
58 |
59 | .. autosummary::
60 | :toctree: generated
61 | :nosignatures:
62 | :template: classtemplate.rst
63 |
64 | CosineCutoff
65 | MollifierCutoff
66 |
67 |
68 | Activations
69 | -----------
70 |
71 | .. autosummary::
72 | :toctree: generated
73 | :nosignatures:
74 | :template: classtemplate.rst
75 |
76 | shifted_softplus
77 |
78 |
79 | Ops
80 | ---
81 |
82 | .. autosummary::
83 | :toctree: generated
84 | :nosignatures:
85 |
86 | scatter_add
87 |
88 | Factory functions
89 | -----------------
90 |
91 | .. autosummary::
92 | :toctree: generated
93 | :nosignatures:
94 |
95 | build_mlp
96 | build_gated_equivariant_mlp
97 | replicate_module
98 |
--------------------------------------------------------------------------------
/docs/api/representation.rst:
--------------------------------------------------------------------------------
1 | schnetpack.representation
2 | =========================
3 | .. currentmodule:: representation
4 |
5 |
6 | .. rubric:: Message-passing neural networks
7 |
8 | .. autosummary::
9 | :toctree: generated
10 | :nosignatures:
11 | :template: classtemplate.rst
12 |
13 | SchNet
14 | PaiNN
15 |
--------------------------------------------------------------------------------
/docs/api/schnetpack.rst:
--------------------------------------------------------------------------------
1 | schnetpack
2 | ==========
3 |
4 | Structure attributes
5 | --------------------
6 | .. autosummary::
7 | :toctree: generated
8 | :nosignatures:
9 |
10 | properties.Z
11 | properties.R
12 | properties.cell
13 | properties.pbc
14 | properties.idx_m
15 | properties.idx_i
16 | properties.idx_j
17 | properties.Rij
18 | properties.n_atoms
19 |
20 | Units
21 | -----
22 | .. autosummary::
23 | :toctree: generated
24 | :nosignatures:
25 |
26 | units.convert_units
27 |
--------------------------------------------------------------------------------
/docs/api/task.rst:
--------------------------------------------------------------------------------
1 | schnetpack.task
2 | ===============
3 | .. currentmodule:: task
4 |
5 | .. autosummary::
6 | :toctree: generated
7 | :nosignatures:
8 | :template: classtemplate.rst
9 |
10 | AtomisticTask
11 | ModelOutput
12 | UnsupervisedModelOutput
13 |
14 |
--------------------------------------------------------------------------------
/docs/api/train.rst:
--------------------------------------------------------------------------------
1 | schnetpack.train
2 | ================
3 | .. currentmodule:: train
4 |
5 |
6 | Callbacks
7 | ---------
8 |
9 | .. autosummary::
10 | :toctree: generated
11 | :nosignatures:
12 | :template: classtemplate.rst
13 |
14 | ModelCheckpoint
15 | PredictionWriter
16 |
17 | Scheduler
18 | ---------
19 |
20 | .. autosummary::
21 | :toctree: generated
22 | :nosignatures:
23 | :template: classtemplate.rst
24 |
25 | ReduceLROnPlateau
26 |
27 |
--------------------------------------------------------------------------------
/docs/api/transform.rst:
--------------------------------------------------------------------------------
1 | schnetpack.transform
2 | ====================
3 | .. automodule:: transform
4 |
5 | .. currentmodule:: transform
6 | .. autoclass:: Transform
7 |
8 | Atomistic
9 | ---------
10 |
11 | .. autosummary::
12 | :toctree: generated
13 | :nosignatures:
14 | :template: classtemplate.rst
15 |
16 | AddOffsets
17 | RemoveOffsets
18 | SubtractCenterOfMass
19 | SubtractCenterOfGeometry
20 |
21 | Casting
22 | -------
23 |
24 | .. autosummary::
25 | :toctree: generated
26 | :nosignatures:
27 | :template: classtemplate.rst
28 |
29 | CastMap
30 | CastTo32
31 | CastTo64
32 |
33 | Neighbor lists
34 | --------------
35 |
36 | .. autosummary::
37 | :toctree: generated
38 | :nosignatures:
39 | :template: classtemplate.rst
40 |
41 | MatScipyNeighborList
42 | ASENeighborList
43 | TorchNeighborList
44 | CachedNeighborList
45 | CountNeighbors
46 | FilterNeighbors
47 | WrapPositions
48 | CollectAtomTriples
49 |
--------------------------------------------------------------------------------
/docs/howtos:
--------------------------------------------------------------------------------
1 | ../examples/howtos/
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. SchNetPack documentation master file, created by
2 | sphinx-quickstart on Mon Jul 30 18:07:50 2018.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | SchNetPack documentation
7 | ========================
8 |
9 | SchNetPack is a toolbox for the development and application of deep neural networks to the prediction of
10 | potential energy surfaces and other quantum-chemical properties of molecules and materials. It contains
11 | basic building blocks of atomistic neural networks, manages their training and provides simple access
12 | to common benchmark datasets. This allows for an easy implementation and evaluation of new models.
13 |
14 | Contents
15 | ========
16 |
17 | .. toctree::
18 | :glob:
19 | :caption: Get Started
20 | :maxdepth: 1
21 |
22 | getstarted
23 |
24 | .. toctree::
25 | :glob:
26 | :caption: User guide
27 | :maxdepth: 1
28 |
29 | userguide/overview
30 | userguide/configs
31 | userguide/md
32 |
33 | .. toctree::
34 | :glob:
35 | :caption: Tutorials
36 | :maxdepth: 1
37 |
38 | tutorials/tutorial_01_preparing_data
39 | tutorials/tutorial_02_qm9
40 | tutorials/tutorial_03_force_models
41 | tutorials/tutorial_04_molecular_dynamics
42 | tutorials/tutorial_05_materials
43 |
44 | .. toctree::
45 | :glob:
46 | :caption: How-To
47 | :maxdepth: 1
48 |
49 | howtos/howto_batchwise_relaxations
50 | howtos/lammps
51 |
52 | .. toctree::
53 | :glob:
54 | :caption: Reference
55 | :maxdepth: 1
56 |
57 | api/schnetpack
58 | api/atomistic
59 | api/data
60 | api/datasets
61 | api/task
62 | api/model
63 | api/representation
64 | api/nn
65 | api/train
66 | api/transform
67 | api/md
68 |
--------------------------------------------------------------------------------
/docs/pictures/tensorboard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/docs/pictures/tensorboard.png
--------------------------------------------------------------------------------
/docs/sphinx-requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | numpy
3 | ase>=3.21
4 | h5py
5 | pyyaml
6 | tqdm
7 | ipykernel
8 |
9 | nbsphinx
10 | sphinx==7.1.2
11 | sphinx_rtd_theme
12 | readthedocs-sphinx-search
13 |
14 | pytorch_lightning>=2.0.0
15 | hydra-core>=1.1.0
16 | hydra-colorlog>=1.1.0
17 | torchmetrics==1.0.1
18 | protobuf==3.20.2
19 |
--------------------------------------------------------------------------------
/docs/tutorials:
--------------------------------------------------------------------------------
1 | ../examples/tutorials/
--------------------------------------------------------------------------------
/docs/userguide/overview.rst:
--------------------------------------------------------------------------------
1 | ========
2 | Overview
3 | ========
4 | .. _overview:
5 |
6 | SchNetPack is built so that it can be used from the command line and configured with
7 | config files as well as used as a Python library.
8 | In this section, we will explain the overall structure of SchNetPack, which will be
9 | helpful for both these use cases.
10 | SchNetPack is based on PyTorch and uses `PyTorchLightning `_ as a training framework.
11 | This heavily influences the structure described here.
12 | Additionally, `Hydra `_ is used to configure SchNetPack for command-line usage,
13 | which will be described in the next chapter.
14 |
15 | Data
16 | ====
17 | .. currentmodule:: data
18 |
19 | SchNetPack currently supports data sets stored in ASE format using
20 | :class:`ASEAtomsData`, but other formats can be added by implementing
21 | :class:`BaseAtomsData`. These classes are compatible with PyTorch dataloaders and
22 | provide an additional interface to store metadata, e.g. property units and
23 | single-atom reference values.
24 |
25 | An important aspect are the transforms that can be passed to the data classes. Those
26 | are PyTorch modules that perform preprocessing task on the data *before* batching.
27 | Typically, this is performed on the CPU as part of the multi-processing of PyTorch
28 | dataloaders.
29 | Important preprocessing :class:`Transform`s include removing of offsets from target properties
30 | and calculation of neighbor lists.
31 |
32 | Furthermore, we support PyTorch Lightning datamodules through :class:`AtomsDataModule`,
33 | which combines :class:`ASEAtomsData` with code for preparation, setup and partitioning
34 | into train/validation/test splits. We provide specific implementations of
35 | :class:`AtomsDataModule` for several benchmark datasets.
36 |
37 |
38 | Model
39 | =====
40 | .. currentmodule:: model
41 |
42 | A core component of SchNetPack is the :class:`AtomisticModel`, which is the base
43 | class for all models implemented in SchNetPack. It is essentially a PyTorch module with
44 | some additional functionality for specific to atomistic machine learning.
45 |
46 | The particular features and requirements are:
47 |
48 | * **Input dictionary:**
49 | To support a flexible interface, each model is supposed to take an input dictionary
50 | mapping strings to PyTorch tensors and returns a modified dictionary as output.
51 |
52 | * **Automatic collection of required dervatives:**
53 | Each layer that requires derivatives w.r.t to some input, should list them as strings
54 | in `layer.required_derivatives = ["input_key"]`. The `requires_grad` of the input
55 | tensor is then set automatically.
56 |
57 | .. currentmodule::transform
58 | * **Post-processing:**
59 | The atomistic model can take a list of non-trainable :class:`Transform`s that are
60 | used to post-process the output dictionary. These are not applied during training.
61 | A common use case are energy values that a large offsets and require double
62 | precision. To be able to still run a single precision model on GPU, one can substract
63 | the offset from the reference data during a preprocessing stage and then add it
64 | to the model prediction in post-processing after casting to double.
65 |
66 | .. currentmodule:: model
67 | While :class:`AtomisticModel` is a fairly general class, the models provided in
68 | SchNetPack follow a structure defined in the subclass :class:`NeuralNetworkPotential`:
69 |
70 | #. **Input modules:**: the input dictionary is sequentially passed to a list of PyTorch
71 | modules that return a modified dictionary
72 |
73 | #. **Representation:**: the input dictionary is passed to a representation module that
74 | computes atomwise representation, e.g. SchNet or PaiNN. The representation is added
75 | to the dictionary
76 |
77 | #. **Output modules:**: the dictionary is sequentially passed to a list of PyTorch
78 | modules that store the outputs in the dictionary
79 |
80 | Adhering to the structure of :class:`NeuralNetworkPotential` makes it easier to define
81 | config templates with Hydra and it is therefore recommended to subclass it wherever
82 | possible.
83 |
84 | Task
85 | ====
86 | .. currentmodule:: task
87 |
88 | The :class:`AtomisticTask` ties the model, outputs, loss and optimizers together and defines
89 | how the neural network will be trained. While the model is a vanilla PyTorch module,
90 | the task is a :class:`LightningModule` that can be directly passed to the
91 | PyTorch Lightning :class:`Trainer`.
92 |
93 | To define an :class:`AtomisticTask`, you need to provide:
94 |
95 | * a model as described above
96 |
97 | * a list of :class:`ModelOutput` which map output dictionary keys to target properties and assigns a loss function and other metrics
98 |
99 | * (optionally) optimizer and learning rate schedulers
100 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | In this directory, you find code examples and tutorials to demonstrate the functionality of SchNetPack
4 |
5 | ## Tutorials
6 | Jupyter notebooks demonstrating general concepts and workflows
7 |
8 | [Preparing and loading your data](tutorials/tutorial_01_preparing_data.ipynb)
9 |
10 | [Training a neural network on QM9
11 | ](tutorials/tutorial_02_qm9.ipynb)
12 |
13 | [Training a model on forces and energies](tutorials/tutorial_03_force_models.ipynb)
14 |
15 | [Molecular dynamics in SchNetPack](tutorials/tutorial_04_molecular_dynamics.ipynb)
16 |
17 | [Force fields for materials](tutorials/tutorial_05_materials.ipynb)
18 |
19 |
20 | ## How-To
21 | Short notebooks showing a particular use-case or functionality
22 |
23 | [Batch-wise Structure Relaxation
24 | ](howtos/howto_batchwise_relaxations.ipynb)
25 |
--------------------------------------------------------------------------------
/interfaces/lammps/examples/aspirin/aspirin.data:
--------------------------------------------------------------------------------
1 | /home/niklas/phd/code/pair_nequip/test/aspirin.data (written by ASE)
2 |
3 | 21 atoms
4 | 3 atom types
5 | 0.0 10 xlo xhi
6 | 0.0 10 ylo yhi
7 | 0.0 10 zlo zhi
8 |
9 |
10 | Atoms
11 |
12 | 1 1 7.1344888199999996 4.0156389499999996 4.8047821099999997
13 | 2 1 5.7626438100000001 5.9594139500000001 3.3200710999999998
14 | 3 1 7.6603448400000005 4.5920739499999996 3.6926960900000001
15 | 4 1 6.9103168200000002 5.3939659600000001 2.8529801400000001
16 | 5 1 1.9698097699999999 6.4954049600000001 5.7196621299999997
17 | 6 1 5.8494248400000002 4.4491289299999996 5.2843751000000001
18 | 7 1 5.2384468200000001 5.4735059399999999 4.5955781
19 | 8 3 5.8978958099999996 2.72356796 6.7300610499999998
20 | 9 3 2.6165478200000001 5.4177789399999998 3.5371431099999997
21 | 10 3 4.5237988199999997 4.4709129299999999 7.3392591500000002
22 | 11 1 5.3929918099999998 3.8097629500000001 6.5379821099999997
23 | 12 1 2.8770148799999999 5.9517599299999997 4.6022951000000001
24 | 13 3 4.19533384 6.2862459399999997 5.1105090999999998
25 | 14 2 4.5061968300000004 3.8132079800000001 8.0959742099999996
26 | 15 2 7.55473471 3.1975029699999999 5.3921310900000003
27 | 16 2 5.3306898199999999 6.8557109799999996 2.65473604
28 | 17 2 8.8037939099999996 4.5062809599999998 3.5437971399999997
29 | 18 2 7.2311408500000001 5.55718595 1.8758501999999999
30 | 19 2 2.2910697500000001 7.4846579999999996 5.9269280999999996
31 | 20 2 0.86951685000000012 6.4821670099999995 5.4312661000000002
32 | 21 2 2.12585187 6.0032089900000001 6.6994850599999998
33 |
--------------------------------------------------------------------------------
/interfaces/lammps/examples/aspirin/aspirin_md.in:
--------------------------------------------------------------------------------
1 | units real
2 | atom_style atomic
3 | newton off
4 | thermo 1
5 | dump mydmp all atom 10 dump.lammpstrj
6 | boundary s s s
7 | read_data aspirin.data
8 | pair_style schnetpack
9 | pair_coeff * * deployed_model 6 1 8
10 | mass 1 12.0
11 | mass 2 1.0
12 | mass 3 16.0
13 | neighbor 1.0 bin
14 | neigh_modify delay 0 every 1 check no
15 | fix 1 all nve
16 | timestep 0.5
17 | compute atomicenergies all pe/atom
18 | compute totalatomicenergy all reduce sum c_atomicenergies
19 | thermo_style custom step time temp pe c_totalatomicenergy etotal press spcpu cpuremain
20 | run 5000
21 | print $(0.000001 * pe) file pe.dat
22 | print $(0.000001 * c_totalatomicenergy) file totalatomicenergy.dat
23 | write_dump all custom output.dump id type x y z fx fy fz c_atomicenergies modify format float %20.15g
24 |
--------------------------------------------------------------------------------
/interfaces/lammps/examples/aspirin/best_model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/interfaces/lammps/examples/aspirin/best_model
--------------------------------------------------------------------------------
/interfaces/lammps/pair_schnetpack.h:
--------------------------------------------------------------------------------
1 | /* ----------------------------------------------------------------------
2 | References:
3 |
4 | .. [#pair_nequip] https://github.com/mir-group/pair_nequip
5 | .. [#lammps] https://github.com/lammps/lammps
6 |
7 | ------------------------------------------------------------------------- */
8 |
9 | #ifdef PAIR_CLASS
10 |
11 | PairStyle(schnetpack,PairSCHNETPACK)
12 |
13 | #else
14 |
15 | #ifndef LMP_PAIR_SCHNETPACK_H
16 | #define LMP_PAIR_SCHNETPACK_H
17 |
18 | #include "pair.h"
19 |
20 | #include
21 |
22 | namespace LAMMPS_NS {
23 |
24 | class PairSCHNETPACK : public Pair {
25 | public:
26 | PairSCHNETPACK(class LAMMPS *);
27 | virtual ~PairSCHNETPACK();
28 | virtual void compute(int, int);
29 | void settings(int, char **);
30 | virtual void coeff(int, char **);
31 | virtual double init_one(int, int);
32 | virtual void init_style();
33 | void allocate();
34 |
35 | double cutoff;
36 | torch::jit::script::Module model;
37 | torch::Device device = torch::kCPU;
38 |
39 | protected:
40 | int * type_mapper;
41 | int debug_mode = 0;
42 |
43 | };
44 |
45 | }
46 |
47 | #endif
48 | #endif
49 |
--------------------------------------------------------------------------------
/interfaces/lammps/patch_lammps.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # usage: patch_lammps.sh [-e] /path/to/lammps/
3 | #
4 | #
5 | # References:
6 | #
7 | # .. [#pair_nequip] https://github.com/mir-group/pair_nequip
8 |
9 |
10 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
11 | echo $SCRIPT_DIR
12 |
13 |
14 | do_e_mode=false
15 |
16 | while getopts "he" option; do
17 | case $option in
18 | e)
19 | do_e_mode=true;;
20 | h) # display Help
21 | echo "patch_lammps.sh [-e] /path/to/lammps/"
22 | exit;;
23 | esac
24 | done
25 |
26 | # https://stackoverflow.com/a/9472919
27 | shift $(($OPTIND - 1))
28 | lammps_dir=$1
29 |
30 | if [ "$lammps_dir" = "" ];
31 | then
32 | echo "lammps_dir must be provided"
33 | exit 1
34 | fi
35 |
36 | if [ ! -d "$lammps_dir" ]
37 | then
38 | echo "$lammps_dir doesn't exist"
39 | exit 1
40 | fi
41 |
42 | if [ ! -d "$lammps_dir/cmake" ]
43 | then
44 | echo "$lammps_dir doesn't look like a LAMMPS source directory"
45 | exit 1
46 | fi
47 |
48 | # Check if root directory is correct
49 | if [ ! -f pair_schnetpack.cpp ]; then
50 | echo "Please run `patch_lammps.sh` from the `pair_schnetpack.cpp` root directory."
51 | exit 1
52 | fi
53 |
54 | echo "Updating CMakeLists.txt..."
55 | # Check for double-patch
56 | if grep -q "find_package(Torch REQUIRED)" $lammps_dir/cmake/CMakeLists.txt
57 | then
58 | echo "This LAMMPS installation _seems_ to already have been patched. CMakeLists.txt file not modified."
59 | else
60 | # Update CMakeLists.txt
61 | sed -i "s/set(CMAKE_CXX_STANDARD 11)/set(CMAKE_CXX_STANDARD 14)/" $lammps_dir/cmake/CMakeLists.txt
62 |
63 | # Add libtorch
64 | cat >> $lammps_dir/cmake/CMakeLists.txt << "EOF2"
65 |
66 | find_package(Torch REQUIRED)
67 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
68 | target_link_libraries(lammps PUBLIC "${TORCH_LIBRARIES}")
69 | EOF2
70 |
71 | fi
72 |
73 | # check if files need to be copied to lammps directory
74 | if [ ! -f $lammps_dir/src/pair_schnetpack.cpp ]; then
75 | if [ "$do_e_mode" = true ]
76 | then
77 | echo "Making source symlinks (-e)..."
78 | for file in *.{cpp,h}; do
79 | ln -s `realpath -s $file` $lammps_dir/src/$file
80 | done
81 | else
82 | echo "Copying files..."
83 | for file in *.{cpp,h}; do
84 | cp $file $lammps_dir/src/$file
85 | done
86 | fi
87 | else
88 | echo "pair_schnetpack.cpp file already exists. No files copied."
89 | fi
90 |
91 |
92 | echo "Done!"
93 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "schnetpack"
7 | # Declaring that the version will be read dynamically.
8 | # This helps us not having to specify the version twice.
9 | # We have to specify the version only in schnetpack.__version__
10 | # see [tool.setuptools.dynamic] below
11 | dynamic = ["version"]
12 | authors = [
13 | { name = "Kristof T. Schuett" },
14 | { name = "Michael Gastegger" },
15 | { name = "Stefaan Hessmann" },
16 | { name = "Niklas Gebauer" },
17 | { name = "Jonas Lederer" }
18 | ]
19 | description = "SchNetPack - Deep Neural Networks for Atomistic Systems"
20 | readme = "README.md"
21 | license = { file="LICENSE" }
22 | requires-python = ">=3.12,<3.13"
23 | dependencies = [
24 | "numpy>=2.0.0",
25 | "sympy>=1.13",
26 | "ase>=3.21",
27 | "h5py",
28 | "pyyaml",
29 | "hydra-core>=1.1.0",
30 | "torch>=2.5.0",
31 | "pytorch_lightning>=2.0.0",
32 | "torchmetrics",
33 | "hydra-colorlog>=1.1.0",
34 | "rich",
35 | "fasteners",
36 | "dirsync",
37 | "torch-ema",
38 | "matscipy>=1.1.0",
39 | "tensorboard>=2.17.1",
40 | "tensorboardX>=2.6.2.2",
41 | "tqdm",
42 | "pre-commit",
43 | "black",
44 | "protobuf",
45 | "progressbar"
46 | ]
47 |
48 | [project.optional-dependencies]
49 | test = ["pytest", "pytest-datadir", "pytest-benchmark"]
50 |
51 | [tool.setuptools]
52 | package-dir = { "" = "src" }
53 | script-files = [
54 | "src/scripts/spkconvert",
55 | "src/scripts/spktrain",
56 | "src/scripts/spkpredict",
57 | "src/scripts/spkmd",
58 | "src/scripts/spkdeploy",
59 | ]
60 |
61 | [tool.setuptools.dynamic]
62 | version = {attr = "schnetpack.__version__"}
63 |
64 | [tool.setuptools.packages.find]
65 | where = ["src"]
66 |
67 | [tool.setuptools.package-data]
68 | schnetpack = ["configs/**/*.yaml"]
69 |
--------------------------------------------------------------------------------
/readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Set the version of Python and other tools you might need
9 | build:
10 | os: "ubuntu-22.04"
11 | tools:
12 | python: "3.12"
13 |
14 | # Build documentation in the docs/ directory with Sphinx
15 | sphinx:
16 | configuration: docs/conf.py
17 |
18 | # Explicitly set the version of Python and its requirements
19 | python:
20 | install:
21 | - requirements: docs/sphinx-requirements.txt
22 | - path: .
23 |
--------------------------------------------------------------------------------
/src/schnetpack/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.filterwarnings("ignore", category=DeprecationWarning, module="tensorboard")
4 |
5 | from schnetpack import transform
6 | from schnetpack import properties
7 | from schnetpack import data
8 | from schnetpack import datasets
9 | from schnetpack import atomistic
10 | from schnetpack import representation
11 | from schnetpack import interfaces
12 | from schnetpack import nn
13 | from schnetpack import train
14 | from schnetpack import model
15 | from schnetpack.units import *
16 | from schnetpack.task import *
17 | from schnetpack import md
18 |
19 |
20 | __version__ = "2.1.1"
21 |
--------------------------------------------------------------------------------
/src/schnetpack/atomistic/__init__.py:
--------------------------------------------------------------------------------
1 | from .atomwise import *
2 | from .response import *
3 | from .distances import *
4 | from .nuclear_repulsion import *
5 | from .electrostatic import *
6 | from .aggregation import *
7 | from .external_fields import *
8 |
--------------------------------------------------------------------------------
/src/schnetpack/atomistic/aggregation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from typing import Dict, List
5 |
6 | __all__ = ["Aggregation"]
7 |
8 |
9 | class Aggregation(nn.Module):
10 | """
11 | Aggregate predictions into a single output variable.
12 |
13 | Args:
14 | keys (list(str)): List of properties to be added.
15 | output_key (str): Name of new property in output.
16 | """
17 |
18 | def __init__(self, keys: List[str], output_key: str = "y"):
19 | super(Aggregation, self).__init__()
20 |
21 | self.keys: List[str] = list(keys)
22 | self.output_key = output_key
23 | self.model_outputs = [output_key]
24 |
25 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
26 | energy = torch.stack([inputs[key] for key in self.keys]).sum(0)
27 | inputs[self.output_key] = energy
28 | return inputs
29 |
--------------------------------------------------------------------------------
/src/schnetpack/atomistic/distances.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | import schnetpack.properties as properties
7 |
8 |
9 | class PairwiseDistances(nn.Module):
10 | """
11 | Compute pair-wise distances from indices provided by a neighbor list transform.
12 | """
13 |
14 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
15 | R = inputs[properties.R]
16 | offsets = inputs[properties.offsets]
17 | idx_i = inputs[properties.idx_i]
18 | idx_j = inputs[properties.idx_j]
19 |
20 | # To avoid error in Windows OS
21 | idx_i = idx_i.long()
22 | idx_j = idx_j.long()
23 |
24 | Rij = R[idx_j] - R[idx_i] + offsets
25 | inputs[properties.Rij] = Rij
26 | return inputs
27 |
28 |
29 | class FilterShortRange(nn.Module):
30 | """
31 | Separate short-range from all supplied distances.
32 |
33 | The short-range distances will be stored under the original keys (properties.Rij,
34 | properties.idx_i, properties.idx_j), while the original distances can be accessed for long-range terms via
35 | (properties.Rij_lr, properties.idx_i_lr, properties.idx_j_lr).
36 | """
37 |
38 | def __init__(self, short_range_cutoff: float):
39 | super().__init__()
40 | self.short_range_cutoff = short_range_cutoff
41 |
42 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
43 | idx_i = inputs[properties.idx_i]
44 | idx_j = inputs[properties.idx_j]
45 | Rij = inputs[properties.Rij]
46 |
47 | rij = torch.norm(Rij, dim=-1)
48 | cidx = torch.nonzero(rij <= self.short_range_cutoff).squeeze(-1)
49 |
50 | inputs[properties.Rij_lr] = Rij
51 | inputs[properties.idx_i_lr] = idx_i
52 | inputs[properties.idx_j_lr] = idx_j
53 |
54 | inputs[properties.Rij] = Rij[cidx]
55 | inputs[properties.idx_i] = idx_i[cidx]
56 | inputs[properties.idx_j] = idx_j[cidx]
57 | return inputs
58 |
--------------------------------------------------------------------------------
/src/schnetpack/atomistic/external_fields.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, List
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | import schnetpack.properties as properties
7 | from schnetpack.utils import required_fields_from_properties
8 |
9 | __all__ = ["StaticExternalFields"]
10 |
11 |
12 | class StaticExternalFields(nn.Module):
13 | """
14 | Input routine for setting up dummy external fields in response models.
15 | Checks if fields are present in input and sets dummy fields otherwise.
16 |
17 | Args:
18 | external_fields (list(str)): List of required external fields. Either this or the requested response
19 | properties needs to be specified.
20 | response_properties (list(str)): List of requested response properties. If this is not None, it is used to
21 | determine the required external fields.
22 | """
23 |
24 | def __init__(
25 | self,
26 | external_fields: List[str] = [],
27 | response_properties: Optional[List[str]] = None,
28 | ):
29 | super(StaticExternalFields, self).__init__()
30 |
31 | if response_properties is not None:
32 | external_fields = required_fields_from_properties(response_properties)
33 |
34 | self.external_fields: List[str] = list(set(external_fields))
35 |
36 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
37 | n_atoms = inputs[properties.n_atoms]
38 | n_molecules = n_atoms.shape[0]
39 |
40 | # Fields passed to interaction computation (cast to batch structure)
41 | for field in self.external_fields:
42 | # Store all fields in directory which will be returned for derivatives
43 | if field not in inputs:
44 | inputs[field] = torch.zeros(
45 | n_molecules,
46 | 3,
47 | device=n_atoms.device,
48 | dtype=inputs[properties.R].dtype,
49 | requires_grad=True,
50 | )
51 |
52 | # Initialize nuclear magnetic moments for magnetic fields
53 | if properties.magnetic_field in self.external_fields:
54 | if properties.nuclear_magnetic_moments not in inputs:
55 | inputs[properties.nuclear_magnetic_moments] = torch.zeros_like(
56 | inputs[properties.R], requires_grad=True
57 | )
58 |
59 | return inputs
60 |
--------------------------------------------------------------------------------
/src/schnetpack/atomistic/nuclear_repulsion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from typing import Union, Callable, Dict, Optional
5 |
6 | import schnetpack.properties as properties
7 | import schnetpack.nn as snn
8 | import schnetpack.units as spk_units
9 |
10 | __all__ = ["ZBLRepulsionEnergy"]
11 |
12 |
13 | class ZBLRepulsionEnergy(nn.Module):
14 | """
15 | Computes a Ziegler-Biersack-Littmark style repulsion energy
16 |
17 | Args:
18 | energy_unit (str/float): Energy unit.
19 | position_unit (str/float): Unit used for distances.
20 | output_key (str): Key to which results will be stored
21 | trainable (bool): If set to true, ZBL parameters will be optimized during training (default=True)
22 | cutoff_fn (Callable): Apply a cutoff function to the interatomic distances.
23 |
24 | References:
25 | .. [#Cutoff] Ebert, D. S.; Musgrave, F. K.; Peachey, D.; Perlin, K.; Worley, S.
26 | Texturing & Modeling: A Procedural Approach;
27 | Morgan Kaufmann, 2003
28 | .. [#ZBL]
29 | https://docs.lammps.org/pair_zbl.html
30 | """
31 |
32 | def __init__(
33 | self,
34 | energy_unit: Union[str, float],
35 | position_unit: Union[str, float],
36 | output_key: str,
37 | trainable: bool = True,
38 | cutoff_fn: Optional[Callable] = None,
39 | ):
40 | super(ZBLRepulsionEnergy, self).__init__()
41 |
42 | energy_units = spk_units.convert_units("Ha", energy_unit)
43 | position_units = spk_units.convert_units("Bohr", position_unit)
44 | ke = energy_units * position_units
45 | self.register_buffer("ke", torch.tensor(ke))
46 |
47 | self.cutoff_fn = cutoff_fn
48 | self.output_key = output_key
49 |
50 | # Basic ZBL parameters (in atomic units)
51 | # Since all quantities have a predefined sign, they are initialized to the inverse softplus and a softplus
52 | # function is applied in the forward pass to guarantee the correct sign during training
53 | a_div = snn.softplus_inverse(
54 | torch.tensor([1.0 / (position_units * 0.8854)])
55 | ) # in this way, distances can be used directly
56 | a_pow = snn.softplus_inverse(torch.tensor([0.23]))
57 | exponents = snn.softplus_inverse(
58 | torch.tensor([3.19980, 0.94229, 0.40290, 0.20162])
59 | )
60 | coefficients = snn.softplus_inverse(
61 | torch.tensor([0.18175, 0.50986, 0.28022, 0.02817])
62 | )
63 |
64 | # Initialize network parameters
65 | self.a_pow = nn.Parameter(a_pow, requires_grad=trainable)
66 | self.a_div = nn.Parameter(a_div, requires_grad=trainable)
67 | self.coefficients = nn.Parameter(coefficients, requires_grad=trainable)
68 | self.exponents = nn.Parameter(exponents, requires_grad=trainable)
69 |
70 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
71 | z = inputs[properties.Z]
72 | r_ij = inputs[properties.Rij]
73 | d_ij = torch.norm(r_ij, dim=1)
74 | idx_i = inputs[properties.idx_i]
75 | idx_j = inputs[properties.idx_j]
76 | idx_m = inputs[properties.idx_m]
77 |
78 | n_atoms = z.shape[0]
79 | n_molecules = int(idx_m[-1]) + 1
80 |
81 | # Construct screening function
82 | a = z ** F.softplus(self.a_pow)
83 | a_ij = (a[idx_i] + a[idx_j]) * F.softplus(self.a_div)
84 | # Get exponents and coefficients, normalize the latter
85 | exponents = a_ij[..., None] * F.softplus(self.exponents)[None, ...]
86 | coefficients = F.softplus(self.coefficients)[None, ...]
87 | coefficients = F.normalize(coefficients, p=1.0, dim=1)
88 |
89 | screening = torch.sum(
90 | coefficients * torch.exp(-exponents * d_ij[:, None]), dim=1
91 | )
92 |
93 | # Compute nuclear repulsion
94 | repulsion = (z[idx_i] * z[idx_j]) / d_ij
95 |
96 | # Apply cutoff if requested
97 | if self.cutoff_fn is not None:
98 | f_cut = self.cutoff_fn(d_ij)
99 | repulsion = repulsion * f_cut
100 |
101 | # Compute ZBL energy
102 | y_zbl = snn.scatter_add(repulsion * screening, idx_i, dim_size=n_atoms)
103 | y_zbl = snn.scatter_add(y_zbl, idx_m, dim_size=n_molecules)
104 | y_zbl = 0.5 * self.ke * y_zbl
105 |
106 | inputs[self.output_key] = y_zbl
107 |
108 | return inputs
109 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/src/schnetpack/configs/__init__.py
--------------------------------------------------------------------------------
/src/schnetpack/configs/callbacks/checkpoint.yaml:
--------------------------------------------------------------------------------
1 | model_checkpoint:
2 | _target_: schnetpack.train.ModelCheckpoint
3 | monitor: "val_loss" # name of the logged metric which determines when model is improving
4 | save_top_k: 1 # save k best models (determined by above metric)
5 | save_last: True # additionaly always save model from last epoch
6 | mode: "min" # can be "max" or "min"
7 | verbose: False
8 | dirpath: 'checkpoints/'
9 | filename: '{epoch:02d}'
10 | model_path: ${globals.model_path}
--------------------------------------------------------------------------------
/src/schnetpack/configs/callbacks/earlystopping.yaml:
--------------------------------------------------------------------------------
1 | early_stopping:
2 | _target_: pytorch_lightning.callbacks.EarlyStopping
3 | monitor: "val_loss" # name of the logged metric which determines when model is improving
4 | patience: 200 # how many epochs of not improving until training stops
5 | mode: "min" # can be "max" or "min"
6 | min_delta: 0.0 # minimum change in the monitored metric needed to qualify as an improvement
7 | check_on_train_epoch_end: False
--------------------------------------------------------------------------------
/src/schnetpack/configs/callbacks/ema.yaml:
--------------------------------------------------------------------------------
1 | ema:
2 | _target_: schnetpack.train.ExponentialMovingAverage
3 | decay: 0.995
--------------------------------------------------------------------------------
/src/schnetpack/configs/callbacks/lrmonitor.yaml:
--------------------------------------------------------------------------------
1 | lr_monitor:
2 | _target_: pytorch_lightning.callbacks.LearningRateMonitor
3 | logging_interval: epoch
--------------------------------------------------------------------------------
/src/schnetpack/configs/data/ani1.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - custom
3 |
4 | _target_: schnetpack.datasets.ANI1
5 |
6 | datapath: ${run.data_dir}/ani1.db # data_dir is specified in train.yaml
7 | batch_size: 32
8 | num_train: 10000000
9 | num_val: 100000
10 | num_heavy_atoms: 8
11 | high_energies: False
12 |
13 | # convert to typically used units
14 | distance_unit: Ang
15 | property_units:
16 | energy: eV
--------------------------------------------------------------------------------
/src/schnetpack/configs/data/custom.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.data.AtomsDataModule
2 |
3 | datapath: ???
4 | data_workdir: null
5 | batch_size: 10
6 | num_train: ???
7 | num_val: ???
8 | num_test: null
9 | num_workers: 8
10 | num_val_workers: null
11 | num_test_workers: null
12 | train_sampler_cls: null
--------------------------------------------------------------------------------
/src/schnetpack/configs/data/iso17.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - custom
3 |
4 | _target_: schnetpack.datasets.ISO17
5 |
6 | datapath: ${run.data_dir}/${data.folder}.db # data_dir is specified in train.yaml
7 | folder: reference
8 | batch_size: 32
9 | num_train: 0.9
10 | num_val: 0.1
11 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/data/materials_project.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - custom
3 |
4 | _target_: schnetpack.datasets.MaterialsProject
5 |
6 | datapath: ${run.data_dir}/materials_project.db # data_dir is specified in train.yaml
7 | batch_size: 32
8 | num_train: 60000
9 | num_val: 2000
10 | apikey: ???
--------------------------------------------------------------------------------
/src/schnetpack/configs/data/md17.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - custom
3 |
4 | _target_: schnetpack.datasets.MD17
5 |
6 | datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml
7 | molecule: aspirin
8 | batch_size: 10
9 | num_train: 950
10 | num_val: 50
11 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/data/md22.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - custom
3 |
4 | _target_: schnetpack.datasets.MD22
5 |
6 | datapath: ${run.data_dir}/${data.molecule}.db # data_dir is specified in train.yaml
7 | molecule: Ac-Ala3-NHMe
8 | batch_size: 10
9 | num_train: 5700
10 | num_val: 300
11 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/data/omdb.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - custom
3 |
4 | _target_: schnetpack.datasets.OrganicMaterialsDatabase
5 |
6 | datapath: ${run.data_dir}/omdb.db # data_dir is specified in train.yaml
7 | batch_size: 32
8 | num_train: 0.8
9 | num_val: 0.1
10 | raw_path: null
--------------------------------------------------------------------------------
/src/schnetpack/configs/data/qm7x.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - custom
3 |
4 | _target_: schnetpack.datasets.QM7X
5 |
6 | datapath: ${run.data_dir}/qm7x.db # data_dir is specified in train.yaml
7 | batch_size: 100
8 | num_train: 5550
9 | num_val: 700
--------------------------------------------------------------------------------
/src/schnetpack/configs/data/qm9.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - custom
3 |
4 | _target_: schnetpack.datasets.QM9
5 |
6 | datapath: ${run.data_dir}/qm9.db # data_dir is specified in train.yaml
7 | batch_size: 100
8 | num_train: 110000
9 | num_val: 10000
10 | remove_uncharacterized: True
11 |
12 | # convert to typically used units
13 | distance_unit: Ang
14 | property_units:
15 | energy_U0: eV
16 | energy_U: eV
17 | enthalpy_H: eV
18 | free_energy: eV
19 | homo: eV
20 | lumo: eV
21 | gap: eV
22 | zpve: eV
--------------------------------------------------------------------------------
/src/schnetpack/configs/data/rmd17.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - custom
3 |
4 | _target_: schnetpack.datasets.rMD17
5 |
6 | datapath: ${run.data_dir}/rmd17_${data.molecule}.db # data_dir is specified in train.yaml
7 | molecule: aspirin
8 | batch_size: 10
9 | num_train: 950
10 | num_val: 50
11 | split_id: null
--------------------------------------------------------------------------------
/src/schnetpack/configs/data/sampler/stratified_property.yaml:
--------------------------------------------------------------------------------
1 | # @package data
2 | train_sampler_cls: schnetpack.data.sampler.StratifiedSampler
3 | train_sampler_args:
4 | partition_criterion:
5 | _target_: schnetpack.data.sampler.PropertyCriterion
6 | property_key: ${globals.property}
7 | num_bins: 10
8 | replacement: True
--------------------------------------------------------------------------------
/src/schnetpack/configs/experiment/md17.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: nnp
5 | - override /data: md17
6 |
7 | run:
8 | experiment: md17_${data.molecule}
9 |
10 | globals:
11 | cutoff: 5.
12 | lr: 1e-3
13 | energy_key: energy
14 | forces_key: forces
15 |
16 | data:
17 | distance_unit: Ang
18 | property_units:
19 | energy: kcal/mol
20 | forces: kcal/mol/Ang
21 | transforms:
22 | - _target_: schnetpack.transform.SubtractCenterOfMass
23 | - _target_: schnetpack.transform.RemoveOffsets
24 | property: energy
25 | remove_mean: True
26 | - _target_: schnetpack.transform.MatScipyNeighborList
27 | cutoff: ${globals.cutoff}
28 | - _target_: schnetpack.transform.CastTo32
29 |
30 | model:
31 | output_modules:
32 | - _target_: schnetpack.atomistic.Atomwise
33 | output_key: ${globals.energy_key}
34 | n_in: ${model.representation.n_atom_basis}
35 | aggregation_mode: sum
36 | - _target_: schnetpack.atomistic.Forces
37 | energy_key: ${globals.energy_key}
38 | force_key: ${globals.forces_key}
39 | postprocessors:
40 | - _target_: schnetpack.transform.CastTo64
41 | - _target_: schnetpack.transform.AddOffsets
42 | property: energy
43 | add_mean: True
44 |
45 | task:
46 | outputs:
47 | - _target_: schnetpack.task.ModelOutput
48 | name: ${globals.energy_key}
49 | loss_fn:
50 | _target_: torch.nn.MSELoss
51 | metrics:
52 | mae:
53 | _target_: torchmetrics.regression.MeanAbsoluteError
54 | rmse:
55 | _target_: torchmetrics.regression.MeanSquaredError
56 | squared: False
57 | loss_weight: 0.01
58 | - _target_: schnetpack.task.ModelOutput
59 | name: ${globals.forces_key}
60 | loss_fn:
61 | _target_: torch.nn.MSELoss
62 | metrics:
63 | mae:
64 | _target_: torchmetrics.regression.MeanAbsoluteError
65 | rmse:
66 | _target_: torchmetrics.regression.MeanSquaredError
67 | squared: False
68 | loss_weight: 0.99
--------------------------------------------------------------------------------
/src/schnetpack/configs/experiment/qm9_atomwise.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: nnp
5 | - override /data: qm9
6 |
7 | run:
8 | experiment: qm9_${globals.property}
9 |
10 | globals:
11 | cutoff: 5.
12 | lr: 5e-4
13 | property: energy_U0
14 | aggregation: sum
15 |
16 | data:
17 | transforms:
18 | - _target_: schnetpack.transform.SubtractCenterOfMass
19 | - _target_: schnetpack.transform.RemoveOffsets
20 | property: ${globals.property}
21 | remove_atomrefs: True
22 | remove_mean: True
23 | - _target_: schnetpack.transform.MatScipyNeighborList
24 | cutoff: ${globals.cutoff}
25 | - _target_: schnetpack.transform.CastTo32
26 |
27 | model:
28 | output_modules:
29 | - _target_: schnetpack.atomistic.Atomwise
30 | output_key: ${globals.property}
31 | n_in: ${model.representation.n_atom_basis}
32 | aggregation_mode: ${globals.aggregation}
33 | postprocessors:
34 | - _target_: schnetpack.transform.CastTo64
35 | - _target_: schnetpack.transform.AddOffsets
36 | property: ${globals.property}
37 | add_mean: True
38 | add_atomrefs: True
39 |
40 | task:
41 | outputs:
42 | - _target_: schnetpack.task.ModelOutput
43 | name: ${globals.property}
44 | loss_fn:
45 | _target_: torch.nn.MSELoss
46 | metrics:
47 | mae:
48 | _target_: torchmetrics.regression.MeanAbsoluteError
49 | rmse:
50 | _target_: torchmetrics.regression.MeanSquaredError
51 | squared: False
52 | loss_weight: 1.
--------------------------------------------------------------------------------
/src/schnetpack/configs/experiment/qm9_dipole.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: nnp
5 | - override /data: qm9
6 |
7 | run:
8 | experiment: qm9_${globals.property}
9 |
10 | globals:
11 | cutoff: 5.
12 | lr: 5e-4
13 | property: dipole_moment
14 |
15 | data:
16 | transforms:
17 | - _target_: schnetpack.transform.SubtractCenterOfMass
18 | - _target_: schnetpack.transform.MatScipyNeighborList
19 | cutoff: ${globals.cutoff}
20 | - _target_: schnetpack.transform.CastTo32
21 |
22 | model:
23 | output_modules:
24 | - _target_: schnetpack.atomistic.DipoleMoment
25 | dipole_key: ${globals.property}
26 | n_in: ${model.representation.n_atom_basis}
27 | predict_magnitude: True
28 | use_vector_representation: False
29 | postprocessors:
30 | - _target_: schnetpack.transform.CastTo64
31 |
32 | task:
33 | outputs:
34 | - _target_: schnetpack.task.ModelOutput
35 | name: ${globals.property}
36 | loss_fn:
37 | _target_: torch.nn.MSELoss
38 | metrics:
39 | mae:
40 | _target_: torchmetrics.regression.MeanAbsoluteError
41 | rmse:
42 | _target_: torchmetrics.regression.MeanSquaredError
43 | squared: False
44 | loss_weight: 1.
--------------------------------------------------------------------------------
/src/schnetpack/configs/experiment/response.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: nnp
5 | - override /model/representation: field_schnet
6 | - override /data: custom
7 |
8 | run:
9 | experiment: response
10 |
11 | globals:
12 | cutoff: 9.448
13 | lr: 1e-4
14 | energy_key: energy
15 | forces_key: forces
16 | shielding_key: shielding
17 | shielding_elements: [ 1,6,8 ]
18 | response_properties:
19 | - forces
20 | - dipole_moment
21 | - polarizability
22 | - ${globals.shielding_key}
23 |
24 | data:
25 | distance_unit: 1.0
26 | batch_size: 10
27 | transforms:
28 | - _target_: schnetpack.transform.SubtractCenterOfMass
29 | - _target_: schnetpack.transform.RemoveOffsets
30 | property: ${globals.energy_key}
31 | remove_mean: true
32 | - _target_: schnetpack.transform.MatScipyNeighborList
33 | cutoff: ${globals.cutoff}
34 | - _target_: schnetpack.transform.CastTo32
35 | - _target_: schnetpack.transform.SplitShielding
36 | shielding_key: ${globals.shielding_key}
37 | atomic_numbers: ${globals.shielding_elements}
38 |
39 | model:
40 | input_modules:
41 | - _target_: schnetpack.atomistic.PairwiseDistances
42 | - _target_: schnetpack.atomistic.StaticExternalFields
43 | response_properties: ${globals.response_properties}
44 | output_modules:
45 | - _target_: schnetpack.atomistic.Atomwise
46 | output_key: ${globals.energy_key}
47 | n_in: ${model.representation.n_atom_basis}
48 | aggregation_mode: sum
49 | - _target_: schnetpack.transform.ScaleProperty
50 | input_key: ${globals.energy_key}
51 | output_key: ${globals.energy_key}
52 | - _target_: schnetpack.atomistic.Response
53 | energy_key: ${globals.energy_key}
54 | response_properties: ${globals.response_properties}
55 | - _target_: schnetpack.transform.SplitShielding
56 | shielding_key: ${globals.shielding_key}
57 | atomic_numbers: ${globals.shielding_elements}
58 | postprocessors:
59 | - _target_: schnetpack.transform.CastTo64
60 | - _target_: schnetpack.transform.AddOffsets
61 | property: energy
62 | add_mean: True
63 |
64 | task:
65 | scheduler_args:
66 | mode: min
67 | factor: 0.5
68 | patience: 50
69 | min_lr: 1e-6
70 | smoothing_factor: 0.0
71 | outputs:
72 | - _target_: schnetpack.task.ModelOutput
73 | name: ${globals.energy_key}
74 | loss_fn:
75 | _target_: torch.nn.MSELoss
76 | metrics:
77 | mae:
78 | _target_: torchmetrics.regression.MeanAbsoluteError
79 | rmse:
80 | _target_: torchmetrics.regression.MeanSquaredError
81 | squared: false
82 | loss_weight: 1.00
83 | - _target_: schnetpack.task.ModelOutput
84 | name: forces
85 | loss_fn:
86 | _target_: torch.nn.MSELoss
87 | metrics:
88 | mae:
89 | _target_: torchmetrics.regression.MeanAbsoluteError
90 | rmse:
91 | _target_: torchmetrics.regression.MeanSquaredError
92 | squared: false
93 | loss_weight: 5.0
94 | - _target_: schnetpack.task.ModelOutput
95 | name: dipole_moment
96 | loss_fn:
97 | _target_: torch.nn.MSELoss
98 | metrics:
99 | mae:
100 | _target_: torchmetrics.regression.MeanAbsoluteError
101 | rmse:
102 | _target_: torchmetrics.regression.MeanSquaredError
103 | squared: false
104 | loss_weight: 0.01
105 | - _target_: schnetpack.task.ModelOutput
106 | name: polarizability
107 | loss_fn:
108 | _target_: torch.nn.MSELoss
109 | metrics:
110 | mae:
111 | _target_: torchmetrics.regression.MeanAbsoluteError
112 | rmse:
113 | _target_: torchmetrics.regression.MeanSquaredError
114 | squared: false
115 | loss_weight: 0.01
116 | # shielding split by element
117 | - _target_: schnetpack.task.ModelOutput
118 | name: ${globals.shielding_key}_1
119 | loss_fn:
120 | _target_: torch.nn.MSELoss
121 | metrics:
122 | mae:
123 | _target_: torchmetrics.regression.MeanAbsoluteError
124 | rmse:
125 | _target_: torchmetrics.regression.MeanSquaredError
126 | squared: false
127 | mae_iso:
128 | _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError
129 | mae_aniso:
130 | _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError
131 | diagonal: false
132 | loss_weight: 0.1
133 | - _target_: schnetpack.task.ModelOutput
134 | name: ${globals.shielding_key}_6
135 | loss_fn:
136 | _target_: torch.nn.MSELoss
137 | metrics:
138 | mae:
139 | _target_: torchmetrics.regression.MeanAbsoluteError
140 | rmse:
141 | _target_: torchmetrics.regression.MeanSquaredError
142 | squared: false
143 | mae_iso:
144 | _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError
145 | mae_aniso:
146 | _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError
147 | diagonal: false
148 | loss_weight: 0.004
149 | - _target_: schnetpack.task.ModelOutput
150 | name: ${globals.shielding_key}_8
151 | loss_fn:
152 | _target_: torch.nn.MSELoss
153 | metrics:
154 | mae:
155 | _target_: torchmetrics.regression.MeanAbsoluteError
156 | rmse:
157 | _target_: torchmetrics.regression.MeanSquaredError
158 | squared: false
159 | mae_iso:
160 | _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError
161 | mae_aniso:
162 | _target_: schnetpack.train.metrics.TensorDiagonalMeanAbsoluteError
163 | diagonal: false
164 | loss_weight: 0.001
165 |
166 |
167 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/experiment/rmd17.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: nnp
5 | - override /data: rmd17
6 |
7 | run:
8 | experiment: md17_${data.molecule}
9 |
10 | globals:
11 | cutoff: 5.
12 | lr: 1e-3
13 | energy_key: energy
14 | forces_key: forces
15 |
16 | data:
17 | distance_unit: Ang
18 | property_units:
19 | energy: kcal/mol
20 | forces: kcal/mol/Ang
21 | transforms:
22 | - _target_: schnetpack.transform.SubtractCenterOfMass
23 | - _target_: schnetpack.transform.RemoveOffsets
24 | property: energy
25 | remove_mean: True
26 | - _target_: schnetpack.transform.MatScipyNeighborList
27 | cutoff: ${globals.cutoff}
28 | - _target_: schnetpack.transform.CastTo32
29 |
30 | model:
31 | output_modules:
32 | - _target_: schnetpack.atomistic.Atomwise
33 | output_key: ${globals.energy_key}
34 | n_in: ${model.representation.n_atom_basis}
35 | aggregation_mode: sum
36 | - _target_: schnetpack.atomistic.Forces
37 | energy_key: ${globals.energy_key}
38 | force_key: ${globals.forces_key}
39 | postprocessors:
40 | - _target_: schnetpack.transform.CastTo64
41 | - _target_: schnetpack.transform.AddOffsets
42 | property: energy
43 | add_mean: True
44 |
45 | task:
46 | outputs:
47 | - _target_: schnetpack.task.ModelOutput
48 | name: ${globals.energy_key}
49 | loss_fn:
50 | _target_: torch.nn.MSELoss
51 | metrics:
52 | mae:
53 | _target_: torchmetrics.regression.MeanAbsoluteError
54 | rmse:
55 | _target_: torchmetrics.regression.MeanSquaredError
56 | squared: False
57 | loss_weight: 0.01
58 | - _target_: schnetpack.task.ModelOutput
59 | name: ${globals.forces_key}
60 | loss_fn:
61 | _target_: torch.nn.MSELoss
62 | metrics:
63 | mae:
64 | _target_: torchmetrics.regression.MeanAbsoluteError
65 | rmse:
66 | _target_: torchmetrics.regression.MeanSquaredError
67 | squared: False
68 | loss_weight: 0.99
69 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/globals/default_globals.yaml:
--------------------------------------------------------------------------------
1 | model_path: "best_model"
--------------------------------------------------------------------------------
/src/schnetpack/configs/logger/aim.yaml:
--------------------------------------------------------------------------------
1 | # Aim Logger
2 |
3 | aim:
4 | _target_: aim.pytorch_lightning.AimLogger
5 | repo: ${hydra:runtime.cwd}/${run.path}
6 | experiment: ${run.experiment}
7 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/logger/csv.yaml:
--------------------------------------------------------------------------------
1 | # CSVLogger built in PyTorch Lightning
2 |
3 | csv:
4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
5 | save_dir: "."
6 | name: "csv/"
7 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/logger/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | # TensorBoard
2 |
3 | tensorboard:
4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
5 | save_dir: "tensorboard/"
6 | name: "default"
7 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/logger/wandb.yaml:
--------------------------------------------------------------------------------
1 | wandb:
2 | _target_: pytorch_lightning.loggers.WandbLogger
3 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/model/nnp.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - representation: painn
3 |
4 | _target_: schnetpack.model.NeuralNetworkPotential
5 |
6 | input_modules:
7 | - _target_: schnetpack.atomistic.PairwiseDistances
8 | output_modules: ???
9 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/model/representation/field_schnet.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - radial_basis: gaussian
3 |
4 | _target_: schnetpack.representation.FieldSchNet
5 | n_atom_basis: 128
6 | n_interactions: 5
7 | external_fields: []
8 | response_properties: ${globals.response_properties}
9 | shared_interactions: False
10 | cutoff_fn:
11 | _target_: schnetpack.nn.cutoff.CosineCutoff
12 | cutoff: ${globals.cutoff}
--------------------------------------------------------------------------------
/src/schnetpack/configs/model/representation/painn.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - radial_basis: gaussian
3 |
4 | _target_: schnetpack.representation.PaiNN
5 | n_atom_basis: 128
6 | n_interactions: 3
7 | shared_interactions: False
8 | shared_filters: False
9 | cutoff_fn:
10 | _target_: schnetpack.nn.cutoff.CosineCutoff
11 | cutoff: ${globals.cutoff}
--------------------------------------------------------------------------------
/src/schnetpack/configs/model/representation/radial_basis/bessel.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.nn.radial.BesselRBF
2 | n_rbf: 20
3 | cutoff: ${globals.cutoff}
--------------------------------------------------------------------------------
/src/schnetpack/configs/model/representation/radial_basis/gaussian.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.nn.radial.GaussianRBF
2 | n_rbf: 20
3 | cutoff: ${globals.cutoff}
--------------------------------------------------------------------------------
/src/schnetpack/configs/model/representation/schnet.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - radial_basis: gaussian
3 |
4 | _target_: schnetpack.representation.SchNet
5 | n_atom_basis: 128
6 | n_interactions: 6
7 | cutoff_fn:
8 | _target_: schnetpack.nn.cutoff.CosineCutoff
9 | cutoff: ${globals.cutoff}
--------------------------------------------------------------------------------
/src/schnetpack/configs/model/representation/so3net.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - radial_basis: gaussian
3 |
4 | _target_: schnetpack.representation.SO3net
5 | n_atom_basis: 128
6 | n_interactions: 3
7 | lmax: 2
8 | shared_interactions: False
9 | return_vector_representation: False
10 | cutoff_fn:
11 | _target_: schnetpack.nn.cutoff.CosineCutoff
12 | cutoff: ${globals.cutoff}
--------------------------------------------------------------------------------
/src/schnetpack/configs/predict.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - _self_
4 | - trainer: default_trainer
5 |
6 | datapath: ???
7 | modeldir: ???
8 | outputdir: ???
9 | cutoff: ???
10 | ckpt_path: null
11 | batch_size: 100
12 | write_interval: epoch
13 | enable_grad: false
14 | write_idx_m: false
15 |
16 | data:
17 | _target_: schnetpack.data.ASEAtomsData
18 | datapath: ${datapath}
19 | transforms:
20 | - _target_: schnetpack.transform.SubtractCenterOfMass
21 | - _target_: schnetpack.transform.MatScipyNeighborList
22 | cutoff: ${cutoff}
23 | - _target_: schnetpack.transform.CastTo32
24 |
25 |
26 | # hydra configuration
27 | hydra:
28 | job:
29 | chdir: True
30 |
31 | # output paths for hydra logs
32 | run:
33 | dir: ${modeldir}
34 |
35 | # disable hydra config storage, since handled manually
36 | output_subdir: null
37 |
38 | help:
39 | app_name: SchNetPack Predict
40 |
41 | template: |-
42 | SchNetPack
43 |
44 | == Configuration groups ==
45 | Compose your configuration from those groups (db=mysql)
46 |
47 | $APP_CONFIG_GROUPS
48 |
49 | == Config ==
50 | This is the config generated for this run.
51 | You can change the config file to be loaded to a predefined one
52 | > spktrain --config-name=train_qm9
53 |
54 | or your own:
55 | > spktrain --config-dir=./my_configs --config-name=my_config
56 |
57 | You can override everything, for example:
58 | > spktrain --config-name=train_qm9 data_dir=/path/to/datadir data.batch_size=50 --help
59 |
60 | -------
61 | $CONFIG
62 | -------
63 |
64 | ${hydra.help.footer}
65 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/run/default_run.yaml:
--------------------------------------------------------------------------------
1 | work_dir: ${hydra:runtime.cwd}
2 | data_dir: ${run.work_dir}/data
3 | path: ${run.work_dir}/runs
4 | experiment: default
5 | id: ${uuid:1}
6 | ckpt_path: null
7 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/task/default_task.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - optimizer: adam
4 | - scheduler: reduce_on_plateau
5 |
6 | _target_: schnetpack.AtomisticTask
7 | outputs: ???
8 | warmup_steps: 0
9 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/task/optimizer/adabelief.yaml:
--------------------------------------------------------------------------------
1 | # @package task
2 | optimizer_cls: adabelief_pytorch.AdaBelief
3 | optimizer_args:
4 | lr: ${globals.lr}
5 | weight_decay: 0.0
--------------------------------------------------------------------------------
/src/schnetpack/configs/task/optimizer/adam.yaml:
--------------------------------------------------------------------------------
1 | # @package task
2 | optimizer_cls: torch.optim.AdamW
3 | optimizer_args:
4 | lr: ${globals.lr}
5 | weight_decay: 0.0
6 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/task/optimizer/sgd.yaml:
--------------------------------------------------------------------------------
1 | # @package task
2 | optimizer_cls: torch.optim.SGD
3 | optimizer_args:
4 | lr: ${globals.lr}
5 | weight_decay: 0.0
6 | momentum: 0.0
7 | nesterov: False
8 | dampening: 0.0
--------------------------------------------------------------------------------
/src/schnetpack/configs/task/scheduler/reduce_on_plateau.yaml:
--------------------------------------------------------------------------------
1 | # @package task
2 | scheduler_cls: schnetpack.train.ReduceLROnPlateau
3 | scheduler_monitor: val_loss
4 | scheduler_args:
5 | mode: min
6 | factor: 0.5
7 | patience: 75
8 | threshold: 0.0
9 | threshold_mode: rel
10 | cooldown: 10
11 | min_lr: 0.0
12 | smoothing_factor: 0.0
--------------------------------------------------------------------------------
/src/schnetpack/configs/train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _self_
5 | - run: default_run
6 | - globals: default_globals
7 | - trainer: default_trainer
8 | - callbacks:
9 | - checkpoint
10 | - earlystopping
11 | - lrmonitor
12 | - ema
13 | - task: default_task
14 | - model: null
15 | - data: custom
16 | - logger: tensorboard
17 | - experiment: null
18 |
19 | print_config: True
20 |
21 | # hydra configuration
22 | hydra:
23 | job:
24 | chdir: True
25 | # output paths for hydra logs
26 | run:
27 | dir: ${run.path}/${run.id}
28 |
29 | searchpath:
30 | - file://${oc.env:PWD}
31 | - file://${oc.env:PWD}/configs
32 |
33 | # disable hydra config storage, since handled manually
34 | output_subdir: null
35 |
36 | help:
37 | app_name: SchNetPack Train
38 |
39 | template: |-
40 | SchNetPack
41 |
42 | == Configuration groups ==
43 | Compose your configuration from those groups (db=mysql)
44 |
45 | $APP_CONFIG_GROUPS
46 |
47 | == Config ==
48 | This is the config generated for this run.
49 |
50 | -------
51 | $CONFIG
52 | -------
53 |
54 | You can overide the config file with a pre-defined experiment config
55 | > spktrain experiment=qm9
56 |
57 | or your own experiment config, which needs to be located in a directory called `experiment` in the config search
58 | path, e.g.,
59 | > spktrain --config-dir=./my_configs experiment=my_experiment
60 |
61 | with your experiment config located at `./my_configs/experiment/my_experiment.yaml`.
62 | Your current working directory as well as an optional config subdirectory are automatically in the config
63 | search path. Therefore, you can put your experiment config either in `./experiment`,
64 | or `./configs/experiment`.
65 |
66 | You can also override everything with the command line, for example:
67 | > spktrain experiment=qm9 data_dir=/path/to/datadir data.batch_size=50
68 |
69 | ${hydra.help.footer}
70 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/trainer/ddp_debug.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - debug_trainer
3 |
4 | devices: 2
5 | accelerator: cpu
6 | strategy: ddp
7 | num_nodes: 1
8 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/trainer/ddp_trainer.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default_trainer
3 |
4 | devices: 2
5 | accelerator: auto
6 | strategy: ddp
7 | num_nodes: 1
8 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/trainer/debug_trainer.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default_trainer
3 |
4 | gpus: 1
5 | min_epochs: 1
6 | max_epochs: 3
7 |
8 | # prints
9 | detect_anomaly: True
10 | profiler: simple
11 |
12 |
--------------------------------------------------------------------------------
/src/schnetpack/configs/trainer/default_trainer.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 |
3 | devices: 1
4 |
5 | min_epochs: null
6 | max_epochs: 100000
7 |
8 | # prints
9 | enable_model_summary: True
10 | profiler: null
11 |
12 | gradient_clip_val: null
13 | accumulate_grad_batches: 1
14 | val_check_interval: 1.0
15 | check_val_every_n_epoch: 1
16 |
17 | num_sanity_val_steps: 0
18 | fast_dev_run: False
19 | overfit_batches: 0
20 | limit_train_batches: 1.0
21 | limit_val_batches: 1.0
22 | limit_test_batches: 1.0
23 | detect_anomaly: False
24 |
25 | precision: 32
26 | accelerator: auto
27 | num_nodes: 1
28 | deterministic: False
29 | inference_mode: False
30 |
--------------------------------------------------------------------------------
/src/schnetpack/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .atoms import *
2 | from .loader import *
3 | from .stats import *
4 | from .splitting import *
5 | from .datamodule import *
6 | from .sampler import *
7 |
--------------------------------------------------------------------------------
/src/schnetpack/data/loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 |
4 | from typing import Optional, Sequence
5 | from torch.utils.data import Dataset, Sampler
6 | from torch.utils.data.dataloader import _collate_fn_t, _T_co
7 |
8 | import schnetpack.properties as structure
9 |
10 | __all__ = ["AtomsLoader"]
11 |
12 |
13 | def _atoms_collate_fn(batch):
14 | """
15 | Build batch from systems and properties & apply padding
16 |
17 | Args:
18 | examples (list):
19 |
20 | Returns:
21 | dict[str->torch.Tensor]: mini-batch of atomistic systems
22 | """
23 | elem = batch[0]
24 | idx_keys = {structure.idx_i, structure.idx_j, structure.idx_i_triples}
25 | # Atom triple indices must be treated separately
26 | idx_triple_keys = {structure.idx_j_triples, structure.idx_k_triples}
27 |
28 | coll_batch = {}
29 | for key in elem:
30 | if (key not in idx_keys) and (key not in idx_triple_keys):
31 | coll_batch[key] = torch.cat([d[key] for d in batch], 0)
32 | elif key in idx_keys:
33 | coll_batch[key + "_local"] = torch.cat([d[key] for d in batch], 0)
34 |
35 | seg_m = torch.cumsum(coll_batch[structure.n_atoms], dim=0)
36 | seg_m = torch.cat([torch.zeros((1,), dtype=seg_m.dtype), seg_m], dim=0)
37 | idx_m = torch.repeat_interleave(
38 | torch.arange(len(batch)), repeats=coll_batch[structure.n_atoms], dim=0
39 | )
40 | coll_batch[structure.idx_m] = idx_m
41 |
42 | for key in idx_keys:
43 | if key in elem.keys():
44 | coll_batch[key] = torch.cat(
45 | [d[key] + off for d, off in zip(batch, seg_m)], 0
46 | )
47 |
48 | # Shift the indices for the atom triples
49 | for key in idx_triple_keys:
50 | if key in elem.keys():
51 | indices = []
52 | offset = 0
53 | for idx, d in enumerate(batch):
54 | indices.append(d[key] + offset)
55 | offset += d[structure.idx_j].shape[0]
56 | coll_batch[key] = torch.cat(indices, 0)
57 |
58 | return coll_batch
59 |
60 |
61 | class AtomsLoader(DataLoader):
62 | """Data loader for subclasses of BaseAtomsData"""
63 |
64 | def __init__(
65 | self,
66 | dataset: Dataset[_T_co],
67 | batch_size: Optional[int] = 1,
68 | shuffle: bool = False,
69 | sampler: Optional[Sampler[int]] = None,
70 | batch_sampler: Optional[Sampler[Sequence[int]]] = None,
71 | num_workers: int = 0,
72 | collate_fn: _collate_fn_t = _atoms_collate_fn,
73 | pin_memory: bool = False,
74 | **kwargs,
75 | ):
76 | super(AtomsLoader, self).__init__(
77 | dataset=dataset,
78 | batch_size=batch_size,
79 | shuffle=shuffle,
80 | sampler=sampler,
81 | batch_sampler=batch_sampler,
82 | num_workers=num_workers,
83 | collate_fn=collate_fn,
84 | pin_memory=pin_memory,
85 | **kwargs,
86 | )
87 |
--------------------------------------------------------------------------------
/src/schnetpack/data/sampler.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator, List, Callable
2 |
3 | import numpy as np
4 | from torch.utils.data import Sampler, WeightedRandomSampler
5 |
6 | from schnetpack import properties
7 | from schnetpack.data import BaseAtomsData
8 |
9 |
10 | __all__ = [
11 | "StratifiedSampler",
12 | "NumberOfAtomsCriterion",
13 | "PropertyCriterion",
14 | ]
15 |
16 |
17 | class NumberOfAtomsCriterion:
18 | """
19 | A callable class that returns the number of atoms for each sample in the dataset.
20 | """
21 |
22 | def __call__(self, dataset):
23 | n_atoms = []
24 | for spl_idx in range(len(dataset)):
25 | sample = dataset[spl_idx]
26 | n_atoms.append(sample[properties.n_atoms].item())
27 | return n_atoms
28 |
29 |
30 | class PropertyCriterion:
31 | """
32 | A callable class that returns the specified property for each sample in the dataset.
33 | Property must be a scalar value.
34 | """
35 |
36 | def __init__(self, property_key: str = properties.energy):
37 | self.property_key = property_key
38 |
39 | def __call__(self, dataset):
40 | property_values = []
41 | for spl_idx in range(len(dataset)):
42 | sample = dataset[spl_idx]
43 | property_values.append(sample[self.property_key].item())
44 | return property_values
45 |
46 |
47 | class StratifiedSampler(WeightedRandomSampler):
48 | """
49 | A custom sampler that performs stratified sampling based on a partition criterion.
50 |
51 | Note: Make sure that num_bins is chosen sufficiently small to avoid too many empty bins.
52 | """
53 |
54 | def __init__(
55 | self,
56 | data_source: BaseAtomsData,
57 | partition_criterion: Callable[[BaseAtomsData], List],
58 | num_samples: int,
59 | num_bins: int = 10,
60 | replacement: bool = True,
61 | verbose: bool = True,
62 | ) -> None:
63 | """
64 | Args:
65 | data_source: The data source to be sampled from.
66 | partition_criterion: A callable function that takes a data source
67 | and returns a list of values used for partitioning.
68 | num_samples: The total number of samples to be drawn from the data source.
69 | num_bins: The number of bins to divide the partitioned values into. Defaults to 10.
70 | replacement: Whether to sample with replacement or without replacement. Defaults to True.
71 | verbose: Whether to print verbose output during sampling. Defaults to True.
72 | """
73 | self.data_source = data_source
74 | self.num_bins = num_bins
75 | self.verbose = verbose
76 |
77 | weights = self.calculate_weights(partition_criterion)
78 | super().__init__(
79 | weights=weights, num_samples=num_samples, replacement=replacement
80 | )
81 |
82 | def calculate_weights(self, partition_criterion):
83 | """
84 | Calculates the weights for each sample based on the partition criterion.
85 | """
86 | feature_values = partition_criterion(self.data_source)
87 |
88 | bin_counts, bin_edges = np.histogram(feature_values, bins=self.num_bins)
89 | bin_edges = bin_edges[1:]
90 | bin_edges[-1] += 0.1
91 | bin_indices = np.digitize(feature_values, bin_edges)
92 |
93 | min_counts = min(bin_counts[bin_counts != 0])
94 | bin_weights = np.where(bin_counts == 0, 0, min_counts / bin_counts)
95 | weights = bin_weights[bin_indices]
96 |
97 | return weights
98 |
--------------------------------------------------------------------------------
/src/schnetpack/data/stats.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 |
3 | import torch
4 | from tqdm import tqdm
5 |
6 | import schnetpack.properties as properties
7 | from schnetpack.data import AtomsLoader
8 |
9 | __all__ = ["calculate_stats", "estimate_atomrefs"]
10 |
11 |
12 | def calculate_stats(
13 | dataloader: AtomsLoader,
14 | divide_by_atoms: Dict[str, bool],
15 | atomref: Dict[str, torch.Tensor] = None,
16 | ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
17 | """
18 | Use the incremental Welford algorithm described in [h1]_ to accumulate
19 | the mean and standard deviation over a set of samples.
20 |
21 | References:
22 | -----------
23 | .. [h1] https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
24 |
25 | Args:
26 | dataloader: data loader
27 | divide_by_atoms: dict from property name to bool:
28 | If True, divide property by number of atoms before calculating statistics.
29 | atomref: reference values for single atoms to be removed before calculating stats
30 |
31 | Returns:
32 | Mean and standard deviation over all samples
33 |
34 | """
35 | property_names = list(divide_by_atoms.keys())
36 | norm_mask = torch.tensor(
37 | [float(divide_by_atoms[p]) for p in property_names], dtype=torch.float64
38 | )
39 |
40 | count = 0
41 | mean = torch.zeros_like(norm_mask)
42 | M2 = torch.zeros_like(norm_mask)
43 |
44 | for props in tqdm(dataloader, "calculating statistics"):
45 | sample_values = []
46 | for p in property_names:
47 | val = props[p][None, :]
48 | if atomref and p in atomref.keys():
49 | ar = atomref[p]
50 | ar = ar[props[properties.Z]]
51 | idx_m = props[properties.idx_m]
52 | tmp = torch.zeros((idx_m[-1] + 1,), dtype=ar.dtype, device=ar.device)
53 | v0 = tmp.index_add(0, idx_m, ar)
54 | val -= v0
55 |
56 | sample_values.append(val)
57 | sample_values = torch.cat(sample_values, dim=0)
58 |
59 | batch_size = sample_values.shape[1]
60 | new_count = count + batch_size
61 |
62 | norm = norm_mask[:, None] * props[properties.n_atoms][None, :] + (
63 | 1 - norm_mask[:, None]
64 | )
65 | sample_values /= norm
66 |
67 | sample_mean = torch.mean(sample_values, dim=1)
68 | sample_m2 = torch.sum((sample_values - sample_mean[:, None]) ** 2, dim=1)
69 |
70 | delta = sample_mean - mean
71 | mean += delta * batch_size / new_count
72 | corr = batch_size * count / new_count
73 | M2 += sample_m2 + delta**2 * corr
74 | count = new_count
75 |
76 | stddev = torch.sqrt(M2 / count)
77 | stats = {pn: (mu, std) for pn, mu, std in zip(property_names, mean, stddev)}
78 | return stats
79 |
80 |
81 | def estimate_atomrefs(dataloader, is_extensive, z_max=100):
82 | """
83 | Uses linear regression to estimate the elementwise biases (atomrefs).
84 |
85 | Args:
86 | dataloader: data loader
87 | is_extensive: If True, divide atom type counts by number of atoms before
88 | calculating statistics.
89 |
90 | Returns:
91 | Elementwise bias estimates over all samples
92 |
93 | """
94 | property_names = list(is_extensive.keys())
95 | n_data = len(dataloader.dataset)
96 | all_properties = {pname: torch.zeros(n_data) for pname in property_names}
97 | all_atom_types = torch.zeros((n_data, z_max))
98 | data_counter = 0
99 |
100 | # loop over all batches
101 | for batch in tqdm(dataloader, "estimating atomrefs"):
102 | # load data
103 | idx_m = batch[properties.idx_m]
104 | atomic_numbers = batch[properties.Z]
105 |
106 | # get counts for atomic numbers
107 | unique_ids = torch.unique(idx_m)
108 | for i in unique_ids:
109 | atomic_numbers_i = atomic_numbers[idx_m == i]
110 | atom_types, atom_counts = torch.unique(atomic_numbers_i, return_counts=True)
111 | # save atom counts and properties
112 | for atom_type, atom_count in zip(atom_types, atom_counts):
113 | all_atom_types[data_counter, atom_type] = atom_count
114 | for pname in property_names:
115 | property_value = batch[pname][i]
116 | if not is_extensive[pname]:
117 | property_value *= batch[properties.n_atoms][i]
118 | all_properties[pname][data_counter] = property_value
119 | data_counter += 1
120 |
121 | # perform linear regression to get the elementwise energy contributions
122 | existing_atom_types = torch.where(all_atom_types.sum(axis=0) != 0)[0]
123 | X = torch.squeeze(all_atom_types[:, existing_atom_types])
124 | w = dict()
125 | for pname in property_names:
126 | if is_extensive[pname]:
127 | w[pname] = torch.linalg.inv(X.T @ X) @ X.T @ all_properties[pname]
128 | else:
129 | w[pname] = (
130 | torch.linalg.inv(X.T @ X)
131 | @ X.T
132 | @ (all_properties[pname] / X.sum(axis=1))
133 | )
134 |
135 | # compute energy estimates
136 | elementwise_contributions = {
137 | pname: torch.zeros((z_max)) for pname in property_names
138 | }
139 | for pname in property_names:
140 | for atom_type, weight in zip(existing_atom_types, w[pname]):
141 | elementwise_contributions[pname][atom_type] = weight
142 |
143 | return elementwise_contributions
144 |
--------------------------------------------------------------------------------
/src/schnetpack/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .qm9 import *
2 | from .md17 import *
3 | from .md22 import *
4 | from .rmd17 import *
5 | from .iso17 import *
6 | from .ani1 import *
7 | from .materials_project import *
8 | from .omdb import *
9 | from .tmqm import *
10 | from .qm7x import *
11 |
--------------------------------------------------------------------------------
/src/schnetpack/datasets/md22.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional, Dict, List
3 |
4 | from schnetpack.data import *
5 | from schnetpack.datasets.md17 import GDMLDataModule
6 |
7 |
8 | all = ["MD22"]
9 |
10 |
11 | class MD22(GDMLDataModule):
12 | """
13 | MD22 benchmark data set for extended molecules containing molecular forces.
14 |
15 | References:
16 | .. [#md22_1] http://quantum-machine.org/gdml/#datasets
17 |
18 | """
19 |
20 | def __init__(
21 | self,
22 | datapath: str,
23 | molecule: str,
24 | batch_size: int,
25 | num_train: Optional[int] = None,
26 | num_val: Optional[int] = None,
27 | num_test: Optional[int] = None,
28 | split_file: Optional[str] = "split.npz",
29 | format: Optional[AtomsDataFormat] = AtomsDataFormat.ASE,
30 | load_properties: Optional[List[str]] = None,
31 | val_batch_size: Optional[int] = None,
32 | test_batch_size: Optional[int] = None,
33 | transforms: Optional[List[torch.nn.Module]] = None,
34 | train_transforms: Optional[List[torch.nn.Module]] = None,
35 | val_transforms: Optional[List[torch.nn.Module]] = None,
36 | test_transforms: Optional[List[torch.nn.Module]] = None,
37 | num_workers: int = 2,
38 | num_val_workers: Optional[int] = None,
39 | num_test_workers: Optional[int] = None,
40 | property_units: Optional[Dict[str, str]] = None,
41 | distance_unit: Optional[str] = None,
42 | data_workdir: Optional[str] = None,
43 | **kwargs,
44 | ):
45 | """
46 | Args:
47 | datapath: path to dataset
48 | batch_size: (train) batch size
49 | num_train: number of training examples
50 | num_val: number of validation examples
51 | num_test: number of test examples
52 | split_file: path to npz file with data partitions
53 | format: dataset format
54 | load_properties: subset of properties to load
55 | val_batch_size: validation batch size. If None, use test_batch_size, then batch_size.
56 | test_batch_size: test batch size. If None, use val_batch_size, then batch_size.
57 | transforms: Transform applied to each system separately before batching.
58 | train_transforms: Overrides transform_fn for training.
59 | val_transforms: Overrides transform_fn for validation.
60 | test_transforms: Overrides transform_fn for testing.
61 | num_workers: Number of data loader workers.
62 | num_val_workers: Number of validation data loader workers (overrides num_workers).
63 | num_test_workers: Number of test data loader workers (overrides num_workers).
64 | distance_unit: Unit of the atom positions and cell as a string (Ang, Bohr, ...).
65 | data_workdir: Copy data here as part of setup, e.g. cluster scratch for faster performance.
66 | """
67 | atomrefs = {
68 | self.energy: [
69 | 0.0,
70 | -313.5150902000774,
71 | 0.0,
72 | 0.0,
73 | 0.0,
74 | 0.0,
75 | -23622.587180094913,
76 | -34219.46811826416,
77 | -47069.30768969713,
78 | ]
79 | }
80 | datasets_dict = {
81 | "Ac-Ala3-NHMe": "md22_Ac-Ala3-NHMe.npz",
82 | "DHA": "md22_DHA.npz",
83 | "stachyose": "md22_stachyose.npz",
84 | "AT-AT": "md22_AT-AT.npz",
85 | "AT-AT-CG-CG": "md22_AT-AT-CG-CG.npz",
86 | "buckyball-catcher": "md22_buckyball-catcher.npz",
87 | "double-walled_nanotube": "md22_double-walled_nanotube.npz",
88 | }
89 |
90 | super(MD22, self).__init__(
91 | datasets_dict=datasets_dict,
92 | download_url="http://www.quantum-machine.org/gdml/repo/datasets/",
93 | tmpdir="md22",
94 | molecule=molecule,
95 | datapath=datapath,
96 | batch_size=batch_size,
97 | num_train=num_train,
98 | num_val=num_val,
99 | num_test=num_test,
100 | split_file=split_file,
101 | format=format,
102 | load_properties=load_properties,
103 | val_batch_size=val_batch_size,
104 | test_batch_size=test_batch_size,
105 | transforms=transforms,
106 | train_transforms=train_transforms,
107 | val_transforms=val_transforms,
108 | test_transforms=test_transforms,
109 | num_workers=num_workers,
110 | num_val_workers=num_val_workers,
111 | num_test_workers=num_test_workers,
112 | property_units=property_units,
113 | distance_unit=distance_unit,
114 | data_workdir=data_workdir,
115 | atomrefs=atomrefs,
116 | **kwargs,
117 | )
118 |
--------------------------------------------------------------------------------
/src/schnetpack/interfaces/__init__.py:
--------------------------------------------------------------------------------
1 | from .ase_interface import *
2 | from .batchwise_optimization import *
3 |
--------------------------------------------------------------------------------
/src/schnetpack/md/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains all functionality for performing various molecular dynamics simulations
3 | using SchNetPack.
4 | """
5 |
6 | from .system import *
7 | from .initial_conditions import *
8 | from .simulator import *
9 | from . import integrators
10 | from . import simulation_hooks
11 | from . import calculators
12 | from . import neighborlist_md
13 | from . import utils
14 | from . import data
15 |
--------------------------------------------------------------------------------
/src/schnetpack/md/calculators/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Calculator modules for molecular dynamics simulations.
3 | """
4 |
5 | from .base_calculator import *
6 | from .schnetpack_calculator import *
7 | from .lj_calculator import *
8 | from .ensemble_calculator import *
9 | from .orca_calculator import *
10 |
--------------------------------------------------------------------------------
/src/schnetpack/md/calculators/ensemble_calculator.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch
3 |
4 | from abc import ABC
5 | from typing import TYPE_CHECKING, List, Dict, Optional
6 | from schnetpack.md.calculators import MDCalculator
7 |
8 | if TYPE_CHECKING:
9 | from schnetpack.md import System
10 |
11 | __all__ = ["EnsembleCalculator"]
12 |
13 |
14 | class EnsembleCalculator(ABC, MDCalculator):
15 | """
16 | Mixin for creating ensemble calculators from the standard `schnetpack.md.calculators` classes. Accumulates
17 | property predictions as the average over all models and uncertainties as the variance of model predictions.
18 | """
19 |
20 | def calculate(self, system: System):
21 | """
22 | Perform all calculations and compyte properties and uncertainties.
23 |
24 | Args:
25 | system (schnetpack.md.System): System from the molecular dynamics simulation.
26 | """
27 | inputs = self._generate_input(system)
28 |
29 | results = []
30 | for model in self.model:
31 | prediction = model(inputs)
32 | results.append(prediction)
33 |
34 | # Compute statistics
35 | self.results = self._accumulate_results(results)
36 | self._update_system(system)
37 |
38 | @staticmethod
39 | def _accumulate_results(
40 | results: List[Dict[str, torch.tensor]],
41 | ) -> Dict[str, torch.tensor]:
42 | """
43 | Accumulate results and compute average predictions and uncertainties.
44 |
45 | Args:
46 | results (list(dict(str, torch.tensor)): List of output dictionaries of individual models.
47 |
48 | Returns:
49 | dict(str, torch.tensor): output dictionary with averaged predictions and uncertainties.
50 | """
51 | # Get the keys
52 | accumulated = {p: [] for p in results[0]}
53 | ensemble_results = {p: [] for p in results[0]}
54 |
55 | for p in accumulated:
56 | tmp = torch.stack([result[p] for result in results])
57 | ensemble_results[p] = torch.mean(tmp, dim=0)
58 | ensemble_results["{:s}_var".format(p)] = torch.var(tmp, dim=0)
59 |
60 | return ensemble_results
61 |
62 | def _activate_stress(self, stress_key: Optional[str] = None):
63 | """
64 | Routine for activating stress computations
65 | Args:
66 | stress_key (str, optional): stess label.
67 | """
68 | raise NotImplementedError
69 |
70 | def _update_required_properties(self):
71 | """
72 | Update required properties to also contain predictive variances.
73 | """
74 | new_required = []
75 | for p in self.required_properties:
76 | prop_string = "{:s}_var".format(p)
77 | new_required += [p, prop_string]
78 | # Update property conversion
79 | self.property_conversion[prop_string] = self.property_conversion[p]
80 |
81 | self.required_properties = new_required
82 |
--------------------------------------------------------------------------------
/src/schnetpack/md/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .hdf5_data import *
2 | from .spectra import *
3 |
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/src/schnetpack/md/md_configs/__init__.py
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/calculator/lj.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.calculators.LJCalculator
2 | r_equilibrium: 3.405 # Angs
3 | well_depth: 3.984264 # 4*e (e=119.8K) in kJ/mol
4 | force_key: forces
5 | energy_unit: kJ/mol
6 | position_unit: Angstrom
7 | energy_key: energy
8 | stress_key: stress
9 | healing_length: 4.0 #0.3405
10 |
11 | defaults:
12 | - neighbor_list: matscipy
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/calculator/neighbor_list/ase.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.neighborlist_md.NeighborListMD
2 | cutoff: ???
3 | cutoff_shell: 2.0
4 | requires_triples: false
5 | base_nbl: schnetpack.transform.ASENeighborList
6 | collate_fn: schnetpack.data.loader._atoms_collate_fn
7 |
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.neighborlist_md.NeighborListMD
2 | cutoff: ???
3 | cutoff_shell: 2.0
4 | requires_triples: false
5 | base_nbl: schnetpack.transform.MatScipyNeighborList
6 | collate_fn: schnetpack.data.loader._atoms_collate_fn
7 |
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/calculator/neighbor_list/torch.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.neighborlist_md.NeighborListMD
2 | cutoff: ???
3 | cutoff_shell: 2.0
4 | requires_triples: false
5 | base_nbl: schnetpack.transform.TorchNeighborList
6 | collate_fn: schnetpack.data.loader._atoms_collate_fn
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/calculator/orca.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.calculators.OrcaCalculator
2 | required_properties:
3 | - energy
4 | - forces
5 | force_key: forces
6 | compdir: qm_calculations
7 | qm_executable: ???
8 | orca_template: ???
9 | energy_unit: Hartree
10 | position_unit: Bohr
11 | energy_key: energy
12 | stress_key: null
13 | property_conversion: { }
14 | overwrite: true
15 | adaptive: false
16 | basename: qm_calc
17 |
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/calculator/spk.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.calculators.SchNetPackCalculator
2 | required_properties:
3 | - energy
4 | - forces
5 | model_file: ???
6 | force_key: forces
7 | energy_unit: kcal / mol
8 | position_unit: Angstrom
9 | energy_key: energy
10 | stress_key: null
11 | script_model: false
12 |
13 | defaults:
14 | - neighbor_list: matscipy
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.calculators.SchNetPackEnsembleCalculator
2 | required_properties:
3 | - energy
4 | - forces
5 | model_files:
6 | - ???
7 | force_key: forces
8 | energy_unit: kcal / mol
9 | position_unit: Angstrom
10 | energy_key: energy
11 | stress_key: null
12 | script_model: false
13 |
14 | defaults:
15 | - neighbor_list: matscipy
16 |
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/callbacks/checkpoint.yaml:
--------------------------------------------------------------------------------
1 | checkpoint:
2 | _target_: schnetpack.md.simulation_hooks.Checkpoint
3 | checkpoint_file: checkpoint.chk
4 | every_n_steps: 10
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/callbacks/hdf5.yaml:
--------------------------------------------------------------------------------
1 | hdf5:
2 | _target_: schnetpack.md.simulation_hooks.FileLogger
3 | filename: simulation.hdf5
4 | buffer_size: 100
5 | data_streams:
6 | - _target_: schnetpack.md.simulation_hooks.MoleculeStream
7 | store_velocities: true
8 | - _target_: schnetpack.md.simulation_hooks.PropertyStream
9 | target_properties: [ energy ]
10 | every_n_steps: 1
11 | precision: ${precision}
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/callbacks/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | tensorboard:
2 | _target_: schnetpack.md.simulation_hooks.TensorBoardLogger
3 | log_file: logs
4 | properties:
5 | - energy
6 | - temperature
7 | every_n_steps: 10
8 |
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/config.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | device: cuda
4 | precision: 32
5 | seed: null
6 | simulation_dir: ???
7 | overwrite: false
8 | restart: null
9 | load_config: null
10 |
11 | defaults:
12 | - _self_
13 | - calculator: spk
14 | - system: system
15 | - dynamics: base
16 | - callbacks:
17 | - checkpoint
18 | - hdf5
19 | - tensorboard
20 |
21 | hydra:
22 | run:
23 | dir: ${simulation_dir}
24 | job:
25 | config:
26 | override_dirname:
27 | exclude_keys:
28 | - basename
29 | kv_sep: '='
30 | item_sep: '_'
31 | chdir: True
32 |
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/barostat/nhc_aniso.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.simulation_hooks.NHCBarostatAnisotropic
2 | target_pressure: 1000.0 # bar
3 | temperature_bath: 300.0 # K
4 | time_constant: 100.0 # fs
5 | time_constant_cell: 1000.0
6 | time_constant_barostat: 500.0
7 | chain_length: 4
8 | multi_step: 4
9 | integration_order: 7
10 | massive: true
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/barostat/nhc_iso.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.simulation_hooks.NHCBarostatIsotropic
2 | target_pressure: 1000.0 # bar
3 | temperature_bath: 300.0 # K
4 | time_constant: 100.0 # fs
5 | time_constant_cell: 1000.0
6 | time_constant_barostat: 500.0
7 | chain_length: 4
8 | multi_step: 4
9 | integration_order: 7
10 | massive: true
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/barostat/pile_rpmd.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.simulation_hooks.PILEBarostat
2 | target_pressure: 1000.0 # bar
3 | temperature_bath: 300.0 # K
4 | time_constant: 500.0 # fs
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - integrator: md
3 |
4 | n_steps: 1000000
5 | thermostat: null
6 | barostat: null
7 | progress: true
8 | simulation_hooks: [ ]
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/integrator/md.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.integrators.VelocityVerlet
2 | time_step: 0.5 # fs
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/integrator/rpmd.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.integrators.RingPolymer
2 | time_step: 0.20 # fs
3 | temperature: 300.0 # K
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/thermostat/berendsen.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.simulation_hooks.BerendsenThermostat
2 | temperature_bath: 300.0 # K
3 | time_constant: 100.0 # fs
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/thermostat/gle.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.simulation_hooks.GLEThermostat
2 | temperature_bath: 300.0
3 | gle_file: ???
4 | free_particle_limit: true
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/thermostat/langevin.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.simulation_hooks.LangevinThermostat
2 | temperature_bath: 300.0 # K
3 | time_constant: 100.0 # fs
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/thermostat/nhc.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.simulation_hooks.NHCThermostat
2 | temperature_bath: 300.0 # K
3 | time_constant: 100.0 # fs
4 | chain_length: 3
5 | massive: false
6 | multi_step: 2
7 | integration_order: 3
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/thermostat/pi_gle.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.simulation_hooks.RPMDGLEThermostat
2 | temperature_bath: 300.0
3 | gle_file: ???
4 | free_particle_limit: true
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/thermostat/pi_nhc_global.yaml:
--------------------------------------------------------------------------------
1 | _target_ : schnetpack.md.simulation_hooks.NHCRingPolymerThermostat
2 | temperature_bath: 300.0 # K
3 | time_constant: 100.0 # fs
4 | local: false
5 | chain_length: 3
6 | multi_step: 2
7 | integration_order: 3
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/thermostat/pi_nhc_local.yaml:
--------------------------------------------------------------------------------
1 | _target_ : schnetpack.md.simulation_hooks.NHCRingPolymerThermostat
2 | temperature_bath: 300.0 # K
3 | time_constant: 100.0 # fs
4 | local: true
5 | chain_length: 3
6 | multi_step: 2
7 | integration_order: 3
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/thermostat/piglet.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.simulation_hooks.PIGLETThermostat
2 | temperature_bath: 300.0
3 | gle_file: ???
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/thermostat/pile_global.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.simulation_hooks.PILEGlobalThermostat
2 | temperature_bath: 300.0 # K
3 | time_constant: 100.0 # fs
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/thermostat/pile_local.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.simulation_hooks.PILELocalThermostat
2 | temperature_bath: 300.0 # K
3 | time_constant: 100.0 # fs
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/dynamics/thermostat/trpmd.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.simulation_hooks.TRPMDThermostat
2 | temperature_bath: 300.0 # K
3 | damping_factor: 0.8
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/system/initializer/maxwell_boltzmann.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.MaxwellBoltzmannInit
2 | temperature: 300
3 | remove_center_of_mass: true
4 | remove_translation: true
5 | remove_rotation: true
6 | wrap_positions: false
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/system/initializer/uniform.yaml:
--------------------------------------------------------------------------------
1 | _target_: schnetpack.md.UniformInit
2 | temperature: 300
3 | remove_center_of_mass: true
4 | remove_translation: true
5 | remove_rotation: true
6 | wrap_positions: false
--------------------------------------------------------------------------------
/src/schnetpack/md/md_configs/system/system.yaml:
--------------------------------------------------------------------------------
1 | molecule_file: ???
2 | load_system_state: null
3 | n_replicas: 1
4 | position_unit_input: Angstrom
5 | mass_unit_input: 1.0
6 |
7 | defaults:
8 | - initializer: uniform
9 |
--------------------------------------------------------------------------------
/src/schnetpack/md/parsers/__init__.py:
--------------------------------------------------------------------------------
1 | from .orca_parser import *
2 |
--------------------------------------------------------------------------------
/src/schnetpack/md/simulation_hooks/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic_hooks import *
2 | from .barostats import *
3 | from .barostats_rpmd import *
4 | from .thermostats import *
5 | from .thermostats_rpmd import *
6 | from .callback_hooks import *
7 |
--------------------------------------------------------------------------------
/src/schnetpack/md/simulation_hooks/basic_hooks.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch.nn as nn
3 |
4 | from schnetpack.md.utils import UninitializedMixin
5 |
6 | from typing import TYPE_CHECKING
7 |
8 | if TYPE_CHECKING:
9 | from schnetpack.md import Simulator
10 |
11 | __all__ = ["RemoveCOMMotion", "SimulationHook", "WrapPositions"]
12 |
13 |
14 | class SimulationHook(UninitializedMixin, nn.Module):
15 | """
16 | Basic class for simulator hooks
17 | """
18 |
19 | def on_step_begin(self, simulator: Simulator):
20 | pass
21 |
22 | def on_step_middle(self, simulator: Simulator):
23 | pass
24 |
25 | def on_step_end(self, simulator: Simulator):
26 | pass
27 |
28 | def on_step_finalize(self, simulator: Simulator):
29 | pass
30 |
31 | def on_step_failed(self, simulator: Simulator):
32 | pass
33 |
34 | def on_simulation_start(self, simulator: Simulator):
35 | pass
36 |
37 | def on_simulation_end(self, simulator: Simulator):
38 | pass
39 |
40 |
41 | class RemoveCOMMotion(SimulationHook):
42 | """
43 | Periodically remove motions of the center of mass from the system.
44 |
45 | Args:
46 | every_n_steps (int): Frequency with which motions are removed.
47 | remove_rotation (bool): Also remove rotations.
48 | """
49 |
50 | def __init__(self, every_n_steps: int, remove_rotation: bool):
51 | super(RemoveCOMMotion, self).__init__()
52 | self.every_n_steps = every_n_steps
53 | self.remove_rotation = remove_rotation
54 |
55 | def on_step_finalize(self, simulator: Simulator):
56 | if simulator.step % self.every_n_steps == 0:
57 | simulator.system.remove_center_of_mass()
58 | simulator.system.remove_translation()
59 |
60 | if self.remove_rotation:
61 | simulator.system.remove_com_rotation()
62 |
63 |
64 | class WrapPositions(SimulationHook):
65 | """
66 | Periodically wrap atoms back into simulation cell.
67 |
68 | Args:
69 | every_n_steps (int): Frequency with which atoms should be wrapped.
70 | """
71 |
72 | def __init__(self, every_n_steps: int):
73 | super(WrapPositions, self).__init__()
74 | self.every_n_steps = every_n_steps
75 |
76 | def on_step_finalize(self, simulator: Simulator):
77 | if simulator.step % self.every_n_steps == 0:
78 | simulator.system.wrap_positions()
79 |
--------------------------------------------------------------------------------
/src/schnetpack/md/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import schnetpack
2 | import torch
3 | import torch.nn as nn
4 |
5 | from .md_config import *
6 | from .normal_model_transformation import *
7 | from .thermostat_utils import *
8 |
9 | from typing import Optional
10 |
11 |
12 | class CalculatorError(Exception):
13 | pass
14 |
15 |
16 | def activate_model_stress(
17 | model: schnetpack.model.AtomisticModel, stress_key: str
18 | ) -> schnetpack.model.AtomisticModel:
19 | """
20 | Utility function for activating computation of stress in models not explicitly trained on the stress tensor.
21 | Used for e.g. simulations under constant pressure and in cells.
22 |
23 | Args:
24 | model (AtomisticTask): loaded schnetpack model for which stress computation should be activated.
25 | stress_key (str): name of stress tensor in model.
26 |
27 | Returns:
28 | model (AtomisticTask): schnetpack model with activated stress tensor.
29 | """
30 | stress = False
31 |
32 | # Check if a module suitable for stress computation is present
33 | for module in model.output_modules:
34 | if isinstance(module, schnetpack.atomistic.response.Forces) or isinstance(
35 | module, schnetpack.atomistic.Response
36 | ):
37 | # for `Forces` module
38 | if hasattr(module, "calc_stress"):
39 | # activate internal stress computation flag
40 | module.calc_stress = True
41 |
42 | # append stress label to output list and update required derivatives in the module
43 | module.model_outputs.append(stress_key)
44 | module.required_derivatives.append(schnetpack.properties.strain)
45 |
46 | # if not set in the model, also update output list and required derivatives so that:
47 | # a) required derivatives are computed and
48 | # b) property is added to the model outputs
49 | if stress_key not in model.model_outputs:
50 | model.model_outputs.append(stress_key)
51 | model.required_derivatives.append(schnetpack.properties.strain)
52 |
53 | stress = True
54 |
55 | # for `Response` module
56 | if hasattr(module, "basic_derivatives"):
57 | # activate internal stress computation flag
58 | module.calc_stress = True
59 | module.basic_derivatives["dEds"] = schnetpack.properties.strain
60 | module.derivative_instructions["dEds"] = True
61 | module.basic_derivatives["dEds"] = schnetpack.properties.strain
62 |
63 | module.map_properties[schnetpack.properties.stress] = (
64 | schnetpack.properties.stress
65 | )
66 |
67 | # append stress label to output list and update required derivatives in the module
68 | module.model_outputs.append(stress_key)
69 | module.required_derivatives.append(schnetpack.properties.strain)
70 |
71 | # if not set in the model, also update output list and required derivatives so that:
72 | # a) required derivatives are computed and
73 | # b) property is added to the model outputs
74 | if stress_key not in model.model_outputs:
75 | model.model_outputs.append(stress_key)
76 | model.required_derivatives.append(schnetpack.properties.strain)
77 |
78 | stress = True
79 |
80 | # If stress computation has been enables, insert preprocessing for strain computation
81 | if stress:
82 | model.input_modules.insert(0, schnetpack.atomistic.Strain())
83 |
84 | if not stress:
85 | raise CalculatorError("Failed to activate stress computation")
86 |
87 | return model
88 |
89 |
90 | class UninitializedMixin(nn.modules.lazy.LazyModuleMixin):
91 | """
92 | Custom mixin for lazy initialization of buffers used in the MD system and simulation hooks.
93 | This can be used to add buffers with a certain dtype in an uninitialized state.
94 | """
95 |
96 | def register_uninitialized_buffer(
97 | self, name: str, dtype: Optional[torch.dtype] = None
98 | ):
99 | """
100 | Register an uninitialized buffer with the requested dtype. This can be used to reserve variable which are not
101 | known at the initialization of `schnetpack.md.System` and simulation hooks.
102 |
103 | Args:
104 | name (str): Name of the uninitialized buffer to register.
105 | dtype (torch.dtype): If specified, buffer will be set to requested dtype. If None is given, this will
106 | default to float64 type.
107 | """
108 | if dtype is None:
109 | dtype = torch.float64
110 |
111 | self.register_buffer(name, nn.parameter.UninitializedBuffer(dtype=dtype))
112 |
--------------------------------------------------------------------------------
/src/schnetpack/md/utils/normal_model_transformation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | __all__ = ["NormalModeTransformer"]
6 |
7 |
8 | class NormalModeTransformer(nn.Module):
9 | """
10 | Class for transforming between bead and normal mode representation of the ring polymer, used e.g. in propagating the
11 | ring polymer during simulation. An in depth description of the transformation can be found e.g. in [#rpmd3]_. Here,
12 | a simple matrix multiplication is used instead of a Fourier transformation, which can be more performant in certain
13 | cases. On the GPU however, no significant performance gains where observed when using a FT based transformation over
14 | the matrix version.
15 |
16 | This transformation operates on the first dimension of the property tensors (e.g. positions, momenta) defined in the
17 | system class. Hence, the transformation can be carried out for several molecules at the same time.
18 |
19 | Args:
20 | n_beads (int): Number of beads in the ring polymer.
21 |
22 | References
23 | ----------
24 | .. [#rpmd3] Ceriotti, Parrinello, Markland, Manolopoulos:
25 | Efficient stochastic thermostatting of path integral molecular dynamics.
26 | The Journal of Chemical Physics, 133, 124105. 2010.
27 | """
28 |
29 | def __init__(self, n_beads):
30 | super(NormalModeTransformer, self).__init__()
31 | self.n_beads = n_beads
32 |
33 | # Initialize the transformation matrix
34 | c_transform = self._init_transformation_matrix()
35 | self.register_buffer("c_transform", c_transform)
36 |
37 | def _init_transformation_matrix(self):
38 | """
39 | Build the normal mode transformation matrix. This matrix only has to be built once and can then be used during
40 | the whole simulation. The matrix has the dimension n_beads x n_beads, where n_beads is the number of beads in
41 | the ring polymer
42 |
43 | Returns:
44 | torch.Tensor: Normal mode transformation matrix of the shape n_beads x n_beads
45 | """
46 | # Set up basic transformation matrix
47 | c_transform = np.zeros((self.n_beads, self.n_beads))
48 |
49 | # Get auxiliary array with bead indices
50 | n = np.arange(1, self.n_beads + 1)
51 |
52 | # for k = 0
53 | c_transform[0, :] = 1.0
54 |
55 | for k in range(1, self.n_beads // 2 + 1):
56 | c_transform[k, :] = np.sqrt(2) * np.cos(2 * np.pi * k * n / self.n_beads)
57 |
58 | for k in range(self.n_beads // 2 + 1, self.n_beads):
59 | c_transform[k, :] = np.sqrt(2) * np.sin(2 * np.pi * k * n / self.n_beads)
60 |
61 | if self.n_beads % 2 == 0:
62 | c_transform[self.n_beads // 2, :] = (-1) ** n
63 |
64 | # Since matrix is initialized as C(k,n) does not need to be transposed
65 | c_transform /= np.sqrt(self.n_beads)
66 | c_transform = torch.from_numpy(c_transform)
67 |
68 | return c_transform
69 |
70 | def beads2normal(self, x_beads):
71 | """
72 | Transform a system tensor (e.g. momenta, positions) from the bead representation to normal mode representation.
73 |
74 | Args:
75 | x_beads (torch.Tensor): System tensor in bead representation with the general shape
76 | n_beads x n_molecules x ...
77 |
78 | Returns:
79 | torch.Tensor: System tensor in normal mode representation with the same shape as the input tensor.
80 | """
81 | return torch.mm(self.c_transform, x_beads.view(self.n_beads, -1)).view(
82 | x_beads.shape
83 | )
84 |
85 | def normal2beads(self, x_normal):
86 | """
87 | Transform a system tensor (e.g. momenta, positions) in normal mode representation back to bead representation.
88 |
89 | Args:
90 | x_normal (torch.Tensor): System tensor in normal mode representation with the general shape
91 | n_beads x n_molecules x ...
92 |
93 | Returns:
94 | torch.Tensor: System tensor in bead representation with the same shape as the input tensor.
95 | """
96 | return torch.mm(
97 | self.c_transform.transpose(0, 1), x_normal.view(self.n_beads, -1)
98 | ).view(x_normal.shape)
99 |
--------------------------------------------------------------------------------
/src/schnetpack/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 |
--------------------------------------------------------------------------------
/src/schnetpack/nn/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Basic building blocks of SchNetPack models. Contains various basic and specialized network layers, layers for
3 | cutoff functions, as well as several auxiliary layers and functions.
4 | """
5 |
6 | from schnetpack.nn.activations import *
7 | from schnetpack.nn.base import *
8 | from schnetpack.nn.blocks import *
9 | from schnetpack.nn.cutoff import *
10 | from schnetpack.nn.equivariant import *
11 | from schnetpack.nn.so3 import *
12 | from schnetpack.nn.scatter import *
13 | from schnetpack.nn.radial import *
14 | from schnetpack.nn.utils import *
15 | from schnetpack.nn.embedding import *
16 |
--------------------------------------------------------------------------------
/src/schnetpack/nn/activations.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 | from torch.nn import functional
5 |
6 | __all__ = ["shifted_softplus", "softplus_inverse", "ShiftedSoftplus"]
7 |
8 |
9 | def shifted_softplus(x: torch.Tensor):
10 | r"""Compute shifted soft-plus activation function.
11 |
12 | .. math::
13 | y = \ln\left(1 + e^{-x}\right) - \ln(2)
14 |
15 | Args:
16 | x (torch.Tensor): input tensor.
17 |
18 | Returns:
19 | torch.Tensor: shifted soft-plus of input.
20 |
21 | """
22 | return functional.softplus(x) - math.log(2.0)
23 |
24 |
25 | def softplus_inverse(x: torch.Tensor):
26 | """
27 | Inverse of the softplus function.
28 |
29 | Args:
30 | x (torch.Tensor): Input vector
31 |
32 | Returns:
33 | torch.Tensor: softplus inverse of input.
34 | """
35 | return x + (torch.log(-torch.expm1(-x)))
36 |
37 |
38 | class ShiftedSoftplus(torch.nn.Module):
39 | """
40 | Shifted softplus activation function with learnable feature-wise parameters:
41 | f(x) = alpha/beta * (softplus(beta*x) - log(2))
42 | softplus(x) = log(exp(x) + 1)
43 | For beta -> 0 : f(x) -> 0.5*alpha*x
44 | For beta -> inf: f(x) -> max(0, alpha*x)
45 |
46 | With learnable parameters alpha and beta, the shifted softplus function can
47 | become equivalent to ReLU (if alpha is equal 1 and beta approaches infinity) or to
48 | the identity function (if alpha is equal 2 and beta is equal 0).
49 | """
50 |
51 | def __init__(
52 | self,
53 | initial_alpha: float = 1.0,
54 | initial_beta: float = 1.0,
55 | trainable: bool = False,
56 | ) -> None:
57 | """
58 | Args:
59 | initial_alpha: Initial "scale" alpha of the softplus function.
60 | initial_beta: Initial "temperature" beta of the softplus function.
61 | trainable: If True, alpha and beta are trained during optimization.
62 | """
63 | super(ShiftedSoftplus, self).__init__()
64 | initial_alpha = torch.tensor(initial_alpha)
65 | initial_beta = torch.tensor(initial_beta)
66 |
67 | if trainable:
68 | self.alpha = torch.nn.Parameter(torch.FloatTensor([initial_alpha]))
69 | self.beta = torch.nn.Parameter(torch.FloatTensor([initial_beta]))
70 | else:
71 | self.register_buffer("alpha", initial_alpha)
72 | self.register_buffer("beta", initial_beta)
73 |
74 | def forward(self, x: torch.Tensor) -> torch.Tensor:
75 | """
76 | Evaluate activation function given the input features x.
77 | num_features: Dimensions of feature space.
78 |
79 | Args:
80 | x (FloatTensor [:, num_features]): Input features.
81 |
82 | Returns:
83 | y (FloatTensor [:, num_features]): Activated features.
84 | """
85 | return self.alpha * torch.where(
86 | self.beta != 0,
87 | (torch.nn.functional.softplus(self.beta * x) - math.log(2)) / self.beta,
88 | 0.5 * x,
89 | )
90 |
--------------------------------------------------------------------------------
/src/schnetpack/nn/base.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Union, Optional
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn
6 | from torch.nn.init import xavier_uniform_
7 |
8 | from torch.nn.init import zeros_
9 |
10 |
11 | __all__ = ["Dense"]
12 |
13 |
14 | class Dense(nn.Linear):
15 | r"""Fully connected linear layer with activation function.
16 |
17 | .. math::
18 | y = activation(x W^T + b)
19 | """
20 |
21 | def __init__(
22 | self,
23 | in_features: int,
24 | out_features: int,
25 | bias: bool = True,
26 | activation: Union[Callable, nn.Module] = None,
27 | weight_init: Callable = xavier_uniform_,
28 | bias_init: Callable = zeros_,
29 | ):
30 | """
31 | Args:
32 | in_features: number of input feature :math:`x`.
33 | out_features: number of output features :math:`y`.
34 | bias: If False, the layer will not adapt bias :math:`b`.
35 | activation: if None, no activation function is used.
36 | weight_init: weight initializer from current weight.
37 | bias_init: bias initializer from current bias.
38 | """
39 | self.weight_init = weight_init
40 | self.bias_init = bias_init
41 | super(Dense, self).__init__(in_features, out_features, bias)
42 |
43 | self.activation = activation
44 | if self.activation is None:
45 | self.activation = nn.Identity()
46 |
47 | def reset_parameters(self):
48 | self.weight_init(self.weight)
49 | if self.bias is not None:
50 | self.bias_init(self.bias)
51 |
52 | def forward(self, input: torch.Tensor):
53 | y = F.linear(input, self.weight, self.bias)
54 | y = self.activation(y)
55 | return y
56 |
--------------------------------------------------------------------------------
/src/schnetpack/nn/cutoff.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 |
5 | __all__ = [
6 | "CosineCutoff",
7 | "MollifierCutoff",
8 | "mollifier_cutoff",
9 | "cosine_cutoff",
10 | "SwitchFunction",
11 | ]
12 |
13 |
14 | def cosine_cutoff(input: torch.Tensor, cutoff: torch.Tensor):
15 | r""" Behler-style cosine cutoff.
16 |
17 | .. math::
18 | f(r) = \begin{cases}
19 | 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right]
20 | & r < r_\text{cutoff} \\
21 | 0 & r \geqslant r_\text{cutoff} \\
22 | \end{cases}
23 |
24 | Args:
25 | cutoff (float, optional): cutoff radius.
26 |
27 | """
28 |
29 | # Compute values of cutoff function
30 | input_cut = 0.5 * (torch.cos(input * math.pi / cutoff) + 1.0)
31 | # Remove contributions beyond the cutoff radius
32 | input_cut *= (input < cutoff).float()
33 | return input_cut
34 |
35 |
36 | class CosineCutoff(nn.Module):
37 | r""" Behler-style cosine cutoff module.
38 |
39 | .. math::
40 | f(r) = \begin{cases}
41 | 0.5 \times \left[1 + \cos\left(\frac{\pi r}{r_\text{cutoff}}\right)\right]
42 | & r < r_\text{cutoff} \\
43 | 0 & r \geqslant r_\text{cutoff} \\
44 | \end{cases}
45 |
46 | """
47 |
48 | def __init__(self, cutoff: float):
49 | """
50 | Args:
51 | cutoff (float, optional): cutoff radius.
52 | """
53 | super(CosineCutoff, self).__init__()
54 | self.register_buffer("cutoff", torch.FloatTensor([cutoff]))
55 |
56 | def forward(self, input: torch.Tensor):
57 | return cosine_cutoff(input, self.cutoff)
58 |
59 |
60 | def mollifier_cutoff(input: torch.Tensor, cutoff: torch.Tensor, eps: torch.Tensor):
61 | r""" Mollifier cutoff scaled to have a value of 1 at :math:`r=0`.
62 |
63 | .. math::
64 | f(r) = \begin{cases}
65 | \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right)
66 | & r < r_\text{cutoff} \\
67 | 0 & r \geqslant r_\text{cutoff} \\
68 | \end{cases}
69 |
70 | Args:
71 | cutoff: Cutoff radius.
72 | eps: Offset added to distances for numerical stability.
73 |
74 | """
75 | mask = (input + eps < cutoff).float()
76 | exponent = 1.0 - 1.0 / (1.0 - torch.pow(input * mask / cutoff, 2))
77 | cutoffs = torch.exp(exponent)
78 | cutoffs = cutoffs * mask
79 | return cutoffs
80 |
81 |
82 | class MollifierCutoff(nn.Module):
83 | r""" Mollifier cutoff module scaled to have a value of 1 at :math:`r=0`.
84 |
85 | .. math::
86 | f(r) = \begin{cases}
87 | \exp\left(1 - \frac{1}{1 - \left(\frac{r}{r_\text{cutoff}}\right)^2}\right)
88 | & r < r_\text{cutoff} \\
89 | 0 & r \geqslant r_\text{cutoff} \\
90 | \end{cases}
91 | """
92 |
93 | def __init__(self, cutoff: float, eps: float = 1.0e-7):
94 | """
95 | Args:
96 | cutoff: Cutoff radius.
97 | eps: Offset added to distances for numerical stability.
98 | """
99 | super(MollifierCutoff, self).__init__()
100 | self.register_buffer("cutoff", torch.FloatTensor([cutoff]))
101 | self.register_buffer("eps", torch.FloatTensor([eps]))
102 |
103 | def forward(self, input: torch.Tensor):
104 | return mollifier_cutoff(input, self.cutoff, self.eps)
105 |
106 |
107 | def _switch_component(
108 | x: torch.Tensor, ones: torch.Tensor, zeros: torch.Tensor
109 | ) -> torch.Tensor:
110 | """
111 | Basic component of switching functions.
112 |
113 | Args:
114 | x (torch.Tensor): Switch functions.
115 | ones (torch.Tensor): Tensor with ones.
116 | zeros (torch.Tensor): Zero tensor
117 |
118 | Returns:
119 | torch.Tensor: Output tensor.
120 | """
121 | x_ = torch.where(x <= 0, ones, x)
122 | return torch.where(x <= 0, zeros, torch.exp(-ones / x_))
123 |
124 |
125 | class SwitchFunction(nn.Module):
126 | """
127 | Decays from 1 to 0 between `switch_on` and `switch_off`.
128 | """
129 |
130 | def __init__(self, switch_on: float, switch_off: float):
131 | """
132 |
133 | Args:
134 | switch_on (float): Onset of switch.
135 | switch_off (float): Value from which on switch is 0.
136 | """
137 | super(SwitchFunction, self).__init__()
138 | self.register_buffer("switch_on", torch.Tensor([switch_on]))
139 | self.register_buffer("switch_off", torch.Tensor([switch_off]))
140 |
141 | def forward(self, x: torch.Tensor) -> torch.Tensor:
142 | """
143 |
144 | Args:
145 | x (torch.Tensor): tensor to which switching function should be applied to.
146 |
147 | Returns:
148 | torch.Tensor: switch output
149 | """
150 | x = (x - self.switch_on) / (self.switch_off - self.switch_on)
151 |
152 | ones = torch.ones_like(x)
153 | zeros = torch.zeros_like(x)
154 | fp = _switch_component(x, ones, zeros)
155 | fm = _switch_component(1 - x, ones, zeros)
156 |
157 | f_switch = torch.where(x <= 0, ones, torch.where(x >= 1, zeros, fm / (fp + fm)))
158 | return f_switch
159 |
--------------------------------------------------------------------------------
/src/schnetpack/nn/equivariant.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import schnetpack.nn as snn
6 | from typing import Tuple
7 |
8 | __all__ = ["GatedEquivariantBlock"]
9 |
10 |
11 | class GatedEquivariantBlock(nn.Module):
12 | """
13 | Gated equivariant block as used for the prediction of tensorial properties by PaiNN.
14 | Transforms scalar and vector representation using gated nonlinearities.
15 |
16 | References:
17 |
18 | .. [#painn1] Schütt, Unke, Gastegger:
19 | Equivariant message passing for the prediction of tensorial properties and molecular spectra.
20 | ICML 2021 (to appear)
21 |
22 | """
23 |
24 | def __init__(
25 | self,
26 | n_sin: int,
27 | n_vin: int,
28 | n_sout: int,
29 | n_vout: int,
30 | n_hidden: int,
31 | activation=F.silu,
32 | sactivation=None,
33 | ):
34 | """
35 | Args:
36 | n_sin: number of input scalar features
37 | n_vin: number of input vector features
38 | n_sout: number of output scalar features
39 | n_vout: number of output vector features
40 | n_hidden: number of hidden units
41 | activation: interal activation function
42 | sactivation: activation function for scalar outputs
43 | """
44 | super().__init__()
45 | self.n_sin = n_sin
46 | self.n_vin = n_vin
47 | self.n_sout = n_sout
48 | self.n_vout = n_vout
49 | self.n_hidden = n_hidden
50 | self.mix_vectors = snn.Dense(n_vin, 2 * n_vout, activation=None, bias=False)
51 | self.scalar_net = nn.Sequential(
52 | snn.Dense(n_sin + n_vout, n_hidden, activation=activation),
53 | snn.Dense(n_hidden, n_sout + n_vout, activation=None),
54 | )
55 | self.sactivation = sactivation
56 |
57 | def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor]):
58 | scalars, vectors = inputs
59 | vmix = self.mix_vectors(vectors)
60 | vectors_V, vectors_W = torch.split(vmix, self.n_vout, dim=-1)
61 | vectors_Vn = torch.norm(vectors_V, dim=-2)
62 |
63 | ctx = torch.cat([scalars, vectors_Vn], dim=-1)
64 | x = self.scalar_net(ctx)
65 | s_out, x = torch.split(x, [self.n_sout, self.n_vout], dim=-1)
66 | v_out = x.unsqueeze(-2) * vectors_W
67 |
68 | if self.sactivation:
69 | s_out = self.sactivation(s_out)
70 |
71 | return s_out, v_out
72 |
--------------------------------------------------------------------------------
/src/schnetpack/nn/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/src/schnetpack/nn/ops/__init__.py
--------------------------------------------------------------------------------
/src/schnetpack/nn/ops/math.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def binom(n: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
5 | """
6 | Compute binomial coefficients (n k)
7 | """
8 | return torch.exp(
9 | torch.lgamma(n + 1) - torch.lgamma((n - k) + 1) - torch.lgamma(k + 1)
10 | )
11 |
--------------------------------------------------------------------------------
/src/schnetpack/nn/ops/so3.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from sympy.physics.wigner import clebsch_gordan
4 |
5 | from functools import lru_cache
6 | from typing import Tuple
7 |
8 |
9 | @lru_cache(maxsize=10)
10 | def sh_indices(lmax: int) -> Tuple[torch.Tensor, torch.Tensor]:
11 | """
12 | Build index arrays for spherical harmonics
13 |
14 | Args:
15 | lmax: maximum angular momentum
16 | """
17 | ls = torch.arange(0, lmax + 1)
18 | nls = 2 * ls + 1
19 | lidx = torch.repeat_interleave(ls, nls)
20 | midx = torch.cat([torch.arange(-l, l + 1) for l in ls])
21 | return lidx, midx
22 |
23 |
24 | @lru_cache(maxsize=10)
25 | def generate_sh_to_rsh(lmax: int) -> torch.Tensor:
26 | """
27 | Generate transformation matrix to convert (complex) spherical harmonics to real form
28 |
29 | Args:
30 | lmax: maximum angular momentum
31 | """
32 | lidx, midx = sh_indices(lmax)
33 | l1 = lidx[:, None]
34 | l2 = lidx[None, :]
35 | m1 = midx[:, None]
36 | m2 = midx[None, :]
37 | U = (
38 | 1.0 * ((m1 == 0) * (m2 == 0))
39 | + (-1.0) ** abs(m1) / math.sqrt(2) * ((m1 == m2) * (m1 > 0))
40 | + 1.0 / math.sqrt(2) * ((m1 == -m2) * (m2 < 0))
41 | + -1.0j * (-1.0) ** abs(m1) / math.sqrt(2) * ((m1 == -m2) * (m1 < 0))
42 | + 1.0j / math.sqrt(2) * ((m1 == m2) * (m1 < 0))
43 | ) * (l1 == l2)
44 | return U
45 |
46 |
47 | @lru_cache(maxsize=10)
48 | def generate_clebsch_gordan(lmax: int) -> torch.Tensor:
49 | """
50 | Generate standard Clebsch-Gordan coefficients for complex spherical harmonics
51 |
52 | Args:
53 | lmax: maximum angular momentum
54 | """
55 | lidx, midx = sh_indices(lmax)
56 | cg = torch.zeros((lidx.shape[0], lidx.shape[0], lidx.shape[0]))
57 | lidx = lidx.numpy()
58 | midx = midx.numpy()
59 | for c1, (l1, m1) in enumerate(zip(lidx, midx)):
60 | for c2, (l2, m2) in enumerate(zip(lidx, midx)):
61 | for c3, (l3, m3) in enumerate(zip(lidx, midx)):
62 | if abs(l1 - l2) <= l3 <= min(l1 + l2, lmax) and m3 in {
63 | m1 + m2,
64 | m1 - m2,
65 | m2 - m1,
66 | -m1 - m2,
67 | }:
68 | coeff = clebsch_gordan(l1, l2, l3, m1, m2, m3)
69 | cg[c1, c2, c3] = float(coeff)
70 | return cg
71 |
72 |
73 | @lru_cache(maxsize=10)
74 | def generate_clebsch_gordan_rsh(
75 | lmax: int, parity_invariance: bool = True
76 | ) -> torch.Tensor:
77 | """
78 | Generate Clebsch-Gordan coefficients for real spherical harmonics
79 |
80 | Args:
81 | lmax: maximum angular momentum
82 | parity_invariance: whether to enforce parity invariance, i.e. only allow
83 | non-zero coefficients if :math:`-1^l_1 -1^l_2 = -1^l_3`
84 |
85 | """
86 | lidx, _ = sh_indices(lmax)
87 | cg = generate_clebsch_gordan(lmax).to(dtype=torch.complex64)
88 | complex_to_real = generate_sh_to_rsh(lmax) # (real, complex)
89 | cg_rsh = torch.einsum(
90 | "ijk,mi,nj,ok->mno",
91 | cg,
92 | complex_to_real,
93 | complex_to_real,
94 | complex_to_real.conj(),
95 | )
96 |
97 | if parity_invariance:
98 | parity = (-1.0) ** lidx
99 | pmask = parity[:, None, None] * parity[None, :, None] == parity[None, None, :]
100 | cg_rsh *= pmask
101 | else:
102 | lsum = lidx[:, None, None] + lidx[None, :, None] - lidx[None, None, :]
103 | cg_rsh *= 1.0j**lsum
104 |
105 | # cast to real
106 | cg_rsh = cg_rsh.real.to(torch.float64)
107 | return cg_rsh
108 |
109 |
110 | def sparsify_clebsch_gordon(
111 | cg: torch.Tensor,
112 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
113 | """
114 | Convert Clebsch-Gordon tensor to sparse format.
115 |
116 | Args:
117 | cg: dense tensor Clebsch-Gordon coefficients
118 | [(lmax_1+1)^2, (lmax_2+1)^2, (lmax_out+1)^2]
119 |
120 | Returns:
121 | cg_sparse: vector of non-zeros CG coefficients
122 | idx_in_1: indices for first set of irreps
123 | idx_in_2: indices for second set of irreps
124 | idx_out: indices for output set of irreps
125 | """
126 | idx = torch.nonzero(cg)
127 | idx_in_1, idx_in_2, idx_out = torch.split(idx, 1, dim=1)
128 | idx_in_1, idx_in_2, idx_out = (
129 | idx_in_1[:, 0],
130 | idx_in_2[:, 0],
131 | idx_out[:, 0],
132 | )
133 | cg_sparse = cg[idx_in_1, idx_in_2, idx_out]
134 | return cg_sparse, idx_in_1, idx_in_2, idx_out
135 |
136 |
137 | def round_cmp(x: torch.Tensor, decimals: int = 1):
138 | return torch.round(x.real, decimals=decimals) + 1j * torch.round(
139 | x.imag, decimals=decimals
140 | )
141 |
--------------------------------------------------------------------------------
/src/schnetpack/nn/radial.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | __all__ = ["gaussian_rbf", "GaussianRBF", "GaussianRBFCentered", "BesselRBF"]
7 |
8 | from torch import nn as nn
9 |
10 |
11 | def gaussian_rbf(inputs: torch.Tensor, offsets: torch.Tensor, widths: torch.Tensor):
12 | coeff = -0.5 / torch.pow(widths, 2)
13 | diff = inputs[..., None] - offsets
14 | y = torch.exp(coeff * torch.pow(diff, 2))
15 | return y
16 |
17 |
18 | class GaussianRBF(nn.Module):
19 | r"""Gaussian radial basis functions."""
20 |
21 | def __init__(
22 | self, n_rbf: int, cutoff: float, start: float = 0.0, trainable: bool = False
23 | ):
24 | """
25 | Args:
26 | n_rbf: total number of Gaussian functions, :math:`N_g`.
27 | cutoff: center of last Gaussian function, :math:`\mu_{N_g}`
28 | start: center of first Gaussian function, :math:`\mu_0`.
29 | trainable: If True, widths and offset of Gaussian functions
30 | are adjusted during training process.
31 | """
32 | super(GaussianRBF, self).__init__()
33 | self.n_rbf = n_rbf
34 |
35 | # compute offset and width of Gaussian functions
36 | offset = torch.linspace(start, cutoff, n_rbf)
37 | widths = torch.FloatTensor(
38 | torch.abs(offset[1] - offset[0]) * torch.ones_like(offset)
39 | )
40 | if trainable:
41 | self.widths = nn.Parameter(widths)
42 | self.offsets = nn.Parameter(offset)
43 | else:
44 | self.register_buffer("widths", widths)
45 | self.register_buffer("offsets", offset)
46 |
47 | def forward(self, inputs: torch.Tensor):
48 | return gaussian_rbf(inputs, self.offsets, self.widths)
49 |
50 |
51 | class GaussianRBFCentered(nn.Module):
52 | r"""Gaussian radial basis functions centered at the origin."""
53 |
54 | def __init__(
55 | self, n_rbf: int, cutoff: float, start: float = 1.0, trainable: bool = False
56 | ):
57 | """
58 | Args:
59 | n_rbf: total number of Gaussian functions, :math:`N_g`.
60 | cutoff: width of last Gaussian function, :math:`\mu_{N_g}`
61 | start: width of first Gaussian function, :math:`\mu_0`.
62 | trainable: If True, widths of Gaussian functions
63 | are adjusted during training process.
64 | """
65 | super(GaussianRBFCentered, self).__init__()
66 | self.n_rbf = n_rbf
67 |
68 | # compute offset and width of Gaussian functions
69 | widths = torch.linspace(start, cutoff, n_rbf)
70 | offset = torch.zeros_like(widths)
71 | if trainable:
72 | self.widths = nn.Parameter(widths)
73 | self.offsets = nn.Parameter(offset)
74 | else:
75 | self.register_buffer("widths", widths)
76 | self.register_buffer("offsets", offset)
77 |
78 | def forward(self, inputs: torch.Tensor):
79 | return gaussian_rbf(inputs, self.offsets, self.widths)
80 |
81 |
82 | class BesselRBF(nn.Module):
83 | """
84 | Sine for radial basis functions with coulomb decay (0th order bessel).
85 |
86 | References:
87 |
88 | .. [#dimenet] Klicpera, Groß, Günnemann:
89 | Directional message passing for molecular graphs.
90 | ICLR 2020
91 | """
92 |
93 | def __init__(self, n_rbf: int, cutoff: float):
94 | """
95 | Args:
96 | cutoff: radial cutoff
97 | n_rbf: number of basis functions.
98 | """
99 | super(BesselRBF, self).__init__()
100 | self.n_rbf = n_rbf
101 |
102 | freqs = torch.arange(1, n_rbf + 1) * pi / cutoff
103 | self.register_buffer("freqs", freqs)
104 |
105 | def forward(self, inputs):
106 | ax = inputs[..., None] * self.freqs
107 | sinax = torch.sin(ax)
108 | norm = torch.where(inputs == 0, torch.tensor(1.0, device=inputs.device), inputs)
109 | y = sinax / norm[..., None]
110 | return y
111 |
--------------------------------------------------------------------------------
/src/schnetpack/nn/scatter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | __all__ = ["scatter_add"]
5 |
6 |
7 | def scatter_add(
8 | x: torch.Tensor, idx_i: torch.Tensor, dim_size: int, dim: int = 0
9 | ) -> torch.Tensor:
10 | """
11 | Sum over values with the same indices.
12 |
13 | Args:
14 | x: input values
15 | idx_i: index of center atom i
16 | dim_size: size of the dimension after reduction
17 | dim: the dimension to reduce
18 |
19 | Returns:
20 | reduced input
21 |
22 | """
23 | return _scatter_add(x, idx_i, dim_size, dim)
24 |
25 |
26 | @torch.jit.script
27 | def _scatter_add(
28 | x: torch.Tensor, idx_i: torch.Tensor, dim_size: int, dim: int = 0
29 | ) -> torch.Tensor:
30 | shape = list(x.shape)
31 | shape[dim] = dim_size
32 | tmp = torch.zeros(shape, dtype=x.dtype, device=x.device)
33 | y = tmp.index_add(dim, idx_i, x)
34 | return y
35 |
--------------------------------------------------------------------------------
/src/schnetpack/nn/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Optional
2 |
3 | import torch
4 | from torch import nn as nn
5 |
6 | __all__ = ["replicate_module", "derivative_from_atomic", "derivative_from_molecular"]
7 |
8 | from torch.autograd import grad
9 |
10 |
11 | def replicate_module(
12 | module_factory: Callable[[], nn.Module], n: int, share_params: bool
13 | ):
14 | if share_params:
15 | module_list = nn.ModuleList([module_factory()] * n)
16 | else:
17 | module_list = nn.ModuleList([module_factory() for i in range(n)])
18 | return module_list
19 |
20 |
21 | def derivative_from_molecular(
22 | fx: torch.Tensor,
23 | dx: torch.Tensor,
24 | create_graph: bool = False,
25 | retain_graph: bool = False,
26 | ):
27 | """
28 | Compute the derivative of `fx` with respect to `dx` if the leading dimension of `fx` is the number of molecules
29 | (e.g. energies, dipole moments, etc).
30 |
31 | Args:
32 | fx (torch.Tensor): Tensor for which the derivative is taken.
33 | dx (torch.Tensor): Derivative.
34 | create_graph (bool): Create computational graph.
35 | retain_graph (bool): Keep the computational graph.
36 |
37 | Returns:
38 | torch.Tensor: derivative of `fx` with respect to `dx`.
39 | """
40 | fx_shape = fx.shape
41 | dx_shape = dx.shape
42 | # Final shape takes into consideration whether derivative will yield atomic or molecular properties
43 | final_shape = (dx_shape[0], *fx_shape[1:], *dx_shape[1:])
44 |
45 | fx = fx.view(fx_shape[0], -1)
46 |
47 | dfdx = torch.stack(
48 | [
49 | grad(
50 | fx[..., i],
51 | dx,
52 | torch.ones_like(fx[..., i]),
53 | create_graph=create_graph,
54 | retain_graph=retain_graph,
55 | )[0]
56 | for i in range(fx.shape[1])
57 | ],
58 | dim=1,
59 | )
60 | dfdx = dfdx.view(final_shape)
61 |
62 | return dfdx
63 |
64 |
65 | def derivative_from_atomic(
66 | fx: torch.Tensor,
67 | dx: torch.Tensor,
68 | n_atoms: torch.Tensor,
69 | create_graph: bool = False,
70 | retain_graph: bool = False,
71 | ):
72 | """
73 | Compute the derivative of a tensor with the leading dimension of (batch x atoms) with respect to another tensor of
74 | either dimension (batch * atoms) (e.g. R) or (batch * atom pairs) (e.g. Rij). This function is primarily used for
75 | computing Hessians and Hessian-like response properties (e.g. nuclear spin-spin couplings). The final tensor will
76 | have the shape ( batch * atoms * atoms x ....).
77 |
78 | This is quite inefficient, use with care.
79 |
80 | Args:
81 | fx (torch.Tensor): Tensor for which the derivative is taken.
82 | dx (torch.Tensor): Derivative.
83 | n_atoms (torch.Tensor): Tensor containing the number of atoms for each molecule.
84 | create_graph (bool): Create computational graph.
85 | retain_graph (bool): Keep the computational graph.
86 |
87 | Returns:
88 | torch.Tensor: derivative of `fx` with respect to `dx`.
89 | """
90 | # Split input tensor for easier bookkeeping
91 | fxm = fx.split(list(n_atoms))
92 |
93 | dfdx = []
94 |
95 | n_mol = 0
96 | # Compute all derivatives
97 | for idx in range(len(fxm)):
98 | fx = fxm[idx].view(-1)
99 |
100 | # Generate the individual derivatives
101 | dfdx_mol = []
102 | for i in range(fx.shape[0]):
103 | dfdx_i = grad(
104 | fx[i],
105 | dx,
106 | torch.ones_like(fx[i]),
107 | create_graph=create_graph,
108 | retain_graph=retain_graph,
109 | )[0]
110 |
111 | dfdx_mol.append(dfdx_i[n_mol : n_mol + n_atoms[idx], ...])
112 |
113 | # Build molecular matrix and reshape
114 | dfdx_mol = torch.stack(dfdx_mol, dim=0)
115 | dfdx_mol = dfdx_mol.view(n_atoms[idx], 3, n_atoms[idx], 3)
116 | dfdx_mol = dfdx_mol.permute(0, 2, 1, 3)
117 | dfdx_mol = dfdx_mol.reshape(n_atoms[idx] ** 2, 3, 3)
118 |
119 | dfdx.append(dfdx_mol)
120 |
121 | n_mol += n_atoms[idx]
122 |
123 | # Accumulate everything
124 | dfdx = torch.cat(dfdx, dim=0)
125 |
126 | return dfdx
127 |
--------------------------------------------------------------------------------
/src/schnetpack/properties.py:
--------------------------------------------------------------------------------
1 | """
2 | Keys to access structure properties.
3 |
4 | Note: Had to be moved out of Structure class for TorchScript compatibility
5 |
6 | """
7 |
8 | from typing import Final
9 |
10 | idx: Final[str] = "_idx"
11 |
12 | ## structure
13 | Z: Final[str] = "_atomic_numbers" #: nuclear charge
14 | position: Final[str] = "_positions" #: atom positions
15 | R: Final[str] = position #: atom positions
16 |
17 | cell: Final[str] = "_cell" #: unit cell
18 | strain: Final[str] = "strain"
19 | pbc: Final[str] = "_pbc" #: periodic boundary conditions
20 |
21 | seg_m: Final[str] = "_seg_m" #: start indices of systems
22 | idx_m: Final[str] = "_idx_m" #: indices of systems
23 | idx_i: Final[str] = "_idx_i" #: indices of center atoms
24 | idx_j: Final[str] = "_idx_j" #: indices of neighboring atoms
25 | idx_i_lr: Final[str] = "_idx_i_lr" #: indices of center atoms for long-range
26 | idx_j_lr: Final[str] = "_idx_j_lr" #: indices of neighboring atoms for long-range
27 |
28 | lidx_i: Final[str] = "_idx_i_local" #: local indices of center atoms (within system)
29 | lidx_j: Final[str] = (
30 | "_idx_j_local" #: local indices of neighboring atoms (within system)
31 | )
32 | Rij: Final[str] = "_Rij" #: vectors pointing from center atoms to neighboring atoms
33 | Rij_lr: Final[str] = (
34 | "_Rij_lr" #: vectors pointing from center atoms to neighboring atoms for long range
35 | )
36 | n_atoms: Final[str] = "_n_atoms" #: number of atoms
37 | offsets: Final[str] = "_offsets" #: cell offset vectors
38 | offsets_lr: Final[str] = "_offsets_lr" #: cell offset vectors for long range
39 |
40 | R_strained: Final[str] = (
41 | position + "_strained"
42 | ) #: atom positions with strain-dependence
43 | cell_strained: Final[str] = cell + "_strained" #: atom positions with strain-dependence
44 |
45 | n_nbh: Final[str] = "_n_nbh" #: number of neighbors
46 |
47 | #: indices of center atom triples
48 | idx_i_triples: Final[str] = "_idx_i_triples"
49 |
50 | #: indices of first neighboring atom triples
51 | idx_j_triples: Final[str] = "_idx_j_triples"
52 |
53 | #: indices of second neighboring atom triples
54 | idx_k_triples: Final[str] = "_idx_k_triples"
55 |
56 | ## chemical properties
57 | energy: Final[str] = "energy"
58 | forces: Final[str] = "forces"
59 | stress: Final[str] = "stress"
60 | masses: Final[str] = "masses"
61 | dipole_moment: Final[str] = "dipole_moment"
62 | polarizability: Final[str] = "polarizability"
63 | hessian: Final[str] = "hessian"
64 | dipole_derivatives: Final[str] = "dipole_derivatives"
65 | polarizability_derivatives: Final[str] = "polarizability_derivatives"
66 | total_charge: Final[str] = "total_charge"
67 | partial_charges: Final[str] = "partial_charges"
68 | spin_multiplicity: Final[str] = "spin_multiplicity"
69 | electric_field: Final[str] = "electric_field"
70 | magnetic_field: Final[str] = "magnetic_field"
71 | nuclear_magnetic_moments: Final[str] = "nuclear_magnetic_moments"
72 | shielding: Final[str] = "shielding"
73 | nuclear_spin_coupling: Final[str] = "nuclear_spin_coupling"
74 |
75 | ## external fields needed for different response properties
76 | required_external_fields = {
77 | dipole_moment: [electric_field],
78 | dipole_derivatives: [electric_field],
79 | partial_charges: [electric_field],
80 | polarizability: [electric_field],
81 | polarizability_derivatives: [electric_field],
82 | shielding: [magnetic_field],
83 | nuclear_spin_coupling: [magnetic_field],
84 | }
85 |
--------------------------------------------------------------------------------
/src/schnetpack/representation/__init__.py:
--------------------------------------------------------------------------------
1 | from .schnet import *
2 | from .painn import *
3 | from .field_schnet import *
4 | from .so3net import *
5 |
--------------------------------------------------------------------------------
/src/schnetpack/train/__init__.py:
--------------------------------------------------------------------------------
1 | from .callbacks import *
2 | from .lr_scheduler import *
3 |
--------------------------------------------------------------------------------
/src/schnetpack/train/callbacks.py:
--------------------------------------------------------------------------------
1 | from copy import copy
2 | from typing import Dict
3 |
4 | from pytorch_lightning.callbacks import Callback
5 | from pytorch_lightning.callbacks import ModelCheckpoint as BaseModelCheckpoint
6 |
7 | from torch_ema import ExponentialMovingAverage as EMA
8 |
9 | import torch
10 | import os
11 | from pytorch_lightning.callbacks import BasePredictionWriter
12 | from typing import List, Any
13 | from schnetpack.task import AtomisticTask
14 | from schnetpack import properties
15 | from collections import defaultdict
16 |
17 |
18 | __all__ = ["ModelCheckpoint", "PredictionWriter", "ExponentialMovingAverage"]
19 |
20 |
21 | class PredictionWriter(BasePredictionWriter):
22 | """
23 | Callback to store prediction results using ``torch.save``.
24 | """
25 |
26 | def __init__(
27 | self,
28 | output_dir: str,
29 | write_interval: str,
30 | write_idx: bool = False,
31 | ):
32 | """
33 | Args:
34 | output_dir: output directory for prediction files
35 | write_interval: can be one of ["batch", "epoch", "batch_and_epoch"]
36 | write_idx: Write molecular ids for all atoms. This is needed for
37 | atomic properties like forces.
38 | """
39 | super().__init__(write_interval)
40 | self.output_dir = output_dir
41 | self.write_idx = write_idx
42 | os.makedirs(output_dir, exist_ok=True)
43 |
44 | def write_on_batch_end(
45 | self,
46 | trainer,
47 | pl_module: AtomisticTask,
48 | prediction: Any,
49 | batch_indices: List[int],
50 | batch: Any,
51 | batch_idx: int,
52 | dataloader_idx: int,
53 | ):
54 | bdir = os.path.join(self.output_dir, str(dataloader_idx))
55 | os.makedirs(bdir, exist_ok=True)
56 | torch.save(prediction, os.path.join(bdir, f"{batch_idx}.pt"))
57 |
58 | def write_on_epoch_end(
59 | self,
60 | trainer,
61 | pl_module: AtomisticTask,
62 | predictions: List[Any],
63 | batch_indices: List[Any],
64 | ):
65 | # collect batches of predictions and restructure
66 | concatenated_predictions = defaultdict(list)
67 | for batch_prediction in predictions[0]:
68 | for property_name, data in batch_prediction.items():
69 | if not self.write_idx and property_name == properties.idx_m:
70 | continue
71 | concatenated_predictions[property_name].append(data)
72 | concatenated_predictions = {
73 | property_name: torch.concat(data)
74 | for property_name, data in concatenated_predictions.items()
75 | }
76 |
77 | # save concatenated predictions
78 | torch.save(
79 | concatenated_predictions,
80 | os.path.join(self.output_dir, "predictions.pt"),
81 | )
82 |
83 |
84 | class ModelCheckpoint(BaseModelCheckpoint):
85 | """
86 | Like the PyTorch Lightning ModelCheckpoint callback,
87 | but also saves the best inference model with activated post-processing
88 | """
89 |
90 | def __init__(self, model_path: str, do_postprocessing=True, *args, **kwargs):
91 | super().__init__(*args, **kwargs)
92 | self.model_path = model_path
93 | self.do_postprocessing = do_postprocessing
94 |
95 | def on_validation_end(self, trainer, pl_module: AtomisticTask) -> None:
96 | self.trainer = trainer
97 | self.task = pl_module
98 | super().on_validation_end(trainer, pl_module)
99 |
100 | def _update_best_and_save(
101 | self, current: torch.Tensor, trainer, monitor_candidates: Dict[str, Any]
102 | ):
103 | # save model checkpoint
104 | super()._update_best_and_save(current, trainer, monitor_candidates)
105 |
106 | # save best inference model
107 | if isinstance(current, torch.Tensor) and torch.isnan(current):
108 | current = torch.tensor(float("inf" if self.mode == "min" else "-inf"))
109 |
110 | if current == self.best_model_score:
111 | if self.trainer.strategy.local_rank == 0:
112 | # remove references to trainer and data loaders to avoid pickle error in ddp
113 | self.task.save_model(self.model_path, do_postprocessing=True)
114 |
115 |
116 | class ExponentialMovingAverage(Callback):
117 | def __init__(self, decay, *args, **kwargs):
118 | self.decay = decay
119 | self.ema = None
120 | self._to_load = None
121 |
122 | def on_fit_start(self, trainer, pl_module: AtomisticTask):
123 | if self.ema is None:
124 | self.ema = EMA(pl_module.model.parameters(), decay=self.decay)
125 | if self._to_load is not None:
126 | self.ema.load_state_dict(self._to_load)
127 | self._to_load = None
128 |
129 | # load average parameters, to have same starting point as after validation
130 | self.ema.store()
131 | self.ema.copy_to()
132 |
133 | def on_train_epoch_start(
134 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
135 | ) -> None:
136 | self.ema.restore()
137 |
138 | def on_train_batch_end(self, trainer, pl_module: AtomisticTask, *args, **kwargs):
139 | self.ema.update()
140 |
141 | def on_validation_epoch_start(
142 | self, trainer: "pl.Trainer", pl_module: AtomisticTask, *args, **kwargs
143 | ):
144 | self.ema.store()
145 | self.ema.copy_to()
146 |
147 | def load_state_dict(self, state_dict):
148 | if "ema" in state_dict:
149 | if self.ema is None:
150 | self._to_load = state_dict["ema"]
151 | else:
152 | self.ema.load_state_dict(state_dict["ema"])
153 |
154 | def state_dict(self):
155 | return {"ema": self.ema.state_dict()}
156 |
--------------------------------------------------------------------------------
/src/schnetpack/train/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | __all__ = ["ReduceLROnPlateau"]
4 |
5 |
6 | class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
7 | """
8 | Extends PyTorch ReduceLROnPlateau by exponential smoothing of the monitored metric
9 |
10 | """
11 |
12 | def __init__(
13 | self,
14 | optimizer,
15 | mode="min",
16 | factor=0.1,
17 | patience=10,
18 | threshold=1e-4,
19 | threshold_mode="rel",
20 | cooldown=0,
21 | min_lr=0,
22 | eps=1e-8,
23 | verbose=False,
24 | smoothing_factor=0.0,
25 | ):
26 | """
27 | Args:
28 | optimizer (Optimizer): Wrapped optimizer.
29 | mode (str): One of `min`, `max`. In `min` mode, lr will
30 | be reduced when the quantity monitored has stopped
31 | decreasing; in `max` mode it will be reduced when the
32 | quantity monitored has stopped increasing. Default: 'min'.
33 | factor (float): Factor by which the learning rate will be
34 | reduced. new_lr = lr * factor. Default: 0.1.
35 | patience (int): Number of epochs with no improvement after
36 | which learning rate will be reduced. For example, if
37 | `patience = 2`, then we will ignore the first 2 epochs
38 | with no improvement, and will only decrease the LR after the
39 | 3rd epoch if the loss still hasn't improved then.
40 | Default: 10.
41 | threshold (float): Threshold for measuring the new optimum,
42 | to only focus on significant changes. Default: 1e-4.
43 | threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
44 | dynamic_threshold = best * ( 1 + threshold ) in 'max'
45 | mode or best * ( 1 - threshold ) in `min` mode.
46 | In `abs` mode, dynamic_threshold = best + threshold in
47 | `max` mode or best - threshold in `min` mode. Default: 'rel'.
48 | cooldown (int): Number of epochs to wait before resuming
49 | normal operation after lr has been reduced. Default: 0.
50 | min_lr (float or list): A scalar or a list of scalars. A
51 | lower bound on the learning rate of all param groups
52 | or each group respectively. Default: 0.
53 | eps (float): Minimal decay applied to lr. If the difference
54 | between new and old lr is smaller than eps, the update is
55 | ignored. Default: 1e-8.
56 | verbose (bool): If ``True``, prints a message to stdout for
57 | each update. Default: ``False``.
58 | smoothing_factor: smoothing_factor of exponential moving average
59 | """
60 | super().__init__(
61 | optimizer=optimizer,
62 | mode=mode,
63 | factor=factor,
64 | patience=patience,
65 | threshold=threshold,
66 | threshold_mode=threshold_mode,
67 | cooldown=cooldown,
68 | min_lr=min_lr,
69 | eps=eps,
70 | verbose=verbose,
71 | )
72 | self.smoothing_factor = smoothing_factor
73 | self.ema_loss = None
74 |
75 | def step(self, metrics, epoch=None):
76 | current = float(metrics)
77 | if self.ema_loss is None:
78 | self.ema_loss = current
79 | else:
80 | self.ema_loss = (
81 | self.smoothing_factor * self.ema_loss
82 | + (1.0 - self.smoothing_factor) * current
83 | )
84 | super().step(current, epoch)
85 |
--------------------------------------------------------------------------------
/src/schnetpack/train/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchmetrics import Metric
3 | from torchmetrics.functional.regression.mae import (
4 | _mean_absolute_error_compute,
5 | _mean_absolute_error_update,
6 | )
7 |
8 | from typing import Optional, Tuple
9 |
10 | __all__ = ["TensorDiagonalMeanAbsoluteError"]
11 |
12 |
13 | class TensorDiagonalMeanAbsoluteError(Metric):
14 | """
15 | Custom torch metric for monitoring the mean absolute error on the diagonals and offdiagonals of tensors, e.g.
16 | polarizability.
17 | """
18 |
19 | is_differentiable = True
20 | higher_is_better = False
21 | sum_abs_error: torch.Tensor
22 | total: torch.Tensor
23 |
24 | def __init__(
25 | self,
26 | diagonal: Optional[bool] = True,
27 | diagonal_dims: Optional[Tuple[int, int]] = (-2, -1),
28 | dist_sync_on_step=False,
29 | ) -> None:
30 | """
31 |
32 | Args:
33 | diagonal (bool): If true, diagonal values are used, if False off-diagonal.
34 | diagonal_dims (tuple(int,int)): axes of the square matrix for which the diagonals should be considered.
35 | dist_sync_on_step (bool): synchronize.
36 | """
37 | # call `self.add_state`for every internal state that is needed for the metrics computations
38 | # dist_reduce_fx indicates the function that should be used to reduce
39 | # state from multiple processes
40 | super().__init__(dist_sync_on_step=dist_sync_on_step)
41 |
42 | self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
43 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
44 |
45 | self.diagonal = diagonal
46 | self.diagonal_dims = diagonal_dims
47 | self._diagonal_mask = None
48 |
49 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
50 | """
51 | Update the metric.
52 |
53 | Args:
54 | preds (torch.Tensor): network predictions.
55 | target (torch.Tensor): reference values.
56 | """
57 | # update metric states
58 | preds = self._input_format(preds)
59 | target = self._input_format(target)
60 |
61 | sum_abs_error, n_obs = _mean_absolute_error_update(preds, target)
62 |
63 | self.sum_abs_error += sum_abs_error
64 | self.total += n_obs
65 |
66 | def compute(self) -> torch.Tensor:
67 | """
68 | Compute the final metric.
69 |
70 | Returns:
71 | torch.Tensor: mean absolute error of diagonal or offdiagonal elements.
72 | """
73 | # compute final result
74 | return _mean_absolute_error_compute(self.sum_abs_error, self.total)
75 |
76 | def _input_format(self, x) -> torch.Tensor:
77 | """
78 | Extract diagonal / offdiagonal elements from input tensor.
79 |
80 | Args:
81 | x (torch.Tensor): input tensor.
82 |
83 | Returns:
84 | torch.Tensor: extracted and flattened elements (diagonal / offdiagonal)
85 | """
86 | if self._diagonal_mask is None:
87 | self._diagonal_mask = self._init_diag_mask(x)
88 | return x.masked_select(self._diagonal_mask)
89 |
90 | def _init_diag_mask(self, x: torch.Tensor) -> torch.Tensor:
91 | """
92 | Initialize the mask for extracting the diagonal elements based on the given axes and the shape of the
93 | input tensor.
94 |
95 | Args:
96 | x (torch.Tensor): input tensor.
97 |
98 | Returns:
99 | torch.Tensor: Boolean diagonal mask.
100 | """
101 | tensor_shape = x.shape
102 | dim_0 = tensor_shape[self.diagonal_dims[0]]
103 | dim_1 = tensor_shape[self.diagonal_dims[1]]
104 |
105 | if not dim_0 == dim_1:
106 | raise AssertionError(
107 | "Found different size for diagonal dimensions, expected square sub matrix."
108 | )
109 |
110 | view = [1 for _ in tensor_shape]
111 | view[self.diagonal_dims[0]] = dim_0
112 | view[self.diagonal_dims[1]] = dim_1
113 |
114 | diag_mask = torch.eye(dim_0, device=x.device, dtype=torch.long).view(view)
115 |
116 | if self.diagonal:
117 | return diag_mask == 1
118 | else:
119 | return diag_mask != 1
120 |
--------------------------------------------------------------------------------
/src/schnetpack/transform/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Transforms are applied before and/or after the model. They can be used, e.g., for calculating
3 | neighbor lists, casting, unit conversion or data augmentation. Some can applied before batching,
4 | i.e. to single systems, when loading the data. This is necessary for pre-processing and includes
5 | neighbor lists, for example. On the other hand, transforms need to be able to handle batches
6 | for post-processing. The flags `is_postprocessor` and `is_preprocessor` indicate how the tranforms
7 | may be used. The attribute `mode` of a transform is set automatically to either "pre" or "post".q
8 | """
9 |
10 | from .atomistic import *
11 | from .casting import *
12 | from .neighborlist import *
13 | from .response import *
14 | from .base import *
15 |
--------------------------------------------------------------------------------
/src/schnetpack/transform/base.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | import schnetpack as spk
7 |
8 | __all__ = [
9 | "Transform",
10 | "TransformException",
11 | ]
12 |
13 |
14 | class TransformException(Exception):
15 | pass
16 |
17 |
18 | class Transform(nn.Module):
19 | """
20 | Base class for all transforms.
21 | The base class ensures that the reference to the data and datamodule attributes are
22 | initialized.
23 | Transforms can be used as pre- or post-processing layers.
24 | They can also be used for other parts of a model, that need to be
25 | initialized based on data.
26 |
27 | To implement a new transform, override the forward method. Preprocessors are applied
28 | to single examples, while postprocessors operate on batches. All transforms should
29 | return a modified `inputs` dictionary.
30 |
31 | """
32 |
33 | def datamodule(self, value):
34 | """
35 | Extract all required information from data module automatically when using
36 | PyTorch Lightning integration. The transform should also implement a way to
37 | set these things manually, to make it usable independent of PL.
38 |
39 | Do not store the datamodule, as this does not work with torchscript conversion!
40 | """
41 | pass
42 |
43 | def forward(
44 | self,
45 | inputs: Dict[str, torch.Tensor],
46 | ) -> Dict[str, torch.Tensor]:
47 | raise NotImplementedError
48 |
49 | def teardown(self):
50 | pass
51 |
--------------------------------------------------------------------------------
/src/schnetpack/transform/casting.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | from typing import Dict, Optional
3 | from schnetpack.utils import as_dtype
4 |
5 | import torch
6 |
7 | from .base import Transform
8 |
9 | __all__ = ["CastMap", "CastTo32", "CastTo64"]
10 |
11 |
12 | class CastMap(Transform):
13 | """
14 | Cast all inputs according to type map.
15 | """
16 |
17 | is_preprocessor: bool = True
18 | is_postprocessor: bool = True
19 |
20 | def __init__(self, type_map: Dict[str, str]):
21 | """
22 | Args:
23 | type_map: dict with source_type: target_type (as strings)
24 | """
25 | super().__init__()
26 | self.type_map = type_map
27 |
28 | def forward(
29 | self,
30 | inputs: Dict[str, torch.Tensor],
31 | ) -> Dict[str, torch.Tensor]:
32 | for k, v in inputs.items():
33 | vdtype = str(v.dtype).split(".")[-1]
34 | if vdtype in self.type_map:
35 | inputs[k] = v.to(dtype=as_dtype(self.type_map[vdtype]))
36 | return inputs
37 |
38 |
39 | class CastTo32(CastMap):
40 | """Cast all float64 tensors to float32"""
41 |
42 | def __init__(self):
43 | super().__init__(type_map={"float64": "float32"})
44 |
45 |
46 | class CastTo64(CastMap):
47 | """Cast all float32 tensors to float64"""
48 |
49 | def __init__(self):
50 | super().__init__(type_map={"float32": "float64"})
51 |
--------------------------------------------------------------------------------
/src/schnetpack/transform/response.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from schnetpack.transform.base import Transform
4 | from schnetpack import properties
5 |
6 | from typing import Dict, List
7 |
8 | __all__ = ["SplitShielding"]
9 |
10 |
11 | class SplitShielding(Transform):
12 | """
13 | Transform for splitting shielding tensors by atom types.
14 | """
15 |
16 | is_preprocessor: bool = True
17 | is_postprocessor: bool = False
18 |
19 | def __init__(
20 | self,
21 | shielding_key: str,
22 | atomic_numbers: List[int],
23 | ):
24 | """
25 | Args:
26 | shielding_key (str): name of the shielding tensor in the model inputs.
27 | atomic_numbers (list(int)): list of atomic numbers used to split the shielding tensor.
28 | """
29 | super(SplitShielding, self).__init__()
30 |
31 | self.shielding_key = shielding_key
32 | self.atomic_numbers = atomic_numbers
33 |
34 | self.model_outputs = [
35 | "{:s}_{:d}".format(self.shielding_key, atomic_number)
36 | for atomic_number in self.atomic_numbers
37 | ]
38 |
39 | def forward(
40 | self,
41 | inputs: Dict[str, torch.Tensor],
42 | ) -> Dict[str, torch.Tensor]:
43 | shielding = inputs[self.shielding_key]
44 |
45 | split_shielding = {}
46 | for atomic_number in self.atomic_numbers:
47 | atomic_key = "{:s}_{:d}".format(self.shielding_key, atomic_number)
48 | split_shielding[atomic_key] = shielding[
49 | inputs[properties.Z] == atomic_number, :, :
50 | ]
51 |
52 | inputs.update(split_shielding)
53 |
54 | return inputs
55 |
--------------------------------------------------------------------------------
/src/schnetpack/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .compatibility import *
2 | import importlib
3 | import torch
4 | from typing import Type, Union, List
5 |
6 | from schnetpack import properties as spk_properties
7 |
8 | TORCH_DTYPES = {
9 | "float32": torch.float32,
10 | "float64": torch.float64,
11 | "float": torch.float,
12 | "float16": torch.float16,
13 | "bfloat16": torch.bfloat16,
14 | "half": torch.half,
15 | "uint8": torch.uint8,
16 | "int8": torch.int8,
17 | "int16": torch.int16,
18 | "short": torch.short,
19 | "int32": torch.int32,
20 | "int": torch.int,
21 | "int64": torch.int64,
22 | "long": torch.long,
23 | "complex64": torch.complex64,
24 | "cfloat": torch.cfloat,
25 | "complex128": torch.complex128,
26 | "cdouble": torch.cdouble,
27 | "quint8": torch.quint8,
28 | "qint8": torch.qint8,
29 | "qint32": torch.qint32,
30 | "bool": torch.bool,
31 | }
32 |
33 | TORCH_DTYPES.update({"torch." + k: v for k, v in TORCH_DTYPES.items()})
34 |
35 |
36 | def as_dtype(dtype_str: str) -> torch.dtype:
37 | """Convert a string to torch.dtype"""
38 | return TORCH_DTYPES[dtype_str]
39 |
40 |
41 | def int2precision(precision: Union[int, torch.dtype]):
42 | """
43 | Get torch floating point precision from integer.
44 | If an instance of torch.dtype is passed, it is returned automatically.
45 |
46 | Args:
47 | precision (int, torch.dtype): Target precision.
48 |
49 | Returns:
50 | torch.dtype: Floating point precision.
51 | """
52 | if isinstance(precision, torch.dtype):
53 | return precision
54 | else:
55 | try:
56 | return getattr(torch, f"float{precision}")
57 | except AttributeError:
58 | raise AttributeError(f"Unknown float precision {precision}")
59 |
60 |
61 | def str2class(class_path: str) -> Type:
62 | """
63 | Obtain a class type from a string
64 |
65 | Args:
66 | class_path: module path to class, e.g. ``module.submodule.classname``
67 |
68 | Returns:
69 | class type
70 | """
71 | class_path = class_path.split(".")
72 | class_name = class_path[-1]
73 | module_name = ".".join(class_path[:-1])
74 | cls = getattr(importlib.import_module(module_name), class_name)
75 | return cls
76 |
77 |
78 | def required_fields_from_properties(properties: List[str]) -> List[str]:
79 | """
80 | Determine required external fields based on the response properties to be computed.
81 |
82 | Args:
83 | properties (list(str)): List of response properties for which external fields should be determined.
84 |
85 | Returns:
86 | list(str): List of required external fields.
87 | """
88 | required_fields = set()
89 |
90 | for p in properties:
91 | if p in spk_properties.required_external_fields:
92 | required_fields.update(spk_properties.required_external_fields[p])
93 |
94 | required_fields = list(required_fields)
95 |
96 | return required_fields
97 |
--------------------------------------------------------------------------------
/src/schnetpack/utils/compatibility.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import warnings
3 | from typing import Any, Union
4 |
5 |
6 | __all__ = ["load_model"]
7 |
8 |
9 | def load_model(
10 | model_path: str, device: Union[torch.device, str] = "cpu", **kwargs: Any
11 | ) -> torch.nn.Module:
12 | """
13 | Load a SchNetPack model from a Torch file, enabling compatibility with models trained using earlier versions of
14 | SchNetPack. This function imports the old model and automatically updates it to the format used in the current
15 | SchNetPack version. To ensure proper functionality, the Torch model object must include a version tag, such as
16 | spk_version="2.0.4".
17 |
18 | Args:
19 | model_path (str): Path to the saved model file.
20 | device (torch.device or str): Device on which to load the model. Defaults to "cpu".
21 | **kwargs (Any): Additional keyword arguments for `torch.load`.
22 |
23 | Returns:
24 | torch.nn.Module: Loaded model.
25 | """
26 |
27 | def _convert_from_older(model: torch.nn.Module) -> torch.nn.Module:
28 | model.spk_version = "2.0.4"
29 | return model
30 |
31 | def _convert_from_v2_0_4(model: torch.nn.Module) -> torch.nn.Module:
32 | if not hasattr(model.representation, "electronic_embeddings"):
33 | model.representation.electronic_embeddings = []
34 | model.spk_version = "2.1.0"
35 | return model
36 |
37 | model = torch.load(model_path, map_location=device, weights_only=False, **kwargs)
38 |
39 | if not hasattr(model, "spk_version"):
40 | # make warning that model has no version information
41 | warnings.warn(
42 | "Model was saved without version information. Conversion to current version may fail."
43 | )
44 | model = _convert_from_older(model)
45 |
46 | if model.spk_version == "2.0.4":
47 | model = _convert_from_v2_0_4(model)
48 |
49 | return model
50 |
--------------------------------------------------------------------------------
/src/schnetpack/utils/script.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Dict, Sequence
2 |
3 | import pytorch_lightning as pl
4 | import rich
5 | import yaml
6 | from omegaconf import DictConfig, OmegaConf
7 | from pytorch_lightning.utilities import rank_zero_only
8 | from rich.syntax import Syntax
9 | from rich.tree import Tree
10 |
11 |
12 | __all__ = ["log_hyperparameters", "print_config"]
13 |
14 |
15 | def empty(*args, **kwargs):
16 | pass
17 |
18 |
19 | def todict(config: Union[DictConfig, Dict]):
20 | config_dict = yaml.safe_load(OmegaConf.to_yaml(config, resolve=True))
21 | return config_dict
22 |
23 |
24 | @rank_zero_only
25 | def log_hyperparameters(
26 | config: DictConfig,
27 | model: pl.LightningModule,
28 | trainer: pl.Trainer,
29 | ) -> None:
30 | """
31 | This saves Hydra config using Lightning loggers.
32 | """
33 |
34 | # send hparams to all loggers
35 | trainer.logger.log_hyperparams(config)
36 |
37 | # disable logging any more hyperparameters for all loggers
38 | trainer.logger.log_hyperparams = empty
39 |
40 |
41 | @rank_zero_only
42 | def print_config(
43 | config: DictConfig,
44 | fields: Sequence[str] = (
45 | "run",
46 | "globals",
47 | "data",
48 | "model",
49 | "task",
50 | "trainer",
51 | "callbacks",
52 | "logger",
53 | "seed",
54 | ),
55 | resolve: bool = True,
56 | ) -> None:
57 | """Prints content of DictConfig using Rich library and its tree structure.
58 |
59 | Args:
60 | config (DictConfig): Config.
61 | fields (Sequence[str], optional): Determines which main fields from config will be printed
62 | and in what order.
63 | resolve (bool, optional): Whether to resolve reference fields of DictConfig.
64 | """
65 |
66 | style = "dim"
67 | tree = Tree(
68 | f":gear: Running with the following config:", style=style, guide_style=style
69 | )
70 |
71 | for field in fields:
72 | branch = tree.add(field, style=style, guide_style=style)
73 |
74 | config_section = config.get(field)
75 | branch_content = str(config_section)
76 | if isinstance(config_section, DictConfig):
77 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
78 |
79 | branch.add(Syntax(branch_content, "yaml"))
80 |
81 | rich.print(tree)
82 |
--------------------------------------------------------------------------------
/src/scripts/spkconvert:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 |
4 | from ase.db import connect
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 |
9 | if __name__ == "__main__":
10 | parser = argparse.ArgumentParser(
11 | description="Set units of an ASE dataset, e.g. to convert from SchNetPack 1.0 to the new format."
12 | )
13 | parser.add_argument(
14 | "data_path",
15 | help="Path to ASE DB dataset",
16 | )
17 | parser.add_argument(
18 | "--distunit",
19 | help="Distance unit as string, corresponding to ASE units (e.g. `Ang`)",
20 | )
21 | parser.add_argument(
22 | "--propunit",
23 | help="Property units as string, corresponding "
24 | "to ASE units (e.g. `kcal/mol/Ang`), in the form: `property1:unit1,property2:unit2`",
25 | )
26 | parser.add_argument(
27 | "--expand_property_dims",
28 | default=[],
29 | nargs='+',
30 | help="Expanding the first dimension of the given property "
31 | "(required for example for old FieldSchNet datasets). "
32 | "Add property names here in the form 'property1 property2 property3'",
33 | )
34 | args = parser.parse_args()
35 | with connect(args.data_path) as db:
36 | meta = db.metadata
37 | print(meta)
38 |
39 | if "atomrefs" not in meta.keys():
40 | meta["atomrefs"] = {}
41 | elif "atref_labels" in meta.keys():
42 | old_atref = np.array(meta["atomrefs"])
43 | new_atomrefs = {}
44 | labels = meta["atref_labels"]
45 | if type(labels) is str:
46 | labels = [labels]
47 | for i, label in enumerate(labels):
48 | print(i, label, old_atref[:, i])
49 | new_atomrefs[label] = list(old_atref[:, i])
50 | meta["atomrefs"] = new_atomrefs
51 | del meta["atref_labels"]
52 |
53 | if args.distunit:
54 | if args.distunit == "A":
55 | raise ValueError(
56 | "The provided unit (A for Ampere) is not a valid distance unit according to the ASE unit"
57 | " definitions. You probably mean `Ang`/`Angstrom`. Please also check your property units!"
58 | )
59 | meta["_distance_unit"] = args.distunit
60 |
61 | if args.propunit:
62 | if "_property_unit_dict" not in meta.keys():
63 | meta["_property_unit_dict"] = {}
64 |
65 | for p in args.propunit.split(","):
66 | prop, unit = p.split(":")
67 | meta["_property_unit_dict"][prop] = unit
68 |
69 | with connect(args.data_path) as db:
70 | db.metadata = meta
71 |
72 | if args.expand_property_dims is not None and len(args.expand_property_dims) > 0:
73 |
74 | with connect(args.data_path) as db:
75 | for i in tqdm(range(len(db))):
76 | atoms_row = db.get(i + 1)
77 | data = {}
78 | for p, v in atoms_row.data.items():
79 | if p in args.expand_property_dims:
80 | data[p] = np.expand_dims(v, 0)
81 | else:
82 | data[p] = v
83 | db.update(i + 1, data=data)
--------------------------------------------------------------------------------
/src/scripts/spkdeploy:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import torch
3 | import torch.nn as nn
4 | from schnetpack.transform import CastTo64, CastTo32, AddOffsets
5 | import argparse
6 |
7 |
8 | # This script is supposed to take a pytorch model and save a just in time compiled version of it.
9 | # This is needed to run the model with LAMMPS.
10 | # For further info see examples/howtos/lammps.rst
11 |
12 | # Note that this script is designed for models that predict atomic forces via automatic differentiation (utilizing response modules).
13 | # Hence this script will not work for models without response modules.
14 |
15 |
16 | def get_jit_model(model):
17 | # fix invalid operations in postprocessing
18 | jit_postprocessors = nn.ModuleList()
19 | for postprocessor in model.postprocessors:
20 | # ignore type casting
21 | if type(postprocessor) in [CastTo64, CastTo32]:
22 | continue
23 | # ensure offset mean is float
24 | if type(postprocessor) == AddOffsets:
25 | postprocessor.mean = postprocessor.mean.float()
26 |
27 | jit_postprocessors.append(postprocessor)
28 | model.postprocessors = jit_postprocessors
29 |
30 | return torch.jit.script(model)
31 |
32 |
33 | def save_jit_model(model, model_path):
34 | jit_model = get_jit_model(model)
35 |
36 | # add metadata
37 | metadata = dict()
38 | metadata["cutoff"] = str(jit_model.representation.cutoff.item()).encode("ascii")
39 |
40 | torch.jit.save(jit_model, model_path, _extra_files=metadata)
41 |
42 |
43 | if __name__ == "__main__":
44 |
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("model_path")
47 | parser.add_argument("deployed_model_path")
48 | parser.add_argument("--device", type=str, default="cpu")
49 | args = parser.parse_args()
50 |
51 | model = torch.load(args.model_path, map_location=args.device)
52 | save_jit_model(model, args.deployed_model_path)
53 |
54 | print(f"stored deployed model at {args.deployed_model_path}.")
55 |
--------------------------------------------------------------------------------
/src/scripts/spkmd:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import schnetpack.md.cli as cli
3 |
4 | if __name__ == "__main__":
5 | cli.simulate()
6 |
--------------------------------------------------------------------------------
/src/scripts/spkpredict:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import schnetpack.cli as cli
3 |
4 | if __name__ == "__main__":
5 | cli.predict()
6 |
--------------------------------------------------------------------------------
/src/scripts/spktrain:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import schnetpack.cli as cli
3 |
4 | if __name__ == "__main__":
5 | cli.train()
6 |
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | # Test suite
2 |
3 | SchNetPack has unit tests that you can run to make sure that everything is working properly.
4 |
5 | ## Test dependencies
6 | To install the additional test dependencies, install with:
7 | ```
8 | pip install schnetpack[test]
9 | ```
10 |
11 | Or, if you installed from source and assuming you are in your local copy of the SchNetPack repository,
12 | ```
13 | pip install .[test]
14 | ```
15 |
16 | ## Run tests
17 | In order to run the tests, run the following command from the root of the repository:
18 | ```
19 | pytest tests
20 | ```
21 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/__init__.py
--------------------------------------------------------------------------------
/tests/atomistic/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/atomistic/__init__.py
--------------------------------------------------------------------------------
/tests/atomistic/test_response.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from schnetpack.data.loader import _atoms_collate_fn
5 | import schnetpack as spk
6 |
7 |
8 | def test_strain(environment_periodic):
9 | cutoff, props, neighbors = environment_periodic
10 | props.update(neighbors)
11 | batch = _atoms_collate_fn([props, props])
12 | strained_batch = spk.atomistic.Strain()(batch)
13 | assert np.allclose(
14 | batch[spk.properties.offsets].detach().numpy(),
15 | strained_batch[spk.properties.offsets].detach().numpy(),
16 | )
17 |
--------------------------------------------------------------------------------
/tests/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/data/__init__.py
--------------------------------------------------------------------------------
/tests/data/conftest.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/data/conftest.py
--------------------------------------------------------------------------------
/tests/data/test_data.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import numpy as np
4 | from schnetpack.data import *
5 | import os
6 | import schnetpack.properties as structure
7 | from schnetpack.data import calculate_stats, AtomsLoader
8 |
9 |
10 | @pytest.fixture
11 | def asedbpath(tmpdir):
12 | return os.path.join(tmpdir, "test.db")
13 |
14 |
15 | @pytest.fixture(scope="function")
16 | def asedb(asedbpath, example_data, property_units):
17 |
18 | asedb = ASEAtomsData.create(
19 | datapath=asedbpath, distance_unit="A", property_unit_dict=property_units
20 | )
21 |
22 | atoms_list, prop_list = zip(*example_data)
23 | asedb.add_systems(property_list=prop_list, atoms_list=atoms_list)
24 | yield asedb
25 |
26 | os.remove(asedb.datapath)
27 | del asedb
28 |
29 |
30 | def test_asedb(asedb, example_data):
31 | assert os.path.exists(asedb.datapath)
32 | assert len(example_data) == len(asedb)
33 | assert set(asedb.metadata["_property_unit_dict"].keys()) == set(
34 | asedb.available_properties
35 | )
36 | assert asedb.metadata["_property_unit_dict"] == asedb.units
37 |
38 | props = asedb[0]
39 | assert set(props.keys()) == set(
40 | [
41 | structure.Z,
42 | structure.R,
43 | structure.cell,
44 | structure.pbc,
45 | structure.n_atoms,
46 | structure.idx,
47 | ]
48 | + asedb.available_properties
49 | )
50 |
51 | load_properties = asedb.available_properties[0:2]
52 | asedb.load_properties = load_properties
53 | props = asedb[0]
54 | assert set(props.keys()) == set(
55 | [
56 | structure.Z,
57 | structure.R,
58 | structure.cell,
59 | structure.pbc,
60 | structure.n_atoms,
61 | structure.idx,
62 | ]
63 | + load_properties
64 | )
65 |
66 | asedb.load_structure = False
67 | props = asedb[0]
68 | assert set(props.keys()) == set(
69 | [
70 | structure.n_atoms,
71 | structure.idx,
72 | ]
73 | + load_properties
74 | )
75 |
76 | asedb.update_metadata(test=1)
77 | assert asedb.metadata["test"] == 1
78 |
79 |
80 | def test_asedb_getprops(asedb):
81 | props = list(asedb.iter_properties(0))[0]
82 | assert set(props.keys()) == set(
83 | [
84 | structure.Z,
85 | structure.R,
86 | structure.cell,
87 | structure.pbc,
88 | structure.n_atoms,
89 | structure.idx,
90 | ]
91 | + asedb.available_properties
92 | )
93 |
94 |
95 | def test_stats():
96 | data = []
97 | for i in range(6):
98 | Z = torch.tensor([1, 1, 1])
99 | off = 1.0 if i % 2 == 0 else -1.0
100 | d = {
101 | structure.Z: Z,
102 | structure.n_atoms: torch.tensor([len(Z)]),
103 | "property1": torch.tensor([1.0 + len(Z) * off]),
104 | "property2": torch.tensor([off]),
105 | }
106 | data.append(d)
107 |
108 | atomref = {"property1": torch.ones((100,)) / 3.0}
109 | for bs in range(1, 7):
110 | stats = calculate_stats(
111 | AtomsLoader(data, batch_size=bs),
112 | {"property1": True, "property2": False},
113 | atomref=atomref,
114 | )
115 | assert np.allclose(stats["property1"][0].numpy(), np.array([0.0]))
116 | assert np.allclose(stats["property1"][1].numpy(), np.array([1.0]))
117 | assert np.allclose(stats["property2"][0].numpy(), np.array([0.0]))
118 | assert np.allclose(stats["property2"][1].numpy(), np.array([1.0]))
119 |
120 |
121 | def test_asedb_add(asedb, example_data):
122 | l = len(asedb)
123 |
124 | at, props = example_data[0]
125 | asedb.add_system(atoms=at, **props)
126 |
127 | props.update(
128 | {
129 | structure.Z: at.numbers,
130 | structure.R: at.positions,
131 | structure.cell: at.cell,
132 | structure.pbc: at.pbc,
133 | }
134 | )
135 | asedb.add_system(**props)
136 |
137 | p1 = asedb[l]
138 | p2 = asedb[l + 1]
139 | for k, v in p1.items():
140 | if k != "_idx":
141 | assert type(v) == torch.Tensor, k
142 | assert (p2[k] == v).all(), v
143 |
--------------------------------------------------------------------------------
/tests/data/test_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | import numpy as np
4 |
5 | from schnetpack.datasets import QM9, MD17, rMD17
6 |
7 |
8 | @pytest.fixture
9 | def test_qm9_path():
10 | path = os.path.join(os.path.dirname(__file__), "../testdata/test_qm9.db")
11 | return path
12 |
13 |
14 | @pytest.mark.skip(
15 | "Run only local, not in CI. Otherwise takes too long and requires downloading "
16 | + "the data"
17 | )
18 | def test_qm9(test_qm9_path):
19 | qm9 = QM9(
20 | test_qm9_path,
21 | num_train=10,
22 | num_val=5,
23 | batch_size=5,
24 | remove_uncharacterized=True,
25 | )
26 | assert len(qm9.train_dataset) == 10
27 | assert len(qm9.val_dataset) == 5
28 | assert len(qm9.test_dataset) == 5
29 |
30 | ds = [b for b in qm9.train_dataloader()]
31 | assert len(ds) == 2
32 | ds = [b for b in qm9.val_dataloader()]
33 | assert len(ds) == 1
34 | ds = [b for b in qm9.test_dataloader()]
35 | assert len(ds) == 1
36 |
37 |
38 | @pytest.fixture
39 | def test_md17_path():
40 | path = os.path.join(os.path.dirname(__file__), "../testdata/tmp/test_md17.db")
41 | return path
42 |
43 |
44 | @pytest.mark.skip(
45 | "Run only local, not in CI. Otherwise takes too long and requires downloading "
46 | + "the data"
47 | )
48 | def test_md17(test_md17_path):
49 | md17 = MD17(
50 | test_md17_path,
51 | num_train=10,
52 | num_val=5,
53 | num_test=5,
54 | batch_size=5,
55 | molecule="uracil",
56 | )
57 | md17.prepare_data()
58 | md17.setup()
59 | assert len(md17.train_dataset) == 10
60 | assert len(md17.val_dataset) == 5
61 | assert len(md17.test_dataset) == 5
62 |
63 | ds = [b for b in md17.train_dataloader()]
64 | assert len(ds) == 2
65 | ds = [b for b in md17.val_dataloader()]
66 | assert len(ds) == 1
67 | ds = [b for b in md17.test_dataloader()]
68 | assert len(ds) == 1
69 |
70 |
71 | @pytest.fixture
72 | def test_rmd17_path():
73 | path = os.path.join(os.path.dirname(__file__), "../testdata/tmp/test_rmd17.db")
74 | return path
75 |
76 |
77 | @pytest.mark.skip(
78 | "Run only local, not in CI. Otherwise takes too long and requires downloading "
79 | + "the data"
80 | )
81 | def test_rmd17(test_rmd17_path):
82 | md17 = rMD17(
83 | test_rmd17_path,
84 | num_train=950,
85 | num_val=50,
86 | num_test=1000,
87 | batch_size=5,
88 | molecule="uracil",
89 | )
90 | md17.prepare_data()
91 | md17.setup()
92 | assert len(md17.train_dataset) == 950
93 | assert len(md17.val_dataset) == 50
94 | assert len(md17.test_dataset) == 1000
95 |
96 | train_idx = md17.train_dataset.subset_idx
97 | val_idx = md17.val_dataset.subset_idx
98 | test_idx = md17.test_dataset.subset_idx
99 | assert len(np.intersect1d(train_idx, val_idx)) == 0
100 | assert len(np.intersect1d(train_idx, test_idx)) == 0
101 | assert len(np.intersect1d(val_idx, test_idx)) == 0
102 |
--------------------------------------------------------------------------------
/tests/data/test_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from schnetpack.transform import ASENeighborList
4 | from schnetpack.data.loader import _atoms_collate_fn
5 | import schnetpack.properties as structure
6 |
7 |
8 | def test_collate_noenv(single_atom, two_atoms):
9 | batch = [single_atom, two_atoms]
10 | collated_batch = _atoms_collate_fn(batch)
11 | assert all([key in collated_batch.keys() for key in single_atom])
12 | assert structure.idx_m in collated_batch.keys()
13 | assert (collated_batch[structure.idx_m] == torch.tensor((0, 1, 1))).all()
14 |
15 |
16 | def test_collate_env(single_atom, two_atoms):
17 | nll = ASENeighborList(cutoff=5.0)
18 | batch = [nll(single_atom), nll(two_atoms)]
19 |
20 | collated_batch = _atoms_collate_fn(batch)
21 | assert all([key in collated_batch.keys() for key in single_atom])
22 | assert structure.idx_m in collated_batch.keys()
23 | assert (collated_batch[structure.idx_m] == torch.tensor((0, 1, 1))).all()
24 | assert (
25 | collated_batch[structure.idx_i] == torch.tensor((1, 2))
26 | ).all(), collated_batch[structure.idx_i]
27 | assert (
28 | collated_batch[structure.idx_j] == torch.tensor((2, 1))
29 | ).all(), collated_batch[structure.idx_j]
30 |
--------------------------------------------------------------------------------
/tests/data/test_transforms.py:
--------------------------------------------------------------------------------
1 | import schnetpack.properties as structure
2 | import pytest
3 | import torch
4 | from ase.data import atomic_masses
5 | from schnetpack.transform import *
6 |
7 |
8 | def assert_consistent(orig, transformed):
9 | for k, v in orig.items():
10 | assert (v == transformed[k]).all(), f"Changed value: {k}"
11 |
12 |
13 | @pytest.fixture(params=[0, 1])
14 | def neighbor_list(request):
15 | neighbor_lists = [ASENeighborList, TorchNeighborList]
16 | return neighbor_lists[request.param]
17 |
18 |
19 | class TestNeighborLists:
20 | """
21 | Test for different neighbor lists defined in neighbor_list using the Argon environment fixtures (periodic and
22 | non-periodic).
23 |
24 | """
25 |
26 | def test_neighbor_list(self, neighbor_list, environment):
27 | cutoff, props, neighbors_ref = environment
28 | neighbor_list = neighbor_list(cutoff)
29 | neighbors = neighbor_list(props)
30 | R = props[structure.R]
31 | neighbors[structure.Rij] = (
32 | R[neighbors[structure.idx_j]]
33 | - R[neighbors[structure.idx_i]]
34 | + props[structure.offsets]
35 | )
36 |
37 | neighbors = self._sort_neighbors(neighbors)
38 | neighbors_ref = self._sort_neighbors(neighbors_ref)
39 |
40 | for nbl, nbl_ref in zip(neighbors, neighbors_ref):
41 | torch.testing.assert_close(nbl, nbl_ref)
42 |
43 | def _sort_neighbors(self, neighbors):
44 | """
45 | Routine for sorting the index, shift and distance vectors to allow comparison between different
46 | neighbor list implementations.
47 |
48 | Args:
49 | neighbors: Input dictionary holding system neighbor information (idx_i, idx_j, cell_offset and Rij)
50 |
51 | Returns:
52 | torch.LongTensor: indices of central atoms in each pair
53 | torch.LongTensor: indices of each neighbor
54 | torch.LongTensor: cell offsets
55 | torch.Tensor: distance vectors associated with each pair
56 | """
57 | idx_i = neighbors[structure.idx_i]
58 | idx_j = neighbors[structure.idx_j]
59 | Rij = neighbors[structure.Rij]
60 |
61 | sort_idx = self._get_unique_idx(idx_i, idx_j, Rij)
62 |
63 | return idx_i[sort_idx], idx_j[sort_idx], Rij[sort_idx]
64 |
65 | @staticmethod
66 | def _get_unique_idx(
67 | idx_i: torch.Tensor, idx_j: torch.Tensor, offsets: torch.Tensor
68 | ):
69 | """
70 | Compute unique indices for every neighbor pair based on the central atom, the neighbor and the cell the
71 | neighbor belongs to. This is used for sorting the neighbor lists in order to compare between different
72 | implementations.
73 |
74 | Args:
75 | idx_i: indices of central atoms in each pair
76 | idx_j: indices of each neighbor
77 | offsets: cell offsets
78 |
79 | Returns:
80 | torch.LongTensor: indices used for sorting each tensor in a unique manner
81 | """
82 | n_max = torch.max(torch.abs(offsets))
83 |
84 | n_repeats = 2 * n_max + 1
85 | n_atoms = torch.max(idx_i) + 1
86 |
87 | unique_idx = (
88 | n_repeats**3 * (n_atoms * idx_i + idx_j)
89 | + (offsets[:, 0] + n_max)
90 | + n_repeats * (offsets[:, 1] + n_max)
91 | + n_repeats**2 * (offsets[:, 2] + n_max)
92 | )
93 |
94 | return torch.argsort(unique_idx)
95 |
96 |
97 | def test_single_atom(single_atom, neighbor_list, cutoff):
98 | neighbor_list = neighbor_list(cutoff)
99 | props_after = neighbor_list(single_atom)
100 | R = props_after[structure.R]
101 | props_after[structure.Rij] = (
102 | R[props_after[structure.idx_j]]
103 | - R[props_after[structure.idx_i]]
104 | + props_after[structure.offsets]
105 | )
106 |
107 | assert_consistent(single_atom, props_after)
108 | assert len(props_after[structure.offsets]) == 0
109 | assert len(props_after[structure.idx_i]) == 0
110 | assert len(props_after[structure.idx_j]) == 0
111 |
112 |
113 | def test_cast(single_atom):
114 | allf64 = [k for k, v in single_atom.items() if v.dtype is torch.float64]
115 | other_types = {
116 | k: v.dtype for k, v in single_atom.items() if v.dtype is not torch.float64
117 | }
118 |
119 | assert len(allf64) > 0, single_atom
120 | props_after = CastTo32()(single_atom)
121 |
122 | for k in props_after:
123 | if k in allf64:
124 | assert props_after[k].dtype is torch.float32
125 | else:
126 | assert props_after[k].dtype is other_types[k]
127 |
128 |
129 | def test_remove_com(four_atoms):
130 | positions_trans = SubtractCenterOfMass()(four_atoms)
131 |
132 | com = torch.tensor([0.0, 0.0, 0.0])
133 | for r_i, m_i in zip(
134 | positions_trans[structure.position], atomic_masses[four_atoms[structure.Z]]
135 | ):
136 | com += r_i * m_i
137 |
138 | torch.testing.assert_close(com, torch.tensor([0.0, 0.0, 0.0]))
139 |
140 |
141 | def test_remove_cog(four_atoms):
142 | positions_trans = SubtractCenterOfGeometry()(four_atoms)
143 |
144 | cog = torch.tensor([0.0, 0.0, 0.0])
145 | for r_i in positions_trans[structure.position]:
146 | cog += r_i
147 |
148 | torch.testing.assert_close(cog, torch.tensor([0.0, 0.0, 0.0]))
149 |
--------------------------------------------------------------------------------
/tests/nn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/nn/__init__.py
--------------------------------------------------------------------------------
/tests/nn/test_activations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | import schnetpack as spk
5 |
6 |
7 | def test_activation_softplus():
8 | # simple tensor
9 | x = torch.tensor([0.0, 1.0, 0.5, 2.0])
10 | expt = torch.log(1.0 + torch.exp(x)) - np.log(2)
11 | assert torch.allclose(expt, spk.nn.shifted_softplus(x), atol=0.0, rtol=1.0e-7)
12 | # random tensor
13 | torch.manual_seed(42)
14 | x = torch.randn((10, 5), dtype=torch.double)
15 | expt = torch.log(1.0 + torch.exp(x)) - np.log(2)
16 | assert torch.allclose(expt, spk.nn.shifted_softplus(x), atol=0.0, rtol=1.0e-7)
17 | x = 10 * torch.randn((10, 5), dtype=torch.double)
18 | expt = torch.log(1.0 + torch.exp(x)) - np.log(2)
19 | assert torch.allclose(expt, spk.nn.shifted_softplus(x), atol=0.0, rtol=1.0e-7)
20 |
21 |
22 | def test_shape_ssp():
23 | in_data = torch.rand(10)
24 | out_data = spk.nn.shifted_softplus(in_data)
25 | assert in_data.shape == out_data.shape
26 |
--------------------------------------------------------------------------------
/tests/nn/test_cutoff.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from schnetpack.nn.cutoff import CosineCutoff, MollifierCutoff
5 |
6 |
7 | def test_cutoff_cosine():
8 | # cosine cutoff with radius 1.8
9 | cutoff = CosineCutoff(cutoff=1.8)
10 | # check cutoff radius
11 | assert abs(1.8 - cutoff.cutoff) < 1.0e-12
12 | # random tensor with elements in [0, 1)
13 | torch.manual_seed(42)
14 | dist = torch.rand((10, 5, 20), dtype=torch.float)
15 | # check cutoff values
16 | expt = 0.5 * (1.0 + torch.cos(dist * np.pi / 1.8))
17 | assert torch.allclose(expt, cutoff(dist), atol=0.0, rtol=1.0e-7)
18 | # compute expected values for 3.5 times distance
19 | values = 0.5 * (1.0 + torch.cos(3.5 * dist * np.pi / 1.8))
20 | values[3.5 * dist >= 1.8] = 0.0
21 | assert torch.allclose(values, cutoff(3.5 * dist), atol=0.0, rtol=1.0e-7)
22 |
23 |
24 | def test_cutoff_mollifier():
25 | # mollifier cutoff with radius 2.3
26 | cutoff = MollifierCutoff(cutoff=2.3)
27 | # check cutoff radius
28 | assert abs(2.3 - cutoff.cutoff) < 1.0e-12
29 | # tensor of zeros
30 | dist = torch.zeros((4, 1, 1))
31 | assert torch.allclose(torch.ones(4, 1, 1), cutoff(dist), atol=0.0, rtol=1.0e-7)
32 | # random tensor with elements in [0, 1)
33 | torch.manual_seed(42)
34 | dist = torch.rand((1, 3, 9), dtype=torch.float)
35 | # check cutoff values
36 | expt = torch.exp(1.0 - 1.0 / (1.0 - (dist / 2.3) ** 2))
37 | assert torch.allclose(expt, cutoff(dist), atol=0.0, rtol=1.0e-7)
38 | # compute cutoff values and expected values
39 | comp = cutoff(3.8 * dist)
40 | expt = torch.exp(1.0 - 1.0 / (1.0 - (3.8 * dist / 2.3) ** 2))
41 | expt[3.8 * dist >= 2.3] = 0.0
42 | assert torch.allclose(expt, comp, atol=0.0, rtol=1.0e-7)
43 |
--------------------------------------------------------------------------------
/tests/nn/test_radial.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from schnetpack.nn.radial import GaussianRBF
4 |
5 |
6 | def test_smear_gaussian_one_distance():
7 | # case of one distance
8 | dist = torch.tensor([[[1.0]]])
9 |
10 | smear = GaussianRBF(n_rbf=6, cutoff=5.0, trainable=False)
11 | expt = torch.exp(-0.5 * torch.tensor([[[1.0, 0.0, 1.0, 4.0, 9.0, 16.0]]]))
12 | assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7)
13 | assert list(smear.parameters()) == []
14 |
15 |
16 | def test_smear_gaussian_one_distance_trainable():
17 | dist = torch.tensor([[[1.0]]])
18 | expt = torch.exp(-0.5 * torch.tensor([[[1.0, 0.0, 1.0, 4.0, 9.0, 16.0]]]))
19 |
20 | smear = GaussianRBF(n_rbf=6, cutoff=5.0, trainable=True)
21 | assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7)
22 | params = list(smear.parameters())
23 | assert len(params) == 2
24 | assert len(params[0]) == 6
25 | assert len(params[1]) == 6
26 |
27 |
28 | def test_smear_gaussian():
29 | dist = torch.tensor([[[0.0, 1.0, 1.5], [0.5, 1.5, 3.0]]])
30 | # smear using 4 Gaussian functions with 1. spacing
31 | smear = GaussianRBF(start=1.0, cutoff=4.0, n_rbf=4)
32 | # absolute value of centered distances
33 | expt = torch.tensor(
34 | [
35 | [
36 | [[1, 2, 3, 4], [0, 1, 2, 3], [0.5, 0.5, 1.5, 2.5]],
37 | [[0.5, 1.5, 2.5, 3.5], [0.5, 0.5, 1.5, 2.5], [2, 1, 0, 1]],
38 | ]
39 | ]
40 | )
41 | expt = torch.exp(-0.5 * expt**2)
42 | assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7)
43 | assert list(smear.parameters()) == []
44 |
45 |
46 | def test_smear_gaussian_trainable():
47 | dist = torch.tensor([[[0.0, 1.0, 1.5, 0.25], [0.5, 1.5, 3.0, 1.0]]])
48 | # smear using 5 Gaussian functions with 0.75 spacing
49 | smear = GaussianRBF(start=1.0, cutoff=4.0, n_rbf=5, trainable=True)
50 | # absolute value of centered distances
51 | expt = torch.tensor(
52 | [
53 | [
54 | [
55 | [1, 1.75, 2.5, 3.25, 4.0],
56 | [0, 0.75, 1.5, 2.25, 3.0],
57 | [0.5, 0.25, 1.0, 1.75, 2.5],
58 | [0.75, 1.5, 2.25, 3.0, 3.75],
59 | ],
60 | [
61 | [0.5, 1.25, 2.0, 2.75, 3.5],
62 | [0.5, 0.25, 1.0, 1.75, 2.5],
63 | [2.0, 1.25, 0.5, 0.25, 1.0],
64 | [0, 0.75, 1.5, 2.25, 3.0],
65 | ],
66 | ]
67 | ]
68 | )
69 | expt = torch.exp((-0.5 / 0.75**2) * expt**2)
70 | assert torch.allclose(expt, smear(dist), atol=0.0, rtol=1.0e-7)
71 | params = list(smear.parameters())
72 | assert len(params) == 2
73 | assert len(params[0]) == 5
74 | assert len(params[1]) == 5
75 |
--------------------------------------------------------------------------------
/tests/nn/test_schnet.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | import schnetpack.properties as structure
5 | import schnetpack as spk
6 | import numpy as np
7 | from ase.neighborlist import neighbor_list
8 |
9 | from schnetpack.representation.schnet import SchNet
10 |
11 | # TODO:make proper timing and golden tests
12 |
13 |
14 | @pytest.fixture
15 | def indexed_data(example_data, batch_size):
16 | Z = []
17 | R = []
18 | C = []
19 | seg_m = []
20 | ind_i = []
21 | ind_j = []
22 | ind_S = []
23 | Rij = []
24 |
25 | n_atoms = 0
26 | n_pairs = 0
27 | for i in range(len(example_data)):
28 | seg_m.append(n_atoms)
29 | atoms = example_data[i][0]
30 | atoms.set_pbc(False)
31 | Z.append(atoms.numbers)
32 | R.append(atoms.positions)
33 | C.append(atoms.cell)
34 | idx_i, idx_j, idx_S, rij = neighbor_list(
35 | "ijSD", atoms, 5.0, self_interaction=False
36 | )
37 | _, seg_im = np.unique(idx_i, return_counts=True)
38 | ind_i.append(idx_i + n_atoms)
39 | ind_j.append(idx_j + n_atoms)
40 | ind_S.append(idx_S)
41 | Rij.append(rij.astype(np.float32))
42 | n_atoms += len(atoms)
43 | n_pairs += len(idx_i)
44 | if i + 1 >= batch_size:
45 | break
46 | seg_m.append(n_atoms)
47 |
48 | Z = np.hstack(Z)
49 | R = np.vstack(R).astype(np.float32)
50 | C = np.array(C).astype(np.float32)
51 | seg_m = np.hstack(seg_m)
52 | ind_i = np.hstack(ind_i)
53 | ind_j = np.hstack(ind_j)
54 | ind_S = np.vstack(ind_S)
55 | Rij = np.vstack(Rij)
56 |
57 | inputs = {
58 | structure.Z: torch.tensor(Z),
59 | structure.position: torch.tensor(R),
60 | structure.cell: torch.tensor(C),
61 | structure.idx_m: torch.tensor(seg_m),
62 | structure.idx_j: torch.tensor(ind_j),
63 | structure.idx_i: torch.tensor(ind_i),
64 | structure.Rij: torch.tensor(Rij),
65 | }
66 |
67 | return inputs
68 |
69 |
70 | def test_schnet_new_coo(indexed_data, benchmark):
71 | radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=5.0)
72 | cutoff_fn = spk.nn.CosineCutoff(5.0)
73 | schnet = SchNet(
74 | n_atom_basis=128,
75 | n_interactions=3,
76 | radial_basis=radial_basis,
77 | cutoff_fn=cutoff_fn,
78 | )
79 |
80 | benchmark(schnet, indexed_data)
81 |
82 |
83 | def test_schnet_new_script(indexed_data, benchmark):
84 |
85 | radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=5.0)
86 | cutoff_fn = spk.nn.CosineCutoff(5.0)
87 | schnet = SchNet(
88 | n_atom_basis=128,
89 | n_interactions=3,
90 | radial_basis=radial_basis,
91 | cutoff_fn=cutoff_fn,
92 | )
93 | schnet = torch.jit.script(schnet)
94 | schnet(indexed_data)
95 |
96 | benchmark(schnet, indexed_data)
97 |
--------------------------------------------------------------------------------
/tests/testdata/md_ethanol.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/testdata/md_ethanol.model
--------------------------------------------------------------------------------
/tests/testdata/md_ethanol.xyz:
--------------------------------------------------------------------------------
1 | 9
2 |
3 | C -4.92196480914482 1.53680877549233 -0.06612792847094
4 | C -3.41079303549336 1.45138155063184 -0.14009009720834
5 | H -5.22648850340463 2.28202241947302 0.66236410391492
6 | H -5.34004680800574 0.57895313793668 0.22257334141131
7 | H -5.33193076526251 1.80898014947387 -1.03229511269262
8 | H -3.00368348713509 1.18933429199764 0.83479697695625
9 | H -2.99557504133053 2.41817570143478 -0.41886385105291
10 | O -3.07553304550781 0.47652256654287 -1.09348059854212
11 | H -2.13350450471551 0.40432140701697 -1.15817683431555
12 |
--------------------------------------------------------------------------------
/tests/testdata/si16.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/testdata/si16.model
--------------------------------------------------------------------------------
/tests/testdata/test_qm9.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/testdata/test_qm9.db
--------------------------------------------------------------------------------
/tests/testdata/tmp/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/atomistic-machine-learning/schnetpack/601c8edb6cdf35af0b2ebf27b4b9e39386b79fb9/tests/testdata/tmp/.gitkeep
--------------------------------------------------------------------------------
/tests/user_config/user_exp.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | user_exp: True
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist = py38,qa
3 |
4 | [testenv]
5 | deps =
6 | pytest
7 | pytest-cov
8 | pytest-datadir
9 | commands =
10 | pip install -r docs/sphinx-requirements.txt
11 | pip install -e .[test]
12 | pip list
13 | pytest --cov-report term --cov=schnetpack
14 |
15 | # prevent exit when error is encountered
16 | ignore_errors = true
17 |
18 | [testenv:qa]
19 | deps =
20 | black
21 | commands =
22 | black -v -l 88 --check --diff src/schnetpack
23 | # prevent exit when error is encountered
24 | ignore_errors = true
25 |
--------------------------------------------------------------------------------