├── .github
└── workflows
│ ├── build-wheels.yml
│ ├── cxx-tests.yml
│ ├── docs.yml
│ └── python-tests.yml
├── .gitignore
├── CMakeLists.txt
├── LICENSE
├── README.md
├── benchmarks
├── benchmark.py
└── carbon.xyz
├── create-single-cpp.py
├── docs
├── Doxyfile
└── src
│ ├── benchmark.png
│ ├── benchmarks.rst
│ ├── c-api.rst
│ ├── conf.py
│ ├── index.rst
│ ├── metatomic.rst
│ ├── python-api.rst
│ ├── static
│ └── images
│ │ ├── Arpitan.png
│ │ ├── Catalan.png
│ │ ├── Lombardy.png
│ │ └── Occitan.png
│ └── torch-api.rst
├── python
├── vesin
│ ├── MANIFEST.in
│ ├── README.md
│ ├── VERSION
│ ├── pyproject.toml
│ ├── setup.py
│ ├── tests
│ │ ├── data
│ │ │ ├── Cd2I4O12.xyz
│ │ │ ├── carbon.xyz
│ │ │ ├── diamond.xyz
│ │ │ ├── naphthalene.xyz
│ │ │ ├── readme.txt
│ │ │ └── water.xyz
│ │ ├── test_metatomic.py
│ │ └── test_neighbors.py
│ └── vesin
│ │ ├── __init__.py
│ │ ├── _ase.py
│ │ ├── _c_api.py
│ │ ├── _c_lib.py
│ │ ├── _neighbors.py
│ │ └── metatomic
│ │ ├── __init__.py
│ │ ├── _model.py
│ │ └── _neighbors.py
└── vesin_torch
│ ├── MANIFEST.in
│ ├── README.md
│ ├── VERSION
│ ├── build-backend
│ └── backend.py
│ ├── pyproject.toml
│ ├── setup.py
│ ├── tests
│ ├── test_autograd.py
│ ├── test_metatensor.py
│ ├── test_metatomic.py
│ └── test_neighbors.py
│ └── vesin
│ └── torch
│ ├── __init__.py
│ ├── _c_lib.py
│ ├── _neighbors.py
│ └── metatensor
│ ├── __init__.py
│ ├── _model.py
│ └── _neighbors.py
├── ruff.toml
├── scripts
├── clean-python.sh
├── create-torch-versions-range.py
└── pytest-dont-rewrite-torch.py
├── setup.py
├── tox.ini
└── vesin
├── CMakeLists.txt
├── VERSION
├── include
└── vesin.h
├── src
├── cpu_cell_list.cpp
├── cpu_cell_list.hpp
├── math.hpp
├── types.hpp
└── vesin.cpp
├── tests
├── CMakeLists.txt
├── memory.cpp
└── neighbors.cpp
└── torch
├── CMakeLists.txt
├── include
└── vesin_torch.hpp
└── src
└── vesin_torch.cpp
/.github/workflows/build-wheels.yml:
--------------------------------------------------------------------------------
1 | name: Build Python wheels
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | tags: ["*"]
7 | pull_request:
8 | paths:
9 | # build wheels in PR if this file changed
10 | - '.github/workflows/build-wheels.yml'
11 | # build wheels in PR if any of the build system files changed
12 | - '**/VERSION'
13 | - '**/setup.py'
14 | - '**/pyproject.toml'
15 | - '**/MANIFEST.in'
16 | - '**/CMakeLists.txt'
17 | schedule:
18 | # check the build once a week on mondays
19 | - cron: '0 10 * * 1'
20 |
21 | concurrency:
22 | group: wheels-${{ github.ref }}
23 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
24 |
25 | jobs:
26 | build-wheels:
27 | runs-on: ${{ matrix.os }}
28 | name: ${{ matrix.name }}
29 | strategy:
30 | matrix:
31 | include:
32 | - name: x86_64 Linux
33 | os: ubuntu-22.04
34 | cibw_arch: x86_64
35 | - name: arm64 Linux
36 | os: ubuntu-22.04-arm
37 | cibw_arch: aarch64
38 | - name: x86_64 macOS
39 | os: macos-13
40 | cibw_arch: x86_64
41 | - name: M1 macOS
42 | os: macos-14
43 | cibw_arch: arm64
44 | - name: x86_64 Windows
45 | os: windows-2022
46 | cibw_arch: AMD64
47 | steps:
48 | - uses: actions/checkout@v4
49 | with:
50 | fetch-depth: 0
51 |
52 | - name: Set up Python
53 | uses: actions/setup-python@v5
54 | with:
55 | python-version: "3.12"
56 |
57 | - name: install dependencies
58 | run: python -m pip install cibuildwheel twine
59 |
60 | - name: build wheel
61 | run: python -m cibuildwheel python/vesin
62 | env:
63 | CIBW_BUILD: cp312-*
64 | CIBW_SKIP: "*musllinux*"
65 | CIBW_ARCHS: ${{ matrix.cibw_arch }}
66 | CIBW_BUILD_FRONTEND: build
67 | CIBW_MANYLINUX_X86_64_IMAGE: quay.io/pypa/manylinux2014_x86_64
68 | CIBW_MANYLINUX_AARCH64_IMAGE: quay.io/pypa/manylinux2014_aarch64
69 |
70 | - name: check wheels with twine
71 | run: twine check wheelhouse/*
72 |
73 | - uses: actions/upload-artifact@v4
74 | with:
75 | name: wheel-${{ matrix.os }}-${{ matrix.cibw_arch }}
76 | path: ./wheelhouse/*.whl
77 |
78 | build-torch-wheels:
79 | runs-on: ${{ matrix.os }}
80 | name: ${{ matrix.name }} (torch v${{ matrix.torch-version }})
81 | strategy:
82 | matrix:
83 | torch-version: ['2.3', '2.4', '2.5', '2.6', '2.7']
84 | arch: ['arm64', 'x86_64']
85 | os: ['ubuntu-22.04', 'ubuntu-22.04-arm', 'macos-14', 'windows-2022']
86 | exclude:
87 | # remove mismatched architectures
88 | - {os: macos-14, arch: x86_64}
89 | - {os: ubuntu-22.04, arch: arm64}
90 | - {os: ubuntu-22.04-arm, arch: x86_64}
91 | - {os: windows-2022, arch: arm64}
92 | include:
93 | # add `cibw-arch` to the different configurations
94 | - name: x86_64 Linux
95 | os: ubuntu-22.04
96 | arch: x86_64
97 | cibw-arch: x86_64
98 | - name: arm64 Linux
99 | os: ubuntu-22.04-arm
100 | arch: arm64
101 | cibw-arch: aarch64
102 | - name: arm64 macOS
103 | os: macos-14
104 | arch: arm64
105 | cibw-arch: arm64
106 | - name: x86_64 Windows
107 | os: windows-2022
108 | arch: x86_64
109 | cibw-arch: AMD64
110 | # add the right python version for each torch version
111 | - {torch-version: '2.3', python-version: '3.12', cibw-python: 'cp312-*'}
112 | - {torch-version: '2.4', python-version: '3.12', cibw-python: 'cp312-*'}
113 | - {torch-version: '2.5', python-version: '3.12', cibw-python: 'cp312-*'}
114 | - {torch-version: '2.6', python-version: '3.12', cibw-python: 'cp312-*'}
115 | - {torch-version: '2.7', python-version: '3.12', cibw-python: 'cp312-*'}
116 | steps:
117 | - uses: actions/checkout@v4
118 | with:
119 | fetch-depth: 0
120 |
121 | - name: Set up Python
122 | uses: actions/setup-python@v5
123 | with:
124 | python-version: ${{ matrix.python-version }}
125 |
126 | - name: install dependencies
127 | run: python -m pip install cibuildwheel
128 |
129 | - name: build vesin-torch wheel
130 | run: python -m cibuildwheel python/vesin_torch
131 | env:
132 | CIBW_BUILD: ${{ matrix.cibw-python}}
133 | CIBW_SKIP: "*musllinux*"
134 | CIBW_ARCHS: ${{ matrix.cibw-arch }}
135 | CIBW_BUILD_VERBOSITY: 1
136 | CIBW_MANYLINUX_X86_64_IMAGE: quay.io/pypa/manylinux_2_28_x86_64
137 | CIBW_MANYLINUX_AARCH64_IMAGE: quay.io/pypa/manylinux_2_28_aarch64
138 | CIBW_ENVIRONMENT: >
139 | VESIN_TORCH_BUILD_WITH_TORCH_VERSION=${{ matrix.torch-version }}.*
140 | PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu
141 | MACOSX_DEPLOYMENT_TARGET=11
142 | # do not complain for missing libtorch.so
143 | CIBW_REPAIR_WHEEL_COMMAND_MACOS: |
144 | delocate-wheel --ignore-missing-dependencies --require-archs {delocate_archs} -w {dest_dir} -v {wheel}
145 | CIBW_REPAIR_WHEEL_COMMAND_LINUX: |
146 | auditwheel repair --exclude libtorch.so --exclude libtorch_cpu.so --exclude libc10.so -w {dest_dir} {wheel}
147 |
148 | - uses: actions/upload-artifact@v4
149 | with:
150 | name: torch-single-version-wheel-${{ matrix.torch-version }}-${{ matrix.os }}-${{ matrix.arch }}
151 | path: ./wheelhouse/*.whl
152 |
153 | merge-torch-wheels:
154 | needs: build-torch-wheels
155 | runs-on: ubuntu-22.04
156 | name: merge vesin-torch ${{ matrix.name }}
157 | strategy:
158 | matrix:
159 | include:
160 | - name: x86_64 Linux
161 | os: ubuntu-22.04
162 | arch: x86_64
163 | - name: arm64 Linux
164 | os: ubuntu-22.04-arm
165 | arch: arm64
166 | - name: arm64 macOS
167 | os: macos-14
168 | arch: arm64
169 | - name: x86_64 Windows
170 | os: windows-2022
171 | arch: x86_64
172 | steps:
173 | - uses: actions/checkout@v4
174 |
175 | - name: Download wheels
176 | uses: actions/download-artifact@v4
177 | with:
178 | pattern: torch-single-version-wheel-*-${{ matrix.os }}-${{ matrix.arch }}
179 | merge-multiple: false
180 | path: dist
181 |
182 | - name: Set up Python
183 | uses: actions/setup-python@v5
184 | with:
185 | python-version: "3.12"
186 |
187 | - name: install dependencies
188 | run: python -m pip install twine wheel
189 |
190 | - name: merge wheels
191 | run: |
192 | # collect all torch versions used for the build
193 | REQUIRES_TORCH=$(find dist -name "*.whl" -exec unzip -p {} "vesin_torch-*.dist-info/METADATA" \; | grep "Requires-Dist: torch")
194 | MERGED_TORCH_REQUIRE=$(python scripts/create-torch-versions-range.py "$REQUIRES_TORCH")
195 |
196 | echo MERGED_TORCH_REQUIRE=$MERGED_TORCH_REQUIRE
197 |
198 | # unpack all single torch versions wheels in the same directory
199 | mkdir dist/unpacked
200 | find dist -name "*.whl" -print -exec python -m wheel unpack --dest dist/unpacked/ {} ';'
201 |
202 | sed -i "s/Requires-Dist: torch.*/$MERGED_TORCH_REQUIRE/" dist/unpacked/vesin_torch-*/vesin_torch-*.dist-info/METADATA
203 |
204 | echo "\n\n METADATA = \n\n"
205 | cat dist/unpacked/vesin_torch-*/vesin_torch-*.dist-info/METADATA
206 |
207 | # check the right metadata was added to the file. grep will exit with
208 | # code `1` if the line is not found, which will stop CI
209 | grep "$MERGED_TORCH_REQUIRE" dist/unpacked/vesin_torch-*/vesin_torch-*.dist-info/METADATA
210 |
211 | # repack the directory as a new wheel
212 | mkdir wheelhouse
213 | python -m wheel pack --dest wheelhouse/ dist/unpacked/*
214 |
215 | - name: check wheels with twine
216 | run: twine check wheelhouse/*
217 |
218 | - uses: actions/upload-artifact@v4
219 | with:
220 | name: torch-wheel-${{ matrix.os }}-${{ matrix.arch }}
221 | path: ./wheelhouse/*.whl
222 |
223 | build-sdist:
224 | runs-on: ubuntu-22.04
225 | name: sdist
226 | steps:
227 | - uses: actions/checkout@v4
228 | with:
229 | fetch-depth: 0
230 |
231 | - name: Set up Python
232 | uses: actions/setup-python@v5
233 | with:
234 | python-version: "3.12"
235 |
236 | - name: build sdist
237 | run: |
238 | pip install build
239 | python -m build --sdist python/vesin --outdir dist
240 | python -m build --sdist python/vesin_torch --outdir dist
241 |
242 | - uses: actions/upload-artifact@v4
243 | with:
244 | name: sdist
245 | path: dist/*.tar.gz
246 |
247 | merge-and-release:
248 | name: Merge and release wheels/sdists
249 | needs: [build-wheels, merge-torch-wheels, build-sdist]
250 | runs-on: ubuntu-22.04
251 | permissions:
252 | contents: write
253 | steps:
254 | - name: Download wheels
255 | uses: actions/download-artifact@v4
256 | with:
257 | path: dist
258 | pattern: wheel-*
259 | merge-multiple: true
260 |
261 | - name: Download metatensor-torch wheels
262 | uses: actions/download-artifact@v4
263 | with:
264 | path: dist
265 | pattern: torch-wheel-*
266 | merge-multiple: true
267 |
268 | - name: Download sdist
269 | uses: actions/download-artifact@v4
270 | with:
271 | path: dist
272 | name: sdist
273 |
274 | - name: Re-upload a single wheels artifact
275 | uses: actions/upload-artifact@v4
276 | with:
277 | name: wheels
278 | path: |
279 | dist/*
280 |
281 | - name: upload to GitHub release
282 | if: startsWith(github.ref, 'refs/tags/')
283 | uses: softprops/action-gh-release@v2
284 | with:
285 | files: |
286 | dist/*
287 | prerelease: ${{ contains(github.ref, '-rc') }}
288 | env:
289 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
290 |
--------------------------------------------------------------------------------
/.github/workflows/cxx-tests.yml:
--------------------------------------------------------------------------------
1 | name: C++ tests
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | # Check all PR
8 |
9 | concurrency:
10 | group: cxx-tests-${{ github.ref }}
11 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
12 |
13 | jobs:
14 | rust-tests:
15 | runs-on: ${{ matrix.os }}
16 | name: ${{ matrix.os }} / ${{ matrix.compiler }}
17 | container: ${{ matrix.container }}
18 | strategy:
19 | matrix:
20 | include:
21 | - os: ubuntu-22.04
22 | compiler: GCC - Valgrind
23 | cc: gcc-13
24 | cxx: g++-13
25 | do-valgrind: true
26 | setup-dependencies: |
27 | sudo add-apt-repository ppa:ubuntu-toolchain-r/test
28 | sudo apt-get update
29 | sudo apt-get install -y gcc-13 g++-13
30 |
31 | - os: ubuntu-22.04
32 | compiler: Clang
33 | cc: clang-18
34 | cxx: clang++-18
35 | setup-dependencies: |
36 | wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
37 | sudo apt-get update
38 | sudo apt-add-repository "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-18 main"
39 | sudo apt-get install -y clang-18
40 |
41 | # check the build on a stock Ubuntu 20.04, which uses cmake 3.16
42 | - os: ubuntu-22.04
43 | compiler: GCC
44 | cc: gcc
45 | cxx: g++
46 | container: ubuntu:20.04
47 | setup-dependencies: |
48 | apt update
49 | apt install -y software-properties-common
50 | apt install -y cmake make gcc g++ git curl
51 |
52 | - os: macos-14
53 | compiler: Clang
54 | cc: clang
55 | cxx: clang++
56 |
57 | - os: windows-2022
58 | compiler: MSVC
59 | cc: cl.exe
60 | cxx: cl.exe
61 | cmake-extra-args:
62 | - -G "Visual Studio 17 2022" -A x64
63 |
64 | - os: windows-2022
65 | compiler: MinGW
66 | cc: gcc.exe
67 | cxx: g++.exe
68 | cmake-extra-args:
69 | - -G "MinGW Makefiles"
70 |
71 | steps:
72 | - name: setup dependencies
73 | run: ${{ matrix.setup-dependencies }}
74 |
75 | - uses: actions/checkout@v4
76 | with:
77 | fetch-depth: 0
78 |
79 | - name: install valgrind
80 | if: matrix.do-valgrind
81 | run: |
82 | sudo apt-get update
83 | sudo apt-get install -y valgrind
84 |
85 | - name: configure cmake
86 | shell: bash
87 | run: |
88 | mkdir build && cd build
89 | cmake ${{ join(matrix.cmake-extra-args, ' ') }} \
90 | -DCMAKE_BUILD_TYPE=Debug \
91 | -DCMAKE_C_COMPILER=${{ matrix.cc }} \
92 | -DCMAKE_CXX_COMPILER=${{ matrix.cxx }} \
93 | -DVESIN_BUILD_TESTS=ON \
94 | -DCMAKE_VERBOSE_MAKEFILE=ON \
95 | ../vesin
96 |
97 | - name: build
98 | run: |
99 | cd build
100 | cmake --build . --config Debug --parallel 2
101 |
102 | - name: run tests
103 | run: |
104 | cd build
105 | ctest --output-on-failure --build-config Debug --parallel 2
106 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Documentation
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | tags: ["*"]
7 | pull_request:
8 | # Check all PR
9 |
10 | jobs:
11 | build-and-publish:
12 | permissions:
13 | contents: write
14 | runs-on: ubuntu-22.04
15 | steps:
16 | - uses: actions/checkout@v4
17 |
18 | - name: setup Python
19 | uses: actions/setup-python@v5
20 | with:
21 | python-version: "3.12"
22 |
23 | - name: install dependencies
24 | run: |
25 | python -m pip install tox
26 | sudo apt install doxygen
27 |
28 | - name: build documentation
29 | run: tox -e docs
30 | env:
31 | # Use the CPU only version of torch when building/running the code
32 | PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu
33 |
34 | - name: put documentation in the website
35 | run: |
36 | git clone https://github.com/$GITHUB_REPOSITORY --branch gh-pages gh-pages
37 | rm -rf gh-pages/.git
38 | cd gh-pages
39 |
40 | REF_KIND=$(echo $GITHUB_REF | cut -d / -f2)
41 | if [[ "$REF_KIND" == "tags" ]]; then
42 | TAG=${GITHUB_REF#refs/tags/}
43 | mv ../docs/build/html $TAG
44 | else
45 | rm -rf latest
46 | mv ../docs/build/html latest
47 | fi
48 |
49 | - name: deploy to gh-pages
50 | if: github.event_name == 'push'
51 | uses: peaceiris/actions-gh-pages@v4
52 | with:
53 | github_token: ${{ secrets.GITHUB_TOKEN }}
54 | publish_dir: ./gh-pages/
55 | force_orphan: true
56 |
--------------------------------------------------------------------------------
/.github/workflows/python-tests.yml:
--------------------------------------------------------------------------------
1 | name: Python tests
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | # Check all PR
8 |
9 | concurrency:
10 | group: python-tests-${{ github.ref }}
11 | cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
12 |
13 | jobs:
14 | python-tests:
15 | runs-on: ${{ matrix.os }}
16 | name: ${{ matrix.os }} / Python ${{ matrix.python-version }}
17 | strategy:
18 | matrix:
19 | include:
20 | - os: ubuntu-22.04
21 | python-version: "3.9"
22 | - os: ubuntu-22.04
23 | python-version: "3.12"
24 | - os: macos-14
25 | python-version: "3.12"
26 | - os: windows-2022
27 | python-version: "3.12"
28 | steps:
29 | - uses: actions/checkout@v4
30 | with:
31 | fetch-depth: 0
32 |
33 | - name: setup Python
34 | uses: actions/setup-python@v5
35 | with:
36 | python-version: ${{ matrix.python-version }}
37 |
38 | - name: install tests dependencies
39 | run: |
40 | python -m pip install --upgrade pip
41 | python -m pip install tox
42 |
43 | - name: run tests
44 | run: tox
45 | env:
46 | # Use the CPU only version of torch when building/running the code
47 | PIP_EXTRA_INDEX_URL: https://download.pytorch.org/whl/cpu
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | dist/
3 | .cache/
4 | .tox/
5 | __pycache__/
6 |
7 | *.egg-info
8 |
9 | vesin-single-build.cpp
10 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # This is not the actual CMakeLists.txt for this project, see
2 | # `vesin/CMakeLists.txt` for it. Instead, this file is here to enable `cmake ..`
3 | # from a git checkout, `add_subdirectory` and `FetchContent` without having to
4 | # specify a subdirectory.
5 |
6 | cmake_minimum_required(VERSION 3.16)
7 |
8 | project(vesin-git LANGUAGES NONE)
9 |
10 | if (${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_CURRENT_SOURCE_DIR})
11 | set(VESIN_MAIN_PROJECT ON)
12 | else()
13 | set(VESIN_MAIN_PROJECT OFF)
14 | endif()
15 |
16 | if (VESIN_MAIN_PROJECT)
17 | if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "")
18 | message(STATUS "Setting build type to 'release' as none was specified.")
19 | set(
20 | CMAKE_BUILD_TYPE "release"
21 | CACHE STRING
22 | "Choose the type of build, options are: none(CMAKE_CXX_FLAGS or CMAKE_C_FLAGS used) debug release relwithdebinfo minsizerel."
23 | FORCE
24 | )
25 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug relwithdebinfo minsizerel none)
26 | endif()
27 | endif()
28 |
29 | set(VESIN_INSTALL ${VESIN_MAIN_PROJECT} CACHE BOOL "Install Vesin's headers and libraries")
30 | add_subdirectory(vesin)
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, Guillaume Fraux
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Vesin: fast neighbor lists for atomistic systems
2 |
3 | [](http://luthaf.fr/vesin/)
4 | 
5 |
6 | | English 🇺🇸/🇬🇧 | Occitan
| French 🇫🇷 | Arpitan
| Gallo‑Italic
| Catalan
| Spanish 🇪🇸 | Italian 🇮🇹 |
7 | |------------------|----------|-----------|----------|--------------|---------|------------|------------|
8 | | neighbo(u)r | vesin | voisin | vesin | visin | veí | vecino | vicino |
9 |
10 |
11 | Vesin is a fast and easy to use library computing neighbor lists for atomistic
12 | system. We provide an interface for the following programing languages:
13 |
14 | - C (also compatible with C++). The project can be installed and used as a
15 | library with your own build system, or included as a single file and built
16 | directly by your own build system;
17 | - Python;
18 | - TorchScript, with both a C++ and Python interface;
19 |
20 | ### Installation
21 |
22 | To use the code from Python, you can install it with `pip`:
23 |
24 | ```
25 | pip install vesin
26 | ```
27 |
28 | See the [documentation](https://luthaf.fr/vesin/latest/index.html#installation)
29 | for more information on how to install the code to use it from C or C++.
30 |
31 | ### Usage instruction
32 |
33 | You can either use the `NeighborList` calculator class:
34 |
35 | ```py
36 | import numpy as np
37 | from vesin import NeighborList
38 |
39 | # positions can be anything compatible with numpy's ndarray
40 | positions = [
41 | (0, 0, 0),
42 | (0, 1.3, 1.3),
43 | ]
44 | box = 3.2 * np.eye(3)
45 |
46 | calculator = NeighborList(cutoff=4.2, full_list=True)
47 | i, j, S, d = calculator.compute(
48 | points=positions,
49 | box=box,
50 | periodic=True,
51 | quantities="ijSd"
52 | )
53 | ```
54 |
55 | We also provide a function with drop-in compatibility to ASE's neighbor list:
56 |
57 | ```py
58 | import ase
59 | from vesin import ase_neighbor_list
60 |
61 | atoms = ase.Atoms(...)
62 |
63 | i, j, S, d = ase_neighbor_list("ijSd", atoms, cutoff=4.2)
64 | ```
65 |
66 | See the [documentation](https://luthaf.fr/vesin/latest/c-api.html) for more
67 | information on how to use the code from C or C++.
68 |
69 | ### Benchmarks
70 |
71 | You can find below benchmark result computing neighbor lists for increasingly
72 | large diamond supercells, using an AMD 3955WX CPU and an NVIDIA 4070 Ti SUPER
73 | GPU. You can run this benchmark on your system with the script at
74 | `benchmarks/benchmark.py`. Missing points indicate that a specific code could
75 | not run the calculation (for example, NNPOps requires the cell to be twice the
76 | cutoff in size, and can't run with large cutoffs and small cells).
77 |
78 | 
79 |
80 | ## License
81 |
82 | Vesin is is distributed under the [3 clauses BSD license](LICENSE). By
83 | contributing to this code, you agree to distribute your contributions under the
84 | same license.
85 |
--------------------------------------------------------------------------------
/benchmarks/benchmark.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import ase.build
4 | import ase.neighborlist
5 | import matscipy.neighbours
6 | import NNPOps.neighbors
7 | import numpy as np
8 | import pymatgen.core
9 | import torch
10 | import torch_nl
11 |
12 | import vesin
13 |
14 |
15 | def benchmark(setup, function, atoms, cutoff):
16 | args = setup(atoms, cutoff)
17 |
18 | n_warm = 5
19 | start = time.time()
20 | for _ in range(n_warm):
21 | function(*args)
22 | end = time.time()
23 |
24 | warmup = (end - start) / n_warm
25 |
26 | # dynamically pick the number of iterations to keep timing below 1s per test, while
27 | # also ensuring at least 10 repetitions
28 | n_iter = int(1.0 / warmup)
29 | if n_iter > 10000:
30 | n_iter = 10000
31 | elif n_iter < 10:
32 | n_iter = 10
33 |
34 | start = time.time()
35 | for _ in range(n_iter):
36 | function(*args)
37 | end = time.time()
38 |
39 | return (end - start) / n_iter
40 |
41 |
42 | def setup_torch_nl_cpu(atoms, cutoff):
43 | pos, cell, pbc, batch, n_atoms = torch_nl.ase2data(
44 | [atoms], device=torch.device("cpu")
45 | )
46 | return cutoff, pos, cell, pbc, batch
47 |
48 |
49 | def setup_torch_nl_cuda(atoms, cutoff):
50 | pos, cell, pbc, batch, n_atoms = torch_nl.ase2data(
51 | [atoms], device=torch.device("cuda")
52 | )
53 | return cutoff, pos, cell, pbc, batch
54 |
55 |
56 | def torch_nl_run(cutoff, pos, cell, pbc, batch):
57 | return torch_nl.compute_neighborlist(
58 | cutoff, pos, cell, pbc, batch, self_interaction=True
59 | )
60 |
61 |
62 | def setup_nnpops_cpu(atoms, cutoff):
63 | positions = torch.tensor(atoms.positions)
64 | box_vector = torch.tensor(atoms.cell)
65 | return positions, cutoff, box_vector
66 |
67 |
68 | def setup_nnpops_cuda(atoms, cutoff):
69 | positions = torch.tensor(atoms.positions).to("cuda")
70 | box_vector = torch.tensor(atoms.cell).to("cuda")
71 | return positions, cutoff, box_vector
72 |
73 |
74 | def nnpops_run(positions, cutoff, box_vectors):
75 | return NNPOps.neighbors.getNeighborPairs(
76 | positions, cutoff=cutoff, box_vectors=box_vectors
77 | )
78 |
79 |
80 | def setup_ase_like(atoms, cutoff):
81 | return "ijSd", atoms, float(cutoff)
82 |
83 |
84 | def setup_pymatgen(atoms, cutoff):
85 | structure = pymatgen.core.Structure(
86 | atoms.cell[:],
87 | atoms.numbers,
88 | atoms.positions,
89 | coords_are_cartesian=True,
90 | )
91 | return structure, cutoff
92 |
93 |
94 | def pymatgen_run(structure, cutoff):
95 | return structure.get_neighbor_list(cutoff)
96 |
97 |
98 | def determine_super_cell(max_cell_repeat, max_log_size_delta, max_cell_ratio):
99 | """
100 | Determine which super cells to include. We want equally spaced number of atoms in
101 | log scale, and cells that are not too anisotropic.
102 | """
103 | sizes = {}
104 | for kx in range(1, max_cell_repeat):
105 | for ky in range(1, max_cell_repeat):
106 | for kz in range(1, max_cell_repeat):
107 | size = kx * ky * kz
108 | if size in sizes:
109 | sizes[size].append((kx, ky, kz))
110 | else:
111 | sizes[size] = [(kx, ky, kz)]
112 |
113 | # for each size, pick the less anisotropic cell
114 | repeats = []
115 | a, b, c = atoms.cell.lengths()
116 | for candidates in sizes.values():
117 | best = None
118 | best_ratio = np.inf
119 | for kx, ky, kz in candidates:
120 | lengths = [kx * a, ky * b, kz * c]
121 | ratio = np.max(lengths) / np.min(lengths)
122 | if ratio < best_ratio:
123 | best = (kx, ky, kz)
124 | best_ratio = ratio
125 |
126 | repeats.append(best)
127 |
128 | filtered_repeats = []
129 | filtered_log_sizes = [-1]
130 |
131 | for kx, ky, kz in repeats:
132 | log_size = np.log(kx * ky * kz)
133 | lengths = [kx * a, ky * b, kz * c]
134 | ratio = np.max(lengths) / np.min(lengths)
135 | if np.min(np.abs(np.array(filtered_log_sizes) - log_size)) > max_log_size_delta:
136 | if log_size < 2 or ratio < max_cell_ratio:
137 | filtered_repeats.append((kx, ky, kz))
138 | filtered_log_sizes.append(log_size)
139 |
140 | return filtered_repeats
141 |
142 |
143 | atoms = ase.build.bulk("C", "diamond", 3.567, orthorhombic=True)
144 |
145 | repeats = determine_super_cell(
146 | max_cell_repeat=20, max_log_size_delta=0.1, max_cell_ratio=3
147 | )
148 |
149 |
150 | n_atoms = {}
151 | ase_time = {}
152 | matscipy_time = {}
153 | torch_nl_cpu_time = {}
154 | torch_nl_cuda_time = {}
155 | pymatgen_time = {}
156 | vesin_time = {}
157 |
158 | for cutoff in [3, 6, 12]:
159 | print(f"=========== CUTOFF={cutoff} =============")
160 |
161 | n_atoms[cutoff] = []
162 | ase_time[cutoff] = []
163 | matscipy_time[cutoff] = []
164 | torch_nl_cpu_time[cutoff] = []
165 | torch_nl_cuda_time[cutoff] = []
166 | pymatgen_time[cutoff] = []
167 | vesin_time[cutoff] = []
168 |
169 | for kx, ky, kz in repeats:
170 | super_cell = atoms.repeat((kx, ky, kz))
171 | print(len(super_cell), "atoms")
172 | n_atoms[cutoff].append(len(super_cell))
173 |
174 | # ASE
175 | timing = benchmark(
176 | setup_ase_like,
177 | ase.neighborlist.neighbor_list,
178 | super_cell,
179 | cutoff,
180 | )
181 | ase_time[cutoff].append(timing * 1e3)
182 | print(f" ase took {timing * 1e3:.3f} ms")
183 |
184 | # MATSCIPY
185 | timing = benchmark(
186 | setup_ase_like,
187 | matscipy.neighbours.neighbour_list,
188 | super_cell,
189 | cutoff,
190 | )
191 | matscipy_time[cutoff].append(timing * 1e3)
192 | print(f" matscipy took {timing * 1e3:.3f} ms")
193 |
194 | # TORCH_NL CPU
195 | timing = benchmark(
196 | setup_torch_nl_cpu,
197 | torch_nl_run,
198 | super_cell,
199 | cutoff,
200 | )
201 | torch_nl_cpu_time[cutoff].append(timing * 1e3)
202 | print(f" torch_nl (cpu) took {timing * 1e3:.3f} ms")
203 |
204 | # TORCH_NL CUDA
205 | timing = benchmark(
206 | setup_torch_nl_cuda,
207 | torch_nl_run,
208 | super_cell,
209 | cutoff,
210 | )
211 | torch_nl_cuda_time[cutoff].append(timing * 1e3)
212 | print(f" torch_nl (cuda) took {timing * 1e3:.3f} ms")
213 |
214 | if np.any(super_cell.cell.lengths() < 2 * cutoff):
215 | print(" NNPOps can not run for this super cell")
216 | else:
217 | # NNPOps CPU
218 | timing = benchmark(
219 | setup_nnpops_cpu,
220 | nnpops_run,
221 | super_cell,
222 | cutoff,
223 | )
224 | torch_nl_cpu_time[cutoff].append(timing * 1e3)
225 | print(f" NNPOps (cpu) took {timing * 1e3:.3f} ms")
226 |
227 | # NNPOps CUDA
228 | timing = benchmark(
229 | setup_nnpops_cuda,
230 | nnpops_run,
231 | super_cell,
232 | cutoff,
233 | )
234 | torch_nl_cuda_time[cutoff].append(timing * 1e3)
235 | print(f" NNPOps (cuda) took {timing * 1e3:.3f} ms")
236 |
237 | # Pymatgen
238 | timing = benchmark(
239 | setup_pymatgen,
240 | pymatgen_run,
241 | super_cell,
242 | cutoff,
243 | )
244 | pymatgen_time[cutoff].append(timing * 1e3)
245 | print(f" pymatgen took {timing * 1e3:.3f} ms")
246 |
247 | # VESIN
248 | timing = benchmark(
249 | setup_ase_like,
250 | vesin.ase_neighbor_list,
251 | super_cell,
252 | cutoff,
253 | )
254 | vesin_time[cutoff].append(timing * 1e3)
255 | print(f" vesin took {timing * 1e3:.3f} ms")
256 | print()
257 |
--------------------------------------------------------------------------------
/benchmarks/carbon.xyz:
--------------------------------------------------------------------------------
1 | ../python/vesin/tests/data/carbon.xyz
--------------------------------------------------------------------------------
/create-single-cpp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 |
4 |
5 | HERE = os.path.dirname(os.path.realpath(__file__))
6 | ALREADY_SEEN = set()
7 |
8 |
9 | def find_file(path):
10 | candidates = [
11 | os.path.join(HERE, "vesin", "src", path),
12 | os.path.join(HERE, "vesin", "include", path),
13 | ]
14 | for candidate in candidates:
15 | if os.path.exists(candidate):
16 | return candidate
17 |
18 | raise RuntimeError(f"unable to find file for {path}")
19 |
20 |
21 | def include_path(line):
22 | assert "#include" in line
23 |
24 | _, path = line.split("#include")
25 | path = path.strip()
26 | if path.startswith('"'):
27 | return path[1:-1]
28 | else:
29 | return ""
30 |
31 |
32 | def merge_files(path, output):
33 | path = find_file(path)
34 |
35 | if path in ALREADY_SEEN:
36 | return
37 | else:
38 | ALREADY_SEEN.add(path)
39 |
40 | if path.endswith("include/vesin.h"):
41 | output.write('#include "vesin.h"\n')
42 | return
43 |
44 | with open(path) as fd:
45 | for line in fd:
46 | if "#include" in line:
47 | new_path = include_path(line)
48 | if new_path != "":
49 | merge_files(new_path, output)
50 | else:
51 | output.write(line)
52 | else:
53 | output.write(line)
54 |
55 |
56 | if __name__ == "__main__":
57 | with open("vesin-single-build.cpp", "w") as output:
58 | merge_files("cpu_cell_list.cpp", output)
59 | merge_files("vesin.cpp", output)
60 |
61 | print("created single build file 'vesin-single-build.cpp'")
62 |
--------------------------------------------------------------------------------
/docs/src/benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luthaf/vesin/d0036631d52b75dca9352d80f028a6383335d6d2/docs/src/benchmark.png
--------------------------------------------------------------------------------
/docs/src/benchmarks.rst:
--------------------------------------------------------------------------------
1 | .. _benchmarks:
2 |
3 | Benchmarks
4 | ==========
5 |
6 | Here are the result of a benchmark of multiple neighbor list implementations.
7 | The benchmark runs on multiple super-cell of diamond carbon, up to 30'000 atoms,
8 | with multiple cutoffs, and using either CPU or CUDA hardware.
9 |
10 | The results below are for an AMD 3955WX CPU and an NVIDIA 4070 Ti SUPER GPU; if
11 | you want to run it on your own system, the corresponding script is in vesin's
12 | `GitHub repository `_.
13 |
14 | .. _bench-script: https://github.com/Luthaf/vesin/blob/main/benchmarks/benchmark.py
15 |
16 | .. figure:: benchmark.png
17 | :align: center
18 |
19 | Speed comparison between multiple neighbor list implementations: vesin, `ase
20 | `_, `matscipy
21 | `_, `pymatgen
22 | `_,
23 | `torch_nl `_, and `NNPOps
24 | `_.
25 |
26 | Missing points indicate that a specific code could not run the calculation
27 | (for example, NNPOps requires the cell to be twice the cutoff in size, and
28 | can't run with large cutoffs and small cells).
29 |
--------------------------------------------------------------------------------
/docs/src/c-api.rst:
--------------------------------------------------------------------------------
1 | .. _c-api:
2 |
3 | C API reference
4 | ===============
5 |
6 | Vesin's C API is defined in the ``vesin.h`` header. The main function is
7 | :c:func:`vesin_neighbors`, which runs a neighbors list calculation.
8 |
9 | .. doxygenfunction:: vesin_neighbors
10 |
11 | .. doxygenfunction:: vesin_free
12 |
13 | .. doxygenstruct:: VesinNeighborList
14 |
15 | .. doxygenstruct:: VesinOptions
16 |
17 | .. doxygenenum:: VesinDevice
18 |
--------------------------------------------------------------------------------
/docs/src/conf.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | import subprocess
5 | import sys
6 | from datetime import datetime
7 |
8 |
9 | os.environ["METATENSOR_IMPORT_FOR_SPHINX"] = "1"
10 | os.environ["METATOMIC_IMPORT_FOR_SPHINX"] = "1"
11 |
12 | ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
13 | sys.path.insert(0, os.path.abspath("."))
14 | sys.path.insert(0, ROOT)
15 |
16 |
17 | # -- Project information -----------------------------------------------------
18 |
19 | project = "vesin"
20 | copyright = f"{datetime.now().date().year}, vesin developers"
21 |
22 |
23 | def setup(app):
24 | subprocess.run(["doxygen", "Doxyfile"], cwd=os.path.join(ROOT, "docs"))
25 |
26 |
27 | # -- General configuration ---------------------------------------------------
28 |
29 | needs_sphinx = "4.4.0"
30 |
31 | # Add any Sphinx extension module names here, as strings. They can be
32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
33 | # ones.
34 | extensions = [
35 | "sphinx.ext.autodoc",
36 | "sphinx.ext.intersphinx",
37 | "sphinx_design",
38 | "breathe",
39 | ]
40 |
41 | # Add any paths that contain templates here, relative to this directory.
42 | # templates_path = ["_templates"]
43 |
44 | # List of patterns, relative to source directory, that match files and
45 | # directories to ignore when looking for source files.
46 | # This pattern also affects html_static_path and html_extra_path.
47 | exclude_patterns = ["Thumbs.db", ".DS_Store"]
48 |
49 | autoclass_content = "both"
50 | autodoc_member_order = "bysource"
51 | autodoc_typehints_format = "short"
52 | autodoc_type_aliases = {
53 | "npt.ArrayLike": "numpy.typing.ArrayLike",
54 | }
55 |
56 |
57 | breathe_projects = {
58 | "vesin": os.path.join(ROOT, "docs", "build", "xml"),
59 | }
60 | breathe_default_project = "vesin"
61 | breathe_domain_by_extension = {
62 | "h": "c",
63 | }
64 |
65 | breathe_default_members = ("members", "undoc-members")
66 | cpp_private_member_specifier = ""
67 |
68 | intersphinx_mapping = {
69 | "python": ("https://docs.python.org/3", None),
70 | "metatensor": ("https://docs.metatensor.org/latest/", None),
71 | "metatomic": ("https://docs.metatensor.org/metatomic/latest/", None),
72 | "numpy": ("https://numpy.org/doc/stable/", None),
73 | "torch": ("https://pytorch.org/docs/stable/", None),
74 | "ase": ("https://wiki.fysik.dtu.dk/ase/", None),
75 | }
76 |
77 | html_theme = "furo"
78 | html_title = "Vesin"
79 | # html_static_path = ["_static"]
80 |
--------------------------------------------------------------------------------
/docs/src/index.rst:
--------------------------------------------------------------------------------
1 | Vesin: we are all neighbors
2 | ===========================
3 |
4 | .. |occ| image:: /static/images/Occitan.png
5 | :width: 18px
6 |
7 | .. |arp| image:: /static/images/Arpitan.png
8 | :width: 18px
9 |
10 | .. |lomb| image:: /static/images/Lombardy.png
11 | :width: 18px
12 |
13 | .. |cat| image:: /static/images/Catalan.png
14 | :width: 18px
15 |
16 | .. list-table::
17 | :align: center
18 | :widths: auto
19 | :header-rows: 1
20 |
21 | - * English 🇺🇸/🇬🇧
22 | * Occitan |occ|
23 | * French 🇫🇷
24 | * Arpitan |arp|
25 | * Gallo‑Italic |lomb|
26 | * Catalan |cat|
27 | * Spanish 🇪🇸
28 | * Italian 🇮🇹
29 | - * neighbo(u)r
30 | * vesin
31 | * voisin
32 | * vesin
33 | * visin
34 | * veí
35 | * vecino
36 | * vicino
37 |
38 |
39 | ``vesin`` is a lightweight neighbor list calculator for molecular systems and
40 | three-dimensional graphs. It is written in C++ and can be used as a standalone
41 | library from C or Python. ``vesin`` is designed to be :ref:`fast `
42 | and easy to use.
43 |
44 | Installation
45 | ------------
46 |
47 | .. tab-set::
48 |
49 | .. tab-item:: Python
50 | :sync: python
51 |
52 | You can install the code with ``pip``:
53 |
54 | .. code-block:: bash
55 |
56 | pip install vesin
57 |
58 | **TorchScript:**
59 |
60 | The TorchScript bindings can be installed with:
61 |
62 | .. code-block:: bash
63 |
64 | pip install vesin[torch]
65 |
66 | .. tab-item:: C/C++ (CMake)
67 | :sync: cxx
68 |
69 | If you use CMake as your build system, the simplest thing to do is to
70 | add https://github.com/Luthaf/vesin to your project.
71 |
72 | .. code-block:: cmake
73 |
74 | # assuming the code is in the `vesin/` directory (for example using
75 | # git submodule)
76 | add_subdirectory(vesin)
77 |
78 | target_link_libraries(your-target vesin)
79 |
80 | Alternatively, you can use CMake's `FetchContent
81 | `_ module
82 | to automatically download the code:
83 |
84 | .. code-block:: cmake
85 |
86 | include(FetchContent)
87 | FetchContent_Declare(
88 | vesin
89 | GIT_REPOSITORY https://github.com/Luthaf/vesin.git
90 | )
91 |
92 | FetchContent_MakeAvailable(vesin)
93 |
94 | target_link_libraries(your-target vesin)
95 |
96 | **TorchScript:**
97 |
98 | To make the TorchScript version of the library available to CMake as
99 | well, you should set the ``VESIN_TORCH`` option to ``ON``. If you are
100 | using ``add_subdirectory(vesin)``:
101 |
102 | .. code-block:: cmake
103 |
104 | set(VESIN_TORCH ON CACHE BOOL "Build the vesin_torch library")
105 |
106 | add_subdirectory(vesin)
107 |
108 | target_link_libraries(your-target vesin_torch)
109 |
110 | And if you are using ``FetchContent``:
111 |
112 | .. code-block:: cmake
113 |
114 | set(VESIN_TORCH ON CACHE BOOL "Build the vesin_torch library")
115 |
116 | # like above
117 | FetchContent_Declare(...)
118 | FetchContent_MakeAvailable(...)
119 |
120 | target_link_libraries(your-target vesin_torch)
121 |
122 | .. tab-item:: C/C++ (single file build)
123 |
124 | We support merging all files in the vesin library to a single one that
125 | can then be included in your own project and built with the same build
126 | system as the rest of your code.
127 |
128 | You can generate this single file to build with the following commands:
129 |
130 | .. code-block:: bash
131 |
132 | git clone https://github.com/Luthaf/vesin.git
133 | cd vesin
134 | python create-single-cpp.py
135 |
136 | Then you'll need to copy both ``include/vesin.h`` and
137 | ``vesin-single-build.cpp`` in your project and configure your build
138 | system accordingly.
139 |
140 | **TorchScript:**
141 |
142 | The TorchScript API does not support single file build, please use one
143 | of the CMake options instead.
144 |
145 |
146 | .. tab-item:: C/C++ (global installation)
147 |
148 | You can build and install vesin in some global location (referred to as
149 | ``$PREFIX`` below), and then use the right compiler flags to give this
150 | location to your compiler. In this case, compilation of ``vesin`` and
151 | your code happen separately.
152 |
153 | .. code-block:: bash
154 |
155 | git clone https://github.com/Luthaf/vesin.git
156 | cd vesin
157 | mkdir build && cd build
158 | cmake -DCMAKE_INSTALL_PREFIX=$PREFIX ..
159 | cmake --install .
160 |
161 | You can then compile your code, adding ``$PREFIX/include`` to the
162 | compiler include path, ``$PREFIX/lib`` to the linker library path; and
163 | linking to vesin (typically with ``-lvesin``). If you are building vesin
164 | as a shared library, you'll also need to define ``VESIN_SHARED`` as a
165 | preprocessor constant (``-DVESIN_SHARED`` when compiling the code).
166 |
167 | Some relevant cmake options you can customize:
168 |
169 | +------------------------------+--------------------------------------------------+------------------------------------------------+
170 | | Option | Description | Default |
171 | +==============================+==================================================+================================================+
172 | | ``CMAKE_BUILD_TYPE`` | Type of build: Debug or Release | Release |
173 | +------------------------------+--------------------------------------------------+------------------------------------------------+
174 | | ``CMAKE_INSTALL_PREFIX`` | Prefix where the library will be installed | ``/usr/local`` |
175 | +------------------------------+--------------------------------------------------+------------------------------------------------+
176 | | ``BUILD_SHARED_LIBS`` | Default to building and installing a shared | OFF |
177 | | | library instead of a static one | |
178 | +------------------------------+--------------------------------------------------+------------------------------------------------+
179 | | ``VESIN_INSTALL`` | Should CMake install vesin library and headers | | ON when building vesin directly |
180 | | | | | OFF when including vesin in another project |
181 | +------------------------------+--------------------------------------------------+------------------------------------------------+
182 | | ``VESIN_TORCH`` | Build (and install if ``VESIN_INSTALL=ON``) the | OFF |
183 | | | vesin_torch library | |
184 | +------------------------------+--------------------------------------------------+------------------------------------------------+
185 |
186 | **TorchScript:**
187 |
188 | Set ``VESIN_TORCH`` to ``ON`` to build and install the TorchScript
189 | bindings.
190 |
191 | You can then compile your code, adding ``$PREFIX/include`` to the
192 | compiler include path, ``$PREFIX/lib`` to the linker library path; and
193 | linking to vesin_torch (typically with ``-lvesin_torch``).
194 |
195 | You'll need to also add to the include and linker path the path to the
196 | same torch installation that was used to build the library.
197 |
198 |
199 | Usage example
200 | -------------
201 |
202 | .. tab-set::
203 |
204 | .. tab-item:: Python
205 | :sync: python
206 |
207 | .. py:currentmodule:: vesin
208 |
209 | There are two ways to use vesin from Python, you can use the
210 | :py:class:`NeighborList` class:
211 |
212 | .. code-block:: Python
213 |
214 | import numpy as np
215 | from vesin import NeighborList
216 |
217 | # positions can be anything compatible with numpy's ndarray
218 | positions = [
219 | (0, 0, 0),
220 | (0, 1.3, 1.3),
221 | ]
222 | box = 3.2 * np.eye(3)
223 |
224 | calculator = NeighborList(cutoff=4.2, full_list=True)
225 | i, j, S, d = calculator.compute(
226 | points=points,
227 | box=box,
228 | periodic=True,
229 | quantities="ijSd"
230 | )
231 |
232 | Alternatively, you can use the :py:func:`ase_neighbor_list` function,
233 | which mimics the API of :py:func:`ase.neighborlist.neighbor_list`:
234 |
235 | .. code-block:: Python
236 |
237 | import ase
238 | from vesin import ase_neighbor_list
239 |
240 | atoms = ase.Atoms(...)
241 |
242 | i, j, S, d = ase_neighbor_list("ijSd", atoms, cutoff=4.2)
243 |
244 |
245 | .. tab-item:: C and C++
246 | :sync: cxx
247 |
248 | .. code-block:: c++
249 |
250 | #include
251 | #include
252 | #include
253 |
254 | #include
255 |
256 | int main() {
257 | // points can be any pointer to `double[3]`
258 | double points[][3] = {
259 | {0, 0, 0},
260 | {0, 1.3, 1.3},
261 | };
262 | size_t n_points = 2;
263 |
264 | // box can be any `double[3][3]` array
265 | double box[3][3] = {
266 | {3.2, 0.0, 0.0},
267 | {0.0, 3.2, 0.0},
268 | {0.0, 0.0, 3.2},
269 | };
270 | bool periodic = true;
271 |
272 | // calculation setup
273 | VesinOptions options;
274 | options.cutoff = 4.2;
275 | options.full = true;
276 |
277 | // decide what quantities should be computed
278 | options.return_shifts = true;
279 | options.return_distances = true;
280 | options.return_vectors = false;
281 |
282 | VesinNeighborList neighbors;
283 | memset(&neighbors, 0, sizeof(VesinNeighborList));
284 |
285 | const char* error_message = NULL;
286 | int status = vesin_neighbors(
287 | points, n_points, box, periodic,
288 | VesinCPU, options,
289 | &neighbors,
290 | &error_message,
291 | );
292 |
293 | if (status != EXIT_SUCCESS) {
294 | fprintf(stderr, "error: %s\n", error_message);
295 | return 1;
296 | }
297 |
298 | // use neighbors as needed
299 | printf("we have %d pairs\n", neighbors.length);
300 |
301 | vesin_free(&neighbors);
302 |
303 | return 0;
304 | }
305 |
306 | .. tab-item:: TorchScript Python
307 |
308 | The entry point for the TorchScript API is the
309 | :py:class:`vesin.torch.NeighborList` class in Python, and the
310 | corresponding :cpp:class:`vesin_torch::NeighborListHolder` class in C++;
311 | both modeled after the standard Python API. For Python, the class is
312 | available in the ``vesin.torch`` module.
313 |
314 | In both cases, the code is integrated with PyTorch autograd framework,
315 | meaning if the ``points`` or ``box`` argument have
316 | ``requires_grad=True``, then the ``d`` (distances) and ``D`` (distance
317 | vectors) outputs will be integrated to the computational graph.
318 |
319 | .. code-block:: Python
320 |
321 | import torch
322 | from vesin.torch import NeighborList
323 |
324 | positions = torch.tensor(
325 | [[0.0, 0.0, 0.0],
326 | [0.0, 1.3, 1.3]],
327 | dtype=torch.float64,
328 | requires_grad=True,
329 | )
330 | box = 3.2 * torch.eye(3, dtype=torch.float64)
331 |
332 | calculator = NeighborList(cutoff=4.2, full_list=True)
333 | i, j, S, d = calculator.compute(
334 | points=points,
335 | box=box,
336 | periodic=True,
337 | quantities="ijSd"
338 | )
339 |
340 | .. tab-item:: TorchScript C++
341 |
342 | The entry point for the TorchScript API is the
343 | :py:class:`vesin.torch.NeighborList` class in Python, and the
344 | corresponding :cpp:class:`vesin_torch::NeighborListHolder` class in C++;
345 | both modeled after the standard Python API. For C++, the class is
346 | available in the ``vesin_torch.hpp`` header.
347 |
348 | In both cases, the code is integrated with PyTorch autograd framework,
349 | meaning if the ``points`` or ``box`` argument have
350 | ``requires_grad=True``, then the ``d`` (distances) and ``D`` (distance
351 | vectors) outputs will be integrated to the computational graph.
352 |
353 | .. code-block:: C++
354 |
355 | #include
356 |
357 | #include
358 |
359 | int main() {
360 | auto options = torch::TensorOptions().dtype(torch::kFloat64);
361 | auto positions = torch.tensor(
362 | {{0.0, 0.0, 0.0},
363 | {0.0, 1.3, 1.3}},
364 | options
365 | );
366 | positions.requires_grad_(true);
367 |
368 | auto box = 3.2 * torch.eye(3, options);
369 |
370 | auto calculator = torch::make_intrusive(
371 | /*cutoff=*/ 4.2,
372 | /*full_list=*/ true
373 | );
374 |
375 | calculator.
376 | auto outputs = calculator.compute(
377 | /*points=*/ points,
378 | /*box=*/ box,
379 | /*periodic=*/ true,
380 | /*quantities=*/ "ijSd",
381 | /*copy=*/ true,
382 | );
383 |
384 | auto i = outputs[0];
385 | auto j = outputs[1];
386 | auto S = outputs[2];
387 | auto d = outputs[3];
388 |
389 | // ...
390 | }
391 |
392 |
393 | API Reference
394 | -------------
395 |
396 | .. toctree::
397 | :maxdepth: 1
398 | :hidden:
399 |
400 | Vesin
401 |
402 | .. toctree::
403 | :maxdepth: 1
404 |
405 | python-api
406 | torch-api
407 | c-api
408 | metatomic
409 |
410 |
411 | .. toctree::
412 | :maxdepth: 1
413 | :hidden:
414 |
415 | benchmarks
416 |
--------------------------------------------------------------------------------
/docs/src/metatomic.rst:
--------------------------------------------------------------------------------
1 | .. _metatomic-api:
2 |
3 | Metatomic interface
4 | ===================
5 |
6 | Vesin offers an interface to compute neighbor lists for `metatomic's
7 | `_ atomistic machine learning models.
8 |
9 | .. autofunction:: vesin.metatomic.compute_requested_neighbors
10 |
11 | .. autoclass:: vesin.metatomic.NeighborList
12 | :members:
13 |
--------------------------------------------------------------------------------
/docs/src/python-api.rst:
--------------------------------------------------------------------------------
1 | .. _python-api:
2 |
3 | Python API reference
4 | ====================
5 |
6 | .. currentmodule:: vesin
7 |
8 | .. autoclass:: NeighborList
9 | :members:
10 |
11 | .. autofunction:: ase_neighbor_list
12 |
--------------------------------------------------------------------------------
/docs/src/static/images/Arpitan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luthaf/vesin/d0036631d52b75dca9352d80f028a6383335d6d2/docs/src/static/images/Arpitan.png
--------------------------------------------------------------------------------
/docs/src/static/images/Catalan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luthaf/vesin/d0036631d52b75dca9352d80f028a6383335d6d2/docs/src/static/images/Catalan.png
--------------------------------------------------------------------------------
/docs/src/static/images/Lombardy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luthaf/vesin/d0036631d52b75dca9352d80f028a6383335d6d2/docs/src/static/images/Lombardy.png
--------------------------------------------------------------------------------
/docs/src/static/images/Occitan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luthaf/vesin/d0036631d52b75dca9352d80f028a6383335d6d2/docs/src/static/images/Occitan.png
--------------------------------------------------------------------------------
/docs/src/torch-api.rst:
--------------------------------------------------------------------------------
1 | .. _torch-api:
2 |
3 | TorchScript API reference
4 | =========================
5 |
6 | .. autoclass:: vesin.torch.NeighborList
7 | :members:
8 |
9 |
10 | TorchScript API reference (C++)
11 | ===============================
12 |
13 | .. doxygentypedef:: vesin_torch::NeighborList
14 |
15 | .. doxygenclass:: vesin_torch::NeighborListHolder
16 |
--------------------------------------------------------------------------------
/python/vesin/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include pyproject.toml
2 | include LICENSE
3 | include VERSION
4 |
5 | recursive-include lib *
6 |
--------------------------------------------------------------------------------
/python/vesin/README.md:
--------------------------------------------------------------------------------
1 | ../../README.md
--------------------------------------------------------------------------------
/python/vesin/VERSION:
--------------------------------------------------------------------------------
1 | ../../vesin/VERSION
--------------------------------------------------------------------------------
/python/vesin/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "vesin"
3 | dynamic = ["version"]
4 | requires-python = ">=3.9"
5 | authors = [
6 | {name = "Guillaume Fraux", email = "guillaume.fraux@epfl.ch"},
7 | ]
8 |
9 | dependencies = [
10 | "numpy"
11 | ]
12 |
13 | readme = "README.md"
14 | license = "BSD-3-Clause"
15 | description = "Computing neighbor lists for atomistic system"
16 |
17 | classifiers = [
18 | "Development Status :: 4 - Beta",
19 | "Intended Audience :: Science/Research",
20 | "Operating System :: POSIX",
21 | "Operating System :: MacOS :: MacOS X",
22 | "Operating System :: Microsoft :: Windows",
23 | "Programming Language :: Python",
24 | "Programming Language :: Python :: 3",
25 | "Topic :: Scientific/Engineering",
26 | "Topic :: Scientific/Engineering :: Bio-Informatics",
27 | "Topic :: Scientific/Engineering :: Chemistry",
28 | "Topic :: Scientific/Engineering :: Physics",
29 | "Topic :: Software Development :: Libraries",
30 | "Topic :: Software Development :: Libraries :: Python Modules",
31 | ]
32 |
33 | [project.optional-dependencies]
34 | torch = ["vesin-torch"]
35 |
36 | [project.urls]
37 | homepage = "https://github.com/Luthaf/vesin/"
38 | documentation = "https://luthaf.fr/vesin/"
39 | repository = "https://github.com/Luthaf/vesin/"
40 |
41 | ### ======================================================================== ###
42 |
43 | [build-system]
44 | requires = [
45 | "setuptools >=77",
46 | "wheel >=0.41",
47 | "cmake",
48 | ]
49 | build-backend = "setuptools.build_meta"
50 |
51 | [tool.setuptools]
52 | zip-safe = false
53 |
54 | [tool.setuptools.packages.find]
55 | include = ["vesin*"]
56 | namespaces = false
57 |
--------------------------------------------------------------------------------
/python/vesin/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import subprocess
4 | import sys
5 |
6 | from setuptools import Extension, setup
7 | from setuptools.command.bdist_egg import bdist_egg
8 | from setuptools.command.build_ext import build_ext
9 | from setuptools.command.sdist import sdist
10 | from wheel.bdist_wheel import bdist_wheel
11 |
12 |
13 | ROOT = os.path.realpath(os.path.dirname(__file__))
14 |
15 | VESIN_BUILD_TYPE = os.environ.get("VESIN_BUILD_TYPE", "release")
16 | if VESIN_BUILD_TYPE not in ["debug", "release"]:
17 | raise Exception(
18 | f"invalid build type passed: '{VESIN_BUILD_TYPE}', "
19 | "expected 'debug' or 'release'"
20 | )
21 |
22 |
23 | class universal_wheel(bdist_wheel):
24 | # When building the wheel, the `wheel` package assumes that if we have a
25 | # binary extension then we are linking to `libpython.so`; and thus the wheel
26 | # is only usable with a single python version. This is not the case for
27 | # here, and the wheel will be compatible with any Python >=3.7. This is
28 | # tracked in https://github.com/pypa/wheel/issues/185, but until then we
29 | # manually override the wheel tag.
30 | def get_tag(self):
31 | tag = bdist_wheel.get_tag(self)
32 | # tag[2:] contains the os/arch tags, we want to keep them
33 | return ("py3", "none") + tag[2:]
34 |
35 |
36 | class cmake_ext(build_ext):
37 | """
38 | Build the native library using cmake
39 | """
40 |
41 | def run(self):
42 | source_dir = os.path.join(ROOT, "lib")
43 | if not os.path.exists(source_dir):
44 | # we are building from a checkout
45 | source_dir = os.path.join(ROOT, "..", "..", "vesin")
46 |
47 | build_dir = os.path.join(ROOT, "build", "cmake-build")
48 | install_dir = os.path.join(os.path.realpath(self.build_lib), "vesin")
49 |
50 | os.makedirs(build_dir, exist_ok=True)
51 |
52 | cmake_options = [
53 | f"-DCMAKE_INSTALL_PREFIX={install_dir}",
54 | f"-DCMAKE_BUILD_TYPE={VESIN_BUILD_TYPE}",
55 | "-DBUILD_SHARED_LIBS=ON",
56 | ]
57 |
58 | subprocess.run(
59 | ["cmake", source_dir, *cmake_options],
60 | cwd=build_dir,
61 | check=True,
62 | )
63 | subprocess.run(
64 | ["cmake", "--build", build_dir, "--target", "install"],
65 | check=True,
66 | )
67 |
68 |
69 | class bdist_egg_disabled(bdist_egg):
70 | """Disabled version of bdist_egg
71 |
72 | Prevents setup.py install performing setuptools' default easy_install,
73 | which it should never ever do.
74 | """
75 |
76 | def run(self):
77 | sys.exit(
78 | "Aborting implicit building of eggs.\nUse `pip install .` or "
79 | "`python -m build --wheel . && pip install "
80 | "dist/metatensor_core-*.whl` to install from source."
81 | )
82 |
83 |
84 | class sdist_with_lib(sdist):
85 | """
86 | Create a sdist including the code for the native library
87 | """
88 |
89 | def run(self):
90 | # generate extra files
91 | shutil.copytree(os.path.join(ROOT, "..", "..", "vesin"), "lib")
92 |
93 | # run original sdist
94 | super().run()
95 |
96 | # cleanup
97 | shutil.rmtree("lib")
98 |
99 |
100 | if __name__ == "__main__":
101 | setup(
102 | version=open("VERSION").read().strip(),
103 | ext_modules=[
104 | # only declare the extension, it is built & copied as required by cmake
105 | # in the build_ext command
106 | Extension(name="vesin", sources=[]),
107 | ],
108 | cmdclass={
109 | "sdist": sdist_with_lib,
110 | "build_ext": cmake_ext,
111 | "bdist_egg": bdist_egg if "bdist_egg" in sys.argv else bdist_egg_disabled,
112 | "bdist_wheel": universal_wheel,
113 | },
114 | )
115 |
--------------------------------------------------------------------------------
/python/vesin/tests/data/Cd2I4O12.xyz:
--------------------------------------------------------------------------------
1 | 18
2 | Lattice="5.7779061565689 0.0 0.0 0.0 5.9353021159594 0.0 0.0 0.0 32.37048911" Properties=species:S:1:pos:R:3:masses:R:1:forces:R:3 pbc="T T T"
3 | Cd 1.66027661 1.01057390 14.51270543 112.41400000 -0.03865710 -0.05802875 0.02849028
4 | Cd 4.54922969 1.95707716 17.85778368 112.41400000 -0.03865710 0.05802875 -0.02849028
5 | I 5.63439709 3.76166207 14.85737491 126.90447000 -1.47061191 1.06604780 0.08108146
6 | I 2.74544401 5.14129111 17.51311420 126.90447000 -1.47061191 -1.06604780 -0.08108146
7 | I 4.31565215 1.38050855 12.64875342 126.90447000 0.74986725 0.79087438 -0.69424655
8 | I 1.42669907 1.58714251 19.72173569 126.90447000 0.74986725 -0.79087438 0.69424655
9 | O 2.78364007 0.39473043 12.50000000 15.99900000 -1.11916203 -0.61054277 -0.10327189
10 | O 5.67259315 2.57292062 19.87048911 15.99900000 -1.11916203 0.61054277 0.10327189
11 | O 3.64269369 2.38202798 14.05899015 15.99900000 -0.44450476 0.58457571 0.81988628
12 | O 0.75374061 0.58562308 18.31149896 15.99900000 -0.44450476 -0.58457571 -0.81988628
13 | O 5.44167386 0.15724415 13.43179857 15.99900000 0.93009003 -0.74236748 0.47707191
14 | O 2.55272078 2.81040691 18.93869054 15.99900000 0.93009003 0.74236748 -0.47707191
15 | O 1.34512782 4.79252145 15.30307812 15.99900000 0.80628161 0.77936335 0.23325304
16 | O 4.23408090 4.11043172 17.06741099 15.99900000 0.80628161 -0.77936335 -0.23325304
17 | O 0.52359753 2.92010411 13.34319181 15.99900000 0.37932513 -0.74376411 -0.97400744
18 | O 3.41255061 0.04754695 19.02729730 15.99900000 0.37932513 0.74376411 0.97400744
19 | O 0.27893312 2.33486569 15.96606058 15.99900000 0.20737178 -1.07627874 0.77025494
20 | O 3.16788619 0.63278537 16.40442853 15.99900000 0.20737178 1.07627874 -0.77025494
21 |
--------------------------------------------------------------------------------
/python/vesin/tests/data/carbon.xyz:
--------------------------------------------------------------------------------
1 | 4
2 | Lattice="2.460394 0.0 0.0 -1.26336 2.044166 0.0 -0.139209 -0.407369 6.809714" Properties=species:S:1:pos:R:3 pbc="T T T"
3 | C -0.03480225 -0.10184225 1.70242850
4 | C -0.10440675 -0.30552675 5.10728550
5 | C -0.05691216 1.26093576 1.70242850
6 | C 1.11473716 0.37586124 5.10728550
7 |
--------------------------------------------------------------------------------
/python/vesin/tests/data/diamond.xyz:
--------------------------------------------------------------------------------
1 | 8
2 | Lattice="3.5607451090903233 0.0 0.0 0.0 3.5607451090903233 0.0 0.0 0.0 3.5607451090903233" Properties=species:S:1:pos:R:3 pbc="T T T"
3 | C 0.00000000 0.00000000 1.78037255
4 | C 0.89018628 0.89018628 2.67055883
5 | C -0.00000000 1.78037255 0.00000000
6 | C 0.89018628 2.67055883 0.89018628
7 | C 1.78037255 0.00000000 0.00000000
8 | C 2.67055883 0.89018628 0.89018628
9 | C 1.78037255 1.78037255 1.78037255
10 | C 2.67055883 2.67055883 2.67055883
11 |
--------------------------------------------------------------------------------
/python/vesin/tests/data/naphthalene.xyz:
--------------------------------------------------------------------------------
1 | 18
2 | Properties=species:S:1:pos:R:3 energy=-10478070.659035627 pbc="F F F"
3 | C -1.07752258 -1.44519018 0.23486494
4 | C -2.34287934 -1.00436934 -0.03884263
5 | C -2.49625633 0.40527529 -0.23488609
6 | C -1.32251140 1.30701063 -0.23011205
7 | C -0.03972489 0.71954944 -0.09685793
8 | C 1.16386179 1.56294692 -0.06105191
9 | C 2.34152520 0.92675009 0.01612411
10 | C 2.44515671 -0.58490551 0.10661457
11 | C 1.31344319 -1.26692346 0.24849954
12 | C 0.01904813 -0.61131099 0.05976149
13 | H -0.84295984 -2.42125368 0.67296941
14 | H -3.18464777 -1.71419493 -0.25475633
15 | H -3.55069635 0.69888542 -0.36353746
16 | H -1.62292307 2.33411368 -0.32316927
17 | H 1.06380309 2.66106164 -0.12679498
18 | H 3.21431354 1.51359069 -0.29721477
19 | H 3.54308435 -0.81134785 0.12807956
20 | H 1.33078774 -2.36610850 0.51600063
21 |
--------------------------------------------------------------------------------
/python/vesin/tests/data/readme.txt:
--------------------------------------------------------------------------------
1 | This is a small test suite for vesin, composed of:
2 | - water.xyz: a medium-sized cluster
3 | - napthalene.xyz: a small organic molecule
4 | - diamond.xyz: a periodic system with a cubic cell
5 | - carbon.xyz: a periodic system with a non-cubic cell
6 |
--------------------------------------------------------------------------------
/python/vesin/tests/data/water.xyz:
--------------------------------------------------------------------------------
1 | 51
2 | Properties=species:S:1:pos:R:3
3 | O 0.35772 8.3064 11.7449
4 | H 1.27587 8.49472 11.9392
5 | H 0.222311 8.65568 10.864
6 | O 9.73218 8.90883 5.56731
7 | H 10.1022 9.47818 6.24194
8 | H 8.84322 9.23864 5.4361
9 | O 1.47521 6.40456 0.956651
10 | H 1.36119 5.83634 1.71846
11 | H 0.61367 6.79872 0.82016
12 | O 10.6678 10.5038 3.35808
13 | H 11.3289 9.85087 3.12814
14 | H 10.3341 10.2182 4.20857
15 | O 9.1735 7.43006 2.04159
16 | H 9.58976 6.88585 2.71002
17 | H 8.97205 8.25061 2.49142
18 | O 9.35696 13.0776 4.13527
19 | H 9.10802 12.7474 4.99856
20 | H 9.72185 12.3167 3.68342
21 | O 9.82166 12.763 10.4287
22 | H 9.58765 12.0028 9.89611
23 | H 9.16248 12.782 11.1225
24 | O 3.50771 14.0639 3.81693
25 | H 4.14093 13.6853 3.20707
26 | H 2.83722 14.4534 3.25567
27 | O 8.4947 11.0391 8.78921
28 | H 7.81372 11.439 8.24831
29 | H 8.0283 10.4014 9.32966
30 | O 8.12407 13.0253 12.765
31 | H 7.95796 13.964 12.8518
32 | H 8.41685 12.7521 13.6344
33 | O 7.8082 0.044473 1.02143
34 | H 8.34368 0.690967 1.48136
35 | H 7.07934 0.549401 0.66082
36 | O 4.56356 8.95562 8.51371
37 | H 5.21772 9.01236 9.21019
38 | H 3.81575 9.45236 8.84578
39 | O 1.67172 4.68017 8.83563
40 | H 2.26816 5.42761 8.79298
41 | H 2.18559 3.94385 8.50399
42 | O 2.93581 9.19049 12.3264
43 | H 3.39043 9.92348 11.9113
44 | H 3.47201 8.42579 12.1167
45 | O 4.91764 12.4533 1.70512
46 | H 3.98424 12.2612 1.61526
47 | H 5.14136 12.9346 0.908508
48 | O 6.85585 12.6923 6.89814
49 | H 7.47976 13.279 7.32575
50 | H 6.9787 12.8522 5.96241
51 | O 1.96907 11.6309 1.47723
52 | H 1.97934 11.4997 2.42534
53 | H 2.35724 10.8321 1.12019
54 |
--------------------------------------------------------------------------------
/python/vesin/tests/test_metatomic.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 |
3 | import pytest
4 | import torch
5 | from metatensor.torch import Labels, TensorMap
6 | from metatomic.torch import (
7 | AtomisticModel,
8 | ModelCapabilities,
9 | ModelMetadata,
10 | ModelOutput,
11 | NeighborListOptions,
12 | System,
13 | )
14 |
15 | from vesin.metatomic import NeighborList, compute_requested_neighbors
16 |
17 |
18 | def test_errors():
19 | positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64)
20 | cell = 4 * torch.eye(3, dtype=torch.float64)
21 | system = System(
22 | positions=positions,
23 | cell=cell,
24 | pbc=torch.ones(3, dtype=bool),
25 | types=torch.tensor([6, 8]),
26 | )
27 |
28 | options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True)
29 | calculator = NeighborList(options, length_unit="A")
30 |
31 | system.pbc[0] = False
32 | message = (
33 | "vesin currently does not support mixed periodic and non-periodic "
34 | "boundary conditions"
35 | )
36 | with pytest.raises(NotImplementedError, match=message):
37 | calculator.compute(system)
38 |
39 |
40 | def test_backward():
41 | positions = torch.tensor(
42 | [[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64, requires_grad=True
43 | )
44 | cell = (4 * torch.eye(3, dtype=torch.float64)).clone().requires_grad_(True)
45 | system = System(
46 | positions=positions,
47 | cell=cell,
48 | pbc=torch.ones(3, dtype=bool),
49 | types=torch.tensor([6, 8]),
50 | )
51 |
52 | options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True)
53 | calculator = NeighborList(options, length_unit="A")
54 | neighbors = calculator.compute(system)
55 |
56 | value = ((neighbors.values) ** 2).sum() * torch.linalg.det(cell)
57 | value.backward()
58 |
59 | # check there are gradients, and they are not zero
60 | assert positions.grad is not None
61 | assert cell.grad is not None
62 | assert torch.linalg.norm(positions.grad) > 0
63 | assert torch.linalg.norm(cell.grad) > 0
64 |
65 |
66 | class InnerModule(torch.nn.Module):
67 | def requested_neighbor_lists(self) -> List[NeighborListOptions]:
68 | return [NeighborListOptions(cutoff=3.4, full_list=False, strict=True)]
69 |
70 |
71 | class OuterModule(torch.nn.Module):
72 | def __init__(self):
73 | super().__init__()
74 | self.inner = InnerModule()
75 |
76 | def requested_neighbor_lists(self) -> List[NeighborListOptions]:
77 | return [NeighborListOptions(cutoff=5.2, full_list=True, strict=False)]
78 |
79 | def forward(
80 | self,
81 | systems: List[System],
82 | outputs: Dict[str, ModelOutput],
83 | selected_atoms: Optional[Labels],
84 | ) -> Dict[str, TensorMap]:
85 | return {}
86 |
87 |
88 | def test_model():
89 | positions = torch.tensor(
90 | [[0.0, 0.0, 0.0], [1.0, 1.0, 1.4]], dtype=torch.float64, requires_grad=True
91 | )
92 | cell = (4 * torch.eye(3, dtype=torch.float64)).clone().requires_grad_(True)
93 | pbc = torch.ones(3, dtype=bool)
94 | types = torch.tensor([6, 8])
95 | systems = [
96 | System(positions=positions, cell=cell, pbc=pbc, types=types),
97 | System(positions=positions, cell=cell, pbc=pbc, types=types),
98 | ]
99 |
100 | # Using a "raw" model
101 | model = OuterModule()
102 | compute_requested_neighbors(
103 | systems=systems,
104 | system_length_unit="A",
105 | model=model,
106 | model_length_unit="A",
107 | check_consistency=True,
108 | )
109 |
110 | for system in systems:
111 | all_options = system.known_neighbor_lists()
112 | assert len(all_options) == 2
113 | assert all_options[0].requestors() == ["OuterModule"]
114 | assert all_options[0].cutoff == 5.2
115 | assert all_options[1].requestors() == ["OuterModule.inner"]
116 | assert all_options[1].cutoff == 3.4
117 |
118 | # Using a AtomisticModel
119 | capabilities = ModelCapabilities(
120 | length_unit="A",
121 | interaction_range=6.0,
122 | supported_devices=["cpu"],
123 | dtype="float64",
124 | )
125 | model = AtomisticModel(model.eval(), ModelMetadata(), capabilities)
126 | compute_requested_neighbors(
127 | systems=System(positions=positions, cell=cell, pbc=pbc, types=types),
128 | system_length_unit="A",
129 | model=model,
130 | )
131 |
132 | for system in systems:
133 | all_options = system.known_neighbor_lists()
134 | assert len(all_options) == 2
135 | assert all_options[0].requestors() == ["OuterModule"]
136 | assert all_options[0].cutoff == 5.2
137 | assert all_options[1].requestors() == ["OuterModule.inner"]
138 | assert all_options[1].cutoff == 3.4
139 |
140 | message = (
141 | "the given `model_length_unit` \\(nm\\) does not match the model "
142 | "capabilities \\(A\\)"
143 | )
144 | with pytest.raises(ValueError, match=message):
145 | compute_requested_neighbors(
146 | systems=System(positions=positions, cell=cell, pbc=pbc, types=types),
147 | system_length_unit="A",
148 | model=model,
149 | model_length_unit="nm",
150 | )
151 |
--------------------------------------------------------------------------------
/python/vesin/tests/test_neighbors.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import ase.build
4 | import ase.io
5 | import ase.neighborlist
6 | import numpy as np
7 | import pytest
8 |
9 | import vesin
10 |
11 |
12 | CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
13 |
14 |
15 | def non_sorted_nl(quantities, atoms, cutoff):
16 | calculator = vesin.NeighborList(cutoff=cutoff, full_list=True, sorted=False)
17 | outputs = calculator.compute(
18 | points=atoms.positions,
19 | box=atoms.cell[:],
20 | periodic=np.all(atoms.pbc),
21 | quantities=quantities,
22 | copy=False,
23 | )
24 | # since we have `copy=False`, also return the calculator to keep the memory alive
25 | return *outputs, calculator
26 |
27 |
28 | @pytest.mark.parametrize("system", ["water", "diamond", "naphthalene", "carbon"])
29 | @pytest.mark.parametrize("cutoff", [float(i) for i in range(1, 10)])
30 | @pytest.mark.parametrize("vesin_nl", [vesin.ase_neighbor_list, non_sorted_nl])
31 | def test_neighbors(system, cutoff, vesin_nl):
32 | atoms = ase.io.read(f"{CURRENT_DIR}/data/{system}.xyz")
33 |
34 | ase_i, ase_j, ase_S, ase_D = ase.neighborlist.neighbor_list("ijSD", atoms, cutoff)
35 | vesin_i, vesin_j, vesin_S, vesin_D, *_ = vesin_nl("ijSD", atoms, cutoff)
36 |
37 | assert len(ase_i) == len(vesin_i)
38 | assert len(ase_j) == len(vesin_j)
39 | assert len(ase_S) == len(vesin_S)
40 | assert len(ase_D) == len(vesin_D)
41 |
42 | ase_ijS = np.concatenate(
43 | (ase_i.reshape(-1, 1), ase_j.reshape(-1, 1), ase_S), axis=1
44 | )
45 | vesin_ijS = np.concatenate(
46 | (vesin_i.reshape(-1, 1), vesin_j.reshape(-1, 1), vesin_S), axis=1
47 | )
48 |
49 | ase_sort_indices = np.lexsort(ase_ijS.T)
50 | vesin_sort_indices = np.lexsort(vesin_ijS.T)
51 |
52 | assert np.array_equal(ase_ijS[ase_sort_indices], vesin_ijS[vesin_sort_indices])
53 | assert np.allclose(ase_D[ase_sort_indices], vesin_D[vesin_sort_indices])
54 |
55 |
56 | def test_pairs_output():
57 | atoms = ase.io.read(f"{CURRENT_DIR}/data/diamond.xyz")
58 |
59 | calculator = vesin.NeighborList(cutoff=2.0, full_list=True, sorted=False)
60 | i, j, P = calculator.compute(
61 | points=atoms.positions, box=atoms.cell[:], periodic=True, quantities="ijP"
62 | )
63 |
64 | assert np.all(np.vstack([i, j]).T == P)
65 |
66 |
67 | def test_sorting():
68 | atoms = ase.io.read(f"{CURRENT_DIR}/data/diamond.xyz")
69 |
70 | calculator = vesin.NeighborList(cutoff=2.0, full_list=True, sorted=False)
71 | i, j = calculator.compute(
72 | points=atoms.positions, box=atoms.cell[:], periodic=True, quantities="ij"
73 | )
74 | unsorted_ij = np.concatenate((i.reshape(-1, 1), j.reshape(-1, 1)), axis=1)
75 | assert not np.all(unsorted_ij[np.lexsort((j, i))] == unsorted_ij)
76 |
77 | calculator = vesin.NeighborList(cutoff=2.0, full_list=True, sorted=True)
78 | i, j = calculator.compute(
79 | points=atoms.positions, box=atoms.cell[:], periodic=True, quantities="ij"
80 | )
81 |
82 | sorted_ij = np.concatenate((i.reshape(-1, 1), j.reshape(-1, 1)), axis=1)
83 | assert np.all(sorted_ij[np.lexsort((j, i))] == sorted_ij)
84 |
85 | # check that unsorted is not already sorted by chance
86 | assert not np.all(sorted_ij == unsorted_ij)
87 |
88 | # https://github.com/Luthaf/vesin/issues/34
89 | atoms = ase.io.read(f"{CURRENT_DIR}/data/Cd2I4O12.xyz")
90 | calculator = vesin.NeighborList(cutoff=5.0, full_list=True, sorted=True)
91 | i, j = calculator.compute(
92 | points=atoms.positions, box=atoms.cell[:], periodic=True, quantities="ij"
93 | )
94 | sorted_ij = np.concatenate((i.reshape(-1, 1), j.reshape(-1, 1)), axis=1)
95 | assert np.all(sorted_ij[np.lexsort((j, i))] == sorted_ij)
96 |
97 |
98 | def test_errors():
99 | points = np.array([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]])
100 | box = np.zeros((3, 3))
101 |
102 | nl = vesin.NeighborList(cutoff=1.2, full_list=True)
103 |
104 | message = "the box matrix is not invertible"
105 | with pytest.raises(RuntimeError, match=message):
106 | nl.compute(points, box, periodic=True, quantities="ij")
107 |
108 | box = np.eye(3, 3)
109 | message = "cutoff is too small"
110 | with pytest.raises(RuntimeError, match=message):
111 | nl = vesin.NeighborList(cutoff=1e-12, full_list=True)
112 | nl.compute(points, box, periodic=True, quantities="ij")
113 |
114 | message = "cutoff must be a finite, positive number"
115 | with pytest.raises(RuntimeError, match=message):
116 | nl = vesin.NeighborList(cutoff=0.0, full_list=True)
117 | nl.compute(points, box, periodic=True, quantities="ij")
118 |
119 | with pytest.raises(RuntimeError, match=message):
120 | nl = vesin.NeighborList(cutoff=-12.0, full_list=True)
121 | nl.compute(points, box, periodic=True, quantities="ij")
122 |
123 | with pytest.raises(RuntimeError, match=message):
124 | nl = vesin.NeighborList(cutoff=float("inf"), full_list=True)
125 | nl.compute(points, box, periodic=True, quantities="ij")
126 |
127 | with pytest.raises(RuntimeError, match=message):
128 | nl = vesin.NeighborList(cutoff=float("nan"), full_list=True)
129 | nl.compute(points, box, periodic=True, quantities="ij")
130 |
--------------------------------------------------------------------------------
/python/vesin/vesin/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib.metadata
2 |
3 | from ._ase import ase_neighbor_list # noqa: F401
4 | from ._neighbors import NeighborList # noqa: F401
5 |
6 |
7 | __version__ = importlib.metadata.version("vesin")
8 |
--------------------------------------------------------------------------------
/python/vesin/vesin/_ase.py:
--------------------------------------------------------------------------------
1 | try:
2 | import ase
3 | except ImportError:
4 | ase = None
5 |
6 |
7 | from ._neighbors import NeighborList
8 |
9 |
10 | def ase_neighbor_list(quantities, a, cutoff, self_interaction=False, max_nbins=0):
11 | """
12 | This is a thin wrapper around :py:class:`NeighborList`, providing the same API as
13 | :py:func:`ase.neighborlist.neighbor_list`.
14 |
15 | It is intended as a drop-in replacement for the ASE function, but only supports a
16 | subset of the functionality. Notably, the following is not supported:
17 |
18 | - ``self_interaction=True``
19 | - :py:class:`ase.Atoms` with mixed periodic boundary conditions
20 | - giving ``cutoff`` as a dictionary
21 |
22 | :param quantities: quantities to output from the neighbor list. Supported are
23 | ``"i"``, ``"j"``, ``"d"``, ``"D"``, and ``"S"`` with the same meaning as in ASE.
24 | :param a: :py:class:`ase.Atoms` instance
25 | :param cutoff: cutoff radius for the neighbor list
26 | :param self_interaction: Should an atom be considered its own neighbor? Default:
27 | False
28 | :param max_nbins: for ASE compatibility, ignored by this implementation
29 | """
30 | if ase is None:
31 | raise ImportError("could not import ase, this function requires ase")
32 |
33 | if self_interaction:
34 | raise ValueError("self_interaction=True is not implemented")
35 |
36 | if not isinstance(cutoff, float):
37 | raise ValueError("only a single float cutoff is supported")
38 |
39 | if not isinstance(a, ase.Atoms):
40 | raise TypeError(f"`a` should be ase.Atoms, got {type(a)} instead")
41 |
42 | if a.pbc[0] and a.pbc[1] and a.pbc[2]:
43 | periodic = True
44 | elif not a.pbc[0] and not a.pbc[1] and not a.pbc[2]:
45 | periodic = False
46 | else:
47 | raise ValueError(
48 | "different periodic boundary conditions on different axis are not supported"
49 | )
50 |
51 | # sorted=True and full_list=True since that's what ASE does
52 | calculator = NeighborList(cutoff=cutoff, full_list=True, sorted=True)
53 | return calculator.compute(
54 | points=a.positions,
55 | box=a.cell[:],
56 | periodic=periodic,
57 | quantities=quantities,
58 | copy=True,
59 | )
60 |
--------------------------------------------------------------------------------
/python/vesin/vesin/_c_api.py:
--------------------------------------------------------------------------------
1 | import ctypes
2 | from ctypes import ARRAY, POINTER
3 |
4 |
5 | VesinDevice = ctypes.c_int
6 | VesinUnknownDevice = 0
7 | VesinCPU = 1
8 |
9 |
10 | class VesinOptions(ctypes.Structure):
11 | _fields_ = [
12 | ("cutoff", ctypes.c_double),
13 | ("full", ctypes.c_bool),
14 | ("sorted", ctypes.c_bool),
15 | ("return_shifts", ctypes.c_bool),
16 | ("return_distances", ctypes.c_bool),
17 | ("return_vectors", ctypes.c_bool),
18 | ]
19 |
20 |
21 | class VesinNeighborList(ctypes.Structure):
22 | _fields_ = [
23 | ("length", ctypes.c_size_t),
24 | ("device", VesinDevice),
25 | ("pairs", POINTER(ARRAY(ctypes.c_size_t, 2))),
26 | ("shifts", POINTER(ARRAY(ctypes.c_int32, 3))),
27 | ("distances", POINTER(ctypes.c_double)),
28 | ("vectors", POINTER(ARRAY(ctypes.c_double, 3))),
29 | ]
30 |
31 |
32 | def setup_functions(lib):
33 | lib.vesin_free.argtypes = [POINTER(VesinNeighborList)]
34 | lib.vesin_free.restype = None
35 |
36 | lib.vesin_neighbors.argtypes = [
37 | POINTER(ARRAY(ctypes.c_double, 3)), # points
38 | ctypes.c_size_t, # n_points
39 | ARRAY(ARRAY(ctypes.c_double, 3), 3), # box
40 | ctypes.c_bool, # periodic
41 | VesinDevice, # device
42 | VesinOptions, # options
43 | POINTER(VesinNeighborList), # neighbors
44 | POINTER(ctypes.c_char_p), # error_message
45 | ]
46 | lib.vesin_neighbors.restype = ctypes.c_int
47 |
--------------------------------------------------------------------------------
/python/vesin/vesin/_c_lib.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from ctypes import cdll
4 |
5 | from ._c_api import setup_functions
6 |
7 |
8 | _HERE = os.path.realpath(os.path.dirname(__file__))
9 |
10 |
11 | class LibraryFinder(object):
12 | def __init__(self):
13 | self._cached_dll = None
14 |
15 | def __call__(self):
16 | if self._cached_dll is None:
17 | path = _lib_path()
18 | self._cached_dll = cdll.LoadLibrary(path)
19 | setup_functions(self._cached_dll)
20 |
21 | return self._cached_dll
22 |
23 |
24 | def _lib_path():
25 | if sys.platform.startswith("darwin"):
26 | windows = False
27 | path = os.path.join(_HERE, "lib", "libvesin.dylib")
28 | elif sys.platform.startswith("linux"):
29 | windows = False
30 | path = os.path.join(_HERE, "lib", "libvesin.so")
31 | elif sys.platform.startswith("win"):
32 | windows = True
33 | path = os.path.join(_HERE, "bin", "vesin.dll")
34 | else:
35 | raise ImportError("Unknown platform. Please edit this file")
36 |
37 | if os.path.isfile(path):
38 | if windows:
39 | _check_dll(path)
40 | return path
41 |
42 | raise ImportError("Could not find vesin shared library at " + path)
43 |
44 |
45 | def _check_dll(path):
46 | """Check if the DLL at ``path`` matches the architecture of Python"""
47 | import platform
48 | import struct
49 |
50 | IMAGE_FILE_MACHINE_I386 = 332
51 | IMAGE_FILE_MACHINE_AMD64 = 34404
52 | IMAGE_FILE_MACHINE_ARM64 = 43620
53 |
54 | machine = None
55 | with open(path, "rb") as fd:
56 | header = fd.read(2).decode(encoding="utf-8", errors="strict")
57 | if header != "MZ":
58 | raise ImportError(path + " is not a DLL")
59 | else:
60 | fd.seek(60)
61 | header = fd.read(4)
62 | header_offset = struct.unpack(" List[np.ndarray]:
44 | """
45 | Compute the neighbor list for the system defined by ``positions``, ``box``, and
46 | ``periodic``; returning the requested ``quantities``.
47 |
48 | ``quantities`` can contain any combination of the following values:
49 |
50 | - ``"i"`` to get the index of the first point in the pair
51 | - ``"j"`` to get the index of the second point in the pair
52 | - ``"P"`` to get the indexes of the two points in the pair simultaneously
53 | - ``"S"`` to get the periodic shift of the pair
54 | - ``"d"`` to get the distance between points in the pair
55 | - ``"D"`` to get the distance vector between points in the pair
56 |
57 | :param points: positions of all points in the system (this can be anything that
58 | can be converted to a numpy array)
59 | :param box: bounding box of the system (this can be anything that can be
60 | converted to a numpy array)
61 | :param periodic: should we use periodic boundary conditions?
62 | :param quantities: quantities to return, defaults to "ij"
63 | :param copy: should we copy the returned quantities, defaults to ``True``.
64 | Setting this to ``False`` might be a bit faster, but the returned arrays are
65 | view inside this class, and will be invalidated whenever this class is
66 | garbage collected or used to run a new calculation.
67 |
68 | :return: tuple of arrays as indicated by ``quantities``.
69 | """
70 | points = np.asarray(points, dtype=np.float64)
71 | box = np.asarray(box, dtype=np.float64)
72 |
73 | if box.shape != (3, 3):
74 | raise ValueError("`box` must be a 3x3 matrix")
75 |
76 | Vector = ARRAY(ctypes.c_double, 3)
77 | box = ARRAY(Vector, 3)(
78 | Vector(box[0][0], box[0][1], box[0][2]),
79 | Vector(box[1][0], box[1][1], box[1][2]),
80 | Vector(box[2][0], box[2][1], box[2][2]),
81 | )
82 |
83 | if len(points.shape) != 2 or points.shape[1] != 3:
84 | raise ValueError("`points` must be a nx3 array")
85 |
86 | options = VesinOptions()
87 | options.cutoff = self.cutoff
88 | options.full = self.full_list
89 | options.sorted = self.sorted
90 | options.return_shifts = "S" in quantities
91 | options.return_distances = "d" in quantities
92 | options.return_vectors = "D" in quantities
93 |
94 | error_message = ctypes.c_char_p()
95 | status = self._lib.vesin_neighbors(
96 | points.ctypes.data_as(POINTER(ARRAY(ctypes.c_double, 3))),
97 | points.shape[0],
98 | box,
99 | periodic,
100 | VesinCPU,
101 | options,
102 | self._neighbors,
103 | error_message,
104 | )
105 |
106 | if status != 0:
107 | raise RuntimeError(error_message.value.decode("utf8"))
108 |
109 | # create numpy arrays for the data
110 | n_pairs = self._neighbors.length
111 | if n_pairs == 0:
112 | pairs = np.empty((0, 2), dtype=ctypes.c_size_t)
113 | shifts = np.empty((0, 3), dtype=ctypes.c_int32)
114 | distances = np.empty((0,), dtype=ctypes.c_double)
115 | vectors = np.empty((0, 3), dtype=ctypes.c_double)
116 | else:
117 | ptr = ctypes.cast(self._neighbors.pairs, POINTER(ctypes.c_size_t))
118 | pairs = np.ctypeslib.as_array(ptr, shape=(n_pairs, 2))
119 |
120 | if "S" in quantities:
121 | ptr = ctypes.cast(self._neighbors.shifts, POINTER(ctypes.c_int32))
122 | shifts = np.ctypeslib.as_array(ptr, shape=(n_pairs, 3))
123 |
124 | if "d" in quantities:
125 | ptr = ctypes.cast(self._neighbors.distances, POINTER(ctypes.c_double))
126 | distances = np.ctypeslib.as_array(ptr, shape=(n_pairs,))
127 |
128 | if "D" in quantities:
129 | ptr = ctypes.cast(self._neighbors.vectors, POINTER(ctypes.c_double))
130 | vectors = np.ctypeslib.as_array(ptr, shape=(n_pairs, 3))
131 |
132 | # assemble output
133 |
134 | data = []
135 | for quantity in quantities:
136 | if quantity == "P":
137 | if copy:
138 | data.append(pairs.copy())
139 | else:
140 | data.append(pairs)
141 |
142 | if quantity == "i":
143 | if copy:
144 | data.append(pairs[:, 0].copy())
145 | else:
146 | data.append(pairs[:, 0])
147 |
148 | elif quantity == "j":
149 | if copy:
150 | data.append(pairs[:, 1].copy())
151 | else:
152 | data.append(pairs[:, 1])
153 |
154 | elif quantity == "S":
155 | if copy:
156 | data.append(shifts.copy())
157 | else:
158 | data.append(shifts)
159 |
160 | elif quantity == "d":
161 | if copy:
162 | data.append(distances.copy())
163 | else:
164 | data.append(distances)
165 |
166 | elif quantity == "D":
167 | if copy:
168 | data.append(vectors.copy())
169 | else:
170 | data.append(vectors)
171 |
172 | return tuple(data)
173 |
--------------------------------------------------------------------------------
/python/vesin/vesin/metatomic/__init__.py:
--------------------------------------------------------------------------------
1 | from ._model import compute_requested_neighbors # noqa: F401
2 | from ._neighbors import NeighborList # noqa: F401
3 |
--------------------------------------------------------------------------------
/python/vesin/vesin/metatomic/_model.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | import torch
4 | from metatomic.torch import (
5 | AtomisticModel,
6 | ModelInterface,
7 | NeighborListOptions,
8 | System,
9 | )
10 |
11 | from ._neighbors import NeighborList
12 |
13 |
14 | def compute_requested_neighbors(
15 | systems: Union[List[System], System],
16 | system_length_unit: str,
17 | model: Union[AtomisticModel, ModelInterface],
18 | model_length_unit: Optional[str] = None,
19 | check_consistency: bool = False,
20 | ):
21 | """
22 | Compute all neighbors lists requested by the ``model`` through
23 | ``requested_neighbor_lists()`` member functions, and store them inside all the
24 | ``systems``.
25 |
26 | :param systems: Single system or list of systems for which we need to compute the
27 | neighbor lists that the model requires.
28 | :param system_length_unit: unit of length used by the data in ``systems``
29 | :param model: :py:class:`AtomisticModel` or any ``torch.nn.Module`` following the
30 | :py:class:`ModelInterface`
31 | :param model_length_unit: unit of length used by the model, optional. This is only
32 | required when giving a raw model instead of a :py:class:`AtomisticModel`.
33 | :param check_consistency: whether to run additional checks on the neighbor lists
34 | validity
35 | """
36 |
37 | if isinstance(model, AtomisticModel):
38 | if model_length_unit is not None:
39 | if model.capabilities().length_unit != model_length_unit:
40 | raise ValueError(
41 | f"the given `model_length_unit` ({model_length_unit}) does not "
42 | f"match the model capabilities ({model.capabilities().length_unit})"
43 | )
44 |
45 | all_options = model.requested_neighbor_lists()
46 | elif isinstance(model, torch.nn.Module):
47 | if model_length_unit is None:
48 | raise ValueError(
49 | "`model_length_unit` parameter is required when not "
50 | "using AtomisticModel"
51 | )
52 |
53 | all_options = []
54 | _get_requested_neighbor_lists(
55 | model, model.__class__.__name__, all_options, model_length_unit
56 | )
57 |
58 | if not isinstance(systems, list):
59 | systems = [systems]
60 |
61 | for options in all_options:
62 | calculator = NeighborList(
63 | options,
64 | system_length_unit,
65 | check_consistency=check_consistency,
66 | )
67 |
68 | for system in systems:
69 | neighbors = calculator.compute(system)
70 | system.add_neighbor_list(options, neighbors)
71 |
72 |
73 | def _get_requested_neighbor_lists(
74 | module: torch.nn.Module,
75 | module_name: str,
76 | requested: List[NeighborListOptions],
77 | length_unit: str,
78 | ):
79 | """
80 | Recursively extract the requested neighbor lists from a non-exported metatomic
81 | model.
82 | """
83 | if hasattr(module, "requested_neighbor_lists"):
84 | for new_options in module.requested_neighbor_lists():
85 | new_options.add_requestor(module_name)
86 |
87 | already_requested = False
88 | for existing in requested:
89 | if existing == new_options:
90 | already_requested = True
91 | for requestor in new_options.requestors():
92 | existing.add_requestor(requestor)
93 |
94 | if not already_requested:
95 | if new_options.length_unit not in ["", length_unit]:
96 | raise ValueError(
97 | f"NeighborsListOptions from {module_name} already have a "
98 | f"length unit ('{new_options.length_unit}') which does not "
99 | f"match the model length units ('{length_unit}')"
100 | )
101 |
102 | new_options.length_unit = length_unit
103 | requested.append(new_options)
104 |
105 | for child_name, child in module.named_children():
106 | _get_requested_neighbor_lists(
107 | module=child,
108 | module_name=module_name + "." + child_name,
109 | requested=requested,
110 | length_unit=length_unit,
111 | )
112 |
--------------------------------------------------------------------------------
/python/vesin/vesin/metatomic/_neighbors.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from metatensor.torch import Labels, TensorBlock
3 | from metatomic.torch import NeighborListOptions, System, register_autograd_neighbors
4 |
5 | from .. import NeighborList as NeighborListNumpy
6 |
7 |
8 | try:
9 | from vesin.torch import NeighborList as NeighborListTorch
10 |
11 | except ImportError:
12 |
13 | class NeighborListTorch:
14 | def __init__(self, cutoff: float, full_list: bool):
15 | raise ValueError("torchscript=True requires `vesin-torch` as a dependency")
16 |
17 | def compute(
18 | self,
19 | points: torch.Tensor,
20 | box: torch.Tensor,
21 | periodic: bool,
22 | quantities: str,
23 | copy: bool = True,
24 | ):
25 | pass
26 |
27 |
28 | class NeighborList:
29 | """
30 | A neighbor list calculator that can be used with metatomic's models.
31 |
32 | The main difference with the other calculators is the automatic handling of
33 | different length unit between what the model expects and what the ``System`` are
34 | using.
35 |
36 | .. seealso::
37 |
38 | The :py:func:`vesin.metatomic.compute_requested_neighbors` function can be used
39 | to automatically compute and store all neighbor lists required by a given model.
40 | """
41 |
42 | def __init__(
43 | self,
44 | options: NeighborListOptions,
45 | length_unit: str,
46 | torchscript: bool = False,
47 | check_consistency: bool = False,
48 | ):
49 | """
50 | :param options: :py:class:`metatomic.torch.NeighborListOptions` defining the
51 | parameters of the neighbor list
52 | :param length_unit: unit of length used for the systems data
53 | :param torchscript: whether this function should be compatible with TorchScript
54 | or not. If ``True``, this requires installing the ``vesin-torch`` package.
55 | :param check_consistency: whether to run additional checks on the neighbor list
56 | validity
57 |
58 | Example
59 | -------
60 |
61 | >>> from vesin.metatomic import NeighborList
62 | >>> from metatomic.torch import System, NeighborListOptions
63 | >>> import torch
64 | >>> system = System(
65 | ... positions=torch.eye(3).requires_grad_(True),
66 | ... cell=4 * torch.eye(3).requires_grad_(True),
67 | ... types=torch.tensor([8, 1, 1]),
68 | ... pbc=torch.ones(3, dtype=bool),
69 | ... )
70 | >>> options = NeighborListOptions(cutoff=4.0, full_list=True, strict=False)
71 | >>> calculator = NeighborList(options, length_unit="Angstrom")
72 | >>> neighbors = calculator.compute(system)
73 | >>> neighbors
74 | TensorBlock
75 | samples (18): ['first_atom', 'second_atom', 'cell_shift_a', 'cell_shift_b', 'cell_shift_c']
76 | components (3): ['xyz']
77 | properties (1): ['distance']
78 | gradients: None
79 |
80 |
81 | The returned TensorBlock can then be registered with the system
82 |
83 | >>> system.add_neighbor_list(options, neighbors)
84 | """ # noqa: E501
85 |
86 | self.options = options
87 | self.length_unit = length_unit
88 | self.check_consistency = check_consistency
89 |
90 | if torch.jit.is_scripting() or torchscript:
91 | self._nl = NeighborListTorch(
92 | cutoff=self.options.engine_cutoff(self.length_unit),
93 | full_list=self.options.full_list,
94 | )
95 | else:
96 | self._nl = NeighborListNumpy(
97 | cutoff=self.options.engine_cutoff(self.length_unit),
98 | full_list=self.options.full_list,
99 | )
100 |
101 | # cached Labels
102 | self._components = [Labels("xyz", torch.tensor([[0], [1], [2]]))]
103 | self._properties = Labels(["distance"], torch.tensor([[0]]))
104 |
105 | def compute(self, system: System) -> TensorBlock:
106 | """
107 | Compute the neighbor list for the given :py:class:`metatomic.torch.System`.
108 |
109 | :param system: a :py:class:`metatomic.torch.System` containing data about a
110 | single structure. If the positions or cell of this system require gradients,
111 | the neighbors list values computational graph will be set accordingly.
112 |
113 | The positions and cell need to be in the length unit defined for this
114 | :py:class:`NeighborList` calculator.
115 | """
116 |
117 | # move to float64, as vesin only works in torch64
118 | points = system.positions.to(torch.float64).detach()
119 | box = system.cell.to(torch.float64).detach()
120 | if torch.all(system.pbc):
121 | periodic = True
122 | elif not torch.any(system.pbc):
123 | periodic = False
124 | else:
125 | raise NotImplementedError(
126 | "vesin currently does not support mixed periodic and non-periodic "
127 | "boundary conditions"
128 | )
129 |
130 | # computes neighbor list
131 | (P, S, D) = self._nl.compute(
132 | points=points, box=box, periodic=periodic, quantities="PSD", copy=True
133 | )
134 | P = torch.as_tensor(P, dtype=torch.int32)
135 | S = torch.as_tensor(S, dtype=torch.int32)
136 | D = torch.as_tensor(D, dtype=system.positions.dtype)
137 |
138 | # converts to a suitable TensorBlock format
139 | neighbors = TensorBlock(
140 | D.reshape(-1, 3, 1).to(system.positions.dtype),
141 | samples=Labels(
142 | names=[
143 | "first_atom",
144 | "second_atom",
145 | "cell_shift_a",
146 | "cell_shift_b",
147 | "cell_shift_c",
148 | ],
149 | values=torch.hstack([P, S]),
150 | ),
151 | components=self._components,
152 | properties=self._properties,
153 | )
154 |
155 | register_autograd_neighbors(
156 | system, neighbors, check_consistency=self.check_consistency
157 | )
158 |
159 | return neighbors
160 |
--------------------------------------------------------------------------------
/python/vesin_torch/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include pyproject.toml
2 | include LICENSE
3 | include VERSION
4 |
5 | recursive-include lib *
6 | recursive-include build-backend *.py
7 |
--------------------------------------------------------------------------------
/python/vesin_torch/README.md:
--------------------------------------------------------------------------------
1 | ../../README.md
--------------------------------------------------------------------------------
/python/vesin_torch/VERSION:
--------------------------------------------------------------------------------
1 | ../../vesin/VERSION
--------------------------------------------------------------------------------
/python/vesin_torch/build-backend/backend.py:
--------------------------------------------------------------------------------
1 | # This is a custom Python build backend wrapping setuptool's to add a build-time
2 | # dependencies on torch/cmake when building the wheel and not the sdist
3 | import os
4 |
5 | from setuptools import build_meta
6 |
7 |
8 | ROOT = os.path.realpath(os.path.dirname(__file__))
9 |
10 | FORCED_TORCH_VERSION = os.environ.get("VESIN_TORCH_BUILD_WITH_TORCH_VERSION")
11 | if FORCED_TORCH_VERSION is not None:
12 | TORCH_DEP = f"torch =={FORCED_TORCH_VERSION}"
13 | else:
14 | TORCH_DEP = "torch >=2.3"
15 |
16 | # ==================================================================================== #
17 | # Build backend functions definition #
18 | # ==================================================================================== #
19 |
20 | # Use the default version of these
21 | prepare_metadata_for_build_wheel = build_meta.prepare_metadata_for_build_wheel
22 | get_requires_for_build_sdist = build_meta.get_requires_for_build_sdist
23 | build_wheel = build_meta.build_wheel
24 | build_sdist = build_meta.build_sdist
25 |
26 |
27 | # Special dependencies to build the wheels
28 | def get_requires_for_build_wheel(config_settings=None):
29 | defaults = build_meta.get_requires_for_build_wheel(config_settings)
30 | return defaults + ["cmake", TORCH_DEP]
31 |
--------------------------------------------------------------------------------
/python/vesin_torch/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "vesin-torch"
3 | dynamic = ["version", "dependencies"]
4 | requires-python = ">=3.9"
5 | authors = [
6 | {name = "Guillaume Fraux", email = "guillaume.fraux@epfl.ch"},
7 | ]
8 |
9 | readme = "README.md"
10 | license = "BSD-3-Clause"
11 | description = "Computing neighbor lists for atomistic system, in TorchScript"
12 |
13 | classifiers = [
14 | "Development Status :: 4 - Beta",
15 | "Intended Audience :: Science/Research",
16 | "Operating System :: POSIX",
17 | "Operating System :: MacOS :: MacOS X",
18 | "Operating System :: Microsoft :: Windows",
19 | "Programming Language :: Python",
20 | "Programming Language :: Python :: 3",
21 | "Topic :: Scientific/Engineering",
22 | "Topic :: Scientific/Engineering :: Bio-Informatics",
23 | "Topic :: Scientific/Engineering :: Chemistry",
24 | "Topic :: Scientific/Engineering :: Physics",
25 | "Topic :: Software Development :: Libraries",
26 | "Topic :: Software Development :: Libraries :: Python Modules",
27 | ]
28 |
29 | [project.urls]
30 | homepage = "https://github.com/Luthaf/vesin/"
31 | documentation = "https://luthaf.fr/vesin/"
32 | repository = "https://github.com/Luthaf/vesin/"
33 |
34 | ### ======================================================================== ###
35 |
36 | [build-system]
37 | requires = [
38 | "setuptools >=77",
39 | "wheel >=0.41",
40 | ]
41 | # use a custom build backend to add a dependency on torch/cmake only when
42 | # building wheels
43 | build-backend = "backend"
44 | backend-path = ["build-backend"]
45 |
46 | [tool.setuptools]
47 | zip-safe = false
48 |
49 | [tool.setuptools.packages.find]
50 | include = ["vesin*"]
51 | namespaces = true
52 |
--------------------------------------------------------------------------------
/python/vesin_torch/setup.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import shutil
4 | import subprocess
5 | import sys
6 |
7 | from setuptools import Extension, setup
8 | from setuptools.command.bdist_egg import bdist_egg
9 | from setuptools.command.build_ext import build_ext
10 | from setuptools.command.sdist import sdist
11 | from wheel.bdist_wheel import bdist_wheel
12 |
13 |
14 | ROOT = os.path.realpath(os.path.dirname(__file__))
15 |
16 | VESIN_BUILD_TYPE = os.environ.get("VESIN_BUILD_TYPE", "release")
17 | if VESIN_BUILD_TYPE not in ["debug", "release"]:
18 | raise Exception(
19 | f"invalid build type passed: '{VESIN_BUILD_TYPE}', "
20 | "expected 'debug' or 'release'"
21 | )
22 |
23 |
24 | class universal_wheel(bdist_wheel):
25 | # When building the wheel, the `wheel` package assumes that if we have a
26 | # binary extension then we are linking to `libpython.so`; and thus the wheel
27 | # is only usable with a single python version. This is not the case for
28 | # here, and the wheel will be compatible with any Python >=3.7. This is
29 | # tracked in https://github.com/pypa/wheel/issues/185, but until then we
30 | # manually override the wheel tag.
31 | def get_tag(self):
32 | tag = bdist_wheel.get_tag(self)
33 | # tag[2:] contains the os/arch tags, we want to keep them
34 | return ("py3", "none") + tag[2:]
35 |
36 |
37 | class cmake_ext(build_ext):
38 | """
39 | Build the native library using cmake
40 | """
41 |
42 | def run(self):
43 | import torch
44 |
45 | source_dir = os.path.join(ROOT, "lib")
46 | if not os.path.exists(source_dir):
47 | # we are building from a checkout
48 | source_dir = os.path.join(ROOT, "..", "..", "vesin")
49 |
50 | build_dir = os.path.join(ROOT, "build", "cmake-build")
51 |
52 | # Install the shared library in a prefix matching the torch version used to
53 | # compile the code. This allows having multiple version of this shared library
54 | # inside the wheel; and dynamically pick the right one.
55 | torch_major, torch_minor, *_ = torch.__version__.split(".")
56 | install_dir = os.path.join(
57 | os.path.realpath(self.build_lib),
58 | "vesin",
59 | "torch",
60 | f"torch-{torch_major}.{torch_minor}",
61 | )
62 |
63 | os.makedirs(build_dir, exist_ok=True)
64 |
65 | cmake_options = [
66 | f"-DCMAKE_INSTALL_PREFIX={install_dir}",
67 | f"-DCMAKE_BUILD_TYPE={VESIN_BUILD_TYPE}",
68 | f"-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}",
69 | "-DBUILD_SHARED_LIBS=ON",
70 | "-DVESIN_TORCH=ON",
71 | ]
72 |
73 | subprocess.run(
74 | ["cmake", source_dir, *cmake_options],
75 | cwd=build_dir,
76 | check=True,
77 | )
78 | subprocess.run(
79 | ["cmake", "--build", build_dir, "--target", "install"],
80 | check=True,
81 | )
82 |
83 | # do not include the non-torch vesin lib in the wheel
84 | for file in glob.glob(os.path.join(install_dir, "bin", "*")):
85 | if "vesin_torch" not in os.path.basename(file):
86 | os.unlink(file)
87 |
88 | for file in glob.glob(os.path.join(install_dir, "lib", "*")):
89 | if "vesin_torch" not in os.path.basename(file):
90 | os.unlink(file)
91 |
92 | for file in glob.glob(os.path.join(install_dir, "include", "*")):
93 | if "vesin_torch" not in os.path.basename(file):
94 | os.unlink(file)
95 |
96 |
97 | class bdist_egg_disabled(bdist_egg):
98 | """Disabled version of bdist_egg
99 |
100 | Prevents setup.py install performing setuptools' default easy_install,
101 | which it should never ever do.
102 | """
103 |
104 | def run(self):
105 | sys.exit(
106 | "Aborting implicit building of eggs.\nUse `pip install .` or "
107 | "`python -m build --wheel . && pip install "
108 | "dist/metatensor_core-*.whl` to install from source."
109 | )
110 |
111 |
112 | class sdist_with_lib(sdist):
113 | """
114 | Create a sdist including the code for the native library
115 | """
116 |
117 | def run(self):
118 | # generate extra files
119 | shutil.copytree(os.path.join(ROOT, "..", "..", "vesin"), "lib")
120 |
121 | # run original sdist
122 | super().run()
123 |
124 | # cleanup
125 | shutil.rmtree("lib")
126 |
127 |
128 | if __name__ == "__main__":
129 | if sys.platform == "win32":
130 | # On Windows, starting with PyTorch 2.3, the file shm.dll in torch has a
131 | # dependency on mkl DLLs. When building the code using pip build isolation, pip
132 | # installs the mkl package in a place where the os is not trying to load
133 | #
134 | # This is a very similar fix to https://github.com/pytorch/pytorch/pull/126095,
135 | # except only applying when importing torch from a build-isolation virtual
136 | # environment created by pip (`python -m build` does not seems to suffer from
137 | # this).
138 | import wheel
139 |
140 | pip_virtualenv = os.path.realpath(
141 | os.path.join(
142 | os.path.dirname(wheel.__file__),
143 | "..",
144 | "..",
145 | "..",
146 | "..",
147 | )
148 | )
149 | mkl_dll_dir = os.path.join(
150 | pip_virtualenv,
151 | "normal",
152 | "Library",
153 | "bin",
154 | )
155 |
156 | if os.path.exists(mkl_dll_dir):
157 | os.add_dll_directory(mkl_dll_dir)
158 |
159 | # End of Windows/MKL/PIP hack
160 |
161 | install_requires = []
162 | forced_torch_version = os.environ.get("VESIN_TORCH_BUILD_WITH_TORCH_VERSION")
163 | if forced_torch_version is not None:
164 | install_requires.append(f"torch =={forced_torch_version}")
165 | else:
166 | install_requires.append("torch >=2.3")
167 |
168 | setup(
169 | version=open("VERSION").read().strip(),
170 | install_requires=install_requires,
171 | ext_modules=[
172 | # only declare the extension, it is built & copied as required by cmake
173 | # in the build_ext command
174 | Extension(name="vesin_torch", sources=[]),
175 | ],
176 | cmdclass={
177 | "sdist": sdist_with_lib,
178 | "build_ext": cmake_ext,
179 | "bdist_egg": bdist_egg if "bdist_egg" in sys.argv else bdist_egg_disabled,
180 | "bdist_wheel": universal_wheel,
181 | },
182 | )
183 |
--------------------------------------------------------------------------------
/python/vesin_torch/tests/test_autograd.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from vesin.torch import NeighborList
5 |
6 |
7 | @pytest.mark.parametrize("full_list", [True, False])
8 | @pytest.mark.parametrize("requires_grad", [(True, True), (False, True), (True, False)])
9 | @pytest.mark.parametrize("quantities", ["ijS", "D", "d", "ijSDd"])
10 | def test_autograd(full_list, requires_grad, quantities):
11 | torch.manual_seed(0xDEADBEEF)
12 |
13 | points_fractional = torch.rand((34, 3), dtype=torch.float64)
14 | box = torch.diag(5 * torch.rand(3, dtype=torch.float64))
15 | box += torch.rand((3, 3), dtype=torch.float64)
16 |
17 | points = points_fractional @ box
18 |
19 | points.requires_grad_(requires_grad[0])
20 | box.requires_grad_(requires_grad[1])
21 |
22 | calculator = NeighborList(cutoff=7.8, full_list=full_list)
23 |
24 | def compute(points, box):
25 | results = calculator.compute(points, box, periodic=True, quantities=quantities)
26 | return results
27 |
28 | torch.autograd.gradcheck(compute, (points, box), fast_mode=True)
29 |
--------------------------------------------------------------------------------
/python/vesin_torch/tests/test_metatensor.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 |
3 | import pytest
4 | import torch
5 | from metatensor.torch import Labels, TensorMap
6 | from metatensor.torch.atomistic import (
7 | MetatensorAtomisticModel,
8 | ModelCapabilities,
9 | ModelMetadata,
10 | ModelOutput,
11 | NeighborListOptions,
12 | System,
13 | )
14 |
15 | from vesin.torch.metatensor import NeighborList, compute_requested_neighbors
16 |
17 |
18 | def test_errors():
19 | positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64)
20 | cell = 4 * torch.eye(3, dtype=torch.float64)
21 | system = System(
22 | positions=positions,
23 | cell=cell,
24 | pbc=torch.ones(3, dtype=bool),
25 | types=torch.tensor([6, 8]),
26 | )
27 |
28 | options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True)
29 | calculator = NeighborList(options, length_unit="A")
30 |
31 | system.pbc[0] = False
32 | message = (
33 | "vesin currently does not support mixed periodic and non-periodic "
34 | "boundary conditions"
35 | )
36 | with pytest.raises(NotImplementedError, match=message):
37 | calculator.compute(system)
38 |
39 |
40 | def test_script():
41 | positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64)
42 | cell = 4 * torch.eye(3, dtype=torch.float64)
43 | system = System(
44 | positions=positions,
45 | cell=cell,
46 | pbc=torch.ones(3, dtype=bool),
47 | types=torch.tensor([6, 8]),
48 | )
49 |
50 | options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True)
51 | calculator = torch.jit.script(NeighborList(options, length_unit="A"))
52 | calculator.compute(system)
53 |
54 |
55 | def test_backward():
56 | positions = torch.tensor(
57 | [[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64, requires_grad=True
58 | )
59 | cell = (4 * torch.eye(3, dtype=torch.float64)).clone().requires_grad_(True)
60 | system = System(
61 | positions=positions,
62 | cell=cell,
63 | pbc=torch.ones(3, dtype=bool),
64 | types=torch.tensor([6, 8]),
65 | )
66 |
67 | options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True)
68 | calculator = NeighborList(options, length_unit="A")
69 | neighbors = calculator.compute(system)
70 |
71 | value = ((neighbors.values) ** 2).sum() * torch.linalg.det(cell)
72 | value.backward()
73 |
74 | # check there are gradients, and they are not zero
75 | assert positions.grad is not None
76 | assert cell.grad is not None
77 | assert torch.linalg.norm(positions.grad) > 0
78 | assert torch.linalg.norm(cell.grad) > 0
79 |
80 |
81 | class InnerModule(torch.nn.Module):
82 | def requested_neighbor_lists(self) -> List[NeighborListOptions]:
83 | return [NeighborListOptions(cutoff=3.4, full_list=False, strict=True)]
84 |
85 |
86 | class OuterModule(torch.nn.Module):
87 | def __init__(self):
88 | super().__init__()
89 | self.inner = InnerModule()
90 |
91 | def requested_neighbor_lists(self) -> List[NeighborListOptions]:
92 | return [NeighborListOptions(cutoff=5.2, full_list=True, strict=False)]
93 |
94 | def forward(
95 | self,
96 | systems: List[System],
97 | outputs: Dict[str, ModelOutput],
98 | selected_atoms: Optional[Labels],
99 | ) -> Dict[str, TensorMap]:
100 | return {}
101 |
102 |
103 | def test_model():
104 | positions = torch.tensor(
105 | [[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64, requires_grad=True
106 | )
107 | cell = (4 * torch.eye(3, dtype=torch.float64)).clone().requires_grad_(True)
108 | pbc = torch.ones(3, dtype=bool)
109 | types = torch.tensor([6, 8])
110 | systems = [
111 | System(positions=positions, cell=cell, pbc=pbc, types=types),
112 | System(positions=positions, cell=cell, pbc=pbc, types=types),
113 | ]
114 |
115 | # Using a "raw" model
116 | model = OuterModule()
117 | compute_requested_neighbors(
118 | systems=systems, system_length_unit="A", model=model, model_length_unit="A"
119 | )
120 |
121 | for system in systems:
122 | all_options = system.known_neighbor_lists()
123 | assert len(all_options) == 2
124 | assert all_options[0].requestors() == ["OuterModule"]
125 | assert all_options[0].cutoff == 5.2
126 | assert all_options[1].requestors() == ["OuterModule.inner"]
127 | assert all_options[1].cutoff == 3.4
128 |
129 | # Using a MetatensorAtomisticModel model
130 | capabilities = ModelCapabilities(
131 | length_unit="A",
132 | interaction_range=6.0,
133 | supported_devices=["cpu"],
134 | dtype="float64",
135 | )
136 | model = MetatensorAtomisticModel(model.eval(), ModelMetadata(), capabilities)
137 | compute_requested_neighbors(
138 | systems=System(positions=positions, cell=cell, pbc=pbc, types=types),
139 | system_length_unit="A",
140 | model=model,
141 | )
142 |
143 | for system in systems:
144 | all_options = system.known_neighbor_lists()
145 | assert len(all_options) == 2
146 | assert all_options[0].requestors() == ["OuterModule"]
147 | assert all_options[0].cutoff == 5.2
148 | assert all_options[1].requestors() == ["OuterModule.inner"]
149 | assert all_options[1].cutoff == 3.4
150 |
151 | message = (
152 | "the given `model_length_unit` \\(nm\\) does not match the model "
153 | "capabilities \\(A\\)"
154 | )
155 | with pytest.raises(ValueError, match=message):
156 | compute_requested_neighbors(
157 | systems=System(positions=positions, cell=cell, pbc=pbc, types=types),
158 | system_length_unit="A",
159 | model=model,
160 | model_length_unit="nm",
161 | )
162 |
--------------------------------------------------------------------------------
/python/vesin_torch/tests/test_metatomic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from metatomic.torch import NeighborListOptions, System
3 |
4 | from vesin.metatomic import NeighborList
5 |
6 |
7 | def test_script():
8 | positions = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 2.0]], dtype=torch.float64)
9 | cell = 4 * torch.eye(3, dtype=torch.float64)
10 | system = System(
11 | positions=positions,
12 | cell=cell,
13 | pbc=torch.ones(3, dtype=bool),
14 | types=torch.tensor([6, 8]),
15 | )
16 |
17 | options = NeighborListOptions(cutoff=3.5, full_list=True, strict=True)
18 | calculator = torch.jit.script(
19 | NeighborList(options, length_unit="A", torchscript=True)
20 | )
21 | calculator.compute(system)
22 |
--------------------------------------------------------------------------------
/python/vesin_torch/tests/test_neighbors.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import pytest
4 | import torch
5 |
6 | from vesin.torch import NeighborList
7 |
8 |
9 | def test_errors():
10 | points = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 2.0]], dtype=torch.float64)
11 | box = torch.zeros((3, 3), dtype=torch.float64)
12 |
13 | calculator = NeighborList(cutoff=2.8, full_list=True)
14 |
15 | message = "only float64 dtype is supported in vesin"
16 | with pytest.raises(ValueError, match=message):
17 | calculator.compute(
18 | points.to(torch.float32),
19 | box.to(torch.float32),
20 | periodic=False,
21 | quantities="ij",
22 | )
23 |
24 | message = "expected `points` and `box` to have the same dtype, got Double and Float"
25 | with pytest.raises(ValueError, match=message):
26 | calculator.compute(
27 | points,
28 | box.to(torch.float32),
29 | periodic=False,
30 | quantities="ij",
31 | )
32 |
33 | message = "expected `points` and `box` to have the same device, got cpu and meta"
34 | with pytest.raises(ValueError, match=message):
35 | calculator.compute(
36 | points,
37 | box.to(device="meta"),
38 | periodic=False,
39 | quantities="ij",
40 | )
41 |
42 | message = "unexpected character in `quantities`: Q"
43 | with pytest.raises(ValueError, match=message):
44 | calculator.compute(
45 | points,
46 | box,
47 | periodic=False,
48 | quantities="ijQ",
49 | )
50 |
51 | message = "device meta is not supported in vesin"
52 | with pytest.raises(RuntimeError, match=message):
53 | calculator.compute(
54 | points.to(device="meta"),
55 | box.to(device="meta"),
56 | periodic=False,
57 | quantities="ij",
58 | )
59 |
60 |
61 | @pytest.mark.parametrize("quantities", ["ijS", "D", "d", "ijSDd"])
62 | def test_all_alone_no_neighbors(quantities):
63 | points = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], dtype=torch.float64)
64 | box = torch.eye(3, dtype=torch.float64)
65 |
66 | calculator = NeighborList(cutoff=0.1, full_list=True)
67 | outputs = calculator.compute(points, box, True, quantities)
68 |
69 | if "ij" in quantities:
70 | assert list(outputs[quantities.index("i")].shape) == [0]
71 | assert list(outputs[quantities.index("j")].shape) == [0]
72 |
73 | if "S" in quantities:
74 | assert list(outputs[quantities.index("S")].shape) == [0, 3]
75 |
76 | if "D" in quantities:
77 | assert list(outputs[quantities.index("D")].shape) == [0, 3]
78 | assert not outputs[quantities.index("D")].requires_grad
79 |
80 | if "d" in quantities:
81 | assert list(outputs[quantities.index("d")].shape) == [0]
82 | assert not outputs[quantities.index("d")].requires_grad
83 |
84 | points.requires_grad_(True)
85 | box.requires_grad_(True)
86 | outputs = calculator.compute(points, box, True, quantities)
87 |
88 | if "ij" in quantities:
89 | assert list(outputs[quantities.index("i")].shape) == [0]
90 | assert list(outputs[quantities.index("j")].shape) == [0]
91 |
92 | if "S" in quantities:
93 | assert list(outputs[quantities.index("S")].shape) == [0, 3]
94 |
95 | if "D" in quantities:
96 | assert list(outputs[quantities.index("D")].shape) == [0, 3]
97 | assert outputs[quantities.index("D")].requires_grad
98 |
99 | if "d" in quantities:
100 | assert list(outputs[quantities.index("d")].shape) == [0]
101 | assert outputs[quantities.index("d")].requires_grad
102 |
103 |
104 | class NeighborListWrap:
105 | def __init__(self, cutoff: float, full_list: bool):
106 | self._c = NeighborList(cutoff=cutoff, full_list=full_list)
107 |
108 | def compute(
109 | self,
110 | points: torch.Tensor,
111 | box: torch.Tensor,
112 | periodic: bool,
113 | quantities: str,
114 | copy: bool,
115 | ) -> List[torch.Tensor]:
116 | return self._c.compute(
117 | points=points,
118 | box=box,
119 | periodic=periodic,
120 | quantities=quantities,
121 | copy=copy,
122 | )
123 |
124 |
125 | def test_script():
126 | class TestModule(torch.nn.Module):
127 | def forward(self, x: NeighborListWrap) -> NeighborListWrap:
128 | return x
129 |
130 | module = TestModule()
131 | module = torch.jit.script(module)
132 |
--------------------------------------------------------------------------------
/python/vesin_torch/vesin/torch/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib.metadata
2 |
3 | from ._c_lib import _load_library
4 | from ._neighbors import NeighborList # noqa: F401
5 |
6 |
7 | __version__ = importlib.metadata.version("vesin-torch")
8 |
9 | _load_library()
10 |
--------------------------------------------------------------------------------
/python/vesin_torch/vesin/torch/_c_lib.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import re
4 | import sys
5 | from collections import namedtuple
6 |
7 | import torch
8 |
9 |
10 | Version = namedtuple("Version", ["major", "minor", "patch"])
11 |
12 |
13 | def parse_version(version):
14 | match = re.match(r"(\d+)\.(\d+)\.(\d+).*", version)
15 | if match:
16 | return Version(*map(int, match.groups()))
17 | else:
18 | raise ValueError("Invalid version string format")
19 |
20 |
21 | _HERE = os.path.realpath(os.path.dirname(__file__))
22 |
23 |
24 | def _lib_path():
25 | torch_version = parse_version(torch.__version__)
26 | expected_prefix = os.path.join(
27 | _HERE, f"torch-{torch_version.major}.{torch_version.minor}"
28 | )
29 | if os.path.exists(expected_prefix):
30 | if sys.platform.startswith("darwin"):
31 | path = os.path.join(expected_prefix, "lib", "libvesin_torch.dylib")
32 | windows = False
33 | elif sys.platform.startswith("linux"):
34 | path = os.path.join(expected_prefix, "lib", "libvesin_torch.so")
35 | windows = False
36 | elif sys.platform.startswith("win"):
37 | path = os.path.join(expected_prefix, "bin", "vesin_torch.dll")
38 | windows = True
39 | else:
40 | raise ImportError("Unknown platform. Please edit this file")
41 |
42 | if os.path.isfile(path):
43 | if windows:
44 | _check_dll(path)
45 | return path
46 | else:
47 | raise ImportError("Could not find vesin_torch shared library at " + path)
48 |
49 | # gather which torch version(s) the current install was built
50 | # with to create the error message
51 | existing_versions = []
52 | for prefix in glob.glob(os.path.join(_HERE, "torch-*")):
53 | existing_versions.append(os.path.basename(prefix)[6:])
54 |
55 | if len(existing_versions) == 1:
56 | raise ImportError(
57 | f"Trying to load vesin-torch with torch v{torch.__version__}, "
58 | f"but it was compiled against torch v{existing_versions[0]}, which "
59 | "is not ABI compatible"
60 | )
61 | else:
62 | all_versions = ", ".join(map(lambda version: f"v{version}", existing_versions))
63 | raise ImportError(
64 | f"Trying to load vesin-torch with torch v{torch.__version__}, "
65 | f"we found builds for torch {all_versions}; which are not ABI compatible.\n"
66 | "You can try to re-install from source with "
67 | "`pip install vesin-torch --no-binary=vesin-torch`"
68 | )
69 |
70 |
71 | def _check_dll(path):
72 | """
73 | Check if the DLL pointer size matches Python (32-bit or 64-bit)
74 | """
75 | import platform
76 | import struct
77 |
78 | IMAGE_FILE_MACHINE_I386 = 332
79 | IMAGE_FILE_MACHINE_AMD64 = 34404
80 |
81 | machine = None
82 | with open(path, "rb") as fd:
83 | header = fd.read(2).decode(encoding="utf-8", errors="strict")
84 | if header != "MZ":
85 | raise ImportError(path + " is not a DLL")
86 | else:
87 | fd.seek(60)
88 | header = fd.read(4)
89 | header_offset = struct.unpack(" List[torch.Tensor]:
25 | """
26 | Compute the neighbor list for the system defined by ``positions``, ``box``, and
27 | ``periodic``; returning the requested ``quantities``.
28 |
29 | ``quantities`` can contain any combination of the following values:
30 |
31 | - ``"i"`` to get the index of the first point in the pair
32 | - ``"j"`` to get the index of the second point in the pair
33 | - ``"P"`` to get the indexes of the two points in the pair simultaneously
34 | - ``"S"`` to get the periodic shift of the pair
35 | - ``"d"`` to get the distance between points in the pair
36 | - ``"D"`` to get the distance vector between points in the pair
37 |
38 | :param points: positions of all points in the system
39 | :param box: bounding box of the system
40 | :param periodic: should we use periodic boundary conditions?
41 | :param quantities: quantities to return, defaults to "ij"
42 | :param copy: should we copy the returned quantities, defaults to ``True``.
43 | Setting this to ``False`` might be a bit faster, but the returned tensors
44 | are view inside this class, and will be invalidated whenever this class is
45 | garbage collected or used to run a new calculation.
46 |
47 | :return: list of :py:class:`torch.Tensor` as indicated by ``quantities``.
48 | """
49 |
50 | return self._c.compute(
51 | points=points,
52 | box=box,
53 | periodic=periodic,
54 | quantities=quantities,
55 | copy=copy,
56 | )
57 |
--------------------------------------------------------------------------------
/python/vesin_torch/vesin/torch/metatensor/__init__.py:
--------------------------------------------------------------------------------
1 | from ._model import compute_requested_neighbors # noqa: F401
2 | from ._neighbors import NeighborList # noqa: F401
3 |
--------------------------------------------------------------------------------
/python/vesin_torch/vesin/torch/metatensor/_model.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | import torch
4 |
5 | from ._neighbors import NeighborList
6 |
7 |
8 | try:
9 | from metatensor.torch.atomistic import (
10 | MetatensorAtomisticModel,
11 | ModelInterface,
12 | NeighborListOptions,
13 | System,
14 | )
15 |
16 | _HAS_METATENSOR = True
17 | except ModuleNotFoundError:
18 | _HAS_METATENSOR = False
19 |
20 | class MetatensorAtomisticModel:
21 | pass
22 |
23 | class ModelInterface:
24 | pass
25 |
26 | class NeighborListOptions:
27 | pass
28 |
29 | class System:
30 | pass
31 |
32 |
33 | def compute_requested_neighbors(
34 | systems: Union[List[System], System],
35 | system_length_unit: str,
36 | model: Union[MetatensorAtomisticModel, ModelInterface],
37 | model_length_unit: Optional[str] = None,
38 | ):
39 | """
40 | Compute all neighbors lists requested by the ``model`` through
41 | ``requested_neighbor_lists()`` member functions, and store them inside all the
42 | ``systems``.
43 |
44 | :param systems: Single system or list of systems for which we need to compute the
45 | neighbor lists that the model requires.
46 | :param system_length_unit: unit of length used by the data in ``systems``
47 | :param model: :py:class:`MetatensorAtomisticModel` or any ``torch.nn.Module``
48 | following the :py:class:`ModelInterface`
49 | :param model_length_unit: unit of length used by the model, optional. This is only
50 | required when giving a raw model instead of a
51 | :py:class:`MetatensorAtomisticModel`.
52 | """
53 |
54 | if isinstance(model, MetatensorAtomisticModel):
55 | if model_length_unit is not None:
56 | if model.capabilities().length_unit != model_length_unit:
57 | raise ValueError(
58 | f"the given `model_length_unit` ({model_length_unit}) does not "
59 | f"match the model capabilities ({model.capabilities().length_unit})"
60 | )
61 |
62 | all_options = model.requested_neighbor_lists()
63 | elif isinstance(model, torch.nn.Module):
64 | if model_length_unit is None:
65 | raise ValueError(
66 | "`model_length_unit` parameter is required when not "
67 | "using MetatensorAtomisticModel"
68 | )
69 |
70 | all_options = []
71 | _get_requested_neighbor_lists(
72 | model, model.__class__.__name__, all_options, model_length_unit
73 | )
74 |
75 | if not isinstance(systems, list):
76 | systems = [systems]
77 |
78 | for options in all_options:
79 | calculator = NeighborList(options, system_length_unit)
80 | for system in systems:
81 | neighbors = calculator.compute(system)
82 | system.add_neighbor_list(options, neighbors)
83 |
84 |
85 | def _get_requested_neighbor_lists(
86 | module: torch.nn.Module,
87 | module_name: str,
88 | requested: List[NeighborListOptions],
89 | length_unit: str,
90 | ):
91 | """
92 | Recursively extract the requested neighbor lists from a non-exported metatensor
93 | atomistic model.
94 | """
95 | if hasattr(module, "requested_neighbor_lists"):
96 | for new_options in module.requested_neighbor_lists():
97 | new_options.add_requestor(module_name)
98 |
99 | already_requested = False
100 | for existing in requested:
101 | if existing == new_options:
102 | already_requested = True
103 | for requestor in new_options.requestors():
104 | existing.add_requestor(requestor)
105 |
106 | if not already_requested:
107 | if new_options.length_unit not in ["", length_unit]:
108 | raise ValueError(
109 | f"NeighborsListOptions from {module_name} already have a "
110 | f"length unit ('{new_options.length_unit}') which does not "
111 | f"match the model length units ('{length_unit}')"
112 | )
113 |
114 | new_options.length_unit = length_unit
115 | requested.append(new_options)
116 |
117 | for child_name, child in module.named_children():
118 | _get_requested_neighbor_lists(
119 | module=child,
120 | module_name=module_name + "." + child_name,
121 | requested=requested,
122 | length_unit=length_unit,
123 | )
124 |
--------------------------------------------------------------------------------
/python/vesin_torch/vesin/torch/metatensor/_neighbors.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .. import NeighborList as NeighborListTorch
4 |
5 |
6 | try: # only define a metatensor adapter if metatensor is available
7 | from metatensor.torch import Labels, TensorBlock
8 | from metatensor.torch.atomistic import NeighborListOptions, System
9 |
10 | _HAS_METATENSOR = True
11 | except ModuleNotFoundError:
12 | _HAS_METATENSOR = False
13 |
14 | class Labels:
15 | pass
16 |
17 | class TensorBlock:
18 | pass
19 |
20 | class System:
21 | pass
22 |
23 | class NeighborListOptions:
24 | pass
25 |
26 |
27 | class NeighborList:
28 | """
29 | A neighbor list calculator that can be used with metatensor's atomistic models.
30 |
31 | The main difference with the other calculators is the automatic handling of
32 | different length unit between what the model expects and what the ``System`` are
33 | using.
34 |
35 | .. seealso::
36 |
37 | The :py:func:`vesin.torch.metatensor.compute_requested_neighbors` function can
38 | be used to automatically compute and store all neighbor lists required by a
39 | given model.
40 | """
41 |
42 | def __init__(self, options: NeighborListOptions, length_unit: str):
43 | """
44 | :param options: :py:class:`metatensor.torch.atomistic.NeighborListOptions`
45 | defining the parameters of the neighbor list
46 | :param length_unit: unit of length used for the systems data
47 |
48 | Example
49 | -------
50 |
51 | >>> from vesin.torch.metatensor import NeighborList
52 | >>> from metatensor.torch.atomistic import System, NeighborListOptions
53 | >>> import torch
54 | >>> system = System(
55 | ... positions=torch.eye(3).requires_grad_(True),
56 | ... cell=4 * torch.eye(3).requires_grad_(True),
57 | ... types=torch.tensor([8, 1, 1]),
58 | ... pbc=torch.ones(3, dtype=bool),
59 | ... )
60 | >>> options = NeighborListOptions(cutoff=4.0, full_list=True, strict=False)
61 | >>> calculator = NeighborList(options, length_unit="Angstrom")
62 | >>> neighbors = calculator.compute(system)
63 | >>> neighbors
64 | TensorBlock
65 | samples (18): ['first_atom', 'second_atom', 'cell_shift_a', 'cell_shift_b', 'cell_shift_c']
66 | components (3): ['xyz']
67 | properties (1): ['distance']
68 | gradients: None
69 |
70 |
71 | The returned TensorBlock can then be registered with the system
72 |
73 | >>> system.add_neighbor_list(options, neighbors)
74 | """ # noqa: E501
75 |
76 | if not torch.jit.is_scripting():
77 | if not _HAS_METATENSOR:
78 | raise ModuleNotFoundError(
79 | "`vesin.metatensor` requires the `metatensor-torch` package"
80 | )
81 | self.options = options
82 | self.length_unit = length_unit
83 | self._nl = NeighborListTorch(
84 | cutoff=self.options.engine_cutoff(self.length_unit),
85 | full_list=self.options.full_list,
86 | )
87 |
88 | # cached Labels
89 | self._components = [Labels("xyz", torch.tensor([[0], [1], [2]]))]
90 | self._properties = Labels(["distance"], torch.tensor([[0]]))
91 |
92 | def compute(self, system: System) -> TensorBlock:
93 | """
94 | Compute the neighbor list for the given
95 | :py:class:`metatensor.torch.atomistic.System`.
96 |
97 | :param system: a :py:class:`metatensor.torch.atomistic.System` containing the
98 | data about a structure. If the positions or cell of this system require
99 | gradients, the neighbors list values computational graph will be set
100 | accordingly.
101 |
102 | The positions and cell need to be in the length unit defined for this
103 | :py:class:`NeighborList` calculator.
104 | """
105 |
106 | # move to float64, as vesin only works in torch64
107 | points = system.positions.to(torch.float64)
108 | box = system.cell.to(torch.float64)
109 | if torch.all(system.pbc):
110 | periodic = True
111 | elif not torch.any(system.pbc):
112 | periodic = False
113 | else:
114 | raise NotImplementedError(
115 | "vesin currently does not support mixed periodic and non-periodic "
116 | "boundary conditions"
117 | )
118 |
119 | # computes neighbor list
120 | (P, S, D) = self._nl.compute(
121 | points=points, box=box, periodic=periodic, quantities="PSD", copy=True
122 | )
123 |
124 | # converts to a suitable TensorBlock format
125 | neighbors = TensorBlock(
126 | D.reshape(-1, 3, 1).to(system.positions.dtype),
127 | samples=Labels(
128 | names=[
129 | "first_atom",
130 | "second_atom",
131 | "cell_shift_a",
132 | "cell_shift_b",
133 | "cell_shift_c",
134 | ],
135 | values=torch.hstack([P, S]),
136 | ),
137 | components=self._components,
138 | properties=self._properties,
139 | )
140 |
141 | return neighbors
142 |
--------------------------------------------------------------------------------
/ruff.toml:
--------------------------------------------------------------------------------
1 | [lint]
2 | select = ["E", "F", "B", "I"]
3 | ignore = ["B018", "B904"]
4 |
5 | [lint.isort]
6 | lines-after-imports = 2
7 | known-first-party = ["vesin"]
8 |
9 | [format]
10 | docstring-code-format = true
11 |
--------------------------------------------------------------------------------
/scripts/clean-python.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # This script removes all temporary files created by Python during
4 | # installation and tests running.
5 |
6 | set -eux
7 |
8 | ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")"/.. && pwd)
9 | cd "$ROOT_DIR"
10 |
11 | rm -rf dist
12 | rm -rf build
13 | rm -rf docs/build
14 |
15 | rm -rf python/vesin/dist
16 | rm -rf python/vesin/build
17 |
18 | rm -rf python/vesin_torch/dist
19 | rm -rf python/vesin_torch/build
20 |
21 | find . -name "*.egg-info" -exec rm -rf "{}" +
22 | find . -name "__pycache__" -exec rm -rf "{}" +
23 | find . -name ".coverage" -exec rm -rf "{}" +
24 |
--------------------------------------------------------------------------------
/scripts/create-torch-versions-range.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | This script updates the `Requires-Dist` information in metatensor-torch wheel METADATA
4 | to contain the range of compatible torch versions. It expects newline separated
5 | `Requires-Dist: torch ==...` information (corresponding to wheels built against a single
6 | torch version) and will print `Requires-Dist: torch >=$MIN_VERSION,<${MAX_VERSION+1}` on
7 | the standard output.
8 |
9 | This output can the be used in the merged wheel containing the build against all torch
10 | versions.
11 | """
12 | import re
13 | import sys
14 |
15 |
16 | if __name__ == "__main__":
17 | torch_versions_raw = sys.argv[1]
18 |
19 | torch_versions = []
20 | for version in torch_versions_raw.split("\n"):
21 | if version.strip() == "":
22 | continue
23 |
24 | match = re.match(r"Requires-Dist: torch[ ]?==(\d+)\.(\d+)\.\*", version)
25 | if match is None:
26 | raise ValueError(f"unexpected Requires-Dist format: {version}")
27 |
28 | major, minor = match.groups()
29 | major = int(major)
30 | minor = int(minor)
31 |
32 | version = (major, minor)
33 |
34 | if version in torch_versions:
35 | raise ValueError(f"duplicate torch version: {version}")
36 |
37 | torch_versions.append(version)
38 |
39 | torch_versions = list(sorted(torch_versions))
40 |
41 | min_version = f"{torch_versions[0][0]}.{torch_versions[0][1]}"
42 | max_version = f"{torch_versions[-1][0]}.{torch_versions[-1][1] + 1}"
43 |
44 | print(f"Requires-Dist: torch >={min_version},<{max_version}")
45 |
--------------------------------------------------------------------------------
/scripts/pytest-dont-rewrite-torch.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import torch # noqa
4 |
5 |
6 | PYTEST_DONT_REWRITE = '"""PYTEST_DONT_REWRITE"""'
7 |
8 | if __name__ == "__main__":
9 | try:
10 | # `torch.autograd.gradcheck` is the `torch.autograd.gradcheck.gradcheck``
11 | # function, so we need to pick the module from `sys.modules`
12 | path = sys.modules["torch.autograd.gradcheck"].__file__
13 |
14 | with open(path) as fd:
15 | content = fd.read()
16 |
17 | if PYTEST_DONT_REWRITE in content:
18 | sys.exit(0)
19 |
20 | with open(path, "w") as fd:
21 | print(f"rewriting {path} to add PYTEST_DONT_REWRITE")
22 | fd.write(PYTEST_DONT_REWRITE)
23 | fd.write("\n")
24 | fd.write(content)
25 |
26 | except Exception:
27 | print(
28 | "failed to add PYTEST_DONT_REWRITE to `torch.autograd.gradcheck`,",
29 | "tests are likely to fail",
30 | file=sys.stderr,
31 | )
32 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # This is not the actual setup.py for this project, see `python/vesin/setup.py` for it.
2 | # Instead, this file is here to enable `pip install .` from a git checkout or `pip
3 | # install git+https://...` without having to specify a subdirectory
4 |
5 | import os
6 |
7 | from setuptools import setup
8 |
9 |
10 | ROOT = os.path.realpath(os.path.dirname(__file__))
11 |
12 | setup(
13 | name="vesin-git",
14 | version="0.0.0",
15 | install_requires=[
16 | f"vesin @ file://{ROOT}/python/vesin",
17 | ],
18 | extras_require={
19 | "torch": [
20 | f"vesin-torch @ file://{ROOT}/python/vesin_torch",
21 | ]
22 | },
23 | packages=[],
24 | )
25 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | # https://github.com/tox-dev/tox/issues/3238
3 | requires = tox==4.14.0
4 |
5 | envlist =
6 | lint
7 | tests
8 | torch-tests
9 |
10 | [testenv]
11 | lint-folders = python setup.py
12 | package = external
13 | package_env = build-vesin
14 |
15 | [testenv:build-vesin]
16 | passenv = *
17 | deps =
18 | cmake
19 | setuptools
20 | wheel
21 |
22 | commands =
23 | pip wheel python/vesin --no-deps --no-build-isolation --check-build-dependencies --wheel-dir {envtmpdir}/dist
24 |
25 | [testenv:tests]
26 | passenv = *
27 | description = Run the tests of the vesin Python package
28 | deps =
29 | ase
30 | pytest
31 |
32 | metatomic-torch >=0.1,<0.2
33 |
34 | changedir = python/vesin
35 | commands =
36 | pytest {posargs}
37 |
38 | # Enable when doc examples exist; pytest fails if no tests exist
39 | # pytest --doctest-modules --pyargs vesin
40 |
41 | [testenv:torch-tests]
42 | passenv = *
43 | description = Run the tests of the vesin-torch Python package
44 | deps =
45 | pytest
46 | torch
47 | metatensor-torch >=0.7,<0.8
48 | metatomic-torch >=0.1,<0.2
49 | numpy
50 |
51 | cmake
52 | setuptools
53 | wheel
54 |
55 | changedir = python/vesin_torch
56 | commands =
57 | pip install . --no-deps --no-build-isolation --check-build-dependencies
58 |
59 | # Make torch.autograd.gradcheck works with pytest
60 | python {toxinidir}/scripts/pytest-dont-rewrite-torch.py
61 |
62 | pytest {posargs}
63 | pytest --doctest-modules --pyargs vesin.torch
64 |
65 |
66 | [testenv:cxx-tests]
67 | passenv = *
68 | description = Run the C++ tests
69 | package = skip
70 | deps = cmake
71 |
72 | commands =
73 | cmake -B {envtmpdir} -S vesin -DVESIN_BUILD_TESTS=ON -DCMAKE_BUILD_TYPE=Debug
74 | cmake --build {envtmpdir} --config Debug
75 | ctest --test-dir {envtmpdir} --build-config Debug
76 |
77 |
78 | [testenv:lint]
79 | description = Run linters and formatter
80 | package = skip
81 | deps =
82 | ruff
83 |
84 | commands =
85 | ruff format --diff {[testenv]lint-folders}
86 | ruff check {[testenv]lint-folders}
87 |
88 |
89 | [testenv:format]
90 | description = Abuse tox to do actual formatting on all files.
91 | package = skip
92 | deps =
93 | ruff
94 | commands =
95 | ruff format {[testenv]lint-folders}
96 | ruff check --fix-only {[testenv]lint-folders}
97 |
98 |
99 | [testenv:docs]
100 | passenv = *
101 | description = Invoke sphinx-build to build the HTML docs
102 | deps =
103 | sphinx
104 | breathe >=4.33 # C++ => sphinx through doxygen
105 | furo # sphinx theme
106 | sphinx-design # helpers for nicer docs website (tabs, grids, cards, …)
107 |
108 | torch
109 | metatomic-torch >=0.1,<0.2
110 | cmake
111 | setuptools
112 | wheel
113 |
114 | commands =
115 | pip install python/vesin_torch --no-deps --no-build-isolation --check-build-dependencies
116 | sphinx-build -d docs/build/doctrees -W -b html docs/src docs/build/html
117 |
--------------------------------------------------------------------------------
/vesin/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.16)
2 |
3 | file(READ "VERSION" VESIN_VERSION)
4 | string(STRIP ${VESIN_VERSION} VESIN_VERSION)
5 |
6 | project(vesin LANGUAGES C CXX VERSION ${VESIN_VERSION})
7 |
8 | if (${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_CURRENT_SOURCE_DIR})
9 | set(VESIN_MAIN_PROJECT ON)
10 | else()
11 | set(VESIN_MAIN_PROJECT OFF)
12 | endif()
13 |
14 | if (VESIN_MAIN_PROJECT)
15 | if("${CMAKE_BUILD_TYPE}" STREQUAL "" AND "${CMAKE_CONFIGURATION_TYPES}" STREQUAL "")
16 | message(STATUS "Setting build type to 'release' as none was specified.")
17 | set(
18 | CMAKE_BUILD_TYPE "release"
19 | CACHE STRING
20 | "Choose the type of build, options are: none(CMAKE_CXX_FLAGS or CMAKE_C_FLAGS used) debug release relwithdebinfo minsizerel."
21 | FORCE
22 | )
23 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS release debug relwithdebinfo minsizerel none)
24 | endif()
25 | endif()
26 |
27 | option(BUILD_SHARED_LIBS "Build shared libraries instead of static ones" OFF)
28 | option(VESIN_BUILD_TESTS "Build and run Vesin's unit tests" OFF)
29 | option(VESIN_INSTALL "Install Vesin's headers and libraries" ${VESIN_MAIN_PROJECT})
30 | option(VESIN_TORCH "Build the vesin_torch library" OFF)
31 |
32 | set(VESIN_SOURCES
33 | ${CMAKE_CURRENT_SOURCE_DIR}/src/vesin.cpp
34 | ${CMAKE_CURRENT_SOURCE_DIR}/src/cpu_cell_list.cpp
35 | )
36 | add_library(vesin ${VESIN_SOURCES})
37 |
38 | target_include_directories(vesin PUBLIC
39 | $
40 | $
41 | )
42 |
43 | target_compile_features(vesin PRIVATE cxx_std_17)
44 |
45 | set_target_properties(vesin PROPERTIES
46 | # hide non-exported symbols by default
47 | CXX_VISIBILITY_PRESET hidden
48 | VISIBILITY_INLINES_HIDDEN ON
49 | )
50 |
51 | target_compile_definitions(vesin PRIVATE VESIN_EXPORTS)
52 | if (BUILD_SHARED_LIBS)
53 | target_compile_definitions(vesin PUBLIC VESIN_SHARED)
54 | endif()
55 |
56 | if (VESIN_BUILD_TESTS)
57 | enable_testing()
58 | add_subdirectory(tests)
59 | endif()
60 |
61 | if (VESIN_TORCH)
62 | add_subdirectory(torch)
63 | endif()
64 |
65 | #------------------------------------------------------------------------------#
66 | # Installation configuration
67 | #------------------------------------------------------------------------------#
68 | if (VESIN_INSTALL)
69 | install(TARGETS vesin
70 | ARCHIVE DESTINATION "lib"
71 | LIBRARY DESTINATION "lib"
72 | RUNTIME DESTINATION "bin"
73 | )
74 |
75 | install(FILES "include/vesin.h" DESTINATION "include")
76 | endif()
77 |
--------------------------------------------------------------------------------
/vesin/VERSION:
--------------------------------------------------------------------------------
1 | 0.3.7
2 |
--------------------------------------------------------------------------------
/vesin/include/vesin.h:
--------------------------------------------------------------------------------
1 | #ifndef VESIN_H
2 | #define VESIN_H
3 |
4 | #include
5 | #include
6 |
7 | #if defined(VESIN_SHARED)
8 | #if defined(VESIN_EXPORTS)
9 | #if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__)
10 | #define VESIN_API __attribute__((visibility("default")))
11 | #elif defined(_MSC_VER)
12 | #define VESIN_API __declspec(dllexport)
13 | #else
14 | #define VESIN_API
15 | #endif
16 | #else
17 | #if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__)
18 | #define VESIN_API __attribute__((visibility("default")))
19 | #elif defined(_MSC_VER)
20 | #define VESIN_API __declspec(dllimport)
21 | #else
22 | #define VESIN_API
23 | #endif
24 | #endif
25 | #else
26 | #define VESIN_API
27 | #endif
28 |
29 | #ifdef __cplusplus
30 | extern "C" {
31 | #endif
32 |
33 | /// Options for a neighbor list calculation
34 | struct VesinOptions {
35 | /// Spherical cutoff, only pairs below this cutoff will be included
36 | double cutoff;
37 | /// Should the returned neighbor list be a full list (include both `i -> j`
38 | /// and `j -> i` pairs) or a half list (include only `i -> j`)?
39 | bool full;
40 | /// Should the neighbor list be sorted? If yes, the returned pairs will be
41 | /// sorted using lexicographic order.
42 | bool sorted;
43 |
44 | /// Should the returned `VesinNeighborList` contain `shifts`?
45 | bool return_shifts;
46 | /// Should the returned `VesinNeighborList` contain `distances`?
47 | bool return_distances;
48 | /// Should the returned `VesinNeighborList` contain `vector`?
49 | bool return_vectors;
50 | };
51 |
52 | /// Device on which the data can be
53 | enum VesinDevice {
54 | /// Unknown device, used for default initialization and to indicate no
55 | /// allocated data.
56 | VesinUnknownDevice = 0,
57 | /// CPU device
58 | VesinCPU = 1,
59 | };
60 |
61 |
62 | /// The actual neighbor list
63 | ///
64 | /// This is organized as a list of pairs, where each pair can contain the
65 | /// following data:
66 | ///
67 | /// - indices of the points in the pair;
68 | /// - distance between points in the pair, accounting for periodic boundary
69 | /// conditions;
70 | /// - vector between points in the pair, accounting for periodic boundary
71 | /// conditions;
72 | /// - periodic shift that created the pair. This is only relevant when using
73 | /// periodic boundary conditions, and contains the number of bounding box we
74 | /// need to cross to create the pair. If the positions of the points are `r_i`
75 | /// and `r_j`, the bounding box is described by a matrix of three vectors `H`,
76 | /// and the periodic shift is `S`, the distance vector for a given pair will
77 | /// be given by `r_ij = r_j - r_i + S @ H`.
78 | ///
79 | /// Under periodic boundary conditions, two atoms can be part of multiple pairs,
80 | /// each pair having a different periodic shift.
81 | struct VESIN_API VesinNeighborList {
82 | #ifdef __cplusplus
83 | VesinNeighborList():
84 | length(0),
85 | device(VesinUnknownDevice),
86 | pairs(nullptr),
87 | shifts(nullptr),
88 | distances(nullptr),
89 | vectors(nullptr)
90 | {}
91 | #endif
92 |
93 | /// Number of pairs in this neighbor list
94 | size_t length;
95 | /// Device used for the data allocations
96 | VesinDevice device;
97 | /// Array of pairs (storing the indices of the first and second point in the
98 | /// pair), containing `length` elements.
99 | size_t (*pairs)[2];
100 | /// Array of box shifts, one for each `pair`. This is only set if
101 | /// `options.return_pairs` was `true` during the calculation.
102 | int32_t (*shifts)[3];
103 | /// Array of pair distance (i.e. distance between the two points), one for
104 | /// each pair. This is only set if `options.return_distances` was `true`
105 | /// during the calculation.
106 | double *distances;
107 | /// Array of pair vector (i.e. vector between the two points), one for
108 | /// each pair. This is only set if `options.return_vector` was `true`
109 | /// during the calculation.
110 | double (*vectors)[3];
111 |
112 | // TODO: custom memory allocators?
113 | };
114 |
115 | /// Free all allocated memory inside a `VesinNeighborList`, according the it's
116 | /// `device`.
117 | void VESIN_API vesin_free(struct VesinNeighborList* neighbors);
118 |
119 | /// Compute a neighbor list.
120 | ///
121 | /// The data is returned in a `VesinNeighborList`. For an initial call, the
122 | /// `VesinNeighborList` should be zero-initialized (or default-initalized in
123 | /// C++). The `VesinNeighborList` can be re-used across calls to this functions
124 | /// to re-use memory allocations, and once it is no longer needed, users should
125 | /// call `vesin_free` to release the corresponding memory.
126 | ///
127 | /// @param points positions of all points in the system;
128 | /// @param n_points number of elements in the `points` array
129 | /// @param box bounding box for the system. If the system is non-periodic,
130 | /// this is ignored. This should contain the three vectors of the bounding
131 | /// box, one vector per row of the matrix.
132 | /// @param periodic is the system using periodic boundary conditions?
133 | /// @param device device where the `points` and `box` data is allocated.
134 | /// @param options options for the calculation
135 | /// @param neighbors non-NULL pointer to `VesinNeighborList` that will be used
136 | /// to store the computed list of neighbors.
137 | /// @param error_message Pointer to a `char*` that wil be set to the error
138 | /// message if this function fails. This does not need to be freed when no
139 | /// longer needed.
140 | int VESIN_API vesin_neighbors(
141 | const double (*points)[3],
142 | size_t n_points,
143 | const double box[3][3],
144 | bool periodic,
145 | VesinDevice device,
146 | struct VesinOptions options,
147 | struct VesinNeighborList* neighbors,
148 | const char** error_message
149 | );
150 |
151 |
152 | #ifdef __cplusplus
153 |
154 | } // extern "C"
155 |
156 | #endif
157 |
158 | #endif
159 |
--------------------------------------------------------------------------------
/vesin/src/cpu_cell_list.hpp:
--------------------------------------------------------------------------------
1 | #ifndef VESIN_CPU_CELL_LIST_HPP
2 | #define VESIN_CPU_CELL_LIST_HPP
3 |
4 | #include
5 |
6 | #include "vesin.h"
7 |
8 | #include "types.hpp"
9 |
10 | namespace vesin { namespace cpu {
11 |
12 | void free_neighbors(VesinNeighborList& neighbors);
13 |
14 | void neighbors(
15 | const Vector* points,
16 | size_t n_points,
17 | BoundingBox cell,
18 | VesinOptions options,
19 | VesinNeighborList& neighbors
20 | );
21 |
22 |
23 | /// The cell list is used to sort atoms inside bins/cells.
24 | ///
25 | /// The list of potential pairs is then constructed by looking through all
26 | /// neighboring cells (the number of cells to search depends on the cutoff and
27 | /// the size of the cells) for each atom to create pair candidates.
28 | class CellList {
29 | public:
30 | /// Create a new `CellList` for the given bounding box and cutoff,
31 | /// determining all required parameters.
32 | CellList(BoundingBox box, double cutoff);
33 |
34 | /// Add a single point to the cell list at the given `position`. The point
35 | /// is uniquely identified by its `index`.
36 | void add_point(size_t index, Vector position);
37 |
38 | /// Iterate over all possible pairs, calling the given callback every time
39 | template
40 | void foreach_pair(Function callback);
41 |
42 | private:
43 | /// How many cells do we need to look at when searching neighbors to include
44 | /// all neighbors below cutoff
45 | std::array n_search_;
46 |
47 | /// the cells themselves are a list of points & corresponding
48 | /// shift to place the point inside the cell
49 | struct Point {
50 | size_t index;
51 | CellShift shift;
52 | };
53 | struct Cell: public std::vector {};
54 |
55 | // raw data for the cells
56 | std::vector cells_;
57 | // shape of the cell array
58 | std::array cells_shape_;
59 |
60 | BoundingBox box_;
61 |
62 | Cell& get_cell(std::array index);
63 | };
64 |
65 | /// Wrapper around `VesinNeighborList` that behaves like a std::vector,
66 | /// automatically growing memory allocations.
67 | class GrowableNeighborList {
68 | public:
69 | VesinNeighborList& neighbors;
70 | size_t capacity;
71 | VesinOptions options;
72 |
73 | size_t length() const {
74 | return neighbors.length;
75 | }
76 |
77 | void increment_length() {
78 | neighbors.length += 1;
79 | }
80 |
81 | void set_pair(size_t index, size_t first, size_t second);
82 | void set_shift(size_t index, vesin::CellShift shift);
83 | void set_distance(size_t index, double distance);
84 | void set_vector(size_t index, vesin::Vector vector);
85 |
86 | // reset length to 0, and allocate/deallocate members of
87 | // `neighbors` according to `options`
88 | void reset();
89 |
90 | // allocate more memory & update capacity
91 | void grow();
92 |
93 | // sort the pairs currently in the neighbor list
94 | void sort();
95 | };
96 |
97 | } // namespace vesin
98 | } // namespace cpu
99 |
100 | #endif
101 |
--------------------------------------------------------------------------------
/vesin/src/math.hpp:
--------------------------------------------------------------------------------
1 | #ifndef VESIN_MATH_HPP
2 | #define VESIN_MATH_HPP
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | namespace vesin {
9 | struct Vector;
10 |
11 | Vector operator*(Vector vector, double scalar);
12 |
13 | struct Vector: public std::array {
14 | double dot(Vector other) const {
15 | return (*this)[0] * other[0] + (*this)[1] * other[1] + (*this)[2] * other[2];
16 | }
17 |
18 | double norm() const {
19 | return std::sqrt(this->dot(*this));
20 | }
21 |
22 | Vector normalize() const {
23 | return *this * (1.0 / this->norm());
24 | }
25 |
26 | Vector cross(Vector other) const {
27 | return Vector{
28 | (*this)[1] * other[2] - (*this)[2] * other[1],
29 | (*this)[2] * other[0] - (*this)[0] * other[2],
30 | (*this)[0] * other[1] - (*this)[1] * other[0],
31 | };
32 | }
33 | };
34 |
35 | inline Vector operator+(Vector u, Vector v) {
36 | return Vector{
37 | u[0] + v[0],
38 | u[1] + v[1],
39 | u[2] + v[2],
40 | };
41 | }
42 |
43 | inline Vector operator-(Vector u, Vector v) {
44 | return Vector{
45 | u[0] - v[0],
46 | u[1] - v[1],
47 | u[2] - v[2],
48 | };
49 | }
50 |
51 | inline Vector operator*(double scalar, Vector vector) {
52 | return Vector{
53 | scalar * vector[0],
54 | scalar * vector[1],
55 | scalar * vector[2],
56 | };
57 | }
58 |
59 | inline Vector operator*(Vector vector, double scalar) {
60 | return Vector{
61 | scalar * vector[0],
62 | scalar * vector[1],
63 | scalar * vector[2],
64 | };
65 | }
66 |
67 |
68 | struct Matrix: public std::array, 3> {
69 | double determinant() const {
70 | return (*this)[0][0] * ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2])
71 | - (*this)[0][1] * ((*this)[1][0] * (*this)[2][2] - (*this)[1][2] * (*this)[2][0])
72 | + (*this)[0][2] * ((*this)[1][0] * (*this)[2][1] - (*this)[1][1] * (*this)[2][0]);
73 | }
74 |
75 | Matrix inverse() const {
76 | auto det = this->determinant();
77 |
78 | if (std::abs(det) < 1e-30) {
79 | throw std::runtime_error("this matrix is not invertible");
80 | }
81 |
82 | auto inverse = Matrix();
83 | inverse[0][0] = ((*this)[1][1] * (*this)[2][2] - (*this)[2][1] * (*this)[1][2]) / det;
84 | inverse[0][1] = ((*this)[0][2] * (*this)[2][1] - (*this)[0][1] * (*this)[2][2]) / det;
85 | inverse[0][2] = ((*this)[0][1] * (*this)[1][2] - (*this)[0][2] * (*this)[1][1]) / det;
86 | inverse[1][0] = ((*this)[1][2] * (*this)[2][0] - (*this)[1][0] * (*this)[2][2]) / det;
87 | inverse[1][1] = ((*this)[0][0] * (*this)[2][2] - (*this)[0][2] * (*this)[2][0]) / det;
88 | inverse[1][2] = ((*this)[1][0] * (*this)[0][2] - (*this)[0][0] * (*this)[1][2]) / det;
89 | inverse[2][0] = ((*this)[1][0] * (*this)[2][1] - (*this)[2][0] * (*this)[1][1]) / det;
90 | inverse[2][1] = ((*this)[2][0] * (*this)[0][1] - (*this)[0][0] * (*this)[2][1]) / det;
91 | inverse[2][2] = ((*this)[0][0] * (*this)[1][1] - (*this)[1][0] * (*this)[0][1]) / det;
92 | return inverse;
93 | }
94 | };
95 |
96 |
97 | inline Vector operator*(Matrix matrix, Vector vector) {
98 | return Vector{
99 | matrix[0][0] * vector[0] + matrix[0][1] * vector[1] + matrix[0][2] * vector[2],
100 | matrix[1][0] * vector[0] + matrix[1][1] * vector[1] + matrix[1][2] * vector[2],
101 | matrix[2][0] * vector[0] + matrix[2][1] * vector[1] + matrix[2][2] * vector[2],
102 | };
103 | }
104 |
105 | inline Vector operator*(Vector vector, Matrix matrix) {
106 | return Vector{
107 | vector[0] * matrix[0][0] + vector[1] * matrix[1][0] + vector[2] * matrix[2][0],
108 | vector[0] * matrix[0][1] + vector[1] * matrix[1][1] + vector[2] * matrix[2][1],
109 | vector[0] * matrix[0][2] + vector[1] * matrix[1][2] + vector[2] * matrix[2][2],
110 | };
111 | }
112 |
113 | } // namespace vesin
114 |
115 | #endif
116 |
--------------------------------------------------------------------------------
/vesin/src/types.hpp:
--------------------------------------------------------------------------------
1 | #ifndef VESIN_TYPES_HPP
2 | #define VESIN_TYPES_HPP
3 |
4 | #include "math.hpp"
5 |
6 | namespace vesin {
7 |
8 | class BoundingBox {
9 | public:
10 | BoundingBox(Matrix matrix, bool periodic): matrix_(matrix), periodic_(periodic) {
11 | if (periodic) {
12 | auto det = matrix_.determinant();
13 | if (std::abs(det) < 1e-30) {
14 | throw std::runtime_error("the box matrix is not invertible");
15 | }
16 |
17 | this->inverse_ = matrix_.inverse();
18 | } else {
19 | this->matrix_ = Matrix{{{
20 | {{1, 0, 0}},
21 | {{0, 1, 0}},
22 | {{0, 0, 1}}
23 | }}};
24 | this->inverse_ = matrix_;
25 | }
26 | }
27 |
28 | const Matrix& matrix() const {
29 | return this->matrix_;
30 | }
31 |
32 | bool periodic() const {
33 | return this->periodic_;
34 | }
35 |
36 | /// Convert a vector from cartesian coordinates to fractional coordinates
37 | Vector cartesian_to_fractional(Vector cartesian) const {
38 | return cartesian * inverse_;
39 | }
40 |
41 | /// Convert a vector from fractional coordinates to cartesian coordinates
42 | Vector fractional_to_cartesian(Vector fractional) const {
43 | return fractional * matrix_;
44 | }
45 |
46 | /// Get the three distances between faces of the bounding box
47 | Vector distances_between_faces() const {
48 | auto a = Vector{matrix_[0]};
49 | auto b = Vector{matrix_[1]};
50 | auto c = Vector{matrix_[2]};
51 |
52 | // Plans normal vectors
53 | auto na = b.cross(c).normalize();
54 | auto nb = c.cross(a).normalize();
55 | auto nc = a.cross(b).normalize();
56 |
57 | return Vector{
58 | std::abs(na.dot(a)),
59 | std::abs(nb.dot(b)),
60 | std::abs(nc.dot(c)),
61 | };
62 | }
63 |
64 | private:
65 | Matrix matrix_;
66 | Matrix inverse_;
67 | bool periodic_;
68 | };
69 |
70 |
71 | /// A cell shift represents the displacement along cell axis between the actual
72 | /// position of an atom and a periodic image of this atom.
73 | ///
74 | /// The cell shift can be used to reconstruct the vector between two points,
75 | /// wrapped inside the unit cell.
76 | struct CellShift: public std::array {
77 | /// Compute the shift vector in cartesian coordinates, using the given cell
78 | /// matrix (stored in row major order).
79 | Vector cartesian(Matrix cell) const {
80 | auto vector = Vector{
81 | static_cast((*this)[0]),
82 | static_cast((*this)[1]),
83 | static_cast((*this)[2]),
84 | };
85 | return vector * cell;
86 | }
87 | };
88 |
89 | inline CellShift operator+(CellShift a, CellShift b) {
90 | return CellShift{
91 | a[0] + b[0],
92 | a[1] + b[1],
93 | a[2] + b[2],
94 | };
95 | }
96 |
97 | inline CellShift operator-(CellShift a, CellShift b) {
98 | return CellShift{
99 | a[0] - b[0],
100 | a[1] - b[1],
101 | a[2] - b[2],
102 | };
103 | }
104 |
105 |
106 | } // namespace vesin
107 |
108 | #endif
109 |
--------------------------------------------------------------------------------
/vesin/src/vesin.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "vesin.h"
5 |
6 | #include "cpu_cell_list.hpp"
7 |
8 | thread_local std::string LAST_ERROR;
9 |
10 | extern "C" int vesin_neighbors(
11 | const double (*points)[3],
12 | size_t n_points,
13 | const double box[3][3],
14 | bool periodic,
15 | VesinDevice device,
16 | VesinOptions options,
17 | VesinNeighborList* neighbors,
18 | const char** error_message
19 | ) {
20 | if (error_message == nullptr) {
21 | return EXIT_FAILURE;
22 | }
23 |
24 | if (points == nullptr) {
25 | *error_message = "`points` can not be a NULL pointer";
26 | return EXIT_FAILURE;
27 | }
28 |
29 | if (box == nullptr) {
30 | *error_message = "`cell` can not be a NULL pointer";
31 | return EXIT_FAILURE;
32 | }
33 |
34 | if (neighbors == nullptr) {
35 | *error_message = "`neighbors` can not be a NULL pointer";
36 | return EXIT_FAILURE;
37 | }
38 |
39 | if (!std::isfinite(options.cutoff) || options.cutoff <= 0) {
40 | *error_message = "cutoff must be a finite, positive number";
41 | return EXIT_FAILURE;
42 | }
43 |
44 | if (options.cutoff <= 1e-6) {
45 | *error_message = "cutoff is too small";
46 | return EXIT_FAILURE;
47 | }
48 |
49 | if (neighbors->device != VesinUnknownDevice && neighbors->device != device) {
50 | *error_message = "`neighbors` device and data `device` do not match, free the neighbors first";
51 | return EXIT_FAILURE;
52 | }
53 |
54 | if (device == VesinUnknownDevice) {
55 | *error_message = "got an unknown device to use when running simulation";
56 | return EXIT_FAILURE;
57 | }
58 |
59 | if (neighbors->device == VesinUnknownDevice) {
60 | // initialize the device
61 | neighbors->device = device;
62 | } else if (neighbors->device != device) {
63 | *error_message = "`neighbors.device` and `device` do not match, free the neighbors first";
64 | return EXIT_FAILURE;
65 | }
66 |
67 | try {
68 | if (device == VesinCPU) {
69 | auto matrix = vesin::Matrix{{{
70 | {{box[0][0], box[0][1], box[0][2]}},
71 | {{box[1][0], box[1][1], box[1][2]}},
72 | {{box[2][0], box[2][1], box[2][2]}},
73 | }}};
74 |
75 | vesin::cpu::neighbors(
76 | reinterpret_cast(points),
77 | n_points,
78 | vesin::BoundingBox(matrix, periodic),
79 | options,
80 | *neighbors
81 | );
82 | } else {
83 | throw std::runtime_error("unknown device " + std::to_string(device));
84 | }
85 | } catch (const std::bad_alloc&) {
86 | LAST_ERROR = "failed to allocate memory";
87 | *error_message = LAST_ERROR.c_str();
88 | return EXIT_FAILURE;
89 | } catch (const std::exception& e) {
90 | LAST_ERROR = e.what();
91 | *error_message = LAST_ERROR.c_str();
92 | return EXIT_FAILURE;
93 | } catch (...) {
94 | *error_message = "fatal error: unknown type thrown as exception";
95 | return EXIT_FAILURE;
96 | }
97 |
98 | return EXIT_SUCCESS;
99 | }
100 |
101 |
102 | extern "C" void vesin_free(VesinNeighborList* neighbors) {
103 | if (neighbors == nullptr) {
104 | return;
105 | }
106 |
107 | if (neighbors->device == VesinUnknownDevice) {
108 | // nothing to do
109 | } else if (neighbors->device == VesinCPU) {
110 | vesin::cpu::free_neighbors(*neighbors);
111 | }
112 |
113 | std::memset(neighbors, 0, sizeof(VesinNeighborList));
114 | }
115 |
--------------------------------------------------------------------------------
/vesin/tests/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | include(FetchContent)
2 |
3 | # Override options with variables
4 | set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
5 |
6 | FetchContent_Declare(Catch2
7 | GIT_REPOSITORY https://github.com/catchorg/Catch2.git
8 | GIT_TAG v3.5.3
9 | )
10 |
11 | set(CATCH_CONFIG_FAST_COMPILE ON)
12 | FetchContent_MakeAvailable(Catch2)
13 |
14 | find_program(VALGRIND valgrind)
15 | if (VALGRIND)
16 | message(STATUS "Running tests using valgrind")
17 | set(TEST_COMMAND
18 | "${VALGRIND}" "--tool=memcheck" "--dsymutil=yes" "--error-exitcode=125"
19 | "--leak-check=full" "--show-leak-kinds=definite,indirect,possible" "--track-origins=yes"
20 | "--gen-suppressions=all"
21 | )
22 | else()
23 | set(TEST_COMMAND "")
24 | endif()
25 |
26 |
27 | file(GLOB ALL_TESTS *.cpp)
28 | foreach(_file_ ${ALL_TESTS})
29 | get_filename_component(_name_ ${_file_} NAME_WE)
30 | add_executable(${_name_} ${_file_})
31 | target_link_libraries(${_name_} vesin Catch2WithMain)
32 |
33 | add_test(
34 | NAME ${_name_}
35 | COMMAND ${TEST_COMMAND} $
36 | )
37 | endforeach()
38 |
--------------------------------------------------------------------------------
/vesin/tests/memory.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 |
6 | TEST_CASE("Re-use allocations") {
7 | double points[][3] = {
8 | {0.0, 0.0, 0.0},
9 | };
10 | size_t n_points = 1;
11 |
12 | double box[3][3] = {
13 | {0.0, 1.5, 1.5},
14 | {1.5, 0.0, 1.5},
15 | {1.5, 1.5, 0.0},
16 | };
17 | bool periodic = true;
18 |
19 | VesinNeighborList neighbors;
20 |
21 | auto compute_neighbors = [&](VesinOptions options) {
22 | const char* error_message = nullptr;
23 | auto status = vesin_neighbors(
24 | points,
25 | n_points,
26 | box,
27 | periodic,
28 | VesinDevice::VesinCPU,
29 | options,
30 | &neighbors,
31 | &error_message
32 | );
33 | REQUIRE(status == EXIT_SUCCESS);
34 | REQUIRE(error_message == nullptr);
35 | };
36 |
37 | auto options = VesinOptions();
38 | options.cutoff = 3.4;
39 | options.full = false;
40 | options.return_shifts = false;
41 | options.return_distances = true;
42 | options.return_vectors = false;
43 |
44 | compute_neighbors(options);
45 |
46 | CHECK(neighbors.length == 9);
47 | CHECK(neighbors.pairs != nullptr);
48 | CHECK(neighbors.shifts == nullptr);
49 | CHECK(neighbors.distances != nullptr);
50 | CHECK(neighbors.vectors == nullptr);
51 |
52 | /***************************************************/
53 | options.cutoff = 6;
54 | compute_neighbors(options);
55 |
56 | CHECK(neighbors.length == 67);
57 | CHECK(neighbors.pairs != nullptr);
58 | CHECK(neighbors.shifts == nullptr);
59 | CHECK(neighbors.distances != nullptr);
60 | CHECK(neighbors.vectors == nullptr);
61 |
62 | /***************************************************/
63 | options.full = true;
64 | compute_neighbors(options);
65 |
66 | CHECK(neighbors.length == 67 * 2);
67 | CHECK(neighbors.pairs != nullptr);
68 | CHECK(neighbors.shifts == nullptr);
69 | CHECK(neighbors.distances != nullptr);
70 | CHECK(neighbors.vectors == nullptr);
71 |
72 | /***************************************************/
73 | options.cutoff = 4.5;
74 | options.full = false;
75 | options.return_shifts = true;
76 | options.return_distances = false;
77 | compute_neighbors(options);
78 |
79 | CHECK(neighbors.length == 27);
80 | CHECK(neighbors.pairs != nullptr);
81 | CHECK(neighbors.shifts != nullptr);
82 | CHECK(neighbors.distances == nullptr);
83 | CHECK(neighbors.vectors == nullptr);
84 |
85 |
86 | vesin_free(&neighbors);
87 | }
88 |
--------------------------------------------------------------------------------
/vesin/tests/neighbors.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include
5 | #include
6 |
7 | using namespace Catch::Matchers;
8 |
9 | #include
10 |
11 | #define CHECK_APPROX_EQUAL(a, b) CHECK_THAT(a, WithinULP(b, 4));
12 |
13 | static void check_neighbors(
14 | const double (*points)[3],
15 | size_t n_points,
16 | const double box[3][3],
17 | bool periodic,
18 | double cutoff,
19 | bool full_list,
20 | std::vector> expected_pairs,
21 | std::vector> expected_shifts,
22 | std::vector expected_distances,
23 | std::vector> expected_vectors
24 | ) {
25 | auto options = VesinOptions();
26 | options.cutoff = cutoff;
27 | options.full = full_list;
28 | options.return_shifts = !expected_shifts.empty();
29 | options.return_distances = !expected_distances.empty();
30 | options.return_vectors = !expected_vectors.empty();
31 |
32 | VesinNeighborList neighbors;
33 |
34 | const char* error_message = nullptr;
35 | auto status = vesin_neighbors(
36 | points,
37 | n_points,
38 | box,
39 | periodic,
40 | VesinDevice::VesinCPU,
41 | options,
42 | &neighbors,
43 | &error_message
44 | );
45 | REQUIRE(status == EXIT_SUCCESS);
46 | REQUIRE(error_message == nullptr);
47 |
48 | if (!expected_pairs.empty()) {
49 | REQUIRE(neighbors.length == expected_pairs.size());
50 | for (size_t i=0; i>{
97 | {0, 1},
98 | {0, 2},
99 | {0, 3},
100 | {0, 4},
101 | {1, 3},
102 | {1, 4},
103 | {2, 3},
104 | {2, 4},
105 | {3, 4},
106 | };
107 |
108 | auto expected_distances = std::vector{
109 | 3.2082345612501593,
110 | 2.283282943482914,
111 | 2.4783286706972505,
112 | 1.215100818862369,
113 | 2.9707625283755013,
114 | 2.3059143522689647,
115 | 1.550639867925496,
116 | 2.9495550511899244,
117 | 2.6482573515427084,
118 | };
119 |
120 | check_neighbors(
121 | points, /*n_points=*/ 5, box, /*periodic=*/ false, /*cutoff=*/ 3.42, /*full_list=*/ false,
122 | expected_pairs, {}, expected_distances, {}
123 | );
124 | }
125 |
126 | TEST_CASE("FCC unit cell") {
127 | double points[][3] = {
128 | {0.0, 0.0, 0.0},
129 | };
130 |
131 | double box[3][3] = {
132 | {0.0, 1.5, 1.5},
133 | {1.5, 0.0, 1.5},
134 | {1.5, 1.5, 0.0},
135 | };
136 |
137 | auto expected_vectors = std::vector>{
138 | {1.5, 0.0, -1.5},
139 | {1.5, -1.5, 0.0},
140 | {0.0, 1.5, -1.5},
141 | {1.5, 1.5, 0.0},
142 | {1.5, 0.0, 1.5},
143 | {0.0, 1.5, 1.5},
144 | };
145 |
146 | auto expected_shifts = std::vector>{
147 | {-1, 0, 1},
148 | {-1, 1, 0},
149 | {0, -1, 1},
150 | {0, 0, 1},
151 | {0, 1, 0},
152 | {1, 0, 0},
153 | };
154 |
155 | auto expected_pairs = std::vector>(6, {0, 0});
156 | auto expected_distances = std::vector(6, 2.1213203435596424);
157 |
158 | check_neighbors(
159 | points, /*n_points=*/ 1, box, /*periodic=*/ true, /*cutoff=*/ 3.0, /*full_list=*/ false,
160 | expected_pairs, expected_shifts, expected_distances, expected_vectors
161 | );
162 | }
163 |
164 | TEST_CASE("Large box, small cutoff") {
165 | double points[][3] = {
166 | {0.0, 0.0, 0.0},
167 | {0.0, 2.0, 0.0},
168 | {0.0, 0.0, 2.0},
169 | // points outside the box natural boundaries
170 | {-6.0, 0.0, 0.0},
171 | {-6.0, -2.0, 0.0},
172 | {-6.0, 0.0, -2.0},
173 | };
174 |
175 | double box[3][3] = {
176 | {54.0, 0.0, 0.0},
177 | {0.0, 54.0, 0.0},
178 | {0.0, 0.0, 54.0},
179 | };
180 |
181 | auto expected_pairs = std::vector>{
182 | {0, 1},
183 | {0, 2},
184 | {3, 4},
185 | {3, 5},
186 | };
187 | auto expected_shifts = std::vector>(4, {0, 0, 0});
188 | auto expected_distances = std::vector(4, 2.0);
189 |
190 | check_neighbors(
191 | points, /*n_points=*/ 6, box, /*periodic=*/ true, /*cutoff=*/ 2.1, /*full_list=*/ false,
192 | expected_pairs, expected_shifts, expected_distances, {}
193 | );
194 | }
195 |
196 | TEST_CASE("Cutoff larger than the box size") {
197 | double points[][3] = {
198 | {0.0, 0.0, 0.0},
199 | };
200 |
201 | double box[3][3] = {
202 | {0.5, 0.0, 0.0},
203 | {0.0, 0.5, 0.0},
204 | {0.0, 0.0, 0.5},
205 | };
206 |
207 | auto expected_pairs = std::vector>(3, {0, 0});
208 | auto expected_distances = std::vector(3, 0.5);
209 | auto expected_shifts = std::vector>{
210 | {0, 0, 1},
211 | {0, 1, 0},
212 | {1, 0, 0},
213 | };
214 | auto expected_vectors = std::vector>{
215 | {0.0, 0.0, 0.5},
216 | {0.0, 0.5, 0.0},
217 | {0.5, 0.0, 0.0},
218 | };
219 |
220 | check_neighbors(
221 | points, /*n_points=*/ 1, box, /*periodic=*/ true, /*cutoff=*/ 0.6, /*full_list=*/ false,
222 | expected_pairs, expected_shifts, expected_distances, expected_vectors
223 | );
224 | }
225 |
226 | TEST_CASE("Slanted box") {
227 | double points[][3] = {
228 | {1.42, 0.0, 0.0},
229 | {2.84, 0.0, 0.0},
230 | {3.55, -1.22975607, 0.0},
231 | {4.97, -1.22975607, 0.0},
232 | };
233 |
234 | double box[3][3] = {
235 | {4.26, -2.45951215, 0.0},
236 | {2.13, 1.22975607, 0.0},
237 | {0.0, 0.0, 50.0},
238 | };
239 |
240 | auto options = VesinOptions();
241 | options.cutoff = 6.4;
242 | options.full = false;
243 | options.return_shifts = true;
244 | options.return_distances = false;
245 | options.return_vectors = false;
246 |
247 | VesinNeighborList neighbors;
248 |
249 | const char* error_message = nullptr;
250 | auto status = vesin_neighbors(
251 | points,
252 | /*n_points=*/ 4,
253 | box,
254 | /*periodic=*/ true,
255 | VesinDevice::VesinCPU,
256 | options,
257 | &neighbors,
258 | &error_message
259 | );
260 | REQUIRE(status == EXIT_SUCCESS);
261 |
262 | REQUIRE(neighbors.length == 90);
263 | auto previously_missing = std::vector>{
264 | {-2, 0, 0},
265 | {-2, 1, 0},
266 | {-2, 2, 0},
267 | };
268 |
269 | for (const auto& missing: previously_missing) {
270 | bool found = false;
271 | for (size_t i=0; i
13 | $
14 | $
15 | )
16 |
17 | target_compile_features(vesin_torch PUBLIC cxx_std_17)
18 |
19 | set_target_properties(vesin_torch PROPERTIES
20 | # hide non-exported symbols by default
21 | CXX_VISIBILITY_PRESET hidden
22 | VISIBILITY_INLINES_HIDDEN ON
23 | )
24 |
25 | #------------------------------------------------------------------------------#
26 | # Installation configuration
27 | #------------------------------------------------------------------------------#
28 | if (VESIN_INSTALL)
29 | install(TARGETS vesin_torch
30 | ARCHIVE DESTINATION "lib"
31 | LIBRARY DESTINATION "lib"
32 | RUNTIME DESTINATION "bin"
33 | )
34 |
35 | install(FILES "include/vesin_torch.hpp" DESTINATION "include")
36 | endif()
37 |
--------------------------------------------------------------------------------
/vesin/torch/include/vesin_torch.hpp:
--------------------------------------------------------------------------------
1 | #ifndef VESIN_TORCH_HPP
2 | #define VESIN_TORCH_HPP
3 |
4 | #include
5 |
6 | struct VesinNeighborList;
7 |
8 | namespace vesin_torch {
9 |
10 | class NeighborListHolder;
11 |
12 | /// `NeighborListHolder` should be manipulated through a `torch::intrusive_ptr`
13 | using NeighborList = torch::intrusive_ptr;
14 |
15 | /// Neighbor list calculator compatible with TorchScript
16 | class NeighborListHolder: public torch::CustomClassHolder {
17 | public:
18 | /// Create a new calculator with the given `cutoff`.
19 | ///
20 | /// @param full_list whether pairs should be included twice in the output
21 | /// (both as `i-j` and `j-i`) or only once
22 | /// @param sorted whether pairs should be sorted in the output
23 | NeighborListHolder(double cutoff, bool full_list, bool sorted = false);
24 | ~NeighborListHolder();
25 |
26 | /// Compute the neighbor list for the system defined by `positions`, `box`,
27 | /// and `periodic`; returning the requested `quantities`.
28 | ///
29 | /// `quantities` can contain any combination of the following values:
30 | ///
31 | /// - `"i"` to get the index of the first point in the pair
32 | /// - `"j"` to get the index of the second point in the pair
33 | /// - `"P"` to get the indexes of the two points in the pair simultaneously
34 | /// - `"S"` to get the periodic shift of the pair
35 | /// - `"d"` to get the distance between points in the pair
36 | /// - `"D"` to get the distance vector between points in the pair
37 | ///
38 | /// @param points positions of all points in the system
39 | /// @param box bounding box of the system
40 | /// @param periodic should we use periodic boundary conditions?
41 | /// @param quantities quantities to return, defaults to "ij"
42 | /// @param copy should we copy the returned quantities, defaults to `true`.
43 | /// Setting this to `False` might be a bit faster, but the returned
44 | /// tensors are view inside this class, and will be invalidated
45 | /// whenever this class is garbage collected or used to run a new
46 | /// calculation.
47 | ///
48 | /// @returns a list of `torch::Tensor` as indicated by `quantities`.
49 | std::vector compute(
50 | torch::Tensor points,
51 | torch::Tensor box,
52 | bool periodic,
53 | std::string quantities,
54 | bool copy=true
55 | );
56 | private:
57 | double cutoff_;
58 | bool full_list_;
59 | bool sorted_;
60 | VesinNeighborList* data_;
61 | };
62 |
63 |
64 | }
65 |
66 | #endif
67 |
--------------------------------------------------------------------------------
/vesin/torch/src/vesin_torch.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include
6 |
7 | #include "vesin_torch.hpp"
8 |
9 | using namespace vesin_torch;
10 |
11 | static VesinDevice torch_to_vesin_device(torch::Device device) {
12 | if (device.is_cpu()) {
13 | return VesinCPU;
14 | } else {
15 | throw std::runtime_error("device " + device.str() + " is not supported in vesin");
16 | }
17 | }
18 |
19 | static torch::Device vesin_to_torch_device(VesinDevice device) {
20 | if (device == VesinCPU) {
21 | return torch::Device("cpu");
22 | } else {
23 | throw std::runtime_error("vesin device is not supported in torch");
24 | }
25 | }
26 |
27 | /// Custom autograd function that only registers a custom backward corresponding
28 | /// to the neighbors list calculation
29 | class AutogradNeighbors: public torch::autograd::Function {
30 | public:
31 | static std::vector forward(
32 | torch::autograd::AutogradContext* ctx,
33 | torch::Tensor points,
34 | torch::Tensor box,
35 | bool periodic,
36 | torch::Tensor pairs,
37 | torch::optional shifts,
38 | torch::optional distances,
39 | torch::optional vectors
40 | );
41 |
42 | static std::vector backward(
43 | torch::autograd::AutogradContext* ctx,
44 | std::vector outputs_grad
45 | );
46 | };
47 |
48 | NeighborListHolder::NeighborListHolder(double cutoff, bool full_list, bool sorted):
49 | cutoff_(cutoff),
50 | full_list_(full_list),
51 | sorted_(sorted),
52 | data_(nullptr)
53 | {
54 | data_ = new VesinNeighborList();
55 | }
56 |
57 | NeighborListHolder::~NeighborListHolder() {
58 | vesin_free(data_);
59 | delete data_;
60 | }
61 |
62 | std::vector NeighborListHolder::compute(
63 | torch::Tensor points,
64 | torch::Tensor box,
65 | bool periodic,
66 | std::string quantities,
67 | bool copy
68 | ) {
69 | // check input data
70 | if (points.device() != box.device()) {
71 | C10_THROW_ERROR(ValueError,
72 | "expected `points` and `box` to have the same device, got " +
73 | points.device().str() + " and " + box.device().str()
74 | );
75 | }
76 | auto device = torch_to_vesin_device(points.device());
77 |
78 | if (points.scalar_type() != box.scalar_type()) {
79 | C10_THROW_ERROR(ValueError,
80 | std::string("expected `points` and `box` to have the same dtype, got ") +
81 | torch::toString(points.scalar_type()) + " and " +
82 | torch::toString(box.scalar_type())
83 | );
84 | }
85 | if (points.scalar_type() != torch::kFloat64) {
86 | C10_THROW_ERROR(ValueError,
87 | "only float64 dtype is supported in vesin"
88 | );
89 | }
90 |
91 | if (points.sizes().size() != 2 || points.size(1) != 3) {
92 | std::ostringstream oss;
93 | oss << "`points` must be n x 3 tensor, but the shape is " << points.sizes();
94 | C10_THROW_ERROR(ValueError, oss.str());
95 | }
96 |
97 | if (box.sizes().size() != 2 || box.size(0) != 3 || box.size(1) != 3) {
98 | std::ostringstream oss;
99 | oss << "`box` must be 3 x 3 tensor, but the shape is " << points.sizes();
100 | C10_THROW_ERROR(ValueError, oss.str());
101 | }
102 |
103 | if (!periodic) {
104 | box = torch::zeros({3, 3}, points.options());
105 | }
106 |
107 | // create calculation options
108 | auto n_points = static_cast(points.size(0));
109 |
110 | if (data_->device != VesinUnknownDevice && data_->device != device) {
111 | vesin_free(data_);
112 | std::memset(data_, 0, sizeof(VesinNeighborList));
113 | }
114 |
115 | auto return_shifts = quantities.find('S') != std::string::npos;
116 | if (box.requires_grad()) {
117 | return_shifts = true;
118 | }
119 |
120 | auto return_distances = quantities.find('d') != std::string::npos;
121 | auto return_vectors = quantities.find('D') != std::string::npos;
122 | if ((points.requires_grad() || box.requires_grad()) && (return_distances || return_vectors)) {
123 | // gradients requires both distances & vectors data to be present
124 | return_distances = true;
125 | return_vectors = true;
126 | }
127 |
128 | auto options = VesinOptions{
129 | /*cutoff=*/ this->cutoff_,
130 | /*full=*/ this->full_list_,
131 | /*sorted=*/ this->sorted_,
132 | /*return_shifts=*/ return_shifts,
133 | /*return_distances=*/ return_distances,
134 | /*return_vectors=*/ return_vectors,
135 | };
136 |
137 | if (!points.is_contiguous()) {
138 | points = points.contiguous();
139 | }
140 |
141 | if (!box.is_contiguous()) {
142 | box = box.contiguous();
143 | }
144 |
145 | const char* error_message = nullptr;
146 | auto status = vesin_neighbors(
147 | reinterpret_cast(points.data_ptr()),
148 | n_points,
149 | reinterpret_cast(box.data_ptr()),
150 | periodic,
151 | device,
152 | options,
153 | data_,
154 | &error_message
155 | );
156 |
157 | if (status != EXIT_SUCCESS) {
158 | throw std::runtime_error(std::string("failed to compute neighbors: ") + error_message);
159 | }
160 |
161 | // wrap vesin data in tensors
162 | auto size_t_options = torch::TensorOptions().device(vesin_to_torch_device(data_->device));
163 | if (sizeof(size_t) == sizeof(uint32_t)) {
164 | size_t_options = size_t_options.dtype(torch::kUInt32);
165 | } else if (sizeof(size_t) == sizeof(uint64_t)) {
166 | size_t_options = size_t_options.dtype(torch::kUInt64);
167 | } else {
168 | C10_THROW_ERROR(ValueError,
169 | "could not determine torch dtype matching size_t"
170 | );
171 | }
172 |
173 | auto pairs = torch::from_blob(
174 | data_->pairs,
175 | {static_cast(data_->length), 2},
176 | size_t_options
177 | ).to(torch::kInt64);
178 |
179 | auto shifts = torch::Tensor();
180 | if (data_->shifts != nullptr) {
181 | auto int32_options = torch::TensorOptions()
182 | .device(vesin_to_torch_device(data_->device))
183 | .dtype(torch::kInt32);
184 |
185 | shifts = torch::from_blob(
186 | data_->shifts,
187 | {static_cast(data_->length), 3},
188 | int32_options
189 | );
190 |
191 | if (copy) {
192 | shifts = shifts.clone();
193 | }
194 | }
195 |
196 | auto double_options = torch::TensorOptions()
197 | .device(vesin_to_torch_device(data_->device))
198 | .dtype(torch::kDouble);
199 |
200 | auto distances = torch::Tensor();
201 | if (data_->distances != nullptr) {
202 | distances = torch::from_blob(
203 | data_->distances,
204 | {static_cast(data_->length)},
205 | double_options
206 | );
207 |
208 | if (copy) {
209 | distances = distances.clone();
210 | }
211 | }
212 |
213 | auto vectors = torch::Tensor();
214 | if (data_->vectors != nullptr) {
215 | vectors = torch::from_blob(
216 | data_->vectors,
217 | {static_cast(data_->length), 3},
218 | double_options
219 | );
220 |
221 | if (copy) {
222 | vectors = vectors.clone();
223 | }
224 | }
225 |
226 | // handle autograd
227 | if ((return_distances || return_vectors)) {
228 | // we use optional for these three because otherwise torch autograd
229 | // tries to access data inside the undefined `torch::Tensor()`.
230 | torch::optional shifts_optional = torch::nullopt;
231 | if (shifts.defined()) {
232 | shifts_optional = shifts;
233 | }
234 |
235 | torch::optional distances_optional = torch::nullopt;
236 | if (distances.defined()) {
237 | distances_optional = distances;
238 | }
239 |
240 | torch::optional vectors_optional = torch::nullopt;
241 | if (vectors.defined()) {
242 | vectors_optional = vectors;
243 | }
244 |
245 | auto outputs = AutogradNeighbors::apply(
246 | points,
247 | box,
248 | periodic,
249 | pairs,
250 | shifts_optional,
251 | distances_optional,
252 | vectors_optional
253 | );
254 |
255 | if (return_distances && return_vectors) {
256 | distances = outputs[0];
257 | vectors = outputs[1];
258 | } else if (return_distances) {
259 | distances = outputs[0];
260 | } else {
261 | assert(return_vectors);
262 | vectors = outputs[0];
263 | }
264 | }
265 |
266 | // assemble the output
267 | auto output = std::vector();
268 | for (auto c: quantities) {
269 | if (c == 'i') {
270 | output.push_back(pairs.index({torch::indexing::Slice(), 0}));
271 | } else if (c == 'j') {
272 | output.push_back(pairs.index({torch::indexing::Slice(), 1}));
273 | } else if (c == 'P') {
274 | output.push_back(pairs);
275 | } else if (c == 'S') {
276 | output.push_back(shifts);
277 | } else if (c == 'd') {
278 | output.push_back(distances);
279 | } else if (c == 'D') {
280 | output.push_back(vectors);
281 | } else {
282 | C10_THROW_ERROR(ValueError,
283 | "unexpected character in `quantities`: " + std::string(1, c)
284 | );
285 | }
286 | }
287 |
288 | return output;
289 | }
290 |
291 |
292 | TORCH_LIBRARY(vesin, m) {
293 | std::string DOCSTRING;
294 |
295 | m.class_("_NeighborList")
296 | .def(
297 | torch::init(), DOCSTRING,
298 | {torch::arg("cutoff"), torch::arg("full_list"), torch::arg("sorted") = false}
299 | )
300 | .def("compute", &NeighborListHolder::compute, DOCSTRING,
301 | {torch::arg("points"), torch::arg("box"), torch::arg("periodic"), torch::arg("quantities"), torch::arg("copy") = true}
302 | )
303 | ;
304 | }
305 |
306 | // ========================================================================== //
307 | // //
308 | // ========================================================================== //
309 |
310 | std::vector AutogradNeighbors::forward(
311 | torch::autograd::AutogradContext* ctx,
312 | torch::Tensor points,
313 | torch::Tensor box,
314 | bool periodic,
315 | torch::Tensor pairs,
316 | torch::optional shifts,
317 | torch::optional distances,
318 | torch::optional vectors
319 | ) {
320 | auto shifts_tensor = shifts.value_or(torch::Tensor());
321 | auto distances_tensor = distances.value_or(torch::Tensor());
322 | auto vectors_tensor = vectors.value_or(torch::Tensor());
323 |
324 | ctx->save_for_backward({points, box, pairs, shifts_tensor, distances_tensor, vectors_tensor});
325 | ctx->saved_data["periodic"] = periodic;
326 |
327 | auto return_distances = distances.has_value();
328 | auto return_vectors = vectors.has_value();
329 | ctx->saved_data["return_distances"] = return_distances;
330 | ctx->saved_data["return_vectors"] = return_vectors;
331 |
332 | // only return defined tensors to make sure torch can use `get_autograd_meta()`
333 | if (return_distances && return_vectors) {
334 | return {distances_tensor, vectors_tensor};
335 | } else if (return_distances) {
336 | return {distances_tensor};
337 | } else if (return_vectors) {
338 | return {vectors_tensor};
339 | } else {
340 | return {};
341 | }
342 | }
343 |
344 | std::vector AutogradNeighbors::backward(
345 | torch::autograd::AutogradContext* ctx,
346 | std::vector outputs_grad
347 | ) {
348 | auto saved_variables = ctx->get_saved_variables();
349 | auto points = saved_variables[0];
350 | auto box = saved_variables[1];
351 | auto periodic = ctx->saved_data["periodic"].toBool();
352 |
353 | auto pairs = saved_variables[2];
354 | auto shifts = saved_variables[3];
355 | auto distances = saved_variables[4];
356 | auto vectors = saved_variables[5];
357 |
358 | auto return_distances = ctx->saved_data["return_distances"].toBool();
359 | auto return_vectors = ctx->saved_data["return_vectors"].toBool();
360 |
361 | auto distances_grad = torch::Tensor();
362 | auto vectors_grad = torch::Tensor();
363 | if (return_distances && return_vectors) {
364 | distances_grad = outputs_grad[0];
365 | vectors_grad = outputs_grad[1];
366 | } else if (return_distances) {
367 | distances_grad = outputs_grad[0];
368 | } else if (return_vectors) {
369 | vectors_grad = outputs_grad[0];
370 | } else {
371 | // nothing to do
372 | return {
373 | torch::Tensor(),
374 | torch::Tensor(),
375 | torch::Tensor(),
376 | torch::Tensor(),
377 | torch::Tensor(),
378 | torch::Tensor(),
379 | torch::Tensor(),
380 | };
381 | }
382 |
383 | if (points.requires_grad() || box.requires_grad()) {
384 | // Do a first backward step from distances_grad to vectors_grad
385 | vectors_grad += distances_grad.index({torch::indexing::Slice(), torch::indexing::None})
386 | * vectors / distances.index({torch::indexing::Slice(), torch::indexing::None});
387 | }
388 |
389 | auto points_grad = torch::Tensor();
390 | if (points.requires_grad()) {
391 | points_grad = torch::zeros_like(points);
392 | points_grad = torch::index_add(
393 | points_grad,
394 | /*dim=*/0,
395 | /*index=*/pairs.index({torch::indexing::Slice(), 1}),
396 | /*source=*/vectors_grad,
397 | /*alpha=*/1.0
398 | );
399 | points_grad = torch::index_add(
400 | points_grad,
401 | /*dim=*/0,
402 | /*index=*/pairs.index({torch::indexing::Slice(), 0}),
403 | /*source=*/vectors_grad,
404 | /*alpha=*/-1.0
405 | );
406 | }
407 |
408 | auto box_grad = torch::Tensor();
409 | if (periodic && box.requires_grad()) {
410 | auto cell_shifts = shifts.to(box.scalar_type());
411 | box_grad = cell_shifts.t().matmul(vectors_grad);
412 | }
413 |
414 | return {
415 | points_grad,
416 | box_grad,
417 | torch::Tensor(),
418 | torch::Tensor(),
419 | torch::Tensor(),
420 | torch::Tensor(),
421 | torch::Tensor(),
422 | };
423 | }
424 |
--------------------------------------------------------------------------------
|