├── JaxOptimization.ipynb ├── Opitimization_with_jax.pdf └── README.md /JaxOptimization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "JaxOptimization.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyNnP0VFlaJ8gdipP2iXhWR0", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "accelerator": "TPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "id": "YfAOPDZgJz0m", 33 | "colab_type": "code", 34 | "colab": {} 35 | }, 36 | "source": [ 37 | "import jax.numpy as np\n", 38 | "from jax import grad, jit, vmap\n", 39 | "from jax import random\n", 40 | "from jax import jacfwd, jacrev\n", 41 | "from jax.numpy import linalg\n", 42 | "\n", 43 | "from numpy import nanargmin,nanargmax \n", 44 | "\n", 45 | "key = random.PRNGKey(42)" 46 | ], 47 | "execution_count": 0, 48 | "outputs": [] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": { 53 | "id": "2slwiA7gOe45", 54 | "colab_type": "text" 55 | }, 56 | "source": [ 57 | "# Single Variable Optimization" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": { 63 | "id": "ebUpVGzwU8Eh", 64 | "colab_type": "text" 65 | }, 66 | "source": [ 67 | "Defining the Object Functions" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "metadata": { 73 | "id": "REq_NrY5Nlvk", 74 | "colab_type": "code", 75 | "colab": {} 76 | }, 77 | "source": [ 78 | "def y(x):\n", 79 | " return ((x * np.sqrt(12*x - 36 )) / (2*(x - 3)))\n", 80 | "def L(x):\n", 81 | " return np.sqrt( x**2 + y(x)**2)" 82 | ], 83 | "execution_count": 0, 84 | "outputs": [] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": { 89 | "id": "U3cXxxCxVCWB", 90 | "colab_type": "text" 91 | }, 92 | "source": [ 93 | "### Solving with Gradient Descent\n", 94 | "Using ***grad*** to find the derivative of the function ***L***\n", 95 | "\n", 96 | "Using ***vmap*** to map the ***minGD*** function over the ***domain***\n", 97 | "\n", 98 | "Using the gradient descent equation:\n", 99 | "\n", 100 | "$x_{n+1} = x_{n} - 0.01 L^{'}(x_{n})$" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "metadata": { 106 | "id": "hdF8OOpMO0LS", 107 | "colab_type": "code", 108 | "colab": {} 109 | }, 110 | "source": [ 111 | "gradL = grad(L)\n", 112 | "\n", 113 | "def minGD(x): return x - 0.01 * gradL(x)\n", 114 | "\n", 115 | "domain = np.linspace(3.0, 5.0, num=50)\n", 116 | "\n", 117 | "vfuncGD = vmap(minGD)\n", 118 | "#Recurrent loop of gradient descent\n", 119 | "for epoch in range(500):\n", 120 | " domain = vfuncGD(domain)\n", 121 | "\n", 122 | "minfunc = vmap(L)\n", 123 | "minimums = minfunc(domain)" 124 | ], 125 | "execution_count": 0, 126 | "outputs": [] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": { 131 | "id": "XL1-aMJ_M3t7", 132 | "colab_type": "text" 133 | }, 134 | "source": [ 135 | "Finding the argmin and the objective minimum" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "metadata": { 141 | "id": "P8_NnP-xhvc1", 142 | "colab_type": "code", 143 | "outputId": "f9be9ee0-e098-4ab0-dd40-8f93e8604c28", 144 | "colab": { 145 | "base_uri": "https://localhost:8080/", 146 | "height": 34 147 | } 148 | }, 149 | "source": [ 150 | "arglist = nanargmin(minimums)\n", 151 | "argmin = domain[arglist]\n", 152 | "minimum = minimums[arglist]\n", 153 | "\n", 154 | "print(\"The minimum is {} the argmin is {}\".format(minimum,argmin))" 155 | ], 156 | "execution_count": 0, 157 | "outputs": [ 158 | { 159 | "output_type": "stream", 160 | "text": [ 161 | "The minimum is 7.794247150421143 the argmin is 4.505752086639404\n" 162 | ], 163 | "name": "stdout" 164 | } 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": { 170 | "id": "6cWQtt4RShjZ", 171 | "colab_type": "text" 172 | }, 173 | "source": [ 174 | "### Solving with Newton's Method\n", 175 | "Using ***grad*** to find the derivative of the function ***L***\n", 176 | "\n", 177 | "Using ***vmap*** to map the ***minGD*** function over the ***domain***\n", 178 | "\n", 179 | "Using the gradient descent equation:\n", 180 | "\n", 181 | "$x_{n+1} = x_{n} - \\frac{L^{'}(x_{n})}{L^{''}(x_{n})} $" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "metadata": { 187 | "id": "cSWJsxTAkKvl", 188 | "colab_type": "code", 189 | "colab": {} 190 | }, 191 | "source": [ 192 | "gradL = grad(L)\n", 193 | "gradL2 = grad(gradL)\n", 194 | "\n", 195 | "def minNewton(x): return x - gradL(x)/gradL2(x)\n", 196 | "\n", 197 | "domain = np.linspace(3.0, 5.0, num=50)\n", 198 | "vfuncNT = vmap(minNewton)\n", 199 | "for epoch in range(50):\n", 200 | " domain = vfuncNT(domain)\n", 201 | "\n", 202 | "minimums = minfunc(domain)" 203 | ], 204 | "execution_count": 0, 205 | "outputs": [] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": { 210 | "id": "mw_BEFd8Tz19", 211 | "colab_type": "text" 212 | }, 213 | "source": [ 214 | "Finding the argmin and the objective minimum" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "metadata": { 220 | "id": "8FQhAYifzQs8", 221 | "colab_type": "code", 222 | "outputId": "73f4c934-8092-4488-bbf3-8876a1350dc7", 223 | "colab": { 224 | "base_uri": "https://localhost:8080/", 225 | "height": 34 226 | } 227 | }, 228 | "source": [ 229 | "arglist = nanargmin(minimums)\n", 230 | "argmin = domain[arglist]\n", 231 | "minimum = minimums[arglist]\n", 232 | "\n", 233 | "print(\"The minimum is {} the argmin is {}\".format(minimum,argmin))" 234 | ], 235 | "execution_count": 0, 236 | "outputs": [ 237 | { 238 | "output_type": "stream", 239 | "text": [ 240 | "The minimum is 7.794229030609131 the arg min is 4.5\n" 241 | ], 242 | "name": "stdout" 243 | } 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": { 249 | "id": "babrapVGOsNB", 250 | "colab_type": "text" 251 | }, 252 | "source": [ 253 | "# Multivariable Optimization" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": { 259 | "id": "sKTRywtlUE_s", 260 | "colab_type": "text" 261 | }, 262 | "source": [ 263 | "Defining the Object Function" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "metadata": { 269 | "id": "bNXOUzoQzqW-", 270 | "colab_type": "code", 271 | "colab": {} 272 | }, 273 | "source": [ 274 | "def paraboloid(x): return (x[0]*x[1]-2)**2 + (x[1]-3)**2\n", 275 | "minfunc = vmap(paraboloid)\n", 276 | "\n", 277 | "J = jacfwd(paraboloid)" 278 | ], 279 | "execution_count": 0, 280 | "outputs": [] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "metadata": { 285 | "id": "7ZfP3qpIUZwj", 286 | "colab_type": "text" 287 | }, 288 | "source": [ 289 | "### Solving with Gradient Descent using the Jacobian ($\\nabla f$)\n", 290 | "Using ***grad*** to find the jacobian of the function ***paraboloid***\n", 291 | "\n", 292 | "Using ***vmap*** to map the ***minJacobian*** function over the ***domain***\n", 293 | "\n", 294 | "Using the gradient descent equation:\n", 295 | "\n", 296 | "$X_{n+1} = X_{n} - 0.01\\nabla f(X_{n}) $\n", 297 | "\n", 298 | "Where $ X = \\left[x_1,x_2,\\ldots,x_n \\right]^T$" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "metadata": { 304 | "id": "FT4zkldxoIZX", 305 | "colab_type": "code", 306 | "colab": {} 307 | }, 308 | "source": [ 309 | "def minJacobian(x): return x - 0.1*J(x) \n", 310 | "\n", 311 | "domain = random.uniform(key, shape=(50,2), dtype='float32',\n", 312 | " minval=-5.0, maxval=5.0)\n", 313 | "\n", 314 | "vfuncHS = vmap(minJacobian)\n", 315 | "for epoch in range(150):\n", 316 | " domain = vfuncHS(domain)\n", 317 | "\n", 318 | "\n", 319 | "minimums = minfunc(domain)" 320 | ], 321 | "execution_count": 0, 322 | "outputs": [] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "metadata": { 327 | "id": "MXFPZRIoZAiT", 328 | "colab_type": "text" 329 | }, 330 | "source": [ 331 | "Finding the argmin and the objective minimum" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "metadata": { 337 | "id": "6eNf0cj_w3hJ", 338 | "colab_type": "code", 339 | "outputId": "3c2c3172-4793-4f43-d7ba-0aebff4358d6", 340 | "colab": { 341 | "base_uri": "https://localhost:8080/", 342 | "height": 34 343 | } 344 | }, 345 | "source": [ 346 | "arglist = nanargmin(minimums)\n", 347 | "argmin = domain[arglist]\n", 348 | "minimum = minimums[arglist]\n", 349 | "\n", 350 | "print(\"The minimum is {} the arg min is ({},{})\".format(minimum,argmin[0],argmin[1]))" 351 | ], 352 | "execution_count": 0, 353 | "outputs": [ 354 | { 355 | "output_type": "stream", 356 | "text": [ 357 | "The minimum is 0.0 the arg min is (0.6666666865348816,3.0)\n" 358 | ], 359 | "name": "stdout" 360 | } 361 | ] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "metadata": { 366 | "id": "oUx3v1ysZRzd", 367 | "colab_type": "text" 368 | }, 369 | "source": [ 370 | "Defining the Hessian as $\\nabla (\\nabla f) = \\nabla^{2}f$" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "metadata": { 376 | "id": "7VUM43RM87dR", 377 | "colab_type": "code", 378 | "colab": {} 379 | }, 380 | "source": [ 381 | "def hessian(f):\n", 382 | " return jacfwd(jacrev(f))\n", 383 | " \n", 384 | "H = hessian(paraboloid)" 385 | ], 386 | "execution_count": 0, 387 | "outputs": [] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "metadata": { 392 | "id": "kxCf6SfdZ2hA", 393 | "colab_type": "text" 394 | }, 395 | "source": [ 396 | "### Solving with Newton's Method using the Hessian ($\\nabla^{2} f$)\n", 397 | "Using ***hessian*** to find the Hessian of the function ***paraboloid***\n", 398 | "\n", 399 | "Using ***vmap*** to map the ***minHessian*** function over the ***domain***\n", 400 | "\n", 401 | "Using the gradient descent equation:\n", 402 | "\n", 403 | "$X_{n+1} = X_{n} - 0.1 H^{-1}(X_{n}) \\nabla f(X_{n}) $\n", 404 | "\n", 405 | "Where $ X = \\left[x_1,x_2,\\ldots,x_n \\right]^T$" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "metadata": { 411 | "id": "TSb8zk0h9-Ju", 412 | "colab_type": "code", 413 | "colab": {} 414 | }, 415 | "source": [ 416 | "def minHessian(x): return x - 0.1*linalg.inv(H(x)) @ J(x) \n", 417 | "\n", 418 | "\n", 419 | "domain = random.uniform(key, shape=(50,2), dtype='float32',\n", 420 | " minval=-5.0, maxval=5.0)\n", 421 | "\n", 422 | "vfuncHS = vmap(minHessian)\n", 423 | "for epoch in range(150):\n", 424 | " domain = vfuncHS(domain)\n", 425 | "\n", 426 | "\n", 427 | "minimums = minfunc(domain)" 428 | ], 429 | "execution_count": 0, 430 | "outputs": [] 431 | }, 432 | { 433 | "cell_type": "markdown", 434 | "metadata": { 435 | "id": "ErsWvrRybaBe", 436 | "colab_type": "text" 437 | }, 438 | "source": [ 439 | "Finding the argmin and the objective minimum" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "metadata": { 445 | "id": "T_N1J8l_-8mN", 446 | "colab_type": "code", 447 | "outputId": "df952ffe-10d2-44e2-fb35-2beb62b6244a", 448 | "colab": { 449 | "base_uri": "https://localhost:8080/", 450 | "height": 34 451 | } 452 | }, 453 | "source": [ 454 | "arglist = nanargmin(minimums)\n", 455 | "argmin = domain[arglist]\n", 456 | "minimum = minimums[arglist]\n", 457 | "\n", 458 | "print(\"The minimum is {} the arg min is ({},{})\".format(minimum,argmin[0],argmin[1]))" 459 | ], 460 | "execution_count": 0, 461 | "outputs": [ 462 | { 463 | "output_type": "stream", 464 | "text": [ 465 | "The minimum is 9.094947017729282e-13 the arg min is (0.6666664481163025,3.0000009536743164)\n" 466 | ], 467 | "name": "stdout" 468 | } 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "metadata": { 474 | "id": "FTw8u2VfO2nZ", 475 | "colab_type": "text" 476 | }, 477 | "source": [ 478 | "# Multivariable Constrained Optimization" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": { 484 | "id": "uj12ffIludMv", 485 | "colab_type": "text" 486 | }, 487 | "source": [ 488 | "Defining the Object Function $f(x)$\n", 489 | "and The Constrained Function $g(x)$\n", 490 | "\n", 491 | "The Lagrangian is defined as ***Lagrange*** $f(x) - \\lambda g(x) = 0 $\n", 492 | "\n", 493 | "Therefore using Newton's Method we solve for $Lagrange(x)=0$\n", 494 | "\n", 495 | "Which is the same as minimizing the multivariable function $\\nabla Lagrange(x)$\n", 496 | "\n", 497 | "Thus the reccurent loop is:\n", 498 | "$X_{n+1} = X_{n} - \\nabla^{2} Lagrange^{-1}(X_{n}) \\nabla Lagrange(X_{n}) $\n", 499 | "\n", 500 | "Where $ X = \\left[x_1,x_2,\\ldots,x_n \\right]^T$\n" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "metadata": { 506 | "id": "SIl0XEFQQmGk", 507 | "colab_type": "code", 508 | "colab": {} 509 | }, 510 | "source": [ 511 | "def f(x): return 4*(x[0]**2)*x[1]\n", 512 | "def g(x): return x[0]**2 + x[1]**2 - 3\n", 513 | "\n", 514 | "minfunc = vmap(f)\n", 515 | "\n", 516 | "def Lagrange(l): return f(l[0:2]) - l[3]*g(l[0:2])\n", 517 | "\n", 518 | "L = jacfwd(Lagrange)\n", 519 | "gradL = jacfwd(L)" 520 | ], 521 | "execution_count": 0, 522 | "outputs": [] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "metadata": { 527 | "id": "QFR44_jUT0c2", 528 | "colab_type": "code", 529 | "colab": {} 530 | }, 531 | "source": [ 532 | "def solveLagrangian(l): return l - linalg.inv(gradL(l)) @ L(l) \n", 533 | "\n", 534 | "\n", 535 | "domain = random.uniform(key, shape=(50,3), dtype='float32',\n", 536 | " minval=-5.0, maxval=5.0)\n", 537 | "\n", 538 | "vfuncsLAG = vmap(solveLagrangian)\n", 539 | "for epoch in range(150):\n", 540 | " domain = vfuncsLAG(domain)\n", 541 | "\n", 542 | "minimums = minfunc(domain)\n" 543 | ], 544 | "execution_count": 0, 545 | "outputs": [] 546 | }, 547 | { 548 | "cell_type": "markdown", 549 | "metadata": { 550 | "id": "MHHGrgWyxiZk", 551 | "colab_type": "text" 552 | }, 553 | "source": [ 554 | "Finding the argmin and the objective minimum" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "metadata": { 560 | "id": "p8ZwO0A8l-PU", 561 | "colab_type": "code", 562 | "outputId": "56f93c06-b239-4c20-bb45-3f6a0a38b18d", 563 | "colab": { 564 | "base_uri": "https://localhost:8080/", 565 | "height": 34 566 | } 567 | }, 568 | "source": [ 569 | "arglist = nanargmin(minimums)\n", 570 | "argmin = domain[arglist]\n", 571 | "minimum = minimums[arglist]\n", 572 | "\n", 573 | "print(\"The minimum is {}, the arg min is ({},{}), the lagrangian is {}\".format(minimum,argmin[0],argmin[1],argmin[2]))" 574 | ], 575 | "execution_count": 0, 576 | "outputs": [ 577 | { 578 | "output_type": "stream", 579 | "text": [ 580 | "The minimum is -7.999999523162842, the arg min is (-1.4142135381698608,-1.0), the lagrangian is -4.0\n" 581 | ], 582 | "name": "stdout" 583 | } 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "metadata": { 589 | "id": "pQizl0Pi2Vm1", 590 | "colab_type": "code", 591 | "colab": {} 592 | }, 593 | "source": [ 594 | "" 595 | ], 596 | "execution_count": 0, 597 | "outputs": [] 598 | }, 599 | { 600 | "cell_type": "markdown", 601 | "metadata": { 602 | "id": "GG8vL7Jd1GPU", 603 | "colab_type": "text" 604 | }, 605 | "source": [ 606 | "# Solving a Three Variable Multivariable Constrained Optimization\n", 607 | "\n", 608 | "Find the dimensions of the box with largest volume if the total surface area is $64 cm^2$\n", 609 | "\n", 610 | "$Volume = f(x_0,x_1,x_2) = x_0 x_1x_2$\n", 611 | "\n", 612 | "$Surface Area = g(x_0,x_1,x_2) = 2x_0x_1 + 2x_1x_1 + 2x_0x_2 = 64$" 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "metadata": { 618 | "id": "YdcKTxCE1MAv", 619 | "colab_type": "code", 620 | "colab": {} 621 | }, 622 | "source": [ 623 | "def f(x): return x[0]*x[1]*x[2]\n", 624 | "def g(x): return 2*x[0]*x[1] + 2*x[1]*x[2] + 2*x[0]*x[2] - 64\n", 625 | "\n", 626 | "minfunc = vmap(f)\n", 627 | "\n", 628 | "def Lagrange(l): return f(l[0:3]) - l[3]*g(l[0:3])\n", 629 | "\n", 630 | "L = jacfwd(Lagrange)\n", 631 | "gradL = jacfwd(L)" 632 | ], 633 | "execution_count": 0, 634 | "outputs": [] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "metadata": { 639 | "id": "QU-U0Z4W2WOg", 640 | "colab_type": "code", 641 | "colab": {} 642 | }, 643 | "source": [ 644 | "def solveLagrangian(l): return l - 0.1*linalg.inv(gradL(l)) @ L(l) \n", 645 | "\n", 646 | "domain = random.uniform(key, shape=(50,4), dtype='float32',\n", 647 | " minval=0, maxval=10)\n", 648 | "\n", 649 | "vfuncsLAG = vmap(solveLagrangian)\n", 650 | "for epoch in range(200):\n", 651 | " domain = vfuncsLAG(domain)\n", 652 | "\n", 653 | "\n", 654 | "maximums = minfunc(domain)" 655 | ], 656 | "execution_count": 0, 657 | "outputs": [] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "metadata": { 662 | "id": "dDS7U7dw2ZjM", 663 | "colab_type": "code", 664 | "outputId": "cd0568d1-b29b-485e-e3c7-4edbcc485f3f", 665 | "colab": { 666 | "base_uri": "https://localhost:8080/", 667 | "height": 34 668 | } 669 | }, 670 | "source": [ 671 | "arglist = nanargmax(maximums)\n", 672 | "argmin = domain[arglist]\n", 673 | "minimum = maximums[arglist]\n", 674 | "\n", 675 | "print(\"The minimum is {}, the argmin is ({},{},{}), the lagrangian is {}\".format(minimum,argmin[0],\n", 676 | " argmin[1],\n", 677 | " argmin[2],\n", 678 | " argmin[3]))" 679 | ], 680 | "execution_count": 0, 681 | "outputs": [ 682 | { 683 | "output_type": "stream", 684 | "text": [ 685 | "The minimum is 34.83720016479492, the argmin is (3.2659873962402344,3.2659854888916016,3.2659873962402344), the lagrangian is 0.8164968490600586\n" 686 | ], 687 | "name": "stdout" 688 | } 689 | ] 690 | }, 691 | { 692 | "cell_type": "markdown", 693 | "metadata": { 694 | "id": "H8eWqfsmH0df", 695 | "colab_type": "text" 696 | }, 697 | "source": [ 698 | "It should be noted that this gives a 0.0000118855014% error!" 699 | ] 700 | }, 701 | { 702 | "cell_type": "markdown", 703 | "metadata": { 704 | "id": "IvMt89tBEoTB", 705 | "colab_type": "text" 706 | }, 707 | "source": [ 708 | "# Multivariable MultiConstrained Optimization" 709 | ] 710 | }, 711 | { 712 | "cell_type": "markdown", 713 | "metadata": { 714 | "id": "a-F_k9fcM8pd", 715 | "colab_type": "text" 716 | }, 717 | "source": [ 718 | "Let's start by trying to maximize the object function $f(x_0,x_1)$ with the constraints $g(x_0,x_1)$ and $h(x_0,x_1)$. \n", 719 | "\n", 720 | "$f(x_0,x_1) = 13x_0*2 + 10x_0x_1+ 7x_1^2 + x_0 + x_1 +2$\n", 721 | "\n", 722 | "$g(x_0,x_1) = 2x_0 - 5x_1 - 2 $\n", 723 | "\n", 724 | "$h(x_0,x_1) = x_0 + x_1 -1$\n" 725 | ] 726 | }, 727 | { 728 | "cell_type": "code", 729 | "metadata": { 730 | "id": "TAfygBNrEpk0", 731 | "colab_type": "code", 732 | "colab": {} 733 | }, 734 | "source": [ 735 | "def f(x) : return 13*x[0]**2 + 10*x[0]*x[1] + 7*x[1]**2 + x[0] + x[1]\n", 736 | "def g(x) : return 2*x[0]-5*x[1]-2\n", 737 | "def h(x) : return x[0] + x[1] -1\n", 738 | "\n", 739 | "minfunc = vmap(f)\n", 740 | "\n", 741 | "def Lagrange(l): return f(l[0:2]) - l[2]*g(l[0:2]) - l[3]*h(l[0:2])\n", 742 | "\n", 743 | "L = jacfwd(Lagrange)\n", 744 | "gradL = jacfwd(L)" 745 | ], 746 | "execution_count": 0, 747 | "outputs": [] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "metadata": { 752 | "id": "QJGFo_9BNm7D", 753 | "colab_type": "code", 754 | "colab": {} 755 | }, 756 | "source": [ 757 | "def solveLagrangian(l): return l - 0.1*linalg.inv(gradL(l)) @ L(l) \n", 758 | "\n", 759 | "\n", 760 | "domain = random.uniform(key, shape=(300,4), dtype='float32',\n", 761 | " minval=-4, maxval=1)\n", 762 | "\n", 763 | "\n", 764 | "vfuncsLAG = vmap(solveLagrangian)\n", 765 | "for epoch in range(300):\n", 766 | " domain = vfuncsLAG(domain)\n", 767 | "\n", 768 | "\n", 769 | "maximums = minfunc(domain)" 770 | ], 771 | "execution_count": 0, 772 | "outputs": [] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "metadata": { 777 | "id": "DFhcdCiNN7eX", 778 | "colab_type": "code", 779 | "outputId": "76907cff-3055-46f4-9863-64ffd7e6e8a3", 780 | "colab": { 781 | "base_uri": "https://localhost:8080/", 782 | "height": 34 783 | } 784 | }, 785 | "source": [ 786 | "arglist = nanargmin(maximums)\n", 787 | "argmin = domain[arglist]\n", 788 | "minimum = maximums[arglist]\n", 789 | "\n", 790 | "print(\"The minimum is {}, the argmin is ({},{}), the lagrangians are {} and {}\".format(minimum,argmin[0],\n", 791 | " argmin[1],\n", 792 | " argmin[2],\n", 793 | " argmin[3]))" 794 | ], 795 | "execution_count": 98, 796 | "outputs": [ 797 | { 798 | "output_type": "stream", 799 | "text": [ 800 | "The minimum is 13.999992370605469, the argmin is (0.9999997019767761,-1.18244605218365e-08), the lagrangians are 2.2857134342193604 and 22.428564071655273\n" 801 | ], 802 | "name": "stdout" 803 | } 804 | ] 805 | }, 806 | { 807 | "cell_type": "code", 808 | "metadata": { 809 | "id": "4XSvBlIrstyV", 810 | "colab_type": "code", 811 | "colab": {} 812 | }, 813 | "source": [ 814 | "" 815 | ], 816 | "execution_count": 0, 817 | "outputs": [] 818 | } 819 | ] 820 | } -------------------------------------------------------------------------------- /Opitimization_with_jax.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazy1998/Solving-Optimization-Problems-with-JAX/630b20676ae1cd8e4c5f5aaa11bb4b0a55d53cd2/Opitimization_with_jax.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Solving-Optimization-Problems-with-JAX 2 | 3 | You can read the medium article here: https://medium.com/swlh/solving-optimization-problems-with-jax-98376508bd4f 4 | 5 | Solving Optimization Problems with JAX, code and PDF 6 | --------------------------------------------------------------------------------