├── .github └── workflows │ ├── docs.yaml │ ├── pre-commit.yaml │ ├── requirements_cuda_ci.txt │ └── verify_extension_build.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── api.rst ├── conf.py ├── index.rst ├── installation.rst ├── livebuild.sh ├── models.rst ├── supported_ops.rst └── tests_and_benchmarks.rst ├── io ├── cif_to_graph.py └── load_nequip_configs.py ├── openequivariance ├── __init__.py ├── benchmark │ ├── ConvBenchmarkSuite.py │ ├── TestBenchmarkSuite.py │ ├── benchmark_utils.py │ ├── correctness_utils.py │ ├── logging_utils.py │ ├── perf_metrics_utils.py │ ├── plotting │ │ ├── __init__.py │ │ ├── plot_convolution.py │ │ ├── plot_double_backward.py │ │ ├── plot_roofline.py │ │ ├── plot_uvu.py │ │ ├── plot_uvw.py │ │ └── plotting_utils.py │ ├── problems.py │ ├── random_buffer_utils.py │ └── tpp_creation_utils.py ├── extension │ ├── CMakeLists.txt │ ├── convolution.hpp │ ├── generic_module.cpp │ ├── group_mm_cuda.hpp │ ├── group_mm_hip.hpp │ ├── libtorch_tp_jit.cpp │ ├── tensorproducts.hpp │ ├── test │ │ ├── CMakeLists.txt │ │ └── load_jitscript.cpp │ └── util │ │ ├── backend_cuda.hpp │ │ ├── backend_hip.hpp │ │ └── buffer.hpp ├── extlib │ ├── .empty │ └── __init__.py ├── implementations │ ├── CUETensorProduct.py │ ├── ComputationSchedule.py │ ├── E3NNTensorProduct.py │ ├── LoopUnrollTP.py │ ├── MultiplicityOuterProductTP.py │ ├── TensorProduct.py │ ├── TensorProductBase.py │ ├── convolution │ │ ├── CUEConv.py │ │ ├── ConvolutionBase.py │ │ ├── E3NNConv.py │ │ ├── LoopUnrollConv.py │ │ ├── TensorProductConv.py │ │ └── scatter.py │ ├── e3nn_lite.py │ ├── symmetric_contraction │ │ ├── __init__.py │ │ └── symmetric_contraction.py │ └── utils.py └── templates │ ├── common.cuh │ ├── jinja_utils.py │ ├── loop_unroll_batch.cuh │ ├── loop_unroll_conv_atomic.cuh │ ├── loop_unroll_conv_det.cuh │ ├── loop_unroll_tp.cuh │ ├── macros.jinja │ ├── subkernel_per_interaction_multirep.cuh │ └── wmm.cuh ├── pyproject.toml └── tests ├── batch_test.py ├── benchmark.py ├── conv_test.py ├── export_test.py ├── import_test.py ├── mace_driver.py └── multidevice_test.py /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: Deploy documentation to Github Pages 2 | on: 3 | workflow_dispatch: 4 | 5 | permissions: write-all 6 | 7 | concurrency: 8 | group: "pages" 9 | cancel-in-progress: false 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v4 17 | - name: Setup Pages 18 | uses: actions/configure-pages@v3 19 | - name: Set up Python 3.10 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: "3.10" 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install sphinx furo 27 | - name: Build website 28 | run: | 29 | sphinx-build -M dirhtml docs docs/_build 30 | 31 | - name: Fix permissions 32 | run: | 33 | chmod -c -R +rX "docs/_build/dirhtml/" | while read line; do 34 | echo "::warning title=Invalid file permissions automatically fixed::$line" 35 | done 36 | - name: Upload artifact 37 | uses: actions/upload-pages-artifact@v3 38 | with: 39 | path: './docs/_build/dirhtml' 40 | deploy: 41 | environment: 42 | name: github-pages 43 | url: ${{ steps.deployment.outputs.page_url }} 44 | runs-on: ubuntu-latest 45 | needs: build 46 | steps: 47 | - name: Deploy to GitHub Pages 48 | id: deployment 49 | uses: actions/deploy-pages@v4 -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: Pre-Commit Checks 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: pre-commit/action@v3.0.1 -------------------------------------------------------------------------------- /.github/workflows/requirements_cuda_ci.txt: -------------------------------------------------------------------------------- 1 | numpy==2.2.5 2 | torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128 3 | pytest==8.3.5 4 | ninja==1.11.1.4 5 | ruff==0.11.11 6 | pre-commit==4.2.0 -------------------------------------------------------------------------------- /.github/workflows/verify_extension_build.yml: -------------------------------------------------------------------------------- 1 | name: OEQ CUDA C++ Extension Build Verification 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | types: [ labeled ] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | verify_cuda_extension: 15 | if: ${{ github.event.label.name == 'ci-ready' || github.event_name != 'pull_request' }} 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Set up Python 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: "3.10" 24 | cache: 'pip' 25 | cache-dependency-path: '**/requirements_cuda_ci.txt' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | sudo apt-get update 30 | sudo apt install nvidia-cuda-toolkit 31 | pip install -r .github/workflows/requirements_cuda_ci.txt 32 | pip install -e . 33 | 34 | - name: Test extension build via import 35 | run: | 36 | pytest tests/import_test.py -k test_import -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .rendered* 2 | *.lock 3 | *.so 4 | *.ipynb 5 | .DS_Store 6 | *.code-workspace 7 | __pycache__ 8 | 9 | # working folders 10 | build 11 | outputs/* 12 | visualization/* 13 | figures/* 14 | data/* 15 | experimental/* 16 | MACE_models/* 17 | triton_autotuning/* 18 | sandbox/* 19 | docs/_build 20 | docs/_static 21 | docs/_templates 22 | 23 | # bash scripts 24 | get_gpu_node.sh 25 | env.sh 26 | 27 | nvidia-mathdx* 28 | .vscode/* 29 | *.ncu-rep 30 | mace_dev 31 | mace_oeq_integration 32 | valid_indices* 33 | 34 | *.xyz 35 | scratch.txt 36 | triton_autotuning 37 | paper_benchmarks 38 | paper_benchmarks_v2 39 | paper_benchmarks_v3 40 | openequivariance/extlib/*.so 41 | 42 | get_node.sh 43 | *.egg-info -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.11.11 5 | hooks: 6 | # Run the linter. 7 | - id: ruff-check 8 | # Run the formatter. 9 | - id: ruff-format -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2025, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include openequivariance/extlib/*.so 2 | include openequivariance/extlib/*.empty 3 | 4 | include openequivariance/templates/*.cuh 5 | include openequivariance/templates/*.jinja 6 | 7 | include openequivariance/extension/* 8 | include openequivariance/extension/convolution/* 9 | include openequivariance/extension/tensorproducts/* 10 | include openequivariance/extension/util/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenEquivariance 2 | [![OEQ CUDA C++ Extension Build Verification](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml/badge.svg?event=push)](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml) 3 | [![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) 4 | 5 | [[Examples]](#show-me-some-examples) 6 | [[Citation and Acknowledgements]](#citation-and-acknowledgements) 7 | 8 | OpenEquivariance is a CUDA and HIP kernel generator for the Clebsch-Gordon tensor product, 9 | a key kernel in rotation-equivariant deep neural networks. 10 | It implements some of the tensor products 11 | that [e3nn](https://e3nn.org/) supports 12 | commonly found in graph neural networks 13 | (e.g. [Nequip](https://github.com/mir-group/nequip) or 14 | [MACE](https://github.com/ACEsuit/mace)). To get 15 | started, ensure that you have GCC 9+ on your system 16 | and install our package via 17 | 18 | ```bash 19 | pip install git+https://github.com/PASSIONLab/OpenEquivariance 20 | ``` 21 | 22 | We provide up to an order of magnitude acceleration over e3nn perform on par with the latest 23 | version of [NVIDIA cuEquivariance](https://github.com/NVIDIA/cuEquivariance), 24 | which has a closed-source kernel package. 25 | We also offer fused equivariant graph 26 | convolutions that can reduce 27 | computation and memory consumption significantly. 28 | 29 | For detailed instructions on tests, benchmarks, MACE / Nequip, and our API, 30 | check out the [documentation](https://passionlab.github.io/OpenEquivariance). 31 | 32 | 📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and 33 | Computational Discrete Algorithms (Proceedings Track)! Catch the talk in 34 | Montréal and check out the [camera-ready copy on Arxiv](https://arxiv.org/abs/2501.13986) (available May 12, 2025). 35 | 36 | ## Show me some examples 37 | Here's a CG tensor product implemented by e3nn: 38 | 39 | ```python 40 | import torch 41 | import e3nn.o3 as o3 42 | 43 | gen = torch.Generator(device='cuda') 44 | 45 | batch_size = 1000 46 | X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e") 47 | X = torch.rand(batch_size, X_ir.dim, device='cuda', generator=gen) 48 | Y = torch.rand(batch_size, Y_ir.dim, device='cuda', generator=gen) 49 | 50 | instructions=[(0, 0, 0, "uvu", True)] 51 | 52 | tp_e3nn = o3.TensorProduct(X_ir, Y_ir, Z_ir, instructions, 53 | shared_weights=False, internal_weights=False).to('cuda') 54 | W = torch.rand(batch_size, tp_e3nn.weight_numel, device='cuda', generator=gen) 55 | 56 | Z = tp_e3nn(X, Y, W) 57 | print(torch.norm(Z)) 58 | ``` 59 | 60 | And here's the same tensor product using openequivariance. We require that your 61 | tensors are stored on a CUDA device for this to work: 62 | 63 | ```python 64 | import openequivariance as oeq 65 | 66 | problem = oeq.TPProblem(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False) 67 | tp_fast = oeq.TensorProduct(problem, torch_op=True) 68 | 69 | Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier 70 | print(torch.norm(Z)) 71 | ``` 72 | 73 | Our interface for `oeq.TPProblem` is almost a strict superset of 74 | `o3.TensorProduct` (two key differences: we 75 | impose `internal_weights=False` and add support for multiple datatypes). 76 | You can pass e3nn `Irreps` instances directly or 77 | use `oeq.Irreps`, which is identical. 78 | 79 | We recommend reading the [e3nn documentation and API reference](https://docs.e3nn.org/en/latest/) first, then using our kernels 80 | as drop-in replacements. We support most "uvu" and "uvw" tensor products; 81 | see [this section](#tensor-products-we-accelerate) for an up-to-date list of supported configurations. 82 | 83 | **Important**: For many configurations, our kernels return results identical to 84 | e3nn up to floating point roundoff (this includes all "uvu" problems with 85 | multiplicity 1 for all irreps in the second input). For other configurations 86 | (e.g. any "uvw" connection modes), we return identical 87 | results up to a well-defined reordering of the weights relative to e3nn. 88 | 89 | If you're executing tensor products as part of a message passing graph 90 | neural network, we offer fused kernels that save both memory and compute time: 91 | 92 | ```python 93 | from torch_geometric import EdgeIndex 94 | 95 | node_ct, nonzero_ct = 3, 4 96 | 97 | # Receiver, sender indices for message passing GNN 98 | edge_index = EdgeIndex( 99 | [[0, 1, 1, 2], # Receiver 100 | [1, 0, 2, 1]], # Sender 101 | device='cuda', 102 | dtype=torch.long) 103 | 104 | X = torch.rand(node_ct, X_ir.dim, device='cuda', generator=gen) 105 | Y = torch.rand(nonzero_ct, Y_ir.dim, device='cuda', generator=gen) 106 | W = torch.rand(nonzero_ct, problem.weight_numel, device='cuda', generator=gen) 107 | 108 | tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=False) # Reuse problem from earlier 109 | Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1]) # Z has shape [node_ct, z_ir.dim] 110 | print(torch.norm(Z)) 111 | ``` 112 | 113 | If you can guarantee `EdgeIndex` is sorted by receiver index and supply the transpose 114 | permutation, we can provide even greater speedup (and deterministic results) 115 | by avoiding atomics: 116 | 117 | ```python 118 | _, sender_perm = edge_index.sort_by("col") # Sort by sender index 119 | edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index 120 | 121 | # Now we can use the faster deterministic algorithm 122 | tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True) 123 | Z = tp_conv.forward(X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm) 124 | print(torch.norm(Z)) 125 | ``` 126 | **Note**: you don't need Pytorch geometric to use our kernels. When 127 | `deterministic=False`, the `sender` and `receiver` indices can have 128 | arbitrary order. 129 | 130 | ## Citation and Acknowledgements 131 | If you find this code useful, please cite our paper: 132 | 133 | ```bibtex 134 | @inbook{openequivariance, 135 | author={Vivek Bharadwaj and Austin Glover and Aydin Buluc and James Demmel}, 136 | title={An Efficient Sparse Kernel Generator for O(3)-Equivariant Deep Networks}, 137 | booktitle = {SIAM Conference on Applied and Computational Discrete Algorithms (ACDA25)}, 138 | chapter = {}, 139 | url={https://arxiv.org/abs/2501.13986}, 140 | publisher={Society for Industrial and Applied Mathematics}, 141 | year={2025} 142 | } 143 | ``` 144 | 145 | Our codebase includes a lightweight clone of 146 | [e3nn](https://e3nn.org/)'s frontend interface (in particular, the 147 | `TensorProduct` and `Irreps` classes). We removed references to Pytorch 148 | and separated the implementation from the problem description (for future 149 | frontend support outside of torch). We also extracted the Wigner 3j tensor generating code from QuTiP. Thank you to the current 150 | developers and maintainers! 151 | 152 | ## Copyright 153 | 154 | Copyright (c) 2025, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. 155 | 156 | If you have questions about your rights to use or distribute this software, please contact Berkeley Lab's Intellectual Property Office at IPO@lbl.gov. 157 | 158 | NOTICE. This Software was developed under funding from the U.S. Department of Energy and the U.S. Government consequently retains certain rights. As such, the U.S. Government has been granted for itself and others acting on its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the Software to reproduce, distribute copies to the public, prepare derivative works, and perform publicly and display publicly, and to permit others to do so. -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | OpenEquivariance API 2 | ============================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :caption: Contents: 7 | 8 | OpenEquivariance exposes two key classes: :py:class:`openequivariance.TensorProduct`, which replaces 9 | ``o3.TensorProduct`` from e3nn, and :py:class:`openequivariance.TensorProductConv`, which fuses 10 | the CG tensor product with a subsequent graph convolution. Initializing either class triggers 11 | JIT compilation of a custom kernel, which can take a few seconds. 12 | 13 | Both classes require a configuration object specified 14 | by :py:class:`openequivariance.TPProblem`, which has a constructor 15 | almost identical to ``o3.TensorProduct``. 16 | We recommend reading the `e3nn documentation `_ before 17 | trying our code. OpenEquivariance cannot accelerate all tensor products; see 18 | :doc:`this page ` for a list of supported configurations. 19 | 20 | .. autoclass:: openequivariance.TensorProduct 21 | :members: 22 | :undoc-members: 23 | :exclude-members: name 24 | 25 | .. autoclass:: openequivariance.TensorProductConv 26 | :members: 27 | :undoc-members: 28 | :exclude-members: name 29 | 30 | .. autoclass:: openequivariance.TPProblem 31 | :members: 32 | :undoc-members: 33 | 34 | .. autofunction:: openequivariance.torch_to_oeq_dtype 35 | 36 | .. autofunction:: openequivariance.torch_ext_so_path 37 | 38 | API Identical to e3nn 39 | --------------------- 40 | 41 | These remaining API members are identical to the corresponding 42 | objects in ``e3nn.o3``. You can freely mix these objects from 43 | both packages. 44 | 45 | .. autoclass:: openequivariance.Irreps 46 | :members: 47 | :undoc-members: -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | # Configuration file for the Sphinx documentation builder. 5 | # 6 | # For the full list of built-in configuration values, see the documentation: 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 8 | 9 | # -- Project information ----------------------------------------------------- 10 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 11 | 12 | project = "OpenEquivariance" 13 | copyright = "2025, The Regents of the University of California, through Lawrence Berkeley National Laboratory." 14 | author = "Vivek Bharadwaj, Austin Glover, Aydin Buluc, James Demmel" 15 | 16 | # -- General configuration --------------------------------------------------- 17 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 18 | 19 | extensions = [] 20 | 21 | templates_path = ["_templates"] 22 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 23 | 24 | 25 | # -- Options for HTML output ------------------------------------------------- 26 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 27 | 28 | html_theme = "furo" 29 | # html_static_path = ["_static"] 30 | 31 | extensions = [ 32 | "sphinx.ext.autodoc", 33 | ] 34 | 35 | sys.path.insert(0, str(Path("..").resolve())) 36 | 37 | autodoc_mock_imports = ["torch", "openequivariance.extlib", "jinja2", "numpy"] 38 | autodoc_typehints = "description" 39 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. OpenEquivariance documentation master file, created by 2 | sphinx-quickstart on Tue Jun 3 00:20:54 2025. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | OpenEquivariance 7 | ============================== 8 | 9 | `OpenEquivariance `_ is a CUDA and 10 | HIP kernel generator for the Clebsch-Gordon 11 | tensor product, a key kernel in equivariant graph neural networks. We offer 12 | an identical interface to e3nn and produce the same results 13 | (up to numerical roundoff). Our package exhibits up to an order of magnitude 14 | speedup over e3nn and competitive performance with NVIDIA's cuEquivariance. 15 | 16 | Here, you can find our API reference, installation instructions, 17 | and troubleshooting guide. We support for both NVIDIA and AMD GPUs through 18 | our PyTorch interface, including support for JITScript compilation accessible 19 | from C++. 20 | 21 | .. toctree:: 22 | :maxdepth: 1 23 | :caption: Contents: 24 | 25 | installation 26 | api 27 | supported_ops 28 | tests_and_benchmarks 29 | models 30 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :caption: Contents: 7 | 8 | You need the following to install OpenEquivariance: 9 | 10 | - A Linux system equipped with an NVIDIA / AMD graphics card. 11 | - PyTorch >= 2.4 (>= 2.8 for AOTI and export). 12 | - GCC 9+ and the CUDA / HIP toolkit. The command 13 | ``c++ --version`` should return >= 9.0; see below for details on 14 | setting an alternate compiler. 15 | 16 | Installation is one easy command, followed by import verification: 17 | 18 | .. code-block:: bash 19 | 20 | pip install git+https://github.com/PASSIONLab/OpenEquivariance 21 | python -c "import openequivariance" 22 | 23 | The second line triggers a build of the C++ extension we use to compile 24 | kernels, which can take a couple of minutes. Subsequent imports are 25 | much faster since this extension is cached. 26 | 27 | 28 | Compiling the Integrated PyTorch Extension 29 | ------------------------------------------ 30 | To support ``torch.compile``, ``torch.export``, and 31 | JITScript, OpenEquivariance needs to compile a C++ extension 32 | tightly integrated with PyTorch. If you see a warning that 33 | this extension could not be compiled, first check: 34 | 35 | .. code-block:: bash 36 | 37 | c++ --version 38 | 39 | To build the extension with an alternate compiler, set the 40 | ``CC`` and ``CXX`` 41 | environment variable and retry the import: 42 | 43 | .. code-block:: bash 44 | 45 | export CCC=/path/to/your/gcc 46 | export CXX=/path/to/your/g++ 47 | python -c "import openequivariance" 48 | 49 | These configuration steps are required only ONCE after 50 | installation (or upgrade) with pip. 51 | 52 | Configurations on Major Platforms 53 | --------------------------------- 54 | OpenEquivariance has been tested on both supercomputers and lab clusters. 55 | Here are some tested environment configuration files. If use OpenEquivariance 56 | on a widely-used platform, send us a pull request to add your configuration! 57 | 58 | NERSC Perlmutter (NVIDIA A100) 59 | """""""""""""""""""""""""""""" 60 | 61 | .. code-block:: bash 62 | :caption: env.sh (last updated June 2025) 63 | 64 | module load gcc 65 | module load conda 66 | 67 | # Deactivate any base environments 68 | for i in $(seq ${CONDA_SHLVL}); do 69 | conda deactivate 70 | done 71 | 72 | conda activate 73 | 74 | 75 | OLCF Frontier (AMD MI250x) 76 | """""""""""""""""""""""""" 77 | You need to install a HIP-enabled verison of PyTorch to use our package. 78 | To do this, follow the steps `here `_. 79 | 80 | 81 | .. code-block:: bash 82 | :caption: env.sh (last updated June 2025) 83 | 84 | module load PrgEnv-gnu/8.6.0 85 | module load miniforge3/23.11.0-0 86 | module load rocm/6.4.0 87 | module load craype-accel-amd-gfx90a 88 | 89 | for i in $(seq ${CONDA_SHLVL}); do 90 | conda deactivate 91 | done 92 | 93 | conda activate 94 | export CC=cc 95 | export CXX=CC -------------------------------------------------------------------------------- /docs/livebuild.sh: -------------------------------------------------------------------------------- 1 | sphinx-autobuild -M dirhtml . _build --watch ../openequivariance -------------------------------------------------------------------------------- /docs/models.rst: -------------------------------------------------------------------------------- 1 | Running MACE and Nequip 2 | ============================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :caption: Contents: 7 | 8 | MACE 9 | ---- 10 | 11 | We have modified MACE to use our accelerated kernels instead 12 | of the standard e3nn backend. Here are the steps to replicate 13 | our MACE benchmark: 14 | 15 | 1. Install ``oeq`` and our modified version of MACE via 16 | 17 | .. code-block:: bash 18 | 19 | pip uninstall mace-torch 20 | pip install git+https://github.com/vbharadwaj-bk/mace_oeq_integration.git@oeq_experimental 21 | 22 | 2. Download the ``carbon.xyz`` data file, available at 23 | ``_. 24 | 25 | This graph has 158K edges. With the original e3nn backend, you would need a GPU with 80GB 26 | of memory to run the experiments. ``oeq`` provides a memory-efficient equivariant convolution, 27 | so we expect the test to succeed. 28 | 29 | 3. Benchmark OpenEquivariance: 30 | 31 | .. code-block:: bash 32 | 33 | python tests/mace_driver.py carbon.xyz -o outputs/mace_tests -i oeq 34 | 35 | 4. If you have a GPU with 80GB of memory *or* supply a smaller molecular graph 36 | as the input file, you can run the full benchmark that includes ``e3nn`` and ``cue``: 37 | 38 | .. code-block:: bash 39 | 40 | python tests/mace_driver.py carbon.xyz -o outputs/mace_tests -i e3nn cue oeq 41 | 42 | Nequip 43 | ------ 44 | See the 45 | `official Nequip documentation `_ 46 | to use OpenEquivariance with Nequip. 47 | -------------------------------------------------------------------------------- /docs/supported_ops.rst: -------------------------------------------------------------------------------- 1 | Supported Operations 2 | ==================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :caption: Contents: 7 | 8 | .. list-table:: 9 | :widths: 50 25 25 10 | :header-rows: 1 11 | 12 | * - Operation 13 | - CUDA 14 | - HIP 15 | * - UVU 16 | - ✅ 17 | - ✅ 18 | * - UVW 19 | - ✅ 20 | - ✅ 21 | * - UVU + Convolution 22 | - ✅ 23 | - ✅ 24 | * - UVW + Convolution 25 | - ✅ 26 | - ✅ 27 | * - Symmetric Tensor Product 28 | - ✅ (beta) 29 | - ✅ (beta) 30 | 31 | e3nn supports a variety of connection modes for CG tensor products. We support 32 | two that are commonly used in equivariant graph neural networks: 33 | "uvu" and "uvw". Our JIT compiled kernels should handle: 34 | 35 | 1. Pure "uvu" tensor products, which are most efficient when the input with higher 36 | multiplicities is the first argument. Our results are identical to e3nn when irreps in 37 | the second input have multiplicity 1, and otherwise identical up to a reordering 38 | of the input weights. 39 | 40 | 2. Pure "uvw" tensor products, which are currently more efficient when the input with 41 | higher multiplicities is the first argument. Our results are identical to e3nn up to a reordering 42 | of the input weights. 43 | 44 | Our code includes correctness checks, but the configuration space is large. If you notice 45 | a bug, let us know in a GitHub issue. We'll try our best to correct it or document the problem here. 46 | 47 | Unsupported Tensor Product Configurations 48 | ----------------------------------------- 49 | 50 | We do not (yet) support: 51 | 52 | - Mixing different instruction types in the same tensor product. 53 | - Instruction types besides "uvu" and "uvw". 54 | - Non-trainable instructions: all of your instructions must have weights associated. 55 | 56 | If you have a use case for any of the unsupported features above, let us know. 57 | 58 | Compilation with JITScript, Export, and AOTInductor 59 | --------------------------------------------------- 60 | 61 | OpenEquivariance supports model compilation with 62 | ``torch.compile``, JITScript, ``torch.export``, and AOTInductor. 63 | Demo the C++ model exports with 64 | 65 | .. code-block:: bash 66 | 67 | pytest tests/export_test.py 68 | 69 | 70 | NOTE: the AOTInductor test (and possibly export) fail 71 | unless you are using a Nightly 72 | build of PyTorch past 4/10/2025 due to incomplete support for 73 | TorchBind in earlier versions. 74 | 75 | 76 | Multiple Devices and Streams 77 | ---------------------------- 78 | OpenEquivariance compiles kernels based on the compute capability of the 79 | first visible GPU. On heterogeneous systems, our kernels 80 | will only execute correctly on devices that share the compute capability 81 | of this first device. 82 | 83 | We are working on support for CUDA streams! 84 | 85 | 86 | Symmetric Contraction (Beta) 87 | ---------------------------- 88 | 89 | We have recently added beta support for symmetric 90 | contraction acceleration. This primitive: 91 | 92 | - Is specific to MACE 93 | - Requires e3nn as a dependency. 94 | - Currently has no support for compile / export 95 | 96 | As a result, we do not expose it in the package 97 | toplevel. You can use our implementation by running 98 | 99 | .. code-block:: 100 | 101 | from openequivariance.implementations.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction -------------------------------------------------------------------------------- /docs/tests_and_benchmarks.rst: -------------------------------------------------------------------------------- 1 | Tests and Benchmarks 2 | ============================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :caption: Contents: 7 | 8 | OpenEquivariance is equipped with a comprehensive suite of tests 9 | and benchmarking utilities. You'll need some additional dependencies to run 10 | these; we provide instructions below. 11 | 12 | We recommend you clone our repository and use an editable install to run tests 13 | and benchmarks. You can still test our code with a non-editable install; just 14 | download the test folder and install only the dependencies with: 15 | 16 | .. code-block:: bash 17 | 18 | pip install "https://github.com/PASSIONLab/OpenEquivariance[dev]" --only-deps 19 | pip install "https://github.com/PASSIONLab/OpenEquivariance[bench]" --only-deps 20 | 21 | Correctness 22 | ------------------------------ 23 | To set up the editable install and run the entire testsuite, use: 24 | 25 | .. code-block:: bash 26 | 27 | git clone https://github.com/PASSIONLab/OpenEquivariance 28 | cd OpenEquivariance 29 | pip install -e .[dev] 30 | pytest 31 | 32 | Browse the ``tests`` directory to run specific components. 33 | 34 | 35 | Replicating our Benchmarks 36 | ------------------------------ 37 | We conducted our benchmarks on an NVIDIA A100-SXM-80GB GPU at Lawrence Berkeley National Laboratory. 38 | Your results may differ a different GPU. The following invocations run the experiments 39 | and generate plots from our paper. 40 | 41 | .. code-block:: bash 42 | 43 | git clone https://github.com/PASSIONLab/OpenEquivariance 44 | cd OpenEquivariance 45 | pip install -e .[bench] 46 | python tests/benchmark.py -o outputs/uvu uvu --plot 47 | python tests/benchmark.py -o outputs/uvw uvw --plot 48 | python tests/benchmark.py -o outputs/roofline roofline --plot 49 | python tests/benchmark.py -o outputs/conv conv --plot --data data/molecular_structures 50 | python tests/benchmark.py -o outputs/kahan_conv kahan_conv --data data/molecular_structures/ 51 | 52 | If your GPU has limited memory, try the ``--limited-memory`` flag 53 | to disable some expensive tests and / or reduce the batch size with ``-b``. 54 | Run ``python tests/benchmark.py --help`` for a full list of flags. 55 | 56 | For example, here's a set of invocations for an NVIDIA A5000 GPU: 57 | 58 | .. code-block:: bash 59 | 60 | python tests/benchmark.py -o outputs/uvu uvu --limited-memory --plot 61 | python tests/benchmark.py -o outputs/uvw uvw -b 25000 --plot 62 | python tests/benchmark.py -o outputs/roofline roofline --plot 63 | python tests/benchmark.py -o outputs/conv conv --data data/molecular_structures --limited-memory 64 | 65 | For GPUs besides the NVIDIA A100, the roofline slope / peak will be incorrect. 66 | The plots for the convolution fusion experiments also require a GPU 67 | with a minimum of 40GB of memory. 68 | 69 | List of GPUs Tested 70 | -------------------------------- 71 | OpenEquivariance has been tested successfully the following GPUs. Submit a pull 72 | request if you'd like to add your own! 73 | 74 | - NVIDIA A100-SXM-40GB and A100-SXM-80GB (A. Glover, NERSC Perlmutter, June 2025) 75 | - NVIDIA A5000 (V. Bharadwaj, UCB SLICE, June 2025) 76 | - NVIDIA V100 (V. Bharadwaj, LBNL Einsteinium, June 2025) 77 | - AMD MI250x (V. Bharadwaj, OLCF Frontier, June 2025) 78 | - AMD MI300x (V. Bharadwaj, AMD Cloud, February 2025) -------------------------------------------------------------------------------- /io/cif_to_graph.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from sklearn.neighbors import radius_neighbors_graph 4 | 5 | 6 | def cif_to_molecular_graph(cif_file, cp, radii): 7 | with open(f"../data/cif_files/{cif_file}", "r") as f: 8 | print("Started reading file...") 9 | lines = f.readlines() 10 | print("Finished reading file!") 11 | 12 | coords = [] 13 | for line in lines: 14 | if line.startswith("ATOM"): 15 | parts = line.split() 16 | coords.append( 17 | [float(parts[cp[0]]), float(parts[cp[1]]), float(parts[cp[2]])] 18 | ) 19 | 20 | coords = np.array(coords) 21 | 22 | for radius in radii: 23 | print(f"Starting radius neighbors calculation, r={radius}") 24 | A = radius_neighbors_graph( 25 | coords, radius, mode="connectivity", include_self=False 26 | ) 27 | print(f"Finished radius neighbors calculation, found {A.nnz} nonzeros.") 28 | 29 | # mmwrite(f'../data/molecular_structures/{cif_file.split(".")[0]}.mtx', A) 30 | 31 | coo_mat = A.tocoo() 32 | result = {"row": coo_mat.row, "col": coo_mat.col, "coords": coords} 33 | 34 | with open( 35 | f"../data/molecular_structures/{cif_file.split('.')[0]}_radius{radius}.pickle", 36 | "wb", 37 | ) as handle: 38 | pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL) 39 | 40 | 41 | if __name__ == "__main__": 42 | # cif_to_molecular_graph('hiv_capsid.cif', (10, 11, 12), radii=[2.0, 2.5, 3.0, 3.5]) 43 | cif_to_molecular_graph("covid_spike.cif", (10, 11, 12), radii=[2.0, 2.5, 3.0, 3.5]) 44 | cif_to_molecular_graph("1drf.cif", (10, 11, 12), radii=[6.0]) 45 | -------------------------------------------------------------------------------- /io/load_nequip_configs.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script parse the repository of 3 | Nequip input files at 4 | https://github.com/mir-group/nequip-input-files. 5 | We extract the node / edge hidden features representations. 6 | """ 7 | 8 | import os 9 | import yaml 10 | 11 | 12 | def process_nequip_configs(): 13 | nequip_files = [] 14 | for root, dirs, files in os.walk("../data/nequip-input-files"): 15 | for file in files: 16 | if file.endswith(".yaml"): 17 | nequip_files.append(os.path.join(root, file)) 18 | 19 | irrep_pairs = [] 20 | configs = [] 21 | for file in nequip_files: 22 | with open(file, "r") as f: 23 | data = yaml.unsafe_load(f) 24 | filename = os.path.splitext(os.path.basename(file))[0] 25 | feature_irreps_hidden = data["feature_irreps_hidden"] 26 | irreps_edge_sh = data["irreps_edge_sh"] 27 | if (feature_irreps_hidden, irreps_edge_sh) not in irrep_pairs: 28 | irrep_pairs.append((feature_irreps_hidden, irreps_edge_sh)) 29 | configs.append((feature_irreps_hidden, irreps_edge_sh, filename)) 30 | 31 | for config in configs: 32 | print(config) 33 | 34 | 35 | if __name__ == "__main__": 36 | process_nequip_configs() 37 | -------------------------------------------------------------------------------- /openequivariance/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F401 2 | import sys 3 | import openequivariance.extlib 4 | from pathlib import Path 5 | from importlib.metadata import version 6 | 7 | from openequivariance.implementations.e3nn_lite import TPProblem, Irreps 8 | from openequivariance.implementations.TensorProduct import TensorProduct 9 | from openequivariance.implementations.convolution.TensorProductConv import ( 10 | TensorProductConv, 11 | ) 12 | from openequivariance.implementations.utils import torch_to_oeq_dtype 13 | 14 | __version__ = None 15 | try: 16 | __version__ = version("openequivariance") 17 | except Exception as e: 18 | print(f"Warning: Could not determine oeq version: {e}", file=sys.stderr) 19 | 20 | 21 | def _check_package_editable(): 22 | import json 23 | from importlib.metadata import Distribution 24 | 25 | direct_url = Distribution.from_name("openequivariance").read_text("direct_url.json") 26 | return json.loads(direct_url).get("dir_info", {}).get("editable", False) 27 | 28 | 29 | _editable_install_output_path = Path(__file__).parent.parent / "outputs" 30 | 31 | 32 | def torch_ext_so_path(): 33 | """ 34 | :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance 35 | from the PyTorch C++ Interface. 36 | """ 37 | return openequivariance.extlib.torch_module.__file__ 38 | 39 | 40 | __all__ = [ 41 | "TPProblem", 42 | "Irreps", 43 | "TensorProduct", 44 | "TensorProductConv", 45 | "torch_to_oeq_dtype", 46 | "_check_package_editable", 47 | "torch_ext_so_path", 48 | ] 49 | -------------------------------------------------------------------------------- /openequivariance/benchmark/ConvBenchmarkSuite.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | import pickle 5 | import pathlib 6 | import numpy as np 7 | 8 | import openequivariance as oeq 9 | from openequivariance.benchmark.logging_utils import getLogger 10 | from openequivariance.implementations.convolution.ConvolutionBase import CoordGraph 11 | 12 | logger = getLogger() 13 | 14 | 15 | def load_graph(filename): 16 | coords, rows, cols = [None] * 3 17 | name = pathlib.Path(filename).stem 18 | with open(filename, "rb") as f: 19 | logger.info(f"Loading {name} from pickle...") 20 | result = pickle.load(f) 21 | coords, rows, cols, name = result["coords"], result["row"], result["col"], name 22 | logger.info( 23 | f"Graph {name} loaded with {len(coords)} nodes and {len(rows)} edges." 24 | ) 25 | 26 | return CoordGraph(coords, rows.astype(np.int64), cols.astype(np.int64), name) 27 | 28 | 29 | class ConvBenchmarkSuite: 30 | def __init__( 31 | self, 32 | configs, 33 | num_warmup=10, 34 | num_iter=30, 35 | reference_impl=None, 36 | test_name=None, 37 | prng_seed=12345, 38 | correctness_threshold=1e-5, 39 | ): 40 | self.configs = configs 41 | self.num_warmup = num_warmup 42 | self.num_iter = num_iter 43 | self.reference_impl = reference_impl 44 | self.prng_seed = 12345 45 | self.correctness_threshold = correctness_threshold 46 | self.exp_count = 0 47 | self.test_name = test_name 48 | 49 | self.millis_since_epoch = round(time.time() * 1000) 50 | 51 | def run( 52 | self, 53 | graph, 54 | implementations, 55 | direction, 56 | output_folder=None, 57 | correctness=True, 58 | benchmark=True, 59 | high_precision_ref=False, 60 | ): 61 | if output_folder is None: 62 | if oeq._check_package_editable(): 63 | output_folder = ( 64 | oeq._editable_install_output_path / f"{self.millis_since_epoch}" 65 | ) 66 | else: 67 | raise ValueError( 68 | "output folder must be specified for non-editable installs." 69 | ) 70 | else: 71 | output_folder = pathlib.Path(output_folder) 72 | output_folder.mkdir(parents=True, exist_ok=True) 73 | 74 | metadata = { 75 | "test_name": self.test_name, 76 | "configs": [str(config) for config in self.configs], 77 | "implementations": [impl.name() for impl in implementations], 78 | "graph": graph.name, 79 | } 80 | if self.exp_count == 0: 81 | with open(os.path.join(output_folder, "metadata.json"), "w") as f: 82 | json.dump(metadata, f, indent=2) 83 | 84 | for config in self.configs: 85 | for impl in implementations: 86 | tc_name = f"{config}, {impl.name()}" 87 | logger.info(f"Starting {tc_name}, graph {graph.name}, {direction}") 88 | conv = impl(config) 89 | 90 | if direction == "forward": 91 | if correctness: 92 | correctness = conv.test_correctness_forward( 93 | graph, 94 | thresh=self.correctness_threshold, 95 | prng_seed=self.prng_seed, 96 | reference_implementation=self.reference_impl, 97 | high_precision_ref=high_precision_ref, 98 | ) 99 | 100 | if benchmark: 101 | benchmark = conv.benchmark_forward( 102 | self.num_warmup, self.num_iter, graph, prng_seed=12345 103 | ) 104 | 105 | if direction == "backward": 106 | if correctness: 107 | correctness = conv.test_correctness_backward( 108 | graph, 109 | thresh=self.correctness_threshold, 110 | prng_seed=self.prng_seed, 111 | reference_implementation=self.reference_impl, 112 | high_precision_ref=high_precision_ref, 113 | ) 114 | 115 | if benchmark: 116 | benchmark = conv.benchmark_backward( 117 | self.num_warmup, self.num_iter, graph, prng_seed=12345 118 | ) 119 | 120 | if direction == "double_backward": 121 | if correctness: 122 | correctness = conv.test_correctness_double_backward( 123 | self.graph, 124 | thresh=self.correctness_threshold, 125 | prng_seed=self.prng_seed, 126 | reference_implementation=self.reference_impl, 127 | high_precision_ref=high_precision_ref, 128 | ) 129 | 130 | assert not benchmark 131 | 132 | result = { 133 | "config": str(config), 134 | "irrep_dtype": str(config.irrep_dtype), 135 | "weight_dtype": str(config.weight_dtype), 136 | "torch_overhead_included": conv.torch_op, 137 | "direction": direction, 138 | "graph": graph.name, 139 | "name": impl.name(), 140 | "correctness": correctness, 141 | "benchmark": benchmark, 142 | } 143 | 144 | fname = pathlib.Path( 145 | f"{output_folder}/{self.exp_count}_{impl.name()}_{graph.name}.json" 146 | ) 147 | with open(fname, "w") as f: 148 | json.dump(result, f, indent=2) 149 | self.exp_count += 1 150 | 151 | logger.info(f"Finished {tc_name}, graph {graph.name}") 152 | 153 | return output_folder 154 | -------------------------------------------------------------------------------- /openequivariance/benchmark/correctness_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | from openequivariance.implementations.TensorProductBase import TensorProductBase 4 | from openequivariance.implementations.CUETensorProduct import CUETensorProduct 5 | from openequivariance.implementations.e3nn_lite import TPProblem 6 | from openequivariance.benchmark.random_buffer_utils import ( 7 | get_random_buffers_forward, 8 | get_random_buffers_backward, 9 | ) 10 | from openequivariance.benchmark.logging_utils import getLogger, bcolors 11 | import numpy as np 12 | import numpy.linalg as la 13 | 14 | logger = getLogger() 15 | 16 | 17 | def check_similiarity( 18 | name: str, 19 | to_check: np.ndarray, 20 | ground_truth: np.ndarray, 21 | correctness_threshold: float, 22 | ): 23 | result = {} 24 | if to_check.shape != ground_truth.shape: 25 | result["shape_match"] = False 26 | result["diff_Linf_norm"] = np.inf 27 | result["pass"] = False 28 | logger.error( 29 | f"{bcolors.FAIL}Ground truth {name} shape does not match input! {to_check.shape=}, {ground_truth.shape=} {bcolors.ENDC}" 30 | ) 31 | else: 32 | result["shape_match"] = True 33 | diff_Linf_norm = float(la.norm((ground_truth - to_check).flatten(), ord=np.inf)) 34 | result["diff_Linf_norm"] = diff_Linf_norm 35 | result["pass"] = bool(diff_Linf_norm < correctness_threshold) 36 | if result["pass"]: 37 | logger.info( 38 | f" {bcolors.OKGREEN}{name} correctness check pass. {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}" 39 | ) 40 | else: 41 | logger.error( 42 | f"{bcolors.FAIL}{name} correctness check fail! {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}" 43 | ) 44 | 45 | return result 46 | 47 | 48 | def instantiate_implementation( 49 | implementation: Union[type[TensorProductBase], TensorProductBase], 50 | problem: TPProblem, 51 | ): 52 | if isinstance(implementation, type): 53 | test_tp = implementation(problem) 54 | else: 55 | test_tp = implementation 56 | 57 | if not isinstance(test_tp, TensorProductBase): 58 | raise TypeError( 59 | f"test_implementation must be a TensorProductBase or a subclass, got {type(implementation)}" 60 | ) 61 | 62 | return test_tp 63 | 64 | 65 | def correctness_forward( 66 | problem: TPProblem, 67 | test_implementation: Union[type[TensorProductBase], TensorProductBase], 68 | reference_implementation: Optional[type[TensorProductBase]], 69 | batch_size: int, 70 | correctness_threshold: float, 71 | prng_seed: int, 72 | ) -> dict: 73 | if reference_implementation is None: 74 | from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct 75 | 76 | reference_implementation = E3NNTensorProduct 77 | 78 | result = {"thresh": correctness_threshold, "batch_size": batch_size} 79 | 80 | in1, in2, weights, out = get_random_buffers_forward(problem, batch_size, prng_seed) 81 | 82 | # run reference 83 | ref_tp = reference_implementation(problem) 84 | 85 | ref_out = out.copy() 86 | ref_tp.forward_cpu( 87 | L1_in=in1.copy(), L2_in=in2.copy(), L3_out=ref_out, weights=weights.copy() 88 | ) 89 | 90 | weights_copy = weights.copy() 91 | if problem.shared_weights and test_implementation == CUETensorProduct: 92 | weights_copy = weights[np.newaxis, :] 93 | 94 | # run test 95 | test_tp = instantiate_implementation(test_implementation, problem) 96 | test_out = out.copy() 97 | test_tp.forward_cpu( 98 | L1_in=in1.copy(), L2_in=in2.copy(), L3_out=test_out, weights=weights_copy 99 | ) 100 | 101 | for name, to_check, ground_truth in [("output", ref_out, test_out)]: 102 | result[name] = check_similiarity( 103 | name, to_check, ground_truth, correctness_threshold 104 | ) 105 | 106 | return result 107 | 108 | 109 | def correctness_backward( 110 | problem: TPProblem, 111 | test_implementation: Union[type[TensorProductBase], TensorProductBase], 112 | reference_implementation: Optional[type[TensorProductBase]], 113 | batch_size: int, 114 | correctness_threshold: float, 115 | prng_seed: int, 116 | ) -> dict: 117 | if reference_implementation is None: 118 | from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct 119 | 120 | reference_implementation = E3NNTensorProduct 121 | 122 | result = {"thresh": correctness_threshold, "batch_size": batch_size} 123 | 124 | # run reference 125 | in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = ( 126 | get_random_buffers_backward(problem, batch_size, prng_seed) 127 | ) 128 | 129 | ref_tp = reference_implementation(problem) 130 | 131 | ref_weights_grad = weights_grad.copy() 132 | ref_in1_grad = in1_grad.copy() 133 | ref_in2_grad = in2_grad.copy() 134 | 135 | ref_tp.backward_cpu( 136 | L1_in=in1.copy(), 137 | L1_grad=ref_in1_grad, 138 | L2_in=in2.copy(), 139 | L2_grad=ref_in2_grad, 140 | L3_grad=out_grad.copy(), 141 | weights=weights.copy(), 142 | weights_grad=ref_weights_grad, 143 | ) 144 | 145 | # run test version 146 | test_weights_grad = weights_grad.copy() 147 | test_in1_grad = in1_grad.copy() 148 | test_in2_grad = in2_grad.copy() 149 | 150 | weights_copy = weights.copy() 151 | 152 | if problem.shared_weights and test_implementation == CUETensorProduct: 153 | weights_copy = weights[np.newaxis, :] 154 | test_weights_grad = test_weights_grad[np.newaxis, :] 155 | 156 | test_tp = instantiate_implementation(test_implementation, problem) 157 | test_tp.backward_cpu( 158 | L1_in=in1.copy(), 159 | L1_grad=test_in1_grad, 160 | L2_in=in2.copy(), 161 | L2_grad=test_in2_grad, 162 | L3_grad=out_grad.copy(), 163 | weights=weights_copy, 164 | weights_grad=test_weights_grad, 165 | ) 166 | 167 | weight_threshold = ( 168 | correctness_threshold * batch_size 169 | if problem.shared_weights 170 | else correctness_threshold 171 | ) 172 | 173 | if problem.shared_weights: 174 | test_weights_grad = test_weights_grad.squeeze() 175 | 176 | for name, to_check, ground_truth, threshold in [ 177 | ("weight_grad", test_weights_grad, ref_weights_grad, weight_threshold), 178 | ("in1_grad", test_in1_grad, ref_in1_grad, correctness_threshold), 179 | ("in2_grad", test_in2_grad, ref_in2_grad, correctness_threshold), 180 | ]: 181 | result[name] = check_similiarity(name, to_check, ground_truth, threshold) 182 | 183 | return result 184 | 185 | 186 | def correctness_double_backward( 187 | problem: TPProblem, 188 | test_implementation: Union[type[TensorProductBase], TensorProductBase], 189 | reference_implementation: Optional[type[TensorProductBase]], 190 | batch_size: int, 191 | correctness_threshold: float, 192 | prng_seed: int, 193 | ): 194 | global torch 195 | import torch 196 | 197 | in1, in2, out_grad, weights, _, _, _ = get_random_buffers_backward( 198 | problem, batch_size, prng_seed 199 | ) 200 | rng = np.random.default_rng(seed=prng_seed * 2) 201 | dummy_grad = rng.standard_normal(1)[0] 202 | 203 | if reference_implementation is None: 204 | from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct 205 | 206 | reference_implementation = E3NNTensorProduct 207 | 208 | result = {"thresh": correctness_threshold, "batch_size": batch_size} 209 | 210 | tensors = [] 211 | for i, impl in enumerate([test_implementation, reference_implementation]): 212 | tp = instantiate_implementation(impl, problem) 213 | 214 | if impl == CUETensorProduct and problem.shared_weights: 215 | weights = weights[np.newaxis, :] 216 | 217 | weights_reordered = np.zeros_like(weights) 218 | if tp.reorder_weights_e3nn_to_oeq is not None: 219 | tp.reorder_weights_e3nn_to_oeq( 220 | weights, weights_reordered, not tp.config.shared_weights 221 | ) 222 | else: 223 | weights_reordered = weights 224 | 225 | in1_torch = torch.tensor(in1, device="cuda", requires_grad=True) 226 | in2_torch = torch.tensor(in2, device="cuda", requires_grad=True) 227 | weights_torch = torch.tensor( 228 | weights_reordered, device="cuda", requires_grad=True 229 | ) 230 | 231 | out_torch = tp.forward(in1_torch, in2_torch, weights_torch) 232 | out_grad = out_torch.clone().detach().to(device="cuda").requires_grad_(True) 233 | 234 | in1_grad, in2_grad, w_grad = torch.autograd.grad( 235 | outputs=[out_torch], 236 | inputs=[in1_torch, in2_torch, weights_torch], 237 | grad_outputs=[out_grad], 238 | create_graph=True, 239 | ) 240 | 241 | dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad) 242 | dummy_grad = torch.tensor(float(dummy_grad), device="cuda", requires_grad=True) 243 | 244 | dummy.backward( 245 | dummy_grad, 246 | retain_graph=True, 247 | inputs=[out_grad, in1_torch, in2_torch, weights_torch], 248 | ) 249 | 250 | weights_grad = weights_torch.grad.detach().cpu().numpy() 251 | if tp.reorder_weights_oeq_to_e3nn is not None: 252 | weights_grad_copy = weights_grad.copy() 253 | tp.reorder_weights_oeq_to_e3nn( 254 | weights_grad_copy, weights_grad, not tp.config.shared_weights 255 | ) 256 | 257 | tensors.append( 258 | ( 259 | out_grad.grad.detach().cpu().numpy(), 260 | in1_torch.grad.detach().cpu().numpy(), 261 | in2_torch.grad.detach().cpu().numpy(), 262 | weights_grad, 263 | ) 264 | ) 265 | 266 | for name, to_check, ground_truth in [ 267 | ("output_double_grad", tensors[0][0], tensors[1][0]), 268 | ("in1_grad", tensors[0][1], tensors[1][1]), 269 | ("in2_grad", tensors[0][2], tensors[1][2]), 270 | ("weights_grad", tensors[0][3], tensors[1][3]), 271 | ]: 272 | result[name] = check_similiarity( 273 | name, to_check, ground_truth, correctness_threshold 274 | ) 275 | 276 | return result 277 | -------------------------------------------------------------------------------- /openequivariance/benchmark/logging_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger("ETP") 4 | logger.setLevel(logging.CRITICAL) 5 | ch = logging.StreamHandler() 6 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") 7 | ch.setFormatter(formatter) 8 | logger.addHandler(ch) 9 | 10 | 11 | def getLogger(): 12 | return logger 13 | 14 | 15 | class bcolors: 16 | HEADER = "\033[95m" 17 | OKBLUE = "\033[94m" 18 | OKCYAN = "\033[96m" 19 | OKGREEN = "\033[92m" 20 | WARNING = "\033[93m" 21 | FAIL = "\033[91m" 22 | ENDC = "\033[0m" 23 | BOLD = "\033[1m" 24 | UNDERLINE = "\033[4m" 25 | -------------------------------------------------------------------------------- /openequivariance/benchmark/perf_metrics_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from openequivariance.implementations.utils import ( 4 | count_cg_non_zero, 5 | sparse_outer_product_work, 6 | ) 7 | from openequivariance.implementations.TensorProductBase import TensorProductBase 8 | from openequivariance.implementations.e3nn_lite import TPProblem 9 | from openequivariance.benchmark.logging_utils import getLogger 10 | import numpy as np 11 | 12 | logger = getLogger() 13 | 14 | 15 | def calculate_minimum_memory_streamed_forward( 16 | tpp: TPProblem, batch_size: int 17 | ) -> dict[str, int]: 18 | """ 19 | This represents an absolute minimum amount of memory that could be streamed on an ideal machine 20 | It returns the number of bytes streamed total and from each source 21 | """ 22 | data_size = {} 23 | irrep_word_size = np.dtype(tpp.irrep_dtype).itemsize 24 | weight_word_size = np.dtype(tpp.weight_dtype).itemsize 25 | 26 | data_size["input 1"] = tpp.irreps_in1.dim * batch_size * irrep_word_size 27 | data_size["input 2"] = tpp.irreps_in2.dim * batch_size * irrep_word_size 28 | data_size["output"] = tpp.irreps_out.dim * batch_size * irrep_word_size 29 | data_size["weights"] = tpp.weight_numel * batch_size * weight_word_size 30 | data_size["total"] = sum(data_size.values()) 31 | return data_size 32 | 33 | 34 | def calculate_minimum_memory_streamed_backward(tpp: TPProblem, batch_size: int) -> dict: 35 | """ 36 | This represents an absolute minimum amount of memory that could be streamed on an ideal machine 37 | It returns the number of bytes streamed total and from each source 38 | """ 39 | data_size = {} 40 | irrep_word_size = np.dtype(tpp.irrep_dtype).itemsize 41 | weight_word_size = np.dtype(tpp.weight_dtype).itemsize 42 | 43 | data_size["input 1"] = tpp.irreps_in1.dim * batch_size * irrep_word_size 44 | data_size["input 1 grad"] = tpp.irreps_in1.dim * batch_size * irrep_word_size 45 | data_size["input 2"] = tpp.irreps_in2.dim * batch_size * irrep_word_size 46 | data_size["input 2 grad"] = tpp.irreps_in2.dim * batch_size * irrep_word_size 47 | data_size["output grad"] = tpp.irreps_out.dim * batch_size * irrep_word_size 48 | data_size["weights"] = tpp.weight_numel * batch_size * weight_word_size 49 | data_size["weights grad"] = tpp.weight_numel * batch_size * weight_word_size 50 | data_size["total"] = sum(data_size.values()) 51 | return data_size 52 | 53 | 54 | def calculate_minimum_flops_forward(tpp: TPProblem, batch_size: int) -> dict: 55 | """ 56 | This is not actually calcuating the minimum value. 57 | Ideally you might share the outer product values between two inputs across multiple inputs. 58 | This is assuming that you form those values and reuse them once per CG decomp. 59 | """ 60 | logger.warning("Minimum flops Calculation is not the true minimum") 61 | flops_count = {} 62 | flops_count["outer_products"] = 0 63 | flops_count["CG_decomposition"] = 0 64 | flops_count["linear_combination"] = 0 65 | for ins in tpp.instructions: # type : Instruction 66 | l1, l2, l3 = ( 67 | tpp.irreps_in1[ins.i_in1].ir.l, 68 | tpp.irreps_in2[ins.i_in2].ir.l, 69 | tpp.irreps_out[ins.i_out].ir.l, 70 | ) 71 | 72 | flops_count["outer_products"] += sparse_outer_product_work( 73 | TensorProductBase.load_cg_tensor(l1, l2, l3) 74 | ) 75 | flops_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * ( 76 | ins.path_shape[0] * ins.path_shape[1] 77 | ) 78 | flops_count["linear_combination"] += ( 79 | (2 * l3 + 1) * math.prod(ins.path_shape) if ins.has_weight else 0 80 | ) 81 | 82 | flops_count["outer_products"] *= batch_size 83 | flops_count["CG_decomposition"] *= 2 * batch_size 84 | flops_count["linear_combination"] *= 2 * batch_size 85 | 86 | flops_count["total"] = sum(flops_count.values()) 87 | return flops_count 88 | 89 | 90 | def calculate_minimum_flops_backward(tpp: TPProblem, batch_size: int) -> dict: 91 | """ 92 | This is not actually calcuating the minumum value. 93 | Ideally you might share the outer product values between two inputs across multiple inputs. 94 | This is assuming that you form those values and reuse them once per CG decomp. 95 | """ 96 | raise NotImplementedError("this needs to be implemented properly") 97 | -------------------------------------------------------------------------------- /openequivariance/benchmark/plotting/__init__.py: -------------------------------------------------------------------------------- 1 | from openequivariance.benchmark.plotting.plot_uvu import plot_uvu 2 | from openequivariance.benchmark.plotting.plot_uvw import plot_uvw 3 | from openequivariance.benchmark.plotting.plot_roofline import plot_roofline 4 | from openequivariance.benchmark.plotting.plot_convolution import plot_convolution 5 | from openequivariance.benchmark.plotting.plot_double_backward import ( 6 | plot_double_backward, 7 | ) 8 | 9 | __all__ = [ 10 | "plot_uvu", 11 | "plot_uvw", 12 | "plot_roofline", 13 | "plot_convolution", 14 | "plot_double_backward", 15 | ] 16 | -------------------------------------------------------------------------------- /openequivariance/benchmark/plotting/plot_convolution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pathlib 4 | from openequivariance.benchmark.plotting.plotting_utils import ( 5 | set_grid, 6 | colormap, 7 | labelmap, 8 | hatchmap, 9 | dtypes, 10 | directions, 11 | dtype_labelmap, 12 | grouped_barchart, 13 | load_benchmarks, 14 | ) 15 | 16 | 17 | def plot_convolution(data_folder): 18 | data_folder = pathlib.Path(data_folder) 19 | benchmarks, metadata = load_benchmarks(data_folder) 20 | 21 | implementations = [ 22 | "CUEConvolution", 23 | "CUEConvolutionFused", 24 | "LoopUnrollConvScatterSum", 25 | "LoopUnrollConvAtomic", 26 | "LoopUnrollConvDeterministic", 27 | ] 28 | 29 | graphs = ["1drf_radius6.0", "covid_spike_radius3.0", "carbon_lattice_radius6.0"] 30 | graph_lmap = { 31 | "covid_spike_radius3.0": "COVID spike", 32 | "1drf_radius6.0": "DHFR", 33 | "carbon_lattice_radius6.0": "carbon-lattice", 34 | } 35 | 36 | data = {} 37 | 38 | for direction in directions: 39 | data[direction] = {} 40 | for dtype in dtypes: 41 | data[direction][dtype] = {} 42 | for graph in graphs: 43 | data[direction][dtype][graph_lmap[graph]] = {} 44 | for impl in implementations: 45 | exp = filter( 46 | benchmarks, 47 | { 48 | "graph": graph, 49 | "direction": direction, 50 | "name": impl, 51 | "irrep_dtype": dtype, 52 | }, 53 | match_one=True, 54 | ) 55 | 56 | data[direction][dtype][graph_lmap[graph]][labelmap[impl]] = np.mean( 57 | exp["benchmark"]["time_millis"] 58 | ) 59 | 60 | fig = plt.figure(figsize=(5, 5)) 61 | gs = fig.add_gridspec(2, 2, hspace=0, wspace=0) 62 | axes = gs.subplots(sharex="col", sharey="row") 63 | 64 | for i, direction in enumerate(directions): 65 | for j, dtype in enumerate(dtypes): 66 | for k, graph in enumerate(graphs): 67 | normalizing_value = data[direction][dtype][graph_lmap[graph]][ 68 | "cuE-scattersum" 69 | ] 70 | for impl in implementations: 71 | data[direction][dtype][graph_lmap[graph]][labelmap[impl]] = ( 72 | normalizing_value 73 | / data[direction][dtype][graph_lmap[graph]][labelmap[impl]] 74 | ) 75 | 76 | grouped_barchart( 77 | data[direction][dtype], 78 | axes[i][j], 79 | bar_height_fontsize=0, 80 | rotate_xlabels=True, 81 | colormap=colormap, 82 | hatchmap=hatchmap, 83 | group_spacing=6.0, 84 | ) 85 | 86 | axes[i][j].set_xlabel(dtype_labelmap[dtype]) 87 | axes[i][j].set_ylabel(direction) 88 | axes[i][j].axhline(1.0, ls="--", c=colormap["cuE"]) 89 | set_grid(axes[i][j]) 90 | 91 | axes[1][0].set_ylim(0, 3.8) 92 | for ax in fig.get_axes(): 93 | ax.label_outer() 94 | 95 | fig.supylabel("Speedup over cuE-scattersum", x=0.025, y=0.6) 96 | 97 | handles, labels = axes[0][0].get_legend_handles_labels() 98 | for i, l in enumerate(labels): 99 | if "fast" in l: 100 | labels[i] += " (ours)" 101 | 102 | unique = [ 103 | (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] 104 | ] 105 | fig.legend(*zip(*unique), loc="upper center", bbox_to_anchor=(0.55, 0.01)) 106 | 107 | fig.show() 108 | fig.tight_layout() 109 | fig.savefig(str(data_folder / "kernel_fusion_speedup.pdf"), bbox_inches="tight") 110 | -------------------------------------------------------------------------------- /openequivariance/benchmark/plotting/plot_double_backward.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pathlib 4 | from openequivariance.benchmark.plotting.plotting_utils import ( 5 | set_grid, 6 | colormap, 7 | labelmap, 8 | grouped_barchart, 9 | load_benchmarks, 10 | ) 11 | 12 | 13 | def plot_double_backward(data_folder): 14 | data_folder = pathlib.Path(data_folder) 15 | benchmarks, metadata = load_benchmarks(data_folder) 16 | 17 | configs = metadata["config_labels"] 18 | implementations = ["E3NNTensorProduct", "CUETensorProduct", "LoopUnrollTP"] 19 | 20 | def calculate_tp_per_sec(exp): 21 | return exp["benchmark results"]["batch_size"] / ( 22 | np.mean(exp["benchmark results"]["time_millis"]) * 0.001 23 | ) 24 | 25 | dataf32 = {"double_backward": {}} 26 | for i, desc in enumerate(configs): 27 | for direction in ["double_backward"]: 28 | dataf32[direction][desc] = {} 29 | for impl in implementations: 30 | f32_benches = [ 31 | b 32 | for b in benchmarks 33 | if b["benchmark results"]["rep_dtype"] == "" 34 | ] 35 | exp = filter( 36 | f32_benches, 37 | { 38 | "config_label": desc, 39 | "direction": direction, 40 | "implementation_name": impl, 41 | }, 42 | match_one=True, 43 | ) 44 | dataf32[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp) 45 | 46 | dataf64 = {"double_backward": {}} 47 | for i, desc in enumerate(configs): 48 | for direction in ["double_backward"]: 49 | dataf64[direction][desc] = {} 50 | for impl in implementations: 51 | f64_benches = [ 52 | b 53 | for b in benchmarks 54 | if "float64" in b["benchmark results"]["rep_dtype"] 55 | ] 56 | 57 | exp = filter( 58 | f64_benches, 59 | { 60 | "config_label": desc, 61 | "direction": direction, 62 | "implementation_name": impl, 63 | }, 64 | match_one=True, 65 | ) 66 | 67 | if exp is None: 68 | print(desc) 69 | print(direction) 70 | print(impl) 71 | 72 | dataf64[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp) 73 | 74 | fig = plt.figure(figsize=(7, 3)) 75 | gs = fig.add_gridspec(1, 2, hspace=0, wspace=0.1) 76 | axs = gs.subplots(sharex="col", sharey="row") 77 | 78 | grouped_barchart( 79 | dataf32["double_backward"], 80 | axs[0], 81 | bar_height_fontsize=0, 82 | colormap=colormap, 83 | group_spacing=6.0, 84 | ) 85 | grouped_barchart( 86 | dataf64["double_backward"], 87 | axs[1], 88 | bar_height_fontsize=0, 89 | colormap=colormap, 90 | group_spacing=6.0, 91 | ) 92 | 93 | for i in range(2): 94 | set_grid(axs[i]) 95 | set_grid(axs[i]) 96 | 97 | axs[0].set_xlabel("float32") 98 | axs[1].set_xlabel("float64") 99 | 100 | handles, labels = axs[0].get_legend_handles_labels() 101 | unique = [ 102 | (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] 103 | ] 104 | axs[0].legend(*zip(*unique)) 105 | 106 | for ax in fig.get_axes(): 107 | ax.label_outer() 108 | 109 | fig.supylabel("2nd Deriv. Throughput\n(# tensor products / s)", y=0.5) 110 | 111 | speedup_table = [] 112 | for direction in ["double_backward"]: 113 | for impl in ["e3nn", "cuE"]: 114 | for dtype_label, dtype_set in [("f32", dataf32), ("f64", dataf64)]: 115 | speedups = [ 116 | measurement["ours"] / measurement[impl] 117 | for _, measurement in dtype_set[direction].items() 118 | if impl in measurement 119 | ] 120 | stats = ( 121 | np.min(speedups), 122 | np.mean(speedups), 123 | np.median(speedups), 124 | np.max(speedups), 125 | ) 126 | stats = [f"{stat:.2f}" for stat in stats] 127 | 128 | dir_print = direction 129 | result = [dir_print, impl, dtype_label] + stats 130 | speedup_table.append(result) 131 | 132 | print("\t\t".join(["Direction", "Base", "dtype", "min", "mean", "med", "max"])) 133 | for row in speedup_table: 134 | print("\t\t".join(row)) 135 | 136 | fig.show() 137 | fig.tight_layout() 138 | fig.savefig( 139 | str(data_folder / "double_backward_throughput.pdf"), bbox_inches="tight" 140 | ) 141 | -------------------------------------------------------------------------------- /openequivariance/benchmark/plotting/plot_roofline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pathlib 3 | from openequivariance.benchmark.plotting.plotting_utils import ( 4 | colormap, 5 | labelmap, 6 | load_benchmarks, 7 | roofline_plot, 8 | ) 9 | 10 | 11 | def plot_roofline(data_folder): 12 | data_folder = pathlib.Path(data_folder) 13 | benchmarks, metadata = load_benchmarks(data_folder) 14 | 15 | configs = metadata["config_labels"] 16 | implementations = ["LoopUnrollTP", "CUETensorProduct"] 17 | 18 | data = {"forward": {}, "backward": {}} 19 | for i, desc in enumerate(configs): 20 | for direction in ["forward", "backward"]: 21 | data[direction][desc] = {} 22 | for impl in implementations: 23 | exp = filter( 24 | benchmarks, 25 | { 26 | "config_label": desc, 27 | "direction": direction, 28 | "implementation_name": impl, 29 | }, 30 | match_one=True, 31 | ) 32 | data[direction][desc][labelmap[impl]] = ( 33 | exp["benchmark results"]["arithmetic_intensity (FLOPs / byte)"], 34 | np.mean(exp["benchmark results"]["throughputs_gflops"]), 35 | ) 36 | 37 | roofline_data = [] 38 | marker_map = { 39 | "forward-cuE": "+", 40 | "backward-cuE": "X", 41 | "forward-ours": "P", 42 | "backward-ours": "X", 43 | } 44 | for i, desc in enumerate(configs): 45 | for direction in ["forward", "backward"]: 46 | ai, throughput = data[direction][desc][labelmap["LoopUnrollTP"]] 47 | label = f"{direction}-ours" 48 | roofline_data.append( 49 | { 50 | "AI": float(ai), 51 | "throughput": throughput / 1000, 52 | "label": label, 53 | "marker": marker_map[label], 54 | "color": colormap["ours"], 55 | "markersize": 80, 56 | } 57 | ) 58 | 59 | label = f"{direction}-cuE" 60 | ai, throughput = data[direction][desc][labelmap["CUETensorProduct"]] 61 | roofline_data.append( 62 | { 63 | "AI": float(ai), 64 | "throughput": throughput / 1000, 65 | "label": label, 66 | "marker": marker_map[label], 67 | "color": colormap["cuE"], 68 | "markersize": 80, 69 | } 70 | ) 71 | 72 | cpu_roofs = {"A100-SXM-80GB FP32 Peak": 19.5} 73 | mem_bottlenecks = {"HBM2": 2.039} 74 | AI_v = {"": 9.56} 75 | 76 | draw_bounds = {"xmin": 0.4, "xmax": 15, "ymin": 0.15, "ymax": 25} 77 | fig, ax = roofline_plot( 78 | draw_bounds, 79 | cpu_roofs, 80 | mem_bottlenecks, 81 | AI_v, 82 | roofline_data, 83 | fig_ratio=1.8, 84 | fig_dimension=4, 85 | ) 86 | 87 | handles, labels = ax.get_legend_handles_labels() 88 | unique = [ 89 | (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] 90 | ] 91 | ax.legend(*zip(*unique)) 92 | 93 | fig.show() 94 | fig.savefig(str(data_folder / "roofline.pdf")) 95 | 96 | # Table of throughputs and arithmetic intensities 97 | 98 | header = r""" 99 | \begin{tabular}{cccccc} 100 | \toprule 101 | \multirow{2}{*}{ID} & \multirow{2}{*}{Description} & \multirow{2}{*}{Dir.} & \multirow{2}{*}{AI} & \multicolumn{2}{c}{TFLOP/s} \\ 102 | \cmidrule(r){5-6} 103 | & & & & \multicolumn{1}{l}{cuE} & ours \\ 104 | \midrule 105 | """ 106 | 107 | rows = [] 108 | 109 | dir_map = {"forward": "F", "backward": "B"} 110 | for i, desc in enumerate(sorted(configs)): 111 | for direction in ["forward", "backward"]: 112 | for impl in implementations: 113 | short_id, long_desc = desc.split("#") 114 | long_desc = long_desc.replace("->", "$\\rightarrow$").replace( 115 | " x ", "$\ \\times\ $" 116 | ) 117 | ai_ours, throughput_ours = data[direction][desc][ 118 | labelmap["LoopUnrollTP"] 119 | ] 120 | throughput_ours = f"{float(throughput_ours / 1000):.2f}" 121 | _, throughput_cue = data[direction][desc][labelmap["CUETensorProduct"]] 122 | throughput_cue = f"{float(throughput_cue / 1000):.2f}" 123 | 124 | result = [ 125 | "\multirow{2}{*}{" + short_id + "}", 126 | "\multirow{2}{*}{" + long_desc + "}", 127 | dir_map[direction], 128 | f"{ai_ours:.1f}", 129 | throughput_cue, 130 | throughput_ours, 131 | ] 132 | if direction == "backward": 133 | result[0] = "" 134 | result[1] = "" 135 | rows.append(result) 136 | 137 | print(header) 138 | result = "" 139 | for i, row in enumerate(rows): 140 | result += " & ".join(row) + r"\\" + "\n" 141 | if row[2] == "B" and i < len(rows) - 1: 142 | result += "\cmidrule(r){3-6}" + "\n" 143 | 144 | print(result.replace("[", "").replace("]", "").replace("uvu", "B")) 145 | print("\\bottomrule\n\\end{tabular}") 146 | -------------------------------------------------------------------------------- /openequivariance/benchmark/plotting/plot_uvu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pathlib 4 | from openequivariance.benchmark.plotting.plotting_utils import ( 5 | set_grid, 6 | colormap, 7 | labelmap, 8 | grouped_barchart, 9 | load_benchmarks, 10 | ) 11 | 12 | 13 | def plot_uvu(data_folder): 14 | data_folder = pathlib.Path(data_folder) 15 | benchmarks, metadata = load_benchmarks(data_folder) 16 | configs = metadata["config_labels"] 17 | implementations = metadata["implementations"] 18 | 19 | for benchmark in benchmarks: 20 | if ( 21 | benchmark["implementation_name"] 22 | == "E3NNTensorProductCompiledMaxAutotuneCUDAGraphs" 23 | ): 24 | benchmark["implementation_name"] = "E3NNTensorProduct" 25 | 26 | for i, implementation in enumerate(implementations): 27 | if implementation == "E3NNTensorProductCompiledMaxAutotuneCUDAGraphs": 28 | implementations[i] = "E3NNTensorProduct" 29 | 30 | def calculate_tp_per_sec(exp): 31 | return exp["benchmark results"]["batch_size"] / ( 32 | np.mean(exp["benchmark results"]["time_millis"]) * 0.001 33 | ) 34 | 35 | dataf32 = {"forward": {}, "backward": {}} 36 | for i, desc in enumerate(configs): 37 | for direction in ["forward", "backward"]: 38 | dataf32[direction][desc] = {} 39 | for impl in implementations: 40 | f32_benches = [ 41 | b 42 | for b in benchmarks 43 | if b["benchmark results"]["rep_dtype"] == "" 44 | ] 45 | exp = filter( 46 | f32_benches, 47 | { 48 | "config_label": desc, 49 | "direction": direction, 50 | "implementation_name": impl, 51 | }, 52 | match_one=True, 53 | ) 54 | if exp is not None: 55 | dataf32[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp) 56 | else: 57 | dataf32[direction][desc][labelmap[impl]] = 0.0 58 | 59 | dataf64 = {"forward": {}, "backward": {}} 60 | for i, desc in enumerate(configs): 61 | for direction in ["forward", "backward"]: 62 | dataf64[direction][desc] = {} 63 | for impl in implementations: 64 | f64_benches = [ 65 | b 66 | for b in benchmarks 67 | if b["benchmark results"]["rep_dtype"] == "" 68 | ] 69 | exp = filter( 70 | f64_benches, 71 | { 72 | "config_label": desc, 73 | "direction": direction, 74 | "implementation_name": impl, 75 | }, 76 | match_one=True, 77 | ) 78 | 79 | if exp is not None: 80 | dataf64[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp) 81 | else: 82 | dataf64[direction][desc][labelmap[impl]] = 0.0 83 | 84 | fig = plt.figure(figsize=(7, 7)) 85 | gs = fig.add_gridspec(2, 2) 86 | axs = gs.subplots(sharex=True) 87 | 88 | grouped_barchart( 89 | dataf32["forward"], 90 | axs[0][0], 91 | bar_height_fontsize=0, 92 | xticklabel=False, 93 | colormap=colormap, 94 | group_spacing=6.0, 95 | ) 96 | grouped_barchart( 97 | dataf32["backward"], 98 | axs[1][0], 99 | bar_height_fontsize=0, 100 | colormap=colormap, 101 | group_spacing=6.0, 102 | ) 103 | 104 | grouped_barchart( 105 | dataf64["forward"], 106 | axs[0][1], 107 | bar_height_fontsize=0, 108 | xticklabel=False, 109 | colormap=colormap, 110 | group_spacing=6.0, 111 | ) 112 | grouped_barchart( 113 | dataf64["backward"], 114 | axs[1][1], 115 | bar_height_fontsize=0, 116 | colormap=colormap, 117 | group_spacing=6.0, 118 | ) 119 | 120 | for i in range(2): 121 | for j in range(2): 122 | set_grid(axs[i][j]) 123 | set_grid(axs[i][j]) 124 | 125 | axs[0][0].set_ylabel("Forward") 126 | axs[1][0].set_ylabel("Backward") 127 | 128 | handles, labels = axs[0][0].get_legend_handles_labels() 129 | unique = [ 130 | (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] 131 | ] 132 | axs[0][0].legend(*zip(*unique)) 133 | 134 | fig.supylabel("Throughput (# tensor products / s)", x=0.036, y=0.605) 135 | 136 | axs[1][0].set_xlabel("float32") 137 | axs[1][1].set_xlabel("float64") 138 | 139 | fig.show() 140 | fig.tight_layout() 141 | fig.savefig(str(data_folder / "throughput_comparison.pdf")) 142 | 143 | speedup_table = [] 144 | for direction in ["forward", "backward"]: 145 | for impl in ["e3nn", "cuE"]: 146 | for dtype_label, dtype_set in [("f32", dataf32), ("f64", dataf64)]: 147 | speedups = [ 148 | measurement["ours"] / measurement[impl] 149 | for _, measurement in dtype_set[direction].items() 150 | if impl in measurement 151 | ] 152 | stats = ( 153 | np.min(speedups), 154 | np.mean(speedups), 155 | np.median(speedups), 156 | np.max(speedups), 157 | ) 158 | stats = [f"{stat:.2f}" for stat in stats] 159 | 160 | dir_print = direction 161 | if direction == "forward": 162 | dir_print += " " 163 | result = [dir_print, impl, dtype_label] + stats 164 | speedup_table.append(result) 165 | 166 | print("\t\t".join(["Direction", "Base", "dtype", "min", "mean", "med", "max"])) 167 | for row in speedup_table: 168 | print("\t\t".join(row)) 169 | -------------------------------------------------------------------------------- /openequivariance/benchmark/plotting/plot_uvw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pathlib 4 | from openequivariance.benchmark.plotting.plotting_utils import ( 5 | set_grid, 6 | colormap, 7 | labelmap, 8 | grouped_barchart, 9 | calculate_tp_per_sec, 10 | load_benchmarks, 11 | ) 12 | 13 | 14 | def plot_uvw(data_folder): 15 | data_folder = pathlib.Path(data_folder) 16 | benchmarks, metadata = load_benchmarks(data_folder) 17 | 18 | configs = metadata["config_labels"] 19 | implementations = metadata["implementations"] 20 | metadata["directions"] 21 | 22 | dataf32 = {"forward": {}, "backward": {}} 23 | for i, desc in enumerate(configs): 24 | for direction in ["forward", "backward"]: 25 | dataf32[direction][desc] = {} 26 | for impl in implementations: 27 | if True: # direction == "forward" or impl != "CUETensorProduct" or 'mace' in desc: 28 | f32_benches = [ 29 | b 30 | for b in benchmarks 31 | if b["benchmark results"]["rep_dtype"] 32 | == "" 33 | ] 34 | exp = filter( 35 | f32_benches, 36 | { 37 | "config_label": desc, 38 | "direction": direction, 39 | "implementation_name": impl, 40 | }, 41 | match_one=True, 42 | ) 43 | dataf32[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp) 44 | 45 | dataf64 = {"forward": {}, "backward": {}} 46 | for i, desc in enumerate(configs): 47 | for direction in ["forward", "backward"]: 48 | dataf64[direction][desc] = {} 49 | for impl in implementations: 50 | if True: # direction == "forward" or impl != "CUETensorProduct" or 'mace' in desc: 51 | f64_benches = [ 52 | b 53 | for b in benchmarks 54 | if b["benchmark results"]["rep_dtype"] 55 | == "" 56 | ] 57 | exp = filter( 58 | f64_benches, 59 | { 60 | "config_label": desc, 61 | "direction": direction, 62 | "implementation_name": impl, 63 | }, 64 | match_one=True, 65 | ) 66 | dataf64[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp) 67 | 68 | plt.rcParams["font.family"] = "serif" 69 | plt.rcParams.update({"font.size": 11}) 70 | 71 | fig = plt.figure(figsize=(7, 7)) 72 | gs = fig.add_gridspec(2, 2) 73 | axs = gs.subplots(sharex=True, sharey="row") 74 | 75 | grouped_barchart( 76 | dataf32["forward"], 77 | axs[0][0], 78 | bar_height_fontsize=0, 79 | xticklabel=False, 80 | colormap=colormap, 81 | group_spacing=6.0, 82 | ) 83 | grouped_barchart( 84 | dataf32["backward"], 85 | axs[1][0], 86 | bar_height_fontsize=0, 87 | xticklabel=True, 88 | colormap=colormap, 89 | group_spacing=6.0, 90 | ) 91 | 92 | grouped_barchart( 93 | dataf64["forward"], 94 | axs[0][1], 95 | bar_height_fontsize=0, 96 | xticklabel=False, 97 | colormap=colormap, 98 | group_spacing=6.0, 99 | ) 100 | grouped_barchart( 101 | dataf64["backward"], 102 | axs[1][1], 103 | bar_height_fontsize=0, 104 | xticklabel=True, 105 | colormap=colormap, 106 | group_spacing=6.0, 107 | ) 108 | 109 | for i in range(2): 110 | for j in range(2): 111 | set_grid(axs[i][j]) 112 | 113 | fig.supylabel("Throughput (# tensor products / s)", x=0.03, y=0.56) 114 | 115 | axs[0][0].set_ylabel("Forward") 116 | axs[1][0].set_ylabel("Backward") 117 | 118 | axs[1][0].set_xlabel("float32") 119 | axs[1][1].set_xlabel("float64") 120 | 121 | handles, labels = axs[0][1].get_legend_handles_labels() 122 | unique = [ 123 | (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] 124 | ] 125 | axs[0][1].legend(*zip(*unique)) 126 | 127 | fig.show() 128 | fig.tight_layout() 129 | fig.savefig(str(data_folder / "uvw_throughput_comparison.pdf")) 130 | 131 | speedup_table = [] 132 | for direction in ["forward", "backward"]: 133 | for impl in ["e3nn", "cuE"]: 134 | for dtype_label, dtype_set in [("f32", dataf32), ("f64", dataf64)]: 135 | speedups = [ 136 | measurement["ours"] / measurement[impl] 137 | for label, measurement in dtype_set[direction].items() 138 | if impl in measurement and "DiffDock" in label 139 | ] 140 | stats = ( 141 | np.min(speedups), 142 | np.mean(speedups), 143 | np.median(speedups), 144 | np.max(speedups), 145 | ) 146 | stats = [f"{stat:.2f}" for stat in stats] 147 | 148 | dir_print = direction 149 | if direction == "forward": 150 | dir_print += " " 151 | result = [dir_print, impl, dtype_label] + stats 152 | speedup_table.append(result) 153 | 154 | print("DiffDock") 155 | print("\t\t".join(["Direction", "Base", "dtype", "min", "mean", "med", "max"])) 156 | for row in speedup_table: 157 | print("\t\t".join(row)) 158 | -------------------------------------------------------------------------------- /openequivariance/benchmark/problems.py: -------------------------------------------------------------------------------- 1 | from openequivariance.benchmark.tpp_creation_utils import ( 2 | FullyConnectedTPProblem as FCTPP, 3 | ) 4 | from openequivariance.benchmark.tpp_creation_utils import ChannelwiseTPP as CTPP 5 | 6 | # source: https://github.com/e3nn/e3nn/blob/main/examples/tetris.py 7 | # running tetris will output the layers. I've only extracted the fully connected layers here. 8 | _e3nn_torch_tetris = [ 9 | # 0th Layer 10 | FCTPP("1x0e", "1x0e", "150x0e + 50x1o + 50x2e"), # sc 11 | FCTPP("1x0e", "1x0e", "1x0e"), # lin1 12 | FCTPP("1x0e + 1x1o + 1x2e", "1x0e", "150x0e + 50x1o + 50x2e"), # lin2 13 | FCTPP("1x0e + 1x1o + 1x2e", "1x0e", "1x0e"), # alpha 14 | # 1st Layer 15 | FCTPP( 16 | "50x0e + 50x1o + 50x2e", "1x0e", "250x0e + 50x1o + 50x1e + 50x2o + 50x2e" 17 | ), # sc 18 | FCTPP("50x0e + 50x1o + 50x2e", "1x0e", "50x0e + 50x1o + 50x2e"), # lin1 19 | # FCTPP("50x0e + 50x1o + 50x2e", "1x0e + 1x1o + 1x2e", "150x0e + 200x1o + 100x1e + 100x2o + 200x2e"), #tp 20 | FCTPP( 21 | "150x0e + 200x1o + 100x1e + 100x2o + 200x2e", 22 | "1x0e", 23 | "250x0e + 50x1o + 50x1e + 50x2o + 50x2e", 24 | ), # lin2 25 | FCTPP("150x0e + 200x1o + 100x1e + 100x2o + 200x2e", "1x0e", "1x0e"), # alpha 26 | # 2nd Layer 27 | FCTPP( 28 | "50x0e + 50x1o + 50x1e + 50x2o + 50x2e", 29 | "1x0e", 30 | "50x0o + 250x0e + 50x1o + 50x1e + 50x2o + 50x2e", 31 | ), # sc 32 | FCTPP( 33 | "50x0e + 50x1o + 50x1e + 50x2o + 50x2e", 34 | "1x0e", 35 | "50x0e + 50x1o + 50x1e + 50x2o + 50x2e", 36 | ), # lin1 37 | FCTPP( 38 | "100x0o + 150x0e + 300x1o + 250x1e + 250x2o + 300x2e", 39 | "1x0e", 40 | "50x0o + 250x0e + 50x1o + 50x1e + 50x2o + 50x2e", 41 | ), # lin2 42 | FCTPP( 43 | "100x0o + 150x0e + 300x1o + 250x1e + 250x2o + 300x2e", "1x0e", "1x0e" 44 | ), # alpha 45 | # 3rd Layer 46 | FCTPP("50x0o + 50x0e + 50x1o + 50x1e + 50x2o + 50x2e", "1x0e", "1x0o + 6x0e"), # sc 47 | FCTPP( 48 | "50x0o + 50x0e + 50x1o + 50x1e + 50x2o + 50x2e", 49 | "1x0e", 50 | "50x0o + 50x0e + 50x1o + 50x1e + 50x2o + 50x2e", 51 | ), # lin1 52 | FCTPP("150x0o + 150x0e", "1x0e", "1x0o + 6x0e"), # lin2 53 | FCTPP("150x0o + 150x0e", "1x0e", "1x0e"), # alpha 54 | ] 55 | 56 | 57 | def e3nn_torch_tetris_poly_problems(): 58 | # source: https://github.com/e3nn/e3nn/blob/f95297952303347a8a3cfe971efe449c710c43b2/examples/tetris_polynomial.py#L66-L68 59 | return [ 60 | FCTPP( 61 | "1x0e + 1x1o + 1x2e + 1x3o", 62 | "1x0e + 1x1o + 1x2e + 1x3o", 63 | "64x0e + 24x1e + 24x1o + 16x2e + 16x2o", 64 | label="tetris-poly-1", 65 | ), # tp1 66 | FCTPP( 67 | "64x0e + 24x1e + 24x1o + 16x2e + 16x2o", 68 | "1x0e + 1x1o + 1x2e", 69 | "0o + 6x0e", 70 | label="tetris-poly-2", 71 | ), # tp2 72 | ] 73 | 74 | 75 | # https://github.com/gcorso/DiffDock/blob/b4704d94de74d8cb2acbe7ec84ad234c09e78009/models/tensor_layers.py#L299 76 | # specific irreps come from Vivek's communication with DiffDock team 77 | def diffdock_problems(): 78 | return [ 79 | FCTPP( 80 | "10x1o + 10x1e + 48x0e + 48x0o", 81 | "1x0e + 1x1o", 82 | "10x1o + 10x1e + 48x0e + 48x0o", 83 | shared_weights=False, 84 | label="DiffDock-L=1", 85 | ), 86 | FCTPP( 87 | "10x1o + 10x1e + 48x0e + 48x0o", 88 | "1x0e + 1x1o + 1x2e", 89 | "10x1o + 10x1e + 48x0e + 48x0o", 90 | shared_weights=False, 91 | label="DiffDock-L=2", 92 | ), 93 | ] 94 | 95 | 96 | def mace_problems(): 97 | return [ 98 | CTPP(*config) 99 | for config in [ 100 | ( 101 | "128x0e+128x1o+128x2e", 102 | "1x0e+1x1o+1x2e+1x3o", 103 | "128x0e+128x1o+128x2e+128x3o", 104 | "mace-large", 105 | ), 106 | ( 107 | "128x0e+128x1o", 108 | "1x0e+1x1o+1x2e+1x3o", 109 | "128x0e+128x1o+128x2e", 110 | "mace-medium", 111 | ), 112 | ] 113 | ] 114 | 115 | 116 | def nequip_problems(): 117 | return [ 118 | CTPP(*config) 119 | for config in [ 120 | ( 121 | "32x0o + 32x0e + 32x1o + 32x1e + 32x2o + 32x2e", 122 | "0e + 1o + 2e", 123 | "32x0o + 32x0e + 32x1o + 32x1e + 32x2o + 32x2e", 124 | "nequip-lips", 125 | ), 126 | ( 127 | "64x0o + 64x0e + 64x1o + 64x1e", 128 | "0e + 1o", 129 | "64x0o + 64x0e + 64x1o + 64x1e", 130 | "nequip-revmd17-aspirin", 131 | ), 132 | ( 133 | "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e", 134 | "0e + 1o + 2e", 135 | "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e", 136 | "nequip-revmd17-toluene", 137 | ), 138 | ( 139 | "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e + 64x3o + 64x3e", 140 | "0e + 1o + 2e + 3o", 141 | "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e + 64x3o + 64x3e", 142 | "nequip-revmd17-benzene", 143 | ), 144 | ( 145 | "32x0o + 32x0e + 32x1o + 32x1e", 146 | "0e + 1o", 147 | "32x0o + 32x0e + 32x1o + 32x1e", 148 | "nequip-water", 149 | ), 150 | ] 151 | ] 152 | -------------------------------------------------------------------------------- /openequivariance/benchmark/random_buffer_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from openequivariance.implementations.e3nn_lite import TPProblem 4 | 5 | 6 | def get_random_buffers_forward( 7 | tpp: TPProblem, batch_size: int, prng_seed: int 8 | ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: 9 | """ 10 | Return properly sized numpy arrays needed to execute a tensor product in the forward direction 11 | Supports shared vs non-shared weights 12 | """ 13 | assert isinstance(tpp, TPProblem) 14 | rng = np.random.default_rng(prng_seed) 15 | 16 | in1 = np.array( 17 | rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype 18 | ) 19 | in2 = np.array( 20 | rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype 21 | ) 22 | 23 | weights_size = ( 24 | tuple([tpp.weight_numel]) 25 | if tpp.shared_weights 26 | else tuple([batch_size, tpp.weight_numel]) 27 | ) 28 | weights = np.array(rng.uniform(size=weights_size), dtype=tpp.weight_dtype) 29 | 30 | out = np.zeros(shape=(batch_size, tpp.irreps_out.dim), dtype=tpp.weight_dtype) 31 | 32 | return in1, in2, weights, out 33 | 34 | 35 | def get_random_buffers_backward( 36 | tpp: TPProblem, batch_size: int, prng_seed: int 37 | ) -> tuple[ 38 | np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray 39 | ]: 40 | """ 41 | Return properly sized numpy arrays needed to execute a tensor product in the backward direction 42 | Supports shared vs non-shared weights 43 | """ 44 | assert isinstance(tpp, TPProblem) 45 | rng = np.random.default_rng(prng_seed) 46 | 47 | in1 = np.array( 48 | rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype 49 | ) 50 | in2 = np.array( 51 | rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype 52 | ) 53 | out_grad = np.array( 54 | rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype 55 | ) 56 | 57 | weights_size = ( 58 | tuple([tpp.weight_numel]) 59 | if tpp.shared_weights 60 | else tuple([batch_size, tpp.weight_numel]) 61 | ) 62 | weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) 63 | 64 | weights_grad = np.zeros_like(weights) 65 | in1_grad = np.zeros_like(in1) 66 | in2_grad = np.zeros_like(in2) 67 | 68 | return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad 69 | 70 | 71 | def get_random_buffers_double_backward( 72 | tpp: TPProblem, batch_size: int, prng_seed: int 73 | ) -> tuple[ 74 | np.ndarray, 75 | np.ndarray, 76 | np.ndarray, 77 | np.ndarray, 78 | np.ndarray, 79 | np.ndarray, 80 | np.ndarray, 81 | np.ndarray, 82 | ]: 83 | """ 84 | Return properly sized numpy arrays needed to execute a tensor product in the double backward direction 85 | Supports shared vs non-shared weights 86 | """ 87 | assert isinstance(tpp, TPProblem) 88 | rng = np.random.default_rng(prng_seed) 89 | 90 | in1 = np.array( 91 | rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype 92 | ) 93 | in2 = np.array( 94 | rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype 95 | ) 96 | out_grad = np.array( 97 | rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype 98 | ) 99 | 100 | weights_size = ( 101 | tuple([tpp.weight_numel]) 102 | if tpp.shared_weights 103 | else tuple([batch_size, tpp.weight_numel]) 104 | ) 105 | weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) 106 | 107 | weights_grad = np.zeros_like(weights) 108 | in1_grad = np.zeros_like(in1) 109 | in2_grad = np.zeros_like(in2) 110 | out_double_grad = np.zeros_like(out_grad) 111 | 112 | return ( 113 | in1, 114 | in2, 115 | out_grad, 116 | weights, 117 | weights_grad, 118 | in1_grad, 119 | in2_grad, 120 | out_double_grad, 121 | ) 122 | 123 | 124 | def get_random_buffers_forward_conv( 125 | tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int 126 | ): 127 | rng = np.random.default_rng(prng_seed) 128 | 129 | in1 = np.array( 130 | rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype 131 | ) 132 | in2 = np.array( 133 | rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype 134 | ) 135 | 136 | weights_size = ( 137 | tuple([tpp.weight_numel]) 138 | if tpp.shared_weights 139 | else tuple([edge_count, tpp.weight_numel]) 140 | ) 141 | weights = np.array(rng.uniform(size=weights_size), dtype=tpp.weight_dtype) 142 | 143 | out = np.zeros(shape=(node_count, tpp.irreps_out.dim), dtype=tpp.weight_dtype) 144 | 145 | return in1, in2, weights, out 146 | 147 | 148 | def get_random_buffers_backward_conv( 149 | tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int 150 | ): 151 | """ 152 | Return properly sized numpy arrays needed to execute a tensor product in the backward direction 153 | Supports shared vs non-shared weights 154 | """ 155 | rng = np.random.default_rng(prng_seed) 156 | 157 | in1 = np.array( 158 | rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype 159 | ) 160 | in2 = np.array( 161 | rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype 162 | ) 163 | out_grad = np.array( 164 | rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype 165 | ) 166 | 167 | weights_size = ( 168 | tuple([tpp.weight_numel]) 169 | if tpp.shared_weights 170 | else tuple([edge_count, tpp.weight_numel]) 171 | ) 172 | weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) 173 | 174 | weights_grad = np.zeros_like(weights) 175 | in1_grad = np.zeros_like(in1) 176 | in2_grad = np.zeros_like(in2) 177 | 178 | return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad 179 | -------------------------------------------------------------------------------- /openequivariance/benchmark/tpp_creation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import Iterator, Optional 4 | from openequivariance.implementations.e3nn_lite import Irrep, Irreps, TPProblem 5 | 6 | """ 7 | This was taken from 8 | 9 | https://github.com/e3nn/e3nn/blob/0.5.4/e3nn/o3/_tensor_product/_sub.py 10 | 11 | And adopted to create TPP's to avoid torch dependence 12 | """ 13 | 14 | 15 | class FullyConnectedTPProblem(TPProblem): 16 | def __init__(self, irreps_in1, irreps_in2, irreps_out, **kwargs) -> None: 17 | irreps_in1 = Irreps(irreps_in1) 18 | irreps_in2 = Irreps(irreps_in2) 19 | irreps_out = Irreps(irreps_out) 20 | 21 | instr = [ 22 | (i_1, i_2, i_out, "uvw", True, 1.0) 23 | for i_1, (_, ir_1) in enumerate(irreps_in1) 24 | for i_2, (_, ir_2) in enumerate(irreps_in2) 25 | for i_out, (_, ir_out) in enumerate(irreps_out) 26 | if ir_out in ir_1 * ir_2 27 | ] 28 | super().__init__( 29 | irreps_in1, 30 | irreps_in2, 31 | irreps_out, 32 | instr, 33 | **kwargs, 34 | ) 35 | 36 | 37 | class ElementwiseTPProblem(TPProblem): 38 | def __init__(self, irreps_in1, irreps_in2, filter_ir_out=None, **kwargs) -> None: 39 | irreps_in1 = Irreps(irreps_in1).simplify() 40 | irreps_in2 = Irreps(irreps_in2).simplify() 41 | if filter_ir_out is not None: 42 | try: 43 | filter_ir_out = [Irrep(ir) for ir in filter_ir_out] 44 | except ValueError: 45 | raise ValueError( 46 | f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep" 47 | ) 48 | 49 | assert irreps_in1.num_irreps == irreps_in2.num_irreps 50 | 51 | irreps_in1 = list(irreps_in1) 52 | irreps_in2 = list(irreps_in2) 53 | 54 | i = 0 55 | while i < len(irreps_in1): 56 | mul_1, ir_1 = irreps_in1[i] 57 | mul_2, ir_2 = irreps_in2[i] 58 | 59 | if mul_1 < mul_2: 60 | irreps_in2[i] = (mul_1, ir_2) 61 | irreps_in2.insert(i + 1, (mul_2 - mul_1, ir_2)) 62 | 63 | if mul_2 < mul_1: 64 | irreps_in1[i] = (mul_2, ir_1) 65 | irreps_in1.insert(i + 1, (mul_1 - mul_2, ir_1)) 66 | i += 1 67 | 68 | out = [] 69 | instr = [] 70 | for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1, irreps_in2)): 71 | assert mul == mul_2 72 | for ir in ir_1 * ir_2: 73 | if filter_ir_out is not None and ir not in filter_ir_out: 74 | continue 75 | 76 | i_out = len(out) 77 | out.append((mul, ir)) 78 | instr += [(i, i, i_out, "uuu", False)] 79 | 80 | super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs) 81 | 82 | 83 | class FullTPProblem(TPProblem): 84 | def __init__( 85 | self, 86 | irreps_in1: Irreps, 87 | irreps_in2: Irreps, 88 | filter_ir_out: Iterator[Irrep] = None, 89 | **kwargs, 90 | ) -> None: 91 | irreps_in1 = Irreps(irreps_in1).simplify() 92 | irreps_in2 = Irreps(irreps_in2).simplify() 93 | if filter_ir_out is not None: 94 | try: 95 | filter_ir_out = [Irrep(ir) for ir in filter_ir_out] 96 | except ValueError: 97 | raise ValueError( 98 | f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep" 99 | ) 100 | 101 | out = [] 102 | instr = [] 103 | for i_1, (mul_1, ir_1) in enumerate(irreps_in1): 104 | for i_2, (mul_2, ir_2) in enumerate(irreps_in2): 105 | for ir_out in ir_1 * ir_2: 106 | if filter_ir_out is not None and ir_out not in filter_ir_out: 107 | continue 108 | 109 | i_out = len(out) 110 | out.append((mul_1 * mul_2, ir_out)) 111 | instr += [(i_1, i_2, i_out, "uvuv", False)] 112 | 113 | out = Irreps(out) 114 | out, p, _ = out.sort() 115 | 116 | instr = [ 117 | (i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instr 118 | ] 119 | 120 | super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs) 121 | 122 | 123 | class ChannelwiseTPP(TPProblem): 124 | """ 125 | Modified from mace/mace/modules/irreps_tools.py. 126 | """ 127 | 128 | def __init__( 129 | self, 130 | irreps_in1: Irreps, 131 | irreps_in2: Irreps, 132 | irreps_out: Irreps, 133 | label: Optional[str] = None, 134 | irrep_dtype=np.float32, 135 | weight_dtype=np.float32, 136 | ): 137 | trainable = True 138 | irreps1 = Irreps(irreps_in1) 139 | irreps2 = Irreps(irreps_in2) 140 | irreps_out = Irreps(irreps_out) 141 | 142 | # Collect possible irreps and their instructions 143 | irreps_out_list = [] 144 | instructions = [] 145 | for i, (mul, ir_in) in enumerate(irreps1): 146 | for j, (_, ir_edge) in enumerate(irreps2): 147 | for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2 148 | if ir_out in irreps_out: 149 | k = len(irreps_out_list) # instruction index 150 | irreps_out_list.append((mul, ir_out)) 151 | instructions.append((i, j, k, "uvu", trainable)) 152 | 153 | irreps_out = Irreps(irreps_out_list) 154 | irreps_out, permut, _ = irreps_out.sort() 155 | 156 | instructions = [ 157 | (i_in1, i_in2, permut[i_out], mode, train) 158 | for i_in1, i_in2, i_out, mode, train in instructions 159 | ] 160 | 161 | instructions = sorted(instructions, key=lambda x: x[2]) 162 | super().__init__( 163 | irreps1, 164 | irreps2, 165 | irreps_out, 166 | instructions, 167 | internal_weights=False, 168 | shared_weights=False, 169 | label=label, 170 | irrep_dtype=irrep_dtype, 171 | weight_dtype=weight_dtype, 172 | ) 173 | 174 | 175 | class SingleInstruction(TPProblem): 176 | def __init__( 177 | self, 178 | irreps_in1: Irreps, 179 | irreps_in2: Irreps, 180 | irreps_in3: Irreps, 181 | mode: str, 182 | label: Optional[str] = None, 183 | ): 184 | trainable = True 185 | irreps1 = Irreps(irreps_in1) 186 | irreps2 = Irreps(irreps_in2) 187 | irreps3 = Irreps(irreps_in3) 188 | instructions = [(0, 0, 0, mode, trainable)] 189 | 190 | super().__init__( 191 | irreps1, 192 | irreps2, 193 | irreps3, 194 | instructions, 195 | internal_weights=False, 196 | shared_weights=False, 197 | label=label, 198 | ) 199 | -------------------------------------------------------------------------------- /openequivariance/extension/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.15) 2 | project(equivariant_spmm LANGUAGES CXX) 3 | 4 | find_package(CUDAToolkit REQUIRED) 5 | find_package(pybind11 REQUIRED) 6 | find_package(Python COMPONENTS Interpreter Development) 7 | 8 | set(core_SOURCES 9 | util/buffer.hpp 10 | util/backend_cuda.hpp 11 | tensorproducts.hpp 12 | convolution.hpp 13 | ) 14 | 15 | add_library(kernel_wrapper MODULE ${core_SOURCES} kernel_wrapper.cpp) 16 | set_property(TARGET kernel_wrapper PROPERTY POSITION_INDEPENDENT_CODE ON) 17 | 18 | target_include_directories(kernel_wrapper PUBLIC 19 | ${CMAKE_CURRENT_SOURCE_DIR}/util 20 | ) 21 | 22 | set_target_properties(kernel_wrapper PROPERTIES LINKER_LANGUAGE CXX) 23 | 24 | # Manually build the module to avoid SOABI renaming. Original command: 25 | # pybind11_add_module(kernel_wrapper kernel_wrapper.cpp) 26 | # target_link_libraries(kernel_wrapper PRIVATE espmm) 27 | 28 | target_link_libraries(kernel_wrapper PRIVATE 29 | pybind11::pybind11 30 | CUDA::cudart 31 | CUDA::cuda_driver 32 | CUDA::nvrtc) 33 | target_link_options(kernel_wrapper PRIVATE -Wl,-rpath='$ORIGIN') 34 | 35 | if(NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES Debug|RelWithDebInfo) 36 | # Strip unnecessary sections of the binary on Linux/macOS 37 | pybind11_strip(kernel_wrapper) 38 | endif() 39 | 40 | set_target_properties(kernel_wrapper PROPERTIES CXX_VISIBILITY_PRESET "hidden" 41 | CUDA_VISIBILITY_PRESET "hidden" 42 | PREFIX "") 43 | 44 | install(TARGETS kernel_wrapper DESTINATION .) -------------------------------------------------------------------------------- /openequivariance/extension/convolution.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | class __attribute__ ((visibility ("default"))) ConvolutionImpl { 10 | public: 11 | bool record_internal_stats = false; 12 | 13 | ConvolutionImpl() { 14 | } 15 | 16 | virtual void exec_conv( 17 | void* L1_in, 18 | void* L2_in, 19 | void* weights, 20 | void* L3_out, 21 | void* rows, 22 | void* cols, 23 | uint64_t nnz, 24 | uint64_t node_count, 25 | void* workspace) = 0; 26 | 27 | void exec_conv_rawptrs( 28 | uint64_t L1_in, 29 | uint64_t L2_in, 30 | uint64_t weights, 31 | uint64_t L3_out, 32 | uint64_t rows, 33 | uint64_t cols, 34 | uint64_t nnz, 35 | uint64_t node_count, 36 | uint64_t workspace) { 37 | 38 | exec_conv( 39 | reinterpret_cast(L1_in), 40 | reinterpret_cast(L2_in), 41 | reinterpret_cast(weights), 42 | reinterpret_cast(L3_out), 43 | reinterpret_cast(rows), 44 | reinterpret_cast(cols), 45 | nnz, 46 | node_count, 47 | reinterpret_cast(workspace)); 48 | } 49 | 50 | virtual void backward( 51 | void* L1_in, void* L1_grad, 52 | void* L2_in, void* L2_grad, 53 | void* weight, void* weight_grad, 54 | void* L3_grad, 55 | void* rows, void* cols, 56 | uint64_t nnz, uint64_t node_count, 57 | void* workspace, void* inverse_perm) = 0; 58 | 59 | void backward_rawptrs( 60 | uint64_t L1_in, uint64_t L1_grad, 61 | uint64_t L2_in, uint64_t L2_grad, 62 | uint64_t weight, uint64_t weight_grad, 63 | uint64_t L3_grad, 64 | uint64_t rows, uint64_t cols, 65 | uint64_t nnz, uint64_t node_count, 66 | uint64_t workspace, uint64_t inverse_perm) { 67 | 68 | backward( 69 | reinterpret_cast(L1_in), 70 | reinterpret_cast(L1_grad), 71 | reinterpret_cast(L2_in), 72 | reinterpret_cast(L2_grad), 73 | reinterpret_cast(weight), 74 | reinterpret_cast(weight_grad), 75 | reinterpret_cast(L3_grad), 76 | reinterpret_cast(rows), 77 | reinterpret_cast(cols), 78 | nnz, 79 | node_count, 80 | reinterpret_cast(workspace), 81 | reinterpret_cast(inverse_perm)); 82 | } 83 | 84 | virtual ~ConvolutionImpl() {}; 85 | }; 86 | 87 | struct ConvData { 88 | void* rows; 89 | void* cols; 90 | unsigned long nnz; 91 | unsigned long node_count; 92 | }; 93 | 94 | template 95 | class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionImpl{ 96 | public: 97 | JIT_IMPL jit; 98 | KernelLaunchConfig forward_config; 99 | KernelLaunchConfig backward_config; 100 | KernelLaunchConfig double_backward_config; 101 | bool is_uvw; 102 | 103 | JITConvImpl( 104 | std::string jit_kernel, 105 | KernelLaunchConfig forward_config_i, 106 | KernelLaunchConfig backward_config_i, 107 | KernelLaunchConfig double_backward_config_i, 108 | bool is_uvw_i) : 109 | jit(jit_kernel), 110 | forward_config(forward_config_i), 111 | backward_config(backward_config_i), 112 | double_backward_config(double_backward_config_i), 113 | is_uvw(is_uvw_i) { 114 | 115 | vector kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B", "fixup_double_backwardB"}; 116 | 117 | int opt_level = 3; 118 | #ifdef HIP_BACKEND 119 | if(is_uvw) { 120 | opt_level = 1; 121 | } 122 | #endif 123 | jit.compile(kernels, {{}, {}, {}, {}, {}, {}, {}}, opt_level); 124 | 125 | if(forward_config.smem > 0) { 126 | jit.set_max_smem(0, forward_config.smem); 127 | jit.set_max_smem(4, forward_config.smem); 128 | } 129 | 130 | if(backward_config.smem > 0) { 131 | jit.set_max_smem(1, backward_config.smem); 132 | } 133 | 134 | if(double_backward_config.smem > 0) { 135 | jit.set_max_smem(5, double_backward_config.smem); 136 | } 137 | } 138 | 139 | JITConvImpl( 140 | std::string jit_kernel, 141 | std::unordered_map fwd_dict, 142 | std::unordered_map bwd_dict, 143 | std::unordered_map dbl_bwd_dict, 144 | std::unordered_map kernel_dims 145 | ) : JITConvImpl( 146 | jit_kernel, 147 | KernelLaunchConfig( 148 | fwd_dict["num_blocks"], 149 | fwd_dict["num_threads"], 150 | fwd_dict["smem"] 151 | ), 152 | KernelLaunchConfig( 153 | bwd_dict["num_blocks"], 154 | bwd_dict["num_threads"], 155 | bwd_dict["smem"] 156 | ), 157 | KernelLaunchConfig( 158 | dbl_bwd_dict["num_blocks"], 159 | dbl_bwd_dict["num_threads"], 160 | dbl_bwd_dict["smem"] 161 | ), 162 | kernel_dims["is_uvw"] == 1) { } 163 | 164 | void exec_conv( 165 | void* L1_in, 166 | void* L2_in, 167 | void* weights, 168 | void* L3_out, 169 | void* rows, 170 | void* cols, 171 | uint64_t nnz, 172 | uint64_t node_count, 173 | void* workspace) { 174 | 175 | ConvData conv_data = {rows, cols, nnz, node_count}; 176 | 177 | void *args[] = {&L1_in, &L2_in, &weights, &L3_out, &conv_data, &workspace}; 178 | jit.execute(0, args, forward_config); 179 | 180 | if(reinterpret_cast(workspace) != 0) { 181 | void *fixup_args[] = {&workspace, &L3_out}; 182 | 183 | KernelLaunchConfig fixup_config; 184 | fixup_config.num_blocks = forward_config.num_blocks; 185 | fixup_config.num_threads = forward_config.num_threads; 186 | fixup_config.smem = 0; 187 | 188 | jit.execute(2, fixup_args, fixup_config); 189 | } 190 | } 191 | 192 | void backward( 193 | void* L1_in, void* L1_grad, 194 | void* L2_in, void* L2_grad, 195 | void* weight, void* weight_grad, 196 | void* L3_grad, 197 | void* rows, void* cols, 198 | uint64_t nnz, uint64_t node_count, 199 | void* workspace, 200 | void* transpose_perm) { 201 | 202 | ConvData conv_data = {rows, cols, nnz, node_count}; 203 | void *args[] = {&L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad, &conv_data, &workspace, &transpose_perm}; 204 | jit.execute(1, args, backward_config); 205 | 206 | if(reinterpret_cast(workspace) != 0) { 207 | void *fixup_args[] = {&workspace, &L1_grad}; 208 | 209 | KernelLaunchConfig fixup_config; 210 | fixup_config.num_blocks = backward_config.num_blocks; 211 | fixup_config.num_threads = backward_config.num_threads; 212 | fixup_config.smem = 0; 213 | 214 | jit.execute(3, fixup_args, fixup_config); 215 | } 216 | } 217 | 218 | void double_backward( 219 | void* L1_in, void* L2_in, void* W, void* L3_grad, 220 | void* L1_dgrad, void* L2_dgrad, void* w_dgrad, 221 | void* L1_grad, void* L2_grad, void* W_grad, void* L3_dgrad, 222 | void* rows, void* cols, 223 | uint64_t nnz, uint64_t node_count, 224 | void* wspace, void* transpose_perm) { 225 | 226 | ConvData conv_data = {rows, cols, nnz, node_count}; 227 | void* args[] = { 228 | &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad, 229 | &L1_grad, &L2_grad, &W_grad, &L3_dgrad, &conv_data, &wspace, &transpose_perm 230 | }; 231 | 232 | jit.execute(4, args, forward_config); 233 | if(reinterpret_cast(wspace) != 0) { 234 | void *fixup_args[] = {&wspace, &L3_dgrad}; 235 | KernelLaunchConfig fixup_config; 236 | fixup_config.num_blocks = forward_config.num_blocks; 237 | fixup_config.num_threads = forward_config.num_threads; 238 | fixup_config.smem = 0; 239 | jit.execute(2, fixup_args, fixup_config); 240 | } 241 | 242 | jit.execute(5, args, double_backward_config); 243 | if(reinterpret_cast(wspace) != 0) { 244 | void *fixup_args[] = {&wspace, &L1_grad}; 245 | KernelLaunchConfig fixup_config; 246 | fixup_config.num_blocks = double_backward_config.num_blocks; 247 | fixup_config.num_threads = double_backward_config.num_threads; 248 | fixup_config.smem = 0; 249 | jit.execute(6, fixup_args, fixup_config); 250 | } 251 | } 252 | 253 | ~JITConvImpl() = default; 254 | }; -------------------------------------------------------------------------------- /openequivariance/extension/generic_module.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #ifdef CUDA_BACKEND 8 | #include "backend_cuda.hpp" 9 | #include "group_mm_cuda.hpp" 10 | using JITKernel = CUJITKernel; 11 | using GPU_Allocator = CUDA_Allocator; 12 | 13 | template 14 | using GroupMM = GroupMMCUDA; 15 | #endif 16 | 17 | #ifdef HIP_BACKEND 18 | #include "backend_hip.hpp" 19 | #include "group_mm_hip.hpp" 20 | using JITKernel = HIPJITKernel; 21 | using GPU_Allocator = HIP_Allocator; 22 | 23 | template 24 | using GroupMM = GroupMMHIP; 25 | #endif 26 | 27 | #include "buffer.hpp" 28 | #include "tensorproducts.hpp" 29 | #include "convolution.hpp" 30 | 31 | using namespace std; 32 | namespace py=pybind11; 33 | 34 | PYBIND11_MODULE(generic_module, m) { 35 | //=========== Batch tensor products ========= 36 | py::class_(m, "GenericTensorProductImpl") 37 | .def("exec_tensor_product_rawptr", &GenericTensorProductImpl::exec_tensor_product_device_rawptrs) 38 | .def("backward_rawptr", &GenericTensorProductImpl::backward_device_rawptrs); 39 | py::class_, GenericTensorProductImpl>(m, "JITTPImpl") 40 | .def(py::init< std::string, 41 | std::unordered_map, 42 | std::unordered_map, 43 | std::unordered_map, 44 | std::unordered_map>()); 45 | 46 | //============= Convolutions =============== 47 | py::class_(m, "ConvolutionImpl") 48 | .def("exec_conv_rawptrs", &ConvolutionImpl::exec_conv_rawptrs) 49 | .def("backward_rawptrs", &ConvolutionImpl::backward_rawptrs); 50 | py::class_, ConvolutionImpl>(m, "JITConvImpl") 51 | .def(py::init< std::string, 52 | std::unordered_map, 53 | std::unordered_map, 54 | std::unordered_map, 55 | std::unordered_map>()); 56 | 57 | py::class_>(m, "GroupMM_F32") 58 | .def(py::init()) 59 | .def("group_gemm", &GroupMM::group_gemm_intptr); 60 | py::class_>(m, "GroupMM_F64") 61 | .def(py::init()) 62 | .def("group_gemm", &GroupMM::group_gemm_intptr); 63 | 64 | py::class_(m, "DeviceProp") 65 | .def(py::init()) 66 | .def_readonly("name", &DeviceProp::name) 67 | .def_readonly("warpsize", &DeviceProp::warpsize) 68 | .def_readonly("major", &DeviceProp::major) 69 | .def_readonly("minor", &DeviceProp::minor) 70 | .def_readonly("multiprocessorCount", &DeviceProp::multiprocessorCount) 71 | .def_readonly("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); 72 | 73 | py::class_>(m, "DeviceBuffer") 74 | .def(py::init()) 75 | .def(py::init()) 76 | .def("copy_to_host", &PyDeviceBuffer::copy_to_host) 77 | .def("data_ptr", &PyDeviceBuffer::data_ptr); 78 | 79 | py::class_(m, "GPUTimer") 80 | .def(py::init<>()) 81 | .def("start", &GPUTimer::start) 82 | .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) 83 | .def("clear_L2_cache", &GPUTimer::clear_L2_cache); 84 | } -------------------------------------------------------------------------------- /openequivariance/extension/group_mm_cuda.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cublas_v2.h" 4 | #include 5 | #include 6 | #include 7 | 8 | using namespace std; 9 | 10 | template 11 | class GroupMMCUDA { 12 | cublasStatus_t stat; 13 | cublasHandle_t handle; 14 | 15 | int num_W; 16 | int batch_size; 17 | 18 | T alpha; 19 | T beta; 20 | 21 | public: 22 | GroupMMCUDA(int num_W, int batch_size) : 23 | num_W(num_W), 24 | batch_size(batch_size), 25 | alpha(1.0), 26 | beta(0.0) { 27 | stat = cublasCreate(&handle); 28 | if (stat != CUBLAS_STATUS_SUCCESS) { 29 | throw std::logic_error("CUBLAS initialization failed"); 30 | } 31 | } 32 | 33 | void group_gemm(void* A_raw, void* B_raw, void* C_raw, 34 | int64_t* ragged_counts, int m, int k, int ragged_inner) { 35 | /* 36 | * Performs one of two grouped, batched GEMMs with a single ragged dimension: 37 | * 38 | * a) If ragged_inner = 0, multiplies each M x K row-major weight matrix A 39 | * against B, where B is stored in column-major order with each matrix of 40 | * dimensions K x [offset_diff]. Output has dimensions M x [offset_diff], 41 | * stored in column-major order. 42 | * b) If ragged_inner = 1, multiplies each M x [offset_diff] A matrix 43 | * against each B K x [offset_diff] matrix transposed to produce a 44 | * M x K matrix output. 45 | */ 46 | 47 | T* A_base = reinterpret_cast(A_raw); 48 | T* B_base = reinterpret_cast(B_raw); 49 | T* C_base = reinterpret_cast(C_raw); 50 | 51 | int64_t ragged_offset = 0; 52 | for(int i = 0; i < num_W; i++) { 53 | int M, K, N, lda, ldb, ldc; 54 | T *A, *B, *C; 55 | 56 | int strideA, strideB, strideC; 57 | cublasOperation_t transa, transb; 58 | 59 | if(ragged_inner == 0) { 60 | M = m; 61 | K = k; 62 | N = static_cast(ragged_counts[i]); 63 | 64 | A = A_base + (m * k * batch_size * i); 65 | lda = k; strideA = M * K; 66 | 67 | B = B_base + (k * batch_size * ragged_offset); 68 | ldb = K * batch_size; strideB = K; 69 | 70 | C = C_base + (m * batch_size * ragged_offset); 71 | ldc = M * batch_size; strideC = M; 72 | 73 | transa = CUBLAS_OP_T; 74 | transb = CUBLAS_OP_N; 75 | } 76 | else { 77 | M = k; 78 | K = static_cast(ragged_counts[i]); 79 | N = m; 80 | 81 | A = B_base + (k * batch_size * ragged_offset); 82 | lda = k * batch_size; strideA = M; 83 | 84 | B = A_base + (m * batch_size * ragged_offset); 85 | ldb = m * batch_size; strideB = N; 86 | 87 | C = C_base + (m * k * batch_size * i); 88 | ldc = k; strideC = M * N; 89 | 90 | transa = CUBLAS_OP_N; 91 | transb = CUBLAS_OP_T; 92 | } 93 | ragged_offset += ragged_counts[i]; 94 | 95 | if(ragged_counts[i] > 0) { 96 | if(std::is_same::value) { 97 | stat = cublasSgemmStridedBatched(handle, 98 | transa, transb, 99 | M, N, K, 100 | reinterpret_cast(&alpha), 101 | reinterpret_cast(A), lda, strideA, 102 | reinterpret_cast(B), ldb, strideB, 103 | reinterpret_cast(&beta), 104 | reinterpret_cast(C), ldc, strideC, 105 | batch_size); 106 | } 107 | else if(std::is_same::value) { 108 | stat = cublasDgemmStridedBatched(handle, 109 | transa, transb, 110 | M, N, K, 111 | reinterpret_cast(&alpha), 112 | reinterpret_cast(A), lda, strideA, 113 | reinterpret_cast(B), ldb, strideB, 114 | reinterpret_cast(&beta), 115 | reinterpret_cast(C), ldc, strideC, 116 | batch_size); 117 | } 118 | else { 119 | throw std::logic_error("Unsupported datatype for grouped GEMM!"); 120 | } 121 | if (stat != CUBLAS_STATUS_SUCCESS) { 122 | throw std::logic_error("Grouped GEMM failed!"); 123 | } 124 | } 125 | } 126 | } 127 | 128 | void group_gemm_intptr(uint64_t weights, 129 | uint64_t vectors, uint64_t output, 130 | uint64_t ragged_counts, int m, int k, int ragged_inner) { 131 | 132 | group_gemm( 133 | reinterpret_cast(weights), 134 | reinterpret_cast(vectors), 135 | reinterpret_cast(output), 136 | reinterpret_cast(ragged_counts), 137 | m, k, ragged_inner); 138 | } 139 | 140 | ~GroupMMCUDA() { 141 | cublasDestroy(handle); 142 | } 143 | }; -------------------------------------------------------------------------------- /openequivariance/extension/group_mm_hip.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "rocblas/rocblas.h" 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | template 10 | class GroupMMHIP { 11 | rocblas_status stat; 12 | rocblas_handle handle; 13 | 14 | int num_W; 15 | int batch_size; 16 | 17 | T alpha; 18 | T beta; 19 | 20 | public: 21 | GroupMMHIP(int num_W, int batch_size) : 22 | num_W(num_W), 23 | batch_size(batch_size), 24 | alpha(1.0), 25 | beta(0.0) { 26 | if(rocblas_create_handle(&handle) != rocblas_status_success) { 27 | throw std::logic_error("rocBLAS initialization failed"); 28 | } 29 | } 30 | 31 | void group_gemm(void* A_raw, void* B_raw, void* C_raw, 32 | int64_t* ragged_counts, int m, int k, int ragged_inner) { 33 | 34 | T* A_base = reinterpret_cast(A_raw); 35 | T* B_base = reinterpret_cast(B_raw); 36 | T* C_base = reinterpret_cast(C_raw); 37 | 38 | int64_t ragged_offset = 0; 39 | for(int i = 0; i < num_W; i++) { 40 | int M, K, N, lda, ldb, ldc; 41 | T *A, *B, *C; 42 | 43 | int strideA, strideB, strideC; 44 | rocblas_operation transa, transb; 45 | 46 | if(ragged_inner == 0) { 47 | M = m; 48 | K = k; 49 | N = static_cast(ragged_counts[i]); 50 | 51 | A = A_base + (m * k * batch_size * i); 52 | lda = k; strideA = M * K; 53 | 54 | B = B_base + (k * batch_size * ragged_offset); 55 | ldb = K * batch_size; strideB = K; 56 | 57 | C = C_base + (m * batch_size * ragged_offset); 58 | ldc = M * batch_size; strideC = M; 59 | 60 | transa = rocblas_operation_transpose; 61 | transb = rocblas_operation_none; 62 | } 63 | else { 64 | M = k; 65 | K = static_cast(ragged_counts[i]); 66 | N = m; 67 | 68 | A = B_base + (k * batch_size * ragged_offset); 69 | lda = k * batch_size; strideA = M; 70 | 71 | B = A_base + (m * batch_size * ragged_offset); 72 | ldb = m * batch_size; strideB = N; 73 | 74 | C = C_base + (m * k * batch_size * i); 75 | ldc = k; strideC = M * N; 76 | 77 | transa = rocblas_operation_none; 78 | transb = rocblas_operation_transpose; 79 | } 80 | ragged_offset += ragged_counts[i]; 81 | 82 | if(ragged_counts[i] > 0) { 83 | if(std::is_same::value) { 84 | stat = rocblas_sgemm_strided_batched(handle, 85 | transa, transb, 86 | M, N, K, 87 | reinterpret_cast(&alpha), 88 | reinterpret_cast(A), lda, strideA, 89 | reinterpret_cast(B), ldb, strideB, 90 | reinterpret_cast(&beta), 91 | reinterpret_cast(C), ldc, strideC, 92 | batch_size); 93 | } 94 | else if(std::is_same::value) { 95 | stat = rocblas_dgemm_strided_batched(handle, 96 | transa, transb, 97 | M, N, K, 98 | reinterpret_cast(&alpha), 99 | reinterpret_cast(A), lda, strideA, 100 | reinterpret_cast(B), ldb, strideB, 101 | reinterpret_cast(&beta), 102 | reinterpret_cast(C), ldc, strideC, 103 | batch_size); 104 | } 105 | else { 106 | throw std::logic_error("Unsupported datatype for grouped GEMM!"); 107 | } 108 | if (stat != rocblas_status_success) { 109 | throw std::logic_error("Grouped GEMM failed!"); 110 | } 111 | } 112 | } 113 | } 114 | 115 | void group_gemm_intptr(uint64_t weights, 116 | uint64_t vectors, uint64_t output, 117 | uint64_t ragged_counts, int m, int k, int ragged_inner) { 118 | group_gemm( 119 | reinterpret_cast(weights), 120 | reinterpret_cast(vectors), 121 | reinterpret_cast(output), 122 | reinterpret_cast(ragged_counts), 123 | m, k, ragged_inner); 124 | } 125 | 126 | ~GroupMMHIP() { 127 | rocblas_destroy_handle(handle); 128 | } 129 | }; -------------------------------------------------------------------------------- /openequivariance/extension/tensorproducts.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | class __attribute__ ((visibility ("default"))) GenericTensorProductImpl { 10 | public: 11 | GenericTensorProductImpl() { } 12 | 13 | virtual void exec_tensor_product(uint64_t num_products, 14 | void* L1_in, void* L2_in, void* L3_out, void* weights) = 0; 15 | 16 | void exec_tensor_product_device_rawptrs(uint64_t num_products, 17 | uint64_t L1_in, uint64_t L2_in, uint64_t L3_out, uint64_t weights) { 18 | 19 | exec_tensor_product(num_products, 20 | reinterpret_cast(L1_in), 21 | reinterpret_cast(L2_in), 22 | reinterpret_cast(L3_out), 23 | reinterpret_cast(weights)); 24 | } 25 | 26 | virtual void backward(size_t num_products, 27 | void* L1_in, void* L1_grad, 28 | void* L2_in, void* L2_grad, 29 | void* weight, void* weight_grad, 30 | void* L3_grad) { 31 | 32 | throw std::logic_error("Backward pass not implemented yet!"); 33 | } 34 | 35 | void backward_device_rawptrs(uint64_t num_products, 36 | uint64_t L1_in, uint64_t L1_grad, 37 | uint64_t L2_in, uint64_t L2_grad, 38 | uint64_t weight, uint64_t weight_grad, 39 | uint64_t L3_grad) { 40 | 41 | backward(num_products, 42 | reinterpret_cast(L1_in), reinterpret_cast(L1_grad), 43 | reinterpret_cast(L2_in), reinterpret_cast(L2_grad), 44 | reinterpret_cast(weight), reinterpret_cast(weight_grad), 45 | reinterpret_cast(L3_grad) 46 | ); 47 | } 48 | 49 | virtual ~GenericTensorProductImpl() {}; 50 | }; 51 | 52 | template 53 | class __attribute__ ((visibility ("default"))) JITTPImpl : public GenericTensorProductImpl { 54 | public: 55 | JIT_IMPL jit; 56 | KernelLaunchConfig forward_config, backward_config, double_backward_config; 57 | bool is_uvw; 58 | 59 | JITTPImpl( 60 | std::string jit_kernel, 61 | KernelLaunchConfig forward_config_i, 62 | KernelLaunchConfig backward_config_i, 63 | KernelLaunchConfig double_backward_config_i, 64 | bool is_uvw_i) : 65 | jit(jit_kernel), 66 | forward_config(forward_config_i), 67 | backward_config(backward_config_i), 68 | double_backward_config(double_backward_config_i), 69 | is_uvw(is_uvw_i) { 70 | vector kernels = {"forward", "backward", "double_backward_A", "double_backward_B"}; 71 | 72 | int opt_level = 3; 73 | #ifdef HIP_BACKEND 74 | if(is_uvw) { 75 | opt_level = 1; 76 | } 77 | #endif 78 | jit.compile(kernels, {{}, {}, {}, {}}, opt_level); 79 | 80 | if(forward_config.smem > 0) { 81 | jit.set_max_smem(0, forward_config.smem); 82 | jit.set_max_smem(2, forward_config.smem); 83 | } 84 | 85 | if(backward_config.smem > 0) { 86 | jit.set_max_smem(1, backward_config.smem); 87 | 88 | } 89 | if(double_backward_config.smem > 0) { 90 | jit.set_max_smem(3, double_backward_config.smem); 91 | } 92 | } 93 | 94 | JITTPImpl( 95 | std::string jit_kernel, 96 | std::unordered_map fwd_dict, 97 | std::unordered_map bwd_dict, 98 | std::unordered_map dbl_bwd_dict, 99 | std::unordered_map kernel_dims 100 | ) : JITTPImpl( 101 | jit_kernel, 102 | KernelLaunchConfig( 103 | fwd_dict["num_blocks"], 104 | fwd_dict["num_threads"], 105 | fwd_dict["smem"] 106 | ), 107 | KernelLaunchConfig( 108 | bwd_dict["num_blocks"], 109 | bwd_dict["num_threads"], 110 | bwd_dict["smem"] 111 | ), 112 | KernelLaunchConfig( 113 | dbl_bwd_dict["num_blocks"], 114 | dbl_bwd_dict["num_threads"], 115 | dbl_bwd_dict["smem"] 116 | ), 117 | kernel_dims["is_uvw"] == 1 118 | ) { } 119 | 120 | void exec_tensor_product( 121 | uint64_t num_products, 122 | void* L1_in, 123 | void* L2_in, 124 | void* L3_out, 125 | void* weights) { 126 | 127 | void *args[] = { &num_products, &L1_in, &L2_in, &L3_out, &weights}; 128 | jit.execute(0, args, forward_config); 129 | } 130 | 131 | void backward( 132 | size_t num_products, 133 | void* L1_in, void* L1_grad, 134 | void* L2_in, void* L2_grad, 135 | void* weight, void* weight_grad, 136 | void* L3_grad) { 137 | void *args[] = { &num_products, &L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad}; 138 | jit.execute(1, args, backward_config); 139 | } 140 | 141 | void double_backward( 142 | size_t num_products, 143 | void* L1_in, void* L2_in, void* W, void* L3_grad, // Inputs of backward op 144 | void* L1_dgrad, void* L2_dgrad, void* w_dgrad, // Gradients w.r.t outputs of backward op 145 | void* L1_grad, void* L2_grad, void* W_grad, void* L3_dgrad) { 146 | 147 | void* args[] = { 148 | &num_products, &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad, 149 | &L1_grad, &L2_grad, &W_grad, &L3_dgrad 150 | }; 151 | jit.execute(2, args, forward_config); 152 | jit.execute(3, args, double_backward_config); 153 | } 154 | 155 | ~JITTPImpl() = default; 156 | }; -------------------------------------------------------------------------------- /openequivariance/extension/test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR) 2 | project(test_oeq_jitscript_load) 3 | 4 | find_package(Torch REQUIRED) 5 | 6 | add_executable(load_jitscript load_jitscript.cpp) 7 | target_link_libraries(load_jitscript "${TORCH_LIBRARIES}") 8 | target_link_libraries(load_jitscript -Wl,--no-as-needed "${OEQ_EXTLIB}") 9 | set_property(TARGET load_jitscript PROPERTY CXX_STANDARD 17) -------------------------------------------------------------------------------- /openequivariance/extension/test/load_jitscript.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | /* 7 | * This program takes in two JITScript modules that execute 8 | * a tensor product in FP32 precision. 9 | * The first module is compiled from e3nn, the second is 10 | * OEQ's compiled module. The program checks that the 11 | * two outputs are comparable. 12 | */ 13 | 14 | int main(int argc, const char* argv[]) { 15 | if (argc != 7) { 16 | std::cerr << "usage: load_jitscript " 17 | << " " 18 | << " " 19 | << " " 20 | << " " 21 | << " " 22 | << " " 23 | << std::endl; 24 | 25 | return 1; 26 | } 27 | 28 | int64_t L1_dim = std::stoi(argv[3]); 29 | int64_t L2_dim = std::stoi(argv[4]); 30 | int64_t weight_numel = std::stoi(argv[5]); 31 | int64_t batch_size = std::stoi(argv[6]); 32 | 33 | torch::Device device(torch::kCUDA); 34 | std::vector inputs; 35 | inputs.push_back(torch::randn({batch_size, L1_dim}, device)); 36 | inputs.push_back(torch::randn({batch_size, L2_dim}, device)); 37 | inputs.push_back(torch::randn({batch_size, weight_numel}, device)); 38 | 39 | torch::jit::script::Module module_e3nn, module_oeq; 40 | try { 41 | module_e3nn = torch::jit::load(argv[1]); 42 | module_oeq = torch::jit::load(argv[2]); 43 | } 44 | catch (const c10::Error& e) { 45 | std::cerr << "error loading script module" << std::endl; 46 | return 1; 47 | } 48 | 49 | module_e3nn.to(device); 50 | module_oeq.to(device); 51 | 52 | at::Tensor output_e3nn = module_e3nn.forward(inputs).toTensor(); 53 | at::Tensor output_oeq = module_oeq.forward(inputs).toTensor(); 54 | 55 | if(at::allclose(output_e3nn, output_oeq, 1e-5, 1e-5)) { 56 | return 0; 57 | } 58 | else { 59 | std::cerr << "torch.allclose returned FALSE comparing model outputs." << std::endl; 60 | return 1; 61 | } 62 | } -------------------------------------------------------------------------------- /openequivariance/extension/util/buffer.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | using namespace std; 6 | namespace py = pybind11; 7 | 8 | template 9 | class PyDeviceBuffer { 10 | public: 11 | char* host_ptr; 12 | char* device_ptr; 13 | size_t size; 14 | 15 | PyDeviceBuffer(uint64_t size) { 16 | this->size = size; 17 | device_ptr = static_cast(ALLOC_T::gpu_alloc(size)); 18 | host_ptr = nullptr; 19 | } 20 | 21 | PyDeviceBuffer(py::buffer host_data) { 22 | const py::buffer_info &info = host_data.request(); 23 | host_ptr = static_cast(info.ptr); 24 | size = 1; 25 | for(int64_t i = 0; i < info.ndim; i++) { 26 | size *= info.shape[i]; 27 | } 28 | size *= info.itemsize; 29 | 30 | device_ptr = static_cast(ALLOC_T::gpu_alloc(size)); 31 | ALLOC_T::copy_host_to_device(host_ptr, device_ptr, size); 32 | } 33 | 34 | ~PyDeviceBuffer() { 35 | ALLOC_T::gpu_free(static_cast(device_ptr)); 36 | } 37 | 38 | void copy_to_host() { 39 | ALLOC_T::copy_device_to_host(host_ptr, device_ptr, size); 40 | } 41 | 42 | uint64_t data_ptr() { 43 | return reinterpret_cast(device_ptr); 44 | } 45 | }; -------------------------------------------------------------------------------- /openequivariance/extlib/.empty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PASSIONLab/OpenEquivariance/ce8e44ea8a9b0b78dee5ac28dc177218cb802db4/openequivariance/extlib/.empty -------------------------------------------------------------------------------- /openequivariance/extlib/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa : F401, E402 2 | import sys 3 | import os 4 | import warnings 5 | from pathlib import Path 6 | 7 | from openequivariance.benchmark.logging_utils import getLogger 8 | from distutils import sysconfig 9 | 10 | oeq_root = str(Path(__file__).parent.parent) 11 | 12 | build_ext = True 13 | TORCH_COMPILE = True 14 | torch_module, generic_module = None, None 15 | postprocess_kernel = lambda kernel: kernel # noqa : E731 16 | 17 | try: 18 | python_lib_dir = sysconfig.get_config_var("LIBDIR") 19 | major, minor = sys.version_info.major, sys.version_info.minor 20 | python_lib_name = f"python{major}.{minor}" 21 | 22 | except Exception as e: 23 | print("Error while retrieving Python library information:", file=sys.stderr) 24 | print(e, file=sys.stderr) 25 | print("Syconfig variable list:", file=sys.stderr) 26 | print(sysconfig.get_config_vars(), file=sys.stderr) 27 | exit(1) 28 | 29 | if not build_ext: 30 | from openequivariance.extlib.generic_module import ( 31 | GenericTensorProductImpl, 32 | JITTPImpl, 33 | ConvolutionImpl, 34 | JITConvImpl, 35 | GroupMM_F32, 36 | GroupMM_F64, 37 | DeviceProp, 38 | DeviceBuffer, 39 | GPUTimer, 40 | ) 41 | else: 42 | from torch.utils.cpp_extension import library_paths, include_paths 43 | 44 | global torch 45 | import torch 46 | 47 | extra_cflags = ["-O3"] 48 | generic_sources = ["generic_module.cpp"] 49 | torch_sources = ["libtorch_tp_jit.cpp"] 50 | 51 | include_dirs, extra_link_args = ( 52 | ["util"], 53 | [ 54 | f"-Wl,--no-as-needed,-rpath,{python_lib_dir}", 55 | f"-L{python_lib_dir}", 56 | f"-l{python_lib_name}", 57 | ], 58 | ) 59 | 60 | if torch.version.cuda: 61 | extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"]) 62 | 63 | try: 64 | torch_libs, cuda_libs = library_paths("cuda") 65 | extra_link_args.append("-Wl,-rpath," + torch_libs) 66 | extra_link_args.append("-L" + cuda_libs) 67 | if os.path.exists(cuda_libs + "/stubs"): 68 | extra_link_args.append("-L" + cuda_libs + "/stubs") 69 | except Exception as e: 70 | getLogger().info(str(e)) 71 | 72 | extra_cflags.append("-DCUDA_BACKEND") 73 | elif torch.version.hip: 74 | extra_link_args.extend(["-lhiprtc"]) 75 | torch_libs = library_paths("cuda")[0] 76 | extra_link_args.append("-Wl,-rpath," + torch_libs) 77 | 78 | def postprocess(kernel): 79 | kernel = kernel.replace("__syncwarp();", "__threadfence_block();") 80 | kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") 81 | kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") 82 | return kernel 83 | 84 | postprocess_kernel = postprocess 85 | 86 | extra_cflags.append("-DHIP_BACKEND") 87 | 88 | generic_sources = [oeq_root + "/extension/" + src for src in generic_sources] 89 | torch_sources = [oeq_root + "/extension/" + src for src in torch_sources] 90 | include_dirs = [oeq_root + "/extension/" + d for d in include_dirs] + include_paths( 91 | "cuda" 92 | ) 93 | 94 | torch_compile_exception = None 95 | with warnings.catch_warnings(): 96 | warnings.simplefilter("ignore") 97 | 98 | try: 99 | torch_module = torch.utils.cpp_extension.load( 100 | "libtorch_tp_jit", 101 | torch_sources, 102 | extra_cflags=extra_cflags, 103 | extra_include_paths=include_dirs, 104 | extra_ldflags=extra_link_args, 105 | ) 106 | torch.ops.load_library(torch_module.__file__) 107 | except Exception as e: 108 | # If compiling torch fails (e.g. low gcc version), we should fall back to the 109 | # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export). 110 | TORCH_COMPILE = False 111 | torch_compile_exception = e 112 | 113 | generic_module = torch.utils.cpp_extension.load( 114 | "generic_module", 115 | generic_sources, 116 | extra_cflags=extra_cflags, 117 | extra_include_paths=include_dirs, 118 | extra_ldflags=extra_link_args, 119 | ) 120 | 121 | if not TORCH_COMPILE: 122 | warnings.warn( 123 | "Could not compile integrated PyTorch wrapper. Falling back to Pybind11" 124 | + f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}" 125 | ) 126 | 127 | from generic_module import ( 128 | GenericTensorProductImpl, 129 | JITTPImpl, 130 | ConvolutionImpl, 131 | JITConvImpl, 132 | GroupMM_F32, 133 | GroupMM_F64, 134 | DeviceProp, 135 | DeviceBuffer, 136 | GPUTimer, 137 | ) 138 | -------------------------------------------------------------------------------- /openequivariance/implementations/E3NNTensorProduct.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "E3NNTensorProduct", 3 | "E3NNTensorProductCompiled", 4 | "E3NNTensorProductCompiledCUDAGraphs", 5 | "E3NNTensorProductCompiledMaxAutotuneCUDAGraphs", 6 | ] 7 | 8 | import os 9 | import pathlib 10 | import numpy as np 11 | 12 | from openequivariance.implementations.TensorProductBase import TensorProductBase 13 | from openequivariance.implementations.e3nn_lite import TPProblem 14 | from openequivariance.benchmark.logging_utils import getLogger 15 | 16 | TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning") 17 | 18 | logger = getLogger() 19 | 20 | 21 | class E3NNTensorProduct(TensorProductBase): 22 | def __init__(self, config: TPProblem, torch_op=True): 23 | super().__init__(config, torch_op=torch_op) 24 | assert self.torch_op 25 | 26 | global torch 27 | global e3nn 28 | import torch 29 | import e3nn 30 | from e3nn import o3 31 | 32 | e3nn.set_optimization_defaults(jit_script_fx=False) 33 | 34 | assert config.irrep_dtype == config.weight_dtype 35 | if config.irrep_dtype == np.float64: 36 | torch.set_default_dtype(torch.float64) 37 | 38 | self.e3nn_tp = o3.TensorProduct( 39 | config.irreps_in1, 40 | config.irreps_in2, 41 | config.irreps_out, 42 | config.instructions_raw, 43 | in1_var=config.in1_var, 44 | in2_var=config.in2_var, 45 | out_var=config.out_var, 46 | irrep_normalization=config.irrep_normalization, 47 | path_normalization=config.path_normalization, 48 | internal_weights=config.internal_weights, 49 | shared_weights=config.shared_weights, 50 | ).to(device="cuda") 51 | 52 | if config.irrep_dtype == np.float64: 53 | torch.set_default_dtype(torch.float32) # Reset to default 54 | 55 | self.forward = self.e3nn_tp.__call__ 56 | 57 | def forward_cpu( 58 | self, 59 | L1_in: np.ndarray, 60 | L2_in: np.ndarray, 61 | L3_out: np.ndarray, 62 | weights: np.ndarray, 63 | ) -> None: 64 | torch_L1_in = torch.tensor(L1_in, device="cuda") 65 | torch_L2_in = torch.tensor(L2_in, device="cuda") 66 | torch_weights = torch.tensor(weights, device="cuda") 67 | 68 | torch_L3_out = self.e3nn_tp(torch_L1_in, torch_L2_in, torch_weights) 69 | 70 | L3_out[:] = torch_L3_out.numpy(force=True) 71 | 72 | def backward_cpu( 73 | self, 74 | L1_in: np.ndarray, 75 | L1_grad: np.ndarray, 76 | L2_in: np.ndarray, 77 | L2_grad: np.ndarray, 78 | L3_grad: np.ndarray, 79 | weights: np.ndarray, 80 | weights_grad: np.ndarray, 81 | ) -> None: 82 | torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") 83 | torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") 84 | torch_weights = torch.tensor(weights, requires_grad=True, device="cuda") 85 | 86 | torch_out = self.e3nn_tp(torch_L1_in, torch_L2_in, torch_weights) 87 | 88 | torch_L3_grad_in = torch.tensor(L3_grad, device="cuda") 89 | 90 | torch_out.backward(gradient=torch_L3_grad_in) 91 | 92 | L1_grad[:] = torch_L1_in.grad.numpy(force=True) 93 | L2_grad[:] = torch_L2_in.grad.numpy(force=True) 94 | weights_grad[:] = torch_weights.grad.numpy(force=True) 95 | 96 | @classmethod 97 | def name(cls): 98 | return cls.__name__ 99 | 100 | 101 | class E3NNTensorProductCompiled(E3NNTensorProduct): 102 | def __init__( 103 | self, 104 | config: TPProblem, 105 | torch_compile_kwargs: dict, 106 | torch_op: bool = True, 107 | ): 108 | super().__init__(config, torch_op=torch_op) 109 | self.torch_compile_kwargs = torch_compile_kwargs 110 | 111 | logger.debug("Torch compiling e3nn TP") 112 | logger.debug(msg=f"{torch_compile_kwargs}") 113 | self.e3nn_tp = torch.compile(self.e3nn_tp, **self.torch_compile_kwargs) 114 | logger.debug("e3nn TP torch compiled") 115 | 116 | self.forward = self.e3nn_tp.__call__ 117 | 118 | 119 | class E3NNTensorProductCompiledCUDAGraphs(E3NNTensorProductCompiled): 120 | def __init__(self, config: TPProblem, torch_op=True): 121 | global torch 122 | import torch 123 | 124 | torch._dynamo.config.cache_size_limit = 64 125 | 126 | torch_compile_kwargs = { 127 | "fullgraph": True, 128 | "backend": "inductor", 129 | "options": {"triton.cudagraphs": True}, 130 | } 131 | super().__init__(config, torch_compile_kwargs, torch_op=torch_op) 132 | 133 | 134 | class E3NNTensorProductCompiledMaxAutotuneCUDAGraphs(E3NNTensorProductCompiled): 135 | def __init__(self, config: TPProblem, torch_op=True): 136 | global torch 137 | import torch 138 | 139 | TORCH_COMPILE_AUTOTUNING_DIR.mkdir(exist_ok=True) 140 | os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(TORCH_COMPILE_AUTOTUNING_DIR) 141 | os.environ["TRITON_CACHE_DIR"] = str(TORCH_COMPILE_AUTOTUNING_DIR) 142 | torch._dynamo.config.cache_size_limit = 64 143 | 144 | torch_compile_kwargs = { 145 | "fullgraph": True, 146 | "backend": "inductor", 147 | "options": { 148 | "max_autotune": True, 149 | "triton.cudagraphs": True, 150 | "triton.unique_kernel_names": False, 151 | "coordinate_descent_tuning": False, 152 | }, 153 | } 154 | super().__init__(config, torch_compile_kwargs, torch_op=torch_op) 155 | -------------------------------------------------------------------------------- /openequivariance/implementations/MultiplicityOuterProductTP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from openequivariance.implementations.utils import calc_weight_offsets 4 | from openequivariance.implementations.e3nn_lite import ( 5 | Irreps, 6 | TPProblem, 7 | Instruction, 8 | ) 9 | from openequivariance.implementations.TensorProductBase import TensorProductBase 10 | from openequivariance.benchmark.logging_utils import getLogger 11 | from jinja2 import Environment, PackageLoader 12 | 13 | from openequivariance.extlib import KernelLaunchConfig, JITTPImpl, DeviceProp 14 | 15 | logger = getLogger() 16 | 17 | 18 | def raise_helper(msg): 19 | raise Exception(msg) 20 | 21 | 22 | def divide(numerator, denominator): 23 | return numerator // denominator 24 | 25 | 26 | def sizeof(dtype): 27 | if dtype in ["float", "int", "unsigned int"]: 28 | return 4 29 | else: 30 | raise Exception("Provided undefined datatype to sizeof!") 31 | 32 | 33 | class MultiplicityOuterProductTP(TensorProductBase): 34 | def __init__(self, config: TPProblem, torch_op: bool = False): 35 | super().__init__(config, torch_op) 36 | 37 | for ins in config.instructions: # type : Instruction 38 | assert isinstance(ins, Instruction) 39 | assert ins.connection_mode == "uvw" 40 | assert ins.path_shape[0] <= 32 41 | assert ins.path_shape[1] <= 32 42 | assert ins.path_shape[2] <= 32 43 | 44 | irreps_in1 = config.irreps_in1 45 | irreps_in2 = config.irreps_in2 46 | irreps_out = config.irreps_out 47 | 48 | # ================================================================================== 49 | 50 | env = Environment( 51 | loader=PackageLoader("openequivariance"), extensions=["jinja2.ext.do"] 52 | ) 53 | env.globals["raise"] = raise_helper 54 | env.globals["divide"] = divide 55 | env.globals["sizeof"] = sizeof 56 | env.globals["range"] = range 57 | env.globals["enumerate"] = enumerate 58 | env.globals["len"] = len 59 | main_template = env.get_template("subkernel_per_interaction_multirep.cuh") 60 | # forward_subkernel_template = env.get_template("subkernel_forward_thread.cu.jinja2") 61 | # backward_subkernel_template = env.get_template("subkernel_backward_thread.cu.jinja2") 62 | 63 | # ===================================================================== 64 | # Updated to work with TensorProductProblem 65 | 66 | class RepData: 67 | def __init__(self, irreps: Irreps): 68 | assert isinstance(irreps, Irreps) 69 | self.rep_len = irreps.dim 70 | self.irrep_lengths = [mul_irrep.ir.dim for mul_irrep in irreps] 71 | self.mults = [mul_irrep.mul for mul_irrep in irreps] 72 | 73 | offset = 0 74 | self.offsets = [] 75 | for mul_irrep in irreps: 76 | self.offsets.append(offset) 77 | offset += mul_irrep.dim 78 | 79 | # ===================================================================== 80 | # Strictly Copied from Loop Unroll TP 81 | 82 | class CGTensor: 83 | def __init__(self, l1, l2, l3): 84 | tensor = load_cg_tensor(l1, l2, l3) 85 | coord1, coord2, coord3 = [ 86 | arr.astype(np.int32).copy() for arr in np.nonzero(tensor) 87 | ] 88 | float_values = tensor[np.nonzero(tensor)].astype(np.float32).copy() 89 | values = [str(float.hex(float(val))) + "f" for val in float_values] 90 | 91 | self.tuples = [ 92 | (coord1[i], coord2[i], coord3[i], values[i]) 93 | for i in range(len(values)) 94 | ] 95 | # self.tuples.sort(key=lambda tup: (tup[1], tup[0], tup[2])) 96 | self.nnz = len(values) 97 | 98 | # ===================================================================== 99 | # FORWARD MEMORY ANALYSIS 100 | forward_thread_blocks_per_SM = 24 101 | forward_threads_per_thread_block = 32 102 | 103 | # ===================================================================== 104 | dp = DeviceProp(0) 105 | 106 | forward_launch_config = KernelLaunchConfig() 107 | forward_launch_config.num_blocks = ( 108 | dp.multiprocessorCount * forward_thread_blocks_per_SM 109 | ) 110 | forward_launch_config.num_threads = forward_threads_per_thread_block 111 | 112 | # IMPORTANT! 113 | smem_gemm_max_n = forward_threads_per_thread_block 114 | smem_gemm_L3_scratch = smem_gemm_max_n * max( 115 | RepData(config.irreps_out).irrep_lengths 116 | ) # this has space for the largest output size * 32 117 | smem_gemm_weights_scratch = ( 118 | max(RepData(config.irreps_out).mults) * smem_gemm_max_n 119 | ) 120 | 121 | smem_gemm_info = { 122 | "n": smem_gemm_max_n, 123 | "L3_scratch_elems": smem_gemm_L3_scratch, 124 | "weight_scratch_elems": smem_gemm_weights_scratch, 125 | } 126 | logger.debug(smem_gemm_info) 127 | # END OF IMPORTANT 128 | 129 | forward_launch_config.smem = ( 130 | ( 131 | irreps_in1.dim 132 | + irreps_in2.dim 133 | + irreps_out.dim 134 | + smem_gemm_L3_scratch 135 | + smem_gemm_weights_scratch 136 | ) 137 | * sizeof("float") 138 | * forward_launch_config.num_threads 139 | // forward_launch_config.warp_size 140 | ) 141 | 142 | logger.info( 143 | f"Forward pass needs {forward_launch_config.smem} bytes of shared memory." 144 | ) 145 | 146 | if forward_launch_config.smem > dp.maxSharedMemPerBlock: 147 | raise Exception( 148 | f"Error, requested shared memory {forward_launch_config.smem}B hits or exceeds maximum, {dp.maxSharedMemPerBlock}B !" 149 | ) 150 | 151 | # ===================================================================== 152 | 153 | backward_launch_config = KernelLaunchConfig() 154 | backward_launch_config.num_blocks = dp.multiprocessorCount * 1 155 | backward_launch_config.num_threads = 32 156 | backward_launch_config.smem = ( 157 | (2 * irreps_in1.dim + 2 * irreps_in2.dim + 2 * +irreps_out.dim) 158 | * sizeof("float") 159 | * backward_launch_config.num_threads 160 | // backward_launch_config.warp_size 161 | ) 162 | logger.info( 163 | f"Backward pass needs {backward_launch_config.smem} bytes of shared memory." 164 | ) 165 | 166 | if backward_launch_config.smem > dp.maxSharedMemPerBlock: 167 | raise Exception( 168 | f"Error, requested shared memory {backward_launch_config.smem}B hits or exceeds maximum, {dp.maxSharedMemPerBlock}B !" 169 | ) 170 | 171 | # ===================================================================== 172 | 173 | self.forward_config = forward_launch_config 174 | self.backward_config = backward_launch_config 175 | load_cg_tensor = self.load_cg_tensor 176 | 177 | # ===================================================================== 178 | # weights_offsets 179 | weight_offsets = calc_weight_offsets(config) 180 | assert isinstance(weight_offsets, list) 181 | assert len(weight_offsets) == len(list(config.instructions)) 182 | 183 | # ===================================================================== 184 | # tranform "e3nn instructions" into "interactions" 185 | instructions: list[Instruction] = config.instructions 186 | interactions = [] 187 | for ins in instructions: 188 | u = ins.i_in1 189 | v = ins.i_in2 190 | w = ins.i_out 191 | interaction = ( 192 | u, 193 | v, 194 | w, 195 | CGTensor(irreps_in1[u].ir.l, irreps_in2[v].ir.l, irreps_out[w].ir.l), 196 | ) 197 | interactions.append(interaction) 198 | # interactions.sort(key=lambda x: (x[2], x[0], x[1])) 199 | 200 | assert len(interactions) != 0 201 | 202 | # ===================================================================== 203 | kernel_text = main_template.render( 204 | L1=RepData(config.irreps_in1), 205 | L2=RepData(config.irreps_in2), 206 | L3=RepData(config.irreps_out), 207 | weight_numel=config.weight_numel, 208 | weight_offsets=weight_offsets, 209 | instructions=instructions, 210 | interactions=interactions, 211 | smem_gemm_info=smem_gemm_info, 212 | forward_config=forward_launch_config, 213 | backward_config=backward_launch_config, 214 | ) 215 | 216 | self.jit_kernel = kernel_text 217 | 218 | logger.debug(kernel_text) 219 | 220 | logger.info("Starting NVRTC") 221 | self.internal = JITTPImpl( 222 | self.jit_kernel, self.forward_config, self.backward_config 223 | ) 224 | logger.info("Kernel compiled!") 225 | 226 | if self.torch_op: 227 | self.setup_torch_custom_op() 228 | 229 | @staticmethod 230 | def name(): 231 | return "MultiplicityOuterProductTP" 232 | -------------------------------------------------------------------------------- /openequivariance/implementations/TensorProduct.py: -------------------------------------------------------------------------------- 1 | from openequivariance.implementations.LoopUnrollTP import LoopUnrollTP 2 | from openequivariance import TPProblem 3 | import torch 4 | 5 | 6 | class TensorProduct(torch.nn.Module, LoopUnrollTP): 7 | """ 8 | Drop-in replacement for ``o3.TensorProduct`` from e3nn. Supports forward, 9 | backward, and double-backward passes using JIT-compiled kernels. Initialization 10 | fails if: 11 | 12 | * There are no visible GPUs. 13 | * The provided tensor product specification is unsupported. 14 | 15 | :param problem: Specification of the tensor product. 16 | """ 17 | 18 | def __init__(self, problem: TPProblem, torch_op=True): 19 | torch.nn.Module.__init__(self) 20 | LoopUnrollTP.__init__(self, problem, torch_op) 21 | self.weight_numel = problem.weight_numel 22 | 23 | @staticmethod 24 | def name(): 25 | return LoopUnrollTP.name() 26 | 27 | def forward( 28 | self, x: torch.Tensor, y: torch.Tensor, W: torch.Tensor 29 | ) -> torch.Tensor: 30 | """ 31 | Computes :math:`W (x \otimes_{\\textrm{CG}} y)`, identical to 32 | ``o3.TensorProduct.forward``. 33 | 34 | :param x: Tensor of shape ``[batch_size, problem.irreps_in1.dim()]``, datatype 35 | ``problem.irrep_dtype``. 36 | :param y: Tensor of shape ``[batch_size, problem.irreps_in2.dim()]``, datatype 37 | ``problem.irrep_dtype``. 38 | :param W: Tensor of datatype ``problem.weight_dtype`` and shape 39 | 40 | * ``[batch_size, problem.weight_numel]`` if ``problem.shared_weights=False`` 41 | * ``[problem.weight_numel]`` if ``problem.shared_weights=True`` 42 | 43 | :return: Tensor of shape ``[batch_size, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. 44 | """ 45 | return torch.ops.libtorch_tp_jit.jit_tp_forward(self.internal, x, y, W) 46 | -------------------------------------------------------------------------------- /openequivariance/implementations/convolution/CUEConv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | from typing import Iterator 4 | 5 | from openequivariance.implementations.CUETensorProduct import CUETensorProduct 6 | from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase 7 | 8 | 9 | class CUEConv(ConvolutionBase): 10 | def __init__(self, config, idx_dtype=np.int64, torch_op=True): 11 | super().__init__(config, idx_dtype, torch_op) 12 | 13 | global torch 14 | import torch 15 | 16 | self.reference_tp = CUETensorProduct(config, torch_op) 17 | self.cue_tp = self.reference_tp.cue_tp 18 | 19 | from openequivariance.implementations.convolution.scatter import scatter_sum 20 | 21 | self.scatter_sum = scatter_sum 22 | 23 | def forward(self, L1_in, L2_in, weights, rows, cols): 24 | tp_outputs = self.cue_tp(L1_in[cols], L2_in, weights) 25 | return self.scatter_sum( 26 | src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0] 27 | ) 28 | 29 | @staticmethod 30 | def name(): 31 | return "CUEConvolution" 32 | 33 | 34 | class CUEConvFused(ConvolutionBase): 35 | def __init__(self, config, idx_dtype=np.int64, torch_op=True): 36 | super().__init__(config, idx_dtype, torch_op) 37 | 38 | global torch 39 | import torch 40 | import e3nn.o3 as o3 41 | 42 | np_to_torch_dtype = {np.float32: torch.float32, np.float64: torch.float64} 43 | 44 | import cuequivariance as cue 45 | from cuequivariance_torch.primitives.tensor_product import ( 46 | TensorProductUniform4x1dIndexed, 47 | ) 48 | 49 | class O3_e3nn(cue.O3): 50 | def __mul__( # pylint: disable=no-self-argument 51 | rep1: "O3_e3nn", rep2: "O3_e3nn" 52 | ) -> Iterator["O3_e3nn"]: 53 | return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)] 54 | 55 | @classmethod 56 | def clebsch_gordan( 57 | cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn" 58 | ) -> np.ndarray: 59 | rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3) 60 | 61 | if rep1.p * rep2.p == rep3.p: 62 | return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt( 63 | rep3.dim 64 | ) 65 | return np.zeros((0, rep1.dim, rep2.dim, rep3.dim)) 66 | 67 | def __lt__( # pylint: disable=no-self-argument 68 | rep1: "O3_e3nn", rep2: "O3_e3nn" 69 | ) -> bool: 70 | rep2 = rep1._from(rep2) 71 | return (rep1.l, rep1.p) < (rep2.l, rep2.p) 72 | 73 | @classmethod 74 | def iterator(cls) -> Iterator["O3_e3nn"]: 75 | for l in itertools.count(0): 76 | yield O3_e3nn(l=l, p=1 * (-1) ** l) 77 | yield O3_e3nn(l=l, p=-1 * (-1) ** l) 78 | 79 | descriptor = ( 80 | cue.descriptors.channelwise_tensor_product( 81 | cue.Irreps(O3_e3nn, str(config.irreps_in1)), 82 | cue.Irreps(O3_e3nn, str(config.irreps_in2)), 83 | cue.Irreps(O3_e3nn, str(config.irreps_out)), 84 | ) 85 | .squeeze_modes() 86 | .flatten_coefficient_modes() 87 | ) 88 | 89 | self.tp = TensorProductUniform4x1dIndexed( 90 | descriptor.polynomial.operations[0][1], 91 | "cuda", 92 | math_dtype=np_to_torch_dtype[config.irrep_dtype], 93 | ) 94 | 95 | def forward(self, L1_in, L2_in, weights, rows, cols): 96 | return self.tp(weights, L1_in, L2_in, None, rows, None, cols, L1_in.shape[0]) 97 | 98 | @staticmethod 99 | def name(): 100 | return "CUEConvolutionFused" 101 | -------------------------------------------------------------------------------- /openequivariance/implementations/convolution/E3NNConv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase 4 | from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct 5 | 6 | 7 | class E3NNConv(ConvolutionBase): 8 | def __init__(self, config, idx_dtype=np.int64, torch_op=True): 9 | assert torch_op 10 | super().__init__(config, idx_dtype, torch_op) 11 | 12 | from e3nn import o3 13 | import torch 14 | 15 | if config.irrep_dtype == np.float64: 16 | torch.set_default_dtype(torch.float64) 17 | 18 | self.e3nn_tp = o3.TensorProduct( 19 | config.irreps_in1, 20 | config.irreps_in2, 21 | config.irreps_out, 22 | config.instructions_raw, 23 | in1_var=config.in1_var, 24 | in2_var=config.in2_var, 25 | out_var=config.out_var, 26 | irrep_normalization=config.irrep_normalization, 27 | path_normalization=config.path_normalization, 28 | internal_weights=config.internal_weights, 29 | shared_weights=config.shared_weights, 30 | ).to(device="cuda") 31 | 32 | self.reference_tp = E3NNTensorProduct(config) 33 | 34 | if config.irrep_dtype == np.float64: 35 | torch.set_default_dtype(torch.float32) # Reset to default 36 | 37 | from openequivariance.implementations.convolution.scatter import scatter_sum 38 | 39 | self.scatter_sum = scatter_sum 40 | 41 | def forward(self, L1_in, L2_in, weights, rows, cols): 42 | tp_outputs = self.reference_tp(L1_in[cols], L2_in, weights) 43 | return self.scatter_sum( 44 | src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0] 45 | ) 46 | 47 | @staticmethod 48 | def name(): 49 | return "E3NNConvolution" 50 | 51 | def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): 52 | tp_outputs = np.zeros((graph.nnz, self.L3.dim), dtype=L3_out.dtype) 53 | self.reference_tp.forward_cpu(L1_in[graph.cols], L2_in, tp_outputs, weights) 54 | np.add.at(L3_out, graph.rows, tp_outputs) 55 | 56 | def backward_cpu( 57 | self, 58 | L1_in: np.ndarray, 59 | L1_grad: np.ndarray, 60 | L2_in: np.ndarray, 61 | L2_grad: np.ndarray, 62 | L3_grad: np.ndarray, 63 | weights: np.ndarray, 64 | weights_grad: np.ndarray, 65 | graph, 66 | ): 67 | L1_grad_bcast = np.zeros((graph.nnz, self.L1.dim), dtype=L1_grad.dtype) 68 | self.reference_tp.backward_cpu( 69 | L1_in[graph.cols], 70 | L1_grad_bcast, 71 | L2_in, 72 | L2_grad, 73 | L3_grad[graph.rows], 74 | weights, 75 | weights_grad, 76 | ) 77 | np.add.at(L1_grad, graph.cols, L1_grad_bcast) 78 | -------------------------------------------------------------------------------- /openequivariance/implementations/convolution/TensorProductConv.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import types 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from openequivariance import extlib 8 | from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase 9 | from openequivariance.implementations.convolution.LoopUnrollConv import LoopUnrollConv 10 | from openequivariance.implementations.TensorProduct import TensorProduct 11 | from openequivariance import TPProblem 12 | 13 | 14 | class TensorProductConv(torch.nn.Module, LoopUnrollConv): 15 | """ 16 | Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`, 17 | :math:`y_1...y_{|E|}`, and weights :math:`W_1...W_{|E|}`, computes 18 | 19 | .. math:: 20 | 21 | z_i = \sum_{(i, j, e) \in \mathcal{N}(i)} W_e (x_j \otimes_{\\textrm{CG}} y_e) 22 | 23 | where :math:`(i, j, e) \in \mathcal{N}(i)` indicates that node :math:`i` is connected to node :math:`j` 24 | via the edge indexed :math:`e`. 25 | 26 | This class offers multiple options to perform the summation: an atomic algorithm and a deterministic algorithm 27 | that relies on a sorted adjacency matrix input. If you use the determinstic algorithm, you must also supply 28 | a permutation to transpose the adjacency matrix. 29 | 30 | :param problem: Specification of the tensor product. 31 | :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic 32 | fixup-based algorithm. `Default`: ``False``. 33 | :param kahan: if ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option, 34 | the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``. 35 | 36 | """ 37 | 38 | def __init__( 39 | self, 40 | problem: TPProblem, 41 | deterministic: bool = False, 42 | kahan: bool = False, 43 | torch_op=True, 44 | ): 45 | torch.nn.Module.__init__(self) 46 | LoopUnrollConv.__init__( 47 | self, 48 | problem, 49 | idx_dtype=np.int64, 50 | torch_op=torch_op, 51 | deterministic=deterministic, 52 | kahan=kahan, 53 | ) 54 | 55 | self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device="cuda") 56 | self.weight_numel = self.config.weight_numel 57 | 58 | if not extlib.TORCH_COMPILE: 59 | self.forward = types.MethodType(LoopUnrollConv.forward, self) 60 | 61 | def forward( 62 | self, 63 | X: torch.Tensor, 64 | Y: torch.Tensor, 65 | W: torch.Tensor, 66 | rows: torch.Tensor, 67 | cols: torch.Tensor, 68 | sender_perm: Optional[torch.Tensor] = None, 69 | ) -> torch.Tensor: 70 | """ 71 | Computes the fused CG tensor product + convolution. 72 | 73 | :param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``. 74 | :param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``. 75 | :param W: Tensor of datatype ``problem.weight_dtype`` and shape 76 | 77 | * ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False`` 78 | * ``[problem.weight_numel]`` if ``problem.shared_weights=True`` 79 | 80 | :param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix, 81 | datatype ``torch.int64``. Must be row-major sorted along with ``cols`` when ``deterministic=True``. 82 | :param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix, 83 | datatype ``torch.int64``. 84 | :param sender_perm: Tensor of shape ``[|E|]`` and ``torch.int64`` datatype containing a 85 | permutation that transposes the adjacency matrix nonzeros from row-major to column-major order. 86 | Must be provided when ``deterministic=True``. 87 | 88 | :return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. 89 | """ 90 | if sender_perm is None: 91 | return torch.ops.libtorch_tp_jit.jit_conv_forward( 92 | self.internal, 93 | X, 94 | Y, 95 | W, 96 | rows, 97 | cols, 98 | self.workspace_buffer, 99 | self.dummy_transpose_perm, 100 | ) 101 | else: 102 | return torch.ops.libtorch_tp_jit.jit_conv_forward( 103 | self.internal, 104 | X, 105 | Y, 106 | W, 107 | rows, 108 | cols, 109 | self.workspace_buffer, 110 | sender_perm, 111 | ) 112 | 113 | @staticmethod 114 | def name(): 115 | return LoopUnrollConv.name() 116 | 117 | 118 | # ================================================================== 119 | # Reference implementations for benchmarking 120 | 121 | 122 | class TensorProductConvKahan(TensorProductConv): 123 | def __init__(self, config, idx_dtype=np.int64, torch_op=True): 124 | super().__init__(config, idx_dtype, torch_op, deterministic=True, kahan=True) 125 | 126 | @staticmethod 127 | def name(): 128 | return "LoopUnrollConvKahan" 129 | 130 | 131 | class TensorProductConvDeterministic(TensorProductConv): 132 | def __init__(self, config, idx_dtype=np.int64, torch_op=True): 133 | super().__init__(config, idx_dtype, torch_op, deterministic=True) 134 | 135 | @staticmethod 136 | def name(): 137 | return "LoopUnrollConvDeterministic" 138 | 139 | 140 | class TensorProductConvAtomic(TensorProductConv): 141 | def __init__(self, config, idx_dtype=np.int64, torch_op=True): 142 | super().__init__(config, idx_dtype, torch_op, deterministic=False) 143 | 144 | @staticmethod 145 | def name(): 146 | return "LoopUnrollConvAtomic" 147 | 148 | 149 | class TensorProductConvScatterSum(ConvolutionBase): 150 | def __init__(self, config, idx_dtype=np.int64, torch_op=True): 151 | assert torch_op 152 | global torch 153 | import torch 154 | 155 | super().__init__(config, idx_dtype, torch_op=torch_op, deterministic=False) 156 | 157 | self.reference_tp = TensorProduct(config, torch_op=torch_op) 158 | from openequivariance.implementations.convolution.scatter import scatter_sum 159 | 160 | self.scatter_sum = scatter_sum 161 | 162 | def forward(self, L1_in, L2_in, weights, rows, cols): 163 | tp_outputs = self.reference_tp(L1_in[cols], L2_in, weights) 164 | return self.scatter_sum( 165 | src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0] 166 | ) 167 | 168 | def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): 169 | tp_outputs = np.zeros((graph.nnz, self.L3.dim), dtype=L3_out.dtype) 170 | self.reference_tp.forward_cpu(L1_in[graph.cols], L2_in, tp_outputs, weights) 171 | np.add.at(L3_out, graph.rows, tp_outputs) 172 | 173 | def backward_cpu( 174 | self, 175 | L1_in: np.ndarray, 176 | L1_grad: np.ndarray, 177 | L2_in: np.ndarray, 178 | L2_grad: np.ndarray, 179 | L3_grad: np.ndarray, 180 | weights: np.ndarray, 181 | weights_grad: np.ndarray, 182 | graph, 183 | ): 184 | L1_grad_bcast = np.zeros((graph.nnz, self.L1.dim), dtype=L1_grad.dtype) 185 | self.reference_tp.backward_cpu( 186 | L1_in[graph.cols], 187 | L1_grad_bcast, 188 | L2_in, 189 | L2_grad, 190 | L3_grad[graph.rows], 191 | weights, 192 | weights_grad, 193 | ) 194 | np.add.at(L1_grad, graph.cols, L1_grad_bcast) 195 | 196 | @staticmethod 197 | def name(): 198 | return "LoopUnrollConvScatterSum" 199 | -------------------------------------------------------------------------------- /openequivariance/implementations/convolution/scatter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | """ 5 | Scatter sum operator from MACE. 6 | 7 | basic scatter_sum operations from torch_scatter from 8 | https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py 9 | Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. 10 | """ 11 | 12 | 13 | def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 14 | if dim < 0: 15 | dim = other.dim() + dim 16 | if src.dim() == 1: 17 | for _ in range(0, dim): 18 | src = src.unsqueeze(0) 19 | for _ in range(src.dim(), other.dim()): 20 | src = src.unsqueeze(-1) 21 | src = src.expand_as(other) 22 | return src 23 | 24 | 25 | def scatter_sum( 26 | src: torch.Tensor, 27 | index: torch.Tensor, 28 | dim: int = -1, 29 | out: Optional[torch.Tensor] = None, 30 | dim_size: Optional[int] = None, 31 | reduce: str = "sum", 32 | ) -> torch.Tensor: 33 | assert reduce == "sum" # for now, TODO 34 | index = _broadcast(index, src, dim) 35 | if out is None: 36 | size = list(src.size()) 37 | if dim_size is not None: 38 | size[dim] = dim_size 39 | elif index.numel() == 0: 40 | size[dim] = 0 41 | else: 42 | size[dim] = int(index.max()) + 1 43 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 44 | return out.scatter_add_(dim, index, src) 45 | else: 46 | return out.scatter_add_(dim, index, src) 47 | -------------------------------------------------------------------------------- /openequivariance/implementations/symmetric_contraction/__init__.py: -------------------------------------------------------------------------------- 1 | from openequivariance.implementations.symmetric_contraction.symmetric_contraction import ( 2 | SymmetricContraction, 3 | ) 4 | 5 | __all__ = ["SymmetricContraction"] 6 | -------------------------------------------------------------------------------- /openequivariance/implementations/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | 4 | import numpy as np 5 | 6 | from openequivariance.implementations.TensorProductBase import TensorProductBase 7 | from openequivariance.implementations.e3nn_lite import ( 8 | Irreps, 9 | Instruction, 10 | TPProblem, 11 | ) 12 | 13 | 14 | def sparse_outer_product_work(cg: np.ndarray) -> int: 15 | return np.sum(np.max(cg != 0, axis=2)) 16 | 17 | 18 | def convenience_namer(L1: Irreps, L2: Irreps, L3: Irreps): 19 | return f"({L1}x{L2}->{L3})" 20 | 21 | 22 | # Non Zeros 23 | @functools.lru_cache(typed=True) 24 | def count_cg_non_zero(l1, l2, l3) -> int: 25 | return np.count_nonzero(TensorProductBase.load_cg_tensor(l1, l2, l3)) 26 | 27 | 28 | def calculate_total_nnz(tpp: TPProblem) -> int: 29 | """ 30 | To make sure you don't over count repeat CGs which get used multiple times 31 | """ 32 | nnz_by_l_combo = {} 33 | for ins in tpp.instructions: # type : Instruction 34 | l1 = tpp.irreps_in1[ins.i_in1].ir.l 35 | l2 = tpp.irreps_in2[ins.i_in2].ir.l 36 | l3 = tpp.irreps_out[ins.i_out].ir.l 37 | assert isinstance(l1, int) 38 | assert isinstance(l2, int) 39 | assert isinstance(l3, int) 40 | nnz_by_l_combo[(l1, l2, l3)] = count_cg_non_zero(l1, l2, l3) 41 | return sum(nnz_by_l_combo.values()) 42 | 43 | 44 | def calc_weight_offsets(tpp: TPProblem) -> list[int]: 45 | """ 46 | Returns a list of weight offsets for every instruction. 47 | """ 48 | assert isinstance(tpp, TPProblem) 49 | offset = 0 50 | offsets = [] 51 | for ins in tpp.instructions: 52 | assert isinstance(ins, Instruction) 53 | offsets.append(offset) 54 | if ins.has_weight: 55 | flatsize = math.prod(ins.path_shape) 56 | offset += flatsize 57 | return offsets 58 | 59 | 60 | def filter_and_analyze_problem(problem): 61 | """ 62 | Centralized function that stops unhandled problem configurations, 63 | returns a dictionary of useful information about the problem. 64 | """ 65 | for i, inst in enumerate(problem.instructions): 66 | assert inst.connection_mode == problem.instructions[0].connection_mode, ( 67 | f"All instructions must have the same connection mode, got {inst.connection_mode} and {problem.instructions[0].connection_mode}" 68 | ) 69 | 70 | assert inst.has_weight, ( 71 | f"All instructions must have trainable weights, got {inst.has_weight} at index {i}" 72 | ) 73 | 74 | assert problem.instructions[0].connection_mode in ["uvu", "uvw"], ( 75 | f"Connection mode must be 'uvu' or 'uvw', got {problem.instructions[0].connection_mode}" 76 | ) 77 | 78 | assert problem.irrep_dtype == problem.weight_dtype, ( 79 | f"irrep_dtype and weight_dtype must be the same, got {problem.irrep_dtype} and {problem.weight_dtype}" 80 | ) 81 | 82 | assert not problem.internal_weights, ( 83 | f"Openequivariance does not support internal weights, got {problem.internal_weights}" 84 | ) 85 | 86 | assert len(problem.instructions) > 0, "Tensor product has no valid instructions!" 87 | 88 | result = { 89 | "is_uvw": problem.instructions[0].connection_mode == "uvw", 90 | } 91 | return result 92 | 93 | 94 | def torch_to_oeq_dtype(torch_dtype) -> type[np.generic]: 95 | """ 96 | Convenience function; converts a torch datatype to the corresponding 97 | numpy datatype for use in TPProblem. 98 | 99 | :param torch_dtype: torch datatype (e.g., torch.float32, torch.float64) 100 | :return: numpy datatype (e.g., np.float32, np.float64) 101 | """ 102 | 103 | global torch 104 | import torch 105 | 106 | if torch_dtype == torch.float32: 107 | return np.float32 108 | elif torch_dtype == torch.float64: 109 | return np.float64 110 | else: 111 | raise ValueError("Unsupported torch dtype!") 112 | -------------------------------------------------------------------------------- /openequivariance/templates/common.cuh: -------------------------------------------------------------------------------- 1 | #define ROW_OPERATION(ROW_LEN, LOOP_VAR, ...) \ 2 | _Pragma ("unroll") \ 3 | for(int LOOP_VAR = 0; LOOP_VAR < ROW_LEN; LOOP_VAR += THREADS_PER_WARP) { \ 4 | if(LOOP_VAR >= ROW_LEN - THREADS_PER_WARP) { \ 5 | if(lane_id < ROW_LEN - LOOP_VAR) { \ 6 | __VA_ARGS__ \ 7 | } \ 8 | } \ 9 | else { \ 10 | __VA_ARGS__ \ 11 | } \ 12 | } 13 | -------------------------------------------------------------------------------- /openequivariance/templates/jinja_utils.py: -------------------------------------------------------------------------------- 1 | from jinja2 import Environment, PackageLoader 2 | 3 | 4 | def raise_helper(msg): 5 | raise Exception(msg) 6 | 7 | 8 | def divide(numerator, denominator): 9 | return numerator // denominator 10 | 11 | 12 | def sizeof(dtype): 13 | if dtype in ["float", "int", "unsigned int"]: 14 | return 4 15 | else: 16 | raise Exception("Provided undefined datatype to sizeof!") 17 | 18 | 19 | def get_jinja_environment(): 20 | env = Environment( 21 | loader=PackageLoader("openequivariance"), extensions=["jinja2.ext.do"] 22 | ) 23 | env.globals["raise"] = raise_helper 24 | env.globals["divide"] = divide 25 | env.globals["sizeof"] = sizeof 26 | env.globals["enumerate"] = enumerate 27 | return env 28 | -------------------------------------------------------------------------------- /openequivariance/templates/macros.jinja: -------------------------------------------------------------------------------- 1 | {# 2 | First input argument consists of a dictionary with keys _common_ and _per_warp_. 3 | Keys map to lists of tuples with (name, dtype, num_elements) of each subarray. 4 | #} 5 | {%- macro declare_smem_arrays(arrays, warp_loc_var, config) %} 6 | {%- set warps_per_block = divide(config.num_threads, config.warp_size) %} 7 | extern __shared__ char s[]; 8 | {%- set ns = {"offset": 0, "total_warp_bytes": 0} %} 9 | {%- for name, dtype, num_elements in arrays["common"] %} 10 | {{dtype}}* {{name}} = ({{dtype}}*) (s + {{ ns["offset"] }}); 11 | {%- do ns.update({"offset": ns["offset"] + num_elements * sizeof(dtype)}) %} 12 | {%- if ns["offset"] > config.smem %} 13 | {{ raise("Error, required shared memory exceeds allocation maximum!") }} 14 | {%- endif %} 15 | {%- endfor %} 16 | 17 | {%- for name, dtype, num_elements in arrays["per_warp"] %} 18 | {% do ns.update({"total_warp_bytes": ns["total_warp_bytes"] + num_elements * sizeof(dtype)}) %} 19 | {%- endfor %} 20 | 21 | {%- if ns["offset"] + ns["total_warp_bytes"] * warps_per_block > config.smem %} 22 | {{ raise("Error, required shared memory exceeds allocation maximum!") }} 23 | {%- endif %} 24 | 25 | char* per_warp_smem = s + {{ns["offset"]}} + {{ns["total_warp_bytes"]}} * {{ warp_loc_var }}; 26 | 27 | {%- do ns.update({"offset": 0}) %} 28 | {%- for name, dtype, num_elements in arrays["per_warp"] %} 29 | {{dtype}}* {{name}} = ({{dtype}}*) (per_warp_smem + {{ ns["offset"] }}); 30 | {% do ns.update({"offset": ns["offset"] + num_elements * sizeof(dtype)}) %} 31 | {%- endfor %} 32 | {%- endmacro %} 33 | 34 | {# smem contains a mul_ir stored in row-major order as mul * rep, where mul 35 | is at most |warp_size|. reg is at least a |rep|-sized register array on each thread. 36 | Assumes: each thread has the lane_id. #} 37 | {%- macro transpose_load(mul, dim, smem, offset, reg) %} 38 | if(lane_id < {{mul}}) { 39 | {%- for i in range(dim) %} 40 | {{reg}}[{{i}}] = {{smem}}[{{offset}} + lane_id * {{dim}} + {{i}}]; 41 | {%- endfor %} 42 | } 43 | {%- endmacro %} 44 | 45 | {%- macro transpose_store(mul, dim, smem, offset, reg, op, coeff) %} 46 | if(lane_id < {{mul}}) { 47 | {%- for i in range(dim) %} 48 | {{smem}}[{{offset}} + lane_id * {{dim}} + {{i}}] {{op}} {{reg}}[{{i}}] * {{coeff}}; 49 | {%- endfor %} 50 | } 51 | {%- endmacro %} 52 | 53 | {%- macro declare_smem_variables(segment, smem_base) %} 54 | {%- for name in segment.smem %} 55 | {%- if name != "total" %} 56 | {%- set smem_rng = segment.smem[name] %} 57 | {{ smem_rng["dtype"] }}* {{name}}_smem = ({{smem_rng["dtype"]}}*) ({{smem_base}} + {{smem_rng["offset"]}}); 58 | {%- endif %} 59 | {%- endfor %} 60 | {%- endmacro %} 61 | 62 | {%- macro load_ir_segments(map, glb_ptr_shft, smem_ptr, loop_var) %} 63 | {%- if not map.persist_load %} 64 | {%- for (src_rng, dst_rng) in map.copy_ranges %} 65 | {%- set range_len = src_rng.stop - src_rng.start %} 66 | ROW_OPERATION({{range_len}}, {{loop_var}}, {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id] = {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}];) 67 | {%- endfor %} 68 | {%- endif %} 69 | {%- endmacro %} 70 | 71 | {%- macro load_ir_segments_force(map, glb_ptr_shft, smem_ptr, loop_var) %} 72 | {%- for (src_rng, dst_rng) in map.copy_ranges %} 73 | {%- set range_len = src_rng.stop - src_rng.start %} 74 | ROW_OPERATION({{range_len}}, {{loop_var}}, {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id] = {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}];) 75 | {%- endfor %} 76 | {%- endmacro %} 77 | 78 | {%- macro store_ir_segments(map, glb_ptr_shft, smem_ptr, loop_var) %} 79 | {%- if not map.persist_store %} 80 | {%- for i, src_rng in enumerate(map.original_src_ranges) %} 81 | {%- set idx = map.idxs[i] %} 82 | {%- set dst_rng = map.original_dst_ranges[i] %} 83 | {%- set range_len = src_rng.stop - src_rng.start %} 84 | {%- if map.storeback_procedure[idx] == "write" %} 85 | ROW_OPERATION({{range_len}}, {{loop_var}}, {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}] = {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id];) 86 | {%- elif map.storeback_procedure[idx] == "accumulate" %} 87 | ROW_OPERATION({{range_len}}, {{loop_var}}, {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}] += {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id];) 88 | {%- elif map.storeback_procedure[idx] == "atomic_accumulate" %} 89 | ROW_OPERATION({{range_len}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_rng.start}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start}} + lane_id + {{loop_var}}]);) 90 | {%- endif %} 91 | {%- endfor %} 92 | {% endif %} 93 | {%- endmacro %} 94 | 95 | {%- macro set_launch_bound_variables(config) %} 96 | {%- set threads_per_warp = config.warp_size %} 97 | {%- set warps_per_block = divide(config.num_threads, config.warp_size) %} 98 | int t_idx = blockIdx.x * blockDim.x + threadIdx.x; 99 | int warp_id = t_idx / {{ threads_per_warp }}; 100 | int lane_id = t_idx % {{ threads_per_warp }}; 101 | int warp_loc = warp_id % {{ warps_per_block }}; 102 | size_t warps_launched = blockDim.x * gridDim.x / {{ threads_per_warp }}; 103 | size_t nnz_per_warp = (num_products + warps_launched - 1) / warps_launched; 104 | 105 | size_t start = nnz_per_warp * ((size_t) warp_id); 106 | size_t end = min(start + nnz_per_warp, num_products); 107 | {%- endmacro %} 108 | 109 | {%- macro transpose_smem_A(irreps, smem_ptr) %} 110 | {%- set slices = irreps.slices() %} 111 | {%- for i, mul_ir in enumerate(irreps) %} { 112 | {%- set dim = mul_ir.ir.dim %} 113 | {%- set mul = mul_ir.mul %} 114 | IRREP_T t_regs[{{dim}}]; 115 | if(lane_id < {{mul}}) { 116 | {%- set offset = slices[i].start %} 117 | {%- for i in range(dim) %} 118 | t_regs[{{i}}] = {{smem_ptr}}[{{offset}} + lane_id * {{dim}} + {{i}}]; 119 | {%- endfor %} 120 | __syncwarp(); 121 | {%- for i in range(dim) %} 122 | {{smem_ptr}}[{{offset}} + lane_id + {{i * mul}}] = t_regs[{{i}}]; 123 | {%- endfor %} 124 | } 125 | } {%- endfor %} 126 | {%- endmacro %} 127 | 128 | {%- macro transpose_smem_B(irreps, smem_ptr) %} 129 | {%- set slices = irreps.slices() %} 130 | {%- for i, mul_ir in enumerate(irreps) %} { 131 | {%- set dim = mul_ir.ir.dim %} 132 | {%- set mul = mul_ir.mul %} 133 | IRREP_T t_regs[{{dim}}]; 134 | if(lane_id < {{mul}}) { 135 | {%- set offset = slices[i].start %} 136 | {%- for i in range(dim) %} 137 | t_regs[{{i}}] = {{smem_ptr}}[{{offset}} + lane_id + {{i * mul}}]; 138 | {%- endfor %} 139 | __syncwarp(); 140 | {%- for i in range(dim) %} 141 | {{smem_ptr}}[{{offset}} + lane_id * {{dim}} + {{i}}] = t_regs[{{i}}]; 142 | {%- endfor %} 143 | } 144 | } {%- endfor %} 145 | {%- endmacro %} 146 | 147 | {%- macro reg_load(mul, dim, smem, offset, reg) %} 148 | if(lane_id < {{mul}}) { 149 | {%- for i in range(dim) %} 150 | {{reg}}[{{i}}] = {{smem}}[{{offset}} + lane_id + {{i * mul}}]; 151 | {%- endfor %} 152 | } 153 | {%- endmacro %} 154 | 155 | {%- macro reg_store(mul, dim, smem, offset, reg, op, coeff) %} 156 | if(lane_id < {{mul}}) { 157 | {%- for i in range(dim) %} 158 | {{smem}}[{{offset}} + lane_id + {{i * mul}}] {{op}} {{reg}}[{{i}}] * {{coeff}}; 159 | {%- endfor %} 160 | } 161 | {%- endmacro %} 162 | 163 | {%- macro launch_bounds(schedule) %} 164 | __launch_bounds__({{schedule.launch_config.num_threads}}) 165 | {%- endmacro %} 166 | -------------------------------------------------------------------------------- /openequivariance/templates/wmm.cuh: -------------------------------------------------------------------------------- 1 | {%- macro generate_matmul(name, M, N, K, TILES_PER_ROW, OUTPUT_RMAJOR, warp_size, A_CMAJOR=True, B_RMAJOR=True, accum=True) %} 2 | 3 | {%-set TILES_PER_COL = warp_size // TILES_PER_ROW %} 4 | 5 | template 6 | __device__ __forceinline__ void {{name}}(const T* __restrict__ A, const T* __restrict__ B, T* C) { 7 | int t_idx = threadIdx.x + blockIdx.x * blockDim.x; 8 | int lane_id = t_idx % {{warp_size}}; 9 | 10 | int const rpt = {{(M + TILES_PER_COL - 1) // TILES_PER_COL}}; 11 | int const cpt = {{(N + TILES_PER_ROW - 1) // TILES_PER_ROW}}; 12 | 13 | T row[cpt]; 14 | T col[rpt]; 15 | T tile[rpt][cpt]; 16 | 17 | int TI_idx = lane_id / {{TILES_PER_ROW}}; 18 | int TJ_idx = lane_id % {{TILES_PER_ROW}}; 19 | int is = TI_idx * rpt; int ie = (TI_idx + 1) * rpt; 20 | int js = TJ_idx * cpt; int je = (TJ_idx + 1) * cpt; 21 | int ist = min(is, {{M}}); int iet = min(ie, {{M}}); 22 | int jst = min(js, {{N}}); int jet = min(je, {{N}}); 23 | 24 | // Zero the output tile 25 | #pragma unroll 26 | for(int i = 0; i < rpt; i++) { 27 | #pragma unroll 28 | for(int j = 0; j < cpt; j++) { 29 | tile[i][j] = 0.0f; 30 | } 31 | } 32 | 33 | for(int k = 0; k < {{K}}; k++) { 34 | #pragma unroll 35 | for(int i = 0; i < rpt; i++) { 36 | if(ist + i < {{M}}) { 37 | {%- if A_CMAJOR %} 38 | col[i] = A[k * {{M}} + ist + i]; 39 | {%- else %} 40 | col[i] = A[(ist + i) * {{K}} + k]; 41 | {%- endif %} 42 | } 43 | } 44 | 45 | #pragma unroll 46 | for(int j = 0; j < cpt; j++) { 47 | if(jst + j < {{N}}) { 48 | {%- if B_RMAJOR %} 49 | row[j] = B[k * {{N}} + jst + j]; 50 | {%- else %} 51 | row[j] = B[j * {{K}} + k]; 52 | {%- endif %} 53 | } 54 | } 55 | 56 | #pragma unroll 57 | for(int i = 0; i < rpt; i++) { 58 | #pragma unroll 59 | for(int j = 0; j < cpt; j++) { 60 | if(ist + i < {{M}} && jst + j < {{N}}) { 61 | tile[i][j] += col[i] * row[j]; 62 | } 63 | } 64 | } 65 | } 66 | 67 | {%- if accum %} 68 | {%- set op = "+=" %} 69 | {%- else %} 70 | {%- set op = "=" %} 71 | {%- endif %} 72 | 73 | // Store the output 74 | #pragma unroll 75 | for(int i = 0; i < rpt; i++) { 76 | for(int j = 0; j < cpt; j++) { 77 | if(i + ist < {{M}} && j + jst < {{N}}) { 78 | {%- if OUTPUT_RMAJOR %} 79 | C[(i + ist) * {{N}} + j + jst] {{op}} tile[i][j]; 80 | {%- else %} 81 | C[(j + jst) * {{M}} + i + ist] {{op}} tile[i][j]; 82 | {%- endif %} 83 | } 84 | } 85 | } 86 | } 87 | 88 | {%- endmacro %} -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "openequivariance" 7 | version = "0.1.0" 8 | authors = [ 9 | { name="Austin Glover" }, 10 | { name="Vivek Bharadwaj" }, 11 | { name="Aydin Buluc" }, 12 | { name="James Demmel" } 13 | ] 14 | description = "A fast GPU JIT kernel generator for the Clebsch-Gordon Tensor Product" 15 | requires-python = ">=3.10" 16 | dependencies = [ 17 | "ninja", 18 | "jinja2", 19 | "numpy", 20 | "torch", 21 | ] 22 | 23 | [project.optional-dependencies] 24 | bench = [ 25 | "matplotlib", 26 | "tqdm", 27 | "e3nn", 28 | "cuequivariance", 29 | "cuequivariance-torch", 30 | "cuequivariance-ops-torch-cu12", 31 | ] 32 | 33 | dev = [ 34 | "e3nn", 35 | "pre-commit", 36 | "ruff", 37 | "pytest", 38 | "pytest-check", 39 | "torch_geometric", 40 | "cmake", 41 | "furo", 42 | "sphinx", 43 | "sphinx-autobuild" 44 | ] 45 | 46 | [tool.setuptools.packages.find] 47 | include = ["openequivariance*"] 48 | 49 | [tool.pytest.ini_options] 50 | addopts = [ 51 | "--import-mode=importlib", 52 | ] 53 | 54 | [tool.ruff] 55 | lint.ignore = ["E741"] -------------------------------------------------------------------------------- /tests/batch_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pytest_check import check 3 | 4 | import numpy as np 5 | import openequivariance as oeq 6 | from openequivariance.implementations.TensorProduct import TensorProduct 7 | from openequivariance.benchmark.correctness_utils import ( 8 | correctness_forward, 9 | correctness_backward, 10 | correctness_double_backward, 11 | ) 12 | from itertools import product 13 | 14 | 15 | class TPCorrectness: 16 | def thresh(self, direction): 17 | return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction] 18 | 19 | def check_result(self, result, fieldname): 20 | with check: 21 | error = result[fieldname]["diff_Linf_norm"] 22 | thresh = result["thresh"] 23 | assert result[fieldname]["pass"], ( 24 | f"{fieldname} observed error={error:.5f} >= {thresh}" 25 | ) 26 | 27 | @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="class") 28 | def dtype(self, request): 29 | return request.param 30 | 31 | @pytest.fixture(scope="class") 32 | def tp_and_problem(self, problem): 33 | tp = TensorProduct(problem) 34 | return tp, problem 35 | 36 | def test_tp_fwd(self, tp_and_problem): 37 | tp, problem = tp_and_problem 38 | result = correctness_forward( 39 | problem=problem, 40 | test_implementation=tp, 41 | reference_implementation=None, 42 | batch_size=1000, 43 | correctness_threshold=self.thresh("fwd"), 44 | prng_seed=12345, 45 | ) 46 | 47 | self.check_result(result, "output") 48 | 49 | def test_tp_bwd(self, tp_and_problem): 50 | tp, problem = tp_and_problem 51 | result = correctness_backward( 52 | problem=problem, 53 | test_implementation=tp, 54 | reference_implementation=None, 55 | batch_size=1000, 56 | correctness_threshold=self.thresh("bwd"), 57 | prng_seed=12345, 58 | ) 59 | 60 | self.check_result(result, "weight_grad") 61 | self.check_result(result, "in1_grad") 62 | self.check_result(result, "in2_grad") 63 | 64 | def test_tp_double_bwd(self, tp_and_problem): 65 | tp, problem = tp_and_problem 66 | result = correctness_double_backward( 67 | problem=problem, 68 | test_implementation=tp, 69 | reference_implementation=None, 70 | batch_size=200, 71 | correctness_threshold=self.thresh("double_bwd"), 72 | prng_seed=12345, 73 | ) 74 | 75 | self.check_result(result, "output_double_grad") 76 | self.check_result(result, "in1_grad") 77 | self.check_result(result, "in2_grad") 78 | self.check_result(result, "weights_grad") 79 | 80 | 81 | class TestProductionModels(TPCorrectness): 82 | from openequivariance.benchmark.problems import ( 83 | e3nn_torch_tetris_poly_problems, 84 | diffdock_problems, 85 | mace_problems, 86 | nequip_problems, 87 | ) 88 | 89 | production_model_tpps = ( 90 | mace_problems() 91 | + nequip_problems() 92 | + e3nn_torch_tetris_poly_problems() 93 | + diffdock_problems() 94 | ) 95 | 96 | @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") 97 | def problem(self, request, dtype): 98 | request.param.irrep_dtype, request.param.weight_dtype = dtype, dtype 99 | return request.param 100 | 101 | 102 | class TestUVUSingleIrrep(TPCorrectness): 103 | muls = [ 104 | (1, 1, 1), 105 | (2, 1, 2), 106 | (4, 1, 4), 107 | (8, 1, 8), 108 | (16, 1, 16), 109 | (32, 1, 32), 110 | (5, 1, 5), 111 | (13, 1, 13), 112 | (19, 1, 19), 113 | (33, 1, 33), 114 | (49, 1, 49), 115 | (50, 1, 50), 116 | (123, 1, 123), 117 | (128, 1, 128), 118 | (256, 1, 256), 119 | (512, 1, 512), 120 | (1, 2, 1), 121 | (1, 4, 1), 122 | (1, 16, 1), 123 | (1, 32, 1), 124 | (16, 3, 16), 125 | (16, 9, 16), 126 | (24, 24, 24), 127 | (32, 32, 32), 128 | ] 129 | 130 | irs = [ 131 | (0, 0, 0), 132 | (1, 1, 1), 133 | (1, 0, 1), 134 | (1, 2, 1), 135 | (2, 0, 2), 136 | (2, 2, 4), 137 | (2, 2, 2), 138 | (5, 3, 5), 139 | (7, 2, 5), 140 | ] 141 | 142 | def id_func(m, i): 143 | return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" 144 | 145 | @pytest.fixture( 146 | params=product(muls, irs), 147 | ids=lambda x: TestUVUSingleIrrep.id_func(x[0], x[1]), 148 | scope="class", 149 | ) 150 | def problem(self, request, dtype): 151 | m, i = request.param[0], request.param[1] 152 | instructions = [(0, 0, 0, "uvu", True)] 153 | return oeq.TPProblem( 154 | f"{m[0]}x{i[0]}e", 155 | f"{m[1]}x{i[1]}e", 156 | f"{m[2]}x{i[2]}e", 157 | instructions, 158 | shared_weights=False, 159 | internal_weights=False, 160 | irrep_dtype=dtype, 161 | weight_dtype=dtype, 162 | ) 163 | 164 | 165 | class TestUVWSingleIrrep(TPCorrectness): 166 | muls = [ 167 | (1, 1, 1), 168 | (2, 1, 2), 169 | (4, 1, 4), 170 | (8, 1, 8), 171 | (16, 1, 16), 172 | (32, 1, 32), 173 | (5, 1, 5), 174 | (13, 1, 13), 175 | (19, 1, 19), 176 | (33, 1, 33), 177 | (49, 1, 49), 178 | (50, 1, 50), 179 | (64, 1, 64), 180 | (1, 2, 1), 181 | (1, 4, 1), 182 | (1, 16, 1), 183 | (1, 32, 1), 184 | (16, 3, 16), 185 | (16, 9, 16), 186 | (24, 24, 24), 187 | (32, 32, 32), 188 | ] 189 | 190 | irs = [ 191 | (0, 0, 0), 192 | (1, 1, 1), 193 | (1, 0, 1), 194 | (1, 2, 1), 195 | (2, 0, 2), 196 | (2, 2, 4), 197 | (2, 2, 2), 198 | (5, 3, 5), 199 | (7, 2, 5), 200 | ] 201 | 202 | def id_func(m, i): 203 | return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" 204 | 205 | @pytest.fixture( 206 | params=product(muls, irs), 207 | ids=lambda x: TestUVWSingleIrrep.id_func(x[0], x[1]), 208 | scope="class", 209 | ) 210 | def problem(self, request, dtype): 211 | m, i = request.param[0], request.param[1] 212 | instructions = [(0, 0, 0, "uvw", True)] 213 | return oeq.TPProblem( 214 | f"{m[0]}x{i[0]}e", 215 | f"{m[1]}x{i[1]}e", 216 | f"{m[2]}x{i[2]}e", 217 | instructions, 218 | shared_weights=False, 219 | internal_weights=False, 220 | irrep_dtype=dtype, 221 | weight_dtype=dtype, 222 | ) 223 | 224 | 225 | class TestSharedWeights(TPCorrectness): 226 | from openequivariance.benchmark.problems import ( 227 | mace_problems, 228 | diffdock_problems, 229 | ) 230 | 231 | problems = [mace_problems()[0], diffdock_problems()[0]] 232 | 233 | def thresh(self, direction): 234 | return { 235 | "fwd": 1e-5, 236 | "bwd": 5e-4, # Expect higher errors for shared weights 237 | "double_bwd": 5e-4, 238 | }[direction] 239 | 240 | @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class") 241 | def problem(self, request, dtype): 242 | problem = request.param 243 | problem.irrep_dtype, problem.weight_dtype = dtype, dtype 244 | problem.shared_weights = True 245 | return problem 246 | -------------------------------------------------------------------------------- /tests/conv_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tempfile 3 | import urllib 4 | from pytest_check import check 5 | 6 | import numpy as np 7 | import openequivariance as oeq 8 | from openequivariance.benchmark.ConvBenchmarkSuite import load_graph 9 | from itertools import product 10 | 11 | 12 | class ConvCorrectness: 13 | def thresh(self, direction): 14 | return {"fwd": 3e-4, "bwd": 3e-4, "double_bwd": 3e-4}[direction] 15 | 16 | def check_result(self, result, fieldname): 17 | with check: 18 | error = result[fieldname]["diff_Linf_norm"] 19 | thresh = result["thresh"] 20 | assert result[fieldname]["pass"], ( 21 | f"{fieldname} observed error={error:.5f} >= {thresh}" 22 | ) 23 | 24 | @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="class") 25 | def dtype(self, request): 26 | return request.param 27 | 28 | @pytest.fixture(params=["1drf_radius3.5.pickle"], ids=["1drf"], scope="class") 29 | def graph(self, request): 30 | download_prefix = ( 31 | "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/" 32 | ) 33 | filename = request.param 34 | 35 | graph = None 36 | with tempfile.NamedTemporaryFile() as temp_file: 37 | urllib.request.urlretrieve(download_prefix + filename, temp_file.name) 38 | graph = load_graph(temp_file.name) 39 | 40 | # graph = load_graph("data/1drf_radius3.5.pickle") 41 | return graph 42 | 43 | @pytest.fixture(params=["atomic", "deterministic", "kahan"], scope="class") 44 | def conv_object(self, request, problem): 45 | if request.param == "atomic": 46 | return oeq.TensorProductConv(problem, deterministic=False) 47 | elif request.param == "deterministic": 48 | if not problem.shared_weights: 49 | return oeq.TensorProductConv(problem, deterministic=True) 50 | else: 51 | pytest.skip("Shared weights not supported with deterministic") 52 | elif request.param == "kahan": 53 | if problem.irrep_dtype == np.float32: 54 | if not problem.shared_weights: 55 | return oeq.TensorProductConv( 56 | problem, deterministic=True, kahan=True 57 | ) 58 | else: 59 | pytest.skip("Shared weights not supported with kahan") 60 | else: 61 | pytest.skip("Only Float32 supported with kahan") 62 | 63 | def test_tp_fwd(self, conv_object, graph): 64 | if conv_object is None: 65 | pytest.skip("'conv_object' fixture returned None, skipping") 66 | 67 | result = conv_object.test_correctness_forward( 68 | graph, 69 | thresh=self.thresh("fwd"), 70 | prng_seed=12345, 71 | reference_implementation=None, 72 | ) 73 | 74 | self.check_result(result, "output") 75 | 76 | def test_tp_bwd(self, conv_object, graph): 77 | if conv_object is None: 78 | pytest.skip("'conv_object' fixture returned None, skipping") 79 | 80 | result = conv_object.test_correctness_backward( 81 | graph, 82 | thresh=self.thresh("bwd"), 83 | prng_seed=12345, 84 | reference_implementation=None, 85 | ) 86 | 87 | self.check_result(result, "weight_grad") 88 | self.check_result(result, "in1_grad") 89 | self.check_result(result, "in2_grad") 90 | 91 | def test_tp_double_bwd(self, conv_object, graph): 92 | if conv_object is None: 93 | pytest.skip("'conv_object' fixture returned None, skipping") 94 | 95 | result = conv_object.test_correctness_double_backward( 96 | graph, 97 | thresh=self.thresh("double_bwd"), 98 | prng_seed=12345, 99 | reference_implementation=None, 100 | ) 101 | 102 | self.check_result(result, "output_grad") 103 | self.check_result(result, "in1_grad") 104 | self.check_result(result, "in2_grad") 105 | self.check_result(result, "weights_grad") 106 | 107 | 108 | class TestProductionModels(ConvCorrectness): 109 | from openequivariance.benchmark.problems import ( 110 | mace_problems, 111 | diffdock_problems, 112 | ) 113 | 114 | production_model_tpps = mace_problems() + diffdock_problems() 115 | 116 | @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class") 117 | def problem(self, request, dtype): 118 | request.param.irrep_dtype, request.param.weight_dtype = dtype, dtype 119 | return request.param 120 | 121 | 122 | class TestUVUSingleIrrep(ConvCorrectness): 123 | muls = [ 124 | (1, 1, 1), 125 | (8, 1, 8), 126 | (16, 1, 16), 127 | (32, 1, 32), 128 | (5, 1, 5), 129 | (13, 1, 13), 130 | (19, 1, 19), 131 | (33, 1, 33), 132 | (49, 1, 49), 133 | (128, 1, 128), 134 | (1, 2, 1), 135 | (1, 16, 1), 136 | (1, 32, 1), 137 | (16, 3, 16), 138 | ] 139 | 140 | irs = [(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (2, 0, 2), (5, 3, 5), (7, 2, 5)] 141 | 142 | def id_func(m, i): 143 | return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" 144 | 145 | @pytest.fixture( 146 | params=product(muls, irs), 147 | ids=lambda x: TestUVUSingleIrrep.id_func(x[0], x[1]), 148 | scope="class", 149 | ) 150 | def problem(self, request, dtype): 151 | m, i = request.param[0], request.param[1] 152 | instructions = [(0, 0, 0, "uvu", True)] 153 | return oeq.TPProblem( 154 | f"{m[0]}x{i[0]}e", 155 | f"{m[1]}x{i[1]}e", 156 | f"{m[2]}x{i[2]}e", 157 | instructions, 158 | shared_weights=False, 159 | internal_weights=False, 160 | irrep_dtype=dtype, 161 | weight_dtype=dtype, 162 | ) 163 | 164 | 165 | class TestUVWSingleIrrep(ConvCorrectness): 166 | muls = [ 167 | (1, 1, 1), 168 | (4, 1, 4), 169 | (8, 1, 8), 170 | (16, 1, 16), 171 | (32, 1, 32), 172 | (5, 1, 5), 173 | (13, 1, 13), 174 | (33, 1, 33), 175 | (49, 1, 49), 176 | (64, 1, 64), 177 | (1, 2, 1), 178 | (1, 4, 1), 179 | (1, 16, 1), 180 | (1, 32, 1), 181 | (16, 3, 16), 182 | ] 183 | 184 | irs = [(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (5, 3, 5), (7, 2, 5)] 185 | 186 | def id_func(m, i): 187 | return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e" 188 | 189 | @pytest.fixture( 190 | params=product(muls, irs), 191 | ids=lambda x: TestUVWSingleIrrep.id_func(x[0], x[1]), 192 | scope="class", 193 | ) 194 | def problem(self, request, dtype): 195 | m, i = request.param[0], request.param[1] 196 | instructions = [(0, 0, 0, "uvw", True)] 197 | return oeq.TPProblem( 198 | f"{m[0]}x{i[0]}e", 199 | f"{m[1]}x{i[1]}e", 200 | f"{m[2]}x{i[2]}e", 201 | instructions, 202 | shared_weights=False, 203 | internal_weights=False, 204 | irrep_dtype=dtype, 205 | weight_dtype=dtype, 206 | ) 207 | 208 | 209 | class TestAtomicSharedWeights(ConvCorrectness): 210 | from openequivariance.benchmark.problems import ( 211 | mace_problems, 212 | diffdock_problems, 213 | ) 214 | 215 | problems = [mace_problems()[0], diffdock_problems()[0]] 216 | 217 | def thresh(self, direction): 218 | return { 219 | "fwd": 1e-5, 220 | "bwd": 5e-2, # Expect higher errors for shared weights 221 | "double_bwd": 5e-2, 222 | }[direction] 223 | 224 | @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class") 225 | def problem(self, request, dtype): 226 | problem = request.param 227 | problem.irrep_dtype, problem.weight_dtype = dtype, dtype 228 | problem.shared_weights = True 229 | return problem 230 | 231 | @pytest.fixture(scope="class") 232 | def conv_object(self, request, problem): 233 | return oeq.TensorProductConv(problem, deterministic=False) 234 | -------------------------------------------------------------------------------- /tests/export_test.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import torch 3 | import pytest 4 | import tempfile 5 | import subprocess 6 | import os 7 | import sys 8 | 9 | import numpy as np 10 | import openequivariance as oeq 11 | from torch_geometric import EdgeIndex 12 | import importlib.resources 13 | 14 | from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct 15 | 16 | 17 | @pytest.fixture(scope="session") 18 | def problem_and_irreps(): 19 | X_ir, Y_ir, Z_ir = oeq.Irreps("32x5e"), oeq.Irreps("1x3e"), oeq.Irreps("32x5e") 20 | problem = oeq.TPProblem( 21 | X_ir, 22 | Y_ir, 23 | Z_ir, 24 | [(0, 0, 0, "uvu", True)], 25 | shared_weights=False, 26 | internal_weights=False, 27 | irrep_dtype=np.float32, 28 | weight_dtype=np.float32, 29 | ) 30 | 31 | gen = torch.Generator(device="cuda") 32 | gen.manual_seed(0) 33 | 34 | return ( 35 | problem, 36 | X_ir, 37 | Y_ir, 38 | Z_ir, 39 | ) 40 | 41 | 42 | @pytest.fixture(params=["batch", "conv_det", "conv_atomic"], scope="session") 43 | def tp_and_inputs(request, problem_and_irreps): 44 | problem, X_ir, Y_ir, _ = problem_and_irreps 45 | gen = torch.Generator(device="cuda") 46 | gen.manual_seed(0) 47 | 48 | if request.param == "batch": 49 | batch_size = 1000 50 | X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen) 51 | Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen) 52 | W = torch.rand(batch_size, problem.weight_numel, device="cuda", generator=gen) 53 | return oeq.TensorProduct(problem), (X, Y, W) 54 | else: 55 | node_ct, nonzero_ct = 3, 4 56 | 57 | # Receiver, sender indices for message passing GNN 58 | edge_index = EdgeIndex( 59 | [[0, 1, 1, 2], [1, 0, 2, 1]], device="cuda", dtype=torch.long 60 | ) 61 | 62 | _, sender_perm = edge_index.sort_by("col") 63 | edge_index, _ = edge_index.sort_by("row") 64 | edge_index = [edge_index[0].detach(), edge_index[1].detach()] 65 | 66 | X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen) 67 | Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen) 68 | W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen) 69 | 70 | if request.param == "conv_atomic": 71 | return oeq.TensorProductConv(problem, torch_op=True, deterministic=False), ( 72 | X, 73 | Y, 74 | W, 75 | edge_index[0], 76 | edge_index[1], 77 | ) 78 | elif request.param == "conv_det": 79 | return oeq.TensorProductConv(problem, torch_op=True, deterministic=True), ( 80 | X, 81 | Y, 82 | W, 83 | edge_index[0], 84 | edge_index[1], 85 | sender_perm, 86 | ) 87 | 88 | 89 | def test_jitscript(tp_and_inputs): 90 | tp, inputs = tp_and_inputs 91 | uncompiled_result = tp.forward(*inputs) 92 | 93 | scripted_tp = torch.jit.script(tp) 94 | loaded_tp = None 95 | with tempfile.NamedTemporaryFile(suffix=".pt") as tmp_file: 96 | scripted_tp.save(tmp_file.name) 97 | loaded_tp = torch.jit.load(tmp_file.name) 98 | 99 | compiled_result = loaded_tp(*inputs) 100 | assert torch.allclose(uncompiled_result, compiled_result, atol=1e-5) 101 | 102 | 103 | def test_compile(tp_and_inputs): 104 | tp, inputs = tp_and_inputs 105 | uncompiled_result = tp.forward(*inputs) 106 | 107 | compiled_tp = torch.compile(tp) 108 | compiled_result = compiled_tp(*inputs) 109 | assert torch.allclose(uncompiled_result, compiled_result, atol=1e-5) 110 | 111 | 112 | def test_export(tp_and_inputs): 113 | tp, inputs = tp_and_inputs 114 | uncompiled_result = tp.forward(*inputs) 115 | 116 | exported_tp = torch.export.export(tp, args=inputs, strict=False) 117 | exported_result = exported_tp.module()(*inputs) 118 | assert torch.allclose(uncompiled_result, exported_result, atol=1e-5) 119 | 120 | 121 | def test_aoti(tp_and_inputs): 122 | tp, inputs = tp_and_inputs 123 | uncompiled_result = tp.forward(*inputs) 124 | 125 | exported_tp = torch.export.export(tp, args=inputs, strict=False) 126 | aoti_model = None 127 | with tempfile.NamedTemporaryFile(suffix=".pt2") as tmp_file: 128 | try: 129 | output_path = torch._inductor.aoti_compile_and_package( 130 | exported_tp, package_path=tmp_file.name 131 | ) 132 | except Exception as e: 133 | err_msg = ( 134 | "AOTI compile_and_package failed. NOTE: OpenEquivariance only supports AOTI for " 135 | + "PyTorch version >= 2.8.0.dev20250410+cu126 due to incomplete TorchBind support " 136 | + "in prior versions. " 137 | + f"{e}" 138 | ) 139 | assert False, err_msg 140 | 141 | aoti_model = torch._inductor.aoti_load_package(output_path) 142 | 143 | aoti_result = aoti_model(*inputs) 144 | assert torch.allclose(uncompiled_result, aoti_result, atol=1e-5) 145 | 146 | 147 | def test_jitscript_cpp_interface(problem_and_irreps): 148 | problem, X_ir, Y_ir, _ = problem_and_irreps 149 | cmake_prefix_path = torch.utils.cmake_prefix_path 150 | torch_ext_so_path = oeq.torch_ext_so_path() 151 | 152 | oeq_tp = oeq.TensorProduct(problem).to("cuda") 153 | scripted_oeq = torch.jit.script(oeq_tp) 154 | 155 | e3nn_tp = E3NNTensorProduct(problem).e3nn_tp.to("cuda") 156 | scripted_e3nn = torch.jit.script(e3nn_tp) 157 | 158 | batch_size = 1000 159 | 160 | with ( 161 | tempfile.TemporaryDirectory() as tmpdir, 162 | tempfile.NamedTemporaryFile(suffix=".pt") as oeq_file, 163 | tempfile.NamedTemporaryFile(suffix=".pt") as e3nn_file, 164 | ): 165 | scripted_oeq.save(oeq_file.name) 166 | scripted_e3nn.save(e3nn_file.name) 167 | 168 | test_path = importlib.resources.files("openequivariance") / "extension" / "test" 169 | build_dir = os.path.join(tmpdir, "build") 170 | os.makedirs(build_dir, exist_ok=True) 171 | 172 | for item in test_path.iterdir(): 173 | shutil.copy(item, tmpdir) 174 | 175 | try: 176 | subprocess.run( 177 | [ 178 | "cmake", 179 | "..", 180 | "-DCMAKE_BUILD_TYPE=Release", 181 | "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, 182 | "-DOEQ_EXTLIB=" + torch_ext_so_path, 183 | ], 184 | cwd=build_dir, 185 | check=True, 186 | stdout=subprocess.PIPE, 187 | stderr=subprocess.PIPE, 188 | ) 189 | 190 | subprocess.run( 191 | ["make"], 192 | cwd=build_dir, 193 | check=True, 194 | stdout=subprocess.PIPE, 195 | stderr=subprocess.PIPE, 196 | ) 197 | 198 | subprocess.run( 199 | [ 200 | "./load_jitscript", 201 | e3nn_file.name, 202 | oeq_file.name, 203 | str(X_ir.dim), 204 | str(Y_ir.dim), 205 | str(problem.weight_numel), 206 | str(batch_size), 207 | ], 208 | cwd=build_dir, 209 | check=True, 210 | stdout=subprocess.PIPE, 211 | stderr=subprocess.PIPE, 212 | ) 213 | except subprocess.CalledProcessError as e: 214 | print(e.stdout.decode(), file=sys.stderr) 215 | print(e.stderr.decode(), file=sys.stderr) 216 | assert False 217 | -------------------------------------------------------------------------------- /tests/import_test.py: -------------------------------------------------------------------------------- 1 | def test_import(): 2 | import openequivariance 3 | 4 | assert openequivariance.__version__ is not None 5 | assert openequivariance.__version__ != "0.0.0" 6 | 7 | 8 | def test_tutorial(): 9 | import torch 10 | import e3nn.o3 as o3 11 | 12 | gen = torch.Generator(device="cuda") 13 | 14 | batch_size = 1000 15 | X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e") 16 | X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen) 17 | Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen) 18 | 19 | instructions = [(0, 0, 0, "uvu", True)] 20 | 21 | tp_e3nn = o3.TensorProduct( 22 | X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False 23 | ).to("cuda") 24 | W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen) 25 | 26 | Z = tp_e3nn(X, Y, W) 27 | print(torch.norm(Z)) 28 | # =============================== 29 | 30 | # =============================== 31 | import openequivariance as oeq 32 | 33 | problem = oeq.TPProblem( 34 | X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False 35 | ) 36 | tp_fast = oeq.TensorProduct(problem, torch_op=True) 37 | 38 | Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier 39 | print(torch.norm(Z)) 40 | # =============================== 41 | 42 | # Graph Convolution 43 | # =============================== 44 | from torch_geometric import EdgeIndex 45 | 46 | node_ct, nonzero_ct = 3, 4 47 | 48 | # Receiver, sender indices for message passing GNN 49 | edge_index = EdgeIndex( 50 | [ 51 | [0, 1, 1, 2], # Receiver 52 | [1, 0, 2, 1], 53 | ], # Sender 54 | device="cuda", 55 | dtype=torch.long, 56 | ) 57 | 58 | X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen) 59 | Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen) 60 | W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen) 61 | 62 | tp_conv = oeq.TensorProductConv( 63 | problem, torch_op=True, deterministic=False 64 | ) # Reuse problem from earlier 65 | Z = tp_conv.forward( 66 | X, Y, W, edge_index[0], edge_index[1] 67 | ) # Z has shape [node_ct, z_ir.dim] 68 | print(torch.norm(Z)) 69 | # =============================== 70 | 71 | # =============================== 72 | _, sender_perm = edge_index.sort_by("col") # Sort by sender index 73 | edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index 74 | 75 | # Now we can use the faster deterministic algorithm 76 | tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True) 77 | Z = tp_conv.forward( 78 | X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm 79 | ) 80 | print(torch.norm(Z)) 81 | # =============================== 82 | assert True 83 | -------------------------------------------------------------------------------- /tests/multidevice_test.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | import torch 3 | import subprocess 4 | import os 5 | 6 | 7 | def test_multidevice(): 8 | result = subprocess.run( 9 | [ 10 | "python", 11 | "-m", 12 | "torch.distributed.run", 13 | "--standalone", 14 | "--nnodes=1", 15 | "--nproc-per-node=gpu", 16 | __file__, 17 | ], 18 | capture_output=True, 19 | check=False, 20 | ) 21 | 22 | if result.returncode != 0: 23 | error_string = f""" 24 | Invocation: {" ".join(result.args)} 25 | Test failed with return code {result.returncode}. 26 | \nOutput:\n\n{result.stdout.decode()} 27 | \nError:\n\n{result.stderr.decode()} 28 | """ 29 | assert False, textwrap.dedent(error_string) 30 | 31 | assert True 32 | 33 | 34 | if __name__ == "__main__": 35 | import openequivariance as oeq 36 | 37 | # Use MACE-large to test >64KB shared memory allocation 38 | from openequivariance.benchmark.problems import mace_problems 39 | 40 | problem = mace_problems()[0] 41 | 42 | local_rank = int(os.environ["LOCAL_RANK"]) 43 | device = f"cuda:{local_rank}" 44 | torch.set_default_device(device) 45 | 46 | X_ir, Y_ir, Z_ir = problem.irreps_in1, problem.irreps_in2, problem.irreps_out 47 | tp = oeq.TensorProduct(problem) 48 | 49 | batch_size = 1000 50 | gen = torch.Generator(device=device) 51 | gen.manual_seed(0) 52 | X = torch.rand(batch_size, X_ir.dim, device=device, generator=gen) 53 | Y = torch.rand(batch_size, Y_ir.dim, device=device, generator=gen) 54 | W = torch.rand(batch_size, problem.weight_numel, device=device, generator=gen) 55 | 56 | with torch.cuda.device(device): 57 | result = tp.forward(X, Y, W) 58 | --------------------------------------------------------------------------------