├── .github └── workflows │ ├── build-wheels.yml │ ├── cxx-tests.yml │ ├── docs.yml │ └── python-tests.yml ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── benchmarks ├── benchmark.py └── carbon.xyz ├── create-single-cpp.py ├── docs ├── Doxyfile └── src │ ├── benchmark.png │ ├── benchmarks.rst │ ├── c-api.rst │ ├── conf.py │ ├── index.rst │ ├── metatomic.rst │ ├── python-api.rst │ ├── static │ └── images │ │ ├── Arpitan.png │ │ ├── Catalan.png │ │ ├── Lombardy.png │ │ └── Occitan.png │ └── torch-api.rst ├── python ├── vesin │ ├── MANIFEST.in │ ├── README.md │ ├── VERSION │ ├── pyproject.toml │ ├── setup.py │ ├── tests │ │ ├── data │ │ │ ├── Cd2I4O12.xyz │ │ │ ├── carbon.xyz │ │ │ ├── diamond.xyz │ │ │ ├── naphthalene.xyz │ │ │ ├── readme.txt │ │ │ └── water.xyz │ │ ├── test_metatomic.py │ │ └── test_neighbors.py │ └── vesin │ │ ├── __init__.py │ │ ├── _ase.py │ │ ├── _c_api.py │ │ ├── _c_lib.py │ │ ├── _neighbors.py │ │ └── metatomic │ │ ├── __init__.py │ │ ├── _model.py │ │ └── _neighbors.py └── vesin_torch │ ├── MANIFEST.in │ ├── README.md │ ├── VERSION │ ├── build-backend │ └── backend.py │ ├── pyproject.toml │ ├── setup.py │ ├── tests │ ├── test_autograd.py │ ├── test_metatensor.py │ ├── test_metatomic.py │ └── test_neighbors.py │ └── vesin │ └── torch │ ├── __init__.py │ ├── _c_lib.py │ ├── _neighbors.py │ └── metatensor │ ├── __init__.py │ ├── _model.py │ └── _neighbors.py ├── ruff.toml ├── scripts ├── clean-python.sh ├── create-torch-versions-range.py └── pytest-dont-rewrite-torch.py ├── setup.py ├── tox.ini └── vesin ├── CMakeLists.txt ├── VERSION ├── include └── vesin.h ├── src ├── cpu_cell_list.cpp ├── cpu_cell_list.hpp ├── math.hpp ├── types.hpp └── vesin.cpp ├── tests ├── CMakeLists.txt ├── memory.cpp └── neighbors.cpp └── torch ├── CMakeLists.txt ├── include └── vesin_torch.hpp └── src └── vesin_torch.cpp /.github/workflows/build-wheels.yml: -------------------------------------------------------------------------------- 1 | name: Build Python wheels 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | tags: ["*"] 7 | pull_request: 8 | paths: 9 | # build wheels in PR if this file changed 10 | - '.github/workflows/build-wheels.yml' 11 | # build wheels in PR if any of the build system files changed 12 | - '**/VERSION' 13 | - '**/setup.py' 14 | - '**/pyproject.toml' 15 | - '**/MANIFEST.in' 16 | - '**/CMakeLists.txt' 17 | schedule: 18 | # check the build once a week on mondays 19 | - cron: '0 10 * * 1' 20 | 21 | concurrency: 22 | group: wheels-${{ github.ref }} 23 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 24 | 25 | jobs: 26 | build-wheels: 27 | runs-on: ${{ matrix.os }} 28 | name: ${{ matrix.name }} 29 | strategy: 30 | matrix: 31 | include: 32 | - name: x86_64 Linux 33 | os: ubuntu-22.04 34 | cibw_arch: x86_64 35 | - name: arm64 Linux 36 | os: ubuntu-22.04-arm 37 | cibw_arch: aarch64 38 | - name: x86_64 macOS 39 | os: macos-13 40 | cibw_arch: x86_64 41 | - name: M1 macOS 42 | os: macos-14 43 | cibw_arch: arm64 44 | - name: x86_64 Windows 45 | os: windows-2022 46 | cibw_arch: AMD64 47 | steps: 48 | - uses: actions/checkout@v4 49 | with: 50 | fetch-depth: 0 51 | 52 | - name: Set up Python 53 | uses: actions/setup-python@v5 54 | with: 55 | python-version: "3.12" 56 | 57 | - name: install dependencies 58 | run: python -m pip install cibuildwheel twine 59 | 60 | - name: build wheel 61 | run: python -m cibuildwheel python/vesin 62 | env: 63 | CIBW_BUILD: cp312-* 64 | CIBW_SKIP: "*musllinux*" 65 | CIBW_ARCHS: ${{ matrix.cibw_arch }} 66 | CIBW_BUILD_FRONTEND: build 67 | CIBW_MANYLINUX_X86_64_IMAGE: quay.io/pypa/manylinux2014_x86_64 68 | CIBW_MANYLINUX_AARCH64_IMAGE: quay.io/pypa/manylinux2014_aarch64 69 | 70 | - name: check wheels with twine 71 | run: twine check wheelhouse/* 72 | 73 | - uses: actions/upload-artifact@v4 74 | with: 75 | name: wheel-${{ matrix.os }}-${{ matrix.cibw_arch }} 76 | path: ./wheelhouse/*.whl 77 | 78 | build-torch-wheels: 79 | runs-on: ${{ matrix.os }} 80 | name: ${{ matrix.name }} (torch v${{ matrix.torch-version }}) 81 | strategy: 82 | matrix: 83 | torch-version: ['2.3', '2.4', '2.5', '2.6', '2.7'] 84 | arch: ['arm64', 'x86_64'] 85 | os: ['ubuntu-22.04', 'ubuntu-22.04-arm', 'macos-14', 'windows-2022'] 86 | exclude: 87 | # remove mismatched architectures 88 | - {os: macos-14, arch: x86_64} 89 | - {os: ubuntu-22.04, arch: arm64} 90 | - {os: ubuntu-22.04-arm, arch: x86_64} 91 | - {os: windows-2022, arch: arm64} 92 | include: 93 | # add `cibw-arch` to the different configurations 94 | - name: x86_64 Linux 95 | os: ubuntu-22.04 96 | arch: x86_64 97 | cibw-arch: x86_64 98 | - name: arm64 Linux 99 | os: ubuntu-22.04-arm 100 | arch: arm64 101 | cibw-arch: aarch64 102 | - name: arm64 macOS 103 | os: macos-14 104 | arch: arm64 105 | cibw-arch: arm64 106 | - name: x86_64 Windows 107 | os: windows-2022 108 | arch: x86_64 109 | cibw-arch: AMD64 110 | # add the right python version for each torch version 111 | - {torch-version: '2.3', python-version: '3.12', cibw-python: 'cp312-*'} 112 | - {torch-version: '2.4', python-version: '3.12', cibw-python: 'cp312-*'} 113 | - {torch-version: '2.5', python-version: '3.12', cibw-python: 'cp312-*'} 114 | - {torch-version: '2.6', python-version: '3.12', cibw-python: 'cp312-*'} 115 | - {torch-version: '2.7', python-version: '3.12', cibw-python: 'cp312-*'} 116 | steps: 117 | - uses: actions/checkout@v4 118 | with: 119 | fetch-depth: 0 120 | 121 | - name: Set up Python 122 | uses: actions/setup-python@v5 123 | with: 124 | python-version: ${{ matrix.python-version }} 125 | 126 | - name: install dependencies 127 | run: python -m pip install cibuildwheel 128 | 129 | - name: build vesin-torch wheel 130 | run: python -m cibuildwheel python/vesin_torch 131 | env: 132 | CIBW_BUILD: ${{ matrix.cibw-python}} 133 | CIBW_SKIP: "*musllinux*" 134 | CIBW_ARCHS: ${{ matrix.cibw-arch }} 135 | CIBW_BUILD_VERBOSITY: 1 136 | CIBW_MANYLINUX_X86_64_IMAGE: quay.io/pypa/manylinux_2_28_x86_64 137 | CIBW_MANYLINUX_AARCH64_IMAGE: quay.io/pypa/manylinux_2_28_aarch64 138 | CIBW_ENVIRONMENT: > 139 | VESIN_TORCH_BUILD_WITH_TORCH_VERSION=${{ matrix.torch-version }}.* 140 | PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu 141 | MACOSX_DEPLOYMENT_TARGET=11 142 | # do not complain for missing libtorch.so 143 | CIBW_REPAIR_WHEEL_COMMAND_MACOS: | 144 | delocate-wheel --ignore-missing-dependencies --require-archs {delocate_archs} -w {dest_dir} -v {wheel} 145 | CIBW_REPAIR_WHEEL_COMMAND_LINUX: | 146 | auditwheel repair --exclude libtorch.so --exclude libtorch_cpu.so --exclude libc10.so -w {dest_dir} {wheel} 147 | 148 | - uses: actions/upload-artifact@v4 149 | with: 150 | name: torch-single-version-wheel-${{ matrix.torch-version }}-${{ matrix.os }}-${{ matrix.arch }} 151 | path: ./wheelhouse/*.whl 152 | 153 | merge-torch-wheels: 154 | needs: build-torch-wheels 155 | runs-on: ubuntu-22.04 156 | name: merge vesin-torch ${{ matrix.name }} 157 | strategy: 158 | matrix: 159 | include: 160 | - name: x86_64 Linux 161 | os: ubuntu-22.04 162 | arch: x86_64 163 | - name: arm64 Linux 164 | os: ubuntu-22.04-arm 165 | arch: arm64 166 | - name: arm64 macOS 167 | os: macos-14 168 | arch: arm64 169 | - name: x86_64 Windows 170 | os: windows-2022 171 | arch: x86_64 172 | steps: 173 | - uses: actions/checkout@v4 174 | 175 | - name: Download wheels 176 | uses: actions/download-artifact@v4 177 | with: 178 | pattern: torch-single-version-wheel-*-${{ matrix.os }}-${{ matrix.arch }} 179 | merge-multiple: false 180 | path: dist 181 | 182 | - name: Set up Python 183 | uses: actions/setup-python@v5 184 | with: 185 | python-version: "3.12" 186 | 187 | - name: install dependencies 188 | run: python -m pip install twine wheel 189 | 190 | - name: merge wheels 191 | run: | 192 | # collect all torch versions used for the build 193 | REQUIRES_TORCH=$(find dist -name "*.whl" -exec unzip -p {} "vesin_torch-*.dist-info/METADATA" \; | grep "Requires-Dist: torch") 194 | MERGED_TORCH_REQUIRE=$(python scripts/create-torch-versions-range.py "$REQUIRES_TORCH") 195 | 196 | echo MERGED_TORCH_REQUIRE=$MERGED_TORCH_REQUIRE 197 | 198 | # unpack all single torch versions wheels in the same directory 199 | mkdir dist/unpacked 200 | find dist -name "*.whl" -print -exec python -m wheel unpack --dest dist/unpacked/ {} ';' 201 | 202 | sed -i "s/Requires-Dist: torch.*/$MERGED_TORCH_REQUIRE/" dist/unpacked/vesin_torch-*/vesin_torch-*.dist-info/METADATA 203 | 204 | echo "\n\n METADATA = \n\n" 205 | cat dist/unpacked/vesin_torch-*/vesin_torch-*.dist-info/METADATA 206 | 207 | # check the right metadata was added to the file. grep will exit with 208 | # code `1` if the line is not found, which will stop CI 209 | grep "$MERGED_TORCH_REQUIRE" dist/unpacked/vesin_torch-*/vesin_torch-*.dist-info/METADATA 210 | 211 | # repack the directory as a new wheel 212 | mkdir wheelhouse 213 | python -m wheel pack --dest wheelhouse/ dist/unpacked/* 214 | 215 | - name: check wheels with twine 216 | run: twine check wheelhouse/* 217 | 218 | - uses: actions/upload-artifact@v4 219 | with: 220 | name: torch-wheel-${{ matrix.os }}-${{ matrix.arch }} 221 | path: ./wheelhouse/*.whl 222 | 223 | build-sdist: 224 | runs-on: ubuntu-22.04 225 | name: sdist 226 | steps: 227 | - uses: actions/checkout@v4 228 | with: 229 | fetch-depth: 0 230 | 231 | - name: Set up Python 232 | uses: actions/setup-python@v5 233 | with: 234 | python-version: "3.12" 235 | 236 | - name: build sdist 237 | run: | 238 | pip install build 239 | python -m build --sdist python/vesin --outdir dist 240 | python -m build --sdist python/vesin_torch --outdir dist 241 | 242 | - uses: actions/upload-artifact@v4 243 | with: 244 | name: sdist 245 | path: dist/*.tar.gz 246 | 247 | merge-and-release: 248 | name: Merge and release wheels/sdists 249 | needs: [build-wheels, merge-torch-wheels, build-sdist] 250 | runs-on: ubuntu-22.04 251 | permissions: 252 | contents: write 253 | steps: 254 | - name: Download wheels 255 | uses: actions/download-artifact@v4 256 | with: 257 | path: dist 258 | pattern: wheel-* 259 | merge-multiple: true 260 | 261 | - name: Download metatensor-torch wheels 262 | uses: actions/download-artifact@v4 263 | with: 264 | path: dist 265 | pattern: torch-wheel-* 266 | merge-multiple: true 267 | 268 | - name: Download sdist 269 | uses: actions/download-artifact@v4 270 | with: 271 | path: dist 272 | name: sdist 273 | 274 | - name: Re-upload a single wheels artifact 275 | uses: actions/upload-artifact@v4 276 | with: 277 | name: wheels 278 | path: | 279 | dist/* 280 | 281 | - name: upload to GitHub release 282 | if: startsWith(github.ref, 'refs/tags/') 283 | uses: softprops/action-gh-release@v2 284 | with: 285 | files: | 286 | dist/* 287 | prerelease: ${{ contains(github.ref, '-rc') }} 288 | env: 289 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 290 | -------------------------------------------------------------------------------- /.github/workflows/cxx-tests.yml: -------------------------------------------------------------------------------- 1 | name: C++ tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | # Check all PR 8 | 9 | concurrency: 10 | group: cxx-tests-${{ github.ref }} 11 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 12 | 13 | jobs: 14 | rust-tests: 15 | runs-on: ${{ matrix.os }} 16 | name: ${{ matrix.os }} / ${{ matrix.compiler }} 17 | container: ${{ matrix.container }} 18 | strategy: 19 | matrix: 20 | include: 21 | - os: ubuntu-22.04 22 | compiler: GCC - Valgrind 23 | cc: gcc-13 24 | cxx: g++-13 25 | do-valgrind: true 26 | setup-dependencies: | 27 | sudo add-apt-repository ppa:ubuntu-toolchain-r/test 28 | sudo apt-get update 29 | sudo apt-get install -y gcc-13 g++-13 30 | 31 | - os: ubuntu-22.04 32 | compiler: Clang 33 | cc: clang-18 34 | cxx: clang++-18 35 | setup-dependencies: | 36 | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - 37 | sudo apt-get update 38 | sudo apt-add-repository "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-18 main" 39 | sudo apt-get install -y clang-18 40 | 41 | # check the build on a stock Ubuntu 20.04, which uses cmake 3.16 42 | - os: ubuntu-22.04 43 | compiler: GCC 44 | cc: gcc 45 | cxx: g++ 46 | container: ubuntu:20.04 47 | setup-dependencies: | 48 | apt update 49 | apt install -y software-properties-common 50 | apt install -y cmake make gcc g++ git curl 51 | 52 | - os: macos-14 53 | compiler: Clang 54 | cc: clang 55 | cxx: clang++ 56 | 57 | - os: windows-2022 58 | compiler: MSVC 59 | cc: cl.exe 60 | cxx: cl.exe 61 | cmake-extra-args: 62 | - -G "Visual Studio 17 2022" -A x64 63 | 64 | - os: windows-2022 65 | compiler: MinGW 66 | cc: gcc.exe 67 | cxx: g++.exe 68 | cmake-extra-args: 69 | - -G "MinGW Makefiles" 70 | 71 | steps: 72 | - name: setup dependencies 73 | run: ${{ matrix.setup-dependencies }} 74 | 75 | - uses: actions/checkout@v4 76 | with: 77 | fetch-depth: 0 78 | 79 | - name: install valgrind 80 | if: matrix.do-valgrind 81 | run: | 82 | sudo apt-get update 83 | sudo apt-get install -y valgrind 84 | 85 | - name: configure cmake 86 | shell: bash 87 | run: | 88 | mkdir build && cd build 89 | cmake ${{ join(matrix.cmake-extra-args, ' ') }} \ 90 | -DCMAKE_BUILD_TYPE=Debug \ 91 | -DCMAKE_C_COMPILER=${{ matrix.cc }} \ 92 | -DCMAKE_CXX_COMPILER=${{ matrix.cxx }} \ 93 | -DVESIN_BUILD_TESTS=ON \ 94 | -DCMAKE_VERBOSE_MAKEFILE=ON \ 95 | ../vesin 96 | 97 | - name: build 98 | run: | 99 | cd build 100 | cmake --build . --config Debug --parallel 2 101 | 102 | - name: run tests 103 | run: | 104 | cd build 105 | ctest --output-on-failure --build-config Debug --parallel 2 106 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | tags: ["*"] 7 | pull_request: 8 | # Check all PR 9 | 10 | jobs: 11 | build-and-publish: 12 | permissions: 13 | contents: write 14 | runs-on: ubuntu-22.04 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: setup Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.12" 22 | 23 | - name: install dependencies 24 | run: | 25 | python -m pip install tox 26 | sudo apt install doxygen 27 | 28 | - name: build documentation 29 | run: tox -e docs 30 | env: 31 | # Use the CPU only version of torch when building/running the code 32 | PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu 33 | 34 | - name: put documentation in the website 35 | run: | 36 | git clone https://github.com/$GITHUB_REPOSITORY --branch gh-pages gh-pages 37 | rm -rf gh-pages/.git 38 | cd gh-pages 39 | 40 | REF_KIND=$(echo $GITHUB_REF | cut -d / -f2) 41 | if [[ "$REF_KIND" == "tags" ]]; then 42 | TAG=${GITHUB_REF#refs/tags/} 43 | mv ../docs/build/html $TAG 44 | else 45 | rm -rf latest 46 | mv ../docs/build/html latest 47 | fi 48 | 49 | - name: deploy to gh-pages 50 | if: github.event_name == 'push' 51 | uses: peaceiris/actions-gh-pages@v4 52 | with: 53 | github_token: ${{ secrets.GITHUB_TOKEN }} 54 | publish_dir: ./gh-pages/ 55 | force_orphan: true 56 | -------------------------------------------------------------------------------- /.github/workflows/python-tests.yml: -------------------------------------------------------------------------------- 1 | name: Python tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | # Check all PR 8 | 9 | concurrency: 10 | group: python-tests-${{ github.ref }} 11 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} 12 | 13 | jobs: 14 | python-tests: 15 | runs-on: ${{ matrix.os }} 16 | name: ${{ matrix.os }} / Python ${{ matrix.python-version }} 17 | strategy: 18 | matrix: 19 | include: 20 | - os: ubuntu-22.04 21 | python-version: "3.9" 22 | - os: ubuntu-22.04 23 | python-version: "3.12" 24 | - os: macos-14 25 | python-version: "3.12" 26 | - os: windows-2022 27 | python-version: "3.12" 28 | steps: 29 | - uses: actions/checkout@v4 30 | with: 31 | fetch-depth: 0 32 | 33 | - name: setup Python 34 | uses: actions/setup-python@v5 35 | with: 36 | python-version: ${{ matrix.python-version }} 37 | 38 | - name: install tests dependencies 39 | run: | 40 | python -m pip install --upgrade pip 41 | python -m pip install tox 42 | 43 | - name: run tests 44 | run: tox 45 | env: 46 | # Use the CPU only version of torch when building/running the code 47 | PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | .cache/ 4 | .tox/ 5 | __pycache__/ 6 | 7 | *.egg-info 8 | 9 | vesin-single-build.cpp 10 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # This is not the actual CMakeLists.txt for this project, see 2 | # `vesin/CMakeLists.txt` for it. Instead, this file is here to enable `cmake ..` 3 | # from a git checkout, `add_subdirectory` and `FetchContent` without having to 4 | # specify a subdirectory. 5 | 6 | cmake_minimum_required(VERSION 3.16) 7 | 8 | project(vesin-git LANGUAGES NONE) 9 | 10 | if (${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_CURRENT_SOURCE_DIR}) 11 | set(VESIN_MAIN_PROJECT ON) 12 | else() 13 | set(VESIN_MAIN_PROJECT OFF) 14 | endif() 15 | 16 | if (VESIN_MAIN_PROJECT) 17 | if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "") 18 | message(STATUS "Setting build type to 'release' as none was specified.") 19 | set( 20 | CMAKE_BUILD_TYPE "release" 21 | CACHE STRING 22 | "Choose the type of build, options are: none(CMAKE_CXX_FLAGS or CMAKE_C_FLAGS used) debug release relwithdebinfo minsizerel." 23 | FORCE 24 | ) 25 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug relwithdebinfo minsizerel none) 26 | endif() 27 | endif() 28 | 29 | set(VESIN_INSTALL ${VESIN_MAIN_PROJECT} CACHE BOOL "Install Vesin's headers and libraries") 30 | add_subdirectory(vesin) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Guillaume Fraux 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vesin: fast neighbor lists for atomistic systems 2 | 3 | [![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg)](http://luthaf.fr/vesin/) 4 | ![Tests](https://img.shields.io/github/check-runs/Luthaf/vesin/main?logo=github&label=tests) 5 | 6 | | English 🇺🇸⁠/⁠🇬🇧 | Occitan | French 🇫🇷 | Arpitan | Gallo‑Italic | Catalan | Spanish 🇪🇸 | Italian 🇮🇹 | 7 | |------------------|----------|-----------|----------|--------------|---------|------------|------------| 8 | | neighbo(u)r | vesin | voisin | vesin | visin | veí | vecino | vicino | 9 | 10 | 11 | Vesin is a fast and easy to use library computing neighbor lists for atomistic 12 | system. We provide an interface for the following programing languages: 13 | 14 | - C (also compatible with C++). The project can be installed and used as a 15 | library with your own build system, or included as a single file and built 16 | directly by your own build system; 17 | - Python; 18 | - TorchScript, with both a C++ and Python interface; 19 | 20 | ### Installation 21 | 22 | To use the code from Python, you can install it with `pip`: 23 | 24 | ``` 25 | pip install vesin 26 | ``` 27 | 28 | See the [documentation](https://luthaf.fr/vesin/latest/index.html#installation) 29 | for more information on how to install the code to use it from C or C++. 30 | 31 | ### Usage instruction 32 | 33 | You can either use the `NeighborList` calculator class: 34 | 35 | ```py 36 | import numpy as np 37 | from vesin import NeighborList 38 | 39 | # positions can be anything compatible with numpy's ndarray 40 | positions = [ 41 | (0, 0, 0), 42 | (0, 1.3, 1.3), 43 | ] 44 | box = 3.2 * np.eye(3) 45 | 46 | calculator = NeighborList(cutoff=4.2, full_list=True) 47 | i, j, S, d = calculator.compute( 48 | points=positions, 49 | box=box, 50 | periodic=True, 51 | quantities="ijSd" 52 | ) 53 | ``` 54 | 55 | We also provide a function with drop-in compatibility to ASE's neighbor list: 56 | 57 | ```py 58 | import ase 59 | from vesin import ase_neighbor_list 60 | 61 | atoms = ase.Atoms(...) 62 | 63 | i, j, S, d = ase_neighbor_list("ijSd", atoms, cutoff=4.2) 64 | ``` 65 | 66 | See the [documentation](https://luthaf.fr/vesin/latest/c-api.html) for more 67 | information on how to use the code from C or C++. 68 | 69 | ### Benchmarks 70 | 71 | You can find below benchmark result computing neighbor lists for increasingly 72 | large diamond supercells, using an AMD 3955WX CPU and an NVIDIA 4070 Ti SUPER 73 | GPU. You can run this benchmark on your system with the script at 74 | `benchmarks/benchmark.py`. Missing points indicate that a specific code could 75 | not run the calculation (for example, NNPOps requires the cell to be twice the 76 | cutoff in size, and can't run with large cutoffs and small cells). 77 | 78 | ![Benchmarks](./docs/src/benchmark.png) 79 | 80 | ## License 81 | 82 | Vesin is is distributed under the [3 clauses BSD license](LICENSE). By 83 | contributing to this code, you agree to distribute your contributions under the 84 | same license. 85 | -------------------------------------------------------------------------------- /benchmarks/benchmark.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import ase.build 4 | import ase.neighborlist 5 | import matscipy.neighbours 6 | import NNPOps.neighbors 7 | import numpy as np 8 | import pymatgen.core 9 | import torch 10 | import torch_nl 11 | 12 | import vesin 13 | 14 | 15 | def benchmark(setup, function, atoms, cutoff): 16 | args = setup(atoms, cutoff) 17 | 18 | n_warm = 5 19 | start = time.time() 20 | for _ in range(n_warm): 21 | function(*args) 22 | end = time.time() 23 | 24 | warmup = (end - start) / n_warm 25 | 26 | # dynamically pick the number of iterations to keep timing below 1s per test, while 27 | # also ensuring at least 10 repetitions 28 | n_iter = int(1.0 / warmup) 29 | if n_iter > 10000: 30 | n_iter = 10000 31 | elif n_iter < 10: 32 | n_iter = 10 33 | 34 | start = time.time() 35 | for _ in range(n_iter): 36 | function(*args) 37 | end = time.time() 38 | 39 | return (end - start) / n_iter 40 | 41 | 42 | def setup_torch_nl_cpu(atoms, cutoff): 43 | pos, cell, pbc, batch, n_atoms = torch_nl.ase2data( 44 | [atoms], device=torch.device("cpu") 45 | ) 46 | return cutoff, pos, cell, pbc, batch 47 | 48 | 49 | def setup_torch_nl_cuda(atoms, cutoff): 50 | pos, cell, pbc, batch, n_atoms = torch_nl.ase2data( 51 | [atoms], device=torch.device("cuda") 52 | ) 53 | return cutoff, pos, cell, pbc, batch 54 | 55 | 56 | def torch_nl_run(cutoff, pos, cell, pbc, batch): 57 | return torch_nl.compute_neighborlist( 58 | cutoff, pos, cell, pbc, batch, self_interaction=True 59 | ) 60 | 61 | 62 | def setup_nnpops_cpu(atoms, cutoff): 63 | positions = torch.tensor(atoms.positions) 64 | box_vector = torch.tensor(atoms.cell) 65 | return positions, cutoff, box_vector 66 | 67 | 68 | def setup_nnpops_cuda(atoms, cutoff): 69 | positions = torch.tensor(atoms.positions).to("cuda") 70 | box_vector = torch.tensor(atoms.cell).to("cuda") 71 | return positions, cutoff, box_vector 72 | 73 | 74 | def nnpops_run(positions, cutoff, box_vectors): 75 | return NNPOps.neighbors.getNeighborPairs( 76 | positions, cutoff=cutoff, box_vectors=box_vectors 77 | ) 78 | 79 | 80 | def setup_ase_like(atoms, cutoff): 81 | return "ijSd", atoms, float(cutoff) 82 | 83 | 84 | def setup_pymatgen(atoms, cutoff): 85 | structure = pymatgen.core.Structure( 86 | atoms.cell[:], 87 | atoms.numbers, 88 | atoms.positions, 89 | coords_are_cartesian=True, 90 | ) 91 | return structure, cutoff 92 | 93 | 94 | def pymatgen_run(structure, cutoff): 95 | return structure.get_neighbor_list(cutoff) 96 | 97 | 98 | def determine_super_cell(max_cell_repeat, max_log_size_delta, max_cell_ratio): 99 | """ 100 | Determine which super cells to include. We want equally spaced number of atoms in 101 | log scale, and cells that are not too anisotropic. 102 | """ 103 | sizes = {} 104 | for kx in range(1, max_cell_repeat): 105 | for ky in range(1, max_cell_repeat): 106 | for kz in range(1, max_cell_repeat): 107 | size = kx * ky * kz 108 | if size in sizes: 109 | sizes[size].append((kx, ky, kz)) 110 | else: 111 | sizes[size] = [(kx, ky, kz)] 112 | 113 | # for each size, pick the less anisotropic cell 114 | repeats = [] 115 | a, b, c = atoms.cell.lengths() 116 | for candidates in sizes.values(): 117 | best = None 118 | best_ratio = np.inf 119 | for kx, ky, kz in candidates: 120 | lengths = [kx * a, ky * b, kz * c] 121 | ratio = np.max(lengths) / np.min(lengths) 122 | if ratio < best_ratio: 123 | best = (kx, ky, kz) 124 | best_ratio = ratio 125 | 126 | repeats.append(best) 127 | 128 | filtered_repeats = [] 129 | filtered_log_sizes = [-1] 130 | 131 | for kx, ky, kz in repeats: 132 | log_size = np.log(kx * ky * kz) 133 | lengths = [kx * a, ky * b, kz * c] 134 | ratio = np.max(lengths) / np.min(lengths) 135 | if np.min(np.abs(np.array(filtered_log_sizes) - log_size)) > max_log_size_delta: 136 | if log_size < 2 or ratio < max_cell_ratio: 137 | filtered_repeats.append((kx, ky, kz)) 138 | filtered_log_sizes.append(log_size) 139 | 140 | return filtered_repeats 141 | 142 | 143 | atoms = ase.build.bulk("C", "diamond", 3.567, orthorhombic=True) 144 | 145 | repeats = determine_super_cell( 146 | max_cell_repeat=20, max_log_size_delta=0.1, max_cell_ratio=3 147 | ) 148 | 149 | 150 | n_atoms = {} 151 | ase_time = {} 152 | matscipy_time = {} 153 | torch_nl_cpu_time = {} 154 | torch_nl_cuda_time = {} 155 | pymatgen_time = {} 156 | vesin_time = {} 157 | 158 | for cutoff in [3, 6, 12]: 159 | print(f"=========== CUTOFF={cutoff} =============") 160 | 161 | n_atoms[cutoff] = [] 162 | ase_time[cutoff] = [] 163 | matscipy_time[cutoff] = [] 164 | torch_nl_cpu_time[cutoff] = [] 165 | torch_nl_cuda_time[cutoff] = [] 166 | pymatgen_time[cutoff] = [] 167 | vesin_time[cutoff] = [] 168 | 169 | for kx, ky, kz in repeats: 170 | super_cell = atoms.repeat((kx, ky, kz)) 171 | print(len(super_cell), "atoms") 172 | n_atoms[cutoff].append(len(super_cell)) 173 | 174 | # ASE 175 | timing = benchmark( 176 | setup_ase_like, 177 | ase.neighborlist.neighbor_list, 178 | super_cell, 179 | cutoff, 180 | ) 181 | ase_time[cutoff].append(timing * 1e3) 182 | print(f" ase took {timing * 1e3:.3f} ms") 183 | 184 | # MATSCIPY 185 | timing = benchmark( 186 | setup_ase_like, 187 | matscipy.neighbours.neighbour_list, 188 | super_cell, 189 | cutoff, 190 | ) 191 | matscipy_time[cutoff].append(timing * 1e3) 192 | print(f" matscipy took {timing * 1e3:.3f} ms") 193 | 194 | # TORCH_NL CPU 195 | timing = benchmark( 196 | setup_torch_nl_cpu, 197 | torch_nl_run, 198 | super_cell, 199 | cutoff, 200 | ) 201 | torch_nl_cpu_time[cutoff].append(timing * 1e3) 202 | print(f" torch_nl (cpu) took {timing * 1e3:.3f} ms") 203 | 204 | # TORCH_NL CUDA 205 | timing = benchmark( 206 | setup_torch_nl_cuda, 207 | torch_nl_run, 208 | super_cell, 209 | cutoff, 210 | ) 211 | torch_nl_cuda_time[cutoff].append(timing * 1e3) 212 | print(f" torch_nl (cuda) took {timing * 1e3:.3f} ms") 213 | 214 | if np.any(super_cell.cell.lengths() < 2 * cutoff): 215 | print(" NNPOps can not run for this super cell") 216 | else: 217 | # NNPOps CPU 218 | timing = benchmark( 219 | setup_nnpops_cpu, 220 | nnpops_run, 221 | super_cell, 222 | cutoff, 223 | ) 224 | torch_nl_cpu_time[cutoff].append(timing * 1e3) 225 | print(f" NNPOps (cpu) took {timing * 1e3:.3f} ms") 226 | 227 | # NNPOps CUDA 228 | timing = benchmark( 229 | setup_nnpops_cuda, 230 | nnpops_run, 231 | super_cell, 232 | cutoff, 233 | ) 234 | torch_nl_cuda_time[cutoff].append(timing * 1e3) 235 | print(f" NNPOps (cuda) took {timing * 1e3:.3f} ms") 236 | 237 | # Pymatgen 238 | timing = benchmark( 239 | setup_pymatgen, 240 | pymatgen_run, 241 | super_cell, 242 | cutoff, 243 | ) 244 | pymatgen_time[cutoff].append(timing * 1e3) 245 | print(f" pymatgen took {timing * 1e3:.3f} ms") 246 | 247 | # VESIN 248 | timing = benchmark( 249 | setup_ase_like, 250 | vesin.ase_neighbor_list, 251 | super_cell, 252 | cutoff, 253 | ) 254 | vesin_time[cutoff].append(timing * 1e3) 255 | print(f" vesin took {timing * 1e3:.3f} ms") 256 | print() 257 | -------------------------------------------------------------------------------- /benchmarks/carbon.xyz: -------------------------------------------------------------------------------- 1 | ../python/vesin/tests/data/carbon.xyz -------------------------------------------------------------------------------- /create-single-cpp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | 4 | 5 | HERE = os.path.dirname(os.path.realpath(__file__)) 6 | ALREADY_SEEN = set() 7 | 8 | 9 | def find_file(path): 10 | candidates = [ 11 | os.path.join(HERE, "vesin", "src", path), 12 | os.path.join(HERE, "vesin", "include", path), 13 | ] 14 | for candidate in candidates: 15 | if os.path.exists(candidate): 16 | return candidate 17 | 18 | raise RuntimeError(f"unable to find file for {path}") 19 | 20 | 21 | def include_path(line): 22 | assert "#include" in line 23 | 24 | _, path = line.split("#include") 25 | path = path.strip() 26 | if path.startswith('"'): 27 | return path[1:-1] 28 | else: 29 | return "" 30 | 31 | 32 | def merge_files(path, output): 33 | path = find_file(path) 34 | 35 | if path in ALREADY_SEEN: 36 | return 37 | else: 38 | ALREADY_SEEN.add(path) 39 | 40 | if path.endswith("include/vesin.h"): 41 | output.write('#include "vesin.h"\n') 42 | return 43 | 44 | with open(path) as fd: 45 | for line in fd: 46 | if "#include" in line: 47 | new_path = include_path(line) 48 | if new_path != "": 49 | merge_files(new_path, output) 50 | else: 51 | output.write(line) 52 | else: 53 | output.write(line) 54 | 55 | 56 | if __name__ == "__main__": 57 | with open("vesin-single-build.cpp", "w") as output: 58 | merge_files("cpu_cell_list.cpp", output) 59 | merge_files("vesin.cpp", output) 60 | 61 | print("created single build file 'vesin-single-build.cpp'") 62 | -------------------------------------------------------------------------------- /docs/src/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luthaf/vesin/d0036631d52b75dca9352d80f028a6383335d6d2/docs/src/benchmark.png -------------------------------------------------------------------------------- /docs/src/benchmarks.rst: -------------------------------------------------------------------------------- 1 | .. _benchmarks: 2 | 3 | Benchmarks 4 | ========== 5 | 6 | Here are the result of a benchmark of multiple neighbor list implementations. 7 | The benchmark runs on multiple super-cell of diamond carbon, up to 30'000 atoms, 8 | with multiple cutoffs, and using either CPU or CUDA hardware. 9 | 10 | The results below are for an AMD 3955WX CPU and an NVIDIA 4070 Ti SUPER GPU; if 11 | you want to run it on your own system, the corresponding script is in vesin's 12 | `GitHub repository `_. 13 | 14 | .. _bench-script: https://github.com/Luthaf/vesin/blob/main/benchmarks/benchmark.py 15 | 16 | .. figure:: benchmark.png 17 | :align: center 18 | 19 | Speed comparison between multiple neighbor list implementations: vesin, `ase 20 | `_, `matscipy 21 | `_, `pymatgen 22 | `_, 23 | `torch_nl `_, and `NNPOps 24 | `_. 25 | 26 | Missing points indicate that a specific code could not run the calculation 27 | (for example, NNPOps requires the cell to be twice the cutoff in size, and 28 | can't run with large cutoffs and small cells). 29 | -------------------------------------------------------------------------------- /docs/src/c-api.rst: -------------------------------------------------------------------------------- 1 | .. _c-api: 2 | 3 | C API reference 4 | =============== 5 | 6 | Vesin's C API is defined in the ``vesin.h`` header. The main function is 7 | :c:func:`vesin_neighbors`, which runs a neighbors list calculation. 8 | 9 | .. doxygenfunction:: vesin_neighbors 10 | 11 | .. doxygenfunction:: vesin_free 12 | 13 | .. doxygenstruct:: VesinNeighborList 14 | 15 | .. doxygenstruct:: VesinOptions 16 | 17 | .. doxygenenum:: VesinDevice 18 | -------------------------------------------------------------------------------- /docs/src/conf.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import subprocess 5 | import sys 6 | from datetime import datetime 7 | 8 | 9 | os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1" 10 | os.environ["METATOMIC_IMPORT_FOR_SPHINX"] = "1" 11 | 12 | ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) 13 | sys.path.insert(0, os.path.abspath(".")) 14 | sys.path.insert(0, ROOT) 15 | 16 | 17 | # -- Project information ----------------------------------------------------- 18 | 19 | project = "vesin" 20 | copyright = f"{datetime.now().date().year}, vesin developers" 21 | 22 | 23 | def setup(app): 24 | subprocess.run(["doxygen", "Doxyfile"], cwd=os.path.join(ROOT, "docs")) 25 | 26 | 27 | # -- General configuration --------------------------------------------------- 28 | 29 | needs_sphinx = "4.4.0" 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = [ 35 | "sphinx.ext.autodoc", 36 | "sphinx.ext.intersphinx", 37 | "sphinx_design", 38 | "breathe", 39 | ] 40 | 41 | # Add any paths that contain templates here, relative to this directory. 42 | # templates_path = ["_templates"] 43 | 44 | # List of patterns, relative to source directory, that match files and 45 | # directories to ignore when looking for source files. 46 | # This pattern also affects html_static_path and html_extra_path. 47 | exclude_patterns = ["Thumbs.db", ".DS_Store"] 48 | 49 | autoclass_content = "both" 50 | autodoc_member_order = "bysource" 51 | autodoc_typehints_format = "short" 52 | autodoc_type_aliases = { 53 | "npt.ArrayLike": "numpy.typing.ArrayLike", 54 | } 55 | 56 | 57 | breathe_projects = { 58 | "vesin": os.path.join(ROOT, "docs", "build", "xml"), 59 | } 60 | breathe_default_project = "vesin" 61 | breathe_domain_by_extension = { 62 | "h": "c", 63 | } 64 | 65 | breathe_default_members = ("members", "undoc-members") 66 | cpp_private_member_specifier = "" 67 | 68 | intersphinx_mapping = { 69 | "python": ("https://docs.python.org/3", None), 70 | "metatensor": ("https://docs.metatensor.org/latest/", None), 71 | "metatomic": ("https://docs.metatensor.org/metatomic/latest/", None), 72 | "numpy": ("https://numpy.org/doc/stable/", None), 73 | "torch": ("https://pytorch.org/docs/stable/", None), 74 | "ase": ("https://wiki.fysik.dtu.dk/ase/", None), 75 | } 76 | 77 | html_theme = "furo" 78 | html_title = "Vesin" 79 | # html_static_path = ["_static"] 80 | -------------------------------------------------------------------------------- /docs/src/index.rst: -------------------------------------------------------------------------------- 1 | Vesin: we are all neighbors 2 | =========================== 3 | 4 | .. |occ| image:: /static/images/Occitan.png 5 | :width: 18px 6 | 7 | .. |arp| image:: /static/images/Arpitan.png 8 | :width: 18px 9 | 10 | .. |lomb| image:: /static/images/Lombardy.png 11 | :width: 18px 12 | 13 | .. |cat| image:: /static/images/Catalan.png 14 | :width: 18px 15 | 16 | .. list-table:: 17 | :align: center 18 | :widths: auto 19 | :header-rows: 1 20 | 21 | - * English 🇺🇸⁠/⁠🇬🇧 22 | * Occitan |occ| 23 | * French 🇫🇷 24 | * Arpitan |arp| 25 | * Gallo‑Italic |lomb| 26 | * Catalan |cat| 27 | * Spanish 🇪🇸 28 | * Italian 🇮🇹 29 | - * neighbo(u)r 30 | * vesin 31 | * voisin 32 | * vesin 33 | * visin 34 | * veí 35 | * vecino 36 | * vicino 37 | 38 | 39 | ``vesin`` is a lightweight neighbor list calculator for molecular systems and 40 | three-dimensional graphs. It is written in C++ and can be used as a standalone 41 | library from C or Python. ``vesin`` is designed to be :ref:`fast ` 42 | and easy to use. 43 | 44 | Installation 45 | ------------ 46 | 47 | .. tab-set:: 48 | 49 | .. tab-item:: Python 50 | :sync: python 51 | 52 | You can install the code with ``pip``: 53 | 54 | .. code-block:: bash 55 | 56 | pip install vesin 57 | 58 | **TorchScript:** 59 | 60 | The TorchScript bindings can be installed with: 61 | 62 | .. code-block:: bash 63 | 64 | pip install vesin[torch] 65 | 66 | .. tab-item:: C/C++ (CMake) 67 | :sync: cxx 68 | 69 | If you use CMake as your build system, the simplest thing to do is to 70 | add https://github.com/Luthaf/vesin to your project. 71 | 72 | .. code-block:: cmake 73 | 74 | # assuming the code is in the `vesin/` directory (for example using 75 | # git submodule) 76 | add_subdirectory(vesin) 77 | 78 | target_link_libraries(your-target vesin) 79 | 80 | Alternatively, you can use CMake's `FetchContent 81 | `_ module 82 | to automatically download the code: 83 | 84 | .. code-block:: cmake 85 | 86 | include(FetchContent) 87 | FetchContent_Declare( 88 | vesin 89 | GIT_REPOSITORY https://github.com/Luthaf/vesin.git 90 | ) 91 | 92 | FetchContent_MakeAvailable(vesin) 93 | 94 | target_link_libraries(your-target vesin) 95 | 96 | **TorchScript:** 97 | 98 | To make the TorchScript version of the library available to CMake as 99 | well, you should set the ``VESIN_TORCH`` option to ``ON``. If you are 100 | using ``add_subdirectory(vesin)``: 101 | 102 | .. code-block:: cmake 103 | 104 | set(VESIN_TORCH ON CACHE BOOL "Build the vesin_torch library") 105 | 106 | add_subdirectory(vesin) 107 | 108 | target_link_libraries(your-target vesin_torch) 109 | 110 | And if you are using ``FetchContent``: 111 | 112 | .. code-block:: cmake 113 | 114 | set(VESIN_TORCH ON CACHE BOOL "Build the vesin_torch library") 115 | 116 | # like above 117 | FetchContent_Declare(...) 118 | FetchContent_MakeAvailable(...) 119 | 120 | target_link_libraries(your-target vesin_torch) 121 | 122 | .. tab-item:: C/C++ (single file build) 123 | 124 | We support merging all files in the vesin library to a single one that 125 | can then be included in your own project and built with the same build 126 | system as the rest of your code. 127 | 128 | You can generate this single file to build with the following commands: 129 | 130 | .. code-block:: bash 131 | 132 | git clone https://github.com/Luthaf/vesin.git 133 | cd vesin 134 | python create-single-cpp.py 135 | 136 | Then you'll need to copy both ``include/vesin.h`` and 137 | ``vesin-single-build.cpp`` in your project and configure your build 138 | system accordingly. 139 | 140 | **TorchScript:** 141 | 142 | The TorchScript API does not support single file build, please use one 143 | of the CMake options instead. 144 | 145 | 146 | .. tab-item:: C/C++ (global installation) 147 | 148 | You can build and install vesin in some global location (referred to as 149 | ``$PREFIX`` below), and then use the right compiler flags to give this 150 | location to your compiler. In this case, compilation of ``vesin`` and 151 | your code happen separately. 152 | 153 | .. code-block:: bash 154 | 155 | git clone https://github.com/Luthaf/vesin.git 156 | cd vesin 157 | mkdir build && cd build 158 | cmake -DCMAKE_INSTALL_PREFIX=$PREFIX .. 159 | cmake --install . 160 | 161 | You can then compile your code, adding ``$PREFIX/include`` to the 162 | compiler include path, ``$PREFIX/lib`` to the linker library path; and 163 | linking to vesin (typically with ``-lvesin``). If you are building vesin 164 | as a shared library, you'll also need to define ``VESIN_SHARED`` as a 165 | preprocessor constant (``-DVESIN_SHARED`` when compiling the code). 166 | 167 | Some relevant cmake options you can customize: 168 | 169 | +------------------------------+--------------------------------------------------+------------------------------------------------+ 170 | | Option | Description | Default | 171 | +==============================+==================================================+================================================+ 172 | | ``CMAKE_BUILD_TYPE`` | Type of build: Debug or Release | Release | 173 | +------------------------------+--------------------------------------------------+------------------------------------------------+ 174 | | ``CMAKE_INSTALL_PREFIX`` | Prefix where the library will be installed | ``/usr/local`` | 175 | +------------------------------+--------------------------------------------------+------------------------------------------------+ 176 | | ``BUILD_SHARED_LIBS`` | Default to building and installing a shared | OFF | 177 | | | library instead of a static one | | 178 | +------------------------------+--------------------------------------------------+------------------------------------------------+ 179 | | ``VESIN_INSTALL`` | Should CMake install vesin library and headers | | ON when building vesin directly | 180 | | | | | OFF when including vesin in another project | 181 | +------------------------------+--------------------------------------------------+------------------------------------------------+ 182 | | ``VESIN_TORCH`` | Build (and install if ``VESIN_INSTALL=ON``) the | OFF | 183 | | | vesin_torch library | | 184 | +------------------------------+--------------------------------------------------+------------------------------------------------+ 185 | 186 | **TorchScript:** 187 | 188 | Set ``VESIN_TORCH`` to ``ON`` to build and install the TorchScript 189 | bindings. 190 | 191 | You can then compile your code, adding ``$PREFIX/include`` to the 192 | compiler include path, ``$PREFIX/lib`` to the linker library path; and 193 | linking to vesin_torch (typically with ``-lvesin_torch``). 194 | 195 | You'll need to also add to the include and linker path the path to the 196 | same torch installation that was used to build the library. 197 | 198 | 199 | Usage example 200 | ------------- 201 | 202 | .. tab-set:: 203 | 204 | .. tab-item:: Python 205 | :sync: python 206 | 207 | .. py:currentmodule:: vesin 208 | 209 | There are two ways to use vesin from Python, you can use the 210 | :py:class:`NeighborList` class: 211 | 212 | .. code-block:: Python 213 | 214 | import numpy as np 215 | from vesin import NeighborList 216 | 217 | # positions can be anything compatible with numpy's ndarray 218 | positions = [ 219 | (0, 0, 0), 220 | (0, 1.3, 1.3), 221 | ] 222 | box = 3.2 * np.eye(3) 223 | 224 | calculator = NeighborList(cutoff=4.2, full_list=True) 225 | i, j, S, d = calculator.compute( 226 | points=points, 227 | box=box, 228 | periodic=True, 229 | quantities="ijSd" 230 | ) 231 | 232 | Alternatively, you can use the :py:func:`ase_neighbor_list` function, 233 | which mimics the API of :py:func:`ase.neighborlist.neighbor_list`: 234 | 235 | .. code-block:: Python 236 | 237 | import ase 238 | from vesin import ase_neighbor_list 239 | 240 | atoms = ase.Atoms(...) 241 | 242 | i, j, S, d = ase_neighbor_list("ijSd", atoms, cutoff=4.2) 243 | 244 | 245 | .. tab-item:: C and C++ 246 | :sync: cxx 247 | 248 | .. code-block:: c++ 249 | 250 | #include 251 | #include 252 | #include 253 | 254 | #include 255 | 256 | int main() { 257 | // points can be any pointer to `double[3]` 258 | double points[][3] = { 259 | {0, 0, 0}, 260 | {0, 1.3, 1.3}, 261 | }; 262 | size_t n_points = 2; 263 | 264 | // box can be any `double[3][3]` array 265 | double box[3][3] = { 266 | {3.2, 0.0, 0.0}, 267 | {0.0, 3.2, 0.0}, 268 | {0.0, 0.0, 3.2}, 269 | }; 270 | bool periodic = true; 271 | 272 | // calculation setup 273 | VesinOptions options; 274 | options.cutoff = 4.2; 275 | options.full = true; 276 | 277 | // decide what quantities should be computed 278 | options.return_shifts = true; 279 | options.return_distances = true; 280 | options.return_vectors = false; 281 | 282 | VesinNeighborList neighbors; 283 | memset(&neighbors, 0, sizeof(VesinNeighborList)); 284 | 285 | const char* error_message = NULL; 286 | int status = vesin_neighbors( 287 | points, n_points, box, periodic, 288 | VesinCPU, options, 289 | &neighbors, 290 | &error_message, 291 | ); 292 | 293 | if (status != EXIT_SUCCESS) { 294 | fprintf(stderr, "error: %s\n", error_message); 295 | return 1; 296 | } 297 | 298 | // use neighbors as needed 299 | printf("we have %d pairs\n", neighbors.length); 300 | 301 | vesin_free(&neighbors); 302 | 303 | return 0; 304 | } 305 | 306 | .. tab-item:: TorchScript Python 307 | 308 | The entry point for the TorchScript API is the 309 | :py:class:`vesin.torch.NeighborList` class in Python, and the 310 | corresponding :cpp:class:`vesin_torch::NeighborListHolder` class in C++; 311 | both modeled after the standard Python API. For Python, the class is 312 | available in the ``vesin.torch`` module. 313 | 314 | In both cases, the code is integrated with PyTorch autograd framework, 315 | meaning if the ``points`` or ``box`` argument have 316 | ``requires_grad=True``, then the ``d`` (distances) and ``D`` (distance 317 | vectors) outputs will be integrated to the computational graph. 318 | 319 | .. code-block:: Python 320 | 321 | import torch 322 | from vesin.torch import NeighborList 323 | 324 | positions = torch.tensor( 325 | [[0.0, 0.0, 0.0], 326 | [0.0, 1.3, 1.3]], 327 | dtype=torch.float64, 328 | requires_grad=True, 329 | ) 330 | box = 3.2 * torch.eye(3, dtype=torch.float64) 331 | 332 | calculator = NeighborList(cutoff=4.2, full_list=True) 333 | i, j, S, d = calculator.compute( 334 | points=points, 335 | box=box, 336 | periodic=True, 337 | quantities="ijSd" 338 | ) 339 | 340 | .. tab-item:: TorchScript C++ 341 | 342 | The entry point for the TorchScript API is the 343 | :py:class:`vesin.torch.NeighborList` class in Python, and the 344 | corresponding :cpp:class:`vesin_torch::NeighborListHolder` class in C++; 345 | both modeled after the standard Python API. For C++, the class is 346 | available in the ``vesin_torch.hpp`` header. 347 | 348 | In both cases, the code is integrated with PyTorch autograd framework, 349 | meaning if the ``points`` or ``box`` argument have 350 | ``requires_grad=True``, then the ``d`` (distances) and ``D`` (distance 351 | vectors) outputs will be integrated to the computational graph. 352 | 353 | .. code-block:: C++ 354 | 355 | #include 356 | 357 | #include 358 | 359 | int main() { 360 | auto options = torch::TensorOptions().dtype(torch::kFloat64); 361 | auto positions = torch.tensor( 362 | {{0.0, 0.0, 0.0}, 363 | {0.0, 1.3, 1.3}}, 364 | options 365 | ); 366 | positions.requires_grad_(true); 367 | 368 | auto box = 3.2 * torch.eye(3, options); 369 | 370 | auto calculator = torch::make_intrusive( 371 | /*cutoff=*/ 4.2, 372 | /*full_list=*/ true 373 | ); 374 | 375 | calculator. 376 | auto outputs = calculator.compute( 377 | /*points=*/ points, 378 | /*box=*/ box, 379 | /*periodic=*/ true, 380 | /*quantities=*/ "ijSd", 381 | /*copy=*/ true, 382 | ); 383 | 384 | auto i = outputs[0]; 385 | auto j = outputs[1]; 386 | auto S = outputs[2]; 387 | auto d = outputs[3]; 388 | 389 | // ... 390 | } 391 | 392 | 393 | API Reference 394 | ------------- 395 | 396 | .. toctree:: 397 | :maxdepth: 1 398 | :hidden: 399 | 400 | Vesin 401 | 402 | .. toctree:: 403 | :maxdepth: 1 404 | 405 | python-api 406 | torch-api 407 | c-api 408 | metatomic 409 | 410 | 411 | .. toctree:: 412 | :maxdepth: 1 413 | :hidden: 414 | 415 | benchmarks 416 | -------------------------------------------------------------------------------- /docs/src/metatomic.rst: -------------------------------------------------------------------------------- 1 | .. _metatomic-api: 2 | 3 | Metatomic interface 4 | =================== 5 | 6 | Vesin offers an interface to compute neighbor lists for `metatomic's 7 | `_ atomistic machine learning models. 8 | 9 | .. autofunction:: vesin.metatomic.compute_requested_neighbors 10 | 11 | .. autoclass:: vesin.metatomic.NeighborList 12 | :members: 13 | -------------------------------------------------------------------------------- /docs/src/python-api.rst: -------------------------------------------------------------------------------- 1 | .. _python-api: 2 | 3 | Python API reference 4 | ==================== 5 | 6 | .. currentmodule:: vesin 7 | 8 | .. autoclass:: NeighborList 9 | :members: 10 | 11 | .. autofunction:: ase_neighbor_list 12 | -------------------------------------------------------------------------------- /docs/src/static/images/Arpitan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luthaf/vesin/d0036631d52b75dca9352d80f028a6383335d6d2/docs/src/static/images/Arpitan.png -------------------------------------------------------------------------------- /docs/src/static/images/Catalan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luthaf/vesin/d0036631d52b75dca9352d80f028a6383335d6d2/docs/src/static/images/Catalan.png -------------------------------------------------------------------------------- /docs/src/static/images/Lombardy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luthaf/vesin/d0036631d52b75dca9352d80f028a6383335d6d2/docs/src/static/images/Lombardy.png -------------------------------------------------------------------------------- /docs/src/static/images/Occitan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luthaf/vesin/d0036631d52b75dca9352d80f028a6383335d6d2/docs/src/static/images/Occitan.png -------------------------------------------------------------------------------- /docs/src/torch-api.rst: -------------------------------------------------------------------------------- 1 | .. _torch-api: 2 | 3 | TorchScript API reference 4 | ========================= 5 | 6 | .. autoclass:: vesin.torch.NeighborList 7 | :members: 8 | 9 | 10 | TorchScript API reference (C++) 11 | =============================== 12 | 13 | .. doxygentypedef:: vesin_torch::NeighborList 14 | 15 | .. doxygenclass:: vesin_torch::NeighborListHolder 16 | -------------------------------------------------------------------------------- /python/vesin/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include pyproject.toml 2 | include LICENSE 3 | include VERSION 4 | 5 | recursive-include lib * 6 | -------------------------------------------------------------------------------- /python/vesin/README.md: -------------------------------------------------------------------------------- 1 | ../../README.md -------------------------------------------------------------------------------- /python/vesin/VERSION: -------------------------------------------------------------------------------- 1 | ../../vesin/VERSION -------------------------------------------------------------------------------- /python/vesin/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "vesin" 3 | dynamic = ["version"] 4 | requires-python = ">=3.9" 5 | authors = [ 6 | {name = "Guillaume Fraux", email = "guillaume.fraux@epfl.ch"}, 7 | ] 8 | 9 | dependencies = [ 10 | "numpy" 11 | ] 12 | 13 | readme = "README.md" 14 | license = "BSD-3-Clause" 15 | description = "Computing neighbor lists for atomistic system" 16 | 17 | classifiers = [ 18 | "Development Status :: 4 - Beta", 19 | "Intended Audience :: Science/Research", 20 | "Operating System :: POSIX", 21 | "Operating System :: MacOS :: MacOS X", 22 | "Operating System :: Microsoft :: Windows", 23 | "Programming Language :: Python", 24 | "Programming Language :: Python :: 3", 25 | "Topic :: Scientific/Engineering", 26 | "Topic :: Scientific/Engineering :: Bio-Informatics", 27 | "Topic :: Scientific/Engineering :: Chemistry", 28 | "Topic :: Scientific/Engineering :: Physics", 29 | "Topic :: Software Development :: Libraries", 30 | "Topic :: Software Development :: Libraries :: Python Modules", 31 | ] 32 | 33 | [project.optional-dependencies] 34 | torch = ["vesin-torch"] 35 | 36 | [project.urls] 37 | homepage = "https://github.com/Luthaf/vesin/" 38 | documentation = "https://luthaf.fr/vesin/" 39 | repository = "https://github.com/Luthaf/vesin/" 40 | 41 | ### ======================================================================== ### 42 | 43 | [build-system] 44 | requires = [ 45 | "setuptools >=77", 46 | "wheel >=0.41", 47 | "cmake", 48 | ] 49 | build-backend = "setuptools.build_meta" 50 | 51 | [tool.setuptools] 52 | zip-safe = false 53 | 54 | [tool.setuptools.packages.find] 55 | include = ["vesin*"] 56 | namespaces = false 57 | -------------------------------------------------------------------------------- /python/vesin/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import sys 5 | 6 | from setuptools import Extension, setup 7 | from setuptools.command.bdist_egg import bdist_egg 8 | from setuptools.command.build_ext import build_ext 9 | from setuptools.command.sdist import sdist 10 | from wheel.bdist_wheel import bdist_wheel 11 | 12 | 13 | ROOT = os.path.realpath(os.path.dirname(__file__)) 14 | 15 | VESIN_BUILD_TYPE = os.environ.get("VESIN_BUILD_TYPE", "release") 16 | if VESIN_BUILD_TYPE not in ["debug", "release"]: 17 | raise Exception( 18 | f"invalid build type passed: '{VESIN_BUILD_TYPE}', " 19 | "expected 'debug' or 'release'" 20 | ) 21 | 22 | 23 | class universal_wheel(bdist_wheel): 24 | # When building the wheel, the `wheel` package assumes that if we have a 25 | # binary extension then we are linking to `libpython.so`; and thus the wheel 26 | # is only usable with a single python version. This is not the case for 27 | # here, and the wheel will be compatible with any Python >=3.7. This is 28 | # tracked in https://github.com/pypa/wheel/issues/185, but until then we 29 | # manually override the wheel tag. 30 | def get_tag(self): 31 | tag = bdist_wheel.get_tag(self) 32 | # tag[2:] contains the os/arch tags, we want to keep them 33 | return ("py3", "none") + tag[2:] 34 | 35 | 36 | class cmake_ext(build_ext): 37 | """ 38 | Build the native library using cmake 39 | """ 40 | 41 | def run(self): 42 | source_dir = os.path.join(ROOT, "lib") 43 | if not os.path.exists(source_dir): 44 | # we are building from a checkout 45 | source_dir = os.path.join(ROOT, "..", "..", "vesin") 46 | 47 | build_dir = os.path.join(ROOT, "build", "cmake-build") 48 | install_dir = os.path.join(os.path.realpath(self.build_lib), "vesin") 49 | 50 | os.makedirs(build_dir, exist_ok=True) 51 | 52 | cmake_options = [ 53 | f"-DCMAKE_INSTALL_PREFIX={install_dir}", 54 | f"-DCMAKE_BUILD_TYPE={VESIN_BUILD_TYPE}", 55 | "-DBUILD_SHARED_LIBS=ON", 56 | ] 57 | 58 | subprocess.run( 59 | ["cmake", source_dir, *cmake_options], 60 | cwd=build_dir, 61 | check=True, 62 | ) 63 | subprocess.run( 64 | ["cmake", "--build", build_dir, "--target", "install"], 65 | check=True, 66 | ) 67 | 68 | 69 | class bdist_egg_disabled(bdist_egg): 70 | """Disabled version of bdist_egg 71 | 72 | Prevents setup.py install performing setuptools' default easy_install, 73 | which it should never ever do. 74 | """ 75 | 76 | def run(self): 77 | sys.exit( 78 | "Aborting implicit building of eggs.\nUse `pip install .` or " 79 | "`python -m build --wheel . && pip install " 80 | "dist/metatensor_core-*.whl` to install from source." 81 | ) 82 | 83 | 84 | class sdist_with_lib(sdist): 85 | """ 86 | Create a sdist including the code for the native library 87 | """ 88 | 89 | def run(self): 90 | # generate extra files 91 | shutil.copytree(os.path.join(ROOT, "..", "..", "vesin"), "lib") 92 | 93 | # run original sdist 94 | super().run() 95 | 96 | # cleanup 97 | shutil.rmtree("lib") 98 | 99 | 100 | if __name__ == "__main__": 101 | setup( 102 | version=open("VERSION").read().strip(), 103 | ext_modules=[ 104 | # only declare the extension, it is built & copied as required by cmake 105 | # in the build_ext command 106 | Extension(name="vesin", sources=[]), 107 | ], 108 | cmdclass={ 109 | "sdist": sdist_with_lib, 110 | "build_ext": cmake_ext, 111 | "bdist_egg": bdist_egg if "bdist_egg" in sys.argv else bdist_egg_disabled, 112 | "bdist_wheel": universal_wheel, 113 | }, 114 | ) 115 | -------------------------------------------------------------------------------- /python/vesin/tests/data/Cd2I4O12.xyz: -------------------------------------------------------------------------------- 1 | 18 2 | Lattice="5.7779061565689 0.0 0.0 0.0 5.9353021159594 0.0 0.0 0.0 32.37048911" Properties=species:S:1:pos:R:3:masses:R:1:forces:R:3 pbc="T T T" 3 | Cd 1.66027661 1.01057390 14.51270543 112.41400000 -0.03865710 -0.05802875 0.02849028 4 | Cd 4.54922969 1.95707716 17.85778368 112.41400000 -0.03865710 0.05802875 -0.02849028 5 | I 5.63439709 3.76166207 14.85737491 126.90447000 -1.47061191 1.06604780 0.08108146 6 | I 2.74544401 5.14129111 17.51311420 126.90447000 -1.47061191 -1.06604780 -0.08108146 7 | I 4.31565215 1.38050855 12.64875342 126.90447000 0.74986725 0.79087438 -0.69424655 8 | I 1.42669907 1.58714251 19.72173569 126.90447000 0.74986725 -0.79087438 0.69424655 9 | O 2.78364007 0.39473043 12.50000000 15.99900000 -1.11916203 -0.61054277 -0.10327189 10 | O 5.67259315 2.57292062 19.87048911 15.99900000 -1.11916203 0.61054277 0.10327189 11 | O 3.64269369 2.38202798 14.05899015 15.99900000 -0.44450476 0.58457571 0.81988628 12 | O 0.75374061 0.58562308 18.31149896 15.99900000 -0.44450476 -0.58457571 -0.81988628 13 | O 5.44167386 0.15724415 13.43179857 15.99900000 0.93009003 -0.74236748 0.47707191 14 | O 2.55272078 2.81040691 18.93869054 15.99900000 0.93009003 0.74236748 -0.47707191 15 | O 1.34512782 4.79252145 15.30307812 15.99900000 0.80628161 0.77936335 0.23325304 16 | O 4.23408090 4.11043172 17.06741099 15.99900000 0.80628161 -0.77936335 -0.23325304 17 | O 0.52359753 2.92010411 13.34319181 15.99900000 0.37932513 -0.74376411 -0.97400744 18 | O 3.41255061 0.04754695 19.02729730 15.99900000 0.37932513 0.74376411 0.97400744 19 | O 0.27893312 2.33486569 15.96606058 15.99900000 0.20737178 -1.07627874 0.77025494 20 | O 3.16788619 0.63278537 16.40442853 15.99900000 0.20737178 1.07627874 -0.77025494 21 | -------------------------------------------------------------------------------- /python/vesin/tests/data/carbon.xyz: -------------------------------------------------------------------------------- 1 | 4 2 | Lattice="2.460394 0.0 0.0 -1.26336 2.044166 0.0 -0.139209 -0.407369 6.809714" Properties=species:S:1:pos:R:3 pbc="T T T" 3 | C -0.03480225 -0.10184225 1.70242850 4 | C -0.10440675 -0.30552675 5.10728550 5 | C -0.05691216 1.26093576 1.70242850 6 | C 1.11473716 0.37586124 5.10728550 7 | -------------------------------------------------------------------------------- /python/vesin/tests/data/diamond.xyz: -------------------------------------------------------------------------------- 1 | 8 2 | Lattice="3.5607451090903233 0.0 0.0 0.0 3.5607451090903233 0.0 0.0 0.0 3.5607451090903233" Properties=species:S:1:pos:R:3 pbc="T T T" 3 | C 0.00000000 0.00000000 1.78037255 4 | C 0.89018628 0.89018628 2.67055883 5 | C -0.00000000 1.78037255 0.00000000 6 | C 0.89018628 2.67055883 0.89018628 7 | C 1.78037255 0.00000000 0.00000000 8 | C 2.67055883 0.89018628 0.89018628 9 | C 1.78037255 1.78037255 1.78037255 10 | C 2.67055883 2.67055883 2.67055883 11 | -------------------------------------------------------------------------------- /python/vesin/tests/data/naphthalene.xyz: -------------------------------------------------------------------------------- 1 | 18 2 | Properties=species:S:1:pos:R:3 energy=-10478070.659035627 pbc="F F F" 3 | C -1.07752258 -1.44519018 0.23486494 4 | C -2.34287934 -1.00436934 -0.03884263 5 | C -2.49625633 0.40527529 -0.23488609 6 | C -1.32251140 1.30701063 -0.23011205 7 | C -0.03972489 0.71954944 -0.09685793 8 | C 1.16386179 1.56294692 -0.06105191 9 | C 2.34152520 0.92675009 0.01612411 10 | C 2.44515671 -0.58490551 0.10661457 11 | C 1.31344319 -1.26692346 0.24849954 12 | C 0.01904813 -0.61131099 0.05976149 13 | H -0.84295984 -2.42125368 0.67296941 14 | H -3.18464777 -1.71419493 -0.25475633 15 | H -3.55069635 0.69888542 -0.36353746 16 | H -1.62292307 2.33411368 -0.32316927 17 | H 1.06380309 2.66106164 -0.12679498 18 | H 3.21431354 1.51359069 -0.29721477 19 | H 3.54308435 -0.81134785 0.12807956 20 | H 1.33078774 -2.36610850 0.51600063 21 | -------------------------------------------------------------------------------- /python/vesin/tests/data/readme.txt: -------------------------------------------------------------------------------- 1 | This is a small test suite for vesin, composed of: 2 | - water.xyz: a medium-sized cluster 3 | - napthalene.xyz: a small organic molecule 4 | - diamond.xyz: a periodic system with a cubic cell 5 | - carbon.xyz: a periodic system with a non-cubic cell 6 | -------------------------------------------------------------------------------- /python/vesin/tests/data/water.xyz: -------------------------------------------------------------------------------- 1 | 51 2 | Properties=species:S:1:pos:R:3 3 | O 0.35772 8.3064 11.7449 4 | H 1.27587 8.49472 11.9392 5 | H 0.222311 8.65568 10.864 6 | O 9.73218 8.90883 5.56731 7 | H 10.1022 9.47818 6.24194 8 | H 8.84322 9.23864 5.4361 9 | O 1.47521 6.40456 0.956651 10 | H 1.36119 5.83634 1.71846 11 | H 0.61367 6.79872 0.82016 12 | O 10.6678 10.5038 3.35808 13 | H 11.3289 9.85087 3.12814 14 | H 10.3341 10.2182 4.20857 15 | O 9.1735 7.43006 2.04159 16 | H 9.58976 6.88585 2.71002 17 | H 8.97205 8.25061 2.49142 18 | O 9.35696 13.0776 4.13527 19 | H 9.10802 12.7474 4.99856 20 | H 9.72185 12.3167 3.68342 21 | O 9.82166 12.763 10.4287 22 | H 9.58765 12.0028 9.89611 23 | H 9.16248 12.782 11.1225 24 | O 3.50771 14.0639 3.81693 25 | H 4.14093 13.6853 3.20707 26 | H 2.83722 14.4534 3.25567 27 | O 8.4947 11.0391 8.78921 28 | H 7.81372 11.439 8.24831 29 | H 8.0283 10.4014 9.32966 30 | O 8.12407 13.0253 12.765 31 | H 7.95796 13.964 12.8518 32 | H 8.41685 12.7521 13.6344 33 | O 7.8082 0.044473 1.02143 34 | H 8.34368 0.690967 1.48136 35 | H 7.07934 0.549401 0.66082 36 | O 4.56356 8.95562 8.51371 37 | H 5.21772 9.01236 9.21019 38 | H 3.81575 9.45236 8.84578 39 | O 1.67172 4.68017 8.83563 40 | H 2.26816 5.42761 8.79298 41 | H 2.18559 3.94385 8.50399 42 | O 2.93581 9.19049 12.3264 43 | H 3.39043 9.92348 11.9113 44 | H 3.47201 8.42579 12.1167 45 | O 4.91764 12.4533 1.70512 46 | H 3.98424 12.2612 1.61526 47 | H 5.14136 12.9346 0.908508 48 | O 6.85585 12.6923 6.89814 49 | H 7.47976 13.279 7.32575 50 | H 6.9787 12.8522 5.96241 51 | O 1.96907 11.6309 1.47723 52 | H 1.97934 11.4997 2.42534 53 | H 2.35724 10.8321 1.12019 54 | -------------------------------------------------------------------------------- /python/vesin/tests/test_metatomic.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import pytest 4 | import torch 5 | from metatensor.torch import Labels, TensorMap 6 | from metatomic.torch import ( 7 | AtomisticModel, 8 | ModelCapabilities, 9 | ModelMetadata, 10 | ModelOutput, 11 | NeighborListOptions, 12 | System, 13 | ) 14 | 15 | from vesin.metatomic import NeighborList, compute_requested_neighbors 16 | 17 | 18 | def test_errors(): 19 | positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64) 20 | cell = 4 * torch.eye(3, dtype=torch.float64) 21 | system = System( 22 | positions=positions, 23 | cell=cell, 24 | pbc=torch.ones(3, dtype=bool), 25 | types=torch.tensor([6, 8]), 26 | ) 27 | 28 | options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True) 29 | calculator = NeighborList(options, length_unit="A") 30 | 31 | system.pbc[0] = False 32 | message = ( 33 | "vesin currently does not support mixed periodic and non-periodic " 34 | "boundary conditions" 35 | ) 36 | with pytest.raises(NotImplementedError, match=message): 37 | calculator.compute(system) 38 | 39 | 40 | def test_backward(): 41 | positions = torch.tensor( 42 | [[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64, requires_grad=True 43 | ) 44 | cell = (4 * torch.eye(3, dtype=torch.float64)).clone().requires_grad_(True) 45 | system = System( 46 | positions=positions, 47 | cell=cell, 48 | pbc=torch.ones(3, dtype=bool), 49 | types=torch.tensor([6, 8]), 50 | ) 51 | 52 | options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True) 53 | calculator = NeighborList(options, length_unit="A") 54 | neighbors = calculator.compute(system) 55 | 56 | value = ((neighbors.values) ** 2).sum() * torch.linalg.det(cell) 57 | value.backward() 58 | 59 | # check there are gradients, and they are not zero 60 | assert positions.grad is not None 61 | assert cell.grad is not None 62 | assert torch.linalg.norm(positions.grad) > 0 63 | assert torch.linalg.norm(cell.grad) > 0 64 | 65 | 66 | class InnerModule(torch.nn.Module): 67 | def requested_neighbor_lists(self) -> List[NeighborListOptions]: 68 | return [NeighborListOptions(cutoff=3.4, full_list=False, strict=True)] 69 | 70 | 71 | class OuterModule(torch.nn.Module): 72 | def __init__(self): 73 | super().__init__() 74 | self.inner = InnerModule() 75 | 76 | def requested_neighbor_lists(self) -> List[NeighborListOptions]: 77 | return [NeighborListOptions(cutoff=5.2, full_list=True, strict=False)] 78 | 79 | def forward( 80 | self, 81 | systems: List[System], 82 | outputs: Dict[str, ModelOutput], 83 | selected_atoms: Optional[Labels], 84 | ) -> Dict[str, TensorMap]: 85 | return {} 86 | 87 | 88 | def test_model(): 89 | positions = torch.tensor( 90 | [[0.0, 0.0, 0.0], [1.0, 1.0, 1.4]], dtype=torch.float64, requires_grad=True 91 | ) 92 | cell = (4 * torch.eye(3, dtype=torch.float64)).clone().requires_grad_(True) 93 | pbc = torch.ones(3, dtype=bool) 94 | types = torch.tensor([6, 8]) 95 | systems = [ 96 | System(positions=positions, cell=cell, pbc=pbc, types=types), 97 | System(positions=positions, cell=cell, pbc=pbc, types=types), 98 | ] 99 | 100 | # Using a "raw" model 101 | model = OuterModule() 102 | compute_requested_neighbors( 103 | systems=systems, 104 | system_length_unit="A", 105 | model=model, 106 | model_length_unit="A", 107 | check_consistency=True, 108 | ) 109 | 110 | for system in systems: 111 | all_options = system.known_neighbor_lists() 112 | assert len(all_options) == 2 113 | assert all_options[0].requestors() == ["OuterModule"] 114 | assert all_options[0].cutoff == 5.2 115 | assert all_options[1].requestors() == ["OuterModule.inner"] 116 | assert all_options[1].cutoff == 3.4 117 | 118 | # Using a AtomisticModel 119 | capabilities = ModelCapabilities( 120 | length_unit="A", 121 | interaction_range=6.0, 122 | supported_devices=["cpu"], 123 | dtype="float64", 124 | ) 125 | model = AtomisticModel(model.eval(), ModelMetadata(), capabilities) 126 | compute_requested_neighbors( 127 | systems=System(positions=positions, cell=cell, pbc=pbc, types=types), 128 | system_length_unit="A", 129 | model=model, 130 | ) 131 | 132 | for system in systems: 133 | all_options = system.known_neighbor_lists() 134 | assert len(all_options) == 2 135 | assert all_options[0].requestors() == ["OuterModule"] 136 | assert all_options[0].cutoff == 5.2 137 | assert all_options[1].requestors() == ["OuterModule.inner"] 138 | assert all_options[1].cutoff == 3.4 139 | 140 | message = ( 141 | "the given `model_length_unit` \\(nm\\) does not match the model " 142 | "capabilities \\(A\\)" 143 | ) 144 | with pytest.raises(ValueError, match=message): 145 | compute_requested_neighbors( 146 | systems=System(positions=positions, cell=cell, pbc=pbc, types=types), 147 | system_length_unit="A", 148 | model=model, 149 | model_length_unit="nm", 150 | ) 151 | -------------------------------------------------------------------------------- /python/vesin/tests/test_neighbors.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import ase.build 4 | import ase.io 5 | import ase.neighborlist 6 | import numpy as np 7 | import pytest 8 | 9 | import vesin 10 | 11 | 12 | CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 13 | 14 | 15 | def non_sorted_nl(quantities, atoms, cutoff): 16 | calculator = vesin.NeighborList(cutoff=cutoff, full_list=True, sorted=False) 17 | outputs = calculator.compute( 18 | points=atoms.positions, 19 | box=atoms.cell[:], 20 | periodic=np.all(atoms.pbc), 21 | quantities=quantities, 22 | copy=False, 23 | ) 24 | # since we have `copy=False`, also return the calculator to keep the memory alive 25 | return *outputs, calculator 26 | 27 | 28 | @pytest.mark.parametrize("system", ["water", "diamond", "naphthalene", "carbon"]) 29 | @pytest.mark.parametrize("cutoff", [float(i) for i in range(1, 10)]) 30 | @pytest.mark.parametrize("vesin_nl", [vesin.ase_neighbor_list, non_sorted_nl]) 31 | def test_neighbors(system, cutoff, vesin_nl): 32 | atoms = ase.io.read(f"{CURRENT_DIR}/data/{system}.xyz") 33 | 34 | ase_i, ase_j, ase_S, ase_D = ase.neighborlist.neighbor_list("ijSD", atoms, cutoff) 35 | vesin_i, vesin_j, vesin_S, vesin_D, *_ = vesin_nl("ijSD", atoms, cutoff) 36 | 37 | assert len(ase_i) == len(vesin_i) 38 | assert len(ase_j) == len(vesin_j) 39 | assert len(ase_S) == len(vesin_S) 40 | assert len(ase_D) == len(vesin_D) 41 | 42 | ase_ijS = np.concatenate( 43 | (ase_i.reshape(-1, 1), ase_j.reshape(-1, 1), ase_S), axis=1 44 | ) 45 | vesin_ijS = np.concatenate( 46 | (vesin_i.reshape(-1, 1), vesin_j.reshape(-1, 1), vesin_S), axis=1 47 | ) 48 | 49 | ase_sort_indices = np.lexsort(ase_ijS.T) 50 | vesin_sort_indices = np.lexsort(vesin_ijS.T) 51 | 52 | assert np.array_equal(ase_ijS[ase_sort_indices], vesin_ijS[vesin_sort_indices]) 53 | assert np.allclose(ase_D[ase_sort_indices], vesin_D[vesin_sort_indices]) 54 | 55 | 56 | def test_pairs_output(): 57 | atoms = ase.io.read(f"{CURRENT_DIR}/data/diamond.xyz") 58 | 59 | calculator = vesin.NeighborList(cutoff=2.0, full_list=True, sorted=False) 60 | i, j, P = calculator.compute( 61 | points=atoms.positions, box=atoms.cell[:], periodic=True, quantities="ijP" 62 | ) 63 | 64 | assert np.all(np.vstack([i, j]).T == P) 65 | 66 | 67 | def test_sorting(): 68 | atoms = ase.io.read(f"{CURRENT_DIR}/data/diamond.xyz") 69 | 70 | calculator = vesin.NeighborList(cutoff=2.0, full_list=True, sorted=False) 71 | i, j = calculator.compute( 72 | points=atoms.positions, box=atoms.cell[:], periodic=True, quantities="ij" 73 | ) 74 | unsorted_ij = np.concatenate((i.reshape(-1, 1), j.reshape(-1, 1)), axis=1) 75 | assert not np.all(unsorted_ij[np.lexsort((j, i))] == unsorted_ij) 76 | 77 | calculator = vesin.NeighborList(cutoff=2.0, full_list=True, sorted=True) 78 | i, j = calculator.compute( 79 | points=atoms.positions, box=atoms.cell[:], periodic=True, quantities="ij" 80 | ) 81 | 82 | sorted_ij = np.concatenate((i.reshape(-1, 1), j.reshape(-1, 1)), axis=1) 83 | assert np.all(sorted_ij[np.lexsort((j, i))] == sorted_ij) 84 | 85 | # check that unsorted is not already sorted by chance 86 | assert not np.all(sorted_ij == unsorted_ij) 87 | 88 | # https://github.com/Luthaf/vesin/issues/34 89 | atoms = ase.io.read(f"{CURRENT_DIR}/data/Cd2I4O12.xyz") 90 | calculator = vesin.NeighborList(cutoff=5.0, full_list=True, sorted=True) 91 | i, j = calculator.compute( 92 | points=atoms.positions, box=atoms.cell[:], periodic=True, quantities="ij" 93 | ) 94 | sorted_ij = np.concatenate((i.reshape(-1, 1), j.reshape(-1, 1)), axis=1) 95 | assert np.all(sorted_ij[np.lexsort((j, i))] == sorted_ij) 96 | 97 | 98 | def test_errors(): 99 | points = np.array([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]]) 100 | box = np.zeros((3, 3)) 101 | 102 | nl = vesin.NeighborList(cutoff=1.2, full_list=True) 103 | 104 | message = "the box matrix is not invertible" 105 | with pytest.raises(RuntimeError, match=message): 106 | nl.compute(points, box, periodic=True, quantities="ij") 107 | 108 | box = np.eye(3, 3) 109 | message = "cutoff is too small" 110 | with pytest.raises(RuntimeError, match=message): 111 | nl = vesin.NeighborList(cutoff=1e-12, full_list=True) 112 | nl.compute(points, box, periodic=True, quantities="ij") 113 | 114 | message = "cutoff must be a finite, positive number" 115 | with pytest.raises(RuntimeError, match=message): 116 | nl = vesin.NeighborList(cutoff=0.0, full_list=True) 117 | nl.compute(points, box, periodic=True, quantities="ij") 118 | 119 | with pytest.raises(RuntimeError, match=message): 120 | nl = vesin.NeighborList(cutoff=-12.0, full_list=True) 121 | nl.compute(points, box, periodic=True, quantities="ij") 122 | 123 | with pytest.raises(RuntimeError, match=message): 124 | nl = vesin.NeighborList(cutoff=float("inf"), full_list=True) 125 | nl.compute(points, box, periodic=True, quantities="ij") 126 | 127 | with pytest.raises(RuntimeError, match=message): 128 | nl = vesin.NeighborList(cutoff=float("nan"), full_list=True) 129 | nl.compute(points, box, periodic=True, quantities="ij") 130 | -------------------------------------------------------------------------------- /python/vesin/vesin/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | 3 | from ._ase import ase_neighbor_list # noqa: F401 4 | from ._neighbors import NeighborList # noqa: F401 5 | 6 | 7 | __version__ = importlib.metadata.version("vesin") 8 | -------------------------------------------------------------------------------- /python/vesin/vesin/_ase.py: -------------------------------------------------------------------------------- 1 | try: 2 | import ase 3 | except ImportError: 4 | ase = None 5 | 6 | 7 | from ._neighbors import NeighborList 8 | 9 | 10 | def ase_neighbor_list(quantities, a, cutoff, self_interaction=False, max_nbins=0): 11 | """ 12 | This is a thin wrapper around :py:class:`NeighborList`, providing the same API as 13 | :py:func:`ase.neighborlist.neighbor_list`. 14 | 15 | It is intended as a drop-in replacement for the ASE function, but only supports a 16 | subset of the functionality. Notably, the following is not supported: 17 | 18 | - ``self_interaction=True`` 19 | - :py:class:`ase.Atoms` with mixed periodic boundary conditions 20 | - giving ``cutoff`` as a dictionary 21 | 22 | :param quantities: quantities to output from the neighbor list. Supported are 23 | ``"i"``, ``"j"``, ``"d"``, ``"D"``, and ``"S"`` with the same meaning as in ASE. 24 | :param a: :py:class:`ase.Atoms` instance 25 | :param cutoff: cutoff radius for the neighbor list 26 | :param self_interaction: Should an atom be considered its own neighbor? Default: 27 | False 28 | :param max_nbins: for ASE compatibility, ignored by this implementation 29 | """ 30 | if ase is None: 31 | raise ImportError("could not import ase, this function requires ase") 32 | 33 | if self_interaction: 34 | raise ValueError("self_interaction=True is not implemented") 35 | 36 | if not isinstance(cutoff, float): 37 | raise ValueError("only a single float cutoff is supported") 38 | 39 | if not isinstance(a, ase.Atoms): 40 | raise TypeError(f"`a` should be ase.Atoms, got {type(a)} instead") 41 | 42 | if a.pbc[0] and a.pbc[1] and a.pbc[2]: 43 | periodic = True 44 | elif not a.pbc[0] and not a.pbc[1] and not a.pbc[2]: 45 | periodic = False 46 | else: 47 | raise ValueError( 48 | "different periodic boundary conditions on different axis are not supported" 49 | ) 50 | 51 | # sorted=True and full_list=True since that's what ASE does 52 | calculator = NeighborList(cutoff=cutoff, full_list=True, sorted=True) 53 | return calculator.compute( 54 | points=a.positions, 55 | box=a.cell[:], 56 | periodic=periodic, 57 | quantities=quantities, 58 | copy=True, 59 | ) 60 | -------------------------------------------------------------------------------- /python/vesin/vesin/_c_api.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | from ctypes import ARRAY, POINTER 3 | 4 | 5 | VesinDevice = ctypes.c_int 6 | VesinUnknownDevice = 0 7 | VesinCPU = 1 8 | 9 | 10 | class VesinOptions(ctypes.Structure): 11 | _fields_ = [ 12 | ("cutoff", ctypes.c_double), 13 | ("full", ctypes.c_bool), 14 | ("sorted", ctypes.c_bool), 15 | ("return_shifts", ctypes.c_bool), 16 | ("return_distances", ctypes.c_bool), 17 | ("return_vectors", ctypes.c_bool), 18 | ] 19 | 20 | 21 | class VesinNeighborList(ctypes.Structure): 22 | _fields_ = [ 23 | ("length", ctypes.c_size_t), 24 | ("device", VesinDevice), 25 | ("pairs", POINTER(ARRAY(ctypes.c_size_t, 2))), 26 | ("shifts", POINTER(ARRAY(ctypes.c_int32, 3))), 27 | ("distances", POINTER(ctypes.c_double)), 28 | ("vectors", POINTER(ARRAY(ctypes.c_double, 3))), 29 | ] 30 | 31 | 32 | def setup_functions(lib): 33 | lib.vesin_free.argtypes = [POINTER(VesinNeighborList)] 34 | lib.vesin_free.restype = None 35 | 36 | lib.vesin_neighbors.argtypes = [ 37 | POINTER(ARRAY(ctypes.c_double, 3)), # points 38 | ctypes.c_size_t, # n_points 39 | ARRAY(ARRAY(ctypes.c_double, 3), 3), # box 40 | ctypes.c_bool, # periodic 41 | VesinDevice, # device 42 | VesinOptions, # options 43 | POINTER(VesinNeighborList), # neighbors 44 | POINTER(ctypes.c_char_p), # error_message 45 | ] 46 | lib.vesin_neighbors.restype = ctypes.c_int 47 | -------------------------------------------------------------------------------- /python/vesin/vesin/_c_lib.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from ctypes import cdll 4 | 5 | from ._c_api import setup_functions 6 | 7 | 8 | _HERE = os.path.realpath(os.path.dirname(__file__)) 9 | 10 | 11 | class LibraryFinder(object): 12 | def __init__(self): 13 | self._cached_dll = None 14 | 15 | def __call__(self): 16 | if self._cached_dll is None: 17 | path = _lib_path() 18 | self._cached_dll = cdll.LoadLibrary(path) 19 | setup_functions(self._cached_dll) 20 | 21 | return self._cached_dll 22 | 23 | 24 | def _lib_path(): 25 | if sys.platform.startswith("darwin"): 26 | windows = False 27 | path = os.path.join(_HERE, "lib", "libvesin.dylib") 28 | elif sys.platform.startswith("linux"): 29 | windows = False 30 | path = os.path.join(_HERE, "lib", "libvesin.so") 31 | elif sys.platform.startswith("win"): 32 | windows = True 33 | path = os.path.join(_HERE, "bin", "vesin.dll") 34 | else: 35 | raise ImportError("Unknown platform. Please edit this file") 36 | 37 | if os.path.isfile(path): 38 | if windows: 39 | _check_dll(path) 40 | return path 41 | 42 | raise ImportError("Could not find vesin shared library at " + path) 43 | 44 | 45 | def _check_dll(path): 46 | """Check if the DLL at ``path`` matches the architecture of Python""" 47 | import platform 48 | import struct 49 | 50 | IMAGE_FILE_MACHINE_I386 = 332 51 | IMAGE_FILE_MACHINE_AMD64 = 34404 52 | IMAGE_FILE_MACHINE_ARM64 = 43620 53 | 54 | machine = None 55 | with open(path, "rb") as fd: 56 | header = fd.read(2).decode(encoding="utf-8", errors="strict") 57 | if header != "MZ": 58 | raise ImportError(path + " is not a DLL") 59 | else: 60 | fd.seek(60) 61 | header = fd.read(4) 62 | header_offset = struct.unpack(" List[np.ndarray]: 44 | """ 45 | Compute the neighbor list for the system defined by ``positions``, ``box``, and 46 | ``periodic``; returning the requested ``quantities``. 47 | 48 | ``quantities`` can contain any combination of the following values: 49 | 50 | - ``"i"`` to get the index of the first point in the pair 51 | - ``"j"`` to get the index of the second point in the pair 52 | - ``"P"`` to get the indexes of the two points in the pair simultaneously 53 | - ``"S"`` to get the periodic shift of the pair 54 | - ``"d"`` to get the distance between points in the pair 55 | - ``"D"`` to get the distance vector between points in the pair 56 | 57 | :param points: positions of all points in the system (this can be anything that 58 | can be converted to a numpy array) 59 | :param box: bounding box of the system (this can be anything that can be 60 | converted to a numpy array) 61 | :param periodic: should we use periodic boundary conditions? 62 | :param quantities: quantities to return, defaults to "ij" 63 | :param copy: should we copy the returned quantities, defaults to ``True``. 64 | Setting this to ``False`` might be a bit faster, but the returned arrays are 65 | view inside this class, and will be invalidated whenever this class is 66 | garbage collected or used to run a new calculation. 67 | 68 | :return: tuple of arrays as indicated by ``quantities``. 69 | """ 70 | points = np.asarray(points, dtype=np.float64) 71 | box = np.asarray(box, dtype=np.float64) 72 | 73 | if box.shape != (3, 3): 74 | raise ValueError("`box` must be a 3x3 matrix") 75 | 76 | Vector = ARRAY(ctypes.c_double, 3) 77 | box = ARRAY(Vector, 3)( 78 | Vector(box[0][0], box[0][1], box[0][2]), 79 | Vector(box[1][0], box[1][1], box[1][2]), 80 | Vector(box[2][0], box[2][1], box[2][2]), 81 | ) 82 | 83 | if len(points.shape) != 2 or points.shape[1] != 3: 84 | raise ValueError("`points` must be a nx3 array") 85 | 86 | options = VesinOptions() 87 | options.cutoff = self.cutoff 88 | options.full = self.full_list 89 | options.sorted = self.sorted 90 | options.return_shifts = "S" in quantities 91 | options.return_distances = "d" in quantities 92 | options.return_vectors = "D" in quantities 93 | 94 | error_message = ctypes.c_char_p() 95 | status = self._lib.vesin_neighbors( 96 | points.ctypes.data_as(POINTER(ARRAY(ctypes.c_double, 3))), 97 | points.shape[0], 98 | box, 99 | periodic, 100 | VesinCPU, 101 | options, 102 | self._neighbors, 103 | error_message, 104 | ) 105 | 106 | if status != 0: 107 | raise RuntimeError(error_message.value.decode("utf8")) 108 | 109 | # create numpy arrays for the data 110 | n_pairs = self._neighbors.length 111 | if n_pairs == 0: 112 | pairs = np.empty((0, 2), dtype=ctypes.c_size_t) 113 | shifts = np.empty((0, 3), dtype=ctypes.c_int32) 114 | distances = np.empty((0,), dtype=ctypes.c_double) 115 | vectors = np.empty((0, 3), dtype=ctypes.c_double) 116 | else: 117 | ptr = ctypes.cast(self._neighbors.pairs, POINTER(ctypes.c_size_t)) 118 | pairs = np.ctypeslib.as_array(ptr, shape=(n_pairs, 2)) 119 | 120 | if "S" in quantities: 121 | ptr = ctypes.cast(self._neighbors.shifts, POINTER(ctypes.c_int32)) 122 | shifts = np.ctypeslib.as_array(ptr, shape=(n_pairs, 3)) 123 | 124 | if "d" in quantities: 125 | ptr = ctypes.cast(self._neighbors.distances, POINTER(ctypes.c_double)) 126 | distances = np.ctypeslib.as_array(ptr, shape=(n_pairs,)) 127 | 128 | if "D" in quantities: 129 | ptr = ctypes.cast(self._neighbors.vectors, POINTER(ctypes.c_double)) 130 | vectors = np.ctypeslib.as_array(ptr, shape=(n_pairs, 3)) 131 | 132 | # assemble output 133 | 134 | data = [] 135 | for quantity in quantities: 136 | if quantity == "P": 137 | if copy: 138 | data.append(pairs.copy()) 139 | else: 140 | data.append(pairs) 141 | 142 | if quantity == "i": 143 | if copy: 144 | data.append(pairs[:, 0].copy()) 145 | else: 146 | data.append(pairs[:, 0]) 147 | 148 | elif quantity == "j": 149 | if copy: 150 | data.append(pairs[:, 1].copy()) 151 | else: 152 | data.append(pairs[:, 1]) 153 | 154 | elif quantity == "S": 155 | if copy: 156 | data.append(shifts.copy()) 157 | else: 158 | data.append(shifts) 159 | 160 | elif quantity == "d": 161 | if copy: 162 | data.append(distances.copy()) 163 | else: 164 | data.append(distances) 165 | 166 | elif quantity == "D": 167 | if copy: 168 | data.append(vectors.copy()) 169 | else: 170 | data.append(vectors) 171 | 172 | return tuple(data) 173 | -------------------------------------------------------------------------------- /python/vesin/vesin/metatomic/__init__.py: -------------------------------------------------------------------------------- 1 | from ._model import compute_requested_neighbors # noqa: F401 2 | from ._neighbors import NeighborList # noqa: F401 3 | -------------------------------------------------------------------------------- /python/vesin/vesin/metatomic/_model.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | from metatomic.torch import ( 5 | AtomisticModel, 6 | ModelInterface, 7 | NeighborListOptions, 8 | System, 9 | ) 10 | 11 | from ._neighbors import NeighborList 12 | 13 | 14 | def compute_requested_neighbors( 15 | systems: Union[List[System], System], 16 | system_length_unit: str, 17 | model: Union[AtomisticModel, ModelInterface], 18 | model_length_unit: Optional[str] = None, 19 | check_consistency: bool = False, 20 | ): 21 | """ 22 | Compute all neighbors lists requested by the ``model`` through 23 | ``requested_neighbor_lists()`` member functions, and store them inside all the 24 | ``systems``. 25 | 26 | :param systems: Single system or list of systems for which we need to compute the 27 | neighbor lists that the model requires. 28 | :param system_length_unit: unit of length used by the data in ``systems`` 29 | :param model: :py:class:`AtomisticModel` or any ``torch.nn.Module`` following the 30 | :py:class:`ModelInterface` 31 | :param model_length_unit: unit of length used by the model, optional. This is only 32 | required when giving a raw model instead of a :py:class:`AtomisticModel`. 33 | :param check_consistency: whether to run additional checks on the neighbor lists 34 | validity 35 | """ 36 | 37 | if isinstance(model, AtomisticModel): 38 | if model_length_unit is not None: 39 | if model.capabilities().length_unit != model_length_unit: 40 | raise ValueError( 41 | f"the given `model_length_unit` ({model_length_unit}) does not " 42 | f"match the model capabilities ({model.capabilities().length_unit})" 43 | ) 44 | 45 | all_options = model.requested_neighbor_lists() 46 | elif isinstance(model, torch.nn.Module): 47 | if model_length_unit is None: 48 | raise ValueError( 49 | "`model_length_unit` parameter is required when not " 50 | "using AtomisticModel" 51 | ) 52 | 53 | all_options = [] 54 | _get_requested_neighbor_lists( 55 | model, model.__class__.__name__, all_options, model_length_unit 56 | ) 57 | 58 | if not isinstance(systems, list): 59 | systems = [systems] 60 | 61 | for options in all_options: 62 | calculator = NeighborList( 63 | options, 64 | system_length_unit, 65 | check_consistency=check_consistency, 66 | ) 67 | 68 | for system in systems: 69 | neighbors = calculator.compute(system) 70 | system.add_neighbor_list(options, neighbors) 71 | 72 | 73 | def _get_requested_neighbor_lists( 74 | module: torch.nn.Module, 75 | module_name: str, 76 | requested: List[NeighborListOptions], 77 | length_unit: str, 78 | ): 79 | """ 80 | Recursively extract the requested neighbor lists from a non-exported metatomic 81 | model. 82 | """ 83 | if hasattr(module, "requested_neighbor_lists"): 84 | for new_options in module.requested_neighbor_lists(): 85 | new_options.add_requestor(module_name) 86 | 87 | already_requested = False 88 | for existing in requested: 89 | if existing == new_options: 90 | already_requested = True 91 | for requestor in new_options.requestors(): 92 | existing.add_requestor(requestor) 93 | 94 | if not already_requested: 95 | if new_options.length_unit not in ["", length_unit]: 96 | raise ValueError( 97 | f"NeighborsListOptions from {module_name} already have a " 98 | f"length unit ('{new_options.length_unit}') which does not " 99 | f"match the model length units ('{length_unit}')" 100 | ) 101 | 102 | new_options.length_unit = length_unit 103 | requested.append(new_options) 104 | 105 | for child_name, child in module.named_children(): 106 | _get_requested_neighbor_lists( 107 | module=child, 108 | module_name=module_name + "." + child_name, 109 | requested=requested, 110 | length_unit=length_unit, 111 | ) 112 | -------------------------------------------------------------------------------- /python/vesin/vesin/metatomic/_neighbors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from metatensor.torch import Labels, TensorBlock 3 | from metatomic.torch import NeighborListOptions, System, register_autograd_neighbors 4 | 5 | from .. import NeighborList as NeighborListNumpy 6 | 7 | 8 | try: 9 | from vesin.torch import NeighborList as NeighborListTorch 10 | 11 | except ImportError: 12 | 13 | class NeighborListTorch: 14 | def __init__(self, cutoff: float, full_list: bool): 15 | raise ValueError("torchscript=True requires `vesin-torch` as a dependency") 16 | 17 | def compute( 18 | self, 19 | points: torch.Tensor, 20 | box: torch.Tensor, 21 | periodic: bool, 22 | quantities: str, 23 | copy: bool = True, 24 | ): 25 | pass 26 | 27 | 28 | class NeighborList: 29 | """ 30 | A neighbor list calculator that can be used with metatomic's models. 31 | 32 | The main difference with the other calculators is the automatic handling of 33 | different length unit between what the model expects and what the ``System`` are 34 | using. 35 | 36 | .. seealso:: 37 | 38 | The :py:func:`vesin.metatomic.compute_requested_neighbors` function can be used 39 | to automatically compute and store all neighbor lists required by a given model. 40 | """ 41 | 42 | def __init__( 43 | self, 44 | options: NeighborListOptions, 45 | length_unit: str, 46 | torchscript: bool = False, 47 | check_consistency: bool = False, 48 | ): 49 | """ 50 | :param options: :py:class:`metatomic.torch.NeighborListOptions` defining the 51 | parameters of the neighbor list 52 | :param length_unit: unit of length used for the systems data 53 | :param torchscript: whether this function should be compatible with TorchScript 54 | or not. If ``True``, this requires installing the ``vesin-torch`` package. 55 | :param check_consistency: whether to run additional checks on the neighbor list 56 | validity 57 | 58 | Example 59 | ------- 60 | 61 | >>> from vesin.metatomic import NeighborList 62 | >>> from metatomic.torch import System, NeighborListOptions 63 | >>> import torch 64 | >>> system = System( 65 | ... positions=torch.eye(3).requires_grad_(True), 66 | ... cell=4 * torch.eye(3).requires_grad_(True), 67 | ... types=torch.tensor([8, 1, 1]), 68 | ... pbc=torch.ones(3, dtype=bool), 69 | ... ) 70 | >>> options = NeighborListOptions(cutoff=4.0, full_list=True, strict=False) 71 | >>> calculator = NeighborList(options, length_unit="Angstrom") 72 | >>> neighbors = calculator.compute(system) 73 | >>> neighbors 74 | TensorBlock 75 | samples (18): ['first_atom', 'second_atom', 'cell_shift_a', 'cell_shift_b', 'cell_shift_c'] 76 | components (3): ['xyz'] 77 | properties (1): ['distance'] 78 | gradients: None 79 | 80 | 81 | The returned TensorBlock can then be registered with the system 82 | 83 | >>> system.add_neighbor_list(options, neighbors) 84 | """ # noqa: E501 85 | 86 | self.options = options 87 | self.length_unit = length_unit 88 | self.check_consistency = check_consistency 89 | 90 | if torch.jit.is_scripting() or torchscript: 91 | self._nl = NeighborListTorch( 92 | cutoff=self.options.engine_cutoff(self.length_unit), 93 | full_list=self.options.full_list, 94 | ) 95 | else: 96 | self._nl = NeighborListNumpy( 97 | cutoff=self.options.engine_cutoff(self.length_unit), 98 | full_list=self.options.full_list, 99 | ) 100 | 101 | # cached Labels 102 | self._components = [Labels("xyz", torch.tensor([[0], [1], [2]]))] 103 | self._properties = Labels(["distance"], torch.tensor([[0]])) 104 | 105 | def compute(self, system: System) -> TensorBlock: 106 | """ 107 | Compute the neighbor list for the given :py:class:`metatomic.torch.System`. 108 | 109 | :param system: a :py:class:`metatomic.torch.System` containing data about a 110 | single structure. If the positions or cell of this system require gradients, 111 | the neighbors list values computational graph will be set accordingly. 112 | 113 | The positions and cell need to be in the length unit defined for this 114 | :py:class:`NeighborList` calculator. 115 | """ 116 | 117 | # move to float64, as vesin only works in torch64 118 | points = system.positions.to(torch.float64).detach() 119 | box = system.cell.to(torch.float64).detach() 120 | if torch.all(system.pbc): 121 | periodic = True 122 | elif not torch.any(system.pbc): 123 | periodic = False 124 | else: 125 | raise NotImplementedError( 126 | "vesin currently does not support mixed periodic and non-periodic " 127 | "boundary conditions" 128 | ) 129 | 130 | # computes neighbor list 131 | (P, S, D) = self._nl.compute( 132 | points=points, box=box, periodic=periodic, quantities="PSD", copy=True 133 | ) 134 | P = torch.as_tensor(P, dtype=torch.int32) 135 | S = torch.as_tensor(S, dtype=torch.int32) 136 | D = torch.as_tensor(D, dtype=system.positions.dtype) 137 | 138 | # converts to a suitable TensorBlock format 139 | neighbors = TensorBlock( 140 | D.reshape(-1, 3, 1).to(system.positions.dtype), 141 | samples=Labels( 142 | names=[ 143 | "first_atom", 144 | "second_atom", 145 | "cell_shift_a", 146 | "cell_shift_b", 147 | "cell_shift_c", 148 | ], 149 | values=torch.hstack([P, S]), 150 | ), 151 | components=self._components, 152 | properties=self._properties, 153 | ) 154 | 155 | register_autograd_neighbors( 156 | system, neighbors, check_consistency=self.check_consistency 157 | ) 158 | 159 | return neighbors 160 | -------------------------------------------------------------------------------- /python/vesin_torch/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include pyproject.toml 2 | include LICENSE 3 | include VERSION 4 | 5 | recursive-include lib * 6 | recursive-include build-backend *.py 7 | -------------------------------------------------------------------------------- /python/vesin_torch/README.md: -------------------------------------------------------------------------------- 1 | ../../README.md -------------------------------------------------------------------------------- /python/vesin_torch/VERSION: -------------------------------------------------------------------------------- 1 | ../../vesin/VERSION -------------------------------------------------------------------------------- /python/vesin_torch/build-backend/backend.py: -------------------------------------------------------------------------------- 1 | # This is a custom Python build backend wrapping setuptool's to add a build-time 2 | # dependencies on torch/cmake when building the wheel and not the sdist 3 | import os 4 | 5 | from setuptools import build_meta 6 | 7 | 8 | ROOT = os.path.realpath(os.path.dirname(__file__)) 9 | 10 | FORCED_TORCH_VERSION = os.environ.get("VESIN_TORCH_BUILD_WITH_TORCH_VERSION") 11 | if FORCED_TORCH_VERSION is not None: 12 | TORCH_DEP = f"torch =={FORCED_TORCH_VERSION}" 13 | else: 14 | TORCH_DEP = "torch >=2.3" 15 | 16 | # ==================================================================================== # 17 | # Build backend functions definition # 18 | # ==================================================================================== # 19 | 20 | # Use the default version of these 21 | prepare_metadata_for_build_wheel = build_meta.prepare_metadata_for_build_wheel 22 | get_requires_for_build_sdist = build_meta.get_requires_for_build_sdist 23 | build_wheel = build_meta.build_wheel 24 | build_sdist = build_meta.build_sdist 25 | 26 | 27 | # Special dependencies to build the wheels 28 | def get_requires_for_build_wheel(config_settings=None): 29 | defaults = build_meta.get_requires_for_build_wheel(config_settings) 30 | return defaults + ["cmake", TORCH_DEP] 31 | -------------------------------------------------------------------------------- /python/vesin_torch/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "vesin-torch" 3 | dynamic = ["version", "dependencies"] 4 | requires-python = ">=3.9" 5 | authors = [ 6 | {name = "Guillaume Fraux", email = "guillaume.fraux@epfl.ch"}, 7 | ] 8 | 9 | readme = "README.md" 10 | license = "BSD-3-Clause" 11 | description = "Computing neighbor lists for atomistic system, in TorchScript" 12 | 13 | classifiers = [ 14 | "Development Status :: 4 - Beta", 15 | "Intended Audience :: Science/Research", 16 | "Operating System :: POSIX", 17 | "Operating System :: MacOS :: MacOS X", 18 | "Operating System :: Microsoft :: Windows", 19 | "Programming Language :: Python", 20 | "Programming Language :: Python :: 3", 21 | "Topic :: Scientific/Engineering", 22 | "Topic :: Scientific/Engineering :: Bio-Informatics", 23 | "Topic :: Scientific/Engineering :: Chemistry", 24 | "Topic :: Scientific/Engineering :: Physics", 25 | "Topic :: Software Development :: Libraries", 26 | "Topic :: Software Development :: Libraries :: Python Modules", 27 | ] 28 | 29 | [project.urls] 30 | homepage = "https://github.com/Luthaf/vesin/" 31 | documentation = "https://luthaf.fr/vesin/" 32 | repository = "https://github.com/Luthaf/vesin/" 33 | 34 | ### ======================================================================== ### 35 | 36 | [build-system] 37 | requires = [ 38 | "setuptools >=77", 39 | "wheel >=0.41", 40 | ] 41 | # use a custom build backend to add a dependency on torch/cmake only when 42 | # building wheels 43 | build-backend = "backend" 44 | backend-path = ["build-backend"] 45 | 46 | [tool.setuptools] 47 | zip-safe = false 48 | 49 | [tool.setuptools.packages.find] 50 | include = ["vesin*"] 51 | namespaces = true 52 | -------------------------------------------------------------------------------- /python/vesin_torch/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | import subprocess 5 | import sys 6 | 7 | from setuptools import Extension, setup 8 | from setuptools.command.bdist_egg import bdist_egg 9 | from setuptools.command.build_ext import build_ext 10 | from setuptools.command.sdist import sdist 11 | from wheel.bdist_wheel import bdist_wheel 12 | 13 | 14 | ROOT = os.path.realpath(os.path.dirname(__file__)) 15 | 16 | VESIN_BUILD_TYPE = os.environ.get("VESIN_BUILD_TYPE", "release") 17 | if VESIN_BUILD_TYPE not in ["debug", "release"]: 18 | raise Exception( 19 | f"invalid build type passed: '{VESIN_BUILD_TYPE}', " 20 | "expected 'debug' or 'release'" 21 | ) 22 | 23 | 24 | class universal_wheel(bdist_wheel): 25 | # When building the wheel, the `wheel` package assumes that if we have a 26 | # binary extension then we are linking to `libpython.so`; and thus the wheel 27 | # is only usable with a single python version. This is not the case for 28 | # here, and the wheel will be compatible with any Python >=3.7. This is 29 | # tracked in https://github.com/pypa/wheel/issues/185, but until then we 30 | # manually override the wheel tag. 31 | def get_tag(self): 32 | tag = bdist_wheel.get_tag(self) 33 | # tag[2:] contains the os/arch tags, we want to keep them 34 | return ("py3", "none") + tag[2:] 35 | 36 | 37 | class cmake_ext(build_ext): 38 | """ 39 | Build the native library using cmake 40 | """ 41 | 42 | def run(self): 43 | import torch 44 | 45 | source_dir = os.path.join(ROOT, "lib") 46 | if not os.path.exists(source_dir): 47 | # we are building from a checkout 48 | source_dir = os.path.join(ROOT, "..", "..", "vesin") 49 | 50 | build_dir = os.path.join(ROOT, "build", "cmake-build") 51 | 52 | # Install the shared library in a prefix matching the torch version used to 53 | # compile the code. This allows having multiple version of this shared library 54 | # inside the wheel; and dynamically pick the right one. 55 | torch_major, torch_minor, *_ = torch.__version__.split(".") 56 | install_dir = os.path.join( 57 | os.path.realpath(self.build_lib), 58 | "vesin", 59 | "torch", 60 | f"torch-{torch_major}.{torch_minor}", 61 | ) 62 | 63 | os.makedirs(build_dir, exist_ok=True) 64 | 65 | cmake_options = [ 66 | f"-DCMAKE_INSTALL_PREFIX={install_dir}", 67 | f"-DCMAKE_BUILD_TYPE={VESIN_BUILD_TYPE}", 68 | f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}", 69 | "-DBUILD_SHARED_LIBS=ON", 70 | "-DVESIN_TORCH=ON", 71 | ] 72 | 73 | subprocess.run( 74 | ["cmake", source_dir, *cmake_options], 75 | cwd=build_dir, 76 | check=True, 77 | ) 78 | subprocess.run( 79 | ["cmake", "--build", build_dir, "--target", "install"], 80 | check=True, 81 | ) 82 | 83 | # do not include the non-torch vesin lib in the wheel 84 | for file in glob.glob(os.path.join(install_dir, "bin", "*")): 85 | if "vesin_torch" not in os.path.basename(file): 86 | os.unlink(file) 87 | 88 | for file in glob.glob(os.path.join(install_dir, "lib", "*")): 89 | if "vesin_torch" not in os.path.basename(file): 90 | os.unlink(file) 91 | 92 | for file in glob.glob(os.path.join(install_dir, "include", "*")): 93 | if "vesin_torch" not in os.path.basename(file): 94 | os.unlink(file) 95 | 96 | 97 | class bdist_egg_disabled(bdist_egg): 98 | """Disabled version of bdist_egg 99 | 100 | Prevents setup.py install performing setuptools' default easy_install, 101 | which it should never ever do. 102 | """ 103 | 104 | def run(self): 105 | sys.exit( 106 | "Aborting implicit building of eggs.\nUse `pip install .` or " 107 | "`python -m build --wheel . && pip install " 108 | "dist/metatensor_core-*.whl` to install from source." 109 | ) 110 | 111 | 112 | class sdist_with_lib(sdist): 113 | """ 114 | Create a sdist including the code for the native library 115 | """ 116 | 117 | def run(self): 118 | # generate extra files 119 | shutil.copytree(os.path.join(ROOT, "..", "..", "vesin"), "lib") 120 | 121 | # run original sdist 122 | super().run() 123 | 124 | # cleanup 125 | shutil.rmtree("lib") 126 | 127 | 128 | if __name__ == "__main__": 129 | if sys.platform == "win32": 130 | # On Windows, starting with PyTorch 2.3, the file shm.dll in torch has a 131 | # dependency on mkl DLLs. When building the code using pip build isolation, pip 132 | # installs the mkl package in a place where the os is not trying to load 133 | # 134 | # This is a very similar fix to https://github.com/pytorch/pytorch/pull/126095, 135 | # except only applying when importing torch from a build-isolation virtual 136 | # environment created by pip (`python -m build` does not seems to suffer from 137 | # this). 138 | import wheel 139 | 140 | pip_virtualenv = os.path.realpath( 141 | os.path.join( 142 | os.path.dirname(wheel.__file__), 143 | "..", 144 | "..", 145 | "..", 146 | "..", 147 | ) 148 | ) 149 | mkl_dll_dir = os.path.join( 150 | pip_virtualenv, 151 | "normal", 152 | "Library", 153 | "bin", 154 | ) 155 | 156 | if os.path.exists(mkl_dll_dir): 157 | os.add_dll_directory(mkl_dll_dir) 158 | 159 | # End of Windows/MKL/PIP hack 160 | 161 | install_requires = [] 162 | forced_torch_version = os.environ.get("VESIN_TORCH_BUILD_WITH_TORCH_VERSION") 163 | if forced_torch_version is not None: 164 | install_requires.append(f"torch =={forced_torch_version}") 165 | else: 166 | install_requires.append("torch >=2.3") 167 | 168 | setup( 169 | version=open("VERSION").read().strip(), 170 | install_requires=install_requires, 171 | ext_modules=[ 172 | # only declare the extension, it is built & copied as required by cmake 173 | # in the build_ext command 174 | Extension(name="vesin_torch", sources=[]), 175 | ], 176 | cmdclass={ 177 | "sdist": sdist_with_lib, 178 | "build_ext": cmake_ext, 179 | "bdist_egg": bdist_egg if "bdist_egg" in sys.argv else bdist_egg_disabled, 180 | "bdist_wheel": universal_wheel, 181 | }, 182 | ) 183 | -------------------------------------------------------------------------------- /python/vesin_torch/tests/test_autograd.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from vesin.torch import NeighborList 5 | 6 | 7 | @pytest.mark.parametrize("full_list", [True, False]) 8 | @pytest.mark.parametrize("requires_grad", [(True, True), (False, True), (True, False)]) 9 | @pytest.mark.parametrize("quantities", ["ijS", "D", "d", "ijSDd"]) 10 | def test_autograd(full_list, requires_grad, quantities): 11 | torch.manual_seed(0xDEADBEEF) 12 | 13 | points_fractional = torch.rand((34, 3), dtype=torch.float64) 14 | box = torch.diag(5 * torch.rand(3, dtype=torch.float64)) 15 | box += torch.rand((3, 3), dtype=torch.float64) 16 | 17 | points = points_fractional @ box 18 | 19 | points.requires_grad_(requires_grad[0]) 20 | box.requires_grad_(requires_grad[1]) 21 | 22 | calculator = NeighborList(cutoff=7.8, full_list=full_list) 23 | 24 | def compute(points, box): 25 | results = calculator.compute(points, box, periodic=True, quantities=quantities) 26 | return results 27 | 28 | torch.autograd.gradcheck(compute, (points, box), fast_mode=True) 29 | -------------------------------------------------------------------------------- /python/vesin_torch/tests/test_metatensor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import pytest 4 | import torch 5 | from metatensor.torch import Labels, TensorMap 6 | from metatensor.torch.atomistic import ( 7 | MetatensorAtomisticModel, 8 | ModelCapabilities, 9 | ModelMetadata, 10 | ModelOutput, 11 | NeighborListOptions, 12 | System, 13 | ) 14 | 15 | from vesin.torch.metatensor import NeighborList, compute_requested_neighbors 16 | 17 | 18 | def test_errors(): 19 | positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64) 20 | cell = 4 * torch.eye(3, dtype=torch.float64) 21 | system = System( 22 | positions=positions, 23 | cell=cell, 24 | pbc=torch.ones(3, dtype=bool), 25 | types=torch.tensor([6, 8]), 26 | ) 27 | 28 | options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True) 29 | calculator = NeighborList(options, length_unit="A") 30 | 31 | system.pbc[0] = False 32 | message = ( 33 | "vesin currently does not support mixed periodic and non-periodic " 34 | "boundary conditions" 35 | ) 36 | with pytest.raises(NotImplementedError, match=message): 37 | calculator.compute(system) 38 | 39 | 40 | def test_script(): 41 | positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64) 42 | cell = 4 * torch.eye(3, dtype=torch.float64) 43 | system = System( 44 | positions=positions, 45 | cell=cell, 46 | pbc=torch.ones(3, dtype=bool), 47 | types=torch.tensor([6, 8]), 48 | ) 49 | 50 | options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True) 51 | calculator = torch.jit.script(NeighborList(options, length_unit="A")) 52 | calculator.compute(system) 53 | 54 | 55 | def test_backward(): 56 | positions = torch.tensor( 57 | [[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64, requires_grad=True 58 | ) 59 | cell = (4 * torch.eye(3, dtype=torch.float64)).clone().requires_grad_(True) 60 | system = System( 61 | positions=positions, 62 | cell=cell, 63 | pbc=torch.ones(3, dtype=bool), 64 | types=torch.tensor([6, 8]), 65 | ) 66 | 67 | options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True) 68 | calculator = NeighborList(options, length_unit="A") 69 | neighbors = calculator.compute(system) 70 | 71 | value = ((neighbors.values) ** 2).sum() * torch.linalg.det(cell) 72 | value.backward() 73 | 74 | # check there are gradients, and they are not zero 75 | assert positions.grad is not None 76 | assert cell.grad is not None 77 | assert torch.linalg.norm(positions.grad) > 0 78 | assert torch.linalg.norm(cell.grad) > 0 79 | 80 | 81 | class InnerModule(torch.nn.Module): 82 | def requested_neighbor_lists(self) -> List[NeighborListOptions]: 83 | return [NeighborListOptions(cutoff=3.4, full_list=False, strict=True)] 84 | 85 | 86 | class OuterModule(torch.nn.Module): 87 | def __init__(self): 88 | super().__init__() 89 | self.inner = InnerModule() 90 | 91 | def requested_neighbor_lists(self) -> List[NeighborListOptions]: 92 | return [NeighborListOptions(cutoff=5.2, full_list=True, strict=False)] 93 | 94 | def forward( 95 | self, 96 | systems: List[System], 97 | outputs: Dict[str, ModelOutput], 98 | selected_atoms: Optional[Labels], 99 | ) -> Dict[str, TensorMap]: 100 | return {} 101 | 102 | 103 | def test_model(): 104 | positions = torch.tensor( 105 | [[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64, requires_grad=True 106 | ) 107 | cell = (4 * torch.eye(3, dtype=torch.float64)).clone().requires_grad_(True) 108 | pbc = torch.ones(3, dtype=bool) 109 | types = torch.tensor([6, 8]) 110 | systems = [ 111 | System(positions=positions, cell=cell, pbc=pbc, types=types), 112 | System(positions=positions, cell=cell, pbc=pbc, types=types), 113 | ] 114 | 115 | # Using a "raw" model 116 | model = OuterModule() 117 | compute_requested_neighbors( 118 | systems=systems, system_length_unit="A", model=model, model_length_unit="A" 119 | ) 120 | 121 | for system in systems: 122 | all_options = system.known_neighbor_lists() 123 | assert len(all_options) == 2 124 | assert all_options[0].requestors() == ["OuterModule"] 125 | assert all_options[0].cutoff == 5.2 126 | assert all_options[1].requestors() == ["OuterModule.inner"] 127 | assert all_options[1].cutoff == 3.4 128 | 129 | # Using a MetatensorAtomisticModel model 130 | capabilities = ModelCapabilities( 131 | length_unit="A", 132 | interaction_range=6.0, 133 | supported_devices=["cpu"], 134 | dtype="float64", 135 | ) 136 | model = MetatensorAtomisticModel(model.eval(), ModelMetadata(), capabilities) 137 | compute_requested_neighbors( 138 | systems=System(positions=positions, cell=cell, pbc=pbc, types=types), 139 | system_length_unit="A", 140 | model=model, 141 | ) 142 | 143 | for system in systems: 144 | all_options = system.known_neighbor_lists() 145 | assert len(all_options) == 2 146 | assert all_options[0].requestors() == ["OuterModule"] 147 | assert all_options[0].cutoff == 5.2 148 | assert all_options[1].requestors() == ["OuterModule.inner"] 149 | assert all_options[1].cutoff == 3.4 150 | 151 | message = ( 152 | "the given `model_length_unit` \\(nm\\) does not match the model " 153 | "capabilities \\(A\\)" 154 | ) 155 | with pytest.raises(ValueError, match=message): 156 | compute_requested_neighbors( 157 | systems=System(positions=positions, cell=cell, pbc=pbc, types=types), 158 | system_length_unit="A", 159 | model=model, 160 | model_length_unit="nm", 161 | ) 162 | -------------------------------------------------------------------------------- /python/vesin_torch/tests/test_metatomic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from metatomic.torch import NeighborListOptions, System 3 | 4 | from vesin.metatomic import NeighborList 5 | 6 | 7 | def test_script(): 8 | positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64) 9 | cell = 4 * torch.eye(3, dtype=torch.float64) 10 | system = System( 11 | positions=positions, 12 | cell=cell, 13 | pbc=torch.ones(3, dtype=bool), 14 | types=torch.tensor([6, 8]), 15 | ) 16 | 17 | options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True) 18 | calculator = torch.jit.script( 19 | NeighborList(options, length_unit="A", torchscript=True) 20 | ) 21 | calculator.compute(system) 22 | -------------------------------------------------------------------------------- /python/vesin_torch/tests/test_neighbors.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | import torch 5 | 6 | from vesin.torch import NeighborList 7 | 8 | 9 | def test_errors(): 10 | points = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], dtype=torch.float64) 11 | box = torch.zeros((3, 3), dtype=torch.float64) 12 | 13 | calculator = NeighborList(cutoff=2.8, full_list=True) 14 | 15 | message = "only float64 dtype is supported in vesin" 16 | with pytest.raises(ValueError, match=message): 17 | calculator.compute( 18 | points.to(torch.float32), 19 | box.to(torch.float32), 20 | periodic=False, 21 | quantities="ij", 22 | ) 23 | 24 | message = "expected `points` and `box` to have the same dtype, got Double and Float" 25 | with pytest.raises(ValueError, match=message): 26 | calculator.compute( 27 | points, 28 | box.to(torch.float32), 29 | periodic=False, 30 | quantities="ij", 31 | ) 32 | 33 | message = "expected `points` and `box` to have the same device, got cpu and meta" 34 | with pytest.raises(ValueError, match=message): 35 | calculator.compute( 36 | points, 37 | box.to(device="meta"), 38 | periodic=False, 39 | quantities="ij", 40 | ) 41 | 42 | message = "unexpected character in `quantities`: Q" 43 | with pytest.raises(ValueError, match=message): 44 | calculator.compute( 45 | points, 46 | box, 47 | periodic=False, 48 | quantities="ijQ", 49 | ) 50 | 51 | message = "device meta is not supported in vesin" 52 | with pytest.raises(RuntimeError, match=message): 53 | calculator.compute( 54 | points.to(device="meta"), 55 | box.to(device="meta"), 56 | periodic=False, 57 | quantities="ij", 58 | ) 59 | 60 | 61 | @pytest.mark.parametrize("quantities", ["ijS", "D", "d", "ijSDd"]) 62 | def test_all_alone_no_neighbors(quantities): 63 | points = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], dtype=torch.float64) 64 | box = torch.eye(3, dtype=torch.float64) 65 | 66 | calculator = NeighborList(cutoff=0.1, full_list=True) 67 | outputs = calculator.compute(points, box, True, quantities) 68 | 69 | if "ij" in quantities: 70 | assert list(outputs[quantities.index("i")].shape) == [0] 71 | assert list(outputs[quantities.index("j")].shape) == [0] 72 | 73 | if "S" in quantities: 74 | assert list(outputs[quantities.index("S")].shape) == [0, 3] 75 | 76 | if "D" in quantities: 77 | assert list(outputs[quantities.index("D")].shape) == [0, 3] 78 | assert not outputs[quantities.index("D")].requires_grad 79 | 80 | if "d" in quantities: 81 | assert list(outputs[quantities.index("d")].shape) == [0] 82 | assert not outputs[quantities.index("d")].requires_grad 83 | 84 | points.requires_grad_(True) 85 | box.requires_grad_(True) 86 | outputs = calculator.compute(points, box, True, quantities) 87 | 88 | if "ij" in quantities: 89 | assert list(outputs[quantities.index("i")].shape) == [0] 90 | assert list(outputs[quantities.index("j")].shape) == [0] 91 | 92 | if "S" in quantities: 93 | assert list(outputs[quantities.index("S")].shape) == [0, 3] 94 | 95 | if "D" in quantities: 96 | assert list(outputs[quantities.index("D")].shape) == [0, 3] 97 | assert outputs[quantities.index("D")].requires_grad 98 | 99 | if "d" in quantities: 100 | assert list(outputs[quantities.index("d")].shape) == [0] 101 | assert outputs[quantities.index("d")].requires_grad 102 | 103 | 104 | class NeighborListWrap: 105 | def __init__(self, cutoff: float, full_list: bool): 106 | self._c = NeighborList(cutoff=cutoff, full_list=full_list) 107 | 108 | def compute( 109 | self, 110 | points: torch.Tensor, 111 | box: torch.Tensor, 112 | periodic: bool, 113 | quantities: str, 114 | copy: bool, 115 | ) -> List[torch.Tensor]: 116 | return self._c.compute( 117 | points=points, 118 | box=box, 119 | periodic=periodic, 120 | quantities=quantities, 121 | copy=copy, 122 | ) 123 | 124 | 125 | def test_script(): 126 | class TestModule(torch.nn.Module): 127 | def forward(self, x: NeighborListWrap) -> NeighborListWrap: 128 | return x 129 | 130 | module = TestModule() 131 | module = torch.jit.script(module) 132 | -------------------------------------------------------------------------------- /python/vesin_torch/vesin/torch/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | 3 | from ._c_lib import _load_library 4 | from ._neighbors import NeighborList # noqa: F401 5 | 6 | 7 | __version__ = importlib.metadata.version("vesin-torch") 8 | 9 | _load_library() 10 | -------------------------------------------------------------------------------- /python/vesin_torch/vesin/torch/_c_lib.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import re 4 | import sys 5 | from collections import namedtuple 6 | 7 | import torch 8 | 9 | 10 | Version = namedtuple("Version", ["major", "minor", "patch"]) 11 | 12 | 13 | def parse_version(version): 14 | match = re.match(r"(\d+)\.(\d+)\.(\d+).*", version) 15 | if match: 16 | return Version(*map(int, match.groups())) 17 | else: 18 | raise ValueError("Invalid version string format") 19 | 20 | 21 | _HERE = os.path.realpath(os.path.dirname(__file__)) 22 | 23 | 24 | def _lib_path(): 25 | torch_version = parse_version(torch.__version__) 26 | expected_prefix = os.path.join( 27 | _HERE, f"torch-{torch_version.major}.{torch_version.minor}" 28 | ) 29 | if os.path.exists(expected_prefix): 30 | if sys.platform.startswith("darwin"): 31 | path = os.path.join(expected_prefix, "lib", "libvesin_torch.dylib") 32 | windows = False 33 | elif sys.platform.startswith("linux"): 34 | path = os.path.join(expected_prefix, "lib", "libvesin_torch.so") 35 | windows = False 36 | elif sys.platform.startswith("win"): 37 | path = os.path.join(expected_prefix, "bin", "vesin_torch.dll") 38 | windows = True 39 | else: 40 | raise ImportError("Unknown platform. Please edit this file") 41 | 42 | if os.path.isfile(path): 43 | if windows: 44 | _check_dll(path) 45 | return path 46 | else: 47 | raise ImportError("Could not find vesin_torch shared library at " + path) 48 | 49 | # gather which torch version(s) the current install was built 50 | # with to create the error message 51 | existing_versions = [] 52 | for prefix in glob.glob(os.path.join(_HERE, "torch-*")): 53 | existing_versions.append(os.path.basename(prefix)[6:]) 54 | 55 | if len(existing_versions) == 1: 56 | raise ImportError( 57 | f"Trying to load vesin-torch with torch v{torch.__version__}, " 58 | f"but it was compiled against torch v{existing_versions[0]}, which " 59 | "is not ABI compatible" 60 | ) 61 | else: 62 | all_versions = ", ".join(map(lambda version: f"v{version}", existing_versions)) 63 | raise ImportError( 64 | f"Trying to load vesin-torch with torch v{torch.__version__}, " 65 | f"we found builds for torch {all_versions}; which are not ABI compatible.\n" 66 | "You can try to re-install from source with " 67 | "`pip install vesin-torch --no-binary=vesin-torch`" 68 | ) 69 | 70 | 71 | def _check_dll(path): 72 | """ 73 | Check if the DLL pointer size matches Python (32-bit or 64-bit) 74 | """ 75 | import platform 76 | import struct 77 | 78 | IMAGE_FILE_MACHINE_I386 = 332 79 | IMAGE_FILE_MACHINE_AMD64 = 34404 80 | 81 | machine = None 82 | with open(path, "rb") as fd: 83 | header = fd.read(2).decode(encoding="utf-8", errors="strict") 84 | if header != "MZ": 85 | raise ImportError(path + " is not a DLL") 86 | else: 87 | fd.seek(60) 88 | header = fd.read(4) 89 | header_offset = struct.unpack(" List[torch.Tensor]: 25 | """ 26 | Compute the neighbor list for the system defined by ``positions``, ``box``, and 27 | ``periodic``; returning the requested ``quantities``. 28 | 29 | ``quantities`` can contain any combination of the following values: 30 | 31 | - ``"i"`` to get the index of the first point in the pair 32 | - ``"j"`` to get the index of the second point in the pair 33 | - ``"P"`` to get the indexes of the two points in the pair simultaneously 34 | - ``"S"`` to get the periodic shift of the pair 35 | - ``"d"`` to get the distance between points in the pair 36 | - ``"D"`` to get the distance vector between points in the pair 37 | 38 | :param points: positions of all points in the system 39 | :param box: bounding box of the system 40 | :param periodic: should we use periodic boundary conditions? 41 | :param quantities: quantities to return, defaults to "ij" 42 | :param copy: should we copy the returned quantities, defaults to ``True``. 43 | Setting this to ``False`` might be a bit faster, but the returned tensors 44 | are view inside this class, and will be invalidated whenever this class is 45 | garbage collected or used to run a new calculation. 46 | 47 | :return: list of :py:class:`torch.Tensor` as indicated by ``quantities``. 48 | """ 49 | 50 | return self._c.compute( 51 | points=points, 52 | box=box, 53 | periodic=periodic, 54 | quantities=quantities, 55 | copy=copy, 56 | ) 57 | -------------------------------------------------------------------------------- /python/vesin_torch/vesin/torch/metatensor/__init__.py: -------------------------------------------------------------------------------- 1 | from ._model import compute_requested_neighbors # noqa: F401 2 | from ._neighbors import NeighborList # noqa: F401 3 | -------------------------------------------------------------------------------- /python/vesin_torch/vesin/torch/metatensor/_model.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | 5 | from ._neighbors import NeighborList 6 | 7 | 8 | try: 9 | from metatensor.torch.atomistic import ( 10 | MetatensorAtomisticModel, 11 | ModelInterface, 12 | NeighborListOptions, 13 | System, 14 | ) 15 | 16 | _HAS_METATENSOR = True 17 | except ModuleNotFoundError: 18 | _HAS_METATENSOR = False 19 | 20 | class MetatensorAtomisticModel: 21 | pass 22 | 23 | class ModelInterface: 24 | pass 25 | 26 | class NeighborListOptions: 27 | pass 28 | 29 | class System: 30 | pass 31 | 32 | 33 | def compute_requested_neighbors( 34 | systems: Union[List[System], System], 35 | system_length_unit: str, 36 | model: Union[MetatensorAtomisticModel, ModelInterface], 37 | model_length_unit: Optional[str] = None, 38 | ): 39 | """ 40 | Compute all neighbors lists requested by the ``model`` through 41 | ``requested_neighbor_lists()`` member functions, and store them inside all the 42 | ``systems``. 43 | 44 | :param systems: Single system or list of systems for which we need to compute the 45 | neighbor lists that the model requires. 46 | :param system_length_unit: unit of length used by the data in ``systems`` 47 | :param model: :py:class:`MetatensorAtomisticModel` or any ``torch.nn.Module`` 48 | following the :py:class:`ModelInterface` 49 | :param model_length_unit: unit of length used by the model, optional. This is only 50 | required when giving a raw model instead of a 51 | :py:class:`MetatensorAtomisticModel`. 52 | """ 53 | 54 | if isinstance(model, MetatensorAtomisticModel): 55 | if model_length_unit is not None: 56 | if model.capabilities().length_unit != model_length_unit: 57 | raise ValueError( 58 | f"the given `model_length_unit` ({model_length_unit}) does not " 59 | f"match the model capabilities ({model.capabilities().length_unit})" 60 | ) 61 | 62 | all_options = model.requested_neighbor_lists() 63 | elif isinstance(model, torch.nn.Module): 64 | if model_length_unit is None: 65 | raise ValueError( 66 | "`model_length_unit` parameter is required when not " 67 | "using MetatensorAtomisticModel" 68 | ) 69 | 70 | all_options = [] 71 | _get_requested_neighbor_lists( 72 | model, model.__class__.__name__, all_options, model_length_unit 73 | ) 74 | 75 | if not isinstance(systems, list): 76 | systems = [systems] 77 | 78 | for options in all_options: 79 | calculator = NeighborList(options, system_length_unit) 80 | for system in systems: 81 | neighbors = calculator.compute(system) 82 | system.add_neighbor_list(options, neighbors) 83 | 84 | 85 | def _get_requested_neighbor_lists( 86 | module: torch.nn.Module, 87 | module_name: str, 88 | requested: List[NeighborListOptions], 89 | length_unit: str, 90 | ): 91 | """ 92 | Recursively extract the requested neighbor lists from a non-exported metatensor 93 | atomistic model. 94 | """ 95 | if hasattr(module, "requested_neighbor_lists"): 96 | for new_options in module.requested_neighbor_lists(): 97 | new_options.add_requestor(module_name) 98 | 99 | already_requested = False 100 | for existing in requested: 101 | if existing == new_options: 102 | already_requested = True 103 | for requestor in new_options.requestors(): 104 | existing.add_requestor(requestor) 105 | 106 | if not already_requested: 107 | if new_options.length_unit not in ["", length_unit]: 108 | raise ValueError( 109 | f"NeighborsListOptions from {module_name} already have a " 110 | f"length unit ('{new_options.length_unit}') which does not " 111 | f"match the model length units ('{length_unit}')" 112 | ) 113 | 114 | new_options.length_unit = length_unit 115 | requested.append(new_options) 116 | 117 | for child_name, child in module.named_children(): 118 | _get_requested_neighbor_lists( 119 | module=child, 120 | module_name=module_name + "." + child_name, 121 | requested=requested, 122 | length_unit=length_unit, 123 | ) 124 | -------------------------------------------------------------------------------- /python/vesin_torch/vesin/torch/metatensor/_neighbors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .. import NeighborList as NeighborListTorch 4 | 5 | 6 | try: # only define a metatensor adapter if metatensor is available 7 | from metatensor.torch import Labels, TensorBlock 8 | from metatensor.torch.atomistic import NeighborListOptions, System 9 | 10 | _HAS_METATENSOR = True 11 | except ModuleNotFoundError: 12 | _HAS_METATENSOR = False 13 | 14 | class Labels: 15 | pass 16 | 17 | class TensorBlock: 18 | pass 19 | 20 | class System: 21 | pass 22 | 23 | class NeighborListOptions: 24 | pass 25 | 26 | 27 | class NeighborList: 28 | """ 29 | A neighbor list calculator that can be used with metatensor's atomistic models. 30 | 31 | The main difference with the other calculators is the automatic handling of 32 | different length unit between what the model expects and what the ``System`` are 33 | using. 34 | 35 | .. seealso:: 36 | 37 | The :py:func:`vesin.torch.metatensor.compute_requested_neighbors` function can 38 | be used to automatically compute and store all neighbor lists required by a 39 | given model. 40 | """ 41 | 42 | def __init__(self, options: NeighborListOptions, length_unit: str): 43 | """ 44 | :param options: :py:class:`metatensor.torch.atomistic.NeighborListOptions` 45 | defining the parameters of the neighbor list 46 | :param length_unit: unit of length used for the systems data 47 | 48 | Example 49 | ------- 50 | 51 | >>> from vesin.torch.metatensor import NeighborList 52 | >>> from metatensor.torch.atomistic import System, NeighborListOptions 53 | >>> import torch 54 | >>> system = System( 55 | ... positions=torch.eye(3).requires_grad_(True), 56 | ... cell=4 * torch.eye(3).requires_grad_(True), 57 | ... types=torch.tensor([8, 1, 1]), 58 | ... pbc=torch.ones(3, dtype=bool), 59 | ... ) 60 | >>> options = NeighborListOptions(cutoff=4.0, full_list=True, strict=False) 61 | >>> calculator = NeighborList(options, length_unit="Angstrom") 62 | >>> neighbors = calculator.compute(system) 63 | >>> neighbors 64 | TensorBlock 65 | samples (18): ['first_atom', 'second_atom', 'cell_shift_a', 'cell_shift_b', 'cell_shift_c'] 66 | components (3): ['xyz'] 67 | properties (1): ['distance'] 68 | gradients: None 69 | 70 | 71 | The returned TensorBlock can then be registered with the system 72 | 73 | >>> system.add_neighbor_list(options, neighbors) 74 | """ # noqa: E501 75 | 76 | if not torch.jit.is_scripting(): 77 | if not _HAS_METATENSOR: 78 | raise ModuleNotFoundError( 79 | "`vesin.metatensor` requires the `metatensor-torch` package" 80 | ) 81 | self.options = options 82 | self.length_unit = length_unit 83 | self._nl = NeighborListTorch( 84 | cutoff=self.options.engine_cutoff(self.length_unit), 85 | full_list=self.options.full_list, 86 | ) 87 | 88 | # cached Labels 89 | self._components = [Labels("xyz", torch.tensor([[0], [1], [2]]))] 90 | self._properties = Labels(["distance"], torch.tensor([[0]])) 91 | 92 | def compute(self, system: System) -> TensorBlock: 93 | """ 94 | Compute the neighbor list for the given 95 | :py:class:`metatensor.torch.atomistic.System`. 96 | 97 | :param system: a :py:class:`metatensor.torch.atomistic.System` containing the 98 | data about a structure. If the positions or cell of this system require 99 | gradients, the neighbors list values computational graph will be set 100 | accordingly. 101 | 102 | The positions and cell need to be in the length unit defined for this 103 | :py:class:`NeighborList` calculator. 104 | """ 105 | 106 | # move to float64, as vesin only works in torch64 107 | points = system.positions.to(torch.float64) 108 | box = system.cell.to(torch.float64) 109 | if torch.all(system.pbc): 110 | periodic = True 111 | elif not torch.any(system.pbc): 112 | periodic = False 113 | else: 114 | raise NotImplementedError( 115 | "vesin currently does not support mixed periodic and non-periodic " 116 | "boundary conditions" 117 | ) 118 | 119 | # computes neighbor list 120 | (P, S, D) = self._nl.compute( 121 | points=points, box=box, periodic=periodic, quantities="PSD", copy=True 122 | ) 123 | 124 | # converts to a suitable TensorBlock format 125 | neighbors = TensorBlock( 126 | D.reshape(-1, 3, 1).to(system.positions.dtype), 127 | samples=Labels( 128 | names=[ 129 | "first_atom", 130 | "second_atom", 131 | "cell_shift_a", 132 | "cell_shift_b", 133 | "cell_shift_c", 134 | ], 135 | values=torch.hstack([P, S]), 136 | ), 137 | components=self._components, 138 | properties=self._properties, 139 | ) 140 | 141 | return neighbors 142 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | [lint] 2 | select = ["E", "F", "B", "I"] 3 | ignore = ["B018", "B904"] 4 | 5 | [lint.isort] 6 | lines-after-imports = 2 7 | known-first-party = ["vesin"] 8 | 9 | [format] 10 | docstring-code-format = true 11 | -------------------------------------------------------------------------------- /scripts/clean-python.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script removes all temporary files created by Python during 4 | # installation and tests running. 5 | 6 | set -eux 7 | 8 | ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")"/.. && pwd) 9 | cd "$ROOT_DIR" 10 | 11 | rm -rf dist 12 | rm -rf build 13 | rm -rf docs/build 14 | 15 | rm -rf python/vesin/dist 16 | rm -rf python/vesin/build 17 | 18 | rm -rf python/vesin_torch/dist 19 | rm -rf python/vesin_torch/build 20 | 21 | find . -name "*.egg-info" -exec rm -rf "{}" + 22 | find . -name "__pycache__" -exec rm -rf "{}" + 23 | find . -name ".coverage" -exec rm -rf "{}" + 24 | -------------------------------------------------------------------------------- /scripts/create-torch-versions-range.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | This script updates the `Requires-Dist` information in metatensor-torch wheel METADATA 4 | to contain the range of compatible torch versions. It expects newline separated 5 | `Requires-Dist: torch ==...` information (corresponding to wheels built against a single 6 | torch version) and will print `Requires-Dist: torch >=$MIN_VERSION,<${MAX_VERSION+1}` on 7 | the standard output. 8 | 9 | This output can the be used in the merged wheel containing the build against all torch 10 | versions. 11 | """ 12 | import re 13 | import sys 14 | 15 | 16 | if __name__ == "__main__": 17 | torch_versions_raw = sys.argv[1] 18 | 19 | torch_versions = [] 20 | for version in torch_versions_raw.split("\n"): 21 | if version.strip() == "": 22 | continue 23 | 24 | match = re.match(r"Requires-Dist: torch[ ]?==(\d+)\.(\d+)\.\*", version) 25 | if match is None: 26 | raise ValueError(f"unexpected Requires-Dist format: {version}") 27 | 28 | major, minor = match.groups() 29 | major = int(major) 30 | minor = int(minor) 31 | 32 | version = (major, minor) 33 | 34 | if version in torch_versions: 35 | raise ValueError(f"duplicate torch version: {version}") 36 | 37 | torch_versions.append(version) 38 | 39 | torch_versions = list(sorted(torch_versions)) 40 | 41 | min_version = f"{torch_versions[0][0]}.{torch_versions[0][1]}" 42 | max_version = f"{torch_versions[-1][0]}.{torch_versions[-1][1] + 1}" 43 | 44 | print(f"Requires-Dist: torch >={min_version},<{max_version}") 45 | -------------------------------------------------------------------------------- /scripts/pytest-dont-rewrite-torch.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch # noqa 4 | 5 | 6 | PYTEST_DONT_REWRITE = '"""PYTEST_DONT_REWRITE"""' 7 | 8 | if __name__ == "__main__": 9 | try: 10 | # `torch.autograd.gradcheck` is the `torch.autograd.gradcheck.gradcheck`` 11 | # function, so we need to pick the module from `sys.modules` 12 | path = sys.modules["torch.autograd.gradcheck"].__file__ 13 | 14 | with open(path) as fd: 15 | content = fd.read() 16 | 17 | if PYTEST_DONT_REWRITE in content: 18 | sys.exit(0) 19 | 20 | with open(path, "w") as fd: 21 | print(f"rewriting {path} to add PYTEST_DONT_REWRITE") 22 | fd.write(PYTEST_DONT_REWRITE) 23 | fd.write("\n") 24 | fd.write(content) 25 | 26 | except Exception: 27 | print( 28 | "failed to add PYTEST_DONT_REWRITE to `torch.autograd.gradcheck`,", 29 | "tests are likely to fail", 30 | file=sys.stderr, 31 | ) 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # This is not the actual setup.py for this project, see `python/vesin/setup.py` for it. 2 | # Instead, this file is here to enable `pip install .` from a git checkout or `pip 3 | # install git+https://...` without having to specify a subdirectory 4 | 5 | import os 6 | 7 | from setuptools import setup 8 | 9 | 10 | ROOT = os.path.realpath(os.path.dirname(__file__)) 11 | 12 | setup( 13 | name="vesin-git", 14 | version="0.0.0", 15 | install_requires=[ 16 | f"vesin @ file://{ROOT}/python/vesin", 17 | ], 18 | extras_require={ 19 | "torch": [ 20 | f"vesin-torch @ file://{ROOT}/python/vesin_torch", 21 | ] 22 | }, 23 | packages=[], 24 | ) 25 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | # https://github.com/tox-dev/tox/issues/3238 3 | requires = tox==4.14.0 4 | 5 | envlist = 6 | lint 7 | tests 8 | torch-tests 9 | 10 | [testenv] 11 | lint-folders = python setup.py 12 | package = external 13 | package_env = build-vesin 14 | 15 | [testenv:build-vesin] 16 | passenv = * 17 | deps = 18 | cmake 19 | setuptools 20 | wheel 21 | 22 | commands = 23 | pip wheel python/vesin --no-deps --no-build-isolation --check-build-dependencies --wheel-dir {envtmpdir}/dist 24 | 25 | [testenv:tests] 26 | passenv = * 27 | description = Run the tests of the vesin Python package 28 | deps = 29 | ase 30 | pytest 31 | 32 | metatomic-torch >=0.1,<0.2 33 | 34 | changedir = python/vesin 35 | commands = 36 | pytest {posargs} 37 | 38 | # Enable when doc examples exist; pytest fails if no tests exist 39 | # pytest --doctest-modules --pyargs vesin 40 | 41 | [testenv:torch-tests] 42 | passenv = * 43 | description = Run the tests of the vesin-torch Python package 44 | deps = 45 | pytest 46 | torch 47 | metatensor-torch >=0.7,<0.8 48 | metatomic-torch >=0.1,<0.2 49 | numpy 50 | 51 | cmake 52 | setuptools 53 | wheel 54 | 55 | changedir = python/vesin_torch 56 | commands = 57 | pip install . --no-deps --no-build-isolation --check-build-dependencies 58 | 59 | # Make torch.autograd.gradcheck works with pytest 60 | python {toxinidir}/scripts/pytest-dont-rewrite-torch.py 61 | 62 | pytest {posargs} 63 | pytest --doctest-modules --pyargs vesin.torch 64 | 65 | 66 | [testenv:cxx-tests] 67 | passenv = * 68 | description = Run the C++ tests 69 | package = skip 70 | deps = cmake 71 | 72 | commands = 73 | cmake -B {envtmpdir} -S vesin -DVESIN_BUILD_TESTS=ON -DCMAKE_BUILD_TYPE=Debug 74 | cmake --build {envtmpdir} --config Debug 75 | ctest --test-dir {envtmpdir} --build-config Debug 76 | 77 | 78 | [testenv:lint] 79 | description = Run linters and formatter 80 | package = skip 81 | deps = 82 | ruff 83 | 84 | commands = 85 | ruff format --diff {[testenv]lint-folders} 86 | ruff check {[testenv]lint-folders} 87 | 88 | 89 | [testenv:format] 90 | description = Abuse tox to do actual formatting on all files. 91 | package = skip 92 | deps = 93 | ruff 94 | commands = 95 | ruff format {[testenv]lint-folders} 96 | ruff check --fix-only {[testenv]lint-folders} 97 | 98 | 99 | [testenv:docs] 100 | passenv = * 101 | description = Invoke sphinx-build to build the HTML docs 102 | deps = 103 | sphinx 104 | breathe >=4.33 # C++ => sphinx through doxygen 105 | furo # sphinx theme 106 | sphinx-design # helpers for nicer docs website (tabs, grids, cards, …) 107 | 108 | torch 109 | metatomic-torch >=0.1,<0.2 110 | cmake 111 | setuptools 112 | wheel 113 | 114 | commands = 115 | pip install python/vesin_torch --no-deps --no-build-isolation --check-build-dependencies 116 | sphinx-build -d docs/build/doctrees -W -b html docs/src docs/build/html 117 | -------------------------------------------------------------------------------- /vesin/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.16) 2 | 3 | file(READ "VERSION" VESIN_VERSION) 4 | string(STRIP ${VESIN_VERSION} VESIN_VERSION) 5 | 6 | project(vesin LANGUAGES C CXX VERSION ${VESIN_VERSION}) 7 | 8 | if (${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_CURRENT_SOURCE_DIR}) 9 | set(VESIN_MAIN_PROJECT ON) 10 | else() 11 | set(VESIN_MAIN_PROJECT OFF) 12 | endif() 13 | 14 | if (VESIN_MAIN_PROJECT) 15 | if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "") 16 | message(STATUS "Setting build type to 'release' as none was specified.") 17 | set( 18 | CMAKE_BUILD_TYPE "release" 19 | CACHE STRING 20 | "Choose the type of build, options are: none(CMAKE_CXX_FLAGS or CMAKE_C_FLAGS used) debug release relwithdebinfo minsizerel." 21 | FORCE 22 | ) 23 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug relwithdebinfo minsizerel none) 24 | endif() 25 | endif() 26 | 27 | option(BUILD_SHARED_LIBS "Build shared libraries instead of static ones" OFF) 28 | option(VESIN_BUILD_TESTS "Build and run Vesin's unit tests" OFF) 29 | option(VESIN_INSTALL "Install Vesin's headers and libraries" ${VESIN_MAIN_PROJECT}) 30 | option(VESIN_TORCH "Build the vesin_torch library" OFF) 31 | 32 | set(VESIN_SOURCES 33 | ${CMAKE_CURRENT_SOURCE_DIR}/src/vesin.cpp 34 | ${CMAKE_CURRENT_SOURCE_DIR}/src/cpu_cell_list.cpp 35 | ) 36 | add_library(vesin ${VESIN_SOURCES}) 37 | 38 | target_include_directories(vesin PUBLIC 39 | $ 40 | $ 41 | ) 42 | 43 | target_compile_features(vesin PRIVATE cxx_std_17) 44 | 45 | set_target_properties(vesin PROPERTIES 46 | # hide non-exported symbols by default 47 | CXX_VISIBILITY_PRESET hidden 48 | VISIBILITY_INLINES_HIDDEN ON 49 | ) 50 | 51 | target_compile_definitions(vesin PRIVATE VESIN_EXPORTS) 52 | if (BUILD_SHARED_LIBS) 53 | target_compile_definitions(vesin PUBLIC VESIN_SHARED) 54 | endif() 55 | 56 | if (VESIN_BUILD_TESTS) 57 | enable_testing() 58 | add_subdirectory(tests) 59 | endif() 60 | 61 | if (VESIN_TORCH) 62 | add_subdirectory(torch) 63 | endif() 64 | 65 | #------------------------------------------------------------------------------# 66 | # Installation configuration 67 | #------------------------------------------------------------------------------# 68 | if (VESIN_INSTALL) 69 | install(TARGETS vesin 70 | ARCHIVE DESTINATION "lib" 71 | LIBRARY DESTINATION "lib" 72 | RUNTIME DESTINATION "bin" 73 | ) 74 | 75 | install(FILES "include/vesin.h" DESTINATION "include") 76 | endif() 77 | -------------------------------------------------------------------------------- /vesin/VERSION: -------------------------------------------------------------------------------- 1 | 0.3.7 2 | -------------------------------------------------------------------------------- /vesin/include/vesin.h: -------------------------------------------------------------------------------- 1 | #ifndef VESIN_H 2 | #define VESIN_H 3 | 4 | #include 5 | #include 6 | 7 | #if defined(VESIN_SHARED) 8 | #if defined(VESIN_EXPORTS) 9 | #if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) 10 | #define VESIN_API __attribute__((visibility("default"))) 11 | #elif defined(_MSC_VER) 12 | #define VESIN_API __declspec(dllexport) 13 | #else 14 | #define VESIN_API 15 | #endif 16 | #else 17 | #if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__) 18 | #define VESIN_API __attribute__((visibility("default"))) 19 | #elif defined(_MSC_VER) 20 | #define VESIN_API __declspec(dllimport) 21 | #else 22 | #define VESIN_API 23 | #endif 24 | #endif 25 | #else 26 | #define VESIN_API 27 | #endif 28 | 29 | #ifdef __cplusplus 30 | extern "C" { 31 | #endif 32 | 33 | /// Options for a neighbor list calculation 34 | struct VesinOptions { 35 | /// Spherical cutoff, only pairs below this cutoff will be included 36 | double cutoff; 37 | /// Should the returned neighbor list be a full list (include both `i -> j` 38 | /// and `j -> i` pairs) or a half list (include only `i -> j`)? 39 | bool full; 40 | /// Should the neighbor list be sorted? If yes, the returned pairs will be 41 | /// sorted using lexicographic order. 42 | bool sorted; 43 | 44 | /// Should the returned `VesinNeighborList` contain `shifts`? 45 | bool return_shifts; 46 | /// Should the returned `VesinNeighborList` contain `distances`? 47 | bool return_distances; 48 | /// Should the returned `VesinNeighborList` contain `vector`? 49 | bool return_vectors; 50 | }; 51 | 52 | /// Device on which the data can be 53 | enum VesinDevice { 54 | /// Unknown device, used for default initialization and to indicate no 55 | /// allocated data. 56 | VesinUnknownDevice = 0, 57 | /// CPU device 58 | VesinCPU = 1, 59 | }; 60 | 61 | 62 | /// The actual neighbor list 63 | /// 64 | /// This is organized as a list of pairs, where each pair can contain the 65 | /// following data: 66 | /// 67 | /// - indices of the points in the pair; 68 | /// - distance between points in the pair, accounting for periodic boundary 69 | /// conditions; 70 | /// - vector between points in the pair, accounting for periodic boundary 71 | /// conditions; 72 | /// - periodic shift that created the pair. This is only relevant when using 73 | /// periodic boundary conditions, and contains the number of bounding box we 74 | /// need to cross to create the pair. If the positions of the points are `r_i` 75 | /// and `r_j`, the bounding box is described by a matrix of three vectors `H`, 76 | /// and the periodic shift is `S`, the distance vector for a given pair will 77 | /// be given by `r_ij = r_j - r_i + S @ H`. 78 | /// 79 | /// Under periodic boundary conditions, two atoms can be part of multiple pairs, 80 | /// each pair having a different periodic shift. 81 | struct VESIN_API VesinNeighborList { 82 | #ifdef __cplusplus 83 | VesinNeighborList(): 84 | length(0), 85 | device(VesinUnknownDevice), 86 | pairs(nullptr), 87 | shifts(nullptr), 88 | distances(nullptr), 89 | vectors(nullptr) 90 | {} 91 | #endif 92 | 93 | /// Number of pairs in this neighbor list 94 | size_t length; 95 | /// Device used for the data allocations 96 | VesinDevice device; 97 | /// Array of pairs (storing the indices of the first and second point in the 98 | /// pair), containing `length` elements. 99 | size_t (*pairs)[2]; 100 | /// Array of box shifts, one for each `pair`. This is only set if 101 | /// `options.return_pairs` was `true` during the calculation. 102 | int32_t (*shifts)[3]; 103 | /// Array of pair distance (i.e. distance between the two points), one for 104 | /// each pair. This is only set if `options.return_distances` was `true` 105 | /// during the calculation. 106 | double *distances; 107 | /// Array of pair vector (i.e. vector between the two points), one for 108 | /// each pair. This is only set if `options.return_vector` was `true` 109 | /// during the calculation. 110 | double (*vectors)[3]; 111 | 112 | // TODO: custom memory allocators? 113 | }; 114 | 115 | /// Free all allocated memory inside a `VesinNeighborList`, according the it's 116 | /// `device`. 117 | void VESIN_API vesin_free(struct VesinNeighborList* neighbors); 118 | 119 | /// Compute a neighbor list. 120 | /// 121 | /// The data is returned in a `VesinNeighborList`. For an initial call, the 122 | /// `VesinNeighborList` should be zero-initialized (or default-initalized in 123 | /// C++). The `VesinNeighborList` can be re-used across calls to this functions 124 | /// to re-use memory allocations, and once it is no longer needed, users should 125 | /// call `vesin_free` to release the corresponding memory. 126 | /// 127 | /// @param points positions of all points in the system; 128 | /// @param n_points number of elements in the `points` array 129 | /// @param box bounding box for the system. If the system is non-periodic, 130 | /// this is ignored. This should contain the three vectors of the bounding 131 | /// box, one vector per row of the matrix. 132 | /// @param periodic is the system using periodic boundary conditions? 133 | /// @param device device where the `points` and `box` data is allocated. 134 | /// @param options options for the calculation 135 | /// @param neighbors non-NULL pointer to `VesinNeighborList` that will be used 136 | /// to store the computed list of neighbors. 137 | /// @param error_message Pointer to a `char*` that wil be set to the error 138 | /// message if this function fails. This does not need to be freed when no 139 | /// longer needed. 140 | int VESIN_API vesin_neighbors( 141 | const double (*points)[3], 142 | size_t n_points, 143 | const double box[3][3], 144 | bool periodic, 145 | VesinDevice device, 146 | struct VesinOptions options, 147 | struct VesinNeighborList* neighbors, 148 | const char** error_message 149 | ); 150 | 151 | 152 | #ifdef __cplusplus 153 | 154 | } // extern "C" 155 | 156 | #endif 157 | 158 | #endif 159 | -------------------------------------------------------------------------------- /vesin/src/cpu_cell_list.hpp: -------------------------------------------------------------------------------- 1 | #ifndef VESIN_CPU_CELL_LIST_HPP 2 | #define VESIN_CPU_CELL_LIST_HPP 3 | 4 | #include 5 | 6 | #include "vesin.h" 7 | 8 | #include "types.hpp" 9 | 10 | namespace vesin { namespace cpu { 11 | 12 | void free_neighbors(VesinNeighborList& neighbors); 13 | 14 | void neighbors( 15 | const Vector* points, 16 | size_t n_points, 17 | BoundingBox cell, 18 | VesinOptions options, 19 | VesinNeighborList& neighbors 20 | ); 21 | 22 | 23 | /// The cell list is used to sort atoms inside bins/cells. 24 | /// 25 | /// The list of potential pairs is then constructed by looking through all 26 | /// neighboring cells (the number of cells to search depends on the cutoff and 27 | /// the size of the cells) for each atom to create pair candidates. 28 | class CellList { 29 | public: 30 | /// Create a new `CellList` for the given bounding box and cutoff, 31 | /// determining all required parameters. 32 | CellList(BoundingBox box, double cutoff); 33 | 34 | /// Add a single point to the cell list at the given `position`. The point 35 | /// is uniquely identified by its `index`. 36 | void add_point(size_t index, Vector position); 37 | 38 | /// Iterate over all possible pairs, calling the given callback every time 39 | template 40 | void foreach_pair(Function callback); 41 | 42 | private: 43 | /// How many cells do we need to look at when searching neighbors to include 44 | /// all neighbors below cutoff 45 | std::array n_search_; 46 | 47 | /// the cells themselves are a list of points & corresponding 48 | /// shift to place the point inside the cell 49 | struct Point { 50 | size_t index; 51 | CellShift shift; 52 | }; 53 | struct Cell: public std::vector {}; 54 | 55 | // raw data for the cells 56 | std::vector cells_; 57 | // shape of the cell array 58 | std::array cells_shape_; 59 | 60 | BoundingBox box_; 61 | 62 | Cell& get_cell(std::array index); 63 | }; 64 | 65 | /// Wrapper around `VesinNeighborList` that behaves like a std::vector, 66 | /// automatically growing memory allocations. 67 | class GrowableNeighborList { 68 | public: 69 | VesinNeighborList& neighbors; 70 | size_t capacity; 71 | VesinOptions options; 72 | 73 | size_t length() const { 74 | return neighbors.length; 75 | } 76 | 77 | void increment_length() { 78 | neighbors.length += 1; 79 | } 80 | 81 | void set_pair(size_t index, size_t first, size_t second); 82 | void set_shift(size_t index, vesin::CellShift shift); 83 | void set_distance(size_t index, double distance); 84 | void set_vector(size_t index, vesin::Vector vector); 85 | 86 | // reset length to 0, and allocate/deallocate members of 87 | // `neighbors` according to `options` 88 | void reset(); 89 | 90 | // allocate more memory & update capacity 91 | void grow(); 92 | 93 | // sort the pairs currently in the neighbor list 94 | void sort(); 95 | }; 96 | 97 | } // namespace vesin 98 | } // namespace cpu 99 | 100 | #endif 101 | -------------------------------------------------------------------------------- /vesin/src/math.hpp: -------------------------------------------------------------------------------- 1 | #ifndef VESIN_MATH_HPP 2 | #define VESIN_MATH_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace vesin { 9 | struct Vector; 10 | 11 | Vector operator*(Vector vector, double scalar); 12 | 13 | struct Vector: public std::array { 14 | double dot(Vector other) const { 15 | return (*this)[0] * other[0] + (*this)[1] * other[1] + (*this)[2] * other[2]; 16 | } 17 | 18 | double norm() const { 19 | return std::sqrt(this->dot(*this)); 20 | } 21 | 22 | Vector normalize() const { 23 | return *this * (1.0 / this->norm()); 24 | } 25 | 26 | Vector cross(Vector other) const { 27 | return Vector{ 28 | (*this)[1] * other[2] - (*this)[2] * other[1], 29 | (*this)[2] * other[0] - (*this)[0] * other[2], 30 | (*this)[0] * other[1] - (*this)[1] * other[0], 31 | }; 32 | } 33 | }; 34 | 35 | inline Vector operator+(Vector u, Vector v) { 36 | return Vector{ 37 | u[0] + v[0], 38 | u[1] + v[1], 39 | u[2] + v[2], 40 | }; 41 | } 42 | 43 | inline Vector operator-(Vector u, Vector v) { 44 | return Vector{ 45 | u[0] - v[0], 46 | u[1] - v[1], 47 | u[2] - v[2], 48 | }; 49 | } 50 | 51 | inline Vector operator*(double scalar, Vector vector) { 52 | return Vector{ 53 | scalar * vector[0], 54 | scalar * vector[1], 55 | scalar * vector[2], 56 | }; 57 | } 58 | 59 | inline Vector operator*(Vector vector, double scalar) { 60 | return Vector{ 61 | scalar * vector[0], 62 | scalar * vector[1], 63 | scalar * vector[2], 64 | }; 65 | } 66 | 67 | 68 | struct Matrix: public std::array, 3> { 69 | double determinant() const { 70 | return (*this)[0][0] * ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2]) 71 | - (*this)[0][1] * ((*this)[1][0] * (*this)[2][2] - (*this)[1][2] * (*this)[2][0]) 72 | + (*this)[0][2] * ((*this)[1][0] * (*this)[2][1] - (*this)[1][1] * (*this)[2][0]); 73 | } 74 | 75 | Matrix inverse() const { 76 | auto det = this->determinant(); 77 | 78 | if (std::abs(det) < 1e-30) { 79 | throw std::runtime_error("this matrix is not invertible"); 80 | } 81 | 82 | auto inverse = Matrix(); 83 | inverse[0][0] = ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2]) / det; 84 | inverse[0][1] = ((*this)[0][2] * (*this)[2][1] - (*this)[0][1] * (*this)[2][2]) / det; 85 | inverse[0][2] = ((*this)[0][1] * (*this)[1][2] - (*this)[0][2] * (*this)[1][1]) / det; 86 | inverse[1][0] = ((*this)[1][2] * (*this)[2][0] - (*this)[1][0] * (*this)[2][2]) / det; 87 | inverse[1][1] = ((*this)[0][0] * (*this)[2][2] - (*this)[0][2] * (*this)[2][0]) / det; 88 | inverse[1][2] = ((*this)[1][0] * (*this)[0][2] - (*this)[0][0] * (*this)[1][2]) / det; 89 | inverse[2][0] = ((*this)[1][0] * (*this)[2][1] - (*this)[2][0] * (*this)[1][1]) / det; 90 | inverse[2][1] = ((*this)[2][0] * (*this)[0][1] - (*this)[0][0] * (*this)[2][1]) / det; 91 | inverse[2][2] = ((*this)[0][0] * (*this)[1][1] - (*this)[1][0] * (*this)[0][1]) / det; 92 | return inverse; 93 | } 94 | }; 95 | 96 | 97 | inline Vector operator*(Matrix matrix, Vector vector) { 98 | return Vector{ 99 | matrix[0][0] * vector[0] + matrix[0][1] * vector[1] + matrix[0][2] * vector[2], 100 | matrix[1][0] * vector[0] + matrix[1][1] * vector[1] + matrix[1][2] * vector[2], 101 | matrix[2][0] * vector[0] + matrix[2][1] * vector[1] + matrix[2][2] * vector[2], 102 | }; 103 | } 104 | 105 | inline Vector operator*(Vector vector, Matrix matrix) { 106 | return Vector{ 107 | vector[0] * matrix[0][0] + vector[1] * matrix[1][0] + vector[2] * matrix[2][0], 108 | vector[0] * matrix[0][1] + vector[1] * matrix[1][1] + vector[2] * matrix[2][1], 109 | vector[0] * matrix[0][2] + vector[1] * matrix[1][2] + vector[2] * matrix[2][2], 110 | }; 111 | } 112 | 113 | } // namespace vesin 114 | 115 | #endif 116 | -------------------------------------------------------------------------------- /vesin/src/types.hpp: -------------------------------------------------------------------------------- 1 | #ifndef VESIN_TYPES_HPP 2 | #define VESIN_TYPES_HPP 3 | 4 | #include "math.hpp" 5 | 6 | namespace vesin { 7 | 8 | class BoundingBox { 9 | public: 10 | BoundingBox(Matrix matrix, bool periodic): matrix_(matrix), periodic_(periodic) { 11 | if (periodic) { 12 | auto det = matrix_.determinant(); 13 | if (std::abs(det) < 1e-30) { 14 | throw std::runtime_error("the box matrix is not invertible"); 15 | } 16 | 17 | this->inverse_ = matrix_.inverse(); 18 | } else { 19 | this->matrix_ = Matrix{{{ 20 | {{1, 0, 0}}, 21 | {{0, 1, 0}}, 22 | {{0, 0, 1}} 23 | }}}; 24 | this->inverse_ = matrix_; 25 | } 26 | } 27 | 28 | const Matrix& matrix() const { 29 | return this->matrix_; 30 | } 31 | 32 | bool periodic() const { 33 | return this->periodic_; 34 | } 35 | 36 | /// Convert a vector from cartesian coordinates to fractional coordinates 37 | Vector cartesian_to_fractional(Vector cartesian) const { 38 | return cartesian * inverse_; 39 | } 40 | 41 | /// Convert a vector from fractional coordinates to cartesian coordinates 42 | Vector fractional_to_cartesian(Vector fractional) const { 43 | return fractional * matrix_; 44 | } 45 | 46 | /// Get the three distances between faces of the bounding box 47 | Vector distances_between_faces() const { 48 | auto a = Vector{matrix_[0]}; 49 | auto b = Vector{matrix_[1]}; 50 | auto c = Vector{matrix_[2]}; 51 | 52 | // Plans normal vectors 53 | auto na = b.cross(c).normalize(); 54 | auto nb = c.cross(a).normalize(); 55 | auto nc = a.cross(b).normalize(); 56 | 57 | return Vector{ 58 | std::abs(na.dot(a)), 59 | std::abs(nb.dot(b)), 60 | std::abs(nc.dot(c)), 61 | }; 62 | } 63 | 64 | private: 65 | Matrix matrix_; 66 | Matrix inverse_; 67 | bool periodic_; 68 | }; 69 | 70 | 71 | /// A cell shift represents the displacement along cell axis between the actual 72 | /// position of an atom and a periodic image of this atom. 73 | /// 74 | /// The cell shift can be used to reconstruct the vector between two points, 75 | /// wrapped inside the unit cell. 76 | struct CellShift: public std::array { 77 | /// Compute the shift vector in cartesian coordinates, using the given cell 78 | /// matrix (stored in row major order). 79 | Vector cartesian(Matrix cell) const { 80 | auto vector = Vector{ 81 | static_cast((*this)[0]), 82 | static_cast((*this)[1]), 83 | static_cast((*this)[2]), 84 | }; 85 | return vector * cell; 86 | } 87 | }; 88 | 89 | inline CellShift operator+(CellShift a, CellShift b) { 90 | return CellShift{ 91 | a[0] + b[0], 92 | a[1] + b[1], 93 | a[2] + b[2], 94 | }; 95 | } 96 | 97 | inline CellShift operator-(CellShift a, CellShift b) { 98 | return CellShift{ 99 | a[0] - b[0], 100 | a[1] - b[1], 101 | a[2] - b[2], 102 | }; 103 | } 104 | 105 | 106 | } // namespace vesin 107 | 108 | #endif 109 | -------------------------------------------------------------------------------- /vesin/src/vesin.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "vesin.h" 5 | 6 | #include "cpu_cell_list.hpp" 7 | 8 | thread_local std::string LAST_ERROR; 9 | 10 | extern "C" int vesin_neighbors( 11 | const double (*points)[3], 12 | size_t n_points, 13 | const double box[3][3], 14 | bool periodic, 15 | VesinDevice device, 16 | VesinOptions options, 17 | VesinNeighborList* neighbors, 18 | const char** error_message 19 | ) { 20 | if (error_message == nullptr) { 21 | return EXIT_FAILURE; 22 | } 23 | 24 | if (points == nullptr) { 25 | *error_message = "`points` can not be a NULL pointer"; 26 | return EXIT_FAILURE; 27 | } 28 | 29 | if (box == nullptr) { 30 | *error_message = "`cell` can not be a NULL pointer"; 31 | return EXIT_FAILURE; 32 | } 33 | 34 | if (neighbors == nullptr) { 35 | *error_message = "`neighbors` can not be a NULL pointer"; 36 | return EXIT_FAILURE; 37 | } 38 | 39 | if (!std::isfinite(options.cutoff) || options.cutoff <= 0) { 40 | *error_message = "cutoff must be a finite, positive number"; 41 | return EXIT_FAILURE; 42 | } 43 | 44 | if (options.cutoff <= 1e-6) { 45 | *error_message = "cutoff is too small"; 46 | return EXIT_FAILURE; 47 | } 48 | 49 | if (neighbors->device != VesinUnknownDevice && neighbors->device != device) { 50 | *error_message = "`neighbors` device and data `device` do not match, free the neighbors first"; 51 | return EXIT_FAILURE; 52 | } 53 | 54 | if (device == VesinUnknownDevice) { 55 | *error_message = "got an unknown device to use when running simulation"; 56 | return EXIT_FAILURE; 57 | } 58 | 59 | if (neighbors->device == VesinUnknownDevice) { 60 | // initialize the device 61 | neighbors->device = device; 62 | } else if (neighbors->device != device) { 63 | *error_message = "`neighbors.device` and `device` do not match, free the neighbors first"; 64 | return EXIT_FAILURE; 65 | } 66 | 67 | try { 68 | if (device == VesinCPU) { 69 | auto matrix = vesin::Matrix{{{ 70 | {{box[0][0], box[0][1], box[0][2]}}, 71 | {{box[1][0], box[1][1], box[1][2]}}, 72 | {{box[2][0], box[2][1], box[2][2]}}, 73 | }}}; 74 | 75 | vesin::cpu::neighbors( 76 | reinterpret_cast(points), 77 | n_points, 78 | vesin::BoundingBox(matrix, periodic), 79 | options, 80 | *neighbors 81 | ); 82 | } else { 83 | throw std::runtime_error("unknown device " + std::to_string(device)); 84 | } 85 | } catch (const std::bad_alloc&) { 86 | LAST_ERROR = "failed to allocate memory"; 87 | *error_message = LAST_ERROR.c_str(); 88 | return EXIT_FAILURE; 89 | } catch (const std::exception& e) { 90 | LAST_ERROR = e.what(); 91 | *error_message = LAST_ERROR.c_str(); 92 | return EXIT_FAILURE; 93 | } catch (...) { 94 | *error_message = "fatal error: unknown type thrown as exception"; 95 | return EXIT_FAILURE; 96 | } 97 | 98 | return EXIT_SUCCESS; 99 | } 100 | 101 | 102 | extern "C" void vesin_free(VesinNeighborList* neighbors) { 103 | if (neighbors == nullptr) { 104 | return; 105 | } 106 | 107 | if (neighbors->device == VesinUnknownDevice) { 108 | // nothing to do 109 | } else if (neighbors->device == VesinCPU) { 110 | vesin::cpu::free_neighbors(*neighbors); 111 | } 112 | 113 | std::memset(neighbors, 0, sizeof(VesinNeighborList)); 114 | } 115 | -------------------------------------------------------------------------------- /vesin/tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | include(FetchContent) 2 | 3 | # Override options with variables 4 | set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) 5 | 6 | FetchContent_Declare(Catch2 7 | GIT_REPOSITORY https://github.com/catchorg/Catch2.git 8 | GIT_TAG v3.5.3 9 | ) 10 | 11 | set(CATCH_CONFIG_FAST_COMPILE ON) 12 | FetchContent_MakeAvailable(Catch2) 13 | 14 | find_program(VALGRIND valgrind) 15 | if (VALGRIND) 16 | message(STATUS "Running tests using valgrind") 17 | set(TEST_COMMAND 18 | "${VALGRIND}" "--tool=memcheck" "--dsymutil=yes" "--error-exitcode=125" 19 | "--leak-check=full" "--show-leak-kinds=definite,indirect,possible" "--track-origins=yes" 20 | "--gen-suppressions=all" 21 | ) 22 | else() 23 | set(TEST_COMMAND "") 24 | endif() 25 | 26 | 27 | file(GLOB ALL_TESTS *.cpp) 28 | foreach(_file_ ${ALL_TESTS}) 29 | get_filename_component(_name_ ${_file_} NAME_WE) 30 | add_executable(${_name_} ${_file_}) 31 | target_link_libraries(${_name_} vesin Catch2WithMain) 32 | 33 | add_test( 34 | NAME ${_name_} 35 | COMMAND ${TEST_COMMAND} $ 36 | ) 37 | endforeach() 38 | -------------------------------------------------------------------------------- /vesin/tests/memory.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | 6 | TEST_CASE("Re-use allocations") { 7 | double points[][3] = { 8 | {0.0, 0.0, 0.0}, 9 | }; 10 | size_t n_points = 1; 11 | 12 | double box[3][3] = { 13 | {0.0, 1.5, 1.5}, 14 | {1.5, 0.0, 1.5}, 15 | {1.5, 1.5, 0.0}, 16 | }; 17 | bool periodic = true; 18 | 19 | VesinNeighborList neighbors; 20 | 21 | auto compute_neighbors = [&](VesinOptions options) { 22 | const char* error_message = nullptr; 23 | auto status = vesin_neighbors( 24 | points, 25 | n_points, 26 | box, 27 | periodic, 28 | VesinDevice::VesinCPU, 29 | options, 30 | &neighbors, 31 | &error_message 32 | ); 33 | REQUIRE(status == EXIT_SUCCESS); 34 | REQUIRE(error_message == nullptr); 35 | }; 36 | 37 | auto options = VesinOptions(); 38 | options.cutoff = 3.4; 39 | options.full = false; 40 | options.return_shifts = false; 41 | options.return_distances = true; 42 | options.return_vectors = false; 43 | 44 | compute_neighbors(options); 45 | 46 | CHECK(neighbors.length == 9); 47 | CHECK(neighbors.pairs != nullptr); 48 | CHECK(neighbors.shifts == nullptr); 49 | CHECK(neighbors.distances != nullptr); 50 | CHECK(neighbors.vectors == nullptr); 51 | 52 | /***************************************************/ 53 | options.cutoff = 6; 54 | compute_neighbors(options); 55 | 56 | CHECK(neighbors.length == 67); 57 | CHECK(neighbors.pairs != nullptr); 58 | CHECK(neighbors.shifts == nullptr); 59 | CHECK(neighbors.distances != nullptr); 60 | CHECK(neighbors.vectors == nullptr); 61 | 62 | /***************************************************/ 63 | options.full = true; 64 | compute_neighbors(options); 65 | 66 | CHECK(neighbors.length == 67 * 2); 67 | CHECK(neighbors.pairs != nullptr); 68 | CHECK(neighbors.shifts == nullptr); 69 | CHECK(neighbors.distances != nullptr); 70 | CHECK(neighbors.vectors == nullptr); 71 | 72 | /***************************************************/ 73 | options.cutoff = 4.5; 74 | options.full = false; 75 | options.return_shifts = true; 76 | options.return_distances = false; 77 | compute_neighbors(options); 78 | 79 | CHECK(neighbors.length == 27); 80 | CHECK(neighbors.pairs != nullptr); 81 | CHECK(neighbors.shifts != nullptr); 82 | CHECK(neighbors.distances == nullptr); 83 | CHECK(neighbors.vectors == nullptr); 84 | 85 | 86 | vesin_free(&neighbors); 87 | } 88 | -------------------------------------------------------------------------------- /vesin/tests/neighbors.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | using namespace Catch::Matchers; 8 | 9 | #include 10 | 11 | #define CHECK_APPROX_EQUAL(a, b) CHECK_THAT(a, WithinULP(b, 4)); 12 | 13 | static void check_neighbors( 14 | const double (*points)[3], 15 | size_t n_points, 16 | const double box[3][3], 17 | bool periodic, 18 | double cutoff, 19 | bool full_list, 20 | std::vector> expected_pairs, 21 | std::vector> expected_shifts, 22 | std::vector expected_distances, 23 | std::vector> expected_vectors 24 | ) { 25 | auto options = VesinOptions(); 26 | options.cutoff = cutoff; 27 | options.full = full_list; 28 | options.return_shifts = !expected_shifts.empty(); 29 | options.return_distances = !expected_distances.empty(); 30 | options.return_vectors = !expected_vectors.empty(); 31 | 32 | VesinNeighborList neighbors; 33 | 34 | const char* error_message = nullptr; 35 | auto status = vesin_neighbors( 36 | points, 37 | n_points, 38 | box, 39 | periodic, 40 | VesinDevice::VesinCPU, 41 | options, 42 | &neighbors, 43 | &error_message 44 | ); 45 | REQUIRE(status == EXIT_SUCCESS); 46 | REQUIRE(error_message == nullptr); 47 | 48 | if (!expected_pairs.empty()) { 49 | REQUIRE(neighbors.length == expected_pairs.size()); 50 | for (size_t i=0; i>{ 97 | {0, 1}, 98 | {0, 2}, 99 | {0, 3}, 100 | {0, 4}, 101 | {1, 3}, 102 | {1, 4}, 103 | {2, 3}, 104 | {2, 4}, 105 | {3, 4}, 106 | }; 107 | 108 | auto expected_distances = std::vector{ 109 | 3.2082345612501593, 110 | 2.283282943482914, 111 | 2.4783286706972505, 112 | 1.215100818862369, 113 | 2.9707625283755013, 114 | 2.3059143522689647, 115 | 1.550639867925496, 116 | 2.9495550511899244, 117 | 2.6482573515427084, 118 | }; 119 | 120 | check_neighbors( 121 | points, /*n_points=*/ 5, box, /*periodic=*/ false, /*cutoff=*/ 3.42, /*full_list=*/ false, 122 | expected_pairs, {}, expected_distances, {} 123 | ); 124 | } 125 | 126 | TEST_CASE("FCC unit cell") { 127 | double points[][3] = { 128 | {0.0, 0.0, 0.0}, 129 | }; 130 | 131 | double box[3][3] = { 132 | {0.0, 1.5, 1.5}, 133 | {1.5, 0.0, 1.5}, 134 | {1.5, 1.5, 0.0}, 135 | }; 136 | 137 | auto expected_vectors = std::vector>{ 138 | {1.5, 0.0, -1.5}, 139 | {1.5, -1.5, 0.0}, 140 | {0.0, 1.5, -1.5}, 141 | {1.5, 1.5, 0.0}, 142 | {1.5, 0.0, 1.5}, 143 | {0.0, 1.5, 1.5}, 144 | }; 145 | 146 | auto expected_shifts = std::vector>{ 147 | {-1, 0, 1}, 148 | {-1, 1, 0}, 149 | {0, -1, 1}, 150 | {0, 0, 1}, 151 | {0, 1, 0}, 152 | {1, 0, 0}, 153 | }; 154 | 155 | auto expected_pairs = std::vector>(6, {0, 0}); 156 | auto expected_distances = std::vector(6, 2.1213203435596424); 157 | 158 | check_neighbors( 159 | points, /*n_points=*/ 1, box, /*periodic=*/ true, /*cutoff=*/ 3.0, /*full_list=*/ false, 160 | expected_pairs, expected_shifts, expected_distances, expected_vectors 161 | ); 162 | } 163 | 164 | TEST_CASE("Large box, small cutoff") { 165 | double points[][3] = { 166 | {0.0, 0.0, 0.0}, 167 | {0.0, 2.0, 0.0}, 168 | {0.0, 0.0, 2.0}, 169 | // points outside the box natural boundaries 170 | {-6.0, 0.0, 0.0}, 171 | {-6.0, -2.0, 0.0}, 172 | {-6.0, 0.0, -2.0}, 173 | }; 174 | 175 | double box[3][3] = { 176 | {54.0, 0.0, 0.0}, 177 | {0.0, 54.0, 0.0}, 178 | {0.0, 0.0, 54.0}, 179 | }; 180 | 181 | auto expected_pairs = std::vector>{ 182 | {0, 1}, 183 | {0, 2}, 184 | {3, 4}, 185 | {3, 5}, 186 | }; 187 | auto expected_shifts = std::vector>(4, {0, 0, 0}); 188 | auto expected_distances = std::vector(4, 2.0); 189 | 190 | check_neighbors( 191 | points, /*n_points=*/ 6, box, /*periodic=*/ true, /*cutoff=*/ 2.1, /*full_list=*/ false, 192 | expected_pairs, expected_shifts, expected_distances, {} 193 | ); 194 | } 195 | 196 | TEST_CASE("Cutoff larger than the box size") { 197 | double points[][3] = { 198 | {0.0, 0.0, 0.0}, 199 | }; 200 | 201 | double box[3][3] = { 202 | {0.5, 0.0, 0.0}, 203 | {0.0, 0.5, 0.0}, 204 | {0.0, 0.0, 0.5}, 205 | }; 206 | 207 | auto expected_pairs = std::vector>(3, {0, 0}); 208 | auto expected_distances = std::vector(3, 0.5); 209 | auto expected_shifts = std::vector>{ 210 | {0, 0, 1}, 211 | {0, 1, 0}, 212 | {1, 0, 0}, 213 | }; 214 | auto expected_vectors = std::vector>{ 215 | {0.0, 0.0, 0.5}, 216 | {0.0, 0.5, 0.0}, 217 | {0.5, 0.0, 0.0}, 218 | }; 219 | 220 | check_neighbors( 221 | points, /*n_points=*/ 1, box, /*periodic=*/ true, /*cutoff=*/ 0.6, /*full_list=*/ false, 222 | expected_pairs, expected_shifts, expected_distances, expected_vectors 223 | ); 224 | } 225 | 226 | TEST_CASE("Slanted box") { 227 | double points[][3] = { 228 | {1.42, 0.0, 0.0}, 229 | {2.84, 0.0, 0.0}, 230 | {3.55, -1.22975607, 0.0}, 231 | {4.97, -1.22975607, 0.0}, 232 | }; 233 | 234 | double box[3][3] = { 235 | {4.26, -2.45951215, 0.0}, 236 | {2.13, 1.22975607, 0.0}, 237 | {0.0, 0.0, 50.0}, 238 | }; 239 | 240 | auto options = VesinOptions(); 241 | options.cutoff = 6.4; 242 | options.full = false; 243 | options.return_shifts = true; 244 | options.return_distances = false; 245 | options.return_vectors = false; 246 | 247 | VesinNeighborList neighbors; 248 | 249 | const char* error_message = nullptr; 250 | auto status = vesin_neighbors( 251 | points, 252 | /*n_points=*/ 4, 253 | box, 254 | /*periodic=*/ true, 255 | VesinDevice::VesinCPU, 256 | options, 257 | &neighbors, 258 | &error_message 259 | ); 260 | REQUIRE(status == EXIT_SUCCESS); 261 | 262 | REQUIRE(neighbors.length == 90); 263 | auto previously_missing = std::vector>{ 264 | {-2, 0, 0}, 265 | {-2, 1, 0}, 266 | {-2, 2, 0}, 267 | }; 268 | 269 | for (const auto& missing: previously_missing) { 270 | bool found = false; 271 | for (size_t i=0; i 13 | $ 14 | $ 15 | ) 16 | 17 | target_compile_features(vesin_torch PUBLIC cxx_std_17) 18 | 19 | set_target_properties(vesin_torch PROPERTIES 20 | # hide non-exported symbols by default 21 | CXX_VISIBILITY_PRESET hidden 22 | VISIBILITY_INLINES_HIDDEN ON 23 | ) 24 | 25 | #------------------------------------------------------------------------------# 26 | # Installation configuration 27 | #------------------------------------------------------------------------------# 28 | if (VESIN_INSTALL) 29 | install(TARGETS vesin_torch 30 | ARCHIVE DESTINATION "lib" 31 | LIBRARY DESTINATION "lib" 32 | RUNTIME DESTINATION "bin" 33 | ) 34 | 35 | install(FILES "include/vesin_torch.hpp" DESTINATION "include") 36 | endif() 37 | -------------------------------------------------------------------------------- /vesin/torch/include/vesin_torch.hpp: -------------------------------------------------------------------------------- 1 | #ifndef VESIN_TORCH_HPP 2 | #define VESIN_TORCH_HPP 3 | 4 | #include 5 | 6 | struct VesinNeighborList; 7 | 8 | namespace vesin_torch { 9 | 10 | class NeighborListHolder; 11 | 12 | /// `NeighborListHolder` should be manipulated through a `torch::intrusive_ptr` 13 | using NeighborList = torch::intrusive_ptr; 14 | 15 | /// Neighbor list calculator compatible with TorchScript 16 | class NeighborListHolder: public torch::CustomClassHolder { 17 | public: 18 | /// Create a new calculator with the given `cutoff`. 19 | /// 20 | /// @param full_list whether pairs should be included twice in the output 21 | /// (both as `i-j` and `j-i`) or only once 22 | /// @param sorted whether pairs should be sorted in the output 23 | NeighborListHolder(double cutoff, bool full_list, bool sorted = false); 24 | ~NeighborListHolder(); 25 | 26 | /// Compute the neighbor list for the system defined by `positions`, `box`, 27 | /// and `periodic`; returning the requested `quantities`. 28 | /// 29 | /// `quantities` can contain any combination of the following values: 30 | /// 31 | /// - `"i"` to get the index of the first point in the pair 32 | /// - `"j"` to get the index of the second point in the pair 33 | /// - `"P"` to get the indexes of the two points in the pair simultaneously 34 | /// - `"S"` to get the periodic shift of the pair 35 | /// - `"d"` to get the distance between points in the pair 36 | /// - `"D"` to get the distance vector between points in the pair 37 | /// 38 | /// @param points positions of all points in the system 39 | /// @param box bounding box of the system 40 | /// @param periodic should we use periodic boundary conditions? 41 | /// @param quantities quantities to return, defaults to "ij" 42 | /// @param copy should we copy the returned quantities, defaults to `true`. 43 | /// Setting this to `False` might be a bit faster, but the returned 44 | /// tensors are view inside this class, and will be invalidated 45 | /// whenever this class is garbage collected or used to run a new 46 | /// calculation. 47 | /// 48 | /// @returns a list of `torch::Tensor` as indicated by `quantities`. 49 | std::vector compute( 50 | torch::Tensor points, 51 | torch::Tensor box, 52 | bool periodic, 53 | std::string quantities, 54 | bool copy=true 55 | ); 56 | private: 57 | double cutoff_; 58 | bool full_list_; 59 | bool sorted_; 60 | VesinNeighborList* data_; 61 | }; 62 | 63 | 64 | } 65 | 66 | #endif 67 | -------------------------------------------------------------------------------- /vesin/torch/src/vesin_torch.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "vesin_torch.hpp" 8 | 9 | using namespace vesin_torch; 10 | 11 | static VesinDevice torch_to_vesin_device(torch::Device device) { 12 | if (device.is_cpu()) { 13 | return VesinCPU; 14 | } else { 15 | throw std::runtime_error("device " + device.str() + " is not supported in vesin"); 16 | } 17 | } 18 | 19 | static torch::Device vesin_to_torch_device(VesinDevice device) { 20 | if (device == VesinCPU) { 21 | return torch::Device("cpu"); 22 | } else { 23 | throw std::runtime_error("vesin device is not supported in torch"); 24 | } 25 | } 26 | 27 | /// Custom autograd function that only registers a custom backward corresponding 28 | /// to the neighbors list calculation 29 | class AutogradNeighbors: public torch::autograd::Function { 30 | public: 31 | static std::vector forward( 32 | torch::autograd::AutogradContext* ctx, 33 | torch::Tensor points, 34 | torch::Tensor box, 35 | bool periodic, 36 | torch::Tensor pairs, 37 | torch::optional shifts, 38 | torch::optional distances, 39 | torch::optional vectors 40 | ); 41 | 42 | static std::vector backward( 43 | torch::autograd::AutogradContext* ctx, 44 | std::vector outputs_grad 45 | ); 46 | }; 47 | 48 | NeighborListHolder::NeighborListHolder(double cutoff, bool full_list, bool sorted): 49 | cutoff_(cutoff), 50 | full_list_(full_list), 51 | sorted_(sorted), 52 | data_(nullptr) 53 | { 54 | data_ = new VesinNeighborList(); 55 | } 56 | 57 | NeighborListHolder::~NeighborListHolder() { 58 | vesin_free(data_); 59 | delete data_; 60 | } 61 | 62 | std::vector NeighborListHolder::compute( 63 | torch::Tensor points, 64 | torch::Tensor box, 65 | bool periodic, 66 | std::string quantities, 67 | bool copy 68 | ) { 69 | // check input data 70 | if (points.device() != box.device()) { 71 | C10_THROW_ERROR(ValueError, 72 | "expected `points` and `box` to have the same device, got " + 73 | points.device().str() + " and " + box.device().str() 74 | ); 75 | } 76 | auto device = torch_to_vesin_device(points.device()); 77 | 78 | if (points.scalar_type() != box.scalar_type()) { 79 | C10_THROW_ERROR(ValueError, 80 | std::string("expected `points` and `box` to have the same dtype, got ") + 81 | torch::toString(points.scalar_type()) + " and " + 82 | torch::toString(box.scalar_type()) 83 | ); 84 | } 85 | if (points.scalar_type() != torch::kFloat64) { 86 | C10_THROW_ERROR(ValueError, 87 | "only float64 dtype is supported in vesin" 88 | ); 89 | } 90 | 91 | if (points.sizes().size() != 2 || points.size(1) != 3) { 92 | std::ostringstream oss; 93 | oss << "`points` must be n x 3 tensor, but the shape is " << points.sizes(); 94 | C10_THROW_ERROR(ValueError, oss.str()); 95 | } 96 | 97 | if (box.sizes().size() != 2 || box.size(0) != 3 || box.size(1) != 3) { 98 | std::ostringstream oss; 99 | oss << "`box` must be 3 x 3 tensor, but the shape is " << points.sizes(); 100 | C10_THROW_ERROR(ValueError, oss.str()); 101 | } 102 | 103 | if (!periodic) { 104 | box = torch::zeros({3, 3}, points.options()); 105 | } 106 | 107 | // create calculation options 108 | auto n_points = static_cast(points.size(0)); 109 | 110 | if (data_->device != VesinUnknownDevice && data_->device != device) { 111 | vesin_free(data_); 112 | std::memset(data_, 0, sizeof(VesinNeighborList)); 113 | } 114 | 115 | auto return_shifts = quantities.find('S') != std::string::npos; 116 | if (box.requires_grad()) { 117 | return_shifts = true; 118 | } 119 | 120 | auto return_distances = quantities.find('d') != std::string::npos; 121 | auto return_vectors = quantities.find('D') != std::string::npos; 122 | if ((points.requires_grad() || box.requires_grad()) && (return_distances || return_vectors)) { 123 | // gradients requires both distances & vectors data to be present 124 | return_distances = true; 125 | return_vectors = true; 126 | } 127 | 128 | auto options = VesinOptions{ 129 | /*cutoff=*/ this->cutoff_, 130 | /*full=*/ this->full_list_, 131 | /*sorted=*/ this->sorted_, 132 | /*return_shifts=*/ return_shifts, 133 | /*return_distances=*/ return_distances, 134 | /*return_vectors=*/ return_vectors, 135 | }; 136 | 137 | if (!points.is_contiguous()) { 138 | points = points.contiguous(); 139 | } 140 | 141 | if (!box.is_contiguous()) { 142 | box = box.contiguous(); 143 | } 144 | 145 | const char* error_message = nullptr; 146 | auto status = vesin_neighbors( 147 | reinterpret_cast(points.data_ptr()), 148 | n_points, 149 | reinterpret_cast(box.data_ptr()), 150 | periodic, 151 | device, 152 | options, 153 | data_, 154 | &error_message 155 | ); 156 | 157 | if (status != EXIT_SUCCESS) { 158 | throw std::runtime_error(std::string("failed to compute neighbors: ") + error_message); 159 | } 160 | 161 | // wrap vesin data in tensors 162 | auto size_t_options = torch::TensorOptions().device(vesin_to_torch_device(data_->device)); 163 | if (sizeof(size_t) == sizeof(uint32_t)) { 164 | size_t_options = size_t_options.dtype(torch::kUInt32); 165 | } else if (sizeof(size_t) == sizeof(uint64_t)) { 166 | size_t_options = size_t_options.dtype(torch::kUInt64); 167 | } else { 168 | C10_THROW_ERROR(ValueError, 169 | "could not determine torch dtype matching size_t" 170 | ); 171 | } 172 | 173 | auto pairs = torch::from_blob( 174 | data_->pairs, 175 | {static_cast(data_->length), 2}, 176 | size_t_options 177 | ).to(torch::kInt64); 178 | 179 | auto shifts = torch::Tensor(); 180 | if (data_->shifts != nullptr) { 181 | auto int32_options = torch::TensorOptions() 182 | .device(vesin_to_torch_device(data_->device)) 183 | .dtype(torch::kInt32); 184 | 185 | shifts = torch::from_blob( 186 | data_->shifts, 187 | {static_cast(data_->length), 3}, 188 | int32_options 189 | ); 190 | 191 | if (copy) { 192 | shifts = shifts.clone(); 193 | } 194 | } 195 | 196 | auto double_options = torch::TensorOptions() 197 | .device(vesin_to_torch_device(data_->device)) 198 | .dtype(torch::kDouble); 199 | 200 | auto distances = torch::Tensor(); 201 | if (data_->distances != nullptr) { 202 | distances = torch::from_blob( 203 | data_->distances, 204 | {static_cast(data_->length)}, 205 | double_options 206 | ); 207 | 208 | if (copy) { 209 | distances = distances.clone(); 210 | } 211 | } 212 | 213 | auto vectors = torch::Tensor(); 214 | if (data_->vectors != nullptr) { 215 | vectors = torch::from_blob( 216 | data_->vectors, 217 | {static_cast(data_->length), 3}, 218 | double_options 219 | ); 220 | 221 | if (copy) { 222 | vectors = vectors.clone(); 223 | } 224 | } 225 | 226 | // handle autograd 227 | if ((return_distances || return_vectors)) { 228 | // we use optional for these three because otherwise torch autograd 229 | // tries to access data inside the undefined `torch::Tensor()`. 230 | torch::optional shifts_optional = torch::nullopt; 231 | if (shifts.defined()) { 232 | shifts_optional = shifts; 233 | } 234 | 235 | torch::optional distances_optional = torch::nullopt; 236 | if (distances.defined()) { 237 | distances_optional = distances; 238 | } 239 | 240 | torch::optional vectors_optional = torch::nullopt; 241 | if (vectors.defined()) { 242 | vectors_optional = vectors; 243 | } 244 | 245 | auto outputs = AutogradNeighbors::apply( 246 | points, 247 | box, 248 | periodic, 249 | pairs, 250 | shifts_optional, 251 | distances_optional, 252 | vectors_optional 253 | ); 254 | 255 | if (return_distances && return_vectors) { 256 | distances = outputs[0]; 257 | vectors = outputs[1]; 258 | } else if (return_distances) { 259 | distances = outputs[0]; 260 | } else { 261 | assert(return_vectors); 262 | vectors = outputs[0]; 263 | } 264 | } 265 | 266 | // assemble the output 267 | auto output = std::vector(); 268 | for (auto c: quantities) { 269 | if (c == 'i') { 270 | output.push_back(pairs.index({torch::indexing::Slice(), 0})); 271 | } else if (c == 'j') { 272 | output.push_back(pairs.index({torch::indexing::Slice(), 1})); 273 | } else if (c == 'P') { 274 | output.push_back(pairs); 275 | } else if (c == 'S') { 276 | output.push_back(shifts); 277 | } else if (c == 'd') { 278 | output.push_back(distances); 279 | } else if (c == 'D') { 280 | output.push_back(vectors); 281 | } else { 282 | C10_THROW_ERROR(ValueError, 283 | "unexpected character in `quantities`: " + std::string(1, c) 284 | ); 285 | } 286 | } 287 | 288 | return output; 289 | } 290 | 291 | 292 | TORCH_LIBRARY(vesin, m) { 293 | std::string DOCSTRING; 294 | 295 | m.class_("_NeighborList") 296 | .def( 297 | torch::init(), DOCSTRING, 298 | {torch::arg("cutoff"), torch::arg("full_list"), torch::arg("sorted") = false} 299 | ) 300 | .def("compute", &NeighborListHolder::compute, DOCSTRING, 301 | {torch::arg("points"), torch::arg("box"), torch::arg("periodic"), torch::arg("quantities"), torch::arg("copy") = true} 302 | ) 303 | ; 304 | } 305 | 306 | // ========================================================================== // 307 | // // 308 | // ========================================================================== // 309 | 310 | std::vector AutogradNeighbors::forward( 311 | torch::autograd::AutogradContext* ctx, 312 | torch::Tensor points, 313 | torch::Tensor box, 314 | bool periodic, 315 | torch::Tensor pairs, 316 | torch::optional shifts, 317 | torch::optional distances, 318 | torch::optional vectors 319 | ) { 320 | auto shifts_tensor = shifts.value_or(torch::Tensor()); 321 | auto distances_tensor = distances.value_or(torch::Tensor()); 322 | auto vectors_tensor = vectors.value_or(torch::Tensor()); 323 | 324 | ctx->save_for_backward({points, box, pairs, shifts_tensor, distances_tensor, vectors_tensor}); 325 | ctx->saved_data["periodic"] = periodic; 326 | 327 | auto return_distances = distances.has_value(); 328 | auto return_vectors = vectors.has_value(); 329 | ctx->saved_data["return_distances"] = return_distances; 330 | ctx->saved_data["return_vectors"] = return_vectors; 331 | 332 | // only return defined tensors to make sure torch can use `get_autograd_meta()` 333 | if (return_distances && return_vectors) { 334 | return {distances_tensor, vectors_tensor}; 335 | } else if (return_distances) { 336 | return {distances_tensor}; 337 | } else if (return_vectors) { 338 | return {vectors_tensor}; 339 | } else { 340 | return {}; 341 | } 342 | } 343 | 344 | std::vector AutogradNeighbors::backward( 345 | torch::autograd::AutogradContext* ctx, 346 | std::vector outputs_grad 347 | ) { 348 | auto saved_variables = ctx->get_saved_variables(); 349 | auto points = saved_variables[0]; 350 | auto box = saved_variables[1]; 351 | auto periodic = ctx->saved_data["periodic"].toBool(); 352 | 353 | auto pairs = saved_variables[2]; 354 | auto shifts = saved_variables[3]; 355 | auto distances = saved_variables[4]; 356 | auto vectors = saved_variables[5]; 357 | 358 | auto return_distances = ctx->saved_data["return_distances"].toBool(); 359 | auto return_vectors = ctx->saved_data["return_vectors"].toBool(); 360 | 361 | auto distances_grad = torch::Tensor(); 362 | auto vectors_grad = torch::Tensor(); 363 | if (return_distances && return_vectors) { 364 | distances_grad = outputs_grad[0]; 365 | vectors_grad = outputs_grad[1]; 366 | } else if (return_distances) { 367 | distances_grad = outputs_grad[0]; 368 | } else if (return_vectors) { 369 | vectors_grad = outputs_grad[0]; 370 | } else { 371 | // nothing to do 372 | return { 373 | torch::Tensor(), 374 | torch::Tensor(), 375 | torch::Tensor(), 376 | torch::Tensor(), 377 | torch::Tensor(), 378 | torch::Tensor(), 379 | torch::Tensor(), 380 | }; 381 | } 382 | 383 | if (points.requires_grad() || box.requires_grad()) { 384 | // Do a first backward step from distances_grad to vectors_grad 385 | vectors_grad += distances_grad.index({torch::indexing::Slice(), torch::indexing::None}) 386 | * vectors / distances.index({torch::indexing::Slice(), torch::indexing::None}); 387 | } 388 | 389 | auto points_grad = torch::Tensor(); 390 | if (points.requires_grad()) { 391 | points_grad = torch::zeros_like(points); 392 | points_grad = torch::index_add( 393 | points_grad, 394 | /*dim=*/0, 395 | /*index=*/pairs.index({torch::indexing::Slice(), 1}), 396 | /*source=*/vectors_grad, 397 | /*alpha=*/1.0 398 | ); 399 | points_grad = torch::index_add( 400 | points_grad, 401 | /*dim=*/0, 402 | /*index=*/pairs.index({torch::indexing::Slice(), 0}), 403 | /*source=*/vectors_grad, 404 | /*alpha=*/-1.0 405 | ); 406 | } 407 | 408 | auto box_grad = torch::Tensor(); 409 | if (periodic && box.requires_grad()) { 410 | auto cell_shifts = shifts.to(box.scalar_type()); 411 | box_grad = cell_shifts.t().matmul(vectors_grad); 412 | } 413 | 414 | return { 415 | points_grad, 416 | box_grad, 417 | torch::Tensor(), 418 | torch::Tensor(), 419 | torch::Tensor(), 420 | torch::Tensor(), 421 | torch::Tensor(), 422 | }; 423 | } 424 | --------------------------------------------------------------------------------