├── .gitignore ├── README.md ├── demo.ipynb ├── handson.ipynb ├── ic ├── 3-body.tsv ├── circle.tsv ├── ellipse.tsv ├── figure-8.tsv └── solar-system.tsv └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !*/ 3 | 4 | !*.md 5 | !*.py 6 | !*.ipynb 7 | 8 | !*.png 9 | !*.jpg 10 | !*.svg 11 | 12 | __pycache__/ 13 | .ipynb_checkpoints/ 14 | _build/ 15 | plots/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Speeding Up Your Python Codes 1000x 2 | 3 | Welcome to this workshop! 4 | 5 | In this repository, you will find Jupyter Notebooks 6 | [Demo](demo.ipynb) and 7 | [Hands-on](handson.ipynb) 8 | that walk you through step-by-step optimizations using a classic 9 | n-body simulation. 10 | We start with simple improvements like list comprehensions and 11 | reducing operation counts, then move on to advanced techniques with 12 | high-performance libraries such as `NumPy` and Google `JAX`, 13 | just-in-time compilation, and GPU acceleration. 14 | 15 | The result is a performance boost of over 1000x, empowering you to 16 | process complex datasets, train AI models, or run detailed simulations 17 | efficiently. 18 | Enjoy the journey and unlock Python's true potential! 19 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "0", 6 | "metadata": {}, 7 | "source": [ 8 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rndsrc/orbits-py/blob/main/demo.ipynb)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "1", 14 | "metadata": {}, 15 | "source": [ 16 | "# Speeding Up Your Python Codes 1000x" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "2", 22 | "metadata": {}, 23 | "source": [ 24 | "## Introduction\n", 25 | "\n", 26 | "Welcome to \"Speeding Up Your Python Codes 1000x\"!\n", 27 | "\n", 28 | "In the next 45 minutes, we will learn tricks to unlock python performance and increase a sample code's speed by a factor of 1000x!" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "3", 34 | "metadata": {}, 35 | "source": [ 36 | "Here is our plan:\n", 37 | "* **Understand the problem:**\n", 38 | " we will start by outlining the $n$-body problem and how to use the leapfrog algorithm to solve it numerically.\n", 39 | "* **Benchmark and profile:**\n", 40 | " we will learn how to measure performance and identify bottlenecks using tools like timeit and cProfile.\n", 41 | "* **Optimize the code:**\n", 42 | " step through a series of optimizations:\n", 43 | " * Use list comprehensions and reduce operation counts.\n", 44 | " * Replace slow operations with faster ones.\n", 45 | " * Leverage high-performance libraries such as NumPy.\n", 46 | " * Explore lower precision where applicable.\n", 47 | " * Harness Google JAX for just-in-time compilation and GPU acceleration.\n", 48 | "* **Apply the integrator:** Finally, we will use our optimized code to simulate the mesmerizing figure-8 self-gravitating orbit." 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "4", 54 | "metadata": {}, 55 | "source": [ 56 | "By the end of this workshop, you will see that Python isn't just great for rapid prototyping.\n", 57 | "It can also deliver serious performance when optimized right. Let’s dive in!" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "5", 63 | "metadata": {}, 64 | "source": [ 65 | "## The $n$-body Problem and the Leapfrog Algorithm\n", 66 | "\n", 67 | "The $n$-body problem involves simulating the motion of many bodies interacting with each other through (usually) gravitational forces.\n", 68 | "\n", 69 | "This problem is physically interesting, as it models the [solar system](https://rebound.readthedocs.io/en/latest/), [galaxy dynamics, or even the whole cosmic structure formation](https://wwwmpa.mpa-garching.mpg.de/gadget/)!\n", 70 | "\n", 71 | "This problem is also computational interesting.\n", 72 | "Each body experiences a force from every other body, making the simulation computationally intensive.\n", 73 | "It is ideal for exploring optimization techniques.\n", 74 | "\n", 75 | "The leapfrog algorithm is a simple, robust numerical integrator that alternates between updating velocities (\"kicks\") and positions (\"drifts\").\n", 76 | "It is popular because it conserves energy and momentum well over long simulation periods." 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "id": "6", 82 | "metadata": {}, 83 | "source": [ 84 | "We need to solve Newton's second law of motion for a collection of $n$ bodies.\n", 85 | "The body forces are solved by using Newton's Law of Universal Gravitation." 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "7", 91 | "metadata": {}, 92 | "source": [ 93 | "### Gravitational Force/Acceleration Equation\n", 94 | "\n", 95 | "For any two bodies labeled by $i$ and $j$, the direct gravitational force exerted on body $i$ by body $j$ is given by:\n", 96 | "\\begin{align}\n", 97 | " \\mathbf{f}_{ij} = -G m_i m_j\\frac{\\mathbf{r}_i - \\mathbf{r}_j\\ \\ \\ }{|\\mathbf{r}_i - \\mathbf{r}_j|^{3/2}},\n", 98 | "\\end{align}\n", 99 | "where:\n", 100 | "* $G$ is the gravitational constant.\n", 101 | "* $m_i$ and $m_j$ are the masses of the bodies.\n", 102 | "* $\\mathbf{r}_i$ and $\\mathbf{r}_j$ are their position vectors." 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "8", 108 | "metadata": {}, 109 | "source": [ 110 | "Summing the contributions from all other bodies gives the net force (and thus the acceleration) on body $i$:\n", 111 | "\\begin{align}\n", 112 | " \\mathbf{f}_i\n", 113 | " = \\sum_{j\\ne i} \\mathbf{f}_{ij}\n", 114 | " = -G m_i \\sum_{j\\ne i} m_j \\frac{\\mathbf{r}_i - \\mathbf{r}_j\\ \\ \\ }{|\\mathbf{r}_i - \\mathbf{r}_j|^3},\n", 115 | "\\end{align}" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "id": "9", 121 | "metadata": {}, 122 | "source": [ 123 | "Given Newton's law $f = m a$, the \"acceleration\" applied on body $i$ caused by all other bodies is:\n", 124 | "\\begin{align}\n", 125 | " \\mathbf{a}_i\n", 126 | " = \\sum_{j\\ne i} \\mathbf{a}_{ij}\n", 127 | " = -G \\sum_{j\\ne i} m_j \\frac{\\mathbf{r}_i - \\mathbf{r}_j\\ \\ \\ }{|\\mathbf{r}_i - \\mathbf{r}_j|^3}.\n", 128 | "\\end{align}" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "id": "10", 134 | "metadata": {}, 135 | "source": [ 136 | "Choosing the right units so $G = 1$.\n", 137 | "Here is a pure Python implementation of the gravitational acceleration:" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "11", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "def acc1(m, r):\n", 148 | " \n", 149 | " n = len(m)\n", 150 | " a = []\n", 151 | " for i in range(n):\n", 152 | " axi, ayi, azi = 0, 0, 0\n", 153 | " for j in range(n):\n", 154 | " if j != i:\n", 155 | " xi, yi, zi = r[i]\n", 156 | " xj, yj, zj = r[j]\n", 157 | "\n", 158 | " axij = - m[j] * (xi - xj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)\n", 159 | " ayij = - m[j] * (yi - yj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)\n", 160 | " azij = - m[j] * (zi - zj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)\n", 161 | "\n", 162 | " axi += axij\n", 163 | " ayi += ayij\n", 164 | " azi += azij\n", 165 | "\n", 166 | " a.append((axi, ayi, azi))\n", 167 | "\n", 168 | " return a" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "id": "12", 174 | "metadata": {}, 175 | "source": [ 176 | "We may try using this function to evaluate the gravitational force between two particles, with with mass $m = 1$, at a distance of $2$.\n", 177 | "We expect the \"acceleration\" be -1/4 in the direction of their seperation." 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "id": "13", 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "m = [1.0, 1.0]\n", 188 | "r0 = [\n", 189 | " (+1.0, 0.0, 0.0),\n", 190 | " (-1.0, 0.0, 0.0),\n", 191 | "]\n", 192 | "a0ref = [\n", 193 | " (-0.25, 0.0, 0.0),\n", 194 | " (+0.25, 0.0, 0.0),\n", 195 | "]" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "14", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "a0 = acc1(m, r0)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "id": "15", 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "a0" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "id": "16", 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "assert a0 == a0ref" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "id": "17", 231 | "metadata": {}, 232 | "source": [ 233 | "### Leapfrog Integrator\n", 234 | "\n", 235 | "Our simulation solves Newton's second law of motion, $\\mathbf{f} = m \\mathbf{a}$, numerically.\n", 236 | "The equation tells us that the acceleration of a body is determined by the net force acting on it.\n", 237 | "In a system where forces are conservative, this law guarantees the conservation of energy and momentum.\n", 238 | "These are critical properties when simulating long-term dynamics.\n", 239 | "\n", 240 | "However, many standard integration methods tend to drift over time, gradually losing these conservation properties.\n", 241 | "The leapfrog algorithm is a symplectic integrator designed to address this issue.\n", 242 | "Its staggered (or \"leapfrogging\") updates of velocity and position help preserve energy and momentum much more effectively over extended simulations.\n", 243 | "\n", 244 | "1. Half-step velocity update (Kick):\n", 245 | " start by updating the velocity by half a time step using the current acceleration $a(t)$:\n", 246 | " \\begin{align}\n", 247 | " v\\left(t+\\frac{1}{2}\\Delta t\\right) = v(t) + \\frac{1}{2}\\Delta t\\, a(t).\n", 248 | " \\end{align}\n", 249 | "\n", 250 | "2. Full-step position update (Drift):\n", 251 | " update the position for a full time step using the half-stepped velocity:\n", 252 | " \\begin{align}\n", 253 | " x(t+\\Delta t) = x(t) + \\Delta t\\, v\\left(t+\\frac{1}{2}\\Delta t\\right).\n", 254 | " \\end{align}\n", 255 | "\n", 256 | "3. Half-step velocity update (Kick):\n", 257 | " finally, update the velocity by another half time step using the new acceleration $a(t+\\Delta t)$ computed from the updated positions:\n", 258 | " \\begin{align}\n", 259 | " v(t+\\Delta t) = v\\left(t+\\frac{1}{2}\\Delta t\\right) + \\frac{1}{2}\\Delta t\\, a(t+\\Delta t).\n", 260 | " \\end{align}\n", 261 | "\n", 262 | "Here is a pure Python implementation of the gravitational acceleration:" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "id": "18", 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "def leapfrog1(m, r0, v0, dt, acc=acc1):\n", 273 | "\n", 274 | " n = len(m)\n", 275 | "\n", 276 | " # vh = v0 + 0.5 * dt * a0\n", 277 | " a0 = acc(m, r0)\n", 278 | " vh = []\n", 279 | " for i in range(n):\n", 280 | " a0xi, a0yi, a0zi = a0[i]\n", 281 | " v0xi, v0yi, v0zi = v0[i]\n", 282 | " vhxi = v0xi + 0.5 * dt * a0xi\n", 283 | " vhyi = v0yi + 0.5 * dt * a0yi\n", 284 | " vhzi = v0zi + 0.5 * dt * a0zi\n", 285 | " vh.append((vhxi, vhyi, vhzi))\n", 286 | "\n", 287 | " # r1 = r0 + dt * vh\n", 288 | " r1 = []\n", 289 | " for i in range(n):\n", 290 | " vhxi, vhyi, vhzi = vh[i]\n", 291 | " x0i, y0i, z0i = r0[i]\n", 292 | " x1i = x0i + dt * vhxi\n", 293 | " y1i = y0i + dt * vhyi\n", 294 | " z1i = z0i + dt * vhzi\n", 295 | " r1.append((x1i, y1i, z1i))\n", 296 | "\n", 297 | " # v1 = vh + 0.5 * dt * a1\n", 298 | " a1 = acc(m, r1)\n", 299 | " v1 = []\n", 300 | " for i in range(n):\n", 301 | " a1xi, a1yi, a1zi = a1[i]\n", 302 | " vhxi, vhyi, vhzi = vh[i]\n", 303 | " v1xi = vhxi + 0.5 * dt * a1xi\n", 304 | " v1yi = vhyi + 0.5 * dt * a1yi\n", 305 | " v1zi = vhzi + 0.5 * dt * a1zi\n", 306 | " v1.append((v1xi, v1yi, v1zi))\n", 307 | "\n", 308 | " return r1, v1" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "id": "19", 314 | "metadata": {}, 315 | "source": [ 316 | "We may try using this function to evolve the two particles under gravitational force, with some initial velocities." 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "id": "20", 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "v0 = [\n", 327 | " (0.0, +0.5, 0.0),\n", 328 | " (0.0, -0.5, 0.0),\n", 329 | "]\n", 330 | "v1ref = [\n", 331 | " (-0.024984345739803245, +0.4993750014648409, 0.0),\n", 332 | " (+0.024984345739803245, -0.4993750014648409, 0.0),\n", 333 | "]" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "id": "21", 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "r1, v1 = leapfrog1(m, r0, v0, 0.1)" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": null, 349 | "id": "22", 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "v1" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "id": "23", 360 | "metadata": {}, 361 | "outputs": [], 362 | "source": [ 363 | "assert v1 == v1ref" 364 | ] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "id": "24", 369 | "metadata": {}, 370 | "source": [ 371 | "### Test the Integrator\n", 372 | "\n", 373 | "We test the integrator by evolving the two particle by a time $T = 2\\pi$ using $N = 64$ steps." 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "id": "25", 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "from math import pi\n", 384 | "\n", 385 | "N = 64\n", 386 | "T = 2 * pi" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "id": "26", 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "dt = T / N\n", 397 | "R = [r0]\n", 398 | "V = [v0]\n", 399 | "for _ in range(N):\n", 400 | " r, v = leapfrog1(m, R[-1], V[-1], dt)\n", 401 | " R.append(r)\n", 402 | " V.append(v)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "id": "27", 408 | "metadata": {}, 409 | "source": [ 410 | "Although our numerical scheme is implemented in pure python, it is still handy to use `numpy` to slice through the data..." 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "id": "28", 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "import numpy as np\n", 421 | "\n", 422 | "R = np.array(R)\n", 423 | "X = R[:,:,0]\n", 424 | "Y = R[:,:,1]" 425 | ] 426 | }, 427 | { 428 | "cell_type": "markdown", 429 | "id": "29", 430 | "metadata": {}, 431 | "source": [ 432 | "and plot the result..." 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": null, 438 | "id": "30", 439 | "metadata": {}, 440 | "outputs": [], 441 | "source": [ 442 | "from matplotlib import pyplot as plt\n", 443 | "\n", 444 | "plt.plot(X, Y, '.-')\n", 445 | "plt.gca().set_aspect('equal')" 446 | ] 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "id": "31", 451 | "metadata": {}, 452 | "source": [ 453 | "In addition, we may create animation." 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "id": "32", 460 | "metadata": {}, 461 | "outputs": [], 462 | "source": [ 463 | "from matplotlib.animation import ArtistAnimation\n", 464 | "from IPython.display import HTML\n", 465 | "from tqdm import tqdm\n", 466 | "\n", 467 | "def animate(X, Y, ntail=10):\n", 468 | " fig, ax = plt.subplots(1, 1, figsize=(5,5))\n", 469 | " ax.set_xlabel('x')\n", 470 | " ax.set_ylabel('y')\n", 471 | "\n", 472 | " frames = []\n", 473 | " for i in tqdm(range(len(X))):\n", 474 | " b,e = max(0, i-ntail), i+1\n", 475 | " ax.set_prop_cycle(None)\n", 476 | " f = ax.plot(X[b:e,:], Y[b:e,:], '.-', animated=True)\n", 477 | " frames.append(f)\n", 478 | "\n", 479 | " anim = ArtistAnimation(fig, frames, interval=50)\n", 480 | " plt.close()\n", 481 | " \n", 482 | " return anim" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": null, 488 | "id": "33", 489 | "metadata": {}, 490 | "outputs": [], 491 | "source": [ 492 | "anim = animate(X, Y)\n", 493 | "\n", 494 | "HTML(anim.to_html5_video()) # display animation\n", 495 | "# anim.save('orbitss.mp4') # save animation" 496 | ] 497 | }, 498 | { 499 | "cell_type": "markdown", 500 | "id": "34", 501 | "metadata": {}, 502 | "source": [ 503 | "## Benchmarking and Profiling Your Code\n", 504 | "\n", 505 | "Before diving into optimizations, it is essential to understand where our code spends most of its time.\n", 506 | "By benchmarking and profiling, we can pinpoint performance bottlenecks and measure improvements after optimization." 507 | ] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "id": "35", 512 | "metadata": {}, 513 | "source": [ 514 | "### Benchmarking vs. Profiling\n", 515 | "\n", 516 | "* Benchmarking:\n", 517 | " Measures the overall runtime of your code.\n", 518 | " Tools like Python's `timeit` module run your code multiple times to provide an accurate average runtime.\n", 519 | " This helps in comparing the performance before and after optimizations.\n", 520 | "\n", 521 | "* Profiling:\n", 522 | " Provides detailed insights into which parts of your code are consuming the most time.\n", 523 | " For example, `cProfile` generates reports showing function call times and frequencies.\n", 524 | " Note that `cProfile` typically runs the code once, so its focus is on identifying hotspots rather than providing averaged timings." 525 | ] 526 | }, 527 | { 528 | "cell_type": "markdown", 529 | "id": "36", 530 | "metadata": {}, 531 | "source": [ 532 | "### Quick Benchmark Example\n", 533 | "\n", 534 | "`timeit` is a module designed for benchmarking by executing code multiple times.\n", 535 | "It is excellent for obtaining reliable runtime measurements." 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": null, 541 | "id": "37", 542 | "metadata": {}, 543 | "outputs": [], 544 | "source": [ 545 | "n = 1000\n", 546 | "m = np.random.lognormal(size=n).tolist()\n", 547 | "r0 = np.random.normal(size=(n, 3)).tolist()\n", 548 | "v0 = np.random.normal(size=(n, 3)).tolist()" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": null, 554 | "id": "38", 555 | "metadata": {}, 556 | "outputs": [], 557 | "source": [ 558 | "%timeit r1, v1 = leapfrog1(m, r0, v0, dt)" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "id": "39", 564 | "metadata": {}, 565 | "source": [ 566 | "It takes about 1.31 second on my laptop to run a single step of leapfrog for $n = 1000$ bodies.\n", 567 | "Your mileage may vary.\n", 568 | "But this is pretty slow in today's standard." 569 | ] 570 | }, 571 | { 572 | "cell_type": "markdown", 573 | "id": "40", 574 | "metadata": {}, 575 | "source": [ 576 | "### Quick Profiling Example\n", 577 | "\n", 578 | "Here is a snippet using cProfile to profile our leapfrog stepper:" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": null, 584 | "id": "41", 585 | "metadata": {}, 586 | "outputs": [], 587 | "source": [ 588 | "import cProfile\n", 589 | "\n", 590 | "cProfile.run(\"r1, v1 = leapfrog1(m, r0, v0, dt)\")" 591 | ] 592 | }, 593 | { 594 | "cell_type": "markdown", 595 | "id": "42", 596 | "metadata": {}, 597 | "source": [ 598 | "From cProfile's result, the most used function is `method 'append' of 'list' objects`.\n", 599 | "This is not surprising given we've been using for-loop and append, e.g.,\n", 600 | "```\n", 601 | " v1 = []\n", 602 | " for i in range(n):\n", 603 | " ...\n", 604 | " v1.append(...)\n", 605 | "```\n", 606 | "in both `acc1()` and `leapfrog1()`.\n", 607 | "This is neither pythonic nor efficient." 608 | ] 609 | }, 610 | { 611 | "cell_type": "markdown", 612 | "id": "43", 613 | "metadata": {}, 614 | "source": [ 615 | "## Optimization 1: Use List Comprehension Over For Loop\n", 616 | "\n", 617 | "Python's list comprehensions is a concise way to create lists by iterating over an iterable and applying an expression---all in a single, compact line.\n", 618 | "Internally, they are optimized in `C`, making them generally faster than a standard python for-loop." 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": null, 624 | "id": "44", 625 | "metadata": {}, 626 | "outputs": [], 627 | "source": [ 628 | "iterable = range(10)\n", 629 | "\n", 630 | "# Standard python for-loop\n", 631 | "l1 = []\n", 632 | "for x in iterable:\n", 633 | " l1.append(x * 2)\n", 634 | "\n", 635 | "# Using a list comprehension\n", 636 | "l2 = [x * 2 for x in iterable]\n", 637 | "\n", 638 | "# Compare results\n", 639 | "print(l1)\n", 640 | "print(l2)" 641 | ] 642 | }, 643 | { 644 | "cell_type": "markdown", 645 | "id": "45", 646 | "metadata": {}, 647 | "source": [ 648 | "We may use list comprehension to rewrite our leapfrog algorithm:" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "id": "46", 655 | "metadata": {}, 656 | "outputs": [], 657 | "source": [ 658 | "def leapfrog2(m, r0, v0, dt, acc=acc1):\n", 659 | "\n", 660 | " n = len(m)\n", 661 | "\n", 662 | " # vh = v0 + 0.5 * dt * a0\n", 663 | " a0 = acc(m, r0)\n", 664 | " vh = [[v0[i][k] + 0.5 * dt * a0[i][k]\n", 665 | " for k in range(3)]\n", 666 | " for i in range(n)]\n", 667 | "\n", 668 | " # r1 = r0 + dt * vh\n", 669 | " r1 = [[r0[i][k] + dt * vh[i][k]\n", 670 | " for k in range(3)]\n", 671 | " for i in range(n)]\n", 672 | "\n", 673 | " # v1 = vh + 0.5 * dt * a1\n", 674 | " a1 = acc(m, r1)\n", 675 | " v1 = [[vh[i][k] + 0.5 * dt * a1[i][k]\n", 676 | " for k in range(3)]\n", 677 | " for i in range(n)]\n", 678 | "\n", 679 | " return r1, v1" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": null, 685 | "id": "47", 686 | "metadata": {}, 687 | "outputs": [], 688 | "source": [ 689 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt)" 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": null, 695 | "id": "48", 696 | "metadata": {}, 697 | "outputs": [], 698 | "source": [ 699 | "cProfile.run(\"r1, v1 = leapfrog2(m, r0, v0, dt)\")" 700 | ] 701 | }, 702 | { 703 | "cell_type": "markdown", 704 | "id": "49", 705 | "metadata": {}, 706 | "source": [ 707 | "Depending on your python version, you may see slight performance increase from `timeit` and number of call decreased for `append`.\n", 708 | "It takes about 1.29 second on my laptop to run a single step of leapfrog for $n = 1000$ bodies." 709 | ] 710 | }, 711 | { 712 | "cell_type": "markdown", 713 | "id": "50", 714 | "metadata": {}, 715 | "source": [ 716 | "## Optimization 2: Reduce Operation Count\n", 717 | "\n", 718 | "Reducing the operation count means cutting down on unnecessary calculations and redundant function calls.\n", 719 | "When operations are executed millions of times---as in inner loops or simulations---even small optimizations can yield significant speed improvements.\n", 720 | "\n", 721 | "For example, precomputing constant values or combining multiple arithmetic steps into one reduces repetitive work.\n", 722 | "In the context of the $n$-body problem, calculate invariant quantities once outside of critical loops instead of recalculating them every time.\n", 723 | "This streamlined approach not only speeds up your code but also help further optimizations like vectorization and JIT compilation." 724 | ] 725 | }, 726 | { 727 | "cell_type": "markdown", 728 | "id": "51", 729 | "metadata": {}, 730 | "source": [ 731 | "Recall the benchmark using `leapfrog2()` with `acc1()`." 732 | ] 733 | }, 734 | { 735 | "cell_type": "code", 736 | "execution_count": null, 737 | "id": "52", 738 | "metadata": {}, 739 | "outputs": [], 740 | "source": [ 741 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc1)" 742 | ] 743 | }, 744 | { 745 | "cell_type": "markdown", 746 | "id": "53", 747 | "metadata": {}, 748 | "source": [ 749 | "We noticed that the computation of $\\mathbf{r}_{ij}^3 = |\\mathbf{r}_i - \\mathbf{r}_j|^3$ is used in all components of $\\mathbf{x}_{ij}$.\n", 750 | "Instead of recomputing it, we can simply cache it in a variable `rrr`." 751 | ] 752 | }, 753 | { 754 | "cell_type": "code", 755 | "execution_count": null, 756 | "id": "54", 757 | "metadata": {}, 758 | "outputs": [], 759 | "source": [ 760 | "def acc2(m, r):\n", 761 | " \n", 762 | " n = len(m)\n", 763 | " a = []\n", 764 | " for i in range(n):\n", 765 | " axi, ayi, azi = 0, 0, 0\n", 766 | " for j in range(n):\n", 767 | " if j != i:\n", 768 | " xi, yi, zi = r[i]\n", 769 | " xj, yj, zj = r[j]\n", 770 | "\n", 771 | " # \"Cache\" r^3\n", 772 | " rrr = ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)\n", 773 | " \n", 774 | " axij = - m[j] * (xi - xj) / rrr\n", 775 | " ayij = - m[j] * (yi - yj) / rrr\n", 776 | " azij = - m[j] * (zi - zj) / rrr\n", 777 | "\n", 778 | " axi += axij\n", 779 | " ayi += ayij\n", 780 | " azi += azij\n", 781 | "\n", 782 | " a.append((axi, ayi, azi))\n", 783 | "\n", 784 | " return a" 785 | ] 786 | }, 787 | { 788 | "cell_type": "code", 789 | "execution_count": null, 790 | "id": "55", 791 | "metadata": {}, 792 | "outputs": [], 793 | "source": [ 794 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc2)" 795 | ] 796 | }, 797 | { 798 | "cell_type": "markdown", 799 | "id": "56", 800 | "metadata": {}, 801 | "source": [ 802 | "This reduce the benchmark time by about 45% already!" 803 | ] 804 | }, 805 | { 806 | "cell_type": "markdown", 807 | "id": "57", 808 | "metadata": {}, 809 | "source": [ 810 | "But we don't have to stop from this.\n", 811 | "The different components of $\\mathbf{r}_{ij} = \\mathbf{r}_i - \\mathbf{r}_j$ can be cached as `dx`, `dy`, `dz`, too." 812 | ] 813 | }, 814 | { 815 | "cell_type": "code", 816 | "execution_count": null, 817 | "id": "58", 818 | "metadata": {}, 819 | "outputs": [], 820 | "source": [ 821 | "def acc3(m, r):\n", 822 | " \n", 823 | " n = len(m)\n", 824 | " a = []\n", 825 | " for i in range(n):\n", 826 | " axi, ayi, azi = 0, 0, 0\n", 827 | " for j in range(n):\n", 828 | " if j != i:\n", 829 | " xi, yi, zi = r[i]\n", 830 | " xj, yj, zj = r[j]\n", 831 | "\n", 832 | " # \"Cache\" the components of dr = r_ij = ri - rj\n", 833 | " dx = xi - xj\n", 834 | " dy = yi - yj\n", 835 | " dz = zi - zj\n", 836 | "\n", 837 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2)\n", 838 | " \n", 839 | " axij = - m[j] * dx / rrr\n", 840 | " ayij = - m[j] * dy / rrr\n", 841 | " azij = - m[j] * dz / rrr\n", 842 | "\n", 843 | " axi += axij\n", 844 | " ayi += ayij\n", 845 | " azi += azij\n", 846 | "\n", 847 | " a.append((axi, ayi, azi))\n", 848 | "\n", 849 | " return a" 850 | ] 851 | }, 852 | { 853 | "cell_type": "code", 854 | "execution_count": null, 855 | "id": "59", 856 | "metadata": {}, 857 | "outputs": [], 858 | "source": [ 859 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc3)" 860 | ] 861 | }, 862 | { 863 | "cell_type": "markdown", 864 | "id": "60", 865 | "metadata": {}, 866 | "source": [ 867 | "Similarly, we can cache $-m_j / |\\mathbf{r}_{ij}|^3$." 868 | ] 869 | }, 870 | { 871 | "cell_type": "code", 872 | "execution_count": null, 873 | "id": "61", 874 | "metadata": {}, 875 | "outputs": [], 876 | "source": [ 877 | "def acc4(m, r):\n", 878 | " \n", 879 | " n = len(m)\n", 880 | " a = []\n", 881 | " for i in range(n):\n", 882 | " axi, ayi, azi = 0, 0, 0\n", 883 | " for j in range(n):\n", 884 | " if j != i:\n", 885 | " xi, yi, zi = r[i]\n", 886 | " xj, yj, zj = r[j]\n", 887 | "\n", 888 | " dx = xi - xj\n", 889 | " dy = yi - yj\n", 890 | " dz = zi - zj\n", 891 | "\n", 892 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2)\n", 893 | " f = - m[j] / rrr # \"cache\" -m_j / r_{ij}^3\n", 894 | "\n", 895 | " axi += f * dx\n", 896 | " ayi += f * dy\n", 897 | " azi += f * dz\n", 898 | "\n", 899 | " a.append((axi, ayi, azi))\n", 900 | "\n", 901 | " return a" 902 | ] 903 | }, 904 | { 905 | "cell_type": "code", 906 | "execution_count": null, 907 | "id": "62", 908 | "metadata": {}, 909 | "outputs": [], 910 | "source": [ 911 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc4)" 912 | ] 913 | }, 914 | { 915 | "cell_type": "markdown", 916 | "id": "63", 917 | "metadata": {}, 918 | "source": [ 919 | "Finally, we notice the symmetry that $\\mathbf{a}_{ij} = \\mathbf{a}_{ji}$.\n", 920 | "In principle, we only need to compute $\\mathbf{a}_{ij}$ for $j < i$.\n", 921 | "However, this requires we pre-allocate a list-of-list.\n", 922 | "Just creating list-of-list and use them to keep track of the acceleration actually increase benchmark time." 923 | ] 924 | }, 925 | { 926 | "cell_type": "code", 927 | "execution_count": null, 928 | "id": "64", 929 | "metadata": {}, 930 | "outputs": [], 931 | "source": [ 932 | "def acc5(m, r):\n", 933 | " \n", 934 | " n = len(m)\n", 935 | " a = [[0]*3]*n # create a list-of-list\n", 936 | " for i in range(n):\n", 937 | " for j in range(n):\n", 938 | " if j != i:\n", 939 | " xi, yi, zi = r[i]\n", 940 | " xj, yj, zj = r[j]\n", 941 | "\n", 942 | " dx = xi - xj\n", 943 | " dy = yi - yj\n", 944 | " dz = zi - zj\n", 945 | "\n", 946 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2)\n", 947 | " f = - m[j] / rrr\n", 948 | "\n", 949 | " # Use the list-of-list to keep track of the acceleration\n", 950 | " a[i][0] += f * dx\n", 951 | " a[i][1] += f * dy\n", 952 | " a[i][2] += f * dz\n", 953 | "\n", 954 | " return a" 955 | ] 956 | }, 957 | { 958 | "cell_type": "code", 959 | "execution_count": null, 960 | "id": "65", 961 | "metadata": {}, 962 | "outputs": [], 963 | "source": [ 964 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc5)" 965 | ] 966 | }, 967 | { 968 | "cell_type": "markdown", 969 | "id": "66", 970 | "metadata": {}, 971 | "source": [ 972 | "But once we have the list-of-list, we may change the upper bound of the inner loop to cut the computation to almost half.\n", 973 | "This is the fastest code so far." 974 | ] 975 | }, 976 | { 977 | "cell_type": "code", 978 | "execution_count": null, 979 | "id": "67", 980 | "metadata": {}, 981 | "outputs": [], 982 | "source": [ 983 | "def acc6(m, r):\n", 984 | " \n", 985 | " n = len(m)\n", 986 | " a = [[0]*3]*n\n", 987 | " for i in range(n):\n", 988 | " for j in range(i): # adjust the upper bound of the inner loop\n", 989 | " xi, yi, zi = r[i]\n", 990 | " xj, yj, zj = r[j]\n", 991 | "\n", 992 | " dx = xi - xj\n", 993 | " dy = yi - yj\n", 994 | " dz = zi - zj\n", 995 | "\n", 996 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2)\n", 997 | " fi = m[i] / rrr\n", 998 | " fj = - m[j] / rrr\n", 999 | " \n", 1000 | " a[i][0] += fj * dx\n", 1001 | " a[i][1] += fj * dy\n", 1002 | " a[i][2] += fj * dz\n", 1003 | "\n", 1004 | " # Account the acceleration to the j-th body.\n", 1005 | " a[j][0] += fi * dx\n", 1006 | " a[j][1] += fi * dy\n", 1007 | " a[j][2] += fi * dz\n", 1008 | "\n", 1009 | " return a" 1010 | ] 1011 | }, 1012 | { 1013 | "cell_type": "code", 1014 | "execution_count": null, 1015 | "id": "68", 1016 | "metadata": {}, 1017 | "outputs": [], 1018 | "source": [ 1019 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc6)" 1020 | ] 1021 | }, 1022 | { 1023 | "cell_type": "markdown", 1024 | "id": "69", 1025 | "metadata": {}, 1026 | "source": [ 1027 | "However, `acc?()` is not the only function we can optimize to reduce operation count.\n", 1028 | "If we study `leapfrog?()` carefully, the acceleration `a` in the second \"kick\" calculation can actually be reused by the first \"kick\" of the next step.\n", 1029 | "This requires modifying the function prototype a bit." 1030 | ] 1031 | }, 1032 | { 1033 | "cell_type": "code", 1034 | "execution_count": null, 1035 | "id": "70", 1036 | "metadata": {}, 1037 | "outputs": [], 1038 | "source": [ 1039 | "def leapfrog3(m, r0, v0, a0, dt, acc=acc6):\n", 1040 | "\n", 1041 | " n = len(m)\n", 1042 | "\n", 1043 | " # vh = v0 + 0.5 * dt * a0\n", 1044 | " # a0 = acc(m, r0) <--- we comment this out, and reuse the acceleration computed in the second \"kick\" of the previous step\n", 1045 | " vh = [[v0[i][k] + 0.5 * dt * a0[i][k]\n", 1046 | " for k in range(3)]\n", 1047 | " for i in range(n)]\n", 1048 | "\n", 1049 | " # r1 = r0 + dt * vh\n", 1050 | " r1 = [[r0[i][k] + dt * vh[i][k]\n", 1051 | " for k in range(3)]\n", 1052 | " for i in range(n)]\n", 1053 | "\n", 1054 | " # v1 = vh + 0.5 * dt * a1\n", 1055 | " a1 = acc(m, r1)\n", 1056 | " v1 = [[vh[i][k] + 0.5 * dt * a1[i][k]\n", 1057 | " for k in range(3)]\n", 1058 | " for i in range(n)]\n", 1059 | "\n", 1060 | " return r1, v1, a1 # <--- to reuse the acceleration computed in the second \"kick\", let's return it." 1061 | ] 1062 | }, 1063 | { 1064 | "cell_type": "code", 1065 | "execution_count": null, 1066 | "id": "71", 1067 | "metadata": {}, 1068 | "outputs": [], 1069 | "source": [ 1070 | "# Precompute `a0`\n", 1071 | "a0 = acc6(m, r0) " 1072 | ] 1073 | }, 1074 | { 1075 | "cell_type": "code", 1076 | "execution_count": null, 1077 | "id": "72", 1078 | "metadata": {}, 1079 | "outputs": [], 1080 | "source": [ 1081 | "%timeit r1, v1, a1 = leapfrog3(m, r0, v0, a0, dt, acc=acc4)" 1082 | ] 1083 | }, 1084 | { 1085 | "cell_type": "code", 1086 | "execution_count": null, 1087 | "id": "73", 1088 | "metadata": {}, 1089 | "outputs": [], 1090 | "source": [ 1091 | "%timeit r1, v1, a1 = leapfrog3(m, r0, v0, a0, dt, acc=acc6)" 1092 | ] 1093 | }, 1094 | { 1095 | "cell_type": "markdown", 1096 | "id": "74", 1097 | "metadata": {}, 1098 | "source": [ 1099 | "Remarkable, it takes about 241ms on my laptop to run a single step of leapfrog for $n = 1000$ bodies.\n", 1100 | "This is almost a 92% reduction in benchmark time!!!" 1101 | ] 1102 | }, 1103 | { 1104 | "cell_type": "markdown", 1105 | "id": "75", 1106 | "metadata": {}, 1107 | "source": [ 1108 | "## Optimizing 3: Use Fast Operations\n", 1109 | "\n", 1110 | "Certain operations in Python can be surprisingly slow.\n", 1111 | "For example, the power operator (`**`) often calls C's pow() function, which for non-integer or variable exponents is typically implemented as:\n", 1112 | "\\begin{align}\n", 1113 | " x^y = \\exp[y \\cdot \\ln(x)]\n", 1114 | "\\end{align}\n", 1115 | "This method involves calculating a logarithm, a multiplication, and an exponential---operations that are much slower than simple multiplication." 1116 | ] 1117 | }, 1118 | { 1119 | "cell_type": "markdown", 1120 | "id": "76", 1121 | "metadata": {}, 1122 | "source": [ 1123 | "For instance, if you need to square a number, it's much faster to write `x*x` rather than `x**2`.\n", 1124 | "When the exponent is a known small integer, manually multiplying the base is usually the quickest route.\n", 1125 | "\n", 1126 | "By choosing fast operations over more generic ones, you can shave off microseconds in code that runs millions of times, ultimately contributing to some performance boost." 1127 | ] 1128 | }, 1129 | { 1130 | "cell_type": "code", 1131 | "execution_count": null, 1132 | "id": "77", 1133 | "metadata": {}, 1134 | "outputs": [], 1135 | "source": [ 1136 | "def acc7(m, r): # this is a modification of acc4()\n", 1137 | " \n", 1138 | " n = len(m)\n", 1139 | " a = []\n", 1140 | " for i in range(n):\n", 1141 | " axi, ayi, azi = 0, 0, 0\n", 1142 | " for j in range(n):\n", 1143 | " if j != i:\n", 1144 | " xi, yi, zi = r[i]\n", 1145 | " xj, yj, zj = r[j]\n", 1146 | "\n", 1147 | " dx = xi - xj\n", 1148 | " dy = yi - yj\n", 1149 | " dz = zi - zj\n", 1150 | "\n", 1151 | " rrr = (dx*dx + dy*dy + dz*dz)**(3/2) # replace dx**2 by dx*dx etc\n", 1152 | " f = - m[j] / rrr\n", 1153 | "\n", 1154 | " axi += f * dx\n", 1155 | " ayi += f * dy\n", 1156 | " azi += f * dz\n", 1157 | "\n", 1158 | " a.append((axi, ayi, azi))\n", 1159 | "\n", 1160 | " return a" 1161 | ] 1162 | }, 1163 | { 1164 | "cell_type": "code", 1165 | "execution_count": null, 1166 | "id": "78", 1167 | "metadata": {}, 1168 | "outputs": [], 1169 | "source": [ 1170 | "%timeit r1, v1, a1 = leapfrog3(m, r0, v0, a0, dt, acc=acc7)" 1171 | ] 1172 | }, 1173 | { 1174 | "cell_type": "code", 1175 | "execution_count": null, 1176 | "id": "79", 1177 | "metadata": {}, 1178 | "outputs": [], 1179 | "source": [ 1180 | "def acc8(m, r): # this is a modification of acc6()\n", 1181 | " \n", 1182 | " n = len(m)\n", 1183 | " a = [[0]*3]*n\n", 1184 | " for i in range(n):\n", 1185 | " for j in range(i):\n", 1186 | " xi, yi, zi = r[i]\n", 1187 | " xj, yj, zj = r[j]\n", 1188 | "\n", 1189 | " dx = xi - xj\n", 1190 | " dy = yi - yj\n", 1191 | " dz = zi - zj\n", 1192 | "\n", 1193 | " rrr = (dx*dx + dy*dy + dz*dz)**(3/2) # replace dx**2 by dx*dx etc\n", 1194 | " fi = m[i] / rrr\n", 1195 | " fj = - m[j] / rrr\n", 1196 | " \n", 1197 | " a[i][0] += fj * dx\n", 1198 | " a[i][1] += fj * dy\n", 1199 | " a[i][2] += fj * dz\n", 1200 | "\n", 1201 | " a[j][0] += fi * dx\n", 1202 | " a[j][1] += fi * dy\n", 1203 | " a[j][2] += fi * dz\n", 1204 | "\n", 1205 | " return a" 1206 | ] 1207 | }, 1208 | { 1209 | "cell_type": "code", 1210 | "execution_count": null, 1211 | "id": "80", 1212 | "metadata": {}, 1213 | "outputs": [], 1214 | "source": [ 1215 | "%timeit r1, v1, a1 = leapfrog3(m, r0, v0, a0, dt, acc=acc8)" 1216 | ] 1217 | }, 1218 | { 1219 | "cell_type": "markdown", 1220 | "id": "81", 1221 | "metadata": {}, 1222 | "source": [ 1223 | "Surprisingly, although `acc7()` does not take advantage of the symmetry of $\\mathbf{a}_{ij}$, it is now faster than `acc6()`.\n", 1224 | "On the other hand, `acc8()` is still slightly faster than `acc7()`." 1225 | ] 1226 | }, 1227 | { 1228 | "cell_type": "markdown", 1229 | "id": "82", 1230 | "metadata": {}, 1231 | "source": [ 1232 | "Without external libraries, we've cut benchmark time by nearly 93%---a 14x speedup over our original $n$-body implementation.\n", 1233 | "While impressive even compared to compiled languages, this is just the beginning.\n", 1234 | "By leveraging high-performance libraries, we can overcome Python’s inherent slowness and move closer to our goal of a 1000x speedup." 1235 | ] 1236 | }, 1237 | { 1238 | "cell_type": "markdown", 1239 | "id": "83", 1240 | "metadata": {}, 1241 | "source": [ 1242 | "## Optimization 4: Use High Performance Libraries\n", 1243 | "\n", 1244 | "Leveraging high-performance libraries like NumPy is key to accelerating Python code.\n", 1245 | "NumPy's vectorized operations, implemented in C, allow you to perform complex computations on large arrays far more efficiently than native Python loops.\n", 1246 | "This means you can offload heavy calculations to optimized, low-level routines and achieve significant speedups." 1247 | ] 1248 | }, 1249 | { 1250 | "cell_type": "code", 1251 | "execution_count": null, 1252 | "id": "84", 1253 | "metadata": {}, 1254 | "outputs": [], 1255 | "source": [ 1256 | "def acc9(m, r): # this is a modification of acc7(); acc8() is difficult to take advantage of numpy\n", 1257 | "\n", 1258 | " # idx: i j i j\n", 1259 | " # v v v v\n", 1260 | " dr = r[:,None,:] - r[None,:,:]\n", 1261 | " rr = np.sum(dr * dr, axis=-1) # sum over vector components\n", 1262 | "\n", 1263 | " # Ensure rr is non-zero\n", 1264 | " rr = np.maximum(rr, 1e-24)\n", 1265 | "\n", 1266 | " # idx: i j\n", 1267 | " # v v\n", 1268 | " f = -m[None,:] / rr**(3/2)\n", 1269 | " a = np.sum(f[:,:,None] * dr, axis=1) # sum over j\n", 1270 | " \n", 1271 | " return a" 1272 | ] 1273 | }, 1274 | { 1275 | "cell_type": "code", 1276 | "execution_count": null, 1277 | "id": "85", 1278 | "metadata": {}, 1279 | "outputs": [], 1280 | "source": [ 1281 | "def leapfrog4(m, r0, v0, a0, dt, acc=acc9):\n", 1282 | " ht = 0.5 * dt\n", 1283 | "\n", 1284 | " vh = v0 + ht * a0\n", 1285 | " r1 = r0 + dt * vh\n", 1286 | " \n", 1287 | " a1 = acc(m, r1)\n", 1288 | " v1 = vh + ht * a1\n", 1289 | "\n", 1290 | " return r1, v1, a1" 1291 | ] 1292 | }, 1293 | { 1294 | "cell_type": "code", 1295 | "execution_count": null, 1296 | "id": "86", 1297 | "metadata": {}, 1298 | "outputs": [], 1299 | "source": [ 1300 | "m = np.array(m)\n", 1301 | "r0 = np.array(r0)\n", 1302 | "v0 = np.array(v0)\n", 1303 | "a0 = np.array(a0)" 1304 | ] 1305 | }, 1306 | { 1307 | "cell_type": "code", 1308 | "execution_count": null, 1309 | "id": "87", 1310 | "metadata": {}, 1311 | "outputs": [], 1312 | "source": [ 1313 | "%timeit r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc9)" 1314 | ] 1315 | }, 1316 | { 1317 | "cell_type": "code", 1318 | "execution_count": null, 1319 | "id": "88", 1320 | "metadata": {}, 1321 | "outputs": [], 1322 | "source": [ 1323 | "cProfile.run(\"r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc9)\")" 1324 | ] 1325 | }, 1326 | { 1327 | "cell_type": "markdown", 1328 | "id": "89", 1329 | "metadata": {}, 1330 | "source": [ 1331 | "This runs at 38.3ms on my laptop.\n", 1332 | "This cuts the benchmark time by 99% (which is no longer good indicator) and reaches a 82x speedup!" 1333 | ] 1334 | }, 1335 | { 1336 | "cell_type": "markdown", 1337 | "id": "90", 1338 | "metadata": {}, 1339 | "source": [ 1340 | "## Optimization 4: Use Lower Precision\n", 1341 | "\n", 1342 | "Using lower precision arithmetic can reduce memory usage and potentially speed up calculations.\n", 1343 | "However, the benefits depend heavily on your hardware.\n", 1344 | "For instance, on Intel x86 platforms, single precision may actually be slower than double precision due to conversion overhead and how the hardware is optimized.\n", 1345 | "Always test and profile on your target system before switching to lower precision." 1346 | ] 1347 | }, 1348 | { 1349 | "cell_type": "code", 1350 | "execution_count": null, 1351 | "id": "91", 1352 | "metadata": {}, 1353 | "outputs": [], 1354 | "source": [ 1355 | "m = np.array(m, dtype=np.single)\n", 1356 | "r0 = np.array(r0, dtype=np.single)\n", 1357 | "v0 = np.array(v0, dtype=np.single)\n", 1358 | "a0 = np.array(a0, dtype=np.single)" 1359 | ] 1360 | }, 1361 | { 1362 | "cell_type": "code", 1363 | "execution_count": null, 1364 | "id": "92", 1365 | "metadata": {}, 1366 | "outputs": [], 1367 | "source": [ 1368 | "%timeit r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc9)" 1369 | ] 1370 | }, 1371 | { 1372 | "cell_type": "markdown", 1373 | "id": "93", 1374 | "metadata": {}, 1375 | "source": [ 1376 | "## Optimization 5: Google `JAX`\n", 1377 | "\n", 1378 | "Google `JAX` is a high-performance library that extends NumPy with automatic differentiation and just-in-time (JIT) compilation.\n", 1379 | "It transforms Python functions into highly optimized machine code using XLA, which can dramatically accelerate numerical computations.\n", 1380 | "\n", 1381 | "One of `JAX`'s standout features is its seamless support for GPUs.\n", 1382 | "Depending on your hardware, running your code on a GPU with `JAX` can lead to even greater speedups compared to CPU execution.\n", 1383 | "Additionally, `JAX`'s vectorization tools, like `vmap`, let you apply functions over arrays efficiently without explicit Python loops.\n", 1384 | "\n", 1385 | "By replace NumPy by `JAX`, we can harness the power of hardware acceleration, and achieve the ambitious 1000x speedup!" 1386 | ] 1387 | }, 1388 | { 1389 | "cell_type": "code", 1390 | "execution_count": null, 1391 | "id": "94", 1392 | "metadata": {}, 1393 | "outputs": [], 1394 | "source": [ 1395 | "from jax import numpy as jnp" 1396 | ] 1397 | }, 1398 | { 1399 | "cell_type": "code", 1400 | "execution_count": null, 1401 | "id": "95", 1402 | "metadata": {}, 1403 | "outputs": [], 1404 | "source": [ 1405 | "def acc10(m, r):\n", 1406 | " dr = r[:,None,:] - r[None,:,:]\n", 1407 | " rr = jnp.sum(dr * dr, axis=-1)\n", 1408 | " rr = jnp.maximum(rr, 1e-24)\n", 1409 | " f = -m[None,:] / rr**(3/2)\n", 1410 | " a = jnp.sum(f[:,:,None] * dr, axis=1)\n", 1411 | " return a" 1412 | ] 1413 | }, 1414 | { 1415 | "cell_type": "code", 1416 | "execution_count": null, 1417 | "id": "96", 1418 | "metadata": {}, 1419 | "outputs": [], 1420 | "source": [ 1421 | "m = jnp.array(m)\n", 1422 | "r0 = jnp.array(r0)\n", 1423 | "v0 = jnp.array(v0)\n", 1424 | "a0 = jnp.array(a0)" 1425 | ] 1426 | }, 1427 | { 1428 | "cell_type": "code", 1429 | "execution_count": null, 1430 | "id": "97", 1431 | "metadata": {}, 1432 | "outputs": [], 1433 | "source": [ 1434 | "%timeit r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc10)" 1435 | ] 1436 | }, 1437 | { 1438 | "cell_type": "code", 1439 | "execution_count": null, 1440 | "id": "98", 1441 | "metadata": {}, 1442 | "outputs": [], 1443 | "source": [ 1444 | "cProfile.run(\"r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc10)\")" 1445 | ] 1446 | }, 1447 | { 1448 | "cell_type": "markdown", 1449 | "id": "99", 1450 | "metadata": {}, 1451 | "source": [ 1452 | "The timeit results on my laptop show that the slowest run took ~ 11 times longer than the fastest, with an average execution time of 3ms.\n", 1453 | "This variability is likely due to the initialization of the JAX library and potential data transfer overhead.\n", 1454 | "Nevertheless, these results are astonishing---achieving a 1023x speedup over our original implementation!" 1455 | ] 1456 | }, 1457 | { 1458 | "cell_type": "markdown", 1459 | "id": "100", 1460 | "metadata": {}, 1461 | "source": [ 1462 | "## Optimization 6: Just-in-Time Compilation\n", 1463 | "\n", 1464 | "Just-in-Time (JIT) compilation converts Python code into optimized machine code at runtime.\n", 1465 | "This means that hot functions can be compiled on the fly, eliminating Python's overhead and allowing them to run at speeds much closer to those of native C code.\n", 1466 | "\n", 1467 | "`JAX` offers powerful JIT capabilities through its `jax.jit` decorator.\n", 1468 | "When applied, it compiles your numerical functions into efficient, low-level code, which can be further accelerated on GPUs.\n", 1469 | "This results in dramatic performance improvements, exeeding our quest for that 1000x speedup." 1470 | ] 1471 | }, 1472 | { 1473 | "cell_type": "code", 1474 | "execution_count": null, 1475 | "id": "101", 1476 | "metadata": {}, 1477 | "outputs": [], 1478 | "source": [ 1479 | "from jax import jit" 1480 | ] 1481 | }, 1482 | { 1483 | "cell_type": "code", 1484 | "execution_count": null, 1485 | "id": "102", 1486 | "metadata": {}, 1487 | "outputs": [], 1488 | "source": [ 1489 | "@jit\n", 1490 | "def leapfrog5(m, r0, v0, a0, dt):\n", 1491 | " return leapfrog4(m, r0, v0, a0, dt, acc=acc10)" 1492 | ] 1493 | }, 1494 | { 1495 | "cell_type": "code", 1496 | "execution_count": null, 1497 | "id": "103", 1498 | "metadata": {}, 1499 | "outputs": [], 1500 | "source": [ 1501 | "%timeit r1, v1, a1 = leapfrog5(m, r0, v0, a0, dt)" 1502 | ] 1503 | }, 1504 | { 1505 | "cell_type": "code", 1506 | "execution_count": null, 1507 | "id": "104", 1508 | "metadata": {}, 1509 | "outputs": [], 1510 | "source": [ 1511 | "cProfile.run(\"r1, v1, a1 = leapfrog5(m, r0, v0, a0, dt)\")" 1512 | ] 1513 | }, 1514 | { 1515 | "cell_type": "markdown", 1516 | "id": "105", 1517 | "metadata": {}, 1518 | "source": [ 1519 | "Our final results are truly remarkable.\n", 1520 | "By applying `JAX`'s JIT compilation, we compiled away all overhead (see the output of `cProfile`) and achieved a 1710x speedup over the original implementation.\n", 1521 | "This demostrate how powerful the optimization techniques we have introduced to improve Python's performance." 1522 | ] 1523 | }, 1524 | { 1525 | "cell_type": "markdown", 1526 | "id": "106", 1527 | "metadata": {}, 1528 | "source": [ 1529 | "## Applying the $n$-Body Integrator\n", 1530 | "\n", 1531 | "With our optimized $n$-body integrator in hand, we can now explore its applications.\n", 1532 | "This efficient python code lets us simulate various dynamical systems---from celestial mechanics to particle interactions---with different initial conditions.\n", 1533 | "By reading in data files that specify masses, positions, velocities, we can easily tailor simulations to real-world scenarios." 1534 | ] 1535 | }, 1536 | { 1537 | "cell_type": "code", 1538 | "execution_count": null, 1539 | "id": "107", 1540 | "metadata": {}, 1541 | "outputs": [], 1542 | "source": [ 1543 | "def integrate(fname, T, N):\n", 1544 | "\n", 1545 | " # Load data\n", 1546 | " ic = np.genfromtxt(fname)\n", 1547 | "\n", 1548 | " # Setup initial conditions\n", 1549 | " m = jnp.array(ic[:,0])\n", 1550 | " R = [jnp.array(ic[:,1:4])]\n", 1551 | " V = [jnp.array(ic[:,4:7])]\n", 1552 | " A = [acc10(m, R[-1])]\n", 1553 | "\n", 1554 | " # Main loop\n", 1555 | " T, dt = jnp.linspace(0, T, N+1, retstep=True)\n", 1556 | " for _ in tqdm(range(N)):\n", 1557 | " r, v, a = leapfrog4(m, R[-1], V[-1], A[-1], dt, acc=acc10)\n", 1558 | " R.append(r)\n", 1559 | " V.append(v)\n", 1560 | " A.append(a)\n", 1561 | "\n", 1562 | " # Return results\n", 1563 | " return T, jnp.array(R), jnp.array(V), jnp.array(A)" 1564 | ] 1565 | }, 1566 | { 1567 | "cell_type": "markdown", 1568 | "id": "108", 1569 | "metadata": {}, 1570 | "source": [ 1571 | "The figure-8 orbit is a fascinating solution to the three-body problem where three equal-mass bodies follow a single, intertwined figure-8 path.\n", 1572 | "Each body moves along the same curve, perfectly choreographed so that they never collide, yet their gravitational interactions keep them in a stable, periodic dance." 1573 | ] 1574 | }, 1575 | { 1576 | "cell_type": "markdown", 1577 | "id": "109", 1578 | "metadata": {}, 1579 | "source": [ 1580 | "We first download the initial condition from GitHub if the file doesn't exist:" 1581 | ] 1582 | }, 1583 | { 1584 | "cell_type": "code", 1585 | "execution_count": null, 1586 | "id": "110", 1587 | "metadata": {}, 1588 | "outputs": [], 1589 | "source": [ 1590 | "! [ -f ic/figure-8.tsv ] || wget -P ic https://raw.githubusercontent.com/rndsrc/orbits-py/refs/heads/main/ic/figure-8.tsv" 1591 | ] 1592 | }, 1593 | { 1594 | "cell_type": "code", 1595 | "execution_count": null, 1596 | "id": "111", 1597 | "metadata": {}, 1598 | "outputs": [], 1599 | "source": [ 1600 | "# We then run the integrate() code\n", 1601 | "\n", 1602 | "T, R, V, A = integrate(\"ic/figure-8.tsv\", 2.5, 250)" 1603 | ] 1604 | }, 1605 | { 1606 | "cell_type": "code", 1607 | "execution_count": null, 1608 | "id": "112", 1609 | "metadata": {}, 1610 | "outputs": [], 1611 | "source": [ 1612 | "X = R[:,:,0]\n", 1613 | "Y = R[:,:,1]\n", 1614 | "\n", 1615 | "plt.plot(X, Y, '-')" 1616 | ] 1617 | }, 1618 | { 1619 | "cell_type": "code", 1620 | "execution_count": null, 1621 | "id": "113", 1622 | "metadata": {}, 1623 | "outputs": [], 1624 | "source": [ 1625 | "anim = animate(X, Y)\n", 1626 | "\n", 1627 | "HTML(anim.to_html5_video()) # display animation\n", 1628 | "# anim.save('orbitss.mp4') # save animation" 1629 | ] 1630 | }, 1631 | { 1632 | "cell_type": "markdown", 1633 | "id": "114", 1634 | "metadata": {}, 1635 | "source": [ 1636 | "## Conclusion and Discussion\n", 1637 | "\n", 1638 | "In this workshop, we've journeyed from a straightforward, pure Python implementation of the $n$-body simulation to a highly optimized version that achieves a 1710x speedup.\n", 1639 | "We started by cleaning up our code with list comprehensions and cutting out unnecessary operations, then moved on to swapping slow functions for faster ones and leveraging high-performance libraries like NumPy.\n", 1640 | "\n", 1641 | "The real game-changer came with Google `JAX`.\n", 1642 | "Its just-in-time compilation and GPU support pushed our code another order of magnitude in performance.\n", 1643 | "These optimizations show that Python, when used right, isn't just for quick-and-dirty prototypes.\n", 1644 | "It can also deliver serious speed and efficiency.\n", 1645 | "\n", 1646 | "This is especially important in hackathons and rapid prototyping environments.\n", 1647 | "You can quickly iterate on your ideas and still end up with code that's robust enough for real-world applications.\n", 1648 | "Python truly bridges the gap between ease of development and high performance, letting you have your cake and eat it too." 1649 | ] 1650 | } 1651 | ], 1652 | "metadata": { 1653 | "kernelspec": { 1654 | "display_name": "Python 3 (ipykernel)", 1655 | "language": "python", 1656 | "name": "python3" 1657 | }, 1658 | "language_info": { 1659 | "codemirror_mode": { 1660 | "name": "ipython", 1661 | "version": 3 1662 | }, 1663 | "file_extension": ".py", 1664 | "mimetype": "text/x-python", 1665 | "name": "python", 1666 | "nbconvert_exporter": "python", 1667 | "pygments_lexer": "ipython3", 1668 | "version": "3.13.2" 1669 | } 1670 | }, 1671 | "nbformat": 4, 1672 | "nbformat_minor": 5 1673 | } 1674 | -------------------------------------------------------------------------------- /handson.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "0", 6 | "metadata": {}, 7 | "source": [ 8 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rndsrc/orbits-py/blob/main/handson.ipynb)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "1", 14 | "metadata": {}, 15 | "source": [ 16 | "# Speeding Up Your Python Codes 1000x" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "2", 22 | "metadata": {}, 23 | "source": [ 24 | "## Introduction\n", 25 | "\n", 26 | "Welcome to \"Speeding Up Your Python Codes 1000x\"!\n", 27 | "\n", 28 | "In the next 45 minutes, we will learn tricks to unlock python performance and increase a sample code's speed by a factor of 1000x!" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "3", 34 | "metadata": {}, 35 | "source": [ 36 | "Here is our plan:\n", 37 | "* **Understand the problem:**\n", 38 | " we will start by outlining the $n$-body problem and how to use the leapfrog algorithm to solve it numerically.\n", 39 | "* **Benchmark and profile:**\n", 40 | " we will learn how to measure performance and identify bottlenecks using tools like timeit and cProfile.\n", 41 | "* **Optimize the code:**\n", 42 | " step through a series of optimizations:\n", 43 | " * Use list comprehensions and reduce operation counts.\n", 44 | " * Replace slow operations with faster ones.\n", 45 | " * Leverage high-performance libraries such as NumPy.\n", 46 | " * Explore lower precision where applicable.\n", 47 | " * Harness Google JAX for just-in-time compilation and GPU acceleration.\n", 48 | "* **Apply the integrator:** Finally, we will use our optimized code to simulate the mesmerizing figure-8 self-gravitating orbit." 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "4", 54 | "metadata": {}, 55 | "source": [ 56 | "By the end of this workshop, you will see that Python isn't just great for rapid prototyping.\n", 57 | "It can also deliver serious performance when optimized right. Let’s dive in!" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "5", 63 | "metadata": {}, 64 | "source": [ 65 | "## The $n$-body Problem and the Leapfrog Algorithm\n", 66 | "\n", 67 | "The $n$-body problem involves simulating the motion of many bodies interacting with each other through (usually) gravitational forces.\n", 68 | "\n", 69 | "This problem is physically interesting, as it models the [solar system](https://rebound.readthedocs.io/en/latest/), [galaxy dynamics, or even the whole cosmic structure formation](https://wwwmpa.mpa-garching.mpg.de/gadget/)!\n", 70 | "\n", 71 | "This problem is also computational interesting.\n", 72 | "Each body experiences a force from every other body, making the simulation computationally intensive.\n", 73 | "It is ideal for exploring optimization techniques.\n", 74 | "\n", 75 | "The leapfrog algorithm is a simple, robust numerical integrator that alternates between updating velocities (\"kicks\") and positions (\"drifts\").\n", 76 | "It is popular because it conserves energy and momentum well over long simulation periods." 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "id": "6", 82 | "metadata": {}, 83 | "source": [ 84 | "We need to solve Newton's second law of motion for a collection of $n$ bodies.\n", 85 | "The body forces are solved by using Newton's Law of Universal Gravitation." 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "7", 91 | "metadata": {}, 92 | "source": [ 93 | "### Gravitational Force/Acceleration Equation\n", 94 | "\n", 95 | "For any two bodies labeled by $i$ and $j$, the direct gravitational force exerted on body $i$ by body $j$ is given by:\n", 96 | "\\begin{align}\n", 97 | " \\mathbf{f}_{ij} = -G m_i m_j\\frac{\\mathbf{r}_i - \\mathbf{r}_j\\ \\ \\ }{|\\mathbf{r}_i - \\mathbf{r}_j|^{3/2}},\n", 98 | "\\end{align}\n", 99 | "where:\n", 100 | "* $G$ is the gravitational constant.\n", 101 | "* $m_i$ and $m_j$ are the masses of the bodies.\n", 102 | "* $\\mathbf{r}_i$ and $\\mathbf{r}_j$ are their position vectors." 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "8", 108 | "metadata": {}, 109 | "source": [ 110 | "Summing the contributions from all other bodies gives the net force (and thus the acceleration) on body $i$:\n", 111 | "\\begin{align}\n", 112 | " \\mathbf{f}_i\n", 113 | " = \\sum_{j\\ne i} \\mathbf{f}_{ij}\n", 114 | " = -G m_i \\sum_{j\\ne i} m_j \\frac{\\mathbf{r}_i - \\mathbf{r}_j\\ \\ \\ }{|\\mathbf{r}_i - \\mathbf{r}_j|^3},\n", 115 | "\\end{align}" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "id": "9", 121 | "metadata": {}, 122 | "source": [ 123 | "Given Newton's law $f = m a$, the \"acceleration\" applied on body $i$ caused by all other bodies is:\n", 124 | "\\begin{align}\n", 125 | " \\mathbf{a}_i\n", 126 | " = \\sum_{j\\ne i} \\mathbf{a}_{ij}\n", 127 | " = -G \\sum_{j\\ne i} m_j \\frac{\\mathbf{r}_i - \\mathbf{r}_j\\ \\ \\ }{|\\mathbf{r}_i - \\mathbf{r}_j|^3}.\n", 128 | "\\end{align}" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "id": "10", 134 | "metadata": {}, 135 | "source": [ 136 | "Choosing the right units so $G = 1$.\n", 137 | "Here is a pure Python implementation of the gravitational acceleration:" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "11", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "def acc1(m, r):\n", 148 | " \n", 149 | " n = len(m)\n", 150 | " a = []\n", 151 | " for i in range(n):\n", 152 | " axi, ayi, azi = 0, 0, 0\n", 153 | " for j in range(n):\n", 154 | " if j != i:\n", 155 | " xi, yi, zi = r[i]\n", 156 | " xj, yj, zj = r[j]\n", 157 | "\n", 158 | " axij = - m[j] * (xi - xj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)\n", 159 | " ayij = - m[j] * (yi - yj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)\n", 160 | " azij = - m[j] * (zi - zj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)\n", 161 | "\n", 162 | " axi += axij\n", 163 | " ayi += ayij\n", 164 | " azi += azij\n", 165 | "\n", 166 | " a.append((axi, ayi, azi))\n", 167 | "\n", 168 | " return a" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "id": "12", 174 | "metadata": {}, 175 | "source": [ 176 | "We may try using this function to evaluate the gravitational force between two particles, with with mass $m = 1$, at a distance of $2$.\n", 177 | "We expect the \"acceleration\" be -1/4 in the direction of their seperation." 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "id": "13", 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "m = [1.0, 1.0]\n", 188 | "r0 = [\n", 189 | " (+1.0, 0.0, 0.0),\n", 190 | " (-1.0, 0.0, 0.0),\n", 191 | "]\n", 192 | "a0ref = [\n", 193 | " (-0.25, 0.0, 0.0),\n", 194 | " (+0.25, 0.0, 0.0),\n", 195 | "]" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "14", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "a0 = acc1(m, r0)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "id": "15", 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "a0" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "id": "16", 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "assert a0 == a0ref" 226 | ] 227 | }, 228 | { 229 | "cell_type": "markdown", 230 | "id": "17", 231 | "metadata": {}, 232 | "source": [ 233 | "### Leapfrog Integrator\n", 234 | "\n", 235 | "Our simulation solves Newton's second law of motion, $\\mathbf{f} = m \\mathbf{a}$, numerically.\n", 236 | "The equation tells us that the acceleration of a body is determined by the net force acting on it.\n", 237 | "In a system where forces are conservative, this law guarantees the conservation of energy and momentum.\n", 238 | "These are critical properties when simulating long-term dynamics.\n", 239 | "\n", 240 | "However, many standard integration methods tend to drift over time, gradually losing these conservation properties.\n", 241 | "The leapfrog algorithm is a symplectic integrator designed to address this issue.\n", 242 | "Its staggered (or \"leapfrogging\") updates of velocity and position help preserve energy and momentum much more effectively over extended simulations.\n", 243 | "\n", 244 | "1. Half-step velocity update (Kick):\n", 245 | " start by updating the velocity by half a time step using the current acceleration $a(t)$:\n", 246 | " \\begin{align}\n", 247 | " v\\left(t+\\frac{1}{2}\\Delta t\\right) = v(t) + \\frac{1}{2}\\Delta t\\, a(t).\n", 248 | " \\end{align}\n", 249 | "\n", 250 | "2. Full-step position update (Drift):\n", 251 | " update the position for a full time step using the half-stepped velocity:\n", 252 | " \\begin{align}\n", 253 | " x(t+\\Delta t) = x(t) + \\Delta t\\, v\\left(t+\\frac{1}{2}\\Delta t\\right).\n", 254 | " \\end{align}\n", 255 | "\n", 256 | "3. Half-step velocity update (Kick):\n", 257 | " finally, update the velocity by another half time step using the new acceleration $a(t+\\Delta t)$ computed from the updated positions:\n", 258 | " \\begin{align}\n", 259 | " v(t+\\Delta t) = v\\left(t+\\frac{1}{2}\\Delta t\\right) + \\frac{1}{2}\\Delta t\\, a(t+\\Delta t).\n", 260 | " \\end{align}\n", 261 | "\n", 262 | "Here is a pure Python implementation of the gravitational acceleration:" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": null, 268 | "id": "18", 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "def leapfrog1(m, r0, v0, dt, acc=acc1):\n", 273 | "\n", 274 | " n = len(m)\n", 275 | "\n", 276 | " # vh = v0 + 0.5 * dt * a0\n", 277 | " a0 = acc(m, r0)\n", 278 | " vh = []\n", 279 | " for i in range(n):\n", 280 | " a0xi, a0yi, a0zi = a0[i]\n", 281 | " v0xi, v0yi, v0zi = v0[i]\n", 282 | " vhxi = v0xi + 0.5 * dt * a0xi\n", 283 | " vhyi = v0yi + 0.5 * dt * a0yi\n", 284 | " vhzi = v0zi + 0.5 * dt * a0zi\n", 285 | " vh.append((vhxi, vhyi, vhzi))\n", 286 | "\n", 287 | " # r1 = r0 + dt * vh\n", 288 | " r1 = []\n", 289 | " for i in range(n):\n", 290 | " vhxi, vhyi, vhzi = vh[i]\n", 291 | " x0i, y0i, z0i = r0[i]\n", 292 | " x1i = x0i + dt * vhxi\n", 293 | " y1i = y0i + dt * vhyi\n", 294 | " z1i = z0i + dt * vhzi\n", 295 | " r1.append((x1i, y1i, z1i))\n", 296 | "\n", 297 | " # v1 = vh + 0.5 * dt * a1\n", 298 | " a1 = acc(m, r1)\n", 299 | " v1 = []\n", 300 | " for i in range(n):\n", 301 | " a1xi, a1yi, a1zi = a1[i]\n", 302 | " vhxi, vhyi, vhzi = vh[i]\n", 303 | " v1xi = vhxi + 0.5 * dt * a1xi\n", 304 | " v1yi = vhyi + 0.5 * dt * a1yi\n", 305 | " v1zi = vhzi + 0.5 * dt * a1zi\n", 306 | " v1.append((v1xi, v1yi, v1zi))\n", 307 | "\n", 308 | " return r1, v1" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "id": "19", 314 | "metadata": {}, 315 | "source": [ 316 | "We may try using this function to evolve the two particles under gravitational force, with some initial velocities." 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "id": "20", 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "v0 = [\n", 327 | " (0.0, +0.5, 0.0),\n", 328 | " (0.0, -0.5, 0.0),\n", 329 | "]\n", 330 | "v1ref = [\n", 331 | " (-0.024984345739803245, +0.4993750014648409, 0.0),\n", 332 | " (+0.024984345739803245, -0.4993750014648409, 0.0),\n", 333 | "]" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "id": "21", 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "r1, v1 = leapfrog1(m, r0, v0, 0.1)" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": null, 349 | "id": "22", 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [ 353 | "v1" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "id": "23", 360 | "metadata": {}, 361 | "outputs": [], 362 | "source": [ 363 | "assert v1 == v1ref" 364 | ] 365 | }, 366 | { 367 | "cell_type": "markdown", 368 | "id": "24", 369 | "metadata": {}, 370 | "source": [ 371 | "### Test the Integrator\n", 372 | "\n", 373 | "We test the integrator by evolving the two particle by a time $T = 2\\pi$ using $N = 64$ steps." 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "id": "25", 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "from math import pi\n", 384 | "\n", 385 | "N = 64\n", 386 | "T = 2 * pi" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "id": "26", 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "dt = T / N\n", 397 | "R = [r0]\n", 398 | "V = [v0]\n", 399 | "for _ in range(N):\n", 400 | " r, v = leapfrog1(m, R[-1], V[-1], dt)\n", 401 | " R.append(r)\n", 402 | " V.append(v)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "id": "27", 408 | "metadata": {}, 409 | "source": [ 410 | "Although our numerical scheme is implemented in pure python, it is still handy to use `numpy` to slice through the data..." 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": null, 416 | "id": "28", 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "import numpy as np\n", 421 | "\n", 422 | "R = np.array(R)\n", 423 | "X = R[:,:,0]\n", 424 | "Y = R[:,:,1]" 425 | ] 426 | }, 427 | { 428 | "cell_type": "markdown", 429 | "id": "29", 430 | "metadata": {}, 431 | "source": [ 432 | "and plot the result..." 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": null, 438 | "id": "30", 439 | "metadata": {}, 440 | "outputs": [], 441 | "source": [ 442 | "from matplotlib import pyplot as plt\n", 443 | "\n", 444 | "plt.plot(X, Y, '.-')\n", 445 | "plt.gca().set_aspect('equal')" 446 | ] 447 | }, 448 | { 449 | "cell_type": "markdown", 450 | "id": "31", 451 | "metadata": {}, 452 | "source": [ 453 | "In addition, we may create animation." 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "id": "32", 460 | "metadata": {}, 461 | "outputs": [], 462 | "source": [ 463 | "from matplotlib.animation import ArtistAnimation\n", 464 | "from IPython.display import HTML\n", 465 | "from tqdm import tqdm\n", 466 | "\n", 467 | "def animate(X, Y, ntail=10):\n", 468 | " fig, ax = plt.subplots(1, 1, figsize=(5,5))\n", 469 | " ax.set_xlabel('x')\n", 470 | " ax.set_ylabel('y')\n", 471 | "\n", 472 | " frames = []\n", 473 | " for i in tqdm(range(len(X))):\n", 474 | " b,e = max(0, i-ntail), i+1\n", 475 | " ax.set_prop_cycle(None)\n", 476 | " f = ax.plot(X[b:e,:], Y[b:e,:], '.-', animated=True)\n", 477 | " frames.append(f)\n", 478 | "\n", 479 | " anim = ArtistAnimation(fig, frames, interval=50)\n", 480 | " plt.close()\n", 481 | " \n", 482 | " return anim" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": null, 488 | "id": "33", 489 | "metadata": {}, 490 | "outputs": [], 491 | "source": [ 492 | "anim = animate(X, Y)\n", 493 | "\n", 494 | "HTML(anim.to_html5_video()) # display animation\n", 495 | "# anim.save('orbitss.mp4') # save animation" 496 | ] 497 | }, 498 | { 499 | "cell_type": "markdown", 500 | "id": "34", 501 | "metadata": {}, 502 | "source": [ 503 | "## Benchmarking and Profiling Your Code\n", 504 | "\n", 505 | "Before diving into optimizations, it is essential to understand where our code spends most of its time.\n", 506 | "By benchmarking and profiling, we can pinpoint performance bottlenecks and measure improvements after optimization." 507 | ] 508 | }, 509 | { 510 | "cell_type": "markdown", 511 | "id": "35", 512 | "metadata": {}, 513 | "source": [ 514 | "### Benchmarking vs. Profiling\n", 515 | "\n", 516 | "* Benchmarking:\n", 517 | " Measures the overall runtime of your code.\n", 518 | " Tools like Python's `timeit` module run your code multiple times to provide an accurate average runtime.\n", 519 | " This helps in comparing the performance before and after optimizations.\n", 520 | "\n", 521 | "* Profiling:\n", 522 | " Provides detailed insights into which parts of your code are consuming the most time.\n", 523 | " For example, `cProfile` generates reports showing function call times and frequencies.\n", 524 | " Note that `cProfile` typically runs the code once, so its focus is on identifying hotspots rather than providing averaged timings." 525 | ] 526 | }, 527 | { 528 | "cell_type": "markdown", 529 | "id": "36", 530 | "metadata": {}, 531 | "source": [ 532 | "### Quick Benchmark Example\n", 533 | "\n", 534 | "`timeit` is a module designed for benchmarking by executing code multiple times.\n", 535 | "It is excellent for obtaining reliable runtime measurements." 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": null, 541 | "id": "37", 542 | "metadata": {}, 543 | "outputs": [], 544 | "source": [ 545 | "n = 1000\n", 546 | "m = np.random.lognormal(size=n).tolist()\n", 547 | "r0 = np.random.normal(size=(n, 3)).tolist()\n", 548 | "v0 = np.random.normal(size=(n, 3)).tolist()" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": null, 554 | "id": "38", 555 | "metadata": {}, 556 | "outputs": [], 557 | "source": [ 558 | "%timeit r1, v1 = leapfrog1(m, r0, v0, dt)" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "id": "39", 564 | "metadata": {}, 565 | "source": [ 566 | "It takes about 1.31 second on my laptop to run a single step of leapfrog for $n = 1000$ bodies.\n", 567 | "Your mileage may vary.\n", 568 | "But this is pretty slow in today's standard." 569 | ] 570 | }, 571 | { 572 | "cell_type": "markdown", 573 | "id": "40", 574 | "metadata": {}, 575 | "source": [ 576 | "### Quick Profiling Example\n", 577 | "\n", 578 | "Here is a snippet using cProfile to profile our leapfrog stepper:" 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "execution_count": null, 584 | "id": "41", 585 | "metadata": {}, 586 | "outputs": [], 587 | "source": [ 588 | "import cProfile\n", 589 | "\n", 590 | "cProfile.run(\"r1, v1 = leapfrog1(m, r0, v0, dt)\")" 591 | ] 592 | }, 593 | { 594 | "cell_type": "markdown", 595 | "id": "42", 596 | "metadata": {}, 597 | "source": [ 598 | "From cProfile's result, the most used function is `method 'append' of 'list' objects`.\n", 599 | "This is not surprising given we've been using for-loop and append, e.g.,\n", 600 | "```\n", 601 | " v1 = []\n", 602 | " for i in range(n):\n", 603 | " ...\n", 604 | " v1.append(...)\n", 605 | "```\n", 606 | "in both `acc1()` and `leapfrog1()`.\n", 607 | "This is neither pythonic nor efficient." 608 | ] 609 | }, 610 | { 611 | "cell_type": "markdown", 612 | "id": "43", 613 | "metadata": {}, 614 | "source": [ 615 | "## Optimization 1: Use List Comprehension Over For Loop\n", 616 | "\n", 617 | "Python's list comprehensions is a concise way to create lists by iterating over an iterable and applying an expression---all in a single, compact line.\n", 618 | "Internally, they are optimized in `C`, making them generally faster than a standard python for-loop." 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": null, 624 | "id": "44", 625 | "metadata": {}, 626 | "outputs": [], 627 | "source": [ 628 | "iterable = range(10)\n", 629 | "\n", 630 | "# Standard python for-loop\n", 631 | "l1 = []\n", 632 | "for x in iterable:\n", 633 | " l1.append(x * 2)\n", 634 | "\n", 635 | "# Using a list comprehension\n", 636 | "l2 = [x * 2 for x in iterable]\n", 637 | "\n", 638 | "# Compare results\n", 639 | "print(l1)\n", 640 | "print(l2)" 641 | ] 642 | }, 643 | { 644 | "cell_type": "markdown", 645 | "id": "45", 646 | "metadata": {}, 647 | "source": [ 648 | "We may use list comprehension to rewrite our leapfrog algorithm:" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "id": "46", 655 | "metadata": {}, 656 | "outputs": [], 657 | "source": [ 658 | "# O1. Use List Comprehension Over For Loop\n", 659 | "\n", 660 | "def leapfrog2(m, r0, v0, dt, acc=acc1):\n", 661 | "\n", 662 | " n = len(m)\n", 663 | "\n", 664 | " # vh = v0 + 0.5 * dt * a0\n", 665 | " a0 = acc(m, r0)\n", 666 | " vh = [[v0[i][k] + 0.5 * dt * a0[i][k]\n", 667 | " for k in range(3)]\n", 668 | " for i in range(n)]\n", 669 | "\n", 670 | " # r1 = r0 + dt * vh\n", 671 | " # TODO: use list comprehension to rewrite the remaining of our leapfrog algorithm\n", 672 | "\n", 673 | " # v1 = vh + 0.5 * dt * a1\n", 674 | " # TODO: use list comprehension to rewrite the remaining of our leapfrog algorithm\n", 675 | "\n", 676 | " # TODO: return ..." 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": null, 682 | "id": "47", 683 | "metadata": { 684 | "jupyter": { 685 | "source_hidden": true 686 | } 687 | }, 688 | "outputs": [], 689 | "source": [ 690 | "# O1. Use List Comprehension Over For Loop\n", 691 | "\n", 692 | "def leapfrog2(m, r0, v0, dt, acc=acc1):\n", 693 | "\n", 694 | " n = len(m)\n", 695 | "\n", 696 | " # vh = v0 + 0.5 * dt * a0\n", 697 | " a0 = acc(m, r0)\n", 698 | " vh = [[v0xi + 0.5 * dt * a0xi\n", 699 | " for v0xi, a0xi in zip(v0i, a0i)]\n", 700 | " for v0i, a0i in zip(v0, a0 )]\n", 701 | "\n", 702 | " # r1 = r0 + dt * vh\n", 703 | " r1 = [[x0i + dt * vhxi\n", 704 | " for x0i, vhxi in zip(r0i, vhi)]\n", 705 | " for r0i, vhi in zip(r0, vh )]\n", 706 | "\n", 707 | " # v1 = vh + 0.5 * dt * a1\n", 708 | " a1 = acc(m, r1)\n", 709 | " v1 = [[vhxi + 0.5 * dt * a1xi\n", 710 | " for vhxi, a1xi in zip(vhi, a1i)]\n", 711 | " for vhi, a1i, in zip(vh, a1 )]\n", 712 | "\n", 713 | " return r1, v1" 714 | ] 715 | }, 716 | { 717 | "cell_type": "code", 718 | "execution_count": null, 719 | "id": "48", 720 | "metadata": {}, 721 | "outputs": [], 722 | "source": [ 723 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt)" 724 | ] 725 | }, 726 | { 727 | "cell_type": "code", 728 | "execution_count": null, 729 | "id": "49", 730 | "metadata": {}, 731 | "outputs": [], 732 | "source": [ 733 | "cProfile.run(\"r1, v1 = leapfrog2(m, r0, v0, dt)\")" 734 | ] 735 | }, 736 | { 737 | "cell_type": "markdown", 738 | "id": "50", 739 | "metadata": {}, 740 | "source": [ 741 | "Depending on your python version, you may see slight performance increase from `timeit` and number of call decreased for `append`.\n", 742 | "It takes about 1.29 second on my laptop to run a single step of leapfrog for $n = 1000$ bodies." 743 | ] 744 | }, 745 | { 746 | "cell_type": "markdown", 747 | "id": "51", 748 | "metadata": {}, 749 | "source": [ 750 | "## Optimization 2: Reduce Operation Count\n", 751 | "\n", 752 | "Reducing the operation count means cutting down on unnecessary calculations and redundant function calls.\n", 753 | "When operations are executed millions of times---as in inner loops or simulations---even small optimizations can yield significant speed improvements.\n", 754 | "\n", 755 | "For example, precomputing constant values or combining multiple arithmetic steps into one reduces repetitive work.\n", 756 | "In the context of the $n$-body problem, calculate invariant quantities once outside of critical loops instead of recalculating them every time.\n", 757 | "This streamlined approach not only speeds up your code but also help further optimizations like vectorization and JIT compilation." 758 | ] 759 | }, 760 | { 761 | "cell_type": "markdown", 762 | "id": "52", 763 | "metadata": {}, 764 | "source": [ 765 | "Recall the benchmark using `leapfrog2()` with `acc1()`." 766 | ] 767 | }, 768 | { 769 | "cell_type": "code", 770 | "execution_count": null, 771 | "id": "53", 772 | "metadata": {}, 773 | "outputs": [], 774 | "source": [ 775 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc1)" 776 | ] 777 | }, 778 | { 779 | "cell_type": "markdown", 780 | "id": "54", 781 | "metadata": {}, 782 | "source": [ 783 | "We noticed that the computation of $\\mathbf{r}_{ij}^3 = |\\mathbf{r}_i - \\mathbf{r}_j|^3$ is used in all components of $\\mathbf{x}_{ij}$.\n", 784 | "Instead of recomputing it, we can simply cache it in a variable `rrr`." 785 | ] 786 | }, 787 | { 788 | "cell_type": "code", 789 | "execution_count": null, 790 | "id": "55", 791 | "metadata": {}, 792 | "outputs": [], 793 | "source": [ 794 | "# O2a. \"Cache\" r^3\n", 795 | "\n", 796 | "def acc2(m, r):\n", 797 | " \n", 798 | " n = len(m)\n", 799 | " a = []\n", 800 | " for i in range(n):\n", 801 | " axi, ayi, azi = 0, 0, 0\n", 802 | " for j in range(n):\n", 803 | " if j != i:\n", 804 | " xi, yi, zi = r[i]\n", 805 | " xj, yj, zj = r[j]\n", 806 | "\n", 807 | " # TODO: \"Cache\" r^3 below\n", 808 | " axij = - m[j] * (xi - xj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)\n", 809 | " ayij = - m[j] * (yi - yj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)\n", 810 | " azij = - m[j] * (zi - zj) / ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)\n", 811 | " \n", 812 | " axi += axij\n", 813 | " ayi += ayij\n", 814 | " azi += azij\n", 815 | "\n", 816 | " a.append((axi, ayi, azi))\n", 817 | "\n", 818 | " return a" 819 | ] 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": null, 824 | "id": "56", 825 | "metadata": { 826 | "jupyter": { 827 | "source_hidden": true 828 | } 829 | }, 830 | "outputs": [], 831 | "source": [ 832 | "# O2a. \"Cache\" r^3\n", 833 | "\n", 834 | "def acc2(m, r):\n", 835 | " \n", 836 | " n = len(m)\n", 837 | " a = []\n", 838 | " for i in range(n):\n", 839 | " axi, ayi, azi = 0, 0, 0\n", 840 | " for j in range(n):\n", 841 | " if j != i:\n", 842 | " xi, yi, zi = r[i]\n", 843 | " xj, yj, zj = r[j]\n", 844 | "\n", 845 | " # \"Cache\" r^3\n", 846 | " rrr = ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)\n", 847 | " \n", 848 | " axij = - m[j] * (xi - xj) / rrr\n", 849 | " ayij = - m[j] * (yi - yj) / rrr\n", 850 | " azij = - m[j] * (zi - zj) / rrr\n", 851 | "\n", 852 | " axi += axij\n", 853 | " ayi += ayij\n", 854 | " azi += azij\n", 855 | "\n", 856 | " a.append((axi, ayi, azi))\n", 857 | "\n", 858 | " return a" 859 | ] 860 | }, 861 | { 862 | "cell_type": "code", 863 | "execution_count": null, 864 | "id": "57", 865 | "metadata": {}, 866 | "outputs": [], 867 | "source": [ 868 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc2)" 869 | ] 870 | }, 871 | { 872 | "cell_type": "markdown", 873 | "id": "58", 874 | "metadata": {}, 875 | "source": [ 876 | "This reduce the benchmark time by about 45% already!" 877 | ] 878 | }, 879 | { 880 | "cell_type": "markdown", 881 | "id": "59", 882 | "metadata": {}, 883 | "source": [ 884 | "But we don't have to stop from this.\n", 885 | "The different components of $\\mathbf{r}_{ij} = \\mathbf{r}_i - \\mathbf{r}_j$ can be cached as `dx`, `dy`, `dz`, too." 886 | ] 887 | }, 888 | { 889 | "cell_type": "code", 890 | "execution_count": null, 891 | "id": "60", 892 | "metadata": {}, 893 | "outputs": [], 894 | "source": [ 895 | "# O2b. \"Cache\" the components of dr = r_ij = ri - rj\n", 896 | "\n", 897 | "def acc3(m, r):\n", 898 | " \n", 899 | " n = len(m)\n", 900 | " a = []\n", 901 | " for i in range(n):\n", 902 | " axi, ayi, azi = 0, 0, 0\n", 903 | " for j in range(n):\n", 904 | " if j != i:\n", 905 | " xi, yi, zi = r[i]\n", 906 | " xj, yj, zj = r[j]\n", 907 | "\n", 908 | " # TODO: \"Cache\" the components of dr = r_ij = ri - rj\n", 909 | " rrr = ((xi - xj)**2 + (yi - yj)**2 + (zi - zj)**2)**(3/2)\n", 910 | " \n", 911 | " axij = - m[j] * (xi - xj) / rrr\n", 912 | " ayij = - m[j] * (yi - yj) / rrr\n", 913 | " azij = - m[j] * (zi - zj) / rrr\n", 914 | "\n", 915 | " axi += axij\n", 916 | " ayi += ayij\n", 917 | " azi += azij\n", 918 | "\n", 919 | " a.append((axi, ayi, azi))\n", 920 | "\n", 921 | " return a" 922 | ] 923 | }, 924 | { 925 | "cell_type": "code", 926 | "execution_count": null, 927 | "id": "61", 928 | "metadata": { 929 | "jupyter": { 930 | "source_hidden": true 931 | } 932 | }, 933 | "outputs": [], 934 | "source": [ 935 | "# O2b. \"Cache\" the components of dr = r_ij = ri - rj\n", 936 | "\n", 937 | "def acc3(m, r):\n", 938 | " \n", 939 | " n = len(m)\n", 940 | " a = []\n", 941 | " for i in range(n):\n", 942 | " axi, ayi, azi = 0, 0, 0\n", 943 | " for j in range(n):\n", 944 | " if j != i:\n", 945 | " xi, yi, zi = r[i]\n", 946 | " xj, yj, zj = r[j]\n", 947 | "\n", 948 | " # \"Cache\" the components of dr = r_ij = ri - rj\n", 949 | " dx = xi - xj\n", 950 | " dy = yi - yj\n", 951 | " dz = zi - zj\n", 952 | "\n", 953 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2)\n", 954 | " \n", 955 | " axij = - m[j] * dx / rrr\n", 956 | " ayij = - m[j] * dy / rrr\n", 957 | " azij = - m[j] * dz / rrr\n", 958 | "\n", 959 | " axi += axij\n", 960 | " ayi += ayij\n", 961 | " azi += azij\n", 962 | "\n", 963 | " a.append((axi, ayi, azi))\n", 964 | "\n", 965 | " return a" 966 | ] 967 | }, 968 | { 969 | "cell_type": "code", 970 | "execution_count": null, 971 | "id": "62", 972 | "metadata": {}, 973 | "outputs": [], 974 | "source": [ 975 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc3)" 976 | ] 977 | }, 978 | { 979 | "cell_type": "markdown", 980 | "id": "63", 981 | "metadata": {}, 982 | "source": [ 983 | "Similarly, we can cache $-m_j / |\\mathbf{r}_{ij}|^3$." 984 | ] 985 | }, 986 | { 987 | "cell_type": "code", 988 | "execution_count": null, 989 | "id": "64", 990 | "metadata": {}, 991 | "outputs": [], 992 | "source": [ 993 | "# O2c. \"cache\" -m_j / r_{ij}^3\n", 994 | "\n", 995 | "def acc4(m, r):\n", 996 | " \n", 997 | " n = len(m)\n", 998 | " a = []\n", 999 | " for i in range(n):\n", 1000 | " axi, ayi, azi = 0, 0, 0\n", 1001 | " for j in range(n):\n", 1002 | " if j != i:\n", 1003 | " xi, yi, zi = r[i]\n", 1004 | " xj, yj, zj = r[j]\n", 1005 | "\n", 1006 | " dx = xi - xj\n", 1007 | " dy = yi - yj\n", 1008 | " dz = zi - zj\n", 1009 | "\n", 1010 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2)\n", 1011 | "\n", 1012 | " # TODO: \"cache\" -m_j / r_{ij}^3\n", 1013 | " axij = - m[j] * dx / rrr\n", 1014 | " ayij = - m[j] * dy / rrr\n", 1015 | " azij = - m[j] * dz / rrr\n", 1016 | "\n", 1017 | " axi += axij\n", 1018 | " ayi += ayij\n", 1019 | " azi += azij\n", 1020 | "\n", 1021 | " a.append((axi, ayi, azi))\n", 1022 | "\n", 1023 | " return a" 1024 | ] 1025 | }, 1026 | { 1027 | "cell_type": "code", 1028 | "execution_count": null, 1029 | "id": "65", 1030 | "metadata": { 1031 | "jupyter": { 1032 | "source_hidden": true 1033 | } 1034 | }, 1035 | "outputs": [], 1036 | "source": [ 1037 | "# O2c. \"cache\" -m_j / r_{ij}^3\n", 1038 | "\n", 1039 | "def acc4(m, r):\n", 1040 | " \n", 1041 | " n = len(m)\n", 1042 | " a = []\n", 1043 | " for i in range(n):\n", 1044 | " axi, ayi, azi = 0, 0, 0\n", 1045 | " for j in range(n):\n", 1046 | " if j != i:\n", 1047 | " xi, yi, zi = r[i]\n", 1048 | " xj, yj, zj = r[j]\n", 1049 | "\n", 1050 | " dx = xi - xj\n", 1051 | " dy = yi - yj\n", 1052 | " dz = zi - zj\n", 1053 | "\n", 1054 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2)\n", 1055 | " f = - m[j] / rrr # \"cache\" -m_j / r_{ij}^3\n", 1056 | "\n", 1057 | " axi += f * dx\n", 1058 | " ayi += f * dy\n", 1059 | " azi += f * dz\n", 1060 | "\n", 1061 | " a.append((axi, ayi, azi))\n", 1062 | "\n", 1063 | " return a" 1064 | ] 1065 | }, 1066 | { 1067 | "cell_type": "code", 1068 | "execution_count": null, 1069 | "id": "66", 1070 | "metadata": {}, 1071 | "outputs": [], 1072 | "source": [ 1073 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc4)" 1074 | ] 1075 | }, 1076 | { 1077 | "cell_type": "markdown", 1078 | "id": "67", 1079 | "metadata": {}, 1080 | "source": [ 1081 | "Finally, we notice the symmetry that $\\mathbf{a}_{ij} = \\mathbf{a}_{ji}$.\n", 1082 | "In principle, we only need to compute $\\mathbf{a}_{ij}$ for $j < i$.\n", 1083 | "However, this requires we pre-allocate a list-of-list.\n", 1084 | "Just creating list-of-list and use them to keep track of the acceleration actually increase benchmark time." 1085 | ] 1086 | }, 1087 | { 1088 | "cell_type": "code", 1089 | "execution_count": null, 1090 | "id": "68", 1091 | "metadata": {}, 1092 | "outputs": [], 1093 | "source": [ 1094 | "# O2d. Use list-of-list to keep track of the acceleration\n", 1095 | "\n", 1096 | "def acc5(m, r):\n", 1097 | " \n", 1098 | " n = len(m)\n", 1099 | " a = [] # TODO: create a list-of-list\n", 1100 | " for i in range(n):\n", 1101 | " axi, ayi, azi = 0, 0, 0\n", 1102 | " for j in range(n):\n", 1103 | " if j != i:\n", 1104 | " xi, yi, zi = r[i]\n", 1105 | " xj, yj, zj = r[j]\n", 1106 | "\n", 1107 | " dx = xi - xj\n", 1108 | " dy = yi - yj\n", 1109 | " dz = zi - zj\n", 1110 | "\n", 1111 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2)\n", 1112 | " f = - m[j] / rrr\n", 1113 | "\n", 1114 | " # TODO: Use the list-of-list to keep track of the acceleration\n", 1115 | " axi += f * dx\n", 1116 | " ayi += f * dy\n", 1117 | " azi += f * dz\n", 1118 | "\n", 1119 | " a.append((axi, ayi, azi))\n", 1120 | "\n", 1121 | " return a" 1122 | ] 1123 | }, 1124 | { 1125 | "cell_type": "code", 1126 | "execution_count": null, 1127 | "id": "69", 1128 | "metadata": { 1129 | "jupyter": { 1130 | "source_hidden": true 1131 | } 1132 | }, 1133 | "outputs": [], 1134 | "source": [ 1135 | "# O2d. Use list-of-list to keep track of the acceleration\n", 1136 | "\n", 1137 | "def acc5(m, r):\n", 1138 | " \n", 1139 | " n = len(m)\n", 1140 | " a = [[0]*3]*n # create a list-of-list\n", 1141 | " for i in range(n):\n", 1142 | " for j in range(n):\n", 1143 | " if j != i:\n", 1144 | " xi, yi, zi = r[i]\n", 1145 | " xj, yj, zj = r[j]\n", 1146 | "\n", 1147 | " dx = xi - xj\n", 1148 | " dy = yi - yj\n", 1149 | " dz = zi - zj\n", 1150 | "\n", 1151 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2)\n", 1152 | " f = - m[j] / rrr\n", 1153 | "\n", 1154 | " # Use the list-of-list to keep track of the acceleration\n", 1155 | " a[i][0] += f * dx\n", 1156 | " a[i][1] += f * dy\n", 1157 | " a[i][2] += f * dz\n", 1158 | "\n", 1159 | " return a" 1160 | ] 1161 | }, 1162 | { 1163 | "cell_type": "code", 1164 | "execution_count": null, 1165 | "id": "70", 1166 | "metadata": {}, 1167 | "outputs": [], 1168 | "source": [ 1169 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc5)" 1170 | ] 1171 | }, 1172 | { 1173 | "cell_type": "markdown", 1174 | "id": "71", 1175 | "metadata": {}, 1176 | "source": [ 1177 | "But once we have the list-of-list, we may change the upper bound of the inner loop to cut the computation to almost half.\n", 1178 | "This is the fastest code so far." 1179 | ] 1180 | }, 1181 | { 1182 | "cell_type": "code", 1183 | "execution_count": null, 1184 | "id": "72", 1185 | "metadata": {}, 1186 | "outputs": [], 1187 | "source": [ 1188 | "# O2e. Take advantage of the symmetry of a_ij\n", 1189 | "\n", 1190 | "def acc6(m, r):\n", 1191 | " \n", 1192 | " n = len(m)\n", 1193 | " a = [[0]*3]*n\n", 1194 | " for i in range(n):\n", 1195 | " for j in range(n): # TODO: adjust the upper bound of the inner loop\n", 1196 | " if j != i:\n", 1197 | " xi, yi, zi = r[i]\n", 1198 | " xj, yj, zj = r[j]\n", 1199 | "\n", 1200 | " dx = xi - xj\n", 1201 | " dy = yi - yj\n", 1202 | " dz = zi - zj\n", 1203 | "\n", 1204 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2)\n", 1205 | " f = - m[j] / rrr\n", 1206 | "\n", 1207 | " a[i][0] += f * dx\n", 1208 | " a[i][1] += f * dy\n", 1209 | " a[i][2] += f * dz\n", 1210 | "\n", 1211 | " # TODO: Account the acceleration to the j-th body.\n", 1212 | "\n", 1213 | " return a" 1214 | ] 1215 | }, 1216 | { 1217 | "cell_type": "code", 1218 | "execution_count": null, 1219 | "id": "73", 1220 | "metadata": { 1221 | "jupyter": { 1222 | "source_hidden": true 1223 | } 1224 | }, 1225 | "outputs": [], 1226 | "source": [ 1227 | "# O2e. Take advantage of the symmetry of a_ij\n", 1228 | "\n", 1229 | "def acc6(m, r):\n", 1230 | " \n", 1231 | " n = len(m)\n", 1232 | " a = [[0]*3]*n\n", 1233 | " for i in range(n):\n", 1234 | " for j in range(i): # adjust the upper bound of the inner loop\n", 1235 | " xi, yi, zi = r[i]\n", 1236 | " xj, yj, zj = r[j]\n", 1237 | "\n", 1238 | " dx = xi - xj\n", 1239 | " dy = yi - yj\n", 1240 | " dz = zi - zj\n", 1241 | "\n", 1242 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2)\n", 1243 | " fi = m[i] / rrr\n", 1244 | " fj = - m[j] / rrr\n", 1245 | " \n", 1246 | " a[i][0] += fj * dx\n", 1247 | " a[i][1] += fj * dy\n", 1248 | " a[i][2] += fj * dz\n", 1249 | "\n", 1250 | " # Account the acceleration to the j-th body.\n", 1251 | " a[j][0] += fi * dx\n", 1252 | " a[j][1] += fi * dy\n", 1253 | " a[j][2] += fi * dz\n", 1254 | "\n", 1255 | " return a" 1256 | ] 1257 | }, 1258 | { 1259 | "cell_type": "code", 1260 | "execution_count": null, 1261 | "id": "74", 1262 | "metadata": {}, 1263 | "outputs": [], 1264 | "source": [ 1265 | "%timeit r1, v1 = leapfrog2(m, r0, v0, dt, acc=acc6)" 1266 | ] 1267 | }, 1268 | { 1269 | "cell_type": "markdown", 1270 | "id": "75", 1271 | "metadata": {}, 1272 | "source": [ 1273 | "However, `acc?()` is not the only function we can optimize to reduce operation count.\n", 1274 | "If we study `leapfrog?()` carefully, the acceleration `a` in the second \"kick\" calculation can actually be reused by the first \"kick\" of the next step.\n", 1275 | "This requires modifying the function prototype a bit." 1276 | ] 1277 | }, 1278 | { 1279 | "cell_type": "code", 1280 | "execution_count": null, 1281 | "id": "76", 1282 | "metadata": {}, 1283 | "outputs": [], 1284 | "source": [ 1285 | "# O2f. Reuse computation\n", 1286 | "\n", 1287 | "def leapfrog3(m, r0, v0, a0, dt, acc=acc6):\n", 1288 | "\n", 1289 | " n = len(m)\n", 1290 | "\n", 1291 | " # vh = v0 + 0.5 * dt * a0\n", 1292 | " # a0 = acc(m, r0) <--- we comment this out, and reuse the acceleration computed in the second \"kick\" of the previous step\n", 1293 | " vh = [[v0[i][k] + 0.5 * dt * a0[i][k]\n", 1294 | " for k in range(3)]\n", 1295 | " for i in range(n)]\n", 1296 | "\n", 1297 | " # r1 = r0 + dt * vh\n", 1298 | " r1 = [[r0[i][k] + dt * vh[i][k]\n", 1299 | " for k in range(3)]\n", 1300 | " for i in range(n)]\n", 1301 | "\n", 1302 | " # v1 = vh + 0.5 * dt * a1\n", 1303 | " a1 = acc(m, r1)\n", 1304 | " v1 = [[vh[i][k] + 0.5 * dt * a1[i][k]\n", 1305 | " for k in range(3)]\n", 1306 | " for i in range(n)]\n", 1307 | "\n", 1308 | " return r1, v1, a1 # <--- to reuse the acceleration computed in the second \"kick\", let's return it." 1309 | ] 1310 | }, 1311 | { 1312 | "cell_type": "code", 1313 | "execution_count": null, 1314 | "id": "77", 1315 | "metadata": {}, 1316 | "outputs": [], 1317 | "source": [ 1318 | "# Precompute `a0`\n", 1319 | "a0 = acc4(m, r0) " 1320 | ] 1321 | }, 1322 | { 1323 | "cell_type": "code", 1324 | "execution_count": null, 1325 | "id": "78", 1326 | "metadata": {}, 1327 | "outputs": [], 1328 | "source": [ 1329 | "%timeit r1, v1, a1 = leapfrog3(m, r0, v0, a0, dt, acc=acc4)" 1330 | ] 1331 | }, 1332 | { 1333 | "cell_type": "code", 1334 | "execution_count": null, 1335 | "id": "79", 1336 | "metadata": {}, 1337 | "outputs": [], 1338 | "source": [ 1339 | "%timeit r1, v1, a1 = leapfrog3(m, r0, v0, a0, dt, acc=acc6)" 1340 | ] 1341 | }, 1342 | { 1343 | "cell_type": "markdown", 1344 | "id": "80", 1345 | "metadata": {}, 1346 | "source": [ 1347 | "Remarkable, it takes about 241ms on my laptop to run a single step of leapfrog for $n = 1000$ bodies.\n", 1348 | "This is almost a 92% reduction in benchmark time!!!" 1349 | ] 1350 | }, 1351 | { 1352 | "cell_type": "markdown", 1353 | "id": "81", 1354 | "metadata": {}, 1355 | "source": [ 1356 | "## Optimizing 3: Use Fast Operations\n", 1357 | "\n", 1358 | "Certain operations in Python can be surprisingly slow.\n", 1359 | "For example, the power operator (`**`) often calls C's pow() function, which for non-integer or variable exponents is typically implemented as:\n", 1360 | "\\begin{align}\n", 1361 | " x^y = \\exp[y \\cdot \\ln(x)]\n", 1362 | "\\end{align}\n", 1363 | "This method involves calculating a logarithm, a multiplication, and an exponential---operations that are much slower than simple multiplication." 1364 | ] 1365 | }, 1366 | { 1367 | "cell_type": "markdown", 1368 | "id": "82", 1369 | "metadata": {}, 1370 | "source": [ 1371 | "For instance, if you need to square a number, it's much faster to write `x*x` rather than `x**2`.\n", 1372 | "When the exponent is a known small integer, manually multiplying the base is usually the quickest route.\n", 1373 | "\n", 1374 | "By choosing fast operations over more generic ones, you can shave off microseconds in code that runs millions of times, ultimately contributing to some performance boost." 1375 | ] 1376 | }, 1377 | { 1378 | "cell_type": "code", 1379 | "execution_count": null, 1380 | "id": "83", 1381 | "metadata": {}, 1382 | "outputs": [], 1383 | "source": [ 1384 | "# O3. Use fast operations\n", 1385 | "\n", 1386 | "def acc7(m, r): # this is the same as acc4()\n", 1387 | " \n", 1388 | " n = len(m)\n", 1389 | " a = []\n", 1390 | " for i in range(n):\n", 1391 | " axi, ayi, azi = 0, 0, 0\n", 1392 | " for j in range(n):\n", 1393 | " if j != i:\n", 1394 | " xi, yi, zi = r[i]\n", 1395 | " xj, yj, zj = r[j]\n", 1396 | "\n", 1397 | " dx = xi - xj\n", 1398 | " dy = yi - yj\n", 1399 | " dz = zi - zj\n", 1400 | "\n", 1401 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2) # TODO: replace dx**2 by dx*dx etc\n", 1402 | " f = - m[j] / rrr\n", 1403 | "\n", 1404 | " axi += f * dx\n", 1405 | " ayi += f * dy\n", 1406 | " azi += f * dz\n", 1407 | "\n", 1408 | " a.append((axi, ayi, azi))\n", 1409 | "\n", 1410 | " return a" 1411 | ] 1412 | }, 1413 | { 1414 | "cell_type": "code", 1415 | "execution_count": null, 1416 | "id": "84", 1417 | "metadata": { 1418 | "jupyter": { 1419 | "source_hidden": true 1420 | } 1421 | }, 1422 | "outputs": [], 1423 | "source": [ 1424 | "# O3. Use fast operations\n", 1425 | "\n", 1426 | "def acc7(m, r): # this is a modification of acc4()\n", 1427 | " \n", 1428 | " n = len(m)\n", 1429 | " a = []\n", 1430 | " for i in range(n):\n", 1431 | " axi, ayi, azi = 0, 0, 0\n", 1432 | " for j in range(n):\n", 1433 | " if j != i:\n", 1434 | " xi, yi, zi = r[i]\n", 1435 | " xj, yj, zj = r[j]\n", 1436 | "\n", 1437 | " dx = xi - xj\n", 1438 | " dy = yi - yj\n", 1439 | " dz = zi - zj\n", 1440 | "\n", 1441 | " rrr = (dx*dx + dy*dy + dz*dz)**(3/2) # replace dx**2 by dx*dx etc\n", 1442 | " f = - m[j] / rrr\n", 1443 | "\n", 1444 | " axi += f * dx\n", 1445 | " ayi += f * dy\n", 1446 | " azi += f * dz\n", 1447 | "\n", 1448 | " a.append((axi, ayi, azi))\n", 1449 | "\n", 1450 | " return a" 1451 | ] 1452 | }, 1453 | { 1454 | "cell_type": "code", 1455 | "execution_count": null, 1456 | "id": "85", 1457 | "metadata": {}, 1458 | "outputs": [], 1459 | "source": [ 1460 | "%timeit r1, v1, a1 = leapfrog3(m, r0, v0, a0, dt, acc=acc7)" 1461 | ] 1462 | }, 1463 | { 1464 | "cell_type": "code", 1465 | "execution_count": null, 1466 | "id": "86", 1467 | "metadata": {}, 1468 | "outputs": [], 1469 | "source": [ 1470 | "# O2f. Reuse computation\n", 1471 | "\n", 1472 | "def leapfrog3(m, r0, v0, a0, dt, acc=acc6):\n", 1473 | "\n", 1474 | " n = len(m)\n", 1475 | "\n", 1476 | " # vh = v0 + 0.5 * dt * a0\n", 1477 | " # a0 = acc(m, r0) <--- we comment this out, and reuse the acceleration computed in the second \"kick\" of the previous step\n", 1478 | " vh = [[v0xi + 0.5 * dt * a0xi\n", 1479 | " for v0xi, a0xi in zip(v0i, a0i)]\n", 1480 | " for v0i, a0i in zip(v0, a0 )]\n", 1481 | "\n", 1482 | " # r1 = r0 + dt * vh\n", 1483 | " r1 = [[x0i + dt * vhxi\n", 1484 | " for x0i, vhxi in zip(r0i, vhi)]\n", 1485 | " for r0i, vhi in zip(r0, vh )]\n", 1486 | "\n", 1487 | " # v1 = vh + 0.5 * dt * a1\n", 1488 | " a1 = acc(m, r1)\n", 1489 | " v1 = [[vhxi + 0.5 * dt * a1xi\n", 1490 | " for vhxi, a1xi in zip(vhi, a1i)]\n", 1491 | " for vhi, a1i, in zip(vh, a1 )]\n", 1492 | "\n", 1493 | " return r1, v1, a1 # <--- to reuse the acceleration computed in the second \"kick\", let's return it." 1494 | ] 1495 | }, 1496 | { 1497 | "cell_type": "code", 1498 | "execution_count": null, 1499 | "id": "87", 1500 | "metadata": {}, 1501 | "outputs": [], 1502 | "source": [ 1503 | "# O3. Use fast operations\n", 1504 | "\n", 1505 | "def acc8(m, r): # this is the same as acc6()\n", 1506 | " \n", 1507 | " n = len(m)\n", 1508 | " a = [[0]*3]*n\n", 1509 | " for i in range(n):\n", 1510 | " for j in range(i):\n", 1511 | " xi, yi, zi = r[i]\n", 1512 | " xj, yj, zj = r[j]\n", 1513 | "\n", 1514 | " dx = xi - xj\n", 1515 | " dy = yi - yj\n", 1516 | " dz = zi - zj\n", 1517 | "\n", 1518 | " rrr = (dx**2 + dy**2 + dz**2)**(3/2) # TODO: replace dx**2 by dx*dx etc\n", 1519 | " fi = m[i] / rrr\n", 1520 | " fj = - m[j] / rrr\n", 1521 | " \n", 1522 | " a[i][0] += fj * dx\n", 1523 | " a[i][1] += fj * dy\n", 1524 | " a[i][2] += fj * dz\n", 1525 | "\n", 1526 | " a[j][0] += fi * dx\n", 1527 | " a[j][1] += fi * dy\n", 1528 | " a[j][2] += fi * dz\n", 1529 | "\n", 1530 | " return a" 1531 | ] 1532 | }, 1533 | { 1534 | "cell_type": "code", 1535 | "execution_count": null, 1536 | "id": "88", 1537 | "metadata": { 1538 | "jupyter": { 1539 | "source_hidden": true 1540 | } 1541 | }, 1542 | "outputs": [], 1543 | "source": [ 1544 | "# O3. Use fast operations\n", 1545 | "\n", 1546 | "def acc8(m, r): # this is a modification of acc6()\n", 1547 | " \n", 1548 | " n = len(m)\n", 1549 | " a = [[0]*3]*n\n", 1550 | " for i in range(n):\n", 1551 | " for j in range(i):\n", 1552 | " xi, yi, zi = r[i]\n", 1553 | " xj, yj, zj = r[j]\n", 1554 | "\n", 1555 | " dx = xi - xj\n", 1556 | " dy = yi - yj\n", 1557 | " dz = zi - zj\n", 1558 | "\n", 1559 | " rrr = (dx*dx + dy*dy + dz*dz)**(3/2) # replace dx**2 by dx*dx etc\n", 1560 | " fi = m[i] / rrr\n", 1561 | " fj = - m[j] / rrr\n", 1562 | " \n", 1563 | " a[i][0] += fj * dx\n", 1564 | " a[i][1] += fj * dy\n", 1565 | " a[i][2] += fj * dz\n", 1566 | "\n", 1567 | " a[j][0] += fi * dx\n", 1568 | " a[j][1] += fi * dy\n", 1569 | " a[j][2] += fi * dz\n", 1570 | "\n", 1571 | " return a" 1572 | ] 1573 | }, 1574 | { 1575 | "cell_type": "code", 1576 | "execution_count": null, 1577 | "id": "89", 1578 | "metadata": {}, 1579 | "outputs": [], 1580 | "source": [ 1581 | "%timeit r1, v1, a1 = leapfrog3(m, r0, v0, a0, dt, acc=acc8)" 1582 | ] 1583 | }, 1584 | { 1585 | "cell_type": "markdown", 1586 | "id": "90", 1587 | "metadata": {}, 1588 | "source": [ 1589 | "Surprisingly, although `acc7()` does not take advantage of the symmetry of $\\mathbf{a}_{ij}$, it is now faster than `acc6()`.\n", 1590 | "On the other hand, `acc8()` is still slightly faster than `acc7()`." 1591 | ] 1592 | }, 1593 | { 1594 | "cell_type": "markdown", 1595 | "id": "91", 1596 | "metadata": {}, 1597 | "source": [ 1598 | "Without external libraries, we've cut benchmark time by nearly 93%---a 14x speedup over our original $n$-body implementation.\n", 1599 | "While impressive even compared to compiled languages, this is just the beginning.\n", 1600 | "By leveraging high-performance libraries, we can overcome Python’s inherent slowness and move closer to our goal of a 1000x speedup." 1601 | ] 1602 | }, 1603 | { 1604 | "cell_type": "markdown", 1605 | "id": "92", 1606 | "metadata": {}, 1607 | "source": [ 1608 | "## Optimization 4: Use High Performance Libraries\n", 1609 | "\n", 1610 | "Leveraging high-performance libraries like NumPy is key to accelerating Python code.\n", 1611 | "NumPy's vectorized operations, implemented in C, allow you to perform complex computations on large arrays far more efficiently than native Python loops.\n", 1612 | "This means you can offload heavy calculations to optimized, low-level routines and achieve significant speedups." 1613 | ] 1614 | }, 1615 | { 1616 | "cell_type": "code", 1617 | "execution_count": null, 1618 | "id": "93", 1619 | "metadata": {}, 1620 | "outputs": [], 1621 | "source": [ 1622 | "# O4. Use NumPy\n", 1623 | "\n", 1624 | "def acc9(m, r): # this is the same as acc7()\n", 1625 | " \n", 1626 | " n = len(m)\n", 1627 | " \n", 1628 | " # TODO: rewrite the following code using numpy\n", 1629 | " a = []\n", 1630 | " for i in range(n):\n", 1631 | " axi, ayi, azi = 0, 0, 0\n", 1632 | " for j in range(n):\n", 1633 | " if j != i:\n", 1634 | " xi, yi, zi = r[i]\n", 1635 | " xj, yj, zj = r[j]\n", 1636 | "\n", 1637 | " dx = xi - xj\n", 1638 | " dy = yi - yj\n", 1639 | " dz = zi - zj\n", 1640 | "\n", 1641 | " rrr = (dx*dx + dy*dy + dz*dz)**(3/2)\n", 1642 | " f = - m[j] / rrr\n", 1643 | "\n", 1644 | " axi += f * dx\n", 1645 | " ayi += f * dy\n", 1646 | " azi += f * dz\n", 1647 | "\n", 1648 | " a.append((axi, ayi, azi))\n", 1649 | "\n", 1650 | " return a" 1651 | ] 1652 | }, 1653 | { 1654 | "cell_type": "code", 1655 | "execution_count": null, 1656 | "id": "94", 1657 | "metadata": { 1658 | "jupyter": { 1659 | "source_hidden": true 1660 | } 1661 | }, 1662 | "outputs": [], 1663 | "source": [ 1664 | "# O4. Use NumPy\n", 1665 | "\n", 1666 | "def acc9(m, r): # this is a modification of acc7(); acc8() is difficult to take advantage of numpy\n", 1667 | "\n", 1668 | " # idx: i j i j\n", 1669 | " # v v v v\n", 1670 | " dr = r[:,None,:] - r[None,:,:]\n", 1671 | " rr = np.sum(dr * dr, axis=-1) # sum over vector components\n", 1672 | "\n", 1673 | " # Ensure rr is non-zero\n", 1674 | " rr = np.maximum(rr, 1e-24)\n", 1675 | "\n", 1676 | " # idx: i j\n", 1677 | " # v v\n", 1678 | " f = -m[None,:] / rr**(3/2)\n", 1679 | " a = np.sum(f[:,:,None] * dr, axis=1) # sum over j\n", 1680 | " \n", 1681 | " return a" 1682 | ] 1683 | }, 1684 | { 1685 | "cell_type": "markdown", 1686 | "id": "95", 1687 | "metadata": {}, 1688 | "source": [ 1689 | "We may do the same for leapfrog:" 1690 | ] 1691 | }, 1692 | { 1693 | "cell_type": "code", 1694 | "execution_count": null, 1695 | "id": "96", 1696 | "metadata": {}, 1697 | "outputs": [], 1698 | "source": [ 1699 | "# O4. Use NumPy\n", 1700 | "\n", 1701 | "def leapfrog4(m, r0, v0, a0, dt, acc=acc6):\n", 1702 | "\n", 1703 | " # TODO: rewrite the following code using numpy\n", 1704 | " \n", 1705 | " n = len(m)\n", 1706 | "\n", 1707 | " # vh = v0 + 0.5 * dt * a0\n", 1708 | " vh = [[v0xi + 0.5 * dt * a0xi\n", 1709 | " for v0xi, a0xi in zip(v0i, a0i)]\n", 1710 | " for v0i, a0i in zip(v0, a0 )]\n", 1711 | "\n", 1712 | " # r1 = r0 + dt * vh\n", 1713 | " r1 = [[x0i + dt * vhxi\n", 1714 | " for x0i, vhxi in zip(r0i, vhi)]\n", 1715 | " for r0i, vhi in zip(r0, vh )]\n", 1716 | "\n", 1717 | " # v1 = vh + 0.5 * dt * a1\n", 1718 | " a1 = acc(m, r1)\n", 1719 | " v1 = [[vhxi + 0.5 * dt * a1xi\n", 1720 | " for vhxi, a1xi in zip(vhi, a1i)]\n", 1721 | " for vhi, a1i, in zip(vh, a1 )]\n", 1722 | "\n", 1723 | " return r1, v1, a1" 1724 | ] 1725 | }, 1726 | { 1727 | "cell_type": "code", 1728 | "execution_count": null, 1729 | "id": "97", 1730 | "metadata": { 1731 | "jupyter": { 1732 | "source_hidden": true 1733 | } 1734 | }, 1735 | "outputs": [], 1736 | "source": [ 1737 | "# O4. Use NumPy\n", 1738 | "\n", 1739 | "def leapfrog4(m, r0, v0, a0, dt, acc=acc9):\n", 1740 | " ht = 0.5 * dt\n", 1741 | "\n", 1742 | " vh = v0 + ht * a0\n", 1743 | " r1 = r0 + dt * vh\n", 1744 | " \n", 1745 | " a1 = acc(m, r1)\n", 1746 | " v1 = vh + ht * a1\n", 1747 | "\n", 1748 | " return r1, v1, a1" 1749 | ] 1750 | }, 1751 | { 1752 | "cell_type": "code", 1753 | "execution_count": null, 1754 | "id": "98", 1755 | "metadata": {}, 1756 | "outputs": [], 1757 | "source": [ 1758 | "m = np.array(m)\n", 1759 | "r0 = np.array(r0)\n", 1760 | "v0 = np.array(v0)\n", 1761 | "a0 = np.array(a0)" 1762 | ] 1763 | }, 1764 | { 1765 | "cell_type": "code", 1766 | "execution_count": null, 1767 | "id": "99", 1768 | "metadata": {}, 1769 | "outputs": [], 1770 | "source": [ 1771 | "%timeit r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc9)" 1772 | ] 1773 | }, 1774 | { 1775 | "cell_type": "code", 1776 | "execution_count": null, 1777 | "id": "100", 1778 | "metadata": {}, 1779 | "outputs": [], 1780 | "source": [ 1781 | "cProfile.run(\"r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc9)\")" 1782 | ] 1783 | }, 1784 | { 1785 | "cell_type": "markdown", 1786 | "id": "101", 1787 | "metadata": {}, 1788 | "source": [ 1789 | "This runs at 38.3ms on my laptop.\n", 1790 | "This cuts the benchmark time by 99% (which is no longer good indicator) and reaches a 82x speedup!" 1791 | ] 1792 | }, 1793 | { 1794 | "cell_type": "markdown", 1795 | "id": "102", 1796 | "metadata": {}, 1797 | "source": [ 1798 | "## Optimization 4: Use Lower Precision\n", 1799 | "\n", 1800 | "Using lower precision arithmetic can reduce memory usage and potentially speed up calculations.\n", 1801 | "However, the benefits depend heavily on your hardware.\n", 1802 | "For instance, on Intel x86 platforms, single precision may actually be slower than double precision due to conversion overhead and how the hardware is optimized.\n", 1803 | "Always test and profile on your target system before switching to lower precision." 1804 | ] 1805 | }, 1806 | { 1807 | "cell_type": "code", 1808 | "execution_count": null, 1809 | "id": "103", 1810 | "metadata": {}, 1811 | "outputs": [], 1812 | "source": [ 1813 | "# 04. Use Lower Precision\n", 1814 | "\n", 1815 | "# TODO: case the arrays `m`, `r0`, `v0`, and `a0` to single precision." 1816 | ] 1817 | }, 1818 | { 1819 | "cell_type": "code", 1820 | "execution_count": null, 1821 | "id": "104", 1822 | "metadata": { 1823 | "jupyter": { 1824 | "source_hidden": true 1825 | } 1826 | }, 1827 | "outputs": [], 1828 | "source": [ 1829 | "m = np.array(m, dtype=np.single)\n", 1830 | "r0 = np.array(r0, dtype=np.single)\n", 1831 | "v0 = np.array(v0, dtype=np.single)\n", 1832 | "a0 = np.array(a0, dtype=np.single)" 1833 | ] 1834 | }, 1835 | { 1836 | "cell_type": "code", 1837 | "execution_count": null, 1838 | "id": "105", 1839 | "metadata": {}, 1840 | "outputs": [], 1841 | "source": [ 1842 | "%timeit r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc9)" 1843 | ] 1844 | }, 1845 | { 1846 | "cell_type": "markdown", 1847 | "id": "106", 1848 | "metadata": {}, 1849 | "source": [ 1850 | "## Optimization 5: Google `JAX`\n", 1851 | "\n", 1852 | "Google `JAX` is a high-performance library that extends NumPy with automatic differentiation and just-in-time (JIT) compilation.\n", 1853 | "It transforms Python functions into highly optimized machine code using XLA, which can dramatically accelerate numerical computations.\n", 1854 | "\n", 1855 | "One of `JAX`'s standout features is its seamless support for GPUs.\n", 1856 | "Depending on your hardware, running your code on a GPU with `JAX` can lead to even greater speedups compared to CPU execution.\n", 1857 | "Additionally, `JAX`'s vectorization tools, like `vmap`, let you apply functions over arrays efficiently without explicit Python loops.\n", 1858 | "\n", 1859 | "By replace NumPy by `JAX`, we can harness the power of hardware acceleration, and achieve the ambitious 1000x speedup!" 1860 | ] 1861 | }, 1862 | { 1863 | "cell_type": "code", 1864 | "execution_count": null, 1865 | "id": "107", 1866 | "metadata": {}, 1867 | "outputs": [], 1868 | "source": [ 1869 | "from jax import numpy as jnp" 1870 | ] 1871 | }, 1872 | { 1873 | "cell_type": "code", 1874 | "execution_count": null, 1875 | "id": "108", 1876 | "metadata": {}, 1877 | "outputs": [], 1878 | "source": [ 1879 | "# O5. Use JAX\n", 1880 | "\n", 1881 | "def acc9(m, r):\n", 1882 | " # TODO: replace `np` by `jnp`.\n", 1883 | " dr = r[:,None,:] - r[None,:,:]\n", 1884 | " rr = np.sum(dr * dr, axis=-1)\n", 1885 | " rr = np.maximum(rr, 1e-24)\n", 1886 | " f = -m[None,:] / rr**(3/2)\n", 1887 | " a = np.sum(f[:,:,None] * dr, axis=1)\n", 1888 | " return a" 1889 | ] 1890 | }, 1891 | { 1892 | "cell_type": "code", 1893 | "execution_count": null, 1894 | "id": "109", 1895 | "metadata": { 1896 | "jupyter": { 1897 | "source_hidden": true 1898 | } 1899 | }, 1900 | "outputs": [], 1901 | "source": [ 1902 | "# O5. Use JAX\n", 1903 | "\n", 1904 | "def acc10(m, r):\n", 1905 | " dr = r[:,None,:] - r[None,:,:]\n", 1906 | " rr = jnp.sum(dr * dr, axis=-1)\n", 1907 | " rr = jnp.maximum(rr, 1e-24)\n", 1908 | " f = -m[None,:] / rr**(3/2)\n", 1909 | " a = jnp.sum(f[:,:,None] * dr, axis=1)\n", 1910 | " return a" 1911 | ] 1912 | }, 1913 | { 1914 | "cell_type": "code", 1915 | "execution_count": null, 1916 | "id": "110", 1917 | "metadata": {}, 1918 | "outputs": [], 1919 | "source": [ 1920 | "m = jnp.array(m)\n", 1921 | "r0 = jnp.array(r0)\n", 1922 | "v0 = jnp.array(v0)\n", 1923 | "a0 = jnp.array(a0)" 1924 | ] 1925 | }, 1926 | { 1927 | "cell_type": "code", 1928 | "execution_count": null, 1929 | "id": "111", 1930 | "metadata": {}, 1931 | "outputs": [], 1932 | "source": [ 1933 | "%timeit r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc10)" 1934 | ] 1935 | }, 1936 | { 1937 | "cell_type": "code", 1938 | "execution_count": null, 1939 | "id": "112", 1940 | "metadata": {}, 1941 | "outputs": [], 1942 | "source": [ 1943 | "cProfile.run(\"r1, v1, a1 = leapfrog4(m, r0, v0, a0, dt, acc=acc10)\")" 1944 | ] 1945 | }, 1946 | { 1947 | "cell_type": "markdown", 1948 | "id": "113", 1949 | "metadata": {}, 1950 | "source": [ 1951 | "The timeit results on my laptop show that the slowest run took ~ 11 times longer than the fastest, with an average execution time of 3ms.\n", 1952 | "This variability is likely due to the initialization of the JAX library and potential data transfer overhead.\n", 1953 | "Nevertheless, these results are astonishing---achieving a 1023x speedup over our original implementation!" 1954 | ] 1955 | }, 1956 | { 1957 | "cell_type": "markdown", 1958 | "id": "114", 1959 | "metadata": {}, 1960 | "source": [ 1961 | "## Optimization 6: Just-in-Time Compilation\n", 1962 | "\n", 1963 | "Just-in-Time (JIT) compilation converts Python code into optimized machine code at runtime.\n", 1964 | "This means that hot functions can be compiled on the fly, eliminating Python's overhead and allowing them to run at speeds much closer to those of native C code.\n", 1965 | "\n", 1966 | "`JAX` offers powerful JIT capabilities through its `jax.jit` decorator.\n", 1967 | "When applied, it compiles your numerical functions into efficient, low-level code, which can be further accelerated on GPUs.\n", 1968 | "This results in dramatic performance improvements, exeeding our quest for that 1000x speedup." 1969 | ] 1970 | }, 1971 | { 1972 | "cell_type": "code", 1973 | "execution_count": null, 1974 | "id": "115", 1975 | "metadata": {}, 1976 | "outputs": [], 1977 | "source": [ 1978 | "from jax import jit" 1979 | ] 1980 | }, 1981 | { 1982 | "cell_type": "code", 1983 | "execution_count": null, 1984 | "id": "116", 1985 | "metadata": {}, 1986 | "outputs": [], 1987 | "source": [ 1988 | "# O6. JIT the leapfrog stepper\n", 1989 | "\n", 1990 | "# TODO: Use the `@jit` decorator to compiler your code" 1991 | ] 1992 | }, 1993 | { 1994 | "cell_type": "code", 1995 | "execution_count": null, 1996 | "id": "117", 1997 | "metadata": { 1998 | "jupyter": { 1999 | "source_hidden": true 2000 | } 2001 | }, 2002 | "outputs": [], 2003 | "source": [ 2004 | "# O6. JIT the leapfrog stepper\n", 2005 | "\n", 2006 | "@jit\n", 2007 | "def leapfrog5(m, r0, v0, a0, dt):\n", 2008 | " return leapfrog4(m, r0, v0, a0, dt, acc=acc10)" 2009 | ] 2010 | }, 2011 | { 2012 | "cell_type": "code", 2013 | "execution_count": null, 2014 | "id": "118", 2015 | "metadata": {}, 2016 | "outputs": [], 2017 | "source": [ 2018 | "%timeit r1, v1, a1 = leapfrog5(m, r0, v0, a0, dt)" 2019 | ] 2020 | }, 2021 | { 2022 | "cell_type": "code", 2023 | "execution_count": null, 2024 | "id": "119", 2025 | "metadata": {}, 2026 | "outputs": [], 2027 | "source": [ 2028 | "cProfile.run(\"r1, v1, a1 = leapfrog5(m, r0, v0, a0, dt)\")" 2029 | ] 2030 | }, 2031 | { 2032 | "cell_type": "markdown", 2033 | "id": "120", 2034 | "metadata": {}, 2035 | "source": [ 2036 | "Our final results are truly remarkable.\n", 2037 | "By applying `JAX`'s JIT compilation, we compiled away all overhead (see the output of `cProfile`) and achieved a 1710x speedup over the original implementation.\n", 2038 | "This demostrate how powerful the optimization techniques we have introduced to improve Python's performance." 2039 | ] 2040 | }, 2041 | { 2042 | "cell_type": "markdown", 2043 | "id": "121", 2044 | "metadata": {}, 2045 | "source": [ 2046 | "## Applying the $n$-Body Integrator\n", 2047 | "\n", 2048 | "With our optimized $n$-body integrator in hand, we can now explore its applications.\n", 2049 | "This efficient python code lets us simulate various dynamical systems---from celestial mechanics to particle interactions---with different initial conditions.\n", 2050 | "By reading in data files that specify masses, positions, velocities, we can easily tailor simulations to real-world scenarios." 2051 | ] 2052 | }, 2053 | { 2054 | "cell_type": "code", 2055 | "execution_count": null, 2056 | "id": "122", 2057 | "metadata": {}, 2058 | "outputs": [], 2059 | "source": [ 2060 | "def integrate(fname, T, N):\n", 2061 | "\n", 2062 | " # Load data\n", 2063 | " ic = np.genfromtxt(fname)\n", 2064 | "\n", 2065 | " # Setup initial conditions\n", 2066 | " m = jnp.array(ic[:,0])\n", 2067 | " R = [jnp.array(ic[:,1:4])]\n", 2068 | " V = [jnp.array(ic[:,4:7])]\n", 2069 | " A = [acc10(m, R[-1])]\n", 2070 | "\n", 2071 | " # Main loop\n", 2072 | " T, dt = jnp.linspace(0, T, N+1, retstep=True)\n", 2073 | " for _ in tqdm(range(N)):\n", 2074 | " r, v, a = leapfrog4(m, R[-1], V[-1], A[-1], dt, acc=acc10)\n", 2075 | " R.append(r)\n", 2076 | " V.append(v)\n", 2077 | " A.append(a)\n", 2078 | "\n", 2079 | " # Return results\n", 2080 | " return T, jnp.array(R), jnp.array(V), jnp.array(A)" 2081 | ] 2082 | }, 2083 | { 2084 | "cell_type": "markdown", 2085 | "id": "123", 2086 | "metadata": {}, 2087 | "source": [ 2088 | "The figure-8 orbit is a fascinating solution to the three-body problem where three equal-mass bodies follow a single, intertwined figure-8 path.\n", 2089 | "Each body moves along the same curve, perfectly choreographed so that they never collide, yet their gravitational interactions keep them in a stable, periodic dance." 2090 | ] 2091 | }, 2092 | { 2093 | "cell_type": "markdown", 2094 | "id": "124", 2095 | "metadata": {}, 2096 | "source": [ 2097 | "We first download the initial condition from GitHub if the file doesn't exist:" 2098 | ] 2099 | }, 2100 | { 2101 | "cell_type": "code", 2102 | "execution_count": null, 2103 | "id": "125", 2104 | "metadata": {}, 2105 | "outputs": [], 2106 | "source": [ 2107 | "! [ -f ic/figure-8.tsv ] || wget -P ic https://raw.githubusercontent.com/rndsrc/orbits-py/refs/heads/main/ic/figure-8.tsv" 2108 | ] 2109 | }, 2110 | { 2111 | "cell_type": "code", 2112 | "execution_count": null, 2113 | "id": "126", 2114 | "metadata": {}, 2115 | "outputs": [], 2116 | "source": [ 2117 | "# We then run the integrate() code\n", 2118 | "\n", 2119 | "T, R, V, A = integrate(\"ic/figure-8.tsv\", 2.5, 250)" 2120 | ] 2121 | }, 2122 | { 2123 | "cell_type": "code", 2124 | "execution_count": null, 2125 | "id": "127", 2126 | "metadata": {}, 2127 | "outputs": [], 2128 | "source": [ 2129 | "X = R[:,:,0]\n", 2130 | "Y = R[:,:,1]\n", 2131 | "\n", 2132 | "plt.plot(X, Y, '-')" 2133 | ] 2134 | }, 2135 | { 2136 | "cell_type": "code", 2137 | "execution_count": null, 2138 | "id": "128", 2139 | "metadata": {}, 2140 | "outputs": [], 2141 | "source": [ 2142 | "anim = animate(X, Y)\n", 2143 | "\n", 2144 | "HTML(anim.to_html5_video()) # display animation\n", 2145 | "# anim.save('orbitss.mp4') # save animation" 2146 | ] 2147 | }, 2148 | { 2149 | "cell_type": "markdown", 2150 | "id": "129", 2151 | "metadata": {}, 2152 | "source": [ 2153 | "## Conclusion and Discussion\n", 2154 | "\n", 2155 | "In this workshop, we've journeyed from a straightforward, pure Python implementation of the $n$-body simulation to a highly optimized version that achieves a 1710x speedup.\n", 2156 | "We started by cleaning up our code with list comprehensions and cutting out unnecessary operations, then moved on to swapping slow functions for faster ones and leveraging high-performance libraries like NumPy.\n", 2157 | "\n", 2158 | "The real game-changer came with Google `JAX`.\n", 2159 | "Its just-in-time compilation and GPU support pushed our code another order of magnitude in performance.\n", 2160 | "These optimizations show that Python, when used right, isn't just for quick-and-dirty prototypes.\n", 2161 | "It can also deliver serious speed and efficiency.\n", 2162 | "\n", 2163 | "This is especially important in hackathons and rapid prototyping environments.\n", 2164 | "You can quickly iterate on your ideas and still end up with code that's robust enough for real-world applications.\n", 2165 | "Python truly bridges the gap between ease of development and high performance, letting you have your cake and eat it too." 2166 | ] 2167 | } 2168 | ], 2169 | "metadata": { 2170 | "kernelspec": { 2171 | "display_name": "Python 3 (ipykernel)", 2172 | "language": "python", 2173 | "name": "python3" 2174 | }, 2175 | "language_info": { 2176 | "codemirror_mode": { 2177 | "name": "ipython", 2178 | "version": 3 2179 | }, 2180 | "file_extension": ".py", 2181 | "mimetype": "text/x-python", 2182 | "name": "python", 2183 | "nbconvert_exporter": "python", 2184 | "pygments_lexer": "ipython3", 2185 | "version": "3.13.2" 2186 | } 2187 | }, 2188 | "nbformat": 4, 2189 | "nbformat_minor": 5 2190 | } 2191 | -------------------------------------------------------------------------------- /ic/3-body.tsv: -------------------------------------------------------------------------------- 1 | 2 0 0 0 0 0 0 2 | 1 0 1 0 1 0 0 3 | 1 0 -1 0 -1 0 0 4 | -------------------------------------------------------------------------------- /ic/circle.tsv: -------------------------------------------------------------------------------- 1 | 1 1 0 0 0 0.5 0 2 | 1 -1 0 0 0 -0.5 0 3 | -------------------------------------------------------------------------------- /ic/ellipse.tsv: -------------------------------------------------------------------------------- 1 | 0.5 1 0 0 0 0.4 0 2 | 2 -1 0 0 0 -0.1 0 3 | -------------------------------------------------------------------------------- /ic/figure-8.tsv: -------------------------------------------------------------------------------- 1 | 1 0.9700436 -0.24308753 0 0.466203685 0.43236573 0 2 | 1 -0.9700436 0.24308753 0 0.466203685 0.43236573 0 3 | 1 0 0 0 -0.93240737 -0.86473146 0 4 | -------------------------------------------------------------------------------- /ic/solar-system.tsv: -------------------------------------------------------------------------------- 1 | 3.9478416442871094e+01 5.6078135967254639e-03 -3.1279660761356354e-03 -1.7757424211595207e-04 2.0783834625035524e-03 2.5170932058244944e-03 6.7283457610756159e-05 2 | 6.5201911638723686e-06 2.5227290391921997e-01 -3.9525583386421204e-01 -5.4849412292242050e-02 6.9892196655273438e+00 4.2829818725585938e+00 2.9164445400238037e-01 3 | 9.6602037956472486e-05 -1.3234636187553406e-01 -7.1000230312347412e-01 -1.8743724795058370e-03 7.2622036933898926e+00 -1.4557237625122070e+00 4.3909132480621338e-01 4 | 1.1849479051306844e-04 2.0726967602968216e-02 9.8006278276443481e-01 -1.7673018737696111e-04 -6.3882884979248047e+00 9.3532137572765350e-02 1.0085605026688427e-06 5 | 1.2683108252531383e-05 9.1781175136566162e-01 1.0471463203430176e+00 -6.1226513935253024e-04 -4.0642089843750000e+00 3.7812654972076416e+00 -1.7918924987316132e-01 6 | 3.7671882659196854e-02 -3.6230869293212891e+00 3.3889400959014893e+00 6.6945947706699371e-02 -2.0039451122283936e+00 -2.0727291107177734e+00 -5.3540796041488647e-02 7 | 1.1279828846454620e-02 -8.5656147003173828e+00 3.2432060241699219e+00 2.8352031111717224e-01 -6.7099428176879883e-01 -2.0045726299285889e+00 -6.1672039330005646e-02 8 | 1.7230373341590166e-03 1.7148736953735352e+01 1.0144646644592285e+01 -1.8480630218982697e-01 -6.6980826854705811e-01 1.2072296142578125e+00 -1.3070498593151569e-02 9 | 2.0324734505265951e-03 -8.7731943130493164e+00 -2.8684883117675781e+01 7.9179602861404419e-01 1.0948790311813354e+00 -3.4561598300933838e-01 1.8136950209736824e-02 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | jaxlib 3 | jupyterlab 4 | matplotlib 5 | numpy 6 | tqdm 7 | --------------------------------------------------------------------------------