├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── benchmark ├── benchmark.ipynb ├── benchmark.py ├── diamond_benchmark.png └── profile_nl.py ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py └── torch_nl ├── __init__.py ├── geometry.py ├── linked_cell.py ├── naive_impl.py ├── neighbor_list.py ├── test_nl.py ├── timer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.prof 6 | *.code-workspace 7 | # C extensions 8 | *.so 9 | .DS_Store 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 felixmusil 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torch_nl 2 | 3 | Provide a pytorch implementation of a naive (`compute_neighborlist_n2`) and a linked cell (`compute_neighborlist`) neighbor list that are compatible with TorchScript. 4 | 5 | Their correctness is tested against ASE's implementation. 6 | 7 | Note that contrary to ASE, the atoms positions are assumed to be wrapped inside the unit cell. 8 | # How to 9 | 10 | ## instal with pip 11 | 12 | ```bash 13 | pip install torch-nl 14 | ``` 15 | 16 | ## use the neighborlist 17 | 18 | ```python 19 | from torch_nl import compute_neighborlist, ase2data 20 | from ase.build import bulk, molecule 21 | 22 | frames = [bulk("Si", "diamond", a=6, cubic=True), molecule("CH3CH2NH2")] 23 | pos, cell, pbc, batch, n_atoms = ase2data(frames) 24 | 25 | mapping, batch_mapping, shifts_idx = compute_neighborlist( 26 | cutoff, pos, cell, pbc, batch, self_interaction 27 | ) 28 | ``` 29 | 30 | # Benchmarks 31 | 32 | ## Periodic structure 33 | 34 | ![](benchmark/diamond_benchmark.png) 35 | -------------------------------------------------------------------------------- /benchmark/benchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 14, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import matplotlib.pyplot as plt\n", 10 | "import torch\n", 11 | "import ase\n", 12 | "import numpy as np\n", 13 | "import scipy\n", 14 | "from ase.build import molecule, bulk, make_supercell\n", 15 | "from ase.neighborlist import neighbor_list\n", 16 | "import pandas as pd\n", 17 | "from tqdm.notebook import tqdm\n", 18 | "import seaborn as sns" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from torch_nl import compute_neighborlist, compute_neighborlist_n2, ase2data\n", 28 | "from torch_nl.timer import timeit" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "# Periodic systems" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Metal " 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 3, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "frame = bulk('Si', 'diamond', a=4, cubic=True)\n", 52 | "aa = torch.arange(1, 6)\n", 53 | "Ps = torch.cartesian_prod(aa,aa,aa)\n", 54 | "Ps = Ps[torch.sort(Ps.sum(dim=1)).indices].to(torch.long).numpy()\n", 55 | "frames = []\n", 56 | "n_atoms = []\n", 57 | "for P in Ps:\n", 58 | " frames.append(make_supercell(frame, np.diag(P)))\n", 59 | " n_atoms.append(len(frames[-1]))\n", 60 | "n_atoms = np.array(n_atoms)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 4, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "cutoff = 4" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 5, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "data": { 79 | "application/json": { 80 | "ascii": false, 81 | "bar_format": null, 82 | "colour": null, 83 | "elapsed": 0.011367321014404297, 84 | "initial": 0, 85 | "n": 0, 86 | "ncols": null, 87 | "nrows": 54, 88 | "postfix": null, 89 | "prefix": "", 90 | "rate": null, 91 | "total": 125, 92 | "unit": "it", 93 | "unit_divisor": 1000, 94 | "unit_scale": false 95 | }, 96 | "application/vnd.jupyter.widget-view+json": { 97 | "model_id": "ab749187b35f41579193cdd675218261", 98 | "version_major": 2, 99 | "version_minor": 0 100 | }, 101 | "text/plain": [ 102 | " 0%| | 0/125 [00:00\n", 300 | "\n", 313 | "\n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | "
tagmeanstdevminmaxn_atomn_neighbor_per_atom_avg
0ASE0.0024980.0000410.0024390.002685828
1ASE0.0031270.0000200.0031010.0031681628
2ASE0.0031620.0000630.0031010.0033741628
3ASE0.0031450.0000430.0031000.0032971628
4ASE0.0044290.0000500.0043740.0045773228
........................
495torch_nl O(n) GPU0.0050460.0001900.0049300.00634860029
496torch_nl O(n) GPU0.0059240.0000940.0058090.00647680029
497torch_nl O(n) GPU0.0058860.0000380.0057950.00594380029
498torch_nl O(n) GPU0.0059130.0000680.0058030.00629280029
499torch_nl O(n) GPU0.0068390.0006310.0066510.011249100029
\n", 439 | "

500 rows × 7 columns

