├── .github └── workflows │ ├── docs_to_pages.yaml │ └── pypi.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CMakeLists.txt ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _static │ ├── css │ │ └── custom_styles.css │ └── images │ │ ├── Graph2Mat.svg │ │ ├── graph2mat_overview.svg │ │ └── water_equivariant_matrix.png ├── _templates │ └── autosummary │ │ ├── custom-class-template.rst │ │ └── custom-module-template.rst ├── api │ ├── description.rst │ └── full_api.rst ├── conf.py ├── index.rst ├── server.rst └── tutorials │ ├── cli │ ├── index.rst │ └── siesta_MDaccel.rst │ └── python_api │ ├── Batching.ipynb │ ├── Computing a matrix.ipynb │ ├── Fitting matrices.ipynb │ ├── MACE+Graph2mat.ipynb │ └── index.rst ├── pyproject.toml └── src ├── CMakeLists.txt └── graph2mat ├── CMakeLists.txt ├── __init__.py ├── __version__.py ├── bindings ├── __init__.py ├── e3nn │ ├── __init__.py │ ├── irreps_tools.py │ └── modules │ │ ├── __init__.py │ │ ├── _utils.py │ │ ├── edge_operations.py │ │ ├── graph2mat.py │ │ ├── matrixblock.py │ │ ├── node_operations.py │ │ ├── preprocessing.py │ │ └── tests │ │ └── test_e3nngraph2mat.py └── torch │ ├── __init__.py │ ├── conftest.py │ ├── data │ ├── __init__.py │ ├── data.py │ ├── dataset.py │ ├── formats.py │ └── tests │ │ └── test_data.py │ ├── load.py │ ├── modules │ ├── __init__.py │ ├── graph2mat.py │ ├── matrixblock.py │ └── tests │ │ └── test_basis_matrix.py │ └── tests │ └── test_orbital_matrix.py ├── conftest.py ├── core ├── CMakeLists.txt ├── __init__.py ├── _docstrings.py ├── data │ ├── CMakeLists.txt │ ├── __init__.py │ ├── _sparse.py │ ├── basis.py │ ├── configuration.py │ ├── formats.py │ ├── matrices │ │ ├── __init__.py │ │ ├── basis_matrix.py │ │ └── physics │ │ │ ├── __init__.py │ │ │ ├── density_matrix.py │ │ │ └── orbital_matrix.py │ ├── metrics.py │ ├── neighborhood.py │ ├── node_feats.py │ ├── processing.py │ ├── sparse.py │ ├── table.py │ └── tests │ │ ├── test_basis.py │ │ ├── test_configuration.py │ │ ├── test_metrics.py │ │ ├── test_processing.py │ │ ├── test_sparse.py │ │ └── test_table.py └── modules │ ├── CMakeLists.txt │ ├── __init__.py │ ├── _labels_resort.py │ ├── graph2mat.py │ ├── matrixblock.py │ └── tests │ └── test_graph2mat.py ├── models ├── __init__.py └── mace.py └── tools ├── __init__.py ├── cli ├── __init__.py ├── _typer.py ├── cli.py ├── models │ ├── cli.py │ └── mace │ │ └── cli.py ├── request.py ├── serve.py └── siesta │ ├── main_cli.py │ └── md.py ├── lightning ├── __init__.py ├── callbacks.py ├── cli.py ├── data.py ├── model.py ├── models │ └── mace.py └── tests │ ├── test_lightning.py │ └── test_models.py ├── server ├── __init__.py ├── api_client.py ├── extrapolation.py ├── frontend │ ├── static │ │ ├── javascript │ │ │ └── form.js │ │ └── styles │ │ │ └── styles.css │ └── templates │ │ ├── about.html │ │ ├── index.html │ │ ├── model_action_page.html │ │ ├── model_action_picker.html │ │ ├── model_info_card.html │ │ ├── model_picker_sidebar.html │ │ ├── predict_page.html │ │ ├── test_page.html │ │ └── topbar.html └── server_app.py ├── siesta ├── __init__.py ├── md.py └── templates │ ├── fdf │ ├── dm_init_atomic.fdf │ ├── dm_init_ml.fdf │ └── dm_init_siesta_extrapolation.fdf │ └── lua │ └── graph2mat.lua └── viz ├── __init__.py └── sparse_plot.py /.github/workflows/docs_to_pages.yaml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: Documentation to github pages 4 | 5 | # Controls when the workflow will run 6 | on: 7 | # Allows you to run this workflow manually from the Actions tab 8 | workflow_dispatch: 9 | 10 | permissions: 11 | contents: read 12 | pages: write 13 | id-token: write 14 | 15 | jobs: 16 | # Build job 17 | build: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@v4 22 | - name: Setup Pages 23 | id: pages 24 | uses: actions/configure-pages@v3 25 | 26 | - name: Setup python environment 27 | uses: actions/setup-python@v4 28 | with: 29 | python-version: '3.11' 30 | cache: "pip" 31 | 32 | - uses: r-lib/actions/setup-pandoc@v2 33 | 34 | - name: Install graph2mat + documentation dependencies 35 | run: | 36 | python -m pip install .[docs] 37 | 38 | - name: Build the documentation 39 | run: | 40 | cd docs 41 | make html 42 | rm -rf build/html/_sources 43 | touch build/html/.nojekyll 44 | cd .. 45 | 46 | - name: Upload artifact 47 | uses: actions/upload-pages-artifact@v3 48 | with: 49 | path: docs/build/html 50 | 51 | # Deployment job 52 | deploy: 53 | environment: 54 | name: github-pages 55 | url: ${{steps.deployment.outputs.page_url}} 56 | runs-on: ubuntu-latest 57 | needs: build 58 | steps: 59 | - name: Deploy to GitHub Pages 60 | id: deployment 61 | uses: actions/deploy-pages@v4 62 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yaml: -------------------------------------------------------------------------------- 1 | name: Wheel creation and publishing 2 | 3 | # Change this to whatever you want 4 | on: 5 | push: 6 | tags: 7 | - 'v*' 8 | workflow_dispatch: 9 | inputs: 10 | branch: 11 | description: 'Which branch to build wheels for' 12 | required: false 13 | default: 'main' 14 | release: 15 | description: 'Whether to release, or not?' 16 | type: boolean 17 | required: false 18 | default: false 19 | 20 | concurrency: 21 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 22 | cancel-in-progress: true 23 | 24 | jobs: 25 | 26 | # cibuildwheels already manages multiple python versions automatically 27 | # by just detecting the os. However, it does everything in the same job 28 | # and therefore in a serial manner. We build a matrix of operating systems 29 | # and python versions so that builds are ran in parallel. 30 | # The job matrix is basically copied from https://github.com/scikit-learn/scikit-learn/blob/main/.github/workflows/wheels.yml 31 | build_wheels: 32 | name: Wheel building 33 | runs-on: ${{ matrix.os }} 34 | strategy: 35 | # If one of the jobs fails, continue with the others. 36 | fail-fast: false 37 | matrix: 38 | os: [macos-latest, ubuntu-latest] 39 | 40 | steps: 41 | - uses: actions/checkout@v4 42 | with: 43 | fetch-depth: 0 44 | submodules: false 45 | 46 | # We use the cibuildwheel action to take care of everything 47 | - name: Build wheels (Mac) 48 | if: runner.os == 'macOS' 49 | uses: pypa/cibuildwheel@v2.16 50 | 51 | - name: Build wheels (Linux) 52 | if: runner.os == 'Linux' 53 | uses: pypa/cibuildwheel@v2.16 54 | 55 | - name: Build wheels (Windows) 56 | if: runner.os == 'Windows' 57 | uses: pypa/cibuildwheel@v2.16 58 | env: 59 | # when building with windows the Cython generated sources lacks linking 60 | # against -lpythonX.Y, I don't know why, or how to bypass this problem. 61 | # Nothing apparent on the web... :( 62 | CMAKE_GENERATOR: MinGW Makefiles 63 | PIP_NO_CLEAN: "yes" 64 | 65 | # Upload the wheel to the action's articfact. 66 | - uses: actions/upload-artifact@v4 67 | with: 68 | name: artifact-${{ matrix.os }} 69 | path: ./wheelhouse/*.whl 70 | 71 | # Build the source distribution as well 72 | build_sdist: 73 | name: Build source distribution 74 | runs-on: ubuntu-latest 75 | steps: 76 | - uses: actions/checkout@v4 77 | with: 78 | fetch-depth: 0 79 | submodules: false 80 | 81 | - name: Build sdist 82 | run: pipx run build --sdist 83 | 84 | - uses: actions/upload-artifact@v4 85 | with: 86 | path: dist/*.tar.gz 87 | 88 | # Upload to testpypi 89 | # upload_testpypi: 90 | # needs: [build_sdist, build_wheels] 91 | # if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 92 | # environment: 93 | # name: testpypi 94 | # url: https://test.pypi.org/p/graph2mat 95 | # permissions: 96 | # id-token: write 97 | # name: Publish package to TestPyPI 98 | # runs-on: ubuntu-latest 99 | # steps: 100 | # - uses: actions/download-artifact@v4 101 | # with: 102 | # path: dist 103 | # merge-multiple: True 104 | 105 | # - uses: pypa/gh-action-pypi-publish@v1.8.11 106 | # with: 107 | # repository-url: https://test.pypi.org/legacy/ 108 | 109 | # # Check that the testpypi installation works 110 | # test_testpypi: 111 | # needs: [upload_testpypi] 112 | # name: Test installation from TestPyPi 113 | # runs-on: ${{ matrix.os }} 114 | 115 | # strategy: 116 | # # If one of the jobs fails, continue with the others. 117 | # fail-fast: false 118 | # matrix: 119 | # os: [ubuntu-latest, macos-latest] 120 | 121 | # steps: 122 | # - name: Python installation 123 | # uses: actions/setup-python@v5 124 | # with: 125 | # python-version: "3.9" 126 | 127 | # # We should also wait for index to update on remote server 128 | # - name: Install graph2mat + dependencies 129 | # run: | 130 | # sleep 10 131 | # version=${GITHUB_REF#refs/*/v} 132 | # version=${version#refs/*/} 133 | # python -m pip install --progress-bar=off --find-links dist --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ graph2mat[test]==${version} 134 | # - name: Test the installation 135 | # run: | 136 | # pytest --pyargs graph2mat 137 | # Upload to PyPI on every tag 138 | upload_pypi: 139 | #needs: [test_testpypi] 140 | needs: [build_sdist, build_wheels] 141 | name: Publish package to Pypi 142 | runs-on: ubuntu-latest 143 | environment: 144 | name: pypi 145 | url: https://pypi.org/p/graph2mat 146 | permissions: 147 | id-token: write 148 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 149 | # alternatively, to publish when a GitHub Release is created, use the following rule: 150 | # if: github.event_name == 'release' && github.event.action == 'published' 151 | steps: 152 | - uses: actions/download-artifact@v4 153 | with: 154 | path: dist 155 | merge-multiple: True 156 | 157 | - uses: pypa/gh-action-pypi-publish@v1.8.11 158 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | build/ 8 | *.egg-info/ 9 | 10 | # pytest 11 | .pytest_cache/ 12 | 13 | # mypy 14 | .mypy_cache/ 15 | 16 | # IDE 17 | .idea/ 18 | .vscode/ 19 | 20 | # Cython files 21 | *.so 22 | *.c 23 | 24 | # Docs 25 | docs/build 26 | docs/api/api-generated 27 | docs/api/conversions.rst 28 | 29 | # Jupyter notebook checkpoints 30 | **/.ipynb_checkpoints 31 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.5.0 6 | hooks: 7 | - id: check-yaml 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | # Using this mirror lets us use mypyc-compiled black, which is about 2x faster 11 | - repo: https://github.com/psf/black-pre-commit-mirror 12 | rev: 23.10.1 13 | hooks: 14 | - id: black-jupyter 15 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.21) 2 | # We will use cmake_path for file-name manipulation 3 | 4 | project(${SKBUILD_PROJECT_NAME} LANGUAGES C) 5 | 6 | find_package( 7 | Python 8 | COMPONENTS Interpreter Development.Module 9 | NumPy 10 | REQUIRED) 11 | 12 | find_program(CYTHON "cython") 13 | 14 | add_subdirectory("src") 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 BIG-MAP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | graph2mat: Equivariant matrices meet machine learning 2 | ---------------------- 3 | 4 | ![graph2mat_overview](https://raw.githubusercontent.com/BIG-MAP/graph2mat/main/docs/_static/images/graph2mat_overview.svg) 5 | 6 | The aim of `graph2mat` is to pave your way into meaningful science by providing the **tools to interface to common machine learning frameworks** (e3nn, pytorch) **to learn equivariant matrices.** 7 | 8 | **[Documentation](https://big-map.github.io/graph2mat/)** 9 | 10 | It also provides a **set of tools** to facilitate the training and usage of the models created using the package: 11 | 12 | - **Training tools**: It contains custom `pytorch_lightning` modules to train, validate and test the orbital matrix models. 13 | - **Server**: A production ready server (and client) to serve predictions of the trained 14 | models. Implemented using `fastapi`. 15 | - **Siesta**: A set of tools to interface the machine learning models with SIESTA. These include tools for input preparation, analysis of performance... 16 | 17 | The package also implements a **command line interface** (CLI): `graph2mat`. The aim of this CLI is 18 | to make the usage of `graph2mat`'s tools as simple as possible. It has two objectives: 19 | 20 | - Make life easy for the model developers. 21 | - Facilitate the usage of the models by non machine learning scientists, who just want 22 | good predictions for their systems. 23 | 24 | Installation 25 | ------------ 26 | 27 | It can be installed with pip. Adding the tools extra will also install all the dependencies 28 | needed to use the tools provided. 29 | 30 | ``` 31 | pip install graph2mat[tools] 32 | ``` 33 | 34 | If you want to use `graph2mat` with e3nn you can also ask for the `e3nn` extra dependencies: 35 | 36 | ``` 37 | pip install graph2mat[tools,e3nn] 38 | ``` 39 | 40 | What is an equivariant matrix? 41 | ------------------------------ 42 | 43 | ![water_equivariant_matrix](https://raw.githubusercontent.com/BIG-MAP/graph2mat/main/docs/_static/images/water_equivariant_matrix.png) 44 | 45 | 46 | Contributions 47 | -------------- 48 | 49 | We are very open to suggestions, contributions, discussions... 50 | 51 | - If you have questions or want do discuss an idea, please [start a discussion](https://github.com/BIG-MAP/graph2mat/discussions) 52 | - If you have a feature suggestion or bug report, please [open an issue](https://github.com/BIG-MAP/graph2mat/issues) 53 | 54 | We are looking forward to your contributions! 55 | 56 | The `graph2mat` package was originally created by Peter Bjørn Jorgensen (@peterbjorgensen) and Pol Febrer (@pfebrer) in the frame of a collaboration to machine learn density matrices. 57 | 58 | Since then, the following users have contributed to the code: 59 | 60 | 61 | 62 | 63 | 64 | Citation 65 | -------- 66 | 67 | If you use `graph2mat` for one of your works, please cite our original paper: 68 | 69 | ``` 70 | @article{febrer2025graph2mat, 71 | title={Graph2Mat: universal graph to matrix conversion for electron density prediction}, 72 | author={Febrer, Pol and J{\o}rgensen, Peter Bj{\o}rn and Pruneda, Miguel and Garc{\'\i}a, Alberto and Ordej{\'o}n, Pablo and Bhowmik, Arghya}, 73 | journal={Machine Learning: Science and Technology}, 74 | volume={6}, 75 | number={2}, 76 | pages={025013}, 77 | year={2025}, 78 | publisher={IOP Publishing} 79 | } 80 | ``` 81 | 82 | We'll be very happy to see what you have done with it :) 83 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | 3 | .PHONY: default 4 | default: html 5 | 6 | # You can set these variables from the command line. 7 | SPHINXOPTS ?= 8 | SPHINXBUILD ?= sphinx-build 9 | SOURCEDIR = . 10 | BUILDDIR = build 11 | 12 | # Put it first so that "make" without argument is like "make help". 13 | help: 14 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 15 | 16 | 17 | .PHONY: help Makefile 18 | 19 | # Catch-all target: route all unknown targets to Sphinx using the new 20 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 21 | %: Makefile 22 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 23 | -------------------------------------------------------------------------------- /docs/_static/css/custom_styles.css: -------------------------------------------------------------------------------- 1 | /* Newlines (\a) and spaces (\20) before each parameter */ 2 | .sig-param:not(:only-of-type):before { 3 | content: "\a\20\20\20\20\20\20\20\20\20\20\20\20\20\20\20\20"; 4 | white-space: pre; 5 | } 6 | 7 | /* Closing bracket on a new line */ 8 | em.sig-param:not(:only-of-type) + .sig-paren::before { 9 | content: "\a"; 10 | white-space: pre; 11 | } 12 | 13 | /* To have blue background of width of the block (instead of width of content) */ 14 | dl.function > dt.sig, dl.class > dt.sig { 15 | display: block !important; 16 | } 17 | 18 | em.sig-param span.default_value { 19 | color: gray; 20 | } 21 | 22 | /* Black and bold argument names, including **kwargs */ 23 | em.sig-param span.n:nth-child(-n + 2) { 24 | color: black; 25 | font-weight: bold !important; 26 | } 27 | em.sig-param span.o:first-child { 28 | color: black; 29 | font-weight: bold !important; 30 | } 31 | 32 | /* Don't bold typehints */ 33 | em.sig-param span.n { 34 | font-weight: 300; 35 | } 36 | 37 | a.reference.external:visited { 38 | color: inherit; 39 | } 40 | 41 | /* Remove number from code cells in jupyter notebooks*/ 42 | div.nbinput .prompt, div.nboutput .prompt { 43 | display: none; 44 | } 45 | 46 | div.nblast.container { 47 | margin-bottom: 24px; 48 | } 49 | 50 | /* Format the header of conversion functions */ 51 | div.g2m-conversion-func-header { 52 | padding: 5px; 53 | background-color: rgb(196, 252, 191); 54 | } 55 | 56 | div.g2m-conversion-func-header p { 57 | padding: 0; 58 | margin: 0; 59 | } 60 | 61 | div.g2m-conversion-func-header + dl.py.function dt{ 62 | scroll-margin-top: 50px; 63 | } 64 | -------------------------------------------------------------------------------- /docs/_static/images/water_equivariant_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIG-MAP/graph2mat/8eb40159b0a29b0fb524245714cc93dbe1885be4/docs/_static/images/water_equivariant_matrix.png -------------------------------------------------------------------------------- /docs/_templates/autosummary/custom-class-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :no-special-members: 8 | :show-inheritance: 9 | :no-inherited-members: 10 | 11 | {% block methods %} 12 | 13 | {% if methods %} 14 | .. rubric:: {{ _('Methods') }} 15 | 16 | .. autosummary:: 17 | {% for item in methods %} 18 | {%- if item not in inherited_members and not item.startswith("__")%} 19 | ~{{ name }}.{{ item }} 20 | {%- endif %} 21 | {%- endfor %} 22 | {% endif %} 23 | {% endblock %} 24 | 25 | {% block attributes %} 26 | {% if attributes %} 27 | .. rubric:: {{ _('Attributes') }} 28 | 29 | .. autosummary:: 30 | {% for item in attributes %} 31 | {%- if item not in inherited_members %} 32 | ~{{ name }}.{{ item }} 33 | {%- endif %} 34 | {%- endfor %} 35 | {% endif %} 36 | {% endblock %} 37 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/custom-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block attributes %} 6 | {% if attributes %} 7 | .. rubric:: Module Attributes 8 | 9 | .. autosummary:: 10 | :toctree: 11 | {% for item in attributes %} 12 | {{ item }} 13 | {%- endfor %} 14 | {% endif %} 15 | {% endblock %} 16 | 17 | {% block functions %} 18 | {% if functions %} 19 | .. rubric:: {{ _('Functions') }} 20 | 21 | .. autosummary:: 22 | {% for item in functions %} 23 | {{ item }} 24 | {%- endfor %} 25 | {% endif %} 26 | {% endblock %} 27 | 28 | {% block classes %} 29 | {% if classes %} 30 | .. rubric:: {{ _('Classes') }} 31 | 32 | .. autosummary:: 33 | :template: autosummary/custom-class-template.rst 34 | {% for item in classes %} 35 | {{ item }} 36 | {%- endfor %} 37 | {% endif %} 38 | {% endblock %} 39 | 40 | {% block exceptions %} 41 | {% if exceptions %} 42 | .. rubric:: {{ _('Exceptions') }} 43 | 44 | .. autosummary:: 45 | :toctree: 46 | {% for item in exceptions %} 47 | {{ item }} 48 | {%- endfor %} 49 | {% endif %} 50 | {% endblock %} 51 | 52 | {% block modules %} 53 | {% if modules %} 54 | .. rubric:: Modules 55 | 56 | .. autosummary:: 57 | :toctree: 58 | :template: autosummary/custom-module-template.rst 59 | :recursive: 60 | {% for item in modules %} 61 | {{ item }} 62 | {%- endfor %} 63 | {% endif %} 64 | {% endblock %} 65 | -------------------------------------------------------------------------------- /docs/api/full_api.rst: -------------------------------------------------------------------------------- 1 | Full documentation 2 | ---------- 3 | 4 | This section describes the full structure of the `graph2mat` package. 5 | The documentation is automatically generated from the codebase, if you 6 | want a more comprehensive overview of the classes in `graph2mat`, you can 7 | check the `more comprehensive description <./description.rst>`_. 8 | 9 | .. autosummary:: 10 | :toctree: api-generated/ 11 | :template: autosummary/custom-module-template.rst 12 | :recursive: 13 | :caption: Reference documentation 14 | 15 | graph2mat 16 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. title:: graph2mat 3 | .. meta:: 4 | :description: graph2mat is a package for 5 | :keywords: ML, e3nn, graphs 6 | 7 | 8 | graph2mat: Equivariant matrices meet machine learning 9 | ============================================================== 10 | 11 | .. image:: /_static/images/graph2mat_overview.svg 12 | :align: center 13 | 14 | 15 | 16 | The aim of ``graph2mat`` is to **pave your way into meaningful science** by providing the tools 17 | to **interface to common machine learning frameworks** (``e3nn``, ``pytorch``) **to learn equivariant matrices**. 18 | 19 | Installation 20 | ------------- 21 | 22 | Using ``pip``, installation is as simple as: 23 | 24 | .. code-block:: bash 25 | 26 | pip install graph2mat 27 | 28 | I would like to... 29 | ------------------ 30 | 31 | - **Learn and predict** matrices using built-in models: `CLI tutorials `_. 32 | - **Develop** my own matrix-predicting model: `Python API tutorials `_. 33 | - Get an **overview of the python API**: `API overview `_. 34 | - **Find documentation** for a particular function/class: `API documentation `_. 35 | 36 | Background 37 | ----------- 38 | 39 | We use the term **equivariant matrix** to refer to a matrix whose rows and columns are 40 | representing some basis made of spherical harmonics. The values of this matrix arise 41 | from the interaction of such basis, and therefore follow the equivariance properties 42 | of products of spherical harmonics. 43 | 44 | .. image:: /_static/images/water_equivariant_matrix.png 45 | :align: center 46 | 47 | One particular case of equivariant matrices are those in which **rows and columns represent 48 | the same basis**. These matrices usually come up in physics when **atom-centered spherical 49 | harmonics** are used as basis functions. Some examples are Hamiltonian and overlap matrices in 50 | quantum chemistry. By the nature of the basis functions, which usually have a finite 51 | range determined by a radial function, these matrices tend to be sparse. 52 | 53 | Dealing with both the **equivariance and the sparsity** of these matrices within a machine 54 | learning framework is not a trivial task. This can easily deter people from implementing 55 | powerful applications that take full advantage of the properties of these matrices. With 56 | ``graph2mat``, we hope that people can explore the full potential of these matrices. 57 | 58 | .. toctree:: 59 | :maxdepth: 3 60 | :caption: Tutorials 61 | :hidden: 62 | 63 | tutorials/cli/index 64 | tutorials/python_api/index 65 | 66 | .. toctree:: 67 | :maxdepth: 3 68 | :caption: API documentation 69 | :hidden: 70 | 71 | api/description 72 | api/conversions 73 | api/full_api 74 | -------------------------------------------------------------------------------- /docs/server.rst: -------------------------------------------------------------------------------- 1 | Usage 2 | ------ 3 | 4 | ### Server 5 | 6 | To serve pretrained models, one can use the command `e3mat serve`. You can do `e3mat serve --help` 7 | to check how it works. Serving models is as simple as: 8 | 9 | ```bash 10 | e3mat serve some.ckpt other.ckpt 11 | ``` 12 | 13 | This will serve the models in the checkpoint files. However, we recommend that you organize your 14 | checkpoints into folders and then pass the names of the folders instead. 15 | 16 | ```bash 17 | e3mat serve first_model second_model 18 | ``` 19 | 20 | where `first_model` and `second_model` are folders that contain a `spec.yaml` file looking something like: 21 | 22 | ```yaml 23 | description: | 24 | This model predicts single water molecules. 25 | authors: 26 | - Pol Febrer (pol.febrer@icn2.cat) 27 | 28 | files: # All the files related to this model. 29 | ckpt: best.ckpt 30 | basis: "*.ion.nc" 31 | structs: structs.xyz 32 | sample_metrics: sample_metrics.csv 33 | database: http://data.com/link/to/your/matrix_database 34 | ``` 35 | 36 | Once your server is running, you will get the url where the server is running, e.g. ttp://localhost:56000. 37 | You can interact with it in multiple ways: 38 | - Through the simple graphical interface included in the package, by opening `http://localhost:56000` in a browser. 39 | - Through the `ServerClient` class in `e3nn_matrix.tools.server.api_client`. 40 | - By simply sending requests to the API of the server. These requests must be sent to `http://localhost:56000/api`. You 41 | can check the documentation for the requests that the server understands under `http://localhost:56000/api/docs`. 42 | -------------------------------------------------------------------------------- /docs/tutorials/cli/index.rst: -------------------------------------------------------------------------------- 1 | 2 | CLI tutorials 3 | ------------------- 4 | 5 | These tutorials show you how to **get your hands dirty** using the built-in models 6 | through the Command Line Interface (CLI). 7 | 8 | **Applications are the main focus** of these tutorials. Implementation details 9 | are not discussed. If you are interested in developing your own thing, or just 10 | curious to see what's behind the curtain, see the `Python API tutorials <../python_api/index.rst>`_. 11 | 12 | .. toctree:: 13 | :hidden: 14 | 15 | siesta_MDaccel 16 | -------------------------------------------------------------------------------- /docs/tutorials/python_api/index.rst: -------------------------------------------------------------------------------- 1 | Python API tutorials 2 | -------------------- 3 | 4 | Are you interested in **developing your own thing**, or just a curious person 5 | that likes to explore? This section is for you! 6 | 7 | The following tutorials show you **how to tackle the different problems** that you 8 | might face **when integrating sparse equivariant matrices into a machine learning 9 | workflow**. 10 | 11 | 12 | .. toctree:: 13 | :hidden: 14 | 15 | Computing a matrix.ipynb 16 | Batching.ipynb 17 | Fitting matrices.ipynb 18 | MACE+Graph2mat.ipynb 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | requires-python = ">=3.9" 3 | 4 | name = "graph2mat" 5 | version = "0.0.9" 6 | description = "Utility package to work with equivariant matrices and graphs." 7 | readme = "README.md" 8 | license = {text = "MIT"} 9 | keywords = [ 10 | "machine learning", 11 | "equivariance", 12 | "e3nn", 13 | "matrix", 14 | ] 15 | 16 | authors = [ 17 | {name = "Pol Febrer", email = "pfebrer96@gmail.com"}, 18 | {name = "Peter B. Jørgensen", email = "peterbjorgensen@gmail.com"} 19 | ] 20 | 21 | dependencies = [ 22 | "numpy", 23 | "scipy", 24 | "ase", 25 | "sisl[viz]>=0.15.0", 26 | "typer", 27 | ] 28 | 29 | classifiers = [ 30 | "Programming Language :: Python :: 3.9", 31 | "Programming Language :: Python :: 3.10", 32 | "Programming Language :: Python :: 3.11", 33 | "Programming Language :: Python :: 3.12", 34 | "Topic :: Scientific/Engineering", 35 | ] 36 | 37 | [project.scripts] 38 | graph2mat = "graph2mat.tools.cli.cli:app" 39 | 40 | [build-system] 41 | requires = [ 42 | "setuptools_scm[toml]>=6.2", 43 | "scikit-build-core[pyproject]>=0.8", 44 | "cython>=3", 45 | "numpy>=2.0.0" 46 | ] 47 | build-backend = "scikit_build_core.build" 48 | 49 | [tool.scikit-build] 50 | # Consider adding 51 | # minimum-version to choose the fallback mechanism in scikit-build-core 52 | wheel.packages = ["src/graph2mat"] 53 | 54 | [project.optional-dependencies] 55 | server = [ 56 | "pyyaml", 57 | "fastapi", 58 | "uvicorn", 59 | "python-multipart", 60 | "jinja2" 61 | ] 62 | 63 | analysis = [ 64 | "plotly", 65 | "pandas" 66 | ] 67 | 68 | lightning = [ 69 | "pytorch-lightning", 70 | "jsonargparse[signatures]", 71 | "tensorboard" 72 | ] 73 | 74 | siesta = [ 75 | "jinja2" 76 | ] 77 | 78 | torch = [ 79 | "torch", 80 | "torch_geometric" 81 | ] 82 | 83 | e3nn = [ 84 | "torch<=2.5", 85 | "torch_geometric", 86 | "e3nn" 87 | ] 88 | 89 | mace = [ 90 | "torch<=2.5", 91 | "torch_geometric", 92 | "mace_torch" 93 | ] 94 | 95 | tools = [ 96 | "pyyaml", 97 | "fastapi", 98 | "uvicorn", 99 | "python-multipart", 100 | "plotly", 101 | "kaleido", 102 | "pandas", 103 | "pytorch-lightning", 104 | "jsonargparse[signatures]", 105 | "jinja2", 106 | "mace_torch", 107 | "tensorboard" 108 | ] 109 | 110 | docs = [ 111 | "pyyaml", 112 | "fastapi", 113 | "uvicorn", 114 | "plotly<6", # Until https://github.com/plotly/plotly.py/issues/5056 is fixed 115 | "pandas", 116 | "pytorch-lightning", 117 | "jsonargparse[signatures]", 118 | "jinja2", 119 | "sphinx", 120 | "sphinx_autodoc_typehints", 121 | "sphinx_rtd_theme", 122 | "nbsphinx", 123 | "ipykernel", 124 | "e3nn", 125 | "torch<=2.5", # MACE pins e3nn to 0.4.4, which doesn't work with torch 2.6 126 | "mace_torch", 127 | "torch_geometric" 128 | ] 129 | 130 | test = [ 131 | "pytest", 132 | "pyyaml", 133 | "plotly", 134 | "pandas", 135 | "pytorch-lightning", 136 | "e3nn", 137 | "torch", 138 | "mace_torch", 139 | "torch_geometric", 140 | "jsonargparse[signatures]", 141 | "jinja2" 142 | ] 143 | 144 | #"fastapi", "uvicorn" 145 | 146 | [tool.cibuildwheel] 147 | build-verbosity = 3 148 | skip = [ 149 | "pp*", 150 | "*i686", 151 | "*musllinux*", 152 | ] 153 | 154 | [tool.black] 155 | line-length = 88 156 | target-version = ["py39", "py310", "py311", "py312"] 157 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory("graph2mat") 2 | -------------------------------------------------------------------------------- /src/graph2mat/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory("core") 2 | -------------------------------------------------------------------------------- /src/graph2mat/__init__.py: -------------------------------------------------------------------------------- 1 | """A package to generate and manipulate sparse equivariant matrices. 2 | 3 | The package contains a `core` submodule implementing the main functionality, 4 | and other submodules that implement specific functionalities. 5 | """ 6 | 7 | from .__version__ import __version__ 8 | 9 | from .core import * 10 | -------------------------------------------------------------------------------- /src/graph2mat/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.9" 2 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/__init__.py: -------------------------------------------------------------------------------- 1 | """Specific interfaces to other codes. 2 | 3 | The core functionality of `graph2mat` is agnostic to the framework, 4 | and it is based on pure python and `numpy`. 5 | 6 | For running a specific ML workflow, we need interfaces of the core 7 | functionality with ML frameworks, e.g. `torch`, `e3nn`. The implementation 8 | of the interfaces are mostly thin wrappers around the core functionality, 9 | as well as functions that only make sense on the specific framework. For 10 | example, the `e3nn` bindings contains functions that use irreps. 11 | 12 | Whatever framework that we interface with, it should not be a required import 13 | for the core functionality. So basically, the criteria for creating a new 14 | submodule in `bindings` is that we can't add the functionality to the core 15 | without requiring the framework as a dependency. 16 | """ 17 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/e3nn/__init__.py: -------------------------------------------------------------------------------- 1 | """Interface to e3nn, as well as functions that use irreps. 2 | 3 | `e3nn` uses `torch`. 4 | 5 | Therefore, the interface with `e3nn` takes as 6 | a starting point the interface with `torch`, defined in `graph2mat.bindings.torch`. 7 | 8 | This interface has two goals: 9 | 10 | - **Wrap the core functionality**. The main addition to the core functionality is two thin 11 | wrappers around ``TorchGraph2Mat`` and ``TorchMatrixBlock`` to handle irreps. **There is no 12 | need to wrap the data containers for now, as we don't define data in terms of irreps**. In 13 | the future, one could envision for example training directly on irreps, and that would 14 | require a wrapper around `TorchBasisMatrixData`. 15 | 16 | - **Implement equivariant functions that use** `e3nn` **irreps** and therefore can not be implemented 17 | in `graph2mat.core` without importing `e3nn`. These functions can be used as the blocks within `Graph2Mat` 18 | to make the model equivariant. 19 | 20 | """ 21 | 22 | from .modules import * 23 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/e3nn/irreps_tools.py: -------------------------------------------------------------------------------- 1 | """Utility tools to deal with e3nn irreps. 2 | 3 | These are basically tools to convert from/to irreps. 4 | 5 | They are currently not being used anywhere in `graph2mat`. 6 | """ 7 | from typing import Union, Sequence, Iterable 8 | 9 | import sisl 10 | import numpy as np 11 | 12 | from e3nn import o3 13 | 14 | 15 | def get_atom_irreps(atom: sisl.Atom): 16 | """For a given atom, returns the irreps representation of its basis. 17 | 18 | Parameters 19 | ---------- 20 | atom: sisl.Atom 21 | The atom for which we want the irreps of its basis. 22 | 23 | Returns 24 | ---------- 25 | o3.Irreps: 26 | the basis irreps. 27 | """ 28 | 29 | if atom.no == 0: 30 | return o3.Irreps("") 31 | 32 | atom_irreps = [] 33 | 34 | # Array that stores the number of orbitals for each l. 35 | # We allocate 8 ls, we will probably never need this much. 36 | n_ls = np.zeros(8) 37 | 38 | # Loop over all orbitals that this atom contains 39 | for orbital in atom.orbitals: 40 | # For each orbital, find its l quantum number 41 | # and increment the total number of orbitals for that l 42 | n_ls[orbital.l] += 1 43 | 44 | # We don't really want to know the number of orbitals for a given l, 45 | # but the number of SETS of orbitals. E.g. a set of l=1 has 3 orbitals. 46 | n_ls /= 2 * np.arange(8) + 1 47 | 48 | # Now just loop over all ls, and intialize as much irreps as we need 49 | # for each of them. We build a list of tuples (n_irreps, (l, parity)) 50 | # to pass it to o3.Irreps. 51 | for l, n_l in enumerate(n_ls): 52 | if n_l != 0: 53 | atom_irreps.append((int(n_l), (l, (-1) ** l))) 54 | 55 | return o3.Irreps(atom_irreps) 56 | 57 | 58 | def get_atom_from_irreps( 59 | irreps: Union[o3.Irreps, str], 60 | orb_kwargs: Union[Iterable[dict], dict] = {}, 61 | atom_args: Sequence = (), 62 | **kwargs, 63 | ): 64 | """Returns a sisl atom with the basis specified by irreps.""" 65 | if isinstance(orb_kwargs, dict): 66 | orb_kwargs = [orb_kwargs] * len(o3.Irreps(irreps).ls) 67 | 68 | orbitals = [] 69 | for orbital_l, orbital_kwargs in zip(o3.Irreps(irreps).ls, orb_kwargs): 70 | if len(orbital_kwargs) == 0: 71 | orbital_kwargs = { 72 | "rf_or_func": None, 73 | } 74 | 75 | for m in range(-orbital_l, orbital_l + 1): 76 | orbital = sisl.SphericalOrbital(l=orbital_l, m=m, **orbital_kwargs) 77 | 78 | orbitals.append(orbital) 79 | 80 | if len(atom_args) == 0: 81 | kwargs = { 82 | "Z": 1, 83 | **kwargs, 84 | } 85 | 86 | return sisl.Atom(*atom_args, orbitals=orbitals, **kwargs) 87 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/e3nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """E3nn based functions to use in graph2mat. 2 | 3 | In `Graph2Mat`, if you are fitting an equivariant matrix, you might 4 | want to design an equivariant model. 5 | 6 | For this reason, this module implements the equivariant functions to 7 | use as the blocks within `Graph2Mat`. It also implements an 8 | `E3nnGraph2Mat` class which is just a subclass of `Graph2Mat` with the 9 | right defaults and an extra argument to pass e3nn's irreps. 10 | """ 11 | from .graph2mat import * 12 | from .matrixblock import * 13 | from .edge_operations import * 14 | from .node_operations import * 15 | from .preprocessing import * 16 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/e3nn/modules/_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Optional 2 | 3 | import torch 4 | from e3nn import o3 5 | 6 | 7 | # Taken directly from the MACE repository (mace.modules.irreps_tools). 8 | def tp_out_irreps_with_instructions( 9 | irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps 10 | ) -> Tuple[o3.Irreps, List]: 11 | """""" 12 | trainable = True 13 | 14 | # Collect possible irreps and their instructions 15 | irreps_out_list: List[Tuple[int, o3.Irreps]] = [] 16 | instructions = [] 17 | for i, (mul, ir_in) in enumerate(irreps1): 18 | for j, (_, ir_edge) in enumerate(irreps2): 19 | for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 20 | if ir_out in target_irreps: 21 | k = len(irreps_out_list) # instruction index 22 | irreps_out_list.append((mul, ir_out)) 23 | instructions.append((i, j, k, "uvu", trainable)) 24 | 25 | # We sort the output irreps of the tensor product so that we can simplify them 26 | # when they are provided to the second o3.Linear 27 | irreps_out = o3.Irreps(irreps_out_list) 28 | irreps_out, permut, _ = irreps_out.sort() 29 | 30 | # Permute the output indexes of the instructions to match the sorted irreps: 31 | instructions = [ 32 | (i_in1, i_in2, permut[i_out], mode, train) 33 | for i_in1, i_in2, i_out, mode, train in instructions 34 | ] 35 | 36 | instructions = sorted(instructions, key=lambda x: x[2]) 37 | 38 | return irreps_out, instructions 39 | 40 | 41 | # Taken directly from the MACE repository (mace.tools.scatter). 42 | def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 43 | if dim < 0: 44 | dim = other.dim() + dim 45 | if src.dim() == 1: 46 | for _ in range(0, dim): 47 | src = src.unsqueeze(0) 48 | for _ in range(src.dim(), other.dim()): 49 | src = src.unsqueeze(-1) 50 | src = src.expand_as(other) 51 | return src 52 | 53 | 54 | # Taken directly from the MACE repository (mace.tools.scatter). 55 | def scatter_sum( 56 | src: torch.Tensor, 57 | index: torch.Tensor, 58 | dim: int = -1, 59 | out: Optional[torch.Tensor] = None, 60 | dim_size: Optional[int] = None, 61 | reduce: str = "sum", 62 | ) -> torch.Tensor: 63 | assert reduce == "sum" # for now, TODO 64 | index = _broadcast(index, src, dim) 65 | if out is None: 66 | size = list(src.size()) 67 | if dim_size is not None: 68 | size[dim] = dim_size 69 | elif index.numel() == 0: 70 | size[dim] = 0 71 | else: 72 | size[dim] = int(index.max()) + 1 73 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 74 | return out.scatter_add_(dim, index, src) 75 | else: 76 | return out.scatter_add_(dim, index, src) 77 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/e3nn/modules/edge_operations.py: -------------------------------------------------------------------------------- 1 | """E3nn operations to compute edge matrix blocks. 2 | 3 | In edge matrix blocks, you tipically will have, for each edge, 4 | a different message coming from each atom in the edge. The edge block 5 | will tipically not be symmetric, but it is common that. 6 | 7 | .. math:: 8 | B_{ij} = B_{ji}^T 9 | """ 10 | 11 | from e3nn import o3, nn 12 | import torch 13 | 14 | from typing import Tuple 15 | 16 | from ._utils import tp_out_irreps_with_instructions 17 | 18 | __all__ = [ 19 | "E3nnSimpleEdgeBlock", 20 | "E3nnEdgeBlockNodeMix", 21 | ] 22 | 23 | 24 | class E3nnSimpleEdgeBlock(torch.nn.Module): 25 | def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps): 26 | super().__init__() 27 | 28 | if isinstance(irreps_in, (o3.Irreps, str)): 29 | irreps_in = [irreps_in] 30 | 31 | self.tensor_products = torch.nn.ModuleList( 32 | [ 33 | o3.FullyConnectedTensorProduct( 34 | this_irreps_in, this_irreps_in, irreps_out 35 | ) 36 | for this_irreps_in in irreps_in 37 | ] 38 | ) 39 | 40 | def forward( 41 | self, **tuple_kwargs: Tuple[torch.Tensor, torch.Tensor] 42 | ) -> torch.Tensor: 43 | assert len(tuple_kwargs) == len( 44 | self.tensor_products 45 | ), f"Number of input tuples ({len(tuple_kwargs)}) must match number of tensor square operations ({len(self.tensor_products)})." 46 | 47 | tensor_tuples = iter(tuple_kwargs.values()) 48 | 49 | final_value = self.tensor_products[0](*next(tensor_tuples)) 50 | for i, tensor_tuple in enumerate(tensor_tuples): 51 | final_value = final_value + self.tensor_products[i + 1](*tensor_tuple) 52 | 53 | return final_value 54 | 55 | 56 | class E3nnEdgeBlockNodeMix(torch.nn.Module): 57 | _data_get_edge_args = ("edge_feats",) 58 | 59 | def __init__( 60 | self, 61 | edge_feats_irreps: o3.Irreps, 62 | edge_messages_irreps: o3.Irreps, 63 | node_feats_irreps: o3.Irreps, 64 | irreps_out: o3.Irreps, 65 | ): 66 | super().__init__() 67 | 68 | # Network to reduce node representations to scalar features 69 | self.nodes_linear = o3.Linear(node_feats_irreps, edge_feats_irreps) 70 | 71 | # The weights of the tensor are produced by a fully connected neural network 72 | # that takes the scalar representations of nodes and edges as input 73 | irreps_mid, instructions = tp_out_irreps_with_instructions( 74 | edge_messages_irreps, 75 | edge_messages_irreps, 76 | irreps_out, 77 | ) 78 | # Tensor product between edge features from sender and receiver 79 | self.edges_tp = o3.TensorProduct( 80 | edge_messages_irreps, 81 | edge_messages_irreps, 82 | irreps_mid, 83 | instructions=instructions, 84 | shared_weights=False, 85 | internal_weights=False, 86 | ) 87 | irreps_mid = irreps_mid.simplify() 88 | 89 | edge_tp_input_irreps = edge_feats_irreps * 3 90 | assert edge_tp_input_irreps.lmax == 0 91 | input_dim = edge_tp_input_irreps.num_irreps 92 | self.edge_tp_weights = nn.FullyConnectedNet( 93 | [input_dim] + 2 * [128] + [self.edges_tp.weight_numel], 94 | torch.nn.SiLU(), 95 | ) 96 | 97 | # The final output is produced by a linear layer 98 | self.output_linear = o3.Linear(irreps_mid, irreps_out) 99 | 100 | def forward( 101 | self, 102 | edge_feats: Tuple[torch.Tensor, torch.Tensor], 103 | edge_messages: Tuple[torch.Tensor, torch.Tensor], 104 | node_feats: Tuple[torch.Tensor, torch.Tensor], 105 | ) -> torch.Tensor: 106 | # Convert nodes to scalar features 107 | scalar_node_feats_sender = self.nodes_linear(node_feats[0]) 108 | scalar_node_feats_receiver = self.nodes_linear(node_feats[1]) 109 | scalar_feats = torch.concatenate( 110 | (scalar_node_feats_sender, scalar_node_feats_receiver, edge_feats[0]), dim=1 111 | ) 112 | # Obtain weights for edge tensor product 113 | edge_tp_weights = self.edge_tp_weights(scalar_feats) 114 | 115 | # Compute edge tensor product 116 | edges_tp = self.edges_tp(edge_messages[0], edge_messages[1], edge_tp_weights) 117 | 118 | # Compute final output 119 | output = self.output_linear(edges_tp) 120 | 121 | return output 122 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/e3nn/modules/matrixblock.py: -------------------------------------------------------------------------------- 1 | from e3nn import o3 2 | import torch 3 | from typing import Type, Dict 4 | import inspect 5 | 6 | from graph2mat import PointBasis 7 | from graph2mat.bindings.torch import TorchMatrixBlock 8 | 9 | __all__ = ["E3nnIrrepsMatrixBlock"] 10 | 11 | 12 | class E3nnIrrepsMatrixBlock(TorchMatrixBlock): 13 | """Computes a matrix block by computing its irreps first.""" 14 | 15 | def __init__( 16 | self, 17 | i_basis: PointBasis, 18 | j_basis: PointBasis, 19 | symmetry: str, 20 | operation_cls: Type, 21 | symm_transpose: bool = False, 22 | preprocessor=None, 23 | irreps: Dict[str, o3.Irreps] = {}, 24 | **operation_kwargs, 25 | ): 26 | torch.nn.Module.__init__(self) 27 | 28 | i_irreps = i_basis.e3nn_irreps 29 | j_irreps = j_basis.e3nn_irreps 30 | 31 | self.i_irreps = i_irreps 32 | self.j_irreps = j_irreps 33 | 34 | self.setup_reduced_tp(i_irreps=i_irreps, j_irreps=j_irreps, symmetry=symmetry) 35 | self.symm_transpose = symm_transpose 36 | 37 | operation_kwargs = { 38 | **self.get_init_kwargs(irreps, operation_cls), 39 | **operation_kwargs, 40 | } 41 | 42 | self.operation = operation_cls(**operation_kwargs) 43 | 44 | def get_summary(self): 45 | return f"{str(self.operation.__class__.__name__)}: ({self.i_irreps}) x ({self.j_irreps}) -> {self._irreps_out}" 46 | 47 | def setup_reduced_tp(self, i_irreps: o3.Irreps, j_irreps: o3.Irreps, symmetry: str): 48 | # Store the shape of the block. 49 | self.block_shape = (i_irreps.dim, j_irreps.dim) 50 | # And number of elements in the block. 51 | self.block_size = i_irreps.dim * j_irreps.dim 52 | 53 | # Understand the irreps out that we need in order to create the block. 54 | # The block is a i_irreps.dim X j_irreps.dim matrix, with possible symmetries that can 55 | # reduce the number of degrees of freedom. We indicate this to the ReducedTensorProducts, 56 | # which we only use as a helper. 57 | reduced_tp = o3.ReducedTensorProducts(symmetry, i=i_irreps, j=j_irreps) 58 | self._irreps_out = reduced_tp.irreps_out 59 | 60 | # We also store the change of basis, a matrix that will bring us from the irreps_out 61 | # to the actual matrix block that we want to calculate. 62 | self.register_buffer("change_of_basis", reduced_tp.change_of_basis) 63 | 64 | def _compute_block(self, *args, **kwargs): 65 | # Get the irreducible output 66 | irreducible_out = self.operation(*args, **kwargs) 67 | 68 | # And convert it to the actual block of the matrix, using the change of basis 69 | # matrix stored on initialization. 70 | # n = number of nodes, i = dim of irreps, x = rows in block, y = cols in block 71 | return self.numpy.einsum("ni,ixy->nxy", irreducible_out, self.change_of_basis) 72 | 73 | def get_init_kwargs(self, irreps: Dict[str, o3.Irreps], operation_cls) -> dict: 74 | kwargs = {} 75 | op_sig = inspect.signature(operation_cls) 76 | 77 | irreps = {**irreps} 78 | 79 | irreps["irreps_in"] = [ 80 | irrep 81 | for irrep in [ 82 | irreps["node_feats_irreps"], 83 | irreps.get("edge_messages_irreps"), 84 | ] 85 | if irrep is not None 86 | ] 87 | irreps["irreps_out"] = self._irreps_out 88 | 89 | for k in op_sig.parameters: 90 | if k in irreps: 91 | kwargs[k] = irreps[k] 92 | 93 | return kwargs 94 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/e3nn/modules/node_operations.py: -------------------------------------------------------------------------------- 1 | from e3nn import o3 2 | import torch 3 | 4 | __all__ = [ 5 | "E3nnSimpleNodeBlock", 6 | "E3nnSeparateTSQNodeBlock", 7 | ] 8 | 9 | 10 | class E3nnSimpleNodeBlock(torch.nn.Module): 11 | """Sums all node features and then passes them to a tensor square. 12 | 13 | All node features must have the same irreps. 14 | 15 | Example 16 | ------- 17 | If we construct a SimpleNodeBlock: 18 | 19 | >>> irreps_in = o3.Irreps("2x0e + 2x1o") 20 | >>> irreps_out = o3.Irreps("3x2e") 21 | >>> node_block = SimpleNodeBlock(irreps_in, irreps_out) 22 | 23 | and then use it with 2 different nodewise tensors: 24 | 25 | >>> node_feats = torch.randn(10, irreps_in.dim) 26 | >>> node_messages = torch.randn(10, irreps_in.dim) 27 | >>> node_block(node_feats=node_feats, node_messages=node_messages) 28 | 29 | this is equivalent to: 30 | 31 | >>> tsq = o3.TensorSquare(irreps_in, irreps_out) 32 | >>> output = tsq(node_feats + node_messages) 33 | """ 34 | 35 | def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps): 36 | super().__init__() 37 | 38 | if isinstance(irreps_in, (list, tuple)) and not isinstance( 39 | irreps_in, o3.Irreps 40 | ): 41 | assert all( 42 | irreps == irreps_in[0] for irreps in irreps_in 43 | ), "All input irreps must be the same." 44 | irreps_in = irreps_in[0] 45 | 46 | self.tsq = o3.TensorSquare(irreps_in, irreps_out) 47 | 48 | def forward(self, **node_kwargs: torch.Tensor) -> torch.Tensor: 49 | node_tensors = iter(node_kwargs.values()) 50 | 51 | node_feats = next(node_tensors) 52 | for other_node_feats in node_tensors: 53 | node_feats = node_feats + other_node_feats 54 | 55 | return self.tsq(node_feats) 56 | 57 | 58 | class E3nnSeparateTSQNodeBlock(torch.nn.Module): 59 | """Tensor squares each node features and then sums all outputs. 60 | 61 | Example 62 | ------- 63 | If we construct a SeparateTSQNodeBlock: 64 | 65 | >>> irreps_in = o3.Irreps("3x0e + 2x1o") 66 | >>> irreps_out = o3.Irreps("3x2e") 67 | >>> node_block = SeparateTSQNodeBlock(irreps_in, irreps_out) 68 | 69 | and then use it with 2 different nodewise tensors: 70 | 71 | >>> node_feats = torch.randn(10, irreps_in.dim) 72 | >>> node_messages = torch.randn(10, irreps_in.dim) 73 | >>> output = node_block(node_feats=node_feats, node_messages=node_messages) 74 | 75 | this is equivalent to: 76 | 77 | >>> tsq1 = o3.TensorSquare(irreps_in, irreps_out) 78 | >>> tsq2 = o3.TensorSquare(irreps_in, irreps_out) 79 | >>> output = tsq1(node_feats) + tsq2(node_messages) 80 | """ 81 | 82 | def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps): 83 | super().__init__() 84 | 85 | if isinstance(irreps_in, (o3.Irreps, str)): 86 | irreps_in = [irreps_in] 87 | 88 | self.tensor_squares = torch.nn.ModuleList( 89 | [ 90 | o3.TensorSquare(this_irreps_in, irreps_out) 91 | for this_irreps_in in irreps_in 92 | ] 93 | ) 94 | 95 | def forward(self, **node_kwargs: torch.Tensor) -> torch.Tensor: 96 | assert len(node_kwargs) == len( 97 | self.tensor_squares 98 | ), f"Number of input tensors ({len(node_kwargs)}) must match number of tensor square operations ({len(self.tensor_squares)})." 99 | 100 | node_tensors = iter(node_kwargs.values()) 101 | 102 | node_feats = self.tensor_squares[0](next(node_tensors)) 103 | for i, other_node_feats in enumerate(node_tensors): 104 | node_feats = node_feats + self.tensor_squares[i + 1](other_node_feats) 105 | 106 | return node_feats 107 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/e3nn/modules/preprocessing.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, Dict 2 | 3 | import torch 4 | from e3nn import o3, nn 5 | 6 | from graph2mat.bindings.torch import TorchBasisMatrixData 7 | from ._utils import tp_out_irreps_with_instructions, scatter_sum 8 | 9 | __all__ = [ 10 | "E3nnInteraction", 11 | "E3nnEdgeMessageBlock", 12 | ] 13 | 14 | 15 | class E3nnInteraction(torch.nn.Module): 16 | """Basically MACE's RealAgnosticResidualInteractionBlock, without reshapes. 17 | 18 | This function takes a graph and returns new states for the nodes. 19 | 20 | This function can be used for the preprocessing step of both nodes and edges. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | irreps: Dict[str, o3.Irreps], 26 | avg_num_neighbors: float = 10, 27 | ) -> None: 28 | super().__init__() 29 | 30 | node_feats_irreps = irreps["node_feats_irreps"] 31 | # node_attrs_irreps = irreps["node_attrs_irreps"] 32 | edge_attrs_irreps = irreps["edge_attrs_irreps"] 33 | edge_feats_irreps = irreps["edge_feats_irreps"] 34 | target_irreps = irreps["node_feats_irreps"] 35 | # hidden_irreps = irreps["node_feats_irreps"] 36 | 37 | # First linear 38 | self.linear_up = o3.Linear( 39 | node_feats_irreps, 40 | node_feats_irreps, 41 | ) 42 | # TensorProduct 43 | irreps_mid, instructions = tp_out_irreps_with_instructions( 44 | node_feats_irreps, 45 | edge_attrs_irreps, 46 | target_irreps, 47 | ) 48 | self.conv_tp = o3.TensorProduct( 49 | node_feats_irreps, 50 | edge_attrs_irreps, 51 | irreps_mid, 52 | instructions=instructions, 53 | shared_weights=False, 54 | internal_weights=False, 55 | ) 56 | 57 | # Convolution weights 58 | input_dim = edge_feats_irreps.num_irreps 59 | self.conv_tp_weights = nn.FullyConnectedNet( 60 | [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], 61 | torch.nn.SiLU(), 62 | ) 63 | 64 | # Linear 65 | irreps_mid = irreps_mid.simplify() 66 | self.irreps_mji = irreps_mid 67 | self.linear = o3.Linear( 68 | irreps_mid, target_irreps, internal_weights=True, shared_weights=True 69 | ) 70 | 71 | self.avg_num_neighbors = avg_num_neighbors 72 | 73 | self.irreps_out = target_irreps 74 | 75 | def forward( 76 | self, 77 | data: TorchBasisMatrixData, 78 | node_feats: torch.Tensor, 79 | ) -> Union[ 80 | Tuple[torch.Tensor, torch.Tensor], 81 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 82 | ]: 83 | edge_attrs = data["edge_attrs"] 84 | edge_feats = data["edge_feats"] 85 | 86 | sender, receiver = data["edge_index"] 87 | num_nodes = node_feats.shape[0] 88 | 89 | node_feats = self.linear_up(node_feats) 90 | tp_weights = self.conv_tp_weights(edge_feats) 91 | mji = self.conv_tp( 92 | node_feats[sender], edge_attrs, tp_weights 93 | ) # [n_edges, irreps] 94 | del tp_weights 95 | message = scatter_sum( 96 | src=mji, index=receiver, dim=0, dim_size=num_nodes 97 | ) # [n_nodes, irreps] 98 | del mji 99 | message = self.linear(message) / self.avg_num_neighbors 100 | 101 | return message 102 | 103 | 104 | class E3nnEdgeMessageBlock(torch.nn.Module): 105 | """This is basically MACE's RealAgnosticResidualInteractionBlock, but only up to the part 106 | where it computes the partial mji messages. 107 | 108 | It computes a "message" for each edge in the graph. Note that the message 109 | is different for the edge (i, j) and the edge (j, i). 110 | 111 | This function can be used for the preprocessing step of edges. It has no effect when used 112 | as the preprocessing step of nodes. 113 | """ 114 | 115 | def __init__( 116 | self, 117 | irreps: Dict[str, o3.Irreps], 118 | ) -> None: 119 | super().__init__() 120 | 121 | node_feats_irreps = irreps["node_feats_irreps"] 122 | edge_attrs_irreps = irreps["edge_attrs_irreps"] 123 | edge_feats_irreps = irreps["edge_feats_irreps"] 124 | target_irreps = irreps["edge_hidden_irreps"] 125 | 126 | # First linear 127 | self.linear_up = o3.Linear( 128 | node_feats_irreps, 129 | node_feats_irreps, 130 | ) 131 | 132 | # TensorProduct 133 | irreps_mid, instructions = tp_out_irreps_with_instructions( 134 | node_feats_irreps, 135 | edge_attrs_irreps, 136 | target_irreps, 137 | ) 138 | self.conv_tp = o3.TensorProduct( 139 | node_feats_irreps, 140 | edge_attrs_irreps, 141 | irreps_mid, 142 | instructions=instructions, 143 | shared_weights=False, 144 | internal_weights=False, 145 | ) 146 | 147 | # Convolution weights 148 | assert ( 149 | edge_feats_irreps.lmax == 0 150 | ), "Edge features must be a scalar array to preserve equivariance" 151 | input_dim = edge_feats_irreps.num_irreps 152 | self.conv_tp_weights = nn.FullyConnectedNet( 153 | [input_dim] + 3 * [64] + [self.conv_tp.weight_numel], 154 | torch.nn.SiLU(), 155 | ) 156 | 157 | irreps_mid = irreps_mid.simplify() 158 | 159 | self.linear = o3.Linear(irreps_mid, target_irreps) 160 | 161 | self.irreps_out = (None, target_irreps) 162 | 163 | def forward( 164 | self, 165 | data: TorchBasisMatrixData, 166 | node_feats: torch.Tensor, 167 | ) -> Tuple[None, torch.Tensor]: 168 | sender, receiver = data["edge_index"] 169 | 170 | edge_attrs = data["edge_attrs"] 171 | edge_feats = data["edge_feats"] 172 | 173 | node_feats = self.linear_up(node_feats) 174 | tp_weights = self.conv_tp_weights(edge_feats) 175 | mji = self.conv_tp( 176 | node_feats[sender], edge_attrs, tp_weights 177 | ) # [n_edges, irreps] 178 | 179 | del tp_weights 180 | 181 | # The first return is the node features [n_nodes, irreps], which we don't compute 182 | # The second return are the edge messages [n_edges, irreps] 183 | return None, self.linear(mji) 184 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/e3nn/modules/tests/test_e3nngraph2mat.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from e3nn import o3 4 | import numpy as np 5 | import torch 6 | 7 | import copy 8 | 9 | from graph2mat import ( 10 | PointBasis, 11 | BasisTableWithEdges, 12 | MatrixDataProcessor, 13 | BasisConfiguration, 14 | ) 15 | from graph2mat.bindings.e3nn import E3nnGraph2Mat 16 | 17 | from graph2mat import conversions 18 | 19 | 20 | @pytest.fixture(scope="module", params=["point_type", "basis_shape", "max"]) 21 | def basis_grouping(request): 22 | return request.param 23 | 24 | 25 | @pytest.fixture(scope="module", params=[True, False]) 26 | def symmetric(request): 27 | return request.param 28 | 29 | 30 | def get_rotation_matrix(data, table, alpha, beta, gamma): 31 | R_0 = 1 32 | R_1 = o3.wigner_D(1, alpha, beta, gamma) 33 | 34 | dim = table.basis_size[data.point_types].sum() 35 | big_R = np.zeros((dim, dim)) 36 | 37 | for inv_els in (0, 1, 5, 9, 10): 38 | big_R[inv_els, inv_els] = R_0 39 | 40 | for start, end in ((2, 5), (6, 9), (11, 14)): 41 | big_R[start:end, start:end] = R_1 42 | 43 | return R_1, big_R 44 | 45 | 46 | @pytest.fixture(scope="module", params=[False, True]) 47 | def with_no_basis(request): 48 | """Whether the test should include a point with no basis.""" 49 | return request.param 50 | 51 | 52 | def test_equivariance(basis_grouping, symmetric, with_no_basis): 53 | # The basis 54 | point_1 = PointBasis("A", R=3, basis="2x0e + 1o", basis_convention="cartesian") 55 | point_2 = PointBasis("B", R=3, basis="0e + 1o", basis_convention="cartesian") 56 | point_3 = PointBasis("C", R=3, basis="2x0e + 1o", basis_convention="cartesian") 57 | 58 | basis = [point_1, point_2, point_3] 59 | 60 | # Add an extra point with no basis 61 | if with_no_basis: 62 | point_4 = PointBasis("F", R=3, basis_convention="cartesian") 63 | basis.append(point_4) 64 | 65 | # The basis table. 66 | table = BasisTableWithEdges(basis) 67 | # The data processor. 68 | processor = MatrixDataProcessor( 69 | basis_table=table, symmetric_matrix=symmetric, sub_point_matrix=False 70 | ) 71 | 72 | g2m = E3nnGraph2Mat( 73 | unique_basis=table, 74 | irreps={"node_feats_irreps": o3.Irreps("0e + 1o")}, 75 | symmetric=symmetric, 76 | basis_grouping=basis_grouping, 77 | ) 78 | 79 | point_types = ["A", "B", "C"] 80 | positions = [[1, 1, 1], [2.0, 1, 1], [4.0, 1, 1]] 81 | # Add an extra point with no basis. To make sure that things work 82 | # in the most general case, we add it in the middle of the list. 83 | if with_no_basis: 84 | point_types.insert(1, "F") 85 | positions.insert(1, [6.0, 1, 1]) 86 | 87 | config = BasisConfiguration( 88 | point_types=point_types, 89 | positions=np.array(positions), 90 | basis=basis, 91 | cell=np.eye(3) * 100, 92 | pbc=(False, False, False), 93 | ) 94 | 95 | conv = conversions.get_converter("basisconfiguration", "torch_basismatrixdata") 96 | data = conv(config, data_processor=processor) 97 | 98 | R1, big_R = get_rotation_matrix( 99 | data, table, torch.tensor(0), torch.tensor(0), torch.tensor(90 * np.pi / 180) 100 | ) 101 | 102 | data_rot = copy.copy(data) 103 | 104 | data_rot.positions = data.positions @ R1.T 105 | 106 | def get_mat(data): 107 | node_feats = torch.concatenate( 108 | [data.point_types.reshape(-1, 1) + 1, data.positions], axis=1 109 | ) 110 | nodes, edges = g2m(data, node_feats=node_feats) 111 | 112 | return processor.matrix_from_data( 113 | data, 114 | predictions={"node_labels": nodes, "edge_labels": edges}, 115 | out_format="numpy", 116 | ) 117 | 118 | pred = get_mat(data) 119 | pred_rot = get_mat(data_rot) 120 | 121 | post_rotated = big_R @ pred @ big_R.T 122 | 123 | max_diff = abs(post_rotated - pred_rot).max() 124 | 125 | assert max_diff < 1e-5, f"Equivariance error is too high: {max_diff:.2e}.\n" 126 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/__init__.py: -------------------------------------------------------------------------------- 1 | """Interface with pytorch. 2 | 3 | The interface with `torch` is simple to understand. Instead of 4 | using `numpy` we need to use `torch`. 5 | 6 | We wrap the `BasisMatrixData` and `Graph2Mat` to do that, generating 7 | `TorchBasisMatrixData` and `TorchGraph2Mat` respectively. Any framework 8 | that uses `torch` should use these classes instead of the original ones. 9 | 10 | Also, if extra bindings are needed for a framework that uses `torch`, 11 | they should take these bindings as a starting point. E.g. if bindings 12 | are implemented for `X`, `XGraph2Mat` should inherit from `TorchGraph2Mat`. 13 | 14 | These bindings contain no extra functionality, as all that we do is to 15 | make sure that the core functionality works with `torch` tensors. 16 | """ 17 | 18 | from .data import * 19 | from .modules import * 20 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/conftest.py: -------------------------------------------------------------------------------- 1 | # import pytest 2 | 3 | # from e3nn import o3 4 | # import numpy as np 5 | # from scipy.sparse import csr_matrix 6 | 7 | # from graph2mat.data.basis import PointBasis 8 | # from graph2mat.data.configuration import BasisConfiguration 9 | 10 | # from graph2mat.torch.modules import BasisMatrixReadout 11 | 12 | # from graph2mat import ( 13 | # PointBasis, 14 | # BasisTableWithEdges, 15 | # BasisConfiguration, 16 | # MatrixDataProcessor, 17 | # ) 18 | 19 | # from graph2mat.torch.data import TorchBasisMatrixData 20 | # from graph2mat.torch.modules import BasisMatrixReadout 21 | 22 | 23 | # @pytest.fixture(scope="module", params=["normal", "long_A", "nobasis_A"]) 24 | # def basis_type(request): 25 | # return request.param 26 | 27 | 28 | # @pytest.fixture(scope="module") 29 | # def ABA_basis_configuration(basis_type): 30 | # """Dummy basis configuration with""" 31 | 32 | # if basis_type == "nobasis_A": 33 | # point_1 = PointBasis("A", R=5) 34 | # else: 35 | # point_1 = PointBasis( 36 | # "A", 37 | # R=5 if basis_type == "long_A" else 2, 38 | # irreps=o3.Irreps("0e"), 39 | # basis_convention="spherical", 40 | # ) 41 | 42 | # point_2 = PointBasis("B", R=5, irreps=o3.Irreps("1o"), basis_convention="spherical") 43 | 44 | # positions = np.array([[0, 0, 0], [3.0, 0, 0], [5.0, 0, 0]]) 45 | 46 | # basis = [point_1, point_2] 47 | 48 | # config = BasisConfiguration( 49 | # point_types=["A", "B", "A"], 50 | # positions=positions, 51 | # basis=basis, 52 | # cell=None, 53 | # pbc=(False, False, False), 54 | # ) 55 | 56 | # return config 57 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Wrappers for data handling in pytorch.""" 2 | 3 | from .data import * 4 | from .dataset import * 5 | from .formats import * 6 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/data/data.py: -------------------------------------------------------------------------------- 1 | """Implements the Data class to use in pytorch models.""" 2 | from typing import Any, Optional 3 | 4 | import torch 5 | 6 | from torch_geometric.data.data import Data 7 | 8 | from graph2mat.core.data import Formats, conversions 9 | from graph2mat.core.data.processing import BasisMatrixDataBase, BasisMatrixData 10 | 11 | __all__ = ["TorchBasisMatrixData"] 12 | 13 | 14 | class TorchBasisMatrixData(BasisMatrixDataBase[torch.Tensor], Data): 15 | """Extension of ``BasisMatrixDataBase`` to be used within pytorch. 16 | 17 | All this class implements is the conversion of numpy arrays to torch tensors 18 | and back. The rest of the functionality is inherited from ``BasisMatrixDataBase``. 19 | 20 | Please refer to the documentation of ``BasisMatrixDataBase`` for more information. 21 | 22 | See Also 23 | -------- 24 | graph2mat.BasisMatrixDataBase 25 | The class that implements the heavy lifting of the data processing. 26 | """ 27 | 28 | _format = Formats.TORCH_BASISMATRIXDATA 29 | _data_format = Formats.TORCH_NODESEDGES 30 | _array_format = Formats.TORCH 31 | 32 | def __init__(self, *args, **kwargs): 33 | data = BasisMatrixDataBase._sanitize_data(self, **kwargs) 34 | Data.__init__(self, **data) 35 | 36 | def __getitem__(self, key: str) -> Any: 37 | return Data.__getitem__(self, key) 38 | 39 | @property 40 | def _data(self): 41 | return {**self._store} 42 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/data/dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from torch import multiprocessing 3 | from pathlib import Path 4 | from typing import Sequence, Union, Type, Optional, TypeVar, Generic 5 | import threading 6 | 7 | import numpy as np 8 | import sisl 9 | 10 | import torch.utils.data 11 | 12 | from graph2mat import BasisConfiguration, MatrixDataProcessor 13 | from .data import TorchBasisMatrixData 14 | 15 | __all__ = ["TorchBasisMatrixDataset", "InMemoryData", "RotatingPoolData"] 16 | 17 | 18 | class TorchBasisMatrixDataset(torch.utils.data.Dataset): 19 | """Stores all configuration info of a dataset. 20 | 21 | Given all of its arguments, it has information to generate all the 22 | `BasisMatrixTorchData` objects. However, **the objects are created on the fly 23 | as they are requested**. They are not stored by this class. 24 | 25 | `torch_geometric's` data loader can be used out of the box to load data from 26 | this dataset. 27 | 28 | Parameters 29 | ---------- 30 | input_data : 31 | A list of input data. Each item can be of any kind that is possible to 32 | convert to the class specified by `data_cls` using the `new` method. 33 | data_processor : 34 | A data processor object that is passed to `data_cls.new` to assist with 35 | the creation of the data objects from the `input_data`. 36 | data_cls : 37 | The class of the data objects that will be generated from this dataset. 38 | Must have a `new` method that takes the `input_data` and `data_processor` 39 | as arguments to create a new object. The `new` method also receives a 40 | `labels` argument specifying whether matrix labels should be loaded 41 | or not for the configurations. 42 | load_labels : 43 | Whether to load the matrix labels or not. 44 | 45 | See Also 46 | -------- 47 | InMemoryData 48 | A wrapper for a dataset that loads all data into memory. 49 | RotatingPoolData 50 | A wrapper for a dataset that continously loads data into a smaller pool. 51 | 52 | Examples 53 | -------- 54 | 55 | .. code-block:: python 56 | 57 | from graph2mat import BasisConfiguration, MatrixDataProcessor 58 | from graph2mat.bindings.torch import 59 | 60 | # Initialize basis configurations (substitute ... by appropriate arguments) 61 | config_1 = BasisConfiguration(...) 62 | config_2 = BasisConfiguration(...) 63 | 64 | # Initialize data processor (substitute ... by appropriate arguments) 65 | processor = MatrixDataProcessor(...) 66 | 67 | # Initialize dataset 68 | dataset = TorchBasisMatrixDataset([config_1, config_2], processor) 69 | 70 | # Import the loader class from torch_geometric 71 | from torch_geometric.loader import DataLoader 72 | 73 | # Create a data loader from this dataset 74 | loader = DataLoader(dataset, batch_size=2) 75 | 76 | """ 77 | 78 | def __init__( 79 | self, 80 | input_data: Sequence[Union[BasisConfiguration, Path, str, sisl.Geometry]], 81 | data_processor: MatrixDataProcessor, 82 | data_cls: Type[TorchBasisMatrixData] = TorchBasisMatrixData, 83 | load_labels: bool = True, 84 | ): 85 | self.input_data = input_data 86 | self.data_processor = data_processor 87 | self.data_cls = data_cls 88 | self.load_labels = load_labels 89 | 90 | def __len__(self): 91 | return len(self.input_data) 92 | 93 | def __getitem__(self, index: int) -> TorchBasisMatrixData: 94 | item = self.input_data[index] 95 | 96 | return self.data_cls.new( 97 | item, data_processor=self.data_processor, labels=self.load_labels 98 | ) 99 | 100 | 101 | class InMemoryData(torch.utils.data.Dataset): 102 | """Wrapper for a dataset that loads all data into memory. 103 | 104 | Parameters 105 | ---------- 106 | dataset: 107 | The dataset to wrap. 108 | size: 109 | If not None, it truncates the dataset to the given size. 110 | """ 111 | 112 | def __init__( 113 | self, dataset: TorchBasisMatrixDataset, size: Optional[int] = None, **kwargs 114 | ): 115 | super().__init__(**kwargs) 116 | size = size or len(dataset) 117 | self.data_objects = [dataset[i] for i in range(size)] 118 | 119 | def __len__(self): 120 | return len(self.data_objects) 121 | 122 | def __getitem__(self, index: int) -> TorchBasisMatrixData: 123 | return self.data_objects[index] 124 | 125 | 126 | class SimpleCounter: 127 | def __init__(self): 128 | self.reset() 129 | 130 | def inc(self): 131 | self.count += 1 132 | 133 | def reset(self): 134 | self.count = 0 135 | 136 | def get_count(self): 137 | return self.count 138 | 139 | 140 | def _rotating_pool_worker(dataset, rng, queue): 141 | while True: 142 | for index in rng.permutation(len(dataset)).tolist(): 143 | queue.put(dataset[index]) 144 | 145 | 146 | def _transfer_thread( 147 | queue: multiprocessing.Queue, datalist: list, counter: SimpleCounter 148 | ): 149 | while True: 150 | for index in range(len(datalist)): 151 | datalist[index] = queue.get() 152 | counter.inc() 153 | 154 | 155 | class RotatingPoolData(torch.utils.data.Dataset): 156 | """Wrapper for a dataset that continously loads data into a smaller pool. 157 | 158 | The data loading is performed in a separate process and is assumed to be IO bound. 159 | 160 | Parameters 161 | ---------- 162 | dataset: 163 | The dataset to wrap. 164 | pool_size: 165 | The size of the pool to keep in memory. 166 | """ 167 | 168 | def __init__(self, dataset: TorchBasisMatrixDataset, pool_size: int, **kwargs): 169 | super().__init__(**kwargs) 170 | self.pool_size = pool_size 171 | self.parent_data = dataset 172 | self.rng = np.random.default_rng() 173 | self.counter = SimpleCounter() 174 | self.manager = multiprocessing.Manager() 175 | logging.debug("Filling rotating data pool of size %d" % pool_size) 176 | data_list = [ 177 | self.parent_data[i] 178 | for i in self.rng.integers( 179 | 0, high=len(self.parent_data), size=self.pool_size, endpoint=False 180 | ).tolist() 181 | ] 182 | self.data_pool = self.manager.list(data_list) 183 | self.loader_queue = multiprocessing.Queue(2) 184 | 185 | # Start loaders 186 | self.loader_process = multiprocessing.Process( 187 | target=_rotating_pool_worker, 188 | args=(self.parent_data, self.rng, self.loader_queue), 189 | daemon=True, 190 | ) 191 | self.transfer_thread = threading.Thread( 192 | target=_transfer_thread, 193 | args=(self.loader_queue, self.data_pool, self.counter), 194 | daemon=True, 195 | ) 196 | self.loader_process.start() 197 | self.transfer_thread.start() 198 | 199 | def __len__(self): 200 | return self.pool_size 201 | 202 | def __getitem__(self, index: int) -> TorchBasisMatrixData: 203 | return self.data_pool[index] 204 | 205 | def get_data_pool(self): 206 | """ 207 | Get the minimal dataset handle object for transfering to dataloader workers 208 | 209 | Returns 210 | ------- 211 | Multiprocessing proxy data object 212 | 213 | """ 214 | return self.data_pool 215 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/data/formats.py: -------------------------------------------------------------------------------- 1 | """Extensions to the registered formats/conversions for ``torch`` tensors.""" 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from graph2mat.core.data import conversions, Formats 9 | from graph2mat.core.data.sparse import _nodes_and_edges_to_coo 10 | 11 | Formats.add_alias(Formats.TORCH, torch.Tensor) 12 | Formats.add_alias(Formats.TORCH_COO, torch.sparse_coo_tensor) 13 | Formats.add_alias(Formats.TORCH_CSR, torch.sparse_csr_tensor) 14 | 15 | converter = conversions.converter 16 | 17 | 18 | @converter 19 | def _coo_to_csr(coo: torch.sparse_coo_tensor) -> torch.sparse_csr_tensor: 20 | return coo.to_sparse_csr() 21 | 22 | 23 | @converter 24 | def _coo_to_dense(coo: torch.sparse_coo_tensor) -> torch.Tensor: 25 | return coo.to_dense() 26 | 27 | 28 | @converter 29 | def _csr_to_coo(csr: torch.sparse_csr_tensor) -> torch.sparse_coo_tensor: 30 | return csr.to_sparse_coo() 31 | 32 | 33 | @converter 34 | def _csr_to_dense(csr: torch.sparse_csr_tensor) -> torch.Tensor: 35 | return csr.to_dense() 36 | 37 | 38 | @converter 39 | def _torch_to_numpy(tensor: torch.Tensor) -> np.ndarray: 40 | return tensor.numpy(force=True) 41 | 42 | 43 | @converter 44 | def _numpy_to_torch(array: np.ndarray) -> torch.Tensor: 45 | if issubclass(array.dtype.type, float): 46 | return torch.tensor(array, dtype=torch.get_default_dtype()) 47 | else: 48 | return torch.tensor(array) 49 | 50 | 51 | @converter(Formats.TORCH_NODESEDGES, Formats.TORCH_COO) 52 | def nodes_and_edges_to_coo( 53 | node_vals: torch.Tensor, 54 | edge_vals: torch.Tensor, 55 | edge_index: torch.Tensor, 56 | orbitals: torch.Tensor, 57 | n_supercells: int = 1, 58 | edge_neigh_isc: Optional[torch.Tensor] = None, 59 | threshold: Optional[float] = None, 60 | symmetrize_edges: bool = False, 61 | ) -> torch.sparse_coo_tensor: 62 | """Converts an orbital matrix from node and edges array to torch coo. 63 | 64 | Conversions to any other sparse structure can be done once we've got the coo array. 65 | 66 | Parameters 67 | ---------- 68 | node_vals 69 | Flat array containing the values of the node blocks. 70 | The order of the values is first by node index, then row then column. 71 | edge_vals 72 | Flat array containing the values of the edge blocks. 73 | The order of the values is first by edge index, then row then column. 74 | edge_index 75 | Array of shape (2, n_edges) containing the indices of the atoms 76 | that participate in each edge. 77 | orbitals 78 | Array of shape (n_nodes, ) containing the number of orbitals for each atom. 79 | n_supercells 80 | Number of auxiliary supercells. 81 | edge_neigh_isc 82 | Array of shape (n_edges, ) containing the supercell index of the second atom 83 | in each edge with respect to the first atom. 84 | If not provided, all interactions are assumed to be in the unit cell. 85 | threshold 86 | Matrix elements with a value below this number are set to 0. 87 | symmetrize_edges 88 | whether for each edge only one direction is provided. The edge block for the 89 | opposite direction is then created as the transpose. 90 | """ 91 | 92 | def _init_coo(data, rows, cols, shape): 93 | return torch.sparse_coo_tensor( 94 | torch.stack([torch.tensor(rows), torch.tensor(cols)]), data, shape 95 | ) 96 | 97 | return _nodes_and_edges_to_coo( 98 | concatenate=torch.concatenate, 99 | init_coo=_init_coo, 100 | node_vals=node_vals, 101 | edge_vals=edge_vals, 102 | edge_index=edge_index, 103 | orbitals=orbitals, 104 | n_supercells=n_supercells, 105 | edge_neigh_isc=edge_neigh_isc, 106 | threshold=threshold, 107 | symmetrize_edges=symmetrize_edges, 108 | ) 109 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/data/tests/test_data.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from graph2mat import ( 7 | PointBasis, 8 | BasisTableWithEdges, 9 | MatrixDataProcessor, 10 | BasisConfiguration, 11 | OrbitalConfiguration, 12 | ) 13 | 14 | from graph2mat.bindings.torch import TorchBasisMatrixData, TorchBasisMatrixDataset 15 | 16 | 17 | @pytest.fixture(scope="module") 18 | def positions(): 19 | return np.array([[0, 0, 0], [6.0, 0, 0]]) 20 | 21 | 22 | @pytest.fixture(scope="module", params=["cartesian", "spherical", "siesta_spherical"]) 23 | def basis_convention(request): 24 | return request.param 25 | 26 | 27 | @pytest.fixture(scope="module") 28 | def basis_table(basis_convention): 29 | point_1 = PointBasis("A", R=2, basis=[1], basis_convention=basis_convention) 30 | point_2 = PointBasis("B", R=5, basis=[2, 1], basis_convention=basis_convention) 31 | 32 | return BasisTableWithEdges([point_1, point_2]) 33 | 34 | 35 | @pytest.mark.parametrize("config_cls", [BasisConfiguration, OrbitalConfiguration]) 36 | @pytest.mark.parametrize("new_method", ["from_config", "new"]) 37 | def test_init_data(positions, basis_table, basis_convention, new_method, config_cls): 38 | # The data processor. 39 | processor = MatrixDataProcessor( 40 | basis_table=basis_table, symmetric_matrix=True, sub_point_matrix=False 41 | ) 42 | 43 | config = config_cls( 44 | point_types=["A", "B"], 45 | positions=positions, 46 | basis=basis_table, 47 | cell=np.eye(3) * 100, 48 | pbc=(False, False, False), 49 | ) 50 | 51 | # Test from_config method 52 | new = getattr(TorchBasisMatrixData, new_method) 53 | data = new(config, processor) 54 | 55 | assert isinstance(data.positions, torch.Tensor) 56 | 57 | if basis_convention == "cartesian": 58 | assert np.all(data.positions.numpy() == positions) 59 | else: 60 | assert (data.positions.numpy() != positions).sum() == 2 61 | 62 | 63 | @pytest.mark.parametrize("config_cls", [BasisConfiguration, OrbitalConfiguration]) 64 | @pytest.mark.parametrize("new_method", ["from_config", "new"]) 65 | def test_dataset(positions, basis_table, basis_convention, new_method, config_cls): 66 | # The data processor. 67 | processor = MatrixDataProcessor( 68 | basis_table=basis_table, symmetric_matrix=True, sub_point_matrix=False 69 | ) 70 | 71 | config_1 = config_cls( 72 | point_types=["A", "B"], 73 | positions=positions, 74 | basis=basis_table, 75 | cell=np.eye(3) * 100, 76 | pbc=(False, False, False), 77 | ) 78 | 79 | config_2 = config_cls( 80 | point_types=["B", "A"], 81 | positions=positions, 82 | basis=basis_table, 83 | cell=np.eye(3) * 100, 84 | pbc=(False, False, False), 85 | ) 86 | 87 | dataset = TorchBasisMatrixDataset([config_1, config_2], data_processor=processor) 88 | 89 | assert len(dataset) == 2 90 | 91 | assert isinstance(dataset[1].positions, torch.Tensor) 92 | 93 | if basis_convention == "cartesian": 94 | assert np.all(dataset[1].positions.numpy() == positions) 95 | else: 96 | assert (dataset[1].positions.numpy() != positions).sum() == 2 97 | pass 98 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/load.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Type, Optional, Tuple 2 | 3 | from pathlib import Path 4 | 5 | from graph2mat import MatrixDataProcessor 6 | 7 | import torch 8 | 9 | 10 | def sanitize_checkpoint(checkpoint: dict) -> dict: 11 | """Makes sure that the checkpoint is compatible with the current version of e3nn_matrix.""" 12 | 13 | checkpoint = checkpoint.copy() 14 | 15 | if "z_table" in checkpoint: 16 | checkpoint["basis_table"] = checkpoint.pop("z_table") 17 | 18 | data_ckpt = checkpoint["datamodule_hyper_parameters"] 19 | model_ckpt = checkpoint["hyper_parameters"] 20 | 21 | for sub_ckpt in (data_ckpt, model_ckpt): 22 | if "z_table" in sub_ckpt: 23 | sub_ckpt["basis_table"] = sub_ckpt.pop("z_table") 24 | if "unique_atoms" in sub_ckpt: 25 | sub_ckpt["unique_basis"] = sub_ckpt.pop("unique_atoms") 26 | if "sub_atomic_matrix" in sub_ckpt: 27 | sub_ckpt["sub_point_matrix"] = sub_ckpt.pop("sub_atomic_matrix") 28 | 29 | return checkpoint 30 | 31 | 32 | def load_from_lit_ckpt( 33 | ckpt_file: Union[Path, str], 34 | cpu: bool = True, 35 | as_torch: bool = False, 36 | ) -> Tuple[torch.nn.Module, MatrixDataProcessor]: 37 | """Load a model from a Lightning checkpoint file. 38 | 39 | Parameters 40 | ---------- 41 | ckpt_file : Union[Path, str] 42 | Path to the checkpoint file. 43 | cpu : bool, optional 44 | If True, the model is loaded on the CPU regardless of whether 45 | it was in the GPU when saved, by default True. 46 | as_torch : bool, optional 47 | If True, the model is returned as the bare torch.nn.Module, otherwise it is returned as a lightning module. 48 | 49 | Returns 50 | ------- 51 | torch.nn.Module 52 | The model 53 | MatrixDataProcessor 54 | The processor to use for processing inputs and outputs. 55 | """ 56 | from graph2mat.tools.lightning.models.mace import LitMACEMatrixModel 57 | 58 | ckpt = torch.load( 59 | ckpt_file, map_location="cpu" if cpu else None, weights_only=False 60 | ) 61 | 62 | ckpt = sanitize_checkpoint(ckpt) 63 | 64 | model = LitMACEMatrixModel.load_from_checkpoint( 65 | ckpt_file, 66 | basis_table=ckpt["basis_table"], 67 | map_location="cpu" if cpu else None, 68 | ) 69 | 70 | data_processor = MatrixDataProcessor( 71 | out_matrix=ckpt["datamodule_hyper_parameters"]["out_matrix"], 72 | sub_point_matrix=ckpt["datamodule_hyper_parameters"]["sub_point_matrix"], 73 | symmetric_matrix=ckpt["datamodule_hyper_parameters"]["symmetric_matrix"], 74 | basis_table=ckpt["basis_table"], 75 | node_attr_getters=model.initial_node_feats, 76 | ) 77 | 78 | if as_torch: 79 | model = model.model 80 | 81 | return model, data_processor 82 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """Wrappers for graph2mat modules in pytorch. 2 | 3 | Torch does not add extra functionality to `graph2mat modules`, we just 4 | need to wrap the core functionality to work with `torch` tensors instead 5 | of `numpy`. 6 | """ 7 | 8 | from .graph2mat import * 9 | from .matrixblock import * 10 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/modules/graph2mat.py: -------------------------------------------------------------------------------- 1 | """Torch wrappers for Graph2Mat.""" 2 | import numpy as np 3 | import torch 4 | from types import ModuleType 5 | 6 | from graph2mat import Graph2Mat 7 | 8 | __all__ = ["TorchGraph2Mat"] 9 | 10 | 11 | class TorchGraph2Mat(Graph2Mat, torch.nn.Module): 12 | """Wrapper for Graph2Mat to make it use torch instead of numpy. 13 | 14 | It also makes `Graph2Mat` a `torch.nn.Module`, and it makes it 15 | store the list of node block functions as a `torch.nn.ModuleList` 16 | and the dictionary of edge block functions as a `torch.nn.ModuleDict`. 17 | 18 | Parameters 19 | ---------- 20 | **kwargs: 21 | Additional arguments passed to the `Graph2Mat` class. 22 | 23 | See Also 24 | -------- 25 | Graph2Mat 26 | The class that `TorchGraph2Mat` extends. Its documentation contains a more 27 | detailed explanation of the inner workings of the class. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | *args, 33 | numpy: ModuleType = torch, 34 | self_interactions_list=torch.nn.ModuleList, 35 | interactions_dict=torch.nn.ModuleDict, 36 | **kwargs, 37 | ): 38 | super().__init__( 39 | *args, 40 | **kwargs, 41 | numpy=numpy, 42 | self_interactions_list=self_interactions_list, 43 | interactions_dict=interactions_dict, 44 | ) 45 | 46 | def _init_center_types(self, basis_grouping): 47 | super()._init_center_types(basis_grouping) 48 | 49 | for k in ("types_to_graph2mat", "edge_types_to_graph2mat"): 50 | array = getattr(self, k, None) 51 | if isinstance(array, np.ndarray): 52 | # Register the buffer as a torch tensor 53 | tensor = torch.from_numpy(getattr(self, k)) 54 | delattr(self, k) 55 | self.register_buffer(k, tensor, persistent=False) 56 | 57 | def _get_labels_resort_index( 58 | self, types: torch.Tensor, original_types: torch.Tensor, **kwargs 59 | ) -> torch.Tensor: 60 | """Wrapping of the method to use torch instead of numpy.""" 61 | types = types.numpy(force=True) 62 | original_types = original_types.numpy(force=True) 63 | 64 | indices = super()._get_labels_resort_index( 65 | types, original_types=original_types, **kwargs 66 | ) 67 | 68 | if self.basis_grouping != "max": 69 | indices = self.numpy.from_numpy(indices) 70 | 71 | return indices 72 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/modules/matrixblock.py: -------------------------------------------------------------------------------- 1 | """Torch wrappers for matrix block.""" 2 | import torch 3 | 4 | from graph2mat import MatrixBlock 5 | 6 | __all__ = ["TorchMatrixBlock"] 7 | 8 | 9 | class TorchMatrixBlock(MatrixBlock, torch.nn.Module): 10 | """Wrapper for matrix block to make it use torch instead of numpy.""" 11 | 12 | numpy = torch 13 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/modules/tests/test_basis_matrix.py: -------------------------------------------------------------------------------- 1 | # from e3nn import o3 2 | 3 | # from scipy.sparse import csr_matrix 4 | # import torch 5 | 6 | # from graph2mat.data.configuration import BasisConfiguration 7 | 8 | # from graph2mat.torch.modules import ( 9 | # BasisMatrixReadout, 10 | # SimpleEdgeBlock, 11 | # SimpleNodeBlock, 12 | # ) 13 | 14 | # from graph2mat.data.table import BasisTableWithEdges 15 | # from graph2mat.data.configuration import BasisConfiguration 16 | # from graph2mat.data.processing import MatrixDataProcessor 17 | 18 | # from graph2mat.torch.data import BasisMatrixTorchData 19 | # from graph2mat.torch.modules import BasisMatrixReadout 20 | 21 | 22 | # def test_irreps_in(ABA_basis_configuration: BasisConfiguration): 23 | # config = ABA_basis_configuration 24 | # basis = ABA_basis_configuration.basis 25 | 26 | # input_irreps = o3.Irreps("0e + 1o") 27 | 28 | # readout = BasisMatrixReadout( 29 | # unique_basis=basis, 30 | # node_operation_kwargs={"irreps_in": input_irreps}, 31 | # edge_operation_kwargs={ 32 | # "irreps_in": input_irreps, 33 | # }, 34 | # symmetric=True, 35 | # ) 36 | 37 | # readout2 = BasisMatrixReadout( 38 | # unique_basis=basis, 39 | # irreps_in=input_irreps, 40 | # symmetric=True, 41 | # ) 42 | 43 | # assert str(readout) == str(readout2) 44 | 45 | 46 | # def test_readout(ABA_basis_configuration: BasisConfiguration, basis_type: str): 47 | # config = ABA_basis_configuration 48 | # basis = ABA_basis_configuration.basis 49 | 50 | # input_irreps = o3.Irreps("0e + 1o") 51 | 52 | # readout = BasisMatrixReadout( 53 | # unique_basis=basis, 54 | # node_operation_kwargs={"irreps_in": input_irreps}, 55 | # edge_operation_kwargs={ 56 | # "irreps_in": input_irreps, 57 | # }, 58 | # symmetric=True, 59 | # ) 60 | 61 | # # Create the basis table. 62 | # table = BasisTableWithEdges(basis) 63 | 64 | # # Initialize the processor. 65 | # processor = MatrixDataProcessor( 66 | # basis_table=table, symmetric_matrix=True, sub_point_matrix=False 67 | # ) 68 | 69 | # data = BasisMatrixTorchData.from_config(config, processor) 70 | 71 | # node_state = input_irreps.randn(3, -1, requires_grad=True) 72 | 73 | # node_labels, edge_labels = readout.forward( 74 | # node_types=data["point_types"], 75 | # edge_index=data["edge_index"], 76 | # edge_types=data["edge_types"], 77 | # edge_type_nlabels=data["edge_type_nlabels"], 78 | # node_operation_node_kwargs={ 79 | # "state": node_state, 80 | # }, 81 | # edge_operation_node_kwargs={ 82 | # "node_state": node_state, 83 | # }, 84 | # ) 85 | 86 | # matrix = processor.matrix_from_data( 87 | # data, 88 | # {"node_labels": node_labels, "edge_labels": edge_labels}, 89 | # ) 90 | 91 | # assert isinstance(matrix, csr_matrix) 92 | # assert matrix.shape == (5, 5) if basis_type != "nobasis_A" else (3, 3) 93 | # assert matrix.nnz == {"normal": 23, "long_A": 25, "nobasis_A": 9}[basis_type] 94 | 95 | 96 | # def test_readout_filtering( 97 | # ABA_basis_configuration: BasisConfiguration, basis_type: str 98 | # ): 99 | # config = ABA_basis_configuration 100 | # basis = ABA_basis_configuration.basis 101 | 102 | # input_irreps = o3.Irreps("0e + 1o") 103 | 104 | # class EdgeChecker(SimpleEdgeBlock): 105 | # """Extension of SimpleEdgeBlock that edge kwargs and node kwargs have been correctly filtered.""" 106 | 107 | # def forward(self, edge_types, node_types, **kwargs): 108 | # assert isinstance(edge_types, tuple) 109 | # assert len(edge_types) == 2 110 | # assert all(isinstance(x, torch.Tensor) for x in edge_types) 111 | # assert torch.all(edge_types[0] == edge_types[0][0, 0]) 112 | # assert torch.all(edge_types[0] == -edge_types[1]) 113 | 114 | # assert isinstance(node_types, tuple) 115 | # assert len(node_types) == 2 116 | # assert all(isinstance(x, torch.Tensor) for x in node_types) 117 | # assert torch.all(node_types[0] == node_types[0][0, 0]) 118 | # assert torch.all(node_types[1] == node_types[1][0, 0]) 119 | 120 | # return super().forward(**kwargs) 121 | 122 | # class NodeChecker(SimpleNodeBlock): 123 | # """Extension of SimpleEdgeBlock that edge kwargs and node kwargs have been correctly filtered.""" 124 | 125 | # def forward(self, node_types, **kwargs): 126 | # assert isinstance(node_types, torch.Tensor) 127 | # assert torch.all(node_types == node_types[0, 0]) 128 | 129 | # return super().forward(**kwargs) 130 | 131 | # readout = BasisMatrixReadout( 132 | # unique_basis=basis, 133 | # node_operation=NodeChecker, 134 | # node_operation_kwargs={"irreps_in": input_irreps}, 135 | # edge_operation=EdgeChecker, 136 | # edge_operation_kwargs={ 137 | # "irreps_in": input_irreps, 138 | # }, 139 | # symmetric=True, 140 | # ) 141 | 142 | # # Create the basis table. 143 | # table = BasisTableWithEdges(basis) 144 | 145 | # # Initialize the processor. 146 | # processor = MatrixDataProcessor( 147 | # basis_table=table, symmetric_matrix=True, sub_point_matrix=False 148 | # ) 149 | 150 | # data = BasisMatrixTorchData.from_config(config, processor) 151 | 152 | # node_state = input_irreps.randn(3, -1, requires_grad=True) 153 | 154 | # edge_types = torch.tensor( 155 | # [*data["edge_types"]], dtype=torch.get_default_dtype() 156 | # ).reshape(-1, 1) 157 | # node_types = torch.tensor( 158 | # [*data["point_types"]], dtype=torch.get_default_dtype() 159 | # ).reshape(-1, 1) 160 | 161 | # node_labels, edge_labels = readout.forward( 162 | # node_types=data["point_types"], 163 | # edge_index=data["edge_index"], 164 | # edge_types=data["edge_types"], 165 | # edge_type_nlabels=data["edge_type_nlabels"], 166 | # node_operation_node_kwargs={ 167 | # "state": node_state, 168 | # "node_types": node_types, 169 | # }, 170 | # edge_operation_node_kwargs={ 171 | # "node_state": node_state, 172 | # "node_types": node_types, 173 | # }, 174 | # edge_kwargs={ 175 | # "edge_types": edge_types, 176 | # }, 177 | # ) 178 | 179 | # matrix = processor.matrix_from_data( 180 | # data, 181 | # {"node_labels": node_labels, "edge_labels": edge_labels}, 182 | # ) 183 | 184 | # assert isinstance(matrix, csr_matrix) 185 | # assert matrix.shape == (5, 5) if basis_type != "nobasis_A" else (3, 3) 186 | # assert matrix.nnz == {"normal": 23, "long_A": 25, "nobasis_A": 9}[basis_type] 187 | -------------------------------------------------------------------------------- /src/graph2mat/bindings/torch/tests/test_orbital_matrix.py: -------------------------------------------------------------------------------- 1 | # """Tests for input preparation""" 2 | # from graph2mat.data.sparse import csr_to_block_dict 3 | # from graph2mat.data.configuration import OrbitalConfiguration 4 | # from graph2mat.torch.data import BasisMatrixTorchData 5 | 6 | 7 | # def test_orbital_matrix_data(density_matrix, density_data_processor): 8 | # # For now we just test that we can get an OrbitalMatrixData object 9 | # # with and without matrix, and nothing breaks. 10 | 11 | # block_dict = csr_to_block_dict( 12 | # density_matrix._csr, density_matrix.atoms, nsc=density_matrix.nsc 13 | # ) 14 | # config = OrbitalConfiguration.from_geometry( 15 | # geometry=density_matrix.geometry, matrix=block_dict 16 | # ) 17 | # data = BasisMatrixTorchData.from_config( 18 | # config, data_processor=density_data_processor 19 | # ) 20 | 21 | # no_matrix_config = OrbitalConfiguration.from_geometry( 22 | # geometry=density_matrix.geometry, matrix=block_dict 23 | # ) 24 | # no_matrix_data = BasisMatrixTorchData.from_config( 25 | # config, data_processor=density_data_processor 26 | # ) 27 | -------------------------------------------------------------------------------- /src/graph2mat/conftest.py: -------------------------------------------------------------------------------- 1 | """Tests for sparse structure conversion""" 2 | import pytest 3 | import sisl 4 | from numpy.random import RandomState 5 | import numpy as np 6 | 7 | from graph2mat.core.data.sparse import csr_to_block_dict 8 | from graph2mat import MatrixDataProcessor, OrbitalConfiguration, AtomicTableWithEdges 9 | from graph2mat.bindings.torch import TorchBasisMatrixData 10 | 11 | 12 | @pytest.fixture(scope="session", params=[True, False]) 13 | def periodic(request): 14 | return request.param 15 | 16 | 17 | @pytest.fixture(scope="session", params=[True, False]) 18 | def symmetric(request): 19 | return request.param 20 | 21 | 22 | @pytest.fixture(scope="session") 23 | def density_matrix(periodic, symmetric): 24 | rs = RandomState(32) 25 | 26 | r = np.linspace(0, 3) 27 | f = np.exp(-r) 28 | 29 | C = sisl.Atom( 30 | "C", 31 | orbitals=[ 32 | sisl.AtomicOrbital("2s", (r, f), q0=2), 33 | sisl.AtomicOrbital("2px", (r, f), q0=0.666), 34 | sisl.AtomicOrbital("2pz", (r, f), q0=0.666), 35 | sisl.AtomicOrbital("2py", (r, f), q0=0.666), 36 | ], 37 | ) 38 | 39 | N = sisl.Atom( 40 | "N", 41 | orbitals=[ 42 | sisl.AtomicOrbital("2s", (r, f), q0=2), 43 | sisl.AtomicOrbital("2px", (r, f), q0=1), 44 | sisl.AtomicOrbital("2pz", (r, f), q0=1), 45 | sisl.AtomicOrbital("2py", (r, f), q0=1), 46 | # sisl.AtomicOrbital("2px", R=10), sisl.AtomicOrbital("2pz", R=10), sisl.AtomicOrbital("2py", R=10), 47 | ], 48 | ) 49 | 50 | geom = sisl.geom.graphene_nanoribbon(width=3, atoms=[C, N]) 51 | if not periodic: 52 | # Don't consider periodicity 53 | geom.cell[0, 0] = 20 54 | geom.set_nsc([1, 1, 1]) 55 | else: 56 | geom.set_nsc([5, 1, 1]) 57 | 58 | dm = sisl.DensityMatrix( 59 | geom, 60 | ) 61 | 62 | rows = dm.geometry.firsto[:-1] 63 | cols = np.tile(rows, dm.n_s).reshape(dm.n_s, -1) + ( 64 | np.arange(dm.n_s) * dm.no 65 | ).reshape(-1, 1) 66 | cols = cols.ravel() 67 | 68 | vals = (rs.random(cols.shape[0]) * 2) - 1 69 | for row in rows: 70 | dists = dm.geometry.rij(dm.o2a(row), dm.o2a(cols)) 71 | dm[row, cols[dists < 6]] = vals[dists < 6] 72 | 73 | if symmetric: # and not periodic: 74 | dm = (dm + dm.transpose()) / 2 75 | 76 | return dm 77 | 78 | 79 | @pytest.fixture(scope="session") 80 | def density_z_table(density_matrix): 81 | return AtomicTableWithEdges(density_matrix.atoms.atom) 82 | 83 | 84 | @pytest.fixture(scope="session") 85 | def density_data_processor(density_z_table, symmetric): 86 | return MatrixDataProcessor( 87 | basis_table=density_z_table, 88 | sub_point_matrix=False, 89 | symmetric_matrix=symmetric, 90 | out_matrix="density_matrix", 91 | ) 92 | 93 | 94 | @pytest.fixture(scope="session") 95 | def density_config(density_matrix): 96 | geometry = density_matrix.geometry 97 | 98 | dm_block = csr_to_block_dict( 99 | density_matrix._csr, density_matrix.atoms, nsc=density_matrix.nsc 100 | ) 101 | 102 | return OrbitalConfiguration.from_geometry(geometry=geometry, matrix=dm_block) 103 | 104 | 105 | @pytest.fixture(scope="session") 106 | def density_data(density_config, density_data_processor): 107 | return TorchBasisMatrixData.from_config( 108 | density_config, data_processor=density_data_processor 109 | ) 110 | -------------------------------------------------------------------------------- /src/graph2mat/core/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory("data") 2 | add_subdirectory("modules") 3 | -------------------------------------------------------------------------------- /src/graph2mat/core/__init__.py: -------------------------------------------------------------------------------- 1 | """Core functionality of the graph2mat package. 2 | 3 | There are two main things that need to be implemented to make 4 | fitting matrices a reality: **data handling and computation**. 5 | 6 | This submodule implements all the routines to deal with data 7 | containing graphs and sparse matrices related to graphs, as 8 | well as the skeleton of `Graph2Mat`, the function to convert 9 | graphs to matrices. 10 | 11 | For now, this module doesn't implement the functions (modules/blocks) 12 | to be used within `Graph2Mat`, because we have only worked with 13 | equivariant functions and therefore we use the functions defined 14 | in `graph2mat.bindings.e3nn` as the working blocks. 15 | """ 16 | 17 | from . import data 18 | from . import modules 19 | 20 | from .data import * 21 | from .modules import * 22 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_custom_command( 2 | OUTPUT _sparse.c 3 | DEPENDS _sparse.py 4 | VERBATIM 5 | COMMAND "${CYTHON}" "${CMAKE_CURRENT_SOURCE_DIR}/_sparse.py" --output-file 6 | "${CMAKE_CURRENT_BINARY_DIR}/_sparse.c") 7 | 8 | python_add_library(_sparse MODULE "${CMAKE_CURRENT_BINARY_DIR}/_sparse.c" 9 | WITH_SOABI) 10 | 11 | install(TARGETS _sparse DESTINATION ${SKBUILD_PROJECT_NAME}/core/data) 12 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Tools to create and manipilate data to interact with the models. 2 | 3 | It implements the functionality needed to handle (sparse) matrices that 4 | are related to a graph. There are several things to take into account 5 | which make the problem of handling the data non-trivial and therefore 6 | this module useful: 7 | 8 | - **Matrices are sparse**. 9 | - Matrices are in a basis which is centered around the points in the graph. Therefore 10 | **elements of the matrix correspond to nodes or edges of the graph**. 11 | - Each point might have more than one basis function, therefore **the matrix is divided 12 | in blocks (not just single elements)** that correspond to nodes or edges of the graph. 13 | - Different point types might have different basis size, which makes **the different 14 | blocks in the matrix have different shapes**. 15 | - **The different block sizes and the sparsity of the matrices supose and extra 16 | challenge when batching** examples for machine learning. 17 | 18 | The tools in this submodule are agnostic to the machine learning framework 19 | of choice, and they are based purely on `numpy`, with the extra dependency on `sisl` 20 | to handle the sparse matrices. The `sisl` dependency could eventually be lift off 21 | if needed. 22 | """ 23 | from .matrices import * 24 | 25 | from .basis import PointBasis 26 | from .configuration import BasisConfiguration, OrbitalConfiguration 27 | from .metrics import OrbitalMatrixMetric 28 | from . import metrics 29 | from .processing import MatrixDataProcessor, BasisMatrixData, BasisMatrixDataBase 30 | from .table import BasisTableWithEdges, AtomicTableWithEdges 31 | from .formats import Formats, conversions, ConversionManager 32 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/_sparse.py: -------------------------------------------------------------------------------- 1 | """Functions whose performance is critical in sparse conversions. 2 | 3 | This file can either be used directly by the python interpreter 4 | or cythonized for increased performance. 5 | 6 | The cython compilation assumes that cython.int is int32 so that 7 | we don't have to cimport numpy. This might not be true in some 8 | machines (?). 9 | """ 10 | 11 | import numpy as np 12 | 13 | import cython 14 | 15 | 16 | def _csr_to_block_dict( 17 | data: cython.numeric[:], 18 | ptr: cython.int[:], 19 | cols: cython.int[:], 20 | atom_first_orb: cython.int[:], 21 | orbitals: cython.int[:], 22 | n_atoms: cython.int, 23 | fill_value: float = 0.0, 24 | ): 25 | # --- Cython annotations for increased performance (ignored if not compiled with cython) 26 | atom_i: cython.int 27 | atomi_firsto: cython.int 28 | atomi_lasto: cython.int 29 | atomi_norbs: cython.int 30 | 31 | orbital_i: cython.int 32 | atom_j: cython.int 33 | orbital_j: cython.int 34 | 35 | no: cython.int 36 | 37 | ival: cython.int 38 | val: cython.numeric 39 | sc_col: cython.int 40 | col: cython.int 41 | i_sc: cython.int 42 | row: cython.int 43 | 44 | rc_to_atom_index: cython.int[:] 45 | rc_to_orbital_index: cython.int[:] 46 | # ------ End of cython annotations. 47 | 48 | # Mapping from row/column index to atom index 49 | rc_to_atom_index = np.concatenate( 50 | [np.ones(o, dtype=np.int32) * i for i, o in enumerate(orbitals)] 51 | ) 52 | # Mapping from row/column index to orbital index within atom 53 | rc_to_orbital_index = np.concatenate( 54 | [np.arange(o) for o in orbitals], dtype=np.int32 55 | ) 56 | 57 | no = atom_first_orb[n_atoms] 58 | 59 | block_dict = {} 60 | for atom_i in range(n_atoms): 61 | # Get the orbital limits of the block 62 | atomi_firsto = atom_first_orb[atom_i] 63 | atomi_lasto = atom_first_orb[atom_i + 1] 64 | # And the size of the block 65 | atomi_norbs = atomi_lasto - atomi_firsto 66 | 67 | # Loop over rows in this atom. 68 | for orbital_i in range(atomi_norbs): 69 | # Get the row index 70 | row = atomi_firsto + orbital_i 71 | 72 | for ival in range(ptr[row], ptr[row + 1]): 73 | val = data[ival] 74 | sc_col = cols[ival] 75 | 76 | # Sisl SparseCSR allocates space in advance for values. Values that are 77 | # allocated but have not been set contain a col of -1. 78 | if sc_col < 0: 79 | break 80 | 81 | col = sc_col % no 82 | i_sc = sc_col // no 83 | 84 | atom_j = rc_to_atom_index[col] 85 | orbital_j = rc_to_orbital_index[col] 86 | try: 87 | block_dict[atom_i, atom_j, i_sc][orbital_i, orbital_j] = val 88 | except KeyError: 89 | block_dict[atom_i, atom_j, i_sc] = np.full( 90 | (orbitals[atom_i], orbitals[atom_j]), fill_value 91 | ) 92 | block_dict[atom_i, atom_j, i_sc][orbital_i, orbital_j] = val 93 | 94 | return block_dict 95 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/matrices/__init__.py: -------------------------------------------------------------------------------- 1 | """Containers to store the raw matrices as a dictionary of blocks. 2 | 3 | The matrices are stored in this format in `BasisConfiguration`, until 4 | they are converted to flat arrays for training in `BasisMatrixData`. 5 | 6 | However, the user does not need to initialize these matrices explicitly, 7 | they are initialized appropiately when initializing a `BasisConfiguration` 8 | object using the `OrbitalConfiguration.new` method. 9 | 10 | There are different matrix classes. This is something that is probably 11 | not needed and is reminiscent of the initial development stages. 12 | """ 13 | from typing import Union, Type 14 | 15 | from warnings import warn 16 | 17 | import sisl 18 | 19 | from .basis_matrix import BasisMatrix 20 | from .physics.orbital_matrix import OrbitalMatrix 21 | from .physics.density_matrix import DensityMatrix 22 | 23 | __all__ = ["BasisMatrix", "OrbitalMatrix", "DensityMatrix", "get_matrix_cls"] 24 | 25 | _KEY_TO_MATRIX_CLS = { 26 | "density_matrix": DensityMatrix, 27 | sisl.DensityMatrix: DensityMatrix, 28 | } 29 | 30 | 31 | def get_matrix_cls(key: Union[str, sisl.SparseOrbital, None]) -> Type[OrbitalMatrix]: 32 | if key is None: 33 | return OrbitalMatrix 34 | else: 35 | if isinstance(key, str): 36 | key = key.lower() 37 | try: 38 | return _KEY_TO_MATRIX_CLS[key] 39 | except KeyError: 40 | warn( 41 | f"{key} is not a known matrix type key, falling back to generic OrbitalMatrix class." 42 | ) 43 | return OrbitalMatrix 44 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/matrices/basis_matrix.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Dict, Tuple, Union 4 | 5 | from dataclasses import dataclass 6 | import numpy as np 7 | 8 | from ..table import BasisTableWithEdges 9 | 10 | BasisCount = np.ndarray # [num_points] 11 | 12 | 13 | @dataclass 14 | class BasisMatrix: 15 | """Container to store the raw matrices as a dictionary of blocks. 16 | 17 | The matrices are stored in this format in `BasisConfiguration`, until 18 | they are converted to flat arrays for training in `BasisMatrixData`. 19 | 20 | As a user you probably don't need to initialize these matrices explicitly, 21 | they are initialized appropiately when initializing a `BasisConfiguration` 22 | object using the `OrbitalConfiguration.new` method. 23 | """ 24 | 25 | #: Dictionary containing the blocks of the matrix. The keys are tuples 26 | #: `(i, j_uc, sc_j)` where `ì` and `j_uc` are the point indices of the block 27 | #: (in the unit cell) and `sc_j` is the index of the neighboring cell where 28 | # `j` is located, e.g. 0 for a unit cell connection. 29 | block_dict: Dict[Tuple[int, int, int], np.ndarray] 30 | # Size of the auxiliary supercell. This is the number of cells required 31 | # in each direction to account for all the interactions of the points in 32 | # the unit cell. If the point distribution is not periodic, this will 33 | # always be [1,1,1]. 34 | nsc: np.ndarray 35 | # Array containing the number of basis functions for each point 36 | basis_count: BasisCount 37 | 38 | def to_flat_nodes_and_edges( 39 | self, 40 | edge_index: np.ndarray, 41 | edge_sc_shifts: np.ndarray, 42 | points_order: Union[np.ndarray, None] = None, 43 | basis_table: Union[BasisTableWithEdges, None] = None, 44 | point_types: Union[np.ndarray, None] = None, 45 | sub_point_matrix: bool = False, 46 | ) -> Tuple[np.ndarray, np.ndarray]: 47 | """Converts the matrix to a flat representation of the nodes and edges values. 48 | 49 | This representation might be useful for training a neural network, where you will 50 | want to compare the output array to the target array. If you have flat arrays instead 51 | of block dicts you can easily compare the whole matrix at once. 52 | 53 | Parameters 54 | ---------- 55 | edge_index : np.ndarray 56 | Array of shape [2, n_edges] containing the indices of the points that form each edge. 57 | edge_sc_shifts : np.ndarray 58 | Array of shape [n_edges, 3] containing the supercell shifts of each edge. That is, if 59 | an edge is from point i to a periodic image of point j in the [1,0,0] cell, edge_sc_shifts 60 | for this edge should be [1,0,0]. Note that for the reverse edge (if there is one), the shift 61 | will be [-1,0,0]. 62 | points_order : Union[np.ndarray, None], optional 63 | Array of shape [n_points] containing the order in which the points should be flattened. 64 | If None, the order will simply be determined by their index. 65 | basis_table : Union[BasisTableWithEdges, None], optional 66 | Table containing the types of the points. Only needed if sub_point_matrix is True. 67 | """ 68 | 69 | if points_order is None: 70 | order = np.arange(len(self.basis_count)) 71 | else: 72 | order = points_order 73 | assert len(order) == len(self.basis_count) 74 | 75 | if sub_point_matrix: 76 | assert basis_table is not None and point_types is not None 77 | point_matrices = self.get_point_matrices(basis_table) 78 | blocks = [ 79 | (self.block_dict[i, i, 0] - point_matrices[point_types[i]]).flatten() 80 | for i in order 81 | if self.basis_count[i] > 0 82 | ] 83 | else: 84 | blocks = [ 85 | self.block_dict[i, i, 0].flatten() 86 | for i in order 87 | if self.basis_count[i] > 0 88 | ] 89 | 90 | node_values = np.concatenate(blocks) 91 | 92 | assert edge_index.shape[0] == 2, "edge_index is assumed to be [2, n_edges]" 93 | blocks = [ 94 | self.block_dict[edge[0], edge[1], sc_shift].flatten() 95 | for edge, sc_shift in zip(edge_index.transpose(), edge_sc_shifts) 96 | if self.basis_count[edge[0]] > 0 and self.basis_count[edge[1]] > 0 97 | ] 98 | 99 | if len(blocks) > 0: 100 | edge_values = np.concatenate(blocks) 101 | else: 102 | edge_values = np.array([]) 103 | 104 | return node_values, edge_values 105 | 106 | def get_point_matrices( 107 | self, basis_table: BasisTableWithEdges 108 | ) -> Dict[int, np.ndarray]: 109 | """This method should implement a way of retreiving the sub-matrices of each individual point. 110 | 111 | This is, the matrix that the point would have if it was the only point in the system. This matrix 112 | will depend on the type of the point, so the basis_table needs to be provided. 113 | 114 | The user might choose to subtract this matrix from the block_dict matrix during training, so that 115 | the model only learns the interactions between points. 116 | """ 117 | raise NotImplementedError( 118 | f"{self.__class__.__name__} does not implement a way of retreiving point matrices." 119 | ) 120 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/matrices/physics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIG-MAP/graph2mat/8eb40159b0a29b0fb524245714cc93dbe1885be4/src/graph2mat/core/data/matrices/physics/__init__.py -------------------------------------------------------------------------------- /src/graph2mat/core/data/matrices/physics/density_matrix.py: -------------------------------------------------------------------------------- 1 | import sisl 2 | import numpy as np 3 | 4 | from dataclasses import dataclass 5 | 6 | from ...table import AtomicTableWithEdges 7 | from .orbital_matrix import OrbitalMatrix 8 | 9 | 10 | @dataclass 11 | class DensityMatrix(OrbitalMatrix): 12 | def get_atomic_matrices(self, z_table: AtomicTableWithEdges): 13 | return z_table.atomic_DM 14 | 15 | 16 | def get_atomic_DM(atom: sisl.Atom) -> np.ndarray: 17 | """Gets the block corresponding to the atomic density matrix. 18 | 19 | Parameters 20 | ---------- 21 | atom: sisl.Atom 22 | The Atom object from which the density block is desired. 23 | It must contain the basis orbitals, with the initial occupation for each of them. 24 | This is how they come if you have read the basis from a SIESTA calculation or 25 | from an .ion file. 26 | 27 | Returns 28 | ---------- 29 | np.ndarray of shape atom.no x atom.no 30 | Square matrix encoding the isolated atom density matrix. 31 | """ 32 | return np.diag(atom.q0) 33 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/matrices/physics/orbital_matrix.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Dict, Tuple, Union 4 | 5 | from dataclasses import dataclass, field 6 | import numpy as np 7 | 8 | from ..basis_matrix import BasisMatrix 9 | from ...table import AtomicTableWithEdges 10 | 11 | OrbitalCount = np.ndarray # [num_atoms] 12 | 13 | 14 | @dataclass 15 | class OrbitalMatrix(BasisMatrix): 16 | """Container to store the raw matrices as a dictionary of blocks. 17 | 18 | This class just adds some extra aliases to the `BasisMatrix` class, 19 | to use orbital terminology. 20 | """ 21 | 22 | #: Dictionary containing the blocks of the matrix. The keys are tuples 23 | #: `(i, j_uc, sc_j)` where `ì` and `j_uc` are the atomic indices of the block 24 | #: (in the unit cell) and `sc_j` is the index of the neighboring cell where 25 | # `j` is located, e.g. 0 for a unit cell connection. 26 | block_dict: Dict[Tuple[int, int, int], np.ndarray] 27 | 28 | #: Size of the auxiliary supercell. This is the number of cells required 29 | #: in each direction to account for all the interactions of the points in 30 | #: the unit cell. If the point distribution is not periodic, this will 31 | #: always be [1,1,1]. 32 | nsc: np.ndarray 33 | 34 | #: Alias for `basis_count`. Array containing the number of basis functions for each point 35 | orbital_count: OrbitalCount 36 | 37 | #: Array containing the number of basis functions for each point 38 | basis_count: OrbitalCount = field(init=False) 39 | 40 | def __post_init__(self): 41 | self.basis_count = self.orbital_count 42 | 43 | def get_point_matrices( 44 | self, basis_table: AtomicTableWithEdges 45 | ) -> Dict[int, ndarray]: 46 | return self.get_atomic_matrices(basis_table) 47 | 48 | def get_atomic_matrices(self, z_table: AtomicTableWithEdges): 49 | """""" 50 | raise NotImplementedError( 51 | f"{self.__class__.__name__} does not implement a way of retreiving atomic matrices." 52 | ) 53 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/neighborhood.py: -------------------------------------------------------------------------------- 1 | """Neighborhood construction. 2 | 3 | This uses ``ASE`` and is heavily based on former MACE code by 4 | Ilyes Batatia and Gregor Simm. 5 | 6 | A future refactoring will probably make use of ``sisl``, since this 7 | is the only place we depend on ASE explicitly. 8 | """ 9 | 10 | from typing import Optional, Tuple 11 | 12 | import ase.neighborlist 13 | import numpy as np 14 | 15 | 16 | def get_neighborhood( 17 | positions: np.ndarray, # [num_positions, 3] 18 | cutoff: float, 19 | pbc: Optional[Tuple[bool, bool, bool]] = None, 20 | cell: Optional[np.ndarray] = None, # [3, 3] 21 | true_self_interaction=False, 22 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 23 | if pbc is None: 24 | pbc = (False, False, False) 25 | 26 | if cell is None or cell.any() == np.zeros((3, 3)).any(): 27 | cell = np.identity(3, dtype=float) 28 | 29 | assert len(pbc) == 3 and all(isinstance(i, (bool, np.bool_)) for i in pbc) 30 | assert cell.shape == (3, 3) 31 | 32 | sender, receiver, unit_shifts = ase.neighborlist.primitive_neighbor_list( 33 | quantities="ijS", 34 | pbc=pbc, 35 | cell=cell, 36 | positions=positions, 37 | cutoff=cutoff, 38 | self_interaction=True, # we want edges from atom to itself in different periodic images 39 | use_scaled_positions=False, # positions are not scaled positions 40 | ) 41 | 42 | if not true_self_interaction: 43 | # Eliminate self-edges that don't cross periodic boundaries 44 | true_self_edge = sender == receiver 45 | true_self_edge &= np.all(unit_shifts == 0, axis=1) 46 | keep_edge = ~true_self_edge 47 | 48 | # Note: after eliminating self-edges, it can be that no edges remain in this system 49 | sender = sender[keep_edge] 50 | receiver = receiver[keep_edge] 51 | unit_shifts = unit_shifts[keep_edge] 52 | 53 | # Build output 54 | edge_index = np.stack((sender, receiver)) # [2, n_edges] 55 | 56 | # From the docs: With the shift vector S, the distances D between atoms can be computed from 57 | # D = positions[j]-positions[i]+S.dot(cell) 58 | shifts = np.dot(unit_shifts, cell) # [n_edges, 3] 59 | 60 | return edge_index, unit_shifts, shifts 61 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/node_feats.py: -------------------------------------------------------------------------------- 1 | """Experimental module for defining node features. 2 | 3 | You probably don't care about this module unless you have come 4 | here to find it because you want to play with node features. 5 | 6 | Usually node features/embeddings should be defined by the atomic 7 | environment descriptor of choice, but I was playing with these to 8 | incorporate the total dipole of water as a global descriptor into 9 | MACE. 10 | """ 11 | import numpy as np 12 | 13 | 14 | class NodeFeature: 15 | def __new__(cls, config, data_processor): 16 | return cls.get_feature(config, data_processor) 17 | 18 | registry = {} 19 | 20 | def __init_subclass__(cls) -> None: 21 | NodeFeature.registry[cls.__name__] = cls 22 | 23 | @staticmethod 24 | def get_feature(config: dict, data_processor) -> np.ndarray: 25 | raise NotImplementedError 26 | 27 | @staticmethod 28 | def get_e3nn_irreps(data_processor): 29 | raise NotImplementedError 30 | 31 | 32 | class OneHotZ(NodeFeature): 33 | @staticmethod 34 | def get_e3nn_irreps(basis_table): 35 | from e3nn import o3 36 | 37 | return o3.Irreps([(len(basis_table), (0, 1))]) 38 | 39 | @staticmethod 40 | def get_feature(config, data_processor): 41 | indices = data_processor.get_point_types(config) 42 | return data_processor.one_hot_encode(indices) 43 | 44 | 45 | class WaterDipole(NodeFeature): 46 | @staticmethod 47 | def get_e3nn_irreps(basis_table): 48 | from e3nn import o3 49 | 50 | return o3.Irreps("1x1o") 51 | 52 | @staticmethod 53 | def get_feature(config, data_processor): 54 | n_atoms = len(config.positions) 55 | 56 | z_dipole = np.array([0.0, 0.0, 0.0]) 57 | for position, point_type in zip(config.positions, config.point_types): 58 | if point_type == 8 or point_type == 1: 59 | z_dipole[2] += position[2] * (-2 if point_type == 8 else 1) 60 | 61 | z_dipole = data_processor.cartesian_to_basis(z_dipole) / 30 62 | z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 3) 63 | 64 | return z_dipole 65 | 66 | 67 | class WaterDipoleInv(NodeFeature): 68 | @staticmethod 69 | def get_e3nn_irreps(basis_table): 70 | from e3nn import o3 71 | 72 | return o3.Irreps("1x0e") 73 | 74 | @staticmethod 75 | def get_feature(config, data_processor): 76 | n_atoms = len(config.positions) 77 | 78 | z_dipole = np.array([0.0]) 79 | for position, point_type in zip(config.positions, config.point_types): 80 | if point_type == 8 or point_type == 1: 81 | z_dipole[0] += position[2] * (-2 if point_type == 8 else 1) 82 | 83 | z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 1) 84 | 85 | return z_dipole / 30 86 | 87 | 88 | class Nothing(NodeFeature): 89 | @staticmethod 90 | def get_e3nn_irreps(basis_table): 91 | from e3nn import o3 92 | 93 | return o3.Irreps("1x0e") 94 | 95 | @staticmethod 96 | def get_feature(config, data_processor): 97 | n_atoms = len(config.positions) 98 | 99 | z_dipole = np.array([0.0]) 100 | 101 | z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 1) 102 | 103 | return z_dipole 104 | 105 | 106 | class NothingVector(NodeFeature): 107 | @staticmethod 108 | def get_e3nn_irreps(basis_table): 109 | from e3nn import o3 110 | 111 | return o3.Irreps("1x1o") 112 | 113 | @staticmethod 114 | def get_feature(config, data_processor): 115 | n_atoms = len(config.positions) 116 | 117 | z_dipole = np.array([0.0, 0.0, 0.0]) 118 | 119 | z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 3) 120 | 121 | return z_dipole 122 | 123 | 124 | class One(NodeFeature): 125 | @staticmethod 126 | def get_e3nn_irreps(basis_table): 127 | from e3nn import o3 128 | 129 | return o3.Irreps("1x0e") 130 | 131 | @staticmethod 132 | def get_feature(config, data_processor): 133 | n_atoms = len(config.positions) 134 | 135 | z_dipole = np.array([1.0]) 136 | 137 | z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 1) 138 | 139 | return z_dipole 140 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/tests/test_basis.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | import itertools 5 | 6 | import sisl 7 | 8 | from graph2mat import PointBasis 9 | from graph2mat.core.data.basis import get_atom_basis, get_change_of_basis 10 | 11 | 12 | def test_simplest(): 13 | basis = PointBasis("A", basis_convention="spherical", basis="3x0e + 2x1o", R=5) 14 | 15 | 16 | def test_list_basis(): 17 | basis_str = PointBasis("A", basis_convention="spherical", basis="3x0e + 2x1o", R=5) 18 | 19 | basis_list = PointBasis("A", basis_convention="spherical", basis=[3, 2], R=5) 20 | 21 | assert basis_str == basis_list 22 | 23 | 24 | def test_siesta_convention(): 25 | basis = PointBasis( 26 | "A", basis_convention="siesta_spherical", basis="3x0e + 2x1o", R=5 27 | ) 28 | 29 | 30 | def test_no_basis(): 31 | basis = PointBasis("A", R=5) 32 | 33 | 34 | def test_multiple_R(): 35 | basis = PointBasis( 36 | "A", 37 | basis_convention="spherical", 38 | basis="3x0e + 2x1o", 39 | R=np.array([5, 5, 5, 3, 3, 3, 3, 3, 3]), 40 | ) 41 | 42 | # Wrong number of Rs 43 | with pytest.raises(AssertionError): 44 | basis = PointBasis( 45 | "A", 46 | basis_convention="spherical", 47 | basis="3x0e + 2x1o", 48 | R=np.array([5, 5, 5, 3, 3]), 49 | ) 50 | 51 | 52 | def test_from_sisl_atom(): 53 | atom = sisl.Atom(1, orbitals=[sisl.AtomicOrbital("2p{ax}", R=4) for ax in "xyz"]) 54 | 55 | basis = PointBasis.from_sisl_atom(atom) 56 | 57 | assert basis.type == 1 58 | assert basis.basis_convention == "siesta_spherical" 59 | # assert basis.irreps == o3.Irreps("1x1o") 60 | assert isinstance(basis.R, np.ndarray) 61 | assert np.all(basis.R == 4) 62 | 63 | 64 | def test_to_sisl_atom(): 65 | basis = PointBasis( 66 | "A", 67 | basis_convention="siesta_spherical", 68 | basis="3x0e + 2x1o", 69 | R=np.array([5, 5, 5, 3, 3, 3, 3, 3, 3]), 70 | ) 71 | 72 | atom = basis.to_sisl_atom() 73 | 74 | isinstance(atom, sisl.Atom) 75 | assert len(atom.orbitals) == 9 76 | 77 | for orbital in atom.orbitals[:3]: 78 | assert orbital.l == 0 79 | assert orbital.R == 5 80 | 81 | for orbital in atom.orbitals[3:]: 82 | assert orbital.l == 1 83 | assert orbital.R == 3 84 | 85 | 86 | def test_get_atom_basis(): 87 | atom = sisl.Atom( 88 | 1, 89 | orbitals=[ 90 | sisl.AtomicOrbital("2s", R=4), 91 | *[sisl.AtomicOrbital(f"2p{ax}", R=4) for ax in "xyz"], 92 | *[sisl.AtomicOrbital(f"2p{ax}Z2", R=4) for ax in "xyz"], 93 | ], 94 | ) 95 | 96 | atom_basis = get_atom_basis(atom) 97 | 98 | assert atom_basis == [(1, 0, 1), (2, 1, -1)] 99 | 100 | 101 | def test_change_of_basis_consistent(): 102 | conventions = ["spherical", "siesta_spherical", "cartesian"] 103 | for original, target in itertools.combinations_with_replacement(conventions, 2): 104 | a, b = get_change_of_basis(original, target) 105 | c, d = get_change_of_basis(target, original) 106 | 107 | assert np.allclose(a, d), f"{original}-{target}" 108 | assert np.allclose(b, c), f"{original}-{target}" 109 | 110 | 111 | def test_change_of_basis_works(): 112 | """Check that we can go through different basis, go back 113 | to the original basis and get the original array.""" 114 | 115 | array = np.random.rand(4, 3) 116 | 117 | cob, _ = get_change_of_basis("cartesian", "siesta_spherical") 118 | 119 | in_siesta = array @ cob.T 120 | 121 | cob, _ = get_change_of_basis("siesta_spherical", "spherical") 122 | 123 | spherical = in_siesta @ cob.T 124 | 125 | cob, _ = get_change_of_basis("spherical", "cartesian") 126 | 127 | array_again = spherical @ cob.T 128 | 129 | assert np.allclose(array, array_again) 130 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/tests/test_configuration.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from graph2mat import BasisConfiguration, OrbitalConfiguration, PointBasis 5 | 6 | 7 | @pytest.mark.parametrize("config_cls", [BasisConfiguration, OrbitalConfiguration]) 8 | def test_init_configuration(config_cls): 9 | # The basis 10 | point_1 = PointBasis("A", R=2, basis=[1], basis_convention="spherical") 11 | point_2 = PointBasis("B", R=5, basis=[2, 1], basis_convention="spherical") 12 | 13 | positions = np.array([[0, 0, 0], [6.0, 0, 0], [9, 0, 0]]) 14 | 15 | # Initialize configuration 16 | config_cls( 17 | point_types=["A", "B", "A"], 18 | positions=positions, 19 | basis=[point_1, point_2], 20 | cell=np.eye(3) * 100, 21 | pbc=(False, False, False), 22 | ) 23 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from graph2mat import OrbitalMatrixMetric 4 | 5 | 6 | # Test that all metrics run 7 | @pytest.mark.parametrize( 8 | "metric", 9 | [metric() for metric in OrbitalMatrixMetric.__subclasses__()], 10 | ids=lambda x: x.__class__.__name__, 11 | ) 12 | def test_metric_runs(density_data, density_z_table, metric): 13 | metric( 14 | nodes_pred=density_data.point_labels - 0.001, 15 | nodes_ref=density_data.point_labels, 16 | edges_pred=density_data.edge_labels, 17 | edges_ref=density_data.edge_labels, 18 | batch=density_data, 19 | basis_table=density_z_table, 20 | ) 21 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/tests/test_processing.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import numpy as np 4 | 5 | from graph2mat import ( 6 | PointBasis, 7 | BasisTableWithEdges, 8 | MatrixDataProcessor, 9 | BasisConfiguration, 10 | BasisMatrixData, 11 | OrbitalConfiguration, 12 | ) 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def positions(): 17 | return np.array([[0.1, 0, 0], [6.0, 0, 0]]) 18 | 19 | 20 | @pytest.fixture(scope="module", params=["cartesian", "spherical", "siesta_spherical"]) 21 | def basis_convention(request): 22 | return request.param 23 | 24 | 25 | @pytest.fixture(scope="module") 26 | def basis_table(basis_convention): 27 | point_1 = PointBasis("A", R=2, basis=[1], basis_convention=basis_convention) 28 | point_2 = PointBasis("B", R=5, basis=[2, 1], basis_convention=basis_convention) 29 | 30 | return BasisTableWithEdges([point_1, point_2]) 31 | 32 | 33 | @pytest.mark.parametrize("load_matrix", [False, True]) 34 | @pytest.mark.parametrize("n_ats", [1, 2]) 35 | @pytest.mark.parametrize("config_cls", [BasisConfiguration, OrbitalConfiguration]) 36 | @pytest.mark.parametrize("new_method", ["from_config", "new"]) 37 | def test_init_data( 38 | positions, basis_table, basis_convention, n_ats, load_matrix, new_method, config_cls 39 | ): 40 | # The data processor. 41 | processor = MatrixDataProcessor( 42 | basis_table=basis_table, symmetric_matrix=True, sub_point_matrix=False 43 | ) 44 | 45 | positions = positions[:n_ats] 46 | 47 | matrix = None 48 | if load_matrix: 49 | matrix = np.random.rand(6, 6) 50 | 51 | config = config_cls( 52 | point_types=["A", "B"][:n_ats], 53 | positions=positions, 54 | basis=basis_table, 55 | cell=np.eye(3) * 100, 56 | pbc=(False, False, False), 57 | matrix=matrix, 58 | ) 59 | 60 | # Test from_config method 61 | new = getattr(BasisMatrixData, new_method) 62 | data = new(config, processor) 63 | 64 | if basis_convention == "cartesian": 65 | assert np.all(data.positions == positions) 66 | else: 67 | assert (data.positions != positions).sum() == n_ats * 2 68 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/tests/test_sparse.py: -------------------------------------------------------------------------------- 1 | """Tests for sparse structure conversion""" 2 | import sisl 3 | import numpy as np 4 | 5 | from graph2mat import conversions, Formats 6 | 7 | csr_to_basismatrix = conversions.get_converter(Formats.SCIPY_CSR, Formats.BASISMATRIX) 8 | block_dict_to_csr = conversions.get_converter(Formats.BLOCK_DICT, Formats.SCIPY_CSR) 9 | nodes_and_edges_to_csr = conversions.get_converter( 10 | Formats.NODESEDGES, Formats.SCIPY_CSR 11 | ) 12 | csr_to_sisl_sparse_orbital = conversions.get_converter(Formats.SCIPY_CSR, Formats.SISL) 13 | nodes_and_edges_to_sparse_orbital = conversions.get_converter( 14 | Formats.NODESEDGES, Formats.SISL 15 | ) 16 | 17 | 18 | def test_csr_to_basismatrix_simple(density_matrix): 19 | density_matrix = density_matrix.copy() 20 | 21 | density_matrix._csr.data[:] = 0 22 | density_matrix[0, 0] = 1 23 | density_matrix[0, density_matrix.orbitals[0]] = 2 24 | density_matrix[density_matrix.orbitals[0], 0] = 3 25 | 26 | basis_matrix = csr_to_basismatrix( 27 | density_matrix._csr, 28 | density_matrix.atoms, 29 | nsc=density_matrix.nsc, 30 | fill_value=np.nan, 31 | ) 32 | 33 | for (i_at, j_at), val in zip([(0, 0), (0, 1), (1, 0)], [1, 2, 3]): 34 | assert basis_matrix.block_dict[i_at, j_at, 0][0, 0] == val 35 | assert (~np.isnan(basis_matrix.block_dict[i_at, j_at, 0])).sum() == 1 36 | 37 | 38 | def test_block_dict_to_csr_simple(density_matrix): 39 | density_matrix = density_matrix.copy() 40 | 41 | first_orb_atom1 = density_matrix.orbitals[0] 42 | 43 | density_matrix._csr.data[:] = 0 44 | density_matrix[0, 0] = 1 45 | density_matrix[0, first_orb_atom1] = 2 46 | density_matrix[first_orb_atom1, 0] = 3 47 | 48 | basis_matrix = csr_to_basismatrix( 49 | density_matrix._csr, 50 | density_matrix.atoms, 51 | nsc=density_matrix.nsc, 52 | fill_value=np.nan, 53 | ) 54 | 55 | for (i_at, j_at), val in zip([(0, 0), (0, 1), (1, 0)], [1, 2, 3]): 56 | assert basis_matrix.block_dict[i_at, j_at, 0][0, 0] == val 57 | assert (~np.isnan(basis_matrix.block_dict[i_at, j_at, 0])).sum() == 1 58 | 59 | new_csr = block_dict_to_csr( 60 | basis_matrix.block_dict, density_matrix.firsto, n_supercells=density_matrix.n_s 61 | ) 62 | 63 | assert (new_csr.data != 0).sum() == 3 64 | for (i, j), val in zip( 65 | [(0, 0), (0, first_orb_atom1), (first_orb_atom1, 0)], [1, 2, 3] 66 | ): 67 | assert new_csr[i, j] == val 68 | 69 | 70 | def test_full_block_dict_csr(density_matrix): 71 | csr = density_matrix._csr 72 | basis_matrix = csr_to_basismatrix(csr, density_matrix.atoms, nsc=density_matrix.nsc) 73 | 74 | new_csr = block_dict_to_csr( 75 | basis_matrix.block_dict, density_matrix.firsto, n_supercells=density_matrix.n_s 76 | ) 77 | 78 | assert csr.shape[:-1] == new_csr.shape 79 | assert np.allclose(csr.tocsr().toarray(), new_csr.toarray()) 80 | 81 | 82 | def test_nodes_and_edges_to_csr( 83 | density_matrix, density_config, density_data, density_z_table, symmetric 84 | ): 85 | csr = density_matrix._csr 86 | 87 | edge_index = density_data.edge_index 88 | neigh_isc = density_data.neigh_isc 89 | if symmetric: 90 | edge_index = edge_index[:, ::2] 91 | neigh_isc = neigh_isc[::2] 92 | 93 | new_csr = nodes_and_edges_to_csr( 94 | density_data.point_labels, 95 | density_data.edge_labels, 96 | edge_index, 97 | edge_neigh_isc=neigh_isc, 98 | n_supercells=density_matrix.n_s, 99 | orbitals=density_config.atoms.orbitals, 100 | symmetrize_edges=symmetric, 101 | ) 102 | 103 | assert csr.shape[:-1] == new_csr.shape 104 | assert np.allclose(csr.tocsr().toarray(), new_csr.toarray()) 105 | 106 | 107 | def test_nodes_and_edges_to_dm( 108 | density_matrix, density_config, density_data, density_z_table, symmetric 109 | ): 110 | edge_index = density_data.edge_index 111 | neigh_isc = density_data.neigh_isc 112 | if symmetric: 113 | edge_index = edge_index[:, ::2] 114 | neigh_isc = neigh_isc[::2] 115 | 116 | new_csr = nodes_and_edges_to_csr( 117 | density_data.point_labels, 118 | density_data.edge_labels, 119 | edge_index, 120 | edge_neigh_isc=neigh_isc, 121 | n_supercells=density_matrix.n_s, 122 | orbitals=density_config.atoms.orbitals, 123 | symmetrize_edges=symmetric, 124 | ) 125 | 126 | new_dm = csr_to_sisl_sparse_orbital( 127 | new_csr, geometry=density_matrix.geometry, sp_class=sisl.DensityMatrix 128 | ) 129 | 130 | assert isinstance(new_dm, sisl.DensityMatrix) 131 | 132 | assert density_matrix.shape == new_dm.shape 133 | assert np.all(abs(new_dm - density_matrix)._csr.data < 1e-7) 134 | 135 | 136 | def test_nodes_and_edges_to_dm_direct( 137 | density_matrix, density_data, density_z_table, symmetric 138 | ): 139 | edge_index = density_data.edge_index 140 | neigh_isc = density_data.neigh_isc 141 | if symmetric: 142 | edge_index = edge_index[:, ::2] 143 | neigh_isc = neigh_isc[::2] 144 | 145 | new_dm = nodes_and_edges_to_sparse_orbital( 146 | density_data.point_labels, 147 | density_data.edge_labels, 148 | edge_index, 149 | edge_neigh_isc=neigh_isc, 150 | geometry=density_matrix.geometry, 151 | sp_class=sisl.DensityMatrix, 152 | symmetrize_edges=symmetric, 153 | ) 154 | 155 | assert isinstance(new_dm, sisl.DensityMatrix) 156 | 157 | assert density_matrix.shape == new_dm.shape 158 | assert np.all(abs(new_dm - density_matrix)._csr.data < 1e-7) 159 | -------------------------------------------------------------------------------- /src/graph2mat/core/data/tests/test_table.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from graph2mat import PointBasis, BasisTableWithEdges 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "basis_convention", ["cartesian", "spherical", "siesta_spherical"] 8 | ) 9 | def test_init_table(basis_convention): 10 | point_1 = PointBasis("A", R=2, basis=[1], basis_convention=basis_convention) 11 | point_2 = PointBasis("B", R=5, basis=[2, 1], basis_convention=basis_convention) 12 | 13 | table = BasisTableWithEdges([point_1, point_2]) 14 | 15 | assert table.basis_convention == basis_convention 16 | 17 | 18 | def test_different_convention(): 19 | point_1 = PointBasis("A", R=2, basis=[1], basis_convention="cartesian") 20 | point_2 = PointBasis("B", R=5, basis=[2, 1], basis_convention="spherical") 21 | 22 | with pytest.raises(AssertionError): 23 | table = BasisTableWithEdges([point_1, point_2]) 24 | 25 | 26 | def test_no_basis(): 27 | point_1 = PointBasis("A", R=2, basis_convention="cartesian") 28 | point_2 = PointBasis("B", R=5) 29 | 30 | table = BasisTableWithEdges([point_1, point_2]) 31 | -------------------------------------------------------------------------------- /src/graph2mat/core/modules/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_custom_command( 2 | OUTPUT _labels_resort.c 3 | DEPENDS _labels_resort.py 4 | VERBATIM 5 | COMMAND "${CYTHON}" "${CMAKE_CURRENT_SOURCE_DIR}/_labels_resort.py" --output-file 6 | "${CMAKE_CURRENT_BINARY_DIR}/_labels_resort.c") 7 | 8 | python_add_library(_labels_resort MODULE "${CMAKE_CURRENT_BINARY_DIR}/_labels_resort.c" 9 | WITH_SOABI) 10 | 11 | install(TARGETS _labels_resort DESTINATION ${SKBUILD_PROJECT_NAME}/core/modules) 12 | -------------------------------------------------------------------------------- /src/graph2mat/core/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """Core of the graph2mat models design. 2 | 3 | This module implements the classes that serve as a base for the 4 | architecture of the models in `graph2mat`. The main class is `Graph2Mat`, 5 | """ 6 | 7 | from .graph2mat import * 8 | from .matrixblock import * 9 | -------------------------------------------------------------------------------- /src/graph2mat/core/modules/_labels_resort.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import cython 4 | 5 | 6 | @cython.boundscheck(False) 7 | @cython.wraparound(False) 8 | def get_labels_resorting_array( 9 | types: cython.integral[:], 10 | shapes: cython.integral[:, :], 11 | transpose_neg: cython.bint = False, 12 | ): 13 | """ 14 | The problem this function solves is that graph2mat executes edge/node operations 15 | per edge/node type. In the case where there are 3 types, for example, you end up with 16 | three arrays: 17 | 18 | labels_0 = [...] (n_0, x_0, y_0) 19 | labels_1 = [...] (n_1, x_1, y_1) 20 | labels_2 = [...] (n_2, x_2, y_2) 21 | 22 | Where n_i is the number of edges/nodes of type i, and x_i, y_i are the number of rows 23 | and columns of the block of type i. 24 | 25 | Since each type has a different block shape, they are always 26 | raveled and concatenated into a single array: 27 | 28 | labels = np.concatenate([labels_0.ravel(), labels_1.ravel(), labels_2.ravel()]) 29 | 30 | The labels array is of course not in the same order as the target labels. 31 | 32 | This function receives the original order of types and then returns the indices 33 | to apply to the labels array to get the correct order. I.e.: 34 | 35 | sorted_labels = labels[indices] 36 | 37 | Extra complication for edges 38 | ----------------------------- 39 | 40 | An interesting fact is that when grouping the edges by size, there might be edges 41 | in one direction and edges in the other. E.g.: 42 | 43 | Original basis: A, B, C 44 | 45 | where shape A == shape C != shape B. Then when grouping by size, we have: 46 | 47 | Basis: B, (A, C) 48 | 49 | In the target, you will always have AB and BC edges (never BA or CB). But once you 50 | group, you have to face the problem that now the order is B(A, C), and therefore 51 | AB edges have been reversed. That is, the predicted blocks are the transpose of 52 | the target blocks. 53 | 54 | I think this is only a problem for symmetric matrices where only 55 | one direction is predicted. 56 | """ 57 | n_entries = types.shape[0] 58 | n_types: cython.int = shapes.shape[1] 59 | 60 | type: cython.int 61 | rows: cython.int 62 | cols: cython.int 63 | jrow: cython.int 64 | jcol: cython.int 65 | 66 | type_nlabels: cython.long[:] = np.zeros(n_types, dtype=int) 67 | offset: cython.long[:] = np.zeros(n_types, dtype=int) 68 | 69 | # Compute the sizes for each type 70 | sizes: cython.int[:] = np.zeros(n_types, dtype=np.int32) 71 | for type in range(n_types): 72 | sizes[type] = shapes[0, type] * shapes[1, type] 73 | 74 | # Count the number of entries of each type 75 | for i_edge in range(n_entries): 76 | type: cython.int = abs(types[i_edge]) 77 | 78 | type_nlabels[type] += sizes[type] 79 | 80 | # Cumsum of type_nlabels to understand where do the labels for 81 | # each type start. 82 | for type in range(1, n_types): 83 | offset[type] = offset[type - 1] + type_nlabels[type - 1] 84 | 85 | # Initialize the indices array. 86 | # (for each label value, index of the unsorted array where it is located) 87 | indices: cython.long[:] = np.empty( 88 | offset[n_types - 1] + type_nlabels[n_types - 1], dtype=int 89 | ) 90 | 91 | type_i: cython.long[:] = np.zeros_like(sizes, dtype=int) 92 | i: cython.int = 0 93 | 94 | for i_edge in range(n_entries): 95 | type = types[i_edge] 96 | abs_type: cython.int = abs(type) 97 | 98 | block_size: cython.int = sizes[abs_type] 99 | start: cython.int = offset[abs_type] + type_i[abs_type] 100 | 101 | if transpose_neg and type < 0: 102 | # Get the transposed shape 103 | cols, rows = shapes[0, abs_type], shapes[1, abs_type] 104 | for jrow in range(rows): 105 | for jcol in range(cols): 106 | indices[i] = start + jcol * rows + jrow 107 | i += 1 108 | 109 | else: 110 | for j in range(start, start + block_size): 111 | indices[i] = j 112 | i += 1 113 | 114 | type_i[abs_type] += block_size 115 | 116 | return np.asarray(indices) 117 | 118 | 119 | # HERE IS SOME CODE THAT COULD BE USED IN THE FUTURE TO GET RESORTING 120 | # INDICES WHEN basis_grouping="max". It is not used currently because 121 | # we just compute a mask (see Graph2Mat._get_labels_resort_index) 122 | 123 | # offsets = np.cumsum(original_sizes) 124 | # offsets = np.concatenate(([0], offsets)) 125 | 126 | # abs_original_types = np.abs(original_types) 127 | 128 | # n_vals = original_sizes[abs_original_types].sum() 129 | 130 | # max_size = self.graph2mat_table.point_block_size[0] 131 | 132 | # indices = np.empty(n_vals, dtype=np.int64) 133 | # ival = 0 134 | # ientry = 0 135 | # for type in abs_original_types: 136 | # # Get the start and end of the block 137 | # start = offsets[type] 138 | # end = offsets[type + 1] 139 | 140 | # # Get the indices for this type 141 | # indices[ival : ival + end - start] = ( 142 | # ientry * max_size + filters[start:end] 143 | # ) 144 | # ival += end - start 145 | # ientry += 1 146 | 147 | # return indices 148 | -------------------------------------------------------------------------------- /src/graph2mat/core/modules/matrixblock.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Type, Tuple, TypeVar 3 | from types import ModuleType 4 | 5 | from ..data.basis import PointBasis 6 | 7 | __all__ = ["MatrixBlock"] 8 | 9 | # This type will be 10 | ArrayType = TypeVar("ArrayType") 11 | 12 | 13 | class MatrixBlock: 14 | """Computes a fixed size matrix coming from the product of spherical harmonics. 15 | 16 | There are two things to note: 17 | - It computes a dense matrix. 18 | - It computes a fixed size matrix. 19 | 20 | It takes care of: 21 | - Determining what are the irreps needed to reproduce a certain block. 22 | - Converting from those irreps to the actual values of the block 23 | using the appropiate change of basis. 24 | 25 | This module doesn't implement any computation, so you need to pass one 26 | as stated in the ``operation`` parameter. 27 | 28 | Parameters 29 | ----------- 30 | i_irreps: o3.Irreps 31 | The irreps of the matrix rows. 32 | j_irreps: o3.Irreps 33 | The irreps of the matrix columns. 34 | symmetry: str 35 | Symmetries that this matrix is expected to have. This should be indicated as documented 36 | in `e3nn.o3.ReducedTensorProducts`. As an example, for a symmetric matrix you would 37 | pass "ij=ji" here. 38 | operation_cls: Type[torch.nn.Module] 39 | Torch module used to actually do the computation. On initialization, it will receive 40 | the `irreps_out` argument from this module, specifying the shape of the output that 41 | it should produce. 42 | 43 | On forward, this module will just be a wrapper around the operation, so you should pass 44 | whatever arguments that the operation expects. 45 | **operation_kwargs: dict 46 | Any arguments needed for the initialization of the `operation_cls`. 47 | 48 | Returns 49 | ----------- 50 | matrix: ArrayType 51 | A 2D tensor of shape (i_irreps.dim, j_irreps.dm) containing the output matrix. 52 | """ 53 | 54 | block_shape: Tuple[int, int] 55 | block_size: int 56 | 57 | symm_transpose: bool 58 | 59 | numpy: ModuleType = np 60 | 61 | def __init__( 62 | self, 63 | i_basis: PointBasis, 64 | j_basis: PointBasis, 65 | operation_cls: Type, 66 | symm_transpose: bool = False, 67 | preprocessor=None, 68 | **operation_kwargs, 69 | ): 70 | super().__init__() 71 | self.symm_transpose = symm_transpose 72 | 73 | self.operation = operation_cls( 74 | i_basis=i_basis, j_basis=j_basis, **operation_kwargs 75 | ) 76 | 77 | def _compute_block(self, *args, **kwargs): 78 | return self.operation(*args, **kwargs) 79 | 80 | def forward(self, *args, **kwargs): 81 | if self.symm_transpose == False: 82 | return self._compute_block(*args, **kwargs) 83 | else: 84 | forward = self._compute_block(*args, **kwargs) 85 | 86 | back_args = [ 87 | (arg[1], arg[0]) if isinstance(arg, tuple) and len(arg) == 2 else arg 88 | for arg in args 89 | ] 90 | back_kwargs = { 91 | key: (value[1], value[0]) 92 | if isinstance(value, tuple) and len(value) == 2 93 | else value 94 | for key, value in kwargs.items() 95 | } 96 | backward = self._compute_block(*back_args, **back_kwargs) 97 | 98 | return (forward + backward.transpose(-1, -2)) / 2 99 | 100 | def __call__(self, *args, **kwargs): 101 | return self.forward(*args, **kwargs) 102 | -------------------------------------------------------------------------------- /src/graph2mat/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Models implemented using e3nn_matrix.""" 2 | 3 | from .mace import MatrixMACE -------------------------------------------------------------------------------- /src/graph2mat/models/mace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from e3nn import o3 3 | from copy import copy 4 | 5 | from mace.modules import MACE 6 | 7 | from graph2mat import Graph2Mat 8 | from graph2mat.bindings.e3nn import E3nnGraph2Mat 9 | from graph2mat.bindings.torch.data import TorchBasisMatrixData 10 | 11 | import torch 12 | from e3nn import o3 13 | 14 | from mace.modules import MACE 15 | 16 | from mace.modules.utils import ( 17 | get_edge_vectors_and_lengths, 18 | ) 19 | 20 | 21 | class MatrixMACE(torch.nn.Module): 22 | """Model that wraps a MACE model to produce a matrix output. 23 | 24 | Parameters 25 | ---------- 26 | mace : 27 | MACE model to wrap. 28 | readout_per_interaction : 29 | If ``True``, a separate readout is applied to the features of each 30 | message passing interaction. 31 | If ``False``, the features of all interactions are concatenated 32 | and passed to a single readout. 33 | graph2mat_cls : 34 | Class of the graph2mat model to use for the readouts. 35 | **kwargs : 36 | Additional keyword arguments to pass to ``graph2mat_cls`` for 37 | initialization. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | mace: MACE, 43 | readout_per_interaction: bool = False, 44 | graph2mat_cls: type[Graph2Mat] = E3nnGraph2Mat, 45 | **kwargs, 46 | ): 47 | super().__init__() 48 | 49 | self.mace = mace 50 | 51 | self.readout_per_interaction = readout_per_interaction 52 | 53 | edge_hidden_irreps = kwargs.pop("edge_hidden_irreps", None) 54 | 55 | if self.readout_per_interaction: 56 | self.mace_inter_irreps = [ 57 | o3.Irreps(inter.hidden_irreps) for inter in self.mace.interactions 58 | ] 59 | 60 | self.matrix_readouts = torch.nn.ModuleList( 61 | [ 62 | graph2mat_cls( 63 | irreps=dict( 64 | node_attrs_irreps=inter.node_attrs_irreps, 65 | node_feats_irreps=o3.Irreps(inter.hidden_irreps), 66 | edge_attrs_irreps=inter.edge_attrs_irreps, 67 | edge_feats_irreps=inter.edge_feats_irreps, 68 | edge_hidden_irreps=edge_hidden_irreps, 69 | ), 70 | **kwargs, 71 | ) 72 | for inter in self.mace.interactions 73 | ] 74 | ) 75 | else: 76 | self.mace_inter_irreps = sum( 77 | [inter.hidden_irreps for inter in self.mace.interactions], o3.Irreps() 78 | ) 79 | 80 | self.matrix_readouts = graph2mat_cls( 81 | irreps=dict( 82 | node_attrs_irreps=self.mace.interactions[0].node_attrs_irreps, 83 | node_feats_irreps=self.mace_inter_irreps, 84 | edge_attrs_irreps=self.mace.interactions[0].edge_attrs_irreps, 85 | edge_feats_irreps=self.mace.interactions[0].edge_feats_irreps, 86 | edge_hidden_irreps=edge_hidden_irreps, 87 | ), 88 | **kwargs, 89 | ) 90 | 91 | def forward( 92 | self, data: TorchBasisMatrixData, compute_force: bool = False, **kwargs 93 | ) -> dict[str, torch.Tensor]: 94 | """Forward pass of the model. 95 | 96 | Parameters 97 | ---------- 98 | data : 99 | Input data. 100 | compute_force : 101 | Passed directly to the ``compute_force`` argument of the MACE model. 102 | **kwargs : 103 | Additional keyword arguments to pass to the MACE 104 | model for the forward pass. 105 | 106 | Returns 107 | ------- 108 | output : 109 | The output of the MACE model, with the additional keys "node_labels" 110 | and "edge_labels" containing the output of ``Graph2Mat``. 111 | """ 112 | mace_out = self.mace(data, compute_force=compute_force, **kwargs) 113 | 114 | # Compute edge feats and edge attrs from the modules in the mace model 115 | # (we can't access them from the model because they are not stored/outputted, 116 | # but they are very cheap to recompute) 117 | vectors, lengths = get_edge_vectors_and_lengths( 118 | positions=data["positions"], 119 | edge_index=data["edge_index"], 120 | shifts=data["shifts"], 121 | ) 122 | edge_attrs = self.mace.spherical_harmonics(vectors) 123 | edge_feats = self.mace.radial_embedding( 124 | lengths, data["node_attrs"], data["edge_index"], self.mace.atomic_numbers 125 | ) 126 | 127 | data_for_readout = copy(data) 128 | 129 | data_for_readout["edge_attrs"] = edge_attrs 130 | data_for_readout["edge_feats"] = edge_feats 131 | 132 | # data._edge_attrs_keys = (*data._edge_attrs_keys, "edge_attrs", "edge_feats") 133 | 134 | # Apply the readouts. 135 | if not self.readout_per_interaction: 136 | # Readout from the whole set of features 137 | node_labels, edge_labels = self.matrix_readouts( 138 | data=data_for_readout, 139 | node_feats=mace_out["node_feats"], 140 | ) 141 | else: 142 | # Go interaction by interaction and grab the features that each one produced 143 | # Apply the readout to each interaction and then sum them all. 144 | used = 0 145 | node_labels_list = [] 146 | edge_labels_list = [] 147 | for i, readout in enumerate(self.matrix_readouts): 148 | inter_dim = self.mace_inter_irreps[i].dim 149 | inter_node_feats = mace_out["node_feats"][:, used : used + inter_dim] 150 | used += inter_dim 151 | 152 | node_labels, edge_labels = readout( 153 | data=data_for_readout, 154 | node_feats=inter_node_feats, 155 | ) 156 | 157 | node_labels_list.append(node_labels) 158 | edge_labels_list.append(edge_labels) 159 | 160 | node_labels = torch.stack(node_labels_list).mean(axis=0) 161 | edge_labels = torch.stack(edge_labels_list).mean(axis=0) 162 | 163 | return {**mace_out, "node_labels": node_labels, "edge_labels": edge_labels} 164 | -------------------------------------------------------------------------------- /src/graph2mat/tools/__init__.py: -------------------------------------------------------------------------------- 1 | """Assortment of tools to help with the practical use of e3nn_matrix""" 2 | -------------------------------------------------------------------------------- /src/graph2mat/tools/cli/__init__.py: -------------------------------------------------------------------------------- 1 | """Implements ``e3nn_matrix``'s cli, ``e3mat``, which uses ``typer`` 2 | 3 | Do: 4 | 5 | ``` 6 | e3mat --help 7 | ``` 8 | 9 | for help on how the CLI works. 10 | """ 11 | -------------------------------------------------------------------------------- /src/graph2mat/tools/cli/cli.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | from .siesta.main_cli import app as siesta_app 4 | from .models.cli import app as models_app 5 | from .serve import app as serve_app 6 | from .request import app as request_app 7 | 8 | app = typer.Typer( 9 | help="Command line interface for e3nn_matrix functionality.", 10 | pretty_exceptions_show_locals=False, 11 | rich_markup_mode="markdown", 12 | ) 13 | 14 | app.add_typer(models_app, name="models") 15 | app.add_typer(siesta_app, name="siesta") 16 | app.add_typer(serve_app, name="serve") 17 | app.add_typer(request_app, name="request") 18 | 19 | if __name__ == "__main__": 20 | app() 21 | -------------------------------------------------------------------------------- /src/graph2mat/tools/cli/models/cli.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | from .mace.cli import app as mace_app 4 | 5 | app = typer.Typer( 6 | help=""" 7 | Interface to ML models that have been adapted to use e3nn_matrix. 8 | For each model, we just defer to the Pytorch Lightning CLI, so if 9 | you do --help, you will see the Pytorch Lightning CLI help. 10 | 11 | NOTE: We did a 12 | """ 13 | ) 14 | 15 | app.add_typer(mace_app, name="mace") 16 | 17 | if __name__ == "__main__": 18 | app() 19 | -------------------------------------------------------------------------------- /src/graph2mat/tools/cli/models/mace/cli.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import typer 4 | 5 | app = typer.Typer(help="Interface to MACE models") 6 | 7 | 8 | @app.command( 9 | context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, 10 | add_help_option=False, 11 | ) 12 | def main(ctx: typer.Context): 13 | """Main MACE model interface, using pytorch lightning CLI.""" 14 | from graph2mat.tools.lightning.models.mace import LitMACEMatrixModel 15 | from graph2mat.tools.lightning import ( 16 | OrbitalMatrixCLI, 17 | MatrixDataModule, 18 | SaveConfigSkipBasisTableCallback, 19 | ) 20 | 21 | sys.argv = [ctx.command_path, *ctx.args] 22 | OrbitalMatrixCLI( 23 | LitMACEMatrixModel, 24 | MatrixDataModule, 25 | save_config_callback=SaveConfigSkipBasisTableCallback, 26 | ) 27 | 28 | 29 | if __name__ == "__main__": 30 | app() 31 | -------------------------------------------------------------------------------- /src/graph2mat/tools/cli/request.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from typing_extensions import Annotated 3 | from requests.exceptions import HTTPError 4 | import warnings 5 | 6 | import typer 7 | 8 | app = typer.Typer(help="Utilities to easily interact with the e3nn server.") 9 | 10 | 11 | @app.command() 12 | def avail_models( 13 | host: Annotated[ 14 | str, 15 | typer.Option( 16 | help="Host where the server is running.", envvar="E3MAT_SERVER_HOST" 17 | ), 18 | ] = "localhost", 19 | port: Annotated[ 20 | int, 21 | typer.Option( 22 | help="Port where the server is listening.", envvar="E3MAT_SERVER_PORT" 23 | ), 24 | ] = 56000, 25 | url: Annotated[ 26 | Optional[str], 27 | typer.Option(help="URL of the server.", envvar="E3MAT_SERVER_URL"), 28 | ] = None, 29 | ): 30 | """Shows the models available in the server.""" 31 | from graph2mat.tools.server import ServerClient 32 | 33 | client = ServerClient(host=host, port=port, url=url) 34 | 35 | print(client.avail_models()) 36 | 37 | 38 | @app.command() 39 | def predict( 40 | geometry: Annotated[ 41 | str, 42 | typer.Argument( 43 | help="Path to the geometry file for which the prediction is desired." 44 | ), 45 | ], 46 | output: Annotated[ 47 | str, 48 | typer.Argument(help="Path to the file where the prediction should be saved."), 49 | ], 50 | model: Annotated[ 51 | str, typer.Option(help="Name of the model to use for the prediction.") 52 | ], 53 | local: Annotated[ 54 | bool, typer.Option(help="Whether the paths are in the local filesystem.") 55 | ] = False, 56 | host: Annotated[ 57 | str, 58 | typer.Option( 59 | help="Host where the server is running.", envvar="E3MAT_SERVER_HOST" 60 | ), 61 | ] = "localhost", 62 | port: Annotated[ 63 | int, 64 | typer.Option( 65 | help="Port where the server is listening.", envvar="E3MAT_SERVER_PORT" 66 | ), 67 | ] = 56000, 68 | url: Annotated[ 69 | Optional[str], 70 | typer.Option(help="URL of the server.", envvar="E3MAT_SERVER_URL"), 71 | ] = None, 72 | ): 73 | """Predict the matrix for a given geometry.""" 74 | # Import the server client class, which will be used to interact with the server. 75 | from graph2mat.tools.server import ServerClient 76 | 77 | client = ServerClient(host=host, port=port, url=url) 78 | 79 | try: 80 | client.predict(geometry=geometry, output=output, model=model, local=local) 81 | except HTTPError as e: 82 | if e.response.status_code == 422: 83 | avail_models = client.avail_models() 84 | if model not in avail_models: 85 | raise ValueError( 86 | f"Model '{model}' not available in the server. Available models are: {avail_models}" 87 | ) 88 | 89 | raise e 90 | 91 | 92 | if __name__ == "__main__": 93 | app() 94 | -------------------------------------------------------------------------------- /src/graph2mat/tools/cli/serve.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from typing_extensions import Annotated 3 | 4 | import typer 5 | 6 | app = typer.Typer() 7 | 8 | 9 | @app.callback(invoke_without_command=True) 10 | def main( 11 | models: Annotated[ 12 | List[str], 13 | typer.Argument( 14 | help="""List of models to load. Each model can be provided either as a .ckpt file, a .yaml specification 15 | or a directory that contains a 'spec.yaml' file. 16 | Regardless of what you provide, you can specify the name of the model like 'model_name:file.ckpt', 17 | that is, separated from the file name using a semicolon.""" 18 | ), 19 | ], 20 | host: Annotated[ 21 | str, typer.Option(help="Host to launch the server.", envvar="E3MAT_SERVER_HOST") 22 | ] = "localhost", 23 | port: Annotated[ 24 | int, 25 | typer.Option( 26 | help="Port where the server should listen.", envvar="E3MAT_SERVER_PORT" 27 | ), 28 | ] = 56000, 29 | cpu: Annotated[ 30 | bool, 31 | typer.Option( 32 | help="Load parameters in the CPU regardless of whether they were in the GPU." 33 | ), 34 | ] = True, 35 | local: Annotated[ 36 | bool, 37 | typer.Option( 38 | help="If True, the server allows the user to ask for changes in the local file system." 39 | ), 40 | ] = False, 41 | ): 42 | import uvicorn 43 | 44 | from graph2mat.tools.server import create_server_app_from_filesystem 45 | 46 | # Sanitize the ckpt files, building a dictionary with names and files. 47 | ckpt_files_dict = {} 48 | for i, model_file in enumerate(models): 49 | splitted = model_file.split(":") 50 | 51 | if len(splitted) == 2: 52 | model_name, model_file = splitted 53 | else: 54 | model_name = str(i) 55 | 56 | ckpt_files_dict[model_name] = model_file 57 | 58 | # Then build the app 59 | fastapi_app = create_server_app_from_filesystem( 60 | ckpt_files_dict, cpu=cpu, local=local 61 | ) 62 | 63 | # And launch it. 64 | uvicorn.run(fastapi_app, host=host, port=port) 65 | 66 | 67 | if __name__ == "__main__": 68 | app() 69 | -------------------------------------------------------------------------------- /src/graph2mat/tools/cli/siesta/main_cli.py: -------------------------------------------------------------------------------- 1 | import typer 2 | 3 | from .md import app as md_app 4 | 5 | app = typer.Typer( 6 | help="Set of utilities to interface the machine learning models with SIESTA." 7 | ) 8 | 9 | app.add_typer(md_app, name="md") 10 | -------------------------------------------------------------------------------- /src/graph2mat/tools/cli/siesta/md.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Union 2 | from typing_extensions import Annotated 3 | 4 | from pathlib import Path 5 | 6 | import typer 7 | 8 | from graph2mat.tools.cli._typer import annotate_typer 9 | from graph2mat.tools.siesta.md import ( 10 | visualize_performance_table, 11 | setup, 12 | setup_store, 13 | ) 14 | 15 | app = typer.Typer(help="Utilities for molecular dynamics runs.") 16 | 17 | app.command("analyze")(annotate_typer(visualize_performance_table)) 18 | 19 | app.command()(annotate_typer(setup)) 20 | 21 | app.command()(annotate_typer(setup_store)) 22 | 23 | if __name__ == "__main__": 24 | app() 25 | -------------------------------------------------------------------------------- /src/graph2mat/tools/lightning/__init__.py: -------------------------------------------------------------------------------- 1 | """Interface to use the matrix models with ``pytorch_lightning``. 2 | 3 | Pytorch lightning is a very useful to streamline the training and deployment 4 | of machine learning models. It is the primary tool that ``e3nn_matrix`` has 5 | chosen to support for this purpose. Keep in mind that you can use whatever 6 | you want, though! 7 | 8 | In this module we implement the data and model classes, which are just 9 | interfaces to the pure ``e3nn_matrix`` classes. We also implement a CLI 10 | that is just ``pytorch_lightning``'s CLI with some tweaks that we think 11 | make its usage smoother for matrix learning. 12 | 13 | """ 14 | 15 | from .model import LitBasisMatrixModel 16 | from .callbacks import * 17 | from .data import MatrixDataModule 18 | from .cli import OrbitalMatrixCLI, SaveConfigSkipBasisTableCallback 19 | -------------------------------------------------------------------------------- /src/graph2mat/tools/lightning/model.py: -------------------------------------------------------------------------------- 1 | """Wrapping of raw models to use them in pytorch_lightning.""" 2 | 3 | from pathlib import Path 4 | from typing import Type, Union, Optional 5 | import warnings 6 | 7 | from e3nn import o3 8 | 9 | import pytorch_lightning as pl 10 | import torch 11 | 12 | # from context import mace 13 | from graph2mat.core.data.metrics import OrbitalMatrixMetric, block_type_mse 14 | from graph2mat import BasisTableWithEdges, AtomicTableWithEdges 15 | from graph2mat.bindings.torch.load import sanitize_checkpoint 16 | from graph2mat.core.data.node_feats import NodeFeature 17 | from graph2mat import __version__ 18 | 19 | 20 | class LitBasisMatrixModel(pl.LightningModule): 21 | """Base class to wrap a matrix model to use it in pytorch_lightning.""" 22 | 23 | basis_table: BasisTableWithEdges 24 | model: torch.nn.Module 25 | model_kwargs: dict 26 | 27 | def __init__( 28 | self, 29 | model_cls: Type[torch.nn.Module], 30 | root_dir: str = ".", 31 | basis_files: Union[str, None] = None, 32 | basis_table: Union[BasisTableWithEdges, None] = None, 33 | no_basis: Optional[dict] = None, 34 | loss: Type[OrbitalMatrixMetric] = block_type_mse, 35 | initial_node_feats: str = "OneHotZ", 36 | **kwargs, 37 | ): 38 | super().__init__() 39 | 40 | self.save_hyperparameters() 41 | 42 | if basis_table is None: 43 | if basis_files is None: 44 | self.basis_table = None 45 | else: 46 | self.basis_table = AtomicTableWithEdges.from_basis_glob( 47 | Path(root_dir).glob(basis_files), no_basis_atoms=no_basis 48 | ) 49 | else: 50 | self.basis_table = basis_table 51 | 52 | self.initial_node_feats = [ 53 | NodeFeature.registry[k] for k in initial_node_feats.split(" ") 54 | ] 55 | self.initial_node_feats_irreps = sum( 56 | [f.get_e3nn_irreps(self.basis_table) for f in self.initial_node_feats], 57 | o3.Irreps(), 58 | ).simplify() 59 | 60 | self.loss_fn = loss() 61 | 62 | self.model_cls = model_cls 63 | self.model = None # Subclasses are responsible for initializing the model by calling init_model. 64 | 65 | def init_model(self, **kwargs): 66 | """Initializes the model, storing the arguments used.""" 67 | self.model_kwargs = kwargs 68 | self.model = self.model_cls(**self.model_kwargs) 69 | return self.model 70 | 71 | def forward(self, x): 72 | return self.model(x) 73 | 74 | def training_step(self, batch, batch_idx): 75 | out = self.model(batch) 76 | 77 | loss, stats = self.loss_fn( 78 | nodes_pred=out["node_labels"], 79 | nodes_ref=batch["point_labels"], 80 | edges_pred=out["edge_labels"], 81 | edges_ref=batch["edge_labels"], 82 | batch=batch, 83 | basis_table=self.basis_table, 84 | ) 85 | 86 | self.log( 87 | "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True 88 | ) 89 | 90 | for k, v in stats.items(): 91 | self.log( 92 | f"train_{k}", 93 | v, 94 | on_step=True, 95 | on_epoch=True, 96 | prog_bar=False, 97 | logger=True, 98 | ) 99 | 100 | return {**out, "loss": loss} 101 | 102 | def validation_step(self, batch, batch_idx): 103 | out = self.model(batch) 104 | 105 | loss, stats = self.loss_fn( 106 | nodes_pred=out["node_labels"], 107 | nodes_ref=batch["point_labels"], 108 | edges_pred=out["edge_labels"], 109 | edges_ref=batch["edge_labels"], 110 | batch=batch, 111 | basis_table=self.basis_table, 112 | log_verbose=True, 113 | ) 114 | 115 | self.log("val_loss", loss, prog_bar=True, logger=True) 116 | # save validation loss as the hyperparameter opt metric (used by tensorboard) 117 | self.log("hp_metric", loss) 118 | 119 | for k, v in stats.items(): 120 | self.log(f"val_{k}", v) 121 | 122 | return {**out, "loss": loss} 123 | 124 | def test_step(self, batch, batch_idx): 125 | out = self.model(batch) 126 | 127 | loss, stats = self.loss_fn( 128 | nodes_pred=out["node_labels"], 129 | nodes_ref=batch["point_labels"], 130 | edges_pred=out["edge_labels"], 131 | edges_ref=batch["edge_labels"], 132 | batch=batch, 133 | basis_table=self.basis_table, 134 | log_verbose=True, 135 | ) 136 | 137 | self.log("test_loss", loss, prog_bar=True, logger=True) 138 | 139 | for k, v in stats.items(): 140 | self.log(f"test_{k}", v) 141 | 142 | return out 143 | 144 | def on_save_checkpoint(self, checkpoint) -> None: 145 | "Objects to include in checkpoint file" 146 | checkpoint["basis_table"] = self.basis_table 147 | checkpoint["version"] = __version__ 148 | 149 | def on_load_checkpoint(self, checkpoint) -> None: 150 | "Objects to retrieve from checkpoint file" 151 | san_checkpoint = sanitize_checkpoint(checkpoint) 152 | checkpoint.update(san_checkpoint) 153 | 154 | try: 155 | self.basis_table = checkpoint["basis_table"] 156 | except KeyError: 157 | warnings.warn( 158 | "Failed to load basis_table from checkpoint: Key does not exist." 159 | ) 160 | 161 | try: 162 | ckpt_version = checkpoint["version"] 163 | except KeyError: 164 | ckpt_version = None 165 | warnings.warn("Unable to determine version that created checkpoint file") 166 | if ckpt_version: 167 | if not (ckpt_version == __version__): 168 | warnings.warn( 169 | "The checkpoint version %s does not match the current package version %s" 170 | % (ckpt_version, __version__) 171 | ) 172 | -------------------------------------------------------------------------------- /src/graph2mat/tools/lightning/tests/test_lightning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytorch_lightning as pl 3 | import sisl 4 | import tempfile 5 | 6 | import pytest 7 | 8 | from graph2mat import ( 9 | BasisConfiguration, 10 | PointBasis, 11 | BasisTableWithEdges, 12 | ) 13 | 14 | from graph2mat.tools.lightning.models.mace import LitMACEMatrixModel 15 | from graph2mat.tools.lightning import ( 16 | MatrixDataModule, 17 | MatrixWriter, 18 | SamplewiseMetricsLogger, 19 | PlotMatrixError, 20 | ) 21 | 22 | 23 | @pytest.mark.parametrize("out_matrix", ["density_matrix", "hamiltonian"]) 24 | def test_lightning_tools_python(out_matrix): 25 | # The basis 26 | point_1 = PointBasis("A", R=2, basis="0e", basis_convention="spherical") 27 | point_2 = PointBasis("B", R=5, basis="2x0e + 1o", basis_convention="spherical") 28 | 29 | basis = [point_1, point_2] 30 | 31 | # The basis table. 32 | table = BasisTableWithEdges(basis) 33 | 34 | # Lightning model 35 | model = LitMACEMatrixModel( 36 | basis_table=table, 37 | hidden_irreps="0e + 1o + 2e", 38 | symmetric_matrix=True, 39 | ) 40 | 41 | # Configurations (just one with a random matrix) 42 | config1 = BasisConfiguration( 43 | point_types=["A", "B", "A"], 44 | positions=np.array([[0, 0, 0], [6.0, 0, 0], [12, 0, 0]]), 45 | basis=basis, 46 | cell=np.eye(3) * 100, 47 | pbc=(False, False, False), 48 | matrix=np.random.random((7, 7)), 49 | ) 50 | configs = [config1] 51 | 52 | # Create the datamodule 53 | datamodule = MatrixDataModule( 54 | out_matrix, 55 | basis_table=table, 56 | symmetric_matrix=True, 57 | sub_point_matrix=False, 58 | train_runs=configs, 59 | val_runs=configs, 60 | test_runs=configs, 61 | ) 62 | 63 | # Temporary files for the callbacks 64 | matrix_file = tempfile.NamedTemporaryFile( 65 | delete=False, suffix=".DM" if out_matrix == "density_matrix" else ".HSX" 66 | ) 67 | metrics_file = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") 68 | 69 | # List of callbacks to use during training 70 | callbacks = [ 71 | MatrixWriter(matrix_file.name, splits=["test"]), 72 | SamplewiseMetricsLogger(splits=["val"], output_file=metrics_file.name), 73 | PlotMatrixError(split="test", show=False, store_in_logger=False), 74 | ] 75 | 76 | # Create the trainer 77 | trainer = pl.Trainer( 78 | callbacks=callbacks, max_epochs=1, logger=False, enable_checkpointing=False 79 | ) 80 | 81 | # Run training (1 epoch) 82 | trainer.fit(model, datamodule=datamodule) 83 | 84 | # Run validation to get the validation metrics 85 | val_metrics, *_ = trainer.validate(model, datamodule=datamodule) 86 | # Make sure the matrix writer hasn't written anything yet 87 | # (it is set to only write on test) 88 | assert matrix_file.read() == bytes() 89 | 90 | # Run test to get the test metrics 91 | test_metrics, *_ = trainer.test(model, datamodule=datamodule) 92 | 93 | # The test set and validation sets are the same, therefore 94 | # the metrics should be the same 95 | for k in val_metrics: 96 | if not k.startswith("val_"): 97 | continue 98 | test_k = k.replace("val_", "test_") 99 | assert ( 100 | val_metrics[k] == test_metrics[test_k] 101 | ), f"Validation and test metrics do not match for {k} and {test_k}" 102 | 103 | # Check that the written matrix is correct (it should be the same 104 | # as the one returned by the model) 105 | data = next(iter(datamodule.train_dataloader())) 106 | out = model(data) 107 | pred_matrix = datamodule.data_processor.matrix_from_data(data, out)[0] 108 | if out_matrix == "density_matrix": 109 | read_matrix = sisl.get_sile(matrix_file.name).read_density_matrix() 110 | elif out_matrix == "hamiltonian": 111 | read_matrix = sisl.get_sile(matrix_file.name).read_hamiltonian() 112 | 113 | assert np.allclose(read_matrix.tocsr().toarray(), pred_matrix.tocsr().toarray()) 114 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/__init__.py: -------------------------------------------------------------------------------- 1 | """Utilites to serve and request matrix predictions. 2 | 3 | Initializing a model takes time. For this reason, sometimes 4 | it is useful to have a server running the model, waiting for 5 | requests to compute predictions. 6 | 7 | This module implements a very simple HTTP server to provide matrix 8 | predictions, as well as a very simple HTML front end and a very 9 | simple python client API so that you are not forced to use raw requests. 10 | 11 | Launching the server is easiest from the ``e3mat`` CLI with ``e3mat serve``. 12 | You can also use ``e3mat request``, which uses the client. 13 | """ 14 | 15 | from .api_client import ServerClient 16 | from .server_app import create_server_app, create_server_app_from_filesystem 17 | 18 | __all__ = ["ServerClient", "create_server_app", "create_server_app_from_filesystem"] 19 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/api_client.py: -------------------------------------------------------------------------------- 1 | """Simple HTTP client to interact with the server.""" 2 | 3 | from typing import Union, List, Type, Optional 4 | 5 | import tempfile 6 | import os 7 | import urllib.parse 8 | 9 | import requests 10 | from pathlib import Path 11 | 12 | import sisl 13 | 14 | 15 | class ServerClient: 16 | """Client to interact easily with the e3nn server. 17 | 18 | Parameters 19 | ---------- 20 | url : str or None, optional 21 | Root url where the server is running. 22 | 23 | If it is set to None, the environment variable E3MAT_SERVER_URL 24 | will be used if present. 25 | 26 | If it is set to None and the environment variable is not present, 27 | the url will be constructed from the values of ``host`` and ``port``. 28 | host : str or None, optional 29 | Host where the server is running. 30 | 31 | If it is set to None, the environment variable E3MAT_SERVER_HOST 32 | will be used if present, otherwise the default value of ``host`` will 33 | be used. 34 | port : int, optional 35 | Port where the server is listening. 36 | 37 | If it is set to None, the environment variable E3MAT_SERVER_PORT 38 | will be used if present, otherwise the default value of ``port`` will 39 | be used. 40 | """ 41 | 42 | host: Union[str, None] 43 | port: Union[int, None] 44 | root_url: str 45 | 46 | def __init__( 47 | self, 48 | url: Optional[str] = None, 49 | host: Optional[str] = "localhost", 50 | port: Optional[int] = 56000, 51 | ): 52 | url = url or os.environ.get("E3MAT_SERVER_URL", None) 53 | 54 | if url is None: 55 | self.host = host or os.environ.get("E3MAT_SERVER_HOST", "localhost") 56 | self.port = port or int(os.environ.get("E3MAT_SERVER_PORT", 56000)) 57 | 58 | self.root_url = f"http://{self.host}:{self.port}" 59 | else: 60 | parsed_url = urllib.parse.urlparse(url) 61 | 62 | self.host = parsed_url.hostname 63 | self.port = parsed_url.port 64 | self.root_url = url 65 | 66 | self.api_url = f"{self.root_url}/api" 67 | 68 | def avail_models(self) -> List[str]: 69 | """Returns the models that are available in the server.""" 70 | response = requests.get(f"{self.api_url}/avail_models") 71 | response.raise_for_status() 72 | return response.json() 73 | 74 | def predict( 75 | self, 76 | geometry: Union[str, Path, sisl.Geometry], 77 | output: Union[str, Path, Type[sisl.SparseOrbital]], 78 | model: str, 79 | local: bool = False, 80 | ) -> Union[Path, sisl.SparseOrbital]: 81 | """Predicts the matrix for a given geometry. 82 | 83 | Parameters 84 | ---------- 85 | geometry : Union[str, sisl.Geometry] 86 | Either the path to the geometry file or the geometry itself. 87 | output : Union[str, Type[sisl.SparseOrbital]] 88 | Either the path to the file where the prediction should be saved 89 | or the type of the object to be returned. 90 | model : str 91 | Name of the model to use for the prediction. This model must be available 92 | in the server. You can check the available models with the `avail_models` method. 93 | local : bool 94 | Whether the paths (if given) are in the local filesystem. 95 | """ 96 | # We may need temporal files to transmit the geometry and the output 97 | input_tmp_file = None 98 | output_tmp_file = None 99 | 100 | # Regardless of what the argument 'geometry' is, we will pass a file 101 | # to the server. 102 | # If a geometry is provided, we store it to a temporary file. 103 | if isinstance(geometry, sisl.Geometry): 104 | input_tmp_file = tempfile.NamedTemporaryFile(suffix=".xyz", delete=False) 105 | geometry_path = Path(input_tmp_file.name).absolute() 106 | geometry = geometry.write(geometry_path) 107 | else: 108 | geometry_path = Path(geometry).absolute() 109 | 110 | # Regardless of what the argument 'output' is, the server0's output will 111 | # be written to a file. 112 | # If the user wants a sisl object, we will just tell the server to output 113 | # to a temporary file, which we will then parse. 114 | if isinstance(output, type) and issubclass(output, sisl.SparseOrbital): 115 | suffix = { 116 | sisl.DensityMatrix: ".DM", 117 | sisl.Hamiltonian: ".TSHS", 118 | sisl.EnergyDensityMatrix: ".EDM", 119 | }[output] 120 | 121 | output_tmp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) 122 | output_path = Path(output_tmp_file.name).absolute() 123 | else: 124 | output_path = Path(output).absolute() 125 | 126 | # Now we can make the request to the server. 127 | if local: 128 | # We ask the server to read and write to its local filesystem. 129 | response = requests.get( 130 | f"{self.api_url}/models/{model}/local_write_predict", 131 | params={ 132 | "geometry_path": str(geometry_path), 133 | "output_path": str(output_path), 134 | }, 135 | ) 136 | 137 | response.raise_for_status() 138 | else: 139 | # We send the geometry file from our filesystem to the server. 140 | # We will also receive a file from the server (binary content). 141 | files = {"geometry_file": open(geometry_path, "rb")} 142 | 143 | response = requests.post( 144 | f"{self.api_url}/models/{model}/predict", files=files 145 | ) 146 | 147 | response.raise_for_status() 148 | 149 | with open(output_path, "wb") as f: 150 | f.write(response.content) 151 | 152 | # Remove the temporal files if they were created. 153 | if input_tmp_file is not None: 154 | Path(input_tmp_file.name).unlink() 155 | 156 | if output_tmp_file is not None: 157 | # In the case that we wanted a sisl obect as output, we parse the 158 | # file that the server sent us, using the read method of the sisl 159 | # object. 160 | returns = output.read(output_path) 161 | Path(output_path).unlink() 162 | else: 163 | returns = output_path 164 | 165 | return returns 166 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/frontend/static/javascript/form.js: -------------------------------------------------------------------------------- 1 | /* Files that handle interaction with the predictions form 2 | This file must be included whenever the form is used. */ 3 | 4 | updateFileNames = function(input, output, remove) { 5 | 6 | if (remove || input.files.length == 0) { 7 | output.textContent = "No file selected"; 8 | } else { 9 | output.textContent= "Uploaded file: " + input.files.item(0).name; 10 | } 11 | 12 | } 13 | 14 | fileInputChange = function(event) { 15 | var input = event.target; 16 | var output = input.closest(".file-drop-container").querySelectorAll(".uploaded-files")[0]; 17 | 18 | updateFileNames(input, output); 19 | } 20 | 21 | fileDrop = function(event) { 22 | event.preventDefault(); 23 | var input = event.target.querySelectorAll("input[type=file]")[0]; 24 | var output = event.target.closest(".file-drop-container").querySelectorAll(".uploaded-files")[0]; 25 | 26 | input.files = event.dataTransfer.files; 27 | event.target.classList.remove("border-teal-600"); 28 | updateFileNames(input, output, false); 29 | } 30 | 31 | fileDrag = function(event) { 32 | event.preventDefault(); 33 | event.target.classList.add("border-teal-600"); 34 | } 35 | 36 | fileDragLeave = function(event) { 37 | event.preventDefault(); 38 | event.target.classList.remove("border-teal-600"); 39 | } 40 | 41 | formReset = function(event) { 42 | form = event.target; 43 | 44 | //Update all the file inputs 45 | var file_inputs = form.querySelectorAll(".file-drop-container"); 46 | file_inputs.forEach((input) => updateFileNames(input.querySelectorAll("input[type=file]")[0], input.querySelectorAll(".uploaded-files")[0], true)); 47 | 48 | document.getElementById("loading").classList.add("hidden"); 49 | document.getElementById("error").classList.add("hidden"); 50 | } 51 | 52 | formSubmit = function(event) { 53 | event.preventDefault(); 54 | 55 | form = event.target; 56 | 57 | document.getElementById("loading").classList.remove("hidden"); 58 | document.getElementById("error").classList.add("hidden"); 59 | document.getElementById("submit_button").disabled = true; 60 | 61 | // TODO do something here to show user that form is being submitted 62 | fetch(form.action, { 63 | method: form.method, 64 | body: new FormData(form) 65 | }).then(async (response) => { 66 | 67 | document.getElementById("loading").classList.add("hidden"); 68 | document.getElementById("submit_button").disabled = false; 69 | 70 | if (!response.ok) { 71 | 72 | var error_element = document.getElementById("error") 73 | var error_m_element = document.getElementById("error_message") 74 | error_m_element.textContent = `HTTP error! Status: ${response.status}`; 75 | error_element.classList.remove("hidden"); 76 | 77 | } 78 | 79 | // Get the name of the received file 80 | var filename = response.headers.get("content-disposition").split("filename=")[1].split(";")[0].slice(1, -1) 81 | 82 | return {filename: filename, blob: await response.blob()}; 83 | }).then(({filename, blob}) => { 84 | var file_url = window.URL.createObjectURL(blob); 85 | 86 | let link = document.createElement('a'); 87 | link.href = file_url; 88 | link.download = filename; 89 | link.click(); 90 | 91 | }).catch((error) => { 92 | // TODO handle error 93 | console.warn(error) 94 | }); 95 | } 96 | 97 | formSubmitTest = function(event) { 98 | event.preventDefault(); 99 | 100 | form = event.target; 101 | 102 | document.getElementById("loading").classList.remove("hidden"); 103 | document.getElementById("error").classList.add("hidden"); 104 | document.getElementById("submit_button").disabled = true; 105 | 106 | // TODO do something here to show user that form is being submitted 107 | fetch(form.action, { 108 | method: form.method, 109 | body: new FormData(form) 110 | }).then(async (response) => { 111 | 112 | document.getElementById("loading").classList.add("hidden"); 113 | document.getElementById("submit_button").disabled = false; 114 | 115 | if (!response.ok) { 116 | 117 | var error_element = document.getElementById("error") 118 | var error_m_element = document.getElementById("error_message") 119 | error_m_element.textContent = `HTTP error! Status: ${response.status}`; 120 | error_element.classList.remove("hidden"); 121 | 122 | } 123 | 124 | return response.json(); 125 | }).then((json) => { 126 | 127 | var pre_element = document.createElement("pre"); 128 | pre_element.textContent = JSON.stringify(json, null, 2); 129 | 130 | output_div = document.getElementById("output_div") 131 | output_div.classList.remove("hidden"); 132 | 133 | // Add pre tag to output div 134 | output_div.appendChild(pre_element); 135 | 136 | }).catch((error) => { 137 | // TODO handle error 138 | console.warn(error) 139 | }); 140 | } 141 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/frontend/static/styles/styles.css: -------------------------------------------------------------------------------- 1 | /* Scrolling behavior is specified in the html templates using the 2 | overflow-y-auto class. Get those cases and customize the scrollbar.*/ 3 | 4 | /* For Firefox Browser */ 5 | .overflow-y-auto { 6 | scrollbar-width: thin; 7 | scrollbar-color: #000 #ccc; 8 | } 9 | 10 | /* For Chrome, EDGE, Opera, Others */ 11 | .overflow-y-auto::-webkit-scrollbar { 12 | width: 5px; 13 | } 14 | 15 | .overflow-y-auto::-webkit-scrollbar-track { 16 | background: #fff; 17 | } 18 | 19 | .overflow-y-auto::-webkit-scrollbar-thumb { 20 | background:#ccc; 21 | } 22 | 23 | /* Styling for the metrics table */ 24 | 25 | th { 26 | border: 1px solid #ccc; 27 | padding: 5px; 28 | text-align: center; 29 | vertical-align: middle; 30 | border-right: 1px solid #ccc; 31 | } 32 | 33 | td { 34 | border: 1px solid #ccc; 35 | padding: 5px; 36 | text-align: center; 37 | vertical-align: middle; 38 | border-right: 1px solid #ccc; 39 | } 40 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/frontend/templates/about.html: -------------------------------------------------------------------------------- 1 | {% extends 'index.html' %} 2 | 3 | {% block content %} 4 |
5 |

