├── .github
└── workflows
│ ├── docs.yml
│ ├── formatting.yaml
│ ├── pyright.yml
│ └── pytest.yml
├── .gitignore
├── LICENSE
├── README.md
├── benchmark
└── ik_benchmark.py
├── docs
├── Makefile
├── README.md
├── requirements.txt
├── source
│ ├── _static
│ │ ├── basic_ik.mov
│ │ ├── css
│ │ │ └── custom.css
│ │ ├── logo.svg
│ │ └── logo_dark.svg
│ ├── _templates
│ │ └── sidebar
│ │ │ └── brand.html
│ ├── conf.py
│ ├── examples
│ │ ├── 01_basic_ik.rst
│ │ ├── 02_bimanual_ik.rst
│ │ ├── 03_mobile_ik.rst
│ │ ├── 04_ik_with_coll.rst
│ │ ├── 05_ik_with_manipulability.rst
│ │ ├── 06_online_planning.rst
│ │ ├── 07_trajopt.rst
│ │ ├── 08_ik_with_mimic_joints.rst
│ │ ├── 09_hand_retargeting.rst
│ │ └── 10_humanoid_retargeting.rst
│ ├── index.rst
│ └── misc
│ │ └── writing_manual_jac.rst
└── update_example_docs.py
├── examples
├── 01_basic_ik.py
├── 02_bimanual_ik.py
├── 03_mobile_ik.py
├── 04_ik_with_coll.py
├── 05_ik_with_manipulability.py
├── 06_online_planning.py
├── 07_trajopt.py
├── 08_ik_with_mimic_joints.py
├── 09_hand_retargeting.py
├── 10_humanoid_retargeting.py
├── pyroki_snippets
│ ├── __init__.py
│ ├── _online_planning.py
│ ├── _solve_ik.py
│ ├── _solve_ik_with_base.py
│ ├── _solve_ik_with_collision.py
│ ├── _solve_ik_with_manipulability.py
│ ├── _solve_ik_with_multiple_targets.py
│ └── _trajopt.py
└── retarget_helpers
│ ├── _utils.py
│ ├── hand
│ ├── dexycb_motion.pkl
│ └── shadowhand_urdf.zip
│ └── humanoid
│ ├── heightmap.npy
│ ├── left_foot_contact.npy
│ ├── right_foot_contact.npy
│ └── smpl_keypoints.npy
├── pyproject.toml
└── src
└── pyroki
├── __init__.py
├── _robot.py
├── _robot_urdf_parser.py
├── collision
├── __init__.py
├── _collision.py
├── _geometry.py
├── _geometry_pairs.py
├── _robot_collision.py
└── _utils.py
├── costs
├── __init__.py
├── _costs.py
├── _pose_cost_analytic_jac.py
└── _pose_cost_numerical_jac.py
├── utils.py
└── viewer
├── __init__.py
├── _batched_urdf.py
├── _manipulability_ellipse.py
└── _weight_tuner.py
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: docs
2 |
3 | on:
4 | push:
5 | branches: [main]
6 |
7 | permissions:
8 | contents: write
9 |
10 | jobs:
11 | docs:
12 | runs-on: ubuntu-latest
13 | steps:
14 | # Check out source.
15 | - uses: actions/checkout@v2
16 | with:
17 | fetch-depth: 0 # This ensures the entire history is fetched so we can switch branches
18 |
19 | - name: Set up Python
20 | uses: actions/setup-python@v1
21 | with:
22 | python-version: "3.12"
23 |
24 | - name: Set up dependencies
25 | run: |
26 | sudo apt update
27 | sudo apt install -y libsuitesparse-dev
28 | pip install uv
29 | uv pip install --system -e ".[dev,examples]"
30 | uv pip install --system -r docs/requirements.txt
31 |
32 | # Build documentation.
33 | - name: Building documentation
34 | run: |
35 | sphinx-build docs/source docs/build -b dirhtml
36 |
37 | # Deploy to version-dependent subdirectory.
38 | - name: Deploy to GitHub Pages
39 | uses: peaceiris/actions-gh-pages@v4
40 | with:
41 | github_token: ${{ secrets.GITHUB_TOKEN }}
42 | publish_dir: ./docs/build
43 |
--------------------------------------------------------------------------------
/.github/workflows/formatting.yaml:
--------------------------------------------------------------------------------
1 | name: formatting
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | jobs:
10 | build:
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - uses: actions/checkout@v2
15 | - name: Set up Python 3.12
16 | uses: actions/setup-python@v1
17 | with:
18 | python-version: 3.12
19 | - name: Install dependencies
20 | run: |
21 | sudo apt update
22 | sudo apt install -y libsuitesparse-dev
23 | pip install uv
24 | uv pip install --system -e ".[dev,examples]"
25 | - name: Run Ruff
26 | run: ruff check docs/ src/ examples/
27 | - name: Run Ruff format
28 | run: ruff format docs/ src/ examples/ --check
29 |
--------------------------------------------------------------------------------
/.github/workflows/pyright.yml:
--------------------------------------------------------------------------------
1 | name: pyright
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | jobs:
10 | pyright:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ["3.10", "3.11", "3.12", "3.13"]
15 |
16 | steps:
17 | - uses: actions/checkout@v2
18 | - name: Set up Python ${{ matrix.python-version }}
19 | uses: actions/setup-python@v1
20 | with:
21 | python-version: ${{ matrix.python-version }}
22 | - name: Install dependencies
23 | run: |
24 | sudo apt update
25 | sudo apt install -y libsuitesparse-dev
26 | pip install uv
27 | uv pip install --system -e ".[dev,examples]"
28 | - name: Run pyright
29 | run: |
30 | pyright .
31 |
--------------------------------------------------------------------------------
/.github/workflows/pytest.yml:
--------------------------------------------------------------------------------
1 | name: pytest
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | jobs:
10 | build:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ["3.10", "3.11", "3.12", "3.13"]
15 |
16 | steps:
17 | - uses: actions/checkout@v2
18 | - name: Set up Python ${{ matrix.python-version }}
19 | uses: actions/setup-python@v4
20 | with:
21 | python-version: ${{ matrix.python-version }}
22 | - name: Install dependencies
23 | run: |
24 | sudo apt update
25 | sudo apt install -y libsuitesparse-dev
26 | pip install uv
27 | uv pip install --system -e ".[dev,examples]"
28 | - name: Test with pytest
29 | run: |
30 | pytest
31 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.swo
3 | *.pyc
4 | *.egg-info
5 | *.ipynb_checkpoints
6 | __pycache__
7 | .coverage
8 | htmlcov
9 | .mypy_cache
10 | .dmypy.json
11 | .hypothesis
12 | .envrc
13 | .lvimrc
14 | .DS_Store
15 | .envrc
16 | .vite
17 | build
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Chung Min Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # `PyRoki`: Python Robot Kinematics Library
2 |
3 | **[Project page](https://pyroki-toolkit.github.io/) •
4 | [arXiv](https://arxiv.org/abs/2505.03728)**
5 |
6 | `PyRoki` is a modular, extensible, and cross-platform toolkit for kinematic optimization, all in Python.
7 |
8 | Core features include:
9 |
10 | - Differentiable robot forward kinematics model from a URDF.
11 | - Automatic generation of robot collision primitives (e.g., capsules).
12 | - Differentiable collision bodies with numpy broadcasting logic.
13 | - Common cost implementations (e.g., end effector pose, self/world-collision, manipulability).
14 | - Arbitrary costs, autodiff or analytical Jacobians.
15 | - Integration with a [Levenberg-Marquardt Solver](https://github.com/brentyi/jaxls) that supports optimization on manifolds (e.g., [lie groups](https://github.com/brentyi/jaxlie))
16 | - Cross-platform support (CPU, GPU, TPU) via JAX.
17 |
18 | Please refer to the [documentation](https://chungmin99.github.io/pyroki/) for more details, features, and usage examples.
19 |
20 | ---
21 |
22 | ## Installation
23 |
24 | You can install `pyroki` with `pip`, on Python 3.12+:
25 |
26 | ```
27 | git clone https://github.com/chungmin99/pyroki.git
28 | cd pyroki
29 | pip install -e .
30 | ```
31 |
32 | Python 3.10-3.11 should also work, but support may be dropped in the future.
33 |
34 | ## Status
35 |
36 | _May 6, 2025_: Initial release
37 |
38 | We are preparing and will release by _May 16, 2025_:
39 |
40 | - [x] Examples + documentation for hand / humanoid motion retargeting
41 | - [x] Documentation for using manually defined Jacobians
42 | - [x] Support with Python 3.10+
43 |
44 | ## Limitations
45 |
46 | - **Soft constraints only**: We use a nonlinear least-squares formulation and model joint limits, collision avoidance, etc. as soft penalties with high weights rather than hard constraints.
47 | - **Static shapes & JIT overhead**: JAX JIT compilation is triggered on first run and when input shapes change (e.g., number of targets, obstacles). Arrays can be pre-padded to vectorize over inputs with different shapes.
48 | - **No sampling-based planners**: We don't include sampling-based planners (e.g., graphs, trees).
49 | - **Collision performance**: Speed and accuracy comparisons against other robot toolkits such as CuRobo have not been extensively performed, and is likely slower than other toolkits for collision-heavy scenarios.
50 |
51 | The following are current implementation limitations that could potentially be addressed in future versions:
52 |
53 | - **Joint types**: We only support revolute, continuous, prismatic, and fixed joints. Other URDF joint types are treated as fixed joints.
54 | - **Collision geometry**: We are limited to sphere, capsule, halfspace, and heightmap geometries. Mesh collision is approximated as capsules.
55 | - **Kinematic structures**: We only support kinematic trees; no closed-loop mechanisms or parallel manipulators.
56 |
57 | ## Citation
58 |
59 | This codebase is released with the following preprint.
60 |
61 |
62 | Chung Min Kim*, Brent Yi*, Hongsuk Choi, Yi Ma, Ken Goldberg, Angjoo Kanazawa.
63 | PyRoki: A Modular Toolkit for Robot Kinematic Optimization
64 | arXiV, 2025.
65 | |
66 |
67 |
68 | \*Equal Contribution, UC Berkeley.
69 |
70 | Please cite PyRoki if you find this work useful for your research:
71 |
72 | ```
73 | @misc{pyroki2025,
74 | title={PyRoki: A Modular Toolkit for Robot Kinematic Optimization},
75 | author={Chung Min Kim* and Brent Yi* and Hongsuk Choi and Yi Ma and Ken Goldberg and Angjoo Kanazawa},
76 | year={2025},
77 | eprint={2505.03728},
78 | archivePrefix={arXiv},
79 | primaryClass={cs.RO},
80 | url={https://arxiv.org/abs/2505.03728},
81 | }
82 | ```
83 |
84 | Thanks!
85 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = viser
8 | SOURCEDIR = source
9 | BUILDDIR = ./build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Viser Documentation
2 |
3 | This directory contains the documentation for Viser.
4 |
5 | ## Building the Documentation
6 |
7 | To build the documentation:
8 |
9 | 1. Install the documentation dependencies:
10 |
11 | ```bash
12 | pip install -r docs/requirements.txt
13 | ```
14 |
15 | 2. Build the documentation:
16 |
17 | ```bash
18 | cd docs
19 | make html
20 | ```
21 |
22 | 3. View the documentation:
23 |
24 | ```bash
25 | # On macOS
26 | open build/html/index.html
27 |
28 | # On Linux
29 | xdg-open build/html/index.html
30 | ```
31 |
32 | ## Contributing Screenshots
33 |
34 | When adding new documentation, screenshots and visual examples significantly improve user understanding.
35 |
36 | We need screenshots for:
37 |
38 | - The Getting Started guide
39 | - GUI element examples
40 | - Scene API visualization examples
41 | - Customization/theming examples
42 |
43 | See [Contributing Visuals](./source/contributing_visuals.md) for guidelines on capturing and adding images to the documentation.
44 |
45 | ## Documentation Structure
46 |
47 | - `source/` - Source files for the documentation
48 | - `_static/` - Static files (CSS, images, etc.)
49 | - `images/` - Screenshots and other images
50 | - `examples/` - Example code with documentation
51 | - `*.md` - Markdown files for documentation pages
52 | - `conf.py` - Sphinx configuration
53 |
54 | ## Auto-Generated Example Documentation
55 |
56 | Example documentation is automatically generated from the examples in the `examples/` directory using the `update_example_docs.py` script. To update the example documentation after making changes to examples:
57 |
58 | ```bash
59 | cd docs
60 | python update_example_docs.py
61 | ```
62 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==8.0.2
2 | furo==2024.8.6
3 | docutils==0.20.1
4 | toml==0.10.2
5 | sphinxcontrib-video==0.4.1
6 | git+https://github.com/brentyi/sphinxcontrib-programoutput.git
7 | git+https://github.com/brentyi/ansi.git
8 | git+https://github.com/sphinx-contrib/googleanalytics.git
9 |
10 | snowballstemmer==2.2.0 # https://github.com/snowballstem/snowball/issues/229
11 |
--------------------------------------------------------------------------------
/docs/source/_static/basic_ik.mov:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/docs/source/_static/basic_ik.mov
--------------------------------------------------------------------------------
/docs/source/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | img.sidebar-logo {
2 | width: 10em;
3 | margin: 1em 0 0 0;
4 | }
5 |
--------------------------------------------------------------------------------
/docs/source/_static/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/docs/source/_templates/sidebar/brand.html:
--------------------------------------------------------------------------------
1 |
5 | {% block brand_content %} {%- if logo_url %}
6 |
11 | {%- endif %} {%- if theme_light_logo and theme_dark_logo %}
12 |
24 | {%- endif %}
25 |
26 | {% endblock brand_content %}
27 |
28 |
29 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # Configuration file for the Sphinx documentation builder.
4 | #
5 | # This file does only contain a selection of the most common options. For a
6 | # full list see the documentation:
7 | # http://www.sphinx-doc.org/en/stable/config
8 |
9 | import os
10 | from typing import Dict, List
11 |
12 | import pyroki
13 |
14 | # -- Path setup --------------------------------------------------------------
15 |
16 | # If extensions (or modules to document with autodoc) are in another directory,
17 | # add these directories to sys.path here. If the directory is relative to the
18 | # documentation root, use os.path.abspath to make it absolute, like shown here.
19 | #
20 |
21 | # -- Project information -----------------------------------------------------
22 |
23 | project = "pyroki" # Change project name
24 | copyright = "2025" # Update copyright year/holder if needed
25 | author = "chungmin99" # Update author name
26 |
27 | version: str = os.environ.get(
28 | "PYROKI_VERSION_STR_OVERRIDE", pyroki.__version__
29 | ) # Remove this
30 |
31 | # Formatting!
32 | # 0.1.30 => v0.1.30
33 | # dev => dev
34 | if not version.isalpha():
35 | version = "v" + version
36 |
37 | # The full version, including alpha/beta/rc tags
38 | release = version # Use the same version for release for now
39 |
40 |
41 | # -- General configuration ---------------------------------------------------
42 |
43 | # If your documentation needs a minimal Sphinx version, state it here.
44 | #
45 | # needs_sphinx = '1.0'
46 |
47 | # Add any Sphinx extension module names here, as strings. They can be
48 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
49 | # ones.
50 | extensions = [
51 | "sphinx.ext.autodoc",
52 | "sphinx.ext.todo",
53 | "sphinx.ext.coverage",
54 | "sphinx.ext.mathjax",
55 | "sphinx.ext.githubpages",
56 | "sphinx.ext.napoleon",
57 | # "sphinx.ext.inheritance_diagram",
58 | "sphinxcontrib.video",
59 | "sphinx.ext.viewcode",
60 | "sphinxcontrib.programoutput",
61 | "sphinxcontrib.ansi",
62 | # "sphinxcontrib.googleanalytics", # google analytics extension https://github.com/sphinx-contrib/googleanalytics/tree/master
63 | ]
64 | programoutput_use_ansi = True
65 | html_ansi_stylesheet = "black-on-white.css"
66 | html_static_path = ["_static"]
67 | html_theme_options = {
68 | "light_css_variables": {
69 | "color-code-background": "#f4f4f4",
70 | "color-code-foreground": "#000",
71 | },
72 | # Remove viser-specific footer icon
73 | "footer_icons": [
74 | {
75 | "name": "GitHub",
76 | "url": "https://github.com/chungmin99/pyroki-dev",
77 | "html": """
78 |
81 | """,
82 | "class": "",
83 | },
84 | ],
85 | # Remove viser-specific logos
86 | "light_logo": "logo.svg",
87 | "dark_logo": "logo_dark.svg",
88 | }
89 |
90 | # Pull documentation types from hints
91 | autodoc_typehints = "both"
92 | autodoc_class_signature = "separated"
93 | autodoc_default_options = {
94 | "members": True,
95 | "member-order": "bysource",
96 | "undoc-members": True,
97 | "inherited-members": True,
98 | "exclude-members": "__init__, __post_init__",
99 | "imported-members": True,
100 | }
101 |
102 | # Add any paths that contain templates here, relative to this directory.
103 | templates_path = ["_templates"]
104 |
105 | # The suffix(es) of source filenames.
106 | # You can specify multiple suffix as a list of string:
107 | #
108 | source_suffix = ".rst"
109 | # source_suffix = ".rst"
110 |
111 | # The master toctree document.
112 | master_doc = "index"
113 |
114 | # The language for content autogenerated by Sphinx. Refer to documentation
115 | # for a list of supported languages.
116 | #
117 | # This is also used if you do content translation via gettext catalogs.
118 | # Usually you set "language" from the command line for these cases.
119 | language: str = "en"
120 |
121 | # List of patterns, relative to source directory, that match files and
122 | # directories to ignore when looking for source files.
123 | # This pattern also affects html_static_path and html_extra_path .
124 | exclude_patterns: List[str] = []
125 |
126 | # The name of the Pygments (syntax highlighting) style to use.
127 | pygments_style = "default"
128 |
129 |
130 | # -- Options for HTML output -------------------------------------------------
131 |
132 | # The theme to use for HTML and HTML Help pages. See the documentation for
133 | # a list of builtin themes.
134 | #
135 | html_theme = "furo"
136 | html_title = "pyroki" # Update title
137 |
138 |
139 | # Theme options are theme-specific and customize the look and feel of a theme
140 | # further. For a list of options available for each theme, see the
141 | # documentation.
142 | #
143 | # html_theme_options = {}
144 |
145 | # Add any paths that contain custom static files (such as style sheets) here,
146 | # relative to this directory. They are copied after the builtin static files,
147 | # so a file named "default.css" will overwrite the builtin "default.css".
148 | # html_static_path = ["_static"]
149 |
150 | # Custom sidebar templates, must be a dictionary that maps document names
151 | # to template names.
152 | #
153 | # The default sidebars (for documents that don't match any pattern) are
154 | # defined by theme itself. Builtin themes are using these templates by
155 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
156 | # 'searchbox.html']``.
157 | #
158 | # html_sidebars = {}
159 |
160 |
161 | # -- Options for HTMLHelp output ---------------------------------------------
162 |
163 | # Output file base name for HTML help builder.
164 | htmlhelp_basename = "pyroki_doc" # Update basename
165 |
166 |
167 | # -- Options for Github output ------------------------------------------------
168 |
169 | sphinx_to_github = True
170 | sphinx_to_github_verbose = True
171 | sphinx_to_github_encoding = "utf-8"
172 |
173 |
174 | # -- Options for LaTeX output ------------------------------------------------
175 |
176 | latex_elements: Dict[str, str] = {
177 | # The paper size ('letterpaper' or 'a4paper').
178 | #
179 | # 'papersize': 'letterpaper',
180 | # The font size ('10pt', '11pt' or '12pt').
181 | #
182 | # 'pointsize': '10pt',
183 | # Additional stuff for the LaTeX preamble.
184 | #
185 | # 'preamble': '',
186 | # Latex figure (float) alignment
187 | #
188 | # 'figure_align': 'htbp',
189 | }
190 |
191 | # Grouping the document tree into LaTeX files. List of tuples
192 | # (source start file, target name, title,
193 | # author, documentclass [howto, manual, or own class]).
194 | latex_documents = [
195 | (
196 | master_doc,
197 | "pyroki.tex", # Update target name
198 | "pyroki", # Update title
199 | "Your Name", # Update author
200 | "manual",
201 | ),
202 | ]
203 |
204 |
205 | # -- Options for manual page output ------------------------------------------
206 |
207 | # One entry per manual page. List of tuples
208 | # (source start file, name, description, authors, manual section).
209 | man_pages = [
210 | (master_doc, "pyroki", "pyroki documentation", [author], 1)
211 | ] # Update name and description
212 |
213 |
214 | # -- Options for Texinfo output ----------------------------------------------
215 |
216 | # Grouping the document tree into Texinfo files. List of tuples
217 | # (source start file, target name, title, author,
218 | # dir menu entry, description, category)
219 | texinfo_documents = [
220 | (
221 | master_doc,
222 | "pyroki", # Update target name
223 | "pyroki", # Update title
224 | author,
225 | "pyroki", # Update dir menu entry
226 | "Python Robot Kinematics library", # Update description
227 | "Miscellaneous",
228 | ),
229 | ]
230 |
231 |
232 | # -- Extension configuration --------------------------------------------------
233 |
234 | # Google Analytics ID
235 | # googleanalytics_id = "G-RRGY51J5ZH" # Remove this
236 |
237 | # -- Options for todo extension ----------------------------------------------
238 |
239 | # If true, `todo` and `todoList` produce output, else they produce nothing.
240 | todo_include_todos = True
241 |
242 | # -- Setup function ----------------------------------------
243 |
244 |
245 | def setup(app):
246 | app.add_css_file("css/custom.css")
247 |
248 |
249 | # -- Napoleon settings -------------------------------------------------------
250 |
251 | # Settings for parsing non-sphinx style docstrings. We use Google style in this
252 | # project.
253 | napoleon_google_docstring = True
254 | napoleon_numpy_docstring = False
255 | napoleon_include_init_with_doc = False
256 | napoleon_include_private_with_doc = False
257 | napoleon_include_special_with_doc = True
258 | napoleon_use_admonition_for_examples = False
259 | napoleon_use_admonition_for_notes = False
260 | napoleon_use_admonition_for_references = False
261 | napoleon_use_ivar = False
262 | napoleon_use_param = True
263 | napoleon_use_rtype = True
264 | napoleon_preprocess_types = True
265 | napoleon_type_aliases = None
266 | napoleon_attr_annotations = True
267 |
--------------------------------------------------------------------------------
/docs/source/examples/01_basic_ik.rst:
--------------------------------------------------------------------------------
1 | .. Comment: this file is automatically generated by `update_example_docs.py`.
2 | It should not be modified manually.
3 |
4 | Basic IK
5 | ==========================================
6 |
7 |
8 | Simplest Inverse Kinematics Example using PyRoki.
9 |
10 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details.
11 |
12 |
13 |
14 | .. code-block:: python
15 | :linenos:
16 |
17 |
18 | import time
19 |
20 | import numpy as np
21 | import pyroki as pk
22 | import viser
23 | from robot_descriptions.loaders.yourdfpy import load_robot_description
24 | from viser.extras import ViserUrdf
25 |
26 | import pyroki_snippets as pks
27 |
28 |
29 | def main():
30 | """Main function for basic IK."""
31 |
32 | urdf = load_robot_description("panda_description")
33 | target_link_name = "panda_hand"
34 |
35 | # Create robot.
36 | robot = pk.Robot.from_urdf(urdf)
37 |
38 | # Set up visualizer.
39 | server = viser.ViserServer()
40 | server.scene.add_grid("/ground", width=2, height=2)
41 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
42 |
43 | # Create interactive controller with initial position.
44 | ik_target = server.scene.add_transform_controls(
45 | "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0)
46 | )
47 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
48 |
49 | while True:
50 | # Solve IK.
51 | start_time = time.time()
52 | solution = pks.solve_ik(
53 | robot=robot,
54 | target_link_name=target_link_name,
55 | target_position=np.array(ik_target.position),
56 | target_wxyz=np.array(ik_target.wxyz),
57 | )
58 |
59 | # Update timing handle.
60 | elapsed_time = time.time() - start_time
61 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
62 |
63 | # Update visualizer.
64 | urdf_vis.update_cfg(solution)
65 |
66 |
67 | if __name__ == "__main__":
68 | main()
69 |
--------------------------------------------------------------------------------
/docs/source/examples/02_bimanual_ik.rst:
--------------------------------------------------------------------------------
1 | .. Comment: this file is automatically generated by `update_example_docs.py`.
2 | It should not be modified manually.
3 |
4 | Bimanual IK
5 | ==========================================
6 |
7 |
8 | Same as 01_basic_ik.py, but with two end effectors!
9 |
10 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details.
11 |
12 |
13 |
14 | .. code-block:: python
15 | :linenos:
16 |
17 |
18 | import time
19 | import viser
20 | from robot_descriptions.loaders.yourdfpy import load_robot_description
21 | import numpy as np
22 |
23 | import pyroki as pk
24 | from viser.extras import ViserUrdf
25 | import pyroki_snippets as pks
26 |
27 |
28 | def main():
29 | """Main function for bimanual IK."""
30 |
31 | urdf = load_robot_description("yumi_description")
32 | target_link_names = ["yumi_link_7_r", "yumi_link_7_l"]
33 |
34 | # Create robot.
35 | robot = pk.Robot.from_urdf(urdf)
36 |
37 | # Set up visualizer.
38 | server = viser.ViserServer()
39 | server.scene.add_grid("/ground", width=2, height=2)
40 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
41 |
42 | # Create interactive controller with initial position.
43 | ik_target_0 = server.scene.add_transform_controls(
44 | "/ik_target_0", scale=0.2, position=(0.41, -0.3, 0.56), wxyz=(0, 0, 1, 0)
45 | )
46 | ik_target_1 = server.scene.add_transform_controls(
47 | "/ik_target_1", scale=0.2, position=(0.41, 0.3, 0.56), wxyz=(0, 0, 1, 0)
48 | )
49 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
50 |
51 | while True:
52 | # Solve IK.
53 | start_time = time.time()
54 | solution = pks.solve_ik_with_multiple_targets(
55 | robot=robot,
56 | target_link_names=target_link_names,
57 | target_positions=np.array([ik_target_0.position, ik_target_1.position]),
58 | target_wxyzs=np.array([ik_target_0.wxyz, ik_target_1.wxyz]),
59 | )
60 |
61 | # Update timing handle.
62 | elapsed_time = time.time() - start_time
63 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
64 |
65 | # Update visualizer.
66 | urdf_vis.update_cfg(solution)
67 |
68 |
69 | if __name__ == "__main__":
70 | main()
71 |
--------------------------------------------------------------------------------
/docs/source/examples/03_mobile_ik.rst:
--------------------------------------------------------------------------------
1 | .. Comment: this file is automatically generated by `update_example_docs.py`.
2 | It should not be modified manually.
3 |
4 | Mobile IK
5 | ==========================================
6 |
7 |
8 | Same as 01_basic_ik.py, but with a mobile base!
9 |
10 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details.
11 |
12 |
13 |
14 | .. code-block:: python
15 | :linenos:
16 |
17 |
18 | import time
19 | import viser
20 | from robot_descriptions.loaders.yourdfpy import load_robot_description
21 | import numpy as np
22 |
23 | import pyroki as pk
24 | from viser.extras import ViserUrdf
25 | import pyroki_snippets as pks
26 |
27 |
28 | def main():
29 | """Main function for IK with a mobile base.
30 | The base is fixed along the xy plane, and is biased towards being at the origin.
31 | """
32 |
33 | urdf = load_robot_description("fetch_description")
34 | target_link_name = "gripper_link"
35 |
36 | # Create robot.
37 | robot = pk.Robot.from_urdf(urdf)
38 |
39 | # Set up visualizer.
40 | server = viser.ViserServer()
41 | server.scene.add_grid("/ground", width=2, height=2)
42 | base_frame = server.scene.add_frame("/base", show_axes=False)
43 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
44 |
45 | # Create interactive controller with initial position.
46 | ik_target = server.scene.add_transform_controls(
47 | "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0.707, 0, -0.707)
48 | )
49 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
50 |
51 | cfg = np.array(robot.joint_var_cls(0).default_factory())
52 |
53 | while True:
54 | # Solve IK.
55 | start_time = time.time()
56 | base_pos, base_wxyz, cfg = pks.solve_ik_with_base(
57 | robot=robot,
58 | target_link_name=target_link_name,
59 | target_position=np.array(ik_target.position),
60 | target_wxyz=np.array(ik_target.wxyz),
61 | fix_base_position=(False, False, True), # Only free along xy plane.
62 | fix_base_orientation=(True, True, False), # Free along z-axis rotation.
63 | prev_pos=base_frame.position,
64 | prev_wxyz=base_frame.wxyz,
65 | prev_cfg=cfg,
66 | )
67 |
68 | # Update timing handle.
69 | elapsed_time = time.time() - start_time
70 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
71 |
72 | # Update visualizer.
73 | urdf_vis.update_cfg(cfg)
74 | base_frame.position = np.array(base_pos)
75 | base_frame.wxyz = np.array(base_wxyz)
76 |
77 |
78 | if __name__ == "__main__":
79 | main()
80 |
--------------------------------------------------------------------------------
/docs/source/examples/04_ik_with_coll.rst:
--------------------------------------------------------------------------------
1 | .. Comment: this file is automatically generated by `update_example_docs.py`.
2 | It should not be modified manually.
3 |
4 | IK with Collision
5 | ==========================================
6 |
7 |
8 | Basic Inverse Kinematics with Collision Avoidance using PyRoKi.
9 |
10 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details.
11 |
12 |
13 |
14 | .. code-block:: python
15 | :linenos:
16 |
17 |
18 | import time
19 |
20 | import numpy as np
21 | import pyroki as pk
22 | import viser
23 | from pyroki.collision import HalfSpace, RobotCollision, Sphere
24 | from robot_descriptions.loaders.yourdfpy import load_robot_description
25 | from viser.extras import ViserUrdf
26 |
27 | import pyroki_snippets as pks
28 |
29 |
30 | def main():
31 | """Main function for basic IK with collision."""
32 | urdf = load_robot_description("panda_description")
33 | target_link_name = "panda_hand"
34 | robot = pk.Robot.from_urdf(urdf)
35 |
36 | robot_coll = RobotCollision.from_urdf(urdf)
37 | plane_coll = HalfSpace.from_point_and_normal(
38 | np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
39 | )
40 | sphere_coll = Sphere.from_center_and_radius(
41 | np.array([0.0, 0.0, 0.0]), np.array([0.05])
42 | )
43 |
44 | # Set up visualizer.
45 | server = viser.ViserServer()
46 | server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
47 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot")
48 |
49 | # Create interactive controller for IK target.
50 | ik_target_handle = server.scene.add_transform_controls(
51 | "/ik_target", scale=0.2, position=(0.5, 0.0, 0.5), wxyz=(0, 0, 1, 0)
52 | )
53 |
54 | # Create interactive controller and mesh for the sphere obstacle.
55 | sphere_handle = server.scene.add_transform_controls(
56 | "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4)
57 | )
58 | server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh())
59 |
60 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
61 |
62 | while True:
63 | start_time = time.time()
64 |
65 | sphere_coll_world_current = sphere_coll.transform_from_wxyz_position(
66 | wxyz=np.array(sphere_handle.wxyz),
67 | position=np.array(sphere_handle.position),
68 | )
69 |
70 | world_coll_list = [plane_coll, sphere_coll_world_current]
71 | solution = pks.solve_ik_with_collision(
72 | robot=robot,
73 | coll=robot_coll,
74 | world_coll_list=world_coll_list,
75 | target_link_name=target_link_name,
76 | target_position=np.array(ik_target_handle.position),
77 | target_wxyz=np.array(ik_target_handle.wxyz),
78 | )
79 |
80 | # Update timing handle.
81 | timing_handle.value = (time.time() - start_time) * 1000
82 |
83 | # Update visualizer.
84 | urdf_vis.update_cfg(solution)
85 |
86 |
87 | if __name__ == "__main__":
88 | main()
89 |
--------------------------------------------------------------------------------
/docs/source/examples/05_ik_with_manipulability.rst:
--------------------------------------------------------------------------------
1 | .. Comment: this file is automatically generated by `update_example_docs.py`.
2 | It should not be modified manually.
3 |
4 | IK with Manipulability
5 | ==========================================
6 |
7 |
8 | Inverse Kinematics with Manipulability using PyRoKi.
9 |
10 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details.
11 |
12 |
13 |
14 | .. code-block:: python
15 | :linenos:
16 |
17 |
18 | import time
19 | import viser
20 | from robot_descriptions.loaders.yourdfpy import load_robot_description
21 | import numpy as np
22 |
23 | import pyroki as pk
24 | from viser.extras import ViserUrdf
25 | import pyroki_snippets as pks
26 |
27 |
28 | def main():
29 | """Main function for basic IK."""
30 |
31 | urdf = load_robot_description("panda_description")
32 | target_link_name = "panda_hand"
33 |
34 | # Create robot.
35 | robot = pk.Robot.from_urdf(urdf)
36 |
37 | # Set up visualizer.
38 | server = viser.ViserServer()
39 | server.scene.add_grid("/ground", width=2, height=2)
40 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
41 |
42 | # Create interactive controller with initial position.
43 | ik_target = server.scene.add_transform_controls(
44 | "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0)
45 | )
46 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
47 | value_handle = server.gui.add_number("Yoshikawa Index", 0.001, disabled=True)
48 | weight_handle = server.gui.add_slider(
49 | "Manipulability Weight", 0.0, 10.0, 0.001, 0.0
50 | )
51 | manip_ellipse = pk.viewer.ManipulabilityEllipse(
52 | server,
53 | robot,
54 | root_node_name="/manipulability",
55 | target_link_name=target_link_name,
56 | )
57 |
58 | while True:
59 | # Solve IK.
60 | start_time = time.time()
61 | solution = pks.solve_ik_with_manipulability(
62 | robot=robot,
63 | target_link_name=target_link_name,
64 | target_position=np.array(ik_target.position),
65 | target_wxyz=np.array(ik_target.wxyz),
66 | manipulability_weight=weight_handle.value,
67 | )
68 |
69 | manip_ellipse.update(solution)
70 | value_handle.value = manip_ellipse.manipulability
71 |
72 | # Update timing handle.
73 | elapsed_time = time.time() - start_time
74 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
75 |
76 | # Update visualizer.
77 | urdf_vis.update_cfg(solution)
78 |
79 |
80 | if __name__ == "__main__":
81 | main()
82 |
--------------------------------------------------------------------------------
/docs/source/examples/06_online_planning.rst:
--------------------------------------------------------------------------------
1 | .. Comment: this file is automatically generated by `update_example_docs.py`.
2 | It should not be modified manually.
3 |
4 | Online Planning
5 | ==========================================
6 |
7 |
8 | Run online planning in collision aware environments.
9 |
10 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details.
11 |
12 |
13 |
14 | .. code-block:: python
15 | :linenos:
16 |
17 |
18 | import time
19 |
20 | import numpy as np
21 | import pyroki as pk
22 | import viser
23 | from pyroki.collision import HalfSpace, RobotCollision, Sphere
24 | from robot_descriptions.loaders.yourdfpy import load_robot_description
25 | from viser.extras import ViserUrdf
26 |
27 | import pyroki_snippets as pks
28 |
29 |
30 | def main():
31 | """Main function for online planning with collision."""
32 | urdf = load_robot_description("panda_description")
33 | target_link_name = "panda_hand"
34 | robot = pk.Robot.from_urdf(urdf)
35 |
36 | robot_coll = RobotCollision.from_urdf(urdf)
37 | plane_coll = HalfSpace.from_point_and_normal(
38 | np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
39 | )
40 | sphere_coll = Sphere.from_center_and_radius(
41 | np.array([0.0, 0.0, 0.0]), np.array([0.05])
42 | )
43 |
44 | # Define the online planning parameters.
45 | len_traj, dt = 5, 0.1
46 |
47 | # Set up visualizer.
48 | server = viser.ViserServer()
49 | server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
50 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot")
51 |
52 | # Create interactive controller for IK target.
53 | ik_target_handle = server.scene.add_transform_controls(
54 | "/ik_target", scale=0.2, position=(0.3, 0.0, 0.5), wxyz=(0, 0, 1, 0)
55 | )
56 |
57 | # Create interactive controller and mesh for the sphere obstacle.
58 | sphere_handle = server.scene.add_transform_controls(
59 | "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4)
60 | )
61 | server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh())
62 | target_frame_handle = server.scene.add_batched_axes(
63 | "target_frame",
64 | axes_length=0.05,
65 | axes_radius=0.005,
66 | batched_positions=np.zeros((25, 3)),
67 | batched_wxyzs=np.array([[1.0, 0.0, 0.0, 0.0]] * 25),
68 | )
69 |
70 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
71 |
72 | sol_pos, sol_wxyz = None, None
73 | sol_traj = np.array(
74 | robot.joint_var_cls.default_factory()[None].repeat(len_traj, axis=0)
75 | )
76 | while True:
77 | start_time = time.time()
78 |
79 | sphere_coll_world_current = sphere_coll.transform_from_wxyz_position(
80 | wxyz=np.array(sphere_handle.wxyz),
81 | position=np.array(sphere_handle.position),
82 | )
83 |
84 | world_coll_list = [plane_coll, sphere_coll_world_current]
85 | sol_traj, sol_pos, sol_wxyz = pks.solve_online_planning(
86 | robot=robot,
87 | robot_coll=robot_coll,
88 | world_coll=world_coll_list,
89 | target_link_name=target_link_name,
90 | target_position=np.array(ik_target_handle.position),
91 | target_wxyz=np.array(ik_target_handle.wxyz),
92 | timesteps=len_traj,
93 | dt=dt,
94 | start_cfg=sol_traj[0],
95 | prev_sols=sol_traj,
96 | )
97 |
98 | # Update timing handle.
99 | timing_handle.value = (
100 | 0.99 * timing_handle.value + 0.01 * (time.time() - start_time) * 1000
101 | )
102 |
103 | # Update visualizer.
104 | urdf_vis.update_cfg(
105 | sol_traj[0]
106 | ) # The first step of the online trajectory solution.
107 |
108 | # Update the planned trajectory visualization.
109 | if hasattr(target_frame_handle, "batched_positions"):
110 | target_frame_handle.batched_positions = np.array(sol_pos) # type: ignore[attr-defined]
111 | target_frame_handle.batched_wxyzs = np.array(sol_wxyz) # type: ignore[attr-defined]
112 | else:
113 | # This is an older version of Viser.
114 | target_frame_handle.positions_batched = np.array(sol_pos) # type: ignore[attr-defined]
115 | target_frame_handle.wxyzs_batched = np.array(sol_wxyz) # type: ignore[attr-defined]
116 |
117 |
118 | if __name__ == "__main__":
119 | main()
120 |
--------------------------------------------------------------------------------
/docs/source/examples/07_trajopt.rst:
--------------------------------------------------------------------------------
1 | .. Comment: this file is automatically generated by `update_example_docs.py`.
2 | It should not be modified manually.
3 |
4 | Trajectory Optimization
5 | ==========================================
6 |
7 |
8 | Basic Trajectory Optimization using PyRoKi.
9 |
10 | Robot going over a wall, while avoiding world-collisions.
11 |
12 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details.
13 |
14 |
15 |
16 | .. code-block:: python
17 | :linenos:
18 |
19 |
20 | import time
21 | from typing import Literal
22 |
23 | import numpy as np
24 | import pyroki as pk
25 | import trimesh
26 | import tyro
27 | import viser
28 | from viser.extras import ViserUrdf
29 | from robot_descriptions.loaders.yourdfpy import load_robot_description
30 |
31 | import pyroki_snippets as pks
32 |
33 |
34 | def main(robot_name: Literal["ur5", "panda"] = "panda"):
35 | if robot_name == "ur5":
36 | urdf = load_robot_description("ur5_description")
37 | down_wxyz = np.array([0.707, 0, 0.707, 0])
38 | target_link_name = "ee_link"
39 |
40 | # For UR5 it's important to initialize the robot in a safe configuration;
41 | # the zero-configuration puts the robot aligned with the wall obstacle.
42 | default_cfg = np.zeros(6)
43 | default_cfg[1] = -1.308
44 | robot = pk.Robot.from_urdf(urdf, default_joint_cfg=default_cfg)
45 |
46 | elif robot_name == "panda":
47 | urdf = load_robot_description("panda_description")
48 | target_link_name = "panda_hand"
49 | down_wxyz = np.array([0, 0, 1, 0]) # for panda!
50 | robot = pk.Robot.from_urdf(urdf)
51 |
52 | else:
53 | raise ValueError(f"Invalid robot: {robot_name}")
54 |
55 | robot_coll = pk.collision.RobotCollision.from_urdf(urdf)
56 |
57 | # Define the trajectory problem:
58 | # - number of timesteps, timestep size
59 | timesteps, dt = 25, 0.02
60 | # - the start and end poses.
61 | start_pos, end_pos = np.array([0.5, -0.3, 0.2]), np.array([0.5, 0.3, 0.2])
62 |
63 | # Define the obstacles:
64 | # - Ground
65 | ground_coll = pk.collision.HalfSpace.from_point_and_normal(
66 | np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
67 | )
68 | # - Wall
69 | wall_height = 0.4
70 | wall_width = 0.1
71 | wall_length = 0.4
72 | wall_intervals = np.arange(start=0.3, stop=wall_length + 0.3, step=0.05)
73 | translation = np.concatenate(
74 | [
75 | wall_intervals.reshape(-1, 1),
76 | np.full((wall_intervals.shape[0], 1), 0.0),
77 | np.full((wall_intervals.shape[0], 1), wall_height / 2),
78 | ],
79 | axis=1,
80 | )
81 | wall_coll = pk.collision.Capsule.from_radius_height(
82 | position=translation,
83 | radius=np.full((translation.shape[0], 1), wall_width / 2),
84 | height=np.full((translation.shape[0], 1), wall_height),
85 | )
86 | world_coll = [ground_coll, wall_coll]
87 |
88 | traj = pks.solve_trajopt(
89 | robot,
90 | robot_coll,
91 | world_coll,
92 | target_link_name,
93 | start_pos,
94 | down_wxyz,
95 | end_pos,
96 | down_wxyz,
97 | timesteps,
98 | dt,
99 | )
100 | traj = np.array(traj)
101 |
102 | # Visualize!
103 | server = viser.ViserServer()
104 | urdf_vis = ViserUrdf(server, urdf)
105 | server.scene.add_grid("/grid", width=2, height=2, cell_size=0.1)
106 | server.scene.add_mesh_trimesh(
107 | "wall_box",
108 | trimesh.creation.box(
109 | extents=(wall_length, wall_width, wall_height),
110 | transform=trimesh.transformations.translation_matrix(
111 | np.array([0.5, 0.0, wall_height / 2])
112 | ),
113 | ),
114 | )
115 | for name, pos in zip(["start", "end"], [start_pos, end_pos]):
116 | server.scene.add_frame(
117 | f"/{name}",
118 | position=pos,
119 | wxyz=down_wxyz,
120 | axes_length=0.05,
121 | axes_radius=0.01,
122 | )
123 |
124 | slider = server.gui.add_slider(
125 | "Timestep", min=0, max=timesteps - 1, step=1, initial_value=0
126 | )
127 | playing = server.gui.add_checkbox("Playing", initial_value=True)
128 |
129 | while True:
130 | if playing.value:
131 | slider.value = (slider.value + 1) % timesteps
132 |
133 | urdf_vis.update_cfg(traj[slider.value])
134 | time.sleep(1.0 / 10.0)
135 |
136 |
137 | if __name__ == "__main__":
138 | tyro.cli(main)
139 |
--------------------------------------------------------------------------------
/docs/source/examples/08_ik_with_mimic_joints.rst:
--------------------------------------------------------------------------------
1 | .. Comment: this file is automatically generated by `update_example_docs.py`.
2 | It should not be modified manually.
3 |
4 | IK with mimic joints
5 | ==========================================
6 |
7 |
8 | This is a simple test to ensure that mimic joints are handled correctly in the IK solver.
9 |
10 | We procedurally generate a "zig-zag" chain of links with mimic joints, where:
11 |
12 |
13 | * the first joint is driven directly,
14 | * and the remaining joints are driven indirectly via mimic joints.
15 | The multipliers alternate between -1 and 1, and the offsets are all 0.
16 |
17 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details.
18 |
19 |
20 |
21 | .. code-block:: python
22 | :linenos:
23 |
24 |
25 | import time
26 |
27 | import numpy as np
28 | import pyroki as pk
29 | import viser
30 | from viser.extras import ViserUrdf
31 |
32 | import pyroki_snippets as pks
33 |
34 |
35 | def create_chain_xml(length: float = 0.2, num_chains: int = 5) -> str:
36 | def create_link(idx):
37 | return f"""
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 | """
50 |
51 | def create_joint(idx, multiplier=1.0, offset=0.0):
52 | mimic = f''
53 | return f"""
54 |
55 |
56 |
57 |
58 |
59 | {mimic if idx != 0 else ""}
60 |
61 |
62 | """
63 |
64 | world_joint_origin_z = length / 2.0
65 | xml = f"""
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 | """
77 | # Create the definition + first link.
78 | xml += create_link(0)
79 | xml += create_link(1)
80 | xml += create_joint(0)
81 |
82 | # Procedurally add more links.
83 | assert num_chains >= 2
84 | for idx in range(2, num_chains):
85 | xml += create_link(idx)
86 | current_offset = 0.0
87 | current_multiplier = 1.0 * ((-1) ** (idx % 2))
88 | xml += create_joint(idx - 1, current_multiplier, current_offset)
89 |
90 | xml += """
91 |
92 | """
93 | return xml
94 |
95 |
96 | def main():
97 | """Main function for basic IK."""
98 |
99 | import yourdfpy
100 | import tempfile
101 |
102 | xml = create_chain_xml(num_chains=10, length=0.1)
103 | with tempfile.NamedTemporaryFile(mode="w", suffix=".urdf") as f:
104 | f.write(xml)
105 | f.flush()
106 | urdf = yourdfpy.URDF.load(f.name)
107 |
108 | # Create robot.
109 | robot = pk.Robot.from_urdf(urdf)
110 |
111 | # Set up visualizer.
112 | server = viser.ViserServer()
113 | server.scene.add_grid("/ground", width=2, height=2)
114 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
115 | target_link_name_handle = server.gui.add_dropdown(
116 | "Target Link",
117 | robot.links.names,
118 | initial_value=robot.links.names[-1],
119 | )
120 |
121 | # Create interactive controller with initial position.
122 | ik_target = server.scene.add_transform_controls(
123 | "/ik_target", scale=0.2, position=(0.0, 0.1, 0.1), wxyz=(0, 0, 1, 0)
124 | )
125 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
126 |
127 | while True:
128 | # Solve IK.
129 | start_time = time.time()
130 | solution = pks.solve_ik(
131 | robot=robot,
132 | target_link_name=target_link_name_handle.value,
133 | target_position=np.array(ik_target.position),
134 | target_wxyz=np.array(ik_target.wxyz),
135 | )
136 |
137 | # Update timing handle.
138 | elapsed_time = time.time() - start_time
139 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
140 |
141 | # Update visualizer.
142 | urdf_vis.update_cfg(solution)
143 |
144 |
145 | if __name__ == "__main__":
146 | main()
147 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | PyRoki
2 | ==========
3 |
4 | `Project page `_ `•` `arXiv `_ `•` `Code `_
5 |
6 | **PyRoki** is a library for robot kinematic optimization (Python Robot Kinematics).
7 |
8 | 1. **Modular**: Optimization variables and cost functions are decoupled, enabling reusable components across tasks. Objectives like collision avoidance and pose matching can be applied to both IK and trajectory optimization without reimplementation.
9 |
10 | 2. **Extensible**: ``PyRoki`` supports automatic differentiation for user-defined costs with Jacobian computation, a real-time cost-weight tuning interface, and optional analytical Jacobians for performance-critical use cases.
11 |
12 | 3. **Cross-Platform**: ``PyRoki`` runs on CPU, GPU, and TPU, allowing efficient scaling from single-robot use cases to large-scale parallel processing for motion datasets or planning.
13 |
14 | We demonstrate how ``PyRoki`` solves IK, trajectory optimization, and motion retargeting for robot hands and humanoids in a unified framework. It uses a Levenberg-Marquardt optimizer to efficiently solve these tasks, and we evaluate its performance on batched IK.
15 |
16 | Features include:
17 |
18 | - Differentiable robot forward kinematics model from a URDF.
19 | - Automatic generation of robot collision primitives (e.g., capsules).
20 | - Differentiable collision bodies with numpy broadcasting logic.
21 | - Common cost factors (e.g., end effector pose, self/world-collision, manipulability).
22 | - Arbitrary costs, getting Jacobians either calculated :doc:`through autodiff or defined manually`.
23 | - Integration with a `Levenberg-Marquardt Solver `_ that supports optimization on manifolds (e.g., `lie groups `_).
24 | - Cross-platform support (CPU, GPU, TPU) via JAX.
25 |
26 |
27 |
28 | Installation
29 | ------------
30 |
31 | You can install ``pyroki`` with ``pip``, on Python 3.12+:
32 |
33 | .. code-block:: bash
34 |
35 | git clone https://github.com/chungmin99/pyroki.git
36 | cd pyroki
37 | pip install -e .
38 |
39 |
40 | Python 3.10-3.11 should also work, but support may be dropped in the future.
41 |
42 | Limitations
43 | -----------
44 |
45 | - **Soft constraints only**: We use a nonlinear least-squares formulation and model joint limits, collision avoidance, etc. as soft penalties with high weights rather than hard constraints.
46 | - **Static shapes & JIT overhead**: JAX JIT compilation is triggered on first run and when input shapes change (e.g., number of targets, obstacles). Arrays can be pre-padded to vectorize over inputs with different shapes.
47 | - **No sampling-based planners**: We don't include sampling-based planners (e.g., graphs, trees).
48 | - **Collision performance**: Speed and accuracy comparisons against other robot toolkits such as CuRobo have not been extensively performed, and is likely slower than other toolkits for collision-heavy scenarios.
49 |
50 | The following are current implementation limitations that could potentially be addressed in future versions:
51 |
52 | - **Joint types**: We only support revolute, continuous, prismatic, and fixed joints. Other URDF joint types are treated as fixed joints.
53 | - **Collision geometry**: We are limited to sphere, capsule, halfspace, and heightmap geometries. Mesh collision is approximated as capsules.
54 | - **Kinematic structures**: We only support kinematic chains; no closed-loop mechanisms or parallel manipulators.
55 |
56 | Examples
57 | --------
58 |
59 | .. toctree::
60 | :maxdepth: 1
61 | :caption: Examples
62 |
63 | examples/01_basic_ik
64 | examples/02_bimanual_ik
65 | examples/03_mobile_ik
66 | examples/04_ik_with_coll
67 | examples/05_ik_with_manipulability
68 | examples/06_online_planning
69 | examples/07_trajopt
70 | examples/08_ik_with_mimic_joints
71 | examples/09_hand_retargeting
72 | examples/10_humanoid_retargeting
73 |
74 |
75 | Acknowledgements
76 | ----------------
77 | ``PyRoki`` is heavily inspired by the prior work, including but not limited to
78 | `Trac-IK `_,
79 | `cuRobo `_,
80 | `pink `_,
81 | `mink `_,
82 | `Drake `_, and
83 | `Dex-Retargeting `_.
84 | Thank you so much for your great work!
85 |
86 |
87 | Citation
88 | --------
89 |
90 | If you find this work useful, please cite it as follows:
91 |
92 | .. code-block:: bibtex
93 |
94 | @misc{pyroki2025,
95 | title={PyRoki: A Modular Toolkit for Robot Kinematic Optimization},
96 | author={Chung Min Kim* and Brent Yi* and Hongsuk Choi and Yi Ma and Ken Goldberg and Angjoo Kanazawa},
97 | year={2025},
98 | eprint={2505.03728},
99 | archivePrefix={arXiv},
100 | primaryClass={cs.RO},
101 | url={https://arxiv.org/abs/2505.03728},
102 | }
103 |
104 | Thanks for using ``PyRoki``!
105 |
--------------------------------------------------------------------------------
/docs/source/misc/writing_manual_jac.rst:
--------------------------------------------------------------------------------
1 | :orphan:
2 |
3 | Defining Jacobians Manually
4 | =====================================
5 |
6 | ``pyroki`` supports both autodiff and manually defined Jacobians for computing cost gradients.
7 |
8 | For reference, this is the robot pose matching cost :math:`C_\text{pose}`:
9 |
10 | .. math::
11 |
12 | \sum_{i} \left( w_{p,i} \left\| \mathbf{p}_{i}(q) - \mathbf{p}_{i}^* \right\|^2 + w_{R,i} \left\| \text{log}(\mathbf{R}_{i}(q)^{-1} \mathbf{R}_{i}^*) \right\|^2 \right)
13 |
14 |
15 | where :math:`q` is the robot joint configuration, :math:`\mathbf{p}_{i}(q)` is the position of the :math:`i`-th link, :math:`\mathbf{R}_{i}(q)` is the rotation matrix of the :math:`i`-th link, and :math:`w_{p,i}` and :math:`w_{R,i}` are the position and orientation weights, respectively.
16 |
17 | The following is the most common way to define costs in ``pyroki`` -- with autodiff:
18 |
19 | .. code-block:: python
20 |
21 | @Cost.create_factory
22 | def pose_cost(
23 | vals: VarValues,
24 | robot: Robot,
25 | joint_var: Var[Array],
26 | target_pose: jaxlie.SE3,
27 | target_link_index: Array,
28 | pos_weight: Array | float,
29 | ori_weight: Array | float,
30 | ) -> Array:
31 | """Computes the residual for matching link poses to target poses."""
32 | assert target_link_index.dtype == jnp.int32
33 | joint_cfg = vals[joint_var]
34 | Ts_link_world = robot.forward_kinematics(joint_cfg)
35 | pose_actual = jaxlie.SE3(Ts_link_world[..., target_link_index, :])
36 |
37 | # Position residual = position error * weight
38 | pos_residual = (pose_actual.translation() - target_pose.translation()) * pos_weight
39 | # Orientation residual = log(actual_inv * target) * weight
40 | ori_residual = (pose_actual.rotation().inverse() @ target_pose.rotation()).log() * ori_weight
41 |
42 | return jnp.concatenate([pos_residual, ori_residual]).flatten()
43 |
44 | The alternative is to manually write out the Jacobian -- while automatic differentiation is convenient and works well for most use cases, analytical Jacobians can provide better performance, which we show in the `paper `_.
45 |
46 | We provide two implementations of pose matching cost with custom Jacobians:
47 |
48 | - an `analytically derived Jacobian `_ (~200 lines), or
49 | - a `numerically approximated Jacobian `_ through finite differences (~50 lines).
50 |
--------------------------------------------------------------------------------
/docs/update_example_docs.py:
--------------------------------------------------------------------------------
1 | """Helper script for updating the auto-generated examples pages in the documentation."""
2 |
3 | from __future__ import annotations
4 |
5 | import dataclasses
6 | import pathlib
7 | import shutil
8 | from typing import Iterable
9 |
10 | import m2r2
11 | import tyro
12 |
13 |
14 | @dataclasses.dataclass
15 | class ExampleMetadata:
16 | index: str
17 | index_with_zero: str
18 | source: str
19 | title: str
20 | description: str
21 |
22 | @staticmethod
23 | def from_path(path: pathlib.Path) -> ExampleMetadata:
24 | # 01_functions -> 01, _, functions.
25 | index, _, _ = path.stem.partition("_")
26 |
27 | # 01 -> 1.
28 | index_with_zero = index
29 | index = str(int(index))
30 |
31 | print("Parsing", path)
32 | source = path.read_text().strip()
33 | docstring = source.split('"""')[1].strip()
34 |
35 | title, _, description = docstring.partition("\n")
36 |
37 | description = description.strip()
38 | description += "\n"
39 | description += "\n"
40 | description += "All examples can be run by first cloning the PyRoki repository, which includes the `pyroki_snippets` implementation details."
41 |
42 | return ExampleMetadata(
43 | index=index,
44 | index_with_zero=index_with_zero,
45 | source=source.partition('"""')[2].partition('"""')[2].strip(),
46 | title=title,
47 | description=description,
48 | )
49 |
50 |
51 | def get_example_paths(examples_dir: pathlib.Path) -> Iterable[pathlib.Path]:
52 | return filter(
53 | lambda p: not p.name.startswith("_"), sorted(examples_dir.glob("*.py"))
54 | )
55 |
56 |
57 | REPO_ROOT = pathlib.Path(__file__).absolute().parent.parent
58 |
59 |
60 | def main(
61 | examples_dir: pathlib.Path = REPO_ROOT / "examples",
62 | sphinx_source_dir: pathlib.Path = REPO_ROOT / "docs" / "source",
63 | ) -> None:
64 | example_doc_dir = sphinx_source_dir / "examples"
65 | shutil.rmtree(example_doc_dir)
66 | example_doc_dir.mkdir()
67 |
68 | for path in get_example_paths(examples_dir):
69 | ex = ExampleMetadata.from_path(path)
70 |
71 | relative_dir = path.parent.relative_to(examples_dir)
72 | target_dir = example_doc_dir / relative_dir
73 | target_dir.mkdir(exist_ok=True, parents=True)
74 |
75 | (target_dir / f"{path.stem}.rst").write_text(
76 | "\n".join(
77 | [
78 | (
79 | ".. Comment: this file is automatically generated by"
80 | " `update_example_docs.py`."
81 | ),
82 | " It should not be modified manually.",
83 | "",
84 | f"{ex.title}",
85 | "==========================================",
86 | "",
87 | m2r2.convert(ex.description),
88 | "",
89 | "",
90 | ".. code-block:: python",
91 | " :linenos:",
92 | "",
93 | "",
94 | "\n".join(
95 | f" {line}".rstrip() for line in ex.source.split("\n")
96 | ),
97 | "",
98 | ]
99 | )
100 | )
101 |
102 |
103 | if __name__ == "__main__":
104 | tyro.cli(main, description=__doc__)
105 |
--------------------------------------------------------------------------------
/examples/01_basic_ik.py:
--------------------------------------------------------------------------------
1 | """Basic IK
2 |
3 | Simplest Inverse Kinematics Example using PyRoki.
4 | """
5 |
6 | import time
7 |
8 | import numpy as np
9 | import pyroki as pk
10 | import viser
11 | from robot_descriptions.loaders.yourdfpy import load_robot_description
12 | from viser.extras import ViserUrdf
13 |
14 | import pyroki_snippets as pks
15 |
16 |
17 | def main():
18 | """Main function for basic IK."""
19 |
20 | urdf = load_robot_description("panda_description")
21 | target_link_name = "panda_hand"
22 |
23 | # Create robot.
24 | robot = pk.Robot.from_urdf(urdf)
25 |
26 | # Set up visualizer.
27 | server = viser.ViserServer()
28 | server.scene.add_grid("/ground", width=2, height=2)
29 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
30 |
31 | # Create interactive controller with initial position.
32 | ik_target = server.scene.add_transform_controls(
33 | "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0)
34 | )
35 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
36 |
37 | while True:
38 | # Solve IK.
39 | start_time = time.time()
40 | solution = pks.solve_ik(
41 | robot=robot,
42 | target_link_name=target_link_name,
43 | target_position=np.array(ik_target.position),
44 | target_wxyz=np.array(ik_target.wxyz),
45 | )
46 |
47 | # Update timing handle.
48 | elapsed_time = time.time() - start_time
49 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
50 |
51 | # Update visualizer.
52 | urdf_vis.update_cfg(solution)
53 |
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/examples/02_bimanual_ik.py:
--------------------------------------------------------------------------------
1 | """Bimanual IK
2 |
3 | Same as 01_basic_ik.py, but with two end effectors!
4 | """
5 |
6 | import time
7 | import viser
8 | from robot_descriptions.loaders.yourdfpy import load_robot_description
9 | import numpy as np
10 |
11 | import pyroki as pk
12 | from viser.extras import ViserUrdf
13 | import pyroki_snippets as pks
14 |
15 |
16 | def main():
17 | """Main function for bimanual IK."""
18 |
19 | urdf = load_robot_description("yumi_description")
20 | target_link_names = ["yumi_link_7_r", "yumi_link_7_l"]
21 |
22 | # Create robot.
23 | robot = pk.Robot.from_urdf(urdf)
24 |
25 | # Set up visualizer.
26 | server = viser.ViserServer()
27 | server.scene.add_grid("/ground", width=2, height=2)
28 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
29 |
30 | # Create interactive controller with initial position.
31 | ik_target_0 = server.scene.add_transform_controls(
32 | "/ik_target_0", scale=0.2, position=(0.41, -0.3, 0.56), wxyz=(0, 0, 1, 0)
33 | )
34 | ik_target_1 = server.scene.add_transform_controls(
35 | "/ik_target_1", scale=0.2, position=(0.41, 0.3, 0.56), wxyz=(0, 0, 1, 0)
36 | )
37 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
38 |
39 | while True:
40 | # Solve IK.
41 | start_time = time.time()
42 | solution = pks.solve_ik_with_multiple_targets(
43 | robot=robot,
44 | target_link_names=target_link_names,
45 | target_positions=np.array([ik_target_0.position, ik_target_1.position]),
46 | target_wxyzs=np.array([ik_target_0.wxyz, ik_target_1.wxyz]),
47 | )
48 |
49 | # Update timing handle.
50 | elapsed_time = time.time() - start_time
51 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
52 |
53 | # Update visualizer.
54 | urdf_vis.update_cfg(solution)
55 |
56 |
57 | if __name__ == "__main__":
58 | main()
59 |
--------------------------------------------------------------------------------
/examples/03_mobile_ik.py:
--------------------------------------------------------------------------------
1 | """Mobile IK
2 |
3 | Same as 01_basic_ik.py, but with a mobile base!
4 | """
5 |
6 | import time
7 | import viser
8 | from robot_descriptions.loaders.yourdfpy import load_robot_description
9 | import numpy as np
10 |
11 | import pyroki as pk
12 | from viser.extras import ViserUrdf
13 | import pyroki_snippets as pks
14 |
15 |
16 | def main():
17 | """Main function for IK with a mobile base.
18 | The base is fixed along the xy plane, and is biased towards being at the origin.
19 | """
20 |
21 | urdf = load_robot_description("fetch_description")
22 | target_link_name = "gripper_link"
23 |
24 | # Create robot.
25 | robot = pk.Robot.from_urdf(urdf)
26 |
27 | # Set up visualizer.
28 | server = viser.ViserServer()
29 | server.scene.add_grid("/ground", width=2, height=2)
30 | base_frame = server.scene.add_frame("/base", show_axes=False)
31 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
32 |
33 | # Create interactive controller with initial position.
34 | ik_target = server.scene.add_transform_controls(
35 | "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0.707, 0, -0.707)
36 | )
37 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
38 |
39 | cfg = np.array(robot.joint_var_cls(0).default_factory())
40 |
41 | while True:
42 | # Solve IK.
43 | start_time = time.time()
44 | base_pos, base_wxyz, cfg = pks.solve_ik_with_base(
45 | robot=robot,
46 | target_link_name=target_link_name,
47 | target_position=np.array(ik_target.position),
48 | target_wxyz=np.array(ik_target.wxyz),
49 | fix_base_position=(False, False, True), # Only free along xy plane.
50 | fix_base_orientation=(True, True, False), # Free along z-axis rotation.
51 | prev_pos=base_frame.position,
52 | prev_wxyz=base_frame.wxyz,
53 | prev_cfg=cfg,
54 | )
55 |
56 | # Update timing handle.
57 | elapsed_time = time.time() - start_time
58 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
59 |
60 | # Update visualizer.
61 | urdf_vis.update_cfg(cfg)
62 | base_frame.position = np.array(base_pos)
63 | base_frame.wxyz = np.array(base_wxyz)
64 |
65 |
66 | if __name__ == "__main__":
67 | main()
68 |
--------------------------------------------------------------------------------
/examples/04_ik_with_coll.py:
--------------------------------------------------------------------------------
1 | """IK with Collision
2 |
3 | Basic Inverse Kinematics with Collision Avoidance using PyRoKi.
4 | """
5 |
6 | import time
7 |
8 | import numpy as np
9 | import pyroki as pk
10 | import viser
11 | from pyroki.collision import HalfSpace, RobotCollision, Sphere
12 | from robot_descriptions.loaders.yourdfpy import load_robot_description
13 | from viser.extras import ViserUrdf
14 |
15 | import pyroki_snippets as pks
16 |
17 |
18 | def main():
19 | """Main function for basic IK with collision."""
20 | urdf = load_robot_description("panda_description")
21 | target_link_name = "panda_hand"
22 | robot = pk.Robot.from_urdf(urdf)
23 |
24 | robot_coll = RobotCollision.from_urdf(urdf)
25 | plane_coll = HalfSpace.from_point_and_normal(
26 | np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
27 | )
28 | sphere_coll = Sphere.from_center_and_radius(
29 | np.array([0.0, 0.0, 0.0]), np.array([0.05])
30 | )
31 |
32 | # Set up visualizer.
33 | server = viser.ViserServer()
34 | server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
35 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot")
36 |
37 | # Create interactive controller for IK target.
38 | ik_target_handle = server.scene.add_transform_controls(
39 | "/ik_target", scale=0.2, position=(0.5, 0.0, 0.5), wxyz=(0, 0, 1, 0)
40 | )
41 |
42 | # Create interactive controller and mesh for the sphere obstacle.
43 | sphere_handle = server.scene.add_transform_controls(
44 | "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4)
45 | )
46 | server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh())
47 |
48 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
49 |
50 | while True:
51 | start_time = time.time()
52 |
53 | sphere_coll_world_current = sphere_coll.transform_from_wxyz_position(
54 | wxyz=np.array(sphere_handle.wxyz),
55 | position=np.array(sphere_handle.position),
56 | )
57 |
58 | world_coll_list = [plane_coll, sphere_coll_world_current]
59 | solution = pks.solve_ik_with_collision(
60 | robot=robot,
61 | coll=robot_coll,
62 | world_coll_list=world_coll_list,
63 | target_link_name=target_link_name,
64 | target_position=np.array(ik_target_handle.position),
65 | target_wxyz=np.array(ik_target_handle.wxyz),
66 | )
67 |
68 | # Update timing handle.
69 | timing_handle.value = (time.time() - start_time) * 1000
70 |
71 | # Update visualizer.
72 | urdf_vis.update_cfg(solution)
73 |
74 |
75 | if __name__ == "__main__":
76 | main()
77 |
--------------------------------------------------------------------------------
/examples/05_ik_with_manipulability.py:
--------------------------------------------------------------------------------
1 | """IK with Manipulability
2 |
3 | Inverse Kinematics with Manipulability using PyRoKi.
4 | """
5 |
6 | import time
7 | import viser
8 | from robot_descriptions.loaders.yourdfpy import load_robot_description
9 | import numpy as np
10 |
11 | import pyroki as pk
12 | from viser.extras import ViserUrdf
13 | import pyroki_snippets as pks
14 |
15 |
16 | def main():
17 | """Main function for basic IK."""
18 |
19 | urdf = load_robot_description("panda_description")
20 | target_link_name = "panda_hand"
21 |
22 | # Create robot.
23 | robot = pk.Robot.from_urdf(urdf)
24 |
25 | # Set up visualizer.
26 | server = viser.ViserServer()
27 | server.scene.add_grid("/ground", width=2, height=2)
28 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
29 |
30 | # Create interactive controller with initial position.
31 | ik_target = server.scene.add_transform_controls(
32 | "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0)
33 | )
34 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
35 | value_handle = server.gui.add_number("Yoshikawa Index", 0.001, disabled=True)
36 | weight_handle = server.gui.add_slider(
37 | "Manipulability Weight", 0.0, 10.0, 0.001, 0.0
38 | )
39 | manip_ellipse = pk.viewer.ManipulabilityEllipse(
40 | server,
41 | robot,
42 | root_node_name="/manipulability",
43 | target_link_name=target_link_name,
44 | )
45 |
46 | while True:
47 | # Solve IK.
48 | start_time = time.time()
49 | solution = pks.solve_ik_with_manipulability(
50 | robot=robot,
51 | target_link_name=target_link_name,
52 | target_position=np.array(ik_target.position),
53 | target_wxyz=np.array(ik_target.wxyz),
54 | manipulability_weight=weight_handle.value,
55 | )
56 |
57 | manip_ellipse.update(solution)
58 | value_handle.value = manip_ellipse.manipulability
59 |
60 | # Update timing handle.
61 | elapsed_time = time.time() - start_time
62 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
63 |
64 | # Update visualizer.
65 | urdf_vis.update_cfg(solution)
66 |
67 |
68 | if __name__ == "__main__":
69 | main()
70 |
--------------------------------------------------------------------------------
/examples/06_online_planning.py:
--------------------------------------------------------------------------------
1 | """Online Planning
2 |
3 | Run online planning in collision aware environments.
4 | """
5 |
6 | import time
7 |
8 | import numpy as np
9 | import pyroki as pk
10 | import viser
11 | from pyroki.collision import HalfSpace, RobotCollision, Sphere
12 | from robot_descriptions.loaders.yourdfpy import load_robot_description
13 | from viser.extras import ViserUrdf
14 |
15 | import pyroki_snippets as pks
16 |
17 |
18 | def main():
19 | """Main function for online planning with collision."""
20 | urdf = load_robot_description("panda_description")
21 | target_link_name = "panda_hand"
22 | robot = pk.Robot.from_urdf(urdf)
23 |
24 | robot_coll = RobotCollision.from_urdf(urdf)
25 | plane_coll = HalfSpace.from_point_and_normal(
26 | np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
27 | )
28 | sphere_coll = Sphere.from_center_and_radius(
29 | np.array([0.0, 0.0, 0.0]), np.array([0.05])
30 | )
31 |
32 | # Define the online planning parameters.
33 | len_traj, dt = 5, 0.1
34 |
35 | # Set up visualizer.
36 | server = viser.ViserServer()
37 | server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
38 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot")
39 |
40 | # Create interactive controller for IK target.
41 | ik_target_handle = server.scene.add_transform_controls(
42 | "/ik_target", scale=0.2, position=(0.3, 0.0, 0.5), wxyz=(0, 0, 1, 0)
43 | )
44 |
45 | # Create interactive controller and mesh for the sphere obstacle.
46 | sphere_handle = server.scene.add_transform_controls(
47 | "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4)
48 | )
49 | server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh())
50 | target_frame_handle = server.scene.add_batched_axes(
51 | "target_frame",
52 | axes_length=0.05,
53 | axes_radius=0.005,
54 | batched_positions=np.zeros((25, 3)),
55 | batched_wxyzs=np.array([[1.0, 0.0, 0.0, 0.0]] * 25),
56 | )
57 |
58 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
59 |
60 | sol_pos, sol_wxyz = None, None
61 | sol_traj = np.array(
62 | robot.joint_var_cls.default_factory()[None].repeat(len_traj, axis=0)
63 | )
64 | while True:
65 | start_time = time.time()
66 |
67 | sphere_coll_world_current = sphere_coll.transform_from_wxyz_position(
68 | wxyz=np.array(sphere_handle.wxyz),
69 | position=np.array(sphere_handle.position),
70 | )
71 |
72 | world_coll_list = [plane_coll, sphere_coll_world_current]
73 | sol_traj, sol_pos, sol_wxyz = pks.solve_online_planning(
74 | robot=robot,
75 | robot_coll=robot_coll,
76 | world_coll=world_coll_list,
77 | target_link_name=target_link_name,
78 | target_position=np.array(ik_target_handle.position),
79 | target_wxyz=np.array(ik_target_handle.wxyz),
80 | timesteps=len_traj,
81 | dt=dt,
82 | start_cfg=sol_traj[0],
83 | prev_sols=sol_traj,
84 | )
85 |
86 | # Update timing handle.
87 | timing_handle.value = (
88 | 0.99 * timing_handle.value + 0.01 * (time.time() - start_time) * 1000
89 | )
90 |
91 | # Update visualizer.
92 | urdf_vis.update_cfg(
93 | sol_traj[0]
94 | ) # The first step of the online trajectory solution.
95 |
96 | # Update the planned trajectory visualization.
97 | if hasattr(target_frame_handle, "batched_positions"):
98 | target_frame_handle.batched_positions = np.array(sol_pos) # type: ignore[attr-defined]
99 | target_frame_handle.batched_wxyzs = np.array(sol_wxyz) # type: ignore[attr-defined]
100 | else:
101 | # This is an older version of Viser.
102 | target_frame_handle.positions_batched = np.array(sol_pos) # type: ignore[attr-defined]
103 | target_frame_handle.wxyzs_batched = np.array(sol_wxyz) # type: ignore[attr-defined]
104 |
105 |
106 | if __name__ == "__main__":
107 | main()
108 |
--------------------------------------------------------------------------------
/examples/07_trajopt.py:
--------------------------------------------------------------------------------
1 | """Trajectory Optimization
2 |
3 | Basic Trajectory Optimization using PyRoKi.
4 |
5 | Robot going over a wall, while avoiding world-collisions.
6 | """
7 |
8 | import time
9 | from typing import Literal
10 |
11 | import numpy as np
12 | import pyroki as pk
13 | import trimesh
14 | import tyro
15 | import viser
16 | from viser.extras import ViserUrdf
17 | from robot_descriptions.loaders.yourdfpy import load_robot_description
18 |
19 | import pyroki_snippets as pks
20 |
21 |
22 | def main(robot_name: Literal["ur5", "panda"] = "panda"):
23 | if robot_name == "ur5":
24 | urdf = load_robot_description("ur5_description")
25 | down_wxyz = np.array([0.707, 0, 0.707, 0])
26 | target_link_name = "ee_link"
27 |
28 | # For UR5 it's important to initialize the robot in a safe configuration;
29 | # the zero-configuration puts the robot aligned with the wall obstacle.
30 | default_cfg = np.zeros(6)
31 | default_cfg[1] = -1.308
32 | robot = pk.Robot.from_urdf(urdf, default_joint_cfg=default_cfg)
33 |
34 | elif robot_name == "panda":
35 | urdf = load_robot_description("panda_description")
36 | target_link_name = "panda_hand"
37 | down_wxyz = np.array([0, 0, 1, 0]) # for panda!
38 | robot = pk.Robot.from_urdf(urdf)
39 |
40 | else:
41 | raise ValueError(f"Invalid robot: {robot_name}")
42 |
43 | robot_coll = pk.collision.RobotCollision.from_urdf(urdf)
44 |
45 | # Define the trajectory problem:
46 | # - number of timesteps, timestep size
47 | timesteps, dt = 25, 0.02
48 | # - the start and end poses.
49 | start_pos, end_pos = np.array([0.5, -0.3, 0.2]), np.array([0.5, 0.3, 0.2])
50 |
51 | # Define the obstacles:
52 | # - Ground
53 | ground_coll = pk.collision.HalfSpace.from_point_and_normal(
54 | np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
55 | )
56 | # - Wall
57 | wall_height = 0.4
58 | wall_width = 0.1
59 | wall_length = 0.4
60 | wall_intervals = np.arange(start=0.3, stop=wall_length + 0.3, step=0.05)
61 | translation = np.concatenate(
62 | [
63 | wall_intervals.reshape(-1, 1),
64 | np.full((wall_intervals.shape[0], 1), 0.0),
65 | np.full((wall_intervals.shape[0], 1), wall_height / 2),
66 | ],
67 | axis=1,
68 | )
69 | wall_coll = pk.collision.Capsule.from_radius_height(
70 | position=translation,
71 | radius=np.full((translation.shape[0], 1), wall_width / 2),
72 | height=np.full((translation.shape[0], 1), wall_height),
73 | )
74 | world_coll = [ground_coll, wall_coll]
75 |
76 | traj = pks.solve_trajopt(
77 | robot,
78 | robot_coll,
79 | world_coll,
80 | target_link_name,
81 | start_pos,
82 | down_wxyz,
83 | end_pos,
84 | down_wxyz,
85 | timesteps,
86 | dt,
87 | )
88 | traj = np.array(traj)
89 |
90 | # Visualize!
91 | server = viser.ViserServer()
92 | urdf_vis = ViserUrdf(server, urdf)
93 | server.scene.add_grid("/grid", width=2, height=2, cell_size=0.1)
94 | server.scene.add_mesh_trimesh(
95 | "wall_box",
96 | trimesh.creation.box(
97 | extents=(wall_length, wall_width, wall_height),
98 | transform=trimesh.transformations.translation_matrix(
99 | np.array([0.5, 0.0, wall_height / 2])
100 | ),
101 | ),
102 | )
103 | for name, pos in zip(["start", "end"], [start_pos, end_pos]):
104 | server.scene.add_frame(
105 | f"/{name}",
106 | position=pos,
107 | wxyz=down_wxyz,
108 | axes_length=0.05,
109 | axes_radius=0.01,
110 | )
111 |
112 | slider = server.gui.add_slider(
113 | "Timestep", min=0, max=timesteps - 1, step=1, initial_value=0
114 | )
115 | playing = server.gui.add_checkbox("Playing", initial_value=True)
116 |
117 | while True:
118 | if playing.value:
119 | slider.value = (slider.value + 1) % timesteps
120 |
121 | urdf_vis.update_cfg(traj[slider.value])
122 | time.sleep(1.0 / 10.0)
123 |
124 |
125 | if __name__ == "__main__":
126 | tyro.cli(main)
127 |
--------------------------------------------------------------------------------
/examples/08_ik_with_mimic_joints.py:
--------------------------------------------------------------------------------
1 | """IK with Mimic Joints
2 |
3 | This is a simple test to ensure that mimic joints are handled correctly in the IK solver.
4 |
5 | We procedurally generate a "zig-zag" chain of links with mimic joints, where:
6 | - the first joint is driven directly,
7 | - and the remaining joints are driven indirectly via mimic joints.
8 | The multipliers alternate between -1 and 1, and the offsets are all 0.
9 | """
10 |
11 | import time
12 |
13 | import numpy as np
14 | import pyroki as pk
15 | import viser
16 | from viser.extras import ViserUrdf
17 |
18 | import pyroki_snippets as pks
19 |
20 |
21 | def create_chain_xml(length: float = 0.2, num_chains: int = 5) -> str:
22 | def create_link(idx):
23 | return f"""
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 | """
36 |
37 | def create_joint(idx, multiplier=1.0, offset=0.0):
38 | mimic = f''
39 | return f"""
40 |
41 |
42 |
43 |
44 |
45 | {mimic if idx != 0 else ""}
46 |
47 |
48 | """
49 |
50 | world_joint_origin_z = length / 2.0
51 | xml = f"""
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 | """
63 | # Create the definition + first link.
64 | xml += create_link(0)
65 | xml += create_link(1)
66 | xml += create_joint(0)
67 |
68 | # Procedurally add more links.
69 | assert num_chains >= 2
70 | for idx in range(2, num_chains):
71 | xml += create_link(idx)
72 | current_offset = 0.0
73 | current_multiplier = 1.0 * ((-1) ** (idx % 2))
74 | xml += create_joint(idx - 1, current_multiplier, current_offset)
75 |
76 | xml += """
77 |
78 | """
79 | return xml
80 |
81 |
82 | def main():
83 | """Main function for basic IK."""
84 |
85 | import yourdfpy
86 | import tempfile
87 |
88 | xml = create_chain_xml(num_chains=10, length=0.1)
89 | with tempfile.NamedTemporaryFile(mode="w", suffix=".urdf") as f:
90 | f.write(xml)
91 | f.flush()
92 | urdf = yourdfpy.URDF.load(f.name)
93 |
94 | # Create robot.
95 | robot = pk.Robot.from_urdf(urdf)
96 |
97 | # Set up visualizer.
98 | server = viser.ViserServer()
99 | server.scene.add_grid("/ground", width=2, height=2)
100 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
101 | target_link_name_handle = server.gui.add_dropdown(
102 | "Target Link",
103 | robot.links.names,
104 | initial_value=robot.links.names[-1],
105 | )
106 |
107 | # Create interactive controller with initial position.
108 | ik_target = server.scene.add_transform_controls(
109 | "/ik_target", scale=0.2, position=(0.0, 0.1, 0.1), wxyz=(0, 0, 1, 0)
110 | )
111 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
112 |
113 | while True:
114 | # Solve IK.
115 | start_time = time.time()
116 | solution = pks.solve_ik(
117 | robot=robot,
118 | target_link_name=target_link_name_handle.value,
119 | target_position=np.array(ik_target.position),
120 | target_wxyz=np.array(ik_target.wxyz),
121 | )
122 |
123 | # Update timing handle.
124 | elapsed_time = time.time() - start_time
125 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
126 |
127 | # Update visualizer.
128 | urdf_vis.update_cfg(solution)
129 |
130 |
131 | if __name__ == "__main__":
132 | main()
133 |
--------------------------------------------------------------------------------
/examples/pyroki_snippets/__init__.py:
--------------------------------------------------------------------------------
1 | from ._online_planning import solve_online_planning as solve_online_planning
2 | from ._solve_ik import solve_ik as solve_ik
3 | from ._solve_ik_with_base import solve_ik_with_base as solve_ik_with_base
4 | from ._solve_ik_with_collision import solve_ik_with_collision as solve_ik_with_collision
5 | from ._solve_ik_with_manipulability import (
6 | solve_ik_with_manipulability as solve_ik_with_manipulability,
7 | )
8 | from ._trajopt import solve_trajopt as solve_trajopt
9 | from ._solve_ik_with_multiple_targets import (
10 | solve_ik_with_multiple_targets as solve_ik_with_multiple_targets,
11 | )
12 |
--------------------------------------------------------------------------------
/examples/pyroki_snippets/_online_planning.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import jax_dataclasses as jdc
6 | import jaxlie
7 | import jaxls
8 | import numpy as onp
9 | import pyroki as pk
10 |
11 |
12 | def solve_online_planning(
13 | robot: pk.Robot,
14 | robot_coll: pk.collision.RobotCollision,
15 | world_coll: Sequence[pk.collision.CollGeom],
16 | target_link_name: str,
17 | target_position: onp.ndarray,
18 | target_wxyz: onp.ndarray,
19 | timesteps: int,
20 | dt: float,
21 | start_cfg: onp.ndarray,
22 | prev_sols: onp.ndarray,
23 | ) -> tuple[onp.ndarray, onp.ndarray, onp.ndarray]:
24 | """Solve online planning with collision."""
25 |
26 | assert target_position.shape == (3,) and target_wxyz.shape == (4,)
27 | target_link_indices = [robot.links.names.index(target_link_name)]
28 |
29 | target_poses = jaxlie.SE3(
30 | jnp.concatenate([jnp.array(target_wxyz), jnp.array(target_position)], axis=-1)
31 | )
32 | target_links = jnp.array(target_link_indices)
33 |
34 | # Warm start: use previous solution shifted by one step.
35 | timesteps = timesteps + 1 # for start pose cost.
36 |
37 | sol_traj, sol_pos, sol_wxyz = _solve_online_planning_jax(
38 | robot,
39 | robot_coll,
40 | world_coll,
41 | target_poses,
42 | target_links,
43 | timesteps,
44 | dt,
45 | jnp.array(start_cfg),
46 | jnp.concatenate([prev_sols, prev_sols[-1:]], axis=0),
47 | )
48 | sol_traj = sol_traj[1:]
49 | sol_pos = sol_pos[1:]
50 | sol_wxyz = sol_wxyz[1:]
51 |
52 | return onp.array(sol_traj), onp.array(sol_pos), onp.array(sol_wxyz)
53 |
54 |
55 | @jdc.jit
56 | def _solve_online_planning_jax(
57 | robot: pk.Robot,
58 | robot_coll: pk.collision.RobotCollision,
59 | world_coll: Sequence[pk.collision.CollGeom],
60 | target_poses: jaxlie.SE3,
61 | target_links: jnp.ndarray,
62 | timesteps: jdc.Static[int],
63 | dt: float,
64 | start_cfg: jnp.ndarray,
65 | prev_sols: jnp.ndarray,
66 | ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
67 | num_targets = len(target_links)
68 |
69 | def batched_rplus(
70 | pose: jaxlie.SE3,
71 | delta: jax.Array,
72 | ) -> jaxlie.SE3:
73 | return jax.vmap(jaxlie.manifold.rplus)(pose, delta.reshape(num_targets, -1))
74 |
75 | # Custom SE3 variable to batch across multiple joint targets.
76 | # This is not to be confused with SE3Vars with ids, which we use here for timesteps.
77 | class BatchedSE3Var( # pylint: disable=missing-class-docstring
78 | jaxls.Var[jaxlie.SE3],
79 | default_factory=lambda: jaxlie.SE3.identity((num_targets,)),
80 | retract_fn=batched_rplus,
81 | tangent_dim=jaxlie.SE3.tangent_dim * num_targets,
82 | ): ...
83 |
84 | # --- Define Variables ---
85 | traj_var = robot.joint_var_cls(jnp.arange(0, timesteps))
86 | traj_var_prev = robot.joint_var_cls(jnp.arange(0, timesteps - 1))
87 | traj_var_next = robot.joint_var_cls(jnp.arange(1, timesteps))
88 | pose_var = BatchedSE3Var(jnp.arange(0, timesteps))
89 | pose_var_prev = BatchedSE3Var(jnp.arange(0, timesteps - 1))
90 | pose_var_next = BatchedSE3Var(jnp.arange(1, timesteps))
91 |
92 | init_pose_vals = jaxlie.SE3(
93 | robot.forward_kinematics(prev_sols)[..., target_links, :]
94 | )
95 |
96 | # --- Define Costs ---
97 | factors: list[jaxls.Cost] = [] # Changed type hint to jaxls.Cost
98 |
99 | @jaxls.Cost.create_factory(name="SE3PoseMatchJointCost")
100 | def match_joint_to_pose_cost(
101 | vals: jaxls.VarValues,
102 | joint_var: jaxls.Var[jnp.ndarray],
103 | pose_var: BatchedSE3Var,
104 | ):
105 | joint_cfg = vals[joint_var]
106 | target_pose = vals[pose_var]
107 | Ts_joint_world = robot.forward_kinematics(joint_cfg)
108 | residual = (
109 | (jaxlie.SE3(Ts_joint_world[..., target_links, :])).inverse() @ (target_pose)
110 | ).log()
111 | return residual.flatten() * 100.0
112 |
113 | @jaxls.Cost.create_factory(name="SE3SmoothnessCost")
114 | def pose_smoothness_cost(
115 | vals: jaxls.VarValues,
116 | pose_var: BatchedSE3Var,
117 | pose_var_prev: BatchedSE3Var,
118 | ):
119 | return (vals[pose_var].inverse() @ vals[pose_var_prev]).log().flatten() * 1.0
120 |
121 | @jaxls.Cost.create_factory(name="SE3PoseMatchCost")
122 | def pose_match_cost(
123 | vals: jaxls.VarValues,
124 | pose_var: BatchedSE3Var,
125 | ):
126 | return (
127 | (vals[pose_var].inverse() @ target_poses).log()
128 | * jnp.array([50.0] * 3 + [20.0] * 3)
129 | ).flatten()
130 |
131 | @jaxls.Cost.create_factory(name="MatchStartPoseCost")
132 | def match_start_pose_cost(
133 | vals: jaxls.VarValues,
134 | joint_var: jaxls.Var[jnp.ndarray],
135 | ):
136 | return (vals[joint_var] - start_cfg).flatten() * 100.0
137 |
138 | # Add pose costs.
139 | factors.extend(
140 | [
141 | pose_match_cost(
142 | BatchedSE3Var(timesteps - 1),
143 | ),
144 | pose_smoothness_cost(
145 | pose_var_next,
146 | pose_var_prev,
147 | ),
148 | ]
149 | )
150 |
151 | # Need to constrain the start joint cfg.
152 | factors.append(match_start_pose_cost(robot.joint_var_cls(0)))
153 |
154 | # Add joint costs.
155 | factors.extend(
156 | [
157 | match_joint_to_pose_cost(
158 | traj_var,
159 | pose_var,
160 | ),
161 | pk.costs.smoothness_cost(
162 | traj_var_prev,
163 | traj_var_next,
164 | weight=10.0,
165 | ),
166 | pk.costs.limit_velocity_cost(
167 | jax.tree.map(lambda x: x[None], robot),
168 | traj_var_prev,
169 | traj_var_next,
170 | weight=10.0,
171 | dt=dt,
172 | ),
173 | pk.costs.limit_cost(
174 | jax.tree.map(lambda x: x[None], robot),
175 | traj_var,
176 | weight=100.0,
177 | ),
178 | pk.costs.rest_cost(
179 | traj_var,
180 | jnp.array(traj_var.default_factory())[None],
181 | weight=0.01,
182 | ),
183 | pk.costs.manipulability_cost(
184 | jax.tree.map(lambda x: x[None], robot),
185 | traj_var,
186 | weight=0.01,
187 | target_link_indices=target_links,
188 | ),
189 | pk.costs.self_collision_cost(
190 | jax.tree.map(lambda x: x[None], robot),
191 | jax.tree.map(lambda x: x[None], robot_coll),
192 | traj_var,
193 | weight=10.0,
194 | margin=0.02,
195 | ),
196 | ]
197 | )
198 | factors.extend(
199 | [
200 | pk.costs.world_collision_cost(
201 | jax.tree.map(lambda x: x[None], robot),
202 | jax.tree.map(lambda x: x[None], robot_coll),
203 | traj_var,
204 | jax.tree.map(lambda x: x[None], obs),
205 | weight=20.0,
206 | margin=0.1,
207 | )
208 | for obs in world_coll
209 | ]
210 | )
211 |
212 | solution = (
213 | jaxls.LeastSquaresProblem(factors, [traj_var, pose_var])
214 | .analyze()
215 | .solve(
216 | verbose=False,
217 | initial_vals=jaxls.VarValues.make(
218 | (traj_var.with_value(prev_sols), pose_var.with_value(init_pose_vals))
219 | ),
220 | termination=jaxls.TerminationConfig(max_iterations=20),
221 | )
222 | )
223 | pose_traj = solution[pose_var]
224 | return (
225 | solution[traj_var],
226 | pose_traj.translation(),
227 | pose_traj.rotation().wxyz,
228 | )
229 |
--------------------------------------------------------------------------------
/examples/pyroki_snippets/_solve_ik.py:
--------------------------------------------------------------------------------
1 | """
2 | Solves the basic IK problem.
3 | """
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import jax_dataclasses as jdc
8 | import jaxlie
9 | import jaxls
10 | import numpy as onp
11 | import pyroki as pk
12 |
13 |
14 | def solve_ik(
15 | robot: pk.Robot,
16 | target_link_name: str,
17 | target_wxyz: onp.ndarray,
18 | target_position: onp.ndarray,
19 | ) -> onp.ndarray:
20 | """
21 | Solves the basic IK problem for a robot.
22 |
23 | Args:
24 | robot: PyRoKi Robot.
25 | target_link_name: String name of the link to be controlled.
26 | target_wxyz: onp.ndarray. Target orientation.
27 | target_position: onp.ndarray. Target position.
28 |
29 | Returns:
30 | cfg: onp.ndarray. Shape: (robot.joint.actuated_count,).
31 | """
32 | assert target_position.shape == (3,) and target_wxyz.shape == (4,)
33 | target_link_index = robot.links.names.index(target_link_name)
34 | cfg = _solve_ik_jax(
35 | robot,
36 | jnp.array(target_link_index),
37 | jnp.array(target_wxyz),
38 | jnp.array(target_position),
39 | )
40 | assert cfg.shape == (robot.joints.num_actuated_joints,)
41 | return onp.array(cfg)
42 |
43 |
44 | @jdc.jit
45 | def _solve_ik_jax(
46 | robot: pk.Robot,
47 | target_link_index: jax.Array,
48 | target_wxyz: jax.Array,
49 | target_position: jax.Array,
50 | ) -> jax.Array:
51 | joint_var = robot.joint_var_cls(0)
52 | factors = [
53 | pk.costs.pose_cost_analytic_jac(
54 | robot,
55 | joint_var,
56 | jaxlie.SE3.from_rotation_and_translation(
57 | jaxlie.SO3(target_wxyz), target_position
58 | ),
59 | target_link_index,
60 | pos_weight=50.0,
61 | ori_weight=10.0,
62 | ),
63 | pk.costs.limit_cost(
64 | robot,
65 | joint_var,
66 | weight=100.0,
67 | ),
68 | ]
69 | sol = (
70 | jaxls.LeastSquaresProblem(factors, [joint_var])
71 | .analyze()
72 | .solve(
73 | verbose=False,
74 | linear_solver="dense_cholesky",
75 | trust_region=jaxls.TrustRegionConfig(lambda_initial=1.0),
76 | )
77 | )
78 | return sol[joint_var]
79 |
--------------------------------------------------------------------------------
/examples/pyroki_snippets/_solve_ik_with_base.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import jax_dataclasses as jdc
4 | import jaxlie
5 | import jaxls
6 |
7 | import numpy as onp
8 |
9 | import pyroki as pk
10 |
11 |
12 | def solve_ik_with_base(
13 | robot: pk.Robot,
14 | target_link_name: str,
15 | target_position: onp.ndarray,
16 | target_wxyz: onp.ndarray,
17 | fix_base_position: tuple[bool, bool, bool],
18 | fix_base_orientation: tuple[bool, bool, bool],
19 | prev_pos: onp.ndarray,
20 | prev_wxyz: onp.ndarray,
21 | prev_cfg: onp.ndarray,
22 | ) -> tuple[onp.ndarray, onp.ndarray, onp.ndarray]:
23 | """
24 | Solves the basic IK problem for a robot with a mobile base.
25 |
26 | Args:
27 | robot: PyRoKi Robot.
28 | target_link_name: str.
29 | position: onp.ndarray. Shape: (3,).
30 | wxyz: onp.ndarray. Shape: (4,).
31 | fix_base_position: Whether to fix the base position (x, y, z).
32 | fix_base_orientation: Whether to fix the base orientation (w_x, w_y, w_z).
33 | prev_pos, prev_wxyz, prev_cfg: Previous base position, orientation, and joint configuration, for smooth motion.
34 |
35 | Returns:
36 | base_pos: onp.ndarray. Shape: (3,).
37 | base_wxyz: onp.ndarray. Shape: (4,).
38 | cfg: onp.ndarray. Shape: (robot.joint.actuated_count,).
39 | """
40 | assert target_position.shape == (3,) and target_wxyz.shape == (4,)
41 | assert prev_pos.shape == (3,) and prev_wxyz.shape == (4,)
42 | assert prev_cfg.shape == (robot.joints.num_actuated_joints,)
43 | target_link_idx = robot.links.names.index(target_link_name)
44 |
45 | T_world_targets = jaxlie.SE3(
46 | jnp.concatenate([jnp.array(target_wxyz), jnp.array(target_position)], axis=-1)
47 | )
48 | base_pose, cfg = _solve_ik_jax(
49 | robot,
50 | T_world_targets,
51 | jnp.array(target_link_idx),
52 | jnp.array(fix_base_position + fix_base_orientation),
53 | jnp.array(prev_pos),
54 | jnp.array(prev_wxyz),
55 | jnp.array(prev_cfg),
56 | )
57 | assert cfg.shape == (robot.joints.num_actuated_joints,)
58 |
59 | base_pos = base_pose.translation()
60 | base_wxyz = base_pose.rotation().wxyz
61 | assert base_pos.shape == (3,) and base_wxyz.shape == (4,)
62 |
63 | return onp.array(base_pos), onp.array(base_wxyz), onp.array(cfg)
64 |
65 |
66 | @jdc.jit
67 | def _solve_ik_jax(
68 | robot: pk.Robot,
69 | T_world_target: jaxlie.SE3,
70 | target_joint_indices: jnp.ndarray,
71 | fix_base: jnp.ndarray,
72 | prev_pos: jnp.ndarray,
73 | prev_wxyz: jnp.ndarray,
74 | prev_cfg: jnp.ndarray,
75 | ) -> tuple[jaxlie.SE3, jax.Array]:
76 | joint_var = robot.joint_var_cls(0)
77 |
78 | def retract_fn(transform: jaxlie.SE3, delta: jax.Array) -> jaxlie.SE3:
79 | """Same as jaxls.SE3Var.retract_fn, but removing updates on certain axes."""
80 | delta = delta * (1 - fix_base)
81 | return jaxls.SE3Var.retract_fn(transform, delta)
82 |
83 | class ConstrainedSE3Var(
84 | jaxls.Var[jaxlie.SE3],
85 | default_factory=lambda: jaxlie.SE3.from_rotation_and_translation(
86 | jaxlie.SO3(prev_wxyz),
87 | prev_pos,
88 | ),
89 | tangent_dim=jaxlie.SE3.tangent_dim,
90 | retract_fn=retract_fn,
91 | ): ...
92 |
93 | base_var = ConstrainedSE3Var(0)
94 |
95 | factors = [
96 | pk.costs.pose_cost_with_base(
97 | robot,
98 | joint_var,
99 | base_var,
100 | T_world_target,
101 | target_joint_indices,
102 | pos_weight=jnp.array(5.0),
103 | ori_weight=jnp.array(1.0),
104 | ),
105 | pk.costs.limit_cost(
106 | robot,
107 | joint_var,
108 | jnp.array(100.0),
109 | ),
110 | pk.costs.rest_with_base_cost(
111 | joint_var,
112 | base_var,
113 | jnp.array(joint_var.default_factory()),
114 | jnp.array(
115 | [0.01] * robot.joints.num_actuated_joints
116 | + [0.1] * 3 # Base position DoF.
117 | + [0.001] * 3, # Base orientation DoF.
118 | ),
119 | ),
120 | ]
121 | sol = (
122 | jaxls.LeastSquaresProblem(factors, [joint_var, base_var])
123 | .analyze()
124 | .solve(
125 | initial_vals=jaxls.VarValues.make(
126 | [joint_var.with_value(prev_cfg), base_var]
127 | ),
128 | verbose=False,
129 | )
130 | )
131 | return sol[base_var], sol[joint_var]
132 |
--------------------------------------------------------------------------------
/examples/pyroki_snippets/_solve_ik_with_collision.py:
--------------------------------------------------------------------------------
1 | """
2 | Solves the basic IK problem with collision avoidance.
3 | """
4 |
5 | from typing import Sequence
6 |
7 | import jax
8 | import jax.numpy as jnp
9 | import jax_dataclasses as jdc
10 | import jaxlie
11 | import jaxls
12 | import numpy as onp
13 | import pyroki as pk
14 |
15 |
16 | def solve_ik_with_collision(
17 | robot: pk.Robot,
18 | coll: pk.collision.RobotCollision,
19 | world_coll_list: Sequence[pk.collision.CollGeom],
20 | target_link_name: str,
21 | target_position: onp.ndarray,
22 | target_wxyz: onp.ndarray,
23 | ) -> onp.ndarray:
24 | """
25 | Solves the basic IK problem for a robot.
26 |
27 | Args:
28 | robot: PyRoKi Robot.
29 | target_link_name: Sequence[str]. Length: num_targets.
30 | position: ArrayLike. Shape: (num_targets, 3), or (3,).
31 | wxyz: ArrayLike. Shape: (num_targets, 4), or (4,).
32 |
33 | Returns:
34 | cfg: ArrayLike. Shape: (robot.joint.actuated_count,).
35 | """
36 | assert target_position.shape == (3,) and target_wxyz.shape == (4,)
37 | target_link_idx = robot.links.names.index(target_link_name)
38 |
39 | T_world_targets = jaxlie.SE3(
40 | jnp.concatenate([jnp.array(target_wxyz), jnp.array(target_position)], axis=-1)
41 | )
42 | cfg = _solve_ik_with_collision_jax(
43 | robot,
44 | coll,
45 | world_coll_list,
46 | T_world_targets,
47 | jnp.array(target_link_idx),
48 | )
49 | assert cfg.shape == (robot.joints.num_actuated_joints,)
50 |
51 | return onp.array(cfg)
52 |
53 |
54 | @jdc.jit
55 | def _solve_ik_with_collision_jax(
56 | robot: pk.Robot,
57 | coll: pk.collision.RobotCollision,
58 | world_coll_list: Sequence[pk.collision.CollGeom],
59 | T_world_target: jaxlie.SE3,
60 | target_link_index: jax.Array,
61 | ) -> jax.Array:
62 | """Solves the basic IK problem with collision avoidance. Returns joint configuration."""
63 | joint_var = robot.joint_var_cls(0)
64 | vars = [joint_var]
65 |
66 | # Weights and margins defined directly in factors
67 | costs = [
68 | pk.costs.pose_cost(
69 | robot,
70 | joint_var,
71 | target_pose=T_world_target,
72 | target_link_index=target_link_index,
73 | pos_weight=5.0,
74 | ori_weight=1.0,
75 | ),
76 | pk.costs.limit_cost(
77 | robot,
78 | joint_var=joint_var,
79 | weight=100.0,
80 | ),
81 | pk.costs.rest_cost(
82 | joint_var,
83 | rest_pose=jnp.array(joint_var.default_factory()),
84 | weight=0.01,
85 | ),
86 | pk.costs.self_collision_cost(
87 | robot,
88 | robot_coll=coll,
89 | joint_var=joint_var,
90 | margin=0.02,
91 | weight=5.0,
92 | ),
93 | ]
94 | costs.extend(
95 | [
96 | pk.costs.world_collision_cost(
97 | robot, coll, joint_var, world_coll, 0.05, 10.0
98 | )
99 | for world_coll in world_coll_list
100 | ]
101 | )
102 |
103 | sol = (
104 | jaxls.LeastSquaresProblem(costs, vars)
105 | .analyze()
106 | .solve(verbose=False, linear_solver="dense_cholesky")
107 | )
108 | return sol[joint_var]
109 |
--------------------------------------------------------------------------------
/examples/pyroki_snippets/_solve_ik_with_manipulability.py:
--------------------------------------------------------------------------------
1 | """
2 | Solves the basic IK problem.
3 | """
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | import jax_dataclasses as jdc
8 | import jaxlie
9 | import jaxls
10 |
11 | import numpy as onp
12 |
13 | import pyroki as pk
14 |
15 |
16 | def solve_ik_with_manipulability(
17 | robot: pk.Robot,
18 | target_link_name: str,
19 | target_position: onp.ndarray,
20 | target_wxyz: onp.ndarray,
21 | manipulability_weight: float = 0.0,
22 | ) -> onp.ndarray:
23 | """
24 | Solves the basic IK problem for a robot, with manipulability cost.
25 |
26 | Args:
27 | robot: PyRoKi Robot.
28 | target_link_name: str.
29 | position: onp.ndarray. Shape: (3,).
30 | wxyz: onp.ndarray. Shape: (4,).
31 | manipulability_weight: float. Weight for the manipulability cost.
32 |
33 | Returns:
34 | cfg: onp.ndarray. Shape: (robot.joint.actuated_count,).
35 | """
36 | assert target_position.shape == (3,) and target_wxyz.shape == (4,)
37 |
38 | assert target_position.shape == (3,) and target_wxyz.shape == (4,)
39 | target_link_idx = robot.links.names.index(target_link_name)
40 |
41 | T_world_target = jaxlie.SE3(
42 | jnp.concatenate([jnp.array(target_wxyz), jnp.array(target_position)], axis=-1)
43 | )
44 | cfg = _solve_ik_jax(
45 | robot,
46 | T_world_target,
47 | jnp.array(target_link_idx),
48 | jnp.array(manipulability_weight),
49 | )
50 | assert cfg.shape == (robot.joints.num_actuated_joints,)
51 |
52 | return onp.array(cfg)
53 |
54 |
55 | @jdc.jit
56 | def _solve_ik_jax(
57 | robot: pk.Robot,
58 | T_world_target: jaxlie.SE3,
59 | target_joint_idx: jnp.ndarray,
60 | manipulability_weight: jnp.ndarray,
61 | ) -> jax.Array:
62 | joint_var = robot.joint_var_cls(0)
63 | vars = [joint_var]
64 | factors = [
65 | pk.costs.pose_cost_analytic_jac(
66 | robot,
67 | joint_var,
68 | T_world_target,
69 | target_joint_idx,
70 | pos_weight=50.0,
71 | ori_weight=10.0,
72 | ),
73 | pk.costs.limit_cost(
74 | robot,
75 | joint_var,
76 | jnp.array([100.0] * robot.joints.num_joints),
77 | ),
78 | pk.costs.rest_cost(
79 | joint_var,
80 | jnp.array(joint_var.default_factory()),
81 | jnp.array([0.01] * robot.joints.num_actuated_joints),
82 | ),
83 | pk.costs.manipulability_cost(
84 | robot,
85 | joint_var,
86 | target_joint_idx,
87 | manipulability_weight,
88 | ),
89 | ]
90 | sol = (
91 | jaxls.LeastSquaresProblem(factors, vars)
92 | .analyze()
93 | .solve(verbose=False, linear_solver="dense_cholesky")
94 | )
95 | return sol[joint_var]
96 |
--------------------------------------------------------------------------------
/examples/pyroki_snippets/_solve_ik_with_multiple_targets.py:
--------------------------------------------------------------------------------
1 | """
2 | Solves the basic IK problem.
3 | """
4 |
5 | from typing import Sequence
6 |
7 | import jax
8 | import jax.numpy as jnp
9 | import jax_dataclasses as jdc
10 | import jaxlie
11 | import jaxls
12 | import numpy as onp
13 | import pyroki as pk
14 |
15 |
16 | def solve_ik_with_multiple_targets(
17 | robot: pk.Robot,
18 | target_link_names: Sequence[str],
19 | target_wxyzs: onp.ndarray,
20 | target_positions: onp.ndarray,
21 | ) -> onp.ndarray:
22 | """
23 | Solves the basic IK problem for a robot.
24 |
25 | Args:
26 | robot: PyRoKi Robot.
27 | target_link_names: Sequence[str]. List of link names to be controlled.
28 | target_wxyzs: onp.ndarray. Shape: (num_targets, 4). Target orientations.
29 | target_positions: onp.ndarray. Shape: (num_targets, 3). Target positions.
30 |
31 | Returns:
32 | cfg: onp.ndarray. Shape: (robot.joint.actuated_count,).
33 | """
34 | num_targets = len(target_link_names)
35 | assert target_positions.shape == (num_targets, 3)
36 | assert target_wxyzs.shape == (num_targets, 4)
37 | target_link_indices = [robot.links.names.index(name) for name in target_link_names]
38 |
39 | cfg = _solve_ik_jax(
40 | robot,
41 | jnp.array(target_wxyzs),
42 | jnp.array(target_positions),
43 | jnp.array(target_link_indices),
44 | )
45 | assert cfg.shape == (robot.joints.num_actuated_joints,)
46 |
47 | return onp.array(cfg)
48 |
49 |
50 | @jdc.jit
51 | def _solve_ik_jax(
52 | robot: pk.Robot,
53 | target_wxyz: jax.Array,
54 | target_position: jax.Array,
55 | target_joint_indices: jax.Array,
56 | ) -> jax.Array:
57 | JointVar = robot.joint_var_cls
58 |
59 | # Get the batch axes for the variable through the target pose.
60 | # Batch axes for the variables and cost terms (e.g., target pose) should be broadcastable!
61 | target_pose = jaxlie.SE3.from_rotation_and_translation(
62 | jaxlie.SO3(target_wxyz), target_position
63 | )
64 | batch_axes = target_pose.get_batch_axes()
65 |
66 | factors = [
67 | pk.costs.pose_cost_analytic_jac(
68 | jax.tree.map(lambda x: x[None], robot),
69 | JointVar(jnp.full(batch_axes, 0)),
70 | target_pose,
71 | target_joint_indices,
72 | pos_weight=50.0,
73 | ori_weight=10.0,
74 | ),
75 | pk.costs.rest_cost(
76 | JointVar(0),
77 | rest_pose=JointVar.default_factory(),
78 | weight=1.0,
79 | ),
80 | pk.costs.limit_cost(
81 | robot,
82 | JointVar(0),
83 | jnp.array([100.0] * robot.joints.num_joints),
84 | ),
85 | ]
86 | sol = (
87 | jaxls.LeastSquaresProblem(factors, [JointVar(0)])
88 | .analyze()
89 | .solve(
90 | verbose=False,
91 | linear_solver="dense_cholesky",
92 | trust_region=jaxls.TrustRegionConfig(lambda_initial=10.0),
93 | )
94 | )
95 | return sol[JointVar(0)]
96 |
--------------------------------------------------------------------------------
/examples/pyroki_snippets/_trajopt.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import jax_dataclasses as jdc
6 | import jaxlie
7 | import jaxls
8 | import numpy as onp
9 | import pyroki as pk
10 | from jax.typing import ArrayLike
11 |
12 |
13 | def solve_trajopt(
14 | robot: pk.Robot,
15 | robot_coll: pk.collision.RobotCollision,
16 | world_coll: Sequence[pk.collision.CollGeom],
17 | target_link_name: str,
18 | start_position: ArrayLike,
19 | start_wxyz: ArrayLike,
20 | end_position: ArrayLike,
21 | end_wxyz: ArrayLike,
22 | timesteps: int,
23 | dt: float,
24 | ) -> ArrayLike:
25 | if isinstance(start_position, onp.ndarray):
26 | np = onp
27 | elif isinstance(start_position, jnp.ndarray):
28 | np = jnp
29 | else:
30 | raise ValueError(f"Invalid type for `ArrayLike`: {type(start_position)}")
31 |
32 | # 1. Solve IK for the start and end poses.
33 | target_link_index = robot.links.names.index(target_link_name)
34 | start_cfg, end_cfg = solve_iks_with_collision(
35 | robot=robot,
36 | coll=robot_coll,
37 | world_coll_list=world_coll,
38 | target_link_index=target_link_index,
39 | target_position_0=jnp.array(start_position),
40 | target_wxyz_0=jnp.array(start_wxyz),
41 | target_position_1=jnp.array(end_position),
42 | target_wxyz_1=jnp.array(end_wxyz),
43 | )
44 |
45 | # 2. Initialize the trajectory through linearly interpolating the start and end poses.
46 | init_traj = jnp.linspace(start_cfg, end_cfg, timesteps)
47 |
48 | # 3. Optimize the trajectory.
49 | traj_vars = robot.joint_var_cls(jnp.arange(timesteps))
50 |
51 | robot = jax.tree.map(lambda x: x[None], robot) # Add batch dimension.
52 | robot_coll = jax.tree.map(lambda x: x[None], robot_coll) # Add batch dimension.
53 |
54 | # Basic regularization / limit costs.
55 | factors: list[jaxls.Cost] = [
56 | pk.costs.rest_cost(
57 | traj_vars,
58 | traj_vars.default_factory()[None],
59 | jnp.array([0.01])[None],
60 | ),
61 | pk.costs.limit_cost(
62 | robot,
63 | traj_vars,
64 | jnp.array([100.0])[None],
65 | ),
66 | ]
67 |
68 | # Collision avoidance.
69 | def compute_world_coll_residual(
70 | vals: jaxls.VarValues,
71 | robot: pk.Robot,
72 | robot_coll: pk.collision.RobotCollision,
73 | world_coll_obj: pk.collision.CollGeom,
74 | prev_traj_vars: jaxls.Var[jax.Array],
75 | curr_traj_vars: jaxls.Var[jax.Array],
76 | ):
77 | coll = robot_coll.get_swept_capsules(
78 | robot, vals[prev_traj_vars], vals[curr_traj_vars]
79 | )
80 | dist = pk.collision.collide(
81 | coll.reshape((-1, 1)), world_coll_obj.reshape((1, -1))
82 | )
83 | colldist = pk.collision.colldist_from_sdf(dist, 0.1)
84 | return (colldist * 20.0).flatten()
85 |
86 | for world_coll_obj in world_coll:
87 | factors.append(
88 | jaxls.Cost(
89 | compute_world_coll_residual,
90 | (
91 | robot,
92 | robot_coll,
93 | jax.tree.map(lambda x: x[None], world_coll_obj),
94 | robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
95 | robot.joint_var_cls(jnp.arange(1, timesteps)),
96 | ),
97 | name="World Collision (sweep)",
98 | )
99 | )
100 |
101 | # Start / end pose constraints.
102 | factors.extend(
103 | [
104 | jaxls.Cost(
105 | lambda vals, var: ((vals[var] - start_cfg) * 100.0).flatten(),
106 | (robot.joint_var_cls(jnp.arange(0, 2)),),
107 | name="start_pose_constraint",
108 | ),
109 | jaxls.Cost(
110 | lambda vals, var: ((vals[var] - end_cfg) * 100.0).flatten(),
111 | (robot.joint_var_cls(jnp.arange(timesteps - 2, timesteps)),),
112 | name="end_pose_constraint",
113 | ),
114 | ]
115 | )
116 |
117 | # Velocity / acceleration / jerk minimization.
118 | factors.extend(
119 | [
120 | pk.costs.smoothness_cost(
121 | robot.joint_var_cls(jnp.arange(1, timesteps)),
122 | robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
123 | jnp.array([0.1])[None],
124 | ),
125 | pk.costs.five_point_velocity_cost(
126 | robot,
127 | robot.joint_var_cls(jnp.arange(4, timesteps)),
128 | robot.joint_var_cls(jnp.arange(3, timesteps - 1)),
129 | robot.joint_var_cls(jnp.arange(1, timesteps - 3)),
130 | robot.joint_var_cls(jnp.arange(0, timesteps - 4)),
131 | dt,
132 | jnp.array([10.0])[None],
133 | ),
134 | pk.costs.five_point_acceleration_cost(
135 | robot.joint_var_cls(jnp.arange(2, timesteps - 2)),
136 | robot.joint_var_cls(jnp.arange(4, timesteps)),
137 | robot.joint_var_cls(jnp.arange(3, timesteps - 1)),
138 | robot.joint_var_cls(jnp.arange(1, timesteps - 3)),
139 | robot.joint_var_cls(jnp.arange(0, timesteps - 4)),
140 | dt,
141 | jnp.array([0.1])[None],
142 | ),
143 | pk.costs.five_point_jerk_cost(
144 | robot.joint_var_cls(jnp.arange(6, timesteps)),
145 | robot.joint_var_cls(jnp.arange(5, timesteps - 1)),
146 | robot.joint_var_cls(jnp.arange(4, timesteps - 2)),
147 | robot.joint_var_cls(jnp.arange(2, timesteps - 4)),
148 | robot.joint_var_cls(jnp.arange(1, timesteps - 5)),
149 | robot.joint_var_cls(jnp.arange(0, timesteps - 6)),
150 | dt,
151 | jnp.array([0.1])[None],
152 | ),
153 | ]
154 | )
155 |
156 | # 4. Solve the optimization problem.
157 | solution = (
158 | jaxls.LeastSquaresProblem(
159 | factors,
160 | [traj_vars],
161 | )
162 | .analyze()
163 | .solve(
164 | initial_vals=jaxls.VarValues.make((traj_vars.with_value(init_traj),)),
165 | )
166 | )
167 | return np.array(solution[traj_vars])
168 |
169 |
170 | @jdc.jit
171 | def solve_iks_with_collision(
172 | robot: pk.Robot,
173 | coll: pk.collision.RobotCollision,
174 | world_coll_list: Sequence[pk.collision.CollGeom],
175 | target_link_index: int,
176 | target_position_0: jax.Array,
177 | target_wxyz_0: jax.Array,
178 | target_position_1: jax.Array,
179 | target_wxyz_1: jax.Array,
180 | ) -> tuple[jax.Array, jax.Array]:
181 | """Solves the basic IK problem with collision avoidance. Returns joint configuration."""
182 | joint_var_0 = robot.joint_var_cls(0)
183 | joint_var_1 = robot.joint_var_cls(1)
184 | joint_vars = robot.joint_var_cls(jnp.arange(2))
185 | vars = [joint_vars]
186 |
187 | # Weights and margins defined directly in factors.
188 | factors = [
189 | pk.costs.pose_cost(
190 | robot,
191 | joint_var_0,
192 | jaxlie.SE3.from_rotation_and_translation(
193 | jaxlie.SO3(target_wxyz_0), target_position_0
194 | ),
195 | jnp.array(target_link_index),
196 | jnp.array([5.0] * 3),
197 | jnp.array([1.0] * 3),
198 | ),
199 | pk.costs.pose_cost(
200 | robot,
201 | joint_var_1,
202 | jaxlie.SE3.from_rotation_and_translation(
203 | jaxlie.SO3(target_wxyz_1), target_position_1
204 | ),
205 | jnp.array(target_link_index),
206 | jnp.array([5.0] * 3),
207 | jnp.array([1.0] * 3),
208 | ),
209 | ]
210 | factors.extend(
211 | [
212 | pk.costs.limit_cost(
213 | jax.tree.map(lambda x: x[None], robot),
214 | joint_vars,
215 | jnp.array(100.0),
216 | ),
217 | pk.costs.rest_cost(
218 | joint_vars,
219 | jnp.array(joint_vars.default_factory()[None]),
220 | jnp.array(0.001),
221 | ),
222 | pk.costs.self_collision_cost(
223 | jax.tree.map(lambda x: x[None], robot),
224 | jax.tree.map(lambda x: x[None], coll),
225 | joint_vars,
226 | 0.02,
227 | 5.0,
228 | ),
229 | ]
230 | )
231 | factors.extend(
232 | [
233 | pk.costs.world_collision_cost(
234 | jax.tree.map(lambda x: x[None], robot),
235 | jax.tree.map(lambda x: x[None], coll),
236 | joint_vars,
237 | jax.tree.map(lambda x: x[None], world_coll),
238 | 0.05,
239 | 10.0,
240 | )
241 | for world_coll in world_coll_list
242 | ]
243 | )
244 |
245 | # Small cost to encourage the start + end configs to be close to each other.
246 | @jaxls.Cost.create_factory(name="JointSimilarityCost")
247 | def joint_similarity_cost(vals, var_0, var_1):
248 | return ((vals[var_0] - vals[var_1]) * 0.01).flatten()
249 |
250 | factors.append(joint_similarity_cost(joint_var_0, joint_var_1))
251 |
252 | sol = jaxls.LeastSquaresProblem(factors, vars).analyze().solve(verbose=False)
253 | return sol[joint_var_0], sol[joint_var_1]
254 |
--------------------------------------------------------------------------------
/examples/retarget_helpers/_utils.py:
--------------------------------------------------------------------------------
1 | import pyroki as pk
2 | import jax.numpy as jnp
3 |
4 |
5 | def create_conn_tree(robot: pk.Robot, link_indices: jnp.ndarray) -> jnp.ndarray:
6 | """
7 | Create a NxN connectivity matrix for N links.
8 | The matrix is marked Y if there is a direct kinematic chain connection
9 | between the two links, without bypassing the root link.
10 | """
11 | n = len(link_indices)
12 | conn_matrix = jnp.zeros((n, n))
13 |
14 | def is_direct_chain_connection(idx1: int, idx2: int) -> bool:
15 | """Check if two joints are connected in the kinematic chain without other retargeted joints between"""
16 | joint1 = link_indices[idx1]
17 | joint2 = link_indices[idx2]
18 |
19 | # Check path from joint2 up to root
20 | current = joint2
21 | while current != -1:
22 | parent = robot.joints.parent_indices[current]
23 | if parent == joint1:
24 | return True
25 | if parent in link_indices:
26 | # Hit another retargeted joint before finding joint1
27 | break
28 | current = parent
29 |
30 | # Check path from joint1 up to root
31 | current = joint1
32 | while current != -1:
33 | parent = robot.joints.parent_indices[current]
34 | if parent == joint2:
35 | return True
36 | if parent in link_indices:
37 | # Hit another retargeted joint before finding joint2
38 | break
39 | current = parent
40 |
41 | return False
42 |
43 | # Build symmetric connectivity matrix
44 | for i in range(n):
45 | conn_matrix = conn_matrix.at[i, i].set(1.0) # Self-connection
46 | for j in range(i + 1, n):
47 | if is_direct_chain_connection(i, j):
48 | conn_matrix = conn_matrix.at[i, j].set(1.0)
49 | conn_matrix = conn_matrix.at[j, i].set(1.0)
50 |
51 | return conn_matrix
52 |
53 |
54 | SMPL_JOINT_NAMES = [
55 | "pelvis",
56 | "left_hip",
57 | "right_hip",
58 | "spine_1",
59 | "left_knee",
60 | "right_knee",
61 | "spine_2",
62 | "left_ankle",
63 | "right_ankle",
64 | "spine_3",
65 | "left_foot",
66 | "right_foot",
67 | "neck",
68 | "left_collar",
69 | "right_collar",
70 | "head",
71 | "left_shoulder",
72 | "right_shoulder",
73 | "left_elbow",
74 | "right_elbow",
75 | "left_wrist",
76 | "right_wrist",
77 | "left_hand",
78 | "right_hand",
79 | "nose",
80 | "right_eye",
81 | "left_eye",
82 | "right_ear",
83 | "left_ear",
84 | "left_big_toe",
85 | "left_small_toe",
86 | "left_heel",
87 | "right_big_toe",
88 | "right_small_toe",
89 | "right_heel",
90 | "left_thumb",
91 | "left_index",
92 | "left_middle",
93 | "left_ring",
94 | "left_pinky",
95 | "right_thumb",
96 | "right_index",
97 | "right_middle",
98 | "right_ring",
99 | "right_pinky",
100 | ]
101 |
102 | # When loaded from `g1_description`s 23-dof model.
103 | G1_LINK_NAMES = [
104 | "pelvis",
105 | "pelvis_contour_link",
106 | "left_hip_pitch_link",
107 | "left_hip_roll_link",
108 | "left_hip_yaw_link",
109 | "left_knee_link",
110 | "left_ankle_pitch_link",
111 | "left_ankle_roll_link",
112 | "right_hip_pitch_link",
113 | "right_hip_roll_link",
114 | "right_hip_yaw_link",
115 | "right_knee_link",
116 | "right_ankle_pitch_link",
117 | "right_ankle_roll_link",
118 | "torso_link",
119 | "head_link",
120 | "left_shoulder_pitch_link",
121 | "left_shoulder_roll_link",
122 | "left_shoulder_yaw_link",
123 | "left_elbow_pitch_link",
124 | "left_elbow_roll_link",
125 | "right_shoulder_pitch_link",
126 | "right_shoulder_roll_link",
127 | "right_shoulder_yaw_link",
128 | "right_elbow_pitch_link",
129 | "right_elbow_roll_link",
130 | "logo_link",
131 | "imu_link",
132 | "left_palm_link",
133 | "left_zero_link",
134 | "left_one_link",
135 | "left_two_link",
136 | "left_three_link",
137 | "left_four_link",
138 | "left_five_link",
139 | "left_six_link",
140 | "right_palm_link",
141 | "right_zero_link",
142 | "right_one_link",
143 | "right_two_link",
144 | "right_three_link",
145 | "right_four_link",
146 | "right_five_link",
147 | "right_six_link",
148 | ]
149 |
150 |
151 | def get_humanoid_retarget_indices() -> tuple[jnp.ndarray, jnp.ndarray]:
152 | smpl_joint_retarget_indices_to_g1 = []
153 | g1_joint_retarget_indices = []
154 |
155 | for smpl_name, g1_name in [
156 | ("pelvis", "pelvis_contour_link"),
157 | ("left_hip", "left_hip_pitch_link"),
158 | ("right_hip", "right_hip_pitch_link"),
159 | ("left_knee", "left_knee_link"),
160 | ("right_knee", "right_knee_link"),
161 | ("left_ankle", "left_ankle_roll_link"),
162 | ("right_ankle", "right_ankle_roll_link"),
163 | ("left_shoulder", "left_shoulder_roll_link"),
164 | ("right_shoulder", "right_shoulder_roll_link"),
165 | ("left_elbow", "left_elbow_pitch_link"),
166 | ("right_elbow", "right_elbow_pitch_link"),
167 | ("left_wrist", "left_palm_link"),
168 | ("right_wrist", "right_palm_link"),
169 | ]:
170 | smpl_joint_retarget_indices_to_g1.append(SMPL_JOINT_NAMES.index(smpl_name))
171 | g1_joint_retarget_indices.append(G1_LINK_NAMES.index(g1_name))
172 |
173 | smpl_joint_retarget_indices = jnp.array(smpl_joint_retarget_indices_to_g1)
174 | g1_joint_retarget_indices = jnp.array(g1_joint_retarget_indices)
175 | return smpl_joint_retarget_indices, g1_joint_retarget_indices
176 |
177 |
178 | MANO_TO_SHADOW_MAPPING = {
179 | # Wrist
180 | 0: "palm",
181 | # Thumb
182 | 1: "thhub",
183 | 2: "thmiddle",
184 | 3: "thdistal",
185 | 4: "thtip",
186 | # Index
187 | 5: "ffproximal",
188 | 6: "ffmiddle",
189 | 7: "ffdistal",
190 | 8: "fftip",
191 | # Middle
192 | 9: "mfproximal",
193 | 10: "mfmiddle",
194 | 11: "mfdistal",
195 | 12: "mftip",
196 | # Ring
197 | 13: "rfproximal",
198 | 14: "rfmiddle",
199 | 15: "rfdistal",
200 | 16: "rftip",
201 | # # Little
202 | 17: "lfproximal",
203 | 18: "lfmiddle",
204 | 19: "lfdistal",
205 | 20: "lftip",
206 | }
207 |
208 |
209 | def get_mapping_from_mano_to_shadow(robot: pk.Robot) -> tuple[jnp.ndarray, jnp.ndarray]:
210 | """Get the mapping indices between MANO and Shadow Hand joints."""
211 | SHADOW_TO_MANO_MAPPING = {v: k for k, v in MANO_TO_SHADOW_MAPPING.items()}
212 | shadow_joint_idx = []
213 | mano_joint_idx = []
214 | link_names = robot.links.names
215 | for i, link_name in enumerate(link_names):
216 | if link_name in SHADOW_TO_MANO_MAPPING:
217 | shadow_joint_idx.append(i)
218 | mano_joint_idx.append(SHADOW_TO_MANO_MAPPING[link_name])
219 |
220 | return jnp.array(shadow_joint_idx), jnp.array(mano_joint_idx)
221 |
--------------------------------------------------------------------------------
/examples/retarget_helpers/hand/dexycb_motion.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/examples/retarget_helpers/hand/dexycb_motion.pkl
--------------------------------------------------------------------------------
/examples/retarget_helpers/hand/shadowhand_urdf.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/examples/retarget_helpers/hand/shadowhand_urdf.zip
--------------------------------------------------------------------------------
/examples/retarget_helpers/humanoid/heightmap.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/examples/retarget_helpers/humanoid/heightmap.npy
--------------------------------------------------------------------------------
/examples/retarget_helpers/humanoid/left_foot_contact.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/examples/retarget_helpers/humanoid/left_foot_contact.npy
--------------------------------------------------------------------------------
/examples/retarget_helpers/humanoid/right_foot_contact.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/examples/retarget_helpers/humanoid/right_foot_contact.npy
--------------------------------------------------------------------------------
/examples/retarget_helpers/humanoid/smpl_keypoints.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/examples/retarget_helpers/humanoid/smpl_keypoints.npy
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "pyroki"
7 | version = "0.0.0"
8 | description = "Python Robot Kinematics Library"
9 | readme = "README.md"
10 | license = { text="MIT" }
11 | requires-python = ">=3.10"
12 | classifiers = [
13 | "Programming Language :: Python :: 3.10",
14 | "Programming Language :: Python :: 3.11",
15 | "Programming Language :: Python :: 3.12",
16 | "Programming Language :: Python :: 3.13",
17 | "License :: OSI Approved :: MIT License",
18 | "Operating System :: OS Independent"
19 | ]
20 | dependencies = [
21 | "tyro",
22 | "jax>=0.4.0",
23 | "jaxlib",
24 | "jaxlie>=1.0.0",
25 | "jax_dataclasses>=1.0.0",
26 | "jaxtyping",
27 | "loguru",
28 | "robot_descriptions",
29 | "jaxls @ git+https://github.com/brentyi/jaxls.git",
30 | "yourdfpy",
31 | "trimesh",
32 | "viser",
33 | "pyliblzfse", # Need for viser.extras import in viser==0.2.23
34 | ]
35 |
36 | [project.optional-dependencies]
37 | dev = [
38 | "pyright>=1.1.308",
39 | "scikit-sparse",
40 | "ruff",
41 | "pytest",
42 | "m2r2",
43 | ]
44 |
45 | [tool.ruff.lint]
46 | select = [
47 | "E", # pycodestyle errors.
48 | "F", # Pyflakes rules.
49 | "PLC", # Pylint convention warnings.
50 | "PLE", # Pylint errors.
51 | "PLR", # Pylint refactor recommendations.
52 | "PLW", # Pylint warnings.
53 | ]
54 | ignore = [
55 | "E731", # Do not assign a lambda expression, use a def.
56 | "E741", # Ambiguous variable name. (l, O, or I)
57 | "E501", # Line too long.
58 | "E721", # Do not compare types, use `isinstance()`.
59 | "F722", # Forward annotation false positive from jaxtyping. Should be caught by pyright.
60 | "F821", # Forward annotation false positive from jaxtyping. Should be caught by pyright.
61 | "PLR2004", # Magic value used in comparison.
62 | "PLR0915", # Too many statements.
63 | "PLR0913", # Too many arguments.
64 | "PLC0414", # Import alias does not rename variable. (this is used for exporting names)
65 | "PLC1901", # Use falsey strings.
66 | "PLR5501", # Use `elif` instead of `else if`.
67 | "PLR0911", # Too many return statements.
68 | "PLR0912", # Too many branches.
69 | "PLW0603", # Globa statement updates are discouraged.
70 | "PLW2901", # For loop variable overwritten.
71 | ]
72 |
--------------------------------------------------------------------------------
/src/pyroki/__init__.py:
--------------------------------------------------------------------------------
1 | from . import collision as collision
2 | from . import costs as costs
3 | from . import viewer as viewer
4 | from ._robot import Robot as Robot
5 |
6 | __version__ = "0.0.0"
7 |
--------------------------------------------------------------------------------
/src/pyroki/_robot.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import jax
4 | import jax_dataclasses as jdc
5 | import jaxlie
6 | import jaxls
7 | import yourdfpy
8 | from jax import Array
9 | from jax import numpy as jnp
10 | from jax.typing import ArrayLike
11 | from jaxtyping import Float
12 |
13 | from ._robot_urdf_parser import JointInfo, LinkInfo, RobotURDFParser
14 |
15 |
16 | @jdc.pytree_dataclass
17 | class Robot:
18 | """A differentiable robot kinematics tree."""
19 |
20 | joints: JointInfo
21 | """Joint information for the robot."""
22 |
23 | links: LinkInfo
24 | """Link information for the robot."""
25 |
26 | joint_var_cls: jdc.Static[type[jaxls.Var[Array]]]
27 | """Variable class for the robot configuration."""
28 |
29 | @staticmethod
30 | def from_urdf(
31 | urdf: yourdfpy.URDF,
32 | default_joint_cfg: Float[ArrayLike, "*batch actuated_count"] | None = None,
33 | ) -> Robot:
34 | """
35 | Loads a robot kinematic tree from a URDF.
36 | Internally tracks a topological sort of the joints.
37 |
38 | Args:
39 | urdf: The URDF to load the robot from.
40 | default_joint_cfg: The default joint configuration to use for optimization.
41 | """
42 | joints, links = RobotURDFParser.parse(urdf)
43 |
44 | # Compute default joint configuration.
45 | if default_joint_cfg is None:
46 | default_joint_cfg = (joints.lower_limits + joints.upper_limits) / 2
47 | else:
48 | default_joint_cfg = jnp.array(default_joint_cfg)
49 | assert default_joint_cfg.shape == (joints.num_actuated_joints,)
50 |
51 | # Variable class for the robot configuration.
52 | class JointVar( # pylint: disable=missing-class-docstring
53 | jaxls.Var[Array],
54 | default_factory=lambda: default_joint_cfg,
55 | ): ...
56 |
57 | robot = Robot(
58 | joints=joints,
59 | links=links,
60 | joint_var_cls=JointVar,
61 | )
62 |
63 | return robot
64 |
65 | @jdc.jit
66 | def forward_kinematics(
67 | self,
68 | cfg: Float[Array, "*batch actuated_count"],
69 | unroll_fk: jdc.Static[bool] = False,
70 | ) -> Float[Array, "*batch link_count 7"]:
71 | """Run forward kinematics on the robot's links, in the provided configuration.
72 |
73 | Computes the world pose of each link frame. The result is ordered
74 | corresponding to `self.link.names`.
75 |
76 | Args:
77 | cfg: The configuration of the actuated joints, in the format `(*batch actuated_count)`.
78 |
79 | Returns:
80 | The SE(3) transforms of the links, ordered by `self.link.names`,
81 | in the format `(*batch, link_count, wxyz_xyz)`.
82 | """
83 | batch_axes = cfg.shape[:-1]
84 | assert cfg.shape == (*batch_axes, self.joints.num_actuated_joints)
85 | return self._link_poses_from_joint_poses(
86 | self._forward_kinematics_joints(cfg, unroll_fk)
87 | )
88 |
89 | def _link_poses_from_joint_poses(
90 | self, Ts_world_joint: Float[Array, "*batch actuated_count 7"]
91 | ) -> Float[Array, "*batch link_count 7"]:
92 | (*batch_axes, _, _) = Ts_world_joint.shape
93 | # Get the link poses.
94 | base_link_mask = self.links.parent_joint_indices == -1
95 | parent_joint_indices = jnp.where(
96 | base_link_mask, 0, self.links.parent_joint_indices
97 | )
98 | identity_pose = jaxlie.SE3.identity().wxyz_xyz
99 | Ts_world_link = jnp.where(
100 | base_link_mask[..., None],
101 | identity_pose,
102 | Ts_world_joint[..., parent_joint_indices, :],
103 | )
104 | assert Ts_world_link.shape == (*batch_axes, self.links.num_links, 7)
105 | return Ts_world_link
106 |
107 | def _forward_kinematics_joints(
108 | self,
109 | cfg: Float[Array, "*batch actuated_count"],
110 | unroll_fk: jdc.Static[bool] = False,
111 | ) -> Float[Array, "*batch joint_count 7"]:
112 | (*batch_axes, _) = cfg.shape
113 | assert cfg.shape == (*batch_axes, self.joints.num_actuated_joints)
114 |
115 | # Calculate full configuration using the dedicated method
116 | q_full = self.joints.get_full_config(cfg)
117 |
118 | # Calculate delta transforms using the effective config and twists for all joints.
119 | tangents = self.joints.twists * q_full[..., None]
120 | assert tangents.shape == (*batch_axes, self.joints.num_joints, 6)
121 | delta_Ts = jaxlie.SE3.exp(tangents) # Shape: (*batch_axes, self.joint.count, 7)
122 |
123 | # Combine constant parent transform with variable joint delta transform.
124 | Ts_parent_child = (
125 | jaxlie.SE3(self.joints.parent_transforms) @ delta_Ts
126 | ).wxyz_xyz
127 | assert Ts_parent_child.shape == (*batch_axes, self.joints.num_joints, 7)
128 |
129 | # Topological sort helpers
130 | topo_order = jnp.argsort(self.joints._topo_sort_inv)
131 | Ts_parent_child_sorted = Ts_parent_child[..., self.joints._topo_sort_inv, :]
132 | parent_orig_for_sorted_child = self.joints.parent_indices[
133 | self.joints._topo_sort_inv
134 | ]
135 | idx_parent_joint_sorted = jnp.where(
136 | parent_orig_for_sorted_child == -1,
137 | -1,
138 | topo_order[parent_orig_for_sorted_child],
139 | )
140 |
141 | # Compute link transforms relative to world, indexed by sorted *joint* index.
142 | def compute_transform(i: int, Ts_world_link_sorted: Array) -> Array:
143 | parent_sorted_idx = idx_parent_joint_sorted[i]
144 | T_world_parent_link = jnp.where(
145 | parent_sorted_idx == -1,
146 | jaxlie.SE3.identity().wxyz_xyz,
147 | Ts_world_link_sorted[..., parent_sorted_idx, :],
148 | )
149 | return Ts_world_link_sorted.at[..., i, :].set(
150 | (
151 | jaxlie.SE3(T_world_parent_link)
152 | @ jaxlie.SE3(Ts_parent_child_sorted[..., i, :])
153 | ).wxyz_xyz
154 | )
155 |
156 | Ts_world_link_init_sorted = jnp.zeros((*batch_axes, self.joints.num_joints, 7))
157 | Ts_world_link_sorted = jax.lax.fori_loop(
158 | lower=0,
159 | upper=self.joints.num_joints,
160 | body_fun=compute_transform,
161 | init_val=Ts_world_link_init_sorted,
162 | unroll=unroll_fk,
163 | )
164 |
165 | Ts_world_link_joint_indexed = Ts_world_link_sorted[..., topo_order, :]
166 | assert Ts_world_link_joint_indexed.shape == (
167 | *batch_axes,
168 | self.joints.num_joints,
169 | 7,
170 | ) # This is the link poses indexed by parent *joint* index.
171 |
172 | return Ts_world_link_joint_indexed
173 |
--------------------------------------------------------------------------------
/src/pyroki/collision/__init__.py:
--------------------------------------------------------------------------------
1 | """Collision detection primitives and utilities."""
2 |
3 | from ._collision import colldist_from_sdf as colldist_from_sdf
4 | from ._collision import collide as collide
5 | from ._geometry import Capsule as Capsule
6 | from ._geometry import CollGeom as CollGeom
7 | from ._geometry import HalfSpace as HalfSpace
8 | from ._geometry import Heightmap as Heightmap
9 | from ._geometry import Sphere as Sphere
10 | from ._robot_collision import RobotCollision as RobotCollision
11 |
--------------------------------------------------------------------------------
/src/pyroki/collision/_collision.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Callable, Dict, Tuple, cast
4 |
5 | import jax
6 | import jax.numpy as jnp
7 | from jaxtyping import Array, Float
8 |
9 | from ._geometry import Capsule, CollGeom, HalfSpace, Heightmap, Sphere
10 | from ._geometry_pairs import (
11 | capsule_capsule,
12 | halfspace_capsule,
13 | halfspace_sphere,
14 | heightmap_capsule,
15 | heightmap_halfspace,
16 | heightmap_sphere,
17 | sphere_capsule,
18 | sphere_sphere,
19 | )
20 |
21 | COLLISION_FUNCTIONS: Dict[
22 | Tuple[type[CollGeom], type[CollGeom]], Callable[..., Float[Array, "*batch"]]
23 | ] = {
24 | (HalfSpace, Sphere): halfspace_sphere,
25 | (HalfSpace, Capsule): halfspace_capsule,
26 | (Sphere, Sphere): sphere_sphere,
27 | (Sphere, Capsule): sphere_capsule,
28 | (Capsule, Capsule): capsule_capsule,
29 | (Heightmap, Sphere): heightmap_sphere,
30 | (Heightmap, Capsule): heightmap_capsule,
31 | (Heightmap, HalfSpace): heightmap_halfspace,
32 | }
33 |
34 |
35 | def _get_coll_func(
36 | geom1_cls: type[CollGeom], geom2_cls: type[CollGeom]
37 | ) -> Callable[[CollGeom, CollGeom], Float[Array, "*batch"]]:
38 | """Get appropriate collision function (distance only) for given geometry types."""
39 | func = COLLISION_FUNCTIONS.get((geom1_cls, geom2_cls))
40 | if func is not None:
41 | return cast(Callable[[CollGeom, CollGeom], Float[Array, "*batch"]], func)
42 |
43 | func_swapped = COLLISION_FUNCTIONS.get((geom2_cls, geom1_cls))
44 | if func_swapped is not None:
45 | return cast(
46 | Callable[[CollGeom, CollGeom], Float[Array, "*batch"]],
47 | lambda g1, g2: func_swapped(g2, g1),
48 | )
49 |
50 | raise NotImplementedError(
51 | f"No collision function found for {geom1_cls.__name__} and {geom2_cls.__name__}"
52 | )
53 |
54 |
55 | def collide(geom1: CollGeom, geom2: CollGeom) -> Float[Array, "*batch"]:
56 | """Calculate collision distance between two geometric objects, handling broadcasting."""
57 | try:
58 | broadcast_shape = jnp.broadcast_shapes(
59 | geom1.get_batch_axes(), geom2.get_batch_axes()
60 | )
61 | except ValueError as e:
62 | raise ValueError(
63 | f"Cannot broadcast geometry shapes {geom1.get_batch_axes()} and {geom2.get_batch_axes()}"
64 | ) from e
65 |
66 | geom1 = geom1.broadcast_to(broadcast_shape)
67 | geom2 = geom2.broadcast_to(broadcast_shape)
68 | geom1_cls = type(geom1)
69 | geom2_cls = type(geom2)
70 | return _get_coll_func(geom1_cls, geom2_cls)(geom1, geom2)
71 |
72 |
73 | def pairwise_collide(geom1: CollGeom, geom2: CollGeom) -> Float[Array, "*batch N M"]:
74 | """
75 | Convenience wrapper around `collide` for computing pairwise distances with broadcasting.
76 |
77 | Args:
78 | geom1: First collection of geometries. Expected to have a shape like
79 | (*batch1, N, ...), where N is the number of geometries.
80 | geom2: Second collection of geometries. Expected to have a shape like
81 | (*batch2, M, ...), where M is the number of geometries.
82 | The batch dimensions (*batch1, *batch2) must be broadcast-compatible.
83 |
84 | Returns:
85 | A matrix of distances with shape (*batch_combined, N, M), where
86 | *batch_combined is the result of broadcasting *batch1 and *batch2.
87 | dist[..., i, j] is the distance between geom1[..., i, :] and geom2[..., j, :].
88 | """
89 | # Input checks.
90 | axes1 = geom1.get_batch_axes()
91 | axes2 = geom2.get_batch_axes()
92 | assert len(axes1) >= 1, (
93 | f"geom1 must have at least one batch dimension to map over, got shape {axes1}"
94 | )
95 | assert len(axes2) >= 1, (
96 | f"geom2 must have at least one batch dimension to map over, got shape {axes2}"
97 | )
98 |
99 | # Determine expected output shape.
100 | batch1_shape = axes1[:-1]
101 | batch2_shape = axes2[:-1]
102 | N = axes1[-1]
103 | M = axes2[-1]
104 | try:
105 | batch_combined_shape = jnp.broadcast_shapes(batch1_shape, batch2_shape)
106 | except ValueError as e:
107 | raise ValueError(
108 | f"Cannot broadcast non-mapped batch shapes {batch1_shape} and {batch2_shape}"
109 | ) from e
110 | expected_output_shape = (*batch_combined_shape, N, M)
111 | result = collide(
112 | geom1.broadcast_to((*batch_combined_shape, N))
113 | .reshape((*batch_combined_shape, N, 1))
114 | .broadcast_to(expected_output_shape),
115 | geom2.broadcast_to((*batch_combined_shape, M))
116 | .reshape((*batch_combined_shape, 1, M))
117 | .broadcast_to(expected_output_shape),
118 | )
119 | assert result.shape == expected_output_shape, (
120 | f"Output shape mismatch. Expected {expected_output_shape}, got {result.shape}"
121 | )
122 | return result
123 |
124 |
125 | def colldist_from_sdf(
126 | _dist: jax.Array,
127 | activation_dist: jax.Array | float,
128 | ) -> jax.Array:
129 | """
130 | Convert a signed distance field to a collision distance field,
131 | based on https://arxiv.org/pdf/2310.17274#page=7.39.
132 |
133 | This function applies a smoothing transformation, useful for converting
134 | raw distances into values suitable for cost functions in optimization.
135 | It returns values <= 0, where 0 corresponds to distances >= activation_dist,
136 | and increasingly negative values for deeper penetrations.
137 |
138 | Args:
139 | _dist: Signed distance field values (positive = separation, negative = penetration).
140 | activation_dist: The distance threshold (margin) below which the cost activates.
141 |
142 | Returns:
143 | Transformed collision distance field values (<= 0).
144 | """
145 | _dist = jnp.minimum(_dist, activation_dist)
146 | _dist = jnp.where(
147 | _dist < 0,
148 | _dist - 0.5 * activation_dist,
149 | -0.5 / (activation_dist + 1e-6) * (_dist - activation_dist) ** 2,
150 | )
151 | _dist = jnp.minimum(_dist, 0.0)
152 | return _dist
153 |
--------------------------------------------------------------------------------
/src/pyroki/collision/_geometry_pairs.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import jax.numpy as jnp
4 | from jaxtyping import Float, Array
5 |
6 | from ._geometry import HalfSpace, Sphere, Capsule, Heightmap
7 | from . import _utils
8 |
9 |
10 | # --- HalfSpace Collision Implementations ---
11 |
12 |
13 | def _halfspace_sphere_dist(
14 | halfspace_normal: Float[Array, "*batch 3"],
15 | halfspace_point: Float[Array, "*batch 3"],
16 | sphere_pos: Float[Array, "*batch 3"],
17 | sphere_radius: Float[Array, "*batch"],
18 | ) -> Float[Array, "*batch"]:
19 | """Helper: Calculates distance between a halfspace boundary plane and sphere center, minus radius."""
20 | dist = (
21 | jnp.einsum("...i,...i->...", sphere_pos - halfspace_point, halfspace_normal)
22 | - sphere_radius
23 | )
24 | return dist
25 |
26 |
27 | def halfspace_sphere(halfspace: HalfSpace, sphere: Sphere) -> Float[Array, "*batch"]:
28 | """Calculates distance between a halfspace and a sphere."""
29 | dist = _halfspace_sphere_dist(
30 | halfspace.normal,
31 | halfspace.pose.translation(),
32 | sphere.pose.translation(),
33 | sphere.radius,
34 | )
35 | return dist
36 |
37 |
38 | def halfspace_capsule(halfspace: HalfSpace, capsule: Capsule) -> Float[Array, "*batch"]:
39 | """Calculates distance between halfspace and capsule (closest end)."""
40 | halfspace_normal = halfspace.normal
41 | halfspace_point = halfspace.pose.translation()
42 | cap_center = capsule.pose.translation()
43 | cap_radius = capsule.radius
44 | cap_axis = capsule.axis
45 | segment_offset = cap_axis * capsule.height[..., None] / 2
46 | dist1 = _halfspace_sphere_dist(
47 | halfspace_normal, halfspace_point, cap_center + segment_offset, cap_radius
48 | )
49 | dist2 = _halfspace_sphere_dist(
50 | halfspace_normal, halfspace_point, cap_center - segment_offset, cap_radius
51 | )
52 | final_dist = jnp.minimum(dist1, dist2)
53 | return final_dist
54 |
55 |
56 | # --- Sphere/Capsule Collision Implementations ---
57 |
58 |
59 | def _sphere_sphere_dist(
60 | pos1: Float[Array, "*batch 3"],
61 | radius1: Float[Array, "*batch"],
62 | pos2: Float[Array, "*batch 3"],
63 | radius2: Float[Array, "*batch"],
64 | ) -> Float[Array, "*batch"]:
65 | """Helper: Calculates distance between two spheres."""
66 | _, dist_center = _utils.normalize_with_norm(pos2 - pos1)
67 | dist = dist_center - (radius1 + radius2)
68 | return dist
69 |
70 |
71 | def sphere_sphere(sphere1: Sphere, sphere2: Sphere) -> Float[Array, "*batch"]:
72 | """Calculate distance between two spheres."""
73 | dist = _sphere_sphere_dist(
74 | sphere1.pose.translation(),
75 | sphere1.radius,
76 | sphere2.pose.translation(),
77 | sphere2.radius,
78 | )
79 | return dist
80 |
81 |
82 | def sphere_capsule(sphere: Sphere, capsule: Capsule) -> Float[Array, "*batch"]:
83 | """Calculate distance between sphere and capsule."""
84 | cap_pos = capsule.pose.translation()
85 | sphere_pos = sphere.pose.translation()
86 | cap_axis = capsule.axis
87 | segment_offset = cap_axis * capsule.height[..., None] / 2
88 | cap_a = cap_pos - segment_offset
89 | cap_b = cap_pos + segment_offset
90 | pt_on_axis = _utils.closest_segment_point(cap_a, cap_b, sphere_pos)
91 | dist = _sphere_sphere_dist(sphere_pos, sphere.radius, pt_on_axis, capsule.radius)
92 | return dist
93 |
94 |
95 | def capsule_capsule(capsule1: Capsule, capsule2: Capsule) -> Float[Array, "*batch"]:
96 | """Calculate distance between two capsules."""
97 | pos1 = capsule1.pose.translation()
98 | axis1 = capsule1.axis
99 | length1 = capsule1.height
100 | radius1 = capsule1.radius
101 | segment1_offset = axis1 * length1[..., None] / 2
102 | a1 = pos1 - segment1_offset
103 | b1 = pos1 + segment1_offset
104 |
105 | pos2 = capsule2.pose.translation()
106 | axis2 = capsule2.axis
107 | length2 = capsule2.height
108 | radius2 = capsule2.radius
109 | segment2_offset = axis2 * length2[..., None] / 2
110 | a2 = pos2 - segment2_offset
111 | b2 = pos2 + segment2_offset
112 |
113 | pt1_on_axis, pt2_on_axis = _utils.closest_segment_to_segment_points(a1, b1, a2, b2)
114 | dist = _sphere_sphere_dist(pt1_on_axis, radius1, pt2_on_axis, radius2)
115 | return dist
116 |
117 |
118 | # --- Heightmap Collision Implementations ---
119 |
120 |
121 | def heightmap_sphere(heightmap: Heightmap, sphere: Sphere) -> Float[Array, "*batch"]:
122 | """Calculate approximate distance between heightmap and sphere.
123 |
124 | Approximation: Considers the heightmap point directly below the sphere center
125 | using bilinear interpolation and calculates vertical distance minus radius.
126 | """
127 | batch_axes = jnp.broadcast_shapes(
128 | heightmap.get_batch_axes(), sphere.get_batch_axes()
129 | )
130 |
131 | sphere_pos_w = sphere.pose.translation()
132 | sphere_radius = sphere.radius
133 | interpolated_local_z = heightmap._interpolate_height_at_coords(sphere_pos_w)
134 | sphere_pos_h = heightmap.pose.inverse().apply(sphere_pos_w)
135 | sphere_local_z = sphere_pos_h[..., 2]
136 | dist = sphere_local_z - interpolated_local_z - sphere_radius
137 |
138 | assert dist.shape == batch_axes
139 | return dist
140 |
141 |
142 | def heightmap_capsule(heightmap: Heightmap, capsule: Capsule) -> Float[Array, "*batch"]:
143 | """Calculate approximate distance between heightmap and capsule, by
144 | checking heightmap points below capsule endpoints.
145 |
146 | Note that this may miss collisions when capsule body intersects but endpoints are above heightmap!
147 | """
148 | batch_axes = jnp.broadcast_shapes(
149 | heightmap.get_batch_axes(), capsule.get_batch_axes()
150 | )
151 |
152 | cap_pos_w = capsule.pose.translation()
153 | cap_radius = capsule.radius
154 | cap_axis_w = capsule.axis # World frame axis
155 | segment_offset_w = cap_axis_w * capsule.height[..., None] / 2
156 |
157 | # Calculate world positions of the two end-sphere centers.
158 | p1_w = cap_pos_w + segment_offset_w
159 | p2_w = cap_pos_w - segment_offset_w
160 |
161 | # Interpolate heightmap surface height (local Z) below each end-sphere center.
162 | h_surf1_local = heightmap._interpolate_height_at_coords(p1_w)
163 | h_surf2_local = heightmap._interpolate_height_at_coords(p2_w)
164 |
165 | # Get end-sphere centers Z coordinates in heightmap's local frame.
166 | p1_h = heightmap.pose.inverse().apply(p1_w)
167 | p2_h = heightmap.pose.inverse().apply(p2_w)
168 | z1_local = p1_h[..., 2]
169 | z2_local = p2_h[..., 2]
170 |
171 | # Calculate vertical distance for each end sphere.
172 | dist1 = z1_local - h_surf1_local - cap_radius
173 | dist2 = z2_local - h_surf2_local - cap_radius
174 |
175 | # Return the minimum distance.
176 | min_dist = jnp.minimum(dist1, dist2)
177 | assert min_dist.shape == batch_axes
178 | return min_dist
179 |
180 |
181 | def heightmap_halfspace(
182 | heightmap: Heightmap, halfspace: HalfSpace
183 | ) -> Float[Array, "*batch"]:
184 | """Calculate approximate distance between heightmap and halfspace.
185 |
186 | Approximation: Finds the minimum signed distance between any heightmap vertex
187 | and the halfspace plane.
188 | """
189 | batch_axes = jnp.broadcast_shapes(
190 | heightmap.get_batch_axes(), halfspace.get_batch_axes()
191 | )
192 |
193 | # Heightmap vertices in world frame.
194 | verts_local = heightmap._get_vertices_local() # (*batch, N, 3), N=H*W
195 | verts_world = heightmap.pose.apply(verts_local) # (*batch, N, 3)
196 |
197 | # Halfspace plane properties (world frame).
198 | hs_normal_w = halfspace.normal # (*batch, 3)
199 | hs_point_w = halfspace.pose.translation() # (*batch, 3)
200 |
201 | # Ensure batch dimensions are compatible for broadcasting.
202 | batch_axes = jnp.broadcast_shapes(
203 | heightmap.get_batch_axes(), halfspace.get_batch_axes()
204 | )
205 | # Expand dims for broadcasting against vertices.
206 | hs_normal_w = jnp.broadcast_to(hs_normal_w, batch_axes + (3,))[..., None, :]
207 | hs_point_w = jnp.broadcast_to(hs_point_w, batch_axes + (3,))[..., None, :]
208 | verts_world = jnp.broadcast_to(verts_world, batch_axes + verts_world.shape[-2:])
209 |
210 | # Calculate signed distance for each vertex to the plane:
211 | vertex_distances = jnp.einsum(
212 | "...vi,...i->...v", verts_world - hs_point_w, hs_normal_w.squeeze(-2)
213 | )
214 |
215 | # Find the minimum distance among all vertices.
216 | min_dist = jnp.min(vertex_distances, axis=-1)
217 | assert min_dist.shape == batch_axes
218 | return min_dist
219 |
--------------------------------------------------------------------------------
/src/pyroki/collision/_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from typing import Tuple
6 | from jaxtyping import Float, Array
7 |
8 | _SAFE_EPS = 1e-6
9 |
10 |
11 | def make_frame(direction: jax.Array) -> jax.Array:
12 | """Make a frame from a direction vector, aligning the z-axis with the direction."""
13 | # Based on `mujoco.mjx._src.math.make_frame`.
14 |
15 | is_zero = jnp.isclose(direction, 0.0).all(axis=-1, keepdims=True)
16 | direction = jnp.where(
17 | is_zero,
18 | jnp.broadcast_to(jnp.array([1.0, 0.0, 0.0]), direction.shape),
19 | direction,
20 | )
21 | direction /= jnp.linalg.norm(direction, axis=-1, keepdims=True) + _SAFE_EPS
22 |
23 | y = jnp.broadcast_to(jnp.array([0, 1, 0]), (*direction.shape[:-1], 3))
24 | z = jnp.broadcast_to(jnp.array([0, 0, 1]), (*direction.shape[:-1], 3))
25 |
26 | normal = jnp.where((-0.5 < direction[..., 1:2]) & (direction[..., 1:2] < 0.5), y, z)
27 | normal -= direction * jnp.einsum("...i,...i->...", normal, direction)[..., None]
28 | normal /= jnp.linalg.norm(normal, axis=-1, keepdims=True) + _SAFE_EPS
29 |
30 | return jnp.stack([jnp.cross(normal, direction), normal, direction], axis=-1)
31 |
32 |
33 | def normalize(x: Float[Array, "*batch N"]) -> Float[Array, "*batch N"]:
34 | """Normalizes a vector, handling the zero vector."""
35 | norm = jnp.linalg.norm(x, axis=-1, keepdims=True)
36 | safe_norm = jnp.where(norm == 0.0, 1.0, norm)
37 | normalized_x = x / safe_norm
38 | return jnp.where(norm == 0.0, jnp.zeros_like(x), normalized_x)
39 |
40 |
41 | def normalize_with_norm(
42 | x: Float[Array, "*batch N"],
43 | ) -> Tuple[Float[Array, "*batch N"], Float[Array, "*batch"]]:
44 | """Normalizes a vector and returns the norm, handling the zero vector."""
45 | norm = jnp.linalg.norm(x + 1e-6, axis=-1, keepdims=True)
46 | safe_norm = jnp.where(norm == 0.0, 1.0, norm)
47 | normalized_x = x / safe_norm
48 | result_vec = jnp.where(norm == 0.0, jnp.zeros_like(x), normalized_x)
49 | result_norm = norm[..., 0]
50 | return result_vec, result_norm
51 |
52 |
53 | def closest_segment_point(
54 | a: Float[Array, "*batch 3"],
55 | b: Float[Array, "*batch 3"],
56 | pt: Float[Array, "*batch 3"],
57 | ) -> Float[Array, "*batch 3"]:
58 | """Finds the closest point on the line segment [a, b] to point pt."""
59 | ab = b - a
60 | t = jnp.einsum("...i,...i->...", pt - a, ab) / (
61 | jnp.einsum("...i,...i->...", ab, ab) + _SAFE_EPS
62 | )
63 | t_clamped = jnp.clip(t, 0.0, 1.0)
64 | return a + ab * t_clamped[..., None]
65 |
66 |
67 | def closest_segment_to_segment_points(
68 | a1: Float[Array, "*batch 3"],
69 | b1: Float[Array, "*batch 3"],
70 | a2: Float[Array, "*batch 3"],
71 | b2: Float[Array, "*batch 3"],
72 | ) -> Tuple[Float[Array, "*batch 3"], Float[Array, "*batch 3"]]:
73 | """Finds the closest points between two line segments [a1, b1] and [a2, b2]."""
74 | d1 = b1 - a1 # Direction vector of segment S1
75 | d2 = b2 - a2 # Direction vector of segment S2
76 | r = a1 - a2
77 |
78 | a = jnp.einsum("...i,...i->...", d1, d1) # Squared length of segment S1
79 | e = jnp.einsum("...i,...i->...", d2, d2) # Squared length of segment S2
80 | f = jnp.einsum("...i,...i->...", d2, r)
81 | c = jnp.einsum("...i,...i->...", d1, r)
82 | b = jnp.einsum("...i,...i->...", d1, d2)
83 | denom = a * e - b * b # Squared area of the parallelogram defined by d1, d2
84 |
85 | s_num = b * f - c * e
86 | t_num = a * f - b * c
87 |
88 | s_parallel = -c / (a + _SAFE_EPS)
89 | t_parallel = f / (e + _SAFE_EPS)
90 |
91 | s = jnp.where(denom < _SAFE_EPS, s_parallel, s_num / (denom + _SAFE_EPS))
92 | t = jnp.where(denom < _SAFE_EPS, t_parallel, t_num / (denom + _SAFE_EPS))
93 |
94 | s_clamped = jnp.clip(s, 0.0, 1.0)
95 | t_clamped = jnp.clip(t, 0.0, 1.0)
96 |
97 | t_recomp = jnp.einsum(
98 | "...i,...i->...", d2, (a1 + d1 * s_clamped[..., None]) - a2
99 | ) / (e + _SAFE_EPS)
100 | t_final = jnp.where(
101 | jnp.abs(s - s_clamped) > _SAFE_EPS, jnp.clip(t_recomp, 0.0, 1.0), t_clamped
102 | )
103 |
104 | s_recomp = jnp.einsum("...i,...i->...", d1, (a2 + d2 * t_final[..., None]) - a1) / (
105 | a + _SAFE_EPS
106 | )
107 | s_final = jnp.where(
108 | jnp.abs(t - t_final) > _SAFE_EPS, jnp.clip(s_recomp, 0.0, 1.0), s_clamped
109 | )
110 |
111 | c1 = a1 + d1 * s_final[..., None]
112 | c2 = a2 + d2 * t_final[..., None]
113 | return c1, c2
114 |
--------------------------------------------------------------------------------
/src/pyroki/costs/__init__.py:
--------------------------------------------------------------------------------
1 | from ._costs import five_point_acceleration_cost as five_point_acceleration_cost
2 | from ._costs import five_point_jerk_cost as five_point_jerk_cost
3 | from ._costs import five_point_velocity_cost as five_point_velocity_cost
4 | from ._costs import limit_cost as limit_cost
5 | from ._costs import limit_velocity_cost as limit_velocity_cost
6 | from ._costs import manipulability_cost as manipulability_cost
7 | from ._costs import pose_cost as pose_cost
8 | from ._costs import pose_cost_with_base as pose_cost_with_base
9 | from ._costs import rest_cost as rest_cost
10 | from ._costs import rest_with_base_cost as rest_with_base_cost
11 | from ._costs import self_collision_cost as self_collision_cost
12 | from ._costs import smoothness_cost as smoothness_cost
13 | from ._costs import world_collision_cost as world_collision_cost
14 | from ._pose_cost_analytic_jac import pose_cost_analytic_jac as pose_cost_analytic_jac
15 | from ._pose_cost_numerical_jac import pose_cost_numerical_jac as pose_cost_numerical_jac
16 |
--------------------------------------------------------------------------------
/src/pyroki/costs/_costs.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import jaxlie
4 | from jax import Array
5 | from jaxls import Cost, Var, VarValues
6 |
7 | from .._robot import Robot
8 | from ..collision import CollGeom, RobotCollision, colldist_from_sdf
9 |
10 |
11 | @Cost.create_factory
12 | def pose_cost(
13 | vals: VarValues,
14 | robot: Robot,
15 | joint_var: Var[Array],
16 | target_pose: jaxlie.SE3,
17 | target_link_index: Array,
18 | pos_weight: Array | float,
19 | ori_weight: Array | float,
20 | ) -> Array:
21 | """Computes the residual for matching link poses to target poses."""
22 | assert target_link_index.dtype == jnp.int32
23 | joint_cfg = vals[joint_var]
24 | Ts_link_world = robot.forward_kinematics(joint_cfg)
25 | pose_actual = jaxlie.SE3(Ts_link_world[..., target_link_index, :])
26 | residual = (pose_actual.inverse() @ target_pose).log()
27 | pos_residual = residual[..., :3] * pos_weight
28 | ori_residual = residual[..., 3:] * ori_weight
29 | return jnp.concatenate([pos_residual, ori_residual]).flatten()
30 |
31 |
32 | @Cost.create_factory
33 | def pose_cost_with_base(
34 | vals: VarValues,
35 | robot: Robot,
36 | joint_var: Var[Array],
37 | T_world_base_var: Var[jaxlie.SE3],
38 | target_pose: jaxlie.SE3,
39 | target_link_indices: Array,
40 | pos_weight: Array | float,
41 | ori_weight: Array | float,
42 | ) -> Array:
43 | """Computes the residual for matching link poses relative to a mobile base."""
44 | assert target_link_indices.dtype == jnp.int32
45 | joint_cfg = vals[joint_var]
46 | T_world_base = vals[T_world_base_var]
47 | Ts_base_link = robot.forward_kinematics(joint_cfg) # FK is T_base_link
48 | T_base_target_link = jaxlie.SE3(Ts_base_link[..., target_link_indices, :])
49 | T_world_target_link_actual = T_world_base @ T_base_target_link
50 |
51 | residual = (T_world_target_link_actual.inverse() @ target_pose).log()
52 | pos_residual = residual[..., :3] * pos_weight
53 | ori_residual = residual[..., 3:] * ori_weight
54 | return jnp.concatenate([pos_residual, ori_residual]).flatten()
55 |
56 |
57 | # --- Limit Costs ---
58 |
59 |
60 | @Cost.create_factory
61 | def limit_cost(
62 | vals: VarValues,
63 | robot: Robot,
64 | joint_var: Var[Array],
65 | weight: Array | float,
66 | ) -> Array:
67 | """Computes the residual penalizing joint limit violations."""
68 | joint_cfg = vals[joint_var]
69 | joint_cfg_eff = robot.joints.get_full_config(joint_cfg)
70 | residual_upper = jnp.maximum(0.0, joint_cfg_eff - robot.joints.upper_limits_all)
71 | residual_lower = jnp.maximum(0.0, robot.joints.lower_limits_all - joint_cfg_eff)
72 | return ((residual_upper + residual_lower) * weight).flatten()
73 |
74 |
75 | @Cost.create_factory
76 | def limit_velocity_cost(
77 | vals: VarValues,
78 | robot: Robot,
79 | joint_var: Var[Array],
80 | prev_joint_var: Var[Array],
81 | dt: float,
82 | weight: Array | float,
83 | ) -> Array:
84 | """Computes the residual penalizing joint velocity limit violations."""
85 | joint_vel = (vals[joint_var] - vals[prev_joint_var]) / dt
86 | residual = jnp.maximum(0.0, jnp.abs(joint_vel) - robot.joints.velocity_limits)
87 | return (residual * weight).flatten()
88 |
89 |
90 | # --- Regularization Costs ---
91 |
92 |
93 | @Cost.create_factory
94 | def rest_cost(
95 | vals: VarValues,
96 | joint_var: Var[Array],
97 | rest_pose: Array,
98 | weight: Array | float,
99 | ) -> Array:
100 | """Computes the residual biasing joints towards a rest pose."""
101 | return ((vals[joint_var] - rest_pose) * weight).flatten()
102 |
103 |
104 | @Cost.create_factory
105 | def rest_with_base_cost(
106 | vals: VarValues,
107 | joint_var: Var[Array],
108 | T_world_base_var: Var[jaxlie.SE3],
109 | rest_pose: Array,
110 | weight: Array | float,
111 | ) -> Array:
112 | """Computes the residual biasing joints and base towards rest/identity."""
113 | residual_joints = vals[joint_var] - rest_pose
114 | residual_base = vals[T_world_base_var].log()
115 | return (jnp.concatenate([residual_joints, residual_base]) * weight).flatten()
116 |
117 |
118 | @Cost.create_factory
119 | def smoothness_cost(
120 | vals: VarValues,
121 | curr_joint_var: Var[Array],
122 | past_joint_var: Var[Array],
123 | weight: Array | float,
124 | ) -> Array:
125 | """Computes the residual penalizing joint configuration differences (velocity)."""
126 | return ((vals[curr_joint_var] - vals[past_joint_var]) * weight).flatten()
127 |
128 |
129 | # --- Manipulability Cost ---
130 |
131 |
132 | def _compute_manip_yoshikawa(
133 | cfg: Array,
134 | robot: Robot,
135 | target_link_index: jax.Array,
136 | ) -> Array:
137 | """Helper: Computes manipulability measure for a single link."""
138 | jacobian = jax.jacfwd(
139 | lambda q: jaxlie.SE3(robot.forward_kinematics(q)).translation()
140 | )(cfg)[target_link_index]
141 | JJT = jacobian @ jacobian.T
142 | assert JJT.shape == (3, 3)
143 | return jnp.sqrt(jnp.maximum(0.0, jnp.linalg.det(JJT)))
144 |
145 |
146 | @Cost.create_factory
147 | def manipulability_cost(
148 | vals: VarValues,
149 | robot: Robot,
150 | joint_var: Var[Array],
151 | target_link_indices: Array,
152 | weight: Array | float,
153 | ) -> Array:
154 | """Computes the residual penalizing low manipulability (translation)."""
155 | cfg = vals[joint_var]
156 | if target_link_indices.ndim == 0:
157 | vmapped_manip = _compute_manip_yoshikawa(cfg, robot, target_link_indices)
158 | else:
159 | vmapped_manip = jax.vmap(_compute_manip_yoshikawa, in_axes=(None, None, 0))(
160 | cfg, robot, target_link_indices
161 | )
162 | residual = 1.0 / (vmapped_manip + 1e-6)
163 | return (residual * weight).flatten()
164 |
165 |
166 | # --- Collision Costs ---
167 |
168 |
169 | @Cost.create_factory
170 | def self_collision_cost(
171 | vals: VarValues,
172 | robot: Robot,
173 | robot_coll: RobotCollision,
174 | joint_var: Var[Array],
175 | margin: float,
176 | weight: Array | float,
177 | ) -> Array:
178 | """Computes the residual penalizing self-collisions below a margin."""
179 | cfg = vals[joint_var]
180 | active_distances = robot_coll.compute_self_collision_distance(robot, cfg)
181 | residual = colldist_from_sdf(active_distances, margin)
182 | return (residual * weight).flatten()
183 |
184 |
185 | @Cost.create_factory
186 | def world_collision_cost(
187 | vals: VarValues,
188 | robot: Robot,
189 | robot_coll: RobotCollision,
190 | joint_var: Var[Array],
191 | world_geom: CollGeom,
192 | margin: float,
193 | weight: Array | float,
194 | ) -> Array:
195 | """Computes the residual penalizing world collisions below a margin."""
196 | cfg = vals[joint_var]
197 | dist_matrix = robot_coll.compute_world_collision_distance(robot, cfg, world_geom)
198 | residual = colldist_from_sdf(dist_matrix, margin)
199 | return (residual * weight).flatten()
200 |
201 |
202 | # --- Finite Difference Costs (Velocity, Acceleration, Jerk) ---
203 |
204 |
205 | @Cost.create_factory
206 | def five_point_velocity_cost(
207 | vals: VarValues,
208 | robot: Robot, # Needed for limits
209 | var_t_plus_2: Var[Array],
210 | var_t_plus_1: Var[Array],
211 | var_t_minus_1: Var[Array],
212 | var_t_minus_2: Var[Array],
213 | dt: float,
214 | weight: Array | float,
215 | ) -> Array:
216 | """Computes the residual penalizing velocity limit violations (5-point stencil)."""
217 | q_tm2 = vals[var_t_minus_2]
218 | q_tm1 = vals[var_t_minus_1]
219 | q_tp1 = vals[var_t_plus_1]
220 | q_tp2 = vals[var_t_plus_2]
221 |
222 | velocity = (-q_tp2 + 8 * q_tp1 - 8 * q_tm1 + q_tm2) / (12 * dt)
223 | vel_limits = robot.joints.velocity_limits
224 | limit_violation = jnp.maximum(0.0, jnp.abs(velocity) - vel_limits)
225 | return (limit_violation * weight).flatten()
226 |
227 |
228 | @Cost.create_factory
229 | def five_point_acceleration_cost(
230 | vals: VarValues,
231 | var_t: Var[Array],
232 | var_t_plus_2: Var[Array],
233 | var_t_plus_1: Var[Array],
234 | var_t_minus_1: Var[Array],
235 | var_t_minus_2: Var[Array],
236 | dt: float,
237 | weight: Array | float,
238 | ) -> Array:
239 | """Computes the residual minimizing joint acceleration (5-point stencil)."""
240 | q_tm2 = vals[var_t_minus_2]
241 | q_tm1 = vals[var_t_minus_1]
242 | q_t = vals[var_t]
243 | q_tp1 = vals[var_t_plus_1]
244 | q_tp2 = vals[var_t_plus_2]
245 |
246 | acceleration = (-q_tp2 + 16 * q_tp1 - 30 * q_t + 16 * q_tm1 - q_tm2) / (12 * dt**2)
247 | return (acceleration * weight).flatten()
248 |
249 |
250 | @Cost.create_factory
251 | def five_point_jerk_cost(
252 | vals: VarValues,
253 | var_t_plus_3: Var[Array],
254 | var_t_plus_2: Var[Array],
255 | var_t_plus_1: Var[Array],
256 | var_t_minus_1: Var[Array],
257 | var_t_minus_2: Var[Array],
258 | var_t_minus_3: Var[Array],
259 | dt: float,
260 | weight: Array | float,
261 | ) -> Array:
262 | """Computes the residual minimizing joint jerk (7-point stencil)."""
263 | q_tm3 = vals[var_t_minus_3]
264 | q_tm2 = vals[var_t_minus_2]
265 | q_tm1 = vals[var_t_minus_1]
266 | q_tp1 = vals[var_t_plus_1]
267 | q_tp2 = vals[var_t_plus_2]
268 | q_tp3 = vals[var_t_plus_3]
269 |
270 | jerk = (-q_tp3 + 8 * q_tp2 - 13 * q_tp1 + 13 * q_tm1 - 8 * q_tm2 + q_tm3) / (
271 | 8 * dt**3
272 | )
273 | return (jerk * weight).flatten()
274 |
--------------------------------------------------------------------------------
/src/pyroki/costs/_pose_cost_analytic_jac.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import jaxlie
6 | import jaxls
7 |
8 | from .._robot import Robot
9 |
10 |
11 | def _get_actuated_joints_applied_to_target(
12 | robot: Robot,
13 | target_joint_idx: jax.Array,
14 | ) -> jax.Array:
15 | """For each joint `i` in the robot, we return an index that is:
16 |
17 | 1) -1 if the joint is not in the path to the target link.
18 | 2) an actuated joint index if the joint is in the path to the target link.
19 | - If the joint `i` is actuated, this is just `i`.
20 | - If the joint `i` mimics joint `j`, this is set to `j`.
21 |
22 | The inputs and outputs should both be integer arrays of shape
23 | `(num_joints,)`.
24 | """
25 |
26 | assert target_joint_idx.shape == ()
27 |
28 | def body_fun(joint_idx, indices):
29 | # Find the corresponding active actuated joint.
30 | active_act_joint = jnp.where(
31 | robot.joints.actuated_indices[joint_idx] != -1,
32 | # The current joint is actuated.
33 | robot.joints.actuated_indices[joint_idx],
34 | # The current joint is not actuated; this is -1 if not a mimic joint.
35 | robot.joints.mimic_act_indices[joint_idx],
36 | )
37 |
38 | # Find the parent of the current joint.
39 | parent_joint = robot.joints.parent_indices[joint_idx]
40 |
41 | # Continue traversing up the kinematic tree, using the parent joint.
42 | # This value may either go up or down, since there's no guarantee that
43 | # the kinematic tree is topologically sorted. :-)
44 | next_indices = indices.at[joint_idx].set(active_act_joint)
45 | return (parent_joint, next_indices)
46 |
47 | idx_applied_to_target = jnp.full(
48 | (robot.joints.num_joints,),
49 | fill_value=-1,
50 | dtype=jnp.int32,
51 | )
52 | idx_applied_to_target = jax.lax.while_loop(
53 | lambda carry: jnp.any(carry[0] >= 0),
54 | lambda carry: body_fun(*carry),
55 | (target_joint_idx, idx_applied_to_target),
56 | )[-1]
57 | return idx_applied_to_target
58 |
59 |
60 | _PoseCostJacCache = tuple[jax.Array, jax.Array, jaxlie.SE3]
61 |
62 |
63 | def pose_cost_analytic_jac(
64 | robot: Robot,
65 | joint_var: jaxls.Var[jax.Array],
66 | target_pose: jaxlie.SE3,
67 | target_link_index: jax.Array,
68 | pos_weight: jax.Array | float,
69 | ori_weight: jax.Array | float,
70 | ) -> jaxls.Cost:
71 | # We only check shape lengths because there might be (1,) axes for
72 | # broadcasting reasons.
73 | assert (
74 | len(target_link_index.shape)
75 | == len(jnp.asarray(joint_var.id).shape)
76 | == len(robot.joints.twists.shape[:-2])
77 | ), "Batch axes of inputs should match"
78 |
79 | # Broadcast the inputs for _get_actuated_joints_applied_to_target().
80 | # Excluding the weights for now...
81 | batch_axes = jnp.broadcast_shapes(
82 | target_pose.get_batch_axes(),
83 | jnp.asarray(joint_var.id).shape,
84 | target_pose.get_batch_axes(),
85 | target_link_index.shape,
86 | )
87 | broadcast_batch_axes = partial(
88 | jax.tree.map,
89 | lambda x: jnp.broadcast_to(x, batch_axes + x.shape[len(batch_axes) :]),
90 | )
91 | get_actuated_joints = _get_actuated_joints_applied_to_target
92 | for _ in range(len(batch_axes)):
93 | get_actuated_joints = jax.vmap(get_actuated_joints)
94 |
95 | # Compute applied joints.
96 | robot = broadcast_batch_axes(robot)
97 | base_link_mask = robot.links.parent_joint_indices == -1
98 | parent_joint_indices = jnp.where(
99 | base_link_mask, 0, robot.links.parent_joint_indices
100 | )
101 | target_joint_idx = parent_joint_indices[
102 | tuple(jnp.arange(d) for d in parent_joint_indices.shape[:-1])
103 | + (target_link_index,)
104 | ]
105 | joints_applied_to_target = get_actuated_joints(
106 | broadcast_batch_axes(robot), broadcast_batch_axes(target_joint_idx)
107 | )
108 |
109 | return _pose_cost_analytical_jac(
110 | robot,
111 | joint_var,
112 | target_pose,
113 | target_link_index,
114 | pos_weight,
115 | ori_weight,
116 | joints_applied_to_target,
117 | )
118 |
119 |
120 | # It's nice to pass arguments in explicitly instead of via closure in the
121 | # `pose_cost_analytic_jac` wrapper. It helps jaxls vectorize repeated costs.
122 | def _pose_cost_jac(
123 | vals: jaxls.VarValues,
124 | jac_cache: _PoseCostJacCache,
125 | robot: Robot,
126 | joint_var: jaxls.Var[jax.Array],
127 | target_pose: jaxlie.SE3,
128 | target_link_index: jax.Array,
129 | pos_weight: jax.Array | float,
130 | ori_weight: jax.Array | float,
131 | joints_applied_to_target: jax.Array,
132 | ) -> jax.Array:
133 | """Jacobian for pose cost with analytic computation."""
134 | del vals, joint_var, target_pose # Unused!
135 | Ts_world_joint, Ts_world_link, pose_error = jac_cache
136 |
137 | T_world_ee = jaxlie.SE3(Ts_world_link[target_link_index])
138 | Ts_world_joint = jaxlie.SE3(Ts_world_joint)
139 |
140 | R_ee_world = T_world_ee.rotation().inverse()
141 |
142 | # Get joint twists; these are scaled for mimic joints.
143 | joint_twists = robot.joints.twists * robot.joints.mimic_multiplier[..., None]
144 |
145 | # Get angular velocity components (omega).
146 | omega_local = joint_twists[:, 3:]
147 | omega_wrt_world = Ts_world_joint.rotation() @ omega_local
148 | omega_wrt_ee = R_ee_world @ omega_wrt_world
149 |
150 | # Get linear velocity components (v).
151 | vel_local = joint_twists[:, :3]
152 | vel_wrt_world = Ts_world_joint.rotation() @ vel_local
153 |
154 | # Compute the linear velocity component (v = ω × r + v_joint).
155 | vel_wrt_world = (
156 | jnp.cross(
157 | omega_wrt_world,
158 | T_world_ee.translation() - Ts_world_joint.translation(),
159 | )
160 | + vel_wrt_world
161 | )
162 | vel_wrt_ee = R_ee_world @ vel_wrt_world
163 |
164 | # Combine into spatial Jacobian.
165 | jac = jnp.where(
166 | joints_applied_to_target[:, None] != -1,
167 | jnp.concatenate(
168 | [
169 | vel_wrt_ee,
170 | omega_wrt_ee,
171 | ],
172 | axis=1,
173 | ),
174 | 0.0,
175 | ).T
176 | jac = pose_error.jlog() @ jac
177 |
178 | # Jacobian of all joints => Jacobian of actuated joints.
179 | #
180 | # Because of mimic joints, the Jacobian terms from multiple joints can be
181 | # applied to a single actuated joint. This is summed!
182 | jac = (
183 | jnp.zeros((6, robot.joints.num_actuated_joints))
184 | .at[:, joints_applied_to_target]
185 | .add((joints_applied_to_target[None, :] != -1) * jac)
186 | )
187 |
188 | # Apply weights
189 | weights = jnp.array([pos_weight] * 3 + [ori_weight] * 3)
190 | return jac * weights[:, None]
191 |
192 |
193 | @jaxls.Cost.create_factory(jac_custom_with_cache_fn=_pose_cost_jac)
194 | def _pose_cost_analytical_jac(
195 | vals: jaxls.VarValues,
196 | robot: Robot,
197 | joint_var: jaxls.Var[jax.Array],
198 | target_pose: jaxlie.SE3,
199 | target_link_index: jax.Array,
200 | pos_weight: jax.Array | float,
201 | ori_weight: jax.Array | float,
202 | joints_applied_to_target: jax.Array,
203 | ) -> tuple[jax.Array, _PoseCostJacCache]:
204 | """Computes the residual for matching link poses to target poses."""
205 | del joints_applied_to_target
206 | assert target_link_index.dtype == jnp.int32
207 | joint_cfg = vals[joint_var]
208 |
209 | Ts_world_joint = robot._forward_kinematics_joints(joint_cfg)
210 | Ts_world_link = robot._link_poses_from_joint_poses(Ts_world_joint)
211 |
212 | T_world_ee = jaxlie.SE3(Ts_world_link[target_link_index, :])
213 | pose_error = target_pose.inverse() @ T_world_ee
214 | return (
215 | pose_error.log() * jnp.array([pos_weight] * 3 + [ori_weight] * 3),
216 | # Second argument is cache parameter, which is passed to the custom Jacobian function.
217 | (Ts_world_joint, Ts_world_link, pose_error),
218 | )
219 |
--------------------------------------------------------------------------------
/src/pyroki/costs/_pose_cost_numerical_jac.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | import jaxlie
4 | import jaxls
5 |
6 | from .._robot import Robot
7 |
8 | _PoseCostJacCache = tuple[jax.Array, jax.Array, jaxlie.SE3]
9 |
10 |
11 | def _pose_cost_jac(
12 | vals: jaxls.VarValues,
13 | jac_cache: _PoseCostJacCache,
14 | robot: Robot,
15 | joint_var: jaxls.Var[jax.Array],
16 | target_pose: jaxlie.SE3,
17 | target_link_index: jax.Array,
18 | pos_weight: jax.Array | float,
19 | ori_weight: jax.Array | float,
20 | eps: float = 1e-4,
21 | ) -> jax.Array:
22 | """Jacobian for pose cost with numerical computation."""
23 | joint_cfg = vals[joint_var]
24 | _, _, pose_error = jac_cache
25 |
26 | def finite_difference_jac(idx: jax.Array) -> jax.Array:
27 | joint_cfg_perturbed = joint_cfg.at[idx].add(eps)
28 | T_world_ee = jaxlie.SE3(
29 | robot.forward_kinematics(joint_cfg_perturbed)[..., target_link_index, :]
30 | )
31 | perturbed_pose_error = target_pose.inverse() @ T_world_ee
32 | err_diff = perturbed_pose_error.log() - pose_error.log()
33 | return err_diff / eps
34 |
35 | jac = jax.vmap(finite_difference_jac)(jnp.arange(joint_cfg.shape[-1])).T
36 | assert jac.shape == (6, robot.joints.num_actuated_joints)
37 |
38 | return jac * jnp.array([pos_weight] * 3 + [ori_weight] * 3)[:, None]
39 |
40 |
41 | @jaxls.Cost.create_factory(jac_custom_with_cache_fn=_pose_cost_jac)
42 | def pose_cost_numerical_jac(
43 | vals: jaxls.VarValues,
44 | robot: Robot,
45 | joint_var: jaxls.Var[jax.Array],
46 | target_pose: jaxlie.SE3,
47 | target_link_index: jax.Array,
48 | pos_weight: jax.Array | float,
49 | ori_weight: jax.Array | float,
50 | eps: float = 1e-4,
51 | ) -> tuple[jax.Array, _PoseCostJacCache]:
52 | """Computes the residual for matching link poses to target poses."""
53 | del eps # Unused!
54 | assert target_link_index.dtype == jnp.int32
55 | joint_cfg = vals[joint_var]
56 |
57 | Ts_world_joint = robot._forward_kinematics_joints(joint_cfg)
58 | Ts_world_link = robot._link_poses_from_joint_poses(Ts_world_joint)
59 |
60 | T_world_ee = jaxlie.SE3(Ts_world_link[target_link_index, :])
61 | pose_error = target_pose.inverse() @ T_world_ee
62 | return (
63 | pose_error.log() * jnp.array([pos_weight] * 3 + [ori_weight] * 3),
64 | # Second argument is cache parameter, which is passed to the custom Jacobian function.
65 | (Ts_world_joint, Ts_world_link, pose_error),
66 | )
67 |
--------------------------------------------------------------------------------
/src/pyroki/utils.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import time
3 | from functools import partial
4 | from typing import Generator
5 |
6 | import jax
7 | import termcolor
8 | from loguru import logger
9 |
10 |
11 | @contextlib.contextmanager
12 | def stopwatch(label: str = "unlabeled block") -> Generator[None, None, None]:
13 | """Context manager for measuring runtime."""
14 | start_time = time.time()
15 | print("\n========")
16 | print(f"Running ({label})")
17 | yield
18 | print(f"{termcolor.colored(str(time.time() - start_time), attrs=['bold'])} seconds")
19 | print("========")
20 |
21 |
22 | def _log(fmt: str, *args, **kwargs) -> None:
23 | logger.bind(function="log").info(fmt, *args, **kwargs)
24 |
25 |
26 | def jax_log(fmt: str, *args, **kwargs) -> None:
27 | """Emit a loguru info message from a JITed JAX function."""
28 | jax.debug.callback(partial(_log, fmt), *args, **kwargs)
29 |
--------------------------------------------------------------------------------
/src/pyroki/viewer/__init__.py:
--------------------------------------------------------------------------------
1 | # from ._batched_urdf import BatchedURDF as BatchedURDF
2 | from ._weight_tuner import WeightTuner as WeightTuner
3 | from ._manipulability_ellipse import ManipulabilityEllipse as ManipulabilityEllipse
4 |
--------------------------------------------------------------------------------
/src/pyroki/viewer/_batched_urdf.py:
--------------------------------------------------------------------------------
1 | """Batched URDF rendering in Viser. This requires features not yet available in
2 | the PyPi version of Viser, we can re-enable it after the next Viser release."""
3 |
4 | # import jax.numpy as jnp
5 | # import jaxlie
6 | # import numpy as onp
7 | # import viser
8 | # import viser.transforms as vtf
9 | # import yourdfpy
10 | # from jax.typing import ArrayLike
11 | # from pyroki._robot import Robot
12 | #
13 | # from viser.extras import BatchedGlbHandle
14 | #
15 | #
16 | # class BatchedURDF:
17 | # """
18 | # Helper for rendering batched URDFs in Viser.
19 | # Similar to `viser.extras.ViserUrdf`, but batched using `pyroki`'s batched forward kinematics.
20 | #
21 | # If num_robots > 1 then the URDF meshes are rendered as batched meshes (instancing),
22 | # otherwise they are rendered as individual meshes.
23 | #
24 | # Args:
25 | # target: Viser server or client handle to add URDF to.
26 | # urdf: URDF to render.
27 | # num_robots: Number of robots in the batch.
28 | # root_node_name: Name of the root node in the Viser scene.
29 | # """
30 | #
31 | # def __init__(
32 | # self,
33 | # target: viser.ViserServer | viser.ClientHandle,
34 | # urdf: yourdfpy.URDF,
35 | # num_robots: int = 1,
36 | # root_node_name: str = "/",
37 | # ):
38 | # assert root_node_name.startswith("/")
39 | # robot = Robot.from_urdf(urdf)
40 | #
41 | # self._urdf = urdf
42 | # self._robot = robot
43 | # self._target = target
44 | # self._root_node_name = root_node_name
45 | # self._num_robots = num_robots
46 | #
47 | # # Initialize base transforms to identity.
48 | # self._base_transforms = jaxlie.SE3.identity(batch_axes=(num_robots,))
49 | # self._last_cfg = None # Store the last configuration.
50 | #
51 | # self._populate()
52 | #
53 | # def _populate(self):
54 | # # Initialize with the correct batch size.
55 | # dummy_transform = vtf.SE3.identity(batch_axes=(self._num_robots,))
56 | # dummy_position = dummy_transform.translation()
57 | # dummy_wxyz = dummy_transform.rotation().wxyz
58 | #
59 | # self._meshes: dict[str, list[BatchedGlbHandle | viser.GlbHandle]] = {}
60 | # self._link_to_meshes: dict[str, onp.ndarray] = {}
61 | #
62 | # # Check if add_batched_meshes_trimesh is available.
63 | # if (
64 | # not hasattr(self._target.scene, "add_batched_meshes_trimesh")
65 | # and self._num_robots > 1
66 | # ):
67 | # raise NotImplementedError(
68 | # "num_robots > 1, but viser doesn't support instancing "
69 | # "(add_batched_meshes_trimesh is not available)."
70 | # )
71 | #
72 | # for mesh_name, mesh in self._urdf.scene.geometry.items():
73 | # link_name = self._urdf.scene.graph.transforms.parents[mesh_name]
74 | # if link_name not in self._meshes:
75 | # self._meshes[link_name] = []
76 | #
77 | # # Put mesh in the link frame.
78 | # T_parent_child = self._urdf.get_transform(
79 | # mesh_name, self._urdf.scene.graph.transforms.parents[mesh_name]
80 | # )
81 | # mesh = mesh.copy()
82 | # mesh.apply_transform(T_parent_child)
83 | #
84 | # if self._num_robots > 1:
85 | # self._meshes[link_name].append(
86 | # self._target.scene.add_batched_meshes_trimesh( # type: ignore[attr-defined]
87 | # f"{self._root_node_name}/{mesh_name}",
88 | # mesh,
89 | # batched_positions=dummy_position,
90 | # batched_wxyzs=dummy_wxyz,
91 | # lod="auto",
92 | # )
93 | # )
94 | # else:
95 | # self._meshes[link_name].append(
96 | # self._target.scene.add_mesh_trimesh(
97 | # f"{self._root_node_name}/{mesh_name}",
98 | # mesh,
99 | # position=dummy_position[0],
100 | # wxyz=dummy_wxyz[0],
101 | # )
102 | # )
103 | #
104 | # self._link_to_meshes[link_name] = T_parent_child
105 | #
106 | # def remove(self):
107 | # for meshes in self._meshes.values():
108 | # for mesh in meshes:
109 | # mesh.remove()
110 | #
111 | # def update_base_frame(self, base_transforms: ArrayLike):
112 | # """
113 | # Update the base transforms for each robot in the batch.
114 | #
115 | # Args:
116 | # base_transforms: New base transforms. Should be a JAX-compatible array
117 | # representing SE(3) transforms (e.g., a jaxlie.SE3 object)
118 | # with shape (num_robots,).
119 | # """
120 | # base_transforms_jnp = jnp.array(base_transforms)
121 | # base_transforms_jnp = jnp.atleast_2d(base_transforms_jnp)
122 | # assert base_transforms_jnp.shape[0] == self._num_robots, (
123 | # f"Expected first dimension of base_transforms to be {self._num_robots}, got {base_transforms_jnp.shape[0]}"
124 | # )
125 | #
126 | # self._base_transforms = jaxlie.SE3(base_transforms_jnp)
127 | #
128 | # # Re-apply transforms if a configuration exists
129 | # if self._last_cfg is not None:
130 | # self._apply_transforms(self._last_cfg)
131 | #
132 | # def update_cfg(self, cfg: ArrayLike):
133 | # """
134 | # Update the poses of the batched robots based on their configurations.
135 | #
136 | # Args:
137 | # cfg: Batched joint configurations. Shape should be (num_robots, num_dofs), or (num_dofs,).
138 | # """
139 | # cfg_jax = jnp.array(cfg) # in case cfg is an onp.ndarray.
140 | # cfg_jax = jnp.atleast_2d(cfg_jax)
141 | # assert cfg_jax.shape[0] == self._num_robots, (
142 | # f"Expected first dimension of cfg to be {self._num_robots}, got {cfg_jax.shape[0]}"
143 | # )
144 | #
145 | # # Store the latest configuration
146 | # self._last_cfg = cfg_jax
147 | # self._apply_transforms(cfg_jax)
148 | #
149 | # def _apply_transforms(self, cfg_jax: jnp.ndarray):
150 | # """Helper method to apply FK and base transforms to update meshes."""
151 | # # Ts_link_world should have shape (num_robots, num_links, ...)
152 | # Ts_link_world = self._robot.forward_kinematics(cfg_jax)
153 | #
154 | # for link_name, meshes in self._meshes.items():
155 | # link_idx = self._robot.links.names.index(link_name)
156 | # # T_link_world has shape (num_robots, ...)
157 | # T_link_world = jaxlie.SE3(
158 | # Ts_link_world[:, link_idx]
159 | # ) # Select link transforms for all robots
160 | #
161 | # # Apply base transform: T_mesh_world = T_base * T_link_world
162 | # # Resulting shape is (num_robots, ...)
163 | # T_mesh_world = self._base_transforms @ T_link_world
164 | #
165 | # # Extract batched positions and orientations
166 | # position = onp.array(T_mesh_world.translation()) # Shape (num_robots, 3)
167 | # wxyz = onp.array(T_mesh_world.rotation().wxyz) # Shape (num_robots, 4)
168 | # for mesh in meshes:
169 | # if isinstance(mesh, viser.GlbHandle):
170 | # mesh.position = position[0]
171 | # mesh.wxyz = wxyz[0]
172 | # else:
173 | # assert isinstance(mesh, BatchedGlbHandle)
174 | # mesh.batched_positions = position
175 | # mesh.batched_wxyzs = wxyz
176 |
--------------------------------------------------------------------------------
/src/pyroki/viewer/_manipulability_ellipse.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | import jaxlie
6 | import numpy as onp
7 | import trimesh.creation
8 | import viser
9 | from jax.typing import ArrayLike
10 | from loguru import logger
11 |
12 | from .._robot import Robot
13 |
14 |
15 | class ManipulabilityEllipse:
16 | """Helper class to visualize the manipulability ellipsoid for a robot link."""
17 |
18 | def __init__(
19 | self,
20 | server: viser.ViserServer | viser.ClientHandle,
21 | robot: Robot,
22 | root_node_name: str = "/manipulability",
23 | target_link_name: Optional[str] = None,
24 | scaling_factor: float = 0.2,
25 | visible: bool = True,
26 | wireframe: bool = True,
27 | color: Tuple[int, int, int] = (200, 200, 255),
28 | ):
29 | """Initializes the manipulability ellipsoid visualizer.
30 |
31 | Args:
32 | server: The Viser server or client handle.
33 | robot: The Pyroki robot model.
34 | root_node_name: The base name for the ellipsoid mesh in the Viser scene.
35 | target_link_name: Optional name of the link to visualize the ellipsoid for initially.
36 | scaling_factor: Scaling factor applied to the ellipsoid dimensions.
37 | visible: Initial visibility state.
38 | wireframe: Whether to render the ellipsoid as a wireframe.
39 | color: The color of the ellipsoid mesh.
40 | """
41 | self._server = server
42 | self._robot = robot
43 | self._root_node_name = root_node_name
44 | self._target_link_name = target_link_name
45 | self._scaling_factor = scaling_factor
46 | self._visible = visible
47 | self._wireframe = wireframe
48 | self._color = color
49 |
50 | self._base_manip_sphere = trimesh.creation.icosphere(radius=1.0)
51 | self._mesh_handle: Optional[viser.MeshHandle] = None
52 | self._target_link_index: Optional[int] = None
53 | self._last_joints: Optional[jnp.ndarray] = None
54 |
55 | # Initial creation of the mesh handle (hidden if not visible)
56 | self._create_mesh_handle()
57 |
58 | # Set initial target link if provided
59 | self.set_target_link(target_link_name)
60 |
61 | self.manipulability = 0.0
62 |
63 | def _create_mesh_handle(self):
64 | """Creates or recreates the mesh handle in the Viser scene."""
65 | if self._mesh_handle is not None:
66 | self._mesh_handle.remove()
67 |
68 | # Create with dummy data initially, will be updated
69 | self._mesh_handle = self._server.scene.add_mesh_simple(
70 | self._root_node_name,
71 | vertices=onp.zeros((1, 3), dtype=onp.float32),
72 | faces=onp.zeros((1, 3), dtype=onp.uint32),
73 | color=self._color,
74 | wireframe=self._wireframe,
75 | visible=self._visible,
76 | )
77 |
78 | # Viser version compatibility.
79 | if hasattr(self._mesh_handle, "cast_shadow"):
80 | self._mesh_handle.cast_shadow = ( # type: ignore[attr-defined]
81 | False # Ellipsoids usually don't need shadows
82 | )
83 |
84 | def set_target_link(self, link_name: Optional[str]):
85 | """Sets the target link for which to display the ellipsoid.
86 |
87 | Args:
88 | link_name: The name of the target link, or None to disable.
89 | """
90 | if link_name is None:
91 | self._target_link_index = None
92 | self.set_visibility(False) # Hide if no target link
93 | else:
94 | try:
95 | self._target_link_index = self._robot.links.names.index(link_name)
96 | # If we previously hid because of no target, make visible again
97 | # if the user hasn't explicitly set visibility to False.
98 | if self._mesh_handle is not None and self._visible:
99 | self._mesh_handle.visible = True
100 | except ValueError:
101 | logger.warning(f"Link name '{link_name}' not found in robot model.")
102 | self._target_link_index = None
103 | self.set_visibility(False) # Hide if link not found
104 |
105 | def update(self, joints: ArrayLike):
106 | """Updates the ellipsoid based on the current joint configuration.
107 |
108 | Args:
109 | joints: The current joint angles of the robot.
110 | """
111 | if (
112 | self._target_link_index is None
113 | or not self._visible
114 | or self._mesh_handle is None
115 | ):
116 | # Ensure mesh is hidden if it shouldn't be shown
117 | if self._mesh_handle is not None and self._mesh_handle.visible:
118 | self._mesh_handle.visible = False
119 | return
120 |
121 | # Ensure mesh is visible if it should be
122 | if not self._mesh_handle.visible:
123 | self._mesh_handle.visible = True
124 |
125 | joints = jnp.asarray(joints)
126 | self._last_joints = joints # Store for potential future updates
127 |
128 | try:
129 | # --- Jacobian Calculation ---
130 | jacobian = jax.jacfwd(
131 | lambda q: jaxlie.SE3(self._robot.forward_kinematics(q)).translation()
132 | )(joints)[self._target_link_index]
133 | assert jacobian.shape == (3, self._robot.joints.num_actuated_joints)
134 |
135 | # --- Manipulability Calculation ---
136 | JJT = jacobian @ jacobian.T
137 | assert JJT.shape == (3, 3)
138 | self.manipulability = jnp.sqrt(jnp.maximum(0.0, jnp.linalg.det(JJT))).item()
139 |
140 | # --- Covariance and Eigen decomposition ---
141 | cov_matrix = jacobian @ jacobian.T
142 | assert cov_matrix.shape == (3, 3)
143 | # Use numpy for Eigh as it might be more stable for visualization
144 | vals, vecs = onp.linalg.eigh(onp.array(cov_matrix))
145 | vals = onp.maximum(vals, 1e-9) # Clamp small eigenvalues for stability
146 |
147 | # --- Get Target Link Pose ---
148 | Ts_link_world_array = self._robot.forward_kinematics(joints)
149 | target_pose_array = Ts_link_world_array[self._target_link_index]
150 | target_pose = jaxlie.SE3(target_pose_array)
151 | target_pos = onp.array(target_pose.translation().squeeze())
152 |
153 | # --- Create and Transform Ellipsoid Mesh ---
154 | ellipsoid_mesh = self._base_manip_sphere.copy()
155 | tf = onp.eye(4)
156 | tf[:3, :3] = onp.array(vecs) # Rotation from eigenvectors
157 | tf[:3, 3] = target_pos # Translation to link origin
158 |
159 | # Apply scaling according to eigenvalues and the scaling factor
160 | ellipsoid_mesh.apply_scale(onp.sqrt(vals) * self._scaling_factor)
161 | # Apply the final transform
162 | ellipsoid_mesh.apply_transform(tf)
163 |
164 | # --- Update Viser Mesh ---
165 | self._mesh_handle.vertices = onp.array(
166 | ellipsoid_mesh.vertices, dtype=onp.float32
167 | )
168 | self._mesh_handle.faces = onp.array(ellipsoid_mesh.faces, dtype=onp.uint32)
169 |
170 | except Exception as e:
171 | logger.warning(f"Failed to update manipulability ellipsoid: {e}")
172 | # Hide the mesh on failure
173 | if self._mesh_handle is not None:
174 | self._mesh_handle.visible = False
175 |
176 | def set_visibility(self, visible: bool):
177 | """Sets the visibility of the ellipsoid mesh."""
178 | self._visible = visible
179 | if self._mesh_handle is not None:
180 | # If visibility is being turned on, and we have a target link and joints,
181 | # trigger an update to ensure the geometry is correct.
182 | if (
183 | visible
184 | and self._target_link_index is not None
185 | and self._last_joints is not None
186 | ):
187 | self.update(self._last_joints) # Recalculate and show
188 | # Otherwise, just set the visibility flag on the handle
189 | elif self._mesh_handle.visible != visible:
190 | self._mesh_handle.visible = visible
191 |
192 | def remove(self):
193 | """Removes the ellipsoid mesh from the Viser scene."""
194 | if self._mesh_handle is not None:
195 | self._mesh_handle.remove()
196 | self._mesh_handle = None
197 | self._target_link_index = None # Clear target when removed
198 |
--------------------------------------------------------------------------------
/src/pyroki/viewer/_weight_tuner.py:
--------------------------------------------------------------------------------
1 | """Provides a class for interactively tuning named weights using Viser GUIs."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import cast
6 |
7 | import jax
8 | import viser
9 |
10 |
11 | class WeightTuner:
12 | """Creates and manages a set of Viser GUI sliders for tuning named weights.
13 |
14 | This class simplifies the process of adding interactive controls to a Viser
15 | application, typically used for adjusting weights in optimization problems
16 | (like inverse kinematics) or other numerical parameters in real-time.
17 | The sliders are grouped within a Viser GUI folder.
18 | """
19 |
20 | _server: viser.ViserServer
21 | _weight_handles: dict[str, viser.GuiSliderHandle]
22 |
23 | _min: dict[str, float]
24 | _max: dict[str, float]
25 | _step: dict[str, float]
26 | _default: dict[str, float]
27 |
28 | def __init__(
29 | self,
30 | server: viser.ViserServer,
31 | default: dict[str, float],
32 | *,
33 | folder_name: str = "Costs",
34 | min: dict[str, float] | None = None,
35 | max: dict[str, float] | None = None,
36 | step: dict[str, float] | None = None,
37 | default_min: float = 0.0,
38 | default_max: float = 100.0,
39 | default_step: float = 0.01,
40 | ):
41 | """Initializes the tuner and creates the Viser GUI sliders.
42 | The sliders show up in the order of the keys in the `default` dictionary.
43 | `default` must be a dictionary.
44 |
45 | Args:
46 | server: The Viser server instance.
47 | default: An instance of the dataclass defining the weights.
48 | folder_name: Name of the Viser GUI folder to contain the sliders.
49 | min: Minimum value for each slider.
50 | max: Maximum value for each slider.
51 | step: Step size for each slider.
52 | default_min: Minimum value for all sliders. `min` overrides this.
53 | default_max: Maximum value for all sliders. `max` overrides this.
54 | default_step: Step size for all sliders. `step` overrides this.
55 | """
56 | leaves = jax.tree.leaves(default)
57 | assert all(isinstance(leaf, (int, float)) for leaf in leaves), (
58 | "All default parameters must be ints or floats."
59 | )
60 | assert isinstance(default, dict)
61 |
62 | self._server = server
63 | self._weight_handles = {}
64 | self._max = jax.tree.map(lambda _: default_max, default)
65 | if max is not None:
66 | for key, max_val in max.items():
67 | cast(dict, self._max)[key] = max_val
68 |
69 | self._min = jax.tree.map(lambda _: default_min, default)
70 | if min is not None:
71 | for key, min_val in min.items():
72 | cast(dict, self._min)[key] = min_val
73 |
74 | self._step = jax.tree.map(lambda _: default_step, default)
75 | if step is not None:
76 | for key, step_val in step.items():
77 | cast(dict, self._step)[key] = step_val
78 |
79 | self._default = default
80 |
81 | with server.gui.add_folder(folder_name):
82 | for field, default_weight in default.items():
83 | self._weight_handles[field] = server.gui.add_slider(
84 | field,
85 | min=self._min[field],
86 | max=self._max[field],
87 | step=self._step[field],
88 | initial_value=default_weight,
89 | )
90 |
91 | reset_button = server.gui.add_button("Reset Weights")
92 | reset_button.on_click(lambda _: self.reset_weights())
93 |
94 | def get_weights(self) -> dict[str, float]:
95 | """Retrieves the current values of all tracked weights from the GUI sliders.
96 |
97 | Returns:
98 | A dictionary mapping weight names to their current float values
99 | as set by the sliders.
100 | """
101 | return {
102 | field: handle.value
103 | for field, handle in zip(
104 | self._weight_handles.keys(),
105 | self._weight_handles.values(),
106 | )
107 | }
108 |
109 | def reset_weights(self):
110 | """Resets all weights to their initial values."""
111 | for name, handle in self._weight_handles.items():
112 | handle.value = self._default[name]
113 |
--------------------------------------------------------------------------------