├── .gitignore
├── 0_introduction_to_numerical_robotics.ipynb
├── 1_motion_planning.ipynb
├── 2_inverse_kinematics.ipynb
├── 3_reinforcement_learning.ipynb
├── LICENSE
├── README.md
├── robotics-mva.yml
└── utils
├── __init__.py
├── collision_wrapper.py
├── datastructures
├── bucketkdtree.py
├── mtree
│ ├── __init__.py
│ ├── faster.py
│ ├── functions.py
│ └── heap_queue.py
├── pathtree.py
├── storage.py
└── tree.py
├── generate.py
├── load_ur5_parallel.py
├── load_ur5_with_obstacles.py
├── meshcat_viewer_wrapper
├── __init__.py
├── colors.py
├── tests.py
├── transformations.py
└── visualizer.py
├── tests.py
└── tiago_loader.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.zip
2 | *_solution.ipynb
3 | *_tensorboard
4 | .ipynb_checkpoints
5 | __pycache__
6 |
--------------------------------------------------------------------------------
/0_introduction_to_numerical_robotics.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Introduction to numerical robotics\n",
8 | "\n",
9 | "This notebook is a general introduction to Pinocchio. It shows how to manipulate the geometry model of a robot manipulator: set the configuration, compute the position of the end effector, check for collisions or the distance to an obstacle. The main idea is to give a brief introduction of the general topic: how to discover and learn a robot movement constrained by the environment, using iterative optimization methods.\n"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "## Set up\n",
17 | "\n",
18 | "Let us load the UR5 robot model, the Pinocchio library, some optimization functions from SciPy and the Matplotlib for plotting:"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {},
25 | "outputs": [],
26 | "source": [
27 | "import pinocchio as pin\n",
28 | "from utils.meshcat_viewer_wrapper import MeshcatVisualizer\n",
29 | "import time\n",
30 | "import numpy as np\n",
31 | "from numpy.linalg import inv,norm,pinv,svd,eig\n",
32 | "from scipy.optimize import fmin_bfgs,fmin_slsqp\n",
33 | "from utils.load_ur5_with_obstacles import load_ur5_with_obstacles,Target\n",
34 | "import matplotlib.pylab as plt"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "Let's first load the robot model and display it. For this tutorial, a single utility function will load the robot model and create obstacles around it:"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "robot = load_ur5_with_obstacles(reduced=True)"
51 | ]
52 | },
53 | {
54 | "cell_type": "markdown",
55 | "metadata": {},
56 | "source": [
57 | "The next few lines initialize a 3D viewer."
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "viz = MeshcatVisualizer(robot)\n",
67 | "viz.display(robot.q0)"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "hasattr(viz.viewer, 'jupyter_cell') and viz.viewer.jupyter_cell()"
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {},
82 | "source": [
83 | "The robot and the red obstacles are encoded in the `robot` object (we will not look in depth at what is inside this object). You can display a new configuration of the robot with `viz.display`. It takes a `numpy.array` of dimension 2 as input:"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "viz.display(np.array([3.,-1.5]))"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "metadata": {},
98 | "source": [
99 | "We also set up a target with is visualized as a green dot:"
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "target_pos = np.array([.5,.5])\n",
109 | "target = Target(viz,position = target_pos)"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "metadata": {},
115 | "source": [
116 | "The `Target` object is the green dot that the robot should reach. You can change the target position by editing `target.position`, and display the new position with `target.display()`."
117 | ]
118 | },
119 | {
120 | "cell_type": "markdown",
121 | "metadata": {},
122 | "source": [
123 | "## Using the robot model\n",
124 | "The robot is originally a 6 degrees-of-freedom (DOF) manipulator. Yet to make the example simple, we will only use its joints 1 and 2. The model has simply be loaded with \"frozen\" extra joints, which will then not appear in this notebook. Reload the model with `reduced=False` if you want to recover a model with full DOF."
125 | ]
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "metadata": {},
130 | "source": [
131 | "The following function computes the position of the end effector (in 2d):"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "def endef(q):\n",
141 | " '''Return the 2d position of the end effector.'''\n",
142 | " pin.framesForwardKinematics(robot.model, robot.data, q)\n",
143 | " return robot.data.oMf[-1].translation[[0, 2]]\n"
144 | ]
145 | },
146 | {
147 | "cell_type": "markdown",
148 | "metadata": {},
149 | "source": [
150 | "This function checks if the robot is in collision, and returns `True` if a collision is detected."
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "metadata": {},
157 | "outputs": [],
158 | "source": [
159 | "def coll(q):\n",
160 | " '''Return True if in collision, false otherwise.'''\n",
161 | " pin.updateGeometryPlacements(robot.model, robot.data, robot.collision_model, robot.collision_data, q)\n",
162 | " return pin.computeCollisions(robot.collision_model, robot.collision_data, False)\n"
163 | ]
164 | },
165 | {
166 | "cell_type": "markdown",
167 | "metadata": {},
168 | "source": [
169 | "The next function computes the distance between the end effector and the target."
170 | ]
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "metadata": {},
175 | "source": [
176 | "Your code:"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": null,
182 | "metadata": {},
183 | "outputs": [],
184 | "source": [
185 | "def dist(q):\n",
186 | " '''Return the distance between the end effector end the target (2d).'''\n",
187 | " return 0.\n"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "metadata": {},
193 | "source": [
194 | "Solution"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "%do_not_load tp0/generated/simple_path_planning_dist"
204 | ]
205 | },
206 | {
207 | "cell_type": "markdown",
208 | "metadata": {},
209 | "source": [
210 | "## Random search of a valid configuration\n",
211 | "The free space is difficult to represent explicitely. We can sample the configuration space until a free configuration is found:"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": null,
217 | "metadata": {},
218 | "outputs": [],
219 | "source": [
220 | "def qrand(check=False):\n",
221 | " '''Return a random configuration. If `check` is True, this configuration is not is collision.'''\n",
222 | " pass"
223 | ]
224 | },
225 | {
226 | "cell_type": "markdown",
227 | "metadata": {},
228 | "source": [
229 | "The solution if needed:"
230 | ]
231 | },
232 | {
233 | "cell_type": "code",
234 | "execution_count": null,
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "%do_not_load tp0/generated/simple_path_planning_qrand"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": null,
244 | "metadata": {},
245 | "outputs": [],
246 | "source": [
247 | "viz.display(qrand(check=True))"
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "metadata": {},
253 | "source": [
254 | "Let's now find a valid configuration that is arbitrarily close to the target: sample until dist is small enough and coll is false (you may want to display the random trials inside the loop)."
255 | ]
256 | },
257 | {
258 | "cell_type": "markdown",
259 | "metadata": {},
260 | "source": [
261 | "## From a random configuration to the target\n",
262 | "Let' s now start from a random configuration. How can we find a path that bring the robot toward the target without touching the obstacles. Any idea?"
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "execution_count": null,
268 | "metadata": {},
269 | "outputs": [],
270 | "source": [
271 | "# Random descent: crawling from one free configuration to the target with random\n",
272 | "# steps.\n",
273 | "def randomDescent(q0 = None):\n",
274 | " '''\n",
275 | " Make a random walk of 0.1 step toward target\n",
276 | " Return the list of configurations visited\n",
277 | " '''\n",
278 | " q = qrand(check=True) if q0 is None else q0\n",
279 | " hist = [ q.copy() ]\n",
280 | " # DO the walk\n",
281 | " return hist"
282 | ]
283 | },
284 | {
285 | "cell_type": "markdown",
286 | "metadata": {},
287 | "source": [
288 | "And solution if needed"
289 | ]
290 | },
291 | {
292 | "cell_type": "code",
293 | "execution_count": null,
294 | "metadata": {},
295 | "outputs": [],
296 | "source": [
297 | "%do_not_load tp0/generated/simple_path_planning_random_descent"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": null,
303 | "metadata": {},
304 | "outputs": [],
305 | "source": [
306 | "randomDescent()"
307 | ]
308 | },
309 | {
310 | "cell_type": "markdown",
311 | "metadata": {},
312 | "source": [
313 | "## Configuration space\n",
314 | "Let's try to have a better look of the configuration space. In this case, it is easy, as it is dimension 2: we can sample it exhaustively and plot it in 2D. For that, let's introduce another function to compute the distance to collision:"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": 46,
320 | "metadata": {},
321 | "outputs": [],
322 | "source": [
323 | "def collisionDistance(q):\n",
324 | " '''Return the minimal distance between robot and environment. '''\n",
325 | " pin.updateGeometryPlacements(robot.model,robot.data,robot.collision_model,robot.collision_data,q)\n",
326 | " if pin.computeCollisions(robot.collision_model,robot.collision_data,False):\n",
327 | " return 0.0\n",
328 | " idx = pin.computeDistances(robot.collision_model,robot.collision_data)\n",
329 | " return robot.collision_data.distanceResults[idx].min_distance"
330 | ]
331 | },
332 | {
333 | "cell_type": "markdown",
334 | "metadata": {},
335 | "source": [
336 | "Now, let's sample the configuration space and plot the distance-to-target and the distance-to-obstacle field (I put 500 samples to spare your CPU, but you need at least 10x more for obtaining a good picture)."
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "execution_count": null,
342 | "metadata": {},
343 | "outputs": [],
344 | "source": [
345 | "def sampleSpace(nbSamples=500):\n",
346 | " '''\n",
347 | " Sample nbSamples configurations and store them in two lists depending\n",
348 | " if the configuration is in free space (hfree) or in collision (hcol), along\n",
349 | " with the distance to the target and the distance to the obstacles.\n",
350 | " '''\n",
351 | " hcol = []\n",
352 | " hfree = []\n",
353 | " for i in range(nbSamples):\n",
354 | " q = qrand(False)\n",
355 | " if not coll(q):\n",
356 | " hfree.append( list(q.flat) + [ dist(q), collisionDistance(q) ])\n",
357 | " else:\n",
358 | " hcol.append( list(q.flat) + [ dist(q), 1e-2 ])\n",
359 | " return hcol,hfree\n",
360 | "\n",
361 | "def plotConfigurationSpace(hcol,hfree,markerSize=20):\n",
362 | " '''\n",
363 | " Plot 2 \"scatter\" plots: the first one plot the distance to the target for \n",
364 | " each configuration, the second plots the distance to the obstacles (axis q1,q2, \n",
365 | " distance in the color space).\n",
366 | " '''\n",
367 | " htotal = hcol + hfree\n",
368 | " h=np.array(htotal)\n",
369 | " plt.subplot(2,1,1)\n",
370 | " plt.scatter(h[:,0],h[:,1],c=h[:,2],s=markerSize,lw=0)\n",
371 | " plt.title(\"Distance to the target\")\n",
372 | " plt.colorbar()\n",
373 | " plt.subplot(2,1,2)\n",
374 | " plt.scatter(h[:,0],h[:,1],c=h[:,3],s=markerSize,lw=0)\n",
375 | " plt.title(\"Distance to the obstacles\")\n",
376 | " plt.colorbar()"
377 | ]
378 | },
379 | {
380 | "cell_type": "code",
381 | "execution_count": null,
382 | "metadata": {},
383 | "outputs": [],
384 | "source": [
385 | "hcol,hfree = sampleSpace(5000)\n",
386 | "plotConfigurationSpace(hcol,hfree)\n"
387 | ]
388 | },
389 | {
390 | "cell_type": "markdown",
391 | "metadata": {},
392 | "source": [
393 | "You can try to match your representation of the free space of the robot with this plot. \n",
394 | "As an example, you can display on this plot a feasible trajectory discover by random walk from an init position."
395 | ]
396 | },
397 | {
398 | "cell_type": "code",
399 | "execution_count": null,
400 | "metadata": {},
401 | "outputs": [],
402 | "source": [
403 | "traj = np.array([])\n",
404 | "qinit = np.array([-1.1, -3. ])"
405 | ]
406 | },
407 | {
408 | "cell_type": "markdown",
409 | "metadata": {},
410 | "source": [
411 | "Here is a solution:"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": null,
417 | "metadata": {},
418 | "outputs": [],
419 | "source": [
420 | "%do_not_load tp0/generated/simple_path_planning_traj"
421 | ]
422 | },
423 | {
424 | "cell_type": "code",
425 | "execution_count": null,
426 | "metadata": {},
427 | "outputs": [],
428 | "source": [
429 | "# Add yourr traj to the plot, be carefull !\n",
430 | "plotConfigurationSpace(hcol,hfree)\n",
431 | "plt.plot(traj[:,0],traj[:,1],'r',lw=3)"
432 | ]
433 | },
434 | {
435 | "cell_type": "markdown",
436 | "metadata": {},
437 | "source": [
438 | "## Optimize the distance under non-collision constraint\n",
439 | "Finally, let's use one of the optimizers from SciPy to search for a robot configuration that minimizes the distance to the target, under the constraint that the distance to collision is positive.\n",
440 | "For that, we define a *cost function* $cost: \\mathcal{C} \\to \\mathbb{R}$ (taking the robot configuration and returning a scalar) and a constraint function (taking again the robot configuration and returning a scalar or a vector of scalar that should be positive). We additionally use the \"callback\" functionnality of the solver to render the robot configuration corresponding to the current value of the decision variable inside the solver algorithm.\n",
441 | "We use the \"SLSQP\" solver from SciPy, which implements a \"sequential quadratic program\" algorithm and accepts constraints.\n"
442 | ]
443 | },
444 | {
445 | "cell_type": "code",
446 | "execution_count": null,
447 | "metadata": {},
448 | "outputs": [],
449 | "source": [
450 | "def cost(q):\n",
451 | " pass\n",
452 | " \n",
453 | "def constraint(q):\n",
454 | " pass\n",
455 | " \n",
456 | "def callback(q):\n",
457 | " '''\n",
458 | " At each optimization step, display the robot configuration in gepetto-viewer.\n",
459 | " '''\n",
460 | " viz.display(q)\n",
461 | " time.sleep(.01)\n",
462 | "\n",
463 | "def optimize():\n",
464 | " '''\n",
465 | " Optimize from an initial random configuration to discover a collision-free\n",
466 | " configuration as close as possible to the target.\n",
467 | " USE fmin_slsqp, see doc online\n",
468 | " '''"
469 | ]
470 | },
471 | {
472 | "cell_type": "markdown",
473 | "metadata": {},
474 | "source": [
475 | "Here is a valid solution:"
476 | ]
477 | },
478 | {
479 | "cell_type": "code",
480 | "execution_count": null,
481 | "metadata": {},
482 | "outputs": [],
483 | "source": [
484 | "%do_not_load tp0/generated/simple_path_planning_optim"
485 | ]
486 | },
487 | {
488 | "cell_type": "markdown",
489 | "metadata": {},
490 | "source": [
491 | "Look at the output of the solver. It always returns a variable value, but sometimes the algorithm fails being traped in an unfeasible region. Most of the time, the solver converges to a local minimum where the final distance to the target is nonzero"
492 | ]
493 | },
494 | {
495 | "cell_type": "markdown",
496 | "metadata": {},
497 | "source": [
498 | "Now you can write a planner that try to optimize and retry until a valid solition is found!"
499 | ]
500 | },
501 | {
502 | "cell_type": "code",
503 | "execution_count": null,
504 | "metadata": {},
505 | "outputs": [],
506 | "source": [
507 | "# Your solution"
508 | ]
509 | },
510 | {
511 | "cell_type": "markdown",
512 | "metadata": {},
513 | "source": [
514 | "And the solution if you need it:"
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "execution_count": null,
520 | "metadata": {},
521 | "outputs": [],
522 | "source": [
523 | "%do_not_load tp0/generated/simple_path_planning_useit"
524 | ]
525 | },
526 | {
527 | "cell_type": "code",
528 | "execution_count": null,
529 | "metadata": {},
530 | "outputs": [],
531 | "source": []
532 | }
533 | ],
534 | "metadata": {
535 | "kernelspec": {
536 | "display_name": "Python 3 (ipykernel)",
537 | "language": "python",
538 | "name": "python3"
539 | },
540 | "language_info": {
541 | "codemirror_mode": {
542 | "name": "ipython",
543 | "version": 3
544 | },
545 | "file_extension": ".py",
546 | "mimetype": "text/x-python",
547 | "name": "python",
548 | "nbconvert_exporter": "python",
549 | "pygments_lexer": "ipython3",
550 | "version": "3.10.12"
551 | }
552 | },
553 | "nbformat": 4,
554 | "nbformat_minor": 4
555 | }
556 |
--------------------------------------------------------------------------------
/1_motion_planning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "66575b61-30a5-4471-9e4e-a45c6fc396a9",
6 | "metadata": {},
7 | "source": [
8 | "# Implement RRT and its variant on UR5"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "b5d79912-1a64-4466-8017-70724567b28c",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import example_robot_data as robex\n",
19 | "import hppfcl\n",
20 | "import math\n",
21 | "import numpy as np\n",
22 | "import pinocchio as pin\n",
23 | "import time\n",
24 | "from tqdm import tqdm"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "8abf514d-0c36-4c32-934f-e1d013c57ea7",
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "import matplotlib.pylab as plt; plt.ion()"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "id": "bf038376-9d69-48e9-9bdb-0bcf3b641247",
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "from utils.meshcat_viewer_wrapper import MeshcatVisualizer, colors\n",
45 | "from utils.datastructures.storage import Storage\n",
46 | "from utils.datastructures.pathtree import PathTree\n",
47 | "from utils.datastructures.mtree import MTree\n",
48 | "from utils.collision_wrapper import CollisionWrapper"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "id": "58561974-a3d2-4692-821c-30f4d0caf96f",
54 | "metadata": {},
55 | "source": [
56 | "## Load UR5"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "id": "7217ea2c-5cac-430e-8496-93e80e78b859",
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "robot = robex.load('ur5')\n",
67 | "collision_model = robot.collision_model\n",
68 | "visual_model = robot.visual_model"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "id": "98805b5f-f993-4d4a-98ca-c465ff363424",
74 | "metadata": {},
75 | "source": [
76 | "Recall some placement for the UR5"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "id": "9f521701-7416-4736-9363-305a45258d14",
83 | "metadata": {},
84 | "outputs": [],
85 | "source": [
86 | "a = robot.placement(robot.q0, 6) # Placement of the end effector joint.\n",
87 | "b = robot.framePlacement(robot.q0, 22) # Placement of the end effector tip.\n",
88 | "\n",
89 | "tool_axis = b.rotation[:, 2] # Axis of the tool\n",
90 | "tool_position = b.translation"
91 | ]
92 | },
93 | {
94 | "cell_type": "code",
95 | "execution_count": null,
96 | "id": "d96626e3-1b22-446f-916a-f7ee54b78f04",
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "viz = MeshcatVisualizer(robot)"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "id": "e2d1d4ee-c1f4-4350-9370-ed8f32dee4cf",
107 | "metadata": {},
108 | "outputs": [],
109 | "source": [
110 | "viz.viewer.jupyter_cell()"
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "id": "8ede57d4-8b0d-4a69-adbd-95901db503cd",
116 | "metadata": {},
117 | "source": [
118 | "Set a start and a goal configuration"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "id": "4b14d4cd-9e44-4e14-9a56-ec08256d0238",
125 | "metadata": {},
126 | "outputs": [],
127 | "source": [
128 | "q_i = np.array([1., -1.5, 2.1, -.5, -.5, 0])\n",
129 | "q_g = np.array([3., -1., 1, -.5, -.5, 0])\n",
130 | "radius = 0.05"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "id": "a05303fd-5fa9-47bf-92cd-63d64e2212cf",
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "viz.display(q_i)\n",
141 | "M = robot.framePlacement(q_i, 22)\n",
142 | "name = \"world/sph_initial\"\n",
143 | "viz.addSphere(name, radius, [0., 1., 0., 1.])\n",
144 | "viz.applyConfiguration(name,M)"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "id": "dbc882d4-b335-4255-9db3-3aa61eeb8ee7",
151 | "metadata": {
152 | "tags": []
153 | },
154 | "outputs": [],
155 | "source": [
156 | "viz.display(q_g)\n",
157 | "M = robot.framePlacement(q_g, 22)\n",
158 | "name = \"world/sph_goal\"\n",
159 | "viz.addSphere(name, radius, [0., 0., 1., 1.])\n",
160 | "viz.applyConfiguration(name,M)"
161 | ]
162 | },
163 | {
164 | "cell_type": "markdown",
165 | "id": "6a4bce9b-0c81-4c7a-bee9-02c9aa59b829",
166 | "metadata": {},
167 | "source": [
168 | "## Implement everything needed for RRT"
169 | ]
170 | },
171 | {
172 | "cell_type": "markdown",
173 | "id": "9d4c45e0-ee85-47b1-88ae-94c3c8fb50d8",
174 | "metadata": {},
175 | "source": [
176 | "We abstract the robot the environment and its behaviour in a class call `System`\n",
177 | "\n",
178 | "It must be able to:\n",
179 | "- generate random configuration which are not colliding if needed (sampling)\n",
180 | "- implement a distance on the configuration space (distance)\n",
181 | "- generate path between two configuration (steering)\n",
182 | "- check if a path is free between two configuration and return the latest free config (directional free steering)\n",
183 | "and some function to display the configuration.\n",
184 | "\n",
185 | "Recall that in the case of the UR5 the configuration space is $S_1^{6}$, where $S_1$ is the unit cirle, we can parametrize by $\\theta\\in[-\\pi,\\pi]$ such that $-\\pi$ and $\\pi$ are identified.\n",
186 | "\n",
187 | "In the next cell, you must implement the system behaviour for the UR5."
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "id": "b74e5455-eae8-49fe-9030-f403467df288",
194 | "metadata": {},
195 | "outputs": [],
196 | "source": [
197 | "class System():\n",
198 | "\n",
199 | " def __init__(self, robot):\n",
200 | " self.robot = robot\n",
201 | " robot.gmodel = robot.collision_model\n",
202 | " self.display_edge_count = 0\n",
203 | " self.colwrap = CollisionWrapper(robot) # For collision checking\n",
204 | " self.nq = self.robot.nq\n",
205 | " self.display_count = 0\n",
206 | " \n",
207 | " def distance(self, q1, q2):\n",
208 | " \"\"\"\n",
209 | " Must return a distance between q1 and q2 which can be a batch of config.\n",
210 | " \"\"\"\n",
211 | " if len(q2.shape) > len(q1.shape):\n",
212 | " q1 = q1[None, ...]\n",
213 | " e = np.mod(np.abs(q1 - q2), 2 * np.pi)\n",
214 | " e[e > np.pi] = 2 * np.pi - e[e > np.pi]\n",
215 | " return np.linalg.norm(e, axis=-1)\n",
216 | "\n",
217 | " def random_config(self, free=True):\n",
218 | " \"\"\"\n",
219 | " Must return a random configuration which is not in collision if free=True\n",
220 | " \"\"\"\n",
221 | " q = 2 * np.pi * np.random.rand(6) - np.pi\n",
222 | " if not free:\n",
223 | " return q\n",
224 | " while self.is_colliding(q):\n",
225 | " q = 2 * np.pi * np.random.rand(6) - np.pi\n",
226 | " return q\n",
227 | "\n",
228 | " def is_colliding(self, q):\n",
229 | " \"\"\"\n",
230 | " Use CollisionWrapper to decide if a configuration is in collision\n",
231 | " \"\"\"\n",
232 | " self.colwrap.computeCollisions(q)\n",
233 | " collisions = self.colwrap.getCollisionList()\n",
234 | " return (len(collisions) > 0)\n",
235 | "\n",
236 | " def get_path(self, q1, q2, l_min=None, l_max=None, eps=0.2):\n",
237 | " \"\"\"\n",
238 | " generate a continuous path with precision eps between q1 and q2\n",
239 | " If l_min of l_max is mention, extrapolate or cut the path such\n",
240 | " that \n",
241 | " \"\"\"\n",
242 | " q1 = np.mod(q1 + np.pi, 2 * np.pi) - np.pi\n",
243 | " q2 = np.mod(q2 + np.pi, 2 * np.pi) - np.pi\n",
244 | "\n",
245 | " diff = q2 - q1\n",
246 | " query = np.abs(diff) > np.pi\n",
247 | " q2[query] = q2[query] - np.sign(diff[query]) * 2 * np.pi\n",
248 | "\n",
249 | " d = self.distance(q1, q2)\n",
250 | " if d < eps:\n",
251 | " return np.stack([q1, q2], axis=0)\n",
252 | " \n",
253 | " if l_min is not None or l_max is not None:\n",
254 | " new_d = np.clip(d, l_min, l_max)\n",
255 | " else:\n",
256 | " new_d = d\n",
257 | " \n",
258 | " N = int(new_d / eps + 2)\n",
259 | "\n",
260 | " return np.linspace(q1, q1 + (q2 - q1) * new_d / d, N)\n",
261 | " \n",
262 | " def is_free_path(self, q1, q2, l_min=0.2, l_max=1., eps=0.2):\n",
263 | " \"\"\"\n",
264 | " Create a path and check collision to return the last\n",
265 | " non-colliding configuration. Return X, q where X is a boolean\n",
266 | " who state is the steering has work.\n",
267 | " We require at least l_min must be cover without collision to validate the path.\n",
268 | " \"\"\"\n",
269 | " q_path = self.get_path(q1, q2, l_min, l_max, eps)\n",
270 | " N = len(q_path)\n",
271 | " N_min = N - 1 if l_min is None else min(N - 1, int(l_min / eps))\n",
272 | " for i in range(N):\n",
273 | " if self.is_colliding(q_path[i]):\n",
274 | " break\n",
275 | " if i < N_min:\n",
276 | " return False, None\n",
277 | " if i == N - 1:\n",
278 | " return True, q_path[-1]\n",
279 | " return True, q_path[i - 1]\n",
280 | "\n",
281 | " def reset(self):\n",
282 | " \"\"\"\n",
283 | " Reset the system visualization\n",
284 | " \"\"\"\n",
285 | " for i in range(self.display_count):\n",
286 | " viz.delete(f\"world/sph{i}\")\n",
287 | " viz.delete(f\"world/cil{i}\")\n",
288 | " self.display_count = 0\n",
289 | " \n",
290 | " def display_edge(self, q1, q2, radius=0.01, color=[1.,0.,0.,1]):\n",
291 | " M1 = self.robot.framePlacement(q1, 22) # Placement of the end effector tip.\n",
292 | " M2 = self.robot.framePlacement(q2, 22) # Placement of the end effector tip.\n",
293 | " middle = .5 * (M1.translation + M2.translation)\n",
294 | " direction = M2.translation - M1.translation\n",
295 | " length = np.linalg.norm(direction)\n",
296 | " dire = direction / length\n",
297 | " orth = np.cross(dire, np.array([0, 0, 1]))\n",
298 | " orth2 = np.cross(dire, orth)\n",
299 | " Mcyl = pin.SE3(np.stack([orth2, dire, orth], axis=1), middle)\n",
300 | " name = f\"world/sph{self.display_count}\"\n",
301 | " viz.addSphere(name, radius, [1.,0.,0.,1])\n",
302 | " viz.applyConfiguration(name,M2)\n",
303 | " name = f\"world/cil{self.display_count}\"\n",
304 | " viz.addCylinder(name, length, radius / 4, [0., 1., 0., 1])\n",
305 | " viz.applyConfiguration(name,Mcyl)\n",
306 | " self.display_count +=1\n",
307 | " \n",
308 | " def display_motion(self, qs, step=1e-1):\n",
309 | " # Given a point path display the smooth movement\n",
310 | " for i in range(len(qs) - 1):\n",
311 | " for q in self.get_path(qs[i], qs[i+1])[:-1]:\n",
312 | " viz.display(q)\n",
313 | " time.sleep(step)\n",
314 | " viz.display(qs[-1])\n"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": null,
320 | "id": "51178d60-0474-4d57-9336-d8f1617b8a4a",
321 | "metadata": {},
322 | "outputs": [],
323 | "source": [
324 | "system = System(robot)"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": null,
330 | "id": "4b0ced1c-2211-459a-aeb3-fb9c3b25085d",
331 | "metadata": {},
332 | "outputs": [],
333 | "source": [
334 | "system.distance(q_i, q_g)"
335 | ]
336 | },
337 | {
338 | "cell_type": "code",
339 | "execution_count": null,
340 | "id": "9fafd71f-87f0-44b3-a92c-48f53b6580ac",
341 | "metadata": {},
342 | "outputs": [],
343 | "source": [
344 | "system.display_motion(system.get_path(q_i, q_g))"
345 | ]
346 | },
347 | {
348 | "cell_type": "markdown",
349 | "id": "914d0701-bc8f-46c0-8888-8e019f15cfc3",
350 | "metadata": {},
351 | "source": [
352 | "## RRT implementation"
353 | ]
354 | },
355 | {
356 | "cell_type": "markdown",
357 | "id": "3a64cb4c-28b0-4e66-9324-1cba47374186",
358 | "metadata": {},
359 | "source": [
360 | "In its most simple form, RRT construct a tree from the start, eventually with a bias toward the goal. In the following class, we add some memoization to avoid recomputing distances. The kNN (k Nearest Neighbors) structure works on node indices."
361 | ]
362 | },
363 | {
364 | "cell_type": "markdown",
365 | "id": "3b1a4b31-07d2-463f-851e-cd4d2ca09bbf",
366 | "metadata": {},
367 | "source": [
368 | "Let us look at an implementation the core algorithm:"
369 | ]
370 | },
371 | {
372 | "cell_type": "code",
373 | "execution_count": null,
374 | "id": "a62a504d-9ac1-4bcc-a149-cd6516f90a0e",
375 | "metadata": {},
376 | "outputs": [],
377 | "source": [
378 | "class RRT():\n",
379 | " \"\"\"\n",
380 | " Can be splited into RRT base because different rrt\n",
381 | " have factorisable logic\n",
382 | " \"\"\"\n",
383 | " def __init__(\n",
384 | " self,\n",
385 | " system,\n",
386 | " node_max=500000,\n",
387 | " iter_max=1000000,\n",
388 | " N_bias=10,\n",
389 | " l_min=.2,\n",
390 | " l_max=.5,\n",
391 | " steer_delta=.1,\n",
392 | " ):\n",
393 | " \"\"\"\n",
394 | " [Here, in proper code, we would document the parameters of our function. Do that below,\n",
395 | " using the Google style for docstrings.]\n",
396 | " https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html\n",
397 | "\n",
398 | " Args:\n",
399 | " node_max: ...\n",
400 | " iter_max: ...\n",
401 | " ...\n",
402 | " \"\"\"\n",
403 | " self.system = system\n",
404 | " # params\n",
405 | " self.l_max = l_max\n",
406 | " self.l_min = l_min\n",
407 | " self.N_bias = N_bias\n",
408 | " self.node_max = node_max\n",
409 | " self.iter_max = iter_max\n",
410 | " self.steer_delta = steer_delta\n",
411 | " # intern\n",
412 | " self.NNtree = None\n",
413 | " self.storage = None\n",
414 | " self.pathtree = None\n",
415 | " # The distance function will be called on N, dim object\n",
416 | " self.real_distance = self.system.distance\n",
417 | " # Internal for computational_opti in calculating distance\n",
418 | " self._candidate = None\n",
419 | " self._goal = None\n",
420 | " self._cached_dist_to_candidate = {}\n",
421 | " self._cached_dist_to_goal = {}\n",
422 | "\n",
423 | " def distance(self, q1_idx, q2_idx):\n",
424 | " if isinstance(q2_idx, int):\n",
425 | " if q1_idx == q2_idx:\n",
426 | " return 0.\n",
427 | " if q1_idx == -1 or q2_idx == -1:\n",
428 | " if q2_idx == -1:\n",
429 | " q1_idx, q2_idx = q2_idx, q1_idx\n",
430 | " if q2_idx not in self._cached_dist_to_candidate:\n",
431 | " self._cached_dist_to_candidate[q2_idx] = self.real_distance(\n",
432 | " self._candidate, self.storage[q2_idx]\n",
433 | " )\n",
434 | " return self._cached_dist_to_candidate[q2_idx]\n",
435 | " if q1_idx == -2 or q2_idx == -2:\n",
436 | " if q2_idx == -2:\n",
437 | " q1_idx, q2_idx = q2_idx, q1_idx\n",
438 | " if q2_idx not in self._cached_dist_to_goal:\n",
439 | " self._cached_dist_to_goal[q2_idx] = self.real_distance(\n",
440 | " self._goal, self.storage[q2_idx]\n",
441 | " )\n",
442 | " return self._cached_dist_to_goal[q2_idx]\n",
443 | " return self.real_distance(self.storage[q1_idx], self.storage[q2_idx])\n",
444 | " if q1_idx == -1:\n",
445 | " q = self._candidate\n",
446 | " elif q1_idx == -2:\n",
447 | " q = self._goal\n",
448 | " else:\n",
449 | " q = self.storage[q1_idx]\n",
450 | " return self.real_distance(q, self.storage[q2_idx])\n",
451 | "\n",
452 | " def new_candidate(self):\n",
453 | " q = self.system.random_config(free=True)\n",
454 | " self._candidate = q\n",
455 | " self._cached_dist_to_candidate = {}\n",
456 | " return q\n",
457 | "\n",
458 | " def solve(self, qi, validate, qg=None):\n",
459 | " self.system.reset()\n",
460 | " self._goal = qg\n",
461 | " \n",
462 | " # Reset internal datastructures\n",
463 | " self.storage = Storage(self.node_max, self.system.nq)\n",
464 | " self.pathtree = PathTree(self.storage)\n",
465 | " self.NNtree = MTree(self.distance)\n",
466 | " qi_idx = self.storage.add_point(qi)\n",
467 | " self.NNtree.add_point(qi_idx)\n",
468 | " self.it_trace = []\n",
469 | "\n",
470 | " found = False\n",
471 | " iterator = range(self.iter_max)\n",
472 | " for i in tqdm(iterator):\n",
473 | " # New candidate\n",
474 | " if i % self.N_bias == 0:\n",
475 | " q_new = self._goal\n",
476 | " q_new_idx = -2\n",
477 | " else:\n",
478 | " q_new = self.new_candidate()\n",
479 | " q_new_idx = -1\n",
480 | "\n",
481 | " # Find closest neighboor to q_new\n",
482 | " q_near_idx, d = self.NNtree.nearest_neighbour(q_new_idx)\n",
483 | " \n",
484 | " # Steer from it toward the new checking for colision\n",
485 | " success, q_prox = self.system.is_free_path(\n",
486 | " self.storage.data[q_near_idx],\n",
487 | " q_new,\n",
488 | " l_min=self.l_min,\n",
489 | " l_max=self.l_max,\n",
490 | " eps=self.steer_delta\n",
491 | " )\n",
492 | "\n",
493 | " if not success:\n",
494 | " self.it_trace.append(0)\n",
495 | " continue\n",
496 | " self.it_trace.append(1)\n",
497 | " \n",
498 | " # Add the points in data structures\n",
499 | " q_prox_idx = self.storage.add_point(q_prox)\n",
500 | " self.NNtree.add_point(q_prox_idx)\n",
501 | " self.pathtree.update_link(q_prox_idx, q_near_idx)\n",
502 | " self.system.display_edge(self.storage[q_near_idx], self.storage[q_prox_idx])\n",
503 | "\n",
504 | " # Test if it reach the goal\n",
505 | " if validate(q_prox):\n",
506 | " q_g_idx = self.storage.add_point(q_prox)\n",
507 | " self.NNtree.add_point(q_g_idx)\n",
508 | " self.pathtree.update_link(q_g_idx, q_prox_idx)\n",
509 | " found = True\n",
510 | " break\n",
511 | " self.iter_done = i + 1\n",
512 | " self.found = found\n",
513 | " return found\n",
514 | "\n",
515 | " def get_path(self, q_g):\n",
516 | " assert self.found\n",
517 | " path = self.pathtree.get_path()\n",
518 | " return np.concatenate([path, q_g[None, :]])\n"
519 | ]
520 | },
521 | {
522 | "cell_type": "markdown",
523 | "id": "73533b79-9692-4856-8094-b4874e3944ef",
524 | "metadata": {},
525 | "source": [
526 | "In proper code, we would document the parameters of our functions.\n",
527 | "\n",
528 | "- **Your turn:** Add docstrings to the code above, following the [Google style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html).\n",
529 | "- Optional: you are welcome to add type annotations if you'd like.\n",
530 | "\n",
531 | "The constructor of the `RRT` class invites you to start."
532 | ]
533 | },
534 | {
535 | "cell_type": "markdown",
536 | "id": "811cc5ae-b601-477d-9d55-e560d0e45262",
537 | "metadata": {},
538 | "source": [
539 | "For this problem, we will instantiate our RRT with the following parameters:"
540 | ]
541 | },
542 | {
543 | "cell_type": "code",
544 | "execution_count": null,
545 | "id": "676b991b-32d1-4e2f-8dd9-ed85252e180c",
546 | "metadata": {},
547 | "outputs": [],
548 | "source": [
549 | "rrt = RRT(\n",
550 | " system,\n",
551 | " N_bias=20,\n",
552 | " l_min=0.2,\n",
553 | " l_max=0.5,\n",
554 | " steer_delta=0.1,\n",
555 | ")"
556 | ]
557 | },
558 | {
559 | "cell_type": "markdown",
560 | "id": "f20ecc74-a2b6-4b91-8131-cbb7ff3e13a9",
561 | "metadata": {},
562 | "source": [
563 | "Now let's define our termination condition, and run the main function:"
564 | ]
565 | },
566 | {
567 | "cell_type": "code",
568 | "execution_count": null,
569 | "id": "de20f3cc-6df3-4e5c-ab5b-e183616d4a2e",
570 | "metadata": {},
571 | "outputs": [],
572 | "source": [
573 | "eps_final = .1\n",
574 | "def validation(key):\n",
575 | " vec = robot.framePlacement(key, 22).translation - robot.framePlacement(q_g, 22).translation\n",
576 | " return (float(np.linalg.norm(vec)) < eps_final)\n",
577 | "\n",
578 | "rrt.solve(q_i, validation, qg=q_g)"
579 | ]
580 | },
581 | {
582 | "cell_type": "code",
583 | "execution_count": null,
584 | "id": "7f3a6a8d-4fd9-419b-8176-214543c08a22",
585 | "metadata": {},
586 | "outputs": [],
587 | "source": [
588 | "system.display_motion(rrt.get_path(q_g))"
589 | ]
590 | },
591 | {
592 | "cell_type": "code",
593 | "execution_count": null,
594 | "id": "d82012a2-c7c4-4362-85ae-2c8e67a2e2a4",
595 | "metadata": {},
596 | "outputs": [],
597 | "source": [
598 | "system.reset()"
599 | ]
600 | },
601 | {
602 | "cell_type": "markdown",
603 | "id": "3bd7d92a-abea-482b-8a91-a5942a125585",
604 | "metadata": {},
605 | "source": [
606 | "## Create obstacle with environments"
607 | ]
608 | },
609 | {
610 | "cell_type": "markdown",
611 | "id": "8784c9ad-ab66-4440-9a0f-1a1e10ab4b2d",
612 | "metadata": {},
613 | "source": [
614 | "We already had some simple algorithms to find free paths, *i.e.* without obstacles. Let us now add some obstacles to the environment:"
615 | ]
616 | },
617 | {
618 | "cell_type": "code",
619 | "execution_count": null,
620 | "id": "5efbc0a8-bd55-4122-b09a-168b64f8de19",
621 | "metadata": {},
622 | "outputs": [],
623 | "source": [
624 | "robot = robex.load('ur5')\n",
625 | "collision_model = robot.collision_model\n",
626 | "visual_model = robot.visual_model"
627 | ]
628 | },
629 | {
630 | "cell_type": "code",
631 | "execution_count": null,
632 | "id": "189570f8-d592-4f6d-b3d7-89a466fdf898",
633 | "metadata": {},
634 | "outputs": [],
635 | "source": [
636 | "def addCylinderToUniverse(name, radius, length, placement, color=colors.red):\n",
637 | " geom = pin.GeometryObject(\n",
638 | " name,\n",
639 | " 0,\n",
640 | " hppfcl.Cylinder(radius, length),\n",
641 | " placement\n",
642 | " )\n",
643 | " new_id = collision_model.addGeometryObject(geom)\n",
644 | " geom.meshColor = np.array(color)\n",
645 | " visual_model.addGeometryObject(geom)\n",
646 | " \n",
647 | " for link_id in range(robot.model.nq):\n",
648 | " collision_model.addCollisionPair(\n",
649 | " pin.CollisionPair(link_id, new_id)\n",
650 | " )\n",
651 | " return geom"
652 | ]
653 | },
654 | {
655 | "cell_type": "code",
656 | "execution_count": null,
657 | "id": "4c0b531f-992c-47cc-b101-d1113a2f0870",
658 | "metadata": {},
659 | "outputs": [],
660 | "source": [
661 | "from pinocchio.utils import rotate\n",
662 | "\n",
663 | "[collision_model.removeGeometryObject(e.name) for e in collision_model.geometryObjects if e.name.startswith('world/')]\n",
664 | "\n",
665 | "# Add a red box in the viewer\n",
666 | "radius = 0.1\n",
667 | "length = 1.\n",
668 | "\n",
669 | "cylID = \"world/cyl1\"\n",
670 | "placement = pin.SE3(pin.SE3(rotate('z',np.pi/2), np.array([-0.5,0.4,0.5])))\n",
671 | "addCylinderToUniverse(cylID,radius,length,placement,color=[.7,.7,0.98,1])\n",
672 | "\n",
673 | "\n",
674 | "cylID = \"world/cyl2\"\n",
675 | "placement = pin.SE3(pin.SE3(rotate('z',np.pi/2), np.array([-0.5,-0.4,0.5])))\n",
676 | "addCylinderToUniverse(cylID,radius,length,placement,color=[.7,.7,0.98,1])\n",
677 | "\n",
678 | "cylID = \"world/cyl3\"\n",
679 | "placement = pin.SE3(pin.SE3(rotate('z',np.pi/2), np.array([-0.5,0.7,0.5])))\n",
680 | "addCylinderToUniverse(cylID,radius,length,placement,color=[.7,.7,0.98,1])\n",
681 | "\n",
682 | "\n",
683 | "cylID = \"world/cyl4\"\n",
684 | "placement = pin.SE3(pin.SE3(rotate('z',np.pi/2), np.array([-0.5,-0.7,0.5])))\n",
685 | "addCylinderToUniverse(cylID,radius,length,placement,color=[.7,.7,0.98,1])"
686 | ]
687 | },
688 | {
689 | "cell_type": "code",
690 | "execution_count": null,
691 | "id": "6780c139-e2de-4c9f-8af4-61cb1e7e1b4e",
692 | "metadata": {},
693 | "outputs": [],
694 | "source": [
695 | "q_i = np.array([-1., -1.5, 2.1, -.5, -.5, 0])\n",
696 | "q_g = np.array([3.1, -1., 1, -.5, -.5, 0])\n",
697 | "radius = 0.05"
698 | ]
699 | },
700 | {
701 | "cell_type": "markdown",
702 | "id": "c114f554-6126-47e9-8749-934d47d2c7c1",
703 | "metadata": {},
704 | "source": [
705 | "We need to reload the viewer"
706 | ]
707 | },
708 | {
709 | "cell_type": "code",
710 | "execution_count": null,
711 | "id": "62ab0ec0-5faf-43fc-bcd8-9ad9cbe3ad56",
712 | "metadata": {},
713 | "outputs": [],
714 | "source": [
715 | "viz = MeshcatVisualizer(robot)"
716 | ]
717 | },
718 | {
719 | "cell_type": "code",
720 | "execution_count": null,
721 | "id": "037ac2f5-e0d3-438a-ab10-9f313a8f8803",
722 | "metadata": {},
723 | "outputs": [],
724 | "source": [
725 | "viz.display(q_i)\n",
726 | "M = robot.framePlacement(q_i, 22)\n",
727 | "name = \"world/sph_initial\"\n",
728 | "viz.addSphere(name, radius, [0., 1., 0., 1.])\n",
729 | "viz.applyConfiguration(name,M)"
730 | ]
731 | },
732 | {
733 | "cell_type": "code",
734 | "execution_count": null,
735 | "id": "5e5dfe04-e29d-4455-ba00-c7c158a8930c",
736 | "metadata": {
737 | "tags": []
738 | },
739 | "outputs": [],
740 | "source": [
741 | "viz.display(q_g)\n",
742 | "M = robot.framePlacement(q_g, 22)\n",
743 | "name = \"world/sph_goal\"\n",
744 | "viz.addSphere(name, radius, [0., 0., 1., 1.])\n",
745 | "viz.applyConfiguration(name,M)"
746 | ]
747 | },
748 | {
749 | "cell_type": "code",
750 | "execution_count": null,
751 | "id": "6f17610e-e874-4e4f-aaf7-6c6d569f965e",
752 | "metadata": {},
753 | "outputs": [],
754 | "source": [
755 | "viz.display(q_g)"
756 | ]
757 | },
758 | {
759 | "cell_type": "code",
760 | "execution_count": null,
761 | "id": "e19068e5-033b-412d-8d64-ffef4def949a",
762 | "metadata": {},
763 | "outputs": [],
764 | "source": [
765 | "system = System(robot)"
766 | ]
767 | },
768 | {
769 | "cell_type": "code",
770 | "execution_count": null,
771 | "id": "661a6589-36ff-44c5-b761-c6f588a8eab7",
772 | "metadata": {},
773 | "outputs": [],
774 | "source": [
775 | "rrt = RRT(\n",
776 | " system,\n",
777 | " N_bias=20,\n",
778 | " l_min=0.2,\n",
779 | " l_max=0.5,\n",
780 | " steer_delta=0.1,\n",
781 | ")"
782 | ]
783 | },
784 | {
785 | "cell_type": "code",
786 | "execution_count": null,
787 | "id": "ffa971ee-6396-4533-b875-5c4e7124bc7a",
788 | "metadata": {},
789 | "outputs": [],
790 | "source": [
791 | "eps_final = .1\n",
792 | "\n",
793 | "def validation(key):\n",
794 | " vec = robot.framePlacement(key, 22).translation - robot.framePlacement(q_g, 22).translation\n",
795 | " return (float(np.linalg.norm(vec)) < eps_final)\n",
796 | "\n",
797 | "rrt.solve(q_i, validation, qg=q_g)"
798 | ]
799 | },
800 | {
801 | "cell_type": "code",
802 | "execution_count": null,
803 | "id": "cf22611e-ba2d-4c0e-93a8-03ff922b6073",
804 | "metadata": {},
805 | "outputs": [],
806 | "source": [
807 | "system.display_motion(rrt.get_path(q_g))"
808 | ]
809 | },
810 | {
811 | "cell_type": "markdown",
812 | "id": "160cf161-755f-430e-8e91-e926f644fbe2",
813 | "metadata": {},
814 | "source": [
815 | "And solve RRT. It is long right ? Let us implement more efficient algorithms"
816 | ]
817 | },
818 | {
819 | "cell_type": "markdown",
820 | "id": "c90f1487-cd74-4eb8-8f3f-a0f5e65a1aad",
821 | "metadata": {},
822 | "source": [
823 | "## Bi-RRT"
824 | ]
825 | },
826 | {
827 | "cell_type": "markdown",
828 | "id": "d526af5b-9c37-425f-a30c-45bb80588404",
829 | "metadata": {},
830 | "source": [
831 | "Now it's your turn. Make a `BiRRT` class, similar to the `RRT` class above, but implementing the Bi-RRT algorithm. (It is not recommended to try to inherit from `RRT`, as you will end up re-implementing most functions.) Here is a template you are free to adapt, with some advice:"
832 | ]
833 | },
834 | {
835 | "cell_type": "code",
836 | "execution_count": null,
837 | "id": "c003e3fe-a042-4e3a-b644-f607acb9faef",
838 | "metadata": {},
839 | "outputs": [],
840 | "source": [
841 | "class BiRRT(RRT):\n",
842 | " def __init__(\n",
843 | " self,\n",
844 | " system,\n",
845 | " node_max=500000,\n",
846 | " iter_max=1000000,\n",
847 | " l_min=.2,\n",
848 | " l_max=.5,\n",
849 | " steer_delta=.1,\n",
850 | " ):\n",
851 | " # Initialize attributes:\n",
852 | " # self.l_min = l_min\n",
853 | " # etc.\n",
854 | "\n",
855 | " # New: duplicate this attribute as dictionaries with two keys:\n",
856 | " # \"forward\" and \"backward\". See `solve` below.\n",
857 | " self._cached_dist_to_candidate = {}\n",
858 | " self.storage = {}\n",
859 | " self.pathtree = {}\n",
860 | " self.tree = {}\n",
861 | "\n",
862 | " def tree_distance(self, direction: str, q1_idx, q2_idx):\n",
863 | " # Adapt from RRT.distance\n",
864 | " # There is now a direction string to select the underlying tree,\n",
865 | " # either \"forward\" (from q_init) or \"backward\" (from q_goal).\n",
866 | "\n",
867 | " def forward_distance(self, q1_idx, q2_idx):\n",
868 | " return self.tree_distance(\"forward\", q1_idx, q2_idx)\n",
869 | "\n",
870 | " def backward_distance(self, q1_idx, q2_idx):\n",
871 | " return self.tree_distance(\"backward\", q1_idx, q2_idx)\n",
872 | "\n",
873 | " def new_candidate(self):\n",
874 | " # A minor change is required to adapt RRT.new_candidate to this template.\n",
875 | "\n",
876 | " def solve(self, qi, qg):\n",
877 | " # Reset internal datastructures\n",
878 | " for direction in (\"forward\", \"backward\"):\n",
879 | " self._cached_dist_to_candidate[direction] = {}\n",
880 | " self.storage[direction] = Storage(node_max, system.nq)\n",
881 | " self.pathtree[direction] = PathTree(self.storage[direction])\n",
882 | " self.tree = {\n",
883 | " \"forward\": MTree(self.forward_distance),\n",
884 | " \"backward\": MTree(self.backward_distance),\n",
885 | " }\n",
886 | "\n",
887 | " # Now datastructures are initialized\n",
888 | " # The rest is up to you! \n",
889 | "\n",
890 | " def get_path(self):\n",
891 | " assert self.found\n",
892 | " forward_path = self.pathtree[\"forward\"].get_path()\n",
893 | " backward_path = self.pathtree[\"backward\"].get_path()\n",
894 | " return np.concatenate([forward_path, backward_path[::-1]])"
895 | ]
896 | },
897 | {
898 | "cell_type": "markdown",
899 | "id": "5254b62c-673f-407d-a0e6-effd3e75aabb",
900 | "metadata": {},
901 | "source": [
902 | "You should be able to call `BiRRT` similarly to `RRT`:"
903 | ]
904 | },
905 | {
906 | "cell_type": "code",
907 | "execution_count": null,
908 | "id": "60443eb2-ea54-4c0b-8d24-1908a4ca710b",
909 | "metadata": {},
910 | "outputs": [],
911 | "source": [
912 | "system.reset()\n",
913 | "\n",
914 | "birrt = BiRRT(\n",
915 | " system,\n",
916 | " l_min=0.2,\n",
917 | " l_max=0.5,\n",
918 | " steer_delta=0.1,\n",
919 | ")\n",
920 | "\n",
921 | "birrt.solve(q_i, q_g)"
922 | ]
923 | },
924 | {
925 | "cell_type": "code",
926 | "execution_count": null,
927 | "id": "b62f8275-b4d7-4168-ad6c-e437d0ea2887",
928 | "metadata": {},
929 | "outputs": [],
930 | "source": [
931 | "system.display_motion(birrt.get_path())"
932 | ]
933 | },
934 | {
935 | "cell_type": "markdown",
936 | "id": "a291e25d-46a4-47b3-b894-8d5810196d39",
937 | "metadata": {},
938 | "source": [
939 | "How many iterations did it take to find a solution? Is it faster than previously with `RRT`?"
940 | ]
941 | },
942 | {
943 | "cell_type": "markdown",
944 | "id": "84285a1d-ac59-40c9-aef1-23527d0bab54",
945 | "metadata": {},
946 | "source": [
947 | "## Bonus question: Bi-RRT*"
948 | ]
949 | },
950 | {
951 | "cell_type": "markdown",
952 | "id": "6212a5aa-5794-48c6-91e5-f70c3ade34fa",
953 | "metadata": {},
954 | "source": [
955 | "Implement an optimal variant `BiRRTStar` of your `BiRRT` class and run it in the same configuration as the two algorithms above. What do you notice about the resulting tree? What is the improvement in overall path length between `RRT`, `BiRRT` and `BiRRTStar`?"
956 | ]
957 | }
958 | ],
959 | "metadata": {
960 | "kernelspec": {
961 | "display_name": "Python 3 (ipykernel)",
962 | "language": "python",
963 | "name": "python3"
964 | },
965 | "language_info": {
966 | "codemirror_mode": {
967 | "name": "ipython",
968 | "version": 3
969 | },
970 | "file_extension": ".py",
971 | "mimetype": "text/x-python",
972 | "name": "python",
973 | "nbconvert_exporter": "python",
974 | "pygments_lexer": "ipython3",
975 | "version": "3.10.12"
976 | }
977 | },
978 | "nbformat": 4,
979 | "nbformat_minor": 5
980 | }
981 |
--------------------------------------------------------------------------------
/3_reinforcement_learning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "589389e6-2aad-45b6-bc03-e9b37ab1514f",
6 | "metadata": {},
7 | "source": [
8 | "# Reinforcement learning for legged robots\n",
9 | "\n",
10 | "## Setup\n",
11 | "\n",
12 | "Before we start, you will need to update your conda environment to use Gymnasium (maintained) rather than OpenAI Gym (discontinued). You can simply run:\n",
13 | "\n",
14 | "```\n",
15 | "conda activate robotics-mva\n",
16 | "conda install -c conda-forge gymnasium imageio mujoco=2.3.7 stable-baselines3 tensorboard\n",
17 | "```\n",
18 | "\n",
19 | "Import Gymnasium and Stable Baselines3 to check that everything is working:"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "id": "b3fa7b15-843f-4696-9bfc-141da71bf7d1",
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "import gymnasium as gym\n",
30 | "import stable_baselines3"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "id": "8e264559-88cc-48b7-8257-f7755fff3ce7",
36 | "metadata": {},
37 | "source": [
38 | "Let's import the usual suspects as well:"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "id": "86303cf2-f879-407d-b528-6c0a80b8df20",
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "import matplotlib.pylab as plt\n",
49 | "import numpy as np\n",
50 | "\n",
51 | "plt.ion()"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "id": "b09e16a1-bf4a-403b-8709-3da11bc3c4b4",
57 | "metadata": {},
58 | "source": [
59 | "# Inverted pendulum environment\n",
60 | "\n",
61 | "The inverted pendulum model is not just a toy model reproducing the properties of real robot models for balancing: as it turns out, the inverted pendulum appears in the dynamics of *any* mobile robot, that is, a model with a floating-base joint at the root of the kinematic tree. (If you are curious: the inverted pendulum is a limit case of the [Newton-Euler equations](https://scaron.info/robotics/newton-euler-equations.html) corresponding to floating-base coordinates in the equations of motion $M \\ddot{q} + h = S^T \\tau + J_c^T f$, in the limit where the robot [does not vary its angular momentum](https://scaron.info/robotics/point-mass-model.html).) Thus, while we work on a simplified inverted pendulum in this notebook, concepts and tools are those used as-is on real robots, as you can verify by exploring the bonus section.\n",
62 | "\n",
63 | "Gymnasium is mainly a single-agent reinforcement learning API, but it also comes with simple environments, including an inverted pendulum sliding on a linear guide:"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "id": "11b5d942-85fa-435e-b5ef-8c85a74ba3db",
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "with gym.make(\"InvertedPendulum-v4\", render_mode=\"human\") as env:\n",
74 | " action = 0.0 * env.action_space.sample()\n",
75 | " observation, _ = env.reset()\n",
76 | " episode_return = 0.0\n",
77 | " for step in range(200):\n",
78 | " # action[0] = 5.0 * observation[1] + 0.3 * observation[0]\n",
79 | " observation, reward, terminated, truncated, _ = env.step(action)\n",
80 | " episode_return += reward\n",
81 | " if terminated or truncated:\n",
82 | " observation, _ = env.reset()\n",
83 | " print(f\"Return of the episode: {episode_return}\")"
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "id": "d7322422-94db-4e12-b299-36bb40649cf7",
89 | "metadata": {},
90 | "source": [
91 | "The structure of the action and observation vectors are documented in [Inverted Pendulum - Gymnasium Documentation](https://gymnasium.farama.org/environments/mujoco/inverted_pendulum/). The observation, in particular, is a NumPy array with four coordinates that we recall here for reference:\n",
92 | "\n",
93 | "| Num | Observation | Min | Max | Unit |\n",
94 | "|-----|-------------|-----|-----|------|\n",
95 | "| 0 | position of the cart along the linear surface | -Inf | Inf | position (m) |\n",
96 | "| 1 | vertical angle of the pole on the cart | -Inf | Inf | angle (rad) |\n",
97 | "| 2 | linear velocity of the cart | -Inf | Inf | linear velocity (m/s) |\n",
98 | "| 3 | angular velocity of the pole on the cart | -Inf | Inf | anglular velocity (rad/s) |\n",
99 | "\n",
100 | "We will use the following labels to annotate plots:"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": null,
106 | "id": "a3231c70-f49d-49be-b260-aadbade7b403",
107 | "metadata": {},
108 | "outputs": [],
109 | "source": [
110 | "OBSERVATION_LEGEND = (\"pitch\", \"position\", \"linear_velocity\", \"angular_velocity\")"
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "id": "aa062536-204c-4312-a858-f992f3db61d6",
116 | "metadata": {},
117 | "source": [
118 | "Check out the documentation for the definitions of the action and rewards."
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "id": "c285d7ce-3a97-4b07-8b5f-a9b04d7721ab",
124 | "metadata": {},
125 | "source": [
126 | "# PID control\n",
127 | "\n",
128 | "A *massively* used class of policies is the [PID controller](https://en.wikipedia.org/wiki/Proportional%E2%80%93integral%E2%80%93derivative_controller). Let's say we have a reference observation, like $o^* = [0\\ 0\\ 0\\ 0]$ for the inverted pendulum. Denoting by $e(t) = o^* - o(t)$ the *error* of the system when it observes a given state, a continuous-time PID controller will apply the action:\n",
129 | "\n",
130 | "$$\n",
131 | "a(t) = K_p^T e(t) + K_d^T \\dot{e}(t) + K_i^T \\int e(\\tau) \\mathrm{d} \\tau\n",
132 | "$$\n",
133 | "\n",
134 | "where $K_{p}, K_i, K_d \\in \\mathbb{R}^4$ are constants called *gains* and tuned by the user. In discrete time the idea is the same:\n",
135 | "\n",
136 | "$$\n",
137 | "a_k = K_p^T e_k + K_d^T \\frac{e_k - e_{k-1}}{\\delta t} + K_i^T \\sum_{i=0}^{k} e_i {\\delta t}\n",
138 | "$$"
139 | ]
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "id": "63c381eb-fca9-4ef4-8f99-3b1943231654",
144 | "metadata": {},
145 | "source": [
146 | "Let's refactor the rolling out of our episode into a standalone function:"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "id": "9c839bc6-168a-42c3-8f1c-c6b0c5411901",
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "def rollout_from_env(env, policy):\n",
157 | " episode = []\n",
158 | " observation, _ = env.reset()\n",
159 | " episode.append(observation)\n",
160 | " for step in range(1000):\n",
161 | " action = policy(observation)\n",
162 | " observation, reward, terminated, truncated, _ = env.step(action)\n",
163 | " episode.extend([action, reward, observation])\n",
164 | " if terminated or truncated:\n",
165 | " return episode\n",
166 | " return episode\n",
167 | "\n",
168 | "def rollout(policy, show: bool = True):\n",
169 | " kwargs = {\"render_mode\": \"human\"} if show else {}\n",
170 | " with gym.make(\"InvertedPendulum-v4\", **kwargs) as env:\n",
171 | " episode = rollout_from_env(env, policy)\n",
172 | " return episode"
173 | ]
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "id": "79ff0dce-a4df-4917-bb17-2393353610a3",
178 | "metadata": {},
179 | "source": [
180 | "## Question 1: Write a PID controller that balances the inverted pendulum"
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "id": "e7cfb28b-ff73-42ff-9524-eac8ec12f8a1",
186 | "metadata": {},
187 | "source": [
188 | "You can use global variables to store the (discrete) derivative and integral terms, this will be OK here as we only rollout a single trajectory:"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "id": "045ddcef-c0f7-4251-b73f-d5df5a0027e5",
195 | "metadata": {},
196 | "outputs": [],
197 | "source": [
198 | "def pid_policy(observation: np.ndarray) -> np.ndarray:\n",
199 | " my_action_value: float = 0.0 # your action here\n",
200 | " return np.array([my_action_value])\n",
201 | "\n",
202 | "episode = rollout(pid_policy, show=False)"
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "id": "a0a005aa-87fa-4f98-8ace-f24421886bed",
208 | "metadata": {},
209 | "source": [
210 | "You can look at the system using `show=True`, but intuition usually builds faster when looking at relevant plots:"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": null,
216 | "id": "9aa5decd-779c-4f0d-84fd-3eb47358b7fa",
217 | "metadata": {},
218 | "outputs": [],
219 | "source": [
220 | "observations = np.array(episode[::3])\n",
221 | "\n",
222 | "plt.plot(observations)\n",
223 | "plt.legend(OBSERVATION_LEGEND)"
224 | ]
225 | },
226 | {
227 | "cell_type": "markdown",
228 | "id": "98d50cd2-26fa-4d3c-a671-1ed0e1b9ee93",
229 | "metadata": {},
230 | "source": [
231 | "Can you reach the full reward of 1000 steps?"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": null,
237 | "id": "8bacbd0a-2ac5-44cf-848b-8ebfb6fe35d7",
238 | "metadata": {},
239 | "outputs": [],
240 | "source": [
241 | "print(f\"Return of the episode: {sum(episode[2::3])}\")"
242 | ]
243 | },
244 | {
245 | "cell_type": "markdown",
246 | "id": "b17cc998-1b23-416f-8e3b-810100c223fb",
247 | "metadata": {},
248 | "source": [
249 | "# Policy optimization\n",
250 | "\n",
251 | "Let us now train a policy, parameterized by a multilayer perceptron (MLP), to maximize the expected return over episodes on the inverted pendulum environment."
252 | ]
253 | },
254 | {
255 | "cell_type": "markdown",
256 | "id": "d5631f0f-1b84-4ee6-8e9c-b4f2915bd281",
257 | "metadata": {},
258 | "source": [
259 | "## Our very first policy\n",
260 | "\n",
261 | "We will use the proximal policy optimization (PPO) algorithm for training, using the implementation from Stable Baselines3: [PPO - Stable Baselines3 documentation](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html)."
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": null,
267 | "id": "128867ca-e600-4ba1-abbd-1f918976fba2",
268 | "metadata": {},
269 | "outputs": [],
270 | "source": [
271 | "from stable_baselines3 import PPO\n",
272 | "\n",
273 | "with gym.make(\"InvertedPendulum-v4\", render_mode=\"human\") as env:\n",
274 | " first_policy = PPO(\"MlpPolicy\", env, verbose=0)\n",
275 | " first_policy.learn(total_timesteps=1000, progress_bar=False)"
276 | ]
277 | },
278 | {
279 | "cell_type": "markdown",
280 | "id": "6323400b-18ca-43f6-a81e-a5e7f033a536",
281 | "metadata": {},
282 | "source": [
283 | "By instantiating the algorithm with no further ado, we let the library decide for us on a sane set of default hyperparameters, including:\n",
284 | "\n",
285 | "- Rollout buffers of `n_steps = 2048` steps, which we will visit `n_epochs = 10` times with mini-batches of size `batch_size = 64`.\n",
286 | "- Clipping range: ``0.2``.\n",
287 | "- No entropy regularization.\n",
288 | "- Learning rate for the Adam optimizer: ``3e-4``\n",
289 | "- Policy and value-function network architectures: two layers of 64 neurons with $\\tanh$ activation functions.\n",
290 | "\n",
291 | "We then called the `learn` function to execute PPO over a fixed total number of timesteps, here just a thousand."
292 | ]
293 | },
294 | {
295 | "cell_type": "markdown",
296 | "id": "8b82173c-6609-4b83-8618-36f82c1c1373",
297 | "metadata": {},
298 | "source": [
299 | "Rendering actually took a significant chunk of time. Let's instantiate and keep an environment open without rendering:"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": null,
305 | "id": "460fe1c7-ee3b-450a-b09c-03b96f9086bf",
306 | "metadata": {},
307 | "outputs": [],
308 | "source": [
309 | "env = gym.make(\"InvertedPendulum-v4\")"
310 | ]
311 | },
312 | {
313 | "cell_type": "markdown",
314 | "id": "a9bd090f-ca34-41e0-9900-52977eef9c4b",
315 | "metadata": {},
316 | "source": [
317 | "We can use it to train much more steps in roughly the same time, reporting training metrics every `n_steps` step:"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": null,
323 | "id": "b7262602-c277-4697-8987-ba126a87e75b",
324 | "metadata": {},
325 | "outputs": [],
326 | "source": [
327 | "second_policy = PPO(\"MlpPolicy\", env, verbose=1)\n",
328 | "second_policy.learn(total_timesteps=10_000, progress_bar=False)"
329 | ]
330 | },
331 | {
332 | "cell_type": "markdown",
333 | "id": "6219aab8-1143-4606-a44f-b62fdffebbf1",
334 | "metadata": {},
335 | "source": [
336 | "Let's see how this policy performs:"
337 | ]
338 | },
339 | {
340 | "cell_type": "code",
341 | "execution_count": null,
342 | "id": "f17dc178-bb9c-4155-8047-feed1e575226",
343 | "metadata": {},
344 | "outputs": [],
345 | "source": [
346 | "def policy_closure(policy):\n",
347 | " \"\"\"Utility function to turn our policy instance into a function.\n",
348 | "\n",
349 | " Args:\n",
350 | " policy: Policy to turn into a function.\n",
351 | " \n",
352 | " Returns:\n",
353 | " Function from observation to policy action.\n",
354 | " \"\"\"\n",
355 | " def policy_function(observation):\n",
356 | " action, _ = policy.predict(observation)\n",
357 | " return action\n",
358 | " return policy_function"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": null,
364 | "id": "7e4e3cd4-4572-4c40-94b1-688d472a4b8c",
365 | "metadata": {},
366 | "outputs": [],
367 | "source": [
368 | "episode = rollout(policy_closure(second_policy), show=True)"
369 | ]
370 | },
371 | {
372 | "cell_type": "markdown",
373 | "id": "941219f1-9b3c-4e66-86e5-d42f2473b149",
374 | "metadata": {},
375 | "source": [
376 | "Okay, it looks like we didn't train for long enough!"
377 | ]
378 | },
379 | {
380 | "cell_type": "markdown",
381 | "id": "195f2a85-8dd7-4f3a-8368-7427f1caadca",
382 | "metadata": {},
383 | "source": [
384 | "## Monitoring performance during training\n",
385 | "\n",
386 | "Let's train for longer, and use TensorBoard to keep track. We don't know how long training will take so let's put a rather large total number of steps (you can interrupt training once you observed convergence in TensorBoard):"
387 | ]
388 | },
389 | {
390 | "cell_type": "code",
391 | "execution_count": null,
392 | "id": "196da0ad-1e83-441e-ac10-6c9ecd83c224",
393 | "metadata": {},
394 | "outputs": [],
395 | "source": [
396 | "erudite_policy = PPO(\n",
397 | " \"MlpPolicy\",\n",
398 | " env,\n",
399 | " tensorboard_log=\"./inverted_pendulum_tensorboard/\",\n",
400 | " verbose=0,\n",
401 | ")\n",
402 | "\n",
403 | "erudite_policy.learn(\n",
404 | " total_timesteps=1_000_000,\n",
405 | " progress_bar=False,\n",
406 | " tb_log_name=\"erudite\",\n",
407 | ")"
408 | ]
409 | },
410 | {
411 | "cell_type": "markdown",
412 | "id": "ad91d14e-53e2-443f-b7ab-edd69d480add",
413 | "metadata": {},
414 | "source": [
415 | "Run TensorBoard on the directory thus created to open a dashboard in your Web browser:\n",
416 | "\n",
417 | "```\n",
418 | "tensorboard --logdir ./inverted_pendulum_tensorboard/\n",
419 | "```\n",
420 | "\n",
421 | "The link will typically be http://localhost:6006 (port number increases if you run TensorBoard multiple times in parallel). Tips:\n",
422 | "\n",
423 | "- Click the Settings icon in the top-right corner and enable \"Reload data\"\n",
424 | "- Uncheck \"Ignore outliers in chart scaling\" (your preference)"
425 | ]
426 | },
427 | {
428 | "cell_type": "markdown",
429 | "id": "68771e35-48cd-43be-89ff-0055dc196d0b",
430 | "metadata": {},
431 | "source": [
432 | "## Saving our policy\n",
433 | "\n",
434 | "Now that you spent some computing to optimize an actual policy, better save it to disk:"
435 | ]
436 | },
437 | {
438 | "cell_type": "code",
439 | "execution_count": null,
440 | "id": "effeccfc-8b95-48e1-98c4-1b96838bb28e",
441 | "metadata": {},
442 | "outputs": [],
443 | "source": [
444 | "erudite_policy.save(\"pendulum_erudite\")"
445 | ]
446 | },
447 | {
448 | "cell_type": "markdown",
449 | "id": "ea09c56e-8647-414d-aae0-5e1b16ba3a0f",
450 | "metadata": {},
451 | "source": [
452 | "You can then reload it later by:"
453 | ]
454 | },
455 | {
456 | "cell_type": "code",
457 | "execution_count": null,
458 | "id": "994dae4f-b651-4488-925e-2ba369eeedc7",
459 | "metadata": {},
460 | "outputs": [],
461 | "source": [
462 | "erudite_policy = PPO.load(\"pendulum_erudite\", env=env)"
463 | ]
464 | },
465 | {
466 | "cell_type": "markdown",
467 | "id": "dded4f1e-ae57-4c94-b2e1-3a6d019aecc4",
468 | "metadata": {},
469 | "source": [
470 | "## Question 2: How many steps does it take to train a successful policy?\n",
471 | "\n",
472 | "We consider a policy successful if it consistently achieves the maximum return of 1000."
473 | ]
474 | },
475 | {
476 | "cell_type": "raw",
477 | "id": "42c6d68d-4812-4222-97da-a6699803b986",
478 | "metadata": {},
479 | "source": [
480 | "== Your reply here =="
481 | ]
482 | },
483 | {
484 | "cell_type": "markdown",
485 | "id": "553b846f-db13-43ba-81cb-57b039852c86",
486 | "metadata": {},
487 | "source": [
488 | "## A more realistic environment\n",
489 | "\n",
490 | "Real systems suffer from the two main issues we saw in the [Perception and estimation](https://scaron.info/robotics-mva/#5-perception-estimation) class: *bias* and *variance*. In this section, we model bias in actuation and perception by adding delays (via low-pass filtering) to respectively the action and observation vectors. Empirically this is an effective model, as for instance it contributes to sim2real transfer on Upkie. To add these delays, we use an [`environment wrapper`](https://gymnasium.farama.org/api/wrappers/), which is a convenient way to compose environments, used in both the Gymnasium and Stable Baselines3 APIs:"
491 | ]
492 | },
493 | {
494 | "cell_type": "code",
495 | "execution_count": null,
496 | "id": "6e8a3140-7ee7-4d6f-afd9-19d6ca4816c0",
497 | "metadata": {},
498 | "outputs": [],
499 | "source": [
500 | "class DelayWrapper(gym.Wrapper):\n",
501 | " def __init__(self, env, time_constant: float = 0.2):\n",
502 | " \"\"\"Wrap environment with some actuation and perception modeling.\n",
503 | "\n",
504 | " Args:\n",
505 | " env: Environment to wrap.\n",
506 | " time_constant: Constant of the internal low-pass filter, in seconds.\n",
507 | " Feel free to play with different values but leave it to the default\n",
508 | " of 0.2 seconds when handing out your homework.\n",
509 | "\n",
510 | " Note:\n",
511 | " Delays are implemented by a low-pass filter. The same time constant\n",
512 | " is used for both actions and observations, which is not realistic, but\n",
513 | " makes for less tutorial code ;)\n",
514 | " \"\"\"\n",
515 | " alpha = env.dt / time_constant\n",
516 | " assert 0.0 < alpha < 1.0\n",
517 | " super().__init__(env)\n",
518 | " self._alpha = alpha\n",
519 | " self._prev_action = 0.0 * env.action_space.sample()\n",
520 | " self._prev_observation = np.zeros(4)\n",
521 | "\n",
522 | " def low_pass_filter(self, old_value, new_value):\n",
523 | " return old_value + self._alpha * (new_value - old_value)\n",
524 | " \n",
525 | " def step(self, action):\n",
526 | " new_action = self.low_pass_filter(self._prev_action, action)\n",
527 | " observation, reward, terminated, truncated, info = self.env.step(new_action)\n",
528 | " new_observation = self.low_pass_filter(self._prev_observation, observation)\n",
529 | " self._prev_action = new_action\n",
530 | " self._prev_observation = new_observation\n",
531 | " return new_observation, reward, terminated, truncated, info\n",
532 | "\n",
533 | "delay_env = DelayWrapper(env)"
534 | ]
535 | },
536 | {
537 | "cell_type": "markdown",
538 | "id": "b1b5de5e-50ca-4049-bb5f-b9203919e0ba",
539 | "metadata": {},
540 | "source": [
541 | "We can check how our current policy fares against the delayed environment. Spoiler alert: no great."
542 | ]
543 | },
544 | {
545 | "cell_type": "code",
546 | "execution_count": null,
547 | "id": "4e1508e6-e04f-4b22-8009-80baae1bae7d",
548 | "metadata": {},
549 | "outputs": [],
550 | "source": [
551 | "delay_episode = rollout_from_env(delay_env, policy_closure(erudite_policy))\n",
552 | "observations = np.array(delay_episode[::3])\n",
553 | "\n",
554 | "plt.plot(observations[:, :2])\n",
555 | "plt.legend(OBSERVATION_LEGEND)"
556 | ]
557 | },
558 | {
559 | "cell_type": "markdown",
560 | "id": "70af3932-751e-47e4-8334-bd55be62aaa1",
561 | "metadata": {},
562 | "source": [
563 | "## Question 3: Can't we just re-train a policy on the new environment?\n",
564 | "\n",
565 | "At this point of the tutorial this is a rethorical question, but we should check anyway. Re-train a policy on the delayed environment and comment on its performance:"
566 | ]
567 | },
568 | {
569 | "cell_type": "code",
570 | "execution_count": null,
571 | "id": "693aa97c-3ee2-4cbd-bc06-7cb224e8bc86",
572 | "metadata": {},
573 | "outputs": [],
574 | "source": [
575 | "# Your code here"
576 | ]
577 | },
578 | {
579 | "cell_type": "raw",
580 | "id": "48928906-bcd9-40d5-b17e-35fd06d6c6ac",
581 | "metadata": {},
582 | "source": [
583 | "== Your observations here =="
584 | ]
585 | },
586 | {
587 | "cell_type": "markdown",
588 | "id": "c0e2df30-259f-477a-ab14-d39c17e5f15f",
589 | "metadata": {},
590 | "source": [
591 | "## The Real Question 3: Why do delays degrade both runtime and training performance?\n",
592 | "\n",
593 | "Loss in runtime performance refers to the one we observed when executing a policy trained without delay on an environment with delays. Loss in training performance refers to the fact that, even when we train a new policy on the environment with delays, by the end of training it does not achieve maximum return."
594 | ]
595 | },
596 | {
597 | "cell_type": "raw",
598 | "id": "3b7459d5-93d0-49cb-85c7-2172e2b08073",
599 | "metadata": {},
600 | "source": [
601 | "== Your explanation here =="
602 | ]
603 | },
604 | {
605 | "cell_type": "markdown",
606 | "id": "e63a441a-a84d-49ab-aecc-7362dee66b91",
607 | "metadata": {},
608 | "source": [
609 | "Propose and implement a way to overcome this. Train the resulting policy in a variable called `iron_policy`."
610 | ]
611 | },
612 | {
613 | "cell_type": "code",
614 | "execution_count": null,
615 | "id": "b22770ba-4e58-4989-b62c-d5aa1734336c",
616 | "metadata": {},
617 | "outputs": [],
618 | "source": [
619 | "# Your code here"
620 | ]
621 | },
622 | {
623 | "cell_type": "code",
624 | "execution_count": null,
625 | "id": "5a7a876f-e78e-47bb-9b42-9423618d1e42",
626 | "metadata": {},
627 | "outputs": [],
628 | "source": [
629 | "iron_policy.save(\"iron_policy\")"
630 | ]
631 | },
632 | {
633 | "cell_type": "markdown",
634 | "id": "d2a70b63-7fda-4c0f-b777-ef0dc2128ab2",
635 | "metadata": {},
636 | "source": [
637 | "Roll out an episode and plot the outcome to show that your policy handles delays properly."
638 | ]
639 | },
640 | {
641 | "cell_type": "code",
642 | "execution_count": null,
643 | "id": "45f13a2b-6a1c-4d44-bb84-2fffaf6bf6e3",
644 | "metadata": {},
645 | "outputs": [],
646 | "source": [
647 | "# Your episode rollout here\n",
648 | "\n",
649 | "plt.plot(np.array(observations)[:, :2])\n",
650 | "plt.legend(OBSERVATION_LEGEND)"
651 | ]
652 | },
653 | {
654 | "cell_type": "markdown",
655 | "id": "1e12fcf1-88b9-4899-b79d-866c67e4a3f5",
656 | "metadata": {},
657 | "source": [
658 | "## Question 4: Can you improve sampling efficiency?\n",
659 | "\n",
660 | "This last question is open: what can you change in the pipeline to train a policy that achieves maximum return using less samples? Report on at least one thing that allowed you to train with less environment steps."
661 | ]
662 | },
663 | {
664 | "cell_type": "raw",
665 | "id": "0f5cb9a5-fd18-4077-a6fc-83fa5377de96",
666 | "metadata": {},
667 | "source": [
668 | "== Your report here =="
669 | ]
670 | },
671 | {
672 | "cell_type": "markdown",
673 | "id": "131966f5-9524-4b44-9843-0c1a662ba2e1",
674 | "metadata": {},
675 | "source": [
676 | "Here is a state-of-the-art™ utility function if you want to experiment with scheduling some of the ``Callable[[float], float]`` [hyperparameters](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#parameters):"
677 | ]
678 | },
679 | {
680 | "cell_type": "code",
681 | "execution_count": null,
682 | "id": "3de11ab9-2534-4723-8868-1582772d038c",
683 | "metadata": {},
684 | "outputs": [],
685 | "source": [
686 | "def affine_schedule(y_0: float, y_1: float):\n",
687 | " \"\"\"Affine schedule as a function over the [0, 1] interval.\n",
688 | "\n",
689 | " Args:\n",
690 | " y_0: Function value at zero.\n",
691 | " y_1: Function value at one.\n",
692 | " \n",
693 | " Returns:\n",
694 | " Corresponding affine function.\n",
695 | " \"\"\"\n",
696 | " def schedule(x: float) -> float:\n",
697 | " return y_0 + x * (y_1 - y_0)\n",
698 | " return schedule"
699 | ]
700 | },
701 | {
702 | "cell_type": "markdown",
703 | "id": "b21d78dd-f80e-4183-8fa7-55c803e38404",
704 | "metadata": {},
705 | "source": [
706 | "And here is a wrapper template if you want to experiment with reward shaping:"
707 | ]
708 | },
709 | {
710 | "cell_type": "code",
711 | "execution_count": null,
712 | "id": "2cf9a3ed-8f76-4fac-98f2-3a23df818deb",
713 | "metadata": {},
714 | "outputs": [],
715 | "source": [
716 | "class CustomRewardWrapper(gym.Wrapper):\n",
717 | " def __init__(self, env):\n",
718 | " super().__init__(env)\n",
719 | "\n",
720 | " def step(self, action):\n",
721 | " observation, reward, terminated, truncated, info = self.env.step(action)\n",
722 | " new_reward = 0.0 # your formula here\n",
723 | " return observation, new_reward, terminated, truncated, info"
724 | ]
725 | },
726 | {
727 | "cell_type": "markdown",
728 | "id": "2c4dd0df-6dc7-4d51-b29b-c77b49bde437",
729 | "metadata": {},
730 | "source": [
731 | "# Bonus: training a policy for a real robot\n",
732 | "\n",
733 | "This section is entirely optional and will only work on Linux or macOS. In this part, we follow the same training pipeline but with the open source software of [Upkie](https://hackaday.io/project/185729-upkie-wheeled-biped-robots)."
734 | ]
735 | },
736 | {
737 | "cell_type": "markdown",
738 | "id": "9634ff93-f09f-4e0a-8d0f-547848f3900b",
739 | "metadata": {},
740 | "source": [
741 | "## Setup\n",
742 | "\n",
743 | "
\n",
744 | "\n",
745 | "First, make sure you have a C++ compiler (setup one-liners: [Fedora](https://github.com/upkie/upkie/discussions/100), [Ubuntu](https://github.com/upkie/upkie/discussions/101)). You can run an Upkie simulation right from the command line. It won't install anything on your machine, everything will run locally from the repository:\n",
746 | "\n",
747 | "```console\n",
748 | "git clone https://github.com/upkie/upkie.git\n",
749 | "cd upkie\n",
750 | "git checkout fb9a0ab1f67a8014c08b34d7c0d317c7a8f71662\n",
751 | "./start_simulation.sh\n",
752 | "```\n",
753 | "\n",
754 | "**NB:** this tutorial is written for the specific commit checked out above. If some instructions don't work it's likely you forgot to check it out.\n",
755 | "\n",
756 | "We will use the Python API of the robot to test things from this notebook, or from custom scripts. Install it from PyPI in your Conda environment:\n",
757 | "\n",
758 | "```\n",
759 | "pip install upkie\n",
760 | "```"
761 | ]
762 | },
763 | {
764 | "cell_type": "markdown",
765 | "id": "ba44abc0-f7e9-4c2b-9d4e-a3579213e138",
766 | "metadata": {},
767 | "source": [
768 | "## Stepping the environment\n",
769 | "\n",
770 | "If everything worked well, you should be able to step an environment as follows:"
771 | ]
772 | },
773 | {
774 | "cell_type": "code",
775 | "execution_count": null,
776 | "id": "acedf0d6-fc2f-43f4-9ff6-a8e12dbd7ae0",
777 | "metadata": {},
778 | "outputs": [],
779 | "source": [
780 | "import gymnasium as gym\n",
781 | "import upkie.envs\n",
782 | "\n",
783 | "upkie.envs.register()\n",
784 | "\n",
785 | "episode_return = 0.0\n",
786 | "with gym.make(\"UpkieGroundVelocity-v1\", frequency=200.0) as env:\n",
787 | " observation, _ = env.reset() # connects to the spine (simulator or real robot)\n",
788 | " action = 0.0 * env.action_space.sample()\n",
789 | " for step in range(1000):\n",
790 | " pitch = observation[0]\n",
791 | " action[0] = 10.0 * pitch # 1D action: [ground_velocity]\n",
792 | " observation, reward, terminated, truncated, _ = env.step(action)\n",
793 | " episode_return += reward\n",
794 | " if terminated or truncated:\n",
795 | " observation, _ = env.reset()\n",
796 | "\n",
797 | "print(f\"We have stepped the environment {step + 1} times\")\n",
798 | "print(f\"The return of our episode is {episode_return}\")"
799 | ]
800 | },
801 | {
802 | "cell_type": "markdown",
803 | "id": "031343b5-cf94-46ae-98f3-a4c5ebbc037c",
804 | "metadata": {},
805 | "source": [
806 | "(If you see a message \"Waiting for spine /vulp to start\", it means the simulation is not running.)"
807 | ]
808 | },
809 | {
810 | "cell_type": "markdown",
811 | "id": "aecfd91f-676c-4d6f-beb0-a286dc681ae3",
812 | "metadata": {},
813 | "source": [
814 | "We can double-check the last observation from the episode:"
815 | ]
816 | },
817 | {
818 | "cell_type": "code",
819 | "execution_count": null,
820 | "id": "ed6d972f-4cc9-4005-b9a1-4a7433a19938",
821 | "metadata": {},
822 | "outputs": [],
823 | "source": [
824 | "def report_last_observation(observation):\n",
825 | " print(\"The last observation of the episode is:\")\n",
826 | " print(f\"- Pitch from torso to world: {observation[0]:.2} rad\")\n",
827 | " print(f\"- Ground position: {observation[1]:.2} m\")\n",
828 | " print(f\"- Angular velocity from torso to world in torso: {observation[2]:.2} rad/s\")\n",
829 | " print(f\"- Ground velocity: {observation[3]:.2} m/s\")\n",
830 | " \n",
831 | "report_last_observation(observation)"
832 | ]
833 | },
834 | {
835 | "cell_type": "markdown",
836 | "id": "d5a269e3-d876-4d05-88e5-b0d73be6f939",
837 | "metadata": {},
838 | "source": [
839 | "## Question B1: PID control\n",
840 | "\n",
841 | "Adapt your code from Question 1 to this environment:"
842 | ]
843 | },
844 | {
845 | "cell_type": "code",
846 | "execution_count": null,
847 | "id": "1e8e9f1e-f2a1-4a18-aa38-256a425d018c",
848 | "metadata": {},
849 | "outputs": [],
850 | "source": [
851 | "def policy_b1(observation):\n",
852 | " return np.array([0.0]) # replace with your solution\n",
853 | "\n",
854 | "\n",
855 | "def run(policy, nb_steps: int):\n",
856 | " episode_return = 0.0\n",
857 | " with gym.make(\"UpkieGroundVelocity-v1\", frequency=200.0) as env:\n",
858 | " observation, _ = env.reset() # connects to the spine (simulator or real robot)\n",
859 | " for step in range(nb_steps):\n",
860 | " action = policy_b1(observation)\n",
861 | " observation, reward, terminated, truncated, _ = env.step(action)\n",
862 | " if terminated or truncated:\n",
863 | " print(\"Fall detected!\")\n",
864 | " return episode_return\n",
865 | " report_last_observation(observation)\n",
866 | " return episode_return\n",
867 | "\n",
868 | "\n",
869 | "episode_return = run(policy_b1, 1000)\n",
870 | "print(f\"The return of our episode is {episode_return}\")"
871 | ]
872 | },
873 | {
874 | "cell_type": "markdown",
875 | "id": "e999eb22-a94a-4a58-ac06-9a7dbc15a7ee",
876 | "metadata": {},
877 | "source": [
878 | "## Training a new policy\n",
879 | "\n",
880 | "The Upkie repository ships three agents based on PID control, model predictive control and reinforcement learning. We now focus on the latter, called the \"PPO balancer\".\n",
881 | "\n",
882 | "Check that you can run the training part by running, from the root of the repository:\n",
883 | "\n",
884 | "```\n",
885 | "./tools/bazel run //agents/ppo_balancer:train -- --nb-envs 1 --show\n",
886 | "```\n",
887 | "\n",
888 | "A simulation window should pop, and verbose output from SB3 should be printed to your terminal.\n",
889 | "\n",
890 | "By default, training data will be logged to `/tmp`. You can select a different output path by setting the `UPKIE_TRAINING_PATH` environment variable in your shell. For instance:\n",
891 | "\n",
892 | "```\n",
893 | "export UPKIE_TRAINING_PATH=\"${HOME}/src/upkie/training\"\n",
894 | "```\n",
895 | "\n",
896 | "Run TensorBoard from the training directory:\n",
897 | "\n",
898 | "```\n",
899 | "tensorboard --logdir ${UPKIE_TRAINING_PATH} # or /tmp if you keep the default\n",
900 | "```\n",
901 | "\n",
902 | "Each training will be named after a word picked at random in an English dictionary."
903 | ]
904 | },
905 | {
906 | "cell_type": "markdown",
907 | "id": "a7e47aad-7787-409c-af7e-b83bfccaa592",
908 | "metadata": {},
909 | "source": [
910 | "## Selecting the number of processes\n",
911 | "\n",
912 | "We can increase the number of parallel CPU environments ``--nb-envs`` to a value suitable to your computer. Let training run for a minute and check `time/fps`. Increase the number of environments and compare the stationary regime of `time/fps`. You should see a performance increase when adding the first few environments, followed by a declined when there are two many parallel processes compared to your number of CPU cores. Pick the value that works best for you."
913 | ]
914 | },
915 | {
916 | "cell_type": "markdown",
917 | "id": "696a0943-cc10-4cd0-a2d8-d5313dbe37e5",
918 | "metadata": {},
919 | "source": [
920 | "## Running a trained policy\n",
921 | "\n",
922 | "Copy the file `final.zip` from your trained policy directory to `agents/ppo_balancer/policy/params.zip`. Start a simulation and run the policy by:\n",
923 | "\n",
924 | "```\n",
925 | "./tools/bazel run //agents/ppo_balancer\n",
926 | "```\n",
927 | "\n",
928 | "What kind of behavior do you observe?"
929 | ]
930 | },
931 | {
932 | "cell_type": "raw",
933 | "id": "eaabe73c-f412-44b5-a714-241077720d01",
934 | "metadata": {},
935 | "source": [
936 | "== Your observations here =="
937 | ]
938 | },
939 | {
940 | "cell_type": "markdown",
941 | "id": "6c356c81-b5ef-4364-a5db-c8e2600e104a",
942 | "metadata": {},
943 | "source": [
944 | "## Question B2: Improve this baseline"
945 | ]
946 | },
947 | {
948 | "cell_type": "markdown",
949 | "id": "527ecb8c-7292-432f-b0d3-b90c36de8719",
950 | "metadata": {},
951 | "source": [
952 | "The policy you are testing here is not the one we saw in class. Open question: improve on it using any of the methods we discussed. Measure the improvement by `ep_len_mean` or any other quantitative criterion:"
953 | ]
954 | },
955 | {
956 | "cell_type": "raw",
957 | "id": "ce7d720b-17b8-493d-8128-e66c6571d3ff",
958 | "metadata": {},
959 | "source": [
960 | "== Your experiments here ==\n",
961 | "\n",
962 | "- Tried: ... / Measured outcome: ..."
963 | ]
964 | }
965 | ],
966 | "metadata": {
967 | "kernelspec": {
968 | "display_name": "Python 3 (ipykernel)",
969 | "language": "python",
970 | "name": "python3"
971 | },
972 | "language_info": {
973 | "codemirror_mode": {
974 | "name": "ipython",
975 | "version": 3
976 | },
977 | "file_extension": ".py",
978 | "mimetype": "text/x-python",
979 | "name": "python",
980 | "nbconvert_exporter": "python",
981 | "pygments_lexer": "ipython3",
982 | "version": "3.10.13"
983 | }
984 | },
985 | "nbformat": 4,
986 | "nbformat_minor": 5
987 | }
988 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2023, Stéphane Caron
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Robotics MVA 2023
2 |
3 | This repository contains tutorial notebooks for the 2023 [Robotics](https://www.master-mva.com/cours/robotics/) class at MVA.
4 |
5 | ## Get started
6 |
7 | Clone this repository:
8 |
9 | ```bash
10 | git clone https://github.com/stephane-caron/robotics-mva-2023.git
11 | ```
12 |
13 | Install miniconda:
14 |
15 | - Linux: https://docs.conda.io/en/latest/miniconda.html
16 | - macOS: https://docs.conda.io/en/latest/miniconda.html
17 | - Windows: https://www.anaconda.com/download/
18 |
19 | Don't forget to add the conda snippet to your shell configuration (for instance ``~/.bashrc``). After that, you can run all labs in a dedicated Python environment that will not affect your system's regular Python envirfonment.
20 |
21 | ### Run a notebook
22 |
23 | - Go to your local copy of the repository.
24 | - Open a terminal.
25 | - Create the conda environment:
26 |
27 | ```bash
28 | conda env create -f robotics-mva.yml
29 | ```
30 |
31 | From there on, to work on a notebook, you will only need to activate the environment:
32 |
33 | ```bash
34 | conda activate robotics-mva
35 | ```
36 |
37 | Then launch the notebook with:
38 |
39 | ```bash
40 | jupyter-lab
41 | ```
42 |
43 | The notebook will be accessible from your web browser at [localhost:8888](http://localhost:8888).
44 |
45 | Meshcat visualisation can be accessed in full page at `localhost:700N/static/` where N denotes the Nth MeshCat instance created by your notebook kernel.
46 |
47 | ## Troubleshooting
48 |
49 | - Make sure the virtual environment is activated for ``jupyter-lab`` to work.
50 | - Make sure you call ``jupyter-lab`` so that Python packages pathes are configured properly.
51 | - In particular, ``jupyter-notebook`` may not have paths configured properly, resulting in failed package imports.
52 |
53 | ## Updating the notebooks
54 |
55 | If the repository changes (for instance when new tutorials are pushed) you will need to update your local copy of it by "pulling" from the repository. To do so, go to the directory containing the tutorials and run:
56 |
57 | ```
58 | git pull
59 | ```
60 |
61 | If you already have local changes to a notebook `something.ipynb`, either you already know how to use git and you can commit them, or you don't and the safest way for you to update is to:
62 |
63 | - Copy your modified `something.ipynb` somewhere else
64 | - Revert it to its original version: ``git checkout -f something.ipynb``
65 | - Pull updates from the remote repository: ``git pull``
66 |
--------------------------------------------------------------------------------
/robotics-mva.yml:
--------------------------------------------------------------------------------
1 | name: robotics-mva
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - example-robot-data
7 | - gymnasium
8 | - imageio
9 | - ipywidgets
10 | - jupyterlab
11 | - matplotlib
12 | - meshcat-python
13 | - mujoco=2.3.7
14 | - pinocchio
15 | - python=3.10
16 | - quadprog
17 | - rich
18 | - scipy
19 | - stable-baselines3
20 | - tensorboard
21 | - tqdm
22 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .load_ur5_parallel import load_ur5_parallel
2 |
--------------------------------------------------------------------------------
/utils/collision_wrapper.py:
--------------------------------------------------------------------------------
1 | import pinocchio as pin
2 | import numpy as np
3 |
4 | class CollisionWrapper:
5 | def __init__(self,robot,viz=None):
6 | self.robot=robot
7 | self.viz=viz
8 |
9 | self.rmodel = robot.model
10 | self.rdata = self.rmodel.createData()
11 | self.gmodel = self.robot.gmodel
12 | self.gdata = self.gmodel.createData()
13 | self.gdata.collisionRequests.enable_contact = True
14 |
15 |
16 | def computeCollisions(self,q,vq=None):
17 | res = pin.computeCollisions(self.rmodel,self.rdata,self.gmodel,self.gdata,q,False)
18 | pin.computeDistances(self.rmodel,self.rdata,self.gmodel,self.gdata,q)
19 | pin.computeJointJacobians(self.rmodel,self.rdata,q)
20 | if vq is not None:
21 | pin.forwardKinematics(self.rmodel,self.rdata,q,vq,0*vq)
22 | return res
23 |
24 | def getCollisionList(self):
25 | '''Return a list of triplets [ index,collision,result ] where index is the
26 | index of the collision pair, colision is gmodel.collisionPairs[index]
27 | and result is gdata.collisionResults[index].
28 | '''
29 | return [ [ir,self.gmodel.collisionPairs[ir],r]
30 | for ir,r in enumerate(self.gdata.collisionResults) if r.isCollision() ]
31 |
32 | def _getCollisionJacobian(self,col,res):
33 | '''Compute the jacobian for one collision only. '''
34 | contact = res.getContact(0)
35 | g1 = self.gmodel.geometryObjects[col.first]
36 | g2 = self.gmodel.geometryObjects[col.second]
37 | oMc = pin.SE3(pin.Quaternion.FromTwoVectors(np.array([0,0,1]),contact.normal).matrix(),contact.pos)
38 |
39 | joint1 = g1.parentJoint
40 | joint2 = g2.parentJoint
41 | oMj1 = self.rdata.oMi[joint1]
42 | oMj2 = self.rdata.oMi[joint2]
43 |
44 | cMj1 = oMc.inverse()*oMj1
45 | cMj2 = oMc.inverse()*oMj2
46 |
47 | J1=pin.getJointJacobian(self.rmodel,self.rdata,joint1,pin.ReferenceFrame.LOCAL)
48 | J2=pin.getJointJacobian(self.rmodel,self.rdata,joint2,pin.ReferenceFrame.LOCAL)
49 | Jc1=cMj1.action@J1
50 | Jc2=cMj2.action@J2
51 | J = (Jc2-Jc1)[2,:]
52 | return J
53 |
54 | def _getCollisionJdotQdot(self,col,res):
55 | '''Compute the Coriolis acceleration for one collision only. '''
56 | contact = res.getContact(0)
57 | g1 = self.gmodel.geometryObjects[col.first]
58 | g2 = self.gmodel.geometryObjects[col.second]
59 | oMc = pin.SE3(pin.Quaternion.FromTwoVectors(np.array([0,0,1]),contact.normal).matrix(),contact.pos)
60 |
61 | joint1 = g1.parentJoint
62 | joint2 = g2.parentJoint
63 | oMj1 = self.rdata.oMi[joint1]
64 | oMj2 = self.rdata.oMi[joint2]
65 |
66 | cMj1 = oMc.inverse()*oMj1
67 | cMj2 = oMc.inverse()*oMj2
68 |
69 | a1 = self.rdata.a[joint1]
70 | a2 = self.rdata.a[joint2]
71 | a = (cMj1*a1-cMj2*a2).linear[2]
72 | return a
73 |
74 | def getCollisionJacobian(self,collisions=None):
75 | '''From a collision list, return the Jacobian corresponding to the normal direction. '''
76 | if collisions is None: collisions = self.getCollisionList()
77 | if len(collisions)==0: return np.ndarray([0,self.rmodel.nv])
78 | J = np.vstack([ self._getCollisionJacobian(c,r) for (i,c,r) in collisions ])
79 | return J
80 |
81 | def getCollisionJdotQdot(self,collisions=None):
82 | if collisions is None: collisions = self.getCollisionList()
83 | if len(collisions)==0: return np.array([])
84 | a0 = np.vstack([ self._getCollisionJdotQdot(c,r) for (i,c,r) in collisions ])
85 | return a0.squeeze()
86 |
87 | def getCollisionDistances(self,collisions=None):
88 | if collisions is None: collisions = self.getCollisionList()
89 | if len(collisions)==0: return np.array([])
90 | dist = np.array([ self.gdata.distanceResults[i].min_distance for (i,c,r) in collisions ])
91 | return dist
92 |
93 |
94 | # --- DISPLAY -----------------------------------------------------------------------------------
95 | # --- DISPLAY -----------------------------------------------------------------------------------
96 | # --- DISPLAY -----------------------------------------------------------------------------------
97 |
98 | def initDisplay(self,viz=None):
99 | if viz is not None: self.viz = viz
100 | assert(self.viz is not None)
101 |
102 | self.patchName = 'world/contact_%d_%s'
103 | self.ncollisions=10
104 | self.createDisplayPatchs(0)
105 |
106 | def createDisplayPatchs(self,ncollisions):
107 |
108 | if ncollisions == self.ncollisions: return
109 | elif ncollisions0: break
180 | if not i % 20: viz.display(q)
181 |
182 | viz.display(q)
183 |
184 | col.displayCollisions()
185 | p = cols[0][1]
186 | ci = cols[0][2].getContact(0)
187 |
188 | import pickle
189 | with open('/tmp/bug.pickle', 'wb') as file:
190 | pickle.dump([ col.gdata.oMg[11],
191 | col.gdata.oMg[3],
192 | #col.gmodel.geometryObjects[11].geometry,
193 | ] , file)
194 |
195 | dist=col.getCollisionDistances()
196 | J = col.getCollisionJacobian()
197 |
--------------------------------------------------------------------------------
/utils/datastructures/bucketkdtree.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from cspace_metric.datastructures.tree import NodeBinaryTree
3 |
4 |
5 | class BucketKDNode(NodeBinaryTree):
6 | # Adapt to wraparound spaces ?
7 | def __init__(
8 | self,
9 | parent=None,
10 | points=None,
11 | dim=3,
12 | dim_scale=None,
13 | bucketsize=10
14 | ):
15 | # Parent node
16 | self.parent = parent
17 | # Hyperparam
18 | self.bucketsize = bucketsize
19 | self.dim = dim
20 | self.dim_scale = dim_scale if dim_scale is not None else np.ones(dim)
21 | # Node are initially leaf without children and with a bucket
22 | self.is_leaf = True
23 | # Points buckets storage_data for leaf node
24 | self._points = np.zeros((bucketsize, dim))
25 | self._n_points = 0
26 | # Value usefull for not leaf node
27 | self.split_dim = None
28 | self.split_val = None
29 | self.left = None
30 | self.right = None
31 | # Bounds tracker
32 | self.lower = np.ones((dim)) * np.inf
33 | self.upper = - np.ones((dim)) * np.inf
34 | # Add given points
35 | self.add_points(points)
36 |
37 | def _update_bounds(self, points):
38 | # ensure non null points
39 | if not isinstance(points, np.ndarray):
40 | if not points:
41 | return
42 | points = np.array(points)
43 | if points.size is 0:
44 | return
45 |
46 | mini = np.min(points, axis=0)
47 | maxi = np.max(points, axis=0)
48 | min_update = mini < self.lower
49 | self.lower[min_update] = mini[min_update]
50 | max_update = maxi > self.upper
51 | self.upper[max_update] = maxi[max_update]
52 |
53 | def add_point(self, p):
54 | self._update_bounds([p])
55 | if self.is_leaf:
56 | # The node is still a leaf, just add to bucket
57 | # We are sure there is enough space
58 | self._points[self._n_points] = p
59 | self._n_points += 1
60 | if self._n_points == self.bucketsize:
61 | # It is now full, transform node as non leaf
62 | self._create_children()
63 | else:
64 | # Kd split to add to children
65 | if p[self.split_dim] <= self.split_val:
66 | self.left.add_point(p)
67 | else:
68 | self.right.add_point(p)
69 |
70 | def add_points(self, points):
71 | # ensure non null points
72 | if not isinstance(points, np.ndarray):
73 | if not points:
74 | return
75 | points = np.array(points)
76 | if points.size is 0:
77 | return
78 |
79 | if self.is_leaf:
80 | # We add the maximum we can to the bucket
81 | n = min(len(points), self.bucketsize - self._n_points)
82 | batch = points[:n]
83 | self._points[self._n_points:self._n_points + n] = batch
84 | self._n_points += n
85 | self._update_bounds(batch)
86 | if self._n_points == self.bucketsize:
87 | # It is full, transform node as non leaf
88 | self._create_children()
89 | # Add eventual remaining points
90 | self.add_points(points[n:])
91 | else:
92 | # We add points to child given their position
93 | self._update_bounds(points)
94 | infe = points[:, self.split_dim] <= self.split_val
95 | self.left.add_points(points[infe])
96 | self.right.add_points(points[~infe])
97 |
98 | def _create_children(self):
99 | assert self.is_leaf and self.bucketsize == self._n_points
100 | # The creation must appeaar only for full leaf (after an add)
101 | # At this points the bounds are the one of the bucket
102 | ranges = self.upper - self.lower
103 | split_dim = np.argmax(ranges * self.dim_scale)
104 | # No more a leaf, create attribute for non leaf
105 | self.is_leaf = False
106 | self.split_dim = split_dim
107 | self.split_val = self.lower[split_dim] + ranges[split_dim] / 2
108 | self.left = BucketKDNode(
109 | self,
110 | dim=self.dim, dim_scale=self.dim_scale, bucketsize=self.bucketsize
111 | )
112 | self.right = BucketKDNode(
113 | self,
114 | dim=self.dim, dim_scale=self.dim_scale, bucketsize=self.bucketsize
115 | )
116 | # Now diffuse bucket points in children and erase local bucket
117 | self.add_points(self._points)
118 | self._points = None
119 |
120 | def nearest_neighbour(self, query, dist_to_many, max_dist=None):
121 | if self.is_leaf:
122 | dists = dist_to_many(query, self._points[:self._n_points])
123 | i_min = np.argmin(dists)
124 | if max_dist is None or dists[i_min] < max_dist:
125 | return dists[i_min], self._points[i_min]
126 | return None, None
127 | else:
128 | cursor = self
129 | # Go down to the best leaf
130 | while not cursor.is_leaf:
131 | if query[cursor.split_dim] <= cursor.split_val:
132 | cursor = cursor.left
133 | else:
134 | cursor = cursor.right
135 | best_d, best_p = cursor.nearest_neighbour(
136 | query, dist_to_many, max_dist
137 | )
138 | if best_d is not None:
139 | max_dist = best_d
140 | # Go up by recursively checking ambiguous split
141 | while cursor is not self:
142 | cursor = cursor.parent
143 | # check ambiguity to non coming child
144 | # Get nearest in the child if needed
145 | i = cursor.split_dim
146 | s = self.dim_scale[i]
147 | x = query[i]
148 | d, p = None, None
149 | if x <= cursor.split_val:
150 | # We come from left, check right
151 | if s * (cursor.right.lower[i] - x) < max_dist:
152 | # There is an ambiguity, check right
153 | d, p = cursor.right.nearest_neighbour(
154 | query, dist_to_many, max_dist
155 | )
156 | else:
157 | # Same for right
158 | if s * (x - cursor.left.upper[i]) < max_dist:
159 | d, p = cursor.left.nearest_neighbour(
160 | query, dist_to_many, max_dist
161 | )
162 | if d is not None:
163 | # We have found something better in ambiguity
164 | best_d, best_p = d, p
165 | max_dist = best_d
166 |
167 | return best_d, best_p
168 |
169 |
170 | class SBucketKDNode(NodeBinaryTree):
171 | """
172 | Alternative with external storage_data
173 | """
174 | def __init__(
175 | self,
176 | storage_data,
177 | parent=None,
178 | points_idx=None,
179 | dim=3,
180 | dim_scale=None,
181 | bucketsize=10
182 | ):
183 | # Store storage_data
184 | self.storage_data = storage_data
185 | # Parent node
186 | self.parent = parent
187 | # Hyperparam
188 | self.bucketsize = bucketsize
189 | self.dim = dim
190 | self.dim_scale = dim_scale if dim_scale is not None else np.ones(dim)
191 | # Node are initially leaf without children and with a bucket
192 | self.is_leaf = True
193 | # Points buckets storage_data for leaf node
194 | self._points_idx = np.zeros(bucketsize, dtype=np.intp)
195 | self._n_points = 0
196 | # Value usefull for not leaf node
197 | self.split_dim = None
198 | self.split_val = None
199 | self.left = None
200 | self.right = None
201 | # Bounds tracker
202 | self.lower = np.ones(dim, dtype=float) * np.inf
203 | self.upper = - np.ones(dim, dtype=float) * np.inf
204 | # Add given points
205 | self.add_points(points_idx)
206 |
207 | def _update_bounds(self, points_idx):
208 | # ensure non null points
209 | if not isinstance(points_idx, np.ndarray):
210 | if not points_idx:
211 | return
212 | points_idx = np.array(points_idx, dtype=np.intp)
213 | if points_idx.size is 0:
214 | return
215 | mini = np.min(self.storage_data[points_idx], axis=0)
216 | maxi = np.max(self.storage_data[points_idx], axis=0)
217 | min_update = mini < self.lower
218 | self.lower[min_update] = mini[min_update]
219 | max_update = maxi > self.upper
220 | self.upper[max_update] = maxi[max_update]
221 |
222 | def add_point(self, p_idx):
223 | self._update_bounds([p_idx])
224 | if self.is_leaf:
225 | # The node is still a leaf, just add to bucket
226 | # We are sure there is enough space
227 | self._points_idx[self._n_points] = p_idx
228 | self._n_points += 1
229 | if self._n_points == self.bucketsize:
230 | # It is now full, transform node as non leaf
231 | self._create_children()
232 | else:
233 | # Kd split to add to children
234 | if self.storage_data[p_idx, self.split_dim] <= self.split_val:
235 | self.left.add_point(p_idx)
236 | else:
237 | self.right.add_point(p_idx)
238 |
239 | def add_points(self, points_idx):
240 | # ensure non null points
241 | if not isinstance(points_idx, np.ndarray):
242 | if not points_idx:
243 | return
244 | points_idx = np.array(points_idx, dtype=np.intp)
245 | if points_idx.size is 0:
246 | return
247 |
248 | if self.is_leaf:
249 | # We add the maximum we can to the bucket
250 | n = min(len(points_idx), self.bucketsize - self._n_points)
251 | batch = points_idx[:n]
252 | self._points_idx[self._n_points:self._n_points + n] = batch
253 | self._n_points += n
254 | self._update_bounds(batch)
255 | if self._n_points == self.bucketsize:
256 | # It is full, transform node as non leaf
257 | self._create_children()
258 | # Add eventual remaining points
259 | self.add_points(points_idx[n:])
260 | else:
261 | # We add points to child given their position
262 | self._update_bounds(points_idx)
263 | infe = self.storage_data[points_idx, self.split_dim] <= self.split_val
264 | self.left.add_points(points_idx[infe])
265 | self.right.add_points(points_idx[~infe])
266 |
267 | def _create_children(self):
268 | assert self.is_leaf and self.bucketsize == self._n_points
269 | # The creation must appeaar only for full leaf (after an add)
270 | # At this points the bounds are the one of the bucket
271 | ranges = self.upper - self.lower
272 | split_dim = np.argmax(ranges * self.dim_scale)
273 | # No more a leaf, create attribute for non leaf
274 | self.is_leaf = False
275 | self.split_dim = split_dim
276 | self.split_val = self.lower[split_dim] + ranges[split_dim] / 2
277 | self.left = SBucketKDNode(
278 | self.storage_data, self,
279 | dim=self.dim, dim_scale=self.dim_scale, bucketsize=self.bucketsize
280 | )
281 | self.right = SBucketKDNode(
282 | self.storage_data, self,
283 | dim=self.dim, dim_scale=self.dim_scale, bucketsize=self.bucketsize
284 | )
285 | # Now diffuse bucket points in children and erase local bucket
286 | self.add_points(self._points_idx)
287 | self._points_idx = None
288 |
289 | def nearest_neighbour(self, query, dist_to_many, max_dist=None):
290 | if self.is_leaf:
291 | dists = dist_to_many(
292 | query, self.storage_data[self._points_idx[:self._n_points]]
293 | )
294 | i_min = np.argmin(dists)
295 | if max_dist is None or dists[i_min] < max_dist:
296 | return dists[i_min], self._points_idx[i_min]
297 | return None, None
298 | else:
299 | cursor = self
300 | # Go down to the best leaf
301 | while not cursor.is_leaf:
302 | if query[cursor.split_dim] <= cursor.split_val:
303 | cursor = cursor.left
304 | else:
305 | cursor = cursor.right
306 | best_d, best_p = cursor.nearest_neighbour(
307 | query, dist_to_many, max_dist
308 | )
309 | if best_d is not None:
310 | max_dist = best_d
311 | # Go up by recursively checking ambiguous split
312 | while cursor is not self:
313 | cursor = cursor.parent
314 | # check ambiguity to non coming child
315 | # Get nearest in the child if needed
316 | i = cursor.split_dim
317 | s = self.dim_scale[i]
318 | x = query[i]
319 | d, p = None, None
320 | if x <= cursor.split_val:
321 | # We come from left, check right
322 | if s * (cursor.right.lower[i] - x) < max_dist:
323 | # There is an ambiguity, check right
324 | d, p = cursor.right.nearest_neighbour(
325 | query, dist_to_many, max_dist
326 | )
327 | else:
328 | # Same for right
329 | if s * (x - cursor.left.upper[i]) < max_dist:
330 | d, p = cursor.left.nearest_neighbour(
331 | query, dist_to_many, max_dist
332 | )
333 | if d is not None:
334 | # We have found something better in ambiguity
335 | best_d, best_p = d, p
336 | max_dist = best_d
337 |
338 | return best_d, best_p
339 |
--------------------------------------------------------------------------------
/utils/datastructures/mtree/__init__.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import utils.datastructures.mtree.functions as functions
4 | from utils.datastructures.mtree.heap_queue import HeapQueue
5 |
6 |
7 | _INFINITY = float("inf")
8 | _ItemWithDistances = namedtuple(
9 | '_ItemWithDistances', 'item, distance, min_distance'
10 | )
11 |
12 |
13 | class _RootNodeReplacement(Exception):
14 | def __init__(self, new_root):
15 | super(_RootNodeReplacement, self).__init__(new_root)
16 | self.new_root = new_root
17 |
18 |
19 | class _SplitNodeReplacement(Exception):
20 | def __init__(self, new_nodes):
21 | super(_SplitNodeReplacement, self).__init__(new_nodes)
22 | self.new_nodes = new_nodes
23 |
24 |
25 | class _NodeUnderCapacity(Exception):
26 | pass
27 |
28 |
29 | class _IndexItem(object):
30 | def __init__(self, data):
31 | self.data = data
32 | self.radius = 0
33 | self.distance_to_parent = None
34 |
35 | def _check(self, mtree):
36 | self._check_data()
37 | self._check_radius()
38 | self._check_distance_to_parent()
39 | return 1
40 |
41 | def _check_data(self):
42 | assert self.data is not None
43 |
44 | def _check_radius(self):
45 | assert self.radius is not None
46 | assert self.radius >= 0
47 |
48 | def _check_distance_to_parent(self):
49 | assert not isinstance(self, _RootNodeTrait), self
50 | assert self.distance_to_parent is not None
51 | assert self.distance_to_parent >= 0
52 |
53 |
54 | class _Node(_IndexItem):
55 |
56 | def __init__(self, data):
57 | super(_Node, self).__init__(data)
58 | self.children = {}
59 |
60 | def add_data(self, data, distance, mtree):
61 | self.do_add_data(data, distance, mtree)
62 | self.check_max_capacity(mtree)
63 |
64 | def check_max_capacity(self, mtree):
65 | if len(self.children) > mtree.max_node_capacity:
66 | data_objects = frozenset(self.children.keys())
67 | cached_distance_function = functions.make_cached_distance_function(
68 | mtree.distance_function
69 | )
70 |
71 | (promoted_data1, partition1,
72 | promoted_data2, partition2) = mtree.split_function(
73 | data_objects, cached_distance_function
74 | )
75 |
76 | split_node_repl_class = self.get_split_node_replacement_class()
77 | new_nodes = []
78 | for promoted_data, partition in [(promoted_data1, partition1),
79 | (promoted_data2, partition2)]:
80 | new_node = split_node_repl_class(promoted_data)
81 | for data in partition:
82 | child = self.children[data]
83 | distance = cached_distance_function(promoted_data, data)
84 | new_node.add_child(child, distance, mtree)
85 | new_nodes.append(new_node)
86 |
87 | raise _SplitNodeReplacement(new_nodes)
88 |
89 | def remove_data(self, data, distance, mtree):
90 | self.do_remove_data(data, distance, mtree)
91 | if len(self.children) < self.get_min_capacity(mtree):
92 | raise _NodeUnderCapacity()
93 |
94 | def update_metrics(self, child, distance):
95 | child.distance_to_parent = distance
96 | self.update_radius(child)
97 |
98 | def update_radius(self, child):
99 | self.radius = max(self.radius, child.distance_to_parent + child.radius)
100 |
101 | def _check(self, mtree):
102 | super(_Node, self)._check(mtree)
103 | self._check_min_capacity(mtree)
104 | self._check_max_capacity(mtree)
105 |
106 | child_height = None
107 | for data, child in self.children.items():
108 | assert child.data == data
109 | self._check_child_class(child)
110 | self._check_child_metrics(child, mtree)
111 |
112 | height = child._check(mtree)
113 | if child_height is None:
114 | child_height = height
115 | else:
116 | assert child_height == height
117 |
118 | return child_height + 1
119 |
120 | def _check_max_capacity(self, mtree):
121 | assert len(self.children) <= mtree.max_node_capacity
122 |
123 | def _check_child_class(self, child):
124 | expected_class = self._get_expected_child_class()
125 | assert isinstance(child, expected_class)
126 |
127 | def _check_child_metrics(self, child, mtree):
128 | dist = mtree.distance_function(child.data, self.data)
129 | assert child.distance_to_parent == dist, (
130 | child.data,
131 | self.data,
132 | child.distance_to_parent,
133 | dist,
134 | abs(child.distance_to_parent - dist)
135 | )
136 | assert child.distance_to_parent + child.radius <= self.radius
137 |
138 |
139 | class _RootNodeTrait(_Node):
140 |
141 | def _check_distance_to_parent(self):
142 | assert self.distance_to_parent is None
143 |
144 |
145 | class _NonRootNodeTrait(_Node):
146 |
147 | def get_min_capacity(self, mtree):
148 | return mtree.min_node_capacity
149 |
150 | def _check_min_capacity(self, mtree):
151 | assert len(self.children) >= mtree.min_node_capacity
152 |
153 |
154 | class _LeafNodeTrait(_Node):
155 |
156 | def do_add_data(self, data, distance, mtree):
157 | entry = _Entry(data)
158 | assert data not in self.children
159 | self.children[data] = entry
160 | assert data in self.children
161 | self.update_metrics(entry, distance)
162 |
163 | def add_child(self, child, distance, mtree):
164 | assert child.data not in self.children
165 | self.children[child.data] = child
166 | assert child.data in self.children
167 | self.update_metrics(child, distance)
168 |
169 | @staticmethod
170 | def get_split_node_replacement_class():
171 | return _LeafNode
172 |
173 | def do_remove_data(self, data, distance, mtree):
174 | del self.children[data]
175 |
176 | @staticmethod
177 | def _get_expected_child_class():
178 | return _Entry
179 |
180 |
181 | class _NonLeafNodeTrait(_Node):
182 |
183 | CandidateChild = namedtuple('CandidateChild', 'node, distance, metric')
184 |
185 | def do_add_data(self, data, distance, mtree):
186 |
187 | min_radius_increase_needed = self.CandidateChild(None, None, _INFINITY)
188 | nearest_distance = self.CandidateChild(None, None, _INFINITY)
189 |
190 | distances = mtree.distance_function(
191 | data, [child.data for child in self.children.values()]
192 | )
193 |
194 | for distance, child in zip(distances, self.children.values()):
195 | if distance > child.radius:
196 | radius_increase = distance - child.radius
197 | if radius_increase < min_radius_increase_needed.metric:
198 | min_radius_increase_needed = self.CandidateChild(
199 | child, distance, radius_increase
200 | )
201 | else:
202 | if distance < nearest_distance.metric:
203 | nearest_distance = self.CandidateChild(
204 | child, distance, distance
205 | )
206 |
207 | if nearest_distance.node is not None:
208 | chosen = nearest_distance
209 | else:
210 | chosen = min_radius_increase_needed
211 |
212 | child = chosen.node
213 | try:
214 | child.add_data(data, chosen.distance, mtree)
215 | except _SplitNodeReplacement as e:
216 | assert len(e.new_nodes) == 2
217 | # Replace current child with new nodes
218 | del self.children[child.data]
219 | distances = mtree.distance_function(
220 | data, [new_child.data for new_child in e.new_nodes]
221 | )
222 | for distance, new_child in zip(distances, e.new_nodes):
223 | self.add_child(new_child, distance, mtree)
224 | else:
225 | self.update_radius(child)
226 |
227 | def add_child(self, new_child, distance, mtree):
228 | new_children = [(new_child, distance)]
229 | while new_children:
230 | new_child, distance = new_children.pop()
231 |
232 | if new_child.data not in self.children:
233 | self.children[new_child.data] = new_child
234 | self.update_metrics(new_child, distance)
235 | else:
236 | existing_child = self.children[new_child.data]
237 | assert existing_child.data == new_child.data
238 |
239 | # Transfer the _children_ of the new_child to the existing_child
240 | for grandchild in new_child.children.values():
241 | existing_child.add_child(grandchild, grandchild.distance_to_parent, mtree)
242 |
243 | try:
244 | existing_child.check_max_capacity(mtree)
245 | except _SplitNodeReplacement as e:
246 | del self.children[new_child.data]
247 | distances = mtree.distance_function(
248 | self.data, [new_node.data for new_node in e.new_nodes]
249 | )
250 | for distance, new_node in zip(distances, e.new_nodes):
251 | new_children.append((new_node, distance))
252 |
253 | @staticmethod
254 | def get_split_node_replacement_class():
255 | return _InternalNode
256 |
257 | def do_remove_data(self, data, distance, mtree):
258 | for child in self.children.values():
259 | if abs(distance - child.distance_to_parent) <= child.radius: # TODO: confirm
260 | distance_to_child = mtree.distance_function(data, child.data)
261 | if distance_to_child <= child.radius:
262 | try:
263 | child.remove_data(data, distance_to_child, mtree)
264 | except KeyError:
265 | # If KeyError was raised, then the data was not found in the child
266 | pass
267 | except _NodeUnderCapacity:
268 | expanded_child = self.balance_children(child, mtree)
269 | self.update_radius(expanded_child)
270 | return
271 | else:
272 | self.update_radius(child)
273 | return
274 | raise KeyError()
275 |
276 | def balance_children(self, the_child, mtree):
277 | # Tries to find another_child which can donate a grandchild to the_child.
278 |
279 | nearest_donor = None
280 | distance_nearest_donor = _INFINITY
281 |
282 | nearest_merge_candidate = None
283 | distance_nearest_merge_candidate = _INFINITY
284 |
285 | distances = mtree.distance_function(
286 | the_child.data,
287 | [another_child.data for another_child in (child for child in self.children.values() if child is not the_child)]
288 | )
289 |
290 | for distance, another_child in zip(distances, (child for child in self.children.values() if child is not the_child)):
291 | if len(another_child.children) > another_child.get_min_capacity(mtree):
292 | if distance < distance_nearest_donor:
293 | distance_nearest_donor = distance
294 | nearest_donor = another_child
295 | else:
296 | if distance < distance_nearest_merge_candidate:
297 | distance_nearest_merge_candidate = distance
298 | nearest_merge_candidate = another_child
299 |
300 | if nearest_donor is None:
301 | # Merge
302 | distances = mtree.distance_function(
303 | nearest_merge_candidate.data,
304 | [grandchild.data for grandchild in the_child.children.values()]
305 |
306 | )
307 | for distance, grandchild in zip(distances, the_child.children.values()):
308 | nearest_merge_candidate.add_child(grandchild, distance, mtree)
309 |
310 | del self.children[the_child.data]
311 | return nearest_merge_candidate
312 | else:
313 | # Donate
314 | # Look for the nearest grandchild
315 | nearest_grandchild_distance = _INFINITY
316 | distances = mtree.distance_function(
317 | the_child.data,
318 | [grandchild.data for grandchild in nearest_donor.children.values()]
319 | )
320 | for distance, grandchild in zip(distances, nearest_donor.children.values()):
321 | if distance < nearest_grandchild_distance:
322 | nearest_grandchild_distance = distance
323 | nearest_grandchild = grandchild
324 |
325 | del nearest_donor.children[nearest_grandchild.data]
326 | the_child.add_child(nearest_grandchild, nearest_grandchild_distance, mtree)
327 | return the_child
328 |
329 | @staticmethod
330 | def _get_expected_child_class():
331 | return (_InternalNode, _LeafNode)
332 |
333 |
334 | class _RootLeafNode(_RootNodeTrait, _LeafNodeTrait):
335 |
336 | def remove_data(self, data, distance, mtree):
337 | try:
338 | super(_RootLeafNode, self).remove_data(data, distance, mtree)
339 | except _NodeUnderCapacity:
340 | assert len(self.children) == 0
341 | raise _RootNodeReplacement(None)
342 |
343 | @staticmethod
344 | def get_min_capacity(mtree):
345 | return 1
346 |
347 | def _check_min_capacity(self, mtree):
348 | assert len(self.children) >= 1
349 |
350 |
351 | class _RootNode(_RootNodeTrait, _NonLeafNodeTrait):
352 |
353 | def remove_data(self, data, distance, mtree):
354 | try:
355 | super(_RootNode, self).remove_data(data, distance, mtree)
356 | except _NodeUnderCapacity:
357 | # Promote the only child to root
358 | (the_child,) = self.children.values()
359 | if isinstance(the_child, _InternalNode):
360 | new_root_class = _RootNode
361 | else:
362 | assert isinstance(the_child, _LeafNode)
363 | new_root_class = _RootLeafNode
364 |
365 | new_root = new_root_class(the_child.data)
366 | distances = mtree.distance_function(
367 | new_root.data,
368 | [grandchild.data for grandchild in the_child.children.values()]
369 | )
370 |
371 | for distance, grandchild in zip(distances, the_child.children.values()):
372 | new_root.add_child(grandchild, distance, mtree)
373 |
374 | raise _RootNodeReplacement(new_root)
375 |
376 | @staticmethod
377 | def get_min_capacity(mtree):
378 | return 2
379 |
380 | def _check_min_capacity(self, mtree):
381 | assert len(self.children) >= 2
382 |
383 |
384 | class _InternalNode(_NonRootNodeTrait, _NonLeafNodeTrait):
385 | pass
386 |
387 |
388 | class _LeafNode(_NonRootNodeTrait, _LeafNodeTrait):
389 | pass
390 |
391 |
392 | class _Entry(_IndexItem):
393 | pass
394 |
395 |
396 | class MTree(object):
397 | """
398 | A data structure for indexing objects based on their proximity.
399 |
400 | The data objects must be any hashable object and the support functions
401 | (distance and split functions) must understand them.
402 |
403 | See http://en.wikipedia.org/wiki/M-tree
404 | """
405 |
406 | ResultItem = namedtuple('ResultItem', 'data, distance')
407 |
408 | def __init__(
409 | self,
410 | distance_function,
411 | min_node_capacity=50,
412 | max_node_capacity=None,
413 | split_function=functions.make_split_function(
414 | functions.random_promotion, functions.balanced_partition
415 | )
416 | ):
417 | """
418 | Creates an M-Tree.
419 |
420 | The argument min_node_capacity must be at least 2.
421 | The argument max_node_capacity should be at least 2*min_node_capacity-1.
422 | The optional argument distance_function must be a function which calculates
423 | the distance between two data objects.
424 | The optional argument split_function must be a function which chooses two
425 | data objects and then partitions the set of data into two subsets
426 | according to the chosen objects. Its arguments are the set of data objects
427 | and the distance_function. Must return a sequence with the following four values:
428 | - First chosen data object.
429 | - Subset with at least [min_node_capacity] objects based on the first
430 | chosen data object. Must contain the first chosen data object.
431 | - Second chosen data object.
432 | - Subset with at least [min_node_capacity] objects based on the second
433 | chosen data object. Must contain the second chosen data object.
434 | """
435 | if min_node_capacity < 2:
436 | raise ValueError("min_node_capacity must be at least 2")
437 | if max_node_capacity is None:
438 | max_node_capacity = 2 * min_node_capacity - 1
439 | if max_node_capacity <= min_node_capacity:
440 | raise ValueError("max_node_capacity must be greater than min_node_capacity")
441 |
442 | self.min_node_capacity = min_node_capacity
443 | self.max_node_capacity = max_node_capacity
444 | self.distance_function = distance_function
445 | self.split_function = split_function
446 | self.root = None
447 |
448 | def add(self, data):
449 | """
450 | Adds and indexes an object.
451 |
452 | The object must not currently already be indexed!
453 | """
454 | if self.root is None:
455 | self.root = _RootLeafNode(data)
456 | self.root.add_data(data, 0, self)
457 | else:
458 | distance = self.distance_function(data, self.root.data)
459 | try:
460 | self.root.add_data(data, distance, self)
461 | except _SplitNodeReplacement as e:
462 | assert len(e.new_nodes) == 2
463 | self.root = _RootNode(self.root.data)
464 | distances = self.distance_function(
465 | self.root.data,
466 | [new_node.data for new_node in e.new_nodes]
467 | )
468 | for distance, new_node in zip(distances, e.new_nodes):
469 | self.root.add_child(new_node, distance, self)
470 |
471 | add_point = add
472 |
473 | def remove(self, data):
474 | """
475 | Removes an object from the index.
476 | """
477 | if self.root is None:
478 | raise KeyError()
479 |
480 | distance_to_root = self.distance_function(data, self.root.data)
481 | try:
482 | self.root.remove_data(data, distance_to_root, self)
483 | except _RootNodeReplacement as e:
484 | self.root = e.new_root
485 |
486 | def get_nearest(self, query_data, range=_INFINITY, limit=_INFINITY):
487 | """
488 | Returns an iterator on the indexed data nearest to the query_data. The
489 | returned items are tuples containing the data and its distance to the
490 | query_data, in increasing distance order. The results can be limited by
491 | the range (maximum distance from the query_data) and limit arguments.
492 | """
493 | if self.root is None:
494 | # No indexed data!
495 | return
496 |
497 | distance = self.distance_function(query_data, self.root.data)
498 | min_distance = max(distance - self.root.radius, 0)
499 |
500 | pending_queue = HeapQueue(
501 | content=[_ItemWithDistances(item=self.root, distance=distance, min_distance=min_distance)],
502 | key=lambda iwd: iwd.min_distance,
503 | )
504 |
505 | nearest_queue = HeapQueue(key=lambda iwd: iwd.distance)
506 |
507 | yielded_count = 0
508 |
509 | while pending_queue:
510 | pending = pending_queue.pop()
511 |
512 | node = pending.item
513 | assert isinstance(node, _Node)
514 |
515 | distances = self.distance_function(
516 | query_data,
517 | [child.data for child in node.children.values()]
518 | )
519 | for child_distance, child in zip(distances, node.children.values()):
520 | if abs(pending.distance - child.distance_to_parent) - child.radius <= range:
521 | child_min_distance = max(child_distance - child.radius, 0)
522 | if child_min_distance <= range:
523 | iwd = _ItemWithDistances(item=child, distance=child_distance, min_distance=child_min_distance)
524 | if isinstance(child, _Entry):
525 | nearest_queue.push(iwd)
526 | else:
527 | pending_queue.push(iwd)
528 |
529 | # Tries to yield known results so far
530 | if pending_queue:
531 | next_pending = pending_queue.head()
532 | next_pending_min_distance = next_pending.min_distance
533 | else:
534 | next_pending_min_distance = _INFINITY
535 |
536 | while nearest_queue:
537 | next_nearest = nearest_queue.head()
538 | assert isinstance(next_nearest, _ItemWithDistances)
539 | if next_nearest.distance <= next_pending_min_distance:
540 | _ = nearest_queue.pop()
541 | assert _ is next_nearest
542 |
543 | yield self.ResultItem(data=next_nearest.item.data, distance=next_nearest.distance)
544 | yielded_count += 1
545 | if yielded_count >= limit:
546 | # Limit reached
547 | return
548 | else:
549 | break
550 |
551 | def nearest_neighbour(self, point):
552 | return next(self.get_nearest(point, limit=1))
553 |
554 | def _check(self):
555 | if self.root is not None:
556 | self.root._check(self)
557 |
--------------------------------------------------------------------------------
/utils/datastructures/mtree/faster.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import utils.datastructures.mtree.functions as functions
4 | from utils.datastructures.mtree.heap_queue import HeapQueue
5 |
6 |
7 | _INFINITY = float("inf")
8 | _ItemWithDistances = namedtuple(
9 | '_ItemWithDistances', 'item, distance, min_distance'
10 | )
11 |
12 |
13 | class _RootNodeReplacement(Exception):
14 | def __init__(self, new_root):
15 | super(_RootNodeReplacement, self).__init__(new_root)
16 | self.new_root = new_root
17 |
18 |
19 | class _SplitNodeReplacement(Exception):
20 | def __init__(self, new_nodes):
21 | super(_SplitNodeReplacement, self).__init__(new_nodes)
22 | self.new_nodes = new_nodes
23 |
24 |
25 | class _NodeUnderCapacity(Exception):
26 | pass
27 |
28 |
29 | class _IndexItem(object):
30 | def __init__(self, data):
31 | self.data = data
32 | self.radius = 0
33 | self.distance_to_parent = None
34 |
35 | def _check(self, mtree):
36 | self._check_data()
37 | self._check_radius()
38 | self._check_distance_to_parent()
39 | return 1
40 |
41 | def _check_data(self):
42 | assert self.data is not None
43 |
44 | def _check_radius(self):
45 | assert self.radius is not None
46 | assert self.radius >= 0
47 |
48 | def _check_distance_to_parent(self):
49 | assert not isinstance(self, _RootNodeTrait), self
50 | assert self.distance_to_parent is not None
51 | assert self.distance_to_parent >= 0
52 |
53 |
54 | class _Node(_IndexItem):
55 |
56 | def __init__(self, data):
57 | super(_Node, self).__init__(data)
58 | self.children = {}
59 |
60 | def add_data(self, data, distance, mtree):
61 | self.do_add_data(data, distance, mtree)
62 | self.check_max_capacity(mtree)
63 |
64 | def check_max_capacity(self, mtree):
65 | if len(self.children) > mtree.max_node_capacity:
66 | data_objects = frozenset(self.children.keys())
67 | cached_distance_function = functions.make_cached_distance_function(
68 | mtree.distance_function
69 | )
70 |
71 | (promoted_data1, partition1,
72 | promoted_data2, partition2) = mtree.split_function(
73 | data_objects, cached_distance_function
74 | )
75 |
76 | split_node_repl_class = self.get_split_node_replacement_class()
77 | new_nodes = []
78 | for promoted_data, partition in [(promoted_data1, partition1),
79 | (promoted_data2, partition2)]:
80 | new_node = split_node_repl_class(promoted_data)
81 | for data in partition:
82 | child = self.children[data]
83 | distance = cached_distance_function(promoted_data, data)
84 | new_node.add_child(child, distance, mtree)
85 | new_nodes.append(new_node)
86 |
87 | raise _SplitNodeReplacement(new_nodes)
88 |
89 | def remove_data(self, data, distance, mtree):
90 | self.do_remove_data(data, distance, mtree)
91 | if len(self.children) < self.get_min_capacity(mtree):
92 | raise _NodeUnderCapacity()
93 |
94 | def update_metrics(self, child, distance):
95 | child.distance_to_parent = distance
96 | self.update_radius(child)
97 |
98 | def update_radius(self, child):
99 | self.radius = max(self.radius, child.distance_to_parent + child.radius)
100 |
101 | def _check(self, mtree):
102 | super(_Node, self)._check(mtree)
103 | self._check_min_capacity(mtree)
104 | self._check_max_capacity(mtree)
105 |
106 | child_height = None
107 | for data, child in self.children.items():
108 | assert child.data == data
109 | self._check_child_class(child)
110 | self._check_child_metrics(child, mtree)
111 |
112 | height = child._check(mtree)
113 | if child_height is None:
114 | child_height = height
115 | else:
116 | assert child_height == height
117 |
118 | return child_height + 1
119 |
120 | def _check_max_capacity(self, mtree):
121 | assert len(self.children) <= mtree.max_node_capacity
122 |
123 | def _check_child_class(self, child):
124 | expected_class = self._get_expected_child_class()
125 | assert isinstance(child, expected_class)
126 |
127 | def _check_child_metrics(self, child, mtree):
128 | dist = mtree.distance_function(child.data, self.data)
129 | assert child.distance_to_parent == dist, (
130 | child.data,
131 | self.data,
132 | child.distance_to_parent,
133 | dist,
134 | abs(child.distance_to_parent - dist)
135 | )
136 | assert child.distance_to_parent + child.radius <= self.radius
137 |
138 |
139 | class _RootNodeTrait(_Node):
140 |
141 | def _check_distance_to_parent(self):
142 | assert self.distance_to_parent is None
143 |
144 |
145 | class _NonRootNodeTrait(_Node):
146 |
147 | def get_min_capacity(self, mtree):
148 | return mtree.min_node_capacity
149 |
150 | def _check_min_capacity(self, mtree):
151 | assert len(self.children) >= mtree.min_node_capacity
152 |
153 |
154 | class _LeafNodeTrait(_Node):
155 |
156 | def do_add_data(self, data, distance, mtree):
157 | entry = _Entry(data)
158 | assert data not in self.children
159 | self.children[data] = entry
160 | assert data in self.children
161 | self.update_metrics(entry, distance)
162 |
163 | def add_child(self, child, distance, mtree):
164 | assert child.data not in self.children
165 | self.children[child.data] = child
166 | assert child.data in self.children
167 | self.update_metrics(child, distance)
168 |
169 | @staticmethod
170 | def get_split_node_replacement_class():
171 | return _LeafNode
172 |
173 | def do_remove_data(self, data, distance, mtree):
174 | del self.children[data]
175 |
176 | @staticmethod
177 | def _get_expected_child_class():
178 | return _Entry
179 |
180 |
181 | class _NonLeafNodeTrait(_Node):
182 |
183 | CandidateChild = namedtuple('CandidateChild', 'node, distance, metric')
184 |
185 | def do_add_data(self, data, distance, mtree):
186 |
187 | min_radius_increase_needed = self.CandidateChild(None, None, _INFINITY)
188 | nearest_distance = self.CandidateChild(None, None, _INFINITY)
189 |
190 | distances = mtree.distance_function(
191 | data, [child.data for child in self.children.values()]
192 | )
193 |
194 | for distance, child in zip(distances, self.children.values()):
195 | if distance > child.radius:
196 | radius_increase = distance - child.radius
197 | if radius_increase < min_radius_increase_needed.metric:
198 | min_radius_increase_needed = self.CandidateChild(
199 | child, distance, radius_increase
200 | )
201 | else:
202 | if distance < nearest_distance.metric:
203 | nearest_distance = self.CandidateChild(
204 | child, distance, distance
205 | )
206 |
207 | if nearest_distance.node is not None:
208 | chosen = nearest_distance
209 | else:
210 | chosen = min_radius_increase_needed
211 |
212 | child = chosen.node
213 | try:
214 | child.add_data(data, chosen.distance, mtree)
215 | except _SplitNodeReplacement as e:
216 | assert len(e.new_nodes) == 2
217 | # Replace current child with new nodes
218 | del self.children[child.data]
219 | distances = mtree.distance_function(
220 | data, [new_child.data for new_child in e.new_nodes]
221 | )
222 | for distance, new_child in zip(distances, e.new_nodes):
223 | self.add_child(new_child, distance, mtree)
224 | else:
225 | self.update_radius(child)
226 |
227 | def add_child(self, new_child, distance, mtree):
228 | new_children = [(new_child, distance)]
229 | while new_children:
230 | new_child, distance = new_children.pop()
231 |
232 | if new_child.data not in self.children:
233 | self.children[new_child.data] = new_child
234 | self.update_metrics(new_child, distance)
235 | else:
236 | existing_child = self.children[new_child.data]
237 | assert existing_child.data == new_child.data
238 |
239 | # Transfer the _children_ of the new_child to the existing_child
240 | for grandchild in new_child.children.values():
241 | existing_child.add_child(grandchild, grandchild.distance_to_parent, mtree)
242 |
243 | try:
244 | existing_child.check_max_capacity(mtree)
245 | except _SplitNodeReplacement as e:
246 | del self.children[new_child.data]
247 | distances = mtree.distance_function(
248 | self.data, [new_node.data for new_node in e.new_nodes]
249 | )
250 | for distance, new_node in zip(distances, e.new_nodes):
251 | new_children.append((new_node, distance))
252 |
253 | @staticmethod
254 | def get_split_node_replacement_class():
255 | return _InternalNode
256 |
257 | def do_remove_data(self, data, distance, mtree):
258 | for child in self.children.values():
259 | if abs(distance - child.distance_to_parent) <= child.radius: # TODO: confirm
260 | distance_to_child = mtree.distance_function(data, child.data)
261 | if distance_to_child <= child.radius:
262 | try:
263 | child.remove_data(data, distance_to_child, mtree)
264 | except KeyError:
265 | # If KeyError was raised, then the data was not found in the child
266 | pass
267 | except _NodeUnderCapacity:
268 | expanded_child = self.balance_children(child, mtree)
269 | self.update_radius(expanded_child)
270 | return
271 | else:
272 | self.update_radius(child)
273 | return
274 | raise KeyError()
275 |
276 | def balance_children(self, the_child, mtree):
277 | # Tries to find another_child which can donate a grandchild to the_child.
278 |
279 | nearest_donor = None
280 | distance_nearest_donor = _INFINITY
281 |
282 | nearest_merge_candidate = None
283 | distance_nearest_merge_candidate = _INFINITY
284 |
285 | distances = mtree.distance_function(
286 | the_child.data,
287 | [another_child.data for another_child in (child for child in self.children.values() if child is not the_child)]
288 | )
289 |
290 | for distance, another_child in zip(distances, (child for child in self.children.values() if child is not the_child)):
291 | if len(another_child.children) > another_child.get_min_capacity(mtree):
292 | if distance < distance_nearest_donor:
293 | distance_nearest_donor = distance
294 | nearest_donor = another_child
295 | else:
296 | if distance < distance_nearest_merge_candidate:
297 | distance_nearest_merge_candidate = distance
298 | nearest_merge_candidate = another_child
299 |
300 | if nearest_donor is None:
301 | # Merge
302 | distances = mtree.distance_function(
303 | nearest_merge_candidate.data,
304 | [grandchild.data for grandchild in the_child.children.values()]
305 |
306 | )
307 | for distance, grandchild in zip(distances, the_child.children.values()):
308 | nearest_merge_candidate.add_child(grandchild, distance, mtree)
309 |
310 | del self.children[the_child.data]
311 | return nearest_merge_candidate
312 | else:
313 | # Donate
314 | # Look for the nearest grandchild
315 | nearest_grandchild_distance = _INFINITY
316 | distances = mtree.distance_function(
317 | the_child.data,
318 | [grandchild.data for grandchild in nearest_donor.children.values()]
319 | )
320 | for distance, grandchild in zip(distances, nearest_donor.children.values()):
321 | if distance < nearest_grandchild_distance:
322 | nearest_grandchild_distance = distance
323 | nearest_grandchild = grandchild
324 |
325 | del nearest_donor.children[nearest_grandchild.data]
326 | the_child.add_child(nearest_grandchild, nearest_grandchild_distance, mtree)
327 | return the_child
328 |
329 | @staticmethod
330 | def _get_expected_child_class():
331 | return (_InternalNode, _LeafNode)
332 |
333 |
334 | class _RootLeafNode(_RootNodeTrait, _LeafNodeTrait):
335 |
336 | def remove_data(self, data, distance, mtree):
337 | try:
338 | super(_RootLeafNode, self).remove_data(data, distance, mtree)
339 | except _NodeUnderCapacity:
340 | assert len(self.children) == 0
341 | raise _RootNodeReplacement(None)
342 |
343 | @staticmethod
344 | def get_min_capacity(mtree):
345 | return 1
346 |
347 | def _check_min_capacity(self, mtree):
348 | assert len(self.children) >= 1
349 |
350 |
351 | class _RootNode(_RootNodeTrait, _NonLeafNodeTrait):
352 |
353 | def remove_data(self, data, distance, mtree):
354 | try:
355 | super(_RootNode, self).remove_data(data, distance, mtree)
356 | except _NodeUnderCapacity:
357 | # Promote the only child to root
358 | (the_child,) = self.children.values()
359 | if isinstance(the_child, _InternalNode):
360 | new_root_class = _RootNode
361 | else:
362 | assert isinstance(the_child, _LeafNode)
363 | new_root_class = _RootLeafNode
364 |
365 | new_root = new_root_class(the_child.data)
366 | distances = mtree.distance_function(
367 | new_root.data,
368 | [grandchild.data for grandchild in the_child.children.values()]
369 | )
370 |
371 | for distance, grandchild in zip(distances, the_child.children.values()):
372 | new_root.add_child(grandchild, distance, mtree)
373 |
374 | raise _RootNodeReplacement(new_root)
375 |
376 | @staticmethod
377 | def get_min_capacity(mtree):
378 | return 2
379 |
380 | def _check_min_capacity(self, mtree):
381 | assert len(self.children) >= 2
382 |
383 |
384 | class _InternalNode(_NonRootNodeTrait, _NonLeafNodeTrait):
385 | pass
386 |
387 |
388 | class _LeafNode(_NonRootNodeTrait, _LeafNodeTrait):
389 | pass
390 |
391 |
392 | class _Entry(_IndexItem):
393 | pass
394 |
395 |
396 | class MTree(object):
397 | """
398 | A data structure for indexing objects based on their proximity.
399 |
400 | The data objects must be any hashable object and the support functions
401 | (distance and split functions) must understand them.
402 |
403 | See http://en.wikipedia.org/wiki/M-tree
404 | """
405 |
406 | ResultItem = namedtuple('ResultItem', 'data, distance')
407 |
408 | def __init__(
409 | self,
410 | distance_function,
411 | min_node_capacity=20,
412 | max_node_capacity=None,
413 | split_function=functions.make_split_function(
414 | functions.random_promotion, functions.balanced_partition
415 | )
416 | ):
417 | """
418 | Creates an M-Tree.
419 |
420 | The argument min_node_capacity must be at least 2.
421 | The argument max_node_capacity should be at least 2*min_node_capacity-1.
422 | The optional argument distance_function must be a function which calculates
423 | the distance between two data objects.
424 | The optional argument split_function must be a function which chooses two
425 | data objects and then partitions the set of data into two subsets
426 | according to the chosen objects. Its arguments are the set of data objects
427 | and the distance_function. Must return a sequence with the following four values:
428 | - First chosen data object.
429 | - Subset with at least [min_node_capacity] objects based on the first
430 | chosen data object. Must contain the first chosen data object.
431 | - Second chosen data object.
432 | - Subset with at least [min_node_capacity] objects based on the second
433 | chosen data object. Must contain the second chosen data object.
434 | """
435 | if min_node_capacity < 2:
436 | raise ValueError("min_node_capacity must be at least 2")
437 | if max_node_capacity is None:
438 | max_node_capacity = 2 * min_node_capacity - 1
439 | if max_node_capacity <= min_node_capacity:
440 | raise ValueError("max_node_capacity must be greater than min_node_capacity")
441 |
442 | self.min_node_capacity = min_node_capacity
443 | self.max_node_capacity = max_node_capacity
444 | self.distance_function = distance_function
445 | self.split_function = split_function
446 | self.root = None
447 |
448 | def add(self, data):
449 | """
450 | Adds and indexes an object.
451 |
452 | The object must not currently already be indexed!
453 | """
454 | if self.root is None:
455 | self.root = _RootLeafNode(data)
456 | self.root.add_data(data, 0, self)
457 | else:
458 | distance = self.distance_function(data, self.root.data)
459 | try:
460 | self.root.add_data(data, distance, self)
461 | except _SplitNodeReplacement as e:
462 | assert len(e.new_nodes) == 2
463 | self.root = _RootNode(self.root.data)
464 | distances = self.distance_function(
465 | self.root.data,
466 | [new_node.data for new_node in e.new_nodes]
467 | )
468 | for distance, new_node in zip(distances, e.new_nodes):
469 | self.root.add_child(new_node, distance, self)
470 |
471 | add_point = add
472 |
473 | def remove(self, data):
474 | """
475 | Removes an object from the index.
476 | """
477 | if self.root is None:
478 | raise KeyError()
479 |
480 | distance_to_root = self.distance_function(data, self.root.data)
481 | try:
482 | self.root.remove_data(data, distance_to_root, self)
483 | except _RootNodeReplacement as e:
484 | self.root = e.new_root
485 |
486 | def get_nearest(self, query_data, range=_INFINITY, limit=_INFINITY):
487 | """
488 | Returns an iterator on the indexed data nearest to the query_data. The
489 | returned items are tuples containing the data and its distance to the
490 | query_data, in increasing distance order. The results can be limited by
491 | the range (maximum distance from the query_data) and limit arguments.
492 | """
493 | if self.root is None:
494 | # No indexed data!
495 | return
496 |
497 | distance = self.distance_function(query_data, self.root.data)
498 | min_distance = max(distance - self.root.radius, 0)
499 |
500 | pending_queue = HeapQueue(
501 | content=[_ItemWithDistances(item=self.root, distance=distance, min_distance=min_distance)],
502 | key=lambda iwd: iwd.min_distance,
503 | )
504 |
505 | nearest_queue = HeapQueue(key=lambda iwd: iwd.distance)
506 |
507 | yielded_count = 0
508 |
509 | while pending_queue:
510 | pending = pending_queue.pop()
511 |
512 | node = pending.item
513 | assert isinstance(node, _Node)
514 |
515 | distances = self.distance_function(
516 | query_data,
517 | [child.data for child in node.children.values()]
518 | )
519 | for child_distance, child in zip(distances, node.children.values()):
520 | if abs(pending.distance - child.distance_to_parent) - child.radius <= range:
521 | child_min_distance = max(child_distance - child.radius, 0)
522 | if child_min_distance <= range:
523 | iwd = _ItemWithDistances(item=child, distance=child_distance, min_distance=child_min_distance)
524 | if isinstance(child, _Entry):
525 | nearest_queue.push(iwd)
526 | else:
527 | pending_queue.push(iwd)
528 |
529 | # Tries to yield known results so far
530 | if pending_queue:
531 | next_pending = pending_queue.head()
532 | next_pending_min_distance = next_pending.min_distance
533 | else:
534 | next_pending_min_distance = _INFINITY
535 |
536 | while nearest_queue:
537 | next_nearest = nearest_queue.head()
538 | assert isinstance(next_nearest, _ItemWithDistances)
539 | if next_nearest.distance <= next_pending_min_distance:
540 | _ = nearest_queue.pop()
541 | assert _ is next_nearest
542 |
543 | yield self.ResultItem(data=next_nearest.item.data, distance=next_nearest.distance)
544 | yielded_count += 1
545 | if yielded_count >= limit:
546 | # Limit reached
547 | return
548 | else:
549 | break
550 |
551 | def nearest_neighbour(self, point):
552 | return next(self.get_nearest(point, limit=1))
553 |
554 | def _check(self):
555 | if self.root is not None:
556 | self.root._check(self)
557 |
--------------------------------------------------------------------------------
/utils/datastructures/mtree/functions.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from utils.datastructures.mtree.heap_queue import HeapQueue
4 |
5 |
6 | def random_promotion(data_objects, distance_function):
7 | """
8 | Randomly chooses two objects to be promoted.
9 | """
10 | data_objects = list(data_objects)
11 | return random.sample(data_objects, 2)
12 |
13 |
14 | def balanced_partition(
15 | promoted_data1, promoted_data2, data_objects, distance_function
16 | ):
17 | partition1 = set()
18 | partition2 = set()
19 |
20 | queue1 = HeapQueue(
21 | data_objects,
22 | key=lambda data: distance_function(data, promoted_data1)
23 | )
24 | queue2 = HeapQueue(
25 | data_objects,
26 | key=lambda data: distance_function(data, promoted_data2)
27 | )
28 |
29 | while queue1 or queue2:
30 | while queue1:
31 | data = queue1.pop()
32 | if data not in partition2:
33 | partition1.add(data)
34 | break
35 |
36 | while queue2:
37 | data = queue2.pop()
38 | if data not in partition1:
39 | partition2.add(data)
40 | break
41 |
42 | return partition1, partition2
43 |
44 |
45 | def make_split_function(promotion_function, partition_function):
46 | """
47 | Creates a splitting function.
48 | The parameters must be callable objects:
49 | - promotion_function(data_objects, distance_function)
50 | Must return two objects chosen from the data_objects argument.
51 | - partition_function(promoted_data1, promoted_data2, data_objects, distance_function)
52 | Must return a sequence with two iterable objects containing a partition
53 | of the data_objects. The promoted_data1 and promoted_data2 arguments
54 | should be used as partitioning criteria and must be contained on the
55 | corresponding iterable subsets.
56 | """
57 | def split_function(data_objects, distance_function):
58 | promoted_data1, promoted_data2 = promotion_function(
59 | data_objects, distance_function
60 | )
61 | partition1, partition2 = partition_function(
62 | promoted_data1, promoted_data2, data_objects, distance_function
63 | )
64 |
65 | return promoted_data1, partition1, promoted_data2, partition2
66 | return split_function
67 |
68 |
69 | def make_cached_distance_function(distance_function):
70 | cache = {}
71 |
72 | def cached_distance_function(data1, data2):
73 | try:
74 | distance = cache[data1][data2]
75 | except KeyError:
76 | distance = distance_function(data1, data2)
77 |
78 | if data1 in cache:
79 | cache[data1][data2] = distance
80 | else:
81 | cache[data1] = {data2: distance}
82 |
83 | if data2 in cache:
84 | cache[data2][data1] = distance
85 | else:
86 | cache[data2] = {data1: distance}
87 |
88 | return distance
89 |
90 | cached_distance_function.cache = cache
91 |
92 | return cached_distance_function
93 |
--------------------------------------------------------------------------------
/utils/datastructures/mtree/heap_queue.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 |
4 | _HeapItem = namedtuple('_HeapItem', 'k, value')
5 |
6 | class HeapQueue(object):
7 |
8 | def __init__(self, content=(), key=lambda x: x, max=False):
9 | if max:
10 | self.key = lambda x: -key(x)
11 | else:
12 | self.key = key
13 | self._items = [_HeapItem(self.key(value), value) for value in content]
14 | self.heapify()
15 |
16 | def _items_less_than(self, base, other):
17 | return self._items[base].k < self._items[other].k
18 |
19 | def _swap_items(self, base, other):
20 | self._items[base], self._items[other] = self._items[other], self._items[base]
21 |
22 | def _make_heap(self, i):
23 | smallest = i
24 |
25 | left = 2 * i + 1
26 | if left < len(self._items) and self._items_less_than(left, smallest):
27 | smallest = left
28 |
29 | right = 2 * i + 2
30 | if right < len(self._items) and self._items_less_than(right, smallest):
31 | smallest = right
32 |
33 | if smallest != i:
34 | self._swap_items(i, smallest)
35 | self._make_heap(smallest)
36 |
37 | def heapify(self):
38 | for i in range(len(self._items) // 2, -1, -1):
39 | self._make_heap(i)
40 |
41 | def head(self):
42 | return self._items[0].value
43 |
44 | def push(self, value):
45 | i = len(self._items)
46 | new_item = _HeapItem(self.key(value), value)
47 | self._items.append(new_item)
48 | while i > 0:
49 | p = int((i - 1) // 2)
50 | if self._items_less_than(p, i):
51 | break
52 | self._swap_items(i, p)
53 | i = p
54 |
55 | def pop(self):
56 | popped = self._items[0].value
57 | self._items[0] = self._items[-1]
58 | self._items.pop(-1)
59 | self._make_heap(0)
60 | return popped
61 |
62 | def pushpop(self, value):
63 | k = self.key(value)
64 | if k <= self._items[0].k:
65 | return value
66 | else:
67 | popped = self._items[0].value
68 | self._items[0] = _HeapItem(k, value)
69 | self._make_heap(0)
70 | return popped
71 |
72 | def __len__(self):
73 | return len(self._items)
74 |
75 | def extractor(self):
76 | while self._items:
77 | yield self.pop()
78 |
--------------------------------------------------------------------------------
/utils/datastructures/pathtree.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle as pkl
3 | from utils.datastructures.storage import Storage
4 |
5 |
6 | class PathTree:
7 |
8 | @classmethod
9 | def load(cls, path):
10 | inst = cls(Storage.load(str(path) + '_storage.pkl'))
11 | with open(str(path) + '_tree.pkl', 'rb') as f:
12 | a = pkl.load(f)
13 |
14 | n = inst.storage.n
15 | inst.parent[:n] = a['parent']
16 | inst.cost[:n] = a['cost']
17 | inst.depth[:n] = a['depth']
18 |
19 | return inst
20 |
21 | def __init__(self, storage):
22 | self.storage = storage
23 | self.parent = np.zeros(storage.N, dtype=np.intp)
24 | self.cost = np.zeros(storage.N, dtype=float)
25 | self.depth = np.zeros(storage.N, dtype=int)
26 |
27 | def update_link(self, q_idx, parent_idx, c=1.):
28 | self.parent[q_idx] = parent_idx
29 | self.depth[q_idx] = self.depth[parent_idx] + 1
30 | self.cost[q_idx] = self.cost[parent_idx] + c
31 |
32 | def get_edges(self):
33 | # TODO use yielding to avoid data overcreation
34 | res = np.zeros((self.storage.n - 1, 2, self.storage.dim), dtype=np.float)
35 | res[:, 0, :] = self.storage.data[1:self.storage.n, :]
36 | res[:, 1, :] = self.storage.data[self.parent[1:self.storage.n], :]
37 |
38 | costs = self.cost[1:self.storage.n]
39 | return res, costs
40 |
41 | def get_path(self):
42 | # TODO use yielding to avoid data overcreation
43 | i = self.storage.n - 1
44 | len_path = self.depth[i] + 1
45 | res = np.zeros((len_path, self.storage.dim))
46 | j = len_path
47 | while not i == 0:
48 | j -= 1
49 | res[j] = self.storage.data[i]
50 | i = self.parent[i]
51 | res[0] = self.storage.data[0]
52 | return res
53 |
54 | def save(self, path):
55 | n = self.storage.n
56 | self.storage.save(str(path) + '_storage.pkl')
57 | with open(str(path) + '_tree.pkl', 'wb') as f:
58 | pkl.dump({
59 | 'parent': self.parent[:n],
60 | 'cost': self.cost[:n],
61 | 'depth': self.depth[:n],
62 | }, f)
63 |
64 | def get_estimated_start_goal(self):
65 | return self.storage.data[0], self.storage.data[self.storage.n - 1]
66 |
--------------------------------------------------------------------------------
/utils/datastructures/storage.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle as pkl
3 |
4 |
5 | class Storage:
6 |
7 | @classmethod
8 | def load(cls, path):
9 | with open(path, 'rb') as f:
10 | a = pkl.load(f)
11 |
12 | inst = cls(a['N'], a['dim'])
13 |
14 | n = a['n']
15 | inst.n = n
16 |
17 | inst.data[:n] = a['data']
18 |
19 | return inst
20 |
21 | def __init__(self, N, dim):
22 | self.N = N
23 | self.dim = dim
24 | self.n = np.intp(0)
25 | self.data = np.zeros((N, dim), dtype=float)
26 |
27 | def add_point(self, p):
28 | assert not self.is_full
29 | self.data[self.n] = p
30 | self.n += 1
31 | return self.n - 1
32 |
33 | def remove_last(self):
34 | assert self.n
35 | self.n -= 1
36 |
37 | def __getitem__(self, idx):
38 | # assert idx < self.n
39 | return self.data[idx]
40 |
41 | def __len__(self):
42 | return self.n
43 |
44 | @property
45 | def ndarray(self):
46 | return self.data[:self.n]
47 |
48 | @property
49 | def is_full(self):
50 | return self.n == self.N
51 |
52 | def save(self, path):
53 | with open(path, 'wb') as f:
54 | pkl.dump({
55 | 'N': self.N,
56 | 'dim': self.dim,
57 | 'n': self.n,
58 | 'data': self.data[:self.n]
59 | }, f)
60 |
--------------------------------------------------------------------------------
/utils/datastructures/tree.py:
--------------------------------------------------------------------------------
1 | class NodeBinaryTree:
2 | """
3 | Abstract tree to implement the classic search
4 | """
5 | def __init__(
6 | self, parent=None, left=None, right=None
7 | ):
8 | self.parent = parent
9 | self.left = left
10 | self.right = right
11 |
12 | def ascension(self):
13 | yield self
14 | if self.parent is not None:
15 | for e in self.parent.ascension():
16 | yield e
17 |
18 | def depth_first(self):
19 | yield self
20 | if self.left is not None:
21 | for e in self.left.depth_first():
22 | yield e
23 | if self.right is not None:
24 | for e in self.right.depth_first():
25 | yield e
26 |
27 | def _wide_first(self, i=0):
28 | yield self, i
29 | iter_left = (
30 | iter(self.left._wide_first(i + 1))
31 | if self.left is not None else None
32 | )
33 | iter_right = (
34 | iter(self.right._wide_first(i + 1))
35 | if self.right is not None else None
36 | )
37 | i_left, n_left = self.robust_next(iter_left)
38 | i_right, n_right = self.robust_next(iter_right)
39 | while not (i_left is None and i_right is None):
40 | if i_left is not None and (i_right is None or i_left <= i_right):
41 | yield i_left, n_left
42 | i_left, n_left = self.robust_next(iter_left)
43 | else:
44 | yield i_right, n_right
45 | i_right, n_right = self.robust_next(iter_right)
46 |
47 | def wide_first(self):
48 | for _, e in self._wide_first():
49 | yield e
50 |
51 | @staticmethod
52 | def robust_next(iterator):
53 | if iterator is None:
54 | return (None, None)
55 | return next(iterator, (None, None))
56 |
--------------------------------------------------------------------------------
/utils/generate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Cut python files in bits loadable by ipython."""
3 |
4 | from pathlib import Path
5 | import json
6 |
7 | hashtags = ['jupyter_snippet']
8 |
9 | def generate_from_id(tp_id : int):
10 | folder = Path() / f'tp{tp_id}'
11 | ipynb = next(Path().glob(f'{tp_id}_*.ipynb'))
12 | generate(ipynb,folder)
13 |
14 | def generate(ipynb, folder):
15 | print(f'processing {ipynb} with scripts in {folder}')
16 | with ipynb.open() as f:
17 | data = json.load(f)
18 | cells_copy = data['cells'].copy()
19 | generated = folder / 'generated'
20 | generated.mkdir(exist_ok=True)
21 | for filename in folder.glob('*.py'):
22 | print(f' processing {filename}')
23 | content = []
24 | hidden = False
25 | dest = None
26 | with filename.open() as f_in:
27 | for line_number, line in enumerate(f_in):
28 | if any([ f'# %{hashtag}' in line for hashtag in hashtags ]):
29 | if dest is not None:
30 | raise SyntaxError(f'%{hashtags[0]} block open twice at line {line_number + 1}')
31 | dest = generated / f'{filename.stem}_{line.split()[2]}'
32 | hidden = False
33 | elif any([ line.strip() == f'# %end_{hashtag}' for hashtag in hashtags ]):
34 | if dest is None:
35 | raise SyntaxError(f'%{hashtags[0]} block before open at line {line_number + 1}')
36 | with dest.open('w') as f_out:
37 | f_out.write(''.join(content))
38 | for cell_number, cell in enumerate(cells_copy):
39 | if len(cell['source'])==0: continue
40 | if cell['source'][0].endswith(f'%load {dest}'):
41 | data['cells'][cell_number]['source'] = [f'# %load {dest}\n'] + content
42 | #if f'%do_not_load {dest}' in cell['source'][0]:
43 | # data['cells'][cell_number]['source'] = [f'%do_not_load {dest}\n']
44 | content = []
45 | hidden = False
46 | dest = None
47 | elif dest is not None:
48 | content.append(line)
49 | with ipynb.open('w') as f:
50 | f.write(json.dumps(data, indent=1))
51 |
52 |
53 | if __name__ == '__main__':
54 | for tp_number in [0,1,2,3,4,5]:
55 | generate_from_id(tp_number)
56 |
57 | for app in [ 'appendix_scipy_optimizers']:
58 | generate(next(Path().glob(app+'.ipynb')),Path()/'appendix')
59 |
--------------------------------------------------------------------------------
/utils/load_ur5_parallel.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pinocchio as pin
3 | from example_robot_data import load
4 |
5 |
6 | def load_ur5_parallel():
7 | """
8 | Create a robot composed of 4 UR5
9 |
10 | >>> ur5 = load('ur5')
11 | >>> ur5.nq
12 | 6
13 | >>> len(ur5.visual_model.geometryObjects)
14 | 7
15 | >>> robot = load_ur5_parallel()
16 | >>> robot.nq
17 | 24
18 | >>> len(robot.visual_model.geometryObjects)
19 | 28
20 | """
21 | robot = load('ur5')
22 | nbRobots = 4
23 |
24 | models = [robot.model.copy() for _ in range(nbRobots)]
25 | vmodels = [robot.visual_model.copy() for _ in range(nbRobots)]
26 |
27 | # Build the kinematic model by assembling 4 UR5
28 | fullmodel = pin.Model()
29 |
30 | for irobot, model in enumerate(models):
31 | # Change frame names
32 | for i, f in enumerate(model.frames):
33 | f.name = '%s_#%d' % (f.name, irobot)
34 | # Change joint names
35 | for i, n in enumerate(model.names):
36 | model.names[i] = '%s_#%d' % (n, irobot)
37 |
38 | # Choose the placement of the new arm to be added
39 | Mt = pin.SE3(np.eye(3), np.array([.3, 0, 0.])) # First robot is simply translated
40 | basePlacement = pin.SE3(pin.utils.rotate('z', np.pi * irobot / 2), np.zeros(3)) * Mt
41 |
42 | # Append the kinematic model
43 | fullmodel = pin.appendModel(fullmodel, model, 0, basePlacement)
44 |
45 | # Build the geometry model
46 | fullvmodel = pin.GeometryModel()
47 |
48 | for irobot, (model, vmodel) in enumerate(zip(models, vmodels)):
49 | # Change geometry names
50 | for i, g in enumerate(vmodel.geometryObjects):
51 | # Change the name to avoid conflict
52 | g.name = '%s_#%d' % (g.name, irobot)
53 |
54 | # Refere to new parent names in the full kinematic tree
55 | g.parentFrame = fullmodel.getFrameId(model.frames[g.parentFrame].name)
56 | g.parentJoint = fullmodel.getJointId(model.names[g.parentJoint])
57 |
58 | # Append the geometry model
59 | fullvmodel.addGeometryObject(g)
60 | # print('add %s on frame %d "%s"' % (g.name, g.parentFrame, fullmodel.frames[g.parentFrame].name))
61 |
62 | fullrobot = pin.RobotWrapper(fullmodel, fullvmodel, fullvmodel)
63 | # fullrobot.q0 = np.array([-0.375, -1.2 , 1.71 , -0.51 , -0.375, 0. ]*4)
64 | fullrobot.q0 = np.array([np.pi / 4, -np.pi / 4, -np.pi / 2, np.pi / 4, np.pi / 2, 0] * nbRobots)
65 |
66 | return fullrobot
67 |
68 |
69 | if __name__ == "__main__":
70 | from utils.meshcat_viewer_wrapper import MeshcatVisualizer
71 |
72 | robot = load_ur5_parallel()
73 | viz = MeshcatVisualizer(robot, url='classical')
74 | viz.display(robot.q0)
75 |
--------------------------------------------------------------------------------
/utils/load_ur5_with_obstacles.py:
--------------------------------------------------------------------------------
1 | '''
2 | Load a UR5 robot model, display it in the viewer. Also create an obstacle
3 | field made of several capsules, display them in the viewer and create the
4 | collision detection to handle it.
5 | '''
6 |
7 | import pinocchio as pin
8 | import example_robot_data as robex
9 | import numpy as np
10 | import itertools
11 |
12 |
13 | def XYZRPYtoSE3(xyzrpy):
14 | rotate = pin.utils.rotate
15 | R = rotate('x',xyzrpy[3]) @ rotate('y',xyzrpy[4]) @ rotate('z',xyzrpy[5])
16 | p = np.array(xyzrpy[:3])
17 | return pin.SE3(R,p)
18 |
19 | def load_ur5_with_obstacles(robotname='ur5',reduced=False):
20 |
21 | ### Robot
22 | # Load the robot
23 | robot = robex.load(robotname)
24 |
25 | ### If reduced, then only keep should-tilt and elbow joint, hence creating a simple R2 robot.
26 | if reduced:
27 | unlocks = [1,2]
28 | robot.model,[robot.visual_model,robot.collision_model]\
29 | = pin.buildReducedModel(robot.model,[robot.visual_model,robot.collision_model],
30 | [ i+1 for i in range(robot.nq) if i not in unlocks ],robot.q0)
31 | robot.data = robot.model.createData()
32 | robot.collision_data = robot.collision_model.createData()
33 | robot.visual_data = robot.visual_model.createData()
34 | robot.q0 = robot.q0[unlocks].copy()
35 |
36 | ### Obstacle map
37 | # Capsule obstacles will be placed at these XYZ-RPY parameters
38 | oMobs = [ [ 0.40, 0., 0.30, np.pi/2,0,0],
39 | [-0.08, -0., 0.69, np.pi/2,0,0],
40 | [ 0.23, -0., 0.04, np.pi/2, 0 ,0 ],
41 | [-0.32, 0., -0.08, np.pi/2, 0, 0]]
42 |
43 | # Load visual objects and add them in collision/visual models
44 | color = [ 1.0, 0.2, 0.2, 1.0 ] # color of the capsules
45 | rad,length = .1,0.4 # radius and length of capsules
46 | for i,xyzrpy in enumerate(oMobs):
47 | obs = pin.GeometryObject.CreateCapsule(rad,length) # Pinocchio obstacle object
48 | obs.meshColor = np.array([ 1.0, 0.2, 0.2, 1.0 ]) # Don't forget me, otherwise I am transparent ...
49 | obs.name = "obs%d"%i # Set object name
50 | obs.parentJoint = 0 # Set object parent = 0 = universe
51 | obs.placement = XYZRPYtoSE3(xyzrpy) # Set object placement wrt parent
52 | robot.collision_model.addGeometryObject(obs) # Add object to collision model
53 | robot.visual_model .addGeometryObject(obs) # Add object to visual model
54 |
55 | ### Collision pairs
56 | nobs = len(oMobs)
57 | nbodies = robot.collision_model.ngeoms-nobs
58 | robotBodies = range(nbodies)
59 | envBodies = range(nbodies,nbodies+nobs)
60 | robot.collision_model.removeAllCollisionPairs()
61 | for a,b in itertools.product(robotBodies,envBodies):
62 | robot.collision_model.addCollisionPair(pin.CollisionPair(a,b))
63 |
64 | ### Geom data
65 | # Collision/visual models have been modified => re-generate corresponding data.
66 | robot.collision_data = pin.GeometryData(robot.collision_model)
67 | robot.visual_data = pin.GeometryData(robot.visual_model )
68 |
69 | return robot
70 |
71 |
72 | class Target:
73 | '''
74 | Simple class target that stores and display the position of a target.
75 | '''
76 | def __init__(self,viz=None,color = [ .0, 1.0, 0.2, 1.0 ], radius = 0.05, position=None):
77 | self.position = position if position is not None else np.array([ 0.0, 0.0 ])
78 | self.initVisual(viz,color,radius)
79 | self.display()
80 |
81 | def initVisual(self,viz,color,radius):
82 | self.viz = viz
83 | if viz is None: return
84 | self.name = "world/pinocchio/target"
85 |
86 | if isinstance(viz,pin.visualize.MeshcatVisualizer):
87 | import meshcat
88 | obj = meshcat.geometry.Sphere(radius)
89 | material = meshcat.geometry.MeshPhongMaterial()
90 | material.color = int(color[0] * 255) * 256**2 + int(color[1] * 255) * 256 + int(color[2] * 255)
91 | if float(color[3]) != 1.0:
92 | material.transparent = True
93 | material.opacity = float(color[3])
94 | self.viz.viewer[self.name].set_object(obj, material)
95 |
96 | elif isinstance(viz,pin.visualize.GepettoVisualizer):
97 | self.viz.viewer.gui.addCapsule( self.name, radius,0., color)
98 |
99 | def display(self):
100 | if self.viz is None or self.position is None: return
101 |
102 | if isinstance(self.viz,pin.visualize.MeshcatVisualizer):
103 | T = np.eye(4)
104 | T[[0,2],3] = self.position
105 | self.viz.viewer[self.name].set_transform(T)
106 | elif isinstance(self.viz,pin.visualize.GepettoVisualizer):
107 | self.viz.viewer.gui.applyConfiguration( self.name,
108 | [ self.position[0], 0, self.position[1],
109 | 1.,0.,0.0,0. ])
110 | self.viz.viewer.gui.refresh()
111 |
--------------------------------------------------------------------------------
/utils/meshcat_viewer_wrapper/__init__.py:
--------------------------------------------------------------------------------
1 | from .visualizer import MeshcatVisualizer # noqa
2 | from .transformations import planar,translation2d
3 |
--------------------------------------------------------------------------------
/utils/meshcat_viewer_wrapper/colors.py:
--------------------------------------------------------------------------------
1 | import meshcat
2 |
3 |
4 | def rgb2int(r, g, b):
5 | '''
6 | Convert 3 integers (chars) 0 <= r, g, b < 256 into one single integer = 256**2*r+256*g+b, as expected by Meshcat.
7 |
8 | >>> rgb2int(0, 0, 0)
9 | 0
10 | >>> rgb2int(0, 0, 255)
11 | 255
12 | >>> rgb2int(0, 255, 0) == 0x00FF00
13 | True
14 | >>> rgb2int(255, 0, 0) == 0xFF0000
15 | True
16 | >>> rgb2int(255, 255, 255) == 0xFFFFFF
17 | True
18 | '''
19 | return int((r << 16) + (g << 8) + b)
20 |
21 |
22 | def material(color, transparent=False):
23 | mat = meshcat.geometry.MeshPhongMaterial()
24 | mat.color = color
25 | mat.transparent = transparent
26 | return mat
27 |
28 |
29 | red = material(color=rgb2int(255, 0, 0), transparent=False)
30 | blue = material(color=rgb2int(0, 0, 255), transparent=False)
31 | green = material(color=rgb2int(0, 255, 0), transparent=False)
32 | yellow = material(color=rgb2int(255, 255, 0), transparent=False)
33 | magenta = material(color=rgb2int(255, 0, 255), transparent=False)
34 | cyan = material(color=rgb2int(0, 255, 255), transparent=False)
35 | white = material(color=rgb2int(250, 250, 250), transparent=False)
36 | black = material(color=rgb2int(5, 5, 5), transparent=False)
37 | grey = material(color=rgb2int(120, 120, 120), transparent=False)
38 |
39 | colormap = {
40 | 'red': red,
41 | 'blue': blue,
42 | 'green': green,
43 | 'yellow': yellow,
44 | 'magenta': magenta,
45 | 'cyan': cyan,
46 | 'black': black,
47 | 'white': white,
48 | 'grey': grey
49 | }
50 |
--------------------------------------------------------------------------------
/utils/meshcat_viewer_wrapper/tests.py:
--------------------------------------------------------------------------------
1 | import doctest
2 |
3 | from utils.meshcat_viewer_wrapper import colors
4 |
5 |
6 | def load_tests(loader, tests, pattern):
7 | tests.addTests(doctest.DocTestSuite(colors))
8 | return tests
9 |
--------------------------------------------------------------------------------
/utils/meshcat_viewer_wrapper/transformations.py:
--------------------------------------------------------------------------------
1 | '''
2 | Collection of super simple transformations to ease the use of the viewer.
3 | '''
4 |
5 | import numpy as np
6 |
7 | def planar(x, y, theta):
8 | '''Convert a 3d vector (x,y,theta) into a transformation in the Y,Z plane.'''
9 | s,c=np.sin(theta/2),np.cos(theta / 2)
10 | return [0, x, y, s,0,0,c] # Rotation around X
11 |
12 | def translation2d(x,y):
13 | ''' Convert a 2d vector (x,y) into a 3d transformation translating the Y,Z plane. '''
14 | return [0,x,y,1,0,0,0]
15 |
16 |
--------------------------------------------------------------------------------
/utils/meshcat_viewer_wrapper/visualizer.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import meshcat
4 | import numpy as np
5 | import pinocchio as pin
6 | from pinocchio.visualize import MeshcatVisualizer as PMV
7 |
8 | from . import colors
9 |
10 |
11 | def materialFromColor(color):
12 | if isinstance(color, meshcat.geometry.MeshPhongMaterial):
13 | return color
14 | elif isinstance(color, str):
15 | material = colors.colormap[color]
16 | elif isinstance(color, list):
17 | material = meshcat.geometry.MeshPhongMaterial()
18 | material.color = colors.rgb2int(*[int(c * 255) for c in color[:3]])
19 | if len(color) == 3:
20 | material.transparent = False
21 | else:
22 | material.transparent = color[3] < 1
23 | material.opacity = float(color[3])
24 | elif color is None:
25 | material = random.sample(list(colors.colormap), 1)[0]
26 | else:
27 | material = colors.black
28 | return material
29 |
30 |
31 | class MeshcatVisualizer(PMV):
32 | def __init__(self, robot=None, model=None, collision_model=None, visual_model=None, url=None):
33 | if robot is not None:
34 | super().__init__(robot.model, robot.collision_model, robot.visual_model)
35 | elif model is not None:
36 | super().__init__(model, collision_model, visual_model)
37 |
38 | if url is not None:
39 | if url == 'classical':
40 | url = 'tcp://127.0.0.1:6000'
41 | print('Wrapper tries to connect to server <%s>' % url)
42 | server = meshcat.Visualizer(zmq_url=url)
43 | else:
44 | server = None
45 |
46 | if robot is not None or model is not None:
47 | self.initViewer(loadModel=True, viewer=server)
48 | else:
49 | self.viewer = server if server is not None else meshcat.Visualizer()
50 |
51 | def addSphere(self, name, radius, color):
52 | material = materialFromColor(color)
53 | self.viewer[name].set_object(meshcat.geometry.Sphere(radius), material)
54 |
55 | def addCylinder(self, name, length, radius, color=None):
56 | material = materialFromColor(color)
57 | self.viewer[name].set_object(meshcat.geometry.Cylinder(length, radius), material)
58 |
59 | def addBox(self, name, dims, color):
60 | material = materialFromColor(color)
61 | self.viewer[name].set_object(meshcat.geometry.Box(dims), material)
62 |
63 | def applyConfiguration(self, name, placement):
64 | if isinstance(placement, list) or isinstance(placement, tuple):
65 | placement = np.array(placement)
66 | if isinstance(placement, pin.SE3):
67 | R, p = placement.rotation, placement.translation
68 | T = np.r_[np.c_[R, p], [[0, 0, 0, 1]]]
69 | elif isinstance(placement, np.ndarray):
70 | if placement.shape == (7, ): # XYZ-quat
71 | R = pin.Quaternion(np.reshape(placement[3:], [4, 1])).matrix()
72 | p = placement[:3]
73 | T = np.r_[np.c_[R, p], [[0, 0, 0, 1]]]
74 | else:
75 | print('Error, np.shape of placement is not accepted')
76 | return False
77 | else:
78 | print('Error format of placement is not accepted')
79 | return False
80 | self.viewer[name].set_transform(T)
81 |
82 | def delete(self, name):
83 | self.viewer[name].delete()
84 |
85 | def __getitem__(self, name):
86 | return self.viewer[name]
87 |
--------------------------------------------------------------------------------
/utils/tests.py:
--------------------------------------------------------------------------------
1 | import doctest
2 |
3 | from utils import load_ur5_parallel
4 |
5 |
6 | def load_tests(loader, tests, pattern):
7 | tests.addTests(doctest.DocTestSuite(load_ur5_parallel))
8 | return tests
9 |
--------------------------------------------------------------------------------
/utils/tiago_loader.py:
--------------------------------------------------------------------------------
1 | '''
2 | Tiago loader accounting for the first planar joint giving robot mobility.
3 | '''
4 |
5 | import numpy as np
6 | import pinocchio as pin
7 | import example_robot_data as robex
8 | import hppfcl
9 | from os.path import dirname, exists, join
10 |
11 |
12 | class TiagoLoader(object):
13 | #path = ''
14 | #urdf_filename = ''
15 | srdf_filename = ''
16 | urdf_subpath = 'robots'
17 | srdf_subpath = 'srdf'
18 | ref_posture = 'half_sitting'
19 | has_rotor_parameters = False
20 | free_flyer = True
21 | verbose = False
22 | path = "tiago_description"
23 | urdf_filename = "tiago_no_hand.urdf"
24 |
25 | def __init__(self):
26 | urdf_path = join(self.path, self.urdf_subpath, self.urdf_filename)
27 | self.model_path = robex.getModelPath(urdf_path, self.verbose)
28 | self.urdf_path = join(self.model_path, urdf_path)
29 | self.robot = pin.RobotWrapper.BuildFromURDF(self.urdf_path, [join(self.model_path, '../..')],
30 | pin.JointModelPlanar() if self.free_flyer else None)
31 |
32 | if self.srdf_filename:
33 | self.srdf_path = join(self.model_path, self.path, self.srdf_subpath, self.srdf_filename)
34 | self.q0 = readParamsFromSrdf(self.robot.model, self.srdf_path, self.verbose, self.has_rotor_parameters,
35 | self.ref_posture)
36 | else:
37 | self.srdf_path = None
38 | self.q0 = None
39 |
40 | if self.free_flyer:
41 | self.addFreeFlyerJointLimits()
42 |
43 | def addFreeFlyerJointLimits(self):
44 | ub = self.robot.model.upperPositionLimit
45 | ub[:self.robot.model.joints[1].nq] = 1
46 | self.robot.model.upperPositionLimit = ub
47 | lb = self.robot.model.lowerPositionLimit
48 | lb[:self.robot.model.joints[1].nq] = -1
49 | self.robot.model.lowerPositionLimit = lb
50 |
51 |
52 | def loadTiago(addGazeFrame=False):
53 | '''
54 | Load a tiago model, without the hand, and with the two following modifications wrt example_robot_data.
55 | - first, the first joint is a planar (x,y,cos,sin) joint, while it is a fixed robot in example robot data.
56 | - second, two visual models of a frame have been added to two new op-frame, "tool0" on the robot hand, and "basis0" in
57 | front of the basis.
58 | '''
59 |
60 |
61 | robot = TiagoLoader().robot
62 | geom = robot.visual_model
63 |
64 | X = pin.utils.rotate('y', np.pi/2)
65 | Y = pin.utils.rotate('x',-np.pi/2)
66 | Z = np.eye(3)
67 |
68 | L = .3
69 | cyl=hppfcl.Cylinder(L/30,L)
70 | med = np.array([0,0,L/2])
71 |
72 | # ---------------------------------------------------------------------------
73 | # Add a frame visualisation in the effector.
74 |
75 | FIDX = robot.model.getFrameId('wrist_ft_tool_link')
76 | JIDX = robot.model.frames[FIDX].parent
77 |
78 | eff = np.array([0,0,.08])
79 | FIDX = robot.model.addFrame(pin.Frame('frametool',JIDX,FIDX,pin.SE3(Z,eff),pin.FrameType.OP_FRAME))
80 |
81 | geom.addGeometryObject(pin.GeometryObject('axis_x',FIDX,JIDX,cyl,pin.SE3(X,X@med+eff)))
82 | geom.geometryObjects[-1].meshColor = np.array([1,0,0,1.])
83 |
84 | geom.addGeometryObject(pin.GeometryObject('axis_y',FIDX,JIDX,cyl,pin.SE3(Y,Y@med+eff)))
85 | geom.geometryObjects[-1].meshColor = np.array([0,1,0,1.])
86 |
87 | geom.addGeometryObject(pin.GeometryObject('axis_z',FIDX,JIDX,cyl,pin.SE3(Z,Z@med+eff)))
88 | geom.geometryObjects[-1].meshColor = np.array([0,0,1,1.])
89 |
90 | # ---------------------------------------------------------------------------
91 | # Add a frame visualisation in front of the basis.
92 |
93 | FIDX = robot.model.getFrameId('base_link')
94 | JIDX = robot.model.frames[FIDX].parent
95 |
96 | eff = np.array([.3,0,.15])
97 | FIDX = robot.model.addFrame(pin.Frame('framebasis',JIDX,FIDX,pin.SE3(Z,eff),pin.FrameType.OP_FRAME))
98 |
99 | geom.addGeometryObject(pin.GeometryObject('axis2_x',FIDX,JIDX,cyl,pin.SE3(X,X@med+eff)))
100 | geom.geometryObjects[-1].meshColor = np.array([1,0,0,1.])
101 |
102 | geom.addGeometryObject(pin.GeometryObject('axi2_y',FIDX,JIDX,cyl,pin.SE3(Y,Y@med+eff)))
103 | geom.geometryObjects[-1].meshColor = np.array([0,1,0,1.])
104 |
105 | geom.addGeometryObject(pin.GeometryObject('axis2_z',FIDX,JIDX,cyl,pin.SE3(Z,Z@med+eff)))
106 | geom.geometryObjects[-1].meshColor = np.array([0,0,1,1.])
107 |
108 | # ---------------------------------------------------------------------------
109 | # Add a frame visualisation in front of the head.
110 |
111 | if addGazeFrame:
112 | L = .05
113 | cyl=hppfcl.Cylinder(L/30,L)
114 | med = np.array([0,0,L/2])
115 |
116 | FIDX = robot.model.getFrameId('xtion_joint')
117 | JIDX = robot.model.frames[FIDX].parent
118 |
119 | eff = np.array([0.4,0.0,0.0])
120 | FIDX = robot.model.addFrame(pin.Frame('framegaze',JIDX,FIDX,pin.SE3(Z,eff),pin.FrameType.OP_FRAME))
121 |
122 | geom.addGeometryObject(pin.GeometryObject('axisgaze_x',FIDX,JIDX,cyl,pin.SE3(X,X@med+eff)))
123 | geom.geometryObjects[-1].meshColor = np.array([1,0,0,1.])
124 |
125 | geom.addGeometryObject(pin.GeometryObject('axisgaze_y',FIDX,JIDX,cyl,pin.SE3(Y,Y@med+eff)))
126 | geom.geometryObjects[-1].meshColor = np.array([0,1,0,1.])
127 |
128 | geom.addGeometryObject(pin.GeometryObject('axisgaze_z',FIDX,JIDX,cyl,pin.SE3(Z,Z@med+eff)))
129 | geom.geometryObjects[-1].meshColor = np.array([0,0,1,1.])
130 |
131 | # -------------------------------------------------------------------------------
132 | # Regenerate the data from the new models.
133 |
134 | robot.q0 = np.array([1,1,1,0]+[0]*(robot.model.nq-4))
135 |
136 | robot.data = robot.model.createData()
137 | robot.visual_data = robot.visual_model.createData()
138 |
139 | return robot
140 |
141 | # ------------------------------------------------------------------------------------------------
142 | # ------------------------------------------------------------------------------------------------
143 | # ------------------------------------------------------------------------------------------------
144 |
145 | if __name__ == "__main__":
146 | from utils.meshcat_viewer_wrapper import MeshcatVisualizer
147 |
148 | robot = loadTiago()
149 | viz = MeshcatVisualizer(robot,url='classical')
150 |
151 | viz.display(robot.q0)
152 |
153 |
--------------------------------------------------------------------------------