├── .gitignore ├── 0_introduction_to_numerical_robotics.ipynb ├── 1_motion_planning.ipynb ├── 2_inverse_kinematics.ipynb ├── 3_reinforcement_learning.ipynb ├── LICENSE ├── README.md ├── robotics-mva.yml └── utils ├── __init__.py ├── collision_wrapper.py ├── datastructures ├── bucketkdtree.py ├── mtree │ ├── __init__.py │ ├── faster.py │ ├── functions.py │ └── heap_queue.py ├── pathtree.py ├── storage.py └── tree.py ├── generate.py ├── load_ur5_parallel.py ├── load_ur5_with_obstacles.py ├── meshcat_viewer_wrapper ├── __init__.py ├── colors.py ├── tests.py ├── transformations.py └── visualizer.py ├── tests.py └── tiago_loader.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.zip 2 | *_solution.ipynb 3 | *_tensorboard 4 | .ipynb_checkpoints 5 | __pycache__ 6 | -------------------------------------------------------------------------------- /0_introduction_to_numerical_robotics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Introduction to numerical robotics\n", 8 | "\n", 9 | "This notebook is a general introduction to Pinocchio. It shows how to manipulate the geometry model of a robot manipulator: set the configuration, compute the position of the end effector, check for collisions or the distance to an obstacle. The main idea is to give a brief introduction of the general topic: how to discover and learn a robot movement constrained by the environment, using iterative optimization methods.\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Set up\n", 17 | "\n", 18 | "Let us load the UR5 robot model, the Pinocchio library, some optimization functions from SciPy and the Matplotlib for plotting:" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import pinocchio as pin\n", 28 | "from utils.meshcat_viewer_wrapper import MeshcatVisualizer\n", 29 | "import time\n", 30 | "import numpy as np\n", 31 | "from numpy.linalg import inv,norm,pinv,svd,eig\n", 32 | "from scipy.optimize import fmin_bfgs,fmin_slsqp\n", 33 | "from utils.load_ur5_with_obstacles import load_ur5_with_obstacles,Target\n", 34 | "import matplotlib.pylab as plt" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "Let's first load the robot model and display it. For this tutorial, a single utility function will load the robot model and create obstacles around it:" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "robot = load_ur5_with_obstacles(reduced=True)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "The next few lines initialize a 3D viewer." 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "viz = MeshcatVisualizer(robot)\n", 67 | "viz.display(robot.q0)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "hasattr(viz.viewer, 'jupyter_cell') and viz.viewer.jupyter_cell()" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "The robot and the red obstacles are encoded in the `robot` object (we will not look in depth at what is inside this object). You can display a new configuration of the robot with `viz.display`. It takes a `numpy.array` of dimension 2 as input:" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "viz.display(np.array([3.,-1.5]))" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "We also set up a target with is visualized as a green dot:" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "target_pos = np.array([.5,.5])\n", 109 | "target = Target(viz,position = target_pos)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "The `Target` object is the green dot that the robot should reach. You can change the target position by editing `target.position`, and display the new position with `target.display()`." 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "## Using the robot model\n", 124 | "The robot is originally a 6 degrees-of-freedom (DOF) manipulator. Yet to make the example simple, we will only use its joints 1 and 2. The model has simply be loaded with \"frozen\" extra joints, which will then not appear in this notebook. Reload the model with `reduced=False` if you want to recover a model with full DOF." 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "The following function computes the position of the end effector (in 2d):" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "def endef(q):\n", 141 | " '''Return the 2d position of the end effector.'''\n", 142 | " pin.framesForwardKinematics(robot.model, robot.data, q)\n", 143 | " return robot.data.oMf[-1].translation[[0, 2]]\n" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "This function checks if the robot is in collision, and returns `True` if a collision is detected." 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "def coll(q):\n", 160 | " '''Return True if in collision, false otherwise.'''\n", 161 | " pin.updateGeometryPlacements(robot.model, robot.data, robot.collision_model, robot.collision_data, q)\n", 162 | " return pin.computeCollisions(robot.collision_model, robot.collision_data, False)\n" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "The next function computes the distance between the end effector and the target." 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "Your code:" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "def dist(q):\n", 186 | " '''Return the distance between the end effector end the target (2d).'''\n", 187 | " return 0.\n" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "Solution" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "%do_not_load tp0/generated/simple_path_planning_dist" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "## Random search of a valid configuration\n", 211 | "The free space is difficult to represent explicitely. We can sample the configuration space until a free configuration is found:" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "def qrand(check=False):\n", 221 | " '''Return a random configuration. If `check` is True, this configuration is not is collision.'''\n", 222 | " pass" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "The solution if needed:" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "%do_not_load tp0/generated/simple_path_planning_qrand" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "viz.display(qrand(check=True))" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "Let's now find a valid configuration that is arbitrarily close to the target: sample until dist is small enough and coll is false (you may want to display the random trials inside the loop)." 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "## From a random configuration to the target\n", 262 | "Let' s now start from a random configuration. How can we find a path that bring the robot toward the target without touching the obstacles. Any idea?" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "# Random descent: crawling from one free configuration to the target with random\n", 272 | "# steps.\n", 273 | "def randomDescent(q0 = None):\n", 274 | " '''\n", 275 | " Make a random walk of 0.1 step toward target\n", 276 | " Return the list of configurations visited\n", 277 | " '''\n", 278 | " q = qrand(check=True) if q0 is None else q0\n", 279 | " hist = [ q.copy() ]\n", 280 | " # DO the walk\n", 281 | " return hist" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "And solution if needed" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "%do_not_load tp0/generated/simple_path_planning_random_descent" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "randomDescent()" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "## Configuration space\n", 314 | "Let's try to have a better look of the configuration space. In this case, it is easy, as it is dimension 2: we can sample it exhaustively and plot it in 2D. For that, let's introduce another function to compute the distance to collision:" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 46, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "def collisionDistance(q):\n", 324 | " '''Return the minimal distance between robot and environment. '''\n", 325 | " pin.updateGeometryPlacements(robot.model,robot.data,robot.collision_model,robot.collision_data,q)\n", 326 | " if pin.computeCollisions(robot.collision_model,robot.collision_data,False):\n", 327 | " return 0.0\n", 328 | " idx = pin.computeDistances(robot.collision_model,robot.collision_data)\n", 329 | " return robot.collision_data.distanceResults[idx].min_distance" 330 | ] 331 | }, 332 | { 333 | "cell_type": "markdown", 334 | "metadata": {}, 335 | "source": [ 336 | "Now, let's sample the configuration space and plot the distance-to-target and the distance-to-obstacle field (I put 500 samples to spare your CPU, but you need at least 10x more for obtaining a good picture)." 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": null, 342 | "metadata": {}, 343 | "outputs": [], 344 | "source": [ 345 | "def sampleSpace(nbSamples=500):\n", 346 | " '''\n", 347 | " Sample nbSamples configurations and store them in two lists depending\n", 348 | " if the configuration is in free space (hfree) or in collision (hcol), along\n", 349 | " with the distance to the target and the distance to the obstacles.\n", 350 | " '''\n", 351 | " hcol = []\n", 352 | " hfree = []\n", 353 | " for i in range(nbSamples):\n", 354 | " q = qrand(False)\n", 355 | " if not coll(q):\n", 356 | " hfree.append( list(q.flat) + [ dist(q), collisionDistance(q) ])\n", 357 | " else:\n", 358 | " hcol.append( list(q.flat) + [ dist(q), 1e-2 ])\n", 359 | " return hcol,hfree\n", 360 | "\n", 361 | "def plotConfigurationSpace(hcol,hfree,markerSize=20):\n", 362 | " '''\n", 363 | " Plot 2 \"scatter\" plots: the first one plot the distance to the target for \n", 364 | " each configuration, the second plots the distance to the obstacles (axis q1,q2, \n", 365 | " distance in the color space).\n", 366 | " '''\n", 367 | " htotal = hcol + hfree\n", 368 | " h=np.array(htotal)\n", 369 | " plt.subplot(2,1,1)\n", 370 | " plt.scatter(h[:,0],h[:,1],c=h[:,2],s=markerSize,lw=0)\n", 371 | " plt.title(\"Distance to the target\")\n", 372 | " plt.colorbar()\n", 373 | " plt.subplot(2,1,2)\n", 374 | " plt.scatter(h[:,0],h[:,1],c=h[:,3],s=markerSize,lw=0)\n", 375 | " plt.title(\"Distance to the obstacles\")\n", 376 | " plt.colorbar()" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": null, 382 | "metadata": {}, 383 | "outputs": [], 384 | "source": [ 385 | "hcol,hfree = sampleSpace(5000)\n", 386 | "plotConfigurationSpace(hcol,hfree)\n" 387 | ] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": {}, 392 | "source": [ 393 | "You can try to match your representation of the free space of the robot with this plot. \n", 394 | "As an example, you can display on this plot a feasible trajectory discover by random walk from an init position." 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": null, 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "traj = np.array([])\n", 404 | "qinit = np.array([-1.1, -3. ])" 405 | ] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "metadata": {}, 410 | "source": [ 411 | "Here is a solution:" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": null, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "%do_not_load tp0/generated/simple_path_planning_traj" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "# Add yourr traj to the plot, be carefull !\n", 430 | "plotConfigurationSpace(hcol,hfree)\n", 431 | "plt.plot(traj[:,0],traj[:,1],'r',lw=3)" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "metadata": {}, 437 | "source": [ 438 | "## Optimize the distance under non-collision constraint\n", 439 | "Finally, let's use one of the optimizers from SciPy to search for a robot configuration that minimizes the distance to the target, under the constraint that the distance to collision is positive.\n", 440 | "For that, we define a *cost function* $cost: \\mathcal{C} \\to \\mathbb{R}$ (taking the robot configuration and returning a scalar) and a constraint function (taking again the robot configuration and returning a scalar or a vector of scalar that should be positive). We additionally use the \"callback\" functionnality of the solver to render the robot configuration corresponding to the current value of the decision variable inside the solver algorithm.\n", 441 | "We use the \"SLSQP\" solver from SciPy, which implements a \"sequential quadratic program\" algorithm and accepts constraints.\n" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": {}, 448 | "outputs": [], 449 | "source": [ 450 | "def cost(q):\n", 451 | " pass\n", 452 | " \n", 453 | "def constraint(q):\n", 454 | " pass\n", 455 | " \n", 456 | "def callback(q):\n", 457 | " '''\n", 458 | " At each optimization step, display the robot configuration in gepetto-viewer.\n", 459 | " '''\n", 460 | " viz.display(q)\n", 461 | " time.sleep(.01)\n", 462 | "\n", 463 | "def optimize():\n", 464 | " '''\n", 465 | " Optimize from an initial random configuration to discover a collision-free\n", 466 | " configuration as close as possible to the target.\n", 467 | " USE fmin_slsqp, see doc online\n", 468 | " '''" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": {}, 474 | "source": [ 475 | "Here is a valid solution:" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "metadata": {}, 482 | "outputs": [], 483 | "source": [ 484 | "%do_not_load tp0/generated/simple_path_planning_optim" 485 | ] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "metadata": {}, 490 | "source": [ 491 | "Look at the output of the solver. It always returns a variable value, but sometimes the algorithm fails being traped in an unfeasible region. Most of the time, the solver converges to a local minimum where the final distance to the target is nonzero" 492 | ] 493 | }, 494 | { 495 | "cell_type": "markdown", 496 | "metadata": {}, 497 | "source": [ 498 | "Now you can write a planner that try to optimize and retry until a valid solition is found!" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [ 507 | "# Your solution" 508 | ] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": {}, 513 | "source": [ 514 | "And the solution if you need it:" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": null, 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [ 523 | "%do_not_load tp0/generated/simple_path_planning_useit" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": null, 529 | "metadata": {}, 530 | "outputs": [], 531 | "source": [] 532 | } 533 | ], 534 | "metadata": { 535 | "kernelspec": { 536 | "display_name": "Python 3 (ipykernel)", 537 | "language": "python", 538 | "name": "python3" 539 | }, 540 | "language_info": { 541 | "codemirror_mode": { 542 | "name": "ipython", 543 | "version": 3 544 | }, 545 | "file_extension": ".py", 546 | "mimetype": "text/x-python", 547 | "name": "python", 548 | "nbconvert_exporter": "python", 549 | "pygments_lexer": "ipython3", 550 | "version": "3.10.12" 551 | } 552 | }, 553 | "nbformat": 4, 554 | "nbformat_minor": 4 555 | } 556 | -------------------------------------------------------------------------------- /1_motion_planning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "66575b61-30a5-4471-9e4e-a45c6fc396a9", 6 | "metadata": {}, 7 | "source": [ 8 | "# Implement RRT and its variant on UR5" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "b5d79912-1a64-4466-8017-70724567b28c", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import example_robot_data as robex\n", 19 | "import hppfcl\n", 20 | "import math\n", 21 | "import numpy as np\n", 22 | "import pinocchio as pin\n", 23 | "import time\n", 24 | "from tqdm import tqdm" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "8abf514d-0c36-4c32-934f-e1d013c57ea7", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import matplotlib.pylab as plt; plt.ion()" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "bf038376-9d69-48e9-9bdb-0bcf3b641247", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "from utils.meshcat_viewer_wrapper import MeshcatVisualizer, colors\n", 45 | "from utils.datastructures.storage import Storage\n", 46 | "from utils.datastructures.pathtree import PathTree\n", 47 | "from utils.datastructures.mtree import MTree\n", 48 | "from utils.collision_wrapper import CollisionWrapper" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "58561974-a3d2-4692-821c-30f4d0caf96f", 54 | "metadata": {}, 55 | "source": [ 56 | "## Load UR5" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "7217ea2c-5cac-430e-8496-93e80e78b859", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "robot = robex.load('ur5')\n", 67 | "collision_model = robot.collision_model\n", 68 | "visual_model = robot.visual_model" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "98805b5f-f993-4d4a-98ca-c465ff363424", 74 | "metadata": {}, 75 | "source": [ 76 | "Recall some placement for the UR5" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "9f521701-7416-4736-9363-305a45258d14", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "a = robot.placement(robot.q0, 6) # Placement of the end effector joint.\n", 87 | "b = robot.framePlacement(robot.q0, 22) # Placement of the end effector tip.\n", 88 | "\n", 89 | "tool_axis = b.rotation[:, 2] # Axis of the tool\n", 90 | "tool_position = b.translation" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "d96626e3-1b22-446f-916a-f7ee54b78f04", 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "viz = MeshcatVisualizer(robot)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "e2d1d4ee-c1f4-4350-9370-ed8f32dee4cf", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "viz.viewer.jupyter_cell()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "id": "8ede57d4-8b0d-4a69-adbd-95901db503cd", 116 | "metadata": {}, 117 | "source": [ 118 | "Set a start and a goal configuration" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "4b14d4cd-9e44-4e14-9a56-ec08256d0238", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "q_i = np.array([1., -1.5, 2.1, -.5, -.5, 0])\n", 129 | "q_g = np.array([3., -1., 1, -.5, -.5, 0])\n", 130 | "radius = 0.05" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "a05303fd-5fa9-47bf-92cd-63d64e2212cf", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "viz.display(q_i)\n", 141 | "M = robot.framePlacement(q_i, 22)\n", 142 | "name = \"world/sph_initial\"\n", 143 | "viz.addSphere(name, radius, [0., 1., 0., 1.])\n", 144 | "viz.applyConfiguration(name,M)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "dbc882d4-b335-4255-9db3-3aa61eeb8ee7", 151 | "metadata": { 152 | "tags": [] 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "viz.display(q_g)\n", 157 | "M = robot.framePlacement(q_g, 22)\n", 158 | "name = \"world/sph_goal\"\n", 159 | "viz.addSphere(name, radius, [0., 0., 1., 1.])\n", 160 | "viz.applyConfiguration(name,M)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "id": "6a4bce9b-0c81-4c7a-bee9-02c9aa59b829", 166 | "metadata": {}, 167 | "source": [ 168 | "## Implement everything needed for RRT" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "id": "9d4c45e0-ee85-47b1-88ae-94c3c8fb50d8", 174 | "metadata": {}, 175 | "source": [ 176 | "We abstract the robot the environment and its behaviour in a class call `System`\n", 177 | "\n", 178 | "It must be able to:\n", 179 | "- generate random configuration which are not colliding if needed (sampling)\n", 180 | "- implement a distance on the configuration space (distance)\n", 181 | "- generate path between two configuration (steering)\n", 182 | "- check if a path is free between two configuration and return the latest free config (directional free steering)\n", 183 | "and some function to display the configuration.\n", 184 | "\n", 185 | "Recall that in the case of the UR5 the configuration space is $S_1^{6}$, where $S_1$ is the unit cirle, we can parametrize by $\\theta\\in[-\\pi,\\pi]$ such that $-\\pi$ and $\\pi$ are identified.\n", 186 | "\n", 187 | "In the next cell, you must implement the system behaviour for the UR5." 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "id": "b74e5455-eae8-49fe-9030-f403467df288", 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "class System():\n", 198 | "\n", 199 | " def __init__(self, robot):\n", 200 | " self.robot = robot\n", 201 | " robot.gmodel = robot.collision_model\n", 202 | " self.display_edge_count = 0\n", 203 | " self.colwrap = CollisionWrapper(robot) # For collision checking\n", 204 | " self.nq = self.robot.nq\n", 205 | " self.display_count = 0\n", 206 | " \n", 207 | " def distance(self, q1, q2):\n", 208 | " \"\"\"\n", 209 | " Must return a distance between q1 and q2 which can be a batch of config.\n", 210 | " \"\"\"\n", 211 | " if len(q2.shape) > len(q1.shape):\n", 212 | " q1 = q1[None, ...]\n", 213 | " e = np.mod(np.abs(q1 - q2), 2 * np.pi)\n", 214 | " e[e > np.pi] = 2 * np.pi - e[e > np.pi]\n", 215 | " return np.linalg.norm(e, axis=-1)\n", 216 | "\n", 217 | " def random_config(self, free=True):\n", 218 | " \"\"\"\n", 219 | " Must return a random configuration which is not in collision if free=True\n", 220 | " \"\"\"\n", 221 | " q = 2 * np.pi * np.random.rand(6) - np.pi\n", 222 | " if not free:\n", 223 | " return q\n", 224 | " while self.is_colliding(q):\n", 225 | " q = 2 * np.pi * np.random.rand(6) - np.pi\n", 226 | " return q\n", 227 | "\n", 228 | " def is_colliding(self, q):\n", 229 | " \"\"\"\n", 230 | " Use CollisionWrapper to decide if a configuration is in collision\n", 231 | " \"\"\"\n", 232 | " self.colwrap.computeCollisions(q)\n", 233 | " collisions = self.colwrap.getCollisionList()\n", 234 | " return (len(collisions) > 0)\n", 235 | "\n", 236 | " def get_path(self, q1, q2, l_min=None, l_max=None, eps=0.2):\n", 237 | " \"\"\"\n", 238 | " generate a continuous path with precision eps between q1 and q2\n", 239 | " If l_min of l_max is mention, extrapolate or cut the path such\n", 240 | " that \n", 241 | " \"\"\"\n", 242 | " q1 = np.mod(q1 + np.pi, 2 * np.pi) - np.pi\n", 243 | " q2 = np.mod(q2 + np.pi, 2 * np.pi) - np.pi\n", 244 | "\n", 245 | " diff = q2 - q1\n", 246 | " query = np.abs(diff) > np.pi\n", 247 | " q2[query] = q2[query] - np.sign(diff[query]) * 2 * np.pi\n", 248 | "\n", 249 | " d = self.distance(q1, q2)\n", 250 | " if d < eps:\n", 251 | " return np.stack([q1, q2], axis=0)\n", 252 | " \n", 253 | " if l_min is not None or l_max is not None:\n", 254 | " new_d = np.clip(d, l_min, l_max)\n", 255 | " else:\n", 256 | " new_d = d\n", 257 | " \n", 258 | " N = int(new_d / eps + 2)\n", 259 | "\n", 260 | " return np.linspace(q1, q1 + (q2 - q1) * new_d / d, N)\n", 261 | " \n", 262 | " def is_free_path(self, q1, q2, l_min=0.2, l_max=1., eps=0.2):\n", 263 | " \"\"\"\n", 264 | " Create a path and check collision to return the last\n", 265 | " non-colliding configuration. Return X, q where X is a boolean\n", 266 | " who state is the steering has work.\n", 267 | " We require at least l_min must be cover without collision to validate the path.\n", 268 | " \"\"\"\n", 269 | " q_path = self.get_path(q1, q2, l_min, l_max, eps)\n", 270 | " N = len(q_path)\n", 271 | " N_min = N - 1 if l_min is None else min(N - 1, int(l_min / eps))\n", 272 | " for i in range(N):\n", 273 | " if self.is_colliding(q_path[i]):\n", 274 | " break\n", 275 | " if i < N_min:\n", 276 | " return False, None\n", 277 | " if i == N - 1:\n", 278 | " return True, q_path[-1]\n", 279 | " return True, q_path[i - 1]\n", 280 | "\n", 281 | " def reset(self):\n", 282 | " \"\"\"\n", 283 | " Reset the system visualization\n", 284 | " \"\"\"\n", 285 | " for i in range(self.display_count):\n", 286 | " viz.delete(f\"world/sph{i}\")\n", 287 | " viz.delete(f\"world/cil{i}\")\n", 288 | " self.display_count = 0\n", 289 | " \n", 290 | " def display_edge(self, q1, q2, radius=0.01, color=[1.,0.,0.,1]):\n", 291 | " M1 = self.robot.framePlacement(q1, 22) # Placement of the end effector tip.\n", 292 | " M2 = self.robot.framePlacement(q2, 22) # Placement of the end effector tip.\n", 293 | " middle = .5 * (M1.translation + M2.translation)\n", 294 | " direction = M2.translation - M1.translation\n", 295 | " length = np.linalg.norm(direction)\n", 296 | " dire = direction / length\n", 297 | " orth = np.cross(dire, np.array([0, 0, 1]))\n", 298 | " orth2 = np.cross(dire, orth)\n", 299 | " Mcyl = pin.SE3(np.stack([orth2, dire, orth], axis=1), middle)\n", 300 | " name = f\"world/sph{self.display_count}\"\n", 301 | " viz.addSphere(name, radius, [1.,0.,0.,1])\n", 302 | " viz.applyConfiguration(name,M2)\n", 303 | " name = f\"world/cil{self.display_count}\"\n", 304 | " viz.addCylinder(name, length, radius / 4, [0., 1., 0., 1])\n", 305 | " viz.applyConfiguration(name,Mcyl)\n", 306 | " self.display_count +=1\n", 307 | " \n", 308 | " def display_motion(self, qs, step=1e-1):\n", 309 | " # Given a point path display the smooth movement\n", 310 | " for i in range(len(qs) - 1):\n", 311 | " for q in self.get_path(qs[i], qs[i+1])[:-1]:\n", 312 | " viz.display(q)\n", 313 | " time.sleep(step)\n", 314 | " viz.display(qs[-1])\n" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "id": "51178d60-0474-4d57-9336-d8f1617b8a4a", 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "system = System(robot)" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "id": "4b0ced1c-2211-459a-aeb3-fb9c3b25085d", 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "system.distance(q_i, q_g)" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "id": "9fafd71f-87f0-44b3-a92c-48f53b6580ac", 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "system.display_motion(system.get_path(q_i, q_g))" 345 | ] 346 | }, 347 | { 348 | "cell_type": "markdown", 349 | "id": "914d0701-bc8f-46c0-8888-8e019f15cfc3", 350 | "metadata": {}, 351 | "source": [ 352 | "## RRT implementation" 353 | ] 354 | }, 355 | { 356 | "cell_type": "markdown", 357 | "id": "3a64cb4c-28b0-4e66-9324-1cba47374186", 358 | "metadata": {}, 359 | "source": [ 360 | "In its most simple form, RRT construct a tree from the start, eventually with a bias toward the goal. In the following class, we add some memoization to avoid recomputing distances. The kNN (k Nearest Neighbors) structure works on node indices." 361 | ] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "id": "3b1a4b31-07d2-463f-851e-cd4d2ca09bbf", 366 | "metadata": {}, 367 | "source": [ 368 | "Let us look at an implementation the core algorithm:" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "id": "a62a504d-9ac1-4bcc-a149-cd6516f90a0e", 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "class RRT():\n", 379 | " \"\"\"\n", 380 | " Can be splited into RRT base because different rrt\n", 381 | " have factorisable logic\n", 382 | " \"\"\"\n", 383 | " def __init__(\n", 384 | " self,\n", 385 | " system,\n", 386 | " node_max=500000,\n", 387 | " iter_max=1000000,\n", 388 | " N_bias=10,\n", 389 | " l_min=.2,\n", 390 | " l_max=.5,\n", 391 | " steer_delta=.1,\n", 392 | " ):\n", 393 | " \"\"\"\n", 394 | " [Here, in proper code, we would document the parameters of our function. Do that below,\n", 395 | " using the Google style for docstrings.]\n", 396 | " https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html\n", 397 | "\n", 398 | " Args:\n", 399 | " node_max: ...\n", 400 | " iter_max: ...\n", 401 | " ...\n", 402 | " \"\"\"\n", 403 | " self.system = system\n", 404 | " # params\n", 405 | " self.l_max = l_max\n", 406 | " self.l_min = l_min\n", 407 | " self.N_bias = N_bias\n", 408 | " self.node_max = node_max\n", 409 | " self.iter_max = iter_max\n", 410 | " self.steer_delta = steer_delta\n", 411 | " # intern\n", 412 | " self.NNtree = None\n", 413 | " self.storage = None\n", 414 | " self.pathtree = None\n", 415 | " # The distance function will be called on N, dim object\n", 416 | " self.real_distance = self.system.distance\n", 417 | " # Internal for computational_opti in calculating distance\n", 418 | " self._candidate = None\n", 419 | " self._goal = None\n", 420 | " self._cached_dist_to_candidate = {}\n", 421 | " self._cached_dist_to_goal = {}\n", 422 | "\n", 423 | " def distance(self, q1_idx, q2_idx):\n", 424 | " if isinstance(q2_idx, int):\n", 425 | " if q1_idx == q2_idx:\n", 426 | " return 0.\n", 427 | " if q1_idx == -1 or q2_idx == -1:\n", 428 | " if q2_idx == -1:\n", 429 | " q1_idx, q2_idx = q2_idx, q1_idx\n", 430 | " if q2_idx not in self._cached_dist_to_candidate:\n", 431 | " self._cached_dist_to_candidate[q2_idx] = self.real_distance(\n", 432 | " self._candidate, self.storage[q2_idx]\n", 433 | " )\n", 434 | " return self._cached_dist_to_candidate[q2_idx]\n", 435 | " if q1_idx == -2 or q2_idx == -2:\n", 436 | " if q2_idx == -2:\n", 437 | " q1_idx, q2_idx = q2_idx, q1_idx\n", 438 | " if q2_idx not in self._cached_dist_to_goal:\n", 439 | " self._cached_dist_to_goal[q2_idx] = self.real_distance(\n", 440 | " self._goal, self.storage[q2_idx]\n", 441 | " )\n", 442 | " return self._cached_dist_to_goal[q2_idx]\n", 443 | " return self.real_distance(self.storage[q1_idx], self.storage[q2_idx])\n", 444 | " if q1_idx == -1:\n", 445 | " q = self._candidate\n", 446 | " elif q1_idx == -2:\n", 447 | " q = self._goal\n", 448 | " else:\n", 449 | " q = self.storage[q1_idx]\n", 450 | " return self.real_distance(q, self.storage[q2_idx])\n", 451 | "\n", 452 | " def new_candidate(self):\n", 453 | " q = self.system.random_config(free=True)\n", 454 | " self._candidate = q\n", 455 | " self._cached_dist_to_candidate = {}\n", 456 | " return q\n", 457 | "\n", 458 | " def solve(self, qi, validate, qg=None):\n", 459 | " self.system.reset()\n", 460 | " self._goal = qg\n", 461 | " \n", 462 | " # Reset internal datastructures\n", 463 | " self.storage = Storage(self.node_max, self.system.nq)\n", 464 | " self.pathtree = PathTree(self.storage)\n", 465 | " self.NNtree = MTree(self.distance)\n", 466 | " qi_idx = self.storage.add_point(qi)\n", 467 | " self.NNtree.add_point(qi_idx)\n", 468 | " self.it_trace = []\n", 469 | "\n", 470 | " found = False\n", 471 | " iterator = range(self.iter_max)\n", 472 | " for i in tqdm(iterator):\n", 473 | " # New candidate\n", 474 | " if i % self.N_bias == 0:\n", 475 | " q_new = self._goal\n", 476 | " q_new_idx = -2\n", 477 | " else:\n", 478 | " q_new = self.new_candidate()\n", 479 | " q_new_idx = -1\n", 480 | "\n", 481 | " # Find closest neighboor to q_new\n", 482 | " q_near_idx, d = self.NNtree.nearest_neighbour(q_new_idx)\n", 483 | " \n", 484 | " # Steer from it toward the new checking for colision\n", 485 | " success, q_prox = self.system.is_free_path(\n", 486 | " self.storage.data[q_near_idx],\n", 487 | " q_new,\n", 488 | " l_min=self.l_min,\n", 489 | " l_max=self.l_max,\n", 490 | " eps=self.steer_delta\n", 491 | " )\n", 492 | "\n", 493 | " if not success:\n", 494 | " self.it_trace.append(0)\n", 495 | " continue\n", 496 | " self.it_trace.append(1)\n", 497 | " \n", 498 | " # Add the points in data structures\n", 499 | " q_prox_idx = self.storage.add_point(q_prox)\n", 500 | " self.NNtree.add_point(q_prox_idx)\n", 501 | " self.pathtree.update_link(q_prox_idx, q_near_idx)\n", 502 | " self.system.display_edge(self.storage[q_near_idx], self.storage[q_prox_idx])\n", 503 | "\n", 504 | " # Test if it reach the goal\n", 505 | " if validate(q_prox):\n", 506 | " q_g_idx = self.storage.add_point(q_prox)\n", 507 | " self.NNtree.add_point(q_g_idx)\n", 508 | " self.pathtree.update_link(q_g_idx, q_prox_idx)\n", 509 | " found = True\n", 510 | " break\n", 511 | " self.iter_done = i + 1\n", 512 | " self.found = found\n", 513 | " return found\n", 514 | "\n", 515 | " def get_path(self, q_g):\n", 516 | " assert self.found\n", 517 | " path = self.pathtree.get_path()\n", 518 | " return np.concatenate([path, q_g[None, :]])\n" 519 | ] 520 | }, 521 | { 522 | "cell_type": "markdown", 523 | "id": "73533b79-9692-4856-8094-b4874e3944ef", 524 | "metadata": {}, 525 | "source": [ 526 | "In proper code, we would document the parameters of our functions.\n", 527 | "\n", 528 | "- **Your turn:** Add docstrings to the code above, following the [Google style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html).\n", 529 | "- Optional: you are welcome to add type annotations if you'd like.\n", 530 | "\n", 531 | "The constructor of the `RRT` class invites you to start." 532 | ] 533 | }, 534 | { 535 | "cell_type": "markdown", 536 | "id": "811cc5ae-b601-477d-9d55-e560d0e45262", 537 | "metadata": {}, 538 | "source": [ 539 | "For this problem, we will instantiate our RRT with the following parameters:" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": null, 545 | "id": "676b991b-32d1-4e2f-8dd9-ed85252e180c", 546 | "metadata": {}, 547 | "outputs": [], 548 | "source": [ 549 | "rrt = RRT(\n", 550 | " system,\n", 551 | " N_bias=20,\n", 552 | " l_min=0.2,\n", 553 | " l_max=0.5,\n", 554 | " steer_delta=0.1,\n", 555 | ")" 556 | ] 557 | }, 558 | { 559 | "cell_type": "markdown", 560 | "id": "f20ecc74-a2b6-4b91-8131-cbb7ff3e13a9", 561 | "metadata": {}, 562 | "source": [ 563 | "Now let's define our termination condition, and run the main function:" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": null, 569 | "id": "de20f3cc-6df3-4e5c-ab5b-e183616d4a2e", 570 | "metadata": {}, 571 | "outputs": [], 572 | "source": [ 573 | "eps_final = .1\n", 574 | "def validation(key):\n", 575 | " vec = robot.framePlacement(key, 22).translation - robot.framePlacement(q_g, 22).translation\n", 576 | " return (float(np.linalg.norm(vec)) < eps_final)\n", 577 | "\n", 578 | "rrt.solve(q_i, validation, qg=q_g)" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": null, 584 | "id": "7f3a6a8d-4fd9-419b-8176-214543c08a22", 585 | "metadata": {}, 586 | "outputs": [], 587 | "source": [ 588 | "system.display_motion(rrt.get_path(q_g))" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": null, 594 | "id": "d82012a2-c7c4-4362-85ae-2c8e67a2e2a4", 595 | "metadata": {}, 596 | "outputs": [], 597 | "source": [ 598 | "system.reset()" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "id": "3bd7d92a-abea-482b-8a91-a5942a125585", 604 | "metadata": {}, 605 | "source": [ 606 | "## Create obstacle with environments" 607 | ] 608 | }, 609 | { 610 | "cell_type": "markdown", 611 | "id": "8784c9ad-ab66-4440-9a0f-1a1e10ab4b2d", 612 | "metadata": {}, 613 | "source": [ 614 | "We already had some simple algorithms to find free paths, *i.e.* without obstacles. Let us now add some obstacles to the environment:" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": null, 620 | "id": "5efbc0a8-bd55-4122-b09a-168b64f8de19", 621 | "metadata": {}, 622 | "outputs": [], 623 | "source": [ 624 | "robot = robex.load('ur5')\n", 625 | "collision_model = robot.collision_model\n", 626 | "visual_model = robot.visual_model" 627 | ] 628 | }, 629 | { 630 | "cell_type": "code", 631 | "execution_count": null, 632 | "id": "189570f8-d592-4f6d-b3d7-89a466fdf898", 633 | "metadata": {}, 634 | "outputs": [], 635 | "source": [ 636 | "def addCylinderToUniverse(name, radius, length, placement, color=colors.red):\n", 637 | " geom = pin.GeometryObject(\n", 638 | " name,\n", 639 | " 0,\n", 640 | " hppfcl.Cylinder(radius, length),\n", 641 | " placement\n", 642 | " )\n", 643 | " new_id = collision_model.addGeometryObject(geom)\n", 644 | " geom.meshColor = np.array(color)\n", 645 | " visual_model.addGeometryObject(geom)\n", 646 | " \n", 647 | " for link_id in range(robot.model.nq):\n", 648 | " collision_model.addCollisionPair(\n", 649 | " pin.CollisionPair(link_id, new_id)\n", 650 | " )\n", 651 | " return geom" 652 | ] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": null, 657 | "id": "4c0b531f-992c-47cc-b101-d1113a2f0870", 658 | "metadata": {}, 659 | "outputs": [], 660 | "source": [ 661 | "from pinocchio.utils import rotate\n", 662 | "\n", 663 | "[collision_model.removeGeometryObject(e.name) for e in collision_model.geometryObjects if e.name.startswith('world/')]\n", 664 | "\n", 665 | "# Add a red box in the viewer\n", 666 | "radius = 0.1\n", 667 | "length = 1.\n", 668 | "\n", 669 | "cylID = \"world/cyl1\"\n", 670 | "placement = pin.SE3(pin.SE3(rotate('z',np.pi/2), np.array([-0.5,0.4,0.5])))\n", 671 | "addCylinderToUniverse(cylID,radius,length,placement,color=[.7,.7,0.98,1])\n", 672 | "\n", 673 | "\n", 674 | "cylID = \"world/cyl2\"\n", 675 | "placement = pin.SE3(pin.SE3(rotate('z',np.pi/2), np.array([-0.5,-0.4,0.5])))\n", 676 | "addCylinderToUniverse(cylID,radius,length,placement,color=[.7,.7,0.98,1])\n", 677 | "\n", 678 | "cylID = \"world/cyl3\"\n", 679 | "placement = pin.SE3(pin.SE3(rotate('z',np.pi/2), np.array([-0.5,0.7,0.5])))\n", 680 | "addCylinderToUniverse(cylID,radius,length,placement,color=[.7,.7,0.98,1])\n", 681 | "\n", 682 | "\n", 683 | "cylID = \"world/cyl4\"\n", 684 | "placement = pin.SE3(pin.SE3(rotate('z',np.pi/2), np.array([-0.5,-0.7,0.5])))\n", 685 | "addCylinderToUniverse(cylID,radius,length,placement,color=[.7,.7,0.98,1])" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": null, 691 | "id": "6780c139-e2de-4c9f-8af4-61cb1e7e1b4e", 692 | "metadata": {}, 693 | "outputs": [], 694 | "source": [ 695 | "q_i = np.array([-1., -1.5, 2.1, -.5, -.5, 0])\n", 696 | "q_g = np.array([3.1, -1., 1, -.5, -.5, 0])\n", 697 | "radius = 0.05" 698 | ] 699 | }, 700 | { 701 | "cell_type": "markdown", 702 | "id": "c114f554-6126-47e9-8749-934d47d2c7c1", 703 | "metadata": {}, 704 | "source": [ 705 | "We need to reload the viewer" 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "execution_count": null, 711 | "id": "62ab0ec0-5faf-43fc-bcd8-9ad9cbe3ad56", 712 | "metadata": {}, 713 | "outputs": [], 714 | "source": [ 715 | "viz = MeshcatVisualizer(robot)" 716 | ] 717 | }, 718 | { 719 | "cell_type": "code", 720 | "execution_count": null, 721 | "id": "037ac2f5-e0d3-438a-ab10-9f313a8f8803", 722 | "metadata": {}, 723 | "outputs": [], 724 | "source": [ 725 | "viz.display(q_i)\n", 726 | "M = robot.framePlacement(q_i, 22)\n", 727 | "name = \"world/sph_initial\"\n", 728 | "viz.addSphere(name, radius, [0., 1., 0., 1.])\n", 729 | "viz.applyConfiguration(name,M)" 730 | ] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "execution_count": null, 735 | "id": "5e5dfe04-e29d-4455-ba00-c7c158a8930c", 736 | "metadata": { 737 | "tags": [] 738 | }, 739 | "outputs": [], 740 | "source": [ 741 | "viz.display(q_g)\n", 742 | "M = robot.framePlacement(q_g, 22)\n", 743 | "name = \"world/sph_goal\"\n", 744 | "viz.addSphere(name, radius, [0., 0., 1., 1.])\n", 745 | "viz.applyConfiguration(name,M)" 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "execution_count": null, 751 | "id": "6f17610e-e874-4e4f-aaf7-6c6d569f965e", 752 | "metadata": {}, 753 | "outputs": [], 754 | "source": [ 755 | "viz.display(q_g)" 756 | ] 757 | }, 758 | { 759 | "cell_type": "code", 760 | "execution_count": null, 761 | "id": "e19068e5-033b-412d-8d64-ffef4def949a", 762 | "metadata": {}, 763 | "outputs": [], 764 | "source": [ 765 | "system = System(robot)" 766 | ] 767 | }, 768 | { 769 | "cell_type": "code", 770 | "execution_count": null, 771 | "id": "661a6589-36ff-44c5-b761-c6f588a8eab7", 772 | "metadata": {}, 773 | "outputs": [], 774 | "source": [ 775 | "rrt = RRT(\n", 776 | " system,\n", 777 | " N_bias=20,\n", 778 | " l_min=0.2,\n", 779 | " l_max=0.5,\n", 780 | " steer_delta=0.1,\n", 781 | ")" 782 | ] 783 | }, 784 | { 785 | "cell_type": "code", 786 | "execution_count": null, 787 | "id": "ffa971ee-6396-4533-b875-5c4e7124bc7a", 788 | "metadata": {}, 789 | "outputs": [], 790 | "source": [ 791 | "eps_final = .1\n", 792 | "\n", 793 | "def validation(key):\n", 794 | " vec = robot.framePlacement(key, 22).translation - robot.framePlacement(q_g, 22).translation\n", 795 | " return (float(np.linalg.norm(vec)) < eps_final)\n", 796 | "\n", 797 | "rrt.solve(q_i, validation, qg=q_g)" 798 | ] 799 | }, 800 | { 801 | "cell_type": "code", 802 | "execution_count": null, 803 | "id": "cf22611e-ba2d-4c0e-93a8-03ff922b6073", 804 | "metadata": {}, 805 | "outputs": [], 806 | "source": [ 807 | "system.display_motion(rrt.get_path(q_g))" 808 | ] 809 | }, 810 | { 811 | "cell_type": "markdown", 812 | "id": "160cf161-755f-430e-8e91-e926f644fbe2", 813 | "metadata": {}, 814 | "source": [ 815 | "And solve RRT. It is long right ? Let us implement more efficient algorithms" 816 | ] 817 | }, 818 | { 819 | "cell_type": "markdown", 820 | "id": "c90f1487-cd74-4eb8-8f3f-a0f5e65a1aad", 821 | "metadata": {}, 822 | "source": [ 823 | "## Bi-RRT" 824 | ] 825 | }, 826 | { 827 | "cell_type": "markdown", 828 | "id": "d526af5b-9c37-425f-a30c-45bb80588404", 829 | "metadata": {}, 830 | "source": [ 831 | "Now it's your turn. Make a `BiRRT` class, similar to the `RRT` class above, but implementing the Bi-RRT algorithm. (It is not recommended to try to inherit from `RRT`, as you will end up re-implementing most functions.) Here is a template you are free to adapt, with some advice:" 832 | ] 833 | }, 834 | { 835 | "cell_type": "code", 836 | "execution_count": null, 837 | "id": "c003e3fe-a042-4e3a-b644-f607acb9faef", 838 | "metadata": {}, 839 | "outputs": [], 840 | "source": [ 841 | "class BiRRT(RRT):\n", 842 | " def __init__(\n", 843 | " self,\n", 844 | " system,\n", 845 | " node_max=500000,\n", 846 | " iter_max=1000000,\n", 847 | " l_min=.2,\n", 848 | " l_max=.5,\n", 849 | " steer_delta=.1,\n", 850 | " ):\n", 851 | " # Initialize attributes:\n", 852 | " # self.l_min = l_min\n", 853 | " # etc.\n", 854 | "\n", 855 | " # New: duplicate this attribute as dictionaries with two keys:\n", 856 | " # \"forward\" and \"backward\". See `solve` below.\n", 857 | " self._cached_dist_to_candidate = {}\n", 858 | " self.storage = {}\n", 859 | " self.pathtree = {}\n", 860 | " self.tree = {}\n", 861 | "\n", 862 | " def tree_distance(self, direction: str, q1_idx, q2_idx):\n", 863 | " # Adapt from RRT.distance\n", 864 | " # There is now a direction string to select the underlying tree,\n", 865 | " # either \"forward\" (from q_init) or \"backward\" (from q_goal).\n", 866 | "\n", 867 | " def forward_distance(self, q1_idx, q2_idx):\n", 868 | " return self.tree_distance(\"forward\", q1_idx, q2_idx)\n", 869 | "\n", 870 | " def backward_distance(self, q1_idx, q2_idx):\n", 871 | " return self.tree_distance(\"backward\", q1_idx, q2_idx)\n", 872 | "\n", 873 | " def new_candidate(self):\n", 874 | " # A minor change is required to adapt RRT.new_candidate to this template.\n", 875 | "\n", 876 | " def solve(self, qi, qg):\n", 877 | " # Reset internal datastructures\n", 878 | " for direction in (\"forward\", \"backward\"):\n", 879 | " self._cached_dist_to_candidate[direction] = {}\n", 880 | " self.storage[direction] = Storage(node_max, system.nq)\n", 881 | " self.pathtree[direction] = PathTree(self.storage[direction])\n", 882 | " self.tree = {\n", 883 | " \"forward\": MTree(self.forward_distance),\n", 884 | " \"backward\": MTree(self.backward_distance),\n", 885 | " }\n", 886 | "\n", 887 | " # Now datastructures are initialized\n", 888 | " # The rest is up to you! \n", 889 | "\n", 890 | " def get_path(self):\n", 891 | " assert self.found\n", 892 | " forward_path = self.pathtree[\"forward\"].get_path()\n", 893 | " backward_path = self.pathtree[\"backward\"].get_path()\n", 894 | " return np.concatenate([forward_path, backward_path[::-1]])" 895 | ] 896 | }, 897 | { 898 | "cell_type": "markdown", 899 | "id": "5254b62c-673f-407d-a0e6-effd3e75aabb", 900 | "metadata": {}, 901 | "source": [ 902 | "You should be able to call `BiRRT` similarly to `RRT`:" 903 | ] 904 | }, 905 | { 906 | "cell_type": "code", 907 | "execution_count": null, 908 | "id": "60443eb2-ea54-4c0b-8d24-1908a4ca710b", 909 | "metadata": {}, 910 | "outputs": [], 911 | "source": [ 912 | "system.reset()\n", 913 | "\n", 914 | "birrt = BiRRT(\n", 915 | " system,\n", 916 | " l_min=0.2,\n", 917 | " l_max=0.5,\n", 918 | " steer_delta=0.1,\n", 919 | ")\n", 920 | "\n", 921 | "birrt.solve(q_i, q_g)" 922 | ] 923 | }, 924 | { 925 | "cell_type": "code", 926 | "execution_count": null, 927 | "id": "b62f8275-b4d7-4168-ad6c-e437d0ea2887", 928 | "metadata": {}, 929 | "outputs": [], 930 | "source": [ 931 | "system.display_motion(birrt.get_path())" 932 | ] 933 | }, 934 | { 935 | "cell_type": "markdown", 936 | "id": "a291e25d-46a4-47b3-b894-8d5810196d39", 937 | "metadata": {}, 938 | "source": [ 939 | "How many iterations did it take to find a solution? Is it faster than previously with `RRT`?" 940 | ] 941 | }, 942 | { 943 | "cell_type": "markdown", 944 | "id": "84285a1d-ac59-40c9-aef1-23527d0bab54", 945 | "metadata": {}, 946 | "source": [ 947 | "## Bonus question: Bi-RRT*" 948 | ] 949 | }, 950 | { 951 | "cell_type": "markdown", 952 | "id": "6212a5aa-5794-48c6-91e5-f70c3ade34fa", 953 | "metadata": {}, 954 | "source": [ 955 | "Implement an optimal variant `BiRRTStar` of your `BiRRT` class and run it in the same configuration as the two algorithms above. What do you notice about the resulting tree? What is the improvement in overall path length between `RRT`, `BiRRT` and `BiRRTStar`?" 956 | ] 957 | } 958 | ], 959 | "metadata": { 960 | "kernelspec": { 961 | "display_name": "Python 3 (ipykernel)", 962 | "language": "python", 963 | "name": "python3" 964 | }, 965 | "language_info": { 966 | "codemirror_mode": { 967 | "name": "ipython", 968 | "version": 3 969 | }, 970 | "file_extension": ".py", 971 | "mimetype": "text/x-python", 972 | "name": "python", 973 | "nbconvert_exporter": "python", 974 | "pygments_lexer": "ipython3", 975 | "version": "3.10.12" 976 | } 977 | }, 978 | "nbformat": 4, 979 | "nbformat_minor": 5 980 | } 981 | -------------------------------------------------------------------------------- /3_reinforcement_learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "589389e6-2aad-45b6-bc03-e9b37ab1514f", 6 | "metadata": {}, 7 | "source": [ 8 | "# Reinforcement learning for legged robots\n", 9 | "\n", 10 | "## Setup\n", 11 | "\n", 12 | "Before we start, you will need to update your conda environment to use Gymnasium (maintained) rather than OpenAI Gym (discontinued). You can simply run:\n", 13 | "\n", 14 | "```\n", 15 | "conda activate robotics-mva\n", 16 | "conda install -c conda-forge gymnasium imageio mujoco=2.3.7 stable-baselines3 tensorboard\n", 17 | "```\n", 18 | "\n", 19 | "Import Gymnasium and Stable Baselines3 to check that everything is working:" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "id": "b3fa7b15-843f-4696-9bfc-141da71bf7d1", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import gymnasium as gym\n", 30 | "import stable_baselines3" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "id": "8e264559-88cc-48b7-8257-f7755fff3ce7", 36 | "metadata": {}, 37 | "source": [ 38 | "Let's import the usual suspects as well:" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "86303cf2-f879-407d-b528-6c0a80b8df20", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "import matplotlib.pylab as plt\n", 49 | "import numpy as np\n", 50 | "\n", 51 | "plt.ion()" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "b09e16a1-bf4a-403b-8709-3da11bc3c4b4", 57 | "metadata": {}, 58 | "source": [ 59 | "# Inverted pendulum environment\n", 60 | "\n", 61 | "The inverted pendulum model is not just a toy model reproducing the properties of real robot models for balancing: as it turns out, the inverted pendulum appears in the dynamics of *any* mobile robot, that is, a model with a floating-base joint at the root of the kinematic tree. (If you are curious: the inverted pendulum is a limit case of the [Newton-Euler equations](https://scaron.info/robotics/newton-euler-equations.html) corresponding to floating-base coordinates in the equations of motion $M \\ddot{q} + h = S^T \\tau + J_c^T f$, in the limit where the robot [does not vary its angular momentum](https://scaron.info/robotics/point-mass-model.html).) Thus, while we work on a simplified inverted pendulum in this notebook, concepts and tools are those used as-is on real robots, as you can verify by exploring the bonus section.\n", 62 | "\n", 63 | "Gymnasium is mainly a single-agent reinforcement learning API, but it also comes with simple environments, including an inverted pendulum sliding on a linear guide:" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "11b5d942-85fa-435e-b5ef-8c85a74ba3db", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "with gym.make(\"InvertedPendulum-v4\", render_mode=\"human\") as env:\n", 74 | " action = 0.0 * env.action_space.sample()\n", 75 | " observation, _ = env.reset()\n", 76 | " episode_return = 0.0\n", 77 | " for step in range(200):\n", 78 | " # action[0] = 5.0 * observation[1] + 0.3 * observation[0]\n", 79 | " observation, reward, terminated, truncated, _ = env.step(action)\n", 80 | " episode_return += reward\n", 81 | " if terminated or truncated:\n", 82 | " observation, _ = env.reset()\n", 83 | " print(f\"Return of the episode: {episode_return}\")" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "id": "d7322422-94db-4e12-b299-36bb40649cf7", 89 | "metadata": {}, 90 | "source": [ 91 | "The structure of the action and observation vectors are documented in [Inverted Pendulum - Gymnasium Documentation](https://gymnasium.farama.org/environments/mujoco/inverted_pendulum/). The observation, in particular, is a NumPy array with four coordinates that we recall here for reference:\n", 92 | "\n", 93 | "| Num | Observation | Min | Max | Unit |\n", 94 | "|-----|-------------|-----|-----|------|\n", 95 | "| 0 | position of the cart along the linear surface | -Inf | Inf | position (m) |\n", 96 | "| 1 | vertical angle of the pole on the cart | -Inf | Inf | angle (rad) |\n", 97 | "| 2 | linear velocity of the cart | -Inf | Inf | linear velocity (m/s) |\n", 98 | "| 3 | angular velocity of the pole on the cart | -Inf | Inf | anglular velocity (rad/s) |\n", 99 | "\n", 100 | "We will use the following labels to annotate plots:" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "a3231c70-f49d-49be-b260-aadbade7b403", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "OBSERVATION_LEGEND = (\"pitch\", \"position\", \"linear_velocity\", \"angular_velocity\")" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "id": "aa062536-204c-4312-a858-f992f3db61d6", 116 | "metadata": {}, 117 | "source": [ 118 | "Check out the documentation for the definitions of the action and rewards." 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "id": "c285d7ce-3a97-4b07-8b5f-a9b04d7721ab", 124 | "metadata": {}, 125 | "source": [ 126 | "# PID control\n", 127 | "\n", 128 | "A *massively* used class of policies is the [PID controller](https://en.wikipedia.org/wiki/Proportional%E2%80%93integral%E2%80%93derivative_controller). Let's say we have a reference observation, like $o^* = [0\\ 0\\ 0\\ 0]$ for the inverted pendulum. Denoting by $e(t) = o^* - o(t)$ the *error* of the system when it observes a given state, a continuous-time PID controller will apply the action:\n", 129 | "\n", 130 | "$$\n", 131 | "a(t) = K_p^T e(t) + K_d^T \\dot{e}(t) + K_i^T \\int e(\\tau) \\mathrm{d} \\tau\n", 132 | "$$\n", 133 | "\n", 134 | "where $K_{p}, K_i, K_d \\in \\mathbb{R}^4$ are constants called *gains* and tuned by the user. In discrete time the idea is the same:\n", 135 | "\n", 136 | "$$\n", 137 | "a_k = K_p^T e_k + K_d^T \\frac{e_k - e_{k-1}}{\\delta t} + K_i^T \\sum_{i=0}^{k} e_i {\\delta t}\n", 138 | "$$" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "id": "63c381eb-fca9-4ef4-8f99-3b1943231654", 144 | "metadata": {}, 145 | "source": [ 146 | "Let's refactor the rolling out of our episode into a standalone function:" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "id": "9c839bc6-168a-42c3-8f1c-c6b0c5411901", 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "def rollout_from_env(env, policy):\n", 157 | " episode = []\n", 158 | " observation, _ = env.reset()\n", 159 | " episode.append(observation)\n", 160 | " for step in range(1000):\n", 161 | " action = policy(observation)\n", 162 | " observation, reward, terminated, truncated, _ = env.step(action)\n", 163 | " episode.extend([action, reward, observation])\n", 164 | " if terminated or truncated:\n", 165 | " return episode\n", 166 | " return episode\n", 167 | "\n", 168 | "def rollout(policy, show: bool = True):\n", 169 | " kwargs = {\"render_mode\": \"human\"} if show else {}\n", 170 | " with gym.make(\"InvertedPendulum-v4\", **kwargs) as env:\n", 171 | " episode = rollout_from_env(env, policy)\n", 172 | " return episode" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "id": "79ff0dce-a4df-4917-bb17-2393353610a3", 178 | "metadata": {}, 179 | "source": [ 180 | "## Question 1: Write a PID controller that balances the inverted pendulum" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "id": "e7cfb28b-ff73-42ff-9524-eac8ec12f8a1", 186 | "metadata": {}, 187 | "source": [ 188 | "You can use global variables to store the (discrete) derivative and integral terms, this will be OK here as we only rollout a single trajectory:" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "id": "045ddcef-c0f7-4251-b73f-d5df5a0027e5", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "def pid_policy(observation: np.ndarray) -> np.ndarray:\n", 199 | " my_action_value: float = 0.0 # your action here\n", 200 | " return np.array([my_action_value])\n", 201 | "\n", 202 | "episode = rollout(pid_policy, show=False)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "id": "a0a005aa-87fa-4f98-8ace-f24421886bed", 208 | "metadata": {}, 209 | "source": [ 210 | "You can look at the system using `show=True`, but intuition usually builds faster when looking at relevant plots:" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "id": "9aa5decd-779c-4f0d-84fd-3eb47358b7fa", 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "observations = np.array(episode[::3])\n", 221 | "\n", 222 | "plt.plot(observations)\n", 223 | "plt.legend(OBSERVATION_LEGEND)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "id": "98d50cd2-26fa-4d3c-a671-1ed0e1b9ee93", 229 | "metadata": {}, 230 | "source": [ 231 | "Can you reach the full reward of 1000 steps?" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "8bacbd0a-2ac5-44cf-848b-8ebfb6fe35d7", 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "print(f\"Return of the episode: {sum(episode[2::3])}\")" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "id": "b17cc998-1b23-416f-8e3b-810100c223fb", 247 | "metadata": {}, 248 | "source": [ 249 | "# Policy optimization\n", 250 | "\n", 251 | "Let us now train a policy, parameterized by a multilayer perceptron (MLP), to maximize the expected return over episodes on the inverted pendulum environment." 252 | ] 253 | }, 254 | { 255 | "cell_type": "markdown", 256 | "id": "d5631f0f-1b84-4ee6-8e9c-b4f2915bd281", 257 | "metadata": {}, 258 | "source": [ 259 | "## Our very first policy\n", 260 | "\n", 261 | "We will use the proximal policy optimization (PPO) algorithm for training, using the implementation from Stable Baselines3: [PPO - Stable Baselines3 documentation](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html)." 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "id": "128867ca-e600-4ba1-abbd-1f918976fba2", 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "from stable_baselines3 import PPO\n", 272 | "\n", 273 | "with gym.make(\"InvertedPendulum-v4\", render_mode=\"human\") as env:\n", 274 | " first_policy = PPO(\"MlpPolicy\", env, verbose=0)\n", 275 | " first_policy.learn(total_timesteps=1000, progress_bar=False)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "id": "6323400b-18ca-43f6-a81e-a5e7f033a536", 281 | "metadata": {}, 282 | "source": [ 283 | "By instantiating the algorithm with no further ado, we let the library decide for us on a sane set of default hyperparameters, including:\n", 284 | "\n", 285 | "- Rollout buffers of `n_steps = 2048` steps, which we will visit `n_epochs = 10` times with mini-batches of size `batch_size = 64`.\n", 286 | "- Clipping range: ``0.2``.\n", 287 | "- No entropy regularization.\n", 288 | "- Learning rate for the Adam optimizer: ``3e-4``\n", 289 | "- Policy and value-function network architectures: two layers of 64 neurons with $\\tanh$ activation functions.\n", 290 | "\n", 291 | "We then called the `learn` function to execute PPO over a fixed total number of timesteps, here just a thousand." 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "id": "8b82173c-6609-4b83-8618-36f82c1c1373", 297 | "metadata": {}, 298 | "source": [ 299 | "Rendering actually took a significant chunk of time. Let's instantiate and keep an environment open without rendering:" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": null, 305 | "id": "460fe1c7-ee3b-450a-b09c-03b96f9086bf", 306 | "metadata": {}, 307 | "outputs": [], 308 | "source": [ 309 | "env = gym.make(\"InvertedPendulum-v4\")" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "id": "a9bd090f-ca34-41e0-9900-52977eef9c4b", 315 | "metadata": {}, 316 | "source": [ 317 | "We can use it to train much more steps in roughly the same time, reporting training metrics every `n_steps` step:" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "id": "b7262602-c277-4697-8987-ba126a87e75b", 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "second_policy = PPO(\"MlpPolicy\", env, verbose=1)\n", 328 | "second_policy.learn(total_timesteps=10_000, progress_bar=False)" 329 | ] 330 | }, 331 | { 332 | "cell_type": "markdown", 333 | "id": "6219aab8-1143-4606-a44f-b62fdffebbf1", 334 | "metadata": {}, 335 | "source": [ 336 | "Let's see how this policy performs:" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": null, 342 | "id": "f17dc178-bb9c-4155-8047-feed1e575226", 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "def policy_closure(policy):\n", 347 | " \"\"\"Utility function to turn our policy instance into a function.\n", 348 | "\n", 349 | " Args:\n", 350 | " policy: Policy to turn into a function.\n", 351 | " \n", 352 | " Returns:\n", 353 | " Function from observation to policy action.\n", 354 | " \"\"\"\n", 355 | " def policy_function(observation):\n", 356 | " action, _ = policy.predict(observation)\n", 357 | " return action\n", 358 | " return policy_function" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "id": "7e4e3cd4-4572-4c40-94b1-688d472a4b8c", 365 | "metadata": {}, 366 | "outputs": [], 367 | "source": [ 368 | "episode = rollout(policy_closure(second_policy), show=True)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "id": "941219f1-9b3c-4e66-86e5-d42f2473b149", 374 | "metadata": {}, 375 | "source": [ 376 | "Okay, it looks like we didn't train for long enough!" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "id": "195f2a85-8dd7-4f3a-8368-7427f1caadca", 382 | "metadata": {}, 383 | "source": [ 384 | "## Monitoring performance during training\n", 385 | "\n", 386 | "Let's train for longer, and use TensorBoard to keep track. We don't know how long training will take so let's put a rather large total number of steps (you can interrupt training once you observed convergence in TensorBoard):" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "id": "196da0ad-1e83-441e-ac10-6c9ecd83c224", 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "erudite_policy = PPO(\n", 397 | " \"MlpPolicy\",\n", 398 | " env,\n", 399 | " tensorboard_log=\"./inverted_pendulum_tensorboard/\",\n", 400 | " verbose=0,\n", 401 | ")\n", 402 | "\n", 403 | "erudite_policy.learn(\n", 404 | " total_timesteps=1_000_000,\n", 405 | " progress_bar=False,\n", 406 | " tb_log_name=\"erudite\",\n", 407 | ")" 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "id": "ad91d14e-53e2-443f-b7ab-edd69d480add", 413 | "metadata": {}, 414 | "source": [ 415 | "Run TensorBoard on the directory thus created to open a dashboard in your Web browser:\n", 416 | "\n", 417 | "```\n", 418 | "tensorboard --logdir ./inverted_pendulum_tensorboard/\n", 419 | "```\n", 420 | "\n", 421 | "The link will typically be http://localhost:6006 (port number increases if you run TensorBoard multiple times in parallel). Tips:\n", 422 | "\n", 423 | "- Click the Settings icon in the top-right corner and enable \"Reload data\"\n", 424 | "- Uncheck \"Ignore outliers in chart scaling\" (your preference)" 425 | ] 426 | }, 427 | { 428 | "cell_type": "markdown", 429 | "id": "68771e35-48cd-43be-89ff-0055dc196d0b", 430 | "metadata": {}, 431 | "source": [ 432 | "## Saving our policy\n", 433 | "\n", 434 | "Now that you spent some computing to optimize an actual policy, better save it to disk:" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "id": "effeccfc-8b95-48e1-98c4-1b96838bb28e", 441 | "metadata": {}, 442 | "outputs": [], 443 | "source": [ 444 | "erudite_policy.save(\"pendulum_erudite\")" 445 | ] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "id": "ea09c56e-8647-414d-aae0-5e1b16ba3a0f", 450 | "metadata": {}, 451 | "source": [ 452 | "You can then reload it later by:" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": null, 458 | "id": "994dae4f-b651-4488-925e-2ba369eeedc7", 459 | "metadata": {}, 460 | "outputs": [], 461 | "source": [ 462 | "erudite_policy = PPO.load(\"pendulum_erudite\", env=env)" 463 | ] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "id": "dded4f1e-ae57-4c94-b2e1-3a6d019aecc4", 468 | "metadata": {}, 469 | "source": [ 470 | "## Question 2: How many steps does it take to train a successful policy?\n", 471 | "\n", 472 | "We consider a policy successful if it consistently achieves the maximum return of 1000." 473 | ] 474 | }, 475 | { 476 | "cell_type": "raw", 477 | "id": "42c6d68d-4812-4222-97da-a6699803b986", 478 | "metadata": {}, 479 | "source": [ 480 | "== Your reply here ==" 481 | ] 482 | }, 483 | { 484 | "cell_type": "markdown", 485 | "id": "553b846f-db13-43ba-81cb-57b039852c86", 486 | "metadata": {}, 487 | "source": [ 488 | "## A more realistic environment\n", 489 | "\n", 490 | "Real systems suffer from the two main issues we saw in the [Perception and estimation](https://scaron.info/robotics-mva/#5-perception-estimation) class: *bias* and *variance*. In this section, we model bias in actuation and perception by adding delays (via low-pass filtering) to respectively the action and observation vectors. Empirically this is an effective model, as for instance it contributes to sim2real transfer on Upkie. To add these delays, we use an [`environment wrapper`](https://gymnasium.farama.org/api/wrappers/), which is a convenient way to compose environments, used in both the Gymnasium and Stable Baselines3 APIs:" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "id": "6e8a3140-7ee7-4d6f-afd9-19d6ca4816c0", 497 | "metadata": {}, 498 | "outputs": [], 499 | "source": [ 500 | "class DelayWrapper(gym.Wrapper):\n", 501 | " def __init__(self, env, time_constant: float = 0.2):\n", 502 | " \"\"\"Wrap environment with some actuation and perception modeling.\n", 503 | "\n", 504 | " Args:\n", 505 | " env: Environment to wrap.\n", 506 | " time_constant: Constant of the internal low-pass filter, in seconds.\n", 507 | " Feel free to play with different values but leave it to the default\n", 508 | " of 0.2 seconds when handing out your homework.\n", 509 | "\n", 510 | " Note:\n", 511 | " Delays are implemented by a low-pass filter. The same time constant\n", 512 | " is used for both actions and observations, which is not realistic, but\n", 513 | " makes for less tutorial code ;)\n", 514 | " \"\"\"\n", 515 | " alpha = env.dt / time_constant\n", 516 | " assert 0.0 < alpha < 1.0\n", 517 | " super().__init__(env)\n", 518 | " self._alpha = alpha\n", 519 | " self._prev_action = 0.0 * env.action_space.sample()\n", 520 | " self._prev_observation = np.zeros(4)\n", 521 | "\n", 522 | " def low_pass_filter(self, old_value, new_value):\n", 523 | " return old_value + self._alpha * (new_value - old_value)\n", 524 | " \n", 525 | " def step(self, action):\n", 526 | " new_action = self.low_pass_filter(self._prev_action, action)\n", 527 | " observation, reward, terminated, truncated, info = self.env.step(new_action)\n", 528 | " new_observation = self.low_pass_filter(self._prev_observation, observation)\n", 529 | " self._prev_action = new_action\n", 530 | " self._prev_observation = new_observation\n", 531 | " return new_observation, reward, terminated, truncated, info\n", 532 | "\n", 533 | "delay_env = DelayWrapper(env)" 534 | ] 535 | }, 536 | { 537 | "cell_type": "markdown", 538 | "id": "b1b5de5e-50ca-4049-bb5f-b9203919e0ba", 539 | "metadata": {}, 540 | "source": [ 541 | "We can check how our current policy fares against the delayed environment. Spoiler alert: no great." 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": null, 547 | "id": "4e1508e6-e04f-4b22-8009-80baae1bae7d", 548 | "metadata": {}, 549 | "outputs": [], 550 | "source": [ 551 | "delay_episode = rollout_from_env(delay_env, policy_closure(erudite_policy))\n", 552 | "observations = np.array(delay_episode[::3])\n", 553 | "\n", 554 | "plt.plot(observations[:, :2])\n", 555 | "plt.legend(OBSERVATION_LEGEND)" 556 | ] 557 | }, 558 | { 559 | "cell_type": "markdown", 560 | "id": "70af3932-751e-47e4-8334-bd55be62aaa1", 561 | "metadata": {}, 562 | "source": [ 563 | "## Question 3: Can't we just re-train a policy on the new environment?\n", 564 | "\n", 565 | "At this point of the tutorial this is a rethorical question, but we should check anyway. Re-train a policy on the delayed environment and comment on its performance:" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": null, 571 | "id": "693aa97c-3ee2-4cbd-bc06-7cb224e8bc86", 572 | "metadata": {}, 573 | "outputs": [], 574 | "source": [ 575 | "# Your code here" 576 | ] 577 | }, 578 | { 579 | "cell_type": "raw", 580 | "id": "48928906-bcd9-40d5-b17e-35fd06d6c6ac", 581 | "metadata": {}, 582 | "source": [ 583 | "== Your observations here ==" 584 | ] 585 | }, 586 | { 587 | "cell_type": "markdown", 588 | "id": "c0e2df30-259f-477a-ab14-d39c17e5f15f", 589 | "metadata": {}, 590 | "source": [ 591 | "## The Real Question 3: Why do delays degrade both runtime and training performance?\n", 592 | "\n", 593 | "Loss in runtime performance refers to the one we observed when executing a policy trained without delay on an environment with delays. Loss in training performance refers to the fact that, even when we train a new policy on the environment with delays, by the end of training it does not achieve maximum return." 594 | ] 595 | }, 596 | { 597 | "cell_type": "raw", 598 | "id": "3b7459d5-93d0-49cb-85c7-2172e2b08073", 599 | "metadata": {}, 600 | "source": [ 601 | "== Your explanation here ==" 602 | ] 603 | }, 604 | { 605 | "cell_type": "markdown", 606 | "id": "e63a441a-a84d-49ab-aecc-7362dee66b91", 607 | "metadata": {}, 608 | "source": [ 609 | "Propose and implement a way to overcome this. Train the resulting policy in a variable called `iron_policy`." 610 | ] 611 | }, 612 | { 613 | "cell_type": "code", 614 | "execution_count": null, 615 | "id": "b22770ba-4e58-4989-b62c-d5aa1734336c", 616 | "metadata": {}, 617 | "outputs": [], 618 | "source": [ 619 | "# Your code here" 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "execution_count": null, 625 | "id": "5a7a876f-e78e-47bb-9b42-9423618d1e42", 626 | "metadata": {}, 627 | "outputs": [], 628 | "source": [ 629 | "iron_policy.save(\"iron_policy\")" 630 | ] 631 | }, 632 | { 633 | "cell_type": "markdown", 634 | "id": "d2a70b63-7fda-4c0f-b777-ef0dc2128ab2", 635 | "metadata": {}, 636 | "source": [ 637 | "Roll out an episode and plot the outcome to show that your policy handles delays properly." 638 | ] 639 | }, 640 | { 641 | "cell_type": "code", 642 | "execution_count": null, 643 | "id": "45f13a2b-6a1c-4d44-bb84-2fffaf6bf6e3", 644 | "metadata": {}, 645 | "outputs": [], 646 | "source": [ 647 | "# Your episode rollout here\n", 648 | "\n", 649 | "plt.plot(np.array(observations)[:, :2])\n", 650 | "plt.legend(OBSERVATION_LEGEND)" 651 | ] 652 | }, 653 | { 654 | "cell_type": "markdown", 655 | "id": "1e12fcf1-88b9-4899-b79d-866c67e4a3f5", 656 | "metadata": {}, 657 | "source": [ 658 | "## Question 4: Can you improve sampling efficiency?\n", 659 | "\n", 660 | "This last question is open: what can you change in the pipeline to train a policy that achieves maximum return using less samples? Report on at least one thing that allowed you to train with less environment steps." 661 | ] 662 | }, 663 | { 664 | "cell_type": "raw", 665 | "id": "0f5cb9a5-fd18-4077-a6fc-83fa5377de96", 666 | "metadata": {}, 667 | "source": [ 668 | "== Your report here ==" 669 | ] 670 | }, 671 | { 672 | "cell_type": "markdown", 673 | "id": "131966f5-9524-4b44-9843-0c1a662ba2e1", 674 | "metadata": {}, 675 | "source": [ 676 | "Here is a state-of-the-art™ utility function if you want to experiment with scheduling some of the ``Callable[[float], float]`` [hyperparameters](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html#parameters):" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": null, 682 | "id": "3de11ab9-2534-4723-8868-1582772d038c", 683 | "metadata": {}, 684 | "outputs": [], 685 | "source": [ 686 | "def affine_schedule(y_0: float, y_1: float):\n", 687 | " \"\"\"Affine schedule as a function over the [0, 1] interval.\n", 688 | "\n", 689 | " Args:\n", 690 | " y_0: Function value at zero.\n", 691 | " y_1: Function value at one.\n", 692 | " \n", 693 | " Returns:\n", 694 | " Corresponding affine function.\n", 695 | " \"\"\"\n", 696 | " def schedule(x: float) -> float:\n", 697 | " return y_0 + x * (y_1 - y_0)\n", 698 | " return schedule" 699 | ] 700 | }, 701 | { 702 | "cell_type": "markdown", 703 | "id": "b21d78dd-f80e-4183-8fa7-55c803e38404", 704 | "metadata": {}, 705 | "source": [ 706 | "And here is a wrapper template if you want to experiment with reward shaping:" 707 | ] 708 | }, 709 | { 710 | "cell_type": "code", 711 | "execution_count": null, 712 | "id": "2cf9a3ed-8f76-4fac-98f2-3a23df818deb", 713 | "metadata": {}, 714 | "outputs": [], 715 | "source": [ 716 | "class CustomRewardWrapper(gym.Wrapper):\n", 717 | " def __init__(self, env):\n", 718 | " super().__init__(env)\n", 719 | "\n", 720 | " def step(self, action):\n", 721 | " observation, reward, terminated, truncated, info = self.env.step(action)\n", 722 | " new_reward = 0.0 # your formula here\n", 723 | " return observation, new_reward, terminated, truncated, info" 724 | ] 725 | }, 726 | { 727 | "cell_type": "markdown", 728 | "id": "2c4dd0df-6dc7-4d51-b29b-c77b49bde437", 729 | "metadata": {}, 730 | "source": [ 731 | "# Bonus: training a policy for a real robot\n", 732 | "\n", 733 | "This section is entirely optional and will only work on Linux or macOS. In this part, we follow the same training pipeline but with the open source software of [Upkie](https://hackaday.io/project/185729-upkie-wheeled-biped-robots)." 734 | ] 735 | }, 736 | { 737 | "cell_type": "markdown", 738 | "id": "9634ff93-f09f-4e0a-8d0f-547848f3900b", 739 | "metadata": {}, 740 | "source": [ 741 | "## Setup\n", 742 | "\n", 743 | "\n", 744 | "\n", 745 | "First, make sure you have a C++ compiler (setup one-liners: [Fedora](https://github.com/upkie/upkie/discussions/100), [Ubuntu](https://github.com/upkie/upkie/discussions/101)). You can run an Upkie simulation right from the command line. It won't install anything on your machine, everything will run locally from the repository:\n", 746 | "\n", 747 | "```console\n", 748 | "git clone https://github.com/upkie/upkie.git\n", 749 | "cd upkie\n", 750 | "git checkout fb9a0ab1f67a8014c08b34d7c0d317c7a8f71662\n", 751 | "./start_simulation.sh\n", 752 | "```\n", 753 | "\n", 754 | "**NB:** this tutorial is written for the specific commit checked out above. If some instructions don't work it's likely you forgot to check it out.\n", 755 | "\n", 756 | "We will use the Python API of the robot to test things from this notebook, or from custom scripts. Install it from PyPI in your Conda environment:\n", 757 | "\n", 758 | "```\n", 759 | "pip install upkie\n", 760 | "```" 761 | ] 762 | }, 763 | { 764 | "cell_type": "markdown", 765 | "id": "ba44abc0-f7e9-4c2b-9d4e-a3579213e138", 766 | "metadata": {}, 767 | "source": [ 768 | "## Stepping the environment\n", 769 | "\n", 770 | "If everything worked well, you should be able to step an environment as follows:" 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "execution_count": null, 776 | "id": "acedf0d6-fc2f-43f4-9ff6-a8e12dbd7ae0", 777 | "metadata": {}, 778 | "outputs": [], 779 | "source": [ 780 | "import gymnasium as gym\n", 781 | "import upkie.envs\n", 782 | "\n", 783 | "upkie.envs.register()\n", 784 | "\n", 785 | "episode_return = 0.0\n", 786 | "with gym.make(\"UpkieGroundVelocity-v1\", frequency=200.0) as env:\n", 787 | " observation, _ = env.reset() # connects to the spine (simulator or real robot)\n", 788 | " action = 0.0 * env.action_space.sample()\n", 789 | " for step in range(1000):\n", 790 | " pitch = observation[0]\n", 791 | " action[0] = 10.0 * pitch # 1D action: [ground_velocity]\n", 792 | " observation, reward, terminated, truncated, _ = env.step(action)\n", 793 | " episode_return += reward\n", 794 | " if terminated or truncated:\n", 795 | " observation, _ = env.reset()\n", 796 | "\n", 797 | "print(f\"We have stepped the environment {step + 1} times\")\n", 798 | "print(f\"The return of our episode is {episode_return}\")" 799 | ] 800 | }, 801 | { 802 | "cell_type": "markdown", 803 | "id": "031343b5-cf94-46ae-98f3-a4c5ebbc037c", 804 | "metadata": {}, 805 | "source": [ 806 | "(If you see a message \"Waiting for spine /vulp to start\", it means the simulation is not running.)" 807 | ] 808 | }, 809 | { 810 | "cell_type": "markdown", 811 | "id": "aecfd91f-676c-4d6f-beb0-a286dc681ae3", 812 | "metadata": {}, 813 | "source": [ 814 | "We can double-check the last observation from the episode:" 815 | ] 816 | }, 817 | { 818 | "cell_type": "code", 819 | "execution_count": null, 820 | "id": "ed6d972f-4cc9-4005-b9a1-4a7433a19938", 821 | "metadata": {}, 822 | "outputs": [], 823 | "source": [ 824 | "def report_last_observation(observation):\n", 825 | " print(\"The last observation of the episode is:\")\n", 826 | " print(f\"- Pitch from torso to world: {observation[0]:.2} rad\")\n", 827 | " print(f\"- Ground position: {observation[1]:.2} m\")\n", 828 | " print(f\"- Angular velocity from torso to world in torso: {observation[2]:.2} rad/s\")\n", 829 | " print(f\"- Ground velocity: {observation[3]:.2} m/s\")\n", 830 | " \n", 831 | "report_last_observation(observation)" 832 | ] 833 | }, 834 | { 835 | "cell_type": "markdown", 836 | "id": "d5a269e3-d876-4d05-88e5-b0d73be6f939", 837 | "metadata": {}, 838 | "source": [ 839 | "## Question B1: PID control\n", 840 | "\n", 841 | "Adapt your code from Question 1 to this environment:" 842 | ] 843 | }, 844 | { 845 | "cell_type": "code", 846 | "execution_count": null, 847 | "id": "1e8e9f1e-f2a1-4a18-aa38-256a425d018c", 848 | "metadata": {}, 849 | "outputs": [], 850 | "source": [ 851 | "def policy_b1(observation):\n", 852 | " return np.array([0.0]) # replace with your solution\n", 853 | "\n", 854 | "\n", 855 | "def run(policy, nb_steps: int):\n", 856 | " episode_return = 0.0\n", 857 | " with gym.make(\"UpkieGroundVelocity-v1\", frequency=200.0) as env:\n", 858 | " observation, _ = env.reset() # connects to the spine (simulator or real robot)\n", 859 | " for step in range(nb_steps):\n", 860 | " action = policy_b1(observation)\n", 861 | " observation, reward, terminated, truncated, _ = env.step(action)\n", 862 | " if terminated or truncated:\n", 863 | " print(\"Fall detected!\")\n", 864 | " return episode_return\n", 865 | " report_last_observation(observation)\n", 866 | " return episode_return\n", 867 | "\n", 868 | "\n", 869 | "episode_return = run(policy_b1, 1000)\n", 870 | "print(f\"The return of our episode is {episode_return}\")" 871 | ] 872 | }, 873 | { 874 | "cell_type": "markdown", 875 | "id": "e999eb22-a94a-4a58-ac06-9a7dbc15a7ee", 876 | "metadata": {}, 877 | "source": [ 878 | "## Training a new policy\n", 879 | "\n", 880 | "The Upkie repository ships three agents based on PID control, model predictive control and reinforcement learning. We now focus on the latter, called the \"PPO balancer\".\n", 881 | "\n", 882 | "Check that you can run the training part by running, from the root of the repository:\n", 883 | "\n", 884 | "```\n", 885 | "./tools/bazel run //agents/ppo_balancer:train -- --nb-envs 1 --show\n", 886 | "```\n", 887 | "\n", 888 | "A simulation window should pop, and verbose output from SB3 should be printed to your terminal.\n", 889 | "\n", 890 | "By default, training data will be logged to `/tmp`. You can select a different output path by setting the `UPKIE_TRAINING_PATH` environment variable in your shell. For instance:\n", 891 | "\n", 892 | "```\n", 893 | "export UPKIE_TRAINING_PATH=\"${HOME}/src/upkie/training\"\n", 894 | "```\n", 895 | "\n", 896 | "Run TensorBoard from the training directory:\n", 897 | "\n", 898 | "```\n", 899 | "tensorboard --logdir ${UPKIE_TRAINING_PATH} # or /tmp if you keep the default\n", 900 | "```\n", 901 | "\n", 902 | "Each training will be named after a word picked at random in an English dictionary." 903 | ] 904 | }, 905 | { 906 | "cell_type": "markdown", 907 | "id": "a7e47aad-7787-409c-af7e-b83bfccaa592", 908 | "metadata": {}, 909 | "source": [ 910 | "## Selecting the number of processes\n", 911 | "\n", 912 | "We can increase the number of parallel CPU environments ``--nb-envs`` to a value suitable to your computer. Let training run for a minute and check `time/fps`. Increase the number of environments and compare the stationary regime of `time/fps`. You should see a performance increase when adding the first few environments, followed by a declined when there are two many parallel processes compared to your number of CPU cores. Pick the value that works best for you." 913 | ] 914 | }, 915 | { 916 | "cell_type": "markdown", 917 | "id": "696a0943-cc10-4cd0-a2d8-d5313dbe37e5", 918 | "metadata": {}, 919 | "source": [ 920 | "## Running a trained policy\n", 921 | "\n", 922 | "Copy the file `final.zip` from your trained policy directory to `agents/ppo_balancer/policy/params.zip`. Start a simulation and run the policy by:\n", 923 | "\n", 924 | "```\n", 925 | "./tools/bazel run //agents/ppo_balancer\n", 926 | "```\n", 927 | "\n", 928 | "What kind of behavior do you observe?" 929 | ] 930 | }, 931 | { 932 | "cell_type": "raw", 933 | "id": "eaabe73c-f412-44b5-a714-241077720d01", 934 | "metadata": {}, 935 | "source": [ 936 | "== Your observations here ==" 937 | ] 938 | }, 939 | { 940 | "cell_type": "markdown", 941 | "id": "6c356c81-b5ef-4364-a5db-c8e2600e104a", 942 | "metadata": {}, 943 | "source": [ 944 | "## Question B2: Improve this baseline" 945 | ] 946 | }, 947 | { 948 | "cell_type": "markdown", 949 | "id": "527ecb8c-7292-432f-b0d3-b90c36de8719", 950 | "metadata": {}, 951 | "source": [ 952 | "The policy you are testing here is not the one we saw in class. Open question: improve on it using any of the methods we discussed. Measure the improvement by `ep_len_mean` or any other quantitative criterion:" 953 | ] 954 | }, 955 | { 956 | "cell_type": "raw", 957 | "id": "ce7d720b-17b8-493d-8128-e66c6571d3ff", 958 | "metadata": {}, 959 | "source": [ 960 | "== Your experiments here ==\n", 961 | "\n", 962 | "- Tried: ... / Measured outcome: ..." 963 | ] 964 | } 965 | ], 966 | "metadata": { 967 | "kernelspec": { 968 | "display_name": "Python 3 (ipykernel)", 969 | "language": "python", 970 | "name": "python3" 971 | }, 972 | "language_info": { 973 | "codemirror_mode": { 974 | "name": "ipython", 975 | "version": 3 976 | }, 977 | "file_extension": ".py", 978 | "mimetype": "text/x-python", 979 | "name": "python", 980 | "nbconvert_exporter": "python", 981 | "pygments_lexer": "ipython3", 982 | "version": "3.10.13" 983 | } 984 | }, 985 | "nbformat": 4, 986 | "nbformat_minor": 5 987 | } 988 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Stéphane Caron 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robotics MVA 2023 2 | 3 | This repository contains tutorial notebooks for the 2023 [Robotics](https://www.master-mva.com/cours/robotics/) class at MVA. 4 | 5 | ## Get started 6 | 7 | Clone this repository: 8 | 9 | ```bash 10 | git clone https://github.com/stephane-caron/robotics-mva-2023.git 11 | ``` 12 | 13 | Install miniconda: 14 | 15 | - Linux: https://docs.conda.io/en/latest/miniconda.html 16 | - macOS: https://docs.conda.io/en/latest/miniconda.html 17 | - Windows: https://www.anaconda.com/download/ 18 | 19 | Don't forget to add the conda snippet to your shell configuration (for instance ``~/.bashrc``). After that, you can run all labs in a dedicated Python environment that will not affect your system's regular Python envirfonment. 20 | 21 | ### Run a notebook 22 | 23 | - Go to your local copy of the repository. 24 | - Open a terminal. 25 | - Create the conda environment: 26 | 27 | ```bash 28 | conda env create -f robotics-mva.yml 29 | ``` 30 | 31 | From there on, to work on a notebook, you will only need to activate the environment: 32 | 33 | ```bash 34 | conda activate robotics-mva 35 | ``` 36 | 37 | Then launch the notebook with: 38 | 39 | ```bash 40 | jupyter-lab 41 | ``` 42 | 43 | The notebook will be accessible from your web browser at [localhost:8888](http://localhost:8888). 44 | 45 | Meshcat visualisation can be accessed in full page at `localhost:700N/static/` where N denotes the Nth MeshCat instance created by your notebook kernel. 46 | 47 | ## Troubleshooting 48 | 49 | - Make sure the virtual environment is activated for ``jupyter-lab`` to work. 50 | - Make sure you call ``jupyter-lab`` so that Python packages pathes are configured properly. 51 | - In particular, ``jupyter-notebook`` may not have paths configured properly, resulting in failed package imports. 52 | 53 | ## Updating the notebooks 54 | 55 | If the repository changes (for instance when new tutorials are pushed) you will need to update your local copy of it by "pulling" from the repository. To do so, go to the directory containing the tutorials and run: 56 | 57 | ``` 58 | git pull 59 | ``` 60 | 61 | If you already have local changes to a notebook `something.ipynb`, either you already know how to use git and you can commit them, or you don't and the safest way for you to update is to: 62 | 63 | - Copy your modified `something.ipynb` somewhere else 64 | - Revert it to its original version: ``git checkout -f something.ipynb`` 65 | - Pull updates from the remote repository: ``git pull`` 66 | -------------------------------------------------------------------------------- /robotics-mva.yml: -------------------------------------------------------------------------------- 1 | name: robotics-mva 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - example-robot-data 7 | - gymnasium 8 | - imageio 9 | - ipywidgets 10 | - jupyterlab 11 | - matplotlib 12 | - meshcat-python 13 | - mujoco=2.3.7 14 | - pinocchio 15 | - python=3.10 16 | - quadprog 17 | - rich 18 | - scipy 19 | - stable-baselines3 20 | - tensorboard 21 | - tqdm 22 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_ur5_parallel import load_ur5_parallel 2 | -------------------------------------------------------------------------------- /utils/collision_wrapper.py: -------------------------------------------------------------------------------- 1 | import pinocchio as pin 2 | import numpy as np 3 | 4 | class CollisionWrapper: 5 | def __init__(self,robot,viz=None): 6 | self.robot=robot 7 | self.viz=viz 8 | 9 | self.rmodel = robot.model 10 | self.rdata = self.rmodel.createData() 11 | self.gmodel = self.robot.gmodel 12 | self.gdata = self.gmodel.createData() 13 | self.gdata.collisionRequests.enable_contact = True 14 | 15 | 16 | def computeCollisions(self,q,vq=None): 17 | res = pin.computeCollisions(self.rmodel,self.rdata,self.gmodel,self.gdata,q,False) 18 | pin.computeDistances(self.rmodel,self.rdata,self.gmodel,self.gdata,q) 19 | pin.computeJointJacobians(self.rmodel,self.rdata,q) 20 | if vq is not None: 21 | pin.forwardKinematics(self.rmodel,self.rdata,q,vq,0*vq) 22 | return res 23 | 24 | def getCollisionList(self): 25 | '''Return a list of triplets [ index,collision,result ] where index is the 26 | index of the collision pair, colision is gmodel.collisionPairs[index] 27 | and result is gdata.collisionResults[index]. 28 | ''' 29 | return [ [ir,self.gmodel.collisionPairs[ir],r] 30 | for ir,r in enumerate(self.gdata.collisionResults) if r.isCollision() ] 31 | 32 | def _getCollisionJacobian(self,col,res): 33 | '''Compute the jacobian for one collision only. ''' 34 | contact = res.getContact(0) 35 | g1 = self.gmodel.geometryObjects[col.first] 36 | g2 = self.gmodel.geometryObjects[col.second] 37 | oMc = pin.SE3(pin.Quaternion.FromTwoVectors(np.array([0,0,1]),contact.normal).matrix(),contact.pos) 38 | 39 | joint1 = g1.parentJoint 40 | joint2 = g2.parentJoint 41 | oMj1 = self.rdata.oMi[joint1] 42 | oMj2 = self.rdata.oMi[joint2] 43 | 44 | cMj1 = oMc.inverse()*oMj1 45 | cMj2 = oMc.inverse()*oMj2 46 | 47 | J1=pin.getJointJacobian(self.rmodel,self.rdata,joint1,pin.ReferenceFrame.LOCAL) 48 | J2=pin.getJointJacobian(self.rmodel,self.rdata,joint2,pin.ReferenceFrame.LOCAL) 49 | Jc1=cMj1.action@J1 50 | Jc2=cMj2.action@J2 51 | J = (Jc2-Jc1)[2,:] 52 | return J 53 | 54 | def _getCollisionJdotQdot(self,col,res): 55 | '''Compute the Coriolis acceleration for one collision only. ''' 56 | contact = res.getContact(0) 57 | g1 = self.gmodel.geometryObjects[col.first] 58 | g2 = self.gmodel.geometryObjects[col.second] 59 | oMc = pin.SE3(pin.Quaternion.FromTwoVectors(np.array([0,0,1]),contact.normal).matrix(),contact.pos) 60 | 61 | joint1 = g1.parentJoint 62 | joint2 = g2.parentJoint 63 | oMj1 = self.rdata.oMi[joint1] 64 | oMj2 = self.rdata.oMi[joint2] 65 | 66 | cMj1 = oMc.inverse()*oMj1 67 | cMj2 = oMc.inverse()*oMj2 68 | 69 | a1 = self.rdata.a[joint1] 70 | a2 = self.rdata.a[joint2] 71 | a = (cMj1*a1-cMj2*a2).linear[2] 72 | return a 73 | 74 | def getCollisionJacobian(self,collisions=None): 75 | '''From a collision list, return the Jacobian corresponding to the normal direction. ''' 76 | if collisions is None: collisions = self.getCollisionList() 77 | if len(collisions)==0: return np.ndarray([0,self.rmodel.nv]) 78 | J = np.vstack([ self._getCollisionJacobian(c,r) for (i,c,r) in collisions ]) 79 | return J 80 | 81 | def getCollisionJdotQdot(self,collisions=None): 82 | if collisions is None: collisions = self.getCollisionList() 83 | if len(collisions)==0: return np.array([]) 84 | a0 = np.vstack([ self._getCollisionJdotQdot(c,r) for (i,c,r) in collisions ]) 85 | return a0.squeeze() 86 | 87 | def getCollisionDistances(self,collisions=None): 88 | if collisions is None: collisions = self.getCollisionList() 89 | if len(collisions)==0: return np.array([]) 90 | dist = np.array([ self.gdata.distanceResults[i].min_distance for (i,c,r) in collisions ]) 91 | return dist 92 | 93 | 94 | # --- DISPLAY ----------------------------------------------------------------------------------- 95 | # --- DISPLAY ----------------------------------------------------------------------------------- 96 | # --- DISPLAY ----------------------------------------------------------------------------------- 97 | 98 | def initDisplay(self,viz=None): 99 | if viz is not None: self.viz = viz 100 | assert(self.viz is not None) 101 | 102 | self.patchName = 'world/contact_%d_%s' 103 | self.ncollisions=10 104 | self.createDisplayPatchs(0) 105 | 106 | def createDisplayPatchs(self,ncollisions): 107 | 108 | if ncollisions == self.ncollisions: return 109 | elif ncollisions0: break 180 | if not i % 20: viz.display(q) 181 | 182 | viz.display(q) 183 | 184 | col.displayCollisions() 185 | p = cols[0][1] 186 | ci = cols[0][2].getContact(0) 187 | 188 | import pickle 189 | with open('/tmp/bug.pickle', 'wb') as file: 190 | pickle.dump([ col.gdata.oMg[11], 191 | col.gdata.oMg[3], 192 | #col.gmodel.geometryObjects[11].geometry, 193 | ] , file) 194 | 195 | dist=col.getCollisionDistances() 196 | J = col.getCollisionJacobian() 197 | -------------------------------------------------------------------------------- /utils/datastructures/bucketkdtree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from cspace_metric.datastructures.tree import NodeBinaryTree 3 | 4 | 5 | class BucketKDNode(NodeBinaryTree): 6 | # Adapt to wraparound spaces ? 7 | def __init__( 8 | self, 9 | parent=None, 10 | points=None, 11 | dim=3, 12 | dim_scale=None, 13 | bucketsize=10 14 | ): 15 | # Parent node 16 | self.parent = parent 17 | # Hyperparam 18 | self.bucketsize = bucketsize 19 | self.dim = dim 20 | self.dim_scale = dim_scale if dim_scale is not None else np.ones(dim) 21 | # Node are initially leaf without children and with a bucket 22 | self.is_leaf = True 23 | # Points buckets storage_data for leaf node 24 | self._points = np.zeros((bucketsize, dim)) 25 | self._n_points = 0 26 | # Value usefull for not leaf node 27 | self.split_dim = None 28 | self.split_val = None 29 | self.left = None 30 | self.right = None 31 | # Bounds tracker 32 | self.lower = np.ones((dim)) * np.inf 33 | self.upper = - np.ones((dim)) * np.inf 34 | # Add given points 35 | self.add_points(points) 36 | 37 | def _update_bounds(self, points): 38 | # ensure non null points 39 | if not isinstance(points, np.ndarray): 40 | if not points: 41 | return 42 | points = np.array(points) 43 | if points.size is 0: 44 | return 45 | 46 | mini = np.min(points, axis=0) 47 | maxi = np.max(points, axis=0) 48 | min_update = mini < self.lower 49 | self.lower[min_update] = mini[min_update] 50 | max_update = maxi > self.upper 51 | self.upper[max_update] = maxi[max_update] 52 | 53 | def add_point(self, p): 54 | self._update_bounds([p]) 55 | if self.is_leaf: 56 | # The node is still a leaf, just add to bucket 57 | # We are sure there is enough space 58 | self._points[self._n_points] = p 59 | self._n_points += 1 60 | if self._n_points == self.bucketsize: 61 | # It is now full, transform node as non leaf 62 | self._create_children() 63 | else: 64 | # Kd split to add to children 65 | if p[self.split_dim] <= self.split_val: 66 | self.left.add_point(p) 67 | else: 68 | self.right.add_point(p) 69 | 70 | def add_points(self, points): 71 | # ensure non null points 72 | if not isinstance(points, np.ndarray): 73 | if not points: 74 | return 75 | points = np.array(points) 76 | if points.size is 0: 77 | return 78 | 79 | if self.is_leaf: 80 | # We add the maximum we can to the bucket 81 | n = min(len(points), self.bucketsize - self._n_points) 82 | batch = points[:n] 83 | self._points[self._n_points:self._n_points + n] = batch 84 | self._n_points += n 85 | self._update_bounds(batch) 86 | if self._n_points == self.bucketsize: 87 | # It is full, transform node as non leaf 88 | self._create_children() 89 | # Add eventual remaining points 90 | self.add_points(points[n:]) 91 | else: 92 | # We add points to child given their position 93 | self._update_bounds(points) 94 | infe = points[:, self.split_dim] <= self.split_val 95 | self.left.add_points(points[infe]) 96 | self.right.add_points(points[~infe]) 97 | 98 | def _create_children(self): 99 | assert self.is_leaf and self.bucketsize == self._n_points 100 | # The creation must appeaar only for full leaf (after an add) 101 | # At this points the bounds are the one of the bucket 102 | ranges = self.upper - self.lower 103 | split_dim = np.argmax(ranges * self.dim_scale) 104 | # No more a leaf, create attribute for non leaf 105 | self.is_leaf = False 106 | self.split_dim = split_dim 107 | self.split_val = self.lower[split_dim] + ranges[split_dim] / 2 108 | self.left = BucketKDNode( 109 | self, 110 | dim=self.dim, dim_scale=self.dim_scale, bucketsize=self.bucketsize 111 | ) 112 | self.right = BucketKDNode( 113 | self, 114 | dim=self.dim, dim_scale=self.dim_scale, bucketsize=self.bucketsize 115 | ) 116 | # Now diffuse bucket points in children and erase local bucket 117 | self.add_points(self._points) 118 | self._points = None 119 | 120 | def nearest_neighbour(self, query, dist_to_many, max_dist=None): 121 | if self.is_leaf: 122 | dists = dist_to_many(query, self._points[:self._n_points]) 123 | i_min = np.argmin(dists) 124 | if max_dist is None or dists[i_min] < max_dist: 125 | return dists[i_min], self._points[i_min] 126 | return None, None 127 | else: 128 | cursor = self 129 | # Go down to the best leaf 130 | while not cursor.is_leaf: 131 | if query[cursor.split_dim] <= cursor.split_val: 132 | cursor = cursor.left 133 | else: 134 | cursor = cursor.right 135 | best_d, best_p = cursor.nearest_neighbour( 136 | query, dist_to_many, max_dist 137 | ) 138 | if best_d is not None: 139 | max_dist = best_d 140 | # Go up by recursively checking ambiguous split 141 | while cursor is not self: 142 | cursor = cursor.parent 143 | # check ambiguity to non coming child 144 | # Get nearest in the child if needed 145 | i = cursor.split_dim 146 | s = self.dim_scale[i] 147 | x = query[i] 148 | d, p = None, None 149 | if x <= cursor.split_val: 150 | # We come from left, check right 151 | if s * (cursor.right.lower[i] - x) < max_dist: 152 | # There is an ambiguity, check right 153 | d, p = cursor.right.nearest_neighbour( 154 | query, dist_to_many, max_dist 155 | ) 156 | else: 157 | # Same for right 158 | if s * (x - cursor.left.upper[i]) < max_dist: 159 | d, p = cursor.left.nearest_neighbour( 160 | query, dist_to_many, max_dist 161 | ) 162 | if d is not None: 163 | # We have found something better in ambiguity 164 | best_d, best_p = d, p 165 | max_dist = best_d 166 | 167 | return best_d, best_p 168 | 169 | 170 | class SBucketKDNode(NodeBinaryTree): 171 | """ 172 | Alternative with external storage_data 173 | """ 174 | def __init__( 175 | self, 176 | storage_data, 177 | parent=None, 178 | points_idx=None, 179 | dim=3, 180 | dim_scale=None, 181 | bucketsize=10 182 | ): 183 | # Store storage_data 184 | self.storage_data = storage_data 185 | # Parent node 186 | self.parent = parent 187 | # Hyperparam 188 | self.bucketsize = bucketsize 189 | self.dim = dim 190 | self.dim_scale = dim_scale if dim_scale is not None else np.ones(dim) 191 | # Node are initially leaf without children and with a bucket 192 | self.is_leaf = True 193 | # Points buckets storage_data for leaf node 194 | self._points_idx = np.zeros(bucketsize, dtype=np.intp) 195 | self._n_points = 0 196 | # Value usefull for not leaf node 197 | self.split_dim = None 198 | self.split_val = None 199 | self.left = None 200 | self.right = None 201 | # Bounds tracker 202 | self.lower = np.ones(dim, dtype=float) * np.inf 203 | self.upper = - np.ones(dim, dtype=float) * np.inf 204 | # Add given points 205 | self.add_points(points_idx) 206 | 207 | def _update_bounds(self, points_idx): 208 | # ensure non null points 209 | if not isinstance(points_idx, np.ndarray): 210 | if not points_idx: 211 | return 212 | points_idx = np.array(points_idx, dtype=np.intp) 213 | if points_idx.size is 0: 214 | return 215 | mini = np.min(self.storage_data[points_idx], axis=0) 216 | maxi = np.max(self.storage_data[points_idx], axis=0) 217 | min_update = mini < self.lower 218 | self.lower[min_update] = mini[min_update] 219 | max_update = maxi > self.upper 220 | self.upper[max_update] = maxi[max_update] 221 | 222 | def add_point(self, p_idx): 223 | self._update_bounds([p_idx]) 224 | if self.is_leaf: 225 | # The node is still a leaf, just add to bucket 226 | # We are sure there is enough space 227 | self._points_idx[self._n_points] = p_idx 228 | self._n_points += 1 229 | if self._n_points == self.bucketsize: 230 | # It is now full, transform node as non leaf 231 | self._create_children() 232 | else: 233 | # Kd split to add to children 234 | if self.storage_data[p_idx, self.split_dim] <= self.split_val: 235 | self.left.add_point(p_idx) 236 | else: 237 | self.right.add_point(p_idx) 238 | 239 | def add_points(self, points_idx): 240 | # ensure non null points 241 | if not isinstance(points_idx, np.ndarray): 242 | if not points_idx: 243 | return 244 | points_idx = np.array(points_idx, dtype=np.intp) 245 | if points_idx.size is 0: 246 | return 247 | 248 | if self.is_leaf: 249 | # We add the maximum we can to the bucket 250 | n = min(len(points_idx), self.bucketsize - self._n_points) 251 | batch = points_idx[:n] 252 | self._points_idx[self._n_points:self._n_points + n] = batch 253 | self._n_points += n 254 | self._update_bounds(batch) 255 | if self._n_points == self.bucketsize: 256 | # It is full, transform node as non leaf 257 | self._create_children() 258 | # Add eventual remaining points 259 | self.add_points(points_idx[n:]) 260 | else: 261 | # We add points to child given their position 262 | self._update_bounds(points_idx) 263 | infe = self.storage_data[points_idx, self.split_dim] <= self.split_val 264 | self.left.add_points(points_idx[infe]) 265 | self.right.add_points(points_idx[~infe]) 266 | 267 | def _create_children(self): 268 | assert self.is_leaf and self.bucketsize == self._n_points 269 | # The creation must appeaar only for full leaf (after an add) 270 | # At this points the bounds are the one of the bucket 271 | ranges = self.upper - self.lower 272 | split_dim = np.argmax(ranges * self.dim_scale) 273 | # No more a leaf, create attribute for non leaf 274 | self.is_leaf = False 275 | self.split_dim = split_dim 276 | self.split_val = self.lower[split_dim] + ranges[split_dim] / 2 277 | self.left = SBucketKDNode( 278 | self.storage_data, self, 279 | dim=self.dim, dim_scale=self.dim_scale, bucketsize=self.bucketsize 280 | ) 281 | self.right = SBucketKDNode( 282 | self.storage_data, self, 283 | dim=self.dim, dim_scale=self.dim_scale, bucketsize=self.bucketsize 284 | ) 285 | # Now diffuse bucket points in children and erase local bucket 286 | self.add_points(self._points_idx) 287 | self._points_idx = None 288 | 289 | def nearest_neighbour(self, query, dist_to_many, max_dist=None): 290 | if self.is_leaf: 291 | dists = dist_to_many( 292 | query, self.storage_data[self._points_idx[:self._n_points]] 293 | ) 294 | i_min = np.argmin(dists) 295 | if max_dist is None or dists[i_min] < max_dist: 296 | return dists[i_min], self._points_idx[i_min] 297 | return None, None 298 | else: 299 | cursor = self 300 | # Go down to the best leaf 301 | while not cursor.is_leaf: 302 | if query[cursor.split_dim] <= cursor.split_val: 303 | cursor = cursor.left 304 | else: 305 | cursor = cursor.right 306 | best_d, best_p = cursor.nearest_neighbour( 307 | query, dist_to_many, max_dist 308 | ) 309 | if best_d is not None: 310 | max_dist = best_d 311 | # Go up by recursively checking ambiguous split 312 | while cursor is not self: 313 | cursor = cursor.parent 314 | # check ambiguity to non coming child 315 | # Get nearest in the child if needed 316 | i = cursor.split_dim 317 | s = self.dim_scale[i] 318 | x = query[i] 319 | d, p = None, None 320 | if x <= cursor.split_val: 321 | # We come from left, check right 322 | if s * (cursor.right.lower[i] - x) < max_dist: 323 | # There is an ambiguity, check right 324 | d, p = cursor.right.nearest_neighbour( 325 | query, dist_to_many, max_dist 326 | ) 327 | else: 328 | # Same for right 329 | if s * (x - cursor.left.upper[i]) < max_dist: 330 | d, p = cursor.left.nearest_neighbour( 331 | query, dist_to_many, max_dist 332 | ) 333 | if d is not None: 334 | # We have found something better in ambiguity 335 | best_d, best_p = d, p 336 | max_dist = best_d 337 | 338 | return best_d, best_p 339 | -------------------------------------------------------------------------------- /utils/datastructures/mtree/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import utils.datastructures.mtree.functions as functions 4 | from utils.datastructures.mtree.heap_queue import HeapQueue 5 | 6 | 7 | _INFINITY = float("inf") 8 | _ItemWithDistances = namedtuple( 9 | '_ItemWithDistances', 'item, distance, min_distance' 10 | ) 11 | 12 | 13 | class _RootNodeReplacement(Exception): 14 | def __init__(self, new_root): 15 | super(_RootNodeReplacement, self).__init__(new_root) 16 | self.new_root = new_root 17 | 18 | 19 | class _SplitNodeReplacement(Exception): 20 | def __init__(self, new_nodes): 21 | super(_SplitNodeReplacement, self).__init__(new_nodes) 22 | self.new_nodes = new_nodes 23 | 24 | 25 | class _NodeUnderCapacity(Exception): 26 | pass 27 | 28 | 29 | class _IndexItem(object): 30 | def __init__(self, data): 31 | self.data = data 32 | self.radius = 0 33 | self.distance_to_parent = None 34 | 35 | def _check(self, mtree): 36 | self._check_data() 37 | self._check_radius() 38 | self._check_distance_to_parent() 39 | return 1 40 | 41 | def _check_data(self): 42 | assert self.data is not None 43 | 44 | def _check_radius(self): 45 | assert self.radius is not None 46 | assert self.radius >= 0 47 | 48 | def _check_distance_to_parent(self): 49 | assert not isinstance(self, _RootNodeTrait), self 50 | assert self.distance_to_parent is not None 51 | assert self.distance_to_parent >= 0 52 | 53 | 54 | class _Node(_IndexItem): 55 | 56 | def __init__(self, data): 57 | super(_Node, self).__init__(data) 58 | self.children = {} 59 | 60 | def add_data(self, data, distance, mtree): 61 | self.do_add_data(data, distance, mtree) 62 | self.check_max_capacity(mtree) 63 | 64 | def check_max_capacity(self, mtree): 65 | if len(self.children) > mtree.max_node_capacity: 66 | data_objects = frozenset(self.children.keys()) 67 | cached_distance_function = functions.make_cached_distance_function( 68 | mtree.distance_function 69 | ) 70 | 71 | (promoted_data1, partition1, 72 | promoted_data2, partition2) = mtree.split_function( 73 | data_objects, cached_distance_function 74 | ) 75 | 76 | split_node_repl_class = self.get_split_node_replacement_class() 77 | new_nodes = [] 78 | for promoted_data, partition in [(promoted_data1, partition1), 79 | (promoted_data2, partition2)]: 80 | new_node = split_node_repl_class(promoted_data) 81 | for data in partition: 82 | child = self.children[data] 83 | distance = cached_distance_function(promoted_data, data) 84 | new_node.add_child(child, distance, mtree) 85 | new_nodes.append(new_node) 86 | 87 | raise _SplitNodeReplacement(new_nodes) 88 | 89 | def remove_data(self, data, distance, mtree): 90 | self.do_remove_data(data, distance, mtree) 91 | if len(self.children) < self.get_min_capacity(mtree): 92 | raise _NodeUnderCapacity() 93 | 94 | def update_metrics(self, child, distance): 95 | child.distance_to_parent = distance 96 | self.update_radius(child) 97 | 98 | def update_radius(self, child): 99 | self.radius = max(self.radius, child.distance_to_parent + child.radius) 100 | 101 | def _check(self, mtree): 102 | super(_Node, self)._check(mtree) 103 | self._check_min_capacity(mtree) 104 | self._check_max_capacity(mtree) 105 | 106 | child_height = None 107 | for data, child in self.children.items(): 108 | assert child.data == data 109 | self._check_child_class(child) 110 | self._check_child_metrics(child, mtree) 111 | 112 | height = child._check(mtree) 113 | if child_height is None: 114 | child_height = height 115 | else: 116 | assert child_height == height 117 | 118 | return child_height + 1 119 | 120 | def _check_max_capacity(self, mtree): 121 | assert len(self.children) <= mtree.max_node_capacity 122 | 123 | def _check_child_class(self, child): 124 | expected_class = self._get_expected_child_class() 125 | assert isinstance(child, expected_class) 126 | 127 | def _check_child_metrics(self, child, mtree): 128 | dist = mtree.distance_function(child.data, self.data) 129 | assert child.distance_to_parent == dist, ( 130 | child.data, 131 | self.data, 132 | child.distance_to_parent, 133 | dist, 134 | abs(child.distance_to_parent - dist) 135 | ) 136 | assert child.distance_to_parent + child.radius <= self.radius 137 | 138 | 139 | class _RootNodeTrait(_Node): 140 | 141 | def _check_distance_to_parent(self): 142 | assert self.distance_to_parent is None 143 | 144 | 145 | class _NonRootNodeTrait(_Node): 146 | 147 | def get_min_capacity(self, mtree): 148 | return mtree.min_node_capacity 149 | 150 | def _check_min_capacity(self, mtree): 151 | assert len(self.children) >= mtree.min_node_capacity 152 | 153 | 154 | class _LeafNodeTrait(_Node): 155 | 156 | def do_add_data(self, data, distance, mtree): 157 | entry = _Entry(data) 158 | assert data not in self.children 159 | self.children[data] = entry 160 | assert data in self.children 161 | self.update_metrics(entry, distance) 162 | 163 | def add_child(self, child, distance, mtree): 164 | assert child.data not in self.children 165 | self.children[child.data] = child 166 | assert child.data in self.children 167 | self.update_metrics(child, distance) 168 | 169 | @staticmethod 170 | def get_split_node_replacement_class(): 171 | return _LeafNode 172 | 173 | def do_remove_data(self, data, distance, mtree): 174 | del self.children[data] 175 | 176 | @staticmethod 177 | def _get_expected_child_class(): 178 | return _Entry 179 | 180 | 181 | class _NonLeafNodeTrait(_Node): 182 | 183 | CandidateChild = namedtuple('CandidateChild', 'node, distance, metric') 184 | 185 | def do_add_data(self, data, distance, mtree): 186 | 187 | min_radius_increase_needed = self.CandidateChild(None, None, _INFINITY) 188 | nearest_distance = self.CandidateChild(None, None, _INFINITY) 189 | 190 | distances = mtree.distance_function( 191 | data, [child.data for child in self.children.values()] 192 | ) 193 | 194 | for distance, child in zip(distances, self.children.values()): 195 | if distance > child.radius: 196 | radius_increase = distance - child.radius 197 | if radius_increase < min_radius_increase_needed.metric: 198 | min_radius_increase_needed = self.CandidateChild( 199 | child, distance, radius_increase 200 | ) 201 | else: 202 | if distance < nearest_distance.metric: 203 | nearest_distance = self.CandidateChild( 204 | child, distance, distance 205 | ) 206 | 207 | if nearest_distance.node is not None: 208 | chosen = nearest_distance 209 | else: 210 | chosen = min_radius_increase_needed 211 | 212 | child = chosen.node 213 | try: 214 | child.add_data(data, chosen.distance, mtree) 215 | except _SplitNodeReplacement as e: 216 | assert len(e.new_nodes) == 2 217 | # Replace current child with new nodes 218 | del self.children[child.data] 219 | distances = mtree.distance_function( 220 | data, [new_child.data for new_child in e.new_nodes] 221 | ) 222 | for distance, new_child in zip(distances, e.new_nodes): 223 | self.add_child(new_child, distance, mtree) 224 | else: 225 | self.update_radius(child) 226 | 227 | def add_child(self, new_child, distance, mtree): 228 | new_children = [(new_child, distance)] 229 | while new_children: 230 | new_child, distance = new_children.pop() 231 | 232 | if new_child.data not in self.children: 233 | self.children[new_child.data] = new_child 234 | self.update_metrics(new_child, distance) 235 | else: 236 | existing_child = self.children[new_child.data] 237 | assert existing_child.data == new_child.data 238 | 239 | # Transfer the _children_ of the new_child to the existing_child 240 | for grandchild in new_child.children.values(): 241 | existing_child.add_child(grandchild, grandchild.distance_to_parent, mtree) 242 | 243 | try: 244 | existing_child.check_max_capacity(mtree) 245 | except _SplitNodeReplacement as e: 246 | del self.children[new_child.data] 247 | distances = mtree.distance_function( 248 | self.data, [new_node.data for new_node in e.new_nodes] 249 | ) 250 | for distance, new_node in zip(distances, e.new_nodes): 251 | new_children.append((new_node, distance)) 252 | 253 | @staticmethod 254 | def get_split_node_replacement_class(): 255 | return _InternalNode 256 | 257 | def do_remove_data(self, data, distance, mtree): 258 | for child in self.children.values(): 259 | if abs(distance - child.distance_to_parent) <= child.radius: # TODO: confirm 260 | distance_to_child = mtree.distance_function(data, child.data) 261 | if distance_to_child <= child.radius: 262 | try: 263 | child.remove_data(data, distance_to_child, mtree) 264 | except KeyError: 265 | # If KeyError was raised, then the data was not found in the child 266 | pass 267 | except _NodeUnderCapacity: 268 | expanded_child = self.balance_children(child, mtree) 269 | self.update_radius(expanded_child) 270 | return 271 | else: 272 | self.update_radius(child) 273 | return 274 | raise KeyError() 275 | 276 | def balance_children(self, the_child, mtree): 277 | # Tries to find another_child which can donate a grandchild to the_child. 278 | 279 | nearest_donor = None 280 | distance_nearest_donor = _INFINITY 281 | 282 | nearest_merge_candidate = None 283 | distance_nearest_merge_candidate = _INFINITY 284 | 285 | distances = mtree.distance_function( 286 | the_child.data, 287 | [another_child.data for another_child in (child for child in self.children.values() if child is not the_child)] 288 | ) 289 | 290 | for distance, another_child in zip(distances, (child for child in self.children.values() if child is not the_child)): 291 | if len(another_child.children) > another_child.get_min_capacity(mtree): 292 | if distance < distance_nearest_donor: 293 | distance_nearest_donor = distance 294 | nearest_donor = another_child 295 | else: 296 | if distance < distance_nearest_merge_candidate: 297 | distance_nearest_merge_candidate = distance 298 | nearest_merge_candidate = another_child 299 | 300 | if nearest_donor is None: 301 | # Merge 302 | distances = mtree.distance_function( 303 | nearest_merge_candidate.data, 304 | [grandchild.data for grandchild in the_child.children.values()] 305 | 306 | ) 307 | for distance, grandchild in zip(distances, the_child.children.values()): 308 | nearest_merge_candidate.add_child(grandchild, distance, mtree) 309 | 310 | del self.children[the_child.data] 311 | return nearest_merge_candidate 312 | else: 313 | # Donate 314 | # Look for the nearest grandchild 315 | nearest_grandchild_distance = _INFINITY 316 | distances = mtree.distance_function( 317 | the_child.data, 318 | [grandchild.data for grandchild in nearest_donor.children.values()] 319 | ) 320 | for distance, grandchild in zip(distances, nearest_donor.children.values()): 321 | if distance < nearest_grandchild_distance: 322 | nearest_grandchild_distance = distance 323 | nearest_grandchild = grandchild 324 | 325 | del nearest_donor.children[nearest_grandchild.data] 326 | the_child.add_child(nearest_grandchild, nearest_grandchild_distance, mtree) 327 | return the_child 328 | 329 | @staticmethod 330 | def _get_expected_child_class(): 331 | return (_InternalNode, _LeafNode) 332 | 333 | 334 | class _RootLeafNode(_RootNodeTrait, _LeafNodeTrait): 335 | 336 | def remove_data(self, data, distance, mtree): 337 | try: 338 | super(_RootLeafNode, self).remove_data(data, distance, mtree) 339 | except _NodeUnderCapacity: 340 | assert len(self.children) == 0 341 | raise _RootNodeReplacement(None) 342 | 343 | @staticmethod 344 | def get_min_capacity(mtree): 345 | return 1 346 | 347 | def _check_min_capacity(self, mtree): 348 | assert len(self.children) >= 1 349 | 350 | 351 | class _RootNode(_RootNodeTrait, _NonLeafNodeTrait): 352 | 353 | def remove_data(self, data, distance, mtree): 354 | try: 355 | super(_RootNode, self).remove_data(data, distance, mtree) 356 | except _NodeUnderCapacity: 357 | # Promote the only child to root 358 | (the_child,) = self.children.values() 359 | if isinstance(the_child, _InternalNode): 360 | new_root_class = _RootNode 361 | else: 362 | assert isinstance(the_child, _LeafNode) 363 | new_root_class = _RootLeafNode 364 | 365 | new_root = new_root_class(the_child.data) 366 | distances = mtree.distance_function( 367 | new_root.data, 368 | [grandchild.data for grandchild in the_child.children.values()] 369 | ) 370 | 371 | for distance, grandchild in zip(distances, the_child.children.values()): 372 | new_root.add_child(grandchild, distance, mtree) 373 | 374 | raise _RootNodeReplacement(new_root) 375 | 376 | @staticmethod 377 | def get_min_capacity(mtree): 378 | return 2 379 | 380 | def _check_min_capacity(self, mtree): 381 | assert len(self.children) >= 2 382 | 383 | 384 | class _InternalNode(_NonRootNodeTrait, _NonLeafNodeTrait): 385 | pass 386 | 387 | 388 | class _LeafNode(_NonRootNodeTrait, _LeafNodeTrait): 389 | pass 390 | 391 | 392 | class _Entry(_IndexItem): 393 | pass 394 | 395 | 396 | class MTree(object): 397 | """ 398 | A data structure for indexing objects based on their proximity. 399 | 400 | The data objects must be any hashable object and the support functions 401 | (distance and split functions) must understand them. 402 | 403 | See http://en.wikipedia.org/wiki/M-tree 404 | """ 405 | 406 | ResultItem = namedtuple('ResultItem', 'data, distance') 407 | 408 | def __init__( 409 | self, 410 | distance_function, 411 | min_node_capacity=50, 412 | max_node_capacity=None, 413 | split_function=functions.make_split_function( 414 | functions.random_promotion, functions.balanced_partition 415 | ) 416 | ): 417 | """ 418 | Creates an M-Tree. 419 | 420 | The argument min_node_capacity must be at least 2. 421 | The argument max_node_capacity should be at least 2*min_node_capacity-1. 422 | The optional argument distance_function must be a function which calculates 423 | the distance between two data objects. 424 | The optional argument split_function must be a function which chooses two 425 | data objects and then partitions the set of data into two subsets 426 | according to the chosen objects. Its arguments are the set of data objects 427 | and the distance_function. Must return a sequence with the following four values: 428 | - First chosen data object. 429 | - Subset with at least [min_node_capacity] objects based on the first 430 | chosen data object. Must contain the first chosen data object. 431 | - Second chosen data object. 432 | - Subset with at least [min_node_capacity] objects based on the second 433 | chosen data object. Must contain the second chosen data object. 434 | """ 435 | if min_node_capacity < 2: 436 | raise ValueError("min_node_capacity must be at least 2") 437 | if max_node_capacity is None: 438 | max_node_capacity = 2 * min_node_capacity - 1 439 | if max_node_capacity <= min_node_capacity: 440 | raise ValueError("max_node_capacity must be greater than min_node_capacity") 441 | 442 | self.min_node_capacity = min_node_capacity 443 | self.max_node_capacity = max_node_capacity 444 | self.distance_function = distance_function 445 | self.split_function = split_function 446 | self.root = None 447 | 448 | def add(self, data): 449 | """ 450 | Adds and indexes an object. 451 | 452 | The object must not currently already be indexed! 453 | """ 454 | if self.root is None: 455 | self.root = _RootLeafNode(data) 456 | self.root.add_data(data, 0, self) 457 | else: 458 | distance = self.distance_function(data, self.root.data) 459 | try: 460 | self.root.add_data(data, distance, self) 461 | except _SplitNodeReplacement as e: 462 | assert len(e.new_nodes) == 2 463 | self.root = _RootNode(self.root.data) 464 | distances = self.distance_function( 465 | self.root.data, 466 | [new_node.data for new_node in e.new_nodes] 467 | ) 468 | for distance, new_node in zip(distances, e.new_nodes): 469 | self.root.add_child(new_node, distance, self) 470 | 471 | add_point = add 472 | 473 | def remove(self, data): 474 | """ 475 | Removes an object from the index. 476 | """ 477 | if self.root is None: 478 | raise KeyError() 479 | 480 | distance_to_root = self.distance_function(data, self.root.data) 481 | try: 482 | self.root.remove_data(data, distance_to_root, self) 483 | except _RootNodeReplacement as e: 484 | self.root = e.new_root 485 | 486 | def get_nearest(self, query_data, range=_INFINITY, limit=_INFINITY): 487 | """ 488 | Returns an iterator on the indexed data nearest to the query_data. The 489 | returned items are tuples containing the data and its distance to the 490 | query_data, in increasing distance order. The results can be limited by 491 | the range (maximum distance from the query_data) and limit arguments. 492 | """ 493 | if self.root is None: 494 | # No indexed data! 495 | return 496 | 497 | distance = self.distance_function(query_data, self.root.data) 498 | min_distance = max(distance - self.root.radius, 0) 499 | 500 | pending_queue = HeapQueue( 501 | content=[_ItemWithDistances(item=self.root, distance=distance, min_distance=min_distance)], 502 | key=lambda iwd: iwd.min_distance, 503 | ) 504 | 505 | nearest_queue = HeapQueue(key=lambda iwd: iwd.distance) 506 | 507 | yielded_count = 0 508 | 509 | while pending_queue: 510 | pending = pending_queue.pop() 511 | 512 | node = pending.item 513 | assert isinstance(node, _Node) 514 | 515 | distances = self.distance_function( 516 | query_data, 517 | [child.data for child in node.children.values()] 518 | ) 519 | for child_distance, child in zip(distances, node.children.values()): 520 | if abs(pending.distance - child.distance_to_parent) - child.radius <= range: 521 | child_min_distance = max(child_distance - child.radius, 0) 522 | if child_min_distance <= range: 523 | iwd = _ItemWithDistances(item=child, distance=child_distance, min_distance=child_min_distance) 524 | if isinstance(child, _Entry): 525 | nearest_queue.push(iwd) 526 | else: 527 | pending_queue.push(iwd) 528 | 529 | # Tries to yield known results so far 530 | if pending_queue: 531 | next_pending = pending_queue.head() 532 | next_pending_min_distance = next_pending.min_distance 533 | else: 534 | next_pending_min_distance = _INFINITY 535 | 536 | while nearest_queue: 537 | next_nearest = nearest_queue.head() 538 | assert isinstance(next_nearest, _ItemWithDistances) 539 | if next_nearest.distance <= next_pending_min_distance: 540 | _ = nearest_queue.pop() 541 | assert _ is next_nearest 542 | 543 | yield self.ResultItem(data=next_nearest.item.data, distance=next_nearest.distance) 544 | yielded_count += 1 545 | if yielded_count >= limit: 546 | # Limit reached 547 | return 548 | else: 549 | break 550 | 551 | def nearest_neighbour(self, point): 552 | return next(self.get_nearest(point, limit=1)) 553 | 554 | def _check(self): 555 | if self.root is not None: 556 | self.root._check(self) 557 | -------------------------------------------------------------------------------- /utils/datastructures/mtree/faster.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import utils.datastructures.mtree.functions as functions 4 | from utils.datastructures.mtree.heap_queue import HeapQueue 5 | 6 | 7 | _INFINITY = float("inf") 8 | _ItemWithDistances = namedtuple( 9 | '_ItemWithDistances', 'item, distance, min_distance' 10 | ) 11 | 12 | 13 | class _RootNodeReplacement(Exception): 14 | def __init__(self, new_root): 15 | super(_RootNodeReplacement, self).__init__(new_root) 16 | self.new_root = new_root 17 | 18 | 19 | class _SplitNodeReplacement(Exception): 20 | def __init__(self, new_nodes): 21 | super(_SplitNodeReplacement, self).__init__(new_nodes) 22 | self.new_nodes = new_nodes 23 | 24 | 25 | class _NodeUnderCapacity(Exception): 26 | pass 27 | 28 | 29 | class _IndexItem(object): 30 | def __init__(self, data): 31 | self.data = data 32 | self.radius = 0 33 | self.distance_to_parent = None 34 | 35 | def _check(self, mtree): 36 | self._check_data() 37 | self._check_radius() 38 | self._check_distance_to_parent() 39 | return 1 40 | 41 | def _check_data(self): 42 | assert self.data is not None 43 | 44 | def _check_radius(self): 45 | assert self.radius is not None 46 | assert self.radius >= 0 47 | 48 | def _check_distance_to_parent(self): 49 | assert not isinstance(self, _RootNodeTrait), self 50 | assert self.distance_to_parent is not None 51 | assert self.distance_to_parent >= 0 52 | 53 | 54 | class _Node(_IndexItem): 55 | 56 | def __init__(self, data): 57 | super(_Node, self).__init__(data) 58 | self.children = {} 59 | 60 | def add_data(self, data, distance, mtree): 61 | self.do_add_data(data, distance, mtree) 62 | self.check_max_capacity(mtree) 63 | 64 | def check_max_capacity(self, mtree): 65 | if len(self.children) > mtree.max_node_capacity: 66 | data_objects = frozenset(self.children.keys()) 67 | cached_distance_function = functions.make_cached_distance_function( 68 | mtree.distance_function 69 | ) 70 | 71 | (promoted_data1, partition1, 72 | promoted_data2, partition2) = mtree.split_function( 73 | data_objects, cached_distance_function 74 | ) 75 | 76 | split_node_repl_class = self.get_split_node_replacement_class() 77 | new_nodes = [] 78 | for promoted_data, partition in [(promoted_data1, partition1), 79 | (promoted_data2, partition2)]: 80 | new_node = split_node_repl_class(promoted_data) 81 | for data in partition: 82 | child = self.children[data] 83 | distance = cached_distance_function(promoted_data, data) 84 | new_node.add_child(child, distance, mtree) 85 | new_nodes.append(new_node) 86 | 87 | raise _SplitNodeReplacement(new_nodes) 88 | 89 | def remove_data(self, data, distance, mtree): 90 | self.do_remove_data(data, distance, mtree) 91 | if len(self.children) < self.get_min_capacity(mtree): 92 | raise _NodeUnderCapacity() 93 | 94 | def update_metrics(self, child, distance): 95 | child.distance_to_parent = distance 96 | self.update_radius(child) 97 | 98 | def update_radius(self, child): 99 | self.radius = max(self.radius, child.distance_to_parent + child.radius) 100 | 101 | def _check(self, mtree): 102 | super(_Node, self)._check(mtree) 103 | self._check_min_capacity(mtree) 104 | self._check_max_capacity(mtree) 105 | 106 | child_height = None 107 | for data, child in self.children.items(): 108 | assert child.data == data 109 | self._check_child_class(child) 110 | self._check_child_metrics(child, mtree) 111 | 112 | height = child._check(mtree) 113 | if child_height is None: 114 | child_height = height 115 | else: 116 | assert child_height == height 117 | 118 | return child_height + 1 119 | 120 | def _check_max_capacity(self, mtree): 121 | assert len(self.children) <= mtree.max_node_capacity 122 | 123 | def _check_child_class(self, child): 124 | expected_class = self._get_expected_child_class() 125 | assert isinstance(child, expected_class) 126 | 127 | def _check_child_metrics(self, child, mtree): 128 | dist = mtree.distance_function(child.data, self.data) 129 | assert child.distance_to_parent == dist, ( 130 | child.data, 131 | self.data, 132 | child.distance_to_parent, 133 | dist, 134 | abs(child.distance_to_parent - dist) 135 | ) 136 | assert child.distance_to_parent + child.radius <= self.radius 137 | 138 | 139 | class _RootNodeTrait(_Node): 140 | 141 | def _check_distance_to_parent(self): 142 | assert self.distance_to_parent is None 143 | 144 | 145 | class _NonRootNodeTrait(_Node): 146 | 147 | def get_min_capacity(self, mtree): 148 | return mtree.min_node_capacity 149 | 150 | def _check_min_capacity(self, mtree): 151 | assert len(self.children) >= mtree.min_node_capacity 152 | 153 | 154 | class _LeafNodeTrait(_Node): 155 | 156 | def do_add_data(self, data, distance, mtree): 157 | entry = _Entry(data) 158 | assert data not in self.children 159 | self.children[data] = entry 160 | assert data in self.children 161 | self.update_metrics(entry, distance) 162 | 163 | def add_child(self, child, distance, mtree): 164 | assert child.data not in self.children 165 | self.children[child.data] = child 166 | assert child.data in self.children 167 | self.update_metrics(child, distance) 168 | 169 | @staticmethod 170 | def get_split_node_replacement_class(): 171 | return _LeafNode 172 | 173 | def do_remove_data(self, data, distance, mtree): 174 | del self.children[data] 175 | 176 | @staticmethod 177 | def _get_expected_child_class(): 178 | return _Entry 179 | 180 | 181 | class _NonLeafNodeTrait(_Node): 182 | 183 | CandidateChild = namedtuple('CandidateChild', 'node, distance, metric') 184 | 185 | def do_add_data(self, data, distance, mtree): 186 | 187 | min_radius_increase_needed = self.CandidateChild(None, None, _INFINITY) 188 | nearest_distance = self.CandidateChild(None, None, _INFINITY) 189 | 190 | distances = mtree.distance_function( 191 | data, [child.data for child in self.children.values()] 192 | ) 193 | 194 | for distance, child in zip(distances, self.children.values()): 195 | if distance > child.radius: 196 | radius_increase = distance - child.radius 197 | if radius_increase < min_radius_increase_needed.metric: 198 | min_radius_increase_needed = self.CandidateChild( 199 | child, distance, radius_increase 200 | ) 201 | else: 202 | if distance < nearest_distance.metric: 203 | nearest_distance = self.CandidateChild( 204 | child, distance, distance 205 | ) 206 | 207 | if nearest_distance.node is not None: 208 | chosen = nearest_distance 209 | else: 210 | chosen = min_radius_increase_needed 211 | 212 | child = chosen.node 213 | try: 214 | child.add_data(data, chosen.distance, mtree) 215 | except _SplitNodeReplacement as e: 216 | assert len(e.new_nodes) == 2 217 | # Replace current child with new nodes 218 | del self.children[child.data] 219 | distances = mtree.distance_function( 220 | data, [new_child.data for new_child in e.new_nodes] 221 | ) 222 | for distance, new_child in zip(distances, e.new_nodes): 223 | self.add_child(new_child, distance, mtree) 224 | else: 225 | self.update_radius(child) 226 | 227 | def add_child(self, new_child, distance, mtree): 228 | new_children = [(new_child, distance)] 229 | while new_children: 230 | new_child, distance = new_children.pop() 231 | 232 | if new_child.data not in self.children: 233 | self.children[new_child.data] = new_child 234 | self.update_metrics(new_child, distance) 235 | else: 236 | existing_child = self.children[new_child.data] 237 | assert existing_child.data == new_child.data 238 | 239 | # Transfer the _children_ of the new_child to the existing_child 240 | for grandchild in new_child.children.values(): 241 | existing_child.add_child(grandchild, grandchild.distance_to_parent, mtree) 242 | 243 | try: 244 | existing_child.check_max_capacity(mtree) 245 | except _SplitNodeReplacement as e: 246 | del self.children[new_child.data] 247 | distances = mtree.distance_function( 248 | self.data, [new_node.data for new_node in e.new_nodes] 249 | ) 250 | for distance, new_node in zip(distances, e.new_nodes): 251 | new_children.append((new_node, distance)) 252 | 253 | @staticmethod 254 | def get_split_node_replacement_class(): 255 | return _InternalNode 256 | 257 | def do_remove_data(self, data, distance, mtree): 258 | for child in self.children.values(): 259 | if abs(distance - child.distance_to_parent) <= child.radius: # TODO: confirm 260 | distance_to_child = mtree.distance_function(data, child.data) 261 | if distance_to_child <= child.radius: 262 | try: 263 | child.remove_data(data, distance_to_child, mtree) 264 | except KeyError: 265 | # If KeyError was raised, then the data was not found in the child 266 | pass 267 | except _NodeUnderCapacity: 268 | expanded_child = self.balance_children(child, mtree) 269 | self.update_radius(expanded_child) 270 | return 271 | else: 272 | self.update_radius(child) 273 | return 274 | raise KeyError() 275 | 276 | def balance_children(self, the_child, mtree): 277 | # Tries to find another_child which can donate a grandchild to the_child. 278 | 279 | nearest_donor = None 280 | distance_nearest_donor = _INFINITY 281 | 282 | nearest_merge_candidate = None 283 | distance_nearest_merge_candidate = _INFINITY 284 | 285 | distances = mtree.distance_function( 286 | the_child.data, 287 | [another_child.data for another_child in (child for child in self.children.values() if child is not the_child)] 288 | ) 289 | 290 | for distance, another_child in zip(distances, (child for child in self.children.values() if child is not the_child)): 291 | if len(another_child.children) > another_child.get_min_capacity(mtree): 292 | if distance < distance_nearest_donor: 293 | distance_nearest_donor = distance 294 | nearest_donor = another_child 295 | else: 296 | if distance < distance_nearest_merge_candidate: 297 | distance_nearest_merge_candidate = distance 298 | nearest_merge_candidate = another_child 299 | 300 | if nearest_donor is None: 301 | # Merge 302 | distances = mtree.distance_function( 303 | nearest_merge_candidate.data, 304 | [grandchild.data for grandchild in the_child.children.values()] 305 | 306 | ) 307 | for distance, grandchild in zip(distances, the_child.children.values()): 308 | nearest_merge_candidate.add_child(grandchild, distance, mtree) 309 | 310 | del self.children[the_child.data] 311 | return nearest_merge_candidate 312 | else: 313 | # Donate 314 | # Look for the nearest grandchild 315 | nearest_grandchild_distance = _INFINITY 316 | distances = mtree.distance_function( 317 | the_child.data, 318 | [grandchild.data for grandchild in nearest_donor.children.values()] 319 | ) 320 | for distance, grandchild in zip(distances, nearest_donor.children.values()): 321 | if distance < nearest_grandchild_distance: 322 | nearest_grandchild_distance = distance 323 | nearest_grandchild = grandchild 324 | 325 | del nearest_donor.children[nearest_grandchild.data] 326 | the_child.add_child(nearest_grandchild, nearest_grandchild_distance, mtree) 327 | return the_child 328 | 329 | @staticmethod 330 | def _get_expected_child_class(): 331 | return (_InternalNode, _LeafNode) 332 | 333 | 334 | class _RootLeafNode(_RootNodeTrait, _LeafNodeTrait): 335 | 336 | def remove_data(self, data, distance, mtree): 337 | try: 338 | super(_RootLeafNode, self).remove_data(data, distance, mtree) 339 | except _NodeUnderCapacity: 340 | assert len(self.children) == 0 341 | raise _RootNodeReplacement(None) 342 | 343 | @staticmethod 344 | def get_min_capacity(mtree): 345 | return 1 346 | 347 | def _check_min_capacity(self, mtree): 348 | assert len(self.children) >= 1 349 | 350 | 351 | class _RootNode(_RootNodeTrait, _NonLeafNodeTrait): 352 | 353 | def remove_data(self, data, distance, mtree): 354 | try: 355 | super(_RootNode, self).remove_data(data, distance, mtree) 356 | except _NodeUnderCapacity: 357 | # Promote the only child to root 358 | (the_child,) = self.children.values() 359 | if isinstance(the_child, _InternalNode): 360 | new_root_class = _RootNode 361 | else: 362 | assert isinstance(the_child, _LeafNode) 363 | new_root_class = _RootLeafNode 364 | 365 | new_root = new_root_class(the_child.data) 366 | distances = mtree.distance_function( 367 | new_root.data, 368 | [grandchild.data for grandchild in the_child.children.values()] 369 | ) 370 | 371 | for distance, grandchild in zip(distances, the_child.children.values()): 372 | new_root.add_child(grandchild, distance, mtree) 373 | 374 | raise _RootNodeReplacement(new_root) 375 | 376 | @staticmethod 377 | def get_min_capacity(mtree): 378 | return 2 379 | 380 | def _check_min_capacity(self, mtree): 381 | assert len(self.children) >= 2 382 | 383 | 384 | class _InternalNode(_NonRootNodeTrait, _NonLeafNodeTrait): 385 | pass 386 | 387 | 388 | class _LeafNode(_NonRootNodeTrait, _LeafNodeTrait): 389 | pass 390 | 391 | 392 | class _Entry(_IndexItem): 393 | pass 394 | 395 | 396 | class MTree(object): 397 | """ 398 | A data structure for indexing objects based on their proximity. 399 | 400 | The data objects must be any hashable object and the support functions 401 | (distance and split functions) must understand them. 402 | 403 | See http://en.wikipedia.org/wiki/M-tree 404 | """ 405 | 406 | ResultItem = namedtuple('ResultItem', 'data, distance') 407 | 408 | def __init__( 409 | self, 410 | distance_function, 411 | min_node_capacity=20, 412 | max_node_capacity=None, 413 | split_function=functions.make_split_function( 414 | functions.random_promotion, functions.balanced_partition 415 | ) 416 | ): 417 | """ 418 | Creates an M-Tree. 419 | 420 | The argument min_node_capacity must be at least 2. 421 | The argument max_node_capacity should be at least 2*min_node_capacity-1. 422 | The optional argument distance_function must be a function which calculates 423 | the distance between two data objects. 424 | The optional argument split_function must be a function which chooses two 425 | data objects and then partitions the set of data into two subsets 426 | according to the chosen objects. Its arguments are the set of data objects 427 | and the distance_function. Must return a sequence with the following four values: 428 | - First chosen data object. 429 | - Subset with at least [min_node_capacity] objects based on the first 430 | chosen data object. Must contain the first chosen data object. 431 | - Second chosen data object. 432 | - Subset with at least [min_node_capacity] objects based on the second 433 | chosen data object. Must contain the second chosen data object. 434 | """ 435 | if min_node_capacity < 2: 436 | raise ValueError("min_node_capacity must be at least 2") 437 | if max_node_capacity is None: 438 | max_node_capacity = 2 * min_node_capacity - 1 439 | if max_node_capacity <= min_node_capacity: 440 | raise ValueError("max_node_capacity must be greater than min_node_capacity") 441 | 442 | self.min_node_capacity = min_node_capacity 443 | self.max_node_capacity = max_node_capacity 444 | self.distance_function = distance_function 445 | self.split_function = split_function 446 | self.root = None 447 | 448 | def add(self, data): 449 | """ 450 | Adds and indexes an object. 451 | 452 | The object must not currently already be indexed! 453 | """ 454 | if self.root is None: 455 | self.root = _RootLeafNode(data) 456 | self.root.add_data(data, 0, self) 457 | else: 458 | distance = self.distance_function(data, self.root.data) 459 | try: 460 | self.root.add_data(data, distance, self) 461 | except _SplitNodeReplacement as e: 462 | assert len(e.new_nodes) == 2 463 | self.root = _RootNode(self.root.data) 464 | distances = self.distance_function( 465 | self.root.data, 466 | [new_node.data for new_node in e.new_nodes] 467 | ) 468 | for distance, new_node in zip(distances, e.new_nodes): 469 | self.root.add_child(new_node, distance, self) 470 | 471 | add_point = add 472 | 473 | def remove(self, data): 474 | """ 475 | Removes an object from the index. 476 | """ 477 | if self.root is None: 478 | raise KeyError() 479 | 480 | distance_to_root = self.distance_function(data, self.root.data) 481 | try: 482 | self.root.remove_data(data, distance_to_root, self) 483 | except _RootNodeReplacement as e: 484 | self.root = e.new_root 485 | 486 | def get_nearest(self, query_data, range=_INFINITY, limit=_INFINITY): 487 | """ 488 | Returns an iterator on the indexed data nearest to the query_data. The 489 | returned items are tuples containing the data and its distance to the 490 | query_data, in increasing distance order. The results can be limited by 491 | the range (maximum distance from the query_data) and limit arguments. 492 | """ 493 | if self.root is None: 494 | # No indexed data! 495 | return 496 | 497 | distance = self.distance_function(query_data, self.root.data) 498 | min_distance = max(distance - self.root.radius, 0) 499 | 500 | pending_queue = HeapQueue( 501 | content=[_ItemWithDistances(item=self.root, distance=distance, min_distance=min_distance)], 502 | key=lambda iwd: iwd.min_distance, 503 | ) 504 | 505 | nearest_queue = HeapQueue(key=lambda iwd: iwd.distance) 506 | 507 | yielded_count = 0 508 | 509 | while pending_queue: 510 | pending = pending_queue.pop() 511 | 512 | node = pending.item 513 | assert isinstance(node, _Node) 514 | 515 | distances = self.distance_function( 516 | query_data, 517 | [child.data for child in node.children.values()] 518 | ) 519 | for child_distance, child in zip(distances, node.children.values()): 520 | if abs(pending.distance - child.distance_to_parent) - child.radius <= range: 521 | child_min_distance = max(child_distance - child.radius, 0) 522 | if child_min_distance <= range: 523 | iwd = _ItemWithDistances(item=child, distance=child_distance, min_distance=child_min_distance) 524 | if isinstance(child, _Entry): 525 | nearest_queue.push(iwd) 526 | else: 527 | pending_queue.push(iwd) 528 | 529 | # Tries to yield known results so far 530 | if pending_queue: 531 | next_pending = pending_queue.head() 532 | next_pending_min_distance = next_pending.min_distance 533 | else: 534 | next_pending_min_distance = _INFINITY 535 | 536 | while nearest_queue: 537 | next_nearest = nearest_queue.head() 538 | assert isinstance(next_nearest, _ItemWithDistances) 539 | if next_nearest.distance <= next_pending_min_distance: 540 | _ = nearest_queue.pop() 541 | assert _ is next_nearest 542 | 543 | yield self.ResultItem(data=next_nearest.item.data, distance=next_nearest.distance) 544 | yielded_count += 1 545 | if yielded_count >= limit: 546 | # Limit reached 547 | return 548 | else: 549 | break 550 | 551 | def nearest_neighbour(self, point): 552 | return next(self.get_nearest(point, limit=1)) 553 | 554 | def _check(self): 555 | if self.root is not None: 556 | self.root._check(self) 557 | -------------------------------------------------------------------------------- /utils/datastructures/mtree/functions.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from utils.datastructures.mtree.heap_queue import HeapQueue 4 | 5 | 6 | def random_promotion(data_objects, distance_function): 7 | """ 8 | Randomly chooses two objects to be promoted. 9 | """ 10 | data_objects = list(data_objects) 11 | return random.sample(data_objects, 2) 12 | 13 | 14 | def balanced_partition( 15 | promoted_data1, promoted_data2, data_objects, distance_function 16 | ): 17 | partition1 = set() 18 | partition2 = set() 19 | 20 | queue1 = HeapQueue( 21 | data_objects, 22 | key=lambda data: distance_function(data, promoted_data1) 23 | ) 24 | queue2 = HeapQueue( 25 | data_objects, 26 | key=lambda data: distance_function(data, promoted_data2) 27 | ) 28 | 29 | while queue1 or queue2: 30 | while queue1: 31 | data = queue1.pop() 32 | if data not in partition2: 33 | partition1.add(data) 34 | break 35 | 36 | while queue2: 37 | data = queue2.pop() 38 | if data not in partition1: 39 | partition2.add(data) 40 | break 41 | 42 | return partition1, partition2 43 | 44 | 45 | def make_split_function(promotion_function, partition_function): 46 | """ 47 | Creates a splitting function. 48 | The parameters must be callable objects: 49 | - promotion_function(data_objects, distance_function) 50 | Must return two objects chosen from the data_objects argument. 51 | - partition_function(promoted_data1, promoted_data2, data_objects, distance_function) 52 | Must return a sequence with two iterable objects containing a partition 53 | of the data_objects. The promoted_data1 and promoted_data2 arguments 54 | should be used as partitioning criteria and must be contained on the 55 | corresponding iterable subsets. 56 | """ 57 | def split_function(data_objects, distance_function): 58 | promoted_data1, promoted_data2 = promotion_function( 59 | data_objects, distance_function 60 | ) 61 | partition1, partition2 = partition_function( 62 | promoted_data1, promoted_data2, data_objects, distance_function 63 | ) 64 | 65 | return promoted_data1, partition1, promoted_data2, partition2 66 | return split_function 67 | 68 | 69 | def make_cached_distance_function(distance_function): 70 | cache = {} 71 | 72 | def cached_distance_function(data1, data2): 73 | try: 74 | distance = cache[data1][data2] 75 | except KeyError: 76 | distance = distance_function(data1, data2) 77 | 78 | if data1 in cache: 79 | cache[data1][data2] = distance 80 | else: 81 | cache[data1] = {data2: distance} 82 | 83 | if data2 in cache: 84 | cache[data2][data1] = distance 85 | else: 86 | cache[data2] = {data1: distance} 87 | 88 | return distance 89 | 90 | cached_distance_function.cache = cache 91 | 92 | return cached_distance_function 93 | -------------------------------------------------------------------------------- /utils/datastructures/mtree/heap_queue.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | _HeapItem = namedtuple('_HeapItem', 'k, value') 5 | 6 | class HeapQueue(object): 7 | 8 | def __init__(self, content=(), key=lambda x: x, max=False): 9 | if max: 10 | self.key = lambda x: -key(x) 11 | else: 12 | self.key = key 13 | self._items = [_HeapItem(self.key(value), value) for value in content] 14 | self.heapify() 15 | 16 | def _items_less_than(self, base, other): 17 | return self._items[base].k < self._items[other].k 18 | 19 | def _swap_items(self, base, other): 20 | self._items[base], self._items[other] = self._items[other], self._items[base] 21 | 22 | def _make_heap(self, i): 23 | smallest = i 24 | 25 | left = 2 * i + 1 26 | if left < len(self._items) and self._items_less_than(left, smallest): 27 | smallest = left 28 | 29 | right = 2 * i + 2 30 | if right < len(self._items) and self._items_less_than(right, smallest): 31 | smallest = right 32 | 33 | if smallest != i: 34 | self._swap_items(i, smallest) 35 | self._make_heap(smallest) 36 | 37 | def heapify(self): 38 | for i in range(len(self._items) // 2, -1, -1): 39 | self._make_heap(i) 40 | 41 | def head(self): 42 | return self._items[0].value 43 | 44 | def push(self, value): 45 | i = len(self._items) 46 | new_item = _HeapItem(self.key(value), value) 47 | self._items.append(new_item) 48 | while i > 0: 49 | p = int((i - 1) // 2) 50 | if self._items_less_than(p, i): 51 | break 52 | self._swap_items(i, p) 53 | i = p 54 | 55 | def pop(self): 56 | popped = self._items[0].value 57 | self._items[0] = self._items[-1] 58 | self._items.pop(-1) 59 | self._make_heap(0) 60 | return popped 61 | 62 | def pushpop(self, value): 63 | k = self.key(value) 64 | if k <= self._items[0].k: 65 | return value 66 | else: 67 | popped = self._items[0].value 68 | self._items[0] = _HeapItem(k, value) 69 | self._make_heap(0) 70 | return popped 71 | 72 | def __len__(self): 73 | return len(self._items) 74 | 75 | def extractor(self): 76 | while self._items: 77 | yield self.pop() 78 | -------------------------------------------------------------------------------- /utils/datastructures/pathtree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle as pkl 3 | from utils.datastructures.storage import Storage 4 | 5 | 6 | class PathTree: 7 | 8 | @classmethod 9 | def load(cls, path): 10 | inst = cls(Storage.load(str(path) + '_storage.pkl')) 11 | with open(str(path) + '_tree.pkl', 'rb') as f: 12 | a = pkl.load(f) 13 | 14 | n = inst.storage.n 15 | inst.parent[:n] = a['parent'] 16 | inst.cost[:n] = a['cost'] 17 | inst.depth[:n] = a['depth'] 18 | 19 | return inst 20 | 21 | def __init__(self, storage): 22 | self.storage = storage 23 | self.parent = np.zeros(storage.N, dtype=np.intp) 24 | self.cost = np.zeros(storage.N, dtype=float) 25 | self.depth = np.zeros(storage.N, dtype=int) 26 | 27 | def update_link(self, q_idx, parent_idx, c=1.): 28 | self.parent[q_idx] = parent_idx 29 | self.depth[q_idx] = self.depth[parent_idx] + 1 30 | self.cost[q_idx] = self.cost[parent_idx] + c 31 | 32 | def get_edges(self): 33 | # TODO use yielding to avoid data overcreation 34 | res = np.zeros((self.storage.n - 1, 2, self.storage.dim), dtype=np.float) 35 | res[:, 0, :] = self.storage.data[1:self.storage.n, :] 36 | res[:, 1, :] = self.storage.data[self.parent[1:self.storage.n], :] 37 | 38 | costs = self.cost[1:self.storage.n] 39 | return res, costs 40 | 41 | def get_path(self): 42 | # TODO use yielding to avoid data overcreation 43 | i = self.storage.n - 1 44 | len_path = self.depth[i] + 1 45 | res = np.zeros((len_path, self.storage.dim)) 46 | j = len_path 47 | while not i == 0: 48 | j -= 1 49 | res[j] = self.storage.data[i] 50 | i = self.parent[i] 51 | res[0] = self.storage.data[0] 52 | return res 53 | 54 | def save(self, path): 55 | n = self.storage.n 56 | self.storage.save(str(path) + '_storage.pkl') 57 | with open(str(path) + '_tree.pkl', 'wb') as f: 58 | pkl.dump({ 59 | 'parent': self.parent[:n], 60 | 'cost': self.cost[:n], 61 | 'depth': self.depth[:n], 62 | }, f) 63 | 64 | def get_estimated_start_goal(self): 65 | return self.storage.data[0], self.storage.data[self.storage.n - 1] 66 | -------------------------------------------------------------------------------- /utils/datastructures/storage.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle as pkl 3 | 4 | 5 | class Storage: 6 | 7 | @classmethod 8 | def load(cls, path): 9 | with open(path, 'rb') as f: 10 | a = pkl.load(f) 11 | 12 | inst = cls(a['N'], a['dim']) 13 | 14 | n = a['n'] 15 | inst.n = n 16 | 17 | inst.data[:n] = a['data'] 18 | 19 | return inst 20 | 21 | def __init__(self, N, dim): 22 | self.N = N 23 | self.dim = dim 24 | self.n = np.intp(0) 25 | self.data = np.zeros((N, dim), dtype=float) 26 | 27 | def add_point(self, p): 28 | assert not self.is_full 29 | self.data[self.n] = p 30 | self.n += 1 31 | return self.n - 1 32 | 33 | def remove_last(self): 34 | assert self.n 35 | self.n -= 1 36 | 37 | def __getitem__(self, idx): 38 | # assert idx < self.n 39 | return self.data[idx] 40 | 41 | def __len__(self): 42 | return self.n 43 | 44 | @property 45 | def ndarray(self): 46 | return self.data[:self.n] 47 | 48 | @property 49 | def is_full(self): 50 | return self.n == self.N 51 | 52 | def save(self, path): 53 | with open(path, 'wb') as f: 54 | pkl.dump({ 55 | 'N': self.N, 56 | 'dim': self.dim, 57 | 'n': self.n, 58 | 'data': self.data[:self.n] 59 | }, f) 60 | -------------------------------------------------------------------------------- /utils/datastructures/tree.py: -------------------------------------------------------------------------------- 1 | class NodeBinaryTree: 2 | """ 3 | Abstract tree to implement the classic search 4 | """ 5 | def __init__( 6 | self, parent=None, left=None, right=None 7 | ): 8 | self.parent = parent 9 | self.left = left 10 | self.right = right 11 | 12 | def ascension(self): 13 | yield self 14 | if self.parent is not None: 15 | for e in self.parent.ascension(): 16 | yield e 17 | 18 | def depth_first(self): 19 | yield self 20 | if self.left is not None: 21 | for e in self.left.depth_first(): 22 | yield e 23 | if self.right is not None: 24 | for e in self.right.depth_first(): 25 | yield e 26 | 27 | def _wide_first(self, i=0): 28 | yield self, i 29 | iter_left = ( 30 | iter(self.left._wide_first(i + 1)) 31 | if self.left is not None else None 32 | ) 33 | iter_right = ( 34 | iter(self.right._wide_first(i + 1)) 35 | if self.right is not None else None 36 | ) 37 | i_left, n_left = self.robust_next(iter_left) 38 | i_right, n_right = self.robust_next(iter_right) 39 | while not (i_left is None and i_right is None): 40 | if i_left is not None and (i_right is None or i_left <= i_right): 41 | yield i_left, n_left 42 | i_left, n_left = self.robust_next(iter_left) 43 | else: 44 | yield i_right, n_right 45 | i_right, n_right = self.robust_next(iter_right) 46 | 47 | def wide_first(self): 48 | for _, e in self._wide_first(): 49 | yield e 50 | 51 | @staticmethod 52 | def robust_next(iterator): 53 | if iterator is None: 54 | return (None, None) 55 | return next(iterator, (None, None)) 56 | -------------------------------------------------------------------------------- /utils/generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Cut python files in bits loadable by ipython.""" 3 | 4 | from pathlib import Path 5 | import json 6 | 7 | hashtags = ['jupyter_snippet'] 8 | 9 | def generate_from_id(tp_id : int): 10 | folder = Path() / f'tp{tp_id}' 11 | ipynb = next(Path().glob(f'{tp_id}_*.ipynb')) 12 | generate(ipynb,folder) 13 | 14 | def generate(ipynb, folder): 15 | print(f'processing {ipynb} with scripts in {folder}') 16 | with ipynb.open() as f: 17 | data = json.load(f) 18 | cells_copy = data['cells'].copy() 19 | generated = folder / 'generated' 20 | generated.mkdir(exist_ok=True) 21 | for filename in folder.glob('*.py'): 22 | print(f' processing {filename}') 23 | content = [] 24 | hidden = False 25 | dest = None 26 | with filename.open() as f_in: 27 | for line_number, line in enumerate(f_in): 28 | if any([ f'# %{hashtag}' in line for hashtag in hashtags ]): 29 | if dest is not None: 30 | raise SyntaxError(f'%{hashtags[0]} block open twice at line {line_number + 1}') 31 | dest = generated / f'{filename.stem}_{line.split()[2]}' 32 | hidden = False 33 | elif any([ line.strip() == f'# %end_{hashtag}' for hashtag in hashtags ]): 34 | if dest is None: 35 | raise SyntaxError(f'%{hashtags[0]} block before open at line {line_number + 1}') 36 | with dest.open('w') as f_out: 37 | f_out.write(''.join(content)) 38 | for cell_number, cell in enumerate(cells_copy): 39 | if len(cell['source'])==0: continue 40 | if cell['source'][0].endswith(f'%load {dest}'): 41 | data['cells'][cell_number]['source'] = [f'# %load {dest}\n'] + content 42 | #if f'%do_not_load {dest}' in cell['source'][0]: 43 | # data['cells'][cell_number]['source'] = [f'%do_not_load {dest}\n'] 44 | content = [] 45 | hidden = False 46 | dest = None 47 | elif dest is not None: 48 | content.append(line) 49 | with ipynb.open('w') as f: 50 | f.write(json.dumps(data, indent=1)) 51 | 52 | 53 | if __name__ == '__main__': 54 | for tp_number in [0,1,2,3,4,5]: 55 | generate_from_id(tp_number) 56 | 57 | for app in [ 'appendix_scipy_optimizers']: 58 | generate(next(Path().glob(app+'.ipynb')),Path()/'appendix') 59 | -------------------------------------------------------------------------------- /utils/load_ur5_parallel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pinocchio as pin 3 | from example_robot_data import load 4 | 5 | 6 | def load_ur5_parallel(): 7 | """ 8 | Create a robot composed of 4 UR5 9 | 10 | >>> ur5 = load('ur5') 11 | >>> ur5.nq 12 | 6 13 | >>> len(ur5.visual_model.geometryObjects) 14 | 7 15 | >>> robot = load_ur5_parallel() 16 | >>> robot.nq 17 | 24 18 | >>> len(robot.visual_model.geometryObjects) 19 | 28 20 | """ 21 | robot = load('ur5') 22 | nbRobots = 4 23 | 24 | models = [robot.model.copy() for _ in range(nbRobots)] 25 | vmodels = [robot.visual_model.copy() for _ in range(nbRobots)] 26 | 27 | # Build the kinematic model by assembling 4 UR5 28 | fullmodel = pin.Model() 29 | 30 | for irobot, model in enumerate(models): 31 | # Change frame names 32 | for i, f in enumerate(model.frames): 33 | f.name = '%s_#%d' % (f.name, irobot) 34 | # Change joint names 35 | for i, n in enumerate(model.names): 36 | model.names[i] = '%s_#%d' % (n, irobot) 37 | 38 | # Choose the placement of the new arm to be added 39 | Mt = pin.SE3(np.eye(3), np.array([.3, 0, 0.])) # First robot is simply translated 40 | basePlacement = pin.SE3(pin.utils.rotate('z', np.pi * irobot / 2), np.zeros(3)) * Mt 41 | 42 | # Append the kinematic model 43 | fullmodel = pin.appendModel(fullmodel, model, 0, basePlacement) 44 | 45 | # Build the geometry model 46 | fullvmodel = pin.GeometryModel() 47 | 48 | for irobot, (model, vmodel) in enumerate(zip(models, vmodels)): 49 | # Change geometry names 50 | for i, g in enumerate(vmodel.geometryObjects): 51 | # Change the name to avoid conflict 52 | g.name = '%s_#%d' % (g.name, irobot) 53 | 54 | # Refere to new parent names in the full kinematic tree 55 | g.parentFrame = fullmodel.getFrameId(model.frames[g.parentFrame].name) 56 | g.parentJoint = fullmodel.getJointId(model.names[g.parentJoint]) 57 | 58 | # Append the geometry model 59 | fullvmodel.addGeometryObject(g) 60 | # print('add %s on frame %d "%s"' % (g.name, g.parentFrame, fullmodel.frames[g.parentFrame].name)) 61 | 62 | fullrobot = pin.RobotWrapper(fullmodel, fullvmodel, fullvmodel) 63 | # fullrobot.q0 = np.array([-0.375, -1.2 , 1.71 , -0.51 , -0.375, 0. ]*4) 64 | fullrobot.q0 = np.array([np.pi / 4, -np.pi / 4, -np.pi / 2, np.pi / 4, np.pi / 2, 0] * nbRobots) 65 | 66 | return fullrobot 67 | 68 | 69 | if __name__ == "__main__": 70 | from utils.meshcat_viewer_wrapper import MeshcatVisualizer 71 | 72 | robot = load_ur5_parallel() 73 | viz = MeshcatVisualizer(robot, url='classical') 74 | viz.display(robot.q0) 75 | -------------------------------------------------------------------------------- /utils/load_ur5_with_obstacles.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Load a UR5 robot model, display it in the viewer. Also create an obstacle 3 | field made of several capsules, display them in the viewer and create the 4 | collision detection to handle it. 5 | ''' 6 | 7 | import pinocchio as pin 8 | import example_robot_data as robex 9 | import numpy as np 10 | import itertools 11 | 12 | 13 | def XYZRPYtoSE3(xyzrpy): 14 | rotate = pin.utils.rotate 15 | R = rotate('x',xyzrpy[3]) @ rotate('y',xyzrpy[4]) @ rotate('z',xyzrpy[5]) 16 | p = np.array(xyzrpy[:3]) 17 | return pin.SE3(R,p) 18 | 19 | def load_ur5_with_obstacles(robotname='ur5',reduced=False): 20 | 21 | ### Robot 22 | # Load the robot 23 | robot = robex.load(robotname) 24 | 25 | ### If reduced, then only keep should-tilt and elbow joint, hence creating a simple R2 robot. 26 | if reduced: 27 | unlocks = [1,2] 28 | robot.model,[robot.visual_model,robot.collision_model]\ 29 | = pin.buildReducedModel(robot.model,[robot.visual_model,robot.collision_model], 30 | [ i+1 for i in range(robot.nq) if i not in unlocks ],robot.q0) 31 | robot.data = robot.model.createData() 32 | robot.collision_data = robot.collision_model.createData() 33 | robot.visual_data = robot.visual_model.createData() 34 | robot.q0 = robot.q0[unlocks].copy() 35 | 36 | ### Obstacle map 37 | # Capsule obstacles will be placed at these XYZ-RPY parameters 38 | oMobs = [ [ 0.40, 0., 0.30, np.pi/2,0,0], 39 | [-0.08, -0., 0.69, np.pi/2,0,0], 40 | [ 0.23, -0., 0.04, np.pi/2, 0 ,0 ], 41 | [-0.32, 0., -0.08, np.pi/2, 0, 0]] 42 | 43 | # Load visual objects and add them in collision/visual models 44 | color = [ 1.0, 0.2, 0.2, 1.0 ] # color of the capsules 45 | rad,length = .1,0.4 # radius and length of capsules 46 | for i,xyzrpy in enumerate(oMobs): 47 | obs = pin.GeometryObject.CreateCapsule(rad,length) # Pinocchio obstacle object 48 | obs.meshColor = np.array([ 1.0, 0.2, 0.2, 1.0 ]) # Don't forget me, otherwise I am transparent ... 49 | obs.name = "obs%d"%i # Set object name 50 | obs.parentJoint = 0 # Set object parent = 0 = universe 51 | obs.placement = XYZRPYtoSE3(xyzrpy) # Set object placement wrt parent 52 | robot.collision_model.addGeometryObject(obs) # Add object to collision model 53 | robot.visual_model .addGeometryObject(obs) # Add object to visual model 54 | 55 | ### Collision pairs 56 | nobs = len(oMobs) 57 | nbodies = robot.collision_model.ngeoms-nobs 58 | robotBodies = range(nbodies) 59 | envBodies = range(nbodies,nbodies+nobs) 60 | robot.collision_model.removeAllCollisionPairs() 61 | for a,b in itertools.product(robotBodies,envBodies): 62 | robot.collision_model.addCollisionPair(pin.CollisionPair(a,b)) 63 | 64 | ### Geom data 65 | # Collision/visual models have been modified => re-generate corresponding data. 66 | robot.collision_data = pin.GeometryData(robot.collision_model) 67 | robot.visual_data = pin.GeometryData(robot.visual_model ) 68 | 69 | return robot 70 | 71 | 72 | class Target: 73 | ''' 74 | Simple class target that stores and display the position of a target. 75 | ''' 76 | def __init__(self,viz=None,color = [ .0, 1.0, 0.2, 1.0 ], radius = 0.05, position=None): 77 | self.position = position if position is not None else np.array([ 0.0, 0.0 ]) 78 | self.initVisual(viz,color,radius) 79 | self.display() 80 | 81 | def initVisual(self,viz,color,radius): 82 | self.viz = viz 83 | if viz is None: return 84 | self.name = "world/pinocchio/target" 85 | 86 | if isinstance(viz,pin.visualize.MeshcatVisualizer): 87 | import meshcat 88 | obj = meshcat.geometry.Sphere(radius) 89 | material = meshcat.geometry.MeshPhongMaterial() 90 | material.color = int(color[0] * 255) * 256**2 + int(color[1] * 255) * 256 + int(color[2] * 255) 91 | if float(color[3]) != 1.0: 92 | material.transparent = True 93 | material.opacity = float(color[3]) 94 | self.viz.viewer[self.name].set_object(obj, material) 95 | 96 | elif isinstance(viz,pin.visualize.GepettoVisualizer): 97 | self.viz.viewer.gui.addCapsule( self.name, radius,0., color) 98 | 99 | def display(self): 100 | if self.viz is None or self.position is None: return 101 | 102 | if isinstance(self.viz,pin.visualize.MeshcatVisualizer): 103 | T = np.eye(4) 104 | T[[0,2],3] = self.position 105 | self.viz.viewer[self.name].set_transform(T) 106 | elif isinstance(self.viz,pin.visualize.GepettoVisualizer): 107 | self.viz.viewer.gui.applyConfiguration( self.name, 108 | [ self.position[0], 0, self.position[1], 109 | 1.,0.,0.0,0. ]) 110 | self.viz.viewer.gui.refresh() 111 | -------------------------------------------------------------------------------- /utils/meshcat_viewer_wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualizer import MeshcatVisualizer # noqa 2 | from .transformations import planar,translation2d 3 | -------------------------------------------------------------------------------- /utils/meshcat_viewer_wrapper/colors.py: -------------------------------------------------------------------------------- 1 | import meshcat 2 | 3 | 4 | def rgb2int(r, g, b): 5 | ''' 6 | Convert 3 integers (chars) 0 <= r, g, b < 256 into one single integer = 256**2*r+256*g+b, as expected by Meshcat. 7 | 8 | >>> rgb2int(0, 0, 0) 9 | 0 10 | >>> rgb2int(0, 0, 255) 11 | 255 12 | >>> rgb2int(0, 255, 0) == 0x00FF00 13 | True 14 | >>> rgb2int(255, 0, 0) == 0xFF0000 15 | True 16 | >>> rgb2int(255, 255, 255) == 0xFFFFFF 17 | True 18 | ''' 19 | return int((r << 16) + (g << 8) + b) 20 | 21 | 22 | def material(color, transparent=False): 23 | mat = meshcat.geometry.MeshPhongMaterial() 24 | mat.color = color 25 | mat.transparent = transparent 26 | return mat 27 | 28 | 29 | red = material(color=rgb2int(255, 0, 0), transparent=False) 30 | blue = material(color=rgb2int(0, 0, 255), transparent=False) 31 | green = material(color=rgb2int(0, 255, 0), transparent=False) 32 | yellow = material(color=rgb2int(255, 255, 0), transparent=False) 33 | magenta = material(color=rgb2int(255, 0, 255), transparent=False) 34 | cyan = material(color=rgb2int(0, 255, 255), transparent=False) 35 | white = material(color=rgb2int(250, 250, 250), transparent=False) 36 | black = material(color=rgb2int(5, 5, 5), transparent=False) 37 | grey = material(color=rgb2int(120, 120, 120), transparent=False) 38 | 39 | colormap = { 40 | 'red': red, 41 | 'blue': blue, 42 | 'green': green, 43 | 'yellow': yellow, 44 | 'magenta': magenta, 45 | 'cyan': cyan, 46 | 'black': black, 47 | 'white': white, 48 | 'grey': grey 49 | } 50 | -------------------------------------------------------------------------------- /utils/meshcat_viewer_wrapper/tests.py: -------------------------------------------------------------------------------- 1 | import doctest 2 | 3 | from utils.meshcat_viewer_wrapper import colors 4 | 5 | 6 | def load_tests(loader, tests, pattern): 7 | tests.addTests(doctest.DocTestSuite(colors)) 8 | return tests 9 | -------------------------------------------------------------------------------- /utils/meshcat_viewer_wrapper/transformations.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Collection of super simple transformations to ease the use of the viewer. 3 | ''' 4 | 5 | import numpy as np 6 | 7 | def planar(x, y, theta): 8 | '''Convert a 3d vector (x,y,theta) into a transformation in the Y,Z plane.''' 9 | s,c=np.sin(theta/2),np.cos(theta / 2) 10 | return [0, x, y, s,0,0,c] # Rotation around X 11 | 12 | def translation2d(x,y): 13 | ''' Convert a 2d vector (x,y) into a 3d transformation translating the Y,Z plane. ''' 14 | return [0,x,y,1,0,0,0] 15 | 16 | -------------------------------------------------------------------------------- /utils/meshcat_viewer_wrapper/visualizer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import meshcat 4 | import numpy as np 5 | import pinocchio as pin 6 | from pinocchio.visualize import MeshcatVisualizer as PMV 7 | 8 | from . import colors 9 | 10 | 11 | def materialFromColor(color): 12 | if isinstance(color, meshcat.geometry.MeshPhongMaterial): 13 | return color 14 | elif isinstance(color, str): 15 | material = colors.colormap[color] 16 | elif isinstance(color, list): 17 | material = meshcat.geometry.MeshPhongMaterial() 18 | material.color = colors.rgb2int(*[int(c * 255) for c in color[:3]]) 19 | if len(color) == 3: 20 | material.transparent = False 21 | else: 22 | material.transparent = color[3] < 1 23 | material.opacity = float(color[3]) 24 | elif color is None: 25 | material = random.sample(list(colors.colormap), 1)[0] 26 | else: 27 | material = colors.black 28 | return material 29 | 30 | 31 | class MeshcatVisualizer(PMV): 32 | def __init__(self, robot=None, model=None, collision_model=None, visual_model=None, url=None): 33 | if robot is not None: 34 | super().__init__(robot.model, robot.collision_model, robot.visual_model) 35 | elif model is not None: 36 | super().__init__(model, collision_model, visual_model) 37 | 38 | if url is not None: 39 | if url == 'classical': 40 | url = 'tcp://127.0.0.1:6000' 41 | print('Wrapper tries to connect to server <%s>' % url) 42 | server = meshcat.Visualizer(zmq_url=url) 43 | else: 44 | server = None 45 | 46 | if robot is not None or model is not None: 47 | self.initViewer(loadModel=True, viewer=server) 48 | else: 49 | self.viewer = server if server is not None else meshcat.Visualizer() 50 | 51 | def addSphere(self, name, radius, color): 52 | material = materialFromColor(color) 53 | self.viewer[name].set_object(meshcat.geometry.Sphere(radius), material) 54 | 55 | def addCylinder(self, name, length, radius, color=None): 56 | material = materialFromColor(color) 57 | self.viewer[name].set_object(meshcat.geometry.Cylinder(length, radius), material) 58 | 59 | def addBox(self, name, dims, color): 60 | material = materialFromColor(color) 61 | self.viewer[name].set_object(meshcat.geometry.Box(dims), material) 62 | 63 | def applyConfiguration(self, name, placement): 64 | if isinstance(placement, list) or isinstance(placement, tuple): 65 | placement = np.array(placement) 66 | if isinstance(placement, pin.SE3): 67 | R, p = placement.rotation, placement.translation 68 | T = np.r_[np.c_[R, p], [[0, 0, 0, 1]]] 69 | elif isinstance(placement, np.ndarray): 70 | if placement.shape == (7, ): # XYZ-quat 71 | R = pin.Quaternion(np.reshape(placement[3:], [4, 1])).matrix() 72 | p = placement[:3] 73 | T = np.r_[np.c_[R, p], [[0, 0, 0, 1]]] 74 | else: 75 | print('Error, np.shape of placement is not accepted') 76 | return False 77 | else: 78 | print('Error format of placement is not accepted') 79 | return False 80 | self.viewer[name].set_transform(T) 81 | 82 | def delete(self, name): 83 | self.viewer[name].delete() 84 | 85 | def __getitem__(self, name): 86 | return self.viewer[name] 87 | -------------------------------------------------------------------------------- /utils/tests.py: -------------------------------------------------------------------------------- 1 | import doctest 2 | 3 | from utils import load_ur5_parallel 4 | 5 | 6 | def load_tests(loader, tests, pattern): 7 | tests.addTests(doctest.DocTestSuite(load_ur5_parallel)) 8 | return tests 9 | -------------------------------------------------------------------------------- /utils/tiago_loader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Tiago loader accounting for the first planar joint giving robot mobility. 3 | ''' 4 | 5 | import numpy as np 6 | import pinocchio as pin 7 | import example_robot_data as robex 8 | import hppfcl 9 | from os.path import dirname, exists, join 10 | 11 | 12 | class TiagoLoader(object): 13 | #path = '' 14 | #urdf_filename = '' 15 | srdf_filename = '' 16 | urdf_subpath = 'robots' 17 | srdf_subpath = 'srdf' 18 | ref_posture = 'half_sitting' 19 | has_rotor_parameters = False 20 | free_flyer = True 21 | verbose = False 22 | path = "tiago_description" 23 | urdf_filename = "tiago_no_hand.urdf" 24 | 25 | def __init__(self): 26 | urdf_path = join(self.path, self.urdf_subpath, self.urdf_filename) 27 | self.model_path = robex.getModelPath(urdf_path, self.verbose) 28 | self.urdf_path = join(self.model_path, urdf_path) 29 | self.robot = pin.RobotWrapper.BuildFromURDF(self.urdf_path, [join(self.model_path, '../..')], 30 | pin.JointModelPlanar() if self.free_flyer else None) 31 | 32 | if self.srdf_filename: 33 | self.srdf_path = join(self.model_path, self.path, self.srdf_subpath, self.srdf_filename) 34 | self.q0 = readParamsFromSrdf(self.robot.model, self.srdf_path, self.verbose, self.has_rotor_parameters, 35 | self.ref_posture) 36 | else: 37 | self.srdf_path = None 38 | self.q0 = None 39 | 40 | if self.free_flyer: 41 | self.addFreeFlyerJointLimits() 42 | 43 | def addFreeFlyerJointLimits(self): 44 | ub = self.robot.model.upperPositionLimit 45 | ub[:self.robot.model.joints[1].nq] = 1 46 | self.robot.model.upperPositionLimit = ub 47 | lb = self.robot.model.lowerPositionLimit 48 | lb[:self.robot.model.joints[1].nq] = -1 49 | self.robot.model.lowerPositionLimit = lb 50 | 51 | 52 | def loadTiago(addGazeFrame=False): 53 | ''' 54 | Load a tiago model, without the hand, and with the two following modifications wrt example_robot_data. 55 | - first, the first joint is a planar (x,y,cos,sin) joint, while it is a fixed robot in example robot data. 56 | - second, two visual models of a frame have been added to two new op-frame, "tool0" on the robot hand, and "basis0" in 57 | front of the basis. 58 | ''' 59 | 60 | 61 | robot = TiagoLoader().robot 62 | geom = robot.visual_model 63 | 64 | X = pin.utils.rotate('y', np.pi/2) 65 | Y = pin.utils.rotate('x',-np.pi/2) 66 | Z = np.eye(3) 67 | 68 | L = .3 69 | cyl=hppfcl.Cylinder(L/30,L) 70 | med = np.array([0,0,L/2]) 71 | 72 | # --------------------------------------------------------------------------- 73 | # Add a frame visualisation in the effector. 74 | 75 | FIDX = robot.model.getFrameId('wrist_ft_tool_link') 76 | JIDX = robot.model.frames[FIDX].parent 77 | 78 | eff = np.array([0,0,.08]) 79 | FIDX = robot.model.addFrame(pin.Frame('frametool',JIDX,FIDX,pin.SE3(Z,eff),pin.FrameType.OP_FRAME)) 80 | 81 | geom.addGeometryObject(pin.GeometryObject('axis_x',FIDX,JIDX,cyl,pin.SE3(X,X@med+eff))) 82 | geom.geometryObjects[-1].meshColor = np.array([1,0,0,1.]) 83 | 84 | geom.addGeometryObject(pin.GeometryObject('axis_y',FIDX,JIDX,cyl,pin.SE3(Y,Y@med+eff))) 85 | geom.geometryObjects[-1].meshColor = np.array([0,1,0,1.]) 86 | 87 | geom.addGeometryObject(pin.GeometryObject('axis_z',FIDX,JIDX,cyl,pin.SE3(Z,Z@med+eff))) 88 | geom.geometryObjects[-1].meshColor = np.array([0,0,1,1.]) 89 | 90 | # --------------------------------------------------------------------------- 91 | # Add a frame visualisation in front of the basis. 92 | 93 | FIDX = robot.model.getFrameId('base_link') 94 | JIDX = robot.model.frames[FIDX].parent 95 | 96 | eff = np.array([.3,0,.15]) 97 | FIDX = robot.model.addFrame(pin.Frame('framebasis',JIDX,FIDX,pin.SE3(Z,eff),pin.FrameType.OP_FRAME)) 98 | 99 | geom.addGeometryObject(pin.GeometryObject('axis2_x',FIDX,JIDX,cyl,pin.SE3(X,X@med+eff))) 100 | geom.geometryObjects[-1].meshColor = np.array([1,0,0,1.]) 101 | 102 | geom.addGeometryObject(pin.GeometryObject('axi2_y',FIDX,JIDX,cyl,pin.SE3(Y,Y@med+eff))) 103 | geom.geometryObjects[-1].meshColor = np.array([0,1,0,1.]) 104 | 105 | geom.addGeometryObject(pin.GeometryObject('axis2_z',FIDX,JIDX,cyl,pin.SE3(Z,Z@med+eff))) 106 | geom.geometryObjects[-1].meshColor = np.array([0,0,1,1.]) 107 | 108 | # --------------------------------------------------------------------------- 109 | # Add a frame visualisation in front of the head. 110 | 111 | if addGazeFrame: 112 | L = .05 113 | cyl=hppfcl.Cylinder(L/30,L) 114 | med = np.array([0,0,L/2]) 115 | 116 | FIDX = robot.model.getFrameId('xtion_joint') 117 | JIDX = robot.model.frames[FIDX].parent 118 | 119 | eff = np.array([0.4,0.0,0.0]) 120 | FIDX = robot.model.addFrame(pin.Frame('framegaze',JIDX,FIDX,pin.SE3(Z,eff),pin.FrameType.OP_FRAME)) 121 | 122 | geom.addGeometryObject(pin.GeometryObject('axisgaze_x',FIDX,JIDX,cyl,pin.SE3(X,X@med+eff))) 123 | geom.geometryObjects[-1].meshColor = np.array([1,0,0,1.]) 124 | 125 | geom.addGeometryObject(pin.GeometryObject('axisgaze_y',FIDX,JIDX,cyl,pin.SE3(Y,Y@med+eff))) 126 | geom.geometryObjects[-1].meshColor = np.array([0,1,0,1.]) 127 | 128 | geom.addGeometryObject(pin.GeometryObject('axisgaze_z',FIDX,JIDX,cyl,pin.SE3(Z,Z@med+eff))) 129 | geom.geometryObjects[-1].meshColor = np.array([0,0,1,1.]) 130 | 131 | # ------------------------------------------------------------------------------- 132 | # Regenerate the data from the new models. 133 | 134 | robot.q0 = np.array([1,1,1,0]+[0]*(robot.model.nq-4)) 135 | 136 | robot.data = robot.model.createData() 137 | robot.visual_data = robot.visual_model.createData() 138 | 139 | return robot 140 | 141 | # ------------------------------------------------------------------------------------------------ 142 | # ------------------------------------------------------------------------------------------------ 143 | # ------------------------------------------------------------------------------------------------ 144 | 145 | if __name__ == "__main__": 146 | from utils.meshcat_viewer_wrapper import MeshcatVisualizer 147 | 148 | robot = loadTiago() 149 | viz = MeshcatVisualizer(robot,url='classical') 150 | 151 | viz.display(robot.q0) 152 | 153 | --------------------------------------------------------------------------------