├── .clang-format ├── .github ├── dependabot.yml └── workflows │ ├── mirror-gitlab.yaml │ ├── release.yaml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── CMakeLists.txt ├── LICENSE ├── README.md ├── build_backend ├── __init__.py └── dp_backend.py ├── deepmd_gnn ├── __init__.py ├── __main__.py ├── argcheck.py ├── env.py ├── mace.py ├── nequip.py ├── op.py └── py.typed ├── docs ├── Makefile ├── conf.py ├── index.rst └── parameters.rst ├── examples ├── .gitignore ├── dprc │ ├── data │ │ ├── nopbc │ │ ├── set.000 │ │ │ ├── box.npy │ │ │ ├── coord.npy │ │ │ ├── energy.npy │ │ │ └── force.npy │ │ ├── type.raw │ │ └── type_map.raw │ ├── mace │ │ └── input.json │ └── nequip │ │ └── input.json └── water │ ├── data │ ├── data_0 │ │ ├── set.000 │ │ │ ├── box.npy │ │ │ ├── coord.npy │ │ │ ├── energy.npy │ │ │ └── force.npy │ │ ├── type.raw │ │ └── type_map.raw │ ├── data_1 │ │ ├── set.000 │ │ │ ├── box.npy │ │ │ ├── coord.npy │ │ │ ├── energy.npy │ │ │ └── force.npy │ │ ├── set.001 │ │ │ ├── box.npy │ │ │ ├── coord.npy │ │ │ ├── energy.npy │ │ │ └── force.npy │ │ ├── type.raw │ │ └── type_map.raw │ ├── data_2 │ │ ├── set.000 │ │ │ ├── box.npy │ │ │ ├── coord.npy │ │ │ ├── energy.npy │ │ │ └── force.npy │ │ ├── type.raw │ │ └── type_map.raw │ └── data_3 │ │ ├── set.000 │ │ ├── box.npy │ │ ├── coord.npy │ │ ├── energy.npy │ │ └── force.npy │ │ ├── type.raw │ │ └── type_map.raw │ ├── mace │ └── input.json │ └── nequip │ └── input.json ├── noxfile.py ├── op ├── CMakeLists.txt └── edge_index.cc ├── pyproject.toml ├── renovate.json └── tests ├── __init__.py ├── data ├── set.000 │ ├── box.npy │ ├── coord.npy │ ├── energy.npy │ └── force.npy ├── type.raw └── type_map.raw ├── mace.json ├── nequip.json ├── test_examples.py ├── test_model.py ├── test_op.py ├── test_training.py └── test_version.py /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: Google 3 | BinPackParameters: false 4 | InsertBraces: true 5 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | -------------------------------------------------------------------------------- /.github/workflows/mirror-gitlab.yaml: -------------------------------------------------------------------------------- 1 | name: Mirror to GitLab Repo 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | tags: 8 | - v* 9 | 10 | # Ensures that only one mirror task will run at a time. 11 | concurrency: 12 | group: git-mirror 13 | 14 | jobs: 15 | git-mirror: 16 | if: github.repository_owner == 'njzjz' 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: wearerequired/git-mirror-action@v1 20 | env: 21 | SSH_PRIVATE_KEY: ${{ secrets.SYNC_GITLAB_PRIVATE_KEY }} 22 | with: 23 | source-repo: "https://github.com/njzjz/deepmd-gnn" 24 | destination-repo: "git@gitlab.com:RutgersLBSR/deepmd-gnn" 25 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - master 5 | tags: 6 | - "v*" 7 | pull_request: 8 | name: Build and release to pypi 9 | permissions: 10 | contents: read 11 | jobs: 12 | release-build: 13 | name: Build and upload distributions 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up uv 18 | uses: astral-sh/setup-uv@v5 19 | with: 20 | enable-cache: true 21 | cache-dependency-glob: | 22 | **/requirements*.txt 23 | **/pyproject.toml 24 | - name: Build dist 25 | run: uv tool run --with build[uv] --from build python -m build --installer uv --sdist 26 | - name: Upload release distributions 27 | uses: actions/upload-artifact@v4 28 | with: 29 | name: release-dists 30 | path: dist/ 31 | release-wheel: 32 | name: Build wheels for cp${{ matrix.python }}-${{ matrix.platform_id }} 33 | runs-on: ${{ matrix.os }} 34 | strategy: 35 | fail-fast: false 36 | matrix: 37 | include: 38 | # linux-64 39 | - os: ubuntu-latest 40 | python: 312 41 | platform_id: manylinux_x86_64 42 | # linux-aarch64 43 | - os: ubuntu-24.04-arm 44 | python: 312 45 | platform_id: manylinux_aarch64 46 | # macos-x86-64 47 | - os: macos-13 48 | python: 312 49 | platform_id: macosx_x86_64 50 | # macos-arm64 51 | - os: macos-14 52 | python: 312 53 | platform_id: macosx_arm64 54 | # win-64 55 | - os: windows-2019 56 | python: 312 57 | platform_id: win_amd64 58 | steps: 59 | - uses: actions/checkout@v4 60 | - name: Set up uv 61 | uses: astral-sh/setup-uv@v5 62 | with: 63 | enable-cache: true 64 | cache-dependency-glob: | 65 | **/requirements*.txt 66 | **/pyproject.toml 67 | if: runner.os != 'Linux' 68 | - name: Build wheels 69 | uses: pypa/cibuildwheel@v2.23 70 | env: 71 | CIBW_ARCHS: all 72 | CIBW_BUILD: cp${{ matrix.python }}-${{ matrix.platform_id }} 73 | CIBW_BUILD_FRONTEND: "build[uv]" 74 | - uses: actions/upload-artifact@v4 75 | with: 76 | name: release-cibw-cp${{ matrix.python }}-${{ matrix.platform_id }}-${{ strategy.job-index }} 77 | path: ./wheelhouse/*.whl 78 | pypi-publish: 79 | name: Release to pypi 80 | runs-on: ubuntu-latest 81 | environment: 82 | name: pypi_publish 83 | url: https://pypi.org/p/python-template 84 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') 85 | needs: 86 | - release-build 87 | - release-wheel 88 | permissions: 89 | id-token: write 90 | steps: 91 | - name: Retrieve release distributions 92 | uses: actions/download-artifact@v4 93 | with: 94 | merge-multiple: true 95 | pattern: release-* 96 | path: dist/ 97 | - name: Publish release distributions to PyPI 98 | uses: pypa/gh-action-pypi-publish@release/v1 99 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test Python package 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | env: 10 | UV_SYSTEM_PYTHON: 1 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | permissions: 15 | id-token: write 16 | contents: read 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up uv 20 | uses: astral-sh/setup-uv@v5 21 | with: 22 | enable-cache: true 23 | cache-dependency-glob: | 24 | **/requirements*.txt 25 | **/pyproject.toml 26 | - name: Set up Python 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: 3.12 30 | - name: Install dependencies 31 | run: uv pip install nox[uv] 32 | - name: Test with pytest 33 | run: nox -db uv 34 | - name: Upload coverage reports to Codecov 35 | uses: codecov/codecov-action@v5 36 | with: 37 | use_oidc: ${{ !(github.event_name == 'pull_request' && github.event.pull_request.head.repo.fork) }} 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | .ruff_cache/ 164 | node_modules/ 165 | deepmd_gnn/_version.py 166 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v5.0.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-json 11 | - id: check-added-large-files 12 | - id: check-merge-conflict 13 | - id: check-symlinks 14 | - id: check-toml 15 | - id: mixed-line-ending 16 | # Python 17 | - repo: https://github.com/astral-sh/ruff-pre-commit 18 | rev: v0.11.7 19 | hooks: 20 | - id: ruff 21 | args: ["--fix"] 22 | - id: ruff-format 23 | - repo: https://github.com/pre-commit/mirrors-mypy 24 | rev: "v1.15.0" 25 | hooks: 26 | - id: mypy 27 | additional_dependencies: [] 28 | - repo: https://github.com/pre-commit/mirrors-prettier 29 | rev: v4.0.0-alpha.8 30 | hooks: 31 | - id: prettier 32 | types_or: [javascript, css, html, markdown, yaml] 33 | # C++ 34 | - repo: https://github.com/pre-commit/mirrors-clang-format 35 | rev: v20.1.0 36 | hooks: 37 | - id: clang-format 38 | # CMake 39 | - repo: https://github.com/cheshirekow/cmake-format-precommit 40 | rev: v0.6.13 41 | hooks: 42 | - id: cmake-format 43 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | build: 3 | os: ubuntu-24.04 4 | tools: 5 | python: "3.12" 6 | jobs: 7 | install: 8 | - asdf plugin add uv 9 | - asdf install uv latest 10 | - asdf global uv latest 11 | - uv venv $READTHEDOCS_VIRTUALENV_PATH 12 | - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH uv pip install deepmd-kit[torch]>=3.0.0b2 --extra-index-url https://download.pytorch.org/whl/cpu 13 | - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH CMAKE_PREFIX_PATH=$(python -c "import torch;print(torch.utils.cmake_prefix_path)") uv pip install -e .[docs] 14 | - $READTHEDOCS_VIRTUALENV_PATH/bin/python -m sphinx -T -b html -d docs/_build/doctrees -D language=en docs $READTHEDOCS_OUTPUT/html 15 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.15) 2 | project(deepmd-gnn CXX) 3 | 4 | set(CMAKE_CXX_STANDARD 14) 5 | macro(set_if_higher VARIABLE VALUE) 6 | # ${VARIABLE} is a variable name, not a string 7 | if(${VARIABLE} LESS "${VALUE}") 8 | set(${VARIABLE} ${VALUE}) 9 | endif() 10 | endmacro() 11 | 12 | # build cpp or python interfaces 13 | option(BUILD_CPP_IF "Build C++ interfaces" ON) 14 | option(BUILD_PY_IF "Build Python interfaces" OFF) 15 | option(USE_PT_PYTHON_LIBS "Use PyTorch Python libraries" OFF) 16 | 17 | if((NOT BUILD_PY_IF) AND (NOT BUILD_CPP_IF)) 18 | # nothing to do 19 | message(FATAL_ERROR "Nothing to build.") 20 | endif() 21 | 22 | if(BUILD_CPP_IF 23 | AND USE_PT_PYTHON_LIBS 24 | AND NOT CMAKE_CROSSCOMPILING 25 | AND NOT SKBUILD 26 | OR "$ENV{CIBUILDWHEEL}" STREQUAL "1") 27 | find_package( 28 | Python 29 | COMPONENTS Interpreter 30 | REQUIRED) 31 | execute_process( 32 | COMMAND ${Python_EXECUTABLE} -c 33 | "import torch;print(torch.utils.cmake_prefix_path)" 34 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} 35 | OUTPUT_VARIABLE PYTORCH_CMAKE_PREFIX_PATH 36 | RESULT_VARIABLE PYTORCH_CMAKE_PREFIX_PATH_RESULT_VAR 37 | ERROR_VARIABLE PYTORCH_CMAKE_PREFIX_PATH_ERROR_VAR 38 | OUTPUT_STRIP_TRAILING_WHITESPACE) 39 | if(NOT ${PYTORCH_CMAKE_PREFIX_PATH_RESULT_VAR} EQUAL 0) 40 | message( 41 | FATAL_ERROR 42 | "Cannot determine PyTorch CMake prefix path, error code: $PYTORCH_CMAKE_PREFIX_PATH_RESULT_VAR}, error message: ${PYTORCH_CMAKE_PREFIX_PATH_ERROR_VAR}" 43 | ) 44 | endif() 45 | list(APPEND CMAKE_PREFIX_PATH ${PYTORCH_CMAKE_PREFIX_PATH}) 46 | endif() 47 | find_package(Torch REQUIRED) 48 | if(NOT Torch_VERSION VERSION_LESS "2.1.0") 49 | set_if_higher(CMAKE_CXX_STANDARD 17) 50 | elseif(NOT Torch_VERSION VERSION_LESS "1.5.0") 51 | set_if_higher(CMAKE_CXX_STANDARD 14) 52 | endif() 53 | string(REGEX MATCH "_GLIBCXX_USE_CXX11_ABI=([0-9]+)" CXXABI_PT_MATCH 54 | "${TORCH_CXX_FLAGS}") 55 | if(CXXABI_PT_MATCH) 56 | set(OP_CXX_ABI_PT ${CMAKE_MATCH_1}) 57 | message(STATUS "PyTorch CXX11 ABI: ${CMAKE_MATCH_1}") 58 | else() 59 | # Maybe in macos/windows 60 | set(OP_CXX_ABI_PT 0) 61 | endif() 62 | 63 | # define build type 64 | if((NOT DEFINED CMAKE_BUILD_TYPE) OR CMAKE_BUILD_TYPE STREQUAL "") 65 | set(CMAKE_BUILD_TYPE release) 66 | endif() 67 | 68 | add_subdirectory(op) 69 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU LESSER GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | 9 | This version of the GNU Lesser General Public License incorporates 10 | the terms and conditions of version 3 of the GNU General Public 11 | License, supplemented by the additional permissions listed below. 12 | 13 | 0. Additional Definitions. 14 | 15 | As used herein, "this License" refers to version 3 of the GNU Lesser 16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU 17 | General Public License. 18 | 19 | "The Library" refers to a covered work governed by this License, 20 | other than an Application or a Combined Work as defined below. 21 | 22 | An "Application" is any work that makes use of an interface provided 23 | by the Library, but which is not otherwise based on the Library. 24 | Defining a subclass of a class defined by the Library is deemed a mode 25 | of using an interface provided by the Library. 26 | 27 | A "Combined Work" is a work produced by combining or linking an 28 | Application with the Library. The particular version of the Library 29 | with which the Combined Work was made is also called the "Linked 30 | Version". 31 | 32 | The "Minimal Corresponding Source" for a Combined Work means the 33 | Corresponding Source for the Combined Work, excluding any source code 34 | for portions of the Combined Work that, considered in isolation, are 35 | based on the Application, and not on the Linked Version. 36 | 37 | The "Corresponding Application Code" for a Combined Work means the 38 | object code and/or source code for the Application, including any data 39 | and utility programs needed for reproducing the Combined Work from the 40 | Application, but excluding the System Libraries of the Combined Work. 41 | 42 | 1. Exception to Section 3 of the GNU GPL. 43 | 44 | You may convey a covered work under sections 3 and 4 of this License 45 | without being bound by section 3 of the GNU GPL. 46 | 47 | 2. Conveying Modified Versions. 48 | 49 | If you modify a copy of the Library, and, in your modifications, a 50 | facility refers to a function or data to be supplied by an Application 51 | that uses the facility (other than as an argument passed when the 52 | facility is invoked), then you may convey a copy of the modified 53 | version: 54 | 55 | a) under this License, provided that you make a good faith effort to 56 | ensure that, in the event an Application does not supply the 57 | function or data, the facility still operates, and performs 58 | whatever part of its purpose remains meaningful, or 59 | 60 | b) under the GNU GPL, with none of the additional permissions of 61 | this License applicable to that copy. 62 | 63 | 3. Object Code Incorporating Material from Library Header Files. 64 | 65 | The object code form of an Application may incorporate material from 66 | a header file that is part of the Library. You may convey such object 67 | code under terms of your choice, provided that, if the incorporated 68 | material is not limited to numerical parameters, data structure 69 | layouts and accessors, or small macros, inline functions and templates 70 | (ten or fewer lines in length), you do both of the following: 71 | 72 | a) Give prominent notice with each copy of the object code that the 73 | Library is used in it and that the Library and its use are 74 | covered by this License. 75 | 76 | b) Accompany the object code with a copy of the GNU GPL and this license 77 | document. 78 | 79 | 4. Combined Works. 80 | 81 | You may convey a Combined Work under terms of your choice that, 82 | taken together, effectively do not restrict modification of the 83 | portions of the Library contained in the Combined Work and reverse 84 | engineering for debugging such modifications, if you also do each of 85 | the following: 86 | 87 | a) Give prominent notice with each copy of the Combined Work that 88 | the Library is used in it and that the Library and its use are 89 | covered by this License. 90 | 91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license 92 | document. 93 | 94 | c) For a Combined Work that displays copyright notices during 95 | execution, include the copyright notice for the Library among 96 | these notices, as well as a reference directing the user to the 97 | copies of the GNU GPL and this license document. 98 | 99 | d) Do one of the following: 100 | 101 | 0) Convey the Minimal Corresponding Source under the terms of this 102 | License, and the Corresponding Application Code in a form 103 | suitable for, and under terms that permit, the user to 104 | recombine or relink the Application with a modified version of 105 | the Linked Version to produce a modified Combined Work, in the 106 | manner specified by section 6 of the GNU GPL for conveying 107 | Corresponding Source. 108 | 109 | 1) Use a suitable shared library mechanism for linking with the 110 | Library. A suitable mechanism is one that (a) uses at run time 111 | a copy of the Library already present on the user's computer 112 | system, and (b) will operate properly with a modified version 113 | of the Library that is interface-compatible with the Linked 114 | Version. 115 | 116 | e) Provide Installation Information, but only if you would otherwise 117 | be required to provide such information under section 6 of the 118 | GNU GPL, and only to the extent that such information is 119 | necessary to install and execute a modified version of the 120 | Combined Work produced by recombining or relinking the 121 | Application with a modified version of the Linked Version. (If 122 | you use option 4d0, the Installation Information must accompany 123 | the Minimal Corresponding Source and Corresponding Application 124 | Code. If you use option 4d1, you must provide the Installation 125 | Information in the manner specified by section 6 of the GNU GPL 126 | for conveying Corresponding Source.) 127 | 128 | 5. Combined Libraries. 129 | 130 | You may place library facilities that are a work based on the 131 | Library side by side in a single library together with other library 132 | facilities that are not Applications and are not covered by this 133 | License, and convey such a combined library under terms of your 134 | choice, if you do both of the following: 135 | 136 | a) Accompany the combined library with a copy of the same work based 137 | on the Library, uncombined with any other library facilities, 138 | conveyed under the terms of this License. 139 | 140 | b) Give prominent notice with the combined library that part of it 141 | is a work based on the Library, and explaining where to find the 142 | accompanying uncombined form of the same work. 143 | 144 | 6. Revised Versions of the GNU Lesser General Public License. 145 | 146 | The Free Software Foundation may publish revised and/or new versions 147 | of the GNU Lesser General Public License from time to time. Such new 148 | versions will be similar in spirit to the present version, but may 149 | differ in detail to address new problems or concerns. 150 | 151 | Each version is given a distinguishing version number. If the 152 | Library as you received it specifies that a certain numbered version 153 | of the GNU Lesser General Public License "or any later version" 154 | applies to it, you have the option of following the terms and 155 | conditions either of that published version or of any later version 156 | published by the Free Software Foundation. If the Library as you 157 | received it does not specify a version number of the GNU Lesser 158 | General Public License, you may choose any version of the GNU Lesser 159 | General Public License ever published by the Free Software Foundation. 160 | 161 | If the Library as you received it specifies that a proxy can decide 162 | whether future versions of the GNU Lesser General Public License shall 163 | apply, that proxy's public statement of acceptance of any version is 164 | permanent authorization for you to choose that version for the 165 | Library. 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeePMD-kit plugin for various graph neural network models 2 | 3 | [![DOI:10.1021/acs.jcim.4c02441](https://img.shields.io/badge/DOI-10.1021%2Facs.jcim.4c02441-blue)](https://doi.org/10.1021/acs.jcim.4c02441) 4 | [![Citations](https://citations.njzjz.win/10.1021/acs.jcim.4c02441)](https://doi.org/10.1021/acs.jcim.4c02441) 5 | [![conda install](https://img.shields.io/conda/dn/conda-forge/deepmd-gnn?label=conda%20install)](https://anaconda.org/conda-forge/deepmd-gnn) 6 | [![PyPI - Version](https://img.shields.io/pypi/v/deepmd-gnn)](https://pypi.org/p/deepmd-gnn) 7 | 8 | `deepmd-gnn` is a [DeePMD-kit](https://github.com/deepmodeling/deepmd-kit) plugin for various graph neural network (GNN) models, which connects DeePMD-kit and atomistic GNN packages by enabling GNN models in DeePMD-kit. 9 | 10 | Supported packages and models include: 11 | 12 | - [MACE](https://github.com/ACEsuit/mace) (PyTorch version) 13 | - [NequIP](https://github.com/mir-group/nequip) (PyTorch version) 14 | 15 | After [installing the plugin](#installation), you can train the GNN models using DeePMD-kit, run active learning cycles for the GNN models using [DP-GEN](https://github.com/deepmodeling/dpgen), perform simulations with the MACE model using molecular dynamic packages supported by DeePMD-kit, such as [LAMMPS](https://github.com/lammps/lammps) and [AMBER](https://ambermd.org/). 16 | You can follow [DeePMD-kit documentation](https://docs.deepmodeling.com/projects/deepmd/en/latest/) to train the GNN models using its PyTorch backend, after using the specific [model parameters](#parameters). 17 | 18 | ## Credits 19 | 20 | If you use this software, please cite the following paper: 21 | 22 | - Jinzhe Zeng, Timothy J. Giese, Duo Zhang, Han Wang, Darrin M. York, DeePMD-GNN: A DeePMD-kit Plugin for External Graph Neural Network Potentials, _J. Chem. Inf. Model._, 2025, 65, 7, 3154-3160, DOI: [10.1021/acs.jcim.4c02441](https://doi.org/10.1021/acs.jcim.4c02441). [![Citations](https://citations.njzjz.win/10.1021/acs.jcim.4c02441)](https://badge.dimensions.ai/details/doi/10.1021/acs.jcim.4c02441) 23 | 24 | ## Installation 25 | 26 | ### Install via conda 27 | 28 | If you are in a [conda environment](https://docs.deepmodeling.com/faq/conda.html) where DeePMD-kit is already installed from the conda-forge channel, 29 | you can use `conda` to install the DeePMD-GNN plugin: 30 | 31 | ```sh 32 | conda install deepmd-gnn -c conda-forge 33 | ``` 34 | 35 | ### Build from source 36 | 37 | First, clone this repository: 38 | 39 | ```sh 40 | git clone https://gitlab.com/RutgersLBSR/deepmd-gnn 41 | cd deepmd-gnn 42 | ``` 43 | 44 | #### Python interface plugin 45 | 46 | Python 3.9 or above is required. A C++ compiler that supports C++ 14 (for PyTorch 2.0) or C++ 17 (for PyTorch 2.1 or above) is required. 47 | 48 | Assume you have installed [DeePMD-kit](https://github.com/deepmodeling/deepmd-kit) (v3.0.0b2 or above) and [PyTorch](https://github.com/pytorch/pytorch) in an environment, then execute 49 | 50 | ```sh 51 | # expose PyTorch CMake modules 52 | export CMAKE_PREFIX_PATH=$(python -c "import torch;print(torch.utils.cmake_prefix_path)") 53 | 54 | pip install . 55 | ``` 56 | 57 | #### C++ interface plugin 58 | 59 | DeePMD-kit version should be v3.0.0b4 or later. 60 | 61 | Follow [DeePMD-kit documentation](https://docs.deepmodeling.com/projects/deepmd/en/latest/install/install-from-source.html#install-the-c-interface) to install DeePMD-kit C++ interface with PyTorch backend support and other related MD packages. 62 | After that, you can build the plugin 63 | 64 | ```sh 65 | # Assume libtorch has been contained in CMAKE_PREFIX_PATH 66 | mkdir -p build 67 | cd build 68 | cmake .. -D CMAKE_INSTALL_PREFIX=/prefix/to/install 69 | cmake --build . -j8 70 | cmake --install . 71 | ``` 72 | 73 | `libdeepmd_gnn.so` will be installed into the directory you assign. 74 | When using any DeePMD-kit C++ interface, set the following environment variable in advance: 75 | 76 | ```sh 77 | export DP_PLUGIN_PATH=/prefix/to/install/lib/libdeepmd_gnn.so 78 | ``` 79 | 80 | ## Usage 81 | 82 | Follow [Parameters section](#parameters) to prepare a DeePMD-kit input file. 83 | 84 | ```sh 85 | dp --pt train input.json 86 | dp --pt freeze 87 | ``` 88 | 89 | A frozen model file named `frozen_model.pth` will be generated. You can use it in the MD packages or other interfaces. 90 | For details, follow [DeePMD-kit documentation](https://docs.deepmodeling.com/projects/deepmd/en/latest/). 91 | 92 | ### Running LAMMPS + MACE with period boundary conditions 93 | 94 | GNN models use message passing neural networks, 95 | so the neighbor list built with traditional cutoff radius will not work, 96 | since the ghost atoms also need to build neighbor list. 97 | By default, the model requests the neighbor list with a cutoff radius of $r_c \times N_{L}$, 98 | where $r_c$ is set by `r_max` and $N_L$ is set by `num_interactions` (MACE) / `num_layers` (NequIP), 99 | and rebuilds the neighbor list for ghost atoms. 100 | However, this approach is very inefficient. 101 | 102 | The alternative approach for the MACE model (note: NequIP doesn't support such approach) is to use the mapping passed from LAMMPS, which does not support MPI. 103 | One needs to set `DP_GNN_USE_MAPPING` when freezing the models, 104 | 105 | ```sh 106 | DP_GNN_USE_MAPPING=1 dp --pt freeze 107 | ``` 108 | 109 | and request the mapping when using LAMMPS (also requires DeePMD-kit v3.0.0rc0 or above). 110 | By using the mapping, the ghost atoms will be mapped to the real atoms, 111 | so the regular neighbor list with a cutoff radius of $r_c$ can be used. 112 | 113 | ```lammps 114 | atom_modify map array 115 | ``` 116 | 117 | In the future, we will explore utilizing the MPI to communicate the neighbor list, 118 | while this approach requires a deep hack for external packages. 119 | 120 | ## Parameters 121 | 122 | ### MACE 123 | 124 | To use the MACE model, set `"type": "mace"` in the `model` section of the training script. 125 | Below is default values for the MACE model, most of which follows default values in the MACE package: 126 | 127 | ```json 128 | "model": { 129 | "type": "mace", 130 | "type_map": [ 131 | "O", 132 | "H" 133 | ], 134 | "r_max": 5.0, 135 | "sel": "auto", 136 | "num_radial_basis": 8, 137 | "num_cutoff_basis": 5, 138 | "max_ell": 3, 139 | "interaction": "RealAgnosticResidualInteractionBlock", 140 | "num_interactions": 2, 141 | "hidden_irreps": "128x0e + 128x1o", 142 | "pair_repulsion": false, 143 | "distance_transform": "None", 144 | "correlation": 3, 145 | "gate": "silu", 146 | "MLP_irreps": "16x0e", 147 | "radial_type": "bessel", 148 | "radial_MLP": [64, 64, 64], 149 | "std": 1.0, 150 | "precision": "float32" 151 | } 152 | ``` 153 | 154 | ### NequIP 155 | 156 | ```json 157 | "model": { 158 | "type": "nequip", 159 | "type_map": [ 160 | "O", 161 | "H" 162 | ], 163 | "r_max": 5.0, 164 | "sel": "auto", 165 | "num_layers": 4, 166 | "l_max": 2, 167 | "num_features": 32, 168 | "nonlinearity_type": "gate", 169 | "parity": true, 170 | "num_basis": 8, 171 | "BesselBasis_trainable": true, 172 | "PolynomialCutoff_p": 6, 173 | "invariant_layers": 2, 174 | "invariant_neurons": 64, 175 | "use_sc": true, 176 | "irreps_edge_sh": "0e + 1e", 177 | "feature_irreps_hidden": "32x0o + 32x0e + 32x1o + 32x1e", 178 | "chemical_embedding_irreps_out": "32x0e", 179 | "conv_to_output_hidden_irreps_out": "16x0e", 180 | "precision": "float32" 181 | } 182 | ``` 183 | 184 | ## DPRc support 185 | 186 | In `deepmd-gnn`, the GNN model can be used in a [DPRc](https://docs.deepmodeling.com/projects/deepmd/en/latest/model/dprc.html) way. 187 | Type maps that starts with `m` (such as `mH`) or `OW` or `HW` will be recognized as MM types. 188 | Two MM atoms will not build edges with each other. 189 | Such GNN+DPRc model can be directly used in AmberTools24. 190 | 191 | ## Examples 192 | 193 | - [examples/water](examples/water) 194 | - [examples/dprc](examples/dprc) 195 | -------------------------------------------------------------------------------- /build_backend/__init__.py: -------------------------------------------------------------------------------- 1 | """Customized PEP-517 build backend.""" 2 | -------------------------------------------------------------------------------- /build_backend/dp_backend.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: LGPL-3.0-or-later 2 | """A PEP-517 backend to add customized dependencies.""" 3 | 4 | import os 5 | 6 | from scikit_build_core import build as _orig 7 | 8 | __all__ = [ 9 | "build_editable", 10 | "build_sdist", 11 | "build_wheel", 12 | "get_requires_for_build_editable", 13 | "get_requires_for_build_sdist", 14 | "get_requires_for_build_wheel", 15 | "prepare_metadata_for_build_editable", 16 | "prepare_metadata_for_build_wheel", 17 | ] 18 | 19 | 20 | def __dir__() -> list[str]: 21 | return __all__ 22 | 23 | 24 | prepare_metadata_for_build_wheel = _orig.prepare_metadata_for_build_wheel 25 | build_wheel = _orig.build_wheel 26 | build_sdist = _orig.build_sdist 27 | get_requires_for_build_sdist = _orig.get_requires_for_build_sdist 28 | prepare_metadata_for_build_editable = _orig.prepare_metadata_for_build_editable 29 | build_editable = _orig.build_editable 30 | get_requires_for_build_editable = _orig.get_requires_for_build_editable 31 | 32 | 33 | def cibuildwheel_dependencies() -> list[str]: 34 | if ( 35 | os.environ.get("CIBUILDWHEEL", "0") == "1" 36 | or os.environ.get("READTHEDOCS", "0") == "True" 37 | ): 38 | return [ 39 | "deepmd-kit[torch]>=3.0.0b2", 40 | ] 41 | return [] 42 | 43 | 44 | def get_requires_for_build_wheel( 45 | config_settings: dict, 46 | ) -> list[str]: 47 | """Return the dependencies for building a wheel.""" 48 | return ( 49 | _orig.get_requires_for_build_wheel(config_settings) 50 | + cibuildwheel_dependencies() 51 | ) 52 | -------------------------------------------------------------------------------- /deepmd_gnn/__init__.py: -------------------------------------------------------------------------------- 1 | """MACE plugin for DeePMD-kit.""" 2 | 3 | import os 4 | 5 | from ._version import __version__ 6 | from .argcheck import mace_model_args 7 | 8 | __email__ = "jinzhe.zeng@rutgers.edu" 9 | 10 | __all__ = [ 11 | "__version__", 12 | "mace_model_args", 13 | ] 14 | 15 | # make compatible with mace & e3nn & pytorch 2.6 16 | os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" 17 | -------------------------------------------------------------------------------- /deepmd_gnn/__main__.py: -------------------------------------------------------------------------------- 1 | """Main entry point for the command line interface.""" 2 | 3 | if __name__ == "__main__": 4 | msg = "This module is not meant to be executed directly." 5 | raise NotImplementedError(msg) 6 | -------------------------------------------------------------------------------- /deepmd_gnn/argcheck.py: -------------------------------------------------------------------------------- 1 | """Argument check for the MACE model.""" 2 | 3 | from dargs import Argument 4 | from deepmd.utils.argcheck import model_args_plugin 5 | 6 | 7 | @model_args_plugin.register("mace") 8 | def mace_model_args() -> Argument: 9 | """Arguments for the MACE model. 10 | 11 | Returns 12 | ------- 13 | Argument 14 | Arguments for the MACE model. 15 | """ 16 | doc_r_max = "distance cutoff (in Ang)" 17 | doc_num_radial_basis = "number of radial basis functions" 18 | doc_num_cutoff_basis = "number of basis functions for smooth cutoff" 19 | doc_max_ell = "highest ell of spherical harmonics" 20 | doc_interaction = "name of interaction block" 21 | doc_num_interactions = "number of interactions" 22 | doc_hidden_irreps = "hidden irreps" 23 | doc_pair_repulsion = "use pair repulsion term with ZBL potential" 24 | doc_distance_transform = "distance transform" 25 | doc_correlation = "correlation order at each layer" 26 | doc_gate = "non linearity for last readout" 27 | doc_mlp_irreps = "hidden irreps of the MLP in last readout" 28 | doc_radial_type = "type of radial basis functions" 29 | doc_radial_mlp = "width of the radial MLP" 30 | doc_std = "Standard deviation of force components in the training set" 31 | doc_precision = "Precision of the model, float32 or float64" 32 | return Argument( 33 | "mace", 34 | dict, 35 | [ 36 | Argument("sel", [int, str], optional=False), 37 | Argument("r_max", float, optional=True, default=5.0, doc=doc_r_max), 38 | Argument( 39 | "num_radial_basis", 40 | int, 41 | optional=True, 42 | default=8, 43 | doc=doc_num_radial_basis, 44 | ), 45 | Argument( 46 | "num_cutoff_basis", 47 | int, 48 | optional=True, 49 | default=5, 50 | doc=doc_num_cutoff_basis, 51 | ), 52 | Argument("max_ell", int, optional=True, default=3, doc=doc_max_ell), 53 | Argument( 54 | "interaction", 55 | str, 56 | optional=True, 57 | default="RealAgnosticResidualInteractionBlock", 58 | doc=doc_interaction, 59 | ), 60 | Argument( 61 | "num_interactions", 62 | int, 63 | optional=True, 64 | default=2, 65 | doc=doc_num_interactions, 66 | ), 67 | Argument( 68 | "hidden_irreps", 69 | str, 70 | optional=True, 71 | default="128x0e + 128x1o", 72 | doc=doc_hidden_irreps, 73 | ), 74 | Argument( 75 | "pair_repulsion", 76 | bool, 77 | optional=True, 78 | default=False, 79 | doc=doc_pair_repulsion, 80 | ), 81 | Argument( 82 | "distance_transform", 83 | str, 84 | optional=True, 85 | default="None", 86 | doc=doc_distance_transform, 87 | ), 88 | Argument("correlation", int, optional=True, default=3, doc=doc_correlation), 89 | Argument("gate", str, optional=True, default="silu", doc=doc_gate), 90 | Argument( 91 | "MLP_irreps", 92 | str, 93 | optional=True, 94 | default="16x0e", 95 | doc=doc_mlp_irreps, 96 | ), 97 | Argument( 98 | "radial_type", 99 | str, 100 | optional=True, 101 | default="bessel", 102 | doc=doc_radial_type, 103 | ), 104 | Argument( 105 | "radial_MLP", 106 | list[int], 107 | optional=True, 108 | default=[64, 64, 64], 109 | doc=doc_radial_mlp, 110 | ), 111 | Argument("std", float, optional=True, doc=doc_std, default=1), 112 | Argument( 113 | "precision", 114 | str, 115 | optional=True, 116 | default="float32", 117 | doc=doc_precision, 118 | ), 119 | ], 120 | doc="MACE model", 121 | ) 122 | 123 | 124 | @model_args_plugin.register("nequip") 125 | def nequip_model_args() -> Argument: 126 | """Arguments for the NequIP model.""" 127 | doc_sel = "Maximum number of neighbor atoms." 128 | doc_r_max = "distance cutoff (in Ang)" 129 | doc_num_layers = "number of interaction blocks, we find 3-5 to work best" 130 | doc_l_max = "the maximum irrep order (rotation order) for the network's features, l=1 is a good default, l=2 is more accurate but slower" 131 | doc_num_features = "the multiplicity of the features, 32 is a good default for accurate network, if you want to be more accurate, go larger, if you want to be faster, go lower" 132 | doc_nonlinearity_type = "may be 'gate' or 'norm', 'gate' is recommended" 133 | doc_parity = "whether to include features with odd mirror parityy; often turning parity off gives equally good results but faster networks, so do consider this" 134 | doc_num_basis = ( 135 | "number of basis functions used in the radial basis, 8 usually works best" 136 | ) 137 | doc_besselbasis_trainable = "set true to train the bessel weights" 138 | doc_polynomialcutoff_p = "p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance" 139 | doc_invariant_layers = ( 140 | "number of radial layers, usually 1-3 works best, smaller is faster" 141 | ) 142 | doc_invariant_neurons = ( 143 | "number of hidden neurons in radial function, smaller is faster" 144 | ) 145 | doc_use_sc = "use self-connection or not, usually gives big improvement" 146 | doc_irreps_edge_sh = "irreps for the chemical embedding of species" 147 | doc_feature_irreps_hidden = "irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster" 148 | doc_chemical_embedding_irreps_out = "irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer" 149 | doc_conv_to_output_hidden_irreps_out = "irreps used in hidden layer of output block" 150 | doc_precision = "Precision of the model, float32 or float64" 151 | return Argument( 152 | "nequip", 153 | dict, 154 | [ 155 | Argument( 156 | "sel", 157 | [int, str], 158 | optional=False, 159 | doc=doc_sel, 160 | ), 161 | Argument( 162 | "r_max", 163 | float, 164 | optional=True, 165 | default=6.0, 166 | doc=doc_r_max, 167 | ), 168 | Argument( 169 | "num_layers", 170 | int, 171 | optional=True, 172 | default=4, 173 | doc=doc_num_layers, 174 | ), 175 | Argument( 176 | "l_max", 177 | int, 178 | optional=True, 179 | default=2, 180 | doc=doc_l_max, 181 | ), 182 | Argument( 183 | "num_features", 184 | int, 185 | optional=True, 186 | default=32, 187 | doc=doc_num_features, 188 | ), 189 | Argument( 190 | "nonlinearity_type", 191 | str, 192 | optional=True, 193 | default="gate", 194 | doc=doc_nonlinearity_type, 195 | ), 196 | Argument( 197 | "parity", 198 | bool, 199 | optional=True, 200 | default=True, 201 | doc=doc_parity, 202 | ), 203 | Argument( 204 | "num_basis", 205 | int, 206 | optional=True, 207 | default=8, 208 | doc=doc_num_basis, 209 | ), 210 | Argument( 211 | "BesselBasis_trainable", 212 | bool, 213 | optional=True, 214 | default=True, 215 | doc=doc_besselbasis_trainable, 216 | ), 217 | Argument( 218 | "PolynomialCutoff_p", 219 | int, 220 | optional=True, 221 | default=6, 222 | doc=doc_polynomialcutoff_p, 223 | ), 224 | Argument( 225 | "invariant_layers", 226 | int, 227 | optional=True, 228 | default=2, 229 | doc=doc_invariant_layers, 230 | ), 231 | Argument( 232 | "invariant_neurons", 233 | int, 234 | optional=True, 235 | default=64, 236 | doc=doc_invariant_neurons, 237 | ), 238 | Argument( 239 | "use_sc", 240 | bool, 241 | optional=True, 242 | default=True, 243 | doc=doc_use_sc, 244 | ), 245 | Argument( 246 | "irreps_edge_sh", 247 | str, 248 | optional=True, 249 | default="0e + 1e", 250 | doc=doc_irreps_edge_sh, 251 | ), 252 | Argument( 253 | "feature_irreps_hidden", 254 | str, 255 | optional=True, 256 | default="32x0o + 32x0e + 32x1o + 32x1e", 257 | doc=doc_feature_irreps_hidden, 258 | ), 259 | Argument( 260 | "chemical_embedding_irreps_out", 261 | str, 262 | optional=True, 263 | default="32x0e", 264 | doc=doc_chemical_embedding_irreps_out, 265 | ), 266 | Argument( 267 | "conv_to_output_hidden_irreps_out", 268 | str, 269 | optional=True, 270 | default="16x0e", 271 | doc=doc_conv_to_output_hidden_irreps_out, 272 | ), 273 | Argument( 274 | "precision", 275 | str, 276 | optional=True, 277 | default="float32", 278 | doc=doc_precision, 279 | ), 280 | ], 281 | doc="Nequip model", 282 | ) 283 | -------------------------------------------------------------------------------- /deepmd_gnn/env.py: -------------------------------------------------------------------------------- 1 | """Configurations read from environment variables.""" 2 | 3 | import os 4 | 5 | DP_GNN_USE_MAPPING = os.environ.get("DP_GNN_USE_MAPPING", "0") == "1" 6 | -------------------------------------------------------------------------------- /deepmd_gnn/mace.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: LGPL-3.0-or-later 2 | """Wrapper for MACE models.""" 3 | 4 | import json 5 | from copy import deepcopy 6 | from typing import Any, Optional 7 | 8 | import torch 9 | from deepmd.dpmodel.output_def import ( 10 | FittingOutputDef, 11 | ModelOutputDef, 12 | OutputVariableDef, 13 | ) 14 | from deepmd.pt.model.model.model import ( 15 | BaseModel, 16 | ) 17 | from deepmd.pt.model.model.transform_output import ( 18 | communicate_extended_output, 19 | ) 20 | from deepmd.pt.utils import env 21 | from deepmd.pt.utils.nlist import ( 22 | build_neighbor_list, 23 | extend_input_and_build_neighbor_list, 24 | ) 25 | from deepmd.pt.utils.stat import ( 26 | compute_output_stats, 27 | ) 28 | from deepmd.pt.utils.update_sel import ( 29 | UpdateSel, 30 | ) 31 | from deepmd.pt.utils.utils import ( 32 | to_numpy_array, 33 | to_torch_tensor, 34 | ) 35 | from deepmd.utils.data_system import ( 36 | DeepmdDataSystem, 37 | ) 38 | from deepmd.utils.path import ( 39 | DPPath, 40 | ) 41 | from deepmd.utils.version import ( 42 | check_version_compatibility, 43 | ) 44 | from e3nn import ( 45 | o3, 46 | ) 47 | from e3nn.util.jit import ( 48 | script, 49 | ) 50 | from mace.modules import ( 51 | ScaleShiftMACE, 52 | gate_dict, 53 | interaction_classes, 54 | ) 55 | 56 | import deepmd_gnn.op # noqa: F401 57 | from deepmd_gnn import env as deepmd_gnn_env 58 | 59 | ELEMENTS = [ 60 | "H", 61 | "He", 62 | "Li", 63 | "Be", 64 | "B", 65 | "C", 66 | "N", 67 | "O", 68 | "F", 69 | "Ne", 70 | "Na", 71 | "Mg", 72 | "Al", 73 | "Si", 74 | "P", 75 | "S", 76 | "Cl", 77 | "Ar", 78 | "K", 79 | "Ca", 80 | "Sc", 81 | "Ti", 82 | "V", 83 | "Cr", 84 | "Mn", 85 | "Fe", 86 | "Co", 87 | "Ni", 88 | "Cu", 89 | "Zn", 90 | "Ga", 91 | "Ge", 92 | "As", 93 | "Se", 94 | "Br", 95 | "Kr", 96 | "Rb", 97 | "Sr", 98 | "Y", 99 | "Zr", 100 | "Nb", 101 | "Mo", 102 | "Tc", 103 | "Ru", 104 | "Rh", 105 | "Pd", 106 | "Ag", 107 | "Cd", 108 | "In", 109 | "Sn", 110 | "Sb", 111 | "Te", 112 | "I", 113 | "Xe", 114 | "Cs", 115 | "Ba", 116 | "La", 117 | "Ce", 118 | "Pr", 119 | "Nd", 120 | "Pm", 121 | "Sm", 122 | "Eu", 123 | "Gd", 124 | "Tb", 125 | "Dy", 126 | "Ho", 127 | "Er", 128 | "Tm", 129 | "Yb", 130 | "Lu", 131 | "Hf", 132 | "Ta", 133 | "W", 134 | "Re", 135 | "Os", 136 | "Ir", 137 | "Pt", 138 | "Au", 139 | "Hg", 140 | "Tl", 141 | "Pb", 142 | "Bi", 143 | "Po", 144 | "At", 145 | "Rn", 146 | "Fr", 147 | "Ra", 148 | "Ac", 149 | "Th", 150 | "Pa", 151 | "U", 152 | "Np", 153 | "Pu", 154 | "Am", 155 | "Cm", 156 | "Bk", 157 | "Cf", 158 | "Es", 159 | "Fm", 160 | "Md", 161 | "No", 162 | "Lr", 163 | "Rf", 164 | "Db", 165 | "Sg", 166 | "Bh", 167 | "Hs", 168 | "Mt", 169 | "Ds", 170 | "Rg", 171 | "Cn", 172 | "Nh", 173 | "Fl", 174 | "Mc", 175 | "Lv", 176 | "Ts", 177 | "Og", 178 | ] 179 | 180 | PeriodicTable = { 181 | **{ee: ii + 1 for ii, ee in enumerate(ELEMENTS)}, 182 | **{f"m{ee}": ii + 1 for ii, ee in enumerate(ELEMENTS)}, 183 | "HW": 1, 184 | "OW": 8, 185 | } 186 | 187 | 188 | @BaseModel.register("mace") 189 | class MaceModel(BaseModel): 190 | """Mace model. 191 | 192 | Parameters 193 | ---------- 194 | type_map : list[str] 195 | The name of each type of atoms 196 | sel : int 197 | Maximum number of neighbor atoms 198 | r_max : float, optional 199 | distance cutoff (in Ang) 200 | num_radial_basis : int, optional 201 | number of radial basis functions 202 | num_cutoff_basis : int, optional 203 | number of basis functions for smooth cutoff 204 | max_ell : int, optional 205 | highest ell of spherical harmonics 206 | interaction : str, optional 207 | name of interaction block 208 | num_interactions : int, optional 209 | number of interactions 210 | hidden_irreps : str, optional 211 | hidden irreps 212 | pair_repulsion : bool 213 | use amsgrad variant of optimizer 214 | distance_transform : str, optional 215 | distance transform 216 | correlation : int 217 | correlation order at each layer 218 | gate : str, optional 219 | non linearity for last readout 220 | MLP_irreps : str, optional 221 | hidden irreps of the MLP in last readout 222 | radial_type : str, optional 223 | type of radial basis functions 224 | radial_MLP : str, optional 225 | width of the radial MLP 226 | std : float, optional 227 | Standard deviation of force components in the training set 228 | """ 229 | 230 | mm_types: list[int] 231 | 232 | def __init__( 233 | self, 234 | type_map: list[str], 235 | sel: int, 236 | r_max: float = 5.0, 237 | num_radial_basis: int = 8, 238 | num_cutoff_basis: int = 5, 239 | max_ell: int = 3, 240 | interaction: str = "RealAgnosticResidualInteractionBlock", 241 | num_interactions: int = 2, 242 | hidden_irreps: str = "128x0e + 128x1o", 243 | pair_repulsion: bool = False, 244 | distance_transform: str = "None", 245 | correlation: int = 3, 246 | gate: str = "silu", 247 | MLP_irreps: str = "16x0e", 248 | radial_type: str = "bessel", 249 | radial_MLP: list[int] = [64, 64, 64], # noqa: B006 250 | std: float = 1, 251 | **kwargs: Any, # noqa: ANN401 252 | ) -> None: 253 | super().__init__(**kwargs) 254 | self.params = { 255 | "type_map": type_map, 256 | "sel": sel, 257 | "r_max": r_max, 258 | "num_radial_basis": num_radial_basis, 259 | "num_cutoff_basis": num_cutoff_basis, 260 | "max_ell": max_ell, 261 | "interaction": interaction, 262 | "num_interactions": num_interactions, 263 | "hidden_irreps": hidden_irreps, 264 | "pair_repulsion": pair_repulsion, 265 | "distance_transform": distance_transform, 266 | "correlation": correlation, 267 | "gate": gate, 268 | "MLP_irreps": MLP_irreps, 269 | "radial_type": radial_type, 270 | "radial_MLP": radial_MLP, 271 | "std": std, 272 | } 273 | self.type_map = type_map 274 | self.ntypes = len(type_map) 275 | self.rcut = r_max 276 | self.num_interactions = num_interactions 277 | atomic_numbers = [] 278 | self.preset_out_bias: dict[str, list] = {"energy": []} 279 | self.mm_types = [] 280 | self.sel = sel 281 | for ii, tt in enumerate(type_map): 282 | atomic_numbers.append(PeriodicTable[tt]) 283 | if not tt.startswith("m") and tt not in {"HW", "OW"}: 284 | self.preset_out_bias["energy"].append(None) 285 | else: 286 | self.preset_out_bias["energy"].append([0]) 287 | self.mm_types.append(ii) 288 | 289 | self.model = script( 290 | ScaleShiftMACE( 291 | r_max=r_max, 292 | num_bessel=num_radial_basis, 293 | num_polynomial_cutoff=num_cutoff_basis, 294 | max_ell=max_ell, 295 | interaction_cls=interaction_classes[interaction], 296 | num_interactions=num_interactions, 297 | num_elements=self.ntypes, 298 | hidden_irreps=o3.Irreps(hidden_irreps), 299 | atomic_energies=torch.zeros(self.ntypes), # pylint: disable=no-explicit-device,no-explicit-dtype 300 | avg_num_neighbors=sel, 301 | atomic_numbers=atomic_numbers, 302 | pair_repulsion=pair_repulsion, 303 | distance_transform=distance_transform, 304 | correlation=correlation, 305 | gate=gate_dict[gate], 306 | interaction_cls_first=interaction_classes[ 307 | "RealAgnosticInteractionBlock" 308 | ], 309 | MLP_irreps=o3.Irreps(MLP_irreps), 310 | atomic_inter_scale=std, 311 | atomic_inter_shift=0.0, 312 | radial_MLP=radial_MLP, 313 | radial_type=radial_type, 314 | ).to(env.DEVICE), 315 | ) 316 | self.atomic_numbers = atomic_numbers 317 | 318 | def compute_or_load_stat( 319 | self, 320 | sampled_func, # noqa: ANN001 321 | stat_file_path: Optional[DPPath] = None, 322 | ) -> None: 323 | """Compute or load the statistics parameters of the model. 324 | 325 | For example, mean and standard deviation of descriptors or the energy bias of 326 | the fitting net. When `sampled` is provided, all the statistics parameters will 327 | be calculated (or re-calculated for update), and saved in the 328 | `stat_file_path`(s). When `sampled` is not provided, it will check the existence 329 | of `stat_file_path`(s) and load the calculated statistics parameters. 330 | 331 | Parameters 332 | ---------- 333 | sampled_func 334 | The sampled data frames from different data systems. 335 | stat_file_path 336 | The path to the statistics files. 337 | """ 338 | bias_out, _ = compute_output_stats( 339 | sampled_func, 340 | self.get_ntypes(), 341 | keys=["energy"], 342 | stat_file_path=stat_file_path, 343 | rcond=None, 344 | preset_bias=self.preset_out_bias, 345 | ) 346 | if "energy" in bias_out: 347 | self.model.atomic_energies_fn.atomic_energies = ( 348 | bias_out["energy"] 349 | .view(self.model.atomic_energies_fn.atomic_energies.shape) 350 | .to(self.model.atomic_energies_fn.atomic_energies.dtype) 351 | .to(self.model.atomic_energies_fn.atomic_energies.device) 352 | ) 353 | 354 | @torch.jit.export 355 | def fitting_output_def(self) -> FittingOutputDef: 356 | """Get the output def of developer implemented atomic models.""" 357 | return FittingOutputDef( 358 | [ 359 | OutputVariableDef( 360 | name="energy", 361 | shape=[1], 362 | reducible=True, 363 | r_differentiable=True, 364 | c_differentiable=True, 365 | ), 366 | ], 367 | ) 368 | 369 | @torch.jit.export 370 | def get_rcut(self) -> float: 371 | """Get the cut-off radius.""" 372 | if deepmd_gnn_env.DP_GNN_USE_MAPPING: 373 | return self.rcut 374 | return self.rcut * self.num_interactions 375 | 376 | @torch.jit.export 377 | def get_type_map(self) -> list[str]: 378 | """Get the type map.""" 379 | return self.type_map 380 | 381 | @torch.jit.export 382 | def get_sel(self) -> list[int]: 383 | """Return the number of selected atoms for each type.""" 384 | return [self.sel] 385 | 386 | @torch.jit.export 387 | def get_dim_fparam(self) -> int: 388 | """Get the number (dimension) of frame parameters of this atomic model.""" 389 | return 0 390 | 391 | @torch.jit.export 392 | def get_dim_aparam(self) -> int: 393 | """Get the number (dimension) of atomic parameters of this atomic model.""" 394 | return 0 395 | 396 | @torch.jit.export 397 | def get_sel_type(self) -> list[int]: 398 | """Get the selected atom types of this model. 399 | 400 | Only atoms with selected atom types have atomic contribution 401 | to the result of the model. 402 | If returning an empty list, all atom types are selected. 403 | """ 404 | return [] 405 | 406 | @torch.jit.export 407 | def is_aparam_nall(self) -> bool: 408 | """Check whether the shape of atomic parameters is (nframes, nall, ndim). 409 | 410 | If False, the shape is (nframes, nloc, ndim). 411 | """ 412 | return False 413 | 414 | @torch.jit.export 415 | def mixed_types(self) -> bool: 416 | """Return whether the model is in mixed-types mode. 417 | 418 | If true, the model 419 | 1. assumes total number of atoms aligned across frames; 420 | 2. uses a neighbor list that does not distinguish different atomic types. 421 | If false, the model 422 | 1. assumes total number of atoms of each atom type aligned across frames; 423 | 2. uses a neighbor list that distinguishes different atomic types. 424 | """ 425 | return True 426 | 427 | @torch.jit.export 428 | def has_message_passing(self) -> bool: 429 | """Return whether the descriptor has message passing.""" 430 | return False 431 | 432 | @torch.jit.export 433 | def forward( 434 | self, 435 | coord: torch.Tensor, 436 | atype: torch.Tensor, 437 | box: Optional[torch.Tensor] = None, 438 | fparam: Optional[torch.Tensor] = None, 439 | aparam: Optional[torch.Tensor] = None, 440 | do_atomic_virial: bool = False, 441 | ) -> dict[str, torch.Tensor]: 442 | """Forward pass of the model. 443 | 444 | Parameters 445 | ---------- 446 | coord : torch.Tensor 447 | The coordinates of atoms. 448 | atype : torch.Tensor 449 | The atomic types of atoms. 450 | box : torch.Tensor, optional 451 | The box tensor. 452 | fparam : torch.Tensor, optional 453 | The frame parameters. 454 | aparam : torch.Tensor, optional 455 | The atomic parameters. 456 | do_atomic_virial : bool, optional 457 | Whether to compute atomic virial. 458 | """ 459 | nloc = atype.shape[1] 460 | extended_coord, extended_atype, mapping, nlist = ( 461 | extend_input_and_build_neighbor_list( 462 | coord, 463 | atype, 464 | self.rcut, 465 | self.get_sel(), 466 | mixed_types=True, 467 | box=box, 468 | ) 469 | ) 470 | model_ret_lower = self.forward_lower_common( 471 | nloc, 472 | extended_coord, 473 | extended_atype, 474 | nlist, 475 | mapping=mapping, 476 | fparam=fparam, 477 | aparam=aparam, 478 | do_atomic_virial=do_atomic_virial, 479 | comm_dict=None, 480 | ) 481 | model_ret = communicate_extended_output( 482 | model_ret_lower, 483 | ModelOutputDef(self.fitting_output_def()), 484 | mapping, 485 | do_atomic_virial, 486 | ) 487 | model_predict = {} 488 | model_predict["atom_energy"] = model_ret["energy"] 489 | model_predict["energy"] = model_ret["energy_redu"] 490 | model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) 491 | model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) 492 | if do_atomic_virial: 493 | model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) 494 | return model_predict 495 | 496 | @torch.jit.export 497 | def forward_lower( 498 | self, 499 | extended_coord: torch.Tensor, 500 | extended_atype: torch.Tensor, 501 | nlist: torch.Tensor, 502 | mapping: Optional[torch.Tensor] = None, 503 | fparam: Optional[torch.Tensor] = None, 504 | aparam: Optional[torch.Tensor] = None, 505 | do_atomic_virial: bool = False, 506 | comm_dict: Optional[dict[str, torch.Tensor]] = None, 507 | ) -> dict[str, torch.Tensor]: 508 | """Forward lower pass of the model. 509 | 510 | Parameters 511 | ---------- 512 | extended_coord : torch.Tensor 513 | The extended coordinates of atoms. 514 | extended_atype : torch.Tensor 515 | The extended atomic types of atoms. 516 | nlist : torch.Tensor 517 | The neighbor list. 518 | mapping : torch.Tensor, optional 519 | The mapping tensor. 520 | fparam : torch.Tensor, optional 521 | The frame parameters. 522 | aparam : torch.Tensor, optional 523 | The atomic parameters. 524 | do_atomic_virial : bool, optional 525 | Whether to compute atomic virial. 526 | comm_dict : dict[str, torch.Tensor], optional 527 | The communication dictionary. 528 | """ 529 | nloc = nlist.shape[1] 530 | nf, nall = extended_atype.shape 531 | # calculate nlist for ghost atoms, as LAMMPS does not calculate it 532 | if mapping is None and self.num_interactions > 1 and nloc < nall: 533 | if deepmd_gnn_env.DP_GNN_USE_MAPPING: 534 | # when setting DP_GNN_USE_MAPPING, ghost atoms are only built 535 | # for one message-passing layer 536 | msg = ( 537 | "When setting DP_GNN_USE_MAPPING, mapping is required. " 538 | "If you are using LAMMPS, set `atom_modify map yes`." 539 | ) 540 | raise ValueError(msg) 541 | nlist = build_neighbor_list( 542 | extended_coord.view(nf, -1), 543 | extended_atype, 544 | nall, 545 | self.rcut, 546 | self.sel, 547 | distinguish_types=False, 548 | ) 549 | 550 | model_ret = self.forward_lower_common( 551 | nloc, 552 | extended_coord, 553 | extended_atype, 554 | nlist, 555 | mapping, 556 | fparam, 557 | aparam, 558 | do_atomic_virial, 559 | comm_dict, 560 | ) 561 | model_predict = {} 562 | model_predict["atom_energy"] = model_ret["energy"] 563 | model_predict["energy"] = model_ret["energy_redu"] 564 | model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) 565 | model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) 566 | if do_atomic_virial: 567 | model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(-3) 568 | return model_predict 569 | 570 | def forward_lower_common( 571 | self, 572 | nloc: int, 573 | extended_coord: torch.Tensor, 574 | extended_atype: torch.Tensor, 575 | nlist: torch.Tensor, 576 | mapping: Optional[torch.Tensor] = None, 577 | fparam: Optional[torch.Tensor] = None, 578 | aparam: Optional[torch.Tensor] = None, 579 | do_atomic_virial: bool = False, # noqa: ARG002 580 | comm_dict: Optional[dict[str, torch.Tensor]] = None, 581 | ) -> dict[str, torch.Tensor]: 582 | """Forward lower common pass of the model. 583 | 584 | Parameters 585 | ---------- 586 | extended_coord : torch.Tensor 587 | The extended coordinates of atoms. 588 | extended_atype : torch.Tensor 589 | The extended atomic types of atoms. 590 | nlist : torch.Tensor 591 | The neighbor list. 592 | mapping : torch.Tensor, optional 593 | The mapping tensor. 594 | fparam : torch.Tensor, optional 595 | The frame parameters. 596 | aparam : torch.Tensor, optional 597 | The atomic parameters. 598 | do_atomic_virial : bool, optional 599 | Whether to compute atomic virial. 600 | comm_dict : dict[str, torch.Tensor], optional 601 | The communication dictionary. 602 | """ 603 | nf, nall = extended_atype.shape 604 | extended_coord = extended_coord.view(nf, nall, 3) 605 | extended_coord_ = extended_coord 606 | if fparam is not None: 607 | msg = "fparam is unsupported" 608 | raise ValueError(msg) 609 | if aparam is not None: 610 | msg = "aparam is unsupported" 611 | raise ValueError(msg) 612 | if comm_dict is not None: 613 | msg = "comm_dict is unsupported" 614 | raise ValueError(msg) 615 | nlist = nlist.to(torch.int64) 616 | extended_atype = extended_atype.to(torch.int64) 617 | nall = extended_coord.shape[1] 618 | 619 | # fake as one frame 620 | extended_coord_ff = extended_coord.view(nf * nall, 3) 621 | extended_atype_ff = extended_atype.view(nf * nall) 622 | edge_index = torch.ops.deepmd_gnn.edge_index( 623 | nlist, 624 | extended_atype, 625 | torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"), 626 | ) 627 | edge_index = edge_index.T 628 | # to one hot 629 | indices = extended_atype_ff.unsqueeze(-1) 630 | oh = torch.zeros( 631 | (nf * nall, self.ntypes), 632 | device=extended_atype.device, 633 | dtype=torch.float64, 634 | ) 635 | # scatter_ is the in-place version of scatter 636 | oh.scatter_(dim=-1, index=indices, value=1) 637 | one_hot = oh.view((nf * nall, self.ntypes)) 638 | 639 | # cast to float32 640 | default_dtype = self.model.atomic_energies_fn.atomic_energies.dtype 641 | extended_coord_ff = extended_coord_ff.to(default_dtype) 642 | extended_coord_ff.requires_grad_(True) # noqa: FBT003 643 | nedge = edge_index.shape[1] 644 | if self.num_interactions > 1 and mapping is not None and nloc < nall: 645 | # shift the edges for ghost atoms, and map the ghost atoms to real atoms 646 | mapping_ff = mapping.view(nf * nall) + torch.arange( 647 | 0, 648 | nf * nall, 649 | nall, 650 | dtype=mapping.dtype, 651 | device=mapping.device, 652 | ).unsqueeze(-1).expand(nf, nall).reshape(-1) 653 | shifts_atoms = extended_coord_ff - extended_coord_ff[mapping_ff] 654 | shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]] 655 | edge_index = mapping_ff[edge_index] 656 | else: 657 | shifts = torch.zeros( 658 | (nedge, 3), 659 | dtype=torch.float64, 660 | device=extended_coord_.device, 661 | ) 662 | shifts = shifts.to(default_dtype) 663 | one_hot = one_hot.to(default_dtype) 664 | # it seems None is not allowed for data 665 | box = ( 666 | torch.eye( 667 | 3, 668 | dtype=extended_coord_ff.dtype, 669 | device=extended_coord_ff.device, 670 | ) 671 | * 1000.0 672 | ) 673 | 674 | ret = self.model.forward( 675 | { 676 | "positions": extended_coord_ff, 677 | "shifts": shifts, 678 | "cell": box, 679 | "edge_index": edge_index, 680 | "batch": torch.zeros( 681 | [nf * nall], 682 | dtype=torch.int64, 683 | device=extended_coord_ff.device, 684 | ), 685 | "node_attrs": one_hot, 686 | "ptr": torch.tensor( 687 | [0, nf * nall], 688 | dtype=torch.int64, 689 | device=extended_coord_ff.device, 690 | ), 691 | "weight": torch.tensor( 692 | [1.0], 693 | dtype=extended_coord_ff.dtype, 694 | device=extended_coord_ff.device, 695 | ), 696 | }, 697 | compute_force=False, 698 | compute_virials=False, 699 | compute_stress=False, 700 | compute_displacement=False, 701 | training=self.training, 702 | ) 703 | 704 | atom_energy = ret["node_energy"] 705 | if atom_energy is None: 706 | msg = "atom_energy is None" 707 | raise ValueError(msg) 708 | atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc] 709 | energy = torch.sum(atom_energy, dim=1).view(nf, 1).to(extended_coord_.dtype) 710 | grad_outputs: list[Optional[torch.Tensor]] = [ 711 | torch.ones_like(energy), 712 | ] 713 | force = torch.autograd.grad( 714 | outputs=[energy], 715 | inputs=[extended_coord_ff], 716 | grad_outputs=grad_outputs, 717 | retain_graph=True, 718 | create_graph=self.training, 719 | )[0] 720 | if force is None: 721 | msg = "force is None" 722 | raise ValueError(msg) 723 | force = -force 724 | atomic_virial = force.unsqueeze(-1).to( 725 | extended_coord_.dtype, 726 | ) @ extended_coord_ff.unsqueeze(-2).to( 727 | extended_coord_.dtype, 728 | ) 729 | force = force.view(nf, nall, 3).to(extended_coord_.dtype) 730 | atomic_virial = atomic_virial.view(nf, nall, 1, 9) 731 | virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype) 732 | 733 | return { 734 | "energy_redu": energy.view(nf, 1), 735 | "energy_derv_r": force.view(nf, nall, 1, 3), 736 | "energy_derv_c_redu": virial.view(nf, 1, 9), 737 | # take the first nloc atoms to match other models 738 | "energy": atom_energy.view(nf, nloc, 1), 739 | # fake atom_virial 740 | "energy_derv_c": atomic_virial.view(nf, nall, 1, 9), 741 | } 742 | 743 | def serialize(self) -> dict: 744 | """Serialize the model.""" 745 | return { 746 | "@class": "Model", 747 | "@version": 1, 748 | "type": "mace", 749 | **self.params, 750 | "@variables": { 751 | kk: to_numpy_array(vv) for kk, vv in self.model.state_dict().items() 752 | }, 753 | } 754 | 755 | @classmethod 756 | def deserialize(cls, data: dict) -> "MaceModel": 757 | """Deserialize the model.""" 758 | data = data.copy() 759 | if not (data.pop("@class") == "Model" and data.pop("type") == "mace"): 760 | msg = "data is not a serialized MaceModel" 761 | raise ValueError(msg) 762 | check_version_compatibility(data.pop("@version"), 1, 1) 763 | variables = { 764 | kk: to_torch_tensor(vv) for kk, vv in data.pop("@variables").items() 765 | } 766 | model = cls(**data) 767 | model.model.load_state_dict(variables) 768 | return model 769 | 770 | @torch.jit.export 771 | def get_nnei(self) -> int: 772 | """Return the total number of selected neighboring atoms in cut-off radius.""" 773 | return self.sel 774 | 775 | @torch.jit.export 776 | def get_nsel(self) -> int: 777 | """Return the total number of selected neighboring atoms in cut-off radius.""" 778 | return self.sel 779 | 780 | @classmethod 781 | def update_sel( 782 | cls, 783 | train_data: DeepmdDataSystem, 784 | type_map: Optional[list[str]], 785 | local_jdata: dict, 786 | ) -> tuple[dict, Optional[float]]: 787 | """Update the selection and perform neighbor statistics. 788 | 789 | Parameters 790 | ---------- 791 | train_data : DeepmdDataSystem 792 | data used to do neighbor statictics 793 | type_map : list[str], optional 794 | The name of each type of atoms 795 | local_jdata : dict 796 | The local data refer to the current class 797 | 798 | Returns 799 | ------- 800 | dict 801 | The updated local data 802 | float 803 | The minimum distance between two atoms 804 | """ 805 | local_jdata_cpy = local_jdata.copy() 806 | min_nbor_dist, sel = UpdateSel().update_one_sel( 807 | train_data, 808 | type_map, 809 | local_jdata_cpy["r_max"], 810 | local_jdata_cpy["sel"], 811 | mixed_type=True, 812 | ) 813 | local_jdata_cpy["sel"] = sel[0] 814 | return local_jdata_cpy, min_nbor_dist 815 | 816 | @torch.jit.export 817 | def model_output_type(self) -> list[str]: 818 | """Get the output type for the model.""" 819 | return ["energy"] 820 | 821 | def translated_output_def(self) -> dict[str, Any]: 822 | """Get the translated output def for the model.""" 823 | out_def_data = self.model_output_def().get_data() 824 | output_def = { 825 | "atom_energy": deepcopy(out_def_data["energy"]), 826 | "energy": deepcopy(out_def_data["energy_redu"]), 827 | } 828 | output_def["force"] = deepcopy(out_def_data["energy_derv_r"]) 829 | output_def["force"].squeeze(-2) 830 | output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) 831 | output_def["virial"].squeeze(-2) 832 | output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) 833 | output_def["atom_virial"].squeeze(-3) 834 | if "mask" in out_def_data: 835 | output_def["mask"] = deepcopy(out_def_data["mask"]) 836 | return output_def 837 | 838 | def model_output_def(self) -> ModelOutputDef: 839 | """Get the output def for the model.""" 840 | return ModelOutputDef(self.fitting_output_def()) 841 | 842 | @classmethod 843 | def get_model(cls, model_params: dict) -> "MaceModel": 844 | """Get the model by the parameters. 845 | 846 | Parameters 847 | ---------- 848 | model_params : dict 849 | The model parameters 850 | 851 | Returns 852 | ------- 853 | BaseBaseModel 854 | The model 855 | """ 856 | model_params_old = model_params.copy() 857 | model_params = model_params.copy() 858 | model_params.pop("type", None) 859 | precision = model_params.pop("precision", "float32") 860 | if precision == "float32": 861 | torch.set_default_dtype(torch.float32) 862 | elif precision == "float64": 863 | torch.set_default_dtype(torch.float64) 864 | else: 865 | msg = f"precision {precision} not supported" 866 | raise ValueError(msg) 867 | model = cls(**model_params) 868 | model.model_def_script = json.dumps(model_params_old) 869 | return model 870 | -------------------------------------------------------------------------------- /deepmd_gnn/nequip.py: -------------------------------------------------------------------------------- 1 | """Nequip model.""" 2 | 3 | from copy import deepcopy 4 | from typing import Any, Optional 5 | 6 | import torch 7 | from deepmd.dpmodel.output_def import ( 8 | FittingOutputDef, 9 | ModelOutputDef, 10 | OutputVariableDef, 11 | ) 12 | from deepmd.pt.model.model.model import ( 13 | BaseModel, 14 | ) 15 | from deepmd.pt.model.model.transform_output import ( 16 | communicate_extended_output, 17 | ) 18 | from deepmd.pt.utils import ( 19 | env, 20 | ) 21 | from deepmd.pt.utils.nlist import ( 22 | build_neighbor_list, 23 | extend_input_and_build_neighbor_list, 24 | ) 25 | from deepmd.pt.utils.stat import ( 26 | compute_output_stats, 27 | ) 28 | from deepmd.pt.utils.update_sel import ( 29 | UpdateSel, 30 | ) 31 | from deepmd.pt.utils.utils import ( 32 | to_numpy_array, 33 | to_torch_tensor, 34 | ) 35 | from deepmd.utils.data_system import ( 36 | DeepmdDataSystem, 37 | ) 38 | from deepmd.utils.path import ( 39 | DPPath, 40 | ) 41 | from deepmd.utils.version import ( 42 | check_version_compatibility, 43 | ) 44 | from e3nn.util.jit import ( 45 | script, 46 | ) 47 | from nequip.model import model_from_config 48 | 49 | 50 | @BaseModel.register("nequip") 51 | class NequipModel(BaseModel): 52 | """Nequip model. 53 | 54 | Parameters 55 | ---------- 56 | type_map : list[str] 57 | The name of each type of atoms 58 | sel : int 59 | Maximum number of neighbor atoms 60 | r_max : float, optional 61 | distance cutoff (in Ang) 62 | num_layers : int 63 | number of interaction blocks, we find 3-5 to work best 64 | l_max : int 65 | the maximum irrep order (rotation order) for the network's features, l=1 is a good default, l=2 is more accurate but slower 66 | num_features : int 67 | the multiplicity of the features, 32 is a good default for accurate network, if you want to be more accurate, go larger, if you want to be faster, go lower 68 | nonlinearity_type : str 69 | may be 'gate' or 'norm', 'gate' is recommended 70 | parity : bool 71 | whether to include features with odd mirror parityy; often turning parity off gives equally good results but faster networks, so do consider this 72 | num_basis : int 73 | number of basis functions used in the radial basis, 8 usually works best 74 | BesselBasis_trainable : bool 75 | set true to train the bessel weights 76 | PolynomialCutoff_p : int 77 | p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance 78 | invariant_layers : int 79 | number of radial layers, usually 1-3 works best, smaller is faster 80 | invariant_neurons : int 81 | number of hidden neurons in radial function, smaller is faster 82 | use_sc : bool 83 | use self-connection or not, usually gives big improvement 84 | irreps_edge_sh : str 85 | irreps for the chemical embedding of species 86 | feature_irreps_hidden : str 87 | irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster 88 | chemical_embedding_irreps_out : str 89 | irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer 90 | conv_to_output_hidden_irreps_out : str 91 | irreps used in hidden layer of output block 92 | """ 93 | 94 | mm_types: list[int] 95 | e0: torch.Tensor 96 | 97 | def __init__( 98 | self, 99 | type_map: list[str], 100 | sel: int, 101 | r_max: float = 6.0, 102 | num_layers: int = 4, 103 | l_max: int = 2, 104 | num_features: int = 32, 105 | nonlinearity_type: str = "gate", 106 | parity: bool = True, 107 | num_basis: int = 8, 108 | BesselBasis_trainable: bool = True, 109 | PolynomialCutoff_p: int = 6, 110 | invariant_layers: int = 2, 111 | invariant_neurons: int = 64, 112 | use_sc: bool = True, 113 | irreps_edge_sh: str = "0e + 1e", 114 | feature_irreps_hidden: str = "32x0o + 32x0e + 32x1o + 32x1e", 115 | chemical_embedding_irreps_out: str = "32x0e", 116 | conv_to_output_hidden_irreps_out: str = "16x0e", 117 | precision: str = "float32", 118 | **kwargs: Any, # noqa: ANN401 119 | ) -> None: 120 | super().__init__(**kwargs) 121 | self.params = { 122 | "type_map": type_map, 123 | "sel": sel, 124 | "r_max": r_max, 125 | "num_layers": num_layers, 126 | "l_max": l_max, 127 | "num_features": num_features, 128 | "nonlinearity_type": nonlinearity_type, 129 | "parity": parity, 130 | "num_basis": num_basis, 131 | "BesselBasis_trainable": BesselBasis_trainable, 132 | "PolynomialCutoff_p": PolynomialCutoff_p, 133 | "invariant_layers": invariant_layers, 134 | "invariant_neurons": invariant_neurons, 135 | "use_sc": use_sc, 136 | "irreps_edge_sh": irreps_edge_sh, 137 | "feature_irreps_hidden": feature_irreps_hidden, 138 | "chemical_embedding_irreps_out": chemical_embedding_irreps_out, 139 | "conv_to_output_hidden_irreps_out": conv_to_output_hidden_irreps_out, 140 | "precision": precision, 141 | } 142 | self.type_map = type_map 143 | self.ntypes = len(type_map) 144 | self.preset_out_bias: dict[str, list] = {"energy": []} 145 | self.mm_types = [] 146 | self.sel = sel 147 | self.num_layers = num_layers 148 | for ii, tt in enumerate(type_map): 149 | if not tt.startswith("m") and tt not in {"HW", "OW"}: 150 | self.preset_out_bias["energy"].append(None) 151 | else: 152 | self.preset_out_bias["energy"].append([0]) 153 | self.mm_types.append(ii) 154 | 155 | self.rcut = r_max 156 | self.model = script( 157 | model_from_config( 158 | { 159 | "model_builders": ["EnergyModel"], 160 | "avg_num_neighbors": sel, 161 | "chemical_symbols": type_map, 162 | "num_types": self.ntypes, 163 | "r_max": r_max, 164 | "num_layers": num_layers, 165 | "l_max": l_max, 166 | "num_features": num_features, 167 | "nonlinearity_type": nonlinearity_type, 168 | "parity": parity, 169 | "num_basis": num_basis, 170 | "BesselBasis_trainable": BesselBasis_trainable, 171 | "PolynomialCutoff_p": PolynomialCutoff_p, 172 | "invariant_layers": invariant_layers, 173 | "invariant_neurons": invariant_neurons, 174 | "use_sc": use_sc, 175 | "irreps_edge_sh": irreps_edge_sh, 176 | "feature_irreps_hidden": feature_irreps_hidden, 177 | "chemical_embedding_irreps_out": chemical_embedding_irreps_out, 178 | "conv_to_output_hidden_irreps_out": conv_to_output_hidden_irreps_out, 179 | "model_dtype": precision, 180 | }, 181 | ), 182 | ) 183 | self.register_buffer( 184 | "e0", 185 | torch.zeros( 186 | self.ntypes, 187 | dtype=env.GLOBAL_PT_ENER_FLOAT_PRECISION, 188 | device=env.DEVICE, 189 | ), 190 | ) 191 | 192 | def compute_or_load_stat( 193 | self, 194 | sampled_func, # noqa: ANN001 195 | stat_file_path: Optional[DPPath] = None, 196 | ) -> None: 197 | """Compute or load the statistics parameters of the model. 198 | 199 | For example, mean and standard deviation of descriptors or the energy bias of 200 | the fitting net. When `sampled` is provided, all the statistics parameters will 201 | be calculated (or re-calculated for update), and saved in the 202 | `stat_file_path`(s). When `sampled` is not provided, it will check the existence 203 | of `stat_file_path`(s) and load the calculated statistics parameters. 204 | 205 | Parameters 206 | ---------- 207 | sampled_func 208 | The sampled data frames from different data systems. 209 | stat_file_path 210 | The path to the statistics files. 211 | """ 212 | bias_out, _ = compute_output_stats( 213 | sampled_func, 214 | self.get_ntypes(), 215 | keys=["energy"], 216 | stat_file_path=stat_file_path, 217 | rcond=None, 218 | preset_bias=self.preset_out_bias, 219 | ) 220 | if "energy" in bias_out: 221 | self.e0 = ( 222 | bias_out["energy"] 223 | .view(self.e0.shape) 224 | .to(self.e0.dtype) 225 | .to(self.e0.device) 226 | ) 227 | 228 | @torch.jit.export 229 | def fitting_output_def(self) -> FittingOutputDef: 230 | """Get the output def of developer implemented atomic models.""" 231 | return FittingOutputDef( 232 | [ 233 | OutputVariableDef( 234 | name="energy", 235 | shape=[1], 236 | reducible=True, 237 | r_differentiable=True, 238 | c_differentiable=True, 239 | ), 240 | ], 241 | ) 242 | 243 | @torch.jit.export 244 | def get_rcut(self) -> float: 245 | """Get the cut-off radius.""" 246 | return self.rcut * self.num_layers 247 | 248 | @torch.jit.export 249 | def get_type_map(self) -> list[str]: 250 | """Get the type map.""" 251 | return self.type_map 252 | 253 | @torch.jit.export 254 | def get_sel(self) -> list[int]: 255 | """Return the number of selected atoms for each type.""" 256 | return [self.sel] 257 | 258 | @torch.jit.export 259 | def get_dim_fparam(self) -> int: 260 | """Get the number (dimension) of frame parameters of this atomic model.""" 261 | return 0 262 | 263 | @torch.jit.export 264 | def get_dim_aparam(self) -> int: 265 | """Get the number (dimension) of atomic parameters of this atomic model.""" 266 | return 0 267 | 268 | @torch.jit.export 269 | def get_sel_type(self) -> list[int]: 270 | """Get the selected atom types of this model. 271 | 272 | Only atoms with selected atom types have atomic contribution 273 | to the result of the model. 274 | If returning an empty list, all atom types are selected. 275 | """ 276 | return [] 277 | 278 | @torch.jit.export 279 | def is_aparam_nall(self) -> bool: 280 | """Check whether the shape of atomic parameters is (nframes, nall, ndim). 281 | 282 | If False, the shape is (nframes, nloc, ndim). 283 | """ 284 | return False 285 | 286 | @torch.jit.export 287 | def mixed_types(self) -> bool: 288 | """Return whether the model is in mixed-types mode. 289 | 290 | If true, the model 291 | 1. assumes total number of atoms aligned across frames; 292 | 2. uses a neighbor list that does not distinguish different atomic types. 293 | If false, the model 294 | 1. assumes total number of atoms of each atom type aligned across frames; 295 | 2. uses a neighbor list that distinguishes different atomic types. 296 | """ 297 | return True 298 | 299 | @torch.jit.export 300 | def has_message_passing(self) -> bool: 301 | """Return whether the descriptor has message passing.""" 302 | return False 303 | 304 | @torch.jit.export 305 | def forward( 306 | self, 307 | coord: torch.Tensor, 308 | atype: torch.Tensor, 309 | box: Optional[torch.Tensor] = None, 310 | fparam: Optional[torch.Tensor] = None, 311 | aparam: Optional[torch.Tensor] = None, 312 | do_atomic_virial: bool = False, 313 | ) -> dict[str, torch.Tensor]: 314 | """Forward pass of the model. 315 | 316 | Parameters 317 | ---------- 318 | coord : torch.Tensor 319 | The coordinates of atoms. 320 | atype : torch.Tensor 321 | The atomic types of atoms. 322 | box : torch.Tensor, optional 323 | The box tensor. 324 | fparam : torch.Tensor, optional 325 | The frame parameters. 326 | aparam : torch.Tensor, optional 327 | The atomic parameters. 328 | do_atomic_virial : bool, optional 329 | Whether to compute atomic virial. 330 | """ 331 | nloc = atype.shape[1] 332 | extended_coord, extended_atype, mapping, nlist = ( 333 | extend_input_and_build_neighbor_list( 334 | coord, 335 | atype, 336 | self.rcut, 337 | self.get_sel(), 338 | mixed_types=True, 339 | box=box, 340 | ) 341 | ) 342 | model_ret_lower = self.forward_lower_common( 343 | nloc, 344 | extended_coord, 345 | extended_atype, 346 | nlist, 347 | mapping=mapping, 348 | fparam=fparam, 349 | aparam=aparam, 350 | do_atomic_virial=do_atomic_virial, 351 | comm_dict=None, 352 | box=box, 353 | ) 354 | model_ret = communicate_extended_output( 355 | model_ret_lower, 356 | ModelOutputDef(self.fitting_output_def()), 357 | mapping, 358 | do_atomic_virial, 359 | ) 360 | model_predict = {} 361 | model_predict["atom_energy"] = model_ret["energy"] 362 | model_predict["energy"] = model_ret["energy_redu"] 363 | model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) 364 | model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) 365 | if do_atomic_virial: 366 | model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) 367 | return model_predict 368 | 369 | @torch.jit.export 370 | def forward_lower( 371 | self, 372 | extended_coord: torch.Tensor, 373 | extended_atype: torch.Tensor, 374 | nlist: torch.Tensor, 375 | mapping: Optional[torch.Tensor] = None, 376 | fparam: Optional[torch.Tensor] = None, 377 | aparam: Optional[torch.Tensor] = None, 378 | do_atomic_virial: bool = False, 379 | comm_dict: Optional[dict[str, torch.Tensor]] = None, 380 | ) -> dict[str, torch.Tensor]: 381 | """Forward lower pass of the model. 382 | 383 | Parameters 384 | ---------- 385 | extended_coord : torch.Tensor 386 | The extended coordinates of atoms. 387 | extended_atype : torch.Tensor 388 | The extended atomic types of atoms. 389 | nlist : torch.Tensor 390 | The neighbor list. 391 | mapping : torch.Tensor, optional 392 | The mapping tensor. 393 | fparam : torch.Tensor, optional 394 | The frame parameters. 395 | aparam : torch.Tensor, optional 396 | The atomic parameters. 397 | do_atomic_virial : bool, optional 398 | Whether to compute atomic virial. 399 | comm_dict : dict[str, torch.Tensor], optional 400 | The communication dictionary. 401 | """ 402 | nloc = nlist.shape[1] 403 | nf, nall = extended_atype.shape 404 | # recalculate nlist for ghost atoms 405 | if self.num_layers > 1 and nloc < nall: 406 | nlist = build_neighbor_list( 407 | extended_coord.view(nf, -1), 408 | extended_atype, 409 | nall, 410 | self.rcut * self.num_layers, 411 | self.sel, 412 | distinguish_types=False, 413 | ) 414 | model_ret = self.forward_lower_common( 415 | nloc, 416 | extended_coord, 417 | extended_atype, 418 | nlist, 419 | mapping, 420 | fparam, 421 | aparam, 422 | do_atomic_virial, 423 | comm_dict, 424 | ) 425 | model_predict = {} 426 | model_predict["atom_energy"] = model_ret["energy"] 427 | model_predict["energy"] = model_ret["energy_redu"] 428 | model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) 429 | model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) 430 | if do_atomic_virial: 431 | model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(-3) 432 | return model_predict 433 | 434 | def forward_lower_common( 435 | self, 436 | nloc: int, 437 | extended_coord: torch.Tensor, 438 | extended_atype: torch.Tensor, 439 | nlist: torch.Tensor, 440 | mapping: Optional[torch.Tensor] = None, 441 | fparam: Optional[torch.Tensor] = None, 442 | aparam: Optional[torch.Tensor] = None, 443 | do_atomic_virial: bool = False, # noqa: ARG002 444 | comm_dict: Optional[dict[str, torch.Tensor]] = None, 445 | box: Optional[torch.Tensor] = None, 446 | ) -> dict[str, torch.Tensor]: 447 | """Forward lower common pass of the model. 448 | 449 | Parameters 450 | ---------- 451 | extended_coord : torch.Tensor 452 | The extended coordinates of atoms. 453 | extended_atype : torch.Tensor 454 | The extended atomic types of atoms. 455 | nlist : torch.Tensor 456 | The neighbor list. 457 | mapping : torch.Tensor, optional 458 | The mapping tensor. 459 | fparam : torch.Tensor, optional 460 | The frame parameters. 461 | aparam : torch.Tensor, optional 462 | The atomic parameters. 463 | do_atomic_virial : bool, optional 464 | Whether to compute atomic virial. 465 | comm_dict : dict[str, torch.Tensor], optional 466 | The communication dictionary. 467 | box : torch.Tensor, optional 468 | The box tensor. 469 | """ 470 | nf, nall = extended_atype.shape 471 | 472 | extended_coord = extended_coord.view(nf, nall, 3) 473 | extended_coord_ = extended_coord 474 | if fparam is not None: 475 | msg = "fparam is unsupported" 476 | raise ValueError(msg) 477 | if aparam is not None: 478 | msg = "aparam is unsupported" 479 | raise ValueError(msg) 480 | if comm_dict is not None: 481 | msg = "comm_dict is unsupported" 482 | raise ValueError(msg) 483 | nlist = nlist.to(torch.int64) 484 | extended_atype = extended_atype.to(torch.int64) 485 | nall = extended_coord.shape[1] 486 | 487 | # fake as one frame 488 | extended_coord_ff = extended_coord.view(nf * nall, 3) 489 | extended_atype_ff = extended_atype.view(nf * nall) 490 | edge_index = torch.ops.deepmd_gnn.edge_index( 491 | nlist, 492 | extended_atype, 493 | torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"), 494 | ) 495 | edge_index = edge_index.T 496 | # Nequip and MACE have different defination for edge_index 497 | edge_index = edge_index[[1, 0]] 498 | 499 | # nequip can convert dtype by itself 500 | default_dtype = torch.float64 501 | extended_coord_ff = extended_coord_ff.to(default_dtype) 502 | extended_coord_ff.requires_grad_(True) # noqa: FBT003 503 | 504 | input_dict = { 505 | "pos": extended_coord_ff, 506 | "edge_index": edge_index, 507 | "atom_types": extended_atype_ff, 508 | } 509 | if box is not None and mapping is not None: 510 | # pass box, map edge index to real 511 | box_ff = box.to(extended_coord_ff.device) 512 | input_dict["cell"] = box_ff 513 | input_dict["pbc"] = torch.zeros( 514 | 3, 515 | dtype=torch.bool, 516 | device=box_ff.device, 517 | ) 518 | batch = torch.arange(nf, device=box_ff.device).repeat(nall) 519 | input_dict["batch"] = batch 520 | ptr = torch.arange( 521 | start=0, 522 | end=nf * nall + 1, 523 | step=nall, 524 | dtype=torch.int64, 525 | device=batch.device, 526 | ) 527 | input_dict["ptr"] = ptr 528 | mapping_ff = mapping.view(nf * nall) + torch.arange( 529 | 0, 530 | nf * nall, 531 | nall, 532 | dtype=mapping.dtype, 533 | device=mapping.device, 534 | ).unsqueeze(-1).expand(nf, nall).reshape(-1) 535 | shifts_atoms = extended_coord_ff - extended_coord_ff[mapping_ff] 536 | shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]] 537 | edge_index = mapping_ff[edge_index] 538 | input_dict["edge_index"] = edge_index 539 | rec_cell, _ = torch.linalg.inv_ex(box.view(nf, 3, 3)) 540 | edge_cell_shift = torch.einsum( 541 | "ni,nij->nj", 542 | shifts, 543 | rec_cell[batch[edge_index[0]]], 544 | ) 545 | input_dict["edge_cell_shift"] = edge_cell_shift 546 | 547 | ret = self.model.forward( 548 | input_dict, 549 | ) 550 | 551 | atom_energy = ret["atomic_energy"] 552 | if atom_energy is None: 553 | msg = "atom_energy is None" 554 | raise ValueError(msg) 555 | atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc] 556 | # adds e0 557 | atom_energy = atom_energy + self.e0[extended_atype[:, :nloc]].view( 558 | nf, 559 | nloc, 560 | ).to( 561 | atom_energy.dtype, 562 | ) 563 | energy = torch.sum(atom_energy, dim=1).view(nf, 1).to(extended_coord_.dtype) 564 | grad_outputs: list[Optional[torch.Tensor]] = [ 565 | torch.ones_like(energy), 566 | ] 567 | force = torch.autograd.grad( 568 | outputs=[energy], 569 | inputs=[extended_coord_ff], 570 | grad_outputs=grad_outputs, 571 | retain_graph=True, 572 | create_graph=self.training, 573 | )[0] 574 | if force is None: 575 | msg = "force is None" 576 | raise ValueError(msg) 577 | force = -force 578 | atomic_virial = force.unsqueeze(-1).to( 579 | extended_coord_.dtype, 580 | ) @ extended_coord_ff.unsqueeze(-2).to( 581 | extended_coord_.dtype, 582 | ) 583 | force = force.view(nf, nall, 3).to(extended_coord_.dtype) 584 | atomic_virial = atomic_virial.view(nf, nall, 1, 9) 585 | virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype) 586 | 587 | return { 588 | "energy_redu": energy.view(nf, 1), 589 | "energy_derv_r": force.view(nf, nall, 1, 3), 590 | "energy_derv_c_redu": virial.view(nf, 1, 9), 591 | # take the first nloc atoms to match other models 592 | "energy": atom_energy.view(nf, nloc, 1), 593 | # fake atom_virial 594 | "energy_derv_c": atomic_virial.view(nf, nall, 1, 9), 595 | } 596 | 597 | def serialize(self) -> dict: 598 | """Serialize the model.""" 599 | return { 600 | "@class": "Model", 601 | "@version": 1, 602 | "type": "mace", 603 | **self.params, 604 | "@variables": { 605 | **{ 606 | kk: to_numpy_array(vv) for kk, vv in self.model.state_dict().items() 607 | }, 608 | "e0": to_numpy_array(self.e0), 609 | }, 610 | } 611 | 612 | @classmethod 613 | def deserialize(cls, data: dict) -> "NequipModel": 614 | """Deserialize the model.""" 615 | data = data.copy() 616 | if not (data.pop("@class") == "Model" and data.pop("type") == "mace"): 617 | msg = "data is not a serialized NequipModel" 618 | raise ValueError(msg) 619 | check_version_compatibility(data.pop("@version"), 1, 1) 620 | variables = { 621 | kk: to_torch_tensor(vv) for kk, vv in data.pop("@variables").items() 622 | } 623 | model = cls(**data) 624 | model.e0 = variables.pop("e0") 625 | model.model.load_state_dict(variables) 626 | return model 627 | 628 | @torch.jit.export 629 | def get_nnei(self) -> int: 630 | """Return the total number of selected neighboring atoms in cut-off radius.""" 631 | return self.sel 632 | 633 | @torch.jit.export 634 | def get_nsel(self) -> int: 635 | """Return the total number of selected neighboring atoms in cut-off radius.""" 636 | return self.sel 637 | 638 | @classmethod 639 | def update_sel( 640 | cls, 641 | train_data: DeepmdDataSystem, 642 | type_map: Optional[list[str]], 643 | local_jdata: dict, 644 | ) -> tuple[dict, Optional[float]]: 645 | """Update the selection and perform neighbor statistics. 646 | 647 | Parameters 648 | ---------- 649 | train_data : DeepmdDataSystem 650 | data used to do neighbor statictics 651 | type_map : list[str], optional 652 | The name of each type of atoms 653 | local_jdata : dict 654 | The local data refer to the current class 655 | 656 | Returns 657 | ------- 658 | dict 659 | The updated local data 660 | float 661 | The minimum distance between two atoms 662 | """ 663 | local_jdata_cpy = local_jdata.copy() 664 | min_nbor_dist, sel = UpdateSel().update_one_sel( 665 | train_data, 666 | type_map, 667 | local_jdata_cpy["r_max"], 668 | local_jdata_cpy["sel"], 669 | mixed_type=True, 670 | ) 671 | local_jdata_cpy["sel"] = sel[0] 672 | return local_jdata_cpy, min_nbor_dist 673 | 674 | @torch.jit.export 675 | def model_output_type(self) -> list[str]: 676 | """Get the output type for the model.""" 677 | return ["energy"] 678 | 679 | def translated_output_def(self) -> dict[str, Any]: 680 | """Get the translated output def for the model.""" 681 | out_def_data = self.model_output_def().get_data() 682 | output_def = { 683 | "atom_energy": deepcopy(out_def_data["energy"]), 684 | "energy": deepcopy(out_def_data["energy_redu"]), 685 | } 686 | output_def["force"] = deepcopy(out_def_data["energy_derv_r"]) 687 | output_def["force"].squeeze(-2) 688 | output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) 689 | output_def["virial"].squeeze(-2) 690 | output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) 691 | output_def["atom_virial"].squeeze(-3) 692 | if "mask" in out_def_data: 693 | output_def["mask"] = deepcopy(out_def_data["mask"]) 694 | return output_def 695 | 696 | def model_output_def(self) -> ModelOutputDef: 697 | """Get the output def for the model.""" 698 | return ModelOutputDef(self.fitting_output_def()) 699 | -------------------------------------------------------------------------------- /deepmd_gnn/op.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: LGPL-3.0-or-later 2 | """Load OP library.""" 3 | 4 | from __future__ import annotations 5 | 6 | import platform 7 | from pathlib import Path 8 | 9 | import torch 10 | 11 | import deepmd_gnn.lib 12 | 13 | SHARED_LIB_DIR = Path(deepmd_gnn.lib.__path__[0]) 14 | 15 | 16 | def load_library(module_name: str) -> None: 17 | """Load OP library. 18 | 19 | Parameters 20 | ---------- 21 | module_name : str 22 | Name of the module 23 | 24 | Returns 25 | ------- 26 | bool 27 | Whether the library is loaded successfully 28 | """ 29 | if platform.system() == "Windows": 30 | ext = ".dll" 31 | prefix = "" 32 | else: 33 | ext = ".so" 34 | prefix = "lib" 35 | 36 | module_file = (SHARED_LIB_DIR / (prefix + module_name)).with_suffix(ext).resolve() 37 | 38 | torch.ops.load_library(module_file) 39 | 40 | 41 | load_library("deepmd_gnn") 42 | -------------------------------------------------------------------------------- /deepmd_gnn/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/deepmd_gnn/py.typed -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: LGPL-3.0-or-later 2 | """Configuration file for the Sphinx documentation builder.""" 3 | # 4 | # This file only contains a selection of the most common options. For a full 5 | # list see the documentation: 6 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 7 | 8 | # -- Path setup -------------------------------------------------------------- 9 | 10 | # If extensions (or modules to document with autodoc) are in another directory, 11 | # add these directories to sys.path here. If the directory is relative to the 12 | # documentation root, use os.path.abspath to make it absolute, like shown here. 13 | # 14 | 15 | # import sys 16 | # sys.path.insert(0, os.path.abspath('..')) 17 | from datetime import datetime, timezone 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = "DeePMD-GNN" 22 | copyright = f"2024-{datetime.now(tz=timezone.utc).year}, DeepModeling" # noqa: A001 23 | author = "Jinzhe Zeng" 24 | 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = [ 32 | "sphinx.ext.autosummary", 33 | "sphinx.ext.mathjax", 34 | "sphinx.ext.viewcode", 35 | "sphinx.ext.intersphinx", 36 | # "sphinxarg.ext", 37 | "myst_parser", 38 | # "sphinx_favicon", 39 | "deepmodeling_sphinx", 40 | "dargs.sphinx", 41 | # "sphinxcontrib.bibtex", 42 | # "sphinx_design", 43 | "autoapi.extension", 44 | ] 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | # templates_path = ['_templates'] 48 | 49 | # List of patterns, relative to source directory, that match files and 50 | # directories to ignore when looking for source files. 51 | # This pattern also affects html_static_path and html_extra_path. 52 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 53 | 54 | 55 | # -- Options for HTML output ------------------------------------------------- 56 | 57 | # The theme to use for HTML and HTML Help pages. See the documentation for 58 | # a list of builtin themes. 59 | # 60 | html_theme = "sphinx_book_theme" 61 | # html_logo = "_static/logo.svg" 62 | html_static_path = ["_static"] 63 | html_js_files: list[str] = [] 64 | html_css_files = ["css/custom.css"] 65 | html_extra_path = ["report.html", "fire.png", "bundle.js", "bundle.css"] 66 | 67 | html_theme_options = { 68 | "github_url": "https://github.com/deepmodeling/deepmd-gnn", 69 | "gitlab_url": "https://gitlab.com/RutgersLBSR/deepmd-gnn", 70 | "logo": { 71 | "text": "DeePMD-GNN", 72 | "alt_text": "DeePMD-GNN", 73 | }, 74 | } 75 | 76 | html_context = { 77 | "github_user": "deepmodeling", 78 | "github_repo": "deepmd-gnn", 79 | "github_version": "master", 80 | "doc_path": "docs", 81 | } 82 | 83 | myst_heading_anchors = 3 84 | 85 | # favicons = [ 86 | # { 87 | # "rel": "icon", 88 | # "static-file": "logo.svg", 89 | # "type": "image/svg+xml", 90 | # }, 91 | # ] 92 | 93 | enable_deepmodeling = False 94 | 95 | myst_enable_extensions = [ 96 | "dollarmath", 97 | "colon_fence", 98 | "attrs_inline", 99 | ] 100 | mathjax_path = ( 101 | "https://cdnjs.cloudflare.com/ajax/libs/mathjax/3.2.2/es5/tex-mml-chtml.min.js" 102 | ) 103 | mathjax_options = { 104 | "integrity": "sha512-6FaAxxHuKuzaGHWnV00ftWqP3luSBRSopnNAA2RvQH1fOfnF/A1wOfiUWF7cLIOFcfb1dEhXwo5VG3DAisocRw==", 105 | "crossorigin": "anonymous", 106 | } 107 | mathjax3_config = { 108 | "loader": {"load": ["[tex]/mhchem"]}, 109 | "tex": {"packages": {"[+]": ["mhchem"]}}, 110 | } 111 | 112 | execution_mode = "off" 113 | numpydoc_show_class_members = False 114 | 115 | # Add any paths that contain custom static files (such as style sheets) here, 116 | # relative to this directory. They are copied after the builtin static files, 117 | # so a file named "default.css" will overwrite the builtin "default.css". 118 | # html_static_path = ['_static'] 119 | 120 | intersphinx_mapping = { 121 | "numpy": ("https://docs.scipy.org/doc/numpy/", None), 122 | "python": ("https://docs.python.org/", None), 123 | "deepmd": ("https://docs.deepmodeling.com/projects/deepmd/", None), 124 | "torch": ("https://pytorch.org/docs/stable/", None), 125 | } 126 | autoapi_dirs = ["../deepmd_gnn"] 127 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.md 2 | :parser: myst_parser.sphinx_ 3 | 4 | Table of contents 5 | ================= 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | Overview 10 | parameters 11 | Python API 12 | 13 | Indices and tables 14 | ================== 15 | 16 | * :ref:`genindex` 17 | * :ref:`modindex` 18 | * :ref:`search` 19 | -------------------------------------------------------------------------------- /docs/parameters.rst: -------------------------------------------------------------------------------- 1 | Parameters 2 | ========== 3 | 4 | MACE 5 | ---- 6 | 7 | .. dargs:: 8 | :module: deepmd_gnn.argcheck 9 | :func: mace_model_args 10 | 11 | NequIP 12 | ------ 13 | 14 | .. dargs:: 15 | :module: deepmd_gnn.argcheck 16 | :func: nequip_model_args 17 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | lcurve.out 2 | out.json 3 | input_v2_compat.json 4 | checkpoint 5 | *.pth 6 | *.pb 7 | *.pt 8 | model.ckpt* 9 | -------------------------------------------------------------------------------- /examples/dprc/data/nopbc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/dprc/data/nopbc -------------------------------------------------------------------------------- /examples/dprc/data/set.000/box.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/dprc/data/set.000/box.npy -------------------------------------------------------------------------------- /examples/dprc/data/set.000/coord.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/dprc/data/set.000/coord.npy -------------------------------------------------------------------------------- /examples/dprc/data/set.000/energy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/dprc/data/set.000/energy.npy -------------------------------------------------------------------------------- /examples/dprc/data/set.000/force.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/dprc/data/set.000/force.npy -------------------------------------------------------------------------------- /examples/dprc/data/type.raw: -------------------------------------------------------------------------------- 1 | 1 2 | 0 3 | 1 4 | 3 5 | 0 6 | 1 7 | 1 8 | 0 9 | 1 10 | 0 11 | 1 12 | 3 13 | 3 14 | 5 15 | 3 16 | 3 17 | 3 18 | 0 19 | 1 20 | 1 21 | 0 22 | 1 23 | 1 24 | 1 25 | 4 26 | 2 27 | 2 28 | 4 29 | 4 30 | 2 31 | 4 32 | 2 33 | 2 34 | 4 35 | 2 36 | 2 37 | 4 38 | 2 39 | 2 40 | 4 41 | 2 42 | 2 43 | 2 44 | 4 45 | 2 46 | 2 47 | 4 48 | 2 49 | 2 50 | 4 51 | 2 52 | 2 53 | 4 54 | 2 55 | 2 56 | 4 57 | 2 58 | 2 59 | 4 60 | 2 61 | 2 62 | 4 63 | 2 64 | 2 65 | 4 66 | 2 67 | 2 68 | 4 69 | 2 70 | 2 71 | 4 72 | 2 73 | 2 74 | 4 75 | 2 76 | 2 77 | 4 78 | 2 79 | 2 80 | 4 81 | 2 82 | 2 83 | 4 84 | 2 85 | 2 86 | 4 87 | 2 88 | 2 89 | 4 90 | 2 91 | 2 92 | 4 93 | 2 94 | 2 95 | 4 96 | 2 97 | 2 98 | 4 99 | 2 100 | 2 101 | 4 102 | 2 103 | 2 104 | 4 105 | 2 106 | 2 107 | 4 108 | 2 109 | 2 110 | 4 111 | 2 112 | 2 113 | 4 114 | 2 115 | 2 116 | 4 117 | 2 118 | 2 119 | 4 120 | 2 121 | 2 122 | 4 123 | 2 124 | 2 125 | 4 126 | 2 127 | 4 128 | 2 129 | 2 130 | 4 131 | 2 132 | 2 133 | 4 134 | 2 135 | 2 136 | 4 137 | 2 138 | 2 139 | 4 140 | 2 141 | 2 142 | 4 143 | 2 144 | 2 145 | 4 146 | 2 147 | 2 148 | 4 149 | 2 150 | 2 151 | 4 152 | 2 153 | 4 154 | 2 155 | 4 156 | 2 157 | 2 158 | 4 159 | 2 160 | 2 161 | 4 162 | 2 163 | 2 164 | 4 165 | 2 166 | 2 167 | 4 168 | 2 169 | 2 170 | 4 171 | 2 172 | 2 173 | 4 174 | 2 175 | 2 176 | 4 177 | 2 178 | 2 179 | 4 180 | 2 181 | 2 182 | 4 183 | 2 184 | 2 185 | 4 186 | 2 187 | 2 188 | 4 189 | 2 190 | 4 191 | 2 192 | 2 193 | 4 194 | 2 195 | 2 196 | 4 197 | 2 198 | 2 199 | 4 200 | 2 201 | 2 202 | 2 203 | 4 204 | 2 205 | 2 206 | 4 207 | 2 208 | 2 209 | 4 210 | 2 211 | 2 212 | 2 213 | 4 214 | 2 215 | 2 216 | 2 217 | 4 218 | 2 219 | 2 220 | 4 221 | 2 222 | 2 223 | 2 224 | 2 225 | 2 226 | 4 227 | 2 228 | 2 229 | 4 230 | 2 231 | 2 232 | 4 233 | 2 234 | 2 235 | 4 236 | 2 237 | 2 238 | 2 239 | 2 240 | 4 241 | 2 242 | 2 243 | 4 244 | 2 245 | 2 246 | 4 247 | 2 248 | 2 249 | 4 250 | 2 251 | 2 252 | 4 253 | 2 254 | 2 255 | 2 256 | 4 257 | 2 258 | 2 259 | 4 260 | 2 261 | 2 262 | 4 263 | 2 264 | 2 265 | 4 266 | 2 267 | 2 268 | 4 269 | 2 270 | 2 271 | 4 272 | 2 273 | 2 274 | 2 275 | 4 276 | -------------------------------------------------------------------------------- /examples/dprc/data/type_map.raw: -------------------------------------------------------------------------------- 1 | C 2 | H 3 | HW 4 | O 5 | OW 6 | P 7 | -------------------------------------------------------------------------------- /examples/dprc/mace/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment1": " model parameters", 3 | "model": { 4 | "type": "mace", 5 | "sel": 150, 6 | "type_map": [ 7 | "C", 8 | "P", 9 | "O", 10 | "H", 11 | "OW", 12 | "HW" 13 | ], 14 | "r_max": 6.0 15 | }, 16 | "learning_rate": { 17 | "type": "exp", 18 | "decay_steps": 5000, 19 | "start_lr": 0.001, 20 | "stop_lr": 3.51e-8, 21 | "_comment2": "that's all" 22 | }, 23 | "loss": { 24 | "type": "ener", 25 | "start_pref_e": 0.02, 26 | "limit_pref_e": 1, 27 | "start_pref_f": 1000, 28 | "limit_pref_f": 1, 29 | "start_pref_v": 0, 30 | "limit_pref_v": 0, 31 | "_comment3": " that's all" 32 | }, 33 | "training": { 34 | "training_data": { 35 | "systems": [ 36 | "../data" 37 | ], 38 | "batch_size": "auto", 39 | "_comment4": "that's all" 40 | }, 41 | "numb_steps": 100000, 42 | "seed": 10, 43 | "disp_file": "lcurve.out", 44 | "disp_freq": 100, 45 | "save_freq": 100, 46 | "_comment5": "that's all" 47 | }, 48 | "_comment6": "that's all" 49 | } 50 | -------------------------------------------------------------------------------- /examples/dprc/nequip/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment1": " model parameters", 3 | "model": { 4 | "type": "nequip", 5 | "type_map": [ 6 | "C", 7 | "P", 8 | "O", 9 | "H", 10 | "OW", 11 | "HW" 12 | ], 13 | "r_max": 6.0, 14 | "sel": "auto", 15 | "l_max": 1, 16 | "_comment2": " that's all" 17 | }, 18 | 19 | "learning_rate": { 20 | "type": "exp", 21 | "decay_steps": 5000, 22 | "start_lr": 0.001, 23 | "stop_lr": 3.51e-8, 24 | "_comment3": "that's all" 25 | }, 26 | 27 | "loss": { 28 | "type": "ener", 29 | "start_pref_e": 0.02, 30 | "limit_pref_e": 1, 31 | "start_pref_f": 1000, 32 | "limit_pref_f": 1, 33 | "start_pref_v": 0, 34 | "limit_pref_v": 0, 35 | "_comment4": " that's all" 36 | }, 37 | 38 | "training": { 39 | "training_data": { 40 | "systems": [ 41 | "../data" 42 | ], 43 | "batch_size": "auto", 44 | "_comment4": "that's all" 45 | }, 46 | "numb_steps": 100000, 47 | "seed": 10, 48 | "disp_file": "lcurve.out", 49 | "disp_freq": 100, 50 | "save_freq": 100, 51 | "_comment5": "that's all" 52 | }, 53 | 54 | "_comment8": "that's all" 55 | } 56 | -------------------------------------------------------------------------------- /examples/water/data/data_0/set.000/box.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_0/set.000/box.npy -------------------------------------------------------------------------------- /examples/water/data/data_0/set.000/coord.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_0/set.000/coord.npy -------------------------------------------------------------------------------- /examples/water/data/data_0/set.000/energy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_0/set.000/energy.npy -------------------------------------------------------------------------------- /examples/water/data/data_0/set.000/force.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_0/set.000/force.npy -------------------------------------------------------------------------------- /examples/water/data/data_0/type.raw: -------------------------------------------------------------------------------- 1 | 0 2 | 0 3 | 0 4 | 0 5 | 0 6 | 0 7 | 0 8 | 0 9 | 0 10 | 0 11 | 0 12 | 0 13 | 0 14 | 0 15 | 0 16 | 0 17 | 0 18 | 0 19 | 0 20 | 0 21 | 0 22 | 0 23 | 0 24 | 0 25 | 0 26 | 0 27 | 0 28 | 0 29 | 0 30 | 0 31 | 0 32 | 0 33 | 0 34 | 0 35 | 0 36 | 0 37 | 0 38 | 0 39 | 0 40 | 0 41 | 0 42 | 0 43 | 0 44 | 0 45 | 0 46 | 0 47 | 0 48 | 0 49 | 0 50 | 0 51 | 0 52 | 0 53 | 0 54 | 0 55 | 0 56 | 0 57 | 0 58 | 0 59 | 0 60 | 0 61 | 0 62 | 0 63 | 0 64 | 0 65 | 1 66 | 1 67 | 1 68 | 1 69 | 1 70 | 1 71 | 1 72 | 1 73 | 1 74 | 1 75 | 1 76 | 1 77 | 1 78 | 1 79 | 1 80 | 1 81 | 1 82 | 1 83 | 1 84 | 1 85 | 1 86 | 1 87 | 1 88 | 1 89 | 1 90 | 1 91 | 1 92 | 1 93 | 1 94 | 1 95 | 1 96 | 1 97 | 1 98 | 1 99 | 1 100 | 1 101 | 1 102 | 1 103 | 1 104 | 1 105 | 1 106 | 1 107 | 1 108 | 1 109 | 1 110 | 1 111 | 1 112 | 1 113 | 1 114 | 1 115 | 1 116 | 1 117 | 1 118 | 1 119 | 1 120 | 1 121 | 1 122 | 1 123 | 1 124 | 1 125 | 1 126 | 1 127 | 1 128 | 1 129 | 1 130 | 1 131 | 1 132 | 1 133 | 1 134 | 1 135 | 1 136 | 1 137 | 1 138 | 1 139 | 1 140 | 1 141 | 1 142 | 1 143 | 1 144 | 1 145 | 1 146 | 1 147 | 1 148 | 1 149 | 1 150 | 1 151 | 1 152 | 1 153 | 1 154 | 1 155 | 1 156 | 1 157 | 1 158 | 1 159 | 1 160 | 1 161 | 1 162 | 1 163 | 1 164 | 1 165 | 1 166 | 1 167 | 1 168 | 1 169 | 1 170 | 1 171 | 1 172 | 1 173 | 1 174 | 1 175 | 1 176 | 1 177 | 1 178 | 1 179 | 1 180 | 1 181 | 1 182 | 1 183 | 1 184 | 1 185 | 1 186 | 1 187 | 1 188 | 1 189 | 1 190 | 1 191 | 1 192 | 1 193 | -------------------------------------------------------------------------------- /examples/water/data/data_0/type_map.raw: -------------------------------------------------------------------------------- 1 | O 2 | H 3 | -------------------------------------------------------------------------------- /examples/water/data/data_1/set.000/box.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.000/box.npy -------------------------------------------------------------------------------- /examples/water/data/data_1/set.000/coord.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.000/coord.npy -------------------------------------------------------------------------------- /examples/water/data/data_1/set.000/energy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.000/energy.npy -------------------------------------------------------------------------------- /examples/water/data/data_1/set.000/force.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.000/force.npy -------------------------------------------------------------------------------- /examples/water/data/data_1/set.001/box.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.001/box.npy -------------------------------------------------------------------------------- /examples/water/data/data_1/set.001/coord.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.001/coord.npy -------------------------------------------------------------------------------- /examples/water/data/data_1/set.001/energy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.001/energy.npy -------------------------------------------------------------------------------- /examples/water/data/data_1/set.001/force.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.001/force.npy -------------------------------------------------------------------------------- /examples/water/data/data_1/type.raw: -------------------------------------------------------------------------------- 1 | 0 2 | 0 3 | 0 4 | 0 5 | 0 6 | 0 7 | 0 8 | 0 9 | 0 10 | 0 11 | 0 12 | 0 13 | 0 14 | 0 15 | 0 16 | 0 17 | 0 18 | 0 19 | 0 20 | 0 21 | 0 22 | 0 23 | 0 24 | 0 25 | 0 26 | 0 27 | 0 28 | 0 29 | 0 30 | 0 31 | 0 32 | 0 33 | 0 34 | 0 35 | 0 36 | 0 37 | 0 38 | 0 39 | 0 40 | 0 41 | 0 42 | 0 43 | 0 44 | 0 45 | 0 46 | 0 47 | 0 48 | 0 49 | 0 50 | 0 51 | 0 52 | 0 53 | 0 54 | 0 55 | 0 56 | 0 57 | 0 58 | 0 59 | 0 60 | 0 61 | 0 62 | 0 63 | 0 64 | 0 65 | 1 66 | 1 67 | 1 68 | 1 69 | 1 70 | 1 71 | 1 72 | 1 73 | 1 74 | 1 75 | 1 76 | 1 77 | 1 78 | 1 79 | 1 80 | 1 81 | 1 82 | 1 83 | 1 84 | 1 85 | 1 86 | 1 87 | 1 88 | 1 89 | 1 90 | 1 91 | 1 92 | 1 93 | 1 94 | 1 95 | 1 96 | 1 97 | 1 98 | 1 99 | 1 100 | 1 101 | 1 102 | 1 103 | 1 104 | 1 105 | 1 106 | 1 107 | 1 108 | 1 109 | 1 110 | 1 111 | 1 112 | 1 113 | 1 114 | 1 115 | 1 116 | 1 117 | 1 118 | 1 119 | 1 120 | 1 121 | 1 122 | 1 123 | 1 124 | 1 125 | 1 126 | 1 127 | 1 128 | 1 129 | 1 130 | 1 131 | 1 132 | 1 133 | 1 134 | 1 135 | 1 136 | 1 137 | 1 138 | 1 139 | 1 140 | 1 141 | 1 142 | 1 143 | 1 144 | 1 145 | 1 146 | 1 147 | 1 148 | 1 149 | 1 150 | 1 151 | 1 152 | 1 153 | 1 154 | 1 155 | 1 156 | 1 157 | 1 158 | 1 159 | 1 160 | 1 161 | 1 162 | 1 163 | 1 164 | 1 165 | 1 166 | 1 167 | 1 168 | 1 169 | 1 170 | 1 171 | 1 172 | 1 173 | 1 174 | 1 175 | 1 176 | 1 177 | 1 178 | 1 179 | 1 180 | 1 181 | 1 182 | 1 183 | 1 184 | 1 185 | 1 186 | 1 187 | 1 188 | 1 189 | 1 190 | 1 191 | 1 192 | 1 193 | -------------------------------------------------------------------------------- /examples/water/data/data_1/type_map.raw: -------------------------------------------------------------------------------- 1 | O 2 | H 3 | -------------------------------------------------------------------------------- /examples/water/data/data_2/set.000/box.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_2/set.000/box.npy -------------------------------------------------------------------------------- /examples/water/data/data_2/set.000/coord.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_2/set.000/coord.npy -------------------------------------------------------------------------------- /examples/water/data/data_2/set.000/energy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_2/set.000/energy.npy -------------------------------------------------------------------------------- /examples/water/data/data_2/set.000/force.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_2/set.000/force.npy -------------------------------------------------------------------------------- /examples/water/data/data_2/type.raw: -------------------------------------------------------------------------------- 1 | 0 2 | 0 3 | 0 4 | 0 5 | 0 6 | 0 7 | 0 8 | 0 9 | 0 10 | 0 11 | 0 12 | 0 13 | 0 14 | 0 15 | 0 16 | 0 17 | 0 18 | 0 19 | 0 20 | 0 21 | 0 22 | 0 23 | 0 24 | 0 25 | 0 26 | 0 27 | 0 28 | 0 29 | 0 30 | 0 31 | 0 32 | 0 33 | 0 34 | 0 35 | 0 36 | 0 37 | 0 38 | 0 39 | 0 40 | 0 41 | 0 42 | 0 43 | 0 44 | 0 45 | 0 46 | 0 47 | 0 48 | 0 49 | 0 50 | 0 51 | 0 52 | 0 53 | 0 54 | 0 55 | 0 56 | 0 57 | 0 58 | 0 59 | 0 60 | 0 61 | 0 62 | 0 63 | 0 64 | 0 65 | 1 66 | 1 67 | 1 68 | 1 69 | 1 70 | 1 71 | 1 72 | 1 73 | 1 74 | 1 75 | 1 76 | 1 77 | 1 78 | 1 79 | 1 80 | 1 81 | 1 82 | 1 83 | 1 84 | 1 85 | 1 86 | 1 87 | 1 88 | 1 89 | 1 90 | 1 91 | 1 92 | 1 93 | 1 94 | 1 95 | 1 96 | 1 97 | 1 98 | 1 99 | 1 100 | 1 101 | 1 102 | 1 103 | 1 104 | 1 105 | 1 106 | 1 107 | 1 108 | 1 109 | 1 110 | 1 111 | 1 112 | 1 113 | 1 114 | 1 115 | 1 116 | 1 117 | 1 118 | 1 119 | 1 120 | 1 121 | 1 122 | 1 123 | 1 124 | 1 125 | 1 126 | 1 127 | 1 128 | 1 129 | 1 130 | 1 131 | 1 132 | 1 133 | 1 134 | 1 135 | 1 136 | 1 137 | 1 138 | 1 139 | 1 140 | 1 141 | 1 142 | 1 143 | 1 144 | 1 145 | 1 146 | 1 147 | 1 148 | 1 149 | 1 150 | 1 151 | 1 152 | 1 153 | 1 154 | 1 155 | 1 156 | 1 157 | 1 158 | 1 159 | 1 160 | 1 161 | 1 162 | 1 163 | 1 164 | 1 165 | 1 166 | 1 167 | 1 168 | 1 169 | 1 170 | 1 171 | 1 172 | 1 173 | 1 174 | 1 175 | 1 176 | 1 177 | 1 178 | 1 179 | 1 180 | 1 181 | 1 182 | 1 183 | 1 184 | 1 185 | 1 186 | 1 187 | 1 188 | 1 189 | 1 190 | 1 191 | 1 192 | 1 193 | -------------------------------------------------------------------------------- /examples/water/data/data_2/type_map.raw: -------------------------------------------------------------------------------- 1 | O 2 | H 3 | -------------------------------------------------------------------------------- /examples/water/data/data_3/set.000/box.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_3/set.000/box.npy -------------------------------------------------------------------------------- /examples/water/data/data_3/set.000/coord.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_3/set.000/coord.npy -------------------------------------------------------------------------------- /examples/water/data/data_3/set.000/energy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_3/set.000/energy.npy -------------------------------------------------------------------------------- /examples/water/data/data_3/set.000/force.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_3/set.000/force.npy -------------------------------------------------------------------------------- /examples/water/data/data_3/type.raw: -------------------------------------------------------------------------------- 1 | 0 2 | 0 3 | 0 4 | 0 5 | 0 6 | 0 7 | 0 8 | 0 9 | 0 10 | 0 11 | 0 12 | 0 13 | 0 14 | 0 15 | 0 16 | 0 17 | 0 18 | 0 19 | 0 20 | 0 21 | 0 22 | 0 23 | 0 24 | 0 25 | 0 26 | 0 27 | 0 28 | 0 29 | 0 30 | 0 31 | 0 32 | 0 33 | 0 34 | 0 35 | 0 36 | 0 37 | 0 38 | 0 39 | 0 40 | 0 41 | 0 42 | 0 43 | 0 44 | 0 45 | 0 46 | 0 47 | 0 48 | 0 49 | 0 50 | 0 51 | 0 52 | 0 53 | 0 54 | 0 55 | 0 56 | 0 57 | 0 58 | 0 59 | 0 60 | 0 61 | 0 62 | 0 63 | 0 64 | 0 65 | 1 66 | 1 67 | 1 68 | 1 69 | 1 70 | 1 71 | 1 72 | 1 73 | 1 74 | 1 75 | 1 76 | 1 77 | 1 78 | 1 79 | 1 80 | 1 81 | 1 82 | 1 83 | 1 84 | 1 85 | 1 86 | 1 87 | 1 88 | 1 89 | 1 90 | 1 91 | 1 92 | 1 93 | 1 94 | 1 95 | 1 96 | 1 97 | 1 98 | 1 99 | 1 100 | 1 101 | 1 102 | 1 103 | 1 104 | 1 105 | 1 106 | 1 107 | 1 108 | 1 109 | 1 110 | 1 111 | 1 112 | 1 113 | 1 114 | 1 115 | 1 116 | 1 117 | 1 118 | 1 119 | 1 120 | 1 121 | 1 122 | 1 123 | 1 124 | 1 125 | 1 126 | 1 127 | 1 128 | 1 129 | 1 130 | 1 131 | 1 132 | 1 133 | 1 134 | 1 135 | 1 136 | 1 137 | 1 138 | 1 139 | 1 140 | 1 141 | 1 142 | 1 143 | 1 144 | 1 145 | 1 146 | 1 147 | 1 148 | 1 149 | 1 150 | 1 151 | 1 152 | 1 153 | 1 154 | 1 155 | 1 156 | 1 157 | 1 158 | 1 159 | 1 160 | 1 161 | 1 162 | 1 163 | 1 164 | 1 165 | 1 166 | 1 167 | 1 168 | 1 169 | 1 170 | 1 171 | 1 172 | 1 173 | 1 174 | 1 175 | 1 176 | 1 177 | 1 178 | 1 179 | 1 180 | 1 181 | 1 182 | 1 183 | 1 184 | 1 185 | 1 186 | 1 187 | 1 188 | 1 189 | 1 190 | 1 191 | 1 192 | 1 193 | -------------------------------------------------------------------------------- /examples/water/data/data_3/type_map.raw: -------------------------------------------------------------------------------- 1 | O 2 | H 3 | -------------------------------------------------------------------------------- /examples/water/mace/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment1": " model parameters", 3 | "model": { 4 | "type": "mace", 5 | "type_map": [ 6 | "O", 7 | "H" 8 | ], 9 | "r_max": 6.0, 10 | "sel": "auto", 11 | "hidden_irreps": "64x0e", 12 | "_comment2": " that's all" 13 | }, 14 | 15 | "learning_rate": { 16 | "type": "exp", 17 | "decay_steps": 5000, 18 | "start_lr": 0.001, 19 | "stop_lr": 3.51e-8, 20 | "_comment3": "that's all" 21 | }, 22 | 23 | "loss": { 24 | "type": "ener", 25 | "start_pref_e": 0.02, 26 | "limit_pref_e": 1, 27 | "start_pref_f": 1000, 28 | "limit_pref_f": 1, 29 | "start_pref_v": 0, 30 | "limit_pref_v": 0, 31 | "_comment4": " that's all" 32 | }, 33 | 34 | "training": { 35 | "training_data": { 36 | "systems": [ 37 | "../data/data_0/", 38 | "../data/data_1/", 39 | "../data/data_2/" 40 | ], 41 | "batch_size": "auto", 42 | "_comment5": "that's all" 43 | }, 44 | "validation_data": { 45 | "systems": [ 46 | "../data/data_3" 47 | ], 48 | "batch_size": 1, 49 | "numb_btch": 3, 50 | "_comment6": "that's all" 51 | }, 52 | "numb_steps": 1000000, 53 | "seed": 10, 54 | "disp_file": "lcurve.out", 55 | "disp_freq": 100, 56 | "save_freq": 1000, 57 | "_comment7": "that's all" 58 | }, 59 | 60 | "_comment8": "that's all" 61 | } 62 | -------------------------------------------------------------------------------- /examples/water/nequip/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment1": " model parameters", 3 | "model": { 4 | "type": "nequip", 5 | "type_map": [ 6 | "O", 7 | "H" 8 | ], 9 | "r_max": 6.0, 10 | "sel": "auto", 11 | "l_max": 1, 12 | "_comment2": " that's all" 13 | }, 14 | 15 | "learning_rate": { 16 | "type": "exp", 17 | "decay_steps": 5000, 18 | "start_lr": 0.001, 19 | "stop_lr": 3.51e-8, 20 | "_comment3": "that's all" 21 | }, 22 | 23 | "loss": { 24 | "type": "ener", 25 | "start_pref_e": 0.02, 26 | "limit_pref_e": 1, 27 | "start_pref_f": 1000, 28 | "limit_pref_f": 1, 29 | "start_pref_v": 0, 30 | "limit_pref_v": 0, 31 | "_comment4": " that's all" 32 | }, 33 | 34 | "training": { 35 | "training_data": { 36 | "systems": [ 37 | "../data/data_0/", 38 | "../data/data_1/", 39 | "../data/data_2/" 40 | ], 41 | "batch_size": "auto", 42 | "_comment5": "that's all" 43 | }, 44 | "validation_data": { 45 | "systems": [ 46 | "../data/data_3" 47 | ], 48 | "batch_size": 1, 49 | "numb_btch": 3, 50 | "_comment6": "that's all" 51 | }, 52 | "numb_steps": 1000000, 53 | "seed": 10, 54 | "disp_file": "lcurve.out", 55 | "disp_freq": 100, 56 | "save_freq": 1000, 57 | "_comment7": "that's all" 58 | }, 59 | 60 | "_comment8": "that's all" 61 | } 62 | -------------------------------------------------------------------------------- /noxfile.py: -------------------------------------------------------------------------------- 1 | """Nox configuration file.""" 2 | 3 | from __future__ import annotations 4 | 5 | import nox 6 | 7 | 8 | @nox.session 9 | def tests(session: nox.Session) -> None: 10 | """Run test suite with pytest.""" 11 | session.install( 12 | "numpy", 13 | "deepmd-kit[torch]>=3.0.0b2", 14 | "--extra-index-url", 15 | "https://download.pytorch.org/whl/cpu", 16 | ) 17 | cmake_prefix_path = session.run( 18 | "python", 19 | "-c", 20 | "import torch;print(torch.utils.cmake_prefix_path)", 21 | silent=True, 22 | ).strip() 23 | session.log(f"{cmake_prefix_path=}") 24 | session.install("-e.[test]", env={"CMAKE_PREFIX_PATH": cmake_prefix_path}) 25 | session.run( 26 | "pytest", 27 | "--cov", 28 | "--cov-config", 29 | "pyproject.toml", 30 | "--cov-report", 31 | "term", 32 | "--cov-report", 33 | "xml", 34 | ) 35 | -------------------------------------------------------------------------------- /op/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB OP_SRC edge_index.cc) 2 | 3 | add_library(deepmd_gnn MODULE ${OP_SRC}) 4 | # link: libdeepmd libtorch 5 | target_link_libraries(deepmd_gnn PRIVATE ${TORCH_LIBRARIES}) 6 | target_compile_definitions( 7 | deepmd_gnn 8 | PUBLIC "$<$:_GLIBCXX_USE_CXX11_ABI=${OP_CXX_ABI_PT}>") 9 | if(APPLE) 10 | set_target_properties(deepmd_gnn PROPERTIES INSTALL_RPATH "@loader_path") 11 | else() 12 | set_target_properties(deepmd_gnn PROPERTIES INSTALL_RPATH "$ORIGIN") 13 | endif() 14 | 15 | if(BUILD_PY_IF) 16 | install(TARGETS deepmd_gnn DESTINATION deepmd_gnn/lib/) 17 | file(TOUCH ${CMAKE_CURRENT_BINARY_DIR}/__init__.py) 18 | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/__init__.py 19 | DESTINATION deepmd_gnn/lib) 20 | else(BUILD_PY_IF) 21 | install(TARGETS deepmd_gnn DESTINATION lib/) 22 | endif(BUILD_PY_IF) 23 | -------------------------------------------------------------------------------- /op/edge_index.cc: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: LGPL-3.0-or-later 2 | #include 3 | 4 | #include 5 | 6 | torch::Tensor edge_index_kernel(const torch::Tensor &nlist_tensor, 7 | const torch::Tensor &atype_tensor, 8 | const torch::Tensor &mm_tensor) { 9 | torch::Tensor nlist_tensor_ = nlist_tensor.cpu().contiguous(); 10 | torch::Tensor atype_tensor_ = atype_tensor.cpu().contiguous(); 11 | torch::Tensor mm_tensor_ = mm_tensor.cpu().contiguous(); 12 | if (nlist_tensor_.dim() == 2) { 13 | nlist_tensor_ = 14 | nlist_tensor_.view({1, nlist_tensor_.size(0), nlist_tensor_.size(1)}); 15 | if (atype_tensor_.dim() != 1) { 16 | throw std::invalid_argument("atype_tensor must be 1D"); 17 | } 18 | atype_tensor_ = atype_tensor_.view({1, atype_tensor_.size(0)}); 19 | } else if (nlist_tensor_.dim() == 3) { 20 | if (atype_tensor_.dim() != 2) { 21 | throw std::invalid_argument("atype_tensor must be 2D"); 22 | } 23 | } else { 24 | throw std::invalid_argument("nlist_tensor must be 2D or 3D"); 25 | } 26 | 27 | const int64_t nf = nlist_tensor_.size(0); 28 | const int64_t nloc = nlist_tensor_.size(1); 29 | const int64_t nnei = nlist_tensor_.size(2); 30 | if (atype_tensor_.size(0) != nf) { 31 | throw std::invalid_argument( 32 | "atype_tensor must have the same size as nlist_tensor"); 33 | } 34 | const int64_t nall = atype_tensor_.size(1); 35 | const int64_t nmm = mm_tensor_.size(0); 36 | int64_t *nlist = nlist_tensor_.view({-1}).data_ptr(); 37 | int64_t *atype = atype_tensor_.view({-1}).data_ptr(); 38 | int64_t *mm = mm_tensor_.view({-1}).data_ptr(); 39 | 40 | std::vector edge_index; 41 | edge_index.reserve(nf * nloc * nnei * 2); 42 | 43 | for (int64_t ff = 0; ff < nf; ff++) { 44 | for (int64_t ii = 0; ii < nloc; ii++) { 45 | for (int64_t jj = 0; jj < nnei; jj++) { 46 | int64_t idx = ff * nloc * nnei + ii * nnei + jj; 47 | int64_t kk = nlist[idx]; 48 | if (kk < 0) { 49 | continue; 50 | } 51 | int64_t global_kk = ff * nall + kk; 52 | int64_t global_ii = ff * nall + ii; 53 | // check if both atype[ii] and atype[kk] are in mm 54 | bool in_mm1 = false; 55 | for (int64_t mm_idx = 0; mm_idx < nmm; mm_idx++) { 56 | if (atype[global_ii] == mm[mm_idx]) { 57 | in_mm1 = true; 58 | break; 59 | } 60 | } 61 | bool in_mm2 = false; 62 | for (int64_t mm_idx = 0; mm_idx < nmm; mm_idx++) { 63 | if (atype[global_kk] == mm[mm_idx]) { 64 | in_mm2 = true; 65 | break; 66 | } 67 | } 68 | if (in_mm1 && in_mm2) { 69 | continue; 70 | } 71 | // add edge 72 | edge_index.push_back(global_kk); 73 | edge_index.push_back(global_ii); 74 | } 75 | } 76 | } 77 | // convert to tensor 78 | int64_t edge_size = edge_index.size() / 2; 79 | torch::Tensor edge_index_tensor = 80 | torch::tensor(edge_index, torch::kInt64).view({edge_size, 2}); 81 | // to nlist_tensor.device 82 | return edge_index_tensor.to(nlist_tensor.device()); 83 | } 84 | 85 | TORCH_LIBRARY(deepmd_gnn, m) { m.def("edge_index", edge_index_kernel); } 86 | // compatbility with old models freezed by deepmd_mace package 87 | TORCH_LIBRARY(deepmd_mace, m) { m.def("mace_edge_index", edge_index_kernel); } 88 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "scikit-build-core>=0.3.0", 4 | ] 5 | build-backend = "build_backend.dp_backend" 6 | backend-path = ["."] 7 | 8 | [project] 9 | name = "deepmd-gnn" 10 | dynamic = ["version"] 11 | description = "DeePMD-kit plugin for graph neural network models." 12 | authors = [ 13 | { name = "Jinzhe Zeng", email = "jinzhe.zeng@rutgers.edu"}, 14 | ] 15 | license = { file = "LICENSE" } 16 | classifiers = [ 17 | "Programming Language :: Python :: 3.9", 18 | "Programming Language :: Python :: 3.10", 19 | "Programming Language :: Python :: 3.11", 20 | "Programming Language :: Python :: 3.12", 21 | "Operating System :: POSIX :: Linux", 22 | "Operating System :: MacOS :: MacOS X", 23 | "Operating System :: Microsoft :: Windows", 24 | "License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)", 25 | ] 26 | dependencies = [ 27 | "torch", 28 | "deepmd-kit[torch]>=3.0.0b2", 29 | "mace-torch>=0.3.5", 30 | "nequip", 31 | "e3nn", 32 | "dargs", 33 | ] 34 | requires-python = ">=3.9" 35 | readme = "README.md" 36 | keywords = [ 37 | ] 38 | 39 | [project.scripts] 40 | 41 | [project.entry-points."deepmd.pt"] 42 | mace = "deepmd_gnn.mace:MaceModel" 43 | nequip = "deepmd_gnn.nequip:NequipModel" 44 | 45 | [project.urls] 46 | repository = "https://gitlab.com/RutgersLBSR/deepmd-gnn" 47 | 48 | [project.optional-dependencies] 49 | test = [ 50 | 'pytest', 51 | 'pytest-cov', 52 | "dargs>=0.4.8", 53 | ] 54 | docs = [ 55 | "sphinx", 56 | "sphinx-autoapi", 57 | "myst-parser", 58 | "deepmodeling-sphinx>=0.3.0", 59 | "sphinx-book-theme", 60 | "dargs", 61 | ] 62 | 63 | [tool.scikit-build] 64 | wheel.py-api = "py2.py3" 65 | metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" 66 | sdist.include = [ 67 | "/deepmd_gnn/_version.py", 68 | ] 69 | 70 | [tool.scikit-build.cmake.define] 71 | BUILD_PY_IF = true 72 | BUILD_CPP_IF = false 73 | 74 | [tool.setuptools_scm] 75 | version_file = "deepmd_gnn/_version.py" 76 | 77 | [tool.ruff.lint] 78 | select = [ 79 | "ALL", 80 | ] 81 | ignore = [ 82 | "PLR0912", 83 | "PLR0913", # Too many arguments in function definition 84 | "PLR0915", 85 | "PLR2004", 86 | "FBT001", 87 | "FBT002", 88 | "N803", 89 | "FA100", 90 | "S603", 91 | "ANN101", 92 | "ANN102", 93 | "C901", 94 | "E501", 95 | ] 96 | 97 | [tool.ruff.lint.pydocstyle] 98 | convention = "numpy" 99 | 100 | [tool.ruff.lint.extend-per-file-ignores] 101 | "tests/**/*.py" = [ 102 | "S101", # asserts allowed in tests... 103 | "ANN", 104 | "D101", 105 | "D102", 106 | ] 107 | "docs/conf.py" = [ 108 | "ERA001", 109 | "INP001", 110 | ] 111 | 112 | [tool.coverage.report] 113 | include = ["deepmd_gnn/*"] 114 | 115 | 116 | [tool.cibuildwheel] 117 | test-command = [ 118 | """python -c "import deepmd_gnn.op" """, 119 | ] 120 | build = ["cp312-*"] 121 | skip = ["*-win32", "*-manylinux_i686", "*-musllinux*"] 122 | manylinux-x86_64-image = "manylinux_2_28" 123 | manylinux-aarch64-image = "manylinux_2_28" 124 | 125 | [tool.cibuildwheel.macos] 126 | repair-wheel-command = """delocate-wheel --require-archs {delocate_archs} -w {dest_dir} -v {wheel} --ignore-missing-dependencies""" 127 | 128 | [tool.cibuildwheel.linux] 129 | repair-wheel-command = "auditwheel repair --exclude libc10.so --exclude libtorch.so --exclude libtorch_cpu.so -w {dest_dir} {wheel}" 130 | environment-pass = [ 131 | "CIBW_BUILD", 132 | ] 133 | [tool.cibuildwheel.linux.environment] 134 | # use CPU version of torch for building, which should also work for GPU 135 | # note: uv has different behavior from pip on extra index url 136 | # https://github.com/astral-sh/uv/blob/main/PIP_COMPATIBILITY.md#packages-that-exist-on-multiple-indexes 137 | UV_EXTRA_INDEX_URL = "https://download.pytorch.org/whl/cpu" 138 | -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json", 3 | "extends": [ 4 | "local>njzjz/renovate-config" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests.""" 2 | -------------------------------------------------------------------------------- /tests/data/set.000/box.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/tests/data/set.000/box.npy -------------------------------------------------------------------------------- /tests/data/set.000/coord.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/tests/data/set.000/coord.npy -------------------------------------------------------------------------------- /tests/data/set.000/energy.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/tests/data/set.000/energy.npy -------------------------------------------------------------------------------- /tests/data/set.000/force.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/tests/data/set.000/force.npy -------------------------------------------------------------------------------- /tests/data/type.raw: -------------------------------------------------------------------------------- 1 | 0 2 | 0 3 | 0 4 | 0 5 | 0 6 | 0 7 | 0 8 | 0 9 | 0 10 | 0 11 | 0 12 | 0 13 | 0 14 | 0 15 | 0 16 | 0 17 | 0 18 | 0 19 | 0 20 | 0 21 | 0 22 | 0 23 | 0 24 | 0 25 | 0 26 | 0 27 | 0 28 | 0 29 | 0 30 | 0 31 | 0 32 | 0 33 | 0 34 | 0 35 | 0 36 | 0 37 | 0 38 | 0 39 | 0 40 | 0 41 | 0 42 | 0 43 | 0 44 | 0 45 | 0 46 | 0 47 | 0 48 | 0 49 | 0 50 | 0 51 | 0 52 | 0 53 | 0 54 | 0 55 | 0 56 | 0 57 | 0 58 | 0 59 | 0 60 | 0 61 | 0 62 | 0 63 | 0 64 | 0 65 | 1 66 | 1 67 | 1 68 | 1 69 | 1 70 | 1 71 | 1 72 | 1 73 | 1 74 | 1 75 | 1 76 | 1 77 | 1 78 | 1 79 | 1 80 | 1 81 | 1 82 | 1 83 | 1 84 | 1 85 | 1 86 | 1 87 | 1 88 | 1 89 | 1 90 | 1 91 | 1 92 | 1 93 | 1 94 | 1 95 | 1 96 | 1 97 | 1 98 | 1 99 | 1 100 | 1 101 | 1 102 | 1 103 | 1 104 | 1 105 | 1 106 | 1 107 | 1 108 | 1 109 | 1 110 | 1 111 | 1 112 | 1 113 | 1 114 | 1 115 | 1 116 | 1 117 | 1 118 | 1 119 | 1 120 | 1 121 | 1 122 | 1 123 | 1 124 | 1 125 | 1 126 | 1 127 | 1 128 | 1 129 | 1 130 | 1 131 | 1 132 | 1 133 | 1 134 | 1 135 | 1 136 | 1 137 | 1 138 | 1 139 | 1 140 | 1 141 | 1 142 | 1 143 | 1 144 | 1 145 | 1 146 | 1 147 | 1 148 | 1 149 | 1 150 | 1 151 | 1 152 | 1 153 | 1 154 | 1 155 | 1 156 | 1 157 | 1 158 | 1 159 | 1 160 | 1 161 | 1 162 | 1 163 | 1 164 | 1 165 | 1 166 | 1 167 | 1 168 | 1 169 | 1 170 | 1 171 | 1 172 | 1 173 | 1 174 | 1 175 | 1 176 | 1 177 | 1 178 | 1 179 | 1 180 | 1 181 | 1 182 | 1 183 | 1 184 | 1 185 | 1 186 | 1 187 | 1 188 | 1 189 | 1 190 | 1 191 | 1 192 | 1 193 | -------------------------------------------------------------------------------- /tests/data/type_map.raw: -------------------------------------------------------------------------------- 1 | O 2 | H 3 | -------------------------------------------------------------------------------- /tests/mace.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment1": " model parameters", 3 | "model": { 4 | "type": "mace", 5 | "type_map": [ 6 | "O", 7 | "H" 8 | ], 9 | "r_max": 6.0, 10 | "sel": "auto", 11 | "hidden_irreps": "16x0e", 12 | "_comment2": " that's all" 13 | }, 14 | 15 | "learning_rate": { 16 | "type": "exp", 17 | "decay_steps": 5000, 18 | "start_lr": 0.001, 19 | "stop_lr": 0.0009, 20 | "_comment3": "that's all" 21 | }, 22 | 23 | "loss": { 24 | "type": "ener", 25 | "start_pref_e": 0.02, 26 | "limit_pref_e": 1, 27 | "start_pref_f": 1000, 28 | "limit_pref_f": 1, 29 | "start_pref_v": 0, 30 | "limit_pref_v": 0, 31 | "_comment4": " that's all" 32 | }, 33 | 34 | "training": { 35 | "training_data": { 36 | "systems": [ 37 | "./data" 38 | ], 39 | "batch_size": "auto", 40 | "_comment5": "that's all" 41 | }, 42 | "validation_data": { 43 | "systems": [ 44 | "./data" 45 | ], 46 | "batch_size": 1, 47 | "numb_btch": 3, 48 | "_comment6": "that's all" 49 | }, 50 | "numb_steps": 2, 51 | "seed": 10, 52 | "disp_file": "lcurve.out", 53 | "disp_freq": 1, 54 | "save_freq": 1, 55 | "_comment7": "that's all" 56 | }, 57 | 58 | "_comment8": "that's all" 59 | } 60 | -------------------------------------------------------------------------------- /tests/nequip.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment1": " model parameters", 3 | "model": { 4 | "type": "mace", 5 | "type_map": [ 6 | "O", 7 | "H" 8 | ], 9 | "r_max": 6.0, 10 | "sel": "auto", 11 | "hidden_irreps": "16x0e", 12 | "_comment2": " that's all" 13 | }, 14 | 15 | "learning_rate": { 16 | "type": "exp", 17 | "decay_steps": 5000, 18 | "start_lr": 0.001, 19 | "stop_lr": 0.0009, 20 | "_comment3": "that's all" 21 | }, 22 | 23 | "loss": { 24 | "type": "ener", 25 | "start_pref_e": 0.02, 26 | "limit_pref_e": 1, 27 | "start_pref_f": 1000, 28 | "limit_pref_f": 1, 29 | "start_pref_v": 0, 30 | "limit_pref_v": 0, 31 | "_comment4": " that's all" 32 | }, 33 | 34 | "training": { 35 | "training_data": { 36 | "systems": [ 37 | "./data" 38 | ], 39 | "batch_size": "auto", 40 | "_comment5": "that's all" 41 | }, 42 | "validation_data": { 43 | "systems": [ 44 | "./data" 45 | ], 46 | "batch_size": 1, 47 | "numb_btch": 3, 48 | "_comment6": "that's all" 49 | }, 50 | "numb_steps": 2, 51 | "seed": 10, 52 | "disp_file": "lcurve.out", 53 | "disp_freq": 1, 54 | "save_freq": 1, 55 | "_comment7": "that's all" 56 | }, 57 | 58 | "_comment8": "that's all" 59 | } 60 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | """Test examples.""" 2 | 3 | import json 4 | from pathlib import Path 5 | 6 | import pytest 7 | from dargs.check import check 8 | from deepmd.utils.argcheck import gen_args 9 | 10 | from deepmd_gnn.argcheck import mace_model_args # noqa: F401 11 | 12 | example_path = Path(__file__).parent.parent / "examples" 13 | 14 | examples = ( 15 | example_path / "water" / "mace" / "input.json", 16 | example_path / "dprc" / "mace" / "input.json", 17 | example_path / "water" / "nequip" / "input.json", 18 | example_path / "dprc" / "nequip" / "input.json", 19 | ) 20 | 21 | 22 | @pytest.mark.parametrize("example", examples) 23 | def test_examples(example: Path) -> None: 24 | """Check whether examples meet arguments.""" 25 | with example.open("r") as f: 26 | data = json.load(f) 27 | check( 28 | gen_args(), 29 | data, 30 | ) 31 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: LGPL-3.0-or-later 2 | """Test models.""" 3 | 4 | import unittest 5 | from copy import deepcopy 6 | from typing import Any, Callable, ClassVar, Optional 7 | 8 | import deepmd.pt.model # noqa: F401 9 | import numpy as np 10 | import torch 11 | from deepmd.dpmodel.output_def import ( 12 | check_deriv, 13 | ) 14 | from deepmd.dpmodel.utils.nlist import ( 15 | build_neighbor_list, 16 | extend_coord_with_ghosts, 17 | extend_input_and_build_neighbor_list, 18 | ) 19 | from deepmd.dpmodel.utils.region import ( 20 | normalize_coord, 21 | ) 22 | from deepmd.pt.utils.utils import ( 23 | to_numpy_array, 24 | to_torch_tensor, 25 | ) 26 | 27 | from deepmd_gnn.mace import MaceModel 28 | from deepmd_gnn.nequip import NequipModel 29 | 30 | GLOBAL_SEED = 20240822 31 | 32 | torch.set_default_dtype(torch.float64) 33 | 34 | 35 | class PTTestCase: 36 | """Common test case.""" 37 | 38 | module: "torch.nn.Module" 39 | """PT module to test.""" 40 | 41 | skipTest: Callable[[str], None] # noqa: N815 42 | """Skip test method.""" 43 | 44 | @property 45 | def script_module(self) -> torch.jit.ScriptModule: 46 | """Script module.""" 47 | with torch.jit.optimized_execution(should_optimize=False): 48 | return torch.jit.script(self.module) 49 | 50 | @property 51 | def deserialized_module(self) -> "torch.nn.Module": 52 | """Deserialized module.""" 53 | return self.module.deserialize(self.module.serialize()) 54 | 55 | @property 56 | def modules_to_test(self) -> list["torch.nn.Module"]: 57 | """Modules to test.""" 58 | return [ 59 | self.module, 60 | self.deserialized_module, 61 | ] 62 | 63 | def test_jit(self) -> None: 64 | """Test jit.""" 65 | if getattr(self, "skip_test_jit", False): 66 | self.skipTest("Skip test jit.") 67 | self.script_module # noqa: B018 68 | 69 | @classmethod 70 | def convert_to_numpy(cls, xx: torch.Tensor) -> np.ndarray: 71 | """Convert to numpy array.""" 72 | return to_numpy_array(xx) 73 | 74 | @classmethod 75 | def convert_from_numpy(cls, xx: np.ndarray) -> torch.Tensor: 76 | """Convert from numpy array.""" 77 | return to_torch_tensor(xx) 78 | 79 | def forward_wrapper_cpu_ref(self, module): 80 | module.to("cpu") 81 | return self.forward_wrapper(module, on_cpu=True) 82 | 83 | def forward_wrapper(self, module, on_cpu=False): 84 | def create_wrapper_method(method): 85 | def wrapper_method(self, *args, **kwargs): # noqa: ARG001 86 | # convert to torch tensor 87 | args = [to_torch_tensor(arg) for arg in args] 88 | kwargs = {k: to_torch_tensor(v) for k, v in kwargs.items()} 89 | if on_cpu: 90 | args = [ 91 | arg.detach().cpu() if arg is not None else None for arg in args 92 | ] 93 | kwargs = { 94 | k: v.detach().cpu() if v is not None else None 95 | for k, v in kwargs.items() 96 | } 97 | # forward 98 | output = method(*args, **kwargs) 99 | # convert to numpy array 100 | if isinstance(output, tuple): 101 | output = tuple(to_numpy_array(o) for o in output) 102 | elif isinstance(output, dict): 103 | output = {k: to_numpy_array(v) for k, v in output.items()} 104 | else: 105 | output = to_numpy_array(output) 106 | return output 107 | 108 | return wrapper_method 109 | 110 | class WrapperModule: 111 | __call__ = create_wrapper_method(module.__call__) 112 | if hasattr(module, "forward_lower"): 113 | forward_lower = create_wrapper_method(module.forward_lower) 114 | 115 | return WrapperModule() 116 | 117 | 118 | class ModelTestCase: 119 | """Common test case for model.""" 120 | 121 | module: torch.nn.Module 122 | """Module to test.""" 123 | modules_to_test: list[torch.nn.Module] 124 | """Modules to test.""" 125 | expected_type_map: list[str] 126 | """Expected type map.""" 127 | expected_rcut: float 128 | """Expected cut-off radius.""" 129 | expected_dim_fparam: int 130 | """Expected number (dimension) of frame parameters.""" 131 | expected_dim_aparam: int 132 | """Expected number (dimension) of atomic parameters.""" 133 | expected_sel_type: list[int] 134 | """Expected selected atom types.""" 135 | expected_aparam_nall: bool 136 | """Expected shape of atomic parameters.""" 137 | expected_model_output_type: list[str] 138 | """Expected output type for the model.""" 139 | model_output_equivariant: list[str] 140 | """Outputs that are equivariant to the input rotation.""" 141 | expected_sel: list[int] 142 | """Expected number of neighbors.""" 143 | expected_has_message_passing: bool 144 | """Expected whether having message passing.""" 145 | expected_nmpnn: int 146 | """Expected number of MPNN.""" 147 | forward_wrapper: ClassVar[Callable[[Any, bool], Any]] 148 | """Class wrapper for forward method.""" 149 | forward_wrapper_cpu_ref: Callable[[Any], Any] 150 | """Convert model to CPU method.""" 151 | aprec_dict: dict[str, Optional[float]] 152 | """Dictionary of absolute precision in each test.""" 153 | rprec_dict: dict[str, Optional[float]] 154 | """Dictionary of relative precision in each test.""" 155 | epsilon_dict: dict[str, Optional[float]] 156 | """Dictionary of epsilons in each test.""" 157 | 158 | skipTest: Callable[[str], None] # noqa: N815 159 | """Skip test method.""" 160 | output_def: dict[str, Any] 161 | """Output definition.""" 162 | 163 | def test_get_type_map(self) -> None: 164 | """Test get_type_map.""" 165 | for module in self.modules_to_test: 166 | assert module.get_type_map() == self.expected_type_map 167 | 168 | def test_get_rcut(self) -> None: 169 | """Test get_rcut.""" 170 | for module in self.modules_to_test: 171 | assert module.get_rcut() == self.expected_rcut * self.expected_nmpnn 172 | 173 | def test_get_dim_fparam(self) -> None: 174 | """Test get_dim_fparam.""" 175 | for module in self.modules_to_test: 176 | assert module.get_dim_fparam() == self.expected_dim_fparam 177 | 178 | def test_get_dim_aparam(self) -> None: 179 | """Test get_dim_aparam.""" 180 | for module in self.modules_to_test: 181 | assert module.get_dim_aparam() == self.expected_dim_aparam 182 | 183 | def test_get_sel_type(self) -> None: 184 | """Test get_sel_type.""" 185 | for module in self.modules_to_test: 186 | assert module.get_sel_type() == self.expected_sel_type 187 | 188 | def test_is_aparam_nall(self) -> None: 189 | """Test is_aparam_nall.""" 190 | for module in self.modules_to_test: 191 | assert module.is_aparam_nall() == self.expected_aparam_nall 192 | 193 | def test_model_output_type(self) -> None: 194 | """Test model_output_type.""" 195 | for module in self.modules_to_test: 196 | assert module.model_output_type() == self.expected_model_output_type 197 | 198 | def test_get_nnei(self) -> None: 199 | """Test get_nnei.""" 200 | expected_nnei = sum(self.expected_sel) 201 | for module in self.modules_to_test: 202 | assert module.get_nnei() == expected_nnei 203 | 204 | def test_get_ntypes(self) -> None: 205 | """Test get_ntypes.""" 206 | for module in self.modules_to_test: 207 | assert module.get_ntypes() == len(self.expected_type_map) 208 | 209 | def test_has_message_passing(self) -> None: 210 | """Test has_message_passing.""" 211 | for module in self.modules_to_test: 212 | assert module.has_message_passing() == self.expected_has_message_passing 213 | 214 | def test_forward(self) -> None: 215 | """Test forward and forward_lower.""" 216 | test_spin = getattr(self, "test_spin", False) 217 | nf = 2 218 | natoms = 5 219 | aprec = ( 220 | 0 221 | if self.aprec_dict.get("test_forward", None) is None 222 | else self.aprec_dict["test_forward"] 223 | ) 224 | rng = np.random.default_rng(GLOBAL_SEED) 225 | coord = 4.0 * rng.random([1, natoms, 3]).repeat(nf, 0).reshape([nf, -1]) 226 | atype = np.array([[0, 0, 0, 1, 1] * nf], dtype=int).reshape([nf, -1]) 227 | spin = 0.5 * rng.random([1, natoms, 3]).repeat(nf, 0).reshape([nf, -1]) 228 | cell = 6.0 * np.repeat(np.eye(3)[None, ...], nf, axis=0).reshape([nf, 9]) 229 | coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list( 230 | coord, 231 | atype, 232 | self.expected_rcut + 1.0 if test_spin else self.expected_rcut, 233 | self.expected_sel, 234 | mixed_types=self.module.mixed_types(), 235 | box=cell, 236 | ) 237 | coord_normalized = normalize_coord( 238 | coord.reshape(nf, natoms, 3), 239 | cell.reshape(nf, 3, 3), 240 | ) 241 | coord_ext_large, atype_ext_large, mapping_large = extend_coord_with_ghosts( 242 | coord_normalized, 243 | atype, 244 | cell, 245 | self.module.get_rcut(), 246 | ) 247 | nlist_large = build_neighbor_list( 248 | coord_ext_large, 249 | atype_ext_large, 250 | natoms, 251 | self.expected_rcut, 252 | self.expected_sel, 253 | distinguish_types=(not self.module.mixed_types()), 254 | ) 255 | spin_ext = np.take_along_axis( 256 | spin.reshape(nf, -1, 3), 257 | np.repeat(np.expand_dims(mapping, axis=-1), 3, axis=-1), 258 | axis=1, 259 | ) 260 | aparam = None 261 | fparam = None 262 | if self.module.get_dim_aparam() > 0: 263 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()]) 264 | if self.module.get_dim_fparam() > 0: 265 | fparam = rng.random([nf, self.module.get_dim_fparam()]) 266 | ret = [] 267 | ret_lower = [] 268 | for _module in self.modules_to_test: 269 | module = self.forward_wrapper(_module) 270 | input_dict = { 271 | "coord": coord, 272 | "atype": atype, 273 | "box": cell, 274 | "aparam": aparam, 275 | "fparam": fparam, 276 | } 277 | if test_spin: 278 | input_dict["spin"] = spin 279 | ret.append(module(**input_dict)) 280 | 281 | input_dict_lower = { 282 | "extended_coord": coord_ext_large, 283 | "extended_atype": atype_ext_large, 284 | "nlist": nlist_large, 285 | "aparam": aparam, 286 | "fparam": fparam, 287 | "mapping": mapping_large, 288 | } 289 | if test_spin: 290 | input_dict_lower["extended_spin"] = spin_ext 291 | 292 | # use shuffled nlist, simulating the lammps interface 293 | rng.shuffle(input_dict_lower["nlist"], axis=-1) 294 | ret_lower.append(module.forward_lower(**input_dict_lower)) 295 | 296 | input_dict_lower = { 297 | "extended_coord": coord_ext_large, 298 | "extended_atype": atype_ext_large, 299 | "nlist": nlist_large, 300 | "aparam": aparam, 301 | "fparam": fparam, 302 | } 303 | if test_spin: 304 | input_dict_lower["extended_spin"] = spin_ext 305 | 306 | # use shuffled nlist, simulating the lammps interface 307 | rng.shuffle(input_dict_lower["nlist"], axis=-1) 308 | ret_lower.append(module.forward_lower(**input_dict_lower)) 309 | 310 | for kk in ret[0]: 311 | # ensure the first frame and the second frame are the same 312 | np.testing.assert_allclose( 313 | ret[0][kk][0], 314 | ret[0][kk][1], 315 | err_msg=f"compare {kk} between frame 0 and 1", 316 | ) 317 | 318 | subret = [rr[kk] for rr in ret if rr is not None] 319 | if subret: 320 | for ii, rr in enumerate(subret[1:]): 321 | if subret[0] is None: 322 | assert rr is None 323 | else: 324 | np.testing.assert_allclose( 325 | subret[0], 326 | rr, 327 | err_msg=f"compare {kk} between 0 and {ii}", 328 | ) 329 | for kk in ret_lower[0]: 330 | subret = [] 331 | for rr in ret_lower: 332 | if rr is not None: 333 | subret.append(rr[kk]) 334 | if len(subret): 335 | for ii, rr in enumerate(subret[1:]): 336 | if kk == "expanded_force": 337 | # use mapping to scatter sum the forces 338 | rr = np.take_along_axis( # noqa: PLW2901 339 | rr, 340 | np.repeat( 341 | np.expand_dims(mapping_large, axis=-1), 342 | 3, 343 | axis=-1, 344 | ), 345 | axis=1, 346 | ) 347 | if subret[0] is None: 348 | assert rr is None 349 | else: 350 | np.testing.assert_allclose( 351 | subret[0], 352 | rr, 353 | atol=1e-5, 354 | err_msg=f"compare {kk} between 0 and {ii}", 355 | ) 356 | same_keys = set(ret[0].keys()) & set(ret_lower[0].keys()) 357 | assert same_keys 358 | for key in same_keys: 359 | for rr in ret: 360 | if rr[key] is not None: 361 | rr1 = rr[key] 362 | break 363 | else: 364 | continue 365 | for rr in ret_lower: 366 | if rr[key] is not None: 367 | rr2 = rr[key] 368 | break 369 | else: 370 | continue 371 | np.testing.assert_allclose(rr1, rr2, atol=aprec) 372 | 373 | def test_permutation(self) -> None: 374 | """Test permutation.""" 375 | if getattr(self, "skip_test_permutation", False): 376 | self.skipTest("Skip test permutation.") 377 | test_spin = getattr(self, "test_spin", False) 378 | rng = np.random.default_rng(GLOBAL_SEED) 379 | natoms = 5 380 | nf = 1 381 | aprec = ( 382 | 0 383 | if self.aprec_dict.get("test_permutation", None) is None 384 | else self.aprec_dict["test_permutation"] 385 | ) 386 | idx = [0, 1, 2, 3, 4] 387 | idx_perm = [1, 0, 4, 3, 2] 388 | cell = rng.random([3, 3]) 389 | cell = (cell + cell.T) + 5.0 * np.eye(3) 390 | coord = rng.random([natoms, 3]) 391 | coord = np.matmul(coord, cell) 392 | spin = 0.1 * rng.random([natoms, 3]) 393 | atype = np.array([0, 0, 0, 1, 1]) 394 | coord_perm = coord[idx_perm] 395 | spin_perm = spin[idx_perm] 396 | atype_perm = atype[idx_perm] 397 | 398 | # reshape for input 399 | coord = coord.reshape([nf, -1]) 400 | coord_perm = coord_perm.reshape([nf, -1]) 401 | spin_perm = spin_perm.reshape([nf, -1]) 402 | atype = atype.reshape([nf, -1]) 403 | atype_perm = atype_perm.reshape([nf, -1]) 404 | cell = cell.reshape([nf, 9]) 405 | 406 | aparam = None 407 | fparam = None 408 | aparam_perm = None 409 | if self.module.get_dim_aparam() > 0: 410 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()]) 411 | aparam_perm = aparam[:, idx_perm, :] 412 | if self.module.get_dim_fparam() > 0: 413 | fparam = rng.random([nf, self.module.get_dim_fparam()]) 414 | 415 | ret = [] 416 | module = self.forward_wrapper(self.module) 417 | input_dict = { 418 | "coord": coord, 419 | "atype": atype, 420 | "box": cell, 421 | "aparam": aparam, 422 | "fparam": fparam, 423 | } 424 | if test_spin: 425 | input_dict["spin"] = spin 426 | ret.append(module(**input_dict)) 427 | # permutation 428 | input_dict["coord"] = coord_perm 429 | input_dict["atype"] = atype_perm 430 | input_dict["aparam"] = aparam_perm 431 | if test_spin: 432 | input_dict["spin"] = spin_perm 433 | ret.append(module(**input_dict)) 434 | 435 | for kk in ret[0]: 436 | if kk in self.output_def: 437 | if ret[0][kk] is None: 438 | assert ret[1][kk] is None 439 | continue 440 | atomic = self.output_def[kk].atomic 441 | if atomic: 442 | np.testing.assert_allclose( 443 | ret[0][kk][:, idx_perm], 444 | ret[1][kk][:, idx], # for extended output 445 | err_msg=f"compare {kk} before and after transform", 446 | atol=aprec, 447 | ) 448 | else: 449 | np.testing.assert_allclose( 450 | ret[0][kk], 451 | ret[1][kk], 452 | err_msg=f"compare {kk} before and after transform", 453 | atol=aprec, 454 | ) 455 | else: 456 | msg = f"Unknown output key: {kk}" 457 | raise RuntimeError(msg) 458 | 459 | def test_trans(self) -> None: 460 | """Test translation.""" 461 | if getattr(self, "skip_test_trans", False): 462 | self.skipTest("Skip test translation.") 463 | test_spin = getattr(self, "test_spin", False) 464 | rng = np.random.default_rng(GLOBAL_SEED) 465 | natoms = 5 466 | nf = 1 467 | aprec = ( 468 | 1e-14 469 | if self.aprec_dict.get("test_rot", None) is None 470 | else self.aprec_dict["test_rot"] 471 | ) 472 | cell = rng.random([3, 3]) 473 | cell = (cell + cell.T) + 5.0 * np.eye(3) 474 | coord = rng.random([natoms, 3]) 475 | coord = np.matmul(coord, cell) 476 | spin = 0.1 * rng.random([natoms, 3]) 477 | atype = np.array([0, 0, 0, 1, 1]) 478 | shift = (rng.random([3]) - 0.5) * 2.0 479 | coord_s = np.matmul( 480 | np.remainder(np.matmul(coord + shift, np.linalg.inv(cell)), 1.0), 481 | cell, 482 | ) 483 | 484 | # reshape for input 485 | coord = coord.reshape([nf, -1]) 486 | spin = spin.reshape([nf, -1]) 487 | coord_s = coord_s.reshape([nf, -1]) 488 | atype = atype.reshape([nf, -1]) 489 | cell = cell.reshape([nf, 9]) 490 | 491 | aparam = None 492 | fparam = None 493 | if self.module.get_dim_aparam() > 0: 494 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()]) 495 | if self.module.get_dim_fparam() > 0: 496 | fparam = rng.random([nf, self.module.get_dim_fparam()]) 497 | 498 | ret = [] 499 | module = self.forward_wrapper(self.module) 500 | input_dict = { 501 | "coord": coord, 502 | "atype": atype, 503 | "box": cell, 504 | "aparam": aparam, 505 | "fparam": fparam, 506 | } 507 | if test_spin: 508 | input_dict["spin"] = spin 509 | ret.append(module(**input_dict)) 510 | # translation 511 | input_dict["coord"] = coord_s 512 | ret.append(module(**input_dict)) 513 | 514 | for kk in ret[0]: 515 | if kk in self.output_def: 516 | if ret[0][kk] is None: 517 | assert ret[1][kk] is None 518 | continue 519 | np.testing.assert_allclose( 520 | ret[0][kk], 521 | ret[1][kk], 522 | err_msg=f"compare {kk} before and after transform", 523 | atol=aprec, 524 | ) 525 | else: 526 | msg = f"Unknown output key: {kk}" 527 | raise RuntimeError(msg) 528 | 529 | def test_rot(self) -> None: 530 | """Test rotation.""" 531 | if getattr(self, "skip_test_rot", False): 532 | self.skipTest("Skip test rotation.") 533 | test_spin = getattr(self, "test_spin", False) 534 | rng = np.random.default_rng(GLOBAL_SEED) 535 | natoms = 5 536 | nf = 1 537 | aprec = ( 538 | 0 539 | if self.aprec_dict.get("test_rot", None) is None 540 | else self.aprec_dict["test_rot"] 541 | ) 542 | # rotate only coord and shift to the center of cell 543 | cell = 10.0 * np.eye(3) 544 | coord = 2.0 * rng.random([natoms, 3]) 545 | spin = 0.1 * rng.random([natoms, 3]) 546 | atype = np.array([0, 0, 0, 1, 1]) 547 | shift = np.array([4.0, 4.0, 4.0]) 548 | from scipy.stats import ( 549 | special_ortho_group, 550 | ) 551 | 552 | rmat = special_ortho_group.rvs(3) 553 | coord_rot = np.matmul(coord, rmat) 554 | spin_rot = np.matmul(spin, rmat) 555 | 556 | # reshape for input 557 | coord = (coord + shift).reshape([nf, -1]) 558 | spin = spin.reshape([nf, -1]) 559 | coord_rot = (coord_rot + shift).reshape([nf, -1]) 560 | spin_rot = spin_rot.reshape([nf, -1]) 561 | atype = atype.reshape([nf, -1]) 562 | cell = cell.reshape([nf, 9]) 563 | 564 | aparam = None 565 | fparam = None 566 | if self.module.get_dim_aparam() > 0: 567 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()]) 568 | if self.module.get_dim_fparam() > 0: 569 | fparam = rng.random([nf, self.module.get_dim_fparam()]) 570 | 571 | ret = [] 572 | module = self.forward_wrapper(self.module) 573 | input_dict = { 574 | "coord": coord, 575 | "atype": atype, 576 | "box": cell, 577 | "aparam": aparam, 578 | "fparam": fparam, 579 | } 580 | if test_spin: 581 | input_dict["spin"] = spin 582 | ret.append(module(**input_dict)) 583 | # rotation 584 | input_dict["coord"] = coord_rot 585 | if test_spin: 586 | input_dict["spin"] = spin_rot 587 | ret.append(module(**input_dict)) 588 | 589 | for kk in ret[0]: 590 | if kk in self.output_def: 591 | if ret[0][kk] is None: 592 | assert ret[1][kk] is None 593 | continue 594 | rot_equivariant = ( 595 | check_deriv(self.output_def[kk]) 596 | or kk in self.model_output_equivariant 597 | ) 598 | if not rot_equivariant: 599 | np.testing.assert_allclose( 600 | ret[0][kk], 601 | ret[1][kk], 602 | err_msg=f"compare {kk} before and after transform", 603 | atol=aprec, 604 | ) 605 | else: 606 | v_size = self.output_def[kk].size 607 | if v_size == 3: 608 | rotated_ret_0 = np.matmul(ret[0][kk], rmat) 609 | ret_1 = ret[1][kk] 610 | elif v_size == 9: 611 | ret_0 = ret[0][kk].reshape(-1, 3, 3) 612 | batch_rmat_t = np.repeat( 613 | rmat.T.reshape(1, 3, 3), 614 | ret_0.shape[0], 615 | axis=0, 616 | ) 617 | batch_rmat = np.repeat( 618 | rmat.reshape(1, 3, 3), 619 | ret_0.shape[0], 620 | axis=0, 621 | ) 622 | rotated_ret_0 = np.matmul( 623 | batch_rmat_t, 624 | np.matmul(ret_0, batch_rmat), 625 | ) 626 | ret_1 = ret[1][kk].reshape(-1, 3, 3) 627 | else: 628 | # unsupported dim 629 | continue 630 | np.testing.assert_allclose( 631 | rotated_ret_0, 632 | ret_1, 633 | err_msg=f"compare {kk} before and after transform", 634 | atol=aprec, 635 | ) 636 | else: 637 | msg = f"Unknown output key: {kk}" 638 | raise RuntimeError(msg) 639 | 640 | # rotate coord and cell 641 | cell = rng.random([3, 3]) 642 | cell = (cell + cell.T) + 5.0 * np.eye(3) 643 | coord = rng.random([natoms, 3]) 644 | coord = np.matmul(coord, cell) 645 | spin = 0.1 * rng.random([natoms, 3]) 646 | atype = np.array([0, 0, 0, 1, 1]) 647 | coord_rot = np.matmul(coord, rmat) 648 | cell_rot = np.matmul(cell, rmat) 649 | spin_rot = np.matmul(spin, rmat) 650 | 651 | # reshape for input 652 | coord = coord.reshape([nf, -1]) 653 | spin = spin.reshape([nf, -1]) 654 | coord_rot = coord_rot.reshape([nf, -1]) 655 | spin_rot = spin_rot.reshape([nf, -1]) 656 | atype = atype.reshape([nf, -1]) 657 | cell = cell.reshape([nf, 9]) 658 | cell_rot = cell_rot.reshape([nf, 9]) 659 | 660 | ret = [] 661 | module = self.forward_wrapper(self.module) 662 | input_dict = { 663 | "coord": coord, 664 | "atype": atype, 665 | "box": cell, 666 | "aparam": aparam, 667 | "fparam": fparam, 668 | } 669 | if test_spin: 670 | input_dict["spin"] = spin 671 | ret.append(module(**input_dict)) 672 | # rotation 673 | input_dict["coord"] = coord_rot 674 | input_dict["box"] = cell_rot 675 | if test_spin: 676 | input_dict["spin"] = spin_rot 677 | ret.append(module(**input_dict)) 678 | 679 | for kk in ret[0]: 680 | if kk in self.output_def: 681 | if ret[0][kk] is None: 682 | assert ret[1][kk] is None 683 | continue 684 | rot_equivariant = ( 685 | check_deriv(self.output_def[kk]) 686 | or kk in self.model_output_equivariant 687 | ) 688 | if not rot_equivariant: 689 | np.testing.assert_allclose( 690 | ret[0][kk], 691 | ret[1][kk], 692 | err_msg=f"compare {kk} before and after transform", 693 | atol=aprec, 694 | ) 695 | else: 696 | v_size = self.output_def[kk].size 697 | if v_size == 3: 698 | rotated_ret_0 = np.matmul(ret[0][kk], rmat) 699 | ret_1 = ret[1][kk] 700 | elif v_size == 9: 701 | ret_0 = ret[0][kk].reshape(-1, 3, 3) 702 | batch_rmat_t = np.repeat( 703 | rmat.T.reshape(1, 3, 3), 704 | ret_0.shape[0], 705 | axis=0, 706 | ) 707 | batch_rmat = np.repeat( 708 | rmat.reshape(1, 3, 3), 709 | ret_0.shape[0], 710 | axis=0, 711 | ) 712 | rotated_ret_0 = np.matmul( 713 | batch_rmat_t, 714 | np.matmul(ret_0, batch_rmat), 715 | ) 716 | ret_1 = ret[1][kk].reshape(-1, 3, 3) 717 | else: 718 | # unsupported dim 719 | continue 720 | np.testing.assert_allclose( 721 | rotated_ret_0, 722 | ret_1, 723 | err_msg=f"compare {kk} before and after transform", 724 | atol=aprec, 725 | ) 726 | else: 727 | msg = f"Unknown output key: {kk}" 728 | raise RuntimeError(msg) 729 | 730 | def test_smooth(self) -> None: 731 | """Test smooth.""" 732 | if getattr(self, "skip_test_smooth", False): 733 | self.skipTest("Skip test smooth.") 734 | test_spin = getattr(self, "test_spin", False) 735 | rng = np.random.default_rng(GLOBAL_SEED) 736 | epsilon = ( 737 | 1e-5 738 | if self.epsilon_dict.get("test_smooth", None) is None 739 | else self.epsilon_dict["test_smooth"] 740 | ) 741 | assert epsilon is not None 742 | # required prec. 743 | rprec = ( 744 | 1e-5 745 | if self.rprec_dict.get("test_smooth", None) is None 746 | else self.rprec_dict["test_smooth"] 747 | ) 748 | aprec = ( 749 | 1e-5 750 | if self.aprec_dict.get("test_smooth", None) is None 751 | else self.aprec_dict["test_smooth"] 752 | ) 753 | natoms = 10 754 | nf = 1 755 | cell = 10.0 * np.eye(3) 756 | atype0 = np.arange(2) 757 | atype1 = rng.integers(0, 2, size=natoms - 2) 758 | atype = np.concatenate([atype0, atype1]).reshape(natoms) 759 | spin = 0.1 * rng.random([natoms, 3]) 760 | coord0 = np.array( 761 | [ 762 | 0.0, 763 | 0.0, 764 | 0.0, 765 | self.expected_rcut - 0.5 * epsilon, 766 | 0.0, 767 | 0.0, 768 | 0.0, 769 | self.expected_rcut - 0.5 * epsilon, 770 | 0.0, 771 | ], 772 | ).reshape(-1, 3) 773 | coord1 = rng.random([natoms - coord0.shape[0], 3]) 774 | coord1 = np.matmul(coord1, cell) 775 | coord = np.concatenate([coord0, coord1], axis=0) 776 | 777 | coord0 = deepcopy(coord) 778 | coord1 = deepcopy(coord) 779 | coord1[1][0] += epsilon 780 | coord2 = deepcopy(coord) 781 | coord2[2][1] += epsilon 782 | coord3 = deepcopy(coord) 783 | coord3[1][0] += epsilon 784 | coord3[2][1] += epsilon 785 | 786 | # reshape for input 787 | coord0 = coord0.reshape([nf, -1]) 788 | coord1 = coord1.reshape([nf, -1]) 789 | coord2 = coord2.reshape([nf, -1]) 790 | coord3 = coord3.reshape([nf, -1]) 791 | spin = spin.reshape([nf, -1]) 792 | atype = atype.reshape([nf, -1]) 793 | cell = cell.reshape([nf, 9]) 794 | 795 | aparam = None 796 | fparam = None 797 | if self.module.get_dim_aparam() > 0: 798 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()]) 799 | if self.module.get_dim_fparam() > 0: 800 | fparam = rng.random([nf, self.module.get_dim_fparam()]) 801 | 802 | ret = [] 803 | module = self.forward_wrapper(self.module) 804 | input_dict = {"atype": atype, "box": cell, "aparam": aparam, "fparam": fparam} 805 | if test_spin: 806 | input_dict["spin"] = spin 807 | # coord0 808 | input_dict["coord"] = coord0 809 | ret.append(module(**input_dict)) 810 | # coord1 811 | input_dict["coord"] = coord1 812 | ret.append(module(**input_dict)) 813 | # coord2 814 | input_dict["coord"] = coord2 815 | ret.append(module(**input_dict)) 816 | # coord3 817 | input_dict["coord"] = coord3 818 | ret.append(module(**input_dict)) 819 | 820 | for kk in ret[0]: 821 | if kk in self.output_def: 822 | if ret[0][kk] is None: 823 | for ii in range(len(ret) - 1): 824 | assert ret[ii + 1][kk] is None 825 | continue 826 | for ii in range(len(ret) - 1): 827 | np.testing.assert_allclose( 828 | ret[0][kk], 829 | ret[ii + 1][kk], 830 | err_msg=f"compare {kk} before and after transform", 831 | atol=aprec, 832 | rtol=rprec, 833 | ) 834 | else: 835 | msg = f"Unknown output key: {kk}" 836 | raise RuntimeError(msg) 837 | 838 | def test_autodiff(self) -> None: 839 | """Test autodiff.""" 840 | if getattr(self, "skip_test_autodiff", False): 841 | self.skipTest("Skip test autodiff.") 842 | test_spin = getattr(self, "test_spin", False) 843 | 844 | places = 4 845 | delta = 1e-5 846 | 847 | def finite_difference(f, x, delta=1e-6): 848 | in_shape = x.shape 849 | y0 = f(x) 850 | out_shape = y0.shape 851 | res = np.empty(out_shape + in_shape) 852 | for idx in np.ndindex(*in_shape): 853 | diff = np.zeros(in_shape) 854 | diff[idx] += delta 855 | y1p = f(x + diff) 856 | y1n = f(x - diff) 857 | res[(Ellipsis, *idx)] = (y1p - y1n) / (2 * delta) 858 | return res 859 | 860 | def stretch_box(old_coord, old_box, new_box): 861 | ocoord = old_coord.reshape(-1, 3) 862 | obox = old_box.reshape(3, 3) 863 | nbox = new_box.reshape(3, 3) 864 | ncoord = ocoord @ np.linalg.inv(obox) @ nbox 865 | return ncoord.reshape(old_coord.shape) 866 | 867 | rng = np.random.default_rng(GLOBAL_SEED) 868 | natoms = 5 869 | nf = 1 870 | cell = rng.random([3, 3]) 871 | cell = (cell + cell.T) + 5.0 * np.eye(3) 872 | coord = rng.random([natoms, 3]) 873 | coord = np.matmul(coord, cell) 874 | spin = 0.1 * rng.random([natoms, 3]) 875 | atype = np.array([0, 0, 0, 1, 1]) 876 | 877 | # reshape for input 878 | coord = coord.reshape([nf, -1]) 879 | spin = spin.reshape([nf, -1]) 880 | atype = atype.reshape([nf, -1]) 881 | cell = cell.reshape([nf, 9]) 882 | 883 | aparam = None 884 | fparam = None 885 | if self.module.get_dim_aparam() > 0: 886 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()]) 887 | if self.module.get_dim_fparam() > 0: 888 | fparam = rng.random([nf, self.module.get_dim_fparam()]) 889 | 890 | module = self.forward_wrapper(self.module) 891 | 892 | # only test force and virial for energy model 893 | def ff_coord(_coord): 894 | input_dict = { 895 | "coord": _coord, 896 | "atype": atype, 897 | "box": cell, 898 | "aparam": aparam, 899 | "fparam": fparam, 900 | } 901 | if test_spin: 902 | input_dict["spin"] = spin 903 | return module(**input_dict)["energy"] 904 | 905 | def ff_spin(_spin): 906 | input_dict = { 907 | "coord": coord, 908 | "atype": atype, 909 | "box": cell, 910 | "aparam": aparam, 911 | "fparam": fparam, 912 | } 913 | if test_spin: 914 | input_dict["spin"] = _spin 915 | return module(**input_dict)["energy"] 916 | 917 | fdf = -finite_difference(ff_coord, coord, delta=delta).squeeze() 918 | input_dict = { 919 | "coord": coord, 920 | "atype": atype, 921 | "box": cell, 922 | "aparam": aparam, 923 | "fparam": fparam, 924 | } 925 | if test_spin: 926 | input_dict["spin"] = spin 927 | rff = module(**input_dict)["force"] 928 | np.testing.assert_almost_equal( 929 | fdf.reshape(-1, 3), 930 | rff.reshape(-1, 3), 931 | decimal=places, 932 | ) 933 | 934 | if test_spin: 935 | # magnetic force 936 | fdf = -finite_difference(ff_spin, spin, delta=delta).squeeze() 937 | rff = module(**input_dict)["force_mag"] 938 | np.testing.assert_almost_equal( 939 | fdf.reshape(-1, 3), 940 | rff.reshape(-1, 3), 941 | decimal=places, 942 | ) 943 | 944 | if not test_spin: 945 | 946 | def ff_cell(bb): 947 | input_dict = { 948 | "coord": stretch_box(coord, cell, bb), 949 | "atype": atype, 950 | "box": bb, 951 | "aparam": aparam, 952 | "fparam": fparam, 953 | } 954 | return module(**input_dict)["energy"] 955 | 956 | fdv = ( 957 | -( 958 | finite_difference(ff_cell, cell, delta=delta) 959 | .reshape(-1, 3, 3) 960 | .transpose(0, 2, 1) 961 | @ cell.reshape(-1, 3, 3) 962 | ) 963 | .squeeze() 964 | .reshape(9) 965 | ) 966 | input_dict = { 967 | "coord": stretch_box(coord, cell, cell), 968 | "atype": atype, 969 | "box": cell, 970 | "aparam": aparam, 971 | "fparam": fparam, 972 | } 973 | rfv = module(**input_dict)["virial"] 974 | np.testing.assert_almost_equal( 975 | fdv.reshape(-1, 9), 976 | rfv.reshape(-1, 9), 977 | decimal=places, 978 | ) 979 | else: 980 | # not support virial by far 981 | pass 982 | 983 | def test_device_consistence(self) -> None: 984 | """Test forward consistency between devices.""" 985 | test_spin = getattr(self, "test_spin", False) 986 | nf = 1 987 | natoms = 5 988 | rng = np.random.default_rng(GLOBAL_SEED) 989 | coord = 4.0 * rng.random([natoms, 3]).reshape([nf, -1]) 990 | atype = np.array([0, 0, 0, 1, 1], dtype=int).reshape([nf, -1]) 991 | spin = 0.5 * rng.random([natoms, 3]).reshape([nf, -1]) 992 | cell = 6.0 * np.eye(3).reshape([nf, 9]) 993 | aparam = None 994 | fparam = None 995 | if self.module.get_dim_aparam() > 0: 996 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()]) 997 | if self.module.get_dim_fparam() > 0: 998 | fparam = rng.random([nf, self.module.get_dim_fparam()]) 999 | ret = [] 1000 | device_module = self.forward_wrapper(self.module) 1001 | ref_module = self.forward_wrapper_cpu_ref(deepcopy(self.module)) 1002 | 1003 | for module in [device_module, ref_module]: 1004 | input_dict = { 1005 | "coord": coord, 1006 | "atype": atype, 1007 | "box": cell, 1008 | "aparam": aparam, 1009 | "fparam": fparam, 1010 | } 1011 | if test_spin: 1012 | input_dict["spin"] = spin 1013 | ret.append(module(**input_dict)) 1014 | for kk in ret[0]: 1015 | subret = [rr[kk] for rr in ret if rr is not None] 1016 | if subret: 1017 | for ii, rr in enumerate(subret[1:]): 1018 | if subret[0] is None: 1019 | assert rr is None 1020 | else: 1021 | np.testing.assert_allclose( 1022 | subret[0], 1023 | rr, 1024 | err_msg=f"compare {kk} between 0 and {ii}", 1025 | atol=1e-10, 1026 | ) 1027 | 1028 | 1029 | class EnerModelTest(ModelTestCase): 1030 | @classmethod 1031 | def setUpClass(cls) -> None: 1032 | cls.expected_rcut = 5.0 1033 | cls.expected_type_map = ["O", "H"] 1034 | cls.expected_dim_fparam = 0 1035 | cls.expected_dim_aparam = 0 1036 | cls.expected_sel_type = [0, 1] 1037 | cls.expected_aparam_nall = False 1038 | cls.expected_model_output_type = ["energy"] 1039 | cls.model_output_equivariant = [] 1040 | cls.expected_sel = [46, 92] 1041 | cls.expected_sel_mix = sum(cls.expected_sel) # type: ignore[attr-defined] 1042 | cls.expected_has_message_passing = False 1043 | cls.aprec_dict = {} 1044 | cls.rprec_dict = {} 1045 | cls.epsilon_dict = {} 1046 | 1047 | 1048 | class TestMaceModel(unittest.TestCase, EnerModelTest, PTTestCase): # type: ignore[misc] 1049 | """Test MACE model.""" 1050 | 1051 | @property 1052 | def modules_to_test(self) -> list[torch.nn.Module]: # type: ignore[override] 1053 | """Modules to test.""" 1054 | skip_test_jit = getattr(self, "skip_test_jit", False) 1055 | modules = PTTestCase.modules_to_test.fget(self) # type: ignore[attr-defined] 1056 | if not skip_test_jit: 1057 | # for Model, we can test script module API 1058 | modules += [ 1059 | self._script_module 1060 | if hasattr(self, "_script_module") 1061 | else self.script_module, 1062 | ] 1063 | return modules 1064 | 1065 | _script_module: torch.jit.ScriptModule 1066 | 1067 | @classmethod 1068 | def setUpClass(cls) -> None: 1069 | """Set up class.""" 1070 | EnerModelTest.setUpClass() 1071 | 1072 | torch.manual_seed(GLOBAL_SEED + 1) 1073 | cls.module = MaceModel( 1074 | type_map=cls.expected_type_map, 1075 | sel=138, 1076 | precision="float64", 1077 | ) 1078 | with torch.jit.optimized_execution(should_optimize=False): 1079 | cls._script_module = torch.jit.script(cls.module) 1080 | cls.output_def = cls.module.translated_output_def() 1081 | cls.expected_has_message_passing = False 1082 | cls.expected_sel_type = [] 1083 | cls.expected_dim_fparam = 0 1084 | cls.expected_dim_aparam = 0 1085 | cls.expected_nmpnn = 2 1086 | 1087 | 1088 | class TestNequipModel(unittest.TestCase, EnerModelTest, PTTestCase): # type: ignore[misc] 1089 | """Test Nequip model.""" 1090 | 1091 | @property 1092 | def modules_to_test(self) -> list[torch.nn.Module]: # type: ignore[override] 1093 | """Modules to test.""" 1094 | skip_test_jit = getattr(self, "skip_test_jit", False) 1095 | modules = PTTestCase.modules_to_test.fget(self) # type: ignore[attr-defined] 1096 | if not skip_test_jit: 1097 | # for Model, we can test script module API 1098 | modules += [ 1099 | self._script_module 1100 | if hasattr(self, "_script_module") 1101 | else self.script_module, 1102 | ] 1103 | return modules 1104 | 1105 | _script_module: torch.jit.ScriptModule 1106 | 1107 | @classmethod 1108 | def setUpClass(cls) -> None: 1109 | """Set up class.""" 1110 | EnerModelTest.setUpClass() 1111 | 1112 | torch.manual_seed(GLOBAL_SEED + 1) 1113 | cls.module = NequipModel( 1114 | type_map=cls.expected_type_map, 1115 | sel=138, 1116 | r_max=cls.expected_rcut, 1117 | num_layers=2, 1118 | precision="float64", 1119 | ) 1120 | with torch.jit.optimized_execution(should_optimize=False): 1121 | cls._script_module = torch.jit.script(cls.module) 1122 | cls.output_def = cls.module.translated_output_def() 1123 | cls.expected_has_message_passing = False 1124 | cls.expected_sel_type = [] 1125 | cls.expected_dim_fparam = 0 1126 | cls.expected_dim_aparam = 0 1127 | cls.expected_nmpnn = 2 1128 | -------------------------------------------------------------------------------- /tests/test_op.py: -------------------------------------------------------------------------------- 1 | """Test custom operations.""" 2 | 3 | import torch 4 | 5 | import deepmd_gnn.op # noqa: F401 6 | 7 | 8 | def test_one_frame() -> None: 9 | """Test one frame.""" 10 | nlist_ff = torch.tensor( 11 | [ 12 | [1, 2, -1, -1], 13 | [2, 0, -1, -1], 14 | [0, 1, -1, -1], 15 | ], 16 | dtype=torch.int64, 17 | device="cpu", 18 | ) 19 | extended_atype_ff = torch.tensor( 20 | [0, 1, 2], 21 | dtype=torch.int64, 22 | device="cpu", 23 | ) 24 | mm_types = [1, 2] 25 | expected_edge_index = torch.tensor( 26 | [ 27 | [1, 0], 28 | [2, 0], 29 | [0, 1], 30 | [0, 2], 31 | ], 32 | dtype=torch.int64, 33 | device="cpu", 34 | ) 35 | 36 | edge_index = torch.ops.deepmd_gnn.edge_index( 37 | nlist_ff, 38 | extended_atype_ff, 39 | torch.tensor(mm_types, dtype=torch.int64, device="cpu"), 40 | ) 41 | 42 | assert torch.equal(edge_index, expected_edge_index) 43 | 44 | 45 | def test_two_frame() -> None: 46 | """Test one frame.""" 47 | nlist = torch.tensor( 48 | [ 49 | [ 50 | [1, 2, -1, -1], 51 | [2, 0, -1, -1], 52 | [0, 1, -1, -1], 53 | ], 54 | [ 55 | [1, 2, -1, -1], 56 | [2, 0, -1, -1], 57 | [0, 1, -1, -1], 58 | ], 59 | ], 60 | dtype=torch.int64, 61 | device="cpu", 62 | ) 63 | extended_atype = torch.tensor( 64 | [ 65 | [0, 1, 2], 66 | [0, 1, 2], 67 | ], 68 | dtype=torch.int64, 69 | device="cpu", 70 | ) 71 | mm_types = [1, 2] 72 | expected_edge_index = torch.tensor( 73 | [ 74 | [1, 0], 75 | [2, 0], 76 | [0, 1], 77 | [0, 2], 78 | [4, 3], 79 | [5, 3], 80 | [3, 4], 81 | [3, 5], 82 | ], 83 | dtype=torch.int64, 84 | device="cpu", 85 | ) 86 | 87 | edge_index = torch.ops.deepmd_gnn.edge_index( 88 | nlist, 89 | extended_atype, 90 | torch.tensor(mm_types, dtype=torch.int64, device="cpu"), 91 | ) 92 | 93 | assert torch.equal(edge_index, expected_edge_index) 94 | -------------------------------------------------------------------------------- /tests/test_training.py: -------------------------------------------------------------------------------- 1 | """Test training.""" 2 | 3 | import shutil 4 | import subprocess 5 | import sys 6 | import tempfile 7 | from pathlib import Path 8 | 9 | import pytest 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "input_fn", 14 | [ 15 | "mace.json", 16 | "nequip.json", 17 | ], 18 | ) 19 | def test_e2e_training(input_fn) -> None: 20 | """Test training the model.""" 21 | model_fn = "model.pth" 22 | # create temp directory and copy example files 23 | with tempfile.TemporaryDirectory() as _tmpdir: 24 | tmpdir = Path(_tmpdir) 25 | this_dir = Path(__file__).parent 26 | data_path = this_dir / "data" 27 | input_path = this_dir / input_fn 28 | # copy data to tmpdir 29 | shutil.copytree(data_path, tmpdir / "data") 30 | # copy input.json to tmpdir 31 | shutil.copy(input_path, tmpdir / input_fn) 32 | 33 | subprocess.check_call( 34 | [ 35 | sys.executable, 36 | "-m", 37 | "deepmd", 38 | "--pt", 39 | "train", 40 | input_fn, 41 | ], 42 | cwd=tmpdir, 43 | ) 44 | subprocess.check_call( 45 | [ 46 | sys.executable, 47 | "-m", 48 | "deepmd", 49 | "--pt", 50 | "freeze", 51 | "-o", 52 | model_fn, 53 | ], 54 | cwd=tmpdir, 55 | ) 56 | assert (tmpdir / model_fn).exists() 57 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | """Test version.""" 2 | 3 | from __future__ import annotations 4 | 5 | from importlib.metadata import version 6 | 7 | from deepmd_gnn import __version__ 8 | 9 | 10 | def test_version() -> None: 11 | """Test version.""" 12 | assert version("deepmd-gnn") == __version__ 13 | --------------------------------------------------------------------------------