├── .clang-format
├── .github
├── dependabot.yml
└── workflows
│ ├── mirror-gitlab.yaml
│ ├── release.yaml
│ └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yml
├── CMakeLists.txt
├── LICENSE
├── README.md
├── build_backend
├── __init__.py
└── dp_backend.py
├── deepmd_gnn
├── __init__.py
├── __main__.py
├── argcheck.py
├── env.py
├── mace.py
├── nequip.py
├── op.py
└── py.typed
├── docs
├── Makefile
├── conf.py
├── index.rst
└── parameters.rst
├── examples
├── .gitignore
├── dprc
│ ├── data
│ │ ├── nopbc
│ │ ├── set.000
│ │ │ ├── box.npy
│ │ │ ├── coord.npy
│ │ │ ├── energy.npy
│ │ │ └── force.npy
│ │ ├── type.raw
│ │ └── type_map.raw
│ ├── mace
│ │ └── input.json
│ └── nequip
│ │ └── input.json
└── water
│ ├── data
│ ├── data_0
│ │ ├── set.000
│ │ │ ├── box.npy
│ │ │ ├── coord.npy
│ │ │ ├── energy.npy
│ │ │ └── force.npy
│ │ ├── type.raw
│ │ └── type_map.raw
│ ├── data_1
│ │ ├── set.000
│ │ │ ├── box.npy
│ │ │ ├── coord.npy
│ │ │ ├── energy.npy
│ │ │ └── force.npy
│ │ ├── set.001
│ │ │ ├── box.npy
│ │ │ ├── coord.npy
│ │ │ ├── energy.npy
│ │ │ └── force.npy
│ │ ├── type.raw
│ │ └── type_map.raw
│ ├── data_2
│ │ ├── set.000
│ │ │ ├── box.npy
│ │ │ ├── coord.npy
│ │ │ ├── energy.npy
│ │ │ └── force.npy
│ │ ├── type.raw
│ │ └── type_map.raw
│ └── data_3
│ │ ├── set.000
│ │ ├── box.npy
│ │ ├── coord.npy
│ │ ├── energy.npy
│ │ └── force.npy
│ │ ├── type.raw
│ │ └── type_map.raw
│ ├── mace
│ └── input.json
│ └── nequip
│ └── input.json
├── noxfile.py
├── op
├── CMakeLists.txt
└── edge_index.cc
├── pyproject.toml
├── renovate.json
└── tests
├── __init__.py
├── data
├── set.000
│ ├── box.npy
│ ├── coord.npy
│ ├── energy.npy
│ └── force.npy
├── type.raw
└── type_map.raw
├── mace.json
├── nequip.json
├── test_examples.py
├── test_model.py
├── test_op.py
├── test_training.py
└── test_version.py
/.clang-format:
--------------------------------------------------------------------------------
1 | ---
2 | BasedOnStyle: Google
3 | BinPackParameters: false
4 | InsertBraces: true
5 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "github-actions"
4 | directory: "/"
5 | schedule:
6 | interval: "weekly"
7 |
--------------------------------------------------------------------------------
/.github/workflows/mirror-gitlab.yaml:
--------------------------------------------------------------------------------
1 | name: Mirror to GitLab Repo
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | tags:
8 | - v*
9 |
10 | # Ensures that only one mirror task will run at a time.
11 | concurrency:
12 | group: git-mirror
13 |
14 | jobs:
15 | git-mirror:
16 | if: github.repository_owner == 'njzjz'
17 | runs-on: ubuntu-latest
18 | steps:
19 | - uses: wearerequired/git-mirror-action@v1
20 | env:
21 | SSH_PRIVATE_KEY: ${{ secrets.SYNC_GITLAB_PRIVATE_KEY }}
22 | with:
23 | source-repo: "https://github.com/njzjz/deepmd-gnn"
24 | destination-repo: "git@gitlab.com:RutgersLBSR/deepmd-gnn"
25 |
--------------------------------------------------------------------------------
/.github/workflows/release.yaml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | branches:
4 | - master
5 | tags:
6 | - "v*"
7 | pull_request:
8 | name: Build and release to pypi
9 | permissions:
10 | contents: read
11 | jobs:
12 | release-build:
13 | name: Build and upload distributions
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v4
17 | - name: Set up uv
18 | uses: astral-sh/setup-uv@v5
19 | with:
20 | enable-cache: true
21 | cache-dependency-glob: |
22 | **/requirements*.txt
23 | **/pyproject.toml
24 | - name: Build dist
25 | run: uv tool run --with build[uv] --from build python -m build --installer uv --sdist
26 | - name: Upload release distributions
27 | uses: actions/upload-artifact@v4
28 | with:
29 | name: release-dists
30 | path: dist/
31 | release-wheel:
32 | name: Build wheels for cp${{ matrix.python }}-${{ matrix.platform_id }}
33 | runs-on: ${{ matrix.os }}
34 | strategy:
35 | fail-fast: false
36 | matrix:
37 | include:
38 | # linux-64
39 | - os: ubuntu-latest
40 | python: 312
41 | platform_id: manylinux_x86_64
42 | # linux-aarch64
43 | - os: ubuntu-24.04-arm
44 | python: 312
45 | platform_id: manylinux_aarch64
46 | # macos-x86-64
47 | - os: macos-13
48 | python: 312
49 | platform_id: macosx_x86_64
50 | # macos-arm64
51 | - os: macos-14
52 | python: 312
53 | platform_id: macosx_arm64
54 | # win-64
55 | - os: windows-2019
56 | python: 312
57 | platform_id: win_amd64
58 | steps:
59 | - uses: actions/checkout@v4
60 | - name: Set up uv
61 | uses: astral-sh/setup-uv@v5
62 | with:
63 | enable-cache: true
64 | cache-dependency-glob: |
65 | **/requirements*.txt
66 | **/pyproject.toml
67 | if: runner.os != 'Linux'
68 | - name: Build wheels
69 | uses: pypa/cibuildwheel@v2.23
70 | env:
71 | CIBW_ARCHS: all
72 | CIBW_BUILD: cp${{ matrix.python }}-${{ matrix.platform_id }}
73 | CIBW_BUILD_FRONTEND: "build[uv]"
74 | - uses: actions/upload-artifact@v4
75 | with:
76 | name: release-cibw-cp${{ matrix.python }}-${{ matrix.platform_id }}-${{ strategy.job-index }}
77 | path: ./wheelhouse/*.whl
78 | pypi-publish:
79 | name: Release to pypi
80 | runs-on: ubuntu-latest
81 | environment:
82 | name: pypi_publish
83 | url: https://pypi.org/p/python-template
84 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
85 | needs:
86 | - release-build
87 | - release-wheel
88 | permissions:
89 | id-token: write
90 | steps:
91 | - name: Retrieve release distributions
92 | uses: actions/download-artifact@v4
93 | with:
94 | merge-multiple: true
95 | pattern: release-*
96 | path: dist/
97 | - name: Publish release distributions to PyPI
98 | uses: pypa/gh-action-pypi-publish@release/v1
99 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: Test Python package
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 |
9 | env:
10 | UV_SYSTEM_PYTHON: 1
11 | jobs:
12 | build:
13 | runs-on: ubuntu-latest
14 | permissions:
15 | id-token: write
16 | contents: read
17 | steps:
18 | - uses: actions/checkout@v4
19 | - name: Set up uv
20 | uses: astral-sh/setup-uv@v5
21 | with:
22 | enable-cache: true
23 | cache-dependency-glob: |
24 | **/requirements*.txt
25 | **/pyproject.toml
26 | - name: Set up Python
27 | uses: actions/setup-python@v5
28 | with:
29 | python-version: 3.12
30 | - name: Install dependencies
31 | run: uv pip install nox[uv]
32 | - name: Test with pytest
33 | run: nox -db uv
34 | - name: Upload coverage reports to Codecov
35 | uses: codecov/codecov-action@v5
36 | with:
37 | use_oidc: ${{ !(github.event_name == 'pull_request' && github.event.pull_request.head.repo.fork) }}
38 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 | .ruff_cache/
164 | node_modules/
165 | deepmd_gnn/_version.py
166 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v5.0.0
6 | hooks:
7 | - id: trailing-whitespace
8 | - id: end-of-file-fixer
9 | - id: check-yaml
10 | - id: check-json
11 | - id: check-added-large-files
12 | - id: check-merge-conflict
13 | - id: check-symlinks
14 | - id: check-toml
15 | - id: mixed-line-ending
16 | # Python
17 | - repo: https://github.com/astral-sh/ruff-pre-commit
18 | rev: v0.11.7
19 | hooks:
20 | - id: ruff
21 | args: ["--fix"]
22 | - id: ruff-format
23 | - repo: https://github.com/pre-commit/mirrors-mypy
24 | rev: "v1.15.0"
25 | hooks:
26 | - id: mypy
27 | additional_dependencies: []
28 | - repo: https://github.com/pre-commit/mirrors-prettier
29 | rev: v4.0.0-alpha.8
30 | hooks:
31 | - id: prettier
32 | types_or: [javascript, css, html, markdown, yaml]
33 | # C++
34 | - repo: https://github.com/pre-commit/mirrors-clang-format
35 | rev: v20.1.0
36 | hooks:
37 | - id: clang-format
38 | # CMake
39 | - repo: https://github.com/cheshirekow/cmake-format-precommit
40 | rev: v0.6.13
41 | hooks:
42 | - id: cmake-format
43 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | build:
3 | os: ubuntu-24.04
4 | tools:
5 | python: "3.12"
6 | jobs:
7 | install:
8 | - asdf plugin add uv
9 | - asdf install uv latest
10 | - asdf global uv latest
11 | - uv venv $READTHEDOCS_VIRTUALENV_PATH
12 | - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH uv pip install deepmd-kit[torch]>=3.0.0b2 --extra-index-url https://download.pytorch.org/whl/cpu
13 | - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH CMAKE_PREFIX_PATH=$(python -c "import torch;print(torch.utils.cmake_prefix_path)") uv pip install -e .[docs]
14 | - $READTHEDOCS_VIRTUALENV_PATH/bin/python -m sphinx -T -b html -d docs/_build/doctrees -D language=en docs $READTHEDOCS_OUTPUT/html
15 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.15)
2 | project(deepmd-gnn CXX)
3 |
4 | set(CMAKE_CXX_STANDARD 14)
5 | macro(set_if_higher VARIABLE VALUE)
6 | # ${VARIABLE} is a variable name, not a string
7 | if(${VARIABLE} LESS "${VALUE}")
8 | set(${VARIABLE} ${VALUE})
9 | endif()
10 | endmacro()
11 |
12 | # build cpp or python interfaces
13 | option(BUILD_CPP_IF "Build C++ interfaces" ON)
14 | option(BUILD_PY_IF "Build Python interfaces" OFF)
15 | option(USE_PT_PYTHON_LIBS "Use PyTorch Python libraries" OFF)
16 |
17 | if((NOT BUILD_PY_IF) AND (NOT BUILD_CPP_IF))
18 | # nothing to do
19 | message(FATAL_ERROR "Nothing to build.")
20 | endif()
21 |
22 | if(BUILD_CPP_IF
23 | AND USE_PT_PYTHON_LIBS
24 | AND NOT CMAKE_CROSSCOMPILING
25 | AND NOT SKBUILD
26 | OR "$ENV{CIBUILDWHEEL}" STREQUAL "1")
27 | find_package(
28 | Python
29 | COMPONENTS Interpreter
30 | REQUIRED)
31 | execute_process(
32 | COMMAND ${Python_EXECUTABLE} -c
33 | "import torch;print(torch.utils.cmake_prefix_path)"
34 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
35 | OUTPUT_VARIABLE PYTORCH_CMAKE_PREFIX_PATH
36 | RESULT_VARIABLE PYTORCH_CMAKE_PREFIX_PATH_RESULT_VAR
37 | ERROR_VARIABLE PYTORCH_CMAKE_PREFIX_PATH_ERROR_VAR
38 | OUTPUT_STRIP_TRAILING_WHITESPACE)
39 | if(NOT ${PYTORCH_CMAKE_PREFIX_PATH_RESULT_VAR} EQUAL 0)
40 | message(
41 | FATAL_ERROR
42 | "Cannot determine PyTorch CMake prefix path, error code: $PYTORCH_CMAKE_PREFIX_PATH_RESULT_VAR}, error message: ${PYTORCH_CMAKE_PREFIX_PATH_ERROR_VAR}"
43 | )
44 | endif()
45 | list(APPEND CMAKE_PREFIX_PATH ${PYTORCH_CMAKE_PREFIX_PATH})
46 | endif()
47 | find_package(Torch REQUIRED)
48 | if(NOT Torch_VERSION VERSION_LESS "2.1.0")
49 | set_if_higher(CMAKE_CXX_STANDARD 17)
50 | elseif(NOT Torch_VERSION VERSION_LESS "1.5.0")
51 | set_if_higher(CMAKE_CXX_STANDARD 14)
52 | endif()
53 | string(REGEX MATCH "_GLIBCXX_USE_CXX11_ABI=([0-9]+)" CXXABI_PT_MATCH
54 | "${TORCH_CXX_FLAGS}")
55 | if(CXXABI_PT_MATCH)
56 | set(OP_CXX_ABI_PT ${CMAKE_MATCH_1})
57 | message(STATUS "PyTorch CXX11 ABI: ${CMAKE_MATCH_1}")
58 | else()
59 | # Maybe in macos/windows
60 | set(OP_CXX_ABI_PT 0)
61 | endif()
62 |
63 | # define build type
64 | if((NOT DEFINED CMAKE_BUILD_TYPE) OR CMAKE_BUILD_TYPE STREQUAL "")
65 | set(CMAKE_BUILD_TYPE release)
66 | endif()
67 |
68 | add_subdirectory(op)
69 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU LESSER GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 |
9 | This version of the GNU Lesser General Public License incorporates
10 | the terms and conditions of version 3 of the GNU General Public
11 | License, supplemented by the additional permissions listed below.
12 |
13 | 0. Additional Definitions.
14 |
15 | As used herein, "this License" refers to version 3 of the GNU Lesser
16 | General Public License, and the "GNU GPL" refers to version 3 of the GNU
17 | General Public License.
18 |
19 | "The Library" refers to a covered work governed by this License,
20 | other than an Application or a Combined Work as defined below.
21 |
22 | An "Application" is any work that makes use of an interface provided
23 | by the Library, but which is not otherwise based on the Library.
24 | Defining a subclass of a class defined by the Library is deemed a mode
25 | of using an interface provided by the Library.
26 |
27 | A "Combined Work" is a work produced by combining or linking an
28 | Application with the Library. The particular version of the Library
29 | with which the Combined Work was made is also called the "Linked
30 | Version".
31 |
32 | The "Minimal Corresponding Source" for a Combined Work means the
33 | Corresponding Source for the Combined Work, excluding any source code
34 | for portions of the Combined Work that, considered in isolation, are
35 | based on the Application, and not on the Linked Version.
36 |
37 | The "Corresponding Application Code" for a Combined Work means the
38 | object code and/or source code for the Application, including any data
39 | and utility programs needed for reproducing the Combined Work from the
40 | Application, but excluding the System Libraries of the Combined Work.
41 |
42 | 1. Exception to Section 3 of the GNU GPL.
43 |
44 | You may convey a covered work under sections 3 and 4 of this License
45 | without being bound by section 3 of the GNU GPL.
46 |
47 | 2. Conveying Modified Versions.
48 |
49 | If you modify a copy of the Library, and, in your modifications, a
50 | facility refers to a function or data to be supplied by an Application
51 | that uses the facility (other than as an argument passed when the
52 | facility is invoked), then you may convey a copy of the modified
53 | version:
54 |
55 | a) under this License, provided that you make a good faith effort to
56 | ensure that, in the event an Application does not supply the
57 | function or data, the facility still operates, and performs
58 | whatever part of its purpose remains meaningful, or
59 |
60 | b) under the GNU GPL, with none of the additional permissions of
61 | this License applicable to that copy.
62 |
63 | 3. Object Code Incorporating Material from Library Header Files.
64 |
65 | The object code form of an Application may incorporate material from
66 | a header file that is part of the Library. You may convey such object
67 | code under terms of your choice, provided that, if the incorporated
68 | material is not limited to numerical parameters, data structure
69 | layouts and accessors, or small macros, inline functions and templates
70 | (ten or fewer lines in length), you do both of the following:
71 |
72 | a) Give prominent notice with each copy of the object code that the
73 | Library is used in it and that the Library and its use are
74 | covered by this License.
75 |
76 | b) Accompany the object code with a copy of the GNU GPL and this license
77 | document.
78 |
79 | 4. Combined Works.
80 |
81 | You may convey a Combined Work under terms of your choice that,
82 | taken together, effectively do not restrict modification of the
83 | portions of the Library contained in the Combined Work and reverse
84 | engineering for debugging such modifications, if you also do each of
85 | the following:
86 |
87 | a) Give prominent notice with each copy of the Combined Work that
88 | the Library is used in it and that the Library and its use are
89 | covered by this License.
90 |
91 | b) Accompany the Combined Work with a copy of the GNU GPL and this license
92 | document.
93 |
94 | c) For a Combined Work that displays copyright notices during
95 | execution, include the copyright notice for the Library among
96 | these notices, as well as a reference directing the user to the
97 | copies of the GNU GPL and this license document.
98 |
99 | d) Do one of the following:
100 |
101 | 0) Convey the Minimal Corresponding Source under the terms of this
102 | License, and the Corresponding Application Code in a form
103 | suitable for, and under terms that permit, the user to
104 | recombine or relink the Application with a modified version of
105 | the Linked Version to produce a modified Combined Work, in the
106 | manner specified by section 6 of the GNU GPL for conveying
107 | Corresponding Source.
108 |
109 | 1) Use a suitable shared library mechanism for linking with the
110 | Library. A suitable mechanism is one that (a) uses at run time
111 | a copy of the Library already present on the user's computer
112 | system, and (b) will operate properly with a modified version
113 | of the Library that is interface-compatible with the Linked
114 | Version.
115 |
116 | e) Provide Installation Information, but only if you would otherwise
117 | be required to provide such information under section 6 of the
118 | GNU GPL, and only to the extent that such information is
119 | necessary to install and execute a modified version of the
120 | Combined Work produced by recombining or relinking the
121 | Application with a modified version of the Linked Version. (If
122 | you use option 4d0, the Installation Information must accompany
123 | the Minimal Corresponding Source and Corresponding Application
124 | Code. If you use option 4d1, you must provide the Installation
125 | Information in the manner specified by section 6 of the GNU GPL
126 | for conveying Corresponding Source.)
127 |
128 | 5. Combined Libraries.
129 |
130 | You may place library facilities that are a work based on the
131 | Library side by side in a single library together with other library
132 | facilities that are not Applications and are not covered by this
133 | License, and convey such a combined library under terms of your
134 | choice, if you do both of the following:
135 |
136 | a) Accompany the combined library with a copy of the same work based
137 | on the Library, uncombined with any other library facilities,
138 | conveyed under the terms of this License.
139 |
140 | b) Give prominent notice with the combined library that part of it
141 | is a work based on the Library, and explaining where to find the
142 | accompanying uncombined form of the same work.
143 |
144 | 6. Revised Versions of the GNU Lesser General Public License.
145 |
146 | The Free Software Foundation may publish revised and/or new versions
147 | of the GNU Lesser General Public License from time to time. Such new
148 | versions will be similar in spirit to the present version, but may
149 | differ in detail to address new problems or concerns.
150 |
151 | Each version is given a distinguishing version number. If the
152 | Library as you received it specifies that a certain numbered version
153 | of the GNU Lesser General Public License "or any later version"
154 | applies to it, you have the option of following the terms and
155 | conditions either of that published version or of any later version
156 | published by the Free Software Foundation. If the Library as you
157 | received it does not specify a version number of the GNU Lesser
158 | General Public License, you may choose any version of the GNU Lesser
159 | General Public License ever published by the Free Software Foundation.
160 |
161 | If the Library as you received it specifies that a proxy can decide
162 | whether future versions of the GNU Lesser General Public License shall
163 | apply, that proxy's public statement of acceptance of any version is
164 | permanent authorization for you to choose that version for the
165 | Library.
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeePMD-kit plugin for various graph neural network models
2 |
3 | [](https://doi.org/10.1021/acs.jcim.4c02441)
4 | [](https://doi.org/10.1021/acs.jcim.4c02441)
5 | [](https://anaconda.org/conda-forge/deepmd-gnn)
6 | [](https://pypi.org/p/deepmd-gnn)
7 |
8 | `deepmd-gnn` is a [DeePMD-kit](https://github.com/deepmodeling/deepmd-kit) plugin for various graph neural network (GNN) models, which connects DeePMD-kit and atomistic GNN packages by enabling GNN models in DeePMD-kit.
9 |
10 | Supported packages and models include:
11 |
12 | - [MACE](https://github.com/ACEsuit/mace) (PyTorch version)
13 | - [NequIP](https://github.com/mir-group/nequip) (PyTorch version)
14 |
15 | After [installing the plugin](#installation), you can train the GNN models using DeePMD-kit, run active learning cycles for the GNN models using [DP-GEN](https://github.com/deepmodeling/dpgen), perform simulations with the MACE model using molecular dynamic packages supported by DeePMD-kit, such as [LAMMPS](https://github.com/lammps/lammps) and [AMBER](https://ambermd.org/).
16 | You can follow [DeePMD-kit documentation](https://docs.deepmodeling.com/projects/deepmd/en/latest/) to train the GNN models using its PyTorch backend, after using the specific [model parameters](#parameters).
17 |
18 | ## Credits
19 |
20 | If you use this software, please cite the following paper:
21 |
22 | - Jinzhe Zeng, Timothy J. Giese, Duo Zhang, Han Wang, Darrin M. York, DeePMD-GNN: A DeePMD-kit Plugin for External Graph Neural Network Potentials, _J. Chem. Inf. Model._, 2025, 65, 7, 3154-3160, DOI: [10.1021/acs.jcim.4c02441](https://doi.org/10.1021/acs.jcim.4c02441). [](https://badge.dimensions.ai/details/doi/10.1021/acs.jcim.4c02441)
23 |
24 | ## Installation
25 |
26 | ### Install via conda
27 |
28 | If you are in a [conda environment](https://docs.deepmodeling.com/faq/conda.html) where DeePMD-kit is already installed from the conda-forge channel,
29 | you can use `conda` to install the DeePMD-GNN plugin:
30 |
31 | ```sh
32 | conda install deepmd-gnn -c conda-forge
33 | ```
34 |
35 | ### Build from source
36 |
37 | First, clone this repository:
38 |
39 | ```sh
40 | git clone https://gitlab.com/RutgersLBSR/deepmd-gnn
41 | cd deepmd-gnn
42 | ```
43 |
44 | #### Python interface plugin
45 |
46 | Python 3.9 or above is required. A C++ compiler that supports C++ 14 (for PyTorch 2.0) or C++ 17 (for PyTorch 2.1 or above) is required.
47 |
48 | Assume you have installed [DeePMD-kit](https://github.com/deepmodeling/deepmd-kit) (v3.0.0b2 or above) and [PyTorch](https://github.com/pytorch/pytorch) in an environment, then execute
49 |
50 | ```sh
51 | # expose PyTorch CMake modules
52 | export CMAKE_PREFIX_PATH=$(python -c "import torch;print(torch.utils.cmake_prefix_path)")
53 |
54 | pip install .
55 | ```
56 |
57 | #### C++ interface plugin
58 |
59 | DeePMD-kit version should be v3.0.0b4 or later.
60 |
61 | Follow [DeePMD-kit documentation](https://docs.deepmodeling.com/projects/deepmd/en/latest/install/install-from-source.html#install-the-c-interface) to install DeePMD-kit C++ interface with PyTorch backend support and other related MD packages.
62 | After that, you can build the plugin
63 |
64 | ```sh
65 | # Assume libtorch has been contained in CMAKE_PREFIX_PATH
66 | mkdir -p build
67 | cd build
68 | cmake .. -D CMAKE_INSTALL_PREFIX=/prefix/to/install
69 | cmake --build . -j8
70 | cmake --install .
71 | ```
72 |
73 | `libdeepmd_gnn.so` will be installed into the directory you assign.
74 | When using any DeePMD-kit C++ interface, set the following environment variable in advance:
75 |
76 | ```sh
77 | export DP_PLUGIN_PATH=/prefix/to/install/lib/libdeepmd_gnn.so
78 | ```
79 |
80 | ## Usage
81 |
82 | Follow [Parameters section](#parameters) to prepare a DeePMD-kit input file.
83 |
84 | ```sh
85 | dp --pt train input.json
86 | dp --pt freeze
87 | ```
88 |
89 | A frozen model file named `frozen_model.pth` will be generated. You can use it in the MD packages or other interfaces.
90 | For details, follow [DeePMD-kit documentation](https://docs.deepmodeling.com/projects/deepmd/en/latest/).
91 |
92 | ### Running LAMMPS + MACE with period boundary conditions
93 |
94 | GNN models use message passing neural networks,
95 | so the neighbor list built with traditional cutoff radius will not work,
96 | since the ghost atoms also need to build neighbor list.
97 | By default, the model requests the neighbor list with a cutoff radius of $r_c \times N_{L}$,
98 | where $r_c$ is set by `r_max` and $N_L$ is set by `num_interactions` (MACE) / `num_layers` (NequIP),
99 | and rebuilds the neighbor list for ghost atoms.
100 | However, this approach is very inefficient.
101 |
102 | The alternative approach for the MACE model (note: NequIP doesn't support such approach) is to use the mapping passed from LAMMPS, which does not support MPI.
103 | One needs to set `DP_GNN_USE_MAPPING` when freezing the models,
104 |
105 | ```sh
106 | DP_GNN_USE_MAPPING=1 dp --pt freeze
107 | ```
108 |
109 | and request the mapping when using LAMMPS (also requires DeePMD-kit v3.0.0rc0 or above).
110 | By using the mapping, the ghost atoms will be mapped to the real atoms,
111 | so the regular neighbor list with a cutoff radius of $r_c$ can be used.
112 |
113 | ```lammps
114 | atom_modify map array
115 | ```
116 |
117 | In the future, we will explore utilizing the MPI to communicate the neighbor list,
118 | while this approach requires a deep hack for external packages.
119 |
120 | ## Parameters
121 |
122 | ### MACE
123 |
124 | To use the MACE model, set `"type": "mace"` in the `model` section of the training script.
125 | Below is default values for the MACE model, most of which follows default values in the MACE package:
126 |
127 | ```json
128 | "model": {
129 | "type": "mace",
130 | "type_map": [
131 | "O",
132 | "H"
133 | ],
134 | "r_max": 5.0,
135 | "sel": "auto",
136 | "num_radial_basis": 8,
137 | "num_cutoff_basis": 5,
138 | "max_ell": 3,
139 | "interaction": "RealAgnosticResidualInteractionBlock",
140 | "num_interactions": 2,
141 | "hidden_irreps": "128x0e + 128x1o",
142 | "pair_repulsion": false,
143 | "distance_transform": "None",
144 | "correlation": 3,
145 | "gate": "silu",
146 | "MLP_irreps": "16x0e",
147 | "radial_type": "bessel",
148 | "radial_MLP": [64, 64, 64],
149 | "std": 1.0,
150 | "precision": "float32"
151 | }
152 | ```
153 |
154 | ### NequIP
155 |
156 | ```json
157 | "model": {
158 | "type": "nequip",
159 | "type_map": [
160 | "O",
161 | "H"
162 | ],
163 | "r_max": 5.0,
164 | "sel": "auto",
165 | "num_layers": 4,
166 | "l_max": 2,
167 | "num_features": 32,
168 | "nonlinearity_type": "gate",
169 | "parity": true,
170 | "num_basis": 8,
171 | "BesselBasis_trainable": true,
172 | "PolynomialCutoff_p": 6,
173 | "invariant_layers": 2,
174 | "invariant_neurons": 64,
175 | "use_sc": true,
176 | "irreps_edge_sh": "0e + 1e",
177 | "feature_irreps_hidden": "32x0o + 32x0e + 32x1o + 32x1e",
178 | "chemical_embedding_irreps_out": "32x0e",
179 | "conv_to_output_hidden_irreps_out": "16x0e",
180 | "precision": "float32"
181 | }
182 | ```
183 |
184 | ## DPRc support
185 |
186 | In `deepmd-gnn`, the GNN model can be used in a [DPRc](https://docs.deepmodeling.com/projects/deepmd/en/latest/model/dprc.html) way.
187 | Type maps that starts with `m` (such as `mH`) or `OW` or `HW` will be recognized as MM types.
188 | Two MM atoms will not build edges with each other.
189 | Such GNN+DPRc model can be directly used in AmberTools24.
190 |
191 | ## Examples
192 |
193 | - [examples/water](examples/water)
194 | - [examples/dprc](examples/dprc)
195 |
--------------------------------------------------------------------------------
/build_backend/__init__.py:
--------------------------------------------------------------------------------
1 | """Customized PEP-517 build backend."""
2 |
--------------------------------------------------------------------------------
/build_backend/dp_backend.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: LGPL-3.0-or-later
2 | """A PEP-517 backend to add customized dependencies."""
3 |
4 | import os
5 |
6 | from scikit_build_core import build as _orig
7 |
8 | __all__ = [
9 | "build_editable",
10 | "build_sdist",
11 | "build_wheel",
12 | "get_requires_for_build_editable",
13 | "get_requires_for_build_sdist",
14 | "get_requires_for_build_wheel",
15 | "prepare_metadata_for_build_editable",
16 | "prepare_metadata_for_build_wheel",
17 | ]
18 |
19 |
20 | def __dir__() -> list[str]:
21 | return __all__
22 |
23 |
24 | prepare_metadata_for_build_wheel = _orig.prepare_metadata_for_build_wheel
25 | build_wheel = _orig.build_wheel
26 | build_sdist = _orig.build_sdist
27 | get_requires_for_build_sdist = _orig.get_requires_for_build_sdist
28 | prepare_metadata_for_build_editable = _orig.prepare_metadata_for_build_editable
29 | build_editable = _orig.build_editable
30 | get_requires_for_build_editable = _orig.get_requires_for_build_editable
31 |
32 |
33 | def cibuildwheel_dependencies() -> list[str]:
34 | if (
35 | os.environ.get("CIBUILDWHEEL", "0") == "1"
36 | or os.environ.get("READTHEDOCS", "0") == "True"
37 | ):
38 | return [
39 | "deepmd-kit[torch]>=3.0.0b2",
40 | ]
41 | return []
42 |
43 |
44 | def get_requires_for_build_wheel(
45 | config_settings: dict,
46 | ) -> list[str]:
47 | """Return the dependencies for building a wheel."""
48 | return (
49 | _orig.get_requires_for_build_wheel(config_settings)
50 | + cibuildwheel_dependencies()
51 | )
52 |
--------------------------------------------------------------------------------
/deepmd_gnn/__init__.py:
--------------------------------------------------------------------------------
1 | """MACE plugin for DeePMD-kit."""
2 |
3 | import os
4 |
5 | from ._version import __version__
6 | from .argcheck import mace_model_args
7 |
8 | __email__ = "jinzhe.zeng@rutgers.edu"
9 |
10 | __all__ = [
11 | "__version__",
12 | "mace_model_args",
13 | ]
14 |
15 | # make compatible with mace & e3nn & pytorch 2.6
16 | os.environ["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
17 |
--------------------------------------------------------------------------------
/deepmd_gnn/__main__.py:
--------------------------------------------------------------------------------
1 | """Main entry point for the command line interface."""
2 |
3 | if __name__ == "__main__":
4 | msg = "This module is not meant to be executed directly."
5 | raise NotImplementedError(msg)
6 |
--------------------------------------------------------------------------------
/deepmd_gnn/argcheck.py:
--------------------------------------------------------------------------------
1 | """Argument check for the MACE model."""
2 |
3 | from dargs import Argument
4 | from deepmd.utils.argcheck import model_args_plugin
5 |
6 |
7 | @model_args_plugin.register("mace")
8 | def mace_model_args() -> Argument:
9 | """Arguments for the MACE model.
10 |
11 | Returns
12 | -------
13 | Argument
14 | Arguments for the MACE model.
15 | """
16 | doc_r_max = "distance cutoff (in Ang)"
17 | doc_num_radial_basis = "number of radial basis functions"
18 | doc_num_cutoff_basis = "number of basis functions for smooth cutoff"
19 | doc_max_ell = "highest ell of spherical harmonics"
20 | doc_interaction = "name of interaction block"
21 | doc_num_interactions = "number of interactions"
22 | doc_hidden_irreps = "hidden irreps"
23 | doc_pair_repulsion = "use pair repulsion term with ZBL potential"
24 | doc_distance_transform = "distance transform"
25 | doc_correlation = "correlation order at each layer"
26 | doc_gate = "non linearity for last readout"
27 | doc_mlp_irreps = "hidden irreps of the MLP in last readout"
28 | doc_radial_type = "type of radial basis functions"
29 | doc_radial_mlp = "width of the radial MLP"
30 | doc_std = "Standard deviation of force components in the training set"
31 | doc_precision = "Precision of the model, float32 or float64"
32 | return Argument(
33 | "mace",
34 | dict,
35 | [
36 | Argument("sel", [int, str], optional=False),
37 | Argument("r_max", float, optional=True, default=5.0, doc=doc_r_max),
38 | Argument(
39 | "num_radial_basis",
40 | int,
41 | optional=True,
42 | default=8,
43 | doc=doc_num_radial_basis,
44 | ),
45 | Argument(
46 | "num_cutoff_basis",
47 | int,
48 | optional=True,
49 | default=5,
50 | doc=doc_num_cutoff_basis,
51 | ),
52 | Argument("max_ell", int, optional=True, default=3, doc=doc_max_ell),
53 | Argument(
54 | "interaction",
55 | str,
56 | optional=True,
57 | default="RealAgnosticResidualInteractionBlock",
58 | doc=doc_interaction,
59 | ),
60 | Argument(
61 | "num_interactions",
62 | int,
63 | optional=True,
64 | default=2,
65 | doc=doc_num_interactions,
66 | ),
67 | Argument(
68 | "hidden_irreps",
69 | str,
70 | optional=True,
71 | default="128x0e + 128x1o",
72 | doc=doc_hidden_irreps,
73 | ),
74 | Argument(
75 | "pair_repulsion",
76 | bool,
77 | optional=True,
78 | default=False,
79 | doc=doc_pair_repulsion,
80 | ),
81 | Argument(
82 | "distance_transform",
83 | str,
84 | optional=True,
85 | default="None",
86 | doc=doc_distance_transform,
87 | ),
88 | Argument("correlation", int, optional=True, default=3, doc=doc_correlation),
89 | Argument("gate", str, optional=True, default="silu", doc=doc_gate),
90 | Argument(
91 | "MLP_irreps",
92 | str,
93 | optional=True,
94 | default="16x0e",
95 | doc=doc_mlp_irreps,
96 | ),
97 | Argument(
98 | "radial_type",
99 | str,
100 | optional=True,
101 | default="bessel",
102 | doc=doc_radial_type,
103 | ),
104 | Argument(
105 | "radial_MLP",
106 | list[int],
107 | optional=True,
108 | default=[64, 64, 64],
109 | doc=doc_radial_mlp,
110 | ),
111 | Argument("std", float, optional=True, doc=doc_std, default=1),
112 | Argument(
113 | "precision",
114 | str,
115 | optional=True,
116 | default="float32",
117 | doc=doc_precision,
118 | ),
119 | ],
120 | doc="MACE model",
121 | )
122 |
123 |
124 | @model_args_plugin.register("nequip")
125 | def nequip_model_args() -> Argument:
126 | """Arguments for the NequIP model."""
127 | doc_sel = "Maximum number of neighbor atoms."
128 | doc_r_max = "distance cutoff (in Ang)"
129 | doc_num_layers = "number of interaction blocks, we find 3-5 to work best"
130 | doc_l_max = "the maximum irrep order (rotation order) for the network's features, l=1 is a good default, l=2 is more accurate but slower"
131 | doc_num_features = "the multiplicity of the features, 32 is a good default for accurate network, if you want to be more accurate, go larger, if you want to be faster, go lower"
132 | doc_nonlinearity_type = "may be 'gate' or 'norm', 'gate' is recommended"
133 | doc_parity = "whether to include features with odd mirror parityy; often turning parity off gives equally good results but faster networks, so do consider this"
134 | doc_num_basis = (
135 | "number of basis functions used in the radial basis, 8 usually works best"
136 | )
137 | doc_besselbasis_trainable = "set true to train the bessel weights"
138 | doc_polynomialcutoff_p = "p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance"
139 | doc_invariant_layers = (
140 | "number of radial layers, usually 1-3 works best, smaller is faster"
141 | )
142 | doc_invariant_neurons = (
143 | "number of hidden neurons in radial function, smaller is faster"
144 | )
145 | doc_use_sc = "use self-connection or not, usually gives big improvement"
146 | doc_irreps_edge_sh = "irreps for the chemical embedding of species"
147 | doc_feature_irreps_hidden = "irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster"
148 | doc_chemical_embedding_irreps_out = "irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer"
149 | doc_conv_to_output_hidden_irreps_out = "irreps used in hidden layer of output block"
150 | doc_precision = "Precision of the model, float32 or float64"
151 | return Argument(
152 | "nequip",
153 | dict,
154 | [
155 | Argument(
156 | "sel",
157 | [int, str],
158 | optional=False,
159 | doc=doc_sel,
160 | ),
161 | Argument(
162 | "r_max",
163 | float,
164 | optional=True,
165 | default=6.0,
166 | doc=doc_r_max,
167 | ),
168 | Argument(
169 | "num_layers",
170 | int,
171 | optional=True,
172 | default=4,
173 | doc=doc_num_layers,
174 | ),
175 | Argument(
176 | "l_max",
177 | int,
178 | optional=True,
179 | default=2,
180 | doc=doc_l_max,
181 | ),
182 | Argument(
183 | "num_features",
184 | int,
185 | optional=True,
186 | default=32,
187 | doc=doc_num_features,
188 | ),
189 | Argument(
190 | "nonlinearity_type",
191 | str,
192 | optional=True,
193 | default="gate",
194 | doc=doc_nonlinearity_type,
195 | ),
196 | Argument(
197 | "parity",
198 | bool,
199 | optional=True,
200 | default=True,
201 | doc=doc_parity,
202 | ),
203 | Argument(
204 | "num_basis",
205 | int,
206 | optional=True,
207 | default=8,
208 | doc=doc_num_basis,
209 | ),
210 | Argument(
211 | "BesselBasis_trainable",
212 | bool,
213 | optional=True,
214 | default=True,
215 | doc=doc_besselbasis_trainable,
216 | ),
217 | Argument(
218 | "PolynomialCutoff_p",
219 | int,
220 | optional=True,
221 | default=6,
222 | doc=doc_polynomialcutoff_p,
223 | ),
224 | Argument(
225 | "invariant_layers",
226 | int,
227 | optional=True,
228 | default=2,
229 | doc=doc_invariant_layers,
230 | ),
231 | Argument(
232 | "invariant_neurons",
233 | int,
234 | optional=True,
235 | default=64,
236 | doc=doc_invariant_neurons,
237 | ),
238 | Argument(
239 | "use_sc",
240 | bool,
241 | optional=True,
242 | default=True,
243 | doc=doc_use_sc,
244 | ),
245 | Argument(
246 | "irreps_edge_sh",
247 | str,
248 | optional=True,
249 | default="0e + 1e",
250 | doc=doc_irreps_edge_sh,
251 | ),
252 | Argument(
253 | "feature_irreps_hidden",
254 | str,
255 | optional=True,
256 | default="32x0o + 32x0e + 32x1o + 32x1e",
257 | doc=doc_feature_irreps_hidden,
258 | ),
259 | Argument(
260 | "chemical_embedding_irreps_out",
261 | str,
262 | optional=True,
263 | default="32x0e",
264 | doc=doc_chemical_embedding_irreps_out,
265 | ),
266 | Argument(
267 | "conv_to_output_hidden_irreps_out",
268 | str,
269 | optional=True,
270 | default="16x0e",
271 | doc=doc_conv_to_output_hidden_irreps_out,
272 | ),
273 | Argument(
274 | "precision",
275 | str,
276 | optional=True,
277 | default="float32",
278 | doc=doc_precision,
279 | ),
280 | ],
281 | doc="Nequip model",
282 | )
283 |
--------------------------------------------------------------------------------
/deepmd_gnn/env.py:
--------------------------------------------------------------------------------
1 | """Configurations read from environment variables."""
2 |
3 | import os
4 |
5 | DP_GNN_USE_MAPPING = os.environ.get("DP_GNN_USE_MAPPING", "0") == "1"
6 |
--------------------------------------------------------------------------------
/deepmd_gnn/mace.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: LGPL-3.0-or-later
2 | """Wrapper for MACE models."""
3 |
4 | import json
5 | from copy import deepcopy
6 | from typing import Any, Optional
7 |
8 | import torch
9 | from deepmd.dpmodel.output_def import (
10 | FittingOutputDef,
11 | ModelOutputDef,
12 | OutputVariableDef,
13 | )
14 | from deepmd.pt.model.model.model import (
15 | BaseModel,
16 | )
17 | from deepmd.pt.model.model.transform_output import (
18 | communicate_extended_output,
19 | )
20 | from deepmd.pt.utils import env
21 | from deepmd.pt.utils.nlist import (
22 | build_neighbor_list,
23 | extend_input_and_build_neighbor_list,
24 | )
25 | from deepmd.pt.utils.stat import (
26 | compute_output_stats,
27 | )
28 | from deepmd.pt.utils.update_sel import (
29 | UpdateSel,
30 | )
31 | from deepmd.pt.utils.utils import (
32 | to_numpy_array,
33 | to_torch_tensor,
34 | )
35 | from deepmd.utils.data_system import (
36 | DeepmdDataSystem,
37 | )
38 | from deepmd.utils.path import (
39 | DPPath,
40 | )
41 | from deepmd.utils.version import (
42 | check_version_compatibility,
43 | )
44 | from e3nn import (
45 | o3,
46 | )
47 | from e3nn.util.jit import (
48 | script,
49 | )
50 | from mace.modules import (
51 | ScaleShiftMACE,
52 | gate_dict,
53 | interaction_classes,
54 | )
55 |
56 | import deepmd_gnn.op # noqa: F401
57 | from deepmd_gnn import env as deepmd_gnn_env
58 |
59 | ELEMENTS = [
60 | "H",
61 | "He",
62 | "Li",
63 | "Be",
64 | "B",
65 | "C",
66 | "N",
67 | "O",
68 | "F",
69 | "Ne",
70 | "Na",
71 | "Mg",
72 | "Al",
73 | "Si",
74 | "P",
75 | "S",
76 | "Cl",
77 | "Ar",
78 | "K",
79 | "Ca",
80 | "Sc",
81 | "Ti",
82 | "V",
83 | "Cr",
84 | "Mn",
85 | "Fe",
86 | "Co",
87 | "Ni",
88 | "Cu",
89 | "Zn",
90 | "Ga",
91 | "Ge",
92 | "As",
93 | "Se",
94 | "Br",
95 | "Kr",
96 | "Rb",
97 | "Sr",
98 | "Y",
99 | "Zr",
100 | "Nb",
101 | "Mo",
102 | "Tc",
103 | "Ru",
104 | "Rh",
105 | "Pd",
106 | "Ag",
107 | "Cd",
108 | "In",
109 | "Sn",
110 | "Sb",
111 | "Te",
112 | "I",
113 | "Xe",
114 | "Cs",
115 | "Ba",
116 | "La",
117 | "Ce",
118 | "Pr",
119 | "Nd",
120 | "Pm",
121 | "Sm",
122 | "Eu",
123 | "Gd",
124 | "Tb",
125 | "Dy",
126 | "Ho",
127 | "Er",
128 | "Tm",
129 | "Yb",
130 | "Lu",
131 | "Hf",
132 | "Ta",
133 | "W",
134 | "Re",
135 | "Os",
136 | "Ir",
137 | "Pt",
138 | "Au",
139 | "Hg",
140 | "Tl",
141 | "Pb",
142 | "Bi",
143 | "Po",
144 | "At",
145 | "Rn",
146 | "Fr",
147 | "Ra",
148 | "Ac",
149 | "Th",
150 | "Pa",
151 | "U",
152 | "Np",
153 | "Pu",
154 | "Am",
155 | "Cm",
156 | "Bk",
157 | "Cf",
158 | "Es",
159 | "Fm",
160 | "Md",
161 | "No",
162 | "Lr",
163 | "Rf",
164 | "Db",
165 | "Sg",
166 | "Bh",
167 | "Hs",
168 | "Mt",
169 | "Ds",
170 | "Rg",
171 | "Cn",
172 | "Nh",
173 | "Fl",
174 | "Mc",
175 | "Lv",
176 | "Ts",
177 | "Og",
178 | ]
179 |
180 | PeriodicTable = {
181 | **{ee: ii + 1 for ii, ee in enumerate(ELEMENTS)},
182 | **{f"m{ee}": ii + 1 for ii, ee in enumerate(ELEMENTS)},
183 | "HW": 1,
184 | "OW": 8,
185 | }
186 |
187 |
188 | @BaseModel.register("mace")
189 | class MaceModel(BaseModel):
190 | """Mace model.
191 |
192 | Parameters
193 | ----------
194 | type_map : list[str]
195 | The name of each type of atoms
196 | sel : int
197 | Maximum number of neighbor atoms
198 | r_max : float, optional
199 | distance cutoff (in Ang)
200 | num_radial_basis : int, optional
201 | number of radial basis functions
202 | num_cutoff_basis : int, optional
203 | number of basis functions for smooth cutoff
204 | max_ell : int, optional
205 | highest ell of spherical harmonics
206 | interaction : str, optional
207 | name of interaction block
208 | num_interactions : int, optional
209 | number of interactions
210 | hidden_irreps : str, optional
211 | hidden irreps
212 | pair_repulsion : bool
213 | use amsgrad variant of optimizer
214 | distance_transform : str, optional
215 | distance transform
216 | correlation : int
217 | correlation order at each layer
218 | gate : str, optional
219 | non linearity for last readout
220 | MLP_irreps : str, optional
221 | hidden irreps of the MLP in last readout
222 | radial_type : str, optional
223 | type of radial basis functions
224 | radial_MLP : str, optional
225 | width of the radial MLP
226 | std : float, optional
227 | Standard deviation of force components in the training set
228 | """
229 |
230 | mm_types: list[int]
231 |
232 | def __init__(
233 | self,
234 | type_map: list[str],
235 | sel: int,
236 | r_max: float = 5.0,
237 | num_radial_basis: int = 8,
238 | num_cutoff_basis: int = 5,
239 | max_ell: int = 3,
240 | interaction: str = "RealAgnosticResidualInteractionBlock",
241 | num_interactions: int = 2,
242 | hidden_irreps: str = "128x0e + 128x1o",
243 | pair_repulsion: bool = False,
244 | distance_transform: str = "None",
245 | correlation: int = 3,
246 | gate: str = "silu",
247 | MLP_irreps: str = "16x0e",
248 | radial_type: str = "bessel",
249 | radial_MLP: list[int] = [64, 64, 64], # noqa: B006
250 | std: float = 1,
251 | **kwargs: Any, # noqa: ANN401
252 | ) -> None:
253 | super().__init__(**kwargs)
254 | self.params = {
255 | "type_map": type_map,
256 | "sel": sel,
257 | "r_max": r_max,
258 | "num_radial_basis": num_radial_basis,
259 | "num_cutoff_basis": num_cutoff_basis,
260 | "max_ell": max_ell,
261 | "interaction": interaction,
262 | "num_interactions": num_interactions,
263 | "hidden_irreps": hidden_irreps,
264 | "pair_repulsion": pair_repulsion,
265 | "distance_transform": distance_transform,
266 | "correlation": correlation,
267 | "gate": gate,
268 | "MLP_irreps": MLP_irreps,
269 | "radial_type": radial_type,
270 | "radial_MLP": radial_MLP,
271 | "std": std,
272 | }
273 | self.type_map = type_map
274 | self.ntypes = len(type_map)
275 | self.rcut = r_max
276 | self.num_interactions = num_interactions
277 | atomic_numbers = []
278 | self.preset_out_bias: dict[str, list] = {"energy": []}
279 | self.mm_types = []
280 | self.sel = sel
281 | for ii, tt in enumerate(type_map):
282 | atomic_numbers.append(PeriodicTable[tt])
283 | if not tt.startswith("m") and tt not in {"HW", "OW"}:
284 | self.preset_out_bias["energy"].append(None)
285 | else:
286 | self.preset_out_bias["energy"].append([0])
287 | self.mm_types.append(ii)
288 |
289 | self.model = script(
290 | ScaleShiftMACE(
291 | r_max=r_max,
292 | num_bessel=num_radial_basis,
293 | num_polynomial_cutoff=num_cutoff_basis,
294 | max_ell=max_ell,
295 | interaction_cls=interaction_classes[interaction],
296 | num_interactions=num_interactions,
297 | num_elements=self.ntypes,
298 | hidden_irreps=o3.Irreps(hidden_irreps),
299 | atomic_energies=torch.zeros(self.ntypes), # pylint: disable=no-explicit-device,no-explicit-dtype
300 | avg_num_neighbors=sel,
301 | atomic_numbers=atomic_numbers,
302 | pair_repulsion=pair_repulsion,
303 | distance_transform=distance_transform,
304 | correlation=correlation,
305 | gate=gate_dict[gate],
306 | interaction_cls_first=interaction_classes[
307 | "RealAgnosticInteractionBlock"
308 | ],
309 | MLP_irreps=o3.Irreps(MLP_irreps),
310 | atomic_inter_scale=std,
311 | atomic_inter_shift=0.0,
312 | radial_MLP=radial_MLP,
313 | radial_type=radial_type,
314 | ).to(env.DEVICE),
315 | )
316 | self.atomic_numbers = atomic_numbers
317 |
318 | def compute_or_load_stat(
319 | self,
320 | sampled_func, # noqa: ANN001
321 | stat_file_path: Optional[DPPath] = None,
322 | ) -> None:
323 | """Compute or load the statistics parameters of the model.
324 |
325 | For example, mean and standard deviation of descriptors or the energy bias of
326 | the fitting net. When `sampled` is provided, all the statistics parameters will
327 | be calculated (or re-calculated for update), and saved in the
328 | `stat_file_path`(s). When `sampled` is not provided, it will check the existence
329 | of `stat_file_path`(s) and load the calculated statistics parameters.
330 |
331 | Parameters
332 | ----------
333 | sampled_func
334 | The sampled data frames from different data systems.
335 | stat_file_path
336 | The path to the statistics files.
337 | """
338 | bias_out, _ = compute_output_stats(
339 | sampled_func,
340 | self.get_ntypes(),
341 | keys=["energy"],
342 | stat_file_path=stat_file_path,
343 | rcond=None,
344 | preset_bias=self.preset_out_bias,
345 | )
346 | if "energy" in bias_out:
347 | self.model.atomic_energies_fn.atomic_energies = (
348 | bias_out["energy"]
349 | .view(self.model.atomic_energies_fn.atomic_energies.shape)
350 | .to(self.model.atomic_energies_fn.atomic_energies.dtype)
351 | .to(self.model.atomic_energies_fn.atomic_energies.device)
352 | )
353 |
354 | @torch.jit.export
355 | def fitting_output_def(self) -> FittingOutputDef:
356 | """Get the output def of developer implemented atomic models."""
357 | return FittingOutputDef(
358 | [
359 | OutputVariableDef(
360 | name="energy",
361 | shape=[1],
362 | reducible=True,
363 | r_differentiable=True,
364 | c_differentiable=True,
365 | ),
366 | ],
367 | )
368 |
369 | @torch.jit.export
370 | def get_rcut(self) -> float:
371 | """Get the cut-off radius."""
372 | if deepmd_gnn_env.DP_GNN_USE_MAPPING:
373 | return self.rcut
374 | return self.rcut * self.num_interactions
375 |
376 | @torch.jit.export
377 | def get_type_map(self) -> list[str]:
378 | """Get the type map."""
379 | return self.type_map
380 |
381 | @torch.jit.export
382 | def get_sel(self) -> list[int]:
383 | """Return the number of selected atoms for each type."""
384 | return [self.sel]
385 |
386 | @torch.jit.export
387 | def get_dim_fparam(self) -> int:
388 | """Get the number (dimension) of frame parameters of this atomic model."""
389 | return 0
390 |
391 | @torch.jit.export
392 | def get_dim_aparam(self) -> int:
393 | """Get the number (dimension) of atomic parameters of this atomic model."""
394 | return 0
395 |
396 | @torch.jit.export
397 | def get_sel_type(self) -> list[int]:
398 | """Get the selected atom types of this model.
399 |
400 | Only atoms with selected atom types have atomic contribution
401 | to the result of the model.
402 | If returning an empty list, all atom types are selected.
403 | """
404 | return []
405 |
406 | @torch.jit.export
407 | def is_aparam_nall(self) -> bool:
408 | """Check whether the shape of atomic parameters is (nframes, nall, ndim).
409 |
410 | If False, the shape is (nframes, nloc, ndim).
411 | """
412 | return False
413 |
414 | @torch.jit.export
415 | def mixed_types(self) -> bool:
416 | """Return whether the model is in mixed-types mode.
417 |
418 | If true, the model
419 | 1. assumes total number of atoms aligned across frames;
420 | 2. uses a neighbor list that does not distinguish different atomic types.
421 | If false, the model
422 | 1. assumes total number of atoms of each atom type aligned across frames;
423 | 2. uses a neighbor list that distinguishes different atomic types.
424 | """
425 | return True
426 |
427 | @torch.jit.export
428 | def has_message_passing(self) -> bool:
429 | """Return whether the descriptor has message passing."""
430 | return False
431 |
432 | @torch.jit.export
433 | def forward(
434 | self,
435 | coord: torch.Tensor,
436 | atype: torch.Tensor,
437 | box: Optional[torch.Tensor] = None,
438 | fparam: Optional[torch.Tensor] = None,
439 | aparam: Optional[torch.Tensor] = None,
440 | do_atomic_virial: bool = False,
441 | ) -> dict[str, torch.Tensor]:
442 | """Forward pass of the model.
443 |
444 | Parameters
445 | ----------
446 | coord : torch.Tensor
447 | The coordinates of atoms.
448 | atype : torch.Tensor
449 | The atomic types of atoms.
450 | box : torch.Tensor, optional
451 | The box tensor.
452 | fparam : torch.Tensor, optional
453 | The frame parameters.
454 | aparam : torch.Tensor, optional
455 | The atomic parameters.
456 | do_atomic_virial : bool, optional
457 | Whether to compute atomic virial.
458 | """
459 | nloc = atype.shape[1]
460 | extended_coord, extended_atype, mapping, nlist = (
461 | extend_input_and_build_neighbor_list(
462 | coord,
463 | atype,
464 | self.rcut,
465 | self.get_sel(),
466 | mixed_types=True,
467 | box=box,
468 | )
469 | )
470 | model_ret_lower = self.forward_lower_common(
471 | nloc,
472 | extended_coord,
473 | extended_atype,
474 | nlist,
475 | mapping=mapping,
476 | fparam=fparam,
477 | aparam=aparam,
478 | do_atomic_virial=do_atomic_virial,
479 | comm_dict=None,
480 | )
481 | model_ret = communicate_extended_output(
482 | model_ret_lower,
483 | ModelOutputDef(self.fitting_output_def()),
484 | mapping,
485 | do_atomic_virial,
486 | )
487 | model_predict = {}
488 | model_predict["atom_energy"] = model_ret["energy"]
489 | model_predict["energy"] = model_ret["energy_redu"]
490 | model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
491 | model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
492 | if do_atomic_virial:
493 | model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
494 | return model_predict
495 |
496 | @torch.jit.export
497 | def forward_lower(
498 | self,
499 | extended_coord: torch.Tensor,
500 | extended_atype: torch.Tensor,
501 | nlist: torch.Tensor,
502 | mapping: Optional[torch.Tensor] = None,
503 | fparam: Optional[torch.Tensor] = None,
504 | aparam: Optional[torch.Tensor] = None,
505 | do_atomic_virial: bool = False,
506 | comm_dict: Optional[dict[str, torch.Tensor]] = None,
507 | ) -> dict[str, torch.Tensor]:
508 | """Forward lower pass of the model.
509 |
510 | Parameters
511 | ----------
512 | extended_coord : torch.Tensor
513 | The extended coordinates of atoms.
514 | extended_atype : torch.Tensor
515 | The extended atomic types of atoms.
516 | nlist : torch.Tensor
517 | The neighbor list.
518 | mapping : torch.Tensor, optional
519 | The mapping tensor.
520 | fparam : torch.Tensor, optional
521 | The frame parameters.
522 | aparam : torch.Tensor, optional
523 | The atomic parameters.
524 | do_atomic_virial : bool, optional
525 | Whether to compute atomic virial.
526 | comm_dict : dict[str, torch.Tensor], optional
527 | The communication dictionary.
528 | """
529 | nloc = nlist.shape[1]
530 | nf, nall = extended_atype.shape
531 | # calculate nlist for ghost atoms, as LAMMPS does not calculate it
532 | if mapping is None and self.num_interactions > 1 and nloc < nall:
533 | if deepmd_gnn_env.DP_GNN_USE_MAPPING:
534 | # when setting DP_GNN_USE_MAPPING, ghost atoms are only built
535 | # for one message-passing layer
536 | msg = (
537 | "When setting DP_GNN_USE_MAPPING, mapping is required. "
538 | "If you are using LAMMPS, set `atom_modify map yes`."
539 | )
540 | raise ValueError(msg)
541 | nlist = build_neighbor_list(
542 | extended_coord.view(nf, -1),
543 | extended_atype,
544 | nall,
545 | self.rcut,
546 | self.sel,
547 | distinguish_types=False,
548 | )
549 |
550 | model_ret = self.forward_lower_common(
551 | nloc,
552 | extended_coord,
553 | extended_atype,
554 | nlist,
555 | mapping,
556 | fparam,
557 | aparam,
558 | do_atomic_virial,
559 | comm_dict,
560 | )
561 | model_predict = {}
562 | model_predict["atom_energy"] = model_ret["energy"]
563 | model_predict["energy"] = model_ret["energy_redu"]
564 | model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
565 | model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
566 | if do_atomic_virial:
567 | model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(-3)
568 | return model_predict
569 |
570 | def forward_lower_common(
571 | self,
572 | nloc: int,
573 | extended_coord: torch.Tensor,
574 | extended_atype: torch.Tensor,
575 | nlist: torch.Tensor,
576 | mapping: Optional[torch.Tensor] = None,
577 | fparam: Optional[torch.Tensor] = None,
578 | aparam: Optional[torch.Tensor] = None,
579 | do_atomic_virial: bool = False, # noqa: ARG002
580 | comm_dict: Optional[dict[str, torch.Tensor]] = None,
581 | ) -> dict[str, torch.Tensor]:
582 | """Forward lower common pass of the model.
583 |
584 | Parameters
585 | ----------
586 | extended_coord : torch.Tensor
587 | The extended coordinates of atoms.
588 | extended_atype : torch.Tensor
589 | The extended atomic types of atoms.
590 | nlist : torch.Tensor
591 | The neighbor list.
592 | mapping : torch.Tensor, optional
593 | The mapping tensor.
594 | fparam : torch.Tensor, optional
595 | The frame parameters.
596 | aparam : torch.Tensor, optional
597 | The atomic parameters.
598 | do_atomic_virial : bool, optional
599 | Whether to compute atomic virial.
600 | comm_dict : dict[str, torch.Tensor], optional
601 | The communication dictionary.
602 | """
603 | nf, nall = extended_atype.shape
604 | extended_coord = extended_coord.view(nf, nall, 3)
605 | extended_coord_ = extended_coord
606 | if fparam is not None:
607 | msg = "fparam is unsupported"
608 | raise ValueError(msg)
609 | if aparam is not None:
610 | msg = "aparam is unsupported"
611 | raise ValueError(msg)
612 | if comm_dict is not None:
613 | msg = "comm_dict is unsupported"
614 | raise ValueError(msg)
615 | nlist = nlist.to(torch.int64)
616 | extended_atype = extended_atype.to(torch.int64)
617 | nall = extended_coord.shape[1]
618 |
619 | # fake as one frame
620 | extended_coord_ff = extended_coord.view(nf * nall, 3)
621 | extended_atype_ff = extended_atype.view(nf * nall)
622 | edge_index = torch.ops.deepmd_gnn.edge_index(
623 | nlist,
624 | extended_atype,
625 | torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"),
626 | )
627 | edge_index = edge_index.T
628 | # to one hot
629 | indices = extended_atype_ff.unsqueeze(-1)
630 | oh = torch.zeros(
631 | (nf * nall, self.ntypes),
632 | device=extended_atype.device,
633 | dtype=torch.float64,
634 | )
635 | # scatter_ is the in-place version of scatter
636 | oh.scatter_(dim=-1, index=indices, value=1)
637 | one_hot = oh.view((nf * nall, self.ntypes))
638 |
639 | # cast to float32
640 | default_dtype = self.model.atomic_energies_fn.atomic_energies.dtype
641 | extended_coord_ff = extended_coord_ff.to(default_dtype)
642 | extended_coord_ff.requires_grad_(True) # noqa: FBT003
643 | nedge = edge_index.shape[1]
644 | if self.num_interactions > 1 and mapping is not None and nloc < nall:
645 | # shift the edges for ghost atoms, and map the ghost atoms to real atoms
646 | mapping_ff = mapping.view(nf * nall) + torch.arange(
647 | 0,
648 | nf * nall,
649 | nall,
650 | dtype=mapping.dtype,
651 | device=mapping.device,
652 | ).unsqueeze(-1).expand(nf, nall).reshape(-1)
653 | shifts_atoms = extended_coord_ff - extended_coord_ff[mapping_ff]
654 | shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]]
655 | edge_index = mapping_ff[edge_index]
656 | else:
657 | shifts = torch.zeros(
658 | (nedge, 3),
659 | dtype=torch.float64,
660 | device=extended_coord_.device,
661 | )
662 | shifts = shifts.to(default_dtype)
663 | one_hot = one_hot.to(default_dtype)
664 | # it seems None is not allowed for data
665 | box = (
666 | torch.eye(
667 | 3,
668 | dtype=extended_coord_ff.dtype,
669 | device=extended_coord_ff.device,
670 | )
671 | * 1000.0
672 | )
673 |
674 | ret = self.model.forward(
675 | {
676 | "positions": extended_coord_ff,
677 | "shifts": shifts,
678 | "cell": box,
679 | "edge_index": edge_index,
680 | "batch": torch.zeros(
681 | [nf * nall],
682 | dtype=torch.int64,
683 | device=extended_coord_ff.device,
684 | ),
685 | "node_attrs": one_hot,
686 | "ptr": torch.tensor(
687 | [0, nf * nall],
688 | dtype=torch.int64,
689 | device=extended_coord_ff.device,
690 | ),
691 | "weight": torch.tensor(
692 | [1.0],
693 | dtype=extended_coord_ff.dtype,
694 | device=extended_coord_ff.device,
695 | ),
696 | },
697 | compute_force=False,
698 | compute_virials=False,
699 | compute_stress=False,
700 | compute_displacement=False,
701 | training=self.training,
702 | )
703 |
704 | atom_energy = ret["node_energy"]
705 | if atom_energy is None:
706 | msg = "atom_energy is None"
707 | raise ValueError(msg)
708 | atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc]
709 | energy = torch.sum(atom_energy, dim=1).view(nf, 1).to(extended_coord_.dtype)
710 | grad_outputs: list[Optional[torch.Tensor]] = [
711 | torch.ones_like(energy),
712 | ]
713 | force = torch.autograd.grad(
714 | outputs=[energy],
715 | inputs=[extended_coord_ff],
716 | grad_outputs=grad_outputs,
717 | retain_graph=True,
718 | create_graph=self.training,
719 | )[0]
720 | if force is None:
721 | msg = "force is None"
722 | raise ValueError(msg)
723 | force = -force
724 | atomic_virial = force.unsqueeze(-1).to(
725 | extended_coord_.dtype,
726 | ) @ extended_coord_ff.unsqueeze(-2).to(
727 | extended_coord_.dtype,
728 | )
729 | force = force.view(nf, nall, 3).to(extended_coord_.dtype)
730 | atomic_virial = atomic_virial.view(nf, nall, 1, 9)
731 | virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype)
732 |
733 | return {
734 | "energy_redu": energy.view(nf, 1),
735 | "energy_derv_r": force.view(nf, nall, 1, 3),
736 | "energy_derv_c_redu": virial.view(nf, 1, 9),
737 | # take the first nloc atoms to match other models
738 | "energy": atom_energy.view(nf, nloc, 1),
739 | # fake atom_virial
740 | "energy_derv_c": atomic_virial.view(nf, nall, 1, 9),
741 | }
742 |
743 | def serialize(self) -> dict:
744 | """Serialize the model."""
745 | return {
746 | "@class": "Model",
747 | "@version": 1,
748 | "type": "mace",
749 | **self.params,
750 | "@variables": {
751 | kk: to_numpy_array(vv) for kk, vv in self.model.state_dict().items()
752 | },
753 | }
754 |
755 | @classmethod
756 | def deserialize(cls, data: dict) -> "MaceModel":
757 | """Deserialize the model."""
758 | data = data.copy()
759 | if not (data.pop("@class") == "Model" and data.pop("type") == "mace"):
760 | msg = "data is not a serialized MaceModel"
761 | raise ValueError(msg)
762 | check_version_compatibility(data.pop("@version"), 1, 1)
763 | variables = {
764 | kk: to_torch_tensor(vv) for kk, vv in data.pop("@variables").items()
765 | }
766 | model = cls(**data)
767 | model.model.load_state_dict(variables)
768 | return model
769 |
770 | @torch.jit.export
771 | def get_nnei(self) -> int:
772 | """Return the total number of selected neighboring atoms in cut-off radius."""
773 | return self.sel
774 |
775 | @torch.jit.export
776 | def get_nsel(self) -> int:
777 | """Return the total number of selected neighboring atoms in cut-off radius."""
778 | return self.sel
779 |
780 | @classmethod
781 | def update_sel(
782 | cls,
783 | train_data: DeepmdDataSystem,
784 | type_map: Optional[list[str]],
785 | local_jdata: dict,
786 | ) -> tuple[dict, Optional[float]]:
787 | """Update the selection and perform neighbor statistics.
788 |
789 | Parameters
790 | ----------
791 | train_data : DeepmdDataSystem
792 | data used to do neighbor statictics
793 | type_map : list[str], optional
794 | The name of each type of atoms
795 | local_jdata : dict
796 | The local data refer to the current class
797 |
798 | Returns
799 | -------
800 | dict
801 | The updated local data
802 | float
803 | The minimum distance between two atoms
804 | """
805 | local_jdata_cpy = local_jdata.copy()
806 | min_nbor_dist, sel = UpdateSel().update_one_sel(
807 | train_data,
808 | type_map,
809 | local_jdata_cpy["r_max"],
810 | local_jdata_cpy["sel"],
811 | mixed_type=True,
812 | )
813 | local_jdata_cpy["sel"] = sel[0]
814 | return local_jdata_cpy, min_nbor_dist
815 |
816 | @torch.jit.export
817 | def model_output_type(self) -> list[str]:
818 | """Get the output type for the model."""
819 | return ["energy"]
820 |
821 | def translated_output_def(self) -> dict[str, Any]:
822 | """Get the translated output def for the model."""
823 | out_def_data = self.model_output_def().get_data()
824 | output_def = {
825 | "atom_energy": deepcopy(out_def_data["energy"]),
826 | "energy": deepcopy(out_def_data["energy_redu"]),
827 | }
828 | output_def["force"] = deepcopy(out_def_data["energy_derv_r"])
829 | output_def["force"].squeeze(-2)
830 | output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
831 | output_def["virial"].squeeze(-2)
832 | output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
833 | output_def["atom_virial"].squeeze(-3)
834 | if "mask" in out_def_data:
835 | output_def["mask"] = deepcopy(out_def_data["mask"])
836 | return output_def
837 |
838 | def model_output_def(self) -> ModelOutputDef:
839 | """Get the output def for the model."""
840 | return ModelOutputDef(self.fitting_output_def())
841 |
842 | @classmethod
843 | def get_model(cls, model_params: dict) -> "MaceModel":
844 | """Get the model by the parameters.
845 |
846 | Parameters
847 | ----------
848 | model_params : dict
849 | The model parameters
850 |
851 | Returns
852 | -------
853 | BaseBaseModel
854 | The model
855 | """
856 | model_params_old = model_params.copy()
857 | model_params = model_params.copy()
858 | model_params.pop("type", None)
859 | precision = model_params.pop("precision", "float32")
860 | if precision == "float32":
861 | torch.set_default_dtype(torch.float32)
862 | elif precision == "float64":
863 | torch.set_default_dtype(torch.float64)
864 | else:
865 | msg = f"precision {precision} not supported"
866 | raise ValueError(msg)
867 | model = cls(**model_params)
868 | model.model_def_script = json.dumps(model_params_old)
869 | return model
870 |
--------------------------------------------------------------------------------
/deepmd_gnn/nequip.py:
--------------------------------------------------------------------------------
1 | """Nequip model."""
2 |
3 | from copy import deepcopy
4 | from typing import Any, Optional
5 |
6 | import torch
7 | from deepmd.dpmodel.output_def import (
8 | FittingOutputDef,
9 | ModelOutputDef,
10 | OutputVariableDef,
11 | )
12 | from deepmd.pt.model.model.model import (
13 | BaseModel,
14 | )
15 | from deepmd.pt.model.model.transform_output import (
16 | communicate_extended_output,
17 | )
18 | from deepmd.pt.utils import (
19 | env,
20 | )
21 | from deepmd.pt.utils.nlist import (
22 | build_neighbor_list,
23 | extend_input_and_build_neighbor_list,
24 | )
25 | from deepmd.pt.utils.stat import (
26 | compute_output_stats,
27 | )
28 | from deepmd.pt.utils.update_sel import (
29 | UpdateSel,
30 | )
31 | from deepmd.pt.utils.utils import (
32 | to_numpy_array,
33 | to_torch_tensor,
34 | )
35 | from deepmd.utils.data_system import (
36 | DeepmdDataSystem,
37 | )
38 | from deepmd.utils.path import (
39 | DPPath,
40 | )
41 | from deepmd.utils.version import (
42 | check_version_compatibility,
43 | )
44 | from e3nn.util.jit import (
45 | script,
46 | )
47 | from nequip.model import model_from_config
48 |
49 |
50 | @BaseModel.register("nequip")
51 | class NequipModel(BaseModel):
52 | """Nequip model.
53 |
54 | Parameters
55 | ----------
56 | type_map : list[str]
57 | The name of each type of atoms
58 | sel : int
59 | Maximum number of neighbor atoms
60 | r_max : float, optional
61 | distance cutoff (in Ang)
62 | num_layers : int
63 | number of interaction blocks, we find 3-5 to work best
64 | l_max : int
65 | the maximum irrep order (rotation order) for the network's features, l=1 is a good default, l=2 is more accurate but slower
66 | num_features : int
67 | the multiplicity of the features, 32 is a good default for accurate network, if you want to be more accurate, go larger, if you want to be faster, go lower
68 | nonlinearity_type : str
69 | may be 'gate' or 'norm', 'gate' is recommended
70 | parity : bool
71 | whether to include features with odd mirror parityy; often turning parity off gives equally good results but faster networks, so do consider this
72 | num_basis : int
73 | number of basis functions used in the radial basis, 8 usually works best
74 | BesselBasis_trainable : bool
75 | set true to train the bessel weights
76 | PolynomialCutoff_p : int
77 | p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance
78 | invariant_layers : int
79 | number of radial layers, usually 1-3 works best, smaller is faster
80 | invariant_neurons : int
81 | number of hidden neurons in radial function, smaller is faster
82 | use_sc : bool
83 | use self-connection or not, usually gives big improvement
84 | irreps_edge_sh : str
85 | irreps for the chemical embedding of species
86 | feature_irreps_hidden : str
87 | irreps used for hidden features, here we go up to lmax=1, with even and odd parities; for more accurate but slower networks, use l=2 or higher, smaller number of features is faster
88 | chemical_embedding_irreps_out : str
89 | irreps of the spherical harmonics used for edges. If a single integer, indicates the full SH up to L_max=that_integer
90 | conv_to_output_hidden_irreps_out : str
91 | irreps used in hidden layer of output block
92 | """
93 |
94 | mm_types: list[int]
95 | e0: torch.Tensor
96 |
97 | def __init__(
98 | self,
99 | type_map: list[str],
100 | sel: int,
101 | r_max: float = 6.0,
102 | num_layers: int = 4,
103 | l_max: int = 2,
104 | num_features: int = 32,
105 | nonlinearity_type: str = "gate",
106 | parity: bool = True,
107 | num_basis: int = 8,
108 | BesselBasis_trainable: bool = True,
109 | PolynomialCutoff_p: int = 6,
110 | invariant_layers: int = 2,
111 | invariant_neurons: int = 64,
112 | use_sc: bool = True,
113 | irreps_edge_sh: str = "0e + 1e",
114 | feature_irreps_hidden: str = "32x0o + 32x0e + 32x1o + 32x1e",
115 | chemical_embedding_irreps_out: str = "32x0e",
116 | conv_to_output_hidden_irreps_out: str = "16x0e",
117 | precision: str = "float32",
118 | **kwargs: Any, # noqa: ANN401
119 | ) -> None:
120 | super().__init__(**kwargs)
121 | self.params = {
122 | "type_map": type_map,
123 | "sel": sel,
124 | "r_max": r_max,
125 | "num_layers": num_layers,
126 | "l_max": l_max,
127 | "num_features": num_features,
128 | "nonlinearity_type": nonlinearity_type,
129 | "parity": parity,
130 | "num_basis": num_basis,
131 | "BesselBasis_trainable": BesselBasis_trainable,
132 | "PolynomialCutoff_p": PolynomialCutoff_p,
133 | "invariant_layers": invariant_layers,
134 | "invariant_neurons": invariant_neurons,
135 | "use_sc": use_sc,
136 | "irreps_edge_sh": irreps_edge_sh,
137 | "feature_irreps_hidden": feature_irreps_hidden,
138 | "chemical_embedding_irreps_out": chemical_embedding_irreps_out,
139 | "conv_to_output_hidden_irreps_out": conv_to_output_hidden_irreps_out,
140 | "precision": precision,
141 | }
142 | self.type_map = type_map
143 | self.ntypes = len(type_map)
144 | self.preset_out_bias: dict[str, list] = {"energy": []}
145 | self.mm_types = []
146 | self.sel = sel
147 | self.num_layers = num_layers
148 | for ii, tt in enumerate(type_map):
149 | if not tt.startswith("m") and tt not in {"HW", "OW"}:
150 | self.preset_out_bias["energy"].append(None)
151 | else:
152 | self.preset_out_bias["energy"].append([0])
153 | self.mm_types.append(ii)
154 |
155 | self.rcut = r_max
156 | self.model = script(
157 | model_from_config(
158 | {
159 | "model_builders": ["EnergyModel"],
160 | "avg_num_neighbors": sel,
161 | "chemical_symbols": type_map,
162 | "num_types": self.ntypes,
163 | "r_max": r_max,
164 | "num_layers": num_layers,
165 | "l_max": l_max,
166 | "num_features": num_features,
167 | "nonlinearity_type": nonlinearity_type,
168 | "parity": parity,
169 | "num_basis": num_basis,
170 | "BesselBasis_trainable": BesselBasis_trainable,
171 | "PolynomialCutoff_p": PolynomialCutoff_p,
172 | "invariant_layers": invariant_layers,
173 | "invariant_neurons": invariant_neurons,
174 | "use_sc": use_sc,
175 | "irreps_edge_sh": irreps_edge_sh,
176 | "feature_irreps_hidden": feature_irreps_hidden,
177 | "chemical_embedding_irreps_out": chemical_embedding_irreps_out,
178 | "conv_to_output_hidden_irreps_out": conv_to_output_hidden_irreps_out,
179 | "model_dtype": precision,
180 | },
181 | ),
182 | )
183 | self.register_buffer(
184 | "e0",
185 | torch.zeros(
186 | self.ntypes,
187 | dtype=env.GLOBAL_PT_ENER_FLOAT_PRECISION,
188 | device=env.DEVICE,
189 | ),
190 | )
191 |
192 | def compute_or_load_stat(
193 | self,
194 | sampled_func, # noqa: ANN001
195 | stat_file_path: Optional[DPPath] = None,
196 | ) -> None:
197 | """Compute or load the statistics parameters of the model.
198 |
199 | For example, mean and standard deviation of descriptors or the energy bias of
200 | the fitting net. When `sampled` is provided, all the statistics parameters will
201 | be calculated (or re-calculated for update), and saved in the
202 | `stat_file_path`(s). When `sampled` is not provided, it will check the existence
203 | of `stat_file_path`(s) and load the calculated statistics parameters.
204 |
205 | Parameters
206 | ----------
207 | sampled_func
208 | The sampled data frames from different data systems.
209 | stat_file_path
210 | The path to the statistics files.
211 | """
212 | bias_out, _ = compute_output_stats(
213 | sampled_func,
214 | self.get_ntypes(),
215 | keys=["energy"],
216 | stat_file_path=stat_file_path,
217 | rcond=None,
218 | preset_bias=self.preset_out_bias,
219 | )
220 | if "energy" in bias_out:
221 | self.e0 = (
222 | bias_out["energy"]
223 | .view(self.e0.shape)
224 | .to(self.e0.dtype)
225 | .to(self.e0.device)
226 | )
227 |
228 | @torch.jit.export
229 | def fitting_output_def(self) -> FittingOutputDef:
230 | """Get the output def of developer implemented atomic models."""
231 | return FittingOutputDef(
232 | [
233 | OutputVariableDef(
234 | name="energy",
235 | shape=[1],
236 | reducible=True,
237 | r_differentiable=True,
238 | c_differentiable=True,
239 | ),
240 | ],
241 | )
242 |
243 | @torch.jit.export
244 | def get_rcut(self) -> float:
245 | """Get the cut-off radius."""
246 | return self.rcut * self.num_layers
247 |
248 | @torch.jit.export
249 | def get_type_map(self) -> list[str]:
250 | """Get the type map."""
251 | return self.type_map
252 |
253 | @torch.jit.export
254 | def get_sel(self) -> list[int]:
255 | """Return the number of selected atoms for each type."""
256 | return [self.sel]
257 |
258 | @torch.jit.export
259 | def get_dim_fparam(self) -> int:
260 | """Get the number (dimension) of frame parameters of this atomic model."""
261 | return 0
262 |
263 | @torch.jit.export
264 | def get_dim_aparam(self) -> int:
265 | """Get the number (dimension) of atomic parameters of this atomic model."""
266 | return 0
267 |
268 | @torch.jit.export
269 | def get_sel_type(self) -> list[int]:
270 | """Get the selected atom types of this model.
271 |
272 | Only atoms with selected atom types have atomic contribution
273 | to the result of the model.
274 | If returning an empty list, all atom types are selected.
275 | """
276 | return []
277 |
278 | @torch.jit.export
279 | def is_aparam_nall(self) -> bool:
280 | """Check whether the shape of atomic parameters is (nframes, nall, ndim).
281 |
282 | If False, the shape is (nframes, nloc, ndim).
283 | """
284 | return False
285 |
286 | @torch.jit.export
287 | def mixed_types(self) -> bool:
288 | """Return whether the model is in mixed-types mode.
289 |
290 | If true, the model
291 | 1. assumes total number of atoms aligned across frames;
292 | 2. uses a neighbor list that does not distinguish different atomic types.
293 | If false, the model
294 | 1. assumes total number of atoms of each atom type aligned across frames;
295 | 2. uses a neighbor list that distinguishes different atomic types.
296 | """
297 | return True
298 |
299 | @torch.jit.export
300 | def has_message_passing(self) -> bool:
301 | """Return whether the descriptor has message passing."""
302 | return False
303 |
304 | @torch.jit.export
305 | def forward(
306 | self,
307 | coord: torch.Tensor,
308 | atype: torch.Tensor,
309 | box: Optional[torch.Tensor] = None,
310 | fparam: Optional[torch.Tensor] = None,
311 | aparam: Optional[torch.Tensor] = None,
312 | do_atomic_virial: bool = False,
313 | ) -> dict[str, torch.Tensor]:
314 | """Forward pass of the model.
315 |
316 | Parameters
317 | ----------
318 | coord : torch.Tensor
319 | The coordinates of atoms.
320 | atype : torch.Tensor
321 | The atomic types of atoms.
322 | box : torch.Tensor, optional
323 | The box tensor.
324 | fparam : torch.Tensor, optional
325 | The frame parameters.
326 | aparam : torch.Tensor, optional
327 | The atomic parameters.
328 | do_atomic_virial : bool, optional
329 | Whether to compute atomic virial.
330 | """
331 | nloc = atype.shape[1]
332 | extended_coord, extended_atype, mapping, nlist = (
333 | extend_input_and_build_neighbor_list(
334 | coord,
335 | atype,
336 | self.rcut,
337 | self.get_sel(),
338 | mixed_types=True,
339 | box=box,
340 | )
341 | )
342 | model_ret_lower = self.forward_lower_common(
343 | nloc,
344 | extended_coord,
345 | extended_atype,
346 | nlist,
347 | mapping=mapping,
348 | fparam=fparam,
349 | aparam=aparam,
350 | do_atomic_virial=do_atomic_virial,
351 | comm_dict=None,
352 | box=box,
353 | )
354 | model_ret = communicate_extended_output(
355 | model_ret_lower,
356 | ModelOutputDef(self.fitting_output_def()),
357 | mapping,
358 | do_atomic_virial,
359 | )
360 | model_predict = {}
361 | model_predict["atom_energy"] = model_ret["energy"]
362 | model_predict["energy"] = model_ret["energy_redu"]
363 | model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
364 | model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
365 | if do_atomic_virial:
366 | model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
367 | return model_predict
368 |
369 | @torch.jit.export
370 | def forward_lower(
371 | self,
372 | extended_coord: torch.Tensor,
373 | extended_atype: torch.Tensor,
374 | nlist: torch.Tensor,
375 | mapping: Optional[torch.Tensor] = None,
376 | fparam: Optional[torch.Tensor] = None,
377 | aparam: Optional[torch.Tensor] = None,
378 | do_atomic_virial: bool = False,
379 | comm_dict: Optional[dict[str, torch.Tensor]] = None,
380 | ) -> dict[str, torch.Tensor]:
381 | """Forward lower pass of the model.
382 |
383 | Parameters
384 | ----------
385 | extended_coord : torch.Tensor
386 | The extended coordinates of atoms.
387 | extended_atype : torch.Tensor
388 | The extended atomic types of atoms.
389 | nlist : torch.Tensor
390 | The neighbor list.
391 | mapping : torch.Tensor, optional
392 | The mapping tensor.
393 | fparam : torch.Tensor, optional
394 | The frame parameters.
395 | aparam : torch.Tensor, optional
396 | The atomic parameters.
397 | do_atomic_virial : bool, optional
398 | Whether to compute atomic virial.
399 | comm_dict : dict[str, torch.Tensor], optional
400 | The communication dictionary.
401 | """
402 | nloc = nlist.shape[1]
403 | nf, nall = extended_atype.shape
404 | # recalculate nlist for ghost atoms
405 | if self.num_layers > 1 and nloc < nall:
406 | nlist = build_neighbor_list(
407 | extended_coord.view(nf, -1),
408 | extended_atype,
409 | nall,
410 | self.rcut * self.num_layers,
411 | self.sel,
412 | distinguish_types=False,
413 | )
414 | model_ret = self.forward_lower_common(
415 | nloc,
416 | extended_coord,
417 | extended_atype,
418 | nlist,
419 | mapping,
420 | fparam,
421 | aparam,
422 | do_atomic_virial,
423 | comm_dict,
424 | )
425 | model_predict = {}
426 | model_predict["atom_energy"] = model_ret["energy"]
427 | model_predict["energy"] = model_ret["energy_redu"]
428 | model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
429 | model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
430 | if do_atomic_virial:
431 | model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(-3)
432 | return model_predict
433 |
434 | def forward_lower_common(
435 | self,
436 | nloc: int,
437 | extended_coord: torch.Tensor,
438 | extended_atype: torch.Tensor,
439 | nlist: torch.Tensor,
440 | mapping: Optional[torch.Tensor] = None,
441 | fparam: Optional[torch.Tensor] = None,
442 | aparam: Optional[torch.Tensor] = None,
443 | do_atomic_virial: bool = False, # noqa: ARG002
444 | comm_dict: Optional[dict[str, torch.Tensor]] = None,
445 | box: Optional[torch.Tensor] = None,
446 | ) -> dict[str, torch.Tensor]:
447 | """Forward lower common pass of the model.
448 |
449 | Parameters
450 | ----------
451 | extended_coord : torch.Tensor
452 | The extended coordinates of atoms.
453 | extended_atype : torch.Tensor
454 | The extended atomic types of atoms.
455 | nlist : torch.Tensor
456 | The neighbor list.
457 | mapping : torch.Tensor, optional
458 | The mapping tensor.
459 | fparam : torch.Tensor, optional
460 | The frame parameters.
461 | aparam : torch.Tensor, optional
462 | The atomic parameters.
463 | do_atomic_virial : bool, optional
464 | Whether to compute atomic virial.
465 | comm_dict : dict[str, torch.Tensor], optional
466 | The communication dictionary.
467 | box : torch.Tensor, optional
468 | The box tensor.
469 | """
470 | nf, nall = extended_atype.shape
471 |
472 | extended_coord = extended_coord.view(nf, nall, 3)
473 | extended_coord_ = extended_coord
474 | if fparam is not None:
475 | msg = "fparam is unsupported"
476 | raise ValueError(msg)
477 | if aparam is not None:
478 | msg = "aparam is unsupported"
479 | raise ValueError(msg)
480 | if comm_dict is not None:
481 | msg = "comm_dict is unsupported"
482 | raise ValueError(msg)
483 | nlist = nlist.to(torch.int64)
484 | extended_atype = extended_atype.to(torch.int64)
485 | nall = extended_coord.shape[1]
486 |
487 | # fake as one frame
488 | extended_coord_ff = extended_coord.view(nf * nall, 3)
489 | extended_atype_ff = extended_atype.view(nf * nall)
490 | edge_index = torch.ops.deepmd_gnn.edge_index(
491 | nlist,
492 | extended_atype,
493 | torch.tensor(self.mm_types, dtype=torch.int64, device="cpu"),
494 | )
495 | edge_index = edge_index.T
496 | # Nequip and MACE have different defination for edge_index
497 | edge_index = edge_index[[1, 0]]
498 |
499 | # nequip can convert dtype by itself
500 | default_dtype = torch.float64
501 | extended_coord_ff = extended_coord_ff.to(default_dtype)
502 | extended_coord_ff.requires_grad_(True) # noqa: FBT003
503 |
504 | input_dict = {
505 | "pos": extended_coord_ff,
506 | "edge_index": edge_index,
507 | "atom_types": extended_atype_ff,
508 | }
509 | if box is not None and mapping is not None:
510 | # pass box, map edge index to real
511 | box_ff = box.to(extended_coord_ff.device)
512 | input_dict["cell"] = box_ff
513 | input_dict["pbc"] = torch.zeros(
514 | 3,
515 | dtype=torch.bool,
516 | device=box_ff.device,
517 | )
518 | batch = torch.arange(nf, device=box_ff.device).repeat(nall)
519 | input_dict["batch"] = batch
520 | ptr = torch.arange(
521 | start=0,
522 | end=nf * nall + 1,
523 | step=nall,
524 | dtype=torch.int64,
525 | device=batch.device,
526 | )
527 | input_dict["ptr"] = ptr
528 | mapping_ff = mapping.view(nf * nall) + torch.arange(
529 | 0,
530 | nf * nall,
531 | nall,
532 | dtype=mapping.dtype,
533 | device=mapping.device,
534 | ).unsqueeze(-1).expand(nf, nall).reshape(-1)
535 | shifts_atoms = extended_coord_ff - extended_coord_ff[mapping_ff]
536 | shifts = shifts_atoms[edge_index[1]] - shifts_atoms[edge_index[0]]
537 | edge_index = mapping_ff[edge_index]
538 | input_dict["edge_index"] = edge_index
539 | rec_cell, _ = torch.linalg.inv_ex(box.view(nf, 3, 3))
540 | edge_cell_shift = torch.einsum(
541 | "ni,nij->nj",
542 | shifts,
543 | rec_cell[batch[edge_index[0]]],
544 | )
545 | input_dict["edge_cell_shift"] = edge_cell_shift
546 |
547 | ret = self.model.forward(
548 | input_dict,
549 | )
550 |
551 | atom_energy = ret["atomic_energy"]
552 | if atom_energy is None:
553 | msg = "atom_energy is None"
554 | raise ValueError(msg)
555 | atom_energy = atom_energy.view(nf, nall).to(extended_coord_.dtype)[:, :nloc]
556 | # adds e0
557 | atom_energy = atom_energy + self.e0[extended_atype[:, :nloc]].view(
558 | nf,
559 | nloc,
560 | ).to(
561 | atom_energy.dtype,
562 | )
563 | energy = torch.sum(atom_energy, dim=1).view(nf, 1).to(extended_coord_.dtype)
564 | grad_outputs: list[Optional[torch.Tensor]] = [
565 | torch.ones_like(energy),
566 | ]
567 | force = torch.autograd.grad(
568 | outputs=[energy],
569 | inputs=[extended_coord_ff],
570 | grad_outputs=grad_outputs,
571 | retain_graph=True,
572 | create_graph=self.training,
573 | )[0]
574 | if force is None:
575 | msg = "force is None"
576 | raise ValueError(msg)
577 | force = -force
578 | atomic_virial = force.unsqueeze(-1).to(
579 | extended_coord_.dtype,
580 | ) @ extended_coord_ff.unsqueeze(-2).to(
581 | extended_coord_.dtype,
582 | )
583 | force = force.view(nf, nall, 3).to(extended_coord_.dtype)
584 | atomic_virial = atomic_virial.view(nf, nall, 1, 9)
585 | virial = torch.sum(atomic_virial, dim=1).view(nf, 9).to(extended_coord_.dtype)
586 |
587 | return {
588 | "energy_redu": energy.view(nf, 1),
589 | "energy_derv_r": force.view(nf, nall, 1, 3),
590 | "energy_derv_c_redu": virial.view(nf, 1, 9),
591 | # take the first nloc atoms to match other models
592 | "energy": atom_energy.view(nf, nloc, 1),
593 | # fake atom_virial
594 | "energy_derv_c": atomic_virial.view(nf, nall, 1, 9),
595 | }
596 |
597 | def serialize(self) -> dict:
598 | """Serialize the model."""
599 | return {
600 | "@class": "Model",
601 | "@version": 1,
602 | "type": "mace",
603 | **self.params,
604 | "@variables": {
605 | **{
606 | kk: to_numpy_array(vv) for kk, vv in self.model.state_dict().items()
607 | },
608 | "e0": to_numpy_array(self.e0),
609 | },
610 | }
611 |
612 | @classmethod
613 | def deserialize(cls, data: dict) -> "NequipModel":
614 | """Deserialize the model."""
615 | data = data.copy()
616 | if not (data.pop("@class") == "Model" and data.pop("type") == "mace"):
617 | msg = "data is not a serialized NequipModel"
618 | raise ValueError(msg)
619 | check_version_compatibility(data.pop("@version"), 1, 1)
620 | variables = {
621 | kk: to_torch_tensor(vv) for kk, vv in data.pop("@variables").items()
622 | }
623 | model = cls(**data)
624 | model.e0 = variables.pop("e0")
625 | model.model.load_state_dict(variables)
626 | return model
627 |
628 | @torch.jit.export
629 | def get_nnei(self) -> int:
630 | """Return the total number of selected neighboring atoms in cut-off radius."""
631 | return self.sel
632 |
633 | @torch.jit.export
634 | def get_nsel(self) -> int:
635 | """Return the total number of selected neighboring atoms in cut-off radius."""
636 | return self.sel
637 |
638 | @classmethod
639 | def update_sel(
640 | cls,
641 | train_data: DeepmdDataSystem,
642 | type_map: Optional[list[str]],
643 | local_jdata: dict,
644 | ) -> tuple[dict, Optional[float]]:
645 | """Update the selection and perform neighbor statistics.
646 |
647 | Parameters
648 | ----------
649 | train_data : DeepmdDataSystem
650 | data used to do neighbor statictics
651 | type_map : list[str], optional
652 | The name of each type of atoms
653 | local_jdata : dict
654 | The local data refer to the current class
655 |
656 | Returns
657 | -------
658 | dict
659 | The updated local data
660 | float
661 | The minimum distance between two atoms
662 | """
663 | local_jdata_cpy = local_jdata.copy()
664 | min_nbor_dist, sel = UpdateSel().update_one_sel(
665 | train_data,
666 | type_map,
667 | local_jdata_cpy["r_max"],
668 | local_jdata_cpy["sel"],
669 | mixed_type=True,
670 | )
671 | local_jdata_cpy["sel"] = sel[0]
672 | return local_jdata_cpy, min_nbor_dist
673 |
674 | @torch.jit.export
675 | def model_output_type(self) -> list[str]:
676 | """Get the output type for the model."""
677 | return ["energy"]
678 |
679 | def translated_output_def(self) -> dict[str, Any]:
680 | """Get the translated output def for the model."""
681 | out_def_data = self.model_output_def().get_data()
682 | output_def = {
683 | "atom_energy": deepcopy(out_def_data["energy"]),
684 | "energy": deepcopy(out_def_data["energy_redu"]),
685 | }
686 | output_def["force"] = deepcopy(out_def_data["energy_derv_r"])
687 | output_def["force"].squeeze(-2)
688 | output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
689 | output_def["virial"].squeeze(-2)
690 | output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
691 | output_def["atom_virial"].squeeze(-3)
692 | if "mask" in out_def_data:
693 | output_def["mask"] = deepcopy(out_def_data["mask"])
694 | return output_def
695 |
696 | def model_output_def(self) -> ModelOutputDef:
697 | """Get the output def for the model."""
698 | return ModelOutputDef(self.fitting_output_def())
699 |
--------------------------------------------------------------------------------
/deepmd_gnn/op.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: LGPL-3.0-or-later
2 | """Load OP library."""
3 |
4 | from __future__ import annotations
5 |
6 | import platform
7 | from pathlib import Path
8 |
9 | import torch
10 |
11 | import deepmd_gnn.lib
12 |
13 | SHARED_LIB_DIR = Path(deepmd_gnn.lib.__path__[0])
14 |
15 |
16 | def load_library(module_name: str) -> None:
17 | """Load OP library.
18 |
19 | Parameters
20 | ----------
21 | module_name : str
22 | Name of the module
23 |
24 | Returns
25 | -------
26 | bool
27 | Whether the library is loaded successfully
28 | """
29 | if platform.system() == "Windows":
30 | ext = ".dll"
31 | prefix = ""
32 | else:
33 | ext = ".so"
34 | prefix = "lib"
35 |
36 | module_file = (SHARED_LIB_DIR / (prefix + module_name)).with_suffix(ext).resolve()
37 |
38 | torch.ops.load_library(module_file)
39 |
40 |
41 | load_library("deepmd_gnn")
42 |
--------------------------------------------------------------------------------
/deepmd_gnn/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/deepmd_gnn/py.typed
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: LGPL-3.0-or-later
2 | """Configuration file for the Sphinx documentation builder."""
3 | #
4 | # This file only contains a selection of the most common options. For a full
5 | # list see the documentation:
6 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
7 |
8 | # -- Path setup --------------------------------------------------------------
9 |
10 | # If extensions (or modules to document with autodoc) are in another directory,
11 | # add these directories to sys.path here. If the directory is relative to the
12 | # documentation root, use os.path.abspath to make it absolute, like shown here.
13 | #
14 |
15 | # import sys
16 | # sys.path.insert(0, os.path.abspath('..'))
17 | from datetime import datetime, timezone
18 |
19 | # -- Project information -----------------------------------------------------
20 |
21 | project = "DeePMD-GNN"
22 | copyright = f"2024-{datetime.now(tz=timezone.utc).year}, DeepModeling" # noqa: A001
23 | author = "Jinzhe Zeng"
24 |
25 |
26 | # -- General configuration ---------------------------------------------------
27 |
28 | # Add any Sphinx extension module names here, as strings. They can be
29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
30 | # ones.
31 | extensions = [
32 | "sphinx.ext.autosummary",
33 | "sphinx.ext.mathjax",
34 | "sphinx.ext.viewcode",
35 | "sphinx.ext.intersphinx",
36 | # "sphinxarg.ext",
37 | "myst_parser",
38 | # "sphinx_favicon",
39 | "deepmodeling_sphinx",
40 | "dargs.sphinx",
41 | # "sphinxcontrib.bibtex",
42 | # "sphinx_design",
43 | "autoapi.extension",
44 | ]
45 |
46 | # Add any paths that contain templates here, relative to this directory.
47 | # templates_path = ['_templates']
48 |
49 | # List of patterns, relative to source directory, that match files and
50 | # directories to ignore when looking for source files.
51 | # This pattern also affects html_static_path and html_extra_path.
52 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
53 |
54 |
55 | # -- Options for HTML output -------------------------------------------------
56 |
57 | # The theme to use for HTML and HTML Help pages. See the documentation for
58 | # a list of builtin themes.
59 | #
60 | html_theme = "sphinx_book_theme"
61 | # html_logo = "_static/logo.svg"
62 | html_static_path = ["_static"]
63 | html_js_files: list[str] = []
64 | html_css_files = ["css/custom.css"]
65 | html_extra_path = ["report.html", "fire.png", "bundle.js", "bundle.css"]
66 |
67 | html_theme_options = {
68 | "github_url": "https://github.com/deepmodeling/deepmd-gnn",
69 | "gitlab_url": "https://gitlab.com/RutgersLBSR/deepmd-gnn",
70 | "logo": {
71 | "text": "DeePMD-GNN",
72 | "alt_text": "DeePMD-GNN",
73 | },
74 | }
75 |
76 | html_context = {
77 | "github_user": "deepmodeling",
78 | "github_repo": "deepmd-gnn",
79 | "github_version": "master",
80 | "doc_path": "docs",
81 | }
82 |
83 | myst_heading_anchors = 3
84 |
85 | # favicons = [
86 | # {
87 | # "rel": "icon",
88 | # "static-file": "logo.svg",
89 | # "type": "image/svg+xml",
90 | # },
91 | # ]
92 |
93 | enable_deepmodeling = False
94 |
95 | myst_enable_extensions = [
96 | "dollarmath",
97 | "colon_fence",
98 | "attrs_inline",
99 | ]
100 | mathjax_path = (
101 | "https://cdnjs.cloudflare.com/ajax/libs/mathjax/3.2.2/es5/tex-mml-chtml.min.js"
102 | )
103 | mathjax_options = {
104 | "integrity": "sha512-6FaAxxHuKuzaGHWnV00ftWqP3luSBRSopnNAA2RvQH1fOfnF/A1wOfiUWF7cLIOFcfb1dEhXwo5VG3DAisocRw==",
105 | "crossorigin": "anonymous",
106 | }
107 | mathjax3_config = {
108 | "loader": {"load": ["[tex]/mhchem"]},
109 | "tex": {"packages": {"[+]": ["mhchem"]}},
110 | }
111 |
112 | execution_mode = "off"
113 | numpydoc_show_class_members = False
114 |
115 | # Add any paths that contain custom static files (such as style sheets) here,
116 | # relative to this directory. They are copied after the builtin static files,
117 | # so a file named "default.css" will overwrite the builtin "default.css".
118 | # html_static_path = ['_static']
119 |
120 | intersphinx_mapping = {
121 | "numpy": ("https://docs.scipy.org/doc/numpy/", None),
122 | "python": ("https://docs.python.org/", None),
123 | "deepmd": ("https://docs.deepmodeling.com/projects/deepmd/", None),
124 | "torch": ("https://pytorch.org/docs/stable/", None),
125 | }
126 | autoapi_dirs = ["../deepmd_gnn"]
127 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. include:: ../README.md
2 | :parser: myst_parser.sphinx_
3 |
4 | Table of contents
5 | =================
6 | .. toctree::
7 | :maxdepth: 2
8 |
9 | Overview
10 | parameters
11 | Python API
12 |
13 | Indices and tables
14 | ==================
15 |
16 | * :ref:`genindex`
17 | * :ref:`modindex`
18 | * :ref:`search`
19 |
--------------------------------------------------------------------------------
/docs/parameters.rst:
--------------------------------------------------------------------------------
1 | Parameters
2 | ==========
3 |
4 | MACE
5 | ----
6 |
7 | .. dargs::
8 | :module: deepmd_gnn.argcheck
9 | :func: mace_model_args
10 |
11 | NequIP
12 | ------
13 |
14 | .. dargs::
15 | :module: deepmd_gnn.argcheck
16 | :func: nequip_model_args
17 |
--------------------------------------------------------------------------------
/examples/.gitignore:
--------------------------------------------------------------------------------
1 | lcurve.out
2 | out.json
3 | input_v2_compat.json
4 | checkpoint
5 | *.pth
6 | *.pb
7 | *.pt
8 | model.ckpt*
9 |
--------------------------------------------------------------------------------
/examples/dprc/data/nopbc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/dprc/data/nopbc
--------------------------------------------------------------------------------
/examples/dprc/data/set.000/box.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/dprc/data/set.000/box.npy
--------------------------------------------------------------------------------
/examples/dprc/data/set.000/coord.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/dprc/data/set.000/coord.npy
--------------------------------------------------------------------------------
/examples/dprc/data/set.000/energy.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/dprc/data/set.000/energy.npy
--------------------------------------------------------------------------------
/examples/dprc/data/set.000/force.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/dprc/data/set.000/force.npy
--------------------------------------------------------------------------------
/examples/dprc/data/type.raw:
--------------------------------------------------------------------------------
1 | 1
2 | 0
3 | 1
4 | 3
5 | 0
6 | 1
7 | 1
8 | 0
9 | 1
10 | 0
11 | 1
12 | 3
13 | 3
14 | 5
15 | 3
16 | 3
17 | 3
18 | 0
19 | 1
20 | 1
21 | 0
22 | 1
23 | 1
24 | 1
25 | 4
26 | 2
27 | 2
28 | 4
29 | 4
30 | 2
31 | 4
32 | 2
33 | 2
34 | 4
35 | 2
36 | 2
37 | 4
38 | 2
39 | 2
40 | 4
41 | 2
42 | 2
43 | 2
44 | 4
45 | 2
46 | 2
47 | 4
48 | 2
49 | 2
50 | 4
51 | 2
52 | 2
53 | 4
54 | 2
55 | 2
56 | 4
57 | 2
58 | 2
59 | 4
60 | 2
61 | 2
62 | 4
63 | 2
64 | 2
65 | 4
66 | 2
67 | 2
68 | 4
69 | 2
70 | 2
71 | 4
72 | 2
73 | 2
74 | 4
75 | 2
76 | 2
77 | 4
78 | 2
79 | 2
80 | 4
81 | 2
82 | 2
83 | 4
84 | 2
85 | 2
86 | 4
87 | 2
88 | 2
89 | 4
90 | 2
91 | 2
92 | 4
93 | 2
94 | 2
95 | 4
96 | 2
97 | 2
98 | 4
99 | 2
100 | 2
101 | 4
102 | 2
103 | 2
104 | 4
105 | 2
106 | 2
107 | 4
108 | 2
109 | 2
110 | 4
111 | 2
112 | 2
113 | 4
114 | 2
115 | 2
116 | 4
117 | 2
118 | 2
119 | 4
120 | 2
121 | 2
122 | 4
123 | 2
124 | 2
125 | 4
126 | 2
127 | 4
128 | 2
129 | 2
130 | 4
131 | 2
132 | 2
133 | 4
134 | 2
135 | 2
136 | 4
137 | 2
138 | 2
139 | 4
140 | 2
141 | 2
142 | 4
143 | 2
144 | 2
145 | 4
146 | 2
147 | 2
148 | 4
149 | 2
150 | 2
151 | 4
152 | 2
153 | 4
154 | 2
155 | 4
156 | 2
157 | 2
158 | 4
159 | 2
160 | 2
161 | 4
162 | 2
163 | 2
164 | 4
165 | 2
166 | 2
167 | 4
168 | 2
169 | 2
170 | 4
171 | 2
172 | 2
173 | 4
174 | 2
175 | 2
176 | 4
177 | 2
178 | 2
179 | 4
180 | 2
181 | 2
182 | 4
183 | 2
184 | 2
185 | 4
186 | 2
187 | 2
188 | 4
189 | 2
190 | 4
191 | 2
192 | 2
193 | 4
194 | 2
195 | 2
196 | 4
197 | 2
198 | 2
199 | 4
200 | 2
201 | 2
202 | 2
203 | 4
204 | 2
205 | 2
206 | 4
207 | 2
208 | 2
209 | 4
210 | 2
211 | 2
212 | 2
213 | 4
214 | 2
215 | 2
216 | 2
217 | 4
218 | 2
219 | 2
220 | 4
221 | 2
222 | 2
223 | 2
224 | 2
225 | 2
226 | 4
227 | 2
228 | 2
229 | 4
230 | 2
231 | 2
232 | 4
233 | 2
234 | 2
235 | 4
236 | 2
237 | 2
238 | 2
239 | 2
240 | 4
241 | 2
242 | 2
243 | 4
244 | 2
245 | 2
246 | 4
247 | 2
248 | 2
249 | 4
250 | 2
251 | 2
252 | 4
253 | 2
254 | 2
255 | 2
256 | 4
257 | 2
258 | 2
259 | 4
260 | 2
261 | 2
262 | 4
263 | 2
264 | 2
265 | 4
266 | 2
267 | 2
268 | 4
269 | 2
270 | 2
271 | 4
272 | 2
273 | 2
274 | 2
275 | 4
276 |
--------------------------------------------------------------------------------
/examples/dprc/data/type_map.raw:
--------------------------------------------------------------------------------
1 | C
2 | H
3 | HW
4 | O
5 | OW
6 | P
7 |
--------------------------------------------------------------------------------
/examples/dprc/mace/input.json:
--------------------------------------------------------------------------------
1 | {
2 | "_comment1": " model parameters",
3 | "model": {
4 | "type": "mace",
5 | "sel": 150,
6 | "type_map": [
7 | "C",
8 | "P",
9 | "O",
10 | "H",
11 | "OW",
12 | "HW"
13 | ],
14 | "r_max": 6.0
15 | },
16 | "learning_rate": {
17 | "type": "exp",
18 | "decay_steps": 5000,
19 | "start_lr": 0.001,
20 | "stop_lr": 3.51e-8,
21 | "_comment2": "that's all"
22 | },
23 | "loss": {
24 | "type": "ener",
25 | "start_pref_e": 0.02,
26 | "limit_pref_e": 1,
27 | "start_pref_f": 1000,
28 | "limit_pref_f": 1,
29 | "start_pref_v": 0,
30 | "limit_pref_v": 0,
31 | "_comment3": " that's all"
32 | },
33 | "training": {
34 | "training_data": {
35 | "systems": [
36 | "../data"
37 | ],
38 | "batch_size": "auto",
39 | "_comment4": "that's all"
40 | },
41 | "numb_steps": 100000,
42 | "seed": 10,
43 | "disp_file": "lcurve.out",
44 | "disp_freq": 100,
45 | "save_freq": 100,
46 | "_comment5": "that's all"
47 | },
48 | "_comment6": "that's all"
49 | }
50 |
--------------------------------------------------------------------------------
/examples/dprc/nequip/input.json:
--------------------------------------------------------------------------------
1 | {
2 | "_comment1": " model parameters",
3 | "model": {
4 | "type": "nequip",
5 | "type_map": [
6 | "C",
7 | "P",
8 | "O",
9 | "H",
10 | "OW",
11 | "HW"
12 | ],
13 | "r_max": 6.0,
14 | "sel": "auto",
15 | "l_max": 1,
16 | "_comment2": " that's all"
17 | },
18 |
19 | "learning_rate": {
20 | "type": "exp",
21 | "decay_steps": 5000,
22 | "start_lr": 0.001,
23 | "stop_lr": 3.51e-8,
24 | "_comment3": "that's all"
25 | },
26 |
27 | "loss": {
28 | "type": "ener",
29 | "start_pref_e": 0.02,
30 | "limit_pref_e": 1,
31 | "start_pref_f": 1000,
32 | "limit_pref_f": 1,
33 | "start_pref_v": 0,
34 | "limit_pref_v": 0,
35 | "_comment4": " that's all"
36 | },
37 |
38 | "training": {
39 | "training_data": {
40 | "systems": [
41 | "../data"
42 | ],
43 | "batch_size": "auto",
44 | "_comment4": "that's all"
45 | },
46 | "numb_steps": 100000,
47 | "seed": 10,
48 | "disp_file": "lcurve.out",
49 | "disp_freq": 100,
50 | "save_freq": 100,
51 | "_comment5": "that's all"
52 | },
53 |
54 | "_comment8": "that's all"
55 | }
56 |
--------------------------------------------------------------------------------
/examples/water/data/data_0/set.000/box.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_0/set.000/box.npy
--------------------------------------------------------------------------------
/examples/water/data/data_0/set.000/coord.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_0/set.000/coord.npy
--------------------------------------------------------------------------------
/examples/water/data/data_0/set.000/energy.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_0/set.000/energy.npy
--------------------------------------------------------------------------------
/examples/water/data/data_0/set.000/force.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_0/set.000/force.npy
--------------------------------------------------------------------------------
/examples/water/data/data_0/type.raw:
--------------------------------------------------------------------------------
1 | 0
2 | 0
3 | 0
4 | 0
5 | 0
6 | 0
7 | 0
8 | 0
9 | 0
10 | 0
11 | 0
12 | 0
13 | 0
14 | 0
15 | 0
16 | 0
17 | 0
18 | 0
19 | 0
20 | 0
21 | 0
22 | 0
23 | 0
24 | 0
25 | 0
26 | 0
27 | 0
28 | 0
29 | 0
30 | 0
31 | 0
32 | 0
33 | 0
34 | 0
35 | 0
36 | 0
37 | 0
38 | 0
39 | 0
40 | 0
41 | 0
42 | 0
43 | 0
44 | 0
45 | 0
46 | 0
47 | 0
48 | 0
49 | 0
50 | 0
51 | 0
52 | 0
53 | 0
54 | 0
55 | 0
56 | 0
57 | 0
58 | 0
59 | 0
60 | 0
61 | 0
62 | 0
63 | 0
64 | 0
65 | 1
66 | 1
67 | 1
68 | 1
69 | 1
70 | 1
71 | 1
72 | 1
73 | 1
74 | 1
75 | 1
76 | 1
77 | 1
78 | 1
79 | 1
80 | 1
81 | 1
82 | 1
83 | 1
84 | 1
85 | 1
86 | 1
87 | 1
88 | 1
89 | 1
90 | 1
91 | 1
92 | 1
93 | 1
94 | 1
95 | 1
96 | 1
97 | 1
98 | 1
99 | 1
100 | 1
101 | 1
102 | 1
103 | 1
104 | 1
105 | 1
106 | 1
107 | 1
108 | 1
109 | 1
110 | 1
111 | 1
112 | 1
113 | 1
114 | 1
115 | 1
116 | 1
117 | 1
118 | 1
119 | 1
120 | 1
121 | 1
122 | 1
123 | 1
124 | 1
125 | 1
126 | 1
127 | 1
128 | 1
129 | 1
130 | 1
131 | 1
132 | 1
133 | 1
134 | 1
135 | 1
136 | 1
137 | 1
138 | 1
139 | 1
140 | 1
141 | 1
142 | 1
143 | 1
144 | 1
145 | 1
146 | 1
147 | 1
148 | 1
149 | 1
150 | 1
151 | 1
152 | 1
153 | 1
154 | 1
155 | 1
156 | 1
157 | 1
158 | 1
159 | 1
160 | 1
161 | 1
162 | 1
163 | 1
164 | 1
165 | 1
166 | 1
167 | 1
168 | 1
169 | 1
170 | 1
171 | 1
172 | 1
173 | 1
174 | 1
175 | 1
176 | 1
177 | 1
178 | 1
179 | 1
180 | 1
181 | 1
182 | 1
183 | 1
184 | 1
185 | 1
186 | 1
187 | 1
188 | 1
189 | 1
190 | 1
191 | 1
192 | 1
193 |
--------------------------------------------------------------------------------
/examples/water/data/data_0/type_map.raw:
--------------------------------------------------------------------------------
1 | O
2 | H
3 |
--------------------------------------------------------------------------------
/examples/water/data/data_1/set.000/box.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.000/box.npy
--------------------------------------------------------------------------------
/examples/water/data/data_1/set.000/coord.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.000/coord.npy
--------------------------------------------------------------------------------
/examples/water/data/data_1/set.000/energy.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.000/energy.npy
--------------------------------------------------------------------------------
/examples/water/data/data_1/set.000/force.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.000/force.npy
--------------------------------------------------------------------------------
/examples/water/data/data_1/set.001/box.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.001/box.npy
--------------------------------------------------------------------------------
/examples/water/data/data_1/set.001/coord.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.001/coord.npy
--------------------------------------------------------------------------------
/examples/water/data/data_1/set.001/energy.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.001/energy.npy
--------------------------------------------------------------------------------
/examples/water/data/data_1/set.001/force.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_1/set.001/force.npy
--------------------------------------------------------------------------------
/examples/water/data/data_1/type.raw:
--------------------------------------------------------------------------------
1 | 0
2 | 0
3 | 0
4 | 0
5 | 0
6 | 0
7 | 0
8 | 0
9 | 0
10 | 0
11 | 0
12 | 0
13 | 0
14 | 0
15 | 0
16 | 0
17 | 0
18 | 0
19 | 0
20 | 0
21 | 0
22 | 0
23 | 0
24 | 0
25 | 0
26 | 0
27 | 0
28 | 0
29 | 0
30 | 0
31 | 0
32 | 0
33 | 0
34 | 0
35 | 0
36 | 0
37 | 0
38 | 0
39 | 0
40 | 0
41 | 0
42 | 0
43 | 0
44 | 0
45 | 0
46 | 0
47 | 0
48 | 0
49 | 0
50 | 0
51 | 0
52 | 0
53 | 0
54 | 0
55 | 0
56 | 0
57 | 0
58 | 0
59 | 0
60 | 0
61 | 0
62 | 0
63 | 0
64 | 0
65 | 1
66 | 1
67 | 1
68 | 1
69 | 1
70 | 1
71 | 1
72 | 1
73 | 1
74 | 1
75 | 1
76 | 1
77 | 1
78 | 1
79 | 1
80 | 1
81 | 1
82 | 1
83 | 1
84 | 1
85 | 1
86 | 1
87 | 1
88 | 1
89 | 1
90 | 1
91 | 1
92 | 1
93 | 1
94 | 1
95 | 1
96 | 1
97 | 1
98 | 1
99 | 1
100 | 1
101 | 1
102 | 1
103 | 1
104 | 1
105 | 1
106 | 1
107 | 1
108 | 1
109 | 1
110 | 1
111 | 1
112 | 1
113 | 1
114 | 1
115 | 1
116 | 1
117 | 1
118 | 1
119 | 1
120 | 1
121 | 1
122 | 1
123 | 1
124 | 1
125 | 1
126 | 1
127 | 1
128 | 1
129 | 1
130 | 1
131 | 1
132 | 1
133 | 1
134 | 1
135 | 1
136 | 1
137 | 1
138 | 1
139 | 1
140 | 1
141 | 1
142 | 1
143 | 1
144 | 1
145 | 1
146 | 1
147 | 1
148 | 1
149 | 1
150 | 1
151 | 1
152 | 1
153 | 1
154 | 1
155 | 1
156 | 1
157 | 1
158 | 1
159 | 1
160 | 1
161 | 1
162 | 1
163 | 1
164 | 1
165 | 1
166 | 1
167 | 1
168 | 1
169 | 1
170 | 1
171 | 1
172 | 1
173 | 1
174 | 1
175 | 1
176 | 1
177 | 1
178 | 1
179 | 1
180 | 1
181 | 1
182 | 1
183 | 1
184 | 1
185 | 1
186 | 1
187 | 1
188 | 1
189 | 1
190 | 1
191 | 1
192 | 1
193 |
--------------------------------------------------------------------------------
/examples/water/data/data_1/type_map.raw:
--------------------------------------------------------------------------------
1 | O
2 | H
3 |
--------------------------------------------------------------------------------
/examples/water/data/data_2/set.000/box.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_2/set.000/box.npy
--------------------------------------------------------------------------------
/examples/water/data/data_2/set.000/coord.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_2/set.000/coord.npy
--------------------------------------------------------------------------------
/examples/water/data/data_2/set.000/energy.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_2/set.000/energy.npy
--------------------------------------------------------------------------------
/examples/water/data/data_2/set.000/force.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_2/set.000/force.npy
--------------------------------------------------------------------------------
/examples/water/data/data_2/type.raw:
--------------------------------------------------------------------------------
1 | 0
2 | 0
3 | 0
4 | 0
5 | 0
6 | 0
7 | 0
8 | 0
9 | 0
10 | 0
11 | 0
12 | 0
13 | 0
14 | 0
15 | 0
16 | 0
17 | 0
18 | 0
19 | 0
20 | 0
21 | 0
22 | 0
23 | 0
24 | 0
25 | 0
26 | 0
27 | 0
28 | 0
29 | 0
30 | 0
31 | 0
32 | 0
33 | 0
34 | 0
35 | 0
36 | 0
37 | 0
38 | 0
39 | 0
40 | 0
41 | 0
42 | 0
43 | 0
44 | 0
45 | 0
46 | 0
47 | 0
48 | 0
49 | 0
50 | 0
51 | 0
52 | 0
53 | 0
54 | 0
55 | 0
56 | 0
57 | 0
58 | 0
59 | 0
60 | 0
61 | 0
62 | 0
63 | 0
64 | 0
65 | 1
66 | 1
67 | 1
68 | 1
69 | 1
70 | 1
71 | 1
72 | 1
73 | 1
74 | 1
75 | 1
76 | 1
77 | 1
78 | 1
79 | 1
80 | 1
81 | 1
82 | 1
83 | 1
84 | 1
85 | 1
86 | 1
87 | 1
88 | 1
89 | 1
90 | 1
91 | 1
92 | 1
93 | 1
94 | 1
95 | 1
96 | 1
97 | 1
98 | 1
99 | 1
100 | 1
101 | 1
102 | 1
103 | 1
104 | 1
105 | 1
106 | 1
107 | 1
108 | 1
109 | 1
110 | 1
111 | 1
112 | 1
113 | 1
114 | 1
115 | 1
116 | 1
117 | 1
118 | 1
119 | 1
120 | 1
121 | 1
122 | 1
123 | 1
124 | 1
125 | 1
126 | 1
127 | 1
128 | 1
129 | 1
130 | 1
131 | 1
132 | 1
133 | 1
134 | 1
135 | 1
136 | 1
137 | 1
138 | 1
139 | 1
140 | 1
141 | 1
142 | 1
143 | 1
144 | 1
145 | 1
146 | 1
147 | 1
148 | 1
149 | 1
150 | 1
151 | 1
152 | 1
153 | 1
154 | 1
155 | 1
156 | 1
157 | 1
158 | 1
159 | 1
160 | 1
161 | 1
162 | 1
163 | 1
164 | 1
165 | 1
166 | 1
167 | 1
168 | 1
169 | 1
170 | 1
171 | 1
172 | 1
173 | 1
174 | 1
175 | 1
176 | 1
177 | 1
178 | 1
179 | 1
180 | 1
181 | 1
182 | 1
183 | 1
184 | 1
185 | 1
186 | 1
187 | 1
188 | 1
189 | 1
190 | 1
191 | 1
192 | 1
193 |
--------------------------------------------------------------------------------
/examples/water/data/data_2/type_map.raw:
--------------------------------------------------------------------------------
1 | O
2 | H
3 |
--------------------------------------------------------------------------------
/examples/water/data/data_3/set.000/box.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_3/set.000/box.npy
--------------------------------------------------------------------------------
/examples/water/data/data_3/set.000/coord.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_3/set.000/coord.npy
--------------------------------------------------------------------------------
/examples/water/data/data_3/set.000/energy.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_3/set.000/energy.npy
--------------------------------------------------------------------------------
/examples/water/data/data_3/set.000/force.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/examples/water/data/data_3/set.000/force.npy
--------------------------------------------------------------------------------
/examples/water/data/data_3/type.raw:
--------------------------------------------------------------------------------
1 | 0
2 | 0
3 | 0
4 | 0
5 | 0
6 | 0
7 | 0
8 | 0
9 | 0
10 | 0
11 | 0
12 | 0
13 | 0
14 | 0
15 | 0
16 | 0
17 | 0
18 | 0
19 | 0
20 | 0
21 | 0
22 | 0
23 | 0
24 | 0
25 | 0
26 | 0
27 | 0
28 | 0
29 | 0
30 | 0
31 | 0
32 | 0
33 | 0
34 | 0
35 | 0
36 | 0
37 | 0
38 | 0
39 | 0
40 | 0
41 | 0
42 | 0
43 | 0
44 | 0
45 | 0
46 | 0
47 | 0
48 | 0
49 | 0
50 | 0
51 | 0
52 | 0
53 | 0
54 | 0
55 | 0
56 | 0
57 | 0
58 | 0
59 | 0
60 | 0
61 | 0
62 | 0
63 | 0
64 | 0
65 | 1
66 | 1
67 | 1
68 | 1
69 | 1
70 | 1
71 | 1
72 | 1
73 | 1
74 | 1
75 | 1
76 | 1
77 | 1
78 | 1
79 | 1
80 | 1
81 | 1
82 | 1
83 | 1
84 | 1
85 | 1
86 | 1
87 | 1
88 | 1
89 | 1
90 | 1
91 | 1
92 | 1
93 | 1
94 | 1
95 | 1
96 | 1
97 | 1
98 | 1
99 | 1
100 | 1
101 | 1
102 | 1
103 | 1
104 | 1
105 | 1
106 | 1
107 | 1
108 | 1
109 | 1
110 | 1
111 | 1
112 | 1
113 | 1
114 | 1
115 | 1
116 | 1
117 | 1
118 | 1
119 | 1
120 | 1
121 | 1
122 | 1
123 | 1
124 | 1
125 | 1
126 | 1
127 | 1
128 | 1
129 | 1
130 | 1
131 | 1
132 | 1
133 | 1
134 | 1
135 | 1
136 | 1
137 | 1
138 | 1
139 | 1
140 | 1
141 | 1
142 | 1
143 | 1
144 | 1
145 | 1
146 | 1
147 | 1
148 | 1
149 | 1
150 | 1
151 | 1
152 | 1
153 | 1
154 | 1
155 | 1
156 | 1
157 | 1
158 | 1
159 | 1
160 | 1
161 | 1
162 | 1
163 | 1
164 | 1
165 | 1
166 | 1
167 | 1
168 | 1
169 | 1
170 | 1
171 | 1
172 | 1
173 | 1
174 | 1
175 | 1
176 | 1
177 | 1
178 | 1
179 | 1
180 | 1
181 | 1
182 | 1
183 | 1
184 | 1
185 | 1
186 | 1
187 | 1
188 | 1
189 | 1
190 | 1
191 | 1
192 | 1
193 |
--------------------------------------------------------------------------------
/examples/water/data/data_3/type_map.raw:
--------------------------------------------------------------------------------
1 | O
2 | H
3 |
--------------------------------------------------------------------------------
/examples/water/mace/input.json:
--------------------------------------------------------------------------------
1 | {
2 | "_comment1": " model parameters",
3 | "model": {
4 | "type": "mace",
5 | "type_map": [
6 | "O",
7 | "H"
8 | ],
9 | "r_max": 6.0,
10 | "sel": "auto",
11 | "hidden_irreps": "64x0e",
12 | "_comment2": " that's all"
13 | },
14 |
15 | "learning_rate": {
16 | "type": "exp",
17 | "decay_steps": 5000,
18 | "start_lr": 0.001,
19 | "stop_lr": 3.51e-8,
20 | "_comment3": "that's all"
21 | },
22 |
23 | "loss": {
24 | "type": "ener",
25 | "start_pref_e": 0.02,
26 | "limit_pref_e": 1,
27 | "start_pref_f": 1000,
28 | "limit_pref_f": 1,
29 | "start_pref_v": 0,
30 | "limit_pref_v": 0,
31 | "_comment4": " that's all"
32 | },
33 |
34 | "training": {
35 | "training_data": {
36 | "systems": [
37 | "../data/data_0/",
38 | "../data/data_1/",
39 | "../data/data_2/"
40 | ],
41 | "batch_size": "auto",
42 | "_comment5": "that's all"
43 | },
44 | "validation_data": {
45 | "systems": [
46 | "../data/data_3"
47 | ],
48 | "batch_size": 1,
49 | "numb_btch": 3,
50 | "_comment6": "that's all"
51 | },
52 | "numb_steps": 1000000,
53 | "seed": 10,
54 | "disp_file": "lcurve.out",
55 | "disp_freq": 100,
56 | "save_freq": 1000,
57 | "_comment7": "that's all"
58 | },
59 |
60 | "_comment8": "that's all"
61 | }
62 |
--------------------------------------------------------------------------------
/examples/water/nequip/input.json:
--------------------------------------------------------------------------------
1 | {
2 | "_comment1": " model parameters",
3 | "model": {
4 | "type": "nequip",
5 | "type_map": [
6 | "O",
7 | "H"
8 | ],
9 | "r_max": 6.0,
10 | "sel": "auto",
11 | "l_max": 1,
12 | "_comment2": " that's all"
13 | },
14 |
15 | "learning_rate": {
16 | "type": "exp",
17 | "decay_steps": 5000,
18 | "start_lr": 0.001,
19 | "stop_lr": 3.51e-8,
20 | "_comment3": "that's all"
21 | },
22 |
23 | "loss": {
24 | "type": "ener",
25 | "start_pref_e": 0.02,
26 | "limit_pref_e": 1,
27 | "start_pref_f": 1000,
28 | "limit_pref_f": 1,
29 | "start_pref_v": 0,
30 | "limit_pref_v": 0,
31 | "_comment4": " that's all"
32 | },
33 |
34 | "training": {
35 | "training_data": {
36 | "systems": [
37 | "../data/data_0/",
38 | "../data/data_1/",
39 | "../data/data_2/"
40 | ],
41 | "batch_size": "auto",
42 | "_comment5": "that's all"
43 | },
44 | "validation_data": {
45 | "systems": [
46 | "../data/data_3"
47 | ],
48 | "batch_size": 1,
49 | "numb_btch": 3,
50 | "_comment6": "that's all"
51 | },
52 | "numb_steps": 1000000,
53 | "seed": 10,
54 | "disp_file": "lcurve.out",
55 | "disp_freq": 100,
56 | "save_freq": 1000,
57 | "_comment7": "that's all"
58 | },
59 |
60 | "_comment8": "that's all"
61 | }
62 |
--------------------------------------------------------------------------------
/noxfile.py:
--------------------------------------------------------------------------------
1 | """Nox configuration file."""
2 |
3 | from __future__ import annotations
4 |
5 | import nox
6 |
7 |
8 | @nox.session
9 | def tests(session: nox.Session) -> None:
10 | """Run test suite with pytest."""
11 | session.install(
12 | "numpy",
13 | "deepmd-kit[torch]>=3.0.0b2",
14 | "--extra-index-url",
15 | "https://download.pytorch.org/whl/cpu",
16 | )
17 | cmake_prefix_path = session.run(
18 | "python",
19 | "-c",
20 | "import torch;print(torch.utils.cmake_prefix_path)",
21 | silent=True,
22 | ).strip()
23 | session.log(f"{cmake_prefix_path=}")
24 | session.install("-e.[test]", env={"CMAKE_PREFIX_PATH": cmake_prefix_path})
25 | session.run(
26 | "pytest",
27 | "--cov",
28 | "--cov-config",
29 | "pyproject.toml",
30 | "--cov-report",
31 | "term",
32 | "--cov-report",
33 | "xml",
34 | )
35 |
--------------------------------------------------------------------------------
/op/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | file(GLOB OP_SRC edge_index.cc)
2 |
3 | add_library(deepmd_gnn MODULE ${OP_SRC})
4 | # link: libdeepmd libtorch
5 | target_link_libraries(deepmd_gnn PRIVATE ${TORCH_LIBRARIES})
6 | target_compile_definitions(
7 | deepmd_gnn
8 | PUBLIC "$<$:_GLIBCXX_USE_CXX11_ABI=${OP_CXX_ABI_PT}>")
9 | if(APPLE)
10 | set_target_properties(deepmd_gnn PROPERTIES INSTALL_RPATH "@loader_path")
11 | else()
12 | set_target_properties(deepmd_gnn PROPERTIES INSTALL_RPATH "$ORIGIN")
13 | endif()
14 |
15 | if(BUILD_PY_IF)
16 | install(TARGETS deepmd_gnn DESTINATION deepmd_gnn/lib/)
17 | file(TOUCH ${CMAKE_CURRENT_BINARY_DIR}/__init__.py)
18 | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/__init__.py
19 | DESTINATION deepmd_gnn/lib)
20 | else(BUILD_PY_IF)
21 | install(TARGETS deepmd_gnn DESTINATION lib/)
22 | endif(BUILD_PY_IF)
23 |
--------------------------------------------------------------------------------
/op/edge_index.cc:
--------------------------------------------------------------------------------
1 | // SPDX-License-Identifier: LGPL-3.0-or-later
2 | #include
3 |
4 | #include
5 |
6 | torch::Tensor edge_index_kernel(const torch::Tensor &nlist_tensor,
7 | const torch::Tensor &atype_tensor,
8 | const torch::Tensor &mm_tensor) {
9 | torch::Tensor nlist_tensor_ = nlist_tensor.cpu().contiguous();
10 | torch::Tensor atype_tensor_ = atype_tensor.cpu().contiguous();
11 | torch::Tensor mm_tensor_ = mm_tensor.cpu().contiguous();
12 | if (nlist_tensor_.dim() == 2) {
13 | nlist_tensor_ =
14 | nlist_tensor_.view({1, nlist_tensor_.size(0), nlist_tensor_.size(1)});
15 | if (atype_tensor_.dim() != 1) {
16 | throw std::invalid_argument("atype_tensor must be 1D");
17 | }
18 | atype_tensor_ = atype_tensor_.view({1, atype_tensor_.size(0)});
19 | } else if (nlist_tensor_.dim() == 3) {
20 | if (atype_tensor_.dim() != 2) {
21 | throw std::invalid_argument("atype_tensor must be 2D");
22 | }
23 | } else {
24 | throw std::invalid_argument("nlist_tensor must be 2D or 3D");
25 | }
26 |
27 | const int64_t nf = nlist_tensor_.size(0);
28 | const int64_t nloc = nlist_tensor_.size(1);
29 | const int64_t nnei = nlist_tensor_.size(2);
30 | if (atype_tensor_.size(0) != nf) {
31 | throw std::invalid_argument(
32 | "atype_tensor must have the same size as nlist_tensor");
33 | }
34 | const int64_t nall = atype_tensor_.size(1);
35 | const int64_t nmm = mm_tensor_.size(0);
36 | int64_t *nlist = nlist_tensor_.view({-1}).data_ptr();
37 | int64_t *atype = atype_tensor_.view({-1}).data_ptr();
38 | int64_t *mm = mm_tensor_.view({-1}).data_ptr();
39 |
40 | std::vector edge_index;
41 | edge_index.reserve(nf * nloc * nnei * 2);
42 |
43 | for (int64_t ff = 0; ff < nf; ff++) {
44 | for (int64_t ii = 0; ii < nloc; ii++) {
45 | for (int64_t jj = 0; jj < nnei; jj++) {
46 | int64_t idx = ff * nloc * nnei + ii * nnei + jj;
47 | int64_t kk = nlist[idx];
48 | if (kk < 0) {
49 | continue;
50 | }
51 | int64_t global_kk = ff * nall + kk;
52 | int64_t global_ii = ff * nall + ii;
53 | // check if both atype[ii] and atype[kk] are in mm
54 | bool in_mm1 = false;
55 | for (int64_t mm_idx = 0; mm_idx < nmm; mm_idx++) {
56 | if (atype[global_ii] == mm[mm_idx]) {
57 | in_mm1 = true;
58 | break;
59 | }
60 | }
61 | bool in_mm2 = false;
62 | for (int64_t mm_idx = 0; mm_idx < nmm; mm_idx++) {
63 | if (atype[global_kk] == mm[mm_idx]) {
64 | in_mm2 = true;
65 | break;
66 | }
67 | }
68 | if (in_mm1 && in_mm2) {
69 | continue;
70 | }
71 | // add edge
72 | edge_index.push_back(global_kk);
73 | edge_index.push_back(global_ii);
74 | }
75 | }
76 | }
77 | // convert to tensor
78 | int64_t edge_size = edge_index.size() / 2;
79 | torch::Tensor edge_index_tensor =
80 | torch::tensor(edge_index, torch::kInt64).view({edge_size, 2});
81 | // to nlist_tensor.device
82 | return edge_index_tensor.to(nlist_tensor.device());
83 | }
84 |
85 | TORCH_LIBRARY(deepmd_gnn, m) { m.def("edge_index", edge_index_kernel); }
86 | // compatbility with old models freezed by deepmd_mace package
87 | TORCH_LIBRARY(deepmd_mace, m) { m.def("mace_edge_index", edge_index_kernel); }
88 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "scikit-build-core>=0.3.0",
4 | ]
5 | build-backend = "build_backend.dp_backend"
6 | backend-path = ["."]
7 |
8 | [project]
9 | name = "deepmd-gnn"
10 | dynamic = ["version"]
11 | description = "DeePMD-kit plugin for graph neural network models."
12 | authors = [
13 | { name = "Jinzhe Zeng", email = "jinzhe.zeng@rutgers.edu"},
14 | ]
15 | license = { file = "LICENSE" }
16 | classifiers = [
17 | "Programming Language :: Python :: 3.9",
18 | "Programming Language :: Python :: 3.10",
19 | "Programming Language :: Python :: 3.11",
20 | "Programming Language :: Python :: 3.12",
21 | "Operating System :: POSIX :: Linux",
22 | "Operating System :: MacOS :: MacOS X",
23 | "Operating System :: Microsoft :: Windows",
24 | "License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)",
25 | ]
26 | dependencies = [
27 | "torch",
28 | "deepmd-kit[torch]>=3.0.0b2",
29 | "mace-torch>=0.3.5",
30 | "nequip",
31 | "e3nn",
32 | "dargs",
33 | ]
34 | requires-python = ">=3.9"
35 | readme = "README.md"
36 | keywords = [
37 | ]
38 |
39 | [project.scripts]
40 |
41 | [project.entry-points."deepmd.pt"]
42 | mace = "deepmd_gnn.mace:MaceModel"
43 | nequip = "deepmd_gnn.nequip:NequipModel"
44 |
45 | [project.urls]
46 | repository = "https://gitlab.com/RutgersLBSR/deepmd-gnn"
47 |
48 | [project.optional-dependencies]
49 | test = [
50 | 'pytest',
51 | 'pytest-cov',
52 | "dargs>=0.4.8",
53 | ]
54 | docs = [
55 | "sphinx",
56 | "sphinx-autoapi",
57 | "myst-parser",
58 | "deepmodeling-sphinx>=0.3.0",
59 | "sphinx-book-theme",
60 | "dargs",
61 | ]
62 |
63 | [tool.scikit-build]
64 | wheel.py-api = "py2.py3"
65 | metadata.version.provider = "scikit_build_core.metadata.setuptools_scm"
66 | sdist.include = [
67 | "/deepmd_gnn/_version.py",
68 | ]
69 |
70 | [tool.scikit-build.cmake.define]
71 | BUILD_PY_IF = true
72 | BUILD_CPP_IF = false
73 |
74 | [tool.setuptools_scm]
75 | version_file = "deepmd_gnn/_version.py"
76 |
77 | [tool.ruff.lint]
78 | select = [
79 | "ALL",
80 | ]
81 | ignore = [
82 | "PLR0912",
83 | "PLR0913", # Too many arguments in function definition
84 | "PLR0915",
85 | "PLR2004",
86 | "FBT001",
87 | "FBT002",
88 | "N803",
89 | "FA100",
90 | "S603",
91 | "ANN101",
92 | "ANN102",
93 | "C901",
94 | "E501",
95 | ]
96 |
97 | [tool.ruff.lint.pydocstyle]
98 | convention = "numpy"
99 |
100 | [tool.ruff.lint.extend-per-file-ignores]
101 | "tests/**/*.py" = [
102 | "S101", # asserts allowed in tests...
103 | "ANN",
104 | "D101",
105 | "D102",
106 | ]
107 | "docs/conf.py" = [
108 | "ERA001",
109 | "INP001",
110 | ]
111 |
112 | [tool.coverage.report]
113 | include = ["deepmd_gnn/*"]
114 |
115 |
116 | [tool.cibuildwheel]
117 | test-command = [
118 | """python -c "import deepmd_gnn.op" """,
119 | ]
120 | build = ["cp312-*"]
121 | skip = ["*-win32", "*-manylinux_i686", "*-musllinux*"]
122 | manylinux-x86_64-image = "manylinux_2_28"
123 | manylinux-aarch64-image = "manylinux_2_28"
124 |
125 | [tool.cibuildwheel.macos]
126 | repair-wheel-command = """delocate-wheel --require-archs {delocate_archs} -w {dest_dir} -v {wheel} --ignore-missing-dependencies"""
127 |
128 | [tool.cibuildwheel.linux]
129 | repair-wheel-command = "auditwheel repair --exclude libc10.so --exclude libtorch.so --exclude libtorch_cpu.so -w {dest_dir} {wheel}"
130 | environment-pass = [
131 | "CIBW_BUILD",
132 | ]
133 | [tool.cibuildwheel.linux.environment]
134 | # use CPU version of torch for building, which should also work for GPU
135 | # note: uv has different behavior from pip on extra index url
136 | # https://github.com/astral-sh/uv/blob/main/PIP_COMPATIBILITY.md#packages-that-exist-on-multiple-indexes
137 | UV_EXTRA_INDEX_URL = "https://download.pytorch.org/whl/cpu"
138 |
--------------------------------------------------------------------------------
/renovate.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json",
3 | "extends": [
4 | "local>njzjz/renovate-config"
5 | ]
6 | }
7 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests."""
2 |
--------------------------------------------------------------------------------
/tests/data/set.000/box.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/tests/data/set.000/box.npy
--------------------------------------------------------------------------------
/tests/data/set.000/coord.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/tests/data/set.000/coord.npy
--------------------------------------------------------------------------------
/tests/data/set.000/energy.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/tests/data/set.000/energy.npy
--------------------------------------------------------------------------------
/tests/data/set.000/force.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepmodeling/deepmd-gnn/2c47b8b345a369bad8bd26cfc91a6cf5cecf3818/tests/data/set.000/force.npy
--------------------------------------------------------------------------------
/tests/data/type.raw:
--------------------------------------------------------------------------------
1 | 0
2 | 0
3 | 0
4 | 0
5 | 0
6 | 0
7 | 0
8 | 0
9 | 0
10 | 0
11 | 0
12 | 0
13 | 0
14 | 0
15 | 0
16 | 0
17 | 0
18 | 0
19 | 0
20 | 0
21 | 0
22 | 0
23 | 0
24 | 0
25 | 0
26 | 0
27 | 0
28 | 0
29 | 0
30 | 0
31 | 0
32 | 0
33 | 0
34 | 0
35 | 0
36 | 0
37 | 0
38 | 0
39 | 0
40 | 0
41 | 0
42 | 0
43 | 0
44 | 0
45 | 0
46 | 0
47 | 0
48 | 0
49 | 0
50 | 0
51 | 0
52 | 0
53 | 0
54 | 0
55 | 0
56 | 0
57 | 0
58 | 0
59 | 0
60 | 0
61 | 0
62 | 0
63 | 0
64 | 0
65 | 1
66 | 1
67 | 1
68 | 1
69 | 1
70 | 1
71 | 1
72 | 1
73 | 1
74 | 1
75 | 1
76 | 1
77 | 1
78 | 1
79 | 1
80 | 1
81 | 1
82 | 1
83 | 1
84 | 1
85 | 1
86 | 1
87 | 1
88 | 1
89 | 1
90 | 1
91 | 1
92 | 1
93 | 1
94 | 1
95 | 1
96 | 1
97 | 1
98 | 1
99 | 1
100 | 1
101 | 1
102 | 1
103 | 1
104 | 1
105 | 1
106 | 1
107 | 1
108 | 1
109 | 1
110 | 1
111 | 1
112 | 1
113 | 1
114 | 1
115 | 1
116 | 1
117 | 1
118 | 1
119 | 1
120 | 1
121 | 1
122 | 1
123 | 1
124 | 1
125 | 1
126 | 1
127 | 1
128 | 1
129 | 1
130 | 1
131 | 1
132 | 1
133 | 1
134 | 1
135 | 1
136 | 1
137 | 1
138 | 1
139 | 1
140 | 1
141 | 1
142 | 1
143 | 1
144 | 1
145 | 1
146 | 1
147 | 1
148 | 1
149 | 1
150 | 1
151 | 1
152 | 1
153 | 1
154 | 1
155 | 1
156 | 1
157 | 1
158 | 1
159 | 1
160 | 1
161 | 1
162 | 1
163 | 1
164 | 1
165 | 1
166 | 1
167 | 1
168 | 1
169 | 1
170 | 1
171 | 1
172 | 1
173 | 1
174 | 1
175 | 1
176 | 1
177 | 1
178 | 1
179 | 1
180 | 1
181 | 1
182 | 1
183 | 1
184 | 1
185 | 1
186 | 1
187 | 1
188 | 1
189 | 1
190 | 1
191 | 1
192 | 1
193 |
--------------------------------------------------------------------------------
/tests/data/type_map.raw:
--------------------------------------------------------------------------------
1 | O
2 | H
3 |
--------------------------------------------------------------------------------
/tests/mace.json:
--------------------------------------------------------------------------------
1 | {
2 | "_comment1": " model parameters",
3 | "model": {
4 | "type": "mace",
5 | "type_map": [
6 | "O",
7 | "H"
8 | ],
9 | "r_max": 6.0,
10 | "sel": "auto",
11 | "hidden_irreps": "16x0e",
12 | "_comment2": " that's all"
13 | },
14 |
15 | "learning_rate": {
16 | "type": "exp",
17 | "decay_steps": 5000,
18 | "start_lr": 0.001,
19 | "stop_lr": 0.0009,
20 | "_comment3": "that's all"
21 | },
22 |
23 | "loss": {
24 | "type": "ener",
25 | "start_pref_e": 0.02,
26 | "limit_pref_e": 1,
27 | "start_pref_f": 1000,
28 | "limit_pref_f": 1,
29 | "start_pref_v": 0,
30 | "limit_pref_v": 0,
31 | "_comment4": " that's all"
32 | },
33 |
34 | "training": {
35 | "training_data": {
36 | "systems": [
37 | "./data"
38 | ],
39 | "batch_size": "auto",
40 | "_comment5": "that's all"
41 | },
42 | "validation_data": {
43 | "systems": [
44 | "./data"
45 | ],
46 | "batch_size": 1,
47 | "numb_btch": 3,
48 | "_comment6": "that's all"
49 | },
50 | "numb_steps": 2,
51 | "seed": 10,
52 | "disp_file": "lcurve.out",
53 | "disp_freq": 1,
54 | "save_freq": 1,
55 | "_comment7": "that's all"
56 | },
57 |
58 | "_comment8": "that's all"
59 | }
60 |
--------------------------------------------------------------------------------
/tests/nequip.json:
--------------------------------------------------------------------------------
1 | {
2 | "_comment1": " model parameters",
3 | "model": {
4 | "type": "mace",
5 | "type_map": [
6 | "O",
7 | "H"
8 | ],
9 | "r_max": 6.0,
10 | "sel": "auto",
11 | "hidden_irreps": "16x0e",
12 | "_comment2": " that's all"
13 | },
14 |
15 | "learning_rate": {
16 | "type": "exp",
17 | "decay_steps": 5000,
18 | "start_lr": 0.001,
19 | "stop_lr": 0.0009,
20 | "_comment3": "that's all"
21 | },
22 |
23 | "loss": {
24 | "type": "ener",
25 | "start_pref_e": 0.02,
26 | "limit_pref_e": 1,
27 | "start_pref_f": 1000,
28 | "limit_pref_f": 1,
29 | "start_pref_v": 0,
30 | "limit_pref_v": 0,
31 | "_comment4": " that's all"
32 | },
33 |
34 | "training": {
35 | "training_data": {
36 | "systems": [
37 | "./data"
38 | ],
39 | "batch_size": "auto",
40 | "_comment5": "that's all"
41 | },
42 | "validation_data": {
43 | "systems": [
44 | "./data"
45 | ],
46 | "batch_size": 1,
47 | "numb_btch": 3,
48 | "_comment6": "that's all"
49 | },
50 | "numb_steps": 2,
51 | "seed": 10,
52 | "disp_file": "lcurve.out",
53 | "disp_freq": 1,
54 | "save_freq": 1,
55 | "_comment7": "that's all"
56 | },
57 |
58 | "_comment8": "that's all"
59 | }
60 |
--------------------------------------------------------------------------------
/tests/test_examples.py:
--------------------------------------------------------------------------------
1 | """Test examples."""
2 |
3 | import json
4 | from pathlib import Path
5 |
6 | import pytest
7 | from dargs.check import check
8 | from deepmd.utils.argcheck import gen_args
9 |
10 | from deepmd_gnn.argcheck import mace_model_args # noqa: F401
11 |
12 | example_path = Path(__file__).parent.parent / "examples"
13 |
14 | examples = (
15 | example_path / "water" / "mace" / "input.json",
16 | example_path / "dprc" / "mace" / "input.json",
17 | example_path / "water" / "nequip" / "input.json",
18 | example_path / "dprc" / "nequip" / "input.json",
19 | )
20 |
21 |
22 | @pytest.mark.parametrize("example", examples)
23 | def test_examples(example: Path) -> None:
24 | """Check whether examples meet arguments."""
25 | with example.open("r") as f:
26 | data = json.load(f)
27 | check(
28 | gen_args(),
29 | data,
30 | )
31 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: LGPL-3.0-or-later
2 | """Test models."""
3 |
4 | import unittest
5 | from copy import deepcopy
6 | from typing import Any, Callable, ClassVar, Optional
7 |
8 | import deepmd.pt.model # noqa: F401
9 | import numpy as np
10 | import torch
11 | from deepmd.dpmodel.output_def import (
12 | check_deriv,
13 | )
14 | from deepmd.dpmodel.utils.nlist import (
15 | build_neighbor_list,
16 | extend_coord_with_ghosts,
17 | extend_input_and_build_neighbor_list,
18 | )
19 | from deepmd.dpmodel.utils.region import (
20 | normalize_coord,
21 | )
22 | from deepmd.pt.utils.utils import (
23 | to_numpy_array,
24 | to_torch_tensor,
25 | )
26 |
27 | from deepmd_gnn.mace import MaceModel
28 | from deepmd_gnn.nequip import NequipModel
29 |
30 | GLOBAL_SEED = 20240822
31 |
32 | torch.set_default_dtype(torch.float64)
33 |
34 |
35 | class PTTestCase:
36 | """Common test case."""
37 |
38 | module: "torch.nn.Module"
39 | """PT module to test."""
40 |
41 | skipTest: Callable[[str], None] # noqa: N815
42 | """Skip test method."""
43 |
44 | @property
45 | def script_module(self) -> torch.jit.ScriptModule:
46 | """Script module."""
47 | with torch.jit.optimized_execution(should_optimize=False):
48 | return torch.jit.script(self.module)
49 |
50 | @property
51 | def deserialized_module(self) -> "torch.nn.Module":
52 | """Deserialized module."""
53 | return self.module.deserialize(self.module.serialize())
54 |
55 | @property
56 | def modules_to_test(self) -> list["torch.nn.Module"]:
57 | """Modules to test."""
58 | return [
59 | self.module,
60 | self.deserialized_module,
61 | ]
62 |
63 | def test_jit(self) -> None:
64 | """Test jit."""
65 | if getattr(self, "skip_test_jit", False):
66 | self.skipTest("Skip test jit.")
67 | self.script_module # noqa: B018
68 |
69 | @classmethod
70 | def convert_to_numpy(cls, xx: torch.Tensor) -> np.ndarray:
71 | """Convert to numpy array."""
72 | return to_numpy_array(xx)
73 |
74 | @classmethod
75 | def convert_from_numpy(cls, xx: np.ndarray) -> torch.Tensor:
76 | """Convert from numpy array."""
77 | return to_torch_tensor(xx)
78 |
79 | def forward_wrapper_cpu_ref(self, module):
80 | module.to("cpu")
81 | return self.forward_wrapper(module, on_cpu=True)
82 |
83 | def forward_wrapper(self, module, on_cpu=False):
84 | def create_wrapper_method(method):
85 | def wrapper_method(self, *args, **kwargs): # noqa: ARG001
86 | # convert to torch tensor
87 | args = [to_torch_tensor(arg) for arg in args]
88 | kwargs = {k: to_torch_tensor(v) for k, v in kwargs.items()}
89 | if on_cpu:
90 | args = [
91 | arg.detach().cpu() if arg is not None else None for arg in args
92 | ]
93 | kwargs = {
94 | k: v.detach().cpu() if v is not None else None
95 | for k, v in kwargs.items()
96 | }
97 | # forward
98 | output = method(*args, **kwargs)
99 | # convert to numpy array
100 | if isinstance(output, tuple):
101 | output = tuple(to_numpy_array(o) for o in output)
102 | elif isinstance(output, dict):
103 | output = {k: to_numpy_array(v) for k, v in output.items()}
104 | else:
105 | output = to_numpy_array(output)
106 | return output
107 |
108 | return wrapper_method
109 |
110 | class WrapperModule:
111 | __call__ = create_wrapper_method(module.__call__)
112 | if hasattr(module, "forward_lower"):
113 | forward_lower = create_wrapper_method(module.forward_lower)
114 |
115 | return WrapperModule()
116 |
117 |
118 | class ModelTestCase:
119 | """Common test case for model."""
120 |
121 | module: torch.nn.Module
122 | """Module to test."""
123 | modules_to_test: list[torch.nn.Module]
124 | """Modules to test."""
125 | expected_type_map: list[str]
126 | """Expected type map."""
127 | expected_rcut: float
128 | """Expected cut-off radius."""
129 | expected_dim_fparam: int
130 | """Expected number (dimension) of frame parameters."""
131 | expected_dim_aparam: int
132 | """Expected number (dimension) of atomic parameters."""
133 | expected_sel_type: list[int]
134 | """Expected selected atom types."""
135 | expected_aparam_nall: bool
136 | """Expected shape of atomic parameters."""
137 | expected_model_output_type: list[str]
138 | """Expected output type for the model."""
139 | model_output_equivariant: list[str]
140 | """Outputs that are equivariant to the input rotation."""
141 | expected_sel: list[int]
142 | """Expected number of neighbors."""
143 | expected_has_message_passing: bool
144 | """Expected whether having message passing."""
145 | expected_nmpnn: int
146 | """Expected number of MPNN."""
147 | forward_wrapper: ClassVar[Callable[[Any, bool], Any]]
148 | """Class wrapper for forward method."""
149 | forward_wrapper_cpu_ref: Callable[[Any], Any]
150 | """Convert model to CPU method."""
151 | aprec_dict: dict[str, Optional[float]]
152 | """Dictionary of absolute precision in each test."""
153 | rprec_dict: dict[str, Optional[float]]
154 | """Dictionary of relative precision in each test."""
155 | epsilon_dict: dict[str, Optional[float]]
156 | """Dictionary of epsilons in each test."""
157 |
158 | skipTest: Callable[[str], None] # noqa: N815
159 | """Skip test method."""
160 | output_def: dict[str, Any]
161 | """Output definition."""
162 |
163 | def test_get_type_map(self) -> None:
164 | """Test get_type_map."""
165 | for module in self.modules_to_test:
166 | assert module.get_type_map() == self.expected_type_map
167 |
168 | def test_get_rcut(self) -> None:
169 | """Test get_rcut."""
170 | for module in self.modules_to_test:
171 | assert module.get_rcut() == self.expected_rcut * self.expected_nmpnn
172 |
173 | def test_get_dim_fparam(self) -> None:
174 | """Test get_dim_fparam."""
175 | for module in self.modules_to_test:
176 | assert module.get_dim_fparam() == self.expected_dim_fparam
177 |
178 | def test_get_dim_aparam(self) -> None:
179 | """Test get_dim_aparam."""
180 | for module in self.modules_to_test:
181 | assert module.get_dim_aparam() == self.expected_dim_aparam
182 |
183 | def test_get_sel_type(self) -> None:
184 | """Test get_sel_type."""
185 | for module in self.modules_to_test:
186 | assert module.get_sel_type() == self.expected_sel_type
187 |
188 | def test_is_aparam_nall(self) -> None:
189 | """Test is_aparam_nall."""
190 | for module in self.modules_to_test:
191 | assert module.is_aparam_nall() == self.expected_aparam_nall
192 |
193 | def test_model_output_type(self) -> None:
194 | """Test model_output_type."""
195 | for module in self.modules_to_test:
196 | assert module.model_output_type() == self.expected_model_output_type
197 |
198 | def test_get_nnei(self) -> None:
199 | """Test get_nnei."""
200 | expected_nnei = sum(self.expected_sel)
201 | for module in self.modules_to_test:
202 | assert module.get_nnei() == expected_nnei
203 |
204 | def test_get_ntypes(self) -> None:
205 | """Test get_ntypes."""
206 | for module in self.modules_to_test:
207 | assert module.get_ntypes() == len(self.expected_type_map)
208 |
209 | def test_has_message_passing(self) -> None:
210 | """Test has_message_passing."""
211 | for module in self.modules_to_test:
212 | assert module.has_message_passing() == self.expected_has_message_passing
213 |
214 | def test_forward(self) -> None:
215 | """Test forward and forward_lower."""
216 | test_spin = getattr(self, "test_spin", False)
217 | nf = 2
218 | natoms = 5
219 | aprec = (
220 | 0
221 | if self.aprec_dict.get("test_forward", None) is None
222 | else self.aprec_dict["test_forward"]
223 | )
224 | rng = np.random.default_rng(GLOBAL_SEED)
225 | coord = 4.0 * rng.random([1, natoms, 3]).repeat(nf, 0).reshape([nf, -1])
226 | atype = np.array([[0, 0, 0, 1, 1] * nf], dtype=int).reshape([nf, -1])
227 | spin = 0.5 * rng.random([1, natoms, 3]).repeat(nf, 0).reshape([nf, -1])
228 | cell = 6.0 * np.repeat(np.eye(3)[None, ...], nf, axis=0).reshape([nf, 9])
229 | coord_ext, atype_ext, mapping, nlist = extend_input_and_build_neighbor_list(
230 | coord,
231 | atype,
232 | self.expected_rcut + 1.0 if test_spin else self.expected_rcut,
233 | self.expected_sel,
234 | mixed_types=self.module.mixed_types(),
235 | box=cell,
236 | )
237 | coord_normalized = normalize_coord(
238 | coord.reshape(nf, natoms, 3),
239 | cell.reshape(nf, 3, 3),
240 | )
241 | coord_ext_large, atype_ext_large, mapping_large = extend_coord_with_ghosts(
242 | coord_normalized,
243 | atype,
244 | cell,
245 | self.module.get_rcut(),
246 | )
247 | nlist_large = build_neighbor_list(
248 | coord_ext_large,
249 | atype_ext_large,
250 | natoms,
251 | self.expected_rcut,
252 | self.expected_sel,
253 | distinguish_types=(not self.module.mixed_types()),
254 | )
255 | spin_ext = np.take_along_axis(
256 | spin.reshape(nf, -1, 3),
257 | np.repeat(np.expand_dims(mapping, axis=-1), 3, axis=-1),
258 | axis=1,
259 | )
260 | aparam = None
261 | fparam = None
262 | if self.module.get_dim_aparam() > 0:
263 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()])
264 | if self.module.get_dim_fparam() > 0:
265 | fparam = rng.random([nf, self.module.get_dim_fparam()])
266 | ret = []
267 | ret_lower = []
268 | for _module in self.modules_to_test:
269 | module = self.forward_wrapper(_module)
270 | input_dict = {
271 | "coord": coord,
272 | "atype": atype,
273 | "box": cell,
274 | "aparam": aparam,
275 | "fparam": fparam,
276 | }
277 | if test_spin:
278 | input_dict["spin"] = spin
279 | ret.append(module(**input_dict))
280 |
281 | input_dict_lower = {
282 | "extended_coord": coord_ext_large,
283 | "extended_atype": atype_ext_large,
284 | "nlist": nlist_large,
285 | "aparam": aparam,
286 | "fparam": fparam,
287 | "mapping": mapping_large,
288 | }
289 | if test_spin:
290 | input_dict_lower["extended_spin"] = spin_ext
291 |
292 | # use shuffled nlist, simulating the lammps interface
293 | rng.shuffle(input_dict_lower["nlist"], axis=-1)
294 | ret_lower.append(module.forward_lower(**input_dict_lower))
295 |
296 | input_dict_lower = {
297 | "extended_coord": coord_ext_large,
298 | "extended_atype": atype_ext_large,
299 | "nlist": nlist_large,
300 | "aparam": aparam,
301 | "fparam": fparam,
302 | }
303 | if test_spin:
304 | input_dict_lower["extended_spin"] = spin_ext
305 |
306 | # use shuffled nlist, simulating the lammps interface
307 | rng.shuffle(input_dict_lower["nlist"], axis=-1)
308 | ret_lower.append(module.forward_lower(**input_dict_lower))
309 |
310 | for kk in ret[0]:
311 | # ensure the first frame and the second frame are the same
312 | np.testing.assert_allclose(
313 | ret[0][kk][0],
314 | ret[0][kk][1],
315 | err_msg=f"compare {kk} between frame 0 and 1",
316 | )
317 |
318 | subret = [rr[kk] for rr in ret if rr is not None]
319 | if subret:
320 | for ii, rr in enumerate(subret[1:]):
321 | if subret[0] is None:
322 | assert rr is None
323 | else:
324 | np.testing.assert_allclose(
325 | subret[0],
326 | rr,
327 | err_msg=f"compare {kk} between 0 and {ii}",
328 | )
329 | for kk in ret_lower[0]:
330 | subret = []
331 | for rr in ret_lower:
332 | if rr is not None:
333 | subret.append(rr[kk])
334 | if len(subret):
335 | for ii, rr in enumerate(subret[1:]):
336 | if kk == "expanded_force":
337 | # use mapping to scatter sum the forces
338 | rr = np.take_along_axis( # noqa: PLW2901
339 | rr,
340 | np.repeat(
341 | np.expand_dims(mapping_large, axis=-1),
342 | 3,
343 | axis=-1,
344 | ),
345 | axis=1,
346 | )
347 | if subret[0] is None:
348 | assert rr is None
349 | else:
350 | np.testing.assert_allclose(
351 | subret[0],
352 | rr,
353 | atol=1e-5,
354 | err_msg=f"compare {kk} between 0 and {ii}",
355 | )
356 | same_keys = set(ret[0].keys()) & set(ret_lower[0].keys())
357 | assert same_keys
358 | for key in same_keys:
359 | for rr in ret:
360 | if rr[key] is not None:
361 | rr1 = rr[key]
362 | break
363 | else:
364 | continue
365 | for rr in ret_lower:
366 | if rr[key] is not None:
367 | rr2 = rr[key]
368 | break
369 | else:
370 | continue
371 | np.testing.assert_allclose(rr1, rr2, atol=aprec)
372 |
373 | def test_permutation(self) -> None:
374 | """Test permutation."""
375 | if getattr(self, "skip_test_permutation", False):
376 | self.skipTest("Skip test permutation.")
377 | test_spin = getattr(self, "test_spin", False)
378 | rng = np.random.default_rng(GLOBAL_SEED)
379 | natoms = 5
380 | nf = 1
381 | aprec = (
382 | 0
383 | if self.aprec_dict.get("test_permutation", None) is None
384 | else self.aprec_dict["test_permutation"]
385 | )
386 | idx = [0, 1, 2, 3, 4]
387 | idx_perm = [1, 0, 4, 3, 2]
388 | cell = rng.random([3, 3])
389 | cell = (cell + cell.T) + 5.0 * np.eye(3)
390 | coord = rng.random([natoms, 3])
391 | coord = np.matmul(coord, cell)
392 | spin = 0.1 * rng.random([natoms, 3])
393 | atype = np.array([0, 0, 0, 1, 1])
394 | coord_perm = coord[idx_perm]
395 | spin_perm = spin[idx_perm]
396 | atype_perm = atype[idx_perm]
397 |
398 | # reshape for input
399 | coord = coord.reshape([nf, -1])
400 | coord_perm = coord_perm.reshape([nf, -1])
401 | spin_perm = spin_perm.reshape([nf, -1])
402 | atype = atype.reshape([nf, -1])
403 | atype_perm = atype_perm.reshape([nf, -1])
404 | cell = cell.reshape([nf, 9])
405 |
406 | aparam = None
407 | fparam = None
408 | aparam_perm = None
409 | if self.module.get_dim_aparam() > 0:
410 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()])
411 | aparam_perm = aparam[:, idx_perm, :]
412 | if self.module.get_dim_fparam() > 0:
413 | fparam = rng.random([nf, self.module.get_dim_fparam()])
414 |
415 | ret = []
416 | module = self.forward_wrapper(self.module)
417 | input_dict = {
418 | "coord": coord,
419 | "atype": atype,
420 | "box": cell,
421 | "aparam": aparam,
422 | "fparam": fparam,
423 | }
424 | if test_spin:
425 | input_dict["spin"] = spin
426 | ret.append(module(**input_dict))
427 | # permutation
428 | input_dict["coord"] = coord_perm
429 | input_dict["atype"] = atype_perm
430 | input_dict["aparam"] = aparam_perm
431 | if test_spin:
432 | input_dict["spin"] = spin_perm
433 | ret.append(module(**input_dict))
434 |
435 | for kk in ret[0]:
436 | if kk in self.output_def:
437 | if ret[0][kk] is None:
438 | assert ret[1][kk] is None
439 | continue
440 | atomic = self.output_def[kk].atomic
441 | if atomic:
442 | np.testing.assert_allclose(
443 | ret[0][kk][:, idx_perm],
444 | ret[1][kk][:, idx], # for extended output
445 | err_msg=f"compare {kk} before and after transform",
446 | atol=aprec,
447 | )
448 | else:
449 | np.testing.assert_allclose(
450 | ret[0][kk],
451 | ret[1][kk],
452 | err_msg=f"compare {kk} before and after transform",
453 | atol=aprec,
454 | )
455 | else:
456 | msg = f"Unknown output key: {kk}"
457 | raise RuntimeError(msg)
458 |
459 | def test_trans(self) -> None:
460 | """Test translation."""
461 | if getattr(self, "skip_test_trans", False):
462 | self.skipTest("Skip test translation.")
463 | test_spin = getattr(self, "test_spin", False)
464 | rng = np.random.default_rng(GLOBAL_SEED)
465 | natoms = 5
466 | nf = 1
467 | aprec = (
468 | 1e-14
469 | if self.aprec_dict.get("test_rot", None) is None
470 | else self.aprec_dict["test_rot"]
471 | )
472 | cell = rng.random([3, 3])
473 | cell = (cell + cell.T) + 5.0 * np.eye(3)
474 | coord = rng.random([natoms, 3])
475 | coord = np.matmul(coord, cell)
476 | spin = 0.1 * rng.random([natoms, 3])
477 | atype = np.array([0, 0, 0, 1, 1])
478 | shift = (rng.random([3]) - 0.5) * 2.0
479 | coord_s = np.matmul(
480 | np.remainder(np.matmul(coord + shift, np.linalg.inv(cell)), 1.0),
481 | cell,
482 | )
483 |
484 | # reshape for input
485 | coord = coord.reshape([nf, -1])
486 | spin = spin.reshape([nf, -1])
487 | coord_s = coord_s.reshape([nf, -1])
488 | atype = atype.reshape([nf, -1])
489 | cell = cell.reshape([nf, 9])
490 |
491 | aparam = None
492 | fparam = None
493 | if self.module.get_dim_aparam() > 0:
494 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()])
495 | if self.module.get_dim_fparam() > 0:
496 | fparam = rng.random([nf, self.module.get_dim_fparam()])
497 |
498 | ret = []
499 | module = self.forward_wrapper(self.module)
500 | input_dict = {
501 | "coord": coord,
502 | "atype": atype,
503 | "box": cell,
504 | "aparam": aparam,
505 | "fparam": fparam,
506 | }
507 | if test_spin:
508 | input_dict["spin"] = spin
509 | ret.append(module(**input_dict))
510 | # translation
511 | input_dict["coord"] = coord_s
512 | ret.append(module(**input_dict))
513 |
514 | for kk in ret[0]:
515 | if kk in self.output_def:
516 | if ret[0][kk] is None:
517 | assert ret[1][kk] is None
518 | continue
519 | np.testing.assert_allclose(
520 | ret[0][kk],
521 | ret[1][kk],
522 | err_msg=f"compare {kk} before and after transform",
523 | atol=aprec,
524 | )
525 | else:
526 | msg = f"Unknown output key: {kk}"
527 | raise RuntimeError(msg)
528 |
529 | def test_rot(self) -> None:
530 | """Test rotation."""
531 | if getattr(self, "skip_test_rot", False):
532 | self.skipTest("Skip test rotation.")
533 | test_spin = getattr(self, "test_spin", False)
534 | rng = np.random.default_rng(GLOBAL_SEED)
535 | natoms = 5
536 | nf = 1
537 | aprec = (
538 | 0
539 | if self.aprec_dict.get("test_rot", None) is None
540 | else self.aprec_dict["test_rot"]
541 | )
542 | # rotate only coord and shift to the center of cell
543 | cell = 10.0 * np.eye(3)
544 | coord = 2.0 * rng.random([natoms, 3])
545 | spin = 0.1 * rng.random([natoms, 3])
546 | atype = np.array([0, 0, 0, 1, 1])
547 | shift = np.array([4.0, 4.0, 4.0])
548 | from scipy.stats import (
549 | special_ortho_group,
550 | )
551 |
552 | rmat = special_ortho_group.rvs(3)
553 | coord_rot = np.matmul(coord, rmat)
554 | spin_rot = np.matmul(spin, rmat)
555 |
556 | # reshape for input
557 | coord = (coord + shift).reshape([nf, -1])
558 | spin = spin.reshape([nf, -1])
559 | coord_rot = (coord_rot + shift).reshape([nf, -1])
560 | spin_rot = spin_rot.reshape([nf, -1])
561 | atype = atype.reshape([nf, -1])
562 | cell = cell.reshape([nf, 9])
563 |
564 | aparam = None
565 | fparam = None
566 | if self.module.get_dim_aparam() > 0:
567 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()])
568 | if self.module.get_dim_fparam() > 0:
569 | fparam = rng.random([nf, self.module.get_dim_fparam()])
570 |
571 | ret = []
572 | module = self.forward_wrapper(self.module)
573 | input_dict = {
574 | "coord": coord,
575 | "atype": atype,
576 | "box": cell,
577 | "aparam": aparam,
578 | "fparam": fparam,
579 | }
580 | if test_spin:
581 | input_dict["spin"] = spin
582 | ret.append(module(**input_dict))
583 | # rotation
584 | input_dict["coord"] = coord_rot
585 | if test_spin:
586 | input_dict["spin"] = spin_rot
587 | ret.append(module(**input_dict))
588 |
589 | for kk in ret[0]:
590 | if kk in self.output_def:
591 | if ret[0][kk] is None:
592 | assert ret[1][kk] is None
593 | continue
594 | rot_equivariant = (
595 | check_deriv(self.output_def[kk])
596 | or kk in self.model_output_equivariant
597 | )
598 | if not rot_equivariant:
599 | np.testing.assert_allclose(
600 | ret[0][kk],
601 | ret[1][kk],
602 | err_msg=f"compare {kk} before and after transform",
603 | atol=aprec,
604 | )
605 | else:
606 | v_size = self.output_def[kk].size
607 | if v_size == 3:
608 | rotated_ret_0 = np.matmul(ret[0][kk], rmat)
609 | ret_1 = ret[1][kk]
610 | elif v_size == 9:
611 | ret_0 = ret[0][kk].reshape(-1, 3, 3)
612 | batch_rmat_t = np.repeat(
613 | rmat.T.reshape(1, 3, 3),
614 | ret_0.shape[0],
615 | axis=0,
616 | )
617 | batch_rmat = np.repeat(
618 | rmat.reshape(1, 3, 3),
619 | ret_0.shape[0],
620 | axis=0,
621 | )
622 | rotated_ret_0 = np.matmul(
623 | batch_rmat_t,
624 | np.matmul(ret_0, batch_rmat),
625 | )
626 | ret_1 = ret[1][kk].reshape(-1, 3, 3)
627 | else:
628 | # unsupported dim
629 | continue
630 | np.testing.assert_allclose(
631 | rotated_ret_0,
632 | ret_1,
633 | err_msg=f"compare {kk} before and after transform",
634 | atol=aprec,
635 | )
636 | else:
637 | msg = f"Unknown output key: {kk}"
638 | raise RuntimeError(msg)
639 |
640 | # rotate coord and cell
641 | cell = rng.random([3, 3])
642 | cell = (cell + cell.T) + 5.0 * np.eye(3)
643 | coord = rng.random([natoms, 3])
644 | coord = np.matmul(coord, cell)
645 | spin = 0.1 * rng.random([natoms, 3])
646 | atype = np.array([0, 0, 0, 1, 1])
647 | coord_rot = np.matmul(coord, rmat)
648 | cell_rot = np.matmul(cell, rmat)
649 | spin_rot = np.matmul(spin, rmat)
650 |
651 | # reshape for input
652 | coord = coord.reshape([nf, -1])
653 | spin = spin.reshape([nf, -1])
654 | coord_rot = coord_rot.reshape([nf, -1])
655 | spin_rot = spin_rot.reshape([nf, -1])
656 | atype = atype.reshape([nf, -1])
657 | cell = cell.reshape([nf, 9])
658 | cell_rot = cell_rot.reshape([nf, 9])
659 |
660 | ret = []
661 | module = self.forward_wrapper(self.module)
662 | input_dict = {
663 | "coord": coord,
664 | "atype": atype,
665 | "box": cell,
666 | "aparam": aparam,
667 | "fparam": fparam,
668 | }
669 | if test_spin:
670 | input_dict["spin"] = spin
671 | ret.append(module(**input_dict))
672 | # rotation
673 | input_dict["coord"] = coord_rot
674 | input_dict["box"] = cell_rot
675 | if test_spin:
676 | input_dict["spin"] = spin_rot
677 | ret.append(module(**input_dict))
678 |
679 | for kk in ret[0]:
680 | if kk in self.output_def:
681 | if ret[0][kk] is None:
682 | assert ret[1][kk] is None
683 | continue
684 | rot_equivariant = (
685 | check_deriv(self.output_def[kk])
686 | or kk in self.model_output_equivariant
687 | )
688 | if not rot_equivariant:
689 | np.testing.assert_allclose(
690 | ret[0][kk],
691 | ret[1][kk],
692 | err_msg=f"compare {kk} before and after transform",
693 | atol=aprec,
694 | )
695 | else:
696 | v_size = self.output_def[kk].size
697 | if v_size == 3:
698 | rotated_ret_0 = np.matmul(ret[0][kk], rmat)
699 | ret_1 = ret[1][kk]
700 | elif v_size == 9:
701 | ret_0 = ret[0][kk].reshape(-1, 3, 3)
702 | batch_rmat_t = np.repeat(
703 | rmat.T.reshape(1, 3, 3),
704 | ret_0.shape[0],
705 | axis=0,
706 | )
707 | batch_rmat = np.repeat(
708 | rmat.reshape(1, 3, 3),
709 | ret_0.shape[0],
710 | axis=0,
711 | )
712 | rotated_ret_0 = np.matmul(
713 | batch_rmat_t,
714 | np.matmul(ret_0, batch_rmat),
715 | )
716 | ret_1 = ret[1][kk].reshape(-1, 3, 3)
717 | else:
718 | # unsupported dim
719 | continue
720 | np.testing.assert_allclose(
721 | rotated_ret_0,
722 | ret_1,
723 | err_msg=f"compare {kk} before and after transform",
724 | atol=aprec,
725 | )
726 | else:
727 | msg = f"Unknown output key: {kk}"
728 | raise RuntimeError(msg)
729 |
730 | def test_smooth(self) -> None:
731 | """Test smooth."""
732 | if getattr(self, "skip_test_smooth", False):
733 | self.skipTest("Skip test smooth.")
734 | test_spin = getattr(self, "test_spin", False)
735 | rng = np.random.default_rng(GLOBAL_SEED)
736 | epsilon = (
737 | 1e-5
738 | if self.epsilon_dict.get("test_smooth", None) is None
739 | else self.epsilon_dict["test_smooth"]
740 | )
741 | assert epsilon is not None
742 | # required prec.
743 | rprec = (
744 | 1e-5
745 | if self.rprec_dict.get("test_smooth", None) is None
746 | else self.rprec_dict["test_smooth"]
747 | )
748 | aprec = (
749 | 1e-5
750 | if self.aprec_dict.get("test_smooth", None) is None
751 | else self.aprec_dict["test_smooth"]
752 | )
753 | natoms = 10
754 | nf = 1
755 | cell = 10.0 * np.eye(3)
756 | atype0 = np.arange(2)
757 | atype1 = rng.integers(0, 2, size=natoms - 2)
758 | atype = np.concatenate([atype0, atype1]).reshape(natoms)
759 | spin = 0.1 * rng.random([natoms, 3])
760 | coord0 = np.array(
761 | [
762 | 0.0,
763 | 0.0,
764 | 0.0,
765 | self.expected_rcut - 0.5 * epsilon,
766 | 0.0,
767 | 0.0,
768 | 0.0,
769 | self.expected_rcut - 0.5 * epsilon,
770 | 0.0,
771 | ],
772 | ).reshape(-1, 3)
773 | coord1 = rng.random([natoms - coord0.shape[0], 3])
774 | coord1 = np.matmul(coord1, cell)
775 | coord = np.concatenate([coord0, coord1], axis=0)
776 |
777 | coord0 = deepcopy(coord)
778 | coord1 = deepcopy(coord)
779 | coord1[1][0] += epsilon
780 | coord2 = deepcopy(coord)
781 | coord2[2][1] += epsilon
782 | coord3 = deepcopy(coord)
783 | coord3[1][0] += epsilon
784 | coord3[2][1] += epsilon
785 |
786 | # reshape for input
787 | coord0 = coord0.reshape([nf, -1])
788 | coord1 = coord1.reshape([nf, -1])
789 | coord2 = coord2.reshape([nf, -1])
790 | coord3 = coord3.reshape([nf, -1])
791 | spin = spin.reshape([nf, -1])
792 | atype = atype.reshape([nf, -1])
793 | cell = cell.reshape([nf, 9])
794 |
795 | aparam = None
796 | fparam = None
797 | if self.module.get_dim_aparam() > 0:
798 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()])
799 | if self.module.get_dim_fparam() > 0:
800 | fparam = rng.random([nf, self.module.get_dim_fparam()])
801 |
802 | ret = []
803 | module = self.forward_wrapper(self.module)
804 | input_dict = {"atype": atype, "box": cell, "aparam": aparam, "fparam": fparam}
805 | if test_spin:
806 | input_dict["spin"] = spin
807 | # coord0
808 | input_dict["coord"] = coord0
809 | ret.append(module(**input_dict))
810 | # coord1
811 | input_dict["coord"] = coord1
812 | ret.append(module(**input_dict))
813 | # coord2
814 | input_dict["coord"] = coord2
815 | ret.append(module(**input_dict))
816 | # coord3
817 | input_dict["coord"] = coord3
818 | ret.append(module(**input_dict))
819 |
820 | for kk in ret[0]:
821 | if kk in self.output_def:
822 | if ret[0][kk] is None:
823 | for ii in range(len(ret) - 1):
824 | assert ret[ii + 1][kk] is None
825 | continue
826 | for ii in range(len(ret) - 1):
827 | np.testing.assert_allclose(
828 | ret[0][kk],
829 | ret[ii + 1][kk],
830 | err_msg=f"compare {kk} before and after transform",
831 | atol=aprec,
832 | rtol=rprec,
833 | )
834 | else:
835 | msg = f"Unknown output key: {kk}"
836 | raise RuntimeError(msg)
837 |
838 | def test_autodiff(self) -> None:
839 | """Test autodiff."""
840 | if getattr(self, "skip_test_autodiff", False):
841 | self.skipTest("Skip test autodiff.")
842 | test_spin = getattr(self, "test_spin", False)
843 |
844 | places = 4
845 | delta = 1e-5
846 |
847 | def finite_difference(f, x, delta=1e-6):
848 | in_shape = x.shape
849 | y0 = f(x)
850 | out_shape = y0.shape
851 | res = np.empty(out_shape + in_shape)
852 | for idx in np.ndindex(*in_shape):
853 | diff = np.zeros(in_shape)
854 | diff[idx] += delta
855 | y1p = f(x + diff)
856 | y1n = f(x - diff)
857 | res[(Ellipsis, *idx)] = (y1p - y1n) / (2 * delta)
858 | return res
859 |
860 | def stretch_box(old_coord, old_box, new_box):
861 | ocoord = old_coord.reshape(-1, 3)
862 | obox = old_box.reshape(3, 3)
863 | nbox = new_box.reshape(3, 3)
864 | ncoord = ocoord @ np.linalg.inv(obox) @ nbox
865 | return ncoord.reshape(old_coord.shape)
866 |
867 | rng = np.random.default_rng(GLOBAL_SEED)
868 | natoms = 5
869 | nf = 1
870 | cell = rng.random([3, 3])
871 | cell = (cell + cell.T) + 5.0 * np.eye(3)
872 | coord = rng.random([natoms, 3])
873 | coord = np.matmul(coord, cell)
874 | spin = 0.1 * rng.random([natoms, 3])
875 | atype = np.array([0, 0, 0, 1, 1])
876 |
877 | # reshape for input
878 | coord = coord.reshape([nf, -1])
879 | spin = spin.reshape([nf, -1])
880 | atype = atype.reshape([nf, -1])
881 | cell = cell.reshape([nf, 9])
882 |
883 | aparam = None
884 | fparam = None
885 | if self.module.get_dim_aparam() > 0:
886 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()])
887 | if self.module.get_dim_fparam() > 0:
888 | fparam = rng.random([nf, self.module.get_dim_fparam()])
889 |
890 | module = self.forward_wrapper(self.module)
891 |
892 | # only test force and virial for energy model
893 | def ff_coord(_coord):
894 | input_dict = {
895 | "coord": _coord,
896 | "atype": atype,
897 | "box": cell,
898 | "aparam": aparam,
899 | "fparam": fparam,
900 | }
901 | if test_spin:
902 | input_dict["spin"] = spin
903 | return module(**input_dict)["energy"]
904 |
905 | def ff_spin(_spin):
906 | input_dict = {
907 | "coord": coord,
908 | "atype": atype,
909 | "box": cell,
910 | "aparam": aparam,
911 | "fparam": fparam,
912 | }
913 | if test_spin:
914 | input_dict["spin"] = _spin
915 | return module(**input_dict)["energy"]
916 |
917 | fdf = -finite_difference(ff_coord, coord, delta=delta).squeeze()
918 | input_dict = {
919 | "coord": coord,
920 | "atype": atype,
921 | "box": cell,
922 | "aparam": aparam,
923 | "fparam": fparam,
924 | }
925 | if test_spin:
926 | input_dict["spin"] = spin
927 | rff = module(**input_dict)["force"]
928 | np.testing.assert_almost_equal(
929 | fdf.reshape(-1, 3),
930 | rff.reshape(-1, 3),
931 | decimal=places,
932 | )
933 |
934 | if test_spin:
935 | # magnetic force
936 | fdf = -finite_difference(ff_spin, spin, delta=delta).squeeze()
937 | rff = module(**input_dict)["force_mag"]
938 | np.testing.assert_almost_equal(
939 | fdf.reshape(-1, 3),
940 | rff.reshape(-1, 3),
941 | decimal=places,
942 | )
943 |
944 | if not test_spin:
945 |
946 | def ff_cell(bb):
947 | input_dict = {
948 | "coord": stretch_box(coord, cell, bb),
949 | "atype": atype,
950 | "box": bb,
951 | "aparam": aparam,
952 | "fparam": fparam,
953 | }
954 | return module(**input_dict)["energy"]
955 |
956 | fdv = (
957 | -(
958 | finite_difference(ff_cell, cell, delta=delta)
959 | .reshape(-1, 3, 3)
960 | .transpose(0, 2, 1)
961 | @ cell.reshape(-1, 3, 3)
962 | )
963 | .squeeze()
964 | .reshape(9)
965 | )
966 | input_dict = {
967 | "coord": stretch_box(coord, cell, cell),
968 | "atype": atype,
969 | "box": cell,
970 | "aparam": aparam,
971 | "fparam": fparam,
972 | }
973 | rfv = module(**input_dict)["virial"]
974 | np.testing.assert_almost_equal(
975 | fdv.reshape(-1, 9),
976 | rfv.reshape(-1, 9),
977 | decimal=places,
978 | )
979 | else:
980 | # not support virial by far
981 | pass
982 |
983 | def test_device_consistence(self) -> None:
984 | """Test forward consistency between devices."""
985 | test_spin = getattr(self, "test_spin", False)
986 | nf = 1
987 | natoms = 5
988 | rng = np.random.default_rng(GLOBAL_SEED)
989 | coord = 4.0 * rng.random([natoms, 3]).reshape([nf, -1])
990 | atype = np.array([0, 0, 0, 1, 1], dtype=int).reshape([nf, -1])
991 | spin = 0.5 * rng.random([natoms, 3]).reshape([nf, -1])
992 | cell = 6.0 * np.eye(3).reshape([nf, 9])
993 | aparam = None
994 | fparam = None
995 | if self.module.get_dim_aparam() > 0:
996 | aparam = rng.random([nf, natoms, self.module.get_dim_aparam()])
997 | if self.module.get_dim_fparam() > 0:
998 | fparam = rng.random([nf, self.module.get_dim_fparam()])
999 | ret = []
1000 | device_module = self.forward_wrapper(self.module)
1001 | ref_module = self.forward_wrapper_cpu_ref(deepcopy(self.module))
1002 |
1003 | for module in [device_module, ref_module]:
1004 | input_dict = {
1005 | "coord": coord,
1006 | "atype": atype,
1007 | "box": cell,
1008 | "aparam": aparam,
1009 | "fparam": fparam,
1010 | }
1011 | if test_spin:
1012 | input_dict["spin"] = spin
1013 | ret.append(module(**input_dict))
1014 | for kk in ret[0]:
1015 | subret = [rr[kk] for rr in ret if rr is not None]
1016 | if subret:
1017 | for ii, rr in enumerate(subret[1:]):
1018 | if subret[0] is None:
1019 | assert rr is None
1020 | else:
1021 | np.testing.assert_allclose(
1022 | subret[0],
1023 | rr,
1024 | err_msg=f"compare {kk} between 0 and {ii}",
1025 | atol=1e-10,
1026 | )
1027 |
1028 |
1029 | class EnerModelTest(ModelTestCase):
1030 | @classmethod
1031 | def setUpClass(cls) -> None:
1032 | cls.expected_rcut = 5.0
1033 | cls.expected_type_map = ["O", "H"]
1034 | cls.expected_dim_fparam = 0
1035 | cls.expected_dim_aparam = 0
1036 | cls.expected_sel_type = [0, 1]
1037 | cls.expected_aparam_nall = False
1038 | cls.expected_model_output_type = ["energy"]
1039 | cls.model_output_equivariant = []
1040 | cls.expected_sel = [46, 92]
1041 | cls.expected_sel_mix = sum(cls.expected_sel) # type: ignore[attr-defined]
1042 | cls.expected_has_message_passing = False
1043 | cls.aprec_dict = {}
1044 | cls.rprec_dict = {}
1045 | cls.epsilon_dict = {}
1046 |
1047 |
1048 | class TestMaceModel(unittest.TestCase, EnerModelTest, PTTestCase): # type: ignore[misc]
1049 | """Test MACE model."""
1050 |
1051 | @property
1052 | def modules_to_test(self) -> list[torch.nn.Module]: # type: ignore[override]
1053 | """Modules to test."""
1054 | skip_test_jit = getattr(self, "skip_test_jit", False)
1055 | modules = PTTestCase.modules_to_test.fget(self) # type: ignore[attr-defined]
1056 | if not skip_test_jit:
1057 | # for Model, we can test script module API
1058 | modules += [
1059 | self._script_module
1060 | if hasattr(self, "_script_module")
1061 | else self.script_module,
1062 | ]
1063 | return modules
1064 |
1065 | _script_module: torch.jit.ScriptModule
1066 |
1067 | @classmethod
1068 | def setUpClass(cls) -> None:
1069 | """Set up class."""
1070 | EnerModelTest.setUpClass()
1071 |
1072 | torch.manual_seed(GLOBAL_SEED + 1)
1073 | cls.module = MaceModel(
1074 | type_map=cls.expected_type_map,
1075 | sel=138,
1076 | precision="float64",
1077 | )
1078 | with torch.jit.optimized_execution(should_optimize=False):
1079 | cls._script_module = torch.jit.script(cls.module)
1080 | cls.output_def = cls.module.translated_output_def()
1081 | cls.expected_has_message_passing = False
1082 | cls.expected_sel_type = []
1083 | cls.expected_dim_fparam = 0
1084 | cls.expected_dim_aparam = 0
1085 | cls.expected_nmpnn = 2
1086 |
1087 |
1088 | class TestNequipModel(unittest.TestCase, EnerModelTest, PTTestCase): # type: ignore[misc]
1089 | """Test Nequip model."""
1090 |
1091 | @property
1092 | def modules_to_test(self) -> list[torch.nn.Module]: # type: ignore[override]
1093 | """Modules to test."""
1094 | skip_test_jit = getattr(self, "skip_test_jit", False)
1095 | modules = PTTestCase.modules_to_test.fget(self) # type: ignore[attr-defined]
1096 | if not skip_test_jit:
1097 | # for Model, we can test script module API
1098 | modules += [
1099 | self._script_module
1100 | if hasattr(self, "_script_module")
1101 | else self.script_module,
1102 | ]
1103 | return modules
1104 |
1105 | _script_module: torch.jit.ScriptModule
1106 |
1107 | @classmethod
1108 | def setUpClass(cls) -> None:
1109 | """Set up class."""
1110 | EnerModelTest.setUpClass()
1111 |
1112 | torch.manual_seed(GLOBAL_SEED + 1)
1113 | cls.module = NequipModel(
1114 | type_map=cls.expected_type_map,
1115 | sel=138,
1116 | r_max=cls.expected_rcut,
1117 | num_layers=2,
1118 | precision="float64",
1119 | )
1120 | with torch.jit.optimized_execution(should_optimize=False):
1121 | cls._script_module = torch.jit.script(cls.module)
1122 | cls.output_def = cls.module.translated_output_def()
1123 | cls.expected_has_message_passing = False
1124 | cls.expected_sel_type = []
1125 | cls.expected_dim_fparam = 0
1126 | cls.expected_dim_aparam = 0
1127 | cls.expected_nmpnn = 2
1128 |
--------------------------------------------------------------------------------
/tests/test_op.py:
--------------------------------------------------------------------------------
1 | """Test custom operations."""
2 |
3 | import torch
4 |
5 | import deepmd_gnn.op # noqa: F401
6 |
7 |
8 | def test_one_frame() -> None:
9 | """Test one frame."""
10 | nlist_ff = torch.tensor(
11 | [
12 | [1, 2, -1, -1],
13 | [2, 0, -1, -1],
14 | [0, 1, -1, -1],
15 | ],
16 | dtype=torch.int64,
17 | device="cpu",
18 | )
19 | extended_atype_ff = torch.tensor(
20 | [0, 1, 2],
21 | dtype=torch.int64,
22 | device="cpu",
23 | )
24 | mm_types = [1, 2]
25 | expected_edge_index = torch.tensor(
26 | [
27 | [1, 0],
28 | [2, 0],
29 | [0, 1],
30 | [0, 2],
31 | ],
32 | dtype=torch.int64,
33 | device="cpu",
34 | )
35 |
36 | edge_index = torch.ops.deepmd_gnn.edge_index(
37 | nlist_ff,
38 | extended_atype_ff,
39 | torch.tensor(mm_types, dtype=torch.int64, device="cpu"),
40 | )
41 |
42 | assert torch.equal(edge_index, expected_edge_index)
43 |
44 |
45 | def test_two_frame() -> None:
46 | """Test one frame."""
47 | nlist = torch.tensor(
48 | [
49 | [
50 | [1, 2, -1, -1],
51 | [2, 0, -1, -1],
52 | [0, 1, -1, -1],
53 | ],
54 | [
55 | [1, 2, -1, -1],
56 | [2, 0, -1, -1],
57 | [0, 1, -1, -1],
58 | ],
59 | ],
60 | dtype=torch.int64,
61 | device="cpu",
62 | )
63 | extended_atype = torch.tensor(
64 | [
65 | [0, 1, 2],
66 | [0, 1, 2],
67 | ],
68 | dtype=torch.int64,
69 | device="cpu",
70 | )
71 | mm_types = [1, 2]
72 | expected_edge_index = torch.tensor(
73 | [
74 | [1, 0],
75 | [2, 0],
76 | [0, 1],
77 | [0, 2],
78 | [4, 3],
79 | [5, 3],
80 | [3, 4],
81 | [3, 5],
82 | ],
83 | dtype=torch.int64,
84 | device="cpu",
85 | )
86 |
87 | edge_index = torch.ops.deepmd_gnn.edge_index(
88 | nlist,
89 | extended_atype,
90 | torch.tensor(mm_types, dtype=torch.int64, device="cpu"),
91 | )
92 |
93 | assert torch.equal(edge_index, expected_edge_index)
94 |
--------------------------------------------------------------------------------
/tests/test_training.py:
--------------------------------------------------------------------------------
1 | """Test training."""
2 |
3 | import shutil
4 | import subprocess
5 | import sys
6 | import tempfile
7 | from pathlib import Path
8 |
9 | import pytest
10 |
11 |
12 | @pytest.mark.parametrize(
13 | "input_fn",
14 | [
15 | "mace.json",
16 | "nequip.json",
17 | ],
18 | )
19 | def test_e2e_training(input_fn) -> None:
20 | """Test training the model."""
21 | model_fn = "model.pth"
22 | # create temp directory and copy example files
23 | with tempfile.TemporaryDirectory() as _tmpdir:
24 | tmpdir = Path(_tmpdir)
25 | this_dir = Path(__file__).parent
26 | data_path = this_dir / "data"
27 | input_path = this_dir / input_fn
28 | # copy data to tmpdir
29 | shutil.copytree(data_path, tmpdir / "data")
30 | # copy input.json to tmpdir
31 | shutil.copy(input_path, tmpdir / input_fn)
32 |
33 | subprocess.check_call(
34 | [
35 | sys.executable,
36 | "-m",
37 | "deepmd",
38 | "--pt",
39 | "train",
40 | input_fn,
41 | ],
42 | cwd=tmpdir,
43 | )
44 | subprocess.check_call(
45 | [
46 | sys.executable,
47 | "-m",
48 | "deepmd",
49 | "--pt",
50 | "freeze",
51 | "-o",
52 | model_fn,
53 | ],
54 | cwd=tmpdir,
55 | )
56 | assert (tmpdir / model_fn).exists()
57 |
--------------------------------------------------------------------------------
/tests/test_version.py:
--------------------------------------------------------------------------------
1 | """Test version."""
2 |
3 | from __future__ import annotations
4 |
5 | from importlib.metadata import version
6 |
7 | from deepmd_gnn import __version__
8 |
9 |
10 | def test_version() -> None:
11 | """Test version."""
12 | assert version("deepmd-gnn") == __version__
13 |
--------------------------------------------------------------------------------