├── .flake8 ├── .github └── workflows │ └── ubuntu.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── assets ├── ant.png ├── humanoid.png ├── joint_angle_editor.gif ├── kuka.png ├── logo.png └── simple_arm.png ├── examples ├── ant │ └── ant.xml ├── ant_example.py ├── humanoid │ └── humanoid.xml ├── humanoid_example.py ├── kuka_example.py ├── kuka_iiwa │ ├── meshes │ │ ├── link_0.obj │ │ ├── link_1.obj │ │ ├── link_2.obj │ │ ├── link_3.obj │ │ ├── link_4.obj │ │ ├── link_5.obj │ │ ├── link_6.obj │ │ └── link_7.obj │ └── model.urdf ├── mycobot │ ├── G_base.ply │ ├── camera_flange.ply │ ├── joint1.ply │ ├── joint1.png │ ├── joint2.ply │ ├── joint2.png │ ├── joint3.ply │ ├── joint3.png │ ├── joint4.ply │ ├── joint4.png │ ├── joint5.ply │ ├── joint5.png │ ├── joint6.ply │ ├── joint6.png │ ├── joint7.ply │ ├── joint7.png │ ├── mycobot.urdf │ └── pump_head.ply ├── mycobot_example.py ├── simple_arm │ └── model.sdf ├── simple_arm_example.py ├── ur │ └── ur.urdf └── ur_example.py ├── kinpy ├── __init__.py ├── chain.py ├── frame.py ├── ik.py ├── jacobian.py ├── mjcf.py ├── mjcf_parser │ ├── __init__.py │ ├── attribute.py │ ├── base.py │ ├── constants.py │ ├── copier.py │ ├── debugging.py │ ├── element.py │ ├── io.py │ ├── namescope.py │ ├── parser.py │ ├── schema.py │ ├── schema.xml │ └── util.py ├── sdf.py ├── transform.py ├── urdf.py ├── urdf_parser_py │ ├── __init__.py │ ├── sdf.py │ ├── urdf.py │ └── xml_reflection │ │ ├── __init__.py │ │ ├── basics.py │ │ └── core.py └── visualizer.py ├── pyproject.toml ├── scripts └── kpviewer.py ├── setup.py └── tests ├── __init__.py ├── test_fkik.py ├── test_jacobian.py └── test_transform.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 119 3 | exclude = __init__.py,__pycache__ 4 | ignore = E121,E123,E126,E133,E226,E203,E241,E242,E704,W503,W504,W505,E127,E266,E402,W605,W391,E701,E731 5 | -------------------------------------------------------------------------------- /.github/workflows/ubuntu.yml: -------------------------------------------------------------------------------- 1 | name: Ubuntu CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | tags: ['v*'] 7 | pull_request: 8 | branches: [ master ] 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: ['3.9', '3.10', '3.11'] 16 | steps: 17 | - name: Checkout source code 18 | uses: actions/checkout@v2 19 | with: 20 | submodules: true 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | sudo apt-get update 28 | sudo apt-get install -y curl 29 | - name: Install poetry 30 | run: | 31 | curl -sSL https://install.python-poetry.org | python3 - 32 | echo "$HOME/.poetry/bin" >> $GITHUB_PATH 33 | - name: Setup 34 | run: make setup 35 | - name: Test 36 | run: make test 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.~ 3 | *.egg-info 4 | build 5 | dist -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 neka-nat 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include kinpy/mjcf_parser/schema.xml 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | setup: 2 | poetry install --no-interaction 3 | poetry run pip install -e . 4 | 5 | test: 6 | find kinpy/. -maxdepth 1 -type f -name "*.py" | xargs poetry run flake8 7 | poetry run mypy kinpy/*.py 8 | poetry run python -m unittest discover -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 2 | 3 | [![Build status](https://github.com/neka-nat/kinpy/actions/workflows/ubuntu.yml/badge.svg)](https://github.com/neka-nat/kinpy/actions/workflows/ubuntu.yml/badge.svg) 4 | [![PyPI version](https://badge.fury.io/py/kinpy.svg)](https://badge.fury.io/py/kinpy) 5 | [![MIT License](http://img.shields.io/badge/license-MIT-blue.svg?style=flat)](LICENSE) 6 | [![Downloads](https://pepy.tech/badge/kinpy)](https://pepy.tech/project/kinpy) 7 | 8 | Simple kinematics body toolkit. 9 | 10 | ## Core features 11 | 12 | * Pure python library. 13 | * Support URDF, SDF and MJCF file. 14 | * Calculate FK, IK and jacobian. 15 | 16 | ![joint_angle_editor](assets/joint_angle_editor.gif) 17 | 18 | ## Installation 19 | 20 | ``` 21 | pip install kinpy 22 | ``` 23 | 24 | ## Getting started 25 | Here is a program that reads urdf and generates a kinematic chain. 26 | 27 | ```py 28 | import kinpy as kp 29 | 30 | chain = kp.build_chain_from_urdf(open("kuka_iiwa/model.urdf").read()) 31 | print(chain) 32 | # lbr_iiwa_link_0_frame 33 | # └──── lbr_iiwa_link_1_frame 34 | # └──── lbr_iiwa_link_2_frame 35 | # └──── lbr_iiwa_link_3_frame 36 | # └──── lbr_iiwa_link_4_frame 37 | # └──── lbr_iiwa_link_5_frame 38 | # └──── lbr_iiwa_link_6_frame 39 | # └──── lbr_iiwa_link_7_frame 40 | ``` 41 | 42 | Displays the parameter names of joint angles included in the chain. 43 | 44 | ```py 45 | print(chain.get_joint_parameter_names()) 46 | # ['lbr_iiwa_joint_1', 'lbr_iiwa_joint_2', 'lbr_iiwa_joint_3', 'lbr_iiwa_joint_4', 'lbr_iiwa_joint_5', 'lbr_iiwa_joint_6', 'lbr_iiwa_joint_7'] 47 | ``` 48 | 49 | Given joint angle values, calculate forward kinematics. 50 | 51 | ```py 52 | import math 53 | th = {'lbr_iiwa_joint_2': math.pi / 4.0, 'lbr_iiwa_joint_4': math.pi / 2.0} 54 | ret = chain.forward_kinematics(th) 55 | # {'lbr_iiwa_link_0': Transform(rot=[1. 0. 0. 0.], pos=[0. 0. 0.]), 'lbr_iiwa_link_1': Transform(rot=[1. 0. 0. 0.], pos=[0. 0. 0.1575]), 'lbr_iiwa_link_2': Transform(rot=[-0.27059805 0.27059805 0.65328148 0.65328148], pos=[0. 0. 0.36]), 'lbr_iiwa_link_3': Transform(rot=[-9.23879533e-01 3.96044251e-14 -3.82683432e-01 -1.96942462e-12], pos=[ 1.44603337e-01 -6.78179735e-13 5.04603337e-01]), 'lbr_iiwa_link_4': Transform(rot=[-0.65328148 -0.65328148 0.27059805 -0.27059805], pos=[ 2.96984848e-01 -3.37579445e-13 6.56984848e-01]), 'lbr_iiwa_link_5': Transform(rot=[ 2.84114655e-12 3.82683432e-01 -1.87377891e-12 -9.23879533e-01], pos=[ 1.66523647e-01 -1.00338887e-12 7.87446049e-01]), 'lbr_iiwa_link_6': Transform(rot=[-0.27059805 0.27059805 -0.65328148 -0.65328148], pos=[ 1.41421356e-02 -7.25873884e-13 9.39827561e-01]), 'lbr_iiwa_link_7': Transform(rot=[ 9.23879533e-01 2.61060896e-12 -3.82683432e-01 4.81056861e-12], pos=[-4.31335137e-02 -1.01819561e-12 9.97103210e-01])} 56 | ``` 57 | 58 | You can get the position and orientation of each link. 59 | 60 | If you want to use IK or Jacobian, you need to create a `SerialChain`. 61 | When creating a `SerialChain`, an end effector must be specified. 62 | 63 | ```py 64 | chain = kp.build_serial_chain_from_urdf(open("kuka_iiwa/model.urdf"), "lbr_iiwa_link_7") 65 | th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0] 66 | ret = chain.forward_kinematics(th, end_only=True) 67 | # chain.inverse_kinematics(ret) 68 | # chain.jacobian(th) 69 | ``` 70 | 71 | ## Visualization 72 | 73 | ### KUKA IIWA 74 | ![kuka](https://raw.githubusercontent.com/neka-nat/kinpy/master/assets/kuka.png) 75 | 76 | ### Mujoco humanoid 77 | ![humanoid](https://raw.githubusercontent.com/neka-nat/kinpy/master/assets/humanoid.png) 78 | 79 | ### Mujoco ant 80 | ![ant](https://raw.githubusercontent.com/neka-nat/kinpy/master/assets/ant.png) 81 | 82 | ### Simple arm 83 | ![simple_arm](https://raw.githubusercontent.com/neka-nat/kinpy/master/assets/simple_arm.png) 84 | 85 | ## Citing 86 | 87 | ``` 88 | @software{kinpy, 89 | author = {{Kenta-Tanaka et al.}}, 90 | title = {kinpy}, 91 | url = {https://github.com/neka-nat/kinpy}, 92 | version = {0.0.3}, 93 | date = {2019-10-11}, 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /assets/ant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/assets/ant.png -------------------------------------------------------------------------------- /assets/humanoid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/assets/humanoid.png -------------------------------------------------------------------------------- /assets/joint_angle_editor.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/assets/joint_angle_editor.gif -------------------------------------------------------------------------------- /assets/kuka.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/assets/kuka.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/assets/logo.png -------------------------------------------------------------------------------- /assets/simple_arm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/assets/simple_arm.png -------------------------------------------------------------------------------- /examples/ant/ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 72 | -------------------------------------------------------------------------------- /examples/ant_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import kinpy as kp 3 | 4 | chain = kp.build_chain_from_mjcf(open("ant/ant.xml")) 5 | print(chain) 6 | print(chain.get_joint_parameter_names()) 7 | th = { 8 | "hip_1": 0.0, 9 | "ankle_1": np.pi / 4.0, 10 | "hip_2": 0.0, 11 | "ankle_2": -np.pi / 4.0, 12 | "hip_3": 0.0, 13 | "ankle_3": -np.pi / 4.0, 14 | "hip_4": 0.0, 15 | "ankle_4": np.pi / 4.0, 16 | } 17 | ret = chain.forward_kinematics(th) 18 | print(ret) 19 | viz = kp.Visualizer() 20 | viz.add_robot(ret, chain.visuals_map()) 21 | viz.spin() 22 | -------------------------------------------------------------------------------- /examples/humanoid/humanoid.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 8 | 9 | 10 | 11 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 23 | 25 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /examples/humanoid_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import kinpy as kp 3 | 4 | chain = kp.build_chain_from_mjcf(open("humanoid/humanoid.xml")) 5 | print(chain) 6 | print(chain.get_joint_parameter_names()) 7 | th = {"left_knee": 0.0, "right_knee": 0.0} 8 | ret = chain.forward_kinematics(th) 9 | print(ret) 10 | viz = kp.Visualizer() 11 | viz.add_robot(ret, chain.visuals_map(), axes=True) 12 | viz.spin() 13 | -------------------------------------------------------------------------------- /examples/kuka_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import kinpy as kp 3 | 4 | chain = kp.build_serial_chain_from_urdf(open("kuka_iiwa/model.urdf"), "lbr_iiwa_link_7") 5 | print(chain) 6 | print(chain.get_joint_parameter_names()) 7 | th = [0.0, -np.pi / 4.0, 0.0, np.pi / 2.0, 0.0, np.pi / 4.0, 0.0] 8 | ret = chain.forward_kinematics(th, end_only=False) 9 | print(ret) 10 | viz = kp.Visualizer() 11 | viz.add_robot(ret, chain.visuals_map(), mesh_file_path="kuka_iiwa/", axes=True) 12 | viz.spin() 13 | -------------------------------------------------------------------------------- /examples/kuka_iiwa/model.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | -------------------------------------------------------------------------------- /examples/mycobot/G_base.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/G_base.ply -------------------------------------------------------------------------------- /examples/mycobot/camera_flange.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/camera_flange.ply -------------------------------------------------------------------------------- /examples/mycobot/joint1.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint1.ply -------------------------------------------------------------------------------- /examples/mycobot/joint1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint1.png -------------------------------------------------------------------------------- /examples/mycobot/joint2.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint2.ply -------------------------------------------------------------------------------- /examples/mycobot/joint2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint2.png -------------------------------------------------------------------------------- /examples/mycobot/joint3.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint3.ply -------------------------------------------------------------------------------- /examples/mycobot/joint3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint3.png -------------------------------------------------------------------------------- /examples/mycobot/joint4.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint4.ply -------------------------------------------------------------------------------- /examples/mycobot/joint4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint4.png -------------------------------------------------------------------------------- /examples/mycobot/joint5.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint5.ply -------------------------------------------------------------------------------- /examples/mycobot/joint5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint5.png -------------------------------------------------------------------------------- /examples/mycobot/joint6.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint6.ply -------------------------------------------------------------------------------- /examples/mycobot/joint6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint6.png -------------------------------------------------------------------------------- /examples/mycobot/joint7.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint7.ply -------------------------------------------------------------------------------- /examples/mycobot/joint7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint7.png -------------------------------------------------------------------------------- /examples/mycobot/mycobot.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | -------------------------------------------------------------------------------- /examples/mycobot/pump_head.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/pump_head.ply -------------------------------------------------------------------------------- /examples/mycobot_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import kinpy as kp 3 | 4 | chain = kp.build_serial_chain_from_urdf(open("mycobot/mycobot.urdf").read(), "pump_head") 5 | print(chain) 6 | 7 | print(chain.get_joint_parameter_names()) 8 | th = np.deg2rad([0, 20, -130, 20, 0, 0]) 9 | viz = kp.JointAngleEditor(chain, "mycobot/", initial_state=th, axes=True) 10 | viz.spin() 11 | -------------------------------------------------------------------------------- /examples/simple_arm/model.sdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 0 0 0.00099 0 0 0 7 | 8 | 1.11 9 | 0 10 | 0 11 | 100.11 12 | 0 13 | 1.01 14 | 15 | 101.0 16 | 17 | 18 | 0 0 0.05 0 0 0 19 | 20 | 21 | 1.0 1.0 0.1 22 | 23 | 24 | 25 | 26 | 0 0 0.05 0 0 0 27 | 28 | 29 | 1.0 1.0 0.1 30 | 31 | 32 | 33 | 37 | 38 | 39 | 40 | 0 0 0.6 0 0 0 41 | 42 | 43 | 0.05 44 | 1.0 45 | 46 | 47 | 48 | 49 | 0 0 0.6 0 0 0 50 | 51 | 52 | 0.05 53 | 1.0 54 | 55 | 56 | 57 | 61 | 62 | 63 | 64 | 65 | 0 0 1.1 0 0 0 66 | 67 | 0.045455 0 0 0 0 0 68 | 69 | 0.011 70 | 0 71 | 0 72 | 0.0225 73 | 0 74 | 0.0135 75 | 76 | 1.1 77 | 78 | 79 | 0 0 0.05 0 0 0 80 | 81 | 82 | 0.05 83 | 0.1 84 | 85 | 86 | 87 | 88 | 0 0 0.05 0 0 0 89 | 90 | 91 | 0.05 92 | 0.1 93 | 94 | 95 | 96 | 100 | 101 | 102 | 103 | 0.55 0 0.05 0 0 0 104 | 105 | 106 | 1.0 0.05 0.1 107 | 108 | 109 | 110 | 111 | 0.55 0 0.05 0 0 0 112 | 113 | 114 | 1.0 0.05 0.1 115 | 116 | 117 | 118 | 122 | 123 | 124 | 125 | 126 | 1.05 0 1.1 0 0 0 127 | 128 | 0.0875 0 0.083333 0 0 0 129 | 130 | 0.031 131 | 0 132 | 0.005 133 | 0.07275 134 | 0 135 | 0.04475 136 | 137 | 1.2 138 | 139 | 140 | 0 0 0.1 0 0 0 141 | 142 | 143 | 0.05 144 | 0.2 145 | 146 | 147 | 148 | 149 | 0 0 0.1 0 0 0 150 | 151 | 152 | 0.05 153 | 0.2 154 | 155 | 156 | 157 | 161 | 162 | 163 | 164 | 0.3 0 0.15 0 0 0 165 | 166 | 167 | 0.5 0.03 0.1 168 | 169 | 170 | 171 | 172 | 0.3 0 0.15 0 0 0 173 | 174 | 175 | 0.5 0.03 0.1 176 | 177 | 178 | 179 | 183 | 184 | 185 | 186 | 0.55 0 0.15 0 0 0 187 | 188 | 189 | 0.05 190 | 0.3 191 | 192 | 193 | 194 | 195 | 0.55 0 0.15 0 0 0 196 | 197 | 198 | 0.05 199 | 0.3 200 | 201 | 202 | 203 | 207 | 208 | 209 | 210 | 211 | 1.6 0 1.05 0 0 0 212 | 213 | 0 0 0 0 0 0 214 | 215 | 0.01 216 | 0 217 | 0 218 | 0.01 219 | 0 220 | 0.001 221 | 222 | 0.1 223 | 224 | 225 | 0 0 0.5 0 0 0 226 | 227 | 228 | 0.03 229 | 1.0 230 | 231 | 232 | 233 | 234 | 0 0 0.5 0 0 0 235 | 236 | 237 | 0.03 238 | 1.0 239 | 240 | 241 | 242 | 246 | 247 | 248 | 249 | 250 | 1.6 0 1.0 0 0 0 251 | 252 | 0 0 0 0 0 0 253 | 254 | 0.01 255 | 0 256 | 0 257 | 0.01 258 | 0 259 | 0.001 260 | 261 | 0.1 262 | 263 | 264 | 0 0 0.025 0 0 0 265 | 266 | 267 | 0.05 268 | 0.05 269 | 270 | 271 | 272 | 273 | 0 0 0.025 0 0 0 274 | 275 | 276 | 0.05 277 | 0.05 278 | 279 | 280 | 281 | 285 | 286 | 287 | 288 | 289 | arm_base 290 | arm_shoulder_pan 291 | 292 | 293 | 1.000000 294 | 0.000000 295 | 296 | 0 0 1 297 | true 298 | 299 | 300 | 301 | arm_shoulder_pan 302 | arm_elbow_pan 303 | 304 | 305 | 1.000000 306 | 0.000000 307 | 308 | 0 0 1 309 | true 310 | 311 | 312 | 313 | arm_elbow_pan 314 | arm_wrist_lift 315 | 316 | 317 | 1.000000 318 | 0.000000 319 | 320 | 321 | -0.8 322 | 0.1 323 | 324 | 0 0 1 325 | true 326 | 327 | 328 | 329 | arm_wrist_lift 330 | arm_wrist_roll 331 | 332 | 333 | 1.000000 334 | 0.000000 335 | 336 | 337 | -2.999994 338 | 2.999994 339 | 340 | 0 0 1 341 | true 342 | 343 | 344 | 357 | 358 | 359 | -------------------------------------------------------------------------------- /examples/simple_arm_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import kinpy as kp 3 | 4 | chain = kp.build_chain_from_sdf(open("simple_arm/model.sdf")) 5 | print(chain) 6 | print(chain.get_joint_parameter_names()) 7 | ret = chain.forward_kinematics({"arm_elbow_pan_joint": np.pi / 2.0, "arm_wrist_lift_joint": -0.5}) 8 | print(ret) 9 | viz = kp.Visualizer() 10 | viz.add_robot(ret, chain.visuals_map()) 11 | viz.spin() 12 | -------------------------------------------------------------------------------- /examples/ur/ur.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 26 | 27 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | transmission_interface/SimpleTransmission 235 | 236 | PositionJointInterface 237 | 238 | 239 | 1 240 | 241 | 242 | 243 | transmission_interface/SimpleTransmission 244 | 245 | PositionJointInterface 246 | 247 | 248 | 1 249 | 250 | 251 | 252 | transmission_interface/SimpleTransmission 253 | 254 | PositionJointInterface 255 | 256 | 257 | 1 258 | 259 | 260 | 261 | transmission_interface/SimpleTransmission 262 | 263 | PositionJointInterface 264 | 265 | 266 | 1 267 | 268 | 269 | 270 | transmission_interface/SimpleTransmission 271 | 272 | PositionJointInterface 273 | 274 | 275 | 1 276 | 277 | 278 | 279 | transmission_interface/SimpleTransmission 280 | 281 | PositionJointInterface 282 | 283 | 284 | 1 285 | 286 | 287 | 288 | 289 | 290 | 291 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | -------------------------------------------------------------------------------- /examples/ur_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import kinpy as kp 3 | 4 | 5 | arm = kp.build_serial_chain_from_urdf( 6 | open("ur/ur.urdf"), 7 | root_link_name="base_link", 8 | end_link_name="ee_link", 9 | ) 10 | fk_solution = arm.forward_kinematics(np.zeros(len(arm.get_joint_parameter_names()))) 11 | print(fk_solution) 12 | -------------------------------------------------------------------------------- /kinpy/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .chain import Chain 4 | from .mjcf import * 5 | from .sdf import * 6 | from .transform import * 7 | from .urdf import * 8 | from .visualizer import * 9 | 10 | 11 | def build_chain_from_file(filename: str) -> Chain: 12 | ext = os.path.splitext(filename)[-1] 13 | if ext == ".urdf": 14 | return build_chain_from_urdf(open(filename).read()) 15 | elif ext == ".sdf": 16 | return build_chain_from_sdf(open(filename).read()) 17 | elif ext == ".mjcf": 18 | return build_chain_from_mjcf(open(filename).read()) 19 | else: 20 | raise ValueError(f"Invalid file type: '{ext}' file.") 21 | -------------------------------------------------------------------------------- /kinpy/chain.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import Dict, Iterator, List, Optional, Union 3 | 4 | import numpy as np 5 | 6 | from . import frame, ik, jacobian, transform 7 | 8 | 9 | class Chain: 10 | """Chain is a class that represents a kinematic chain.""" 11 | 12 | def __init__(self, root_frame: frame.Frame) -> None: 13 | self._root: Optional[frame.Frame] = root_frame 14 | 15 | def __str__(self) -> str: 16 | return str(self._root) 17 | 18 | def __iter__(self) -> Iterator[frame.Frame]: 19 | assert self._root is not None, "Root frame is None" 20 | yield from self._root.walk() 21 | 22 | @cached_property 23 | def dof(self): 24 | return len(self.get_joint_parameter_names()) 25 | 26 | @staticmethod 27 | def _find_frame_recursive(name: str, frame: frame.Frame) -> Optional[frame.Frame]: 28 | for child in frame.children: 29 | if child.name == name: 30 | return child 31 | ret = Chain._find_frame_recursive(name, child) 32 | if ret is not None: 33 | return ret 34 | return None 35 | 36 | def find_frame(self, name: str) -> Optional[frame.Frame]: 37 | """Find a frame by name. 38 | 39 | Parameters 40 | ---------- 41 | name : str 42 | Frame name. 43 | 44 | Returns 45 | ------- 46 | Optional[frame.Frame] 47 | Frame if found, None otherwise. 48 | """ 49 | assert self._root is not None, "Root frame is None" 50 | if self._root.name == name: 51 | return self._root 52 | return self._find_frame_recursive(name, self._root) 53 | 54 | @staticmethod 55 | def _find_link_recursive(name: str, frame: frame.Frame) -> Optional[frame.Link]: 56 | for child in frame.children: 57 | if child.link.name == name: 58 | return child.link 59 | ret = Chain._find_link_recursive(name, child) 60 | if ret is not None: 61 | return ret 62 | return None 63 | 64 | def find_link(self, name: str) -> Optional[frame.Link]: 65 | """Find a link by name. 66 | 67 | Parameters 68 | ---------- 69 | name : str 70 | Link name. 71 | 72 | Returns 73 | ------- 74 | Optional[frame.Link] 75 | Link if found, None otherwise. 76 | """ 77 | assert self._root is not None, "Root frame is None" 78 | if self._root.link.name == name: 79 | return self._root.link 80 | return self._find_link_recursive(name, self._root) 81 | 82 | @staticmethod 83 | def _get_joint_parameter_names(frame: frame.Frame, exclude_fixed: bool = True) -> List[str]: 84 | joint_names = [] 85 | if not (exclude_fixed and frame.joint.joint_type == "fixed"): 86 | joint_names.append(frame.joint.name) 87 | for child in frame.children: 88 | joint_names.extend(Chain._get_joint_parameter_names(child, exclude_fixed)) 89 | return joint_names 90 | 91 | def get_joint_parameter_names(self, exclude_fixed: bool = True) -> List[str]: 92 | """Get joint parameter names. 93 | 94 | Parameters 95 | ---------- 96 | exclude_fixed : bool, optional 97 | Exclude fixed joints, by default True 98 | 99 | Returns 100 | ------- 101 | List[str] 102 | Joint parameter names. 103 | """ 104 | assert self._root is not None, "Root frame is None" 105 | names = self._get_joint_parameter_names(self._root, exclude_fixed) 106 | return list(sorted(set(names), key=names.index)) 107 | 108 | def add_frame(self, frame: frame.Frame, parent_name: str) -> None: 109 | parent_frame = self.find_frame(parent_name) 110 | if parent_frame is not None: 111 | parent_frame.add_child(frame) 112 | 113 | @staticmethod 114 | def _forward_kinematics( 115 | root: frame.Frame, th_dict: Dict[str, float], world: Optional[transform.Transform] = None 116 | ) -> Dict[str, transform.Transform]: 117 | world = world or transform.Transform() 118 | link_transforms = {} 119 | trans = world * root.get_transform(th_dict.get(root.joint.name, 0.0)) 120 | link_transforms[root.link.name] = trans * root.link.offset 121 | for child in root.children: 122 | link_transforms.update(Chain._forward_kinematics(child, th_dict, trans)) 123 | return link_transforms 124 | 125 | def forward_kinematics( 126 | self, th: Union[Dict[str, float], List[float]], world: Optional[transform.Transform] = None, **kwargs: Dict 127 | ) -> Dict[str, transform.Transform]: 128 | """Forward kinematics. 129 | 130 | Parameters 131 | ---------- 132 | th : Union[Dict[str, float], List[float]] 133 | Joint parameters. 134 | world : Optional[transform.Transform], optional 135 | World transform, by default None 136 | 137 | Returns 138 | ------- 139 | Dict[str, transform.Transform] 140 | Link transforms. 141 | """ 142 | assert self._root is not None, "Root frame is None" 143 | world = world or transform.Transform() 144 | if not isinstance(th, dict): 145 | jn = self.get_joint_parameter_names() 146 | assert len(jn) == len(th) 147 | th_dict = dict((j, th[i]) for i, j in enumerate(jn)) 148 | else: 149 | th_dict = th 150 | return self._forward_kinematics(self._root, th_dict, world) 151 | 152 | @staticmethod 153 | def _visuals_map(root: frame.Frame) -> Dict[str, List[frame.Visual]]: 154 | vmap = {root.link.name: root.link.visuals} 155 | for child in root.children: 156 | vmap.update(Chain._visuals_map(child)) 157 | return vmap 158 | 159 | def visuals_map(self): 160 | return self._visuals_map(self._root) 161 | 162 | 163 | class SerialChain(Chain): 164 | """SerialChain is a class that represents a serial kinematic chain.""" 165 | 166 | def __init__(self, chain: Chain, end_frame_name: str, root_frame_name: str = "") -> None: 167 | assert chain._root is not None, "Chain root frame is None" 168 | if root_frame_name == "": 169 | self._root = chain._root 170 | else: 171 | self._root = chain.find_frame(root_frame_name) 172 | if self._root is None: 173 | raise ValueError("Invalid root frame name %s." % root_frame_name) 174 | frames = self._generate_serial_chain_recurse(self._root, end_frame_name) 175 | if frames is None: 176 | raise ValueError("Invalid end frame name %s." % end_frame_name) 177 | self._serial_frames = [self._root] + frames 178 | 179 | @staticmethod 180 | def _generate_serial_chain_recurse(root_frame: frame.Frame, end_frame_name: str) -> Optional[List[frame.Frame]]: 181 | for child in root_frame.children: 182 | if child.name == end_frame_name: 183 | return [child] 184 | else: 185 | frames = SerialChain._generate_serial_chain_recurse(child, end_frame_name) 186 | if frames is not None: 187 | return [child] + frames 188 | return None 189 | 190 | def get_joint_parameter_names(self, exclude_fixed: bool = True) -> List[str]: 191 | assert self._serial_frames is not None, "Serial chain not initialized." 192 | names = [] 193 | for f in self._serial_frames: 194 | if exclude_fixed and f.joint.joint_type == "fixed": 195 | continue 196 | names.append(f.joint.name) 197 | return names 198 | 199 | def forward_kinematics( # type: ignore[override] 200 | self, 201 | th: Union[Dict[str, float], List[float]], 202 | world: Optional[transform.Transform] = None, 203 | end_only: bool = True, 204 | ) -> Union[transform.Transform, Dict[str, transform.Transform]]: 205 | assert self._serial_frames is not None, "Serial chain not initialized." 206 | if isinstance(th, dict): 207 | link_transforms = super().forward_kinematics(th, world) 208 | if end_only: 209 | return link_transforms[self._serial_frames[-1].link.name] 210 | else: 211 | return link_transforms 212 | world = world or transform.Transform() 213 | cnt = 0 214 | link_transforms = {} 215 | trans = world 216 | for f in self._serial_frames: 217 | if f.joint.joint_type != "fixed": 218 | trans = trans * f.get_transform(th[cnt]) 219 | else: 220 | trans = trans * f.get_transform() 221 | link_transforms[f.link.name] = trans * f.link.offset 222 | if f.joint.joint_type != "fixed": 223 | cnt += 1 224 | return link_transforms[self._serial_frames[-1].link.name] if end_only else link_transforms 225 | 226 | def jacobian(self, th: List[float], end_only: bool = True) -> Union[np.ndarray, Dict[str, np.ndarray]]: 227 | assert self._serial_frames is not None, "Serial chain not initialized." 228 | if end_only: 229 | return jacobian.calc_jacobian(self, th) 230 | else: 231 | jacobians = {} 232 | for serial_frame in self._serial_frames: 233 | jac = jacobian.calc_jacobian_frames(self, th, link_name=serial_frame.link.name) 234 | jacobians[serial_frame.link.name] = jac 235 | return jacobians 236 | 237 | def inverse_kinematics(self, pose: transform.Transform, initial_state: Optional[np.ndarray] = None) -> np.ndarray: 238 | return ik.inverse_kinematics(self, pose, initial_state) 239 | -------------------------------------------------------------------------------- /kinpy/frame.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterator, List, Optional 2 | 3 | import numpy as np 4 | import transformations as tf 5 | 6 | from . import transform 7 | 8 | 9 | class Visual: 10 | TYPES = ["box", "cylinder", "sphere", "capsule", "mesh"] 11 | 12 | def __init__( 13 | self, 14 | offset: Optional[transform.Transform] = None, 15 | geom_type: Optional[str] = None, 16 | geom_param: Any = None, 17 | ) -> None: 18 | self.offset = offset or transform.Transform() 19 | self.geom_type = geom_type 20 | self.geom_param = geom_param 21 | 22 | def __repr__(self) -> str: 23 | return "Visual(offset={0}, geom_type='{1}', geom_param={2})".format( 24 | self.offset, self.geom_type, self.geom_param 25 | ) 26 | 27 | 28 | class Link: 29 | def __init__( 30 | self, name: Optional[str] = None, offset: Optional[transform.Transform] = None, visuals: Optional[List] = None 31 | ) -> None: 32 | self.name = name if name is not None else "none" 33 | self.offset = offset or transform.Transform() 34 | self.visuals = visuals or [] 35 | 36 | def __repr__(self) -> str: 37 | return "Link(name='{0}', offset={1}, visuals={2})".format(self.name, self.offset, self.visuals) 38 | 39 | 40 | class Joint: 41 | TYPES = ["fixed", "revolute", "prismatic"] 42 | 43 | def __init__( 44 | self, 45 | name: Optional[str] = None, 46 | offset: Optional[transform.Transform] = None, 47 | joint_type: str = "fixed", 48 | axis: Optional[List[float]] = None, 49 | ) -> None: 50 | self.name = name if name is not None else "none" 51 | self.offset = offset or transform.Transform() 52 | self.joint_type = joint_type 53 | if self.joint_type != "fixed" and axis is None: 54 | self.axis = np.array([0.0, 0.0, 1.0]) 55 | else: 56 | self.axis = np.array(axis) if axis is not None else np.array([0.0, 0.0, 1.0]) 57 | 58 | def __repr__(self) -> str: 59 | return "Joint(name='{0}', offset={1}, joint_type='{2}', axis={3})".format( 60 | self.name, self.offset, self.joint_type, self.axis 61 | ) 62 | 63 | 64 | class Frame: 65 | def __init__( 66 | self, 67 | name: Optional[str] = None, 68 | link: Optional[Link] = None, 69 | joint: Optional[Joint] = None, 70 | children: Optional[List["Frame"]] = None, 71 | ) -> None: 72 | self.name = "None" if name is None else name 73 | self.link = link or Link() 74 | self.joint = joint or Joint() 75 | self.children = children or [] 76 | 77 | def _ptree(self, indent_width: int = 4) -> str: 78 | def _inner_ptree(root: Frame, parent: Frame, grandpa: Optional[Frame] = None, indent: str = ""): 79 | show_str = "" 80 | if parent.name != root.name: 81 | show_str += " " + parent.name + ("" if grandpa is None else "\n") 82 | if not parent.children: 83 | return show_str 84 | for child in parent.children[:-1]: 85 | show_str += indent + "├" + "─" * indent_width 86 | show_str += _inner_ptree(root, child, parent, indent + "│" + " " * (indent_width + 1)) 87 | if parent.children: 88 | child = parent.children[-1] 89 | show_str += indent + "└" + "─" * indent_width 90 | show_str += _inner_ptree(root, child, parent, indent + " " * (indent_width + 2)) 91 | return show_str 92 | 93 | show_str = self.name + "\n" 94 | show_str += _inner_ptree(self, self) 95 | return show_str 96 | 97 | def __str__(self) -> str: 98 | return self._ptree() 99 | 100 | def add_child(self, child: "Frame") -> None: 101 | self.children.append(child) 102 | 103 | def is_end(self) -> bool: 104 | return len(self.children) == 0 105 | 106 | def get_transform(self, theta: float = 0.0) -> transform.Transform: 107 | if self.joint.joint_type == "revolute": 108 | t = transform.Transform(tf.quaternion_about_axis(theta, self.joint.axis)) 109 | elif self.joint.joint_type == "prismatic": 110 | t = transform.Transform(pos=theta * self.joint.axis) 111 | elif self.joint.joint_type == "fixed": 112 | t = transform.Transform() 113 | else: 114 | raise ValueError("Unsupported joint type %s." % self.joint.joint_type) 115 | return self.joint.offset * t 116 | 117 | def walk(self) -> Iterator["Frame"]: 118 | yield self 119 | for child in self.children: 120 | yield from child.walk() 121 | -------------------------------------------------------------------------------- /kinpy/ik.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import numpy as np 4 | import scipy.optimize as sco 5 | 6 | from . import transform 7 | 8 | 9 | def inverse_kinematics( 10 | serial_chain: Any, pose: transform.Transform, initial_state: Optional[np.ndarray] = None 11 | ) -> np.ndarray: 12 | ndim = len(serial_chain.get_joint_parameter_names()) 13 | if initial_state is None: 14 | x0 = np.zeros(ndim) 15 | else: 16 | x0 = initial_state 17 | 18 | def object_fn(x): 19 | tf = serial_chain.forward_kinematics(x) 20 | obj = np.square(np.linalg.lstsq(pose.matrix(), tf.matrix(), rcond=-1)[0] - np.identity(4)).sum() 21 | return obj 22 | 23 | ret = sco.minimize(object_fn, x0, method="BFGS") 24 | return ret.x 25 | -------------------------------------------------------------------------------- /kinpy/jacobian.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | import numpy as np 4 | 5 | from . import transform 6 | 7 | 8 | def calc_jacobian(serial_chain: Any, th: List[float], tool: Optional[transform.Transform] = None) -> np.ndarray: 9 | tool = tool or transform.Transform() 10 | ndof = len(th) 11 | j_fl = np.zeros((6, ndof)) 12 | cur_transform = tool.matrix() 13 | 14 | cnt = 0 15 | for f in reversed(serial_chain._serial_frames): 16 | if f.joint.joint_type == "revolute": 17 | cnt += 1 18 | delta = np.dot(f.joint.axis, cur_transform[:3, :3]) 19 | d = np.dot(np.cross(f.joint.axis, cur_transform[:3, 3]), cur_transform[:3, :3]) 20 | j_fl[:, -cnt] = np.hstack((d, delta)) 21 | elif f.joint.joint_type == "prismatic": 22 | cnt += 1 23 | j_fl[:3, -cnt] = np.dot(f.joint.axis, cur_transform[:3, :3]) 24 | cur_frame_transform = f.get_transform(th[-cnt]).matrix() 25 | cur_transform = np.dot(cur_frame_transform, cur_transform) 26 | 27 | pose = serial_chain.forward_kinematics(th).matrix() 28 | rotation = pose[:3, :3] 29 | j_tr = np.zeros((6, 6)) 30 | j_tr[:3, :3] = rotation 31 | j_tr[3:, 3:] = rotation 32 | j_w = np.dot(j_tr, j_fl) 33 | return j_w 34 | 35 | 36 | def calc_jacobian_frames( 37 | serial_chain: Any, th: List[float], link_name: str, tool: Optional[transform.Transform] = None 38 | ) -> np.ndarray: 39 | tool = tool or transform.Transform() 40 | ndof = len(th) 41 | j_fl = np.zeros((6, ndof)) 42 | cur_transform = tool.matrix() 43 | 44 | # select first num_th movable joints 45 | serial_frames = [] 46 | num_movable_joints = 0 47 | for serial_frame in serial_chain._serial_frames: 48 | serial_frames.append(serial_frame) 49 | if serial_frame.joint.joint_type != "fixed": 50 | num_movable_joints += 1 51 | 52 | if serial_frame.link.name == link_name: 53 | break # found first n joints 54 | 55 | cnt = len(th) - num_movable_joints # only first num_th joints 56 | for f in reversed(serial_frames): 57 | if f.joint.joint_type == "revolute": 58 | cnt += 1 59 | delta = np.dot(f.joint.axis, cur_transform[:3, :3]) 60 | d = np.dot(np.cross(f.joint.axis, cur_transform[:3, 3]), cur_transform[:3, :3]) 61 | j_fl[:, -cnt] = np.hstack((d, delta)) 62 | elif f.joint.joint_type == "prismatic": 63 | cnt += 1 64 | j_fl[:3, -cnt] = np.dot(f.joint.axis, cur_transform[:3, :3]) 65 | cur_frame_transform = f.get_transform(th[-cnt]).matrix() 66 | cur_transform = np.dot(cur_frame_transform, cur_transform) 67 | 68 | poses = serial_chain.forward_kinematics(th, end_only=False) 69 | pose = poses[link_name].matrix() 70 | 71 | rotation = pose[:3, :3] 72 | j_tr = np.zeros((6, 6)) 73 | j_tr[:3, :3] = rotation 74 | j_tr[3:, 3:] = rotation 75 | j_w = np.dot(j_tr, j_fl) 76 | 77 | return j_w 78 | -------------------------------------------------------------------------------- /kinpy/mjcf.py: -------------------------------------------------------------------------------- 1 | import io 2 | from typing import Dict, Optional, TextIO, Union 3 | 4 | from . import chain, frame, mjcf_parser, transform 5 | 6 | JOINT_TYPE_MAP: Dict[str, str] = {"hinge": "revolute", "slide": "prismatic"} 7 | 8 | 9 | def geoms_to_visuals(geom, base: Optional[transform.Transform] = None): 10 | base = base or transform.Transform() 11 | visuals = [] 12 | for g in geom: 13 | if g.type == "capsule": 14 | param = (g.size[0], g.fromto) 15 | elif g.type == "sphere": 16 | param = g.size[0] 17 | else: 18 | raise ValueError("Invalid geometry type %s." % g.type) 19 | visuals.append( 20 | frame.Visual(offset=base * transform.Transform(g.quat, g.pos), geom_type=g.type, geom_param=param) 21 | ) 22 | return visuals 23 | 24 | 25 | def body_to_link(body, base: Optional[transform.Transform] = None): 26 | base = base or transform.Transform() 27 | return frame.Link(body.name, offset=base * transform.Transform(body.quat, body.pos)) 28 | 29 | 30 | def joint_to_joint(joint, base: Optional[transform.Transform] = None): 31 | base = base or transform.Transform() 32 | return frame.Joint( 33 | joint.name, 34 | offset=base * transform.Transform(pos=joint.pos), 35 | joint_type=JOINT_TYPE_MAP[joint.type], 36 | axis=joint.axis, 37 | ) 38 | 39 | 40 | def add_composite_joint(root_frame, joints, base: Optional[transform.Transform] = None): 41 | base = base or transform.Transform() 42 | if len(joints) > 0: 43 | root_frame.children = root_frame.children + [ 44 | frame.Frame(link=frame.Link(name=root_frame.link.name + "_child"), joint=joint_to_joint(joints[0], base)) 45 | ] 46 | ret, offset = add_composite_joint(root_frame.children[-1], joints[1:]) 47 | return ret, root_frame.joint.offset * offset 48 | else: 49 | return root_frame, root_frame.joint.offset 50 | 51 | 52 | def _build_chain_recurse(root_frame, root_body): 53 | base = root_frame.link.offset 54 | cur_frame, cur_base = add_composite_joint(root_frame, root_body.joint, base) 55 | jbase = cur_base.inverse() * base 56 | if len(root_body.joint) > 0: 57 | cur_frame.link.visuals = geoms_to_visuals(root_body.geom, jbase) 58 | else: 59 | cur_frame.link.visuals = geoms_to_visuals(root_body.geom) 60 | for b in root_body.body: 61 | cur_frame.children = cur_frame.children + [frame.Frame()] 62 | next_frame = cur_frame.children[-1] 63 | next_frame.name = b.name + "_frame" 64 | next_frame.link = body_to_link(b, jbase) 65 | _build_chain_recurse(next_frame, b) 66 | 67 | 68 | def build_chain_from_mjcf(data: Union[str, TextIO]) -> chain.Chain: 69 | """ 70 | Build a Chain object from MJCF data. 71 | 72 | Parameters 73 | ---------- 74 | data : str or TextIO 75 | MJCF string data or file object. 76 | 77 | Returns 78 | ------- 79 | chain.Chain 80 | Chain object created from MJCF. 81 | """ 82 | if isinstance(data, io.TextIOBase): 83 | data = data.read() 84 | model = mjcf_parser.from_xml_string(data) 85 | root_body = model.worldbody.body[0] 86 | root_frame = frame.Frame(root_body.name + "_frame", link=body_to_link(root_body), joint=frame.Joint()) 87 | _build_chain_recurse(root_frame, root_body) 88 | return chain.Chain(root_frame) 89 | 90 | 91 | def build_serial_chain_from_mjcf( 92 | data: Union[str, TextIO], end_link_name: str, root_link_name: str = "" 93 | ) -> chain.SerialChain: 94 | """ 95 | Build a SerialChain object from MJCF data. 96 | 97 | Parameters 98 | ---------- 99 | data : str 100 | MJCF string data. 101 | end_link_name : str 102 | The name of the link that is the end effector. 103 | root_link_name : str, optional 104 | The name of the root link. 105 | 106 | Returns 107 | ------- 108 | chain.SerialChain 109 | SerialChain object created from MJCF. 110 | """ 111 | if isinstance(data, io.TextIOBase): 112 | data = data.read() 113 | mjcf_chain = build_chain_from_mjcf(data) 114 | return chain.SerialChain( 115 | mjcf_chain, end_link_name + "_frame", "" if root_link_name == "" else root_link_name + "_frame" 116 | ) 117 | -------------------------------------------------------------------------------- /kinpy/mjcf_parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser import * 2 | -------------------------------------------------------------------------------- /kinpy/mjcf_parser/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Base class for all MJCF elements in the object model.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import abc 21 | 22 | import six 23 | 24 | 25 | @six.add_metaclass(abc.ABCMeta) 26 | class Element(object): 27 | """Abstract base class for an MJCF element. 28 | 29 | This class is provided so that `isinstance(foo, Element)` is `True` for all 30 | Element-like objects. We do not implement the actual element here because 31 | the actual object returned from traversing the object hierarchy is a 32 | weakproxy-like proxy to an actual element. This is because we do not allow 33 | orphaned non-root elements, so when a particular element is removed from the 34 | tree, all references held automatically become invalid. 35 | """ 36 | 37 | __slots__ = [] 38 | -------------------------------------------------------------------------------- /kinpy/mjcf_parser/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Magic constants used within `dm_control.mjcf`.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | PREFIX_SEPARATOR = "/" 21 | PREFIX_SEPARATOR_ESCAPE = "\\" 22 | 23 | # Used to disambiguate namespaces between attachment frames. 24 | NAMESPACE_SEPARATOR = "@" 25 | 26 | # Magic attribute names 27 | BASEPATH = "basepath" 28 | CHILDCLASS = "childclass" 29 | CLASS = "class" 30 | DEFAULT = "default" 31 | DCLASS = "dclass" 32 | 33 | # Magic tags 34 | ACTUATOR = "actuator" 35 | BODY = "body" 36 | DEFAULT = "default" 37 | MESH = "mesh" 38 | SITE = "site" 39 | TENDON = "tendon" 40 | WORLDBODY = "worldbody" 41 | 42 | MJDATA_TRIGGERS_DIRTY = ["qpos", "qvel", "act", "ctrl", "qfrc_applied", "xfrc_applied"] 43 | MJMODEL_DOESNT_TRIGGER_DIRTY = ["rgba", "matid", "emission", "specular", "shininess", "reflectance"] 44 | 45 | # When writing into `model.{body,geom,site}_{pos,quat}` we must ensure that the 46 | # corresponding rows in `model.{body,geom,site}_sameframe` are set to zero, 47 | # otherwise MuJoCo will use the body or inertial frame instead of our modified 48 | # pos/quat values. We must do the same for `body_{ipos,iquat}` and 49 | # `body_simple`. 50 | MJMODEL_DISABLE_ON_WRITE = { 51 | # Field name in MjModel: (attribute names of Binding instance to be zeroed) 52 | "body_pos": ("sameframe",), 53 | "body_quat": ("sameframe",), 54 | "geom_pos": ("sameframe",), 55 | "geom_quat": ("sameframe",), 56 | "site_pos": ("sameframe",), 57 | "site_quat": ("sameframe",), 58 | "body_ipos": ("simple", "sameframe"), 59 | "body_iquat": ("simple", "sameframe"), 60 | } 61 | 62 | # This is the actual upper limit on VFS filename length, despite what it says 63 | # in the header file (100) or the error message (99). 64 | MAX_VFS_FILENAME_LENGTH = 98 65 | 66 | # The prefix used in the schema to denote reference_namespace that are defined 67 | # via another attribute. 68 | INDIRECT_REFERENCE_NAMESPACE_PREFIX = "attrib:" 69 | -------------------------------------------------------------------------------- /kinpy/mjcf_parser/copier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Helper object for keeping track of new elements created when copying MJCF.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | from . import constants 21 | 22 | 23 | class Copier(object): 24 | """Helper for keeping track of new elements created when copying MJCF.""" 25 | 26 | def __init__(self, source): 27 | if source._attachments: # pylint: disable=protected-access 28 | raise NotImplementedError("Cannot copy from elements with attachments") 29 | self._source = source 30 | 31 | def copy_into(self, destination, override_attributes=False): 32 | """Copies this copier's element into a destination MJCF element.""" 33 | newly_created_elements = {} 34 | destination._check_valid_attachment(self._source) # pylint: disable=protected-access 35 | if override_attributes: 36 | destination.set_attributes(**self._source.get_attributes()) 37 | else: 38 | destination._sync_attributes(self._source, copying=True) # pylint: disable=protected-access 39 | for source_child in self._source.all_children(): 40 | dest_child = None 41 | # First, if source_child has an identifier, we look for an existing child 42 | # element of self with the same identifier to override. 43 | if source_child.spec.identifier and override_attributes: 44 | identifier_attr = source_child.spec.identifier 45 | if identifier_attr == constants.CLASS: 46 | identifier_attr = constants.DCLASS 47 | identifier = getattr(source_child, identifier_attr) 48 | if identifier: 49 | dest_child = destination.find(source_child.spec.namespace, identifier) 50 | if dest_child is not None and dest_child.parent is not destination: 51 | raise ValueError( 52 | "<{}> with identifier {!r} is already a child of another element".format( 53 | source_child.spec.namespace, identifier 54 | ) 55 | ) 56 | # Next, we cover the case where either the child is not a repeated element 57 | # or if source_child has an identifier attribute but it isn't set. 58 | if not source_child.spec.repeated and dest_child is None: 59 | dest_child = destination.get_children(source_child.tag) 60 | 61 | # Add a new element if dest_child doesn't exist, either because it is 62 | # supposed to be a repeated child, or because it's an uncreated on-demand. 63 | if dest_child is None: 64 | dest_child = destination.add(source_child.tag, **source_child.get_attributes()) 65 | newly_created_elements[source_child] = dest_child 66 | override_child_attributes = True 67 | else: 68 | override_child_attributes = override_attributes 69 | 70 | # Finally, copy attributes into dest_child. 71 | child_copier = Copier(source_child) 72 | newly_created_elements.update(child_copier.copy_into(dest_child, override_child_attributes)) 73 | return newly_created_elements 74 | -------------------------------------------------------------------------------- /kinpy/mjcf_parser/debugging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Implements PyMJCF debug mode. 17 | 18 | PyMJCF debug mode stores a stack trace each time the MJCF object is modified. 19 | If Mujoco raises a compile error on the generated XML model, we would then be 20 | able to find the original source line that created the offending element. 21 | """ 22 | 23 | from __future__ import absolute_import, division, print_function 24 | 25 | import collections 26 | import contextlib 27 | import copy 28 | import os 29 | import re 30 | import sys 31 | import traceback 32 | 33 | import six 34 | from absl import flags 35 | from lxml import etree 36 | 37 | FLAGS = flags.FLAGS 38 | flags.DEFINE_boolean( 39 | "pymjcf_debug", 40 | False, 41 | "Enables PyMJCF debug mode (SLOW!). In this mode, a stack trace is logged " 42 | "each the MJCF object is modified. This may be helpful in locating the " 43 | "Python source line corresponding to a problematic element in the " 44 | "generated XML.", 45 | ) 46 | flags.DEFINE_string("pymjcf_debug_full_dump_dir", "", "Path to dump full debug info when Mujoco error is encountered.") 47 | 48 | StackTraceEntry = collections.namedtuple("StackTraceEntry", ("filename", "line_number", "function_name", "text")) 49 | 50 | ElementDebugInfo = collections.namedtuple("ElementDebugInfo", ("element", "init_stack", "attribute_stacks")) 51 | 52 | MODULE_PATH = os.path.dirname(sys.modules[__name__].__file__) 53 | DEBUG_METADATA_PREFIX = "pymjcfdebug" 54 | 55 | _DEBUG_METADATA_TAG_PREFIX = "".format(DEBUG_METADATA_PREFIX)) 57 | 58 | # Modified by `freeze_current_stack_trace`. 59 | _CURRENT_FROZEN_STACK = None 60 | 61 | # These globals will take their default values from the `--pymjcf_debug` and 62 | # `--pymjcf_debug_full_dump_dir` flags respectively. We cannot use `FLAGS` as 63 | # global variables because flag parsing might not have taken place (e.g. when 64 | # running `nosetests`). 65 | _DEBUG_MODE_ENABLED = None 66 | _DEBUG_FULL_DUMP_DIR = None 67 | 68 | 69 | def debug_mode(): 70 | """Returns a boolean that indicates whether PyMJCF debug mode is enabled.""" 71 | global _DEBUG_MODE_ENABLED 72 | if _DEBUG_MODE_ENABLED is None: 73 | if FLAGS.is_parsed(): 74 | _DEBUG_MODE_ENABLED = FLAGS.pymjcf_debug 75 | else: 76 | _DEBUG_MODE_ENABLED = FLAGS["pymjcf_debug"].default 77 | return _DEBUG_MODE_ENABLED 78 | 79 | 80 | def enable_debug_mode(): 81 | """Enables PyMJCF debug mode.""" 82 | global _DEBUG_MODE_ENABLED 83 | _DEBUG_MODE_ENABLED = True 84 | 85 | 86 | def disable_debug_mode(): 87 | """Disables PyMJCF debug mode.""" 88 | global _DEBUG_MODE_ENABLED 89 | _DEBUG_MODE_ENABLED = False 90 | 91 | 92 | def get_full_dump_dir(): 93 | """Gets the directory to dump full debug info files.""" 94 | global _DEBUG_FULL_DUMP_DIR 95 | if _DEBUG_FULL_DUMP_DIR is None: 96 | if FLAGS.is_parsed(): 97 | _DEBUG_FULL_DUMP_DIR = FLAGS.pymjcf_debug_full_dump_dir 98 | else: 99 | _DEBUG_FULL_DUMP_DIR = FLAGS["pymjcf_debug_full_dump_dir"].default 100 | return _DEBUG_FULL_DUMP_DIR 101 | 102 | 103 | def set_full_dump_dir(dump_path): 104 | """Sets the directory to dump full debug info files.""" 105 | global _DEBUG_FULL_DUMP_DIR 106 | _DEBUG_FULL_DUMP_DIR = dump_path 107 | 108 | 109 | def get_current_stack_trace(): 110 | """Returns the stack trace of the current execution frame. 111 | 112 | Returns: 113 | A list of `StackTraceEntry` named tuples corresponding to the current stack 114 | trace of the process, truncated to immediately before entry into 115 | PyMJCF internal code. 116 | """ 117 | if _CURRENT_FROZEN_STACK: 118 | return copy.deepcopy(_CURRENT_FROZEN_STACK) 119 | else: 120 | return _get_actual_current_stack_trace() 121 | 122 | 123 | def _get_actual_current_stack_trace(): 124 | """Returns the stack trace of the current execution frame. 125 | 126 | Returns: 127 | A list of `StackTraceEntry` named tuples corresponding to the current stack 128 | trace of the process, truncated to immediately before entry into 129 | PyMJCF internal code. 130 | """ 131 | raw_stack = traceback.extract_stack() 132 | processed_stack = [] 133 | for raw_stack_item in raw_stack: 134 | stack_item = StackTraceEntry(*raw_stack_item) 135 | if stack_item.filename.startswith(MODULE_PATH) and not stack_item.filename.endswith("_test.py"): 136 | break 137 | else: 138 | processed_stack.append(stack_item) 139 | return processed_stack 140 | 141 | 142 | @contextlib.contextmanager 143 | def freeze_current_stack_trace(): 144 | """A context manager that freezes the stack trace. 145 | 146 | AVOID USING THIS CONTEXT MANAGER OUTSIDE OF INTERNAL PYMJCF IMPLEMENTATION, 147 | AS IT REDUCES THE USEFULNESS OF DEBUG MODE. 148 | 149 | If PyMJCF debug mode is enabled, calls to `debugging.get_current_stack_trace` 150 | within this context will always return the stack trace from when this context 151 | was entered. 152 | 153 | The frozen stack is global to this debugging module. That is, if the context 154 | is entered while another one is still active, then the stack trace of the 155 | outermost one is returned. 156 | 157 | This context significantly speeds up bulk operations in debug mode, e.g. 158 | parsing an existing XML string or creating a deeply-nested element, as it 159 | prevents the same stack trace from being repeatedly constructed. 160 | 161 | Yields: 162 | `None` 163 | """ 164 | global _CURRENT_FROZEN_STACK 165 | if debug_mode() and _CURRENT_FROZEN_STACK is None: 166 | _CURRENT_FROZEN_STACK = _get_actual_current_stack_trace() 167 | yield 168 | _CURRENT_FROZEN_STACK = None 169 | else: 170 | yield 171 | 172 | 173 | class DebugContext(object): 174 | """A helper object to store debug information for a generated XML string. 175 | 176 | This class is intended for internal use within the PyMJCF implementation. 177 | """ 178 | 179 | def __init__(self): 180 | self._xml_string = None 181 | self._debug_info_for_element_ids = {} 182 | 183 | def register_element_for_debugging(self, elem): 184 | """Registers an `Element` and returns debugging metadata for the XML. 185 | 186 | Args: 187 | elem: An `mjcf.Element`. 188 | 189 | Returns: 190 | An `lxml.etree.Comment` that represents debugging metadata in the 191 | generated XML. 192 | """ 193 | if not debug_mode(): 194 | return None 195 | else: 196 | self._debug_info_for_element_ids[id(elem)] = ElementDebugInfo( 197 | elem, 198 | copy.deepcopy(elem.get_init_stack()), 199 | copy.deepcopy(elem.get_last_modified_stacks_for_all_attributes()), 200 | ) 201 | return etree.Comment("{}:{}".format(DEBUG_METADATA_PREFIX, id(elem))) 202 | 203 | def commit_xml_string(self, xml_string): 204 | """Commits the XML string associated with this debug context. 205 | 206 | This function also formats the XML string to make sure that the debugging 207 | metadata appears on the same line as the corresponding XML element. 208 | 209 | Args: 210 | xml_string: A pretty-printed XML string. 211 | 212 | Returns: 213 | A reformatted XML string where all debugging metadata appears on the same 214 | line as the corresponding XML element. 215 | """ 216 | formatted = re.sub(r"\n\s*" + _DEBUG_METADATA_TAG_PREFIX, _DEBUG_METADATA_TAG_PREFIX, xml_string) 217 | self._xml_string = formatted 218 | return formatted 219 | 220 | def process_and_raise_last_exception(self): 221 | """Processes and re-raises the last mujoco.wrapper.Error caught. 222 | 223 | This function will insert the relevant line from the source XML to the error 224 | message. If debug mode is enabled, additional debugging information is 225 | appended to the error message. If debug mode is not enabled, the error 226 | message instructs the user to enable it by rerunning the executable with an 227 | appropriate flag. 228 | """ 229 | err_type, err, stack = sys.exc_info() 230 | line_number_match = re.search(r"[Ll][Ii][Nn][Ee]\s*[:=]?\s*(\d+)", str(err)) 231 | if line_number_match: 232 | xml_line_number = int(line_number_match.group(1)) 233 | xml_line = self._xml_string.split("\n")[xml_line_number - 1] 234 | stripped_xml_line = xml_line.strip() 235 | comment_match = re.search(_DEBUG_METADATA_TAG_PREFIX, stripped_xml_line) 236 | if comment_match: 237 | stripped_xml_line = stripped_xml_line[: comment_match.start()] 238 | else: 239 | xml_line = "" 240 | stripped_xml_line = "" 241 | 242 | message_lines = [] 243 | if debug_mode(): 244 | if get_full_dump_dir(): 245 | self.dump_full_debug_info_to_disk() 246 | message_lines.extend(["Compile error raised by Mujoco.", str(err)]) 247 | if xml_line: 248 | message_lines.extend([stripped_xml_line, self._generate_debug_message_from_xml_line(xml_line)]) 249 | else: 250 | message_lines.extend( 251 | [ 252 | "Compile error raised by Mujoco; " 253 | "run again with --pymjcf_debug for additional debug information.", 254 | str(err), 255 | ] 256 | ) 257 | if xml_line: 258 | message_lines.append(stripped_xml_line) 259 | 260 | message = "\n".join(message_lines) 261 | six.reraise(err_type, err_type(message), stack) 262 | 263 | @property 264 | def default_dump_dir(self): 265 | return get_full_dump_dir() 266 | 267 | @property 268 | def debug_mode(self): 269 | return debug_mode() 270 | 271 | def dump_full_debug_info_to_disk(self, dump_dir=None): 272 | """Dumps full debug information to disk. 273 | 274 | Full debug information consists of an XML file whose elements are tagged 275 | with a unique ID, and a stack trace file for each element ID. Each stack 276 | trace file consists of a stack trace for when the element was created, and 277 | when each attribute was last modified. 278 | 279 | Args: 280 | dump_dir: Full path to the directory in which dump files are created. 281 | 282 | Raises: 283 | ValueError: If neither `dump_dir` nor the global dump path is given. The 284 | global dump path can be specified either via the 285 | --pymjcf_debug_full_dump_dir flag or via `debugging.set_full_dump_dir`. 286 | """ 287 | dump_dir = dump_dir or self.default_dump_dir 288 | if not dump_dir: 289 | raise ValueError("`dump_dir` is not specified") 290 | section_separator = "\n" + ("=" * 80) + "\n" 291 | 292 | def dump_stack(header, stack, f): 293 | indent = " " 294 | f.write(header + "\n") 295 | for stack_entry in stack: 296 | f.write( 297 | indent 298 | + "`{}` at {}:{}\n".format( 299 | stack_entry.function_name, stack_entry.filename, stack_entry.line_number 300 | ) 301 | ) 302 | f.write((indent * 2) + str(stack_entry.text) + "\n") 303 | f.write(section_separator) 304 | 305 | with open(os.path.join(dump_dir, "model.xml"), "w") as f: 306 | f.write(self._xml_string) 307 | for elem_id, debug_info in six.iteritems(self._debug_info_for_element_ids): 308 | with open(os.path.join(dump_dir, str(elem_id) + ".dump"), "w") as f: 309 | f.write("{}:{}\n".format(DEBUG_METADATA_PREFIX, elem_id)) 310 | f.write(str(debug_info.element) + "\n") 311 | dump_stack("Element creation", debug_info.init_stack, f) 312 | for attrib_name, stack in six.iteritems(debug_info.attribute_stacks): 313 | attrib_value = debug_info.element.get_attribute_xml_string(attrib_name) 314 | if stack[-1] == debug_info.init_stack[-1]: 315 | if attrib_value is not None: 316 | f.write('Attribute {}="{}"\n'.format(attrib_name, attrib_value)) 317 | f.write(" was set when the element was created\n") 318 | f.write(section_separator) 319 | else: 320 | if attrib_value is not None: 321 | dump_stack('Attribute {}="{}"'.format(attrib_name, attrib_value), stack, f) 322 | else: 323 | dump_stack("Attribute {} was CLEARED".format(attrib_name), stack, f) 324 | 325 | def _generate_debug_message_from_xml_line(self, xml_line): 326 | """Generates a debug message by parsing the metadata on an XML line.""" 327 | metadata_match = _DEBUG_METADATA_SEARCH_PATTERN.search(xml_line) 328 | if metadata_match: 329 | elem_id = int(metadata_match.group(1)) 330 | return self._generate_debug_message_from_element_id(elem_id) 331 | else: 332 | return "" 333 | 334 | def _generate_debug_message_from_element_id(self, elem_id): 335 | """Generates a debug message for the specified Element.""" 336 | out = [] 337 | debug_info = self._debug_info_for_element_ids[elem_id] 338 | 339 | out.append("Debug summary for element:") 340 | if not get_full_dump_dir(): 341 | out.append( 342 | " * Full debug info can be dumped to disk by setting the " 343 | "flag --pymjcf_debug_full_dump_dir=path/to/dump>" 344 | ) 345 | out.append( 346 | " * Element object was created by `{}` at {}:{}".format( 347 | debug_info.init_stack[-1].function_name, 348 | debug_info.init_stack[-1].filename, 349 | debug_info.init_stack[-1].line_number, 350 | ) 351 | ) 352 | 353 | for attrib_name, stack in six.iteritems(debug_info.attribute_stacks): 354 | attrib_value = debug_info.element.get_attribute_xml_string(attrib_name) 355 | if stack[-1] == debug_info.init_stack[-1]: 356 | if attrib_value is not None: 357 | out.append(' * {}="{}" was set when the element was created'.format(attrib_name, attrib_value)) 358 | else: 359 | if attrib_value is not None: 360 | out.append( 361 | ' * {}="{}" was set by `{}` at `{}:{}`'.format( 362 | attrib_name, 363 | attrib_value, 364 | stack[-1].function_name, 365 | stack[-1].filename, 366 | stack[-1].line_number, 367 | ) 368 | ) 369 | else: 370 | out.append( 371 | " * {} was CLEARED by `{}` at {}:{}".format( 372 | attrib_name, stack[-1].function_name, stack[-1].filename, stack[-1].line_number 373 | ) 374 | ) 375 | 376 | return "\n".join(out) 377 | -------------------------------------------------------------------------------- /kinpy/mjcf_parser/io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """IO functions.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | 21 | def GetResource(name, mode="rb"): 22 | with open(name, mode=mode) as f: 23 | return f.read() 24 | 25 | 26 | def GetResourceFilename(name, mode="rb"): 27 | del mode # Unused. 28 | return name 29 | 30 | 31 | GetResourceAsFile = open # pylint: disable=invalid-name 32 | -------------------------------------------------------------------------------- /kinpy/mjcf_parser/namescope.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """An object to manage the scoping of identifiers in MJCF models.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import collections 21 | 22 | import six 23 | 24 | from . import constants 25 | 26 | 27 | class NameScope(object): 28 | """A name scoping context for an MJCF model. 29 | 30 | This object maintains the uniqueness of identifiers within each MJCF 31 | namespace. Examples of MJCF namespaces include 'body', 'joint', and 'geom'. 32 | Each namescope also carries a name, and can have a parent namescope. 33 | When MJCF models are merged, all identifiers gain a hierarchical prefix 34 | separated by '/', which is the concatenation of all scope names up to 35 | the root namescope. 36 | """ 37 | 38 | def __init__(self, name, mjcf_model, model_dir="", assets=None): 39 | """Initializes a scope with the given name. 40 | 41 | Args: 42 | name: The scope's name 43 | mjcf_model: The RootElement of the MJCF model associated with this scope. 44 | model_dir: (optional) Path to the directory containing the model XML file. 45 | This is used to prefix the paths of all asset files. 46 | assets: (optional) A dictionary of pre-loaded assets, of the form 47 | `{filename: bytestring}`. If present, PyMJCF will search for assets in 48 | this dictionary before attempting to load them from the filesystem. 49 | """ 50 | self._parent = None 51 | self._name = name 52 | self._mjcf_model = mjcf_model 53 | self._namespaces = collections.defaultdict(dict) 54 | self._model_dir = model_dir 55 | self._files = set() 56 | self._assets = assets or {} 57 | self._revision = 0 58 | 59 | @property 60 | def revision(self): 61 | return self._revision 62 | 63 | def increment_revision(self): 64 | self._revision += 1 65 | for namescope in six.itervalues(self._namespaces["namescope"]): 66 | namescope.increment_revision() 67 | 68 | @property 69 | def name(self): 70 | """This scope's name.""" 71 | return self._name 72 | 73 | @property 74 | def files(self): 75 | """A set containing the `File` attributes registered in this scope.""" 76 | return self._files 77 | 78 | @property 79 | def assets(self): 80 | """A dictionary containing pre-loaded assets.""" 81 | return self._assets 82 | 83 | @property 84 | def model_dir(self): 85 | """Path to the directory containing the model XML file.""" 86 | return self._model_dir 87 | 88 | @name.setter 89 | def name(self, new_name): 90 | if self._parent: 91 | self._parent.add("namescope", new_name, self) 92 | self._parent.remove("namescope", self._name) 93 | self._name = new_name 94 | self.increment_revision() 95 | 96 | @property 97 | def mjcf_model(self): 98 | return self._mjcf_model 99 | 100 | @property 101 | def parent(self): 102 | """This parent `NameScope`, or `None` if this is a root scope.""" 103 | return self._parent 104 | 105 | @parent.setter 106 | def parent(self, new_parent): 107 | if self._parent: 108 | self._parent.remove("namescope", self._name) 109 | self._parent = new_parent 110 | if self._parent: 111 | self._parent.add("namescope", self._name, self) 112 | self.increment_revision() 113 | 114 | @property 115 | def root(self): 116 | if self._parent is None: 117 | return self 118 | else: 119 | return self._parent.root 120 | 121 | def full_prefix(self, prefix_root=None, as_list=False): 122 | """The prefix for identifiers belonging to this scope. 123 | 124 | Args: 125 | prefix_root: (optional) A `NameScope` object to be treated as root 126 | for the purpose of calculating the prefix. If `None` then no prefix 127 | is produced. 128 | as_list: (optional) A boolean, if `True` return the list of prefix 129 | components. If `False`, return the full prefix string separated by 130 | `mjcf.constants.PREFIX_SEPARATOR`. 131 | 132 | Returns: 133 | The prefix string. 134 | """ 135 | prefix_root = prefix_root or self 136 | if prefix_root != self and self._parent: 137 | prefix_list = self._parent.full_prefix(prefix_root, as_list=True) 138 | prefix_list.append(self._name) 139 | else: 140 | prefix_list = [] 141 | if as_list: 142 | return prefix_list 143 | else: 144 | if prefix_list: 145 | prefix_list.append("") 146 | return constants.PREFIX_SEPARATOR.join(prefix_list) 147 | 148 | def _assign(self, namespace, identifier, obj): 149 | """Checks a proposed identifier's validity before assigning to an object.""" 150 | namespace_dict = self._namespaces[namespace] 151 | if not isinstance(identifier, str): 152 | raise ValueError("Identifier must be a string: got {}".format(type(identifier))) 153 | elif constants.PREFIX_SEPARATOR in identifier: 154 | raise ValueError("Identifier cannot contain {!r}: got {}".format(constants.PREFIX_SEPARATOR, identifier)) 155 | else: 156 | namespace_dict[identifier] = obj 157 | 158 | def add(self, namespace, identifier, obj): 159 | """Add an identifier to this name scope. 160 | 161 | Args: 162 | namespace: A string specifying the namespace to which the 163 | identifier belongs. 164 | identifier: The identifier string. 165 | obj: The object referred to by the identifier. 166 | 167 | Raises: 168 | ValueError: If `identifier` not valid. 169 | """ 170 | namespace_dict = self._namespaces[namespace] 171 | if identifier in namespace_dict: 172 | raise ValueError("Duplicated identifier {!r} in namespace <{}>".format(identifier, namespace)) 173 | else: 174 | self._assign(namespace, identifier, obj) 175 | self.increment_revision() 176 | 177 | def replace(self, namespace, identifier, obj): 178 | """Reassociates an identifier with a different object. 179 | 180 | Args: 181 | namespace: A string specifying the namespace to which the 182 | identifier belongs. 183 | identifier: The identifier string. 184 | obj: The object referred to by the identifier. 185 | 186 | Raises: 187 | ValueError: If `identifier` not valid. 188 | """ 189 | self._assign(namespace, identifier, obj) 190 | self.increment_revision() 191 | 192 | def remove(self, namespace, identifier): 193 | """Removes an identifier from this name scope. 194 | 195 | Args: 196 | namespace: A string specifying the namespace to which the 197 | identifier belongs. 198 | identifier: The identifier string. 199 | 200 | Raises: 201 | KeyError: If `identifier` does not exist in this scope. 202 | """ 203 | del self._namespaces[namespace][identifier] 204 | self.increment_revision() 205 | 206 | def rename(self, namespace, old_identifier, new_identifier): 207 | obj = self.get(namespace, old_identifier) 208 | self.add(namespace, new_identifier, obj) 209 | self.remove(namespace, old_identifier) 210 | 211 | def get(self, namespace, identifier): 212 | return self._namespaces[namespace][identifier] 213 | 214 | def has_identifier(self, namespace, identifier): 215 | return identifier in self._namespaces[namespace] 216 | -------------------------------------------------------------------------------- /kinpy/mjcf_parser/parser.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Functions for parsing XML into an MJCF object model.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import os 21 | import sys 22 | 23 | import six 24 | from lxml import etree 25 | 26 | from . import constants, debugging, element 27 | from . import io as resources 28 | 29 | 30 | def from_xml_string(xml_string, escape_separators=False, model_dir="", resolve_references=True, assets=None): 31 | """Parses an XML string into an MJCF object model. 32 | 33 | Args: 34 | xml_string: An XML string representing an MJCF model. 35 | escape_separators: (optional) A boolean, whether to replace '/' characters 36 | in element identifiers. If `False`, any '/' present in the XML causes 37 | a ValueError to be raised. 38 | model_dir: (optional) Path to the directory containing the model XML file. 39 | This is used to prefix the paths of all asset files. 40 | resolve_references: (optional) A boolean indicating whether the parser 41 | should attempt to resolve reference attributes to a corresponding element. 42 | assets: (optional) A dictionary of pre-loaded assets, of the form 43 | `{filename: bytestring}`. If present, PyMJCF will search for assets in 44 | this dictionary before attempting to load them from the filesystem. 45 | 46 | Returns: 47 | An `mjcf.RootElement`. 48 | """ 49 | xml_root = etree.fromstring(xml_string) 50 | return _parse( 51 | xml_root, escape_separators, model_dir=model_dir, resolve_references=resolve_references, assets=assets 52 | ) 53 | 54 | 55 | def from_file(file_handle, escape_separators=False, model_dir="", resolve_references=True, assets=None): 56 | """Parses an XML file into an MJCF object model. 57 | 58 | Args: 59 | file_handle: A Python file-like handle. 60 | escape_separators: (optional) A boolean, whether to replace '/' characters 61 | in element identifiers. If `False`, any '/' present in the XML causes 62 | a ValueError to be raised. 63 | model_dir: (optional) Path to the directory containing the model XML file. 64 | This is used to prefix the paths of all asset files. 65 | resolve_references: (optional) A boolean indicating whether the parser 66 | should attempt to resolve reference attributes to a corresponding element. 67 | assets: (optional) A dictionary of pre-loaded assets, of the form 68 | `{filename: bytestring}`. If present, PyMJCF will search for assets in 69 | this dictionary before attempting to load them from the filesystem. 70 | 71 | Returns: 72 | An `mjcf.RootElement`. 73 | """ 74 | xml_root = etree.parse(file_handle).getroot() 75 | return _parse( 76 | xml_root, escape_separators, model_dir=model_dir, resolve_references=resolve_references, assets=assets 77 | ) 78 | 79 | 80 | def from_path(path, escape_separators=False, resolve_references=True, assets=None): 81 | """Parses an XML file into an MJCF object model. 82 | 83 | Args: 84 | path: A path to an XML file. This path should be loadable using 85 | `resources.GetResource`. 86 | escape_separators: (optional) A boolean, whether to replace '/' characters 87 | in element identifiers. If `False`, any '/' present in the XML causes 88 | a ValueError to be raised. 89 | resolve_references: (optional) A boolean indicating whether the parser 90 | should attempt to resolve reference attributes to a corresponding element. 91 | assets: (optional) A dictionary of pre-loaded assets, of the form 92 | `{filename: bytestring}`. If present, PyMJCF will search for assets in 93 | this dictionary before attempting to load them from the filesystem. 94 | 95 | Returns: 96 | An `mjcf.RootElement`. 97 | """ 98 | model_dir, _ = os.path.split(path) 99 | contents = resources.GetResource(path) 100 | xml_root = etree.fromstring(contents) 101 | return _parse( 102 | xml_root, escape_separators, model_dir=model_dir, resolve_references=resolve_references, assets=assets 103 | ) 104 | 105 | 106 | def _parse(xml_root, escape_separators=False, model_dir="", resolve_references=True, assets=None): 107 | """Parses a complete MJCF model from an XML. 108 | 109 | Args: 110 | xml_root: An `etree.Element` object. 111 | escape_separators: (optional) A boolean, whether to replace '/' characters 112 | in element identifiers. If `False`, any '/' present in the XML causes 113 | a ValueError to be raised. 114 | model_dir: (optional) Path to the directory containing the model XML file. 115 | This is used to prefix the paths of all asset files. 116 | resolve_references: (optional) A boolean indicating whether the parser 117 | should attempt to resolve reference attributes to a corresponding element. 118 | assets: (optional) A dictionary of pre-loaded assets, of the form 119 | `{filename: bytestring}`. If present, PyMJCF will search for assets in 120 | this dictionary before attempting to load them from the filesystem. 121 | 122 | Returns: 123 | An `mjcf.RootElement`. 124 | 125 | Raises: 126 | ValueError: If `xml_root`'s tag is not 'mujoco'. 127 | """ 128 | 129 | assets = assets or {} 130 | 131 | if xml_root.tag != "mujoco": 132 | raise ValueError("Root element of the XML should be : got <{}>".format(xml_root.tag)) 133 | 134 | with debugging.freeze_current_stack_trace(): 135 | # Recursively parse any included XML files. 136 | to_include = [] 137 | for include_tag in xml_root.findall("include"): 138 | try: 139 | # First look for the path to the included XML file in the assets dict. 140 | path_or_xml_string = assets[include_tag.attrib["file"]] 141 | parsing_func = from_xml_string 142 | except KeyError: 143 | # If it's not present in the assets dict then attempt to load the XML 144 | # from the filesystem. 145 | path_or_xml_string = os.path.join(model_dir, include_tag.attrib["file"]) 146 | parsing_func = from_path 147 | included_mjcf = parsing_func( 148 | path_or_xml_string, 149 | escape_separators=escape_separators, 150 | resolve_references=resolve_references, 151 | assets=assets, 152 | ) 153 | to_include.append(included_mjcf) 154 | # We must remove tags before parsing the main XML file, since 155 | # these are a schema violation. 156 | xml_root.remove(include_tag) 157 | 158 | # Parse the main XML file. 159 | try: 160 | model = xml_root.attrib.pop("model") 161 | except KeyError: 162 | model = None 163 | mjcf_root = element.RootElement(model=model, model_dir=model_dir, assets=assets) 164 | _parse_children(xml_root, mjcf_root, escape_separators) 165 | 166 | # Merge in the included XML files. 167 | for included_mjcf in to_include: 168 | # The included MJCF might have been automatically assigned a model name 169 | # that conficts with that of `mjcf_root`, so we override it here. 170 | included_mjcf.model = mjcf_root.model 171 | mjcf_root.include_copy(included_mjcf) 172 | 173 | if resolve_references: 174 | mjcf_root.resolve_references() 175 | return mjcf_root 176 | 177 | 178 | def _parse_children(xml_element, mjcf_element, escape_separators=False): 179 | """Parses all children of a given XML element into an MJCF element. 180 | 181 | Args: 182 | xml_element: The source `etree.Element` object. 183 | mjcf_element: The target `mjcf.Element` object. 184 | escape_separators: (optional) A boolean, whether to replace '/' characters 185 | in element identifiers. If `False`, any '/' present in the XML causes 186 | a ValueError to be raised. 187 | """ 188 | for xml_child in xml_element: 189 | if xml_child.tag is etree.Comment or xml_child.tag is etree.PI: 190 | continue 191 | try: 192 | child_spec = mjcf_element.spec.children[xml_child.tag] 193 | if escape_separators: 194 | attributes = {} 195 | for name, value in six.iteritems(xml_child.attrib): 196 | new_value = value.replace(constants.PREFIX_SEPARATOR_ESCAPE, constants.PREFIX_SEPARATOR_ESCAPE * 2) 197 | new_value = new_value.replace(constants.PREFIX_SEPARATOR, constants.PREFIX_SEPARATOR_ESCAPE) 198 | attributes[name] = new_value 199 | else: 200 | attributes = dict(xml_child.attrib) 201 | if child_spec.repeated or child_spec.on_demand: 202 | mjcf_child = mjcf_element.add(xml_child.tag, **attributes) 203 | else: 204 | mjcf_child = getattr(mjcf_element, xml_child.tag) 205 | mjcf_child.set_attributes(**attributes) 206 | except: # pylint: disable=bare-except 207 | err_type, err, traceback = sys.exc_info() 208 | message = "Line {}: error while parsing element <{}>: {}".format(xml_child.sourceline, xml_child.tag, err) 209 | six.reraise(err_type, err_type(message), traceback) 210 | _parse_children(xml_child, mjcf_child, escape_separators) 211 | -------------------------------------------------------------------------------- /kinpy/mjcf_parser/schema.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """A Python object representation of Mujoco's MJCF schema. 17 | 18 | The root schema is provided as a module-level constant `schema.MUJOCO`. 19 | """ 20 | 21 | from __future__ import absolute_import, division, print_function 22 | 23 | import collections 24 | import copy 25 | import os 26 | import pkgutil 27 | 28 | import six 29 | from lxml import etree 30 | 31 | from . import attribute 32 | from . import io as resources 33 | 34 | _SCHEMA_XML_PATH = "mjcf_parser/schema.xml" 35 | 36 | _ARRAY_DTYPE_MAP = {"int": int, "float": float, "string": str} 37 | 38 | _SCALAR_TYPE_MAP = {"int": attribute.Integer, "float": attribute.Float, "string": attribute.String} 39 | 40 | ElementSpec = collections.namedtuple( 41 | "ElementSpec", ("name", "repeated", "on_demand", "identifier", "namespace", "attributes", "children") 42 | ) 43 | 44 | AttributeSpec = collections.namedtuple( 45 | "AttributeSpec", ("name", "type", "required", "conflict_allowed", "conflict_behavior", "other_kwargs") 46 | ) 47 | 48 | # Additional namespaces that are not present in the MJCF schema but can 49 | # be used in `find` and `find_all`. 50 | _ADDITIONAL_FINDABLE_NAMESPACES = frozenset(["attachment_frame"]) 51 | 52 | 53 | def _str2bool(string): 54 | """Converts either 'true' or 'false' (not case-sensitively) into a boolean.""" 55 | if string is None: 56 | return False 57 | else: 58 | string = string.lower() 59 | 60 | if string == "true": 61 | return True 62 | elif string == "false": 63 | return False 64 | else: 65 | raise ValueError("String should either be `true` or `false`: got {}".format(string)) 66 | 67 | 68 | def parse_schema(schema_path): 69 | """Parses the schema XML. 70 | 71 | Args: 72 | schema_path: Path to the schema XML file. 73 | 74 | Returns: 75 | An `ElementSpec` for the root element in the schema. 76 | """ 77 | schema_xml = etree.fromstring(pkgutil.get_data("kinpy", schema_path)) 78 | return _parse_element(schema_xml) 79 | 80 | 81 | def _parse_element(element_xml): 82 | """Parses an element in the schema.""" 83 | name = element_xml.get("name") 84 | if not name: 85 | raise ValueError("Element must always have a name") 86 | repeated = _str2bool(element_xml.get("repeated")) 87 | on_demand = _str2bool(element_xml.get("on_demand")) 88 | 89 | attributes = collections.OrderedDict() 90 | attributes_xml = element_xml.find("attributes") 91 | if attributes_xml is not None: 92 | for attribute_xml in attributes_xml.findall("attribute"): 93 | attributes[attribute_xml.get("name")] = _parse_attribute(attribute_xml) 94 | 95 | identifier = None 96 | namespace = None 97 | for attribute_spec in six.itervalues(attributes): 98 | if attribute_spec.type == attribute.Identifier: 99 | identifier = attribute_spec.name 100 | namespace = element_xml.get("namespace") or name 101 | 102 | children = collections.OrderedDict() 103 | children_xml = element_xml.find("children") 104 | if children_xml is not None: 105 | for child_xml in children_xml.findall("element"): 106 | children[child_xml.get("name")] = _parse_element(child_xml) 107 | 108 | element_spec = ElementSpec(name, repeated, on_demand, identifier, namespace, attributes, children) 109 | 110 | recursive = _str2bool(element_xml.get("recursive")) 111 | if recursive: 112 | element_spec.children[name] = element_spec 113 | 114 | common_keys = set(element_spec.attributes).intersection(element_spec.children) 115 | if common_keys: 116 | raise RuntimeError( 117 | "Element '{}' contains the following attributes and children with " 118 | "the same name: '{}'. This violates the design assumptions of " 119 | "this library. Please file a bug report. Thank you.".format(name, sorted(common_keys)) 120 | ) 121 | 122 | return element_spec 123 | 124 | 125 | def _parse_attribute(attribute_xml): 126 | """Parses an element in the schema.""" 127 | name = attribute_xml.get("name") 128 | required = _str2bool(attribute_xml.get("required")) 129 | conflict_allowed = _str2bool(attribute_xml.get("conflict_allowed")) 130 | conflict_behavior = attribute_xml.get("conflict_behavior", "replace") 131 | attribute_type = attribute_xml.get("type") 132 | other_kwargs = {} 133 | if attribute_type == "keyword": 134 | attribute_callable = attribute.Keyword 135 | other_kwargs["valid_values"] = attribute_xml.get("valid_values").split(" ") 136 | elif attribute_type == "array": 137 | array_size_str = attribute_xml.get("array_size") 138 | attribute_callable = attribute.Array 139 | other_kwargs["length"] = int(array_size_str) if array_size_str else None 140 | other_kwargs["dtype"] = _ARRAY_DTYPE_MAP[attribute_xml.get("array_type")] 141 | elif attribute_type == "identifier": 142 | attribute_callable = attribute.Identifier 143 | elif attribute_type == "reference": 144 | attribute_callable = attribute.Reference 145 | other_kwargs["reference_namespace"] = attribute_xml.get("reference_namespace") or name 146 | elif attribute_type == "basepath": 147 | attribute_callable = attribute.BasePath 148 | other_kwargs["path_namespace"] = attribute_xml.get("path_namespace") 149 | elif attribute_type == "file": 150 | attribute_callable = attribute.File 151 | other_kwargs["path_namespace"] = attribute_xml.get("path_namespace") 152 | else: 153 | try: 154 | attribute_callable = _SCALAR_TYPE_MAP[attribute_type] 155 | except KeyError: 156 | raise ValueError("Invalid attribute type: {}".format(attribute_type)) 157 | 158 | return AttributeSpec( 159 | name=name, 160 | type=attribute_callable, 161 | required=required, 162 | conflict_allowed=conflict_allowed, 163 | conflict_behavior=conflict_behavior, 164 | other_kwargs=other_kwargs, 165 | ) 166 | 167 | 168 | def collect_namespaces(root_spec): 169 | """Constructs a set of namespaces in a given ElementSpec. 170 | 171 | Args: 172 | root_spec: An `ElementSpec` for the root element in the schema. 173 | 174 | Returns: 175 | A set of strings specifying the names of all the namespaces that are present 176 | in the spec. 177 | """ 178 | findable_namespaces = set() 179 | 180 | def update_namespaces_from_spec(spec): 181 | findable_namespaces.add(spec.namespace) 182 | for child_spec in six.itervalues(spec.children): 183 | if child_spec is not spec: 184 | update_namespaces_from_spec(child_spec) 185 | 186 | update_namespaces_from_spec(root_spec) 187 | return findable_namespaces 188 | 189 | 190 | MUJOCO = parse_schema(_SCHEMA_XML_PATH) 191 | FINDABLE_NAMESPACES = frozenset(collect_namespaces(MUJOCO).union(_ADDITIONAL_FINDABLE_NAMESPACES)) 192 | 193 | 194 | def _attachment_frame_spec(is_world_attachment): 195 | """Create specs for attachment frames. 196 | 197 | Attachment frames are specialized without an identifier. 198 | The only allowed children are joints which also don't have identifiers. 199 | 200 | Args: 201 | is_world_attachment: Whether we are creating a spec for attachments to 202 | worldbody. If `True`, allow as child. 203 | 204 | Returns: 205 | An `ElementSpec`. 206 | """ 207 | frame_spec = ElementSpec( 208 | "body", 209 | repeated=True, 210 | on_demand=False, 211 | identifier=None, 212 | namespace="body", 213 | attributes=collections.OrderedDict(), 214 | children=collections.OrderedDict(), 215 | ) 216 | 217 | body_spec = MUJOCO.children["worldbody"].children["body"] 218 | # 'name' and 'childclass' attributes are excluded. 219 | for attrib_name in ("mocap", "pos", "quat", "axisangle", "xyaxes", "zaxis", "euler"): 220 | frame_spec.attributes[attrib_name] = copy.deepcopy(body_spec.attributes[attrib_name]) 221 | 222 | inertial_spec = body_spec.children["inertial"] 223 | frame_spec.children["inertial"] = copy.deepcopy(inertial_spec) 224 | joint_spec = body_spec.children["joint"] 225 | frame_spec.children["joint"] = ElementSpec( 226 | "joint", 227 | repeated=True, 228 | on_demand=False, 229 | identifier=None, 230 | namespace="joint", 231 | attributes=copy.deepcopy(joint_spec.attributes), 232 | children=collections.OrderedDict(), 233 | ) 234 | 235 | if is_world_attachment: 236 | freejoint_spec = MUJOCO.children["worldbody"].children["body"].children["freejoint"] 237 | frame_spec.children["freejoint"] = ElementSpec( 238 | "freejoint", 239 | repeated=False, 240 | on_demand=True, 241 | identifier=None, 242 | namespace="joint", 243 | attributes=copy.deepcopy(freejoint_spec.attributes), 244 | children=collections.OrderedDict(), 245 | ) 246 | 247 | return frame_spec 248 | 249 | 250 | ATTACHMENT_FRAME = _attachment_frame_spec(is_world_attachment=False) 251 | WORLD_ATTACHMENT_FRAME = _attachment_frame_spec(is_world_attachment=True) 252 | -------------------------------------------------------------------------------- /kinpy/mjcf_parser/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-2018 The dm_control Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | 16 | """Various helper functions and classes.""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import sys 21 | 22 | import six 23 | 24 | DEFAULT_ENCODING = sys.getdefaultencoding() 25 | 26 | 27 | def to_binary_string(s): 28 | """Convert text string to binary.""" 29 | if isinstance(s, six.binary_type): 30 | return s 31 | return s.encode(DEFAULT_ENCODING) 32 | 33 | 34 | def to_native_string(s): 35 | """Convert a text or binary string to the native string format.""" 36 | if six.PY3 and isinstance(s, six.binary_type): 37 | return s.decode(DEFAULT_ENCODING) 38 | elif six.PY2 and isinstance(s, six.text_type): 39 | return s.encode(DEFAULT_ENCODING) 40 | else: 41 | return s 42 | -------------------------------------------------------------------------------- /kinpy/sdf.py: -------------------------------------------------------------------------------- 1 | import io 2 | from typing import List, TextIO, Union 3 | 4 | import numpy as np 5 | 6 | from . import chain, frame, transform 7 | from .urdf_parser_py.sdf import SDF, Box, Cylinder, Mesh, Sphere 8 | 9 | JOINT_TYPE_MAP = {"revolute": "revolute", "prismatic": "prismatic", "fixed": "fixed"} 10 | 11 | 12 | def _convert_transform(pose: np.ndarray) -> transform.Transform: 13 | if pose is None: 14 | return transform.Transform() 15 | else: 16 | return transform.Transform(rot=pose[3:], pos=pose[:3]) 17 | 18 | 19 | def _convert_visuals(visuals: List) -> List: 20 | vlist = [] 21 | for v in visuals: 22 | v_tf = _convert_transform(v.pose) 23 | if isinstance(v.geometry, Mesh): 24 | g_type = "mesh" 25 | g_param = v.geometry.filename 26 | elif isinstance(v.geometry, Cylinder): 27 | g_type = "cylinder" 28 | v_tf = v_tf * transform.Transform(rot=np.deg2rad([90.0, 0.0, 0.0])) 29 | g_param = (v.geometry.radius, v.geometry.length) 30 | elif isinstance(v.geometry, Box): 31 | g_type = "box" 32 | g_param = v.geometry.size 33 | elif isinstance(v.geometry, Sphere): 34 | g_type = "sphere" 35 | g_param = v.geometry.radius 36 | else: 37 | g_type = None 38 | g_param = None 39 | vlist.append(frame.Visual(v_tf, g_type, g_param)) 40 | return vlist 41 | 42 | 43 | def _build_chain_recurse(root_frame, lmap, joints) -> List: 44 | children = [] 45 | for j in joints: 46 | if j.parent == root_frame.link.name: 47 | child_frame = frame.Frame(j.child + "_frame") 48 | link_p = lmap[j.parent] 49 | link_c = lmap[j.child] 50 | t_p = _convert_transform(link_p.pose) 51 | t_c = _convert_transform(link_c.pose) 52 | child_frame.joint = frame.Joint( 53 | j.name, offset=t_p.inverse() * t_c, joint_type=JOINT_TYPE_MAP[j.type], axis=j.axis.xyz 54 | ) 55 | child_frame.link = frame.Link( 56 | link_c.name, offset=transform.Transform(), visuals=_convert_visuals(link_c.visuals) 57 | ) 58 | child_frame.children = _build_chain_recurse(child_frame, lmap, joints) 59 | children.append(child_frame) 60 | return children 61 | 62 | 63 | def build_chain_from_sdf(data: Union[str, TextIO]) -> chain.Chain: 64 | """ 65 | Build a Chain object from SDF data. 66 | 67 | Parameters 68 | ---------- 69 | data : str or TextIO 70 | SDF string data or file object. 71 | 72 | Returns 73 | ------- 74 | chain.Chain 75 | Chain object created from SDF. 76 | """ 77 | if isinstance(data, io.TextIOBase): 78 | data = data.read() 79 | sdf = SDF.from_xml_string(data) 80 | robot = sdf.model 81 | lmap = robot.link_map 82 | joints = robot.joints 83 | n_joints = len(joints) 84 | has_root = [True for _ in range(len(joints))] 85 | for i in range(n_joints): 86 | for j in range(i + 1, n_joints): 87 | if joints[i].parent == joints[j].child: 88 | has_root[i] = False 89 | elif joints[j].parent == joints[i].child: 90 | has_root[j] = False 91 | for i in range(n_joints): 92 | if has_root[i]: 93 | root_link = lmap[joints[i].parent] 94 | break 95 | root_frame = frame.Frame(root_link.name + "_frame") 96 | root_frame.joint = frame.Joint(offset=_convert_transform(root_link.pose)) 97 | root_frame.link = frame.Link(root_link.name, transform.Transform(), _convert_visuals(root_link.visuals)) 98 | root_frame.children = _build_chain_recurse(root_frame, lmap, joints) 99 | return chain.Chain(root_frame) 100 | -------------------------------------------------------------------------------- /kinpy/transform.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import numpy as np 4 | import transformations as tf 5 | 6 | 7 | class Transform: 8 | """This class calculates the rotation and translation of a 3D rigid body. 9 | 10 | Attributes 11 | ---------- 12 | rot : np.ndarray 13 | The rotation parameter. Give in quaternions or roll pitch yaw. 14 | pos : np.ndarray 15 | The translation parameter. 16 | """ 17 | 18 | def __init__(self, rot: Union[List, np.ndarray, None] = None, pos: Optional[np.ndarray] = None) -> None: 19 | if rot is None: 20 | rot = [1.0, 0.0, 0.0, 0.0] 21 | if pos is None: 22 | pos = np.zeros(3) 23 | if len(rot) == 3: 24 | self.rot = tf.quaternion_from_euler(*rot) 25 | elif len(rot) == 4: 26 | self.rot = np.array(rot) 27 | else: 28 | raise ValueError("Size of rot must be 3 or 4.") 29 | self.pos = np.array(pos) 30 | 31 | def __repr__(self) -> str: 32 | return "Transform(rot={0}, pos={1})".format(self.rot, self.pos) 33 | 34 | @staticmethod 35 | def _rotation_vec(rot: np.ndarray, vec: np.ndarray) -> np.ndarray: 36 | v4 = np.hstack([np.array([0.0]), vec]) 37 | inv_rot = tf.quaternion_inverse(rot) 38 | ans = tf.quaternion_multiply(tf.quaternion_multiply(rot, v4), inv_rot) 39 | return ans[1:] 40 | 41 | def __mul__(self, other: "Transform") -> "Transform": 42 | rot = tf.quaternion_multiply(self.rot, other.rot) 43 | pos = self._rotation_vec(self.rot, other.pos) + self.pos 44 | return Transform(rot, pos) 45 | 46 | def inverse(self) -> "Transform": 47 | rot = tf.quaternion_inverse(self.rot) 48 | pos = -self._rotation_vec(rot, self.pos) 49 | return Transform(rot, pos) 50 | 51 | def matrix(self) -> np.ndarray: 52 | mat = tf.quaternion_matrix(self.rot) 53 | mat[:3, 3] = self.pos 54 | return mat 55 | 56 | @property 57 | def rot_mat(self) -> np.ndarray: 58 | return tf.quaternion_matrix(self.rot)[:3, :3] 59 | 60 | @property 61 | def rot_euler(self) -> np.ndarray: 62 | return tf.euler_from_quaternion(self.rot) 63 | -------------------------------------------------------------------------------- /kinpy/urdf.py: -------------------------------------------------------------------------------- 1 | import io 2 | from typing import List, TextIO, Union 3 | 4 | from . import chain, frame, transform 5 | from .urdf_parser_py import urdf 6 | 7 | JOINT_TYPE_MAP = {"revolute": "revolute", "continuous": "revolute", "prismatic": "prismatic", "fixed": "fixed"} 8 | 9 | 10 | def _convert_transform(origin) -> transform.Transform: 11 | if origin is None: 12 | return transform.Transform() 13 | else: 14 | return transform.Transform(rot=origin.rpy, pos=origin.xyz) 15 | 16 | 17 | def _convert_visual(visual) -> frame.Visual: 18 | if visual is None or visual.geometry is None: 19 | return frame.Visual() 20 | else: 21 | v_tf = _convert_transform(visual.origin) 22 | if isinstance(visual.geometry, urdf.Mesh): 23 | g_type = "mesh" 24 | g_param = visual.geometry.filename 25 | elif isinstance(visual.geometry, urdf.Cylinder): 26 | g_type = "cylinder" 27 | g_param = (visual.geometry.radius, visual.geometry.length) 28 | elif isinstance(visual.geometry, urdf.Box): 29 | g_type = "box" 30 | g_param = visual.geometry.size 31 | elif isinstance(visual.geometry, urdf.Sphere): 32 | g_type = "sphere" 33 | g_param = visual.geometry.radius 34 | else: 35 | g_type = None 36 | g_param = None 37 | return frame.Visual(v_tf, g_type, g_param) 38 | 39 | 40 | def _build_chain_recurse(root_frame, lmap, joints) -> List[frame.Frame]: 41 | children = [] 42 | for j in joints: 43 | if j.parent == root_frame.link.name: 44 | child_frame = frame.Frame(j.child + "_frame") 45 | child_frame.joint = frame.Joint( 46 | j.name, offset=_convert_transform(j.origin), joint_type=JOINT_TYPE_MAP[j.type], axis=j.axis 47 | ) 48 | link = lmap[j.child] 49 | child_frame.link = frame.Link( 50 | link.name, offset=_convert_transform(link.origin), visuals=[_convert_visual(link.visual)] 51 | ) 52 | child_frame.children = _build_chain_recurse(child_frame, lmap, joints) 53 | children.append(child_frame) 54 | return children 55 | 56 | 57 | def build_chain_from_urdf(data: Union[str, TextIO]) -> chain.Chain: 58 | """ 59 | Build a Chain object from URDF data. 60 | 61 | Parameters 62 | ---------- 63 | data : str or TextIO 64 | URDF string data or file object. 65 | 66 | Returns 67 | ------- 68 | chain.Chain 69 | Chain object created from URDF. 70 | 71 | Example 72 | ------- 73 | >>> import kinpy as kp 74 | >>> data = ''' 75 | ... 76 | ... 77 | ... 78 | ... 79 | ... 80 | ... 81 | ... ''' 82 | >>> chain = kp.build_chain_from_urdf(data) 83 | >>> print(chain) 84 | link1_frame 85 | └──── link2_frame 86 | 87 | """ 88 | if isinstance(data, io.TextIOBase): 89 | data = data.read() 90 | robot = urdf.URDF.from_xml_string(data) 91 | lmap = robot.link_map 92 | joints = robot.joints 93 | n_joints = len(joints) 94 | has_root = [True for _ in range(len(joints))] 95 | for i in range(n_joints): 96 | for j in range(i + 1, n_joints): 97 | if joints[i].parent == joints[j].child: 98 | has_root[i] = False 99 | elif joints[j].parent == joints[i].child: 100 | has_root[j] = False 101 | for i in range(n_joints): 102 | if has_root[i]: 103 | root_link = lmap[joints[i].parent] 104 | break 105 | root_frame = frame.Frame(root_link.name + "_frame") 106 | root_frame.joint = frame.Joint() 107 | root_frame.link = frame.Link( 108 | root_link.name, _convert_transform(root_link.origin), [_convert_visual(root_link.visual)] 109 | ) 110 | root_frame.children = _build_chain_recurse(root_frame, lmap, joints) 111 | return chain.Chain(root_frame) 112 | 113 | 114 | def build_serial_chain_from_urdf( 115 | data: Union[str, TextIO], end_link_name: str, root_link_name: str = "" 116 | ) -> chain.SerialChain: 117 | """ 118 | Build a SerialChain object from urdf data. 119 | 120 | Parameters 121 | ---------- 122 | data : str or TextIO 123 | URDF string data or file object. 124 | end_link_name : str 125 | The name of the link that is the end effector. 126 | root_link_name : str, optional 127 | The name of the root link. 128 | 129 | Returns 130 | ------- 131 | chain.SerialChain 132 | SerialChain object created from URDF. 133 | """ 134 | if isinstance(data, io.TextIOBase): 135 | data = data.read() 136 | urdf_chain = build_chain_from_urdf(data) 137 | return chain.SerialChain( 138 | urdf_chain, end_link_name + "_frame", "" if root_link_name == "" else root_link_name + "_frame" 139 | ) 140 | -------------------------------------------------------------------------------- /kinpy/urdf_parser_py/__init__.py: -------------------------------------------------------------------------------- 1 | from . import sdf, urdf 2 | -------------------------------------------------------------------------------- /kinpy/urdf_parser_py/sdf.py: -------------------------------------------------------------------------------- 1 | from . import xml_reflection as xmlr 2 | from .xml_reflection.basics import * 3 | 4 | # What is the scope of plugins? Model, World, Sensor? 5 | 6 | xmlr.start_namespace("sdf") 7 | 8 | 9 | name_attribute = xmlr.Attribute("name", str, False) 10 | pose_element = xmlr.Element("pose", "vector6", False) 11 | 12 | 13 | class Inertia(xmlr.Object): 14 | KEYS = ["ixx", "ixy", "ixz", "iyy", "iyz", "izz"] 15 | 16 | def __init__(self, ixx=0.0, ixy=0.0, ixz=0.0, iyy=0.0, iyz=0.0, izz=0.0): 17 | self.ixx = ixx 18 | self.ixy = ixy 19 | self.ixz = ixz 20 | self.iyy = iyy 21 | self.iyz = iyz 22 | self.izz = izz 23 | 24 | def to_matrix(self): 25 | return [[self.ixx, self.ixy, self.ixz], [self.ixy, self.iyy, self.iyz], [self.ixz, self.iyz, self.izz]] 26 | 27 | 28 | xmlr.reflect(Inertia, params=[xmlr.Element(key, float) for key in Inertia.KEYS]) 29 | 30 | # Pretty much copy-paste... Better method? 31 | # Use multiple inheritance to separate the objects out so they are unique? 32 | 33 | 34 | class Inertial(xmlr.Object): 35 | def __init__(self, mass=0.0, inertia=None, pose=None): 36 | self.mass = mass 37 | self.inertia = inertia 38 | self.pose = pose 39 | 40 | 41 | xmlr.reflect(Inertial, params=[xmlr.Element("mass", float), xmlr.Element("inertia", Inertia), pose_element]) 42 | 43 | 44 | class Box(xmlr.Object): 45 | def __init__(self, size=None): 46 | self.size = size 47 | 48 | 49 | xmlr.reflect(Box, tag="box", params=[xmlr.Element("size", "vector3")]) 50 | 51 | 52 | class Cylinder(xmlr.Object): 53 | def __init__(self, radius=0.0, length=0.0): 54 | self.radius = radius 55 | self.length = length 56 | 57 | 58 | xmlr.reflect(Cylinder, tag="cylinder", params=[xmlr.Element("radius", float), xmlr.Element("length", float)]) 59 | 60 | 61 | class Sphere(xmlr.Object): 62 | def __init__(self, radius=0.0): 63 | self.radius = radius 64 | 65 | 66 | xmlr.reflect(Sphere, tag="sphere", params=[xmlr.Element("radius", float)]) 67 | 68 | 69 | class Mesh(xmlr.Object): 70 | def __init__(self, filename=None, scale=None): 71 | self.filename = filename 72 | self.scale = scale 73 | 74 | 75 | xmlr.reflect( 76 | Mesh, tag="mesh", params=[xmlr.Element("filename", str), xmlr.Element("scale", "vector3", required=False)] 77 | ) 78 | 79 | 80 | class GeometricType(xmlr.ValueType): 81 | def __init__(self): 82 | self.factory = xmlr.FactoryType( 83 | "geometric", {"box": Box, "cylinder": Cylinder, "sphere": Sphere, "mesh": Mesh} 84 | ) 85 | 86 | def from_xml(self, node, path): 87 | children = xml_children(node) 88 | assert len(children) == 1, "One element only for geometric" 89 | return self.factory.from_xml(children[0], path=path) 90 | 91 | def write_xml(self, node, obj): 92 | name = self.factory.get_name(obj) 93 | child = node_add(node, name) 94 | obj.write_xml(child) 95 | 96 | 97 | xmlr.add_type("geometric", GeometricType()) 98 | 99 | 100 | class Script(xmlr.Object): 101 | def __init__(self, uri=None, name=None): 102 | self.uri = uri 103 | self.name = name 104 | 105 | 106 | xmlr.reflect(Script, tag="script", params=[xmlr.Element("name", str, False), xmlr.Element("uri", str, False)]) 107 | 108 | 109 | class Material(xmlr.Object): 110 | def __init__(self, name=None, script=None): 111 | self.name = name 112 | self.script = script 113 | 114 | 115 | xmlr.reflect(Material, tag="material", params=[name_attribute, xmlr.Element("script", Script, False)]) 116 | 117 | 118 | class Visual(xmlr.Object): 119 | def __init__(self, name=None, geometry=None, pose=None): 120 | self.name = name 121 | self.geometry = geometry 122 | self.pose = pose 123 | 124 | 125 | xmlr.reflect( 126 | Visual, 127 | tag="visual", 128 | params=[ 129 | name_attribute, 130 | xmlr.Element("geometry", "geometric"), 131 | xmlr.Element("material", Material, False), 132 | pose_element, 133 | ], 134 | ) 135 | 136 | 137 | class Collision(xmlr.Object): 138 | def __init__(self, name=None, geometry=None, pose=None): 139 | self.name = name 140 | self.geometry = geometry 141 | self.pose = pose 142 | 143 | 144 | xmlr.reflect(Collision, tag="collision", params=[name_attribute, xmlr.Element("geometry", "geometric"), pose_element]) 145 | 146 | 147 | class Dynamics(xmlr.Object): 148 | def __init__(self, damping=None, friction=None): 149 | self.damping = damping 150 | self.friction = friction 151 | 152 | 153 | xmlr.reflect( 154 | Dynamics, tag="dynamics", params=[xmlr.Element("damping", float, False), xmlr.Element("friction", float, False)] 155 | ) 156 | 157 | 158 | class Limit(xmlr.Object): 159 | def __init__(self, lower=None, upper=None): 160 | self.lower = lower 161 | self.upper = upper 162 | 163 | 164 | xmlr.reflect(Limit, tag="limit", params=[xmlr.Element("lower", float, False), xmlr.Element("upper", float, False)]) 165 | 166 | 167 | class Axis(xmlr.Object): 168 | def __init__(self, xyz=None, limit=None, dynamics=None, use_parent_model_frame=None): 169 | self.xyz = xyz 170 | self.limit = limit 171 | self.dynamics = dynamics 172 | self.use_parent_model_frame = use_parent_model_frame 173 | 174 | 175 | xmlr.reflect( 176 | Axis, 177 | tag="axis", 178 | params=[ 179 | xmlr.Element("xyz", "vector3"), 180 | xmlr.Element("limit", Limit, False), 181 | xmlr.Element("dynamics", Dynamics, False), 182 | xmlr.Element("use_parent_model_frame", bool, False), 183 | ], 184 | ) 185 | 186 | 187 | class Joint(xmlr.Object): 188 | TYPES = ["unknown", "revolute", "gearbox", "revolute2", "prismatic", "ball", "screw", "universal", "fixed"] 189 | 190 | def __init__(self, name=None, parent=None, child=None, joint_type=None, axis=None, pose=None): 191 | self.aggregate_init() 192 | self.name = name 193 | self.parent = parent 194 | self.child = child 195 | self.type = joint_type 196 | self.axis = axis 197 | self.pose = pose 198 | 199 | # Aliases 200 | @property 201 | def joint_type(self): 202 | return self.type 203 | 204 | @joint_type.setter 205 | def joint_type(self, value): 206 | self.type = value 207 | 208 | 209 | xmlr.reflect( 210 | Joint, 211 | tag="joint", 212 | params=[ 213 | name_attribute, 214 | xmlr.Attribute("type", str, False), 215 | xmlr.Element("axis", Axis), 216 | xmlr.Element("parent", str), 217 | xmlr.Element("child", str), 218 | pose_element, 219 | ], 220 | ) 221 | 222 | 223 | class Link(xmlr.Object): 224 | def __init__(self, name=None, pose=None, inertial=None, kinematic=False): 225 | self.aggregate_init() 226 | self.name = name 227 | self.pose = pose 228 | self.inertial = inertial 229 | self.kinematic = kinematic 230 | self.visuals = [] 231 | self.collisions = [] 232 | 233 | 234 | xmlr.reflect( 235 | Link, 236 | tag="link", 237 | params=[ 238 | name_attribute, 239 | xmlr.Element("inertial", Inertial), 240 | xmlr.Attribute("kinematic", bool, False), 241 | xmlr.AggregateElement("visual", Visual, var="visuals"), 242 | xmlr.AggregateElement("collision", Collision, var="collisions"), 243 | pose_element, 244 | ], 245 | ) 246 | 247 | 248 | class Model(xmlr.Object): 249 | def __init__(self, name=None, pose=None): 250 | self.aggregate_init() 251 | self.name = name 252 | self.pose = pose 253 | self.links = [] 254 | self.joints = [] 255 | self.joint_map = {} 256 | self.link_map = {} 257 | 258 | self.parent_map = {} 259 | self.child_map = {} 260 | 261 | def add_aggregate(self, typeName, elem): 262 | xmlr.Object.add_aggregate(self, typeName, elem) 263 | 264 | if typeName == "joint": 265 | joint = elem 266 | self.joint_map[joint.name] = joint 267 | self.parent_map[joint.child] = (joint.name, joint.parent) 268 | if joint.parent in self.child_map: 269 | self.child_map[joint.parent].append((joint.name, joint.child)) 270 | else: 271 | self.child_map[joint.parent] = [(joint.name, joint.child)] 272 | elif typeName == "link": 273 | link = elem 274 | self.link_map[link.name] = link 275 | 276 | def add_link(self, link): 277 | self.add_aggregate("link", link) 278 | 279 | def add_joint(self, joint): 280 | self.add_aggregate("joint", joint) 281 | 282 | 283 | xmlr.reflect( 284 | Model, 285 | tag="model", 286 | params=[ 287 | name_attribute, 288 | xmlr.AggregateElement("link", Link, var="links"), 289 | xmlr.AggregateElement("joint", Joint, var="joints"), 290 | pose_element, 291 | ], 292 | ) 293 | 294 | 295 | class SDF(xmlr.Object): 296 | def __init__(self, version=None): 297 | self.version = version 298 | 299 | 300 | xmlr.reflect( 301 | SDF, 302 | tag="sdf", 303 | params=[ 304 | xmlr.Attribute("version", str, False), 305 | xmlr.Element("model", Model, False), 306 | ], 307 | ) 308 | 309 | 310 | xmlr.end_namespace() 311 | -------------------------------------------------------------------------------- /kinpy/urdf_parser_py/urdf.py: -------------------------------------------------------------------------------- 1 | from . import xml_reflection as xmlr 2 | from .xml_reflection.basics import * 3 | 4 | # Add a 'namespace' for names to avoid a conflict between URDF and SDF? 5 | # A type registry? How to scope that? Just make a 'global' type pointer? 6 | # Or just qualify names? urdf.geometric, sdf.geometric 7 | 8 | xmlr.start_namespace("urdf") 9 | 10 | xmlr.add_type("element_link", xmlr.SimpleElementType("link", str)) 11 | xmlr.add_type("element_xyz", xmlr.SimpleElementType("xyz", "vector3")) 12 | 13 | verbose = True 14 | 15 | 16 | class Pose(xmlr.Object): 17 | def __init__(self, xyz=None, rpy=None): 18 | self.xyz = xyz 19 | self.rpy = rpy 20 | 21 | def check_valid(self): 22 | assert (self.xyz is None or len(self.xyz) == 3) and (self.rpy is None or len(self.rpy) == 3) 23 | 24 | # Aliases for backwards compatibility 25 | @property 26 | def rotation(self): 27 | return self.rpy 28 | 29 | @rotation.setter 30 | def rotation(self, value): 31 | self.rpy = value 32 | 33 | @property 34 | def position(self): 35 | return self.xyz 36 | 37 | @position.setter 38 | def position(self, value): 39 | self.xyz = value 40 | 41 | 42 | xmlr.reflect( 43 | Pose, 44 | tag="origin", 45 | params=[ 46 | xmlr.Attribute("xyz", "vector3", False, default=[0, 0, 0]), 47 | xmlr.Attribute("rpy", "vector3", False, default=[0, 0, 0]), 48 | ], 49 | ) 50 | 51 | 52 | # Common stuff 53 | name_attribute = xmlr.Attribute("name", str) 54 | origin_element = xmlr.Element("origin", Pose, False) 55 | 56 | 57 | class Color(xmlr.Object): 58 | def __init__(self, *args): 59 | # What about named colors? 60 | count = len(args) 61 | if count == 4 or count == 3: 62 | self.rgba = args 63 | elif count == 1: 64 | self.rgba = args[0] 65 | elif count == 0: 66 | self.rgba = None 67 | if self.rgba is not None: 68 | if len(self.rgba) == 3: 69 | self.rgba += [1.0] 70 | if len(self.rgba) != 4: 71 | raise Exception("Invalid color argument count") 72 | 73 | 74 | xmlr.reflect(Color, tag="color", params=[xmlr.Attribute("rgba", "vector4")]) 75 | 76 | 77 | class JointDynamics(xmlr.Object): 78 | def __init__(self, damping=None, friction=None): 79 | self.damping = damping 80 | self.friction = friction 81 | 82 | 83 | xmlr.reflect( 84 | JointDynamics, 85 | tag="dynamics", 86 | params=[xmlr.Attribute("damping", float, False), xmlr.Attribute("friction", float, False)], 87 | ) 88 | 89 | 90 | class Box(xmlr.Object): 91 | def __init__(self, size=None): 92 | self.size = size 93 | 94 | 95 | xmlr.reflect(Box, tag="box", params=[xmlr.Attribute("size", "vector3")]) 96 | 97 | 98 | class Cylinder(xmlr.Object): 99 | def __init__(self, radius=0.0, length=0.0): 100 | self.radius = radius 101 | self.length = length 102 | 103 | 104 | xmlr.reflect(Cylinder, tag="cylinder", params=[xmlr.Attribute("radius", float), xmlr.Attribute("length", float)]) 105 | 106 | 107 | class Sphere(xmlr.Object): 108 | def __init__(self, radius=0.0): 109 | self.radius = radius 110 | 111 | 112 | xmlr.reflect(Sphere, tag="sphere", params=[xmlr.Attribute("radius", float)]) 113 | 114 | 115 | class Mesh(xmlr.Object): 116 | def __init__(self, filename=None, scale=None): 117 | self.filename = filename 118 | self.scale = scale 119 | 120 | 121 | xmlr.reflect( 122 | Mesh, tag="mesh", params=[xmlr.Attribute("filename", str), xmlr.Attribute("scale", "vector3", required=False)] 123 | ) 124 | 125 | 126 | class GeometricType(xmlr.ValueType): 127 | def __init__(self): 128 | self.factory = xmlr.FactoryType( 129 | "geometric", {"box": Box, "cylinder": Cylinder, "sphere": Sphere, "mesh": Mesh} 130 | ) 131 | 132 | def from_xml(self, node, path): 133 | children = xml_children(node) 134 | assert len(children) == 1, "One element only for geometric" 135 | return self.factory.from_xml(children[0], path=path) 136 | 137 | def write_xml(self, node, obj): 138 | name = self.factory.get_name(obj) 139 | child = node_add(node, name) 140 | obj.write_xml(child) 141 | 142 | 143 | xmlr.add_type("geometric", GeometricType()) 144 | 145 | 146 | class Collision(xmlr.Object): 147 | def __init__(self, geometry=None, origin=None): 148 | self.geometry = geometry 149 | self.origin = origin 150 | 151 | 152 | xmlr.reflect(Collision, tag="collision", params=[origin_element, xmlr.Element("geometry", "geometric")]) 153 | 154 | 155 | class Texture(xmlr.Object): 156 | def __init__(self, filename=None): 157 | self.filename = filename 158 | 159 | 160 | xmlr.reflect(Texture, tag="texture", params=[xmlr.Attribute("filename", str)]) 161 | 162 | 163 | class Material(xmlr.Object): 164 | def __init__(self, name=None, color=None, texture=None): 165 | self.name = name 166 | self.color = color 167 | self.texture = texture 168 | 169 | def check_valid(self): 170 | if self.color is None and self.texture is None: 171 | xmlr.on_error("Material has neither a color nor texture.") 172 | 173 | 174 | xmlr.reflect( 175 | Material, 176 | tag="material", 177 | params=[name_attribute, xmlr.Element("color", Color, False), xmlr.Element("texture", Texture, False)], 178 | ) 179 | 180 | 181 | class LinkMaterial(Material): 182 | def check_valid(self): 183 | pass 184 | 185 | 186 | class Visual(xmlr.Object): 187 | def __init__(self, geometry=None, material=None, origin=None): 188 | self.geometry = geometry 189 | self.material = material 190 | self.origin = origin 191 | 192 | 193 | xmlr.reflect( 194 | Visual, 195 | tag="visual", 196 | params=[origin_element, xmlr.Element("geometry", "geometric"), xmlr.Element("material", LinkMaterial, False)], 197 | ) 198 | 199 | 200 | class Inertia(xmlr.Object): 201 | KEYS = ["ixx", "ixy", "ixz", "iyy", "iyz", "izz"] 202 | 203 | def __init__(self, ixx=0.0, ixy=0.0, ixz=0.0, iyy=0.0, iyz=0.0, izz=0.0): 204 | self.ixx = ixx 205 | self.ixy = ixy 206 | self.ixz = ixz 207 | self.iyy = iyy 208 | self.iyz = iyz 209 | self.izz = izz 210 | 211 | def to_matrix(self): 212 | return [[self.ixx, self.ixy, self.ixz], [self.ixy, self.iyy, self.iyz], [self.ixz, self.iyz, self.izz]] 213 | 214 | 215 | xmlr.reflect(Inertia, tag="inertia", params=[xmlr.Attribute(key, float) for key in Inertia.KEYS]) 216 | 217 | 218 | class Inertial(xmlr.Object): 219 | def __init__(self, mass=0.0, inertia=None, origin=None): 220 | self.mass = mass 221 | self.inertia = inertia 222 | self.origin = origin 223 | 224 | 225 | xmlr.reflect( 226 | Inertial, 227 | tag="inertial", 228 | params=[origin_element, xmlr.Element("mass", "element_value"), xmlr.Element("inertia", Inertia, False)], 229 | ) 230 | 231 | 232 | # FIXME: we are missing the reference position here. 233 | class JointCalibration(xmlr.Object): 234 | def __init__(self, rising=None, falling=None): 235 | self.rising = rising 236 | self.falling = falling 237 | 238 | 239 | xmlr.reflect( 240 | JointCalibration, 241 | tag="calibration", 242 | params=[xmlr.Attribute("rising", float, False, 0), xmlr.Attribute("falling", float, False, 0)], 243 | ) 244 | 245 | 246 | class JointLimit(xmlr.Object): 247 | def __init__(self, effort=None, velocity=None, lower=None, upper=None): 248 | self.effort = effort 249 | self.velocity = velocity 250 | self.lower = lower 251 | self.upper = upper 252 | 253 | 254 | xmlr.reflect( 255 | JointLimit, 256 | tag="limit", 257 | params=[ 258 | xmlr.Attribute("effort", float), 259 | xmlr.Attribute("lower", float, False, 0), 260 | xmlr.Attribute("upper", float, False, 0), 261 | xmlr.Attribute("velocity", float), 262 | ], 263 | ) 264 | 265 | # FIXME: we are missing __str__ here. 266 | 267 | 268 | class JointMimic(xmlr.Object): 269 | def __init__(self, joint_name=None, multiplier=None, offset=None): 270 | self.joint = joint_name 271 | self.multiplier = multiplier 272 | self.offset = offset 273 | 274 | 275 | xmlr.reflect( 276 | JointMimic, 277 | tag="mimic", 278 | params=[ 279 | xmlr.Attribute("joint", str), 280 | xmlr.Attribute("multiplier", float, False), 281 | xmlr.Attribute("offset", float, False), 282 | ], 283 | ) 284 | 285 | 286 | class SafetyController(xmlr.Object): 287 | def __init__(self, velocity=None, position=None, lower=None, upper=None): 288 | self.k_velocity = velocity 289 | self.k_position = position 290 | self.soft_lower_limit = lower 291 | self.soft_upper_limit = upper 292 | 293 | 294 | xmlr.reflect( 295 | SafetyController, 296 | tag="safety_controller", 297 | params=[ 298 | xmlr.Attribute("k_velocity", float), 299 | xmlr.Attribute("k_position", float, False, 0), 300 | xmlr.Attribute("soft_lower_limit", float, False, 0), 301 | xmlr.Attribute("soft_upper_limit", float, False, 0), 302 | ], 303 | ) 304 | 305 | 306 | class Joint(xmlr.Object): 307 | TYPES = ["unknown", "revolute", "continuous", "prismatic", "floating", "planar", "fixed"] 308 | 309 | def __init__( 310 | self, 311 | name=None, 312 | parent=None, 313 | child=None, 314 | joint_type=None, 315 | axis=None, 316 | origin=None, 317 | limit=None, 318 | dynamics=None, 319 | safety_controller=None, 320 | calibration=None, 321 | mimic=None, 322 | ): 323 | self.name = name 324 | self.parent = parent 325 | self.child = child 326 | self.type = joint_type 327 | self.axis = axis 328 | self.origin = origin 329 | self.limit = limit 330 | self.dynamics = dynamics 331 | self.safety_controller = safety_controller 332 | self.calibration = calibration 333 | self.mimic = mimic 334 | 335 | def check_valid(self): 336 | assert self.type in self.TYPES, "Invalid joint type: {}".format(self.type) # noqa 337 | 338 | # Aliases 339 | @property 340 | def joint_type(self): 341 | return self.type 342 | 343 | @joint_type.setter 344 | def joint_type(self, value): 345 | self.type = value 346 | 347 | 348 | xmlr.reflect( 349 | Joint, 350 | tag="joint", 351 | params=[ 352 | name_attribute, 353 | xmlr.Attribute("type", str), 354 | origin_element, 355 | xmlr.Element("axis", "element_xyz", False), 356 | xmlr.Element("parent", "element_link"), 357 | xmlr.Element("child", "element_link"), 358 | xmlr.Element("limit", JointLimit, False), 359 | xmlr.Element("dynamics", JointDynamics, False), 360 | xmlr.Element("safety_controller", SafetyController, False), 361 | xmlr.Element("calibration", JointCalibration, False), 362 | xmlr.Element("mimic", JointMimic, False), 363 | ], 364 | ) 365 | 366 | 367 | class Link(xmlr.Object): 368 | def __init__(self, name=None, visual=None, inertial=None, collision=None, origin=None): 369 | self.aggregate_init() 370 | self.name = name 371 | self.visuals = [] 372 | self.inertial = inertial 373 | self.collisions = [] 374 | self.origin = origin 375 | 376 | def __get_visual(self): 377 | """Return the first visual or None.""" 378 | if self.visuals: 379 | return self.visuals[0] 380 | 381 | def __set_visual(self, visual): 382 | """Set the first visual.""" 383 | if self.visuals: 384 | self.visuals[0] = visual 385 | else: 386 | self.visuals.append(visual) 387 | 388 | def __get_collision(self): 389 | """Return the first collision or None.""" 390 | if self.collisions: 391 | return self.collisions[0] 392 | 393 | def __set_collision(self, collision): 394 | """Set the first collision.""" 395 | if self.collisions: 396 | self.collisions[0] = collision 397 | else: 398 | self.collisions.append(collision) 399 | 400 | # Properties for backwards compatibility 401 | visual = property(__get_visual, __set_visual) 402 | collision = property(__get_collision, __set_collision) 403 | 404 | 405 | xmlr.reflect( 406 | Link, 407 | tag="link", 408 | params=[ 409 | name_attribute, 410 | origin_element, 411 | xmlr.AggregateElement("visual", Visual), 412 | xmlr.AggregateElement("collision", Collision), 413 | xmlr.Element("inertial", Inertial, False), 414 | ], 415 | ) 416 | 417 | 418 | class PR2Transmission(xmlr.Object): 419 | def __init__(self, name=None, joint=None, actuator=None, type=None, mechanicalReduction=1): 420 | self.name = name 421 | self.type = type 422 | self.joint = joint 423 | self.actuator = actuator 424 | self.mechanicalReduction = mechanicalReduction 425 | 426 | 427 | xmlr.reflect( 428 | PR2Transmission, 429 | tag="pr2_transmission", 430 | params=[ 431 | name_attribute, 432 | xmlr.Attribute("type", str), 433 | xmlr.Element("joint", "element_name"), 434 | xmlr.Element("actuator", "element_name"), 435 | xmlr.Element("mechanicalReduction", float), 436 | ], 437 | ) 438 | 439 | 440 | class Actuator(xmlr.Object): 441 | def __init__(self, name=None, mechanicalReduction=1): 442 | self.name = name 443 | self.mechanicalReduction = None 444 | 445 | 446 | xmlr.reflect( 447 | Actuator, tag="actuator", params=[name_attribute, xmlr.Element("mechanicalReduction", float, required=False)] 448 | ) 449 | 450 | 451 | class TransmissionJoint(xmlr.Object): 452 | def __init__(self, name=None): 453 | self.aggregate_init() 454 | self.name = name 455 | self.hardwareInterfaces = [] 456 | 457 | def check_valid(self): 458 | assert len(self.hardwareInterfaces) > 0, "no hardwareInterface defined" 459 | 460 | 461 | xmlr.reflect( 462 | TransmissionJoint, 463 | tag="joint", 464 | params=[ 465 | name_attribute, 466 | xmlr.AggregateElement("hardwareInterface", str), 467 | ], 468 | ) 469 | 470 | 471 | class Transmission(xmlr.Object): 472 | """New format: http://wiki.ros.org/urdf/XML/Transmission""" 473 | 474 | def __init__(self, name=None): 475 | self.aggregate_init() 476 | self.name = name 477 | self.joints = [] 478 | self.actuators = [] 479 | 480 | def check_valid(self): 481 | assert len(self.joints) > 0, "no joint defined" 482 | assert len(self.actuators) > 0, "no actuator defined" 483 | 484 | 485 | xmlr.reflect( 486 | Transmission, 487 | tag="new_transmission", 488 | params=[ 489 | name_attribute, 490 | xmlr.Element("type", str), 491 | xmlr.AggregateElement("joint", TransmissionJoint), 492 | xmlr.AggregateElement("actuator", Actuator), 493 | ], 494 | ) 495 | 496 | xmlr.add_type("transmission", xmlr.DuckTypedFactory("transmission", [Transmission, PR2Transmission])) 497 | 498 | 499 | class Robot(xmlr.Object): 500 | def __init__(self, name=None): 501 | self.aggregate_init() 502 | 503 | self.name = name 504 | self.joints = [] 505 | self.links = [] 506 | self.materials = [] 507 | self.gazebos = [] 508 | self.transmissions = [] 509 | 510 | self.joint_map = {} 511 | self.link_map = {} 512 | 513 | self.parent_map = {} 514 | self.child_map = {} 515 | 516 | def add_aggregate(self, typeName, elem): 517 | xmlr.Object.add_aggregate(self, typeName, elem) 518 | 519 | if typeName == "joint": 520 | joint = elem 521 | self.joint_map[joint.name] = joint 522 | self.parent_map[joint.child] = (joint.name, joint.parent) 523 | if joint.parent in self.child_map: 524 | self.child_map[joint.parent].append((joint.name, joint.child)) 525 | else: 526 | self.child_map[joint.parent] = [(joint.name, joint.child)] 527 | elif typeName == "link": 528 | link = elem 529 | self.link_map[link.name] = link 530 | 531 | def add_link(self, link): 532 | self.add_aggregate("link", link) 533 | 534 | def add_joint(self, joint): 535 | self.add_aggregate("joint", joint) 536 | 537 | def get_chain(self, root, tip, joints=True, links=True, fixed=True): 538 | chain = [] 539 | if links: 540 | chain.append(tip) 541 | link = tip 542 | while link != root: 543 | (joint, parent) = self.parent_map[link] 544 | if joints: 545 | if fixed or self.joint_map[joint].joint_type != "fixed": 546 | chain.append(joint) 547 | if links: 548 | chain.append(parent) 549 | link = parent 550 | chain.reverse() 551 | return chain 552 | 553 | def get_root(self): 554 | root = None 555 | for link in self.link_map: 556 | if link not in self.parent_map: 557 | assert root is None, "Multiple roots detected, invalid URDF." 558 | root = link 559 | assert root is not None, "No roots detected, invalid URDF." 560 | return root 561 | 562 | @classmethod 563 | def from_parameter_server(cls, key="robot_description"): 564 | """ 565 | Retrieve the robot model on the parameter server 566 | and parse it to create a URDF robot structure. 567 | 568 | Warning: this requires roscore to be running. 569 | """ 570 | # Could move this into xml_reflection 571 | import rospy 572 | 573 | return cls.from_xml_string(rospy.get_param(key)) 574 | 575 | 576 | xmlr.reflect( 577 | Robot, 578 | tag="robot", 579 | params=[ 580 | xmlr.Attribute("name", str, False), # Is 'name' a required attribute? 581 | xmlr.AggregateElement("link", Link), 582 | xmlr.AggregateElement("joint", Joint), 583 | xmlr.AggregateElement("gazebo", xmlr.RawType()), 584 | xmlr.AggregateElement("transmission", "transmission"), 585 | xmlr.AggregateElement("material", Material), 586 | ], 587 | ) 588 | 589 | # Make an alias 590 | URDF = Robot 591 | 592 | xmlr.end_namespace() 593 | -------------------------------------------------------------------------------- /kinpy/urdf_parser_py/xml_reflection/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /kinpy/urdf_parser_py/xml_reflection/basics.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import string 3 | 4 | import yaml 5 | from lxml import etree 6 | 7 | 8 | def xml_string(rootXml, addHeader=True): 9 | # Meh 10 | xmlString = etree.tostring(rootXml, pretty_print=True, encoding="unicode") 11 | if addHeader: 12 | xmlString = '\n' + xmlString 13 | return xmlString 14 | 15 | 16 | def dict_sub(obj, keys): 17 | return dict((key, obj[key]) for key in keys) 18 | 19 | 20 | def node_add(doc, sub): 21 | if sub is None: 22 | return None 23 | if type(sub) == str: 24 | return etree.SubElement(doc, sub) 25 | elif isinstance(sub, etree._Element): 26 | doc.append(sub) # This screws up the rest of the tree for prettyprint 27 | return sub 28 | else: 29 | raise Exception("Invalid sub value") 30 | 31 | 32 | def pfloat(x): 33 | return str(x).rstrip(".") 34 | 35 | 36 | def xml_children(node): 37 | children = node.getchildren() 38 | 39 | def predicate(node): 40 | return not isinstance(node, etree._Comment) 41 | 42 | return list(filter(predicate, children)) 43 | 44 | 45 | def isstring(obj): 46 | try: 47 | return isinstance(obj, basestring) 48 | except NameError: 49 | return isinstance(obj, str) 50 | 51 | 52 | def to_yaml(obj): 53 | """Simplify yaml representation for pretty printing""" 54 | # Is there a better way to do this by adding a representation with 55 | # yaml.Dumper? 56 | # Ordered dict: http://pyyaml.org/ticket/29#comment:11 57 | if obj is None or isstring(obj): 58 | out = str(obj) 59 | elif type(obj) in [int, float, bool]: 60 | return obj 61 | elif hasattr(obj, "to_yaml"): 62 | out = obj.to_yaml() 63 | elif isinstance(obj, etree._Element): 64 | out = etree.tostring(obj, pretty_print=True) 65 | elif type(obj) == dict: 66 | out = {} 67 | for (var, value) in obj.items(): 68 | out[str(var)] = to_yaml(value) 69 | elif hasattr(obj, "tolist"): 70 | # For numpy objects 71 | out = to_yaml(obj.tolist()) 72 | elif isinstance(obj, collections.Iterable): 73 | out = [to_yaml(item) for item in obj] 74 | else: 75 | out = str(obj) 76 | return out 77 | 78 | 79 | class SelectiveReflection(object): 80 | def get_refl_vars(self): 81 | return list(vars(self).keys()) 82 | 83 | 84 | class YamlReflection(SelectiveReflection): 85 | def to_yaml(self): 86 | raw = dict((var, getattr(self, var)) for var in self.get_refl_vars()) 87 | return to_yaml(raw) 88 | 89 | def __str__(self): 90 | # Good idea? Will it remove other important things? 91 | return yaml.dump(self.to_yaml()).rstrip() 92 | -------------------------------------------------------------------------------- /kinpy/visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from functools import partial 4 | from typing import Dict, List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import transformations as trf 8 | import vtk 9 | from vtk.util.colors import tomato 10 | 11 | from . import transform 12 | from .chain import Chain 13 | from .frame import Visual 14 | 15 | 16 | class Visualizer: 17 | def __init__(self, win_size: Tuple[int, int] = (640, 480)) -> None: 18 | self._actors: Dict[str, list] = defaultdict(list) 19 | self._axes: Dict[str, vtk.vtkAxesActor] = {} 20 | self._ren = vtk.vtkRenderer() 21 | self._ren.SetBackground(0.1, 0.2, 0.4) 22 | self._win = vtk.vtkRenderWindow() 23 | self._win.SetSize(*win_size) 24 | self._win.AddRenderer(self._ren) 25 | self._inter = vtk.vtkRenderWindowInteractor() 26 | self._inter.SetRenderWindow(self._win) 27 | 28 | def add_robot( 29 | self, 30 | transformations: Dict[str, transform.Transform], 31 | visuals_map: Dict[str, List[Visual]], 32 | mesh_file_path: str = "./", 33 | axes: bool = False, 34 | ) -> None: 35 | for k, trans in transformations.items(): 36 | if axes: 37 | self.add_axes(trans, geom_name=k) 38 | for v in visuals_map[k]: 39 | tf = trans * v.offset 40 | if v.geom_type == "mesh": 41 | self.add_mesh(os.path.join(mesh_file_path, v.geom_param), tf, geom_name=k) 42 | elif v.geom_type == "cylinder": 43 | self.add_cylinder(v.geom_param[0], v.geom_param[1], tf, geom_name=k) 44 | elif v.geom_type == "box": 45 | self.add_box(v.geom_param, tf, geom_name=k) 46 | elif v.geom_type == "sphere": 47 | self.add_sphere(v.geom_param, tf, geom_name=k) 48 | elif v.geom_type == "capsule": 49 | self.add_capsule(v.geom_param[0], v.geom_param[1], tf, geom_name=k) 50 | 51 | def add_shape_source( 52 | self, source: vtk.vtkAbstractPolyDataReader, transform: transform.Transform, geom_name: Optional[str] = None 53 | ) -> None: 54 | mapper = vtk.vtkPolyDataMapper() 55 | mapper.SetInputConnection(source.GetOutputPort()) 56 | actor = vtk.vtkActor() 57 | actor.SetMapper(mapper) 58 | actor.GetProperty().SetColor(tomato) 59 | actor.SetPosition(transform.pos) 60 | rpy = np.rad2deg(trf.euler_from_quaternion(transform.rot, "rxyz")) 61 | actor.RotateX(rpy[0]) 62 | actor.RotateY(rpy[1]) 63 | actor.RotateZ(rpy[2]) 64 | if geom_name: 65 | self._actors[geom_name].append(actor) 66 | self._ren.AddActor(actor) 67 | 68 | def add_axes(self, trans: transform.Transform, geom_name: Optional[str] = None) -> None: 69 | transform = vtk.vtkTransform() 70 | transform.Translate(trans.pos) 71 | rpy = np.rad2deg(trf.euler_from_quaternion(trans.rot, "rxyz")) 72 | transform.RotateX(rpy[0]) 73 | transform.RotateY(rpy[1]) 74 | transform.RotateZ(rpy[2]) 75 | axes = vtk.vtkAxesActor() 76 | axes.SetTotalLength(0.1, 0.1, 0.1) 77 | axes.AxisLabelsOff() 78 | axes.SetUserTransform(transform) 79 | if geom_name: 80 | self._axes[geom_name] = axes 81 | self._ren.AddActor(axes) 82 | 83 | def load_obj(self, filename: str) -> vtk.vtkOBJReader: 84 | reader = vtk.vtkOBJReader() 85 | reader.SetFileName(filename) 86 | return reader 87 | 88 | def load_ply(self, filename: str) -> vtk.vtkPLYReader: 89 | reader = vtk.vtkPLYReader() 90 | reader.SetFileName(filename) 91 | return reader 92 | 93 | def load_stl(self, filename: str) -> vtk.vtkSTLReader: 94 | reader = vtk.vtkSTLReader() 95 | reader.SetFileName(filename) 96 | return reader 97 | 98 | def add_cylinder( 99 | self, radius: float, height: float, tf: Optional[transform.Transform] = None, geom_name: Optional[str] = None 100 | ) -> None: 101 | tf = tf or transform.Transform() 102 | cylinder = vtk.vtkCylinderSource() 103 | cylinder.SetResolution(20) 104 | cylinder.SetRadius(radius) 105 | cylinder.SetHeight(height) 106 | self.add_shape_source(cylinder, tf, geom_name) 107 | 108 | def add_box( 109 | self, size: List[float], tf: Optional[transform.Transform] = None, geom_name: Optional[str] = None 110 | ) -> None: 111 | tf = tf or transform.Transform() 112 | cube = vtk.vtkCubeSource() 113 | cube.SetXLength(size[0]) 114 | cube.SetYLength(size[1]) 115 | cube.SetZLength(size[2]) 116 | self.add_shape_source(cube, tf, geom_name) 117 | 118 | def add_sphere( 119 | self, radius: float, tf: Optional[transform.Transform] = None, geom_name: Optional[str] = None 120 | ) -> None: 121 | tf = tf or transform.Transform() 122 | sphere = vtk.vtkSphereSource() 123 | sphere.SetRadius(radius) 124 | self.add_shape_source(sphere, tf, geom_name) 125 | 126 | def add_capsule( 127 | self, 128 | radius: float, 129 | fromto: np.ndarray, 130 | tf: Optional[transform.Transform] = None, 131 | step: float = 0.05, 132 | geom_name: Optional[str] = None, 133 | ) -> None: 134 | tf = tf or transform.Transform() 135 | spheres = vtk.vtkAppendPolyData() 136 | d_norm = np.linalg.norm(fromto[3:] - fromto[:3]) 137 | direction = (fromto[3:] - fromto[:3]) / d_norm 138 | offset = transform.Transform( 139 | rot=trf.quaternion_about_axis( 140 | np.arccos(np.dot(tf.rot_mat[:, 2], direction)), np.cross(tf.rot_mat[:, 2], direction) 141 | ), 142 | pos=fromto[:3], 143 | ) 144 | for t in np.arange(0.0, 1.0, step): 145 | sphere = vtk.vtkSphereSource() 146 | sphere.SetRadius(radius) 147 | sphere.SetCenter(0, 0, t * d_norm) 148 | spheres.AddInputConnection(sphere.GetOutputPort()) 149 | self.add_shape_source(spheres, tf * offset, geom_name) 150 | 151 | def add_mesh( 152 | self, filename: str, tf: Optional[transform.Transform] = None, geom_name: Optional[str] = None 153 | ) -> None: 154 | tf = tf or transform.Transform() 155 | _, ext = os.path.splitext(filename) 156 | ext = ext.lower() 157 | if ext == ".stl": 158 | reader = self.load_stl(filename) 159 | elif ext == ".obj": 160 | reader = self.load_obj(filename) 161 | elif ext == ".ply": 162 | reader = self.load_ply(filename) 163 | else: 164 | raise ValueError("Unsupported file extension, '%s'." % ext) 165 | self.add_shape_source(reader, tf, geom_name) 166 | 167 | def spin(self) -> None: 168 | self._win.Render() 169 | self._inter.Initialize() 170 | self._inter.Start() 171 | 172 | 173 | class JointAngleEditor(Visualizer): 174 | def __init__( 175 | self, 176 | chain: Chain, 177 | mesh_file_path: str = "./", 178 | axes: bool = False, 179 | initial_state: Optional[Union[Dict[str, float], List[float]]] = None, 180 | ) -> None: 181 | super().__init__() 182 | if initial_state is None: 183 | initial_state = {} 184 | self._chain = chain 185 | if isinstance(initial_state, (list, np.ndarray)): 186 | initial_state = {k: v for k, v in zip(self._chain.get_joint_parameter_names(), initial_state)} 187 | self._joint_angles: Dict[str, float] = initial_state 188 | self._visuals_map = self._chain.visuals_map() 189 | self.add_robot( 190 | self._chain.forward_kinematics(self._joint_angles, end_only=False), # type: ignore 191 | self._visuals_map, 192 | mesh_file_path, 193 | axes, 194 | ) 195 | self._sliders = self._set_joint_slider(chain) 196 | 197 | def _update_joint_angle(self, obj: vtk.vtkSliderWidget, event: str, joint_name: str) -> None: 198 | slider_rep = obj.GetRepresentation() 199 | self._joint_angles[joint_name] = np.deg2rad(slider_rep.GetValue()) 200 | positions = self._chain.forward_kinematics(self._joint_angles, end_only=False) # type: ignore 201 | for k, position in positions.items(): 202 | if k in self._axes: 203 | transform = vtk.vtkTransform() 204 | transform.Translate(position.pos) 205 | rpy = np.rad2deg(trf.euler_from_quaternion(position.rot, "rxyz")) 206 | transform.RotateX(rpy[0]) 207 | transform.RotateY(rpy[1]) 208 | transform.RotateZ(rpy[2]) 209 | self._axes[k].SetUserTransform(transform) 210 | self._axes[k].Modified() 211 | actors = self._actors[k] 212 | for i, actor in enumerate(actors): 213 | actor.SetOrientation(0, 0, 0) 214 | trans = position * self._visuals_map[k][i].offset 215 | actor.SetPosition(trans.pos) 216 | rpy = np.rad2deg(trf.euler_from_quaternion(trans.rot, "rxyz")) 217 | actor.RotateX(rpy[0]) 218 | actor.RotateY(rpy[1]) 219 | actor.RotateZ(rpy[2]) 220 | actor.Modified() 221 | self._win.Render() 222 | 223 | def _set_joint_slider(self, chain: Chain) -> List[vtk.vtkSliderWidget]: 224 | sliders = [] 225 | for i, frame in enumerate(chain): 226 | if frame.joint.joint_type == "fixed": 227 | continue 228 | slider_rep = vtk.vtkSliderRepresentation2D() 229 | slider_rep.SetMinimumValue(-180) 230 | slider_rep.SetMaximumValue(180) 231 | slider_rep.SetValue(np.rad2deg(self._joint_angles.get(frame.joint.name, 0.0))) 232 | slider_rep.SetSliderLength(0.05) 233 | slider_rep.SetSliderWidth(0.02) 234 | slider_rep.SetEndCapLength(0.01) 235 | slider_rep.SetEndCapWidth(0.02) 236 | slider_rep.GetPoint1Coordinate().SetCoordinateSystemToNormalizedDisplay() 237 | slider_rep.GetPoint1Coordinate().SetValue(0.75, 1.0 - 0.05 * i) 238 | slider_rep.GetPoint2Coordinate().SetCoordinateSystemToNormalizedDisplay() 239 | slider_rep.GetPoint2Coordinate().SetValue(0.98, 1.0 - 0.05 * i) 240 | 241 | # Set the background color of the slider representation 242 | slider_rep.GetSliderProperty().SetColor(0.2, 0.2, 0.2) 243 | slider_rep.GetTubeProperty().SetColor(0.7, 0.7, 0.7) 244 | slider_rep.GetCapProperty().SetColor(0.2, 0.2, 0.2) 245 | 246 | slider_widget = vtk.vtkSliderWidget() 247 | slider_widget.SetInteractor(self._inter) 248 | slider_widget.SetRepresentation(slider_rep) 249 | slider_widget.AddObserver( 250 | vtk.vtkCommand.InteractionEvent, 251 | partial(self._update_joint_angle, joint_name=frame.joint.name), 252 | ) 253 | sliders.append(slider_widget) 254 | return sliders 255 | 256 | def spin(self) -> None: 257 | self._win.Render() 258 | self._inter.Initialize() 259 | for slider in self._sliders: 260 | slider.SetEnabled(True) 261 | slider.On() 262 | self._inter.Start() 263 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "kinpy" 3 | version = "0.4.2" 4 | description = "" 5 | authors = ["neka-nat "] 6 | license = "MIT" 7 | 8 | [tool.poetry.dependencies] 9 | python = ">=3.9,<3.12" 10 | numpy = [ 11 | {version = ">=1.24.3,<=1.26.0", python = "3.9"}, 12 | {version = "^1.24.3", python = ">=3.10,<3.12"}, 13 | ] 14 | scipy = "^1.11.2" 15 | transformations = "^2020.1.1" 16 | absl-py = "^0.11.0" 17 | lxml = "^4.6.2" 18 | PyYAML = "^6.0.1" 19 | vtk = "^9.0.1" 20 | 21 | [tool.poetry.group.dev.dependencies] 22 | twine = "^3.3.0" 23 | isort = "^5.9.3" 24 | black = "^22.6.0" 25 | flake8 = "^5.0.4" 26 | flake8-bugbear = "^22.8.23" 27 | flake8-simplify = "^0.19.3" 28 | mypy = "^1.0.1" 29 | 30 | [tool.mypy] 31 | python_version = "3.9" 32 | ignore_missing_imports = true 33 | 34 | [build-system] 35 | requires = ["poetry-core>=1.0.0", "setuptools"] 36 | build-backend = "poetry.core.masonry.api" 37 | -------------------------------------------------------------------------------- /scripts/kpviewer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from typing import Any 3 | 4 | import kinpy as kp 5 | 6 | 7 | def main(args: Any) -> None: 8 | chain = kp.build_chain_from_file(args.filename) 9 | print(chain) 10 | if args.show3d: 11 | param_names = chain.get_joint_parameter_names() 12 | th = {name: 0.0 for name in param_names} 13 | ret = chain.forward_kinematics(th) 14 | viz = kp.Visualizer() 15 | viz.add_robot(ret, chain.visuals_map()) 16 | viz.spin() 17 | 18 | 19 | if __name__ == "__main__": 20 | import argparse 21 | parser = argparse.ArgumentParser(description="Robot model viewer.") 22 | parser.add_argument("filename", type=str, help="Robot model filename.") 23 | parser.add_argument("--show3d", action="store_true", help="Show 3D model on GUI window.") 24 | args = parser.parse_args() 25 | main(args) 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | __version__ = '0.4.2' 4 | 5 | 6 | setup( 7 | name='kinpy', 8 | version=__version__, 9 | author='neka-nat', 10 | author_email='nekanat.stock@gmail.com', 11 | description='Robotics kinematic calculation toolkit', 12 | license='MIT', 13 | keywords='robot kinematics', 14 | url='http://github.com/neka-nat/kinpy', 15 | packages=find_packages(exclude=["tests"]), #['kinpy'], 16 | include_package_data = True, 17 | package_data = {'': ['kinpy/mjcf_parser/schema.xml']}, 18 | long_description=open('README.md').read(), 19 | long_description_content_type='text/markdown', 20 | install_requires=['numpy', 'scipy', 'absl-py', 'pyyaml', 21 | 'lxml', 'transformations', 'vtk'], 22 | ) 23 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_fkik.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import kinpy as kp 4 | 5 | 6 | class TestFkIk(unittest.TestCase): 7 | def test_fkik(self): 8 | data = ''\ 9 | ''\ 10 | ''\ 11 | ''\ 12 | ''\ 13 | ''\ 14 | ''\ 15 | ''\ 16 | ''\ 17 | ''\ 18 | ''\ 19 | ''\ 20 | ''\ 21 | ''\ 22 | '' 23 | chain = kp.build_serial_chain_from_urdf(data, 'link3') 24 | th1 = np.random.rand(2) 25 | tg = chain.forward_kinematics(th1) 26 | th2 = chain.inverse_kinematics(tg) 27 | self.assertTrue(np.allclose(th1, th2, atol=1.0e-6)) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() -------------------------------------------------------------------------------- /tests/test_jacobian.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import kinpy as kp 4 | 5 | 6 | class TestJacobian(unittest.TestCase): 7 | def test_jacobian1(self): 8 | data = ''\ 9 | ''\ 10 | ''\ 11 | ''\ 12 | ''\ 13 | ''\ 14 | ''\ 15 | ''\ 16 | ''\ 17 | ''\ 18 | ''\ 19 | ''\ 20 | ''\ 21 | ''\ 22 | '' 23 | chain = kp.build_serial_chain_from_urdf(data, 'link3') 24 | jc = chain.jacobian([0.0, 0.0]) 25 | np.testing.assert_equal(np.array([[0.0, 0.0], 26 | [1.0, 0.0], 27 | [0.0, 0.0], 28 | [0.0, 0.0], 29 | [0.0, 0.0], 30 | [1.0, 1.0]]), jc) 31 | 32 | def test_jacobian2(self): 33 | data = ''\ 34 | ''\ 35 | ''\ 36 | ''\ 37 | ''\ 38 | ''\ 39 | ''\ 40 | ''\ 41 | ''\ 42 | ''\ 43 | ''\ 44 | ''\ 45 | ''\ 46 | ''\ 47 | '' 48 | chain = kp.build_serial_chain_from_urdf(data, 'link3') 49 | jc = chain.jacobian([0.0, 0.0]) 50 | np.testing.assert_equal(np.array([[0.0, 0.0], 51 | [1.0, 0.0], 52 | [0.0, 1.0], 53 | [0.0, 0.0], 54 | [0.0, 0.0], 55 | [1.0, 0.0]]), jc) 56 | 57 | def test_jacobian3(self): 58 | chain = kp.build_serial_chain_from_urdf(open("examples/kuka_iiwa/model.urdf").read(), "lbr_iiwa_link_7") 59 | th = [0.0, -np.pi / 4.0, 0.0, np.pi / 2.0, 0.0, np.pi / 4.0, 0.0] 60 | jc = chain.jacobian(th) 61 | np.testing.assert_almost_equal(np.array([[0, 1.41421356e-02, 0, 2.82842712e-01, 0, 0, 0], 62 | [-6.60827561e-01, 0, -4.57275649e-01, 0, 5.72756493e-02, 0, 0], 63 | [0, 6.60827561e-01, 0, -3.63842712e-01, 0, 8.10000000e-02, 0], 64 | [0, 0, -7.07106781e-01, 0, -7.07106781e-01, 0, -1], 65 | [0, 1, 0, -1, 0, 1, 0], 66 | [1, 0, 7.07106781e-01, 0, -7.07106781e-01, 0, 0]]), jc) 67 | 68 | if __name__ == "__main__": 69 | unittest.main() -------------------------------------------------------------------------------- /tests/test_transform.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import kinpy as kp 4 | 5 | 6 | def random_transform(): 7 | euler = np.random.rand(3) * 2.0 * np.pi - np.pi 8 | return kp.Transform(euler, np.random.rand(3)) 9 | 10 | 11 | class TestTransform(unittest.TestCase): 12 | def test_multiply(self): 13 | t = random_transform() 14 | res = t * t.inverse() 15 | self.assertTrue(np.allclose(res.rot, np.array([1.0, 0.0, 0.0, 0.0]), atol=1.0e-6)) 16 | self.assertTrue(np.allclose(res.pos, np.array([0.0, 0.0, 0.0]), atol=1.0e-6)) 17 | 18 | 19 | if __name__ == "__main__": 20 | unittest.main() --------------------------------------------------------------------------------