Welcome to the matrix server!

6 | 7 |

8 | This server helps you use machine learning models to predict matrices without using any code. 9 |

10 | 11 |

12 | The host of this server receives your inputs, does the calculations and returns the outputs so that you don't need to install anything. 13 | 14 |
This is a good place to start to see what the models can do. If you need to use the models intensively, we recommend you 15 | to download them and run them in your own computer. 16 |

17 |
18 | {% endblock %} 19 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/frontend/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
11 | {% include 'topbar.html' %} 12 |
13 | {% block content %} 14 |
15 | {% include 'model_picker_sidebar.html' %} 16 |
17 |
18 |
19 | Pick a model on the left to start using it. 20 |
21 |
22 | {% endblock %} 23 |
24 |
25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/frontend/templates/model_action_page.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | {% include 'topbar.html' %} 13 |
14 |
15 | {% include 'model_picker_sidebar.html' %} 16 |
17 |
18 | {% include 'model_info_card.html' %} 19 |
20 |

Request form

21 |

Use this form to send requests to the model.

22 | 23 |

Pick an action

24 | {% include 'model_action_picker.html' %} 25 |

{{ model_action.short_help }}

26 | {% block model_action_form %} 27 |
28 | 29 |
30 | 31 | 32 |
33 |
34 | {% endblock %} 35 | 36 | 46 | 47 | 53 | 54 |
55 | 56 |
57 |
58 |
59 |
60 | 61 | 62 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/frontend/templates/model_action_picker.html: -------------------------------------------------------------------------------- 1 |
2 | 21 |
22 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/frontend/templates/model_info_card.html: -------------------------------------------------------------------------------- 1 |
2 |
3 |

