├── .DS_Store
├── .github
└── workflows
│ ├── build.yml
│ ├── docs.yml
│ ├── mypy.yml
│ ├── ruff.yml
│ └── tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── benchmarks
├── bench_global_ik_vmapped_output.py
├── bench_local_ik.py
└── bench_local_ik_vmapped_output.py
├── docs
├── README.md
├── barriers.rst
├── components.rst
├── conf.py
├── configuration.rst
├── constraints.rst
├── css
│ └── custom.css
├── developer-notes.rst
├── img
│ └── kuka_iiwa_14.png
├── index.rst
├── installation.rst
├── problem.rst
├── quick_start.rst
├── references.bib
├── references.rst
├── solvers.rst
├── tasks.rst
├── typing.rst
└── visualization.rst
├── examples
├── cassie_alip.py
├── cassie_squat.py
├── g1_description
│ ├── LICENSE
│ ├── README.md
│ ├── assets
│ │ ├── head_link.STL
│ │ ├── left_ankle_pitch_link.STL
│ │ ├── left_ankle_roll_link.STL
│ │ ├── left_elbow_pitch_link.STL
│ │ ├── left_elbow_roll_link.STL
│ │ ├── left_five_link.STL
│ │ ├── left_four_link.STL
│ │ ├── left_hip_pitch_link.STL
│ │ ├── left_hip_roll_link.STL
│ │ ├── left_hip_yaw_link.STL
│ │ ├── left_knee_link.STL
│ │ ├── left_one_link.STL
│ │ ├── left_palm_link.STL
│ │ ├── left_shoulder_pitch_link.STL
│ │ ├── left_shoulder_roll_link.STL
│ │ ├── left_shoulder_yaw_link.STL
│ │ ├── left_six_link.STL
│ │ ├── left_three_link.STL
│ │ ├── left_two_link.STL
│ │ ├── left_zero_link.STL
│ │ ├── logo_link.STL
│ │ ├── pelvis.STL
│ │ ├── pelvis_contour_link.STL
│ │ ├── right_ankle_pitch_link.STL
│ │ ├── right_ankle_roll_link.STL
│ │ ├── right_elbow_pitch_link.STL
│ │ ├── right_elbow_roll_link.STL
│ │ ├── right_five_link.STL
│ │ ├── right_four_link.STL
│ │ ├── right_hip_pitch_link.STL
│ │ ├── right_hip_roll_link.STL
│ │ ├── right_hip_yaw_link.STL
│ │ ├── right_knee_link.STL
│ │ ├── right_one_link.STL
│ │ ├── right_palm_link.STL
│ │ ├── right_shoulder_pitch_link.STL
│ │ ├── right_shoulder_roll_link.STL
│ │ ├── right_shoulder_yaw_link.STL
│ │ ├── right_six_link.STL
│ │ ├── right_three_link.STL
│ │ ├── right_two_link.STL
│ │ ├── right_zero_link.STL
│ │ └── torso_link.STL
│ ├── g1.png
│ ├── g1.xml
│ └── scene.xml
├── g1_heart.py
├── g1_heart_constrained.py
├── global_ik.py
├── global_ik_vmapped_input.py
├── global_ik_vmapped_output.py
├── go2_squat.py
├── local_ik.py
├── local_ik_vmapped_input.py
├── local_ik_vmapped_output.py
└── notebooks
│ └── turoial.ipynb
├── img
├── cassie_caravan.gif
├── g1_heart.gif
├── go2_stance.gif
├── local_ik_input.gif
├── local_ik_output.gif
└── logo.svg
├── mjinx
├── __init__.py
├── components
│ ├── __init__.py
│ ├── _base.py
│ ├── barriers
│ │ ├── __init__.py
│ │ ├── _base.py
│ │ ├── _joint_barrier.py
│ │ ├── _obj_barrier.py
│ │ ├── _obj_position_barrier.py
│ │ └── _self_collision_barrier.py
│ ├── constraints
│ │ ├── __init__.py
│ │ ├── _base.py
│ │ └── _equality_constraint.py
│ └── tasks
│ │ ├── __init__.py
│ │ ├── _base.py
│ │ ├── _com_task.py
│ │ ├── _joint_task.py
│ │ ├── _obj_frame_task.py
│ │ ├── _obj_position_task.py
│ │ └── _obj_task.py
├── configuration
│ ├── __init__.py
│ ├── _collision.py
│ ├── _lie.py
│ └── _model.py
├── model.py
├── problem.py
├── solvers
│ ├── __init__.py
│ ├── _base.py
│ ├── _global_ik.py
│ └── _local_ik.py
├── typing.py
└── visualize.py
├── mypy.ini
├── pyproject.toml
└── tests
├── __init__.py
├── example_tests
├── test_global_ik_jit.py
├── test_global_ik_vmap.py
├── test_local_ik_jit.py
└── test_local_ik_vmap.py
└── unit_tests
├── __init__.py
├── components
├── __init__.py
├── barriers
│ ├── __init__.py
│ ├── test_base_barrier.py
│ ├── test_joint_barrier.py
│ ├── test_obj_barrier.py
│ ├── test_obj_position_barrier.py
│ └── test_self_collision_barrier.py
├── constraints
│ ├── __init__.py
│ └── test_base_constraint.py
├── tasks
│ ├── __init__.py
│ ├── test_base_task.py
│ ├── test_com_task.py
│ ├── test_joint_task.py
│ ├── test_obj_frame_task.py
│ ├── test_obj_position_task.py
│ └── test_obj_task.py
└── test_base_component.py
├── configuration
├── __init__.py
├── test_collision_computation.py
├── test_collision_utils.py
└── test_configuration.py
├── solvers
├── __init__.py
├── test_global_ik.py
└── test_local_ik.py
└── test_problem.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/.DS_Store
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package
2 |
3 | on:
4 | push:
5 | tags:
6 | - "v*" # Triggers on any tag that starts with 'v'
7 |
8 | permissions:
9 | contents: read
10 | id-token: write # IMPORTANT: mandatory for trusted publishing
11 | jobs:
12 | # tests:
13 | # uses: ./.github/workflows/tests.yml # use the callable tests job to run tests
14 |
15 | deploy:
16 | runs-on: ubuntu-latest
17 | # run on tag only
18 | if: startsWith(github.ref, 'refs/tags/')
19 | # needs: [tests] # require tests to pass before deploy runs
20 | steps:
21 | - name: "Checkout Git repository"
22 | uses: actions/checkout@v3
23 |
24 | - name: Set up Python
25 | uses: actions/setup-python@v3
26 | with:
27 | python-version: "3.x"
28 |
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 |
34 | - name: Build package
35 | run: python -m build --wheel
36 |
37 | - name: Publish package
38 | uses: pypa/gh-action-pypi-publish@release/v1
39 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Documentation
2 |
3 | on:
4 | push:
5 | branches: [main, docs/github_pages]
6 | pull_request:
7 | branches: [main]
8 |
9 | jobs:
10 | docs:
11 | name: "GitHub Pages"
12 | runs-on: ubuntu-latest
13 | permissions:
14 | contents: write
15 | steps:
16 | - name: "Checkout Git repository"
17 | uses: actions/checkout@v3
18 |
19 | - name: "Set up Python"
20 | uses: actions/setup-python@v4
21 | with:
22 | python-version: "3.10" # Specify the Python version you need
23 |
24 | - name: "Install dependencies"
25 | run: |
26 | python -m pip install --upgrade pip
27 | pip install ".[all]"
28 |
29 | - name: "Build documentation"
30 | run: |
31 | sphinx-build docs _build
32 |
33 | - name: "Deploy to GitHub Pages"
34 | uses: peaceiris/actions-gh-pages@v3
35 | if: ${{ github.ref == 'refs/heads/main' }}
36 | with:
37 | github_token: ${{ secrets.GITHUB_TOKEN }}
38 | publish_dir: _build/
39 | force_orphan: true
40 |
--------------------------------------------------------------------------------
/.github/workflows/mypy.yml:
--------------------------------------------------------------------------------
1 | name: mypy
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | jobs:
10 | mypy:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Check out the repository
14 | uses: actions/checkout@v3
15 |
16 | - name: Set up Python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: "3.12"
20 |
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install -e .
25 | pip install mypy
26 |
27 | - name: Run mypy
28 | run: |
29 | mypy mjinx
30 |
--------------------------------------------------------------------------------
/.github/workflows/ruff.yml:
--------------------------------------------------------------------------------
1 | name: Formatting
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 | permissions:
9 | contents: read
10 |
11 | jobs:
12 | ruff:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v3
16 | - name: Install Python
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: "3.10"
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install ruff==0.6.3 # keep in sync with pyproject.toml
24 | # Update output format to enable automatic inline annotations.
25 | - name: Run Ruff
26 | run: ruff format --diff .
27 |
28 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: tests
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 | workflow_call: # allow this workflow to be called from other workflows
9 |
10 | jobs:
11 | unit-tests:
12 | name: "Unit tests"
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - name: "Checkout sources"
17 | uses: actions/checkout@v3
18 | with:
19 | submodules: recursive
20 |
21 | - name: "Set up Python"
22 | uses: actions/setup-python@v4
23 | with:
24 | python-version: "3.10"
25 |
26 | - name: "Install dependencies"
27 | run: |
28 | python -m pip install --upgrade pip
29 | python -m pip install coveralls
30 | pip install ".[dev]"
31 |
32 | - name: "Run unit tests"
33 | run: |
34 | coverage erase
35 | coverage run -m unittest discover --failfast
36 | coverage report --include="mjinx/*"
37 |
38 | test:
39 | runs-on: ubuntu-latest
40 | strategy:
41 | matrix:
42 | python-version: ["3.10", "3.11", "3.12", "3.13"]
43 |
44 | steps:
45 | - name: Check out the repository
46 | uses: actions/checkout@v3
47 |
48 | - name: Set up Python ${{ matrix.python-version }}
49 | uses: actions/setup-python@v4
50 | with:
51 | python-version: ${{ matrix.python-version }}
52 |
53 | - name: Install dependencies
54 | run: |
55 | python -m pip install --upgrade pip
56 | pip install ".[dev]"
57 |
58 | - name: Run tests
59 | run: |
60 | pytest tests/example_tests
61 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 | _build
164 | MUJOCO_LOG.TXT
165 | *.mp4
166 | *.prof
167 | benchmarks/logs/*
168 | *.vscode/*
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/ruff-pre-commit
3 | # Ruff version.
4 | rev: v0.6.3
5 | hooks:
6 | # Run the formatter.
7 | - id: ruff-format
8 | types_or: [python, pyi, jupyter]
9 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to mjinx
2 |
3 | We're thrilled that you're interested in contributing to mjinx! This document outlines the process for contributing to this project.
4 |
5 | ## How to Contribute
6 |
7 | There are many ways to contribute to mjinx:
8 |
9 | 1. Reporting bugs
10 | 2. Suggesting enhancements
11 | 3. Writing documentation
12 | 4. Submitting code changes
13 |
14 | ### Reporting Bugs
15 |
16 | 1. Check the [issue tracker](https://github.com/based-robotics/mjinx/issues) to see if the bug has already been reported.
17 | 2. If not, create a new issue. Provide a clear title and description, as much relevant information as possible, and a code sample or executable test case demonstrating the bug.
18 |
19 | ### Suggesting Enhancements
20 |
21 | 1. Check the [issue tracker](https://github.com/based-robotics/mjinx/issues) to see if the enhancement has already been suggested.
22 | 2. If not, create a new issue. Clearly describe the enhancement, why it would be useful, and any potential drawbacks.
23 |
24 | ### Writing Documentation
25 |
26 | Good documentation is crucial. If you notice any part of our documentation that could be improved or expanded, please let us know or submit a pull request with your suggested changes.
27 |
28 | ### Submitting Code Changes
29 |
30 | 1. Fork the repository.
31 | 2. Create a new branch for your changes.
32 | 3. Make your changes in your branch.
33 | > Please, pay attention to the code style and mypy typing, it should match the one in the repository. The authors suggest to use pre-commit:
34 | > ```bash
35 | > pip3 install pre-commit
36 | > pre-commit install
37 | > ```
38 | > This is included in one of the tests, and if you ignore it, the tests won't pass.
39 |
40 | 4. Add or update tests as necessary.
41 | 5. Ensure the test suite passes.
42 | 6. Update the documentation as needed.
43 | 7. Push your branch and submit a pull request.
44 |
45 | ## Pull Request Process
46 |
47 | 1. Ensure your code follows the project's style guidelines.
48 | 2. Update the README.md or relevant documentation with details of changes, if applicable.
49 | 3. Add tests for your changes and ensure all tests pass.
50 | 4. Your pull request will be reviewed by the maintainers. They may suggest changes or improvements.
51 | 5. Once approved, your pull request will be merged.
52 |
53 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2024 Ivan Domrachev
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | [](https://github.com/based-robotics/mjinx/actions)
3 | [](https://github.com/based-robotics/mjinx/actions)
4 | [](https://based-robotics.github.io/mjinx/)
5 | [](https://pypi.org/project/mjinx/)
6 | [](https://pypistats.org/packages/mjinx)
7 | [](https://colab.research.google.com/github/based-robotics/mjinx/blob/main/examples/notebooks/turoial.ipynb)
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | **MJINX** is a python library for auto-differentiable numerical inverse kinematics built on **JAX** and **Mujoco MJX**. It draws inspiration from similar tools like the Pinocchio-based [PINK](https://github.com/stephane-caron/pink/tree/main) and Mujoco-based [MINK](https://github.com/kevinzakka/mink/tree/main).
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | ## Key Features
25 | 1. **Flexibility**. Problems are constructed using modular `Components` that enforce desired behaviors or maintain system safety constraints.
26 | 2. **Multiple Solution Strategies**. Leveraging JAX's efficient sampling and automatic differentiation capabilities, MJINX implements various solvers optimized for different robotics scenarios.
27 | 3. **Full JAX Compatibility**. Both the optimal control formulation and solvers are fully JAX-compatible, enabling JIT compilation and automatic vectorization across the entire pipeline.
28 | 4. **User-Friendly Interface**. The API is designed with a clean, intuitive interface that simplifies complex inverse kinematics tasks while maintaining advanced functionality.
29 |
30 | ## Installation
31 | The package is available in PyPI registry, and could be installed via `pip`:
32 | ```bash
33 | pip install mjinx
34 | ```
35 |
36 | Different installation versions:
37 | 1. Visualization tool `mjinx.visualization.BatchVisualizer` is available in `mjinx[visual]`
38 | 2. To run examples, install `mjinx[examples]`
39 | 3. To install development version, install `mjinx[dev]` (preferably in editable mode)
40 | 4. To build docs, install `mjinx[docs]`
41 | 5. To install the repository with all dependencies, install `mjinx[all]`
42 |
43 | Note that by default installation of `mjinx` installs `jax` without cuda support. If you need it, please install `jax>=0.5.0` with CUDA support manually.
44 |
45 | ## Usage
46 | Here is the example of `mjinx` usage:
47 |
48 | ```python
49 | from mujoco import mjx mjx
50 | from mjinx.problem import Problem
51 |
52 | # Initialize the robot model using MuJoCo
53 | MJCF_PATH = "path_to_mjcf.xml"
54 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
55 | mjx_model = mjx.put_model(mj_model)
56 |
57 | # Create instance of the problem
58 | problem = Problem(mjx_model)
59 |
60 | # Add tasks to track desired behavior
61 | frame_task = FrameTask("ee_task", cost=1, gain=20, body_name="link7")
62 | problem.add_component(frame_task)
63 |
64 | # Add barriers to keep robot in a safety set
65 | joints_barrier = JointBarrier("jnt_range", gain=10)
66 | problem.add_component(joints_barrier)
67 |
68 | # Initialize the solver
69 | solver = LocalIKSolver(mjx_model)
70 |
71 | # Initializing initial condition
72 | q0 = np.zeros(7)
73 |
74 | # Initialize solver data
75 | solver_data = solver.init()
76 |
77 | # jit-compiling solve and integrate
78 | solve_jit = jax.jit(solver.solve)
79 | integrate_jit = jax.jit(integrate, static_argnames=["dt"])
80 |
81 | # === Control loop ===
82 | for t in np.arange(0, 5, 1e-2):
83 | # Changing problem and compiling it
84 | frame_task.target_frame = np.array([0.1 * np.sin(t), 0.1 * np.cos(t), 0.1, 1, 0, 0,])
85 | problem_data = problem.compile()
86 |
87 | # Solving the instance of the problem
88 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
89 |
90 | # Integrating
91 | q = integrate_jit(
92 | mjx_model,
93 | q,
94 | opt_solution.v_opt,
95 | dt,
96 | )
97 | ```
98 |
99 | ## Examples
100 | The list of examples includes:
101 | 1. `Kuka iiwa` local inverse kinematics ([single item](examples/local_ik.py), [vmap over desired trajectory](examples/local_ik_vmapped_output.py))
102 | 2. `Kuka iiwa` global inverse kinematics ([single item](examples/global_ik.py), [vmap over desired trajectory](examples/global_ik_vmapped_output.py))
103 | 3. `Go2` [batched squats](examples/go2_squat.py) example
104 |
105 | > **Note:** The Global IK functionality is currently under development and not yet working properly as expected. It needs proper tuning and will be fixed in future updates. Use the Global IK examples with caution and expect suboptimal results.
106 |
107 |
108 | ## Citing MJINX
109 |
110 | If you use MJINX in your research, please cite it as follows:
111 |
112 | ```bibtex
113 | @software{mjinx25,
114 | author = {Domrachev, Ivan and Nedelchev, Simeon},
115 | license = {MIT},
116 | month = mar,
117 | title = {{MJINX: Differentiable GPU-accelerated inverse kinematics in JAX}},
118 | url = {https://github.com/based-robotics/mjinx},
119 | version = {0.1.1},
120 | year = {2025}
121 | }
122 | ```
123 |
124 | ## Contributing
125 | We welcome suggestions and contributions. Please see our [CONTRIBUTING.md](CONTRIBUTING.md) file for guidelines.
126 |
127 | ## Acknowledgements
128 | I am deeply grateful to Simeon Nedelchev, whose guidance and expertise were instrumental in bringing this project to life.
129 |
130 | This work draws significant inspiration from [`pink`](https://github.com/stephane-caron/pink) by Stéphane Caron and [`mink`](https://github.com/kevinzakka/mink) by Kevin Zakka. Their pioneering work in robotics and open source has been a guiding light for this project.
131 |
132 | The codebase incorporates utility functions from [`MuJoCo MJX`](https://github.com/google-deepmind/mujoco/tree/main/mjx). Beyond being an excellent tool for batched computations and machine learning, MJX's codebase serves as a masterclass in clean, informative implementation of physical simulations and JAX usage.
133 |
134 | Special thanks to [IRIS lab](http://iris.kaist.ac.kr/) for their support.
135 |
--------------------------------------------------------------------------------
/benchmarks/bench_global_ik_vmapped_output.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 | from time import perf_counter
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import mujoco as mj
8 | import mujoco.mjx as mjx
9 | import numpy as np
10 | from optax import adam
11 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH
12 |
13 | from mjinx.components.barriers import JointBarrier, PositionBarrier
14 | from mjinx.components.tasks import FrameTask
15 | from mjinx.configuration import integrate
16 | from mjinx.problem import Problem
17 | from mjinx.solvers import GlobalIKSolver
18 |
19 | # === Mujoco ===
20 |
21 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
22 | mj_data = mj.MjData(mj_model)
23 |
24 | mjx_model = mjx.put_model(mj_model)
25 |
26 | q_min = mj_model.jnt_range[:, 0].copy()
27 | q_max = mj_model.jnt_range[:, 1].copy()
28 |
29 | # === Mjinx ===
30 |
31 | # --- Constructing the problem ---
32 | problem = Problem(mjx_model)
33 |
34 | # Creating components of interest and adding them to the problem
35 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7")
36 | position_barrier = PositionBarrier(
37 | "ee_barrier",
38 | gain=0.1,
39 | obj_name="link7",
40 | limit_type="max",
41 | p_max=0.4,
42 | safe_displacement_gain=1e-2,
43 | mask=[1, 0, 0],
44 | )
45 | joints_barrier = JointBarrier("jnt_range", gain=0.1)
46 |
47 | problem.add_component(frame_task)
48 | problem.add_component(position_barrier)
49 | problem.add_component(joints_barrier)
50 |
51 | # Compiling the problem upon any parameters update
52 | problem_data = problem.compile()
53 |
54 | # Initializing solver and its initial state
55 | solver = GlobalIKSolver(mjx_model, adam(learning_rate=1e-2), dt=1e-2)
56 |
57 | # Initializing initial condition
58 | N_batch = 10000
59 | np.random.seed(42)
60 | q0 = jnp.array(
61 | [
62 | -1.4238753,
63 | -1.7268502,
64 | -0.84355015,
65 | 2.0962472,
66 | 2.1339328,
67 | 2.0837479,
68 | -2.5521986,
69 | ]
70 | )
71 | q = jnp.array(
72 | [
73 | np.clip(
74 | q0
75 | + np.random.uniform(
76 | -0.1,
77 | 0.1,
78 | size=(mj_model.nq),
79 | ),
80 | q_min + 1e-1,
81 | q_max - 1e-1,
82 | )
83 | for _ in range(N_batch)
84 | ]
85 | )
86 |
87 | # --- Batching ---
88 | solver_data = jax.vmap(solver.init, in_axes=0)(q)
89 | with problem.set_vmap_dimension() as empty_problem_data:
90 | empty_problem_data.components["ee_task"].target_frame = 0
91 |
92 | solve_jit = jax.jit(jax.vmap(solver.solve, in_axes=(0, 0, empty_problem_data)))
93 | integrate_jit = jax.jit(jax.vmap(integrate, in_axes=(None, 0, 0, None)), static_argnames=["dt"])
94 |
95 | # === Control loop ===
96 | dt = 1e-2
97 | ts = np.arange(0, 10, dt)
98 |
99 | metrics: dict[str, list] = {
100 | "update_time_ms": [],
101 | "joint_positions": [],
102 | "joint_velocities": [],
103 | "target_frames": [],
104 | "optimization_steps": [], # Track number of optimization steps
105 | }
106 |
107 | try:
108 | for t in ts:
109 | # Changing desired values
110 | target_frames = np.array(
111 | [
112 | [
113 | 0.4 + 0.3 * np.sin(t + 2 * np.pi * i / N_batch),
114 | 0.2,
115 | 0.4 + 0.3 * np.cos(t + 2 * np.pi * i / N_batch),
116 | 1,
117 | 0,
118 | 0,
119 | 0,
120 | ]
121 | for i in range(N_batch)
122 | ]
123 | )
124 | frame_task.target_frame = target_frames
125 | problem_data = problem.compile()
126 |
127 | # Solving the instance of the problem
128 | t0 = perf_counter()
129 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
130 | t1 = perf_counter()
131 |
132 | # Update positions directly for global IK
133 | q = opt_solution.q_opt
134 |
135 | # Logging
136 | t_update_ms = (t1 - t0) * 1e3
137 |
138 | # Log metrics
139 | metrics["update_time_ms"].append(t_update_ms)
140 | metrics["joint_positions"].append(np.array(q))
141 | metrics["joint_velocities"].append(opt_solution.v_opt)
142 | metrics["target_frames"].append(target_frames)
143 | metrics["optimization_steps"].append(3) # Fixed number of steps used
144 |
145 | print(f"t={t:.2f}, Update time: {t_update_ms:.2f}ms")
146 |
147 | except KeyboardInterrupt:
148 | print("Finalizing the simulation as requested...")
149 | except Exception as e:
150 | print(e)
151 |
152 | # Calculate and print summary statistics
153 | compilation_time = metrics["update_time_ms"][0]
154 | mean_update_time = np.mean(metrics["update_time_ms"][1:])
155 | std_update_time = np.std(metrics["update_time_ms"][1:])
156 | print("Benchmark Summary:")
157 | print(f"Mean update time: {mean_update_time:.2f} ± {std_update_time:.2f} ms")
158 |
159 | # Save data
160 | save_data = {
161 | "timestamps": ts,
162 | "compilation_time": compilation_time,
163 | "n_batch": N_batch,
164 | "mean_update_time": mean_update_time,
165 | "std_update_time": std_update_time,
166 | "learning_rate": 1e-2, # Save optimizer parameters
167 | }
168 |
169 | # Add component values and metrics to save data
170 | for metric_name, values in metrics.items():
171 | save_data[metric_name] = np.array(values)
172 |
173 | # Save to npz file
174 | timestamp = int(perf_counter())
175 |
176 | # Save to npz file
177 | timestamp = datetime.now().strftime("%Y_%d_%m_%H_%M_%S")
178 | script_dir = os.path.abspath(os.path.dirname(__file__))
179 |
180 | filename = f"{script_dir}/logs/local_ik_benchmark_{timestamp}.npz"
181 | np.savez_compressed(filename, **save_data)
182 | print(f"\nBenchmark data saved to: {filename}")
183 |
--------------------------------------------------------------------------------
/benchmarks/bench_local_ik.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 | from time import perf_counter
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import mujoco as mj
8 | import mujoco.mjx as mjx
9 | import numpy as np
10 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH
11 |
12 | from mjinx.components.barriers import JointBarrier, PositionBarrier, SelfCollisionBarrier
13 | from mjinx.components.tasks import FrameTask
14 | from mjinx.configuration import integrate
15 | from mjinx.problem import Problem
16 | from mjinx.solvers import LocalIKSolver
17 |
18 | # === Mujoco ===
19 |
20 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
21 | mj_data = mj.MjData(mj_model)
22 | mjx_model = mjx.put_model(mj_model)
23 |
24 | # === Mjinx ===
25 | problem = Problem(mjx_model, v_min=-100, v_max=100)
26 |
27 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7")
28 | position_barrier = PositionBarrier(
29 | "ee_barrier",
30 | gain=100,
31 | obj_name="link7",
32 | limit_type="max",
33 | p_max=0.3,
34 | safe_displacement_gain=1e-2,
35 | mask=[1, 0, 0],
36 | )
37 | joints_barrier = JointBarrier("jnt_range", gain=10)
38 | self_collision_barrier = SelfCollisionBarrier(
39 | "self_collision_barrier",
40 | gain=1.0,
41 | d_min=0.01,
42 | )
43 |
44 | problem.add_component(frame_task)
45 | problem.add_component(position_barrier)
46 | problem.add_component(joints_barrier)
47 | problem.add_component(self_collision_barrier)
48 |
49 | problem_data = problem.compile()
50 |
51 | solver = LocalIKSolver(mjx_model, maxiter=20)
52 |
53 | q = jnp.array(
54 | [
55 | -1.4238753,
56 | -1.7268502,
57 | -0.84355015,
58 | 2.0962472,
59 | 2.1339328,
60 | 2.0837479,
61 | -2.5521986,
62 | ]
63 | )
64 | solver_data = solver.init()
65 |
66 | solve_jit = jax.jit(solver.solve)
67 | integrate_jit = jax.jit(integrate, static_argnames=["dt"])
68 |
69 |
70 | # === Control loop ===
71 | dt = 1e-2
72 | ts = np.arange(0, 10, dt)
73 |
74 | # Additional metrics for comprehensive benchmarking
75 | metrics: dict[str, list] = {
76 | "update_time_ms": [],
77 | "joint_positions": [],
78 | "joint_velocities": [],
79 | "target_frames": [],
80 | "solver_iterations": [],
81 | }
82 |
83 | for t in ts:
84 | # Changing desired values
85 | target_frame = np.array([0.2 + 0.2 * jnp.sin(t) ** 2, 0.2, 0.2, 1, 0, 0, 0])
86 | frame_task.target_frame = target_frame
87 | # After changes, recompiling the model
88 | problem_data = problem.compile()
89 |
90 | # Solving the instance of the problem
91 | t0 = perf_counter()
92 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
93 | t1 = perf_counter()
94 |
95 | # Integrating
96 | q = integrate_jit(
97 | mjx_model,
98 | q,
99 | velocity=opt_solution.v_opt,
100 | dt=dt,
101 | )
102 |
103 | # Logging
104 | t_update_ms = (t1 - t0) * 1e3
105 |
106 | # Log additional metrics
107 | metrics["update_time_ms"].append(t_update_ms)
108 | metrics["joint_positions"].append(np.array(q.copy()))
109 | metrics["joint_velocities"].append(np.array(opt_solution.v_opt.copy()))
110 | metrics["target_frames"].append(target_frame)
111 | metrics["solver_iterations"].append(opt_solution.iterations)
112 |
113 | print(f"t={t:.2f}, Update time: {t_update_ms:.2f}ms, Iterations: {opt_solution.iterations}")
114 |
115 | # Calculate and print summary statistics
116 | compilation_time = metrics["update_time_ms"][0]
117 | mean_update_time = np.mean(metrics["update_time_ms"][1:])
118 | std_update_time = np.std(metrics["update_time_ms"][1:])
119 | mean_iterations = np.mean(metrics["solver_iterations"])
120 | print()
121 | print("Benchmark Summary:")
122 | print(f"Mean update time: {mean_update_time:.2f} ± {std_update_time:.2f} ms")
123 | print(f"Mean solver iterations: {mean_iterations:.2f}")
124 |
125 | # Convert lists to numpy arrays for saving
126 | save_data = {
127 | "timestamps": ts,
128 | "compilation_time": compilation_time,
129 | "mean_update_time": mean_update_time,
130 | "std_update_time": std_update_time,
131 | "mean_iterations": mean_iterations,
132 | }
133 |
134 |
135 | # Add metrics to save data
136 | for metric_name, values in metrics.items():
137 | save_data[metric_name] = np.array(values)
138 |
139 | # Save to npz file
140 | timestamp = datetime.now().strftime("%Y_%d_%m_%H_%M_%S")
141 | script_dir = os.path.abspath(os.path.dirname(__file__))
142 |
143 | filename = f"{script_dir}/logs/local_ik_benchmark_{timestamp}.npz"
144 | np.savez_compressed(filename, **save_data)
145 | print(f"\nBenchmark data saved to: {filename}")
146 |
--------------------------------------------------------------------------------
/benchmarks/bench_local_ik_vmapped_output.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 | from time import perf_counter
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import mujoco as mj
8 | import mujoco.mjx as mjx
9 | import numpy as np
10 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH
11 |
12 | from mjinx.components.barriers import JointBarrier, PositionBarrier
13 | from mjinx.components.tasks import FrameTask
14 | from mjinx.configuration import integrate
15 | from mjinx.problem import Problem
16 | from mjinx.solvers import LocalIKSolver
17 |
18 | # === Mujoco ===
19 |
20 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
21 | mj_data = mj.MjData(mj_model)
22 |
23 | mjx_model = mjx.put_model(mj_model)
24 |
25 | q_min = mj_model.jnt_range[:, 0].copy()
26 | q_max = mj_model.jnt_range[:, 1].copy()
27 |
28 | # === Mjinx ===
29 |
30 | # --- Constructing the problem ---
31 | # Creating problem formulation
32 | problem = Problem(mjx_model, v_min=-5, v_max=5)
33 |
34 | # Creating components of interest and adding them to the problem
35 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7")
36 | position_barrier = PositionBarrier(
37 | "ee_barrier",
38 | gain=100,
39 | obj_name="link7",
40 | limit_type="max",
41 | p_max=0.4,
42 | safe_displacement_gain=1e-2,
43 | mask=[1, 0, 0],
44 | )
45 | joints_barrier = JointBarrier("jnt_range", gain=10)
46 |
47 | problem.add_component(frame_task)
48 | problem.add_component(position_barrier)
49 | problem.add_component(joints_barrier)
50 |
51 | # Compiling the problem upon any parameters update
52 | problem_data = problem.compile()
53 |
54 | # Initializing solver and its initial state
55 | solver = LocalIKSolver(mjx_model, maxiter=20)
56 |
57 | # Initializing initial condition
58 | N_batch = 10000
59 | q0 = np.array(
60 | [
61 | -1.5878328,
62 | -2.0968683,
63 | -1.4339591,
64 | 1.6550868,
65 | 2.1080072,
66 | 1.646142,
67 | -2.982619,
68 | ]
69 | )
70 | q = jnp.array(
71 | [
72 | np.clip(
73 | q0
74 | + np.random.uniform(
75 | -0.1,
76 | 0.1,
77 | size=(mj_model.nq),
78 | ),
79 | q_min + 1e-1,
80 | q_max - 1e-1,
81 | )
82 | for _ in range(N_batch)
83 | ]
84 | )
85 |
86 | # --- Batching ---
87 | solver_data = jax.vmap(solver.init, in_axes=0)(v_init=jnp.zeros((N_batch, mjx_model.nv)))
88 | with problem.set_vmap_dimension() as empty_problem_data:
89 | empty_problem_data.components["ee_task"].target_frame = 0
90 | solve_jit = jax.jit(
91 | jax.vmap(
92 | solver.solve,
93 | in_axes=(0, 0, empty_problem_data),
94 | )
95 | )
96 | integrate_jit = jax.jit(jax.vmap(integrate, in_axes=(None, 0, 0, None)), static_argnames=["dt"])
97 |
98 |
99 | # === Control loop ===
100 | dt = 1e-2
101 | ts = np.arange(0, 10, dt)
102 |
103 | metrics: dict[str, list] = {
104 | "update_time_ms": [],
105 | "joint_positions": [],
106 | "joint_velocities": [],
107 | "target_frames": [],
108 | "solver_iterations": [],
109 | }
110 |
111 | try:
112 | for t in ts:
113 | # Changing desired values
114 | target_frames = np.array(
115 | [
116 | [
117 | 0.4 + 0.3 * np.sin(t + 2 * np.pi * i / N_batch),
118 | 0.2,
119 | 0.4 + 0.3 * np.cos(t + 2 * np.pi * i / N_batch),
120 | 1,
121 | 0,
122 | 0,
123 | 0,
124 | ]
125 | for i in range(N_batch)
126 | ]
127 | )
128 | frame_task.target_frame = target_frames
129 | problem_data = problem.compile()
130 |
131 | # Solving the instance of the problem
132 | t0 = perf_counter()
133 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
134 | t1 = perf_counter()
135 |
136 | # Integrating
137 | q = integrate_jit(
138 | mjx_model,
139 | q,
140 | opt_solution.v_opt,
141 | dt,
142 | )
143 |
144 | # Logging
145 | t_update_ms = (t1 - t0) * 1e3
146 |
147 | # Log metrics
148 | metrics["update_time_ms"].append(t_update_ms)
149 | metrics["joint_positions"].append(np.array(q))
150 | metrics["joint_velocities"].append(np.array(opt_solution.v_opt))
151 | metrics["target_frames"].append(np.array(target_frames).copy())
152 | metrics["solver_iterations"].append(np.array(opt_solution.iterations))
153 |
154 | print(f"t={t:.2f}, Update time: {t_update_ms:.2f}ms, Mean iterations: {np.mean(opt_solution.iterations)}")
155 |
156 | except KeyboardInterrupt:
157 | print("Finalizing the simulation as requested...")
158 | except Exception as e:
159 | print(e)
160 |
161 | # Calculate and print summary statistics
162 | compilation_time = metrics["update_time_ms"][0]
163 | mean_update_time = np.mean(metrics["update_time_ms"][1:])
164 | std_update_time = np.std(metrics["update_time_ms"][1:])
165 | mean_iterations = np.mean(metrics["solver_iterations"])
166 |
167 | print()
168 | print("Benchmark Summary:")
169 | print(f"Mean update time: {mean_update_time:.2f} ± {std_update_time:.2f} ms")
170 | print(f"Mean solver iterations: {mean_iterations:.2f}")
171 |
172 | # Save data
173 | save_data = {
174 | "timestamps": ts,
175 | "n_batch": N_batch,
176 | "compilation_time": compilation_time,
177 | "mean_update_time": mean_update_time,
178 | "std_update_time": std_update_time,
179 | "mean_iterations": mean_iterations,
180 | }
181 |
182 | # Add component values and metrics to save data
183 | for metric_name, values in metrics.items():
184 | save_data[metric_name] = np.array(values)
185 |
186 | # Save to npz file
187 | timestamp = datetime.now().strftime("%Y_%d_%m_%H_%M_%S")
188 | script_dir = os.path.abspath(os.path.dirname(__file__))
189 |
190 | filename = f"{script_dir}/logs/local_ik_vmapped_benchmark_{timestamp}.npz"
191 | np.savez_compressed(filename, **save_data)
192 | print(f"\nBenchmark data saved to: {filename}")
193 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # `mjinx` documentation
2 |
3 | The documentation is built using [Sphinx](https://www.sphinx-doc.org/en/master/), [Read the docs](https://docs.readthedocs.io/en/stable/) template and Python.
4 |
5 | The website is available at url.com.
6 |
7 | ## Building locally
8 | To build and test the website locally, do the following:
9 | ```bash
10 | pip install ".[docs]"
11 | rm -r _build && sphinx-build -M html docs _build
12 | ```
13 |
14 | And open the `file:///home//mjinx/_build/index.html` file in the browser.
--------------------------------------------------------------------------------
/docs/barriers.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/barriers.rst
2 |
3 | .. _Barriers:
4 |
5 |
6 | Barriers implement inequality constraints in the inverse kinematics problem. Each barrier defines a scalar function :math:`h: \mathcal{Q} \rightarrow \mathbb{R}` that must remain positive (:math:`h(q) \geq 0`), creating a boundary that the solver must respect. This approach is mathematically equivalent to control barrier functions (CBFs) :cite:`ames2019control`, which have been widely adopted in robotics for safety-critical control.
7 |
8 | The barrier formulation creates a continuous constraint boundary that prevents the system from entering prohibited regions of the configuration space. In the optimization problem, these constraints are typically enforced through linearization at each step:
9 |
10 | .. math::
11 |
12 | \nabla h(q)^T \dot{q} \geq -\alpha h(q)
13 |
14 | where :math:`\alpha` is a gain parameter that controls how aggressively the system is pushed away from the constraint boundary.
15 |
16 |
17 | All barriers follow a consistent mathematical formulation while adapting to specific constraint types, enabling systematic enforcement of safety and feasibility requirements.
18 |
19 | Base Barrier
20 | ------------
21 | The foundation for all barrier constraints, defining the core mathematical properties and interface.
22 |
23 | .. automodule:: mjinx.components.barriers._base
24 | :members:
25 | :member-order: bysource
26 |
27 | Joint Barrier
28 | -------------
29 | Enforces joint limit constraints, preventing the robot from exceeding mechanical limits.
30 |
31 | .. automodule:: mjinx.components.barriers._joint_barrier
32 | :members:
33 | :member-order: bysource
34 |
35 | Base Body Barrier
36 | -----------------
37 | The foundation for barriers applied to specific bodies, geometries, or sites in the robot model. This abstract class provides common functionality for all object-specific barriers.
38 |
39 | .. automodule:: mjinx.components.barriers._obj_barrier
40 | :members:
41 | :member-order: bysource
42 |
43 | Body Position Barrier
44 | ---------------------
45 | Enforces position constraints on specific bodies, geometries, or sites, useful for defining workspace limits.
46 |
47 | .. automodule:: mjinx.components.barriers._obj_position_barrier
48 | :members:
49 | :member-order: bysource
50 |
51 | Self Collision Barrier
52 | ----------------------
53 | Prevents different parts of the robot from colliding with each other, essential for complex manipulators.
54 |
55 | .. automodule:: mjinx.components.barriers._self_collision_barrier
56 | :members:
57 | :member-order: bysource
--------------------------------------------------------------------------------
/docs/components.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/main/docs/components.rst
2 |
3 | ==========
4 | Components
5 | ==========
6 |
7 | MJINX employs a component-based architecture to formulate inverse kinematics problems through functional decomposition. Each component encapsulates a mathematical mapping from the configuration space to a task or constraint space, functioning either as an objective function (task) or an inequality constraint (barrier).
8 |
9 | This modular structure facilitates the systematic construction of complex kinematic problems through composition of elementary components. The approach aligns with established practices in robotics control theory, where complex behaviors emerge from the coordination of simpler control objectives.
10 |
11 | Components are characterized by several key attributes:
12 |
13 | - A unique identifier for reference within the problem formulation
14 | - A gain parameter that determines its relative weight in the optimization
15 | - An optional dimensional mask for selective application
16 | - A differentiable function mapping robot state to output values
17 |
18 | This formulation follows the task-priority framework established in robotics literature, where multiple objectives are managed through appropriate weighting and prioritization. The separation of concerns between tasks and constraints provides a natural expression of both the desired behavior and the feasible region of operation.
19 |
20 | When integrated into a Problem instance, components form a well-posed optimization problem. Tasks define the objective function to be minimized, while barriers establish the constraint manifold. The solver then computes solutions that optimize the weighted task objectives while maintaining feasibility with respect to all constraints.
21 |
22 | **************
23 | Base Component
24 | **************
25 |
26 | The Component class serves as the abstract base class from which all specific component implementations derive. This inheritance hierarchy ensures a consistent interface while enabling specialized behavior for different component types.
27 |
28 | .. automodule:: mjinx.components._base
29 | :members:
30 | :special-members: __call__
31 | :member-order: bysource
32 |
33 | ********
34 | Tasks
35 | ********
36 |
37 | Tasks define objective functions that map from configuration space to task space, with the solver minimizing weighted errors between current and desired values :cite:`kanoun2011kinematic`. Each task :math:`f: \mathcal{Q} \rightarrow \mathbb{R}^m` produces an error :math:`e(q) = f(q) - f_{desired}` that is minimized according to :math:`\|e(q)\|^2_W`.
38 |
39 | MJINX provides task implementations for common robotics objectives:
40 |
41 | .. toctree::
42 | :maxdepth: 1
43 |
44 | tasks
45 |
46 | ********
47 | Barriers
48 | ********
49 | Barriers implement inequality constraints through scalar functions :math:`h(q) \geq 0` that create boundaries the solver must respect. Based on control barrier functions (CBFs) :cite:`ames2019control`, these constraints are enforced through differential inequality: :math:`\nabla h(q)^T v \geq -\alpha h(q)`, with :math:`\alpha` controls constraint enforcement and :math:`v` is the velocity vector.
50 |
51 | MJINX provides several barrier implementations:
52 |
53 | .. toctree::
54 | :maxdepth: 1
55 |
56 | barriers
57 |
58 | ***********
59 | Constraints
60 | ***********
61 |
62 | Constraints represent a new type of component with a strictly enforced equality condition :math:`f(q) = 0`. Those constraints might be either treated strictly as differentiated exponentially stable equality: :math:`\nabla h(q)^T v = -\alpha h(q)`, with :math:`\alpha` controls constraint enforcement and :math:`v` is the velocity vector, or as a soft constraint -- task with high gain.
63 |
64 | Yet, only the following constraints are implemented:
65 |
66 | .. toctree::
67 | :maxdepth: 1
68 |
69 | constraints
--------------------------------------------------------------------------------
/docs/configuration.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/configuration.rst
2 |
3 | =============
4 | Configuration
5 | =============
6 |
7 | The configuration module serves as the mathematical foundation of MJINX, providing essential utilities for robot kinematics, transformations, and state management. These core functions enable precise control and manipulation of robotic systems throughout the library.
8 |
9 | The module is structured into three complementary categories, each addressing a critical aspect of robot configuration:
10 |
11 | 1. **Model** - Functions for manipulating the MuJoCo model and managing robot state
12 | 2. **Lie Algebra** - Specialized tools for handling rotations and transformations with mathematical rigor
13 | 3. **Collision** - Algorithms for detecting and responding to potential collisions
14 |
15 | ******
16 | Model
17 | ******
18 |
19 | The model operations provide fundamental capabilities for state integration, Jacobian computation, and frame transformations. These functions form the bridge between abstract mathematical representations and the physical robot model.
20 |
21 |
22 | .. automodule:: mjinx.configuration._model
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 | :member-order: bysource
27 |
28 | ************
29 | Lie Algebra
30 | ************
31 |
32 | The Lie algebra operations implement sophisticated mathematical tools for handling rotations, quaternions, and transformations in 3D space. Based on principles from differential geometry, these functions ensure proper handling of the SE(3) and SO(3) Lie groups.
33 |
34 | .. automodule:: mjinx.configuration._lie
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 | :member-order: bysource
39 |
40 | **********
41 | Collision
42 | **********
43 |
44 | The collision operations provide sophisticated algorithms for detecting potential collisions, computing distances between objects, and analyzing contact points. These functions are crucial for implementing safety constraints and realistic physical interactions.
45 |
46 | .. automodule:: mjinx.configuration._collision
47 | :members:
48 | :undoc-members:
49 | :show-inheritance:
50 | :member-order: bysource
--------------------------------------------------------------------------------
/docs/constraints.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/constraints.rst
2 |
3 | .. _Constraints:
4 |
5 | Constraints represent a new type of component with a strictly enforced equality, either hardly or softly.
6 |
7 | The constratin is formulated as :math:`f(q) = 0`. Those constraints might be either treated strictly as differentiated exponentially stable equality: :math:`\nabla h(q)^T v = -\alpha h(q)`, with :math:`\alpha` controls constraint enforcement and :math:`v` is the velocity vector, or as a soft constraint -- task with high gain.
8 |
9 | Base Constraint
10 | ---------------
11 | The foundational class that all constraints extend. It defines the core interface and mathematical properties for constraint objectives.
12 |
13 | .. automodule:: mjinx.components.constraints._base
14 | :members:
15 | :member-order: bysource
16 |
17 | Model Equality Constraint
18 | -------------------------
19 | The constraints, described in MuJoCo model as equality constrainst.
20 |
21 | .. automodule:: mjinx.components.constraints._equality_constraint
22 | :members:
23 | :member-order: bysource
24 |
25 | .. Base Task
26 | .. ---------
27 | .. The foundational class that all tasks extend. It defines the core interface and mathematical properties for task objectives.
28 |
29 | .. .. automodule:: mjinx.components.tasks._base
30 | .. :members:
31 | .. :member-order: bysource
32 |
--------------------------------------------------------------------------------
/docs/css/custom.css:
--------------------------------------------------------------------------------
1 | /* Jinx color palette */
2 | /* #03587f - dark blue */
3 | /* #a75a97 - pink */
4 | /* #a16f6f - brown */
5 | /* #d9ccc9 - beige */
6 | /* #aecadc - light blue */
7 | a, a:visited, a:focus {
8 | color: rgb(3, 88, 127); /* dark blue */
9 | text-decoration: none;
10 | }
11 |
12 | a:hover, a:active {
13 | color: rgb(3, 88, 127); /* dark blue */
14 | text-decoration: none;
15 | }
16 |
17 | .wy-menu-vertical header, .wy-menu-vertical p.caption {
18 | color: white;
19 | font-weight: 600;
20 | }
21 |
22 | .wy-menu-vertical a {
23 | color: rgb(167, 90, 151); /* pink */
24 | }
25 |
26 | .wy-menu-vertical a:hover {
27 | background: rgb(167, 90, 151); /* pink */
28 | color: white;
29 | }
30 |
31 | .wy-menu-vertical header, .wy-menu-vertical p.caption {
32 | color: black;
33 | font-weight: 600;
34 | }
35 |
36 | .wy-menu-vertical a{
37 | color: black;
38 | }
39 |
40 | .wy-nav-content-wrap, .wy-menu li.current > a {
41 | background-color: white;
42 | }
43 |
44 | .wy-nav-side {
45 | background: rgb(174, 202, 220); /* light blue */
46 | }
47 |
48 | .wy-nav-top {
49 | background: rgb(217, 204, 201); /* beige */
50 | }
51 |
52 | .wy-nav-top a {
53 | color: black;
54 | }
55 |
56 | .wy-nav-top i {
57 | color: rgb(167, 90, 151); /* pink */
58 | }
59 |
60 | .wy-side-nav-search {
61 | background: rgb(3, 88, 127); /* dark blue */
62 | }
63 |
64 | .wy-side-nav-search > div.version {
65 | color: white;
66 | }
67 |
68 | .wy-side-nav-search > a {
69 | color: white;
70 | }
71 |
72 | pre {
73 | background: #efd0f3; /* beige */
74 | border-left: 5px solid #c076b1; /* brown */
75 | }
76 |
77 | html.writer-html4 .rst-content dl:not(.docutils) > dt,
78 | html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.citation):not(.glossary):not(.simple) > dt {
79 | background: rgb(217, 204, 201); /* beige */
80 | border-top: 3px solid #a16f6f; /* brown */
81 | color: rgb(3, 88, 127); /* dark blue */
82 | }
83 |
84 | .rst-content .note {
85 | background: rgb(217, 204, 201); /* beige */
86 | }
87 |
88 | .rst-content .note .admonition-title {
89 | background: rgb(3, 88, 127); /* dark blue */
90 | }
91 |
92 | .rst-content code.literal {
93 | color: rgb(161, 111, 111); /* brown */
94 | }
95 |
--------------------------------------------------------------------------------
/docs/developer-notes.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/developer-notes.rst
2 |
3 | =================
4 | Developer Notes
5 | =================
6 |
7 | This section contains information for developers who want to contribute to MJINX or understand its internals better.
8 |
9 | *******************
10 | Code Organization
11 | *******************
12 |
13 | MJINX follows a modular architecture:
14 |
15 | - ``mjinx/components/`` - Task and barrier implementations
16 | - ``mjinx/configuration/`` - Kinematics and transformation utilities
17 | - ``mjinx/solvers/`` - Inverse kinematics solvers
18 | - ``mjinx/visualize.py`` - Visualization tools
19 | - ``mjinx/problem.py`` - Problem construction and management
20 | - ``mjinx/typing.py`` - Type definitions
21 |
22 | **************************
23 | Development Guidelines
24 | **************************
25 |
26 | When contributing to MJINX, please follow these guidelines:
27 |
28 | 1. **Type annotations**: Use type annotations throughout the code.
29 | 2. **Documentation**: Write clear docstrings with reStructuredText format.
30 | 3. **Testing**: Add tests for new features using pytest.
31 | 4. **JAX compatibility**: Ensure new code works with JAX transformations.
32 | 5. **Performance**: Consider computation efficiency, especially for operations in inner loops.
33 |
34 | ******************
35 | JAX Considerations
36 | ******************
37 |
38 | MJINX leverages JAX for automatic differentiation and acceleration. When working with JAX:
39 |
40 | - Use JAX's functional programming style
41 | - Avoid in-place mutations
42 | - Remember that JAX arrays are immutable
43 | - Use ``jit``, ``vmap``, and ``grad`` to leverage JAX's transformations
44 | - Test code with both CPU and GPU devices
45 |
46 | For more information, refer to the `JAX documentation `_.
47 |
48 | ****************
49 | Types
50 | ****************
51 |
52 | MJINX uses a comprehensive type system to ensure code correctness and improve developer experience. Understanding the type system is essential for contributing to the codebase and extending its functionality.
53 |
54 | The type system provides clear interfaces between components, helps catch errors at development time, and makes the code more maintainable. All new contributions should adhere to the established typing conventions.
55 |
56 | For detailed information about MJINX's type system, including type aliases and enumerations, see the :doc:`typing` documentation.
57 |
58 | .. toctree::
59 | :maxdepth: 2
60 |
61 | typing
62 |
--------------------------------------------------------------------------------
/docs/img/kuka_iiwa_14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/docs/img/kuka_iiwa_14.png
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/index.rst
2 |
3 | .. title:: Table of Contents
4 |
5 | #####
6 | MJINX
7 | #####
8 |
9 | .. raw:: html
10 |
11 |
12 |
13 | |colab| |pypi_version| |pypi_downloads|
14 |
15 | .. |colab| image:: https://colab.research.google.com/assets/colab-badge.svg
16 | :target: https://colab.research.google.com/github/based-robotics/mjinx/blob/main/examples/notebooks/turoial.ipynb
17 | :alt: Open in Colab
18 |
19 | .. |pypi_version| image:: https://img.shields.io/pypi/v/mjinx?color=blue
20 | :target: https://pypi.org/project/mjinx/
21 | :alt: PyPI version
22 |
23 | .. |pypi_downloads| image:: https://img.shields.io/pypi/dm/mjinx?color=blue
24 | :target: https://pypistats.org/packages/mjinx
25 | :alt: PyPI downloads
26 |
27 | **MJINX** is a high-performance library for differentiable inverse kinematics, powered by `JAX `_
28 | and `MuJoCo MJX `_. The library was inspired by the Pinocchio-based tool `PINK `_ and Mujoco-based analogue `MINK `_.
29 |
30 | .. raw:: html
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 | *************
40 | Key Features
41 | *************
42 |
43 | 1. **Flexibility**: Each control problem is assembled via ``Components``, which enforce desired behavior or keep the system within a safety set.
44 | 2. **Multiple solution approaches**: JAX's efficient sampling and autodifferentiation enable various solvers optimized for different scenarios.
45 | 3. **Fully JAX-compatible**: Both the optimization problem and solver support JAX transformations, including JIT compilation and automatic vectorization.
46 | 4. **Convenience**: The API is designed to make complex inverse kinematics problems easy to express and solve.
47 |
48 | *************
49 | Citing MJINX
50 | *************
51 |
52 | If you use MJINX in your research, please cite it as follows:
53 |
54 | .. code-block:: bibtex
55 |
56 | @software{mjinx25,
57 | author = {Domrachev, Ivan and Nedelchev, Simeon},
58 | license = {MIT},
59 | month = mar,
60 | title = {{MJINX: Differentiable GPU-accelerated inverse kinematics in JAX}},
61 | url = {https://github.com/based-robotics/mjinx},
62 | version = {0.1.1},
63 | year = {2025}
64 | }
65 |
66 |
67 | .. toctree::
68 | :maxdepth: 2
69 | :caption: Contents:
70 |
71 | installation.rst
72 | quick_start.rst
73 | problem.rst
74 | configuration.rst
75 | components.rst
76 | solvers.rst
77 | visualization.rst
78 | references.rst
79 | developer-notes.rst
80 |
81 |
82 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/installation.rst
2 |
3 | ============
4 | Installation
5 | ============
6 |
7 | **********
8 | From PyPI
9 | **********
10 |
11 | The simplest way to install MJINX is via the Python Package Index:
12 |
13 | .. code:: bash
14 |
15 | pip install mjinx
16 |
17 | ********************
18 | Installation Options
19 | ********************
20 |
21 | MJINX offers several installation options with different dependencies:
22 |
23 | 1. **Visualization tools**: ``pip install mjinx[visual]``
24 | 2. **Example dependencies**: ``pip install mjinx[examples]``
25 | 3. **Development dependencies**: ``pip install mjinx[dev]`` (preferably in editable mode)
26 | 4. **Documentation dependencies**: ``pip install mjinx[docs]``
27 | 5. **Complete installation**: ``pip install mjinx[all]``
28 |
29 | *****************************
30 | From Source (Developer Mode)
31 | *****************************
32 |
33 | For developers or to access the latest features, you can clone the repository and install in editable mode:
34 |
35 | .. code:: bash
36 |
37 | git clone https://github.com/based-robotics/mjinx.git
38 | cd mjinx
39 | pip install -e .
40 |
41 | With editable mode, any changes you make to the source code take effect immediately without requiring reinstallation.
42 |
43 | For development work, we recommend installing with the development extras:
44 |
45 | .. code:: bash
46 |
47 | pip install -e ".[dev]"
--------------------------------------------------------------------------------
/docs/problem.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/problem.rst
2 |
3 | .. _Problem:
4 |
5 | =======
6 | Problem
7 | =======
8 |
9 | At the heart of MJINX lies the Problem module - a structured framework that elegantly handles inverse kinematics challenges. This module serves as the central hub where various components come together to form a cohesive mathematical formulation.
10 |
11 | When working with MJINX, a Problem instance orchestrates several key elements:
12 |
13 | - **Tasks**: Objective functions that define desired behaviors, such as reaching specific poses or following trajectories
14 | - **Barriers**: Smooth constraint functions that naturally keep the system within valid states
15 | - **Velocity limits**: Physical bounds on joint velocities to ensure feasible motion
16 |
17 | The module's modular architecture shines in its flexibility - users can begin with simple scenarios like positioning a single end-effector, then naturally build up to complex whole-body motions with multiple objectives and constraints.
18 |
19 | Under the hood, the Problem class transforms these high-level specifications into optimized computations through its JaxProblemData representation. By leveraging JAX's JIT compilation, it ensures that even sophisticated inverse kinematics problems run with maximum efficiency.
20 |
21 | .. automodule:: mjinx.problem
22 | :members:
23 | :member-order: bysource
24 |
25 |
--------------------------------------------------------------------------------
/docs/references.bib:
--------------------------------------------------------------------------------
1 |
2 | @article{kanoun2011kinematic,
3 | title={Kinematic Control of Redundant Manipulators: Generalizing the Task-Priority Framework to Inequality Task},
4 | author={Kanoun, Oussama and Lamiraux, Florent and Wieber, Pierre-Brice},
5 | journal={IEEE Transactions on Robotics},
6 | volume={27},
7 | number={4},
8 | pages={785--792},
9 | year={2011},
10 | publisher={IEEE}
11 | }
12 |
13 | @article{delprete2018joint,
14 | title={Joint Position and Velocity Bounds in Discrete-Time Acceleration/Torque Control of Robot Manipulators},
15 | author={Del Prete, Andrea},
16 | journal={IEEE Robotics and Automation Letters},
17 | volume={3},
18 | number={1},
19 | pages={281--288},
20 | year={2018},
21 | publisher={IEEE}
22 | }
23 |
24 | @article{jackson2021planning,
25 | title={Planning with Attitude},
26 | author={Jackson, Brian E and Tracy, Kevin and Manchester, Zachary},
27 | journal={IEEE Robotics and Automation Letters},
28 | volume={6},
29 | number={3},
30 | pages={5658--5664},
31 | year={2021},
32 | publisher={IEEE}
33 | }
34 |
35 | @inproceedings{ames2019control,
36 | title={Control barrier functions: Theory and applications},
37 | author={Ames, Aaron D and Coogan, Samuel and Egerstedt, Magnus and Notomista, Gennaro and Sreenath, Koushil and Tabuada, Paulo},
38 | booktitle={2019 18th European control conference (ECC)},
39 | pages={3420--3431},
40 | year={2019},
41 | organization={Ieee}
42 | }
43 |
44 | @inproceedings{brunke2024practical,
45 | title={Practical Considerations for Discrete-Time Implementations of Continuous-Time Control Barrier Function-Based Safety Filters},
46 | author={Brunke, Lukas and Zhou, Siqi and Che, Mingxuan and Schoellig, Angela P},
47 | booktitle={2024 American Control Conference (ACC)},
48 | pages={272--278},
49 | year={2024},
50 | organization={IEEE}
51 | }
52 |
53 | @misc{jax2018github,
54 | title={{JAX}: Composable Transformations of {Python+NumPy} Programs},
55 | author={Bradbury, James and Frostig, Roy and Hawkins, Peter and Johnson, Matthew J and Leary, Chris and Maclaurin, Dougal and Wanderman-Milne, Skye},
56 | year={2018},
57 | howpublished={\url{https://github.com/google/jax}}
58 | }
--------------------------------------------------------------------------------
/docs/references.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/references.rst
2 |
3 | **********
4 | References
5 | **********
6 |
7 | =====================
8 | Academic References
9 | =====================
10 |
11 | This section lists key academic papers and resources that influenced MJINX's design and implementation:
12 |
13 | Citations
14 | ---------
15 |
16 | This project builds on the following academic works:
17 |
18 | .. bibliography::
19 | :style: plain
20 | :all:
21 |
22 | ===================
23 | Software References
24 | ===================
25 |
26 | MJINX builds upon and integrates with the following software libraries:
27 |
28 | - `JAX `_: Autograd and XLA for high-performance machine learning research.
29 | - `MuJoCo `_: Physics engine for detailed, efficient robot simulation.
30 | - `PINK `_: Differentiable inverse kinematics using Pinocchio.
31 | - `MINK `_: MuJoCo-based inverse kinematics.
32 | - `Optax `_: Gradient processing and optimization.
33 | - `JaxLie `_: JAX library for Lie groups.
34 |
35 | =================
36 | Acknowledgements
37 | =================
38 |
39 | MJINX would not exist without the contributions and inspiration from several sources:
40 |
41 | - Simeon Nedelchev for guidance and contributions during development
42 | - Stéphane Caron and Kevin Zakka, whose work on PINK and MINK respectively provided significant inspiration
43 | - The MuJoCo MJX team for their excellent physics simulation tools
44 | - IRIS lab at KAIST
45 |
46 |
47 | =============
48 | Citing MJINX
49 | =============
50 |
51 | If you use MJINX in your research, please cite it as follows:
52 |
53 | .. code-block:: bibtex
54 |
55 | @software{mjinx25,
56 | author = {Domrachev, Ivan and Nedelchev, Simeon},
57 | license = {MIT},
58 | month = mar,
59 | title = {{MJINX: Differentiable GPU-accelerated inverse kinematics in JAX}},
60 | url = {https://github.com/based-robotics/mjinx},
61 | version = {0.1.1},
62 | year = {2025}
63 | }
64 |
65 |
66 |
--------------------------------------------------------------------------------
/docs/solvers.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/solvers.rst
2 |
3 | .. _Solvers:
4 |
5 | ========
6 | Solvers
7 | ========
8 |
9 | MJINX provides multiple solver implementations for inverse kinematics problems, each with different characteristics suitable for various applications.
10 |
11 | ***********
12 | Base Solver
13 | ***********
14 |
15 | The abstract base class defining the interface for all solvers.
16 |
17 | .. automodule:: mjinx.solvers._base
18 | :members:
19 | :undoc-members:
20 | :show-inheritance:
21 | :member-order: bysource
22 |
23 |
24 | ***************
25 | Local IK Solver
26 | ***************
27 |
28 | A Quadratic Programming (QP) based solver that linearizes the problem at each step. This solver is efficient for real-time control and tracking applications.
29 |
30 | .. automodule:: mjinx.solvers._local_ik
31 | :members:
32 | :undoc-members:
33 | :show-inheritance:
34 | :member-order: bysource
35 |
36 | ****************
37 | Global IK Solver
38 | ****************
39 |
40 | A nonlinear optimization solver that directly optimizes joint positions. This solver can find solutions that avoid local minima and is suitable for complex positioning tasks.
41 |
42 | .. automodule:: mjinx.solvers._global_ik
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 | :member-order: bysource
--------------------------------------------------------------------------------
/docs/tasks.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/tasks.rst
2 |
3 | .. _Tasks:
4 |
5 | Tasks define the objective functions in an inverse kinematics problem. Each task represents a mapping from the configuration space to a task space, with the solver minimizing the weighted error between the current and desired task values. This approach follows the task-priority framework established in robotics literature :cite:`kanoun2011kinematic`, where multiple objectives are managed through appropriate weighting and prioritization.
6 |
7 | Mathematically, a task is defined as a function :math:`f: \mathcal{Q} \rightarrow \mathbb{R}^m` that maps from the configuration space :math:`\mathcal{Q}` to a task space. The error is computed as :math:`e(q) = f(q) - f_{desired}`, and the solver minimizes a weighted norm of this error: :math:`\|e(q)\|^2_W`, where :math:`W` is a positive-definite weight matrix.
8 |
9 | All tasks inherit from the base Task class and follow a consistent mathematical formulation, enabling systematic composition of complex behaviors through the combination of elementary tasks.
10 |
11 | Base Task
12 | ---------
13 | The foundational class that all tasks extend. It defines the core interface and mathematical properties for task objectives.
14 |
15 | .. automodule:: mjinx.components.tasks._base
16 | :members:
17 | :member-order: bysource
18 |
19 | Center of Mass Task
20 | -------------------
21 | Controls the position of the robot's center of mass, critical for maintaining balance in legged systems and manipulators.
22 |
23 | .. automodule:: mjinx.components.tasks._com_task
24 | :members:
25 | :member-order: bysource
26 |
27 | Joint Task
28 | ----------
29 | Directly controls joint positions, useful for regularization, posture optimization, and redundancy resolution.
30 |
31 | .. automodule:: mjinx.components.tasks._joint_task
32 | :members:
33 | :member-order: bysource
34 |
35 | Base Body Task
36 | --------------
37 | The foundation for tasks that target specific bodies, geometries, or sites in the robot model. This abstract class provides common functionality for all object-specific tasks.
38 |
39 | .. automodule:: mjinx.components.tasks._obj_task
40 | :members:
41 | :member-order: bysource
42 |
43 | Body Frame Task
44 | ---------------
45 | Controls the complete pose (position and orientation) of a body, geometry, or site using SE(3) representations.
46 |
47 | .. automodule:: mjinx.components.tasks._obj_frame_task
48 | :members:
49 | :member-order: bysource
50 |
51 | Body Position Task
52 | ------------------
53 | Controls just the position of a body, geometry, or site, ignoring orientation. Useful when only positional constraints matter.
54 |
55 | .. automodule:: mjinx.components.tasks._obj_position_task
56 | :members:
57 | :member-order: bysource
58 |
--------------------------------------------------------------------------------
/docs/typing.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/typing.rst
2 |
3 | ************
4 | Type System
5 | ************
6 |
7 | MJINX uses Python's type annotations throughout the codebase to enhance code clarity, enable better IDE support, and catch potential errors. This module provides the type definitions and aliases used across the library.
8 |
9 | .. automodule:: mjinx.typing
10 | :members:
11 | :undoc-members:
12 | :show-inheritance:
13 |
14 | ************
15 | Type Aliases
16 | ************
17 |
18 | The following type aliases are defined for common data structures and function signatures:
19 |
20 | .. data:: ndarray
21 | :no-index:
22 |
23 | Type alias for numpy or JAX numpy arrays.
24 |
25 | :annotation: = np.ndarray | jnp.ndarray
26 |
27 | .. data:: ArrayOrFloat
28 | :no-index:
29 |
30 | Type alias for an array or a scalar float value.
31 |
32 | :annotation: = ndarray | float
33 |
34 | .. data:: ClassKFunctions
35 | :no-index:
36 |
37 | Type alias for Class K functions, which are scalar functions that take and return ndarrays.
38 |
39 | :annotation: = Callable[[ndarray], ndarray]
40 |
41 | .. data:: CollisionBody
42 | :no-index:
43 |
44 | Type alias for collision body representation, either as an integer ID or a string name.
45 |
46 | :annotation: = int | str
47 |
48 | .. data:: CollisionPair
49 | :no-index:
50 |
51 | Type alias for a pair of collision body IDs.
52 |
53 | :annotation: = tuple[int, int]
54 |
55 | ************
56 | Enumerations
57 | ************
58 |
59 | .. autoclass:: PositionLimitType
60 | :no-index:
61 | :members:
62 | :undoc-members:
63 | :show-inheritance:
64 |
65 | Enumeration of possible position limit types.
66 |
67 | .. attribute:: MIN
68 | :no-index:
69 | :value: 0
70 |
71 | Minimum position limit.
72 |
73 | .. attribute:: MAX
74 | :no-index:
75 | :value: 1
76 |
77 | Maximum position limit.
78 |
79 | .. attribute:: BOTH
80 | :no-index:
81 | :value: 2
82 |
83 | Both minimum and maximum position limits.
--------------------------------------------------------------------------------
/docs/visualization.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/based-robotics/mjinx/tree/docs/github_pages/docs/visualization.rst
2 |
3 | =============
4 | Visualization
5 | =============
6 |
7 | MJINX provides utilities for visualizing robots, problem components, and optimization results. These visualization tools help with debugging, analysis, and presentation of inverse kinematics solutions.
8 |
9 | The visualization module offers:
10 | - Functions for rendering robot states
11 | - Tools for visualizing task errors and barrier values
12 | - Methods for displaying optimization trajectories
13 | - Integration with popular plotting libraries
14 |
15 | Using visualization alongside MJINX's computational capabilities allows for both development and deployment of robust inverse kinematics solutions.
16 |
17 | .. automodule:: mjinx.visualize
18 | :members:
19 | :member-order: bysource
20 |
21 |
--------------------------------------------------------------------------------
/examples/cassie_squat.py:
--------------------------------------------------------------------------------
1 | import traceback
2 | from time import perf_counter
3 |
4 | import jax
5 | import jax.numpy as jnp
6 | import mujoco as mj
7 | import mujoco.mjx as mjx
8 | import numpy as np
9 | from jaxlie import SE3, SO3
10 | from robot_descriptions.cassie_mj_description import MJCF_PATH
11 |
12 | from mjinx.components.barriers import JointBarrier
13 | from mjinx.components.constraints import ModelEqualityConstraint
14 | from mjinx.components.tasks import ComTask, FrameTask
15 | from mjinx.configuration import integrate, update
16 | from mjinx.problem import Problem
17 | from mjinx.solvers import LocalIKSolver
18 | from mjinx.visualize import BatchVisualizer
19 |
20 | # === Mujoco ===
21 |
22 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
23 | mjx_model = mjx.put_model(mj_model)
24 | print(mjx_model.nq, mjx_model.nv)
25 |
26 | q_min = mj_model.jnt_range[:, 0].copy()
27 | q_max = mj_model.jnt_range[:, 1].copy()
28 |
29 | # --- Mujoco visualization ---
30 | # Initialize render window and launch it at the background
31 | vis = BatchVisualizer(MJCF_PATH, n_models=5, alpha=0.2, record=True)
32 |
33 | # === Mjinx ===
34 | # --- Constructing the problem ---
35 | # Creating problem formulation
36 | problem = Problem(mjx_model, v_min=-1, v_max=1)
37 |
38 | # Creating components of interest and adding them to the problem
39 | joints_barrier = JointBarrier(
40 | "jnt_range",
41 | gain=0.1,
42 | )
43 | com_task = ComTask("com_task", cost=10.0, gain=50.0, mask=[1, 1, 1])
44 | torso_task = FrameTask("torso_task", cost=1.0, gain=2.0, obj_name="cassie-pelvis", mask=[0, 0, 0, 1, 1, 1])
45 |
46 | # Feet (in stance)
47 | left_foot_task = FrameTask(
48 | "left_foot_task",
49 | cost=20.0,
50 | gain=10.0,
51 | obj_name="left-foot",
52 | mask=[1, 1, 1, 1, 0, 1],
53 | )
54 | right_foot_task = FrameTask(
55 | "right_foot_task",
56 | cost=20.0,
57 | gain=10.0,
58 | obj_name="right-foot",
59 | mask=[1, 1, 1, 1, 0, 1],
60 | )
61 |
62 | model_equality_constraint = ModelEqualityConstraint()
63 |
64 | problem.add_component(com_task)
65 | problem.add_component(torso_task)
66 | # TODO: fix this
67 | problem.add_component(joints_barrier)
68 | problem.add_component(left_foot_task)
69 | problem.add_component(right_foot_task)
70 | problem.add_component(model_equality_constraint)
71 | # Initializing solver and its initial state
72 | solver = LocalIKSolver(mjx_model, maxiter=10)
73 |
74 | # Initializing initial condition
75 | N_batch = 100
76 | q0 = mj_model.keyframe("home").qpos
77 | q = jnp.tile(q0, (N_batch, 1))
78 |
79 | mjx_data = update(mjx_model, jnp.array(q0))
80 |
81 | com0 = np.array(mjx_data.subtree_com[mjx_model.body_rootid[0]])
82 | com_task.target_com = com0
83 | # Get torso orientation and set it as target
84 | torso_id = mjx.name2id(mjx_model, mj.mjtObj.mjOBJ_BODY, "cassie-pelvis")
85 | torso_quat = mjx_data.xquat[torso_id]
86 | torso_task.target_frame = np.concatenate([np.zeros(3), torso_quat])
87 |
88 | left_foot_id = mjx.name2id(mjx_model, mj.mjtObj.mjOBJ_BODY, "left-foot")
89 | left_foot_pos = mjx_data.xpos[left_foot_id]
90 | left_foot_quat = mjx_data.xquat[left_foot_id]
91 | left_foot_task.target_frame = jnp.concatenate([left_foot_pos, left_foot_quat])
92 |
93 | right_foot_id = mjx.name2id(mjx_model, mj.mjtObj.mjOBJ_BODY, "right-foot")
94 | right_foot_pos = mjx_data.xpos[right_foot_id]
95 | right_foot_quat = mjx_data.xquat[right_foot_id]
96 | right_foot_task.target_frame = jnp.concatenate([right_foot_pos, right_foot_quat])
97 |
98 | # Compiling the problem upon any parameters update
99 | problem_data = problem.compile()
100 | # --- Batching ---
101 | print("Setting up batched computations...")
102 | solver_data = jax.vmap(solver.init, in_axes=0)(v_init=jnp.zeros((N_batch, mjx_model.nv)))
103 |
104 | with problem.set_vmap_dimension() as empty_problem_data:
105 | empty_problem_data.components["com_task"].target_com = 0
106 |
107 | # Vmapping solve and integrate functions.
108 | solve_jit = jax.jit(
109 | jax.vmap(
110 | solver.solve,
111 | in_axes=(0, 0, empty_problem_data),
112 | )
113 | )
114 | integrate_jit = jax.jit(jax.vmap(integrate, in_axes=(None, 0, 0, None)), static_argnames=["dt"])
115 |
116 | # === Control loop ===
117 | dt = 1e-2
118 | ts = np.arange(0, 20, dt)
119 |
120 | try:
121 | for t in ts:
122 | # Solving the instance of the problem
123 | # com_task.target_com = com0 - np.array([0, 0, 0.2 * np.sin(t) ** 2])
124 | com_task.target_com = np.array(
125 | [[0.0, 0.0, com0[2] - 0.3 * np.sin(t + 2 * np.pi * i / N_batch + np.pi / 2) ** 2] for i in range(N_batch)]
126 | )
127 | problem_data = problem.compile()
128 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
129 | # Integrating
130 | q = integrate_jit(
131 | mjx_model,
132 | q,
133 | opt_solution.v_opt,
134 | dt,
135 | )
136 | # --- MuJoCo visualization ---
137 | indices = np.arange(0, N_batch, N_batch // vis.n_models)
138 |
139 | vis.update(q[:: N_batch // vis.n_models])
140 |
141 | except KeyboardInterrupt:
142 | print("Finalizing the simulation as requested...")
143 | except Exception:
144 | print(traceback.format_exc())
145 | finally:
146 | if vis.record:
147 | vis.save_video(round(1 / dt))
148 | vis.close()
149 |
--------------------------------------------------------------------------------
/examples/g1_description/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2016-2023 HangZhou YuShu TECHNOLOGY CO.,LTD. ("Unitree Robotics")
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | * Neither the name of the copyright holder nor the names of its
15 | contributors may be used to endorse or promote products derived from
16 | this software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
--------------------------------------------------------------------------------
/examples/g1_description/README.md:
--------------------------------------------------------------------------------
1 | # Unitree G1 Description (MJCF)
2 |
3 | > [!IMPORTANT]
4 | > Requires MuJoCo 2.3.4 or later.
5 |
6 | ## Overview
7 |
8 | This package contains a simplified robot description (MJCF) of the [G1 Humanoid
9 | Robot](https://www.unitree.com/g1/) developed by [Unitree
10 | Robotics](https://www.unitree.com/). It is derived from the [publicly available
11 | MJCF
12 | description](https://github.com/unitreerobotics/unitree_ros/tree/master/robots/g1_description).
13 |
14 |
15 |
16 |
17 |
18 | ## MJCF derivation steps
19 |
20 | 1. Copied the MJCF description from [g1_description](https://github.com/unitreerobotics/unitree_ros/tree/master/robots/g1_description).
21 | 2. Manually edited the MJCF to extract common properties into the `` section.
22 | 3. Added sites for the IMU, head and feet.
23 | 4. Add IMU sensor (gyro, accelero, framequat).
24 | 5. Added stand keyframe.
25 | 6. Added spotlight and tracking light.
26 |
27 | ## License
28 |
29 | This model is released under a [BSD-3-Clause License](LICENSE).
30 |
--------------------------------------------------------------------------------
/examples/g1_description/assets/head_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/head_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_ankle_pitch_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_ankle_pitch_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_ankle_roll_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_ankle_roll_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_elbow_pitch_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_elbow_pitch_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_elbow_roll_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_elbow_roll_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_five_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_five_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_four_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_four_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_hip_pitch_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_hip_pitch_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_hip_roll_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_hip_roll_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_hip_yaw_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_hip_yaw_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_knee_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_knee_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_one_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_one_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_palm_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_palm_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_shoulder_pitch_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_shoulder_pitch_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_shoulder_roll_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_shoulder_roll_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_shoulder_yaw_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_shoulder_yaw_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_six_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_six_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_three_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_three_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_two_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_two_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/left_zero_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/left_zero_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/logo_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/logo_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/pelvis.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/pelvis.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/pelvis_contour_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/pelvis_contour_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_ankle_pitch_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_ankle_pitch_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_ankle_roll_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_ankle_roll_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_elbow_pitch_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_elbow_pitch_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_elbow_roll_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_elbow_roll_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_five_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_five_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_four_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_four_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_hip_pitch_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_hip_pitch_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_hip_roll_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_hip_roll_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_hip_yaw_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_hip_yaw_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_knee_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_knee_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_one_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_one_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_palm_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_palm_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_shoulder_pitch_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_shoulder_pitch_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_shoulder_roll_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_shoulder_roll_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_shoulder_yaw_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_shoulder_yaw_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_six_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_six_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_three_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_three_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_two_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_two_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/right_zero_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/right_zero_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/assets/torso_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/assets/torso_link.STL
--------------------------------------------------------------------------------
/examples/g1_description/g1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/examples/g1_description/g1.png
--------------------------------------------------------------------------------
/examples/g1_description/scene.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/examples/global_ik.py:
--------------------------------------------------------------------------------
1 | """
2 | Example of Global inverse kinematics for a Kuka iiwa robot.
3 |
4 | NOTE: The Global IK functionality is not yet working properly as expected and needs proper tuning.
5 | This example will be fixed in future updates. Use with caution and expect suboptimal results.
6 | """
7 |
8 | import time
9 | from time import perf_counter
10 | from collections import defaultdict
11 |
12 | import jax
13 | import jax.numpy as jnp
14 | import mujoco as mj
15 | import mujoco.mjx as mjx
16 | import numpy as np
17 | from mujoco import viewer
18 | from optax import adam
19 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH
20 |
21 | from mjinx.components.barriers import JointBarrier, PositionBarrier, SelfCollisionBarrier
22 | from mjinx.components.tasks import FrameTask
23 | from mjinx.configuration import integrate
24 | from mjinx.problem import Problem
25 | from mjinx.solvers import GlobalIKSolver
26 |
27 | print("=== Initializing ===")
28 |
29 | # === Mujoco ===
30 | print("Loading MuJoCo model...")
31 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
32 | mj_data = mj.MjData(mj_model)
33 |
34 | mjx_model = mjx.put_model(mj_model)
35 |
36 | q_min = mj_model.jnt_range[:, 0].copy()
37 | q_max = mj_model.jnt_range[:, 1].copy()
38 |
39 |
40 | # --- Mujoco visualization ---
41 | print("Setting up visualization...")
42 | # Initialize render window and launch it at the background
43 | mj_data = mj.MjData(mj_model)
44 | renderer = mj.Renderer(mj_model)
45 | mj_viewer = viewer.launch_passive(
46 | mj_model,
47 | mj_data,
48 | show_left_ui=False,
49 | show_right_ui=False,
50 | )
51 |
52 | # Initialize a sphere marker for end-effector task
53 | renderer.scene.ngeom += 1
54 | mj_viewer.user_scn.ngeom = 1
55 | mj.mjv_initGeom(
56 | mj_viewer.user_scn.geoms[0],
57 | mj.mjtGeom.mjGEOM_SPHERE,
58 | 0.05 * np.ones(3),
59 | np.array([0.2, 0.2, 0.2]),
60 | np.eye(3).flatten(),
61 | np.array([0.565, 0.933, 0.565, 0.4]),
62 | )
63 |
64 | # === Mjinx ===
65 | print("Setting up optimization problem...")
66 | # --- Constructing the problem ---
67 | # Creating problem formulation
68 | problem = Problem(mjx_model, v_min=-100, v_max=100)
69 |
70 | # Creating components of interest and adding them to the problem
71 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7")
72 | position_barrier = PositionBarrier(
73 | "ee_barrier",
74 | gain=0.1,
75 | obj_name="link7",
76 | limit_type="max",
77 | p_max=0.3,
78 | safe_displacement_gain=1e-2,
79 | mask=[1, 0, 0],
80 | )
81 | joints_barrier = JointBarrier("jnt_range", gain=0.1)
82 | self_collision_barrier = SelfCollisionBarrier(
83 | "self_collision_barrier",
84 | gain=1e-4,
85 | d_min=0.01,
86 | )
87 |
88 | problem.add_component(frame_task)
89 | problem.add_component(position_barrier)
90 | problem.add_component(joints_barrier)
91 | problem.add_component(self_collision_barrier)
92 |
93 | # Compiling the problem upon any parameters update
94 | problem_data = problem.compile()
95 |
96 | # Initializing solver and its initial state
97 | print("Initializing solver...")
98 | solver = GlobalIKSolver(mjx_model, adam(learning_rate=1e-2), dt=1e-2)
99 |
100 | # Initial condition
101 | q = np.array(
102 | [
103 | -1.4238753,
104 | -1.7268502,
105 | -0.84355015,
106 | 2.0962472,
107 | 2.1339328,
108 | 2.0837479,
109 | -2.5521986,
110 | ]
111 | )
112 | solver_data = solver.init(q)
113 |
114 | # Jit-compiling the key functions for better efficiency
115 | solve_jit = jax.jit(solver.solve)
116 | integrate_jit = jax.jit(integrate, static_argnames=["dt"])
117 |
118 | t_warmup = perf_counter()
119 | print("Performing warmup calls...")
120 | # Warmup iterations for JIT compilation
121 | frame_task.target_frame = np.array([0.2, 0.2, 0.2, 1, 0, 0, 0])
122 | problem_data = problem.compile()
123 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
124 | q_warmup = opt_solution.q_opt
125 |
126 | t_warmup_duration = perf_counter() - t_warmup
127 | print(f"Warmup completed in {t_warmup_duration:.3f} seconds")
128 |
129 | # === Control loop ===
130 | print("\n=== Starting main loop ===")
131 | dt = 1e-2
132 | ts = np.arange(0, 20, dt)
133 |
134 | # Performance tracking
135 | solve_times = []
136 | n_steps = 0
137 |
138 | try:
139 | for t in ts:
140 | # Changing desired values
141 | frame_task.target_frame = np.array([0.2 + 0.2 * jnp.sin(t) ** 2, 0.2, 0.2, 1, 0, 0, 0])
142 |
143 | # After changes, recompiling the model
144 | problem_data = problem.compile()
145 |
146 | # Solving the instance of the problem
147 | t1 = perf_counter()
148 | for _ in range(1):
149 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
150 | t2 = perf_counter()
151 | solve_times.append(t2 - t1)
152 |
153 | # Two options for retriving q:
154 | # Option 1, integrating:
155 | # q = integrate(mjx_model, q, opt_solution.v_opt, dt=dt)
156 | # Option 2, direct:
157 | q = opt_solution.q_opt
158 |
159 | # --- MuJoCo visualization ---
160 | mj_data.qpos = q
161 | mj.mj_forward(mj_model, mj_data)
162 | print(f"Position barrier: {mj_data.xpos[position_barrier.obj_id][0]} <= {position_barrier.p_max[0]}")
163 | mj.mjv_initGeom(
164 | mj_viewer.user_scn.geoms[0],
165 | mj.mjtGeom.mjGEOM_SPHERE,
166 | 0.05 * np.ones(3),
167 | np.array(frame_task.target_frame.wxyz_xyz[-3:], dtype=np.float64),
168 | np.eye(3).flatten(),
169 | np.array([0.565, 0.933, 0.565, 0.4]),
170 | )
171 |
172 | # Run the forward dynamics to reflec
173 | # the updated state in the data
174 | mj.mj_forward(mj_model, mj_data)
175 | mj_viewer.sync()
176 | n_steps += 1
177 | except KeyboardInterrupt:
178 | print("\nSimulation interrupted by user")
179 | except Exception as e:
180 | print(f"\nError occurred: {e}")
181 | finally:
182 | renderer.close()
183 |
184 | # Print performance report
185 | print("\n=== Performance Report ===")
186 | print(f"Total steps completed: {n_steps}")
187 | print("\nComputation times per step:")
188 | if solve_times:
189 | avg_solve = sum(solve_times) / len(solve_times)
190 | std_solve = np.std(solve_times)
191 | print(f"solve : {avg_solve * 1000:8.3f} ± {std_solve * 1000:8.3f} ms")
192 |
193 | if solve_times:
194 | print(f"\nAverage computation time per step: {avg_solve * 1000:.3f} ms")
195 | print(f"Effective computation rate: {1 / avg_solve:.1f} Hz")
196 |
--------------------------------------------------------------------------------
/examples/global_ik_vmapped_input.py:
--------------------------------------------------------------------------------
1 | """
2 | Example of Global inverse kinematics for a Kuka iiwa robot with vmapped input.
3 | This demonstrates how to use JAX's vmap to efficiently compute IK solutions for multiple initial configurations.
4 |
5 | NOTE: The Global IK functionality is not yet working properly as expected and needs proper tuning.
6 | This example will be fixed in future updates. Use with caution and expect suboptimal results.
7 | """
8 |
9 | import jax
10 | import jax.numpy as jnp
11 | import mujoco as mj
12 | import mujoco.mjx as mjx
13 | import numpy as np
14 | from optax import adam
15 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH
16 | from time import perf_counter
17 |
18 | from mjinx.components.barriers import JointBarrier, PositionBarrier
19 | from mjinx.components.tasks import FrameTask
20 | from mjinx.configuration import integrate
21 | from mjinx.problem import Problem
22 | from mjinx.solvers import GlobalIKSolver
23 | from mjinx.visualize import BatchVisualizer
24 |
25 | print("=== Initializing ===")
26 |
27 | # === Mujoco ===
28 | print("Loading MuJoCo model...")
29 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
30 | mjx_model = mjx.put_model(mj_model)
31 |
32 | q_min = mj_model.jnt_range[:, 0].copy()
33 | q_max = mj_model.jnt_range[:, 1].copy()
34 |
35 |
36 | # --- Mujoco visualization ---
37 | print("Setting up visualization...")
38 | vis = BatchVisualizer(MJCF_PATH, n_models=5, alpha=0.5, record=False)
39 |
40 | # Initialize a sphere marker for end-effector task
41 | vis.add_markers(
42 | name="ee_marker",
43 | size=0.05,
44 | marker_alpha=0.4,
45 | color_begin=np.array([0, 1.0, 0.53]),
46 | )
47 | vis.add_markers(
48 | name="blocking_plane",
49 | marker_type=mj.mjtGeom.mjGEOM_PLANE,
50 | size=np.array([0.5, 0.5, 0.02]),
51 | marker_alpha=0.7,
52 | color_begin=np.array([1, 0, 0]),
53 | )
54 |
55 | # === Mjinx ===
56 | print("Setting up optimization problem...")
57 | # --- Constructing the problem ---
58 | # Creating problem formulation
59 | problem = Problem(mjx_model)
60 |
61 | # Creating components of interest and adding them to the problem
62 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7")
63 | position_barrier = PositionBarrier(
64 | "ee_barrier",
65 | gain=0.1,
66 | obj_name="link7",
67 | limit_type="max",
68 | p_max=0.5,
69 | safe_displacement_gain=1e-2,
70 | mask=[1, 0, 0],
71 | )
72 | joints_barrier = JointBarrier("jnt_range", gain=0.1)
73 | # Set plane coodinate same to limiting one
74 | vis.marker_data["blocking_plane"].pos = np.array([position_barrier.p_max[0], 0, 0.3])
75 | vis.marker_data["blocking_plane"].rot = np.array(
76 | [
77 | [0, 0, -1],
78 | [0, 1, 0],
79 | [1, 0, 0],
80 | ]
81 | )
82 |
83 | problem.add_component(frame_task)
84 | problem.add_component(position_barrier)
85 | problem.add_component(joints_barrier)
86 |
87 | # Compiling the problem upon any parameters update
88 | problem_data = problem.compile()
89 |
90 | # Initializing solver and its initial state
91 | print("Initializing solver...")
92 | solver = GlobalIKSolver(mjx_model, adam(learning_rate=1e-2), dt=1e-2)
93 |
94 | # Initializing initial condition
95 | N_batch = 100
96 | q0 = np.array(
97 | [
98 | -1.4238753,
99 | -1.7268502,
100 | -0.84355015,
101 | 2.0962472,
102 | 2.1339328,
103 | 2.0837479,
104 | -2.5521986,
105 | ]
106 | )
107 | q = jnp.array(
108 | [
109 | np.clip(
110 | q0
111 | + np.random.uniform(
112 | -0.1,
113 | 0.1,
114 | size=(mj_model.nq),
115 | ),
116 | q_min + 1e-1,
117 | q_max - 1e-1,
118 | )
119 | for _ in range(N_batch)
120 | ]
121 | )
122 |
123 |
124 | # --- Batching ---
125 | print("Setting up batched computations...")
126 | # First of all, data should be created via vmapped init function
127 | solver_data = jax.vmap(solver.init, in_axes=0)(q)
128 |
129 | # Vmapping solve and integrate functions.
130 | solve_jit = jax.jit(jax.vmap(solver.solve, in_axes=(0, 0, None)))
131 | integrate_jit = jax.jit(jax.vmap(integrate, in_axes=(None, 0, 0, None)), static_argnames=["dt"])
132 |
133 | t_warmup = perf_counter()
134 | print("Performing warmup calls...")
135 | # Warmup iterations for JIT compilation
136 | frame_task.target_frame = np.array([0.4, 0.2, 0.7, 1, 0, 0, 0])
137 | problem_data = problem.compile()
138 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
139 | q_warmup = opt_solution.q_opt
140 |
141 | t_warmup_duration = perf_counter() - t_warmup
142 | print(f"Warmup completed in {t_warmup_duration:.3f} seconds")
143 |
144 | # === Control loop ===
145 | print("\n=== Starting main loop ===")
146 | dt = 1e-2
147 | ts = np.arange(0, 20, dt)
148 |
149 | # Performance tracking
150 | solve_times = []
151 | n_steps = 0
152 |
153 | try:
154 | for t in ts:
155 | # Changing desired values
156 | frame_task.target_frame = np.array([0.4 + 0.3 * np.sin(t), 0.2, 0.4 + 0.3 * np.cos(t), 1, 0, 0, 0])
157 |
158 | # After changes, recompiling the model
159 | problem_data = problem.compile()
160 |
161 | # Solving the instance of the problem
162 | t1 = perf_counter()
163 | for _ in range(3):
164 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
165 | t2 = perf_counter()
166 | solve_times.append(t2 - t1)
167 |
168 | # Two options for retriving q:
169 | # Option 1, integrating:
170 | # q = integrate(mjx_model, q, opt_solution.v_opt, dt=dt)
171 | # Option 2, direct:
172 | q = opt_solution.q_opt
173 |
174 | # --- MuJoCo visualization ---
175 | vis.marker_data["ee_marker"].pos = np.array(frame_task.target_frame.wxyz_xyz[-3:])
176 | vis.update(q[: vis.n_models])
177 | n_steps += 1
178 |
179 | except KeyboardInterrupt:
180 | print("\nSimulation interrupted by user")
181 | except Exception as e:
182 | print(f"\nError occurred: {e}")
183 | finally:
184 | if vis.record:
185 | vis.save_video(round(1 / dt))
186 | vis.close()
187 |
188 | # Print performance report
189 | print("\n=== Performance Report ===")
190 | print(f"Total steps completed: {n_steps}")
191 | print("\nComputation times per step:")
192 | if solve_times:
193 | avg_solve = sum(solve_times) / len(solve_times)
194 | std_solve = np.std(solve_times)
195 | print(f"solve : {avg_solve * 1000:8.3f} ± {std_solve * 1000:8.3f} ms")
196 |
197 | if solve_times:
198 | print(f"\nAverage computation time per step: {avg_solve * 1000:.3f} ms")
199 | print(f"Effective computation rate: {1 / avg_solve:.1f} Hz")
200 |
--------------------------------------------------------------------------------
/img/cassie_caravan.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/img/cassie_caravan.gif
--------------------------------------------------------------------------------
/img/g1_heart.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/img/g1_heart.gif
--------------------------------------------------------------------------------
/img/go2_stance.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/img/go2_stance.gif
--------------------------------------------------------------------------------
/img/local_ik_input.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/img/local_ik_input.gif
--------------------------------------------------------------------------------
/img/local_ik_output.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/img/local_ik_output.gif
--------------------------------------------------------------------------------
/img/logo.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/mjinx/__init__.py:
--------------------------------------------------------------------------------
1 | """Differentiable GPU-accelerated inverse kinematics based on JAX and MuJoCo MJX.
2 |
3 | This package provides a framework for solving inverse kinematics problems using
4 | differential programming paradigms. It leverages JAX's automatic differentiation
5 | and GPU acceleration alongside MuJoCo MJX's physics simulation capabilities.
6 | """
7 |
8 | __version__ = "0.1.0"
9 |
--------------------------------------------------------------------------------
/mjinx/components/__init__.py:
--------------------------------------------------------------------------------
1 | from ._base import Component, JaxComponent
2 |
3 | __all__ = ["Component", "JaxComponent"]
4 |
--------------------------------------------------------------------------------
/mjinx/components/barriers/__init__.py:
--------------------------------------------------------------------------------
1 | from ._base import Barrier, JaxBarrier
2 | from ._obj_barrier import ObjBarrier, JaxObjBarrier
3 | from ._obj_position_barrier import PositionBarrier, PositionLimitType
4 | from ._joint_barrier import JaxJointBarrier, JointBarrier
5 | from ._self_collision_barrier import JaxSelfCollisionBarrier, SelfCollisionBarrier
6 |
7 | __all__ = [
8 | "Barrier",
9 | "JaxBarrier",
10 | "ObjBarrier",
11 | "JaxObjBarrier",
12 | "PositionBarrier",
13 | "PositionLimitType",
14 | "JaxJointBarrier",
15 | "JointBarrier",
16 | "JaxSelfCollisionBarrier",
17 | "SelfCollisionBarrier",
18 | ]
19 |
--------------------------------------------------------------------------------
/mjinx/components/barriers/_base.py:
--------------------------------------------------------------------------------
1 | from typing import Generic, TypeVar, Callable
2 |
3 | import jax.numpy as jnp
4 | import jax_dataclasses as jdc
5 | import mujoco.mjx as mjx
6 |
7 | from mjinx.components import Component, JaxComponent
8 |
9 |
10 | @jdc.pytree_dataclass
11 | class JaxBarrier(JaxComponent):
12 | r"""
13 | A base class for implementing barrier functions in JAX.
14 |
15 | This class provides a framework for creating barrier functions that can be used
16 | in optimization problems, particularly for constraint handling in robotics applications.
17 |
18 | A barrier function is defined mathematically as a function:
19 |
20 | .. math::
21 |
22 | h(q) \geq 0
23 |
24 | where the constraint is satisfied when h(q) is non-negative. As h(q) approaches zero,
25 | the barrier "activates" to prevent constraint violation. In optimization problems,
26 | the barrier helps enforce constraints by adding a penalty term that increases rapidly
27 | as the system approaches constraint boundaries.
28 |
29 | :param safe_displacement_gain: The gain for computing safe displacements.
30 | """
31 |
32 | safe_displacement_gain: float
33 |
34 | def compute_barrier(self, data: mjx.Data) -> jnp.ndarray:
35 | """
36 | Compute the barrier function value.
37 |
38 | This method calculates the value of the barrier function h(q) at the current state.
39 | The barrier is active when h(q) is close to zero and satisfied when h(q) > 0.
40 |
41 | :param data: The MuJoCo simulation data.
42 | :return: The computed barrier value.
43 | """
44 | return self.__call__(data)
45 |
46 | def compute_safe_displacement(self, data: mjx.Data) -> jnp.ndarray:
47 | r"""
48 | Compute a safe displacement to move away from constraint boundaries.
49 |
50 | When the barrier function value is close to zero, this method computes
51 | a displacement in joint space that helps move the system away from the constraint boundary:
52 |
53 | .. math::
54 |
55 | \Delta q_{safe} = \alpha \nabla h(q)
56 |
57 | where:
58 | - :math:`\alpha` is the safe_displacement_gain
59 | - :math:`\nabla h(q)` is the gradient of the barrier function
60 |
61 | :param data: The MuJoCo simulation data.
62 | :return: A joint displacement vector to move away from constraint boundaries.
63 | """
64 | return jnp.zeros(self.model.nv)
65 |
66 |
67 | AtomicBarrierType = TypeVar("AtomicBarrierType", bound=JaxBarrier)
68 |
69 |
70 | class Barrier(Generic[AtomicBarrierType], Component[AtomicBarrierType]):
71 | r"""
72 | A generic barrier class that wraps atomic barrier implementations.
73 |
74 | This class provides a high-level interface for barrier functions, allowing
75 | for easy integration into optimization problems. Barrier functions are used
76 | to enforce constraints by creating a potential field that pushes the system
77 | away from constraint boundaries.
78 |
79 | In optimization problems, barriers are typically integrated as inequality constraints
80 | or as penalty terms in the objective function:
81 |
82 | .. math::
83 |
84 | \min_{q} f(q) \quad \text{subject to} \quad h(q) \geq 0
85 |
86 | or as a penalty term:
87 |
88 | .. math::
89 |
90 | \min_{q} f(q) - \lambda \log(h(q))
91 |
92 | where :math:`\lambda` is a weight parameter.
93 |
94 | :param safe_displacement_gain: The gain for computing safe displacements.
95 | """
96 |
97 | safe_displacement_gain: float
98 |
99 | def __init__(self, name, gain, gain_fn=None, safe_displacement_gain=0, mask=None):
100 | """
101 | Initialize the Barrier object.
102 |
103 | :param name: The name of the barrier.
104 | :param gain: The gain for the barrier function.
105 | :param gain_fn: A function to compute the gain dynamically.
106 | :param safe_displacement_gain: The gain for computing safe displacements.
107 | :param mask: A sequence of integers to mask certain dimensions.
108 | """
109 | super().__init__(name, gain, gain_fn, mask)
110 | self.safe_displacement_gain = safe_displacement_gain
111 |
--------------------------------------------------------------------------------
/mjinx/components/constraints/__init__.py:
--------------------------------------------------------------------------------
1 | from ._base import Constraint, JaxConstraint
2 | from ._equality_constraint import JaxModelEqualityConstraint, ModelEqualityConstraint
3 |
4 | __all__ = [
5 | "Constraint",
6 | "JaxConstraint",
7 | "JaxModelEqualityConstraint",
8 | "ModelEqualityConstraint",
9 | ]
10 |
--------------------------------------------------------------------------------
/mjinx/components/constraints/_base.py:
--------------------------------------------------------------------------------
1 | from typing import Generic, Sequence, TypeVar, Callable
2 |
3 | import jax.numpy as jnp
4 | import jax_dataclasses as jdc
5 | import mujoco.mjx as mjx
6 |
7 | from mjinx.components import Component, JaxComponent
8 | from mjinx.typing import ArrayLike, ArrayOrFloat
9 |
10 |
11 | @jdc.pytree_dataclass
12 | class JaxConstraint(JaxComponent):
13 | """
14 | A JAX-based representation of an equality constraint.
15 |
16 | This class defines an equality constraint function f(x) = 0 for optimization tasks.
17 |
18 | :param active: Indicates whether the constraint is active.
19 | :param hard_constraint: A flag that specifies if the constraint is hard (True) or soft (False).
20 | :param soft_constraint_cost: The cost matrix used for soft constraint relaxation.
21 | """
22 |
23 | active: bool
24 | hard_constraint: jdc.Static[bool]
25 | soft_constraint_cost: jnp.ndarray
26 |
27 | def compute_constraint(self, data: mjx.Data) -> jnp.ndarray:
28 | """
29 | Compute the equality constraint function value.
30 |
31 | Evaluates the constraint function f(x) = 0 based on the current simulation data.
32 | For soft constraints, the computed value is later penalized by the associated cost;
33 | for hard constraints, the evaluation is directly used in the Ax = b formulation.
34 |
35 | :param data: The MuJoCo simulation data.
36 | :return: The computed constraint value.
37 | """
38 | return self.__call__(data)
39 |
40 |
41 | AtomicConstraintType = TypeVar("AtomicConstraintType", bound=JaxConstraint)
42 |
43 |
44 | class Constraint(Generic[AtomicConstraintType], Component[AtomicConstraintType]):
45 | r"""
46 | A high-level component for formulating equality constraints.
47 |
48 | This class wraps an atomic JAX constraint (JaxConstraint) and provides a framework to
49 | manage constraints within the optimization problem. Equality constraints are specified as:
50 |
51 | .. math::
52 |
53 | f(x) = 0
54 |
55 | They can be treated in one of two ways:
56 | - As a soft constraint. In this mode, the constraint violation is penalized by a cost
57 | (typically with a high gain), transforming the constraint into a task.
58 | - As a hard constraint. Here, the constraint is enforced as a strict equality in the following form:
59 |
60 | .. math::
61 |
62 | \nabla h(q)^T v = -\alpha h(q),
63 |
64 | where :math:`\alpha` controls constraint enforcement and :math:`v` is the velocity vector.
65 |
66 | :param matrix_cost: The cost matrix associated with the task.
67 | :param lm_damping: The Levenberg-Marquardt damping factor.
68 | :param active: Determines if the constraint is active.
69 | :param hard_constraint: Indicates whether the constraint is hard (True) or soft (False).
70 | """
71 |
72 | active: bool
73 | _soft_constraint_cost: jnp.ndarray | None
74 | hard_constraint: bool
75 |
76 | def __init__(
77 | self,
78 | name: str,
79 | gain: ArrayOrFloat,
80 | mask: Sequence[int] | None = None,
81 | hard_constraint: bool = False,
82 | soft_constraint_cost: ArrayLike | None = None,
83 | ):
84 | """
85 | Initialize a Constraint object.
86 |
87 | Sets up the constraint with the given parameters. Depending on whether it is
88 | a hard or soft constraint, further integration in the optimization problem will
89 | process it accordingly.
90 |
91 | :param name: The unique identifier for the constraint.
92 | :param gain: The gain for the constraint function, affecting its impact.
93 | :param mask: Indices to select specific dimensions for evaluation. If not provided, applies to all dimensions.
94 | :param hard_constraint: If True, the constraint is handled as a hard constraint. Defaults to False (i.e. soft constraint).
95 | :param soft_constraint_cost: The cost used to relax a soft constraint. If not provided, a default high gain cost matrix (scaled identity matrix) based on the component dimension will be used.
96 | """
97 | super().__init__(name, gain, gain_fn=None, mask=mask)
98 | self.active = True
99 | self.hard_constraint = hard_constraint
100 | self._soft_constraint_cost = jnp.array(soft_constraint_cost) if soft_constraint_cost is not None else None
101 |
102 | @property
103 | def soft_constraint_cost(self) -> jnp.ndarray:
104 | """
105 | Get the cost matrix associated with a soft constraint.
106 |
107 | For soft constraints, any violation is penalized by a cost if self._soft_constraint_cost is not None else 1e2 * jnp.eye(self.dim)"
108 |
109 | :return: The cost matrix for the soft constraint.
110 | """
111 | if self._soft_constraint_cost is None:
112 | return 1e2 * jnp.eye(self.dim)
113 |
114 | match self._soft_constraint_cost.ndim:
115 | case 0:
116 | return jnp.eye(self.dim) * self._soft_constraint_cost
117 | case 1:
118 | if len(self._soft_constraint_cost) != self.dim:
119 | raise ValueError(
120 | f"fail to construct matrix jnp.diag(({self.dim},)) from vector of length {self._soft_constraint_cost.shape}"
121 | )
122 | return jnp.diag(self._soft_constraint_cost)
123 | case 2:
124 | if self._soft_constraint_cost.shape != (
125 | self.dim,
126 | self.dim,
127 | ):
128 | raise ValueError(
129 | f"wrong shape of the cost: {self._soft_constraint_cost.shape} != ({self.dim}, {self.dim},)"
130 | )
131 | return self._soft_constraint_cost
132 | case _: # pragma: no cover
133 | raise ValueError("fail to construct cost, given dim > 2")
134 |
--------------------------------------------------------------------------------
/mjinx/components/constraints/_equality_constraint.py:
--------------------------------------------------------------------------------
1 | from collections.abc import Sequence
2 | from typing import TypeVar
3 |
4 | import jax.numpy as jnp # noqa: F401
5 | import jax_dataclasses as jdc
6 | import mujoco as mj
7 | import mujoco.mjx as mjx
8 |
9 | from mjinx.components.constraints._base import Constraint, JaxConstraint
10 | import mjinx.typing
11 |
12 |
13 | @jdc.pytree_dataclass
14 | class JaxModelEqualityConstraint(JaxConstraint):
15 | """
16 | A JAX-based equality constraint derived from the simulation model.
17 |
18 | Equality constraints respresent all constrainst in the mujoco, which could be defined in
19 | tag in the xml file. More details could be found here: https://mujoco.readthedocs.io/en/stable/XMLreference.html#equality.
20 |
21 | This class utilized the fact, that during kinematics computations, the equality value are computed as well,
22 | and the jacobian of the equalities are stored in data.efc_J, and the residuals are stored in data.efc_pos.
23 |
24 | :param data: A MuJoCo simulation data structure containing model-specific constraint information.
25 | :return: A jax.numpy.ndarray with the positions corresponding to the equality constraints.
26 | """
27 |
28 | def __call__(self, data: mjx.Data) -> jnp.ndarray:
29 | """
30 | Compute the equality constraint values from the simulation data.
31 |
32 | This method selects the positions of equality constraints from the simulation data
33 | by filtering based on the constraint type and applying the mask indices.
34 |
35 | :param data: The MuJoCo simulation data.
36 | :return: A jax.numpy.ndarray containing the equality constraint values.
37 | """
38 | return data.efc_pos[data.efc_type == mj.mjtConstraint.mjCNSTR_EQUALITY][self.mask_idxs,]
39 |
40 | def compute_jacobian(self, data: mjx.Data) -> jnp.ndarray:
41 | """
42 | Compute the Jacobian matrix of the equality constraint.
43 |
44 | Retrieves the relevant rows of the Jacobian matrix from the simulation data,
45 | filtering by equality constraint type and applying the mask indices.
46 |
47 | :param data: The MuJoCo simulation data.
48 | :return: A jax.numpy.ndarray representing the Jacobian matrix of the equality constraints.
49 | """
50 | return data.efc_J[data.efc_type == mj.mjtConstraint.mjCNSTR_EQUALITY, :][self.mask_idxs, :]
51 |
52 |
53 | AtomicModelEqualityConstraintType = TypeVar("AtomicModelEqualityConstraintType", bound=JaxModelEqualityConstraint)
54 |
55 |
56 | class ModelEqualityConstraint(Constraint[AtomicModelEqualityConstraintType]):
57 | """
58 | High-level component that aggregates all equality constraints from the simulation model.
59 |
60 | The main purpose of the wrapper is to recalculate mask based on the dimensions of the equality constrain
61 | and compute proper dimensionality.
62 |
63 |
64 | :param name: The unique identifier for the constraint.
65 | :param gain: The gain for the constraint function, affecting its impact.
66 | :param hard_constraint: If True, the constraint is handled as a hard constraint.
67 | Defaults to False (i.e., soft constraint).
68 | :param soft_constraint_cost: The cost used to relax a soft constraint. If not provided, a default high gain
69 | cost matrix (scaled identity matrix) based on the component dimension will be used.
70 | """
71 |
72 | JaxComponentType = JaxModelEqualityConstraint
73 |
74 | def __init__(
75 | self,
76 | name: str = "equality_constraint",
77 | gain: mjinx.typing.ArrayOrFloat = 100,
78 | hard_constraint: bool = False,
79 | soft_constraint_cost: mjinx.typing.ArrayLike | None = None,
80 | ):
81 | """
82 | Initialize a new equality constraint.
83 |
84 | Args:
85 | name (str): A unique identifier for the constraint. Defaults to 'equality_constraint'.
86 | gain (ArrayOrFloat): The gain used to weight or scale the constraint. Can be a float or an array-like type. Defaults to 100.
87 | hard_constraint (bool): Flag to determine if the constraint should be enforced strictly (hard constraint). Defaults to False.
88 | soft_constraint_cost (ArrayLike | None): Optional cost associated with treating the constraint as a soft constraint.
89 | When provided, this cost is used to penalize deviations.
90 |
91 | Attributes:
92 | active (bool): Indicates whether the constraint is active. Initialized to True.
93 | """
94 | super().__init__(
95 | name=name,
96 | gain=gain,
97 | mask=None,
98 | hard_constraint=hard_constraint,
99 | soft_constraint_cost=soft_constraint_cost,
100 | )
101 | self.active = True
102 |
103 | def update_model(self, model: mj.MjModel) -> None:
104 | """
105 | Update the equality constraint using the provided MuJoCo model.
106 |
107 | This method computes the total dimension of the equality constraints by iterating
108 | over the equality types in the model. Based on the type of each equality constraint,
109 | it sets the mask indices and component dimension accordingly.
110 |
111 | :param model: The MuJoCo model instance.
112 | :return: None
113 | """
114 | super().update_model(model)
115 |
116 | nefc = 0
117 | for i in range(self.model.neq):
118 | match self.model.eq_type[i]:
119 | case mj.mjtEq.mjEQ_CONNECT:
120 | nefc += 3
121 | case mj.mjtEq.mjEQ_WELD:
122 | nefc += 6
123 | case mj.mjtEq.mjEQ_JOINT:
124 | nefc += 1
125 | case _:
126 | raise ValueError(f"Unsupported equality constraint type {self.model.eq_type[i]}")
127 |
128 | self._mask_idxs = jnp.arange(nefc)
129 | self._dim = nefc
130 |
--------------------------------------------------------------------------------
/mjinx/components/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from ._base import JaxTask, Task
2 | from ._obj_frame_task import FrameTask, JaxFrameTask
3 | from ._obj_position_task import JaxPositionTask, PositionTask
4 | from ._obj_task import ObjTask, JaxObjTask
5 | from ._com_task import ComTask, JaxComTask
6 | from ._joint_task import JaxJointTask, JointTask
7 |
8 | __all__ = [
9 | "JaxTask",
10 | "Task",
11 | "FrameTask",
12 | "JaxFrameTask",
13 | "JaxPositionTask",
14 | "PositionTask",
15 | "ObjTask",
16 | "JaxObjTask",
17 | "ComTask",
18 | "JaxComTask",
19 | "JaxJointTask",
20 | "JointTask",
21 | ]
22 |
--------------------------------------------------------------------------------
/mjinx/components/tasks/_com_task.py:
--------------------------------------------------------------------------------
1 | """Center of mass task implementation."""
2 |
3 | from collections.abc import Callable, Sequence
4 | from typing import final
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | import jax_dataclasses as jdc
9 | import mujoco
10 | import mujoco.mjx as mjx
11 | from mujoco.mjx._src import scan
12 |
13 | from mjinx.components.tasks._base import JaxTask, Task
14 | from mjinx.configuration import jac_dq2v
15 | from mjinx.typing import ArrayOrFloat
16 |
17 |
18 | @jdc.pytree_dataclass
19 | class JaxComTask(JaxTask):
20 | r"""
21 | A JAX-based implementation of a center of mass (CoM) task for inverse kinematics.
22 |
23 | This class represents a task that aims to achieve a specific target center of mass
24 | for the robot model.
25 |
26 | The task function maps joint positions to the robot's center of mass position:
27 |
28 | .. math::
29 |
30 | f(q) = \frac{\sum_i m_i p_i(q)}{\sum_i m_i}
31 |
32 | where:
33 | - :math:`m_i` is the mass of body i
34 | - :math:`p_i(q)` is the position of body i's center of mass
35 |
36 | The error is computed as the difference between the current and target CoM positions:
37 |
38 | .. math::
39 |
40 | e(q) = p_c(q) - p_{c_{target}}
41 |
42 | :param target_com: The target center of mass position to be achieved.
43 | """
44 |
45 | target_com: jnp.ndarray
46 |
47 | @final
48 | def __call__(self, data: mjx.Data) -> jnp.ndarray:
49 | r"""
50 | Compute the error between the current center of mass and the target center of mass.
51 |
52 | The error is given by:
53 |
54 | .. math::
55 |
56 | e(q) = p_c(q) - p_{c_{target}}
57 |
58 | :param data: The MuJoCo simulation data.
59 | :return: The error vector representing the difference between the current and target center of mass.
60 | """
61 | error = data.subtree_com[self.model.body_rootid[0], self.mask_idxs] - self.target_com
62 | return error
63 |
64 | @final
65 | def compute_jacobian(self, data):
66 | def specific_update(model: mjx.Model, q: jnp.ndarray) -> mjx.Data:
67 | data = mjx.kinematics(model, mjx.make_data(model).replace(qpos=q))
68 |
69 | # calculate center of mass of each subtree
70 | def subtree_sum(carry, xipos, body_mass):
71 | pos, mass = xipos * body_mass, body_mass
72 | if carry is not None:
73 | subtree_pos, subtree_mass = carry
74 | pos, mass = pos + subtree_pos, mass + subtree_mass
75 | return pos, mass
76 |
77 | pos, mass = scan.body_tree(model, subtree_sum, "bb", "bb", data.xipos, model.body_mass, reverse=True)
78 | cond = jnp.tile(mass < mujoco.mjMINVAL, (3, 1)).T
79 | # take maximum to avoid NaN in gradient of jp.where
80 | subtree_com = jax.vmap(jnp.divide)(pos, jnp.maximum(mass, mujoco.mjMINVAL))
81 | subtree_com = jnp.where(cond, data.xipos, subtree_com)
82 | data = data.replace(subtree_com=subtree_com)
83 |
84 | return data
85 |
86 | jac = jax.jacrev(
87 | lambda q, model=self.model: self.__call__(
88 | specific_update(model, q),
89 | ),
90 | argnums=0,
91 | )(data.qpos)
92 | if self.model.nq != self.model.nv:
93 | jac = jac @ jac_dq2v(self.model, data.qpos)
94 | return jac
95 |
96 |
97 | class ComTask(Task[JaxComTask]):
98 | """
99 | A high-level representation of a center of mass (CoM) task for inverse kinematics.
100 |
101 | This class provides an interface for creating and manipulating center of mass tasks,
102 | which aim to achieve a specific target center of mass for the robot model.
103 |
104 | :param name: The name of the task.
105 | :param cost: The cost associated with the task.
106 | :param gain: The gain for the task.
107 | :param gain_fn: A function to compute the gain dynamically.
108 | :param lm_damping: The Levenberg-Marquardt damping factor.
109 | :param mask: A sequence of integers to mask certain dimensions of the task.
110 | """
111 |
112 | JaxComponentType: type = JaxComTask
113 | _target_com: jnp.ndarray
114 |
115 | def __init__(
116 | self,
117 | name: str,
118 | cost: ArrayOrFloat,
119 | gain: ArrayOrFloat,
120 | gain_fn: Callable[[float], float] | None = None,
121 | lm_damping: float = 0,
122 | mask: Sequence[int] | None = None,
123 | ):
124 | if mask is not None and len(mask) != 3:
125 | raise ValueError("provided mask is too large, expected 1D vector of length 3")
126 | super().__init__(name, cost, gain, gain_fn, lm_damping, mask=mask)
127 | self._dim = 3 if mask is None else len(self.mask_idxs)
128 | self.target_com = jnp.zeros(self._dim)
129 |
130 | @property
131 | def target_com(self) -> jnp.ndarray:
132 | """
133 | Get the current target center of mass for the task.
134 |
135 | :return: The current target center of mass as a numpy array.
136 | """
137 | return self._target_com
138 |
139 | @target_com.setter
140 | def target_com(self, value: Sequence):
141 | """
142 | Set the target center of mass for the task.
143 |
144 | :param value: The new target center of mass as a sequence of values.
145 | """
146 | self.update_target_com(value)
147 |
148 | def update_target_com(self, target_com: Sequence):
149 | """
150 | Update the target center of mass for the task.
151 |
152 | This method allows setting the target center of mass using a sequence of values.
153 |
154 | :param target_com: The new target center of mass as a sequence of values.
155 | :raises ValueError: If the provided sequence doesn't have the correct length.
156 | """
157 |
158 | target_com_jnp = jnp.array(target_com)
159 | if target_com_jnp.shape[-1] != self._dim:
160 | raise ValueError(
161 | f"invalid last dimension of target CoM : {target_com_jnp.shape[-1]} given, expected {self._dim} "
162 | )
163 | self._target_com = target_com_jnp
164 |
--------------------------------------------------------------------------------
/mjinx/components/tasks/_obj_frame_task.py:
--------------------------------------------------------------------------------
1 | """Frame task implementation."""
2 |
3 | from collections.abc import Callable, Sequence
4 | from typing import final
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | import jax_dataclasses as jdc
9 | import mujoco as mj
10 | import mujoco.mjx as mjx
11 | from jaxlie import SE3, SO3
12 |
13 | from mjinx.components.tasks._obj_task import JaxObjTask, ObjTask
14 | from mjinx.configuration import get_frame_jacobian_local
15 | from mjinx.typing import ArrayOrFloat, ndarray
16 |
17 |
18 | @jdc.pytree_dataclass
19 | class JaxFrameTask(JaxObjTask):
20 | r"""
21 | A JAX-based implementation of a frame task for inverse kinematics.
22 |
23 | This class represents a task that aims to achieve a specific target frame
24 | for a given object in the robot model.
25 |
26 | The task function maps joint positions to the object's frame (pose) in the world frame:
27 |
28 | .. math::
29 |
30 | f(q) = T(q) \in SE(3)
31 |
32 | where :math:`T(q)` is the transformation matrix representing the object's pose.
33 |
34 | The error is computed using the logarithmic map in SE(3), which represents the
35 | relative transformation between current and target frames as a twist vector:
36 |
37 | .. math::
38 |
39 | e(q) = \log(T(q)^{-1} T_{target})
40 |
41 | This formulation provides a natural way to interpolate between frames and
42 | control both position and orientation simultaneously.
43 |
44 | :param target_frame: The target frame to be achieved.
45 | """
46 |
47 | target_frame: SE3
48 |
49 | @final
50 | def __call__(self, data: mjx.Data) -> jnp.ndarray:
51 | r"""
52 | Compute the error between the current frame and the target frame.
53 |
54 | The error is given by the logarithmic map in SE(3):
55 |
56 | .. math::
57 |
58 | e(q) = \log(T(q)^{-1} T_{target})
59 |
60 | This creates a 6D twist vector representing the relative transformation
61 | between the current and target frames.
62 |
63 | :param data: The MuJoCo simulation data.
64 | :return: The error vector representing the difference between the current and target frames.
65 | """
66 | return (self.get_frame(data).inverse() @ self.target_frame).log()[self.mask_idxs,]
67 |
68 | @final
69 | def compute_jacobian(self, data: mjx.Data) -> jnp.ndarray:
70 | r"""
71 | Compute the Jacobian of the frame task.
72 |
73 | The Jacobian relates changes in joint positions to changes in the frame task error.
74 | It is computed as:
75 |
76 | .. math::
77 |
78 | J = -J_{\log} \cdot J_{frame}^T
79 |
80 | where:
81 | - :math:`J_{\log}` is the Jacobian of the logarithmic map at the relative transformation
82 | - :math:`J_{frame}` is the geometric Jacobian of the object's frame
83 |
84 | :param data: The MuJoCo simulation data.
85 | :return: The Jacobian matrix of the frame task.
86 | """
87 | T_bt = self.target_frame.inverse() @ self.get_frame(data).inverse()
88 |
89 | def transform_log(tau):
90 | return (T_bt.multiply(SE3.exp(tau))).log()
91 |
92 | frame_jac = get_frame_jacobian_local(self.model, data, self.obj_id, self.obj_type)
93 | jlog = jax.jacobian(transform_log)(jnp.zeros(SE3.tangent_dim))
94 | return (-jlog @ frame_jac.T)[self.mask_idxs,]
95 |
96 |
97 | class FrameTask(ObjTask[JaxFrameTask]):
98 | """
99 | A high-level representation of a frame task for inverse kinematics.
100 |
101 | This class provides an interface for creating and manipulating frame tasks,
102 | which aim to achieve a specific target frame for a object (body, geom, or site) in the robot model.
103 |
104 | :param name: The name of the task.
105 | :param cost: The cost associated with the task.
106 | :param gain: The gain for the task.
107 | :param obj_name: The name of the object to which the task is applied.
108 | :param gain_fn: A function to compute the gain dynamically.
109 | :param lm_damping: The Levenberg-Marquardt damping factor.
110 | :param mask: A sequence of integers to mask certain dimensions of the task.
111 | """
112 |
113 | JaxComponentType: type = JaxFrameTask
114 | _target_frame: SE3
115 |
116 | def __init__(
117 | self,
118 | name: str,
119 | cost: ArrayOrFloat,
120 | gain: ArrayOrFloat,
121 | obj_name: str,
122 | obj_type: mj.mjtObj = mj.mjtObj.mjOBJ_BODY,
123 | gain_fn: Callable[[float], float] | None = None,
124 | lm_damping: float = 0,
125 | mask: Sequence[int] | None = None,
126 | ):
127 | super().__init__(name, cost, gain, obj_name, obj_type, gain_fn, lm_damping, mask)
128 | self.target_frame = SE3.identity()
129 | self._dim = SE3.tangent_dim if mask is None else len(self.mask_idxs)
130 |
131 | @property
132 | def target_frame(self) -> SE3:
133 | """
134 | Get the current target frame for the task.
135 |
136 | :return: The current target frame as an SE3 object.
137 | """
138 | return self._target_frame
139 |
140 | @target_frame.setter
141 | def target_frame(self, value: SE3 | Sequence | ndarray):
142 | """
143 | Set the target frame for the task.
144 |
145 | :param value: The new target frame, either as an SE3 object or a sequence of values.
146 | """
147 | self.update_target_frame(value)
148 |
149 | def update_target_frame(self, target_frame: SE3 | Sequence | ndarray):
150 | """
151 | Update the target frame for the task.
152 |
153 | This method allows setting the target frame using either an SE3 object
154 | or a sequence of values representing the frame.
155 |
156 | :param target_frame: The new target frame, either as an SE3 object or a sequence of values.
157 | :raises ValueError: If the provided sequence doesn't have the correct length.
158 | """
159 | if not isinstance(target_frame, SE3):
160 | target_frame_jnp = jnp.array(target_frame)
161 | if target_frame_jnp.shape[-1] != SE3.parameters_dim:
162 | raise ValueError(
163 | "Target frame provided via array must have length 7 (xyz + quaternion with scalar first)"
164 | )
165 |
166 | xyz, quat = target_frame_jnp[..., :3], target_frame_jnp[..., 3:]
167 | target_frame_se3 = SE3.from_rotation_and_translation(
168 | SO3.from_quaternion_xyzw(
169 | quat[..., [1, 2, 3, 0]],
170 | ),
171 | xyz,
172 | )
173 | else:
174 | target_frame_se3 = target_frame
175 | self._target_frame = target_frame_se3
176 |
--------------------------------------------------------------------------------
/mjinx/components/tasks/_obj_position_task.py:
--------------------------------------------------------------------------------
1 | """Frame task implementation."""
2 |
3 | from collections.abc import Callable, Sequence
4 |
5 | import jax.numpy as jnp
6 | import jax_dataclasses as jdc
7 | import mujoco as mj
8 | import mujoco.mjx as mjx
9 |
10 | from mjinx.components.tasks._obj_task import JaxObjTask, ObjTask
11 | from mjinx.typing import ArrayOrFloat
12 |
13 |
14 | @jdc.pytree_dataclass
15 | class JaxPositionTask(JaxObjTask):
16 | """
17 | A JAX-based implementation of a position task for inverse kinematics.
18 |
19 | This class represents a task that aims to achieve a specific target position
20 | for an object (body, geometry, or site) in the robot model.
21 |
22 | The task function maps joint positions to the object's position in the world frame:
23 |
24 | .. math::
25 |
26 | f(q) = p(q)
27 |
28 | where :math:`p(q)` is the position of the object in world coordinates.
29 |
30 | The error is computed as the difference between the current and target positions:
31 |
32 | .. math::
33 |
34 | e(q) = p(q) - p_{target}
35 |
36 | The Jacobian of this task is the object's position Jacobian, which relates changes in
37 | joint positions to changes in the object's position.
38 |
39 | :param target_pos: The target position to be achieved.
40 | """
41 |
42 | target_pos: jnp.ndarray
43 |
44 | def __call__(self, data: mjx.Data) -> jnp.ndarray:
45 | """
46 | Compute the error between the current position and the target position.
47 |
48 | The error is given by:
49 |
50 | .. math::
51 |
52 | e(q) = p(q) - p_{target}
53 |
54 | :param data: The MuJoCo simulation data.
55 | :return: The error vector representing the difference between the current and target positions.
56 | """
57 | return self.get_pos(data)[self.mask_idxs,] - self.target_pos
58 |
59 |
60 | class PositionTask(ObjTask[JaxPositionTask]):
61 | """
62 | A high-level representation of a position task for inverse kinematics.
63 |
64 | This class provides an interface for creating and manipulating position tasks,
65 | which aim to achieve a specific target position for an object in the robot model.
66 |
67 | :param name: The name of the task.
68 | :param cost: The cost associated with the task.
69 | :param gain: The gain for the task.
70 | :param obj_name: The name of the object to which the task is applied.
71 | :param gain_fn: A function to compute the gain dynamically.
72 | :param lm_damping: The Levenberg-Marquardt damping factor.
73 | :param mask: A sequence of integers to mask certain dimensions of the task.
74 | """
75 |
76 | JaxComponentType: type = JaxPositionTask
77 | _target_pos: jnp.ndarray
78 |
79 | def __init__(
80 | self,
81 | name: str,
82 | cost: ArrayOrFloat,
83 | gain: ArrayOrFloat,
84 | obj_name: str,
85 | obj_type: mj.mjtObj = mj.mjtObj.mjOBJ_BODY,
86 | gain_fn: Callable[[float], float] | None = None,
87 | lm_damping: float = 0,
88 | mask: Sequence[int] | None = None,
89 | ):
90 | super().__init__(name, cost, gain, obj_name, obj_type, gain_fn, lm_damping, mask)
91 | self._dim = 3 if mask is None else len(self.mask_idxs)
92 | self._target_pos = jnp.zeros(self._dim)
93 |
94 | @property
95 | def target_pos(self) -> jnp.ndarray:
96 | """
97 | Get the current target position for the task.
98 |
99 | :return: The current target position as a numpy array.
100 | """
101 | return self._target_pos
102 |
103 | @target_pos.setter
104 | def target_pos(self, value: Sequence):
105 | """
106 | Set the target position for the task.
107 |
108 | :param value: The new target position as a sequence of values.
109 | """
110 | self.update_target_pos(value)
111 |
112 | def update_target_pos(self, target_pos: Sequence):
113 | """
114 | Update the target position for the task.
115 |
116 | This method allows setting the target position using a sequence of values.
117 |
118 | :param target_pos: The new target position as a sequence of values.
119 | :raises ValueError: If the provided sequence doesn't have the correct length.
120 | """
121 | target_pos_array = jnp.array(target_pos)
122 | if target_pos_array.shape[-1] != self._dim:
123 | raise ValueError(
124 | f"Invalid dimension of the target position: expected {self._dim}, got {target_pos_array.shape[-1]}"
125 | )
126 | self._target_pos = target_pos_array
127 |
--------------------------------------------------------------------------------
/mjinx/configuration/__init__.py:
--------------------------------------------------------------------------------
1 | from ._collision import compute_collision_pairs, geom_groups, get_distance, sorted_pair
2 | from ._lie import attitude_jacobian, get_joint_zero, jac_dq2v, joint_difference, skew_symmetric
3 | from ._model import (
4 | geom_point_jacobian,
5 | get_configuration_limit,
6 | get_frame_jacobian_local,
7 | get_frame_jacobian_world_aligned,
8 | get_transform,
9 | get_transform_frame_to_world,
10 | integrate,
11 | update,
12 | )
13 |
--------------------------------------------------------------------------------
/mjinx/solvers/__init__.py:
--------------------------------------------------------------------------------
1 | from ._base import Solver, SolverData
2 | from ._global_ik import GlobalIKData, GlobalIKSolution, GlobalIKSolver
3 | from ._local_ik import LocalIKData, LocalIKSolution, LocalIKSolver
4 |
5 | __all__ = [
6 | "Solver",
7 | "SolverData",
8 | "GlobalIKData",
9 | "GlobalIKSolution",
10 | "GlobalIKSolver",
11 | "LocalIKData",
12 | "LocalIKSolution",
13 | "LocalIKSolver",
14 | ]
15 |
--------------------------------------------------------------------------------
/mjinx/solvers/_base.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Generic, TypeVar
3 |
4 | import jax.numpy as jnp
5 | import jax_dataclasses as jdc
6 | import mujoco.mjx as mjx
7 |
8 | import mjinx.typing as mjt
9 | from mjinx import configuration
10 | from mjinx.problem import JaxProblemData
11 |
12 |
13 | @jdc.pytree_dataclass
14 | class SolverData:
15 | """Base class for solver-specific data.
16 |
17 | This class serves as a placeholder for any data that a specific solver needs to maintain
18 | between iterations or function calls. It enables solver implementations to preserve
19 | state information and warm-start subsequent optimization steps.
20 | """
21 |
22 | pass
23 |
24 |
25 | @jdc.pytree_dataclass
26 | class SolverSolution:
27 | """Base class for solver solutions.
28 |
29 | This class provides the structure for returning optimization results. It contains
30 | the optimal velocity solution and can be extended by specific solvers to include
31 | additional solution information.
32 |
33 | :param v_opt: Optimal velocity solution.
34 | """
35 |
36 | v_opt: jnp.ndarray
37 |
38 |
39 | SolverDataType = TypeVar("SolverDataType", bound=SolverData)
40 | SolverSolutionType = TypeVar("SolverSolutionType", bound=SolverSolution)
41 |
42 |
43 | class Solver(Generic[SolverDataType, SolverSolutionType], abc.ABC):
44 | r"""Abstract base class for solvers.
45 |
46 | This class defines the interface for solvers used in inverse kinematics problems.
47 | Solvers transform task and barrier constraints into optimization problems, which
48 | can be solved to find joint configurations or velocities that satisfy the constraints.
49 |
50 | In general, the optimization problem can be formulated as:
51 |
52 | .. math::
53 |
54 | \min_{q} \sum_{i} \|e_i(q)\|^2_{W_i} \quad \text{subject to} \quad h_j(q) \geq 0
55 |
56 | where:
57 | - :math:`e_i(q)` are the task errors
58 | - :math:`W_i` are the task weight matrices
59 | - :math:`h_j(q)` are the barrier constraints
60 |
61 | Different solver implementations use different approaches to solve this problem,
62 | such as local linearization (QP) or global nonlinear optimization.
63 |
64 | :param model: The MuJoCo model used by the solver.
65 | """
66 |
67 | model: mjx.Model
68 |
69 | def __init__(self, model: mjx.Model):
70 | """Initialize the solver with a MuJoCo model.
71 |
72 | :param model: The MuJoCo model to be used by the solver.
73 | """
74 | self.model = model
75 |
76 | @abc.abstractmethod
77 | def solve_from_data(
78 | self, solver_data: SolverDataType, problem_data: JaxProblemData, model_data: mjx.Data
79 | ) -> tuple[SolverSolutionType, SolverDataType]:
80 | """Solve the inverse kinematics problem using pre-computed model data.
81 |
82 | :param solver_data: Solver-specific data.
83 | :param problem_data: Problem-specific data.
84 | :param model_data: MuJoCo model data.
85 | :return: A tuple containing the solver solution and updated solver data.
86 | """
87 | pass
88 |
89 | def solve(
90 | self, q: jnp.ndarray, solver_data: SolverDataType, problem_data: JaxProblemData
91 | ) -> tuple[SolverSolutionType, SolverDataType]:
92 | """Solve the inverse kinematics problem for a given configuration.
93 |
94 | This method creates mjx.Data instance and updates it under the hood. To avoid doing an extra
95 | update, consider solve_from_data method.
96 |
97 | :param q: The current joint configuration.
98 | :param solver_data: Solver-specific data.
99 | :param problem_data: Problem-specific data.
100 | :return: A tuple containing the solver solution and updated solver data.
101 | :raises ValueError: If the input configuration has incorrect dimensions.
102 | """
103 | if q.shape != (self.model.nq,):
104 | raise ValueError(f"wrong dimension of the state: expected ({self.model.nq}, ), got {q.shape}")
105 | model_data = configuration.update(self.model, q)
106 | return self.solve_from_data(solver_data, problem_data, model_data)
107 |
108 | @abc.abstractmethod
109 | def init(self, q: mjt.ndarray) -> SolverDataType:
110 | """Initialize solver-specific data.
111 |
112 | :param q: The initial joint configuration.
113 | :return: Initialized solver-specific data.
114 | """
115 | pass
116 |
--------------------------------------------------------------------------------
/mjinx/typing.py:
--------------------------------------------------------------------------------
1 | """
2 | This module contains type definitions and aliases used throughout the mjinx library.
3 | """
4 |
5 | from __future__ import annotations
6 |
7 | from collections.abc import Callable
8 | from enum import Enum
9 | from typing import NamedTuple, TypeAlias
10 |
11 | import jax
12 | import jax.numpy as jnp
13 | import numpy as np
14 | from mujoco.mjx._src.dataclasses import PyTreeNode
15 |
16 | ndarray: TypeAlias = np.ndarray | jnp.ndarray
17 | """Type alias for numpy or JAX numpy arrays."""
18 |
19 | ArrayOrFloat: TypeAlias = ndarray | float
20 | """Type alias for an array or a float value."""
21 |
22 | ClassKFunctions: TypeAlias = Callable[[ndarray], ndarray]
23 | """Type alias for Class K functions, which are scalar functions that take and return ndarrays."""
24 |
25 | CollisionBody: TypeAlias = int | str
26 | """Type alias for collision body representation, either as an integer ID or a string name."""
27 |
28 | CollisionPair: TypeAlias = tuple[int, int]
29 | """Type alias for a pair of collision body IDs."""
30 |
31 | ArrayLike: TypeAlias = np.typing.ArrayLike | jax.typing.ArrayLike
32 | """Type alias for an array-like object, either a numpy array or a JAX array-like object."""
33 |
34 |
35 | class SimplifiedContact(PyTreeNode):
36 | geom: jnp.ndarray
37 | dist: jnp.ndarray
38 | pos: jnp.ndarray
39 | frame: jnp.ndarray
40 |
41 |
42 | class PositionLimitType(Enum):
43 | """Type which describes possible position limits.
44 |
45 | The position limit could be only minimal, only maximal, or minimal and maximal.
46 | """
47 |
48 | MIN = 0
49 | MAX = 1
50 | BOTH = 2
51 |
52 | @staticmethod
53 | def from_str(type: str) -> PositionLimitType:
54 | """Generates position limit type from string.
55 |
56 | :param type: position limit type.
57 | :raises ValueError: limit name is not 'min', 'max', or 'both'.
58 | :return: corresponding enum type.
59 | """
60 | match type.lower():
61 | case "min":
62 | return PositionLimitType.MIN
63 | case "max":
64 | return PositionLimitType.MAX
65 | case "both":
66 | return PositionLimitType.BOTH
67 | case _:
68 | raise ValueError(
69 | f"[PositionLimitType] invalid position limit type: {type}. Expected {{'min', 'max', 'both'}}"
70 | )
71 |
72 | @staticmethod
73 | def includes_min(type: PositionLimitType) -> bool:
74 | """Either given limit includes minimum limit or not.
75 |
76 | Returns true, if limit is either MIN or BOTH, and false otherwise.
77 |
78 | :param type: limit to be processes.
79 | :return: True, if limit includes minimum limit, False otherwise.
80 | """
81 | return type == PositionLimitType.MIN or type == PositionLimitType.BOTH
82 |
83 | @staticmethod
84 | def includes_max(type: PositionLimitType) -> bool:
85 | """Either given limit includes maximum limit or not.
86 |
87 | Returns true, if limit is either MIN or BOTH, and false otherwise.
88 |
89 | :param type: limit to be processes.
90 | :return: True, if limit includes maximum limit, False otherwise.
91 | """
92 |
93 | return type == PositionLimitType.MAX or type == PositionLimitType.BOTH
94 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | ignore_missing_imports = True
3 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "mjinx"
3 | version = "0.1.1"
4 | description = "Numerical Inverse Kinematics based on JAX + MJX"
5 | authors = [
6 | { name = "Ivan Domrachev", email = "domrachev10@gmail.com" },
7 | { name = "Simeon Nedelchev", email = "simkaned@gmail.com" },
8 | { name = "Lev Kozlov", email = "l.kozlov@kaist.ac.kr" },
9 | ]
10 | readme = "README.md"
11 | requires-python = ">=3.10"
12 |
13 | dependencies = [
14 | "mujoco",
15 | "mujoco_mjx",
16 | "jaxopt",
17 | "jax>=0.5",
18 | "jaxlie>=1.4",
19 | "jax_dataclasses>=1.6.0",
20 | "optax>=0.2",
21 | ]
22 |
23 | [project.optional-dependencies]
24 | dev = ["pre-commit", "ruff", "pytest", "robot_descriptions>=1.12"]
25 | docs = [
26 | "sphinx>=8",
27 | "sphinx-mathjax-offline",
28 | "sphinx-autodoc-typehints",
29 | "sphinx-rtd-theme>1",
30 | "sphinxcontrib-bibtex",
31 | ]
32 | visual = ["dm_control", "mediapy"]
33 | examples = ["mjinx[visual]", "robot_descriptions>=1.12"]
34 | all = ["mjinx[dev, visual, docs]"]
35 |
36 | [build-system]
37 | requires = ["setuptools>=43.0.0", "wheel"]
38 | build-backend = "setuptools.build_meta"
39 |
40 | [tool.setuptools.packages.find]
41 | where = ["."]
42 | include = ["mjinx*"]
43 |
44 | [tool.ruff]
45 | select = [
46 | "E", # pycodestyle errors
47 | "W", # pycodestyle warnings
48 | "F", # pyflakes
49 | "I", # isort
50 | "B", # flake8-bugbear
51 | "C4", # flake8-comprehensions
52 | "UP", # pyupgrade
53 | ]
54 | line-length = 120
55 |
56 | [tool.ruff.lint.per-file-ignores]
57 | "__init__.py" = ["F401"]
58 |
59 | [tool.pytest]
60 | filterwarnings = "ignore:.*U.*mode is deprecated:DeprecationWarning"
61 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/__init__.py
--------------------------------------------------------------------------------
/tests/example_tests/test_global_ik_jit.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import mujoco as mj
4 | import mujoco.mjx as mjx
5 | import numpy as np
6 | from optax import adam
7 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH
8 |
9 | from mjinx.components.barriers import JointBarrier, PositionBarrier
10 | from mjinx.components.tasks import FrameTask
11 | from mjinx.problem import Problem
12 | from mjinx.solvers import GlobalIKSolver
13 |
14 |
15 | def test_global_ik_jit():
16 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
17 | mjx_model = mjx.put_model(mj_model)
18 |
19 | # === Mjinx ===
20 |
21 | # --- Constructing the problem ---
22 | # Creating problem formulation
23 | problem = Problem(mjx_model, v_min=-100, v_max=100)
24 |
25 | # Creating components of interest and adding them to the problem
26 | frame_task = FrameTask("ee_task", cost=1, gain=50, obj_name="link7")
27 | position_barrier = PositionBarrier(
28 | "ee_barrier",
29 | gain=0.1,
30 | obj_name="link7",
31 | limit_type="max",
32 | p_max=0.3,
33 | safe_displacement_gain=1e-2,
34 | mask=[1, 0, 0],
35 | )
36 | joints_barrier = JointBarrier("jnt_range", gain=0.1)
37 |
38 | problem.add_component(frame_task)
39 | problem.add_component(position_barrier)
40 | problem.add_component(joints_barrier)
41 |
42 | # Compiling the problem upon any parameters update
43 | problem_data = problem.compile()
44 |
45 | # Initializing solver and its initial state
46 | solver = GlobalIKSolver(mjx_model, adam(learning_rate=1e-2), dt=1e-2)
47 |
48 | # Initial condition
49 | q = np.array(
50 | [
51 | -1.4238753,
52 | -1.7268502,
53 | -0.84355015,
54 | 2.0962472,
55 | 2.1339328,
56 | 2.0837479,
57 | -2.5521986,
58 | ]
59 | )
60 | solver_data = solver.init(q)
61 |
62 | # Jit-compiling the key functions for better efficiency
63 | solve_jit = jax.jit(solver.solve)
64 |
65 | # === Control loop ===
66 | dt = 1e-2
67 | ts = np.arange(0, 1, dt)
68 |
69 | for t in ts:
70 | # Changing desired values
71 | frame_task.target_frame = np.array([0.2 + 0.2 * jnp.sin(t) ** 2, 0.2, 0.2, 1, 0, 0, 0])
72 | # After changes, recompiling the model
73 | problem_data = problem.compile()
74 | # Solving the instance of the problem
75 | for _ in range(1):
76 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
77 |
--------------------------------------------------------------------------------
/tests/example_tests/test_global_ik_vmap.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import mujoco as mj
4 | import mujoco.mjx as mjx
5 | import numpy as np
6 | from optax import adam
7 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH
8 |
9 | from mjinx.components.barriers import JointBarrier, PositionBarrier
10 | from mjinx.components.tasks import FrameTask
11 | from mjinx.problem import Problem
12 | from mjinx.solvers import GlobalIKSolver
13 |
14 |
15 | def test_global_ik_jit():
16 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
17 | mjx_model = mjx.put_model(mj_model)
18 |
19 | q_min = mj_model.jnt_range[:, 0].copy()
20 | q_max = mj_model.jnt_range[:, 1].copy()
21 |
22 | # === Mjinx ===
23 |
24 | # --- Constructing the problem ---
25 | # Creating problem formulation
26 | problem = Problem(mjx_model, v_min=-100, v_max=100)
27 |
28 | # Creating components of interest and adding them to the problem
29 | frame_task = FrameTask("ee_task", cost=1, gain=50, obj_name="link7")
30 | position_barrier = PositionBarrier(
31 | "ee_barrier",
32 | gain=0.1,
33 | obj_name="link7",
34 | limit_type="max",
35 | p_max=0.3,
36 | safe_displacement_gain=1e-2,
37 | mask=[1, 0, 0],
38 | )
39 | joints_barrier = JointBarrier("jnt_range", gain=0.1)
40 |
41 | problem.add_component(frame_task)
42 | problem.add_component(position_barrier)
43 | problem.add_component(joints_barrier)
44 |
45 | # Compiling the problem upon any parameters update
46 | problem_data = problem.compile()
47 |
48 | # Initializing solver and its initial state
49 | solver = GlobalIKSolver(mjx_model, adam(learning_rate=1e-2), dt=1e-2)
50 |
51 | # Initializing initial condition
52 | N_batch = 1000
53 | np.random.seed(42)
54 | q0 = jnp.array(
55 | [
56 | -1.4238753,
57 | -1.7268502,
58 | -0.84355015,
59 | 2.0962472,
60 | 2.1339328,
61 | 2.0837479,
62 | -2.5521986,
63 | ]
64 | )
65 | q = jnp.array(
66 | [
67 | np.clip(
68 | q0
69 | + np.random.uniform(
70 | -0.5,
71 | 0.5,
72 | size=(mj_model.nq),
73 | ),
74 | q_min + 1e-1,
75 | q_max - 1e-1,
76 | )
77 | for _ in range(N_batch)
78 | ]
79 | )
80 |
81 | # --- Batching ---
82 | # First of all, data should be created via vmapped init function
83 | solver_data = jax.vmap(solver.init, in_axes=0)(q)
84 |
85 | # To create a batch w.r.t. desired component's attributes, library defines convinient wrapper
86 | # That sets all elements to None and allows user to mutate dataclasses of interest.
87 | # After exiting the Context Manager, you'll get immutable jax dataclass object.
88 | with problem.set_vmap_dimension() as empty_problem_data:
89 | empty_problem_data.components["ee_task"].target_frame = 0
90 |
91 | # Vmapping solve and integrate functions.
92 | # Note that for batching w.r.t. q both q and solver_data should be batched.
93 | # Other approaches might work, but it would be undefined behaviour, please stick to this format.
94 | solve_jit = jax.jit(jax.vmap(solver.solve, in_axes=(0, 0, empty_problem_data)))
95 |
96 | # === Control loop ===
97 | dt = 1e-2
98 | ts = np.arange(0, 1, dt)
99 |
100 | for t in ts:
101 | # Changing desired values
102 | frame_task.target_frame = np.array(
103 | [[0.2 + 0.2 * np.sin(t + np.pi * i / N_batch) ** 2, 0.2, 0.2, 1, 0, 0, 0] for i in range(N_batch)]
104 | )
105 | # After changes, recompiling the model
106 | problem_data = problem.compile()
107 | # Solving the instance of the problem
108 | for _ in range(1):
109 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
110 |
--------------------------------------------------------------------------------
/tests/example_tests/test_local_ik_jit.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import mujoco as mj
4 | import mujoco.mjx as mjx
5 | import numpy as np
6 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH
7 |
8 | from mjinx.components.barriers import JointBarrier, PositionBarrier
9 | from mjinx.components.tasks import FrameTask
10 | from mjinx.configuration import integrate
11 | from mjinx.problem import Problem
12 | from mjinx.solvers import LocalIKSolver
13 |
14 | # === Mujoco ===
15 |
16 |
17 | def test_local_ik_jit():
18 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
19 | mjx_model = mjx.put_model(mj_model)
20 |
21 | # === Mjinx ===
22 |
23 | # --- Constructing the problem ---
24 | # Creating problem formulation
25 | problem = Problem(mjx_model, v_min=-100, v_max=100)
26 |
27 | # Creating components of interest and adding them to the problem
28 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7")
29 | position_barrier = PositionBarrier(
30 | "ee_barrier",
31 | gain=100,
32 | obj_name="link7",
33 | limit_type="max",
34 | p_max=0.3,
35 | safe_displacement_gain=1e-2,
36 | mask=[1, 0, 0],
37 | )
38 | joints_barrier = JointBarrier("jnt_range", gain=10)
39 |
40 | problem.add_component(frame_task)
41 | problem.add_component(position_barrier)
42 | problem.add_component(joints_barrier)
43 |
44 | # Compiling the problem upon any parameters update
45 | problem_data = problem.compile()
46 |
47 | # Initializing solver and its initial state
48 | solver = LocalIKSolver(mjx_model, maxiter=20)
49 |
50 | # Initial condition
51 | q = jnp.array(
52 | [
53 | -1.4238753,
54 | -1.7268502,
55 | -0.84355015,
56 | 2.0962472,
57 | 2.1339328,
58 | 2.0837479,
59 | -2.5521986,
60 | ]
61 | )
62 | solver_data = solver.init()
63 |
64 | # Jit-compiling the key functions for better efficiency
65 | solve_jit = jax.jit(solver.solve)
66 | integrate_jit = jax.jit(integrate, static_argnames=["dt"])
67 |
68 | # === Control loop ===
69 | dt = 1e-2
70 | ts = np.arange(0, 1, dt)
71 |
72 | for t in ts:
73 | # Changing desired values
74 | frame_task.target_frame = np.array([0.2 + 0.2 * jnp.sin(t) ** 2, 0.2, 0.2, 1, 0, 0, 0])
75 | # After changes, recompiling the model
76 | problem_data = problem.compile()
77 |
78 | # Solving the instance of the problem
79 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
80 |
81 | # Integrating
82 | q = integrate_jit(
83 | mjx_model,
84 | q,
85 | velocity=opt_solution.v_opt,
86 | dt=dt,
87 | )
88 |
--------------------------------------------------------------------------------
/tests/example_tests/test_local_ik_vmap.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import mujoco as mj
4 | import mujoco.mjx as mjx
5 | import numpy as np
6 | from robot_descriptions.iiwa14_mj_description import MJCF_PATH
7 |
8 | from mjinx.components.barriers import JointBarrier, PositionBarrier
9 | from mjinx.components.tasks import FrameTask
10 | from mjinx.configuration import integrate
11 | from mjinx.problem import Problem
12 | from mjinx.solvers import LocalIKSolver
13 |
14 |
15 | # === Mujoco ===
16 | def test_local_ik_vmap():
17 | mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
18 |
19 | mjx_model = mjx.put_model(mj_model)
20 |
21 | # === Mjinx ===
22 |
23 | # --- Constructing the problem ---
24 | # Creating problem formulation
25 | problem = Problem(mjx_model, v_min=-100, v_max=100)
26 |
27 | # Creating components of interest and adding them to the problem
28 | frame_task = FrameTask("ee_task", cost=1, gain=20, obj_name="link7")
29 | position_barrier = PositionBarrier(
30 | "ee_barrier",
31 | gain=100,
32 | obj_name="link7",
33 | limit_type="max",
34 | p_max=0.3,
35 | safe_displacement_gain=1e-2,
36 | mask=[1, 0, 0],
37 | )
38 | joints_barrier = JointBarrier("jnt_range", gain=10)
39 |
40 | problem.add_component(frame_task)
41 | problem.add_component(position_barrier)
42 | problem.add_component(joints_barrier)
43 |
44 | # Compiling the problem upon any parameters update
45 | problem_data = problem.compile()
46 |
47 | # Initializing solver and its initial state
48 | solver = LocalIKSolver(mjx_model, maxiter=20)
49 |
50 | # Initializing initial condition
51 | N_batch = 10000
52 | q0 = np.array(
53 | [
54 | -1.5878328,
55 | -2.0968683,
56 | -1.4339591,
57 | 1.6550868,
58 | 2.1080072,
59 | 1.646142,
60 | -2.982619,
61 | ]
62 | )
63 | q = jnp.array([q0.copy() for _ in range(N_batch)])
64 |
65 | # --- Batching ---
66 | # First of all, data should be created via vmapped init function
67 | solver_data = jax.vmap(solver.init, in_axes=0)(v_init=jnp.zeros((N_batch, mjx_model.nv)))
68 |
69 | # To create a batch w.r.t. desired component's attributes, library defines convinient wrapper
70 | # That sets all elements to None and allows user to mutate dataclasses of interest.
71 | # After exiting the Context Manager, you'll get immutable jax dataclass object.
72 | with problem.set_vmap_dimension() as empty_problem_data:
73 | empty_problem_data.components["ee_task"].target_frame = 0
74 |
75 | # Vmapping solve and integrate functions.
76 | # Note that for batching w.r.t. q both q and solver_data should be batched.
77 | # Other approaches might work, but it would be undefined behaviour, please stick to this format.
78 | solve_jit = jax.jit(
79 | jax.vmap(
80 | solver.solve,
81 | in_axes=(0, 0, empty_problem_data),
82 | )
83 | )
84 | integrate_jit = jax.jit(jax.vmap(integrate, in_axes=(None, 0, 0, None)), static_argnames=["dt"])
85 |
86 | # === Control loop ===
87 | dt = 1e-2
88 | ts = np.arange(0, 1, dt)
89 |
90 | for t in ts:
91 | # Changing desired values
92 | frame_task.target_frame = np.array(
93 | [[0.2 + 0.2 * np.sin(t + np.pi * i / N_batch) ** 2, 0.2, 0.2, 1, 0, 0, 0] for i in range(N_batch)]
94 | )
95 | problem_data = problem.compile()
96 |
97 | # Solving the instance of the problem
98 | opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
99 |
100 | # Integrating
101 | q = integrate_jit(
102 | mjx_model,
103 | q,
104 | opt_solution.v_opt,
105 | dt,
106 | )
107 |
--------------------------------------------------------------------------------
/tests/unit_tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/components/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/components/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/components/barriers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/components/barriers/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/components/barriers/test_base_barrier.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax.numpy as jnp
4 | import mujoco.mjx as mjx
5 |
6 | from mjinx.components.barriers import Barrier, JaxBarrier
7 |
8 |
9 | class DummyJaxBarrier(JaxBarrier):
10 | """Dummy class to make minimal non-abstract jax barrier class"""
11 |
12 | def __call__(self, data: mjx.Data) -> jnp.ndarray:
13 | return jnp.ones(self.model.nv)[self.mask_idxs,]
14 |
15 |
16 | class DummyBarrier(Barrier[DummyJaxBarrier]):
17 | """Dummy class to make minimal non-abstract barrier class"""
18 |
19 | JaxComponentType: type = DummyJaxBarrier
20 |
21 |
22 | class TestBarrier(unittest.TestCase):
23 | def test_safe_displacement_gain(self):
24 | """Tests proper definition of safe displacement gain"""
25 | self.component = DummyBarrier("test_component", gain=2.0)
26 | self.component.safe_displacement_gain = 5.0
27 | self.assertEqual(self.component.safe_displacement_gain, 5.0)
28 |
--------------------------------------------------------------------------------
/tests/unit_tests/components/barriers/test_obj_position_barrier.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax.numpy as jnp
4 | import mujoco as mj
5 | import mujoco.mjx as mjx
6 | import numpy as np
7 |
8 | from mjinx.components.barriers._obj_position_barrier import JaxPositionBarrier, PositionBarrier
9 | from mjinx.typing import PositionLimitType
10 |
11 |
12 | class TestPositionBarrier(unittest.TestCase):
13 | def setUp(self):
14 | self.model = mjx.put_model(
15 | mj.MjModel.from_xml_string(
16 | """
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | """
26 | )
27 | )
28 | self.data = mjx.make_data(self.model)
29 |
30 | def test_initialization(self):
31 | """Testing component initialization"""
32 | barrier = PositionBarrier(
33 | name="test_barrier",
34 | gain=1.0,
35 | obj_name="body1",
36 | p_min=[-1.0, -1.0, -1.0],
37 | p_max=[1.0, 1.0, 1.0],
38 | )
39 | barrier.update_model(self.model)
40 |
41 | self.assertEqual(barrier.name, "test_barrier")
42 | self.assertEqual(barrier.obj_name, "body1")
43 | self.assertEqual(barrier.obj_id, 1)
44 | np.testing.assert_array_equal(barrier.p_min, jnp.array([-1.0, -1.0, -1.0]))
45 | np.testing.assert_array_equal(barrier.p_max, jnp.array([1.0, 1.0, 1.0]))
46 | self.assertEqual(barrier.limit_type, PositionLimitType.BOTH)
47 |
48 | def test_update_p_min(self):
49 | """Testing minimal position updates"""
50 | barrier = PositionBarrier(
51 | name="test_barrier",
52 | gain=1.0,
53 | obj_name="body1",
54 | p_min=[-1.0, -1.0, -1.0],
55 | limit_type="min",
56 | )
57 | barrier.update_p_min(0.1)
58 | np.testing.assert_array_equal(barrier.p_min, 0.1 * jnp.ones(3))
59 |
60 | barrier.update_p_min([-2.0, -2.0, -2.0])
61 | np.testing.assert_array_equal(barrier.p_min, jnp.array([-2.0, -2.0, -2.0]))
62 |
63 | with self.assertRaises(ValueError):
64 | barrier.update_p_min([-2.0, -2.0])
65 |
66 | with self.assertWarns(Warning):
67 | barrier.p_max = [1.0, 1.0, 1.0]
68 |
69 | def test_update_p_max(self):
70 | """Testing maximal position updates"""
71 | barrier = PositionBarrier(
72 | name="test_barrier",
73 | gain=1.0,
74 | obj_name="body1",
75 | p_max=[1.0, 1.0, 1.0],
76 | limit_type="max",
77 | )
78 | barrier.update_p_max(0.1)
79 | np.testing.assert_array_equal(barrier.p_max, 0.1 * jnp.ones(3))
80 |
81 | barrier.update_p_max([2.0, 2.0, 2.0])
82 | np.testing.assert_array_equal(barrier.p_max, jnp.array([2.0, 2.0, 2.0]))
83 |
84 | with self.assertRaises(ValueError):
85 | barrier.update_p_max([2.0, 2.0])
86 |
87 | with self.assertWarns(Warning):
88 | barrier.p_min = [1.0, 1.0, 1.0]
89 |
90 | def test_limit_type_and_dimension(self):
91 | """Test limit type detection and dimension of the barrier"""
92 | barrier_min = PositionBarrier(
93 | name="test_barrier_min",
94 | gain=1.0,
95 | obj_name="body1",
96 | p_min=[-1.0, -1.0, -1.0],
97 | limit_type="min",
98 | )
99 | self.assertEqual(barrier_min.dim, 3)
100 | self.assertEqual(barrier_min.limit_type, PositionLimitType.MIN)
101 |
102 | barrier_max = PositionBarrier(
103 | name="test_barrier_max",
104 | gain=1.0,
105 | obj_name="body1",
106 | p_max=[1.0, 1.0, 1.0],
107 | limit_type="max",
108 | )
109 | self.assertEqual(barrier_max.dim, 3)
110 | self.assertEqual(barrier_max.limit_type, PositionLimitType.MAX)
111 |
112 | barrier_both = PositionBarrier(
113 | name="test_barrier_max",
114 | gain=1.0,
115 | obj_name="body1",
116 | p_min=[-1.0, -1.0, -1.0],
117 | p_max=[1.0, 1.0, 1.0],
118 | limit_type="both",
119 | )
120 | self.assertEqual(barrier_both.dim, 6)
121 | self.assertEqual(barrier_both.limit_type, PositionLimitType.BOTH)
122 |
123 | with self.assertRaises(ValueError):
124 | PositionBarrier(
125 | name="test_barrier_invalid",
126 | gain=1.0,
127 | obj_name="body1",
128 | limit_type="invalid",
129 | )
130 |
131 | def test_jax_component(self):
132 | """Test generating jac component"""
133 | barrier = PositionBarrier(
134 | name="test_barrier",
135 | gain=1.0,
136 | obj_name="body1",
137 | p_min=[-1.0, -1.0, -1.0],
138 | p_max=[1.0, 1.0, 1.0],
139 | )
140 | barrier.update_model(self.model)
141 |
142 | jax_component = barrier.jax_component
143 |
144 | self.assertIsInstance(jax_component, JaxPositionBarrier)
145 | self.assertEqual(jax_component.dim, 6)
146 | np.testing.assert_array_equal(jax_component.vector_gain, jnp.ones(6))
147 | self.assertEqual(jax_component.obj_id, 1)
148 | np.testing.assert_array_equal(jax_component.p_min, jnp.array([-1.0, -1.0, -1.0]))
149 | np.testing.assert_array_equal(jax_component.p_max, jnp.array([1.0, 1.0, 1.0]))
150 |
151 | def test_jax_call(self):
152 | """Test call of the jax barrier function"""
153 | barrier = JaxPositionBarrier(
154 | dim=3,
155 | model=self.model,
156 | vector_gain=jnp.ones(3),
157 | gain_fn=lambda x: x,
158 | mask_idxs=(0, 1, 2),
159 | safe_displacement_gain=0.0,
160 | obj_id=1,
161 | obj_type=mj.mjtObj.mjOBJ_BODY,
162 | p_min=jnp.array([-1.0, -1.0, -1.0]),
163 | p_max=jnp.array([1.0, 1.0, 1.0]),
164 | limit_type_mask_idxs=tuple(i for i in range(6)),
165 | )
166 |
167 | self.data = self.data.replace(qpos=jnp.array([0.0, 0.5, -0.5, 1.0, 0.0, 0.0, 0.0]))
168 | self.data = mjx.kinematics(self.model, self.data)
169 | result = barrier(self.data)
170 | result_compute_barrier = barrier.compute_barrier(self.data)
171 | np.testing.assert_array_equal(result, result_compute_barrier)
172 | np.testing.assert_array_almost_equal(result, jnp.array([1.0, 1.5, 0.5, 1.0, 0.5, 1.5]))
173 |
--------------------------------------------------------------------------------
/tests/unit_tests/components/barriers/test_self_collision_barrier.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax.numpy as jnp
4 | import mujoco as mj
5 | import mujoco.mjx as mjx
6 | import numpy as np
7 |
8 | from mjinx.components.barriers._self_collision_barrier import JaxSelfCollisionBarrier, SelfCollisionBarrier
9 | from mjinx.configuration import get_distance, sorted_pair
10 |
11 |
12 | class TestSelfCollisionBarrier(unittest.TestCase):
13 | def setUp(self):
14 | self.model = mjx.put_model(
15 | mj.MjModel.from_xml_string(
16 | """
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 | """
38 | )
39 | )
40 | self.data = mjx.make_data(self.model)
41 |
42 | def test_initialization(self):
43 | barrier = SelfCollisionBarrier(
44 | name="test_barrier",
45 | gain=1.0,
46 | d_min=0.1,
47 | collision_bodies=["body1", "body2"],
48 | excluded_collisions=[("body1", "body2")],
49 | )
50 | with self.assertRaises(ValueError):
51 | _ = barrier.d_min_vec
52 |
53 | barrier.update_model(self.model)
54 |
55 | self.assertEqual(barrier.name, "test_barrier")
56 | self.assertEqual(barrier.d_min, 0.1)
57 | self.assertEqual(barrier.collision_bodies, ["body1", "body2"])
58 | self.assertEqual(barrier.exclude_collisions, {sorted_pair(1, 2)})
59 |
60 | def test_generate_collision_pairs(self):
61 | """Test straight-away collision pair generation"""
62 | barrier = SelfCollisionBarrier(
63 | name="test_barrier",
64 | gain=1.0,
65 | d_min=0.1,
66 | )
67 | barrier.update_model(self.model)
68 |
69 | expected_pairs = [(0, 1), (0, 2)]
70 | self.assertEqual(barrier.collision_pairs, expected_pairs)
71 |
72 | def test_generate_collision_pairs_with_exclusion(self):
73 | """Test exclusion of collision pairs"""
74 | barrier = SelfCollisionBarrier(
75 | name="test_barrier",
76 | gain=1.0,
77 | d_min=0.1,
78 | excluded_collisions=[("body1", "body2")],
79 | )
80 | barrier.update_model(self.model)
81 |
82 | expected_pairs = [(0, 2)]
83 | self.assertEqual(barrier.collision_pairs, expected_pairs)
84 |
85 | def test_body2id(self):
86 | barrier = SelfCollisionBarrier(
87 | name="test_barrier",
88 | gain=1.0,
89 | d_min=0.1,
90 | )
91 | barrier.update_model(self.model)
92 |
93 | self.assertEqual(barrier.body2id("body1"), 1)
94 | self.assertEqual(barrier.body2id(1), 1)
95 |
96 | with self.assertRaises(ValueError):
97 | barrier.body2id(1.5)
98 |
99 | def test_validate_body_pair(self):
100 | """Test body pairs validation"""
101 | barrier = SelfCollisionBarrier(
102 | name="test_barrier",
103 | gain=1.0,
104 | d_min=0.1,
105 | )
106 | barrier.update_model(self.model)
107 |
108 | # Valid body pairs
109 | self.assertTrue(barrier.validate_body_pair(1, 2))
110 | self.assertTrue(barrier.validate_body_pair(1, 3))
111 | self.assertTrue(barrier.validate_body_pair(2, 4))
112 |
113 | # Invalid pairs: consequitive pairs
114 | self.assertFalse(barrier.validate_body_pair(2, 3))
115 | self.assertFalse(barrier.validate_body_pair(3, 4))
116 |
117 | def test_validate_geom_pair(self):
118 | """Test geometry pairs validation"""
119 | barrier = SelfCollisionBarrier(
120 | name="test_barrier",
121 | gain=1.0,
122 | d_min=0.1,
123 | )
124 | barrier.update_model(self.model)
125 |
126 | # Valid geom pairs
127 | self.assertTrue(barrier.validate_geom_pair(0, 1))
128 | self.assertTrue(barrier.validate_geom_pair(0, 2))
129 | self.assertTrue(barrier.validate_geom_pair(1, 2))
130 |
131 | for i in range(self.model.ngeom - 1):
132 | self.assertFalse(barrier.validate_geom_pair(3, i)) # Invalid geom pairs: condyn-affinity contradict.
133 |
134 | def test_d_min_vec(self):
135 | barrier = SelfCollisionBarrier(
136 | name="test_barrier",
137 | gain=1.0,
138 | d_min=0.1,
139 | )
140 | barrier.update_model(self.model)
141 |
142 | np.testing.assert_array_equal(barrier.d_min_vec, jnp.ones(barrier.dim) * 0.1)
143 |
144 | def test_jax_component(self):
145 | barrier = SelfCollisionBarrier(
146 | name="test_barrier",
147 | gain=1.0,
148 | d_min=0.1,
149 | )
150 | barrier.update_model(self.model)
151 |
152 | jax_component = barrier.jax_component
153 |
154 | self.assertIsInstance(jax_component, JaxSelfCollisionBarrier)
155 | self.assertEqual(jax_component.dim, 2)
156 | np.testing.assert_array_equal(jax_component.vector_gain, jnp.ones(jax_component.dim))
157 | np.testing.assert_array_equal(jax_component.d_min_vec, jnp.ones(jax_component.dim) * 0.1)
158 | self.assertEqual(len(jax_component.collision_pairs), jax_component.dim)
159 |
160 | def test_call(self):
161 | """Test jax component actual computation"""
162 | collision_pairs = [(0, 1)]
163 | barrier = JaxSelfCollisionBarrier(
164 | dim=1,
165 | model=self.model,
166 | vector_gain=jnp.ones(1),
167 | gain_fn=lambda x: x,
168 | mask_idxs=(0,),
169 | safe_displacement_gain=0.0,
170 | d_min_vec=jnp.array([0.1]),
171 | collision_pairs=collision_pairs,
172 | n_closest_pairs=len(collision_pairs),
173 | )
174 |
175 | result = barrier(self.data)
176 | expected = get_distance(self.model, self.data, collision_pairs)[0] - barrier.d_min_vec
177 | np.testing.assert_array_almost_equal(result, expected)
178 |
--------------------------------------------------------------------------------
/tests/unit_tests/components/constraints/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/components/constraints/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/components/constraints/test_base_constraint.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | from mjinx.components.constraints import Constraint, JaxConstraint
4 |
5 | # Language: python
6 | import jax.numpy as jnp
7 |
8 |
9 | class DummyConstraint(Constraint):
10 | dim = 3
11 | pass
12 |
13 |
14 | class TestJaxConstraint(unittest.TestCase):
15 | def setUp(self):
16 | # Use a fixed dimension for tests.
17 | self.dim = 3
18 |
19 | def create_constraint(self, soft_cost):
20 | # Create a dummy constraint instance and set the required dim attribute.
21 | instance = DummyConstraint(name="dummy", gain=1.0, soft_constraint_cost=soft_cost)
22 | return instance
23 |
24 | def test_soft_constraint_cost_none(self):
25 | # Test when soft_constraint_cost is None, property returns 1e2 * identity matrix.
26 | constraint = self.create_constraint(None)
27 | expected = 1e2 * jnp.eye(self.dim)
28 | self.assertTrue(jnp.allclose(constraint.soft_constraint_cost, expected))
29 |
30 | def test_soft_constraint_cost_scalar(self):
31 | # Test when soft_constraint_cost is a scalar.
32 | scalar_value = 10.0
33 | constraint = self.create_constraint(scalar_value)
34 | expected = jnp.eye(self.dim) * scalar_value
35 | self.assertTrue(jnp.allclose(constraint.soft_constraint_cost, expected))
36 |
37 | def test_soft_constraint_cost_vector_correct(self):
38 | # Test when soft_constraint_cost is a vector with proper length.
39 | vector_cost = jnp.array([1.0, 2.0, 3.0])
40 | constraint = self.create_constraint(vector_cost)
41 | expected = jnp.diag(vector_cost)
42 | self.assertTrue(jnp.allclose(constraint.soft_constraint_cost, expected))
43 |
44 | def test_soft_constraint_cost_vector_invalid(self):
45 | # Test when soft_constraint_cost is a vector with incorrect length.
46 | vector_cost = jnp.array([1.0, 2.0]) # length != self.dim
47 | constraint = self.create_constraint(vector_cost)
48 | with self.assertRaises(ValueError):
49 | _ = constraint.soft_constraint_cost
50 |
51 | def test_soft_constraint_cost_matrix_correct(self):
52 | # Test when soft_constraint_cost is a matrix of proper shape.
53 | matrix_cost = jnp.array([[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]])
54 | constraint = self.create_constraint(matrix_cost)
55 | self.assertTrue(jnp.allclose(constraint.soft_constraint_cost, matrix_cost))
56 |
57 | def test_soft_constraint_cost_matrix_invalid(self):
58 | # Test when soft_constraint_cost is a matrix with an incorrect shape.
59 | matrix_cost = jnp.array([[1.0, 0], [0, 2.0]])
60 | constraint = self.create_constraint(matrix_cost)
61 | with self.assertRaises(ValueError):
62 | _ = constraint.soft_constraint_cost
63 |
64 | def test_soft_constraint_cost_ndim_invalid(self):
65 | # Test when soft_constraint_cost has ndim > 2.
66 | invalid_cost = jnp.ones((self.dim, self.dim, 1))
67 | constraint = self.create_constraint(invalid_cost)
68 | with self.assertRaises(ValueError):
69 | _ = constraint.soft_constraint_cost
70 |
71 |
72 | if __name__ == "__main__":
73 | unittest.main()
74 |
--------------------------------------------------------------------------------
/tests/unit_tests/components/tasks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/components/tasks/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/components/tasks/test_base_task.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax.numpy as jnp
4 | import mujoco as mj
5 | import mujoco.mjx as mjx
6 | import numpy as np
7 |
8 | from mjinx.components.tasks import JaxTask, Task
9 | from mjinx.typing import ArrayOrFloat
10 |
11 |
12 | class DummyJaxTask(JaxTask):
13 | def __call__(self, data: mjx.Data) -> jnp.ndarray:
14 | """Immitating CoM task"""
15 | return data.subtree_com[self.model.body_rootid[0], self.mask_idxs]
16 |
17 |
18 | class DummyTask(Task[DummyJaxTask]):
19 | JaxComponentType: type = DummyJaxTask
20 |
21 | def define_dim(self, dim: int):
22 | self._dim = dim
23 |
24 |
25 | class TestTask(unittest.TestCase):
26 | dummy_dim: int = 3
27 |
28 | def setUp(self):
29 | self.model = mjx.put_model(
30 | mj.MjModel.from_xml_string(
31 | """
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | """
42 | )
43 | )
44 | self.task = DummyTask("test_task", cost=2.0, gain=1.0)
45 | self.task.update_model(self.model)
46 |
47 | def set_dim(self):
48 | self.task.define_dim(self.dummy_dim)
49 |
50 | def __set_cost_and_check(self, cost: ArrayOrFloat):
51 | """Checking that cost is transformed into jnp.ndarray, and values are correct"""
52 | self.task.update_cost(cost)
53 | np.testing.assert_array_equal(self.task.cost, cost)
54 |
55 | def test_update_gain(self):
56 | """Testing setting up different cost dimensions"""
57 | # Test scalar cost assignment
58 | self.__set_cost_and_check(1.0)
59 | self.__set_cost_and_check(np.array(2.0))
60 | self.__set_cost_and_check(jnp.array(3.0))
61 |
62 | # Test vector cost assignment
63 | self.__set_cost_and_check(np.zeros(self.dummy_dim))
64 | self.__set_cost_and_check(jnp.ones(self.dummy_dim))
65 |
66 | # Test matrix cost assignment
67 | self.__set_cost_and_check(np.eye(self.dummy_dim))
68 | self.__set_cost_and_check(2 * jnp.eye(self.dummy_dim))
69 |
70 | # Test assigning cost with other dimension
71 | cost = np.zeros((self.dummy_dim, self.dummy_dim, self.dummy_dim))
72 | with self.assertRaises(ValueError):
73 | self.task.update_cost(cost)
74 |
75 | def test_matrix_cost(self):
76 | """Testing proper convertation of cost to matrix form"""
77 | # Matrix could not be constucted till dimension is specified
78 | self.task.cost = 1.0
79 | with self.assertRaises(ValueError):
80 | _ = self.task.matrix_cost
81 |
82 | self.set_dim()
83 |
84 | # Scalar -> jnp.eye(dim) * scalar
85 | self.task.update_cost(3.0)
86 | np.testing.assert_array_equal(self.task.matrix_cost, jnp.eye(self.task.dim) * 3)
87 |
88 | # Vector -> jnp.diag(vector)
89 | vector_cost = jnp.arange(self.task.dim)
90 | self.task.update_cost(vector_cost)
91 | np.testing.assert_array_equal(self.task.matrix_cost, jnp.diag(vector_cost))
92 |
93 | # Matrix -> matrix
94 | matrix_cost = jnp.eye(self.task.dim)
95 | self.task.update_cost(matrix_cost)
96 | np.testing.assert_array_equal(self.task.matrix_cost, matrix_cost)
97 |
98 | # For vector cost, the length has to be equal to dimension of the task
99 | vector_cost = jnp.ones(self.task.dim + 1)
100 | self.task.update_cost(vector_cost)
101 | with self.assertRaises(ValueError):
102 | _ = self.task.matrix_cost
103 |
104 | # For matrix cost, if the dimension turned out to be wrong, ValueError should be raised
105 | # Note that error would be raised only when matrix_cost is accessed, even if model is already
106 | # provided
107 | matrix_cost = jnp.eye(self.task.dim + 1)
108 | self.task.update_cost(matrix_cost)
109 | with self.assertRaises(ValueError):
110 | _ = self.task.matrix_cost
111 |
112 | def test_lm_damping(self):
113 | lm_specific_task = DummyTask("lm_specific_task", 1.0, 1.0, lm_damping=5.0)
114 | self.assertEqual(lm_specific_task.lm_damping, 5.0)
115 |
116 | with self.assertRaises(ValueError):
117 | _ = DummyTask("negative_lm_task", 1.0, 1.0, lm_damping=-5.0)
118 |
119 | def test_error(self):
120 | self.set_dim()
121 | jax_task = self.task.jax_component
122 | data = mjx.fwd_position(self.model, mjx.make_data(self.model))
123 | np.testing.assert_array_almost_equal(jax_task(data), jax_task.compute_error(data))
124 |
--------------------------------------------------------------------------------
/tests/unit_tests/components/tasks/test_com_task.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax.numpy as jnp
4 | import mujoco as mj
5 | import mujoco.mjx as mjx
6 | import numpy as np
7 |
8 | from mjinx.components.tasks import ComTask
9 |
10 |
11 | class TestComTask(unittest.TestCase):
12 | def set_model(self, task: ComTask):
13 | self.model = mjx.put_model(
14 | mj.MjModel.from_xml_string(
15 | """
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | w
24 |
25 | """
26 | )
27 | )
28 | task.update_model(self.model)
29 |
30 | def test_task_dim(self):
31 | """Test task dimensionality"""
32 | com_task_masked = ComTask("com_task_masked", cost=1.0, gain=1.0, mask=[True, False, True])
33 | self.assertEqual(com_task_masked.dim, 2)
34 |
35 | def test_mask_validation(self):
36 | """Test task's mask size validation"""
37 | with self.assertRaises(ValueError):
38 | _ = ComTask("com_task_small_mask", cost=1.0, gain=1.0, mask=[True, False])
39 |
40 | with self.assertRaises(ValueError):
41 | _ = ComTask("com_task_big_mask", cost=1.0, gain=1.0, mask=[True, False, True, False])
42 |
43 | def test_update_target_com(self):
44 | """Test target CoM parameter updates"""
45 | com_task = ComTask("com_task_default", cost=1.0, gain=1.0)
46 | new_target = [1.0, 2.0, 3.0]
47 | com_task.update_target_com(new_target)
48 | np.testing.assert_array_equal(com_task.target_com, new_target)
49 |
50 | too_big_target = range(4)
51 | with self.assertRaises(ValueError):
52 | com_task.update_target_com(too_big_target)
53 |
54 | def test_build_component(self):
55 | com_task = ComTask(
56 | "com_task", cost=1.0, gain=2.0, gain_fn=lambda x: 2 * x, lm_damping=0.5, mask=[True, True, False]
57 | )
58 | self.set_model(com_task)
59 | com_des = jnp.array((-0.3, 0.3))
60 | com_task.target_com = com_des
61 |
62 | jax_component = com_task.jax_component
63 | self.assertEqual(jax_component.dim, 2)
64 | np.testing.assert_array_equal(jax_component.matrix_cost, jnp.eye(2))
65 | np.testing.assert_array_equal(jax_component.vector_gain, jnp.ones(2) * 2.0)
66 | self.assertEqual(jax_component.gain_fn(4), 8)
67 | self.assertEqual(jax_component.lm_damping, 0.5)
68 | np.testing.assert_array_equal(jax_component.target_com, com_des)
69 | self.assertEqual(jax_component.mask_idxs, (0, 1))
70 |
71 | data = mjx.fwd_position(self.model, mjx.make_data(self.model))
72 | com_value = jax_component(data)
73 | np.testing.assert_array_equal(com_value, jnp.array([0.6, 0.0]))
74 |
--------------------------------------------------------------------------------
/tests/unit_tests/components/tasks/test_obj_frame_task.py:
--------------------------------------------------------------------------------
1 | """Center of mass task implementation."""
2 |
3 | import unittest
4 |
5 | import jax.numpy as jnp
6 | import mujoco as mj
7 | import mujoco.mjx as mjx
8 | import numpy as np
9 | from jaxlie import SE3
10 |
11 | from mjinx.components.tasks import FrameTask
12 |
13 |
14 | class TestObjFrameTask(unittest.TestCase):
15 | def setUp(self) -> None:
16 | self.to_wxyz_xyz: jnp.ndarray = jnp.array([3, 4, 5, 6, 0, 1, 2])
17 |
18 | def set_model(self, task: FrameTask):
19 | self.model = mjx.put_model(
20 | mj.MjModel.from_xml_string(
21 | """
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 | """
39 | )
40 | )
41 | task.update_model(self.model)
42 |
43 | def test_target_frame(self):
44 | """Testing manipulations with target frame"""
45 | frame_task = FrameTask("frame_task", cost=1.0, gain=2.0, obj_name="body3")
46 | # By default, it has to be identity
47 | np.testing.assert_array_almost_equal(SE3.identity().wxyz_xyz, frame_task.target_frame.wxyz_xyz)
48 |
49 | # Setting with SE3 object
50 | test_se3 = SE3(jnp.array([0.1, 0.2, -0.1, 0, 1, 0, 0]))
51 | frame_task.target_frame = test_se3
52 | np.testing.assert_array_almost_equal(test_se3.wxyz_xyz, frame_task.target_frame.wxyz_xyz)
53 |
54 | # Setting with sequence
55 | test_se3_seq = (-1, 0, 1, 0, 0, 1, 0)
56 | frame_task.target_frame = test_se3_seq
57 | np.testing.assert_array_almost_equal(
58 | jnp.array(test_se3_seq)[self.to_wxyz_xyz], frame_task.target_frame.wxyz_xyz
59 | )
60 |
61 | # Setting with sequence of wrong length
62 | with self.assertRaises(ValueError):
63 | frame_task.target_frame = (-1, 0, 1, 0, 0, 1, 0, 1)
64 |
65 | def test_build_component(self):
66 | frame_task = FrameTask(
67 | "frame_task",
68 | cost=1.0,
69 | gain=2.0,
70 | obj_name="body3",
71 | gain_fn=lambda x: 2 * x,
72 | lm_damping=0.5,
73 | mask=[True, False, True, True, False, True],
74 | )
75 | self.set_model(frame_task)
76 | frame_des = jnp.array([0.1, 0, -0.1, 1, 0, 0, 0])
77 | frame_task.target_frame = frame_des
78 |
79 | jax_component = frame_task.jax_component
80 |
81 | self.assertEqual(jax_component.dim, 4)
82 | np.testing.assert_array_equal(jax_component.matrix_cost, jnp.eye(jax_component.dim))
83 | np.testing.assert_array_equal(jax_component.vector_gain, jnp.ones(jax_component.dim) * 2.0)
84 | np.testing.assert_array_equal(jax_component.obj_id, frame_task.obj_id)
85 | self.assertEqual(jax_component.gain_fn(4), 8)
86 | self.assertEqual(jax_component.lm_damping, 0.5)
87 | np.testing.assert_array_almost_equal(jax_component.target_frame.wxyz_xyz, frame_des[self.to_wxyz_xyz])
88 |
89 | self.assertEqual(jax_component.mask_idxs, (0, 2, 3, 5))
90 |
91 | data = mjx.fwd_position(self.model, mjx.make_data(self.model))
92 | error = jax_component(data)
93 | np.testing.assert_array_equal(error, jnp.array([0.1, -0.1, 0.0, 0.0]))
94 |
95 | # Testing component jacobian
96 | jac = jax_component.compute_jacobian(mjx.make_data(self.model))
97 |
98 | self.assertEqual(jac.shape, (jax_component.dim, self.model.nv))
99 |
--------------------------------------------------------------------------------
/tests/unit_tests/components/tasks/test_obj_position_task.py:
--------------------------------------------------------------------------------
1 | """Center of mass task implementation."""
2 |
3 | import unittest
4 |
5 | import jax.numpy as jnp
6 | import mujoco as mj
7 | import mujoco.mjx as mjx
8 | import numpy as np
9 |
10 | from mjinx.components.tasks import PositionTask
11 |
12 |
13 | class TestObjPositionTask(unittest.TestCase):
14 | def set_model(self, task: PositionTask):
15 | self.model = mjx.put_model(
16 | mj.MjModel.from_xml_string(
17 | """
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | """
35 | )
36 | )
37 | task.update_model(self.model)
38 |
39 | def test_target_frame(self):
40 | """Testing manipulations with target frame"""
41 | pos_task = PositionTask("pos_task", cost=1.0, gain=2.0, obj_name="body3")
42 | # By default, it has to be identity
43 | np.testing.assert_array_almost_equal(pos_task.target_pos, jnp.zeros(3))
44 |
45 | # Setting with sequence
46 | test_pos = (-1, 0, 1)
47 | pos_task.target_pos = test_pos
48 | np.testing.assert_array_almost_equal(jnp.array(test_pos), pos_task.target_pos)
49 |
50 | with self.assertRaises(ValueError):
51 | pos_task.target_pos = (0, 1, 2, 3, 4)
52 |
53 | def test_build_component(self):
54 | frame_task = PositionTask(
55 | "frame_task",
56 | cost=1.0,
57 | gain=2.0,
58 | obj_name="body3",
59 | gain_fn=lambda x: 2 * x,
60 | lm_damping=0.5,
61 | mask=[True, False, True],
62 | )
63 | self.set_model(frame_task)
64 | pos_des = jnp.array([0.1, -0.1])
65 | frame_task.target_pos = pos_des
66 |
67 | jax_component = frame_task.jax_component
68 |
69 | self.assertEqual(jax_component.dim, 2)
70 | np.testing.assert_array_equal(jax_component.matrix_cost, jnp.eye(jax_component.dim))
71 | np.testing.assert_array_equal(jax_component.vector_gain, jnp.ones(jax_component.dim) * 2.0)
72 | np.testing.assert_array_equal(jax_component.obj_id, frame_task.obj_id)
73 | self.assertEqual(jax_component.gain_fn(4), 8)
74 | self.assertEqual(jax_component.lm_damping, 0.5)
75 | np.testing.assert_array_almost_equal(jax_component.target_pos, pos_des)
76 |
77 | self.assertEqual(jax_component.mask_idxs, (0, 2))
78 |
79 | data = mjx.fwd_position(self.model, mjx.make_data(self.model))
80 | error = jax_component(data)
81 | np.testing.assert_array_equal(error, jnp.array([-0.1, 0.1]))
82 |
--------------------------------------------------------------------------------
/tests/unit_tests/configuration/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/configuration/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/configuration/test_collision_computation.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax.numpy as jnp
4 | import mujoco
5 | import mujoco.mjx as mjx
6 | import numpy as np
7 |
8 | from mjinx.configuration import compute_collision_pairs, update
9 | from mjinx.typing import CollisionPair
10 |
11 |
12 | class TestCollisionPairs(unittest.TestCase):
13 | """Test suite for collision pair computation functionality."""
14 |
15 | # XML template for creating test models
16 | TEST_MODEL_TEMPLATE = """
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 | """
29 |
30 | def create_test_model(self, type1, type2, pos1, pos2, size1, size2):
31 | """
32 | Helper method to create test models with different geometries.
33 |
34 | :param type1: Type of first geometry
35 | :param type2: Type of second geometry
36 | :param pos1: Position of first body [x, y, z]
37 | :param pos2: Position of second body [x, y, z]
38 | :param size1: Size parameters for first geometry
39 | :param size2: Size parameters for second geometry
40 | :return: Tuple of (mjx.Model, mjx.Data)
41 | """
42 | xml = self.TEST_MODEL_TEMPLATE.format(
43 | type1=type1,
44 | type2=type2,
45 | pos1=" ".join(map(str, pos1)),
46 | pos2=" ".join(map(str, pos2)),
47 | size1=" ".join(map(str, size1 if isinstance(size1, list | tuple) else [size1])),
48 | size2=" ".join(map(str, size2 if isinstance(size2, list | tuple) else [size2])),
49 | )
50 | mjx_model = mjx.put_model(mujoco.MjModel.from_xml_string(xml))
51 | mjx_data = update(mjx_model, jnp.zeros(1))
52 | return mjx_model, mjx_data
53 |
54 | def setUp(self):
55 | """Set up common test models."""
56 | # Model for sphere-sphere collision tests
57 | self.sphere_model, self.sphere_data = self.create_test_model(
58 | "sphere", "sphere", pos1=[0, 0, 0], pos2=[0.8, 0, 0], size1=0.5, size2=0.5
59 | )
60 |
61 | # Model for box-sphere collision tests
62 | self.box_sphere_model, self.box_sphere_data = self.create_test_model(
63 | "box", "sphere", pos1=[0, 0, 0], pos2=[0.8, 0, 0], size1=[0.5, 0.5, 0.5], size2=0.5
64 | )
65 |
66 | def test_sphere_sphere_collision(self):
67 | """Test collision detection between two spheres."""
68 | collision_pairs = [(0, 1)]
69 | contact = compute_collision_pairs(self.sphere_model, self.sphere_data, collision_pairs)
70 |
71 | # Expected distance: 0.8 - (0.5 + 0.5) = -0.2 (penetration)
72 | self.assertIsNotNone(contact.dist)
73 | np.testing.assert_allclose(contact.dist, -0.2, atol=1e-6)
74 | self.assertIsNotNone(contact.pos)
75 | np.testing.assert_allclose(contact.pos[0], [0.4, 0, 0], atol=1e-6)
76 |
77 | def test_sphere_sphere_no_collision(self):
78 | """Test non-colliding spheres."""
79 | model, data = self.create_test_model("sphere", "sphere", pos1=[0, 0, 0], pos2=[2, 0, 0], size1=0.5, size2=0.5)
80 |
81 | collision_pairs = [(0, 1)]
82 | contact = compute_collision_pairs(model, data, collision_pairs)
83 |
84 | self.assertIsNotNone(contact.dist)
85 | np.testing.assert_allclose(contact.dist, 1.0, atol=1e-6)
86 |
87 | def test_box_sphere_collision(self):
88 | """Test collision detection between a box and a sphere."""
89 | collision_pairs = [(0, 1)]
90 | contact = compute_collision_pairs(self.box_sphere_model, self.box_sphere_data, collision_pairs)
91 |
92 | self.assertIsNotNone(contact.dist)
93 | np.testing.assert_allclose(contact.dist, -0.2, atol=1e-6)
94 |
95 | def test_multiple_collision_pairs(self):
96 | """Test handling multiple collision pairs simultaneously."""
97 | # Create model with three bodies
98 | xml_additional = """
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 | """
114 | model = mjx.put_model(mujoco.MjModel.from_xml_string(xml_additional))
115 | data = update(model, jnp.zeros(1))
116 |
117 | collision_pairs = [(0, 1), (0, 2)]
118 | contact = compute_collision_pairs(model, data, collision_pairs)
119 |
120 | self.assertIsNotNone(contact.dist)
121 | self.assertEqual(len(contact.dist), 2)
122 | np.testing.assert_allclose(contact.dist[0], -0.2, atol=1e-6) # penetration
123 | np.testing.assert_allclose(contact.dist[1], 0.2, atol=1e-6) # separation
124 |
125 | def test_invalid_collision_pairs(self):
126 | """Test behavior with invalid collision pair indices."""
127 | with self.assertRaises(IndexError):
128 | compute_collision_pairs(self.sphere_model, self.sphere_data, [(0, 5)])
129 |
130 | def test_empty_collision_pairs(self):
131 | """Test behavior with empty collision pairs list."""
132 | contact = compute_collision_pairs(self.sphere_model, self.sphere_data, [])
133 |
134 | self.assertIsNotNone(contact.dist)
135 | self.assertEqual(len(contact.dist), 0)
136 | self.assertIsNotNone(contact.pos)
137 | self.assertEqual(len(contact.pos), 0)
138 |
139 | def test_contact_frame_orientation(self):
140 | """Test the orientation of contact frames."""
141 | model, data = self.create_test_model(
142 | "sphere",
143 | "sphere",
144 | pos1=[0, 0, 0],
145 | pos2=[0.8, 0.8, 0],
146 | size1=0.5,
147 | size2=0.5, # Diagonal position
148 | )
149 |
150 | collision_pairs = [(0, 1)]
151 | contact = compute_collision_pairs(model, data, collision_pairs)
152 |
153 | self.assertIsNotNone(contact.frame)
154 | expected_direction = np.array([1, 1, 0]) / np.sqrt(2)
155 | np.testing.assert_allclose(contact.frame[0, 0], expected_direction, atol=1e-6)
156 |
--------------------------------------------------------------------------------
/tests/unit_tests/solvers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/based-robotics/mjinx/6ab93b7bda39a7937e78739b7270f99719140a0b/tests/unit_tests/solvers/__init__.py
--------------------------------------------------------------------------------
/tests/unit_tests/solvers/test_global_ik.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import mujoco as mj
4 | import mujoco.mjx as mjx
5 | import numpy as np
6 | import optax
7 |
8 | from mjinx.components.barriers import JointBarrier
9 | from mjinx.components.tasks import ComTask
10 | from mjinx.problem import Problem
11 | from mjinx.solvers import GlobalIKSolver
12 |
13 |
14 | class TestGlobalIK(unittest.TestCase):
15 | def setUp(self):
16 | """Setting up basic model, components, and problem data."""
17 |
18 | self.model = mjx.put_model(
19 | mj.MjModel.from_xml_string(
20 | """
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | """
31 | )
32 | )
33 | self.dummy_task = ComTask("com_task", cost=1.0, gain=1.0)
34 | self.dummy_barrier = JointBarrier("joint_barrier", gain=1.0)
35 | self.problem = Problem(
36 | self.model,
37 | v_min=0,
38 | v_max=0,
39 | )
40 | self.problem.add_component(self.dummy_task)
41 | self.problem.add_component(self.dummy_barrier)
42 |
43 | self.problem_data = self.problem.compile()
44 |
45 | def test_init(self):
46 | """Testing planner initialization"""
47 |
48 | solver = GlobalIKSolver(model=self.model, optimizer=optax.adam(learning_rate=1e-3))
49 |
50 | with self.assertRaises(ValueError):
51 | _ = solver.init(q=np.arange(self.model.nq + 1))
52 |
53 | data = solver.init(q=np.arange(self.model.nq))
54 |
55 | # Check that it correctly initialized for chosen optimizer
56 | self.assertIsInstance(data.optax_state[0], optax.ScaleByAdamState)
57 | self.assertIsInstance(data.optax_state[1], optax.EmptyState)
58 |
59 | def test_loss_fn(self):
60 | self.problem.remove_component(self.dummy_barrier.name)
61 | self.dummy_task.target_com = [0.3, 0.3, 0.3]
62 | new_problem_data = self.problem.compile()
63 |
64 | solver = GlobalIKSolver(model=self.model, optimizer=optax.adam(learning_rate=1e-3))
65 | loss = solver.loss_fn(np.zeros(self.model.nq), problem_data=new_problem_data)
66 |
67 | self.assertEqual(loss, 0)
68 |
69 | def test_loss_grad(self):
70 | solver = GlobalIKSolver(model=self.model, optimizer=optax.adam(learning_rate=1e-3))
71 | grad = solver.grad_fn(np.zeros(self.model.nq), problem_data=self.problem_data)
72 |
73 | self.assertEqual(grad.shape, (self.model.nv,))
74 |
75 | def test_solve(self):
76 | """Testing solving functions"""
77 | solver = GlobalIKSolver(model=self.model, optimizer=optax.adam(learning_rate=1e-2), dt=1e-3)
78 | solver_data = solver.init(q=np.ones(self.model.nq))
79 |
80 | with self.assertRaises(ValueError):
81 | solver.solve(np.ones(self.model.nq + 1), solver_data, self.problem_data)
82 |
83 | new_solution, _ = solver.solve(np.ones(self.model.nq), solver_data, self.problem_data)
84 |
85 | new_solution_from_data, _ = solver.solve_from_data(
86 | solver_data,
87 | self.problem_data,
88 | mjx.make_data(self.model).replace(qpos=np.ones(self.model.nq)),
89 | )
90 |
91 | np.testing.assert_almost_equal(new_solution.v_opt, new_solution_from_data.v_opt, decimal=3)
92 |
--------------------------------------------------------------------------------
/tests/unit_tests/test_problem.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import jax.numpy as jnp
4 | import mujoco as mj
5 | import mujoco.mjx as mjx
6 | import numpy as np
7 |
8 | from mjinx.components.barriers import JointBarrier
9 | from mjinx.components.tasks import ComTask, JaxComTask
10 | from mjinx.problem import Problem
11 |
12 |
13 | class TestProblem(unittest.TestCase):
14 | dummy_dim: int = 3
15 |
16 | def setUp(self):
17 | """Setting up component test based on ComTask example.
18 |
19 | Note that ComTask has one of the smallest additional functionality and processing
20 | compared to other components"""
21 |
22 | self.model = mjx.put_model(
23 | mj.MjModel.from_xml_string(
24 | """
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 | w
33 |
34 | """
35 | )
36 | )
37 | self.problem = Problem(
38 | self.model,
39 | v_min=0,
40 | v_max=0,
41 | )
42 |
43 | def test_v_min(self):
44 | """Testing proper v lower limit"""
45 |
46 | self.problem.v_min = 1
47 |
48 | self.assertIsInstance(self.problem.v_min, jnp.ndarray)
49 | np.testing.assert_array_equal(self.problem.v_min, jnp.ones(self.model.nv))
50 |
51 | with self.assertRaises(ValueError):
52 | self.problem.v_min = jnp.array([1, 2, 3])
53 |
54 | self.problem.v_min = np.arange(self.model.nv)
55 | self.assertIsInstance(self.problem.v_min, jnp.ndarray)
56 | np.testing.assert_array_equal(self.problem.v_min, jnp.arange(self.model.nv))
57 |
58 | jax_problem_data = self.problem.compile()
59 | self.assertIsInstance(jax_problem_data.v_min, jnp.ndarray)
60 | np.testing.assert_array_equal(jax_problem_data.v_min, jnp.arange(self.model.nv))
61 |
62 | with self.assertRaises(ValueError):
63 | self.problem.v_min = jnp.eye(self.model.nv)
64 |
65 | def test_v_max(self):
66 | """Testing proper v upper limit"""
67 |
68 | self.problem.v_max = 1
69 |
70 | np.testing.assert_array_equal(self.problem.v_max, jnp.ones(self.model.nv))
71 |
72 | with self.assertRaises(ValueError):
73 | self.problem.v_max = jnp.array([1, 2, 3])
74 |
75 | self.problem.v_max = np.arange(self.model.nv)
76 | self.assertIsInstance(self.problem.v_max, jnp.ndarray)
77 | np.testing.assert_array_equal(self.problem.v_max, jnp.arange(self.model.nv))
78 |
79 | jax_problem_data = self.problem.compile()
80 | self.assertIsInstance(jax_problem_data.v_max, jnp.ndarray)
81 | np.testing.assert_array_equal(jax_problem_data.v_max, jnp.arange(self.model.nv))
82 |
83 | with self.assertRaises(ValueError):
84 | self.problem.v_max = jnp.eye(self.model.nv)
85 |
86 | def test_add_component(self):
87 | """Testing adding component"""
88 |
89 | component = ComTask("test_component", 1.0, 1.0)
90 |
91 | self.problem.add_component(component)
92 | # Chech that model is defined and does not raise errors
93 | _ = component.model
94 |
95 | jax_problem_data = self.problem.compile()
96 |
97 | # False -> component was compiled as well
98 | self.assertFalse(component.modified)
99 | self.assertEqual(len(jax_problem_data.components), 1)
100 | self.assertIsInstance(jax_problem_data.components["test_component"], JaxComTask)
101 |
102 | component_with_same_name = JointBarrier("test_component", 1.0)
103 |
104 | with self.assertRaises(ValueError):
105 | self.problem.add_component(component_with_same_name)
106 |
107 | def test_remove_component(self):
108 | """Testing removing component"""
109 | component = ComTask("test_component", 1.0, 1.0)
110 |
111 | self.problem.add_component(component)
112 | self.problem.remove_component(component.name)
113 |
114 | jax_problem_data = self.problem.compile()
115 | self.assertEqual(len(jax_problem_data.components), 0)
116 |
117 | def test_components_access(self):
118 | """Test componnets access interface"""
119 |
120 | component = ComTask("test_component", 1.0, 1.0)
121 |
122 | self.problem.add_component(component)
123 |
124 | self.assertIsInstance(self.problem.component("test_component"), ComTask)
125 |
126 | with self.assertRaises(ValueError):
127 | _ = self.problem.component("non_existens_component")
128 |
129 | def test_tasks_access(self):
130 | """Test tasks access interface"""
131 |
132 | task = ComTask("test_task", 1.0, 1.0)
133 | barrier = JointBarrier("test_barrier", 1.0)
134 |
135 | self.problem.add_component(task)
136 | self.problem.add_component(barrier)
137 |
138 | self.assertIsInstance(self.problem.task("test_task"), ComTask)
139 |
140 | with self.assertRaises(ValueError):
141 | _ = self.problem.task("test_barrier")
142 |
143 | with self.assertRaises(ValueError):
144 | _ = self.problem.task("non_existent_task")
145 |
146 | def test_barriers_access(self):
147 | """Test barriers access interface"""
148 |
149 | task = ComTask("test_task", 1.0, 1.0)
150 | barrier = JointBarrier("test_barrier", 1.0)
151 |
152 | self.problem.add_component(task)
153 | self.problem.add_component(barrier)
154 |
155 | self.assertIsInstance(self.problem.barrier("test_barrier"), JointBarrier)
156 |
157 | with self.assertRaises(ValueError):
158 | _ = self.problem.barrier("test_task")
159 |
160 | with self.assertRaises(ValueError):
161 | _ = self.problem.barrier("non_existent_task")
162 |
163 | def test_setting_vmap_dimension(self):
164 | """Testing context manager for vmapping the dimensions"""
165 |
166 | task = ComTask("test_task", 1.0, 1.0)
167 | barrier = JointBarrier("test_barrier", 1.0)
168 |
169 | self.problem.add_component(task)
170 | self.problem.add_component(barrier)
171 |
172 | with self.problem.set_vmap_dimension() as empty_problem_data:
173 | empty_problem_data.components["test_task"].target_com = 0
174 |
175 | self.assertEqual(empty_problem_data.v_min, None)
176 | self.assertEqual(empty_problem_data.v_max, None)
177 |
178 | self.assertEqual(empty_problem_data.components["test_task"].target_com, 0)
179 |
--------------------------------------------------------------------------------