├── .github
└── workflows
│ ├── docs.yaml
│ ├── pre-commit.yaml
│ ├── requirements_cuda_ci.txt
│ └── verify_extension_build.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── MANIFEST.in
├── README.md
├── docs
├── api.rst
├── conf.py
├── index.rst
├── installation.rst
├── livebuild.sh
├── models.rst
├── supported_ops.rst
└── tests_and_benchmarks.rst
├── io
├── cif_to_graph.py
└── load_nequip_configs.py
├── openequivariance
├── __init__.py
├── benchmark
│ ├── ConvBenchmarkSuite.py
│ ├── TestBenchmarkSuite.py
│ ├── benchmark_utils.py
│ ├── correctness_utils.py
│ ├── logging_utils.py
│ ├── perf_metrics_utils.py
│ ├── plotting
│ │ ├── __init__.py
│ │ ├── plot_convolution.py
│ │ ├── plot_double_backward.py
│ │ ├── plot_roofline.py
│ │ ├── plot_uvu.py
│ │ ├── plot_uvw.py
│ │ └── plotting_utils.py
│ ├── problems.py
│ ├── random_buffer_utils.py
│ └── tpp_creation_utils.py
├── extension
│ ├── CMakeLists.txt
│ ├── convolution.hpp
│ ├── generic_module.cpp
│ ├── group_mm_cuda.hpp
│ ├── group_mm_hip.hpp
│ ├── libtorch_tp_jit.cpp
│ ├── tensorproducts.hpp
│ ├── test
│ │ ├── CMakeLists.txt
│ │ └── load_jitscript.cpp
│ └── util
│ │ ├── backend_cuda.hpp
│ │ ├── backend_hip.hpp
│ │ └── buffer.hpp
├── extlib
│ ├── .empty
│ └── __init__.py
├── implementations
│ ├── CUETensorProduct.py
│ ├── ComputationSchedule.py
│ ├── E3NNTensorProduct.py
│ ├── LoopUnrollTP.py
│ ├── MultiplicityOuterProductTP.py
│ ├── TensorProduct.py
│ ├── TensorProductBase.py
│ ├── convolution
│ │ ├── CUEConv.py
│ │ ├── ConvolutionBase.py
│ │ ├── E3NNConv.py
│ │ ├── LoopUnrollConv.py
│ │ ├── TensorProductConv.py
│ │ └── scatter.py
│ ├── e3nn_lite.py
│ ├── symmetric_contraction
│ │ ├── __init__.py
│ │ └── symmetric_contraction.py
│ └── utils.py
└── templates
│ ├── common.cuh
│ ├── jinja_utils.py
│ ├── loop_unroll_batch.cuh
│ ├── loop_unroll_conv_atomic.cuh
│ ├── loop_unroll_conv_det.cuh
│ ├── loop_unroll_tp.cuh
│ ├── macros.jinja
│ ├── subkernel_per_interaction_multirep.cuh
│ └── wmm.cuh
├── pyproject.toml
└── tests
├── batch_test.py
├── benchmark.py
├── conv_test.py
├── export_test.py
├── import_test.py
├── mace_driver.py
└── multidevice_test.py
/.github/workflows/docs.yaml:
--------------------------------------------------------------------------------
1 | name: Deploy documentation to Github Pages
2 | on:
3 | workflow_dispatch:
4 |
5 | permissions: write-all
6 |
7 | concurrency:
8 | group: "pages"
9 | cancel-in-progress: false
10 |
11 | jobs:
12 | build:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - name: Checkout
16 | uses: actions/checkout@v4
17 | - name: Setup Pages
18 | uses: actions/configure-pages@v3
19 | - name: Set up Python 3.10
20 | uses: actions/setup-python@v5
21 | with:
22 | python-version: "3.10"
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | pip install sphinx furo
27 | - name: Build website
28 | run: |
29 | sphinx-build -M dirhtml docs docs/_build
30 |
31 | - name: Fix permissions
32 | run: |
33 | chmod -c -R +rX "docs/_build/dirhtml/" | while read line; do
34 | echo "::warning title=Invalid file permissions automatically fixed::$line"
35 | done
36 | - name: Upload artifact
37 | uses: actions/upload-pages-artifact@v3
38 | with:
39 | path: './docs/_build/dirhtml'
40 | deploy:
41 | environment:
42 | name: github-pages
43 | url: ${{ steps.deployment.outputs.page_url }}
44 | runs-on: ubuntu-latest
45 | needs: build
46 | steps:
47 | - name: Deploy to GitHub Pages
48 | id: deployment
49 | uses: actions/deploy-pages@v4
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yaml:
--------------------------------------------------------------------------------
1 | name: Pre-Commit Checks
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches: [main]
7 |
8 | jobs:
9 | pre-commit:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | - uses: pre-commit/action@v3.0.1
--------------------------------------------------------------------------------
/.github/workflows/requirements_cuda_ci.txt:
--------------------------------------------------------------------------------
1 | numpy==2.2.5
2 | torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128
3 | pytest==8.3.5
4 | ninja==1.11.1.4
5 | ruff==0.11.11
6 | pre-commit==4.2.0
--------------------------------------------------------------------------------
/.github/workflows/verify_extension_build.yml:
--------------------------------------------------------------------------------
1 | name: OEQ CUDA C++ Extension Build Verification
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | pull_request:
7 | branches: [ "main" ]
8 | types: [ labeled ]
9 |
10 | permissions:
11 | contents: read
12 |
13 | jobs:
14 | verify_cuda_extension:
15 | if: ${{ github.event.label.name == 'ci-ready' || github.event_name != 'pull_request' }}
16 | runs-on: ubuntu-latest
17 |
18 | steps:
19 | - uses: actions/checkout@v4
20 | - name: Set up Python
21 | uses: actions/setup-python@v5
22 | with:
23 | python-version: "3.10"
24 | cache: 'pip'
25 | cache-dependency-path: '**/requirements_cuda_ci.txt'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | sudo apt-get update
30 | sudo apt install nvidia-cuda-toolkit
31 | pip install -r .github/workflows/requirements_cuda_ci.txt
32 | pip install -e .
33 |
34 | - name: Test extension build via import
35 | run: |
36 | pytest tests/import_test.py -k test_import
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .rendered*
2 | *.lock
3 | *.so
4 | *.ipynb
5 | .DS_Store
6 | *.code-workspace
7 | __pycache__
8 |
9 | # working folders
10 | build
11 | outputs/*
12 | visualization/*
13 | figures/*
14 | data/*
15 | experimental/*
16 | MACE_models/*
17 | triton_autotuning/*
18 | sandbox/*
19 | docs/_build
20 | docs/_static
21 | docs/_templates
22 |
23 | # bash scripts
24 | get_gpu_node.sh
25 | env.sh
26 |
27 | nvidia-mathdx*
28 | .vscode/*
29 | *.ncu-rep
30 | mace_dev
31 | mace_oeq_integration
32 | valid_indices*
33 |
34 | *.xyz
35 | scratch.txt
36 | triton_autotuning
37 | paper_benchmarks
38 | paper_benchmarks_v2
39 | paper_benchmarks_v3
40 | openequivariance/extlib/*.so
41 |
42 | get_node.sh
43 | *.egg-info
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/ruff-pre-commit
3 | # Ruff version.
4 | rev: v0.11.11
5 | hooks:
6 | # Run the linter.
7 | - id: ruff-check
8 | # Run the formatter.
9 | - id: ruff-format
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2025, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include openequivariance/extlib/*.so
2 | include openequivariance/extlib/*.empty
3 |
4 | include openequivariance/templates/*.cuh
5 | include openequivariance/templates/*.jinja
6 |
7 | include openequivariance/extension/*
8 | include openequivariance/extension/convolution/*
9 | include openequivariance/extension/tensorproducts/*
10 | include openequivariance/extension/util/*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OpenEquivariance
2 | [](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml)
3 | [](https://opensource.org/licenses/BSD-3-Clause)
4 |
5 | [[Examples]](#show-me-some-examples)
6 | [[Citation and Acknowledgements]](#citation-and-acknowledgements)
7 |
8 | OpenEquivariance is a CUDA and HIP kernel generator for the Clebsch-Gordon tensor product,
9 | a key kernel in rotation-equivariant deep neural networks.
10 | It implements some of the tensor products
11 | that [e3nn](https://e3nn.org/) supports
12 | commonly found in graph neural networks
13 | (e.g. [Nequip](https://github.com/mir-group/nequip) or
14 | [MACE](https://github.com/ACEsuit/mace)). To get
15 | started, ensure that you have GCC 9+ on your system
16 | and install our package via
17 |
18 | ```bash
19 | pip install git+https://github.com/PASSIONLab/OpenEquivariance
20 | ```
21 |
22 | We provide up to an order of magnitude acceleration over e3nn perform on par with the latest
23 | version of [NVIDIA cuEquivariance](https://github.com/NVIDIA/cuEquivariance),
24 | which has a closed-source kernel package.
25 | We also offer fused equivariant graph
26 | convolutions that can reduce
27 | computation and memory consumption significantly.
28 |
29 | For detailed instructions on tests, benchmarks, MACE / Nequip, and our API,
30 | check out the [documentation](https://passionlab.github.io/OpenEquivariance).
31 |
32 | 📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and
33 | Computational Discrete Algorithms (Proceedings Track)! Catch the talk in
34 | Montréal and check out the [camera-ready copy on Arxiv](https://arxiv.org/abs/2501.13986) (available May 12, 2025).
35 |
36 | ## Show me some examples
37 | Here's a CG tensor product implemented by e3nn:
38 |
39 | ```python
40 | import torch
41 | import e3nn.o3 as o3
42 |
43 | gen = torch.Generator(device='cuda')
44 |
45 | batch_size = 1000
46 | X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e")
47 | X = torch.rand(batch_size, X_ir.dim, device='cuda', generator=gen)
48 | Y = torch.rand(batch_size, Y_ir.dim, device='cuda', generator=gen)
49 |
50 | instructions=[(0, 0, 0, "uvu", True)]
51 |
52 | tp_e3nn = o3.TensorProduct(X_ir, Y_ir, Z_ir, instructions,
53 | shared_weights=False, internal_weights=False).to('cuda')
54 | W = torch.rand(batch_size, tp_e3nn.weight_numel, device='cuda', generator=gen)
55 |
56 | Z = tp_e3nn(X, Y, W)
57 | print(torch.norm(Z))
58 | ```
59 |
60 | And here's the same tensor product using openequivariance. We require that your
61 | tensors are stored on a CUDA device for this to work:
62 |
63 | ```python
64 | import openequivariance as oeq
65 |
66 | problem = oeq.TPProblem(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False)
67 | tp_fast = oeq.TensorProduct(problem, torch_op=True)
68 |
69 | Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier
70 | print(torch.norm(Z))
71 | ```
72 |
73 | Our interface for `oeq.TPProblem` is almost a strict superset of
74 | `o3.TensorProduct` (two key differences: we
75 | impose `internal_weights=False` and add support for multiple datatypes).
76 | You can pass e3nn `Irreps` instances directly or
77 | use `oeq.Irreps`, which is identical.
78 |
79 | We recommend reading the [e3nn documentation and API reference](https://docs.e3nn.org/en/latest/) first, then using our kernels
80 | as drop-in replacements. We support most "uvu" and "uvw" tensor products;
81 | see [this section](#tensor-products-we-accelerate) for an up-to-date list of supported configurations.
82 |
83 | **Important**: For many configurations, our kernels return results identical to
84 | e3nn up to floating point roundoff (this includes all "uvu" problems with
85 | multiplicity 1 for all irreps in the second input). For other configurations
86 | (e.g. any "uvw" connection modes), we return identical
87 | results up to a well-defined reordering of the weights relative to e3nn.
88 |
89 | If you're executing tensor products as part of a message passing graph
90 | neural network, we offer fused kernels that save both memory and compute time:
91 |
92 | ```python
93 | from torch_geometric import EdgeIndex
94 |
95 | node_ct, nonzero_ct = 3, 4
96 |
97 | # Receiver, sender indices for message passing GNN
98 | edge_index = EdgeIndex(
99 | [[0, 1, 1, 2], # Receiver
100 | [1, 0, 2, 1]], # Sender
101 | device='cuda',
102 | dtype=torch.long)
103 |
104 | X = torch.rand(node_ct, X_ir.dim, device='cuda', generator=gen)
105 | Y = torch.rand(nonzero_ct, Y_ir.dim, device='cuda', generator=gen)
106 | W = torch.rand(nonzero_ct, problem.weight_numel, device='cuda', generator=gen)
107 |
108 | tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=False) # Reuse problem from earlier
109 | Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1]) # Z has shape [node_ct, z_ir.dim]
110 | print(torch.norm(Z))
111 | ```
112 |
113 | If you can guarantee `EdgeIndex` is sorted by receiver index and supply the transpose
114 | permutation, we can provide even greater speedup (and deterministic results)
115 | by avoiding atomics:
116 |
117 | ```python
118 | _, sender_perm = edge_index.sort_by("col") # Sort by sender index
119 | edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index
120 |
121 | # Now we can use the faster deterministic algorithm
122 | tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True)
123 | Z = tp_conv.forward(X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm)
124 | print(torch.norm(Z))
125 | ```
126 | **Note**: you don't need Pytorch geometric to use our kernels. When
127 | `deterministic=False`, the `sender` and `receiver` indices can have
128 | arbitrary order.
129 |
130 | ## Citation and Acknowledgements
131 | If you find this code useful, please cite our paper:
132 |
133 | ```bibtex
134 | @inbook{openequivariance,
135 | author={Vivek Bharadwaj and Austin Glover and Aydin Buluc and James Demmel},
136 | title={An Efficient Sparse Kernel Generator for O(3)-Equivariant Deep Networks},
137 | booktitle = {SIAM Conference on Applied and Computational Discrete Algorithms (ACDA25)},
138 | chapter = {},
139 | url={https://arxiv.org/abs/2501.13986},
140 | publisher={Society for Industrial and Applied Mathematics},
141 | year={2025}
142 | }
143 | ```
144 |
145 | Our codebase includes a lightweight clone of
146 | [e3nn](https://e3nn.org/)'s frontend interface (in particular, the
147 | `TensorProduct` and `Irreps` classes). We removed references to Pytorch
148 | and separated the implementation from the problem description (for future
149 | frontend support outside of torch). We also extracted the Wigner 3j tensor generating code from QuTiP. Thank you to the current
150 | developers and maintainers!
151 |
152 | ## Copyright
153 |
154 | Copyright (c) 2025, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved.
155 |
156 | If you have questions about your rights to use or distribute this software, please contact Berkeley Lab's Intellectual Property Office at IPO@lbl.gov.
157 |
158 | NOTICE. This Software was developed under funding from the U.S. Department of Energy and the U.S. Government consequently retains certain rights. As such, the U.S. Government has been granted for itself and others acting on its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the Software to reproduce, distribute copies to the public, prepare derivative works, and perform publicly and display publicly, and to permit others to do so.
--------------------------------------------------------------------------------
/docs/api.rst:
--------------------------------------------------------------------------------
1 | OpenEquivariance API
2 | ==============================
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 | :caption: Contents:
7 |
8 | OpenEquivariance exposes two key classes: :py:class:`openequivariance.TensorProduct`, which replaces
9 | ``o3.TensorProduct`` from e3nn, and :py:class:`openequivariance.TensorProductConv`, which fuses
10 | the CG tensor product with a subsequent graph convolution. Initializing either class triggers
11 | JIT compilation of a custom kernel, which can take a few seconds.
12 |
13 | Both classes require a configuration object specified
14 | by :py:class:`openequivariance.TPProblem`, which has a constructor
15 | almost identical to ``o3.TensorProduct``.
16 | We recommend reading the `e3nn documentation `_ before
17 | trying our code. OpenEquivariance cannot accelerate all tensor products; see
18 | :doc:`this page ` for a list of supported configurations.
19 |
20 | .. autoclass:: openequivariance.TensorProduct
21 | :members:
22 | :undoc-members:
23 | :exclude-members: name
24 |
25 | .. autoclass:: openequivariance.TensorProductConv
26 | :members:
27 | :undoc-members:
28 | :exclude-members: name
29 |
30 | .. autoclass:: openequivariance.TPProblem
31 | :members:
32 | :undoc-members:
33 |
34 | .. autofunction:: openequivariance.torch_to_oeq_dtype
35 |
36 | .. autofunction:: openequivariance.torch_ext_so_path
37 |
38 | API Identical to e3nn
39 | ---------------------
40 |
41 | These remaining API members are identical to the corresponding
42 | objects in ``e3nn.o3``. You can freely mix these objects from
43 | both packages.
44 |
45 | .. autoclass:: openequivariance.Irreps
46 | :members:
47 | :undoc-members:
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | # Configuration file for the Sphinx documentation builder.
5 | #
6 | # For the full list of built-in configuration values, see the documentation:
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
8 |
9 | # -- Project information -----------------------------------------------------
10 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
11 |
12 | project = "OpenEquivariance"
13 | copyright = "2025, The Regents of the University of California, through Lawrence Berkeley National Laboratory."
14 | author = "Vivek Bharadwaj, Austin Glover, Aydin Buluc, James Demmel"
15 |
16 | # -- General configuration ---------------------------------------------------
17 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
18 |
19 | extensions = []
20 |
21 | templates_path = ["_templates"]
22 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
23 |
24 |
25 | # -- Options for HTML output -------------------------------------------------
26 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
27 |
28 | html_theme = "furo"
29 | # html_static_path = ["_static"]
30 |
31 | extensions = [
32 | "sphinx.ext.autodoc",
33 | ]
34 |
35 | sys.path.insert(0, str(Path("..").resolve()))
36 |
37 | autodoc_mock_imports = ["torch", "openequivariance.extlib", "jinja2", "numpy"]
38 | autodoc_typehints = "description"
39 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. OpenEquivariance documentation master file, created by
2 | sphinx-quickstart on Tue Jun 3 00:20:54 2025.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | OpenEquivariance
7 | ==============================
8 |
9 | `OpenEquivariance `_ is a CUDA and
10 | HIP kernel generator for the Clebsch-Gordon
11 | tensor product, a key kernel in equivariant graph neural networks. We offer
12 | an identical interface to e3nn and produce the same results
13 | (up to numerical roundoff). Our package exhibits up to an order of magnitude
14 | speedup over e3nn and competitive performance with NVIDIA's cuEquivariance.
15 |
16 | Here, you can find our API reference, installation instructions,
17 | and troubleshooting guide. We support for both NVIDIA and AMD GPUs through
18 | our PyTorch interface, including support for JITScript compilation accessible
19 | from C++.
20 |
21 | .. toctree::
22 | :maxdepth: 1
23 | :caption: Contents:
24 |
25 | installation
26 | api
27 | supported_ops
28 | tests_and_benchmarks
29 | models
30 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ==============================
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 | :caption: Contents:
7 |
8 | You need the following to install OpenEquivariance:
9 |
10 | - A Linux system equipped with an NVIDIA / AMD graphics card.
11 | - PyTorch >= 2.4 (>= 2.8 for AOTI and export).
12 | - GCC 9+ and the CUDA / HIP toolkit. The command
13 | ``c++ --version`` should return >= 9.0; see below for details on
14 | setting an alternate compiler.
15 |
16 | Installation is one easy command, followed by import verification:
17 |
18 | .. code-block:: bash
19 |
20 | pip install git+https://github.com/PASSIONLab/OpenEquivariance
21 | python -c "import openequivariance"
22 |
23 | The second line triggers a build of the C++ extension we use to compile
24 | kernels, which can take a couple of minutes. Subsequent imports are
25 | much faster since this extension is cached.
26 |
27 |
28 | Compiling the Integrated PyTorch Extension
29 | ------------------------------------------
30 | To support ``torch.compile``, ``torch.export``, and
31 | JITScript, OpenEquivariance needs to compile a C++ extension
32 | tightly integrated with PyTorch. If you see a warning that
33 | this extension could not be compiled, first check:
34 |
35 | .. code-block:: bash
36 |
37 | c++ --version
38 |
39 | To build the extension with an alternate compiler, set the
40 | ``CC`` and ``CXX``
41 | environment variable and retry the import:
42 |
43 | .. code-block:: bash
44 |
45 | export CCC=/path/to/your/gcc
46 | export CXX=/path/to/your/g++
47 | python -c "import openequivariance"
48 |
49 | These configuration steps are required only ONCE after
50 | installation (or upgrade) with pip.
51 |
52 | Configurations on Major Platforms
53 | ---------------------------------
54 | OpenEquivariance has been tested on both supercomputers and lab clusters.
55 | Here are some tested environment configuration files. If use OpenEquivariance
56 | on a widely-used platform, send us a pull request to add your configuration!
57 |
58 | NERSC Perlmutter (NVIDIA A100)
59 | """"""""""""""""""""""""""""""
60 |
61 | .. code-block:: bash
62 | :caption: env.sh (last updated June 2025)
63 |
64 | module load gcc
65 | module load conda
66 |
67 | # Deactivate any base environments
68 | for i in $(seq ${CONDA_SHLVL}); do
69 | conda deactivate
70 | done
71 |
72 | conda activate
73 |
74 |
75 | OLCF Frontier (AMD MI250x)
76 | """"""""""""""""""""""""""
77 | You need to install a HIP-enabled verison of PyTorch to use our package.
78 | To do this, follow the steps `here `_.
79 |
80 |
81 | .. code-block:: bash
82 | :caption: env.sh (last updated June 2025)
83 |
84 | module load PrgEnv-gnu/8.6.0
85 | module load miniforge3/23.11.0-0
86 | module load rocm/6.4.0
87 | module load craype-accel-amd-gfx90a
88 |
89 | for i in $(seq ${CONDA_SHLVL}); do
90 | conda deactivate
91 | done
92 |
93 | conda activate
94 | export CC=cc
95 | export CXX=CC
--------------------------------------------------------------------------------
/docs/livebuild.sh:
--------------------------------------------------------------------------------
1 | sphinx-autobuild -M dirhtml . _build --watch ../openequivariance
--------------------------------------------------------------------------------
/docs/models.rst:
--------------------------------------------------------------------------------
1 | Running MACE and Nequip
2 | ==============================
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 | :caption: Contents:
7 |
8 | MACE
9 | ----
10 |
11 | We have modified MACE to use our accelerated kernels instead
12 | of the standard e3nn backend. Here are the steps to replicate
13 | our MACE benchmark:
14 |
15 | 1. Install ``oeq`` and our modified version of MACE via
16 |
17 | .. code-block:: bash
18 |
19 | pip uninstall mace-torch
20 | pip install git+https://github.com/vbharadwaj-bk/mace_oeq_integration.git@oeq_experimental
21 |
22 | 2. Download the ``carbon.xyz`` data file, available at
23 | ``_.
24 |
25 | This graph has 158K edges. With the original e3nn backend, you would need a GPU with 80GB
26 | of memory to run the experiments. ``oeq`` provides a memory-efficient equivariant convolution,
27 | so we expect the test to succeed.
28 |
29 | 3. Benchmark OpenEquivariance:
30 |
31 | .. code-block:: bash
32 |
33 | python tests/mace_driver.py carbon.xyz -o outputs/mace_tests -i oeq
34 |
35 | 4. If you have a GPU with 80GB of memory *or* supply a smaller molecular graph
36 | as the input file, you can run the full benchmark that includes ``e3nn`` and ``cue``:
37 |
38 | .. code-block:: bash
39 |
40 | python tests/mace_driver.py carbon.xyz -o outputs/mace_tests -i e3nn cue oeq
41 |
42 | Nequip
43 | ------
44 | See the
45 | `official Nequip documentation `_
46 | to use OpenEquivariance with Nequip.
47 |
--------------------------------------------------------------------------------
/docs/supported_ops.rst:
--------------------------------------------------------------------------------
1 | Supported Operations
2 | ====================
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 | :caption: Contents:
7 |
8 | .. list-table::
9 | :widths: 50 25 25
10 | :header-rows: 1
11 |
12 | * - Operation
13 | - CUDA
14 | - HIP
15 | * - UVU
16 | - ✅
17 | - ✅
18 | * - UVW
19 | - ✅
20 | - ✅
21 | * - UVU + Convolution
22 | - ✅
23 | - ✅
24 | * - UVW + Convolution
25 | - ✅
26 | - ✅
27 | * - Symmetric Tensor Product
28 | - ✅ (beta)
29 | - ✅ (beta)
30 |
31 | e3nn supports a variety of connection modes for CG tensor products. We support
32 | two that are commonly used in equivariant graph neural networks:
33 | "uvu" and "uvw". Our JIT compiled kernels should handle:
34 |
35 | 1. Pure "uvu" tensor products, which are most efficient when the input with higher
36 | multiplicities is the first argument. Our results are identical to e3nn when irreps in
37 | the second input have multiplicity 1, and otherwise identical up to a reordering
38 | of the input weights.
39 |
40 | 2. Pure "uvw" tensor products, which are currently more efficient when the input with
41 | higher multiplicities is the first argument. Our results are identical to e3nn up to a reordering
42 | of the input weights.
43 |
44 | Our code includes correctness checks, but the configuration space is large. If you notice
45 | a bug, let us know in a GitHub issue. We'll try our best to correct it or document the problem here.
46 |
47 | Unsupported Tensor Product Configurations
48 | -----------------------------------------
49 |
50 | We do not (yet) support:
51 |
52 | - Mixing different instruction types in the same tensor product.
53 | - Instruction types besides "uvu" and "uvw".
54 | - Non-trainable instructions: all of your instructions must have weights associated.
55 |
56 | If you have a use case for any of the unsupported features above, let us know.
57 |
58 | Compilation with JITScript, Export, and AOTInductor
59 | ---------------------------------------------------
60 |
61 | OpenEquivariance supports model compilation with
62 | ``torch.compile``, JITScript, ``torch.export``, and AOTInductor.
63 | Demo the C++ model exports with
64 |
65 | .. code-block:: bash
66 |
67 | pytest tests/export_test.py
68 |
69 |
70 | NOTE: the AOTInductor test (and possibly export) fail
71 | unless you are using a Nightly
72 | build of PyTorch past 4/10/2025 due to incomplete support for
73 | TorchBind in earlier versions.
74 |
75 |
76 | Multiple Devices and Streams
77 | ----------------------------
78 | OpenEquivariance compiles kernels based on the compute capability of the
79 | first visible GPU. On heterogeneous systems, our kernels
80 | will only execute correctly on devices that share the compute capability
81 | of this first device.
82 |
83 | We are working on support for CUDA streams!
84 |
85 |
86 | Symmetric Contraction (Beta)
87 | ----------------------------
88 |
89 | We have recently added beta support for symmetric
90 | contraction acceleration. This primitive:
91 |
92 | - Is specific to MACE
93 | - Requires e3nn as a dependency.
94 | - Currently has no support for compile / export
95 |
96 | As a result, we do not expose it in the package
97 | toplevel. You can use our implementation by running
98 |
99 | .. code-block::
100 |
101 | from openequivariance.implementations.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction
--------------------------------------------------------------------------------
/docs/tests_and_benchmarks.rst:
--------------------------------------------------------------------------------
1 | Tests and Benchmarks
2 | ==============================
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 | :caption: Contents:
7 |
8 | OpenEquivariance is equipped with a comprehensive suite of tests
9 | and benchmarking utilities. You'll need some additional dependencies to run
10 | these; we provide instructions below.
11 |
12 | We recommend you clone our repository and use an editable install to run tests
13 | and benchmarks. You can still test our code with a non-editable install; just
14 | download the test folder and install only the dependencies with:
15 |
16 | .. code-block:: bash
17 |
18 | pip install "https://github.com/PASSIONLab/OpenEquivariance[dev]" --only-deps
19 | pip install "https://github.com/PASSIONLab/OpenEquivariance[bench]" --only-deps
20 |
21 | Correctness
22 | ------------------------------
23 | To set up the editable install and run the entire testsuite, use:
24 |
25 | .. code-block:: bash
26 |
27 | git clone https://github.com/PASSIONLab/OpenEquivariance
28 | cd OpenEquivariance
29 | pip install -e .[dev]
30 | pytest
31 |
32 | Browse the ``tests`` directory to run specific components.
33 |
34 |
35 | Replicating our Benchmarks
36 | ------------------------------
37 | We conducted our benchmarks on an NVIDIA A100-SXM-80GB GPU at Lawrence Berkeley National Laboratory.
38 | Your results may differ a different GPU. The following invocations run the experiments
39 | and generate plots from our paper.
40 |
41 | .. code-block:: bash
42 |
43 | git clone https://github.com/PASSIONLab/OpenEquivariance
44 | cd OpenEquivariance
45 | pip install -e .[bench]
46 | python tests/benchmark.py -o outputs/uvu uvu --plot
47 | python tests/benchmark.py -o outputs/uvw uvw --plot
48 | python tests/benchmark.py -o outputs/roofline roofline --plot
49 | python tests/benchmark.py -o outputs/conv conv --plot --data data/molecular_structures
50 | python tests/benchmark.py -o outputs/kahan_conv kahan_conv --data data/molecular_structures/
51 |
52 | If your GPU has limited memory, try the ``--limited-memory`` flag
53 | to disable some expensive tests and / or reduce the batch size with ``-b``.
54 | Run ``python tests/benchmark.py --help`` for a full list of flags.
55 |
56 | For example, here's a set of invocations for an NVIDIA A5000 GPU:
57 |
58 | .. code-block:: bash
59 |
60 | python tests/benchmark.py -o outputs/uvu uvu --limited-memory --plot
61 | python tests/benchmark.py -o outputs/uvw uvw -b 25000 --plot
62 | python tests/benchmark.py -o outputs/roofline roofline --plot
63 | python tests/benchmark.py -o outputs/conv conv --data data/molecular_structures --limited-memory
64 |
65 | For GPUs besides the NVIDIA A100, the roofline slope / peak will be incorrect.
66 | The plots for the convolution fusion experiments also require a GPU
67 | with a minimum of 40GB of memory.
68 |
69 | List of GPUs Tested
70 | --------------------------------
71 | OpenEquivariance has been tested successfully the following GPUs. Submit a pull
72 | request if you'd like to add your own!
73 |
74 | - NVIDIA A100-SXM-40GB and A100-SXM-80GB (A. Glover, NERSC Perlmutter, June 2025)
75 | - NVIDIA A5000 (V. Bharadwaj, UCB SLICE, June 2025)
76 | - NVIDIA V100 (V. Bharadwaj, LBNL Einsteinium, June 2025)
77 | - AMD MI250x (V. Bharadwaj, OLCF Frontier, June 2025)
78 | - AMD MI300x (V. Bharadwaj, AMD Cloud, February 2025)
--------------------------------------------------------------------------------
/io/cif_to_graph.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 | from sklearn.neighbors import radius_neighbors_graph
4 |
5 |
6 | def cif_to_molecular_graph(cif_file, cp, radii):
7 | with open(f"../data/cif_files/{cif_file}", "r") as f:
8 | print("Started reading file...")
9 | lines = f.readlines()
10 | print("Finished reading file!")
11 |
12 | coords = []
13 | for line in lines:
14 | if line.startswith("ATOM"):
15 | parts = line.split()
16 | coords.append(
17 | [float(parts[cp[0]]), float(parts[cp[1]]), float(parts[cp[2]])]
18 | )
19 |
20 | coords = np.array(coords)
21 |
22 | for radius in radii:
23 | print(f"Starting radius neighbors calculation, r={radius}")
24 | A = radius_neighbors_graph(
25 | coords, radius, mode="connectivity", include_self=False
26 | )
27 | print(f"Finished radius neighbors calculation, found {A.nnz} nonzeros.")
28 |
29 | # mmwrite(f'../data/molecular_structures/{cif_file.split(".")[0]}.mtx', A)
30 |
31 | coo_mat = A.tocoo()
32 | result = {"row": coo_mat.row, "col": coo_mat.col, "coords": coords}
33 |
34 | with open(
35 | f"../data/molecular_structures/{cif_file.split('.')[0]}_radius{radius}.pickle",
36 | "wb",
37 | ) as handle:
38 | pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)
39 |
40 |
41 | if __name__ == "__main__":
42 | # cif_to_molecular_graph('hiv_capsid.cif', (10, 11, 12), radii=[2.0, 2.5, 3.0, 3.5])
43 | cif_to_molecular_graph("covid_spike.cif", (10, 11, 12), radii=[2.0, 2.5, 3.0, 3.5])
44 | cif_to_molecular_graph("1drf.cif", (10, 11, 12), radii=[6.0])
45 |
--------------------------------------------------------------------------------
/io/load_nequip_configs.py:
--------------------------------------------------------------------------------
1 | """
2 | This script parse the repository of
3 | Nequip input files at
4 | https://github.com/mir-group/nequip-input-files.
5 | We extract the node / edge hidden features representations.
6 | """
7 |
8 | import os
9 | import yaml
10 |
11 |
12 | def process_nequip_configs():
13 | nequip_files = []
14 | for root, dirs, files in os.walk("../data/nequip-input-files"):
15 | for file in files:
16 | if file.endswith(".yaml"):
17 | nequip_files.append(os.path.join(root, file))
18 |
19 | irrep_pairs = []
20 | configs = []
21 | for file in nequip_files:
22 | with open(file, "r") as f:
23 | data = yaml.unsafe_load(f)
24 | filename = os.path.splitext(os.path.basename(file))[0]
25 | feature_irreps_hidden = data["feature_irreps_hidden"]
26 | irreps_edge_sh = data["irreps_edge_sh"]
27 | if (feature_irreps_hidden, irreps_edge_sh) not in irrep_pairs:
28 | irrep_pairs.append((feature_irreps_hidden, irreps_edge_sh))
29 | configs.append((feature_irreps_hidden, irreps_edge_sh, filename))
30 |
31 | for config in configs:
32 | print(config)
33 |
34 |
35 | if __name__ == "__main__":
36 | process_nequip_configs()
37 |
--------------------------------------------------------------------------------
/openequivariance/__init__.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: F401
2 | import sys
3 | import openequivariance.extlib
4 | from pathlib import Path
5 | from importlib.metadata import version
6 |
7 | from openequivariance.implementations.e3nn_lite import TPProblem, Irreps
8 | from openequivariance.implementations.TensorProduct import TensorProduct
9 | from openequivariance.implementations.convolution.TensorProductConv import (
10 | TensorProductConv,
11 | )
12 | from openequivariance.implementations.utils import torch_to_oeq_dtype
13 |
14 | __version__ = None
15 | try:
16 | __version__ = version("openequivariance")
17 | except Exception as e:
18 | print(f"Warning: Could not determine oeq version: {e}", file=sys.stderr)
19 |
20 |
21 | def _check_package_editable():
22 | import json
23 | from importlib.metadata import Distribution
24 |
25 | direct_url = Distribution.from_name("openequivariance").read_text("direct_url.json")
26 | return json.loads(direct_url).get("dir_info", {}).get("editable", False)
27 |
28 |
29 | _editable_install_output_path = Path(__file__).parent.parent / "outputs"
30 |
31 |
32 | def torch_ext_so_path():
33 | """
34 | :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance
35 | from the PyTorch C++ Interface.
36 | """
37 | return openequivariance.extlib.torch_module.__file__
38 |
39 |
40 | __all__ = [
41 | "TPProblem",
42 | "Irreps",
43 | "TensorProduct",
44 | "TensorProductConv",
45 | "torch_to_oeq_dtype",
46 | "_check_package_editable",
47 | "torch_ext_so_path",
48 | ]
49 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/ConvBenchmarkSuite.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | import pickle
5 | import pathlib
6 | import numpy as np
7 |
8 | import openequivariance as oeq
9 | from openequivariance.benchmark.logging_utils import getLogger
10 | from openequivariance.implementations.convolution.ConvolutionBase import CoordGraph
11 |
12 | logger = getLogger()
13 |
14 |
15 | def load_graph(filename):
16 | coords, rows, cols = [None] * 3
17 | name = pathlib.Path(filename).stem
18 | with open(filename, "rb") as f:
19 | logger.info(f"Loading {name} from pickle...")
20 | result = pickle.load(f)
21 | coords, rows, cols, name = result["coords"], result["row"], result["col"], name
22 | logger.info(
23 | f"Graph {name} loaded with {len(coords)} nodes and {len(rows)} edges."
24 | )
25 |
26 | return CoordGraph(coords, rows.astype(np.int64), cols.astype(np.int64), name)
27 |
28 |
29 | class ConvBenchmarkSuite:
30 | def __init__(
31 | self,
32 | configs,
33 | num_warmup=10,
34 | num_iter=30,
35 | reference_impl=None,
36 | test_name=None,
37 | prng_seed=12345,
38 | correctness_threshold=1e-5,
39 | ):
40 | self.configs = configs
41 | self.num_warmup = num_warmup
42 | self.num_iter = num_iter
43 | self.reference_impl = reference_impl
44 | self.prng_seed = 12345
45 | self.correctness_threshold = correctness_threshold
46 | self.exp_count = 0
47 | self.test_name = test_name
48 |
49 | self.millis_since_epoch = round(time.time() * 1000)
50 |
51 | def run(
52 | self,
53 | graph,
54 | implementations,
55 | direction,
56 | output_folder=None,
57 | correctness=True,
58 | benchmark=True,
59 | high_precision_ref=False,
60 | ):
61 | if output_folder is None:
62 | if oeq._check_package_editable():
63 | output_folder = (
64 | oeq._editable_install_output_path / f"{self.millis_since_epoch}"
65 | )
66 | else:
67 | raise ValueError(
68 | "output folder must be specified for non-editable installs."
69 | )
70 | else:
71 | output_folder = pathlib.Path(output_folder)
72 | output_folder.mkdir(parents=True, exist_ok=True)
73 |
74 | metadata = {
75 | "test_name": self.test_name,
76 | "configs": [str(config) for config in self.configs],
77 | "implementations": [impl.name() for impl in implementations],
78 | "graph": graph.name,
79 | }
80 | if self.exp_count == 0:
81 | with open(os.path.join(output_folder, "metadata.json"), "w") as f:
82 | json.dump(metadata, f, indent=2)
83 |
84 | for config in self.configs:
85 | for impl in implementations:
86 | tc_name = f"{config}, {impl.name()}"
87 | logger.info(f"Starting {tc_name}, graph {graph.name}, {direction}")
88 | conv = impl(config)
89 |
90 | if direction == "forward":
91 | if correctness:
92 | correctness = conv.test_correctness_forward(
93 | graph,
94 | thresh=self.correctness_threshold,
95 | prng_seed=self.prng_seed,
96 | reference_implementation=self.reference_impl,
97 | high_precision_ref=high_precision_ref,
98 | )
99 |
100 | if benchmark:
101 | benchmark = conv.benchmark_forward(
102 | self.num_warmup, self.num_iter, graph, prng_seed=12345
103 | )
104 |
105 | if direction == "backward":
106 | if correctness:
107 | correctness = conv.test_correctness_backward(
108 | graph,
109 | thresh=self.correctness_threshold,
110 | prng_seed=self.prng_seed,
111 | reference_implementation=self.reference_impl,
112 | high_precision_ref=high_precision_ref,
113 | )
114 |
115 | if benchmark:
116 | benchmark = conv.benchmark_backward(
117 | self.num_warmup, self.num_iter, graph, prng_seed=12345
118 | )
119 |
120 | if direction == "double_backward":
121 | if correctness:
122 | correctness = conv.test_correctness_double_backward(
123 | self.graph,
124 | thresh=self.correctness_threshold,
125 | prng_seed=self.prng_seed,
126 | reference_implementation=self.reference_impl,
127 | high_precision_ref=high_precision_ref,
128 | )
129 |
130 | assert not benchmark
131 |
132 | result = {
133 | "config": str(config),
134 | "irrep_dtype": str(config.irrep_dtype),
135 | "weight_dtype": str(config.weight_dtype),
136 | "torch_overhead_included": conv.torch_op,
137 | "direction": direction,
138 | "graph": graph.name,
139 | "name": impl.name(),
140 | "correctness": correctness,
141 | "benchmark": benchmark,
142 | }
143 |
144 | fname = pathlib.Path(
145 | f"{output_folder}/{self.exp_count}_{impl.name()}_{graph.name}.json"
146 | )
147 | with open(fname, "w") as f:
148 | json.dump(result, f, indent=2)
149 | self.exp_count += 1
150 |
151 | logger.info(f"Finished {tc_name}, graph {graph.name}")
152 |
153 | return output_folder
154 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/correctness_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | from openequivariance.implementations.TensorProductBase import TensorProductBase
4 | from openequivariance.implementations.CUETensorProduct import CUETensorProduct
5 | from openequivariance.implementations.e3nn_lite import TPProblem
6 | from openequivariance.benchmark.random_buffer_utils import (
7 | get_random_buffers_forward,
8 | get_random_buffers_backward,
9 | )
10 | from openequivariance.benchmark.logging_utils import getLogger, bcolors
11 | import numpy as np
12 | import numpy.linalg as la
13 |
14 | logger = getLogger()
15 |
16 |
17 | def check_similiarity(
18 | name: str,
19 | to_check: np.ndarray,
20 | ground_truth: np.ndarray,
21 | correctness_threshold: float,
22 | ):
23 | result = {}
24 | if to_check.shape != ground_truth.shape:
25 | result["shape_match"] = False
26 | result["diff_Linf_norm"] = np.inf
27 | result["pass"] = False
28 | logger.error(
29 | f"{bcolors.FAIL}Ground truth {name} shape does not match input! {to_check.shape=}, {ground_truth.shape=} {bcolors.ENDC}"
30 | )
31 | else:
32 | result["shape_match"] = True
33 | diff_Linf_norm = float(la.norm((ground_truth - to_check).flatten(), ord=np.inf))
34 | result["diff_Linf_norm"] = diff_Linf_norm
35 | result["pass"] = bool(diff_Linf_norm < correctness_threshold)
36 | if result["pass"]:
37 | logger.info(
38 | f" {bcolors.OKGREEN}{name} correctness check pass. {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}"
39 | )
40 | else:
41 | logger.error(
42 | f"{bcolors.FAIL}{name} correctness check fail! {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}"
43 | )
44 |
45 | return result
46 |
47 |
48 | def instantiate_implementation(
49 | implementation: Union[type[TensorProductBase], TensorProductBase],
50 | problem: TPProblem,
51 | ):
52 | if isinstance(implementation, type):
53 | test_tp = implementation(problem)
54 | else:
55 | test_tp = implementation
56 |
57 | if not isinstance(test_tp, TensorProductBase):
58 | raise TypeError(
59 | f"test_implementation must be a TensorProductBase or a subclass, got {type(implementation)}"
60 | )
61 |
62 | return test_tp
63 |
64 |
65 | def correctness_forward(
66 | problem: TPProblem,
67 | test_implementation: Union[type[TensorProductBase], TensorProductBase],
68 | reference_implementation: Optional[type[TensorProductBase]],
69 | batch_size: int,
70 | correctness_threshold: float,
71 | prng_seed: int,
72 | ) -> dict:
73 | if reference_implementation is None:
74 | from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct
75 |
76 | reference_implementation = E3NNTensorProduct
77 |
78 | result = {"thresh": correctness_threshold, "batch_size": batch_size}
79 |
80 | in1, in2, weights, out = get_random_buffers_forward(problem, batch_size, prng_seed)
81 |
82 | # run reference
83 | ref_tp = reference_implementation(problem)
84 |
85 | ref_out = out.copy()
86 | ref_tp.forward_cpu(
87 | L1_in=in1.copy(), L2_in=in2.copy(), L3_out=ref_out, weights=weights.copy()
88 | )
89 |
90 | weights_copy = weights.copy()
91 | if problem.shared_weights and test_implementation == CUETensorProduct:
92 | weights_copy = weights[np.newaxis, :]
93 |
94 | # run test
95 | test_tp = instantiate_implementation(test_implementation, problem)
96 | test_out = out.copy()
97 | test_tp.forward_cpu(
98 | L1_in=in1.copy(), L2_in=in2.copy(), L3_out=test_out, weights=weights_copy
99 | )
100 |
101 | for name, to_check, ground_truth in [("output", ref_out, test_out)]:
102 | result[name] = check_similiarity(
103 | name, to_check, ground_truth, correctness_threshold
104 | )
105 |
106 | return result
107 |
108 |
109 | def correctness_backward(
110 | problem: TPProblem,
111 | test_implementation: Union[type[TensorProductBase], TensorProductBase],
112 | reference_implementation: Optional[type[TensorProductBase]],
113 | batch_size: int,
114 | correctness_threshold: float,
115 | prng_seed: int,
116 | ) -> dict:
117 | if reference_implementation is None:
118 | from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct
119 |
120 | reference_implementation = E3NNTensorProduct
121 |
122 | result = {"thresh": correctness_threshold, "batch_size": batch_size}
123 |
124 | # run reference
125 | in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad = (
126 | get_random_buffers_backward(problem, batch_size, prng_seed)
127 | )
128 |
129 | ref_tp = reference_implementation(problem)
130 |
131 | ref_weights_grad = weights_grad.copy()
132 | ref_in1_grad = in1_grad.copy()
133 | ref_in2_grad = in2_grad.copy()
134 |
135 | ref_tp.backward_cpu(
136 | L1_in=in1.copy(),
137 | L1_grad=ref_in1_grad,
138 | L2_in=in2.copy(),
139 | L2_grad=ref_in2_grad,
140 | L3_grad=out_grad.copy(),
141 | weights=weights.copy(),
142 | weights_grad=ref_weights_grad,
143 | )
144 |
145 | # run test version
146 | test_weights_grad = weights_grad.copy()
147 | test_in1_grad = in1_grad.copy()
148 | test_in2_grad = in2_grad.copy()
149 |
150 | weights_copy = weights.copy()
151 |
152 | if problem.shared_weights and test_implementation == CUETensorProduct:
153 | weights_copy = weights[np.newaxis, :]
154 | test_weights_grad = test_weights_grad[np.newaxis, :]
155 |
156 | test_tp = instantiate_implementation(test_implementation, problem)
157 | test_tp.backward_cpu(
158 | L1_in=in1.copy(),
159 | L1_grad=test_in1_grad,
160 | L2_in=in2.copy(),
161 | L2_grad=test_in2_grad,
162 | L3_grad=out_grad.copy(),
163 | weights=weights_copy,
164 | weights_grad=test_weights_grad,
165 | )
166 |
167 | weight_threshold = (
168 | correctness_threshold * batch_size
169 | if problem.shared_weights
170 | else correctness_threshold
171 | )
172 |
173 | if problem.shared_weights:
174 | test_weights_grad = test_weights_grad.squeeze()
175 |
176 | for name, to_check, ground_truth, threshold in [
177 | ("weight_grad", test_weights_grad, ref_weights_grad, weight_threshold),
178 | ("in1_grad", test_in1_grad, ref_in1_grad, correctness_threshold),
179 | ("in2_grad", test_in2_grad, ref_in2_grad, correctness_threshold),
180 | ]:
181 | result[name] = check_similiarity(name, to_check, ground_truth, threshold)
182 |
183 | return result
184 |
185 |
186 | def correctness_double_backward(
187 | problem: TPProblem,
188 | test_implementation: Union[type[TensorProductBase], TensorProductBase],
189 | reference_implementation: Optional[type[TensorProductBase]],
190 | batch_size: int,
191 | correctness_threshold: float,
192 | prng_seed: int,
193 | ):
194 | global torch
195 | import torch
196 |
197 | in1, in2, out_grad, weights, _, _, _ = get_random_buffers_backward(
198 | problem, batch_size, prng_seed
199 | )
200 | rng = np.random.default_rng(seed=prng_seed * 2)
201 | dummy_grad = rng.standard_normal(1)[0]
202 |
203 | if reference_implementation is None:
204 | from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct
205 |
206 | reference_implementation = E3NNTensorProduct
207 |
208 | result = {"thresh": correctness_threshold, "batch_size": batch_size}
209 |
210 | tensors = []
211 | for i, impl in enumerate([test_implementation, reference_implementation]):
212 | tp = instantiate_implementation(impl, problem)
213 |
214 | if impl == CUETensorProduct and problem.shared_weights:
215 | weights = weights[np.newaxis, :]
216 |
217 | weights_reordered = np.zeros_like(weights)
218 | if tp.reorder_weights_e3nn_to_oeq is not None:
219 | tp.reorder_weights_e3nn_to_oeq(
220 | weights, weights_reordered, not tp.config.shared_weights
221 | )
222 | else:
223 | weights_reordered = weights
224 |
225 | in1_torch = torch.tensor(in1, device="cuda", requires_grad=True)
226 | in2_torch = torch.tensor(in2, device="cuda", requires_grad=True)
227 | weights_torch = torch.tensor(
228 | weights_reordered, device="cuda", requires_grad=True
229 | )
230 |
231 | out_torch = tp.forward(in1_torch, in2_torch, weights_torch)
232 | out_grad = out_torch.clone().detach().to(device="cuda").requires_grad_(True)
233 |
234 | in1_grad, in2_grad, w_grad = torch.autograd.grad(
235 | outputs=[out_torch],
236 | inputs=[in1_torch, in2_torch, weights_torch],
237 | grad_outputs=[out_grad],
238 | create_graph=True,
239 | )
240 |
241 | dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad)
242 | dummy_grad = torch.tensor(float(dummy_grad), device="cuda", requires_grad=True)
243 |
244 | dummy.backward(
245 | dummy_grad,
246 | retain_graph=True,
247 | inputs=[out_grad, in1_torch, in2_torch, weights_torch],
248 | )
249 |
250 | weights_grad = weights_torch.grad.detach().cpu().numpy()
251 | if tp.reorder_weights_oeq_to_e3nn is not None:
252 | weights_grad_copy = weights_grad.copy()
253 | tp.reorder_weights_oeq_to_e3nn(
254 | weights_grad_copy, weights_grad, not tp.config.shared_weights
255 | )
256 |
257 | tensors.append(
258 | (
259 | out_grad.grad.detach().cpu().numpy(),
260 | in1_torch.grad.detach().cpu().numpy(),
261 | in2_torch.grad.detach().cpu().numpy(),
262 | weights_grad,
263 | )
264 | )
265 |
266 | for name, to_check, ground_truth in [
267 | ("output_double_grad", tensors[0][0], tensors[1][0]),
268 | ("in1_grad", tensors[0][1], tensors[1][1]),
269 | ("in2_grad", tensors[0][2], tensors[1][2]),
270 | ("weights_grad", tensors[0][3], tensors[1][3]),
271 | ]:
272 | result[name] = check_similiarity(
273 | name, to_check, ground_truth, correctness_threshold
274 | )
275 |
276 | return result
277 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/logging_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logger = logging.getLogger("ETP")
4 | logger.setLevel(logging.CRITICAL)
5 | ch = logging.StreamHandler()
6 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
7 | ch.setFormatter(formatter)
8 | logger.addHandler(ch)
9 |
10 |
11 | def getLogger():
12 | return logger
13 |
14 |
15 | class bcolors:
16 | HEADER = "\033[95m"
17 | OKBLUE = "\033[94m"
18 | OKCYAN = "\033[96m"
19 | OKGREEN = "\033[92m"
20 | WARNING = "\033[93m"
21 | FAIL = "\033[91m"
22 | ENDC = "\033[0m"
23 | BOLD = "\033[1m"
24 | UNDERLINE = "\033[4m"
25 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/perf_metrics_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from openequivariance.implementations.utils import (
4 | count_cg_non_zero,
5 | sparse_outer_product_work,
6 | )
7 | from openequivariance.implementations.TensorProductBase import TensorProductBase
8 | from openequivariance.implementations.e3nn_lite import TPProblem
9 | from openequivariance.benchmark.logging_utils import getLogger
10 | import numpy as np
11 |
12 | logger = getLogger()
13 |
14 |
15 | def calculate_minimum_memory_streamed_forward(
16 | tpp: TPProblem, batch_size: int
17 | ) -> dict[str, int]:
18 | """
19 | This represents an absolute minimum amount of memory that could be streamed on an ideal machine
20 | It returns the number of bytes streamed total and from each source
21 | """
22 | data_size = {}
23 | irrep_word_size = np.dtype(tpp.irrep_dtype).itemsize
24 | weight_word_size = np.dtype(tpp.weight_dtype).itemsize
25 |
26 | data_size["input 1"] = tpp.irreps_in1.dim * batch_size * irrep_word_size
27 | data_size["input 2"] = tpp.irreps_in2.dim * batch_size * irrep_word_size
28 | data_size["output"] = tpp.irreps_out.dim * batch_size * irrep_word_size
29 | data_size["weights"] = tpp.weight_numel * batch_size * weight_word_size
30 | data_size["total"] = sum(data_size.values())
31 | return data_size
32 |
33 |
34 | def calculate_minimum_memory_streamed_backward(tpp: TPProblem, batch_size: int) -> dict:
35 | """
36 | This represents an absolute minimum amount of memory that could be streamed on an ideal machine
37 | It returns the number of bytes streamed total and from each source
38 | """
39 | data_size = {}
40 | irrep_word_size = np.dtype(tpp.irrep_dtype).itemsize
41 | weight_word_size = np.dtype(tpp.weight_dtype).itemsize
42 |
43 | data_size["input 1"] = tpp.irreps_in1.dim * batch_size * irrep_word_size
44 | data_size["input 1 grad"] = tpp.irreps_in1.dim * batch_size * irrep_word_size
45 | data_size["input 2"] = tpp.irreps_in2.dim * batch_size * irrep_word_size
46 | data_size["input 2 grad"] = tpp.irreps_in2.dim * batch_size * irrep_word_size
47 | data_size["output grad"] = tpp.irreps_out.dim * batch_size * irrep_word_size
48 | data_size["weights"] = tpp.weight_numel * batch_size * weight_word_size
49 | data_size["weights grad"] = tpp.weight_numel * batch_size * weight_word_size
50 | data_size["total"] = sum(data_size.values())
51 | return data_size
52 |
53 |
54 | def calculate_minimum_flops_forward(tpp: TPProblem, batch_size: int) -> dict:
55 | """
56 | This is not actually calcuating the minimum value.
57 | Ideally you might share the outer product values between two inputs across multiple inputs.
58 | This is assuming that you form those values and reuse them once per CG decomp.
59 | """
60 | logger.warning("Minimum flops Calculation is not the true minimum")
61 | flops_count = {}
62 | flops_count["outer_products"] = 0
63 | flops_count["CG_decomposition"] = 0
64 | flops_count["linear_combination"] = 0
65 | for ins in tpp.instructions: # type : Instruction
66 | l1, l2, l3 = (
67 | tpp.irreps_in1[ins.i_in1].ir.l,
68 | tpp.irreps_in2[ins.i_in2].ir.l,
69 | tpp.irreps_out[ins.i_out].ir.l,
70 | )
71 |
72 | flops_count["outer_products"] += sparse_outer_product_work(
73 | TensorProductBase.load_cg_tensor(l1, l2, l3)
74 | )
75 | flops_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * (
76 | ins.path_shape[0] * ins.path_shape[1]
77 | )
78 | flops_count["linear_combination"] += (
79 | (2 * l3 + 1) * math.prod(ins.path_shape) if ins.has_weight else 0
80 | )
81 |
82 | flops_count["outer_products"] *= batch_size
83 | flops_count["CG_decomposition"] *= 2 * batch_size
84 | flops_count["linear_combination"] *= 2 * batch_size
85 |
86 | flops_count["total"] = sum(flops_count.values())
87 | return flops_count
88 |
89 |
90 | def calculate_minimum_flops_backward(tpp: TPProblem, batch_size: int) -> dict:
91 | """
92 | This is not actually calcuating the minumum value.
93 | Ideally you might share the outer product values between two inputs across multiple inputs.
94 | This is assuming that you form those values and reuse them once per CG decomp.
95 | """
96 | raise NotImplementedError("this needs to be implemented properly")
97 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/plotting/__init__.py:
--------------------------------------------------------------------------------
1 | from openequivariance.benchmark.plotting.plot_uvu import plot_uvu
2 | from openequivariance.benchmark.plotting.plot_uvw import plot_uvw
3 | from openequivariance.benchmark.plotting.plot_roofline import plot_roofline
4 | from openequivariance.benchmark.plotting.plot_convolution import plot_convolution
5 | from openequivariance.benchmark.plotting.plot_double_backward import (
6 | plot_double_backward,
7 | )
8 |
9 | __all__ = [
10 | "plot_uvu",
11 | "plot_uvw",
12 | "plot_roofline",
13 | "plot_convolution",
14 | "plot_double_backward",
15 | ]
16 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/plotting/plot_convolution.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import pathlib
4 | from openequivariance.benchmark.plotting.plotting_utils import (
5 | set_grid,
6 | colormap,
7 | labelmap,
8 | hatchmap,
9 | dtypes,
10 | directions,
11 | dtype_labelmap,
12 | grouped_barchart,
13 | load_benchmarks,
14 | )
15 |
16 |
17 | def plot_convolution(data_folder):
18 | data_folder = pathlib.Path(data_folder)
19 | benchmarks, metadata = load_benchmarks(data_folder)
20 |
21 | implementations = [
22 | "CUEConvolution",
23 | "CUEConvolutionFused",
24 | "LoopUnrollConvScatterSum",
25 | "LoopUnrollConvAtomic",
26 | "LoopUnrollConvDeterministic",
27 | ]
28 |
29 | graphs = ["1drf_radius6.0", "covid_spike_radius3.0", "carbon_lattice_radius6.0"]
30 | graph_lmap = {
31 | "covid_spike_radius3.0": "COVID spike",
32 | "1drf_radius6.0": "DHFR",
33 | "carbon_lattice_radius6.0": "carbon-lattice",
34 | }
35 |
36 | data = {}
37 |
38 | for direction in directions:
39 | data[direction] = {}
40 | for dtype in dtypes:
41 | data[direction][dtype] = {}
42 | for graph in graphs:
43 | data[direction][dtype][graph_lmap[graph]] = {}
44 | for impl in implementations:
45 | exp = filter(
46 | benchmarks,
47 | {
48 | "graph": graph,
49 | "direction": direction,
50 | "name": impl,
51 | "irrep_dtype": dtype,
52 | },
53 | match_one=True,
54 | )
55 |
56 | data[direction][dtype][graph_lmap[graph]][labelmap[impl]] = np.mean(
57 | exp["benchmark"]["time_millis"]
58 | )
59 |
60 | fig = plt.figure(figsize=(5, 5))
61 | gs = fig.add_gridspec(2, 2, hspace=0, wspace=0)
62 | axes = gs.subplots(sharex="col", sharey="row")
63 |
64 | for i, direction in enumerate(directions):
65 | for j, dtype in enumerate(dtypes):
66 | for k, graph in enumerate(graphs):
67 | normalizing_value = data[direction][dtype][graph_lmap[graph]][
68 | "cuE-scattersum"
69 | ]
70 | for impl in implementations:
71 | data[direction][dtype][graph_lmap[graph]][labelmap[impl]] = (
72 | normalizing_value
73 | / data[direction][dtype][graph_lmap[graph]][labelmap[impl]]
74 | )
75 |
76 | grouped_barchart(
77 | data[direction][dtype],
78 | axes[i][j],
79 | bar_height_fontsize=0,
80 | rotate_xlabels=True,
81 | colormap=colormap,
82 | hatchmap=hatchmap,
83 | group_spacing=6.0,
84 | )
85 |
86 | axes[i][j].set_xlabel(dtype_labelmap[dtype])
87 | axes[i][j].set_ylabel(direction)
88 | axes[i][j].axhline(1.0, ls="--", c=colormap["cuE"])
89 | set_grid(axes[i][j])
90 |
91 | axes[1][0].set_ylim(0, 3.8)
92 | for ax in fig.get_axes():
93 | ax.label_outer()
94 |
95 | fig.supylabel("Speedup over cuE-scattersum", x=0.025, y=0.6)
96 |
97 | handles, labels = axes[0][0].get_legend_handles_labels()
98 | for i, l in enumerate(labels):
99 | if "fast" in l:
100 | labels[i] += " (ours)"
101 |
102 | unique = [
103 | (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]
104 | ]
105 | fig.legend(*zip(*unique), loc="upper center", bbox_to_anchor=(0.55, 0.01))
106 |
107 | fig.show()
108 | fig.tight_layout()
109 | fig.savefig(str(data_folder / "kernel_fusion_speedup.pdf"), bbox_inches="tight")
110 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/plotting/plot_double_backward.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import pathlib
4 | from openequivariance.benchmark.plotting.plotting_utils import (
5 | set_grid,
6 | colormap,
7 | labelmap,
8 | grouped_barchart,
9 | load_benchmarks,
10 | )
11 |
12 |
13 | def plot_double_backward(data_folder):
14 | data_folder = pathlib.Path(data_folder)
15 | benchmarks, metadata = load_benchmarks(data_folder)
16 |
17 | configs = metadata["config_labels"]
18 | implementations = ["E3NNTensorProduct", "CUETensorProduct", "LoopUnrollTP"]
19 |
20 | def calculate_tp_per_sec(exp):
21 | return exp["benchmark results"]["batch_size"] / (
22 | np.mean(exp["benchmark results"]["time_millis"]) * 0.001
23 | )
24 |
25 | dataf32 = {"double_backward": {}}
26 | for i, desc in enumerate(configs):
27 | for direction in ["double_backward"]:
28 | dataf32[direction][desc] = {}
29 | for impl in implementations:
30 | f32_benches = [
31 | b
32 | for b in benchmarks
33 | if b["benchmark results"]["rep_dtype"] == ""
34 | ]
35 | exp = filter(
36 | f32_benches,
37 | {
38 | "config_label": desc,
39 | "direction": direction,
40 | "implementation_name": impl,
41 | },
42 | match_one=True,
43 | )
44 | dataf32[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp)
45 |
46 | dataf64 = {"double_backward": {}}
47 | for i, desc in enumerate(configs):
48 | for direction in ["double_backward"]:
49 | dataf64[direction][desc] = {}
50 | for impl in implementations:
51 | f64_benches = [
52 | b
53 | for b in benchmarks
54 | if "float64" in b["benchmark results"]["rep_dtype"]
55 | ]
56 |
57 | exp = filter(
58 | f64_benches,
59 | {
60 | "config_label": desc,
61 | "direction": direction,
62 | "implementation_name": impl,
63 | },
64 | match_one=True,
65 | )
66 |
67 | if exp is None:
68 | print(desc)
69 | print(direction)
70 | print(impl)
71 |
72 | dataf64[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp)
73 |
74 | fig = plt.figure(figsize=(7, 3))
75 | gs = fig.add_gridspec(1, 2, hspace=0, wspace=0.1)
76 | axs = gs.subplots(sharex="col", sharey="row")
77 |
78 | grouped_barchart(
79 | dataf32["double_backward"],
80 | axs[0],
81 | bar_height_fontsize=0,
82 | colormap=colormap,
83 | group_spacing=6.0,
84 | )
85 | grouped_barchart(
86 | dataf64["double_backward"],
87 | axs[1],
88 | bar_height_fontsize=0,
89 | colormap=colormap,
90 | group_spacing=6.0,
91 | )
92 |
93 | for i in range(2):
94 | set_grid(axs[i])
95 | set_grid(axs[i])
96 |
97 | axs[0].set_xlabel("float32")
98 | axs[1].set_xlabel("float64")
99 |
100 | handles, labels = axs[0].get_legend_handles_labels()
101 | unique = [
102 | (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]
103 | ]
104 | axs[0].legend(*zip(*unique))
105 |
106 | for ax in fig.get_axes():
107 | ax.label_outer()
108 |
109 | fig.supylabel("2nd Deriv. Throughput\n(# tensor products / s)", y=0.5)
110 |
111 | speedup_table = []
112 | for direction in ["double_backward"]:
113 | for impl in ["e3nn", "cuE"]:
114 | for dtype_label, dtype_set in [("f32", dataf32), ("f64", dataf64)]:
115 | speedups = [
116 | measurement["ours"] / measurement[impl]
117 | for _, measurement in dtype_set[direction].items()
118 | if impl in measurement
119 | ]
120 | stats = (
121 | np.min(speedups),
122 | np.mean(speedups),
123 | np.median(speedups),
124 | np.max(speedups),
125 | )
126 | stats = [f"{stat:.2f}" for stat in stats]
127 |
128 | dir_print = direction
129 | result = [dir_print, impl, dtype_label] + stats
130 | speedup_table.append(result)
131 |
132 | print("\t\t".join(["Direction", "Base", "dtype", "min", "mean", "med", "max"]))
133 | for row in speedup_table:
134 | print("\t\t".join(row))
135 |
136 | fig.show()
137 | fig.tight_layout()
138 | fig.savefig(
139 | str(data_folder / "double_backward_throughput.pdf"), bbox_inches="tight"
140 | )
141 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/plotting/plot_roofline.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pathlib
3 | from openequivariance.benchmark.plotting.plotting_utils import (
4 | colormap,
5 | labelmap,
6 | load_benchmarks,
7 | roofline_plot,
8 | )
9 |
10 |
11 | def plot_roofline(data_folder):
12 | data_folder = pathlib.Path(data_folder)
13 | benchmarks, metadata = load_benchmarks(data_folder)
14 |
15 | configs = metadata["config_labels"]
16 | implementations = ["LoopUnrollTP", "CUETensorProduct"]
17 |
18 | data = {"forward": {}, "backward": {}}
19 | for i, desc in enumerate(configs):
20 | for direction in ["forward", "backward"]:
21 | data[direction][desc] = {}
22 | for impl in implementations:
23 | exp = filter(
24 | benchmarks,
25 | {
26 | "config_label": desc,
27 | "direction": direction,
28 | "implementation_name": impl,
29 | },
30 | match_one=True,
31 | )
32 | data[direction][desc][labelmap[impl]] = (
33 | exp["benchmark results"]["arithmetic_intensity (FLOPs / byte)"],
34 | np.mean(exp["benchmark results"]["throughputs_gflops"]),
35 | )
36 |
37 | roofline_data = []
38 | marker_map = {
39 | "forward-cuE": "+",
40 | "backward-cuE": "X",
41 | "forward-ours": "P",
42 | "backward-ours": "X",
43 | }
44 | for i, desc in enumerate(configs):
45 | for direction in ["forward", "backward"]:
46 | ai, throughput = data[direction][desc][labelmap["LoopUnrollTP"]]
47 | label = f"{direction}-ours"
48 | roofline_data.append(
49 | {
50 | "AI": float(ai),
51 | "throughput": throughput / 1000,
52 | "label": label,
53 | "marker": marker_map[label],
54 | "color": colormap["ours"],
55 | "markersize": 80,
56 | }
57 | )
58 |
59 | label = f"{direction}-cuE"
60 | ai, throughput = data[direction][desc][labelmap["CUETensorProduct"]]
61 | roofline_data.append(
62 | {
63 | "AI": float(ai),
64 | "throughput": throughput / 1000,
65 | "label": label,
66 | "marker": marker_map[label],
67 | "color": colormap["cuE"],
68 | "markersize": 80,
69 | }
70 | )
71 |
72 | cpu_roofs = {"A100-SXM-80GB FP32 Peak": 19.5}
73 | mem_bottlenecks = {"HBM2": 2.039}
74 | AI_v = {"": 9.56}
75 |
76 | draw_bounds = {"xmin": 0.4, "xmax": 15, "ymin": 0.15, "ymax": 25}
77 | fig, ax = roofline_plot(
78 | draw_bounds,
79 | cpu_roofs,
80 | mem_bottlenecks,
81 | AI_v,
82 | roofline_data,
83 | fig_ratio=1.8,
84 | fig_dimension=4,
85 | )
86 |
87 | handles, labels = ax.get_legend_handles_labels()
88 | unique = [
89 | (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]
90 | ]
91 | ax.legend(*zip(*unique))
92 |
93 | fig.show()
94 | fig.savefig(str(data_folder / "roofline.pdf"))
95 |
96 | # Table of throughputs and arithmetic intensities
97 |
98 | header = r"""
99 | \begin{tabular}{cccccc}
100 | \toprule
101 | \multirow{2}{*}{ID} & \multirow{2}{*}{Description} & \multirow{2}{*}{Dir.} & \multirow{2}{*}{AI} & \multicolumn{2}{c}{TFLOP/s} \\
102 | \cmidrule(r){5-6}
103 | & & & & \multicolumn{1}{l}{cuE} & ours \\
104 | \midrule
105 | """
106 |
107 | rows = []
108 |
109 | dir_map = {"forward": "F", "backward": "B"}
110 | for i, desc in enumerate(sorted(configs)):
111 | for direction in ["forward", "backward"]:
112 | for impl in implementations:
113 | short_id, long_desc = desc.split("#")
114 | long_desc = long_desc.replace("->", "$\\rightarrow$").replace(
115 | " x ", "$\ \\times\ $"
116 | )
117 | ai_ours, throughput_ours = data[direction][desc][
118 | labelmap["LoopUnrollTP"]
119 | ]
120 | throughput_ours = f"{float(throughput_ours / 1000):.2f}"
121 | _, throughput_cue = data[direction][desc][labelmap["CUETensorProduct"]]
122 | throughput_cue = f"{float(throughput_cue / 1000):.2f}"
123 |
124 | result = [
125 | "\multirow{2}{*}{" + short_id + "}",
126 | "\multirow{2}{*}{" + long_desc + "}",
127 | dir_map[direction],
128 | f"{ai_ours:.1f}",
129 | throughput_cue,
130 | throughput_ours,
131 | ]
132 | if direction == "backward":
133 | result[0] = ""
134 | result[1] = ""
135 | rows.append(result)
136 |
137 | print(header)
138 | result = ""
139 | for i, row in enumerate(rows):
140 | result += " & ".join(row) + r"\\" + "\n"
141 | if row[2] == "B" and i < len(rows) - 1:
142 | result += "\cmidrule(r){3-6}" + "\n"
143 |
144 | print(result.replace("[", "").replace("]", "").replace("uvu", "B"))
145 | print("\\bottomrule\n\\end{tabular}")
146 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/plotting/plot_uvu.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import pathlib
4 | from openequivariance.benchmark.plotting.plotting_utils import (
5 | set_grid,
6 | colormap,
7 | labelmap,
8 | grouped_barchart,
9 | load_benchmarks,
10 | )
11 |
12 |
13 | def plot_uvu(data_folder):
14 | data_folder = pathlib.Path(data_folder)
15 | benchmarks, metadata = load_benchmarks(data_folder)
16 | configs = metadata["config_labels"]
17 | implementations = metadata["implementations"]
18 |
19 | for benchmark in benchmarks:
20 | if (
21 | benchmark["implementation_name"]
22 | == "E3NNTensorProductCompiledMaxAutotuneCUDAGraphs"
23 | ):
24 | benchmark["implementation_name"] = "E3NNTensorProduct"
25 |
26 | for i, implementation in enumerate(implementations):
27 | if implementation == "E3NNTensorProductCompiledMaxAutotuneCUDAGraphs":
28 | implementations[i] = "E3NNTensorProduct"
29 |
30 | def calculate_tp_per_sec(exp):
31 | return exp["benchmark results"]["batch_size"] / (
32 | np.mean(exp["benchmark results"]["time_millis"]) * 0.001
33 | )
34 |
35 | dataf32 = {"forward": {}, "backward": {}}
36 | for i, desc in enumerate(configs):
37 | for direction in ["forward", "backward"]:
38 | dataf32[direction][desc] = {}
39 | for impl in implementations:
40 | f32_benches = [
41 | b
42 | for b in benchmarks
43 | if b["benchmark results"]["rep_dtype"] == ""
44 | ]
45 | exp = filter(
46 | f32_benches,
47 | {
48 | "config_label": desc,
49 | "direction": direction,
50 | "implementation_name": impl,
51 | },
52 | match_one=True,
53 | )
54 | if exp is not None:
55 | dataf32[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp)
56 | else:
57 | dataf32[direction][desc][labelmap[impl]] = 0.0
58 |
59 | dataf64 = {"forward": {}, "backward": {}}
60 | for i, desc in enumerate(configs):
61 | for direction in ["forward", "backward"]:
62 | dataf64[direction][desc] = {}
63 | for impl in implementations:
64 | f64_benches = [
65 | b
66 | for b in benchmarks
67 | if b["benchmark results"]["rep_dtype"] == ""
68 | ]
69 | exp = filter(
70 | f64_benches,
71 | {
72 | "config_label": desc,
73 | "direction": direction,
74 | "implementation_name": impl,
75 | },
76 | match_one=True,
77 | )
78 |
79 | if exp is not None:
80 | dataf64[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp)
81 | else:
82 | dataf64[direction][desc][labelmap[impl]] = 0.0
83 |
84 | fig = plt.figure(figsize=(7, 7))
85 | gs = fig.add_gridspec(2, 2)
86 | axs = gs.subplots(sharex=True)
87 |
88 | grouped_barchart(
89 | dataf32["forward"],
90 | axs[0][0],
91 | bar_height_fontsize=0,
92 | xticklabel=False,
93 | colormap=colormap,
94 | group_spacing=6.0,
95 | )
96 | grouped_barchart(
97 | dataf32["backward"],
98 | axs[1][0],
99 | bar_height_fontsize=0,
100 | colormap=colormap,
101 | group_spacing=6.0,
102 | )
103 |
104 | grouped_barchart(
105 | dataf64["forward"],
106 | axs[0][1],
107 | bar_height_fontsize=0,
108 | xticklabel=False,
109 | colormap=colormap,
110 | group_spacing=6.0,
111 | )
112 | grouped_barchart(
113 | dataf64["backward"],
114 | axs[1][1],
115 | bar_height_fontsize=0,
116 | colormap=colormap,
117 | group_spacing=6.0,
118 | )
119 |
120 | for i in range(2):
121 | for j in range(2):
122 | set_grid(axs[i][j])
123 | set_grid(axs[i][j])
124 |
125 | axs[0][0].set_ylabel("Forward")
126 | axs[1][0].set_ylabel("Backward")
127 |
128 | handles, labels = axs[0][0].get_legend_handles_labels()
129 | unique = [
130 | (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]
131 | ]
132 | axs[0][0].legend(*zip(*unique))
133 |
134 | fig.supylabel("Throughput (# tensor products / s)", x=0.036, y=0.605)
135 |
136 | axs[1][0].set_xlabel("float32")
137 | axs[1][1].set_xlabel("float64")
138 |
139 | fig.show()
140 | fig.tight_layout()
141 | fig.savefig(str(data_folder / "throughput_comparison.pdf"))
142 |
143 | speedup_table = []
144 | for direction in ["forward", "backward"]:
145 | for impl in ["e3nn", "cuE"]:
146 | for dtype_label, dtype_set in [("f32", dataf32), ("f64", dataf64)]:
147 | speedups = [
148 | measurement["ours"] / measurement[impl]
149 | for _, measurement in dtype_set[direction].items()
150 | if impl in measurement
151 | ]
152 | stats = (
153 | np.min(speedups),
154 | np.mean(speedups),
155 | np.median(speedups),
156 | np.max(speedups),
157 | )
158 | stats = [f"{stat:.2f}" for stat in stats]
159 |
160 | dir_print = direction
161 | if direction == "forward":
162 | dir_print += " "
163 | result = [dir_print, impl, dtype_label] + stats
164 | speedup_table.append(result)
165 |
166 | print("\t\t".join(["Direction", "Base", "dtype", "min", "mean", "med", "max"]))
167 | for row in speedup_table:
168 | print("\t\t".join(row))
169 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/plotting/plot_uvw.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import pathlib
4 | from openequivariance.benchmark.plotting.plotting_utils import (
5 | set_grid,
6 | colormap,
7 | labelmap,
8 | grouped_barchart,
9 | calculate_tp_per_sec,
10 | load_benchmarks,
11 | )
12 |
13 |
14 | def plot_uvw(data_folder):
15 | data_folder = pathlib.Path(data_folder)
16 | benchmarks, metadata = load_benchmarks(data_folder)
17 |
18 | configs = metadata["config_labels"]
19 | implementations = metadata["implementations"]
20 | metadata["directions"]
21 |
22 | dataf32 = {"forward": {}, "backward": {}}
23 | for i, desc in enumerate(configs):
24 | for direction in ["forward", "backward"]:
25 | dataf32[direction][desc] = {}
26 | for impl in implementations:
27 | if True: # direction == "forward" or impl != "CUETensorProduct" or 'mace' in desc:
28 | f32_benches = [
29 | b
30 | for b in benchmarks
31 | if b["benchmark results"]["rep_dtype"]
32 | == ""
33 | ]
34 | exp = filter(
35 | f32_benches,
36 | {
37 | "config_label": desc,
38 | "direction": direction,
39 | "implementation_name": impl,
40 | },
41 | match_one=True,
42 | )
43 | dataf32[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp)
44 |
45 | dataf64 = {"forward": {}, "backward": {}}
46 | for i, desc in enumerate(configs):
47 | for direction in ["forward", "backward"]:
48 | dataf64[direction][desc] = {}
49 | for impl in implementations:
50 | if True: # direction == "forward" or impl != "CUETensorProduct" or 'mace' in desc:
51 | f64_benches = [
52 | b
53 | for b in benchmarks
54 | if b["benchmark results"]["rep_dtype"]
55 | == ""
56 | ]
57 | exp = filter(
58 | f64_benches,
59 | {
60 | "config_label": desc,
61 | "direction": direction,
62 | "implementation_name": impl,
63 | },
64 | match_one=True,
65 | )
66 | dataf64[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp)
67 |
68 | plt.rcParams["font.family"] = "serif"
69 | plt.rcParams.update({"font.size": 11})
70 |
71 | fig = plt.figure(figsize=(7, 7))
72 | gs = fig.add_gridspec(2, 2)
73 | axs = gs.subplots(sharex=True, sharey="row")
74 |
75 | grouped_barchart(
76 | dataf32["forward"],
77 | axs[0][0],
78 | bar_height_fontsize=0,
79 | xticklabel=False,
80 | colormap=colormap,
81 | group_spacing=6.0,
82 | )
83 | grouped_barchart(
84 | dataf32["backward"],
85 | axs[1][0],
86 | bar_height_fontsize=0,
87 | xticklabel=True,
88 | colormap=colormap,
89 | group_spacing=6.0,
90 | )
91 |
92 | grouped_barchart(
93 | dataf64["forward"],
94 | axs[0][1],
95 | bar_height_fontsize=0,
96 | xticklabel=False,
97 | colormap=colormap,
98 | group_spacing=6.0,
99 | )
100 | grouped_barchart(
101 | dataf64["backward"],
102 | axs[1][1],
103 | bar_height_fontsize=0,
104 | xticklabel=True,
105 | colormap=colormap,
106 | group_spacing=6.0,
107 | )
108 |
109 | for i in range(2):
110 | for j in range(2):
111 | set_grid(axs[i][j])
112 |
113 | fig.supylabel("Throughput (# tensor products / s)", x=0.03, y=0.56)
114 |
115 | axs[0][0].set_ylabel("Forward")
116 | axs[1][0].set_ylabel("Backward")
117 |
118 | axs[1][0].set_xlabel("float32")
119 | axs[1][1].set_xlabel("float64")
120 |
121 | handles, labels = axs[0][1].get_legend_handles_labels()
122 | unique = [
123 | (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]
124 | ]
125 | axs[0][1].legend(*zip(*unique))
126 |
127 | fig.show()
128 | fig.tight_layout()
129 | fig.savefig(str(data_folder / "uvw_throughput_comparison.pdf"))
130 |
131 | speedup_table = []
132 | for direction in ["forward", "backward"]:
133 | for impl in ["e3nn", "cuE"]:
134 | for dtype_label, dtype_set in [("f32", dataf32), ("f64", dataf64)]:
135 | speedups = [
136 | measurement["ours"] / measurement[impl]
137 | for label, measurement in dtype_set[direction].items()
138 | if impl in measurement and "DiffDock" in label
139 | ]
140 | stats = (
141 | np.min(speedups),
142 | np.mean(speedups),
143 | np.median(speedups),
144 | np.max(speedups),
145 | )
146 | stats = [f"{stat:.2f}" for stat in stats]
147 |
148 | dir_print = direction
149 | if direction == "forward":
150 | dir_print += " "
151 | result = [dir_print, impl, dtype_label] + stats
152 | speedup_table.append(result)
153 |
154 | print("DiffDock")
155 | print("\t\t".join(["Direction", "Base", "dtype", "min", "mean", "med", "max"]))
156 | for row in speedup_table:
157 | print("\t\t".join(row))
158 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/problems.py:
--------------------------------------------------------------------------------
1 | from openequivariance.benchmark.tpp_creation_utils import (
2 | FullyConnectedTPProblem as FCTPP,
3 | )
4 | from openequivariance.benchmark.tpp_creation_utils import ChannelwiseTPP as CTPP
5 |
6 | # source: https://github.com/e3nn/e3nn/blob/main/examples/tetris.py
7 | # running tetris will output the layers. I've only extracted the fully connected layers here.
8 | _e3nn_torch_tetris = [
9 | # 0th Layer
10 | FCTPP("1x0e", "1x0e", "150x0e + 50x1o + 50x2e"), # sc
11 | FCTPP("1x0e", "1x0e", "1x0e"), # lin1
12 | FCTPP("1x0e + 1x1o + 1x2e", "1x0e", "150x0e + 50x1o + 50x2e"), # lin2
13 | FCTPP("1x0e + 1x1o + 1x2e", "1x0e", "1x0e"), # alpha
14 | # 1st Layer
15 | FCTPP(
16 | "50x0e + 50x1o + 50x2e", "1x0e", "250x0e + 50x1o + 50x1e + 50x2o + 50x2e"
17 | ), # sc
18 | FCTPP("50x0e + 50x1o + 50x2e", "1x0e", "50x0e + 50x1o + 50x2e"), # lin1
19 | # FCTPP("50x0e + 50x1o + 50x2e", "1x0e + 1x1o + 1x2e", "150x0e + 200x1o + 100x1e + 100x2o + 200x2e"), #tp
20 | FCTPP(
21 | "150x0e + 200x1o + 100x1e + 100x2o + 200x2e",
22 | "1x0e",
23 | "250x0e + 50x1o + 50x1e + 50x2o + 50x2e",
24 | ), # lin2
25 | FCTPP("150x0e + 200x1o + 100x1e + 100x2o + 200x2e", "1x0e", "1x0e"), # alpha
26 | # 2nd Layer
27 | FCTPP(
28 | "50x0e + 50x1o + 50x1e + 50x2o + 50x2e",
29 | "1x0e",
30 | "50x0o + 250x0e + 50x1o + 50x1e + 50x2o + 50x2e",
31 | ), # sc
32 | FCTPP(
33 | "50x0e + 50x1o + 50x1e + 50x2o + 50x2e",
34 | "1x0e",
35 | "50x0e + 50x1o + 50x1e + 50x2o + 50x2e",
36 | ), # lin1
37 | FCTPP(
38 | "100x0o + 150x0e + 300x1o + 250x1e + 250x2o + 300x2e",
39 | "1x0e",
40 | "50x0o + 250x0e + 50x1o + 50x1e + 50x2o + 50x2e",
41 | ), # lin2
42 | FCTPP(
43 | "100x0o + 150x0e + 300x1o + 250x1e + 250x2o + 300x2e", "1x0e", "1x0e"
44 | ), # alpha
45 | # 3rd Layer
46 | FCTPP("50x0o + 50x0e + 50x1o + 50x1e + 50x2o + 50x2e", "1x0e", "1x0o + 6x0e"), # sc
47 | FCTPP(
48 | "50x0o + 50x0e + 50x1o + 50x1e + 50x2o + 50x2e",
49 | "1x0e",
50 | "50x0o + 50x0e + 50x1o + 50x1e + 50x2o + 50x2e",
51 | ), # lin1
52 | FCTPP("150x0o + 150x0e", "1x0e", "1x0o + 6x0e"), # lin2
53 | FCTPP("150x0o + 150x0e", "1x0e", "1x0e"), # alpha
54 | ]
55 |
56 |
57 | def e3nn_torch_tetris_poly_problems():
58 | # source: https://github.com/e3nn/e3nn/blob/f95297952303347a8a3cfe971efe449c710c43b2/examples/tetris_polynomial.py#L66-L68
59 | return [
60 | FCTPP(
61 | "1x0e + 1x1o + 1x2e + 1x3o",
62 | "1x0e + 1x1o + 1x2e + 1x3o",
63 | "64x0e + 24x1e + 24x1o + 16x2e + 16x2o",
64 | label="tetris-poly-1",
65 | ), # tp1
66 | FCTPP(
67 | "64x0e + 24x1e + 24x1o + 16x2e + 16x2o",
68 | "1x0e + 1x1o + 1x2e",
69 | "0o + 6x0e",
70 | label="tetris-poly-2",
71 | ), # tp2
72 | ]
73 |
74 |
75 | # https://github.com/gcorso/DiffDock/blob/b4704d94de74d8cb2acbe7ec84ad234c09e78009/models/tensor_layers.py#L299
76 | # specific irreps come from Vivek's communication with DiffDock team
77 | def diffdock_problems():
78 | return [
79 | FCTPP(
80 | "10x1o + 10x1e + 48x0e + 48x0o",
81 | "1x0e + 1x1o",
82 | "10x1o + 10x1e + 48x0e + 48x0o",
83 | shared_weights=False,
84 | label="DiffDock-L=1",
85 | ),
86 | FCTPP(
87 | "10x1o + 10x1e + 48x0e + 48x0o",
88 | "1x0e + 1x1o + 1x2e",
89 | "10x1o + 10x1e + 48x0e + 48x0o",
90 | shared_weights=False,
91 | label="DiffDock-L=2",
92 | ),
93 | ]
94 |
95 |
96 | def mace_problems():
97 | return [
98 | CTPP(*config)
99 | for config in [
100 | (
101 | "128x0e+128x1o+128x2e",
102 | "1x0e+1x1o+1x2e+1x3o",
103 | "128x0e+128x1o+128x2e+128x3o",
104 | "mace-large",
105 | ),
106 | (
107 | "128x0e+128x1o",
108 | "1x0e+1x1o+1x2e+1x3o",
109 | "128x0e+128x1o+128x2e",
110 | "mace-medium",
111 | ),
112 | ]
113 | ]
114 |
115 |
116 | def nequip_problems():
117 | return [
118 | CTPP(*config)
119 | for config in [
120 | (
121 | "32x0o + 32x0e + 32x1o + 32x1e + 32x2o + 32x2e",
122 | "0e + 1o + 2e",
123 | "32x0o + 32x0e + 32x1o + 32x1e + 32x2o + 32x2e",
124 | "nequip-lips",
125 | ),
126 | (
127 | "64x0o + 64x0e + 64x1o + 64x1e",
128 | "0e + 1o",
129 | "64x0o + 64x0e + 64x1o + 64x1e",
130 | "nequip-revmd17-aspirin",
131 | ),
132 | (
133 | "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e",
134 | "0e + 1o + 2e",
135 | "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e",
136 | "nequip-revmd17-toluene",
137 | ),
138 | (
139 | "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e + 64x3o + 64x3e",
140 | "0e + 1o + 2e + 3o",
141 | "64x0o + 64x0e + 64x1o + 64x1e + 64x2o + 64x2e + 64x3o + 64x3e",
142 | "nequip-revmd17-benzene",
143 | ),
144 | (
145 | "32x0o + 32x0e + 32x1o + 32x1e",
146 | "0e + 1o",
147 | "32x0o + 32x0e + 32x1o + 32x1e",
148 | "nequip-water",
149 | ),
150 | ]
151 | ]
152 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/random_buffer_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from openequivariance.implementations.e3nn_lite import TPProblem
4 |
5 |
6 | def get_random_buffers_forward(
7 | tpp: TPProblem, batch_size: int, prng_seed: int
8 | ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
9 | """
10 | Return properly sized numpy arrays needed to execute a tensor product in the forward direction
11 | Supports shared vs non-shared weights
12 | """
13 | assert isinstance(tpp, TPProblem)
14 | rng = np.random.default_rng(prng_seed)
15 |
16 | in1 = np.array(
17 | rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype
18 | )
19 | in2 = np.array(
20 | rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype
21 | )
22 |
23 | weights_size = (
24 | tuple([tpp.weight_numel])
25 | if tpp.shared_weights
26 | else tuple([batch_size, tpp.weight_numel])
27 | )
28 | weights = np.array(rng.uniform(size=weights_size), dtype=tpp.weight_dtype)
29 |
30 | out = np.zeros(shape=(batch_size, tpp.irreps_out.dim), dtype=tpp.weight_dtype)
31 |
32 | return in1, in2, weights, out
33 |
34 |
35 | def get_random_buffers_backward(
36 | tpp: TPProblem, batch_size: int, prng_seed: int
37 | ) -> tuple[
38 | np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray
39 | ]:
40 | """
41 | Return properly sized numpy arrays needed to execute a tensor product in the backward direction
42 | Supports shared vs non-shared weights
43 | """
44 | assert isinstance(tpp, TPProblem)
45 | rng = np.random.default_rng(prng_seed)
46 |
47 | in1 = np.array(
48 | rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype
49 | )
50 | in2 = np.array(
51 | rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype
52 | )
53 | out_grad = np.array(
54 | rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype
55 | )
56 |
57 | weights_size = (
58 | tuple([tpp.weight_numel])
59 | if tpp.shared_weights
60 | else tuple([batch_size, tpp.weight_numel])
61 | )
62 | weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype)
63 |
64 | weights_grad = np.zeros_like(weights)
65 | in1_grad = np.zeros_like(in1)
66 | in2_grad = np.zeros_like(in2)
67 |
68 | return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad
69 |
70 |
71 | def get_random_buffers_double_backward(
72 | tpp: TPProblem, batch_size: int, prng_seed: int
73 | ) -> tuple[
74 | np.ndarray,
75 | np.ndarray,
76 | np.ndarray,
77 | np.ndarray,
78 | np.ndarray,
79 | np.ndarray,
80 | np.ndarray,
81 | np.ndarray,
82 | ]:
83 | """
84 | Return properly sized numpy arrays needed to execute a tensor product in the double backward direction
85 | Supports shared vs non-shared weights
86 | """
87 | assert isinstance(tpp, TPProblem)
88 | rng = np.random.default_rng(prng_seed)
89 |
90 | in1 = np.array(
91 | rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype
92 | )
93 | in2 = np.array(
94 | rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype
95 | )
96 | out_grad = np.array(
97 | rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype
98 | )
99 |
100 | weights_size = (
101 | tuple([tpp.weight_numel])
102 | if tpp.shared_weights
103 | else tuple([batch_size, tpp.weight_numel])
104 | )
105 | weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype)
106 |
107 | weights_grad = np.zeros_like(weights)
108 | in1_grad = np.zeros_like(in1)
109 | in2_grad = np.zeros_like(in2)
110 | out_double_grad = np.zeros_like(out_grad)
111 |
112 | return (
113 | in1,
114 | in2,
115 | out_grad,
116 | weights,
117 | weights_grad,
118 | in1_grad,
119 | in2_grad,
120 | out_double_grad,
121 | )
122 |
123 |
124 | def get_random_buffers_forward_conv(
125 | tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int
126 | ):
127 | rng = np.random.default_rng(prng_seed)
128 |
129 | in1 = np.array(
130 | rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype
131 | )
132 | in2 = np.array(
133 | rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype
134 | )
135 |
136 | weights_size = (
137 | tuple([tpp.weight_numel])
138 | if tpp.shared_weights
139 | else tuple([edge_count, tpp.weight_numel])
140 | )
141 | weights = np.array(rng.uniform(size=weights_size), dtype=tpp.weight_dtype)
142 |
143 | out = np.zeros(shape=(node_count, tpp.irreps_out.dim), dtype=tpp.weight_dtype)
144 |
145 | return in1, in2, weights, out
146 |
147 |
148 | def get_random_buffers_backward_conv(
149 | tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int
150 | ):
151 | """
152 | Return properly sized numpy arrays needed to execute a tensor product in the backward direction
153 | Supports shared vs non-shared weights
154 | """
155 | rng = np.random.default_rng(prng_seed)
156 |
157 | in1 = np.array(
158 | rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype
159 | )
160 | in2 = np.array(
161 | rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype
162 | )
163 | out_grad = np.array(
164 | rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype
165 | )
166 |
167 | weights_size = (
168 | tuple([tpp.weight_numel])
169 | if tpp.shared_weights
170 | else tuple([edge_count, tpp.weight_numel])
171 | )
172 | weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype)
173 |
174 | weights_grad = np.zeros_like(weights)
175 | in1_grad = np.zeros_like(in1)
176 | in2_grad = np.zeros_like(in2)
177 |
178 | return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad
179 |
--------------------------------------------------------------------------------
/openequivariance/benchmark/tpp_creation_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from typing import Iterator, Optional
4 | from openequivariance.implementations.e3nn_lite import Irrep, Irreps, TPProblem
5 |
6 | """
7 | This was taken from
8 |
9 | https://github.com/e3nn/e3nn/blob/0.5.4/e3nn/o3/_tensor_product/_sub.py
10 |
11 | And adopted to create TPP's to avoid torch dependence
12 | """
13 |
14 |
15 | class FullyConnectedTPProblem(TPProblem):
16 | def __init__(self, irreps_in1, irreps_in2, irreps_out, **kwargs) -> None:
17 | irreps_in1 = Irreps(irreps_in1)
18 | irreps_in2 = Irreps(irreps_in2)
19 | irreps_out = Irreps(irreps_out)
20 |
21 | instr = [
22 | (i_1, i_2, i_out, "uvw", True, 1.0)
23 | for i_1, (_, ir_1) in enumerate(irreps_in1)
24 | for i_2, (_, ir_2) in enumerate(irreps_in2)
25 | for i_out, (_, ir_out) in enumerate(irreps_out)
26 | if ir_out in ir_1 * ir_2
27 | ]
28 | super().__init__(
29 | irreps_in1,
30 | irreps_in2,
31 | irreps_out,
32 | instr,
33 | **kwargs,
34 | )
35 |
36 |
37 | class ElementwiseTPProblem(TPProblem):
38 | def __init__(self, irreps_in1, irreps_in2, filter_ir_out=None, **kwargs) -> None:
39 | irreps_in1 = Irreps(irreps_in1).simplify()
40 | irreps_in2 = Irreps(irreps_in2).simplify()
41 | if filter_ir_out is not None:
42 | try:
43 | filter_ir_out = [Irrep(ir) for ir in filter_ir_out]
44 | except ValueError:
45 | raise ValueError(
46 | f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep"
47 | )
48 |
49 | assert irreps_in1.num_irreps == irreps_in2.num_irreps
50 |
51 | irreps_in1 = list(irreps_in1)
52 | irreps_in2 = list(irreps_in2)
53 |
54 | i = 0
55 | while i < len(irreps_in1):
56 | mul_1, ir_1 = irreps_in1[i]
57 | mul_2, ir_2 = irreps_in2[i]
58 |
59 | if mul_1 < mul_2:
60 | irreps_in2[i] = (mul_1, ir_2)
61 | irreps_in2.insert(i + 1, (mul_2 - mul_1, ir_2))
62 |
63 | if mul_2 < mul_1:
64 | irreps_in1[i] = (mul_2, ir_1)
65 | irreps_in1.insert(i + 1, (mul_1 - mul_2, ir_1))
66 | i += 1
67 |
68 | out = []
69 | instr = []
70 | for i, ((mul, ir_1), (mul_2, ir_2)) in enumerate(zip(irreps_in1, irreps_in2)):
71 | assert mul == mul_2
72 | for ir in ir_1 * ir_2:
73 | if filter_ir_out is not None and ir not in filter_ir_out:
74 | continue
75 |
76 | i_out = len(out)
77 | out.append((mul, ir))
78 | instr += [(i, i, i_out, "uuu", False)]
79 |
80 | super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs)
81 |
82 |
83 | class FullTPProblem(TPProblem):
84 | def __init__(
85 | self,
86 | irreps_in1: Irreps,
87 | irreps_in2: Irreps,
88 | filter_ir_out: Iterator[Irrep] = None,
89 | **kwargs,
90 | ) -> None:
91 | irreps_in1 = Irreps(irreps_in1).simplify()
92 | irreps_in2 = Irreps(irreps_in2).simplify()
93 | if filter_ir_out is not None:
94 | try:
95 | filter_ir_out = [Irrep(ir) for ir in filter_ir_out]
96 | except ValueError:
97 | raise ValueError(
98 | f"filter_ir_out (={filter_ir_out}) must be an iterable of e3nn.o3.Irrep"
99 | )
100 |
101 | out = []
102 | instr = []
103 | for i_1, (mul_1, ir_1) in enumerate(irreps_in1):
104 | for i_2, (mul_2, ir_2) in enumerate(irreps_in2):
105 | for ir_out in ir_1 * ir_2:
106 | if filter_ir_out is not None and ir_out not in filter_ir_out:
107 | continue
108 |
109 | i_out = len(out)
110 | out.append((mul_1 * mul_2, ir_out))
111 | instr += [(i_1, i_2, i_out, "uvuv", False)]
112 |
113 | out = Irreps(out)
114 | out, p, _ = out.sort()
115 |
116 | instr = [
117 | (i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instr
118 | ]
119 |
120 | super().__init__(irreps_in1, irreps_in2, out, instr, **kwargs)
121 |
122 |
123 | class ChannelwiseTPP(TPProblem):
124 | """
125 | Modified from mace/mace/modules/irreps_tools.py.
126 | """
127 |
128 | def __init__(
129 | self,
130 | irreps_in1: Irreps,
131 | irreps_in2: Irreps,
132 | irreps_out: Irreps,
133 | label: Optional[str] = None,
134 | irrep_dtype=np.float32,
135 | weight_dtype=np.float32,
136 | ):
137 | trainable = True
138 | irreps1 = Irreps(irreps_in1)
139 | irreps2 = Irreps(irreps_in2)
140 | irreps_out = Irreps(irreps_out)
141 |
142 | # Collect possible irreps and their instructions
143 | irreps_out_list = []
144 | instructions = []
145 | for i, (mul, ir_in) in enumerate(irreps1):
146 | for j, (_, ir_edge) in enumerate(irreps2):
147 | for ir_out in ir_in * ir_edge: # | l1 - l2 | <= l <= l1 + l2
148 | if ir_out in irreps_out:
149 | k = len(irreps_out_list) # instruction index
150 | irreps_out_list.append((mul, ir_out))
151 | instructions.append((i, j, k, "uvu", trainable))
152 |
153 | irreps_out = Irreps(irreps_out_list)
154 | irreps_out, permut, _ = irreps_out.sort()
155 |
156 | instructions = [
157 | (i_in1, i_in2, permut[i_out], mode, train)
158 | for i_in1, i_in2, i_out, mode, train in instructions
159 | ]
160 |
161 | instructions = sorted(instructions, key=lambda x: x[2])
162 | super().__init__(
163 | irreps1,
164 | irreps2,
165 | irreps_out,
166 | instructions,
167 | internal_weights=False,
168 | shared_weights=False,
169 | label=label,
170 | irrep_dtype=irrep_dtype,
171 | weight_dtype=weight_dtype,
172 | )
173 |
174 |
175 | class SingleInstruction(TPProblem):
176 | def __init__(
177 | self,
178 | irreps_in1: Irreps,
179 | irreps_in2: Irreps,
180 | irreps_in3: Irreps,
181 | mode: str,
182 | label: Optional[str] = None,
183 | ):
184 | trainable = True
185 | irreps1 = Irreps(irreps_in1)
186 | irreps2 = Irreps(irreps_in2)
187 | irreps3 = Irreps(irreps_in3)
188 | instructions = [(0, 0, 0, mode, trainable)]
189 |
190 | super().__init__(
191 | irreps1,
192 | irreps2,
193 | irreps3,
194 | instructions,
195 | internal_weights=False,
196 | shared_weights=False,
197 | label=label,
198 | )
199 |
--------------------------------------------------------------------------------
/openequivariance/extension/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.15)
2 | project(equivariant_spmm LANGUAGES CXX)
3 |
4 | find_package(CUDAToolkit REQUIRED)
5 | find_package(pybind11 REQUIRED)
6 | find_package(Python COMPONENTS Interpreter Development)
7 |
8 | set(core_SOURCES
9 | util/buffer.hpp
10 | util/backend_cuda.hpp
11 | tensorproducts.hpp
12 | convolution.hpp
13 | )
14 |
15 | add_library(kernel_wrapper MODULE ${core_SOURCES} kernel_wrapper.cpp)
16 | set_property(TARGET kernel_wrapper PROPERTY POSITION_INDEPENDENT_CODE ON)
17 |
18 | target_include_directories(kernel_wrapper PUBLIC
19 | ${CMAKE_CURRENT_SOURCE_DIR}/util
20 | )
21 |
22 | set_target_properties(kernel_wrapper PROPERTIES LINKER_LANGUAGE CXX)
23 |
24 | # Manually build the module to avoid SOABI renaming. Original command:
25 | # pybind11_add_module(kernel_wrapper kernel_wrapper.cpp)
26 | # target_link_libraries(kernel_wrapper PRIVATE espmm)
27 |
28 | target_link_libraries(kernel_wrapper PRIVATE
29 | pybind11::pybind11
30 | CUDA::cudart
31 | CUDA::cuda_driver
32 | CUDA::nvrtc)
33 | target_link_options(kernel_wrapper PRIVATE -Wl,-rpath='$ORIGIN')
34 |
35 | if(NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES Debug|RelWithDebInfo)
36 | # Strip unnecessary sections of the binary on Linux/macOS
37 | pybind11_strip(kernel_wrapper)
38 | endif()
39 |
40 | set_target_properties(kernel_wrapper PROPERTIES CXX_VISIBILITY_PRESET "hidden"
41 | CUDA_VISIBILITY_PRESET "hidden"
42 | PREFIX "")
43 |
44 | install(TARGETS kernel_wrapper DESTINATION .)
--------------------------------------------------------------------------------
/openequivariance/extension/convolution.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 | class __attribute__ ((visibility ("default"))) ConvolutionImpl {
10 | public:
11 | bool record_internal_stats = false;
12 |
13 | ConvolutionImpl() {
14 | }
15 |
16 | virtual void exec_conv(
17 | void* L1_in,
18 | void* L2_in,
19 | void* weights,
20 | void* L3_out,
21 | void* rows,
22 | void* cols,
23 | uint64_t nnz,
24 | uint64_t node_count,
25 | void* workspace) = 0;
26 |
27 | void exec_conv_rawptrs(
28 | uint64_t L1_in,
29 | uint64_t L2_in,
30 | uint64_t weights,
31 | uint64_t L3_out,
32 | uint64_t rows,
33 | uint64_t cols,
34 | uint64_t nnz,
35 | uint64_t node_count,
36 | uint64_t workspace) {
37 |
38 | exec_conv(
39 | reinterpret_cast(L1_in),
40 | reinterpret_cast(L2_in),
41 | reinterpret_cast(weights),
42 | reinterpret_cast(L3_out),
43 | reinterpret_cast(rows),
44 | reinterpret_cast(cols),
45 | nnz,
46 | node_count,
47 | reinterpret_cast(workspace));
48 | }
49 |
50 | virtual void backward(
51 | void* L1_in, void* L1_grad,
52 | void* L2_in, void* L2_grad,
53 | void* weight, void* weight_grad,
54 | void* L3_grad,
55 | void* rows, void* cols,
56 | uint64_t nnz, uint64_t node_count,
57 | void* workspace, void* inverse_perm) = 0;
58 |
59 | void backward_rawptrs(
60 | uint64_t L1_in, uint64_t L1_grad,
61 | uint64_t L2_in, uint64_t L2_grad,
62 | uint64_t weight, uint64_t weight_grad,
63 | uint64_t L3_grad,
64 | uint64_t rows, uint64_t cols,
65 | uint64_t nnz, uint64_t node_count,
66 | uint64_t workspace, uint64_t inverse_perm) {
67 |
68 | backward(
69 | reinterpret_cast(L1_in),
70 | reinterpret_cast(L1_grad),
71 | reinterpret_cast(L2_in),
72 | reinterpret_cast(L2_grad),
73 | reinterpret_cast(weight),
74 | reinterpret_cast(weight_grad),
75 | reinterpret_cast(L3_grad),
76 | reinterpret_cast(rows),
77 | reinterpret_cast(cols),
78 | nnz,
79 | node_count,
80 | reinterpret_cast(workspace),
81 | reinterpret_cast(inverse_perm));
82 | }
83 |
84 | virtual ~ConvolutionImpl() {};
85 | };
86 |
87 | struct ConvData {
88 | void* rows;
89 | void* cols;
90 | unsigned long nnz;
91 | unsigned long node_count;
92 | };
93 |
94 | template
95 | class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionImpl{
96 | public:
97 | JIT_IMPL jit;
98 | KernelLaunchConfig forward_config;
99 | KernelLaunchConfig backward_config;
100 | KernelLaunchConfig double_backward_config;
101 | bool is_uvw;
102 |
103 | JITConvImpl(
104 | std::string jit_kernel,
105 | KernelLaunchConfig forward_config_i,
106 | KernelLaunchConfig backward_config_i,
107 | KernelLaunchConfig double_backward_config_i,
108 | bool is_uvw_i) :
109 | jit(jit_kernel),
110 | forward_config(forward_config_i),
111 | backward_config(backward_config_i),
112 | double_backward_config(double_backward_config_i),
113 | is_uvw(is_uvw_i) {
114 |
115 | vector kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B", "fixup_double_backwardB"};
116 |
117 | int opt_level = 3;
118 | #ifdef HIP_BACKEND
119 | if(is_uvw) {
120 | opt_level = 1;
121 | }
122 | #endif
123 | jit.compile(kernels, {{}, {}, {}, {}, {}, {}, {}}, opt_level);
124 |
125 | if(forward_config.smem > 0) {
126 | jit.set_max_smem(0, forward_config.smem);
127 | jit.set_max_smem(4, forward_config.smem);
128 | }
129 |
130 | if(backward_config.smem > 0) {
131 | jit.set_max_smem(1, backward_config.smem);
132 | }
133 |
134 | if(double_backward_config.smem > 0) {
135 | jit.set_max_smem(5, double_backward_config.smem);
136 | }
137 | }
138 |
139 | JITConvImpl(
140 | std::string jit_kernel,
141 | std::unordered_map fwd_dict,
142 | std::unordered_map bwd_dict,
143 | std::unordered_map dbl_bwd_dict,
144 | std::unordered_map kernel_dims
145 | ) : JITConvImpl(
146 | jit_kernel,
147 | KernelLaunchConfig(
148 | fwd_dict["num_blocks"],
149 | fwd_dict["num_threads"],
150 | fwd_dict["smem"]
151 | ),
152 | KernelLaunchConfig(
153 | bwd_dict["num_blocks"],
154 | bwd_dict["num_threads"],
155 | bwd_dict["smem"]
156 | ),
157 | KernelLaunchConfig(
158 | dbl_bwd_dict["num_blocks"],
159 | dbl_bwd_dict["num_threads"],
160 | dbl_bwd_dict["smem"]
161 | ),
162 | kernel_dims["is_uvw"] == 1) { }
163 |
164 | void exec_conv(
165 | void* L1_in,
166 | void* L2_in,
167 | void* weights,
168 | void* L3_out,
169 | void* rows,
170 | void* cols,
171 | uint64_t nnz,
172 | uint64_t node_count,
173 | void* workspace) {
174 |
175 | ConvData conv_data = {rows, cols, nnz, node_count};
176 |
177 | void *args[] = {&L1_in, &L2_in, &weights, &L3_out, &conv_data, &workspace};
178 | jit.execute(0, args, forward_config);
179 |
180 | if(reinterpret_cast(workspace) != 0) {
181 | void *fixup_args[] = {&workspace, &L3_out};
182 |
183 | KernelLaunchConfig fixup_config;
184 | fixup_config.num_blocks = forward_config.num_blocks;
185 | fixup_config.num_threads = forward_config.num_threads;
186 | fixup_config.smem = 0;
187 |
188 | jit.execute(2, fixup_args, fixup_config);
189 | }
190 | }
191 |
192 | void backward(
193 | void* L1_in, void* L1_grad,
194 | void* L2_in, void* L2_grad,
195 | void* weight, void* weight_grad,
196 | void* L3_grad,
197 | void* rows, void* cols,
198 | uint64_t nnz, uint64_t node_count,
199 | void* workspace,
200 | void* transpose_perm) {
201 |
202 | ConvData conv_data = {rows, cols, nnz, node_count};
203 | void *args[] = {&L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad, &conv_data, &workspace, &transpose_perm};
204 | jit.execute(1, args, backward_config);
205 |
206 | if(reinterpret_cast(workspace) != 0) {
207 | void *fixup_args[] = {&workspace, &L1_grad};
208 |
209 | KernelLaunchConfig fixup_config;
210 | fixup_config.num_blocks = backward_config.num_blocks;
211 | fixup_config.num_threads = backward_config.num_threads;
212 | fixup_config.smem = 0;
213 |
214 | jit.execute(3, fixup_args, fixup_config);
215 | }
216 | }
217 |
218 | void double_backward(
219 | void* L1_in, void* L2_in, void* W, void* L3_grad,
220 | void* L1_dgrad, void* L2_dgrad, void* w_dgrad,
221 | void* L1_grad, void* L2_grad, void* W_grad, void* L3_dgrad,
222 | void* rows, void* cols,
223 | uint64_t nnz, uint64_t node_count,
224 | void* wspace, void* transpose_perm) {
225 |
226 | ConvData conv_data = {rows, cols, nnz, node_count};
227 | void* args[] = {
228 | &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad,
229 | &L1_grad, &L2_grad, &W_grad, &L3_dgrad, &conv_data, &wspace, &transpose_perm
230 | };
231 |
232 | jit.execute(4, args, forward_config);
233 | if(reinterpret_cast(wspace) != 0) {
234 | void *fixup_args[] = {&wspace, &L3_dgrad};
235 | KernelLaunchConfig fixup_config;
236 | fixup_config.num_blocks = forward_config.num_blocks;
237 | fixup_config.num_threads = forward_config.num_threads;
238 | fixup_config.smem = 0;
239 | jit.execute(2, fixup_args, fixup_config);
240 | }
241 |
242 | jit.execute(5, args, double_backward_config);
243 | if(reinterpret_cast(wspace) != 0) {
244 | void *fixup_args[] = {&wspace, &L1_grad};
245 | KernelLaunchConfig fixup_config;
246 | fixup_config.num_blocks = double_backward_config.num_blocks;
247 | fixup_config.num_threads = double_backward_config.num_threads;
248 | fixup_config.smem = 0;
249 | jit.execute(6, fixup_args, fixup_config);
250 | }
251 | }
252 |
253 | ~JITConvImpl() = default;
254 | };
--------------------------------------------------------------------------------
/openequivariance/extension/generic_module.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | #ifdef CUDA_BACKEND
8 | #include "backend_cuda.hpp"
9 | #include "group_mm_cuda.hpp"
10 | using JITKernel = CUJITKernel;
11 | using GPU_Allocator = CUDA_Allocator;
12 |
13 | template
14 | using GroupMM = GroupMMCUDA;
15 | #endif
16 |
17 | #ifdef HIP_BACKEND
18 | #include "backend_hip.hpp"
19 | #include "group_mm_hip.hpp"
20 | using JITKernel = HIPJITKernel;
21 | using GPU_Allocator = HIP_Allocator;
22 |
23 | template
24 | using GroupMM = GroupMMHIP;
25 | #endif
26 |
27 | #include "buffer.hpp"
28 | #include "tensorproducts.hpp"
29 | #include "convolution.hpp"
30 |
31 | using namespace std;
32 | namespace py=pybind11;
33 |
34 | PYBIND11_MODULE(generic_module, m) {
35 | //=========== Batch tensor products =========
36 | py::class_(m, "GenericTensorProductImpl")
37 | .def("exec_tensor_product_rawptr", &GenericTensorProductImpl::exec_tensor_product_device_rawptrs)
38 | .def("backward_rawptr", &GenericTensorProductImpl::backward_device_rawptrs);
39 | py::class_, GenericTensorProductImpl>(m, "JITTPImpl")
40 | .def(py::init< std::string,
41 | std::unordered_map,
42 | std::unordered_map,
43 | std::unordered_map,
44 | std::unordered_map>());
45 |
46 | //============= Convolutions ===============
47 | py::class_(m, "ConvolutionImpl")
48 | .def("exec_conv_rawptrs", &ConvolutionImpl::exec_conv_rawptrs)
49 | .def("backward_rawptrs", &ConvolutionImpl::backward_rawptrs);
50 | py::class_, ConvolutionImpl>(m, "JITConvImpl")
51 | .def(py::init< std::string,
52 | std::unordered_map,
53 | std::unordered_map,
54 | std::unordered_map,
55 | std::unordered_map>());
56 |
57 | py::class_>(m, "GroupMM_F32")
58 | .def(py::init())
59 | .def("group_gemm", &GroupMM::group_gemm_intptr);
60 | py::class_>(m, "GroupMM_F64")
61 | .def(py::init())
62 | .def("group_gemm", &GroupMM::group_gemm_intptr);
63 |
64 | py::class_(m, "DeviceProp")
65 | .def(py::init())
66 | .def_readonly("name", &DeviceProp::name)
67 | .def_readonly("warpsize", &DeviceProp::warpsize)
68 | .def_readonly("major", &DeviceProp::major)
69 | .def_readonly("minor", &DeviceProp::minor)
70 | .def_readonly("multiprocessorCount", &DeviceProp::multiprocessorCount)
71 | .def_readonly("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock);
72 |
73 | py::class_>(m, "DeviceBuffer")
74 | .def(py::init())
75 | .def(py::init())
76 | .def("copy_to_host", &PyDeviceBuffer::copy_to_host)
77 | .def("data_ptr", &PyDeviceBuffer::data_ptr);
78 |
79 | py::class_(m, "GPUTimer")
80 | .def(py::init<>())
81 | .def("start", &GPUTimer::start)
82 | .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed)
83 | .def("clear_L2_cache", &GPUTimer::clear_L2_cache);
84 | }
--------------------------------------------------------------------------------
/openequivariance/extension/group_mm_cuda.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "cublas_v2.h"
4 | #include
5 | #include
6 | #include
7 |
8 | using namespace std;
9 |
10 | template
11 | class GroupMMCUDA {
12 | cublasStatus_t stat;
13 | cublasHandle_t handle;
14 |
15 | int num_W;
16 | int batch_size;
17 |
18 | T alpha;
19 | T beta;
20 |
21 | public:
22 | GroupMMCUDA(int num_W, int batch_size) :
23 | num_W(num_W),
24 | batch_size(batch_size),
25 | alpha(1.0),
26 | beta(0.0) {
27 | stat = cublasCreate(&handle);
28 | if (stat != CUBLAS_STATUS_SUCCESS) {
29 | throw std::logic_error("CUBLAS initialization failed");
30 | }
31 | }
32 |
33 | void group_gemm(void* A_raw, void* B_raw, void* C_raw,
34 | int64_t* ragged_counts, int m, int k, int ragged_inner) {
35 | /*
36 | * Performs one of two grouped, batched GEMMs with a single ragged dimension:
37 | *
38 | * a) If ragged_inner = 0, multiplies each M x K row-major weight matrix A
39 | * against B, where B is stored in column-major order with each matrix of
40 | * dimensions K x [offset_diff]. Output has dimensions M x [offset_diff],
41 | * stored in column-major order.
42 | * b) If ragged_inner = 1, multiplies each M x [offset_diff] A matrix
43 | * against each B K x [offset_diff] matrix transposed to produce a
44 | * M x K matrix output.
45 | */
46 |
47 | T* A_base = reinterpret_cast(A_raw);
48 | T* B_base = reinterpret_cast(B_raw);
49 | T* C_base = reinterpret_cast(C_raw);
50 |
51 | int64_t ragged_offset = 0;
52 | for(int i = 0; i < num_W; i++) {
53 | int M, K, N, lda, ldb, ldc;
54 | T *A, *B, *C;
55 |
56 | int strideA, strideB, strideC;
57 | cublasOperation_t transa, transb;
58 |
59 | if(ragged_inner == 0) {
60 | M = m;
61 | K = k;
62 | N = static_cast(ragged_counts[i]);
63 |
64 | A = A_base + (m * k * batch_size * i);
65 | lda = k; strideA = M * K;
66 |
67 | B = B_base + (k * batch_size * ragged_offset);
68 | ldb = K * batch_size; strideB = K;
69 |
70 | C = C_base + (m * batch_size * ragged_offset);
71 | ldc = M * batch_size; strideC = M;
72 |
73 | transa = CUBLAS_OP_T;
74 | transb = CUBLAS_OP_N;
75 | }
76 | else {
77 | M = k;
78 | K = static_cast(ragged_counts[i]);
79 | N = m;
80 |
81 | A = B_base + (k * batch_size * ragged_offset);
82 | lda = k * batch_size; strideA = M;
83 |
84 | B = A_base + (m * batch_size * ragged_offset);
85 | ldb = m * batch_size; strideB = N;
86 |
87 | C = C_base + (m * k * batch_size * i);
88 | ldc = k; strideC = M * N;
89 |
90 | transa = CUBLAS_OP_N;
91 | transb = CUBLAS_OP_T;
92 | }
93 | ragged_offset += ragged_counts[i];
94 |
95 | if(ragged_counts[i] > 0) {
96 | if(std::is_same::value) {
97 | stat = cublasSgemmStridedBatched(handle,
98 | transa, transb,
99 | M, N, K,
100 | reinterpret_cast(&alpha),
101 | reinterpret_cast(A), lda, strideA,
102 | reinterpret_cast(B), ldb, strideB,
103 | reinterpret_cast(&beta),
104 | reinterpret_cast(C), ldc, strideC,
105 | batch_size);
106 | }
107 | else if(std::is_same::value) {
108 | stat = cublasDgemmStridedBatched(handle,
109 | transa, transb,
110 | M, N, K,
111 | reinterpret_cast(&alpha),
112 | reinterpret_cast(A), lda, strideA,
113 | reinterpret_cast(B), ldb, strideB,
114 | reinterpret_cast(&beta),
115 | reinterpret_cast(C), ldc, strideC,
116 | batch_size);
117 | }
118 | else {
119 | throw std::logic_error("Unsupported datatype for grouped GEMM!");
120 | }
121 | if (stat != CUBLAS_STATUS_SUCCESS) {
122 | throw std::logic_error("Grouped GEMM failed!");
123 | }
124 | }
125 | }
126 | }
127 |
128 | void group_gemm_intptr(uint64_t weights,
129 | uint64_t vectors, uint64_t output,
130 | uint64_t ragged_counts, int m, int k, int ragged_inner) {
131 |
132 | group_gemm(
133 | reinterpret_cast(weights),
134 | reinterpret_cast(vectors),
135 | reinterpret_cast(output),
136 | reinterpret_cast(ragged_counts),
137 | m, k, ragged_inner);
138 | }
139 |
140 | ~GroupMMCUDA() {
141 | cublasDestroy(handle);
142 | }
143 | };
--------------------------------------------------------------------------------
/openequivariance/extension/group_mm_hip.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "rocblas/rocblas.h"
4 | #include
5 | #include
6 | #include
7 |
8 |
9 | template
10 | class GroupMMHIP {
11 | rocblas_status stat;
12 | rocblas_handle handle;
13 |
14 | int num_W;
15 | int batch_size;
16 |
17 | T alpha;
18 | T beta;
19 |
20 | public:
21 | GroupMMHIP(int num_W, int batch_size) :
22 | num_W(num_W),
23 | batch_size(batch_size),
24 | alpha(1.0),
25 | beta(0.0) {
26 | if(rocblas_create_handle(&handle) != rocblas_status_success) {
27 | throw std::logic_error("rocBLAS initialization failed");
28 | }
29 | }
30 |
31 | void group_gemm(void* A_raw, void* B_raw, void* C_raw,
32 | int64_t* ragged_counts, int m, int k, int ragged_inner) {
33 |
34 | T* A_base = reinterpret_cast(A_raw);
35 | T* B_base = reinterpret_cast(B_raw);
36 | T* C_base = reinterpret_cast(C_raw);
37 |
38 | int64_t ragged_offset = 0;
39 | for(int i = 0; i < num_W; i++) {
40 | int M, K, N, lda, ldb, ldc;
41 | T *A, *B, *C;
42 |
43 | int strideA, strideB, strideC;
44 | rocblas_operation transa, transb;
45 |
46 | if(ragged_inner == 0) {
47 | M = m;
48 | K = k;
49 | N = static_cast(ragged_counts[i]);
50 |
51 | A = A_base + (m * k * batch_size * i);
52 | lda = k; strideA = M * K;
53 |
54 | B = B_base + (k * batch_size * ragged_offset);
55 | ldb = K * batch_size; strideB = K;
56 |
57 | C = C_base + (m * batch_size * ragged_offset);
58 | ldc = M * batch_size; strideC = M;
59 |
60 | transa = rocblas_operation_transpose;
61 | transb = rocblas_operation_none;
62 | }
63 | else {
64 | M = k;
65 | K = static_cast(ragged_counts[i]);
66 | N = m;
67 |
68 | A = B_base + (k * batch_size * ragged_offset);
69 | lda = k * batch_size; strideA = M;
70 |
71 | B = A_base + (m * batch_size * ragged_offset);
72 | ldb = m * batch_size; strideB = N;
73 |
74 | C = C_base + (m * k * batch_size * i);
75 | ldc = k; strideC = M * N;
76 |
77 | transa = rocblas_operation_none;
78 | transb = rocblas_operation_transpose;
79 | }
80 | ragged_offset += ragged_counts[i];
81 |
82 | if(ragged_counts[i] > 0) {
83 | if(std::is_same::value) {
84 | stat = rocblas_sgemm_strided_batched(handle,
85 | transa, transb,
86 | M, N, K,
87 | reinterpret_cast(&alpha),
88 | reinterpret_cast(A), lda, strideA,
89 | reinterpret_cast(B), ldb, strideB,
90 | reinterpret_cast(&beta),
91 | reinterpret_cast(C), ldc, strideC,
92 | batch_size);
93 | }
94 | else if(std::is_same::value) {
95 | stat = rocblas_dgemm_strided_batched(handle,
96 | transa, transb,
97 | M, N, K,
98 | reinterpret_cast(&alpha),
99 | reinterpret_cast(A), lda, strideA,
100 | reinterpret_cast(B), ldb, strideB,
101 | reinterpret_cast(&beta),
102 | reinterpret_cast(C), ldc, strideC,
103 | batch_size);
104 | }
105 | else {
106 | throw std::logic_error("Unsupported datatype for grouped GEMM!");
107 | }
108 | if (stat != rocblas_status_success) {
109 | throw std::logic_error("Grouped GEMM failed!");
110 | }
111 | }
112 | }
113 | }
114 |
115 | void group_gemm_intptr(uint64_t weights,
116 | uint64_t vectors, uint64_t output,
117 | uint64_t ragged_counts, int m, int k, int ragged_inner) {
118 | group_gemm(
119 | reinterpret_cast(weights),
120 | reinterpret_cast(vectors),
121 | reinterpret_cast(output),
122 | reinterpret_cast(ragged_counts),
123 | m, k, ragged_inner);
124 | }
125 |
126 | ~GroupMMHIP() {
127 | rocblas_destroy_handle(handle);
128 | }
129 | };
--------------------------------------------------------------------------------
/openequivariance/extension/tensorproducts.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 | class __attribute__ ((visibility ("default"))) GenericTensorProductImpl {
10 | public:
11 | GenericTensorProductImpl() { }
12 |
13 | virtual void exec_tensor_product(uint64_t num_products,
14 | void* L1_in, void* L2_in, void* L3_out, void* weights) = 0;
15 |
16 | void exec_tensor_product_device_rawptrs(uint64_t num_products,
17 | uint64_t L1_in, uint64_t L2_in, uint64_t L3_out, uint64_t weights) {
18 |
19 | exec_tensor_product(num_products,
20 | reinterpret_cast(L1_in),
21 | reinterpret_cast(L2_in),
22 | reinterpret_cast(L3_out),
23 | reinterpret_cast(weights));
24 | }
25 |
26 | virtual void backward(size_t num_products,
27 | void* L1_in, void* L1_grad,
28 | void* L2_in, void* L2_grad,
29 | void* weight, void* weight_grad,
30 | void* L3_grad) {
31 |
32 | throw std::logic_error("Backward pass not implemented yet!");
33 | }
34 |
35 | void backward_device_rawptrs(uint64_t num_products,
36 | uint64_t L1_in, uint64_t L1_grad,
37 | uint64_t L2_in, uint64_t L2_grad,
38 | uint64_t weight, uint64_t weight_grad,
39 | uint64_t L3_grad) {
40 |
41 | backward(num_products,
42 | reinterpret_cast(L1_in), reinterpret_cast(L1_grad),
43 | reinterpret_cast(L2_in), reinterpret_cast(L2_grad),
44 | reinterpret_cast(weight), reinterpret_cast(weight_grad),
45 | reinterpret_cast(L3_grad)
46 | );
47 | }
48 |
49 | virtual ~GenericTensorProductImpl() {};
50 | };
51 |
52 | template
53 | class __attribute__ ((visibility ("default"))) JITTPImpl : public GenericTensorProductImpl {
54 | public:
55 | JIT_IMPL jit;
56 | KernelLaunchConfig forward_config, backward_config, double_backward_config;
57 | bool is_uvw;
58 |
59 | JITTPImpl(
60 | std::string jit_kernel,
61 | KernelLaunchConfig forward_config_i,
62 | KernelLaunchConfig backward_config_i,
63 | KernelLaunchConfig double_backward_config_i,
64 | bool is_uvw_i) :
65 | jit(jit_kernel),
66 | forward_config(forward_config_i),
67 | backward_config(backward_config_i),
68 | double_backward_config(double_backward_config_i),
69 | is_uvw(is_uvw_i) {
70 | vector kernels = {"forward", "backward", "double_backward_A", "double_backward_B"};
71 |
72 | int opt_level = 3;
73 | #ifdef HIP_BACKEND
74 | if(is_uvw) {
75 | opt_level = 1;
76 | }
77 | #endif
78 | jit.compile(kernels, {{}, {}, {}, {}}, opt_level);
79 |
80 | if(forward_config.smem > 0) {
81 | jit.set_max_smem(0, forward_config.smem);
82 | jit.set_max_smem(2, forward_config.smem);
83 | }
84 |
85 | if(backward_config.smem > 0) {
86 | jit.set_max_smem(1, backward_config.smem);
87 |
88 | }
89 | if(double_backward_config.smem > 0) {
90 | jit.set_max_smem(3, double_backward_config.smem);
91 | }
92 | }
93 |
94 | JITTPImpl(
95 | std::string jit_kernel,
96 | std::unordered_map fwd_dict,
97 | std::unordered_map bwd_dict,
98 | std::unordered_map dbl_bwd_dict,
99 | std::unordered_map kernel_dims
100 | ) : JITTPImpl(
101 | jit_kernel,
102 | KernelLaunchConfig(
103 | fwd_dict["num_blocks"],
104 | fwd_dict["num_threads"],
105 | fwd_dict["smem"]
106 | ),
107 | KernelLaunchConfig(
108 | bwd_dict["num_blocks"],
109 | bwd_dict["num_threads"],
110 | bwd_dict["smem"]
111 | ),
112 | KernelLaunchConfig(
113 | dbl_bwd_dict["num_blocks"],
114 | dbl_bwd_dict["num_threads"],
115 | dbl_bwd_dict["smem"]
116 | ),
117 | kernel_dims["is_uvw"] == 1
118 | ) { }
119 |
120 | void exec_tensor_product(
121 | uint64_t num_products,
122 | void* L1_in,
123 | void* L2_in,
124 | void* L3_out,
125 | void* weights) {
126 |
127 | void *args[] = { &num_products, &L1_in, &L2_in, &L3_out, &weights};
128 | jit.execute(0, args, forward_config);
129 | }
130 |
131 | void backward(
132 | size_t num_products,
133 | void* L1_in, void* L1_grad,
134 | void* L2_in, void* L2_grad,
135 | void* weight, void* weight_grad,
136 | void* L3_grad) {
137 | void *args[] = { &num_products, &L1_in, &L1_grad, &L2_in, &L2_grad, &weight, &weight_grad, &L3_grad};
138 | jit.execute(1, args, backward_config);
139 | }
140 |
141 | void double_backward(
142 | size_t num_products,
143 | void* L1_in, void* L2_in, void* W, void* L3_grad, // Inputs of backward op
144 | void* L1_dgrad, void* L2_dgrad, void* w_dgrad, // Gradients w.r.t outputs of backward op
145 | void* L1_grad, void* L2_grad, void* W_grad, void* L3_dgrad) {
146 |
147 | void* args[] = {
148 | &num_products, &L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad,
149 | &L1_grad, &L2_grad, &W_grad, &L3_dgrad
150 | };
151 | jit.execute(2, args, forward_config);
152 | jit.execute(3, args, double_backward_config);
153 | }
154 |
155 | ~JITTPImpl() = default;
156 | };
--------------------------------------------------------------------------------
/openequivariance/extension/test/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
2 | project(test_oeq_jitscript_load)
3 |
4 | find_package(Torch REQUIRED)
5 |
6 | add_executable(load_jitscript load_jitscript.cpp)
7 | target_link_libraries(load_jitscript "${TORCH_LIBRARIES}")
8 | target_link_libraries(load_jitscript -Wl,--no-as-needed "${OEQ_EXTLIB}")
9 | set_property(TARGET load_jitscript PROPERTY CXX_STANDARD 17)
--------------------------------------------------------------------------------
/openequivariance/extension/test/load_jitscript.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | /*
7 | * This program takes in two JITScript modules that execute
8 | * a tensor product in FP32 precision.
9 | * The first module is compiled from e3nn, the second is
10 | * OEQ's compiled module. The program checks that the
11 | * two outputs are comparable.
12 | */
13 |
14 | int main(int argc, const char* argv[]) {
15 | if (argc != 7) {
16 | std::cerr << "usage: load_jitscript "
17 | << " "
18 | << " "
19 | << " "
20 | << " "
21 | << " "
22 | << " "
23 | << std::endl;
24 |
25 | return 1;
26 | }
27 |
28 | int64_t L1_dim = std::stoi(argv[3]);
29 | int64_t L2_dim = std::stoi(argv[4]);
30 | int64_t weight_numel = std::stoi(argv[5]);
31 | int64_t batch_size = std::stoi(argv[6]);
32 |
33 | torch::Device device(torch::kCUDA);
34 | std::vector inputs;
35 | inputs.push_back(torch::randn({batch_size, L1_dim}, device));
36 | inputs.push_back(torch::randn({batch_size, L2_dim}, device));
37 | inputs.push_back(torch::randn({batch_size, weight_numel}, device));
38 |
39 | torch::jit::script::Module module_e3nn, module_oeq;
40 | try {
41 | module_e3nn = torch::jit::load(argv[1]);
42 | module_oeq = torch::jit::load(argv[2]);
43 | }
44 | catch (const c10::Error& e) {
45 | std::cerr << "error loading script module" << std::endl;
46 | return 1;
47 | }
48 |
49 | module_e3nn.to(device);
50 | module_oeq.to(device);
51 |
52 | at::Tensor output_e3nn = module_e3nn.forward(inputs).toTensor();
53 | at::Tensor output_oeq = module_oeq.forward(inputs).toTensor();
54 |
55 | if(at::allclose(output_e3nn, output_oeq, 1e-5, 1e-5)) {
56 | return 0;
57 | }
58 | else {
59 | std::cerr << "torch.allclose returned FALSE comparing model outputs." << std::endl;
60 | return 1;
61 | }
62 | }
--------------------------------------------------------------------------------
/openequivariance/extension/util/buffer.hpp:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | using namespace std;
6 | namespace py = pybind11;
7 |
8 | template
9 | class PyDeviceBuffer {
10 | public:
11 | char* host_ptr;
12 | char* device_ptr;
13 | size_t size;
14 |
15 | PyDeviceBuffer(uint64_t size) {
16 | this->size = size;
17 | device_ptr = static_cast(ALLOC_T::gpu_alloc(size));
18 | host_ptr = nullptr;
19 | }
20 |
21 | PyDeviceBuffer(py::buffer host_data) {
22 | const py::buffer_info &info = host_data.request();
23 | host_ptr = static_cast(info.ptr);
24 | size = 1;
25 | for(int64_t i = 0; i < info.ndim; i++) {
26 | size *= info.shape[i];
27 | }
28 | size *= info.itemsize;
29 |
30 | device_ptr = static_cast(ALLOC_T::gpu_alloc(size));
31 | ALLOC_T::copy_host_to_device(host_ptr, device_ptr, size);
32 | }
33 |
34 | ~PyDeviceBuffer() {
35 | ALLOC_T::gpu_free(static_cast(device_ptr));
36 | }
37 |
38 | void copy_to_host() {
39 | ALLOC_T::copy_device_to_host(host_ptr, device_ptr, size);
40 | }
41 |
42 | uint64_t data_ptr() {
43 | return reinterpret_cast(device_ptr);
44 | }
45 | };
--------------------------------------------------------------------------------
/openequivariance/extlib/.empty:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PASSIONLab/OpenEquivariance/ce8e44ea8a9b0b78dee5ac28dc177218cb802db4/openequivariance/extlib/.empty
--------------------------------------------------------------------------------
/openequivariance/extlib/__init__.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa : F401, E402
2 | import sys
3 | import os
4 | import warnings
5 | from pathlib import Path
6 |
7 | from openequivariance.benchmark.logging_utils import getLogger
8 | from distutils import sysconfig
9 |
10 | oeq_root = str(Path(__file__).parent.parent)
11 |
12 | build_ext = True
13 | TORCH_COMPILE = True
14 | torch_module, generic_module = None, None
15 | postprocess_kernel = lambda kernel: kernel # noqa : E731
16 |
17 | try:
18 | python_lib_dir = sysconfig.get_config_var("LIBDIR")
19 | major, minor = sys.version_info.major, sys.version_info.minor
20 | python_lib_name = f"python{major}.{minor}"
21 |
22 | except Exception as e:
23 | print("Error while retrieving Python library information:", file=sys.stderr)
24 | print(e, file=sys.stderr)
25 | print("Syconfig variable list:", file=sys.stderr)
26 | print(sysconfig.get_config_vars(), file=sys.stderr)
27 | exit(1)
28 |
29 | if not build_ext:
30 | from openequivariance.extlib.generic_module import (
31 | GenericTensorProductImpl,
32 | JITTPImpl,
33 | ConvolutionImpl,
34 | JITConvImpl,
35 | GroupMM_F32,
36 | GroupMM_F64,
37 | DeviceProp,
38 | DeviceBuffer,
39 | GPUTimer,
40 | )
41 | else:
42 | from torch.utils.cpp_extension import library_paths, include_paths
43 |
44 | global torch
45 | import torch
46 |
47 | extra_cflags = ["-O3"]
48 | generic_sources = ["generic_module.cpp"]
49 | torch_sources = ["libtorch_tp_jit.cpp"]
50 |
51 | include_dirs, extra_link_args = (
52 | ["util"],
53 | [
54 | f"-Wl,--no-as-needed,-rpath,{python_lib_dir}",
55 | f"-L{python_lib_dir}",
56 | f"-l{python_lib_name}",
57 | ],
58 | )
59 |
60 | if torch.version.cuda:
61 | extra_link_args.extend(["-lcuda", "-lcudart", "-lnvrtc"])
62 |
63 | try:
64 | torch_libs, cuda_libs = library_paths("cuda")
65 | extra_link_args.append("-Wl,-rpath," + torch_libs)
66 | extra_link_args.append("-L" + cuda_libs)
67 | if os.path.exists(cuda_libs + "/stubs"):
68 | extra_link_args.append("-L" + cuda_libs + "/stubs")
69 | except Exception as e:
70 | getLogger().info(str(e))
71 |
72 | extra_cflags.append("-DCUDA_BACKEND")
73 | elif torch.version.hip:
74 | extra_link_args.extend(["-lhiprtc"])
75 | torch_libs = library_paths("cuda")[0]
76 | extra_link_args.append("-Wl,-rpath," + torch_libs)
77 |
78 | def postprocess(kernel):
79 | kernel = kernel.replace("__syncwarp();", "__threadfence_block();")
80 | kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(")
81 | kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd")
82 | return kernel
83 |
84 | postprocess_kernel = postprocess
85 |
86 | extra_cflags.append("-DHIP_BACKEND")
87 |
88 | generic_sources = [oeq_root + "/extension/" + src for src in generic_sources]
89 | torch_sources = [oeq_root + "/extension/" + src for src in torch_sources]
90 | include_dirs = [oeq_root + "/extension/" + d for d in include_dirs] + include_paths(
91 | "cuda"
92 | )
93 |
94 | torch_compile_exception = None
95 | with warnings.catch_warnings():
96 | warnings.simplefilter("ignore")
97 |
98 | try:
99 | torch_module = torch.utils.cpp_extension.load(
100 | "libtorch_tp_jit",
101 | torch_sources,
102 | extra_cflags=extra_cflags,
103 | extra_include_paths=include_dirs,
104 | extra_ldflags=extra_link_args,
105 | )
106 | torch.ops.load_library(torch_module.__file__)
107 | except Exception as e:
108 | # If compiling torch fails (e.g. low gcc version), we should fall back to the
109 | # version that takes integer pointers as args (but is untraceable to PyTorch JIT / export).
110 | TORCH_COMPILE = False
111 | torch_compile_exception = e
112 |
113 | generic_module = torch.utils.cpp_extension.load(
114 | "generic_module",
115 | generic_sources,
116 | extra_cflags=extra_cflags,
117 | extra_include_paths=include_dirs,
118 | extra_ldflags=extra_link_args,
119 | )
120 |
121 | if not TORCH_COMPILE:
122 | warnings.warn(
123 | "Could not compile integrated PyTorch wrapper. Falling back to Pybind11"
124 | + f", but JITScript, compile fullgraph, and export will fail.\n {torch_compile_exception}"
125 | )
126 |
127 | from generic_module import (
128 | GenericTensorProductImpl,
129 | JITTPImpl,
130 | ConvolutionImpl,
131 | JITConvImpl,
132 | GroupMM_F32,
133 | GroupMM_F64,
134 | DeviceProp,
135 | DeviceBuffer,
136 | GPUTimer,
137 | )
138 |
--------------------------------------------------------------------------------
/openequivariance/implementations/E3NNTensorProduct.py:
--------------------------------------------------------------------------------
1 | __all__ = [
2 | "E3NNTensorProduct",
3 | "E3NNTensorProductCompiled",
4 | "E3NNTensorProductCompiledCUDAGraphs",
5 | "E3NNTensorProductCompiledMaxAutotuneCUDAGraphs",
6 | ]
7 |
8 | import os
9 | import pathlib
10 | import numpy as np
11 |
12 | from openequivariance.implementations.TensorProductBase import TensorProductBase
13 | from openequivariance.implementations.e3nn_lite import TPProblem
14 | from openequivariance.benchmark.logging_utils import getLogger
15 |
16 | TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning")
17 |
18 | logger = getLogger()
19 |
20 |
21 | class E3NNTensorProduct(TensorProductBase):
22 | def __init__(self, config: TPProblem, torch_op=True):
23 | super().__init__(config, torch_op=torch_op)
24 | assert self.torch_op
25 |
26 | global torch
27 | global e3nn
28 | import torch
29 | import e3nn
30 | from e3nn import o3
31 |
32 | e3nn.set_optimization_defaults(jit_script_fx=False)
33 |
34 | assert config.irrep_dtype == config.weight_dtype
35 | if config.irrep_dtype == np.float64:
36 | torch.set_default_dtype(torch.float64)
37 |
38 | self.e3nn_tp = o3.TensorProduct(
39 | config.irreps_in1,
40 | config.irreps_in2,
41 | config.irreps_out,
42 | config.instructions_raw,
43 | in1_var=config.in1_var,
44 | in2_var=config.in2_var,
45 | out_var=config.out_var,
46 | irrep_normalization=config.irrep_normalization,
47 | path_normalization=config.path_normalization,
48 | internal_weights=config.internal_weights,
49 | shared_weights=config.shared_weights,
50 | ).to(device="cuda")
51 |
52 | if config.irrep_dtype == np.float64:
53 | torch.set_default_dtype(torch.float32) # Reset to default
54 |
55 | self.forward = self.e3nn_tp.__call__
56 |
57 | def forward_cpu(
58 | self,
59 | L1_in: np.ndarray,
60 | L2_in: np.ndarray,
61 | L3_out: np.ndarray,
62 | weights: np.ndarray,
63 | ) -> None:
64 | torch_L1_in = torch.tensor(L1_in, device="cuda")
65 | torch_L2_in = torch.tensor(L2_in, device="cuda")
66 | torch_weights = torch.tensor(weights, device="cuda")
67 |
68 | torch_L3_out = self.e3nn_tp(torch_L1_in, torch_L2_in, torch_weights)
69 |
70 | L3_out[:] = torch_L3_out.numpy(force=True)
71 |
72 | def backward_cpu(
73 | self,
74 | L1_in: np.ndarray,
75 | L1_grad: np.ndarray,
76 | L2_in: np.ndarray,
77 | L2_grad: np.ndarray,
78 | L3_grad: np.ndarray,
79 | weights: np.ndarray,
80 | weights_grad: np.ndarray,
81 | ) -> None:
82 | torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda")
83 | torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda")
84 | torch_weights = torch.tensor(weights, requires_grad=True, device="cuda")
85 |
86 | torch_out = self.e3nn_tp(torch_L1_in, torch_L2_in, torch_weights)
87 |
88 | torch_L3_grad_in = torch.tensor(L3_grad, device="cuda")
89 |
90 | torch_out.backward(gradient=torch_L3_grad_in)
91 |
92 | L1_grad[:] = torch_L1_in.grad.numpy(force=True)
93 | L2_grad[:] = torch_L2_in.grad.numpy(force=True)
94 | weights_grad[:] = torch_weights.grad.numpy(force=True)
95 |
96 | @classmethod
97 | def name(cls):
98 | return cls.__name__
99 |
100 |
101 | class E3NNTensorProductCompiled(E3NNTensorProduct):
102 | def __init__(
103 | self,
104 | config: TPProblem,
105 | torch_compile_kwargs: dict,
106 | torch_op: bool = True,
107 | ):
108 | super().__init__(config, torch_op=torch_op)
109 | self.torch_compile_kwargs = torch_compile_kwargs
110 |
111 | logger.debug("Torch compiling e3nn TP")
112 | logger.debug(msg=f"{torch_compile_kwargs}")
113 | self.e3nn_tp = torch.compile(self.e3nn_tp, **self.torch_compile_kwargs)
114 | logger.debug("e3nn TP torch compiled")
115 |
116 | self.forward = self.e3nn_tp.__call__
117 |
118 |
119 | class E3NNTensorProductCompiledCUDAGraphs(E3NNTensorProductCompiled):
120 | def __init__(self, config: TPProblem, torch_op=True):
121 | global torch
122 | import torch
123 |
124 | torch._dynamo.config.cache_size_limit = 64
125 |
126 | torch_compile_kwargs = {
127 | "fullgraph": True,
128 | "backend": "inductor",
129 | "options": {"triton.cudagraphs": True},
130 | }
131 | super().__init__(config, torch_compile_kwargs, torch_op=torch_op)
132 |
133 |
134 | class E3NNTensorProductCompiledMaxAutotuneCUDAGraphs(E3NNTensorProductCompiled):
135 | def __init__(self, config: TPProblem, torch_op=True):
136 | global torch
137 | import torch
138 |
139 | TORCH_COMPILE_AUTOTUNING_DIR.mkdir(exist_ok=True)
140 | os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(TORCH_COMPILE_AUTOTUNING_DIR)
141 | os.environ["TRITON_CACHE_DIR"] = str(TORCH_COMPILE_AUTOTUNING_DIR)
142 | torch._dynamo.config.cache_size_limit = 64
143 |
144 | torch_compile_kwargs = {
145 | "fullgraph": True,
146 | "backend": "inductor",
147 | "options": {
148 | "max_autotune": True,
149 | "triton.cudagraphs": True,
150 | "triton.unique_kernel_names": False,
151 | "coordinate_descent_tuning": False,
152 | },
153 | }
154 | super().__init__(config, torch_compile_kwargs, torch_op=torch_op)
155 |
--------------------------------------------------------------------------------
/openequivariance/implementations/MultiplicityOuterProductTP.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from openequivariance.implementations.utils import calc_weight_offsets
4 | from openequivariance.implementations.e3nn_lite import (
5 | Irreps,
6 | TPProblem,
7 | Instruction,
8 | )
9 | from openequivariance.implementations.TensorProductBase import TensorProductBase
10 | from openequivariance.benchmark.logging_utils import getLogger
11 | from jinja2 import Environment, PackageLoader
12 |
13 | from openequivariance.extlib import KernelLaunchConfig, JITTPImpl, DeviceProp
14 |
15 | logger = getLogger()
16 |
17 |
18 | def raise_helper(msg):
19 | raise Exception(msg)
20 |
21 |
22 | def divide(numerator, denominator):
23 | return numerator // denominator
24 |
25 |
26 | def sizeof(dtype):
27 | if dtype in ["float", "int", "unsigned int"]:
28 | return 4
29 | else:
30 | raise Exception("Provided undefined datatype to sizeof!")
31 |
32 |
33 | class MultiplicityOuterProductTP(TensorProductBase):
34 | def __init__(self, config: TPProblem, torch_op: bool = False):
35 | super().__init__(config, torch_op)
36 |
37 | for ins in config.instructions: # type : Instruction
38 | assert isinstance(ins, Instruction)
39 | assert ins.connection_mode == "uvw"
40 | assert ins.path_shape[0] <= 32
41 | assert ins.path_shape[1] <= 32
42 | assert ins.path_shape[2] <= 32
43 |
44 | irreps_in1 = config.irreps_in1
45 | irreps_in2 = config.irreps_in2
46 | irreps_out = config.irreps_out
47 |
48 | # ==================================================================================
49 |
50 | env = Environment(
51 | loader=PackageLoader("openequivariance"), extensions=["jinja2.ext.do"]
52 | )
53 | env.globals["raise"] = raise_helper
54 | env.globals["divide"] = divide
55 | env.globals["sizeof"] = sizeof
56 | env.globals["range"] = range
57 | env.globals["enumerate"] = enumerate
58 | env.globals["len"] = len
59 | main_template = env.get_template("subkernel_per_interaction_multirep.cuh")
60 | # forward_subkernel_template = env.get_template("subkernel_forward_thread.cu.jinja2")
61 | # backward_subkernel_template = env.get_template("subkernel_backward_thread.cu.jinja2")
62 |
63 | # =====================================================================
64 | # Updated to work with TensorProductProblem
65 |
66 | class RepData:
67 | def __init__(self, irreps: Irreps):
68 | assert isinstance(irreps, Irreps)
69 | self.rep_len = irreps.dim
70 | self.irrep_lengths = [mul_irrep.ir.dim for mul_irrep in irreps]
71 | self.mults = [mul_irrep.mul for mul_irrep in irreps]
72 |
73 | offset = 0
74 | self.offsets = []
75 | for mul_irrep in irreps:
76 | self.offsets.append(offset)
77 | offset += mul_irrep.dim
78 |
79 | # =====================================================================
80 | # Strictly Copied from Loop Unroll TP
81 |
82 | class CGTensor:
83 | def __init__(self, l1, l2, l3):
84 | tensor = load_cg_tensor(l1, l2, l3)
85 | coord1, coord2, coord3 = [
86 | arr.astype(np.int32).copy() for arr in np.nonzero(tensor)
87 | ]
88 | float_values = tensor[np.nonzero(tensor)].astype(np.float32).copy()
89 | values = [str(float.hex(float(val))) + "f" for val in float_values]
90 |
91 | self.tuples = [
92 | (coord1[i], coord2[i], coord3[i], values[i])
93 | for i in range(len(values))
94 | ]
95 | # self.tuples.sort(key=lambda tup: (tup[1], tup[0], tup[2]))
96 | self.nnz = len(values)
97 |
98 | # =====================================================================
99 | # FORWARD MEMORY ANALYSIS
100 | forward_thread_blocks_per_SM = 24
101 | forward_threads_per_thread_block = 32
102 |
103 | # =====================================================================
104 | dp = DeviceProp(0)
105 |
106 | forward_launch_config = KernelLaunchConfig()
107 | forward_launch_config.num_blocks = (
108 | dp.multiprocessorCount * forward_thread_blocks_per_SM
109 | )
110 | forward_launch_config.num_threads = forward_threads_per_thread_block
111 |
112 | # IMPORTANT!
113 | smem_gemm_max_n = forward_threads_per_thread_block
114 | smem_gemm_L3_scratch = smem_gemm_max_n * max(
115 | RepData(config.irreps_out).irrep_lengths
116 | ) # this has space for the largest output size * 32
117 | smem_gemm_weights_scratch = (
118 | max(RepData(config.irreps_out).mults) * smem_gemm_max_n
119 | )
120 |
121 | smem_gemm_info = {
122 | "n": smem_gemm_max_n,
123 | "L3_scratch_elems": smem_gemm_L3_scratch,
124 | "weight_scratch_elems": smem_gemm_weights_scratch,
125 | }
126 | logger.debug(smem_gemm_info)
127 | # END OF IMPORTANT
128 |
129 | forward_launch_config.smem = (
130 | (
131 | irreps_in1.dim
132 | + irreps_in2.dim
133 | + irreps_out.dim
134 | + smem_gemm_L3_scratch
135 | + smem_gemm_weights_scratch
136 | )
137 | * sizeof("float")
138 | * forward_launch_config.num_threads
139 | // forward_launch_config.warp_size
140 | )
141 |
142 | logger.info(
143 | f"Forward pass needs {forward_launch_config.smem} bytes of shared memory."
144 | )
145 |
146 | if forward_launch_config.smem > dp.maxSharedMemPerBlock:
147 | raise Exception(
148 | f"Error, requested shared memory {forward_launch_config.smem}B hits or exceeds maximum, {dp.maxSharedMemPerBlock}B !"
149 | )
150 |
151 | # =====================================================================
152 |
153 | backward_launch_config = KernelLaunchConfig()
154 | backward_launch_config.num_blocks = dp.multiprocessorCount * 1
155 | backward_launch_config.num_threads = 32
156 | backward_launch_config.smem = (
157 | (2 * irreps_in1.dim + 2 * irreps_in2.dim + 2 * +irreps_out.dim)
158 | * sizeof("float")
159 | * backward_launch_config.num_threads
160 | // backward_launch_config.warp_size
161 | )
162 | logger.info(
163 | f"Backward pass needs {backward_launch_config.smem} bytes of shared memory."
164 | )
165 |
166 | if backward_launch_config.smem > dp.maxSharedMemPerBlock:
167 | raise Exception(
168 | f"Error, requested shared memory {backward_launch_config.smem}B hits or exceeds maximum, {dp.maxSharedMemPerBlock}B !"
169 | )
170 |
171 | # =====================================================================
172 |
173 | self.forward_config = forward_launch_config
174 | self.backward_config = backward_launch_config
175 | load_cg_tensor = self.load_cg_tensor
176 |
177 | # =====================================================================
178 | # weights_offsets
179 | weight_offsets = calc_weight_offsets(config)
180 | assert isinstance(weight_offsets, list)
181 | assert len(weight_offsets) == len(list(config.instructions))
182 |
183 | # =====================================================================
184 | # tranform "e3nn instructions" into "interactions"
185 | instructions: list[Instruction] = config.instructions
186 | interactions = []
187 | for ins in instructions:
188 | u = ins.i_in1
189 | v = ins.i_in2
190 | w = ins.i_out
191 | interaction = (
192 | u,
193 | v,
194 | w,
195 | CGTensor(irreps_in1[u].ir.l, irreps_in2[v].ir.l, irreps_out[w].ir.l),
196 | )
197 | interactions.append(interaction)
198 | # interactions.sort(key=lambda x: (x[2], x[0], x[1]))
199 |
200 | assert len(interactions) != 0
201 |
202 | # =====================================================================
203 | kernel_text = main_template.render(
204 | L1=RepData(config.irreps_in1),
205 | L2=RepData(config.irreps_in2),
206 | L3=RepData(config.irreps_out),
207 | weight_numel=config.weight_numel,
208 | weight_offsets=weight_offsets,
209 | instructions=instructions,
210 | interactions=interactions,
211 | smem_gemm_info=smem_gemm_info,
212 | forward_config=forward_launch_config,
213 | backward_config=backward_launch_config,
214 | )
215 |
216 | self.jit_kernel = kernel_text
217 |
218 | logger.debug(kernel_text)
219 |
220 | logger.info("Starting NVRTC")
221 | self.internal = JITTPImpl(
222 | self.jit_kernel, self.forward_config, self.backward_config
223 | )
224 | logger.info("Kernel compiled!")
225 |
226 | if self.torch_op:
227 | self.setup_torch_custom_op()
228 |
229 | @staticmethod
230 | def name():
231 | return "MultiplicityOuterProductTP"
232 |
--------------------------------------------------------------------------------
/openequivariance/implementations/TensorProduct.py:
--------------------------------------------------------------------------------
1 | from openequivariance.implementations.LoopUnrollTP import LoopUnrollTP
2 | from openequivariance import TPProblem
3 | import torch
4 |
5 |
6 | class TensorProduct(torch.nn.Module, LoopUnrollTP):
7 | """
8 | Drop-in replacement for ``o3.TensorProduct`` from e3nn. Supports forward,
9 | backward, and double-backward passes using JIT-compiled kernels. Initialization
10 | fails if:
11 |
12 | * There are no visible GPUs.
13 | * The provided tensor product specification is unsupported.
14 |
15 | :param problem: Specification of the tensor product.
16 | """
17 |
18 | def __init__(self, problem: TPProblem, torch_op=True):
19 | torch.nn.Module.__init__(self)
20 | LoopUnrollTP.__init__(self, problem, torch_op)
21 | self.weight_numel = problem.weight_numel
22 |
23 | @staticmethod
24 | def name():
25 | return LoopUnrollTP.name()
26 |
27 | def forward(
28 | self, x: torch.Tensor, y: torch.Tensor, W: torch.Tensor
29 | ) -> torch.Tensor:
30 | """
31 | Computes :math:`W (x \otimes_{\\textrm{CG}} y)`, identical to
32 | ``o3.TensorProduct.forward``.
33 |
34 | :param x: Tensor of shape ``[batch_size, problem.irreps_in1.dim()]``, datatype
35 | ``problem.irrep_dtype``.
36 | :param y: Tensor of shape ``[batch_size, problem.irreps_in2.dim()]``, datatype
37 | ``problem.irrep_dtype``.
38 | :param W: Tensor of datatype ``problem.weight_dtype`` and shape
39 |
40 | * ``[batch_size, problem.weight_numel]`` if ``problem.shared_weights=False``
41 | * ``[problem.weight_numel]`` if ``problem.shared_weights=True``
42 |
43 | :return: Tensor of shape ``[batch_size, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
44 | """
45 | return torch.ops.libtorch_tp_jit.jit_tp_forward(self.internal, x, y, W)
46 |
--------------------------------------------------------------------------------
/openequivariance/implementations/convolution/CUEConv.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import itertools
3 | from typing import Iterator
4 |
5 | from openequivariance.implementations.CUETensorProduct import CUETensorProduct
6 | from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase
7 |
8 |
9 | class CUEConv(ConvolutionBase):
10 | def __init__(self, config, idx_dtype=np.int64, torch_op=True):
11 | super().__init__(config, idx_dtype, torch_op)
12 |
13 | global torch
14 | import torch
15 |
16 | self.reference_tp = CUETensorProduct(config, torch_op)
17 | self.cue_tp = self.reference_tp.cue_tp
18 |
19 | from openequivariance.implementations.convolution.scatter import scatter_sum
20 |
21 | self.scatter_sum = scatter_sum
22 |
23 | def forward(self, L1_in, L2_in, weights, rows, cols):
24 | tp_outputs = self.cue_tp(L1_in[cols], L2_in, weights)
25 | return self.scatter_sum(
26 | src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0]
27 | )
28 |
29 | @staticmethod
30 | def name():
31 | return "CUEConvolution"
32 |
33 |
34 | class CUEConvFused(ConvolutionBase):
35 | def __init__(self, config, idx_dtype=np.int64, torch_op=True):
36 | super().__init__(config, idx_dtype, torch_op)
37 |
38 | global torch
39 | import torch
40 | import e3nn.o3 as o3
41 |
42 | np_to_torch_dtype = {np.float32: torch.float32, np.float64: torch.float64}
43 |
44 | import cuequivariance as cue
45 | from cuequivariance_torch.primitives.tensor_product import (
46 | TensorProductUniform4x1dIndexed,
47 | )
48 |
49 | class O3_e3nn(cue.O3):
50 | def __mul__( # pylint: disable=no-self-argument
51 | rep1: "O3_e3nn", rep2: "O3_e3nn"
52 | ) -> Iterator["O3_e3nn"]:
53 | return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)]
54 |
55 | @classmethod
56 | def clebsch_gordan(
57 | cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn"
58 | ) -> np.ndarray:
59 | rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3)
60 |
61 | if rep1.p * rep2.p == rep3.p:
62 | return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(
63 | rep3.dim
64 | )
65 | return np.zeros((0, rep1.dim, rep2.dim, rep3.dim))
66 |
67 | def __lt__( # pylint: disable=no-self-argument
68 | rep1: "O3_e3nn", rep2: "O3_e3nn"
69 | ) -> bool:
70 | rep2 = rep1._from(rep2)
71 | return (rep1.l, rep1.p) < (rep2.l, rep2.p)
72 |
73 | @classmethod
74 | def iterator(cls) -> Iterator["O3_e3nn"]:
75 | for l in itertools.count(0):
76 | yield O3_e3nn(l=l, p=1 * (-1) ** l)
77 | yield O3_e3nn(l=l, p=-1 * (-1) ** l)
78 |
79 | descriptor = (
80 | cue.descriptors.channelwise_tensor_product(
81 | cue.Irreps(O3_e3nn, str(config.irreps_in1)),
82 | cue.Irreps(O3_e3nn, str(config.irreps_in2)),
83 | cue.Irreps(O3_e3nn, str(config.irreps_out)),
84 | )
85 | .squeeze_modes()
86 | .flatten_coefficient_modes()
87 | )
88 |
89 | self.tp = TensorProductUniform4x1dIndexed(
90 | descriptor.polynomial.operations[0][1],
91 | "cuda",
92 | math_dtype=np_to_torch_dtype[config.irrep_dtype],
93 | )
94 |
95 | def forward(self, L1_in, L2_in, weights, rows, cols):
96 | return self.tp(weights, L1_in, L2_in, None, rows, None, cols, L1_in.shape[0])
97 |
98 | @staticmethod
99 | def name():
100 | return "CUEConvolutionFused"
101 |
--------------------------------------------------------------------------------
/openequivariance/implementations/convolution/E3NNConv.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase
4 | from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct
5 |
6 |
7 | class E3NNConv(ConvolutionBase):
8 | def __init__(self, config, idx_dtype=np.int64, torch_op=True):
9 | assert torch_op
10 | super().__init__(config, idx_dtype, torch_op)
11 |
12 | from e3nn import o3
13 | import torch
14 |
15 | if config.irrep_dtype == np.float64:
16 | torch.set_default_dtype(torch.float64)
17 |
18 | self.e3nn_tp = o3.TensorProduct(
19 | config.irreps_in1,
20 | config.irreps_in2,
21 | config.irreps_out,
22 | config.instructions_raw,
23 | in1_var=config.in1_var,
24 | in2_var=config.in2_var,
25 | out_var=config.out_var,
26 | irrep_normalization=config.irrep_normalization,
27 | path_normalization=config.path_normalization,
28 | internal_weights=config.internal_weights,
29 | shared_weights=config.shared_weights,
30 | ).to(device="cuda")
31 |
32 | self.reference_tp = E3NNTensorProduct(config)
33 |
34 | if config.irrep_dtype == np.float64:
35 | torch.set_default_dtype(torch.float32) # Reset to default
36 |
37 | from openequivariance.implementations.convolution.scatter import scatter_sum
38 |
39 | self.scatter_sum = scatter_sum
40 |
41 | def forward(self, L1_in, L2_in, weights, rows, cols):
42 | tp_outputs = self.reference_tp(L1_in[cols], L2_in, weights)
43 | return self.scatter_sum(
44 | src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0]
45 | )
46 |
47 | @staticmethod
48 | def name():
49 | return "E3NNConvolution"
50 |
51 | def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
52 | tp_outputs = np.zeros((graph.nnz, self.L3.dim), dtype=L3_out.dtype)
53 | self.reference_tp.forward_cpu(L1_in[graph.cols], L2_in, tp_outputs, weights)
54 | np.add.at(L3_out, graph.rows, tp_outputs)
55 |
56 | def backward_cpu(
57 | self,
58 | L1_in: np.ndarray,
59 | L1_grad: np.ndarray,
60 | L2_in: np.ndarray,
61 | L2_grad: np.ndarray,
62 | L3_grad: np.ndarray,
63 | weights: np.ndarray,
64 | weights_grad: np.ndarray,
65 | graph,
66 | ):
67 | L1_grad_bcast = np.zeros((graph.nnz, self.L1.dim), dtype=L1_grad.dtype)
68 | self.reference_tp.backward_cpu(
69 | L1_in[graph.cols],
70 | L1_grad_bcast,
71 | L2_in,
72 | L2_grad,
73 | L3_grad[graph.rows],
74 | weights,
75 | weights_grad,
76 | )
77 | np.add.at(L1_grad, graph.cols, L1_grad_bcast)
78 |
--------------------------------------------------------------------------------
/openequivariance/implementations/convolution/TensorProductConv.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import types
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from openequivariance import extlib
8 | from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase
9 | from openequivariance.implementations.convolution.LoopUnrollConv import LoopUnrollConv
10 | from openequivariance.implementations.TensorProduct import TensorProduct
11 | from openequivariance import TPProblem
12 |
13 |
14 | class TensorProductConv(torch.nn.Module, LoopUnrollConv):
15 | """
16 | Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`,
17 | :math:`y_1...y_{|E|}`, and weights :math:`W_1...W_{|E|}`, computes
18 |
19 | .. math::
20 |
21 | z_i = \sum_{(i, j, e) \in \mathcal{N}(i)} W_e (x_j \otimes_{\\textrm{CG}} y_e)
22 |
23 | where :math:`(i, j, e) \in \mathcal{N}(i)` indicates that node :math:`i` is connected to node :math:`j`
24 | via the edge indexed :math:`e`.
25 |
26 | This class offers multiple options to perform the summation: an atomic algorithm and a deterministic algorithm
27 | that relies on a sorted adjacency matrix input. If you use the determinstic algorithm, you must also supply
28 | a permutation to transpose the adjacency matrix.
29 |
30 | :param problem: Specification of the tensor product.
31 | :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic
32 | fixup-based algorithm. `Default`: ``False``.
33 | :param kahan: if ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option,
34 | the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``.
35 |
36 | """
37 |
38 | def __init__(
39 | self,
40 | problem: TPProblem,
41 | deterministic: bool = False,
42 | kahan: bool = False,
43 | torch_op=True,
44 | ):
45 | torch.nn.Module.__init__(self)
46 | LoopUnrollConv.__init__(
47 | self,
48 | problem,
49 | idx_dtype=np.int64,
50 | torch_op=torch_op,
51 | deterministic=deterministic,
52 | kahan=kahan,
53 | )
54 |
55 | self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device="cuda")
56 | self.weight_numel = self.config.weight_numel
57 |
58 | if not extlib.TORCH_COMPILE:
59 | self.forward = types.MethodType(LoopUnrollConv.forward, self)
60 |
61 | def forward(
62 | self,
63 | X: torch.Tensor,
64 | Y: torch.Tensor,
65 | W: torch.Tensor,
66 | rows: torch.Tensor,
67 | cols: torch.Tensor,
68 | sender_perm: Optional[torch.Tensor] = None,
69 | ) -> torch.Tensor:
70 | """
71 | Computes the fused CG tensor product + convolution.
72 |
73 | :param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
74 | :param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
75 | :param W: Tensor of datatype ``problem.weight_dtype`` and shape
76 |
77 | * ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False``
78 | * ``[problem.weight_numel]`` if ``problem.shared_weights=True``
79 |
80 | :param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix,
81 | datatype ``torch.int64``. Must be row-major sorted along with ``cols`` when ``deterministic=True``.
82 | :param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix,
83 | datatype ``torch.int64``.
84 | :param sender_perm: Tensor of shape ``[|E|]`` and ``torch.int64`` datatype containing a
85 | permutation that transposes the adjacency matrix nonzeros from row-major to column-major order.
86 | Must be provided when ``deterministic=True``.
87 |
88 | :return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
89 | """
90 | if sender_perm is None:
91 | return torch.ops.libtorch_tp_jit.jit_conv_forward(
92 | self.internal,
93 | X,
94 | Y,
95 | W,
96 | rows,
97 | cols,
98 | self.workspace_buffer,
99 | self.dummy_transpose_perm,
100 | )
101 | else:
102 | return torch.ops.libtorch_tp_jit.jit_conv_forward(
103 | self.internal,
104 | X,
105 | Y,
106 | W,
107 | rows,
108 | cols,
109 | self.workspace_buffer,
110 | sender_perm,
111 | )
112 |
113 | @staticmethod
114 | def name():
115 | return LoopUnrollConv.name()
116 |
117 |
118 | # ==================================================================
119 | # Reference implementations for benchmarking
120 |
121 |
122 | class TensorProductConvKahan(TensorProductConv):
123 | def __init__(self, config, idx_dtype=np.int64, torch_op=True):
124 | super().__init__(config, idx_dtype, torch_op, deterministic=True, kahan=True)
125 |
126 | @staticmethod
127 | def name():
128 | return "LoopUnrollConvKahan"
129 |
130 |
131 | class TensorProductConvDeterministic(TensorProductConv):
132 | def __init__(self, config, idx_dtype=np.int64, torch_op=True):
133 | super().__init__(config, idx_dtype, torch_op, deterministic=True)
134 |
135 | @staticmethod
136 | def name():
137 | return "LoopUnrollConvDeterministic"
138 |
139 |
140 | class TensorProductConvAtomic(TensorProductConv):
141 | def __init__(self, config, idx_dtype=np.int64, torch_op=True):
142 | super().__init__(config, idx_dtype, torch_op, deterministic=False)
143 |
144 | @staticmethod
145 | def name():
146 | return "LoopUnrollConvAtomic"
147 |
148 |
149 | class TensorProductConvScatterSum(ConvolutionBase):
150 | def __init__(self, config, idx_dtype=np.int64, torch_op=True):
151 | assert torch_op
152 | global torch
153 | import torch
154 |
155 | super().__init__(config, idx_dtype, torch_op=torch_op, deterministic=False)
156 |
157 | self.reference_tp = TensorProduct(config, torch_op=torch_op)
158 | from openequivariance.implementations.convolution.scatter import scatter_sum
159 |
160 | self.scatter_sum = scatter_sum
161 |
162 | def forward(self, L1_in, L2_in, weights, rows, cols):
163 | tp_outputs = self.reference_tp(L1_in[cols], L2_in, weights)
164 | return self.scatter_sum(
165 | src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0]
166 | )
167 |
168 | def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
169 | tp_outputs = np.zeros((graph.nnz, self.L3.dim), dtype=L3_out.dtype)
170 | self.reference_tp.forward_cpu(L1_in[graph.cols], L2_in, tp_outputs, weights)
171 | np.add.at(L3_out, graph.rows, tp_outputs)
172 |
173 | def backward_cpu(
174 | self,
175 | L1_in: np.ndarray,
176 | L1_grad: np.ndarray,
177 | L2_in: np.ndarray,
178 | L2_grad: np.ndarray,
179 | L3_grad: np.ndarray,
180 | weights: np.ndarray,
181 | weights_grad: np.ndarray,
182 | graph,
183 | ):
184 | L1_grad_bcast = np.zeros((graph.nnz, self.L1.dim), dtype=L1_grad.dtype)
185 | self.reference_tp.backward_cpu(
186 | L1_in[graph.cols],
187 | L1_grad_bcast,
188 | L2_in,
189 | L2_grad,
190 | L3_grad[graph.rows],
191 | weights,
192 | weights_grad,
193 | )
194 | np.add.at(L1_grad, graph.cols, L1_grad_bcast)
195 |
196 | @staticmethod
197 | def name():
198 | return "LoopUnrollConvScatterSum"
199 |
--------------------------------------------------------------------------------
/openequivariance/implementations/convolution/scatter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional
3 |
4 | """
5 | Scatter sum operator from MACE.
6 |
7 | basic scatter_sum operations from torch_scatter from
8 | https://github.com/mir-group/pytorch_runstats/blob/main/torch_runstats/scatter_sum.py
9 | Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency.
10 | """
11 |
12 |
13 | def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
14 | if dim < 0:
15 | dim = other.dim() + dim
16 | if src.dim() == 1:
17 | for _ in range(0, dim):
18 | src = src.unsqueeze(0)
19 | for _ in range(src.dim(), other.dim()):
20 | src = src.unsqueeze(-1)
21 | src = src.expand_as(other)
22 | return src
23 |
24 |
25 | def scatter_sum(
26 | src: torch.Tensor,
27 | index: torch.Tensor,
28 | dim: int = -1,
29 | out: Optional[torch.Tensor] = None,
30 | dim_size: Optional[int] = None,
31 | reduce: str = "sum",
32 | ) -> torch.Tensor:
33 | assert reduce == "sum" # for now, TODO
34 | index = _broadcast(index, src, dim)
35 | if out is None:
36 | size = list(src.size())
37 | if dim_size is not None:
38 | size[dim] = dim_size
39 | elif index.numel() == 0:
40 | size[dim] = 0
41 | else:
42 | size[dim] = int(index.max()) + 1
43 | out = torch.zeros(size, dtype=src.dtype, device=src.device)
44 | return out.scatter_add_(dim, index, src)
45 | else:
46 | return out.scatter_add_(dim, index, src)
47 |
--------------------------------------------------------------------------------
/openequivariance/implementations/symmetric_contraction/__init__.py:
--------------------------------------------------------------------------------
1 | from openequivariance.implementations.symmetric_contraction.symmetric_contraction import (
2 | SymmetricContraction,
3 | )
4 |
5 | __all__ = ["SymmetricContraction"]
6 |
--------------------------------------------------------------------------------
/openequivariance/implementations/utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import math
3 |
4 | import numpy as np
5 |
6 | from openequivariance.implementations.TensorProductBase import TensorProductBase
7 | from openequivariance.implementations.e3nn_lite import (
8 | Irreps,
9 | Instruction,
10 | TPProblem,
11 | )
12 |
13 |
14 | def sparse_outer_product_work(cg: np.ndarray) -> int:
15 | return np.sum(np.max(cg != 0, axis=2))
16 |
17 |
18 | def convenience_namer(L1: Irreps, L2: Irreps, L3: Irreps):
19 | return f"({L1}x{L2}->{L3})"
20 |
21 |
22 | # Non Zeros
23 | @functools.lru_cache(typed=True)
24 | def count_cg_non_zero(l1, l2, l3) -> int:
25 | return np.count_nonzero(TensorProductBase.load_cg_tensor(l1, l2, l3))
26 |
27 |
28 | def calculate_total_nnz(tpp: TPProblem) -> int:
29 | """
30 | To make sure you don't over count repeat CGs which get used multiple times
31 | """
32 | nnz_by_l_combo = {}
33 | for ins in tpp.instructions: # type : Instruction
34 | l1 = tpp.irreps_in1[ins.i_in1].ir.l
35 | l2 = tpp.irreps_in2[ins.i_in2].ir.l
36 | l3 = tpp.irreps_out[ins.i_out].ir.l
37 | assert isinstance(l1, int)
38 | assert isinstance(l2, int)
39 | assert isinstance(l3, int)
40 | nnz_by_l_combo[(l1, l2, l3)] = count_cg_non_zero(l1, l2, l3)
41 | return sum(nnz_by_l_combo.values())
42 |
43 |
44 | def calc_weight_offsets(tpp: TPProblem) -> list[int]:
45 | """
46 | Returns a list of weight offsets for every instruction.
47 | """
48 | assert isinstance(tpp, TPProblem)
49 | offset = 0
50 | offsets = []
51 | for ins in tpp.instructions:
52 | assert isinstance(ins, Instruction)
53 | offsets.append(offset)
54 | if ins.has_weight:
55 | flatsize = math.prod(ins.path_shape)
56 | offset += flatsize
57 | return offsets
58 |
59 |
60 | def filter_and_analyze_problem(problem):
61 | """
62 | Centralized function that stops unhandled problem configurations,
63 | returns a dictionary of useful information about the problem.
64 | """
65 | for i, inst in enumerate(problem.instructions):
66 | assert inst.connection_mode == problem.instructions[0].connection_mode, (
67 | f"All instructions must have the same connection mode, got {inst.connection_mode} and {problem.instructions[0].connection_mode}"
68 | )
69 |
70 | assert inst.has_weight, (
71 | f"All instructions must have trainable weights, got {inst.has_weight} at index {i}"
72 | )
73 |
74 | assert problem.instructions[0].connection_mode in ["uvu", "uvw"], (
75 | f"Connection mode must be 'uvu' or 'uvw', got {problem.instructions[0].connection_mode}"
76 | )
77 |
78 | assert problem.irrep_dtype == problem.weight_dtype, (
79 | f"irrep_dtype and weight_dtype must be the same, got {problem.irrep_dtype} and {problem.weight_dtype}"
80 | )
81 |
82 | assert not problem.internal_weights, (
83 | f"Openequivariance does not support internal weights, got {problem.internal_weights}"
84 | )
85 |
86 | assert len(problem.instructions) > 0, "Tensor product has no valid instructions!"
87 |
88 | result = {
89 | "is_uvw": problem.instructions[0].connection_mode == "uvw",
90 | }
91 | return result
92 |
93 |
94 | def torch_to_oeq_dtype(torch_dtype) -> type[np.generic]:
95 | """
96 | Convenience function; converts a torch datatype to the corresponding
97 | numpy datatype for use in TPProblem.
98 |
99 | :param torch_dtype: torch datatype (e.g., torch.float32, torch.float64)
100 | :return: numpy datatype (e.g., np.float32, np.float64)
101 | """
102 |
103 | global torch
104 | import torch
105 |
106 | if torch_dtype == torch.float32:
107 | return np.float32
108 | elif torch_dtype == torch.float64:
109 | return np.float64
110 | else:
111 | raise ValueError("Unsupported torch dtype!")
112 |
--------------------------------------------------------------------------------
/openequivariance/templates/common.cuh:
--------------------------------------------------------------------------------
1 | #define ROW_OPERATION(ROW_LEN, LOOP_VAR, ...) \
2 | _Pragma ("unroll") \
3 | for(int LOOP_VAR = 0; LOOP_VAR < ROW_LEN; LOOP_VAR += THREADS_PER_WARP) { \
4 | if(LOOP_VAR >= ROW_LEN - THREADS_PER_WARP) { \
5 | if(lane_id < ROW_LEN - LOOP_VAR) { \
6 | __VA_ARGS__ \
7 | } \
8 | } \
9 | else { \
10 | __VA_ARGS__ \
11 | } \
12 | }
13 |
--------------------------------------------------------------------------------
/openequivariance/templates/jinja_utils.py:
--------------------------------------------------------------------------------
1 | from jinja2 import Environment, PackageLoader
2 |
3 |
4 | def raise_helper(msg):
5 | raise Exception(msg)
6 |
7 |
8 | def divide(numerator, denominator):
9 | return numerator // denominator
10 |
11 |
12 | def sizeof(dtype):
13 | if dtype in ["float", "int", "unsigned int"]:
14 | return 4
15 | else:
16 | raise Exception("Provided undefined datatype to sizeof!")
17 |
18 |
19 | def get_jinja_environment():
20 | env = Environment(
21 | loader=PackageLoader("openequivariance"), extensions=["jinja2.ext.do"]
22 | )
23 | env.globals["raise"] = raise_helper
24 | env.globals["divide"] = divide
25 | env.globals["sizeof"] = sizeof
26 | env.globals["enumerate"] = enumerate
27 | return env
28 |
--------------------------------------------------------------------------------
/openequivariance/templates/macros.jinja:
--------------------------------------------------------------------------------
1 | {#
2 | First input argument consists of a dictionary with keys _common_ and _per_warp_.
3 | Keys map to lists of tuples with (name, dtype, num_elements) of each subarray.
4 | #}
5 | {%- macro declare_smem_arrays(arrays, warp_loc_var, config) %}
6 | {%- set warps_per_block = divide(config.num_threads, config.warp_size) %}
7 | extern __shared__ char s[];
8 | {%- set ns = {"offset": 0, "total_warp_bytes": 0} %}
9 | {%- for name, dtype, num_elements in arrays["common"] %}
10 | {{dtype}}* {{name}} = ({{dtype}}*) (s + {{ ns["offset"] }});
11 | {%- do ns.update({"offset": ns["offset"] + num_elements * sizeof(dtype)}) %}
12 | {%- if ns["offset"] > config.smem %}
13 | {{ raise("Error, required shared memory exceeds allocation maximum!") }}
14 | {%- endif %}
15 | {%- endfor %}
16 |
17 | {%- for name, dtype, num_elements in arrays["per_warp"] %}
18 | {% do ns.update({"total_warp_bytes": ns["total_warp_bytes"] + num_elements * sizeof(dtype)}) %}
19 | {%- endfor %}
20 |
21 | {%- if ns["offset"] + ns["total_warp_bytes"] * warps_per_block > config.smem %}
22 | {{ raise("Error, required shared memory exceeds allocation maximum!") }}
23 | {%- endif %}
24 |
25 | char* per_warp_smem = s + {{ns["offset"]}} + {{ns["total_warp_bytes"]}} * {{ warp_loc_var }};
26 |
27 | {%- do ns.update({"offset": 0}) %}
28 | {%- for name, dtype, num_elements in arrays["per_warp"] %}
29 | {{dtype}}* {{name}} = ({{dtype}}*) (per_warp_smem + {{ ns["offset"] }});
30 | {% do ns.update({"offset": ns["offset"] + num_elements * sizeof(dtype)}) %}
31 | {%- endfor %}
32 | {%- endmacro %}
33 |
34 | {# smem contains a mul_ir stored in row-major order as mul * rep, where mul
35 | is at most |warp_size|. reg is at least a |rep|-sized register array on each thread.
36 | Assumes: each thread has the lane_id. #}
37 | {%- macro transpose_load(mul, dim, smem, offset, reg) %}
38 | if(lane_id < {{mul}}) {
39 | {%- for i in range(dim) %}
40 | {{reg}}[{{i}}] = {{smem}}[{{offset}} + lane_id * {{dim}} + {{i}}];
41 | {%- endfor %}
42 | }
43 | {%- endmacro %}
44 |
45 | {%- macro transpose_store(mul, dim, smem, offset, reg, op, coeff) %}
46 | if(lane_id < {{mul}}) {
47 | {%- for i in range(dim) %}
48 | {{smem}}[{{offset}} + lane_id * {{dim}} + {{i}}] {{op}} {{reg}}[{{i}}] * {{coeff}};
49 | {%- endfor %}
50 | }
51 | {%- endmacro %}
52 |
53 | {%- macro declare_smem_variables(segment, smem_base) %}
54 | {%- for name in segment.smem %}
55 | {%- if name != "total" %}
56 | {%- set smem_rng = segment.smem[name] %}
57 | {{ smem_rng["dtype"] }}* {{name}}_smem = ({{smem_rng["dtype"]}}*) ({{smem_base}} + {{smem_rng["offset"]}});
58 | {%- endif %}
59 | {%- endfor %}
60 | {%- endmacro %}
61 |
62 | {%- macro load_ir_segments(map, glb_ptr_shft, smem_ptr, loop_var) %}
63 | {%- if not map.persist_load %}
64 | {%- for (src_rng, dst_rng) in map.copy_ranges %}
65 | {%- set range_len = src_rng.stop - src_rng.start %}
66 | ROW_OPERATION({{range_len}}, {{loop_var}}, {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id] = {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}];)
67 | {%- endfor %}
68 | {%- endif %}
69 | {%- endmacro %}
70 |
71 | {%- macro load_ir_segments_force(map, glb_ptr_shft, smem_ptr, loop_var) %}
72 | {%- for (src_rng, dst_rng) in map.copy_ranges %}
73 | {%- set range_len = src_rng.stop - src_rng.start %}
74 | ROW_OPERATION({{range_len}}, {{loop_var}}, {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id] = {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}];)
75 | {%- endfor %}
76 | {%- endmacro %}
77 |
78 | {%- macro store_ir_segments(map, glb_ptr_shft, smem_ptr, loop_var) %}
79 | {%- if not map.persist_store %}
80 | {%- for i, src_rng in enumerate(map.original_src_ranges) %}
81 | {%- set idx = map.idxs[i] %}
82 | {%- set dst_rng = map.original_dst_ranges[i] %}
83 | {%- set range_len = src_rng.stop - src_rng.start %}
84 | {%- if map.storeback_procedure[idx] == "write" %}
85 | ROW_OPERATION({{range_len}}, {{loop_var}}, {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}] = {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id];)
86 | {%- elif map.storeback_procedure[idx] == "accumulate" %}
87 | ROW_OPERATION({{range_len}}, {{loop_var}}, {{glb_ptr_shft}}[{{loop_var}} + {{src_rng.start}}] += {{smem_ptr}}[{{loop_var}} + {{dst_rng.start}} + lane_id];)
88 | {%- elif map.storeback_procedure[idx] == "atomic_accumulate" %}
89 | ROW_OPERATION({{range_len}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_rng.start}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start}} + lane_id + {{loop_var}}]);)
90 | {%- endif %}
91 | {%- endfor %}
92 | {% endif %}
93 | {%- endmacro %}
94 |
95 | {%- macro set_launch_bound_variables(config) %}
96 | {%- set threads_per_warp = config.warp_size %}
97 | {%- set warps_per_block = divide(config.num_threads, config.warp_size) %}
98 | int t_idx = blockIdx.x * blockDim.x + threadIdx.x;
99 | int warp_id = t_idx / {{ threads_per_warp }};
100 | int lane_id = t_idx % {{ threads_per_warp }};
101 | int warp_loc = warp_id % {{ warps_per_block }};
102 | size_t warps_launched = blockDim.x * gridDim.x / {{ threads_per_warp }};
103 | size_t nnz_per_warp = (num_products + warps_launched - 1) / warps_launched;
104 |
105 | size_t start = nnz_per_warp * ((size_t) warp_id);
106 | size_t end = min(start + nnz_per_warp, num_products);
107 | {%- endmacro %}
108 |
109 | {%- macro transpose_smem_A(irreps, smem_ptr) %}
110 | {%- set slices = irreps.slices() %}
111 | {%- for i, mul_ir in enumerate(irreps) %} {
112 | {%- set dim = mul_ir.ir.dim %}
113 | {%- set mul = mul_ir.mul %}
114 | IRREP_T t_regs[{{dim}}];
115 | if(lane_id < {{mul}}) {
116 | {%- set offset = slices[i].start %}
117 | {%- for i in range(dim) %}
118 | t_regs[{{i}}] = {{smem_ptr}}[{{offset}} + lane_id * {{dim}} + {{i}}];
119 | {%- endfor %}
120 | __syncwarp();
121 | {%- for i in range(dim) %}
122 | {{smem_ptr}}[{{offset}} + lane_id + {{i * mul}}] = t_regs[{{i}}];
123 | {%- endfor %}
124 | }
125 | } {%- endfor %}
126 | {%- endmacro %}
127 |
128 | {%- macro transpose_smem_B(irreps, smem_ptr) %}
129 | {%- set slices = irreps.slices() %}
130 | {%- for i, mul_ir in enumerate(irreps) %} {
131 | {%- set dim = mul_ir.ir.dim %}
132 | {%- set mul = mul_ir.mul %}
133 | IRREP_T t_regs[{{dim}}];
134 | if(lane_id < {{mul}}) {
135 | {%- set offset = slices[i].start %}
136 | {%- for i in range(dim) %}
137 | t_regs[{{i}}] = {{smem_ptr}}[{{offset}} + lane_id + {{i * mul}}];
138 | {%- endfor %}
139 | __syncwarp();
140 | {%- for i in range(dim) %}
141 | {{smem_ptr}}[{{offset}} + lane_id * {{dim}} + {{i}}] = t_regs[{{i}}];
142 | {%- endfor %}
143 | }
144 | } {%- endfor %}
145 | {%- endmacro %}
146 |
147 | {%- macro reg_load(mul, dim, smem, offset, reg) %}
148 | if(lane_id < {{mul}}) {
149 | {%- for i in range(dim) %}
150 | {{reg}}[{{i}}] = {{smem}}[{{offset}} + lane_id + {{i * mul}}];
151 | {%- endfor %}
152 | }
153 | {%- endmacro %}
154 |
155 | {%- macro reg_store(mul, dim, smem, offset, reg, op, coeff) %}
156 | if(lane_id < {{mul}}) {
157 | {%- for i in range(dim) %}
158 | {{smem}}[{{offset}} + lane_id + {{i * mul}}] {{op}} {{reg}}[{{i}}] * {{coeff}};
159 | {%- endfor %}
160 | }
161 | {%- endmacro %}
162 |
163 | {%- macro launch_bounds(schedule) %}
164 | __launch_bounds__({{schedule.launch_config.num_threads}})
165 | {%- endmacro %}
166 |
--------------------------------------------------------------------------------
/openequivariance/templates/wmm.cuh:
--------------------------------------------------------------------------------
1 | {%- macro generate_matmul(name, M, N, K, TILES_PER_ROW, OUTPUT_RMAJOR, warp_size, A_CMAJOR=True, B_RMAJOR=True, accum=True) %}
2 |
3 | {%-set TILES_PER_COL = warp_size // TILES_PER_ROW %}
4 |
5 | template
6 | __device__ __forceinline__ void {{name}}(const T* __restrict__ A, const T* __restrict__ B, T* C) {
7 | int t_idx = threadIdx.x + blockIdx.x * blockDim.x;
8 | int lane_id = t_idx % {{warp_size}};
9 |
10 | int const rpt = {{(M + TILES_PER_COL - 1) // TILES_PER_COL}};
11 | int const cpt = {{(N + TILES_PER_ROW - 1) // TILES_PER_ROW}};
12 |
13 | T row[cpt];
14 | T col[rpt];
15 | T tile[rpt][cpt];
16 |
17 | int TI_idx = lane_id / {{TILES_PER_ROW}};
18 | int TJ_idx = lane_id % {{TILES_PER_ROW}};
19 | int is = TI_idx * rpt; int ie = (TI_idx + 1) * rpt;
20 | int js = TJ_idx * cpt; int je = (TJ_idx + 1) * cpt;
21 | int ist = min(is, {{M}}); int iet = min(ie, {{M}});
22 | int jst = min(js, {{N}}); int jet = min(je, {{N}});
23 |
24 | // Zero the output tile
25 | #pragma unroll
26 | for(int i = 0; i < rpt; i++) {
27 | #pragma unroll
28 | for(int j = 0; j < cpt; j++) {
29 | tile[i][j] = 0.0f;
30 | }
31 | }
32 |
33 | for(int k = 0; k < {{K}}; k++) {
34 | #pragma unroll
35 | for(int i = 0; i < rpt; i++) {
36 | if(ist + i < {{M}}) {
37 | {%- if A_CMAJOR %}
38 | col[i] = A[k * {{M}} + ist + i];
39 | {%- else %}
40 | col[i] = A[(ist + i) * {{K}} + k];
41 | {%- endif %}
42 | }
43 | }
44 |
45 | #pragma unroll
46 | for(int j = 0; j < cpt; j++) {
47 | if(jst + j < {{N}}) {
48 | {%- if B_RMAJOR %}
49 | row[j] = B[k * {{N}} + jst + j];
50 | {%- else %}
51 | row[j] = B[j * {{K}} + k];
52 | {%- endif %}
53 | }
54 | }
55 |
56 | #pragma unroll
57 | for(int i = 0; i < rpt; i++) {
58 | #pragma unroll
59 | for(int j = 0; j < cpt; j++) {
60 | if(ist + i < {{M}} && jst + j < {{N}}) {
61 | tile[i][j] += col[i] * row[j];
62 | }
63 | }
64 | }
65 | }
66 |
67 | {%- if accum %}
68 | {%- set op = "+=" %}
69 | {%- else %}
70 | {%- set op = "=" %}
71 | {%- endif %}
72 |
73 | // Store the output
74 | #pragma unroll
75 | for(int i = 0; i < rpt; i++) {
76 | for(int j = 0; j < cpt; j++) {
77 | if(i + ist < {{M}} && j + jst < {{N}}) {
78 | {%- if OUTPUT_RMAJOR %}
79 | C[(i + ist) * {{N}} + j + jst] {{op}} tile[i][j];
80 | {%- else %}
81 | C[(j + jst) * {{M}} + i + ist] {{op}} tile[i][j];
82 | {%- endif %}
83 | }
84 | }
85 | }
86 | }
87 |
88 | {%- endmacro %}
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "openequivariance"
7 | version = "0.1.0"
8 | authors = [
9 | { name="Austin Glover" },
10 | { name="Vivek Bharadwaj" },
11 | { name="Aydin Buluc" },
12 | { name="James Demmel" }
13 | ]
14 | description = "A fast GPU JIT kernel generator for the Clebsch-Gordon Tensor Product"
15 | requires-python = ">=3.10"
16 | dependencies = [
17 | "ninja",
18 | "jinja2",
19 | "numpy",
20 | "torch",
21 | ]
22 |
23 | [project.optional-dependencies]
24 | bench = [
25 | "matplotlib",
26 | "tqdm",
27 | "e3nn",
28 | "cuequivariance",
29 | "cuequivariance-torch",
30 | "cuequivariance-ops-torch-cu12",
31 | ]
32 |
33 | dev = [
34 | "e3nn",
35 | "pre-commit",
36 | "ruff",
37 | "pytest",
38 | "pytest-check",
39 | "torch_geometric",
40 | "cmake",
41 | "furo",
42 | "sphinx",
43 | "sphinx-autobuild"
44 | ]
45 |
46 | [tool.setuptools.packages.find]
47 | include = ["openequivariance*"]
48 |
49 | [tool.pytest.ini_options]
50 | addopts = [
51 | "--import-mode=importlib",
52 | ]
53 |
54 | [tool.ruff]
55 | lint.ignore = ["E741"]
--------------------------------------------------------------------------------
/tests/batch_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pytest_check import check
3 |
4 | import numpy as np
5 | import openequivariance as oeq
6 | from openequivariance.implementations.TensorProduct import TensorProduct
7 | from openequivariance.benchmark.correctness_utils import (
8 | correctness_forward,
9 | correctness_backward,
10 | correctness_double_backward,
11 | )
12 | from itertools import product
13 |
14 |
15 | class TPCorrectness:
16 | def thresh(self, direction):
17 | return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction]
18 |
19 | def check_result(self, result, fieldname):
20 | with check:
21 | error = result[fieldname]["diff_Linf_norm"]
22 | thresh = result["thresh"]
23 | assert result[fieldname]["pass"], (
24 | f"{fieldname} observed error={error:.5f} >= {thresh}"
25 | )
26 |
27 | @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="class")
28 | def dtype(self, request):
29 | return request.param
30 |
31 | @pytest.fixture(scope="class")
32 | def tp_and_problem(self, problem):
33 | tp = TensorProduct(problem)
34 | return tp, problem
35 |
36 | def test_tp_fwd(self, tp_and_problem):
37 | tp, problem = tp_and_problem
38 | result = correctness_forward(
39 | problem=problem,
40 | test_implementation=tp,
41 | reference_implementation=None,
42 | batch_size=1000,
43 | correctness_threshold=self.thresh("fwd"),
44 | prng_seed=12345,
45 | )
46 |
47 | self.check_result(result, "output")
48 |
49 | def test_tp_bwd(self, tp_and_problem):
50 | tp, problem = tp_and_problem
51 | result = correctness_backward(
52 | problem=problem,
53 | test_implementation=tp,
54 | reference_implementation=None,
55 | batch_size=1000,
56 | correctness_threshold=self.thresh("bwd"),
57 | prng_seed=12345,
58 | )
59 |
60 | self.check_result(result, "weight_grad")
61 | self.check_result(result, "in1_grad")
62 | self.check_result(result, "in2_grad")
63 |
64 | def test_tp_double_bwd(self, tp_and_problem):
65 | tp, problem = tp_and_problem
66 | result = correctness_double_backward(
67 | problem=problem,
68 | test_implementation=tp,
69 | reference_implementation=None,
70 | batch_size=200,
71 | correctness_threshold=self.thresh("double_bwd"),
72 | prng_seed=12345,
73 | )
74 |
75 | self.check_result(result, "output_double_grad")
76 | self.check_result(result, "in1_grad")
77 | self.check_result(result, "in2_grad")
78 | self.check_result(result, "weights_grad")
79 |
80 |
81 | class TestProductionModels(TPCorrectness):
82 | from openequivariance.benchmark.problems import (
83 | e3nn_torch_tetris_poly_problems,
84 | diffdock_problems,
85 | mace_problems,
86 | nequip_problems,
87 | )
88 |
89 | production_model_tpps = (
90 | mace_problems()
91 | + nequip_problems()
92 | + e3nn_torch_tetris_poly_problems()
93 | + diffdock_problems()
94 | )
95 |
96 | @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class")
97 | def problem(self, request, dtype):
98 | request.param.irrep_dtype, request.param.weight_dtype = dtype, dtype
99 | return request.param
100 |
101 |
102 | class TestUVUSingleIrrep(TPCorrectness):
103 | muls = [
104 | (1, 1, 1),
105 | (2, 1, 2),
106 | (4, 1, 4),
107 | (8, 1, 8),
108 | (16, 1, 16),
109 | (32, 1, 32),
110 | (5, 1, 5),
111 | (13, 1, 13),
112 | (19, 1, 19),
113 | (33, 1, 33),
114 | (49, 1, 49),
115 | (50, 1, 50),
116 | (123, 1, 123),
117 | (128, 1, 128),
118 | (256, 1, 256),
119 | (512, 1, 512),
120 | (1, 2, 1),
121 | (1, 4, 1),
122 | (1, 16, 1),
123 | (1, 32, 1),
124 | (16, 3, 16),
125 | (16, 9, 16),
126 | (24, 24, 24),
127 | (32, 32, 32),
128 | ]
129 |
130 | irs = [
131 | (0, 0, 0),
132 | (1, 1, 1),
133 | (1, 0, 1),
134 | (1, 2, 1),
135 | (2, 0, 2),
136 | (2, 2, 4),
137 | (2, 2, 2),
138 | (5, 3, 5),
139 | (7, 2, 5),
140 | ]
141 |
142 | def id_func(m, i):
143 | return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e"
144 |
145 | @pytest.fixture(
146 | params=product(muls, irs),
147 | ids=lambda x: TestUVUSingleIrrep.id_func(x[0], x[1]),
148 | scope="class",
149 | )
150 | def problem(self, request, dtype):
151 | m, i = request.param[0], request.param[1]
152 | instructions = [(0, 0, 0, "uvu", True)]
153 | return oeq.TPProblem(
154 | f"{m[0]}x{i[0]}e",
155 | f"{m[1]}x{i[1]}e",
156 | f"{m[2]}x{i[2]}e",
157 | instructions,
158 | shared_weights=False,
159 | internal_weights=False,
160 | irrep_dtype=dtype,
161 | weight_dtype=dtype,
162 | )
163 |
164 |
165 | class TestUVWSingleIrrep(TPCorrectness):
166 | muls = [
167 | (1, 1, 1),
168 | (2, 1, 2),
169 | (4, 1, 4),
170 | (8, 1, 8),
171 | (16, 1, 16),
172 | (32, 1, 32),
173 | (5, 1, 5),
174 | (13, 1, 13),
175 | (19, 1, 19),
176 | (33, 1, 33),
177 | (49, 1, 49),
178 | (50, 1, 50),
179 | (64, 1, 64),
180 | (1, 2, 1),
181 | (1, 4, 1),
182 | (1, 16, 1),
183 | (1, 32, 1),
184 | (16, 3, 16),
185 | (16, 9, 16),
186 | (24, 24, 24),
187 | (32, 32, 32),
188 | ]
189 |
190 | irs = [
191 | (0, 0, 0),
192 | (1, 1, 1),
193 | (1, 0, 1),
194 | (1, 2, 1),
195 | (2, 0, 2),
196 | (2, 2, 4),
197 | (2, 2, 2),
198 | (5, 3, 5),
199 | (7, 2, 5),
200 | ]
201 |
202 | def id_func(m, i):
203 | return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e"
204 |
205 | @pytest.fixture(
206 | params=product(muls, irs),
207 | ids=lambda x: TestUVWSingleIrrep.id_func(x[0], x[1]),
208 | scope="class",
209 | )
210 | def problem(self, request, dtype):
211 | m, i = request.param[0], request.param[1]
212 | instructions = [(0, 0, 0, "uvw", True)]
213 | return oeq.TPProblem(
214 | f"{m[0]}x{i[0]}e",
215 | f"{m[1]}x{i[1]}e",
216 | f"{m[2]}x{i[2]}e",
217 | instructions,
218 | shared_weights=False,
219 | internal_weights=False,
220 | irrep_dtype=dtype,
221 | weight_dtype=dtype,
222 | )
223 |
224 |
225 | class TestSharedWeights(TPCorrectness):
226 | from openequivariance.benchmark.problems import (
227 | mace_problems,
228 | diffdock_problems,
229 | )
230 |
231 | problems = [mace_problems()[0], diffdock_problems()[0]]
232 |
233 | def thresh(self, direction):
234 | return {
235 | "fwd": 1e-5,
236 | "bwd": 5e-4, # Expect higher errors for shared weights
237 | "double_bwd": 5e-4,
238 | }[direction]
239 |
240 | @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class")
241 | def problem(self, request, dtype):
242 | problem = request.param
243 | problem.irrep_dtype, problem.weight_dtype = dtype, dtype
244 | problem.shared_weights = True
245 | return problem
246 |
--------------------------------------------------------------------------------
/tests/conv_test.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import tempfile
3 | import urllib
4 | from pytest_check import check
5 |
6 | import numpy as np
7 | import openequivariance as oeq
8 | from openequivariance.benchmark.ConvBenchmarkSuite import load_graph
9 | from itertools import product
10 |
11 |
12 | class ConvCorrectness:
13 | def thresh(self, direction):
14 | return {"fwd": 3e-4, "bwd": 3e-4, "double_bwd": 3e-4}[direction]
15 |
16 | def check_result(self, result, fieldname):
17 | with check:
18 | error = result[fieldname]["diff_Linf_norm"]
19 | thresh = result["thresh"]
20 | assert result[fieldname]["pass"], (
21 | f"{fieldname} observed error={error:.5f} >= {thresh}"
22 | )
23 |
24 | @pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="class")
25 | def dtype(self, request):
26 | return request.param
27 |
28 | @pytest.fixture(params=["1drf_radius3.5.pickle"], ids=["1drf"], scope="class")
29 | def graph(self, request):
30 | download_prefix = (
31 | "https://portal.nersc.gov/project/m1982/equivariant_nn_graphs/"
32 | )
33 | filename = request.param
34 |
35 | graph = None
36 | with tempfile.NamedTemporaryFile() as temp_file:
37 | urllib.request.urlretrieve(download_prefix + filename, temp_file.name)
38 | graph = load_graph(temp_file.name)
39 |
40 | # graph = load_graph("data/1drf_radius3.5.pickle")
41 | return graph
42 |
43 | @pytest.fixture(params=["atomic", "deterministic", "kahan"], scope="class")
44 | def conv_object(self, request, problem):
45 | if request.param == "atomic":
46 | return oeq.TensorProductConv(problem, deterministic=False)
47 | elif request.param == "deterministic":
48 | if not problem.shared_weights:
49 | return oeq.TensorProductConv(problem, deterministic=True)
50 | else:
51 | pytest.skip("Shared weights not supported with deterministic")
52 | elif request.param == "kahan":
53 | if problem.irrep_dtype == np.float32:
54 | if not problem.shared_weights:
55 | return oeq.TensorProductConv(
56 | problem, deterministic=True, kahan=True
57 | )
58 | else:
59 | pytest.skip("Shared weights not supported with kahan")
60 | else:
61 | pytest.skip("Only Float32 supported with kahan")
62 |
63 | def test_tp_fwd(self, conv_object, graph):
64 | if conv_object is None:
65 | pytest.skip("'conv_object' fixture returned None, skipping")
66 |
67 | result = conv_object.test_correctness_forward(
68 | graph,
69 | thresh=self.thresh("fwd"),
70 | prng_seed=12345,
71 | reference_implementation=None,
72 | )
73 |
74 | self.check_result(result, "output")
75 |
76 | def test_tp_bwd(self, conv_object, graph):
77 | if conv_object is None:
78 | pytest.skip("'conv_object' fixture returned None, skipping")
79 |
80 | result = conv_object.test_correctness_backward(
81 | graph,
82 | thresh=self.thresh("bwd"),
83 | prng_seed=12345,
84 | reference_implementation=None,
85 | )
86 |
87 | self.check_result(result, "weight_grad")
88 | self.check_result(result, "in1_grad")
89 | self.check_result(result, "in2_grad")
90 |
91 | def test_tp_double_bwd(self, conv_object, graph):
92 | if conv_object is None:
93 | pytest.skip("'conv_object' fixture returned None, skipping")
94 |
95 | result = conv_object.test_correctness_double_backward(
96 | graph,
97 | thresh=self.thresh("double_bwd"),
98 | prng_seed=12345,
99 | reference_implementation=None,
100 | )
101 |
102 | self.check_result(result, "output_grad")
103 | self.check_result(result, "in1_grad")
104 | self.check_result(result, "in2_grad")
105 | self.check_result(result, "weights_grad")
106 |
107 |
108 | class TestProductionModels(ConvCorrectness):
109 | from openequivariance.benchmark.problems import (
110 | mace_problems,
111 | diffdock_problems,
112 | )
113 |
114 | production_model_tpps = mace_problems() + diffdock_problems()
115 |
116 | @pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class")
117 | def problem(self, request, dtype):
118 | request.param.irrep_dtype, request.param.weight_dtype = dtype, dtype
119 | return request.param
120 |
121 |
122 | class TestUVUSingleIrrep(ConvCorrectness):
123 | muls = [
124 | (1, 1, 1),
125 | (8, 1, 8),
126 | (16, 1, 16),
127 | (32, 1, 32),
128 | (5, 1, 5),
129 | (13, 1, 13),
130 | (19, 1, 19),
131 | (33, 1, 33),
132 | (49, 1, 49),
133 | (128, 1, 128),
134 | (1, 2, 1),
135 | (1, 16, 1),
136 | (1, 32, 1),
137 | (16, 3, 16),
138 | ]
139 |
140 | irs = [(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (2, 0, 2), (5, 3, 5), (7, 2, 5)]
141 |
142 | def id_func(m, i):
143 | return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e"
144 |
145 | @pytest.fixture(
146 | params=product(muls, irs),
147 | ids=lambda x: TestUVUSingleIrrep.id_func(x[0], x[1]),
148 | scope="class",
149 | )
150 | def problem(self, request, dtype):
151 | m, i = request.param[0], request.param[1]
152 | instructions = [(0, 0, 0, "uvu", True)]
153 | return oeq.TPProblem(
154 | f"{m[0]}x{i[0]}e",
155 | f"{m[1]}x{i[1]}e",
156 | f"{m[2]}x{i[2]}e",
157 | instructions,
158 | shared_weights=False,
159 | internal_weights=False,
160 | irrep_dtype=dtype,
161 | weight_dtype=dtype,
162 | )
163 |
164 |
165 | class TestUVWSingleIrrep(ConvCorrectness):
166 | muls = [
167 | (1, 1, 1),
168 | (4, 1, 4),
169 | (8, 1, 8),
170 | (16, 1, 16),
171 | (32, 1, 32),
172 | (5, 1, 5),
173 | (13, 1, 13),
174 | (33, 1, 33),
175 | (49, 1, 49),
176 | (64, 1, 64),
177 | (1, 2, 1),
178 | (1, 4, 1),
179 | (1, 16, 1),
180 | (1, 32, 1),
181 | (16, 3, 16),
182 | ]
183 |
184 | irs = [(0, 0, 0), (1, 1, 1), (1, 0, 1), (1, 2, 1), (5, 3, 5), (7, 2, 5)]
185 |
186 | def id_func(m, i):
187 | return f"{m[0]}x{i[0]}e__x__{m[1]}x{i[1]}e---{m[2]}x{i[2]}e"
188 |
189 | @pytest.fixture(
190 | params=product(muls, irs),
191 | ids=lambda x: TestUVWSingleIrrep.id_func(x[0], x[1]),
192 | scope="class",
193 | )
194 | def problem(self, request, dtype):
195 | m, i = request.param[0], request.param[1]
196 | instructions = [(0, 0, 0, "uvw", True)]
197 | return oeq.TPProblem(
198 | f"{m[0]}x{i[0]}e",
199 | f"{m[1]}x{i[1]}e",
200 | f"{m[2]}x{i[2]}e",
201 | instructions,
202 | shared_weights=False,
203 | internal_weights=False,
204 | irrep_dtype=dtype,
205 | weight_dtype=dtype,
206 | )
207 |
208 |
209 | class TestAtomicSharedWeights(ConvCorrectness):
210 | from openequivariance.benchmark.problems import (
211 | mace_problems,
212 | diffdock_problems,
213 | )
214 |
215 | problems = [mace_problems()[0], diffdock_problems()[0]]
216 |
217 | def thresh(self, direction):
218 | return {
219 | "fwd": 1e-5,
220 | "bwd": 5e-2, # Expect higher errors for shared weights
221 | "double_bwd": 5e-2,
222 | }[direction]
223 |
224 | @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class")
225 | def problem(self, request, dtype):
226 | problem = request.param
227 | problem.irrep_dtype, problem.weight_dtype = dtype, dtype
228 | problem.shared_weights = True
229 | return problem
230 |
231 | @pytest.fixture(scope="class")
232 | def conv_object(self, request, problem):
233 | return oeq.TensorProductConv(problem, deterministic=False)
234 |
--------------------------------------------------------------------------------
/tests/export_test.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import torch
3 | import pytest
4 | import tempfile
5 | import subprocess
6 | import os
7 | import sys
8 |
9 | import numpy as np
10 | import openequivariance as oeq
11 | from torch_geometric import EdgeIndex
12 | import importlib.resources
13 |
14 | from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct
15 |
16 |
17 | @pytest.fixture(scope="session")
18 | def problem_and_irreps():
19 | X_ir, Y_ir, Z_ir = oeq.Irreps("32x5e"), oeq.Irreps("1x3e"), oeq.Irreps("32x5e")
20 | problem = oeq.TPProblem(
21 | X_ir,
22 | Y_ir,
23 | Z_ir,
24 | [(0, 0, 0, "uvu", True)],
25 | shared_weights=False,
26 | internal_weights=False,
27 | irrep_dtype=np.float32,
28 | weight_dtype=np.float32,
29 | )
30 |
31 | gen = torch.Generator(device="cuda")
32 | gen.manual_seed(0)
33 |
34 | return (
35 | problem,
36 | X_ir,
37 | Y_ir,
38 | Z_ir,
39 | )
40 |
41 |
42 | @pytest.fixture(params=["batch", "conv_det", "conv_atomic"], scope="session")
43 | def tp_and_inputs(request, problem_and_irreps):
44 | problem, X_ir, Y_ir, _ = problem_and_irreps
45 | gen = torch.Generator(device="cuda")
46 | gen.manual_seed(0)
47 |
48 | if request.param == "batch":
49 | batch_size = 1000
50 | X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen)
51 | Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen)
52 | W = torch.rand(batch_size, problem.weight_numel, device="cuda", generator=gen)
53 | return oeq.TensorProduct(problem), (X, Y, W)
54 | else:
55 | node_ct, nonzero_ct = 3, 4
56 |
57 | # Receiver, sender indices for message passing GNN
58 | edge_index = EdgeIndex(
59 | [[0, 1, 1, 2], [1, 0, 2, 1]], device="cuda", dtype=torch.long
60 | )
61 |
62 | _, sender_perm = edge_index.sort_by("col")
63 | edge_index, _ = edge_index.sort_by("row")
64 | edge_index = [edge_index[0].detach(), edge_index[1].detach()]
65 |
66 | X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen)
67 | Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen)
68 | W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen)
69 |
70 | if request.param == "conv_atomic":
71 | return oeq.TensorProductConv(problem, torch_op=True, deterministic=False), (
72 | X,
73 | Y,
74 | W,
75 | edge_index[0],
76 | edge_index[1],
77 | )
78 | elif request.param == "conv_det":
79 | return oeq.TensorProductConv(problem, torch_op=True, deterministic=True), (
80 | X,
81 | Y,
82 | W,
83 | edge_index[0],
84 | edge_index[1],
85 | sender_perm,
86 | )
87 |
88 |
89 | def test_jitscript(tp_and_inputs):
90 | tp, inputs = tp_and_inputs
91 | uncompiled_result = tp.forward(*inputs)
92 |
93 | scripted_tp = torch.jit.script(tp)
94 | loaded_tp = None
95 | with tempfile.NamedTemporaryFile(suffix=".pt") as tmp_file:
96 | scripted_tp.save(tmp_file.name)
97 | loaded_tp = torch.jit.load(tmp_file.name)
98 |
99 | compiled_result = loaded_tp(*inputs)
100 | assert torch.allclose(uncompiled_result, compiled_result, atol=1e-5)
101 |
102 |
103 | def test_compile(tp_and_inputs):
104 | tp, inputs = tp_and_inputs
105 | uncompiled_result = tp.forward(*inputs)
106 |
107 | compiled_tp = torch.compile(tp)
108 | compiled_result = compiled_tp(*inputs)
109 | assert torch.allclose(uncompiled_result, compiled_result, atol=1e-5)
110 |
111 |
112 | def test_export(tp_and_inputs):
113 | tp, inputs = tp_and_inputs
114 | uncompiled_result = tp.forward(*inputs)
115 |
116 | exported_tp = torch.export.export(tp, args=inputs, strict=False)
117 | exported_result = exported_tp.module()(*inputs)
118 | assert torch.allclose(uncompiled_result, exported_result, atol=1e-5)
119 |
120 |
121 | def test_aoti(tp_and_inputs):
122 | tp, inputs = tp_and_inputs
123 | uncompiled_result = tp.forward(*inputs)
124 |
125 | exported_tp = torch.export.export(tp, args=inputs, strict=False)
126 | aoti_model = None
127 | with tempfile.NamedTemporaryFile(suffix=".pt2") as tmp_file:
128 | try:
129 | output_path = torch._inductor.aoti_compile_and_package(
130 | exported_tp, package_path=tmp_file.name
131 | )
132 | except Exception as e:
133 | err_msg = (
134 | "AOTI compile_and_package failed. NOTE: OpenEquivariance only supports AOTI for "
135 | + "PyTorch version >= 2.8.0.dev20250410+cu126 due to incomplete TorchBind support "
136 | + "in prior versions. "
137 | + f"{e}"
138 | )
139 | assert False, err_msg
140 |
141 | aoti_model = torch._inductor.aoti_load_package(output_path)
142 |
143 | aoti_result = aoti_model(*inputs)
144 | assert torch.allclose(uncompiled_result, aoti_result, atol=1e-5)
145 |
146 |
147 | def test_jitscript_cpp_interface(problem_and_irreps):
148 | problem, X_ir, Y_ir, _ = problem_and_irreps
149 | cmake_prefix_path = torch.utils.cmake_prefix_path
150 | torch_ext_so_path = oeq.torch_ext_so_path()
151 |
152 | oeq_tp = oeq.TensorProduct(problem).to("cuda")
153 | scripted_oeq = torch.jit.script(oeq_tp)
154 |
155 | e3nn_tp = E3NNTensorProduct(problem).e3nn_tp.to("cuda")
156 | scripted_e3nn = torch.jit.script(e3nn_tp)
157 |
158 | batch_size = 1000
159 |
160 | with (
161 | tempfile.TemporaryDirectory() as tmpdir,
162 | tempfile.NamedTemporaryFile(suffix=".pt") as oeq_file,
163 | tempfile.NamedTemporaryFile(suffix=".pt") as e3nn_file,
164 | ):
165 | scripted_oeq.save(oeq_file.name)
166 | scripted_e3nn.save(e3nn_file.name)
167 |
168 | test_path = importlib.resources.files("openequivariance") / "extension" / "test"
169 | build_dir = os.path.join(tmpdir, "build")
170 | os.makedirs(build_dir, exist_ok=True)
171 |
172 | for item in test_path.iterdir():
173 | shutil.copy(item, tmpdir)
174 |
175 | try:
176 | subprocess.run(
177 | [
178 | "cmake",
179 | "..",
180 | "-DCMAKE_BUILD_TYPE=Release",
181 | "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path,
182 | "-DOEQ_EXTLIB=" + torch_ext_so_path,
183 | ],
184 | cwd=build_dir,
185 | check=True,
186 | stdout=subprocess.PIPE,
187 | stderr=subprocess.PIPE,
188 | )
189 |
190 | subprocess.run(
191 | ["make"],
192 | cwd=build_dir,
193 | check=True,
194 | stdout=subprocess.PIPE,
195 | stderr=subprocess.PIPE,
196 | )
197 |
198 | subprocess.run(
199 | [
200 | "./load_jitscript",
201 | e3nn_file.name,
202 | oeq_file.name,
203 | str(X_ir.dim),
204 | str(Y_ir.dim),
205 | str(problem.weight_numel),
206 | str(batch_size),
207 | ],
208 | cwd=build_dir,
209 | check=True,
210 | stdout=subprocess.PIPE,
211 | stderr=subprocess.PIPE,
212 | )
213 | except subprocess.CalledProcessError as e:
214 | print(e.stdout.decode(), file=sys.stderr)
215 | print(e.stderr.decode(), file=sys.stderr)
216 | assert False
217 |
--------------------------------------------------------------------------------
/tests/import_test.py:
--------------------------------------------------------------------------------
1 | def test_import():
2 | import openequivariance
3 |
4 | assert openequivariance.__version__ is not None
5 | assert openequivariance.__version__ != "0.0.0"
6 |
7 |
8 | def test_tutorial():
9 | import torch
10 | import e3nn.o3 as o3
11 |
12 | gen = torch.Generator(device="cuda")
13 |
14 | batch_size = 1000
15 | X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e")
16 | X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen)
17 | Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen)
18 |
19 | instructions = [(0, 0, 0, "uvu", True)]
20 |
21 | tp_e3nn = o3.TensorProduct(
22 | X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
23 | ).to("cuda")
24 | W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen)
25 |
26 | Z = tp_e3nn(X, Y, W)
27 | print(torch.norm(Z))
28 | # ===============================
29 |
30 | # ===============================
31 | import openequivariance as oeq
32 |
33 | problem = oeq.TPProblem(
34 | X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
35 | )
36 | tp_fast = oeq.TensorProduct(problem, torch_op=True)
37 |
38 | Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier
39 | print(torch.norm(Z))
40 | # ===============================
41 |
42 | # Graph Convolution
43 | # ===============================
44 | from torch_geometric import EdgeIndex
45 |
46 | node_ct, nonzero_ct = 3, 4
47 |
48 | # Receiver, sender indices for message passing GNN
49 | edge_index = EdgeIndex(
50 | [
51 | [0, 1, 1, 2], # Receiver
52 | [1, 0, 2, 1],
53 | ], # Sender
54 | device="cuda",
55 | dtype=torch.long,
56 | )
57 |
58 | X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen)
59 | Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen)
60 | W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen)
61 |
62 | tp_conv = oeq.TensorProductConv(
63 | problem, torch_op=True, deterministic=False
64 | ) # Reuse problem from earlier
65 | Z = tp_conv.forward(
66 | X, Y, W, edge_index[0], edge_index[1]
67 | ) # Z has shape [node_ct, z_ir.dim]
68 | print(torch.norm(Z))
69 | # ===============================
70 |
71 | # ===============================
72 | _, sender_perm = edge_index.sort_by("col") # Sort by sender index
73 | edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index
74 |
75 | # Now we can use the faster deterministic algorithm
76 | tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True)
77 | Z = tp_conv.forward(
78 | X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm
79 | )
80 | print(torch.norm(Z))
81 | # ===============================
82 | assert True
83 |
--------------------------------------------------------------------------------
/tests/multidevice_test.py:
--------------------------------------------------------------------------------
1 | import textwrap
2 | import torch
3 | import subprocess
4 | import os
5 |
6 |
7 | def test_multidevice():
8 | result = subprocess.run(
9 | [
10 | "python",
11 | "-m",
12 | "torch.distributed.run",
13 | "--standalone",
14 | "--nnodes=1",
15 | "--nproc-per-node=gpu",
16 | __file__,
17 | ],
18 | capture_output=True,
19 | check=False,
20 | )
21 |
22 | if result.returncode != 0:
23 | error_string = f"""
24 | Invocation: {" ".join(result.args)}
25 | Test failed with return code {result.returncode}.
26 | \nOutput:\n\n{result.stdout.decode()}
27 | \nError:\n\n{result.stderr.decode()}
28 | """
29 | assert False, textwrap.dedent(error_string)
30 |
31 | assert True
32 |
33 |
34 | if __name__ == "__main__":
35 | import openequivariance as oeq
36 |
37 | # Use MACE-large to test >64KB shared memory allocation
38 | from openequivariance.benchmark.problems import mace_problems
39 |
40 | problem = mace_problems()[0]
41 |
42 | local_rank = int(os.environ["LOCAL_RANK"])
43 | device = f"cuda:{local_rank}"
44 | torch.set_default_device(device)
45 |
46 | X_ir, Y_ir, Z_ir = problem.irreps_in1, problem.irreps_in2, problem.irreps_out
47 | tp = oeq.TensorProduct(problem)
48 |
49 | batch_size = 1000
50 | gen = torch.Generator(device=device)
51 | gen.manual_seed(0)
52 | X = torch.rand(batch_size, X_ir.dim, device=device, generator=gen)
53 | Y = torch.rand(batch_size, Y_ir.dim, device=device, generator=gen)
54 | W = torch.rand(batch_size, problem.weight_numel, device=device, generator=gen)
55 |
56 | with torch.cuda.device(device):
57 | result = tp.forward(X, Y, W)
58 |
--------------------------------------------------------------------------------