Model information

4 |

Inspect the most relevant information about the model.

5 |
6 |
7 |
8 |
9 |
Model name
10 |
{{ model_name }}
11 |
12 |
13 |
Description
14 |
{{ model.description }}
15 |
16 |
17 |
Authors
18 |
19 |
    20 | {% for author in model.authors %} 21 |
  • {{ author }}
  • 22 | {% endfor %} 23 |
24 |
25 |
26 |
27 |
Data properties
28 |
29 |
    30 |
  • Symmetric matrix: {{ model.data_processor.symmetric_matrix }}
  • 31 |
  • Subtract point contributions: {{ model.data_processor.sub_point_matrix }}
  • 32 |
  • Matrix type: {{ model.data_processor.out_matrix }}
  • 33 |
34 |
35 |
36 |
37 |
Basis
38 |
39 |
    40 | {% for point_basis in model.data_processor.basis_table.basis %} 41 |
  • {{ point_basis }}
  • 42 | {% endfor %} 43 |
44 |
45 |
46 |
47 |
Test metrics summary
48 |
49 | {{ model.test_metrics_summary | safe }} 50 |
51 |
52 |
53 |
Files
54 |
55 |
    56 | {% for file in model.files.keys() %} 57 |
  • 58 |
    59 | 62 |
    63 | {{ file }} 64 | 65 |
    66 |
    67 |
    68 | Download 69 |
    70 |
  • 71 | {% endfor %} 72 |
