├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .git_archival.txt ├── .gitattributes ├── .github ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── ci.yaml │ └── docs.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── Makefile ├── README.md ├── devtools └── envs │ └── base.yaml ├── docs ├── development.md ├── examples ├── index.md ├── javascripts │ └── mathjax.js └── scripts │ └── gen_ref_pages.py ├── examples ├── README.md ├── compute-energy.ipynb ├── conformer-minimization.ipynb ├── md-simulations.ipynb └── parameter-gradients.ipynb ├── mkdocs.yml ├── pyproject.toml └── smee ├── __init__.py ├── _constants.py ├── _models.py ├── converters ├── __init__.py ├── openff │ ├── __init__.py │ ├── _openff.py │ ├── nonbonded.py │ └── valence.py └── openmm │ ├── __init__.py │ ├── _ff.py │ ├── _openmm.py │ ├── nonbonded.py │ └── valence.py ├── geometry.py ├── mm ├── __init__.py ├── _config.py ├── _fe.py ├── _mm.py ├── _ops.py ├── _reporters.py └── _utils.py ├── potentials ├── __init__.py ├── _potentials.py ├── nonbonded.py └── valence.py ├── py.typed ├── tests ├── __init__.py ├── conftest.py ├── convertors │ ├── __init__.py │ ├── openff │ │ ├── __init__.py │ │ ├── test_nonbonded.py │ │ ├── test_openff.py │ │ └── test_valence.py │ └── openmm │ │ ├── __init__.py │ │ ├── test_ff.py │ │ └── test_openmm.py ├── data │ └── de-ff.offxml ├── mm │ ├── __init__.py │ ├── conftest.py │ ├── test_fe.py │ ├── test_mm.py │ ├── test_ops.py │ └── test_reporters.py ├── potentials │ ├── __init__.py │ ├── conftest.py │ ├── test_nonbonded.py │ ├── test_potentials.py │ └── test_valence.py ├── test_geometry.py ├── test_models.py ├── test_utils.py └── utils.py └── utils.py /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM --platform=linux/x86_64 condaforge/mambaforge:latest 2 | 3 | RUN apt update \ 4 | && apt install -y git make build-essentials \ 5 | && rm -rf /var/lib/apt/lists/* 6 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "build": { "dockerfile": "Dockerfile" }, 3 | "postCreateCommand": "make env" 4 | } -------------------------------------------------------------------------------- /.git_archival.txt: -------------------------------------------------------------------------------- 1 | node: aca69b9da4c67916c6e59ed2c435fffd4c49a2b6 2 | node-date: 2024-12-13T02:24:40Z 3 | describe-name: 0.16.2 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | .git_archival.txt export-subst 2 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | Provide a brief description of the PR's purpose here. 3 | 4 | ## Status 5 | - [ ] Ready to go -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | concurrency: 4 | group: ${{ github.ref }} 5 | cancel-in-progress: true 6 | 7 | on: 8 | push: { branches: [ "main" ] } 9 | pull_request: { branches: [ "main" ] } 10 | 11 | jobs: 12 | test: 13 | 14 | runs-on: ubuntu-latest 15 | container: condaforge/mambaforge:latest 16 | 17 | steps: 18 | - uses: actions/checkout@v3.3.0 19 | 20 | - name: Setup Conda Environment 21 | run: | 22 | apt update && apt install -y git make 23 | 24 | make env 25 | make lint 26 | make test 27 | make test-examples 28 | make docs 29 | 30 | # TODO: Remove this once pydantic 1.0 support is dropped 31 | # We remove absolv as femto needs pydantic >=2 32 | mamba remove --name smee --yes "absolv" 33 | mamba install --name smee --yes "pydantic <2" 34 | make test 35 | 36 | - name: CodeCov 37 | uses: codecov/codecov-action@v4.1.1 38 | with: 39 | file: ./coverage.xml 40 | flags: unittests 41 | token: ${{ secrets.CODECOV_TOKEN }} 42 | -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: Publish Documentation 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | tags: ["*"] 7 | 8 | jobs: 9 | deploy-docs: 10 | 11 | runs-on: ubuntu-latest 12 | container: condaforge/mambaforge:latest 13 | 14 | steps: 15 | - name: Prepare container 16 | run: | 17 | apt update && apt install -y git make 18 | 19 | - name: Checkout 20 | uses: actions/checkout@v3.3.0 21 | 22 | - name: Determine Version 23 | shell: bash 24 | run: | 25 | if [ "$GITHUB_REF" = "refs/heads/main" ]; then 26 | echo "VERSION=latest" >> $GITHUB_ENV 27 | elif [ "${GITHUB_REF#refs/tags/}" != "$GITHUB_REF" ]; then 28 | VERSION=$(echo $GITHUB_REF | sed 's/refs\/tags\///') 29 | echo "VERSION=$VERSION stable" >> $GITHUB_ENV 30 | else 31 | echo "Invalid ref: $GITHUB_REF" 32 | exit 1 33 | fi 34 | 35 | - name: Build and Deploy Documentation 36 | run: | 37 | git config --global user.name 'GitHub Actions' 38 | git config --global user.email 'actions@github.com' 39 | git config --global --add safe.directory "$PWD" 40 | git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }} 41 | 42 | git fetch --all --prune 43 | git pull origin gh-pages --allow-unrelated-histories 44 | 45 | make env 46 | make docs-deploy VERSION="$VERSION" 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | .pytest_cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # profraw files from LLVM? Unclear exactly what triggers this 105 | # There are reports this comes from LLVM profiling, but also Xcode 9. 106 | *profraw 107 | 108 | # PyCharm 109 | .idea 110 | 111 | # OSX 112 | .DS_Store 113 | 114 | # Local development 115 | scratch 116 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.2.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | 8 | - repo: local 9 | hooks: 10 | - id: ruff 11 | name: Autoformat python code 12 | language: system 13 | entry: ruff 14 | args: [check] 15 | files: \.py$ 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2021 Simon Boothroyd 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PACKAGE_NAME := smee 2 | CONDA_ENV_RUN := conda run --no-capture-output --name $(PACKAGE_NAME) 3 | 4 | EXAMPLES_SKIP := examples/md-simulations.ipynb 5 | EXAMPLES := $(filter-out $(EXAMPLES_SKIP), $(wildcard examples/*.ipynb)) 6 | 7 | .PHONY: env lint format test test-examples docs docs-deploy 8 | 9 | env: 10 | mamba create --name $(PACKAGE_NAME) 11 | mamba env update --name $(PACKAGE_NAME) --file devtools/envs/base.yaml 12 | $(CONDA_ENV_RUN) pip install --no-deps -e . 13 | $(CONDA_ENV_RUN) pre-commit install || true 14 | 15 | lint: 16 | $(CONDA_ENV_RUN) ruff check $(PACKAGE_NAME) 17 | $(CONDA_ENV_RUN) ruff check examples 18 | 19 | format: 20 | $(CONDA_ENV_RUN) ruff format $(PACKAGE_NAME) 21 | $(CONDA_ENV_RUN) ruff check --fix --select I $(PACKAGE_NAME) 22 | $(CONDA_ENV_RUN) ruff format examples 23 | $(CONDA_ENV_RUN) ruff check --fix --select I examples 24 | 25 | test: 26 | $(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-append --cov-report=xml --color=yes $(PACKAGE_NAME)/tests/ 27 | 28 | test-examples: 29 | $(CONDA_ENV_RUN) jupyter nbconvert --to notebook --execute $(EXAMPLES) 30 | 31 | docs: 32 | $(CONDA_ENV_RUN) mkdocs build 33 | 34 | docs-deploy: 35 | ifndef VERSION 36 | $(error VERSION is not set) 37 | endif 38 | $(CONDA_ENV_RUN) mike deploy --push --update-aliases $(VERSION) 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

SMIRNOFF Energy Evaluations

2 | 3 |

Differentiably evaluate energies of molecules using SMIRNOFF force fields

4 | 5 |

6 | 7 | ci 8 | 9 | 10 | coverage 11 | 12 | 13 | license 14 | 15 |

16 | 17 | --- 18 | 19 | The `smee` framework aims to offer a simple API for differentiably evaluating the energy of [SMIRNOFF](https://openforcefield.github.io/standards/standards/smirnoff/) 20 | force fields applied to molecules using `pytorch`. 21 | 22 | The package currently supports evaluating the energy of force fields that contain: 23 | 24 | * `Bonds`, `Angles`, `ProperTorsions` and `ImproperTorsions` 25 | * `vdW`, `Electrostatics`, `ToolkitAM1BCC`, `LibraryCharges` 26 | * `VirtualSites` 27 | 28 | parameter handlers in addition to limited support for registering custom handlers. 29 | 30 | It further supports a number of functional forms included in `smirnoff-plugins`, namely: 31 | 32 | * `DoubleExponential` 33 | 34 | ## Installation 35 | 36 | This package can be installed using `conda` (or `mamba`, a faster version of `conda`): 37 | 38 | ```shell 39 | mamba install -c conda-forge smee 40 | ``` 41 | 42 | The example notebooks further require you install `jupyter`, `nglview`, and `smirnoff-plugins`: 43 | 44 | ```shell 45 | mamba install -c conda-forge jupyter nglview "smirnoff-plugins >=0.0.4" 46 | ``` 47 | 48 | ## Getting Started 49 | 50 | To get started, see the [examples](examples). 51 | 52 | ## Copyright 53 | 54 | Copyright (c) 2023, Simon Boothroyd 55 | -------------------------------------------------------------------------------- /devtools/envs/base.yaml: -------------------------------------------------------------------------------- 1 | name: smee 2 | 3 | channels: 4 | - conda-forge 5 | 6 | dependencies: 7 | 8 | - python >=3.10 9 | - pip 10 | 11 | # Core packages 12 | - openff-units 13 | - openff-toolkit-base >=0.9.2 14 | - openff-interchange-base >=0.3.17 15 | 16 | - pytorch 17 | - nnpops 18 | 19 | - pydantic 20 | - pydantic-units 21 | 22 | - networkx 23 | 24 | # Optional packages 25 | 26 | ### MM simulations 27 | - openmm 28 | - python-symengine 29 | - rdkit 30 | - packmol 31 | - numpy 32 | - msgpack-python 33 | 34 | # FE simulations 35 | - mdtraj 36 | - absolv >=1.0.1 37 | 38 | # Examples 39 | - jupyter 40 | - nbconvert 41 | - nglview 42 | - smirnoff-plugins >=0.0.4 43 | 44 | # Dev / Testing 45 | - ambertools 46 | - scipy # test logsumexp implementation 47 | - smirnoff-plugins 48 | 49 | - setuptools_scm >=8 50 | 51 | - pre-commit 52 | - ruff 53 | - nbqa 54 | 55 | - pytest 56 | - pytest-cov 57 | - pytest-mock 58 | 59 | - codecov 60 | 61 | # Docs 62 | - mkdocs 63 | - mkdocs-material 64 | - mkdocs-gen-files 65 | - mkdocs-literate-nav 66 | - mkdocs-jupyter 67 | - mkdocstrings 68 | - mkdocstrings-python >=1.10.8 69 | - griffe-pydantic 70 | - black 71 | - mike 72 | -------------------------------------------------------------------------------- /docs/development.md: -------------------------------------------------------------------------------- 1 | # Development 2 | 3 | To create a development environment, you must have [`mamba` installed](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html). 4 | 5 | A development conda environment can be created and activated with: 6 | 7 | ```shell 8 | make env 9 | conda activate smee 10 | ``` 11 | 12 | To format the codebase: 13 | 14 | ```shell 15 | make format 16 | ``` 17 | 18 | To run the unit tests: 19 | 20 | ```shell 21 | make test 22 | ``` 23 | 24 | To serve the documentation locally: 25 | 26 | ```shell 27 | mkdocs serve 28 | ``` 29 | -------------------------------------------------------------------------------- /docs/examples: -------------------------------------------------------------------------------- 1 | ../examples -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --8<-- "README.md" 2 | -------------------------------------------------------------------------------- /docs/javascripts/mathjax.js: -------------------------------------------------------------------------------- 1 | window.MathJax = { 2 | tex: { 3 | inlineMath: [["\\(", "\\)"]], 4 | displayMath: [["\\[", "\\]"]], 5 | processEscapes: true, 6 | processEnvironments: true 7 | }, 8 | options: { 9 | ignoreHtmlClass: ".*|", 10 | processHtmlClass: "arithmatex" 11 | } 12 | }; 13 | 14 | document$.subscribe(() => { 15 | MathJax.typesetPromise() 16 | }) 17 | -------------------------------------------------------------------------------- /docs/scripts/gen_ref_pages.py: -------------------------------------------------------------------------------- 1 | """Generate the code reference pages and navigation.""" 2 | 3 | import pathlib 4 | 5 | import mkdocs_gen_files 6 | 7 | nav = mkdocs_gen_files.Nav() 8 | src = pathlib.Path(__file__).parent.parent.parent / "smee" 9 | 10 | mod_symbol = '' 11 | 12 | for path in sorted(src.rglob("*.py")): 13 | if "tests" in str(path): 14 | continue 15 | 16 | module_path = path.relative_to(src.parent).with_suffix("") 17 | doc_path = path.relative_to(src).with_suffix(".md") 18 | full_doc_path = pathlib.Path("reference", doc_path) 19 | 20 | parts = tuple(module_path.parts) 21 | 22 | if parts[-1] == "__init__": 23 | parts = parts[:-1] 24 | doc_path = doc_path.with_name("index.md") 25 | full_doc_path = full_doc_path.with_name("index.md") 26 | elif parts[-1].startswith("_"): 27 | continue 28 | 29 | nav_parts = [f"{mod_symbol} {part}" for part in parts] 30 | nav[tuple(nav_parts)] = doc_path.as_posix() 31 | 32 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: 33 | ident = ".".join(parts) 34 | fd.write(f"::: {ident}") 35 | 36 | mkdocs_gen_files.set_edit_path(full_doc_path, ".." / path) 37 | 38 | with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: 39 | nav_file.writelines(nav.build_literate_nav()) 40 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | This directory contains a number of examples of how to use `smee`. They currently include: 4 | 5 | * [Evaluating the energy of a water dimer with virtual sites](compute-energy.ipynb) 6 | * [Minimizing the conformer of a molecule](conformer-minimization.ipynb) 7 | * [Computing the gradient of the energy w.r.t. force field parameters](parameter-gradients.ipynb) 8 | * [Registering custom parameter handlers](custom-handler.ipynb) 9 | * [Differentiably compute ensemble averages from MD simulations](md-simulations.ipynb) 10 | -------------------------------------------------------------------------------- /examples/conformer-minimization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "261b79c7042b8a6f", 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "source": [ 10 | "# Conformer Minimization\n", 11 | "\n", 12 | "This example will show how to optimize a conformer of paracetamol.\n", 13 | "\n", 14 | "Load in a paracetamol molecule, generate a conformer for it, and perturb the conformer to ensure it needs minimization." 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "id": "b081ee3aecf864ac", 21 | "metadata": { 22 | "collapsed": false, 23 | "ExecuteTime": { 24 | "end_time": "2023-10-17T21:18:13.134692Z", 25 | "start_time": "2023-10-17T21:18:10.562001Z" 26 | } 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "import openff.toolkit\n", 31 | "import openff.units\n", 32 | "import torch\n", 33 | "\n", 34 | "molecule = openff.toolkit.Molecule.from_smiles(\"CC(=O)NC1=CC=C(C=C1)O\")\n", 35 | "molecule.generate_conformers(n_conformers=1)\n", 36 | "\n", 37 | "conformer = torch.tensor(molecule.conformers[0].m_as(openff.units.unit.angstrom)) * 1.10\n", 38 | "conformer.requires_grad = True" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "id": "f4168aec7a72494c", 44 | "metadata": { 45 | "collapsed": false 46 | }, 47 | "source": [ 48 | "We specify that the gradient of the conformer is required so that we can optimize it using PyTorch.\n", 49 | "\n", 50 | "Parameterize the molecule using OpenFF Interchange and convert it into a PyTorch tensor representation." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 2, 56 | "id": "8d00fd2dcf4c27cf", 57 | "metadata": { 58 | "collapsed": false, 59 | "ExecuteTime": { 60 | "end_time": "2023-10-17T21:18:16.758187Z", 61 | "start_time": "2023-10-17T21:18:13.138018Z" 62 | } 63 | }, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "text/plain": "", 68 | "application/vnd.jupyter.widget-view+json": { 69 | "version_major": 2, 70 | "version_minor": 0, 71 | "model_id": "d8c8c3f62d1448a4b07498d18cf6dc5f" 72 | } 73 | }, 74 | "metadata": {}, 75 | "output_type": "display_data" 76 | } 77 | ], 78 | "source": [ 79 | "import openff.interchange\n", 80 | "\n", 81 | "interchange = openff.interchange.Interchange.from_smirnoff(\n", 82 | " openff.toolkit.ForceField(\"openff-2.1.0.offxml\"),\n", 83 | " molecule.to_topology(),\n", 84 | ")\n", 85 | "\n", 86 | "import smee.converters\n", 87 | "\n", 88 | "force_field, [topology] = smee.converters.convert_interchange(interchange)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "id": "792cb057cb419fa8", 94 | "metadata": { 95 | "collapsed": false 96 | }, 97 | "source": [ 98 | "We can minimize the conformer using any of PyTorch's optimizers. " 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 3, 104 | "id": "facd656a27cf46a8", 105 | "metadata": { 106 | "collapsed": false, 107 | "ExecuteTime": { 108 | "end_time": "2023-10-17T21:18:17.036136Z", 109 | "start_time": "2023-10-17T21:18:16.761394Z" 110 | } 111 | }, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "Epoch 0: E=102.10968017578125 kcal / mol\n", 118 | "Epoch 5: E=7.088213920593262 kcal / mol\n", 119 | "Epoch 10: E=-18.331130981445312 kcal / mol\n", 120 | "Epoch 15: E=-22.182296752929688 kcal / mol\n", 121 | "Epoch 20: E=-30.369152069091797 kcal / mol\n", 122 | "Epoch 25: E=-36.81045150756836 kcal / mol\n", 123 | "Epoch 30: E=-38.517852783203125 kcal / mol\n", 124 | "Epoch 35: E=-40.50505828857422 kcal / mol\n", 125 | "Epoch 40: E=-42.08476257324219 kcal / mol\n", 126 | "Epoch 45: E=-42.19199752807617 kcal / mol\n", 127 | "Epoch 50: E=-42.37827682495117 kcal / mol\n", 128 | "Epoch 55: E=-42.6767692565918 kcal / mol\n", 129 | "Epoch 60: E=-42.799903869628906 kcal / mol\n", 130 | "Epoch 65: E=-42.94251251220703 kcal / mol\n", 131 | "Epoch 70: E=-43.037200927734375 kcal / mol\n", 132 | "Epoch 74: E=-43.084136962890625 kcal / mol\n" 133 | ] 134 | } 135 | ], 136 | "source": [ 137 | "import smee\n", 138 | "\n", 139 | "optimizer = torch.optim.Adam([conformer], lr=0.02)\n", 140 | "\n", 141 | "for epoch in range(75):\n", 142 | " energy = smee.compute_energy(topology, force_field, conformer)\n", 143 | " energy.backward()\n", 144 | "\n", 145 | " optimizer.step()\n", 146 | " optimizer.zero_grad()\n", 147 | "\n", 148 | " if epoch % 5 == 0 or epoch == 74:\n", 149 | " print(f\"Epoch {epoch}: E={energy.item()} kcal / mol\")" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "id": "360d6eb9cf2b6cc4", 155 | "metadata": { 156 | "collapsed": false 157 | }, 158 | "source": [ 159 | "We can then re-store the optimized conformer back into the molecule. Here we add the conformer to the molecule's conformer list, but we could also replace the original conformer." 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 4, 165 | "id": "eaec04c4039ca59b", 166 | "metadata": { 167 | "collapsed": false, 168 | "ExecuteTime": { 169 | "end_time": "2023-10-17T21:18:17.052947Z", 170 | "start_time": "2023-10-17T21:18:17.036498Z" 171 | } 172 | }, 173 | "outputs": [ 174 | { 175 | "data": { 176 | "text/plain": "NGLWidget(max_frame=1)", 177 | "application/vnd.jupyter.widget-view+json": { 178 | "version_major": 2, 179 | "version_minor": 0, 180 | "model_id": "449fcae6d9eb4e5a8a3d765f0608e399" 181 | } 182 | }, 183 | "metadata": {}, 184 | "output_type": "display_data" 185 | } 186 | ], 187 | "source": [ 188 | "molecule.add_conformer(conformer.detach().numpy() * openff.units.unit.angstrom)\n", 189 | "molecule.visualize(backend=\"nglview\")" 190 | ] 191 | } 192 | ], 193 | "metadata": { 194 | "kernelspec": { 195 | "display_name": "Python 3 (ipykernel)", 196 | "language": "python", 197 | "name": "python3" 198 | }, 199 | "language_info": { 200 | "codemirror_mode": { 201 | "name": "ipython", 202 | "version": 3 203 | }, 204 | "file_extension": ".py", 205 | "mimetype": "text/x-python", 206 | "name": "python", 207 | "nbconvert_exporter": "python", 208 | "pygments_lexer": "ipython3", 209 | "version": "3.11.5" 210 | } 211 | }, 212 | "nbformat": 4, 213 | "nbformat_minor": 5 214 | } 215 | -------------------------------------------------------------------------------- /examples/md-simulations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "aaf952cb344ca32d", 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "source": [ 10 | "# Ensemble Averages from MD Simulations\n", 11 | "\n", 12 | "This example shows how ensemble averages can be computed from MD simulations, such that their gradient with respect to force field parameters can be computed through backpropagation.\n", 13 | "\n", 14 | "We start by parameterizing the set of molecules that will appear in our simulation boxes: " 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "e178932166969df5", 21 | "metadata": { 22 | "collapsed": false 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "import openff.interchange\n", 27 | "import openff.toolkit\n", 28 | "\n", 29 | "import smee.converters\n", 30 | "\n", 31 | "interchanges = [\n", 32 | " openff.interchange.Interchange.from_smirnoff(\n", 33 | " openff.toolkit.ForceField(\"openff-2.0.0.offxml\"),\n", 34 | " openff.toolkit.Molecule.from_smiles(smiles).to_topology(),\n", 35 | " )\n", 36 | " for smiles in (\"CCO\", \"CO\")\n", 37 | "]\n", 38 | "\n", 39 | "tensor_ff, topologies = smee.converters.convert_interchange(interchanges)\n", 40 | "\n", 41 | "# move the force field to the GPU for faster processing of the simulation\n", 42 | "# trajectories - the system and force field must be on the same device.\n", 43 | "tensor_ff = tensor_ff.to(\"cuda\")" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "id": "415cbefaa3d60c49", 49 | "metadata": { 50 | "collapsed": false 51 | }, 52 | "source": [ 53 | "We will also flag that the vdW parameter gradients are required:" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "38700a4d251b1ab9", 60 | "metadata": { 61 | "collapsed": false 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "vdw_potential = tensor_ff.potentials_by_type[\"vdW\"]\n", 66 | "vdw_potential.parameters.requires_grad = True" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "824dfd3f7f6916b3", 72 | "metadata": { 73 | "collapsed": false 74 | }, 75 | "source": [ 76 | "We then define the full simulation boxes that we wish to simulate:" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "58a584bf7997e194", 83 | "metadata": { 84 | "collapsed": false 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "import smee\n", 89 | "\n", 90 | "# define a periodic box containing 216 ethanol molecules\n", 91 | "system_ethanol = smee.TensorSystem([topologies[0]], [216], is_periodic=True)\n", 92 | "system_ethanol = system_ethanol.to(\"cuda\")\n", 93 | "# define a periodic box containing 216 methanol molecules\n", 94 | "system_methanol = smee.TensorSystem([topologies[1]], [216], True)\n", 95 | "system_methanol = system_methanol.to(\"cuda\")\n", 96 | "# define a periodic box containing 128 ethanol molecules and 128 methanol molecules\n", 97 | "system_mixture = smee.TensorSystem(topologies, [128, 128], True)\n", 98 | "system_mixture = system_mixture.to(\"cuda\")" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "id": "93a519affcb22db4", 104 | "metadata": { 105 | "collapsed": false 106 | }, 107 | "source": [ 108 | "A tensor system is simply a wrapper around a set of topology objects that define parameters applied to individual molecules, and the number of copies of that topology that should be present similar to GROMACS topologies. The `is_periodic` flag indicates whether the system should be simulated in a periodic box.\n", 109 | "\n", 110 | "Here we have also moved the systems onto the GPU. This will allow us to much more rapidly compute ensemble averages from the trajectories, but is not required.\n", 111 | "\n", 112 | "We then also must define the simulation protocol that will be used to run the simulations. This consists of a config object that defines how to generate the system coordinates using PACKMOL, the set of energy minimisations /simulations to run as equilibration, and finally the configuration of the production simulation:" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "id": "ccba93245cf83ff7", 119 | "metadata": { 120 | "collapsed": false 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "import tempfile\n", 125 | "\n", 126 | "import openmm.unit\n", 127 | "\n", 128 | "import smee.mm\n", 129 | "\n", 130 | "temperature = 298.15 * openmm.unit.kelvin\n", 131 | "pressure = 1.0 * openmm.unit.atmosphere\n", 132 | "\n", 133 | "beta = 1.0 / (openmm.unit.MOLAR_GAS_CONSTANT_R * temperature)\n", 134 | "\n", 135 | "# we can run an arbitrary number of equilibration simulations / minimizations.\n", 136 | "# all generated data will be discarded, but the final coordinates will be used\n", 137 | "# to initialize the production simulation\n", 138 | "equilibrate_config = [\n", 139 | " smee.mm.MinimizationConfig(),\n", 140 | " # short NVT equilibration simulation\n", 141 | " smee.mm.SimulationConfig(\n", 142 | " temperature=temperature,\n", 143 | " pressure=None,\n", 144 | " n_steps=50000,\n", 145 | " timestep=1.0 * openmm.unit.femtosecond,\n", 146 | " ),\n", 147 | " # short NPT equilibration simulation\n", 148 | " smee.mm.SimulationConfig(\n", 149 | " temperature=temperature,\n", 150 | " pressure=pressure,\n", 151 | " n_steps=50000,\n", 152 | " timestep=1.0 * openmm.unit.femtosecond,\n", 153 | " ),\n", 154 | "]\n", 155 | "# long NPT production simulation\n", 156 | "production_config = smee.mm.SimulationConfig(\n", 157 | " temperature=temperature,\n", 158 | " pressure=pressure,\n", 159 | " n_steps=500000,\n", 160 | " timestep=2.0 * openmm.unit.femtosecond,\n", 161 | ")" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "id": "9db4088f839c4265", 167 | "metadata": { 168 | "collapsed": false 169 | }, 170 | "source": [ 171 | "We will further define a convenience function that will first simulate the system of interest (storing the trajectory in a temporary directory), and then compute ensemble averages over that trajectory:" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "id": "d485cce5c1a0fce3", 178 | "metadata": { 179 | "collapsed": false 180 | }, 181 | "outputs": [], 182 | "source": [ 183 | "import pathlib\n", 184 | "\n", 185 | "import torch\n", 186 | "\n", 187 | "\n", 188 | "def compute_ensemble_averages(\n", 189 | " system: smee.TensorSystem, force_field: smee.TensorForceField\n", 190 | ") -> dict[str, torch.Tensor]:\n", 191 | " # computing the ensemble averages is a two step process - we first need to run\n", 192 | " # an MD simulation using the force field making sure to store the coordinates,\n", 193 | " # box vectors and kinetic energies\n", 194 | " coords, box_vectors = smee.mm.generate_system_coords(system, force_field)\n", 195 | "\n", 196 | " interval = 1000\n", 197 | "\n", 198 | " # save the simulation output every 1000th frame (2 ps) to a temporary file.\n", 199 | " # we could also save the trajectory more permanently, but as we do nothing\n", 200 | " # with it after computing the averages in this example, we simply want to\n", 201 | " # discard it.\n", 202 | " with (\n", 203 | " tempfile.NamedTemporaryFile() as tmp_file,\n", 204 | " smee.mm.tensor_reporter(tmp_file.name, interval, beta, pressure) as reporter,\n", 205 | " ):\n", 206 | " smee.mm.simulate(\n", 207 | " system,\n", 208 | " force_field,\n", 209 | " coords,\n", 210 | " box_vectors,\n", 211 | " equilibrate_config,\n", 212 | " production_config,\n", 213 | " [reporter],\n", 214 | " )\n", 215 | "\n", 216 | " # we can then compute the ensemble averages from the trajectory. generating\n", 217 | " # the trajectory separately from computing the ensemble averages allows us\n", 218 | " # to run the simulation in parallel with other simulations more easily, without\n", 219 | " # having to worry about copying gradients between workers / processes.\n", 220 | " avgs, stds = smee.mm.compute_ensemble_averages(\n", 221 | " system, force_field, pathlib.Path(tmp_file.name), temperature, pressure\n", 222 | " )\n", 223 | " return avgs" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "id": "8d1fdb6973324d14", 229 | "metadata": { 230 | "collapsed": false 231 | }, 232 | "source": [ 233 | "Computing the ensemble averages is then as simple as:" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "id": "3156bcfc509380f7", 240 | "metadata": { 241 | "collapsed": false 242 | }, 243 | "outputs": [], 244 | "source": [ 245 | "# run simulations of each system and compute ensemble averages over the trajectories\n", 246 | "# of the potential energy, volume, and density\n", 247 | "ethanol_avgs = compute_ensemble_averages(system_ethanol, tensor_ff)\n", 248 | "methanol_avgs = compute_ensemble_averages(system_methanol, tensor_ff)\n", 249 | "mixture_avgs = compute_ensemble_averages(system_mixture, tensor_ff)" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "id": "bcce5c83a564c59f", 255 | "metadata": { 256 | "collapsed": false 257 | }, 258 | "source": [ 259 | "Each of the returned values is a dictionary of ensemble averages computed over the simulated production trajectory. This currently includes the potential energy, volume, and density of the system.\n", 260 | "\n", 261 | "These averages can be used in a loss function" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "id": "38b9a27d7cd06c1a", 268 | "metadata": { 269 | "collapsed": false 270 | }, 271 | "outputs": [], 272 | "source": [ 273 | "# define some MOCK data and loss function\n", 274 | "mock_ethanol_density = 0.789 # g/mL\n", 275 | "mock_methanol_density = 0.791 # g/mL\n", 276 | "\n", 277 | "mock_enthalpy_of_mixing = 0.891 # kcal/mol\n", 278 | "\n", 279 | "loss = (ethanol_avgs[\"density\"] - mock_ethanol_density) ** 2\n", 280 | "loss += (methanol_avgs[\"density\"] - mock_methanol_density) ** 2\n", 281 | "\n", 282 | "mixture_enthalpy = mixture_avgs[\"enthalpy\"] / 256\n", 283 | "\n", 284 | "ethanol_enthalpy = ethanol_avgs[\"enthalpy\"] / 128\n", 285 | "methanol_enthalpy = methanol_avgs[\"enthalpy\"] / 128\n", 286 | "\n", 287 | "enthalpy_of_mixing = mixture_enthalpy - (\n", 288 | " 0.5 * ethanol_enthalpy + 0.5 * methanol_enthalpy\n", 289 | ")\n", 290 | "loss += (enthalpy_of_mixing - mock_enthalpy_of_mixing) ** 2" 291 | ] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "id": "9bd6779012316898", 296 | "metadata": { 297 | "collapsed": false 298 | }, 299 | "source": [ 300 | "and the gradient of this loss function with respect to the force field parameters can be computed through backpropagation:" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "id": "dd3ccdfe61a0cd09", 307 | "metadata": { 308 | "collapsed": false 309 | }, 310 | "outputs": [], 311 | "source": [ 312 | "loss.backward()\n", 313 | "\n", 314 | "epsilon_col = vdw_potential.parameter_cols.index(\"epsilon\")\n", 315 | "sigma_col = vdw_potential.parameter_cols.index(\"sigma\")\n", 316 | "\n", 317 | "print(\"VdW Ɛ Gradients\", vdw_potential.parameters.grad[:, epsilon_col])\n", 318 | "print(\"VdW σ Gradients\", vdw_potential.parameters.grad[:, sigma_col])" 319 | ] 320 | } 321 | ], 322 | "metadata": { 323 | "kernelspec": { 324 | "display_name": "Python 3 (ipykernel)", 325 | "language": "python", 326 | "name": "python3" 327 | }, 328 | "language_info": { 329 | "codemirror_mode": { 330 | "name": "ipython", 331 | "version": 3 332 | }, 333 | "file_extension": ".py", 334 | "mimetype": "text/x-python", 335 | "name": "python", 336 | "nbconvert_exporter": "python", 337 | "pygments_lexer": "ipython3", 338 | "version": "3.11.6" 339 | } 340 | }, 341 | "nbformat": 4, 342 | "nbformat_minor": 5 343 | } 344 | -------------------------------------------------------------------------------- /examples/parameter-gradients.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Parameter Gradients\n", 7 | "\n", 8 | "This example will show how the gradient of the potential energy with respect to force field parameters may be computed.\n", 9 | "\n", 10 | "We start be loading and parameterizing the molecule of interest." 11 | ], 12 | "metadata": { 13 | "collapsed": false 14 | }, 15 | "id": "581674729871c21a" 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "outputs": [ 21 | { 22 | "data": { 23 | "text/plain": "", 24 | "application/vnd.jupyter.widget-view+json": { 25 | "version_major": 2, 26 | "version_minor": 0, 27 | "model_id": "9f2762d7fdec4691a8fba35cac8e7f98" 28 | } 29 | }, 30 | "metadata": {}, 31 | "output_type": "display_data" 32 | } 33 | ], 34 | "source": [ 35 | "import openff.interchange\n", 36 | "import openff.toolkit\n", 37 | "import openff.units\n", 38 | "import torch\n", 39 | "\n", 40 | "import smee.converters\n", 41 | "\n", 42 | "molecule = openff.toolkit.Molecule.from_smiles(\"CC(=O)NC1=CC=C(C=C1)O\")\n", 43 | "molecule.generate_conformers(n_conformers=1)\n", 44 | "\n", 45 | "conformer = torch.tensor(molecule.conformers[0].m_as(openff.units.unit.angstrom))\n", 46 | "\n", 47 | "interchange = openff.interchange.Interchange.from_smirnoff(\n", 48 | " openff.toolkit.ForceField(\"openff_unconstrained-2.0.0.offxml\"),\n", 49 | " molecule.to_topology(),\n", 50 | ")\n", 51 | "tensor_ff, [tensor_topology] = smee.converters.convert_interchange(interchange)" 52 | ], 53 | "metadata": { 54 | "collapsed": false, 55 | "ExecuteTime": { 56 | "end_time": "2023-10-17T21:17:56.406479Z", 57 | "start_time": "2023-10-17T21:17:50.319630Z" 58 | } 59 | }, 60 | "id": "67b29771d6b77bc" 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "source": [ 65 | "We can access the parameters for each SMIRNOFF parameter 'handler' (e.g. vdW, bond, angle, etc.) using the `potentials_by_type` (or the `potentials`) attribute of the `TensorForceField` object." 66 | ], 67 | "metadata": { 68 | "collapsed": false 69 | }, 70 | "id": "338809852105c2d8" 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 2, 75 | "outputs": [], 76 | "source": [ 77 | "vdw_potential = tensor_ff.potentials_by_type[\"vdW\"]\n", 78 | "vdw_potential.parameters.requires_grad = True" 79 | ], 80 | "metadata": { 81 | "collapsed": false, 82 | "ExecuteTime": { 83 | "end_time": "2023-10-17T21:17:56.416841Z", 84 | "start_time": "2023-10-17T21:17:56.409028Z" 85 | } 86 | }, 87 | "id": "4757416c61bdcf8d" 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "source": [ 92 | "The gradient of the potential energy with respect to the parameters can then be computed by backpropagating through the energy computation." 93 | ], 94 | "metadata": { 95 | "collapsed": false 96 | }, 97 | "id": "6e8d221ef30f9da6" 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 3, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "[#6X4:1] - dU/depsilon = -1.033, dU/dsigma = -0.139\n", 108 | "[#6:1] - dU/depsilon = 87.490, dU/dsigma = 34.983\n", 109 | "[#8:1] - dU/depsilon = 15.846, dU/dsigma = 17.232\n", 110 | "[#7:1] - dU/depsilon = 0.148, dU/dsigma = 1.187\n", 111 | "[#8X2H1+0:1] - dU/depsilon = -0.305, dU/dsigma = 0.558\n", 112 | "[#1:1]-[#6X4] - dU/depsilon = 7.630, dU/dsigma = 1.404\n", 113 | "[#1:1]-[#7] - dU/depsilon = -2.894, dU/dsigma = -0.074\n", 114 | "[#1:1]-[#6X3] - dU/depsilon = 137.134, dU/dsigma = 12.129\n", 115 | "[#1:1]-[#8] - dU/depsilon = -22.417, dU/dsigma = -0.001\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "import smee\n", 121 | "\n", 122 | "energy = smee.compute_energy(tensor_topology, tensor_ff, conformer)\n", 123 | "energy.backward()\n", 124 | "\n", 125 | "for parameter_key, gradient in zip(\n", 126 | " vdw_potential.parameter_keys, vdw_potential.parameters.grad.numpy(), strict=True\n", 127 | "):\n", 128 | " parameter_cols = vdw_potential.parameter_cols\n", 129 | "\n", 130 | " parameter_grads = \", \".join(\n", 131 | " f\"dU/d{parameter_col} = {parameter_grad: 8.3f}\"\n", 132 | " for parameter_col, parameter_grad in zip(parameter_cols, gradient, strict=True)\n", 133 | " )\n", 134 | " print(f\"{parameter_key.id.ljust(15)} - {parameter_grads}\")" 135 | ], 136 | "metadata": { 137 | "collapsed": false, 138 | "ExecuteTime": { 139 | "end_time": "2023-10-17T21:17:56.452896Z", 140 | "start_time": "2023-10-17T21:17:56.415831Z" 141 | } 142 | }, 143 | "id": "6df321d552be0aad" 144 | } 145 | ], 146 | "metadata": { 147 | "kernelspec": { 148 | "display_name": "Python 3", 149 | "language": "python", 150 | "name": "python3" 151 | }, 152 | "language_info": { 153 | "codemirror_mode": { 154 | "name": "ipython", 155 | "version": 2 156 | }, 157 | "file_extension": ".py", 158 | "mimetype": "text/x-python", 159 | "name": "python", 160 | "nbconvert_exporter": "python", 161 | "pygments_lexer": "ipython2", 162 | "version": "2.7.6" 163 | } 164 | }, 165 | "nbformat": 4, 166 | "nbformat_minor": 5 167 | } 168 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: "smee" 2 | site_description: "Differentiably evaluate energies of molecules." 3 | site_url: "https://github.com/SimonBoothroyd/smee" 4 | repo_url: "https://github.com/SimonBoothroyd/smee" 5 | repo_name: "SimonBoothroyd/smee" 6 | site_dir: "site" 7 | watch: [mkdocs.yml, README.md, smee/, docs] 8 | copyright: Copyright © 2024 Simon Boothroyd 9 | edit_uri: edit/main/docs/ 10 | 11 | validation: 12 | omitted_files: warn 13 | absolute_links: warn 14 | unrecognized_links: warn 15 | 16 | extra: 17 | version: 18 | provider: mike 19 | 20 | nav: 21 | - Home: 22 | - Overview: index.md 23 | - Examples: 24 | - examples/README.md 25 | - examples/compute-energy.ipynb 26 | - examples/conformer-minimization.ipynb 27 | - examples/md-simulations.ipynb 28 | - examples/parameter-gradients.ipynb 29 | - API reference: reference/ 30 | - Development: development.md 31 | 32 | theme: 33 | name: material 34 | features: 35 | - announce.dismiss 36 | - content.code.annotate 37 | - content.code.copy 38 | - content.tooltips 39 | - navigation.footer 40 | - navigation.indexes 41 | - navigation.sections 42 | - navigation.tabs 43 | - navigation.tabs.sticky 44 | - navigation.top 45 | - search.highlight 46 | - search.suggest 47 | - toc.follow 48 | palette: 49 | - media: "(prefers-color-scheme: light)" 50 | scheme: default 51 | primary: teal 52 | accent: purple 53 | toggle: 54 | icon: material/weather-sunny 55 | name: Switch to dark mode 56 | - media: "(prefers-color-scheme: dark)" 57 | scheme: slate 58 | primary: black 59 | accent: lime 60 | toggle: 61 | icon: material/weather-night 62 | name: Switch to light mode 63 | 64 | markdown_extensions: 65 | - attr_list 66 | - md_in_html 67 | - def_list 68 | - admonition 69 | - footnotes 70 | - pymdownx.highlight: 71 | anchor_linenums: true 72 | line_spans: __span 73 | pygments_lang_class: true 74 | - pymdownx.inlinehilite 75 | - pymdownx.superfences 76 | - pymdownx.magiclink 77 | - pymdownx.snippets: 78 | check_paths: true 79 | - pymdownx.details 80 | - pymdownx.arithmatex: 81 | generic: true 82 | - pymdownx.tabbed: 83 | alternate_style: true 84 | - toc: 85 | permalink: "#" 86 | 87 | plugins: 88 | - autorefs 89 | - search 90 | - gen-files: 91 | scripts: 92 | - docs/scripts/gen_ref_pages.py 93 | - mkdocs-jupyter: 94 | include: [ "examples/*.ipynb" ] 95 | - literate-nav: 96 | nav_file: SUMMARY.md 97 | - mkdocstrings: 98 | handlers: 99 | python: 100 | paths: [smee/] 101 | import: 102 | - http://docs.openmm.org/latest/api-python/objects.inv 103 | options: 104 | extensions: [ griffe_pydantic ] 105 | docstring_options: 106 | ignore_init_summary: true 107 | returns_multiple_items: false 108 | returns_named_value: false 109 | docstring_section_style: list 110 | heading_level: 1 111 | inherited_members: true 112 | merge_init_into_class: true 113 | separate_signature: true 114 | show_root_heading: true 115 | show_root_full_path: false 116 | show_signature_annotations: true 117 | show_symbol_type_heading: true 118 | show_symbol_type_toc: true 119 | signature_crossrefs: true 120 | summary: true 121 | members_order: source 122 | 123 | extra_javascript: 124 | - javascripts/mathjax.js 125 | - https://polyfill.io/v3/polyfill.min.js?features=es6 126 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 127 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "setuptools_scm>=8", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "smee" 7 | description = "Differentiably compute energies of molecules using SMIRNOFF force fields." 8 | authors = [ {name = "Simon Boothroyd"} ] 9 | license = { text = "MIT" } 10 | dynamic = ["version"] 11 | readme = "README.md" 12 | requires-python = ">=3.10" 13 | classifiers = ["Programming Language :: Python :: 3"] 14 | 15 | [tool.setuptools.packages.find] 16 | include = ["smee*"] 17 | 18 | [tool.setuptools_scm] 19 | 20 | [tool.ruff] 21 | extend-include = ["*.ipynb"] 22 | 23 | [tool.ruff.lint] 24 | ignore = ["C901","E402","E501"] 25 | select = ["B","C","E","F","W","B9"] 26 | 27 | [tool.ruff.lint.pydocstyle] 28 | convention = "google" 29 | 30 | [tool.coverage.run] 31 | omit = ["**/tests/*", "smee/mm/_fe.py"] 32 | 33 | [tool.coverage.report] 34 | exclude_lines = [ 35 | "@overload", 36 | "pragma: no cover", 37 | "raise NotImplementedError", 38 | "if __name__ = .__main__.:", 39 | "if TYPE_CHECKING:", 40 | "if typing.TYPE_CHECKING:", 41 | ] 42 | 43 | [tool.pytest.ini_options] 44 | markers = [ 45 | "fe: run free energy regression tests", 46 | ] 47 | addopts = "-m 'not fe'" 48 | -------------------------------------------------------------------------------- /smee/__init__.py: -------------------------------------------------------------------------------- 1 | """Differentiably evaluate energies of molecules using SMIRNOFF force fields""" 2 | 3 | import importlib.metadata 4 | 5 | from ._constants import CUTOFF_ATTRIBUTE, SWITCH_ATTRIBUTE, EnergyFn, PotentialType 6 | from ._models import ( 7 | NonbondedParameterMap, 8 | ParameterMap, 9 | TensorConstraints, 10 | TensorForceField, 11 | TensorPotential, 12 | TensorSystem, 13 | TensorTopology, 14 | TensorVSites, 15 | ValenceParameterMap, 16 | VSiteMap, 17 | ) 18 | from .geometry import add_v_site_coords, compute_v_site_coords 19 | from .potentials import compute_energy, compute_energy_potential 20 | 21 | try: 22 | __version__ = importlib.metadata.version("smee") 23 | except importlib.metadata.PackageNotFoundError: 24 | __version__ = "0+unknown" 25 | 26 | __all__ = [ 27 | "CUTOFF_ATTRIBUTE", 28 | "SWITCH_ATTRIBUTE", 29 | "EnergyFn", 30 | "PotentialType", 31 | "ValenceParameterMap", 32 | "NonbondedParameterMap", 33 | "ParameterMap", 34 | "VSiteMap", 35 | "TensorConstraints", 36 | "TensorTopology", 37 | "TensorSystem", 38 | "TensorPotential", 39 | "TensorVSites", 40 | "TensorForceField", 41 | "__version__", 42 | "add_v_site_coords", 43 | "compute_v_site_coords", 44 | "compute_energy", 45 | "compute_energy_potential", 46 | ] 47 | -------------------------------------------------------------------------------- /smee/_constants.py: -------------------------------------------------------------------------------- 1 | """Constants used throughout the package.""" 2 | 3 | import enum 4 | 5 | if hasattr(enum, "StrEnum"): 6 | _StrEnum = enum.StrEnum 7 | else: 8 | import typing 9 | 10 | _S = typing.TypeVar("_S", bound="_StrEnum") 11 | 12 | class _StrEnum(str, enum.Enum): 13 | """TODO: remove when python 3.10 support is dropped.""" 14 | 15 | def __new__(cls: typing.Type[_S], *values: str) -> _S: 16 | value = str(*values) 17 | 18 | member = str.__new__(cls, value) 19 | member._value_ = value 20 | 21 | return member 22 | 23 | __str__ = str.__str__ 24 | 25 | 26 | class PotentialType(_StrEnum): 27 | """An enumeration of the potential types supported by ``smee`` out of the box.""" 28 | 29 | BONDS = "Bonds" 30 | ANGLES = "Angles" 31 | 32 | PROPER_TORSIONS = "ProperTorsions" 33 | IMPROPER_TORSIONS = "ImproperTorsions" 34 | 35 | VDW = "vdW" 36 | ELECTROSTATICS = "Electrostatics" 37 | 38 | 39 | class EnergyFn(_StrEnum): 40 | """An enumeration of the energy functions supported by ``smee`` out of the box.""" 41 | 42 | COULOMB = "coul" 43 | 44 | VDW_LJ = "4*epsilon*((sigma/r)**12-(sigma/r)**6)" 45 | VDW_DEXP = ( 46 | "epsilon*(" 47 | "beta/(alpha-beta)*exp(alpha*(1-r/r_min))-" 48 | "alpha/(alpha-beta)*exp(beta*(1-r/r_min)))" 49 | ) 50 | # VDW_BUCKINGHAM = "a*exp(-b*r)-c*r^-6" 51 | 52 | BOND_HARMONIC = "k/2*(r-length)**2" 53 | 54 | ANGLE_HARMONIC = "k/2*(theta-angle)**2" 55 | 56 | TORSION_COSINE = "k*(1+cos(periodicity*theta-phase))" 57 | 58 | 59 | CUTOFF_ATTRIBUTE = "cutoff" 60 | """The attribute that should be used to store the cutoff distance of a potential.""" 61 | SWITCH_ATTRIBUTE = "switch_width" 62 | """The attribute that should be used to store the switch width of a potential, if the 63 | potential should use the standard OpenMM switch function. 64 | 65 | This attribute should be omitted if the potential should not use a switch function. 66 | """ 67 | -------------------------------------------------------------------------------- /smee/converters/__init__.py: -------------------------------------------------------------------------------- 1 | """Convert to / from ``smee`` tensor representations.""" 2 | 3 | from smee.converters.openff import ( 4 | convert_handlers, 5 | convert_interchange, 6 | smirnoff_parameter_converter, 7 | ) 8 | from smee.converters.openmm import ( 9 | convert_to_openmm_ffxml, 10 | convert_to_openmm_force, 11 | convert_to_openmm_system, 12 | convert_to_openmm_topology, 13 | ffxml_converter, 14 | ) 15 | 16 | __all__ = [ 17 | "convert_handlers", 18 | "convert_interchange", 19 | "convert_to_openmm_system", 20 | "convert_to_openmm_topology", 21 | "convert_to_openmm_ffxml", 22 | "convert_to_openmm_force", 23 | "ffxml_converter", 24 | "smirnoff_parameter_converter", 25 | ] 26 | -------------------------------------------------------------------------------- /smee/converters/openff/__init__.py: -------------------------------------------------------------------------------- 1 | """Tensor representations of SMIRNOFF force fields.""" 2 | 3 | from smee.converters.openff._openff import ( 4 | convert_handlers, 5 | convert_interchange, 6 | smirnoff_parameter_converter, 7 | ) 8 | 9 | __all__ = ["convert_handlers", "convert_interchange", "smirnoff_parameter_converter"] 10 | -------------------------------------------------------------------------------- /smee/converters/openff/valence.py: -------------------------------------------------------------------------------- 1 | """Convert SMIRNOFF valence parameters into tensors.""" 2 | 3 | import openff.interchange.smirnoff 4 | import openff.units 5 | import torch 6 | 7 | import smee 8 | 9 | _UNITLESS = openff.units.unit.dimensionless 10 | _ANGSTROM = openff.units.unit.angstrom 11 | _RADIANS = openff.units.unit.radians 12 | _KCAL_PER_MOL = openff.units.unit.kilocalories / openff.units.unit.mole 13 | 14 | 15 | def strip_constrained_bonds( 16 | parameter_maps: list[smee.ValenceParameterMap], 17 | constraints: list[set[tuple[int, int]]], 18 | ): 19 | """Remove bonded interactions between distance-constrained atoms. 20 | 21 | Args: 22 | parameter_maps: The parameter maps to strip. 23 | constraints: The distanced constrained bonds to exclude for each parameter map. 24 | """ 25 | 26 | for parameter_map, bonds_to_exclude in zip( 27 | parameter_maps, constraints, strict=True 28 | ): 29 | bonds_to_exclude = {tuple(sorted(idxs)) for idxs in bonds_to_exclude} 30 | 31 | bond_idxs = [ 32 | tuple(sorted(idxs)) for idxs in parameter_map.particle_idxs.tolist() 33 | ] 34 | include = [idxs not in bonds_to_exclude for idxs in bond_idxs] 35 | 36 | parameter_map.particle_idxs = parameter_map.particle_idxs[include] 37 | parameter_map.assignment_matrix = parameter_map.assignment_matrix.to_dense()[ 38 | include, : 39 | ].to_sparse() 40 | 41 | 42 | def strip_constrained_angles( 43 | parameter_maps: list[smee.ValenceParameterMap], 44 | constraints: list[set[tuple[int, int]]], 45 | ): 46 | """Remove angle interactions between angles where all three atoms are constrained 47 | with distance constraints. 48 | 49 | Args: 50 | parameter_maps: The parameter maps to strip. 51 | constraints: The distanced constrained bonds to exclude for each parameter map. 52 | """ 53 | 54 | def is_constrained(idxs_, excluded): 55 | bonds = { 56 | tuple(sorted([idxs_[0], idxs_[1]])), 57 | tuple(sorted([idxs_[0], idxs_[2]])), 58 | tuple(sorted([idxs_[1], idxs_[2]])), 59 | } 60 | return len(bonds & excluded) == 3 61 | 62 | for parameter_map, bonds_to_exclude in zip( 63 | parameter_maps, constraints, strict=True 64 | ): 65 | bonds_to_exclude = {tuple(sorted(idxs)) for idxs in bonds_to_exclude} 66 | 67 | angle_idxs = parameter_map.particle_idxs.tolist() 68 | include = [not is_constrained(idxs, bonds_to_exclude) for idxs in angle_idxs] 69 | 70 | parameter_map.particle_idxs = parameter_map.particle_idxs[include] 71 | parameter_map.assignment_matrix = parameter_map.assignment_matrix.to_dense()[ 72 | include, : 73 | ].to_sparse() 74 | 75 | 76 | def convert_valence_handlers( 77 | handlers: list[openff.interchange.smirnoff.SMIRNOFFCollection], 78 | handler_type: str, 79 | parameter_cols: tuple[str, ...], 80 | ) -> tuple[smee.TensorPotential, list[smee.ValenceParameterMap]]: 81 | """Convert a list of SMIRNOFF valence handlers into a tensor potential and 82 | associated parameter maps. 83 | 84 | Notes: 85 | This function assumes that all parameters come from the same force field 86 | 87 | Args: 88 | handlers: The list of SMIRNOFF valence handlers to convert. 89 | handler_type: The type of valence handler being converted. 90 | parameter_cols: The ordering of the parameter array columns. 91 | 92 | Returns: 93 | The potential containing tensors of the parameter values, and a list of 94 | parameter maps which map the parameters to the interactions they apply to. 95 | """ 96 | potential = smee.converters.openff._openff._handlers_to_potential( 97 | handlers, handler_type, parameter_cols, None 98 | ) 99 | 100 | parameter_key_to_idx = { 101 | parameter_key: i for i, parameter_key in enumerate(potential.parameter_keys) 102 | } 103 | parameter_maps = [] 104 | 105 | for handler in handlers: 106 | particle_idxs = [topology_key.atom_indices for topology_key in handler.key_map] 107 | 108 | assignment_matrix = torch.zeros( 109 | (len(particle_idxs), len(potential.parameters)), dtype=torch.float64 110 | ) 111 | 112 | for i, parameter_key in enumerate(handler.key_map.values()): 113 | assignment_matrix[i, parameter_key_to_idx[parameter_key]] += 1.0 114 | 115 | parameter_map = smee.ValenceParameterMap( 116 | torch.tensor(particle_idxs), assignment_matrix.to_sparse() 117 | ) 118 | parameter_maps.append(parameter_map) 119 | 120 | return potential, parameter_maps 121 | 122 | 123 | @smee.converters.smirnoff_parameter_converter( 124 | "Bonds", {"k": _KCAL_PER_MOL / _ANGSTROM**2, "length": _ANGSTROM} 125 | ) 126 | def convert_bonds( 127 | handlers: list[openff.interchange.smirnoff.SMIRNOFFBondCollection], 128 | constraints: list[set[tuple[int, int]]], 129 | ) -> tuple[smee.TensorPotential, list[smee.ValenceParameterMap]]: 130 | potential, parameter_maps = convert_valence_handlers( 131 | handlers, "Bonds", ("k", "length") 132 | ) 133 | strip_constrained_bonds(parameter_maps, constraints) 134 | 135 | return potential, parameter_maps 136 | 137 | 138 | @smee.converters.smirnoff_parameter_converter( 139 | "Angles", {"k": _KCAL_PER_MOL / _RADIANS**2, "angle": _RADIANS} 140 | ) 141 | def convert_angles( 142 | handlers: list[openff.interchange.smirnoff.SMIRNOFFAngleCollection], 143 | constraints: list[set[tuple[int, int]]], 144 | ) -> tuple[smee.TensorPotential, list[smee.ValenceParameterMap]]: 145 | potential, parameter_maps = convert_valence_handlers( 146 | handlers, "Angles", ("k", "angle") 147 | ) 148 | strip_constrained_angles(parameter_maps, constraints) 149 | 150 | return potential, parameter_maps 151 | 152 | 153 | @smee.converters.smirnoff_parameter_converter( 154 | "ProperTorsions", 155 | { 156 | "k": _KCAL_PER_MOL, 157 | "periodicity": _UNITLESS, 158 | "phase": _RADIANS, 159 | "idivf": _UNITLESS, 160 | }, 161 | ) 162 | def convert_propers( 163 | handlers: list[openff.interchange.smirnoff.SMIRNOFFProperTorsionCollection], 164 | ) -> tuple[smee.TensorPotential, list[smee.ValenceParameterMap]]: 165 | return convert_valence_handlers( 166 | handlers, "ProperTorsions", ("k", "periodicity", "phase", "idivf") 167 | ) 168 | 169 | 170 | @smee.converters.smirnoff_parameter_converter( 171 | "ImproperTorsions", 172 | { 173 | "k": _KCAL_PER_MOL, 174 | "periodicity": _UNITLESS, 175 | "phase": _RADIANS, 176 | "idivf": _UNITLESS, 177 | }, 178 | ) 179 | def convert_impropers( 180 | handlers: list[openff.interchange.smirnoff.SMIRNOFFImproperTorsionCollection], 181 | ) -> tuple[smee.TensorPotential, list[smee.ValenceParameterMap]]: 182 | return convert_valence_handlers( 183 | handlers, "ImproperTorsions", ("k", "periodicity", "phase", "idivf") 184 | ) 185 | -------------------------------------------------------------------------------- /smee/converters/openmm/__init__.py: -------------------------------------------------------------------------------- 1 | """Convert tensor representations into OpenMM systems.""" 2 | 3 | from smee.converters.openmm._ff import convert_to_openmm_ffxml, ffxml_converter 4 | from smee.converters.openmm._openmm import ( 5 | convert_to_openmm_force, 6 | convert_to_openmm_system, 7 | convert_to_openmm_topology, 8 | create_openmm_system, 9 | potential_converter, 10 | ) 11 | 12 | __all__ = [ 13 | "convert_to_openmm_ffxml", 14 | "convert_to_openmm_force", 15 | "convert_to_openmm_system", 16 | "convert_to_openmm_topology", 17 | "create_openmm_system", 18 | "ffxml_converter", 19 | "potential_converter", 20 | ] 21 | -------------------------------------------------------------------------------- /smee/converters/openmm/_openmm.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import importlib 4 | import typing 5 | 6 | import openmm 7 | import openmm.app 8 | 9 | import smee 10 | 11 | _ANGSTROM = openmm.unit.angstrom 12 | _ANGSTROM_TO_NM = 1.0 / 10.0 13 | 14 | 15 | _CONVERTER_FUNCTIONS: dict[ 16 | tuple[str, str], 17 | typing.Callable[ 18 | [smee.TensorPotential, smee.TensorSystem], openmm.Force | list[openmm.Force] 19 | ], 20 | ] = {} 21 | 22 | 23 | def potential_converter(handler_type: str, energy_expression: str): 24 | """A decorator used to flag a function as being able to convert a tensor potential 25 | of a given type and energy function to an OpenMM force. 26 | """ 27 | 28 | def _openmm_converter_inner(func): 29 | if (handler_type, energy_expression) in _CONVERTER_FUNCTIONS: 30 | raise KeyError( 31 | f"An OpenMM converter function is already defined for " 32 | f"handler={handler_type} fn={energy_expression}." 33 | ) 34 | 35 | _CONVERTER_FUNCTIONS[(str(handler_type), str(energy_expression))] = func 36 | return func 37 | 38 | return _openmm_converter_inner 39 | 40 | 41 | def _combine_nonbonded( 42 | vdw_force: openmm.NonbondedForce, electrostatic_force: openmm.NonbondedForce 43 | ) -> openmm.NonbondedForce: 44 | assert vdw_force.getNumParticles() == electrostatic_force.getNumParticles() 45 | assert vdw_force.getNumExceptions() == electrostatic_force.getNumExceptions() 46 | assert vdw_force.getNonbondedMethod() == electrostatic_force.getNonbondedMethod() 47 | assert vdw_force.getCutoffDistance() == electrostatic_force.getCutoffDistance() 48 | 49 | force = copy.deepcopy(vdw_force) 50 | force.setEwaldErrorTolerance(electrostatic_force.getEwaldErrorTolerance()) 51 | 52 | for i in range(force.getNumParticles()): 53 | charge, _, _ = electrostatic_force.getParticleParameters(i) 54 | _, sigma, epsilon = vdw_force.getParticleParameters(i) 55 | 56 | force.setParticleParameters(i, charge, sigma, epsilon) 57 | 58 | vdw_exceptions, electrostatic_exceptions = {}, {} 59 | 60 | for index in range(vdw_force.getNumExceptions()): 61 | i, j, *values = vdw_force.getExceptionParameters(index) 62 | vdw_exceptions[(i, j)] = (index, *values) 63 | 64 | for index in range(electrostatic_force.getNumExceptions()): 65 | i, j, *values = electrostatic_force.getExceptionParameters(index) 66 | electrostatic_exceptions[(i, j)] = values 67 | 68 | for (i, j), (charge_prod, _, _) in electrostatic_exceptions.items(): 69 | index, _, sigma, epsilon = vdw_exceptions[(i, j)] 70 | force.setExceptionParameters(index, i, j, charge_prod, sigma, epsilon) 71 | 72 | return force 73 | 74 | 75 | def create_openmm_system( 76 | system: smee.TensorSystem, v_sites: smee.TensorVSites | None 77 | ) -> openmm.System: 78 | """Create an empty OpenMM system from a ``smee`` system.""" 79 | v_sites = None if v_sites is None else v_sites.to("cpu") 80 | system = system.to("cpu") 81 | 82 | omm_system = openmm.System() 83 | 84 | for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): 85 | for _ in range(n_copies): 86 | start_idx = omm_system.getNumParticles() 87 | 88 | for atomic_num in topology.atomic_nums: 89 | mass = openmm.app.Element.getByAtomicNumber(int(atomic_num)).mass 90 | omm_system.addParticle(mass) 91 | 92 | if topology.v_sites is None: 93 | continue 94 | 95 | for _ in range(topology.n_v_sites): 96 | omm_system.addParticle(0.0) 97 | 98 | for key, parameter_idx in zip( 99 | topology.v_sites.keys, topology.v_sites.parameter_idxs, strict=True 100 | ): 101 | system_idx = start_idx + topology.v_sites.key_to_idx[key] 102 | assert system_idx >= start_idx 103 | 104 | parent_idxs = [i + start_idx for i in key.orientation_atom_indices] 105 | 106 | local_frame_coords = smee.geometry.polar_to_cartesian_coords( 107 | v_sites.parameters[[parameter_idx], :].detach() 108 | ) 109 | origin, x_dir, y_dir = v_sites.weights[parameter_idx] 110 | 111 | v_site = openmm.LocalCoordinatesSite( 112 | parent_idxs, 113 | origin.numpy(), 114 | x_dir.numpy(), 115 | y_dir.numpy(), 116 | local_frame_coords.numpy().flatten() * _ANGSTROM_TO_NM, 117 | ) 118 | 119 | omm_system.setVirtualSite(system_idx, v_site) 120 | 121 | return omm_system 122 | 123 | 124 | def _apply_constraints(omm_system: openmm.System, system: smee.TensorSystem): 125 | idx_offset = 0 126 | 127 | for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): 128 | if topology.constraints is None: 129 | continue 130 | 131 | for _ in range(n_copies): 132 | atom_idxs = topology.constraints.idxs + idx_offset 133 | 134 | for (i, j), distance in zip( 135 | atom_idxs, topology.constraints.distances, strict=True 136 | ): 137 | omm_system.addConstraint(i, j, distance * _ANGSTROM) 138 | 139 | idx_offset += topology.n_particles 140 | 141 | 142 | def convert_to_openmm_force( 143 | potential: smee.TensorPotential, system: smee.TensorSystem 144 | ) -> list[openmm.Force]: 145 | """Convert a ``smee`` potential to OpenMM forces. 146 | 147 | Some potentials may return multiple forces, e.g. a vdW potential may return one 148 | force containing intermolecular interactions and another containing intramolecular 149 | interactions. 150 | 151 | See Also: 152 | potential_converter: for how to define a converter function. 153 | 154 | Args: 155 | potential: The potential to convert. 156 | system: The system to convert. 157 | 158 | Returns: 159 | The OpenMM force(s). 160 | """ 161 | # register the built-in converter functions 162 | importlib.import_module("smee.converters.openmm.nonbonded") 163 | importlib.import_module("smee.converters.openmm.valence") 164 | 165 | potential = potential.to("cpu") 166 | system = system.to("cpu") 167 | 168 | if potential.exceptions is not None and potential.type != "vdW": 169 | raise NotImplementedError("exceptions are only supported for vdW potentials") 170 | 171 | converter_key = (str(potential.type), str(potential.fn)) 172 | 173 | if converter_key not in _CONVERTER_FUNCTIONS: 174 | raise NotImplementedError( 175 | f"cannot convert type={potential.type} fn={potential.fn} to an OpenMM force" 176 | ) 177 | 178 | forces = _CONVERTER_FUNCTIONS[converter_key](potential, system) 179 | return forces if isinstance(forces, (list, tuple)) else [forces] 180 | 181 | 182 | def convert_to_openmm_system( 183 | force_field: smee.TensorForceField, 184 | system: smee.TensorSystem | smee.TensorTopology, 185 | ) -> openmm.System: 186 | """Convert a ``smee`` force field and system / topology into an OpenMM system. 187 | 188 | Args: 189 | force_field: The force field parameters. 190 | system: The system / topology to convert. 191 | 192 | Returns: 193 | The OpenMM system. 194 | """ 195 | 196 | system: smee.TensorSystem = ( 197 | system 198 | if isinstance(system, smee.TensorSystem) 199 | else smee.TensorSystem([system], [1], False) 200 | ) 201 | 202 | force_field = force_field.to("cpu") 203 | system = system.to("cpu") 204 | 205 | omm_forces = { 206 | potential_type: convert_to_openmm_force(potential, system) 207 | for potential_type, potential in force_field.potentials_by_type.items() 208 | } 209 | omm_system = create_openmm_system(system, force_field.v_sites) 210 | 211 | if ( 212 | "Electrostatics" in omm_forces 213 | and "vdW" in omm_forces 214 | and len(omm_forces["vdW"]) == 1 215 | and isinstance(omm_forces["vdW"][0], openmm.NonbondedForce) 216 | ): 217 | (electrostatic_force,) = omm_forces.pop("Electrostatics") 218 | (vdw_force,) = omm_forces.pop("vdW") 219 | 220 | nonbonded_force = _combine_nonbonded(vdw_force, electrostatic_force) 221 | omm_system.addForce(nonbonded_force) 222 | 223 | for forces in omm_forces.values(): 224 | for force in forces: 225 | omm_system.addForce(force) 226 | 227 | _apply_constraints(omm_system, system) 228 | 229 | return omm_system 230 | 231 | 232 | def convert_to_openmm_topology( 233 | system: smee.TensorSystem | smee.TensorTopology, 234 | ) -> openmm.app.Topology: 235 | """Convert a ``smee`` system to an OpenMM topology. 236 | 237 | Notes: 238 | Virtual sites are given the name "X{i}". 239 | 240 | Args: 241 | system: The system to convert. 242 | 243 | Returns: 244 | The OpenMM topology. 245 | """ 246 | system: smee.TensorSystem = ( 247 | system 248 | if isinstance(system, smee.TensorSystem) 249 | else smee.TensorSystem([system], [1], False) 250 | ) 251 | 252 | omm_topology = openmm.app.Topology() 253 | 254 | for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): 255 | chain = omm_topology.addChain() 256 | 257 | is_water = topology.n_atoms == 3 and sorted( 258 | int(v) for v in topology.atomic_nums 259 | ) == [1, 1, 8] 260 | 261 | residue_name = "HOH" if is_water else "UNK" 262 | 263 | for _ in range(n_copies): 264 | residue = omm_topology.addResidue(residue_name, chain) 265 | element_counter = collections.defaultdict(int) 266 | 267 | atoms = {} 268 | 269 | for i, atomic_num in enumerate(topology.atomic_nums): 270 | element = openmm.app.Element.getByAtomicNumber(int(atomic_num)) 271 | element_counter[element.symbol] += 1 272 | 273 | name = element.symbol + ( 274 | "" 275 | if element_counter[element.symbol] == 1 and element.symbol != "H" 276 | else f"{element_counter[element.symbol]}" 277 | ) 278 | atoms[i] = omm_topology.addAtom(name, element, residue) 279 | 280 | for i in range(topology.n_v_sites): 281 | omm_topology.addAtom(f"X{i + 1}", None, residue) 282 | 283 | for bond_idxs, bond_order in zip( 284 | topology.bond_idxs, topology.bond_orders, strict=True 285 | ): 286 | idx_a, idx_b = int(bond_idxs[0]), int(bond_idxs[1]) 287 | 288 | bond_order = int(bond_order) 289 | bond_type = { 290 | 1: openmm.app.Single, 291 | 2: openmm.app.Double, 292 | 3: openmm.app.Triple, 293 | }[bond_order] 294 | 295 | omm_topology.addBond(atoms[idx_a], atoms[idx_b], bond_type, bond_order) 296 | 297 | return omm_topology 298 | -------------------------------------------------------------------------------- /smee/converters/openmm/valence.py: -------------------------------------------------------------------------------- 1 | """Convert valence potentials to OpenMM forces.""" 2 | 3 | import openmm 4 | import openmm.app 5 | 6 | import smee 7 | 8 | _KCAL_PER_MOL = openmm.unit.kilocalorie_per_mole 9 | _ANGSTROM = openmm.unit.angstrom 10 | _RADIANS = openmm.unit.radians 11 | 12 | 13 | @smee.converters.openmm.potential_converter( 14 | smee.PotentialType.BONDS, smee.EnergyFn.BOND_HARMONIC 15 | ) 16 | def convert_bond_potential( 17 | potential: smee.TensorPotential, system: smee.TensorSystem 18 | ) -> openmm.HarmonicBondForce: 19 | """Convert a harmonic bond potential to a corresponding OpenMM force.""" 20 | force = openmm.HarmonicBondForce() 21 | 22 | idx_offset = 0 23 | 24 | for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): 25 | parameters = ( 26 | topology.parameters[potential.type].assignment_matrix @ potential.parameters 27 | ).detach() 28 | 29 | for _ in range(n_copies): 30 | atom_idxs = topology.parameters[potential.type].particle_idxs + idx_offset 31 | 32 | for (i, j), (constant, length) in zip(atom_idxs, parameters, strict=True): 33 | force.addBond( 34 | i, 35 | j, 36 | length * _ANGSTROM, 37 | constant * _KCAL_PER_MOL / _ANGSTROM**2, 38 | ) 39 | 40 | idx_offset += topology.n_particles 41 | 42 | return force 43 | 44 | 45 | @smee.converters.openmm.potential_converter( 46 | smee.PotentialType.ANGLES, smee.EnergyFn.ANGLE_HARMONIC 47 | ) 48 | def _convert_angle_potential( 49 | potential: smee.TensorPotential, system: smee.TensorSystem 50 | ) -> openmm.HarmonicAngleForce: 51 | """Convert a harmonic angle potential to a corresponding OpenMM force.""" 52 | force = openmm.HarmonicAngleForce() 53 | 54 | idx_offset = 0 55 | 56 | for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): 57 | parameters = ( 58 | topology.parameters[potential.type].assignment_matrix @ potential.parameters 59 | ).detach() 60 | 61 | for _ in range(n_copies): 62 | atom_idxs = topology.parameters[potential.type].particle_idxs + idx_offset 63 | 64 | for (i, j, k), (constant, angle) in zip(atom_idxs, parameters, strict=True): 65 | force.addAngle( 66 | i, 67 | j, 68 | k, 69 | angle * _RADIANS, 70 | constant * _KCAL_PER_MOL / _RADIANS**2, 71 | ) 72 | 73 | idx_offset += topology.n_particles 74 | 75 | return force 76 | 77 | 78 | @smee.converters.openmm.potential_converter( 79 | smee.PotentialType.PROPER_TORSIONS, smee.EnergyFn.TORSION_COSINE 80 | ) 81 | @smee.converters.openmm.potential_converter( 82 | smee.PotentialType.IMPROPER_TORSIONS, smee.EnergyFn.TORSION_COSINE 83 | ) 84 | def convert_torsion_potential( 85 | potential: smee.TensorPotential, system: smee.TensorSystem 86 | ) -> openmm.PeriodicTorsionForce: 87 | """Convert a torsion potential to a corresponding OpenMM force.""" 88 | force = openmm.PeriodicTorsionForce() 89 | 90 | idx_offset = 0 91 | 92 | for topology, n_copies in zip(system.topologies, system.n_copies, strict=True): 93 | parameters = ( 94 | topology.parameters[potential.type].assignment_matrix @ potential.parameters 95 | ).detach() 96 | 97 | for _ in range(n_copies): 98 | atom_idxs = topology.parameters[potential.type].particle_idxs + idx_offset 99 | 100 | for (idx_i, idx_j, idx_k, idx_l), ( 101 | constant, 102 | periodicity, 103 | phase, 104 | idivf, 105 | ) in zip(atom_idxs, parameters, strict=True): 106 | force.addTorsion( 107 | idx_i, 108 | idx_j, 109 | idx_k, 110 | idx_l, 111 | int(periodicity), 112 | phase * _RADIANS, 113 | constant / idivf * _KCAL_PER_MOL, 114 | ) 115 | 116 | idx_offset += topology.n_particles 117 | 118 | return force 119 | -------------------------------------------------------------------------------- /smee/mm/__init__.py: -------------------------------------------------------------------------------- 1 | """Compute differentiable ensemble averages using OpenMM and SMEE.""" 2 | 3 | from smee.mm._config import GenerateCoordsConfig, MinimizationConfig, SimulationConfig 4 | from smee.mm._fe import generate_dg_solv_data 5 | from smee.mm._mm import generate_system_coords, simulate 6 | from smee.mm._ops import ( 7 | NotEnoughSamplesError, 8 | compute_dg_solv, 9 | compute_ensemble_averages, 10 | reweight_dg_solv, 11 | reweight_ensemble_averages, 12 | ) 13 | from smee.mm._reporters import TensorReporter, tensor_reporter, unpack_frames 14 | 15 | __all__ = [ 16 | "compute_dg_solv", 17 | "compute_ensemble_averages", 18 | "generate_dg_solv_data", 19 | "generate_system_coords", 20 | "reweight_dg_solv", 21 | "reweight_ensemble_averages", 22 | "simulate", 23 | "GenerateCoordsConfig", 24 | "MinimizationConfig", 25 | "NotEnoughSamplesError", 26 | "SimulationConfig", 27 | "TensorReporter", 28 | "tensor_reporter", 29 | "unpack_frames", 30 | ] 31 | -------------------------------------------------------------------------------- /smee/mm/_config.py: -------------------------------------------------------------------------------- 1 | """Configuration from MM simulations.""" 2 | 3 | import openmm.unit 4 | import pydantic 5 | from pydantic_units import OpenMMQuantity, quantity_serializer 6 | 7 | _KCAL_PER_MOL = openmm.unit.kilocalories_per_mole 8 | _ANGSTROM = openmm.unit.angstrom 9 | _GRAMS_PER_ML = openmm.unit.grams / openmm.unit.milliliters 10 | 11 | 12 | if pydantic.__version__.startswith("1."): 13 | 14 | class BaseModel(pydantic.BaseModel): 15 | class Config: 16 | json_encoders = {openmm.unit.Quantity: quantity_serializer} 17 | 18 | else: 19 | BaseModel = pydantic.BaseModel 20 | 21 | 22 | class GenerateCoordsConfig(BaseModel): 23 | """Configure how coordinates should be generated for a system using PACKMOL.""" 24 | 25 | target_density: OpenMMQuantity[_GRAMS_PER_ML] = pydantic.Field( 26 | 0.95 * _GRAMS_PER_ML, 27 | description="Target mass density for final system with units compatible with " 28 | "g / mL.", 29 | ) 30 | 31 | scale_factor: float = pydantic.Field( 32 | 1.1, 33 | description="The amount to scale the approximate box size by to help alleviate " 34 | "issues with packing larger molecules.", 35 | ) 36 | padding: OpenMMQuantity[openmm.unit.angstrom] = pydantic.Field( 37 | 2.0 * openmm.unit.angstrom, 38 | description="The amount of padding to add to the final box size to help " 39 | "alleviate PBC issues.", 40 | ) 41 | 42 | tolerance: OpenMMQuantity[openmm.unit.angstrom] = pydantic.Field( 43 | 2.0 * openmm.unit.angstrom, 44 | description="The minimum spacing between molecules during packing.", 45 | ) 46 | 47 | seed: int | None = pydantic.Field( 48 | None, description="The random seed to use when generating the coordinates." 49 | ) 50 | 51 | 52 | class MinimizationConfig(BaseModel): 53 | """Configure how a system should be energy minimized.""" 54 | 55 | tolerance: OpenMMQuantity[_KCAL_PER_MOL / _ANGSTROM] = pydantic.Field( 56 | 10.0 * _KCAL_PER_MOL / _ANGSTROM, 57 | description="Minimization will be halted once the root-mean-square value of " 58 | "all force components reaches this tolerance.", 59 | ) 60 | max_iterations: int = pydantic.Field( 61 | 0, 62 | description="The maximum number of iterations to perform. If 0, minimization " 63 | "will continue until the tolerance is met.", 64 | ) 65 | 66 | 67 | class SimulationConfig(BaseModel): 68 | temperature: OpenMMQuantity[openmm.unit.kelvin] = pydantic.Field( 69 | ..., 70 | description="The temperature to simulate at.", 71 | ) 72 | pressure: OpenMMQuantity[openmm.unit.atmospheres] | None = pydantic.Field( 73 | ..., 74 | description="The pressure to simulate at, or none to run in NVT.", 75 | ) 76 | 77 | n_steps: int = pydantic.Field( 78 | ..., description="The number of steps to simulate for." 79 | ) 80 | 81 | timestep: OpenMMQuantity[openmm.unit.femtoseconds] = pydantic.Field( 82 | 2.0 * openmm.unit.femtoseconds, 83 | description="The timestep to use during the simulation.", 84 | ) 85 | friction_coeff: OpenMMQuantity[1.0 / openmm.unit.picoseconds] = pydantic.Field( 86 | 1.0 / openmm.unit.picoseconds, 87 | description="The integrator friction coefficient.", 88 | ) 89 | -------------------------------------------------------------------------------- /smee/mm/_reporters.py: -------------------------------------------------------------------------------- 1 | """OpenMM simulation reporters""" 2 | 3 | import contextlib 4 | import math 5 | import os 6 | import typing 7 | 8 | import msgpack 9 | import numpy 10 | import openmm.app 11 | import openmm.unit 12 | import torch 13 | 14 | _ANGSTROM = openmm.unit.angstrom 15 | _KCAL_PER_MOL = openmm.unit.kilocalories_per_mole 16 | 17 | 18 | def _encoder(obj, chain=None): 19 | """msgpack encoder for tensors""" 20 | if isinstance(obj, torch.Tensor): 21 | assert obj.dtype == torch.float32 22 | return {b"torch": True, b"shape": obj.shape, b"data": obj.numpy().tobytes()} 23 | else: 24 | return obj if chain is None else chain(obj) 25 | 26 | 27 | def _decoder(obj, chain=None): 28 | """msgpack decoder for tensors""" 29 | try: 30 | if b"torch" in obj: 31 | array = numpy.ndarray( 32 | buffer=obj[b"data"], dtype=numpy.float32, shape=obj[b"shape"] 33 | ) 34 | return torch.from_numpy(array.copy()) 35 | else: 36 | return obj if chain is None else chain(obj) 37 | except KeyError: 38 | return obj if chain is None else chain(obj) 39 | 40 | 41 | class TensorReporter: 42 | """A reporter which stores coords, box vectors, reduced potentials and kinetic 43 | energy using msgpack.""" 44 | 45 | def __init__( 46 | self, 47 | output_file: typing.BinaryIO, 48 | report_interval: int, 49 | beta: openmm.unit.Quantity, 50 | pressure: openmm.unit.Quantity | None, 51 | ): 52 | """ 53 | 54 | Args: 55 | output_file: The file to write the frames to. 56 | report_interval: The interval (in steps) at which to write frames. 57 | beta: The inverse temperature the simulation is being run at. 58 | pressure: The pressure the simulation is being run at, or None if NVT / 59 | vacuum. 60 | """ 61 | self._output_file = output_file 62 | self._report_interval = report_interval 63 | 64 | self._beta = beta 65 | self._pressure = ( 66 | None if pressure is None else pressure * openmm.unit.AVOGADRO_CONSTANT_NA 67 | ) 68 | 69 | def describeNextReport(self, simulation: openmm.app.Simulation): 70 | steps = self._report_interval - simulation.currentStep % self._report_interval 71 | # requires - positions, velocities, forces, energies? 72 | return steps, True, False, False, True 73 | 74 | def report(self, simulation: openmm.app.Simulation, state: openmm.State): 75 | potential_energy = state.getPotentialEnergy() 76 | kinetic_energy = state.getKineticEnergy() 77 | 78 | total_energy = potential_energy + kinetic_energy 79 | 80 | if math.isnan(total_energy.value_in_unit(_KCAL_PER_MOL)): 81 | raise ValueError("total energy is nan") 82 | if math.isinf(total_energy.value_in_unit(_KCAL_PER_MOL)): 83 | raise ValueError("total energy is infinite") 84 | 85 | unreduced_potential = potential_energy 86 | 87 | if self._pressure is not None: 88 | unreduced_potential += self._pressure * state.getPeriodicBoxVolume() 89 | 90 | reduced_potential = unreduced_potential * self._beta 91 | 92 | coords = state.getPositions(asNumpy=True).value_in_unit(_ANGSTROM) 93 | coords = torch.from_numpy(coords).float() 94 | box_vectors = state.getPeriodicBoxVectors(asNumpy=True).value_in_unit(_ANGSTROM) 95 | box_vectors = torch.from_numpy(box_vectors).float() 96 | 97 | frame = ( 98 | coords, 99 | box_vectors, 100 | reduced_potential, 101 | kinetic_energy.value_in_unit(_KCAL_PER_MOL), 102 | ) 103 | self._output_file.write(msgpack.dumps(frame, default=_encoder)) 104 | 105 | 106 | def unpack_frames( 107 | file: typing.BinaryIO, 108 | ) -> typing.Generator[tuple[torch.Tensor, torch.Tensor, float], None, None]: 109 | """Unpack frames saved by a ``TensorReporter``.""" 110 | 111 | unpacker = msgpack.Unpacker(file, object_hook=_decoder) 112 | 113 | for frame in unpacker: 114 | yield frame 115 | 116 | 117 | @contextlib.contextmanager 118 | def tensor_reporter( 119 | output_path: os.PathLike, 120 | report_interval: int, 121 | beta: openmm.unit.Quantity, 122 | pressure: openmm.unit.Quantity | None, 123 | ) -> TensorReporter: 124 | """Create a ``TensorReporter`` capable of writing frames to a file. 125 | 126 | Args: 127 | output_path: The path to write the frames to. 128 | report_interval: The interval (in steps) at which to write frames. 129 | beta: The inverse temperature the simulation is being run at. 130 | pressure: The pressure the simulation is being run at, or ``None`` if NVT / 131 | vacuum. 132 | """ 133 | with open(output_path, "wb") as output_file: 134 | reporter = TensorReporter(output_file, report_interval, beta, pressure) 135 | yield reporter 136 | -------------------------------------------------------------------------------- /smee/mm/_utils.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | from rdkit.Chem import AllChem 3 | 4 | import smee 5 | 6 | 7 | def topology_to_rdkit(topology: smee.TensorTopology) -> Chem.Mol: 8 | """Convert a topology to an RDKit molecule.""" 9 | mol = Chem.RWMol() 10 | 11 | for atomic_num, formal_charge in zip( 12 | topology.atomic_nums, topology.formal_charges, strict=True 13 | ): 14 | atom = Chem.Atom(int(atomic_num)) 15 | atom.SetFormalCharge(int(formal_charge)) 16 | mol.AddAtom(atom) 17 | 18 | for bond_idxs, bond_order in zip( 19 | topology.bond_idxs, topology.bond_orders, strict=True 20 | ): 21 | idx_a, idx_b = int(bond_idxs[0]), int(bond_idxs[1]) 22 | mol.AddBond(idx_a, idx_b, Chem.BondType(bond_order)) 23 | 24 | mol = Chem.Mol(mol) 25 | mol.UpdatePropertyCache() 26 | 27 | AllChem.EmbedMolecule(mol) 28 | 29 | return mol 30 | -------------------------------------------------------------------------------- /smee/potentials/__init__.py: -------------------------------------------------------------------------------- 1 | """Evaluate the potential energy of parameterized topologies.""" 2 | 3 | from smee.potentials._potentials import ( 4 | broadcast_exceptions, 5 | broadcast_idxs, 6 | broadcast_parameters, 7 | compute_energy, 8 | compute_energy_potential, 9 | potential_energy_fn, 10 | ) 11 | 12 | __all__ = [ 13 | "broadcast_exceptions", 14 | "broadcast_idxs", 15 | "broadcast_parameters", 16 | "compute_energy", 17 | "compute_energy_potential", 18 | "potential_energy_fn", 19 | ] 20 | -------------------------------------------------------------------------------- /smee/potentials/valence.py: -------------------------------------------------------------------------------- 1 | """Valence potential energy functions.""" 2 | 3 | import torch 4 | 5 | import smee.geometry 6 | import smee.potentials 7 | import smee.utils 8 | 9 | 10 | @smee.potentials.potential_energy_fn( 11 | smee.PotentialType.BONDS, smee.EnergyFn.BOND_HARMONIC 12 | ) 13 | def compute_harmonic_bond_energy( 14 | system: smee.TensorSystem, 15 | potential: smee.TensorPotential, 16 | conformer: torch.Tensor, 17 | ) -> torch.Tensor: 18 | """Compute the potential energy [kcal / mol] of a set of bonds for a given 19 | conformer using a harmonic potential of the form ``1/2 * k * (r - length) ** 2`` 20 | 21 | Args: 22 | system: The system to compute the energy for. 23 | potential: The potential energy function to evaluate. 24 | conformer: The conformer [Å] to evaluate the potential at with 25 | ``shape=(n_confs, n_particles, 3)`` or ``shape=(n_particles, 3)``. 26 | 27 | Returns: 28 | The computed potential energy [kcal / mol]. 29 | """ 30 | 31 | parameters = smee.potentials.broadcast_parameters(system, potential) 32 | particle_idxs = smee.potentials.broadcast_idxs(system, potential) 33 | 34 | _, distances = smee.geometry.compute_bond_vectors(conformer, particle_idxs) 35 | 36 | k = parameters[:, potential.parameter_cols.index("k")] 37 | length = parameters[:, potential.parameter_cols.index("length")] 38 | 39 | return (0.5 * k * (distances - length) ** 2).sum(-1) 40 | 41 | 42 | @smee.potentials.potential_energy_fn( 43 | smee.PotentialType.ANGLES, smee.EnergyFn.ANGLE_HARMONIC 44 | ) 45 | def compute_harmonic_angle_energy( 46 | system: smee.TensorSystem, 47 | potential: smee.TensorPotential, 48 | conformer: torch.Tensor, 49 | ) -> torch.Tensor: 50 | """Compute the potential energy [kcal / mol] of a set of valence angles for a given 51 | conformer using a harmonic potential of the form ``1/2 * k * (theta - angle) ** 2`` 52 | 53 | Args: 54 | system: The system to compute the energy for. 55 | potential: The potential energy function to evaluate. 56 | conformer: The conformer [Å] to evaluate the potential at with 57 | ``shape=(n_confs, n_particles, 3)`` or ``shape=(n_particles, 3)``. 58 | 59 | Returns: 60 | The computed potential energy [kcal / mol]. 61 | """ 62 | 63 | parameters = smee.potentials.broadcast_parameters(system, potential) 64 | particle_idxs = smee.potentials.broadcast_idxs(system, potential) 65 | 66 | theta = smee.geometry.compute_angles(conformer, particle_idxs) 67 | 68 | k = parameters[:, potential.parameter_cols.index("k")] 69 | angle = parameters[:, potential.parameter_cols.index("angle")] 70 | 71 | return (0.5 * k * (theta - angle) ** 2).sum(-1) 72 | 73 | 74 | def _compute_cosine_torsion_energy( 75 | system: smee.TensorSystem, 76 | potential: smee.TensorPotential, 77 | conformer: torch.Tensor, 78 | ) -> torch.Tensor: 79 | """Compute the potential energy [kcal / mol] of a set of torsions for a given 80 | conformer using a cosine potential of the form 81 | ``k/idivf*(1+cos(periodicity*phi-phase))`` 82 | 83 | Args: 84 | system: The system to compute the energy for. 85 | potential: The potential energy function to evaluate. 86 | conformer: The conformer [Å] to evaluate the potential at with 87 | ``shape=(n_confs, n_particles, 3)`` or ``shape=(n_particles, 3)``. 88 | 89 | Returns: 90 | The computed potential energy [kcal / mol]. 91 | """ 92 | 93 | parameters = smee.potentials.broadcast_parameters(system, potential) 94 | particle_idxs = smee.potentials.broadcast_idxs(system, potential) 95 | 96 | phi = smee.geometry.compute_dihedrals(conformer, particle_idxs) 97 | 98 | k = parameters[:, potential.parameter_cols.index("k")] 99 | periodicity = parameters[:, potential.parameter_cols.index("periodicity")] 100 | phase = parameters[:, potential.parameter_cols.index("phase")] 101 | idivf = parameters[:, potential.parameter_cols.index("idivf")] 102 | 103 | return ((k / idivf) * (1.0 + torch.cos(periodicity * phi - phase))).sum(-1) 104 | 105 | 106 | @smee.potentials.potential_energy_fn( 107 | smee.PotentialType.PROPER_TORSIONS, smee.EnergyFn.TORSION_COSINE 108 | ) 109 | def compute_cosine_proper_torsion_energy( 110 | system: smee.TensorSystem, 111 | potential: smee.TensorPotential, 112 | conformer: torch.Tensor, 113 | ) -> torch.Tensor: 114 | """Compute the potential energy [kcal / mol] of a set of proper torsions 115 | for a given conformer using a cosine potential of the form: 116 | 117 | `k*(1+cos(periodicity*theta-phase))` 118 | 119 | Args: 120 | system: The system to compute the energy for. 121 | potential: The potential energy function to evaluate. 122 | conformer: The conformer [Å] to evaluate the potential at with 123 | ``shape=(n_confs, n_particles, 3)`` or ``shape=(n_particles, 3)``. 124 | 125 | Returns: 126 | The computed potential energy [kcal / mol]. 127 | """ 128 | return _compute_cosine_torsion_energy(system, potential, conformer) 129 | 130 | 131 | @smee.potentials.potential_energy_fn( 132 | smee.PotentialType.IMPROPER_TORSIONS, smee.EnergyFn.TORSION_COSINE 133 | ) 134 | def compute_cosine_improper_torsion_energy( 135 | system: smee.TensorSystem, 136 | potential: smee.TensorPotential, 137 | conformer: torch.Tensor, 138 | ) -> torch.Tensor: 139 | """Compute the potential energy [kcal / mol] of a set of improper torsions 140 | for a given conformer using a cosine potential of the form: 141 | 142 | `k*(1+cos(periodicity*theta-phase))` 143 | 144 | Args: 145 | system: The system to compute the energy for. 146 | potential: The potential energy function to evaluate. 147 | conformer: The conformer [Å] to evaluate the potential at with 148 | ``shape=(n_confs, n_particles, 3)`` or ``shape=(n_particles, 3)``. 149 | 150 | Returns: 151 | The computed potential energy [kcal / mol]. 152 | """ 153 | return _compute_cosine_torsion_energy(system, potential, conformer) 154 | -------------------------------------------------------------------------------- /smee/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimonBoothroyd/smee/aca69b9da4c67916c6e59ed2c435fffd4c49a2b6/smee/py.typed -------------------------------------------------------------------------------- /smee/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimonBoothroyd/smee/aca69b9da4c67916c6e59ed2c435fffd4c49a2b6/smee/tests/__init__.py -------------------------------------------------------------------------------- /smee/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import openff.interchange 4 | import openff.toolkit 5 | import openff.units 6 | import pytest 7 | import torch 8 | 9 | _ANGSTROM = openff.units.unit.angstrom 10 | _NM = openff.units.unit.nanometer 11 | 12 | _DEGREES = openff.units.unit.degree 13 | 14 | _KJ_PER_MOLE = openff.units.unit.kilojoules / openff.units.unit.mole 15 | _KCAL_PER_MOLE = openff.units.unit.kilocalories / openff.units.unit.mole 16 | 17 | _E = openff.units.unit.elementary_charge 18 | 19 | 20 | @pytest.fixture 21 | def tmp_cwd(tmp_path, monkeypatch) -> pathlib.Path: 22 | monkeypatch.chdir(tmp_path) 23 | yield tmp_path 24 | 25 | 26 | @pytest.fixture 27 | def test_data_dir() -> pathlib.Path: 28 | return pathlib.Path(__file__).parent / "data" 29 | 30 | 31 | @pytest.fixture(scope="module") 32 | def default_force_field() -> openff.toolkit.ForceField: 33 | """Returns the OpenFF 1.3.0 force field with constraints removed.""" 34 | 35 | force_field = openff.toolkit.ForceField("openff-1.3.0.offxml") 36 | force_field.deregister_parameter_handler("Constraints") 37 | 38 | return force_field 39 | 40 | 41 | @pytest.fixture(scope="module") 42 | def ethanol() -> openff.toolkit.Molecule: 43 | """Returns an OpenFF ethanol molecule with a fixed atom order.""" 44 | 45 | return openff.toolkit.Molecule.from_mapped_smiles( 46 | "[H:5][C:2]([H:6])([H:7])[C:3]([H:8])([H:9])[O:1][H:4]" 47 | ) 48 | 49 | 50 | @pytest.fixture(scope="module") 51 | def ethanol_conformer(ethanol) -> torch.Tensor: 52 | """Returns a conformer [Å] of ethanol with an ordering which matches the 53 | ``ethanol`` fixture.""" 54 | 55 | ethanol.generate_conformers(n_conformers=1) 56 | conformer = ethanol.conformers[0].m_as(_ANGSTROM) 57 | 58 | return torch.from_numpy(conformer) 59 | 60 | 61 | @pytest.fixture(scope="module") 62 | def ethanol_interchange(ethanol, default_force_field) -> openff.interchange.Interchange: 63 | """Returns a parameterized system of ethanol.""" 64 | 65 | return openff.interchange.Interchange.from_smirnoff( 66 | default_force_field, ethanol.to_topology() 67 | ) 68 | 69 | 70 | @pytest.fixture(scope="module") 71 | def formaldehyde() -> openff.toolkit.Molecule: 72 | """Returns an OpenFF formaldehyde molecule with a fixed atom order.""" 73 | 74 | return openff.toolkit.Molecule.from_mapped_smiles("[H:3][C:1](=[O:2])[H:4]") 75 | 76 | 77 | @pytest.fixture(scope="module") 78 | def formaldehyde_conformer(formaldehyde) -> torch.Tensor: 79 | """Returns a conformer [Å] of formaldehyde with an ordering which matches the 80 | ``formaldehyde`` fixture.""" 81 | 82 | formaldehyde.generate_conformers(n_conformers=1) 83 | conformer = formaldehyde.conformers[0].m_as(_ANGSTROM) 84 | 85 | return torch.from_numpy(conformer) 86 | 87 | 88 | @pytest.fixture(scope="module") 89 | def formaldehyde_interchange( 90 | formaldehyde, default_force_field 91 | ) -> openff.interchange.Interchange: 92 | """Returns a parameterized system of formaldehyde.""" 93 | 94 | return openff.interchange.Interchange.from_smirnoff( 95 | default_force_field, formaldehyde.to_topology() 96 | ) 97 | 98 | 99 | @pytest.fixture 100 | def v_site_force_field() -> openff.toolkit.ForceField: 101 | force_field = openff.toolkit.ForceField() 102 | 103 | force_field.get_parameter_handler("Electrostatics") 104 | 105 | vdw_handler = force_field.get_parameter_handler("vdW") 106 | vdw_handler.add_parameter( 107 | { 108 | "smirks": "[*:1]", 109 | "epsilon": 0.0 * _KJ_PER_MOLE, 110 | "sigma": 1.0 * _ANGSTROM, 111 | } 112 | ) 113 | 114 | charge_handler = force_field.get_parameter_handler("LibraryCharges") 115 | charge_handler.add_parameter( 116 | {"smirks": "[*:1]", "charge1": 0.0 * openff.units.unit.e} 117 | ) 118 | 119 | vsite_handler = force_field.get_parameter_handler("VirtualSites") 120 | 121 | vsite_handler.add_parameter( 122 | parameter_kwargs={ 123 | "smirks": "[H][#6:2]([H])=[#8:1]", 124 | "name": "EP", 125 | "type": "BondCharge", 126 | "distance": 7.0 * _ANGSTROM, 127 | "match": "all_permutations", 128 | "charge_increment1": 0.2 * _E, 129 | "charge_increment2": 0.1 * _E, 130 | "sigma": 1.0 * _ANGSTROM, 131 | "epsilon": 2.0 / 4.184 * _KCAL_PER_MOLE, 132 | } 133 | ) 134 | vsite_handler.add_parameter( 135 | parameter_kwargs={ 136 | "smirks": "[#8:1]=[#6X3:2](-[#17])-[#1:3]", 137 | "name": "EP", 138 | "type": "MonovalentLonePair", 139 | "distance": 1.234 * _ANGSTROM, 140 | "outOfPlaneAngle": 25.67 * _DEGREES, 141 | "inPlaneAngle": 134.0 * _DEGREES, 142 | "match": "all_permutations", 143 | "charge_increment1": 0.0 * _E, 144 | "charge_increment2": 1.0552 * 0.5 * _E, 145 | "charge_increment3": 1.0552 * 0.5 * _E, 146 | "sigma": 0.0 * _NM, 147 | "epsilon": 0.5 * _KJ_PER_MOLE, 148 | } 149 | ) 150 | vsite_handler.add_parameter( 151 | parameter_kwargs={ 152 | "smirks": "[#1:2]-[#8X2H2+0:1]-[#1:3]", 153 | "name": "EP", 154 | "type": "DivalentLonePair", 155 | "distance": -3.21 * _NM, 156 | "outOfPlaneAngle": 37.43 * _DEGREES, 157 | "match": "all_permutations", 158 | "charge_increment1": 0.0 * _E, 159 | "charge_increment2": 1.0552 * 0.5 * _E, 160 | "charge_increment3": 1.0552 * 0.5 * _E, 161 | "sigma": 1.0 * _ANGSTROM, 162 | "epsilon": 0.5 * _KJ_PER_MOLE, 163 | } 164 | ) 165 | vsite_handler.add_parameter( 166 | parameter_kwargs={ 167 | "smirks": "[#1:2][#7:1]([#1:3])[#1:4]", 168 | "name": "EP", 169 | "type": "TrivalentLonePair", 170 | "distance": 0.5 * _NM, 171 | "match": "once", 172 | "charge_increment1": 0.2 * _E, 173 | "charge_increment2": 0.0 * _E, 174 | "charge_increment3": 0.0 * _E, 175 | "charge_increment4": 0.0 * _E, 176 | "sigma": 1.0 * _ANGSTROM, 177 | "epsilon": 0.5 * _KJ_PER_MOLE, 178 | } 179 | ) 180 | return force_field 181 | -------------------------------------------------------------------------------- /smee/tests/convertors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimonBoothroyd/smee/aca69b9da4c67916c6e59ed2c435fffd4c49a2b6/smee/tests/convertors/__init__.py -------------------------------------------------------------------------------- /smee/tests/convertors/openff/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimonBoothroyd/smee/aca69b9da4c67916c6e59ed2c435fffd4c49a2b6/smee/tests/convertors/openff/__init__.py -------------------------------------------------------------------------------- /smee/tests/convertors/openff/test_nonbonded.py: -------------------------------------------------------------------------------- 1 | import openff.interchange 2 | import openff.interchange.models 3 | import openff.toolkit 4 | import openff.units 5 | import torch 6 | 7 | import smee 8 | import smee.converters 9 | from smee.converters.openff.nonbonded import ( 10 | convert_dexp, 11 | convert_electrostatics, 12 | convert_vdw, 13 | ) 14 | 15 | 16 | def test_convert_electrostatics_am1bcc(ethanol, ethanol_interchange): 17 | charge_collection = ethanol_interchange.collections["Electrostatics"] 18 | 19 | potential, parameter_maps = convert_electrostatics( 20 | [charge_collection], [ethanol.to_topology()], [None] 21 | ) 22 | 23 | assert potential.type == "Electrostatics" 24 | assert potential.fn == "coul" 25 | 26 | expected_attributes = torch.tensor( 27 | [0.0, 0.0, 5.0 / 6.0, 1.0, 9.0], dtype=torch.float64 28 | ) 29 | assert torch.allclose(potential.attributes, expected_attributes) 30 | assert potential.attribute_cols == ( 31 | "scale_12", 32 | "scale_13", 33 | "scale_14", 34 | "scale_15", 35 | smee.CUTOFF_ATTRIBUTE, 36 | ) 37 | 38 | assert potential.parameter_cols == ("charge",) 39 | 40 | assert all( 41 | parameter_key.id == "[O:1]([C:3]([C:2]([H:5])([H:6])[H:7])([H:8])[H:9])[H:4]" 42 | for parameter_key in potential.parameter_keys 43 | ) 44 | assert potential.parameters.shape == (9, 1) 45 | 46 | assert len(parameter_maps) == 1 47 | parameter_map = parameter_maps[0] 48 | 49 | assert parameter_map.assignment_matrix.shape == (ethanol.n_atoms, ethanol.n_atoms) 50 | assert torch.allclose( 51 | parameter_map.assignment_matrix.to_dense(), 52 | torch.eye(ethanol.n_atoms, dtype=torch.float64), 53 | ) 54 | 55 | n_expected_exclusions = 36 56 | assert parameter_map.exclusions.shape == (n_expected_exclusions, 2) 57 | assert parameter_map.exclusion_scale_idxs.shape == (n_expected_exclusions, 1) 58 | 59 | 60 | def test_convert_electrostatics_v_site(): 61 | force_field = openff.toolkit.ForceField() 62 | force_field.get_parameter_handler("Electrostatics") 63 | force_field.get_parameter_handler("vdW") 64 | 65 | charge_handler = force_field.get_parameter_handler("LibraryCharges") 66 | charge_handler.add_parameter( 67 | { 68 | "smirks": "[Cl:1]-[H:2]", 69 | "charge1": -0.75 * openff.units.unit.e, 70 | "charge2": 0.25 * openff.units.unit.e, 71 | } 72 | ) 73 | 74 | v_site_handler = force_field.get_parameter_handler("VirtualSites") 75 | v_site_handler.add_parameter( 76 | { 77 | "type": "BondCharge", 78 | "smirks": "[Cl:1]-[H:2]", 79 | "distance": 2.0 * openff.units.unit.angstrom, 80 | "match": "all_permutations", 81 | "charge_increment1": -0.25 * openff.units.unit.e, 82 | "charge_increment2": 0.5 * openff.units.unit.e, 83 | } 84 | ) 85 | 86 | molecule = openff.toolkit.Molecule.from_mapped_smiles("[Cl:2]-[H:1]") 87 | 88 | interchange = openff.interchange.Interchange.from_smirnoff( 89 | force_field, molecule.to_topology(), allow_nonintegral_charges=True 90 | ) 91 | charge_collection = interchange.collections["Electrostatics"] 92 | 93 | potential, parameter_maps = convert_electrostatics( 94 | [charge_collection], 95 | [molecule.to_topology()], 96 | [ 97 | smee.VSiteMap( 98 | keys=[*interchange.collections["VirtualSites"].key_map], 99 | key_to_idx=interchange.collections[ 100 | "VirtualSites" 101 | ].virtual_site_key_topology_index_map, 102 | parameter_idxs=torch.tensor([[0]]), 103 | ) 104 | ], 105 | ) 106 | 107 | assert potential.parameter_cols == ("charge",) 108 | expected_keys = [ 109 | openff.interchange.models.PotentialKey( 110 | id="[Cl:1]-[H:2]", mult=0, associated_handler="LibraryCharges" 111 | ), 112 | openff.interchange.models.PotentialKey( 113 | id="[Cl:1]-[H:2]", mult=1, associated_handler="LibraryCharges" 114 | ), 115 | openff.interchange.models.PotentialKey( 116 | id="[Cl:1]-[H:2] EP all_permutations", 117 | mult=0, 118 | associated_handler="Electrostatics", 119 | ), 120 | openff.interchange.models.PotentialKey( 121 | id="[Cl:1]-[H:2] EP all_permutations", 122 | mult=1, 123 | associated_handler="Electrostatics", 124 | ), 125 | ] 126 | assert potential.parameter_keys == expected_keys 127 | assert potential.parameters.shape == (4, 1) 128 | 129 | expected_parameters = torch.tensor( 130 | [[-0.75], [0.25], [-0.25], [0.5]], dtype=torch.float64 131 | ) 132 | assert torch.allclose(potential.parameters, expected_parameters) 133 | 134 | assert len(parameter_maps) == 1 135 | parameter_map = parameter_maps[0] 136 | 137 | n_particles = 3 138 | 139 | assert parameter_map.assignment_matrix.shape == (n_particles, len(expected_keys)) 140 | expected_assignment_matrix = torch.tensor( 141 | [ 142 | [0.0, 1.0, 0.0, 1.0], 143 | [1.0, 0.0, 1.0, 0.0], 144 | [0.0, 0.0, -1.0, -1.0], 145 | ], 146 | dtype=torch.float64, 147 | ) 148 | assert torch.allclose( 149 | parameter_map.assignment_matrix.to_dense(), expected_assignment_matrix 150 | ) 151 | 152 | n_expected_exclusions = 3 153 | assert parameter_map.exclusions.shape == (n_expected_exclusions, 2) 154 | assert parameter_map.exclusion_scale_idxs.shape == (n_expected_exclusions, 1) 155 | 156 | expected_exclusions = torch.tensor([[0, 1], [0, 2], [1, 2]], dtype=torch.long) 157 | assert torch.allclose(parameter_map.exclusions, expected_exclusions) 158 | 159 | expected_scales = torch.zeros((n_expected_exclusions, 1), dtype=torch.long) 160 | assert torch.allclose(parameter_map.exclusion_scale_idxs, expected_scales) 161 | 162 | 163 | def test_convert_electrostatics_tip4p(): 164 | """Explicitly test the case of TIP4P (FB) water to make sure v-site charges are 165 | correct. 166 | """ 167 | 168 | force_field = openff.toolkit.ForceField("tip4p_fb.offxml") 169 | molecule = openff.toolkit.Molecule.from_mapped_smiles("[H:2][O:1][H:3]") 170 | 171 | interchange = openff.interchange.Interchange.from_smirnoff( 172 | force_field, molecule.to_topology(), allow_nonintegral_charges=True 173 | ) 174 | 175 | tensor_top: smee.TensorTopology 176 | tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange) 177 | 178 | q = 0.5258681106763 179 | expected_charges = torch.tensor([[0.0], [q], [q], [-2.0 * q]], dtype=torch.float64) 180 | 181 | charges = ( 182 | tensor_top.parameters["Electrostatics"].assignment_matrix 183 | @ tensor_ff.potentials_by_type["Electrostatics"].parameters 184 | ) 185 | assert charges.shape == expected_charges.shape 186 | assert torch.allclose(charges, expected_charges) 187 | 188 | 189 | def test_convert_bci_and_vsite(): 190 | ff_off = openff.toolkit.ForceField() 191 | ff_off.get_parameter_handler("Electrostatics") 192 | ff_off.get_parameter_handler("vdW") 193 | 194 | charge_handler = ff_off.get_parameter_handler("ChargeIncrementModel") 195 | charge_handler.partial_charge_method = "am1-mulliken" 196 | charge_handler.add_parameter( 197 | {"smirks": "[O:1]-[H:2]", "charge_increment1": -0.1 * openff.units.unit.e} 198 | ) 199 | v_site_handler = ff_off.get_parameter_handler("VirtualSites") 200 | v_site_handler.add_parameter( 201 | { 202 | "type": "DivalentLonePair", 203 | "smirks": "[#1:2]-[#8X2H2+0:1]-[#1:3]", 204 | "distance": -0.1 * openff.units.unit.angstrom, 205 | "outOfPlaneAngle": 0.0 * openff.units.unit.degree, 206 | "match": "once", 207 | "charge_increment1": 0.0 * openff.units.unit.e, 208 | "charge_increment2": 0.53 * openff.units.unit.e, 209 | "charge_increment3": 0.53 * openff.units.unit.e, 210 | } 211 | ) 212 | 213 | mol = openff.toolkit.Molecule.from_mapped_smiles("[O:1]([H:2])[H:3]") 214 | mol.assign_partial_charges(charge_handler.partial_charge_method) 215 | 216 | interchange = openff.interchange.Interchange.from_smirnoff( 217 | ff_off, mol.to_topology() 218 | ) 219 | 220 | expected_charges = [ 221 | q.m_as("e") for q in interchange.collections["Electrostatics"].charges.values() 222 | ] 223 | 224 | ff, [top] = smee.converters.convert_interchange(interchange) 225 | 226 | charge_pot = ff.potentials_by_type["Electrostatics"] 227 | 228 | assert charge_pot.attribute_cols == ( 229 | "scale_12", 230 | "scale_13", 231 | "scale_14", 232 | "scale_15", 233 | "cutoff", 234 | ) 235 | assert torch.allclose( 236 | charge_pot.attributes, 237 | torch.tensor([0.0000, 0.0000, 1.0 / 1.2, 1.0000, 9.0000], dtype=torch.float64), 238 | ) 239 | 240 | assert charge_pot.parameter_cols == ("charge",) 241 | 242 | found_keys = [ 243 | (key.associated_handler, key.id, key.mult) for key in charge_pot.parameter_keys 244 | ] 245 | expected_keys = [ 246 | ("ChargeModel", "[O:1]([H:2])[H:3]", 0), 247 | ("ChargeModel", "[O:1]([H:2])[H:3]", 1), 248 | ("ChargeModel", "[O:1]([H:2])[H:3]", 2), 249 | ("ChargeModel", "[#1:2]-[#8X2H2+0:1]-[#1:3] EP once", 0), 250 | ("ChargeModel", "[#1:2]-[#8X2H2+0:1]-[#1:3] EP once", 1), 251 | ("ChargeModel", "[#1:2]-[#8X2H2+0:1]-[#1:3] EP once", 2), 252 | ("ChargeIncrementModel", "[O:1]-[H:2]", 0), 253 | ] 254 | assert found_keys == expected_keys 255 | 256 | expected_charge_params = torch.tensor( 257 | [*mol.partial_charges.m_as("e"), 0.0, 0.53, 0.53, -0.1] 258 | ).reshape(-1, 1) 259 | assert torch.allclose(charge_pot.parameters, expected_charge_params) 260 | 261 | param_map = top.parameters["Electrostatics"] 262 | 263 | found_exclusions = sorted((i, j) for i, j in param_map.exclusions.tolist()) 264 | expected_exclusions = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)] 265 | assert found_exclusions == expected_exclusions 266 | 267 | expected_assignment = torch.tensor( 268 | [ 269 | [1, 0, 0, +1, +0, +0, +2], 270 | [0, 1, 0, +0, +1, +0, -1], 271 | [0, 0, 1, +0, +0, +1, -1], 272 | [0, 0, 0, -1, -1, -1, +0], 273 | ], 274 | dtype=torch.float64, 275 | ) 276 | found_assignment = param_map.assignment_matrix.to_dense() 277 | 278 | assert found_assignment.shape == expected_assignment.shape 279 | assert torch.allclose(found_assignment, expected_assignment) 280 | 281 | found_charges = param_map.assignment_matrix @ charge_pot.parameters 282 | assert torch.allclose(found_charges.flatten(), torch.tensor(expected_charges)) 283 | 284 | 285 | def test_convert_vdw(ethanol, ethanol_interchange): 286 | vdw_collection = ethanol_interchange.collections["vdW"] 287 | 288 | potential, parameter_maps = convert_vdw( 289 | [vdw_collection], [ethanol.to_topology()], [None] 290 | ) 291 | 292 | assert potential.type == "vdW" 293 | assert potential.fn == smee.EnergyFn.VDW_LJ 294 | 295 | 296 | def test_convert_dexp(ethanol, test_data_dir): 297 | ff = openff.toolkit.ForceField( 298 | str(test_data_dir / "de-ff.offxml"), load_plugins=True 299 | ) 300 | 301 | interchange = openff.interchange.Interchange.from_smirnoff( 302 | ff, ethanol.to_topology() 303 | ) 304 | vdw_collection = interchange.collections["DoubleExponential"] 305 | 306 | potential, parameter_maps = convert_dexp( 307 | [vdw_collection], [ethanol.to_topology()], [None] 308 | ) 309 | 310 | assert potential.attribute_cols[-2:] == ("alpha", "beta") 311 | assert potential.parameter_cols == ("epsilon", "r_min") 312 | 313 | assert potential.type == "vdW" 314 | assert potential.fn == smee.EnergyFn.VDW_DEXP 315 | -------------------------------------------------------------------------------- /smee/tests/convertors/openff/test_openff.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import openff.interchange.models 4 | import openff.toolkit 5 | import openff.units 6 | import pytest 7 | import torch 8 | 9 | import smee 10 | import smee.tests.utils 11 | from smee.converters.openff._openff import ( 12 | _CONVERTERS, 13 | _convert_topology, 14 | _Converter, 15 | _resolve_conversion_order, 16 | convert_handlers, 17 | convert_interchange, 18 | smirnoff_parameter_converter, 19 | ) 20 | 21 | 22 | def test_parameter_converter(): 23 | smirnoff_parameter_converter("Dummy", {"parm-a": openff.units.unit.angstrom})( 24 | lambda x: None 25 | ) 26 | assert "Dummy" in _CONVERTERS 27 | assert "parm-a" in _CONVERTERS["Dummy"].units 28 | 29 | with pytest.raises(KeyError, match="A Dummy converter is already"): 30 | smirnoff_parameter_converter("Dummy", {})(lambda x: None) 31 | 32 | del _CONVERTERS["Dummy"] 33 | 34 | 35 | def test_convert_handler(ethanol, ethanol_interchange, mocker): 36 | # avoid already registered converter error 37 | importlib.import_module("smee.converters.openff.nonbonded") 38 | 39 | mock_deps = [(mocker.MagicMock(type="mock"), [mocker.MagicMock()])] 40 | mock_result = (mocker.MagicMock(), []) 41 | 42 | mock_convert = mocker.patch( 43 | "smee.tests.utils.mock_convert_fn_with_deps", 44 | autospec=True, 45 | return_value=mock_result, 46 | ) 47 | mocker.patch.dict(_CONVERTERS, {"vdW": _Converter(mock_convert, {}, ["mock"])}) 48 | 49 | handlers = [ethanol_interchange.collections["vdW"]] 50 | topologies = [ethanol.to_topology()] 51 | 52 | v_site = openff.interchange.models.VirtualSiteKey( 53 | orientation_atom_indices=(0, 1, 2), 54 | type="MonovalentLonePair", 55 | match="once", 56 | name="EP", 57 | ) 58 | v_site_maps = [ 59 | smee.VSiteMap([v_site], {v_site: ethanol.n_atoms}, torch.tensor([[0]])) 60 | ] 61 | 62 | result = convert_handlers(handlers, topologies, v_site_maps, mock_deps) 63 | 64 | mock_convert.assert_called_once_with( 65 | handlers, 66 | topologies=topologies, 67 | v_site_maps=v_site_maps, 68 | dependencies={"mock": mock_deps[0]}, 69 | ) 70 | assert result == [mock_result] 71 | 72 | 73 | def test_convert_topology(formaldehyde, mocker): 74 | parameters = mocker.MagicMock() 75 | v_sites = smee.VSiteMap([], {}, torch.tensor([])) 76 | constraints = smee.TensorConstraints(torch.tensor([1, 2]), torch.tensor([3.0])) 77 | 78 | topology = _convert_topology(formaldehyde, parameters, v_sites, constraints) 79 | 80 | assert topology.n_atoms == 4 81 | assert topology.n_bonds == 3 82 | 83 | expected_atomic_nums = torch.tensor([6, 8, 1, 1]) 84 | expected_formal_charges = torch.tensor([0, 0, 0, 0]) 85 | 86 | expected_bond_idxs = torch.tensor([[0, 1], [0, 2], [0, 3]]) 87 | expected_bond_orders = torch.tensor([2, 1, 1]) 88 | 89 | assert topology.atomic_nums.shape == expected_atomic_nums.shape 90 | assert torch.allclose(topology.atomic_nums, expected_atomic_nums) 91 | assert topology.formal_charges.shape == expected_formal_charges.shape 92 | assert torch.allclose(topology.formal_charges, expected_formal_charges) 93 | 94 | assert topology.bond_idxs.shape == expected_bond_idxs.shape 95 | assert torch.allclose(topology.bond_idxs, expected_bond_idxs) 96 | assert topology.bond_orders.shape == expected_bond_orders.shape 97 | assert torch.allclose(topology.bond_orders, expected_bond_orders) 98 | 99 | assert topology.parameters == parameters 100 | assert topology.v_sites == v_sites 101 | assert topology.constraints == constraints 102 | 103 | 104 | def test_resolve_conversion_order(mocker): 105 | mocker.patch.dict( 106 | _CONVERTERS, 107 | { 108 | "a": _Converter(mocker.MagicMock(), {}, ["c"]), 109 | "b": _Converter(mocker.MagicMock(), {}, []), 110 | "c": _Converter(mocker.MagicMock(), {}, ["b"]), 111 | }, 112 | ) 113 | 114 | order = _resolve_conversion_order(["a", "b", "c"]) 115 | assert order == ["b", "c", "a"] 116 | 117 | 118 | def test_convert_interchange(): 119 | force_field = openff.toolkit.ForceField() 120 | force_field.get_parameter_handler("Electrostatics") 121 | force_field.get_parameter_handler("vdW") 122 | 123 | constraint_handler = force_field.get_parameter_handler("Constraints") 124 | constraint_handler.add_parameter( 125 | {"smirks": "[Cl:1]-[H:2]", "distance": 0.2 * openff.units.unit.nanometer} 126 | ) 127 | 128 | charge_handler = force_field.get_parameter_handler("LibraryCharges") 129 | charge_handler.add_parameter( 130 | { 131 | "smirks": "[Cl:1]-[H:2]", 132 | "charge1": -0.75 * openff.units.unit.e, 133 | "charge2": 0.25 * openff.units.unit.e, 134 | } 135 | ) 136 | 137 | v_site_handler = force_field.get_parameter_handler("VirtualSites") 138 | v_site_handler.add_parameter( 139 | { 140 | "type": "BondCharge", 141 | "smirks": "[Cl:1]-[H:2]", 142 | "distance": 2.0 * openff.units.unit.angstrom, 143 | "match": "all_permutations", 144 | "charge_increment1": -0.25 * openff.units.unit.e, 145 | "charge_increment2": 0.5 * openff.units.unit.e, 146 | } 147 | ) 148 | 149 | molecule = openff.toolkit.Molecule.from_mapped_smiles("[Cl:2]-[H:1]") 150 | 151 | interchange = openff.interchange.Interchange.from_smirnoff( 152 | force_field, molecule.to_topology(), allow_nonintegral_charges=True 153 | ) 154 | 155 | tensor_force_field, tensor_topologies = convert_interchange(interchange) 156 | 157 | assert {*tensor_force_field.potentials_by_type} == {"vdW", "Electrostatics"} 158 | 159 | assert tensor_force_field.v_sites is not None 160 | assert len(tensor_force_field.v_sites.keys) == 1 161 | assert tensor_force_field.v_sites.keys[0].id == "[Cl:1]-[H:2] EP all_permutations" 162 | expected_parameters = torch.tensor([[2.0, torch.pi, 0.0]]) 163 | assert torch.allclose(tensor_force_field.v_sites.parameters, expected_parameters) 164 | assert len(tensor_force_field.v_sites.weights) == 1 165 | 166 | assert len(tensor_topologies) == 1 167 | tensor_topology = tensor_topologies[0] 168 | 169 | assert len(tensor_topology.v_sites.keys) == 1 170 | assert tensor_topology.v_sites.keys[0].type == "BondCharge" 171 | assert tensor_topology.v_sites.keys[0].orientation_atom_indices == (1, 0) 172 | 173 | assert tensor_topology.constraints is not None 174 | expected_constraint_idxs = torch.tensor([[0, 1]]) 175 | assert tensor_topology.constraints.idxs.shape == expected_constraint_idxs.shape 176 | assert torch.allclose(tensor_topology.constraints.idxs, expected_constraint_idxs) 177 | 178 | expected_constraint_distances = torch.tensor([2.0]) 179 | assert ( 180 | tensor_topology.constraints.distances.shape 181 | == expected_constraint_distances.shape 182 | ) 183 | assert torch.allclose( 184 | tensor_topology.constraints.distances, expected_constraint_distances 185 | ) 186 | 187 | 188 | def test_convert_interchange_multiple( 189 | ethanol_conformer, 190 | ethanol_interchange, 191 | formaldehyde_conformer, 192 | formaldehyde_interchange, 193 | ): 194 | force_field, topologies = convert_interchange( 195 | [ethanol_interchange, formaldehyde_interchange] 196 | ) 197 | assert len(topologies) == 2 198 | 199 | expected_potentials = { 200 | "Angles", 201 | "Bonds", 202 | "Electrostatics", 203 | "ImproperTorsions", 204 | "ProperTorsions", 205 | "vdW", 206 | } 207 | assert {*force_field.potentials_by_type} == expected_potentials 208 | 209 | expected_charge_keys = [ 210 | openff.interchange.models.PotentialKey( 211 | id="[O:1]([C:3]([C:2]([H:5])([H:6])[H:7])([H:8])[H:9])[H:4]", 212 | mult=0, 213 | associated_handler="ToolkitAM1BCCHandler", 214 | ), 215 | openff.interchange.models.PotentialKey( 216 | id="[C:1](=[O:2])([H:3])[H:4]", 217 | mult=0, 218 | associated_handler="ToolkitAM1BCCHandler", 219 | ), 220 | ] 221 | assert all( 222 | key in force_field.potentials_by_type["Electrostatics"].parameter_keys 223 | for key in expected_charge_keys 224 | ) 225 | 226 | expected_improper_keys = [ 227 | openff.interchange.models.PotentialKey( 228 | id="[*:1]~[#6X3:2](~[*:3])~[*:4]", 229 | mult=0, 230 | associated_handler="ImproperTorsions", 231 | ), 232 | ] 233 | assert ( 234 | force_field.potentials_by_type["ImproperTorsions"].parameter_keys 235 | == expected_improper_keys 236 | ) 237 | -------------------------------------------------------------------------------- /smee/tests/convertors/openff/test_valence.py: -------------------------------------------------------------------------------- 1 | import openff.interchange 2 | import openff.toolkit 3 | import pytest 4 | 5 | from smee.converters.openff.valence import ( 6 | convert_angles, 7 | convert_bonds, 8 | convert_impropers, 9 | convert_propers, 10 | ) 11 | 12 | 13 | def test_convert_bonds(ethanol, ethanol_interchange): 14 | bond_collection = ethanol_interchange.collections["Bonds"] 15 | 16 | potential, parameter_maps = convert_bonds([bond_collection], [set()]) 17 | 18 | assert potential.type == "Bonds" 19 | assert potential.fn == "k/2*(r-length)**2" 20 | 21 | assert potential.attributes is None 22 | assert potential.attribute_cols is None 23 | 24 | assert potential.parameter_cols == ("k", "length") 25 | 26 | parameter_keys = [key.id for key in potential.parameter_keys] 27 | expected_parameter_keys = [ 28 | "[#6:1]-[#8:2]", 29 | "[#6X4:1]-[#1:2]", 30 | "[#6X4:1]-[#6X4:2]", 31 | "[#8:1]-[#1:2]", 32 | ] 33 | assert sorted(parameter_keys) == sorted(expected_parameter_keys) 34 | 35 | assert potential.parameters.shape == (4, 2) 36 | 37 | assert len(parameter_maps) == 1 38 | parameter_map = parameter_maps[0] 39 | 40 | assert len(parameter_map.assignment_matrix) == len(parameter_map.particle_idxs) 41 | assignment_matrix = parameter_map.assignment_matrix.to_dense() 42 | 43 | actual_parameters = { 44 | tuple(particle_idxs.tolist()): parameter_keys[parameter_idxs.nonzero()] 45 | for parameter_idxs, particle_idxs in zip( 46 | assignment_matrix, parameter_map.particle_idxs, strict=True 47 | ) 48 | } 49 | expected_parameters = { 50 | (0, 2): "[#6:1]-[#8:2]", 51 | (0, 3): "[#8:1]-[#1:2]", 52 | (1, 2): "[#6X4:1]-[#6X4:2]", 53 | (1, 4): "[#6X4:1]-[#1:2]", 54 | (1, 5): "[#6X4:1]-[#1:2]", 55 | (1, 6): "[#6X4:1]-[#1:2]", 56 | (2, 7): "[#6X4:1]-[#1:2]", 57 | (2, 8): "[#6X4:1]-[#1:2]", 58 | } 59 | 60 | assert actual_parameters == expected_parameters 61 | 62 | 63 | def test_convert_bonds_with_constraints(ethanol): 64 | interchange = openff.interchange.Interchange.from_smirnoff( 65 | openff.toolkit.ForceField("openff-1.3.0.offxml"), ethanol.to_topology() 66 | ) 67 | 68 | bond_collection = interchange.collections["Bonds"] 69 | 70 | constraints = { 71 | (bond.atom1_index, bond.atom2_index) 72 | for bond in ethanol.bonds 73 | if bond.atom1.atomic_number == 1 or bond.atom2.atomic_number == 1 74 | } 75 | 76 | potential, [parameter_map] = convert_bonds([bond_collection], [constraints]) 77 | parameter_keys = [key.id for key in potential.parameter_keys] 78 | 79 | assert len(parameter_map.assignment_matrix) == len(parameter_map.particle_idxs) 80 | assignment_matrix = parameter_map.assignment_matrix.to_dense() 81 | 82 | actual_parameters = { 83 | tuple(particle_idxs.tolist()): parameter_keys[parameter_idxs.nonzero()] 84 | for parameter_idxs, particle_idxs in zip( 85 | assignment_matrix, parameter_map.particle_idxs, strict=True 86 | ) 87 | } 88 | expected_parameters = {(0, 2): "[#6:1]-[#8:2]", (1, 2): "[#6X4:1]-[#6X4:2]"} 89 | 90 | assert actual_parameters == expected_parameters 91 | 92 | 93 | @pytest.mark.parametrize("with_constraints", [True, False]) 94 | def test_convert_angles_etoh(ethanol, ethanol_interchange, with_constraints): 95 | angle_collection = ethanol_interchange.collections["Angles"] 96 | 97 | h_bond_idxs = { 98 | (bond.atom1_index, bond.atom2_index) 99 | for bond in ethanol.bonds 100 | if bond.atom1.atomic_number == 1 or bond.atom2.atomic_number == 1 101 | } 102 | constraints = set() if not with_constraints else h_bond_idxs 103 | 104 | potential, parameter_maps = convert_angles([angle_collection], [constraints]) 105 | 106 | assert potential.type == "Angles" 107 | assert potential.fn == "k/2*(theta-angle)**2" 108 | 109 | assert potential.attributes is None 110 | assert potential.attribute_cols is None 111 | 112 | assert potential.parameter_cols == ("k", "angle") 113 | 114 | parameter_keys = [key.id for key in potential.parameter_keys] 115 | expected_parameter_keys = [ 116 | "[#1:1]-[#6X4:2]-[#1:3]", 117 | "[*:1]-[#8:2]-[*:3]", 118 | "[*:1]~[#6X4:2]-[*:3]", 119 | ] 120 | assert sorted(parameter_keys) == sorted(expected_parameter_keys) 121 | 122 | assert potential.parameters.shape == (3, 2) 123 | 124 | assert len(parameter_maps) == 1 125 | parameter_map = parameter_maps[0] 126 | 127 | assert len(parameter_map.assignment_matrix) == len(parameter_map.particle_idxs) 128 | assignment_matrix = parameter_map.assignment_matrix.to_dense() 129 | 130 | actual_parameters = { 131 | tuple(particle_idxs.tolist()): parameter_keys[parameter_idxs.nonzero()] 132 | for parameter_idxs, particle_idxs in zip( 133 | assignment_matrix, parameter_map.particle_idxs, strict=True 134 | ) 135 | } 136 | expected_parameters = { 137 | (0, 2, 1): "[*:1]~[#6X4:2]-[*:3]", 138 | (0, 2, 7): "[*:1]~[#6X4:2]-[*:3]", 139 | (0, 2, 8): "[*:1]~[#6X4:2]-[*:3]", 140 | (1, 2, 7): "[*:1]~[#6X4:2]-[*:3]", 141 | (1, 2, 8): "[*:1]~[#6X4:2]-[*:3]", 142 | (2, 0, 3): "[*:1]-[#8:2]-[*:3]", 143 | (2, 1, 4): "[*:1]~[#6X4:2]-[*:3]", 144 | (2, 1, 5): "[*:1]~[#6X4:2]-[*:3]", 145 | (2, 1, 6): "[*:1]~[#6X4:2]-[*:3]", 146 | (4, 1, 5): "[#1:1]-[#6X4:2]-[#1:3]", 147 | (4, 1, 6): "[#1:1]-[#6X4:2]-[#1:3]", 148 | (5, 1, 6): "[#1:1]-[#6X4:2]-[#1:3]", 149 | (7, 2, 8): "[#1:1]-[#6X4:2]-[#1:3]", 150 | } 151 | 152 | assert actual_parameters == expected_parameters 153 | 154 | 155 | @pytest.mark.parametrize("with_constraints", [True, False]) 156 | def test_convert_angle_water(with_constraints): 157 | interchange = openff.interchange.Interchange.from_smirnoff( 158 | openff.toolkit.ForceField("openff-1.3.0.offxml"), 159 | openff.toolkit.Molecule.from_mapped_smiles("[O:1]([H:2])[H:3]").to_topology(), 160 | ) 161 | 162 | angle_collection = interchange.collections["Angles"] 163 | 164 | constraints = set() if not with_constraints else {(0, 1), (0, 2), (1, 2)} 165 | 166 | potential, [parameter_map] = convert_angles([angle_collection], [constraints]) 167 | parameter_keys = [key.id for key in potential.parameter_keys] 168 | 169 | assert len(parameter_map.assignment_matrix) == len(parameter_map.particle_idxs) 170 | assignment_matrix = parameter_map.assignment_matrix.to_dense() 171 | 172 | actual_parameters = { 173 | tuple(particle_idxs.tolist()): parameter_keys[parameter_idxs.nonzero()] 174 | for parameter_idxs, particle_idxs in zip( 175 | assignment_matrix, parameter_map.particle_idxs, strict=True 176 | ) 177 | } 178 | expected_parameters = {} if with_constraints else {(1, 0, 2): "[*:1]-[#8:2]-[*:3]"} 179 | 180 | assert actual_parameters == expected_parameters 181 | 182 | 183 | def test_convert_propers(ethanol, ethanol_interchange): 184 | proper_collection = ethanol_interchange.collections["ProperTorsions"] 185 | 186 | potential, parameter_maps = convert_propers([proper_collection]) 187 | 188 | assert potential.type == "ProperTorsions" 189 | assert potential.fn == "k*(1+cos(periodicity*theta-phase))" 190 | 191 | hcco_smirks = "[#1:1]-[#6X4:2]-[#6X4:3]-[#8X2:4]" 192 | ccoh_smirks = "[#6X4:1]-[#6X4:2]-[#8X2H1:3]-[#1:4]" 193 | xcoh_smirks = "[*:1]-[#6X4:2]-[#8X2:3]-[#1:4]" 194 | hcch_smirks = "[#1:1]-[#6X4:2]-[#6X4:3]-[#1:4]" 195 | 196 | assert len(parameter_maps) == 1 197 | parameter_map = parameter_maps[0] 198 | 199 | assert len(parameter_map.assignment_matrix) == len(parameter_map.particle_idxs) 200 | assignment_matrix = parameter_map.assignment_matrix.to_dense() 201 | 202 | actual_parameters = { 203 | ( 204 | tuple(particle_idxs.tolist()), 205 | potential.parameter_keys[parameter_idxs.nonzero()].id, 206 | potential.parameter_keys[parameter_idxs.nonzero()].mult, 207 | ) 208 | for parameter_idxs, particle_idxs in zip( 209 | assignment_matrix, parameter_map.particle_idxs, strict=True 210 | ) 211 | } 212 | expected_parameters = { 213 | ((0, 2, 1, 4), hcco_smirks, 1), 214 | ((0, 2, 1, 4), hcco_smirks, 0), 215 | ((0, 2, 1, 5), hcco_smirks, 1), 216 | ((0, 2, 1, 5), hcco_smirks, 0), 217 | ((0, 2, 1, 6), hcco_smirks, 1), 218 | ((0, 2, 1, 6), hcco_smirks, 0), 219 | ((1, 2, 0, 3), ccoh_smirks, 1), 220 | ((1, 2, 0, 3), ccoh_smirks, 0), 221 | ((3, 0, 2, 7), xcoh_smirks, 0), 222 | ((3, 0, 2, 8), xcoh_smirks, 0), 223 | ((4, 1, 2, 7), hcch_smirks, 0), 224 | ((4, 1, 2, 8), hcch_smirks, 0), 225 | ((5, 1, 2, 7), hcch_smirks, 0), 226 | ((5, 1, 2, 8), hcch_smirks, 0), 227 | ((6, 1, 2, 7), hcch_smirks, 0), 228 | ((6, 1, 2, 8), hcch_smirks, 0), 229 | } 230 | assert actual_parameters == expected_parameters 231 | 232 | 233 | def test_convert_impropers(formaldehyde, formaldehyde_interchange): 234 | improper_collection = formaldehyde_interchange.collections["ImproperTorsions"] 235 | 236 | potential, parameter_maps = convert_impropers([improper_collection]) 237 | 238 | assert potential.type == "ImproperTorsions" 239 | assert potential.fn == "k*(1+cos(periodicity*theta-phase))" 240 | -------------------------------------------------------------------------------- /smee/tests/convertors/openmm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimonBoothroyd/smee/aca69b9da4c67916c6e59ed2c435fffd4c49a2b6/smee/tests/convertors/openmm/__init__.py -------------------------------------------------------------------------------- /smee/tests/convertors/openmm/test_ff.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import openff.interchange 3 | import openff.toolkit 4 | import openmm.app 5 | import openmm.unit 6 | import pytest 7 | 8 | import smee 9 | import smee.converters 10 | import smee.converters.openmm 11 | 12 | 13 | def compute_energy(system: openmm.System, coords: numpy.ndarray) -> float: 14 | ctx = openmm.Context(system, openmm.VerletIntegrator(0.01)) 15 | ctx.setPositions(coords * openmm.unit.angstrom) 16 | 17 | state = ctx.getState(getEnergy=True) 18 | return state.getPotentialEnergy().value_in_unit(openmm.unit.kilojoule_per_mole) 19 | 20 | 21 | def compute_v_site_coords( 22 | system: openmm.System, coords: numpy.ndarray 23 | ) -> numpy.ndarray: 24 | ctx = openmm.Context(system, openmm.VerletIntegrator(0.01)) 25 | ctx.setPositions(coords * openmm.unit.angstrom) 26 | ctx.computeVirtualSites() 27 | 28 | state = ctx.getState(getPositions=True) 29 | return state.getPositions(asNumpy=True).value_in_unit(openmm.unit.angstrom) 30 | 31 | 32 | @pytest.mark.parametrize("with_constraints", [True, False]) 33 | @pytest.mark.parametrize("smiles", ["CO", "C=O", "Oc1ccccc1"]) 34 | def test_convert_to_openmm_ffxml(tmp_cwd, with_constraints, smiles): 35 | off_ff = openff.toolkit.ForceField( 36 | "openff-2.0.0.offxml" 37 | if with_constraints 38 | else "openff_unconstrained-2.0.0.offxml" 39 | ) 40 | 41 | off_mol = openff.toolkit.Molecule.from_smiles(smiles) 42 | off_top = off_mol.to_topology() 43 | 44 | interchange = openff.interchange.Interchange.from_smirnoff(off_ff, off_top) 45 | ff, [top] = smee.converters.convert_interchange(interchange) 46 | 47 | [ffxml] = smee.converters.openmm._ff.convert_to_openmm_ffxml(ff, top) 48 | 49 | ffxml_path = tmp_cwd / "ff.xml" 50 | ffxml_path.write_text(ffxml) 51 | 52 | omm_ff = openmm.app.ForceField(str(ffxml_path)) 53 | 54 | system_from_xml = omm_ff.createSystem( 55 | smee.converters.convert_to_openmm_topology(top), 56 | nonbondedCutoff=9.0 * openmm.unit.angstrom, 57 | switchDistance=8.0 * openmm.unit.angstrom, 58 | constraints=openmm.app.HBonds if with_constraints else None, 59 | rigidWater=True, 60 | removeCMMotion=False, 61 | ) 62 | system_from_off = off_ff.create_openmm_system(off_top) 63 | 64 | assert system_from_xml.getNumParticles() == system_from_off.getNumParticles() 65 | assert system_from_xml.getNumForces() == system_from_off.getNumForces() 66 | 67 | assert system_from_xml.getNumConstraints() == system_from_off.getNumConstraints() 68 | 69 | constraints_from_off = {} 70 | 71 | for i in range(system_from_off.getNumConstraints()): 72 | idx_a, idx_b, dist = system_from_off.getConstraintParameters(i) 73 | constraints_from_off[idx_a, idx_b] = dist.value_in_unit(openmm.unit.nanometer) 74 | 75 | constraints_from_xml = {} 76 | 77 | for i in range(system_from_xml.getNumConstraints()): 78 | idx_a, idx_b, dist = system_from_xml.getConstraintParameters(i) 79 | constraints_from_xml[idx_a, idx_b] = dist.value_in_unit(openmm.unit.nanometer) 80 | 81 | assert constraints_from_xml == pytest.approx(constraints_from_off) 82 | 83 | off_mol.generate_conformers(n_conformers=1) 84 | coords = off_mol.conformers[0].m_as("angstrom") 85 | 86 | for _ in range(5): 87 | coords_rand = coords + numpy.random.randn(*coords.shape) * 0.1 88 | 89 | energy_off = compute_energy(system_from_off, coords_rand) 90 | energy_xml = compute_energy(system_from_xml, coords_rand) 91 | 92 | assert energy_off == pytest.approx(energy_xml, abs=1.0e-3) 93 | 94 | 95 | def test_convert_to_openmm_ffxml_v_sites(tmp_cwd): 96 | off_ff = openff.toolkit.ForceField("tip5p.offxml") 97 | 98 | off_mol = openff.toolkit.Molecule.from_smiles("O") 99 | off_top = off_mol.to_topology() 100 | 101 | interchange = openff.interchange.Interchange.from_smirnoff(off_ff, off_top) 102 | ff, [top] = smee.converters.convert_interchange(interchange) 103 | 104 | [ffxml] = smee.converters.openmm._ff.convert_to_openmm_ffxml(ff, top) 105 | 106 | ffxml_path = tmp_cwd / "ff.xml" 107 | ffxml_path.write_text(ffxml) 108 | 109 | omm_ff = openmm.app.ForceField(str(ffxml_path)) 110 | 111 | system_from_xml = omm_ff.createSystem( 112 | smee.converters.convert_to_openmm_topology(top), 113 | nonbondedCutoff=9.0 * openmm.unit.angstrom, 114 | switchDistance=8.0 * openmm.unit.angstrom, 115 | constraints=openmm.app.HBonds, 116 | rigidWater=True, 117 | removeCMMotion=False, 118 | ) 119 | system_from_off = off_ff.create_openmm_system(off_top) 120 | 121 | assert system_from_xml.getNumParticles() == system_from_off.getNumParticles() 122 | 123 | off_mol.generate_conformers(n_conformers=1) 124 | 125 | coords = off_mol.conformers[0].m_as("angstrom") 126 | coords = numpy.vstack([coords, numpy.zeros((2, 3))]) 127 | 128 | coords_from_xml = compute_v_site_coords(system_from_xml, coords) 129 | coords_from_off = compute_v_site_coords(system_from_off, coords) 130 | 131 | assert coords_from_xml.shape == coords_from_off.shape 132 | assert numpy.allclose(coords_from_xml, coords_from_off, atol=1.0e-3) 133 | 134 | params_from_xml = {} 135 | [nb_force_from_xml] = [ 136 | force 137 | for force in system_from_xml.getForces() 138 | if isinstance(force, openmm.NonbondedForce) 139 | ] 140 | 141 | for i in range(nb_force_from_xml.getNumParticles()): 142 | charge, sigma, epsilon = nb_force_from_xml.getParticleParameters(i) 143 | params_from_xml[i] = ( 144 | charge.value_in_unit(openmm.unit.elementary_charge), 145 | sigma.value_in_unit(openmm.unit.angstrom), 146 | epsilon.value_in_unit(openmm.unit.kilojoule_per_mole), 147 | ) 148 | 149 | params_from_off = {} 150 | [nb_force_from_off] = [ 151 | force 152 | for force in system_from_off.getForces() 153 | if isinstance(force, openmm.NonbondedForce) 154 | ] 155 | 156 | for i in range(nb_force_from_off.getNumParticles()): 157 | charge, sigma, epsilon = nb_force_from_off.getParticleParameters(i) 158 | params_from_off[i] = ( 159 | charge.value_in_unit(openmm.unit.elementary_charge), 160 | sigma.value_in_unit(openmm.unit.angstrom), 161 | epsilon.value_in_unit(openmm.unit.kilojoule_per_mole), 162 | ) 163 | 164 | assert len(params_from_xml) == len(params_from_off) 165 | 166 | for i in range(nb_force_from_off.getNumParticles()): 167 | assert params_from_xml[i][0] == pytest.approx(params_from_off[i][0]) 168 | assert params_from_xml[i][1] == pytest.approx(params_from_off[i][1]) 169 | assert params_from_xml[i][2] == pytest.approx(params_from_off[i][2]) 170 | -------------------------------------------------------------------------------- /smee/tests/convertors/openmm/test_openmm.py: -------------------------------------------------------------------------------- 1 | import numpy.random 2 | import openff.interchange 3 | import openff.toolkit 4 | import openff.units 5 | import openmm 6 | import pytest 7 | import torch 8 | 9 | import smee 10 | import smee.mm 11 | import smee.potentials 12 | import smee.tests.utils 13 | from smee.converters.openmm import ( 14 | convert_to_openmm_force, 15 | convert_to_openmm_system, 16 | convert_to_openmm_topology, 17 | create_openmm_system, 18 | ) 19 | 20 | 21 | def _compute_energy( 22 | system: openmm.System, 23 | coords: openmm.unit.Quantity, 24 | box_vectors: openmm.unit.Quantity | None, 25 | ) -> float: 26 | if box_vectors is not None: 27 | system.setDefaultPeriodicBoxVectors(*box_vectors) 28 | 29 | integrator = openmm.VerletIntegrator(1.0 * openmm.unit.femtoseconds) 30 | context = openmm.Context( 31 | system, integrator, openmm.Platform.getPlatformByName("Reference") 32 | ) 33 | 34 | if box_vectors is not None: 35 | context.setPeriodicBoxVectors(*box_vectors) 36 | 37 | context.setPositions(coords) 38 | 39 | state = context.getState(getEnergy=True) 40 | 41 | return state.getPotentialEnergy().value_in_unit(openmm.unit.kilocalories_per_mole) 42 | 43 | 44 | def _compare_smee_and_interchange( 45 | tensor_ff: smee.TensorForceField, 46 | tensor_system: smee.TensorSystem, 47 | interchange: openff.interchange.Interchange, 48 | coords: openmm.unit.Quantity, 49 | box_vectors: openmm.unit.Quantity | None, 50 | ): 51 | system_smee = convert_to_openmm_system(tensor_ff, tensor_system) 52 | assert isinstance(system_smee, openmm.System) 53 | system_interchange = interchange.to_openmm(False, False) 54 | 55 | coords += (numpy.random.randn(*coords.shape) * 0.1) * openmm.unit.angstrom 56 | 57 | energy_smee = _compute_energy(system_smee, coords, box_vectors) 58 | energy_interchange = _compute_energy(system_interchange, coords, box_vectors) 59 | 60 | assert numpy.isclose(energy_smee, energy_interchange) 61 | 62 | 63 | def test_create_openmm_system_v_sites(v_site_force_field): 64 | smiles = [ 65 | "[H:3][C:2]([H:4])=[O:1]", 66 | "[Cl:3][C:2]([H:4])=[O:1]", 67 | "[H:2][O:1][H:3]", 68 | "[H:2][N:1]([H:3])[H:4]", 69 | ] 70 | 71 | interchange_full = openff.interchange.Interchange.from_smirnoff( 72 | v_site_force_field, 73 | openff.toolkit.Topology.from_molecules( 74 | [openff.toolkit.Molecule.from_mapped_smiles(pattern) for pattern in smiles] 75 | ), 76 | ) 77 | 78 | system_interchange = interchange_full.to_openmm() 79 | n_particles = system_interchange.getNumParticles() 80 | 81 | force_field, topologies = smee.converters.convert_interchange( 82 | [ 83 | openff.interchange.Interchange.from_smirnoff( 84 | v_site_force_field, 85 | openff.toolkit.Molecule.from_mapped_smiles(pattern).to_topology(), 86 | ) 87 | for pattern in smiles 88 | ] 89 | ) 90 | 91 | system_smee = create_openmm_system( 92 | smee.TensorSystem(topologies, [1] * len(smiles), False), force_field.v_sites 93 | ) 94 | 95 | expected_v_site_idxs = [4, 9, 13, 14, 19] 96 | actual_v_site_idxs = [ 97 | i for i in range(system_smee.getNumParticles()) if system_smee.isVirtualSite(i) 98 | ] 99 | assert actual_v_site_idxs == expected_v_site_idxs 100 | 101 | v_sites_interchange = [ 102 | # interchange puts all v-sites at the end of a topology 103 | system_interchange.getVirtualSite(n_particles - 5 + i) 104 | for i in range(5) 105 | ] 106 | v_sites_smee = [system_smee.getVirtualSite(i) for i in expected_v_site_idxs] 107 | 108 | def compare_vec3(a: openmm.Vec3, b: openmm.Vec3): 109 | assert a.unit == b.unit 110 | assert numpy.allclose( 111 | numpy.array([*a.value_in_unit(a.unit)]), 112 | numpy.array([*b.value_in_unit(a.unit)]), 113 | atol=1.0e-5, 114 | ) 115 | 116 | expected_particle_idxs = [ 117 | [0, 1], 118 | [5, 6, 8], 119 | [10, 11, 12], 120 | [10, 12, 11], 121 | [15, 16, 17, 18], 122 | ] 123 | 124 | for i, (v_site_interchange, v_site_smee) in enumerate( 125 | zip(v_sites_interchange, v_sites_smee, strict=True) 126 | ): 127 | assert v_site_smee.getNumParticles() == v_site_interchange.getNumParticles() 128 | 129 | particles_smee = [ 130 | v_site_smee.getParticle(i) for i in range(v_site_smee.getNumParticles()) 131 | ] 132 | assert particles_smee == expected_particle_idxs[i] 133 | 134 | compare_vec3( 135 | v_site_smee.getLocalPosition(), v_site_interchange.getLocalPosition() 136 | ) 137 | assert v_site_smee.getOriginWeights() == pytest.approx( 138 | v_site_interchange.getOriginWeights() 139 | ) 140 | assert v_site_smee.getXWeights() == pytest.approx( 141 | v_site_interchange.getXWeights() 142 | ) 143 | assert v_site_smee.getYWeights() == pytest.approx( 144 | v_site_interchange.getYWeights() 145 | ) 146 | 147 | 148 | @pytest.mark.parametrize("with_constraints", [True, False]) 149 | def test_convert_to_openmm_system_vacuum(with_constraints): 150 | # carbonic acid has impropers, 1-5 interactions so should test most convertors 151 | mol = openff.toolkit.Molecule.from_smiles("OC(=O)O") 152 | mol.generate_conformers(n_conformers=1) 153 | 154 | coords = mol.conformers[0].m_as(openff.units.unit.angstrom) 155 | coords = coords * openmm.unit.angstrom 156 | 157 | force_field = openff.toolkit.ForceField( 158 | "openff-2.0.0.offxml" 159 | if with_constraints 160 | else "openff_unconstrained-2.0.0.offxml" 161 | ) 162 | interchange = openff.interchange.Interchange.from_smirnoff( 163 | force_field, mol.to_topology() 164 | ) 165 | 166 | tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange) 167 | 168 | _compare_smee_and_interchange(tensor_ff, tensor_top, interchange, coords, None) 169 | 170 | 171 | @pytest.mark.parametrize("with_constraints", [True, False]) 172 | def test_convert_to_openmm_system_periodic(with_constraints): 173 | ff = openff.toolkit.ForceField( 174 | "openff-2.0.0.offxml" 175 | if with_constraints 176 | else "openff_unconstrained-2.0.0.offxml" 177 | ) 178 | top = openff.toolkit.Topology() 179 | 180 | interchanges = [] 181 | 182 | n_copies_per_mol = [5, 5] 183 | 184 | # carbonic acid has impropers, 1-5 interactions so should test most convertors 185 | for smiles, n_copies in zip(["OC(=O)O", "O"], n_copies_per_mol, strict=True): 186 | mol = openff.toolkit.Molecule.from_smiles(smiles) 187 | mol.generate_conformers(n_conformers=1) 188 | 189 | interchange = openff.interchange.Interchange.from_smirnoff( 190 | ff, mol.to_topology() 191 | ) 192 | interchanges.append(interchange) 193 | 194 | for _ in range(n_copies): 195 | top.add_molecule(mol) 196 | 197 | tensor_ff, tensor_tops = smee.converters.convert_interchange(interchanges) 198 | tensor_system = smee.TensorSystem(tensor_tops, n_copies_per_mol, True) 199 | 200 | coords, _ = smee.mm.generate_system_coords( 201 | tensor_system, None, smee.mm.GenerateCoordsConfig() 202 | ) 203 | box_vectors = numpy.eye(3) * 20.0 * openmm.unit.angstrom 204 | 205 | top.box_vectors = box_vectors 206 | 207 | interchange_top = openff.interchange.Interchange.from_smirnoff(ff, top) 208 | 209 | _compare_smee_and_interchange( 210 | tensor_ff, tensor_system, interchange_top, coords, box_vectors 211 | ) 212 | 213 | 214 | @pytest.mark.parametrize("with_exception", [True, False]) 215 | def test_convert_lj_potential_with_exceptions(with_exception): 216 | system, vdw_potential, _ = smee.tests.utils.system_with_exceptions() 217 | 218 | vdw_potential.exceptions = {} if not with_exception else vdw_potential.exceptions 219 | 220 | forces = convert_to_openmm_force(vdw_potential, system) 221 | 222 | assert len(forces) == 2 223 | 224 | assert isinstance(forces[0], openmm.CustomNonbondedForce) 225 | assert isinstance(forces[1], openmm.CustomBondForce) 226 | 227 | coords = torch.tensor( 228 | [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]] 229 | ) 230 | expected_energy = smee.compute_energy_potential(system, vdw_potential, coords) 231 | 232 | omm_system = openmm.System() 233 | for _ in range(system.n_atoms): 234 | omm_system.addParticle(1.0) 235 | for force in forces: 236 | omm_system.addForce(force) 237 | 238 | context = openmm.Context( 239 | omm_system, 240 | openmm.VerletIntegrator(1.0), 241 | openmm.Platform.getPlatformByName("Reference"), 242 | ) 243 | context.setPositions(coords.numpy() * openmm.unit.angstrom) 244 | 245 | energy = ( 246 | context.getState(getEnergy=True) 247 | .getPotentialEnergy() 248 | .value_in_unit(openmm.unit.kilocalorie_per_mole) 249 | ) 250 | assert torch.isclose( 251 | torch.tensor(energy, dtype=expected_energy.dtype), expected_energy 252 | ) 253 | 254 | 255 | def test_convert_to_openmm_system_dexp_periodic(test_data_dir): 256 | ff = openff.toolkit.ForceField( 257 | str(test_data_dir / "de-ff.offxml"), load_plugins=True 258 | ) 259 | top = openff.toolkit.Topology() 260 | 261 | interchanges = [] 262 | 263 | n_copies_per_mol = [5, 5] 264 | 265 | for smiles, n_copies in zip(["OCCO", "O"], n_copies_per_mol, strict=True): 266 | mol = openff.toolkit.Molecule.from_smiles(smiles) 267 | mol.generate_conformers(n_conformers=1) 268 | 269 | interchange = openff.interchange.Interchange.from_smirnoff( 270 | ff, mol.to_topology() 271 | ) 272 | interchanges.append(interchange) 273 | 274 | for _ in range(n_copies): 275 | top.add_molecule(mol) 276 | 277 | tensor_ff, tensor_tops = smee.converters.convert_interchange(interchanges) 278 | tensor_system = smee.TensorSystem(tensor_tops, n_copies_per_mol, True) 279 | 280 | coords, _ = smee.mm.generate_system_coords( 281 | tensor_system, None, smee.mm.GenerateCoordsConfig() 282 | ) 283 | box_vectors = numpy.eye(3) * 20.0 * openmm.unit.angstrom 284 | 285 | top.box_vectors = box_vectors 286 | 287 | interchange_top = openff.interchange.Interchange.from_smirnoff(ff, top) 288 | 289 | _compare_smee_and_interchange( 290 | tensor_ff, tensor_system, interchange_top, coords, box_vectors 291 | ) 292 | 293 | 294 | def test_convert_to_openmm_topology(): 295 | formaldehyde_interchange = openff.interchange.Interchange.from_smirnoff( 296 | openff.toolkit.ForceField("openff-2.0.0.offxml"), 297 | openff.toolkit.Molecule.from_smiles("C=O").to_topology(), 298 | ) 299 | water_interchange = openff.interchange.Interchange.from_smirnoff( 300 | openff.toolkit.ForceField("openff-2.0.0.offxml"), 301 | openff.toolkit.Molecule.from_smiles("O").to_topology(), 302 | ) 303 | 304 | tensor_ff, [methane_top, water_top] = smee.converters.convert_interchange( 305 | [formaldehyde_interchange, water_interchange] 306 | ) 307 | tensor_system = smee.TensorSystem([methane_top, water_top], [1, 2], True) 308 | 309 | openmm_topology = convert_to_openmm_topology(tensor_system) 310 | 311 | assert openmm_topology.getNumChains() == 2 312 | assert openmm_topology.getNumResidues() == 3 # 1 methane, 2 water 313 | 314 | residue_names = [residue.name for residue in openmm_topology.residues()] 315 | assert residue_names == ["UNK", "HOH", "HOH"] 316 | 317 | atom_names = [atom.name for atom in openmm_topology.atoms()] 318 | expected_atom_names = [ 319 | "C", 320 | "O", 321 | "H1", 322 | "H2", 323 | "O", 324 | "H1", 325 | "H2", 326 | "O", 327 | "H1", 328 | "H2", 329 | ] 330 | assert atom_names == expected_atom_names 331 | 332 | bond_idxs = [ 333 | (bond.atom1.index, bond.atom2.index, bond.order) 334 | for bond in openmm_topology.bonds() 335 | ] 336 | expected_bond_idxs = [ 337 | (0, 1, 2), 338 | (0, 2, 1), 339 | (0, 3, 1), 340 | (4, 5, 1), 341 | (4, 6, 1), 342 | (7, 8, 1), 343 | (7, 9, 1), 344 | ] 345 | assert bond_idxs == expected_bond_idxs 346 | -------------------------------------------------------------------------------- /smee/tests/data/de-ff.offxml: -------------------------------------------------------------------------------- 1 | 2 | 3 | Test 4 | 2024-02-16 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /smee/tests/mm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimonBoothroyd/smee/aca69b9da4c67916c6e59ed2c435fffd4c49a2b6/smee/tests/mm/__init__.py -------------------------------------------------------------------------------- /smee/tests/mm/conftest.py: -------------------------------------------------------------------------------- 1 | import openff.interchange 2 | import openff.toolkit 3 | import openff.units 4 | import pytest 5 | 6 | import smee 7 | import smee.converters 8 | 9 | 10 | @pytest.fixture() 11 | def mock_argon_params() -> tuple[openff.units.Quantity, openff.units.Quantity]: 12 | return ( 13 | 0.1 * openff.units.unit.kilojoules / openff.units.unit.mole, 14 | 3.0 * openff.units.unit.angstrom, 15 | ) 16 | 17 | 18 | @pytest.fixture() 19 | def mock_argon_ff(mock_argon_params) -> openff.toolkit.ForceField: 20 | epsilon, sigma = mock_argon_params 21 | 22 | ff = openff.toolkit.ForceField() 23 | ff.get_parameter_handler("Electrostatics") 24 | ff.get_parameter_handler("LibraryCharges").add_parameter( 25 | { 26 | "smirks": "[Ar:1]", 27 | "charge1": 0.0 * openff.units.unit.elementary_charge, 28 | } 29 | ) 30 | ff.get_parameter_handler("vdW").add_parameter( 31 | {"smirks": "[Ar:1]", "epsilon": epsilon, "sigma": sigma} 32 | ) 33 | return ff 34 | 35 | 36 | @pytest.fixture() 37 | def mock_argon_tensors( 38 | mock_argon_ff, 39 | ) -> tuple[smee.TensorForceField, smee.TensorTopology]: 40 | interchange = openff.interchange.Interchange.from_smirnoff( 41 | mock_argon_ff, openff.toolkit.Molecule.from_smiles("[Ar]").to_topology() 42 | ) 43 | tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange) 44 | tensor_ff.potentials = [p for p in tensor_ff.potentials if p.type == "vdW"] 45 | 46 | return tensor_ff, tensor_top 47 | -------------------------------------------------------------------------------- /smee/tests/mm/test_fe.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import openff.interchange 4 | import openff.toolkit 5 | import openff.units 6 | import openmm.unit 7 | import pytest 8 | import torch 9 | 10 | import smee.converters 11 | import smee.mm 12 | import smee.mm._fe 13 | 14 | 15 | def load_systems(solute: str, solvent: str): 16 | ff_off = openff.toolkit.ForceField("openff-2.0.0.offxml") 17 | 18 | v_site_handler = ff_off.get_parameter_handler("VirtualSites") 19 | v_site_handler.add_parameter( 20 | { 21 | "type": "DivalentLonePair", 22 | "match": "once", 23 | "smirks": "[*:2][#7:1][*:3]", 24 | "distance": 0.4 * openff.units.unit.angstrom, 25 | "epsilon": 0.0 * openff.units.unit.kilojoule_per_mole, 26 | "sigma": 0.1 * openff.units.unit.nanometer, 27 | "outOfPlaneAngle": 0.0 * openff.units.unit.degree, 28 | "charge_increment1": 0.0 * openff.units.unit.elementary_charge, 29 | "charge_increment2": 0.0 * openff.units.unit.elementary_charge, 30 | "charge_increment3": 0.0 * openff.units.unit.elementary_charge, 31 | } 32 | ) 33 | 34 | solute_inter = openff.interchange.Interchange.from_smirnoff( 35 | ff_off, 36 | openff.toolkit.Molecule.from_smiles(solute).to_topology(), 37 | ) 38 | solvent_inter = openff.interchange.Interchange.from_smirnoff( 39 | ff_off, 40 | openff.toolkit.Molecule.from_smiles(solvent).to_topology(), 41 | ) 42 | solvent_inter.to_openmm_system() 43 | 44 | ff, (top_solute, top_solvent) = smee.converters.convert_interchange( 45 | [solute_inter, solvent_inter] 46 | ) 47 | 48 | return top_solute, top_solvent, ff 49 | 50 | 51 | def test_extract_pure_solvent(tmp_cwd, mocker): 52 | top_solute, top_solvent, ff = load_systems("c1ccncc1", "O") 53 | 54 | system = smee.TensorSystem([top_solute, top_solvent], [1, 10], True) 55 | xyz, box = smee.mm.generate_system_coords(system, ff) 56 | 57 | xyz = torch.tensor(xyz.value_in_unit(openmm.unit.angstrom)).unsqueeze(0) 58 | box = torch.tensor(box.value_in_unit(openmm.unit.angstrom)).unsqueeze(0) * 10.0 59 | 60 | mocker.patch( 61 | "smee.mm._fe._load_samples", 62 | return_value=(system, None, None, None, None, xyz, box), 63 | ) 64 | 65 | xyz_solv, _, _ = smee.mm._fe._extract_pure_solvent( 66 | top_solute, top_solvent, ff, tmp_cwd 67 | ) 68 | 69 | assert xyz_solv.shape == (1, 30, 3) 70 | assert torch.allclose(xyz_solv, xyz[:, 12:, :]) 71 | 72 | 73 | @pytest.mark.fe 74 | def test_fe_ops(tmp_cwd): 75 | # taken from a run on commit ec3d272b466f761ed838e16a5ba7b97ceadc463b 76 | expected_dg = torch.tensor(-3.8262).double() 77 | expected_dg_dtheta = torch.tensor( 78 | [ 79 | [10.2679], 80 | [13.3933], 81 | [25.3670], 82 | [9.3747], 83 | [9.3279], 84 | [9.1520], 85 | [10.5614], 86 | [9.6908], 87 | [-4.4326], 88 | [-17.3971], 89 | [-38.5407], 90 | ] 91 | ).double() 92 | 93 | top_solute, top_solvent, ff = load_systems("CCO", "O") 94 | 95 | output_dir = pathlib.Path("CCO") 96 | output_dir.mkdir(parents=True, exist_ok=True) 97 | 98 | smee.mm.generate_dg_solv_data( 99 | top_solute, None, top_solvent, ff, output_dir=output_dir 100 | ) 101 | 102 | params = ff.potentials_by_type["Electrostatics"].parameters 103 | params.requires_grad_(True) 104 | 105 | dg_comp = smee.mm.compute_dg_solv(top_solute, None, top_solvent, ff, output_dir) 106 | dg_comp_dtheta = torch.autograd.grad(dg_comp, params)[0] 107 | 108 | print("dg COMP", dg_comp, flush=True) 109 | print("dg_dtheta COMP", dg_comp_dtheta, dg_comp, flush=True) 110 | 111 | dg_rw, n_eff = smee.mm.reweight_dg_solv( 112 | top_solute, None, top_solvent, ff, output_dir, dg_comp 113 | ) 114 | dg_rw_dtheta = torch.autograd.grad(dg_rw, params)[0] 115 | 116 | print("dg REWEIGHT", dg_rw, flush=True) 117 | print("dg_dtheta REWEIGHT", dg_rw_dtheta, dg_rw, flush=True) 118 | 119 | assert dg_comp.detach() == pytest.approx(expected_dg.detach(), abs=0.5) 120 | assert dg_comp_dtheta.detach() == pytest.approx(expected_dg_dtheta, rel=1.1) 121 | 122 | assert dg_rw.detach() == pytest.approx(expected_dg.detach(), abs=0.5) 123 | assert dg_rw_dtheta.detach() == pytest.approx(expected_dg_dtheta, rel=1.1) 124 | -------------------------------------------------------------------------------- /smee/tests/mm/test_mm.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import openff.interchange 3 | import openff.interchange.models 4 | import openff.toolkit 5 | import openmm.app 6 | import openmm.unit 7 | import pytest 8 | import torch 9 | from rdkit import Chem 10 | 11 | import smee 12 | import smee.converters 13 | import smee.mm 14 | import smee.mm._utils 15 | import smee.tests.utils 16 | from smee.mm._mm import ( 17 | _apply_hmr, 18 | _approximate_box_size, 19 | _energy_minimize, 20 | _generate_packmol_input, 21 | _get_platform, 22 | _get_state_log, 23 | _run_simulation, 24 | _topology_to_xyz, 25 | generate_system_coords, 26 | simulate, 27 | ) 28 | 29 | 30 | @pytest.fixture() 31 | def mock_omm_topology() -> openmm.app.Topology: 32 | topology = openmm.app.Topology() 33 | chain = topology.addChain() 34 | residue = topology.addResidue("UNK", chain) 35 | 36 | for _ in range(2): 37 | topology.addAtom("Ar", openmm.app.Element.getByAtomicNumber(18), residue) 38 | 39 | return topology 40 | 41 | 42 | @pytest.fixture() 43 | def mock_omm_system() -> openmm.System: 44 | system = openmm.System() 45 | 46 | for _ in range(2): 47 | system.addParticle(18.0 * openmm.unit.amu) 48 | 49 | force = openmm.NonbondedForce() 50 | 51 | for _ in range(2): 52 | force.addParticle( 53 | 0.0, 3.0 * openmm.unit.angstrom, 1.0 * openmm.unit.kilojoule_per_mole 54 | ) 55 | 56 | system.addForce(force) 57 | 58 | return system 59 | 60 | 61 | def test_apply_hmr(): 62 | interchanges = [ 63 | openff.interchange.Interchange.from_smirnoff( 64 | openff.toolkit.ForceField("tip4p_fb.offxml", "openff-2.1.0.offxml"), 65 | openff.toolkit.Molecule.from_smiles(smiles).to_topology(), 66 | ) 67 | for smiles in ["O", "CO"] 68 | ] 69 | 70 | force_field, [topology_water, topology_meoh] = smee.converters.convert_interchange( 71 | interchanges 72 | ) 73 | 74 | system_smee = smee.TensorSystem( 75 | [topology_water, topology_meoh, topology_water], 76 | [1, 2, 1], 77 | False, 78 | ) 79 | system_openmm = smee.converters.convert_to_openmm_system(force_field, system_smee) 80 | 81 | for i in range(system_openmm.getNumParticles()): 82 | # Round the masses to the nearest integer to make comparisons easier. 83 | mass = system_openmm.getParticleMass(i).value_in_unit(openmm.unit.amu) 84 | mass = float(int(mass + 0.5)) * openmm.unit.amu 85 | system_openmm.setParticleMass(i, mass) 86 | 87 | _apply_hmr(system_openmm, system_smee) 88 | 89 | masses = [ 90 | system_openmm.getParticleMass(i).value_in_unit(openmm.unit.amu) 91 | for i in range(system_openmm.getNumParticles()) 92 | ] 93 | 94 | expected_masses = [ 95 | # Water 1 96 | 16.0, 97 | 1.0, 98 | 1.0, 99 | 0.0, 100 | # MeOH 1 101 | 15.5, 102 | 0.5, 103 | 1.5, 104 | 1.5, 105 | 1.5, 106 | 1.5, 107 | # MeOH 2 108 | 15.5, 109 | 0.5, 110 | 1.5, 111 | 1.5, 112 | 1.5, 113 | 1.5, 114 | # Water 2 115 | 16.0, 116 | 1.0, 117 | 1.0, 118 | 0.0, 119 | ] 120 | assert masses == expected_masses 121 | 122 | 123 | def test_topology_to_rdkit(): 124 | expected_atomic_nums = [1, 1, 8, 6, 8, 1, 1] 125 | 126 | topology = smee.TensorTopology( 127 | atomic_nums=torch.tensor(expected_atomic_nums), 128 | formal_charges=torch.tensor([0, 0, 0, 0, -1, 0, 0]), 129 | bond_idxs=torch.tensor([[2, 1], [2, 0], [3, 4], [3, 5], [3, 6]]), 130 | bond_orders=torch.tensor([1, 1, 1, 1, 1]), 131 | parameters={}, 132 | v_sites=None, 133 | constraints=None, 134 | ) 135 | 136 | mol = smee.mm._utils.topology_to_rdkit(topology) 137 | assert Chem.MolToSmiles(mol) == "[H]C([H])[O-].[H]O[H]" 138 | 139 | atomic_nums = [atom.GetAtomicNum() for atom in mol.GetAtoms()] 140 | assert atomic_nums == expected_atomic_nums 141 | 142 | assert mol.GetNumConformers() == 1 143 | 144 | 145 | def test_topology_to_xyz(mocker): 146 | mock_molecule: Chem.Mol = Chem.AddHs(Chem.MolFromSmiles("O")) 147 | mock_molecule.RemoveAllConformers() 148 | 149 | conformer = Chem.Conformer(mock_molecule.GetNumAtoms()) 150 | conformer.SetAtomPosition(0, (0.0, 0.0, 0.0)) 151 | conformer.SetAtomPosition(1, (1.0, 0.0, 0.0)) 152 | conformer.SetAtomPosition(2, (0.0, 1.0, 0.0)) 153 | 154 | mock_molecule.AddConformer(conformer) 155 | 156 | mocker.patch("smee.mm._utils.topology_to_rdkit", return_value=mock_molecule) 157 | 158 | interchange = openff.interchange.Interchange.from_smirnoff( 159 | openff.toolkit.ForceField("tip4p_fb.offxml"), 160 | openff.toolkit.Molecule.from_rdkit(mock_molecule).to_topology(), 161 | ) 162 | force_field, [topology] = smee.converters.convert_interchange(interchange) 163 | 164 | xyz = _topology_to_xyz(topology, force_field) 165 | 166 | expected_xyz = ( 167 | "4\n" 168 | "\n" 169 | "O 0.000000 0.000000 0.000000\n" 170 | "H 1.000000 0.000000 0.000000\n" 171 | "H 0.000000 1.000000 0.000000\n" 172 | "X 0.074440 0.074440 0.000000" 173 | ) 174 | 175 | assert xyz == expected_xyz 176 | 177 | 178 | def test_approximate_box_size(): 179 | system = smee.TensorSystem( 180 | [smee.tests.utils.topology_from_smiles("O")], [256], True 181 | ) 182 | 183 | config = smee.mm.GenerateCoordsConfig(scale_factor=2.0) 184 | 185 | box_size = _approximate_box_size(system, config) 186 | 187 | assert isinstance(box_size, openmm.unit.Quantity) 188 | assert box_size.unit.is_compatible(openmm.unit.angstrom) 189 | 190 | box_size = box_size.value_in_unit(openmm.unit.angstrom) 191 | assert isinstance(box_size, float) 192 | 193 | expected_length = (256.0 * 18.01528 / 6.02214076e23 * 1.0e24) ** (1.0 / 3.0) * 2.0 194 | assert numpy.isclose(box_size, expected_length, atol=3) 195 | 196 | 197 | def test_generate_packmol_input(): 198 | expected_tolerance = 0.1 * openmm.unit.nanometer 199 | expected_seed = 42 200 | 201 | config = smee.mm.GenerateCoordsConfig( 202 | tolerance=expected_tolerance, seed=expected_seed 203 | ) 204 | 205 | actual_input_file = _generate_packmol_input( 206 | [1, 2, 3], 1.0 * openmm.unit.angstrom, config 207 | ) 208 | 209 | expected_input_file = "\n".join( 210 | [ 211 | "tolerance 1.000000", 212 | "filetype xyz", 213 | "output output.xyz", 214 | "seed 42", 215 | "structure 0.xyz", 216 | " number 1", 217 | " inside box 0. 0. 0. 1.0 1.0 1.0", 218 | "end structure", 219 | "structure 1.xyz", 220 | " number 2", 221 | " inside box 0. 0. 0. 1.0 1.0 1.0", 222 | "end structure", 223 | "structure 2.xyz", 224 | " number 3", 225 | " inside box 0. 0. 0. 1.0 1.0 1.0", 226 | "end structure", 227 | ] 228 | ) 229 | assert actual_input_file == expected_input_file 230 | 231 | 232 | def test_generate_system_coords(): 233 | coords, box_vectors = generate_system_coords( 234 | smee.TensorSystem( 235 | [ 236 | smee.tests.utils.topology_from_smiles("O"), 237 | smee.tests.utils.topology_from_smiles("CO"), 238 | ], 239 | [1, 2], 240 | True, 241 | ), 242 | None, 243 | smee.mm.GenerateCoordsConfig(), 244 | ) 245 | 246 | assert isinstance(coords, openmm.unit.Quantity) 247 | coords = coords.value_in_unit(openmm.unit.angstrom) 248 | assert isinstance(coords, numpy.ndarray) 249 | assert coords.shape == (3 + 6 * 2, 3) 250 | assert not numpy.allclose(coords, 0.0) 251 | 252 | assert isinstance(box_vectors, openmm.unit.Quantity) 253 | box_vectors = box_vectors.value_in_unit(openmm.unit.angstrom) 254 | assert isinstance(box_vectors, numpy.ndarray) 255 | assert box_vectors.shape == (3, 3) 256 | assert not numpy.allclose(box_vectors, 0.0) 257 | 258 | 259 | def test_generate_system_coords_with_v_sites(): 260 | interchange = openff.interchange.Interchange.from_smirnoff( 261 | openff.toolkit.ForceField("tip4p_fb.offxml"), 262 | openff.toolkit.Molecule.from_mapped_smiles("[H:2][O:1][H:3]").to_topology(), 263 | ) 264 | 265 | force_field, [topology] = smee.converters.convert_interchange(interchange) 266 | system = smee.TensorSystem([topology], [1], False) 267 | 268 | coords, box_vectors = generate_system_coords(system, force_field) 269 | assert isinstance(coords, openmm.unit.Quantity) 270 | coords = coords.value_in_unit(openmm.unit.angstrom) 271 | assert isinstance(coords, numpy.ndarray) 272 | assert coords.shape == (3 + 1, 3) 273 | assert not numpy.allclose(coords, 0.0) 274 | 275 | 276 | def test_get_state_log(mocker): 277 | energy = 1.0 * openmm.unit.kilocalorie_per_mole 278 | box_vectors = numpy.eye(3) * 10.0 * openmm.unit.angstrom 279 | 280 | state = mocker.MagicMock() 281 | state.getPotentialEnergy.return_value = energy 282 | state.getPeriodicBoxVectors.return_value = box_vectors 283 | 284 | actual_log = _get_state_log(state) 285 | expected_log = "energy=4.1840 kcal / mol volume=1000.0000 Å^3" 286 | 287 | assert expected_log == actual_log 288 | 289 | 290 | @pytest.mark.parametrize("is_periodic", [True, False]) 291 | def test_get_platform(is_periodic): 292 | platform = _get_platform(is_periodic) 293 | assert isinstance(platform, openmm.Platform) 294 | 295 | assert (platform.getName() == "Reference") == (not is_periodic) 296 | 297 | 298 | def test_energy_minimize(mock_omm_system): 299 | state = ( 300 | numpy.array([[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]) * openmm.unit.angstrom, 301 | numpy.eye(3) * 1.0 * openmm.unit.nanometer, 302 | ) 303 | 304 | state_new = _energy_minimize( 305 | mock_omm_system, 306 | state, 307 | openmm.Platform.getPlatformByName("Reference"), 308 | smee.mm.MinimizationConfig(), 309 | ) 310 | 311 | coords = state_new.getPositions(asNumpy=True).value_in_unit(openmm.unit.angstrom) 312 | assert coords.shape == (2, 3) 313 | 314 | 315 | def test_run_simulation(mock_omm_topology, mock_omm_system): 316 | state = ( 317 | numpy.array([[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]) * openmm.unit.angstrom, 318 | numpy.eye(3) * 2.0 * openmm.unit.nanometer, 319 | ) 320 | 321 | force: openmm.NonbondedForce = mock_omm_system.getForce(0) 322 | force.setNonbondedMethod(openmm.NonbondedForce.CutoffPeriodic) 323 | 324 | state_new = _run_simulation( 325 | mock_omm_system, 326 | mock_omm_topology, 327 | state, 328 | openmm.Platform.getPlatformByName("Reference"), 329 | smee.mm.SimulationConfig( 330 | temperature=86.0 * openmm.unit.kelvin, 331 | pressure=1.0 * openmm.unit.atmosphere, 332 | n_steps=1, 333 | ), 334 | ) 335 | assert isinstance(state_new, openmm.State) 336 | 337 | coords = state_new.getPositions(asNumpy=True).value_in_unit(openmm.unit.angstrom) 338 | assert coords.shape == (2, 3) 339 | 340 | 341 | def test_simulate(mocker, mock_argon_tensors): 342 | tensor_ff, tensor_top = mock_argon_tensors 343 | 344 | mock_coords = numpy.array([[0.0, 0.0, 0.0]]) * openmm.unit.angstrom 345 | mock_box = numpy.eye(3) * 2.0 * openmm.unit.nanometer 346 | 347 | mock_state = mocker.MagicMock() 348 | mock_state.getPotentialEnergy.return_value = 1.0 * openmm.unit.kilocalorie_per_mole 349 | mock_state.getPeriodicBoxVectors.return_value = mock_box 350 | 351 | spied_energy_minimize = mocker.spy(smee.mm._mm, "_energy_minimize") 352 | spied_run_simulation = mocker.spy(smee.mm._mm, "_run_simulation") 353 | 354 | reporter = mocker.MagicMock() 355 | reporter.describeNextReport.return_value = (1, True, False, False, True) 356 | 357 | simulate( 358 | tensor_top, 359 | tensor_ff, 360 | mock_coords, 361 | mock_box, 362 | [ 363 | smee.mm.MinimizationConfig(), 364 | smee.mm.SimulationConfig( 365 | temperature=86.0 * openmm.unit.kelvin, pressure=None, n_steps=1 366 | ), 367 | ], 368 | smee.mm.SimulationConfig( 369 | temperature=86.0 * openmm.unit.kelvin, pressure=None, n_steps=2 370 | ), 371 | [reporter], 372 | True, 373 | ) 374 | 375 | spied_energy_minimize.assert_called_once() 376 | assert spied_run_simulation.call_count == 2 377 | 378 | assert reporter.report.call_count == 2 379 | 380 | 381 | def test_simulate_invalid_pressure(mock_argon_tensors): 382 | tensor_ff, tensor_top = mock_argon_tensors 383 | 384 | with pytest.raises( 385 | ValueError, match="pressure cannot be specified for a non-periodic" 386 | ): 387 | simulate( 388 | tensor_top, 389 | tensor_ff, 390 | numpy.zeros((0, 3)) * openmm.unit.angstrom, 391 | numpy.eye(3) * openmm.unit.angstrom, 392 | [], 393 | smee.mm.SimulationConfig( 394 | temperature=1.0 * openmm.unit.kelvin, 395 | pressure=1.0 * openmm.unit.bar, 396 | n_steps=2, 397 | ), 398 | [], 399 | ) 400 | -------------------------------------------------------------------------------- /smee/tests/mm/test_reporters.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import openmm.unit 3 | import pytest 4 | 5 | from smee.mm._reporters import TensorReporter, tensor_reporter, unpack_frames 6 | 7 | 8 | class TestTensorReporter: 9 | def test_describe_next(self, mocker): 10 | simulation = mocker.MagicMock() 11 | simulation.currentStep = 5 12 | 13 | reporter = TensorReporter( 14 | mocker.MagicMock(), 2, 1.0 / openmm.unit.kilocalories_per_mole, None 15 | ) 16 | assert reporter.describeNextReport(simulation) == (1, True, False, False, True) 17 | 18 | def test_report(self, tmp_path, mocker): 19 | expected_potential = 1.0 * openmm.unit.kilocalories_per_mole 20 | expected_kinetic = 2.0 * openmm.unit.kilojoules_per_mole 21 | 22 | box_length = 3.0 23 | expected_volume = box_length**3 * openmm.unit.angstrom**3 24 | 25 | expected_box_vectors = numpy.eye(3) * box_length 26 | expected_coords = numpy.ones((1, 3)) 27 | 28 | mock_state = mocker.MagicMock() 29 | mock_state.getPotentialEnergy.return_value = expected_potential 30 | mock_state.getKineticEnergy.return_value = expected_kinetic 31 | mock_state.getPeriodicBoxVectors.return_value = ( 32 | expected_box_vectors * openmm.unit.angstrom 33 | ) 34 | mock_state.getPeriodicBoxVolume.return_value = expected_volume 35 | mock_state.getPositions.return_value = expected_coords * openmm.unit.angstrom 36 | 37 | expected_output_path = tmp_path / "output.msgpack" 38 | 39 | beta = 1.0 / (298.15 * openmm.unit.kelvin * openmm.unit.MOLAR_GAS_CONSTANT_R) 40 | pressure = 1.0 * openmm.unit.atmospheres 41 | 42 | with expected_output_path.open("wb") as file: 43 | reporter = TensorReporter(file, 1, beta, pressure) 44 | reporter.report(None, mock_state) 45 | 46 | with expected_output_path.open("rb") as file: 47 | frames = [*unpack_frames(file)] 48 | 49 | assert len(frames) == 1 50 | coords, box_vectors, reduced_potential, kinetic = frames[0] 51 | 52 | expected_reduced_potential = beta * ( 53 | expected_potential 54 | + pressure * expected_volume * openmm.unit.AVOGADRO_CONSTANT_NA 55 | ) 56 | 57 | assert coords == pytest.approx(expected_coords) 58 | assert box_vectors == pytest.approx(expected_box_vectors) 59 | 60 | assert isinstance(reduced_potential, float) 61 | assert reduced_potential == pytest.approx(expected_reduced_potential) 62 | 63 | assert isinstance(kinetic, float) 64 | assert kinetic == pytest.approx( 65 | expected_kinetic.value_in_unit(openmm.unit.kilocalories_per_mole) 66 | ) 67 | 68 | @pytest.mark.parametrize( 69 | "potential, contains", [(numpy.nan, "nan"), (numpy.inf, "inf")] 70 | ) 71 | def test_report_energy_check(self, potential, contains, mocker): 72 | potential = potential * openmm.unit.kilocalories_per_mole 73 | kinetic = 2.0 * openmm.unit.kilocalories_per_mole 74 | 75 | mock_state = mocker.MagicMock() 76 | mock_state.getPotentialEnergy.return_value = potential 77 | mock_state.getKineticEnergy.return_value = kinetic 78 | 79 | beta = 1.0 / openmm.unit.kilocalories_per_mole 80 | 81 | with pytest.raises(ValueError, match=f"total energy is {contains}"): 82 | reporter = TensorReporter(mocker.MagicMock(), 1, beta, None) 83 | reporter.report(None, mock_state) 84 | 85 | 86 | def test_tensor_reporter(tmp_path): 87 | output = tmp_path / "frames.msgpack" 88 | 89 | beta = 1.0 / (openmm.unit.MOLAR_GAS_CONSTANT_R * 298.15 * openmm.unit.kelvin) 90 | 91 | pressure = 1.0 * openmm.unit.atmospheres 92 | 93 | with tensor_reporter(output, 2, beta, pressure) as reporter: 94 | assert isinstance(reporter, TensorReporter) 95 | 96 | assert output.exists() and output.is_file() 97 | -------------------------------------------------------------------------------- /smee/tests/potentials/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SimonBoothroyd/smee/aca69b9da4c67916c6e59ed2c435fffd4c49a2b6/smee/tests/potentials/__init__.py -------------------------------------------------------------------------------- /smee/tests/potentials/conftest.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import openmm.unit 4 | import pytest 5 | import torch 6 | 7 | import smee.mm 8 | import smee.tests.utils 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def _etoh_water_system() -> ( 13 | tuple[smee.TensorSystem, smee.TensorForceField, torch.Tensor, torch.Tensor] 14 | ): 15 | system, force_field = smee.tests.utils.system_from_smiles(["CCO", "O"], [67, 123]) 16 | coords, box_vectors = smee.mm.generate_system_coords(system, None) 17 | 18 | return ( 19 | system, 20 | force_field, 21 | torch.tensor(coords.value_in_unit(openmm.unit.angstrom), dtype=torch.float32), 22 | torch.tensor( 23 | box_vectors.value_in_unit(openmm.unit.angstrom), dtype=torch.float32 24 | ), 25 | ) 26 | 27 | 28 | @pytest.fixture() 29 | def etoh_water_system( 30 | _etoh_water_system, 31 | ) -> tuple[smee.TensorSystem, smee.TensorForceField, torch.Tensor, torch.Tensor]: 32 | """Creates a system of ethanol and water.""" 33 | 34 | return copy.deepcopy(_etoh_water_system) 35 | -------------------------------------------------------------------------------- /smee/tests/potentials/test_potentials.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import openff.interchange 3 | import openff.interchange.models 4 | import openff.toolkit 5 | import openff.units 6 | import openmm 7 | import openmm.unit 8 | import pytest 9 | import torch 10 | 11 | import smee.converters 12 | import smee.tests.utils 13 | import smee.utils 14 | from smee.potentials import broadcast_exceptions, broadcast_parameters, compute_energy 15 | 16 | 17 | def _place_v_sites( 18 | conformer: torch.Tensor, interchange: openff.interchange.Interchange 19 | ) -> torch.Tensor: 20 | conformer = conformer.numpy() * openmm.unit.angstrom 21 | 22 | openmm_system = interchange.to_openmm() 23 | openmm_context = openmm.Context( 24 | openmm_system, 25 | openmm.VerletIntegrator(0.1), 26 | openmm.Platform.getPlatformByName("Reference"), 27 | ) 28 | openmm_context.setPositions(conformer) 29 | openmm_context.computeVirtualSites() 30 | conformer = openmm_context.getState(getPositions=True).getPositions(asNumpy=True) 31 | return torch.tensor(conformer.value_in_unit(openmm.unit.angstrom)) 32 | 33 | 34 | def _compute_openmm_energy( 35 | interchange: openff.interchange.Interchange, 36 | conformer: torch.Tensor, 37 | ) -> torch.Tensor: 38 | """Evaluate the potential energy of a molecule in a specified conformer using a 39 | specified force field. 40 | 41 | Args: 42 | interchange: The interchange object containing the applied force field 43 | parameters. 44 | conformer: The conformer [Å] of the molecule. 45 | 46 | Returns: 47 | The energy in units of [kcal / mol]. 48 | """ 49 | 50 | import openmm.unit 51 | 52 | openmm_system = interchange.to_openmm() 53 | 54 | if openmm_system.getNumParticles() != interchange.topology.n_atoms: 55 | for _ in range(interchange.topology.n_atoms - openmm_system.getNumParticles()): 56 | openmm_system.addParticle(1.0) 57 | 58 | openmm_context = openmm.Context( 59 | openmm_system, 60 | openmm.VerletIntegrator(0.1), 61 | openmm.Platform.getPlatformByName("Reference"), 62 | ) 63 | openmm_context.setPositions(conformer.numpy() * openmm.unit.angstrom) 64 | openmm_context.computeVirtualSites() 65 | 66 | state = openmm_context.getState(getEnergy=True) 67 | energy = state.getPotentialEnergy().value_in_unit(openmm.unit.kilocalorie_per_mole) 68 | 69 | return torch.tensor(energy) 70 | 71 | 72 | def test_broadcast_parameters(): 73 | system, force_field = smee.tests.utils.system_from_smiles(["C", "O"], [2, 3]) 74 | vdw_potential = force_field.potentials_by_type["vdW"] 75 | 76 | methane_top, water_top = system.topologies 77 | 78 | parameters = broadcast_parameters(system, vdw_potential) 79 | 80 | expected_methane_parameters = ( 81 | methane_top.parameters["vdW"].assignment_matrix @ vdw_potential.parameters 82 | ) 83 | expected_water_parameters = ( 84 | water_top.parameters["vdW"].assignment_matrix @ vdw_potential.parameters 85 | ) 86 | 87 | expected_parameters = torch.vstack( 88 | [expected_methane_parameters] * 2 + [expected_water_parameters] * 3 89 | ) 90 | assert parameters.shape == expected_parameters.shape 91 | assert torch.allclose(parameters, expected_parameters) 92 | 93 | 94 | def test_broadcast_exceptions(): 95 | system, force_field = smee.tests.utils.system_from_smiles( 96 | ["O", "[Na+]", "[Cl-]"], [1, 2, 2] 97 | ) 98 | 99 | vdw_potential = force_field.potentials_by_type["vdW"] 100 | assert len(vdw_potential.parameters) == 4 101 | assert vdw_potential.exceptions is None 102 | 103 | def _parameter_key_to_idx(key): 104 | return next( 105 | iter(i for i, k in enumerate(vdw_potential.parameter_keys) if k.id == key) 106 | ) 107 | 108 | parameter_idx_o = _parameter_key_to_idx("[#1]-[#8X2H2+0:1]-[#1]") 109 | parameter_idx_cl = _parameter_key_to_idx("[#17X0-1:1]") 110 | parameter_idx_na = _parameter_key_to_idx("[#11+1:1]") 111 | 112 | vdw_potential.parameters = torch.vstack( 113 | [vdw_potential.parameters, torch.tensor([[0.12, 0.34], [0.56, 0.67]])] 114 | ) 115 | vdw_potential.parameter_keys = [*vdw_potential.parameter_keys, "o-cl", "o-na"] 116 | 117 | vdw_potential.exceptions = { 118 | (parameter_idx_o, parameter_idx_cl): 4, 119 | (parameter_idx_o, parameter_idx_na): 5, 120 | } 121 | 122 | idxs_a = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 3, 3, 4, 4, 5]) 123 | idxs_b = torch.tensor([3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 4, 5, 6, 5, 6, 6]) 124 | 125 | exceptions_idxs, exceptions = broadcast_exceptions( 126 | system, vdw_potential, idxs_a, idxs_b 127 | ) 128 | 129 | # we only expect custom exceptions between the O and Na and O and Cl 130 | # i.e. particle pairs (0, 3), (0, 4) and (0, 5), (0, 6) 131 | assert torch.allclose(exceptions_idxs, torch.tensor([0, 3, 6, 9])) 132 | assert torch.allclose(exceptions, vdw_potential.parameters[[5, 5, 4, 4], :]) 133 | 134 | 135 | @pytest.mark.parametrize("precision", ["single", "double"]) 136 | @pytest.mark.parametrize( 137 | "smiles", 138 | [ 139 | "C1=NC2=C(N1)C(=O)NC(=N2)N", 140 | "C", 141 | "CO", 142 | "C=O", 143 | "c1ccccc1", 144 | "Cc1ccccc1", 145 | "c1cocc1", 146 | "CC(=O)NC1=CC=C(C=C1)O", 147 | ], 148 | ) 149 | def test_compute_energy(precision, smiles: str): 150 | molecule = openff.toolkit.Molecule.from_smiles(smiles, allow_undefined_stereo=True) 151 | molecule.generate_conformers(n_conformers=1) 152 | 153 | conformer = torch.tensor(molecule.conformers[0].m_as(openff.units.unit.angstrom)) 154 | conformer += torch.randn((molecule.n_atoms, 1)) * 0.25 155 | 156 | interchange = openff.interchange.Interchange.from_smirnoff( 157 | openff.toolkit.ForceField("openff_unconstrained-2.0.0.offxml"), 158 | molecule.to_topology(), 159 | ) 160 | tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange) 161 | 162 | tensor_top = tensor_top.to(precision=precision) 163 | tensor_ff = tensor_ff.to(precision=precision) 164 | 165 | energy_smee = compute_energy(tensor_top, tensor_ff, conformer, None) 166 | energy_openmm = _compute_openmm_energy(interchange, conformer) 167 | 168 | assert torch.isclose(energy_smee, energy_openmm.to(energy_smee.dtype)) 169 | 170 | 171 | @pytest.mark.parametrize("precision", ["single", "double"]) 172 | def test_compute_energy_periodic(etoh_water_system, precision): 173 | tensor_sys, tensor_ff, coords, box_vectors = etoh_water_system 174 | 175 | tensor_sys = tensor_sys.to(precision=precision) 176 | tensor_ff = tensor_ff.to(precision=precision) 177 | 178 | energy_smee = compute_energy(tensor_sys, tensor_ff, coords, box_vectors) 179 | 180 | openmm_system = smee.converters.openmm.convert_to_openmm_system( 181 | tensor_ff, tensor_sys 182 | ) 183 | openmm_system.setDefaultPeriodicBoxVectors( 184 | *box_vectors.numpy() * openmm.unit.angstrom 185 | ) 186 | openmm_context = openmm.Context( 187 | openmm_system, 188 | openmm.VerletIntegrator(0.1), 189 | openmm.Platform.getPlatformByName("Reference"), 190 | ) 191 | openmm_context.setPeriodicBoxVectors(*box_vectors.numpy() * openmm.unit.angstrom) 192 | openmm_context.setPositions(coords.numpy() * openmm.unit.angstrom) 193 | openmm_state = openmm_context.getState(getEnergy=True) 194 | energy_openmm = openmm_state.getPotentialEnergy().value_in_unit( 195 | openmm.unit.kilocalorie_per_mole 196 | ) 197 | 198 | assert torch.isclose( 199 | energy_smee, torch.tensor(energy_openmm, dtype=energy_smee.dtype) 200 | ) 201 | 202 | 203 | def test_compute_energy_v_sites(): 204 | molecule_a = openff.toolkit.Molecule.from_smiles("O") 205 | molecule_a.generate_conformers(n_conformers=1) 206 | molecule_b = openff.toolkit.Molecule.from_smiles("O") 207 | molecule_b.generate_conformers(n_conformers=1) 208 | 209 | topology = openff.toolkit.Topology() 210 | topology.add_molecule(molecule_a) 211 | topology.add_molecule(molecule_b) 212 | 213 | conformer_a = molecule_a.conformers[0].m_as(openff.units.unit.angstrom) 214 | conformer_b = molecule_b.conformers[0].m_as(openff.units.unit.angstrom) 215 | 216 | conformer = torch.vstack( 217 | [ 218 | torch.tensor(conformer_a), 219 | torch.tensor(conformer_b + numpy.array([[3.0, 0.0, 0.0]])), 220 | torch.zeros((2, 3)), 221 | ] 222 | ) 223 | 224 | interchange = openff.interchange.Interchange.from_smirnoff( 225 | openff.toolkit.ForceField("tip4p_fb.offxml"), topology 226 | ) 227 | conformer = _place_v_sites(conformer, interchange) 228 | 229 | tensor_ff, [tensor_top] = smee.converters.convert_interchange(interchange) 230 | 231 | energy_openmm = _compute_openmm_energy(interchange, conformer) 232 | energy_smee = compute_energy(tensor_top, tensor_ff, conformer) 233 | 234 | assert torch.isclose(energy_smee, energy_openmm.to(energy_smee.dtype)) 235 | -------------------------------------------------------------------------------- /smee/tests/potentials/test_valence.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | import smee 5 | from smee.potentials.valence import ( 6 | _compute_cosine_torsion_energy, 7 | compute_cosine_improper_torsion_energy, 8 | compute_cosine_proper_torsion_energy, 9 | compute_harmonic_angle_energy, 10 | compute_harmonic_bond_energy, 11 | ) 12 | 13 | 14 | def _mock_models( 15 | particle_idxs: torch.Tensor, 16 | parameters: torch.Tensor, 17 | parameter_cols: tuple[str, ...], 18 | ) -> tuple[smee.TensorPotential, smee.TensorSystem]: 19 | potential = smee.TensorPotential( 20 | type="mock", 21 | fn="mock-fn", 22 | parameters=parameters, 23 | parameter_keys=[None] * len(parameters), 24 | parameter_cols=parameter_cols, 25 | parameter_units=[None] * len(parameters), 26 | attributes=None, 27 | attribute_cols=None, 28 | attribute_units=None, 29 | ) 30 | 31 | n_atoms = int(particle_idxs.max()) 32 | 33 | parameter_map = smee.ValenceParameterMap( 34 | particle_idxs=particle_idxs, 35 | assignment_matrix=torch.eye(len(particle_idxs)), 36 | ) 37 | topology = smee.TensorTopology( 38 | atomic_nums=torch.zeros(n_atoms, dtype=torch.long), 39 | formal_charges=torch.zeros(n_atoms, dtype=torch.long), 40 | bond_idxs=torch.zeros((0, 2), dtype=torch.long), 41 | bond_orders=torch.zeros(0, dtype=torch.long), 42 | parameters={potential.type: parameter_map}, 43 | v_sites=None, 44 | constraints=None, 45 | ) 46 | 47 | return potential, smee.TensorSystem([topology], [1], False) 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "conformer, expected_shape", 52 | [ 53 | ( 54 | torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), 55 | torch.Size([]), 56 | ), 57 | (torch.tensor([[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]]), (1,)), 58 | ], 59 | ) 60 | def test_compute_harmonic_bond_energy(conformer, expected_shape): 61 | atom_indices = torch.tensor([[0, 1], [0, 2]]) 62 | parameters = torch.tensor([[2.0, 0.95], [0.5, 1.01]], requires_grad=True) 63 | 64 | potential, system = _mock_models(atom_indices, parameters, ("k", "length")) 65 | 66 | energy = compute_harmonic_bond_energy(system, potential, conformer) 67 | energy.backward() 68 | 69 | assert energy.shape == expected_shape 70 | 71 | assert torch.isclose(energy, torch.tensor(1.0 * 0.05**2 + 0.25 * 0.01**2)) 72 | assert not torch.allclose(parameters.grad, torch.tensor(0.0)) 73 | 74 | 75 | @pytest.mark.parametrize( 76 | "conformer, expected_shape", 77 | [ 78 | ( 79 | torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), 80 | torch.Size([]), 81 | ), 82 | (torch.tensor([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]]), (1,)), 83 | ], 84 | ) 85 | def test_compute_harmonic_angle_energy(conformer, expected_shape): 86 | atom_indices = torch.tensor([[0, 1, 2]]) 87 | parameters = torch.tensor([[2.0, 92.5]], requires_grad=True) 88 | 89 | potential, system = _mock_models(atom_indices, parameters, ("k", "angle")) 90 | 91 | energy = compute_harmonic_angle_energy(system, potential, conformer) 92 | energy.backward() 93 | 94 | assert energy.shape == expected_shape 95 | 96 | expected_energy = 0.5 * parameters[0, 0] * (torch.pi / 2.0 - parameters[0, 1]) ** 2 97 | expected_gradient = torch.tensor( 98 | [ 99 | 0.5 * (torch.pi / 2.0 - parameters[0, 1]) ** 2, 100 | parameters[0, 0] * (parameters[0, 1] - torch.pi / 2.0), 101 | ] 102 | ) 103 | 104 | assert torch.isclose(energy, expected_energy) 105 | assert torch.allclose(parameters.grad, expected_gradient) 106 | 107 | 108 | @pytest.mark.parametrize("expected_shape", [torch.Size([]), (1,)]) 109 | @pytest.mark.parametrize( 110 | "energy_function", 111 | [ 112 | _compute_cosine_torsion_energy, 113 | compute_cosine_proper_torsion_energy, 114 | compute_cosine_improper_torsion_energy, 115 | ], 116 | ) 117 | @pytest.mark.parametrize("phi_sign", [-1.0, 1.0]) 118 | def test_compute_cosine_torsion_energy(expected_shape, energy_function, phi_sign): 119 | conformer = torch.tensor( 120 | [[-1.0, 1.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 1.0, phi_sign]] 121 | ) 122 | 123 | if expected_shape == (1,): 124 | conformer = torch.unsqueeze(conformer, 0) 125 | 126 | atom_indices = torch.tensor([[0, 1, 2, 3]]) 127 | parameters = torch.tensor([[2.0, 2.0, 20.0, 1.5]], requires_grad=True) 128 | 129 | potential, system = _mock_models( 130 | atom_indices, parameters, ("k", "periodicity", "phase", "idivf") 131 | ) 132 | 133 | energy = energy_function(system, potential, conformer) 134 | energy.backward() 135 | 136 | expected_energy = ( 137 | parameters[0, 0] 138 | / parameters[0, 3] 139 | * ( 140 | 1.0 141 | + torch.cos( 142 | torch.tensor( 143 | [ 144 | parameters[0, 1] * torch.tensor(phi_sign * torch.pi / 4.0) 145 | - parameters[0, 2] 146 | ] 147 | ) 148 | ) 149 | ) 150 | ) 151 | 152 | assert torch.isclose(energy, expected_energy) 153 | assert energy.shape == expected_shape 154 | -------------------------------------------------------------------------------- /smee/tests/test_models.py: -------------------------------------------------------------------------------- 1 | import openff.interchange.models 2 | import pytest 3 | import torch 4 | 5 | import smee.tests.utils 6 | from smee._models import _cast 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "tensor, precision, expected_device, expected_dtype", 11 | [ 12 | (torch.zeros(2, dtype=torch.float32), "single", "cpu", torch.float32), 13 | (torch.zeros(2, dtype=torch.float64), "single", "cpu", torch.float32), 14 | (torch.zeros(2, dtype=torch.float32), "double", "cpu", torch.float64), 15 | (torch.zeros(2, dtype=torch.float64), "double", "cpu", torch.float64), 16 | (torch.zeros(2, dtype=torch.int32), "single", "cpu", torch.int32), 17 | (torch.zeros(2, dtype=torch.int64), "single", "cpu", torch.int32), 18 | (torch.zeros(2, dtype=torch.int32), "double", "cpu", torch.int64), 19 | (torch.zeros(2, dtype=torch.int64), "double", "cpu", torch.int64), 20 | ], 21 | ) 22 | def test_cast(tensor, precision, expected_device, expected_dtype): 23 | output = _cast(tensor, precision=precision) 24 | 25 | assert output.shape == tensor.shape 26 | assert output.device.type == expected_device 27 | assert output.dtype == expected_dtype 28 | 29 | 30 | def _add_v_sites(topology: smee.TensorTopology): 31 | v_site_key = openff.interchange.models.VirtualSiteKey( 32 | orientation_atom_indices=(0,), 33 | type="BondCharge", 34 | name="EP", 35 | match="once", 36 | ) 37 | topology.v_sites = smee.VSiteMap( 38 | keys=[v_site_key], 39 | key_to_idx={v_site_key: 6}, 40 | parameter_idxs=torch.tensor([[0]]), 41 | ) 42 | 43 | 44 | class TestTensorTopology: 45 | def test_n_atoms(self): 46 | topology = smee.tests.utils.topology_from_smiles("CO") 47 | 48 | expected_n_atoms = 6 49 | assert topology.n_atoms == expected_n_atoms 50 | 51 | def test_n_bonds(self): 52 | topology = smee.tests.utils.topology_from_smiles("CO") 53 | 54 | expected_n_bonds = 5 55 | assert topology.n_bonds == expected_n_bonds 56 | 57 | def test_n_residues(self): 58 | topology = smee.tests.utils.topology_from_smiles("[Ar]") 59 | topology.residue_ids = None 60 | topology.residue_idxs = None 61 | assert topology.n_residues == 0 62 | 63 | topology.residue_ids = ["Ar"] 64 | topology.residue_idxs = [0] 65 | assert topology.n_residues == 1 66 | 67 | def test_n_chains(self): 68 | topology = smee.tests.utils.topology_from_smiles("[Ar]") 69 | topology.residue_ids = [0] 70 | topology.residue_idxs = ["UNK"] 71 | topology.chain_idxs = [0] 72 | topology.chain_ids = ["A"] 73 | assert topology.n_chains == 1 74 | 75 | def test_n_v_sites(self): 76 | topology = smee.tests.utils.topology_from_smiles("CO") 77 | 78 | expected_n_v_sites = 0 79 | assert topology.n_v_sites == expected_n_v_sites 80 | 81 | _add_v_sites(topology) 82 | 83 | expected_n_v_sites = 1 84 | assert topology.n_v_sites == expected_n_v_sites 85 | 86 | def test_n_particles(self): 87 | topology = smee.tests.utils.topology_from_smiles("CO") 88 | _add_v_sites(topology) 89 | 90 | expected_n_particles = 7 91 | assert topology.n_particles == expected_n_particles 92 | 93 | 94 | class TestTensorSystem: 95 | def test_n_atoms(self): 96 | system = smee.TensorSystem( 97 | topologies=[ 98 | smee.tests.utils.topology_from_smiles("CO"), 99 | smee.tests.utils.topology_from_smiles("O"), 100 | ], 101 | n_copies=[2, 5], 102 | is_periodic=True, 103 | ) 104 | 105 | expected_n_atoms = 6 * 2 + 3 * 5 106 | assert system.n_atoms == expected_n_atoms 107 | 108 | def test_n_v_sites(self): 109 | system = smee.TensorSystem( 110 | topologies=[ 111 | smee.tests.utils.topology_from_smiles("CO"), 112 | smee.tests.utils.topology_from_smiles("O"), 113 | ], 114 | n_copies=[2, 5], 115 | is_periodic=True, 116 | ) 117 | 118 | expected_n_v_sites = 0 * 2 + 0 * 5 119 | assert system.n_v_sites == expected_n_v_sites 120 | 121 | _add_v_sites(system.topologies[0]) 122 | 123 | expected_n_v_sites = 1 * 2 + 0 * 5 124 | assert system.n_v_sites == expected_n_v_sites 125 | 126 | def test_n_particles(self): 127 | system = smee.TensorSystem( 128 | topologies=[ 129 | smee.tests.utils.topology_from_smiles("CO"), 130 | smee.tests.utils.topology_from_smiles("O"), 131 | ], 132 | n_copies=[2, 5], 133 | is_periodic=True, 134 | ) 135 | _add_v_sites(system.topologies[0]) 136 | 137 | expected_n_particles = (6 + 1) * 2 + 3 * 5 138 | assert system.n_particles == expected_n_particles 139 | -------------------------------------------------------------------------------- /smee/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import openff.interchange.models 2 | import openff.toolkit 3 | import pytest 4 | import torch 5 | 6 | import smee 7 | import smee.utils 8 | 9 | 10 | def test_find_exclusions_simple(): 11 | molecule = openff.toolkit.Molecule() 12 | for _ in range(6): 13 | molecule.add_atom(6, 0, False) 14 | for i in range(5): 15 | molecule.add_bond(i, i + 1, 1, False) 16 | 17 | exclusions = smee.utils.find_exclusions(molecule.to_topology()) 18 | assert exclusions == { 19 | (0, 1): "scale_12", 20 | (0, 2): "scale_13", 21 | (0, 3): "scale_14", 22 | (0, 4): "scale_15", 23 | (1, 2): "scale_12", 24 | (1, 3): "scale_13", 25 | (1, 4): "scale_14", 26 | (1, 5): "scale_15", 27 | (2, 3): "scale_12", 28 | (2, 4): "scale_13", 29 | (2, 5): "scale_14", 30 | (3, 4): "scale_12", 31 | (3, 5): "scale_13", 32 | (4, 5): "scale_12", 33 | } 34 | 35 | 36 | def test_find_exclusions_rings(): 37 | molecule = openff.toolkit.Molecule() 38 | for _ in range(8): 39 | molecule.add_atom(6, 0, False) 40 | 41 | # para substituted 6-membered ring 42 | molecule.add_bond(0, 1, 1, False) 43 | for i in range(6): 44 | molecule.add_bond(i + 1, (i + 1) % 6 + 1, 1, False) 45 | molecule.add_bond(2, 7, 1, False) 46 | 47 | exclusions = smee.utils.find_exclusions(molecule.to_topology()) 48 | assert exclusions == { 49 | (0, 1): "scale_12", 50 | (0, 2): "scale_13", 51 | (0, 6): "scale_13", 52 | (0, 3): "scale_14", 53 | (0, 5): "scale_14", 54 | (0, 7): "scale_14", 55 | (0, 4): "scale_15", 56 | (1, 2): "scale_12", 57 | (1, 6): "scale_12", 58 | (1, 3): "scale_13", 59 | (1, 5): "scale_13", 60 | (1, 7): "scale_13", 61 | (1, 4): "scale_14", 62 | (2, 3): "scale_12", 63 | (2, 6): "scale_13", 64 | (2, 4): "scale_13", 65 | (2, 5): "scale_14", 66 | (2, 7): "scale_12", 67 | (3, 4): "scale_12", 68 | (3, 5): "scale_13", 69 | (3, 6): "scale_14", 70 | (3, 7): "scale_13", 71 | (4, 5): "scale_12", 72 | (4, 6): "scale_13", 73 | (4, 7): "scale_14", 74 | (5, 6): "scale_12", 75 | (5, 7): "scale_15", 76 | (6, 7): "scale_14", 77 | } 78 | 79 | 80 | def test_find_exclusions_dimer(): 81 | molecule = openff.toolkit.Molecule() 82 | for _ in range(3): 83 | molecule.add_atom(6, 0, False) 84 | 85 | molecule.add_bond(0, 1, 1, False) 86 | molecule.add_bond(1, 2, 1, False) 87 | 88 | topology = openff.toolkit.Topology() 89 | topology.add_molecule(molecule) 90 | topology.add_molecule(molecule) 91 | 92 | exclusions = smee.utils.find_exclusions(topology) 93 | assert exclusions == { 94 | (0, 1): "scale_12", 95 | (0, 2): "scale_13", 96 | (1, 2): "scale_12", 97 | (3, 4): "scale_12", 98 | (3, 5): "scale_13", 99 | (4, 5): "scale_12", 100 | } 101 | 102 | 103 | def test_find_exclusions_v_sites(): 104 | molecule = openff.toolkit.Molecule() 105 | for _ in range(4): 106 | molecule.add_atom(6, 0, False) 107 | for i in range(3): 108 | molecule.add_bond(i, i + 1, 1, False) 109 | 110 | v_site_keys = [ 111 | openff.interchange.models.VirtualSiteKey( 112 | orientation_atom_indices=(0, 1, 2), 113 | type="MonovalentLonePair", 114 | match="once", 115 | name="EP", 116 | ), 117 | openff.interchange.models.VirtualSiteKey( 118 | orientation_atom_indices=(3, 2, 1), 119 | type="MonovalentLonePair", 120 | match="once", 121 | name="EP", 122 | ), 123 | ] 124 | 125 | v_sites = smee.VSiteMap( 126 | keys=v_site_keys, 127 | key_to_idx={v_site_keys[0]: 4, v_site_keys[1]: 5}, 128 | parameter_idxs=torch.zeros((2, 1)), 129 | ) 130 | 131 | exclusions = smee.utils.find_exclusions(molecule.to_topology(), v_sites) 132 | assert exclusions == { 133 | (0, 1): "scale_12", 134 | (0, 2): "scale_13", 135 | (0, 3): "scale_14", 136 | (0, 4): "scale_12", 137 | (0, 5): "scale_14", 138 | (1, 2): "scale_12", 139 | (1, 3): "scale_13", 140 | (1, 4): "scale_12", 141 | (1, 5): "scale_13", 142 | (2, 3): "scale_12", 143 | (2, 4): "scale_13", 144 | (2, 5): "scale_12", 145 | (3, 4): "scale_14", 146 | (3, 5): "scale_12", 147 | (4, 5): "scale_14", 148 | } 149 | 150 | 151 | def test_ones_like(): 152 | expected_size = (4, 5) 153 | expected_type = torch.float16 154 | 155 | other = torch.tensor([1, 2, 3], dtype=expected_type, device="cpu") 156 | tensor = smee.utils.ones_like(expected_size, other) 157 | 158 | assert tensor.dtype == expected_type 159 | assert tensor.shape == expected_size 160 | assert torch.allclose(tensor, torch.tensor(1.0, dtype=expected_type)) 161 | 162 | 163 | def test_zeros_like(): 164 | expected_size = (4, 5) 165 | expected_type = torch.float16 166 | 167 | other = torch.tensor([1, 2, 3], dtype=expected_type, device="cpu") 168 | tensor = smee.utils.zeros_like(expected_size, other) 169 | 170 | assert tensor.dtype == expected_type 171 | assert tensor.shape == expected_size 172 | assert torch.allclose(tensor, torch.tensor(0.0, dtype=expected_type)) 173 | 174 | 175 | def test_tensor_like(): 176 | expected_type = torch.float16 177 | expected_data = [[3.0], [2.0], [1.0]] 178 | 179 | other = torch.tensor([], dtype=expected_type, device="cpu") 180 | tensor = smee.utils.tensor_like(expected_data, other) 181 | 182 | assert tensor.dtype == expected_type 183 | assert torch.allclose(tensor, torch.tensor(expected_data, dtype=expected_type)) 184 | 185 | 186 | def test_tensor_like_copy(): 187 | expected_type = torch.float16 188 | expected_data = torch.tensor([[3.0], [2.0], [1.0]], requires_grad=True) 189 | 190 | other = torch.tensor(expected_data, dtype=expected_type, device="cpu") 191 | tensor = smee.utils.tensor_like(expected_data, other) 192 | 193 | assert tensor.requires_grad is False 194 | assert tensor.dtype == expected_type 195 | assert torch.allclose(tensor, torch.tensor(expected_data, dtype=expected_type)) 196 | 197 | 198 | def test_arange_like(): 199 | expected_type = torch.int8 200 | expected_data = [0, 1, 2, 3] 201 | 202 | other = torch.tensor([], dtype=expected_type, device="cpu") 203 | tensor = smee.utils.arange_like(4, other) 204 | 205 | assert tensor.dtype == expected_type 206 | assert torch.allclose(tensor, torch.tensor(expected_data, dtype=expected_type)) 207 | 208 | 209 | @pytest.mark.parametrize( 210 | "a, b, dim, keepdim", 211 | [ 212 | (torch.tensor([1.0, 2.0, 3.0]), None, 0, False), 213 | (torch.tensor([1.0, 2.0, 3.0]), None, 0, True), 214 | (torch.tensor([1.0, 2.0, 3.0]), torch.tensor(0.0), 0, False), 215 | (torch.tensor([1.0, 2.0, 3.0]), torch.tensor(2.0), 0, False), 216 | (torch.tensor([1.0, 2.0, 3.0]), torch.tensor([3.0, 2.0, 1.0]), 0, False), 217 | (torch.tensor([1.0, 2.0, 3.0]), torch.tensor(2.0), 0, True), 218 | (torch.tensor([1.0, 2.0, 3.0]), torch.tensor([3.0, 2.0, 1.0]), 0, True), 219 | ( 220 | torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 221 | torch.tensor([[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]), 222 | 0, 223 | False, 224 | ), 225 | ( 226 | torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 227 | torch.tensor([[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]), 228 | 1, 229 | False, 230 | ), 231 | ( 232 | torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 233 | torch.tensor([[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]), 234 | 1, 235 | True, 236 | ), 237 | (torch.tensor(-torch.inf), torch.tensor(1.0), 0, True), 238 | ], 239 | ) 240 | def test_logsumexp(a, b, dim, keepdim): 241 | from scipy.special import logsumexp 242 | 243 | actual = smee.utils.logsumexp(a, dim, keepdim, b) 244 | expected = torch.tensor( 245 | logsumexp(a.numpy(), dim, b if b is None else b.numpy(), keepdim) 246 | ) 247 | 248 | assert actual.shape == expected.shape 249 | assert torch.allclose(actual.double(), expected.double()) 250 | 251 | 252 | def test_logsumexp_with_sign(): 253 | from scipy.special import logsumexp 254 | 255 | a = torch.tensor([1.0, 2.0, 3.0]) 256 | b = torch.tensor(-2.0) 257 | 258 | actual, actual_sign = smee.utils.logsumexp(a, -1, True, b, return_sign=True) 259 | expected, expected_sign = torch.tensor( 260 | logsumexp(a.numpy(), -1, b.numpy(), True, return_sign=True) 261 | ) 262 | 263 | assert actual.shape == expected.shape 264 | assert torch.allclose(actual.double(), expected.double()) 265 | 266 | assert actual_sign.shape == expected_sign.shape 267 | assert torch.allclose(actual_sign.double(), expected_sign.double()) 268 | 269 | 270 | @pytest.mark.parametrize("n", [7499, 7500, 7501]) 271 | def test_to_upper_tri_idx(n): 272 | i, j = torch.triu_indices(n, n, 1) 273 | expected_idxs = torch.arange(len(i)) 274 | 275 | idxs = smee.utils.to_upper_tri_idx(i, j, n) 276 | 277 | assert idxs.shape == expected_idxs.shape 278 | assert (idxs == expected_idxs).all() 279 | 280 | 281 | def test_geometric_mean(): 282 | a = torch.tensor(4.0, requires_grad=True).double() 283 | b = torch.tensor(9.0, requires_grad=True).double() 284 | 285 | assert torch.autograd.gradcheck( 286 | smee.utils.geometric_mean, (a, b), check_backward_ad=True, check_forward_ad=True 287 | ) 288 | 289 | assert torch.isclose(smee.utils.geometric_mean(a, b), torch.tensor(6.0).double()) 290 | 291 | 292 | @pytest.mark.parametrize( 293 | "a, b, expected_grad_a, expected_grad_b", 294 | [ 295 | (0.0, 0.0, 0.0, 0.0), 296 | (3.0, 0.0, 0.0, 3.0 / (2.0 * smee.utils.EPSILON)), 297 | (0.0, 4.0, 4.0 / (2.0 * smee.utils.EPSILON), 0.0), 298 | ], 299 | ) 300 | def test_geometric_mean_zero(a, b, expected_grad_a, expected_grad_b): 301 | a = torch.tensor(a, requires_grad=True) 302 | b = torch.tensor(b, requires_grad=True) 303 | 304 | v = smee.utils.geometric_mean(a, b) 305 | v.backward() 306 | 307 | expected_grad_a = torch.tensor(expected_grad_a) 308 | expected_grad_b = torch.tensor(expected_grad_b) 309 | 310 | assert a.grad.shape == expected_grad_a.shape 311 | assert torch.allclose(a.grad, expected_grad_a) 312 | 313 | assert b.grad.shape == expected_grad_b.shape 314 | assert torch.allclose(b.grad, expected_grad_b) 315 | -------------------------------------------------------------------------------- /smee/tests/utils.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import openff.interchange 4 | import openff.interchange.models 5 | import openff.toolkit 6 | import openff.units 7 | import torch 8 | from rdkit import Chem 9 | 10 | import smee 11 | import smee.converters 12 | import smee.potentials.nonbonded 13 | import smee.utils 14 | 15 | LJParam = typing.NamedTuple("LJParam", [("eps", float), ("sig", float)]) 16 | 17 | 18 | def mock_convert_fn_with_deps( 19 | handlers: list[openff.interchange.smirnoff.SMIRNOFFvdWCollection], 20 | topologies: list[openff.toolkit.Topology], 21 | v_site_maps: list[smee.VSiteMap | None], 22 | dependencies: dict[ 23 | str, tuple[smee.TensorPotential, list[smee.NonbondedParameterMap]] 24 | ], 25 | ) -> tuple[smee.TensorPotential, list[smee.NonbondedParameterMap]]: 26 | raise NotImplementedError() 27 | 28 | 29 | def convert_lj_to_dexp(potential: smee.TensorPotential): 30 | potential.fn = smee.EnergyFn.VDW_DEXP 31 | 32 | parameter_cols = [*potential.parameter_cols] 33 | sigma_idx = potential.parameter_cols.index("sigma") 34 | 35 | sigma = potential.parameters[:, sigma_idx] 36 | r_min = 2 ** (1 / 6) * sigma 37 | 38 | potential.parameters[:, sigma_idx] = r_min 39 | 40 | parameter_cols[sigma_idx] = "r_min" 41 | potential.parameter_cols = tuple(parameter_cols) 42 | 43 | potential.attribute_cols = (*potential.attribute_cols, "alpha", "beta") 44 | potential.attribute_units = ( 45 | *potential.attribute_units, 46 | openff.units.unit.dimensionless, 47 | openff.units.unit.dimensionless, 48 | ) 49 | potential.attributes = torch.cat([potential.attributes, torch.tensor([16.5, 5.0])]) 50 | 51 | return potential 52 | 53 | 54 | def topology_from_smiles(smiles: str) -> smee.TensorTopology: 55 | """Creates a topology with no parameters from a SMILES string. 56 | 57 | Args: 58 | smiles: The SMILES string. 59 | 60 | Returns: 61 | The topology. 62 | """ 63 | mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) 64 | 65 | return smee.TensorTopology( 66 | atomic_nums=torch.tensor([atom.GetAtomicNum() for atom in mol.GetAtoms()]), 67 | formal_charges=torch.tensor( 68 | [atom.GetFormalCharge() for atom in mol.GetAtoms()] 69 | ), 70 | bond_idxs=torch.tensor( 71 | [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()] for bond in mol.GetBonds()] 72 | ), 73 | bond_orders=torch.tensor( 74 | [int(bond.GetBondTypeAsDouble()) for bond in mol.GetBonds()] 75 | ), 76 | parameters={}, 77 | v_sites=None, 78 | constraints=None, 79 | ) 80 | 81 | 82 | def system_from_smiles( 83 | smiles: list[str], 84 | n_copies: list[int], 85 | force_field: openff.toolkit.ForceField | None = None, 86 | ) -> tuple[smee.TensorSystem, smee.TensorForceField]: 87 | """Creates a system from a list of SMILES strings. 88 | 89 | Args: 90 | smiles: The list of SMILES strings. 91 | n_copies: The number of copies of each molecule. 92 | 93 | Returns: 94 | The system and force field. 95 | """ 96 | force_field = ( 97 | force_field 98 | if force_field is not None 99 | else openff.toolkit.ForceField("openff-2.0.0.offxml") 100 | ) 101 | 102 | interchanges = [ 103 | openff.interchange.Interchange.from_smirnoff( 104 | force_field, 105 | openff.toolkit.Molecule.from_smiles(pattern).to_topology(), 106 | ) 107 | for pattern in smiles 108 | ] 109 | 110 | tensor_ff, tensor_tops = smee.converters.convert_interchange(interchanges) 111 | 112 | return smee.TensorSystem(tensor_tops, n_copies, is_periodic=True), tensor_ff 113 | 114 | 115 | def _parameter_key_to_idx(potential: smee.TensorPotential, key: str): 116 | return next(iter(i for i, k in enumerate(potential.parameter_keys) if k.id == key)) 117 | 118 | 119 | def system_with_exceptions() -> ( 120 | tuple[smee.TensorSystem, smee.TensorPotential, dict[str, LJParam]] 121 | ): 122 | ff = openff.toolkit.ForceField() 123 | 124 | kcal = openff.units.unit.kilocalorie / openff.units.unit.mole 125 | ang = openff.units.unit.angstrom 126 | 127 | eps_h, sig_h = 0.123, 1.234 128 | eps_o, sig_o = 0.234, 2.345 129 | eps_a, sig_a = 0.345, 3.456 130 | 131 | eps_hh, sig_hh = 0.3, 1.4 132 | eps_ah, sig_ah = 0.4, 2.5 133 | 134 | lj_handler = ff.get_parameter_handler("vdW") 135 | lj_handler.add_parameter( 136 | {"smirks": "[O:1]", "epsilon": eps_o * kcal, "sigma": sig_o * ang} 137 | ) 138 | lj_handler.add_parameter( 139 | {"smirks": "[H:1]", "epsilon": eps_h * kcal, "sigma": sig_h * ang} 140 | ) 141 | lj_handler.add_parameter( 142 | {"smirks": "[Ar:1]", "epsilon": eps_a * kcal, "sigma": sig_a * ang} 143 | ) 144 | 145 | system, tensor_ff = smee.tests.utils.system_from_smiles(["O", "[Ar]"], [1, 1], ff) 146 | system.is_periodic = False 147 | 148 | lj_potential = tensor_ff.potentials_by_type["vdW"] 149 | lj_potential.attributes[lj_potential.attribute_cols.index("scale_12")] = 1.0 150 | lj_potential.attributes[lj_potential.attribute_cols.index("scale_13")] = 1.0 151 | 152 | parameter_idx_h = _parameter_key_to_idx(lj_potential, "[H:1]") 153 | parameter_idx_a = _parameter_key_to_idx(lj_potential, "[Ar:1]") 154 | 155 | assert lj_potential.parameter_cols[0] == "epsilon" 156 | 157 | lj_potential.parameters = torch.vstack( 158 | [lj_potential.parameters, torch.tensor([[eps_ah, sig_ah], [eps_hh, sig_hh]])] 159 | ) 160 | lj_potential.parameter_keys = [*lj_potential.parameter_keys, "ar-h", "h-h"] 161 | 162 | lj_potential.exceptions = { 163 | (parameter_idx_a, parameter_idx_h): 3, 164 | (parameter_idx_h, parameter_idx_h): 4, 165 | } 166 | 167 | for top in system.topologies: 168 | assignment_matrix = smee.utils.zeros_like( 169 | (top.n_particles, len(lj_potential.parameters)), 170 | top.parameters["vdW"].assignment_matrix, 171 | ) 172 | assignment_matrix[:, :3] = top.parameters["vdW"].assignment_matrix.to_dense() 173 | 174 | top.parameters["vdW"].assignment_matrix = assignment_matrix.to_sparse() 175 | 176 | params = { 177 | "oo": LJParam(eps_o, sig_o), 178 | "hh": LJParam(eps_hh, sig_hh), 179 | "aa": LJParam(eps_a, sig_a), 180 | "ah": LJParam(eps_ah, sig_ah), 181 | "oh": LJParam((eps_o * eps_h) ** 0.5, 0.5 * (sig_o + sig_h)), 182 | "oa": LJParam((eps_a * eps_o) ** 0.5, 0.5 * (sig_a + sig_o)), 183 | } 184 | 185 | for k, v in [*params.items()]: 186 | params[k[::-1]] = v 187 | 188 | return system, lj_potential, params 189 | 190 | 191 | def add_explicit_lb_exceptions( 192 | potential: smee.TensorPotential, system: smee.TensorSystem 193 | ): 194 | """Apply Lorentz-Berthelot mixing rules to a vdW potential and add the parameters 195 | as explicit exceptions. 196 | 197 | Notes: 198 | The potential and system will be modified in place. 199 | 200 | Args: 201 | potential: The potential to add exceptions to. 202 | system: The system whose assignment matrices need to be updated to account for 203 | the new exception parameters. 204 | """ 205 | assert potential.type == "vdW" 206 | 207 | n_params = len(potential.parameters) 208 | 209 | idxs_i, idxs_j = torch.triu_indices(n_params, n_params, 1) 210 | 211 | idxs_i = torch.cat([idxs_i, torch.arange(n_params)]) 212 | idxs_j = torch.cat([idxs_j, torch.arange(n_params)]) 213 | 214 | eps_col = potential.parameter_cols.index("epsilon") 215 | 216 | if "r_min" in potential.parameter_cols: 217 | sig_col = potential.parameter_cols.index("r_min") 218 | elif "sigma" in potential.parameter_cols: 219 | sig_col = potential.parameter_cols.index("sigma") 220 | else: 221 | raise ValueError("no sigma-like parameter found.") 222 | 223 | eps_i = potential.parameters[idxs_i, eps_col] 224 | sig_i = potential.parameters[idxs_i, sig_col] 225 | 226 | eps_j = potential.parameters[idxs_j, eps_col] 227 | sig_j = potential.parameters[idxs_j, sig_col] 228 | 229 | eps_ij, sig_ij = smee.potentials.nonbonded.lorentz_berthelot( 230 | eps_i, eps_j, sig_i, sig_j 231 | ) 232 | 233 | potential.parameters = torch.vstack( 234 | [torch.zeros_like(potential.parameters), torch.stack([eps_ij, sig_ij], dim=-1)] 235 | ) 236 | for i in range(len(eps_ij)): 237 | potential.parameter_keys.append( 238 | openff.interchange.models.PotentialKey( 239 | id=f"exception-{i}", associated_handler="vdW" 240 | ) 241 | ) 242 | 243 | potential.exceptions = { 244 | (int(idx_i), int(idx_j)): i + n_params 245 | for i, (idx_i, idx_j) in enumerate(zip(idxs_i, idxs_j, strict=True)) 246 | } 247 | 248 | for top in system.topologies: 249 | parameter_map = top.parameters["vdW"] 250 | 251 | padding = smee.utils.zeros_like( 252 | (len(parameter_map.assignment_matrix), len(potential.exceptions)), 253 | parameter_map.assignment_matrix, 254 | ) 255 | padded = torch.concat( 256 | [parameter_map.assignment_matrix.to_dense(), padding], 257 | dim=-1, 258 | ) 259 | 260 | parameter_map.assignment_matrix = padded.to_sparse() 261 | -------------------------------------------------------------------------------- /smee/utils.py: -------------------------------------------------------------------------------- 1 | """General utility functions""" 2 | 3 | import typing 4 | 5 | import networkx 6 | import openff.toolkit 7 | import torch 8 | 9 | if typing.TYPE_CHECKING: 10 | import smee 11 | 12 | 13 | _size = int | torch.Size | list[int] | tuple[int, ...] 14 | 15 | ExclusionType = typing.Literal["scale_12", "scale_13", "scale_14", "scale_15"] 16 | 17 | 18 | EPSILON = 1.0e-6 19 | """A small epsilon value used to prevent divide by zero errors.""" 20 | 21 | 22 | def find_exclusions( 23 | topology: openff.toolkit.Topology, 24 | v_sites: typing.Optional["smee.VSiteMap"] = None, 25 | ) -> dict[tuple[int, int], ExclusionType]: 26 | """Find all excluded interaction pairs and their associated scaling factor. 27 | 28 | Args: 29 | topology: The topology to find the interaction pairs of. 30 | v_sites: Virtual sites that will be added to the topology. 31 | 32 | Returns: 33 | A dictionary of the form ``{(atom_idx_1, atom_idx_2): scale}``. 34 | """ 35 | 36 | graph = networkx.from_edgelist( 37 | tuple( 38 | sorted((topology.atom_index(bond.atom1), topology.atom_index(bond.atom2))) 39 | ) 40 | for bond in topology.bonds 41 | ) 42 | 43 | if v_sites is not None: 44 | for v_site_key in v_sites.keys: 45 | v_site_idx = v_sites.key_to_idx[v_site_key] 46 | parent_idx = v_site_key.orientation_atom_indices[0] 47 | 48 | for neighbour_idx in graph.neighbors(parent_idx): 49 | graph.add_edge(v_site_idx, neighbour_idx) 50 | 51 | graph.add_edge(v_site_idx, parent_idx) 52 | 53 | distances = dict(networkx.all_pairs_shortest_path_length(graph, cutoff=5)) 54 | distance_to_scale = {1: "scale_12", 2: "scale_13", 3: "scale_14", 4: "scale_15"} 55 | 56 | exclusions = {} 57 | 58 | for idx_a in distances: 59 | for idx_b, distance in distances[idx_a].items(): 60 | pair = tuple(sorted((idx_a, idx_b))) 61 | scale = distance_to_scale.get(distance) 62 | 63 | if scale is None: 64 | continue 65 | 66 | assert pair not in exclusions or exclusions[pair] == scale 67 | exclusions[pair] = scale 68 | 69 | return exclusions 70 | 71 | 72 | def ones_like(size: _size, other: torch.Tensor) -> torch.Tensor: 73 | """Create a tensor of ones with the same device and type as another tensor.""" 74 | return torch.ones(size, dtype=other.dtype, device=other.device) 75 | 76 | 77 | def zeros_like(size: _size, other: torch.Tensor) -> torch.Tensor: 78 | """Create a tensor of zeros with the same device and type as another tensor.""" 79 | return torch.zeros(size, dtype=other.dtype, device=other.device) 80 | 81 | 82 | def tensor_like(data: typing.Any, other: torch.Tensor) -> torch.Tensor: 83 | """Create a tensor with the same device and type as another tensor.""" 84 | 85 | if isinstance(data, torch.Tensor): 86 | return data.clone().detach().to(other.device, other.dtype) 87 | 88 | return torch.tensor(data, dtype=other.dtype, device=other.device) 89 | 90 | 91 | def arange_like(end: int, other: torch.Tensor) -> torch.Tensor: 92 | """Arange a tensor with the same device and type as another tensor.""" 93 | return torch.arange(end, dtype=other.dtype, device=other.device) 94 | 95 | 96 | def logsumexp( 97 | a: torch.Tensor, 98 | dim: int, 99 | keepdim: bool = False, 100 | b: torch.Tensor | None = None, 101 | return_sign: bool = False, 102 | ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: 103 | """Compute the log of the sum of the exponential of the input elements, optionally 104 | with each element multiplied by a scaling factor. 105 | 106 | Notes: 107 | This should be removed if torch.logsumexp is updated to support scaling factors. 108 | 109 | Args: 110 | a: The elements that should be summed over. 111 | dim: The dimension to sum over. 112 | keepdim: Whether to keep the summed dimension. 113 | b: The scaling factor to multiply each element by. 114 | 115 | Returns: 116 | The log of the sum of exponential of the a elements. 117 | """ 118 | a_type = a.dtype 119 | 120 | if b is None: 121 | assert return_sign is False 122 | return torch.logsumexp(a, dim, keepdim) 123 | 124 | a = a.double() 125 | b = b if b is not None else b.double() 126 | 127 | a, b = torch.broadcast_tensors(a, b) 128 | 129 | if torch.any(b == 0): 130 | a[b == 0] = -torch.inf 131 | 132 | a_max = torch.amax(a, dim=dim, keepdim=True) 133 | 134 | if a_max.ndim > 0: 135 | a_max[~torch.isfinite(a_max)] = 0 136 | elif not torch.isfinite(a_max): 137 | a_max = 0 138 | 139 | exp_sum = torch.sum(b * torch.exp(a - a_max), dim=dim, keepdim=keepdim) 140 | sign = None 141 | 142 | if return_sign: 143 | sign = torch.sign(exp_sum) 144 | exp_sum = exp_sum * sign 145 | 146 | ln_exp_sum = torch.log(exp_sum) 147 | 148 | if not keepdim: 149 | a_max = torch.squeeze(a_max, dim=dim) 150 | 151 | ln_exp_sum += a_max 152 | ln_exp_sum = ln_exp_sum.to(a_type) 153 | 154 | if return_sign: 155 | return ln_exp_sum, sign.to(a_type) 156 | else: 157 | return ln_exp_sum 158 | 159 | 160 | def to_upper_tri_idx( 161 | i: torch.Tensor, j: torch.Tensor, n: int, include_diag: bool = False 162 | ) -> torch.Tensor: 163 | """Converts pairs of 2D indices to 1D indices in an upper triangular matrix that 164 | excludes the diagonal. 165 | 166 | Args: 167 | i: A tensor of the indices along the first axis with ``shape=(n_pairs,)``. 168 | j: A tensor of the indices along the second axis with ``shape=(n_pairs,)``. 169 | n: The size of the matrix. 170 | include_diag: Whether the diagonal is included in the upper triangular matrix. 171 | 172 | Returns: 173 | A tensor of the indices in the upper triangular matrix with 174 | ``shape=(n_pairs * (n_pairs - 1) // 2,)``. 175 | """ 176 | 177 | if not include_diag: 178 | assert (i < j).all(), "i must be less than j" 179 | return (i * (2 * n - i - 1)) // 2 + j - i - 1 180 | 181 | assert (i <= j).all(), "i must be less than or equal to j" 182 | return (i * (2 * n - i + 1)) // 2 + j - i 183 | 184 | 185 | class _SafeGeometricMean(torch.autograd.Function): 186 | generate_vmap_rule = True 187 | 188 | @staticmethod 189 | def forward(eps_a, eps_b): 190 | return torch.sqrt(eps_a * eps_b) 191 | 192 | @staticmethod 193 | def setup_context(ctx, inputs, output): 194 | eps_a, eps_b = inputs 195 | eps = output 196 | ctx.save_for_backward(eps_a, eps_b, eps) 197 | ctx.save_for_forward(eps_a, eps_b, eps) 198 | 199 | @staticmethod 200 | def backward(ctx, grad_output): 201 | eps_a, eps_b, eps = ctx.saved_tensors 202 | 203 | eps = torch.where(eps == 0.0, EPSILON, eps) 204 | 205 | grad_eps_a = grad_output * eps_b / (2 * eps) 206 | grad_eps_b = grad_output * eps_a / (2 * eps) 207 | 208 | return grad_eps_a, grad_eps_b 209 | 210 | @staticmethod 211 | def jvp(ctx, *grad_inputs): 212 | eps_a, eps_b, eps = ctx.saved_tensors 213 | 214 | eps = torch.where(eps == 0.0, EPSILON, eps) 215 | 216 | grad_eps_a = grad_inputs[0] * eps_b / (2 * eps) 217 | grad_eps_b = grad_inputs[1] * eps_a / (2 * eps) 218 | 219 | return grad_eps_a + grad_eps_b 220 | 221 | 222 | def geometric_mean(eps_a: torch.Tensor, eps_b: torch.Tensor) -> torch.Tensor: 223 | """Computes the geometric mean of two values 'safely'. 224 | 225 | A small epsilon (``smee.utils.EPSILON``) is added when computing the gradient in 226 | cases where the mean is zero to prevent divide by zero errors. 227 | 228 | Args: 229 | eps_a: The first value. 230 | eps_b: The second value. 231 | 232 | Returns: 233 | The geometric mean of the two values. 234 | """ 235 | 236 | return _SafeGeometricMean.apply(eps_a, eps_b) 237 | --------------------------------------------------------------------------------