├── docs ├── source │ ├── modules │ │ ├── ops.rst │ │ ├── nn.rst │ │ ├── utils.rst │ │ ├── irreps.rst │ │ ├── structs.rst │ │ ├── constants.rst │ │ └── transforms.rst │ ├── _static │ │ ├── css │ │ │ └── custom.css │ │ ├── logo.png │ │ ├── logo_.png │ │ └── logo_wide.png │ ├── api.rst │ ├── get_started │ │ └── installation.rst │ ├── index.rst │ └── conf.py ├── build │ ├── html │ │ ├── _static │ │ │ ├── css │ │ │ │ ├── custom.css │ │ │ │ ├── fonts │ │ │ │ │ ├── lato-bold.woff │ │ │ │ │ ├── lato-bold.woff2 │ │ │ │ │ ├── lato-normal.woff │ │ │ │ │ ├── lato-normal.woff2 │ │ │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ │ │ ├── fontawesome-webfont.eot │ │ │ │ │ ├── fontawesome-webfont.ttf │ │ │ │ │ ├── lato-bold-italic.woff │ │ │ │ │ ├── lato-bold-italic.woff2 │ │ │ │ │ ├── lato-normal-italic.woff │ │ │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ │ │ ├── fontawesome-webfont.woff │ │ │ │ │ ├── fontawesome-webfont.woff2 │ │ │ │ │ └── lato-normal-italic.woff2 │ │ │ │ └── badge_only.css │ │ │ ├── custom.css │ │ │ ├── file.png │ │ │ ├── logo.png │ │ │ ├── plus.png │ │ │ ├── logo_.png │ │ │ ├── minus.png │ │ │ ├── logo_wide.png │ │ │ ├── fonts │ │ │ │ ├── Lato │ │ │ │ │ ├── lato-bold.eot │ │ │ │ │ ├── lato-bold.ttf │ │ │ │ │ ├── lato-bold.woff │ │ │ │ │ ├── lato-bold.woff2 │ │ │ │ │ ├── lato-italic.eot │ │ │ │ │ ├── lato-italic.ttf │ │ │ │ │ ├── lato-italic.woff │ │ │ │ │ ├── lato-italic.woff2 │ │ │ │ │ ├── lato-regular.eot │ │ │ │ │ ├── lato-regular.ttf │ │ │ │ │ ├── lato-regular.woff │ │ │ │ │ ├── lato-bolditalic.eot │ │ │ │ │ ├── lato-bolditalic.ttf │ │ │ │ │ ├── lato-regular.woff2 │ │ │ │ │ ├── lato-bolditalic.woff │ │ │ │ │ └── lato-bolditalic.woff2 │ │ │ │ └── RobotoSlab │ │ │ │ │ ├── roboto-slab-v7-bold.eot │ │ │ │ │ ├── roboto-slab-v7-bold.ttf │ │ │ │ │ ├── roboto-slab-v7-bold.woff │ │ │ │ │ ├── roboto-slab-v7-bold.woff2 │ │ │ │ │ ├── roboto-slab-v7-regular.eot │ │ │ │ │ ├── roboto-slab-v7-regular.ttf │ │ │ │ │ ├── roboto-slab-v7-regular.woff │ │ │ │ │ └── roboto-slab-v7-regular.woff2 │ │ │ ├── documentation_options.js │ │ │ ├── github-banner.svg │ │ │ └── js │ │ │ │ └── badge_only.js │ │ ├── _sources │ │ │ ├── modules │ │ │ │ ├── ops.rst.txt │ │ │ │ ├── modules.rst.txt │ │ │ │ ├── nn.rst.txt │ │ │ │ ├── utils.rst.txt │ │ │ │ ├── irreps.rst.txt │ │ │ │ ├── structs.rst.txt │ │ │ │ ├── constants.rst.txt │ │ │ │ ├── transforms.rst.txt │ │ │ │ ├── equitorch.nn.init.rst.txt │ │ │ │ ├── equitorch.nn.norm.rst.txt │ │ │ │ ├── equitorch.nn.others.rst.txt │ │ │ │ ├── equitorch.nn.angular.rst.txt │ │ │ │ ├── equitorch.nn.cutoffs.rst.txt │ │ │ │ ├── equitorch.nn.dropout.rst.txt │ │ │ │ ├── equitorch.nn.linears.rst.txt │ │ │ │ ├── equitorch.nn.radials.rst.txt │ │ │ │ ├── equitorch.nn.wigner_d.rst.txt │ │ │ │ ├── equitorch.nn.rotations.rst.txt │ │ │ │ ├── equitorch.nn.activations.rst.txt │ │ │ │ ├── equitorch.nn.sphericals.rst.txt │ │ │ │ ├── equitorch.nn.normalization.rst.txt │ │ │ │ ├── equitorch.ops.kernel_dense.rst.txt │ │ │ │ ├── equitorch.ops.kernel_utils.rst.txt │ │ │ │ ├── equitorch.nn.functional.norm.rst.txt │ │ │ │ ├── equitorch.nn.tensor_products.rst.txt │ │ │ │ ├── equitorch.utils.rst.txt │ │ │ │ ├── equitorch.irreps.rst.txt │ │ │ │ ├── equitorch.nn.functional.angular.rst.txt │ │ │ │ ├── equitorch.nn.functional.cutoffs.rst.txt │ │ │ │ ├── equitorch.nn.functional.dropout.rst.txt │ │ │ │ ├── equitorch.nn.functional.linears.rst.txt │ │ │ │ ├── equitorch.structs.rst.txt │ │ │ │ ├── equitorch.nn.functional.wigner_d.rst.txt │ │ │ │ ├── equitorch.constants.rst.txt │ │ │ │ ├── equitorch.nn.functional.rotations.rst.txt │ │ │ │ ├── equitorch.nn.functional.sphericals.rst.txt │ │ │ │ ├── equitorch.ops.indexed_product_op.rst.txt │ │ │ │ ├── equitorch.ops.product_segment_op.rst.txt │ │ │ │ ├── equitorch.ops.spherical_harmonics.rst.txt │ │ │ │ ├── equitorch.nn.functional.activations.rst.txt │ │ │ │ ├── equitorch.transforms.rst.txt │ │ │ │ ├── equitorch.nn.functional.sparse_scale.rst.txt │ │ │ │ ├── equitorch.nn.functional.normalization.rst.txt │ │ │ │ ├── equitorch.nn.functional.sparse_product.rst.txt │ │ │ │ ├── equitorch.ops.batched_sparse_dense_op.rst.txt │ │ │ │ ├── equitorch.nn.functional.tensor_products.rst.txt │ │ │ │ ├── equitorch.ops.indexed_scale_segment_op.rst.txt │ │ │ │ ├── equitorch.ops.indexed_product_segment_op.rst.txt │ │ │ │ ├── equitorch.ops.indexed_product_scale_segment_op.rst.txt │ │ │ │ ├── equitorch.ops.accumulated_indexed_product_segment_op.rst.txt │ │ │ │ ├── equitorch.rst.txt │ │ │ │ ├── equitorch.ops.rst.txt │ │ │ │ ├── equitorch.nn.rst.txt │ │ │ │ └── equitorch.nn.functional.rst.txt │ │ │ ├── generated │ │ │ │ ├── equitorch.ops.kernel_dense.rst.txt │ │ │ │ ├── equitorch.ops.kernel_utils.rst.txt │ │ │ │ ├── equitorch.nn.functional.cutoffs.rst.txt │ │ │ │ ├── equitorch.nn.functional.norm.rst.txt │ │ │ │ ├── equitorch.ops.spherical_harmonics.rst.txt │ │ │ │ ├── equitorch.nn.functional.sparse_scale.rst.txt │ │ │ │ ├── equitorch.nn.angular.rst.txt │ │ │ │ ├── equitorch.nn.dropout.rst.txt │ │ │ │ ├── equitorch.nn.radials.rst.txt │ │ │ │ ├── equitorch.nn.activations.rst.txt │ │ │ │ ├── equitorch.nn.rotations.rst.txt │ │ │ │ ├── equitorch.nn.others.rst.txt │ │ │ │ ├── equitorch.constants.rst.txt │ │ │ │ ├── equitorch.nn.norm.rst.txt │ │ │ │ ├── equitorch.nn.functional.angular.rst.txt │ │ │ │ ├── equitorch.nn.normalization.rst.txt │ │ │ │ ├── equitorch.nn.functional.dropout.rst.txt │ │ │ │ ├── equitorch.rst.txt │ │ │ │ ├── equitorch.nn.cutoffs.rst.txt │ │ │ │ ├── equitorch.nn.functional.activations.rst.txt │ │ │ │ ├── equitorch.nn.functional.rotations.rst.txt │ │ │ │ ├── equitorch.nn.tensor_products.rst.txt │ │ │ │ ├── equitorch.transforms.rst.txt │ │ │ │ ├── equitorch.nn.init.rst.txt │ │ │ │ ├── equitorch.nn.linears.rst.txt │ │ │ │ ├── equitorch.nn.sphericals.rst.txt │ │ │ │ ├── equitorch.nn.wigner_d.rst.txt │ │ │ │ ├── equitorch.nn.functional.normalization.rst.txt │ │ │ │ ├── equitorch.nn.functional.sphericals.rst.txt │ │ │ │ ├── equitorch.nn.functional.wigner_d.rst.txt │ │ │ │ ├── equitorch.ops.indexed_product_op.rst.txt │ │ │ │ ├── equitorch.ops.product_segment_op.rst.txt │ │ │ │ ├── equitorch.ops.indexed_scale_segment_op.rst.txt │ │ │ │ ├── equitorch.ops.indexed_product_segment_op.rst.txt │ │ │ │ ├── equitorch.structs.rst.txt │ │ │ │ ├── equitorch.ops.indexed_product_scale_segment_op.rst.txt │ │ │ │ ├── equitorch.ops.rst.txt │ │ │ │ ├── equitorch.nn.functional.tensor_products.rst.txt │ │ │ │ ├── equitorch.irreps.rst.txt │ │ │ │ ├── equitorch.ops.accumulated_indexed_product_segment_op.rst.txt │ │ │ │ ├── equitorch.nn.functional.sparse_product.rst.txt │ │ │ │ ├── equitorch.nn.functional.linears.rst.txt │ │ │ │ ├── equitorch.utils.rst.txt │ │ │ │ ├── equitorch.ops.batched_sparse_dense_op.rst.txt │ │ │ │ ├── equitorch.nn.rst.txt │ │ │ │ └── equitorch.nn.functional.rst.txt │ │ │ ├── api.rst.txt │ │ │ ├── get_started │ │ │ │ └── installation.rst.txt │ │ │ └── index.rst.txt │ │ ├── objects.inv │ │ ├── _images │ │ │ └── logo_wide.png │ │ ├── .buildinfo │ │ ├── .buildinfo.bak │ │ ├── search.html │ │ └── generated │ │ │ ├── equitorch.nn.functional.cutoffs.html │ │ │ ├── equitorch.nn.radials.html │ │ │ ├── equitorch.nn.dropout.html │ │ │ └── equitorch.nn.rotations.html │ └── doctrees │ │ ├── api.doctree │ │ ├── index.doctree │ │ ├── environment.pickle │ │ ├── modules │ │ ├── nn.doctree │ │ ├── ops.doctree │ │ ├── irreps.doctree │ │ ├── utils.doctree │ │ ├── modules.doctree │ │ ├── structs.doctree │ │ ├── constants.doctree │ │ ├── equitorch.doctree │ │ ├── transforms.doctree │ │ ├── equitorch.nn.doctree │ │ ├── equitorch.ops.doctree │ │ ├── equitorch.irreps.doctree │ │ ├── equitorch.utils.doctree │ │ ├── equitorch.nn.init.doctree │ │ ├── equitorch.nn.norm.doctree │ │ ├── equitorch.structs.doctree │ │ ├── equitorch.constants.doctree │ │ ├── equitorch.nn.angular.doctree │ │ ├── equitorch.nn.cutoffs.doctree │ │ ├── equitorch.nn.dropout.doctree │ │ ├── equitorch.nn.linears.doctree │ │ ├── equitorch.nn.others.doctree │ │ ├── equitorch.nn.radials.doctree │ │ ├── equitorch.nn.wigner_d.doctree │ │ ├── equitorch.transforms.doctree │ │ ├── equitorch.nn.functional.doctree │ │ ├── equitorch.nn.rotations.doctree │ │ ├── equitorch.nn.sphericals.doctree │ │ ├── equitorch.nn.activations.doctree │ │ ├── equitorch.nn.normalization.doctree │ │ ├── equitorch.ops.kernel_dense.doctree │ │ ├── equitorch.ops.kernel_utils.doctree │ │ ├── equitorch.nn.functional.norm.doctree │ │ ├── equitorch.nn.tensor_products.doctree │ │ ├── equitorch.nn.functional.angular.doctree │ │ ├── equitorch.nn.functional.cutoffs.doctree │ │ ├── equitorch.nn.functional.dropout.doctree │ │ ├── equitorch.nn.functional.linears.doctree │ │ ├── equitorch.nn.functional.rotations.doctree │ │ ├── equitorch.nn.functional.wigner_d.doctree │ │ ├── equitorch.ops.indexed_product_op.doctree │ │ ├── equitorch.ops.product_segment_op.doctree │ │ ├── equitorch.ops.spherical_harmonics.doctree │ │ ├── equitorch.nn.functional.activations.doctree │ │ ├── equitorch.nn.functional.sparse_scale.doctree │ │ ├── equitorch.nn.functional.sphericals.doctree │ │ ├── equitorch.nn.functional.normalization.doctree │ │ ├── equitorch.nn.functional.sparse_product.doctree │ │ ├── equitorch.ops.batched_sparse_dense_op.doctree │ │ ├── equitorch.ops.indexed_scale_segment_op.doctree │ │ ├── equitorch.nn.functional.tensor_products.doctree │ │ ├── equitorch.ops.indexed_product_segment_op.doctree │ │ ├── equitorch.ops.indexed_product_scale_segment_op.doctree │ │ └── equitorch.ops.accumulated_indexed_product_segment_op.doctree │ │ ├── generated │ │ ├── equitorch.doctree │ │ ├── equitorch.nn.doctree │ │ ├── equitorch.ops.doctree │ │ ├── equitorch.irreps.doctree │ │ ├── equitorch.utils.doctree │ │ ├── equitorch.constants.doctree │ │ ├── equitorch.nn.init.doctree │ │ ├── equitorch.nn.norm.doctree │ │ ├── equitorch.nn.others.doctree │ │ ├── equitorch.structs.doctree │ │ ├── equitorch.nn.angular.doctree │ │ ├── equitorch.nn.cutoffs.doctree │ │ ├── equitorch.nn.dropout.doctree │ │ ├── equitorch.nn.linears.doctree │ │ ├── equitorch.nn.radials.doctree │ │ ├── equitorch.nn.wigner_d.doctree │ │ ├── equitorch.transforms.doctree │ │ ├── equitorch.nn.activations.doctree │ │ ├── equitorch.nn.functional.doctree │ │ ├── equitorch.nn.rotations.doctree │ │ ├── equitorch.nn.sphericals.doctree │ │ ├── equitorch.nn.normalization.doctree │ │ ├── equitorch.ops.kernel_dense.doctree │ │ ├── equitorch.ops.kernel_utils.doctree │ │ ├── equitorch.nn.functional.norm.doctree │ │ ├── equitorch.nn.tensor_products.doctree │ │ ├── equitorch.nn.functional.angular.doctree │ │ ├── equitorch.nn.functional.cutoffs.doctree │ │ ├── equitorch.nn.functional.dropout.doctree │ │ ├── equitorch.nn.functional.linears.doctree │ │ ├── equitorch.nn.functional.rotations.doctree │ │ ├── equitorch.nn.functional.sphericals.doctree │ │ ├── equitorch.nn.functional.wigner_d.doctree │ │ ├── equitorch.ops.indexed_product_op.doctree │ │ ├── equitorch.ops.product_segment_op.doctree │ │ ├── equitorch.ops.spherical_harmonics.doctree │ │ ├── equitorch.nn.functional.activations.doctree │ │ ├── equitorch.nn.functional.sparse_scale.doctree │ │ ├── equitorch.nn.functional.normalization.doctree │ │ ├── equitorch.nn.functional.sparse_product.doctree │ │ ├── equitorch.nn.functional.tensor_products.doctree │ │ ├── equitorch.ops.batched_sparse_dense_op.doctree │ │ ├── equitorch.ops.indexed_scale_segment_op.doctree │ │ ├── equitorch.ops.indexed_product_segment_op.doctree │ │ ├── equitorch.ops.indexed_product_scale_segment_op.doctree │ │ └── equitorch.ops.accumulated_indexed_product_segment_op.doctree │ │ └── get_started │ │ └── installation.doctree ├── Makefile └── make.bat ├── equitorch ├── ops │ ├── __init__.py │ ├── product_segment_op.py │ ├── indexed_product_op.py │ ├── indexed_product_segment_op.py │ └── indexed_product_scale_segment_op.py ├── __init__.py ├── nn │ ├── functional │ │ ├── sparse_scale.py │ │ ├── angular.py │ │ ├── activations.py │ │ ├── cutoffs.py │ │ ├── dropout.py │ │ ├── __init__.py │ │ └── rotations.py │ ├── radials.py │ ├── rotations.py │ ├── __init__.py │ ├── angular.py │ └── dropout.py └── utils │ ├── __init__.py │ └── _random.py ├── img └── logo_wide.png ├── .readthedocs.yaml ├── README.md ├── LICENSE ├── setup.py ├── test ├── test_op │ └── test_indexed_scale_segment.py ├── test_other_modules │ └── test_irreps_split.ipynb ├── test_norm │ ├── test_grad_layer_norm.py │ └── test_grad_batch_norm.py └── test_linear │ └── test_grad_linear.py └── prompts └── README.md /docs/source/modules/ops.rst: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /equitorch/ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/build/html/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/ops.rst.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/build/html/_static/custom.css: -------------------------------------------------------------------------------- 1 | /* This file intentionally left blank. */ 2 | -------------------------------------------------------------------------------- /img/logo_wide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/img/logo_wide.png -------------------------------------------------------------------------------- /docs/build/html/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/objects.inv -------------------------------------------------------------------------------- /docs/source/_static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/source/_static/logo.png -------------------------------------------------------------------------------- /docs/source/_static/logo_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/source/_static/logo_.png -------------------------------------------------------------------------------- /docs/build/doctrees/api.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/api.doctree -------------------------------------------------------------------------------- /docs/build/html/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/file.png -------------------------------------------------------------------------------- /docs/build/html/_static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/logo.png -------------------------------------------------------------------------------- /docs/build/html/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/plus.png -------------------------------------------------------------------------------- /docs/build/doctrees/index.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/index.doctree -------------------------------------------------------------------------------- /docs/build/html/_static/logo_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/logo_.png -------------------------------------------------------------------------------- /docs/build/html/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/minus.png -------------------------------------------------------------------------------- /docs/source/_static/logo_wide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/source/_static/logo_wide.png -------------------------------------------------------------------------------- /docs/build/html/_images/logo_wide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_images/logo_wide.png -------------------------------------------------------------------------------- /docs/build/html/_static/logo_wide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/logo_wide.png -------------------------------------------------------------------------------- /docs/build/doctrees/environment.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/environment.pickle -------------------------------------------------------------------------------- /docs/build/doctrees/modules/nn.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/nn.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/ops.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/ops.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/irreps.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/irreps.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/utils.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/utils.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/modules.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/modules.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/structs.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/structs.doctree -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/modules.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch 2 | ========= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | equitorch 8 | -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/constants.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/constants.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/transforms.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/transforms.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.ops.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.ops.doctree -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-bold.eot -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-bold.ttf -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-bold.woff -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.ops.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.ops.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/get_started/installation.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/get_started/installation.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.irreps.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.irreps.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.utils.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.utils.doctree -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-italic.eot -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-italic.ttf -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-italic.woff -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-italic.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-regular.eot -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-regular.ttf -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-regular.woff -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.irreps.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.irreps.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.utils.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.utils.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.init.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.init.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.norm.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.norm.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.structs.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.structs.doctree -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-bolditalic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-bolditalic.eot -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-bolditalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-bolditalic.ttf -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-regular.woff2 -------------------------------------------------------------------------------- /docs/source/modules/nn.rst: -------------------------------------------------------------------------------- 1 | equitorch.nn 2 | ============ 3 | 4 | 5 | .. currentmodule:: equitorch.nn 6 | 7 | 8 | .. automodule:: equitorch.nn 9 | :members: -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.constants.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.constants.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.init.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.init.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.norm.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.norm.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.others.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.others.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.structs.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.structs.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.constants.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.constants.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.angular.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.angular.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.cutoffs.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.cutoffs.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.dropout.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.dropout.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.linears.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.linears.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.others.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.others.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.radials.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.radials.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.wigner_d.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.wigner_d.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.transforms.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.transforms.doctree -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-bolditalic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-bolditalic.woff -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/Lato/lato-bolditalic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/Lato/lato-bolditalic.woff2 -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.angular.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.angular.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.cutoffs.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.cutoffs.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.dropout.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.dropout.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.linears.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.linears.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.radials.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.radials.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.wigner_d.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.wigner_d.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.transforms.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.transforms.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.rotations.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.rotations.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.sphericals.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.sphericals.doctree -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/css/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/css/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.activations.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.activations.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.rotations.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.rotations.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.sphericals.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.sphericals.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.activations.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.activations.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.normalization.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.normalization.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.ops.kernel_dense.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.ops.kernel_dense.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.ops.kernel_utils.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.ops.kernel_utils.doctree -------------------------------------------------------------------------------- /docs/source/modules/utils.rst: -------------------------------------------------------------------------------- 1 | equitorch.utils 2 | =============== 3 | 4 | 5 | .. currentmodule:: equitorch.utils 6 | 7 | 8 | .. automodule:: equitorch.utils 9 | :members: -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.normalization.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.normalization.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.ops.kernel_dense.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.ops.kernel_dense.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.ops.kernel_utils.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.ops.kernel_utils.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.norm.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.norm.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.tensor_products.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.tensor_products.doctree -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/nn.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn 2 | ============ 3 | 4 | 5 | .. currentmodule:: equitorch.nn 6 | 7 | 8 | .. automodule:: equitorch.nn 9 | :members: -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf -------------------------------------------------------------------------------- /docs/source/modules/irreps.rst: -------------------------------------------------------------------------------- 1 | equitorch.irreps 2 | ================ 3 | 4 | 5 | .. currentmodule:: equitorch.irreps 6 | 7 | 8 | .. automodule:: equitorch.irreps 9 | :members: -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.norm.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.norm.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.tensor_products.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.tensor_products.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.angular.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.angular.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.cutoffs.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.cutoffs.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.dropout.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.dropout.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.linears.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.linears.doctree -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf -------------------------------------------------------------------------------- /docs/source/modules/structs.rst: -------------------------------------------------------------------------------- 1 | equitorch.structs 2 | ================= 3 | 4 | 5 | .. currentmodule:: equitorch.structs 6 | 7 | 8 | .. automodule:: equitorch.structs 9 | :members: -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.angular.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.angular.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.cutoffs.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.cutoffs.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.dropout.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.dropout.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.linears.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.linears.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.rotations.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.rotations.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.wigner_d.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.wigner_d.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.ops.indexed_product_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.ops.indexed_product_op.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.ops.product_segment_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.ops.product_segment_op.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.ops.spherical_harmonics.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.ops.spherical_harmonics.doctree -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.ops.kernel_dense.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.kernel\_dense 2 | =========================== 3 | 4 | .. automodule:: equitorch.ops.kernel_dense 5 | 6 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.ops.kernel_utils.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.kernel\_utils 2 | =========================== 3 | 4 | .. automodule:: equitorch.ops.kernel_utils 5 | 6 | -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff -------------------------------------------------------------------------------- /docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/html/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.rotations.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.rotations.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.sphericals.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.sphericals.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.wigner_d.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.wigner_d.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.ops.indexed_product_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.ops.indexed_product_op.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.ops.product_segment_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.ops.product_segment_op.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.ops.spherical_harmonics.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.ops.spherical_harmonics.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.activations.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.activations.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.sparse_scale.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.sparse_scale.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.sphericals.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.sphericals.doctree -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/utils.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.utils 2 | =============== 3 | 4 | 5 | .. currentmodule:: equitorch.utils 6 | 7 | 8 | .. automodule:: equitorch.utils 9 | :members: -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | .. _api: 2 | 3 | API Reference 4 | ============= 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Modules: 9 | :glob: 10 | 11 | modules/* 12 | 13 | -------------------------------------------------------------------------------- /docs/source/modules/constants.rst: -------------------------------------------------------------------------------- 1 | equitorch.constants 2 | =================== 3 | 4 | 5 | .. currentmodule:: equitorch.constants 6 | 7 | 8 | .. automodule:: equitorch.constants 9 | :members: -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.activations.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.activations.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.sparse_scale.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.sparse_scale.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.normalization.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.normalization.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.sparse_product.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.sparse_product.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.ops.batched_sparse_dense_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.ops.batched_sparse_dense_op.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.ops.indexed_scale_segment_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.ops.indexed_scale_segment_op.doctree -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/irreps.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.irreps 2 | ================ 3 | 4 | 5 | .. currentmodule:: equitorch.irreps 6 | 7 | 8 | .. automodule:: equitorch.irreps 9 | :members: -------------------------------------------------------------------------------- /docs/source/modules/transforms.rst: -------------------------------------------------------------------------------- 1 | equitorch.transforms 2 | ==================== 3 | 4 | 5 | .. currentmodule:: equitorch.transforms 6 | 7 | 8 | .. automodule:: equitorch.transforms 9 | :members: -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.normalization.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.normalization.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.sparse_product.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.sparse_product.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.nn.functional.tensor_products.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.nn.functional.tensor_products.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.ops.batched_sparse_dense_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.ops.batched_sparse_dense_op.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.ops.indexed_scale_segment_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.ops.indexed_scale_segment_op.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.nn.functional.tensor_products.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.nn.functional.tensor_products.doctree -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.ops.indexed_product_segment_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.ops.indexed_product_segment_op.doctree -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/structs.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.structs 2 | ================= 3 | 4 | 5 | .. currentmodule:: equitorch.structs 6 | 7 | 8 | .. automodule:: equitorch.structs 9 | :members: -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.ops.indexed_product_segment_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.ops.indexed_product_segment_op.doctree -------------------------------------------------------------------------------- /docs/build/html/_sources/api.rst.txt: -------------------------------------------------------------------------------- 1 | .. _api: 2 | 3 | API Reference 4 | ============= 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Modules: 9 | :glob: 10 | 11 | modules/* 12 | 13 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.cutoffs.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.cutoffs 2 | =============================== 3 | 4 | .. automodule:: equitorch.nn.functional.cutoffs 5 | 6 | -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.ops.indexed_product_scale_segment_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.ops.indexed_product_scale_segment_op.doctree -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.norm.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.norm 2 | ============================ 3 | 4 | .. currentmodule:: equitorch.nn.functional 5 | 6 | .. autofunction:: norm -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/constants.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.constants 2 | =================== 3 | 4 | 5 | .. currentmodule:: equitorch.constants 6 | 7 | 8 | .. automodule:: equitorch.constants 9 | :members: -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.ops.indexed_product_scale_segment_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.ops.indexed_product_scale_segment_op.doctree -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.ops.spherical_harmonics.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.spherical\_harmonics 2 | ================================== 3 | 4 | .. automodule:: equitorch.ops.spherical_harmonics 5 | 6 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/transforms.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.transforms 2 | ==================== 3 | 4 | 5 | .. currentmodule:: equitorch.transforms 6 | 7 | 8 | .. automodule:: equitorch.transforms 9 | :members: -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | sphinx: 3 | configuration: docs/source/conf.py 4 | build: 5 | os: "ubuntu-20.04" 6 | tools: 7 | python: "miniconda3-3.12-24.1" 8 | conda: 9 | environment: "./docs/environment.yaml" -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.init.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.init module 2 | ======================== 3 | 4 | .. automodule:: equitorch.nn.init 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.norm.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.norm module 2 | ======================== 3 | 4 | .. automodule:: equitorch.nn.norm 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/doctrees/modules/equitorch.ops.accumulated_indexed_product_segment_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/modules/equitorch.ops.accumulated_indexed_product_segment_op.doctree -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.others.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.others module 2 | ========================== 3 | 4 | .. automodule:: equitorch.nn.others 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/doctrees/generated/equitorch.ops.accumulated_indexed_product_segment_op.doctree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GTML-LAB/Equitorch/HEAD/docs/build/doctrees/generated/equitorch.ops.accumulated_indexed_product_segment_op.doctree -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.angular.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.angular module 2 | =========================== 3 | 4 | .. automodule:: equitorch.nn.angular 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.cutoffs.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.cutoffs module 2 | =========================== 3 | 4 | .. automodule:: equitorch.nn.cutoffs 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.dropout.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.dropout module 2 | =========================== 3 | 4 | .. automodule:: equitorch.nn.dropout 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.linears.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.linears module 2 | =========================== 3 | 4 | .. automodule:: equitorch.nn.linears 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.radials.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.radials module 2 | =========================== 3 | 4 | .. automodule:: equitorch.nn.radials 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.wigner_d.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.wigner\_d module 2 | ============================= 3 | 4 | .. automodule:: equitorch.nn.wigner_d 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.rotations.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.rotations module 2 | ============================= 3 | 4 | .. automodule:: equitorch.nn.rotations 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.sparse_scale.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.sparse\_scale 2 | ===================================== 3 | 4 | .. currentmodule:: equitorch.nn.functional 5 | 6 | .. autofunction:: sparse_scale -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.activations.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.activations module 2 | =============================== 3 | 4 | .. automodule:: equitorch.nn.activations 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.sphericals.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.sphericals module 2 | ============================== 3 | 4 | .. automodule:: equitorch.nn.sphericals 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.normalization.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.normalization module 2 | ================================= 3 | 4 | .. automodule:: equitorch.nn.normalization 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.ops.kernel_dense.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.kernel\_dense module 2 | ================================== 3 | 4 | .. automodule:: equitorch.ops.kernel_dense 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.ops.kernel_utils.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.kernel\_utils module 2 | ================================== 3 | 4 | .. automodule:: equitorch.ops.kernel_utils 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.angular.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.angular 2 | ==================== 3 | 4 | .. automodule:: equitorch.nn.angular 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | SinCos 12 | -------------------------------------------------------------------------------- /docs/build/html/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file records the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: fa6ca59b478ea84880120ab89fee1578 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.dropout.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.dropout 2 | ==================== 3 | 4 | .. automodule:: equitorch.nn.dropout 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | Dropout 12 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.radials.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.radials 2 | ==================== 3 | 4 | .. automodule:: equitorch.nn.radials 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | BesselBasis 12 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.norm.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.norm module 2 | =================================== 3 | 4 | .. automodule:: equitorch.nn.functional.norm 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.tensor_products.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.tensor\_products module 2 | ==================================== 3 | 4 | .. automodule:: equitorch.nn.tensor_products 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/.buildinfo.bak: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file records the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 37973216bf391cc61e3198479ef8cc5e 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.utils.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.utils package 2 | ======================= 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: equitorch.utils 8 | :members: 9 | :show-inheritance: 10 | :undoc-members: 11 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.activations.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.activations 2 | ======================== 3 | 4 | .. automodule:: equitorch.nn.activations 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | Gate 12 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.irreps.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.irreps package 2 | ======================== 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: equitorch.irreps 8 | :members: 9 | :show-inheritance: 10 | :undoc-members: 11 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.angular.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.angular module 2 | ====================================== 3 | 4 | .. automodule:: equitorch.nn.functional.angular 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.cutoffs.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.cutoffs module 2 | ====================================== 3 | 4 | .. automodule:: equitorch.nn.functional.cutoffs 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.dropout.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.dropout module 2 | ====================================== 3 | 4 | .. automodule:: equitorch.nn.functional.dropout 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.linears.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.linears module 2 | ====================================== 3 | 4 | .. automodule:: equitorch.nn.functional.linears 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.structs.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.structs package 2 | ========================= 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: equitorch.structs 8 | :members: 9 | :show-inheritance: 10 | :undoc-members: 11 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.rotations.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.rotations 2 | ====================== 3 | 4 | .. automodule:: equitorch.nn.rotations 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | AnglesToMatrix 12 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.wigner_d.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.wigner\_d module 2 | ======================================== 3 | 4 | .. automodule:: equitorch.nn.functional.wigner_d 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.others.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.others 2 | =================== 3 | 4 | .. automodule:: equitorch.nn.others 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | Separable 12 | SplitIrreps 13 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.constants.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.constants package 2 | =========================== 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: equitorch.constants 8 | :members: 9 | :show-inheritance: 10 | :undoc-members: 11 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.rotations.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.rotations module 2 | ======================================== 3 | 4 | .. automodule:: equitorch.nn.functional.rotations 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.sphericals.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.sphericals module 2 | ========================================= 3 | 4 | .. automodule:: equitorch.nn.functional.sphericals 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.ops.indexed_product_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.indexed\_product\_op module 2 | ========================================= 3 | 4 | .. automodule:: equitorch.ops.indexed_product_op 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.ops.product_segment_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.product\_segment\_op module 2 | ========================================= 3 | 4 | .. automodule:: equitorch.ops.product_segment_op 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.ops.spherical_harmonics.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.spherical\_harmonics module 2 | ========================================= 3 | 4 | .. automodule:: equitorch.ops.spherical_harmonics 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.constants.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.constants 2 | =================== 3 | 4 | .. automodule:: equitorch.constants 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | j_matrix 12 | so3_clebsch_gordan 13 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.activations.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.activations module 2 | ========================================== 3 | 4 | .. automodule:: equitorch.nn.functional.activations 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.transforms.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.transforms package 2 | ============================ 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: equitorch.transforms 8 | :members: 9 | :show-inheritance: 10 | :undoc-members: 11 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.sparse_scale.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.sparse\_scale module 2 | ============================================ 3 | 4 | .. automodule:: equitorch.nn.functional.sparse_scale 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.norm.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.norm 2 | ================= 3 | 4 | .. automodule:: equitorch.nn.norm 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | MeanSquaredNorm 12 | Norm 13 | SquaredNorm 14 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.normalization.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.normalization module 2 | ============================================ 3 | 4 | .. automodule:: equitorch.nn.functional.normalization 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.sparse_product.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.sparse\_product module 2 | ============================================== 3 | 4 | .. automodule:: equitorch.nn.functional.sparse_product 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.ops.batched_sparse_dense_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.batched\_sparse\_dense\_op module 2 | =============================================== 3 | 4 | .. automodule:: equitorch.ops.batched_sparse_dense_op 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.angular.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.angular 2 | =============================== 3 | 4 | .. automodule:: equitorch.nn.functional.angular 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | sincos 12 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.tensor_products.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.tensor\_products module 2 | =============================================== 3 | 4 | .. automodule:: equitorch.nn.functional.tensor_products 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.ops.indexed_scale_segment_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.indexed\_scale\_segment\_op module 2 | ================================================ 3 | 4 | .. automodule:: equitorch.ops.indexed_scale_segment_op 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.normalization.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.normalization 2 | ========================== 3 | 4 | .. automodule:: equitorch.nn.normalization 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | BatchRMSNorm 12 | LayerRMSNorm 13 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.ops.indexed_product_segment_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.indexed\_product\_segment\_op module 2 | ================================================== 3 | 4 | .. automodule:: equitorch.ops.indexed_product_segment_op 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.dropout.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.dropout 2 | =============================== 3 | 4 | .. automodule:: equitorch.nn.functional.dropout 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | irrep_wise_dropout 12 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch 2 | ========= 3 | 4 | .. automodule:: equitorch 5 | 6 | 7 | .. rubric:: Modules 8 | 9 | .. autosummary:: 10 | :toctree: 11 | :recursive: 12 | 13 | constants 14 | irreps 15 | nn 16 | structs 17 | transforms 18 | utils 19 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.cutoffs.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.cutoffs 2 | ==================== 3 | 4 | .. automodule:: equitorch.nn.cutoffs 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | CosineCutoff 12 | MollifierCutoff 13 | PolynomialCutoff 14 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.activations.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.activations 2 | =================================== 3 | 4 | .. automodule:: equitorch.nn.functional.activations 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | gating 12 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.rotations.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.rotations 2 | ================================= 3 | 4 | .. automodule:: equitorch.nn.functional.rotations 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | angles_to_matrix 12 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.tensor_products.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.tensor\_products 2 | ============================= 3 | 4 | .. automodule:: equitorch.nn.tensor_products 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | TensorDot 12 | TensorProduct 13 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.transforms.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.transforms 2 | ==================== 3 | 4 | .. automodule:: equitorch.transforms 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | RadiusGraph 12 | AddSphericalHarmonics 13 | AddVectorNorm 14 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.init.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.init 2 | ================= 3 | 4 | .. automodule:: equitorch.nn.init 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | initialize_linear 12 | initialize_so3_so2_linear 13 | initialize_tensor_product 14 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.linears.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.linears 2 | ==================== 3 | 4 | .. automodule:: equitorch.nn.linears 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | IrrepWiseLinear 12 | IrrepsLinear 13 | SO2Linear 14 | SO3Linear 15 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.ops.indexed_product_scale_segment_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.indexed\_product\_scale\_segment\_op module 2 | ========================================================= 3 | 4 | .. automodule:: equitorch.ops.indexed_product_scale_segment_op 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.sphericals.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.sphericals 2 | ======================= 3 | 4 | .. automodule:: equitorch.nn.sphericals 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | SphericalHarmonics 12 | SphericalToXYZ 13 | XYZToSinCos 14 | XYZToSpherical 15 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.wigner_d.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.wigner\_d 2 | ====================== 3 | 4 | .. automodule:: equitorch.nn.wigner_d 5 | 6 | 7 | .. rubric:: Classes 8 | 9 | .. autosummary:: 10 | 11 | AlignToZWignerD 12 | DenseWignerRotation 13 | SparseWignerRotation 14 | WignerD 15 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.normalization.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.normalization 2 | ===================================== 3 | 4 | .. automodule:: equitorch.nn.functional.normalization 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | batch_rms_norm 12 | layer_rms_norm 13 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.ops.accumulated_indexed_product_segment_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.accumulated\_indexed\_product\_segment\_op module 2 | =============================================================== 3 | 4 | .. automodule:: equitorch.ops.accumulated_indexed_product_segment_op 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: 8 | -------------------------------------------------------------------------------- /equitorch/__init__.py: -------------------------------------------------------------------------------- 1 | """Equitorch.""" 2 | from . import constants 3 | from . import irreps 4 | from . import nn 5 | # from . import ops 6 | from . import structs 7 | from . import transforms 8 | from . import utils 9 | 10 | __all__ = [ 11 | "constants", 12 | "irreps", 13 | "nn", 14 | # "ops", 15 | "structs", 16 | "transforms", 17 | "utils", 18 | ] -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.sphericals.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.sphericals 2 | ================================== 3 | 4 | .. automodule:: equitorch.nn.functional.sphericals 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | spherical_harmonics 12 | spherical_to_xyz 13 | xyz_to_sincos 14 | xyz_to_spherical 15 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.wigner_d.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.wigner\_d 2 | ================================= 3 | 4 | .. automodule:: equitorch.nn.functional.wigner_d 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | align_to_z_wigner_d 12 | dense_wigner_rotation 13 | sparse_wigner_rotation 14 | wigner_d_matrix 15 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.ops.indexed_product_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.indexed\_product\_op 2 | ================================== 3 | 4 | .. automodule:: equitorch.ops.indexed_product_op 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | indexed_inner 12 | indexed_mul 13 | indexed_outer 14 | indexed_vecmat 15 | indexed_vecsca 16 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.ops.product_segment_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.product\_segment\_op 2 | ================================== 3 | 4 | .. automodule:: equitorch.ops.product_segment_op 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | inner_segment 12 | mul_segment 13 | outer_segment 14 | vecmat_segment 15 | vecsca_segment 16 | -------------------------------------------------------------------------------- /docs/build/html/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | const DOCUMENTATION_OPTIONS = { 2 | VERSION: '', 3 | LANGUAGE: 'en', 4 | COLLAPSE_INDEX: false, 5 | BUILDER: 'html', 6 | FILE_SUFFIX: '.html', 7 | LINK_SUFFIX: '.html', 8 | HAS_SOURCE: true, 9 | SOURCELINK_SUFFIX: '.txt', 10 | NAVIGATION_WITH_KEYS: false, 11 | SHOW_SEARCH_SUMMARY: true, 12 | ENABLE_SEARCH_SHORTCUTS: true, 13 | }; -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.ops.indexed_scale_segment_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.indexed\_scale\_segment\_op 2 | ========================================= 3 | 4 | .. automodule:: equitorch.ops.indexed_scale_segment_op 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | indexed_scale_segment 12 | indexed_scale_segment_cpu 13 | indexed_scale_segment_gpu 14 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.ops.indexed_product_segment_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.indexed\_product\_segment\_op 2 | =========================================== 3 | 4 | .. automodule:: equitorch.ops.indexed_product_segment_op 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | indexed_inner_segment 12 | indexed_mul_segment 13 | indexed_outer_segment 14 | indexed_vecmat_segment 15 | indexed_vecsca_segment 16 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.structs.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.structs 2 | ================= 3 | 4 | .. automodule:: equitorch.structs 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | add_operation_methods 12 | 13 | .. rubric:: Classes 14 | 15 | .. autosummary:: 16 | 17 | IrrepsInfo 18 | IrrepsLinearInfo 19 | SparseProductInfo 20 | SparseScaleInfo 21 | TensorProductInfo 22 | WignerRotationInfo 23 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch package 2 | ================= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | equitorch.constants 11 | equitorch.irreps 12 | equitorch.nn 13 | equitorch.ops 14 | equitorch.structs 15 | equitorch.transforms 16 | equitorch.utils 17 | 18 | Module contents 19 | --------------- 20 | 21 | .. automodule:: equitorch 22 | :members: 23 | :show-inheritance: 24 | :undoc-members: 25 | -------------------------------------------------------------------------------- /docs/build/html/_static/github-banner.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.ops.indexed_product_scale_segment_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.indexed\_product\_scale\_segment\_op 2 | ================================================== 3 | 4 | .. automodule:: equitorch.ops.indexed_product_scale_segment_op 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | indexed_inner_scale_segment 12 | indexed_mul_scale_segment 13 | indexed_outer_scale_segment 14 | indexed_vecmat_scale_segment 15 | indexed_vecsca_scale_segment 16 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.ops.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops 2 | ============= 3 | 4 | .. automodule:: equitorch.ops 5 | 6 | 7 | .. rubric:: Modules 8 | 9 | .. autosummary:: 10 | :toctree: 11 | :recursive: 12 | 13 | accumulated_indexed_product_segment_op 14 | batched_sparse_dense_op 15 | indexed_product_op 16 | indexed_product_scale_segment_op 17 | indexed_product_segment_op 18 | indexed_scale_segment_op 19 | kernel_dense 20 | kernel_utils 21 | product_segment_op 22 | spherical_harmonics 23 | -------------------------------------------------------------------------------- /docs/build/html/_sources/get_started/installation.rst.txt: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | This package is based on `Pytorch `_ (>=2.4) and `Pytorch-Geometric `_ (>=2.4). Please make sure you have already installed the version that fit your device. (It is temporarily recommended to use `pip` to install the Pytorch-Geometric.) 5 | 6 | With these packages installed, you can install *Equitorch* by 7 | 8 | .. code-block:: bash 9 | 10 | pip install equitorch 11 | 12 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.tensor_products.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.tensor\_products 2 | ======================================== 3 | 4 | .. automodule:: equitorch.nn.functional.tensor_products 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | tensor_product_uuu 12 | tensor_product_uvw 13 | 14 | .. rubric:: Classes 15 | 16 | .. autosummary:: 17 | 18 | TensorDotUU 19 | TensorDotUV 20 | TensorProductUUUDummy 21 | TensorProductUVWDummy 22 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.irreps.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.irreps 2 | ================ 3 | 4 | .. automodule:: equitorch.irreps 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | check_irreps 12 | element_degrees 13 | element_orders 14 | has_path 15 | irrep_degrees 16 | irrep_indices 17 | irrep_segments 18 | parse_irreps 19 | show_irreps 20 | unique_irreps 21 | 22 | .. rubric:: Classes 23 | 24 | .. autosummary:: 25 | 26 | Irrep 27 | Irreps 28 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.ops.accumulated_indexed_product_segment_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.accumulated\_indexed\_product\_segment\_op 2 | ======================================================== 3 | 4 | .. automodule:: equitorch.ops.accumulated_indexed_product_segment_op 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | accumulated_indexed_inner_segment 12 | accumulated_indexed_mul_segment 13 | accumulated_indexed_outer_segment 14 | accumulated_indexed_vecmat_segment 15 | accumulated_indexed_vecsca_segment 16 | -------------------------------------------------------------------------------- /docs/source/get_started/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | This package is based on `Pytorch `_ (>=2.4) and `Pytorch-Geometric `_ (>=2.4), `Triton `_ (>=3.2). Please make sure you have already installed the version that fit your device. (It is temporarily recommended to use `pip` to install the Pytorch-Geometric.) 5 | 6 | With these packages installed, you can install *Equitorch* by 7 | 8 | .. code-block:: bash 9 | 10 | pip install git+https://github.com/GTML-LAB/Equitorch.git 11 | 12 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.sparse_product.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.sparse\_product 2 | ======================================= 3 | 4 | .. automodule:: equitorch.nn.functional.sparse_product 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | sparse_inner 12 | sparse_mat_t_vec 13 | sparse_mul 14 | sparse_outer 15 | sparse_scavec 16 | sparse_vecmat 17 | sparse_vecsca 18 | 19 | .. rubric:: Classes 20 | 21 | .. autosummary:: 22 | 23 | SparseInner 24 | SparseMatTVec 25 | SparseMul 26 | SparseOuter 27 | SparseScaVec 28 | SparseVecMat 29 | SparseVecSca 30 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. equitorch documentation master file, created by 2 | sphinx-quickstart on Tue Sep 17 14:07:34 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | 7 | 8 | Equitorch documentation 9 | ======================= 10 | 11 | .. image:: ./_static/logo_wide.png 12 | 13 | 14 | This is **Equitorch**, a modularized package for flexibily constructing O(3)/SO(3) equivariant (and invariant) neural networks with Triton_ accelerated operators. 15 | 16 | Github pages: ``_ 17 | 18 | 19 | .. toctree:: 20 | :maxdepth: 2 21 | 22 | api 23 | 24 | 25 | .. _Triton: https://triton-lang.org/main/index.html -------------------------------------------------------------------------------- /docs/build/html/_sources/index.rst.txt: -------------------------------------------------------------------------------- 1 | .. equitorch documentation master file, created by 2 | sphinx-quickstart on Tue Sep 17 14:07:34 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | 7 | 8 | Equitorch documentation 9 | ======================= 10 | 11 | .. image:: ./_static/logo_wide.png 12 | 13 | 14 | This is **Equitorch**, a modularized package for flexibily constructing equivariant (and invariant) neural networks with Triton_ accelerated operators. 15 | 16 | Github pages: ``_ 17 | 18 | 19 | .. toctree:: 20 | :maxdepth: 2 21 | 22 | api 23 | 24 | 25 | .. _Triton: https://triton-lang.org/main/index.html -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.linears.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional.linears 2 | =============================== 3 | 4 | .. automodule:: equitorch.nn.functional.linears 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | so3_linear_uu 12 | so3_linear_uv 13 | tensor_product_1uu 14 | tensor_product_1vu 15 | tensor_product_u1u 16 | tensor_product_u1v 17 | tensor_product_uu1 18 | tensor_product_vu1 19 | 20 | .. rubric:: Classes 21 | 22 | .. autosummary:: 23 | 24 | IrrepWiseLinear 25 | IrrepsLinear 26 | TensorProduct1UUDummy 27 | TensorProduct1VUDummy 28 | TensorProductU1UDummy 29 | TensorProductU1VDummy 30 | TensorProductUU1Dummy 31 | TensorProductVU1Dummy 32 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.utils.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.utils 2 | =============== 3 | 4 | .. automodule:: equitorch.utils 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | rand_spherical_xyz 12 | rand_spherical_angles 13 | rand_rotation_angles 14 | rand_rotation_matrices 15 | expand_left 16 | extract_batch_segments 17 | sort_by_column_key 18 | extract_scatter_indices 19 | sparse_scale_info 20 | sparse_scale_infos 21 | sparse_product_info 22 | sparse_product_infos 23 | generate_fully_connected_tp_paths 24 | tp_info 25 | tp_infos 26 | generate_fully_connected_irreps_linear_paths 27 | irreps_linear_infos 28 | irreps_info 29 | z_rotation_infos 30 | j_matrix_info 31 | wigner_d_info 32 | irreps_blocks_infos 33 | so2_linear_info 34 | so2_linear_infos 35 | -------------------------------------------------------------------------------- /docs/build/html/_static/js/badge_only.js: -------------------------------------------------------------------------------- 1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}}); -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.ops.batched_sparse_dense_op.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops.batched\_sparse\_dense\_op 2 | ======================================== 3 | 4 | .. automodule:: equitorch.ops.batched_sparse_dense_op 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | indexed_inner_scale_gather 12 | indexed_inner_scale_gather_cpu 13 | indexed_inner_scale_gather_gpu 14 | indexed_mat_t_vec_scale_gather 15 | indexed_mul_scale_gather 16 | indexed_mul_scale_gather_cpu 17 | indexed_mul_scale_gather_gpu 18 | indexed_outer_scale_gather 19 | indexed_outer_scale_gather_cpu 20 | indexed_outer_scale_gather_gpu 21 | indexed_scavec_scale_gather 22 | indexed_vecmat_scale_gather 23 | indexed_vecmat_scale_gather_cpu 24 | indexed_vecmat_scale_gather_gpu 25 | indexed_vecsca_scale_gather 26 | indexed_vecsca_scale_gather_cpu 27 | indexed_vecsca_scale_gather_gpu 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Equitorch 2 | 3 | ![Equitorch_logo](./img/logo_wide.png) 4 | 5 | *Equitorch* is a modularized package that can be used to flexibly constructing O(3)/SO(3) equivariant neural networks. 6 | 7 | **[Github Pages](https://github.com/GTML-LAB/Equitorch/tree/main)** 8 | 9 | **[Documentation](https://equitorch.readthedocs.io/en/latest/index.html)** 10 | 11 | > This package is still under development. 12 | > We are actively adding more operations, documentations and tutorials. 13 | 14 | ### Installation 15 | 16 | This package is based on [Pytorch](https://pytorch.org/)(>=2.4), [Pytorch-Geometric](https://pytorch-geometric.readthedocs.io/en/latest/index.html)(>=2.4), and [Triton](http://triton-lang.org/)(>=3.2). Please make sure you have already installed the version that fit your device. (It is currently recommended to use `pip` to install the Pytorch-Geometric.) 17 | 18 | With these packages installed, you can install *Equitorch* by 19 | 20 | ```bash 21 | pip install git+https://github.com/GTML-LAB/Equitorch.git 22 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 GTML-LAB 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn 2 | ============ 3 | 4 | .. automodule:: equitorch.nn 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | initialize_tensor_product 12 | initialize_so3_so2_linear 13 | initialize_linear 14 | 15 | .. rubric:: Classes 16 | 17 | .. autosummary:: 18 | 19 | SO3Linear 20 | IrrepWiseLinear 21 | IrrepsLinear 22 | SO2Linear 23 | SplitIrreps 24 | Separable 25 | SphericalHarmonics 26 | XYZToSpherical 27 | SphericalToXYZ 28 | XYZToSinCos 29 | BatchRMSNorm 30 | LayerRMSNorm 31 | TensorProduct 32 | TensorDot 33 | SparseWignerRotation 34 | DenseWignerRotation 35 | WignerD 36 | AlignToZWignerD 37 | PolynomialCutoff 38 | CosineCutoff 39 | MollifierCutoff 40 | Gate 41 | SinCos 42 | BesselBasis 43 | SquaredNorm 44 | Norm 45 | MeanSquaredNorm 46 | Dropout 47 | AnglesToMatrix 48 | 49 | .. rubric:: Modules 50 | 51 | .. autosummary:: 52 | :toctree: 53 | :recursive: 54 | 55 | functional 56 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="equitorch", # Package name set to equitorch 5 | version="1.0.0", # Default version 6 | author="Tong Wang", # Placeholder - Update if needed 7 | author_email="TongWang_2000@outlook.com", # Placeholder - Update if needed 8 | description="An efficient modularized package for SO(3)/O(3) equivariant neural networks", # Placeholder description 9 | long_description="""An efficient modularized package for SO(3)/O(3) equivariant neural networks""", # Placeholder long description 10 | long_description_content_type="text/markdown", 11 | url="https://equitorch.readthedocs.io/en/latest/index.html", # Placeholder URL - Update if needed 12 | packages=setuptools.find_packages(exclude=['test*']), # Automatically find packages in 'equitorch/' 13 | classifiers=[ 14 | "Programming Language :: Python :: 3", 15 | "License :: OSI Approved :: MIT License", # Placeholder License - Update if needed 16 | "Operating System :: OS Independent", 17 | ], 18 | python_requires='>=3.12', # Example Python version requirement 19 | install_requires=[ 20 | # Add actual dependencies here, for example: 21 | # 'torch>=2.4.0', 22 | 'triton>=3.2.0', 23 | ], 24 | ) 25 | -------------------------------------------------------------------------------- /docs/build/html/_sources/generated/equitorch.nn.functional.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional 2 | ======================= 3 | 4 | .. automodule:: equitorch.nn.functional 5 | 6 | 7 | .. rubric:: Functions 8 | 9 | .. autosummary:: 10 | 11 | tensor_product_u1u 12 | so3_linear_uu 13 | tensor_product_1uu 14 | tensor_product_uu1 15 | tensor_product_u1v 16 | so3_linear_uv 17 | tensor_product_1vu 18 | tensor_product_vu1 19 | irrep_wise_linear 20 | irreps_linear 21 | so2_linear_uu 22 | so2_linear_uv 23 | spherical_harmonics 24 | xyz_to_spherical 25 | spherical_to_xyz 26 | xyz_to_sincos 27 | batch_rms_norm 28 | layer_rms_norm 29 | sparse_scale 30 | tensor_product_uuu 31 | tensor_product_uvw 32 | tensor_dot_uu 33 | tensor_dot_uv 34 | sparse_mul 35 | sparse_outer 36 | sparse_inner 37 | sparse_vecmat 38 | sparse_vecsca 39 | sparse_scavec 40 | sparse_mat_t_vec 41 | sparse_wigner_rotation 42 | dense_wigner_rotation 43 | wigner_d_matrix 44 | align_to_z_wigner_d 45 | gating 46 | sincos 47 | squared_norm 48 | norm 49 | channel_mean_squared_norm 50 | batch_mean_squared_norm 51 | irrep_wise_dropout 52 | angles_to_matrix 53 | -------------------------------------------------------------------------------- /equitorch/nn/functional/sparse_scale.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | from torch import Tensor 4 | from torch.autograd import Function 5 | 6 | from ...ops.indexed_scale_segment_op import ( 7 | indexed_scale_segment 8 | ) 9 | 10 | from ...structs import SparseScaleInfo 11 | 12 | from equitorch.irreps import check_irreps, Irreps 13 | 14 | class SparseScale(Function): 15 | ''' 16 | Currently only support Square-matrix Transformation 17 | ''' 18 | @staticmethod 19 | def forward(ctx, input, info_fwd, info_bwd): 20 | 21 | ret = indexed_scale_segment( 22 | input, 23 | info_fwd.scale, 24 | info_fwd.index, 25 | info_fwd.seg_out, 26 | info_fwd.out_size, 27 | ) 28 | 29 | ctx.save_for_backward(input) 30 | ctx.infos = (info_fwd, info_bwd) 31 | return ret 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | grad = grad_output 36 | (input,) = ctx.saved_tensors 37 | info_fwd, info_bwd = ctx.infos 38 | 39 | grad_in = None 40 | 41 | if ctx.needs_input_grad[0]: 42 | grad_in = SparseScale.apply(grad, info_bwd, info_fwd) 43 | 44 | return grad_in, None, None 45 | 46 | def sparse_scale(input: Tensor, info_fwd: SparseScaleInfo, info_bwd: Optional[SparseScaleInfo] = None) -> Tensor: 47 | return SparseScale.apply(input, info_fwd, info_bwd) -------------------------------------------------------------------------------- /equitorch/nn/functional/angular.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | def sincos(angle: Tensor, max_m: int, with_ones=True, component_normalize=False): 5 | r"""Prepares the sin/cos tensor for z-rotation. 6 | 7 | This is the functional version of the :class:`~equitorch.nn.angular.SinCos` module. 8 | See :class:`~equitorch.nn.angular.SinCos` for more details. 9 | 10 | Args: 11 | angle (torch.Tensor): Input angles. 12 | max_m (int): The maximum multiple of the angle to compute. 13 | with_ones (bool, optional): Whether to include the leading 1.0. Defaults to ``True``. 14 | component_normalize (bool, optional): If ``True``, normalizes sin/cos components by :math:`\sqrt{2}`. 15 | Defaults to ``False``. 16 | 17 | Returns: 18 | torch.Tensor: The computed sin/cos tensor. 19 | """ 20 | if max_m == 0: 21 | # Only scalar irreps, rotation is identity 22 | return torch.ones_like(angle).unsqueeze(-1) 23 | m = torch.arange(1, max_m+1, dtype=angle.dtype, device=angle.device) 24 | m_angle = angle.unsqueeze(-1) * m 25 | sin_m = torch.sin(m_angle) # sin(|m|*angle) 26 | cos_m = torch.cos(m_angle) # cos(|m|*angle) 27 | if component_normalize: 28 | sin_m = sin_m * (2**0.5) 29 | cos_m = cos_m * (2**0.5) 30 | if with_ones: 31 | ones = torch.ones_like(angle).unsqueeze(-1) 32 | # [1.0, sin(1a), cos(1a), sin(2a), cos(2a), ...] 33 | return torch.cat([ones, torch.stack([sin_m, cos_m], dim=-1).flatten(-2, -1)], dim=-1) 34 | else: 35 | return torch.stack([sin_m, cos_m], dim=-1).flatten(-2, -1) 36 | 37 | -------------------------------------------------------------------------------- /equitorch/nn/functional/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from .sparse_product import sparse_mul 5 | 6 | from ...structs import IrrepsInfo, SparseProductInfo 7 | 8 | 9 | def gating(input: Tensor, gates: Tensor, irreps_info: IrrepsInfo) -> Tensor: 10 | r""" 11 | Equivariant gating mechanism. 12 | 13 | Applies element-wise gates to features. This is the functional 14 | version of the :class:`~equitorch.nn.activations.Gate` module. 15 | 16 | See :class:`~equitorch.nn.activations.Gate` for more details on the gating mechanism, 17 | including how ``input`` and ``gates`` are structured and combined. 18 | 19 | Args: 20 | input (torch.Tensor): Tensor to be gated. Shape ``(..., irreps.dim, channels)``. 21 | gates (torch.Tensor): Gating values. Shape depends on how gates are applied 22 | (e.g., ``(..., num_gates, channels)`` for irrep-wise gating or 23 | ``(..., 1, channels)`` for global gating). 24 | irreps_info (IrrepsInfo): Contains ``irreps_info.irrep_index``, which maps 25 | each component of ``input``'s spherical dimension to an 26 | index in ``gates``' corresponding dimension (the gate dimension). 27 | 28 | Returns: 29 | torch.Tensor: The gated input tensor, shape ``(..., irreps.dim, channels)``. 30 | """ 31 | info_fwd = SparseProductInfo(index2=irreps_info.irrep_index) 32 | info_bwd1 = SparseProductInfo(index1=irreps_info.irrep_index) 33 | info_bwd2 = SparseProductInfo(seg_out=irreps_info.irrep_seg) 34 | 35 | return sparse_mul(input, gates, 36 | info_fwd, info_bwd1, info_bwd2) 37 | -------------------------------------------------------------------------------- /equitorch/nn/radials.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import math 3 | 4 | import torch 5 | 6 | from torch import nn 7 | 8 | 9 | # Adapted from https://github.com/mir-group/nequip/blob/v0.6.2/nequip/nn/radial_basis.py 10 | class BesselBasis(nn.Module): 11 | r_max: float 12 | prefactor: float 13 | 14 | def __init__(self, r_max, num_basis=8, trainable=True): 15 | r"""Radial Bessel Basis, as proposed in DimeNet: https://arxiv.org/abs/2003.03123 16 | 17 | 18 | Parameters 19 | ---------- 20 | r_max : float 21 | Cutoff radius 22 | 23 | num_basis : int 24 | Number of Bessel Basis functions 25 | 26 | trainable : bool 27 | Train the :math:`n \pi` part or not. 28 | """ 29 | super(BesselBasis, self).__init__() 30 | 31 | self.trainable = trainable 32 | self.num_basis = num_basis 33 | 34 | self.r_max = float(r_max) 35 | self.prefactor = 2.0 / self.r_max 36 | 37 | bessel_weights = ( 38 | torch.linspace(start=1.0, end=num_basis, steps=num_basis) * math.pi 39 | ) 40 | if self.trainable: 41 | self.bessel_weights = nn.Parameter(bessel_weights) 42 | else: 43 | self.register_buffer("bessel_weights", bessel_weights) 44 | 45 | def forward(self, x: torch.Tensor) -> torch.Tensor: 46 | r""" 47 | Evaluate Bessel Basis for input x. 48 | 49 | Parameters 50 | ---------- 51 | x : torch.Tensor 52 | Input 53 | """ 54 | numerator = torch.sin(self.bessel_weights * x.unsqueeze(-1) / self.r_max) 55 | 56 | return self.prefactor * (numerator / x.unsqueeze(-1)) 57 | 58 | 59 | -------------------------------------------------------------------------------- /equitorch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions. 3 | """ 4 | 5 | from ._random import ( 6 | rand_spherical_xyz, 7 | rand_spherical_angles, 8 | rand_rotation_angles, 9 | rand_rotation_matrices 10 | ) 11 | from ._indices import ( 12 | expand_left, 13 | extract_batch_segments, 14 | sort_by_column_key, 15 | extract_scatter_indices 16 | ) 17 | from ._structs import ( 18 | sparse_scale_info, 19 | sparse_scale_infos, 20 | sparse_product_info, 21 | sparse_product_infos, 22 | generate_fully_connected_tp_paths, 23 | # prepare_so3, 24 | # create_tp_info, 25 | tp_info, 26 | tp_infos, 27 | generate_fully_connected_irreps_linear_paths, 28 | # prepare_irreps_linear, 29 | # create_irreps_linear_info, 30 | irreps_linear_infos, 31 | irreps_info, 32 | # prepare_z_rotation, 33 | z_rotation_infos, 34 | z_rotation_infos, 35 | j_matrix_info, 36 | wigner_d_info, 37 | irreps_blocks_infos, 38 | # prepare_so2_linear, 39 | so2_linear_info, 40 | so2_linear_infos 41 | ) 42 | 43 | __all__ = [ 44 | "rand_spherical_xyz", 45 | "rand_spherical_angles", 46 | "rand_rotation_angles", 47 | "rand_rotation_matrices", 48 | "expand_left", 49 | "extract_batch_segments", 50 | "sort_by_column_key", 51 | "extract_scatter_indices", 52 | "sparse_scale_info", 53 | "sparse_scale_infos", 54 | "sparse_product_info", 55 | "sparse_product_infos", 56 | "generate_fully_connected_tp_paths", 57 | # "prepare_so3", 58 | # "create_tp_info", 59 | "tp_info", 60 | "tp_infos", 61 | "generate_fully_connected_irreps_linear_paths", 62 | # "prepare_irreps_linear", 63 | # "create_irreps_linear_info", 64 | "irreps_linear_infos", 65 | "irreps_info", 66 | # "prepare_z_rotation", 67 | "z_rotation_infos", 68 | "z_rotation_infos", 69 | "j_matrix_info", 70 | "wigner_d_info", 71 | "irreps_blocks_infos", 72 | # "prepare_so2_linear", 73 | "so2_linear_info", 74 | "so2_linear_infos" 75 | ] 76 | -------------------------------------------------------------------------------- /equitorch/nn/rotations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Optional 4 | 5 | from .functional.rotations import angles_to_matrix 6 | 7 | 8 | class AnglesToMatrix(nn.Module): 9 | r"""Module to convert Euler angles (ZYZ convention) to rotation matrices. 10 | 11 | The ZYZ Euler angles \(\alpha, \beta, \gamma\) correspond to the rotation matrix: 12 | 13 | .. math:: 14 | R(\alpha, \beta, \gamma) = R_z(\alpha) R_y(\beta) R_z(\gamma) 15 | 16 | which is explicitly: 17 | 18 | .. math:: 19 | \begin{pmatrix} 20 | -\sin(\alpha)\sin(\gamma) + \cos(\alpha)\cos(\beta)\cos(\gamma) & -\sin(\alpha)\cos(\beta)\cos(\gamma) - \sin(\gamma)\cos(\alpha) & \sin(\beta)\cos(\gamma) \\ 21 | \sin(\alpha)\cos(\gamma) + \sin(\gamma)\cos(\alpha)\cos(\beta) & -\sin(\alpha)\sin(\gamma)\cos(\beta) + \cos(\alpha)\cos(\gamma) & \sin(\beta)\sin(\gamma) \\ 22 | -\sin(\beta)\cos(\alpha) & \sin(\alpha)\sin(\beta) & \cos(\beta) 23 | \end{pmatrix} 24 | 25 | Wraps the functional version :func:`~equitorch.nn.functional.rotations.angles_to_matrix`. 26 | """ 27 | def __init__(self): 28 | super().__init__() 29 | 30 | def forward( 31 | self, 32 | alpha: Optional[torch.Tensor] = None, 33 | beta: Optional[torch.Tensor] = None, 34 | gamma: Optional[torch.Tensor] = None 35 | ) -> torch.Tensor: 36 | r""" 37 | Args: 38 | alpha: First rotation angle about z-axis (radians). Shape (...) 39 | beta: Second rotation angle about y-axis (radians). Shape (...) 40 | gamma: Third rotation angle about z-axis (radians). Shape (...) 41 | 42 | Returns: 43 | Rotation matrices of shape (..., 3, 3) 44 | """ 45 | return angles_to_matrix(alpha=alpha, beta=beta, gamma=gamma) 46 | 47 | def extra_repr(self) -> str: 48 | # This module has no parameters to display. 49 | return "" 50 | -------------------------------------------------------------------------------- /equitorch/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains equivariant neural network modules and functionalities.""" 2 | from .linears import ( 3 | SO3Linear, 4 | IrrepWiseLinear, 5 | IrrepsLinear, 6 | SO2Linear 7 | ) 8 | from .others import ( 9 | SplitIrreps, 10 | Separable 11 | ) 12 | from .sphericals import ( 13 | SphericalHarmonics, 14 | XYZToSpherical, 15 | SphericalToXYZ, 16 | XYZToSinCos 17 | ) 18 | from .normalization import ( 19 | BatchRMSNorm, 20 | LayerRMSNorm 21 | ) 22 | from .init import ( 23 | initialize_tensor_product, 24 | initialize_so3_so2_linear, 25 | initialize_linear 26 | ) 27 | from .tensor_products import ( 28 | TensorProduct, 29 | TensorDot 30 | ) 31 | from .wigner_d import ( 32 | SparseWignerRotation, 33 | DenseWignerRotation, 34 | WignerD, 35 | AlignToZWignerD 36 | ) 37 | from .cutoffs import ( 38 | PolynomialCutoff, 39 | CosineCutoff, 40 | MollifierCutoff 41 | ) 42 | from .activations import ( 43 | Gate 44 | ) 45 | from .angular import ( 46 | SinCos 47 | ) 48 | from .radials import ( 49 | BesselBasis 50 | ) 51 | from .norm import ( 52 | SquaredNorm, 53 | Norm, 54 | MeanSquaredNorm 55 | ) 56 | from .dropout import ( 57 | Dropout 58 | ) 59 | from .rotations import ( 60 | AnglesToMatrix 61 | ) 62 | 63 | from . import functional 64 | 65 | __all__ = [ 66 | "functional", 67 | "SO3Linear", 68 | "IrrepWiseLinear", 69 | "IrrepsLinear", 70 | "SO2Linear", 71 | "SplitIrreps", 72 | "Separable", 73 | "SphericalHarmonics", 74 | "XYZToSpherical", 75 | "SphericalToXYZ", 76 | "XYZToSinCos", 77 | "BatchRMSNorm", 78 | "LayerRMSNorm", 79 | "initialize_tensor_product", 80 | "initialize_so3_so2_linear", 81 | "initialize_linear", 82 | "TensorProduct", 83 | "TensorDot", 84 | "SparseWignerRotation", 85 | "DenseWignerRotation", 86 | "WignerD", 87 | "AlignToZWignerD", 88 | "PolynomialCutoff", 89 | "CosineCutoff", 90 | "MollifierCutoff", 91 | "Gate", 92 | "SinCos", 93 | "BesselBasis", 94 | "SquaredNorm", 95 | "Norm", 96 | "MeanSquaredNorm", 97 | "Dropout", 98 | "AnglesToMatrix" 99 | ] 100 | -------------------------------------------------------------------------------- /equitorch/nn/angular.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn # Add nn import 4 | from .functional.angular import sincos 5 | 6 | class SinCos(nn.Module): 7 | r""" 8 | Module wrapper for the :func:`~equitorch.nn.functional.angular.sincos` function. 9 | 10 | Computes the sin/cos expansion of an angle \(a\): 11 | 12 | .. math:: 13 | 14 | [1.0, \sin(a), \cos(a), \sin(2a), \cos(2a), \dots, \sin(\text{max_m} \cdot a), \cos(\text{max_m} \cdot a)] 15 | 16 | or 17 | 18 | .. math:: 19 | [1.0, \sqrt{2}\sin(a), \sqrt{2}\cos(a), \sqrt{2}\sin(2a), \sqrt{2}\cos(2a), \dots, \sqrt{2}\sin(\text{max_m} \cdot a), \sqrt{2}\cos(\text{max_m} \cdot a)] 20 | 21 | The leading 1.0 is excluded if `with_ones` is ``False``. 22 | 23 | Args: 24 | max_m (int): The maximum multiple of the angle \(a\) to compute \(\sin\) and \(\cos\) for. 25 | with_ones (bool, optional): Whether to include the leading 1.0 in the expansion. Defaults to ``True``. 26 | component_normalize (bool, optional): If ``True``, multiplies the \(\sin\) and \(\cos\) values by \(\sqrt{2}\) 27 | such that the expectation of the squared norm over \([0, 2\pi]\) is 1. 28 | Defaults to ``False``. 29 | """ 30 | def __init__(self, max_m: int, with_ones: bool = True, component_normalize: bool = False): 31 | super().__init__() 32 | if not isinstance(max_m, int) or max_m < 0: 33 | raise ValueError(f"max_m must be a non-negative integer, got {max_m}") 34 | self.max_m = max_m 35 | self.with_ones = with_ones 36 | self.component_normalize = component_normalize 37 | 38 | def forward(self, angle: Tensor) -> Tensor: 39 | r""" 40 | Args: 41 | angle (Tensor): Input angles. 42 | 43 | Returns: 44 | Tensor: The computed sin/cos tensor. 45 | """ 46 | return sincos(angle, self.max_m, self.with_ones, component_normalize=self.component_normalize) 47 | 48 | def extra_repr(self) -> str: 49 | return f'max_m={self.max_m}, with_ones={self.with_ones}' 50 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.ops.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.ops package 2 | ===================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | equitorch.ops.accumulated\_indexed\_product\_segment\_op module 8 | --------------------------------------------------------------- 9 | 10 | .. automodule:: equitorch.ops.accumulated_indexed_product_segment_op 11 | :members: 12 | :show-inheritance: 13 | :undoc-members: 14 | 15 | equitorch.ops.batched\_sparse\_dense\_op module 16 | ----------------------------------------------- 17 | 18 | .. automodule:: equitorch.ops.batched_sparse_dense_op 19 | :members: 20 | :show-inheritance: 21 | :undoc-members: 22 | 23 | equitorch.ops.indexed\_product\_op module 24 | ----------------------------------------- 25 | 26 | .. automodule:: equitorch.ops.indexed_product_op 27 | :members: 28 | :show-inheritance: 29 | :undoc-members: 30 | 31 | equitorch.ops.indexed\_product\_scale\_segment\_op module 32 | --------------------------------------------------------- 33 | 34 | .. automodule:: equitorch.ops.indexed_product_scale_segment_op 35 | :members: 36 | :show-inheritance: 37 | :undoc-members: 38 | 39 | equitorch.ops.indexed\_product\_segment\_op module 40 | -------------------------------------------------- 41 | 42 | .. automodule:: equitorch.ops.indexed_product_segment_op 43 | :members: 44 | :show-inheritance: 45 | :undoc-members: 46 | 47 | equitorch.ops.indexed\_scale\_segment\_op module 48 | ------------------------------------------------ 49 | 50 | .. automodule:: equitorch.ops.indexed_scale_segment_op 51 | :members: 52 | :show-inheritance: 53 | :undoc-members: 54 | 55 | equitorch.ops.kernel\_dense module 56 | ---------------------------------- 57 | 58 | .. automodule:: equitorch.ops.kernel_dense 59 | :members: 60 | :show-inheritance: 61 | :undoc-members: 62 | 63 | equitorch.ops.kernel\_utils module 64 | ---------------------------------- 65 | 66 | .. automodule:: equitorch.ops.kernel_utils 67 | :members: 68 | :show-inheritance: 69 | :undoc-members: 70 | 71 | equitorch.ops.product\_segment\_op module 72 | ----------------------------------------- 73 | 74 | .. automodule:: equitorch.ops.product_segment_op 75 | :members: 76 | :show-inheritance: 77 | :undoc-members: 78 | 79 | equitorch.ops.spherical\_harmonics module 80 | ----------------------------------------- 81 | 82 | .. automodule:: equitorch.ops.spherical_harmonics 83 | :members: 84 | :show-inheritance: 85 | :undoc-members: 86 | 87 | Module contents 88 | --------------- 89 | 90 | .. automodule:: equitorch.ops 91 | :members: 92 | :show-inheritance: 93 | :undoc-members: 94 | -------------------------------------------------------------------------------- /test/test_op/test_indexed_scale_segment.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | sys.path.append('..') 5 | 6 | import torch 7 | from torch_geometric.utils import segment 8 | 9 | from test_utils import profile_funcs, compare_funcs, max_abs_diff 10 | from irreps import check_irreps 11 | # from so3_indicies import expand_left, tp_info 12 | from ops.indexed_scale_segment_op import indexed_scale_segment 13 | from utils._indices import expand_left 14 | from utils._structs import tp_info 15 | 16 | import os 17 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 18 | 19 | torch.set_default_dtype(torch.float32) 20 | # torch.set_default_dtype(torch.float64) 21 | 22 | torch.random.manual_seed(0) 23 | 24 | 25 | def prepare_indexed_scale_segment(ir1, ir2, ir, device): 26 | print('tp info start') 27 | tp_info, num_paths = tp_info(ir, ir1, ir2) 28 | tp_info = tp_info.to(device) 29 | print('tp info calculated') 30 | kM1M2 = tp_info.kM1M2_MijM1M2 31 | # M = tp_info.M_MijM1M2 32 | seg = tp_info.M_seg_MijM1M2 33 | cg = tp_info.cg_vals 34 | return cg, kM1M2, None, seg 35 | 36 | def indexed_scale_segment_torch(input, scale, index, seg): 37 | input = input.index_select(-2, index) * scale.unsqueeze(-1) 38 | return segment(input, expand_left(input, seg, -2)) 39 | 40 | def init_indexed_scale_segment(ir1, ir2, ir, C, N, ones=False, device='cuda'): 41 | cg, kM1M2, M, seg = prepare_indexed_scale_segment(ir1, ir2, ir, device) 42 | 43 | input = torch.randn(N,kM1M2.max()+1,C).to(device) 44 | 45 | print('shapes:') 46 | print(input.shape) 47 | 48 | if ones: 49 | input = torch.ones_like(input) 50 | 51 | funcs = [ 52 | indexed_scale_segment, indexed_scale_segment_torch 53 | ] 54 | 55 | return input, cg, kM1M2, M, seg, funcs 56 | 57 | def test_indexed_scale_segment(ir1, ir2, ir, C, N, ones=False): 58 | 59 | input, cg, kM1M2, M, seg, funcs = init_indexed_scale_segment(ir1, ir2, ir, C, N, ones) 60 | 61 | compare_funcs(funcs, max_abs_diff, 62 | input, cg, kM1M2, seg) 63 | profile_funcs(funcs, ['triton', 'torch'], None, 5, 64 | input, cg, kM1M2, seg) 65 | 66 | 67 | if __name__ == '__main__': 68 | 69 | # irreps1 = irreps2 = irreps_out = check_irreps((0,0)) 70 | # C = 1 71 | # N = 1 72 | 73 | # irreps1 = check_irreps((1,5)) 74 | # irreps2 = check_irreps((2,7)) 75 | # irreps_out = check_irreps((1,4)) 76 | # C = 347 77 | # N = 51 78 | 79 | irreps1 = irreps2 = irreps_out = check_irreps((0,4)) 80 | # C = 64 81 | C = 128 82 | # C = 512 83 | N = 256 84 | 85 | test_indexed_scale_segment(irreps1, irreps2, irreps_out, C, N) 86 | -------------------------------------------------------------------------------- /equitorch/nn/functional/cutoffs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | @torch.jit.script 4 | def radial_standarize(input: torch.Tensor, range: float, r_min: float): 5 | r"""Standardize radial distances for cutoff functions. 6 | 7 | Transforms input distances :math:`r` to a normalized range :math:`u \in [0, 1]` using: 8 | 9 | .. math:: 10 | u = \text{clamp}\left(\frac{r - r_{\text{min}}}{r_{\text{max}} - r_{\text{min}}}, 0, 1\right) 11 | 12 | where ``range`` = :math:`r_{\text{max}} - r_{\text{min}}`. 13 | 14 | Args: 15 | input (torch.Tensor): Input distance tensor. 16 | range (float): The difference :math:`r_{\text{max}} - r_{\text{min}}`. 17 | r_min (float): The minimum cutoff distance :math:`r_{\text{min}}`. 18 | 19 | Returns: 20 | torch.Tensor: Standardized distances. 21 | """ 22 | if r_min != 0: 23 | input = input - r_min 24 | input = (input / range).clamp(0,1) 25 | return input 26 | 27 | 28 | # adapted from https://github.com/mir-group/nequip/blob/v0.6.2/nequip/nn/cutoffs.py 29 | @torch.jit.script 30 | def polynomial_cutoff(input: torch.Tensor, p: float = 6.0) -> torch.Tensor: 31 | r"""Polynomial cutoff function. 32 | 33 | This is the functional version of the :class:`~equitorch.nn.cutoffs.PolynomialCutoff` module. 34 | See :class:`~equitorch.nn.cutoffs.PolynomialCutoff` for the mathematical formula and more details. 35 | 36 | Args: 37 | input (torch.Tensor): Standardized distance tensor, :math:`u \in [0, 1]`. 38 | p (float, optional): Power parameter. Defaults to ``6.0``. 39 | 40 | Returns: 41 | torch.Tensor: Cutoff values. 42 | """ 43 | 44 | out = 1.0 45 | out = out - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(input, p)) 46 | out = out + (p * (p + 2.0) * torch.pow(input, p + 1.0)) 47 | out = out - ((p * (p + 1.0) / 2) * torch.pow(input, p + 2.0)) 48 | 49 | return out 50 | 51 | @torch.jit.script 52 | def cosine_cutoff(input: torch.Tensor) -> torch.Tensor: 53 | r"""Cosine cutoff function. 54 | 55 | This is the functional version of the :class:`~equitorch.nn.cutoffs.CosineCutoff` module. 56 | See :class:`~equitorch.nn.cutoffs.CosineCutoff` for the mathematical formula and more details. 57 | 58 | Args: 59 | input (torch.Tensor): Standardized distance tensor, :math:`u \in [0, 1]`. 60 | 61 | Returns: 62 | torch.Tensor: Cutoff values. 63 | """ 64 | return 0.5 * (1.0 + torch.cos(torch.pi * input)) 65 | 66 | @torch.jit.script 67 | def mollifier_cutoff(input: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: 68 | r"""Mollifier cutoff function. 69 | 70 | This is the functional version of the :class:`~equitorch.nn.cutoffs.MollifierCutoff` module. 71 | See :class:`~equitorch.nn.cutoffs.MollifierCutoff` for the mathematical formula and more details. 72 | 73 | Args: 74 | input (torch.Tensor): Standardized distance tensor, :math:`u \in [0, 1]`. 75 | eps (float, optional): Small epsilon for numerical stability. Defaults to ``1e-12``. 76 | 77 | Returns: 78 | torch.Tensor: Cutoff values. 79 | """ 80 | return torch.exp(1-1/(1-input**2+eps)) -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn package 2 | ==================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | equitorch.nn.functional 11 | 12 | Submodules 13 | ---------- 14 | 15 | equitorch.nn.activations module 16 | ------------------------------- 17 | 18 | .. automodule:: equitorch.nn.activations 19 | :members: 20 | :show-inheritance: 21 | :undoc-members: 22 | 23 | equitorch.nn.angular module 24 | --------------------------- 25 | 26 | .. automodule:: equitorch.nn.angular 27 | :members: 28 | :show-inheritance: 29 | :undoc-members: 30 | 31 | equitorch.nn.cutoffs module 32 | --------------------------- 33 | 34 | .. automodule:: equitorch.nn.cutoffs 35 | :members: 36 | :show-inheritance: 37 | :undoc-members: 38 | 39 | equitorch.nn.dropout module 40 | --------------------------- 41 | 42 | .. automodule:: equitorch.nn.dropout 43 | :members: 44 | :show-inheritance: 45 | :undoc-members: 46 | 47 | equitorch.nn.init module 48 | ------------------------ 49 | 50 | .. automodule:: equitorch.nn.init 51 | :members: 52 | :show-inheritance: 53 | :undoc-members: 54 | 55 | equitorch.nn.linears module 56 | --------------------------- 57 | 58 | .. automodule:: equitorch.nn.linears 59 | :members: 60 | :show-inheritance: 61 | :undoc-members: 62 | 63 | equitorch.nn.norm module 64 | ------------------------ 65 | 66 | .. automodule:: equitorch.nn.norm 67 | :members: 68 | :show-inheritance: 69 | :undoc-members: 70 | 71 | equitorch.nn.normalization module 72 | --------------------------------- 73 | 74 | .. automodule:: equitorch.nn.normalization 75 | :members: 76 | :show-inheritance: 77 | :undoc-members: 78 | 79 | equitorch.nn.others module 80 | -------------------------- 81 | 82 | .. automodule:: equitorch.nn.others 83 | :members: 84 | :show-inheritance: 85 | :undoc-members: 86 | 87 | equitorch.nn.radials module 88 | --------------------------- 89 | 90 | .. automodule:: equitorch.nn.radials 91 | :members: 92 | :show-inheritance: 93 | :undoc-members: 94 | 95 | equitorch.nn.rotations module 96 | ----------------------------- 97 | 98 | .. automodule:: equitorch.nn.rotations 99 | :members: 100 | :show-inheritance: 101 | :undoc-members: 102 | 103 | equitorch.nn.sphericals module 104 | ------------------------------ 105 | 106 | .. automodule:: equitorch.nn.sphericals 107 | :members: 108 | :show-inheritance: 109 | :undoc-members: 110 | 111 | equitorch.nn.tensor\_products module 112 | ------------------------------------ 113 | 114 | .. automodule:: equitorch.nn.tensor_products 115 | :members: 116 | :show-inheritance: 117 | :undoc-members: 118 | 119 | equitorch.nn.wigner\_d module 120 | ----------------------------- 121 | 122 | .. automodule:: equitorch.nn.wigner_d 123 | :members: 124 | :show-inheritance: 125 | :undoc-members: 126 | 127 | Module contents 128 | --------------- 129 | 130 | .. automodule:: equitorch.nn 131 | :members: 132 | :show-inheritance: 133 | :undoc-members: 134 | -------------------------------------------------------------------------------- /equitorch/ops/product_segment_op.py: -------------------------------------------------------------------------------- 1 | from .batched_sparse_dense_op import ( 2 | indexed_mul_scale_gather, 3 | indexed_inner_scale_gather, 4 | indexed_outer_scale_gather, 5 | indexed_vecmat_scale_gather, 6 | indexed_vecsca_scale_gather 7 | ) 8 | 9 | 10 | 11 | def mul_segment( 12 | input1, 13 | input2, 14 | seg, 15 | out=None, 16 | block_size_n=64, 17 | block_size_c=64, 18 | num_stages=2, 19 | accumulated=False 20 | ): 21 | r"""分段聚集的逐元素乘法(必须提供seg且至少一个索引)""" 22 | 23 | return indexed_mul_scale_gather( 24 | input1, input2, 25 | seg=seg, 26 | out=out, 27 | block_size_n=block_size_n, 28 | block_size_c=block_size_c, 29 | num_stages=num_stages, 30 | out_accumulated=accumulated 31 | ) 32 | 33 | def outer_segment( 34 | input1, 35 | input2, 36 | seg, 37 | out=None, 38 | block_size_n=8, 39 | block_size_c1=32, 40 | block_size_c2=32, 41 | num_stages=1, 42 | accumulated=False 43 | ): 44 | r"""分段聚集的逐元素外积(必须提供seg且至少一个索引)""" 45 | 46 | return indexed_outer_scale_gather( 47 | input1, input2, 48 | seg=seg, 49 | out=out, 50 | block_size_n=block_size_n, 51 | block_size_c1=block_size_c1, 52 | block_size_c2=block_size_c2, 53 | num_stages=num_stages, 54 | out_accumulated=accumulated 55 | ) 56 | 57 | def inner_segment( 58 | input1, 59 | input2, 60 | seg, 61 | out=None, 62 | block_size_n=32, 63 | block_size_c=32, 64 | num_stages=2, 65 | accumulated=False 66 | ): 67 | r"""分段聚集的逐元素内积(必须提供seg且至少一个索引)""" 68 | 69 | return indexed_inner_scale_gather( 70 | input1, input2, 71 | seg=seg, 72 | out=out, 73 | block_size_n=block_size_n, 74 | block_size_c=block_size_c, 75 | num_stages=num_stages, 76 | accumulated=accumulated 77 | ) 78 | 79 | def vecmat_segment( 80 | input1, 81 | input2, 82 | seg, 83 | out=None, 84 | block_size_n=8, 85 | block_size_c_in=32, 86 | block_size_c_out=32, 87 | num_stages=1, 88 | accumulated=False 89 | ): 90 | r"""分段聚集的逐元素向量-矩阵乘(必须提供seg且至少一个索引)""" 91 | 92 | return indexed_vecmat_scale_gather( 93 | input1, input2, 94 | seg=seg, 95 | out=out, 96 | block_size_n=block_size_n, 97 | block_size_c_in=block_size_c_in, 98 | block_size_c_out=block_size_c_out, 99 | num_stages=num_stages, 100 | accumulated=accumulated 101 | ) 102 | 103 | def vecsca_segment( 104 | input1, 105 | input2, 106 | seg, 107 | out=None, 108 | block_size_n=64, 109 | block_size_c=64, 110 | num_stages=2, 111 | accumulated=False 112 | ): 113 | r"""分段聚集的逐元素向量缩放(必须提供seg且至少一个索引)""" 114 | 115 | return indexed_vecsca_scale_gather( 116 | input1, input2, 117 | seg=seg, 118 | out=out, 119 | block_size_n=block_size_n, 120 | block_size_c=block_size_c, 121 | num_stages=num_stages, 122 | accumulated=accumulated 123 | ) 124 | 125 | -------------------------------------------------------------------------------- /docs/build/html/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions .rst-other-versions .rtd-current-item{font-weight:700}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}}#flyout-search-form{padding:6px} -------------------------------------------------------------------------------- /test/test_other_modules/test_irreps_split.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "87c3f1c4", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import sys\n", 22 | "\n", 23 | "\n", 24 | "sys.path.append('../..')\n", 25 | "\n", 26 | "import torch\n", 27 | "from irreps import Irreps, check_irreps\n", 28 | "\n", 29 | "from nn.others import SplitIrreps\n", 30 | "\n", 31 | "torch.random.manual_seed(0)\n" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 8, 37 | "id": "6f470a47", 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "torch.Size([5, 14, 10])\n", 45 | "torch.Size([5, 2, 10]) torch.Size([5, 4, 10]) torch.Size([5, 8, 10])\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "irreps = Irreps(\"3x0e + 2x1o + 2\")\n", 51 | "split = SplitIrreps(irreps, [2,2,2])\n", 52 | "x = torch.randn(5, irreps.dim, 10)\n", 53 | "print(x.shape)\n", 54 | "sp = split(x)\n", 55 | "print(sp[0].shape, sp[1].shape, sp[2].shape)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 9, 61 | "id": "2d110aab", 62 | "metadata": {}, 63 | "outputs": [ 64 | { 65 | "name": "stdout", 66 | "output_type": "stream", 67 | "text": [ 68 | "torch.Size([5, 14, 10])\n", 69 | "torch.Size([5, 6, 10]) torch.Size([5, 8, 10])\n" 70 | ] 71 | } 72 | ], 73 | "source": [ 74 | "split = SplitIrreps(irreps, [4,-1])\n", 75 | "x = torch.randn(5, irreps.dim, 10)\n", 76 | "print(x.shape)\n", 77 | "sp = split(x)\n", 78 | "print(sp[0].shape, sp[1].shape)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 12, 84 | "id": "6d4bc1e3", 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "torch.Size([5, 14, 10])\n", 92 | "torch.Size([5, 1, 10]) torch.Size([5, 5, 10]) torch.Size([5, 8, 10])\n" 93 | ] 94 | } 95 | ], 96 | "source": [ 97 | "split = SplitIrreps(irreps, [1,...,2])\n", 98 | "x = torch.randn(5, irreps.dim, 10)\n", 99 | "print(x.shape)\n", 100 | "sp = split(x)\n", 101 | "print(sp[0].shape, sp[1].shape, sp[2].shape)" 102 | ] 103 | } 104 | ], 105 | "metadata": { 106 | "kernelspec": { 107 | "display_name": "base", 108 | "language": "python", 109 | "name": "python3" 110 | }, 111 | "language_info": { 112 | "codemirror_mode": { 113 | "name": "ipython", 114 | "version": 3 115 | }, 116 | "file_extension": ".py", 117 | "mimetype": "text/x-python", 118 | "name": "python", 119 | "nbconvert_exporter": "python", 120 | "pygments_lexer": "ipython3", 121 | "version": "3.12.8" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 5 126 | } 127 | -------------------------------------------------------------------------------- /docs/build/html/_sources/modules/equitorch.nn.functional.rst.txt: -------------------------------------------------------------------------------- 1 | equitorch.nn.functional package 2 | =============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | equitorch.nn.functional.activations module 8 | ------------------------------------------ 9 | 10 | .. automodule:: equitorch.nn.functional.activations 11 | :members: 12 | :show-inheritance: 13 | :undoc-members: 14 | 15 | equitorch.nn.functional.angular module 16 | -------------------------------------- 17 | 18 | .. automodule:: equitorch.nn.functional.angular 19 | :members: 20 | :show-inheritance: 21 | :undoc-members: 22 | 23 | equitorch.nn.functional.cutoffs module 24 | -------------------------------------- 25 | 26 | .. automodule:: equitorch.nn.functional.cutoffs 27 | :members: 28 | :show-inheritance: 29 | :undoc-members: 30 | 31 | equitorch.nn.functional.dropout module 32 | -------------------------------------- 33 | 34 | .. automodule:: equitorch.nn.functional.dropout 35 | :members: 36 | :show-inheritance: 37 | :undoc-members: 38 | 39 | equitorch.nn.functional.linears module 40 | -------------------------------------- 41 | 42 | .. automodule:: equitorch.nn.functional.linears 43 | :members: 44 | :show-inheritance: 45 | :undoc-members: 46 | 47 | equitorch.nn.functional.norm module 48 | ----------------------------------- 49 | 50 | .. automodule:: equitorch.nn.functional.norm 51 | :members: 52 | :show-inheritance: 53 | :undoc-members: 54 | 55 | equitorch.nn.functional.normalization module 56 | -------------------------------------------- 57 | 58 | .. automodule:: equitorch.nn.functional.normalization 59 | :members: 60 | :show-inheritance: 61 | :undoc-members: 62 | 63 | equitorch.nn.functional.rotations module 64 | ---------------------------------------- 65 | 66 | .. automodule:: equitorch.nn.functional.rotations 67 | :members: 68 | :show-inheritance: 69 | :undoc-members: 70 | 71 | equitorch.nn.functional.sparse\_product module 72 | ---------------------------------------------- 73 | 74 | .. automodule:: equitorch.nn.functional.sparse_product 75 | :members: 76 | :show-inheritance: 77 | :undoc-members: 78 | 79 | equitorch.nn.functional.sparse\_scale module 80 | -------------------------------------------- 81 | 82 | .. automodule:: equitorch.nn.functional.sparse_scale 83 | :members: 84 | :show-inheritance: 85 | :undoc-members: 86 | 87 | equitorch.nn.functional.sphericals module 88 | ----------------------------------------- 89 | 90 | .. automodule:: equitorch.nn.functional.sphericals 91 | :members: 92 | :show-inheritance: 93 | :undoc-members: 94 | 95 | equitorch.nn.functional.tensor\_products module 96 | ----------------------------------------------- 97 | 98 | .. automodule:: equitorch.nn.functional.tensor_products 99 | :members: 100 | :show-inheritance: 101 | :undoc-members: 102 | 103 | equitorch.nn.functional.wigner\_d module 104 | ---------------------------------------- 105 | 106 | .. automodule:: equitorch.nn.functional.wigner_d 107 | :members: 108 | :show-inheritance: 109 | :undoc-members: 110 | 111 | Module contents 112 | --------------- 113 | 114 | .. automodule:: equitorch.nn.functional 115 | :members: 116 | :show-inheritance: 117 | :undoc-members: 118 | -------------------------------------------------------------------------------- /equitorch/ops/indexed_product_op.py: -------------------------------------------------------------------------------- 1 | from .batched_sparse_dense_op import ( 2 | indexed_mul_scale_gather, 3 | indexed_inner_scale_gather, 4 | indexed_outer_scale_gather, 5 | indexed_vecmat_scale_gather, 6 | indexed_vecsca_scale_gather 7 | ) 8 | 9 | 10 | 11 | def indexed_mul( 12 | input1, input2, index1=None, index2=None, out = None, 13 | block_size_n = 64, 14 | block_size_c = 64, 15 | ): 16 | assert index1 is not None or index2 is not None 17 | 18 | return indexed_mul_scale_gather( 19 | input1, input2, 20 | index1=index1, index2=index2, 21 | out=out, 22 | block_size_n=block_size_n, 23 | block_size_c=block_size_c, 24 | num_stages=0 25 | ) 26 | 27 | def indexed_outer( 28 | input1, 29 | input2, 30 | index1=None, 31 | index2=None, 32 | out=None, 33 | block_size_n=16, # 外积需要更高并行度 34 | block_size_c1=64, 35 | block_size_c2=64, 36 | ): 37 | r"""带索引的批量外积运算""" 38 | assert index1 is not None or index2 is not None 39 | 40 | return indexed_outer_scale_gather( 41 | input1, input2, 42 | index1=index1, 43 | index2=index2, 44 | out=out, 45 | block_size_n=block_size_n, 46 | block_size_c1=block_size_c1, 47 | block_size_c2=block_size_c2, 48 | num_stages=1 # 外积内存压力大,减少流水线阶段 49 | ) 50 | 51 | def indexed_inner( 52 | input1, 53 | input2, 54 | index1=None, 55 | index2=None, 56 | out=None, 57 | block_size_n=32, # 内积计算密集,增大块大小 58 | block_size_c=32, 59 | loop_unroll_factor=4, # 显式循环展开 60 | ): 61 | r"""带索引的批量内积运算""" 62 | assert index1 is not None or index2 is not None 63 | 64 | return indexed_inner_scale_gather( 65 | input1, input2, 66 | index1=index1, 67 | index2=index2, 68 | out=out, 69 | block_size_n=block_size_n, 70 | block_size_c=block_size_c, # 固定通道分块大小 71 | num_stages=2, 72 | loop_unroll_factor=loop_unroll_factor 73 | ) 74 | 75 | def indexed_vecmat( 76 | input1, # [N, M1, C_in] 77 | input2, # [N, M2, C_in, C_out] 78 | index1=None, 79 | index2=None, 80 | out=None, 81 | block_size_n=16, # 矩阵乘法需要更高并行 82 | block_size_c_out=32, # 输出通道分块 83 | block_size_c_in=32, # 输入通道分块 84 | ): 85 | r"""带索引的向量-矩阵乘法""" 86 | assert index1 is not None or index2 is not None 87 | 88 | return indexed_vecmat_scale_gather( 89 | input1, input2, 90 | index1=index1, 91 | index2=index2, 92 | out=out, 93 | block_size_n=block_size_n, 94 | block_size_c_out=block_size_c_out, 95 | block_size_c_in=block_size_c_in, 96 | num_stages=1, # 减少寄存器压力 97 | loop_unroll_factor=4 98 | ) 99 | 100 | def indexed_vecsca( 101 | input1, # [N, M1, C] 102 | input2, # [N, M2] 或 [M2] 或 [N, M2, 1] 103 | index1=None, 104 | index2=None, 105 | out=None, 106 | block_size_n=32, 107 | block_size_c=32, 108 | ): 109 | r"""带索引的向量缩放运算""" 110 | assert index1 is not None or index2 is not None 111 | 112 | return indexed_vecsca_scale_gather( 113 | input1, input2, 114 | index1=index1, 115 | index2=index2, 116 | out=out, 117 | block_size_n=block_size_n, 118 | block_size_c=block_size_c, 119 | num_stages=2 # 缩放计算简单,可用更多流水线 120 | ) -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | import sys, os 10 | sys.path.insert(0,'../') 11 | sys.path.insert(0,os.path.abspath('../')) 12 | sys.path.insert(0,os.path.abspath('../..')) 13 | 14 | import equitorch 15 | 16 | project = 'equitorch' 17 | copyright = '2025, Tong Wang' 18 | author = 'Tong Wang' 19 | 20 | # -- General configuration --------------------------------------------------- 21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 22 | 23 | extensions = [ 24 | 'sphinx.ext.autosummary', 25 | 'sphinx.ext.autodoc', 26 | 'sphinx.ext.intersphinx', 27 | 'sphinx.ext.mathjax', 28 | 'sphinx.ext.napoleon', 29 | ] 30 | 31 | autoapi_options = [ 32 | "members", 33 | # "undoc-members", 34 | "show-inheritance", 35 | "show-module-summary", 36 | "imported-members", 37 | ] 38 | 39 | # autodoc_default_options = { 40 | # 'member-order': 'bysource', 41 | # 'special-members': '__init__', 42 | # 'undoc-members': False, 43 | # 'private-members': False, 44 | # 'inherited-members': False, 45 | # 'show-inheritance': False, 46 | 47 | # } 48 | 49 | autoapi_member_order = 'bysource' 50 | 51 | napoleon_google_docstring = True 52 | napoleon_numpy_docstring = True 53 | napoleon_include_init_with_doc = True 54 | napoleon_include_private_with_doc = False 55 | napoleon_include_special_with_doc = False 56 | napoleon_use_admonition_for_examples = False 57 | napoleon_use_admonition_for_notes = False 58 | napoleon_use_admonition_for_references = False 59 | napoleon_use_ivar = False 60 | napoleon_use_param = True 61 | napoleon_use_rtype = True 62 | napoleon_preprocess_types = False 63 | napoleon_attr_annotations = True 64 | 65 | autodoc_default_options = { 66 | 'member-order': 'bysource', 67 | 'undoc-members': False, 68 | 'no_index': True, 69 | 'show-inheritance': False, 70 | 'inherited-members': False, 71 | 'exclude-members': '__init__, extra_repr, forward', 72 | # 'imported-members': False, 73 | # 'members': False, 74 | } 75 | 76 | # templates borrowed from https://github.com/pyg-team/pytorch_geometric/blob/master/docs/source/_templates/ 77 | templates_path = ['_templates'] 78 | exclude_patterns = [] 79 | 80 | 81 | # autosummary_generate = True 82 | # autosummary_imported_members = True 83 | autosummary_ignore_module_all = False 84 | # -- Options for HTML output ------------------------------------------------- 85 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 86 | 87 | html_theme = 'sphinx_rtd_theme' 88 | html_static_path = ['_static'] 89 | html_logo = '_static/logo_.png' 90 | html_css_files = [ 91 | 'css/custom.css' 92 | ] 93 | 94 | def rstjinja(app, docname, source): 95 | r""" 96 | Render our pages as a jinja template for fancy templating goodness. 97 | """ 98 | # Make sure we're outputting HTML 99 | # if app.builder.format != 'html': 100 | # return 101 | src = source[0] 102 | rst_context = {'equitorch': equitorch} 103 | rendered = app.builder.templates.render_string( 104 | src, rst_context | app.config.html_context 105 | ) 106 | source[0] = rendered 107 | 108 | def setup(app): 109 | app.connect("source-read", rstjinja) -------------------------------------------------------------------------------- /equitorch/nn/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import torch.nn.functional as F 5 | 6 | from ..utils._structs import irreps_info 7 | 8 | from ..irreps import Irreps 9 | from .functional.dropout import irrep_wise_dropout 10 | 11 | class Dropout(nn.Module): 12 | r""" 13 | Apply dropout to equivariant features. 14 | 15 | Can operate irrep-wise or on the entire feature vector (channel-wise). 16 | 17 | Args: 18 | p (float, optional): Probability of an element to be zeroed. 19 | Default: 0.5 20 | irreps (Irreps, optional): Irreps of the input tensor. 21 | Required if `irrep_wise` is True. 22 | Default: None 23 | irrep_wise (bool, optional): If True, applies dropout independently 24 | for each (irrep_instance, channel). 25 | If False, applies standard 1D dropout 26 | treating (irreps_dim, channels) as a 27 | single feature dimension for dropout. 28 | Default: True 29 | work_on_eval (bool, optional): If True, dropout is applied even during 30 | evaluation. Default: False 31 | """ 32 | def __init__(self, p: float = 0.5, 33 | irreps: Irreps = None, 34 | irrep_wise: bool = True, 35 | work_on_eval: bool = False): 36 | super().__init__() 37 | if p < 0.0 or p > 1.0: 38 | raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 39 | if irrep_wise and irreps is None: 40 | raise ValueError("irreps must be provided if irrep_wise is True") 41 | 42 | self.p = p 43 | self.irrep_wise = irrep_wise 44 | self.work_on_eval = work_on_eval 45 | 46 | if irrep_wise: 47 | # We need IrrepsInfo for irrep_wise_dropout 48 | # Construct it once and store it if irreps are provided 49 | self.irreps_info = irreps_info(irreps) if irreps is not None else None 50 | else: 51 | self.irreps_info = None # Not needed for non-irrep-wise 52 | 53 | def forward(self, input: Tensor) -> Tensor: 54 | r""" 55 | Args: 56 | input (Tensor): Input tensor of shape (N, irreps_dim, C) 57 | Returns: 58 | Tensor: Output tensor with dropout applied. 59 | """ 60 | assert input.ndim >= 2, "Input tensor must have at least 2 dimensions (irreps_dim, channels)" 61 | 62 | 63 | if not self.work_on_eval and not self.training: 64 | return input 65 | 66 | if self.p == 0.0: # No dropout if p is 0 67 | return input 68 | 69 | if self.irrep_wise: 70 | return irrep_wise_dropout(input, self.p, self.training or self.work_on_eval, self.irreps_info) 71 | else: 72 | 73 | x = input.transpose(-1, -2) 74 | x_dropout = F.dropout1d(x, self.p, self.training or self.work_on_eval) 75 | output = x_dropout.transpose(-1, -2) 76 | 77 | return output.contiguous() # Ensure contiguous after transpose 78 | 79 | def extra_repr(self) -> str: 80 | return f'p={self.p}, irrep_wise={self.irrep_wise}, work_on_eval={self.work_on_eval}' 81 | 82 | def _apply(self, *args, **kwargs): 83 | d = super()._apply(*args, **kwargs) 84 | if d.irreps_info is not None: 85 | d.irreps_info = self.irreps_info._apply(*args, **kwargs) 86 | return d -------------------------------------------------------------------------------- /test/test_norm/test_grad_layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import os 4 | from torch.autograd import gradcheck, gradgradcheck 5 | 6 | from equitorch.irreps import Irreps, check_irreps 7 | from equitorch.nn.normalization import LayerRMSNorm 8 | 9 | # Set environment and defaults 10 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 11 | torch.set_default_dtype(torch.float64) # Use float64 for gradcheck stability 12 | torch.random.manual_seed(0) 13 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | 15 | EPS = 1e-9 16 | 17 | def _init_test_case(irreps, channels, batch_size, device, affine, scaled): 18 | r"""Initialize test case for gradcheck.""" 19 | layer = LayerRMSNorm( 20 | irreps=irreps, 21 | channels=channels, 22 | eps=EPS, 23 | affine=affine, 24 | scaled=scaled 25 | ).to(device) 26 | 27 | # Create input with requires_grad=True 28 | x = torch.randn(batch_size, Irreps(irreps).dim, channels, 29 | device=device, dtype=torch.float64, requires_grad=True) 30 | 31 | return layer, x 32 | 33 | def _run_gradcheck(layer, x): 34 | r"""Run gradcheck and gradgradcheck for the given inputs.""" 35 | def func(x): 36 | return layer(x) 37 | 38 | # Ensure input is leaf variable 39 | x = x.detach().requires_grad_(True) 40 | 41 | # Forward pass 42 | out = func(x) 43 | print(f"Forward output sum: {out.sum().item()}") 44 | 45 | # Backward pass with gradient checking 46 | out.sum().backward() 47 | 48 | # Print gradient norms for debugging 49 | print(f"x.grad norm: {x.grad.norm().item() if x.grad is not None else 'None'}") 50 | 51 | # Reset gradients 52 | x.grad = None 53 | 54 | # Run gradcheck with relaxed tolerances 55 | gradcheck_success = gradcheck( 56 | func, (x,), 57 | eps=EPS, atol=1e-5, rtol=1e-5, 58 | nondet_tol=1e-3, 59 | check_undefined_grad=False 60 | ) 61 | print('grad_check_passed') 62 | 63 | # Run gradgradcheck with same settings 64 | gradgradcheck_success = gradgradcheck( 65 | func, (x,), 66 | eps=EPS, atol=1e-5, rtol=1e-5, 67 | nondet_tol=1e-3, 68 | check_undefined_grad=False 69 | ) 70 | print('gradgrad_check_passed') 71 | 72 | return gradcheck_success, gradgradcheck_success 73 | 74 | # Test cases with different irreps and parameter combinations 75 | def test_norm_case(irreps, channels, batch_size, affine, scaled): 76 | r"""Test normalization with given parameters.""" 77 | print(f"\nTesting LayerRMSNorm with irreps={irreps}, channels={channels}, " 78 | f"batch_size={batch_size}, affine={affine}, scaled={scaled}") 79 | 80 | layer, x = _init_test_case(irreps, channels, batch_size, device, affine, scaled) 81 | gradcheck_success, gradgradcheck_success = _run_gradcheck(layer, x) 82 | 83 | print(f"LayerRMSNorm - gradcheck: {gradcheck_success}, gradgradcheck: {gradgradcheck_success}") 84 | assert gradcheck_success and gradgradcheck_success 85 | 86 | # Main execution 87 | if __name__ == '__main__': 88 | print("Running gradient checks for LayerRMSNorm...") 89 | 90 | # Test different combinations of affine and scaled 91 | test_configs = [ 92 | ("1x0e", 3, 5), # scalar only 93 | ("1x0e + 1x1e", 4, 6), # scalar + vector 94 | ("1x0e + 1x1e + 1x2e", 5, 7) # scalar + vector + tensor 95 | ] 96 | 97 | for irreps, channels, batch_size in test_configs: 98 | for affine in [True, False]: 99 | for scaled in [True, False]: 100 | test_norm_case(irreps, channels, batch_size, affine, scaled) 101 | 102 | print("\nAll gradient checks completed.") 103 | -------------------------------------------------------------------------------- /equitorch/nn/functional/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn.functional as F 4 | 5 | from .sparse_product import sparse_mul 6 | from ...structs import IrrepsInfo, SparseProductInfo 7 | 8 | def irrep_wise_dropout(input: Tensor, p: float, training: bool, irreps_info: IrrepsInfo) -> Tensor: 9 | r""" 10 | Apply dropout irrep-wise to an input tensor. 11 | 12 | This is the functional version of the irrep-wise mode of the 13 | :class:`~equitorch.nn.dropout.Dropout` module. 14 | See :class:`~equitorch.nn.dropout.Dropout` for more details on the dropout mechanism 15 | when ``irrep_wise=True``. 16 | 17 | Args: 18 | input (torch.Tensor): Input tensor of shape ``(..., irreps_dim, channels)``. 19 | p (float): Probability of an element to be zeroed. 20 | training (bool): Apply dropout if ``True``. 21 | irreps_info (IrrepsInfo): Contains irreps structure information, such as 22 | ``irreps_info.irrep_index`` to map components of ``input`` to their 23 | respective irrep instances, and ``irreps_info.num_irreps``. 24 | Must not be ``None``. 25 | 26 | Returns: 27 | torch.Tensor: Output tensor with dropout applied, of the same shape as ``input``. 28 | """ 29 | if p < 0.0 or p > 1.0: 30 | raise ValueError(f"dropout probability has to be between 0 and 1, but got {p}") 31 | if p == 0.0 or not training: 32 | return input 33 | 34 | if irreps_info is None: 35 | raise ValueError("irreps_info cannot be None for irrep_wise_dropout") 36 | 37 | num_irreps = irreps_info.num_irreps # num_irrep_instances 38 | num_channels = input.shape[-1] 39 | 40 | # Create a mask for each (irrep_instance, channel) 41 | # The mask should have shape (num_irreps, num_channels) 42 | # This mask will be applied to input elements input[n, M, c] 43 | # where M corresponds to an irrep instance i, and we use mask[i, c] 44 | 45 | # Bernoulli distribution for the mask 46 | # The mask tensor will have dimensions corresponding to (num_irreps, num_channels) 47 | # It needs to be broadcastable or correctly indexed by sparse_mul 48 | # sparse_mul with index2=irreps_info.irrep_index will map the irreps_dim of input 49 | # to the first dimension of the mask (num_irreps). 50 | # The channel dimension will align directly. 51 | mask_shape = (input.shape[0], num_irreps, num_channels) 52 | mask = torch.bernoulli(torch.full(mask_shape, 1.0 - p, dtype=input.dtype, device=input.device)) 53 | 54 | # Scale the mask by 1 / (1 - p) 55 | mask = mask.div_(1.0 - p).nan_to_num_(nan=0) 56 | 57 | # Define SparseProductInfo similar to Gating or BatchRMSNorm 58 | # input1 is 'input', input2 is 'mask' 59 | # index2 maps input's irreps_dim to mask's first dim (num_irreps) 60 | info_fwd = SparseProductInfo(index2=irreps_info.irrep_index) 61 | # For backward pass of sparse_mul(A, B): 62 | # grad_A = sparse_mul(B, grad_output, info_bwd1_for_A, info_bwd2_for_A, info_fwd_for_A) 63 | # grad_B = sparse_mul(grad_output, A, info_bwd1_for_B, info_bwd2_for_B, info_fwd_for_B) 64 | # Here, A=input, B=mask. 65 | # info_bwd1 (for grad_input) needs to configure sparse_mul(mask, grad_output) 66 | # - mask is input1, grad_output is input2 67 | # - index1 for mask should be irreps_info.irrep_index 68 | info_bwd1 = SparseProductInfo(index1=irreps_info.irrep_index) 69 | # info_bwd2 (for grad_mask, though mask is not learnable) needs to configure sparse_mul(grad_output, input) 70 | # - grad_output is input1, input is input2 71 | # - seg_out for grad_output should be irreps_info.irrep_seg to sum contributions for each (irrep_instance, channel) 72 | info_bwd2 = SparseProductInfo(seg_out=irreps_info.irrep_seg) 73 | 74 | return sparse_mul(input, mask, info_fwd, info_bwd1, info_bwd2) 75 | -------------------------------------------------------------------------------- /prompts/README.md: -------------------------------------------------------------------------------- 1 | # Prompts 2 | 3 | This folder provides some AI-generated backgrounds for related operations. 4 | The backgrounds are not carefully revised and may contain some errors, but it may provide a starting point for LLMs to understand basic ideas. 5 | 6 | 7 | ## File Summaries: 8 | 9 | ### 1. `irreps_introduction.md` 10 | * **Keywords:** Irreps, O(3), SO(3), tensor shape, `Irrep` class, `Irreps` class, multiplicity, channels, e3nn comparison, Schur's Lemma, Clebsch-Gordan coefficients. 11 | * **Introduction:** This document details the conventions for representing irreducible representations (irreps) of O(3)/SO(3) and the standard tensor shapes within the `equitorch` library. It focuses on the `Irrep` and `Irreps` classes, their attributes, initialization, methods (like tensor product), and the distinction between multiplicity and channels. It also covers tensor shape conventions, compares them with `e3nn`, discusses irrep interactions based on Schur's Lemma, and provides examples of module input/output shapes. 12 | 13 | ### 2. `equitorch_others.md` 14 | * **Keywords:** `SplitIrreps`, `Separable` module, equivariant tensors, `torch.split`, sub-modules, tensor manipulation. 15 | * **Introduction:** This document explains utility modules in `equitorch/nn/others.py`. It covers `SplitIrreps` for dividing tensors based on irrep dimensions and `Separable` for applying different transformations to these segments. It includes parameters, functionality, and usage examples for both modules. 16 | 17 | ### 3. `sparse_autograd_functions_documentation.md` 18 | 19 | * **Keywords:** Sparse Tensor Operations, Autograd Functions, SparseProductInfo, SparseScaleInfo, Segment Aggregation, Automatic Differentiation, Triton Kernels, Tensor Products, Sparse Indexing, Forward Pass, Backward Pass, Chain Rule. 20 | * **Introduction:** This document provides a mathematical and operational overview of `torch.autograd.Function` subclasses and `...Info` configuration structures in `equitorch`. It covers sparse tensor products and sparse scale/segment operations, detailing their mathematical formulations, configuration objects (`SparseProductInfo`, `SparseScaleInfo`), and the mechanics of their forward and backward passes for automatic differentiation, without delving into low-level Triton kernel implementations. 21 | 22 | ### 4. `product_kernels_documentation.md` 23 | 24 | * **Keywords:** Triton Kernels, Sparse Operations, Indexed Operations, Batched Operations, Segment Gathering, Kernel Utilities, Dense Product Kernels, `SparseProductInfo`, Autograd Wrappers, Tensor Shape Conventions, Forward/Backward Pass Relationships. 25 | * **Introduction:** This document details the Triton kernels and Python wrapper functions in `equitorch` for various product operations (e.g., mul, outer, inner, vecmat). It covers low-level kernel utilities (`kernel_utils.py`), core dense product kernels (`kernel_dense.py`), and batched sparse/indexed operations with scaling and segment gathering (`batched_sparse_dense_op.py`). It also explains the `SparseProductInfo` structure, autograd wrappers for differentiability (`sparse_product.py`), tensor shape conventions, and the mathematical relationships between forward and backward passes for these operations. 26 | 27 | ### 5. `indexed_scale_segment_documentation.md` 28 | 29 | * **Keywords:** Triton Kernel, Indexed Operation, Scaled Operation, Segment Reduction, Sparse Matrix Multiplication, Autograd Wrapper, `SparseScaleInfo`, Tensor Shapes. 30 | * **Introduction:** This document describes the Triton kernel (`indexed_scale_segment_kernel`) and Python wrapper (`indexed_scale_segment`) in `equitorch/ops/indexed_scale_segment_op.py` for performing an indexed, scaled segment reduction. It also covers the `SparseScale` autograd function from `equitorch/nn/functional/sparse_scale.py` for differentiability, the `SparseScaleInfo` configuration object, the mathematical formulation (relating to sparse matrix multiplication), and tensor shape conventions. 31 | 32 | -------------------------------------------------------------------------------- /test/test_norm/test_grad_batch_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import os 4 | from torch.autograd import gradcheck, gradgradcheck 5 | 6 | from equitorch.irreps import Irreps, check_irreps 7 | from equitorch.nn.normalization import BatchRMSNorm 8 | 9 | # Set environment and defaults 10 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 11 | torch.set_default_dtype(torch.float64) # Use float64 for gradcheck stability 12 | torch.random.manual_seed(0) 13 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | 15 | EPS = 1e-9 16 | 17 | def _init_test_case(irreps, channels, batch_size, device, affine, scaled, training): 18 | r"""Initialize test case for gradcheck.""" 19 | layer = BatchRMSNorm( 20 | irreps=irreps, 21 | channels=channels, 22 | eps=EPS, 23 | momentum=0.1, 24 | affine=affine, 25 | scaled=scaled 26 | ).to(device) 27 | 28 | layer.train(training) # Set training mode 29 | 30 | # Create input with requires_grad=True 31 | x = torch.randn(batch_size, Irreps(irreps).dim, channels, 32 | device=device, dtype=torch.float64, requires_grad=True) 33 | 34 | return layer, x 35 | 36 | def _run_gradcheck(layer, x): 37 | r"""Run gradcheck and gradgradcheck for the given inputs.""" 38 | def func(x): 39 | return layer(x) 40 | 41 | # Ensure input is leaf variable 42 | x = x.detach().requires_grad_(True) 43 | 44 | # Forward pass 45 | out = func(x) 46 | print(f"Forward output sum: {out.sum().item()}") 47 | 48 | # Backward pass with gradient checking 49 | out.sum().backward() 50 | 51 | # Print gradient norms for debugging 52 | print(f"x.grad norm: {x.grad.norm().item() if x.grad is not None else 'None'}") 53 | 54 | # Reset gradients 55 | x.grad = None 56 | 57 | # Run gradcheck with relaxed tolerances 58 | gradcheck_success = gradcheck( 59 | func, (x,), 60 | eps=EPS, atol=1e-5, rtol=1e-5, 61 | nondet_tol=1e-3, 62 | check_undefined_grad=False 63 | ) 64 | print('grad_check_passed') 65 | 66 | # Run gradgradcheck with same settings 67 | gradgradcheck_success = gradgradcheck( 68 | func, (x,), 69 | eps=EPS, atol=1e-5, rtol=1e-5, 70 | nondet_tol=1e-3, 71 | check_undefined_grad=False 72 | ) 73 | print('gradgrad_check_passed') 74 | 75 | return gradcheck_success, gradgradcheck_success 76 | 77 | # Test cases with different irreps and parameter combinations 78 | def test_norm_case(irreps, channels, batch_size, affine, scaled, training): 79 | r"""Test normalization with given parameters.""" 80 | print(f"\nTesting BatchRMSNorm with irreps={irreps}, channels={channels}, " 81 | f"batch_size={batch_size}, affine={affine}, scaled={scaled}, training={training}") 82 | 83 | layer, x = _init_test_case(irreps, channels, batch_size, device, affine, scaled, training) 84 | gradcheck_success, gradgradcheck_success = _run_gradcheck(layer, x) 85 | 86 | print(f"BatchRMSNorm - gradcheck: {gradcheck_success}, gradgradcheck: {gradgradcheck_success}") 87 | assert gradcheck_success and gradgradcheck_success 88 | 89 | # Main execution 90 | if __name__ == '__main__': 91 | print("Running gradient checks for BatchRMSNorm...") 92 | 93 | # Test different combinations of affine, scaled and training 94 | test_configs = [ 95 | ("1x0e", 3, 5), # scalar only 96 | ("1x0e + 1x1e", 4, 6), # scalar + vector 97 | ("1x0e + 1x1e + 1x2e", 5, 7) # scalar + vector + tensor 98 | ] 99 | 100 | for irreps, channels, batch_size in test_configs: 101 | for affine in [True, False]: 102 | for scaled in [True, False]: 103 | for training in [True, False]: 104 | test_norm_case(irreps, channels, batch_size, affine, scaled, training) 105 | 106 | print("\nAll gradient checks completed.") 107 | -------------------------------------------------------------------------------- /equitorch/utils/_random.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from typing import Tuple 4 | 5 | def rand_spherical_xyz(shape: Tuple[int, ...], device=None, dtype=None) -> torch.Tensor: 6 | """Generate random points uniformly distributed on a unit sphere. 7 | 8 | Args: 9 | shape: Tuple defining the batch dimensions (e.g., (10,) or (5,5)) 10 | device: Torch device for the output tensor 11 | dtype: Torch dtype for the output tensor 12 | 13 | Returns: 14 | Tensor of shape ``(*shape, 3)`` where each vector has unit norm 15 | """ 16 | # Generate random points and normalize 17 | xyz = torch.randn(*shape, 3, device=device, dtype=dtype) 18 | return xyz / xyz.norm(dim=-1, keepdim=True) 19 | 20 | def rand_spherical_angles(shape: Tuple[int, ...], device=None, dtype=None) -> Tuple[torch.Tensor, torch.Tensor]: 21 | """Generate random spherical angles with uniform distribution. 22 | 23 | Args: 24 | shape: Tuple defining the batch dimensions 25 | device: Torch device for the output tensors 26 | dtype: Torch dtype for the output tensors 27 | 28 | Returns: 29 | Tuple of (theta, phi) where: 30 | - theta is in [0, π) (polar angle from +z axis) 31 | - phi is in [0, 2π) (azimuthal angle from +x axis) 32 | """ 33 | # Uniform theta using inverse transform sampling 34 | theta = torch.acos(2 * torch.rand(shape, device=device, dtype=dtype) - 1) 35 | # Uniform phi 36 | phi = 2 * math.pi * torch.rand(shape, device=device, dtype=dtype) 37 | return theta, phi 38 | 39 | def rand_rotation_angles(shape: Tuple[int, ...], device=None, dtype=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 40 | """Generate random Euler angles (ZYZ convention) with uniform distribution. 41 | 42 | Args: 43 | shape: Tuple defining the batch dimensions 44 | device: Torch device for the output tensors 45 | dtype: Torch dtype for the output tensors 46 | 47 | Returns: 48 | Tuple of (alpha, beta, gamma) where: 49 | - alpha is in [0, 2π) (first rotation about z-axis) 50 | - beta is in [0, π) (rotation about y-axis) 51 | - gamma is in [0, 2π) (second rotation about z-axis) 52 | """ 53 | beta = torch.acos(2 * torch.rand(shape, device=device, dtype=dtype) - 1) 54 | alpha = 2 * math.pi * torch.rand(shape, device=device, dtype=dtype) 55 | gamma = 2 * math.pi * torch.rand(shape, device=device, dtype=dtype) 56 | return alpha, beta, gamma 57 | 58 | def rand_rotation_matrices(shape: Tuple[int, ...], device=None, dtype=None) -> torch.Tensor: 59 | """Generate random rotation matrices using Rodrigues' rotation formula. 60 | 61 | Args: 62 | shape: Tuple defining the batch dimensions 63 | device: Torch device for the output tensor 64 | dtype: Torch dtype for the output tensor 65 | 66 | Returns: 67 | Tensor of shape (*shape, 3, 3) containing valid rotation matrices 68 | """ 69 | # Generate random rotation axis (unit vector) and angle 70 | axis = rand_spherical_xyz(shape, device, dtype) 71 | angle = 2 * math.pi * torch.rand(shape, device=device, dtype=dtype) 72 | 73 | # Rodrigues' rotation formula components 74 | cos = torch.cos(angle) 75 | sin = torch.sin(angle) 76 | one_minus_cos = 1 - cos 77 | 78 | # Cross product matrix [a]× 79 | a_x = torch.zeros(*shape, 3, 3, device=device, dtype=dtype) 80 | a_x[..., 0, 1] = -axis[..., 2] 81 | a_x[..., 0, 2] = axis[..., 1] 82 | a_x[..., 1, 0] = axis[..., 2] 83 | a_x[..., 1, 2] = -axis[..., 0] 84 | a_x[..., 2, 0] = -axis[..., 1] 85 | a_x[..., 2, 1] = axis[..., 0] 86 | 87 | # Outer product a ⊗ a 88 | a_outer = axis.unsqueeze(-1) * axis.unsqueeze(-2) 89 | 90 | # Rodrigues' formula: R = cosθ I + sinθ [a]× + (1-cosθ) a ⊗ a 91 | eye = torch.eye(3, device=device, dtype=dtype).expand(*shape, 3, 3) 92 | return (cos[..., None, None] * eye + 93 | sin[..., None, None] * a_x + 94 | one_minus_cos[..., None, None] * a_outer) 95 | -------------------------------------------------------------------------------- /equitorch/nn/functional/__init__.py: -------------------------------------------------------------------------------- 1 | """Provides functional implementations of equivariant neural network operations.""" 2 | from .linears import ( 3 | # TensorProductU1UDummy, 4 | tensor_product_u1u, 5 | so3_linear_uu, 6 | # TensorProduct1UUDummy, 7 | tensor_product_1uu, 8 | # TensorProductUU1Dummy, 9 | tensor_product_uu1, 10 | # TensorProductU1VDummy, 11 | tensor_product_u1v, 12 | so3_linear_uv, 13 | # TensorProduct1VUDummy, 14 | tensor_product_1vu, 15 | # TensorProductVU1Dummy, 16 | tensor_product_vu1, 17 | # IrrepWiseLinear, 18 | irrep_wise_linear, 19 | # IrrepsLinear, 20 | irreps_linear, 21 | so2_linear_uu, 22 | so2_linear_uv 23 | ) 24 | from .sphericals import ( 25 | spherical_harmonics, 26 | xyz_to_spherical, 27 | spherical_to_xyz, 28 | xyz_to_sincos 29 | ) 30 | from .normalization import ( 31 | batch_rms_norm, 32 | layer_rms_norm 33 | ) 34 | from .sparse_scale import ( 35 | # SparseScale, 36 | sparse_scale 37 | ) 38 | from .tensor_products import ( 39 | # TensorProductUUUDummy, 40 | tensor_product_uuu, 41 | # TensorProductUVWDummy, 42 | tensor_product_uvw, 43 | # TensorDotUU, 44 | tensor_dot_uu, 45 | # TensorDotUV, 46 | tensor_dot_uv 47 | ) 48 | from .sparse_product import ( 49 | # SparseMul, 50 | sparse_mul, 51 | # SparseOuter, 52 | sparse_outer, 53 | # SparseInner, 54 | sparse_inner, 55 | # SparseVecMat, 56 | sparse_vecmat, 57 | # SparseVecSca, 58 | sparse_vecsca, 59 | # SparseScaVec, 60 | sparse_scavec, 61 | # SparseMatTVec, 62 | sparse_mat_t_vec 63 | ) 64 | from .wigner_d import ( 65 | sparse_wigner_rotation, 66 | dense_wigner_rotation, 67 | wigner_d_matrix, 68 | align_to_z_wigner_d 69 | ) 70 | from .cutoffs import ( 71 | radial_standarize, 72 | polynomial_cutoff, 73 | cosine_cutoff, 74 | mollifier_cutoff 75 | ) 76 | from .activations import ( 77 | gating 78 | ) 79 | from .angular import ( 80 | sincos 81 | ) 82 | from .norm import ( 83 | # SquaredNorm, 84 | squared_norm, 85 | # Norm, 86 | norm, 87 | # ChannelMeanSquaredNorm, 88 | channel_mean_squared_norm, 89 | # BatchMeanSquaredNorm, 90 | batch_mean_squared_norm 91 | ) 92 | from .dropout import ( 93 | irrep_wise_dropout 94 | ) 95 | from .rotations import ( 96 | angles_to_matrix 97 | ) 98 | 99 | __all__ = [ 100 | # "TensorProductU1UDummy", 101 | "tensor_product_u1u", 102 | "so3_linear_uu", 103 | # "TensorProduct1UUDummy", 104 | "tensor_product_1uu", 105 | # "TensorProductUU1Dummy", 106 | "tensor_product_uu1", 107 | # "TensorProductU1VDummy", 108 | "tensor_product_u1v", 109 | "so3_linear_uv", 110 | # "TensorProduct1VUDummy", 111 | "tensor_product_1vu", 112 | # "TensorProductVU1Dummy", 113 | "tensor_product_vu1", 114 | # "IrrepWiseLinear", 115 | "irrep_wise_linear", 116 | # "IrrepsLinear", 117 | "irreps_linear", 118 | "so2_linear_uu", 119 | "so2_linear_uv", 120 | "spherical_harmonics", 121 | "xyz_to_spherical", 122 | "spherical_to_xyz", 123 | "xyz_to_sincos", 124 | "batch_rms_norm", 125 | "layer_rms_norm", 126 | # "SparseScale", 127 | "sparse_scale", 128 | # "TensorProductUUUDummy", 129 | "tensor_product_uuu", 130 | # "TensorProductUVWDummy", 131 | "tensor_product_uvw", 132 | # "TensorDotUU", 133 | "tensor_dot_uu", 134 | # "TensorDotUV", 135 | "tensor_dot_uv", 136 | # "SparseMul", 137 | "sparse_mul", 138 | # "SparseOuter", 139 | "sparse_outer", 140 | # "SparseInner", 141 | "sparse_inner", 142 | # "SparseVecMat", 143 | "sparse_vecmat", 144 | # "SparseVecSca", 145 | "sparse_vecsca", 146 | # "SparseScaVec", 147 | "sparse_scavec", 148 | # "SparseMatTVec", 149 | "sparse_mat_t_vec", 150 | "sparse_wigner_rotation", 151 | "dense_wigner_rotation", 152 | "wigner_d_matrix", 153 | "align_to_z_wigner_d", 154 | "radial_standarize", 155 | "polynomial_cutoff", 156 | "cosine_cutoff", 157 | "mollifier_cutoff", 158 | "gating", 159 | "sincos", 160 | # "SquaredNorm", 161 | "squared_norm", 162 | # "Norm", 163 | "norm", 164 | # "ChannelMeanSquaredNorm", 165 | "channel_mean_squared_norm", 166 | # "BatchMeanSquaredNorm", 167 | "batch_mean_squared_norm", 168 | "irrep_wise_dropout", 169 | "angles_to_matrix" 170 | ] 171 | -------------------------------------------------------------------------------- /docs/build/html/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Search — equitorch documentation 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 54 | 55 |
59 | 60 |
61 |
62 |
63 |
    64 |
  • 65 | 66 |
  • 67 |
  • 68 |
69 |
70 |
71 |
72 |
73 | 74 | 81 | 82 | 83 |
84 | 85 |
86 | 87 |
88 |
89 |
90 | 91 |
92 | 93 |
94 |

© Copyright 2025, Tong Wang.

95 |
96 | 97 | Built with Sphinx using a 98 | theme 99 | provided by Read the Docs. 100 | 101 | 102 |
103 |
104 |
105 |
106 |
107 | 112 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /docs/build/html/generated/equitorch.nn.functional.cutoffs.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | equitorch.nn.functional.cutoffs — equitorch documentation 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 52 | 53 |
57 | 58 |
59 |
60 |
61 | 68 |
69 |
70 |
71 |
72 | 73 |
74 |

equitorch.nn.functional.cutoffs

75 |
76 | 77 | 78 |
79 |
80 |
81 | 82 |
83 | 84 |
85 |

© Copyright 2025, Tong Wang.

86 |
87 | 88 | Built with Sphinx using a 89 | theme 90 | provided by Read the Docs. 91 | 92 | 93 |
94 |
95 |
96 |
97 |
98 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /equitorch/ops/indexed_product_segment_op.py: -------------------------------------------------------------------------------- 1 | from .batched_sparse_dense_op import ( 2 | indexed_mul_scale_gather, 3 | indexed_inner_scale_gather, 4 | indexed_outer_scale_gather, 5 | indexed_vecmat_scale_gather, 6 | indexed_vecsca_scale_gather 7 | ) 8 | 9 | 10 | 11 | def indexed_mul_segment( 12 | input1, 13 | input2, 14 | index1=None, 15 | index2=None, 16 | seg=None, 17 | out=None, 18 | block_size_n=64, 19 | block_size_c=64, 20 | num_stages=2 21 | ): 22 | r"""分段聚集的索引乘法(必须提供seg且至少一个索引)""" 23 | assert seg is not None, "seg cannot be None for segment operations" 24 | assert index1 is not None or index2 is not None, "At least one of index1 or index2 must be provided" 25 | 26 | return indexed_mul_scale_gather( 27 | input1, input2, 28 | index1=index1, 29 | index2=index2, 30 | seg=seg, 31 | out=out, 32 | block_size_n=block_size_n, 33 | block_size_c=block_size_c, 34 | num_stages=num_stages 35 | ) 36 | 37 | def indexed_outer_segment( 38 | input1, 39 | input2, 40 | index1=None, 41 | index2=None, 42 | seg=None, 43 | out=None, 44 | accumulated=False, 45 | block_size_n=8, 46 | block_size_c1=32, 47 | block_size_c2=32, 48 | num_stages=1, 49 | ): 50 | r"""分段聚集的索引外积(必须提供seg且至少一个索引)""" 51 | assert seg is not None, "seg cannot be None for segment operations" 52 | assert index1 is not None or index2 is not None, "At least one of index1 or index2 must be provided" 53 | 54 | return indexed_outer_scale_gather( 55 | input1, input2, 56 | index1=index1, 57 | index2=index2, 58 | seg=seg, 59 | out=out, 60 | block_size_n=block_size_n, 61 | block_size_c1=block_size_c1, 62 | block_size_c2=block_size_c2, 63 | num_stages=num_stages, 64 | out_accumulated=accumulated 65 | ) 66 | 67 | def indexed_inner_segment( 68 | input1, 69 | input2, 70 | index1=None, 71 | index2=None, 72 | seg=None, 73 | out=None, 74 | block_size_n=32, 75 | block_size_c=32, 76 | num_stages=2 77 | ): 78 | r"""分段聚集的索引内积(必须提供seg且至少一个索引)""" 79 | assert seg is not None, "seg cannot be None for segment operations" 80 | assert index1 is not None or index2 is not None, "At least one of index1 or index2 must be provided" 81 | 82 | return indexed_inner_scale_gather( 83 | input1, input2, 84 | index1=index1, 85 | index2=index2, 86 | seg=seg, 87 | out=out, 88 | block_size_n=block_size_n, 89 | block_size_c=block_size_c, 90 | num_stages=num_stages 91 | ) 92 | 93 | def indexed_vecmat_segment( 94 | input1, 95 | input2, 96 | index1=None, 97 | index2=None, 98 | seg=None, 99 | out=None, 100 | block_size_n=8, 101 | block_size_c_in=32, 102 | block_size_c_out=32, 103 | num_stages=1 104 | ): 105 | r"""分段聚集的索引向量-矩阵乘(必须提供seg且至少一个索引)""" 106 | assert seg is not None, "seg cannot be None for segment operations" 107 | assert index1 is not None or index2 is not None, "At least one of index1 or index2 must be provided" 108 | 109 | return indexed_vecmat_scale_gather( 110 | input1, input2, 111 | index1=index1, 112 | index2=index2, 113 | seg=seg, 114 | out=out, 115 | block_size_n=block_size_n, 116 | block_size_c_in=block_size_c_in, 117 | block_size_c_out=block_size_c_out, 118 | num_stages=num_stages 119 | ) 120 | 121 | def indexed_vecsca_segment( 122 | input1, 123 | input2, 124 | index1=None, 125 | index2=None, 126 | seg=None, 127 | out=None, 128 | block_size_n=64, 129 | block_size_c=64, 130 | num_stages=2 131 | ): 132 | r"""分段聚集的索引向量缩放(必须提供seg且至少一个索引)""" 133 | assert seg is not None, "seg cannot be None for segment operations" 134 | assert index1 is not None or index2 is not None, "At least one of index1 or index2 must be provided" 135 | 136 | return indexed_vecsca_scale_gather( 137 | input1, input2, 138 | index1=index1, 139 | index2=index2, 140 | seg=seg, 141 | out=out, 142 | block_size_n=block_size_n, 143 | block_size_c=block_size_c, 144 | num_stages=num_stages 145 | ) 146 | 147 | -------------------------------------------------------------------------------- /equitorch/ops/indexed_product_scale_segment_op.py: -------------------------------------------------------------------------------- 1 | from .batched_sparse_dense_op import ( 2 | indexed_mul_scale_gather, 3 | indexed_inner_scale_gather, 4 | indexed_outer_scale_gather, 5 | indexed_vecmat_scale_gather, 6 | indexed_vecsca_scale_gather 7 | ) 8 | 9 | def indexed_mul_scale_segment( 10 | input1, 11 | input2, 12 | scale, 13 | index1=None, 14 | index2=None, 15 | seg=None, 16 | out=None, 17 | block_size_n=64, 18 | block_size_c=64, 19 | num_stages=2 20 | ): 21 | r"""分段聚集的索引乘法(必须提供seg且至少一个索引)""" 22 | assert seg is not None, "seg cannot be None for segment operations" 23 | assert index1 is not None or index2 is not None, "At least one of index1 or index2 must be provided" 24 | 25 | return indexed_mul_scale_gather( 26 | input1, input2, scale, 27 | index1=index1, 28 | index2=index2, 29 | seg=seg, 30 | out=out, 31 | block_size_n=block_size_n, 32 | block_size_c=block_size_c, 33 | num_stages=num_stages 34 | ) 35 | 36 | def indexed_outer_scale_segment( 37 | input1, 38 | input2, 39 | scale, 40 | index1=None, 41 | index2=None, 42 | seg=None, 43 | out=None, 44 | block_size_n=8, 45 | block_size_c1=32, 46 | block_size_c2=32, 47 | num_stages=1 48 | ): 49 | r"""分段聚集的索引外积(必须提供seg且至少一个索引)""" 50 | assert seg is not None, "seg cannot be None for segment operations" 51 | assert index1 is not None or index2 is not None, "At least one of index1 or index2 must be provided" 52 | 53 | return indexed_outer_scale_gather( 54 | input1, input2, scale, 55 | index1=index1, 56 | index2=index2, 57 | seg=seg, 58 | out=out, 59 | block_size_n=block_size_n, 60 | block_size_c1=block_size_c1, 61 | block_size_c2=block_size_c2, 62 | num_stages=num_stages 63 | ) 64 | 65 | def indexed_inner_scale_segment( 66 | input1, 67 | input2, 68 | scale, 69 | index1=None, 70 | index2=None, 71 | seg=None, 72 | out=None, 73 | block_size_n=32, 74 | block_size_c=32, 75 | num_stages=2 76 | ): 77 | r"""分段聚集的索引内积(必须提供seg且至少一个索引)""" 78 | assert seg is not None, "seg cannot be None for segment operations" 79 | assert index1 is not None or index2 is not None, "At least one of index1 or index2 must be provided" 80 | 81 | return indexed_inner_scale_gather( 82 | input1, input2, scale, 83 | index1=index1, 84 | index2=index2, 85 | seg=seg, 86 | out=out, 87 | block_size_n=block_size_n, 88 | block_size_c=block_size_c, 89 | num_stages=num_stages 90 | ) 91 | 92 | def indexed_vecmat_scale_segment( 93 | input1, 94 | input2, 95 | scale, 96 | index1=None, 97 | index2=None, 98 | seg=None, 99 | out=None, 100 | block_size_n=8, 101 | block_size_c_in=32, 102 | block_size_c_out=32, 103 | num_stages=1 104 | ): 105 | r"""分段聚集的索引向量-矩阵乘(必须提供seg且至少一个索引)""" 106 | assert seg is not None, "seg cannot be None for segment operations" 107 | assert index1 is not None or index2 is not None, "At least one of index1 or index2 must be provided" 108 | 109 | return indexed_vecmat_scale_gather( 110 | input1, input2, scale, 111 | index1=index1, 112 | index2=index2, 113 | seg=seg, 114 | out=out, 115 | block_size_n=block_size_n, 116 | block_size_c_in=block_size_c_in, 117 | block_size_c_out=block_size_c_out, 118 | num_stages=num_stages 119 | ) 120 | 121 | def indexed_vecsca_scale_segment( 122 | input1, 123 | input2, 124 | scale, 125 | index1=None, 126 | index2=None, 127 | seg=None, 128 | out=None, 129 | block_size_n=64, 130 | block_size_c=64, 131 | num_stages=2 132 | ): 133 | r"""分段聚集的索引向量缩放(必须提供seg且至少一个索引)""" 134 | assert seg is not None, "seg cannot be None for segment operations" 135 | assert index1 is not None or index2 is not None, "At least one of index1 or index2 must be provided" 136 | 137 | return indexed_vecsca_scale_gather( 138 | input1, input2, scale, 139 | index1=index1, 140 | index2=index2, 141 | seg=seg, 142 | out=out, 143 | block_size_n=block_size_n, 144 | block_size_c=block_size_c, 145 | num_stages=num_stages 146 | ) 147 | 148 | -------------------------------------------------------------------------------- /equitorch/nn/functional/rotations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Tuple 3 | 4 | def angles_to_matrix( 5 | alpha: Optional[torch.Tensor] = None, 6 | beta: Optional[torch.Tensor] = None, 7 | gamma: Optional[torch.Tensor] = None 8 | ) -> torch.Tensor: 9 | """Convert Euler angles (ZYZ convention) to rotation matrices. 10 | 11 | Args: 12 | alpha: First rotation angle about z-axis (radians) 13 | beta: Second rotation angle about y-axis (radians) 14 | gamma: Third rotation angle about z-axis (radians) 15 | 16 | Returns: 17 | Rotation matrices of shape (..., 3, 3) 18 | """ 19 | if alpha is None and beta is None and gamma is None: 20 | raise ValueError("At least one of alpha, beta, or gamma must be provided.") 21 | 22 | # Determine broadcast shape 23 | shapes = [] 24 | if alpha is not None: 25 | shapes.append(alpha.shape) 26 | if beta is not None: 27 | shapes.append(beta.shape) 28 | if gamma is not None: 29 | shapes.append(gamma.shape) 30 | 31 | if not shapes: # Should be caught by the None check above, but as a safeguard 32 | return torch.eye(3) 33 | 34 | # Find a non-None tensor to get dtype and device 35 | ref_tensor = alpha if alpha is not None else beta if beta is not None else gamma 36 | assert ref_tensor is not None # Ensured by the first check 37 | 38 | # Broadcast all angles to the same shape 39 | # We need to handle the case where an angle is None by creating a zero tensor of the broadcasted shape 40 | # First, let's find the target broadcast shape using torch.broadcast_shapes 41 | # To do this, we need at least one tensor. If an angle is None, we can't directly use it. 42 | # Instead, we'll create dummy tensors of shape (1,) for None angles to participate in broadcasting shape calculation. 43 | 44 | dummy_alpha = alpha if alpha is not None else torch.zeros((1,), dtype=ref_tensor.dtype, device=ref_tensor.device) 45 | dummy_beta = beta if beta is not None else torch.zeros((1,), dtype=ref_tensor.dtype, device=ref_tensor.device) 46 | dummy_gamma = gamma if gamma is not None else torch.zeros((1,), dtype=ref_tensor.dtype, device=ref_tensor.device) 47 | 48 | try: 49 | broadcast_shape = torch.broadcast_shapes(dummy_alpha.shape, dummy_beta.shape, dummy_gamma.shape) 50 | except RuntimeError as e: 51 | raise ValueError(f"Could not broadcast shapes of alpha, beta, gamma: {shapes}. Error: {e}") 52 | 53 | 54 | zeros = torch.zeros(broadcast_shape, dtype=ref_tensor.dtype, device=ref_tensor.device) 55 | ones = torch.ones(broadcast_shape, dtype=ref_tensor.dtype, device=ref_tensor.device) 56 | 57 | # Rotation around Z-axis by alpha 58 | if alpha is not None: 59 | expanded_alpha = alpha.expand_as(zeros) 60 | sin_a = torch.sin(expanded_alpha) 61 | cos_a = torch.cos(expanded_alpha) 62 | else: 63 | sin_a, cos_a = zeros, ones # No rotation 64 | 65 | # R_z(alpha) 66 | # yapf: disable 67 | R_alpha = torch.stack([ 68 | cos_a, -sin_a, zeros, 69 | sin_a, cos_a, zeros, 70 | zeros, zeros, ones 71 | ], dim=-1).reshape(broadcast_shape + (3, 3)) 72 | # yapf: enable 73 | 74 | # Rotation around Y-axis by beta 75 | if beta is not None: 76 | expanded_beta = beta.expand_as(zeros) 77 | sin_b = torch.sin(expanded_beta) 78 | cos_b = torch.cos(expanded_beta) 79 | else: 80 | sin_b, cos_b = zeros, ones # No rotation 81 | 82 | # R_y(beta) 83 | # yapf: disable 84 | R_beta = torch.stack([ 85 | cos_b, zeros, sin_b, 86 | zeros, ones, zeros, 87 | -sin_b, zeros, cos_b 88 | ], dim=-1).reshape(broadcast_shape + (3, 3)) 89 | # yapf: enable 90 | 91 | # Rotation around Z-axis by gamma 92 | if gamma is not None: 93 | expanded_gamma = gamma.expand_as(zeros) 94 | sin_g = torch.sin(expanded_gamma) 95 | cos_g = torch.cos(expanded_gamma) 96 | else: 97 | sin_g, cos_g = zeros, ones # No rotation 98 | 99 | # R_z(gamma) 100 | # yapf: disable 101 | R_gamma = torch.stack([ 102 | cos_g, -sin_g, zeros, 103 | sin_g, cos_g, zeros, 104 | zeros, zeros, ones 105 | ], dim=-1).reshape(broadcast_shape + (3, 3)) 106 | # yapf: enable 107 | 108 | R_matrix = R_gamma @ R_beta @ R_alpha # More concise 109 | 110 | return R_matrix 111 | -------------------------------------------------------------------------------- /docs/build/html/generated/equitorch.nn.radials.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | equitorch.nn.radials — equitorch documentation 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 52 | 53 |
57 | 58 |
59 |
60 |
61 | 68 |
69 |
70 |
71 |
72 | 73 |
74 |

equitorch.nn.radials

75 |

Classes

76 | 77 | 78 | 79 | 80 | 81 | 82 |

BesselBasis(r_max[, num_basis, trainable])

83 |
84 | 85 | 86 |
87 |
88 |
89 | 90 |
91 | 92 |
93 |

© Copyright 2025, Tong Wang.

94 |
95 | 96 | Built with Sphinx using a 97 | theme 98 | provided by Read the Docs. 99 | 100 | 101 |
102 |
103 |
104 |
105 |
106 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /test/test_linear/test_grad_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import os 4 | from torch.autograd import gradcheck, gradgradcheck 5 | 6 | from equitorch.irreps import Irreps, check_irreps 7 | from equitorch.nn.linears import IrrepsLinear 8 | 9 | # Set environment and defaults 10 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 11 | torch.set_default_dtype(torch.float64) # Use float64 for gradcheck stability 12 | torch.random.manual_seed(0) 13 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | 15 | EPS = 1e-9 16 | 17 | def _init_test_case(irreps_in, irreps_out, channels_in, channels_out, batch_size, device, shared_weight): 18 | r"""Initialize test case for gradcheck.""" 19 | layer = IrrepsLinear( 20 | irreps_in=irreps_in, 21 | irreps_out=irreps_out, 22 | channels_in=channels_in, 23 | channels_out=channels_out, 24 | path_norm=True, 25 | internal_weights=False 26 | ).to(device) 27 | 28 | # Create inputs with requires_grad=True 29 | x = torch.randn(batch_size, Irreps(irreps_in).dim, channels_in, 30 | device=device, dtype=torch.float64, requires_grad=True) 31 | 32 | # Create weights with requires_grad=True 33 | weight_shape = layer.weight_shape 34 | if not shared_weight: 35 | weight_shape = (batch_size,) + weight_shape 36 | W = torch.randn(*weight_shape, device=device, dtype=torch.float64, requires_grad=True) 37 | 38 | return layer, x, W 39 | 40 | def _run_gradcheck(layer, x, W): 41 | r"""Run gradcheck and gradgradcheck for the given inputs.""" 42 | def func(x, W): 43 | return layer(x, W) 44 | 45 | # Ensure inputs are leaf variables 46 | x = x.detach().requires_grad_(True) 47 | W = W.detach().requires_grad_(True) 48 | 49 | # Forward pass 50 | out = func(x, W) 51 | print(f"Forward output sum: {out.sum().item()}") 52 | 53 | # Backward pass with gradient checking 54 | out.sum().backward() 55 | 56 | # Print gradient norms for debugging 57 | print(f"x.grad norm: {x.grad.norm().item() if x.grad is not None else 'None'}") 58 | print(f"W.grad norm: {W.grad.norm().item() if W.grad is not None else 'None'}") 59 | 60 | # Reset gradients 61 | x.grad = None 62 | W.grad = None 63 | 64 | # Run gradcheck with relaxed tolerances 65 | gradcheck_success = gradcheck( 66 | func, (x, W), 67 | eps=EPS, atol=1e-5, rtol=1e-5, 68 | nondet_tol=1e-3, 69 | check_undefined_grad=False 70 | ) 71 | print('grad_check_passed') 72 | 73 | # Run gradgradcheck with same settings 74 | gradgradcheck_success = gradgradcheck( 75 | func, (x, W), 76 | eps=EPS, atol=1e-5, rtol=1e-5, 77 | nondet_tol=1e-3, 78 | check_undefined_grad=False 79 | ) 80 | print('gradgrad_check_passed') 81 | 82 | return gradcheck_success, gradgradcheck_success 83 | 84 | # Test cases with different irreps and parameter combinations 85 | def test_linear_case(irreps_in, irreps_out, channels_in, channels_out, batch_size, shared_weight): 86 | r"""Test linear layer with given parameters.""" 87 | print(f"\nTesting IrrepsLinear with irreps_in={irreps_in}, irreps_out={irreps_out}, " 88 | f"channels_in={channels_in}, channels_out={channels_out}, " 89 | f"batch_size={batch_size}, shared_weight={shared_weight}") 90 | 91 | layer, x, W = _init_test_case( 92 | irreps_in, irreps_out, channels_in, channels_out, 93 | batch_size, device, shared_weight 94 | ) 95 | 96 | gradcheck_success, gradgradcheck_success = _run_gradcheck(layer, x, W) 97 | print(f"IrrepsLinear - gradcheck: {gradcheck_success}, gradgradcheck: {gradgradcheck_success}") 98 | assert gradcheck_success and gradgradcheck_success 99 | 100 | # Main execution 101 | if __name__ == '__main__': 102 | print("Running gradient checks for IrrepsLinear...") 103 | 104 | # Test different combinations of parameters 105 | test_configs = [ 106 | ("1x0e", "1x0e", 3, 3, 5), # scalar only 107 | ("1x0e + 1x1e", "1x0e + 1x1e", 4, 4, 6), # scalar + vector 108 | ("1x0e + 1x1e + 1x2e", "1x0e + 1x1e + 1x2e", 5, 5, 7) # scalar + vector + tensor 109 | ] 110 | 111 | for irreps_in, irreps_out, channels_in, channels_out, batch_size in test_configs: 112 | for shared_weight in [True, False]: 113 | test_linear_case( 114 | irreps_in, irreps_out, 115 | channels_in, channels_out, 116 | batch_size, shared_weight 117 | ) 118 | 119 | print("\nAll gradient checks completed.") 120 | -------------------------------------------------------------------------------- /docs/build/html/generated/equitorch.nn.dropout.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | equitorch.nn.dropout — equitorch documentation 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 52 | 53 |
57 | 58 |
59 |
60 |
61 | 68 |
69 |
70 |
71 |
72 | 73 |
74 |

equitorch.nn.dropout

75 |

Classes

76 | 77 | 78 | 79 | 80 | 81 | 82 |

Dropout([p, irreps, irrep_wise, work_on_eval])

Apply dropout to equivariant features.

83 |
84 | 85 | 86 |
87 |
88 |
89 | 90 |
91 | 92 |
93 |

© Copyright 2025, Tong Wang.

94 |
95 | 96 | Built with Sphinx using a 97 | theme 98 | provided by Read the Docs. 99 | 100 | 101 |
102 |
103 |
104 |
105 |
106 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /docs/build/html/generated/equitorch.nn.rotations.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | equitorch.nn.rotations — equitorch documentation 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 52 | 53 |
57 | 58 |
59 |
60 |
61 | 68 |
69 |
70 |
71 |
72 | 73 |
74 |

equitorch.nn.rotations

75 |

Classes

76 | 77 | 78 | 79 | 80 | 81 | 82 |

AnglesToMatrix()

Module to convert Euler angles (ZYZ convention) to rotation matrices.

83 |
84 | 85 | 86 |
87 |
88 |
89 | 90 |
91 | 92 |
93 |

© Copyright 2025, Tong Wang.

94 |
95 | 96 | Built with Sphinx using a 97 | theme 98 | provided by Read the Docs. 99 | 100 | 101 |
102 |
103 |
104 |
105 |
106 | 111 | 112 | 113 | --------------------------------------------------------------------------------