73 |
74 |
75 |
76 |
77 |
78 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/frontend/templates/model_picker_sidebar.html: -------------------------------------------------------------------------------- 1 | 26 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/frontend/templates/predict_page.html: -------------------------------------------------------------------------------- 1 | {% extends 'model_action_page.html' %} 2 | 3 | {% block model_action_form %} 4 |
5 |
6 |
7 | 8 |
9 |
10 | 13 |
14 | 18 |

or drag and drop

19 |
20 |

Any file that sisl can parse into a geometry.

21 |
22 |
23 |

No file selected.

24 |
25 | 26 |
27 | 28 |
29 | 30 | 31 |
32 |
33 | {% endblock %} 34 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/frontend/templates/test_page.html: -------------------------------------------------------------------------------- 1 | {% extends 'model_action_page.html' %} 2 | 3 | {% block model_action_form %} 4 |
5 |
6 |
7 | 8 |
9 |
10 | 13 |
14 | 18 |

or drag and drop

19 |
20 |

Any file that sisl can parse into a geometry.

21 |
22 |
23 |

No file selected.

24 |
25 | 26 |
27 | 28 |
29 |
30 | 31 |
32 |
33 | 36 |
37 | 41 |

or drag and drop

42 |
43 |

Any file that sisl can parse into a matrix.