\n", 440 | "" 441 | ], 442 | "text/plain": [ 443 | " tag mean stdev min max n_atom \\\n", 444 | "0 ASE 0.002498 0.000041 0.002439 0.002685 8 \n", 445 | "1 ASE 0.003127 0.000020 0.003101 0.003168 16 \n", 446 | "2 ASE 0.003162 0.000063 0.003101 0.003374 16 \n", 447 | "3 ASE 0.003145 0.000043 0.003100 0.003297 16 \n", 448 | "4 ASE 0.004429 0.000050 0.004374 0.004577 32 \n", 449 | ".. ... ... ... ... ... ... \n", 450 | "495 torch_nl O(n) GPU 0.005046 0.000190 0.004930 0.006348 600 \n", 451 | "496 torch_nl O(n) GPU 0.005924 0.000094 0.005809 0.006476 800 \n", 452 | "497 torch_nl O(n) GPU 0.005886 0.000038 0.005795 0.005943 800 \n", 453 | "498 torch_nl O(n) GPU 0.005913 0.000068 0.005803 0.006292 800 \n", 454 | "499 torch_nl O(n) GPU 0.006839 0.000631 0.006651 0.011249 1000 \n", 455 | "\n", 456 | " n_neighbor_per_atom_avg \n", 457 | "0 28 \n", 458 | "1 28 \n", 459 | "2 28 \n", 460 | "3 28 \n", 461 | "4 28 \n", 462 | ".. ... \n", 463 | "495 29 \n", 464 | "496 29 \n", 465 | "497 29 \n", 466 | "498 29 \n", 467 | "499 29 \n", 468 | "\n", 469 | "[500 rows x 7 columns]" 470 | ] 471 | }, 472 | "execution_count": 8, 473 | "metadata": {}, 474 | "output_type": "execute_result" 475 | } 476 | ], 477 | "source": [ 478 | "df" 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": 16, 484 | "metadata": { 485 | "scrolled": true 486 | }, 487 | "outputs": [ 488 | { 489 | "data": { 490 | "image/png": "\n", 491 | "text/plain": [ 492 | "
" 493 | ] 494 | }, 495 | "metadata": {}, 496 | "output_type": "display_data" 497 | } 498 | ], 499 | "source": [ 500 | "with sns.plotting_context(\"notebook\"):\n", 501 | " plt.title(\"Compute neighborlist: diamond structure\")\n", 502 | " sns.lmplot(data=df, x='n_atom', y='mean', hue='tag',fit_reg=False)\n", 503 | " plt.xlabel('timing []')\n", 504 | " plt.savefig('diamond_benchmark.png', dpi=300, bbox_inches='tight')" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 15, 510 | "metadata": {}, 511 | "outputs": [], 512 | "source": [ 513 | "plt.savefig?" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "metadata": {}, 520 | "outputs": [], 521 | "source": [] 522 | } 523 | ], 524 | "metadata": { 525 | "kernelspec": { 526 | "display_name": "jax39", 527 | "language": "python", 528 | "name": "jax39" 529 | }, 530 | "language_info": { 531 | "codemirror_mode": { 532 | "name": "ipython", 533 | "version": 3 534 | }, 535 | "file_extension": ".py", 536 | "mimetype": "text/x-python", 537 | "name": "python", 538 | "nbconvert_exporter": "python", 539 | "pygments_lexer": "ipython3", 540 | "version": "3.9.13" 541 | }, 542 | "toc": { 543 | "colors": { 544 | "hover_highlight": "#DAA520", 545 | "navigate_num": "#000000", 546 | "navigate_text": "#333333", 547 | "running_highlight": "#FF0000", 548 | "selected_highlight": "#FFD700", 549 | "sidebar_border": "#EEEEEE", 550 | "wrapper_background": "#FFFFFF" 551 | }, 552 | "moveMenuLeft": true, 553 | "nav_menu": { 554 | "height": "48.9333px", 555 | "width": "251.8px" 556 | }, 557 | "navigate_menu": true, 558 | "number_sections": true, 559 | "sideBar": true, 560 | "threshold": 4, 561 | "toc_cell": false, 562 | "toc_section_display": "block", 563 | "toc_window_display": false, 564 | "widenNotebook": false 565 | }, 566 | "vscode": { 567 | "interpreter": { 568 | "hash": "f79d3df5ff5684964744ab9f5218f96eb946f2b40d3f02d5eb965bb50f364a25" 569 | } 570 | } 571 | }, 572 | "nbformat": 4, 573 | "nbformat_minor": 2 574 | } 575 | -------------------------------------------------------------------------------- /benchmark/benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import ase 3 | import sys 4 | import numpy as np 5 | import scipy 6 | from ase.build import molecule, bulk, make_supercell 7 | from ase.neighborlist import neighbor_list 8 | import pandas as pd 9 | from tqdm import tqdm 10 | 11 | # import seaborn as sns 12 | import matplotlib.pyplot as plt 13 | 14 | sys.path.insert(0, "../") 15 | from torch_nl import compute_neighborlist, compute_neighborlist_n2, ase2data 16 | from torch_nl.timer import timeit 17 | 18 | torch.set_num_threads(4) 19 | cutoff = 4 20 | tags = [ 21 | # "torch_nl O(n^2) CPU", 22 | "torch_nl O(n^2) GPU", 23 | "torch_nl O(n) CPU", 24 | "torch_nl O(n) GPU" 25 | ] 26 | 27 | frame = bulk("Si", "diamond", a=4, cubic=True) 28 | aa = torch.arange(1, 6) 29 | Ps = torch.cartesian_prod(aa, aa, aa) 30 | Ps = Ps[torch.sort(Ps.sum(dim=1)).indices].to(torch.long).numpy() 31 | frames = [] 32 | n_atoms = [] 33 | for P in Ps: 34 | frames.append(make_supercell(frame, np.diag(P))) 35 | n_atoms.append(len(frames[-1])) 36 | n_atoms = np.array(n_atoms) 37 | print("Starting") 38 | tag = "ASE" 39 | datas = [] 40 | for frame in tqdm(frames): 41 | timing = timeit( 42 | neighbor_list, ["ijS", frame, cutoff], tag=tag, warmup=1, nit=50 43 | ) 44 | data = timing.dumps() 45 | i, j, S = neighbor_list("ijS", frame, cutoff) 46 | n_neighbor = np.bincount(i).mean() 47 | data.update(n_atom=len(frame), n_neighbor_per_atom_avg=int(n_neighbor)) 48 | data.pop("samples") 49 | datas.append(data) 50 | 51 | 52 | for tag in tqdm(tags): 53 | if "CPU" in tag: 54 | device = "cpu" 55 | elif "GPU" in tag: 56 | device = "cuda" 57 | 58 | if "O(n^2)" in tag: 59 | nl_func = compute_neighborlist_n2 60 | elif "O(n)" in tag: 61 | nl_func = compute_neighborlist 62 | 63 | for frame in tqdm(frames): 64 | pos, cell, pbc, batch, n_atoms = ase2data([frame], device=device) 65 | with torch.cuda.amp.autocast(): 66 | timing = timeit( 67 | nl_func, 68 | [cutoff, pos, cell, pbc, batch], 69 | tag=tag, 70 | warmup=10, 71 | nit=50, 72 | ) 73 | data = timing.dumps() 74 | data.pop("samples") 75 | mapping, mapping_batch, shifts_idx = nl_func( 76 | cutoff, pos, cell, pbc, batch 77 | ) 78 | n_neighbor = np.bincount(mapping[0].cpu().numpy()).mean() 79 | data.update(n_atom=len(frame), n_neighbor_per_atom_avg=int(n_neighbor)) 80 | datas.append(data) 81 | 82 | df = pd.DataFrame(datas) 83 | 84 | # sns.lmplot(data=df, x='n_atom', y='mean', hue='tag',fit_reg=False) 85 | 86 | # plt.savefig('./test_0.png', dpi=300, bbox_inches='tight') 87 | # plt.show() 88 | print("END") 89 | # %% 90 | -------------------------------------------------------------------------------- /benchmark/diamond_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/felixmusil/torch_nl/ce69750db61343b983c4e9cd92450494a487885e/benchmark/diamond_benchmark.png -------------------------------------------------------------------------------- /benchmark/profile_nl.py: -------------------------------------------------------------------------------- 1 | import torch.profiler 2 | import torch 3 | 4 | torch.jit.set_fusion_strategy([("STATIC", 3), ("DYNAMIC", 3)]) 5 | 6 | import sys 7 | 8 | sys.path.insert(0, "../") 9 | import numpy as np 10 | 11 | from torch_nl import compute_neighborlist, compute_neighborlist_n2, ase2data 12 | from torch_nl.timer import timeit 13 | 14 | from ase.build import molecule, bulk, make_supercell 15 | 16 | device = "cuda" 17 | cutoff = 4 18 | frame = bulk("Si", "diamond", a=4, cubic=True) 19 | 20 | frame = make_supercell(frame, 6 * np.eye(3)) 21 | 22 | 23 | pos, cell, pbc, batch, n_atoms = ase2data([frame], device=device) 24 | 25 | 26 | with torch.profiler.profile( 27 | schedule=torch.profiler.schedule(wait=20, warmup=20, active=2, repeat=1), 28 | on_trace_ready=torch.profiler.tensorboard_trace_handler( 29 | "/local_scratch/musil/nl_n.prof" 30 | ), 31 | record_shapes=False, 32 | profile_memory=False, 33 | with_stack=True, 34 | ) as prof: 35 | for _ in range(50): 36 | mapping, mapping_batch, shifts_idx = compute_neighborlist_n2( 37 | cutoff, pos, cell, pbc, batch 38 | ) 39 | prof.step() 40 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "ninja"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | line-length = 80 7 | target-version = ['py38', 'py39'] 8 | include = '\.pyi?$' 9 | extend-exclude = ''' 10 | /( 11 | 12 | )/ 13 | ''' 14 | 15 | [tool.pytest.ini_options] 16 | minversion = "6.0" 17 | addopts = "-ra -q" 18 | testpaths = [ 19 | "mlcg", 20 | ] 21 | filterwarnings = [ 22 | "ignore::DeprecationWarning:networkx.*" 23 | ] 24 | 25 | [tool.coverage.run] 26 | branch = true 27 | source = ["mlcg/"] 28 | omit = [ 29 | "**/test_*.py", 30 | "**/__init__.py", 31 | ] 32 | 33 | [tool.coverage.report] 34 | exclude_lines = [ 35 | "if self.debug:" , 36 | "pragma: no cover" , 37 | "raise NotImplementedError" , 38 | "@(abc\\.)?abstractmethod" , 39 | ] 40 | ignore_errors = true 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ase 2 | numpy 3 | torch >=1.10 4 | # developer tools 5 | pytest 6 | black[jupyter] 7 | 8 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = torch_nl 3 | version = attr: torch_nl.__version__ 4 | description = TorchScript-able neighbor lists implementations (linear and quadratic scaling) for molecular modeling 5 | long_description = file: README.md 6 | long_description_content_type = text/markdown 7 | classifiers = 8 | Programming Language :: Python :: 3.7 9 | Topic :: Scientific/Engineering :: Chemistry 10 | Topic :: Scientific/Engineering :: Physics -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import re 3 | 4 | NAME = "torch_nl" 5 | 6 | # read the version number from the library 7 | pattern = r"[0-9]\.[0-9]" 8 | VERSION = None 9 | with open("./torch_nl/__init__.py", "r") as fp: 10 | for line in fp.readlines(): 11 | if "__version__" in line: 12 | VERSION = re.findall(pattern, line)[0] 13 | if VERSION is None: 14 | raise ValueError("Version number not found.") 15 | 16 | 17 | with open("requirements.txt") as f: 18 | install_requires = list( 19 | filter(lambda x: "#" not in x, (line.strip() for line in f)) 20 | ) 21 | 22 | setup( 23 | name=NAME, 24 | version=VERSION, 25 | packages=find_packages(), 26 | zip_safe=True, 27 | python_requires=">=3.8", 28 | license="MIT", 29 | author="Fe" + "\u0301" + "lix Musil", 30 | install_requires=install_requires, 31 | ) 32 | -------------------------------------------------------------------------------- /torch_nl/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3" 2 | 3 | from .neighbor_list import ( 4 | compute_neighborlist, 5 | compute_neighborlist_n2, 6 | strict_nl, 7 | ) 8 | from .geometry import compute_distances, compute_cell_shifts 9 | from .naive_impl import build_naive_neighborhood 10 | from .linked_cell import linked_cell 11 | from .utils import ase2data 12 | -------------------------------------------------------------------------------- /torch_nl/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | 5 | def compute_distances( 6 | pos: torch.Tensor, 7 | mapping: torch.Tensor, 8 | cell_shifts: Optional[torch.Tensor] = None, 9 | ): 10 | assert mapping.dim() == 2 11 | assert mapping.shape[0] == 2 12 | 13 | if cell_shifts is None: 14 | dr = pos[mapping[1]] - pos[mapping[0]] 15 | else: 16 | dr = pos[mapping[1]] - pos[mapping[0]] + cell_shifts 17 | 18 | return dr.norm(p=2, dim=1) 19 | 20 | 21 | def compute_cell_shifts( 22 | cell: torch.Tensor, shifts_idx: torch.Tensor, batch_mapping: torch.Tensor 23 | ): 24 | if cell is None: 25 | cell_shifts = None 26 | else: 27 | cell_shifts = torch.einsum( 28 | "jn,jnm->jm", shifts_idx, cell.view(-1, 3, 3)[batch_mapping] 29 | ) 30 | return cell_shifts 31 | -------------------------------------------------------------------------------- /torch_nl/linked_cell.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | 4 | from .utils import get_number_of_cell_repeats, get_cell_shift_idx, strides_of 5 | from .geometry import compute_cell_shifts 6 | 7 | 8 | def ravel_3d(idx_3d: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: 9 | """Convert 3d indices meant for an array of sizes `shape` into linear 10 | indices. 11 | 12 | Parameters 13 | ---------- 14 | idx_3d : [-1, 3] 15 | _description_ 16 | shape : [3] 17 | _description_ 18 | 19 | Returns 20 | ------- 21 | torch.Tensor 22 | linear indices 23 | """ 24 | idx_linear = idx_3d[:, 2] + shape[2] * ( 25 | idx_3d[:, 1] + shape[1] * idx_3d[:, 0] 26 | ) 27 | return idx_linear 28 | 29 | 30 | def unravel_3d(idx_linear: torch.Tensor, shape: torch.Tensor) -> torch.Tensor: 31 | """Convert linear indices meant for an array of sizes `shape` into 3d indices. 32 | 33 | Parameters 34 | ---------- 35 | idx_linear : torch.Tensor [-1] 36 | 37 | shape : torch.Tensor [3] 38 | 39 | 40 | Returns 41 | ------- 42 | torch.Tensor [-1, 3] 43 | 44 | """ 45 | idx_3d = idx_linear.new_empty((idx_linear.shape[0], 3)) 46 | idx_3d[:, 2] = torch.remainder(idx_linear, shape[2]) 47 | idx_3d[:, 1] = torch.remainder( 48 | torch.div(idx_linear, shape[2], rounding_mode="floor"), shape[1] 49 | ) 50 | idx_3d[:, 0] = torch.div( 51 | idx_linear, shape[1] * shape[2], rounding_mode="floor" 52 | ) 53 | return idx_3d 54 | 55 | 56 | def get_linear_bin_idx( 57 | cell: torch.Tensor, pos: torch.Tensor, nbins_s: torch.Tensor 58 | ) -> torch.Tensor: 59 | """Find the linear bin index of each input pos given a box defined by its cell vectors and a number of bins, contained in the box, for each directions of the box. 60 | 61 | Parameters 62 | ---------- 63 | cell : torch.Tensor [3, 3] 64 | cell vectors 65 | pos : torch.Tensor [-1, 3] 66 | set of positions 67 | nbins_s : torch.Tensor [3] 68 | number of bins in each directions 69 | 70 | Returns 71 | ------- 72 | torch.Tensor 73 | linear bin index 74 | """ 75 | scaled_pos = torch.linalg.solve(cell.t(), pos.t()).t() 76 | bin_index_s = torch.floor(scaled_pos * nbins_s).to(torch.long) 77 | bin_index_l = ravel_3d(bin_index_s, nbins_s) 78 | return bin_index_l 79 | 80 | def scatter_bin_index( 81 | nbins: int, 82 | max_n_atom_per_bin: int, 83 | n_images: int, 84 | bin_index: torch.Tensor, 85 | ): 86 | """convert the linear table `bin_index` into the table `bin_id`. Empty entries in `bin_id` are set to `n_images` so that they can be removed later. 87 | 88 | Parameters 89 | ---------- 90 | nbins : _type_ 91 | total number of bins 92 | max_n_atom_per_bin : _type_ 93 | maximum number of atoms per bin 94 | n_images : _type_ 95 | total number of atoms counting the pbc replicas 96 | bin_index : _type_ 97 | map relating `atom_index` to the `bin_index` that it belongs to such that `bin_index[atom_index] -> bin_index`. 98 | 99 | Returns 100 | ------- 101 | bin_id : torch.Tensor [nbins, max_n_atom_per_bin] 102 | relate `bin_index` (row) with the `atom_index` (stored in the columns). 103 | """ 104 | device = bin_index.device 105 | sorted_bin_index, sorted_id = torch.sort(bin_index) 106 | bin_id = torch.full( 107 | (nbins * max_n_atom_per_bin,), n_images, device=device, dtype=torch.long 108 | ) 109 | sorted_bin_id = torch.remainder( 110 | torch.arange(bin_index.shape[0], device=device), max_n_atom_per_bin 111 | ) 112 | sorted_bin_id = sorted_bin_index * max_n_atom_per_bin + sorted_bin_id 113 | bin_id.scatter_(dim=0, index=sorted_bin_id, src=sorted_id) 114 | bin_id = bin_id.view((nbins, max_n_atom_per_bin)) 115 | return bin_id 116 | 117 | 118 | def linked_cell( 119 | pos: torch.Tensor, 120 | cell: torch.Tensor, 121 | cutoff: float, 122 | num_repeats: torch.Tensor, 123 | self_interaction: bool = False, 124 | ) -> Tuple[torch.Tensor, torch.Tensor]: 125 | """Determine the atomic neighborhood of the atoms of a given structure for a particular cutoff using the linked cell algorithm. 126 | 127 | Parameters 128 | ---------- 129 | pos : torch.Tensor [n_atom, 3] 130 | atomic positions in the unit cell (positions outside the cell boundaries will result in an undifined behaviour) 131 | cell : torch.Tensor [3, 3] 132 | unit cell vectors in the format V=[v_0, v_1, v_2] 133 | cutoff : float 134 | length used to determine neighborhood 135 | num_repeats : torch.Tensor [3] 136 | number of unit cell repetitions in each directions required to account for PBC 137 | self_interaction : bool, optional 138 | to keep the original atoms as their own neighbor, by default False 139 | 140 | Returns 141 | ------- 142 | Tuple[torch.Tensor, torch.Tensor] 143 | neigh_atom : [2, n_neighbors] 144 | indices of the original atoms (neigh_atom[0]) with their neighbor index (neigh_atom[1]). The indices are meant to access the provided position array 145 | neigh_shift_idx : [n_neighbors, 3] 146 | cell shift indices to be used in reconstructing the neighbor atom positions. 147 | """ 148 | device = pos.device 149 | dtype = pos.dtype 150 | n_atom = pos.shape[0] 151 | # find all the integer shifts of the unit cell given the cutoff and periodicity 152 | shifts_idx = get_cell_shift_idx(num_repeats, dtype) 153 | n_cell_image = shifts_idx.shape[0] 154 | shifts_idx = torch.repeat_interleave( 155 | shifts_idx, n_atom, dim=0, output_size=n_atom * n_cell_image 156 | ) 157 | batch_image = torch.zeros((shifts_idx.shape[0]), dtype=torch.long) 158 | cell_shifts = compute_cell_shifts( 159 | cell.view(-1, 3, 3), shifts_idx, batch_image 160 | ) 161 | 162 | i_ids = torch.arange(n_atom, device=device, dtype=torch.long) 163 | i_ids = i_ids.repeat(n_cell_image) 164 | # compute the positions of the replicated unit cell (including the original) 165 | # they are organized such that: 1st n_atom are the non-shifted atom, 2nd n_atom are moved by the same translation, ... 166 | images = pos[i_ids] + cell_shifts 167 | n_images = images.shape[0] 168 | # create a rectangular box at [0,0,0] that encompases all the atoms (hence shifting the atoms so that they lie inside the box) 169 | b_min = images.min(dim=0).values 170 | b_max = images.max(dim=0).values 171 | images -= b_min - 1e-5 172 | box_length = b_max - b_min + 1e-3 173 | # divide the box into square bins of size cutoff in 3d 174 | nbins_s = torch.maximum(torch.ceil(box_length / cutoff), pos.new_ones(3)) 175 | # adapt the box lenghts so that it encompasses 176 | box_vec = torch.diag_embed(nbins_s * cutoff) 177 | nbins_s = nbins_s.to(torch.long) 178 | nbins = int(torch.prod(nbins_s)) 179 | # determine which bins the original atoms and the images belong to following a linear indexing of the 3d bins 180 | bin_index_j = get_linear_bin_idx(box_vec, images, nbins_s) 181 | n_atom_j_per_bin = torch.bincount(bin_index_j, minlength=nbins) 182 | max_n_atom_per_bin = int(n_atom_j_per_bin.max()) 183 | # convert the linear map bin_index_j into a 2d map. This allows for 184 | # fully vectorized neighbor assignment 185 | bin_id_j = scatter_bin_index( 186 | nbins, max_n_atom_per_bin, n_images, bin_index_j 187 | ) 188 | 189 | # find which bins the original atoms belong to 190 | bin_index_i = bin_index_j[:n_atom] 191 | i_bins_l = torch.unique(bin_index_i) 192 | i_bins_s = unravel_3d(i_bins_l, nbins_s) 193 | 194 | # find the bin indices in the neighborhood of i_bins_l. Since the bins have 195 | # a side length of cutoff only 27 bins are in the neighborhood 196 | # (including itself) 197 | dd = torch.tensor([0, 1, -1], dtype=torch.long, device=device) 198 | bin_shifts = torch.cartesian_prod(dd, dd, dd) 199 | n_neigh_bins = bin_shifts.shape[0] 200 | bin_shifts = bin_shifts.repeat((i_bins_s.shape[0], 1)) 201 | neigh_bins_s = ( 202 | torch.repeat_interleave( 203 | i_bins_s, 204 | n_neigh_bins, 205 | dim=0, 206 | output_size=n_neigh_bins * i_bins_s.shape[0], 207 | ) 208 | + bin_shifts 209 | ) 210 | # some of the generated bin_idx might not be valid 211 | mask = torch.all( 212 | torch.logical_and(neigh_bins_s < nbins_s.view(1, 3), neigh_bins_s >= 0), 213 | dim=1, 214 | ) 215 | 216 | # remove the bins that are outside of the search range, i.e. beyond the borders of the box in the case of non-periodic directions. 217 | neigh_j_bins_l = ravel_3d(neigh_bins_s[mask], nbins_s) 218 | 219 | max_neigh_per_atom = max_n_atom_per_bin * n_neigh_bins 220 | # the i_bin related to neigh_j_bins_l 221 | repeats = mask.view(-1, n_neigh_bins).sum(dim=1) 222 | neigh_i_bins_l = torch.cat( 223 | [ 224 | torch.arange(rr, device=device) + i_bins_l[ii] * n_neigh_bins 225 | for ii, rr in enumerate(repeats) 226 | ], 227 | dim=0, 228 | ) 229 | # the linear neighborlist. make it at large as necessary 230 | neigh_atom = torch.empty( 231 | (2, n_atom * max_neigh_per_atom), dtype=torch.long, device=device 232 | ) 233 | # fill the i_atom index 234 | neigh_atom[0] = ( 235 | torch.arange(n_atom).view(-1, 1).repeat(1, max_neigh_per_atom).view(-1) 236 | ) 237 | # relate `bin_index` (row) with the `neighbor_atom_index` (stored in the columns). empty entries are set to `n_images` 238 | bin_id_ij = torch.full( 239 | (nbins * n_neigh_bins, max_n_atom_per_bin), 240 | n_images, 241 | dtype=torch.long, 242 | device=device, 243 | ) 244 | # fill the bins with neighbor atom indices 245 | bin_id_ij[neigh_i_bins_l] = bin_id_j[neigh_j_bins_l] 246 | bin_id_ij = bin_id_ij.view((nbins, max_neigh_per_atom)) 247 | # map the neighbors in the bins to the central atoms 248 | neigh_atom[1] = bin_id_ij[bin_index_i].view(-1) 249 | # remove empty entries 250 | neigh_atom = neigh_atom[:, neigh_atom[1] != n_images] 251 | 252 | if not self_interaction: 253 | # neighbor atoms are still indexed from 0 to n_atom*n_cell_image 254 | neigh_atom = neigh_atom[:, neigh_atom[0] != neigh_atom[1]] 255 | 256 | # sort neighbor list so that the i_atom indices increase 257 | sorted_ids = torch.argsort(neigh_atom[0]) 258 | neigh_atom = neigh_atom[:, sorted_ids] 259 | # get the cell shift indices for each neighbor atom 260 | neigh_shift_idx = shifts_idx[neigh_atom[1]] 261 | # make sure the j_atom indices access the original positions 262 | neigh_atom[1] = torch.remainder(neigh_atom[1], n_atom) 263 | # print(neigh_atom) 264 | return neigh_atom, neigh_shift_idx 265 | 266 | 267 | def build_linked_cell_neighborhood( 268 | positions: torch.Tensor, 269 | cell: torch.Tensor, 270 | pbc: torch.Tensor, 271 | cutoff: float, 272 | n_atoms: torch.Tensor, 273 | self_interaction: bool = False, 274 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 275 | """Build the neighborlist of a given set of atomic structures using the linked cell algorithm. 276 | 277 | Parameters 278 | ---------- 279 | positions : torch.Tensor [-1, 3] 280 | set of atomic positions for each structures 281 | cell : torch.Tensor [3*n_structure, 3] 282 | set of unit cell vectors for each structures 283 | pbc : torch.Tensor [n_structures, 3] bool 284 | periodic boundary conditions to apply 285 | cutoff : float 286 | length used to determine neighborhood 287 | n_atoms : torch.Tensor 288 | number of atoms in each structures 289 | self_interaction : bool 290 | to keep the original atoms as their own neighbor 291 | 292 | Returns 293 | ------- 294 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 295 | mapping : [2, n_neighbors] 296 | indices of the neighbor list for the given positions array, mapping[0/1] correspond respectively to the central/neighbor atom (or node in the graph terminology) 297 | batch_mapping : [n_neighbors] 298 | indices mapping the neighbor atom to each structures 299 | cell_shifts_idx : [n_neighbors, 3] 300 | cell shift indices to be used in reconstructing the neighbor atom positions. 301 | """ 302 | 303 | n_structure = n_atoms.shape[0] 304 | device = positions.device 305 | cell = cell.view((-1, 3, 3)) 306 | pbc = pbc.view((-1, 3)) 307 | # compute the number of cell replica necessary so that all the unit cell's atom have a complete neighborhood (no MIC assumed here) 308 | num_repeats = get_number_of_cell_repeats(cutoff, cell, pbc) 309 | 310 | stride = strides_of(n_atoms) 311 | 312 | mapping, batch_mapping, cell_shifts_idx = [], [], [] 313 | for i_structure in range(n_structure): 314 | # compute the neighborhood with the linked cell algorithm 315 | neigh_atom, neigh_shift_idx = linked_cell( 316 | positions[stride[i_structure] : stride[i_structure + 1]], 317 | cell[i_structure], 318 | cutoff, 319 | num_repeats[i_structure], 320 | self_interaction, 321 | ) 322 | 323 | batch_mapping.append( 324 | i_structure 325 | * torch.ones(neigh_atom.shape[1], dtype=torch.long, device=device) 326 | ) 327 | # shift the mapping indices so that they can access positions 328 | mapping.append(neigh_atom + stride[i_structure]) 329 | cell_shifts_idx.append(neigh_shift_idx) 330 | return ( 331 | torch.cat(mapping, dim=1), 332 | torch.cat(batch_mapping, dim=0), 333 | torch.cat(cell_shifts_idx, dim=0), 334 | ) 335 | -------------------------------------------------------------------------------- /torch_nl/naive_impl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple 3 | 4 | from .utils import get_number_of_cell_repeats, get_cell_shift_idx, strides_of 5 | 6 | 7 | def get_fully_connected_mapping( 8 | i_ids: torch.Tensor, shifts_idx: torch.Tensor, self_interaction: bool 9 | ) -> Tuple[torch.Tensor, torch.Tensor]: 10 | n_atom = i_ids.shape[0] 11 | n_atom2 = n_atom * n_atom 12 | n_cell_image = shifts_idx.shape[0] 13 | j_ids = torch.repeat_interleave( 14 | i_ids, n_cell_image, dim=0, output_size=n_cell_image * n_atom 15 | ) 16 | mapping = torch.cartesian_prod(i_ids, j_ids) 17 | shifts_idx = shifts_idx.repeat((n_atom2, 1)) 18 | if not self_interaction: 19 | mask = torch.ones( 20 | mapping.shape[0], dtype=torch.bool, device=i_ids.device 21 | ) 22 | ids = n_cell_image * torch.arange( 23 | n_atom, device=i_ids.device 24 | ) + torch.arange( 25 | 0, mapping.shape[0], n_atom * n_cell_image, device=i_ids.device 26 | ) 27 | mask[ids] = False 28 | mapping = mapping[mask, :] 29 | shifts_idx = shifts_idx[mask] 30 | return mapping, shifts_idx 31 | 32 | 33 | def build_naive_neighborhood( 34 | positions: torch.Tensor, 35 | cell: torch.Tensor, 36 | pbc: torch.Tensor, 37 | cutoff: float, 38 | n_atoms: torch.Tensor, 39 | self_interaction: bool, 40 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 41 | """TODO: add doc""" 42 | device = positions.device 43 | dtype = positions.dtype 44 | 45 | num_repeats_ = get_number_of_cell_repeats(cutoff, cell, pbc) 46 | 47 | stride = strides_of(n_atoms) 48 | ids = torch.arange(positions.shape[0], device=device, dtype=torch.long) 49 | 50 | mapping, batch_mapping, shifts_idx_ = [], [], [] 51 | for i_structure in range(n_atoms.shape[0]): 52 | num_repeats = num_repeats_[i_structure] 53 | shifts_idx = get_cell_shift_idx(num_repeats, dtype) 54 | i_ids = ids[stride[i_structure] : stride[i_structure + 1]] 55 | 56 | s_mapping, shifts_idx = get_fully_connected_mapping( 57 | i_ids, shifts_idx, self_interaction 58 | ) 59 | mapping.append(s_mapping) 60 | batch_mapping.append( 61 | torch.full( 62 | (s_mapping.shape[0],), 63 | i_structure, 64 | dtype=torch.long, 65 | device=device, 66 | ) 67 | ) 68 | shifts_idx_.append(shifts_idx) 69 | return ( 70 | torch.cat(mapping, dim=0).t(), 71 | torch.cat(batch_mapping, dim=0), 72 | torch.cat(shifts_idx_, dim=0), 73 | ) 74 | -------------------------------------------------------------------------------- /torch_nl/neighbor_list.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .naive_impl import build_naive_neighborhood 4 | from .geometry import compute_cell_shifts 5 | from .linked_cell import build_linked_cell_neighborhood 6 | 7 | 8 | def strict_nl( 9 | cutoff: float, 10 | pos: torch.Tensor, 11 | cell: torch.Tensor, 12 | mapping: torch.Tensor, 13 | batch_mapping: torch.Tensor, 14 | shifts_idx: torch.Tensor, 15 | ): 16 | """Apply a strict cutoff to the neighbor list defined in mapping. 17 | 18 | Parameters 19 | ---------- 20 | cutoff : _type_ 21 | _description_ 22 | pos : _type_ 23 | _description_ 24 | cell : _type_ 25 | _description_ 26 | mapping : _type_ 27 | _description_ 28 | batch_mapping : _type_ 29 | _description_ 30 | shifts_idx : _type_ 31 | _description_ 32 | 33 | Returns 34 | ------- 35 | _type_ 36 | _description_ 37 | """ 38 | cell_shifts = compute_cell_shifts(cell, shifts_idx, batch_mapping) 39 | if cell_shifts is None: 40 | d2 = (pos[mapping[0]] - pos[mapping[1]]).square().sum(dim=1) 41 | else: 42 | d2 = ( 43 | (pos[mapping[0]] - pos[mapping[1]] - cell_shifts) 44 | .square() 45 | .sum(dim=1) 46 | ) 47 | 48 | mask = d2 < cutoff * cutoff 49 | mapping = mapping[:, mask] 50 | mapping_batch = batch_mapping[mask] 51 | shifts_idx = shifts_idx[mask] 52 | return mapping, mapping_batch, shifts_idx 53 | 54 | 55 | @torch.jit.script 56 | def compute_neighborlist_n2( 57 | cutoff: float, 58 | pos: torch.Tensor, 59 | cell: torch.Tensor, 60 | pbc: torch.Tensor, 61 | batch: torch.Tensor, 62 | self_interaction: bool = False, 63 | ): 64 | """Compute the neighborlist for a set of atomic structures using the naive a neighbor search before applying a strict `cutoff`. The atoms positions 65 | `pos` should be wrapped inside their respective unit cells. 66 | 67 | Parameters 68 | ---------- 69 | cutoff : float 70 | cutoff radius of used for the neighbor search 71 | pos : torch.Tensor [n_atom, 3] 72 | set of atoms positions wrapped inside their respective unit cells 73 | cell : torch.Tensor [3*n_structure, 3] 74 | unit cell vectors in the format [a_1, a_2, a_3] 75 | pbc : torch.Tensor [n_structure, 3] bool 76 | periodic boundary conditions to apply. Partial PBC are not supported yet 77 | batch : torch.Tensor torch.long [n_atom,] 78 | index of the structure in which the atom belongs to 79 | self_interaction : bool, optional 80 | to keep the center atoms as their own neighbor, by default False 81 | 82 | Returns 83 | ------- 84 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 85 | mapping : [2, n_neighbors] 86 | indices of the neighbor list for the given positions array, mapping[0/1] correspond respectively to the central/neighbor atom (or node in the graph terminology) 87 | batch_mapping : [n_neighbors] 88 | indices mapping the neighbor atom to each structures 89 | shifts_idx : [n_neighbors, 3] 90 | cell shift indices to be used in reconstructing the neighbor atom positions. 91 | """ 92 | n_atoms = torch.bincount(batch) 93 | mapping, batch_mapping, shifts_idx = build_naive_neighborhood( 94 | pos, cell, pbc, cutoff, n_atoms, self_interaction 95 | ) 96 | mapping, mapping_batch, shifts_idx = strict_nl( 97 | cutoff, pos, cell, mapping, batch_mapping, shifts_idx 98 | ) 99 | return mapping, mapping_batch, shifts_idx 100 | 101 | 102 | @torch.jit.script 103 | def compute_neighborlist( 104 | cutoff: float, 105 | pos: torch.Tensor, 106 | cell: torch.Tensor, 107 | pbc: torch.Tensor, 108 | batch: torch.Tensor, 109 | self_interaction: bool = False, 110 | ): 111 | """Compute the neighborlist for a set of atomic structures using the linked 112 | cell algorithm before applying a strict `cutoff`. The atoms positions `pos` 113 | should be wrapped inside their respective unit cells. 114 | 115 | Parameters 116 | ---------- 117 | cutoff : float 118 | cutoff radius of used for the neighbor search 119 | pos : torch.Tensor [n_atom, 3] 120 | set of atoms positions wrapped inside their respective unit cells 121 | cell : torch.Tensor [3*n_structure, 3] 122 | unit cell vectors in the format [a_1, a_2, a_3] 123 | pbc : torch.Tensor [n_structure, 3] bool 124 | periodic boundary conditions to apply. Partial PBC are not supported yet 125 | batch : torch.Tensor torch.long [n_atom,] 126 | index of the structure in which the atom belongs to 127 | self_interaction : bool, optional 128 | to keep the center atoms as their own neighbor, by default False 129 | 130 | Returns 131 | ------- 132 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 133 | mapping : [2, n_neighbors] 134 | indices of the neighbor list for the given positions array, mapping[0/1] correspond respectively to the central/neighbor atom (or node in the graph terminology) 135 | batch_mapping : [n_neighbors] 136 | indices mapping the neighbor atom to each structures 137 | shifts_idx : [n_neighbors, 3] 138 | cell shift indices to be used in reconstructing the neighbor atom positions. 139 | """ 140 | n_atoms = torch.bincount(batch) 141 | mapping, batch_mapping, shifts_idx = build_linked_cell_neighborhood( 142 | pos, cell, pbc, cutoff, n_atoms, self_interaction 143 | ) 144 | 145 | mapping, mapping_batch, shifts_idx = strict_nl( 146 | cutoff, pos, cell, mapping, batch_mapping, shifts_idx 147 | ) 148 | return mapping, mapping_batch, shifts_idx 149 | -------------------------------------------------------------------------------- /torch_nl/test_nl.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from ase.build import bulk, molecule 3 | import numpy as np 4 | from ase.neighborlist import neighbor_list 5 | from ase import Atoms 6 | import torch 7 | 8 | from .neighbor_list import ( 9 | compute_neighborlist_n2, 10 | compute_cell_shifts, 11 | compute_neighborlist, 12 | ) 13 | from .utils import ase2data 14 | from .geometry import compute_distances 15 | 16 | # triclinic atomic structure 17 | CaCrP2O7_mvc_11955_symmetrized = { 18 | "positions": [ 19 | [3.68954016, 5.03568186, 4.64369552], 20 | [5.12301681, 2.13482791, 2.66220405], 21 | [1.99411973, 0.94691001, 1.25068234], 22 | [6.81843724, 6.22359976, 6.05521724], 23 | [2.63005662, 4.16863452, 0.86090529], 24 | [6.18250036, 3.00187525, 6.44499428], 25 | [2.11497733, 1.98032773, 4.53610884], 26 | [6.69757964, 5.19018203, 2.76979073], 27 | [1.39215545, 2.94386142, 5.60917746], 28 | [7.42040152, 4.22664834, 1.69672212], 29 | [2.43224207, 5.4571615, 6.70305327], 30 | [6.3803149, 1.71334827, 0.6028463], 31 | [1.11265639, 1.50166318, 3.48760997], 32 | [7.69990058, 5.66884659, 3.8182896], 33 | [3.56971588, 5.20836551, 1.43673437], 34 | [5.2428411, 1.96214426, 5.8691652], 35 | [3.12282634, 2.72812741, 1.05450432], 36 | [5.68973063, 4.44238236, 6.25139525], 37 | [3.24868468, 2.83997522, 3.99842386], 38 | [5.56387229, 4.33053455, 3.30747571], 39 | [2.60835346, 0.74421609, 5.3236629], 40 | [6.20420351, 6.42629368, 1.98223667], 41 | ], 42 | "cell": [ 43 | [6.19330899, 0.0, 0.0], 44 | [2.4074486111396207, 6.149627748674982, 0.0], 45 | [0.2117993724186579, 1.0208820183960539, 7.305899571570074], 46 | ], 47 | "numbers": [ 48 | 20, 49 | 20, 50 | 24, 51 | 24, 52 | 15, 53 | 15, 54 | 15, 55 | 15, 56 | 8, 57 | 8, 58 | 8, 59 | 8, 60 | 8, 61 | 8, 62 | 8, 63 | 8, 64 | 8, 65 | 8, 66 | 8, 67 | 8, 68 | 8, 69 | 8, 70 | ], 71 | "pbc": [True, True, True], 72 | } 73 | 74 | 75 | def bulk_metal(): 76 | frames = [ 77 | bulk("Si", "diamond", a=6, cubic=True), 78 | bulk("Si", "diamond", a=6), 79 | bulk("Cu", "fcc", a=3.6), 80 | bulk("Si", "bct", a=6, c=3), 81 | # test very skewed unit cell 82 | bulk("Bi", "rhombohedral", a=6, alpha=20), 83 | bulk("Bi", "rhombohedral", a=6, alpha=10), 84 | bulk("Bi", "rhombohedral", a=6, alpha=5), 85 | bulk("SiCu", "rocksalt", a=6), 86 | bulk("SiFCu", "fluorite", a=6), 87 | Atoms(**CaCrP2O7_mvc_11955_symmetrized), 88 | ] 89 | return frames 90 | 91 | 92 | def atomic_structures(): 93 | frames = ( 94 | [ 95 | molecule("CH3CH2NH2"), 96 | molecule("H2O"), 97 | molecule("methylenecyclopropane"), 98 | ] 99 | + bulk_metal() 100 | + [ 101 | molecule("OCHCHO"), 102 | molecule("C3H9C"), 103 | ] 104 | ) 105 | return frames 106 | 107 | 108 | @pytest.mark.parametrize( 109 | "frames, cutoff, self_interaction", 110 | [ 111 | (atomic_structures(), rc, self_interaction) 112 | for rc in [1, 3, 5, 7] 113 | for self_interaction in [True, False] 114 | ], 115 | ) 116 | def test_neighborlist_n2(frames, cutoff, self_interaction): 117 | """Check that torch_neighbor_list gives the same NL as ASE by comparing 118 | the resulting sorted list of distances between neighbors.""" 119 | pos, cell, pbc, batch, n_atoms = ase2data(frames) 120 | 121 | dds = [] 122 | mapping, batch_mapping, shifts_idx = compute_neighborlist_n2( 123 | cutoff, pos, cell, pbc, batch, self_interaction 124 | ) 125 | cell_shifts = compute_cell_shifts(cell, shifts_idx, batch_mapping) 126 | dds = compute_distances(pos, mapping, cell_shifts) 127 | dds = np.sort(dds.numpy()) 128 | 129 | dd_ref = [] 130 | for frame in frames: 131 | idx_i, idx_j, idx_S, dist = neighbor_list( 132 | "ijSd", frame, cutoff=cutoff, self_interaction=self_interaction 133 | ) 134 | dd_ref.extend(dist) 135 | dd_ref = np.sort(dd_ref) 136 | 137 | np.testing.assert_allclose(dd_ref, dds) 138 | 139 | 140 | @pytest.mark.parametrize( 141 | "frames, cutoff, self_interaction", 142 | [ 143 | (atomic_structures(), rc, self_interaction) 144 | # for rc in [3] #[1, 3, 5, 7] 145 | # for self_interaction in [False] 146 | for rc in [1, 3, 5, 7] 147 | for self_interaction in [False, True] 148 | ], 149 | ) 150 | def test_neighborlist_linked_cell(frames, cutoff, self_interaction): 151 | """Check that torch_neighbor_list gives the same NL as ASE by comparing 152 | the resulting sorted list of distances between neighbors.""" 153 | pos, cell, pbc, batch, n_atoms = ase2data(frames) 154 | 155 | dds = [] 156 | mapping, batch_mapping, shifts_idx = compute_neighborlist( 157 | cutoff, pos, cell, pbc, batch, self_interaction 158 | ) 159 | cell_shifts = compute_cell_shifts(cell, shifts_idx, batch_mapping) 160 | dds = compute_distances(pos, mapping, cell_shifts) 161 | dds = np.sort(dds.numpy()) 162 | 163 | dd_ref = [] 164 | for frame in frames: 165 | idx_i, idx_j, idx_S, dist = neighbor_list( 166 | "ijSd", frame, cutoff=cutoff, self_interaction=self_interaction 167 | ) 168 | dd_ref.extend(dist) 169 | # nice for understanding if something goes wrong 170 | idx_S = torch.from_numpy(idx_S).to(torch.float64) 171 | 172 | print("idx_i", idx_i) 173 | print("idx_j", idx_j) 174 | missing_entries = [] 175 | for ineigh in range(idx_i.shape[0]): 176 | mask = torch.logical_and( 177 | idx_i[ineigh] == mapping[0], idx_j[ineigh] == mapping[1] 178 | ) 179 | 180 | if torch.any(torch.all(idx_S[ineigh] == shifts_idx[mask], dim=1)): 181 | pass 182 | else: 183 | missing_entries.append( 184 | (idx_i[ineigh], idx_j[ineigh], idx_S[ineigh]) 185 | ) 186 | print(missing_entries[-1]) 187 | print( 188 | compute_cell_shifts( 189 | cell, 190 | idx_S[ineigh].view((1, -1)), 191 | torch.tensor([0], dtype=torch.long), 192 | ) 193 | ) 194 | 195 | dd_ref = np.sort(dd_ref) 196 | print(dd_ref[-20:]) 197 | print(dds[-20:]) 198 | np.testing.assert_allclose(dd_ref, dds) 199 | -------------------------------------------------------------------------------- /torch_nl/timer.py: -------------------------------------------------------------------------------- 1 | from timeit import default_timer as timer 2 | import numpy as np 3 | from typing import Mapping 4 | from types import GeneratorType 5 | 6 | 7 | def eval_func(func, inp): 8 | if isinstance(inp, Mapping): 9 | inner = lambda inp: func(**inp) 10 | elif isinstance(inp, GeneratorType): 11 | inner = lambda inp: func(*inp) 12 | else: 13 | inner = lambda inp: func(*inp) 14 | return inner 15 | 16 | 17 | def timeit(func, inp, tag="", warmup=10, nit=100): 18 | timer = Timer(tag=tag) 19 | inner = eval_func(func, inp) 20 | 21 | for _ in range(warmup): 22 | inner(inp) 23 | for _ in range(nit): 24 | with timer: 25 | inner(inp) 26 | return timer 27 | 28 | 29 | class Timer(object): 30 | def __init__(self, tag="", logger=None): 31 | self.tag = tag 32 | self.elapsed = [] 33 | self.start = None 34 | self.end = None 35 | 36 | def __enter__(self): 37 | self.start = timer() 38 | 39 | def __exit__(self, type, value, traceback): 40 | self.end = timer() 41 | self.elapsed.append(self.end - self.start) 42 | 43 | def mean(self): 44 | return np.mean(self.elapsed) 45 | 46 | def stdev(self): 47 | return np.std(self.elapsed) 48 | 49 | def min(self): 50 | return np.min(self.elapsed) 51 | 52 | def max(self): 53 | return np.max(self.elapsed) 54 | 55 | def samples(self): 56 | return self.elapsed 57 | 58 | def dumps(self): 59 | data = dict( 60 | tag=self.tag, 61 | mean=self.mean(), 62 | stdev=self.stdev(), 63 | min=self.min(), 64 | max=self.max(), 65 | samples=self.samples(), 66 | ) 67 | return data 68 | 69 | def __repr__(self) -> str: 70 | timings = self.dumps() 71 | return f'{timings["tag"]} ' + " / ".join( 72 | [ 73 | f"{k}={timings[k]*1000:.5f} [ms]" 74 | for k in ["mean", "stdev", "min", "max"] 75 | ] 76 | ) 77 | -------------------------------------------------------------------------------- /torch_nl/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.types import _dtype 4 | 5 | 6 | def ase2data(frames, device="cpu"): 7 | n_atoms = [0] 8 | pos = [] 9 | cell = [] 10 | pbc = [] 11 | for ff in frames: 12 | n_atoms.append(len(ff)) 13 | pos.append(torch.from_numpy(ff.get_positions())) 14 | cell.append(torch.from_numpy(ff.get_cell().array)) 15 | pbc.append(torch.from_numpy(ff.get_pbc())) 16 | pos = torch.cat(pos) 17 | cell = torch.cat(cell) 18 | pbc = torch.cat(pbc) 19 | stride = torch.from_numpy(np.cumsum(n_atoms)) 20 | batch = torch.zeros(pos.shape[0], dtype=torch.long) 21 | for ii, (st, nd) in enumerate(zip(stride[:-1], stride[1:])): 22 | batch[st:nd] = ii 23 | n_atoms = torch.Tensor(n_atoms[1:]).to(dtype=torch.long) 24 | return ( 25 | pos.to(device=device), 26 | cell.to(device=device), 27 | pbc.to(device=device), 28 | batch.to(device=device), 29 | n_atoms.to(device=device), 30 | ) 31 | 32 | 33 | def strides_of(v: torch.Tensor) -> torch.Tensor: 34 | v = v.flatten() 35 | stride = v.new_empty(v.shape[0] + 1) 36 | stride[0] = 0 37 | torch.cumsum(v, dim=0, dtype=stride.dtype, out=stride[1:]) 38 | return stride 39 | 40 | 41 | def get_number_of_cell_repeats( 42 | cutoff: float, cell: torch.Tensor, pbc: torch.Tensor 43 | ) -> torch.Tensor: 44 | cell = cell.view((-1, 3, 3)) 45 | pbc = pbc.view((-1, 3)) 46 | 47 | has_pbc = pbc.prod(dim=1, dtype=torch.bool) 48 | reciprocal_cell = torch.zeros_like(cell) 49 | reciprocal_cell[has_pbc, :, :] = torch.linalg.inv( 50 | cell[has_pbc, :, :] 51 | ).transpose(2, 1) 52 | inv_distances = reciprocal_cell.norm(2, dim=-1) 53 | num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long) 54 | num_repeats_ = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats)) 55 | return num_repeats_ 56 | 57 | 58 | def get_cell_shift_idx( 59 | num_repeats: torch.Tensor, dtype: _dtype 60 | ) -> torch.Tensor: 61 | reps = [] 62 | for ii in range(3): 63 | r1 = torch.arange( 64 | -num_repeats[ii], 65 | num_repeats[ii] + 1, 66 | device=num_repeats.device, 67 | dtype=dtype, 68 | ) 69 | _, indices = torch.sort(torch.abs(r1)) 70 | reps.append(r1[indices]) 71 | shifts_idx = torch.cartesian_prod(reps[0], reps[1], reps[2]) 72 | return shifts_idx 73 | --------------------------------------------------------------------------------