├── .flake8
├── .github
└── workflows
│ └── ubuntu.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── assets
├── ant.png
├── humanoid.png
├── joint_angle_editor.gif
├── kuka.png
├── logo.png
└── simple_arm.png
├── examples
├── ant
│ └── ant.xml
├── ant_example.py
├── humanoid
│ └── humanoid.xml
├── humanoid_example.py
├── kuka_example.py
├── kuka_iiwa
│ ├── meshes
│ │ ├── link_0.obj
│ │ ├── link_1.obj
│ │ ├── link_2.obj
│ │ ├── link_3.obj
│ │ ├── link_4.obj
│ │ ├── link_5.obj
│ │ ├── link_6.obj
│ │ └── link_7.obj
│ └── model.urdf
├── mycobot
│ ├── G_base.ply
│ ├── camera_flange.ply
│ ├── joint1.ply
│ ├── joint1.png
│ ├── joint2.ply
│ ├── joint2.png
│ ├── joint3.ply
│ ├── joint3.png
│ ├── joint4.ply
│ ├── joint4.png
│ ├── joint5.ply
│ ├── joint5.png
│ ├── joint6.ply
│ ├── joint6.png
│ ├── joint7.ply
│ ├── joint7.png
│ ├── mycobot.urdf
│ └── pump_head.ply
├── mycobot_example.py
├── simple_arm
│ └── model.sdf
├── simple_arm_example.py
├── ur
│ └── ur.urdf
└── ur_example.py
├── kinpy
├── __init__.py
├── chain.py
├── frame.py
├── ik.py
├── jacobian.py
├── mjcf.py
├── mjcf_parser
│ ├── __init__.py
│ ├── attribute.py
│ ├── base.py
│ ├── constants.py
│ ├── copier.py
│ ├── debugging.py
│ ├── element.py
│ ├── io.py
│ ├── namescope.py
│ ├── parser.py
│ ├── schema.py
│ ├── schema.xml
│ └── util.py
├── sdf.py
├── transform.py
├── urdf.py
├── urdf_parser_py
│ ├── __init__.py
│ ├── sdf.py
│ ├── urdf.py
│ └── xml_reflection
│ │ ├── __init__.py
│ │ ├── basics.py
│ │ └── core.py
└── visualizer.py
├── pyproject.toml
├── scripts
└── kpviewer.py
├── setup.py
└── tests
├── __init__.py
├── test_fkik.py
├── test_jacobian.py
└── test_transform.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 119
3 | exclude = __init__.py,__pycache__
4 | ignore = E121,E123,E126,E133,E226,E203,E241,E242,E704,W503,W504,W505,E127,E266,E402,W605,W391,E701,E731
5 |
--------------------------------------------------------------------------------
/.github/workflows/ubuntu.yml:
--------------------------------------------------------------------------------
1 | name: Ubuntu CI
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 | tags: ['v*']
7 | pull_request:
8 | branches: [ master ]
9 |
10 | jobs:
11 | test:
12 | runs-on: ubuntu-latest
13 | strategy:
14 | matrix:
15 | python-version: ['3.9', '3.10', '3.11']
16 | steps:
17 | - name: Checkout source code
18 | uses: actions/checkout@v2
19 | with:
20 | submodules: true
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install dependencies
26 | run: |
27 | sudo apt-get update
28 | sudo apt-get install -y curl
29 | - name: Install poetry
30 | run: |
31 | curl -sSL https://install.python-poetry.org | python3 -
32 | echo "$HOME/.poetry/bin" >> $GITHUB_PATH
33 | - name: Setup
34 | run: make setup
35 | - name: Test
36 | run: make test
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.~
3 | *.egg-info
4 | build
5 | dist
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 neka-nat
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include kinpy/mjcf_parser/schema.xml
2 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | setup:
2 | poetry install --no-interaction
3 | poetry run pip install -e .
4 |
5 | test:
6 | find kinpy/. -maxdepth 1 -type f -name "*.py" | xargs poetry run flake8
7 | poetry run mypy kinpy/*.py
8 | poetry run python -m unittest discover
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
2 |
3 | [](https://github.com/neka-nat/kinpy/actions/workflows/ubuntu.yml/badge.svg)
4 | [](https://badge.fury.io/py/kinpy)
5 | [](LICENSE)
6 | [](https://pepy.tech/project/kinpy)
7 |
8 | Simple kinematics body toolkit.
9 |
10 | ## Core features
11 |
12 | * Pure python library.
13 | * Support URDF, SDF and MJCF file.
14 | * Calculate FK, IK and jacobian.
15 |
16 | 
17 |
18 | ## Installation
19 |
20 | ```
21 | pip install kinpy
22 | ```
23 |
24 | ## Getting started
25 | Here is a program that reads urdf and generates a kinematic chain.
26 |
27 | ```py
28 | import kinpy as kp
29 |
30 | chain = kp.build_chain_from_urdf(open("kuka_iiwa/model.urdf").read())
31 | print(chain)
32 | # lbr_iiwa_link_0_frame
33 | # └──── lbr_iiwa_link_1_frame
34 | # └──── lbr_iiwa_link_2_frame
35 | # └──── lbr_iiwa_link_3_frame
36 | # └──── lbr_iiwa_link_4_frame
37 | # └──── lbr_iiwa_link_5_frame
38 | # └──── lbr_iiwa_link_6_frame
39 | # └──── lbr_iiwa_link_7_frame
40 | ```
41 |
42 | Displays the parameter names of joint angles included in the chain.
43 |
44 | ```py
45 | print(chain.get_joint_parameter_names())
46 | # ['lbr_iiwa_joint_1', 'lbr_iiwa_joint_2', 'lbr_iiwa_joint_3', 'lbr_iiwa_joint_4', 'lbr_iiwa_joint_5', 'lbr_iiwa_joint_6', 'lbr_iiwa_joint_7']
47 | ```
48 |
49 | Given joint angle values, calculate forward kinematics.
50 |
51 | ```py
52 | import math
53 | th = {'lbr_iiwa_joint_2': math.pi / 4.0, 'lbr_iiwa_joint_4': math.pi / 2.0}
54 | ret = chain.forward_kinematics(th)
55 | # {'lbr_iiwa_link_0': Transform(rot=[1. 0. 0. 0.], pos=[0. 0. 0.]), 'lbr_iiwa_link_1': Transform(rot=[1. 0. 0. 0.], pos=[0. 0. 0.1575]), 'lbr_iiwa_link_2': Transform(rot=[-0.27059805 0.27059805 0.65328148 0.65328148], pos=[0. 0. 0.36]), 'lbr_iiwa_link_3': Transform(rot=[-9.23879533e-01 3.96044251e-14 -3.82683432e-01 -1.96942462e-12], pos=[ 1.44603337e-01 -6.78179735e-13 5.04603337e-01]), 'lbr_iiwa_link_4': Transform(rot=[-0.65328148 -0.65328148 0.27059805 -0.27059805], pos=[ 2.96984848e-01 -3.37579445e-13 6.56984848e-01]), 'lbr_iiwa_link_5': Transform(rot=[ 2.84114655e-12 3.82683432e-01 -1.87377891e-12 -9.23879533e-01], pos=[ 1.66523647e-01 -1.00338887e-12 7.87446049e-01]), 'lbr_iiwa_link_6': Transform(rot=[-0.27059805 0.27059805 -0.65328148 -0.65328148], pos=[ 1.41421356e-02 -7.25873884e-13 9.39827561e-01]), 'lbr_iiwa_link_7': Transform(rot=[ 9.23879533e-01 2.61060896e-12 -3.82683432e-01 4.81056861e-12], pos=[-4.31335137e-02 -1.01819561e-12 9.97103210e-01])}
56 | ```
57 |
58 | You can get the position and orientation of each link.
59 |
60 | If you want to use IK or Jacobian, you need to create a `SerialChain`.
61 | When creating a `SerialChain`, an end effector must be specified.
62 |
63 | ```py
64 | chain = kp.build_serial_chain_from_urdf(open("kuka_iiwa/model.urdf"), "lbr_iiwa_link_7")
65 | th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]
66 | ret = chain.forward_kinematics(th, end_only=True)
67 | # chain.inverse_kinematics(ret)
68 | # chain.jacobian(th)
69 | ```
70 |
71 | ## Visualization
72 |
73 | ### KUKA IIWA
74 | 
75 |
76 | ### Mujoco humanoid
77 | 
78 |
79 | ### Mujoco ant
80 | 
81 |
82 | ### Simple arm
83 | 
84 |
85 | ## Citing
86 |
87 | ```
88 | @software{kinpy,
89 | author = {{Kenta-Tanaka et al.}},
90 | title = {kinpy},
91 | url = {https://github.com/neka-nat/kinpy},
92 | version = {0.0.3},
93 | date = {2019-10-11},
94 | }
95 | ```
96 |
--------------------------------------------------------------------------------
/assets/ant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/assets/ant.png
--------------------------------------------------------------------------------
/assets/humanoid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/assets/humanoid.png
--------------------------------------------------------------------------------
/assets/joint_angle_editor.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/assets/joint_angle_editor.gif
--------------------------------------------------------------------------------
/assets/kuka.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/assets/kuka.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/assets/logo.png
--------------------------------------------------------------------------------
/assets/simple_arm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/assets/simple_arm.png
--------------------------------------------------------------------------------
/examples/ant/ant.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/examples/ant_example.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import kinpy as kp
3 |
4 | chain = kp.build_chain_from_mjcf(open("ant/ant.xml"))
5 | print(chain)
6 | print(chain.get_joint_parameter_names())
7 | th = {
8 | "hip_1": 0.0,
9 | "ankle_1": np.pi / 4.0,
10 | "hip_2": 0.0,
11 | "ankle_2": -np.pi / 4.0,
12 | "hip_3": 0.0,
13 | "ankle_3": -np.pi / 4.0,
14 | "hip_4": 0.0,
15 | "ankle_4": np.pi / 4.0,
16 | }
17 | ret = chain.forward_kinematics(th)
18 | print(ret)
19 | viz = kp.Visualizer()
20 | viz.add_robot(ret, chain.visuals_map())
21 | viz.spin()
22 |
--------------------------------------------------------------------------------
/examples/humanoid/humanoid.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
8 |
9 |
10 |
11 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
23 |
25 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/examples/humanoid_example.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import kinpy as kp
3 |
4 | chain = kp.build_chain_from_mjcf(open("humanoid/humanoid.xml"))
5 | print(chain)
6 | print(chain.get_joint_parameter_names())
7 | th = {"left_knee": 0.0, "right_knee": 0.0}
8 | ret = chain.forward_kinematics(th)
9 | print(ret)
10 | viz = kp.Visualizer()
11 | viz.add_robot(ret, chain.visuals_map(), axes=True)
12 | viz.spin()
13 |
--------------------------------------------------------------------------------
/examples/kuka_example.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import kinpy as kp
3 |
4 | chain = kp.build_serial_chain_from_urdf(open("kuka_iiwa/model.urdf"), "lbr_iiwa_link_7")
5 | print(chain)
6 | print(chain.get_joint_parameter_names())
7 | th = [0.0, -np.pi / 4.0, 0.0, np.pi / 2.0, 0.0, np.pi / 4.0, 0.0]
8 | ret = chain.forward_kinematics(th, end_only=False)
9 | print(ret)
10 | viz = kp.Visualizer()
11 | viz.add_robot(ret, chain.visuals_map(), mesh_file_path="kuka_iiwa/", axes=True)
12 | viz.spin()
13 |
--------------------------------------------------------------------------------
/examples/kuka_iiwa/model.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
--------------------------------------------------------------------------------
/examples/mycobot/G_base.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/G_base.ply
--------------------------------------------------------------------------------
/examples/mycobot/camera_flange.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/camera_flange.ply
--------------------------------------------------------------------------------
/examples/mycobot/joint1.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint1.ply
--------------------------------------------------------------------------------
/examples/mycobot/joint1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint1.png
--------------------------------------------------------------------------------
/examples/mycobot/joint2.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint2.ply
--------------------------------------------------------------------------------
/examples/mycobot/joint2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint2.png
--------------------------------------------------------------------------------
/examples/mycobot/joint3.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint3.ply
--------------------------------------------------------------------------------
/examples/mycobot/joint3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint3.png
--------------------------------------------------------------------------------
/examples/mycobot/joint4.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint4.ply
--------------------------------------------------------------------------------
/examples/mycobot/joint4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint4.png
--------------------------------------------------------------------------------
/examples/mycobot/joint5.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint5.ply
--------------------------------------------------------------------------------
/examples/mycobot/joint5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint5.png
--------------------------------------------------------------------------------
/examples/mycobot/joint6.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint6.ply
--------------------------------------------------------------------------------
/examples/mycobot/joint6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint6.png
--------------------------------------------------------------------------------
/examples/mycobot/joint7.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint7.ply
--------------------------------------------------------------------------------
/examples/mycobot/joint7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/joint7.png
--------------------------------------------------------------------------------
/examples/mycobot/mycobot.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
--------------------------------------------------------------------------------
/examples/mycobot/pump_head.ply:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/examples/mycobot/pump_head.ply
--------------------------------------------------------------------------------
/examples/mycobot_example.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import kinpy as kp
3 |
4 | chain = kp.build_serial_chain_from_urdf(open("mycobot/mycobot.urdf").read(), "pump_head")
5 | print(chain)
6 |
7 | print(chain.get_joint_parameter_names())
8 | th = np.deg2rad([0, 20, -130, 20, 0, 0])
9 | viz = kp.JointAngleEditor(chain, "mycobot/", initial_state=th, axes=True)
10 | viz.spin()
11 |
--------------------------------------------------------------------------------
/examples/simple_arm/model.sdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | 0 0 0.00099 0 0 0
7 |
8 | 1.11
9 | 0
10 | 0
11 | 100.11
12 | 0
13 | 1.01
14 |
15 | 101.0
16 |
17 |
18 | 0 0 0.05 0 0 0
19 |
20 |
21 | 1.0 1.0 0.1
22 |
23 |
24 |
25 |
26 | 0 0 0.05 0 0 0
27 |
28 |
29 | 1.0 1.0 0.1
30 |
31 |
32 |
33 |
37 |
38 |
39 |
40 | 0 0 0.6 0 0 0
41 |
42 |
43 | 0.05
44 | 1.0
45 |
46 |
47 |
48 |
49 | 0 0 0.6 0 0 0
50 |
51 |
52 | 0.05
53 | 1.0
54 |
55 |
56 |
57 |
61 |
62 |
63 |
64 |
65 | 0 0 1.1 0 0 0
66 |
67 | 0.045455 0 0 0 0 0
68 |
69 | 0.011
70 | 0
71 | 0
72 | 0.0225
73 | 0
74 | 0.0135
75 |
76 | 1.1
77 |
78 |
79 | 0 0 0.05 0 0 0
80 |
81 |
82 | 0.05
83 | 0.1
84 |
85 |
86 |
87 |
88 | 0 0 0.05 0 0 0
89 |
90 |
91 | 0.05
92 | 0.1
93 |
94 |
95 |
96 |
100 |
101 |
102 |
103 | 0.55 0 0.05 0 0 0
104 |
105 |
106 | 1.0 0.05 0.1
107 |
108 |
109 |
110 |
111 | 0.55 0 0.05 0 0 0
112 |
113 |
114 | 1.0 0.05 0.1
115 |
116 |
117 |
118 |
122 |
123 |
124 |
125 |
126 | 1.05 0 1.1 0 0 0
127 |
128 | 0.0875 0 0.083333 0 0 0
129 |
130 | 0.031
131 | 0
132 | 0.005
133 | 0.07275
134 | 0
135 | 0.04475
136 |
137 | 1.2
138 |
139 |
140 | 0 0 0.1 0 0 0
141 |
142 |
143 | 0.05
144 | 0.2
145 |
146 |
147 |
148 |
149 | 0 0 0.1 0 0 0
150 |
151 |
152 | 0.05
153 | 0.2
154 |
155 |
156 |
157 |
161 |
162 |
163 |
164 | 0.3 0 0.15 0 0 0
165 |
166 |
167 | 0.5 0.03 0.1
168 |
169 |
170 |
171 |
172 | 0.3 0 0.15 0 0 0
173 |
174 |
175 | 0.5 0.03 0.1
176 |
177 |
178 |
179 |
183 |
184 |
185 |
186 | 0.55 0 0.15 0 0 0
187 |
188 |
189 | 0.05
190 | 0.3
191 |
192 |
193 |
194 |
195 | 0.55 0 0.15 0 0 0
196 |
197 |
198 | 0.05
199 | 0.3
200 |
201 |
202 |
203 |
207 |
208 |
209 |
210 |
211 | 1.6 0 1.05 0 0 0
212 |
213 | 0 0 0 0 0 0
214 |
215 | 0.01
216 | 0
217 | 0
218 | 0.01
219 | 0
220 | 0.001
221 |
222 | 0.1
223 |
224 |
225 | 0 0 0.5 0 0 0
226 |
227 |
228 | 0.03
229 | 1.0
230 |
231 |
232 |
233 |
234 | 0 0 0.5 0 0 0
235 |
236 |
237 | 0.03
238 | 1.0
239 |
240 |
241 |
242 |
246 |
247 |
248 |
249 |
250 | 1.6 0 1.0 0 0 0
251 |
252 | 0 0 0 0 0 0
253 |
254 | 0.01
255 | 0
256 | 0
257 | 0.01
258 | 0
259 | 0.001
260 |
261 | 0.1
262 |
263 |
264 | 0 0 0.025 0 0 0
265 |
266 |
267 | 0.05
268 | 0.05
269 |
270 |
271 |
272 |
273 | 0 0 0.025 0 0 0
274 |
275 |
276 | 0.05
277 | 0.05
278 |
279 |
280 |
281 |
285 |
286 |
287 |
288 |
289 | arm_base
290 | arm_shoulder_pan
291 |
292 |
293 | 1.000000
294 | 0.000000
295 |
296 | 0 0 1
297 | true
298 |
299 |
300 |
301 | arm_shoulder_pan
302 | arm_elbow_pan
303 |
304 |
305 | 1.000000
306 | 0.000000
307 |
308 | 0 0 1
309 | true
310 |
311 |
312 |
313 | arm_elbow_pan
314 | arm_wrist_lift
315 |
316 |
317 | 1.000000
318 | 0.000000
319 |
320 |
321 | -0.8
322 | 0.1
323 |
324 | 0 0 1
325 | true
326 |
327 |
328 |
329 | arm_wrist_lift
330 | arm_wrist_roll
331 |
332 |
333 | 1.000000
334 | 0.000000
335 |
336 |
337 | -2.999994
338 | 2.999994
339 |
340 | 0 0 1
341 | true
342 |
343 |
344 |
357 |
358 |
359 |
--------------------------------------------------------------------------------
/examples/simple_arm_example.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import kinpy as kp
3 |
4 | chain = kp.build_chain_from_sdf(open("simple_arm/model.sdf"))
5 | print(chain)
6 | print(chain.get_joint_parameter_names())
7 | ret = chain.forward_kinematics({"arm_elbow_pan_joint": np.pi / 2.0, "arm_wrist_lift_joint": -0.5})
8 | print(ret)
9 | viz = kp.Visualizer()
10 | viz.add_robot(ret, chain.visuals_map())
11 | viz.spin()
12 |
--------------------------------------------------------------------------------
/examples/ur/ur.urdf:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
26 |
27 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 | transmission_interface/SimpleTransmission
235 |
236 | PositionJointInterface
237 |
238 |
239 | 1
240 |
241 |
242 |
243 | transmission_interface/SimpleTransmission
244 |
245 | PositionJointInterface
246 |
247 |
248 | 1
249 |
250 |
251 |
252 | transmission_interface/SimpleTransmission
253 |
254 | PositionJointInterface
255 |
256 |
257 | 1
258 |
259 |
260 |
261 | transmission_interface/SimpleTransmission
262 |
263 | PositionJointInterface
264 |
265 |
266 | 1
267 |
268 |
269 |
270 | transmission_interface/SimpleTransmission
271 |
272 | PositionJointInterface
273 |
274 |
275 | 1
276 |
277 |
278 |
279 | transmission_interface/SimpleTransmission
280 |
281 | PositionJointInterface
282 |
283 |
284 | 1
285 |
286 |
287 |
288 |
289 |
290 |
291 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
--------------------------------------------------------------------------------
/examples/ur_example.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import kinpy as kp
3 |
4 |
5 | arm = kp.build_serial_chain_from_urdf(
6 | open("ur/ur.urdf"),
7 | root_link_name="base_link",
8 | end_link_name="ee_link",
9 | )
10 | fk_solution = arm.forward_kinematics(np.zeros(len(arm.get_joint_parameter_names())))
11 | print(fk_solution)
12 |
--------------------------------------------------------------------------------
/kinpy/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .chain import Chain
4 | from .mjcf import *
5 | from .sdf import *
6 | from .transform import *
7 | from .urdf import *
8 | from .visualizer import *
9 |
10 |
11 | def build_chain_from_file(filename: str) -> Chain:
12 | ext = os.path.splitext(filename)[-1]
13 | if ext == ".urdf":
14 | return build_chain_from_urdf(open(filename).read())
15 | elif ext == ".sdf":
16 | return build_chain_from_sdf(open(filename).read())
17 | elif ext == ".mjcf":
18 | return build_chain_from_mjcf(open(filename).read())
19 | else:
20 | raise ValueError(f"Invalid file type: '{ext}' file.")
21 |
--------------------------------------------------------------------------------
/kinpy/chain.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 | from typing import Dict, Iterator, List, Optional, Union
3 |
4 | import numpy as np
5 |
6 | from . import frame, ik, jacobian, transform
7 |
8 |
9 | class Chain:
10 | """Chain is a class that represents a kinematic chain."""
11 |
12 | def __init__(self, root_frame: frame.Frame) -> None:
13 | self._root: Optional[frame.Frame] = root_frame
14 |
15 | def __str__(self) -> str:
16 | return str(self._root)
17 |
18 | def __iter__(self) -> Iterator[frame.Frame]:
19 | assert self._root is not None, "Root frame is None"
20 | yield from self._root.walk()
21 |
22 | @cached_property
23 | def dof(self):
24 | return len(self.get_joint_parameter_names())
25 |
26 | @staticmethod
27 | def _find_frame_recursive(name: str, frame: frame.Frame) -> Optional[frame.Frame]:
28 | for child in frame.children:
29 | if child.name == name:
30 | return child
31 | ret = Chain._find_frame_recursive(name, child)
32 | if ret is not None:
33 | return ret
34 | return None
35 |
36 | def find_frame(self, name: str) -> Optional[frame.Frame]:
37 | """Find a frame by name.
38 |
39 | Parameters
40 | ----------
41 | name : str
42 | Frame name.
43 |
44 | Returns
45 | -------
46 | Optional[frame.Frame]
47 | Frame if found, None otherwise.
48 | """
49 | assert self._root is not None, "Root frame is None"
50 | if self._root.name == name:
51 | return self._root
52 | return self._find_frame_recursive(name, self._root)
53 |
54 | @staticmethod
55 | def _find_link_recursive(name: str, frame: frame.Frame) -> Optional[frame.Link]:
56 | for child in frame.children:
57 | if child.link.name == name:
58 | return child.link
59 | ret = Chain._find_link_recursive(name, child)
60 | if ret is not None:
61 | return ret
62 | return None
63 |
64 | def find_link(self, name: str) -> Optional[frame.Link]:
65 | """Find a link by name.
66 |
67 | Parameters
68 | ----------
69 | name : str
70 | Link name.
71 |
72 | Returns
73 | -------
74 | Optional[frame.Link]
75 | Link if found, None otherwise.
76 | """
77 | assert self._root is not None, "Root frame is None"
78 | if self._root.link.name == name:
79 | return self._root.link
80 | return self._find_link_recursive(name, self._root)
81 |
82 | @staticmethod
83 | def _get_joint_parameter_names(frame: frame.Frame, exclude_fixed: bool = True) -> List[str]:
84 | joint_names = []
85 | if not (exclude_fixed and frame.joint.joint_type == "fixed"):
86 | joint_names.append(frame.joint.name)
87 | for child in frame.children:
88 | joint_names.extend(Chain._get_joint_parameter_names(child, exclude_fixed))
89 | return joint_names
90 |
91 | def get_joint_parameter_names(self, exclude_fixed: bool = True) -> List[str]:
92 | """Get joint parameter names.
93 |
94 | Parameters
95 | ----------
96 | exclude_fixed : bool, optional
97 | Exclude fixed joints, by default True
98 |
99 | Returns
100 | -------
101 | List[str]
102 | Joint parameter names.
103 | """
104 | assert self._root is not None, "Root frame is None"
105 | names = self._get_joint_parameter_names(self._root, exclude_fixed)
106 | return list(sorted(set(names), key=names.index))
107 |
108 | def add_frame(self, frame: frame.Frame, parent_name: str) -> None:
109 | parent_frame = self.find_frame(parent_name)
110 | if parent_frame is not None:
111 | parent_frame.add_child(frame)
112 |
113 | @staticmethod
114 | def _forward_kinematics(
115 | root: frame.Frame, th_dict: Dict[str, float], world: Optional[transform.Transform] = None
116 | ) -> Dict[str, transform.Transform]:
117 | world = world or transform.Transform()
118 | link_transforms = {}
119 | trans = world * root.get_transform(th_dict.get(root.joint.name, 0.0))
120 | link_transforms[root.link.name] = trans * root.link.offset
121 | for child in root.children:
122 | link_transforms.update(Chain._forward_kinematics(child, th_dict, trans))
123 | return link_transforms
124 |
125 | def forward_kinematics(
126 | self, th: Union[Dict[str, float], List[float]], world: Optional[transform.Transform] = None, **kwargs: Dict
127 | ) -> Dict[str, transform.Transform]:
128 | """Forward kinematics.
129 |
130 | Parameters
131 | ----------
132 | th : Union[Dict[str, float], List[float]]
133 | Joint parameters.
134 | world : Optional[transform.Transform], optional
135 | World transform, by default None
136 |
137 | Returns
138 | -------
139 | Dict[str, transform.Transform]
140 | Link transforms.
141 | """
142 | assert self._root is not None, "Root frame is None"
143 | world = world or transform.Transform()
144 | if not isinstance(th, dict):
145 | jn = self.get_joint_parameter_names()
146 | assert len(jn) == len(th)
147 | th_dict = dict((j, th[i]) for i, j in enumerate(jn))
148 | else:
149 | th_dict = th
150 | return self._forward_kinematics(self._root, th_dict, world)
151 |
152 | @staticmethod
153 | def _visuals_map(root: frame.Frame) -> Dict[str, List[frame.Visual]]:
154 | vmap = {root.link.name: root.link.visuals}
155 | for child in root.children:
156 | vmap.update(Chain._visuals_map(child))
157 | return vmap
158 |
159 | def visuals_map(self):
160 | return self._visuals_map(self._root)
161 |
162 |
163 | class SerialChain(Chain):
164 | """SerialChain is a class that represents a serial kinematic chain."""
165 |
166 | def __init__(self, chain: Chain, end_frame_name: str, root_frame_name: str = "") -> None:
167 | assert chain._root is not None, "Chain root frame is None"
168 | if root_frame_name == "":
169 | self._root = chain._root
170 | else:
171 | self._root = chain.find_frame(root_frame_name)
172 | if self._root is None:
173 | raise ValueError("Invalid root frame name %s." % root_frame_name)
174 | frames = self._generate_serial_chain_recurse(self._root, end_frame_name)
175 | if frames is None:
176 | raise ValueError("Invalid end frame name %s." % end_frame_name)
177 | self._serial_frames = [self._root] + frames
178 |
179 | @staticmethod
180 | def _generate_serial_chain_recurse(root_frame: frame.Frame, end_frame_name: str) -> Optional[List[frame.Frame]]:
181 | for child in root_frame.children:
182 | if child.name == end_frame_name:
183 | return [child]
184 | else:
185 | frames = SerialChain._generate_serial_chain_recurse(child, end_frame_name)
186 | if frames is not None:
187 | return [child] + frames
188 | return None
189 |
190 | def get_joint_parameter_names(self, exclude_fixed: bool = True) -> List[str]:
191 | assert self._serial_frames is not None, "Serial chain not initialized."
192 | names = []
193 | for f in self._serial_frames:
194 | if exclude_fixed and f.joint.joint_type == "fixed":
195 | continue
196 | names.append(f.joint.name)
197 | return names
198 |
199 | def forward_kinematics( # type: ignore[override]
200 | self,
201 | th: Union[Dict[str, float], List[float]],
202 | world: Optional[transform.Transform] = None,
203 | end_only: bool = True,
204 | ) -> Union[transform.Transform, Dict[str, transform.Transform]]:
205 | assert self._serial_frames is not None, "Serial chain not initialized."
206 | if isinstance(th, dict):
207 | link_transforms = super().forward_kinematics(th, world)
208 | if end_only:
209 | return link_transforms[self._serial_frames[-1].link.name]
210 | else:
211 | return link_transforms
212 | world = world or transform.Transform()
213 | cnt = 0
214 | link_transforms = {}
215 | trans = world
216 | for f in self._serial_frames:
217 | if f.joint.joint_type != "fixed":
218 | trans = trans * f.get_transform(th[cnt])
219 | else:
220 | trans = trans * f.get_transform()
221 | link_transforms[f.link.name] = trans * f.link.offset
222 | if f.joint.joint_type != "fixed":
223 | cnt += 1
224 | return link_transforms[self._serial_frames[-1].link.name] if end_only else link_transforms
225 |
226 | def jacobian(self, th: List[float], end_only: bool = True) -> Union[np.ndarray, Dict[str, np.ndarray]]:
227 | assert self._serial_frames is not None, "Serial chain not initialized."
228 | if end_only:
229 | return jacobian.calc_jacobian(self, th)
230 | else:
231 | jacobians = {}
232 | for serial_frame in self._serial_frames:
233 | jac = jacobian.calc_jacobian_frames(self, th, link_name=serial_frame.link.name)
234 | jacobians[serial_frame.link.name] = jac
235 | return jacobians
236 |
237 | def inverse_kinematics(self, pose: transform.Transform, initial_state: Optional[np.ndarray] = None) -> np.ndarray:
238 | return ik.inverse_kinematics(self, pose, initial_state)
239 |
--------------------------------------------------------------------------------
/kinpy/frame.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Iterator, List, Optional
2 |
3 | import numpy as np
4 | import transformations as tf
5 |
6 | from . import transform
7 |
8 |
9 | class Visual:
10 | TYPES = ["box", "cylinder", "sphere", "capsule", "mesh"]
11 |
12 | def __init__(
13 | self,
14 | offset: Optional[transform.Transform] = None,
15 | geom_type: Optional[str] = None,
16 | geom_param: Any = None,
17 | ) -> None:
18 | self.offset = offset or transform.Transform()
19 | self.geom_type = geom_type
20 | self.geom_param = geom_param
21 |
22 | def __repr__(self) -> str:
23 | return "Visual(offset={0}, geom_type='{1}', geom_param={2})".format(
24 | self.offset, self.geom_type, self.geom_param
25 | )
26 |
27 |
28 | class Link:
29 | def __init__(
30 | self, name: Optional[str] = None, offset: Optional[transform.Transform] = None, visuals: Optional[List] = None
31 | ) -> None:
32 | self.name = name if name is not None else "none"
33 | self.offset = offset or transform.Transform()
34 | self.visuals = visuals or []
35 |
36 | def __repr__(self) -> str:
37 | return "Link(name='{0}', offset={1}, visuals={2})".format(self.name, self.offset, self.visuals)
38 |
39 |
40 | class Joint:
41 | TYPES = ["fixed", "revolute", "prismatic"]
42 |
43 | def __init__(
44 | self,
45 | name: Optional[str] = None,
46 | offset: Optional[transform.Transform] = None,
47 | joint_type: str = "fixed",
48 | axis: Optional[List[float]] = None,
49 | ) -> None:
50 | self.name = name if name is not None else "none"
51 | self.offset = offset or transform.Transform()
52 | self.joint_type = joint_type
53 | if self.joint_type != "fixed" and axis is None:
54 | self.axis = np.array([0.0, 0.0, 1.0])
55 | else:
56 | self.axis = np.array(axis) if axis is not None else np.array([0.0, 0.0, 1.0])
57 |
58 | def __repr__(self) -> str:
59 | return "Joint(name='{0}', offset={1}, joint_type='{2}', axis={3})".format(
60 | self.name, self.offset, self.joint_type, self.axis
61 | )
62 |
63 |
64 | class Frame:
65 | def __init__(
66 | self,
67 | name: Optional[str] = None,
68 | link: Optional[Link] = None,
69 | joint: Optional[Joint] = None,
70 | children: Optional[List["Frame"]] = None,
71 | ) -> None:
72 | self.name = "None" if name is None else name
73 | self.link = link or Link()
74 | self.joint = joint or Joint()
75 | self.children = children or []
76 |
77 | def _ptree(self, indent_width: int = 4) -> str:
78 | def _inner_ptree(root: Frame, parent: Frame, grandpa: Optional[Frame] = None, indent: str = ""):
79 | show_str = ""
80 | if parent.name != root.name:
81 | show_str += " " + parent.name + ("" if grandpa is None else "\n")
82 | if not parent.children:
83 | return show_str
84 | for child in parent.children[:-1]:
85 | show_str += indent + "├" + "─" * indent_width
86 | show_str += _inner_ptree(root, child, parent, indent + "│" + " " * (indent_width + 1))
87 | if parent.children:
88 | child = parent.children[-1]
89 | show_str += indent + "└" + "─" * indent_width
90 | show_str += _inner_ptree(root, child, parent, indent + " " * (indent_width + 2))
91 | return show_str
92 |
93 | show_str = self.name + "\n"
94 | show_str += _inner_ptree(self, self)
95 | return show_str
96 |
97 | def __str__(self) -> str:
98 | return self._ptree()
99 |
100 | def add_child(self, child: "Frame") -> None:
101 | self.children.append(child)
102 |
103 | def is_end(self) -> bool:
104 | return len(self.children) == 0
105 |
106 | def get_transform(self, theta: float = 0.0) -> transform.Transform:
107 | if self.joint.joint_type == "revolute":
108 | t = transform.Transform(tf.quaternion_about_axis(theta, self.joint.axis))
109 | elif self.joint.joint_type == "prismatic":
110 | t = transform.Transform(pos=theta * self.joint.axis)
111 | elif self.joint.joint_type == "fixed":
112 | t = transform.Transform()
113 | else:
114 | raise ValueError("Unsupported joint type %s." % self.joint.joint_type)
115 | return self.joint.offset * t
116 |
117 | def walk(self) -> Iterator["Frame"]:
118 | yield self
119 | for child in self.children:
120 | yield from child.walk()
121 |
--------------------------------------------------------------------------------
/kinpy/ik.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional
2 |
3 | import numpy as np
4 | import scipy.optimize as sco
5 |
6 | from . import transform
7 |
8 |
9 | def inverse_kinematics(
10 | serial_chain: Any, pose: transform.Transform, initial_state: Optional[np.ndarray] = None
11 | ) -> np.ndarray:
12 | ndim = len(serial_chain.get_joint_parameter_names())
13 | if initial_state is None:
14 | x0 = np.zeros(ndim)
15 | else:
16 | x0 = initial_state
17 |
18 | def object_fn(x):
19 | tf = serial_chain.forward_kinematics(x)
20 | obj = np.square(np.linalg.lstsq(pose.matrix(), tf.matrix(), rcond=-1)[0] - np.identity(4)).sum()
21 | return obj
22 |
23 | ret = sco.minimize(object_fn, x0, method="BFGS")
24 | return ret.x
25 |
--------------------------------------------------------------------------------
/kinpy/jacobian.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional
2 |
3 | import numpy as np
4 |
5 | from . import transform
6 |
7 |
8 | def calc_jacobian(serial_chain: Any, th: List[float], tool: Optional[transform.Transform] = None) -> np.ndarray:
9 | tool = tool or transform.Transform()
10 | ndof = len(th)
11 | j_fl = np.zeros((6, ndof))
12 | cur_transform = tool.matrix()
13 |
14 | cnt = 0
15 | for f in reversed(serial_chain._serial_frames):
16 | if f.joint.joint_type == "revolute":
17 | cnt += 1
18 | delta = np.dot(f.joint.axis, cur_transform[:3, :3])
19 | d = np.dot(np.cross(f.joint.axis, cur_transform[:3, 3]), cur_transform[:3, :3])
20 | j_fl[:, -cnt] = np.hstack((d, delta))
21 | elif f.joint.joint_type == "prismatic":
22 | cnt += 1
23 | j_fl[:3, -cnt] = np.dot(f.joint.axis, cur_transform[:3, :3])
24 | cur_frame_transform = f.get_transform(th[-cnt]).matrix()
25 | cur_transform = np.dot(cur_frame_transform, cur_transform)
26 |
27 | pose = serial_chain.forward_kinematics(th).matrix()
28 | rotation = pose[:3, :3]
29 | j_tr = np.zeros((6, 6))
30 | j_tr[:3, :3] = rotation
31 | j_tr[3:, 3:] = rotation
32 | j_w = np.dot(j_tr, j_fl)
33 | return j_w
34 |
35 |
36 | def calc_jacobian_frames(
37 | serial_chain: Any, th: List[float], link_name: str, tool: Optional[transform.Transform] = None
38 | ) -> np.ndarray:
39 | tool = tool or transform.Transform()
40 | ndof = len(th)
41 | j_fl = np.zeros((6, ndof))
42 | cur_transform = tool.matrix()
43 |
44 | # select first num_th movable joints
45 | serial_frames = []
46 | num_movable_joints = 0
47 | for serial_frame in serial_chain._serial_frames:
48 | serial_frames.append(serial_frame)
49 | if serial_frame.joint.joint_type != "fixed":
50 | num_movable_joints += 1
51 |
52 | if serial_frame.link.name == link_name:
53 | break # found first n joints
54 |
55 | cnt = len(th) - num_movable_joints # only first num_th joints
56 | for f in reversed(serial_frames):
57 | if f.joint.joint_type == "revolute":
58 | cnt += 1
59 | delta = np.dot(f.joint.axis, cur_transform[:3, :3])
60 | d = np.dot(np.cross(f.joint.axis, cur_transform[:3, 3]), cur_transform[:3, :3])
61 | j_fl[:, -cnt] = np.hstack((d, delta))
62 | elif f.joint.joint_type == "prismatic":
63 | cnt += 1
64 | j_fl[:3, -cnt] = np.dot(f.joint.axis, cur_transform[:3, :3])
65 | cur_frame_transform = f.get_transform(th[-cnt]).matrix()
66 | cur_transform = np.dot(cur_frame_transform, cur_transform)
67 |
68 | poses = serial_chain.forward_kinematics(th, end_only=False)
69 | pose = poses[link_name].matrix()
70 |
71 | rotation = pose[:3, :3]
72 | j_tr = np.zeros((6, 6))
73 | j_tr[:3, :3] = rotation
74 | j_tr[3:, 3:] = rotation
75 | j_w = np.dot(j_tr, j_fl)
76 |
77 | return j_w
78 |
--------------------------------------------------------------------------------
/kinpy/mjcf.py:
--------------------------------------------------------------------------------
1 | import io
2 | from typing import Dict, Optional, TextIO, Union
3 |
4 | from . import chain, frame, mjcf_parser, transform
5 |
6 | JOINT_TYPE_MAP: Dict[str, str] = {"hinge": "revolute", "slide": "prismatic"}
7 |
8 |
9 | def geoms_to_visuals(geom, base: Optional[transform.Transform] = None):
10 | base = base or transform.Transform()
11 | visuals = []
12 | for g in geom:
13 | if g.type == "capsule":
14 | param = (g.size[0], g.fromto)
15 | elif g.type == "sphere":
16 | param = g.size[0]
17 | else:
18 | raise ValueError("Invalid geometry type %s." % g.type)
19 | visuals.append(
20 | frame.Visual(offset=base * transform.Transform(g.quat, g.pos), geom_type=g.type, geom_param=param)
21 | )
22 | return visuals
23 |
24 |
25 | def body_to_link(body, base: Optional[transform.Transform] = None):
26 | base = base or transform.Transform()
27 | return frame.Link(body.name, offset=base * transform.Transform(body.quat, body.pos))
28 |
29 |
30 | def joint_to_joint(joint, base: Optional[transform.Transform] = None):
31 | base = base or transform.Transform()
32 | return frame.Joint(
33 | joint.name,
34 | offset=base * transform.Transform(pos=joint.pos),
35 | joint_type=JOINT_TYPE_MAP[joint.type],
36 | axis=joint.axis,
37 | )
38 |
39 |
40 | def add_composite_joint(root_frame, joints, base: Optional[transform.Transform] = None):
41 | base = base or transform.Transform()
42 | if len(joints) > 0:
43 | root_frame.children = root_frame.children + [
44 | frame.Frame(link=frame.Link(name=root_frame.link.name + "_child"), joint=joint_to_joint(joints[0], base))
45 | ]
46 | ret, offset = add_composite_joint(root_frame.children[-1], joints[1:])
47 | return ret, root_frame.joint.offset * offset
48 | else:
49 | return root_frame, root_frame.joint.offset
50 |
51 |
52 | def _build_chain_recurse(root_frame, root_body):
53 | base = root_frame.link.offset
54 | cur_frame, cur_base = add_composite_joint(root_frame, root_body.joint, base)
55 | jbase = cur_base.inverse() * base
56 | if len(root_body.joint) > 0:
57 | cur_frame.link.visuals = geoms_to_visuals(root_body.geom, jbase)
58 | else:
59 | cur_frame.link.visuals = geoms_to_visuals(root_body.geom)
60 | for b in root_body.body:
61 | cur_frame.children = cur_frame.children + [frame.Frame()]
62 | next_frame = cur_frame.children[-1]
63 | next_frame.name = b.name + "_frame"
64 | next_frame.link = body_to_link(b, jbase)
65 | _build_chain_recurse(next_frame, b)
66 |
67 |
68 | def build_chain_from_mjcf(data: Union[str, TextIO]) -> chain.Chain:
69 | """
70 | Build a Chain object from MJCF data.
71 |
72 | Parameters
73 | ----------
74 | data : str or TextIO
75 | MJCF string data or file object.
76 |
77 | Returns
78 | -------
79 | chain.Chain
80 | Chain object created from MJCF.
81 | """
82 | if isinstance(data, io.TextIOBase):
83 | data = data.read()
84 | model = mjcf_parser.from_xml_string(data)
85 | root_body = model.worldbody.body[0]
86 | root_frame = frame.Frame(root_body.name + "_frame", link=body_to_link(root_body), joint=frame.Joint())
87 | _build_chain_recurse(root_frame, root_body)
88 | return chain.Chain(root_frame)
89 |
90 |
91 | def build_serial_chain_from_mjcf(
92 | data: Union[str, TextIO], end_link_name: str, root_link_name: str = ""
93 | ) -> chain.SerialChain:
94 | """
95 | Build a SerialChain object from MJCF data.
96 |
97 | Parameters
98 | ----------
99 | data : str
100 | MJCF string data.
101 | end_link_name : str
102 | The name of the link that is the end effector.
103 | root_link_name : str, optional
104 | The name of the root link.
105 |
106 | Returns
107 | -------
108 | chain.SerialChain
109 | SerialChain object created from MJCF.
110 | """
111 | if isinstance(data, io.TextIOBase):
112 | data = data.read()
113 | mjcf_chain = build_chain_from_mjcf(data)
114 | return chain.SerialChain(
115 | mjcf_chain, end_link_name + "_frame", "" if root_link_name == "" else root_link_name + "_frame"
116 | )
117 |
--------------------------------------------------------------------------------
/kinpy/mjcf_parser/__init__.py:
--------------------------------------------------------------------------------
1 | from .parser import *
2 |
--------------------------------------------------------------------------------
/kinpy/mjcf_parser/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Base class for all MJCF elements in the object model."""
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | import abc
21 |
22 | import six
23 |
24 |
25 | @six.add_metaclass(abc.ABCMeta)
26 | class Element(object):
27 | """Abstract base class for an MJCF element.
28 |
29 | This class is provided so that `isinstance(foo, Element)` is `True` for all
30 | Element-like objects. We do not implement the actual element here because
31 | the actual object returned from traversing the object hierarchy is a
32 | weakproxy-like proxy to an actual element. This is because we do not allow
33 | orphaned non-root elements, so when a particular element is removed from the
34 | tree, all references held automatically become invalid.
35 | """
36 |
37 | __slots__ = []
38 |
--------------------------------------------------------------------------------
/kinpy/mjcf_parser/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Magic constants used within `dm_control.mjcf`."""
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | PREFIX_SEPARATOR = "/"
21 | PREFIX_SEPARATOR_ESCAPE = "\\"
22 |
23 | # Used to disambiguate namespaces between attachment frames.
24 | NAMESPACE_SEPARATOR = "@"
25 |
26 | # Magic attribute names
27 | BASEPATH = "basepath"
28 | CHILDCLASS = "childclass"
29 | CLASS = "class"
30 | DEFAULT = "default"
31 | DCLASS = "dclass"
32 |
33 | # Magic tags
34 | ACTUATOR = "actuator"
35 | BODY = "body"
36 | DEFAULT = "default"
37 | MESH = "mesh"
38 | SITE = "site"
39 | TENDON = "tendon"
40 | WORLDBODY = "worldbody"
41 |
42 | MJDATA_TRIGGERS_DIRTY = ["qpos", "qvel", "act", "ctrl", "qfrc_applied", "xfrc_applied"]
43 | MJMODEL_DOESNT_TRIGGER_DIRTY = ["rgba", "matid", "emission", "specular", "shininess", "reflectance"]
44 |
45 | # When writing into `model.{body,geom,site}_{pos,quat}` we must ensure that the
46 | # corresponding rows in `model.{body,geom,site}_sameframe` are set to zero,
47 | # otherwise MuJoCo will use the body or inertial frame instead of our modified
48 | # pos/quat values. We must do the same for `body_{ipos,iquat}` and
49 | # `body_simple`.
50 | MJMODEL_DISABLE_ON_WRITE = {
51 | # Field name in MjModel: (attribute names of Binding instance to be zeroed)
52 | "body_pos": ("sameframe",),
53 | "body_quat": ("sameframe",),
54 | "geom_pos": ("sameframe",),
55 | "geom_quat": ("sameframe",),
56 | "site_pos": ("sameframe",),
57 | "site_quat": ("sameframe",),
58 | "body_ipos": ("simple", "sameframe"),
59 | "body_iquat": ("simple", "sameframe"),
60 | }
61 |
62 | # This is the actual upper limit on VFS filename length, despite what it says
63 | # in the header file (100) or the error message (99).
64 | MAX_VFS_FILENAME_LENGTH = 98
65 |
66 | # The prefix used in the schema to denote reference_namespace that are defined
67 | # via another attribute.
68 | INDIRECT_REFERENCE_NAMESPACE_PREFIX = "attrib:"
69 |
--------------------------------------------------------------------------------
/kinpy/mjcf_parser/copier.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Helper object for keeping track of new elements created when copying MJCF."""
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | from . import constants
21 |
22 |
23 | class Copier(object):
24 | """Helper for keeping track of new elements created when copying MJCF."""
25 |
26 | def __init__(self, source):
27 | if source._attachments: # pylint: disable=protected-access
28 | raise NotImplementedError("Cannot copy from elements with attachments")
29 | self._source = source
30 |
31 | def copy_into(self, destination, override_attributes=False):
32 | """Copies this copier's element into a destination MJCF element."""
33 | newly_created_elements = {}
34 | destination._check_valid_attachment(self._source) # pylint: disable=protected-access
35 | if override_attributes:
36 | destination.set_attributes(**self._source.get_attributes())
37 | else:
38 | destination._sync_attributes(self._source, copying=True) # pylint: disable=protected-access
39 | for source_child in self._source.all_children():
40 | dest_child = None
41 | # First, if source_child has an identifier, we look for an existing child
42 | # element of self with the same identifier to override.
43 | if source_child.spec.identifier and override_attributes:
44 | identifier_attr = source_child.spec.identifier
45 | if identifier_attr == constants.CLASS:
46 | identifier_attr = constants.DCLASS
47 | identifier = getattr(source_child, identifier_attr)
48 | if identifier:
49 | dest_child = destination.find(source_child.spec.namespace, identifier)
50 | if dest_child is not None and dest_child.parent is not destination:
51 | raise ValueError(
52 | "<{}> with identifier {!r} is already a child of another element".format(
53 | source_child.spec.namespace, identifier
54 | )
55 | )
56 | # Next, we cover the case where either the child is not a repeated element
57 | # or if source_child has an identifier attribute but it isn't set.
58 | if not source_child.spec.repeated and dest_child is None:
59 | dest_child = destination.get_children(source_child.tag)
60 |
61 | # Add a new element if dest_child doesn't exist, either because it is
62 | # supposed to be a repeated child, or because it's an uncreated on-demand.
63 | if dest_child is None:
64 | dest_child = destination.add(source_child.tag, **source_child.get_attributes())
65 | newly_created_elements[source_child] = dest_child
66 | override_child_attributes = True
67 | else:
68 | override_child_attributes = override_attributes
69 |
70 | # Finally, copy attributes into dest_child.
71 | child_copier = Copier(source_child)
72 | newly_created_elements.update(child_copier.copy_into(dest_child, override_child_attributes))
73 | return newly_created_elements
74 |
--------------------------------------------------------------------------------
/kinpy/mjcf_parser/debugging.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Implements PyMJCF debug mode.
17 |
18 | PyMJCF debug mode stores a stack trace each time the MJCF object is modified.
19 | If Mujoco raises a compile error on the generated XML model, we would then be
20 | able to find the original source line that created the offending element.
21 | """
22 |
23 | from __future__ import absolute_import, division, print_function
24 |
25 | import collections
26 | import contextlib
27 | import copy
28 | import os
29 | import re
30 | import sys
31 | import traceback
32 |
33 | import six
34 | from absl import flags
35 | from lxml import etree
36 |
37 | FLAGS = flags.FLAGS
38 | flags.DEFINE_boolean(
39 | "pymjcf_debug",
40 | False,
41 | "Enables PyMJCF debug mode (SLOW!). In this mode, a stack trace is logged "
42 | "each the MJCF object is modified. This may be helpful in locating the "
43 | "Python source line corresponding to a problematic element in the "
44 | "generated XML.",
45 | )
46 | flags.DEFINE_string("pymjcf_debug_full_dump_dir", "", "Path to dump full debug info when Mujoco error is encountered.")
47 |
48 | StackTraceEntry = collections.namedtuple("StackTraceEntry", ("filename", "line_number", "function_name", "text"))
49 |
50 | ElementDebugInfo = collections.namedtuple("ElementDebugInfo", ("element", "init_stack", "attribute_stacks"))
51 |
52 | MODULE_PATH = os.path.dirname(sys.modules[__name__].__file__)
53 | DEBUG_METADATA_PREFIX = "pymjcfdebug"
54 |
55 | _DEBUG_METADATA_TAG_PREFIX = "".format(DEBUG_METADATA_PREFIX))
57 |
58 | # Modified by `freeze_current_stack_trace`.
59 | _CURRENT_FROZEN_STACK = None
60 |
61 | # These globals will take their default values from the `--pymjcf_debug` and
62 | # `--pymjcf_debug_full_dump_dir` flags respectively. We cannot use `FLAGS` as
63 | # global variables because flag parsing might not have taken place (e.g. when
64 | # running `nosetests`).
65 | _DEBUG_MODE_ENABLED = None
66 | _DEBUG_FULL_DUMP_DIR = None
67 |
68 |
69 | def debug_mode():
70 | """Returns a boolean that indicates whether PyMJCF debug mode is enabled."""
71 | global _DEBUG_MODE_ENABLED
72 | if _DEBUG_MODE_ENABLED is None:
73 | if FLAGS.is_parsed():
74 | _DEBUG_MODE_ENABLED = FLAGS.pymjcf_debug
75 | else:
76 | _DEBUG_MODE_ENABLED = FLAGS["pymjcf_debug"].default
77 | return _DEBUG_MODE_ENABLED
78 |
79 |
80 | def enable_debug_mode():
81 | """Enables PyMJCF debug mode."""
82 | global _DEBUG_MODE_ENABLED
83 | _DEBUG_MODE_ENABLED = True
84 |
85 |
86 | def disable_debug_mode():
87 | """Disables PyMJCF debug mode."""
88 | global _DEBUG_MODE_ENABLED
89 | _DEBUG_MODE_ENABLED = False
90 |
91 |
92 | def get_full_dump_dir():
93 | """Gets the directory to dump full debug info files."""
94 | global _DEBUG_FULL_DUMP_DIR
95 | if _DEBUG_FULL_DUMP_DIR is None:
96 | if FLAGS.is_parsed():
97 | _DEBUG_FULL_DUMP_DIR = FLAGS.pymjcf_debug_full_dump_dir
98 | else:
99 | _DEBUG_FULL_DUMP_DIR = FLAGS["pymjcf_debug_full_dump_dir"].default
100 | return _DEBUG_FULL_DUMP_DIR
101 |
102 |
103 | def set_full_dump_dir(dump_path):
104 | """Sets the directory to dump full debug info files."""
105 | global _DEBUG_FULL_DUMP_DIR
106 | _DEBUG_FULL_DUMP_DIR = dump_path
107 |
108 |
109 | def get_current_stack_trace():
110 | """Returns the stack trace of the current execution frame.
111 |
112 | Returns:
113 | A list of `StackTraceEntry` named tuples corresponding to the current stack
114 | trace of the process, truncated to immediately before entry into
115 | PyMJCF internal code.
116 | """
117 | if _CURRENT_FROZEN_STACK:
118 | return copy.deepcopy(_CURRENT_FROZEN_STACK)
119 | else:
120 | return _get_actual_current_stack_trace()
121 |
122 |
123 | def _get_actual_current_stack_trace():
124 | """Returns the stack trace of the current execution frame.
125 |
126 | Returns:
127 | A list of `StackTraceEntry` named tuples corresponding to the current stack
128 | trace of the process, truncated to immediately before entry into
129 | PyMJCF internal code.
130 | """
131 | raw_stack = traceback.extract_stack()
132 | processed_stack = []
133 | for raw_stack_item in raw_stack:
134 | stack_item = StackTraceEntry(*raw_stack_item)
135 | if stack_item.filename.startswith(MODULE_PATH) and not stack_item.filename.endswith("_test.py"):
136 | break
137 | else:
138 | processed_stack.append(stack_item)
139 | return processed_stack
140 |
141 |
142 | @contextlib.contextmanager
143 | def freeze_current_stack_trace():
144 | """A context manager that freezes the stack trace.
145 |
146 | AVOID USING THIS CONTEXT MANAGER OUTSIDE OF INTERNAL PYMJCF IMPLEMENTATION,
147 | AS IT REDUCES THE USEFULNESS OF DEBUG MODE.
148 |
149 | If PyMJCF debug mode is enabled, calls to `debugging.get_current_stack_trace`
150 | within this context will always return the stack trace from when this context
151 | was entered.
152 |
153 | The frozen stack is global to this debugging module. That is, if the context
154 | is entered while another one is still active, then the stack trace of the
155 | outermost one is returned.
156 |
157 | This context significantly speeds up bulk operations in debug mode, e.g.
158 | parsing an existing XML string or creating a deeply-nested element, as it
159 | prevents the same stack trace from being repeatedly constructed.
160 |
161 | Yields:
162 | `None`
163 | """
164 | global _CURRENT_FROZEN_STACK
165 | if debug_mode() and _CURRENT_FROZEN_STACK is None:
166 | _CURRENT_FROZEN_STACK = _get_actual_current_stack_trace()
167 | yield
168 | _CURRENT_FROZEN_STACK = None
169 | else:
170 | yield
171 |
172 |
173 | class DebugContext(object):
174 | """A helper object to store debug information for a generated XML string.
175 |
176 | This class is intended for internal use within the PyMJCF implementation.
177 | """
178 |
179 | def __init__(self):
180 | self._xml_string = None
181 | self._debug_info_for_element_ids = {}
182 |
183 | def register_element_for_debugging(self, elem):
184 | """Registers an `Element` and returns debugging metadata for the XML.
185 |
186 | Args:
187 | elem: An `mjcf.Element`.
188 |
189 | Returns:
190 | An `lxml.etree.Comment` that represents debugging metadata in the
191 | generated XML.
192 | """
193 | if not debug_mode():
194 | return None
195 | else:
196 | self._debug_info_for_element_ids[id(elem)] = ElementDebugInfo(
197 | elem,
198 | copy.deepcopy(elem.get_init_stack()),
199 | copy.deepcopy(elem.get_last_modified_stacks_for_all_attributes()),
200 | )
201 | return etree.Comment("{}:{}".format(DEBUG_METADATA_PREFIX, id(elem)))
202 |
203 | def commit_xml_string(self, xml_string):
204 | """Commits the XML string associated with this debug context.
205 |
206 | This function also formats the XML string to make sure that the debugging
207 | metadata appears on the same line as the corresponding XML element.
208 |
209 | Args:
210 | xml_string: A pretty-printed XML string.
211 |
212 | Returns:
213 | A reformatted XML string where all debugging metadata appears on the same
214 | line as the corresponding XML element.
215 | """
216 | formatted = re.sub(r"\n\s*" + _DEBUG_METADATA_TAG_PREFIX, _DEBUG_METADATA_TAG_PREFIX, xml_string)
217 | self._xml_string = formatted
218 | return formatted
219 |
220 | def process_and_raise_last_exception(self):
221 | """Processes and re-raises the last mujoco.wrapper.Error caught.
222 |
223 | This function will insert the relevant line from the source XML to the error
224 | message. If debug mode is enabled, additional debugging information is
225 | appended to the error message. If debug mode is not enabled, the error
226 | message instructs the user to enable it by rerunning the executable with an
227 | appropriate flag.
228 | """
229 | err_type, err, stack = sys.exc_info()
230 | line_number_match = re.search(r"[Ll][Ii][Nn][Ee]\s*[:=]?\s*(\d+)", str(err))
231 | if line_number_match:
232 | xml_line_number = int(line_number_match.group(1))
233 | xml_line = self._xml_string.split("\n")[xml_line_number - 1]
234 | stripped_xml_line = xml_line.strip()
235 | comment_match = re.search(_DEBUG_METADATA_TAG_PREFIX, stripped_xml_line)
236 | if comment_match:
237 | stripped_xml_line = stripped_xml_line[: comment_match.start()]
238 | else:
239 | xml_line = ""
240 | stripped_xml_line = ""
241 |
242 | message_lines = []
243 | if debug_mode():
244 | if get_full_dump_dir():
245 | self.dump_full_debug_info_to_disk()
246 | message_lines.extend(["Compile error raised by Mujoco.", str(err)])
247 | if xml_line:
248 | message_lines.extend([stripped_xml_line, self._generate_debug_message_from_xml_line(xml_line)])
249 | else:
250 | message_lines.extend(
251 | [
252 | "Compile error raised by Mujoco; "
253 | "run again with --pymjcf_debug for additional debug information.",
254 | str(err),
255 | ]
256 | )
257 | if xml_line:
258 | message_lines.append(stripped_xml_line)
259 |
260 | message = "\n".join(message_lines)
261 | six.reraise(err_type, err_type(message), stack)
262 |
263 | @property
264 | def default_dump_dir(self):
265 | return get_full_dump_dir()
266 |
267 | @property
268 | def debug_mode(self):
269 | return debug_mode()
270 |
271 | def dump_full_debug_info_to_disk(self, dump_dir=None):
272 | """Dumps full debug information to disk.
273 |
274 | Full debug information consists of an XML file whose elements are tagged
275 | with a unique ID, and a stack trace file for each element ID. Each stack
276 | trace file consists of a stack trace for when the element was created, and
277 | when each attribute was last modified.
278 |
279 | Args:
280 | dump_dir: Full path to the directory in which dump files are created.
281 |
282 | Raises:
283 | ValueError: If neither `dump_dir` nor the global dump path is given. The
284 | global dump path can be specified either via the
285 | --pymjcf_debug_full_dump_dir flag or via `debugging.set_full_dump_dir`.
286 | """
287 | dump_dir = dump_dir or self.default_dump_dir
288 | if not dump_dir:
289 | raise ValueError("`dump_dir` is not specified")
290 | section_separator = "\n" + ("=" * 80) + "\n"
291 |
292 | def dump_stack(header, stack, f):
293 | indent = " "
294 | f.write(header + "\n")
295 | for stack_entry in stack:
296 | f.write(
297 | indent
298 | + "`{}` at {}:{}\n".format(
299 | stack_entry.function_name, stack_entry.filename, stack_entry.line_number
300 | )
301 | )
302 | f.write((indent * 2) + str(stack_entry.text) + "\n")
303 | f.write(section_separator)
304 |
305 | with open(os.path.join(dump_dir, "model.xml"), "w") as f:
306 | f.write(self._xml_string)
307 | for elem_id, debug_info in six.iteritems(self._debug_info_for_element_ids):
308 | with open(os.path.join(dump_dir, str(elem_id) + ".dump"), "w") as f:
309 | f.write("{}:{}\n".format(DEBUG_METADATA_PREFIX, elem_id))
310 | f.write(str(debug_info.element) + "\n")
311 | dump_stack("Element creation", debug_info.init_stack, f)
312 | for attrib_name, stack in six.iteritems(debug_info.attribute_stacks):
313 | attrib_value = debug_info.element.get_attribute_xml_string(attrib_name)
314 | if stack[-1] == debug_info.init_stack[-1]:
315 | if attrib_value is not None:
316 | f.write('Attribute {}="{}"\n'.format(attrib_name, attrib_value))
317 | f.write(" was set when the element was created\n")
318 | f.write(section_separator)
319 | else:
320 | if attrib_value is not None:
321 | dump_stack('Attribute {}="{}"'.format(attrib_name, attrib_value), stack, f)
322 | else:
323 | dump_stack("Attribute {} was CLEARED".format(attrib_name), stack, f)
324 |
325 | def _generate_debug_message_from_xml_line(self, xml_line):
326 | """Generates a debug message by parsing the metadata on an XML line."""
327 | metadata_match = _DEBUG_METADATA_SEARCH_PATTERN.search(xml_line)
328 | if metadata_match:
329 | elem_id = int(metadata_match.group(1))
330 | return self._generate_debug_message_from_element_id(elem_id)
331 | else:
332 | return ""
333 |
334 | def _generate_debug_message_from_element_id(self, elem_id):
335 | """Generates a debug message for the specified Element."""
336 | out = []
337 | debug_info = self._debug_info_for_element_ids[elem_id]
338 |
339 | out.append("Debug summary for element:")
340 | if not get_full_dump_dir():
341 | out.append(
342 | " * Full debug info can be dumped to disk by setting the "
343 | "flag --pymjcf_debug_full_dump_dir=path/to/dump>"
344 | )
345 | out.append(
346 | " * Element object was created by `{}` at {}:{}".format(
347 | debug_info.init_stack[-1].function_name,
348 | debug_info.init_stack[-1].filename,
349 | debug_info.init_stack[-1].line_number,
350 | )
351 | )
352 |
353 | for attrib_name, stack in six.iteritems(debug_info.attribute_stacks):
354 | attrib_value = debug_info.element.get_attribute_xml_string(attrib_name)
355 | if stack[-1] == debug_info.init_stack[-1]:
356 | if attrib_value is not None:
357 | out.append(' * {}="{}" was set when the element was created'.format(attrib_name, attrib_value))
358 | else:
359 | if attrib_value is not None:
360 | out.append(
361 | ' * {}="{}" was set by `{}` at `{}:{}`'.format(
362 | attrib_name,
363 | attrib_value,
364 | stack[-1].function_name,
365 | stack[-1].filename,
366 | stack[-1].line_number,
367 | )
368 | )
369 | else:
370 | out.append(
371 | " * {} was CLEARED by `{}` at {}:{}".format(
372 | attrib_name, stack[-1].function_name, stack[-1].filename, stack[-1].line_number
373 | )
374 | )
375 |
376 | return "\n".join(out)
377 |
--------------------------------------------------------------------------------
/kinpy/mjcf_parser/io.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017-2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """IO functions."""
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 |
21 | def GetResource(name, mode="rb"):
22 | with open(name, mode=mode) as f:
23 | return f.read()
24 |
25 |
26 | def GetResourceFilename(name, mode="rb"):
27 | del mode # Unused.
28 | return name
29 |
30 |
31 | GetResourceAsFile = open # pylint: disable=invalid-name
32 |
--------------------------------------------------------------------------------
/kinpy/mjcf_parser/namescope.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """An object to manage the scoping of identifiers in MJCF models."""
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | import collections
21 |
22 | import six
23 |
24 | from . import constants
25 |
26 |
27 | class NameScope(object):
28 | """A name scoping context for an MJCF model.
29 |
30 | This object maintains the uniqueness of identifiers within each MJCF
31 | namespace. Examples of MJCF namespaces include 'body', 'joint', and 'geom'.
32 | Each namescope also carries a name, and can have a parent namescope.
33 | When MJCF models are merged, all identifiers gain a hierarchical prefix
34 | separated by '/', which is the concatenation of all scope names up to
35 | the root namescope.
36 | """
37 |
38 | def __init__(self, name, mjcf_model, model_dir="", assets=None):
39 | """Initializes a scope with the given name.
40 |
41 | Args:
42 | name: The scope's name
43 | mjcf_model: The RootElement of the MJCF model associated with this scope.
44 | model_dir: (optional) Path to the directory containing the model XML file.
45 | This is used to prefix the paths of all asset files.
46 | assets: (optional) A dictionary of pre-loaded assets, of the form
47 | `{filename: bytestring}`. If present, PyMJCF will search for assets in
48 | this dictionary before attempting to load them from the filesystem.
49 | """
50 | self._parent = None
51 | self._name = name
52 | self._mjcf_model = mjcf_model
53 | self._namespaces = collections.defaultdict(dict)
54 | self._model_dir = model_dir
55 | self._files = set()
56 | self._assets = assets or {}
57 | self._revision = 0
58 |
59 | @property
60 | def revision(self):
61 | return self._revision
62 |
63 | def increment_revision(self):
64 | self._revision += 1
65 | for namescope in six.itervalues(self._namespaces["namescope"]):
66 | namescope.increment_revision()
67 |
68 | @property
69 | def name(self):
70 | """This scope's name."""
71 | return self._name
72 |
73 | @property
74 | def files(self):
75 | """A set containing the `File` attributes registered in this scope."""
76 | return self._files
77 |
78 | @property
79 | def assets(self):
80 | """A dictionary containing pre-loaded assets."""
81 | return self._assets
82 |
83 | @property
84 | def model_dir(self):
85 | """Path to the directory containing the model XML file."""
86 | return self._model_dir
87 |
88 | @name.setter
89 | def name(self, new_name):
90 | if self._parent:
91 | self._parent.add("namescope", new_name, self)
92 | self._parent.remove("namescope", self._name)
93 | self._name = new_name
94 | self.increment_revision()
95 |
96 | @property
97 | def mjcf_model(self):
98 | return self._mjcf_model
99 |
100 | @property
101 | def parent(self):
102 | """This parent `NameScope`, or `None` if this is a root scope."""
103 | return self._parent
104 |
105 | @parent.setter
106 | def parent(self, new_parent):
107 | if self._parent:
108 | self._parent.remove("namescope", self._name)
109 | self._parent = new_parent
110 | if self._parent:
111 | self._parent.add("namescope", self._name, self)
112 | self.increment_revision()
113 |
114 | @property
115 | def root(self):
116 | if self._parent is None:
117 | return self
118 | else:
119 | return self._parent.root
120 |
121 | def full_prefix(self, prefix_root=None, as_list=False):
122 | """The prefix for identifiers belonging to this scope.
123 |
124 | Args:
125 | prefix_root: (optional) A `NameScope` object to be treated as root
126 | for the purpose of calculating the prefix. If `None` then no prefix
127 | is produced.
128 | as_list: (optional) A boolean, if `True` return the list of prefix
129 | components. If `False`, return the full prefix string separated by
130 | `mjcf.constants.PREFIX_SEPARATOR`.
131 |
132 | Returns:
133 | The prefix string.
134 | """
135 | prefix_root = prefix_root or self
136 | if prefix_root != self and self._parent:
137 | prefix_list = self._parent.full_prefix(prefix_root, as_list=True)
138 | prefix_list.append(self._name)
139 | else:
140 | prefix_list = []
141 | if as_list:
142 | return prefix_list
143 | else:
144 | if prefix_list:
145 | prefix_list.append("")
146 | return constants.PREFIX_SEPARATOR.join(prefix_list)
147 |
148 | def _assign(self, namespace, identifier, obj):
149 | """Checks a proposed identifier's validity before assigning to an object."""
150 | namespace_dict = self._namespaces[namespace]
151 | if not isinstance(identifier, str):
152 | raise ValueError("Identifier must be a string: got {}".format(type(identifier)))
153 | elif constants.PREFIX_SEPARATOR in identifier:
154 | raise ValueError("Identifier cannot contain {!r}: got {}".format(constants.PREFIX_SEPARATOR, identifier))
155 | else:
156 | namespace_dict[identifier] = obj
157 |
158 | def add(self, namespace, identifier, obj):
159 | """Add an identifier to this name scope.
160 |
161 | Args:
162 | namespace: A string specifying the namespace to which the
163 | identifier belongs.
164 | identifier: The identifier string.
165 | obj: The object referred to by the identifier.
166 |
167 | Raises:
168 | ValueError: If `identifier` not valid.
169 | """
170 | namespace_dict = self._namespaces[namespace]
171 | if identifier in namespace_dict:
172 | raise ValueError("Duplicated identifier {!r} in namespace <{}>".format(identifier, namespace))
173 | else:
174 | self._assign(namespace, identifier, obj)
175 | self.increment_revision()
176 |
177 | def replace(self, namespace, identifier, obj):
178 | """Reassociates an identifier with a different object.
179 |
180 | Args:
181 | namespace: A string specifying the namespace to which the
182 | identifier belongs.
183 | identifier: The identifier string.
184 | obj: The object referred to by the identifier.
185 |
186 | Raises:
187 | ValueError: If `identifier` not valid.
188 | """
189 | self._assign(namespace, identifier, obj)
190 | self.increment_revision()
191 |
192 | def remove(self, namespace, identifier):
193 | """Removes an identifier from this name scope.
194 |
195 | Args:
196 | namespace: A string specifying the namespace to which the
197 | identifier belongs.
198 | identifier: The identifier string.
199 |
200 | Raises:
201 | KeyError: If `identifier` does not exist in this scope.
202 | """
203 | del self._namespaces[namespace][identifier]
204 | self.increment_revision()
205 |
206 | def rename(self, namespace, old_identifier, new_identifier):
207 | obj = self.get(namespace, old_identifier)
208 | self.add(namespace, new_identifier, obj)
209 | self.remove(namespace, old_identifier)
210 |
211 | def get(self, namespace, identifier):
212 | return self._namespaces[namespace][identifier]
213 |
214 | def has_identifier(self, namespace, identifier):
215 | return identifier in self._namespaces[namespace]
216 |
--------------------------------------------------------------------------------
/kinpy/mjcf_parser/parser.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Functions for parsing XML into an MJCF object model."""
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | import os
21 | import sys
22 |
23 | import six
24 | from lxml import etree
25 |
26 | from . import constants, debugging, element
27 | from . import io as resources
28 |
29 |
30 | def from_xml_string(xml_string, escape_separators=False, model_dir="", resolve_references=True, assets=None):
31 | """Parses an XML string into an MJCF object model.
32 |
33 | Args:
34 | xml_string: An XML string representing an MJCF model.
35 | escape_separators: (optional) A boolean, whether to replace '/' characters
36 | in element identifiers. If `False`, any '/' present in the XML causes
37 | a ValueError to be raised.
38 | model_dir: (optional) Path to the directory containing the model XML file.
39 | This is used to prefix the paths of all asset files.
40 | resolve_references: (optional) A boolean indicating whether the parser
41 | should attempt to resolve reference attributes to a corresponding element.
42 | assets: (optional) A dictionary of pre-loaded assets, of the form
43 | `{filename: bytestring}`. If present, PyMJCF will search for assets in
44 | this dictionary before attempting to load them from the filesystem.
45 |
46 | Returns:
47 | An `mjcf.RootElement`.
48 | """
49 | xml_root = etree.fromstring(xml_string)
50 | return _parse(
51 | xml_root, escape_separators, model_dir=model_dir, resolve_references=resolve_references, assets=assets
52 | )
53 |
54 |
55 | def from_file(file_handle, escape_separators=False, model_dir="", resolve_references=True, assets=None):
56 | """Parses an XML file into an MJCF object model.
57 |
58 | Args:
59 | file_handle: A Python file-like handle.
60 | escape_separators: (optional) A boolean, whether to replace '/' characters
61 | in element identifiers. If `False`, any '/' present in the XML causes
62 | a ValueError to be raised.
63 | model_dir: (optional) Path to the directory containing the model XML file.
64 | This is used to prefix the paths of all asset files.
65 | resolve_references: (optional) A boolean indicating whether the parser
66 | should attempt to resolve reference attributes to a corresponding element.
67 | assets: (optional) A dictionary of pre-loaded assets, of the form
68 | `{filename: bytestring}`. If present, PyMJCF will search for assets in
69 | this dictionary before attempting to load them from the filesystem.
70 |
71 | Returns:
72 | An `mjcf.RootElement`.
73 | """
74 | xml_root = etree.parse(file_handle).getroot()
75 | return _parse(
76 | xml_root, escape_separators, model_dir=model_dir, resolve_references=resolve_references, assets=assets
77 | )
78 |
79 |
80 | def from_path(path, escape_separators=False, resolve_references=True, assets=None):
81 | """Parses an XML file into an MJCF object model.
82 |
83 | Args:
84 | path: A path to an XML file. This path should be loadable using
85 | `resources.GetResource`.
86 | escape_separators: (optional) A boolean, whether to replace '/' characters
87 | in element identifiers. If `False`, any '/' present in the XML causes
88 | a ValueError to be raised.
89 | resolve_references: (optional) A boolean indicating whether the parser
90 | should attempt to resolve reference attributes to a corresponding element.
91 | assets: (optional) A dictionary of pre-loaded assets, of the form
92 | `{filename: bytestring}`. If present, PyMJCF will search for assets in
93 | this dictionary before attempting to load them from the filesystem.
94 |
95 | Returns:
96 | An `mjcf.RootElement`.
97 | """
98 | model_dir, _ = os.path.split(path)
99 | contents = resources.GetResource(path)
100 | xml_root = etree.fromstring(contents)
101 | return _parse(
102 | xml_root, escape_separators, model_dir=model_dir, resolve_references=resolve_references, assets=assets
103 | )
104 |
105 |
106 | def _parse(xml_root, escape_separators=False, model_dir="", resolve_references=True, assets=None):
107 | """Parses a complete MJCF model from an XML.
108 |
109 | Args:
110 | xml_root: An `etree.Element` object.
111 | escape_separators: (optional) A boolean, whether to replace '/' characters
112 | in element identifiers. If `False`, any '/' present in the XML causes
113 | a ValueError to be raised.
114 | model_dir: (optional) Path to the directory containing the model XML file.
115 | This is used to prefix the paths of all asset files.
116 | resolve_references: (optional) A boolean indicating whether the parser
117 | should attempt to resolve reference attributes to a corresponding element.
118 | assets: (optional) A dictionary of pre-loaded assets, of the form
119 | `{filename: bytestring}`. If present, PyMJCF will search for assets in
120 | this dictionary before attempting to load them from the filesystem.
121 |
122 | Returns:
123 | An `mjcf.RootElement`.
124 |
125 | Raises:
126 | ValueError: If `xml_root`'s tag is not 'mujoco'.
127 | """
128 |
129 | assets = assets or {}
130 |
131 | if xml_root.tag != "mujoco":
132 | raise ValueError("Root element of the XML should be : got <{}>".format(xml_root.tag))
133 |
134 | with debugging.freeze_current_stack_trace():
135 | # Recursively parse any included XML files.
136 | to_include = []
137 | for include_tag in xml_root.findall("include"):
138 | try:
139 | # First look for the path to the included XML file in the assets dict.
140 | path_or_xml_string = assets[include_tag.attrib["file"]]
141 | parsing_func = from_xml_string
142 | except KeyError:
143 | # If it's not present in the assets dict then attempt to load the XML
144 | # from the filesystem.
145 | path_or_xml_string = os.path.join(model_dir, include_tag.attrib["file"])
146 | parsing_func = from_path
147 | included_mjcf = parsing_func(
148 | path_or_xml_string,
149 | escape_separators=escape_separators,
150 | resolve_references=resolve_references,
151 | assets=assets,
152 | )
153 | to_include.append(included_mjcf)
154 | # We must remove tags before parsing the main XML file, since
155 | # these are a schema violation.
156 | xml_root.remove(include_tag)
157 |
158 | # Parse the main XML file.
159 | try:
160 | model = xml_root.attrib.pop("model")
161 | except KeyError:
162 | model = None
163 | mjcf_root = element.RootElement(model=model, model_dir=model_dir, assets=assets)
164 | _parse_children(xml_root, mjcf_root, escape_separators)
165 |
166 | # Merge in the included XML files.
167 | for included_mjcf in to_include:
168 | # The included MJCF might have been automatically assigned a model name
169 | # that conficts with that of `mjcf_root`, so we override it here.
170 | included_mjcf.model = mjcf_root.model
171 | mjcf_root.include_copy(included_mjcf)
172 |
173 | if resolve_references:
174 | mjcf_root.resolve_references()
175 | return mjcf_root
176 |
177 |
178 | def _parse_children(xml_element, mjcf_element, escape_separators=False):
179 | """Parses all children of a given XML element into an MJCF element.
180 |
181 | Args:
182 | xml_element: The source `etree.Element` object.
183 | mjcf_element: The target `mjcf.Element` object.
184 | escape_separators: (optional) A boolean, whether to replace '/' characters
185 | in element identifiers. If `False`, any '/' present in the XML causes
186 | a ValueError to be raised.
187 | """
188 | for xml_child in xml_element:
189 | if xml_child.tag is etree.Comment or xml_child.tag is etree.PI:
190 | continue
191 | try:
192 | child_spec = mjcf_element.spec.children[xml_child.tag]
193 | if escape_separators:
194 | attributes = {}
195 | for name, value in six.iteritems(xml_child.attrib):
196 | new_value = value.replace(constants.PREFIX_SEPARATOR_ESCAPE, constants.PREFIX_SEPARATOR_ESCAPE * 2)
197 | new_value = new_value.replace(constants.PREFIX_SEPARATOR, constants.PREFIX_SEPARATOR_ESCAPE)
198 | attributes[name] = new_value
199 | else:
200 | attributes = dict(xml_child.attrib)
201 | if child_spec.repeated or child_spec.on_demand:
202 | mjcf_child = mjcf_element.add(xml_child.tag, **attributes)
203 | else:
204 | mjcf_child = getattr(mjcf_element, xml_child.tag)
205 | mjcf_child.set_attributes(**attributes)
206 | except: # pylint: disable=bare-except
207 | err_type, err, traceback = sys.exc_info()
208 | message = "Line {}: error while parsing element <{}>: {}".format(xml_child.sourceline, xml_child.tag, err)
209 | six.reraise(err_type, err_type(message), traceback)
210 | _parse_children(xml_child, mjcf_child, escape_separators)
211 |
--------------------------------------------------------------------------------
/kinpy/mjcf_parser/schema.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """A Python object representation of Mujoco's MJCF schema.
17 |
18 | The root schema is provided as a module-level constant `schema.MUJOCO`.
19 | """
20 |
21 | from __future__ import absolute_import, division, print_function
22 |
23 | import collections
24 | import copy
25 | import os
26 | import pkgutil
27 |
28 | import six
29 | from lxml import etree
30 |
31 | from . import attribute
32 | from . import io as resources
33 |
34 | _SCHEMA_XML_PATH = "mjcf_parser/schema.xml"
35 |
36 | _ARRAY_DTYPE_MAP = {"int": int, "float": float, "string": str}
37 |
38 | _SCALAR_TYPE_MAP = {"int": attribute.Integer, "float": attribute.Float, "string": attribute.String}
39 |
40 | ElementSpec = collections.namedtuple(
41 | "ElementSpec", ("name", "repeated", "on_demand", "identifier", "namespace", "attributes", "children")
42 | )
43 |
44 | AttributeSpec = collections.namedtuple(
45 | "AttributeSpec", ("name", "type", "required", "conflict_allowed", "conflict_behavior", "other_kwargs")
46 | )
47 |
48 | # Additional namespaces that are not present in the MJCF schema but can
49 | # be used in `find` and `find_all`.
50 | _ADDITIONAL_FINDABLE_NAMESPACES = frozenset(["attachment_frame"])
51 |
52 |
53 | def _str2bool(string):
54 | """Converts either 'true' or 'false' (not case-sensitively) into a boolean."""
55 | if string is None:
56 | return False
57 | else:
58 | string = string.lower()
59 |
60 | if string == "true":
61 | return True
62 | elif string == "false":
63 | return False
64 | else:
65 | raise ValueError("String should either be `true` or `false`: got {}".format(string))
66 |
67 |
68 | def parse_schema(schema_path):
69 | """Parses the schema XML.
70 |
71 | Args:
72 | schema_path: Path to the schema XML file.
73 |
74 | Returns:
75 | An `ElementSpec` for the root element in the schema.
76 | """
77 | schema_xml = etree.fromstring(pkgutil.get_data("kinpy", schema_path))
78 | return _parse_element(schema_xml)
79 |
80 |
81 | def _parse_element(element_xml):
82 | """Parses an element in the schema."""
83 | name = element_xml.get("name")
84 | if not name:
85 | raise ValueError("Element must always have a name")
86 | repeated = _str2bool(element_xml.get("repeated"))
87 | on_demand = _str2bool(element_xml.get("on_demand"))
88 |
89 | attributes = collections.OrderedDict()
90 | attributes_xml = element_xml.find("attributes")
91 | if attributes_xml is not None:
92 | for attribute_xml in attributes_xml.findall("attribute"):
93 | attributes[attribute_xml.get("name")] = _parse_attribute(attribute_xml)
94 |
95 | identifier = None
96 | namespace = None
97 | for attribute_spec in six.itervalues(attributes):
98 | if attribute_spec.type == attribute.Identifier:
99 | identifier = attribute_spec.name
100 | namespace = element_xml.get("namespace") or name
101 |
102 | children = collections.OrderedDict()
103 | children_xml = element_xml.find("children")
104 | if children_xml is not None:
105 | for child_xml in children_xml.findall("element"):
106 | children[child_xml.get("name")] = _parse_element(child_xml)
107 |
108 | element_spec = ElementSpec(name, repeated, on_demand, identifier, namespace, attributes, children)
109 |
110 | recursive = _str2bool(element_xml.get("recursive"))
111 | if recursive:
112 | element_spec.children[name] = element_spec
113 |
114 | common_keys = set(element_spec.attributes).intersection(element_spec.children)
115 | if common_keys:
116 | raise RuntimeError(
117 | "Element '{}' contains the following attributes and children with "
118 | "the same name: '{}'. This violates the design assumptions of "
119 | "this library. Please file a bug report. Thank you.".format(name, sorted(common_keys))
120 | )
121 |
122 | return element_spec
123 |
124 |
125 | def _parse_attribute(attribute_xml):
126 | """Parses an element in the schema."""
127 | name = attribute_xml.get("name")
128 | required = _str2bool(attribute_xml.get("required"))
129 | conflict_allowed = _str2bool(attribute_xml.get("conflict_allowed"))
130 | conflict_behavior = attribute_xml.get("conflict_behavior", "replace")
131 | attribute_type = attribute_xml.get("type")
132 | other_kwargs = {}
133 | if attribute_type == "keyword":
134 | attribute_callable = attribute.Keyword
135 | other_kwargs["valid_values"] = attribute_xml.get("valid_values").split(" ")
136 | elif attribute_type == "array":
137 | array_size_str = attribute_xml.get("array_size")
138 | attribute_callable = attribute.Array
139 | other_kwargs["length"] = int(array_size_str) if array_size_str else None
140 | other_kwargs["dtype"] = _ARRAY_DTYPE_MAP[attribute_xml.get("array_type")]
141 | elif attribute_type == "identifier":
142 | attribute_callable = attribute.Identifier
143 | elif attribute_type == "reference":
144 | attribute_callable = attribute.Reference
145 | other_kwargs["reference_namespace"] = attribute_xml.get("reference_namespace") or name
146 | elif attribute_type == "basepath":
147 | attribute_callable = attribute.BasePath
148 | other_kwargs["path_namespace"] = attribute_xml.get("path_namespace")
149 | elif attribute_type == "file":
150 | attribute_callable = attribute.File
151 | other_kwargs["path_namespace"] = attribute_xml.get("path_namespace")
152 | else:
153 | try:
154 | attribute_callable = _SCALAR_TYPE_MAP[attribute_type]
155 | except KeyError:
156 | raise ValueError("Invalid attribute type: {}".format(attribute_type))
157 |
158 | return AttributeSpec(
159 | name=name,
160 | type=attribute_callable,
161 | required=required,
162 | conflict_allowed=conflict_allowed,
163 | conflict_behavior=conflict_behavior,
164 | other_kwargs=other_kwargs,
165 | )
166 |
167 |
168 | def collect_namespaces(root_spec):
169 | """Constructs a set of namespaces in a given ElementSpec.
170 |
171 | Args:
172 | root_spec: An `ElementSpec` for the root element in the schema.
173 |
174 | Returns:
175 | A set of strings specifying the names of all the namespaces that are present
176 | in the spec.
177 | """
178 | findable_namespaces = set()
179 |
180 | def update_namespaces_from_spec(spec):
181 | findable_namespaces.add(spec.namespace)
182 | for child_spec in six.itervalues(spec.children):
183 | if child_spec is not spec:
184 | update_namespaces_from_spec(child_spec)
185 |
186 | update_namespaces_from_spec(root_spec)
187 | return findable_namespaces
188 |
189 |
190 | MUJOCO = parse_schema(_SCHEMA_XML_PATH)
191 | FINDABLE_NAMESPACES = frozenset(collect_namespaces(MUJOCO).union(_ADDITIONAL_FINDABLE_NAMESPACES))
192 |
193 |
194 | def _attachment_frame_spec(is_world_attachment):
195 | """Create specs for attachment frames.
196 |
197 | Attachment frames are specialized without an identifier.
198 | The only allowed children are joints which also don't have identifiers.
199 |
200 | Args:
201 | is_world_attachment: Whether we are creating a spec for attachments to
202 | worldbody. If `True`, allow as child.
203 |
204 | Returns:
205 | An `ElementSpec`.
206 | """
207 | frame_spec = ElementSpec(
208 | "body",
209 | repeated=True,
210 | on_demand=False,
211 | identifier=None,
212 | namespace="body",
213 | attributes=collections.OrderedDict(),
214 | children=collections.OrderedDict(),
215 | )
216 |
217 | body_spec = MUJOCO.children["worldbody"].children["body"]
218 | # 'name' and 'childclass' attributes are excluded.
219 | for attrib_name in ("mocap", "pos", "quat", "axisangle", "xyaxes", "zaxis", "euler"):
220 | frame_spec.attributes[attrib_name] = copy.deepcopy(body_spec.attributes[attrib_name])
221 |
222 | inertial_spec = body_spec.children["inertial"]
223 | frame_spec.children["inertial"] = copy.deepcopy(inertial_spec)
224 | joint_spec = body_spec.children["joint"]
225 | frame_spec.children["joint"] = ElementSpec(
226 | "joint",
227 | repeated=True,
228 | on_demand=False,
229 | identifier=None,
230 | namespace="joint",
231 | attributes=copy.deepcopy(joint_spec.attributes),
232 | children=collections.OrderedDict(),
233 | )
234 |
235 | if is_world_attachment:
236 | freejoint_spec = MUJOCO.children["worldbody"].children["body"].children["freejoint"]
237 | frame_spec.children["freejoint"] = ElementSpec(
238 | "freejoint",
239 | repeated=False,
240 | on_demand=True,
241 | identifier=None,
242 | namespace="joint",
243 | attributes=copy.deepcopy(freejoint_spec.attributes),
244 | children=collections.OrderedDict(),
245 | )
246 |
247 | return frame_spec
248 |
249 |
250 | ATTACHMENT_FRAME = _attachment_frame_spec(is_world_attachment=False)
251 | WORLD_ATTACHMENT_FRAME = _attachment_frame_spec(is_world_attachment=True)
252 |
--------------------------------------------------------------------------------
/kinpy/mjcf_parser/util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017-2018 The dm_control Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 |
16 | """Various helper functions and classes."""
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | import sys
21 |
22 | import six
23 |
24 | DEFAULT_ENCODING = sys.getdefaultencoding()
25 |
26 |
27 | def to_binary_string(s):
28 | """Convert text string to binary."""
29 | if isinstance(s, six.binary_type):
30 | return s
31 | return s.encode(DEFAULT_ENCODING)
32 |
33 |
34 | def to_native_string(s):
35 | """Convert a text or binary string to the native string format."""
36 | if six.PY3 and isinstance(s, six.binary_type):
37 | return s.decode(DEFAULT_ENCODING)
38 | elif six.PY2 and isinstance(s, six.text_type):
39 | return s.encode(DEFAULT_ENCODING)
40 | else:
41 | return s
42 |
--------------------------------------------------------------------------------
/kinpy/sdf.py:
--------------------------------------------------------------------------------
1 | import io
2 | from typing import List, TextIO, Union
3 |
4 | import numpy as np
5 |
6 | from . import chain, frame, transform
7 | from .urdf_parser_py.sdf import SDF, Box, Cylinder, Mesh, Sphere
8 |
9 | JOINT_TYPE_MAP = {"revolute": "revolute", "prismatic": "prismatic", "fixed": "fixed"}
10 |
11 |
12 | def _convert_transform(pose: np.ndarray) -> transform.Transform:
13 | if pose is None:
14 | return transform.Transform()
15 | else:
16 | return transform.Transform(rot=pose[3:], pos=pose[:3])
17 |
18 |
19 | def _convert_visuals(visuals: List) -> List:
20 | vlist = []
21 | for v in visuals:
22 | v_tf = _convert_transform(v.pose)
23 | if isinstance(v.geometry, Mesh):
24 | g_type = "mesh"
25 | g_param = v.geometry.filename
26 | elif isinstance(v.geometry, Cylinder):
27 | g_type = "cylinder"
28 | v_tf = v_tf * transform.Transform(rot=np.deg2rad([90.0, 0.0, 0.0]))
29 | g_param = (v.geometry.radius, v.geometry.length)
30 | elif isinstance(v.geometry, Box):
31 | g_type = "box"
32 | g_param = v.geometry.size
33 | elif isinstance(v.geometry, Sphere):
34 | g_type = "sphere"
35 | g_param = v.geometry.radius
36 | else:
37 | g_type = None
38 | g_param = None
39 | vlist.append(frame.Visual(v_tf, g_type, g_param))
40 | return vlist
41 |
42 |
43 | def _build_chain_recurse(root_frame, lmap, joints) -> List:
44 | children = []
45 | for j in joints:
46 | if j.parent == root_frame.link.name:
47 | child_frame = frame.Frame(j.child + "_frame")
48 | link_p = lmap[j.parent]
49 | link_c = lmap[j.child]
50 | t_p = _convert_transform(link_p.pose)
51 | t_c = _convert_transform(link_c.pose)
52 | child_frame.joint = frame.Joint(
53 | j.name, offset=t_p.inverse() * t_c, joint_type=JOINT_TYPE_MAP[j.type], axis=j.axis.xyz
54 | )
55 | child_frame.link = frame.Link(
56 | link_c.name, offset=transform.Transform(), visuals=_convert_visuals(link_c.visuals)
57 | )
58 | child_frame.children = _build_chain_recurse(child_frame, lmap, joints)
59 | children.append(child_frame)
60 | return children
61 |
62 |
63 | def build_chain_from_sdf(data: Union[str, TextIO]) -> chain.Chain:
64 | """
65 | Build a Chain object from SDF data.
66 |
67 | Parameters
68 | ----------
69 | data : str or TextIO
70 | SDF string data or file object.
71 |
72 | Returns
73 | -------
74 | chain.Chain
75 | Chain object created from SDF.
76 | """
77 | if isinstance(data, io.TextIOBase):
78 | data = data.read()
79 | sdf = SDF.from_xml_string(data)
80 | robot = sdf.model
81 | lmap = robot.link_map
82 | joints = robot.joints
83 | n_joints = len(joints)
84 | has_root = [True for _ in range(len(joints))]
85 | for i in range(n_joints):
86 | for j in range(i + 1, n_joints):
87 | if joints[i].parent == joints[j].child:
88 | has_root[i] = False
89 | elif joints[j].parent == joints[i].child:
90 | has_root[j] = False
91 | for i in range(n_joints):
92 | if has_root[i]:
93 | root_link = lmap[joints[i].parent]
94 | break
95 | root_frame = frame.Frame(root_link.name + "_frame")
96 | root_frame.joint = frame.Joint(offset=_convert_transform(root_link.pose))
97 | root_frame.link = frame.Link(root_link.name, transform.Transform(), _convert_visuals(root_link.visuals))
98 | root_frame.children = _build_chain_recurse(root_frame, lmap, joints)
99 | return chain.Chain(root_frame)
100 |
--------------------------------------------------------------------------------
/kinpy/transform.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | import numpy as np
4 | import transformations as tf
5 |
6 |
7 | class Transform:
8 | """This class calculates the rotation and translation of a 3D rigid body.
9 |
10 | Attributes
11 | ----------
12 | rot : np.ndarray
13 | The rotation parameter. Give in quaternions or roll pitch yaw.
14 | pos : np.ndarray
15 | The translation parameter.
16 | """
17 |
18 | def __init__(self, rot: Union[List, np.ndarray, None] = None, pos: Optional[np.ndarray] = None) -> None:
19 | if rot is None:
20 | rot = [1.0, 0.0, 0.0, 0.0]
21 | if pos is None:
22 | pos = np.zeros(3)
23 | if len(rot) == 3:
24 | self.rot = tf.quaternion_from_euler(*rot)
25 | elif len(rot) == 4:
26 | self.rot = np.array(rot)
27 | else:
28 | raise ValueError("Size of rot must be 3 or 4.")
29 | self.pos = np.array(pos)
30 |
31 | def __repr__(self) -> str:
32 | return "Transform(rot={0}, pos={1})".format(self.rot, self.pos)
33 |
34 | @staticmethod
35 | def _rotation_vec(rot: np.ndarray, vec: np.ndarray) -> np.ndarray:
36 | v4 = np.hstack([np.array([0.0]), vec])
37 | inv_rot = tf.quaternion_inverse(rot)
38 | ans = tf.quaternion_multiply(tf.quaternion_multiply(rot, v4), inv_rot)
39 | return ans[1:]
40 |
41 | def __mul__(self, other: "Transform") -> "Transform":
42 | rot = tf.quaternion_multiply(self.rot, other.rot)
43 | pos = self._rotation_vec(self.rot, other.pos) + self.pos
44 | return Transform(rot, pos)
45 |
46 | def inverse(self) -> "Transform":
47 | rot = tf.quaternion_inverse(self.rot)
48 | pos = -self._rotation_vec(rot, self.pos)
49 | return Transform(rot, pos)
50 |
51 | def matrix(self) -> np.ndarray:
52 | mat = tf.quaternion_matrix(self.rot)
53 | mat[:3, 3] = self.pos
54 | return mat
55 |
56 | @property
57 | def rot_mat(self) -> np.ndarray:
58 | return tf.quaternion_matrix(self.rot)[:3, :3]
59 |
60 | @property
61 | def rot_euler(self) -> np.ndarray:
62 | return tf.euler_from_quaternion(self.rot)
63 |
--------------------------------------------------------------------------------
/kinpy/urdf.py:
--------------------------------------------------------------------------------
1 | import io
2 | from typing import List, TextIO, Union
3 |
4 | from . import chain, frame, transform
5 | from .urdf_parser_py import urdf
6 |
7 | JOINT_TYPE_MAP = {"revolute": "revolute", "continuous": "revolute", "prismatic": "prismatic", "fixed": "fixed"}
8 |
9 |
10 | def _convert_transform(origin) -> transform.Transform:
11 | if origin is None:
12 | return transform.Transform()
13 | else:
14 | return transform.Transform(rot=origin.rpy, pos=origin.xyz)
15 |
16 |
17 | def _convert_visual(visual) -> frame.Visual:
18 | if visual is None or visual.geometry is None:
19 | return frame.Visual()
20 | else:
21 | v_tf = _convert_transform(visual.origin)
22 | if isinstance(visual.geometry, urdf.Mesh):
23 | g_type = "mesh"
24 | g_param = visual.geometry.filename
25 | elif isinstance(visual.geometry, urdf.Cylinder):
26 | g_type = "cylinder"
27 | g_param = (visual.geometry.radius, visual.geometry.length)
28 | elif isinstance(visual.geometry, urdf.Box):
29 | g_type = "box"
30 | g_param = visual.geometry.size
31 | elif isinstance(visual.geometry, urdf.Sphere):
32 | g_type = "sphere"
33 | g_param = visual.geometry.radius
34 | else:
35 | g_type = None
36 | g_param = None
37 | return frame.Visual(v_tf, g_type, g_param)
38 |
39 |
40 | def _build_chain_recurse(root_frame, lmap, joints) -> List[frame.Frame]:
41 | children = []
42 | for j in joints:
43 | if j.parent == root_frame.link.name:
44 | child_frame = frame.Frame(j.child + "_frame")
45 | child_frame.joint = frame.Joint(
46 | j.name, offset=_convert_transform(j.origin), joint_type=JOINT_TYPE_MAP[j.type], axis=j.axis
47 | )
48 | link = lmap[j.child]
49 | child_frame.link = frame.Link(
50 | link.name, offset=_convert_transform(link.origin), visuals=[_convert_visual(link.visual)]
51 | )
52 | child_frame.children = _build_chain_recurse(child_frame, lmap, joints)
53 | children.append(child_frame)
54 | return children
55 |
56 |
57 | def build_chain_from_urdf(data: Union[str, TextIO]) -> chain.Chain:
58 | """
59 | Build a Chain object from URDF data.
60 |
61 | Parameters
62 | ----------
63 | data : str or TextIO
64 | URDF string data or file object.
65 |
66 | Returns
67 | -------
68 | chain.Chain
69 | Chain object created from URDF.
70 |
71 | Example
72 | -------
73 | >>> import kinpy as kp
74 | >>> data = '''
75 | ...
76 | ...
77 | ...
78 | ...
79 | ...
80 | ...
81 | ... '''
82 | >>> chain = kp.build_chain_from_urdf(data)
83 | >>> print(chain)
84 | link1_frame
85 | └──── link2_frame
86 |
87 | """
88 | if isinstance(data, io.TextIOBase):
89 | data = data.read()
90 | robot = urdf.URDF.from_xml_string(data)
91 | lmap = robot.link_map
92 | joints = robot.joints
93 | n_joints = len(joints)
94 | has_root = [True for _ in range(len(joints))]
95 | for i in range(n_joints):
96 | for j in range(i + 1, n_joints):
97 | if joints[i].parent == joints[j].child:
98 | has_root[i] = False
99 | elif joints[j].parent == joints[i].child:
100 | has_root[j] = False
101 | for i in range(n_joints):
102 | if has_root[i]:
103 | root_link = lmap[joints[i].parent]
104 | break
105 | root_frame = frame.Frame(root_link.name + "_frame")
106 | root_frame.joint = frame.Joint()
107 | root_frame.link = frame.Link(
108 | root_link.name, _convert_transform(root_link.origin), [_convert_visual(root_link.visual)]
109 | )
110 | root_frame.children = _build_chain_recurse(root_frame, lmap, joints)
111 | return chain.Chain(root_frame)
112 |
113 |
114 | def build_serial_chain_from_urdf(
115 | data: Union[str, TextIO], end_link_name: str, root_link_name: str = ""
116 | ) -> chain.SerialChain:
117 | """
118 | Build a SerialChain object from urdf data.
119 |
120 | Parameters
121 | ----------
122 | data : str or TextIO
123 | URDF string data or file object.
124 | end_link_name : str
125 | The name of the link that is the end effector.
126 | root_link_name : str, optional
127 | The name of the root link.
128 |
129 | Returns
130 | -------
131 | chain.SerialChain
132 | SerialChain object created from URDF.
133 | """
134 | if isinstance(data, io.TextIOBase):
135 | data = data.read()
136 | urdf_chain = build_chain_from_urdf(data)
137 | return chain.SerialChain(
138 | urdf_chain, end_link_name + "_frame", "" if root_link_name == "" else root_link_name + "_frame"
139 | )
140 |
--------------------------------------------------------------------------------
/kinpy/urdf_parser_py/__init__.py:
--------------------------------------------------------------------------------
1 | from . import sdf, urdf
2 |
--------------------------------------------------------------------------------
/kinpy/urdf_parser_py/sdf.py:
--------------------------------------------------------------------------------
1 | from . import xml_reflection as xmlr
2 | from .xml_reflection.basics import *
3 |
4 | # What is the scope of plugins? Model, World, Sensor?
5 |
6 | xmlr.start_namespace("sdf")
7 |
8 |
9 | name_attribute = xmlr.Attribute("name", str, False)
10 | pose_element = xmlr.Element("pose", "vector6", False)
11 |
12 |
13 | class Inertia(xmlr.Object):
14 | KEYS = ["ixx", "ixy", "ixz", "iyy", "iyz", "izz"]
15 |
16 | def __init__(self, ixx=0.0, ixy=0.0, ixz=0.0, iyy=0.0, iyz=0.0, izz=0.0):
17 | self.ixx = ixx
18 | self.ixy = ixy
19 | self.ixz = ixz
20 | self.iyy = iyy
21 | self.iyz = iyz
22 | self.izz = izz
23 |
24 | def to_matrix(self):
25 | return [[self.ixx, self.ixy, self.ixz], [self.ixy, self.iyy, self.iyz], [self.ixz, self.iyz, self.izz]]
26 |
27 |
28 | xmlr.reflect(Inertia, params=[xmlr.Element(key, float) for key in Inertia.KEYS])
29 |
30 | # Pretty much copy-paste... Better method?
31 | # Use multiple inheritance to separate the objects out so they are unique?
32 |
33 |
34 | class Inertial(xmlr.Object):
35 | def __init__(self, mass=0.0, inertia=None, pose=None):
36 | self.mass = mass
37 | self.inertia = inertia
38 | self.pose = pose
39 |
40 |
41 | xmlr.reflect(Inertial, params=[xmlr.Element("mass", float), xmlr.Element("inertia", Inertia), pose_element])
42 |
43 |
44 | class Box(xmlr.Object):
45 | def __init__(self, size=None):
46 | self.size = size
47 |
48 |
49 | xmlr.reflect(Box, tag="box", params=[xmlr.Element("size", "vector3")])
50 |
51 |
52 | class Cylinder(xmlr.Object):
53 | def __init__(self, radius=0.0, length=0.0):
54 | self.radius = radius
55 | self.length = length
56 |
57 |
58 | xmlr.reflect(Cylinder, tag="cylinder", params=[xmlr.Element("radius", float), xmlr.Element("length", float)])
59 |
60 |
61 | class Sphere(xmlr.Object):
62 | def __init__(self, radius=0.0):
63 | self.radius = radius
64 |
65 |
66 | xmlr.reflect(Sphere, tag="sphere", params=[xmlr.Element("radius", float)])
67 |
68 |
69 | class Mesh(xmlr.Object):
70 | def __init__(self, filename=None, scale=None):
71 | self.filename = filename
72 | self.scale = scale
73 |
74 |
75 | xmlr.reflect(
76 | Mesh, tag="mesh", params=[xmlr.Element("filename", str), xmlr.Element("scale", "vector3", required=False)]
77 | )
78 |
79 |
80 | class GeometricType(xmlr.ValueType):
81 | def __init__(self):
82 | self.factory = xmlr.FactoryType(
83 | "geometric", {"box": Box, "cylinder": Cylinder, "sphere": Sphere, "mesh": Mesh}
84 | )
85 |
86 | def from_xml(self, node, path):
87 | children = xml_children(node)
88 | assert len(children) == 1, "One element only for geometric"
89 | return self.factory.from_xml(children[0], path=path)
90 |
91 | def write_xml(self, node, obj):
92 | name = self.factory.get_name(obj)
93 | child = node_add(node, name)
94 | obj.write_xml(child)
95 |
96 |
97 | xmlr.add_type("geometric", GeometricType())
98 |
99 |
100 | class Script(xmlr.Object):
101 | def __init__(self, uri=None, name=None):
102 | self.uri = uri
103 | self.name = name
104 |
105 |
106 | xmlr.reflect(Script, tag="script", params=[xmlr.Element("name", str, False), xmlr.Element("uri", str, False)])
107 |
108 |
109 | class Material(xmlr.Object):
110 | def __init__(self, name=None, script=None):
111 | self.name = name
112 | self.script = script
113 |
114 |
115 | xmlr.reflect(Material, tag="material", params=[name_attribute, xmlr.Element("script", Script, False)])
116 |
117 |
118 | class Visual(xmlr.Object):
119 | def __init__(self, name=None, geometry=None, pose=None):
120 | self.name = name
121 | self.geometry = geometry
122 | self.pose = pose
123 |
124 |
125 | xmlr.reflect(
126 | Visual,
127 | tag="visual",
128 | params=[
129 | name_attribute,
130 | xmlr.Element("geometry", "geometric"),
131 | xmlr.Element("material", Material, False),
132 | pose_element,
133 | ],
134 | )
135 |
136 |
137 | class Collision(xmlr.Object):
138 | def __init__(self, name=None, geometry=None, pose=None):
139 | self.name = name
140 | self.geometry = geometry
141 | self.pose = pose
142 |
143 |
144 | xmlr.reflect(Collision, tag="collision", params=[name_attribute, xmlr.Element("geometry", "geometric"), pose_element])
145 |
146 |
147 | class Dynamics(xmlr.Object):
148 | def __init__(self, damping=None, friction=None):
149 | self.damping = damping
150 | self.friction = friction
151 |
152 |
153 | xmlr.reflect(
154 | Dynamics, tag="dynamics", params=[xmlr.Element("damping", float, False), xmlr.Element("friction", float, False)]
155 | )
156 |
157 |
158 | class Limit(xmlr.Object):
159 | def __init__(self, lower=None, upper=None):
160 | self.lower = lower
161 | self.upper = upper
162 |
163 |
164 | xmlr.reflect(Limit, tag="limit", params=[xmlr.Element("lower", float, False), xmlr.Element("upper", float, False)])
165 |
166 |
167 | class Axis(xmlr.Object):
168 | def __init__(self, xyz=None, limit=None, dynamics=None, use_parent_model_frame=None):
169 | self.xyz = xyz
170 | self.limit = limit
171 | self.dynamics = dynamics
172 | self.use_parent_model_frame = use_parent_model_frame
173 |
174 |
175 | xmlr.reflect(
176 | Axis,
177 | tag="axis",
178 | params=[
179 | xmlr.Element("xyz", "vector3"),
180 | xmlr.Element("limit", Limit, False),
181 | xmlr.Element("dynamics", Dynamics, False),
182 | xmlr.Element("use_parent_model_frame", bool, False),
183 | ],
184 | )
185 |
186 |
187 | class Joint(xmlr.Object):
188 | TYPES = ["unknown", "revolute", "gearbox", "revolute2", "prismatic", "ball", "screw", "universal", "fixed"]
189 |
190 | def __init__(self, name=None, parent=None, child=None, joint_type=None, axis=None, pose=None):
191 | self.aggregate_init()
192 | self.name = name
193 | self.parent = parent
194 | self.child = child
195 | self.type = joint_type
196 | self.axis = axis
197 | self.pose = pose
198 |
199 | # Aliases
200 | @property
201 | def joint_type(self):
202 | return self.type
203 |
204 | @joint_type.setter
205 | def joint_type(self, value):
206 | self.type = value
207 |
208 |
209 | xmlr.reflect(
210 | Joint,
211 | tag="joint",
212 | params=[
213 | name_attribute,
214 | xmlr.Attribute("type", str, False),
215 | xmlr.Element("axis", Axis),
216 | xmlr.Element("parent", str),
217 | xmlr.Element("child", str),
218 | pose_element,
219 | ],
220 | )
221 |
222 |
223 | class Link(xmlr.Object):
224 | def __init__(self, name=None, pose=None, inertial=None, kinematic=False):
225 | self.aggregate_init()
226 | self.name = name
227 | self.pose = pose
228 | self.inertial = inertial
229 | self.kinematic = kinematic
230 | self.visuals = []
231 | self.collisions = []
232 |
233 |
234 | xmlr.reflect(
235 | Link,
236 | tag="link",
237 | params=[
238 | name_attribute,
239 | xmlr.Element("inertial", Inertial),
240 | xmlr.Attribute("kinematic", bool, False),
241 | xmlr.AggregateElement("visual", Visual, var="visuals"),
242 | xmlr.AggregateElement("collision", Collision, var="collisions"),
243 | pose_element,
244 | ],
245 | )
246 |
247 |
248 | class Model(xmlr.Object):
249 | def __init__(self, name=None, pose=None):
250 | self.aggregate_init()
251 | self.name = name
252 | self.pose = pose
253 | self.links = []
254 | self.joints = []
255 | self.joint_map = {}
256 | self.link_map = {}
257 |
258 | self.parent_map = {}
259 | self.child_map = {}
260 |
261 | def add_aggregate(self, typeName, elem):
262 | xmlr.Object.add_aggregate(self, typeName, elem)
263 |
264 | if typeName == "joint":
265 | joint = elem
266 | self.joint_map[joint.name] = joint
267 | self.parent_map[joint.child] = (joint.name, joint.parent)
268 | if joint.parent in self.child_map:
269 | self.child_map[joint.parent].append((joint.name, joint.child))
270 | else:
271 | self.child_map[joint.parent] = [(joint.name, joint.child)]
272 | elif typeName == "link":
273 | link = elem
274 | self.link_map[link.name] = link
275 |
276 | def add_link(self, link):
277 | self.add_aggregate("link", link)
278 |
279 | def add_joint(self, joint):
280 | self.add_aggregate("joint", joint)
281 |
282 |
283 | xmlr.reflect(
284 | Model,
285 | tag="model",
286 | params=[
287 | name_attribute,
288 | xmlr.AggregateElement("link", Link, var="links"),
289 | xmlr.AggregateElement("joint", Joint, var="joints"),
290 | pose_element,
291 | ],
292 | )
293 |
294 |
295 | class SDF(xmlr.Object):
296 | def __init__(self, version=None):
297 | self.version = version
298 |
299 |
300 | xmlr.reflect(
301 | SDF,
302 | tag="sdf",
303 | params=[
304 | xmlr.Attribute("version", str, False),
305 | xmlr.Element("model", Model, False),
306 | ],
307 | )
308 |
309 |
310 | xmlr.end_namespace()
311 |
--------------------------------------------------------------------------------
/kinpy/urdf_parser_py/urdf.py:
--------------------------------------------------------------------------------
1 | from . import xml_reflection as xmlr
2 | from .xml_reflection.basics import *
3 |
4 | # Add a 'namespace' for names to avoid a conflict between URDF and SDF?
5 | # A type registry? How to scope that? Just make a 'global' type pointer?
6 | # Or just qualify names? urdf.geometric, sdf.geometric
7 |
8 | xmlr.start_namespace("urdf")
9 |
10 | xmlr.add_type("element_link", xmlr.SimpleElementType("link", str))
11 | xmlr.add_type("element_xyz", xmlr.SimpleElementType("xyz", "vector3"))
12 |
13 | verbose = True
14 |
15 |
16 | class Pose(xmlr.Object):
17 | def __init__(self, xyz=None, rpy=None):
18 | self.xyz = xyz
19 | self.rpy = rpy
20 |
21 | def check_valid(self):
22 | assert (self.xyz is None or len(self.xyz) == 3) and (self.rpy is None or len(self.rpy) == 3)
23 |
24 | # Aliases for backwards compatibility
25 | @property
26 | def rotation(self):
27 | return self.rpy
28 |
29 | @rotation.setter
30 | def rotation(self, value):
31 | self.rpy = value
32 |
33 | @property
34 | def position(self):
35 | return self.xyz
36 |
37 | @position.setter
38 | def position(self, value):
39 | self.xyz = value
40 |
41 |
42 | xmlr.reflect(
43 | Pose,
44 | tag="origin",
45 | params=[
46 | xmlr.Attribute("xyz", "vector3", False, default=[0, 0, 0]),
47 | xmlr.Attribute("rpy", "vector3", False, default=[0, 0, 0]),
48 | ],
49 | )
50 |
51 |
52 | # Common stuff
53 | name_attribute = xmlr.Attribute("name", str)
54 | origin_element = xmlr.Element("origin", Pose, False)
55 |
56 |
57 | class Color(xmlr.Object):
58 | def __init__(self, *args):
59 | # What about named colors?
60 | count = len(args)
61 | if count == 4 or count == 3:
62 | self.rgba = args
63 | elif count == 1:
64 | self.rgba = args[0]
65 | elif count == 0:
66 | self.rgba = None
67 | if self.rgba is not None:
68 | if len(self.rgba) == 3:
69 | self.rgba += [1.0]
70 | if len(self.rgba) != 4:
71 | raise Exception("Invalid color argument count")
72 |
73 |
74 | xmlr.reflect(Color, tag="color", params=[xmlr.Attribute("rgba", "vector4")])
75 |
76 |
77 | class JointDynamics(xmlr.Object):
78 | def __init__(self, damping=None, friction=None):
79 | self.damping = damping
80 | self.friction = friction
81 |
82 |
83 | xmlr.reflect(
84 | JointDynamics,
85 | tag="dynamics",
86 | params=[xmlr.Attribute("damping", float, False), xmlr.Attribute("friction", float, False)],
87 | )
88 |
89 |
90 | class Box(xmlr.Object):
91 | def __init__(self, size=None):
92 | self.size = size
93 |
94 |
95 | xmlr.reflect(Box, tag="box", params=[xmlr.Attribute("size", "vector3")])
96 |
97 |
98 | class Cylinder(xmlr.Object):
99 | def __init__(self, radius=0.0, length=0.0):
100 | self.radius = radius
101 | self.length = length
102 |
103 |
104 | xmlr.reflect(Cylinder, tag="cylinder", params=[xmlr.Attribute("radius", float), xmlr.Attribute("length", float)])
105 |
106 |
107 | class Sphere(xmlr.Object):
108 | def __init__(self, radius=0.0):
109 | self.radius = radius
110 |
111 |
112 | xmlr.reflect(Sphere, tag="sphere", params=[xmlr.Attribute("radius", float)])
113 |
114 |
115 | class Mesh(xmlr.Object):
116 | def __init__(self, filename=None, scale=None):
117 | self.filename = filename
118 | self.scale = scale
119 |
120 |
121 | xmlr.reflect(
122 | Mesh, tag="mesh", params=[xmlr.Attribute("filename", str), xmlr.Attribute("scale", "vector3", required=False)]
123 | )
124 |
125 |
126 | class GeometricType(xmlr.ValueType):
127 | def __init__(self):
128 | self.factory = xmlr.FactoryType(
129 | "geometric", {"box": Box, "cylinder": Cylinder, "sphere": Sphere, "mesh": Mesh}
130 | )
131 |
132 | def from_xml(self, node, path):
133 | children = xml_children(node)
134 | assert len(children) == 1, "One element only for geometric"
135 | return self.factory.from_xml(children[0], path=path)
136 |
137 | def write_xml(self, node, obj):
138 | name = self.factory.get_name(obj)
139 | child = node_add(node, name)
140 | obj.write_xml(child)
141 |
142 |
143 | xmlr.add_type("geometric", GeometricType())
144 |
145 |
146 | class Collision(xmlr.Object):
147 | def __init__(self, geometry=None, origin=None):
148 | self.geometry = geometry
149 | self.origin = origin
150 |
151 |
152 | xmlr.reflect(Collision, tag="collision", params=[origin_element, xmlr.Element("geometry", "geometric")])
153 |
154 |
155 | class Texture(xmlr.Object):
156 | def __init__(self, filename=None):
157 | self.filename = filename
158 |
159 |
160 | xmlr.reflect(Texture, tag="texture", params=[xmlr.Attribute("filename", str)])
161 |
162 |
163 | class Material(xmlr.Object):
164 | def __init__(self, name=None, color=None, texture=None):
165 | self.name = name
166 | self.color = color
167 | self.texture = texture
168 |
169 | def check_valid(self):
170 | if self.color is None and self.texture is None:
171 | xmlr.on_error("Material has neither a color nor texture.")
172 |
173 |
174 | xmlr.reflect(
175 | Material,
176 | tag="material",
177 | params=[name_attribute, xmlr.Element("color", Color, False), xmlr.Element("texture", Texture, False)],
178 | )
179 |
180 |
181 | class LinkMaterial(Material):
182 | def check_valid(self):
183 | pass
184 |
185 |
186 | class Visual(xmlr.Object):
187 | def __init__(self, geometry=None, material=None, origin=None):
188 | self.geometry = geometry
189 | self.material = material
190 | self.origin = origin
191 |
192 |
193 | xmlr.reflect(
194 | Visual,
195 | tag="visual",
196 | params=[origin_element, xmlr.Element("geometry", "geometric"), xmlr.Element("material", LinkMaterial, False)],
197 | )
198 |
199 |
200 | class Inertia(xmlr.Object):
201 | KEYS = ["ixx", "ixy", "ixz", "iyy", "iyz", "izz"]
202 |
203 | def __init__(self, ixx=0.0, ixy=0.0, ixz=0.0, iyy=0.0, iyz=0.0, izz=0.0):
204 | self.ixx = ixx
205 | self.ixy = ixy
206 | self.ixz = ixz
207 | self.iyy = iyy
208 | self.iyz = iyz
209 | self.izz = izz
210 |
211 | def to_matrix(self):
212 | return [[self.ixx, self.ixy, self.ixz], [self.ixy, self.iyy, self.iyz], [self.ixz, self.iyz, self.izz]]
213 |
214 |
215 | xmlr.reflect(Inertia, tag="inertia", params=[xmlr.Attribute(key, float) for key in Inertia.KEYS])
216 |
217 |
218 | class Inertial(xmlr.Object):
219 | def __init__(self, mass=0.0, inertia=None, origin=None):
220 | self.mass = mass
221 | self.inertia = inertia
222 | self.origin = origin
223 |
224 |
225 | xmlr.reflect(
226 | Inertial,
227 | tag="inertial",
228 | params=[origin_element, xmlr.Element("mass", "element_value"), xmlr.Element("inertia", Inertia, False)],
229 | )
230 |
231 |
232 | # FIXME: we are missing the reference position here.
233 | class JointCalibration(xmlr.Object):
234 | def __init__(self, rising=None, falling=None):
235 | self.rising = rising
236 | self.falling = falling
237 |
238 |
239 | xmlr.reflect(
240 | JointCalibration,
241 | tag="calibration",
242 | params=[xmlr.Attribute("rising", float, False, 0), xmlr.Attribute("falling", float, False, 0)],
243 | )
244 |
245 |
246 | class JointLimit(xmlr.Object):
247 | def __init__(self, effort=None, velocity=None, lower=None, upper=None):
248 | self.effort = effort
249 | self.velocity = velocity
250 | self.lower = lower
251 | self.upper = upper
252 |
253 |
254 | xmlr.reflect(
255 | JointLimit,
256 | tag="limit",
257 | params=[
258 | xmlr.Attribute("effort", float),
259 | xmlr.Attribute("lower", float, False, 0),
260 | xmlr.Attribute("upper", float, False, 0),
261 | xmlr.Attribute("velocity", float),
262 | ],
263 | )
264 |
265 | # FIXME: we are missing __str__ here.
266 |
267 |
268 | class JointMimic(xmlr.Object):
269 | def __init__(self, joint_name=None, multiplier=None, offset=None):
270 | self.joint = joint_name
271 | self.multiplier = multiplier
272 | self.offset = offset
273 |
274 |
275 | xmlr.reflect(
276 | JointMimic,
277 | tag="mimic",
278 | params=[
279 | xmlr.Attribute("joint", str),
280 | xmlr.Attribute("multiplier", float, False),
281 | xmlr.Attribute("offset", float, False),
282 | ],
283 | )
284 |
285 |
286 | class SafetyController(xmlr.Object):
287 | def __init__(self, velocity=None, position=None, lower=None, upper=None):
288 | self.k_velocity = velocity
289 | self.k_position = position
290 | self.soft_lower_limit = lower
291 | self.soft_upper_limit = upper
292 |
293 |
294 | xmlr.reflect(
295 | SafetyController,
296 | tag="safety_controller",
297 | params=[
298 | xmlr.Attribute("k_velocity", float),
299 | xmlr.Attribute("k_position", float, False, 0),
300 | xmlr.Attribute("soft_lower_limit", float, False, 0),
301 | xmlr.Attribute("soft_upper_limit", float, False, 0),
302 | ],
303 | )
304 |
305 |
306 | class Joint(xmlr.Object):
307 | TYPES = ["unknown", "revolute", "continuous", "prismatic", "floating", "planar", "fixed"]
308 |
309 | def __init__(
310 | self,
311 | name=None,
312 | parent=None,
313 | child=None,
314 | joint_type=None,
315 | axis=None,
316 | origin=None,
317 | limit=None,
318 | dynamics=None,
319 | safety_controller=None,
320 | calibration=None,
321 | mimic=None,
322 | ):
323 | self.name = name
324 | self.parent = parent
325 | self.child = child
326 | self.type = joint_type
327 | self.axis = axis
328 | self.origin = origin
329 | self.limit = limit
330 | self.dynamics = dynamics
331 | self.safety_controller = safety_controller
332 | self.calibration = calibration
333 | self.mimic = mimic
334 |
335 | def check_valid(self):
336 | assert self.type in self.TYPES, "Invalid joint type: {}".format(self.type) # noqa
337 |
338 | # Aliases
339 | @property
340 | def joint_type(self):
341 | return self.type
342 |
343 | @joint_type.setter
344 | def joint_type(self, value):
345 | self.type = value
346 |
347 |
348 | xmlr.reflect(
349 | Joint,
350 | tag="joint",
351 | params=[
352 | name_attribute,
353 | xmlr.Attribute("type", str),
354 | origin_element,
355 | xmlr.Element("axis", "element_xyz", False),
356 | xmlr.Element("parent", "element_link"),
357 | xmlr.Element("child", "element_link"),
358 | xmlr.Element("limit", JointLimit, False),
359 | xmlr.Element("dynamics", JointDynamics, False),
360 | xmlr.Element("safety_controller", SafetyController, False),
361 | xmlr.Element("calibration", JointCalibration, False),
362 | xmlr.Element("mimic", JointMimic, False),
363 | ],
364 | )
365 |
366 |
367 | class Link(xmlr.Object):
368 | def __init__(self, name=None, visual=None, inertial=None, collision=None, origin=None):
369 | self.aggregate_init()
370 | self.name = name
371 | self.visuals = []
372 | self.inertial = inertial
373 | self.collisions = []
374 | self.origin = origin
375 |
376 | def __get_visual(self):
377 | """Return the first visual or None."""
378 | if self.visuals:
379 | return self.visuals[0]
380 |
381 | def __set_visual(self, visual):
382 | """Set the first visual."""
383 | if self.visuals:
384 | self.visuals[0] = visual
385 | else:
386 | self.visuals.append(visual)
387 |
388 | def __get_collision(self):
389 | """Return the first collision or None."""
390 | if self.collisions:
391 | return self.collisions[0]
392 |
393 | def __set_collision(self, collision):
394 | """Set the first collision."""
395 | if self.collisions:
396 | self.collisions[0] = collision
397 | else:
398 | self.collisions.append(collision)
399 |
400 | # Properties for backwards compatibility
401 | visual = property(__get_visual, __set_visual)
402 | collision = property(__get_collision, __set_collision)
403 |
404 |
405 | xmlr.reflect(
406 | Link,
407 | tag="link",
408 | params=[
409 | name_attribute,
410 | origin_element,
411 | xmlr.AggregateElement("visual", Visual),
412 | xmlr.AggregateElement("collision", Collision),
413 | xmlr.Element("inertial", Inertial, False),
414 | ],
415 | )
416 |
417 |
418 | class PR2Transmission(xmlr.Object):
419 | def __init__(self, name=None, joint=None, actuator=None, type=None, mechanicalReduction=1):
420 | self.name = name
421 | self.type = type
422 | self.joint = joint
423 | self.actuator = actuator
424 | self.mechanicalReduction = mechanicalReduction
425 |
426 |
427 | xmlr.reflect(
428 | PR2Transmission,
429 | tag="pr2_transmission",
430 | params=[
431 | name_attribute,
432 | xmlr.Attribute("type", str),
433 | xmlr.Element("joint", "element_name"),
434 | xmlr.Element("actuator", "element_name"),
435 | xmlr.Element("mechanicalReduction", float),
436 | ],
437 | )
438 |
439 |
440 | class Actuator(xmlr.Object):
441 | def __init__(self, name=None, mechanicalReduction=1):
442 | self.name = name
443 | self.mechanicalReduction = None
444 |
445 |
446 | xmlr.reflect(
447 | Actuator, tag="actuator", params=[name_attribute, xmlr.Element("mechanicalReduction", float, required=False)]
448 | )
449 |
450 |
451 | class TransmissionJoint(xmlr.Object):
452 | def __init__(self, name=None):
453 | self.aggregate_init()
454 | self.name = name
455 | self.hardwareInterfaces = []
456 |
457 | def check_valid(self):
458 | assert len(self.hardwareInterfaces) > 0, "no hardwareInterface defined"
459 |
460 |
461 | xmlr.reflect(
462 | TransmissionJoint,
463 | tag="joint",
464 | params=[
465 | name_attribute,
466 | xmlr.AggregateElement("hardwareInterface", str),
467 | ],
468 | )
469 |
470 |
471 | class Transmission(xmlr.Object):
472 | """New format: http://wiki.ros.org/urdf/XML/Transmission"""
473 |
474 | def __init__(self, name=None):
475 | self.aggregate_init()
476 | self.name = name
477 | self.joints = []
478 | self.actuators = []
479 |
480 | def check_valid(self):
481 | assert len(self.joints) > 0, "no joint defined"
482 | assert len(self.actuators) > 0, "no actuator defined"
483 |
484 |
485 | xmlr.reflect(
486 | Transmission,
487 | tag="new_transmission",
488 | params=[
489 | name_attribute,
490 | xmlr.Element("type", str),
491 | xmlr.AggregateElement("joint", TransmissionJoint),
492 | xmlr.AggregateElement("actuator", Actuator),
493 | ],
494 | )
495 |
496 | xmlr.add_type("transmission", xmlr.DuckTypedFactory("transmission", [Transmission, PR2Transmission]))
497 |
498 |
499 | class Robot(xmlr.Object):
500 | def __init__(self, name=None):
501 | self.aggregate_init()
502 |
503 | self.name = name
504 | self.joints = []
505 | self.links = []
506 | self.materials = []
507 | self.gazebos = []
508 | self.transmissions = []
509 |
510 | self.joint_map = {}
511 | self.link_map = {}
512 |
513 | self.parent_map = {}
514 | self.child_map = {}
515 |
516 | def add_aggregate(self, typeName, elem):
517 | xmlr.Object.add_aggregate(self, typeName, elem)
518 |
519 | if typeName == "joint":
520 | joint = elem
521 | self.joint_map[joint.name] = joint
522 | self.parent_map[joint.child] = (joint.name, joint.parent)
523 | if joint.parent in self.child_map:
524 | self.child_map[joint.parent].append((joint.name, joint.child))
525 | else:
526 | self.child_map[joint.parent] = [(joint.name, joint.child)]
527 | elif typeName == "link":
528 | link = elem
529 | self.link_map[link.name] = link
530 |
531 | def add_link(self, link):
532 | self.add_aggregate("link", link)
533 |
534 | def add_joint(self, joint):
535 | self.add_aggregate("joint", joint)
536 |
537 | def get_chain(self, root, tip, joints=True, links=True, fixed=True):
538 | chain = []
539 | if links:
540 | chain.append(tip)
541 | link = tip
542 | while link != root:
543 | (joint, parent) = self.parent_map[link]
544 | if joints:
545 | if fixed or self.joint_map[joint].joint_type != "fixed":
546 | chain.append(joint)
547 | if links:
548 | chain.append(parent)
549 | link = parent
550 | chain.reverse()
551 | return chain
552 |
553 | def get_root(self):
554 | root = None
555 | for link in self.link_map:
556 | if link not in self.parent_map:
557 | assert root is None, "Multiple roots detected, invalid URDF."
558 | root = link
559 | assert root is not None, "No roots detected, invalid URDF."
560 | return root
561 |
562 | @classmethod
563 | def from_parameter_server(cls, key="robot_description"):
564 | """
565 | Retrieve the robot model on the parameter server
566 | and parse it to create a URDF robot structure.
567 |
568 | Warning: this requires roscore to be running.
569 | """
570 | # Could move this into xml_reflection
571 | import rospy
572 |
573 | return cls.from_xml_string(rospy.get_param(key))
574 |
575 |
576 | xmlr.reflect(
577 | Robot,
578 | tag="robot",
579 | params=[
580 | xmlr.Attribute("name", str, False), # Is 'name' a required attribute?
581 | xmlr.AggregateElement("link", Link),
582 | xmlr.AggregateElement("joint", Joint),
583 | xmlr.AggregateElement("gazebo", xmlr.RawType()),
584 | xmlr.AggregateElement("transmission", "transmission"),
585 | xmlr.AggregateElement("material", Material),
586 | ],
587 | )
588 |
589 | # Make an alias
590 | URDF = Robot
591 |
592 | xmlr.end_namespace()
593 |
--------------------------------------------------------------------------------
/kinpy/urdf_parser_py/xml_reflection/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import *
2 |
--------------------------------------------------------------------------------
/kinpy/urdf_parser_py/xml_reflection/basics.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import string
3 |
4 | import yaml
5 | from lxml import etree
6 |
7 |
8 | def xml_string(rootXml, addHeader=True):
9 | # Meh
10 | xmlString = etree.tostring(rootXml, pretty_print=True, encoding="unicode")
11 | if addHeader:
12 | xmlString = '\n' + xmlString
13 | return xmlString
14 |
15 |
16 | def dict_sub(obj, keys):
17 | return dict((key, obj[key]) for key in keys)
18 |
19 |
20 | def node_add(doc, sub):
21 | if sub is None:
22 | return None
23 | if type(sub) == str:
24 | return etree.SubElement(doc, sub)
25 | elif isinstance(sub, etree._Element):
26 | doc.append(sub) # This screws up the rest of the tree for prettyprint
27 | return sub
28 | else:
29 | raise Exception("Invalid sub value")
30 |
31 |
32 | def pfloat(x):
33 | return str(x).rstrip(".")
34 |
35 |
36 | def xml_children(node):
37 | children = node.getchildren()
38 |
39 | def predicate(node):
40 | return not isinstance(node, etree._Comment)
41 |
42 | return list(filter(predicate, children))
43 |
44 |
45 | def isstring(obj):
46 | try:
47 | return isinstance(obj, basestring)
48 | except NameError:
49 | return isinstance(obj, str)
50 |
51 |
52 | def to_yaml(obj):
53 | """Simplify yaml representation for pretty printing"""
54 | # Is there a better way to do this by adding a representation with
55 | # yaml.Dumper?
56 | # Ordered dict: http://pyyaml.org/ticket/29#comment:11
57 | if obj is None or isstring(obj):
58 | out = str(obj)
59 | elif type(obj) in [int, float, bool]:
60 | return obj
61 | elif hasattr(obj, "to_yaml"):
62 | out = obj.to_yaml()
63 | elif isinstance(obj, etree._Element):
64 | out = etree.tostring(obj, pretty_print=True)
65 | elif type(obj) == dict:
66 | out = {}
67 | for (var, value) in obj.items():
68 | out[str(var)] = to_yaml(value)
69 | elif hasattr(obj, "tolist"):
70 | # For numpy objects
71 | out = to_yaml(obj.tolist())
72 | elif isinstance(obj, collections.Iterable):
73 | out = [to_yaml(item) for item in obj]
74 | else:
75 | out = str(obj)
76 | return out
77 |
78 |
79 | class SelectiveReflection(object):
80 | def get_refl_vars(self):
81 | return list(vars(self).keys())
82 |
83 |
84 | class YamlReflection(SelectiveReflection):
85 | def to_yaml(self):
86 | raw = dict((var, getattr(self, var)) for var in self.get_refl_vars())
87 | return to_yaml(raw)
88 |
89 | def __str__(self):
90 | # Good idea? Will it remove other important things?
91 | return yaml.dump(self.to_yaml()).rstrip()
92 |
--------------------------------------------------------------------------------
/kinpy/visualizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import defaultdict
3 | from functools import partial
4 | from typing import Dict, List, Optional, Tuple, Union
5 |
6 | import numpy as np
7 | import transformations as trf
8 | import vtk
9 | from vtk.util.colors import tomato
10 |
11 | from . import transform
12 | from .chain import Chain
13 | from .frame import Visual
14 |
15 |
16 | class Visualizer:
17 | def __init__(self, win_size: Tuple[int, int] = (640, 480)) -> None:
18 | self._actors: Dict[str, list] = defaultdict(list)
19 | self._axes: Dict[str, vtk.vtkAxesActor] = {}
20 | self._ren = vtk.vtkRenderer()
21 | self._ren.SetBackground(0.1, 0.2, 0.4)
22 | self._win = vtk.vtkRenderWindow()
23 | self._win.SetSize(*win_size)
24 | self._win.AddRenderer(self._ren)
25 | self._inter = vtk.vtkRenderWindowInteractor()
26 | self._inter.SetRenderWindow(self._win)
27 |
28 | def add_robot(
29 | self,
30 | transformations: Dict[str, transform.Transform],
31 | visuals_map: Dict[str, List[Visual]],
32 | mesh_file_path: str = "./",
33 | axes: bool = False,
34 | ) -> None:
35 | for k, trans in transformations.items():
36 | if axes:
37 | self.add_axes(trans, geom_name=k)
38 | for v in visuals_map[k]:
39 | tf = trans * v.offset
40 | if v.geom_type == "mesh":
41 | self.add_mesh(os.path.join(mesh_file_path, v.geom_param), tf, geom_name=k)
42 | elif v.geom_type == "cylinder":
43 | self.add_cylinder(v.geom_param[0], v.geom_param[1], tf, geom_name=k)
44 | elif v.geom_type == "box":
45 | self.add_box(v.geom_param, tf, geom_name=k)
46 | elif v.geom_type == "sphere":
47 | self.add_sphere(v.geom_param, tf, geom_name=k)
48 | elif v.geom_type == "capsule":
49 | self.add_capsule(v.geom_param[0], v.geom_param[1], tf, geom_name=k)
50 |
51 | def add_shape_source(
52 | self, source: vtk.vtkAbstractPolyDataReader, transform: transform.Transform, geom_name: Optional[str] = None
53 | ) -> None:
54 | mapper = vtk.vtkPolyDataMapper()
55 | mapper.SetInputConnection(source.GetOutputPort())
56 | actor = vtk.vtkActor()
57 | actor.SetMapper(mapper)
58 | actor.GetProperty().SetColor(tomato)
59 | actor.SetPosition(transform.pos)
60 | rpy = np.rad2deg(trf.euler_from_quaternion(transform.rot, "rxyz"))
61 | actor.RotateX(rpy[0])
62 | actor.RotateY(rpy[1])
63 | actor.RotateZ(rpy[2])
64 | if geom_name:
65 | self._actors[geom_name].append(actor)
66 | self._ren.AddActor(actor)
67 |
68 | def add_axes(self, trans: transform.Transform, geom_name: Optional[str] = None) -> None:
69 | transform = vtk.vtkTransform()
70 | transform.Translate(trans.pos)
71 | rpy = np.rad2deg(trf.euler_from_quaternion(trans.rot, "rxyz"))
72 | transform.RotateX(rpy[0])
73 | transform.RotateY(rpy[1])
74 | transform.RotateZ(rpy[2])
75 | axes = vtk.vtkAxesActor()
76 | axes.SetTotalLength(0.1, 0.1, 0.1)
77 | axes.AxisLabelsOff()
78 | axes.SetUserTransform(transform)
79 | if geom_name:
80 | self._axes[geom_name] = axes
81 | self._ren.AddActor(axes)
82 |
83 | def load_obj(self, filename: str) -> vtk.vtkOBJReader:
84 | reader = vtk.vtkOBJReader()
85 | reader.SetFileName(filename)
86 | return reader
87 |
88 | def load_ply(self, filename: str) -> vtk.vtkPLYReader:
89 | reader = vtk.vtkPLYReader()
90 | reader.SetFileName(filename)
91 | return reader
92 |
93 | def load_stl(self, filename: str) -> vtk.vtkSTLReader:
94 | reader = vtk.vtkSTLReader()
95 | reader.SetFileName(filename)
96 | return reader
97 |
98 | def add_cylinder(
99 | self, radius: float, height: float, tf: Optional[transform.Transform] = None, geom_name: Optional[str] = None
100 | ) -> None:
101 | tf = tf or transform.Transform()
102 | cylinder = vtk.vtkCylinderSource()
103 | cylinder.SetResolution(20)
104 | cylinder.SetRadius(radius)
105 | cylinder.SetHeight(height)
106 | self.add_shape_source(cylinder, tf, geom_name)
107 |
108 | def add_box(
109 | self, size: List[float], tf: Optional[transform.Transform] = None, geom_name: Optional[str] = None
110 | ) -> None:
111 | tf = tf or transform.Transform()
112 | cube = vtk.vtkCubeSource()
113 | cube.SetXLength(size[0])
114 | cube.SetYLength(size[1])
115 | cube.SetZLength(size[2])
116 | self.add_shape_source(cube, tf, geom_name)
117 |
118 | def add_sphere(
119 | self, radius: float, tf: Optional[transform.Transform] = None, geom_name: Optional[str] = None
120 | ) -> None:
121 | tf = tf or transform.Transform()
122 | sphere = vtk.vtkSphereSource()
123 | sphere.SetRadius(radius)
124 | self.add_shape_source(sphere, tf, geom_name)
125 |
126 | def add_capsule(
127 | self,
128 | radius: float,
129 | fromto: np.ndarray,
130 | tf: Optional[transform.Transform] = None,
131 | step: float = 0.05,
132 | geom_name: Optional[str] = None,
133 | ) -> None:
134 | tf = tf or transform.Transform()
135 | spheres = vtk.vtkAppendPolyData()
136 | d_norm = np.linalg.norm(fromto[3:] - fromto[:3])
137 | direction = (fromto[3:] - fromto[:3]) / d_norm
138 | offset = transform.Transform(
139 | rot=trf.quaternion_about_axis(
140 | np.arccos(np.dot(tf.rot_mat[:, 2], direction)), np.cross(tf.rot_mat[:, 2], direction)
141 | ),
142 | pos=fromto[:3],
143 | )
144 | for t in np.arange(0.0, 1.0, step):
145 | sphere = vtk.vtkSphereSource()
146 | sphere.SetRadius(radius)
147 | sphere.SetCenter(0, 0, t * d_norm)
148 | spheres.AddInputConnection(sphere.GetOutputPort())
149 | self.add_shape_source(spheres, tf * offset, geom_name)
150 |
151 | def add_mesh(
152 | self, filename: str, tf: Optional[transform.Transform] = None, geom_name: Optional[str] = None
153 | ) -> None:
154 | tf = tf or transform.Transform()
155 | _, ext = os.path.splitext(filename)
156 | ext = ext.lower()
157 | if ext == ".stl":
158 | reader = self.load_stl(filename)
159 | elif ext == ".obj":
160 | reader = self.load_obj(filename)
161 | elif ext == ".ply":
162 | reader = self.load_ply(filename)
163 | else:
164 | raise ValueError("Unsupported file extension, '%s'." % ext)
165 | self.add_shape_source(reader, tf, geom_name)
166 |
167 | def spin(self) -> None:
168 | self._win.Render()
169 | self._inter.Initialize()
170 | self._inter.Start()
171 |
172 |
173 | class JointAngleEditor(Visualizer):
174 | def __init__(
175 | self,
176 | chain: Chain,
177 | mesh_file_path: str = "./",
178 | axes: bool = False,
179 | initial_state: Optional[Union[Dict[str, float], List[float]]] = None,
180 | ) -> None:
181 | super().__init__()
182 | if initial_state is None:
183 | initial_state = {}
184 | self._chain = chain
185 | if isinstance(initial_state, (list, np.ndarray)):
186 | initial_state = {k: v for k, v in zip(self._chain.get_joint_parameter_names(), initial_state)}
187 | self._joint_angles: Dict[str, float] = initial_state
188 | self._visuals_map = self._chain.visuals_map()
189 | self.add_robot(
190 | self._chain.forward_kinematics(self._joint_angles, end_only=False), # type: ignore
191 | self._visuals_map,
192 | mesh_file_path,
193 | axes,
194 | )
195 | self._sliders = self._set_joint_slider(chain)
196 |
197 | def _update_joint_angle(self, obj: vtk.vtkSliderWidget, event: str, joint_name: str) -> None:
198 | slider_rep = obj.GetRepresentation()
199 | self._joint_angles[joint_name] = np.deg2rad(slider_rep.GetValue())
200 | positions = self._chain.forward_kinematics(self._joint_angles, end_only=False) # type: ignore
201 | for k, position in positions.items():
202 | if k in self._axes:
203 | transform = vtk.vtkTransform()
204 | transform.Translate(position.pos)
205 | rpy = np.rad2deg(trf.euler_from_quaternion(position.rot, "rxyz"))
206 | transform.RotateX(rpy[0])
207 | transform.RotateY(rpy[1])
208 | transform.RotateZ(rpy[2])
209 | self._axes[k].SetUserTransform(transform)
210 | self._axes[k].Modified()
211 | actors = self._actors[k]
212 | for i, actor in enumerate(actors):
213 | actor.SetOrientation(0, 0, 0)
214 | trans = position * self._visuals_map[k][i].offset
215 | actor.SetPosition(trans.pos)
216 | rpy = np.rad2deg(trf.euler_from_quaternion(trans.rot, "rxyz"))
217 | actor.RotateX(rpy[0])
218 | actor.RotateY(rpy[1])
219 | actor.RotateZ(rpy[2])
220 | actor.Modified()
221 | self._win.Render()
222 |
223 | def _set_joint_slider(self, chain: Chain) -> List[vtk.vtkSliderWidget]:
224 | sliders = []
225 | for i, frame in enumerate(chain):
226 | if frame.joint.joint_type == "fixed":
227 | continue
228 | slider_rep = vtk.vtkSliderRepresentation2D()
229 | slider_rep.SetMinimumValue(-180)
230 | slider_rep.SetMaximumValue(180)
231 | slider_rep.SetValue(np.rad2deg(self._joint_angles.get(frame.joint.name, 0.0)))
232 | slider_rep.SetSliderLength(0.05)
233 | slider_rep.SetSliderWidth(0.02)
234 | slider_rep.SetEndCapLength(0.01)
235 | slider_rep.SetEndCapWidth(0.02)
236 | slider_rep.GetPoint1Coordinate().SetCoordinateSystemToNormalizedDisplay()
237 | slider_rep.GetPoint1Coordinate().SetValue(0.75, 1.0 - 0.05 * i)
238 | slider_rep.GetPoint2Coordinate().SetCoordinateSystemToNormalizedDisplay()
239 | slider_rep.GetPoint2Coordinate().SetValue(0.98, 1.0 - 0.05 * i)
240 |
241 | # Set the background color of the slider representation
242 | slider_rep.GetSliderProperty().SetColor(0.2, 0.2, 0.2)
243 | slider_rep.GetTubeProperty().SetColor(0.7, 0.7, 0.7)
244 | slider_rep.GetCapProperty().SetColor(0.2, 0.2, 0.2)
245 |
246 | slider_widget = vtk.vtkSliderWidget()
247 | slider_widget.SetInteractor(self._inter)
248 | slider_widget.SetRepresentation(slider_rep)
249 | slider_widget.AddObserver(
250 | vtk.vtkCommand.InteractionEvent,
251 | partial(self._update_joint_angle, joint_name=frame.joint.name),
252 | )
253 | sliders.append(slider_widget)
254 | return sliders
255 |
256 | def spin(self) -> None:
257 | self._win.Render()
258 | self._inter.Initialize()
259 | for slider in self._sliders:
260 | slider.SetEnabled(True)
261 | slider.On()
262 | self._inter.Start()
263 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "kinpy"
3 | version = "0.4.2"
4 | description = ""
5 | authors = ["neka-nat "]
6 | license = "MIT"
7 |
8 | [tool.poetry.dependencies]
9 | python = ">=3.9,<3.12"
10 | numpy = [
11 | {version = ">=1.24.3,<=1.26.0", python = "3.9"},
12 | {version = "^1.24.3", python = ">=3.10,<3.12"},
13 | ]
14 | scipy = "^1.11.2"
15 | transformations = "^2020.1.1"
16 | absl-py = "^0.11.0"
17 | lxml = "^4.6.2"
18 | PyYAML = "^6.0.1"
19 | vtk = "^9.0.1"
20 |
21 | [tool.poetry.group.dev.dependencies]
22 | twine = "^3.3.0"
23 | isort = "^5.9.3"
24 | black = "^22.6.0"
25 | flake8 = "^5.0.4"
26 | flake8-bugbear = "^22.8.23"
27 | flake8-simplify = "^0.19.3"
28 | mypy = "^1.0.1"
29 |
30 | [tool.mypy]
31 | python_version = "3.9"
32 | ignore_missing_imports = true
33 |
34 | [build-system]
35 | requires = ["poetry-core>=1.0.0", "setuptools"]
36 | build-backend = "poetry.core.masonry.api"
37 |
--------------------------------------------------------------------------------
/scripts/kpviewer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from typing import Any
3 |
4 | import kinpy as kp
5 |
6 |
7 | def main(args: Any) -> None:
8 | chain = kp.build_chain_from_file(args.filename)
9 | print(chain)
10 | if args.show3d:
11 | param_names = chain.get_joint_parameter_names()
12 | th = {name: 0.0 for name in param_names}
13 | ret = chain.forward_kinematics(th)
14 | viz = kp.Visualizer()
15 | viz.add_robot(ret, chain.visuals_map())
16 | viz.spin()
17 |
18 |
19 | if __name__ == "__main__":
20 | import argparse
21 | parser = argparse.ArgumentParser(description="Robot model viewer.")
22 | parser.add_argument("filename", type=str, help="Robot model filename.")
23 | parser.add_argument("--show3d", action="store_true", help="Show 3D model on GUI window.")
24 | args = parser.parse_args()
25 | main(args)
26 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | __version__ = '0.4.2'
4 |
5 |
6 | setup(
7 | name='kinpy',
8 | version=__version__,
9 | author='neka-nat',
10 | author_email='nekanat.stock@gmail.com',
11 | description='Robotics kinematic calculation toolkit',
12 | license='MIT',
13 | keywords='robot kinematics',
14 | url='http://github.com/neka-nat/kinpy',
15 | packages=find_packages(exclude=["tests"]), #['kinpy'],
16 | include_package_data = True,
17 | package_data = {'': ['kinpy/mjcf_parser/schema.xml']},
18 | long_description=open('README.md').read(),
19 | long_description_content_type='text/markdown',
20 | install_requires=['numpy', 'scipy', 'absl-py', 'pyyaml',
21 | 'lxml', 'transformations', 'vtk'],
22 | )
23 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neka-nat/kinpy/36f9add8d2ad3425892361de9995d63f23254dfa/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_fkik.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | import kinpy as kp
4 |
5 |
6 | class TestFkIk(unittest.TestCase):
7 | def test_fkik(self):
8 | data = ''\
9 | ''\
10 | ''\
11 | ''\
12 | ''\
13 | ''\
14 | ''\
15 | ''\
16 | ''\
17 | ''\
18 | ''\
19 | ''\
20 | ''\
21 | ''\
22 | ''
23 | chain = kp.build_serial_chain_from_urdf(data, 'link3')
24 | th1 = np.random.rand(2)
25 | tg = chain.forward_kinematics(th1)
26 | th2 = chain.inverse_kinematics(tg)
27 | self.assertTrue(np.allclose(th1, th2, atol=1.0e-6))
28 |
29 |
30 | if __name__ == "__main__":
31 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_jacobian.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | import kinpy as kp
4 |
5 |
6 | class TestJacobian(unittest.TestCase):
7 | def test_jacobian1(self):
8 | data = ''\
9 | ''\
10 | ''\
11 | ''\
12 | ''\
13 | ''\
14 | ''\
15 | ''\
16 | ''\
17 | ''\
18 | ''\
19 | ''\
20 | ''\
21 | ''\
22 | ''
23 | chain = kp.build_serial_chain_from_urdf(data, 'link3')
24 | jc = chain.jacobian([0.0, 0.0])
25 | np.testing.assert_equal(np.array([[0.0, 0.0],
26 | [1.0, 0.0],
27 | [0.0, 0.0],
28 | [0.0, 0.0],
29 | [0.0, 0.0],
30 | [1.0, 1.0]]), jc)
31 |
32 | def test_jacobian2(self):
33 | data = ''\
34 | ''\
35 | ''\
36 | ''\
37 | ''\
38 | ''\
39 | ''\
40 | ''\
41 | ''\
42 | ''\
43 | ''\
44 | ''\
45 | ''\
46 | ''\
47 | ''
48 | chain = kp.build_serial_chain_from_urdf(data, 'link3')
49 | jc = chain.jacobian([0.0, 0.0])
50 | np.testing.assert_equal(np.array([[0.0, 0.0],
51 | [1.0, 0.0],
52 | [0.0, 1.0],
53 | [0.0, 0.0],
54 | [0.0, 0.0],
55 | [1.0, 0.0]]), jc)
56 |
57 | def test_jacobian3(self):
58 | chain = kp.build_serial_chain_from_urdf(open("examples/kuka_iiwa/model.urdf").read(), "lbr_iiwa_link_7")
59 | th = [0.0, -np.pi / 4.0, 0.0, np.pi / 2.0, 0.0, np.pi / 4.0, 0.0]
60 | jc = chain.jacobian(th)
61 | np.testing.assert_almost_equal(np.array([[0, 1.41421356e-02, 0, 2.82842712e-01, 0, 0, 0],
62 | [-6.60827561e-01, 0, -4.57275649e-01, 0, 5.72756493e-02, 0, 0],
63 | [0, 6.60827561e-01, 0, -3.63842712e-01, 0, 8.10000000e-02, 0],
64 | [0, 0, -7.07106781e-01, 0, -7.07106781e-01, 0, -1],
65 | [0, 1, 0, -1, 0, 1, 0],
66 | [1, 0, 7.07106781e-01, 0, -7.07106781e-01, 0, 0]]), jc)
67 |
68 | if __name__ == "__main__":
69 | unittest.main()
--------------------------------------------------------------------------------
/tests/test_transform.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | import kinpy as kp
4 |
5 |
6 | def random_transform():
7 | euler = np.random.rand(3) * 2.0 * np.pi - np.pi
8 | return kp.Transform(euler, np.random.rand(3))
9 |
10 |
11 | class TestTransform(unittest.TestCase):
12 | def test_multiply(self):
13 | t = random_transform()
14 | res = t * t.inverse()
15 | self.assertTrue(np.allclose(res.rot, np.array([1.0, 0.0, 0.0, 0.0]), atol=1.0e-6))
16 | self.assertTrue(np.allclose(res.pos, np.array([0.0, 0.0, 0.0]), atol=1.0e-6))
17 |
18 |
19 | if __name__ == "__main__":
20 | unittest.main()
--------------------------------------------------------------------------------