├── .github └── workflows │ ├── docs.yml │ ├── formatting.yaml │ ├── pyright.yml │ └── pytest.yml ├── .gitignore ├── LICENSE ├── README.md ├── benchmark └── ik_benchmark.py ├── docs ├── Makefile ├── README.md ├── requirements.txt ├── source │ ├── _static │ │ ├── basic_ik.mov │ │ ├── css │ │ │ └── custom.css │ │ ├── logo.svg │ │ └── logo_dark.svg │ ├── _templates │ │ └── sidebar │ │ │ └── brand.html │ ├── conf.py │ ├── examples │ │ ├── 01_basic_ik.rst │ │ ├── 02_bimanual_ik.rst │ │ ├── 03_mobile_ik.rst │ │ ├── 04_ik_with_coll.rst │ │ ├── 05_ik_with_manipulability.rst │ │ ├── 06_online_planning.rst │ │ ├── 07_trajopt.rst │ │ ├── 08_ik_with_mimic_joints.rst │ │ ├── 09_hand_retargeting.rst │ │ └── 10_humanoid_retargeting.rst │ ├── index.rst │ └── misc │ │ └── writing_manual_jac.rst └── update_example_docs.py ├── examples ├── 01_basic_ik.py ├── 02_bimanual_ik.py ├── 03_mobile_ik.py ├── 04_ik_with_coll.py ├── 05_ik_with_manipulability.py ├── 06_online_planning.py ├── 07_trajopt.py ├── 08_ik_with_mimic_joints.py ├── 09_hand_retargeting.py ├── 10_humanoid_retargeting.py ├── pyroki_snippets │ ├── __init__.py │ ├── _online_planning.py │ ├── _solve_ik.py │ ├── _solve_ik_with_base.py │ ├── _solve_ik_with_collision.py │ ├── _solve_ik_with_manipulability.py │ ├── _solve_ik_with_multiple_targets.py │ └── _trajopt.py └── retarget_helpers │ ├── _utils.py │ ├── hand │ ├── dexycb_motion.pkl │ └── shadowhand_urdf.zip │ └── humanoid │ ├── heightmap.npy │ ├── left_foot_contact.npy │ ├── right_foot_contact.npy │ └── smpl_keypoints.npy ├── pyproject.toml └── src └── pyroki ├── __init__.py ├── _robot.py ├── _robot_urdf_parser.py ├── collision ├── __init__.py ├── _collision.py ├── _geometry.py ├── _geometry_pairs.py ├── _robot_collision.py └── _utils.py ├── costs ├── __init__.py ├── _costs.py ├── _pose_cost_analytic_jac.py └── _pose_cost_numerical_jac.py ├── utils.py └── viewer ├── __init__.py ├── _batched_urdf.py ├── _manipulability_ellipse.py └── _weight_tuner.py /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | 7 | permissions: 8 | contents: write 9 | 10 | jobs: 11 | docs: 12 | runs-on: ubuntu-latest 13 | steps: 14 | # Check out source. 15 | - uses: actions/checkout@v2 16 | with: 17 | fetch-depth: 0 # This ensures the entire history is fetched so we can switch branches 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: "3.12" 23 | 24 | - name: Set up dependencies 25 | run: | 26 | sudo apt update 27 | sudo apt install -y libsuitesparse-dev 28 | pip install uv 29 | uv pip install --system -e ".[dev,examples]" 30 | uv pip install --system -r docs/requirements.txt 31 | 32 | # Build documentation. 33 | - name: Building documentation 34 | run: | 35 | sphinx-build docs/source docs/build -b dirhtml 36 | 37 | # Deploy to version-dependent subdirectory. 38 | - name: Deploy to GitHub Pages 39 | uses: peaceiris/actions-gh-pages@v4 40 | with: 41 | github_token: ${{ secrets.GITHUB_TOKEN }} 42 | publish_dir: ./docs/build 43 | -------------------------------------------------------------------------------- /.github/workflows/formatting.yaml: -------------------------------------------------------------------------------- 1 | name: formatting 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 3.12 16 | uses: actions/setup-python@v1 17 | with: 18 | python-version: 3.12 19 | - name: Install dependencies 20 | run: | 21 | sudo apt update 22 | sudo apt install -y libsuitesparse-dev 23 | pip install uv 24 | uv pip install --system -e ".[dev,examples]" 25 | - name: Run Ruff 26 | run: ruff check docs/ src/ examples/ 27 | - name: Run Ruff format 28 | run: ruff format docs/ src/ examples/ --check 29 | -------------------------------------------------------------------------------- /.github/workflows/pyright.yml: -------------------------------------------------------------------------------- 1 | name: pyright 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | pyright: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10", "3.11", "3.12", "3.13"] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v1 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | sudo apt update 25 | sudo apt install -y libsuitesparse-dev 26 | pip install uv 27 | uv pip install --system -e ".[dev,examples]" 28 | - name: Run pyright 29 | run: | 30 | pyright . 31 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10", "3.11", "3.12", "3.13"] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | sudo apt update 25 | sudo apt install -y libsuitesparse-dev 26 | pip install uv 27 | uv pip install --system -e ".[dev,examples]" 28 | - name: Test with pytest 29 | run: | 30 | pytest 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.swo 3 | *.pyc 4 | *.egg-info 5 | *.ipynb_checkpoints 6 | __pycache__ 7 | .coverage 8 | htmlcov 9 | .mypy_cache 10 | .dmypy.json 11 | .hypothesis 12 | .envrc 13 | .lvimrc 14 | .DS_Store 15 | .envrc 16 | .vite 17 | build 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Chung Min Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `PyRoki`: Python Robot Kinematics Library 2 | 3 | **[Project page](https://pyroki-toolkit.github.io/) • 4 | [arXiv](https://arxiv.org/abs/2505.03728)** 5 | 6 | `PyRoki` is a modular, extensible, and cross-platform toolkit for kinematic optimization, all in Python. 7 | 8 | Core features include: 9 | 10 | - Differentiable robot forward kinematics model from a URDF. 11 | - Automatic generation of robot collision primitives (e.g., capsules). 12 | - Differentiable collision bodies with numpy broadcasting logic. 13 | - Common cost implementations (e.g., end effector pose, self/world-collision, manipulability). 14 | - Arbitrary costs, autodiff or analytical Jacobians. 15 | - Integration with a [Levenberg-Marquardt Solver](https://github.com/brentyi/jaxls) that supports optimization on manifolds (e.g., [lie groups](https://github.com/brentyi/jaxlie)) 16 | - Cross-platform support (CPU, GPU, TPU) via JAX. 17 | 18 | Please refer to the [documentation](https://chungmin99.github.io/pyroki/) for more details, features, and usage examples. 19 | 20 | --- 21 | 22 | ## Installation 23 | 24 | You can install `pyroki` with `pip`, on Python 3.12+: 25 | 26 | ``` 27 | git clone https://github.com/chungmin99/pyroki.git 28 | cd pyroki 29 | pip install -e . 30 | ``` 31 | 32 | Python 3.10-3.11 should also work, but support may be dropped in the future. 33 | 34 | ## Status 35 | 36 | _May 6, 2025_: Initial release 37 | 38 | We are preparing and will release by _May 16, 2025_: 39 | 40 | - [x] Examples + documentation for hand / humanoid motion retargeting 41 | - [x] Documentation for using manually defined Jacobians 42 | - [x] Support with Python 3.10+ 43 | 44 | ## Limitations 45 | 46 | - **Soft constraints only**: We use a nonlinear least-squares formulation and model joint limits, collision avoidance, etc. as soft penalties with high weights rather than hard constraints. 47 | - **Static shapes & JIT overhead**: JAX JIT compilation is triggered on first run and when input shapes change (e.g., number of targets, obstacles). Arrays can be pre-padded to vectorize over inputs with different shapes. 48 | - **No sampling-based planners**: We don't include sampling-based planners (e.g., graphs, trees). 49 | - **Collision performance**: Speed and accuracy comparisons against other robot toolkits such as CuRobo have not been extensively performed, and is likely slower than other toolkits for collision-heavy scenarios. 50 | 51 | The following are current implementation limitations that could potentially be addressed in future versions: 52 | 53 | - **Joint types**: We only support revolute, continuous, prismatic, and fixed joints. Other URDF joint types are treated as fixed joints. 54 | - **Collision geometry**: We are limited to sphere, capsule, halfspace, and heightmap geometries. Mesh collision is approximated as capsules. 55 | - **Kinematic structures**: We only support kinematic trees; no closed-loop mechanisms or parallel manipulators. 56 | 57 | ## Citation 58 | 59 | This codebase is released with the following preprint. 60 | 61 | 66 |
62 | Chung Min Kim*, Brent Yi*, Hongsuk Choi, Yi Ma, Ken Goldberg, Angjoo Kanazawa. 63 | PyRoki: A Modular Toolkit for Robot Kinematic Optimization 64 | arXiV, 2025. 65 |
67 | 68 | \*Equal Contribution, UC Berkeley. 69 | 70 | Please cite PyRoki if you find this work useful for your research: 71 | 72 | ``` 73 | @misc{pyroki2025, 74 | title={PyRoki: A Modular Toolkit for Robot Kinematic Optimization}, 75 | author={Chung Min Kim* and Brent Yi* and Hongsuk Choi and Yi Ma and Ken Goldberg and Angjoo Kanazawa}, 76 | year={2025}, 77 | eprint={2505.03728}, 78 | archivePrefix={arXiv}, 79 | primaryClass={cs.RO}, 80 | url={https://arxiv.org/abs/2505.03728}, 81 | } 82 | ``` 83 | 84 | Thanks! 85 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = viser 8 | SOURCEDIR = source 9 | BUILDDIR = ./build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Viser Documentation 2 | 3 | This directory contains the documentation for Viser. 4 | 5 | ## Building the Documentation 6 | 7 | To build the documentation: 8 | 9 | 1. Install the documentation dependencies: 10 | 11 | ```bash 12 | pip install -r docs/requirements.txt 13 | ``` 14 | 15 | 2. Build the documentation: 16 | 17 | ```bash 18 | cd docs 19 | make html 20 | ``` 21 | 22 | 3. View the documentation: 23 | 24 | ```bash 25 | # On macOS 26 | open build/html/index.html 27 | 28 | # On Linux 29 | xdg-open build/html/index.html 30 | ``` 31 | 32 | ## Contributing Screenshots 33 | 34 | When adding new documentation, screenshots and visual examples significantly improve user understanding. 35 | 36 | We need screenshots for: 37 | 38 | - The Getting Started guide 39 | - GUI element examples 40 | - Scene API visualization examples 41 | - Customization/theming examples 42 | 43 | See [Contributing Visuals](./source/contributing_visuals.md) for guidelines on capturing and adding images to the documentation. 44 | 45 | ## Documentation Structure 46 | 47 | - `source/` - Source files for the documentation 48 | - `_static/` - Static files (CSS, images, etc.) 49 | - `images/` - Screenshots and other images 50 | - `examples/` - Example code with documentation 51 | - `*.md` - Markdown files for documentation pages 52 | - `conf.py` - Sphinx configuration 53 | 54 | ## Auto-Generated Example Documentation 55 | 56 | Example documentation is automatically generated from the examples in the `examples/` directory using the `update_example_docs.py` script. To update the example documentation after making changes to examples: 57 | 58 | ```bash 59 | cd docs 60 | python update_example_docs.py 61 | ``` 62 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==8.0.2 2 | furo==2024.8.6 3 | docutils==0.20.1 4 | toml==0.10.2 5 | sphinxcontrib-video==0.4.1 6 | git+https://github.com/brentyi/sphinxcontrib-programoutput.git 7 | git+https://github.com/brentyi/ansi.git 8 | git+https://github.com/sphinx-contrib/googleanalytics.git 9 | 10 | snowballstemmer==2.2.0 # https://github.com/snowballstem/snowball/issues/229 11 | -------------------------------------------------------------------------------- /docs/source/_static/basic_ik.mov: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/docs/source/_static/basic_ik.mov -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | img.sidebar-logo { 2 | width: 10em; 3 | margin: 1em 0 0 0; 4 | } 5 | -------------------------------------------------------------------------------- /docs/source/_static/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /docs/source/_templates/sidebar/brand.html: -------------------------------------------------------------------------------- 1 | 11 | {%- endif %} {%- if theme_light_logo and theme_dark_logo %} 12 | 24 | {%- endif %} 25 | 26 | {% endblock brand_content %} 27 | 28 | 29 |
30 | 31 | 39 | Github 40 | 41 |
-------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/stable/config 8 | 9 | import os 10 | from typing import Dict, List 11 | 12 | import pyroki 13 | 14 | # -- Path setup -------------------------------------------------------------- 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = "pyroki" # Change project name 24 | copyright = "2025" # Update copyright year/holder if needed 25 | author = "chungmin99" # Update author name 26 | 27 | version: str = os.environ.get( 28 | "PYROKI_VERSION_STR_OVERRIDE", pyroki.__version__ 29 | ) # Remove this 30 | 31 | # Formatting! 32 | # 0.1.30 => v0.1.30 33 | # dev => dev 34 | if not version.isalpha(): 35 | version = "v" + version 36 | 37 | # The full version, including alpha/beta/rc tags 38 | release = version # Use the same version for release for now 39 | 40 | 41 | # -- General configuration --------------------------------------------------- 42 | 43 | # If your documentation needs a minimal Sphinx version, state it here. 44 | # 45 | # needs_sphinx = '1.0' 46 | 47 | # Add any Sphinx extension module names here, as strings. They can be 48 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 49 | # ones. 50 | extensions = [ 51 | "sphinx.ext.autodoc", 52 | "sphinx.ext.todo", 53 | "sphinx.ext.coverage", 54 | "sphinx.ext.mathjax", 55 | "sphinx.ext.githubpages", 56 | "sphinx.ext.napoleon", 57 | # "sphinx.ext.inheritance_diagram", 58 | "sphinxcontrib.video", 59 | "sphinx.ext.viewcode", 60 | "sphinxcontrib.programoutput", 61 | "sphinxcontrib.ansi", 62 | # "sphinxcontrib.googleanalytics", # google analytics extension https://github.com/sphinx-contrib/googleanalytics/tree/master 63 | ] 64 | programoutput_use_ansi = True 65 | html_ansi_stylesheet = "black-on-white.css" 66 | html_static_path = ["_static"] 67 | html_theme_options = { 68 | "light_css_variables": { 69 | "color-code-background": "#f4f4f4", 70 | "color-code-foreground": "#000", 71 | }, 72 | # Remove viser-specific footer icon 73 | "footer_icons": [ 74 | { 75 | "name": "GitHub", 76 | "url": "https://github.com/chungmin99/pyroki-dev", 77 | "html": """ 78 | 79 | 80 | 81 | """, 82 | "class": "", 83 | }, 84 | ], 85 | # Remove viser-specific logos 86 | "light_logo": "logo.svg", 87 | "dark_logo": "logo_dark.svg", 88 | } 89 | 90 | # Pull documentation types from hints 91 | autodoc_typehints = "both" 92 | autodoc_class_signature = "separated" 93 | autodoc_default_options = { 94 | "members": True, 95 | "member-order": "bysource", 96 | "undoc-members": True, 97 | "inherited-members": True, 98 | "exclude-members": "__init__, __post_init__", 99 | "imported-members": True, 100 | } 101 | 102 | # Add any paths that contain templates here, relative to this directory. 103 | templates_path = ["_templates"] 104 | 105 | # The suffix(es) of source filenames. 106 | # You can specify multiple suffix as a list of string: 107 | # 108 | source_suffix = ".rst" 109 | # source_suffix = ".rst" 110 | 111 | # The master toctree document. 112 | master_doc = "index" 113 | 114 | # The language for content autogenerated by Sphinx. Refer to documentation 115 | # for a list of supported languages. 116 | # 117 | # This is also used if you do content translation via gettext catalogs. 118 | # Usually you set "language" from the command line for these cases. 119 | language: str = "en" 120 | 121 | # List of patterns, relative to source directory, that match files and 122 | # directories to ignore when looking for source files. 123 | # This pattern also affects html_static_path and html_extra_path . 124 | exclude_patterns: List[str] = [] 125 | 126 | # The name of the Pygments (syntax highlighting) style to use. 127 | pygments_style = "default" 128 | 129 | 130 | # -- Options for HTML output ------------------------------------------------- 131 | 132 | # The theme to use for HTML and HTML Help pages. See the documentation for 133 | # a list of builtin themes. 134 | # 135 | html_theme = "furo" 136 | html_title = "pyroki" # Update title 137 | 138 | 139 | # Theme options are theme-specific and customize the look and feel of a theme 140 | # further. For a list of options available for each theme, see the 141 | # documentation. 142 | # 143 | # html_theme_options = {} 144 | 145 | # Add any paths that contain custom static files (such as style sheets) here, 146 | # relative to this directory. They are copied after the builtin static files, 147 | # so a file named "default.css" will overwrite the builtin "default.css". 148 | # html_static_path = ["_static"] 149 | 150 | # Custom sidebar templates, must be a dictionary that maps document names 151 | # to template names. 152 | # 153 | # The default sidebars (for documents that don't match any pattern) are 154 | # defined by theme itself. Builtin themes are using these templates by 155 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 156 | # 'searchbox.html']``. 157 | # 158 | # html_sidebars = {} 159 | 160 | 161 | # -- Options for HTMLHelp output --------------------------------------------- 162 | 163 | # Output file base name for HTML help builder. 164 | htmlhelp_basename = "pyroki_doc" # Update basename 165 | 166 | 167 | # -- Options for Github output ------------------------------------------------ 168 | 169 | sphinx_to_github = True 170 | sphinx_to_github_verbose = True 171 | sphinx_to_github_encoding = "utf-8" 172 | 173 | 174 | # -- Options for LaTeX output ------------------------------------------------ 175 | 176 | latex_elements: Dict[str, str] = { 177 | # The paper size ('letterpaper' or 'a4paper'). 178 | # 179 | # 'papersize': 'letterpaper', 180 | # The font size ('10pt', '11pt' or '12pt'). 181 | # 182 | # 'pointsize': '10pt', 183 | # Additional stuff for the LaTeX preamble. 184 | # 185 | # 'preamble': '', 186 | # Latex figure (float) alignment 187 | # 188 | # 'figure_align': 'htbp', 189 | } 190 | 191 | # Grouping the document tree into LaTeX files. List of tuples 192 | # (source start file, target name, title, 193 | # author, documentclass [howto, manual, or own class]). 194 | latex_documents = [ 195 | ( 196 | master_doc, 197 | "pyroki.tex", # Update target name 198 | "pyroki", # Update title 199 | "Your Name", # Update author 200 | "manual", 201 | ), 202 | ] 203 | 204 | 205 | # -- Options for manual page output ------------------------------------------ 206 | 207 | # One entry per manual page. List of tuples 208 | # (source start file, name, description, authors, manual section). 209 | man_pages = [ 210 | (master_doc, "pyroki", "pyroki documentation", [author], 1) 211 | ] # Update name and description 212 | 213 | 214 | # -- Options for Texinfo output ---------------------------------------------- 215 | 216 | # Grouping the document tree into Texinfo files. List of tuples 217 | # (source start file, target name, title, author, 218 | # dir menu entry, description, category) 219 | texinfo_documents = [ 220 | ( 221 | master_doc, 222 | "pyroki", # Update target name 223 | "pyroki", # Update title 224 | author, 225 | "pyroki", # Update dir menu entry 226 | "Python Robot Kinematics library", # Update description 227 | "Miscellaneous", 228 | ), 229 | ] 230 | 231 | 232 | # -- Extension configuration -------------------------------------------------- 233 | 234 | # Google Analytics ID 235 | # googleanalytics_id = "G-RRGY51J5ZH" # Remove this 236 | 237 | # -- Options for todo extension ---------------------------------------------- 238 | 239 | # If true, `todo` and `todoList` produce output, else they produce nothing. 240 | todo_include_todos = True 241 | 242 | # -- Setup function ---------------------------------------- 243 | 244 | 245 | def setup(app): 246 | app.add_css_file("css/custom.css") 247 | 248 | 249 | # -- Napoleon settings ------------------------------------------------------- 250 | 251 | # Settings for parsing non-sphinx style docstrings. We use Google style in this 252 | # project. 253 | napoleon_google_docstring = True 254 | napoleon_numpy_docstring = False 255 | napoleon_include_init_with_doc = False 256 | napoleon_include_private_with_doc = False 257 | napoleon_include_special_with_doc = True 258 | napoleon_use_admonition_for_examples = False 259 | napoleon_use_admonition_for_notes = False 260 | napoleon_use_admonition_for_references = False 261 | napoleon_use_ivar = False 262 | napoleon_use_param = True 263 | napoleon_use_rtype = True 264 | napoleon_preprocess_types = True 265 | napoleon_type_aliases = None 266 | napoleon_attr_annotations = True 267 | -------------------------------------------------------------------------------- /docs/source/examples/01_basic_ik.rst: -------------------------------------------------------------------------------- 1 | .. Comment: this file is automatically generated by `update_example_docs.py`. 2 | It should not be modified manually. 3 | 4 | Basic IK 5 | ========================================== 6 | 7 | 8 | Simplest Inverse Kinematics Example using PyRoki. 9 | 10 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details. 11 | 12 | 13 | 14 | .. code-block:: python 15 | :linenos: 16 | 17 | 18 | import time 19 | 20 | import numpy as np 21 | import pyroki as pk 22 | import viser 23 | from robot_descriptions.loaders.yourdfpy import load_robot_description 24 | from viser.extras import ViserUrdf 25 | 26 | import pyroki_snippets as pks 27 | 28 | 29 | def main(): 30 | """Main function for basic IK.""" 31 | 32 | urdf = load_robot_description("panda_description") 33 | target_link_name = "panda_hand" 34 | 35 | # Create robot. 36 | robot = pk.Robot.from_urdf(urdf) 37 | 38 | # Set up visualizer. 39 | server = viser.ViserServer() 40 | server.scene.add_grid("/ground", width=2, height=2) 41 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") 42 | 43 | # Create interactive controller with initial position. 44 | ik_target = server.scene.add_transform_controls( 45 | "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0) 46 | ) 47 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 48 | 49 | while True: 50 | # Solve IK. 51 | start_time = time.time() 52 | solution = pks.solve_ik( 53 | robot=robot, 54 | target_link_name=target_link_name, 55 | target_position=np.array(ik_target.position), 56 | target_wxyz=np.array(ik_target.wxyz), 57 | ) 58 | 59 | # Update timing handle. 60 | elapsed_time = time.time() - start_time 61 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) 62 | 63 | # Update visualizer. 64 | urdf_vis.update_cfg(solution) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /docs/source/examples/02_bimanual_ik.rst: -------------------------------------------------------------------------------- 1 | .. Comment: this file is automatically generated by `update_example_docs.py`. 2 | It should not be modified manually. 3 | 4 | Bimanual IK 5 | ========================================== 6 | 7 | 8 | Same as 01_basic_ik.py, but with two end effectors! 9 | 10 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details. 11 | 12 | 13 | 14 | .. code-block:: python 15 | :linenos: 16 | 17 | 18 | import time 19 | import viser 20 | from robot_descriptions.loaders.yourdfpy import load_robot_description 21 | import numpy as np 22 | 23 | import pyroki as pk 24 | from viser.extras import ViserUrdf 25 | import pyroki_snippets as pks 26 | 27 | 28 | def main(): 29 | """Main function for bimanual IK.""" 30 | 31 | urdf = load_robot_description("yumi_description") 32 | target_link_names = ["yumi_link_7_r", "yumi_link_7_l"] 33 | 34 | # Create robot. 35 | robot = pk.Robot.from_urdf(urdf) 36 | 37 | # Set up visualizer. 38 | server = viser.ViserServer() 39 | server.scene.add_grid("/ground", width=2, height=2) 40 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") 41 | 42 | # Create interactive controller with initial position. 43 | ik_target_0 = server.scene.add_transform_controls( 44 | "/ik_target_0", scale=0.2, position=(0.41, -0.3, 0.56), wxyz=(0, 0, 1, 0) 45 | ) 46 | ik_target_1 = server.scene.add_transform_controls( 47 | "/ik_target_1", scale=0.2, position=(0.41, 0.3, 0.56), wxyz=(0, 0, 1, 0) 48 | ) 49 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 50 | 51 | while True: 52 | # Solve IK. 53 | start_time = time.time() 54 | solution = pks.solve_ik_with_multiple_targets( 55 | robot=robot, 56 | target_link_names=target_link_names, 57 | target_positions=np.array([ik_target_0.position, ik_target_1.position]), 58 | target_wxyzs=np.array([ik_target_0.wxyz, ik_target_1.wxyz]), 59 | ) 60 | 61 | # Update timing handle. 62 | elapsed_time = time.time() - start_time 63 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) 64 | 65 | # Update visualizer. 66 | urdf_vis.update_cfg(solution) 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /docs/source/examples/03_mobile_ik.rst: -------------------------------------------------------------------------------- 1 | .. Comment: this file is automatically generated by `update_example_docs.py`. 2 | It should not be modified manually. 3 | 4 | Mobile IK 5 | ========================================== 6 | 7 | 8 | Same as 01_basic_ik.py, but with a mobile base! 9 | 10 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details. 11 | 12 | 13 | 14 | .. code-block:: python 15 | :linenos: 16 | 17 | 18 | import time 19 | import viser 20 | from robot_descriptions.loaders.yourdfpy import load_robot_description 21 | import numpy as np 22 | 23 | import pyroki as pk 24 | from viser.extras import ViserUrdf 25 | import pyroki_snippets as pks 26 | 27 | 28 | def main(): 29 | """Main function for IK with a mobile base. 30 | The base is fixed along the xy plane, and is biased towards being at the origin. 31 | """ 32 | 33 | urdf = load_robot_description("fetch_description") 34 | target_link_name = "gripper_link" 35 | 36 | # Create robot. 37 | robot = pk.Robot.from_urdf(urdf) 38 | 39 | # Set up visualizer. 40 | server = viser.ViserServer() 41 | server.scene.add_grid("/ground", width=2, height=2) 42 | base_frame = server.scene.add_frame("/base", show_axes=False) 43 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") 44 | 45 | # Create interactive controller with initial position. 46 | ik_target = server.scene.add_transform_controls( 47 | "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0.707, 0, -0.707) 48 | ) 49 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 50 | 51 | cfg = np.array(robot.joint_var_cls(0).default_factory()) 52 | 53 | while True: 54 | # Solve IK. 55 | start_time = time.time() 56 | base_pos, base_wxyz, cfg = pks.solve_ik_with_base( 57 | robot=robot, 58 | target_link_name=target_link_name, 59 | target_position=np.array(ik_target.position), 60 | target_wxyz=np.array(ik_target.wxyz), 61 | fix_base_position=(False, False, True), # Only free along xy plane. 62 | fix_base_orientation=(True, True, False), # Free along z-axis rotation. 63 | prev_pos=base_frame.position, 64 | prev_wxyz=base_frame.wxyz, 65 | prev_cfg=cfg, 66 | ) 67 | 68 | # Update timing handle. 69 | elapsed_time = time.time() - start_time 70 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) 71 | 72 | # Update visualizer. 73 | urdf_vis.update_cfg(cfg) 74 | base_frame.position = np.array(base_pos) 75 | base_frame.wxyz = np.array(base_wxyz) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /docs/source/examples/04_ik_with_coll.rst: -------------------------------------------------------------------------------- 1 | .. Comment: this file is automatically generated by `update_example_docs.py`. 2 | It should not be modified manually. 3 | 4 | IK with Collision 5 | ========================================== 6 | 7 | 8 | Basic Inverse Kinematics with Collision Avoidance using PyRoKi. 9 | 10 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details. 11 | 12 | 13 | 14 | .. code-block:: python 15 | :linenos: 16 | 17 | 18 | import time 19 | 20 | import numpy as np 21 | import pyroki as pk 22 | import viser 23 | from pyroki.collision import HalfSpace, RobotCollision, Sphere 24 | from robot_descriptions.loaders.yourdfpy import load_robot_description 25 | from viser.extras import ViserUrdf 26 | 27 | import pyroki_snippets as pks 28 | 29 | 30 | def main(): 31 | """Main function for basic IK with collision.""" 32 | urdf = load_robot_description("panda_description") 33 | target_link_name = "panda_hand" 34 | robot = pk.Robot.from_urdf(urdf) 35 | 36 | robot_coll = RobotCollision.from_urdf(urdf) 37 | plane_coll = HalfSpace.from_point_and_normal( 38 | np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) 39 | ) 40 | sphere_coll = Sphere.from_center_and_radius( 41 | np.array([0.0, 0.0, 0.0]), np.array([0.05]) 42 | ) 43 | 44 | # Set up visualizer. 45 | server = viser.ViserServer() 46 | server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) 47 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot") 48 | 49 | # Create interactive controller for IK target. 50 | ik_target_handle = server.scene.add_transform_controls( 51 | "/ik_target", scale=0.2, position=(0.5, 0.0, 0.5), wxyz=(0, 0, 1, 0) 52 | ) 53 | 54 | # Create interactive controller and mesh for the sphere obstacle. 55 | sphere_handle = server.scene.add_transform_controls( 56 | "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4) 57 | ) 58 | server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh()) 59 | 60 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 61 | 62 | while True: 63 | start_time = time.time() 64 | 65 | sphere_coll_world_current = sphere_coll.transform_from_wxyz_position( 66 | wxyz=np.array(sphere_handle.wxyz), 67 | position=np.array(sphere_handle.position), 68 | ) 69 | 70 | world_coll_list = [plane_coll, sphere_coll_world_current] 71 | solution = pks.solve_ik_with_collision( 72 | robot=robot, 73 | coll=robot_coll, 74 | world_coll_list=world_coll_list, 75 | target_link_name=target_link_name, 76 | target_position=np.array(ik_target_handle.position), 77 | target_wxyz=np.array(ik_target_handle.wxyz), 78 | ) 79 | 80 | # Update timing handle. 81 | timing_handle.value = (time.time() - start_time) * 1000 82 | 83 | # Update visualizer. 84 | urdf_vis.update_cfg(solution) 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /docs/source/examples/05_ik_with_manipulability.rst: -------------------------------------------------------------------------------- 1 | .. Comment: this file is automatically generated by `update_example_docs.py`. 2 | It should not be modified manually. 3 | 4 | IK with Manipulability 5 | ========================================== 6 | 7 | 8 | Inverse Kinematics with Manipulability using PyRoKi. 9 | 10 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details. 11 | 12 | 13 | 14 | .. code-block:: python 15 | :linenos: 16 | 17 | 18 | import time 19 | import viser 20 | from robot_descriptions.loaders.yourdfpy import load_robot_description 21 | import numpy as np 22 | 23 | import pyroki as pk 24 | from viser.extras import ViserUrdf 25 | import pyroki_snippets as pks 26 | 27 | 28 | def main(): 29 | """Main function for basic IK.""" 30 | 31 | urdf = load_robot_description("panda_description") 32 | target_link_name = "panda_hand" 33 | 34 | # Create robot. 35 | robot = pk.Robot.from_urdf(urdf) 36 | 37 | # Set up visualizer. 38 | server = viser.ViserServer() 39 | server.scene.add_grid("/ground", width=2, height=2) 40 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") 41 | 42 | # Create interactive controller with initial position. 43 | ik_target = server.scene.add_transform_controls( 44 | "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0) 45 | ) 46 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 47 | value_handle = server.gui.add_number("Yoshikawa Index", 0.001, disabled=True) 48 | weight_handle = server.gui.add_slider( 49 | "Manipulability Weight", 0.0, 10.0, 0.001, 0.0 50 | ) 51 | manip_ellipse = pk.viewer.ManipulabilityEllipse( 52 | server, 53 | robot, 54 | root_node_name="/manipulability", 55 | target_link_name=target_link_name, 56 | ) 57 | 58 | while True: 59 | # Solve IK. 60 | start_time = time.time() 61 | solution = pks.solve_ik_with_manipulability( 62 | robot=robot, 63 | target_link_name=target_link_name, 64 | target_position=np.array(ik_target.position), 65 | target_wxyz=np.array(ik_target.wxyz), 66 | manipulability_weight=weight_handle.value, 67 | ) 68 | 69 | manip_ellipse.update(solution) 70 | value_handle.value = manip_ellipse.manipulability 71 | 72 | # Update timing handle. 73 | elapsed_time = time.time() - start_time 74 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) 75 | 76 | # Update visualizer. 77 | urdf_vis.update_cfg(solution) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /docs/source/examples/06_online_planning.rst: -------------------------------------------------------------------------------- 1 | .. Comment: this file is automatically generated by `update_example_docs.py`. 2 | It should not be modified manually. 3 | 4 | Online Planning 5 | ========================================== 6 | 7 | 8 | Run online planning in collision aware environments. 9 | 10 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details. 11 | 12 | 13 | 14 | .. code-block:: python 15 | :linenos: 16 | 17 | 18 | import time 19 | 20 | import numpy as np 21 | import pyroki as pk 22 | import viser 23 | from pyroki.collision import HalfSpace, RobotCollision, Sphere 24 | from robot_descriptions.loaders.yourdfpy import load_robot_description 25 | from viser.extras import ViserUrdf 26 | 27 | import pyroki_snippets as pks 28 | 29 | 30 | def main(): 31 | """Main function for online planning with collision.""" 32 | urdf = load_robot_description("panda_description") 33 | target_link_name = "panda_hand" 34 | robot = pk.Robot.from_urdf(urdf) 35 | 36 | robot_coll = RobotCollision.from_urdf(urdf) 37 | plane_coll = HalfSpace.from_point_and_normal( 38 | np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) 39 | ) 40 | sphere_coll = Sphere.from_center_and_radius( 41 | np.array([0.0, 0.0, 0.0]), np.array([0.05]) 42 | ) 43 | 44 | # Define the online planning parameters. 45 | len_traj, dt = 5, 0.1 46 | 47 | # Set up visualizer. 48 | server = viser.ViserServer() 49 | server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) 50 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot") 51 | 52 | # Create interactive controller for IK target. 53 | ik_target_handle = server.scene.add_transform_controls( 54 | "/ik_target", scale=0.2, position=(0.3, 0.0, 0.5), wxyz=(0, 0, 1, 0) 55 | ) 56 | 57 | # Create interactive controller and mesh for the sphere obstacle. 58 | sphere_handle = server.scene.add_transform_controls( 59 | "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4) 60 | ) 61 | server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh()) 62 | target_frame_handle = server.scene.add_batched_axes( 63 | "target_frame", 64 | axes_length=0.05, 65 | axes_radius=0.005, 66 | batched_positions=np.zeros((25, 3)), 67 | batched_wxyzs=np.array([[1.0, 0.0, 0.0, 0.0]] * 25), 68 | ) 69 | 70 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 71 | 72 | sol_pos, sol_wxyz = None, None 73 | sol_traj = np.array( 74 | robot.joint_var_cls.default_factory()[None].repeat(len_traj, axis=0) 75 | ) 76 | while True: 77 | start_time = time.time() 78 | 79 | sphere_coll_world_current = sphere_coll.transform_from_wxyz_position( 80 | wxyz=np.array(sphere_handle.wxyz), 81 | position=np.array(sphere_handle.position), 82 | ) 83 | 84 | world_coll_list = [plane_coll, sphere_coll_world_current] 85 | sol_traj, sol_pos, sol_wxyz = pks.solve_online_planning( 86 | robot=robot, 87 | robot_coll=robot_coll, 88 | world_coll=world_coll_list, 89 | target_link_name=target_link_name, 90 | target_position=np.array(ik_target_handle.position), 91 | target_wxyz=np.array(ik_target_handle.wxyz), 92 | timesteps=len_traj, 93 | dt=dt, 94 | start_cfg=sol_traj[0], 95 | prev_sols=sol_traj, 96 | ) 97 | 98 | # Update timing handle. 99 | timing_handle.value = ( 100 | 0.99 * timing_handle.value + 0.01 * (time.time() - start_time) * 1000 101 | ) 102 | 103 | # Update visualizer. 104 | urdf_vis.update_cfg( 105 | sol_traj[0] 106 | ) # The first step of the online trajectory solution. 107 | 108 | # Update the planned trajectory visualization. 109 | if hasattr(target_frame_handle, "batched_positions"): 110 | target_frame_handle.batched_positions = np.array(sol_pos) # type: ignore[attr-defined] 111 | target_frame_handle.batched_wxyzs = np.array(sol_wxyz) # type: ignore[attr-defined] 112 | else: 113 | # This is an older version of Viser. 114 | target_frame_handle.positions_batched = np.array(sol_pos) # type: ignore[attr-defined] 115 | target_frame_handle.wxyzs_batched = np.array(sol_wxyz) # type: ignore[attr-defined] 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /docs/source/examples/07_trajopt.rst: -------------------------------------------------------------------------------- 1 | .. Comment: this file is automatically generated by `update_example_docs.py`. 2 | It should not be modified manually. 3 | 4 | Trajectory Optimization 5 | ========================================== 6 | 7 | 8 | Basic Trajectory Optimization using PyRoKi. 9 | 10 | Robot going over a wall, while avoiding world-collisions. 11 | 12 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details. 13 | 14 | 15 | 16 | .. code-block:: python 17 | :linenos: 18 | 19 | 20 | import time 21 | from typing import Literal 22 | 23 | import numpy as np 24 | import pyroki as pk 25 | import trimesh 26 | import tyro 27 | import viser 28 | from viser.extras import ViserUrdf 29 | from robot_descriptions.loaders.yourdfpy import load_robot_description 30 | 31 | import pyroki_snippets as pks 32 | 33 | 34 | def main(robot_name: Literal["ur5", "panda"] = "panda"): 35 | if robot_name == "ur5": 36 | urdf = load_robot_description("ur5_description") 37 | down_wxyz = np.array([0.707, 0, 0.707, 0]) 38 | target_link_name = "ee_link" 39 | 40 | # For UR5 it's important to initialize the robot in a safe configuration; 41 | # the zero-configuration puts the robot aligned with the wall obstacle. 42 | default_cfg = np.zeros(6) 43 | default_cfg[1] = -1.308 44 | robot = pk.Robot.from_urdf(urdf, default_joint_cfg=default_cfg) 45 | 46 | elif robot_name == "panda": 47 | urdf = load_robot_description("panda_description") 48 | target_link_name = "panda_hand" 49 | down_wxyz = np.array([0, 0, 1, 0]) # for panda! 50 | robot = pk.Robot.from_urdf(urdf) 51 | 52 | else: 53 | raise ValueError(f"Invalid robot: {robot_name}") 54 | 55 | robot_coll = pk.collision.RobotCollision.from_urdf(urdf) 56 | 57 | # Define the trajectory problem: 58 | # - number of timesteps, timestep size 59 | timesteps, dt = 25, 0.02 60 | # - the start and end poses. 61 | start_pos, end_pos = np.array([0.5, -0.3, 0.2]), np.array([0.5, 0.3, 0.2]) 62 | 63 | # Define the obstacles: 64 | # - Ground 65 | ground_coll = pk.collision.HalfSpace.from_point_and_normal( 66 | np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) 67 | ) 68 | # - Wall 69 | wall_height = 0.4 70 | wall_width = 0.1 71 | wall_length = 0.4 72 | wall_intervals = np.arange(start=0.3, stop=wall_length + 0.3, step=0.05) 73 | translation = np.concatenate( 74 | [ 75 | wall_intervals.reshape(-1, 1), 76 | np.full((wall_intervals.shape[0], 1), 0.0), 77 | np.full((wall_intervals.shape[0], 1), wall_height / 2), 78 | ], 79 | axis=1, 80 | ) 81 | wall_coll = pk.collision.Capsule.from_radius_height( 82 | position=translation, 83 | radius=np.full((translation.shape[0], 1), wall_width / 2), 84 | height=np.full((translation.shape[0], 1), wall_height), 85 | ) 86 | world_coll = [ground_coll, wall_coll] 87 | 88 | traj = pks.solve_trajopt( 89 | robot, 90 | robot_coll, 91 | world_coll, 92 | target_link_name, 93 | start_pos, 94 | down_wxyz, 95 | end_pos, 96 | down_wxyz, 97 | timesteps, 98 | dt, 99 | ) 100 | traj = np.array(traj) 101 | 102 | # Visualize! 103 | server = viser.ViserServer() 104 | urdf_vis = ViserUrdf(server, urdf) 105 | server.scene.add_grid("/grid", width=2, height=2, cell_size=0.1) 106 | server.scene.add_mesh_trimesh( 107 | "wall_box", 108 | trimesh.creation.box( 109 | extents=(wall_length, wall_width, wall_height), 110 | transform=trimesh.transformations.translation_matrix( 111 | np.array([0.5, 0.0, wall_height / 2]) 112 | ), 113 | ), 114 | ) 115 | for name, pos in zip(["start", "end"], [start_pos, end_pos]): 116 | server.scene.add_frame( 117 | f"/{name}", 118 | position=pos, 119 | wxyz=down_wxyz, 120 | axes_length=0.05, 121 | axes_radius=0.01, 122 | ) 123 | 124 | slider = server.gui.add_slider( 125 | "Timestep", min=0, max=timesteps - 1, step=1, initial_value=0 126 | ) 127 | playing = server.gui.add_checkbox("Playing", initial_value=True) 128 | 129 | while True: 130 | if playing.value: 131 | slider.value = (slider.value + 1) % timesteps 132 | 133 | urdf_vis.update_cfg(traj[slider.value]) 134 | time.sleep(1.0 / 10.0) 135 | 136 | 137 | if __name__ == "__main__": 138 | tyro.cli(main) 139 | -------------------------------------------------------------------------------- /docs/source/examples/08_ik_with_mimic_joints.rst: -------------------------------------------------------------------------------- 1 | .. Comment: this file is automatically generated by `update_example_docs.py`. 2 | It should not be modified manually. 3 | 4 | IK with mimic joints 5 | ========================================== 6 | 7 | 8 | This is a simple test to ensure that mimic joints are handled correctly in the IK solver. 9 | 10 | We procedurally generate a "zig-zag" chain of links with mimic joints, where: 11 | 12 | 13 | * the first joint is driven directly, 14 | * and the remaining joints are driven indirectly via mimic joints. 15 | The multipliers alternate between -1 and 1, and the offsets are all 0. 16 | 17 | All examples can be run by first cloning the PyRoki repository, which includes the ``pyroki_snippets`` implementation details. 18 | 19 | 20 | 21 | .. code-block:: python 22 | :linenos: 23 | 24 | 25 | import time 26 | 27 | import numpy as np 28 | import pyroki as pk 29 | import viser 30 | from viser.extras import ViserUrdf 31 | 32 | import pyroki_snippets as pks 33 | 34 | 35 | def create_chain_xml(length: float = 0.2, num_chains: int = 5) -> str: 36 | def create_link(idx): 37 | return f""" 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | """ 50 | 51 | def create_joint(idx, multiplier=1.0, offset=0.0): 52 | mimic = f'' 53 | return f""" 54 | 55 | 56 | 57 | 58 | 59 | {mimic if idx != 0 else ""} 60 | 61 | 62 | """ 63 | 64 | world_joint_origin_z = length / 2.0 65 | xml = f""" 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | """ 77 | # Create the definition + first link. 78 | xml += create_link(0) 79 | xml += create_link(1) 80 | xml += create_joint(0) 81 | 82 | # Procedurally add more links. 83 | assert num_chains >= 2 84 | for idx in range(2, num_chains): 85 | xml += create_link(idx) 86 | current_offset = 0.0 87 | current_multiplier = 1.0 * ((-1) ** (idx % 2)) 88 | xml += create_joint(idx - 1, current_multiplier, current_offset) 89 | 90 | xml += """ 91 | 92 | """ 93 | return xml 94 | 95 | 96 | def main(): 97 | """Main function for basic IK.""" 98 | 99 | import yourdfpy 100 | import tempfile 101 | 102 | xml = create_chain_xml(num_chains=10, length=0.1) 103 | with tempfile.NamedTemporaryFile(mode="w", suffix=".urdf") as f: 104 | f.write(xml) 105 | f.flush() 106 | urdf = yourdfpy.URDF.load(f.name) 107 | 108 | # Create robot. 109 | robot = pk.Robot.from_urdf(urdf) 110 | 111 | # Set up visualizer. 112 | server = viser.ViserServer() 113 | server.scene.add_grid("/ground", width=2, height=2) 114 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") 115 | target_link_name_handle = server.gui.add_dropdown( 116 | "Target Link", 117 | robot.links.names, 118 | initial_value=robot.links.names[-1], 119 | ) 120 | 121 | # Create interactive controller with initial position. 122 | ik_target = server.scene.add_transform_controls( 123 | "/ik_target", scale=0.2, position=(0.0, 0.1, 0.1), wxyz=(0, 0, 1, 0) 124 | ) 125 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 126 | 127 | while True: 128 | # Solve IK. 129 | start_time = time.time() 130 | solution = pks.solve_ik( 131 | robot=robot, 132 | target_link_name=target_link_name_handle.value, 133 | target_position=np.array(ik_target.position), 134 | target_wxyz=np.array(ik_target.wxyz), 135 | ) 136 | 137 | # Update timing handle. 138 | elapsed_time = time.time() - start_time 139 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) 140 | 141 | # Update visualizer. 142 | urdf_vis.update_cfg(solution) 143 | 144 | 145 | if __name__ == "__main__": 146 | main() 147 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | PyRoki 2 | ========== 3 | 4 | `Project page `_ `•` `arXiv `_ `•` `Code `_ 5 | 6 | **PyRoki** is a library for robot kinematic optimization (Python Robot Kinematics). 7 | 8 | 1. **Modular**: Optimization variables and cost functions are decoupled, enabling reusable components across tasks. Objectives like collision avoidance and pose matching can be applied to both IK and trajectory optimization without reimplementation. 9 | 10 | 2. **Extensible**: ``PyRoki`` supports automatic differentiation for user-defined costs with Jacobian computation, a real-time cost-weight tuning interface, and optional analytical Jacobians for performance-critical use cases. 11 | 12 | 3. **Cross-Platform**: ``PyRoki`` runs on CPU, GPU, and TPU, allowing efficient scaling from single-robot use cases to large-scale parallel processing for motion datasets or planning. 13 | 14 | We demonstrate how ``PyRoki`` solves IK, trajectory optimization, and motion retargeting for robot hands and humanoids in a unified framework. It uses a Levenberg-Marquardt optimizer to efficiently solve these tasks, and we evaluate its performance on batched IK. 15 | 16 | Features include: 17 | 18 | - Differentiable robot forward kinematics model from a URDF. 19 | - Automatic generation of robot collision primitives (e.g., capsules). 20 | - Differentiable collision bodies with numpy broadcasting logic. 21 | - Common cost factors (e.g., end effector pose, self/world-collision, manipulability). 22 | - Arbitrary costs, getting Jacobians either calculated :doc:`through autodiff or defined manually`. 23 | - Integration with a `Levenberg-Marquardt Solver `_ that supports optimization on manifolds (e.g., `lie groups `_). 24 | - Cross-platform support (CPU, GPU, TPU) via JAX. 25 | 26 | 27 | 28 | Installation 29 | ------------ 30 | 31 | You can install ``pyroki`` with ``pip``, on Python 3.12+: 32 | 33 | .. code-block:: bash 34 | 35 | git clone https://github.com/chungmin99/pyroki.git 36 | cd pyroki 37 | pip install -e . 38 | 39 | 40 | Python 3.10-3.11 should also work, but support may be dropped in the future. 41 | 42 | Limitations 43 | ----------- 44 | 45 | - **Soft constraints only**: We use a nonlinear least-squares formulation and model joint limits, collision avoidance, etc. as soft penalties with high weights rather than hard constraints. 46 | - **Static shapes & JIT overhead**: JAX JIT compilation is triggered on first run and when input shapes change (e.g., number of targets, obstacles). Arrays can be pre-padded to vectorize over inputs with different shapes. 47 | - **No sampling-based planners**: We don't include sampling-based planners (e.g., graphs, trees). 48 | - **Collision performance**: Speed and accuracy comparisons against other robot toolkits such as CuRobo have not been extensively performed, and is likely slower than other toolkits for collision-heavy scenarios. 49 | 50 | The following are current implementation limitations that could potentially be addressed in future versions: 51 | 52 | - **Joint types**: We only support revolute, continuous, prismatic, and fixed joints. Other URDF joint types are treated as fixed joints. 53 | - **Collision geometry**: We are limited to sphere, capsule, halfspace, and heightmap geometries. Mesh collision is approximated as capsules. 54 | - **Kinematic structures**: We only support kinematic chains; no closed-loop mechanisms or parallel manipulators. 55 | 56 | Examples 57 | -------- 58 | 59 | .. toctree:: 60 | :maxdepth: 1 61 | :caption: Examples 62 | 63 | examples/01_basic_ik 64 | examples/02_bimanual_ik 65 | examples/03_mobile_ik 66 | examples/04_ik_with_coll 67 | examples/05_ik_with_manipulability 68 | examples/06_online_planning 69 | examples/07_trajopt 70 | examples/08_ik_with_mimic_joints 71 | examples/09_hand_retargeting 72 | examples/10_humanoid_retargeting 73 | 74 | 75 | Acknowledgements 76 | ---------------- 77 | ``PyRoki`` is heavily inspired by the prior work, including but not limited to 78 | `Trac-IK `_, 79 | `cuRobo `_, 80 | `pink `_, 81 | `mink `_, 82 | `Drake `_, and 83 | `Dex-Retargeting `_. 84 | Thank you so much for your great work! 85 | 86 | 87 | Citation 88 | -------- 89 | 90 | If you find this work useful, please cite it as follows: 91 | 92 | .. code-block:: bibtex 93 | 94 | @misc{pyroki2025, 95 | title={PyRoki: A Modular Toolkit for Robot Kinematic Optimization}, 96 | author={Chung Min Kim* and Brent Yi* and Hongsuk Choi and Yi Ma and Ken Goldberg and Angjoo Kanazawa}, 97 | year={2025}, 98 | eprint={2505.03728}, 99 | archivePrefix={arXiv}, 100 | primaryClass={cs.RO}, 101 | url={https://arxiv.org/abs/2505.03728}, 102 | } 103 | 104 | Thanks for using ``PyRoki``! 105 | -------------------------------------------------------------------------------- /docs/source/misc/writing_manual_jac.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | Defining Jacobians Manually 4 | ===================================== 5 | 6 | ``pyroki`` supports both autodiff and manually defined Jacobians for computing cost gradients. 7 | 8 | For reference, this is the robot pose matching cost :math:`C_\text{pose}`: 9 | 10 | .. math:: 11 | 12 | \sum_{i} \left( w_{p,i} \left\| \mathbf{p}_{i}(q) - \mathbf{p}_{i}^* \right\|^2 + w_{R,i} \left\| \text{log}(\mathbf{R}_{i}(q)^{-1} \mathbf{R}_{i}^*) \right\|^2 \right) 13 | 14 | 15 | where :math:`q` is the robot joint configuration, :math:`\mathbf{p}_{i}(q)` is the position of the :math:`i`-th link, :math:`\mathbf{R}_{i}(q)` is the rotation matrix of the :math:`i`-th link, and :math:`w_{p,i}` and :math:`w_{R,i}` are the position and orientation weights, respectively. 16 | 17 | The following is the most common way to define costs in ``pyroki`` -- with autodiff: 18 | 19 | .. code-block:: python 20 | 21 | @Cost.create_factory 22 | def pose_cost( 23 | vals: VarValues, 24 | robot: Robot, 25 | joint_var: Var[Array], 26 | target_pose: jaxlie.SE3, 27 | target_link_index: Array, 28 | pos_weight: Array | float, 29 | ori_weight: Array | float, 30 | ) -> Array: 31 | """Computes the residual for matching link poses to target poses.""" 32 | assert target_link_index.dtype == jnp.int32 33 | joint_cfg = vals[joint_var] 34 | Ts_link_world = robot.forward_kinematics(joint_cfg) 35 | pose_actual = jaxlie.SE3(Ts_link_world[..., target_link_index, :]) 36 | 37 | # Position residual = position error * weight 38 | pos_residual = (pose_actual.translation() - target_pose.translation()) * pos_weight 39 | # Orientation residual = log(actual_inv * target) * weight 40 | ori_residual = (pose_actual.rotation().inverse() @ target_pose.rotation()).log() * ori_weight 41 | 42 | return jnp.concatenate([pos_residual, ori_residual]).flatten() 43 | 44 | The alternative is to manually write out the Jacobian -- while automatic differentiation is convenient and works well for most use cases, analytical Jacobians can provide better performance, which we show in the `paper `_. 45 | 46 | We provide two implementations of pose matching cost with custom Jacobians: 47 | 48 | - an `analytically derived Jacobian `_ (~200 lines), or 49 | - a `numerically approximated Jacobian `_ through finite differences (~50 lines). 50 | -------------------------------------------------------------------------------- /docs/update_example_docs.py: -------------------------------------------------------------------------------- 1 | """Helper script for updating the auto-generated examples pages in the documentation.""" 2 | 3 | from __future__ import annotations 4 | 5 | import dataclasses 6 | import pathlib 7 | import shutil 8 | from typing import Iterable 9 | 10 | import m2r2 11 | import tyro 12 | 13 | 14 | @dataclasses.dataclass 15 | class ExampleMetadata: 16 | index: str 17 | index_with_zero: str 18 | source: str 19 | title: str 20 | description: str 21 | 22 | @staticmethod 23 | def from_path(path: pathlib.Path) -> ExampleMetadata: 24 | # 01_functions -> 01, _, functions. 25 | index, _, _ = path.stem.partition("_") 26 | 27 | # 01 -> 1. 28 | index_with_zero = index 29 | index = str(int(index)) 30 | 31 | print("Parsing", path) 32 | source = path.read_text().strip() 33 | docstring = source.split('"""')[1].strip() 34 | 35 | title, _, description = docstring.partition("\n") 36 | 37 | description = description.strip() 38 | description += "\n" 39 | description += "\n" 40 | description += "All examples can be run by first cloning the PyRoki repository, which includes the `pyroki_snippets` implementation details." 41 | 42 | return ExampleMetadata( 43 | index=index, 44 | index_with_zero=index_with_zero, 45 | source=source.partition('"""')[2].partition('"""')[2].strip(), 46 | title=title, 47 | description=description, 48 | ) 49 | 50 | 51 | def get_example_paths(examples_dir: pathlib.Path) -> Iterable[pathlib.Path]: 52 | return filter( 53 | lambda p: not p.name.startswith("_"), sorted(examples_dir.glob("*.py")) 54 | ) 55 | 56 | 57 | REPO_ROOT = pathlib.Path(__file__).absolute().parent.parent 58 | 59 | 60 | def main( 61 | examples_dir: pathlib.Path = REPO_ROOT / "examples", 62 | sphinx_source_dir: pathlib.Path = REPO_ROOT / "docs" / "source", 63 | ) -> None: 64 | example_doc_dir = sphinx_source_dir / "examples" 65 | shutil.rmtree(example_doc_dir) 66 | example_doc_dir.mkdir() 67 | 68 | for path in get_example_paths(examples_dir): 69 | ex = ExampleMetadata.from_path(path) 70 | 71 | relative_dir = path.parent.relative_to(examples_dir) 72 | target_dir = example_doc_dir / relative_dir 73 | target_dir.mkdir(exist_ok=True, parents=True) 74 | 75 | (target_dir / f"{path.stem}.rst").write_text( 76 | "\n".join( 77 | [ 78 | ( 79 | ".. Comment: this file is automatically generated by" 80 | " `update_example_docs.py`." 81 | ), 82 | " It should not be modified manually.", 83 | "", 84 | f"{ex.title}", 85 | "==========================================", 86 | "", 87 | m2r2.convert(ex.description), 88 | "", 89 | "", 90 | ".. code-block:: python", 91 | " :linenos:", 92 | "", 93 | "", 94 | "\n".join( 95 | f" {line}".rstrip() for line in ex.source.split("\n") 96 | ), 97 | "", 98 | ] 99 | ) 100 | ) 101 | 102 | 103 | if __name__ == "__main__": 104 | tyro.cli(main, description=__doc__) 105 | -------------------------------------------------------------------------------- /examples/01_basic_ik.py: -------------------------------------------------------------------------------- 1 | """Basic IK 2 | 3 | Simplest Inverse Kinematics Example using PyRoki. 4 | """ 5 | 6 | import time 7 | 8 | import numpy as np 9 | import pyroki as pk 10 | import viser 11 | from robot_descriptions.loaders.yourdfpy import load_robot_description 12 | from viser.extras import ViserUrdf 13 | 14 | import pyroki_snippets as pks 15 | 16 | 17 | def main(): 18 | """Main function for basic IK.""" 19 | 20 | urdf = load_robot_description("panda_description") 21 | target_link_name = "panda_hand" 22 | 23 | # Create robot. 24 | robot = pk.Robot.from_urdf(urdf) 25 | 26 | # Set up visualizer. 27 | server = viser.ViserServer() 28 | server.scene.add_grid("/ground", width=2, height=2) 29 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") 30 | 31 | # Create interactive controller with initial position. 32 | ik_target = server.scene.add_transform_controls( 33 | "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0) 34 | ) 35 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 36 | 37 | while True: 38 | # Solve IK. 39 | start_time = time.time() 40 | solution = pks.solve_ik( 41 | robot=robot, 42 | target_link_name=target_link_name, 43 | target_position=np.array(ik_target.position), 44 | target_wxyz=np.array(ik_target.wxyz), 45 | ) 46 | 47 | # Update timing handle. 48 | elapsed_time = time.time() - start_time 49 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) 50 | 51 | # Update visualizer. 52 | urdf_vis.update_cfg(solution) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /examples/02_bimanual_ik.py: -------------------------------------------------------------------------------- 1 | """Bimanual IK 2 | 3 | Same as 01_basic_ik.py, but with two end effectors! 4 | """ 5 | 6 | import time 7 | import viser 8 | from robot_descriptions.loaders.yourdfpy import load_robot_description 9 | import numpy as np 10 | 11 | import pyroki as pk 12 | from viser.extras import ViserUrdf 13 | import pyroki_snippets as pks 14 | 15 | 16 | def main(): 17 | """Main function for bimanual IK.""" 18 | 19 | urdf = load_robot_description("yumi_description") 20 | target_link_names = ["yumi_link_7_r", "yumi_link_7_l"] 21 | 22 | # Create robot. 23 | robot = pk.Robot.from_urdf(urdf) 24 | 25 | # Set up visualizer. 26 | server = viser.ViserServer() 27 | server.scene.add_grid("/ground", width=2, height=2) 28 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") 29 | 30 | # Create interactive controller with initial position. 31 | ik_target_0 = server.scene.add_transform_controls( 32 | "/ik_target_0", scale=0.2, position=(0.41, -0.3, 0.56), wxyz=(0, 0, 1, 0) 33 | ) 34 | ik_target_1 = server.scene.add_transform_controls( 35 | "/ik_target_1", scale=0.2, position=(0.41, 0.3, 0.56), wxyz=(0, 0, 1, 0) 36 | ) 37 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 38 | 39 | while True: 40 | # Solve IK. 41 | start_time = time.time() 42 | solution = pks.solve_ik_with_multiple_targets( 43 | robot=robot, 44 | target_link_names=target_link_names, 45 | target_positions=np.array([ik_target_0.position, ik_target_1.position]), 46 | target_wxyzs=np.array([ik_target_0.wxyz, ik_target_1.wxyz]), 47 | ) 48 | 49 | # Update timing handle. 50 | elapsed_time = time.time() - start_time 51 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) 52 | 53 | # Update visualizer. 54 | urdf_vis.update_cfg(solution) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /examples/03_mobile_ik.py: -------------------------------------------------------------------------------- 1 | """Mobile IK 2 | 3 | Same as 01_basic_ik.py, but with a mobile base! 4 | """ 5 | 6 | import time 7 | import viser 8 | from robot_descriptions.loaders.yourdfpy import load_robot_description 9 | import numpy as np 10 | 11 | import pyroki as pk 12 | from viser.extras import ViserUrdf 13 | import pyroki_snippets as pks 14 | 15 | 16 | def main(): 17 | """Main function for IK with a mobile base. 18 | The base is fixed along the xy plane, and is biased towards being at the origin. 19 | """ 20 | 21 | urdf = load_robot_description("fetch_description") 22 | target_link_name = "gripper_link" 23 | 24 | # Create robot. 25 | robot = pk.Robot.from_urdf(urdf) 26 | 27 | # Set up visualizer. 28 | server = viser.ViserServer() 29 | server.scene.add_grid("/ground", width=2, height=2) 30 | base_frame = server.scene.add_frame("/base", show_axes=False) 31 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") 32 | 33 | # Create interactive controller with initial position. 34 | ik_target = server.scene.add_transform_controls( 35 | "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0.707, 0, -0.707) 36 | ) 37 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 38 | 39 | cfg = np.array(robot.joint_var_cls(0).default_factory()) 40 | 41 | while True: 42 | # Solve IK. 43 | start_time = time.time() 44 | base_pos, base_wxyz, cfg = pks.solve_ik_with_base( 45 | robot=robot, 46 | target_link_name=target_link_name, 47 | target_position=np.array(ik_target.position), 48 | target_wxyz=np.array(ik_target.wxyz), 49 | fix_base_position=(False, False, True), # Only free along xy plane. 50 | fix_base_orientation=(True, True, False), # Free along z-axis rotation. 51 | prev_pos=base_frame.position, 52 | prev_wxyz=base_frame.wxyz, 53 | prev_cfg=cfg, 54 | ) 55 | 56 | # Update timing handle. 57 | elapsed_time = time.time() - start_time 58 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) 59 | 60 | # Update visualizer. 61 | urdf_vis.update_cfg(cfg) 62 | base_frame.position = np.array(base_pos) 63 | base_frame.wxyz = np.array(base_wxyz) 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /examples/04_ik_with_coll.py: -------------------------------------------------------------------------------- 1 | """IK with Collision 2 | 3 | Basic Inverse Kinematics with Collision Avoidance using PyRoKi. 4 | """ 5 | 6 | import time 7 | 8 | import numpy as np 9 | import pyroki as pk 10 | import viser 11 | from pyroki.collision import HalfSpace, RobotCollision, Sphere 12 | from robot_descriptions.loaders.yourdfpy import load_robot_description 13 | from viser.extras import ViserUrdf 14 | 15 | import pyroki_snippets as pks 16 | 17 | 18 | def main(): 19 | """Main function for basic IK with collision.""" 20 | urdf = load_robot_description("panda_description") 21 | target_link_name = "panda_hand" 22 | robot = pk.Robot.from_urdf(urdf) 23 | 24 | robot_coll = RobotCollision.from_urdf(urdf) 25 | plane_coll = HalfSpace.from_point_and_normal( 26 | np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) 27 | ) 28 | sphere_coll = Sphere.from_center_and_radius( 29 | np.array([0.0, 0.0, 0.0]), np.array([0.05]) 30 | ) 31 | 32 | # Set up visualizer. 33 | server = viser.ViserServer() 34 | server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) 35 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot") 36 | 37 | # Create interactive controller for IK target. 38 | ik_target_handle = server.scene.add_transform_controls( 39 | "/ik_target", scale=0.2, position=(0.5, 0.0, 0.5), wxyz=(0, 0, 1, 0) 40 | ) 41 | 42 | # Create interactive controller and mesh for the sphere obstacle. 43 | sphere_handle = server.scene.add_transform_controls( 44 | "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4) 45 | ) 46 | server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh()) 47 | 48 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 49 | 50 | while True: 51 | start_time = time.time() 52 | 53 | sphere_coll_world_current = sphere_coll.transform_from_wxyz_position( 54 | wxyz=np.array(sphere_handle.wxyz), 55 | position=np.array(sphere_handle.position), 56 | ) 57 | 58 | world_coll_list = [plane_coll, sphere_coll_world_current] 59 | solution = pks.solve_ik_with_collision( 60 | robot=robot, 61 | coll=robot_coll, 62 | world_coll_list=world_coll_list, 63 | target_link_name=target_link_name, 64 | target_position=np.array(ik_target_handle.position), 65 | target_wxyz=np.array(ik_target_handle.wxyz), 66 | ) 67 | 68 | # Update timing handle. 69 | timing_handle.value = (time.time() - start_time) * 1000 70 | 71 | # Update visualizer. 72 | urdf_vis.update_cfg(solution) 73 | 74 | 75 | if __name__ == "__main__": 76 | main() 77 | -------------------------------------------------------------------------------- /examples/05_ik_with_manipulability.py: -------------------------------------------------------------------------------- 1 | """IK with Manipulability 2 | 3 | Inverse Kinematics with Manipulability using PyRoKi. 4 | """ 5 | 6 | import time 7 | import viser 8 | from robot_descriptions.loaders.yourdfpy import load_robot_description 9 | import numpy as np 10 | 11 | import pyroki as pk 12 | from viser.extras import ViserUrdf 13 | import pyroki_snippets as pks 14 | 15 | 16 | def main(): 17 | """Main function for basic IK.""" 18 | 19 | urdf = load_robot_description("panda_description") 20 | target_link_name = "panda_hand" 21 | 22 | # Create robot. 23 | robot = pk.Robot.from_urdf(urdf) 24 | 25 | # Set up visualizer. 26 | server = viser.ViserServer() 27 | server.scene.add_grid("/ground", width=2, height=2) 28 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") 29 | 30 | # Create interactive controller with initial position. 31 | ik_target = server.scene.add_transform_controls( 32 | "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0) 33 | ) 34 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 35 | value_handle = server.gui.add_number("Yoshikawa Index", 0.001, disabled=True) 36 | weight_handle = server.gui.add_slider( 37 | "Manipulability Weight", 0.0, 10.0, 0.001, 0.0 38 | ) 39 | manip_ellipse = pk.viewer.ManipulabilityEllipse( 40 | server, 41 | robot, 42 | root_node_name="/manipulability", 43 | target_link_name=target_link_name, 44 | ) 45 | 46 | while True: 47 | # Solve IK. 48 | start_time = time.time() 49 | solution = pks.solve_ik_with_manipulability( 50 | robot=robot, 51 | target_link_name=target_link_name, 52 | target_position=np.array(ik_target.position), 53 | target_wxyz=np.array(ik_target.wxyz), 54 | manipulability_weight=weight_handle.value, 55 | ) 56 | 57 | manip_ellipse.update(solution) 58 | value_handle.value = manip_ellipse.manipulability 59 | 60 | # Update timing handle. 61 | elapsed_time = time.time() - start_time 62 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) 63 | 64 | # Update visualizer. 65 | urdf_vis.update_cfg(solution) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /examples/06_online_planning.py: -------------------------------------------------------------------------------- 1 | """Online Planning 2 | 3 | Run online planning in collision aware environments. 4 | """ 5 | 6 | import time 7 | 8 | import numpy as np 9 | import pyroki as pk 10 | import viser 11 | from pyroki.collision import HalfSpace, RobotCollision, Sphere 12 | from robot_descriptions.loaders.yourdfpy import load_robot_description 13 | from viser.extras import ViserUrdf 14 | 15 | import pyroki_snippets as pks 16 | 17 | 18 | def main(): 19 | """Main function for online planning with collision.""" 20 | urdf = load_robot_description("panda_description") 21 | target_link_name = "panda_hand" 22 | robot = pk.Robot.from_urdf(urdf) 23 | 24 | robot_coll = RobotCollision.from_urdf(urdf) 25 | plane_coll = HalfSpace.from_point_and_normal( 26 | np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) 27 | ) 28 | sphere_coll = Sphere.from_center_and_radius( 29 | np.array([0.0, 0.0, 0.0]), np.array([0.05]) 30 | ) 31 | 32 | # Define the online planning parameters. 33 | len_traj, dt = 5, 0.1 34 | 35 | # Set up visualizer. 36 | server = viser.ViserServer() 37 | server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) 38 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot") 39 | 40 | # Create interactive controller for IK target. 41 | ik_target_handle = server.scene.add_transform_controls( 42 | "/ik_target", scale=0.2, position=(0.3, 0.0, 0.5), wxyz=(0, 0, 1, 0) 43 | ) 44 | 45 | # Create interactive controller and mesh for the sphere obstacle. 46 | sphere_handle = server.scene.add_transform_controls( 47 | "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4) 48 | ) 49 | server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh()) 50 | target_frame_handle = server.scene.add_batched_axes( 51 | "target_frame", 52 | axes_length=0.05, 53 | axes_radius=0.005, 54 | batched_positions=np.zeros((25, 3)), 55 | batched_wxyzs=np.array([[1.0, 0.0, 0.0, 0.0]] * 25), 56 | ) 57 | 58 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 59 | 60 | sol_pos, sol_wxyz = None, None 61 | sol_traj = np.array( 62 | robot.joint_var_cls.default_factory()[None].repeat(len_traj, axis=0) 63 | ) 64 | while True: 65 | start_time = time.time() 66 | 67 | sphere_coll_world_current = sphere_coll.transform_from_wxyz_position( 68 | wxyz=np.array(sphere_handle.wxyz), 69 | position=np.array(sphere_handle.position), 70 | ) 71 | 72 | world_coll_list = [plane_coll, sphere_coll_world_current] 73 | sol_traj, sol_pos, sol_wxyz = pks.solve_online_planning( 74 | robot=robot, 75 | robot_coll=robot_coll, 76 | world_coll=world_coll_list, 77 | target_link_name=target_link_name, 78 | target_position=np.array(ik_target_handle.position), 79 | target_wxyz=np.array(ik_target_handle.wxyz), 80 | timesteps=len_traj, 81 | dt=dt, 82 | start_cfg=sol_traj[0], 83 | prev_sols=sol_traj, 84 | ) 85 | 86 | # Update timing handle. 87 | timing_handle.value = ( 88 | 0.99 * timing_handle.value + 0.01 * (time.time() - start_time) * 1000 89 | ) 90 | 91 | # Update visualizer. 92 | urdf_vis.update_cfg( 93 | sol_traj[0] 94 | ) # The first step of the online trajectory solution. 95 | 96 | # Update the planned trajectory visualization. 97 | if hasattr(target_frame_handle, "batched_positions"): 98 | target_frame_handle.batched_positions = np.array(sol_pos) # type: ignore[attr-defined] 99 | target_frame_handle.batched_wxyzs = np.array(sol_wxyz) # type: ignore[attr-defined] 100 | else: 101 | # This is an older version of Viser. 102 | target_frame_handle.positions_batched = np.array(sol_pos) # type: ignore[attr-defined] 103 | target_frame_handle.wxyzs_batched = np.array(sol_wxyz) # type: ignore[attr-defined] 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /examples/07_trajopt.py: -------------------------------------------------------------------------------- 1 | """Trajectory Optimization 2 | 3 | Basic Trajectory Optimization using PyRoKi. 4 | 5 | Robot going over a wall, while avoiding world-collisions. 6 | """ 7 | 8 | import time 9 | from typing import Literal 10 | 11 | import numpy as np 12 | import pyroki as pk 13 | import trimesh 14 | import tyro 15 | import viser 16 | from viser.extras import ViserUrdf 17 | from robot_descriptions.loaders.yourdfpy import load_robot_description 18 | 19 | import pyroki_snippets as pks 20 | 21 | 22 | def main(robot_name: Literal["ur5", "panda"] = "panda"): 23 | if robot_name == "ur5": 24 | urdf = load_robot_description("ur5_description") 25 | down_wxyz = np.array([0.707, 0, 0.707, 0]) 26 | target_link_name = "ee_link" 27 | 28 | # For UR5 it's important to initialize the robot in a safe configuration; 29 | # the zero-configuration puts the robot aligned with the wall obstacle. 30 | default_cfg = np.zeros(6) 31 | default_cfg[1] = -1.308 32 | robot = pk.Robot.from_urdf(urdf, default_joint_cfg=default_cfg) 33 | 34 | elif robot_name == "panda": 35 | urdf = load_robot_description("panda_description") 36 | target_link_name = "panda_hand" 37 | down_wxyz = np.array([0, 0, 1, 0]) # for panda! 38 | robot = pk.Robot.from_urdf(urdf) 39 | 40 | else: 41 | raise ValueError(f"Invalid robot: {robot_name}") 42 | 43 | robot_coll = pk.collision.RobotCollision.from_urdf(urdf) 44 | 45 | # Define the trajectory problem: 46 | # - number of timesteps, timestep size 47 | timesteps, dt = 25, 0.02 48 | # - the start and end poses. 49 | start_pos, end_pos = np.array([0.5, -0.3, 0.2]), np.array([0.5, 0.3, 0.2]) 50 | 51 | # Define the obstacles: 52 | # - Ground 53 | ground_coll = pk.collision.HalfSpace.from_point_and_normal( 54 | np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) 55 | ) 56 | # - Wall 57 | wall_height = 0.4 58 | wall_width = 0.1 59 | wall_length = 0.4 60 | wall_intervals = np.arange(start=0.3, stop=wall_length + 0.3, step=0.05) 61 | translation = np.concatenate( 62 | [ 63 | wall_intervals.reshape(-1, 1), 64 | np.full((wall_intervals.shape[0], 1), 0.0), 65 | np.full((wall_intervals.shape[0], 1), wall_height / 2), 66 | ], 67 | axis=1, 68 | ) 69 | wall_coll = pk.collision.Capsule.from_radius_height( 70 | position=translation, 71 | radius=np.full((translation.shape[0], 1), wall_width / 2), 72 | height=np.full((translation.shape[0], 1), wall_height), 73 | ) 74 | world_coll = [ground_coll, wall_coll] 75 | 76 | traj = pks.solve_trajopt( 77 | robot, 78 | robot_coll, 79 | world_coll, 80 | target_link_name, 81 | start_pos, 82 | down_wxyz, 83 | end_pos, 84 | down_wxyz, 85 | timesteps, 86 | dt, 87 | ) 88 | traj = np.array(traj) 89 | 90 | # Visualize! 91 | server = viser.ViserServer() 92 | urdf_vis = ViserUrdf(server, urdf) 93 | server.scene.add_grid("/grid", width=2, height=2, cell_size=0.1) 94 | server.scene.add_mesh_trimesh( 95 | "wall_box", 96 | trimesh.creation.box( 97 | extents=(wall_length, wall_width, wall_height), 98 | transform=trimesh.transformations.translation_matrix( 99 | np.array([0.5, 0.0, wall_height / 2]) 100 | ), 101 | ), 102 | ) 103 | for name, pos in zip(["start", "end"], [start_pos, end_pos]): 104 | server.scene.add_frame( 105 | f"/{name}", 106 | position=pos, 107 | wxyz=down_wxyz, 108 | axes_length=0.05, 109 | axes_radius=0.01, 110 | ) 111 | 112 | slider = server.gui.add_slider( 113 | "Timestep", min=0, max=timesteps - 1, step=1, initial_value=0 114 | ) 115 | playing = server.gui.add_checkbox("Playing", initial_value=True) 116 | 117 | while True: 118 | if playing.value: 119 | slider.value = (slider.value + 1) % timesteps 120 | 121 | urdf_vis.update_cfg(traj[slider.value]) 122 | time.sleep(1.0 / 10.0) 123 | 124 | 125 | if __name__ == "__main__": 126 | tyro.cli(main) 127 | -------------------------------------------------------------------------------- /examples/08_ik_with_mimic_joints.py: -------------------------------------------------------------------------------- 1 | """IK with Mimic Joints 2 | 3 | This is a simple test to ensure that mimic joints are handled correctly in the IK solver. 4 | 5 | We procedurally generate a "zig-zag" chain of links with mimic joints, where: 6 | - the first joint is driven directly, 7 | - and the remaining joints are driven indirectly via mimic joints. 8 | The multipliers alternate between -1 and 1, and the offsets are all 0. 9 | """ 10 | 11 | import time 12 | 13 | import numpy as np 14 | import pyroki as pk 15 | import viser 16 | from viser.extras import ViserUrdf 17 | 18 | import pyroki_snippets as pks 19 | 20 | 21 | def create_chain_xml(length: float = 0.2, num_chains: int = 5) -> str: 22 | def create_link(idx): 23 | return f""" 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | """ 36 | 37 | def create_joint(idx, multiplier=1.0, offset=0.0): 38 | mimic = f'' 39 | return f""" 40 | 41 | 42 | 43 | 44 | 45 | {mimic if idx != 0 else ""} 46 | 47 | 48 | """ 49 | 50 | world_joint_origin_z = length / 2.0 51 | xml = f""" 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | """ 63 | # Create the definition + first link. 64 | xml += create_link(0) 65 | xml += create_link(1) 66 | xml += create_joint(0) 67 | 68 | # Procedurally add more links. 69 | assert num_chains >= 2 70 | for idx in range(2, num_chains): 71 | xml += create_link(idx) 72 | current_offset = 0.0 73 | current_multiplier = 1.0 * ((-1) ** (idx % 2)) 74 | xml += create_joint(idx - 1, current_multiplier, current_offset) 75 | 76 | xml += """ 77 | 78 | """ 79 | return xml 80 | 81 | 82 | def main(): 83 | """Main function for basic IK.""" 84 | 85 | import yourdfpy 86 | import tempfile 87 | 88 | xml = create_chain_xml(num_chains=10, length=0.1) 89 | with tempfile.NamedTemporaryFile(mode="w", suffix=".urdf") as f: 90 | f.write(xml) 91 | f.flush() 92 | urdf = yourdfpy.URDF.load(f.name) 93 | 94 | # Create robot. 95 | robot = pk.Robot.from_urdf(urdf) 96 | 97 | # Set up visualizer. 98 | server = viser.ViserServer() 99 | server.scene.add_grid("/ground", width=2, height=2) 100 | urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") 101 | target_link_name_handle = server.gui.add_dropdown( 102 | "Target Link", 103 | robot.links.names, 104 | initial_value=robot.links.names[-1], 105 | ) 106 | 107 | # Create interactive controller with initial position. 108 | ik_target = server.scene.add_transform_controls( 109 | "/ik_target", scale=0.2, position=(0.0, 0.1, 0.1), wxyz=(0, 0, 1, 0) 110 | ) 111 | timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) 112 | 113 | while True: 114 | # Solve IK. 115 | start_time = time.time() 116 | solution = pks.solve_ik( 117 | robot=robot, 118 | target_link_name=target_link_name_handle.value, 119 | target_position=np.array(ik_target.position), 120 | target_wxyz=np.array(ik_target.wxyz), 121 | ) 122 | 123 | # Update timing handle. 124 | elapsed_time = time.time() - start_time 125 | timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) 126 | 127 | # Update visualizer. 128 | urdf_vis.update_cfg(solution) 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /examples/pyroki_snippets/__init__.py: -------------------------------------------------------------------------------- 1 | from ._online_planning import solve_online_planning as solve_online_planning 2 | from ._solve_ik import solve_ik as solve_ik 3 | from ._solve_ik_with_base import solve_ik_with_base as solve_ik_with_base 4 | from ._solve_ik_with_collision import solve_ik_with_collision as solve_ik_with_collision 5 | from ._solve_ik_with_manipulability import ( 6 | solve_ik_with_manipulability as solve_ik_with_manipulability, 7 | ) 8 | from ._trajopt import solve_trajopt as solve_trajopt 9 | from ._solve_ik_with_multiple_targets import ( 10 | solve_ik_with_multiple_targets as solve_ik_with_multiple_targets, 11 | ) 12 | -------------------------------------------------------------------------------- /examples/pyroki_snippets/_online_planning.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import jax_dataclasses as jdc 6 | import jaxlie 7 | import jaxls 8 | import numpy as onp 9 | import pyroki as pk 10 | 11 | 12 | def solve_online_planning( 13 | robot: pk.Robot, 14 | robot_coll: pk.collision.RobotCollision, 15 | world_coll: Sequence[pk.collision.CollGeom], 16 | target_link_name: str, 17 | target_position: onp.ndarray, 18 | target_wxyz: onp.ndarray, 19 | timesteps: int, 20 | dt: float, 21 | start_cfg: onp.ndarray, 22 | prev_sols: onp.ndarray, 23 | ) -> tuple[onp.ndarray, onp.ndarray, onp.ndarray]: 24 | """Solve online planning with collision.""" 25 | 26 | assert target_position.shape == (3,) and target_wxyz.shape == (4,) 27 | target_link_indices = [robot.links.names.index(target_link_name)] 28 | 29 | target_poses = jaxlie.SE3( 30 | jnp.concatenate([jnp.array(target_wxyz), jnp.array(target_position)], axis=-1) 31 | ) 32 | target_links = jnp.array(target_link_indices) 33 | 34 | # Warm start: use previous solution shifted by one step. 35 | timesteps = timesteps + 1 # for start pose cost. 36 | 37 | sol_traj, sol_pos, sol_wxyz = _solve_online_planning_jax( 38 | robot, 39 | robot_coll, 40 | world_coll, 41 | target_poses, 42 | target_links, 43 | timesteps, 44 | dt, 45 | jnp.array(start_cfg), 46 | jnp.concatenate([prev_sols, prev_sols[-1:]], axis=0), 47 | ) 48 | sol_traj = sol_traj[1:] 49 | sol_pos = sol_pos[1:] 50 | sol_wxyz = sol_wxyz[1:] 51 | 52 | return onp.array(sol_traj), onp.array(sol_pos), onp.array(sol_wxyz) 53 | 54 | 55 | @jdc.jit 56 | def _solve_online_planning_jax( 57 | robot: pk.Robot, 58 | robot_coll: pk.collision.RobotCollision, 59 | world_coll: Sequence[pk.collision.CollGeom], 60 | target_poses: jaxlie.SE3, 61 | target_links: jnp.ndarray, 62 | timesteps: jdc.Static[int], 63 | dt: float, 64 | start_cfg: jnp.ndarray, 65 | prev_sols: jnp.ndarray, 66 | ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: 67 | num_targets = len(target_links) 68 | 69 | def batched_rplus( 70 | pose: jaxlie.SE3, 71 | delta: jax.Array, 72 | ) -> jaxlie.SE3: 73 | return jax.vmap(jaxlie.manifold.rplus)(pose, delta.reshape(num_targets, -1)) 74 | 75 | # Custom SE3 variable to batch across multiple joint targets. 76 | # This is not to be confused with SE3Vars with ids, which we use here for timesteps. 77 | class BatchedSE3Var( # pylint: disable=missing-class-docstring 78 | jaxls.Var[jaxlie.SE3], 79 | default_factory=lambda: jaxlie.SE3.identity((num_targets,)), 80 | retract_fn=batched_rplus, 81 | tangent_dim=jaxlie.SE3.tangent_dim * num_targets, 82 | ): ... 83 | 84 | # --- Define Variables --- 85 | traj_var = robot.joint_var_cls(jnp.arange(0, timesteps)) 86 | traj_var_prev = robot.joint_var_cls(jnp.arange(0, timesteps - 1)) 87 | traj_var_next = robot.joint_var_cls(jnp.arange(1, timesteps)) 88 | pose_var = BatchedSE3Var(jnp.arange(0, timesteps)) 89 | pose_var_prev = BatchedSE3Var(jnp.arange(0, timesteps - 1)) 90 | pose_var_next = BatchedSE3Var(jnp.arange(1, timesteps)) 91 | 92 | init_pose_vals = jaxlie.SE3( 93 | robot.forward_kinematics(prev_sols)[..., target_links, :] 94 | ) 95 | 96 | # --- Define Costs --- 97 | factors: list[jaxls.Cost] = [] # Changed type hint to jaxls.Cost 98 | 99 | @jaxls.Cost.create_factory(name="SE3PoseMatchJointCost") 100 | def match_joint_to_pose_cost( 101 | vals: jaxls.VarValues, 102 | joint_var: jaxls.Var[jnp.ndarray], 103 | pose_var: BatchedSE3Var, 104 | ): 105 | joint_cfg = vals[joint_var] 106 | target_pose = vals[pose_var] 107 | Ts_joint_world = robot.forward_kinematics(joint_cfg) 108 | residual = ( 109 | (jaxlie.SE3(Ts_joint_world[..., target_links, :])).inverse() @ (target_pose) 110 | ).log() 111 | return residual.flatten() * 100.0 112 | 113 | @jaxls.Cost.create_factory(name="SE3SmoothnessCost") 114 | def pose_smoothness_cost( 115 | vals: jaxls.VarValues, 116 | pose_var: BatchedSE3Var, 117 | pose_var_prev: BatchedSE3Var, 118 | ): 119 | return (vals[pose_var].inverse() @ vals[pose_var_prev]).log().flatten() * 1.0 120 | 121 | @jaxls.Cost.create_factory(name="SE3PoseMatchCost") 122 | def pose_match_cost( 123 | vals: jaxls.VarValues, 124 | pose_var: BatchedSE3Var, 125 | ): 126 | return ( 127 | (vals[pose_var].inverse() @ target_poses).log() 128 | * jnp.array([50.0] * 3 + [20.0] * 3) 129 | ).flatten() 130 | 131 | @jaxls.Cost.create_factory(name="MatchStartPoseCost") 132 | def match_start_pose_cost( 133 | vals: jaxls.VarValues, 134 | joint_var: jaxls.Var[jnp.ndarray], 135 | ): 136 | return (vals[joint_var] - start_cfg).flatten() * 100.0 137 | 138 | # Add pose costs. 139 | factors.extend( 140 | [ 141 | pose_match_cost( 142 | BatchedSE3Var(timesteps - 1), 143 | ), 144 | pose_smoothness_cost( 145 | pose_var_next, 146 | pose_var_prev, 147 | ), 148 | ] 149 | ) 150 | 151 | # Need to constrain the start joint cfg. 152 | factors.append(match_start_pose_cost(robot.joint_var_cls(0))) 153 | 154 | # Add joint costs. 155 | factors.extend( 156 | [ 157 | match_joint_to_pose_cost( 158 | traj_var, 159 | pose_var, 160 | ), 161 | pk.costs.smoothness_cost( 162 | traj_var_prev, 163 | traj_var_next, 164 | weight=10.0, 165 | ), 166 | pk.costs.limit_velocity_cost( 167 | jax.tree.map(lambda x: x[None], robot), 168 | traj_var_prev, 169 | traj_var_next, 170 | weight=10.0, 171 | dt=dt, 172 | ), 173 | pk.costs.limit_cost( 174 | jax.tree.map(lambda x: x[None], robot), 175 | traj_var, 176 | weight=100.0, 177 | ), 178 | pk.costs.rest_cost( 179 | traj_var, 180 | jnp.array(traj_var.default_factory())[None], 181 | weight=0.01, 182 | ), 183 | pk.costs.manipulability_cost( 184 | jax.tree.map(lambda x: x[None], robot), 185 | traj_var, 186 | weight=0.01, 187 | target_link_indices=target_links, 188 | ), 189 | pk.costs.self_collision_cost( 190 | jax.tree.map(lambda x: x[None], robot), 191 | jax.tree.map(lambda x: x[None], robot_coll), 192 | traj_var, 193 | weight=10.0, 194 | margin=0.02, 195 | ), 196 | ] 197 | ) 198 | factors.extend( 199 | [ 200 | pk.costs.world_collision_cost( 201 | jax.tree.map(lambda x: x[None], robot), 202 | jax.tree.map(lambda x: x[None], robot_coll), 203 | traj_var, 204 | jax.tree.map(lambda x: x[None], obs), 205 | weight=20.0, 206 | margin=0.1, 207 | ) 208 | for obs in world_coll 209 | ] 210 | ) 211 | 212 | solution = ( 213 | jaxls.LeastSquaresProblem(factors, [traj_var, pose_var]) 214 | .analyze() 215 | .solve( 216 | verbose=False, 217 | initial_vals=jaxls.VarValues.make( 218 | (traj_var.with_value(prev_sols), pose_var.with_value(init_pose_vals)) 219 | ), 220 | termination=jaxls.TerminationConfig(max_iterations=20), 221 | ) 222 | ) 223 | pose_traj = solution[pose_var] 224 | return ( 225 | solution[traj_var], 226 | pose_traj.translation(), 227 | pose_traj.rotation().wxyz, 228 | ) 229 | -------------------------------------------------------------------------------- /examples/pyroki_snippets/_solve_ik.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solves the basic IK problem. 3 | """ 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import jax_dataclasses as jdc 8 | import jaxlie 9 | import jaxls 10 | import numpy as onp 11 | import pyroki as pk 12 | 13 | 14 | def solve_ik( 15 | robot: pk.Robot, 16 | target_link_name: str, 17 | target_wxyz: onp.ndarray, 18 | target_position: onp.ndarray, 19 | ) -> onp.ndarray: 20 | """ 21 | Solves the basic IK problem for a robot. 22 | 23 | Args: 24 | robot: PyRoKi Robot. 25 | target_link_name: String name of the link to be controlled. 26 | target_wxyz: onp.ndarray. Target orientation. 27 | target_position: onp.ndarray. Target position. 28 | 29 | Returns: 30 | cfg: onp.ndarray. Shape: (robot.joint.actuated_count,). 31 | """ 32 | assert target_position.shape == (3,) and target_wxyz.shape == (4,) 33 | target_link_index = robot.links.names.index(target_link_name) 34 | cfg = _solve_ik_jax( 35 | robot, 36 | jnp.array(target_link_index), 37 | jnp.array(target_wxyz), 38 | jnp.array(target_position), 39 | ) 40 | assert cfg.shape == (robot.joints.num_actuated_joints,) 41 | return onp.array(cfg) 42 | 43 | 44 | @jdc.jit 45 | def _solve_ik_jax( 46 | robot: pk.Robot, 47 | target_link_index: jax.Array, 48 | target_wxyz: jax.Array, 49 | target_position: jax.Array, 50 | ) -> jax.Array: 51 | joint_var = robot.joint_var_cls(0) 52 | factors = [ 53 | pk.costs.pose_cost_analytic_jac( 54 | robot, 55 | joint_var, 56 | jaxlie.SE3.from_rotation_and_translation( 57 | jaxlie.SO3(target_wxyz), target_position 58 | ), 59 | target_link_index, 60 | pos_weight=50.0, 61 | ori_weight=10.0, 62 | ), 63 | pk.costs.limit_cost( 64 | robot, 65 | joint_var, 66 | weight=100.0, 67 | ), 68 | ] 69 | sol = ( 70 | jaxls.LeastSquaresProblem(factors, [joint_var]) 71 | .analyze() 72 | .solve( 73 | verbose=False, 74 | linear_solver="dense_cholesky", 75 | trust_region=jaxls.TrustRegionConfig(lambda_initial=1.0), 76 | ) 77 | ) 78 | return sol[joint_var] 79 | -------------------------------------------------------------------------------- /examples/pyroki_snippets/_solve_ik_with_base.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jax_dataclasses as jdc 4 | import jaxlie 5 | import jaxls 6 | 7 | import numpy as onp 8 | 9 | import pyroki as pk 10 | 11 | 12 | def solve_ik_with_base( 13 | robot: pk.Robot, 14 | target_link_name: str, 15 | target_position: onp.ndarray, 16 | target_wxyz: onp.ndarray, 17 | fix_base_position: tuple[bool, bool, bool], 18 | fix_base_orientation: tuple[bool, bool, bool], 19 | prev_pos: onp.ndarray, 20 | prev_wxyz: onp.ndarray, 21 | prev_cfg: onp.ndarray, 22 | ) -> tuple[onp.ndarray, onp.ndarray, onp.ndarray]: 23 | """ 24 | Solves the basic IK problem for a robot with a mobile base. 25 | 26 | Args: 27 | robot: PyRoKi Robot. 28 | target_link_name: str. 29 | position: onp.ndarray. Shape: (3,). 30 | wxyz: onp.ndarray. Shape: (4,). 31 | fix_base_position: Whether to fix the base position (x, y, z). 32 | fix_base_orientation: Whether to fix the base orientation (w_x, w_y, w_z). 33 | prev_pos, prev_wxyz, prev_cfg: Previous base position, orientation, and joint configuration, for smooth motion. 34 | 35 | Returns: 36 | base_pos: onp.ndarray. Shape: (3,). 37 | base_wxyz: onp.ndarray. Shape: (4,). 38 | cfg: onp.ndarray. Shape: (robot.joint.actuated_count,). 39 | """ 40 | assert target_position.shape == (3,) and target_wxyz.shape == (4,) 41 | assert prev_pos.shape == (3,) and prev_wxyz.shape == (4,) 42 | assert prev_cfg.shape == (robot.joints.num_actuated_joints,) 43 | target_link_idx = robot.links.names.index(target_link_name) 44 | 45 | T_world_targets = jaxlie.SE3( 46 | jnp.concatenate([jnp.array(target_wxyz), jnp.array(target_position)], axis=-1) 47 | ) 48 | base_pose, cfg = _solve_ik_jax( 49 | robot, 50 | T_world_targets, 51 | jnp.array(target_link_idx), 52 | jnp.array(fix_base_position + fix_base_orientation), 53 | jnp.array(prev_pos), 54 | jnp.array(prev_wxyz), 55 | jnp.array(prev_cfg), 56 | ) 57 | assert cfg.shape == (robot.joints.num_actuated_joints,) 58 | 59 | base_pos = base_pose.translation() 60 | base_wxyz = base_pose.rotation().wxyz 61 | assert base_pos.shape == (3,) and base_wxyz.shape == (4,) 62 | 63 | return onp.array(base_pos), onp.array(base_wxyz), onp.array(cfg) 64 | 65 | 66 | @jdc.jit 67 | def _solve_ik_jax( 68 | robot: pk.Robot, 69 | T_world_target: jaxlie.SE3, 70 | target_joint_indices: jnp.ndarray, 71 | fix_base: jnp.ndarray, 72 | prev_pos: jnp.ndarray, 73 | prev_wxyz: jnp.ndarray, 74 | prev_cfg: jnp.ndarray, 75 | ) -> tuple[jaxlie.SE3, jax.Array]: 76 | joint_var = robot.joint_var_cls(0) 77 | 78 | def retract_fn(transform: jaxlie.SE3, delta: jax.Array) -> jaxlie.SE3: 79 | """Same as jaxls.SE3Var.retract_fn, but removing updates on certain axes.""" 80 | delta = delta * (1 - fix_base) 81 | return jaxls.SE3Var.retract_fn(transform, delta) 82 | 83 | class ConstrainedSE3Var( 84 | jaxls.Var[jaxlie.SE3], 85 | default_factory=lambda: jaxlie.SE3.from_rotation_and_translation( 86 | jaxlie.SO3(prev_wxyz), 87 | prev_pos, 88 | ), 89 | tangent_dim=jaxlie.SE3.tangent_dim, 90 | retract_fn=retract_fn, 91 | ): ... 92 | 93 | base_var = ConstrainedSE3Var(0) 94 | 95 | factors = [ 96 | pk.costs.pose_cost_with_base( 97 | robot, 98 | joint_var, 99 | base_var, 100 | T_world_target, 101 | target_joint_indices, 102 | pos_weight=jnp.array(5.0), 103 | ori_weight=jnp.array(1.0), 104 | ), 105 | pk.costs.limit_cost( 106 | robot, 107 | joint_var, 108 | jnp.array(100.0), 109 | ), 110 | pk.costs.rest_with_base_cost( 111 | joint_var, 112 | base_var, 113 | jnp.array(joint_var.default_factory()), 114 | jnp.array( 115 | [0.01] * robot.joints.num_actuated_joints 116 | + [0.1] * 3 # Base position DoF. 117 | + [0.001] * 3, # Base orientation DoF. 118 | ), 119 | ), 120 | ] 121 | sol = ( 122 | jaxls.LeastSquaresProblem(factors, [joint_var, base_var]) 123 | .analyze() 124 | .solve( 125 | initial_vals=jaxls.VarValues.make( 126 | [joint_var.with_value(prev_cfg), base_var] 127 | ), 128 | verbose=False, 129 | ) 130 | ) 131 | return sol[base_var], sol[joint_var] 132 | -------------------------------------------------------------------------------- /examples/pyroki_snippets/_solve_ik_with_collision.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solves the basic IK problem with collision avoidance. 3 | """ 4 | 5 | from typing import Sequence 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import jax_dataclasses as jdc 10 | import jaxlie 11 | import jaxls 12 | import numpy as onp 13 | import pyroki as pk 14 | 15 | 16 | def solve_ik_with_collision( 17 | robot: pk.Robot, 18 | coll: pk.collision.RobotCollision, 19 | world_coll_list: Sequence[pk.collision.CollGeom], 20 | target_link_name: str, 21 | target_position: onp.ndarray, 22 | target_wxyz: onp.ndarray, 23 | ) -> onp.ndarray: 24 | """ 25 | Solves the basic IK problem for a robot. 26 | 27 | Args: 28 | robot: PyRoKi Robot. 29 | target_link_name: Sequence[str]. Length: num_targets. 30 | position: ArrayLike. Shape: (num_targets, 3), or (3,). 31 | wxyz: ArrayLike. Shape: (num_targets, 4), or (4,). 32 | 33 | Returns: 34 | cfg: ArrayLike. Shape: (robot.joint.actuated_count,). 35 | """ 36 | assert target_position.shape == (3,) and target_wxyz.shape == (4,) 37 | target_link_idx = robot.links.names.index(target_link_name) 38 | 39 | T_world_targets = jaxlie.SE3( 40 | jnp.concatenate([jnp.array(target_wxyz), jnp.array(target_position)], axis=-1) 41 | ) 42 | cfg = _solve_ik_with_collision_jax( 43 | robot, 44 | coll, 45 | world_coll_list, 46 | T_world_targets, 47 | jnp.array(target_link_idx), 48 | ) 49 | assert cfg.shape == (robot.joints.num_actuated_joints,) 50 | 51 | return onp.array(cfg) 52 | 53 | 54 | @jdc.jit 55 | def _solve_ik_with_collision_jax( 56 | robot: pk.Robot, 57 | coll: pk.collision.RobotCollision, 58 | world_coll_list: Sequence[pk.collision.CollGeom], 59 | T_world_target: jaxlie.SE3, 60 | target_link_index: jax.Array, 61 | ) -> jax.Array: 62 | """Solves the basic IK problem with collision avoidance. Returns joint configuration.""" 63 | joint_var = robot.joint_var_cls(0) 64 | vars = [joint_var] 65 | 66 | # Weights and margins defined directly in factors 67 | costs = [ 68 | pk.costs.pose_cost( 69 | robot, 70 | joint_var, 71 | target_pose=T_world_target, 72 | target_link_index=target_link_index, 73 | pos_weight=5.0, 74 | ori_weight=1.0, 75 | ), 76 | pk.costs.limit_cost( 77 | robot, 78 | joint_var=joint_var, 79 | weight=100.0, 80 | ), 81 | pk.costs.rest_cost( 82 | joint_var, 83 | rest_pose=jnp.array(joint_var.default_factory()), 84 | weight=0.01, 85 | ), 86 | pk.costs.self_collision_cost( 87 | robot, 88 | robot_coll=coll, 89 | joint_var=joint_var, 90 | margin=0.02, 91 | weight=5.0, 92 | ), 93 | ] 94 | costs.extend( 95 | [ 96 | pk.costs.world_collision_cost( 97 | robot, coll, joint_var, world_coll, 0.05, 10.0 98 | ) 99 | for world_coll in world_coll_list 100 | ] 101 | ) 102 | 103 | sol = ( 104 | jaxls.LeastSquaresProblem(costs, vars) 105 | .analyze() 106 | .solve(verbose=False, linear_solver="dense_cholesky") 107 | ) 108 | return sol[joint_var] 109 | -------------------------------------------------------------------------------- /examples/pyroki_snippets/_solve_ik_with_manipulability.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solves the basic IK problem. 3 | """ 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import jax_dataclasses as jdc 8 | import jaxlie 9 | import jaxls 10 | 11 | import numpy as onp 12 | 13 | import pyroki as pk 14 | 15 | 16 | def solve_ik_with_manipulability( 17 | robot: pk.Robot, 18 | target_link_name: str, 19 | target_position: onp.ndarray, 20 | target_wxyz: onp.ndarray, 21 | manipulability_weight: float = 0.0, 22 | ) -> onp.ndarray: 23 | """ 24 | Solves the basic IK problem for a robot, with manipulability cost. 25 | 26 | Args: 27 | robot: PyRoKi Robot. 28 | target_link_name: str. 29 | position: onp.ndarray. Shape: (3,). 30 | wxyz: onp.ndarray. Shape: (4,). 31 | manipulability_weight: float. Weight for the manipulability cost. 32 | 33 | Returns: 34 | cfg: onp.ndarray. Shape: (robot.joint.actuated_count,). 35 | """ 36 | assert target_position.shape == (3,) and target_wxyz.shape == (4,) 37 | 38 | assert target_position.shape == (3,) and target_wxyz.shape == (4,) 39 | target_link_idx = robot.links.names.index(target_link_name) 40 | 41 | T_world_target = jaxlie.SE3( 42 | jnp.concatenate([jnp.array(target_wxyz), jnp.array(target_position)], axis=-1) 43 | ) 44 | cfg = _solve_ik_jax( 45 | robot, 46 | T_world_target, 47 | jnp.array(target_link_idx), 48 | jnp.array(manipulability_weight), 49 | ) 50 | assert cfg.shape == (robot.joints.num_actuated_joints,) 51 | 52 | return onp.array(cfg) 53 | 54 | 55 | @jdc.jit 56 | def _solve_ik_jax( 57 | robot: pk.Robot, 58 | T_world_target: jaxlie.SE3, 59 | target_joint_idx: jnp.ndarray, 60 | manipulability_weight: jnp.ndarray, 61 | ) -> jax.Array: 62 | joint_var = robot.joint_var_cls(0) 63 | vars = [joint_var] 64 | factors = [ 65 | pk.costs.pose_cost_analytic_jac( 66 | robot, 67 | joint_var, 68 | T_world_target, 69 | target_joint_idx, 70 | pos_weight=50.0, 71 | ori_weight=10.0, 72 | ), 73 | pk.costs.limit_cost( 74 | robot, 75 | joint_var, 76 | jnp.array([100.0] * robot.joints.num_joints), 77 | ), 78 | pk.costs.rest_cost( 79 | joint_var, 80 | jnp.array(joint_var.default_factory()), 81 | jnp.array([0.01] * robot.joints.num_actuated_joints), 82 | ), 83 | pk.costs.manipulability_cost( 84 | robot, 85 | joint_var, 86 | target_joint_idx, 87 | manipulability_weight, 88 | ), 89 | ] 90 | sol = ( 91 | jaxls.LeastSquaresProblem(factors, vars) 92 | .analyze() 93 | .solve(verbose=False, linear_solver="dense_cholesky") 94 | ) 95 | return sol[joint_var] 96 | -------------------------------------------------------------------------------- /examples/pyroki_snippets/_solve_ik_with_multiple_targets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Solves the basic IK problem. 3 | """ 4 | 5 | from typing import Sequence 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import jax_dataclasses as jdc 10 | import jaxlie 11 | import jaxls 12 | import numpy as onp 13 | import pyroki as pk 14 | 15 | 16 | def solve_ik_with_multiple_targets( 17 | robot: pk.Robot, 18 | target_link_names: Sequence[str], 19 | target_wxyzs: onp.ndarray, 20 | target_positions: onp.ndarray, 21 | ) -> onp.ndarray: 22 | """ 23 | Solves the basic IK problem for a robot. 24 | 25 | Args: 26 | robot: PyRoKi Robot. 27 | target_link_names: Sequence[str]. List of link names to be controlled. 28 | target_wxyzs: onp.ndarray. Shape: (num_targets, 4). Target orientations. 29 | target_positions: onp.ndarray. Shape: (num_targets, 3). Target positions. 30 | 31 | Returns: 32 | cfg: onp.ndarray. Shape: (robot.joint.actuated_count,). 33 | """ 34 | num_targets = len(target_link_names) 35 | assert target_positions.shape == (num_targets, 3) 36 | assert target_wxyzs.shape == (num_targets, 4) 37 | target_link_indices = [robot.links.names.index(name) for name in target_link_names] 38 | 39 | cfg = _solve_ik_jax( 40 | robot, 41 | jnp.array(target_wxyzs), 42 | jnp.array(target_positions), 43 | jnp.array(target_link_indices), 44 | ) 45 | assert cfg.shape == (robot.joints.num_actuated_joints,) 46 | 47 | return onp.array(cfg) 48 | 49 | 50 | @jdc.jit 51 | def _solve_ik_jax( 52 | robot: pk.Robot, 53 | target_wxyz: jax.Array, 54 | target_position: jax.Array, 55 | target_joint_indices: jax.Array, 56 | ) -> jax.Array: 57 | JointVar = robot.joint_var_cls 58 | 59 | # Get the batch axes for the variable through the target pose. 60 | # Batch axes for the variables and cost terms (e.g., target pose) should be broadcastable! 61 | target_pose = jaxlie.SE3.from_rotation_and_translation( 62 | jaxlie.SO3(target_wxyz), target_position 63 | ) 64 | batch_axes = target_pose.get_batch_axes() 65 | 66 | factors = [ 67 | pk.costs.pose_cost_analytic_jac( 68 | jax.tree.map(lambda x: x[None], robot), 69 | JointVar(jnp.full(batch_axes, 0)), 70 | target_pose, 71 | target_joint_indices, 72 | pos_weight=50.0, 73 | ori_weight=10.0, 74 | ), 75 | pk.costs.rest_cost( 76 | JointVar(0), 77 | rest_pose=JointVar.default_factory(), 78 | weight=1.0, 79 | ), 80 | pk.costs.limit_cost( 81 | robot, 82 | JointVar(0), 83 | jnp.array([100.0] * robot.joints.num_joints), 84 | ), 85 | ] 86 | sol = ( 87 | jaxls.LeastSquaresProblem(factors, [JointVar(0)]) 88 | .analyze() 89 | .solve( 90 | verbose=False, 91 | linear_solver="dense_cholesky", 92 | trust_region=jaxls.TrustRegionConfig(lambda_initial=10.0), 93 | ) 94 | ) 95 | return sol[JointVar(0)] 96 | -------------------------------------------------------------------------------- /examples/pyroki_snippets/_trajopt.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import jax_dataclasses as jdc 6 | import jaxlie 7 | import jaxls 8 | import numpy as onp 9 | import pyroki as pk 10 | from jax.typing import ArrayLike 11 | 12 | 13 | def solve_trajopt( 14 | robot: pk.Robot, 15 | robot_coll: pk.collision.RobotCollision, 16 | world_coll: Sequence[pk.collision.CollGeom], 17 | target_link_name: str, 18 | start_position: ArrayLike, 19 | start_wxyz: ArrayLike, 20 | end_position: ArrayLike, 21 | end_wxyz: ArrayLike, 22 | timesteps: int, 23 | dt: float, 24 | ) -> ArrayLike: 25 | if isinstance(start_position, onp.ndarray): 26 | np = onp 27 | elif isinstance(start_position, jnp.ndarray): 28 | np = jnp 29 | else: 30 | raise ValueError(f"Invalid type for `ArrayLike`: {type(start_position)}") 31 | 32 | # 1. Solve IK for the start and end poses. 33 | target_link_index = robot.links.names.index(target_link_name) 34 | start_cfg, end_cfg = solve_iks_with_collision( 35 | robot=robot, 36 | coll=robot_coll, 37 | world_coll_list=world_coll, 38 | target_link_index=target_link_index, 39 | target_position_0=jnp.array(start_position), 40 | target_wxyz_0=jnp.array(start_wxyz), 41 | target_position_1=jnp.array(end_position), 42 | target_wxyz_1=jnp.array(end_wxyz), 43 | ) 44 | 45 | # 2. Initialize the trajectory through linearly interpolating the start and end poses. 46 | init_traj = jnp.linspace(start_cfg, end_cfg, timesteps) 47 | 48 | # 3. Optimize the trajectory. 49 | traj_vars = robot.joint_var_cls(jnp.arange(timesteps)) 50 | 51 | robot = jax.tree.map(lambda x: x[None], robot) # Add batch dimension. 52 | robot_coll = jax.tree.map(lambda x: x[None], robot_coll) # Add batch dimension. 53 | 54 | # Basic regularization / limit costs. 55 | factors: list[jaxls.Cost] = [ 56 | pk.costs.rest_cost( 57 | traj_vars, 58 | traj_vars.default_factory()[None], 59 | jnp.array([0.01])[None], 60 | ), 61 | pk.costs.limit_cost( 62 | robot, 63 | traj_vars, 64 | jnp.array([100.0])[None], 65 | ), 66 | ] 67 | 68 | # Collision avoidance. 69 | def compute_world_coll_residual( 70 | vals: jaxls.VarValues, 71 | robot: pk.Robot, 72 | robot_coll: pk.collision.RobotCollision, 73 | world_coll_obj: pk.collision.CollGeom, 74 | prev_traj_vars: jaxls.Var[jax.Array], 75 | curr_traj_vars: jaxls.Var[jax.Array], 76 | ): 77 | coll = robot_coll.get_swept_capsules( 78 | robot, vals[prev_traj_vars], vals[curr_traj_vars] 79 | ) 80 | dist = pk.collision.collide( 81 | coll.reshape((-1, 1)), world_coll_obj.reshape((1, -1)) 82 | ) 83 | colldist = pk.collision.colldist_from_sdf(dist, 0.1) 84 | return (colldist * 20.0).flatten() 85 | 86 | for world_coll_obj in world_coll: 87 | factors.append( 88 | jaxls.Cost( 89 | compute_world_coll_residual, 90 | ( 91 | robot, 92 | robot_coll, 93 | jax.tree.map(lambda x: x[None], world_coll_obj), 94 | robot.joint_var_cls(jnp.arange(0, timesteps - 1)), 95 | robot.joint_var_cls(jnp.arange(1, timesteps)), 96 | ), 97 | name="World Collision (sweep)", 98 | ) 99 | ) 100 | 101 | # Start / end pose constraints. 102 | factors.extend( 103 | [ 104 | jaxls.Cost( 105 | lambda vals, var: ((vals[var] - start_cfg) * 100.0).flatten(), 106 | (robot.joint_var_cls(jnp.arange(0, 2)),), 107 | name="start_pose_constraint", 108 | ), 109 | jaxls.Cost( 110 | lambda vals, var: ((vals[var] - end_cfg) * 100.0).flatten(), 111 | (robot.joint_var_cls(jnp.arange(timesteps - 2, timesteps)),), 112 | name="end_pose_constraint", 113 | ), 114 | ] 115 | ) 116 | 117 | # Velocity / acceleration / jerk minimization. 118 | factors.extend( 119 | [ 120 | pk.costs.smoothness_cost( 121 | robot.joint_var_cls(jnp.arange(1, timesteps)), 122 | robot.joint_var_cls(jnp.arange(0, timesteps - 1)), 123 | jnp.array([0.1])[None], 124 | ), 125 | pk.costs.five_point_velocity_cost( 126 | robot, 127 | robot.joint_var_cls(jnp.arange(4, timesteps)), 128 | robot.joint_var_cls(jnp.arange(3, timesteps - 1)), 129 | robot.joint_var_cls(jnp.arange(1, timesteps - 3)), 130 | robot.joint_var_cls(jnp.arange(0, timesteps - 4)), 131 | dt, 132 | jnp.array([10.0])[None], 133 | ), 134 | pk.costs.five_point_acceleration_cost( 135 | robot.joint_var_cls(jnp.arange(2, timesteps - 2)), 136 | robot.joint_var_cls(jnp.arange(4, timesteps)), 137 | robot.joint_var_cls(jnp.arange(3, timesteps - 1)), 138 | robot.joint_var_cls(jnp.arange(1, timesteps - 3)), 139 | robot.joint_var_cls(jnp.arange(0, timesteps - 4)), 140 | dt, 141 | jnp.array([0.1])[None], 142 | ), 143 | pk.costs.five_point_jerk_cost( 144 | robot.joint_var_cls(jnp.arange(6, timesteps)), 145 | robot.joint_var_cls(jnp.arange(5, timesteps - 1)), 146 | robot.joint_var_cls(jnp.arange(4, timesteps - 2)), 147 | robot.joint_var_cls(jnp.arange(2, timesteps - 4)), 148 | robot.joint_var_cls(jnp.arange(1, timesteps - 5)), 149 | robot.joint_var_cls(jnp.arange(0, timesteps - 6)), 150 | dt, 151 | jnp.array([0.1])[None], 152 | ), 153 | ] 154 | ) 155 | 156 | # 4. Solve the optimization problem. 157 | solution = ( 158 | jaxls.LeastSquaresProblem( 159 | factors, 160 | [traj_vars], 161 | ) 162 | .analyze() 163 | .solve( 164 | initial_vals=jaxls.VarValues.make((traj_vars.with_value(init_traj),)), 165 | ) 166 | ) 167 | return np.array(solution[traj_vars]) 168 | 169 | 170 | @jdc.jit 171 | def solve_iks_with_collision( 172 | robot: pk.Robot, 173 | coll: pk.collision.RobotCollision, 174 | world_coll_list: Sequence[pk.collision.CollGeom], 175 | target_link_index: int, 176 | target_position_0: jax.Array, 177 | target_wxyz_0: jax.Array, 178 | target_position_1: jax.Array, 179 | target_wxyz_1: jax.Array, 180 | ) -> tuple[jax.Array, jax.Array]: 181 | """Solves the basic IK problem with collision avoidance. Returns joint configuration.""" 182 | joint_var_0 = robot.joint_var_cls(0) 183 | joint_var_1 = robot.joint_var_cls(1) 184 | joint_vars = robot.joint_var_cls(jnp.arange(2)) 185 | vars = [joint_vars] 186 | 187 | # Weights and margins defined directly in factors. 188 | factors = [ 189 | pk.costs.pose_cost( 190 | robot, 191 | joint_var_0, 192 | jaxlie.SE3.from_rotation_and_translation( 193 | jaxlie.SO3(target_wxyz_0), target_position_0 194 | ), 195 | jnp.array(target_link_index), 196 | jnp.array([5.0] * 3), 197 | jnp.array([1.0] * 3), 198 | ), 199 | pk.costs.pose_cost( 200 | robot, 201 | joint_var_1, 202 | jaxlie.SE3.from_rotation_and_translation( 203 | jaxlie.SO3(target_wxyz_1), target_position_1 204 | ), 205 | jnp.array(target_link_index), 206 | jnp.array([5.0] * 3), 207 | jnp.array([1.0] * 3), 208 | ), 209 | ] 210 | factors.extend( 211 | [ 212 | pk.costs.limit_cost( 213 | jax.tree.map(lambda x: x[None], robot), 214 | joint_vars, 215 | jnp.array(100.0), 216 | ), 217 | pk.costs.rest_cost( 218 | joint_vars, 219 | jnp.array(joint_vars.default_factory()[None]), 220 | jnp.array(0.001), 221 | ), 222 | pk.costs.self_collision_cost( 223 | jax.tree.map(lambda x: x[None], robot), 224 | jax.tree.map(lambda x: x[None], coll), 225 | joint_vars, 226 | 0.02, 227 | 5.0, 228 | ), 229 | ] 230 | ) 231 | factors.extend( 232 | [ 233 | pk.costs.world_collision_cost( 234 | jax.tree.map(lambda x: x[None], robot), 235 | jax.tree.map(lambda x: x[None], coll), 236 | joint_vars, 237 | jax.tree.map(lambda x: x[None], world_coll), 238 | 0.05, 239 | 10.0, 240 | ) 241 | for world_coll in world_coll_list 242 | ] 243 | ) 244 | 245 | # Small cost to encourage the start + end configs to be close to each other. 246 | @jaxls.Cost.create_factory(name="JointSimilarityCost") 247 | def joint_similarity_cost(vals, var_0, var_1): 248 | return ((vals[var_0] - vals[var_1]) * 0.01).flatten() 249 | 250 | factors.append(joint_similarity_cost(joint_var_0, joint_var_1)) 251 | 252 | sol = jaxls.LeastSquaresProblem(factors, vars).analyze().solve(verbose=False) 253 | return sol[joint_var_0], sol[joint_var_1] 254 | -------------------------------------------------------------------------------- /examples/retarget_helpers/_utils.py: -------------------------------------------------------------------------------- 1 | import pyroki as pk 2 | import jax.numpy as jnp 3 | 4 | 5 | def create_conn_tree(robot: pk.Robot, link_indices: jnp.ndarray) -> jnp.ndarray: 6 | """ 7 | Create a NxN connectivity matrix for N links. 8 | The matrix is marked Y if there is a direct kinematic chain connection 9 | between the two links, without bypassing the root link. 10 | """ 11 | n = len(link_indices) 12 | conn_matrix = jnp.zeros((n, n)) 13 | 14 | def is_direct_chain_connection(idx1: int, idx2: int) -> bool: 15 | """Check if two joints are connected in the kinematic chain without other retargeted joints between""" 16 | joint1 = link_indices[idx1] 17 | joint2 = link_indices[idx2] 18 | 19 | # Check path from joint2 up to root 20 | current = joint2 21 | while current != -1: 22 | parent = robot.joints.parent_indices[current] 23 | if parent == joint1: 24 | return True 25 | if parent in link_indices: 26 | # Hit another retargeted joint before finding joint1 27 | break 28 | current = parent 29 | 30 | # Check path from joint1 up to root 31 | current = joint1 32 | while current != -1: 33 | parent = robot.joints.parent_indices[current] 34 | if parent == joint2: 35 | return True 36 | if parent in link_indices: 37 | # Hit another retargeted joint before finding joint2 38 | break 39 | current = parent 40 | 41 | return False 42 | 43 | # Build symmetric connectivity matrix 44 | for i in range(n): 45 | conn_matrix = conn_matrix.at[i, i].set(1.0) # Self-connection 46 | for j in range(i + 1, n): 47 | if is_direct_chain_connection(i, j): 48 | conn_matrix = conn_matrix.at[i, j].set(1.0) 49 | conn_matrix = conn_matrix.at[j, i].set(1.0) 50 | 51 | return conn_matrix 52 | 53 | 54 | SMPL_JOINT_NAMES = [ 55 | "pelvis", 56 | "left_hip", 57 | "right_hip", 58 | "spine_1", 59 | "left_knee", 60 | "right_knee", 61 | "spine_2", 62 | "left_ankle", 63 | "right_ankle", 64 | "spine_3", 65 | "left_foot", 66 | "right_foot", 67 | "neck", 68 | "left_collar", 69 | "right_collar", 70 | "head", 71 | "left_shoulder", 72 | "right_shoulder", 73 | "left_elbow", 74 | "right_elbow", 75 | "left_wrist", 76 | "right_wrist", 77 | "left_hand", 78 | "right_hand", 79 | "nose", 80 | "right_eye", 81 | "left_eye", 82 | "right_ear", 83 | "left_ear", 84 | "left_big_toe", 85 | "left_small_toe", 86 | "left_heel", 87 | "right_big_toe", 88 | "right_small_toe", 89 | "right_heel", 90 | "left_thumb", 91 | "left_index", 92 | "left_middle", 93 | "left_ring", 94 | "left_pinky", 95 | "right_thumb", 96 | "right_index", 97 | "right_middle", 98 | "right_ring", 99 | "right_pinky", 100 | ] 101 | 102 | # When loaded from `g1_description`s 23-dof model. 103 | G1_LINK_NAMES = [ 104 | "pelvis", 105 | "pelvis_contour_link", 106 | "left_hip_pitch_link", 107 | "left_hip_roll_link", 108 | "left_hip_yaw_link", 109 | "left_knee_link", 110 | "left_ankle_pitch_link", 111 | "left_ankle_roll_link", 112 | "right_hip_pitch_link", 113 | "right_hip_roll_link", 114 | "right_hip_yaw_link", 115 | "right_knee_link", 116 | "right_ankle_pitch_link", 117 | "right_ankle_roll_link", 118 | "torso_link", 119 | "head_link", 120 | "left_shoulder_pitch_link", 121 | "left_shoulder_roll_link", 122 | "left_shoulder_yaw_link", 123 | "left_elbow_pitch_link", 124 | "left_elbow_roll_link", 125 | "right_shoulder_pitch_link", 126 | "right_shoulder_roll_link", 127 | "right_shoulder_yaw_link", 128 | "right_elbow_pitch_link", 129 | "right_elbow_roll_link", 130 | "logo_link", 131 | "imu_link", 132 | "left_palm_link", 133 | "left_zero_link", 134 | "left_one_link", 135 | "left_two_link", 136 | "left_three_link", 137 | "left_four_link", 138 | "left_five_link", 139 | "left_six_link", 140 | "right_palm_link", 141 | "right_zero_link", 142 | "right_one_link", 143 | "right_two_link", 144 | "right_three_link", 145 | "right_four_link", 146 | "right_five_link", 147 | "right_six_link", 148 | ] 149 | 150 | 151 | def get_humanoid_retarget_indices() -> tuple[jnp.ndarray, jnp.ndarray]: 152 | smpl_joint_retarget_indices_to_g1 = [] 153 | g1_joint_retarget_indices = [] 154 | 155 | for smpl_name, g1_name in [ 156 | ("pelvis", "pelvis_contour_link"), 157 | ("left_hip", "left_hip_pitch_link"), 158 | ("right_hip", "right_hip_pitch_link"), 159 | ("left_knee", "left_knee_link"), 160 | ("right_knee", "right_knee_link"), 161 | ("left_ankle", "left_ankle_roll_link"), 162 | ("right_ankle", "right_ankle_roll_link"), 163 | ("left_shoulder", "left_shoulder_roll_link"), 164 | ("right_shoulder", "right_shoulder_roll_link"), 165 | ("left_elbow", "left_elbow_pitch_link"), 166 | ("right_elbow", "right_elbow_pitch_link"), 167 | ("left_wrist", "left_palm_link"), 168 | ("right_wrist", "right_palm_link"), 169 | ]: 170 | smpl_joint_retarget_indices_to_g1.append(SMPL_JOINT_NAMES.index(smpl_name)) 171 | g1_joint_retarget_indices.append(G1_LINK_NAMES.index(g1_name)) 172 | 173 | smpl_joint_retarget_indices = jnp.array(smpl_joint_retarget_indices_to_g1) 174 | g1_joint_retarget_indices = jnp.array(g1_joint_retarget_indices) 175 | return smpl_joint_retarget_indices, g1_joint_retarget_indices 176 | 177 | 178 | MANO_TO_SHADOW_MAPPING = { 179 | # Wrist 180 | 0: "palm", 181 | # Thumb 182 | 1: "thhub", 183 | 2: "thmiddle", 184 | 3: "thdistal", 185 | 4: "thtip", 186 | # Index 187 | 5: "ffproximal", 188 | 6: "ffmiddle", 189 | 7: "ffdistal", 190 | 8: "fftip", 191 | # Middle 192 | 9: "mfproximal", 193 | 10: "mfmiddle", 194 | 11: "mfdistal", 195 | 12: "mftip", 196 | # Ring 197 | 13: "rfproximal", 198 | 14: "rfmiddle", 199 | 15: "rfdistal", 200 | 16: "rftip", 201 | # # Little 202 | 17: "lfproximal", 203 | 18: "lfmiddle", 204 | 19: "lfdistal", 205 | 20: "lftip", 206 | } 207 | 208 | 209 | def get_mapping_from_mano_to_shadow(robot: pk.Robot) -> tuple[jnp.ndarray, jnp.ndarray]: 210 | """Get the mapping indices between MANO and Shadow Hand joints.""" 211 | SHADOW_TO_MANO_MAPPING = {v: k for k, v in MANO_TO_SHADOW_MAPPING.items()} 212 | shadow_joint_idx = [] 213 | mano_joint_idx = [] 214 | link_names = robot.links.names 215 | for i, link_name in enumerate(link_names): 216 | if link_name in SHADOW_TO_MANO_MAPPING: 217 | shadow_joint_idx.append(i) 218 | mano_joint_idx.append(SHADOW_TO_MANO_MAPPING[link_name]) 219 | 220 | return jnp.array(shadow_joint_idx), jnp.array(mano_joint_idx) 221 | -------------------------------------------------------------------------------- /examples/retarget_helpers/hand/dexycb_motion.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/examples/retarget_helpers/hand/dexycb_motion.pkl -------------------------------------------------------------------------------- /examples/retarget_helpers/hand/shadowhand_urdf.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/examples/retarget_helpers/hand/shadowhand_urdf.zip -------------------------------------------------------------------------------- /examples/retarget_helpers/humanoid/heightmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/examples/retarget_helpers/humanoid/heightmap.npy -------------------------------------------------------------------------------- /examples/retarget_helpers/humanoid/left_foot_contact.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/examples/retarget_helpers/humanoid/left_foot_contact.npy -------------------------------------------------------------------------------- /examples/retarget_helpers/humanoid/right_foot_contact.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/examples/retarget_helpers/humanoid/right_foot_contact.npy -------------------------------------------------------------------------------- /examples/retarget_helpers/humanoid/smpl_keypoints.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chungmin99/pyroki/f234516fe12c57795f31ca4b4fc9b2e04c6cb233/examples/retarget_helpers/humanoid/smpl_keypoints.npy -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "pyroki" 7 | version = "0.0.0" 8 | description = "Python Robot Kinematics Library" 9 | readme = "README.md" 10 | license = { text="MIT" } 11 | requires-python = ">=3.10" 12 | classifiers = [ 13 | "Programming Language :: Python :: 3.10", 14 | "Programming Language :: Python :: 3.11", 15 | "Programming Language :: Python :: 3.12", 16 | "Programming Language :: Python :: 3.13", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent" 19 | ] 20 | dependencies = [ 21 | "tyro", 22 | "jax>=0.4.0", 23 | "jaxlib", 24 | "jaxlie>=1.0.0", 25 | "jax_dataclasses>=1.0.0", 26 | "jaxtyping", 27 | "loguru", 28 | "robot_descriptions", 29 | "jaxls @ git+https://github.com/brentyi/jaxls.git", 30 | "yourdfpy", 31 | "trimesh", 32 | "viser", 33 | "pyliblzfse", # Need for viser.extras import in viser==0.2.23 34 | ] 35 | 36 | [project.optional-dependencies] 37 | dev = [ 38 | "pyright>=1.1.308", 39 | "scikit-sparse", 40 | "ruff", 41 | "pytest", 42 | "m2r2", 43 | ] 44 | 45 | [tool.ruff.lint] 46 | select = [ 47 | "E", # pycodestyle errors. 48 | "F", # Pyflakes rules. 49 | "PLC", # Pylint convention warnings. 50 | "PLE", # Pylint errors. 51 | "PLR", # Pylint refactor recommendations. 52 | "PLW", # Pylint warnings. 53 | ] 54 | ignore = [ 55 | "E731", # Do not assign a lambda expression, use a def. 56 | "E741", # Ambiguous variable name. (l, O, or I) 57 | "E501", # Line too long. 58 | "E721", # Do not compare types, use `isinstance()`. 59 | "F722", # Forward annotation false positive from jaxtyping. Should be caught by pyright. 60 | "F821", # Forward annotation false positive from jaxtyping. Should be caught by pyright. 61 | "PLR2004", # Magic value used in comparison. 62 | "PLR0915", # Too many statements. 63 | "PLR0913", # Too many arguments. 64 | "PLC0414", # Import alias does not rename variable. (this is used for exporting names) 65 | "PLC1901", # Use falsey strings. 66 | "PLR5501", # Use `elif` instead of `else if`. 67 | "PLR0911", # Too many return statements. 68 | "PLR0912", # Too many branches. 69 | "PLW0603", # Globa statement updates are discouraged. 70 | "PLW2901", # For loop variable overwritten. 71 | ] 72 | -------------------------------------------------------------------------------- /src/pyroki/__init__.py: -------------------------------------------------------------------------------- 1 | from . import collision as collision 2 | from . import costs as costs 3 | from . import viewer as viewer 4 | from ._robot import Robot as Robot 5 | 6 | __version__ = "0.0.0" 7 | -------------------------------------------------------------------------------- /src/pyroki/_robot.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import jax 4 | import jax_dataclasses as jdc 5 | import jaxlie 6 | import jaxls 7 | import yourdfpy 8 | from jax import Array 9 | from jax import numpy as jnp 10 | from jax.typing import ArrayLike 11 | from jaxtyping import Float 12 | 13 | from ._robot_urdf_parser import JointInfo, LinkInfo, RobotURDFParser 14 | 15 | 16 | @jdc.pytree_dataclass 17 | class Robot: 18 | """A differentiable robot kinematics tree.""" 19 | 20 | joints: JointInfo 21 | """Joint information for the robot.""" 22 | 23 | links: LinkInfo 24 | """Link information for the robot.""" 25 | 26 | joint_var_cls: jdc.Static[type[jaxls.Var[Array]]] 27 | """Variable class for the robot configuration.""" 28 | 29 | @staticmethod 30 | def from_urdf( 31 | urdf: yourdfpy.URDF, 32 | default_joint_cfg: Float[ArrayLike, "*batch actuated_count"] | None = None, 33 | ) -> Robot: 34 | """ 35 | Loads a robot kinematic tree from a URDF. 36 | Internally tracks a topological sort of the joints. 37 | 38 | Args: 39 | urdf: The URDF to load the robot from. 40 | default_joint_cfg: The default joint configuration to use for optimization. 41 | """ 42 | joints, links = RobotURDFParser.parse(urdf) 43 | 44 | # Compute default joint configuration. 45 | if default_joint_cfg is None: 46 | default_joint_cfg = (joints.lower_limits + joints.upper_limits) / 2 47 | else: 48 | default_joint_cfg = jnp.array(default_joint_cfg) 49 | assert default_joint_cfg.shape == (joints.num_actuated_joints,) 50 | 51 | # Variable class for the robot configuration. 52 | class JointVar( # pylint: disable=missing-class-docstring 53 | jaxls.Var[Array], 54 | default_factory=lambda: default_joint_cfg, 55 | ): ... 56 | 57 | robot = Robot( 58 | joints=joints, 59 | links=links, 60 | joint_var_cls=JointVar, 61 | ) 62 | 63 | return robot 64 | 65 | @jdc.jit 66 | def forward_kinematics( 67 | self, 68 | cfg: Float[Array, "*batch actuated_count"], 69 | unroll_fk: jdc.Static[bool] = False, 70 | ) -> Float[Array, "*batch link_count 7"]: 71 | """Run forward kinematics on the robot's links, in the provided configuration. 72 | 73 | Computes the world pose of each link frame. The result is ordered 74 | corresponding to `self.link.names`. 75 | 76 | Args: 77 | cfg: The configuration of the actuated joints, in the format `(*batch actuated_count)`. 78 | 79 | Returns: 80 | The SE(3) transforms of the links, ordered by `self.link.names`, 81 | in the format `(*batch, link_count, wxyz_xyz)`. 82 | """ 83 | batch_axes = cfg.shape[:-1] 84 | assert cfg.shape == (*batch_axes, self.joints.num_actuated_joints) 85 | return self._link_poses_from_joint_poses( 86 | self._forward_kinematics_joints(cfg, unroll_fk) 87 | ) 88 | 89 | def _link_poses_from_joint_poses( 90 | self, Ts_world_joint: Float[Array, "*batch actuated_count 7"] 91 | ) -> Float[Array, "*batch link_count 7"]: 92 | (*batch_axes, _, _) = Ts_world_joint.shape 93 | # Get the link poses. 94 | base_link_mask = self.links.parent_joint_indices == -1 95 | parent_joint_indices = jnp.where( 96 | base_link_mask, 0, self.links.parent_joint_indices 97 | ) 98 | identity_pose = jaxlie.SE3.identity().wxyz_xyz 99 | Ts_world_link = jnp.where( 100 | base_link_mask[..., None], 101 | identity_pose, 102 | Ts_world_joint[..., parent_joint_indices, :], 103 | ) 104 | assert Ts_world_link.shape == (*batch_axes, self.links.num_links, 7) 105 | return Ts_world_link 106 | 107 | def _forward_kinematics_joints( 108 | self, 109 | cfg: Float[Array, "*batch actuated_count"], 110 | unroll_fk: jdc.Static[bool] = False, 111 | ) -> Float[Array, "*batch joint_count 7"]: 112 | (*batch_axes, _) = cfg.shape 113 | assert cfg.shape == (*batch_axes, self.joints.num_actuated_joints) 114 | 115 | # Calculate full configuration using the dedicated method 116 | q_full = self.joints.get_full_config(cfg) 117 | 118 | # Calculate delta transforms using the effective config and twists for all joints. 119 | tangents = self.joints.twists * q_full[..., None] 120 | assert tangents.shape == (*batch_axes, self.joints.num_joints, 6) 121 | delta_Ts = jaxlie.SE3.exp(tangents) # Shape: (*batch_axes, self.joint.count, 7) 122 | 123 | # Combine constant parent transform with variable joint delta transform. 124 | Ts_parent_child = ( 125 | jaxlie.SE3(self.joints.parent_transforms) @ delta_Ts 126 | ).wxyz_xyz 127 | assert Ts_parent_child.shape == (*batch_axes, self.joints.num_joints, 7) 128 | 129 | # Topological sort helpers 130 | topo_order = jnp.argsort(self.joints._topo_sort_inv) 131 | Ts_parent_child_sorted = Ts_parent_child[..., self.joints._topo_sort_inv, :] 132 | parent_orig_for_sorted_child = self.joints.parent_indices[ 133 | self.joints._topo_sort_inv 134 | ] 135 | idx_parent_joint_sorted = jnp.where( 136 | parent_orig_for_sorted_child == -1, 137 | -1, 138 | topo_order[parent_orig_for_sorted_child], 139 | ) 140 | 141 | # Compute link transforms relative to world, indexed by sorted *joint* index. 142 | def compute_transform(i: int, Ts_world_link_sorted: Array) -> Array: 143 | parent_sorted_idx = idx_parent_joint_sorted[i] 144 | T_world_parent_link = jnp.where( 145 | parent_sorted_idx == -1, 146 | jaxlie.SE3.identity().wxyz_xyz, 147 | Ts_world_link_sorted[..., parent_sorted_idx, :], 148 | ) 149 | return Ts_world_link_sorted.at[..., i, :].set( 150 | ( 151 | jaxlie.SE3(T_world_parent_link) 152 | @ jaxlie.SE3(Ts_parent_child_sorted[..., i, :]) 153 | ).wxyz_xyz 154 | ) 155 | 156 | Ts_world_link_init_sorted = jnp.zeros((*batch_axes, self.joints.num_joints, 7)) 157 | Ts_world_link_sorted = jax.lax.fori_loop( 158 | lower=0, 159 | upper=self.joints.num_joints, 160 | body_fun=compute_transform, 161 | init_val=Ts_world_link_init_sorted, 162 | unroll=unroll_fk, 163 | ) 164 | 165 | Ts_world_link_joint_indexed = Ts_world_link_sorted[..., topo_order, :] 166 | assert Ts_world_link_joint_indexed.shape == ( 167 | *batch_axes, 168 | self.joints.num_joints, 169 | 7, 170 | ) # This is the link poses indexed by parent *joint* index. 171 | 172 | return Ts_world_link_joint_indexed 173 | -------------------------------------------------------------------------------- /src/pyroki/collision/__init__.py: -------------------------------------------------------------------------------- 1 | """Collision detection primitives and utilities.""" 2 | 3 | from ._collision import colldist_from_sdf as colldist_from_sdf 4 | from ._collision import collide as collide 5 | from ._geometry import Capsule as Capsule 6 | from ._geometry import CollGeom as CollGeom 7 | from ._geometry import HalfSpace as HalfSpace 8 | from ._geometry import Heightmap as Heightmap 9 | from ._geometry import Sphere as Sphere 10 | from ._robot_collision import RobotCollision as RobotCollision 11 | -------------------------------------------------------------------------------- /src/pyroki/collision/_collision.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Callable, Dict, Tuple, cast 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | from jaxtyping import Array, Float 8 | 9 | from ._geometry import Capsule, CollGeom, HalfSpace, Heightmap, Sphere 10 | from ._geometry_pairs import ( 11 | capsule_capsule, 12 | halfspace_capsule, 13 | halfspace_sphere, 14 | heightmap_capsule, 15 | heightmap_halfspace, 16 | heightmap_sphere, 17 | sphere_capsule, 18 | sphere_sphere, 19 | ) 20 | 21 | COLLISION_FUNCTIONS: Dict[ 22 | Tuple[type[CollGeom], type[CollGeom]], Callable[..., Float[Array, "*batch"]] 23 | ] = { 24 | (HalfSpace, Sphere): halfspace_sphere, 25 | (HalfSpace, Capsule): halfspace_capsule, 26 | (Sphere, Sphere): sphere_sphere, 27 | (Sphere, Capsule): sphere_capsule, 28 | (Capsule, Capsule): capsule_capsule, 29 | (Heightmap, Sphere): heightmap_sphere, 30 | (Heightmap, Capsule): heightmap_capsule, 31 | (Heightmap, HalfSpace): heightmap_halfspace, 32 | } 33 | 34 | 35 | def _get_coll_func( 36 | geom1_cls: type[CollGeom], geom2_cls: type[CollGeom] 37 | ) -> Callable[[CollGeom, CollGeom], Float[Array, "*batch"]]: 38 | """Get appropriate collision function (distance only) for given geometry types.""" 39 | func = COLLISION_FUNCTIONS.get((geom1_cls, geom2_cls)) 40 | if func is not None: 41 | return cast(Callable[[CollGeom, CollGeom], Float[Array, "*batch"]], func) 42 | 43 | func_swapped = COLLISION_FUNCTIONS.get((geom2_cls, geom1_cls)) 44 | if func_swapped is not None: 45 | return cast( 46 | Callable[[CollGeom, CollGeom], Float[Array, "*batch"]], 47 | lambda g1, g2: func_swapped(g2, g1), 48 | ) 49 | 50 | raise NotImplementedError( 51 | f"No collision function found for {geom1_cls.__name__} and {geom2_cls.__name__}" 52 | ) 53 | 54 | 55 | def collide(geom1: CollGeom, geom2: CollGeom) -> Float[Array, "*batch"]: 56 | """Calculate collision distance between two geometric objects, handling broadcasting.""" 57 | try: 58 | broadcast_shape = jnp.broadcast_shapes( 59 | geom1.get_batch_axes(), geom2.get_batch_axes() 60 | ) 61 | except ValueError as e: 62 | raise ValueError( 63 | f"Cannot broadcast geometry shapes {geom1.get_batch_axes()} and {geom2.get_batch_axes()}" 64 | ) from e 65 | 66 | geom1 = geom1.broadcast_to(broadcast_shape) 67 | geom2 = geom2.broadcast_to(broadcast_shape) 68 | geom1_cls = type(geom1) 69 | geom2_cls = type(geom2) 70 | return _get_coll_func(geom1_cls, geom2_cls)(geom1, geom2) 71 | 72 | 73 | def pairwise_collide(geom1: CollGeom, geom2: CollGeom) -> Float[Array, "*batch N M"]: 74 | """ 75 | Convenience wrapper around `collide` for computing pairwise distances with broadcasting. 76 | 77 | Args: 78 | geom1: First collection of geometries. Expected to have a shape like 79 | (*batch1, N, ...), where N is the number of geometries. 80 | geom2: Second collection of geometries. Expected to have a shape like 81 | (*batch2, M, ...), where M is the number of geometries. 82 | The batch dimensions (*batch1, *batch2) must be broadcast-compatible. 83 | 84 | Returns: 85 | A matrix of distances with shape (*batch_combined, N, M), where 86 | *batch_combined is the result of broadcasting *batch1 and *batch2. 87 | dist[..., i, j] is the distance between geom1[..., i, :] and geom2[..., j, :]. 88 | """ 89 | # Input checks. 90 | axes1 = geom1.get_batch_axes() 91 | axes2 = geom2.get_batch_axes() 92 | assert len(axes1) >= 1, ( 93 | f"geom1 must have at least one batch dimension to map over, got shape {axes1}" 94 | ) 95 | assert len(axes2) >= 1, ( 96 | f"geom2 must have at least one batch dimension to map over, got shape {axes2}" 97 | ) 98 | 99 | # Determine expected output shape. 100 | batch1_shape = axes1[:-1] 101 | batch2_shape = axes2[:-1] 102 | N = axes1[-1] 103 | M = axes2[-1] 104 | try: 105 | batch_combined_shape = jnp.broadcast_shapes(batch1_shape, batch2_shape) 106 | except ValueError as e: 107 | raise ValueError( 108 | f"Cannot broadcast non-mapped batch shapes {batch1_shape} and {batch2_shape}" 109 | ) from e 110 | expected_output_shape = (*batch_combined_shape, N, M) 111 | result = collide( 112 | geom1.broadcast_to((*batch_combined_shape, N)) 113 | .reshape((*batch_combined_shape, N, 1)) 114 | .broadcast_to(expected_output_shape), 115 | geom2.broadcast_to((*batch_combined_shape, M)) 116 | .reshape((*batch_combined_shape, 1, M)) 117 | .broadcast_to(expected_output_shape), 118 | ) 119 | assert result.shape == expected_output_shape, ( 120 | f"Output shape mismatch. Expected {expected_output_shape}, got {result.shape}" 121 | ) 122 | return result 123 | 124 | 125 | def colldist_from_sdf( 126 | _dist: jax.Array, 127 | activation_dist: jax.Array | float, 128 | ) -> jax.Array: 129 | """ 130 | Convert a signed distance field to a collision distance field, 131 | based on https://arxiv.org/pdf/2310.17274#page=7.39. 132 | 133 | This function applies a smoothing transformation, useful for converting 134 | raw distances into values suitable for cost functions in optimization. 135 | It returns values <= 0, where 0 corresponds to distances >= activation_dist, 136 | and increasingly negative values for deeper penetrations. 137 | 138 | Args: 139 | _dist: Signed distance field values (positive = separation, negative = penetration). 140 | activation_dist: The distance threshold (margin) below which the cost activates. 141 | 142 | Returns: 143 | Transformed collision distance field values (<= 0). 144 | """ 145 | _dist = jnp.minimum(_dist, activation_dist) 146 | _dist = jnp.where( 147 | _dist < 0, 148 | _dist - 0.5 * activation_dist, 149 | -0.5 / (activation_dist + 1e-6) * (_dist - activation_dist) ** 2, 150 | ) 151 | _dist = jnp.minimum(_dist, 0.0) 152 | return _dist 153 | -------------------------------------------------------------------------------- /src/pyroki/collision/_geometry_pairs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import jax.numpy as jnp 4 | from jaxtyping import Float, Array 5 | 6 | from ._geometry import HalfSpace, Sphere, Capsule, Heightmap 7 | from . import _utils 8 | 9 | 10 | # --- HalfSpace Collision Implementations --- 11 | 12 | 13 | def _halfspace_sphere_dist( 14 | halfspace_normal: Float[Array, "*batch 3"], 15 | halfspace_point: Float[Array, "*batch 3"], 16 | sphere_pos: Float[Array, "*batch 3"], 17 | sphere_radius: Float[Array, "*batch"], 18 | ) -> Float[Array, "*batch"]: 19 | """Helper: Calculates distance between a halfspace boundary plane and sphere center, minus radius.""" 20 | dist = ( 21 | jnp.einsum("...i,...i->...", sphere_pos - halfspace_point, halfspace_normal) 22 | - sphere_radius 23 | ) 24 | return dist 25 | 26 | 27 | def halfspace_sphere(halfspace: HalfSpace, sphere: Sphere) -> Float[Array, "*batch"]: 28 | """Calculates distance between a halfspace and a sphere.""" 29 | dist = _halfspace_sphere_dist( 30 | halfspace.normal, 31 | halfspace.pose.translation(), 32 | sphere.pose.translation(), 33 | sphere.radius, 34 | ) 35 | return dist 36 | 37 | 38 | def halfspace_capsule(halfspace: HalfSpace, capsule: Capsule) -> Float[Array, "*batch"]: 39 | """Calculates distance between halfspace and capsule (closest end).""" 40 | halfspace_normal = halfspace.normal 41 | halfspace_point = halfspace.pose.translation() 42 | cap_center = capsule.pose.translation() 43 | cap_radius = capsule.radius 44 | cap_axis = capsule.axis 45 | segment_offset = cap_axis * capsule.height[..., None] / 2 46 | dist1 = _halfspace_sphere_dist( 47 | halfspace_normal, halfspace_point, cap_center + segment_offset, cap_radius 48 | ) 49 | dist2 = _halfspace_sphere_dist( 50 | halfspace_normal, halfspace_point, cap_center - segment_offset, cap_radius 51 | ) 52 | final_dist = jnp.minimum(dist1, dist2) 53 | return final_dist 54 | 55 | 56 | # --- Sphere/Capsule Collision Implementations --- 57 | 58 | 59 | def _sphere_sphere_dist( 60 | pos1: Float[Array, "*batch 3"], 61 | radius1: Float[Array, "*batch"], 62 | pos2: Float[Array, "*batch 3"], 63 | radius2: Float[Array, "*batch"], 64 | ) -> Float[Array, "*batch"]: 65 | """Helper: Calculates distance between two spheres.""" 66 | _, dist_center = _utils.normalize_with_norm(pos2 - pos1) 67 | dist = dist_center - (radius1 + radius2) 68 | return dist 69 | 70 | 71 | def sphere_sphere(sphere1: Sphere, sphere2: Sphere) -> Float[Array, "*batch"]: 72 | """Calculate distance between two spheres.""" 73 | dist = _sphere_sphere_dist( 74 | sphere1.pose.translation(), 75 | sphere1.radius, 76 | sphere2.pose.translation(), 77 | sphere2.radius, 78 | ) 79 | return dist 80 | 81 | 82 | def sphere_capsule(sphere: Sphere, capsule: Capsule) -> Float[Array, "*batch"]: 83 | """Calculate distance between sphere and capsule.""" 84 | cap_pos = capsule.pose.translation() 85 | sphere_pos = sphere.pose.translation() 86 | cap_axis = capsule.axis 87 | segment_offset = cap_axis * capsule.height[..., None] / 2 88 | cap_a = cap_pos - segment_offset 89 | cap_b = cap_pos + segment_offset 90 | pt_on_axis = _utils.closest_segment_point(cap_a, cap_b, sphere_pos) 91 | dist = _sphere_sphere_dist(sphere_pos, sphere.radius, pt_on_axis, capsule.radius) 92 | return dist 93 | 94 | 95 | def capsule_capsule(capsule1: Capsule, capsule2: Capsule) -> Float[Array, "*batch"]: 96 | """Calculate distance between two capsules.""" 97 | pos1 = capsule1.pose.translation() 98 | axis1 = capsule1.axis 99 | length1 = capsule1.height 100 | radius1 = capsule1.radius 101 | segment1_offset = axis1 * length1[..., None] / 2 102 | a1 = pos1 - segment1_offset 103 | b1 = pos1 + segment1_offset 104 | 105 | pos2 = capsule2.pose.translation() 106 | axis2 = capsule2.axis 107 | length2 = capsule2.height 108 | radius2 = capsule2.radius 109 | segment2_offset = axis2 * length2[..., None] / 2 110 | a2 = pos2 - segment2_offset 111 | b2 = pos2 + segment2_offset 112 | 113 | pt1_on_axis, pt2_on_axis = _utils.closest_segment_to_segment_points(a1, b1, a2, b2) 114 | dist = _sphere_sphere_dist(pt1_on_axis, radius1, pt2_on_axis, radius2) 115 | return dist 116 | 117 | 118 | # --- Heightmap Collision Implementations --- 119 | 120 | 121 | def heightmap_sphere(heightmap: Heightmap, sphere: Sphere) -> Float[Array, "*batch"]: 122 | """Calculate approximate distance between heightmap and sphere. 123 | 124 | Approximation: Considers the heightmap point directly below the sphere center 125 | using bilinear interpolation and calculates vertical distance minus radius. 126 | """ 127 | batch_axes = jnp.broadcast_shapes( 128 | heightmap.get_batch_axes(), sphere.get_batch_axes() 129 | ) 130 | 131 | sphere_pos_w = sphere.pose.translation() 132 | sphere_radius = sphere.radius 133 | interpolated_local_z = heightmap._interpolate_height_at_coords(sphere_pos_w) 134 | sphere_pos_h = heightmap.pose.inverse().apply(sphere_pos_w) 135 | sphere_local_z = sphere_pos_h[..., 2] 136 | dist = sphere_local_z - interpolated_local_z - sphere_radius 137 | 138 | assert dist.shape == batch_axes 139 | return dist 140 | 141 | 142 | def heightmap_capsule(heightmap: Heightmap, capsule: Capsule) -> Float[Array, "*batch"]: 143 | """Calculate approximate distance between heightmap and capsule, by 144 | checking heightmap points below capsule endpoints. 145 | 146 | Note that this may miss collisions when capsule body intersects but endpoints are above heightmap! 147 | """ 148 | batch_axes = jnp.broadcast_shapes( 149 | heightmap.get_batch_axes(), capsule.get_batch_axes() 150 | ) 151 | 152 | cap_pos_w = capsule.pose.translation() 153 | cap_radius = capsule.radius 154 | cap_axis_w = capsule.axis # World frame axis 155 | segment_offset_w = cap_axis_w * capsule.height[..., None] / 2 156 | 157 | # Calculate world positions of the two end-sphere centers. 158 | p1_w = cap_pos_w + segment_offset_w 159 | p2_w = cap_pos_w - segment_offset_w 160 | 161 | # Interpolate heightmap surface height (local Z) below each end-sphere center. 162 | h_surf1_local = heightmap._interpolate_height_at_coords(p1_w) 163 | h_surf2_local = heightmap._interpolate_height_at_coords(p2_w) 164 | 165 | # Get end-sphere centers Z coordinates in heightmap's local frame. 166 | p1_h = heightmap.pose.inverse().apply(p1_w) 167 | p2_h = heightmap.pose.inverse().apply(p2_w) 168 | z1_local = p1_h[..., 2] 169 | z2_local = p2_h[..., 2] 170 | 171 | # Calculate vertical distance for each end sphere. 172 | dist1 = z1_local - h_surf1_local - cap_radius 173 | dist2 = z2_local - h_surf2_local - cap_radius 174 | 175 | # Return the minimum distance. 176 | min_dist = jnp.minimum(dist1, dist2) 177 | assert min_dist.shape == batch_axes 178 | return min_dist 179 | 180 | 181 | def heightmap_halfspace( 182 | heightmap: Heightmap, halfspace: HalfSpace 183 | ) -> Float[Array, "*batch"]: 184 | """Calculate approximate distance between heightmap and halfspace. 185 | 186 | Approximation: Finds the minimum signed distance between any heightmap vertex 187 | and the halfspace plane. 188 | """ 189 | batch_axes = jnp.broadcast_shapes( 190 | heightmap.get_batch_axes(), halfspace.get_batch_axes() 191 | ) 192 | 193 | # Heightmap vertices in world frame. 194 | verts_local = heightmap._get_vertices_local() # (*batch, N, 3), N=H*W 195 | verts_world = heightmap.pose.apply(verts_local) # (*batch, N, 3) 196 | 197 | # Halfspace plane properties (world frame). 198 | hs_normal_w = halfspace.normal # (*batch, 3) 199 | hs_point_w = halfspace.pose.translation() # (*batch, 3) 200 | 201 | # Ensure batch dimensions are compatible for broadcasting. 202 | batch_axes = jnp.broadcast_shapes( 203 | heightmap.get_batch_axes(), halfspace.get_batch_axes() 204 | ) 205 | # Expand dims for broadcasting against vertices. 206 | hs_normal_w = jnp.broadcast_to(hs_normal_w, batch_axes + (3,))[..., None, :] 207 | hs_point_w = jnp.broadcast_to(hs_point_w, batch_axes + (3,))[..., None, :] 208 | verts_world = jnp.broadcast_to(verts_world, batch_axes + verts_world.shape[-2:]) 209 | 210 | # Calculate signed distance for each vertex to the plane: 211 | vertex_distances = jnp.einsum( 212 | "...vi,...i->...v", verts_world - hs_point_w, hs_normal_w.squeeze(-2) 213 | ) 214 | 215 | # Find the minimum distance among all vertices. 216 | min_dist = jnp.min(vertex_distances, axis=-1) 217 | assert min_dist.shape == batch_axes 218 | return min_dist 219 | -------------------------------------------------------------------------------- /src/pyroki/collision/_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from typing import Tuple 6 | from jaxtyping import Float, Array 7 | 8 | _SAFE_EPS = 1e-6 9 | 10 | 11 | def make_frame(direction: jax.Array) -> jax.Array: 12 | """Make a frame from a direction vector, aligning the z-axis with the direction.""" 13 | # Based on `mujoco.mjx._src.math.make_frame`. 14 | 15 | is_zero = jnp.isclose(direction, 0.0).all(axis=-1, keepdims=True) 16 | direction = jnp.where( 17 | is_zero, 18 | jnp.broadcast_to(jnp.array([1.0, 0.0, 0.0]), direction.shape), 19 | direction, 20 | ) 21 | direction /= jnp.linalg.norm(direction, axis=-1, keepdims=True) + _SAFE_EPS 22 | 23 | y = jnp.broadcast_to(jnp.array([0, 1, 0]), (*direction.shape[:-1], 3)) 24 | z = jnp.broadcast_to(jnp.array([0, 0, 1]), (*direction.shape[:-1], 3)) 25 | 26 | normal = jnp.where((-0.5 < direction[..., 1:2]) & (direction[..., 1:2] < 0.5), y, z) 27 | normal -= direction * jnp.einsum("...i,...i->...", normal, direction)[..., None] 28 | normal /= jnp.linalg.norm(normal, axis=-1, keepdims=True) + _SAFE_EPS 29 | 30 | return jnp.stack([jnp.cross(normal, direction), normal, direction], axis=-1) 31 | 32 | 33 | def normalize(x: Float[Array, "*batch N"]) -> Float[Array, "*batch N"]: 34 | """Normalizes a vector, handling the zero vector.""" 35 | norm = jnp.linalg.norm(x, axis=-1, keepdims=True) 36 | safe_norm = jnp.where(norm == 0.0, 1.0, norm) 37 | normalized_x = x / safe_norm 38 | return jnp.where(norm == 0.0, jnp.zeros_like(x), normalized_x) 39 | 40 | 41 | def normalize_with_norm( 42 | x: Float[Array, "*batch N"], 43 | ) -> Tuple[Float[Array, "*batch N"], Float[Array, "*batch"]]: 44 | """Normalizes a vector and returns the norm, handling the zero vector.""" 45 | norm = jnp.linalg.norm(x + 1e-6, axis=-1, keepdims=True) 46 | safe_norm = jnp.where(norm == 0.0, 1.0, norm) 47 | normalized_x = x / safe_norm 48 | result_vec = jnp.where(norm == 0.0, jnp.zeros_like(x), normalized_x) 49 | result_norm = norm[..., 0] 50 | return result_vec, result_norm 51 | 52 | 53 | def closest_segment_point( 54 | a: Float[Array, "*batch 3"], 55 | b: Float[Array, "*batch 3"], 56 | pt: Float[Array, "*batch 3"], 57 | ) -> Float[Array, "*batch 3"]: 58 | """Finds the closest point on the line segment [a, b] to point pt.""" 59 | ab = b - a 60 | t = jnp.einsum("...i,...i->...", pt - a, ab) / ( 61 | jnp.einsum("...i,...i->...", ab, ab) + _SAFE_EPS 62 | ) 63 | t_clamped = jnp.clip(t, 0.0, 1.0) 64 | return a + ab * t_clamped[..., None] 65 | 66 | 67 | def closest_segment_to_segment_points( 68 | a1: Float[Array, "*batch 3"], 69 | b1: Float[Array, "*batch 3"], 70 | a2: Float[Array, "*batch 3"], 71 | b2: Float[Array, "*batch 3"], 72 | ) -> Tuple[Float[Array, "*batch 3"], Float[Array, "*batch 3"]]: 73 | """Finds the closest points between two line segments [a1, b1] and [a2, b2].""" 74 | d1 = b1 - a1 # Direction vector of segment S1 75 | d2 = b2 - a2 # Direction vector of segment S2 76 | r = a1 - a2 77 | 78 | a = jnp.einsum("...i,...i->...", d1, d1) # Squared length of segment S1 79 | e = jnp.einsum("...i,...i->...", d2, d2) # Squared length of segment S2 80 | f = jnp.einsum("...i,...i->...", d2, r) 81 | c = jnp.einsum("...i,...i->...", d1, r) 82 | b = jnp.einsum("...i,...i->...", d1, d2) 83 | denom = a * e - b * b # Squared area of the parallelogram defined by d1, d2 84 | 85 | s_num = b * f - c * e 86 | t_num = a * f - b * c 87 | 88 | s_parallel = -c / (a + _SAFE_EPS) 89 | t_parallel = f / (e + _SAFE_EPS) 90 | 91 | s = jnp.where(denom < _SAFE_EPS, s_parallel, s_num / (denom + _SAFE_EPS)) 92 | t = jnp.where(denom < _SAFE_EPS, t_parallel, t_num / (denom + _SAFE_EPS)) 93 | 94 | s_clamped = jnp.clip(s, 0.0, 1.0) 95 | t_clamped = jnp.clip(t, 0.0, 1.0) 96 | 97 | t_recomp = jnp.einsum( 98 | "...i,...i->...", d2, (a1 + d1 * s_clamped[..., None]) - a2 99 | ) / (e + _SAFE_EPS) 100 | t_final = jnp.where( 101 | jnp.abs(s - s_clamped) > _SAFE_EPS, jnp.clip(t_recomp, 0.0, 1.0), t_clamped 102 | ) 103 | 104 | s_recomp = jnp.einsum("...i,...i->...", d1, (a2 + d2 * t_final[..., None]) - a1) / ( 105 | a + _SAFE_EPS 106 | ) 107 | s_final = jnp.where( 108 | jnp.abs(t - t_final) > _SAFE_EPS, jnp.clip(s_recomp, 0.0, 1.0), s_clamped 109 | ) 110 | 111 | c1 = a1 + d1 * s_final[..., None] 112 | c2 = a2 + d2 * t_final[..., None] 113 | return c1, c2 114 | -------------------------------------------------------------------------------- /src/pyroki/costs/__init__.py: -------------------------------------------------------------------------------- 1 | from ._costs import five_point_acceleration_cost as five_point_acceleration_cost 2 | from ._costs import five_point_jerk_cost as five_point_jerk_cost 3 | from ._costs import five_point_velocity_cost as five_point_velocity_cost 4 | from ._costs import limit_cost as limit_cost 5 | from ._costs import limit_velocity_cost as limit_velocity_cost 6 | from ._costs import manipulability_cost as manipulability_cost 7 | from ._costs import pose_cost as pose_cost 8 | from ._costs import pose_cost_with_base as pose_cost_with_base 9 | from ._costs import rest_cost as rest_cost 10 | from ._costs import rest_with_base_cost as rest_with_base_cost 11 | from ._costs import self_collision_cost as self_collision_cost 12 | from ._costs import smoothness_cost as smoothness_cost 13 | from ._costs import world_collision_cost as world_collision_cost 14 | from ._pose_cost_analytic_jac import pose_cost_analytic_jac as pose_cost_analytic_jac 15 | from ._pose_cost_numerical_jac import pose_cost_numerical_jac as pose_cost_numerical_jac 16 | -------------------------------------------------------------------------------- /src/pyroki/costs/_costs.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jaxlie 4 | from jax import Array 5 | from jaxls import Cost, Var, VarValues 6 | 7 | from .._robot import Robot 8 | from ..collision import CollGeom, RobotCollision, colldist_from_sdf 9 | 10 | 11 | @Cost.create_factory 12 | def pose_cost( 13 | vals: VarValues, 14 | robot: Robot, 15 | joint_var: Var[Array], 16 | target_pose: jaxlie.SE3, 17 | target_link_index: Array, 18 | pos_weight: Array | float, 19 | ori_weight: Array | float, 20 | ) -> Array: 21 | """Computes the residual for matching link poses to target poses.""" 22 | assert target_link_index.dtype == jnp.int32 23 | joint_cfg = vals[joint_var] 24 | Ts_link_world = robot.forward_kinematics(joint_cfg) 25 | pose_actual = jaxlie.SE3(Ts_link_world[..., target_link_index, :]) 26 | residual = (pose_actual.inverse() @ target_pose).log() 27 | pos_residual = residual[..., :3] * pos_weight 28 | ori_residual = residual[..., 3:] * ori_weight 29 | return jnp.concatenate([pos_residual, ori_residual]).flatten() 30 | 31 | 32 | @Cost.create_factory 33 | def pose_cost_with_base( 34 | vals: VarValues, 35 | robot: Robot, 36 | joint_var: Var[Array], 37 | T_world_base_var: Var[jaxlie.SE3], 38 | target_pose: jaxlie.SE3, 39 | target_link_indices: Array, 40 | pos_weight: Array | float, 41 | ori_weight: Array | float, 42 | ) -> Array: 43 | """Computes the residual for matching link poses relative to a mobile base.""" 44 | assert target_link_indices.dtype == jnp.int32 45 | joint_cfg = vals[joint_var] 46 | T_world_base = vals[T_world_base_var] 47 | Ts_base_link = robot.forward_kinematics(joint_cfg) # FK is T_base_link 48 | T_base_target_link = jaxlie.SE3(Ts_base_link[..., target_link_indices, :]) 49 | T_world_target_link_actual = T_world_base @ T_base_target_link 50 | 51 | residual = (T_world_target_link_actual.inverse() @ target_pose).log() 52 | pos_residual = residual[..., :3] * pos_weight 53 | ori_residual = residual[..., 3:] * ori_weight 54 | return jnp.concatenate([pos_residual, ori_residual]).flatten() 55 | 56 | 57 | # --- Limit Costs --- 58 | 59 | 60 | @Cost.create_factory 61 | def limit_cost( 62 | vals: VarValues, 63 | robot: Robot, 64 | joint_var: Var[Array], 65 | weight: Array | float, 66 | ) -> Array: 67 | """Computes the residual penalizing joint limit violations.""" 68 | joint_cfg = vals[joint_var] 69 | joint_cfg_eff = robot.joints.get_full_config(joint_cfg) 70 | residual_upper = jnp.maximum(0.0, joint_cfg_eff - robot.joints.upper_limits_all) 71 | residual_lower = jnp.maximum(0.0, robot.joints.lower_limits_all - joint_cfg_eff) 72 | return ((residual_upper + residual_lower) * weight).flatten() 73 | 74 | 75 | @Cost.create_factory 76 | def limit_velocity_cost( 77 | vals: VarValues, 78 | robot: Robot, 79 | joint_var: Var[Array], 80 | prev_joint_var: Var[Array], 81 | dt: float, 82 | weight: Array | float, 83 | ) -> Array: 84 | """Computes the residual penalizing joint velocity limit violations.""" 85 | joint_vel = (vals[joint_var] - vals[prev_joint_var]) / dt 86 | residual = jnp.maximum(0.0, jnp.abs(joint_vel) - robot.joints.velocity_limits) 87 | return (residual * weight).flatten() 88 | 89 | 90 | # --- Regularization Costs --- 91 | 92 | 93 | @Cost.create_factory 94 | def rest_cost( 95 | vals: VarValues, 96 | joint_var: Var[Array], 97 | rest_pose: Array, 98 | weight: Array | float, 99 | ) -> Array: 100 | """Computes the residual biasing joints towards a rest pose.""" 101 | return ((vals[joint_var] - rest_pose) * weight).flatten() 102 | 103 | 104 | @Cost.create_factory 105 | def rest_with_base_cost( 106 | vals: VarValues, 107 | joint_var: Var[Array], 108 | T_world_base_var: Var[jaxlie.SE3], 109 | rest_pose: Array, 110 | weight: Array | float, 111 | ) -> Array: 112 | """Computes the residual biasing joints and base towards rest/identity.""" 113 | residual_joints = vals[joint_var] - rest_pose 114 | residual_base = vals[T_world_base_var].log() 115 | return (jnp.concatenate([residual_joints, residual_base]) * weight).flatten() 116 | 117 | 118 | @Cost.create_factory 119 | def smoothness_cost( 120 | vals: VarValues, 121 | curr_joint_var: Var[Array], 122 | past_joint_var: Var[Array], 123 | weight: Array | float, 124 | ) -> Array: 125 | """Computes the residual penalizing joint configuration differences (velocity).""" 126 | return ((vals[curr_joint_var] - vals[past_joint_var]) * weight).flatten() 127 | 128 | 129 | # --- Manipulability Cost --- 130 | 131 | 132 | def _compute_manip_yoshikawa( 133 | cfg: Array, 134 | robot: Robot, 135 | target_link_index: jax.Array, 136 | ) -> Array: 137 | """Helper: Computes manipulability measure for a single link.""" 138 | jacobian = jax.jacfwd( 139 | lambda q: jaxlie.SE3(robot.forward_kinematics(q)).translation() 140 | )(cfg)[target_link_index] 141 | JJT = jacobian @ jacobian.T 142 | assert JJT.shape == (3, 3) 143 | return jnp.sqrt(jnp.maximum(0.0, jnp.linalg.det(JJT))) 144 | 145 | 146 | @Cost.create_factory 147 | def manipulability_cost( 148 | vals: VarValues, 149 | robot: Robot, 150 | joint_var: Var[Array], 151 | target_link_indices: Array, 152 | weight: Array | float, 153 | ) -> Array: 154 | """Computes the residual penalizing low manipulability (translation).""" 155 | cfg = vals[joint_var] 156 | if target_link_indices.ndim == 0: 157 | vmapped_manip = _compute_manip_yoshikawa(cfg, robot, target_link_indices) 158 | else: 159 | vmapped_manip = jax.vmap(_compute_manip_yoshikawa, in_axes=(None, None, 0))( 160 | cfg, robot, target_link_indices 161 | ) 162 | residual = 1.0 / (vmapped_manip + 1e-6) 163 | return (residual * weight).flatten() 164 | 165 | 166 | # --- Collision Costs --- 167 | 168 | 169 | @Cost.create_factory 170 | def self_collision_cost( 171 | vals: VarValues, 172 | robot: Robot, 173 | robot_coll: RobotCollision, 174 | joint_var: Var[Array], 175 | margin: float, 176 | weight: Array | float, 177 | ) -> Array: 178 | """Computes the residual penalizing self-collisions below a margin.""" 179 | cfg = vals[joint_var] 180 | active_distances = robot_coll.compute_self_collision_distance(robot, cfg) 181 | residual = colldist_from_sdf(active_distances, margin) 182 | return (residual * weight).flatten() 183 | 184 | 185 | @Cost.create_factory 186 | def world_collision_cost( 187 | vals: VarValues, 188 | robot: Robot, 189 | robot_coll: RobotCollision, 190 | joint_var: Var[Array], 191 | world_geom: CollGeom, 192 | margin: float, 193 | weight: Array | float, 194 | ) -> Array: 195 | """Computes the residual penalizing world collisions below a margin.""" 196 | cfg = vals[joint_var] 197 | dist_matrix = robot_coll.compute_world_collision_distance(robot, cfg, world_geom) 198 | residual = colldist_from_sdf(dist_matrix, margin) 199 | return (residual * weight).flatten() 200 | 201 | 202 | # --- Finite Difference Costs (Velocity, Acceleration, Jerk) --- 203 | 204 | 205 | @Cost.create_factory 206 | def five_point_velocity_cost( 207 | vals: VarValues, 208 | robot: Robot, # Needed for limits 209 | var_t_plus_2: Var[Array], 210 | var_t_plus_1: Var[Array], 211 | var_t_minus_1: Var[Array], 212 | var_t_minus_2: Var[Array], 213 | dt: float, 214 | weight: Array | float, 215 | ) -> Array: 216 | """Computes the residual penalizing velocity limit violations (5-point stencil).""" 217 | q_tm2 = vals[var_t_minus_2] 218 | q_tm1 = vals[var_t_minus_1] 219 | q_tp1 = vals[var_t_plus_1] 220 | q_tp2 = vals[var_t_plus_2] 221 | 222 | velocity = (-q_tp2 + 8 * q_tp1 - 8 * q_tm1 + q_tm2) / (12 * dt) 223 | vel_limits = robot.joints.velocity_limits 224 | limit_violation = jnp.maximum(0.0, jnp.abs(velocity) - vel_limits) 225 | return (limit_violation * weight).flatten() 226 | 227 | 228 | @Cost.create_factory 229 | def five_point_acceleration_cost( 230 | vals: VarValues, 231 | var_t: Var[Array], 232 | var_t_plus_2: Var[Array], 233 | var_t_plus_1: Var[Array], 234 | var_t_minus_1: Var[Array], 235 | var_t_minus_2: Var[Array], 236 | dt: float, 237 | weight: Array | float, 238 | ) -> Array: 239 | """Computes the residual minimizing joint acceleration (5-point stencil).""" 240 | q_tm2 = vals[var_t_minus_2] 241 | q_tm1 = vals[var_t_minus_1] 242 | q_t = vals[var_t] 243 | q_tp1 = vals[var_t_plus_1] 244 | q_tp2 = vals[var_t_plus_2] 245 | 246 | acceleration = (-q_tp2 + 16 * q_tp1 - 30 * q_t + 16 * q_tm1 - q_tm2) / (12 * dt**2) 247 | return (acceleration * weight).flatten() 248 | 249 | 250 | @Cost.create_factory 251 | def five_point_jerk_cost( 252 | vals: VarValues, 253 | var_t_plus_3: Var[Array], 254 | var_t_plus_2: Var[Array], 255 | var_t_plus_1: Var[Array], 256 | var_t_minus_1: Var[Array], 257 | var_t_minus_2: Var[Array], 258 | var_t_minus_3: Var[Array], 259 | dt: float, 260 | weight: Array | float, 261 | ) -> Array: 262 | """Computes the residual minimizing joint jerk (7-point stencil).""" 263 | q_tm3 = vals[var_t_minus_3] 264 | q_tm2 = vals[var_t_minus_2] 265 | q_tm1 = vals[var_t_minus_1] 266 | q_tp1 = vals[var_t_plus_1] 267 | q_tp2 = vals[var_t_plus_2] 268 | q_tp3 = vals[var_t_plus_3] 269 | 270 | jerk = (-q_tp3 + 8 * q_tp2 - 13 * q_tp1 + 13 * q_tm1 - 8 * q_tm2 + q_tm3) / ( 271 | 8 * dt**3 272 | ) 273 | return (jerk * weight).flatten() 274 | -------------------------------------------------------------------------------- /src/pyroki/costs/_pose_cost_analytic_jac.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import jaxlie 6 | import jaxls 7 | 8 | from .._robot import Robot 9 | 10 | 11 | def _get_actuated_joints_applied_to_target( 12 | robot: Robot, 13 | target_joint_idx: jax.Array, 14 | ) -> jax.Array: 15 | """For each joint `i` in the robot, we return an index that is: 16 | 17 | 1) -1 if the joint is not in the path to the target link. 18 | 2) an actuated joint index if the joint is in the path to the target link. 19 | - If the joint `i` is actuated, this is just `i`. 20 | - If the joint `i` mimics joint `j`, this is set to `j`. 21 | 22 | The inputs and outputs should both be integer arrays of shape 23 | `(num_joints,)`. 24 | """ 25 | 26 | assert target_joint_idx.shape == () 27 | 28 | def body_fun(joint_idx, indices): 29 | # Find the corresponding active actuated joint. 30 | active_act_joint = jnp.where( 31 | robot.joints.actuated_indices[joint_idx] != -1, 32 | # The current joint is actuated. 33 | robot.joints.actuated_indices[joint_idx], 34 | # The current joint is not actuated; this is -1 if not a mimic joint. 35 | robot.joints.mimic_act_indices[joint_idx], 36 | ) 37 | 38 | # Find the parent of the current joint. 39 | parent_joint = robot.joints.parent_indices[joint_idx] 40 | 41 | # Continue traversing up the kinematic tree, using the parent joint. 42 | # This value may either go up or down, since there's no guarantee that 43 | # the kinematic tree is topologically sorted. :-) 44 | next_indices = indices.at[joint_idx].set(active_act_joint) 45 | return (parent_joint, next_indices) 46 | 47 | idx_applied_to_target = jnp.full( 48 | (robot.joints.num_joints,), 49 | fill_value=-1, 50 | dtype=jnp.int32, 51 | ) 52 | idx_applied_to_target = jax.lax.while_loop( 53 | lambda carry: jnp.any(carry[0] >= 0), 54 | lambda carry: body_fun(*carry), 55 | (target_joint_idx, idx_applied_to_target), 56 | )[-1] 57 | return idx_applied_to_target 58 | 59 | 60 | _PoseCostJacCache = tuple[jax.Array, jax.Array, jaxlie.SE3] 61 | 62 | 63 | def pose_cost_analytic_jac( 64 | robot: Robot, 65 | joint_var: jaxls.Var[jax.Array], 66 | target_pose: jaxlie.SE3, 67 | target_link_index: jax.Array, 68 | pos_weight: jax.Array | float, 69 | ori_weight: jax.Array | float, 70 | ) -> jaxls.Cost: 71 | # We only check shape lengths because there might be (1,) axes for 72 | # broadcasting reasons. 73 | assert ( 74 | len(target_link_index.shape) 75 | == len(jnp.asarray(joint_var.id).shape) 76 | == len(robot.joints.twists.shape[:-2]) 77 | ), "Batch axes of inputs should match" 78 | 79 | # Broadcast the inputs for _get_actuated_joints_applied_to_target(). 80 | # Excluding the weights for now... 81 | batch_axes = jnp.broadcast_shapes( 82 | target_pose.get_batch_axes(), 83 | jnp.asarray(joint_var.id).shape, 84 | target_pose.get_batch_axes(), 85 | target_link_index.shape, 86 | ) 87 | broadcast_batch_axes = partial( 88 | jax.tree.map, 89 | lambda x: jnp.broadcast_to(x, batch_axes + x.shape[len(batch_axes) :]), 90 | ) 91 | get_actuated_joints = _get_actuated_joints_applied_to_target 92 | for _ in range(len(batch_axes)): 93 | get_actuated_joints = jax.vmap(get_actuated_joints) 94 | 95 | # Compute applied joints. 96 | robot = broadcast_batch_axes(robot) 97 | base_link_mask = robot.links.parent_joint_indices == -1 98 | parent_joint_indices = jnp.where( 99 | base_link_mask, 0, robot.links.parent_joint_indices 100 | ) 101 | target_joint_idx = parent_joint_indices[ 102 | tuple(jnp.arange(d) for d in parent_joint_indices.shape[:-1]) 103 | + (target_link_index,) 104 | ] 105 | joints_applied_to_target = get_actuated_joints( 106 | broadcast_batch_axes(robot), broadcast_batch_axes(target_joint_idx) 107 | ) 108 | 109 | return _pose_cost_analytical_jac( 110 | robot, 111 | joint_var, 112 | target_pose, 113 | target_link_index, 114 | pos_weight, 115 | ori_weight, 116 | joints_applied_to_target, 117 | ) 118 | 119 | 120 | # It's nice to pass arguments in explicitly instead of via closure in the 121 | # `pose_cost_analytic_jac` wrapper. It helps jaxls vectorize repeated costs. 122 | def _pose_cost_jac( 123 | vals: jaxls.VarValues, 124 | jac_cache: _PoseCostJacCache, 125 | robot: Robot, 126 | joint_var: jaxls.Var[jax.Array], 127 | target_pose: jaxlie.SE3, 128 | target_link_index: jax.Array, 129 | pos_weight: jax.Array | float, 130 | ori_weight: jax.Array | float, 131 | joints_applied_to_target: jax.Array, 132 | ) -> jax.Array: 133 | """Jacobian for pose cost with analytic computation.""" 134 | del vals, joint_var, target_pose # Unused! 135 | Ts_world_joint, Ts_world_link, pose_error = jac_cache 136 | 137 | T_world_ee = jaxlie.SE3(Ts_world_link[target_link_index]) 138 | Ts_world_joint = jaxlie.SE3(Ts_world_joint) 139 | 140 | R_ee_world = T_world_ee.rotation().inverse() 141 | 142 | # Get joint twists; these are scaled for mimic joints. 143 | joint_twists = robot.joints.twists * robot.joints.mimic_multiplier[..., None] 144 | 145 | # Get angular velocity components (omega). 146 | omega_local = joint_twists[:, 3:] 147 | omega_wrt_world = Ts_world_joint.rotation() @ omega_local 148 | omega_wrt_ee = R_ee_world @ omega_wrt_world 149 | 150 | # Get linear velocity components (v). 151 | vel_local = joint_twists[:, :3] 152 | vel_wrt_world = Ts_world_joint.rotation() @ vel_local 153 | 154 | # Compute the linear velocity component (v = ω × r + v_joint). 155 | vel_wrt_world = ( 156 | jnp.cross( 157 | omega_wrt_world, 158 | T_world_ee.translation() - Ts_world_joint.translation(), 159 | ) 160 | + vel_wrt_world 161 | ) 162 | vel_wrt_ee = R_ee_world @ vel_wrt_world 163 | 164 | # Combine into spatial Jacobian. 165 | jac = jnp.where( 166 | joints_applied_to_target[:, None] != -1, 167 | jnp.concatenate( 168 | [ 169 | vel_wrt_ee, 170 | omega_wrt_ee, 171 | ], 172 | axis=1, 173 | ), 174 | 0.0, 175 | ).T 176 | jac = pose_error.jlog() @ jac 177 | 178 | # Jacobian of all joints => Jacobian of actuated joints. 179 | # 180 | # Because of mimic joints, the Jacobian terms from multiple joints can be 181 | # applied to a single actuated joint. This is summed! 182 | jac = ( 183 | jnp.zeros((6, robot.joints.num_actuated_joints)) 184 | .at[:, joints_applied_to_target] 185 | .add((joints_applied_to_target[None, :] != -1) * jac) 186 | ) 187 | 188 | # Apply weights 189 | weights = jnp.array([pos_weight] * 3 + [ori_weight] * 3) 190 | return jac * weights[:, None] 191 | 192 | 193 | @jaxls.Cost.create_factory(jac_custom_with_cache_fn=_pose_cost_jac) 194 | def _pose_cost_analytical_jac( 195 | vals: jaxls.VarValues, 196 | robot: Robot, 197 | joint_var: jaxls.Var[jax.Array], 198 | target_pose: jaxlie.SE3, 199 | target_link_index: jax.Array, 200 | pos_weight: jax.Array | float, 201 | ori_weight: jax.Array | float, 202 | joints_applied_to_target: jax.Array, 203 | ) -> tuple[jax.Array, _PoseCostJacCache]: 204 | """Computes the residual for matching link poses to target poses.""" 205 | del joints_applied_to_target 206 | assert target_link_index.dtype == jnp.int32 207 | joint_cfg = vals[joint_var] 208 | 209 | Ts_world_joint = robot._forward_kinematics_joints(joint_cfg) 210 | Ts_world_link = robot._link_poses_from_joint_poses(Ts_world_joint) 211 | 212 | T_world_ee = jaxlie.SE3(Ts_world_link[target_link_index, :]) 213 | pose_error = target_pose.inverse() @ T_world_ee 214 | return ( 215 | pose_error.log() * jnp.array([pos_weight] * 3 + [ori_weight] * 3), 216 | # Second argument is cache parameter, which is passed to the custom Jacobian function. 217 | (Ts_world_joint, Ts_world_link, pose_error), 218 | ) 219 | -------------------------------------------------------------------------------- /src/pyroki/costs/_pose_cost_numerical_jac.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import jaxlie 4 | import jaxls 5 | 6 | from .._robot import Robot 7 | 8 | _PoseCostJacCache = tuple[jax.Array, jax.Array, jaxlie.SE3] 9 | 10 | 11 | def _pose_cost_jac( 12 | vals: jaxls.VarValues, 13 | jac_cache: _PoseCostJacCache, 14 | robot: Robot, 15 | joint_var: jaxls.Var[jax.Array], 16 | target_pose: jaxlie.SE3, 17 | target_link_index: jax.Array, 18 | pos_weight: jax.Array | float, 19 | ori_weight: jax.Array | float, 20 | eps: float = 1e-4, 21 | ) -> jax.Array: 22 | """Jacobian for pose cost with numerical computation.""" 23 | joint_cfg = vals[joint_var] 24 | _, _, pose_error = jac_cache 25 | 26 | def finite_difference_jac(idx: jax.Array) -> jax.Array: 27 | joint_cfg_perturbed = joint_cfg.at[idx].add(eps) 28 | T_world_ee = jaxlie.SE3( 29 | robot.forward_kinematics(joint_cfg_perturbed)[..., target_link_index, :] 30 | ) 31 | perturbed_pose_error = target_pose.inverse() @ T_world_ee 32 | err_diff = perturbed_pose_error.log() - pose_error.log() 33 | return err_diff / eps 34 | 35 | jac = jax.vmap(finite_difference_jac)(jnp.arange(joint_cfg.shape[-1])).T 36 | assert jac.shape == (6, robot.joints.num_actuated_joints) 37 | 38 | return jac * jnp.array([pos_weight] * 3 + [ori_weight] * 3)[:, None] 39 | 40 | 41 | @jaxls.Cost.create_factory(jac_custom_with_cache_fn=_pose_cost_jac) 42 | def pose_cost_numerical_jac( 43 | vals: jaxls.VarValues, 44 | robot: Robot, 45 | joint_var: jaxls.Var[jax.Array], 46 | target_pose: jaxlie.SE3, 47 | target_link_index: jax.Array, 48 | pos_weight: jax.Array | float, 49 | ori_weight: jax.Array | float, 50 | eps: float = 1e-4, 51 | ) -> tuple[jax.Array, _PoseCostJacCache]: 52 | """Computes the residual for matching link poses to target poses.""" 53 | del eps # Unused! 54 | assert target_link_index.dtype == jnp.int32 55 | joint_cfg = vals[joint_var] 56 | 57 | Ts_world_joint = robot._forward_kinematics_joints(joint_cfg) 58 | Ts_world_link = robot._link_poses_from_joint_poses(Ts_world_joint) 59 | 60 | T_world_ee = jaxlie.SE3(Ts_world_link[target_link_index, :]) 61 | pose_error = target_pose.inverse() @ T_world_ee 62 | return ( 63 | pose_error.log() * jnp.array([pos_weight] * 3 + [ori_weight] * 3), 64 | # Second argument is cache parameter, which is passed to the custom Jacobian function. 65 | (Ts_world_joint, Ts_world_link, pose_error), 66 | ) 67 | -------------------------------------------------------------------------------- /src/pyroki/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import time 3 | from functools import partial 4 | from typing import Generator 5 | 6 | import jax 7 | import termcolor 8 | from loguru import logger 9 | 10 | 11 | @contextlib.contextmanager 12 | def stopwatch(label: str = "unlabeled block") -> Generator[None, None, None]: 13 | """Context manager for measuring runtime.""" 14 | start_time = time.time() 15 | print("\n========") 16 | print(f"Running ({label})") 17 | yield 18 | print(f"{termcolor.colored(str(time.time() - start_time), attrs=['bold'])} seconds") 19 | print("========") 20 | 21 | 22 | def _log(fmt: str, *args, **kwargs) -> None: 23 | logger.bind(function="log").info(fmt, *args, **kwargs) 24 | 25 | 26 | def jax_log(fmt: str, *args, **kwargs) -> None: 27 | """Emit a loguru info message from a JITed JAX function.""" 28 | jax.debug.callback(partial(_log, fmt), *args, **kwargs) 29 | -------------------------------------------------------------------------------- /src/pyroki/viewer/__init__.py: -------------------------------------------------------------------------------- 1 | # from ._batched_urdf import BatchedURDF as BatchedURDF 2 | from ._weight_tuner import WeightTuner as WeightTuner 3 | from ._manipulability_ellipse import ManipulabilityEllipse as ManipulabilityEllipse 4 | -------------------------------------------------------------------------------- /src/pyroki/viewer/_batched_urdf.py: -------------------------------------------------------------------------------- 1 | """Batched URDF rendering in Viser. This requires features not yet available in 2 | the PyPi version of Viser, we can re-enable it after the next Viser release.""" 3 | 4 | # import jax.numpy as jnp 5 | # import jaxlie 6 | # import numpy as onp 7 | # import viser 8 | # import viser.transforms as vtf 9 | # import yourdfpy 10 | # from jax.typing import ArrayLike 11 | # from pyroki._robot import Robot 12 | # 13 | # from viser.extras import BatchedGlbHandle 14 | # 15 | # 16 | # class BatchedURDF: 17 | # """ 18 | # Helper for rendering batched URDFs in Viser. 19 | # Similar to `viser.extras.ViserUrdf`, but batched using `pyroki`'s batched forward kinematics. 20 | # 21 | # If num_robots > 1 then the URDF meshes are rendered as batched meshes (instancing), 22 | # otherwise they are rendered as individual meshes. 23 | # 24 | # Args: 25 | # target: Viser server or client handle to add URDF to. 26 | # urdf: URDF to render. 27 | # num_robots: Number of robots in the batch. 28 | # root_node_name: Name of the root node in the Viser scene. 29 | # """ 30 | # 31 | # def __init__( 32 | # self, 33 | # target: viser.ViserServer | viser.ClientHandle, 34 | # urdf: yourdfpy.URDF, 35 | # num_robots: int = 1, 36 | # root_node_name: str = "/", 37 | # ): 38 | # assert root_node_name.startswith("/") 39 | # robot = Robot.from_urdf(urdf) 40 | # 41 | # self._urdf = urdf 42 | # self._robot = robot 43 | # self._target = target 44 | # self._root_node_name = root_node_name 45 | # self._num_robots = num_robots 46 | # 47 | # # Initialize base transforms to identity. 48 | # self._base_transforms = jaxlie.SE3.identity(batch_axes=(num_robots,)) 49 | # self._last_cfg = None # Store the last configuration. 50 | # 51 | # self._populate() 52 | # 53 | # def _populate(self): 54 | # # Initialize with the correct batch size. 55 | # dummy_transform = vtf.SE3.identity(batch_axes=(self._num_robots,)) 56 | # dummy_position = dummy_transform.translation() 57 | # dummy_wxyz = dummy_transform.rotation().wxyz 58 | # 59 | # self._meshes: dict[str, list[BatchedGlbHandle | viser.GlbHandle]] = {} 60 | # self._link_to_meshes: dict[str, onp.ndarray] = {} 61 | # 62 | # # Check if add_batched_meshes_trimesh is available. 63 | # if ( 64 | # not hasattr(self._target.scene, "add_batched_meshes_trimesh") 65 | # and self._num_robots > 1 66 | # ): 67 | # raise NotImplementedError( 68 | # "num_robots > 1, but viser doesn't support instancing " 69 | # "(add_batched_meshes_trimesh is not available)." 70 | # ) 71 | # 72 | # for mesh_name, mesh in self._urdf.scene.geometry.items(): 73 | # link_name = self._urdf.scene.graph.transforms.parents[mesh_name] 74 | # if link_name not in self._meshes: 75 | # self._meshes[link_name] = [] 76 | # 77 | # # Put mesh in the link frame. 78 | # T_parent_child = self._urdf.get_transform( 79 | # mesh_name, self._urdf.scene.graph.transforms.parents[mesh_name] 80 | # ) 81 | # mesh = mesh.copy() 82 | # mesh.apply_transform(T_parent_child) 83 | # 84 | # if self._num_robots > 1: 85 | # self._meshes[link_name].append( 86 | # self._target.scene.add_batched_meshes_trimesh( # type: ignore[attr-defined] 87 | # f"{self._root_node_name}/{mesh_name}", 88 | # mesh, 89 | # batched_positions=dummy_position, 90 | # batched_wxyzs=dummy_wxyz, 91 | # lod="auto", 92 | # ) 93 | # ) 94 | # else: 95 | # self._meshes[link_name].append( 96 | # self._target.scene.add_mesh_trimesh( 97 | # f"{self._root_node_name}/{mesh_name}", 98 | # mesh, 99 | # position=dummy_position[0], 100 | # wxyz=dummy_wxyz[0], 101 | # ) 102 | # ) 103 | # 104 | # self._link_to_meshes[link_name] = T_parent_child 105 | # 106 | # def remove(self): 107 | # for meshes in self._meshes.values(): 108 | # for mesh in meshes: 109 | # mesh.remove() 110 | # 111 | # def update_base_frame(self, base_transforms: ArrayLike): 112 | # """ 113 | # Update the base transforms for each robot in the batch. 114 | # 115 | # Args: 116 | # base_transforms: New base transforms. Should be a JAX-compatible array 117 | # representing SE(3) transforms (e.g., a jaxlie.SE3 object) 118 | # with shape (num_robots,). 119 | # """ 120 | # base_transforms_jnp = jnp.array(base_transforms) 121 | # base_transforms_jnp = jnp.atleast_2d(base_transforms_jnp) 122 | # assert base_transforms_jnp.shape[0] == self._num_robots, ( 123 | # f"Expected first dimension of base_transforms to be {self._num_robots}, got {base_transforms_jnp.shape[0]}" 124 | # ) 125 | # 126 | # self._base_transforms = jaxlie.SE3(base_transforms_jnp) 127 | # 128 | # # Re-apply transforms if a configuration exists 129 | # if self._last_cfg is not None: 130 | # self._apply_transforms(self._last_cfg) 131 | # 132 | # def update_cfg(self, cfg: ArrayLike): 133 | # """ 134 | # Update the poses of the batched robots based on their configurations. 135 | # 136 | # Args: 137 | # cfg: Batched joint configurations. Shape should be (num_robots, num_dofs), or (num_dofs,). 138 | # """ 139 | # cfg_jax = jnp.array(cfg) # in case cfg is an onp.ndarray. 140 | # cfg_jax = jnp.atleast_2d(cfg_jax) 141 | # assert cfg_jax.shape[0] == self._num_robots, ( 142 | # f"Expected first dimension of cfg to be {self._num_robots}, got {cfg_jax.shape[0]}" 143 | # ) 144 | # 145 | # # Store the latest configuration 146 | # self._last_cfg = cfg_jax 147 | # self._apply_transforms(cfg_jax) 148 | # 149 | # def _apply_transforms(self, cfg_jax: jnp.ndarray): 150 | # """Helper method to apply FK and base transforms to update meshes.""" 151 | # # Ts_link_world should have shape (num_robots, num_links, ...) 152 | # Ts_link_world = self._robot.forward_kinematics(cfg_jax) 153 | # 154 | # for link_name, meshes in self._meshes.items(): 155 | # link_idx = self._robot.links.names.index(link_name) 156 | # # T_link_world has shape (num_robots, ...) 157 | # T_link_world = jaxlie.SE3( 158 | # Ts_link_world[:, link_idx] 159 | # ) # Select link transforms for all robots 160 | # 161 | # # Apply base transform: T_mesh_world = T_base * T_link_world 162 | # # Resulting shape is (num_robots, ...) 163 | # T_mesh_world = self._base_transforms @ T_link_world 164 | # 165 | # # Extract batched positions and orientations 166 | # position = onp.array(T_mesh_world.translation()) # Shape (num_robots, 3) 167 | # wxyz = onp.array(T_mesh_world.rotation().wxyz) # Shape (num_robots, 4) 168 | # for mesh in meshes: 169 | # if isinstance(mesh, viser.GlbHandle): 170 | # mesh.position = position[0] 171 | # mesh.wxyz = wxyz[0] 172 | # else: 173 | # assert isinstance(mesh, BatchedGlbHandle) 174 | # mesh.batched_positions = position 175 | # mesh.batched_wxyzs = wxyz 176 | -------------------------------------------------------------------------------- /src/pyroki/viewer/_manipulability_ellipse.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | import jaxlie 6 | import numpy as onp 7 | import trimesh.creation 8 | import viser 9 | from jax.typing import ArrayLike 10 | from loguru import logger 11 | 12 | from .._robot import Robot 13 | 14 | 15 | class ManipulabilityEllipse: 16 | """Helper class to visualize the manipulability ellipsoid for a robot link.""" 17 | 18 | def __init__( 19 | self, 20 | server: viser.ViserServer | viser.ClientHandle, 21 | robot: Robot, 22 | root_node_name: str = "/manipulability", 23 | target_link_name: Optional[str] = None, 24 | scaling_factor: float = 0.2, 25 | visible: bool = True, 26 | wireframe: bool = True, 27 | color: Tuple[int, int, int] = (200, 200, 255), 28 | ): 29 | """Initializes the manipulability ellipsoid visualizer. 30 | 31 | Args: 32 | server: The Viser server or client handle. 33 | robot: The Pyroki robot model. 34 | root_node_name: The base name for the ellipsoid mesh in the Viser scene. 35 | target_link_name: Optional name of the link to visualize the ellipsoid for initially. 36 | scaling_factor: Scaling factor applied to the ellipsoid dimensions. 37 | visible: Initial visibility state. 38 | wireframe: Whether to render the ellipsoid as a wireframe. 39 | color: The color of the ellipsoid mesh. 40 | """ 41 | self._server = server 42 | self._robot = robot 43 | self._root_node_name = root_node_name 44 | self._target_link_name = target_link_name 45 | self._scaling_factor = scaling_factor 46 | self._visible = visible 47 | self._wireframe = wireframe 48 | self._color = color 49 | 50 | self._base_manip_sphere = trimesh.creation.icosphere(radius=1.0) 51 | self._mesh_handle: Optional[viser.MeshHandle] = None 52 | self._target_link_index: Optional[int] = None 53 | self._last_joints: Optional[jnp.ndarray] = None 54 | 55 | # Initial creation of the mesh handle (hidden if not visible) 56 | self._create_mesh_handle() 57 | 58 | # Set initial target link if provided 59 | self.set_target_link(target_link_name) 60 | 61 | self.manipulability = 0.0 62 | 63 | def _create_mesh_handle(self): 64 | """Creates or recreates the mesh handle in the Viser scene.""" 65 | if self._mesh_handle is not None: 66 | self._mesh_handle.remove() 67 | 68 | # Create with dummy data initially, will be updated 69 | self._mesh_handle = self._server.scene.add_mesh_simple( 70 | self._root_node_name, 71 | vertices=onp.zeros((1, 3), dtype=onp.float32), 72 | faces=onp.zeros((1, 3), dtype=onp.uint32), 73 | color=self._color, 74 | wireframe=self._wireframe, 75 | visible=self._visible, 76 | ) 77 | 78 | # Viser version compatibility. 79 | if hasattr(self._mesh_handle, "cast_shadow"): 80 | self._mesh_handle.cast_shadow = ( # type: ignore[attr-defined] 81 | False # Ellipsoids usually don't need shadows 82 | ) 83 | 84 | def set_target_link(self, link_name: Optional[str]): 85 | """Sets the target link for which to display the ellipsoid. 86 | 87 | Args: 88 | link_name: The name of the target link, or None to disable. 89 | """ 90 | if link_name is None: 91 | self._target_link_index = None 92 | self.set_visibility(False) # Hide if no target link 93 | else: 94 | try: 95 | self._target_link_index = self._robot.links.names.index(link_name) 96 | # If we previously hid because of no target, make visible again 97 | # if the user hasn't explicitly set visibility to False. 98 | if self._mesh_handle is not None and self._visible: 99 | self._mesh_handle.visible = True 100 | except ValueError: 101 | logger.warning(f"Link name '{link_name}' not found in robot model.") 102 | self._target_link_index = None 103 | self.set_visibility(False) # Hide if link not found 104 | 105 | def update(self, joints: ArrayLike): 106 | """Updates the ellipsoid based on the current joint configuration. 107 | 108 | Args: 109 | joints: The current joint angles of the robot. 110 | """ 111 | if ( 112 | self._target_link_index is None 113 | or not self._visible 114 | or self._mesh_handle is None 115 | ): 116 | # Ensure mesh is hidden if it shouldn't be shown 117 | if self._mesh_handle is not None and self._mesh_handle.visible: 118 | self._mesh_handle.visible = False 119 | return 120 | 121 | # Ensure mesh is visible if it should be 122 | if not self._mesh_handle.visible: 123 | self._mesh_handle.visible = True 124 | 125 | joints = jnp.asarray(joints) 126 | self._last_joints = joints # Store for potential future updates 127 | 128 | try: 129 | # --- Jacobian Calculation --- 130 | jacobian = jax.jacfwd( 131 | lambda q: jaxlie.SE3(self._robot.forward_kinematics(q)).translation() 132 | )(joints)[self._target_link_index] 133 | assert jacobian.shape == (3, self._robot.joints.num_actuated_joints) 134 | 135 | # --- Manipulability Calculation --- 136 | JJT = jacobian @ jacobian.T 137 | assert JJT.shape == (3, 3) 138 | self.manipulability = jnp.sqrt(jnp.maximum(0.0, jnp.linalg.det(JJT))).item() 139 | 140 | # --- Covariance and Eigen decomposition --- 141 | cov_matrix = jacobian @ jacobian.T 142 | assert cov_matrix.shape == (3, 3) 143 | # Use numpy for Eigh as it might be more stable for visualization 144 | vals, vecs = onp.linalg.eigh(onp.array(cov_matrix)) 145 | vals = onp.maximum(vals, 1e-9) # Clamp small eigenvalues for stability 146 | 147 | # --- Get Target Link Pose --- 148 | Ts_link_world_array = self._robot.forward_kinematics(joints) 149 | target_pose_array = Ts_link_world_array[self._target_link_index] 150 | target_pose = jaxlie.SE3(target_pose_array) 151 | target_pos = onp.array(target_pose.translation().squeeze()) 152 | 153 | # --- Create and Transform Ellipsoid Mesh --- 154 | ellipsoid_mesh = self._base_manip_sphere.copy() 155 | tf = onp.eye(4) 156 | tf[:3, :3] = onp.array(vecs) # Rotation from eigenvectors 157 | tf[:3, 3] = target_pos # Translation to link origin 158 | 159 | # Apply scaling according to eigenvalues and the scaling factor 160 | ellipsoid_mesh.apply_scale(onp.sqrt(vals) * self._scaling_factor) 161 | # Apply the final transform 162 | ellipsoid_mesh.apply_transform(tf) 163 | 164 | # --- Update Viser Mesh --- 165 | self._mesh_handle.vertices = onp.array( 166 | ellipsoid_mesh.vertices, dtype=onp.float32 167 | ) 168 | self._mesh_handle.faces = onp.array(ellipsoid_mesh.faces, dtype=onp.uint32) 169 | 170 | except Exception as e: 171 | logger.warning(f"Failed to update manipulability ellipsoid: {e}") 172 | # Hide the mesh on failure 173 | if self._mesh_handle is not None: 174 | self._mesh_handle.visible = False 175 | 176 | def set_visibility(self, visible: bool): 177 | """Sets the visibility of the ellipsoid mesh.""" 178 | self._visible = visible 179 | if self._mesh_handle is not None: 180 | # If visibility is being turned on, and we have a target link and joints, 181 | # trigger an update to ensure the geometry is correct. 182 | if ( 183 | visible 184 | and self._target_link_index is not None 185 | and self._last_joints is not None 186 | ): 187 | self.update(self._last_joints) # Recalculate and show 188 | # Otherwise, just set the visibility flag on the handle 189 | elif self._mesh_handle.visible != visible: 190 | self._mesh_handle.visible = visible 191 | 192 | def remove(self): 193 | """Removes the ellipsoid mesh from the Viser scene.""" 194 | if self._mesh_handle is not None: 195 | self._mesh_handle.remove() 196 | self._mesh_handle = None 197 | self._target_link_index = None # Clear target when removed 198 | -------------------------------------------------------------------------------- /src/pyroki/viewer/_weight_tuner.py: -------------------------------------------------------------------------------- 1 | """Provides a class for interactively tuning named weights using Viser GUIs.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import cast 6 | 7 | import jax 8 | import viser 9 | 10 | 11 | class WeightTuner: 12 | """Creates and manages a set of Viser GUI sliders for tuning named weights. 13 | 14 | This class simplifies the process of adding interactive controls to a Viser 15 | application, typically used for adjusting weights in optimization problems 16 | (like inverse kinematics) or other numerical parameters in real-time. 17 | The sliders are grouped within a Viser GUI folder. 18 | """ 19 | 20 | _server: viser.ViserServer 21 | _weight_handles: dict[str, viser.GuiSliderHandle] 22 | 23 | _min: dict[str, float] 24 | _max: dict[str, float] 25 | _step: dict[str, float] 26 | _default: dict[str, float] 27 | 28 | def __init__( 29 | self, 30 | server: viser.ViserServer, 31 | default: dict[str, float], 32 | *, 33 | folder_name: str = "Costs", 34 | min: dict[str, float] | None = None, 35 | max: dict[str, float] | None = None, 36 | step: dict[str, float] | None = None, 37 | default_min: float = 0.0, 38 | default_max: float = 100.0, 39 | default_step: float = 0.01, 40 | ): 41 | """Initializes the tuner and creates the Viser GUI sliders. 42 | The sliders show up in the order of the keys in the `default` dictionary. 43 | `default` must be a dictionary. 44 | 45 | Args: 46 | server: The Viser server instance. 47 | default: An instance of the dataclass defining the weights. 48 | folder_name: Name of the Viser GUI folder to contain the sliders. 49 | min: Minimum value for each slider. 50 | max: Maximum value for each slider. 51 | step: Step size for each slider. 52 | default_min: Minimum value for all sliders. `min` overrides this. 53 | default_max: Maximum value for all sliders. `max` overrides this. 54 | default_step: Step size for all sliders. `step` overrides this. 55 | """ 56 | leaves = jax.tree.leaves(default) 57 | assert all(isinstance(leaf, (int, float)) for leaf in leaves), ( 58 | "All default parameters must be ints or floats." 59 | ) 60 | assert isinstance(default, dict) 61 | 62 | self._server = server 63 | self._weight_handles = {} 64 | self._max = jax.tree.map(lambda _: default_max, default) 65 | if max is not None: 66 | for key, max_val in max.items(): 67 | cast(dict, self._max)[key] = max_val 68 | 69 | self._min = jax.tree.map(lambda _: default_min, default) 70 | if min is not None: 71 | for key, min_val in min.items(): 72 | cast(dict, self._min)[key] = min_val 73 | 74 | self._step = jax.tree.map(lambda _: default_step, default) 75 | if step is not None: 76 | for key, step_val in step.items(): 77 | cast(dict, self._step)[key] = step_val 78 | 79 | self._default = default 80 | 81 | with server.gui.add_folder(folder_name): 82 | for field, default_weight in default.items(): 83 | self._weight_handles[field] = server.gui.add_slider( 84 | field, 85 | min=self._min[field], 86 | max=self._max[field], 87 | step=self._step[field], 88 | initial_value=default_weight, 89 | ) 90 | 91 | reset_button = server.gui.add_button("Reset Weights") 92 | reset_button.on_click(lambda _: self.reset_weights()) 93 | 94 | def get_weights(self) -> dict[str, float]: 95 | """Retrieves the current values of all tracked weights from the GUI sliders. 96 | 97 | Returns: 98 | A dictionary mapping weight names to their current float values 99 | as set by the sliders. 100 | """ 101 | return { 102 | field: handle.value 103 | for field, handle in zip( 104 | self._weight_handles.keys(), 105 | self._weight_handles.values(), 106 | ) 107 | } 108 | 109 | def reset_weights(self): 110 | """Resets all weights to their initial values.""" 111 | for name, handle in self._weight_handles.items(): 112 | handle.value = self._default[name] 113 | --------------------------------------------------------------------------------