├── jaxlie ├── py.typed ├── utils │ ├── __init__.py │ └── _utils.py ├── __init__.py ├── manifold │ ├── __init__.py │ ├── _tree_utils.py │ ├── _backprop.py │ └── _deltas.py ├── hints │ └── __init__.py ├── _so2.py ├── _se3.py ├── _base.py ├── _se2.py └── _so3.py ├── docs ├── .gitignore ├── source │ ├── se3_basics.rst │ ├── vmap_usage.rst │ ├── se3_optimization.rst │ ├── index.rst │ └── conf.py ├── requirements.txt └── Makefile ├── mypy.ini ├── .gitignore ├── .flake8 ├── .github └── workflows │ ├── lint.yml │ ├── mypy.yml │ ├── docs.yml │ ├── build.yml │ ├── coverage.yml │ └── publish.yml ├── .coveragerc ├── LICENSE ├── tests ├── test_serialization.py ├── test_broadcast.py ├── test_group_axioms.py ├── test_examples.py ├── test_jlog.py ├── test_manifold.py ├── test_jac_left.py ├── utils.py ├── test_operations.py └── test_autodiff.py ├── setup.py ├── examples ├── se3_basics.py ├── vmap_example.py └── se3_optimization.py └── README.md /jaxlie/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.swo 3 | *.pyc 4 | *.egg-info 5 | __pycache__ 6 | .coverage 7 | htmlcov 8 | .mypy_cache 9 | .dmypy.json 10 | .hypothesis 11 | -------------------------------------------------------------------------------- /docs/source/se3_basics.rst: -------------------------------------------------------------------------------- 1 | Basics 2 | ========================================== 3 | 4 | 5 | .. literalinclude:: ../../examples/se3_basics.py 6 | :language: python 7 | 8 | -------------------------------------------------------------------------------- /jaxlie/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._utils import broadcast_leading_axes, get_epsilon, register_lie_group 2 | 3 | __all__ = ["get_epsilon", "register_lie_group", "broadcast_leading_axes"] 4 | -------------------------------------------------------------------------------- /docs/source/vmap_usage.rst: -------------------------------------------------------------------------------- 1 | `jax.vmap` Usage 2 | ========================================== 3 | 4 | 5 | .. literalinclude:: ../../examples/vmap_example.py 6 | :language: python 7 | 8 | -------------------------------------------------------------------------------- /docs/source/se3_optimization.rst: -------------------------------------------------------------------------------- 1 | SE(3) Optimization 2 | ========================================== 3 | 4 | 5 | .. literalinclude:: ../../examples/se3_optimization.py 6 | :language: python 7 | 8 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # E203: whitespace before : 3 | # E501: line too long ( characters) 4 | # W503: line break before binary operator 5 | ; ignore = E203,E501,D100,D101,D102,D103,W503 6 | ignore = E203,E501,W503 7 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==7.2.6 2 | sphinx_rtd_theme 3 | sphinx_math_dollar 4 | sphinx-autoapi==3.0.0 5 | graphviz 6 | m2r2==0.3.3.post2 7 | git+https://github.com/brentyi/sphinxcontrib-programoutput.git 8 | git+https://github.com/brentyi/ansi.git 9 | -------------------------------------------------------------------------------- /jaxlie/__init__.py: -------------------------------------------------------------------------------- 1 | from . import hints as hints 2 | from . import manifold as manifold 3 | from . import utils as utils 4 | from ._base import MatrixLieGroup as MatrixLieGroup 5 | from ._base import SEBase as SEBase 6 | from ._base import SOBase as SOBase 7 | from ._se2 import SE2 as SE2 8 | from ._se3 import SE3 as SE3 9 | from ._so2 import SO2 as SO2 10 | from ._so3 import SO3 as SO3 11 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | black-check: 11 | runs-on: ubuntu-22.04 12 | steps: 13 | - uses: actions/checkout@v1 14 | - name: Black Code Formatter 15 | uses: lgeiger/black-action@master 16 | with: 17 | args: ". --check" 18 | -------------------------------------------------------------------------------- /jaxlie/manifold/__init__.py: -------------------------------------------------------------------------------- 1 | from ._backprop import grad as grad 2 | from ._backprop import value_and_grad as value_and_grad 3 | from ._backprop import zero_tangents as zero_tangents 4 | from ._deltas import rminus as rminus 5 | from ._deltas import rplus as rplus 6 | from ._deltas import ( 7 | rplus_jacobian_parameters_wrt_delta as rplus_jacobian_parameters_wrt_delta, 8 | ) 9 | from ._tree_utils import normalize_all as normalize_all 10 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | # Have to re-enable the standard pragma 4 | pragma: no cover 5 | 6 | # Don't compute coverage for abstract methods, properties 7 | @abstract 8 | @abc\.abstract 9 | 10 | # or warnings 11 | warnings 12 | 13 | # or empty function bodies 14 | pass 15 | \.\.\. 16 | 17 | # or typing imports 18 | TYPE_CHECKING 19 | 20 | # or assert statements 21 | assert 22 | 23 | # or anything that's not implemented 24 | NotImplementedError() 25 | 26 | # or fallback imports 27 | except ImportError: 28 | -------------------------------------------------------------------------------- /jaxlie/hints/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple, Union 2 | 3 | import jax 4 | import numpy as onp 5 | 6 | # Type aliases for JAX/Numpy arrays; primarily for function inputs. 7 | 8 | Array = Union[onp.ndarray, jax.Array] 9 | """Type alias for `Union[jax.Array, onp.ndarray]`.""" 10 | 11 | Scalar = Union[float, Array] 12 | """Type alias for `Union[float, Array]`.""" 13 | 14 | 15 | class RollPitchYaw(NamedTuple): 16 | """Tuple containing roll, pitch, and yaw Euler angles.""" 17 | 18 | roll: Scalar 19 | pitch: Scalar 20 | yaw: Scalar 21 | 22 | 23 | __all__ = [ 24 | "Array", 25 | "Scalar", 26 | "RollPitchYaw", 27 | ] 28 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = fannypack 8 | SOURCEDIR = source 9 | BUILDDIR = ./build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | name: mypy 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | mypy: 11 | runs-on: ubuntu-22.04 12 | strategy: 13 | matrix: 14 | python-version: ["3.9", "3.10"] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v1 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install -e ".[testing]" 26 | - name: Test with mypy 27 | run: | 28 | mypy . 29 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | 7 | jobs: 8 | docs: 9 | runs-on: ubuntu-22.04 10 | container: 11 | image: python:3.9 12 | steps: 13 | 14 | # Check out source 15 | - uses: actions/checkout@v2 16 | 17 | # Build documentation 18 | - name: Building documentation 19 | run: | 20 | apt-get update 21 | apt-get install -y graphviz 22 | pip install -e . 23 | pip install -r docs/requirements.txt 24 | sphinx-build docs/source docs/build -b dirhtml 25 | 26 | # Deploy 27 | - name: Deploy to GitHub Pages 28 | uses: peaceiris/actions-gh-pages@v3 29 | with: 30 | github_token: ${{ secrets.GITHUB_TOKEN }} 31 | publish_dir: ./docs/build 32 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-22.04 12 | strategy: 13 | matrix: 14 | python-version: ["3.8", "3.9", "3.10", "3.11"] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v1 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install -e ".[testing]" 26 | - name: Test with pytest 27 | run: | 28 | # `-n auto` tells pytest-xdist to create 1 worker per CPU core. For 29 | # GitHub actions, this typically results in 2 workers. 30 | pytest -n auto 31 | -------------------------------------------------------------------------------- /.github/workflows/coverage.yml: -------------------------------------------------------------------------------- 1 | name: coverage 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | coverage: 11 | runs-on: ubuntu-22.04 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 3.8 15 | uses: actions/setup-python@v1 16 | with: 17 | python-version: 3.8 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install -e ".[testing]" 22 | - name: Generate coverage report 23 | run: | 24 | pytest --cov=jaxlie --cov-report=xml 25 | - name: Upload to Codecov 26 | uses: codecov/codecov-action@v1 27 | with: 28 | token: ${{ secrets.CODECOV_TOKEN }} 29 | file: ./coverage.xml 30 | flags: unittests 31 | name: codecov-umbrella 32 | fail_ci_if_error: true 33 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-22.04 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Brent Yi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_serialization.py: -------------------------------------------------------------------------------- 1 | """Test transform serialization, for things like saving calibrated transforms to 2 | disk.""" 3 | 4 | from typing import Tuple, Type 5 | 6 | import flax.serialization 7 | from utils import assert_transforms_close, general_group_test, sample_transform 8 | 9 | import jaxlie 10 | 11 | 12 | @general_group_test 13 | def test_serialization_state_dict_bijective( 14 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 15 | ): 16 | """Check bijectivity of state dict representation conversions.""" 17 | T = sample_transform(Group, batch_axes) 18 | T_recovered = flax.serialization.from_state_dict( 19 | T, flax.serialization.to_state_dict(T) 20 | ) 21 | assert_transforms_close(T, T_recovered) 22 | 23 | 24 | @general_group_test 25 | def test_serialization_bytes_bijective( 26 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 27 | ): 28 | """Check bijectivity of byte representation conversions.""" 29 | T = sample_transform(Group, batch_axes) 30 | T_recovered = flax.serialization.from_bytes(T, flax.serialization.to_bytes(T)) 31 | assert_transforms_close(T, T_recovered) 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setup( 7 | name="jaxlie", 8 | version="1.5.0", 9 | description="Matrix Lie groups in JAX", 10 | long_description=long_description, 11 | long_description_content_type="text/markdown", 12 | url="http://github.com/brentyi/jaxlie", 13 | author="brentyi", 14 | author_email="brentyi@berkeley.edu", 15 | license="MIT", 16 | packages=find_packages(), 17 | package_data={"jaxlie": ["py.typed"]}, 18 | python_requires=">=3.8", 19 | install_requires=[ 20 | "jax>=0.3.18", # For jax.Array. 21 | "jax_dataclasses>=1.4.4", 22 | "numpy", 23 | "typing_extensions>=4.0.0", 24 | "tyro", # Only used in examples. 25 | ], 26 | extras_require={ 27 | "testing": [ 28 | "mypy", 29 | # https://github.com/google/jax/issues/12536 30 | "jax!=0.3.19", 31 | "flax", 32 | "hypothesis[numpy]", 33 | "pytest", 34 | "pytest-xdist[psutil]", 35 | "pytest-cov", 36 | ] 37 | }, 38 | classifiers=[ 39 | "Programming Language :: Python :: 3.7", 40 | "Programming Language :: Python :: 3.8", 41 | "Programming Language :: Python :: 3.9", 42 | "Programming Language :: Python :: 3.10", 43 | "License :: OSI Approved :: MIT License", 44 | "Operating System :: OS Independent", 45 | ], 46 | ) 47 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | jaxlie documentation 2 | ========================================== 3 | 4 | |build| |nbsp| |mypy| |nbsp| |lint| |nbsp| |coverage| 5 | 6 | 7 | :code:`jaxlie` is a Lie theory library for rigid body transformations and 8 | optimization in JAX. 9 | 10 | 11 | .. autoapi-inheritance-diagram:: jaxlie.SO2 jaxlie.SO3 jaxlie.SE2 jaxlie.SE3 12 | :top-classes: jaxlie.MatrixLieGroup 13 | 14 | 15 | Current functionality: 16 | 17 | - SO(2), SE(2), SO(3), and SE(3) Lie groups implemented as high-level 18 | dataclasses. 19 | 20 | - :code:`exp()`, :code:`log()`, :code:`adjoint()`, :code:`multiply()`, 21 | :code:`inverse()`, and :code:`identity()` implementations for each Lie group. 22 | 23 | - Pytree registration for all dataclasses. 24 | 25 | - Broadcasting for leading axes. 26 | 27 | - Helpers + analytical Jacobians for tangent-space optimization 28 | (:code:`jaxlie.manifold`). 29 | 30 | Source code on `Github `_. 31 | 32 | 33 | .. toctree:: 34 | :caption: API Reference 35 | :maxdepth: 3 36 | :titlesonly: 37 | :glob: 38 | 39 | api/jaxlie/index 40 | 41 | 42 | .. toctree:: 43 | :maxdepth: 5 44 | :caption: Example usage 45 | 46 | se3_basics 47 | se3_optimization 48 | vmap_usage 49 | 50 | 51 | .. |build| image:: https://github.com/brentyi/jaxlie/workflows/build/badge.svg 52 | :alt: Build status icon 53 | .. |mypy| image:: https://github.com/brentyi/jaxlie/workflows/mypy/badge.svg?branch=master 54 | :alt: Mypy status icon 55 | .. |lint| image:: https://github.com/brentyi/jaxlie/workflows/lint/badge.svg 56 | :alt: Lint status icon 57 | .. |coverage| image:: https://codecov.io/gh/brentyi/jaxlie/branch/master/graph/badge.svg 58 | :alt: Test coverage status icon 59 | :target: https://codecov.io/gh/brentyi/jaxlie 60 | .. |nbsp| unicode:: 0xA0 61 | :trim: 62 | -------------------------------------------------------------------------------- /tests/test_broadcast.py: -------------------------------------------------------------------------------- 1 | """Shape tests for broadcasting.""" 2 | 3 | from typing import Tuple, Type 4 | 5 | import numpy as onp 6 | from utils import ( 7 | general_group_test, 8 | sample_transform, 9 | ) 10 | 11 | import jaxlie 12 | 13 | 14 | @general_group_test 15 | def test_broadcast_multiply( 16 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 17 | ): 18 | if batch_axes == (): 19 | return 20 | 21 | T = sample_transform(Group, batch_axes) @ sample_transform(Group) 22 | assert T.get_batch_axes() == batch_axes 23 | 24 | T = sample_transform(Group, batch_axes) @ sample_transform(Group, batch_axes=(1,)) 25 | assert T.get_batch_axes() == batch_axes 26 | 27 | T = sample_transform(Group, batch_axes) @ sample_transform( 28 | Group, batch_axes=(1,) * len(batch_axes) 29 | ) 30 | assert T.get_batch_axes() == batch_axes 31 | 32 | T = sample_transform(Group) @ sample_transform(Group, batch_axes) 33 | assert T.get_batch_axes() == batch_axes 34 | 35 | T = sample_transform(Group, batch_axes=(1,)) @ sample_transform(Group, batch_axes) 36 | assert T.get_batch_axes() == batch_axes 37 | 38 | 39 | @general_group_test 40 | def test_broadcast_apply( 41 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 42 | ): 43 | if batch_axes == (): 44 | return 45 | 46 | T = sample_transform(Group, batch_axes) 47 | points = onp.random.randn(Group.space_dim) 48 | assert (T @ points).shape == (*batch_axes, Group.space_dim) 49 | 50 | T = sample_transform(Group, batch_axes) 51 | points = onp.random.randn(1, Group.space_dim) 52 | assert (T @ points).shape == (*batch_axes, Group.space_dim) 53 | 54 | T = sample_transform(Group, batch_axes) 55 | points = onp.random.randn(*((1,) * len(batch_axes)), Group.space_dim) 56 | assert (T @ points).shape == (*batch_axes, Group.space_dim) 57 | 58 | T = sample_transform(Group) 59 | points = onp.random.randn(*batch_axes, Group.space_dim) 60 | assert (T @ points).shape == (*batch_axes, Group.space_dim) 61 | 62 | T = sample_transform(Group, batch_axes=(1,)) 63 | points = onp.random.randn(*batch_axes, Group.space_dim) 64 | assert (T @ points).shape == (*batch_axes, Group.space_dim) 65 | -------------------------------------------------------------------------------- /examples/se3_basics.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | 3 | from jaxlie import SE3 4 | 5 | ############################# 6 | # (1) Constructing transforms. 7 | ############################# 8 | 9 | print("Constructing transforms.") 10 | 11 | # We can compute a w<-b transform by integrating over an se(3) screw, equivalent 12 | # to `SE3.from_matrix(expm(wedge(twist)))`. 13 | twist = onp.array([1.0, 0.0, 0.2, 0.0, 0.5, 0.0]) 14 | T_w_b = SE3.exp(twist) 15 | 16 | # We can print the (quaternion) rotation term; this is an `SO3` object: 17 | print(f"\t{T_w_b.rotation()=}") 18 | 19 | # Or print the translation; this is a simple array with shape (3,): 20 | print(f"\t{T_w_b.translation()=}") 21 | 22 | # Or the underlying parameters; this is a length-7 (quaternion, translation) array: 23 | print(f"\t{T_w_b.wxyz_xyz=}") # SE3-specific field. 24 | print(f"\t{T_w_b.parameters()=}") # Helper shared by all groups. 25 | 26 | # There are also other helpers to generate transforms, eg from matrices: 27 | T_w_b = SE3.from_matrix(T_w_b.as_matrix()) 28 | 29 | # Or from explicit rotation and translation terms: 30 | T_w_b = SE3.from_rotation_and_translation( 31 | rotation=T_w_b.rotation(), 32 | translation=T_w_b.translation(), 33 | ) 34 | 35 | # Or with the dataclass constructor + the underlying length-7 parameterization: 36 | T_w_b = SE3(wxyz_xyz=T_w_b.wxyz_xyz) 37 | 38 | 39 | ############################# 40 | # (2) Applying transforms. 41 | ############################# 42 | 43 | print("Applying transforms.") 44 | 45 | # Transform points with the `@` operator: 46 | p_b = onp.random.randn(3) 47 | p_w = T_w_b @ p_b 48 | print(f"\t{p_w=}") 49 | 50 | # or `.apply()`: 51 | p_w = T_w_b.apply(p_b) 52 | print(f"\t{p_w=}") 53 | 54 | # or the homogeneous matrix form: 55 | p_w = (T_w_b.as_matrix() @ onp.append(p_b, 1.0))[:-1] 56 | print(f"\t{p_w=}") 57 | 58 | 59 | ############################# 60 | # (3) Composing transforms. 61 | ############################# 62 | 63 | print("Composing transforms.") 64 | 65 | # Compose transforms with the `@` operator: 66 | T_b_a = SE3.identity() 67 | T_w_a = T_w_b @ T_b_a 68 | print(f"\t{T_w_a=}") 69 | 70 | # or `.multiply()`: 71 | T_w_a = T_w_b.multiply(T_b_a) 72 | print(f"\t{T_w_a=}") 73 | 74 | 75 | ############################# 76 | # (4) Misc. 77 | ############################# 78 | 79 | print("Misc.") 80 | 81 | # Compute inverses: 82 | T_b_w = T_w_b.inverse() 83 | identity = T_w_b @ T_b_w 84 | print(f"\t{identity=}") 85 | 86 | # Compute adjoints: 87 | adjoint_T_w_b = T_w_b.adjoint() 88 | print(f"\t{adjoint_T_w_b=}") 89 | 90 | # Recover our twist, equivalent to `vee(logm(T_w_b.as_matrix()))`: 91 | recovered_twist = T_w_b.log() 92 | print(f"\t{recovered_twist=}") 93 | -------------------------------------------------------------------------------- /jaxlie/utils/_utils.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Callable, Tuple, Type, TypeVar, Union, cast 2 | 3 | import jax_dataclasses as jdc 4 | from jax import numpy as jnp 5 | 6 | from jaxlie.hints import Array 7 | 8 | if TYPE_CHECKING: 9 | from .._base import MatrixLieGroup 10 | 11 | 12 | T = TypeVar("T", bound="MatrixLieGroup") 13 | 14 | 15 | def get_epsilon(dtype: jnp.dtype) -> float: 16 | """Helper for grabbing type-specific precision constants. 17 | 18 | Args: 19 | dtype: Datatype. 20 | 21 | Returns: 22 | Output float. 23 | """ 24 | return { 25 | jnp.dtype("float32"): 1e-5, 26 | jnp.dtype("float64"): 1e-10, 27 | }[dtype] 28 | 29 | 30 | def register_lie_group( 31 | *, 32 | matrix_dim: int, 33 | parameters_dim: int, 34 | tangent_dim: int, 35 | space_dim: int, 36 | ) -> Callable[[Type[T]], Type[T]]: 37 | """Decorator for registering Lie group dataclasses. 38 | 39 | Sets dimensionality class variables, and marks all methods for JIT compilation. 40 | """ 41 | 42 | def _wrap(cls: Type[T]) -> Type[T]: 43 | # Register dimensions as class attributes. 44 | cls.matrix_dim = matrix_dim 45 | cls.parameters_dim = parameters_dim 46 | cls.tangent_dim = tangent_dim 47 | cls.space_dim = space_dim 48 | 49 | # JIT all methods. 50 | for f in filter( 51 | lambda f: not f.startswith("_") 52 | and callable(getattr(cls, f)) 53 | and f != "get_batch_axes", # Avoid returning tracers. 54 | dir(cls), 55 | ): 56 | setattr(cls, f, jdc.jit(getattr(cls, f))) 57 | 58 | return cls 59 | 60 | return _wrap 61 | 62 | 63 | TupleOfBroadcastable = TypeVar( 64 | "TupleOfBroadcastable", 65 | bound="Tuple[Union[MatrixLieGroup, Array], ...]", 66 | ) 67 | 68 | 69 | def broadcast_leading_axes(inputs: TupleOfBroadcastable) -> TupleOfBroadcastable: 70 | """Broadcast leading axes of arrays. Takes tuples of either: 71 | - an array, which we assume has shape (*, D). 72 | - a Lie group object.""" 73 | 74 | from .._base import MatrixLieGroup 75 | 76 | array_inputs = [ 77 | ( 78 | (x.parameters(), (x.parameters_dim,)) 79 | if isinstance(x, MatrixLieGroup) 80 | else (x, x.shape[-1:]) 81 | ) 82 | for x in inputs 83 | ] 84 | for array, shape_suffix in array_inputs: 85 | assert array.shape[-len(shape_suffix) :] == shape_suffix 86 | batch_axes = jnp.broadcast_shapes( 87 | *[array.shape[: -len(suffix)] for array, suffix in array_inputs] 88 | ) 89 | broadcasted_arrays = tuple( 90 | jnp.broadcast_to(array, batch_axes + shape_suffix) 91 | for (array, shape_suffix) in array_inputs 92 | ) 93 | return cast( 94 | TupleOfBroadcastable, 95 | tuple( 96 | array if not isinstance(inp, MatrixLieGroup) else type(inp)(array) 97 | for array, inp in zip(broadcasted_arrays, inputs) 98 | ), 99 | ) 100 | -------------------------------------------------------------------------------- /jaxlie/manifold/_tree_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, TypeVar 2 | 3 | import jax 4 | import numpy as onp 5 | from jax._src.tree_util import _registry # Dangerous! 6 | 7 | from .._base import MatrixLieGroup 8 | 9 | # Tangent structures are difficult to annotate, so we just mark everything via Any. 10 | # 11 | # An annotation that would work in most cases is: 12 | # 13 | # def zero_tangents(structure: T) -> T 14 | # 15 | # But this is leaky; note that an input of List[SE3] should output List[jax.Array], 16 | # Dict[str, SE3] should output Dict[str, SE3], etc. 17 | # 18 | # Another tempting option is to define a wrapper class: 19 | # 20 | # @jdc.pytree_dataclass 21 | # class TangentPytree(Generic[PytreeType]): 22 | # wrapped: Any 23 | # 24 | # And have zero_tangents() return: 25 | # 26 | # def zero_tangents(structure: T) -> TangentPytree[T] 27 | # 28 | # which we could also use to make `jaxlie.manifold.rplus()` type safe by adding 29 | # overloads to make sure that the delta input is a TangentPytree, but it would be hard 30 | # to accurately annotate the `grad()` and `value_and_grad()` functions with this wrapper 31 | # type without sacrificing the ability to use them as drop-in replacements for 32 | # `jax.grad()` and `jax.value_and_grad()`. 33 | # 34 | # Finally, NewType is also attractive: 35 | # 36 | # TangentPytree: TypeAlias = NewType("TangentPytree", object) 37 | # 38 | # This seems reasonable, but doesn't play nice with how optax currently (a) annotates 39 | # everything using chex.ArrayTree and (b) doesn't use any generics, leading to a mess of 40 | # casts and `type: ignore` directives. We might consider using this if optax's gradient 41 | # transform annotations change. 42 | TangentPytree = Any 43 | 44 | 45 | def _map_group_trees( 46 | f_lie_groups: Callable, 47 | f_other_arrays: Callable, 48 | *tree_args, 49 | ) -> Any: 50 | if isinstance(tree_args[0], MatrixLieGroup): 51 | return f_lie_groups(*tree_args) 52 | elif isinstance(tree_args[0], (jax.Array, onp.ndarray)): 53 | return f_other_arrays(*tree_args) 54 | else: 55 | # Handle PyTrees recursively. 56 | assert len(set(map(type, tree_args))) == 1 57 | registry_entry = _registry[type(tree_args[0])] # type: ignore 58 | 59 | children: List[List[Any]] = [] 60 | metadata: List[Any] = [] 61 | for tree in tree_args: 62 | childs, meta = registry_entry.to_iter(tree) 63 | children.append(childs) 64 | metadata.append(meta) 65 | 66 | assert len(set(metadata)) == 1 67 | 68 | return registry_entry.from_iter( 69 | metadata[0], 70 | [ 71 | _map_group_trees( 72 | f_lie_groups, 73 | f_other_arrays, 74 | *list(children[i][j] for i in range(len(children))), 75 | ) 76 | for j in range(len(children[0])) 77 | ], 78 | ) 79 | 80 | 81 | PytreeType = TypeVar("PytreeType") 82 | 83 | 84 | def normalize_all(pytree: PytreeType) -> PytreeType: 85 | """Call `.normalize()` on each Lie group instance in a pytree. 86 | 87 | Results in a naive projection of each group instance to its respective manifold. 88 | """ 89 | 90 | def _project(t: MatrixLieGroup) -> MatrixLieGroup: 91 | return t.normalize() 92 | 93 | return _map_group_trees( 94 | _project, 95 | lambda x: x, 96 | pytree, 97 | ) 98 | -------------------------------------------------------------------------------- /tests/test_group_axioms.py: -------------------------------------------------------------------------------- 1 | """Tests for group axioms. 2 | 3 | https://proofwiki.org/wiki/Definition:Group_Axioms 4 | """ 5 | 6 | from typing import Tuple, Type 7 | 8 | import numpy as onp 9 | from utils import ( 10 | assert_arrays_close, 11 | assert_transforms_close, 12 | general_group_test, 13 | sample_transform, 14 | ) 15 | 16 | import jaxlie 17 | 18 | 19 | @general_group_test 20 | def test_closure(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 21 | """Check closure property.""" 22 | transform_a = sample_transform(Group, batch_axes) 23 | transform_b = sample_transform(Group, batch_axes) 24 | 25 | composed = transform_a @ transform_b 26 | assert_transforms_close(composed, composed.normalize()) 27 | composed = transform_b @ transform_a 28 | assert_transforms_close(composed, composed.normalize()) 29 | composed = Group.multiply(transform_a, transform_b) 30 | assert_transforms_close(composed, composed.normalize()) 31 | composed = Group.multiply(transform_b, transform_a) 32 | assert_transforms_close(composed, composed.normalize()) 33 | 34 | 35 | @general_group_test 36 | def test_identity(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 37 | """Check identity property.""" 38 | transform = sample_transform(Group, batch_axes) 39 | identity = Group.identity(batch_axes) 40 | assert_transforms_close(transform, identity @ transform) 41 | assert_transforms_close(transform, transform @ identity) 42 | assert_arrays_close( 43 | transform.as_matrix(), 44 | onp.einsum("...ij,...jk->...ik", identity.as_matrix(), transform.as_matrix()), 45 | ) 46 | assert_arrays_close( 47 | transform.as_matrix(), 48 | onp.einsum("...ij,...jk->...ik", transform.as_matrix(), identity.as_matrix()), 49 | ) 50 | 51 | 52 | @general_group_test 53 | def test_inverse(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 54 | """Check inverse property.""" 55 | transform = sample_transform(Group, batch_axes) 56 | identity = Group.identity(batch_axes) 57 | assert_transforms_close(identity, transform @ transform.inverse()) 58 | assert_transforms_close(identity, transform.inverse() @ transform) 59 | assert_transforms_close(identity, Group.multiply(transform, transform.inverse())) 60 | assert_transforms_close(identity, Group.multiply(transform.inverse(), transform)) 61 | assert_arrays_close( 62 | onp.broadcast_to( 63 | onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim) 64 | ), 65 | onp.einsum( 66 | "...ij,...jk->...ik", 67 | transform.as_matrix(), 68 | transform.inverse().as_matrix(), 69 | ), 70 | ) 71 | assert_arrays_close( 72 | onp.broadcast_to( 73 | onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim) 74 | ), 75 | onp.einsum( 76 | "...ij,...jk->...ik", 77 | transform.inverse().as_matrix(), 78 | transform.as_matrix(), 79 | ), 80 | ) 81 | 82 | 83 | @general_group_test 84 | def test_associative(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 85 | """Check associative property.""" 86 | transform_a = sample_transform(Group, batch_axes) 87 | transform_b = sample_transform(Group, batch_axes) 88 | transform_c = sample_transform(Group, batch_axes) 89 | assert_transforms_close( 90 | (transform_a @ transform_b) @ transform_c, 91 | transform_a @ (transform_b @ transform_c), 92 | ) 93 | -------------------------------------------------------------------------------- /tests/test_examples.py: -------------------------------------------------------------------------------- 1 | """Tests with explicit examples.""" 2 | 3 | import jaxlie 4 | import numpy as onp 5 | from hypothesis import given, settings 6 | from hypothesis import strategies as st 7 | 8 | from utils import assert_arrays_close, assert_transforms_close, sample_transform 9 | 10 | 11 | @settings(deadline=None) 12 | @given(_random_module=st.random_module()) 13 | def test_se2_translation(_random_module): 14 | """Simple test for SE(2) translation terms.""" 15 | translation = onp.random.randn(2) 16 | T = jaxlie.SE2.from_xy_theta(*translation, theta=0.0) 17 | assert_arrays_close(T @ translation, translation * 2) 18 | 19 | 20 | @settings(deadline=None) 21 | @given(_random_module=st.random_module()) 22 | def test_se3_translation(_random_module): 23 | """Simple test for SE(3) translation terms.""" 24 | translation = onp.random.randn(3) 25 | T = jaxlie.SE3.from_rotation_and_translation( 26 | rotation=jaxlie.SO3.identity(), 27 | translation=translation, 28 | ) 29 | assert_arrays_close(T @ translation, translation * 2) 30 | 31 | 32 | def test_se2_rotation(): 33 | """Simple test for SE(2) rotation terms.""" 34 | T_w_b = jaxlie.SE2.from_rotation_and_translation( 35 | rotation=jaxlie.SO2.from_radians(onp.pi / 2.0), 36 | translation=onp.zeros(2), 37 | ) 38 | p_b = onp.array([1.0, 0.0]) 39 | p_w = onp.array([0.0, 1.0]) 40 | assert_arrays_close(T_w_b @ p_b, p_w) 41 | 42 | 43 | def test_se3_rotation(): 44 | """Simple test for SE(3) rotation terms.""" 45 | T_w_b = jaxlie.SE3.from_rotation_and_translation( 46 | rotation=jaxlie.SO3.from_rpy_radians(onp.pi / 2.0, 0.0, 0.0), 47 | translation=onp.zeros(3), 48 | ) 49 | T_w_b_alt = jaxlie.SE3.from_rotation( 50 | jaxlie.SO3.from_rpy_radians(onp.pi / 2.0, 0.0, 0.0), 51 | ) 52 | p_b = onp.array([0.0, 1.0, 0.0]) 53 | p_w = onp.array([0.0, 0.0, 1.0]) 54 | assert_arrays_close(T_w_b @ p_b, T_w_b_alt @ p_b, p_w) 55 | 56 | 57 | def test_se3_from_translation(): 58 | """Simple test for SE(3) rotation terms.""" 59 | T_w_b = jaxlie.SE3.from_rotation_and_translation( 60 | rotation=jaxlie.SO3.identity(), 61 | translation=onp.arange(3) * 1.0, 62 | ) 63 | T_w_b_alt = jaxlie.SE3.from_translation(onp.arange(3) * 1.0) 64 | p_b = onp.array([0.0, 1.0, 0.0]) 65 | p_w = onp.array([0.0, 2.0, 2.0]) 66 | assert_arrays_close(T_w_b @ p_b, T_w_b_alt @ p_b, p_w) 67 | 68 | 69 | def test_so3_xyzw_basic(): 70 | """Check that we can create an SO3 object from an xyzw quaternion.""" 71 | assert_transforms_close( 72 | jaxlie.SO3.from_quaternion_xyzw(onp.array([0.0, 0.0, 0.0, 1.0])), 73 | jaxlie.SO3.identity(), 74 | ) 75 | 76 | 77 | # def test_so3_xyzw_dtype_error(): 78 | # """Check that an incorrect data-type results in an AssertionError.""" 79 | # with pytest.raises(AssertionError): 80 | # jaxlie.SO3(onp.array([1, 0, 0, 0])), 81 | # 82 | # 83 | # def test_so3_xyzw_shape_error(): 84 | # """Check that an incorrect shape results in an AssertionError.""" 85 | # with pytest.raises(AssertionError): 86 | # jaxlie.SO3(onp.array([1.0, 0.0, 0.0, 0.0, 0.0])) 87 | 88 | 89 | @settings(deadline=None) 90 | @given(_random_module=st.random_module()) 91 | def test_se3_compose(_random_module): 92 | """Compare SE3 composition in matrix form vs compact form.""" 93 | T1 = sample_transform(jaxlie.SE3) 94 | T2 = sample_transform(jaxlie.SE3) 95 | assert_arrays_close(T1.as_matrix() @ T2.as_matrix(), (T1 @ T2).as_matrix()) 96 | assert_transforms_close( 97 | jaxlie.SE3.from_matrix(T1.as_matrix() @ T2.as_matrix()), T1 @ T2 98 | ) 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # jaxlie 2 | 3 | ![build](https://github.com/brentyi/jaxlie/workflows/build/badge.svg) 4 | ![mypy](https://github.com/brentyi/jaxlie/workflows/mypy/badge.svg) 5 | ![lint](https://github.com/brentyi/jaxlie/workflows/lint/badge.svg) 6 | [![codecov](https://codecov.io/gh/brentyi/jaxlie/branch/master/graph/badge.svg)](https://codecov.io/gh/brentyi/jaxlie) 7 | [![pypi_dowlnoads](https://pepy.tech/badge/jaxlie)](https://pypi.org/project/jaxlie) 8 | 9 | **[ [API reference](https://brentyi.github.io/jaxlie) ]** **[ 10 | [PyPI](https://pypi.org/project/jaxlie/) ]** 11 | 12 | `jaxlie` is a library containing implementations of Lie groups commonly used for 13 | rigid body transformations, targeted at computer vision & robotics 14 | applications written in JAX. Heavily inspired by the C++ library 15 | [Sophus](https://github.com/strasdat/Sophus). 16 | 17 | We implement Lie groups as high-level (data)classes: 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 |
GroupDescriptionParameterization
jaxlie.SO2Rotations in 2D.(real, imaginary): unit complex (∈ S1)
jaxlie.SE2Proper rigid transforms in 2D.(real, imaginary, x, y): unit complex & translation
jaxlie.SO3Rotations in 3D.(qw, qx, qy, qz): wxyz quaternion (∈ S3)
jaxlie.SE3Proper rigid transforms in 3D.(qw, qx, qy, qz, x, y, z): wxyz quaternion & translation
50 | 51 | Where each group supports: 52 | 53 | - Forward- and reverse-mode AD-friendly **`exp()`**, **`log()`**, 54 | **`adjoint()`**, **`apply()`**, **`multiply()`**, **`inverse()`**, 55 | **`identity()`**, **`from_matrix()`**, and **`as_matrix()`** operations. (see 56 | [./examples/se3_example.py](./examples/se3_basics.py)) 57 | - Taylor approximations near singularities. 58 | - Helpers for optimization on manifolds (see 59 | [./examples/se3_optimization.py](./examples/se3_optimization.py), 60 | jaxlie.manifold.\*). 61 | - Compatibility with standard JAX function transformations. (see 62 | [./examples/vmap_example.py](./examples/vmap_example.py)) 63 | - Broadcasting for leading axes. 64 | - (Un)flattening as pytree nodes. 65 | - Serialization using [flax](https://github.com/google/flax). 66 | 67 | We also implement various common utilities for things like uniform random 68 | sampling (**`sample_uniform()`**) and converting from/to Euler angles (in the 69 | `SO3` class). 70 | 71 | --- 72 | 73 | ### Install (Python >=3.7) 74 | 75 | ```bash 76 | # Python 3.6 releases also exist, but are no longer being updated. 77 | pip install jaxlie 78 | ``` 79 | 80 | --- 81 | 82 | ### Misc 83 | 84 | `jaxlie` was originally written when I was learning about Lie groups for our IROS 2021 paper 85 | ([link](https://github.com/brentyi/dfgo)): 86 | 87 | ``` 88 | @inproceedings{yi2021iros, 89 | author={Brent Yi and Michelle Lee and Alina Kloss and Roberto Mart\'in-Mart\'in and Jeannette Bohg}, 90 | title = {Differentiable Factor Graph Optimization for Learning Smoothers}, 91 | year = 2021, 92 | BOOKTITLE = {2021 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)} 93 | } 94 | ``` 95 | -------------------------------------------------------------------------------- /tests/test_jlog.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | from typing import Tuple, Type 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import jaxlie 8 | 9 | from utils import assert_arrays_close, general_group_test, sample_transform 10 | 11 | 12 | def autodiff_jlog(group_element) -> jnp.ndarray: 13 | """ 14 | Compute the Jacobian of the logarithm map for a Lie group element using automatic differentiation. 15 | 16 | Args: 17 | group_element (Union[SO2, SO3, SE2, SE3]): A Lie group element. 18 | 19 | Returns: 20 | jnp.ndarray: The Jacobian matrix. 21 | """ 22 | 23 | def wrapped_function(tau): 24 | Group = type(group_element) 25 | return (group_element @ Group.exp(tau)).log() 26 | 27 | return jax.jacobian(wrapped_function)(jnp.zeros(group_element.tangent_dim)) 28 | 29 | 30 | def analytical_jlog(group_element) -> jnp.ndarray: 31 | """ 32 | Analytical computation of the Jacobian of the logarithm map for a Lie group element. 33 | 34 | Args: 35 | group_element (Union[SO2, SO3, SE2, SE3]): A Lie group element. 36 | 37 | Returns: 38 | jnp.ndarray: The Jacobian matrix. 39 | """ 40 | return group_element.jlog() 41 | 42 | 43 | @general_group_test 44 | def test_jlog_accuracy(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 45 | """Check accuracy of analytical jlog against autodiff jlog.""" 46 | transform = sample_transform(Group, batch_axes) 47 | 48 | # Create jitted versions of both functions. 49 | jitted_autodiff = jax.jit(autodiff_jlog) 50 | jitted_analytical = jax.jit(analytical_jlog) 51 | 52 | # Get results from both implementations. 53 | result_analytical = jitted_analytical(transform) 54 | result_autodiff = jitted_autodiff(transform) 55 | 56 | # Compare results with appropriate tolerance. 57 | assert_arrays_close(result_analytical, result_autodiff, rtol=1e-5, atol=1e-5) 58 | 59 | 60 | @partial(general_group_test, max_examples=1) 61 | def test_jlog_runtime(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 62 | """Compare runtime of analytical jlog and autodiff jlog.""" 63 | if Group is jaxlie.SO2: 64 | # Skip SO(2) since it has a trivial Jacobian. 65 | return 66 | 67 | transform = sample_transform(Group, batch_axes) 68 | 69 | # JIT compile both functions. 70 | jitted_autodiff = jax.jit(autodiff_jlog) 71 | jitted_analytical = jax.jit(analytical_jlog) 72 | 73 | # Warm-up run to ensure compilation happens before timing. 74 | jax.block_until_ready(jitted_autodiff(transform)) 75 | jax.block_until_ready(jitted_analytical(transform)) 76 | 77 | # Create a new transform for timing. 78 | num_runs = 30 79 | 80 | # Time autodiff implementation. 81 | times = [] 82 | for _ in range(num_runs): 83 | transform = jax.block_until_ready(sample_transform(Group, batch_axes)) 84 | start = time.perf_counter() 85 | result = jitted_autodiff(transform) 86 | result = jax.block_until_ready(result) # Wait for all operations to complete. 87 | times.append(time.perf_counter() - start) 88 | autodiff_runtime = min(times) * 1000 # Convert to ms. 89 | 90 | # Time analytical implementation. 91 | times = [] 92 | for _ in range(num_runs): 93 | transform = jax.block_until_ready(sample_transform(Group, batch_axes)) 94 | start = time.perf_counter() 95 | result = jitted_analytical(transform) 96 | result = jax.block_until_ready(result) # Wait for all operations to complete. 97 | times.append(time.perf_counter() - start) 98 | analytical_runtime = min(times) * 1000 # Convert to ms. 99 | 100 | assert ( 101 | analytical_runtime <= autodiff_runtime 102 | ), f"Autodiff jlog is slower than analytical jlog: {analytical_runtime:.2f}ms vs {autodiff_runtime:.2f}ms" 103 | -------------------------------------------------------------------------------- /tests/test_manifold.py: -------------------------------------------------------------------------------- 1 | """Test manifold helpers.""" 2 | 3 | from typing import Tuple, Type 4 | 5 | import jax 6 | import jaxlie 7 | import numpy as onp 8 | import pytest 9 | from jax import numpy as jnp 10 | from jax import tree_util 11 | 12 | from utils import ( 13 | assert_arrays_close, 14 | assert_transforms_close, 15 | general_group_test, 16 | general_group_test_faster, 17 | sample_transform, 18 | ) 19 | 20 | 21 | @general_group_test 22 | def test_rplus_rminus(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 23 | """Check rplus and rminus on random inputs.""" 24 | T_wa = sample_transform(Group, batch_axes) 25 | T_wb = sample_transform(Group, batch_axes) 26 | T_ab = T_wa.inverse() @ T_wb 27 | 28 | assert_transforms_close(jaxlie.manifold.rplus(T_wa, T_ab.log()), T_wb) 29 | assert_arrays_close(jaxlie.manifold.rminus(T_wa, T_wb), T_ab.log()) 30 | 31 | 32 | @general_group_test 33 | def test_rplus_jacobian( 34 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 35 | ): 36 | """Check analytical rplus Jacobian..""" 37 | T_wa = sample_transform(Group, batch_axes) 38 | 39 | J_ours = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(T_wa) 40 | 41 | if batch_axes == (): 42 | J_jacfwd = _rplus_jacobian_parameters_wrt_delta(T_wa) 43 | assert_arrays_close(J_ours, J_jacfwd) 44 | else: 45 | # Batch axes should match vmap. 46 | jacfunc = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta 47 | for _ in batch_axes: 48 | jacfunc = jax.vmap(jacfunc) 49 | J_vmap = jacfunc(T_wa) 50 | assert_arrays_close(J_ours, J_vmap) 51 | 52 | 53 | @jax.jit 54 | def _rplus_jacobian_parameters_wrt_delta( 55 | transform: jaxlie.MatrixLieGroup, 56 | ) -> jax.Array: 57 | # Copied from docstring for `rplus_jacobian_parameters_wrt_delta()`. 58 | return jax.jacfwd( 59 | lambda delta: jaxlie.manifold.rplus(transform, delta).parameters() 60 | )(onp.zeros(transform.tangent_dim)) 61 | 62 | 63 | @general_group_test_faster 64 | def test_sgd(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 65 | def loss(transform: jaxlie.MatrixLieGroup): 66 | return (transform.log() ** 2).sum() 67 | 68 | transform = Group.exp(sample_transform(Group, batch_axes).log()) 69 | original_loss = loss(transform) 70 | 71 | assert_arrays_close( 72 | jaxlie.manifold.grad(lambda transform: (loss(transform), None), has_aux=True)( 73 | transform 74 | )[0], 75 | jaxlie.manifold.grad(loss)(transform), 76 | ) 77 | 78 | @jax.jit 79 | def step(t): 80 | return jaxlie.manifold.rplus(t, -1e-3 * jaxlie.manifold.grad(loss)(t)) 81 | 82 | for i in range(5): 83 | transform = step(transform) 84 | 85 | assert loss(transform) < original_loss 86 | 87 | 88 | def test_rplus_euclidean(): 89 | assert_arrays_close( 90 | jaxlie.manifold.rplus(jnp.ones(2), jnp.ones(2)), 2 * jnp.ones(2) 91 | ) 92 | 93 | 94 | def test_rminus_auto_vmap(): 95 | deltas = jaxlie.manifold.rminus( 96 | tree_util.tree_map( 97 | lambda *args: jnp.stack(args), 98 | [jaxlie.SE3.sample_uniform(jax.random.PRNGKey(0)), jaxlie.SE3.identity()], 99 | ), 100 | tree_util.tree_map( 101 | lambda *args: jnp.stack(args), 102 | [jaxlie.SE3.identity(), jaxlie.SE3.sample_uniform(jax.random.PRNGKey(0))], 103 | ), 104 | ) 105 | assert_arrays_close(deltas[0], -deltas[1]) 106 | 107 | 108 | def test_normalize(): 109 | container = {"key": (jaxlie.SO3(jnp.array([2.0, 0.0, 0.0, 0.0])),)} 110 | container_valid = {"key": (jaxlie.SO3(jnp.array([1.0, 0.0, 0.0, 0.0])),)} 111 | with pytest.raises(AssertionError): 112 | assert_transforms_close(container["key"][0], container_valid["key"][0]) 113 | assert_transforms_close( 114 | jaxlie.manifold.normalize_all(container)["key"][0], container_valid["key"][0] 115 | ) 116 | -------------------------------------------------------------------------------- /jaxlie/manifold/_backprop.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, Callable, Sequence, Tuple, Union, overload 4 | 5 | import jax 6 | from jax import numpy as jnp 7 | from typing_extensions import ParamSpec 8 | 9 | from .._base import MatrixLieGroup 10 | from . import _deltas, _tree_utils 11 | 12 | 13 | def zero_tangents(pytree: Any) -> _tree_utils.TangentPytree: 14 | """Replace all values in a Pytree with zero vectors on the corresponding tangent 15 | spaces.""" 16 | 17 | def tangent_zero(t: MatrixLieGroup) -> jax.Array: 18 | return jnp.zeros(t.get_batch_axes() + (t.tangent_dim,)) 19 | 20 | return _tree_utils._map_group_trees( 21 | tangent_zero, 22 | lambda array: jnp.zeros_like(array), 23 | pytree, 24 | ) 25 | 26 | 27 | AxisName = Any 28 | 29 | P = ParamSpec("P") 30 | 31 | 32 | @overload 33 | def grad( 34 | fun: Callable[P, Any], 35 | argnums: int = 0, 36 | has_aux: bool = False, 37 | holomorphic: bool = False, 38 | allow_int: bool = False, 39 | reduce_axes: Sequence[AxisName] = (), 40 | ) -> Callable[P, _tree_utils.TangentPytree]: ... 41 | 42 | 43 | @overload 44 | def grad( 45 | fun: Callable[P, Any], 46 | argnums: Sequence[int], 47 | has_aux: bool = False, 48 | holomorphic: bool = False, 49 | allow_int: bool = False, 50 | reduce_axes: Sequence[AxisName] = (), 51 | ) -> Callable[P, Tuple[_tree_utils.TangentPytree, ...]]: ... 52 | 53 | 54 | def grad( 55 | fun: Callable[P, Any], 56 | argnums: Union[int, Sequence[int]] = 0, 57 | has_aux: bool = False, 58 | holomorphic: bool = False, 59 | allow_int: bool = False, 60 | reduce_axes: Sequence[AxisName] = (), 61 | ): 62 | """Same as `jax.grad`, but computes gradients of Lie groups with respect to 63 | tangent spaces.""" 64 | 65 | compute_value_and_grad = value_and_grad( 66 | fun=fun, 67 | argnums=argnums, 68 | has_aux=has_aux, 69 | holomorphic=holomorphic, 70 | allow_int=allow_int, 71 | reduce_axes=reduce_axes, 72 | ) 73 | 74 | def grad_fun(*args, **kwargs): 75 | ret = compute_value_and_grad(*args, **kwargs) 76 | if has_aux: 77 | return ret[1], ret[0][1] 78 | else: 79 | return ret[1] 80 | 81 | return grad_fun 82 | 83 | 84 | @overload 85 | def value_and_grad( 86 | fun: Callable[P, Any], 87 | argnums: int = 0, 88 | has_aux: bool = False, 89 | holomorphic: bool = False, 90 | allow_int: bool = False, 91 | reduce_axes: Sequence[AxisName] = (), 92 | ) -> Callable[P, Tuple[Any, _tree_utils.TangentPytree]]: ... 93 | 94 | 95 | @overload 96 | def value_and_grad( 97 | fun: Callable[P, Any], 98 | argnums: Sequence[int], 99 | has_aux: bool = False, 100 | holomorphic: bool = False, 101 | allow_int: bool = False, 102 | reduce_axes: Sequence[AxisName] = (), 103 | ) -> Callable[P, Tuple[Any, Tuple[_tree_utils.TangentPytree, ...]]]: ... 104 | 105 | 106 | def value_and_grad( 107 | fun: Callable[P, Any], 108 | argnums: Union[int, Sequence[int]] = 0, 109 | has_aux: bool = False, 110 | holomorphic: bool = False, 111 | allow_int: bool = False, 112 | reduce_axes: Sequence[AxisName] = (), 113 | ): 114 | """Same as `jax.value_and_grad`, but computes gradients of Lie groups with respect to 115 | tangent spaces.""" 116 | 117 | def wrapped_grad(*args, **kwargs): 118 | def tangent_fun(*tangent_args, **tangent_kwargs): 119 | return fun( # type: ignore 120 | *_deltas.rplus(args, tangent_args), 121 | **_deltas.rplus(kwargs, tangent_kwargs), 122 | ) 123 | 124 | # Put arguments onto tangent space. 125 | tangent_args = map(zero_tangents, args) 126 | tangent_kwargs = {k: zero_tangents(v) for k, v in kwargs.items()} 127 | 128 | return jax.value_and_grad( 129 | fun=tangent_fun, 130 | argnums=argnums, 131 | has_aux=has_aux, 132 | holomorphic=holomorphic, 133 | allow_int=allow_int, 134 | reduce_axes=reduce_axes, 135 | )(*tangent_args, **tangent_kwargs) 136 | 137 | return wrapped_grad # type: ignore 138 | -------------------------------------------------------------------------------- /examples/vmap_example.py: -------------------------------------------------------------------------------- 1 | """jaxlie implements numpy-style broadcasting for all operations. For more 2 | explicit vectorization, we can also use vmap function transformations. 3 | 4 | Omitted for brevity here, but in practice we usually want to JIT after 5 | vmapping.""" 6 | 7 | import jax 8 | import numpy as onp 9 | 10 | from jaxlie import SO3 11 | 12 | N = 100 13 | 14 | ############################# 15 | # (1) Setup. 16 | ############################# 17 | 18 | # We start by creating two rotation objects: 19 | # - R_single contains a standard single rotation. 20 | # - R_stacked contained `N` rotations stacked together! Note that all Lie group objects 21 | # are PyTrees, so this has the same structure as R_single but with a batch axis in the 22 | # contained parameters array. 23 | 24 | R_single = SO3.from_x_radians(onp.pi / 2.0) 25 | assert R_single.wxyz.shape == (4,) 26 | 27 | R_stacked = jax.vmap(SO3.from_x_radians)( 28 | onp.random.uniform(low=-onp.pi, high=onp.pi, size=(N,)) 29 | ) 30 | assert R_stacked.wxyz.shape == (N, 4) 31 | 32 | # We can also create two arrays containing points: one is a single point, the other is 33 | # `N` points stacked. 34 | p_single = onp.random.uniform(size=(3,)) 35 | p_stacked = onp.random.uniform(size=(N, 3)) 36 | 37 | ############################# 38 | # (2) Applying 1 transformation to 1 point. 39 | ############################# 40 | 41 | # Recall that these two approaches to transforming a point: 42 | p_transformed_single = R_single @ p_single 43 | assert p_transformed_single.shape == (3,) 44 | p_transformed_single = R_single.apply(p_single) 45 | assert p_transformed_single.shape == (3,) 46 | 47 | # Are just syntactic sugar for calling: 48 | p_transformed_single = SO3.apply(R_single, p_single) 49 | assert p_transformed_single.shape == (3,) 50 | 51 | 52 | ############################# 53 | # (3) Applying 1 transformation to N points. 54 | ############################# 55 | 56 | # This follows standard vmap semantics! 57 | p_transformed_stacked = jax.vmap(R_single.apply)(p_stacked) 58 | assert p_transformed_stacked.shape == (N, 3) 59 | 60 | # Note that this is equivalent to: 61 | p_transformed_stacked = jax.vmap(lambda p: SO3.apply(R_single, p))(p_stacked) 62 | assert p_transformed_stacked.shape == (N, 3) 63 | 64 | # We can also just rely on broadcasting. 65 | p_transformed_stacked = R_single @ p_stacked 66 | assert p_transformed_stacked.shape == (N, 3) 67 | 68 | ############################# 69 | # (4) Applying N transformations to N points. 70 | ############################# 71 | 72 | # R_stacked and p_stacked both have an (N,) batch dimension compared to their "single" 73 | # counterparts. We can therefore vmap over both arguments of SO3.apply: 74 | p_transformed_stacked = jax.vmap(SO3.apply)(R_stacked, p_stacked) 75 | assert p_transformed_stacked.shape == (N, 3) 76 | 77 | # We can also just rely on broadcasting. 78 | p_transformed_stacked = R_stacked @ p_stacked 79 | assert p_transformed_stacked.shape == (N, 3) 80 | 81 | ############################# 82 | # (5) Applying N transformations to 1 point. 83 | ############################# 84 | 85 | p_transformed_stacked = jax.vmap(lambda R: SO3.apply(R, p_single))(R_stacked) 86 | assert p_transformed_stacked.shape == (N, 3) 87 | 88 | # We can also just rely on broadcasting. 89 | p_transformed_stacked = R_stacked @ p_single[None, :] 90 | assert p_transformed_stacked.shape == (N, 3) 91 | 92 | ############################# 93 | # (6) Multiplying transformations. 94 | ############################# 95 | 96 | # The same concepts as above apply to other operations! 97 | # For multiplication, these are all the same: 98 | assert (R_single @ R_single).wxyz.shape == (4,) 99 | assert (R_single.multiply(R_single)).wxyz.shape == (4,) 100 | assert (SO3.multiply(R_single, R_single)).wxyz.shape == (4,) 101 | 102 | # And therefore we can also do 1 x N multiplication: 103 | assert (jax.vmap(R_single.multiply)(R_stacked)).wxyz.shape == (N, 4) 104 | assert (jax.vmap(lambda R: SO3.multiply(R_single, R))(R_stacked)).wxyz.shape == (N, 4) 105 | 106 | # Or N x N multiplication: 107 | assert (jax.vmap(SO3.multiply)(R_stacked, R_stacked)).wxyz.shape == (N, 4) 108 | 109 | # Or N x 1 multiplication: 110 | assert (jax.vmap(lambda R: SO3.multiply(R, R_single))(R_stacked)).wxyz.shape == (N, 4) 111 | 112 | # Again, broadcasting also works. 113 | assert (R_stacked @ R_stacked).wxyz.shape == (N, 4) 114 | assert (R_stacked @ SO3(R_single.wxyz[None, :])).wxyz.shape == (N, 4) 115 | -------------------------------------------------------------------------------- /tests/test_jac_left.py: -------------------------------------------------------------------------------- 1 | """Test left jacobian functions for SO2 and SO3 groups. These are submatrices 2 | of the left Jacobian of the Lie group and its inverse respectively. We can 3 | check our analytical implementations against autodiff.""" 4 | 5 | from typing import Callable, Dict, Tuple, Type 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import jaxlie 10 | 11 | from utils import assert_arrays_close, general_group_test, sample_transform 12 | 13 | # Dictionary mapping group classes to their corresponding left jacobian functions. 14 | _V_FUNCS: Dict[Type[jaxlie.MatrixLieGroup], Tuple[Callable, Callable]] = { 15 | jaxlie.SE2: ( 16 | jax.jit(jaxlie._se2._SE2_jac_left), 17 | jax.jit(jaxlie._se2._SE2_jac_left_inv), 18 | ), 19 | jaxlie.SO3: ( 20 | jax.jit(jaxlie._so3._SO3_jac_left), 21 | jax.jit(jaxlie._so3._SO3_jac_left_inv), 22 | ), 23 | } 24 | 25 | 26 | # Autodiff versions of left jacobian functions. We could very reasonably use these 27 | # directly in jaxlie, but the analytical versions give us a bit more control for 28 | # things like Taylor expansion. In the future we might be able to handle that 29 | # automatically with jet types though: 30 | # https://docs.jax.dev/en/latest/jax.experimental.jet.html 31 | # 32 | # For these autodiff implementations: 33 | # > https://arxiv.org/pdf/1812.01537 34 | @jax.jit 35 | def compute_autodiff_jac_left(transform: jaxlie.MatrixLieGroup): 36 | def left_plus(tangent_at_identity): 37 | Group = type(transform) 38 | return (Group.exp(tangent_at_identity) @ transform).log() 39 | 40 | # Jacobian of tangent at `transform` wrt tangent at identity. 41 | pullback = jax.jacrev(left_plus)(jnp.zeros(transform.tangent_dim)) 42 | 43 | # The pushforward is the left Jacobian. This transforms tangent vectors at 44 | # identity to tangent vectors at `transform`. 45 | pushforward = jnp.linalg.inv(pullback) 46 | return pushforward 47 | 48 | 49 | compute_autodiff_jac_left_vmap = jax.jit(jax.vmap(compute_autodiff_jac_left)) 50 | 51 | 52 | @jax.jit 53 | def compute_autodiff_jac_left_inv(transform: jaxlie.MatrixLieGroup): 54 | def left_plus(tangent_at_identity): 55 | Group = type(transform) 56 | return (Group.exp(tangent_at_identity) @ transform).log() 57 | 58 | pullback = jax.jacrev(left_plus)(jnp.zeros(transform.tangent_dim)) 59 | return pullback 60 | 61 | 62 | compute_autodiff_jac_left_inv_vmap = jax.jit(jax.vmap(compute_autodiff_jac_left_inv)) 63 | 64 | 65 | @general_group_test 66 | def test_jac_left_autodiff( 67 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 68 | ): 69 | """Test left jacobian and its inverse against automatic differentiation.""" 70 | # Skip groups that don't have left jacobian functions. 71 | if Group not in _V_FUNCS: 72 | return 73 | 74 | # Create identity transform with appropriate batch shape. 75 | transform = sample_transform(Group, batch_axes) 76 | theta = transform.log() 77 | 78 | # For SE2, the input should be just the rotation part. 79 | if Group is jaxlie.SE2: 80 | theta = theta[..., 2:3] 81 | 82 | # Compute _V_inv using the implementation. 83 | V, V_inv = _V_FUNCS[Group] 84 | analytical_V_inv = V_inv(theta) 85 | if Group is jaxlie.SO3: 86 | analytical_V = V(theta, transform.as_matrix()) 87 | else: 88 | analytical_V = V(theta) 89 | assert_arrays_close( 90 | jnp.linalg.inv(analytical_V_inv), analytical_V, rtol=1e-5, atol=1e-5 91 | ) 92 | 93 | # Compute _V_inv using autodiff. 94 | autodiff_jac_left_inv = ( 95 | compute_autodiff_jac_left_inv(transform) 96 | if len(transform.get_batch_axes()) == 0 97 | else compute_autodiff_jac_left_inv_vmap(transform) 98 | ) 99 | autodiff_jac_left = ( 100 | compute_autodiff_jac_left(transform) 101 | if len(transform.get_batch_axes()) == 0 102 | else compute_autodiff_jac_left_vmap(transform) 103 | ) 104 | 105 | # For SE2, the output should be just the translation part. 106 | if Group is jaxlie.SE2: 107 | autodiff_jac_left = autodiff_jac_left[..., :2, :2] 108 | autodiff_jac_left_inv = autodiff_jac_left_inv[..., :2, :2] 109 | 110 | # Compare the results. 111 | assert_arrays_close(analytical_V, autodiff_jac_left, rtol=1e-5, atol=1e-5) 112 | assert_arrays_close(analytical_V_inv, autodiff_jac_left_inv, rtol=1e-5, atol=1e-5) 113 | -------------------------------------------------------------------------------- /jaxlie/_so2.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Tuple 4 | 5 | import jax 6 | import jax_dataclasses as jdc 7 | from jax import numpy as jnp 8 | from typing_extensions import override 9 | 10 | from . import _base, hints 11 | from .utils import broadcast_leading_axes, register_lie_group 12 | 13 | 14 | @register_lie_group( 15 | matrix_dim=2, 16 | parameters_dim=2, 17 | tangent_dim=1, 18 | space_dim=2, 19 | ) 20 | @jdc.pytree_dataclass 21 | class SO2(_base.SOBase): 22 | """Special orthogonal group for 2D rotations. Broadcasting rules are the 23 | same as for `numpy`. 24 | 25 | Internal parameterization is `(cos, sin)`. Tangent parameterization is `(omega,)`. 26 | """ 27 | 28 | # SO2-specific. 29 | 30 | unit_complex: jax.Array 31 | """Internal parameters. `(cos, sin)`. Shape should be `(*, 2)`.""" 32 | 33 | @override 34 | def __repr__(self) -> str: 35 | unit_complex = jnp.round(self.unit_complex, 5) 36 | return f"{self.__class__.__name__}(unit_complex={unit_complex})" 37 | 38 | @staticmethod 39 | def from_radians(theta: hints.Scalar) -> SO2: 40 | """Construct a rotation object from a scalar angle.""" 41 | cos = jnp.cos(theta) 42 | sin = jnp.sin(theta) 43 | return SO2(unit_complex=jnp.stack([cos, sin], axis=-1)) 44 | 45 | def as_radians(self) -> jax.Array: 46 | """Compute a scalar angle from a rotation object.""" 47 | radians = self.log()[..., 0] 48 | return radians 49 | 50 | # Factory. 51 | 52 | @classmethod 53 | @override 54 | def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SO2: 55 | return SO2( 56 | unit_complex=jnp.stack( 57 | [jnp.ones(batch_axes), jnp.zeros(batch_axes)], axis=-1 58 | ) 59 | ) 60 | 61 | @classmethod 62 | @override 63 | def from_matrix(cls, matrix: hints.Array) -> SO2: 64 | assert matrix.shape[-2:] == (2, 2) 65 | return SO2(unit_complex=jnp.asarray(matrix[..., :, 0])) 66 | 67 | # Accessors. 68 | 69 | @override 70 | def as_matrix(self) -> jax.Array: 71 | cos_sin = self.unit_complex 72 | out = jnp.stack( 73 | [ 74 | # [cos, -sin], 75 | cos_sin * jnp.array([1, -1]), 76 | # [sin, cos], 77 | cos_sin[..., ::-1], 78 | ], 79 | axis=-2, 80 | ) 81 | assert out.shape == (*self.get_batch_axes(), 2, 2) 82 | return out 83 | 84 | @override 85 | def parameters(self) -> jax.Array: 86 | return self.unit_complex 87 | 88 | # Operations. 89 | 90 | @override 91 | def apply(self, target: hints.Array) -> jax.Array: 92 | assert target.shape[-1:] == (2,) 93 | self, target = broadcast_leading_axes((self, target)) 94 | return jnp.einsum("...ij,...j->...i", self.as_matrix(), target) 95 | 96 | @override 97 | def multiply(self, other: SO2) -> SO2: 98 | return SO2( 99 | unit_complex=jnp.einsum( 100 | "...ij,...j->...i", self.as_matrix(), other.unit_complex 101 | ) 102 | ) 103 | 104 | @classmethod 105 | @override 106 | def exp(cls, tangent: hints.Array) -> SO2: 107 | assert tangent.shape[-1] == 1 108 | cos = jnp.cos(tangent) 109 | sin = jnp.sin(tangent) 110 | return SO2(unit_complex=jnp.concatenate([cos, sin], axis=-1)) 111 | 112 | @override 113 | def log(self) -> jax.Array: 114 | return jnp.arctan2( 115 | self.unit_complex[..., 1, None], self.unit_complex[..., 0, None] 116 | ) 117 | 118 | @override 119 | def adjoint(self) -> jax.Array: 120 | return jnp.ones((*self.get_batch_axes(), 1, 1)) 121 | 122 | @override 123 | def inverse(self) -> SO2: 124 | return SO2(unit_complex=self.unit_complex * jnp.array([1, -1])) 125 | 126 | @override 127 | def normalize(self) -> SO2: 128 | return SO2( 129 | unit_complex=self.unit_complex 130 | / jnp.linalg.norm(self.unit_complex, axis=-1, keepdims=True) 131 | ) 132 | 133 | @override 134 | def jlog(self) -> jax.Array: 135 | batch_axes = self.get_batch_axes() 136 | ones = jnp.ones(batch_axes) 137 | return ones[..., None, None] 138 | 139 | @classmethod 140 | @override 141 | def sample_uniform( 142 | cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = () 143 | ) -> SO2: 144 | out = SO2.from_radians( 145 | jax.random.uniform( 146 | key=key, shape=batch_axes, minval=0.0, maxval=2.0 * jnp.pi 147 | ) 148 | ) 149 | assert out.get_batch_axes() == batch_axes 150 | return out 151 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import random 3 | from typing import Any, Callable, List, Tuple, Type, TypeVar, cast 4 | 5 | import jax 6 | import numpy as onp 7 | import pytest 8 | import scipy.optimize 9 | from hypothesis import given, settings 10 | from hypothesis import strategies as st 11 | from jax import numpy as jnp 12 | 13 | import jaxlie 14 | 15 | # Run all tests with double-precision. 16 | jax.config.update("jax_enable_x64", True) 17 | 18 | T = TypeVar("T", bound=jaxlie.MatrixLieGroup) 19 | 20 | 21 | def sample_transform(Group: Type[T], batch_axes: Tuple[int, ...] = ()) -> T: 22 | """Sample a random transform from a group.""" 23 | seed = random.getrandbits(32) 24 | strategy = random.randint(0, 2) 25 | 26 | if strategy == 0: 27 | # Uniform sampling. 28 | return cast( 29 | T, 30 | Group.sample_uniform( 31 | key=jax.random.PRNGKey(seed=seed), batch_axes=batch_axes 32 | ), 33 | ) 34 | elif strategy == 1: 35 | # Sample from normally-sampled tangent vector. 36 | return cast(T, Group.exp(onp.random.randn(*batch_axes, Group.tangent_dim))) 37 | elif strategy == 2: 38 | # Sample near identity. 39 | return cast( 40 | T, Group.exp(onp.random.randn(*batch_axes, Group.tangent_dim) * 1e-7) 41 | ) 42 | else: 43 | assert False 44 | 45 | 46 | def general_group_test( 47 | f: Callable[[Type[jaxlie.MatrixLieGroup], Tuple[int, ...]], None], 48 | max_examples: int = 30, 49 | ) -> Callable[[Type[jaxlie.MatrixLieGroup], Tuple[int, ...], Any], None]: 50 | """Decorator for defining tests that run on all group types.""" 51 | 52 | # Disregard unused argument. 53 | def f_wrapped( 54 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...], _random_module 55 | ) -> None: 56 | f(Group, batch_axes) 57 | 58 | # Disable timing check (first run requires JIT tracing and will be slower). 59 | f_wrapped = settings(deadline=None, max_examples=max_examples)(f_wrapped) 60 | 61 | # Add _random_module parameter. 62 | f_wrapped = given(_random_module=st.random_module())(f_wrapped) 63 | 64 | # Parametrize tests with each group type. 65 | f_wrapped = pytest.mark.parametrize( 66 | "Group", 67 | [ 68 | jaxlie.SO2, 69 | jaxlie.SE2, 70 | jaxlie.SO3, 71 | jaxlie.SE3, 72 | ], 73 | )(f_wrapped) 74 | 75 | # Parametrize tests with each group type. 76 | f_wrapped = pytest.mark.parametrize( 77 | "batch_axes", 78 | [ 79 | (), 80 | (1,), 81 | (3, 1, 2, 1), 82 | ], 83 | )(f_wrapped) 84 | return f_wrapped 85 | 86 | 87 | general_group_test_faster = functools.partial(general_group_test, max_examples=5) 88 | 89 | 90 | def assert_transforms_close(a: jaxlie.MatrixLieGroup, b: jaxlie.MatrixLieGroup): 91 | """Make sure two transforms are equivalent.""" 92 | # Check matrix representation. 93 | assert_arrays_close(a.as_matrix(), b.as_matrix()) 94 | 95 | # Flip signs for quaternions. 96 | # We use `jnp.asarray` here in case inputs are onp arrays and don't support `.at()`. 97 | p1 = jnp.asarray(a.parameters()) 98 | p2 = jnp.asarray(b.parameters()) 99 | if isinstance(a, jaxlie.SO3): 100 | p1 = p1 * jnp.sign(jnp.sum(p1, axis=-1, keepdims=True)) 101 | p2 = p2 * jnp.sign(jnp.sum(p2, axis=-1, keepdims=True)) 102 | elif isinstance(a, jaxlie.SE3): 103 | p1 = p1.at[..., :4].mul(jnp.sign(jnp.sum(p1[..., :4], axis=-1, keepdims=True))) 104 | p2 = p2.at[..., :4].mul(jnp.sign(jnp.sum(p2[..., :4], axis=-1, keepdims=True))) 105 | 106 | # Make sure parameters are equal. 107 | assert_arrays_close(p1, p2) 108 | 109 | 110 | def assert_arrays_close( 111 | *arrays: jaxlie.hints.Array, 112 | rtol: float = 1e-8, 113 | atol: float = 1e-8, 114 | ): 115 | """Make sure two arrays are close. (and not NaN)""" 116 | for array1, array2 in zip(arrays[:-1], arrays[1:]): 117 | onp.testing.assert_allclose(array1, array2, rtol=rtol, atol=atol) 118 | assert not onp.any(onp.isnan(array1)) 119 | assert not onp.any(onp.isnan(array2)) 120 | 121 | 122 | def jacnumerical( 123 | f: Callable[[jaxlie.hints.Array], jax.Array], 124 | ) -> Callable[[jaxlie.hints.Array], jax.Array]: 125 | """Decorator for computing numerical Jacobians of vector->vector functions.""" 126 | 127 | def wrapped(primal: jaxlie.hints.Array) -> jax.Array: 128 | output_dim: int 129 | (output_dim,) = f(primal).shape 130 | 131 | jacobian_rows: List[onp.ndarray] = [] 132 | for i in range(output_dim): 133 | jacobian_row: onp.ndarray = scipy.optimize.approx_fprime( 134 | primal, lambda p: f(p)[i], epsilon=1e-5 135 | ) 136 | assert jacobian_row.shape == primal.shape 137 | jacobian_rows.append(jacobian_row) 138 | 139 | return jnp.stack(jacobian_rows, axis=0) 140 | 141 | return wrapped 142 | -------------------------------------------------------------------------------- /tests/test_operations.py: -------------------------------------------------------------------------------- 1 | """Tests for general operation definitions.""" 2 | 3 | from typing import Tuple, Type 4 | 5 | import numpy as onp 6 | from hypothesis import given, settings 7 | from hypothesis import strategies as st 8 | from jax import numpy as jnp 9 | from utils import ( 10 | assert_arrays_close, 11 | assert_transforms_close, 12 | general_group_test, 13 | sample_transform, 14 | ) 15 | 16 | import jaxlie 17 | 18 | 19 | @general_group_test 20 | def test_sample_uniform_valid( 21 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 22 | ): 23 | """Check that sample_uniform() returns valid group members.""" 24 | T = sample_transform(Group, batch_axes) # Calls sample_uniform under the hood. 25 | assert_transforms_close(T, T.normalize()) 26 | 27 | 28 | @settings(deadline=None) 29 | @given(_random_module=st.random_module()) 30 | def test_so2_from_to_radians_bijective(_random_module): 31 | """Check that we can convert from and to radians.""" 32 | radians = onp.random.uniform(low=-onp.pi, high=onp.pi) 33 | assert_arrays_close(jaxlie.SO2.from_radians(radians).as_radians(), radians) 34 | 35 | 36 | @settings(deadline=None) 37 | @given(_random_module=st.random_module()) 38 | def test_so3_xyzw_bijective(_random_module): 39 | """Check that we can convert between xyzw and wxyz quaternions.""" 40 | T = sample_transform(jaxlie.SO3) 41 | assert_transforms_close(T, jaxlie.SO3.from_quaternion_xyzw(T.as_quaternion_xyzw())) 42 | 43 | 44 | @settings(deadline=None) 45 | @given(_random_module=st.random_module()) 46 | def test_so3_rpy_bijective(_random_module): 47 | """Check that we can convert between quaternions and Euler angles.""" 48 | T = sample_transform(jaxlie.SO3) 49 | assert_transforms_close(T, jaxlie.SO3.from_rpy_radians(*T.as_rpy_radians())) 50 | 51 | 52 | @general_group_test 53 | def test_log_exp_bijective( 54 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 55 | ): 56 | """Check 1-to-1 mapping for log <=> exp operations.""" 57 | transform = sample_transform(Group, batch_axes) 58 | 59 | tangent = transform.log() 60 | assert tangent.shape == (*batch_axes, Group.tangent_dim) 61 | 62 | exp_transform = Group.exp(tangent) 63 | assert_transforms_close(transform, exp_transform) 64 | assert_arrays_close(tangent, exp_transform.log()) 65 | 66 | 67 | @general_group_test 68 | def test_inverse_bijective( 69 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 70 | ): 71 | """Check inverse of inverse.""" 72 | transform = sample_transform(Group, batch_axes) 73 | assert_transforms_close(transform, transform.inverse().inverse()) 74 | 75 | 76 | @general_group_test 77 | def test_matrix_bijective( 78 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 79 | ): 80 | """Check that we can convert to and from matrices.""" 81 | transform = sample_transform(Group, batch_axes) 82 | assert_transforms_close(transform, Group.from_matrix(transform.as_matrix())) 83 | 84 | 85 | @general_group_test 86 | def test_adjoint(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 87 | """Check adjoint definition.""" 88 | transform = sample_transform(Group, batch_axes) 89 | omega = onp.random.randn(*batch_axes, Group.tangent_dim) 90 | assert_transforms_close( 91 | transform @ Group.exp(omega), 92 | Group.exp(onp.einsum("...ij,...j->...i", transform.adjoint(), omega)) 93 | @ transform, 94 | ) 95 | 96 | 97 | @general_group_test 98 | def test_repr(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 99 | """Smoke test for __repr__ implementations.""" 100 | transform = sample_transform(Group, batch_axes) 101 | print(transform) 102 | 103 | 104 | @general_group_test 105 | def test_apply(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 106 | """Check group action interfaces.""" 107 | T_w_b = sample_transform(Group, batch_axes) 108 | p_b = onp.random.randn(*batch_axes, Group.space_dim) 109 | 110 | if Group.matrix_dim == Group.space_dim: 111 | assert_arrays_close( 112 | T_w_b @ p_b, 113 | T_w_b.apply(p_b), 114 | onp.einsum("...ij,...j->...i", T_w_b.as_matrix(), p_b), 115 | ) 116 | else: 117 | # Homogeneous coordinates. 118 | assert Group.matrix_dim == Group.space_dim + 1 119 | assert_arrays_close( 120 | T_w_b @ p_b, 121 | T_w_b.apply(p_b), 122 | onp.einsum( 123 | "...ij,...j->...i", 124 | T_w_b.as_matrix(), 125 | onp.concatenate([p_b, onp.ones_like(p_b[..., :1])], axis=-1), 126 | )[..., :-1], 127 | ) 128 | 129 | 130 | @general_group_test 131 | def test_multiply(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 132 | """Check multiply interfaces.""" 133 | T_w_b = sample_transform(Group, batch_axes) 134 | T_b_a = sample_transform(Group, batch_axes) 135 | assert_arrays_close( 136 | onp.einsum( 137 | "...ij,...jk->...ik", T_w_b.as_matrix(), T_w_b.inverse().as_matrix() 138 | ), 139 | onp.broadcast_to( 140 | onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim) 141 | ), 142 | ) 143 | assert_arrays_close( 144 | onp.einsum( 145 | "...ij,...jk->...ik", T_w_b.as_matrix(), jnp.linalg.inv(T_w_b.as_matrix()) 146 | ), 147 | onp.broadcast_to( 148 | onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim) 149 | ), 150 | ) 151 | assert_transforms_close(T_w_b @ T_b_a, Group.multiply(T_w_b, T_b_a)) 152 | -------------------------------------------------------------------------------- /jaxlie/manifold/_deltas.py: -------------------------------------------------------------------------------- 1 | """Helpers for recursively applying tangent-space deltas.""" 2 | 3 | from typing import Any, TypeVar, Union, overload 4 | 5 | import jax 6 | import numpy as onp 7 | from jax import numpy as jnp 8 | 9 | from .. import hints 10 | from .._base import MatrixLieGroup 11 | from .._se2 import SE2 12 | from .._se3 import SE3 13 | from .._so2 import SO2 14 | from .._so3 import SO3 15 | from . import _tree_utils 16 | 17 | PytreeType = TypeVar("PytreeType") 18 | GroupType = TypeVar("GroupType", bound=MatrixLieGroup) 19 | 20 | 21 | def _rplus(transform: GroupType, delta: jax.Array) -> GroupType: 22 | assert isinstance(transform, MatrixLieGroup) 23 | assert isinstance(delta, (jax.Array, onp.ndarray)) 24 | return transform @ type(transform).exp(delta) 25 | 26 | 27 | @overload 28 | def rplus( 29 | transform: GroupType, 30 | delta: hints.Array, 31 | ) -> GroupType: ... 32 | 33 | 34 | @overload 35 | def rplus( 36 | transform: PytreeType, 37 | delta: _tree_utils.TangentPytree, 38 | ) -> PytreeType: ... 39 | 40 | 41 | # Using our typevars in the overloaded signature will cause errors. 42 | def rplus( 43 | transform: Union[MatrixLieGroup, Any], 44 | delta: Union[hints.Array, Any], 45 | ) -> Union[MatrixLieGroup, Any]: 46 | """Manifold right plus. Computes `T' = T @ exp(delta)`. 47 | 48 | Supports pytrees containing Lie group instances recursively; simple Euclidean 49 | addition will be performed for all other arrays. 50 | """ 51 | return _tree_utils._map_group_trees(_rplus, jnp.add, transform, delta) 52 | 53 | 54 | def _rminus(a: GroupType, b: GroupType) -> jax.Array: 55 | assert isinstance(a, MatrixLieGroup) and isinstance(b, MatrixLieGroup) 56 | return (a.inverse() @ b).log() 57 | 58 | 59 | @overload 60 | def rminus(a: GroupType, b: GroupType) -> jax.Array: ... 61 | 62 | 63 | @overload 64 | def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: ... 65 | 66 | 67 | # Using our typevars in the overloaded signature will cause errors. 68 | def rminus( 69 | a: Union[MatrixLieGroup, Any], b: Union[MatrixLieGroup, Any] 70 | ) -> Union[jax.Array, _tree_utils.TangentPytree]: 71 | """Manifold right minus. Computes 72 | `delta = T_ab.log() = (T_wa.inverse() @ T_wb).log()`. 73 | 74 | Supports pytrees containing Lie group instances recursively; simple Euclidean 75 | subtraction will be performed for all other arrays. 76 | """ 77 | return _tree_utils._map_group_trees(_rminus, jnp.subtract, a, b) 78 | 79 | 80 | @jax.jit 81 | def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array: 82 | """Analytical Jacobians for `jaxlie.manifold.rplus()`, linearized around a zero 83 | local delta. 84 | 85 | Mostly useful for reducing JIT compile times for tangent-space optimization. 86 | 87 | Equivalent to -- 88 | ``` 89 | def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array: 90 | # Since transform objects are pytree containers, note that `jacfwd` returns a 91 | # transformation object itself and that the Jacobian terms corresponding to the 92 | # parameters are grabbed explicitly. 93 | return jax.jacfwd( 94 | jaxlie.manifold.rplus, # Args are (transform, delta) 95 | argnums=1, # Jacobian wrt delta 96 | )(transform, onp.zeros(transform.tangent_dim)).parameters() 97 | ``` 98 | 99 | Args: 100 | transform: Transform to linearize around. 101 | 102 | Returns: 103 | Jacobian. Shape should be `(Group.parameters_dim, Group.tangent_dim)`. 104 | """ 105 | if isinstance(transform, SO2): 106 | # Jacobian row indices: cos, sin 107 | # Jacobian col indices: theta 108 | 109 | J = jnp.zeros((*transform.get_batch_axes(), 2, 1)) 110 | cos, sin = jnp.moveaxis(transform.unit_complex, -1, 0) 111 | J = J.at[..., 0, 0].set(-sin).at[..., 1, 0].set(cos) 112 | 113 | elif isinstance(transform, SE2): 114 | # Jacobian row indices: cos, sin, x, y 115 | # Jacobian col indices: vx, vy, omega 116 | J = jnp.zeros((*transform.get_batch_axes(), 4, 3)) 117 | 118 | # Translation terms. 119 | J = J.at[..., 2:, :2].set(transform.rotation().as_matrix()) 120 | 121 | # Rotation terms. 122 | J = J.at[..., :2, 2:3].set( 123 | rplus_jacobian_parameters_wrt_delta(transform.rotation()) 124 | ) 125 | 126 | elif isinstance(transform, SO3): 127 | # Jacobian row indices: qw, qx, qy, qz 128 | # Jacobian col indices: omega x, omega y, omega z 129 | w, x, y, z = jnp.moveaxis(transform.wxyz, -1, 0) 130 | neg_x = -x 131 | neg_y = -y 132 | neg_z = -z 133 | 134 | J = ( 135 | jnp.stack( 136 | [ 137 | neg_x, 138 | neg_y, 139 | neg_z, 140 | w, 141 | neg_z, 142 | y, 143 | z, 144 | w, 145 | neg_x, 146 | neg_y, 147 | x, 148 | w, 149 | ], 150 | axis=-1, 151 | ).reshape((*transform.get_batch_axes(), 4, 3)) 152 | / 2.0 153 | ) 154 | 155 | elif isinstance(transform, SE3): 156 | # Jacobian row indices: qw, qx, qy, qz, x, y, z 157 | # Jacobian col indices: vx, vy, vz, omega x, omega y, omega z 158 | J = jnp.zeros((*transform.get_batch_axes(), 7, 6)) 159 | 160 | # Translation terms. 161 | J = J.at[..., 4:, :3].set(transform.rotation().as_matrix()) 162 | 163 | # Rotation terms. 164 | J = J.at[..., :4, 3:6].set( 165 | rplus_jacobian_parameters_wrt_delta(transform.rotation()) 166 | ) 167 | 168 | else: 169 | assert False, f"Unsupported type: {type(transform)}" 170 | 171 | assert J.shape == ( 172 | *transform.get_batch_axes(), 173 | transform.parameters_dim, 174 | transform.tangent_dim, 175 | ) 176 | return J 177 | -------------------------------------------------------------------------------- /jaxlie/_se3.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Tuple, cast 4 | 5 | import jax 6 | import jax_dataclasses as jdc 7 | from jax import numpy as jnp 8 | from typing_extensions import override 9 | 10 | from . import _base, hints 11 | from ._so3 import _SO3_jac_left as _SO3_V, SO3, _skew, _SO3_jac_left_inv as _SO3_V_inv 12 | from .utils import broadcast_leading_axes, get_epsilon, register_lie_group 13 | 14 | 15 | @register_lie_group( 16 | matrix_dim=4, 17 | parameters_dim=7, 18 | tangent_dim=6, 19 | space_dim=3, 20 | ) 21 | @jdc.pytree_dataclass 22 | class SE3(_base.SEBase[SO3]): 23 | """Special Euclidean group for proper rigid transforms in 3D. Broadcasting 24 | rules are the same as for numpy. 25 | 26 | Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization 27 | is `(vx, vy, vz, omega_x, omega_y, omega_z)`. 28 | """ 29 | 30 | # SE3-specific. 31 | 32 | wxyz_xyz: jax.Array 33 | """Internal parameters. wxyz quaternion followed by xyz translation. Shape should be `(*, 7)`.""" 34 | 35 | @override 36 | def __repr__(self) -> str: 37 | quat = jnp.round(self.wxyz_xyz[..., :4], 5) 38 | trans = jnp.round(self.wxyz_xyz[..., 4:], 5) 39 | return f"{self.__class__.__name__}(wxyz={quat}, xyz={trans})" 40 | 41 | # SE-specific. 42 | 43 | @classmethod 44 | @override 45 | def from_rotation_and_translation( 46 | cls, 47 | rotation: SO3, 48 | translation: hints.Array, 49 | ) -> SE3: 50 | assert translation.shape[-1:] == (3,) 51 | rotation, translation = broadcast_leading_axes((rotation, translation)) 52 | return SE3(wxyz_xyz=jnp.concatenate([rotation.wxyz, translation], axis=-1)) 53 | 54 | @override 55 | def rotation(self) -> SO3: 56 | return SO3(wxyz=self.wxyz_xyz[..., :4]) 57 | 58 | @override 59 | def translation(self) -> jax.Array: 60 | return self.wxyz_xyz[..., 4:] 61 | 62 | # Factory. 63 | 64 | @classmethod 65 | @override 66 | def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SE3: 67 | return SE3( 68 | wxyz_xyz=jnp.broadcast_to( 69 | jnp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), (*batch_axes, 7) 70 | ) 71 | ) 72 | 73 | @classmethod 74 | @override 75 | def from_matrix(cls, matrix: hints.Array) -> SE3: 76 | assert matrix.shape[-2:] == (4, 4) 77 | # Currently assumes bottom row is [0, 0, 0, 1]. 78 | return SE3.from_rotation_and_translation( 79 | rotation=SO3.from_matrix(matrix[..., :3, :3]), 80 | translation=matrix[..., :3, 3], 81 | ) 82 | 83 | # Accessors. 84 | 85 | @override 86 | def as_matrix(self) -> jax.Array: 87 | return ( 88 | jnp.zeros((*self.get_batch_axes(), 4, 4)) 89 | .at[..., :3, :3] 90 | .set(self.rotation().as_matrix()) 91 | .at[..., :3, 3] 92 | .set(self.translation()) 93 | .at[..., 3, 3] 94 | .set(1.0) 95 | ) 96 | 97 | @override 98 | def parameters(self) -> jax.Array: 99 | return self.wxyz_xyz 100 | 101 | # Operations. 102 | 103 | @classmethod 104 | @override 105 | def exp(cls, tangent: hints.Array) -> SE3: 106 | # Reference: 107 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761 108 | 109 | # (x, y, z, omega_x, omega_y, omega_z) 110 | assert tangent.shape[-1:] == (6,) 111 | theta = tangent[..., 3:] 112 | rotation = SO3.exp(theta) 113 | V = _SO3_V( 114 | cast(jax.Array, theta), rotation.as_matrix() 115 | ) # Using _SO3_jac_left via import alias 116 | return SE3.from_rotation_and_translation( 117 | rotation=rotation, 118 | translation=jnp.einsum("...ij,...j->...i", V, tangent[..., :3]), 119 | ) 120 | 121 | @override 122 | def log(self) -> jax.Array: 123 | # Reference: 124 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223 125 | theta = self.rotation().log() 126 | V_inv = _SO3_V_inv(theta) # Using _SO3_jac_left_inv via import alias 127 | return jnp.concatenate( 128 | [jnp.einsum("...ij,...j->...i", V_inv, self.translation()), theta], axis=-1 129 | ) 130 | 131 | @override 132 | def adjoint(self) -> jax.Array: 133 | R = self.rotation().as_matrix() 134 | return jnp.concatenate( 135 | [ 136 | jnp.concatenate( 137 | [R, jnp.einsum("...ij,...jk->...ik", _skew(self.translation()), R)], 138 | axis=-1, 139 | ), 140 | jnp.concatenate( 141 | [jnp.zeros((*self.get_batch_axes(), 3, 3)), R], axis=-1 142 | ), 143 | ], 144 | axis=-2, 145 | ) 146 | 147 | @override 148 | def jlog(self) -> jax.Array: 149 | rotation = self.rotation() 150 | translation = self.translation() 151 | 152 | jlog_so3 = rotation.jlog() 153 | 154 | w = rotation.log() 155 | theta = jnp.linalg.norm(w, axis=-1) 156 | theta_squared = jnp.sum(jnp.square(w), axis=-1) 157 | 158 | use_taylor = theta_squared < get_epsilon(theta.dtype) 159 | theta_inv = cast(jax.Array, jnp.where(use_taylor, 1.0, 1.0 / theta)) 160 | theta_squared_inv = theta_inv**2 161 | st, ct = jnp.sin(theta), jnp.cos(theta) 162 | inv_2_2ct = jnp.where(use_taylor, 0.5, 1 / (2 * (1 - ct))) 163 | 164 | # Use jnp.where for beta and beta_dot_over_theta. 165 | beta = theta_squared_inv - st * theta_inv * inv_2_2ct 166 | beta_dot_over_theta = ( 167 | -2 * theta_squared_inv**2 168 | + (1 + st * theta_inv) * theta_squared_inv * inv_2_2ct 169 | ) 170 | wTp = jnp.sum(w * translation, axis=-1, keepdims=True) 171 | v3_tmp = (beta_dot_over_theta[..., None] * wTp) * w - ( 172 | theta_squared[..., None] * beta_dot_over_theta[..., None] 173 | + 2 * beta[..., None] 174 | ) * translation 175 | C = ( 176 | jnp.einsum("...i,...j->...ij", v3_tmp, w) 177 | + beta[..., None, None] * jnp.einsum("...i,...j->...ij", w, translation) 178 | + wTp[..., None] * beta[..., None, None] * jnp.eye(3) 179 | ) 180 | C = C + 0.5 * _skew(translation) 181 | 182 | B = jnp.einsum("...ij,...jk->...ik", C, jlog_so3) 183 | B_wh = jnp.where(use_taylor[..., None, None], 0.5 * _skew(translation), B) 184 | assert B_wh.shape == jlog_so3.shape 185 | 186 | jlog = ( 187 | jnp.zeros((*theta.shape, 6, 6)) 188 | .at[..., :3, :3] 189 | .set(jlog_so3) 190 | .at[..., 3:, 3:] 191 | .set(jlog_so3) 192 | .at[..., :3, 3:] 193 | .set(B_wh) 194 | ) 195 | return jlog 196 | 197 | @classmethod 198 | @override 199 | def sample_uniform( 200 | cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = () 201 | ) -> SE3: 202 | key0, key1 = jax.random.split(key) 203 | return SE3.from_rotation_and_translation( 204 | rotation=SO3.sample_uniform(key0, batch_axes=batch_axes), 205 | translation=jax.random.uniform( 206 | key=key1, shape=(*batch_axes, 3), minval=-1.0, maxval=1.0 207 | ), 208 | ) 209 | -------------------------------------------------------------------------------- /tests/test_autodiff.py: -------------------------------------------------------------------------------- 1 | """Compare forward- and reverse-mode Jacobians with a numerical Jacobian.""" 2 | 3 | from functools import lru_cache 4 | from typing import Callable, Tuple, Type, cast 5 | 6 | import jax 7 | import numpy as onp 8 | from jax import numpy as jnp 9 | from utils import assert_arrays_close, general_group_test, jacnumerical 10 | 11 | import jaxlie 12 | 13 | # We cache JITed Jacobians to improve runtime. 14 | cached_jacfwd = lru_cache(maxsize=None)( 15 | lambda f: jax.jit(jax.jacfwd(f, argnums=1), static_argnums=0) 16 | ) 17 | cached_jacrev = lru_cache(maxsize=None)( 18 | lambda f: jax.jit(jax.jacrev(f, argnums=1), static_argnums=0) 19 | ) 20 | cached_jit = lru_cache(maxsize=None)(jax.jit) 21 | 22 | 23 | def _assert_jacobians_close( 24 | Group: Type[jaxlie.MatrixLieGroup], 25 | f: Callable[[Type[jaxlie.MatrixLieGroup], jax.Array], jax.Array], 26 | primal: jaxlie.hints.Array, 27 | ) -> None: 28 | jacobian_fwd = cached_jacfwd(f)(Group, primal) 29 | jacobian_rev = cached_jacrev(f)(Group, primal) 30 | jacobian_numerical = jacnumerical( 31 | lambda primal: cached_jit(f, static_argnums=0)(Group, primal) 32 | )(primal) 33 | 34 | assert_arrays_close(jacobian_fwd, jacobian_rev) 35 | assert_arrays_close(jacobian_fwd, jacobian_numerical, rtol=5e-4, atol=5e-4) 36 | 37 | 38 | # Exp tests. 39 | def _exp(Group: Type[jaxlie.MatrixLieGroup], generator: jax.Array) -> jax.Array: 40 | return cast(jax.Array, Group.exp(generator).parameters()) 41 | 42 | 43 | def test_so3_nan(): 44 | """Make sure we don't get NaNs from division when w == 0. 45 | 46 | https://github.com/brentyi/jaxlie/issues/9""" 47 | 48 | @jax.jit 49 | @jax.grad 50 | def func(x): 51 | return jaxlie.SO3.exp(x).log().sum() 52 | 53 | for omega in jnp.eye(3) * jnp.pi: 54 | a = jnp.array(omega, dtype=jnp.float32) 55 | assert all(onp.logical_not(onp.isnan(func(a)))) 56 | 57 | 58 | @general_group_test 59 | def test_exp_random(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 60 | """Check that exp Jacobians are consistent, with randomly sampled transforms.""" 61 | del batch_axes # Not used for autodiff tests. 62 | generator = onp.random.randn(Group.tangent_dim) 63 | _assert_jacobians_close(Group=Group, f=_exp, primal=generator) 64 | 65 | 66 | @general_group_test 67 | def test_exp_identity(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 68 | """Check that exp Jacobians are consistent, with transforms close to the 69 | identity.""" 70 | del batch_axes # Not used for autodiff tests. 71 | generator = onp.random.randn(Group.tangent_dim) * 1e-6 72 | _assert_jacobians_close(Group=Group, f=_exp, primal=generator) 73 | 74 | 75 | # Log tests. 76 | def _log(Group: Type[jaxlie.MatrixLieGroup], params: jax.Array) -> jax.Array: 77 | return Group.log(Group(params)) 78 | 79 | 80 | @general_group_test 81 | def test_log_random(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 82 | """Check that log Jacobians are consistent, with randomly sampled transforms.""" 83 | del batch_axes # Not used for autodiff tests. 84 | params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() 85 | _assert_jacobians_close(Group=Group, f=_log, primal=params) 86 | 87 | 88 | @general_group_test 89 | def test_log_identity(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 90 | """Check that log Jacobians are consistent, with transforms close to the 91 | identity.""" 92 | params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() 93 | _assert_jacobians_close(Group=Group, f=_log, primal=params) 94 | 95 | 96 | # Adjoint tests. 97 | def _adjoint(Group: Type[jaxlie.MatrixLieGroup], params: jax.Array) -> jax.Array: 98 | return cast(jax.Array, Group(params).adjoint().flatten()) 99 | 100 | 101 | @general_group_test 102 | def test_adjoint_random( 103 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 104 | ): 105 | """Check that adjoint Jacobians are consistent, with randomly sampled transforms.""" 106 | del batch_axes # Not used for autodiff tests. 107 | params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() 108 | _assert_jacobians_close(Group=Group, f=_adjoint, primal=params) 109 | 110 | 111 | @general_group_test 112 | def test_adjoint_identity( 113 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 114 | ): 115 | """Check that adjoint Jacobians are consistent, with transforms close to the 116 | identity.""" 117 | del batch_axes # Not used for autodiff tests. 118 | params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() 119 | _assert_jacobians_close(Group=Group, f=_adjoint, primal=params) 120 | 121 | 122 | # Apply tests. 123 | def _apply(Group: Type[jaxlie.MatrixLieGroup], params: jax.Array) -> jax.Array: 124 | return Group(params) @ onp.ones(Group.space_dim) 125 | 126 | 127 | @general_group_test 128 | def test_apply_random(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): 129 | """Check that apply Jacobians are consistent, with randomly sampled transforms.""" 130 | del batch_axes # Not used for autodiff tests. 131 | params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() 132 | _assert_jacobians_close(Group=Group, f=_apply, primal=params) 133 | 134 | 135 | @general_group_test 136 | def test_apply_identity( 137 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 138 | ): 139 | """Check that apply Jacobians are consistent, with transforms close to the 140 | identity.""" 141 | del batch_axes # Not used for autodiff tests. 142 | params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() 143 | _assert_jacobians_close(Group=Group, f=_apply, primal=params) 144 | 145 | 146 | # Multiply tests. 147 | def _multiply(Group: Type[jaxlie.MatrixLieGroup], params: jax.Array) -> jax.Array: 148 | return cast(jax.Array, (Group(params) @ Group(params)).parameters()) 149 | 150 | 151 | @general_group_test 152 | def test_multiply_random( 153 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 154 | ): 155 | """Check that multiply Jacobians are consistent, with randomly sampled 156 | transforms.""" 157 | del batch_axes # Not used for autodiff tests. 158 | params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() 159 | _assert_jacobians_close(Group=Group, f=_multiply, primal=params) 160 | 161 | 162 | @general_group_test 163 | def test_multiply_identity( 164 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 165 | ): 166 | """Check that multiply Jacobians are consistent, with transforms close to the 167 | identity.""" 168 | del batch_axes # Not used for autodiff tests. 169 | params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() 170 | _assert_jacobians_close(Group=Group, f=_multiply, primal=params) 171 | 172 | 173 | # Inverse tests. 174 | def _inverse(Group: Type[jaxlie.MatrixLieGroup], params: jax.Array) -> jax.Array: 175 | return cast(jax.Array, Group(params).inverse().parameters()) 176 | 177 | 178 | @general_group_test 179 | def test_inverse_random( 180 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 181 | ): 182 | """Check that inverse Jacobians are consistent, with randomly sampled transforms.""" 183 | del batch_axes # Not used for autodiff tests. 184 | params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() 185 | _assert_jacobians_close(Group=Group, f=_inverse, primal=params) 186 | 187 | 188 | @general_group_test 189 | def test_inverse_identity( 190 | Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] 191 | ): 192 | """Check that inverse Jacobians are consistent, with transforms close to the 193 | identity.""" 194 | del batch_axes # Not used for autodiff tests. 195 | params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() 196 | _assert_jacobians_close(Group=Group, f=_inverse, primal=params) 197 | -------------------------------------------------------------------------------- /examples/se3_optimization.py: -------------------------------------------------------------------------------- 1 | """Example that uses helpers in `jaxlie.manifold.*` to compare algorithms for running an 2 | ADAM optimizer on SE(3) variables. 3 | 4 | We compare three approaches: 5 | 6 | (1) Tangent-space ADAM: computing updates on a local tangent space, which are then 7 | retracted back to the global parameterization at each step. This should generally be the 8 | most stable. 9 | 10 | (2) Projected ADAM: running standard ADAM directly on the global parameterization, then 11 | projecting after each step. 12 | 13 | (3) Standard ADAM with exponential coordinates: using a log-space underlying 14 | parameterization lets us run ADAM without any modifications. 15 | 16 | Note that the number of training steps and learning rate can be configured, see: 17 | 18 | python se3_optimization.py --help 19 | 20 | """ 21 | 22 | from __future__ import annotations 23 | 24 | import time 25 | from typing import List, Literal, Tuple, Union 26 | 27 | import jax 28 | import jax_dataclasses as jdc 29 | import matplotlib.pyplot as plt 30 | import optax 31 | import tyro 32 | from jax import numpy as jnp 33 | from typing_extensions import assert_never 34 | 35 | import jaxlie 36 | 37 | 38 | @jdc.pytree_dataclass 39 | class Parameters: 40 | """Parameters to optimize over, in their global representation. Rotations are 41 | quaternions under the hood. 42 | 43 | Note that there's redundancy here: given T_ab and T_bc, T_ca can be computed as 44 | (T_ab @ T_bc).inverse(). Our optimization will be focused on making these redundant 45 | transforms consistent with each other. 46 | """ 47 | 48 | T_ab: jaxlie.SE3 49 | T_bc: jaxlie.SE3 50 | T_ca: jaxlie.SE3 51 | 52 | 53 | @jdc.pytree_dataclass 54 | class ExponentialCoordinatesParameters: 55 | """Same as `Parameters`, but using exponential coordinates.""" 56 | 57 | log_T_ab: jax.Array 58 | log_T_bc: jax.Array 59 | log_T_ca: jax.Array 60 | 61 | @property 62 | def T_ab(self) -> jaxlie.SE3: 63 | return jaxlie.SE3.exp(self.log_T_ab) 64 | 65 | @property 66 | def T_bc(self) -> jaxlie.SE3: 67 | return jaxlie.SE3.exp(self.log_T_bc) 68 | 69 | @property 70 | def T_ca(self) -> jaxlie.SE3: 71 | return jaxlie.SE3.exp(self.log_T_ca) 72 | 73 | @staticmethod 74 | def from_global(params: Parameters) -> ExponentialCoordinatesParameters: 75 | return ExponentialCoordinatesParameters( 76 | params.T_ab.log(), 77 | params.T_bc.log(), 78 | params.T_ca.log(), 79 | ) 80 | 81 | 82 | def compute_loss( 83 | params: Union[Parameters, ExponentialCoordinatesParameters], 84 | ) -> jax.Array: 85 | """As our loss, we enforce (a) priors on our transforms and (b) a consistency 86 | constraint.""" 87 | T_ba_prior = jaxlie.SE3.sample_uniform(jax.random.PRNGKey(1)) 88 | T_cb_prior = jaxlie.SE3.sample_uniform(jax.random.PRNGKey(2)) 89 | 90 | return jnp.sum( 91 | # Consistency term. 92 | (params.T_ab @ params.T_bc @ params.T_ca).log() ** 2 93 | # Priors. 94 | + (params.T_ab @ T_ba_prior).log() ** 2 95 | + (params.T_bc @ T_cb_prior).log() ** 2 96 | ) 97 | 98 | 99 | Algorithm = Literal["tangent_space", "projected", "exponential_coordinates"] 100 | 101 | 102 | @jdc.pytree_dataclass 103 | class State: 104 | params: Union[Parameters, ExponentialCoordinatesParameters] 105 | optimizer: jdc.Static[optax.GradientTransformation] 106 | optimizer_state: optax.OptState 107 | algorithm: jdc.Static[Algorithm] 108 | 109 | @staticmethod 110 | def initialize(algorithm: Algorithm, learning_rate: float) -> State: 111 | """Initialize the state of our optimization problem. Note that the transforms 112 | parameters won't initially be consistent; `T_ab @ T_bc != T_ca.inverse()`. 113 | """ 114 | prngs = jax.random.split(jax.random.PRNGKey(0), num=1) 115 | global_params = Parameters( 116 | jaxlie.SE3.sample_uniform(prngs[0]), 117 | jaxlie.SE3.sample_uniform(prngs[1]), 118 | jaxlie.SE3.sample_uniform(prngs[2]), 119 | ) 120 | 121 | # Make optimizer. 122 | params: Union[Parameters, ExponentialCoordinatesParameters] 123 | optimizer = optax.adam(learning_rate=learning_rate) 124 | if algorithm == "tangent_space": 125 | # Initialize gradient statistics as on the tangent space. 126 | params = global_params 127 | optimizer_state = optimizer.init(jaxlie.manifold.zero_tangents(params)) 128 | elif algorithm == "projected": 129 | # Initialize gradient statistics directly in quaternion space. 130 | params = global_params 131 | optimizer_state = optimizer.init(params) 132 | elif algorithm == "exponential_coordinates": 133 | # Switch to a log-space parameterization. 134 | params = ExponentialCoordinatesParameters.from_global(global_params) 135 | optimizer_state = optimizer.init(params) 136 | else: 137 | assert_never(algorithm) 138 | 139 | return State( 140 | params=params, 141 | optimizer=optimizer, 142 | optimizer_state=optimizer_state, 143 | algorithm=algorithm, 144 | ) 145 | 146 | @jax.jit 147 | def step(self: State) -> Tuple[jax.Array, State]: 148 | """Take one ADAM optimization step.""" 149 | 150 | if self.algorithm == "tangent_space": 151 | # ADAM step on manifold. 152 | # 153 | # `jaxlie.manifold.value_and_grad()` is a drop-in replacement for 154 | # `jax.value_and_grad()`, but for Lie group instances computes gradients on 155 | # the tangent space. 156 | loss, grads = jaxlie.manifold.value_and_grad(compute_loss)(self.params) 157 | updates, new_optimizer_state = self.optimizer.update( 158 | grads, 159 | self.optimizer_state, 160 | self.params, 161 | ) 162 | new_params = jaxlie.manifold.rplus(self.params, updates) 163 | 164 | elif self.algorithm == "projected": 165 | # Projection-based approach. 166 | loss, grads = jax.value_and_grad(compute_loss)(self.params) 167 | updates, new_optimizer_state = self.optimizer.update( 168 | grads, 169 | self.optimizer_state, 170 | self.params, 171 | ) 172 | new_params = optax.apply_updates(self.params, updates) 173 | 174 | # Project back to manifold. 175 | new_params = jaxlie.manifold.normalize_all(new_params) 176 | 177 | elif self.algorithm == "exponential_coordinates": 178 | # If we parameterize with exponential coordinates, we can 179 | loss, grads = jax.value_and_grad(compute_loss)(self.params) 180 | updates, new_optimizer_state = self.optimizer.update( 181 | grads, 182 | self.optimizer_state, 183 | self.params, 184 | ) 185 | new_params = optax.apply_updates(self.params, updates) 186 | 187 | else: 188 | assert assert_never(self.algorithm) 189 | 190 | # Return updated structure. 191 | with jdc.copy_and_mutate(self, validate=True) as new_state: 192 | new_state.params = new_params 193 | new_state.optimizer_state = new_optimizer_state 194 | 195 | return loss, new_state 196 | 197 | 198 | def run_experiment( 199 | algorithm: Algorithm, learning_rate: float, train_steps: int 200 | ) -> List[float]: 201 | """Run the optimization problem, either using a tangent-space approach or via 202 | projection.""" 203 | 204 | print(algorithm) 205 | state = State.initialize(algorithm, learning_rate) 206 | state.step() # Don't include JIT compile in timing. 207 | 208 | start_time = time.time() 209 | losses = [] 210 | for i in range(train_steps): 211 | loss, state = state.step() 212 | if i % 20 == 0: 213 | print(f"\t(step {i:03d}) Loss", loss, flush=True) 214 | losses.append(float(loss)) 215 | print() 216 | print(f"\tConverged in {time.time() - start_time} seconds") 217 | print() 218 | print("\tAfter optimization, the following transforms should be consistent:") 219 | print(f"\t\t{state.params.T_ab @ state.params.T_bc=}") 220 | print(f"\t\t{state.params.T_ca.inverse()=}") 221 | 222 | return losses 223 | 224 | 225 | def main(train_steps: int = 1000, learning_rate: float = 1e-1) -> None: 226 | """Run pose optimization experiments. 227 | 228 | Args: 229 | train_steps: Number of training steps to take. 230 | learning_rate: Learning rate for our ADAM optimizers. 231 | """ 232 | xs = range(train_steps) 233 | 234 | algorithms: Tuple[Algorithm, ...] = ( 235 | "tangent_space", 236 | "projected", 237 | "exponential_coordinates", 238 | ) 239 | for algorithm in algorithms: 240 | plt.plot( 241 | xs, 242 | run_experiment(algorithm, learning_rate, train_steps), 243 | label=algorithm, 244 | ) 245 | print() 246 | plt.yscale("log", base=2) 247 | plt.legend() 248 | plt.show() 249 | 250 | 251 | if __name__ == "__main__": 252 | tyro.cli(main) 253 | -------------------------------------------------------------------------------- /jaxlie/_base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import ClassVar, Generic, Tuple, TypeVar, Union, overload 3 | 4 | import jax 5 | import numpy as onp 6 | from jax import numpy as jnp 7 | from typing_extensions import Self, final, get_args, override 8 | 9 | from . import hints 10 | 11 | 12 | class MatrixLieGroup(abc.ABC): 13 | """Interface definition for matrix Lie groups.""" 14 | 15 | # Class properties. 16 | # > These will be set in `_utils.register_lie_group()`. 17 | 18 | matrix_dim: ClassVar[int] 19 | """Dimension of square matrix output from `.as_matrix()`.""" 20 | 21 | parameters_dim: ClassVar[int] 22 | """Dimension of underlying parameters, `.parameters()`.""" 23 | 24 | tangent_dim: ClassVar[int] 25 | """Dimension of tangent space.""" 26 | 27 | space_dim: ClassVar[int] 28 | """Dimension of coordinates that can be transformed.""" 29 | 30 | def __init__( 31 | # Notes: 32 | # - For the constructor signature to be consistent with subclasses, `parameters` 33 | # should be marked as positional-only. But this isn't possible in Python 3.7. 34 | # - This method is implicitly overriden by the dataclass decorator and 35 | # should _not_ be marked abstract. 36 | self, 37 | parameters: jax.Array, 38 | ): 39 | """Construct a group object from its underlying parameters.""" 40 | raise NotImplementedError() 41 | 42 | # Shared implementations. 43 | 44 | @overload 45 | def __matmul__(self, other: Self) -> Self: ... 46 | 47 | @overload 48 | def __matmul__(self, other: hints.Array) -> jax.Array: ... 49 | 50 | def __matmul__(self, other: Union[Self, hints.Array]) -> Union[Self, jax.Array]: 51 | """Overload for the `@` operator. 52 | 53 | Switches between the group action (`.apply()`) and multiplication 54 | (`.multiply()`) based on the type of `other`. 55 | """ 56 | if isinstance(other, (onp.ndarray, jax.Array)): 57 | return self.apply(target=other) 58 | elif isinstance(other, MatrixLieGroup): 59 | assert self.space_dim == other.space_dim 60 | return self.multiply(other=other) 61 | else: 62 | assert False, f"Invalid argument type for `@` operator: {type(other)}" 63 | 64 | # Factory. 65 | 66 | @classmethod 67 | @abc.abstractmethod 68 | def identity(cls, batch_axes: Tuple[int, ...] = ()) -> Self: 69 | """Returns identity element. 70 | 71 | Args: 72 | batch_axes: Any leading batch axes for the output transform. 73 | 74 | Returns: 75 | Identity element. 76 | """ 77 | 78 | @classmethod 79 | @abc.abstractmethod 80 | def from_matrix(cls, matrix: hints.Array) -> Self: 81 | """Get group member from matrix representation. 82 | 83 | Args: 84 | matrix: Matrix representaiton. 85 | 86 | Returns: 87 | Group member. 88 | """ 89 | 90 | # Accessors. 91 | 92 | @abc.abstractmethod 93 | def as_matrix(self) -> jax.Array: 94 | """Get transformation as a matrix. Homogeneous for SE groups.""" 95 | 96 | @abc.abstractmethod 97 | def parameters(self) -> jax.Array: 98 | """Get underlying representation.""" 99 | 100 | # Operations. 101 | 102 | @abc.abstractmethod 103 | def apply(self, target: hints.Array) -> jax.Array: 104 | """Applies group action to a point. 105 | 106 | Args: 107 | target: Point to transform. 108 | 109 | Returns: 110 | Transformed point. 111 | """ 112 | 113 | @abc.abstractmethod 114 | def multiply(self, other: Self) -> Self: 115 | """Composes this transformation with another. 116 | 117 | Returns: 118 | self @ other 119 | """ 120 | 121 | @classmethod 122 | @abc.abstractmethod 123 | def exp(cls, tangent: hints.Array) -> Self: 124 | """Computes `expm(wedge(tangent))`. 125 | 126 | Args: 127 | tangent: Tangent vector to take the exponential of. 128 | 129 | Returns: 130 | Output. 131 | """ 132 | 133 | @abc.abstractmethod 134 | def log(self) -> jax.Array: 135 | """Computes `vee(logm(transformation matrix))`. 136 | 137 | Returns: 138 | Output. Shape should be `(tangent_dim,)`. 139 | """ 140 | 141 | @abc.abstractmethod 142 | def adjoint(self) -> jax.Array: 143 | """Computes the adjoint, which transforms tangent vectors between tangent 144 | spaces. 145 | 146 | More precisely, for a transform `GroupType`: 147 | ``` 148 | GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType 149 | ``` 150 | 151 | In robotics, typically used for transforming twists, wrenches, and Jacobians 152 | across different reference frames. 153 | 154 | Returns: 155 | Output. Shape should be `(tangent_dim, tangent_dim)`. 156 | """ 157 | 158 | @abc.abstractmethod 159 | def inverse(self) -> Self: 160 | """Computes the inverse of our transform. 161 | 162 | Returns: 163 | Output. 164 | """ 165 | 166 | @abc.abstractmethod 167 | def normalize(self) -> Self: 168 | """Normalize/projects values and returns. 169 | 170 | Returns: 171 | Normalized group member. 172 | """ 173 | 174 | @abc.abstractmethod 175 | def jlog(self) -> jax.Array: 176 | """ 177 | Computes the Jacobian of the logarithm of the group element when a 178 | local perturbation is applied. 179 | 180 | This is equivalent to the inverse of the right Jacobian, or: 181 | 182 | ``` 183 | jax.jacrev(lambda x: (T @ exp(x)).log())(jnp.zeros(tangent_dim)) 184 | ``` 185 | 186 | where `T` is the group element and `exp(x)` is the tangent vector. 187 | 188 | Returns: 189 | The Jacobian of the logarithm, having the dimensions `(tangent_dim, tangent_dim,)` or batch of these Jacobians. 190 | """ 191 | 192 | @classmethod 193 | @abc.abstractmethod 194 | def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> Self: 195 | """Draw a uniform sample from the group. Translations (if applicable) are in the 196 | range [-1, 1]. 197 | 198 | Args: 199 | key: PRNG key, as returned by `jax.random.PRNGKey()`. 200 | batch_axes: Any leading batch axes for the output transforms. Each 201 | sampled transform will be different. 202 | 203 | Returns: 204 | Sampled group member. 205 | """ 206 | 207 | @final 208 | def get_batch_axes(self) -> Tuple[int, ...]: 209 | """Return any leading batch axes in contained parameters. If an array of shape 210 | `(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will 211 | return `(100,)`.""" 212 | return self.parameters().shape[:-1] 213 | 214 | 215 | class SOBase(MatrixLieGroup): 216 | """Base class for special orthogonal groups.""" 217 | 218 | 219 | ContainedSOType = TypeVar("ContainedSOType", bound=SOBase) 220 | 221 | 222 | class SEBase(Generic[ContainedSOType], MatrixLieGroup): 223 | """Base class for special Euclidean groups. 224 | 225 | Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional 226 | translation vector. 227 | """ 228 | 229 | # SE-specific interface. 230 | 231 | @classmethod 232 | @abc.abstractmethod 233 | def from_rotation_and_translation( 234 | cls, 235 | rotation: ContainedSOType, 236 | translation: hints.Array, 237 | ) -> Self: 238 | """Construct a rigid transform from a rotation and a translation. 239 | 240 | Args: 241 | rotation: Rotation term. 242 | translation: translation term. 243 | 244 | Returns: 245 | Constructed transformation. 246 | """ 247 | 248 | @final 249 | @classmethod 250 | def from_rotation(cls, rotation: ContainedSOType) -> Self: 251 | return cls.from_rotation_and_translation( 252 | rotation=rotation, 253 | translation=jnp.zeros( 254 | (*rotation.get_batch_axes(), cls.space_dim), 255 | dtype=rotation.parameters().dtype, 256 | ), 257 | ) 258 | 259 | @final 260 | @classmethod 261 | def from_translation(cls, translation: hints.Array) -> Self: 262 | # Extract rotation class from type parameter. 263 | assert len(cls.__orig_bases__) == 1 # type: ignore 264 | return cls.from_rotation_and_translation( 265 | rotation=get_args(cls.__orig_bases__[0])[0].identity(), # type: ignore 266 | translation=translation, 267 | ) 268 | 269 | @abc.abstractmethod 270 | def rotation(self) -> ContainedSOType: 271 | """Returns a transform's rotation term.""" 272 | 273 | @abc.abstractmethod 274 | def translation(self) -> jax.Array: 275 | """Returns a transform's translation term.""" 276 | 277 | # Overrides. 278 | 279 | @final 280 | @override 281 | def apply(self, target: hints.Array) -> jax.Array: 282 | return self.rotation() @ target + self.translation() # type: ignore 283 | 284 | @final 285 | @override 286 | def multiply(self, other: Self) -> Self: 287 | return type(self).from_rotation_and_translation( 288 | rotation=self.rotation() @ other.rotation(), 289 | translation=(self.rotation() @ other.translation()) + self.translation(), 290 | ) 291 | 292 | @final 293 | @override 294 | def inverse(self) -> Self: 295 | R_inv = self.rotation().inverse() 296 | return type(self).from_rotation_and_translation( 297 | rotation=R_inv, 298 | translation=-(R_inv @ self.translation()), 299 | ) 300 | 301 | @final 302 | @override 303 | def normalize(self) -> Self: 304 | return type(self).from_rotation_and_translation( 305 | rotation=self.rotation().normalize(), 306 | translation=self.translation(), 307 | ) 308 | -------------------------------------------------------------------------------- /jaxlie/_se2.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, cast 2 | 3 | import jax 4 | import jax_dataclasses as jdc 5 | from jax import numpy as jnp 6 | from typing_extensions import override 7 | 8 | from . import _base, hints 9 | from ._so2 import SO2 10 | from .utils import broadcast_leading_axes, get_epsilon, register_lie_group 11 | 12 | 13 | def _SE2_jac_left(tangent: jax.Array) -> jax.Array: 14 | """Compute the left jacobian for the given SO(2) tangent vector. This only 15 | includes the translation terms (2x2), the orientation ones are less 16 | useful.""" 17 | theta = tangent.squeeze(axis=-1) 18 | del tangent 19 | 20 | use_taylor = jnp.abs(theta) < get_epsilon(theta.dtype) 21 | 22 | # Shim to avoid NaNs in jnp.where branches, which cause failures for 23 | # reverse-mode AD. 24 | safe_theta = cast( 25 | jax.Array, 26 | jnp.where(use_taylor, 1.0, theta), 27 | ) 28 | 29 | theta_sq = theta**2 30 | sin_over_theta = cast( 31 | jax.Array, 32 | jnp.where( 33 | use_taylor, 34 | 1.0 - theta_sq / 6.0, 35 | jnp.sin(safe_theta) / safe_theta, 36 | ), 37 | ) 38 | one_minus_cos_over_theta = cast( 39 | jax.Array, 40 | jnp.where( 41 | use_taylor, 42 | 0.5 * theta - theta * theta_sq / 24.0, 43 | (1.0 - jnp.cos(safe_theta)) / safe_theta, 44 | ), 45 | ) 46 | jac_left = ( 47 | jnp.zeros((*theta.shape, 2, 2)) 48 | .at[..., 0, 0] 49 | .set(sin_over_theta) 50 | .at[..., 0, 1] 51 | .set(-one_minus_cos_over_theta) 52 | .at[..., 1, 0] 53 | .set(one_minus_cos_over_theta) 54 | .at[..., 1, 1] 55 | .set(sin_over_theta) 56 | ) 57 | return jac_left 58 | 59 | 60 | def _SE2_jac_left_inv(tangent: jax.Array) -> jax.Array: 61 | """Compute the inverse of the left jacobian for the given SO(2) tangent 62 | vector. This only includes the translation terms (2x2), the orientation 63 | ones are less useful.""" 64 | theta = tangent.squeeze(axis=-1) 65 | del tangent 66 | 67 | cos = jnp.cos(theta) 68 | cos_minus_one = cos - 1.0 69 | half_theta = theta / 2.0 70 | use_taylor = jnp.abs(cos_minus_one) < get_epsilon(theta.dtype) 71 | 72 | # Shim to avoid NaNs in jnp.where branches, which cause failures for 73 | # reverse-mode AD. 74 | safe_cos_minus_one = jnp.where(use_taylor, 1.0, cos_minus_one) 75 | half_theta_over_tan_half_theta = jnp.where( 76 | use_taylor, 77 | # Taylor approximation. 78 | 1.0 - theta**2 / 12.0, 79 | # Default. 80 | -(half_theta * jnp.sin(theta)) / safe_cos_minus_one, 81 | ) 82 | jac_left_inv = ( 83 | jnp.zeros((*theta.shape, 2, 2)) 84 | .at[..., 0, 0] 85 | .set(half_theta_over_tan_half_theta) 86 | .at[..., 0, 1] 87 | .set(half_theta) 88 | .at[..., 1, 0] 89 | .set(-half_theta) 90 | .at[..., 1, 1] 91 | .set(half_theta_over_tan_half_theta) 92 | ) 93 | return jac_left_inv 94 | 95 | 96 | @register_lie_group( 97 | matrix_dim=3, 98 | parameters_dim=4, 99 | tangent_dim=3, 100 | space_dim=2, 101 | ) 102 | @jdc.pytree_dataclass 103 | class SE2(_base.SEBase[SO2]): 104 | """Special Euclidean group for proper rigid transforms in 2D. Broadcasting 105 | rules are the same as for numpy. 106 | 107 | Internal parameterization is `(cos, sin, x, y)`. Tangent parameterization is `(vx, 108 | vy, omega)`. 109 | """ 110 | 111 | # SE2-specific. 112 | 113 | unit_complex_xy: jax.Array 114 | """Internal parameters. `(cos, sin, x, y)`. Shape should be `(*, 4)`.""" 115 | 116 | @override 117 | def __repr__(self) -> str: 118 | unit_complex = jnp.round(self.unit_complex_xy[..., :2], 5) 119 | xy = jnp.round(self.unit_complex_xy[..., 2:], 5) 120 | return f"{self.__class__.__name__}(unit_complex={unit_complex}, xy={xy})" 121 | 122 | @staticmethod 123 | def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2": 124 | """Construct a transformation from standard 2D pose parameters. 125 | 126 | Note that this is not the same as integrating over a length-3 twist. 127 | """ 128 | cos = jnp.cos(theta) 129 | sin = jnp.sin(theta) 130 | return SE2(unit_complex_xy=jnp.stack([cos, sin, x, y], axis=-1)) 131 | 132 | # SE-specific. 133 | 134 | @classmethod 135 | @override 136 | def from_rotation_and_translation( 137 | cls, 138 | rotation: SO2, 139 | translation: hints.Array, 140 | ) -> "SE2": 141 | assert translation.shape[-1:] == (2,) 142 | rotation, translation = broadcast_leading_axes((rotation, translation)) 143 | return SE2( 144 | unit_complex_xy=jnp.concatenate( 145 | [rotation.unit_complex, translation], axis=-1 146 | ) 147 | ) 148 | 149 | @override 150 | def rotation(self) -> SO2: 151 | return SO2(unit_complex=self.unit_complex_xy[..., :2]) 152 | 153 | @override 154 | def translation(self) -> jax.Array: 155 | return self.unit_complex_xy[..., 2:] 156 | 157 | # Factory. 158 | 159 | @classmethod 160 | @override 161 | def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> "SE2": 162 | return SE2( 163 | unit_complex_xy=jnp.broadcast_to( 164 | jnp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4) 165 | ) 166 | ) 167 | 168 | @classmethod 169 | @override 170 | def from_matrix(cls, matrix: hints.Array) -> "SE2": 171 | assert matrix.shape[-2:] == (3, 3) 172 | # Currently assumes bottom row is [0, 0, 1]. 173 | return SE2.from_rotation_and_translation( 174 | rotation=SO2.from_matrix(matrix[..., :2, :2]), 175 | translation=matrix[..., :2, 2], 176 | ) 177 | 178 | # Accessors. 179 | 180 | @override 181 | def parameters(self) -> jax.Array: 182 | return self.unit_complex_xy 183 | 184 | @override 185 | def as_matrix(self) -> jax.Array: 186 | cos, sin, x, y = jnp.moveaxis(self.unit_complex_xy, -1, 0) 187 | out = jnp.stack( 188 | [ 189 | cos, 190 | -sin, 191 | x, 192 | sin, 193 | cos, 194 | y, 195 | jnp.zeros_like(x), 196 | jnp.zeros_like(x), 197 | jnp.ones_like(x), 198 | ], 199 | axis=-1, 200 | ).reshape((*self.get_batch_axes(), 3, 3)) 201 | return out 202 | 203 | # Operations. 204 | 205 | @classmethod 206 | @override 207 | def exp(cls, tangent: hints.Array) -> "SE2": 208 | # Reference: 209 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558 210 | # Also see: 211 | # > http://ethaneade.com/lie.pdf 212 | 213 | assert tangent.shape[-1:] == (3,) 214 | so2_tangent = cast(jax.Array, tangent[..., 2:3]) 215 | V = _SE2_jac_left(so2_tangent) 216 | return SE2.from_rotation_and_translation( 217 | rotation=SO2.exp(so2_tangent), 218 | translation=jnp.einsum("...ij,...j->...i", V, tangent[..., :2]), 219 | ) 220 | 221 | @override 222 | def log(self) -> jax.Array: 223 | # Reference: 224 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L160 225 | # Also see: 226 | # > http://ethaneade.com/lie.pdf 227 | so2_tangent = self.rotation().log() 228 | V_inv = _SE2_jac_left_inv(so2_tangent) 229 | tangent = jnp.concatenate( 230 | [ 231 | jnp.einsum("...ij,...j->...i", V_inv, self.translation()), 232 | so2_tangent, 233 | ], 234 | axis=-1, 235 | ) 236 | return tangent 237 | 238 | @override 239 | def adjoint(self: "SE2") -> jax.Array: 240 | cos, sin, x, y = jnp.moveaxis(self.unit_complex_xy, -1, 0) 241 | return jnp.stack( 242 | [ 243 | cos, 244 | -sin, 245 | y, 246 | sin, 247 | cos, 248 | -x, 249 | jnp.zeros_like(x), 250 | jnp.zeros_like(x), 251 | jnp.ones_like(x), 252 | ], 253 | axis=-1, 254 | ).reshape((*self.get_batch_axes(), 3, 3)) 255 | 256 | @override 257 | def jlog(self) -> jax.Array: 258 | # Reference: 259 | # This is inverse of matrix (163) from Micro-Lie theory: 260 | # > https://arxiv.org/pdf/1812.01537 261 | 262 | tangent = self.log() 263 | theta = tangent[..., 2] 264 | 265 | # Handle the case where theta is small to avoid division by zero. 266 | use_taylor = jnp.abs(theta) < get_epsilon(theta.dtype) 267 | 268 | V_inv_theta = _SE2_jac_left_inv(theta[..., None]) 269 | V_inv_theta_T = jnp.swapaxes( 270 | V_inv_theta, -2, -1 271 | ) # Transpose the last two dimensions 272 | 273 | # Calculate r, handling the small theta case separately. 274 | batch_shape = self.get_batch_axes() 275 | eye_2 = jnp.eye(2).reshape((1,) * len(batch_shape) + (2, 2)) 276 | 277 | # Shim to avoid NaNs in jnp.where branches, which cause failures for reverse-mode AD. 278 | safe_theta = jnp.where(use_taylor, jnp.ones_like(theta), theta) 279 | A = jnp.where( 280 | use_taylor[..., None, None], 281 | jnp.stack( 282 | [ 283 | jnp.stack([theta / 12.0, jnp.full_like(theta, 0.5)], axis=-1), 284 | jnp.stack([jnp.full_like(theta, -0.5), theta / 12.0], axis=-1), 285 | ], 286 | axis=-2, 287 | ), 288 | (eye_2 - V_inv_theta_T) / safe_theta[..., None, None], 289 | ) 290 | r = jnp.einsum("...ij,...j->...i", A, tangent[..., :2]) 291 | 292 | # Create the jlog matrix. 293 | jlog = ( 294 | jnp.zeros((*batch_shape, 3, 3)) 295 | .at[..., :2, :2] 296 | .set(V_inv_theta_T) 297 | .at[..., :2, 2] 298 | .set(r) 299 | .at[..., 2, 2] 300 | .set(1) 301 | ) 302 | return jlog 303 | 304 | @classmethod 305 | @override 306 | def sample_uniform( 307 | cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = () 308 | ) -> "SE2": 309 | key0, key1 = jax.random.split(key) 310 | return SE2.from_rotation_and_translation( 311 | rotation=SO2.sample_uniform(key0, batch_axes=batch_axes), 312 | translation=jax.random.uniform( 313 | key=key1, 314 | shape=( 315 | *batch_axes, 316 | 2, 317 | ), 318 | minval=-1.0, 319 | maxval=1.0, 320 | ), 321 | ) 322 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | # -*- coding: utf-8 -*- 4 | # 5 | # Configuration file for the Sphinx documentation builder. 6 | # 7 | # This file does only contain a selection of the most common options. For a 8 | # full list see the documentation: 9 | # http://www.sphinx-doc.org/en/stable/config 10 | 11 | # -- Path setup -------------------------------------------------------------- 12 | 13 | # If extensions (or modules to document with autodoc) are in another directory, 14 | # add these directories to sys.path here. If the directory is relative to the 15 | # documentation root, use os.path.abspath to make it absolute, like shown here. 16 | # 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "jaxlie" 21 | copyright = "2020" 22 | author = "brentyi" 23 | 24 | # The short X.Y version 25 | version = "" 26 | # The full version, including alpha/beta/rc tags 27 | release = "" 28 | 29 | 30 | # -- General configuration --------------------------------------------------- 31 | 32 | napoleon_numpy_docstring = False # Force consistency, leave only Google 33 | napoleon_use_rtype = False # More legible 34 | 35 | # If your documentation needs a minimal Sphinx version, state it here. 36 | # 37 | # needs_sphinx = '1.0' 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be 40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 41 | # ones. 42 | extensions = [ 43 | "sphinx.ext.autodoc", 44 | "sphinx.ext.todo", 45 | "sphinx.ext.coverage", 46 | "sphinx.ext.mathjax", 47 | "sphinx.ext.githubpages", 48 | "sphinx.ext.napoleon", 49 | "sphinx.ext.inheritance_diagram", 50 | "autoapi.extension", 51 | "sphinx_math_dollar", 52 | "sphinx.ext.viewcode", 53 | ] 54 | 55 | # Pull documentation types from hints 56 | autodoc_typehints = "description" 57 | 58 | # Add any paths that contain templates here, relative to this directory. 59 | templates_path = ["_templates"] 60 | 61 | # The suffix(es) of source filenames. 62 | # You can specify multiple suffix as a list of string: 63 | # 64 | # source_suffix = ['.rst', '.md'] 65 | source_suffix = ".rst" 66 | 67 | # The master toctree document. 68 | master_doc = "index" 69 | 70 | # The language for content autogenerated by Sphinx. Refer to documentation 71 | # for a list of supported languages. 72 | # 73 | # This is also used if you do content translation via gettext catalogs. 74 | # Usually you set "language" from the command line for these cases. 75 | language: Optional[str] = "en" 76 | 77 | # List of patterns, relative to source directory, that match files and 78 | # directories to ignore when looking for source files. 79 | # This pattern also affects html_static_path and html_extra_path . 80 | exclude_patterns: List[str] = [] 81 | 82 | # The name of the Pygments (syntax highlighting) style to use. 83 | pygments_style = "sphinx" 84 | 85 | 86 | # -- Options for HTML output ------------------------------------------------- 87 | 88 | # The theme to use for HTML and HTML Help pages. See the documentation for 89 | # a list of builtin themes. 90 | # 91 | html_theme = "sphinx_rtd_theme" 92 | 93 | 94 | # Theme options are theme-specific and customize the look and feel of a theme 95 | # further. For a list of options available for each theme, see the 96 | # documentation. 97 | # 98 | # html_theme_options = {} 99 | 100 | # Add any paths that contain custom static files (such as style sheets) here, 101 | # relative to this directory. They are copied after the builtin static files, 102 | # so a file named "default.css" will overwrite the builtin "default.css". 103 | html_static_path = ["_static"] 104 | 105 | # Custom sidebar templates, must be a dictionary that maps document names 106 | # to template names. 107 | # 108 | # The default sidebars (for documents that don't match any pattern) are 109 | # defined by theme itself. Builtin themes are using these templates by 110 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 111 | # 'searchbox.html']``. 112 | # 113 | # html_sidebars = {} 114 | 115 | 116 | # -- Options for HTMLHelp output --------------------------------------------- 117 | 118 | # Output file base name for HTML help builder. 119 | htmlhelp_basename = "jaxlie_doc" 120 | 121 | 122 | # -- Options for Github output ------------------------------------------------ 123 | 124 | sphinx_to_github = True 125 | sphinx_to_github_verbose = True 126 | sphinx_to_github_encoding = "utf-8" 127 | 128 | 129 | # -- Options for LaTeX output ------------------------------------------------ 130 | 131 | latex_elements: Dict[str, str] = { 132 | # The paper size ('letterpaper' or 'a4paper'). 133 | # 134 | # 'papersize': 'letterpaper', 135 | # The font size ('10pt', '11pt' or '12pt'). 136 | # 137 | # 'pointsize': '10pt', 138 | # Additional stuff for the LaTeX preamble. 139 | # 140 | # 'preamble': '', 141 | # Latex figure (float) alignment 142 | # 143 | # 'figure_align': 'htbp', 144 | } 145 | 146 | # Grouping the document tree into LaTeX files. List of tuples 147 | # (source start file, target name, title, 148 | # author, documentclass [howto, manual, or own class]). 149 | latex_documents = [ 150 | ( 151 | master_doc, 152 | "jaxlie.tex", 153 | "jaxlie documentation", 154 | "brentyi", 155 | "manual", 156 | ), 157 | ] 158 | 159 | 160 | # -- Options for manual page output ------------------------------------------ 161 | 162 | # One entry per manual page. List of tuples 163 | # (source start file, name, description, authors, manual section). 164 | man_pages = [(master_doc, "jaxlie", "jaxlie documentation", [author], 1)] 165 | 166 | 167 | # -- Options for Texinfo output ---------------------------------------------- 168 | 169 | # Grouping the document tree into Texinfo files. List of tuples 170 | # (source start file, target name, title, author, 171 | # dir menu entry, description, category) 172 | texinfo_documents = [ 173 | ( 174 | master_doc, 175 | "jaxlie", 176 | "jaxlie documentation", 177 | author, 178 | "jaxlie", 179 | "jaxlie documentation", 180 | "Miscellaneous", 181 | ), 182 | ] 183 | 184 | 185 | # -- Extension configuration -------------------------------------------------- 186 | 187 | # -- Options for autoapi extension -------------------------------------------- 188 | autoapi_dirs = ["../../jaxlie"] 189 | autoapi_root = "api" 190 | autoapi_options = [ 191 | "members", 192 | "undoc-members", 193 | "imported-members", 194 | "show-inheritance", 195 | "show-inheritance-diagram", 196 | "special-members", 197 | "show-module-summary", 198 | ] 199 | autoapi_add_toctree_entry = False 200 | 201 | 202 | # Generate name aliases 203 | def _gen_name_aliases(): 204 | """Generate a name alias dictionary, which maps private names to ones in the public 205 | API. A little bit hardcoded/hacky.""" 206 | 207 | name_alias = {} 208 | 209 | def recurse(module, prefixes): 210 | if hasattr(module, "__name__") and module.__name__.startswith("jaxlie"): 211 | MAX_DEPTH = 5 212 | if len(prefixes) > MAX_DEPTH: 213 | # Prevent infinite loops from cyclic imports 214 | return 215 | else: 216 | return 217 | 218 | for member_name in dir(module): 219 | if member_name == "jaxlie": 220 | continue 221 | 222 | member = getattr(module, member_name) 223 | if callable(member): 224 | full_name = ".".join(["jaxlie"] + prefixes + [member_name]) 225 | 226 | shortened_name = "jaxlie" 227 | current = jaxlie 228 | success = True 229 | for p in prefixes + [member_name]: 230 | if p.startswith("_"): 231 | continue 232 | if not hasattr(current, p): 233 | success = False 234 | break 235 | current = getattr(current, p) 236 | shortened_name += "." + p 237 | 238 | if success and shortened_name != full_name: 239 | if full_name in name_alias: 240 | assert full_name == name_alias[shortened_name], full_name 241 | else: 242 | name_alias[full_name] = shortened_name 243 | elif not member_name.startswith("__"): 244 | recurse(member, prefixes + [member_name]) 245 | 246 | import jaxlie 247 | 248 | recurse(jaxlie, prefixes=[]) 249 | return name_alias 250 | 251 | 252 | _name_aliases = _gen_name_aliases() 253 | 254 | # Set inheritance_alias setting for inheritance diagrams 255 | inheritance_alias = _name_aliases 256 | 257 | 258 | def _apply_name_aliases(name: Optional[str]) -> Optional[str]: 259 | if name is None: 260 | return None 261 | 262 | name = name.strip() 263 | for k, v in _name_aliases.items(): 264 | name = name.replace(k, v) 265 | return name # type: ignore 266 | 267 | 268 | # Apply our inheritance alias to autoapi base classes 269 | def _override_class_documenter(): 270 | import autoapi 271 | 272 | orig_init = autoapi.mappers.python.PythonClass.__init__ 273 | 274 | def __init__(self, obj, **kwargs): 275 | bases = obj["bases"] 276 | for i in range(len(bases)): 277 | bases[i] = _apply_name_aliases(bases[i]) 278 | obj["full_name"] = _apply_name_aliases(obj["full_name"]) 279 | orig_init(self, obj, **kwargs) 280 | 281 | autoapi.mappers.python.PythonClass.__init__ = __init__ 282 | 283 | 284 | _override_class_documenter() 285 | 286 | 287 | # Apply our inheritance alias to autoapi type annotations 288 | def _override_function_documenter(): 289 | import autoapi 290 | 291 | orig_init = autoapi.mappers.python.PythonFunction.__init__ 292 | 293 | def __init__(self, obj, **kwargs): 294 | args = obj["args"] 295 | if args is not None: 296 | for i in range(len(args)): 297 | assert isinstance(args[i], tuple) and len(args[i]) == 4 298 | args[i] = ( 299 | args[i][0], 300 | args[i][1], 301 | _apply_name_aliases(args[i][2]), 302 | args[i][3], 303 | ) 304 | 305 | obj["return_annotation"] = _apply_name_aliases(obj["return_annotation"]) 306 | orig_init(self, obj, **kwargs) 307 | 308 | autoapi.mappers.python.PythonFunction.__init__ = __init__ 309 | 310 | 311 | _override_function_documenter() 312 | 313 | 314 | # Apply our inheritance alias to autoapi attribute annotations 315 | def _override_attribute_documenter(): 316 | import autoapi 317 | 318 | orig_init = autoapi.mappers.python.PythonAttribute.__init__ 319 | 320 | def __init__(self, obj, **kwargs): 321 | obj["annotation"] = _apply_name_aliases(obj["annotation"]) 322 | orig_init(self, obj, **kwargs) 323 | 324 | autoapi.mappers.python.PythonAttribute.__init__ = __init__ 325 | 326 | 327 | _override_attribute_documenter() 328 | 329 | 330 | # -- Options for todo extension ---------------------------------------------- 331 | 332 | # If true, `todo` and `todoList` produce output, else they produce nothing. 333 | todo_include_todos = True 334 | 335 | # -- Enable Markdown -> RST conversion ---------------------------------------- 336 | 337 | import m2r2 338 | 339 | 340 | def docstring(app, what, name, obj, options, lines): 341 | md = "\n".join(lines) 342 | rst = m2r2.convert(md) 343 | lines.clear() 344 | lines += rst.splitlines() 345 | 346 | 347 | def setup(app): 348 | app.connect("autodoc-process-docstring", docstring) 349 | -------------------------------------------------------------------------------- /jaxlie/_so3.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Tuple, cast 4 | 5 | import jax 6 | import jax_dataclasses as jdc 7 | from jax import numpy as jnp 8 | from typing_extensions import override 9 | 10 | from . import _base, hints 11 | from .utils import broadcast_leading_axes, get_epsilon, register_lie_group 12 | 13 | 14 | def _skew(omega: hints.Array) -> jax.Array: 15 | """Returns the skew-symmetric form of a length-3 vector.""" 16 | 17 | wx, wy, wz = jnp.moveaxis(omega, -1, 0) 18 | zeros = jnp.zeros_like(wx) 19 | return jnp.stack( 20 | [zeros, -wz, wy, wz, zeros, -wx, -wy, wx, zeros], 21 | axis=-1, 22 | ).reshape((*omega.shape[:-1], 3, 3)) 23 | 24 | 25 | def _SO3_jac_left(theta: jax.Array, rotation_matrix: jax.Array) -> jax.Array: 26 | """Compute the left jacobian for the given theta and rotation matrix. 27 | 28 | This function calculates the left jacobian, which is used in various geometric transformations. 29 | It handles both small and large theta values using different computation methods. 30 | 31 | Args: 32 | theta (jax.Array): The input angle(s) in axis-angle representation. 33 | rotation_matrix (jax.Array): The corresponding rotation matrix. 34 | 35 | Returns: 36 | jax.Array: A 3x3 matrix (or batch of 3x3 matrices) representing the left jacobian. 37 | """ 38 | theta_squared = jnp.sum(jnp.square(theta), axis=-1) 39 | use_taylor = theta_squared < get_epsilon(theta_squared.dtype) 40 | 41 | # Shim to avoid NaNs in jnp.where branches, which cause failures for 42 | # reverse-mode AD. 43 | theta_squared_safe = cast( 44 | jax.Array, 45 | jnp.where( 46 | use_taylor, 47 | # Any non-zero value should do here. 48 | jnp.ones_like(theta_squared), 49 | theta_squared, 50 | ), 51 | ) 52 | del theta_squared 53 | theta_safe = jnp.sqrt(theta_squared_safe) 54 | 55 | skew_omega = _skew(theta) 56 | jac_left = jnp.where( 57 | use_taylor[..., None, None], 58 | rotation_matrix, 59 | ( 60 | jnp.eye(3) 61 | + ((1.0 - jnp.cos(theta_safe)) / (theta_squared_safe))[..., None, None] 62 | * skew_omega 63 | + ((theta_safe - jnp.sin(theta_safe)) / (theta_squared_safe * theta_safe))[ 64 | ..., None, None 65 | ] 66 | * jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) 67 | ), 68 | ) 69 | return jac_left 70 | 71 | 72 | def _SO3_jac_left_inv(theta: jax.Array) -> jax.Array: 73 | """ 74 | Compute the inverse of the left jacobian for the given theta. 75 | 76 | This function calculates the inverse of the left jacobian, which is used in various 77 | geometric transformations. It handles both small and large theta values 78 | using different computation methods. 79 | 80 | Args: 81 | theta (jax.Array): The input angle(s) in axis-angle representation. 82 | 83 | Returns: 84 | jax.Array: A 3x3 matrix (or batch of 3x3 matrices) representing the inverse left jacobian. 85 | """ 86 | theta_squared = jnp.sum(jnp.square(theta), axis=-1) 87 | use_taylor = theta_squared < get_epsilon(theta_squared.dtype) 88 | 89 | # Shim to avoid NaNs in jnp.where branches, which cause failures for 90 | # reverse-mode AD. 91 | theta_squared_safe = jnp.where( 92 | use_taylor, 93 | jnp.ones_like(theta_squared), # Any non-zero value should do here. 94 | theta_squared, 95 | ) 96 | del theta_squared 97 | theta_safe = jnp.sqrt(theta_squared_safe) 98 | half_theta_safe = theta_safe / 2.0 99 | 100 | skew_omega = _skew(theta) 101 | jac_left_inv = jnp.where( 102 | use_taylor[..., None, None], 103 | jnp.eye(3) 104 | - 0.5 * skew_omega 105 | + jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) / 12.0, 106 | ( 107 | jnp.eye(3) 108 | - 0.5 * skew_omega 109 | + ( 110 | ( 111 | 1.0 112 | - theta_safe 113 | * jnp.cos(half_theta_safe) 114 | / (2.0 * jnp.sin(half_theta_safe)) 115 | ) 116 | / theta_squared_safe 117 | )[..., None, None] 118 | * jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) 119 | ), 120 | ) 121 | return jac_left_inv 122 | 123 | 124 | @register_lie_group( 125 | matrix_dim=3, 126 | parameters_dim=4, 127 | tangent_dim=3, 128 | space_dim=3, 129 | ) 130 | @jdc.pytree_dataclass 131 | class SO3(_base.SOBase): 132 | """Special orthogonal group for 3D rotations. Broadcasting rules are the same as 133 | for numpy. 134 | 135 | Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is 136 | `(omega_x, omega_y, omega_z)`. 137 | """ 138 | 139 | wxyz: jax.Array 140 | """Internal parameters. `(w, x, y, z)` quaternion. Shape should be `(*, 4)`.""" 141 | 142 | @override 143 | def __repr__(self) -> str: 144 | wxyz = jnp.round(self.wxyz, 5) 145 | return f"{self.__class__.__name__}(wxyz={wxyz})" 146 | 147 | @staticmethod 148 | def from_x_radians(theta: hints.Scalar) -> SO3: 149 | """Generates a x-axis rotation. 150 | 151 | Args: 152 | angle: X rotation, in radians. 153 | 154 | Returns: 155 | Output. 156 | """ 157 | zeros = jnp.zeros_like(theta) 158 | return SO3.exp(jnp.stack([theta, zeros, zeros], axis=-1)) 159 | 160 | @staticmethod 161 | def from_y_radians(theta: hints.Scalar) -> SO3: 162 | """Generates a y-axis rotation. 163 | 164 | Args: 165 | angle: Y rotation, in radians. 166 | 167 | Returns: 168 | Output. 169 | """ 170 | zeros = jnp.zeros_like(theta) 171 | return SO3.exp(jnp.stack([zeros, theta, zeros], axis=-1)) 172 | 173 | @staticmethod 174 | def from_z_radians(theta: hints.Scalar) -> SO3: 175 | """Generates a z-axis rotation. 176 | 177 | Args: 178 | angle: Z rotation, in radians. 179 | 180 | Returns: 181 | Output. 182 | """ 183 | zeros = jnp.zeros_like(theta) 184 | return SO3.exp(jnp.stack([zeros, zeros, theta], axis=-1)) 185 | 186 | @staticmethod 187 | def from_rpy_radians( 188 | roll: hints.Scalar, 189 | pitch: hints.Scalar, 190 | yaw: hints.Scalar, 191 | ) -> SO3: 192 | """Generates a transform from a set of Euler angles. Uses the ZYX mobile robot 193 | convention. 194 | 195 | Args: 196 | roll: X rotation, in radians. Applied first. 197 | pitch: Y rotation, in radians. Applied second. 198 | yaw: Z rotation, in radians. Applied last. 199 | 200 | Returns: 201 | Output. 202 | """ 203 | return ( 204 | SO3.from_z_radians(yaw) 205 | @ SO3.from_y_radians(pitch) 206 | @ SO3.from_x_radians(roll) 207 | ) 208 | 209 | @staticmethod 210 | def from_quaternion_xyzw(xyzw: hints.Array) -> SO3: 211 | """Construct a rotation from an `xyzw` quaternion. 212 | 213 | Note that `wxyz` quaternions can be constructed using the default dataclass 214 | constructor. 215 | 216 | Args: 217 | xyzw: xyzw quaternion. Shape should be (*, 4). 218 | 219 | Returns: 220 | Output. 221 | """ 222 | assert xyzw.shape[-1:] == (4,) 223 | return SO3(jnp.roll(xyzw, axis=-1, shift=1)) 224 | 225 | def as_quaternion_xyzw(self) -> jax.Array: 226 | """Grab parameters as xyzw quaternion.""" 227 | return jnp.roll(self.wxyz, axis=-1, shift=-1) 228 | 229 | def as_rpy_radians(self) -> hints.RollPitchYaw: 230 | """Computes roll, pitch, and yaw angles. Uses the ZYX mobile robot convention. 231 | 232 | Returns: 233 | Named tuple containing Euler angles in radians. 234 | """ 235 | return hints.RollPitchYaw( 236 | roll=self.compute_roll_radians(), 237 | pitch=self.compute_pitch_radians(), 238 | yaw=self.compute_yaw_radians(), 239 | ) 240 | 241 | def compute_roll_radians(self) -> jax.Array: 242 | """Compute roll angle. Uses the ZYX mobile robot convention. 243 | 244 | Returns: 245 | Euler angle in radians. 246 | """ 247 | # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion 248 | q0, q1, q2, q3 = jnp.moveaxis(self.wxyz, -1, 0) 249 | return jnp.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1**2 + q2**2)) 250 | 251 | def compute_pitch_radians(self) -> jax.Array: 252 | """Compute pitch angle. Uses the ZYX mobile robot convention. 253 | 254 | Returns: 255 | Euler angle in radians. 256 | """ 257 | # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion 258 | q0, q1, q2, q3 = jnp.moveaxis(self.wxyz, -1, 0) 259 | return jnp.arcsin(2 * (q0 * q2 - q3 * q1)) 260 | 261 | def compute_yaw_radians(self) -> jax.Array: 262 | """Compute yaw angle. Uses the ZYX mobile robot convention. 263 | 264 | Returns: 265 | Euler angle in radians. 266 | """ 267 | # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion 268 | q0, q1, q2, q3 = jnp.moveaxis(self.wxyz, -1, 0) 269 | return jnp.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2**2 + q3**2)) 270 | 271 | # Factory. 272 | 273 | @classmethod 274 | @override 275 | def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SO3: 276 | return SO3( 277 | wxyz=jnp.broadcast_to(jnp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4)) 278 | ) 279 | 280 | @classmethod 281 | @override 282 | def from_matrix(cls, matrix: hints.Array) -> SO3: 283 | assert matrix.shape[-2:] == (3, 3) 284 | 285 | # Modified from: 286 | # > "Converting a Rotation Matrix to a Quaternion" from Mike Day 287 | # > https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf 288 | 289 | def case0(m): 290 | t = 1 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2] 291 | q = jnp.stack( 292 | [ 293 | m[..., 2, 1] - m[..., 1, 2], 294 | t, 295 | m[..., 1, 0] + m[..., 0, 1], 296 | m[..., 0, 2] + m[..., 2, 0], 297 | ], 298 | axis=-1, 299 | ) 300 | return t, q 301 | 302 | def case1(m): 303 | t = 1 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2] 304 | q = jnp.stack( 305 | [ 306 | m[..., 0, 2] - m[..., 2, 0], 307 | m[..., 1, 0] + m[..., 0, 1], 308 | t, 309 | m[..., 2, 1] + m[..., 1, 2], 310 | ], 311 | axis=-1, 312 | ) 313 | return t, q 314 | 315 | def case2(m): 316 | t = 1 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2] 317 | q = jnp.stack( 318 | [ 319 | m[..., 1, 0] - m[..., 0, 1], 320 | m[..., 0, 2] + m[..., 2, 0], 321 | m[..., 2, 1] + m[..., 1, 2], 322 | t, 323 | ], 324 | axis=-1, 325 | ) 326 | return t, q 327 | 328 | def case3(m): 329 | t = 1 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2] 330 | q = jnp.stack( 331 | [ 332 | t, 333 | m[..., 2, 1] - m[..., 1, 2], 334 | m[..., 0, 2] - m[..., 2, 0], 335 | m[..., 1, 0] - m[..., 0, 1], 336 | ], 337 | axis=-1, 338 | ) 339 | return t, q 340 | 341 | # Compute four cases, then pick the most precise one. 342 | # Probably worth revisiting this! 343 | case0_t, case0_q = case0(matrix) 344 | case1_t, case1_q = case1(matrix) 345 | case2_t, case2_q = case2(matrix) 346 | case3_t, case3_q = case3(matrix) 347 | 348 | cond0 = matrix[..., 2, 2] < 0 349 | cond1 = matrix[..., 0, 0] > matrix[..., 1, 1] 350 | cond2 = matrix[..., 0, 0] < -matrix[..., 1, 1] 351 | 352 | t = jnp.where( 353 | cond0, 354 | jnp.where(cond1, case0_t, case1_t), 355 | jnp.where(cond2, case2_t, case3_t), 356 | ) 357 | q = jnp.where( 358 | cond0[..., None], 359 | jnp.where(cond1[..., None], case0_q, case1_q), 360 | jnp.where(cond2[..., None], case2_q, case3_q), 361 | ) 362 | 363 | # We can also choose to branch, but this is slower. 364 | # t, q = jax.lax.cond( 365 | # matrix[2, 2] < 0, 366 | # true_fun=lambda matrix: jax.lax.cond( 367 | # matrix[0, 0] > matrix[1, 1], 368 | # true_fun=case0, 369 | # false_fun=case1, 370 | # operand=matrix, 371 | # ), 372 | # false_fun=lambda matrix: jax.lax.cond( 373 | # matrix[0, 0] < -matrix[1, 1], 374 | # true_fun=case2, 375 | # false_fun=case3, 376 | # operand=matrix, 377 | # ), 378 | # operand=matrix, 379 | # ) 380 | 381 | return SO3(wxyz=q * 0.5 / jnp.sqrt(t[..., None])) 382 | 383 | # Accessors. 384 | 385 | @override 386 | def as_matrix(self) -> jax.Array: 387 | norm_sq = jnp.sum(jnp.square(self.wxyz), axis=-1, keepdims=True) 388 | q = self.wxyz * jnp.sqrt(2.0 / norm_sq) # (*, 4) 389 | q_outer = jnp.einsum("...i,...j->...ij", q, q) # (*, 4, 4) 390 | return jnp.stack( 391 | [ 392 | 1.0 - q_outer[..., 2, 2] - q_outer[..., 3, 3], 393 | q_outer[..., 1, 2] - q_outer[..., 3, 0], 394 | q_outer[..., 1, 3] + q_outer[..., 2, 0], 395 | q_outer[..., 1, 2] + q_outer[..., 3, 0], 396 | 1.0 - q_outer[..., 1, 1] - q_outer[..., 3, 3], 397 | q_outer[..., 2, 3] - q_outer[..., 1, 0], 398 | q_outer[..., 1, 3] - q_outer[..., 2, 0], 399 | q_outer[..., 2, 3] + q_outer[..., 1, 0], 400 | 1.0 - q_outer[..., 1, 1] - q_outer[..., 2, 2], 401 | ], 402 | axis=-1, 403 | ).reshape(*q.shape[:-1], 3, 3) 404 | 405 | @override 406 | def parameters(self) -> jax.Array: 407 | return self.wxyz 408 | 409 | # Operations. 410 | 411 | @override 412 | def apply(self, target: hints.Array) -> jax.Array: 413 | assert target.shape[-1:] == (3,) 414 | self, target = broadcast_leading_axes((self, target)) 415 | 416 | # Compute using quaternion multiplys. 417 | padded_target = jnp.concatenate( 418 | [jnp.zeros((*self.get_batch_axes(), 1)), target], axis=-1 419 | ) 420 | return (self @ SO3(wxyz=padded_target) @ self.inverse()).wxyz[..., 1:] 421 | 422 | @override 423 | def multiply(self, other: SO3) -> SO3: 424 | # Original implementation: 425 | # 426 | # w0, x0, y0, z0 = jnp.moveaxis(self.wxyz, -1, 0) 427 | # w1, x1, y1, z1 = jnp.moveaxis(other.wxyz, -1, 0) 428 | # return SO3( 429 | # wxyz=jnp.stack( 430 | # [ 431 | # -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1, 432 | # x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1, 433 | # -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1, 434 | # x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1, 435 | # ], 436 | # axis=-1, 437 | # ) 438 | # ) 439 | # 440 | # This is great/fine/standard, but there are a lot of operations. This 441 | # puts a lot of burden on the JIT compiler. 442 | # 443 | # Here's another implementation option. The JIT time is much faster, but the 444 | # runtime is ~10% slower: 445 | # 446 | # inds = jnp.array([0, 1, 2, 3, 1, 0, 3, 2, 2, 3, 0, 1, 3, 2, 1, 0]) 447 | # signs = jnp.array([1, -1, -1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1, 1]) 448 | # return SO3( 449 | # wxyz=jnp.einsum( 450 | # "...ij,...j->...i", 451 | # (self.wxyz[..., inds] * signs).reshape((*self.wxyz.shape, 4)), 452 | # other.wxyz, 453 | # ) 454 | # ) 455 | # 456 | # For pose graph optimization on the sphere2500 dataset, the following 457 | # speeds up *overall* JIT times by over 35%, without any runtime 458 | # penalties. 459 | 460 | # Hamilton product constants. 461 | terms_i = jnp.array([[0, 1, 2, 3], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]]) 462 | terms_j = jnp.array([[0, 1, 2, 3], [1, 0, 3, 2], [2, 0, 1, 3], [3, 0, 2, 1]]) 463 | signs = jnp.array( 464 | [ 465 | [1, -1, -1, -1], 466 | [1, 1, 1, -1], 467 | [1, 1, 1, -1], 468 | [1, 1, 1, -1], 469 | ] 470 | ) 471 | 472 | # Compute all components at once 473 | q_outer = jnp.einsum("...i,...j->...ij", self.wxyz, other.wxyz) 474 | return SO3( 475 | jnp.sum( 476 | signs * q_outer[..., terms_i, terms_j], 477 | axis=-1, 478 | ) 479 | ) 480 | 481 | @classmethod 482 | @override 483 | def exp(cls, tangent: hints.Array) -> SO3: 484 | # Reference: 485 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583 486 | 487 | assert tangent.shape[-1:] == (3,) 488 | 489 | theta_squared = jnp.sum(jnp.square(tangent), axis=-1) 490 | theta_pow_4 = theta_squared * theta_squared 491 | use_taylor = theta_squared < get_epsilon(tangent.dtype) 492 | 493 | # Shim to avoid NaNs in jnp.where branches, which cause failures for 494 | # reverse-mode AD. 495 | safe_theta = jnp.sqrt( 496 | jnp.where( 497 | use_taylor, 498 | # Any constant value should do here. 499 | jnp.ones_like(theta_squared), 500 | theta_squared, 501 | ) 502 | ) 503 | safe_half_theta = 0.5 * safe_theta 504 | 505 | real_factor = jnp.where( 506 | use_taylor, 507 | 1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0, 508 | jnp.cos(safe_half_theta), 509 | ) 510 | 511 | imaginary_factor = jnp.where( 512 | use_taylor, 513 | 0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0, 514 | jnp.sin(safe_half_theta) / safe_theta, 515 | ) 516 | 517 | return SO3( 518 | wxyz=jnp.concatenate( 519 | [ 520 | real_factor[..., None], 521 | imaginary_factor[..., None] * tangent, 522 | ], 523 | axis=-1, 524 | ) 525 | ) 526 | 527 | @override 528 | def log(self) -> jax.Array: 529 | # Reference: 530 | # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247 531 | 532 | w = self.wxyz[..., 0] 533 | norm_sq = jnp.sum(jnp.square(self.wxyz[..., 1:]), axis=-1) 534 | use_taylor = norm_sq < get_epsilon(norm_sq.dtype) 535 | 536 | # Shim to avoid NaNs in jnp.where branches, which cause failures for 537 | # reverse-mode AD. 538 | norm_safe = jnp.sqrt( 539 | jnp.where( 540 | use_taylor, 541 | 1.0, # Any non-zero value should do here. 542 | norm_sq, 543 | ) 544 | ) 545 | w_safe = jnp.where(use_taylor, w, 1.0) 546 | atan_n_over_w = jnp.arctan2( 547 | jnp.where(w < 0, -norm_safe, norm_safe), 548 | jnp.abs(w), 549 | ) 550 | atan_factor = jnp.where( 551 | use_taylor, 552 | 2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3, 553 | jnp.where( 554 | jnp.abs(w) < get_epsilon(w.dtype), 555 | jnp.where(w > 0, 1.0, -1.0) * jnp.pi / norm_safe, 556 | 2.0 * atan_n_over_w / norm_safe, 557 | ), 558 | ) 559 | 560 | return atan_factor[..., None] * self.wxyz[..., 1:] 561 | 562 | @override 563 | def adjoint(self) -> jax.Array: 564 | return self.as_matrix() 565 | 566 | @override 567 | def inverse(self) -> SO3: 568 | # Negate complex terms. 569 | return SO3(wxyz=self.wxyz * jnp.array([1, -1, -1, -1])) 570 | 571 | @override 572 | def normalize(self) -> SO3: 573 | return SO3(wxyz=self.wxyz / jnp.linalg.norm(self.wxyz, axis=-1, keepdims=True)) 574 | 575 | @override 576 | def jlog(self) -> jax.Array: 577 | # Reference: 578 | # Equations (144, 147, 174) from Micro-Lie theory: 579 | # > https://arxiv.org/pdf/1812.01537 580 | V_inv = _SO3_jac_left_inv(self.log()) 581 | return jnp.swapaxes(V_inv, -1, -2) # Transpose the last two dimensions 582 | 583 | @classmethod 584 | @override 585 | def sample_uniform( 586 | cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = () 587 | ) -> SO3: 588 | # Uniformly sample over S^3. 589 | # > Reference: http://planning.cs.uiuc.edu/node198.html 590 | u1, u2, u3 = jnp.moveaxis( 591 | jax.random.uniform( 592 | key=key, 593 | shape=(*batch_axes, 3), 594 | minval=jnp.zeros(3), 595 | maxval=jnp.array([1.0, 2.0 * jnp.pi, 2.0 * jnp.pi]), 596 | ), 597 | -1, 598 | 0, 599 | ) 600 | a = jnp.sqrt(1.0 - u1) 601 | b = jnp.sqrt(u1) 602 | 603 | return SO3( 604 | wxyz=jnp.stack( 605 | [ 606 | a * jnp.sin(u2), 607 | a * jnp.cos(u2), 608 | b * jnp.sin(u3), 609 | b * jnp.cos(u3), 610 | ], 611 | axis=-1, 612 | ) 613 | ) 614 | --------------------------------------------------------------------------------