├── .DS_Store ├── .github └── workflows │ ├── build.yml │ ├── docs.yml │ ├── mypy.yml │ ├── ruff.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── bench_global_ik_vmapped_output.py ├── bench_local_ik.py └── bench_local_ik_vmapped_output.py ├── docs ├── README.md ├── barriers.rst ├── components.rst ├── conf.py ├── configuration.rst ├── constraints.rst ├── css │ └── custom.css ├── developer-notes.rst ├── img │ └── kuka_iiwa_14.png ├── index.rst ├── installation.rst ├── problem.rst ├── quick_start.rst ├── references.bib ├── references.rst ├── solvers.rst ├── tasks.rst ├── typing.rst └── visualization.rst ├── examples ├── cassie_alip.py ├── cassie_squat.py ├── g1_description │ ├── LICENSE │ ├── README.md │ ├── assets │ │ ├── head_link.STL │ │ ├── left_ankle_pitch_link.STL │ │ ├── left_ankle_roll_link.STL │ │ ├── left_elbow_pitch_link.STL │ │ ├── left_elbow_roll_link.STL │ │ ├── left_five_link.STL │ │ ├── left_four_link.STL │ │ ├── left_hip_pitch_link.STL │ │ ├── left_hip_roll_link.STL │ │ ├── left_hip_yaw_link.STL │ │ ├── left_knee_link.STL │ │ ├── left_one_link.STL │ │ ├── left_palm_link.STL │ │ ├── left_shoulder_pitch_link.STL │ │ ├── left_shoulder_roll_link.STL │ │ ├── left_shoulder_yaw_link.STL │ │ ├── left_six_link.STL │ │ ├── left_three_link.STL │ │ ├── left_two_link.STL │ │ ├── left_zero_link.STL │ │ ├── logo_link.STL │ │ ├── pelvis.STL │ │ ├── pelvis_contour_link.STL │ │ ├── right_ankle_pitch_link.STL │ │ ├── right_ankle_roll_link.STL │ │ ├── right_elbow_pitch_link.STL │ │ ├── right_elbow_roll_link.STL │ │ ├── right_five_link.STL │ │ ├── right_four_link.STL │ │ ├── right_hip_pitch_link.STL │ │ ├── right_hip_roll_link.STL │ │ ├── right_hip_yaw_link.STL │ │ ├── right_knee_link.STL │ │ ├── right_one_link.STL │ │ ├── right_palm_link.STL │ │ ├── right_shoulder_pitch_link.STL │ │ ├── right_shoulder_roll_link.STL │ │ ├── right_shoulder_yaw_link.STL │ │ ├── right_six_link.STL │ │ ├── right_three_link.STL │ │ ├── right_two_link.STL │ │ ├── right_zero_link.STL │ │ └── torso_link.STL │ ├── g1.png │ ├── g1.xml │ └── scene.xml ├── g1_heart.py ├── g1_heart_constrained.py ├── global_ik.py ├── global_ik_vmapped_input.py ├── global_ik_vmapped_output.py ├── go2_squat.py ├── local_ik.py ├── local_ik_vmapped_input.py ├── local_ik_vmapped_output.py └── notebooks │ └── turoial.ipynb ├── img ├── cassie_caravan.gif ├── g1_heart.gif ├── go2_stance.gif ├── local_ik_input.gif ├── local_ik_output.gif └── logo.svg ├── mjinx ├── __init__.py ├── components │ ├── __init__.py │ ├── _base.py │ ├── barriers │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── _joint_barrier.py │ │ ├── _obj_barrier.py │ │ ├── _obj_position_barrier.py │ │ └── _self_collision_barrier.py │ ├── constraints │ │ ├── __init__.py │ │ ├── _base.py │ │ └── _equality_constraint.py │ └── tasks │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── _com_task.py │ │ ├── _joint_task.py │ │ ├── _obj_frame_task.py │ │ ├── _obj_position_task.py │ │ └── _obj_task.py ├── configuration │ ├── __init__.py │ ├── _collision.py │ ├── _lie.py │ └── _model.py ├── model.py ├── problem.py ├── solvers │ ├── __init__.py │ ├── _base.py │ ├── _global_ik.py │ └── _local_ik.py ├── typing.py └── visualize.py ├── mypy.ini ├── pyproject.toml └── tests ├── __init__.py ├── example_tests ├── test_global_ik_jit.py ├── test_global_ik_vmap.py ├── test_local_ik_jit.py └── test_local_ik_vmap.py └── unit_tests ├── __init__.py ├── components ├── __init__.py ├── barriers │ ├── __init__.py │ ├── test_base_barrier.py │ ├── test_joint_barrier.py │ ├── test_obj_barrier.py │ ├── test_obj_position_barrier.py │ └── test_self_collision_barrier.py ├── constraints │ ├── __init__.py │ └── test_base_constraint.py ├── tasks │ ├── __init__.py │ ├── test_base_task.py │ ├── test_com_task.py │ ├── test_joint_task.py │ ├── test_obj_frame_task.py │ ├── test_obj_position_task.py │ └── test_obj_task.py └── test_base_component.py ├── configuration ├── __init__.py ├── test_collision_computation.py ├── test_collision_utils.py └── test_configuration.py ├── solvers ├── __init__.py ├── test_global_ik.py └── test_local_ik.py └── test_problem.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/.DS_Store -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*" # Triggers on any tag that starts with 'v' 7 | 8 | permissions: 9 | contents: read 10 | id-token: write # IMPORTANT: mandatory for trusted publishing 11 | jobs: 12 | # tests: 13 | # uses: ./.github/workflows/tests.yml # use the callable tests job to run tests 14 | 15 | deploy: 16 | runs-on: ubuntu-latest 17 | # run on tag only 18 | if: startsWith(github.ref, 'refs/tags/') 19 | # needs: [tests] # require tests to pass before deploy runs 20 | steps: 21 | - name: "Checkout Git repository" 22 | uses: actions/checkout@v3 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: "3.x" 28 | 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | 34 | - name: Build package 35 | run: python -m build --wheel 36 | 37 | - name: Publish package 38 | uses: pypa/gh-action-pypi-publish@release/v1 39 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | 3 | on: 4 | push: 5 | branches: [main, docs/github_pages] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | docs: 11 | name: "GitHub Pages" 12 | runs-on: ubuntu-latest 13 | permissions: 14 | contents: write 15 | steps: 16 | - name: "Checkout Git repository" 17 | uses: actions/checkout@v3 18 | 19 | - name: "Set up Python" 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: "3.10" # Specify the Python version you need 23 | 24 | - name: "Install dependencies" 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install ".[all]" 28 | 29 | - name: "Build documentation" 30 | run: | 31 | sphinx-build docs _build 32 | 33 | - name: "Deploy to GitHub Pages" 34 | uses: peaceiris/actions-gh-pages@v3 35 | if: ${{ github.ref == 'refs/heads/main' }} 36 | with: 37 | github_token: ${{ secrets.GITHUB_TOKEN }} 38 | publish_dir: _build/ 39 | force_orphan: true 40 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | name: mypy 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | mypy: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Check out the repository 14 | uses: actions/checkout@v3 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: "3.12" 20 | 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install -e . 25 | pip install mypy 26 | 27 | - name: Run mypy 28 | run: | 29 | mypy mjinx 30 | -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: Formatting 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | ruff: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Install Python 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: "3.10" 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install ruff==0.6.3 # keep in sync with pyproject.toml 24 | # Update output format to enable automatic inline annotations. 25 | - name: Run Ruff 26 | run: ruff format --diff . 27 | 28 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | workflow_call: # allow this workflow to be called from other workflows 9 | 10 | jobs: 11 | unit-tests: 12 | name: "Unit tests" 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: "Checkout sources" 17 | uses: actions/checkout@v3 18 | with: 19 | submodules: recursive 20 | 21 | - name: "Set up Python" 22 | uses: actions/setup-python@v4 23 | with: 24 | python-version: "3.10" 25 | 26 | - name: "Install dependencies" 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install coveralls 30 | pip install ".[dev]" 31 | 32 | - name: "Run unit tests" 33 | run: | 34 | coverage erase 35 | coverage run -m unittest discover --failfast 36 | coverage report --include="mjinx/*" 37 | 38 | test: 39 | runs-on: ubuntu-latest 40 | strategy: 41 | matrix: 42 | python-version: ["3.10", "3.11", "3.12", "3.13"] 43 | 44 | steps: 45 | - name: Check out the repository 46 | uses: actions/checkout@v3 47 | 48 | - name: Set up Python ${{ matrix.python-version }} 49 | uses: actions/setup-python@v4 50 | with: 51 | python-version: ${{ matrix.python-version }} 52 | 53 | - name: Install dependencies 54 | run: | 55 | python -m pip install --upgrade pip 56 | pip install ".[dev]" 57 | 58 | - name: Run tests 59 | run: | 60 | pytest tests/example_tests 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | _build 164 | MUJOCO_LOG.TXT 165 | *.mp4 166 | *.prof 167 | benchmarks/logs/* 168 | *.vscode/* -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.6.3 5 | hooks: 6 | # Run the formatter. 7 | - id: ruff-format 8 | types_or: [python, pyi, jupyter] 9 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to mjinx 2 | 3 | We're thrilled that you're interested in contributing to mjinx! This document outlines the process for contributing to this project. 4 | 5 | ## How to Contribute 6 | 7 | There are many ways to contribute to mjinx: 8 | 9 | 1. Reporting bugs 10 | 2. Suggesting enhancements 11 | 3. Writing documentation 12 | 4. Submitting code changes 13 | 14 | ### Reporting Bugs 15 | 16 | 1. Check the [issue tracker](https://github.com/based-robotics/mjinx/issues) to see if the bug has already been reported. 17 | 2. If not, create a new issue. Provide a clear title and description, as much relevant information as possible, and a code sample or executable test case demonstrating the bug. 18 | 19 | ### Suggesting Enhancements 20 | 21 | 1. Check the [issue tracker](https://github.com/based-robotics/mjinx/issues) to see if the enhancement has already been suggested. 22 | 2. If not, create a new issue. Clearly describe the enhancement, why it would be useful, and any potential drawbacks. 23 | 24 | ### Writing Documentation 25 | 26 | Good documentation is crucial. If you notice any part of our documentation that could be improved or expanded, please let us know or submit a pull request with your suggested changes. 27 | 28 | ### Submitting Code Changes 29 | 30 | 1. Fork the repository. 31 | 2. Create a new branch for your changes. 32 | 3. Make your changes in your branch. 33 | > Please, pay attention to the code style and mypy typing, it should match the one in the repository. The authors suggest to use pre-commit: 34 | > ```bash 35 | > pip3 install pre-commit 36 | > pre-commit install 37 | > ``` 38 | > This is included in one of the tests, and if you ignore it, the tests won't pass. 39 | 40 | 4. Add or update tests as necessary. 41 | 5. Ensure the test suite passes. 42 | 6. Update the documentation as needed. 43 | 7. Push your branch and submit a pull request. 44 | 45 | ## Pull Request Process 46 | 47 | 1. Ensure your code follows the project's style guidelines. 48 | 2. Update the README.md or relevant documentation with details of changes, if applicable. 49 | 3. Add tests for your changes and ensure all tests pass. 50 | 4. Your pull request will be reviewed by the maintainers. They may suggest changes or improvements. 51 | 5. Once approved, your pull request will be merged. 52 | 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024 Ivan Domrachev 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [![mypy](https://img.shields.io/github/actions/workflow/status/based-robotics/mjinx/mypy.yml?branch=main&label=mypy)](https://github.com/based-robotics/mjinx/actions) 3 | [![ruff](https://img.shields.io/github/actions/workflow/status/based-robotics/mjinx/ruff.yml?branch=main&label=ruff)](https://github.com/based-robotics/mjinx/actions) 4 | [![docs](https://img.shields.io/github/actions/workflow/status/based-robotics/mjinx/docs.yml?branch=main&label=docs)](https://based-robotics.github.io/mjinx/) 5 | [![PyPI version](https://img.shields.io/pypi/v/mjinx?color=blue)](https://pypi.org/project/mjinx/) 6 | [![PyPI downloads](https://img.shields.io/pypi/dm/mjinx?color=blue)](https://pypistats.org/packages/mjinx) 7 | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/based-robotics/mjinx/blob/main/examples/notebooks/turoial.ipynb) 8 | 9 | 10 |

11 | 12 |

13 | 14 | 15 | **MJINX** is a python library for auto-differentiable numerical inverse kinematics built on **JAX** and **Mujoco MJX**. It draws inspiration from similar tools like the Pinocchio-based [PINK](https://github.com/stephane-caron/pink/tree/main) and Mujoco-based [MINK](https://github.com/kevinzakka/mink/tree/main). 16 | 17 |

18 | 19 | 20 | 21 | 22 |

23 | 24 | ## Key Features 25 | 1. **Flexibility**. Problems are constructed using modular `Components` that enforce desired behaviors or maintain system safety constraints. 26 | 2. **Multiple Solution Strategies**. Leveraging JAX's efficient sampling and automatic differentiation capabilities, MJINX implements various solvers optimized for different robotics scenarios. 27 | 3. **Full JAX Compatibility**. Both the optimal control formulation and solvers are fully JAX-compatible, enabling JIT compilation and automatic vectorization across the entire pipeline. 28 | 4. **User-Friendly Interface**. The API is designed with a clean, intuitive interface that simplifies complex inverse kinematics tasks while maintaining advanced functionality. 29 | 30 | ## Installation 31 | The package is available in PyPI registry, and could be installed via `pip`: 32 | ```bash 33 | pip install mjinx 34 | ``` 35 | 36 | Different installation versions: 37 | 1. Visualization tool `mjinx.visualization.BatchVisualizer` is available in `mjinx[visual]` 38 | 2. To run examples, install `mjinx[examples]` 39 | 3. To install development version, install `mjinx[dev]` (preferably in editable mode) 40 | 4. To build docs, install `mjinx[docs]` 41 | 5. To install the repository with all dependencies, install `mjinx[all]` 42 | 43 | Note that by default installation of `mjinx` installs `jax` without cuda support. If you need it, please install `jax>=0.5.0` with CUDA support manually. 44 | 45 | ## Usage 46 | Here is the example of `mjinx` usage: 47 | 48 | ```python 49 | from mujoco import mjx mjx 50 | from mjinx.problem import Problem 51 | 52 | # Initialize the robot model using MuJoCo 53 | MJCF_PATH = "path_to_mjcf.xml" 54 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH) 55 | mjx_model = mjx.put_model(mj_model) 56 | 57 | # Create instance of the problem 58 | problem = Problem(mjx_model) 59 | 60 | # Add tasks to track desired behavior 61 | frame_task = FrameTask("ee_task", cost=1, gain=20, body_name="link7") 62 | problem.add_component(frame_task) 63 | 64 | # Add barriers to keep robot in a safety set 65 | joints_barrier = JointBarrier("jnt_range", gain=10) 66 | problem.add_component(joints_barrier) 67 | 68 | # Initialize the solver 69 | solver = LocalIKSolver(mjx_model) 70 | 71 | # Initializing initial condition 72 | q0 = np.zeros(7) 73 | 74 | # Initialize solver data 75 | solver_data = solver.init() 76 | 77 | # jit-compiling solve and integrate 78 | solve_jit = jax.jit(solver.solve) 79 | integrate_jit = jax.jit(integrate, static_argnames=["dt"]) 80 | 81 | # === Control loop === 82 | for t in np.arange(0, 5, 1e-2): 83 | # Changing problem and compiling it 84 | frame_task.target_frame = np.array([0.1 * np.sin(t), 0.1 * np.cos(t), 0.1, 1, 0, 0,]) 85 | problem_data = problem.compile() 86 | 87 | # Solving the instance of the problem 88 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 89 | 90 | # Integrating 91 | q = integrate_jit( 92 | mjx_model, 93 | q, 94 | opt_solution.v_opt, 95 | dt, 96 | ) 97 | ``` 98 | 99 | ## Examples 100 | The list of examples includes: 101 | 1. `Kuka iiwa` local inverse kinematics ([single item](examples/local_ik.py), [vmap over desired trajectory](examples/local_ik_vmapped_output.py)) 102 | 2. `Kuka iiwa` global inverse kinematics ([single item](examples/global_ik.py), [vmap over desired trajectory](examples/global_ik_vmapped_output.py)) 103 | 3. `Go2` [batched squats](examples/go2_squat.py) example 104 | 105 | > **Note:** The Global IK functionality is currently under development and not yet working properly as expected. It needs proper tuning and will be fixed in future updates. Use the Global IK examples with caution and expect suboptimal results. 106 | 107 | 108 | ## Citing MJINX 109 | 110 | If you use MJINX in your research, please cite it as follows: 111 | 112 | ```bibtex 113 | @software{mjinx25, 114 | author = {Domrachev, Ivan and Nedelchev, Simeon}, 115 | license = {MIT}, 116 | month = mar, 117 | title = {{MJINX: Differentiable GPU-accelerated inverse kinematics in JAX}}, 118 | url = {https://github.com/based-robotics/mjinx}, 119 | version = {0.1.1}, 120 | year = {2025} 121 | } 122 | ``` 123 | 124 | ## Contributing 125 | We welcome suggestions and contributions. Please see our [CONTRIBUTING.md](CONTRIBUTING.md) file for guidelines. 126 | 127 | ## Acknowledgements 128 | I am deeply grateful to Simeon Nedelchev, whose guidance and expertise were instrumental in bringing this project to life. 129 | 130 | This work draws significant inspiration from [`pink`](https://github.com/stephane-caron/pink) by Stéphane Caron and [`mink`](https://github.com/kevinzakka/mink) by Kevin Zakka. Their pioneering work in robotics and open source has been a guiding light for this project. 131 | 132 | The codebase incorporates utility functions from [`MuJoCo MJX`](https://github.com/google-deepmind/mujoco/tree/main/mjx). Beyond being an excellent tool for batched computations and machine learning, MJX's codebase serves as a masterclass in clean, informative implementation of physical simulations and JAX usage. 133 | 134 | Special thanks to [IRIS lab](http://iris.kaist.ac.kr/) for their support. 135 | -------------------------------------------------------------------------------- /benchmarks/bench_global_ik_vmapped_output.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from time import perf_counter 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import mujoco as mj 8 | import mujoco.mjx as mjx 9 | import numpy as np 10 | from optax import adam 11 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH 12 | 13 | from mjinx.components.barriers import JointBarrier, PositionBarrier 14 | from mjinx.components.tasks import FrameTask 15 | from mjinx.configuration import integrate 16 | from mjinx.problem import Problem 17 | from mjinx.solvers import GlobalIKSolver 18 | 19 | # === Mujoco === 20 | 21 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH) 22 | mj_data = mj.MjData(mj_model) 23 | 24 | mjx_model = mjx.put_model(mj_model) 25 | 26 | q_min = mj_model.jnt_range[:, 0].copy() 27 | q_max = mj_model.jnt_range[:, 1].copy() 28 | 29 | # === Mjinx === 30 | 31 | # --- Constructing the problem --- 32 | problem = Problem(mjx_model) 33 | 34 | # Creating components of interest and adding them to the problem 35 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7") 36 | position_barrier = PositionBarrier( 37 | "ee_barrier", 38 | gain=0.1, 39 | obj_name="link7", 40 | limit_type="max", 41 | p_max=0.4, 42 | safe_displacement_gain=1e-2, 43 | mask=[1, 0, 0], 44 | ) 45 | joints_barrier = JointBarrier("jnt_range", gain=0.1) 46 | 47 | problem.add_component(frame_task) 48 | problem.add_component(position_barrier) 49 | problem.add_component(joints_barrier) 50 | 51 | # Compiling the problem upon any parameters update 52 | problem_data = problem.compile() 53 | 54 | # Initializing solver and its initial state 55 | solver = GlobalIKSolver(mjx_model, adam(learning_rate=1e-2), dt=1e-2) 56 | 57 | # Initializing initial condition 58 | N_batch = 10000 59 | np.random.seed(42) 60 | q0 = jnp.array( 61 | [ 62 | -1.4238753, 63 | -1.7268502, 64 | -0.84355015, 65 | 2.0962472, 66 | 2.1339328, 67 | 2.0837479, 68 | -2.5521986, 69 | ] 70 | ) 71 | q = jnp.array( 72 | [ 73 | np.clip( 74 | q0 75 | + np.random.uniform( 76 | -0.1, 77 | 0.1, 78 | size=(mj_model.nq), 79 | ), 80 | q_min + 1e-1, 81 | q_max - 1e-1, 82 | ) 83 | for _ in range(N_batch) 84 | ] 85 | ) 86 | 87 | # --- Batching --- 88 | solver_data = jax.vmap(solver.init, in_axes=0)(q) 89 | with problem.set_vmap_dimension() as empty_problem_data: 90 | empty_problem_data.components["ee_task"].target_frame = 0 91 | 92 | solve_jit = jax.jit(jax.vmap(solver.solve, in_axes=(0, 0, empty_problem_data))) 93 | integrate_jit = jax.jit(jax.vmap(integrate, in_axes=(None, 0, 0, None)), static_argnames=["dt"]) 94 | 95 | # === Control loop === 96 | dt = 1e-2 97 | ts = np.arange(0, 10, dt) 98 | 99 | metrics: dict[str, list] = { 100 | "update_time_ms": [], 101 | "joint_positions": [], 102 | "joint_velocities": [], 103 | "target_frames": [], 104 | "optimization_steps": [], # Track number of optimization steps 105 | } 106 | 107 | try: 108 | for t in ts: 109 | # Changing desired values 110 | target_frames = np.array( 111 | [ 112 | [ 113 | 0.4 + 0.3 * np.sin(t + 2 * np.pi * i / N_batch), 114 | 0.2, 115 | 0.4 + 0.3 * np.cos(t + 2 * np.pi * i / N_batch), 116 | 1, 117 | 0, 118 | 0, 119 | 0, 120 | ] 121 | for i in range(N_batch) 122 | ] 123 | ) 124 | frame_task.target_frame = target_frames 125 | problem_data = problem.compile() 126 | 127 | # Solving the instance of the problem 128 | t0 = perf_counter() 129 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 130 | t1 = perf_counter() 131 | 132 | # Update positions directly for global IK 133 | q = opt_solution.q_opt 134 | 135 | # Logging 136 | t_update_ms = (t1 - t0) * 1e3 137 | 138 | # Log metrics 139 | metrics["update_time_ms"].append(t_update_ms) 140 | metrics["joint_positions"].append(np.array(q)) 141 | metrics["joint_velocities"].append(opt_solution.v_opt) 142 | metrics["target_frames"].append(target_frames) 143 | metrics["optimization_steps"].append(3) # Fixed number of steps used 144 | 145 | print(f"t={t:.2f}, Update time: {t_update_ms:.2f}ms") 146 | 147 | except KeyboardInterrupt: 148 | print("Finalizing the simulation as requested...") 149 | except Exception as e: 150 | print(e) 151 | 152 | # Calculate and print summary statistics 153 | compilation_time = metrics["update_time_ms"][0] 154 | mean_update_time = np.mean(metrics["update_time_ms"][1:]) 155 | std_update_time = np.std(metrics["update_time_ms"][1:]) 156 | print("Benchmark Summary:") 157 | print(f"Mean update time: {mean_update_time:.2f} ± {std_update_time:.2f} ms") 158 | 159 | # Save data 160 | save_data = { 161 | "timestamps": ts, 162 | "compilation_time": compilation_time, 163 | "n_batch": N_batch, 164 | "mean_update_time": mean_update_time, 165 | "std_update_time": std_update_time, 166 | "learning_rate": 1e-2, # Save optimizer parameters 167 | } 168 | 169 | # Add component values and metrics to save data 170 | for metric_name, values in metrics.items(): 171 | save_data[metric_name] = np.array(values) 172 | 173 | # Save to npz file 174 | timestamp = int(perf_counter()) 175 | 176 | # Save to npz file 177 | timestamp = datetime.now().strftime("%Y_%d_%m_%H_%M_%S") 178 | script_dir = os.path.abspath(os.path.dirname(__file__)) 179 | 180 | filename = f"{script_dir}/logs/local_ik_benchmark_{timestamp}.npz" 181 | np.savez_compressed(filename, **save_data) 182 | print(f"\nBenchmark data saved to: {filename}") 183 | -------------------------------------------------------------------------------- /benchmarks/bench_local_ik.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from time import perf_counter 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import mujoco as mj 8 | import mujoco.mjx as mjx 9 | import numpy as np 10 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH 11 | 12 | from mjinx.components.barriers import JointBarrier, PositionBarrier, SelfCollisionBarrier 13 | from mjinx.components.tasks import FrameTask 14 | from mjinx.configuration import integrate 15 | from mjinx.problem import Problem 16 | from mjinx.solvers import LocalIKSolver 17 | 18 | # === Mujoco === 19 | 20 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH) 21 | mj_data = mj.MjData(mj_model) 22 | mjx_model = mjx.put_model(mj_model) 23 | 24 | # === Mjinx === 25 | problem = Problem(mjx_model, v_min=-100, v_max=100) 26 | 27 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7") 28 | position_barrier = PositionBarrier( 29 | "ee_barrier", 30 | gain=100, 31 | obj_name="link7", 32 | limit_type="max", 33 | p_max=0.3, 34 | safe_displacement_gain=1e-2, 35 | mask=[1, 0, 0], 36 | ) 37 | joints_barrier = JointBarrier("jnt_range", gain=10) 38 | self_collision_barrier = SelfCollisionBarrier( 39 | "self_collision_barrier", 40 | gain=1.0, 41 | d_min=0.01, 42 | ) 43 | 44 | problem.add_component(frame_task) 45 | problem.add_component(position_barrier) 46 | problem.add_component(joints_barrier) 47 | problem.add_component(self_collision_barrier) 48 | 49 | problem_data = problem.compile() 50 | 51 | solver = LocalIKSolver(mjx_model, maxiter=20) 52 | 53 | q = jnp.array( 54 | [ 55 | -1.4238753, 56 | -1.7268502, 57 | -0.84355015, 58 | 2.0962472, 59 | 2.1339328, 60 | 2.0837479, 61 | -2.5521986, 62 | ] 63 | ) 64 | solver_data = solver.init() 65 | 66 | solve_jit = jax.jit(solver.solve) 67 | integrate_jit = jax.jit(integrate, static_argnames=["dt"]) 68 | 69 | 70 | # === Control loop === 71 | dt = 1e-2 72 | ts = np.arange(0, 10, dt) 73 | 74 | # Additional metrics for comprehensive benchmarking 75 | metrics: dict[str, list] = { 76 | "update_time_ms": [], 77 | "joint_positions": [], 78 | "joint_velocities": [], 79 | "target_frames": [], 80 | "solver_iterations": [], 81 | } 82 | 83 | for t in ts: 84 | # Changing desired values 85 | target_frame = np.array([0.2 + 0.2 * jnp.sin(t) ** 2, 0.2, 0.2, 1, 0, 0, 0]) 86 | frame_task.target_frame = target_frame 87 | # After changes, recompiling the model 88 | problem_data = problem.compile() 89 | 90 | # Solving the instance of the problem 91 | t0 = perf_counter() 92 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 93 | t1 = perf_counter() 94 | 95 | # Integrating 96 | q = integrate_jit( 97 | mjx_model, 98 | q, 99 | velocity=opt_solution.v_opt, 100 | dt=dt, 101 | ) 102 | 103 | # Logging 104 | t_update_ms = (t1 - t0) * 1e3 105 | 106 | # Log additional metrics 107 | metrics["update_time_ms"].append(t_update_ms) 108 | metrics["joint_positions"].append(np.array(q.copy())) 109 | metrics["joint_velocities"].append(np.array(opt_solution.v_opt.copy())) 110 | metrics["target_frames"].append(target_frame) 111 | metrics["solver_iterations"].append(opt_solution.iterations) 112 | 113 | print(f"t={t:.2f}, Update time: {t_update_ms:.2f}ms, Iterations: {opt_solution.iterations}") 114 | 115 | # Calculate and print summary statistics 116 | compilation_time = metrics["update_time_ms"][0] 117 | mean_update_time = np.mean(metrics["update_time_ms"][1:]) 118 | std_update_time = np.std(metrics["update_time_ms"][1:]) 119 | mean_iterations = np.mean(metrics["solver_iterations"]) 120 | print() 121 | print("Benchmark Summary:") 122 | print(f"Mean update time: {mean_update_time:.2f} ± {std_update_time:.2f} ms") 123 | print(f"Mean solver iterations: {mean_iterations:.2f}") 124 | 125 | # Convert lists to numpy arrays for saving 126 | save_data = { 127 | "timestamps": ts, 128 | "compilation_time": compilation_time, 129 | "mean_update_time": mean_update_time, 130 | "std_update_time": std_update_time, 131 | "mean_iterations": mean_iterations, 132 | } 133 | 134 | 135 | # Add metrics to save data 136 | for metric_name, values in metrics.items(): 137 | save_data[metric_name] = np.array(values) 138 | 139 | # Save to npz file 140 | timestamp = datetime.now().strftime("%Y_%d_%m_%H_%M_%S") 141 | script_dir = os.path.abspath(os.path.dirname(__file__)) 142 | 143 | filename = f"{script_dir}/logs/local_ik_benchmark_{timestamp}.npz" 144 | np.savez_compressed(filename, **save_data) 145 | print(f"\nBenchmark data saved to: {filename}") 146 | -------------------------------------------------------------------------------- /benchmarks/bench_local_ik_vmapped_output.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from time import perf_counter 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import mujoco as mj 8 | import mujoco.mjx as mjx 9 | import numpy as np 10 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH 11 | 12 | from mjinx.components.barriers import JointBarrier, PositionBarrier 13 | from mjinx.components.tasks import FrameTask 14 | from mjinx.configuration import integrate 15 | from mjinx.problem import Problem 16 | from mjinx.solvers import LocalIKSolver 17 | 18 | # === Mujoco === 19 | 20 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH) 21 | mj_data = mj.MjData(mj_model) 22 | 23 | mjx_model = mjx.put_model(mj_model) 24 | 25 | q_min = mj_model.jnt_range[:, 0].copy() 26 | q_max = mj_model.jnt_range[:, 1].copy() 27 | 28 | # === Mjinx === 29 | 30 | # --- Constructing the problem --- 31 | # Creating problem formulation 32 | problem = Problem(mjx_model, v_min=-5, v_max=5) 33 | 34 | # Creating components of interest and adding them to the problem 35 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7") 36 | position_barrier = PositionBarrier( 37 | "ee_barrier", 38 | gain=100, 39 | obj_name="link7", 40 | limit_type="max", 41 | p_max=0.4, 42 | safe_displacement_gain=1e-2, 43 | mask=[1, 0, 0], 44 | ) 45 | joints_barrier = JointBarrier("jnt_range", gain=10) 46 | 47 | problem.add_component(frame_task) 48 | problem.add_component(position_barrier) 49 | problem.add_component(joints_barrier) 50 | 51 | # Compiling the problem upon any parameters update 52 | problem_data = problem.compile() 53 | 54 | # Initializing solver and its initial state 55 | solver = LocalIKSolver(mjx_model, maxiter=20) 56 | 57 | # Initializing initial condition 58 | N_batch = 10000 59 | q0 = np.array( 60 | [ 61 | -1.5878328, 62 | -2.0968683, 63 | -1.4339591, 64 | 1.6550868, 65 | 2.1080072, 66 | 1.646142, 67 | -2.982619, 68 | ] 69 | ) 70 | q = jnp.array( 71 | [ 72 | np.clip( 73 | q0 74 | + np.random.uniform( 75 | -0.1, 76 | 0.1, 77 | size=(mj_model.nq), 78 | ), 79 | q_min + 1e-1, 80 | q_max - 1e-1, 81 | ) 82 | for _ in range(N_batch) 83 | ] 84 | ) 85 | 86 | # --- Batching --- 87 | solver_data = jax.vmap(solver.init, in_axes=0)(v_init=jnp.zeros((N_batch, mjx_model.nv))) 88 | with problem.set_vmap_dimension() as empty_problem_data: 89 | empty_problem_data.components["ee_task"].target_frame = 0 90 | solve_jit = jax.jit( 91 | jax.vmap( 92 | solver.solve, 93 | in_axes=(0, 0, empty_problem_data), 94 | ) 95 | ) 96 | integrate_jit = jax.jit(jax.vmap(integrate, in_axes=(None, 0, 0, None)), static_argnames=["dt"]) 97 | 98 | 99 | # === Control loop === 100 | dt = 1e-2 101 | ts = np.arange(0, 10, dt) 102 | 103 | metrics: dict[str, list] = { 104 | "update_time_ms": [], 105 | "joint_positions": [], 106 | "joint_velocities": [], 107 | "target_frames": [], 108 | "solver_iterations": [], 109 | } 110 | 111 | try: 112 | for t in ts: 113 | # Changing desired values 114 | target_frames = np.array( 115 | [ 116 | [ 117 | 0.4 + 0.3 * np.sin(t + 2 * np.pi * i / N_batch), 118 | 0.2, 119 | 0.4 + 0.3 * np.cos(t + 2 * np.pi * i / N_batch), 120 | 1, 121 | 0, 122 | 0, 123 | 0, 124 | ] 125 | for i in range(N_batch) 126 | ] 127 | ) 128 | frame_task.target_frame = target_frames 129 | problem_data = problem.compile() 130 | 131 | # Solving the instance of the problem 132 | t0 = perf_counter() 133 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 134 | t1 = perf_counter() 135 | 136 | # Integrating 137 | q = integrate_jit( 138 | mjx_model, 139 | q, 140 | opt_solution.v_opt, 141 | dt, 142 | ) 143 | 144 | # Logging 145 | t_update_ms = (t1 - t0) * 1e3 146 | 147 | # Log metrics 148 | metrics["update_time_ms"].append(t_update_ms) 149 | metrics["joint_positions"].append(np.array(q)) 150 | metrics["joint_velocities"].append(np.array(opt_solution.v_opt)) 151 | metrics["target_frames"].append(np.array(target_frames).copy()) 152 | metrics["solver_iterations"].append(np.array(opt_solution.iterations)) 153 | 154 | print(f"t={t:.2f}, Update time: {t_update_ms:.2f}ms, Mean iterations: {np.mean(opt_solution.iterations)}") 155 | 156 | except KeyboardInterrupt: 157 | print("Finalizing the simulation as requested...") 158 | except Exception as e: 159 | print(e) 160 | 161 | # Calculate and print summary statistics 162 | compilation_time = metrics["update_time_ms"][0] 163 | mean_update_time = np.mean(metrics["update_time_ms"][1:]) 164 | std_update_time = np.std(metrics["update_time_ms"][1:]) 165 | mean_iterations = np.mean(metrics["solver_iterations"]) 166 | 167 | print() 168 | print("Benchmark Summary:") 169 | print(f"Mean update time: {mean_update_time:.2f} ± {std_update_time:.2f} ms") 170 | print(f"Mean solver iterations: {mean_iterations:.2f}") 171 | 172 | # Save data 173 | save_data = { 174 | "timestamps": ts, 175 | "n_batch": N_batch, 176 | "compilation_time": compilation_time, 177 | "mean_update_time": mean_update_time, 178 | "std_update_time": std_update_time, 179 | "mean_iterations": mean_iterations, 180 | } 181 | 182 | # Add component values and metrics to save data 183 | for metric_name, values in metrics.items(): 184 | save_data[metric_name] = np.array(values) 185 | 186 | # Save to npz file 187 | timestamp = datetime.now().strftime("%Y_%d_%m_%H_%M_%S") 188 | script_dir = os.path.abspath(os.path.dirname(__file__)) 189 | 190 | filename = f"{script_dir}/logs/local_ik_vmapped_benchmark_{timestamp}.npz" 191 | np.savez_compressed(filename, **save_data) 192 | print(f"\nBenchmark data saved to: {filename}") 193 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # `mjinx` documentation 2 | 3 | The documentation is built using [Sphinx](https://www.sphinx-doc.org/en/master/), [Read the docs](https://docs.readthedocs.io/en/stable/) template and Python. 4 | 5 | The website is available at url.com. 6 | 7 | ## Building locally 8 | To build and test the website locally, do the following: 9 | ```bash 10 | pip install ".[docs]" 11 | rm -r _build && sphinx-build -M html docs _build 12 | ``` 13 | 14 | And open the `file:///home//mjinx/_build/index.html` file in the browser. -------------------------------------------------------------------------------- /docs/barriers.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/barriers.rst 2 | 3 | .. _Barriers: 4 | 5 | 6 | Barriers implement inequality constraints in the inverse kinematics problem. Each barrier defines a scalar function :math:`h: \mathcal{Q} \rightarrow \mathbb{R}` that must remain positive (:math:`h(q) \geq 0`), creating a boundary that the solver must respect. This approach is mathematically equivalent to control barrier functions (CBFs) :cite:`ames2019control`, which have been widely adopted in robotics for safety-critical control. 7 | 8 | The barrier formulation creates a continuous constraint boundary that prevents the system from entering prohibited regions of the configuration space. In the optimization problem, these constraints are typically enforced through linearization at each step: 9 | 10 | .. math:: 11 | 12 | \nabla h(q)^T \dot{q} \geq -\alpha h(q) 13 | 14 | where :math:`\alpha` is a gain parameter that controls how aggressively the system is pushed away from the constraint boundary. 15 | 16 | 17 | All barriers follow a consistent mathematical formulation while adapting to specific constraint types, enabling systematic enforcement of safety and feasibility requirements. 18 | 19 | Base Barrier 20 | ------------ 21 | The foundation for all barrier constraints, defining the core mathematical properties and interface. 22 | 23 | .. automodule:: mjinx.components.barriers._base 24 | :members: 25 | :member-order: bysource 26 | 27 | Joint Barrier 28 | ------------- 29 | Enforces joint limit constraints, preventing the robot from exceeding mechanical limits. 30 | 31 | .. automodule:: mjinx.components.barriers._joint_barrier 32 | :members: 33 | :member-order: bysource 34 | 35 | Base Body Barrier 36 | ----------------- 37 | The foundation for barriers applied to specific bodies, geometries, or sites in the robot model. This abstract class provides common functionality for all object-specific barriers. 38 | 39 | .. automodule:: mjinx.components.barriers._obj_barrier 40 | :members: 41 | :member-order: bysource 42 | 43 | Body Position Barrier 44 | --------------------- 45 | Enforces position constraints on specific bodies, geometries, or sites, useful for defining workspace limits. 46 | 47 | .. automodule:: mjinx.components.barriers._obj_position_barrier 48 | :members: 49 | :member-order: bysource 50 | 51 | Self Collision Barrier 52 | ---------------------- 53 | Prevents different parts of the robot from colliding with each other, essential for complex manipulators. 54 | 55 | .. automodule:: mjinx.components.barriers._self_collision_barrier 56 | :members: 57 | :member-order: bysource -------------------------------------------------------------------------------- /docs/components.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/main/docs/components.rst 2 | 3 | ========== 4 | Components 5 | ========== 6 | 7 | MJINX employs a component-based architecture to formulate inverse kinematics problems through functional decomposition. Each component encapsulates a mathematical mapping from the configuration space to a task or constraint space, functioning either as an objective function (task) or an inequality constraint (barrier). 8 | 9 | This modular structure facilitates the systematic construction of complex kinematic problems through composition of elementary components. The approach aligns with established practices in robotics control theory, where complex behaviors emerge from the coordination of simpler control objectives. 10 | 11 | Components are characterized by several key attributes: 12 | 13 | - A unique identifier for reference within the problem formulation 14 | - A gain parameter that determines its relative weight in the optimization 15 | - An optional dimensional mask for selective application 16 | - A differentiable function mapping robot state to output values 17 | 18 | This formulation follows the task-priority framework established in robotics literature, where multiple objectives are managed through appropriate weighting and prioritization. The separation of concerns between tasks and constraints provides a natural expression of both the desired behavior and the feasible region of operation. 19 | 20 | When integrated into a Problem instance, components form a well-posed optimization problem. Tasks define the objective function to be minimized, while barriers establish the constraint manifold. The solver then computes solutions that optimize the weighted task objectives while maintaining feasibility with respect to all constraints. 21 | 22 | ************** 23 | Base Component 24 | ************** 25 | 26 | The Component class serves as the abstract base class from which all specific component implementations derive. This inheritance hierarchy ensures a consistent interface while enabling specialized behavior for different component types. 27 | 28 | .. automodule:: mjinx.components._base 29 | :members: 30 | :special-members: __call__ 31 | :member-order: bysource 32 | 33 | ******** 34 | Tasks 35 | ******** 36 | 37 | Tasks define objective functions that map from configuration space to task space, with the solver minimizing weighted errors between current and desired values :cite:`kanoun2011kinematic`. Each task :math:`f: \mathcal{Q} \rightarrow \mathbb{R}^m` produces an error :math:`e(q) = f(q) - f_{desired}` that is minimized according to :math:`\|e(q)\|^2_W`. 38 | 39 | MJINX provides task implementations for common robotics objectives: 40 | 41 | .. toctree:: 42 | :maxdepth: 1 43 | 44 | tasks 45 | 46 | ******** 47 | Barriers 48 | ******** 49 | Barriers implement inequality constraints through scalar functions :math:`h(q) \geq 0` that create boundaries the solver must respect. Based on control barrier functions (CBFs) :cite:`ames2019control`, these constraints are enforced through differential inequality: :math:`\nabla h(q)^T v \geq -\alpha h(q)`, with :math:`\alpha` controls constraint enforcement and :math:`v` is the velocity vector. 50 | 51 | MJINX provides several barrier implementations: 52 | 53 | .. toctree:: 54 | :maxdepth: 1 55 | 56 | barriers 57 | 58 | *********** 59 | Constraints 60 | *********** 61 | 62 | Constraints represent a new type of component with a strictly enforced equality condition :math:`f(q) = 0`. Those constraints might be either treated strictly as differentiated exponentially stable equality: :math:`\nabla h(q)^T v = -\alpha h(q)`, with :math:`\alpha` controls constraint enforcement and :math:`v` is the velocity vector, or as a soft constraint -- task with high gain. 63 | 64 | Yet, only the following constraints are implemented: 65 | 66 | .. toctree:: 67 | :maxdepth: 1 68 | 69 | constraints -------------------------------------------------------------------------------- /docs/configuration.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/configuration.rst 2 | 3 | ============= 4 | Configuration 5 | ============= 6 | 7 | The configuration module serves as the mathematical foundation of MJINX, providing essential utilities for robot kinematics, transformations, and state management. These core functions enable precise control and manipulation of robotic systems throughout the library. 8 | 9 | The module is structured into three complementary categories, each addressing a critical aspect of robot configuration: 10 | 11 | 1. **Model** - Functions for manipulating the MuJoCo model and managing robot state 12 | 2. **Lie Algebra** - Specialized tools for handling rotations and transformations with mathematical rigor 13 | 3. **Collision** - Algorithms for detecting and responding to potential collisions 14 | 15 | ****** 16 | Model 17 | ****** 18 | 19 | The model operations provide fundamental capabilities for state integration, Jacobian computation, and frame transformations. These functions form the bridge between abstract mathematical representations and the physical robot model. 20 | 21 | 22 | .. automodule:: mjinx.configuration._model 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | :member-order: bysource 27 | 28 | ************ 29 | Lie Algebra 30 | ************ 31 | 32 | The Lie algebra operations implement sophisticated mathematical tools for handling rotations, quaternions, and transformations in 3D space. Based on principles from differential geometry, these functions ensure proper handling of the SE(3) and SO(3) Lie groups. 33 | 34 | .. automodule:: mjinx.configuration._lie 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | :member-order: bysource 39 | 40 | ********** 41 | Collision 42 | ********** 43 | 44 | The collision operations provide sophisticated algorithms for detecting potential collisions, computing distances between objects, and analyzing contact points. These functions are crucial for implementing safety constraints and realistic physical interactions. 45 | 46 | .. automodule:: mjinx.configuration._collision 47 | :members: 48 | :undoc-members: 49 | :show-inheritance: 50 | :member-order: bysource -------------------------------------------------------------------------------- /docs/constraints.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/constraints.rst 2 | 3 | .. _Constraints: 4 | 5 | Constraints represent a new type of component with a strictly enforced equality, either hardly or softly. 6 | 7 | The constratin is formulated as :math:`f(q) = 0`. Those constraints might be either treated strictly as differentiated exponentially stable equality: :math:`\nabla h(q)^T v = -\alpha h(q)`, with :math:`\alpha` controls constraint enforcement and :math:`v` is the velocity vector, or as a soft constraint -- task with high gain. 8 | 9 | Base Constraint 10 | --------------- 11 | The foundational class that all constraints extend. It defines the core interface and mathematical properties for constraint objectives. 12 | 13 | .. automodule:: mjinx.components.constraints._base 14 | :members: 15 | :member-order: bysource 16 | 17 | Model Equality Constraint 18 | ------------------------- 19 | The constraints, described in MuJoCo model as equality constrainst. 20 | 21 | .. automodule:: mjinx.components.constraints._equality_constraint 22 | :members: 23 | :member-order: bysource 24 | 25 | .. Base Task 26 | .. --------- 27 | .. The foundational class that all tasks extend. It defines the core interface and mathematical properties for task objectives. 28 | 29 | .. .. automodule:: mjinx.components.tasks._base 30 | .. :members: 31 | .. :member-order: bysource 32 | -------------------------------------------------------------------------------- /docs/css/custom.css: -------------------------------------------------------------------------------- 1 | /* Jinx color palette */ 2 | /* #03587f - dark blue */ 3 | /* #a75a97 - pink */ 4 | /* #a16f6f - brown */ 5 | /* #d9ccc9 - beige */ 6 | /* #aecadc - light blue */ 7 | a, a:visited, a:focus { 8 | color: rgb(3, 88, 127); /* dark blue */ 9 | text-decoration: none; 10 | } 11 | 12 | a:hover, a:active { 13 | color: rgb(3, 88, 127); /* dark blue */ 14 | text-decoration: none; 15 | } 16 | 17 | .wy-menu-vertical header, .wy-menu-vertical p.caption { 18 | color: white; 19 | font-weight: 600; 20 | } 21 | 22 | .wy-menu-vertical a { 23 | color: rgb(167, 90, 151); /* pink */ 24 | } 25 | 26 | .wy-menu-vertical a:hover { 27 | background: rgb(167, 90, 151); /* pink */ 28 | color: white; 29 | } 30 | 31 | .wy-menu-vertical header, .wy-menu-vertical p.caption { 32 | color: black; 33 | font-weight: 600; 34 | } 35 | 36 | .wy-menu-vertical a{ 37 | color: black; 38 | } 39 | 40 | .wy-nav-content-wrap, .wy-menu li.current > a { 41 | background-color: white; 42 | } 43 | 44 | .wy-nav-side { 45 | background: rgb(174, 202, 220); /* light blue */ 46 | } 47 | 48 | .wy-nav-top { 49 | background: rgb(217, 204, 201); /* beige */ 50 | } 51 | 52 | .wy-nav-top a { 53 | color: black; 54 | } 55 | 56 | .wy-nav-top i { 57 | color: rgb(167, 90, 151); /* pink */ 58 | } 59 | 60 | .wy-side-nav-search { 61 | background: rgb(3, 88, 127); /* dark blue */ 62 | } 63 | 64 | .wy-side-nav-search > div.version { 65 | color: white; 66 | } 67 | 68 | .wy-side-nav-search > a { 69 | color: white; 70 | } 71 | 72 | pre { 73 | background: #efd0f3; /* beige */ 74 | border-left: 5px solid #c076b1; /* brown */ 75 | } 76 | 77 | html.writer-html4 .rst-content dl:not(.docutils) > dt, 78 | html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.citation):not(.glossary):not(.simple) > dt { 79 | background: rgb(217, 204, 201); /* beige */ 80 | border-top: 3px solid #a16f6f; /* brown */ 81 | color: rgb(3, 88, 127); /* dark blue */ 82 | } 83 | 84 | .rst-content .note { 85 | background: rgb(217, 204, 201); /* beige */ 86 | } 87 | 88 | .rst-content .note .admonition-title { 89 | background: rgb(3, 88, 127); /* dark blue */ 90 | } 91 | 92 | .rst-content code.literal { 93 | color: rgb(161, 111, 111); /* brown */ 94 | } 95 | -------------------------------------------------------------------------------- /docs/developer-notes.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/developer-notes.rst 2 | 3 | ================= 4 | Developer Notes 5 | ================= 6 | 7 | This section contains information for developers who want to contribute to MJINX or understand its internals better. 8 | 9 | ******************* 10 | Code Organization 11 | ******************* 12 | 13 | MJINX follows a modular architecture: 14 | 15 | - ``mjinx/components/`` - Task and barrier implementations 16 | - ``mjinx/configuration/`` - Kinematics and transformation utilities 17 | - ``mjinx/solvers/`` - Inverse kinematics solvers 18 | - ``mjinx/visualize.py`` - Visualization tools 19 | - ``mjinx/problem.py`` - Problem construction and management 20 | - ``mjinx/typing.py`` - Type definitions 21 | 22 | ************************** 23 | Development Guidelines 24 | ************************** 25 | 26 | When contributing to MJINX, please follow these guidelines: 27 | 28 | 1. **Type annotations**: Use type annotations throughout the code. 29 | 2. **Documentation**: Write clear docstrings with reStructuredText format. 30 | 3. **Testing**: Add tests for new features using pytest. 31 | 4. **JAX compatibility**: Ensure new code works with JAX transformations. 32 | 5. **Performance**: Consider computation efficiency, especially for operations in inner loops. 33 | 34 | ****************** 35 | JAX Considerations 36 | ****************** 37 | 38 | MJINX leverages JAX for automatic differentiation and acceleration. When working with JAX: 39 | 40 | - Use JAX's functional programming style 41 | - Avoid in-place mutations 42 | - Remember that JAX arrays are immutable 43 | - Use ``jit``, ``vmap``, and ``grad`` to leverage JAX's transformations 44 | - Test code with both CPU and GPU devices 45 | 46 | For more information, refer to the `JAX documentation `_. 47 | 48 | **************** 49 | Types 50 | **************** 51 | 52 | MJINX uses a comprehensive type system to ensure code correctness and improve developer experience. Understanding the type system is essential for contributing to the codebase and extending its functionality. 53 | 54 | The type system provides clear interfaces between components, helps catch errors at development time, and makes the code more maintainable. All new contributions should adhere to the established typing conventions. 55 | 56 | For detailed information about MJINX's type system, including type aliases and enumerations, see the :doc:`typing` documentation. 57 | 58 | .. toctree:: 59 | :maxdepth: 2 60 | 61 | typing 62 | -------------------------------------------------------------------------------- /docs/img/kuka_iiwa_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/docs/img/kuka_iiwa_14.png -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/index.rst 2 | 3 | .. title:: Table of Contents 4 | 5 | ##### 6 | MJINX 7 | ##### 8 | 9 | .. raw:: html 10 | 11 |
12 | 13 | |colab| |pypi_version| |pypi_downloads| 14 | 15 | .. |colab| image:: https://colab.research.google.com/assets/colab-badge.svg 16 | :target: https://colab.research.google.com/github/based-robotics/mjinx/blob/main/examples/notebooks/turoial.ipynb 17 | :alt: Open in Colab 18 | 19 | .. |pypi_version| image:: https://img.shields.io/pypi/v/mjinx?color=blue 20 | :target: https://pypi.org/project/mjinx/ 21 | :alt: PyPI version 22 | 23 | .. |pypi_downloads| image:: https://img.shields.io/pypi/dm/mjinx?color=blue 24 | :target: https://pypistats.org/packages/mjinx 25 | :alt: PyPI downloads 26 | 27 | **MJINX** is a high-performance library for differentiable inverse kinematics, powered by `JAX `_ 28 | and `MuJoCo MJX `_. The library was inspired by the Pinocchio-based tool `PINK `_ and Mujoco-based analogue `MINK `_. 29 | 30 | .. raw:: html 31 | 32 |
33 | KUKA arm example 34 | GO2 robot example 35 | Heart path example 36 | Heart path example 37 |
38 | 39 | ************* 40 | Key Features 41 | ************* 42 | 43 | 1. **Flexibility**: Each control problem is assembled via ``Components``, which enforce desired behavior or keep the system within a safety set. 44 | 2. **Multiple solution approaches**: JAX's efficient sampling and autodifferentiation enable various solvers optimized for different scenarios. 45 | 3. **Fully JAX-compatible**: Both the optimization problem and solver support JAX transformations, including JIT compilation and automatic vectorization. 46 | 4. **Convenience**: The API is designed to make complex inverse kinematics problems easy to express and solve. 47 | 48 | ************* 49 | Citing MJINX 50 | ************* 51 | 52 | If you use MJINX in your research, please cite it as follows: 53 | 54 | .. code-block:: bibtex 55 | 56 | @software{mjinx25, 57 | author = {Domrachev, Ivan and Nedelchev, Simeon}, 58 | license = {MIT}, 59 | month = mar, 60 | title = {{MJINX: Differentiable GPU-accelerated inverse kinematics in JAX}}, 61 | url = {https://github.com/based-robotics/mjinx}, 62 | version = {0.1.1}, 63 | year = {2025} 64 | } 65 | 66 | 67 | .. toctree:: 68 | :maxdepth: 2 69 | :caption: Contents: 70 | 71 | installation.rst 72 | quick_start.rst 73 | problem.rst 74 | configuration.rst 75 | components.rst 76 | solvers.rst 77 | visualization.rst 78 | references.rst 79 | developer-notes.rst 80 | 81 | 82 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/installation.rst 2 | 3 | ============ 4 | Installation 5 | ============ 6 | 7 | ********** 8 | From PyPI 9 | ********** 10 | 11 | The simplest way to install MJINX is via the Python Package Index: 12 | 13 | .. code:: bash 14 | 15 | pip install mjinx 16 | 17 | ******************** 18 | Installation Options 19 | ******************** 20 | 21 | MJINX offers several installation options with different dependencies: 22 | 23 | 1. **Visualization tools**: ``pip install mjinx[visual]`` 24 | 2. **Example dependencies**: ``pip install mjinx[examples]`` 25 | 3. **Development dependencies**: ``pip install mjinx[dev]`` (preferably in editable mode) 26 | 4. **Documentation dependencies**: ``pip install mjinx[docs]`` 27 | 5. **Complete installation**: ``pip install mjinx[all]`` 28 | 29 | ***************************** 30 | From Source (Developer Mode) 31 | ***************************** 32 | 33 | For developers or to access the latest features, you can clone the repository and install in editable mode: 34 | 35 | .. code:: bash 36 | 37 | git clone https://github.com/based-robotics/mjinx.git 38 | cd mjinx 39 | pip install -e . 40 | 41 | With editable mode, any changes you make to the source code take effect immediately without requiring reinstallation. 42 | 43 | For development work, we recommend installing with the development extras: 44 | 45 | .. code:: bash 46 | 47 | pip install -e ".[dev]" -------------------------------------------------------------------------------- /docs/problem.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/problem.rst 2 | 3 | .. _Problem: 4 | 5 | ======= 6 | Problem 7 | ======= 8 | 9 | At the heart of MJINX lies the Problem module - a structured framework that elegantly handles inverse kinematics challenges. This module serves as the central hub where various components come together to form a cohesive mathematical formulation. 10 | 11 | When working with MJINX, a Problem instance orchestrates several key elements: 12 | 13 | - **Tasks**: Objective functions that define desired behaviors, such as reaching specific poses or following trajectories 14 | - **Barriers**: Smooth constraint functions that naturally keep the system within valid states 15 | - **Velocity limits**: Physical bounds on joint velocities to ensure feasible motion 16 | 17 | The module's modular architecture shines in its flexibility - users can begin with simple scenarios like positioning a single end-effector, then naturally build up to complex whole-body motions with multiple objectives and constraints. 18 | 19 | Under the hood, the Problem class transforms these high-level specifications into optimized computations through its JaxProblemData representation. By leveraging JAX's JIT compilation, it ensures that even sophisticated inverse kinematics problems run with maximum efficiency. 20 | 21 | .. automodule:: mjinx.problem 22 | :members: 23 | :member-order: bysource 24 | 25 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | 2 | @article{kanoun2011kinematic, 3 | title={Kinematic Control of Redundant Manipulators: Generalizing the Task-Priority Framework to Inequality Task}, 4 | author={Kanoun, Oussama and Lamiraux, Florent and Wieber, Pierre-Brice}, 5 | journal={IEEE Transactions on Robotics}, 6 | volume={27}, 7 | number={4}, 8 | pages={785--792}, 9 | year={2011}, 10 | publisher={IEEE} 11 | } 12 | 13 | @article{delprete2018joint, 14 | title={Joint Position and Velocity Bounds in Discrete-Time Acceleration/Torque Control of Robot Manipulators}, 15 | author={Del Prete, Andrea}, 16 | journal={IEEE Robotics and Automation Letters}, 17 | volume={3}, 18 | number={1}, 19 | pages={281--288}, 20 | year={2018}, 21 | publisher={IEEE} 22 | } 23 | 24 | @article{jackson2021planning, 25 | title={Planning with Attitude}, 26 | author={Jackson, Brian E and Tracy, Kevin and Manchester, Zachary}, 27 | journal={IEEE Robotics and Automation Letters}, 28 | volume={6}, 29 | number={3}, 30 | pages={5658--5664}, 31 | year={2021}, 32 | publisher={IEEE} 33 | } 34 | 35 | @inproceedings{ames2019control, 36 | title={Control barrier functions: Theory and applications}, 37 | author={Ames, Aaron D and Coogan, Samuel and Egerstedt, Magnus and Notomista, Gennaro and Sreenath, Koushil and Tabuada, Paulo}, 38 | booktitle={2019 18th European control conference (ECC)}, 39 | pages={3420--3431}, 40 | year={2019}, 41 | organization={Ieee} 42 | } 43 | 44 | @inproceedings{brunke2024practical, 45 | title={Practical Considerations for Discrete-Time Implementations of Continuous-Time Control Barrier Function-Based Safety Filters}, 46 | author={Brunke, Lukas and Zhou, Siqi and Che, Mingxuan and Schoellig, Angela P}, 47 | booktitle={2024 American Control Conference (ACC)}, 48 | pages={272--278}, 49 | year={2024}, 50 | organization={IEEE} 51 | } 52 | 53 | @misc{jax2018github, 54 | title={{JAX}: Composable Transformations of {Python+NumPy} Programs}, 55 | author={Bradbury, James and Frostig, Roy and Hawkins, Peter and Johnson, Matthew J and Leary, Chris and Maclaurin, Dougal and Wanderman-Milne, Skye}, 56 | year={2018}, 57 | howpublished={\url{https://github.com/google/jax}} 58 | } -------------------------------------------------------------------------------- /docs/references.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/references.rst 2 | 3 | ********** 4 | References 5 | ********** 6 | 7 | ===================== 8 | Academic References 9 | ===================== 10 | 11 | This section lists key academic papers and resources that influenced MJINX's design and implementation: 12 | 13 | Citations 14 | --------- 15 | 16 | This project builds on the following academic works: 17 | 18 | .. bibliography:: 19 | :style: plain 20 | :all: 21 | 22 | =================== 23 | Software References 24 | =================== 25 | 26 | MJINX builds upon and integrates with the following software libraries: 27 | 28 | - `JAX `_: Autograd and XLA for high-performance machine learning research. 29 | - `MuJoCo `_: Physics engine for detailed, efficient robot simulation. 30 | - `PINK `_: Differentiable inverse kinematics using Pinocchio. 31 | - `MINK `_: MuJoCo-based inverse kinematics. 32 | - `Optax `_: Gradient processing and optimization. 33 | - `JaxLie `_: JAX library for Lie groups. 34 | 35 | ================= 36 | Acknowledgements 37 | ================= 38 | 39 | MJINX would not exist without the contributions and inspiration from several sources: 40 | 41 | - Simeon Nedelchev for guidance and contributions during development 42 | - Stéphane Caron and Kevin Zakka, whose work on PINK and MINK respectively provided significant inspiration 43 | - The MuJoCo MJX team for their excellent physics simulation tools 44 | - IRIS lab at KAIST 45 | 46 | 47 | ============= 48 | Citing MJINX 49 | ============= 50 | 51 | If you use MJINX in your research, please cite it as follows: 52 | 53 | .. code-block:: bibtex 54 | 55 | @software{mjinx25, 56 | author = {Domrachev, Ivan and Nedelchev, Simeon}, 57 | license = {MIT}, 58 | month = mar, 59 | title = {{MJINX: Differentiable GPU-accelerated inverse kinematics in JAX}}, 60 | url = {https://github.com/based-robotics/mjinx}, 61 | version = {0.1.1}, 62 | year = {2025} 63 | } 64 | 65 | 66 | -------------------------------------------------------------------------------- /docs/solvers.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/solvers.rst 2 | 3 | .. _Solvers: 4 | 5 | ======== 6 | Solvers 7 | ======== 8 | 9 | MJINX provides multiple solver implementations for inverse kinematics problems, each with different characteristics suitable for various applications. 10 | 11 | *********** 12 | Base Solver 13 | *********** 14 | 15 | The abstract base class defining the interface for all solvers. 16 | 17 | .. automodule:: mjinx.solvers._base 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | :member-order: bysource 22 | 23 | 24 | *************** 25 | Local IK Solver 26 | *************** 27 | 28 | A Quadratic Programming (QP) based solver that linearizes the problem at each step. This solver is efficient for real-time control and tracking applications. 29 | 30 | .. automodule:: mjinx.solvers._local_ik 31 | :members: 32 | :undoc-members: 33 | :show-inheritance: 34 | :member-order: bysource 35 | 36 | **************** 37 | Global IK Solver 38 | **************** 39 | 40 | A nonlinear optimization solver that directly optimizes joint positions. This solver can find solutions that avoid local minima and is suitable for complex positioning tasks. 41 | 42 | .. automodule:: mjinx.solvers._global_ik 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | :member-order: bysource -------------------------------------------------------------------------------- /docs/tasks.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/tasks.rst 2 | 3 | .. _Tasks: 4 | 5 | Tasks define the objective functions in an inverse kinematics problem. Each task represents a mapping from the configuration space to a task space, with the solver minimizing the weighted error between the current and desired task values. This approach follows the task-priority framework established in robotics literature :cite:`kanoun2011kinematic`, where multiple objectives are managed through appropriate weighting and prioritization. 6 | 7 | Mathematically, a task is defined as a function :math:`f: \mathcal{Q} \rightarrow \mathbb{R}^m` that maps from the configuration space :math:`\mathcal{Q}` to a task space. The error is computed as :math:`e(q) = f(q) - f_{desired}`, and the solver minimizes a weighted norm of this error: :math:`\|e(q)\|^2_W`, where :math:`W` is a positive-definite weight matrix. 8 | 9 | All tasks inherit from the base Task class and follow a consistent mathematical formulation, enabling systematic composition of complex behaviors through the combination of elementary tasks. 10 | 11 | Base Task 12 | --------- 13 | The foundational class that all tasks extend. It defines the core interface and mathematical properties for task objectives. 14 | 15 | .. automodule:: mjinx.components.tasks._base 16 | :members: 17 | :member-order: bysource 18 | 19 | Center of Mass Task 20 | ------------------- 21 | Controls the position of the robot's center of mass, critical for maintaining balance in legged systems and manipulators. 22 | 23 | .. automodule:: mjinx.components.tasks._com_task 24 | :members: 25 | :member-order: bysource 26 | 27 | Joint Task 28 | ---------- 29 | Directly controls joint positions, useful for regularization, posture optimization, and redundancy resolution. 30 | 31 | .. automodule:: mjinx.components.tasks._joint_task 32 | :members: 33 | :member-order: bysource 34 | 35 | Base Body Task 36 | -------------- 37 | The foundation for tasks that target specific bodies, geometries, or sites in the robot model. This abstract class provides common functionality for all object-specific tasks. 38 | 39 | .. automodule:: mjinx.components.tasks._obj_task 40 | :members: 41 | :member-order: bysource 42 | 43 | Body Frame Task 44 | --------------- 45 | Controls the complete pose (position and orientation) of a body, geometry, or site using SE(3) representations. 46 | 47 | .. automodule:: mjinx.components.tasks._obj_frame_task 48 | :members: 49 | :member-order: bysource 50 | 51 | Body Position Task 52 | ------------------ 53 | Controls just the position of a body, geometry, or site, ignoring orientation. Useful when only positional constraints matter. 54 | 55 | .. automodule:: mjinx.components.tasks._obj_position_task 56 | :members: 57 | :member-order: bysource 58 | -------------------------------------------------------------------------------- /docs/typing.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/typing.rst 2 | 3 | ************ 4 | Type System 5 | ************ 6 | 7 | MJINX uses Python's type annotations throughout the codebase to enhance code clarity, enable better IDE support, and catch potential errors. This module provides the type definitions and aliases used across the library. 8 | 9 | .. automodule:: mjinx.typing 10 | :members: 11 | :undoc-members: 12 | :show-inheritance: 13 | 14 | ************ 15 | Type Aliases 16 | ************ 17 | 18 | The following type aliases are defined for common data structures and function signatures: 19 | 20 | .. data:: ndarray 21 | :no-index: 22 | 23 | Type alias for numpy or JAX numpy arrays. 24 | 25 | :annotation: = np.ndarray | jnp.ndarray 26 | 27 | .. data:: ArrayOrFloat 28 | :no-index: 29 | 30 | Type alias for an array or a scalar float value. 31 | 32 | :annotation: = ndarray | float 33 | 34 | .. data:: ClassKFunctions 35 | :no-index: 36 | 37 | Type alias for Class K functions, which are scalar functions that take and return ndarrays. 38 | 39 | :annotation: = Callable[[ndarray], ndarray] 40 | 41 | .. data:: CollisionBody 42 | :no-index: 43 | 44 | Type alias for collision body representation, either as an integer ID or a string name. 45 | 46 | :annotation: = int | str 47 | 48 | .. data:: CollisionPair 49 | :no-index: 50 | 51 | Type alias for a pair of collision body IDs. 52 | 53 | :annotation: = tuple[int, int] 54 | 55 | ************ 56 | Enumerations 57 | ************ 58 | 59 | .. autoclass:: PositionLimitType 60 | :no-index: 61 | :members: 62 | :undoc-members: 63 | :show-inheritance: 64 | 65 | Enumeration of possible position limit types. 66 | 67 | .. attribute:: MIN 68 | :no-index: 69 | :value: 0 70 | 71 | Minimum position limit. 72 | 73 | .. attribute:: MAX 74 | :no-index: 75 | :value: 1 76 | 77 | Maximum position limit. 78 | 79 | .. attribute:: BOTH 80 | :no-index: 81 | :value: 2 82 | 83 | Both minimum and maximum position limits. -------------------------------------------------------------------------------- /docs/visualization.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/visualization.rst 2 | 3 | ============= 4 | Visualization 5 | ============= 6 | 7 | MJINX provides utilities for visualizing robots, problem components, and optimization results. These visualization tools help with debugging, analysis, and presentation of inverse kinematics solutions. 8 | 9 | The visualization module offers: 10 | - Functions for rendering robot states 11 | - Tools for visualizing task errors and barrier values 12 | - Methods for displaying optimization trajectories 13 | - Integration with popular plotting libraries 14 | 15 | Using visualization alongside MJINX's computational capabilities allows for both development and deployment of robust inverse kinematics solutions. 16 | 17 | .. automodule:: mjinx.visualize 18 | :members: 19 | :member-order: bysource 20 | 21 | -------------------------------------------------------------------------------- /examples/cassie_squat.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | from time import perf_counter 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import mujoco as mj 7 | import mujoco.mjx as mjx 8 | import numpy as np 9 | from jaxlie import SE3, SO3 10 | from robot_descriptions.cassie_mj_description import MJCF_PATH 11 | 12 | from mjinx.components.barriers import JointBarrier 13 | from mjinx.components.constraints import ModelEqualityConstraint 14 | from mjinx.components.tasks import ComTask, FrameTask 15 | from mjinx.configuration import integrate, update 16 | from mjinx.problem import Problem 17 | from mjinx.solvers import LocalIKSolver 18 | from mjinx.visualize import BatchVisualizer 19 | 20 | # === Mujoco === 21 | 22 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH) 23 | mjx_model = mjx.put_model(mj_model) 24 | print(mjx_model.nq, mjx_model.nv) 25 | 26 | q_min = mj_model.jnt_range[:, 0].copy() 27 | q_max = mj_model.jnt_range[:, 1].copy() 28 | 29 | # --- Mujoco visualization --- 30 | # Initialize render window and launch it at the background 31 | vis = BatchVisualizer(MJCF_PATH, n_models=5, alpha=0.2, record=True) 32 | 33 | # === Mjinx === 34 | # --- Constructing the problem --- 35 | # Creating problem formulation 36 | problem = Problem(mjx_model, v_min=-1, v_max=1) 37 | 38 | # Creating components of interest and adding them to the problem 39 | joints_barrier = JointBarrier( 40 | "jnt_range", 41 | gain=0.1, 42 | ) 43 | com_task = ComTask("com_task", cost=10.0, gain=50.0, mask=[1, 1, 1]) 44 | torso_task = FrameTask("torso_task", cost=1.0, gain=2.0, obj_name="cassie-pelvis", mask=[0, 0, 0, 1, 1, 1]) 45 | 46 | # Feet (in stance) 47 | left_foot_task = FrameTask( 48 | "left_foot_task", 49 | cost=20.0, 50 | gain=10.0, 51 | obj_name="left-foot", 52 | mask=[1, 1, 1, 1, 0, 1], 53 | ) 54 | right_foot_task = FrameTask( 55 | "right_foot_task", 56 | cost=20.0, 57 | gain=10.0, 58 | obj_name="right-foot", 59 | mask=[1, 1, 1, 1, 0, 1], 60 | ) 61 | 62 | model_equality_constraint = ModelEqualityConstraint() 63 | 64 | problem.add_component(com_task) 65 | problem.add_component(torso_task) 66 | # TODO: fix this 67 | problem.add_component(joints_barrier) 68 | problem.add_component(left_foot_task) 69 | problem.add_component(right_foot_task) 70 | problem.add_component(model_equality_constraint) 71 | # Initializing solver and its initial state 72 | solver = LocalIKSolver(mjx_model, maxiter=10) 73 | 74 | # Initializing initial condition 75 | N_batch = 100 76 | q0 = mj_model.keyframe("home").qpos 77 | q = jnp.tile(q0, (N_batch, 1)) 78 | 79 | mjx_data = update(mjx_model, jnp.array(q0)) 80 | 81 | com0 = np.array(mjx_data.subtree_com[mjx_model.body_rootid[0]]) 82 | com_task.target_com = com0 83 | # Get torso orientation and set it as target 84 | torso_id = mjx.name2id(mjx_model, mj.mjtObj.mjOBJ_BODY, "cassie-pelvis") 85 | torso_quat = mjx_data.xquat[torso_id] 86 | torso_task.target_frame = np.concatenate([np.zeros(3), torso_quat]) 87 | 88 | left_foot_id = mjx.name2id(mjx_model, mj.mjtObj.mjOBJ_BODY, "left-foot") 89 | left_foot_pos = mjx_data.xpos[left_foot_id] 90 | left_foot_quat = mjx_data.xquat[left_foot_id] 91 | left_foot_task.target_frame = jnp.concatenate([left_foot_pos, left_foot_quat]) 92 | 93 | right_foot_id = mjx.name2id(mjx_model, mj.mjtObj.mjOBJ_BODY, "right-foot") 94 | right_foot_pos = mjx_data.xpos[right_foot_id] 95 | right_foot_quat = mjx_data.xquat[right_foot_id] 96 | right_foot_task.target_frame = jnp.concatenate([right_foot_pos, right_foot_quat]) 97 | 98 | # Compiling the problem upon any parameters update 99 | problem_data = problem.compile() 100 | # --- Batching --- 101 | print("Setting up batched computations...") 102 | solver_data = jax.vmap(solver.init, in_axes=0)(v_init=jnp.zeros((N_batch, mjx_model.nv))) 103 | 104 | with problem.set_vmap_dimension() as empty_problem_data: 105 | empty_problem_data.components["com_task"].target_com = 0 106 | 107 | # Vmapping solve and integrate functions. 108 | solve_jit = jax.jit( 109 | jax.vmap( 110 | solver.solve, 111 | in_axes=(0, 0, empty_problem_data), 112 | ) 113 | ) 114 | integrate_jit = jax.jit(jax.vmap(integrate, in_axes=(None, 0, 0, None)), static_argnames=["dt"]) 115 | 116 | # === Control loop === 117 | dt = 1e-2 118 | ts = np.arange(0, 20, dt) 119 | 120 | try: 121 | for t in ts: 122 | # Solving the instance of the problem 123 | # com_task.target_com = com0 - np.array([0, 0, 0.2 * np.sin(t) ** 2]) 124 | com_task.target_com = np.array( 125 | [[0.0, 0.0, com0[2] - 0.3 * np.sin(t + 2 * np.pi * i / N_batch + np.pi / 2) ** 2] for i in range(N_batch)] 126 | ) 127 | problem_data = problem.compile() 128 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 129 | # Integrating 130 | q = integrate_jit( 131 | mjx_model, 132 | q, 133 | opt_solution.v_opt, 134 | dt, 135 | ) 136 | # --- MuJoCo visualization --- 137 | indices = np.arange(0, N_batch, N_batch // vis.n_models) 138 | 139 | vis.update(q[:: N_batch // vis.n_models]) 140 | 141 | except KeyboardInterrupt: 142 | print("Finalizing the simulation as requested...") 143 | except Exception: 144 | print(traceback.format_exc()) 145 | finally: 146 | if vis.record: 147 | vis.save_video(round(1 / dt)) 148 | vis.close() 149 | -------------------------------------------------------------------------------- /examples/g1_description/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016-2023 HangZhou YuShu TECHNOLOGY CO.,LTD. ("Unitree Robotics") 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /examples/g1_description/README.md: -------------------------------------------------------------------------------- 1 | # Unitree G1 Description (MJCF) 2 | 3 | > [!IMPORTANT] 4 | > Requires MuJoCo 2.3.4 or later. 5 | 6 | ## Overview 7 | 8 | This package contains a simplified robot description (MJCF) of the [G1 Humanoid 9 | Robot](https://www.unitree.com/g1/) developed by [Unitree 10 | Robotics](https://www.unitree.com/). It is derived from the [publicly available 11 | MJCF 12 | description](https://github.com/unitreerobotics/unitree_ros/tree/master/robots/g1_description). 13 | 14 |

15 | 16 |

17 | 18 | ## MJCF derivation steps 19 | 20 | 1. Copied the MJCF description from [g1_description](https://github.com/unitreerobotics/unitree_ros/tree/master/robots/g1_description). 21 | 2. Manually edited the MJCF to extract common properties into the `` section. 22 | 3. Added sites for the IMU, head and feet. 23 | 4. Add IMU sensor (gyro, accelero, framequat). 24 | 5. Added stand keyframe. 25 | 6. Added spotlight and tracking light. 26 | 27 | ## License 28 | 29 | This model is released under a [BSD-3-Clause License](LICENSE). 30 | -------------------------------------------------------------------------------- /examples/g1_description/assets/head_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/head_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_ankle_pitch_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_ankle_pitch_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_ankle_roll_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_ankle_roll_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_elbow_pitch_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_elbow_pitch_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_elbow_roll_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_elbow_roll_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_five_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_five_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_four_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_four_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_hip_pitch_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_hip_pitch_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_hip_roll_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_hip_roll_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_hip_yaw_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_hip_yaw_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_knee_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_knee_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_one_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_one_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_palm_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_palm_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_shoulder_pitch_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_shoulder_pitch_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_shoulder_roll_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_shoulder_roll_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_shoulder_yaw_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_shoulder_yaw_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_six_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_six_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_three_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_three_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_two_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_two_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/left_zero_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_zero_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/logo_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/logo_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/pelvis.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/pelvis.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/pelvis_contour_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/pelvis_contour_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_ankle_pitch_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_ankle_pitch_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_ankle_roll_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_ankle_roll_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_elbow_pitch_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_elbow_pitch_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_elbow_roll_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_elbow_roll_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_five_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_five_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_four_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_four_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_hip_pitch_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_hip_pitch_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_hip_roll_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_hip_roll_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_hip_yaw_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_hip_yaw_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_knee_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_knee_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_one_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_one_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_palm_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_palm_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_shoulder_pitch_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_shoulder_pitch_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_shoulder_roll_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_shoulder_roll_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_shoulder_yaw_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_shoulder_yaw_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_six_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_six_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_three_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_three_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_two_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_two_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/right_zero_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_zero_link.STL -------------------------------------------------------------------------------- /examples/g1_description/assets/torso_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/torso_link.STL -------------------------------------------------------------------------------- /examples/g1_description/g1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/g1.png -------------------------------------------------------------------------------- /examples/g1_description/scene.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /examples/global_ik.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of Global inverse kinematics for a Kuka iiwa robot. 3 | 4 | NOTE: The Global IK functionality is not yet working properly as expected and needs proper tuning. 5 | This example will be fixed in future updates. Use with caution and expect suboptimal results. 6 | """ 7 | 8 | import time 9 | from time import perf_counter 10 | from collections import defaultdict 11 | 12 | import jax 13 | import jax.numpy as jnp 14 | import mujoco as mj 15 | import mujoco.mjx as mjx 16 | import numpy as np 17 | from mujoco import viewer 18 | from optax import adam 19 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH 20 | 21 | from mjinx.components.barriers import JointBarrier, PositionBarrier, SelfCollisionBarrier 22 | from mjinx.components.tasks import FrameTask 23 | from mjinx.configuration import integrate 24 | from mjinx.problem import Problem 25 | from mjinx.solvers import GlobalIKSolver 26 | 27 | print("=== Initializing ===") 28 | 29 | # === Mujoco === 30 | print("Loading MuJoCo model...") 31 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH) 32 | mj_data = mj.MjData(mj_model) 33 | 34 | mjx_model = mjx.put_model(mj_model) 35 | 36 | q_min = mj_model.jnt_range[:, 0].copy() 37 | q_max = mj_model.jnt_range[:, 1].copy() 38 | 39 | 40 | # --- Mujoco visualization --- 41 | print("Setting up visualization...") 42 | # Initialize render window and launch it at the background 43 | mj_data = mj.MjData(mj_model) 44 | renderer = mj.Renderer(mj_model) 45 | mj_viewer = viewer.launch_passive( 46 | mj_model, 47 | mj_data, 48 | show_left_ui=False, 49 | show_right_ui=False, 50 | ) 51 | 52 | # Initialize a sphere marker for end-effector task 53 | renderer.scene.ngeom += 1 54 | mj_viewer.user_scn.ngeom = 1 55 | mj.mjv_initGeom( 56 | mj_viewer.user_scn.geoms[0], 57 | mj.mjtGeom.mjGEOM_SPHERE, 58 | 0.05 * np.ones(3), 59 | np.array([0.2, 0.2, 0.2]), 60 | np.eye(3).flatten(), 61 | np.array([0.565, 0.933, 0.565, 0.4]), 62 | ) 63 | 64 | # === Mjinx === 65 | print("Setting up optimization problem...") 66 | # --- Constructing the problem --- 67 | # Creating problem formulation 68 | problem = Problem(mjx_model, v_min=-100, v_max=100) 69 | 70 | # Creating components of interest and adding them to the problem 71 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7") 72 | position_barrier = PositionBarrier( 73 | "ee_barrier", 74 | gain=0.1, 75 | obj_name="link7", 76 | limit_type="max", 77 | p_max=0.3, 78 | safe_displacement_gain=1e-2, 79 | mask=[1, 0, 0], 80 | ) 81 | joints_barrier = JointBarrier("jnt_range", gain=0.1) 82 | self_collision_barrier = SelfCollisionBarrier( 83 | "self_collision_barrier", 84 | gain=1e-4, 85 | d_min=0.01, 86 | ) 87 | 88 | problem.add_component(frame_task) 89 | problem.add_component(position_barrier) 90 | problem.add_component(joints_barrier) 91 | problem.add_component(self_collision_barrier) 92 | 93 | # Compiling the problem upon any parameters update 94 | problem_data = problem.compile() 95 | 96 | # Initializing solver and its initial state 97 | print("Initializing solver...") 98 | solver = GlobalIKSolver(mjx_model, adam(learning_rate=1e-2), dt=1e-2) 99 | 100 | # Initial condition 101 | q = np.array( 102 | [ 103 | -1.4238753, 104 | -1.7268502, 105 | -0.84355015, 106 | 2.0962472, 107 | 2.1339328, 108 | 2.0837479, 109 | -2.5521986, 110 | ] 111 | ) 112 | solver_data = solver.init(q) 113 | 114 | # Jit-compiling the key functions for better efficiency 115 | solve_jit = jax.jit(solver.solve) 116 | integrate_jit = jax.jit(integrate, static_argnames=["dt"]) 117 | 118 | t_warmup = perf_counter() 119 | print("Performing warmup calls...") 120 | # Warmup iterations for JIT compilation 121 | frame_task.target_frame = np.array([0.2, 0.2, 0.2, 1, 0, 0, 0]) 122 | problem_data = problem.compile() 123 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 124 | q_warmup = opt_solution.q_opt 125 | 126 | t_warmup_duration = perf_counter() - t_warmup 127 | print(f"Warmup completed in {t_warmup_duration:.3f} seconds") 128 | 129 | # === Control loop === 130 | print("\n=== Starting main loop ===") 131 | dt = 1e-2 132 | ts = np.arange(0, 20, dt) 133 | 134 | # Performance tracking 135 | solve_times = [] 136 | n_steps = 0 137 | 138 | try: 139 | for t in ts: 140 | # Changing desired values 141 | frame_task.target_frame = np.array([0.2 + 0.2 * jnp.sin(t) ** 2, 0.2, 0.2, 1, 0, 0, 0]) 142 | 143 | # After changes, recompiling the model 144 | problem_data = problem.compile() 145 | 146 | # Solving the instance of the problem 147 | t1 = perf_counter() 148 | for _ in range(1): 149 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 150 | t2 = perf_counter() 151 | solve_times.append(t2 - t1) 152 | 153 | # Two options for retriving q: 154 | # Option 1, integrating: 155 | # q = integrate(mjx_model, q, opt_solution.v_opt, dt=dt) 156 | # Option 2, direct: 157 | q = opt_solution.q_opt 158 | 159 | # --- MuJoCo visualization --- 160 | mj_data.qpos = q 161 | mj.mj_forward(mj_model, mj_data) 162 | print(f"Position barrier: {mj_data.xpos[position_barrier.obj_id][0]} <= {position_barrier.p_max[0]}") 163 | mj.mjv_initGeom( 164 | mj_viewer.user_scn.geoms[0], 165 | mj.mjtGeom.mjGEOM_SPHERE, 166 | 0.05 * np.ones(3), 167 | np.array(frame_task.target_frame.wxyz_xyz[-3:], dtype=np.float64), 168 | np.eye(3).flatten(), 169 | np.array([0.565, 0.933, 0.565, 0.4]), 170 | ) 171 | 172 | # Run the forward dynamics to reflec 173 | # the updated state in the data 174 | mj.mj_forward(mj_model, mj_data) 175 | mj_viewer.sync() 176 | n_steps += 1 177 | except KeyboardInterrupt: 178 | print("\nSimulation interrupted by user") 179 | except Exception as e: 180 | print(f"\nError occurred: {e}") 181 | finally: 182 | renderer.close() 183 | 184 | # Print performance report 185 | print("\n=== Performance Report ===") 186 | print(f"Total steps completed: {n_steps}") 187 | print("\nComputation times per step:") 188 | if solve_times: 189 | avg_solve = sum(solve_times) / len(solve_times) 190 | std_solve = np.std(solve_times) 191 | print(f"solve : {avg_solve * 1000:8.3f} ± {std_solve * 1000:8.3f} ms") 192 | 193 | if solve_times: 194 | print(f"\nAverage computation time per step: {avg_solve * 1000:.3f} ms") 195 | print(f"Effective computation rate: {1 / avg_solve:.1f} Hz") 196 | -------------------------------------------------------------------------------- /examples/global_ik_vmapped_input.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example of Global inverse kinematics for a Kuka iiwa robot with vmapped input. 3 | This demonstrates how to use JAX's vmap to efficiently compute IK solutions for multiple initial configurations. 4 | 5 | NOTE: The Global IK functionality is not yet working properly as expected and needs proper tuning. 6 | This example will be fixed in future updates. Use with caution and expect suboptimal results. 7 | """ 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | import mujoco as mj 12 | import mujoco.mjx as mjx 13 | import numpy as np 14 | from optax import adam 15 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH 16 | from time import perf_counter 17 | 18 | from mjinx.components.barriers import JointBarrier, PositionBarrier 19 | from mjinx.components.tasks import FrameTask 20 | from mjinx.configuration import integrate 21 | from mjinx.problem import Problem 22 | from mjinx.solvers import GlobalIKSolver 23 | from mjinx.visualize import BatchVisualizer 24 | 25 | print("=== Initializing ===") 26 | 27 | # === Mujoco === 28 | print("Loading MuJoCo model...") 29 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH) 30 | mjx_model = mjx.put_model(mj_model) 31 | 32 | q_min = mj_model.jnt_range[:, 0].copy() 33 | q_max = mj_model.jnt_range[:, 1].copy() 34 | 35 | 36 | # --- Mujoco visualization --- 37 | print("Setting up visualization...") 38 | vis = BatchVisualizer(MJCF_PATH, n_models=5, alpha=0.5, record=False) 39 | 40 | # Initialize a sphere marker for end-effector task 41 | vis.add_markers( 42 | name="ee_marker", 43 | size=0.05, 44 | marker_alpha=0.4, 45 | color_begin=np.array([0, 1.0, 0.53]), 46 | ) 47 | vis.add_markers( 48 | name="blocking_plane", 49 | marker_type=mj.mjtGeom.mjGEOM_PLANE, 50 | size=np.array([0.5, 0.5, 0.02]), 51 | marker_alpha=0.7, 52 | color_begin=np.array([1, 0, 0]), 53 | ) 54 | 55 | # === Mjinx === 56 | print("Setting up optimization problem...") 57 | # --- Constructing the problem --- 58 | # Creating problem formulation 59 | problem = Problem(mjx_model) 60 | 61 | # Creating components of interest and adding them to the problem 62 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7") 63 | position_barrier = PositionBarrier( 64 | "ee_barrier", 65 | gain=0.1, 66 | obj_name="link7", 67 | limit_type="max", 68 | p_max=0.5, 69 | safe_displacement_gain=1e-2, 70 | mask=[1, 0, 0], 71 | ) 72 | joints_barrier = JointBarrier("jnt_range", gain=0.1) 73 | # Set plane coodinate same to limiting one 74 | vis.marker_data["blocking_plane"].pos = np.array([position_barrier.p_max[0], 0, 0.3]) 75 | vis.marker_data["blocking_plane"].rot = np.array( 76 | [ 77 | [0, 0, -1], 78 | [0, 1, 0], 79 | [1, 0, 0], 80 | ] 81 | ) 82 | 83 | problem.add_component(frame_task) 84 | problem.add_component(position_barrier) 85 | problem.add_component(joints_barrier) 86 | 87 | # Compiling the problem upon any parameters update 88 | problem_data = problem.compile() 89 | 90 | # Initializing solver and its initial state 91 | print("Initializing solver...") 92 | solver = GlobalIKSolver(mjx_model, adam(learning_rate=1e-2), dt=1e-2) 93 | 94 | # Initializing initial condition 95 | N_batch = 100 96 | q0 = np.array( 97 | [ 98 | -1.4238753, 99 | -1.7268502, 100 | -0.84355015, 101 | 2.0962472, 102 | 2.1339328, 103 | 2.0837479, 104 | -2.5521986, 105 | ] 106 | ) 107 | q = jnp.array( 108 | [ 109 | np.clip( 110 | q0 111 | + np.random.uniform( 112 | -0.1, 113 | 0.1, 114 | size=(mj_model.nq), 115 | ), 116 | q_min + 1e-1, 117 | q_max - 1e-1, 118 | ) 119 | for _ in range(N_batch) 120 | ] 121 | ) 122 | 123 | 124 | # --- Batching --- 125 | print("Setting up batched computations...") 126 | # First of all, data should be created via vmapped init function 127 | solver_data = jax.vmap(solver.init, in_axes=0)(q) 128 | 129 | # Vmapping solve and integrate functions. 130 | solve_jit = jax.jit(jax.vmap(solver.solve, in_axes=(0, 0, None))) 131 | integrate_jit = jax.jit(jax.vmap(integrate, in_axes=(None, 0, 0, None)), static_argnames=["dt"]) 132 | 133 | t_warmup = perf_counter() 134 | print("Performing warmup calls...") 135 | # Warmup iterations for JIT compilation 136 | frame_task.target_frame = np.array([0.4, 0.2, 0.7, 1, 0, 0, 0]) 137 | problem_data = problem.compile() 138 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 139 | q_warmup = opt_solution.q_opt 140 | 141 | t_warmup_duration = perf_counter() - t_warmup 142 | print(f"Warmup completed in {t_warmup_duration:.3f} seconds") 143 | 144 | # === Control loop === 145 | print("\n=== Starting main loop ===") 146 | dt = 1e-2 147 | ts = np.arange(0, 20, dt) 148 | 149 | # Performance tracking 150 | solve_times = [] 151 | n_steps = 0 152 | 153 | try: 154 | for t in ts: 155 | # Changing desired values 156 | frame_task.target_frame = np.array([0.4 + 0.3 * np.sin(t), 0.2, 0.4 + 0.3 * np.cos(t), 1, 0, 0, 0]) 157 | 158 | # After changes, recompiling the model 159 | problem_data = problem.compile() 160 | 161 | # Solving the instance of the problem 162 | t1 = perf_counter() 163 | for _ in range(3): 164 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 165 | t2 = perf_counter() 166 | solve_times.append(t2 - t1) 167 | 168 | # Two options for retriving q: 169 | # Option 1, integrating: 170 | # q = integrate(mjx_model, q, opt_solution.v_opt, dt=dt) 171 | # Option 2, direct: 172 | q = opt_solution.q_opt 173 | 174 | # --- MuJoCo visualization --- 175 | vis.marker_data["ee_marker"].pos = np.array(frame_task.target_frame.wxyz_xyz[-3:]) 176 | vis.update(q[: vis.n_models]) 177 | n_steps += 1 178 | 179 | except KeyboardInterrupt: 180 | print("\nSimulation interrupted by user") 181 | except Exception as e: 182 | print(f"\nError occurred: {e}") 183 | finally: 184 | if vis.record: 185 | vis.save_video(round(1 / dt)) 186 | vis.close() 187 | 188 | # Print performance report 189 | print("\n=== Performance Report ===") 190 | print(f"Total steps completed: {n_steps}") 191 | print("\nComputation times per step:") 192 | if solve_times: 193 | avg_solve = sum(solve_times) / len(solve_times) 194 | std_solve = np.std(solve_times) 195 | print(f"solve : {avg_solve * 1000:8.3f} ± {std_solve * 1000:8.3f} ms") 196 | 197 | if solve_times: 198 | print(f"\nAverage computation time per step: {avg_solve * 1000:.3f} ms") 199 | print(f"Effective computation rate: {1 / avg_solve:.1f} Hz") 200 | -------------------------------------------------------------------------------- /img/cassie_caravan.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/img/cassie_caravan.gif -------------------------------------------------------------------------------- /img/g1_heart.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/img/g1_heart.gif -------------------------------------------------------------------------------- /img/go2_stance.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/img/go2_stance.gif -------------------------------------------------------------------------------- /img/local_ik_input.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/img/local_ik_input.gif -------------------------------------------------------------------------------- /img/local_ik_output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/img/local_ik_output.gif -------------------------------------------------------------------------------- /img/logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mjinx/__init__.py: -------------------------------------------------------------------------------- 1 | """Differentiable GPU-accelerated inverse kinematics based on JAX and MuJoCo MJX. 2 | 3 | This package provides a framework for solving inverse kinematics problems using 4 | differential programming paradigms. It leverages JAX's automatic differentiation 5 | and GPU acceleration alongside MuJoCo MJX's physics simulation capabilities. 6 | """ 7 | 8 | __version__ = "0.1.0" 9 | -------------------------------------------------------------------------------- /mjinx/components/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import Component, JaxComponent 2 | 3 | __all__ = ["Component", "JaxComponent"] 4 | -------------------------------------------------------------------------------- /mjinx/components/barriers/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import Barrier, JaxBarrier 2 | from ._obj_barrier import ObjBarrier, JaxObjBarrier 3 | from ._obj_position_barrier import PositionBarrier, PositionLimitType 4 | from ._joint_barrier import JaxJointBarrier, JointBarrier 5 | from ._self_collision_barrier import JaxSelfCollisionBarrier, SelfCollisionBarrier 6 | 7 | __all__ = [ 8 | "Barrier", 9 | "JaxBarrier", 10 | "ObjBarrier", 11 | "JaxObjBarrier", 12 | "PositionBarrier", 13 | "PositionLimitType", 14 | "JaxJointBarrier", 15 | "JointBarrier", 16 | "JaxSelfCollisionBarrier", 17 | "SelfCollisionBarrier", 18 | ] 19 | -------------------------------------------------------------------------------- /mjinx/components/barriers/_base.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, TypeVar, Callable 2 | 3 | import jax.numpy as jnp 4 | import jax_dataclasses as jdc 5 | import mujoco.mjx as mjx 6 | 7 | from mjinx.components import Component, JaxComponent 8 | 9 | 10 | @jdc.pytree_dataclass 11 | class JaxBarrier(JaxComponent): 12 | r""" 13 | A base class for implementing barrier functions in JAX. 14 | 15 | This class provides a framework for creating barrier functions that can be used 16 | in optimization problems, particularly for constraint handling in robotics applications. 17 | 18 | A barrier function is defined mathematically as a function: 19 | 20 | .. math:: 21 | 22 | h(q) \geq 0 23 | 24 | where the constraint is satisfied when h(q) is non-negative. As h(q) approaches zero, 25 | the barrier "activates" to prevent constraint violation. In optimization problems, 26 | the barrier helps enforce constraints by adding a penalty term that increases rapidly 27 | as the system approaches constraint boundaries. 28 | 29 | :param safe_displacement_gain: The gain for computing safe displacements. 30 | """ 31 | 32 | safe_displacement_gain: float 33 | 34 | def compute_barrier(self, data: mjx.Data) -> jnp.ndarray: 35 | """ 36 | Compute the barrier function value. 37 | 38 | This method calculates the value of the barrier function h(q) at the current state. 39 | The barrier is active when h(q) is close to zero and satisfied when h(q) > 0. 40 | 41 | :param data: The MuJoCo simulation data. 42 | :return: The computed barrier value. 43 | """ 44 | return self.__call__(data) 45 | 46 | def compute_safe_displacement(self, data: mjx.Data) -> jnp.ndarray: 47 | r""" 48 | Compute a safe displacement to move away from constraint boundaries. 49 | 50 | When the barrier function value is close to zero, this method computes 51 | a displacement in joint space that helps move the system away from the constraint boundary: 52 | 53 | .. math:: 54 | 55 | \Delta q_{safe} = \alpha \nabla h(q) 56 | 57 | where: 58 | - :math:`\alpha` is the safe_displacement_gain 59 | - :math:`\nabla h(q)` is the gradient of the barrier function 60 | 61 | :param data: The MuJoCo simulation data. 62 | :return: A joint displacement vector to move away from constraint boundaries. 63 | """ 64 | return jnp.zeros(self.model.nv) 65 | 66 | 67 | AtomicBarrierType = TypeVar("AtomicBarrierType", bound=JaxBarrier) 68 | 69 | 70 | class Barrier(Generic[AtomicBarrierType], Component[AtomicBarrierType]): 71 | r""" 72 | A generic barrier class that wraps atomic barrier implementations. 73 | 74 | This class provides a high-level interface for barrier functions, allowing 75 | for easy integration into optimization problems. Barrier functions are used 76 | to enforce constraints by creating a potential field that pushes the system 77 | away from constraint boundaries. 78 | 79 | In optimization problems, barriers are typically integrated as inequality constraints 80 | or as penalty terms in the objective function: 81 | 82 | .. math:: 83 | 84 | \min_{q} f(q) \quad \text{subject to} \quad h(q) \geq 0 85 | 86 | or as a penalty term: 87 | 88 | .. math:: 89 | 90 | \min_{q} f(q) - \lambda \log(h(q)) 91 | 92 | where :math:`\lambda` is a weight parameter. 93 | 94 | :param safe_displacement_gain: The gain for computing safe displacements. 95 | """ 96 | 97 | safe_displacement_gain: float 98 | 99 | def __init__(self, name, gain, gain_fn=None, safe_displacement_gain=0, mask=None): 100 | """ 101 | Initialize the Barrier object. 102 | 103 | :param name: The name of the barrier. 104 | :param gain: The gain for the barrier function. 105 | :param gain_fn: A function to compute the gain dynamically. 106 | :param safe_displacement_gain: The gain for computing safe displacements. 107 | :param mask: A sequence of integers to mask certain dimensions. 108 | """ 109 | super().__init__(name, gain, gain_fn, mask) 110 | self.safe_displacement_gain = safe_displacement_gain 111 | -------------------------------------------------------------------------------- /mjinx/components/constraints/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import Constraint, JaxConstraint 2 | from ._equality_constraint import JaxModelEqualityConstraint, ModelEqualityConstraint 3 | 4 | __all__ = [ 5 | "Constraint", 6 | "JaxConstraint", 7 | "JaxModelEqualityConstraint", 8 | "ModelEqualityConstraint", 9 | ] 10 | -------------------------------------------------------------------------------- /mjinx/components/constraints/_base.py: -------------------------------------------------------------------------------- 1 | from typing import Generic, Sequence, TypeVar, Callable 2 | 3 | import jax.numpy as jnp 4 | import jax_dataclasses as jdc 5 | import mujoco.mjx as mjx 6 | 7 | from mjinx.components import Component, JaxComponent 8 | from mjinx.typing import ArrayLike, ArrayOrFloat 9 | 10 | 11 | @jdc.pytree_dataclass 12 | class JaxConstraint(JaxComponent): 13 | """ 14 | A JAX-based representation of an equality constraint. 15 | 16 | This class defines an equality constraint function f(x) = 0 for optimization tasks. 17 | 18 | :param active: Indicates whether the constraint is active. 19 | :param hard_constraint: A flag that specifies if the constraint is hard (True) or soft (False). 20 | :param soft_constraint_cost: The cost matrix used for soft constraint relaxation. 21 | """ 22 | 23 | active: bool 24 | hard_constraint: jdc.Static[bool] 25 | soft_constraint_cost: jnp.ndarray 26 | 27 | def compute_constraint(self, data: mjx.Data) -> jnp.ndarray: 28 | """ 29 | Compute the equality constraint function value. 30 | 31 | Evaluates the constraint function f(x) = 0 based on the current simulation data. 32 | For soft constraints, the computed value is later penalized by the associated cost; 33 | for hard constraints, the evaluation is directly used in the Ax = b formulation. 34 | 35 | :param data: The MuJoCo simulation data. 36 | :return: The computed constraint value. 37 | """ 38 | return self.__call__(data) 39 | 40 | 41 | AtomicConstraintType = TypeVar("AtomicConstraintType", bound=JaxConstraint) 42 | 43 | 44 | class Constraint(Generic[AtomicConstraintType], Component[AtomicConstraintType]): 45 | r""" 46 | A high-level component for formulating equality constraints. 47 | 48 | This class wraps an atomic JAX constraint (JaxConstraint) and provides a framework to 49 | manage constraints within the optimization problem. Equality constraints are specified as: 50 | 51 | .. math:: 52 | 53 | f(x) = 0 54 | 55 | They can be treated in one of two ways: 56 | - As a soft constraint. In this mode, the constraint violation is penalized by a cost 57 | (typically with a high gain), transforming the constraint into a task. 58 | - As a hard constraint. Here, the constraint is enforced as a strict equality in the following form: 59 | 60 | .. math:: 61 | 62 | \nabla h(q)^T v = -\alpha h(q), 63 | 64 | where :math:`\alpha` controls constraint enforcement and :math:`v` is the velocity vector. 65 | 66 | :param matrix_cost: The cost matrix associated with the task. 67 | :param lm_damping: The Levenberg-Marquardt damping factor. 68 | :param active: Determines if the constraint is active. 69 | :param hard_constraint: Indicates whether the constraint is hard (True) or soft (False). 70 | """ 71 | 72 | active: bool 73 | _soft_constraint_cost: jnp.ndarray | None 74 | hard_constraint: bool 75 | 76 | def __init__( 77 | self, 78 | name: str, 79 | gain: ArrayOrFloat, 80 | mask: Sequence[int] | None = None, 81 | hard_constraint: bool = False, 82 | soft_constraint_cost: ArrayLike | None = None, 83 | ): 84 | """ 85 | Initialize a Constraint object. 86 | 87 | Sets up the constraint with the given parameters. Depending on whether it is 88 | a hard or soft constraint, further integration in the optimization problem will 89 | process it accordingly. 90 | 91 | :param name: The unique identifier for the constraint. 92 | :param gain: The gain for the constraint function, affecting its impact. 93 | :param mask: Indices to select specific dimensions for evaluation. If not provided, applies to all dimensions. 94 | :param hard_constraint: If True, the constraint is handled as a hard constraint. Defaults to False (i.e. soft constraint). 95 | :param soft_constraint_cost: The cost used to relax a soft constraint. If not provided, a default high gain cost matrix (scaled identity matrix) based on the component dimension will be used. 96 | """ 97 | super().__init__(name, gain, gain_fn=None, mask=mask) 98 | self.active = True 99 | self.hard_constraint = hard_constraint 100 | self._soft_constraint_cost = jnp.array(soft_constraint_cost) if soft_constraint_cost is not None else None 101 | 102 | @property 103 | def soft_constraint_cost(self) -> jnp.ndarray: 104 | """ 105 | Get the cost matrix associated with a soft constraint. 106 | 107 | For soft constraints, any violation is penalized by a cost if self._soft_constraint_cost is not None else 1e2 * jnp.eye(self.dim)" 108 | 109 | :return: The cost matrix for the soft constraint. 110 | """ 111 | if self._soft_constraint_cost is None: 112 | return 1e2 * jnp.eye(self.dim) 113 | 114 | match self._soft_constraint_cost.ndim: 115 | case 0: 116 | return jnp.eye(self.dim) * self._soft_constraint_cost 117 | case 1: 118 | if len(self._soft_constraint_cost) != self.dim: 119 | raise ValueError( 120 | f"fail to construct matrix jnp.diag(({self.dim},)) from vector of length {self._soft_constraint_cost.shape}" 121 | ) 122 | return jnp.diag(self._soft_constraint_cost) 123 | case 2: 124 | if self._soft_constraint_cost.shape != ( 125 | self.dim, 126 | self.dim, 127 | ): 128 | raise ValueError( 129 | f"wrong shape of the cost: {self._soft_constraint_cost.shape} != ({self.dim}, {self.dim},)" 130 | ) 131 | return self._soft_constraint_cost 132 | case _: # pragma: no cover 133 | raise ValueError("fail to construct cost, given dim > 2") 134 | -------------------------------------------------------------------------------- /mjinx/components/constraints/_equality_constraint.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from typing import TypeVar 3 | 4 | import jax.numpy as jnp # noqa: F401 5 | import jax_dataclasses as jdc 6 | import mujoco as mj 7 | import mujoco.mjx as mjx 8 | 9 | from mjinx.components.constraints._base import Constraint, JaxConstraint 10 | import mjinx.typing 11 | 12 | 13 | @jdc.pytree_dataclass 14 | class JaxModelEqualityConstraint(JaxConstraint): 15 | """ 16 | A JAX-based equality constraint derived from the simulation model. 17 | 18 | Equality constraints respresent all constrainst in the mujoco, which could be defined in 19 | tag in the xml file. More details could be found here: https://mujoco.readthedocs.io/en/stable/XMLreference.html#equality. 20 | 21 | This class utilized the fact, that during kinematics computations, the equality value are computed as well, 22 | and the jacobian of the equalities are stored in data.efc_J, and the residuals are stored in data.efc_pos. 23 | 24 | :param data: A MuJoCo simulation data structure containing model-specific constraint information. 25 | :return: A jax.numpy.ndarray with the positions corresponding to the equality constraints. 26 | """ 27 | 28 | def __call__(self, data: mjx.Data) -> jnp.ndarray: 29 | """ 30 | Compute the equality constraint values from the simulation data. 31 | 32 | This method selects the positions of equality constraints from the simulation data 33 | by filtering based on the constraint type and applying the mask indices. 34 | 35 | :param data: The MuJoCo simulation data. 36 | :return: A jax.numpy.ndarray containing the equality constraint values. 37 | """ 38 | return data.efc_pos[data.efc_type == mj.mjtConstraint.mjCNSTR_EQUALITY][self.mask_idxs,] 39 | 40 | def compute_jacobian(self, data: mjx.Data) -> jnp.ndarray: 41 | """ 42 | Compute the Jacobian matrix of the equality constraint. 43 | 44 | Retrieves the relevant rows of the Jacobian matrix from the simulation data, 45 | filtering by equality constraint type and applying the mask indices. 46 | 47 | :param data: The MuJoCo simulation data. 48 | :return: A jax.numpy.ndarray representing the Jacobian matrix of the equality constraints. 49 | """ 50 | return data.efc_J[data.efc_type == mj.mjtConstraint.mjCNSTR_EQUALITY, :][self.mask_idxs, :] 51 | 52 | 53 | AtomicModelEqualityConstraintType = TypeVar("AtomicModelEqualityConstraintType", bound=JaxModelEqualityConstraint) 54 | 55 | 56 | class ModelEqualityConstraint(Constraint[AtomicModelEqualityConstraintType]): 57 | """ 58 | High-level component that aggregates all equality constraints from the simulation model. 59 | 60 | The main purpose of the wrapper is to recalculate mask based on the dimensions of the equality constrain 61 | and compute proper dimensionality. 62 | 63 | 64 | :param name: The unique identifier for the constraint. 65 | :param gain: The gain for the constraint function, affecting its impact. 66 | :param hard_constraint: If True, the constraint is handled as a hard constraint. 67 | Defaults to False (i.e., soft constraint). 68 | :param soft_constraint_cost: The cost used to relax a soft constraint. If not provided, a default high gain 69 | cost matrix (scaled identity matrix) based on the component dimension will be used. 70 | """ 71 | 72 | JaxComponentType = JaxModelEqualityConstraint 73 | 74 | def __init__( 75 | self, 76 | name: str = "equality_constraint", 77 | gain: mjinx.typing.ArrayOrFloat = 100, 78 | hard_constraint: bool = False, 79 | soft_constraint_cost: mjinx.typing.ArrayLike | None = None, 80 | ): 81 | """ 82 | Initialize a new equality constraint. 83 | 84 | Args: 85 | name (str): A unique identifier for the constraint. Defaults to 'equality_constraint'. 86 | gain (ArrayOrFloat): The gain used to weight or scale the constraint. Can be a float or an array-like type. Defaults to 100. 87 | hard_constraint (bool): Flag to determine if the constraint should be enforced strictly (hard constraint). Defaults to False. 88 | soft_constraint_cost (ArrayLike | None): Optional cost associated with treating the constraint as a soft constraint. 89 | When provided, this cost is used to penalize deviations. 90 | 91 | Attributes: 92 | active (bool): Indicates whether the constraint is active. Initialized to True. 93 | """ 94 | super().__init__( 95 | name=name, 96 | gain=gain, 97 | mask=None, 98 | hard_constraint=hard_constraint, 99 | soft_constraint_cost=soft_constraint_cost, 100 | ) 101 | self.active = True 102 | 103 | def update_model(self, model: mj.MjModel) -> None: 104 | """ 105 | Update the equality constraint using the provided MuJoCo model. 106 | 107 | This method computes the total dimension of the equality constraints by iterating 108 | over the equality types in the model. Based on the type of each equality constraint, 109 | it sets the mask indices and component dimension accordingly. 110 | 111 | :param model: The MuJoCo model instance. 112 | :return: None 113 | """ 114 | super().update_model(model) 115 | 116 | nefc = 0 117 | for i in range(self.model.neq): 118 | match self.model.eq_type[i]: 119 | case mj.mjtEq.mjEQ_CONNECT: 120 | nefc += 3 121 | case mj.mjtEq.mjEQ_WELD: 122 | nefc += 6 123 | case mj.mjtEq.mjEQ_JOINT: 124 | nefc += 1 125 | case _: 126 | raise ValueError(f"Unsupported equality constraint type {self.model.eq_type[i]}") 127 | 128 | self._mask_idxs = jnp.arange(nefc) 129 | self._dim = nefc 130 | -------------------------------------------------------------------------------- /mjinx/components/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import JaxTask, Task 2 | from ._obj_frame_task import FrameTask, JaxFrameTask 3 | from ._obj_position_task import JaxPositionTask, PositionTask 4 | from ._obj_task import ObjTask, JaxObjTask 5 | from ._com_task import ComTask, JaxComTask 6 | from ._joint_task import JaxJointTask, JointTask 7 | 8 | __all__ = [ 9 | "JaxTask", 10 | "Task", 11 | "FrameTask", 12 | "JaxFrameTask", 13 | "JaxPositionTask", 14 | "PositionTask", 15 | "ObjTask", 16 | "JaxObjTask", 17 | "ComTask", 18 | "JaxComTask", 19 | "JaxJointTask", 20 | "JointTask", 21 | ] 22 | -------------------------------------------------------------------------------- /mjinx/components/tasks/_com_task.py: -------------------------------------------------------------------------------- 1 | """Center of mass task implementation.""" 2 | 3 | from collections.abc import Callable, Sequence 4 | from typing import final 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import jax_dataclasses as jdc 9 | import mujoco 10 | import mujoco.mjx as mjx 11 | from mujoco.mjx._src import scan 12 | 13 | from mjinx.components.tasks._base import JaxTask, Task 14 | from mjinx.configuration import jac_dq2v 15 | from mjinx.typing import ArrayOrFloat 16 | 17 | 18 | @jdc.pytree_dataclass 19 | class JaxComTask(JaxTask): 20 | r""" 21 | A JAX-based implementation of a center of mass (CoM) task for inverse kinematics. 22 | 23 | This class represents a task that aims to achieve a specific target center of mass 24 | for the robot model. 25 | 26 | The task function maps joint positions to the robot's center of mass position: 27 | 28 | .. math:: 29 | 30 | f(q) = \frac{\sum_i m_i p_i(q)}{\sum_i m_i} 31 | 32 | where: 33 | - :math:`m_i` is the mass of body i 34 | - :math:`p_i(q)` is the position of body i's center of mass 35 | 36 | The error is computed as the difference between the current and target CoM positions: 37 | 38 | .. math:: 39 | 40 | e(q) = p_c(q) - p_{c_{target}} 41 | 42 | :param target_com: The target center of mass position to be achieved. 43 | """ 44 | 45 | target_com: jnp.ndarray 46 | 47 | @final 48 | def __call__(self, data: mjx.Data) -> jnp.ndarray: 49 | r""" 50 | Compute the error between the current center of mass and the target center of mass. 51 | 52 | The error is given by: 53 | 54 | .. math:: 55 | 56 | e(q) = p_c(q) - p_{c_{target}} 57 | 58 | :param data: The MuJoCo simulation data. 59 | :return: The error vector representing the difference between the current and target center of mass. 60 | """ 61 | error = data.subtree_com[self.model.body_rootid[0], self.mask_idxs] - self.target_com 62 | return error 63 | 64 | @final 65 | def compute_jacobian(self, data): 66 | def specific_update(model: mjx.Model, q: jnp.ndarray) -> mjx.Data: 67 | data = mjx.kinematics(model, mjx.make_data(model).replace(qpos=q)) 68 | 69 | # calculate center of mass of each subtree 70 | def subtree_sum(carry, xipos, body_mass): 71 | pos, mass = xipos * body_mass, body_mass 72 | if carry is not None: 73 | subtree_pos, subtree_mass = carry 74 | pos, mass = pos + subtree_pos, mass + subtree_mass 75 | return pos, mass 76 | 77 | pos, mass = scan.body_tree(model, subtree_sum, "bb", "bb", data.xipos, model.body_mass, reverse=True) 78 | cond = jnp.tile(mass < mujoco.mjMINVAL, (3, 1)).T 79 | # take maximum to avoid NaN in gradient of jp.where 80 | subtree_com = jax.vmap(jnp.divide)(pos, jnp.maximum(mass, mujoco.mjMINVAL)) 81 | subtree_com = jnp.where(cond, data.xipos, subtree_com) 82 | data = data.replace(subtree_com=subtree_com) 83 | 84 | return data 85 | 86 | jac = jax.jacrev( 87 | lambda q, model=self.model: self.__call__( 88 | specific_update(model, q), 89 | ), 90 | argnums=0, 91 | )(data.qpos) 92 | if self.model.nq != self.model.nv: 93 | jac = jac @ jac_dq2v(self.model, data.qpos) 94 | return jac 95 | 96 | 97 | class ComTask(Task[JaxComTask]): 98 | """ 99 | A high-level representation of a center of mass (CoM) task for inverse kinematics. 100 | 101 | This class provides an interface for creating and manipulating center of mass tasks, 102 | which aim to achieve a specific target center of mass for the robot model. 103 | 104 | :param name: The name of the task. 105 | :param cost: The cost associated with the task. 106 | :param gain: The gain for the task. 107 | :param gain_fn: A function to compute the gain dynamically. 108 | :param lm_damping: The Levenberg-Marquardt damping factor. 109 | :param mask: A sequence of integers to mask certain dimensions of the task. 110 | """ 111 | 112 | JaxComponentType: type = JaxComTask 113 | _target_com: jnp.ndarray 114 | 115 | def __init__( 116 | self, 117 | name: str, 118 | cost: ArrayOrFloat, 119 | gain: ArrayOrFloat, 120 | gain_fn: Callable[[float], float] | None = None, 121 | lm_damping: float = 0, 122 | mask: Sequence[int] | None = None, 123 | ): 124 | if mask is not None and len(mask) != 3: 125 | raise ValueError("provided mask is too large, expected 1D vector of length 3") 126 | super().__init__(name, cost, gain, gain_fn, lm_damping, mask=mask) 127 | self._dim = 3 if mask is None else len(self.mask_idxs) 128 | self.target_com = jnp.zeros(self._dim) 129 | 130 | @property 131 | def target_com(self) -> jnp.ndarray: 132 | """ 133 | Get the current target center of mass for the task. 134 | 135 | :return: The current target center of mass as a numpy array. 136 | """ 137 | return self._target_com 138 | 139 | @target_com.setter 140 | def target_com(self, value: Sequence): 141 | """ 142 | Set the target center of mass for the task. 143 | 144 | :param value: The new target center of mass as a sequence of values. 145 | """ 146 | self.update_target_com(value) 147 | 148 | def update_target_com(self, target_com: Sequence): 149 | """ 150 | Update the target center of mass for the task. 151 | 152 | This method allows setting the target center of mass using a sequence of values. 153 | 154 | :param target_com: The new target center of mass as a sequence of values. 155 | :raises ValueError: If the provided sequence doesn't have the correct length. 156 | """ 157 | 158 | target_com_jnp = jnp.array(target_com) 159 | if target_com_jnp.shape[-1] != self._dim: 160 | raise ValueError( 161 | f"invalid last dimension of target CoM : {target_com_jnp.shape[-1]} given, expected {self._dim} " 162 | ) 163 | self._target_com = target_com_jnp 164 | -------------------------------------------------------------------------------- /mjinx/components/tasks/_obj_frame_task.py: -------------------------------------------------------------------------------- 1 | """Frame task implementation.""" 2 | 3 | from collections.abc import Callable, Sequence 4 | from typing import final 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import jax_dataclasses as jdc 9 | import mujoco as mj 10 | import mujoco.mjx as mjx 11 | from jaxlie import SE3, SO3 12 | 13 | from mjinx.components.tasks._obj_task import JaxObjTask, ObjTask 14 | from mjinx.configuration import get_frame_jacobian_local 15 | from mjinx.typing import ArrayOrFloat, ndarray 16 | 17 | 18 | @jdc.pytree_dataclass 19 | class JaxFrameTask(JaxObjTask): 20 | r""" 21 | A JAX-based implementation of a frame task for inverse kinematics. 22 | 23 | This class represents a task that aims to achieve a specific target frame 24 | for a given object in the robot model. 25 | 26 | The task function maps joint positions to the object's frame (pose) in the world frame: 27 | 28 | .. math:: 29 | 30 | f(q) = T(q) \in SE(3) 31 | 32 | where :math:`T(q)` is the transformation matrix representing the object's pose. 33 | 34 | The error is computed using the logarithmic map in SE(3), which represents the 35 | relative transformation between current and target frames as a twist vector: 36 | 37 | .. math:: 38 | 39 | e(q) = \log(T(q)^{-1} T_{target}) 40 | 41 | This formulation provides a natural way to interpolate between frames and 42 | control both position and orientation simultaneously. 43 | 44 | :param target_frame: The target frame to be achieved. 45 | """ 46 | 47 | target_frame: SE3 48 | 49 | @final 50 | def __call__(self, data: mjx.Data) -> jnp.ndarray: 51 | r""" 52 | Compute the error between the current frame and the target frame. 53 | 54 | The error is given by the logarithmic map in SE(3): 55 | 56 | .. math:: 57 | 58 | e(q) = \log(T(q)^{-1} T_{target}) 59 | 60 | This creates a 6D twist vector representing the relative transformation 61 | between the current and target frames. 62 | 63 | :param data: The MuJoCo simulation data. 64 | :return: The error vector representing the difference between the current and target frames. 65 | """ 66 | return (self.get_frame(data).inverse() @ self.target_frame).log()[self.mask_idxs,] 67 | 68 | @final 69 | def compute_jacobian(self, data: mjx.Data) -> jnp.ndarray: 70 | r""" 71 | Compute the Jacobian of the frame task. 72 | 73 | The Jacobian relates changes in joint positions to changes in the frame task error. 74 | It is computed as: 75 | 76 | .. math:: 77 | 78 | J = -J_{\log} \cdot J_{frame}^T 79 | 80 | where: 81 | - :math:`J_{\log}` is the Jacobian of the logarithmic map at the relative transformation 82 | - :math:`J_{frame}` is the geometric Jacobian of the object's frame 83 | 84 | :param data: The MuJoCo simulation data. 85 | :return: The Jacobian matrix of the frame task. 86 | """ 87 | T_bt = self.target_frame.inverse() @ self.get_frame(data).inverse() 88 | 89 | def transform_log(tau): 90 | return (T_bt.multiply(SE3.exp(tau))).log() 91 | 92 | frame_jac = get_frame_jacobian_local(self.model, data, self.obj_id, self.obj_type) 93 | jlog = jax.jacobian(transform_log)(jnp.zeros(SE3.tangent_dim)) 94 | return (-jlog @ frame_jac.T)[self.mask_idxs,] 95 | 96 | 97 | class FrameTask(ObjTask[JaxFrameTask]): 98 | """ 99 | A high-level representation of a frame task for inverse kinematics. 100 | 101 | This class provides an interface for creating and manipulating frame tasks, 102 | which aim to achieve a specific target frame for a object (body, geom, or site) in the robot model. 103 | 104 | :param name: The name of the task. 105 | :param cost: The cost associated with the task. 106 | :param gain: The gain for the task. 107 | :param obj_name: The name of the object to which the task is applied. 108 | :param gain_fn: A function to compute the gain dynamically. 109 | :param lm_damping: The Levenberg-Marquardt damping factor. 110 | :param mask: A sequence of integers to mask certain dimensions of the task. 111 | """ 112 | 113 | JaxComponentType: type = JaxFrameTask 114 | _target_frame: SE3 115 | 116 | def __init__( 117 | self, 118 | name: str, 119 | cost: ArrayOrFloat, 120 | gain: ArrayOrFloat, 121 | obj_name: str, 122 | obj_type: mj.mjtObj = mj.mjtObj.mjOBJ_BODY, 123 | gain_fn: Callable[[float], float] | None = None, 124 | lm_damping: float = 0, 125 | mask: Sequence[int] | None = None, 126 | ): 127 | super().__init__(name, cost, gain, obj_name, obj_type, gain_fn, lm_damping, mask) 128 | self.target_frame = SE3.identity() 129 | self._dim = SE3.tangent_dim if mask is None else len(self.mask_idxs) 130 | 131 | @property 132 | def target_frame(self) -> SE3: 133 | """ 134 | Get the current target frame for the task. 135 | 136 | :return: The current target frame as an SE3 object. 137 | """ 138 | return self._target_frame 139 | 140 | @target_frame.setter 141 | def target_frame(self, value: SE3 | Sequence | ndarray): 142 | """ 143 | Set the target frame for the task. 144 | 145 | :param value: The new target frame, either as an SE3 object or a sequence of values. 146 | """ 147 | self.update_target_frame(value) 148 | 149 | def update_target_frame(self, target_frame: SE3 | Sequence | ndarray): 150 | """ 151 | Update the target frame for the task. 152 | 153 | This method allows setting the target frame using either an SE3 object 154 | or a sequence of values representing the frame. 155 | 156 | :param target_frame: The new target frame, either as an SE3 object or a sequence of values. 157 | :raises ValueError: If the provided sequence doesn't have the correct length. 158 | """ 159 | if not isinstance(target_frame, SE3): 160 | target_frame_jnp = jnp.array(target_frame) 161 | if target_frame_jnp.shape[-1] != SE3.parameters_dim: 162 | raise ValueError( 163 | "Target frame provided via array must have length 7 (xyz + quaternion with scalar first)" 164 | ) 165 | 166 | xyz, quat = target_frame_jnp[..., :3], target_frame_jnp[..., 3:] 167 | target_frame_se3 = SE3.from_rotation_and_translation( 168 | SO3.from_quaternion_xyzw( 169 | quat[..., [1, 2, 3, 0]], 170 | ), 171 | xyz, 172 | ) 173 | else: 174 | target_frame_se3 = target_frame 175 | self._target_frame = target_frame_se3 176 | -------------------------------------------------------------------------------- /mjinx/components/tasks/_obj_position_task.py: -------------------------------------------------------------------------------- 1 | """Frame task implementation.""" 2 | 3 | from collections.abc import Callable, Sequence 4 | 5 | import jax.numpy as jnp 6 | import jax_dataclasses as jdc 7 | import mujoco as mj 8 | import mujoco.mjx as mjx 9 | 10 | from mjinx.components.tasks._obj_task import JaxObjTask, ObjTask 11 | from mjinx.typing import ArrayOrFloat 12 | 13 | 14 | @jdc.pytree_dataclass 15 | class JaxPositionTask(JaxObjTask): 16 | """ 17 | A JAX-based implementation of a position task for inverse kinematics. 18 | 19 | This class represents a task that aims to achieve a specific target position 20 | for an object (body, geometry, or site) in the robot model. 21 | 22 | The task function maps joint positions to the object's position in the world frame: 23 | 24 | .. math:: 25 | 26 | f(q) = p(q) 27 | 28 | where :math:`p(q)` is the position of the object in world coordinates. 29 | 30 | The error is computed as the difference between the current and target positions: 31 | 32 | .. math:: 33 | 34 | e(q) = p(q) - p_{target} 35 | 36 | The Jacobian of this task is the object's position Jacobian, which relates changes in 37 | joint positions to changes in the object's position. 38 | 39 | :param target_pos: The target position to be achieved. 40 | """ 41 | 42 | target_pos: jnp.ndarray 43 | 44 | def __call__(self, data: mjx.Data) -> jnp.ndarray: 45 | """ 46 | Compute the error between the current position and the target position. 47 | 48 | The error is given by: 49 | 50 | .. math:: 51 | 52 | e(q) = p(q) - p_{target} 53 | 54 | :param data: The MuJoCo simulation data. 55 | :return: The error vector representing the difference between the current and target positions. 56 | """ 57 | return self.get_pos(data)[self.mask_idxs,] - self.target_pos 58 | 59 | 60 | class PositionTask(ObjTask[JaxPositionTask]): 61 | """ 62 | A high-level representation of a position task for inverse kinematics. 63 | 64 | This class provides an interface for creating and manipulating position tasks, 65 | which aim to achieve a specific target position for an object in the robot model. 66 | 67 | :param name: The name of the task. 68 | :param cost: The cost associated with the task. 69 | :param gain: The gain for the task. 70 | :param obj_name: The name of the object to which the task is applied. 71 | :param gain_fn: A function to compute the gain dynamically. 72 | :param lm_damping: The Levenberg-Marquardt damping factor. 73 | :param mask: A sequence of integers to mask certain dimensions of the task. 74 | """ 75 | 76 | JaxComponentType: type = JaxPositionTask 77 | _target_pos: jnp.ndarray 78 | 79 | def __init__( 80 | self, 81 | name: str, 82 | cost: ArrayOrFloat, 83 | gain: ArrayOrFloat, 84 | obj_name: str, 85 | obj_type: mj.mjtObj = mj.mjtObj.mjOBJ_BODY, 86 | gain_fn: Callable[[float], float] | None = None, 87 | lm_damping: float = 0, 88 | mask: Sequence[int] | None = None, 89 | ): 90 | super().__init__(name, cost, gain, obj_name, obj_type, gain_fn, lm_damping, mask) 91 | self._dim = 3 if mask is None else len(self.mask_idxs) 92 | self._target_pos = jnp.zeros(self._dim) 93 | 94 | @property 95 | def target_pos(self) -> jnp.ndarray: 96 | """ 97 | Get the current target position for the task. 98 | 99 | :return: The current target position as a numpy array. 100 | """ 101 | return self._target_pos 102 | 103 | @target_pos.setter 104 | def target_pos(self, value: Sequence): 105 | """ 106 | Set the target position for the task. 107 | 108 | :param value: The new target position as a sequence of values. 109 | """ 110 | self.update_target_pos(value) 111 | 112 | def update_target_pos(self, target_pos: Sequence): 113 | """ 114 | Update the target position for the task. 115 | 116 | This method allows setting the target position using a sequence of values. 117 | 118 | :param target_pos: The new target position as a sequence of values. 119 | :raises ValueError: If the provided sequence doesn't have the correct length. 120 | """ 121 | target_pos_array = jnp.array(target_pos) 122 | if target_pos_array.shape[-1] != self._dim: 123 | raise ValueError( 124 | f"Invalid dimension of the target position: expected {self._dim}, got {target_pos_array.shape[-1]}" 125 | ) 126 | self._target_pos = target_pos_array 127 | -------------------------------------------------------------------------------- /mjinx/configuration/__init__.py: -------------------------------------------------------------------------------- 1 | from ._collision import compute_collision_pairs, geom_groups, get_distance, sorted_pair 2 | from ._lie import attitude_jacobian, get_joint_zero, jac_dq2v, joint_difference, skew_symmetric 3 | from ._model import ( 4 | geom_point_jacobian, 5 | get_configuration_limit, 6 | get_frame_jacobian_local, 7 | get_frame_jacobian_world_aligned, 8 | get_transform, 9 | get_transform_frame_to_world, 10 | integrate, 11 | update, 12 | ) 13 | -------------------------------------------------------------------------------- /mjinx/solvers/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base import Solver, SolverData 2 | from ._global_ik import GlobalIKData, GlobalIKSolution, GlobalIKSolver 3 | from ._local_ik import LocalIKData, LocalIKSolution, LocalIKSolver 4 | 5 | __all__ = [ 6 | "Solver", 7 | "SolverData", 8 | "GlobalIKData", 9 | "GlobalIKSolution", 10 | "GlobalIKSolver", 11 | "LocalIKData", 12 | "LocalIKSolution", 13 | "LocalIKSolver", 14 | ] 15 | -------------------------------------------------------------------------------- /mjinx/solvers/_base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Generic, TypeVar 3 | 4 | import jax.numpy as jnp 5 | import jax_dataclasses as jdc 6 | import mujoco.mjx as mjx 7 | 8 | import mjinx.typing as mjt 9 | from mjinx import configuration 10 | from mjinx.problem import JaxProblemData 11 | 12 | 13 | @jdc.pytree_dataclass 14 | class SolverData: 15 | """Base class for solver-specific data. 16 | 17 | This class serves as a placeholder for any data that a specific solver needs to maintain 18 | between iterations or function calls. It enables solver implementations to preserve 19 | state information and warm-start subsequent optimization steps. 20 | """ 21 | 22 | pass 23 | 24 | 25 | @jdc.pytree_dataclass 26 | class SolverSolution: 27 | """Base class for solver solutions. 28 | 29 | This class provides the structure for returning optimization results. It contains 30 | the optimal velocity solution and can be extended by specific solvers to include 31 | additional solution information. 32 | 33 | :param v_opt: Optimal velocity solution. 34 | """ 35 | 36 | v_opt: jnp.ndarray 37 | 38 | 39 | SolverDataType = TypeVar("SolverDataType", bound=SolverData) 40 | SolverSolutionType = TypeVar("SolverSolutionType", bound=SolverSolution) 41 | 42 | 43 | class Solver(Generic[SolverDataType, SolverSolutionType], abc.ABC): 44 | r"""Abstract base class for solvers. 45 | 46 | This class defines the interface for solvers used in inverse kinematics problems. 47 | Solvers transform task and barrier constraints into optimization problems, which 48 | can be solved to find joint configurations or velocities that satisfy the constraints. 49 | 50 | In general, the optimization problem can be formulated as: 51 | 52 | .. math:: 53 | 54 | \min_{q} \sum_{i} \|e_i(q)\|^2_{W_i} \quad \text{subject to} \quad h_j(q) \geq 0 55 | 56 | where: 57 | - :math:`e_i(q)` are the task errors 58 | - :math:`W_i` are the task weight matrices 59 | - :math:`h_j(q)` are the barrier constraints 60 | 61 | Different solver implementations use different approaches to solve this problem, 62 | such as local linearization (QP) or global nonlinear optimization. 63 | 64 | :param model: The MuJoCo model used by the solver. 65 | """ 66 | 67 | model: mjx.Model 68 | 69 | def __init__(self, model: mjx.Model): 70 | """Initialize the solver with a MuJoCo model. 71 | 72 | :param model: The MuJoCo model to be used by the solver. 73 | """ 74 | self.model = model 75 | 76 | @abc.abstractmethod 77 | def solve_from_data( 78 | self, solver_data: SolverDataType, problem_data: JaxProblemData, model_data: mjx.Data 79 | ) -> tuple[SolverSolutionType, SolverDataType]: 80 | """Solve the inverse kinematics problem using pre-computed model data. 81 | 82 | :param solver_data: Solver-specific data. 83 | :param problem_data: Problem-specific data. 84 | :param model_data: MuJoCo model data. 85 | :return: A tuple containing the solver solution and updated solver data. 86 | """ 87 | pass 88 | 89 | def solve( 90 | self, q: jnp.ndarray, solver_data: SolverDataType, problem_data: JaxProblemData 91 | ) -> tuple[SolverSolutionType, SolverDataType]: 92 | """Solve the inverse kinematics problem for a given configuration. 93 | 94 | This method creates mjx.Data instance and updates it under the hood. To avoid doing an extra 95 | update, consider solve_from_data method. 96 | 97 | :param q: The current joint configuration. 98 | :param solver_data: Solver-specific data. 99 | :param problem_data: Problem-specific data. 100 | :return: A tuple containing the solver solution and updated solver data. 101 | :raises ValueError: If the input configuration has incorrect dimensions. 102 | """ 103 | if q.shape != (self.model.nq,): 104 | raise ValueError(f"wrong dimension of the state: expected ({self.model.nq}, ), got {q.shape}") 105 | model_data = configuration.update(self.model, q) 106 | return self.solve_from_data(solver_data, problem_data, model_data) 107 | 108 | @abc.abstractmethod 109 | def init(self, q: mjt.ndarray) -> SolverDataType: 110 | """Initialize solver-specific data. 111 | 112 | :param q: The initial joint configuration. 113 | :return: Initialized solver-specific data. 114 | """ 115 | pass 116 | -------------------------------------------------------------------------------- /mjinx/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains type definitions and aliases used throughout the mjinx library. 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from collections.abc import Callable 8 | from enum import Enum 9 | from typing import NamedTuple, TypeAlias 10 | 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | from mujoco.mjx._src.dataclasses import PyTreeNode 15 | 16 | ndarray: TypeAlias = np.ndarray | jnp.ndarray 17 | """Type alias for numpy or JAX numpy arrays.""" 18 | 19 | ArrayOrFloat: TypeAlias = ndarray | float 20 | """Type alias for an array or a float value.""" 21 | 22 | ClassKFunctions: TypeAlias = Callable[[ndarray], ndarray] 23 | """Type alias for Class K functions, which are scalar functions that take and return ndarrays.""" 24 | 25 | CollisionBody: TypeAlias = int | str 26 | """Type alias for collision body representation, either as an integer ID or a string name.""" 27 | 28 | CollisionPair: TypeAlias = tuple[int, int] 29 | """Type alias for a pair of collision body IDs.""" 30 | 31 | ArrayLike: TypeAlias = np.typing.ArrayLike | jax.typing.ArrayLike 32 | """Type alias for an array-like object, either a numpy array or a JAX array-like object.""" 33 | 34 | 35 | class SimplifiedContact(PyTreeNode): 36 | geom: jnp.ndarray 37 | dist: jnp.ndarray 38 | pos: jnp.ndarray 39 | frame: jnp.ndarray 40 | 41 | 42 | class PositionLimitType(Enum): 43 | """Type which describes possible position limits. 44 | 45 | The position limit could be only minimal, only maximal, or minimal and maximal. 46 | """ 47 | 48 | MIN = 0 49 | MAX = 1 50 | BOTH = 2 51 | 52 | @staticmethod 53 | def from_str(type: str) -> PositionLimitType: 54 | """Generates position limit type from string. 55 | 56 | :param type: position limit type. 57 | :raises ValueError: limit name is not 'min', 'max', or 'both'. 58 | :return: corresponding enum type. 59 | """ 60 | match type.lower(): 61 | case "min": 62 | return PositionLimitType.MIN 63 | case "max": 64 | return PositionLimitType.MAX 65 | case "both": 66 | return PositionLimitType.BOTH 67 | case _: 68 | raise ValueError( 69 | f"[PositionLimitType] invalid position limit type: {type}. Expected {{'min', 'max', 'both'}}" 70 | ) 71 | 72 | @staticmethod 73 | def includes_min(type: PositionLimitType) -> bool: 74 | """Either given limit includes minimum limit or not. 75 | 76 | Returns true, if limit is either MIN or BOTH, and false otherwise. 77 | 78 | :param type: limit to be processes. 79 | :return: True, if limit includes minimum limit, False otherwise. 80 | """ 81 | return type == PositionLimitType.MIN or type == PositionLimitType.BOTH 82 | 83 | @staticmethod 84 | def includes_max(type: PositionLimitType) -> bool: 85 | """Either given limit includes maximum limit or not. 86 | 87 | Returns true, if limit is either MIN or BOTH, and false otherwise. 88 | 89 | :param type: limit to be processes. 90 | :return: True, if limit includes maximum limit, False otherwise. 91 | """ 92 | 93 | return type == PositionLimitType.MAX or type == PositionLimitType.BOTH 94 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "mjinx" 3 | version = "0.1.1" 4 | description = "Numerical Inverse Kinematics based on JAX + MJX" 5 | authors = [ 6 | { name = "Ivan Domrachev", email = "domrachev10@gmail.com" }, 7 | { name = "Simeon Nedelchev", email = "simkaned@gmail.com" }, 8 | { name = "Lev Kozlov", email = "l.kozlov@kaist.ac.kr" }, 9 | ] 10 | readme = "README.md" 11 | requires-python = ">=3.10" 12 | 13 | dependencies = [ 14 | "mujoco", 15 | "mujoco_mjx", 16 | "jaxopt", 17 | "jax>=0.5", 18 | "jaxlie>=1.4", 19 | "jax_dataclasses>=1.6.0", 20 | "optax>=0.2", 21 | ] 22 | 23 | [project.optional-dependencies] 24 | dev = ["pre-commit", "ruff", "pytest", "robot_descriptions>=1.12"] 25 | docs = [ 26 | "sphinx>=8", 27 | "sphinx-mathjax-offline", 28 | "sphinx-autodoc-typehints", 29 | "sphinx-rtd-theme>1", 30 | "sphinxcontrib-bibtex", 31 | ] 32 | visual = ["dm_control", "mediapy"] 33 | examples = ["mjinx[visual]", "robot_descriptions>=1.12"] 34 | all = ["mjinx[dev, visual, docs]"] 35 | 36 | [build-system] 37 | requires = ["setuptools>=43.0.0", "wheel"] 38 | build-backend = "setuptools.build_meta" 39 | 40 | [tool.setuptools.packages.find] 41 | where = ["."] 42 | include = ["mjinx*"] 43 | 44 | [tool.ruff] 45 | select = [ 46 | "E", # pycodestyle errors 47 | "W", # pycodestyle warnings 48 | "F", # pyflakes 49 | "I", # isort 50 | "B", # flake8-bugbear 51 | "C4", # flake8-comprehensions 52 | "UP", # pyupgrade 53 | ] 54 | line-length = 120 55 | 56 | [tool.ruff.lint.per-file-ignores] 57 | "__init__.py" = ["F401"] 58 | 59 | [tool.pytest] 60 | filterwarnings = "ignore:.*U.*mode is deprecated:DeprecationWarning" 61 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/__init__.py -------------------------------------------------------------------------------- /tests/example_tests/test_global_ik_jit.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import mujoco as mj 4 | import mujoco.mjx as mjx 5 | import numpy as np 6 | from optax import adam 7 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH 8 | 9 | from mjinx.components.barriers import JointBarrier, PositionBarrier 10 | from mjinx.components.tasks import FrameTask 11 | from mjinx.problem import Problem 12 | from mjinx.solvers import GlobalIKSolver 13 | 14 | 15 | def test_global_ik_jit(): 16 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH) 17 | mjx_model = mjx.put_model(mj_model) 18 | 19 | # === Mjinx === 20 | 21 | # --- Constructing the problem --- 22 | # Creating problem formulation 23 | problem = Problem(mjx_model, v_min=-100, v_max=100) 24 | 25 | # Creating components of interest and adding them to the problem 26 | frame_task = FrameTask("ee_task", cost=1, gain=50, obj_name="link7") 27 | position_barrier = PositionBarrier( 28 | "ee_barrier", 29 | gain=0.1, 30 | obj_name="link7", 31 | limit_type="max", 32 | p_max=0.3, 33 | safe_displacement_gain=1e-2, 34 | mask=[1, 0, 0], 35 | ) 36 | joints_barrier = JointBarrier("jnt_range", gain=0.1) 37 | 38 | problem.add_component(frame_task) 39 | problem.add_component(position_barrier) 40 | problem.add_component(joints_barrier) 41 | 42 | # Compiling the problem upon any parameters update 43 | problem_data = problem.compile() 44 | 45 | # Initializing solver and its initial state 46 | solver = GlobalIKSolver(mjx_model, adam(learning_rate=1e-2), dt=1e-2) 47 | 48 | # Initial condition 49 | q = np.array( 50 | [ 51 | -1.4238753, 52 | -1.7268502, 53 | -0.84355015, 54 | 2.0962472, 55 | 2.1339328, 56 | 2.0837479, 57 | -2.5521986, 58 | ] 59 | ) 60 | solver_data = solver.init(q) 61 | 62 | # Jit-compiling the key functions for better efficiency 63 | solve_jit = jax.jit(solver.solve) 64 | 65 | # === Control loop === 66 | dt = 1e-2 67 | ts = np.arange(0, 1, dt) 68 | 69 | for t in ts: 70 | # Changing desired values 71 | frame_task.target_frame = np.array([0.2 + 0.2 * jnp.sin(t) ** 2, 0.2, 0.2, 1, 0, 0, 0]) 72 | # After changes, recompiling the model 73 | problem_data = problem.compile() 74 | # Solving the instance of the problem 75 | for _ in range(1): 76 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 77 | -------------------------------------------------------------------------------- /tests/example_tests/test_global_ik_vmap.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import mujoco as mj 4 | import mujoco.mjx as mjx 5 | import numpy as np 6 | from optax import adam 7 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH 8 | 9 | from mjinx.components.barriers import JointBarrier, PositionBarrier 10 | from mjinx.components.tasks import FrameTask 11 | from mjinx.problem import Problem 12 | from mjinx.solvers import GlobalIKSolver 13 | 14 | 15 | def test_global_ik_jit(): 16 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH) 17 | mjx_model = mjx.put_model(mj_model) 18 | 19 | q_min = mj_model.jnt_range[:, 0].copy() 20 | q_max = mj_model.jnt_range[:, 1].copy() 21 | 22 | # === Mjinx === 23 | 24 | # --- Constructing the problem --- 25 | # Creating problem formulation 26 | problem = Problem(mjx_model, v_min=-100, v_max=100) 27 | 28 | # Creating components of interest and adding them to the problem 29 | frame_task = FrameTask("ee_task", cost=1, gain=50, obj_name="link7") 30 | position_barrier = PositionBarrier( 31 | "ee_barrier", 32 | gain=0.1, 33 | obj_name="link7", 34 | limit_type="max", 35 | p_max=0.3, 36 | safe_displacement_gain=1e-2, 37 | mask=[1, 0, 0], 38 | ) 39 | joints_barrier = JointBarrier("jnt_range", gain=0.1) 40 | 41 | problem.add_component(frame_task) 42 | problem.add_component(position_barrier) 43 | problem.add_component(joints_barrier) 44 | 45 | # Compiling the problem upon any parameters update 46 | problem_data = problem.compile() 47 | 48 | # Initializing solver and its initial state 49 | solver = GlobalIKSolver(mjx_model, adam(learning_rate=1e-2), dt=1e-2) 50 | 51 | # Initializing initial condition 52 | N_batch = 1000 53 | np.random.seed(42) 54 | q0 = jnp.array( 55 | [ 56 | -1.4238753, 57 | -1.7268502, 58 | -0.84355015, 59 | 2.0962472, 60 | 2.1339328, 61 | 2.0837479, 62 | -2.5521986, 63 | ] 64 | ) 65 | q = jnp.array( 66 | [ 67 | np.clip( 68 | q0 69 | + np.random.uniform( 70 | -0.5, 71 | 0.5, 72 | size=(mj_model.nq), 73 | ), 74 | q_min + 1e-1, 75 | q_max - 1e-1, 76 | ) 77 | for _ in range(N_batch) 78 | ] 79 | ) 80 | 81 | # --- Batching --- 82 | # First of all, data should be created via vmapped init function 83 | solver_data = jax.vmap(solver.init, in_axes=0)(q) 84 | 85 | # To create a batch w.r.t. desired component's attributes, library defines convinient wrapper 86 | # That sets all elements to None and allows user to mutate dataclasses of interest. 87 | # After exiting the Context Manager, you'll get immutable jax dataclass object. 88 | with problem.set_vmap_dimension() as empty_problem_data: 89 | empty_problem_data.components["ee_task"].target_frame = 0 90 | 91 | # Vmapping solve and integrate functions. 92 | # Note that for batching w.r.t. q both q and solver_data should be batched. 93 | # Other approaches might work, but it would be undefined behaviour, please stick to this format. 94 | solve_jit = jax.jit(jax.vmap(solver.solve, in_axes=(0, 0, empty_problem_data))) 95 | 96 | # === Control loop === 97 | dt = 1e-2 98 | ts = np.arange(0, 1, dt) 99 | 100 | for t in ts: 101 | # Changing desired values 102 | frame_task.target_frame = np.array( 103 | [[0.2 + 0.2 * np.sin(t + np.pi * i / N_batch) ** 2, 0.2, 0.2, 1, 0, 0, 0] for i in range(N_batch)] 104 | ) 105 | # After changes, recompiling the model 106 | problem_data = problem.compile() 107 | # Solving the instance of the problem 108 | for _ in range(1): 109 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 110 | -------------------------------------------------------------------------------- /tests/example_tests/test_local_ik_jit.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import mujoco as mj 4 | import mujoco.mjx as mjx 5 | import numpy as np 6 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH 7 | 8 | from mjinx.components.barriers import JointBarrier, PositionBarrier 9 | from mjinx.components.tasks import FrameTask 10 | from mjinx.configuration import integrate 11 | from mjinx.problem import Problem 12 | from mjinx.solvers import LocalIKSolver 13 | 14 | # === Mujoco === 15 | 16 | 17 | def test_local_ik_jit(): 18 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH) 19 | mjx_model = mjx.put_model(mj_model) 20 | 21 | # === Mjinx === 22 | 23 | # --- Constructing the problem --- 24 | # Creating problem formulation 25 | problem = Problem(mjx_model, v_min=-100, v_max=100) 26 | 27 | # Creating components of interest and adding them to the problem 28 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7") 29 | position_barrier = PositionBarrier( 30 | "ee_barrier", 31 | gain=100, 32 | obj_name="link7", 33 | limit_type="max", 34 | p_max=0.3, 35 | safe_displacement_gain=1e-2, 36 | mask=[1, 0, 0], 37 | ) 38 | joints_barrier = JointBarrier("jnt_range", gain=10) 39 | 40 | problem.add_component(frame_task) 41 | problem.add_component(position_barrier) 42 | problem.add_component(joints_barrier) 43 | 44 | # Compiling the problem upon any parameters update 45 | problem_data = problem.compile() 46 | 47 | # Initializing solver and its initial state 48 | solver = LocalIKSolver(mjx_model, maxiter=20) 49 | 50 | # Initial condition 51 | q = jnp.array( 52 | [ 53 | -1.4238753, 54 | -1.7268502, 55 | -0.84355015, 56 | 2.0962472, 57 | 2.1339328, 58 | 2.0837479, 59 | -2.5521986, 60 | ] 61 | ) 62 | solver_data = solver.init() 63 | 64 | # Jit-compiling the key functions for better efficiency 65 | solve_jit = jax.jit(solver.solve) 66 | integrate_jit = jax.jit(integrate, static_argnames=["dt"]) 67 | 68 | # === Control loop === 69 | dt = 1e-2 70 | ts = np.arange(0, 1, dt) 71 | 72 | for t in ts: 73 | # Changing desired values 74 | frame_task.target_frame = np.array([0.2 + 0.2 * jnp.sin(t) ** 2, 0.2, 0.2, 1, 0, 0, 0]) 75 | # After changes, recompiling the model 76 | problem_data = problem.compile() 77 | 78 | # Solving the instance of the problem 79 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 80 | 81 | # Integrating 82 | q = integrate_jit( 83 | mjx_model, 84 | q, 85 | velocity=opt_solution.v_opt, 86 | dt=dt, 87 | ) 88 | -------------------------------------------------------------------------------- /tests/example_tests/test_local_ik_vmap.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import mujoco as mj 4 | import mujoco.mjx as mjx 5 | import numpy as np 6 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH 7 | 8 | from mjinx.components.barriers import JointBarrier, PositionBarrier 9 | from mjinx.components.tasks import FrameTask 10 | from mjinx.configuration import integrate 11 | from mjinx.problem import Problem 12 | from mjinx.solvers import LocalIKSolver 13 | 14 | 15 | # === Mujoco === 16 | def test_local_ik_vmap(): 17 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH) 18 | 19 | mjx_model = mjx.put_model(mj_model) 20 | 21 | # === Mjinx === 22 | 23 | # --- Constructing the problem --- 24 | # Creating problem formulation 25 | problem = Problem(mjx_model, v_min=-100, v_max=100) 26 | 27 | # Creating components of interest and adding them to the problem 28 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7") 29 | position_barrier = PositionBarrier( 30 | "ee_barrier", 31 | gain=100, 32 | obj_name="link7", 33 | limit_type="max", 34 | p_max=0.3, 35 | safe_displacement_gain=1e-2, 36 | mask=[1, 0, 0], 37 | ) 38 | joints_barrier = JointBarrier("jnt_range", gain=10) 39 | 40 | problem.add_component(frame_task) 41 | problem.add_component(position_barrier) 42 | problem.add_component(joints_barrier) 43 | 44 | # Compiling the problem upon any parameters update 45 | problem_data = problem.compile() 46 | 47 | # Initializing solver and its initial state 48 | solver = LocalIKSolver(mjx_model, maxiter=20) 49 | 50 | # Initializing initial condition 51 | N_batch = 10000 52 | q0 = np.array( 53 | [ 54 | -1.5878328, 55 | -2.0968683, 56 | -1.4339591, 57 | 1.6550868, 58 | 2.1080072, 59 | 1.646142, 60 | -2.982619, 61 | ] 62 | ) 63 | q = jnp.array([q0.copy() for _ in range(N_batch)]) 64 | 65 | # --- Batching --- 66 | # First of all, data should be created via vmapped init function 67 | solver_data = jax.vmap(solver.init, in_axes=0)(v_init=jnp.zeros((N_batch, mjx_model.nv))) 68 | 69 | # To create a batch w.r.t. desired component's attributes, library defines convinient wrapper 70 | # That sets all elements to None and allows user to mutate dataclasses of interest. 71 | # After exiting the Context Manager, you'll get immutable jax dataclass object. 72 | with problem.set_vmap_dimension() as empty_problem_data: 73 | empty_problem_data.components["ee_task"].target_frame = 0 74 | 75 | # Vmapping solve and integrate functions. 76 | # Note that for batching w.r.t. q both q and solver_data should be batched. 77 | # Other approaches might work, but it would be undefined behaviour, please stick to this format. 78 | solve_jit = jax.jit( 79 | jax.vmap( 80 | solver.solve, 81 | in_axes=(0, 0, empty_problem_data), 82 | ) 83 | ) 84 | integrate_jit = jax.jit(jax.vmap(integrate, in_axes=(None, 0, 0, None)), static_argnames=["dt"]) 85 | 86 | # === Control loop === 87 | dt = 1e-2 88 | ts = np.arange(0, 1, dt) 89 | 90 | for t in ts: 91 | # Changing desired values 92 | frame_task.target_frame = np.array( 93 | [[0.2 + 0.2 * np.sin(t + np.pi * i / N_batch) ** 2, 0.2, 0.2, 1, 0, 0, 0] for i in range(N_batch)] 94 | ) 95 | problem_data = problem.compile() 96 | 97 | # Solving the instance of the problem 98 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data) 99 | 100 | # Integrating 101 | q = integrate_jit( 102 | mjx_model, 103 | q, 104 | opt_solution.v_opt, 105 | dt, 106 | ) 107 | -------------------------------------------------------------------------------- /tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/components/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/components/barriers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/components/barriers/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/components/barriers/test_base_barrier.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax.numpy as jnp 4 | import mujoco.mjx as mjx 5 | 6 | from mjinx.components.barriers import Barrier, JaxBarrier 7 | 8 | 9 | class DummyJaxBarrier(JaxBarrier): 10 | """Dummy class to make minimal non-abstract jax barrier class""" 11 | 12 | def __call__(self, data: mjx.Data) -> jnp.ndarray: 13 | return jnp.ones(self.model.nv)[self.mask_idxs,] 14 | 15 | 16 | class DummyBarrier(Barrier[DummyJaxBarrier]): 17 | """Dummy class to make minimal non-abstract barrier class""" 18 | 19 | JaxComponentType: type = DummyJaxBarrier 20 | 21 | 22 | class TestBarrier(unittest.TestCase): 23 | def test_safe_displacement_gain(self): 24 | """Tests proper definition of safe displacement gain""" 25 | self.component = DummyBarrier("test_component", gain=2.0) 26 | self.component.safe_displacement_gain = 5.0 27 | self.assertEqual(self.component.safe_displacement_gain, 5.0) 28 | -------------------------------------------------------------------------------- /tests/unit_tests/components/barriers/test_obj_position_barrier.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax.numpy as jnp 4 | import mujoco as mj 5 | import mujoco.mjx as mjx 6 | import numpy as np 7 | 8 | from mjinx.components.barriers._obj_position_barrier import JaxPositionBarrier, PositionBarrier 9 | from mjinx.typing import PositionLimitType 10 | 11 | 12 | class TestPositionBarrier(unittest.TestCase): 13 | def setUp(self): 14 | self.model = mjx.put_model( 15 | mj.MjModel.from_xml_string( 16 | """ 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | """ 26 | ) 27 | ) 28 | self.data = mjx.make_data(self.model) 29 | 30 | def test_initialization(self): 31 | """Testing component initialization""" 32 | barrier = PositionBarrier( 33 | name="test_barrier", 34 | gain=1.0, 35 | obj_name="body1", 36 | p_min=[-1.0, -1.0, -1.0], 37 | p_max=[1.0, 1.0, 1.0], 38 | ) 39 | barrier.update_model(self.model) 40 | 41 | self.assertEqual(barrier.name, "test_barrier") 42 | self.assertEqual(barrier.obj_name, "body1") 43 | self.assertEqual(barrier.obj_id, 1) 44 | np.testing.assert_array_equal(barrier.p_min, jnp.array([-1.0, -1.0, -1.0])) 45 | np.testing.assert_array_equal(barrier.p_max, jnp.array([1.0, 1.0, 1.0])) 46 | self.assertEqual(barrier.limit_type, PositionLimitType.BOTH) 47 | 48 | def test_update_p_min(self): 49 | """Testing minimal position updates""" 50 | barrier = PositionBarrier( 51 | name="test_barrier", 52 | gain=1.0, 53 | obj_name="body1", 54 | p_min=[-1.0, -1.0, -1.0], 55 | limit_type="min", 56 | ) 57 | barrier.update_p_min(0.1) 58 | np.testing.assert_array_equal(barrier.p_min, 0.1 * jnp.ones(3)) 59 | 60 | barrier.update_p_min([-2.0, -2.0, -2.0]) 61 | np.testing.assert_array_equal(barrier.p_min, jnp.array([-2.0, -2.0, -2.0])) 62 | 63 | with self.assertRaises(ValueError): 64 | barrier.update_p_min([-2.0, -2.0]) 65 | 66 | with self.assertWarns(Warning): 67 | barrier.p_max = [1.0, 1.0, 1.0] 68 | 69 | def test_update_p_max(self): 70 | """Testing maximal position updates""" 71 | barrier = PositionBarrier( 72 | name="test_barrier", 73 | gain=1.0, 74 | obj_name="body1", 75 | p_max=[1.0, 1.0, 1.0], 76 | limit_type="max", 77 | ) 78 | barrier.update_p_max(0.1) 79 | np.testing.assert_array_equal(barrier.p_max, 0.1 * jnp.ones(3)) 80 | 81 | barrier.update_p_max([2.0, 2.0, 2.0]) 82 | np.testing.assert_array_equal(barrier.p_max, jnp.array([2.0, 2.0, 2.0])) 83 | 84 | with self.assertRaises(ValueError): 85 | barrier.update_p_max([2.0, 2.0]) 86 | 87 | with self.assertWarns(Warning): 88 | barrier.p_min = [1.0, 1.0, 1.0] 89 | 90 | def test_limit_type_and_dimension(self): 91 | """Test limit type detection and dimension of the barrier""" 92 | barrier_min = PositionBarrier( 93 | name="test_barrier_min", 94 | gain=1.0, 95 | obj_name="body1", 96 | p_min=[-1.0, -1.0, -1.0], 97 | limit_type="min", 98 | ) 99 | self.assertEqual(barrier_min.dim, 3) 100 | self.assertEqual(barrier_min.limit_type, PositionLimitType.MIN) 101 | 102 | barrier_max = PositionBarrier( 103 | name="test_barrier_max", 104 | gain=1.0, 105 | obj_name="body1", 106 | p_max=[1.0, 1.0, 1.0], 107 | limit_type="max", 108 | ) 109 | self.assertEqual(barrier_max.dim, 3) 110 | self.assertEqual(barrier_max.limit_type, PositionLimitType.MAX) 111 | 112 | barrier_both = PositionBarrier( 113 | name="test_barrier_max", 114 | gain=1.0, 115 | obj_name="body1", 116 | p_min=[-1.0, -1.0, -1.0], 117 | p_max=[1.0, 1.0, 1.0], 118 | limit_type="both", 119 | ) 120 | self.assertEqual(barrier_both.dim, 6) 121 | self.assertEqual(barrier_both.limit_type, PositionLimitType.BOTH) 122 | 123 | with self.assertRaises(ValueError): 124 | PositionBarrier( 125 | name="test_barrier_invalid", 126 | gain=1.0, 127 | obj_name="body1", 128 | limit_type="invalid", 129 | ) 130 | 131 | def test_jax_component(self): 132 | """Test generating jac component""" 133 | barrier = PositionBarrier( 134 | name="test_barrier", 135 | gain=1.0, 136 | obj_name="body1", 137 | p_min=[-1.0, -1.0, -1.0], 138 | p_max=[1.0, 1.0, 1.0], 139 | ) 140 | barrier.update_model(self.model) 141 | 142 | jax_component = barrier.jax_component 143 | 144 | self.assertIsInstance(jax_component, JaxPositionBarrier) 145 | self.assertEqual(jax_component.dim, 6) 146 | np.testing.assert_array_equal(jax_component.vector_gain, jnp.ones(6)) 147 | self.assertEqual(jax_component.obj_id, 1) 148 | np.testing.assert_array_equal(jax_component.p_min, jnp.array([-1.0, -1.0, -1.0])) 149 | np.testing.assert_array_equal(jax_component.p_max, jnp.array([1.0, 1.0, 1.0])) 150 | 151 | def test_jax_call(self): 152 | """Test call of the jax barrier function""" 153 | barrier = JaxPositionBarrier( 154 | dim=3, 155 | model=self.model, 156 | vector_gain=jnp.ones(3), 157 | gain_fn=lambda x: x, 158 | mask_idxs=(0, 1, 2), 159 | safe_displacement_gain=0.0, 160 | obj_id=1, 161 | obj_type=mj.mjtObj.mjOBJ_BODY, 162 | p_min=jnp.array([-1.0, -1.0, -1.0]), 163 | p_max=jnp.array([1.0, 1.0, 1.0]), 164 | limit_type_mask_idxs=tuple(i for i in range(6)), 165 | ) 166 | 167 | self.data = self.data.replace(qpos=jnp.array([0.0, 0.5, -0.5, 1.0, 0.0, 0.0, 0.0])) 168 | self.data = mjx.kinematics(self.model, self.data) 169 | result = barrier(self.data) 170 | result_compute_barrier = barrier.compute_barrier(self.data) 171 | np.testing.assert_array_equal(result, result_compute_barrier) 172 | np.testing.assert_array_almost_equal(result, jnp.array([1.0, 1.5, 0.5, 1.0, 0.5, 1.5])) 173 | -------------------------------------------------------------------------------- /tests/unit_tests/components/barriers/test_self_collision_barrier.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax.numpy as jnp 4 | import mujoco as mj 5 | import mujoco.mjx as mjx 6 | import numpy as np 7 | 8 | from mjinx.components.barriers._self_collision_barrier import JaxSelfCollisionBarrier, SelfCollisionBarrier 9 | from mjinx.configuration import get_distance, sorted_pair 10 | 11 | 12 | class TestSelfCollisionBarrier(unittest.TestCase): 13 | def setUp(self): 14 | self.model = mjx.put_model( 15 | mj.MjModel.from_xml_string( 16 | """ 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | """ 38 | ) 39 | ) 40 | self.data = mjx.make_data(self.model) 41 | 42 | def test_initialization(self): 43 | barrier = SelfCollisionBarrier( 44 | name="test_barrier", 45 | gain=1.0, 46 | d_min=0.1, 47 | collision_bodies=["body1", "body2"], 48 | excluded_collisions=[("body1", "body2")], 49 | ) 50 | with self.assertRaises(ValueError): 51 | _ = barrier.d_min_vec 52 | 53 | barrier.update_model(self.model) 54 | 55 | self.assertEqual(barrier.name, "test_barrier") 56 | self.assertEqual(barrier.d_min, 0.1) 57 | self.assertEqual(barrier.collision_bodies, ["body1", "body2"]) 58 | self.assertEqual(barrier.exclude_collisions, {sorted_pair(1, 2)}) 59 | 60 | def test_generate_collision_pairs(self): 61 | """Test straight-away collision pair generation""" 62 | barrier = SelfCollisionBarrier( 63 | name="test_barrier", 64 | gain=1.0, 65 | d_min=0.1, 66 | ) 67 | barrier.update_model(self.model) 68 | 69 | expected_pairs = [(0, 1), (0, 2)] 70 | self.assertEqual(barrier.collision_pairs, expected_pairs) 71 | 72 | def test_generate_collision_pairs_with_exclusion(self): 73 | """Test exclusion of collision pairs""" 74 | barrier = SelfCollisionBarrier( 75 | name="test_barrier", 76 | gain=1.0, 77 | d_min=0.1, 78 | excluded_collisions=[("body1", "body2")], 79 | ) 80 | barrier.update_model(self.model) 81 | 82 | expected_pairs = [(0, 2)] 83 | self.assertEqual(barrier.collision_pairs, expected_pairs) 84 | 85 | def test_body2id(self): 86 | barrier = SelfCollisionBarrier( 87 | name="test_barrier", 88 | gain=1.0, 89 | d_min=0.1, 90 | ) 91 | barrier.update_model(self.model) 92 | 93 | self.assertEqual(barrier.body2id("body1"), 1) 94 | self.assertEqual(barrier.body2id(1), 1) 95 | 96 | with self.assertRaises(ValueError): 97 | barrier.body2id(1.5) 98 | 99 | def test_validate_body_pair(self): 100 | """Test body pairs validation""" 101 | barrier = SelfCollisionBarrier( 102 | name="test_barrier", 103 | gain=1.0, 104 | d_min=0.1, 105 | ) 106 | barrier.update_model(self.model) 107 | 108 | # Valid body pairs 109 | self.assertTrue(barrier.validate_body_pair(1, 2)) 110 | self.assertTrue(barrier.validate_body_pair(1, 3)) 111 | self.assertTrue(barrier.validate_body_pair(2, 4)) 112 | 113 | # Invalid pairs: consequitive pairs 114 | self.assertFalse(barrier.validate_body_pair(2, 3)) 115 | self.assertFalse(barrier.validate_body_pair(3, 4)) 116 | 117 | def test_validate_geom_pair(self): 118 | """Test geometry pairs validation""" 119 | barrier = SelfCollisionBarrier( 120 | name="test_barrier", 121 | gain=1.0, 122 | d_min=0.1, 123 | ) 124 | barrier.update_model(self.model) 125 | 126 | # Valid geom pairs 127 | self.assertTrue(barrier.validate_geom_pair(0, 1)) 128 | self.assertTrue(barrier.validate_geom_pair(0, 2)) 129 | self.assertTrue(barrier.validate_geom_pair(1, 2)) 130 | 131 | for i in range(self.model.ngeom - 1): 132 | self.assertFalse(barrier.validate_geom_pair(3, i)) # Invalid geom pairs: condyn-affinity contradict. 133 | 134 | def test_d_min_vec(self): 135 | barrier = SelfCollisionBarrier( 136 | name="test_barrier", 137 | gain=1.0, 138 | d_min=0.1, 139 | ) 140 | barrier.update_model(self.model) 141 | 142 | np.testing.assert_array_equal(barrier.d_min_vec, jnp.ones(barrier.dim) * 0.1) 143 | 144 | def test_jax_component(self): 145 | barrier = SelfCollisionBarrier( 146 | name="test_barrier", 147 | gain=1.0, 148 | d_min=0.1, 149 | ) 150 | barrier.update_model(self.model) 151 | 152 | jax_component = barrier.jax_component 153 | 154 | self.assertIsInstance(jax_component, JaxSelfCollisionBarrier) 155 | self.assertEqual(jax_component.dim, 2) 156 | np.testing.assert_array_equal(jax_component.vector_gain, jnp.ones(jax_component.dim)) 157 | np.testing.assert_array_equal(jax_component.d_min_vec, jnp.ones(jax_component.dim) * 0.1) 158 | self.assertEqual(len(jax_component.collision_pairs), jax_component.dim) 159 | 160 | def test_call(self): 161 | """Test jax component actual computation""" 162 | collision_pairs = [(0, 1)] 163 | barrier = JaxSelfCollisionBarrier( 164 | dim=1, 165 | model=self.model, 166 | vector_gain=jnp.ones(1), 167 | gain_fn=lambda x: x, 168 | mask_idxs=(0,), 169 | safe_displacement_gain=0.0, 170 | d_min_vec=jnp.array([0.1]), 171 | collision_pairs=collision_pairs, 172 | n_closest_pairs=len(collision_pairs), 173 | ) 174 | 175 | result = barrier(self.data) 176 | expected = get_distance(self.model, self.data, collision_pairs)[0] - barrier.d_min_vec 177 | np.testing.assert_array_almost_equal(result, expected) 178 | -------------------------------------------------------------------------------- /tests/unit_tests/components/constraints/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/components/constraints/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/components/constraints/test_base_constraint.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | from mjinx.components.constraints import Constraint, JaxConstraint 4 | 5 | # Language: python 6 | import jax.numpy as jnp 7 | 8 | 9 | class DummyConstraint(Constraint): 10 | dim = 3 11 | pass 12 | 13 | 14 | class TestJaxConstraint(unittest.TestCase): 15 | def setUp(self): 16 | # Use a fixed dimension for tests. 17 | self.dim = 3 18 | 19 | def create_constraint(self, soft_cost): 20 | # Create a dummy constraint instance and set the required dim attribute. 21 | instance = DummyConstraint(name="dummy", gain=1.0, soft_constraint_cost=soft_cost) 22 | return instance 23 | 24 | def test_soft_constraint_cost_none(self): 25 | # Test when soft_constraint_cost is None, property returns 1e2 * identity matrix. 26 | constraint = self.create_constraint(None) 27 | expected = 1e2 * jnp.eye(self.dim) 28 | self.assertTrue(jnp.allclose(constraint.soft_constraint_cost, expected)) 29 | 30 | def test_soft_constraint_cost_scalar(self): 31 | # Test when soft_constraint_cost is a scalar. 32 | scalar_value = 10.0 33 | constraint = self.create_constraint(scalar_value) 34 | expected = jnp.eye(self.dim) * scalar_value 35 | self.assertTrue(jnp.allclose(constraint.soft_constraint_cost, expected)) 36 | 37 | def test_soft_constraint_cost_vector_correct(self): 38 | # Test when soft_constraint_cost is a vector with proper length. 39 | vector_cost = jnp.array([1.0, 2.0, 3.0]) 40 | constraint = self.create_constraint(vector_cost) 41 | expected = jnp.diag(vector_cost) 42 | self.assertTrue(jnp.allclose(constraint.soft_constraint_cost, expected)) 43 | 44 | def test_soft_constraint_cost_vector_invalid(self): 45 | # Test when soft_constraint_cost is a vector with incorrect length. 46 | vector_cost = jnp.array([1.0, 2.0]) # length != self.dim 47 | constraint = self.create_constraint(vector_cost) 48 | with self.assertRaises(ValueError): 49 | _ = constraint.soft_constraint_cost 50 | 51 | def test_soft_constraint_cost_matrix_correct(self): 52 | # Test when soft_constraint_cost is a matrix of proper shape. 53 | matrix_cost = jnp.array([[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]]) 54 | constraint = self.create_constraint(matrix_cost) 55 | self.assertTrue(jnp.allclose(constraint.soft_constraint_cost, matrix_cost)) 56 | 57 | def test_soft_constraint_cost_matrix_invalid(self): 58 | # Test when soft_constraint_cost is a matrix with an incorrect shape. 59 | matrix_cost = jnp.array([[1.0, 0], [0, 2.0]]) 60 | constraint = self.create_constraint(matrix_cost) 61 | with self.assertRaises(ValueError): 62 | _ = constraint.soft_constraint_cost 63 | 64 | def test_soft_constraint_cost_ndim_invalid(self): 65 | # Test when soft_constraint_cost has ndim > 2. 66 | invalid_cost = jnp.ones((self.dim, self.dim, 1)) 67 | constraint = self.create_constraint(invalid_cost) 68 | with self.assertRaises(ValueError): 69 | _ = constraint.soft_constraint_cost 70 | 71 | 72 | if __name__ == "__main__": 73 | unittest.main() 74 | -------------------------------------------------------------------------------- /tests/unit_tests/components/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/components/tasks/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/components/tasks/test_base_task.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax.numpy as jnp 4 | import mujoco as mj 5 | import mujoco.mjx as mjx 6 | import numpy as np 7 | 8 | from mjinx.components.tasks import JaxTask, Task 9 | from mjinx.typing import ArrayOrFloat 10 | 11 | 12 | class DummyJaxTask(JaxTask): 13 | def __call__(self, data: mjx.Data) -> jnp.ndarray: 14 | """Immitating CoM task""" 15 | return data.subtree_com[self.model.body_rootid[0], self.mask_idxs] 16 | 17 | 18 | class DummyTask(Task[DummyJaxTask]): 19 | JaxComponentType: type = DummyJaxTask 20 | 21 | def define_dim(self, dim: int): 22 | self._dim = dim 23 | 24 | 25 | class TestTask(unittest.TestCase): 26 | dummy_dim: int = 3 27 | 28 | def setUp(self): 29 | self.model = mjx.put_model( 30 | mj.MjModel.from_xml_string( 31 | """ 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | """ 42 | ) 43 | ) 44 | self.task = DummyTask("test_task", cost=2.0, gain=1.0) 45 | self.task.update_model(self.model) 46 | 47 | def set_dim(self): 48 | self.task.define_dim(self.dummy_dim) 49 | 50 | def __set_cost_and_check(self, cost: ArrayOrFloat): 51 | """Checking that cost is transformed into jnp.ndarray, and values are correct""" 52 | self.task.update_cost(cost) 53 | np.testing.assert_array_equal(self.task.cost, cost) 54 | 55 | def test_update_gain(self): 56 | """Testing setting up different cost dimensions""" 57 | # Test scalar cost assignment 58 | self.__set_cost_and_check(1.0) 59 | self.__set_cost_and_check(np.array(2.0)) 60 | self.__set_cost_and_check(jnp.array(3.0)) 61 | 62 | # Test vector cost assignment 63 | self.__set_cost_and_check(np.zeros(self.dummy_dim)) 64 | self.__set_cost_and_check(jnp.ones(self.dummy_dim)) 65 | 66 | # Test matrix cost assignment 67 | self.__set_cost_and_check(np.eye(self.dummy_dim)) 68 | self.__set_cost_and_check(2 * jnp.eye(self.dummy_dim)) 69 | 70 | # Test assigning cost with other dimension 71 | cost = np.zeros((self.dummy_dim, self.dummy_dim, self.dummy_dim)) 72 | with self.assertRaises(ValueError): 73 | self.task.update_cost(cost) 74 | 75 | def test_matrix_cost(self): 76 | """Testing proper convertation of cost to matrix form""" 77 | # Matrix could not be constucted till dimension is specified 78 | self.task.cost = 1.0 79 | with self.assertRaises(ValueError): 80 | _ = self.task.matrix_cost 81 | 82 | self.set_dim() 83 | 84 | # Scalar -> jnp.eye(dim) * scalar 85 | self.task.update_cost(3.0) 86 | np.testing.assert_array_equal(self.task.matrix_cost, jnp.eye(self.task.dim) * 3) 87 | 88 | # Vector -> jnp.diag(vector) 89 | vector_cost = jnp.arange(self.task.dim) 90 | self.task.update_cost(vector_cost) 91 | np.testing.assert_array_equal(self.task.matrix_cost, jnp.diag(vector_cost)) 92 | 93 | # Matrix -> matrix 94 | matrix_cost = jnp.eye(self.task.dim) 95 | self.task.update_cost(matrix_cost) 96 | np.testing.assert_array_equal(self.task.matrix_cost, matrix_cost) 97 | 98 | # For vector cost, the length has to be equal to dimension of the task 99 | vector_cost = jnp.ones(self.task.dim + 1) 100 | self.task.update_cost(vector_cost) 101 | with self.assertRaises(ValueError): 102 | _ = self.task.matrix_cost 103 | 104 | # For matrix cost, if the dimension turned out to be wrong, ValueError should be raised 105 | # Note that error would be raised only when matrix_cost is accessed, even if model is already 106 | # provided 107 | matrix_cost = jnp.eye(self.task.dim + 1) 108 | self.task.update_cost(matrix_cost) 109 | with self.assertRaises(ValueError): 110 | _ = self.task.matrix_cost 111 | 112 | def test_lm_damping(self): 113 | lm_specific_task = DummyTask("lm_specific_task", 1.0, 1.0, lm_damping=5.0) 114 | self.assertEqual(lm_specific_task.lm_damping, 5.0) 115 | 116 | with self.assertRaises(ValueError): 117 | _ = DummyTask("negative_lm_task", 1.0, 1.0, lm_damping=-5.0) 118 | 119 | def test_error(self): 120 | self.set_dim() 121 | jax_task = self.task.jax_component 122 | data = mjx.fwd_position(self.model, mjx.make_data(self.model)) 123 | np.testing.assert_array_almost_equal(jax_task(data), jax_task.compute_error(data)) 124 | -------------------------------------------------------------------------------- /tests/unit_tests/components/tasks/test_com_task.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax.numpy as jnp 4 | import mujoco as mj 5 | import mujoco.mjx as mjx 6 | import numpy as np 7 | 8 | from mjinx.components.tasks import ComTask 9 | 10 | 11 | class TestComTask(unittest.TestCase): 12 | def set_model(self, task: ComTask): 13 | self.model = mjx.put_model( 14 | mj.MjModel.from_xml_string( 15 | """ 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | w 24 | 25 | """ 26 | ) 27 | ) 28 | task.update_model(self.model) 29 | 30 | def test_task_dim(self): 31 | """Test task dimensionality""" 32 | com_task_masked = ComTask("com_task_masked", cost=1.0, gain=1.0, mask=[True, False, True]) 33 | self.assertEqual(com_task_masked.dim, 2) 34 | 35 | def test_mask_validation(self): 36 | """Test task's mask size validation""" 37 | with self.assertRaises(ValueError): 38 | _ = ComTask("com_task_small_mask", cost=1.0, gain=1.0, mask=[True, False]) 39 | 40 | with self.assertRaises(ValueError): 41 | _ = ComTask("com_task_big_mask", cost=1.0, gain=1.0, mask=[True, False, True, False]) 42 | 43 | def test_update_target_com(self): 44 | """Test target CoM parameter updates""" 45 | com_task = ComTask("com_task_default", cost=1.0, gain=1.0) 46 | new_target = [1.0, 2.0, 3.0] 47 | com_task.update_target_com(new_target) 48 | np.testing.assert_array_equal(com_task.target_com, new_target) 49 | 50 | too_big_target = range(4) 51 | with self.assertRaises(ValueError): 52 | com_task.update_target_com(too_big_target) 53 | 54 | def test_build_component(self): 55 | com_task = ComTask( 56 | "com_task", cost=1.0, gain=2.0, gain_fn=lambda x: 2 * x, lm_damping=0.5, mask=[True, True, False] 57 | ) 58 | self.set_model(com_task) 59 | com_des = jnp.array((-0.3, 0.3)) 60 | com_task.target_com = com_des 61 | 62 | jax_component = com_task.jax_component 63 | self.assertEqual(jax_component.dim, 2) 64 | np.testing.assert_array_equal(jax_component.matrix_cost, jnp.eye(2)) 65 | np.testing.assert_array_equal(jax_component.vector_gain, jnp.ones(2) * 2.0) 66 | self.assertEqual(jax_component.gain_fn(4), 8) 67 | self.assertEqual(jax_component.lm_damping, 0.5) 68 | np.testing.assert_array_equal(jax_component.target_com, com_des) 69 | self.assertEqual(jax_component.mask_idxs, (0, 1)) 70 | 71 | data = mjx.fwd_position(self.model, mjx.make_data(self.model)) 72 | com_value = jax_component(data) 73 | np.testing.assert_array_equal(com_value, jnp.array([0.6, 0.0])) 74 | -------------------------------------------------------------------------------- /tests/unit_tests/components/tasks/test_obj_frame_task.py: -------------------------------------------------------------------------------- 1 | """Center of mass task implementation.""" 2 | 3 | import unittest 4 | 5 | import jax.numpy as jnp 6 | import mujoco as mj 7 | import mujoco.mjx as mjx 8 | import numpy as np 9 | from jaxlie import SE3 10 | 11 | from mjinx.components.tasks import FrameTask 12 | 13 | 14 | class TestObjFrameTask(unittest.TestCase): 15 | def setUp(self) -> None: 16 | self.to_wxyz_xyz: jnp.ndarray = jnp.array([3, 4, 5, 6, 0, 1, 2]) 17 | 18 | def set_model(self, task: FrameTask): 19 | self.model = mjx.put_model( 20 | mj.MjModel.from_xml_string( 21 | """ 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | """ 39 | ) 40 | ) 41 | task.update_model(self.model) 42 | 43 | def test_target_frame(self): 44 | """Testing manipulations with target frame""" 45 | frame_task = FrameTask("frame_task", cost=1.0, gain=2.0, obj_name="body3") 46 | # By default, it has to be identity 47 | np.testing.assert_array_almost_equal(SE3.identity().wxyz_xyz, frame_task.target_frame.wxyz_xyz) 48 | 49 | # Setting with SE3 object 50 | test_se3 = SE3(jnp.array([0.1, 0.2, -0.1, 0, 1, 0, 0])) 51 | frame_task.target_frame = test_se3 52 | np.testing.assert_array_almost_equal(test_se3.wxyz_xyz, frame_task.target_frame.wxyz_xyz) 53 | 54 | # Setting with sequence 55 | test_se3_seq = (-1, 0, 1, 0, 0, 1, 0) 56 | frame_task.target_frame = test_se3_seq 57 | np.testing.assert_array_almost_equal( 58 | jnp.array(test_se3_seq)[self.to_wxyz_xyz], frame_task.target_frame.wxyz_xyz 59 | ) 60 | 61 | # Setting with sequence of wrong length 62 | with self.assertRaises(ValueError): 63 | frame_task.target_frame = (-1, 0, 1, 0, 0, 1, 0, 1) 64 | 65 | def test_build_component(self): 66 | frame_task = FrameTask( 67 | "frame_task", 68 | cost=1.0, 69 | gain=2.0, 70 | obj_name="body3", 71 | gain_fn=lambda x: 2 * x, 72 | lm_damping=0.5, 73 | mask=[True, False, True, True, False, True], 74 | ) 75 | self.set_model(frame_task) 76 | frame_des = jnp.array([0.1, 0, -0.1, 1, 0, 0, 0]) 77 | frame_task.target_frame = frame_des 78 | 79 | jax_component = frame_task.jax_component 80 | 81 | self.assertEqual(jax_component.dim, 4) 82 | np.testing.assert_array_equal(jax_component.matrix_cost, jnp.eye(jax_component.dim)) 83 | np.testing.assert_array_equal(jax_component.vector_gain, jnp.ones(jax_component.dim) * 2.0) 84 | np.testing.assert_array_equal(jax_component.obj_id, frame_task.obj_id) 85 | self.assertEqual(jax_component.gain_fn(4), 8) 86 | self.assertEqual(jax_component.lm_damping, 0.5) 87 | np.testing.assert_array_almost_equal(jax_component.target_frame.wxyz_xyz, frame_des[self.to_wxyz_xyz]) 88 | 89 | self.assertEqual(jax_component.mask_idxs, (0, 2, 3, 5)) 90 | 91 | data = mjx.fwd_position(self.model, mjx.make_data(self.model)) 92 | error = jax_component(data) 93 | np.testing.assert_array_equal(error, jnp.array([0.1, -0.1, 0.0, 0.0])) 94 | 95 | # Testing component jacobian 96 | jac = jax_component.compute_jacobian(mjx.make_data(self.model)) 97 | 98 | self.assertEqual(jac.shape, (jax_component.dim, self.model.nv)) 99 | -------------------------------------------------------------------------------- /tests/unit_tests/components/tasks/test_obj_position_task.py: -------------------------------------------------------------------------------- 1 | """Center of mass task implementation.""" 2 | 3 | import unittest 4 | 5 | import jax.numpy as jnp 6 | import mujoco as mj 7 | import mujoco.mjx as mjx 8 | import numpy as np 9 | 10 | from mjinx.components.tasks import PositionTask 11 | 12 | 13 | class TestObjPositionTask(unittest.TestCase): 14 | def set_model(self, task: PositionTask): 15 | self.model = mjx.put_model( 16 | mj.MjModel.from_xml_string( 17 | """ 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | """ 35 | ) 36 | ) 37 | task.update_model(self.model) 38 | 39 | def test_target_frame(self): 40 | """Testing manipulations with target frame""" 41 | pos_task = PositionTask("pos_task", cost=1.0, gain=2.0, obj_name="body3") 42 | # By default, it has to be identity 43 | np.testing.assert_array_almost_equal(pos_task.target_pos, jnp.zeros(3)) 44 | 45 | # Setting with sequence 46 | test_pos = (-1, 0, 1) 47 | pos_task.target_pos = test_pos 48 | np.testing.assert_array_almost_equal(jnp.array(test_pos), pos_task.target_pos) 49 | 50 | with self.assertRaises(ValueError): 51 | pos_task.target_pos = (0, 1, 2, 3, 4) 52 | 53 | def test_build_component(self): 54 | frame_task = PositionTask( 55 | "frame_task", 56 | cost=1.0, 57 | gain=2.0, 58 | obj_name="body3", 59 | gain_fn=lambda x: 2 * x, 60 | lm_damping=0.5, 61 | mask=[True, False, True], 62 | ) 63 | self.set_model(frame_task) 64 | pos_des = jnp.array([0.1, -0.1]) 65 | frame_task.target_pos = pos_des 66 | 67 | jax_component = frame_task.jax_component 68 | 69 | self.assertEqual(jax_component.dim, 2) 70 | np.testing.assert_array_equal(jax_component.matrix_cost, jnp.eye(jax_component.dim)) 71 | np.testing.assert_array_equal(jax_component.vector_gain, jnp.ones(jax_component.dim) * 2.0) 72 | np.testing.assert_array_equal(jax_component.obj_id, frame_task.obj_id) 73 | self.assertEqual(jax_component.gain_fn(4), 8) 74 | self.assertEqual(jax_component.lm_damping, 0.5) 75 | np.testing.assert_array_almost_equal(jax_component.target_pos, pos_des) 76 | 77 | self.assertEqual(jax_component.mask_idxs, (0, 2)) 78 | 79 | data = mjx.fwd_position(self.model, mjx.make_data(self.model)) 80 | error = jax_component(data) 81 | np.testing.assert_array_equal(error, jnp.array([-0.1, 0.1])) 82 | -------------------------------------------------------------------------------- /tests/unit_tests/configuration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/configuration/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/configuration/test_collision_computation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax.numpy as jnp 4 | import mujoco 5 | import mujoco.mjx as mjx 6 | import numpy as np 7 | 8 | from mjinx.configuration import compute_collision_pairs, update 9 | from mjinx.typing import CollisionPair 10 | 11 | 12 | class TestCollisionPairs(unittest.TestCase): 13 | """Test suite for collision pair computation functionality.""" 14 | 15 | # XML template for creating test models 16 | TEST_MODEL_TEMPLATE = """ 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | """ 29 | 30 | def create_test_model(self, type1, type2, pos1, pos2, size1, size2): 31 | """ 32 | Helper method to create test models with different geometries. 33 | 34 | :param type1: Type of first geometry 35 | :param type2: Type of second geometry 36 | :param pos1: Position of first body [x, y, z] 37 | :param pos2: Position of second body [x, y, z] 38 | :param size1: Size parameters for first geometry 39 | :param size2: Size parameters for second geometry 40 | :return: Tuple of (mjx.Model, mjx.Data) 41 | """ 42 | xml = self.TEST_MODEL_TEMPLATE.format( 43 | type1=type1, 44 | type2=type2, 45 | pos1=" ".join(map(str, pos1)), 46 | pos2=" ".join(map(str, pos2)), 47 | size1=" ".join(map(str, size1 if isinstance(size1, list | tuple) else [size1])), 48 | size2=" ".join(map(str, size2 if isinstance(size2, list | tuple) else [size2])), 49 | ) 50 | mjx_model = mjx.put_model(mujoco.MjModel.from_xml_string(xml)) 51 | mjx_data = update(mjx_model, jnp.zeros(1)) 52 | return mjx_model, mjx_data 53 | 54 | def setUp(self): 55 | """Set up common test models.""" 56 | # Model for sphere-sphere collision tests 57 | self.sphere_model, self.sphere_data = self.create_test_model( 58 | "sphere", "sphere", pos1=[0, 0, 0], pos2=[0.8, 0, 0], size1=0.5, size2=0.5 59 | ) 60 | 61 | # Model for box-sphere collision tests 62 | self.box_sphere_model, self.box_sphere_data = self.create_test_model( 63 | "box", "sphere", pos1=[0, 0, 0], pos2=[0.8, 0, 0], size1=[0.5, 0.5, 0.5], size2=0.5 64 | ) 65 | 66 | def test_sphere_sphere_collision(self): 67 | """Test collision detection between two spheres.""" 68 | collision_pairs = [(0, 1)] 69 | contact = compute_collision_pairs(self.sphere_model, self.sphere_data, collision_pairs) 70 | 71 | # Expected distance: 0.8 - (0.5 + 0.5) = -0.2 (penetration) 72 | self.assertIsNotNone(contact.dist) 73 | np.testing.assert_allclose(contact.dist, -0.2, atol=1e-6) 74 | self.assertIsNotNone(contact.pos) 75 | np.testing.assert_allclose(contact.pos[0], [0.4, 0, 0], atol=1e-6) 76 | 77 | def test_sphere_sphere_no_collision(self): 78 | """Test non-colliding spheres.""" 79 | model, data = self.create_test_model("sphere", "sphere", pos1=[0, 0, 0], pos2=[2, 0, 0], size1=0.5, size2=0.5) 80 | 81 | collision_pairs = [(0, 1)] 82 | contact = compute_collision_pairs(model, data, collision_pairs) 83 | 84 | self.assertIsNotNone(contact.dist) 85 | np.testing.assert_allclose(contact.dist, 1.0, atol=1e-6) 86 | 87 | def test_box_sphere_collision(self): 88 | """Test collision detection between a box and a sphere.""" 89 | collision_pairs = [(0, 1)] 90 | contact = compute_collision_pairs(self.box_sphere_model, self.box_sphere_data, collision_pairs) 91 | 92 | self.assertIsNotNone(contact.dist) 93 | np.testing.assert_allclose(contact.dist, -0.2, atol=1e-6) 94 | 95 | def test_multiple_collision_pairs(self): 96 | """Test handling multiple collision pairs simultaneously.""" 97 | # Create model with three bodies 98 | xml_additional = """ 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | """ 114 | model = mjx.put_model(mujoco.MjModel.from_xml_string(xml_additional)) 115 | data = update(model, jnp.zeros(1)) 116 | 117 | collision_pairs = [(0, 1), (0, 2)] 118 | contact = compute_collision_pairs(model, data, collision_pairs) 119 | 120 | self.assertIsNotNone(contact.dist) 121 | self.assertEqual(len(contact.dist), 2) 122 | np.testing.assert_allclose(contact.dist[0], -0.2, atol=1e-6) # penetration 123 | np.testing.assert_allclose(contact.dist[1], 0.2, atol=1e-6) # separation 124 | 125 | def test_invalid_collision_pairs(self): 126 | """Test behavior with invalid collision pair indices.""" 127 | with self.assertRaises(IndexError): 128 | compute_collision_pairs(self.sphere_model, self.sphere_data, [(0, 5)]) 129 | 130 | def test_empty_collision_pairs(self): 131 | """Test behavior with empty collision pairs list.""" 132 | contact = compute_collision_pairs(self.sphere_model, self.sphere_data, []) 133 | 134 | self.assertIsNotNone(contact.dist) 135 | self.assertEqual(len(contact.dist), 0) 136 | self.assertIsNotNone(contact.pos) 137 | self.assertEqual(len(contact.pos), 0) 138 | 139 | def test_contact_frame_orientation(self): 140 | """Test the orientation of contact frames.""" 141 | model, data = self.create_test_model( 142 | "sphere", 143 | "sphere", 144 | pos1=[0, 0, 0], 145 | pos2=[0.8, 0.8, 0], 146 | size1=0.5, 147 | size2=0.5, # Diagonal position 148 | ) 149 | 150 | collision_pairs = [(0, 1)] 151 | contact = compute_collision_pairs(model, data, collision_pairs) 152 | 153 | self.assertIsNotNone(contact.frame) 154 | expected_direction = np.array([1, 1, 0]) / np.sqrt(2) 155 | np.testing.assert_allclose(contact.frame[0, 0], expected_direction, atol=1e-6) 156 | -------------------------------------------------------------------------------- /tests/unit_tests/solvers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/solvers/__init__.py -------------------------------------------------------------------------------- /tests/unit_tests/solvers/test_global_ik.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import mujoco as mj 4 | import mujoco.mjx as mjx 5 | import numpy as np 6 | import optax 7 | 8 | from mjinx.components.barriers import JointBarrier 9 | from mjinx.components.tasks import ComTask 10 | from mjinx.problem import Problem 11 | from mjinx.solvers import GlobalIKSolver 12 | 13 | 14 | class TestGlobalIK(unittest.TestCase): 15 | def setUp(self): 16 | """Setting up basic model, components, and problem data.""" 17 | 18 | self.model = mjx.put_model( 19 | mj.MjModel.from_xml_string( 20 | """ 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | """ 31 | ) 32 | ) 33 | self.dummy_task = ComTask("com_task", cost=1.0, gain=1.0) 34 | self.dummy_barrier = JointBarrier("joint_barrier", gain=1.0) 35 | self.problem = Problem( 36 | self.model, 37 | v_min=0, 38 | v_max=0, 39 | ) 40 | self.problem.add_component(self.dummy_task) 41 | self.problem.add_component(self.dummy_barrier) 42 | 43 | self.problem_data = self.problem.compile() 44 | 45 | def test_init(self): 46 | """Testing planner initialization""" 47 | 48 | solver = GlobalIKSolver(model=self.model, optimizer=optax.adam(learning_rate=1e-3)) 49 | 50 | with self.assertRaises(ValueError): 51 | _ = solver.init(q=np.arange(self.model.nq + 1)) 52 | 53 | data = solver.init(q=np.arange(self.model.nq)) 54 | 55 | # Check that it correctly initialized for chosen optimizer 56 | self.assertIsInstance(data.optax_state[0], optax.ScaleByAdamState) 57 | self.assertIsInstance(data.optax_state[1], optax.EmptyState) 58 | 59 | def test_loss_fn(self): 60 | self.problem.remove_component(self.dummy_barrier.name) 61 | self.dummy_task.target_com = [0.3, 0.3, 0.3] 62 | new_problem_data = self.problem.compile() 63 | 64 | solver = GlobalIKSolver(model=self.model, optimizer=optax.adam(learning_rate=1e-3)) 65 | loss = solver.loss_fn(np.zeros(self.model.nq), problem_data=new_problem_data) 66 | 67 | self.assertEqual(loss, 0) 68 | 69 | def test_loss_grad(self): 70 | solver = GlobalIKSolver(model=self.model, optimizer=optax.adam(learning_rate=1e-3)) 71 | grad = solver.grad_fn(np.zeros(self.model.nq), problem_data=self.problem_data) 72 | 73 | self.assertEqual(grad.shape, (self.model.nv,)) 74 | 75 | def test_solve(self): 76 | """Testing solving functions""" 77 | solver = GlobalIKSolver(model=self.model, optimizer=optax.adam(learning_rate=1e-2), dt=1e-3) 78 | solver_data = solver.init(q=np.ones(self.model.nq)) 79 | 80 | with self.assertRaises(ValueError): 81 | solver.solve(np.ones(self.model.nq + 1), solver_data, self.problem_data) 82 | 83 | new_solution, _ = solver.solve(np.ones(self.model.nq), solver_data, self.problem_data) 84 | 85 | new_solution_from_data, _ = solver.solve_from_data( 86 | solver_data, 87 | self.problem_data, 88 | mjx.make_data(self.model).replace(qpos=np.ones(self.model.nq)), 89 | ) 90 | 91 | np.testing.assert_almost_equal(new_solution.v_opt, new_solution_from_data.v_opt, decimal=3) 92 | -------------------------------------------------------------------------------- /tests/unit_tests/test_problem.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import jax.numpy as jnp 4 | import mujoco as mj 5 | import mujoco.mjx as mjx 6 | import numpy as np 7 | 8 | from mjinx.components.barriers import JointBarrier 9 | from mjinx.components.tasks import ComTask, JaxComTask 10 | from mjinx.problem import Problem 11 | 12 | 13 | class TestProblem(unittest.TestCase): 14 | dummy_dim: int = 3 15 | 16 | def setUp(self): 17 | """Setting up component test based on ComTask example. 18 | 19 | Note that ComTask has one of the smallest additional functionality and processing 20 | compared to other components""" 21 | 22 | self.model = mjx.put_model( 23 | mj.MjModel.from_xml_string( 24 | """ 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | w 33 | 34 | """ 35 | ) 36 | ) 37 | self.problem = Problem( 38 | self.model, 39 | v_min=0, 40 | v_max=0, 41 | ) 42 | 43 | def test_v_min(self): 44 | """Testing proper v lower limit""" 45 | 46 | self.problem.v_min = 1 47 | 48 | self.assertIsInstance(self.problem.v_min, jnp.ndarray) 49 | np.testing.assert_array_equal(self.problem.v_min, jnp.ones(self.model.nv)) 50 | 51 | with self.assertRaises(ValueError): 52 | self.problem.v_min = jnp.array([1, 2, 3]) 53 | 54 | self.problem.v_min = np.arange(self.model.nv) 55 | self.assertIsInstance(self.problem.v_min, jnp.ndarray) 56 | np.testing.assert_array_equal(self.problem.v_min, jnp.arange(self.model.nv)) 57 | 58 | jax_problem_data = self.problem.compile() 59 | self.assertIsInstance(jax_problem_data.v_min, jnp.ndarray) 60 | np.testing.assert_array_equal(jax_problem_data.v_min, jnp.arange(self.model.nv)) 61 | 62 | with self.assertRaises(ValueError): 63 | self.problem.v_min = jnp.eye(self.model.nv) 64 | 65 | def test_v_max(self): 66 | """Testing proper v upper limit""" 67 | 68 | self.problem.v_max = 1 69 | 70 | np.testing.assert_array_equal(self.problem.v_max, jnp.ones(self.model.nv)) 71 | 72 | with self.assertRaises(ValueError): 73 | self.problem.v_max = jnp.array([1, 2, 3]) 74 | 75 | self.problem.v_max = np.arange(self.model.nv) 76 | self.assertIsInstance(self.problem.v_max, jnp.ndarray) 77 | np.testing.assert_array_equal(self.problem.v_max, jnp.arange(self.model.nv)) 78 | 79 | jax_problem_data = self.problem.compile() 80 | self.assertIsInstance(jax_problem_data.v_max, jnp.ndarray) 81 | np.testing.assert_array_equal(jax_problem_data.v_max, jnp.arange(self.model.nv)) 82 | 83 | with self.assertRaises(ValueError): 84 | self.problem.v_max = jnp.eye(self.model.nv) 85 | 86 | def test_add_component(self): 87 | """Testing adding component""" 88 | 89 | component = ComTask("test_component", 1.0, 1.0) 90 | 91 | self.problem.add_component(component) 92 | # Chech that model is defined and does not raise errors 93 | _ = component.model 94 | 95 | jax_problem_data = self.problem.compile() 96 | 97 | # False -> component was compiled as well 98 | self.assertFalse(component.modified) 99 | self.assertEqual(len(jax_problem_data.components), 1) 100 | self.assertIsInstance(jax_problem_data.components["test_component"], JaxComTask) 101 | 102 | component_with_same_name = JointBarrier("test_component", 1.0) 103 | 104 | with self.assertRaises(ValueError): 105 | self.problem.add_component(component_with_same_name) 106 | 107 | def test_remove_component(self): 108 | """Testing removing component""" 109 | component = ComTask("test_component", 1.0, 1.0) 110 | 111 | self.problem.add_component(component) 112 | self.problem.remove_component(component.name) 113 | 114 | jax_problem_data = self.problem.compile() 115 | self.assertEqual(len(jax_problem_data.components), 0) 116 | 117 | def test_components_access(self): 118 | """Test componnets access interface""" 119 | 120 | component = ComTask("test_component", 1.0, 1.0) 121 | 122 | self.problem.add_component(component) 123 | 124 | self.assertIsInstance(self.problem.component("test_component"), ComTask) 125 | 126 | with self.assertRaises(ValueError): 127 | _ = self.problem.component("non_existens_component") 128 | 129 | def test_tasks_access(self): 130 | """Test tasks access interface""" 131 | 132 | task = ComTask("test_task", 1.0, 1.0) 133 | barrier = JointBarrier("test_barrier", 1.0) 134 | 135 | self.problem.add_component(task) 136 | self.problem.add_component(barrier) 137 | 138 | self.assertIsInstance(self.problem.task("test_task"), ComTask) 139 | 140 | with self.assertRaises(ValueError): 141 | _ = self.problem.task("test_barrier") 142 | 143 | with self.assertRaises(ValueError): 144 | _ = self.problem.task("non_existent_task") 145 | 146 | def test_barriers_access(self): 147 | """Test barriers access interface""" 148 | 149 | task = ComTask("test_task", 1.0, 1.0) 150 | barrier = JointBarrier("test_barrier", 1.0) 151 | 152 | self.problem.add_component(task) 153 | self.problem.add_component(barrier) 154 | 155 | self.assertIsInstance(self.problem.barrier("test_barrier"), JointBarrier) 156 | 157 | with self.assertRaises(ValueError): 158 | _ = self.problem.barrier("test_task") 159 | 160 | with self.assertRaises(ValueError): 161 | _ = self.problem.barrier("non_existent_task") 162 | 163 | def test_setting_vmap_dimension(self): 164 | """Testing context manager for vmapping the dimensions""" 165 | 166 | task = ComTask("test_task", 1.0, 1.0) 167 | barrier = JointBarrier("test_barrier", 1.0) 168 | 169 | self.problem.add_component(task) 170 | self.problem.add_component(barrier) 171 | 172 | with self.problem.set_vmap_dimension() as empty_problem_data: 173 | empty_problem_data.components["test_task"].target_com = 0 174 | 175 | self.assertEqual(empty_problem_data.v_min, None) 176 | self.assertEqual(empty_problem_data.v_max, None) 177 | 178 | self.assertEqual(empty_problem_data.components["test_task"].target_com, 0) 179 | --------------------------------------------------------------------------------