44 |
45 |
46 |

No file selected.

47 |
48 | 49 |
50 | 51 |
52 | 53 | 54 |
55 |
56 | {% endblock %} 57 | -------------------------------------------------------------------------------- /src/graph2mat/tools/server/frontend/templates/topbar.html: -------------------------------------------------------------------------------- 1 |
2 | 25 |
26 | -------------------------------------------------------------------------------- /src/graph2mat/tools/siesta/__init__.py: -------------------------------------------------------------------------------- 1 | """Utilities to manage the interaction with the SIESTA DFT code.""" 2 | -------------------------------------------------------------------------------- /src/graph2mat/tools/siesta/templates/fdf/dm_init_atomic.fdf: -------------------------------------------------------------------------------- 1 | DM.AllowReuse f 2 | DM.UseSaveDM f 3 | DM.AllowExtrapolation f 4 | DM.HistoryDepth 0 5 | 6 | {{ extra_string }} 7 | -------------------------------------------------------------------------------- /src/graph2mat/tools/siesta/templates/fdf/dm_init_ml.fdf: -------------------------------------------------------------------------------- 1 | DM.HistoryDepth 0 2 | DM.UseSaveDM t 3 | Lua.Script {{ lua_script }} 4 | -------------------------------------------------------------------------------- /src/graph2mat/tools/siesta/templates/fdf/dm_init_siesta_extrapolation.fdf: -------------------------------------------------------------------------------- 1 | DM.HistoryDepth {{ history_depth }} 2 | 3 | {{ extra_string }} 4 | -------------------------------------------------------------------------------- /src/graph2mat/tools/siesta/templates/lua/graph2mat.lua: -------------------------------------------------------------------------------- 1 | {% if include_store_code %} 2 | local store_dir = "{{ store_dir }}" 3 | local store_interval = {{ store_interval }} 4 | local store_step_prefix = "{{ store_step_prefix }}" 5 | local store_files = "{{ store_files }}" 6 | 7 | local istep_store = 0 8 | 9 | {% endif %} 10 | {% if include_server_code %} 11 | local server_address = "{{ server_address }}" 12 | local history_len = {{ history_len }} 13 | local main_fdf = "{{ main_fdf }}" 14 | local matrix_ref = "{{ ml_model_name }}" 15 | local work_dir = {{ work_dir }} 16 | 17 | local main_fdf_path = work_dir .. "/" .. main_fdf 18 | local istep_extrapolation = 0 19 | 20 | {% endif %} 21 | function siesta_comm() 22 | 23 | {% if include_store_code %} 24 | -- Initialize the storage directory if this is the beggining of the MD 25 | if siesta.state == siesta.INIT_MD then 26 | init_store_dir() 27 | end 28 | 29 | -- After each step, store the step files 30 | if siesta.state == siesta.FORCES then 31 | store_step(istep_store) 32 | istep_store = istep_store + 1 33 | end 34 | {% endif %} 35 | 36 | {% if include_server_code %} 37 | -- On initialization, write the prediction for the starting structure 38 | if siesta.state == siesta.INITIALIZE then 39 | init_history(history_len, matrix_ref) 40 | end 41 | 42 | if siesta.state == siesta.INIT_MD then 43 | if istep_extrapolation == 0 then 44 | add_step(main_fdf_path, matrix_ref) 45 | end 46 | end 47 | 48 | if siesta.state == siesta.FORCES then 49 | add_matrix(main_fdf_path) 50 | end 51 | 52 | if siesta.state == siesta.AFTER_MOVE then 53 | istep_extrapolation = istep_extrapolation + 1 54 | 55 | add_step(main_fdf_path, matrix_ref) 56 | predict_next(work_dir .. "/siesta.DM") 57 | 58 | end 59 | {% endif %} 60 | 61 | end 62 | 63 | {% if include_store_code %} 64 | -- ---------------------------------------------------- 65 | -- MD STORAGE HELPER FUNCTIONS 66 | -- ---------------------------------------------------- 67 | 68 | function init_store_dir() 69 | 70 | if not siesta.IONode then 71 | -- only allow the IOnode to perform stuff... 72 | return 73 | end 74 | 75 | -- Create the directory where the dataset will be stored 76 | os.execute("mkdir " .. store_dir) 77 | 78 | -- Store the basis 79 | os.execute("mkdir " .. store_dir .. "/basis") 80 | os.execute("cp *.ion* " .. store_dir .. "/basis") 81 | end 82 | 83 | function store_step(istep) 84 | 85 | if not siesta.IONode then 86 | -- only allow the IOnode to perform stuff... 87 | return 88 | end 89 | 90 | -- If the step is a multiple of the store interval, store the frame 91 | if istep % store_interval == 0 then 92 | os.execute("mkdir " .. store_dir .. "/" .. store_step_prefix .. istep) 93 | os.execute("cp " .. store_files .. " " .. store_dir .. "/" .. store_step_prefix .. istep) 94 | end 95 | 96 | 97 | end 98 | {% endif %} 99 | 100 | {% if include_server_code %} 101 | -- ---------------------------------------------------- 102 | -- SERVER FUNCTIONALITY 103 | -- ---------------------------------------------------- 104 | 105 | function server_get(path) 106 | if not siesta.IONode then 107 | -- only IO node communicates with the server 108 | return 109 | end 110 | 111 | os.execute("curl '" .. server_address .. "/" .. path .. "'") 112 | end 113 | 114 | function init_history(len, matrix_ref) 115 | server_get("init?history_len=" .. len .. "&data_processor=" .. matrix_ref) 116 | end 117 | 118 | function add_step(path, matrix_ref) 119 | server_get("add_step?path=" .. path .. "&matrix_ref=" .. matrix_ref) 120 | end 121 | 122 | function add_matrix(path) 123 | server_get("add_matrix?path=" .. path) 124 | end 125 | 126 | function setup_processor(basis_dir, matrix) 127 | server_get("setup_processor?basis_dir=" .. basis_dir .. "&matrix=" .. matrix) 128 | end 129 | 130 | function predict_next(out) 131 | server_get("extrapolate?out=" .. out) 132 | end 133 | 134 | {% endif %} 135 | -------------------------------------------------------------------------------- /src/graph2mat/tools/viz/__init__.py: -------------------------------------------------------------------------------- 1 | """Visualization utilities""" 2 | 3 | from .sparse_plot import plot_basis_matrix 4 | -------------------------------------------------------------------------------- /src/graph2mat/tools/viz/sparse_plot.py: -------------------------------------------------------------------------------- 1 | import plotly.express as px 2 | import plotly.graph_objects as go 3 | 4 | import numpy as np 5 | from typing import Dict, Union 6 | 7 | import sisl 8 | from scipy.sparse import issparse, spmatrix 9 | 10 | from graph2mat import BasisConfiguration 11 | from graph2mat.bindings.e3nn.irreps_tools import get_atom_irreps 12 | 13 | 14 | def plot_basis_matrix( 15 | matrix: Union[np.ndarray, sisl.SparseCSR, spmatrix], 16 | configuration: Union[BasisConfiguration, sisl.Geometry, None] = None, 17 | point_lines: Union[bool, Dict] = False, 18 | basis_lines: Union[bool, Dict] = False, 19 | sc_lines: Union[bool, Dict] = False, 20 | colorscale: str = "RdBu", 21 | text: Union[bool, str] = False, 22 | basis_labels: bool = False, 23 | ) -> go.Figure: 24 | """Plots a matrix where rows and columns are spherical harmonics basis functions. 25 | 26 | Parameters 27 | ----------- 28 | matrix: 29 | the matrix, either as a numpy array or as a sisl sparse matrix. 30 | configuration: 31 | Should contain the point coordinates and types associated with the matrix. 32 | Only needed if separator lines or basis labels are requested. 33 | point_lines: 34 | If a boolean, whether to draw lines separating points, using default styles. 35 | If a dict, draws the lines with the specified plotly line styles. 36 | basis_lines: 37 | If a boolean, whether to draw lines separating sets of basis functions, using default styles. 38 | If a dict, draws the lines with the specified plotly line styles. 39 | sc_lines: 40 | If a boolean, whether to draw lines separating the supercells, using default styles. 41 | If a dict, draws the lines with the specified plotly line styles. 42 | colorscale: 43 | A plotly colorscale. 44 | text: 45 | If a boolean, whether to show the value of each element as text on top of it, using plotly's 46 | default formatting. 47 | If a string, show text with the specified format. E.g. text=".3f" shows the value with three 48 | decimal places. 49 | basis_labels: 50 | Whether to label the axes with the basis function indices. If True, the labels will be of 51 | the form "P: (l, m)", where `P` is the index of the point and l and m are the indices of 52 | the spherical harmonic. 53 | """ 54 | mode = "orbitals" 55 | 56 | if isinstance(matrix, sisl.SparseOrbital): 57 | if configuration is None: 58 | configuration = matrix.geometry 59 | 60 | matrix = matrix._csr 61 | elif isinstance(matrix, sisl.SparseAtom): 62 | if configuration is None: 63 | configuration = matrix.geometry 64 | 65 | matrix = matrix._csr 66 | mode = "atoms" 67 | 68 | geometry = configuration 69 | if isinstance(geometry, BasisConfiguration): 70 | geometry = geometry.to_sisl_geometry() 71 | 72 | if isinstance(matrix, sisl.SparseCSR): 73 | matrix = matrix.tocsr() 74 | 75 | if issparse(matrix): 76 | matrix = matrix.toarray() 77 | matrix[matrix == 0] = np.nan 78 | 79 | matrix = np.array(matrix) 80 | 81 | color_midpoint = None 82 | if np.sum(matrix < 0) > 0 and np.sum(matrix > 0) > 0: 83 | color_midpoint = 0 84 | 85 | fig = px.imshow( 86 | matrix, 87 | color_continuous_midpoint=color_midpoint, 88 | color_continuous_scale=colorscale, 89 | text_auto=text is True, 90 | ) 91 | 92 | if point_lines is not False and mode == "orbitals": 93 | if point_lines is True: 94 | point_lines = {} 95 | 96 | point_lines = {"color": "orange", **point_lines} 97 | 98 | for atom_last_o in geometry.lasto[:-1]: 99 | line_pos = atom_last_o + 0.5 100 | fig.add_hline( 101 | y=line_pos, 102 | line=point_lines, 103 | ) 104 | 105 | for i_s in range(geometry.n_s): 106 | fig.add_vline(x=line_pos + (i_s * geometry.no), line=point_lines) 107 | 108 | if basis_lines is not False and mode == "orbitals": 109 | if basis_lines is True: 110 | basis_lines = {} 111 | 112 | basis_lines = {"color": "black", "dash": "dot", **basis_lines} 113 | 114 | atom_irreps = [get_atom_irreps(atom) for atom in geometry.atoms.atom] 115 | 116 | curr_l = 0 117 | for atom_specie, atom_last_o in zip(geometry.atoms.species, geometry.lasto): 118 | irreps = atom_irreps[atom_specie] 119 | 120 | for ir in irreps: 121 | m = ir[0] 122 | l = ir[1].l 123 | for _ in range(m): 124 | curr_l += 2 * l + 1 125 | 126 | if curr_l == atom_last_o + 1: 127 | continue 128 | 129 | line_pos = curr_l - 0.5 130 | 131 | fig.add_hline( 132 | y=line_pos, 133 | line=basis_lines, 134 | ) 135 | 136 | for i_s in range(geometry.n_s): 137 | fig.add_vline( 138 | x=line_pos + (i_s * geometry.no), line=basis_lines 139 | ) 140 | 141 | if sc_lines is not False: 142 | if sc_lines is True: 143 | sc_lines = {} 144 | 145 | sc_lines = {"color": "black", **sc_lines} 146 | sc_len = geometry.no if mode == "orbitals" else geometry.na 147 | 148 | for i_s in range(1, geometry.n_s): 149 | fig.add_vline(x=(i_s * sc_len) - 0.5, line=sc_lines, name=i_s) 150 | 151 | if isinstance(text, str): 152 | fig.update_traces( 153 | texttemplate="%{z:" + text + "}", selector={"type": "heatmap"} 154 | ) 155 | 156 | if basis_labels: 157 | atoms_ticks = [] 158 | atoms = geometry.atoms.atom 159 | for i, atom in enumerate(atoms): 160 | atom_ticks = [] 161 | atoms_ticks.append(atom_ticks) 162 | for orb in atom.orbitals: 163 | atom_ticks.append(f"({orb.l}, {orb.m})") 164 | 165 | ticks = [] 166 | for i, specie in enumerate(geometry.atoms.species): 167 | ticks.extend([f"{i}: {orb}" for orb in atoms_ticks[specie]]) 168 | 169 | fig.update_layout( 170 | yaxis_ticktext=ticks, 171 | yaxis_tickvals=np.arange(geometry.no), 172 | xaxis_ticktext=ticks, 173 | xaxis_tickvals=np.arange(geometry.no), 174 | ) 175 | 176 | return fig 177 | --------------------------------------------------------------------------------