├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── benchmark
├── benchmark.ipynb
├── benchmark.py
├── diamond_benchmark.png
└── profile_nl.py
├── pyproject.toml
├── requirements.txt
├── setup.cfg
├── setup.py
└── torch_nl
├── __init__.py
├── geometry.py
├── linked_cell.py
├── naive_impl.py
├── neighbor_list.py
├── test_nl.py
├── timer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.prof
6 | *.code-workspace
7 | # C extensions
8 | *.so
9 | .DS_Store
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | pip-wheel-metadata/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 felixmusil
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # torch_nl
2 |
3 | Provide a pytorch implementation of a naive (`compute_neighborlist_n2`) and a linked cell (`compute_neighborlist`) neighbor list that are compatible with TorchScript.
4 |
5 | Their correctness is tested against ASE's implementation.
6 |
7 | Note that contrary to ASE, the atoms positions are assumed to be wrapped inside the unit cell.
8 | # How to
9 |
10 | ## instal with pip
11 |
12 | ```bash
13 | pip install torch-nl
14 | ```
15 |
16 | ## use the neighborlist
17 |
18 | ```python
19 | from torch_nl import compute_neighborlist, ase2data
20 | from ase.build import bulk, molecule
21 |
22 | frames = [bulk("Si", "diamond", a=6, cubic=True), molecule("CH3CH2NH2")]
23 | pos, cell, pbc, batch, n_atoms = ase2data(frames)
24 |
25 | mapping, batch_mapping, shifts_idx = compute_neighborlist(
26 | cutoff, pos, cell, pbc, batch, self_interaction
27 | )
28 | ```
29 |
30 | # Benchmarks
31 |
32 | ## Periodic structure
33 |
34 | 
35 |
--------------------------------------------------------------------------------
/benchmark/benchmark.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 14,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import matplotlib.pyplot as plt\n",
10 | "import torch\n",
11 | "import ase\n",
12 | "import numpy as np\n",
13 | "import scipy\n",
14 | "from ase.build import molecule, bulk, make_supercell\n",
15 | "from ase.neighborlist import neighbor_list\n",
16 | "import pandas as pd\n",
17 | "from tqdm.notebook import tqdm\n",
18 | "import seaborn as sns"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 2,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "from torch_nl import compute_neighborlist, compute_neighborlist_n2, ase2data\n",
28 | "from torch_nl.timer import timeit"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "# Periodic systems"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "## Metal "
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 3,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "frame = bulk('Si', 'diamond', a=4, cubic=True)\n",
52 | "aa = torch.arange(1, 6)\n",
53 | "Ps = torch.cartesian_prod(aa,aa,aa)\n",
54 | "Ps = Ps[torch.sort(Ps.sum(dim=1)).indices].to(torch.long).numpy()\n",
55 | "frames = []\n",
56 | "n_atoms = []\n",
57 | "for P in Ps:\n",
58 | " frames.append(make_supercell(frame, np.diag(P)))\n",
59 | " n_atoms.append(len(frames[-1]))\n",
60 | "n_atoms = np.array(n_atoms)"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 4,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "cutoff = 4"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 5,
75 | "metadata": {},
76 | "outputs": [
77 | {
78 | "data": {
79 | "application/json": {
80 | "ascii": false,
81 | "bar_format": null,
82 | "colour": null,
83 | "elapsed": 0.011367321014404297,
84 | "initial": 0,
85 | "n": 0,
86 | "ncols": null,
87 | "nrows": 54,
88 | "postfix": null,
89 | "prefix": "",
90 | "rate": null,
91 | "total": 125,
92 | "unit": "it",
93 | "unit_divisor": 1000,
94 | "unit_scale": false
95 | },
96 | "application/vnd.jupyter.widget-view+json": {
97 | "model_id": "ab749187b35f41579193cdd675218261",
98 | "version_major": 2,
99 | "version_minor": 0
100 | },
101 | "text/plain": [
102 | " 0%| | 0/125 [00:00, ?it/s]"
103 | ]
104 | },
105 | "metadata": {},
106 | "output_type": "display_data"
107 | }
108 | ],
109 | "source": [
110 | "tag = \"ASE\"\n",
111 | "datas = []\n",
112 | "for frame in tqdm(frames):\n",
113 | " timing = timeit(neighbor_list, ['ijS', frame, cutoff], tag=tag, warmup=1, nit=50)\n",
114 | " data = timing.dumps()\n",
115 | " i,j,S = neighbor_list('ijS', frame, cutoff)\n",
116 | " n_neighbor = np.bincount(i).mean()\n",
117 | " data.update(n_atom=len(frame), n_neighbor_per_atom_avg=int(n_neighbor))\n",
118 | " data.pop('samples')\n",
119 | " datas.append(data)\n",
120 | "# df = pd.DataFrame(datas)"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 6,
126 | "metadata": {},
127 | "outputs": [
128 | {
129 | "data": {
130 | "application/json": {
131 | "ascii": false,
132 | "bar_format": null,
133 | "colour": null,
134 | "elapsed": 0.011538267135620117,
135 | "initial": 0,
136 | "n": 0,
137 | "ncols": null,
138 | "nrows": 54,
139 | "postfix": null,
140 | "prefix": "",
141 | "rate": null,
142 | "total": 3,
143 | "unit": "it",
144 | "unit_divisor": 1000,
145 | "unit_scale": false
146 | },
147 | "application/vnd.jupyter.widget-view+json": {
148 | "model_id": "7271363cbf3c488b8f304250ae5e5dae",
149 | "version_major": 2,
150 | "version_minor": 0
151 | },
152 | "text/plain": [
153 | " 0%| | 0/3 [00:00, ?it/s]"
154 | ]
155 | },
156 | "metadata": {},
157 | "output_type": "display_data"
158 | },
159 | {
160 | "data": {
161 | "application/json": {
162 | "ascii": false,
163 | "bar_format": null,
164 | "colour": null,
165 | "elapsed": 0.011441469192504883,
166 | "initial": 0,
167 | "n": 0,
168 | "ncols": null,
169 | "nrows": 54,
170 | "postfix": null,
171 | "prefix": "",
172 | "rate": null,
173 | "total": 125,
174 | "unit": "it",
175 | "unit_divisor": 1000,
176 | "unit_scale": false
177 | },
178 | "application/vnd.jupyter.widget-view+json": {
179 | "model_id": "64db2e509ee9433589ffa2fe314aec27",
180 | "version_major": 2,
181 | "version_minor": 0
182 | },
183 | "text/plain": [
184 | " 0%| | 0/125 [00:00, ?it/s]"
185 | ]
186 | },
187 | "metadata": {},
188 | "output_type": "display_data"
189 | },
190 | {
191 | "data": {
192 | "application/json": {
193 | "ascii": false,
194 | "bar_format": null,
195 | "colour": null,
196 | "elapsed": 0.009797096252441406,
197 | "initial": 0,
198 | "n": 0,
199 | "ncols": null,
200 | "nrows": 54,
201 | "postfix": null,
202 | "prefix": "",
203 | "rate": null,
204 | "total": 125,
205 | "unit": "it",
206 | "unit_divisor": 1000,
207 | "unit_scale": false
208 | },
209 | "application/vnd.jupyter.widget-view+json": {
210 | "model_id": "4af8f351b2d54b7e8a59d589d6ee0d79",
211 | "version_major": 2,
212 | "version_minor": 0
213 | },
214 | "text/plain": [
215 | " 0%| | 0/125 [00:00, ?it/s]"
216 | ]
217 | },
218 | "metadata": {},
219 | "output_type": "display_data"
220 | },
221 | {
222 | "data": {
223 | "application/json": {
224 | "ascii": false,
225 | "bar_format": null,
226 | "colour": null,
227 | "elapsed": 0.018729209899902344,
228 | "initial": 0,
229 | "n": 0,
230 | "ncols": null,
231 | "nrows": 54,
232 | "postfix": null,
233 | "prefix": "",
234 | "rate": null,
235 | "total": 125,
236 | "unit": "it",
237 | "unit_divisor": 1000,
238 | "unit_scale": false
239 | },
240 | "application/vnd.jupyter.widget-view+json": {
241 | "model_id": "a0754679bca646fe9edc4286a67f8004",
242 | "version_major": 2,
243 | "version_minor": 0
244 | },
245 | "text/plain": [
246 | " 0%| | 0/125 [00:00, ?it/s]"
247 | ]
248 | },
249 | "metadata": {},
250 | "output_type": "display_data"
251 | }
252 | ],
253 | "source": [
254 | "tags = [\n",
255 | " # \"torch_nl O(n^2) CPU\", \n",
256 | " \"torch_nl O(n^2) GPU\", \n",
257 | " \"torch_nl O(n) CPU\", \n",
258 | " \"torch_nl O(n) GPU\"\n",
259 | "]\n",
260 | "for tag in tqdm(tags):\n",
261 | " if \"CPU\" in tag:\n",
262 | " device = 'cpu'\n",
263 | " elif \"GPU\" in tag:\n",
264 | " device = 'cuda'\n",
265 | " \n",
266 | " if 'O(n^2)' in tag:\n",
267 | " nl_func = compute_neighborlist_n2\n",
268 | " elif 'O(n)' in tag:\n",
269 | " nl_func = compute_neighborlist\n",
270 | "\n",
271 | " for frame in tqdm(frames):\n",
272 | " pos, cell, pbc, batch, n_atoms = ase2data([frame], device=device)\n",
273 | " timing = timeit(nl_func, [cutoff, pos, cell, pbc, batch], tag=tag, warmup=10, nit=50)\n",
274 | " data = timing.dumps()\n",
275 | " data.pop('samples')\n",
276 | " mapping, mapping_batch, shifts_idx = nl_func(cutoff, pos, cell, pbc, batch)\n",
277 | " n_neighbor = np.bincount(mapping[0].cpu().numpy()).mean()\n",
278 | " data.update(n_atom=len(frame), n_neighbor_per_atom_avg=int(n_neighbor))\n",
279 | " datas.append(data)"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": 7,
285 | "metadata": {},
286 | "outputs": [],
287 | "source": [
288 | "df = pd.DataFrame(datas)"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": 8,
294 | "metadata": {},
295 | "outputs": [
296 | {
297 | "data": {
298 | "text/html": [
299 | "
\n",
300 | "\n",
313 | "
\n",
314 | " \n",
315 | " \n",
316 | " | \n",
317 | " tag | \n",
318 | " mean | \n",
319 | " stdev | \n",
320 | " min | \n",
321 | " max | \n",
322 | " n_atom | \n",
323 | " n_neighbor_per_atom_avg | \n",
324 | "
\n",
325 | " \n",
326 | " \n",
327 | " \n",
328 | " 0 | \n",
329 | " ASE | \n",
330 | " 0.002498 | \n",
331 | " 0.000041 | \n",
332 | " 0.002439 | \n",
333 | " 0.002685 | \n",
334 | " 8 | \n",
335 | " 28 | \n",
336 | "
\n",
337 | " \n",
338 | " 1 | \n",
339 | " ASE | \n",
340 | " 0.003127 | \n",
341 | " 0.000020 | \n",
342 | " 0.003101 | \n",
343 | " 0.003168 | \n",
344 | " 16 | \n",
345 | " 28 | \n",
346 | "
\n",
347 | " \n",
348 | " 2 | \n",
349 | " ASE | \n",
350 | " 0.003162 | \n",
351 | " 0.000063 | \n",
352 | " 0.003101 | \n",
353 | " 0.003374 | \n",
354 | " 16 | \n",
355 | " 28 | \n",
356 | "
\n",
357 | " \n",
358 | " 3 | \n",
359 | " ASE | \n",
360 | " 0.003145 | \n",
361 | " 0.000043 | \n",
362 | " 0.003100 | \n",
363 | " 0.003297 | \n",
364 | " 16 | \n",
365 | " 28 | \n",
366 | "
\n",
367 | " \n",
368 | " 4 | \n",
369 | " ASE | \n",
370 | " 0.004429 | \n",
371 | " 0.000050 | \n",
372 | " 0.004374 | \n",
373 | " 0.004577 | \n",
374 | " 32 | \n",
375 | " 28 | \n",
376 | "
\n",
377 | " \n",
378 | " ... | \n",
379 | " ... | \n",
380 | " ... | \n",
381 | " ... | \n",
382 | " ... | \n",
383 | " ... | \n",
384 | " ... | \n",
385 | " ... | \n",
386 | "
\n",
387 | " \n",
388 | " 495 | \n",
389 | " torch_nl O(n) GPU | \n",
390 | " 0.005046 | \n",
391 | " 0.000190 | \n",
392 | " 0.004930 | \n",
393 | " 0.006348 | \n",
394 | " 600 | \n",
395 | " 29 | \n",
396 | "
\n",
397 | " \n",
398 | " 496 | \n",
399 | " torch_nl O(n) GPU | \n",
400 | " 0.005924 | \n",
401 | " 0.000094 | \n",
402 | " 0.005809 | \n",
403 | " 0.006476 | \n",
404 | " 800 | \n",
405 | " 29 | \n",
406 | "
\n",
407 | " \n",
408 | " 497 | \n",
409 | " torch_nl O(n) GPU | \n",
410 | " 0.005886 | \n",
411 | " 0.000038 | \n",
412 | " 0.005795 | \n",
413 | " 0.005943 | \n",
414 | " 800 | \n",
415 | " 29 | \n",
416 | "
\n",
417 | " \n",
418 | " 498 | \n",
419 | " torch_nl O(n) GPU | \n",
420 | " 0.005913 | \n",
421 | " 0.000068 | \n",
422 | " 0.005803 | \n",
423 | " 0.006292 | \n",
424 | " 800 | \n",
425 | " 29 | \n",
426 | "
\n",
427 | " \n",
428 | " 499 | \n",
429 | " torch_nl O(n) GPU | \n",
430 | " 0.006839 | \n",
431 | " 0.000631 | \n",
432 | " 0.006651 | \n",
433 | " 0.011249 | \n",
434 | " 1000 | \n",
435 | " 29 | \n",
436 | "
\n",
437 | " \n",
438 | "
\n",
439 | "
500 rows × 7 columns
\n",
440 | "
"
441 | ],
442 | "text/plain": [
443 | " tag mean stdev min max n_atom \\\n",
444 | "0 ASE 0.002498 0.000041 0.002439 0.002685 8 \n",
445 | "1 ASE 0.003127 0.000020 0.003101 0.003168 16 \n",
446 | "2 ASE 0.003162 0.000063 0.003101 0.003374 16 \n",
447 | "3 ASE 0.003145 0.000043 0.003100 0.003297 16 \n",
448 | "4 ASE 0.004429 0.000050 0.004374 0.004577 32 \n",
449 | ".. ... ... ... ... ... ... \n",
450 | "495 torch_nl O(n) GPU 0.005046 0.000190 0.004930 0.006348 600 \n",
451 | "496 torch_nl O(n) GPU 0.005924 0.000094 0.005809 0.006476 800 \n",
452 | "497 torch_nl O(n) GPU 0.005886 0.000038 0.005795 0.005943 800 \n",
453 | "498 torch_nl O(n) GPU 0.005913 0.000068 0.005803 0.006292 800 \n",
454 | "499 torch_nl O(n) GPU 0.006839 0.000631 0.006651 0.011249 1000 \n",
455 | "\n",
456 | " n_neighbor_per_atom_avg \n",
457 | "0 28 \n",
458 | "1 28 \n",
459 | "2 28 \n",
460 | "3 28 \n",
461 | "4 28 \n",
462 | ".. ... \n",
463 | "495 29 \n",
464 | "496 29 \n",
465 | "497 29 \n",
466 | "498 29 \n",
467 | "499 29 \n",
468 | "\n",
469 | "[500 rows x 7 columns]"
470 | ]
471 | },
472 | "execution_count": 8,
473 | "metadata": {},
474 | "output_type": "execute_result"
475 | }
476 | ],
477 | "source": [
478 | "df"
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": 16,
484 | "metadata": {
485 | "scrolled": true
486 | },
487 | "outputs": [
488 | {
489 | "data": {
490 | "image/png": "\n",
491 | "text/plain": [
492 | ""
493 | ]
494 | },
495 | "metadata": {},
496 | "output_type": "display_data"
497 | }
498 | ],
499 | "source": [
500 | "with sns.plotting_context(\"notebook\"):\n",
501 | " plt.title(\"Compute neighborlist: diamond structure\")\n",
502 | " sns.lmplot(data=df, x='n_atom', y='mean', hue='tag',fit_reg=False)\n",
503 | " plt.xlabel('timing []')\n",
504 | " plt.savefig('diamond_benchmark.png', dpi=300, bbox_inches='tight')"
505 | ]
506 | },
507 | {
508 | "cell_type": "code",
509 | "execution_count": 15,
510 | "metadata": {},
511 | "outputs": [],
512 | "source": [
513 | "plt.savefig?"
514 | ]
515 | },
516 | {
517 | "cell_type": "code",
518 | "execution_count": null,
519 | "metadata": {},
520 | "outputs": [],
521 | "source": []
522 | }
523 | ],
524 | "metadata": {
525 | "kernelspec": {
526 | "display_name": "jax39",
527 | "language": "python",
528 | "name": "jax39"
529 | },
530 | "language_info": {
531 | "codemirror_mode": {
532 | "name": "ipython",
533 | "version": 3
534 | },
535 | "file_extension": ".py",
536 | "mimetype": "text/x-python",
537 | "name": "python",
538 | "nbconvert_exporter": "python",
539 | "pygments_lexer": "ipython3",
540 | "version": "3.9.13"
541 | },
542 | "toc": {
543 | "colors": {
544 | "hover_highlight": "#DAA520",
545 | "navigate_num": "#000000",
546 | "navigate_text": "#333333",
547 | "running_highlight": "#FF0000",
548 | "selected_highlight": "#FFD700",
549 | "sidebar_border": "#EEEEEE",
550 | "wrapper_background": "#FFFFFF"
551 | },
552 | "moveMenuLeft": true,
553 | "nav_menu": {
554 | "height": "48.9333px",
555 | "width": "251.8px"
556 | },
557 | "navigate_menu": true,
558 | "number_sections": true,
559 | "sideBar": true,
560 | "threshold": 4,
561 | "toc_cell": false,
562 | "toc_section_display": "block",
563 | "toc_window_display": false,
564 | "widenNotebook": false
565 | },
566 | "vscode": {
567 | "interpreter": {
568 | "hash": "f79d3df5ff5684964744ab9f5218f96eb946f2b40d3f02d5eb965bb50f364a25"
569 | }
570 | }
571 | },
572 | "nbformat": 4,
573 | "nbformat_minor": 2
574 | }
575 |
--------------------------------------------------------------------------------
/benchmark/benchmark.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import ase
3 | import sys
4 | import numpy as np
5 | import scipy
6 | from ase.build import molecule, bulk, make_supercell
7 | from ase.neighborlist import neighbor_list
8 | import pandas as pd
9 | from tqdm import tqdm
10 |
11 | # import seaborn as sns
12 | import matplotlib.pyplot as plt
13 |
14 | sys.path.insert(0, "../")
15 | from torch_nl import compute_neighborlist, compute_neighborlist_n2, ase2data
16 | from torch_nl.timer import timeit
17 |
18 | torch.set_num_threads(4)
19 | cutoff = 4
20 | tags = [
21 | # "torch_nl O(n^2) CPU",
22 | "torch_nl O(n^2) GPU",
23 | "torch_nl O(n) CPU",
24 | "torch_nl O(n) GPU"
25 | ]
26 |
27 | frame = bulk("Si", "diamond", a=4, cubic=True)
28 | aa = torch.arange(1, 6)
29 | Ps = torch.cartesian_prod(aa, aa, aa)
30 | Ps = Ps[torch.sort(Ps.sum(dim=1)).indices].to(torch.long).numpy()
31 | frames = []
32 | n_atoms = []
33 | for P in Ps:
34 | frames.append(make_supercell(frame, np.diag(P)))
35 | n_atoms.append(len(frames[-1]))
36 | n_atoms = np.array(n_atoms)
37 | print("Starting")
38 | tag = "ASE"
39 | datas = []
40 | for frame in tqdm(frames):
41 | timing = timeit(
42 | neighbor_list, ["ijS", frame, cutoff], tag=tag, warmup=1, nit=50
43 | )
44 | data = timing.dumps()
45 | i, j, S = neighbor_list("ijS", frame, cutoff)
46 | n_neighbor = np.bincount(i).mean()
47 | data.update(n_atom=len(frame), n_neighbor_per_atom_avg=int(n_neighbor))
48 | data.pop("samples")
49 | datas.append(data)
50 |
51 |
52 | for tag in tqdm(tags):
53 | if "CPU" in tag:
54 | device = "cpu"
55 | elif "GPU" in tag:
56 | device = "cuda"
57 |
58 | if "O(n^2)" in tag:
59 | nl_func = compute_neighborlist_n2
60 | elif "O(n)" in tag:
61 | nl_func = compute_neighborlist
62 |
63 | for frame in tqdm(frames):
64 | pos, cell, pbc, batch, n_atoms = ase2data([frame], device=device)
65 | with torch.cuda.amp.autocast():
66 | timing = timeit(
67 | nl_func,
68 | [cutoff, pos, cell, pbc, batch],
69 | tag=tag,
70 | warmup=10,
71 | nit=50,
72 | )
73 | data = timing.dumps()
74 | data.pop("samples")
75 | mapping, mapping_batch, shifts_idx = nl_func(
76 | cutoff, pos, cell, pbc, batch
77 | )
78 | n_neighbor = np.bincount(mapping[0].cpu().numpy()).mean()
79 | data.update(n_atom=len(frame), n_neighbor_per_atom_avg=int(n_neighbor))
80 | datas.append(data)
81 |
82 | df = pd.DataFrame(datas)
83 |
84 | # sns.lmplot(data=df, x='n_atom', y='mean', hue='tag',fit_reg=False)
85 |
86 | # plt.savefig('./test_0.png', dpi=300, bbox_inches='tight')
87 | # plt.show()
88 | print("END")
89 | # %%
90 |
--------------------------------------------------------------------------------
/benchmark/diamond_benchmark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/felixmusil/torch_nl/ce69750db61343b983c4e9cd92450494a487885e/benchmark/diamond_benchmark.png
--------------------------------------------------------------------------------
/benchmark/profile_nl.py:
--------------------------------------------------------------------------------
1 | import torch.profiler
2 | import torch
3 |
4 | torch.jit.set_fusion_strategy([("STATIC", 3), ("DYNAMIC", 3)])
5 |
6 | import sys
7 |
8 | sys.path.insert(0, "../")
9 | import numpy as np
10 |
11 | from torch_nl import compute_neighborlist, compute_neighborlist_n2, ase2data
12 | from torch_nl.timer import timeit
13 |
14 | from ase.build import molecule, bulk, make_supercell
15 |
16 | device = "cuda"
17 | cutoff = 4
18 | frame = bulk("Si", "diamond", a=4, cubic=True)
19 |
20 | frame = make_supercell(frame, 6 * np.eye(3))
21 |
22 |
23 | pos, cell, pbc, batch, n_atoms = ase2data([frame], device=device)
24 |
25 |
26 | with torch.profiler.profile(
27 | schedule=torch.profiler.schedule(wait=20, warmup=20, active=2, repeat=1),
28 | on_trace_ready=torch.profiler.tensorboard_trace_handler(
29 | "/local_scratch/musil/nl_n.prof"
30 | ),
31 | record_shapes=False,
32 | profile_memory=False,
33 | with_stack=True,
34 | ) as prof:
35 | for _ in range(50):
36 | mapping, mapping_batch, shifts_idx = compute_neighborlist_n2(
37 | cutoff, pos, cell, pbc, batch
38 | )
39 | prof.step()
40 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel", "ninja"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.black]
6 | line-length = 80
7 | target-version = ['py38', 'py39']
8 | include = '\.pyi?$'
9 | extend-exclude = '''
10 | /(
11 |
12 | )/
13 | '''
14 |
15 | [tool.pytest.ini_options]
16 | minversion = "6.0"
17 | addopts = "-ra -q"
18 | testpaths = [
19 | "mlcg",
20 | ]
21 | filterwarnings = [
22 | "ignore::DeprecationWarning:networkx.*"
23 | ]
24 |
25 | [tool.coverage.run]
26 | branch = true
27 | source = ["mlcg/"]
28 | omit = [
29 | "**/test_*.py",
30 | "**/__init__.py",
31 | ]
32 |
33 | [tool.coverage.report]
34 | exclude_lines = [
35 | "if self.debug:" ,
36 | "pragma: no cover" ,
37 | "raise NotImplementedError" ,
38 | "@(abc\\.)?abstractmethod" ,
39 | ]
40 | ignore_errors = true
41 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ase
2 | numpy
3 | torch >=1.10
4 | # developer tools
5 | pytest
6 | black[jupyter]
7 |
8 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = torch_nl
3 | version = attr: torch_nl.__version__
4 | description = TorchScript-able neighbor lists implementations (linear and quadratic scaling) for molecular modeling
5 | long_description = file: README.md
6 | long_description_content_type = text/markdown
7 | classifiers =
8 | Programming Language :: Python :: 3.7
9 | Topic :: Scientific/Engineering :: Chemistry
10 | Topic :: Scientific/Engineering :: Physics
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | import re
3 |
4 | NAME = "torch_nl"
5 |
6 | # read the version number from the library
7 | pattern = r"[0-9]\.[0-9]"
8 | VERSION = None
9 | with open("./torch_nl/__init__.py", "r") as fp:
10 | for line in fp.readlines():
11 | if "__version__" in line:
12 | VERSION = re.findall(pattern, line)[0]
13 | if VERSION is None:
14 | raise ValueError("Version number not found.")
15 |
16 |
17 | with open("requirements.txt") as f:
18 | install_requires = list(
19 | filter(lambda x: "#" not in x, (line.strip() for line in f))
20 | )
21 |
22 | setup(
23 | name=NAME,
24 | version=VERSION,
25 | packages=find_packages(),
26 | zip_safe=True,
27 | python_requires=">=3.8",
28 | license="MIT",
29 | author="Fe" + "\u0301" + "lix Musil",
30 | install_requires=install_requires,
31 | )
32 |
--------------------------------------------------------------------------------
/torch_nl/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.3"
2 |
3 | from .neighbor_list import (
4 | compute_neighborlist,
5 | compute_neighborlist_n2,
6 | strict_nl,
7 | )
8 | from .geometry import compute_distances, compute_cell_shifts
9 | from .naive_impl import build_naive_neighborhood
10 | from .linked_cell import linked_cell
11 | from .utils import ase2data
12 |
--------------------------------------------------------------------------------
/torch_nl/geometry.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional
3 |
4 |
5 | def compute_distances(
6 | pos: torch.Tensor,
7 | mapping: torch.Tensor,
8 | cell_shifts: Optional[torch.Tensor] = None,
9 | ):
10 | assert mapping.dim() == 2
11 | assert mapping.shape[0] == 2
12 |
13 | if cell_shifts is None:
14 | dr = pos[mapping[1]] - pos[mapping[0]]
15 | else:
16 | dr = pos[mapping[1]] - pos[mapping[0]] + cell_shifts
17 |
18 | return dr.norm(p=2, dim=1)
19 |
20 |
21 | def compute_cell_shifts(
22 | cell: torch.Tensor, shifts_idx: torch.Tensor, batch_mapping: torch.Tensor
23 | ):
24 | if cell is None:
25 | cell_shifts = None
26 | else:
27 | cell_shifts = torch.einsum(
28 | "jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[batch_mapping]
29 | )
30 | return cell_shifts
31 |
--------------------------------------------------------------------------------
/torch_nl/linked_cell.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import torch
3 |
4 | from .utils import get_number_of_cell_repeats, get_cell_shift_idx, strides_of
5 | from .geometry import compute_cell_shifts
6 |
7 |
8 | def ravel_3d(idx_3d: torch.Tensor, shape: torch.Tensor) -> torch.Tensor:
9 | """Convert 3d indices meant for an array of sizes `shape` into linear
10 | indices.
11 |
12 | Parameters
13 | ----------
14 | idx_3d : [-1, 3]
15 | _description_
16 | shape : [3]
17 | _description_
18 |
19 | Returns
20 | -------
21 | torch.Tensor
22 | linear indices
23 | """
24 | idx_linear = idx_3d[:, 2] + shape[2] * (
25 | idx_3d[:, 1] + shape[1] * idx_3d[:, 0]
26 | )
27 | return idx_linear
28 |
29 |
30 | def unravel_3d(idx_linear: torch.Tensor, shape: torch.Tensor) -> torch.Tensor:
31 | """Convert linear indices meant for an array of sizes `shape` into 3d indices.
32 |
33 | Parameters
34 | ----------
35 | idx_linear : torch.Tensor [-1]
36 |
37 | shape : torch.Tensor [3]
38 |
39 |
40 | Returns
41 | -------
42 | torch.Tensor [-1, 3]
43 |
44 | """
45 | idx_3d = idx_linear.new_empty((idx_linear.shape[0], 3))
46 | idx_3d[:, 2] = torch.remainder(idx_linear, shape[2])
47 | idx_3d[:, 1] = torch.remainder(
48 | torch.div(idx_linear, shape[2], rounding_mode="floor"), shape[1]
49 | )
50 | idx_3d[:, 0] = torch.div(
51 | idx_linear, shape[1] * shape[2], rounding_mode="floor"
52 | )
53 | return idx_3d
54 |
55 |
56 | def get_linear_bin_idx(
57 | cell: torch.Tensor, pos: torch.Tensor, nbins_s: torch.Tensor
58 | ) -> torch.Tensor:
59 | """Find the linear bin index of each input pos given a box defined by its cell vectors and a number of bins, contained in the box, for each directions of the box.
60 |
61 | Parameters
62 | ----------
63 | cell : torch.Tensor [3, 3]
64 | cell vectors
65 | pos : torch.Tensor [-1, 3]
66 | set of positions
67 | nbins_s : torch.Tensor [3]
68 | number of bins in each directions
69 |
70 | Returns
71 | -------
72 | torch.Tensor
73 | linear bin index
74 | """
75 | scaled_pos = torch.linalg.solve(cell.t(), pos.t()).t()
76 | bin_index_s = torch.floor(scaled_pos * nbins_s).to(torch.long)
77 | bin_index_l = ravel_3d(bin_index_s, nbins_s)
78 | return bin_index_l
79 |
80 | def scatter_bin_index(
81 | nbins: int,
82 | max_n_atom_per_bin: int,
83 | n_images: int,
84 | bin_index: torch.Tensor,
85 | ):
86 | """convert the linear table `bin_index` into the table `bin_id`. Empty entries in `bin_id` are set to `n_images` so that they can be removed later.
87 |
88 | Parameters
89 | ----------
90 | nbins : _type_
91 | total number of bins
92 | max_n_atom_per_bin : _type_
93 | maximum number of atoms per bin
94 | n_images : _type_
95 | total number of atoms counting the pbc replicas
96 | bin_index : _type_
97 | map relating `atom_index` to the `bin_index` that it belongs to such that `bin_index[atom_index] -> bin_index`.
98 |
99 | Returns
100 | -------
101 | bin_id : torch.Tensor [nbins, max_n_atom_per_bin]
102 | relate `bin_index` (row) with the `atom_index` (stored in the columns).
103 | """
104 | device = bin_index.device
105 | sorted_bin_index, sorted_id = torch.sort(bin_index)
106 | bin_id = torch.full(
107 | (nbins * max_n_atom_per_bin,), n_images, device=device, dtype=torch.long
108 | )
109 | sorted_bin_id = torch.remainder(
110 | torch.arange(bin_index.shape[0], device=device), max_n_atom_per_bin
111 | )
112 | sorted_bin_id = sorted_bin_index * max_n_atom_per_bin + sorted_bin_id
113 | bin_id.scatter_(dim=0, index=sorted_bin_id, src=sorted_id)
114 | bin_id = bin_id.view((nbins, max_n_atom_per_bin))
115 | return bin_id
116 |
117 |
118 | def linked_cell(
119 | pos: torch.Tensor,
120 | cell: torch.Tensor,
121 | cutoff: float,
122 | num_repeats: torch.Tensor,
123 | self_interaction: bool = False,
124 | ) -> Tuple[torch.Tensor, torch.Tensor]:
125 | """Determine the atomic neighborhood of the atoms of a given structure for a particular cutoff using the linked cell algorithm.
126 |
127 | Parameters
128 | ----------
129 | pos : torch.Tensor [n_atom, 3]
130 | atomic positions in the unit cell (positions outside the cell boundaries will result in an undifined behaviour)
131 | cell : torch.Tensor [3, 3]
132 | unit cell vectors in the format V=[v_0, v_1, v_2]
133 | cutoff : float
134 | length used to determine neighborhood
135 | num_repeats : torch.Tensor [3]
136 | number of unit cell repetitions in each directions required to account for PBC
137 | self_interaction : bool, optional
138 | to keep the original atoms as their own neighbor, by default False
139 |
140 | Returns
141 | -------
142 | Tuple[torch.Tensor, torch.Tensor]
143 | neigh_atom : [2, n_neighbors]
144 | indices of the original atoms (neigh_atom[0]) with their neighbor index (neigh_atom[1]). The indices are meant to access the provided position array
145 | neigh_shift_idx : [n_neighbors, 3]
146 | cell shift indices to be used in reconstructing the neighbor atom positions.
147 | """
148 | device = pos.device
149 | dtype = pos.dtype
150 | n_atom = pos.shape[0]
151 | # find all the integer shifts of the unit cell given the cutoff and periodicity
152 | shifts_idx = get_cell_shift_idx(num_repeats, dtype)
153 | n_cell_image = shifts_idx.shape[0]
154 | shifts_idx = torch.repeat_interleave(
155 | shifts_idx, n_atom, dim=0, output_size=n_atom * n_cell_image
156 | )
157 | batch_image = torch.zeros((shifts_idx.shape[0]), dtype=torch.long)
158 | cell_shifts = compute_cell_shifts(
159 | cell.view(-1, 3, 3), shifts_idx, batch_image
160 | )
161 |
162 | i_ids = torch.arange(n_atom, device=device, dtype=torch.long)
163 | i_ids = i_ids.repeat(n_cell_image)
164 | # compute the positions of the replicated unit cell (including the original)
165 | # they are organized such that: 1st n_atom are the non-shifted atom, 2nd n_atom are moved by the same translation, ...
166 | images = pos[i_ids] + cell_shifts
167 | n_images = images.shape[0]
168 | # create a rectangular box at [0,0,0] that encompases all the atoms (hence shifting the atoms so that they lie inside the box)
169 | b_min = images.min(dim=0).values
170 | b_max = images.max(dim=0).values
171 | images -= b_min - 1e-5
172 | box_length = b_max - b_min + 1e-3
173 | # divide the box into square bins of size cutoff in 3d
174 | nbins_s = torch.maximum(torch.ceil(box_length / cutoff), pos.new_ones(3))
175 | # adapt the box lenghts so that it encompasses
176 | box_vec = torch.diag_embed(nbins_s * cutoff)
177 | nbins_s = nbins_s.to(torch.long)
178 | nbins = int(torch.prod(nbins_s))
179 | # determine which bins the original atoms and the images belong to following a linear indexing of the 3d bins
180 | bin_index_j = get_linear_bin_idx(box_vec, images, nbins_s)
181 | n_atom_j_per_bin = torch.bincount(bin_index_j, minlength=nbins)
182 | max_n_atom_per_bin = int(n_atom_j_per_bin.max())
183 | # convert the linear map bin_index_j into a 2d map. This allows for
184 | # fully vectorized neighbor assignment
185 | bin_id_j = scatter_bin_index(
186 | nbins, max_n_atom_per_bin, n_images, bin_index_j
187 | )
188 |
189 | # find which bins the original atoms belong to
190 | bin_index_i = bin_index_j[:n_atom]
191 | i_bins_l = torch.unique(bin_index_i)
192 | i_bins_s = unravel_3d(i_bins_l, nbins_s)
193 |
194 | # find the bin indices in the neighborhood of i_bins_l. Since the bins have
195 | # a side length of cutoff only 27 bins are in the neighborhood
196 | # (including itself)
197 | dd = torch.tensor([0, 1, -1], dtype=torch.long, device=device)
198 | bin_shifts = torch.cartesian_prod(dd, dd, dd)
199 | n_neigh_bins = bin_shifts.shape[0]
200 | bin_shifts = bin_shifts.repeat((i_bins_s.shape[0], 1))
201 | neigh_bins_s = (
202 | torch.repeat_interleave(
203 | i_bins_s,
204 | n_neigh_bins,
205 | dim=0,
206 | output_size=n_neigh_bins * i_bins_s.shape[0],
207 | )
208 | + bin_shifts
209 | )
210 | # some of the generated bin_idx might not be valid
211 | mask = torch.all(
212 | torch.logical_and(neigh_bins_s < nbins_s.view(1, 3), neigh_bins_s >= 0),
213 | dim=1,
214 | )
215 |
216 | # remove the bins that are outside of the search range, i.e. beyond the borders of the box in the case of non-periodic directions.
217 | neigh_j_bins_l = ravel_3d(neigh_bins_s[mask], nbins_s)
218 |
219 | max_neigh_per_atom = max_n_atom_per_bin * n_neigh_bins
220 | # the i_bin related to neigh_j_bins_l
221 | repeats = mask.view(-1, n_neigh_bins).sum(dim=1)
222 | neigh_i_bins_l = torch.cat(
223 | [
224 | torch.arange(rr, device=device) + i_bins_l[ii] * n_neigh_bins
225 | for ii, rr in enumerate(repeats)
226 | ],
227 | dim=0,
228 | )
229 | # the linear neighborlist. make it at large as necessary
230 | neigh_atom = torch.empty(
231 | (2, n_atom * max_neigh_per_atom), dtype=torch.long, device=device
232 | )
233 | # fill the i_atom index
234 | neigh_atom[0] = (
235 | torch.arange(n_atom).view(-1, 1).repeat(1, max_neigh_per_atom).view(-1)
236 | )
237 | # relate `bin_index` (row) with the `neighbor_atom_index` (stored in the columns). empty entries are set to `n_images`
238 | bin_id_ij = torch.full(
239 | (nbins * n_neigh_bins, max_n_atom_per_bin),
240 | n_images,
241 | dtype=torch.long,
242 | device=device,
243 | )
244 | # fill the bins with neighbor atom indices
245 | bin_id_ij[neigh_i_bins_l] = bin_id_j[neigh_j_bins_l]
246 | bin_id_ij = bin_id_ij.view((nbins, max_neigh_per_atom))
247 | # map the neighbors in the bins to the central atoms
248 | neigh_atom[1] = bin_id_ij[bin_index_i].view(-1)
249 | # remove empty entries
250 | neigh_atom = neigh_atom[:, neigh_atom[1] != n_images]
251 |
252 | if not self_interaction:
253 | # neighbor atoms are still indexed from 0 to n_atom*n_cell_image
254 | neigh_atom = neigh_atom[:, neigh_atom[0] != neigh_atom[1]]
255 |
256 | # sort neighbor list so that the i_atom indices increase
257 | sorted_ids = torch.argsort(neigh_atom[0])
258 | neigh_atom = neigh_atom[:, sorted_ids]
259 | # get the cell shift indices for each neighbor atom
260 | neigh_shift_idx = shifts_idx[neigh_atom[1]]
261 | # make sure the j_atom indices access the original positions
262 | neigh_atom[1] = torch.remainder(neigh_atom[1], n_atom)
263 | # print(neigh_atom)
264 | return neigh_atom, neigh_shift_idx
265 |
266 |
267 | def build_linked_cell_neighborhood(
268 | positions: torch.Tensor,
269 | cell: torch.Tensor,
270 | pbc: torch.Tensor,
271 | cutoff: float,
272 | n_atoms: torch.Tensor,
273 | self_interaction: bool = False,
274 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
275 | """Build the neighborlist of a given set of atomic structures using the linked cell algorithm.
276 |
277 | Parameters
278 | ----------
279 | positions : torch.Tensor [-1, 3]
280 | set of atomic positions for each structures
281 | cell : torch.Tensor [3*n_structure, 3]
282 | set of unit cell vectors for each structures
283 | pbc : torch.Tensor [n_structures, 3] bool
284 | periodic boundary conditions to apply
285 | cutoff : float
286 | length used to determine neighborhood
287 | n_atoms : torch.Tensor
288 | number of atoms in each structures
289 | self_interaction : bool
290 | to keep the original atoms as their own neighbor
291 |
292 | Returns
293 | -------
294 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
295 | mapping : [2, n_neighbors]
296 | indices of the neighbor list for the given positions array, mapping[0/1] correspond respectively to the central/neighbor atom (or node in the graph terminology)
297 | batch_mapping : [n_neighbors]
298 | indices mapping the neighbor atom to each structures
299 | cell_shifts_idx : [n_neighbors, 3]
300 | cell shift indices to be used in reconstructing the neighbor atom positions.
301 | """
302 |
303 | n_structure = n_atoms.shape[0]
304 | device = positions.device
305 | cell = cell.view((-1, 3, 3))
306 | pbc = pbc.view((-1, 3))
307 | # compute the number of cell replica necessary so that all the unit cell's atom have a complete neighborhood (no MIC assumed here)
308 | num_repeats = get_number_of_cell_repeats(cutoff, cell, pbc)
309 |
310 | stride = strides_of(n_atoms)
311 |
312 | mapping, batch_mapping, cell_shifts_idx = [], [], []
313 | for i_structure in range(n_structure):
314 | # compute the neighborhood with the linked cell algorithm
315 | neigh_atom, neigh_shift_idx = linked_cell(
316 | positions[stride[i_structure] : stride[i_structure + 1]],
317 | cell[i_structure],
318 | cutoff,
319 | num_repeats[i_structure],
320 | self_interaction,
321 | )
322 |
323 | batch_mapping.append(
324 | i_structure
325 | * torch.ones(neigh_atom.shape[1], dtype=torch.long, device=device)
326 | )
327 | # shift the mapping indices so that they can access positions
328 | mapping.append(neigh_atom + stride[i_structure])
329 | cell_shifts_idx.append(neigh_shift_idx)
330 | return (
331 | torch.cat(mapping, dim=1),
332 | torch.cat(batch_mapping, dim=0),
333 | torch.cat(cell_shifts_idx, dim=0),
334 | )
335 |
--------------------------------------------------------------------------------
/torch_nl/naive_impl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Tuple
3 |
4 | from .utils import get_number_of_cell_repeats, get_cell_shift_idx, strides_of
5 |
6 |
7 | def get_fully_connected_mapping(
8 | i_ids: torch.Tensor, shifts_idx: torch.Tensor, self_interaction: bool
9 | ) -> Tuple[torch.Tensor, torch.Tensor]:
10 | n_atom = i_ids.shape[0]
11 | n_atom2 = n_atom * n_atom
12 | n_cell_image = shifts_idx.shape[0]
13 | j_ids = torch.repeat_interleave(
14 | i_ids, n_cell_image, dim=0, output_size=n_cell_image * n_atom
15 | )
16 | mapping = torch.cartesian_prod(i_ids, j_ids)
17 | shifts_idx = shifts_idx.repeat((n_atom2, 1))
18 | if not self_interaction:
19 | mask = torch.ones(
20 | mapping.shape[0], dtype=torch.bool, device=i_ids.device
21 | )
22 | ids = n_cell_image * torch.arange(
23 | n_atom, device=i_ids.device
24 | ) + torch.arange(
25 | 0, mapping.shape[0], n_atom * n_cell_image, device=i_ids.device
26 | )
27 | mask[ids] = False
28 | mapping = mapping[mask, :]
29 | shifts_idx = shifts_idx[mask]
30 | return mapping, shifts_idx
31 |
32 |
33 | def build_naive_neighborhood(
34 | positions: torch.Tensor,
35 | cell: torch.Tensor,
36 | pbc: torch.Tensor,
37 | cutoff: float,
38 | n_atoms: torch.Tensor,
39 | self_interaction: bool,
40 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
41 | """TODO: add doc"""
42 | device = positions.device
43 | dtype = positions.dtype
44 |
45 | num_repeats_ = get_number_of_cell_repeats(cutoff, cell, pbc)
46 |
47 | stride = strides_of(n_atoms)
48 | ids = torch.arange(positions.shape[0], device=device, dtype=torch.long)
49 |
50 | mapping, batch_mapping, shifts_idx_ = [], [], []
51 | for i_structure in range(n_atoms.shape[0]):
52 | num_repeats = num_repeats_[i_structure]
53 | shifts_idx = get_cell_shift_idx(num_repeats, dtype)
54 | i_ids = ids[stride[i_structure] : stride[i_structure + 1]]
55 |
56 | s_mapping, shifts_idx = get_fully_connected_mapping(
57 | i_ids, shifts_idx, self_interaction
58 | )
59 | mapping.append(s_mapping)
60 | batch_mapping.append(
61 | torch.full(
62 | (s_mapping.shape[0],),
63 | i_structure,
64 | dtype=torch.long,
65 | device=device,
66 | )
67 | )
68 | shifts_idx_.append(shifts_idx)
69 | return (
70 | torch.cat(mapping, dim=0).t(),
71 | torch.cat(batch_mapping, dim=0),
72 | torch.cat(shifts_idx_, dim=0),
73 | )
74 |
--------------------------------------------------------------------------------
/torch_nl/neighbor_list.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .naive_impl import build_naive_neighborhood
4 | from .geometry import compute_cell_shifts
5 | from .linked_cell import build_linked_cell_neighborhood
6 |
7 |
8 | def strict_nl(
9 | cutoff: float,
10 | pos: torch.Tensor,
11 | cell: torch.Tensor,
12 | mapping: torch.Tensor,
13 | batch_mapping: torch.Tensor,
14 | shifts_idx: torch.Tensor,
15 | ):
16 | """Apply a strict cutoff to the neighbor list defined in mapping.
17 |
18 | Parameters
19 | ----------
20 | cutoff : _type_
21 | _description_
22 | pos : _type_
23 | _description_
24 | cell : _type_
25 | _description_
26 | mapping : _type_
27 | _description_
28 | batch_mapping : _type_
29 | _description_
30 | shifts_idx : _type_
31 | _description_
32 |
33 | Returns
34 | -------
35 | _type_
36 | _description_
37 | """
38 | cell_shifts = compute_cell_shifts(cell, shifts_idx, batch_mapping)
39 | if cell_shifts is None:
40 | d2 = (pos[mapping[0]] - pos[mapping[1]]).square().sum(dim=1)
41 | else:
42 | d2 = (
43 | (pos[mapping[0]] - pos[mapping[1]] - cell_shifts)
44 | .square()
45 | .sum(dim=1)
46 | )
47 |
48 | mask = d2 < cutoff * cutoff
49 | mapping = mapping[:, mask]
50 | mapping_batch = batch_mapping[mask]
51 | shifts_idx = shifts_idx[mask]
52 | return mapping, mapping_batch, shifts_idx
53 |
54 |
55 | @torch.jit.script
56 | def compute_neighborlist_n2(
57 | cutoff: float,
58 | pos: torch.Tensor,
59 | cell: torch.Tensor,
60 | pbc: torch.Tensor,
61 | batch: torch.Tensor,
62 | self_interaction: bool = False,
63 | ):
64 | """Compute the neighborlist for a set of atomic structures using the naive a neighbor search before applying a strict `cutoff`. The atoms positions
65 | `pos` should be wrapped inside their respective unit cells.
66 |
67 | Parameters
68 | ----------
69 | cutoff : float
70 | cutoff radius of used for the neighbor search
71 | pos : torch.Tensor [n_atom, 3]
72 | set of atoms positions wrapped inside their respective unit cells
73 | cell : torch.Tensor [3*n_structure, 3]
74 | unit cell vectors in the format [a_1, a_2, a_3]
75 | pbc : torch.Tensor [n_structure, 3] bool
76 | periodic boundary conditions to apply. Partial PBC are not supported yet
77 | batch : torch.Tensor torch.long [n_atom,]
78 | index of the structure in which the atom belongs to
79 | self_interaction : bool, optional
80 | to keep the center atoms as their own neighbor, by default False
81 |
82 | Returns
83 | -------
84 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
85 | mapping : [2, n_neighbors]
86 | indices of the neighbor list for the given positions array, mapping[0/1] correspond respectively to the central/neighbor atom (or node in the graph terminology)
87 | batch_mapping : [n_neighbors]
88 | indices mapping the neighbor atom to each structures
89 | shifts_idx : [n_neighbors, 3]
90 | cell shift indices to be used in reconstructing the neighbor atom positions.
91 | """
92 | n_atoms = torch.bincount(batch)
93 | mapping, batch_mapping, shifts_idx = build_naive_neighborhood(
94 | pos, cell, pbc, cutoff, n_atoms, self_interaction
95 | )
96 | mapping, mapping_batch, shifts_idx = strict_nl(
97 | cutoff, pos, cell, mapping, batch_mapping, shifts_idx
98 | )
99 | return mapping, mapping_batch, shifts_idx
100 |
101 |
102 | @torch.jit.script
103 | def compute_neighborlist(
104 | cutoff: float,
105 | pos: torch.Tensor,
106 | cell: torch.Tensor,
107 | pbc: torch.Tensor,
108 | batch: torch.Tensor,
109 | self_interaction: bool = False,
110 | ):
111 | """Compute the neighborlist for a set of atomic structures using the linked
112 | cell algorithm before applying a strict `cutoff`. The atoms positions `pos`
113 | should be wrapped inside their respective unit cells.
114 |
115 | Parameters
116 | ----------
117 | cutoff : float
118 | cutoff radius of used for the neighbor search
119 | pos : torch.Tensor [n_atom, 3]
120 | set of atoms positions wrapped inside their respective unit cells
121 | cell : torch.Tensor [3*n_structure, 3]
122 | unit cell vectors in the format [a_1, a_2, a_3]
123 | pbc : torch.Tensor [n_structure, 3] bool
124 | periodic boundary conditions to apply. Partial PBC are not supported yet
125 | batch : torch.Tensor torch.long [n_atom,]
126 | index of the structure in which the atom belongs to
127 | self_interaction : bool, optional
128 | to keep the center atoms as their own neighbor, by default False
129 |
130 | Returns
131 | -------
132 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
133 | mapping : [2, n_neighbors]
134 | indices of the neighbor list for the given positions array, mapping[0/1] correspond respectively to the central/neighbor atom (or node in the graph terminology)
135 | batch_mapping : [n_neighbors]
136 | indices mapping the neighbor atom to each structures
137 | shifts_idx : [n_neighbors, 3]
138 | cell shift indices to be used in reconstructing the neighbor atom positions.
139 | """
140 | n_atoms = torch.bincount(batch)
141 | mapping, batch_mapping, shifts_idx = build_linked_cell_neighborhood(
142 | pos, cell, pbc, cutoff, n_atoms, self_interaction
143 | )
144 |
145 | mapping, mapping_batch, shifts_idx = strict_nl(
146 | cutoff, pos, cell, mapping, batch_mapping, shifts_idx
147 | )
148 | return mapping, mapping_batch, shifts_idx
149 |
--------------------------------------------------------------------------------
/torch_nl/test_nl.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from ase.build import bulk, molecule
3 | import numpy as np
4 | from ase.neighborlist import neighbor_list
5 | from ase import Atoms
6 | import torch
7 |
8 | from .neighbor_list import (
9 | compute_neighborlist_n2,
10 | compute_cell_shifts,
11 | compute_neighborlist,
12 | )
13 | from .utils import ase2data
14 | from .geometry import compute_distances
15 |
16 | # triclinic atomic structure
17 | CaCrP2O7_mvc_11955_symmetrized = {
18 | "positions": [
19 | [3.68954016, 5.03568186, 4.64369552],
20 | [5.12301681, 2.13482791, 2.66220405],
21 | [1.99411973, 0.94691001, 1.25068234],
22 | [6.81843724, 6.22359976, 6.05521724],
23 | [2.63005662, 4.16863452, 0.86090529],
24 | [6.18250036, 3.00187525, 6.44499428],
25 | [2.11497733, 1.98032773, 4.53610884],
26 | [6.69757964, 5.19018203, 2.76979073],
27 | [1.39215545, 2.94386142, 5.60917746],
28 | [7.42040152, 4.22664834, 1.69672212],
29 | [2.43224207, 5.4571615, 6.70305327],
30 | [6.3803149, 1.71334827, 0.6028463],
31 | [1.11265639, 1.50166318, 3.48760997],
32 | [7.69990058, 5.66884659, 3.8182896],
33 | [3.56971588, 5.20836551, 1.43673437],
34 | [5.2428411, 1.96214426, 5.8691652],
35 | [3.12282634, 2.72812741, 1.05450432],
36 | [5.68973063, 4.44238236, 6.25139525],
37 | [3.24868468, 2.83997522, 3.99842386],
38 | [5.56387229, 4.33053455, 3.30747571],
39 | [2.60835346, 0.74421609, 5.3236629],
40 | [6.20420351, 6.42629368, 1.98223667],
41 | ],
42 | "cell": [
43 | [6.19330899, 0.0, 0.0],
44 | [2.4074486111396207, 6.149627748674982, 0.0],
45 | [0.2117993724186579, 1.0208820183960539, 7.305899571570074],
46 | ],
47 | "numbers": [
48 | 20,
49 | 20,
50 | 24,
51 | 24,
52 | 15,
53 | 15,
54 | 15,
55 | 15,
56 | 8,
57 | 8,
58 | 8,
59 | 8,
60 | 8,
61 | 8,
62 | 8,
63 | 8,
64 | 8,
65 | 8,
66 | 8,
67 | 8,
68 | 8,
69 | 8,
70 | ],
71 | "pbc": [True, True, True],
72 | }
73 |
74 |
75 | def bulk_metal():
76 | frames = [
77 | bulk("Si", "diamond", a=6, cubic=True),
78 | bulk("Si", "diamond", a=6),
79 | bulk("Cu", "fcc", a=3.6),
80 | bulk("Si", "bct", a=6, c=3),
81 | # test very skewed unit cell
82 | bulk("Bi", "rhombohedral", a=6, alpha=20),
83 | bulk("Bi", "rhombohedral", a=6, alpha=10),
84 | bulk("Bi", "rhombohedral", a=6, alpha=5),
85 | bulk("SiCu", "rocksalt", a=6),
86 | bulk("SiFCu", "fluorite", a=6),
87 | Atoms(**CaCrP2O7_mvc_11955_symmetrized),
88 | ]
89 | return frames
90 |
91 |
92 | def atomic_structures():
93 | frames = (
94 | [
95 | molecule("CH3CH2NH2"),
96 | molecule("H2O"),
97 | molecule("methylenecyclopropane"),
98 | ]
99 | + bulk_metal()
100 | + [
101 | molecule("OCHCHO"),
102 | molecule("C3H9C"),
103 | ]
104 | )
105 | return frames
106 |
107 |
108 | @pytest.mark.parametrize(
109 | "frames, cutoff, self_interaction",
110 | [
111 | (atomic_structures(), rc, self_interaction)
112 | for rc in [1, 3, 5, 7]
113 | for self_interaction in [True, False]
114 | ],
115 | )
116 | def test_neighborlist_n2(frames, cutoff, self_interaction):
117 | """Check that torch_neighbor_list gives the same NL as ASE by comparing
118 | the resulting sorted list of distances between neighbors."""
119 | pos, cell, pbc, batch, n_atoms = ase2data(frames)
120 |
121 | dds = []
122 | mapping, batch_mapping, shifts_idx = compute_neighborlist_n2(
123 | cutoff, pos, cell, pbc, batch, self_interaction
124 | )
125 | cell_shifts = compute_cell_shifts(cell, shifts_idx, batch_mapping)
126 | dds = compute_distances(pos, mapping, cell_shifts)
127 | dds = np.sort(dds.numpy())
128 |
129 | dd_ref = []
130 | for frame in frames:
131 | idx_i, idx_j, idx_S, dist = neighbor_list(
132 | "ijSd", frame, cutoff=cutoff, self_interaction=self_interaction
133 | )
134 | dd_ref.extend(dist)
135 | dd_ref = np.sort(dd_ref)
136 |
137 | np.testing.assert_allclose(dd_ref, dds)
138 |
139 |
140 | @pytest.mark.parametrize(
141 | "frames, cutoff, self_interaction",
142 | [
143 | (atomic_structures(), rc, self_interaction)
144 | # for rc in [3] #[1, 3, 5, 7]
145 | # for self_interaction in [False]
146 | for rc in [1, 3, 5, 7]
147 | for self_interaction in [False, True]
148 | ],
149 | )
150 | def test_neighborlist_linked_cell(frames, cutoff, self_interaction):
151 | """Check that torch_neighbor_list gives the same NL as ASE by comparing
152 | the resulting sorted list of distances between neighbors."""
153 | pos, cell, pbc, batch, n_atoms = ase2data(frames)
154 |
155 | dds = []
156 | mapping, batch_mapping, shifts_idx = compute_neighborlist(
157 | cutoff, pos, cell, pbc, batch, self_interaction
158 | )
159 | cell_shifts = compute_cell_shifts(cell, shifts_idx, batch_mapping)
160 | dds = compute_distances(pos, mapping, cell_shifts)
161 | dds = np.sort(dds.numpy())
162 |
163 | dd_ref = []
164 | for frame in frames:
165 | idx_i, idx_j, idx_S, dist = neighbor_list(
166 | "ijSd", frame, cutoff=cutoff, self_interaction=self_interaction
167 | )
168 | dd_ref.extend(dist)
169 | # nice for understanding if something goes wrong
170 | idx_S = torch.from_numpy(idx_S).to(torch.float64)
171 |
172 | print("idx_i", idx_i)
173 | print("idx_j", idx_j)
174 | missing_entries = []
175 | for ineigh in range(idx_i.shape[0]):
176 | mask = torch.logical_and(
177 | idx_i[ineigh] == mapping[0], idx_j[ineigh] == mapping[1]
178 | )
179 |
180 | if torch.any(torch.all(idx_S[ineigh] == shifts_idx[mask], dim=1)):
181 | pass
182 | else:
183 | missing_entries.append(
184 | (idx_i[ineigh], idx_j[ineigh], idx_S[ineigh])
185 | )
186 | print(missing_entries[-1])
187 | print(
188 | compute_cell_shifts(
189 | cell,
190 | idx_S[ineigh].view((1, -1)),
191 | torch.tensor([0], dtype=torch.long),
192 | )
193 | )
194 |
195 | dd_ref = np.sort(dd_ref)
196 | print(dd_ref[-20:])
197 | print(dds[-20:])
198 | np.testing.assert_allclose(dd_ref, dds)
199 |
--------------------------------------------------------------------------------
/torch_nl/timer.py:
--------------------------------------------------------------------------------
1 | from timeit import default_timer as timer
2 | import numpy as np
3 | from typing import Mapping
4 | from types import GeneratorType
5 |
6 |
7 | def eval_func(func, inp):
8 | if isinstance(inp, Mapping):
9 | inner = lambda inp: func(**inp)
10 | elif isinstance(inp, GeneratorType):
11 | inner = lambda inp: func(*inp)
12 | else:
13 | inner = lambda inp: func(*inp)
14 | return inner
15 |
16 |
17 | def timeit(func, inp, tag="", warmup=10, nit=100):
18 | timer = Timer(tag=tag)
19 | inner = eval_func(func, inp)
20 |
21 | for _ in range(warmup):
22 | inner(inp)
23 | for _ in range(nit):
24 | with timer:
25 | inner(inp)
26 | return timer
27 |
28 |
29 | class Timer(object):
30 | def __init__(self, tag="", logger=None):
31 | self.tag = tag
32 | self.elapsed = []
33 | self.start = None
34 | self.end = None
35 |
36 | def __enter__(self):
37 | self.start = timer()
38 |
39 | def __exit__(self, type, value, traceback):
40 | self.end = timer()
41 | self.elapsed.append(self.end - self.start)
42 |
43 | def mean(self):
44 | return np.mean(self.elapsed)
45 |
46 | def stdev(self):
47 | return np.std(self.elapsed)
48 |
49 | def min(self):
50 | return np.min(self.elapsed)
51 |
52 | def max(self):
53 | return np.max(self.elapsed)
54 |
55 | def samples(self):
56 | return self.elapsed
57 |
58 | def dumps(self):
59 | data = dict(
60 | tag=self.tag,
61 | mean=self.mean(),
62 | stdev=self.stdev(),
63 | min=self.min(),
64 | max=self.max(),
65 | samples=self.samples(),
66 | )
67 | return data
68 |
69 | def __repr__(self) -> str:
70 | timings = self.dumps()
71 | return f'{timings["tag"]} ' + " / ".join(
72 | [
73 | f"{k}={timings[k]*1000:.5f} [ms]"
74 | for k in ["mean", "stdev", "min", "max"]
75 | ]
76 | )
77 |
--------------------------------------------------------------------------------
/torch_nl/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.types import _dtype
4 |
5 |
6 | def ase2data(frames, device="cpu"):
7 | n_atoms = [0]
8 | pos = []
9 | cell = []
10 | pbc = []
11 | for ff in frames:
12 | n_atoms.append(len(ff))
13 | pos.append(torch.from_numpy(ff.get_positions()))
14 | cell.append(torch.from_numpy(ff.get_cell().array))
15 | pbc.append(torch.from_numpy(ff.get_pbc()))
16 | pos = torch.cat(pos)
17 | cell = torch.cat(cell)
18 | pbc = torch.cat(pbc)
19 | stride = torch.from_numpy(np.cumsum(n_atoms))
20 | batch = torch.zeros(pos.shape[0], dtype=torch.long)
21 | for ii, (st, nd) in enumerate(zip(stride[:-1], stride[1:])):
22 | batch[st:nd] = ii
23 | n_atoms = torch.Tensor(n_atoms[1:]).to(dtype=torch.long)
24 | return (
25 | pos.to(device=device),
26 | cell.to(device=device),
27 | pbc.to(device=device),
28 | batch.to(device=device),
29 | n_atoms.to(device=device),
30 | )
31 |
32 |
33 | def strides_of(v: torch.Tensor) -> torch.Tensor:
34 | v = v.flatten()
35 | stride = v.new_empty(v.shape[0] + 1)
36 | stride[0] = 0
37 | torch.cumsum(v, dim=0, dtype=stride.dtype, out=stride[1:])
38 | return stride
39 |
40 |
41 | def get_number_of_cell_repeats(
42 | cutoff: float, cell: torch.Tensor, pbc: torch.Tensor
43 | ) -> torch.Tensor:
44 | cell = cell.view((-1, 3, 3))
45 | pbc = pbc.view((-1, 3))
46 |
47 | has_pbc = pbc.prod(dim=1, dtype=torch.bool)
48 | reciprocal_cell = torch.zeros_like(cell)
49 | reciprocal_cell[has_pbc, :, :] = torch.linalg.inv(
50 | cell[has_pbc, :, :]
51 | ).transpose(2, 1)
52 | inv_distances = reciprocal_cell.norm(2, dim=-1)
53 | num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
54 | num_repeats_ = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats))
55 | return num_repeats_
56 |
57 |
58 | def get_cell_shift_idx(
59 | num_repeats: torch.Tensor, dtype: _dtype
60 | ) -> torch.Tensor:
61 | reps = []
62 | for ii in range(3):
63 | r1 = torch.arange(
64 | -num_repeats[ii],
65 | num_repeats[ii] + 1,
66 | device=num_repeats.device,
67 | dtype=dtype,
68 | )
69 | _, indices = torch.sort(torch.abs(r1))
70 | reps.append(r1[indices])
71 | shifts_idx = torch.cartesian_prod(reps[0], reps[1], reps[2])
72 | return shifts_idx
73 |
--------------------------------------------------------------------------------