├── .gitignore ├── .markdownlint.yaml ├── LICENSE ├── README.md ├── [1]_ODE_PINN.ipynb ├── [1]_ODE_PINN_ClassForm.ipynb ├── [2]_PDE_Burgers_PINN.ipynb ├── [3]_PDE_Laplace_PINN.ipynb ├── [3]_PDE_Laplace_PINN_ClassForm.ipynb ├── [4]_ODE_Supervised_and_PINN.ipynb ├── [5]_System_of_ODEs_PINN.ipynb ├── [6]_ODE_PINN_finite_difference.ipynb └── [7]_PDE_LAPLACE_FO_PINN.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | **/__pycache__ 3 | **/.ipynb_checkpoints 4 | **/.DS_Store 5 | *.Icon 6 | *.egg-info/ 7 | .trunk -------------------------------------------------------------------------------- /.markdownlint.yaml: -------------------------------------------------------------------------------- 1 | # Autoformatter friendly markdownlint config (all formatting rules disabled) 2 | default: true 3 | blank_lines: false 4 | bullet: false 5 | html: false 6 | indentation: false 7 | line_length: false 8 | spaces: false 9 | url: false 10 | whitespace: false 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ASEM000 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Physics informed neural network](https://maziarraissi.github.io/PINNs/) in [JAX](https://github.com/google/jax) 2 | 3 | Example notebooks for various applications in Physics informed neural network. 4 | 5 | | Description | Functional form | Class form ✨*New*✨ | 6 | |---|---|---| 7 | | **[ODE]** | Open In Colab | Open In Colab | 8 | | **[PDE]** Burgers | Open In Colab | | 9 | | **[PDE]** Laplace | Open In Colab | Open In Colab | 10 | | **[ODE]** Supervised loss + PINN | Open In Colab | | 11 | | **[ODE]** System of ODE | Open In Colab | | 12 | | **[ODE]** Finite difference | Open In Colab | | 13 | 14 | 15 | If you find it useful please give it a ⭐ 16 | -------------------------------------------------------------------------------- /[1]_ODE_PINN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "id": "v77fdC1ZLyg1" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "#Credits : Mahmoud Asem @Asem000 Septemeber 2021" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": { 28 | "colab": { 29 | "base_uri": "https://localhost:8080/" 30 | }, 31 | "id": "vAR0swbLX_ZI", 32 | "outputId": "23ef5ef3-495b-4582-96a2-3108e8dce0bb" 33 | }, 34 | "outputs": [ 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "Collecting optax\n", 40 | " Downloading optax-0.0.9-py3-none-any.whl (118 kB)\n", 41 | "\u001b[K |████████████████████████████████| 118 kB 8.5 MB/s eta 0:00:01\n", 42 | "\u001b[?25hCollecting chex>=0.0.4\n", 43 | " Downloading chex-0.0.8-py3-none-any.whl (57 kB)\n", 44 | "\u001b[K |████████████████████████████████| 57 kB 6.1 MB/s eta 0:00:01\n", 45 | "\u001b[?25hRequirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.1.70+cuda110)\n", 46 | "Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.2.19)\n", 47 | "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (0.12.0)\n", 48 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.19.5)\n", 49 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->optax) (1.15.0)\n", 50 | "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.6)\n", 51 | "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.11.1)\n", 52 | "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0)\n", 53 | "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.12)\n", 54 | "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.4.1)\n", 55 | "Installing collected packages: chex, optax\n", 56 | "Successfully installed chex-0.0.8 optax-0.0.9\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "#Imports\n", 62 | "import jax \n", 63 | "import jax.numpy as jnp\n", 64 | "import numpy as np\n", 65 | "import matplotlib.pyplot as plt\n", 66 | "from matplotlib import cm\n", 67 | "import matplotlib as mpl\n", 68 | "!pip install optax\n", 69 | "import optax" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 4, 75 | "metadata": { 76 | "id": "yoPHsh5lWvyP" 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "import sympy as sp" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": { 86 | "id": "7bg4nSbsXVwD" 87 | }, 88 | "source": [ 89 | "### Generate a a differential equation and its solution using SymPy" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 19, 95 | "metadata": { 96 | "id": "P9664e-mVMTN" 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "t= sp.symbols('t')\n", 101 | "f = sp.Function('y')\n", 102 | "diffeq = sp.Eq(f(t).diff(t,t) + f(t).diff(t)-t*sp.cos(2*sp.pi*t),0)\n", 103 | "sol = sp.simplify(sp.dsolve(diffeq,ics={f(0):1,f(t).diff(t).subs(t,0):10}).rhs)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 20, 109 | "metadata": { 110 | "colab": { 111 | "base_uri": "https://localhost:8080/", 112 | "height": 54 113 | }, 114 | "id": "klgFeU6bcTrC", 115 | "outputId": "cad6a477-37e0-414d-daf6-6e4f05a288cd" 116 | }, 117 | "outputs": [ 118 | { 119 | "data": { 120 | "text/latex": [ 121 | "$\\displaystyle - t \\cos{\\left(2 \\pi t \\right)} + \\frac{d}{d t} y{\\left(t \\right)} + \\frac{d^{2}}{d t^{2}} y{\\left(t \\right)} = 0$" 122 | ], 123 | "text/plain": [ 124 | "Eq(-t*cos(2*pi*t) + Derivative(y(t), t) + Derivative(y(t), (t, 2)), 0)" 125 | ] 126 | }, 127 | "execution_count": 20, 128 | "metadata": {}, 129 | "output_type": "execute_result" 130 | } 131 | ], 132 | "source": [ 133 | "diffeq" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 21, 139 | "metadata": { 140 | "colab": { 141 | "base_uri": "https://localhost:8080/", 142 | "height": 60 143 | }, 144 | "id": "E4Uu2hbiYJtv", 145 | "outputId": "e58180aa-779a-49fa-9771-b9f4763bb037" 146 | }, 147 | "outputs": [ 148 | { 149 | "data": { 150 | "text/latex": [ 151 | "$\\displaystyle \\left. \\frac{d}{d t} y{\\left(t \\right)} \\right|_{\\substack{ t=0 }} = 10$" 152 | ], 153 | "text/plain": [ 154 | "Eq(Subs(Derivative(y(t), t), t, 0), 10)" 155 | ] 156 | }, 157 | "execution_count": 21, 158 | "metadata": {}, 159 | "output_type": "execute_result" 160 | } 161 | ], 162 | "source": [ 163 | "sp.Eq(f(t).diff(t).subs(t,0),10)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 22, 169 | "metadata": { 170 | "colab": { 171 | "base_uri": "https://localhost:8080/", 172 | "height": 38 173 | }, 174 | "id": "29QUbt_2YwlJ", 175 | "outputId": "3d9edb85-a44f-42cc-8776-bccda97cbc7e" 176 | }, 177 | "outputs": [ 178 | { 179 | "data": { 180 | "text/latex": [ 181 | "$\\displaystyle y{\\left(0 \\right)} = 1$" 182 | ], 183 | "text/plain": [ 184 | "Eq(y(0), 1)" 185 | ] 186 | }, 187 | "execution_count": 22, 188 | "metadata": {}, 189 | "output_type": "execute_result" 190 | } 191 | ], 192 | "source": [ 193 | "sp.Eq(f(t).subs(t,0),1)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 24, 199 | "metadata": { 200 | "colab": { 201 | "base_uri": "https://localhost:8080/", 202 | "height": 101 203 | }, 204 | "id": "r9KVq1yjYfld", 205 | "outputId": "b6e56baa-b419-49e0-bba9-3706d96b394c" 206 | }, 207 | "outputs": [ 208 | { 209 | "data": { 210 | "text/latex": [ 211 | "$\\displaystyle y{\\left(t \\right)} = \\frac{\\left(2 \\pi t e^{t} \\sin{\\left(2 \\pi t \\right)} + 8 \\pi^{3} t e^{t} \\sin{\\left(2 \\pi t \\right)} - 16 \\pi^{4} t e^{t} \\cos{\\left(2 \\pi t \\right)} - 4 \\pi^{2} t e^{t} \\cos{\\left(2 \\pi t \\right)} + 16 \\pi^{3} e^{t} \\sin{\\left(2 \\pi t \\right)} + e^{t} \\cos{\\left(2 \\pi t \\right)} + 12 \\pi^{2} e^{t} \\cos{\\left(2 \\pi t \\right)} - e^{t} + 36 \\pi^{2} e^{t} + 336 \\pi^{4} e^{t} + 704 \\pi^{6} e^{t} - 640 \\pi^{6} - 304 \\pi^{4} - 44 \\pi^{2}\\right) e^{- t}}{4 \\pi^{2} \\left(1 + 8 \\pi^{2} + 16 \\pi^{4}\\right)}$" 212 | ], 213 | "text/plain": [ 214 | "Eq(y(t), (2*pi*t*exp(t)*sin(2*pi*t) + 8*pi**3*t*exp(t)*sin(2*pi*t) - 16*pi**4*t*exp(t)*cos(2*pi*t) - 4*pi**2*t*exp(t)*cos(2*pi*t) + 16*pi**3*exp(t)*sin(2*pi*t) + exp(t)*cos(2*pi*t) + 12*pi**2*exp(t)*cos(2*pi*t) - exp(t) + 36*pi**2*exp(t) + 336*pi**4*exp(t) + 704*pi**6*exp(t) - 640*pi**6 - 304*pi**4 - 44*pi**2)*exp(-t)/(4*pi**2*(1 + 8*pi**2 + 16*pi**4)))" 215 | ] 216 | }, 217 | "execution_count": 24, 218 | "metadata": {}, 219 | "output_type": "execute_result" 220 | } 221 | ], 222 | "source": [ 223 | "sp.Eq(f(t),sol)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 25, 229 | "metadata": { 230 | "colab": { 231 | "base_uri": "https://localhost:8080/", 232 | "height": 37 233 | }, 234 | "id": "MNVOpPyCW-GU", 235 | "outputId": "c333d114-01de-4132-8ebf-effd19785112" 236 | }, 237 | "outputs": [ 238 | { 239 | "data": { 240 | "text/latex": [ 241 | "$\\displaystyle 0$" 242 | ], 243 | "text/plain": [ 244 | "0" 245 | ] 246 | }, 247 | "execution_count": 25, 248 | "metadata": {}, 249 | "output_type": "execute_result" 250 | } 251 | ], 252 | "source": [ 253 | "#verify solution\n", 254 | "sp.simplify(-t*sp.cos(sp.pi*2*t)+sol.diff(t)+sol.diff(t,t))" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": { 260 | "id": "NQ61lEQeXgrc" 261 | }, 262 | "source": [ 263 | "### Constructing the MLP" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": 61, 269 | "metadata": { 270 | "id": "Lml6PGLPZgmr" 271 | }, 272 | "outputs": [], 273 | "source": [ 274 | "N_b = 1\n", 275 | "N_c = 100\n", 276 | "\n", 277 | "tmin,tmax=0. ,jnp.pi\n", 278 | "\n", 279 | "'''boundary conditions'''\n", 280 | "\n", 281 | "\n", 282 | "# U[0] = 1\n", 283 | "t_0 = jnp.ones([N_b,1],dtype='float32')*0.\n", 284 | "ic_0 = jnp.ones_like(t_0) \n", 285 | "IC_0 = jnp.concatenate([t_0,ic_0],axis=1)\n", 286 | "\n", 287 | "# U_t[0] = 10\n", 288 | "t_b1 = jnp.zeros([N_b,1])\n", 289 | "bc_1 = jnp.ones_like(t_b1) * 10\n", 290 | "BC_1 = jnp.concatenate([t_b1,bc_1],axis=1)\n", 291 | "\n", 292 | "conds = [IC_0,BC_1]\n", 293 | "\n", 294 | "#collocation points\n", 295 | "\n", 296 | "key=jax.random.PRNGKey(0)\n", 297 | "\n", 298 | "t_c = jax.random.uniform(key,minval=tmin,maxval=tmax,shape=(N_c,1))\n", 299 | "colloc = t_c\n", 300 | "\n", 301 | "def ODE_loss(t,u):\n", 302 | " u_t=lambda t:jax.grad(lambda t:jnp.sum(u(t)))(t)\n", 303 | " u_tt=lambda t:jax.grad(lambda t : jnp.sum(u_t(t)))(t)\n", 304 | " return -t*jnp.cos(2*jnp.pi*t) + u_t(t) + u_tt(t)" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 68, 310 | "metadata": { 311 | "id": "KoZZJl2TbI_n" 312 | }, 313 | "outputs": [], 314 | "source": [ 315 | "def init_params(layers):\n", 316 | " keys = jax.random.split(jax.random.PRNGKey(0),len(layers)-1)\n", 317 | " params = list()\n", 318 | " for key,n_in,n_out in zip(keys,layers[:-1],layers[1:]):\n", 319 | " lb, ub = -(1 / jnp.sqrt(n_in)), (1 / jnp.sqrt(n_in)) # xavier initialization lower and upper bound\n", 320 | " W = lb + (ub-lb) * jax.random.uniform(key,shape=(n_in,n_out))\n", 321 | " B = jax.random.uniform(key,shape=(n_out,))\n", 322 | " params.append({'W':W,'B':B})\n", 323 | " return params\n", 324 | "\n", 325 | "def fwd(params,t):\n", 326 | " X = jnp.concatenate([t],axis=1)\n", 327 | " *hidden,last = params\n", 328 | " for layer in hidden :\n", 329 | " X = jax.nn.tanh(X@layer['W']+layer['B'])\n", 330 | " return X@last['W'] + last['B']\n", 331 | "\n", 332 | "@jax.jit\n", 333 | "def MSE(true,pred):\n", 334 | " return jnp.mean((true-pred)**2)\n", 335 | "\n", 336 | "def loss_fun(params,colloc,conds):\n", 337 | " t_c =colloc[:,[0]]\n", 338 | " ufunc = lambda t : fwd(params,t)\n", 339 | " ufunc_t=lambda t:jax.grad(lambda t:jnp.sum(ufunc(t)))(t)\n", 340 | " loss =jnp.mean(ODE_loss(t_c,ufunc) **2)\n", 341 | "\n", 342 | " t_ic,u_ic = conds[0][:,[0]],conds[0][:,[1]] \n", 343 | " loss += MSE(u_ic,ufunc(t_ic))\n", 344 | "\n", 345 | " t_bc,u_bc = conds[1][:,[0]],conds[1][:,[1]] \n", 346 | " loss += MSE(u_bc,ufunc_t(t_bc))\n", 347 | "\n", 348 | " return loss\n", 349 | "\n", 350 | "@jax.jit\n", 351 | "def update(opt_state,params,colloc,conds):\n", 352 | " # Get the gradient w.r.t to MLP params\n", 353 | " grads=jax.jit(jax.grad(loss_fun,0))(params,colloc,conds)\n", 354 | " \n", 355 | " #Update params\n", 356 | " updates, opt_state = optimizer.update(grads, opt_state)\n", 357 | " params = optax.apply_updates(params, updates)\n", 358 | "\n", 359 | " #Update params\n", 360 | " # return jax.tree_multimap(lambda params,grads : params-LR*grads, params,grads)\n", 361 | " return opt_state,params\n" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 69, 367 | "metadata": { 368 | "id": "ae1ZDoy0c29c" 369 | }, 370 | "outputs": [], 371 | "source": [ 372 | "# construct the MLP of 6 hidden layers of 8 neurons for each layer\n", 373 | "params = init_params([1] + [20]*4+[1])" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 70, 379 | "metadata": { 380 | "id": "jySmbUwic5yk" 381 | }, 382 | "outputs": [], 383 | "source": [ 384 | "lr = optax.piecewise_constant_schedule(1e-3,{10_000:5e-3,30_000:1e-3,50_000:5e-4,70_000:1e-4})\n", 385 | "optimizer = optax.adam(lr)\n", 386 | "opt_state = optimizer.init(params)" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 71, 392 | "metadata": { 393 | "colab": { 394 | "base_uri": "https://localhost:8080/" 395 | }, 396 | "id": "kBzGA8OVc8C6", 397 | "outputId": "effacbde-f915-47cd-8f13-a44e6b697062" 398 | }, 399 | "outputs": [ 400 | { 401 | "name": "stdout", 402 | "output_type": "stream", 403 | "text": [ 404 | "Epoch=0\tloss=1.026e+02\n", 405 | "Epoch=100\tloss=1.500e+01\n", 406 | "Epoch=200\tloss=7.508e+00\n", 407 | "Epoch=300\tloss=5.177e+00\n", 408 | "Epoch=400\tloss=3.458e+00\n", 409 | "Epoch=500\tloss=2.615e+00\n", 410 | "Epoch=600\tloss=2.391e+00\n", 411 | "Epoch=700\tloss=2.254e+00\n", 412 | "Epoch=800\tloss=2.133e+00\n", 413 | "Epoch=900\tloss=2.013e+00\n", 414 | "Epoch=1000\tloss=1.855e+00\n", 415 | "Epoch=1100\tloss=1.549e+00\n", 416 | "Epoch=1200\tloss=1.018e+00\n", 417 | "Epoch=1300\tloss=9.915e-01\n", 418 | "Epoch=1400\tloss=9.756e-01\n", 419 | "Epoch=1500\tloss=9.618e-01\n", 420 | "Epoch=1600\tloss=9.480e-01\n", 421 | "Epoch=1700\tloss=9.327e-01\n", 422 | "Epoch=1800\tloss=9.145e-01\n", 423 | "Epoch=1900\tloss=8.902e-01\n", 424 | "Epoch=2000\tloss=7.379e-01\n", 425 | "Epoch=2100\tloss=3.372e-01\n", 426 | "Epoch=2200\tloss=1.753e-02\n", 427 | "Epoch=2300\tloss=4.962e-03\n", 428 | "Epoch=2400\tloss=3.304e-03\n", 429 | "Epoch=2500\tloss=2.670e-03\n", 430 | "Epoch=2600\tloss=2.238e-03\n", 431 | "Epoch=2700\tloss=1.883e-03\n", 432 | "Epoch=2800\tloss=1.577e-03\n", 433 | "Epoch=2900\tloss=1.311e-03\n", 434 | "Epoch=3000\tloss=1.082e-03\n", 435 | "Epoch=3100\tloss=8.856e-04\n", 436 | "Epoch=3200\tloss=7.198e-04\n", 437 | "Epoch=3300\tloss=5.825e-04\n", 438 | "Epoch=3400\tloss=4.711e-04\n", 439 | "Epoch=3500\tloss=6.700e-04\n", 440 | "Epoch=3600\tloss=3.161e-04\n", 441 | "Epoch=3700\tloss=3.296e-04\n", 442 | "Epoch=3800\tloss=2.278e-04\n", 443 | "Epoch=3900\tloss=2.032e-04\n", 444 | "Epoch=4000\tloss=1.785e-04\n", 445 | "Epoch=4100\tloss=1.624e-04\n", 446 | "Epoch=4200\tloss=1.494e-04\n", 447 | "Epoch=4300\tloss=1.377e-04\n", 448 | "Epoch=4400\tloss=2.545e-04\n", 449 | "Epoch=4500\tloss=1.195e-04\n", 450 | "Epoch=4600\tloss=2.162e-03\n", 451 | "Epoch=4700\tloss=1.041e-04\n", 452 | "Epoch=4800\tloss=1.321e-04\n", 453 | "Epoch=4900\tloss=9.035e-05\n", 454 | "Epoch=5000\tloss=8.385e-05\n", 455 | "Epoch=5100\tloss=8.181e-05\n", 456 | "Epoch=5200\tloss=7.205e-05\n", 457 | "Epoch=5300\tloss=2.472e-03\n", 458 | "Epoch=5400\tloss=6.137e-05\n", 459 | "Epoch=5500\tloss=6.695e-05\n", 460 | "Epoch=5600\tloss=5.197e-05\n", 461 | "Epoch=5700\tloss=4.714e-05\n", 462 | "Epoch=5800\tloss=7.266e-05\n", 463 | "Epoch=5900\tloss=3.922e-05\n", 464 | "Epoch=6000\tloss=1.470e-03\n", 465 | "Epoch=6100\tloss=3.234e-05\n", 466 | "Epoch=6200\tloss=1.919e-03\n", 467 | "Epoch=6300\tloss=2.676e-05\n", 468 | "Epoch=6400\tloss=2.389e-05\n", 469 | "Epoch=6500\tloss=2.301e-05\n", 470 | "Epoch=6600\tloss=1.962e-05\n", 471 | "Epoch=6700\tloss=3.187e-04\n", 472 | "Epoch=6800\tloss=1.618e-05\n", 473 | "Epoch=6900\tloss=2.400e-03\n", 474 | "Epoch=7000\tloss=1.355e-05\n", 475 | "Epoch=7100\tloss=2.970e-04\n", 476 | "Epoch=7200\tloss=1.146e-05\n", 477 | "Epoch=7300\tloss=1.040e-05\n", 478 | "Epoch=7400\tloss=3.527e-05\n", 479 | "Epoch=7500\tloss=9.057e-06\n", 480 | "Epoch=7600\tloss=4.930e-04\n", 481 | "Epoch=7700\tloss=3.244e-04\n", 482 | "Epoch=7800\tloss=1.300e-03\n", 483 | "Epoch=7900\tloss=7.630e-06\n", 484 | "Epoch=8000\tloss=1.313e-05\n", 485 | "Epoch=8100\tloss=1.451e-05\n", 486 | "Epoch=8200\tloss=7.152e-06\n", 487 | "Epoch=8300\tloss=6.462e-05\n", 488 | "Epoch=8400\tloss=5.616e-03\n", 489 | "Epoch=8500\tloss=6.044e-06\n", 490 | "Epoch=8600\tloss=7.831e-05\n", 491 | "Epoch=8700\tloss=5.532e-06\n", 492 | "Epoch=8800\tloss=8.422e-05\n", 493 | "Epoch=8900\tloss=1.461e-03\n", 494 | "Epoch=9000\tloss=5.241e-06\n", 495 | "Epoch=9100\tloss=2.537e-04\n", 496 | "Epoch=9200\tloss=4.977e-06\n", 497 | "Epoch=9300\tloss=4.646e-05\n", 498 | "Epoch=9400\tloss=4.768e-06\n", 499 | "Epoch=9500\tloss=1.605e-05\n", 500 | "Epoch=9600\tloss=4.656e-06\n", 501 | "Epoch=9700\tloss=3.441e-03\n", 502 | "Epoch=9800\tloss=4.752e-06\n", 503 | "Epoch=9900\tloss=6.980e-03\n", 504 | "CPU times: user 17.7 s, sys: 139 ms, total: 17.8 s\n", 505 | "Wall time: 17.7 s\n" 506 | ] 507 | } 508 | ], 509 | "source": [ 510 | "%%time\n", 511 | "epochs = 10_000\n", 512 | "for _ in range(epochs):\n", 513 | " opt_state,params = update(opt_state,params,colloc,conds)\n", 514 | "\n", 515 | " # print loss and epoch info\n", 516 | " if _ %(100) ==0:\n", 517 | " print(f'Epoch={_}\\tloss={loss_fun(params,colloc,conds):.3e}')" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": 80, 523 | "metadata": { 524 | "colab": { 525 | "base_uri": "https://localhost:8080/", 526 | "height": 282 527 | }, 528 | "id": "eWeNvDsdDEuI", 529 | "outputId": "32551eeb-25df-4d2e-8cae-82cc52b41ac5" 530 | }, 531 | "outputs": [ 532 | { 533 | "data": { 534 | "text/plain": [ 535 | "" 536 | ] 537 | }, 538 | "execution_count": 80, 539 | "metadata": {}, 540 | "output_type": "execute_result" 541 | }, 542 | { 543 | "data": { 544 | "image/png": "", 545 | "text/plain": [ 546 | "
" 547 | ] 548 | }, 549 | "metadata": { 550 | "needs_background": "light" 551 | }, 552 | "output_type": "display_data" 553 | } 554 | ], 555 | "source": [ 556 | "lam_sol= sp.lambdify(t,sol)\n", 557 | "\n", 558 | "dT = 1e-3\n", 559 | "Tf = jnp.pi\n", 560 | "T = np.arange(0,Tf+dT,dT)\n", 561 | "\n", 562 | "\n", 563 | "sym_sol =np.array([lam_sol(i) for i in T])\n", 564 | "\n", 565 | "plt.plot(T,sym_sol,'--r',label='sympy solution')\n", 566 | "plt.plot(T,fwd(params,T.reshape(-1,1))[:,0],'--k',label='NN solution')\n", 567 | "plt.legend()" 568 | ] 569 | } 570 | ], 571 | "metadata": { 572 | "colab": { 573 | "authorship_tag": "ABX9TyPlDK/9ZvMulH91+B+B32BC", 574 | "collapsed_sections": [], 575 | "include_colab_link": true, 576 | "name": "[1] ODE-PINN.ipynb", 577 | "provenance": [] 578 | }, 579 | "kernelspec": { 580 | "display_name": "Python 3.8.9 64-bit", 581 | "language": "python", 582 | "name": "python3" 583 | }, 584 | "language_info": { 585 | "codemirror_mode": { 586 | "name": "ipython", 587 | "version": 3 588 | }, 589 | "file_extension": ".py", 590 | "mimetype": "text/x-python", 591 | "name": "python", 592 | "nbconvert_exporter": "python", 593 | "pygments_lexer": "ipython3", 594 | "version": "3.8.9" 595 | }, 596 | "vscode": { 597 | "interpreter": { 598 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 599 | } 600 | } 601 | }, 602 | "nbformat": 4, 603 | "nbformat_minor": 4 604 | } 605 | -------------------------------------------------------------------------------- /[1]_ODE_PINN_ClassForm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "source": [ 16 | "!pip install optax\n", 17 | "!pip install pytreeclass\n", 18 | "!pip install tqdm" 19 | ], 20 | "metadata": { 21 | "id": "PmiMsCTpOtKE", 22 | "colab": { 23 | "base_uri": "https://localhost:8080/" 24 | }, 25 | "outputId": "673af120-a5ba-45a4-953f-e77dee359c09" 26 | }, 27 | "execution_count": 1, 28 | "outputs": [ 29 | { 30 | "output_type": "stream", 31 | "name": "stdout", 32 | "text": [ 33 | "Requirement already satisfied: optax in /usr/local/lib/python3.10/dist-packages (0.1.7)\n", 34 | "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from optax) (1.4.0)\n", 35 | "Requirement already satisfied: chex>=0.1.5 in /usr/local/lib/python3.10/dist-packages (from optax) (0.1.7)\n", 36 | "Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.10/dist-packages (from optax) (0.4.14)\n", 37 | "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.10/dist-packages (from optax) (0.4.14+cuda11.cudnn86)\n", 38 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.10/dist-packages (from optax) (1.23.5)\n", 39 | "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.10/dist-packages (from chex>=0.1.5->optax) (0.1.8)\n", 40 | "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.1.5->optax) (0.12.0)\n", 41 | "Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.1.5->optax) (4.7.1)\n", 42 | "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.1.55->optax) (0.2.0)\n", 43 | "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax>=0.1.55->optax) (3.3.0)\n", 44 | "Requirement already satisfied: scipy>=1.7 in /usr/local/lib/python3.10/dist-packages (from jax>=0.1.55->optax) (1.10.1)\n", 45 | "Requirement already satisfied: pytreeclass in /usr/local/lib/python3.10/dist-packages (0.6.0)\n", 46 | "Requirement already satisfied: jax>=0.4.7 in /usr/local/lib/python3.10/dist-packages (from pytreeclass) (0.4.14)\n", 47 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from pytreeclass) (4.7.1)\n", 48 | "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.7->pytreeclass) (0.2.0)\n", 49 | "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.7->pytreeclass) (1.23.5)\n", 50 | "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.7->pytreeclass) (3.3.0)\n", 51 | "Requirement already satisfied: scipy>=1.7 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.7->pytreeclass) (1.10.1)\n", 52 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.1)\n" 53 | ] 54 | } 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "metadata": { 61 | "id": "vAR0swbLX_ZI" 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "# Imports\n", 66 | "from __future__ import annotations\n", 67 | "from typing import Callable\n", 68 | "import jax\n", 69 | "import jax.numpy as jnp\n", 70 | "import numpy as np\n", 71 | "import matplotlib.pyplot as plt\n", 72 | "import optax\n", 73 | "import sympy as sp\n", 74 | "import pytreeclass as pytc\n", 75 | "from tqdm.notebook import tqdm" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": { 81 | "id": "7bg4nSbsXVwD" 82 | }, 83 | "source": [ 84 | "### Generate a a differential equation and its solution using SymPy" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 3, 90 | "metadata": { 91 | "id": "P9664e-mVMTN" 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "t= sp.symbols('t')\n", 96 | "f = sp.Function('y')\n", 97 | "diffeq = sp.Eq(f(t).diff(t,t) + f(t).diff(t)-t*sp.cos(2*sp.pi*t),0)\n", 98 | "sol = sp.simplify(sp.dsolve(diffeq,ics={f(0):1,f(t).diff(t).subs(t,0):10}).rhs)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 4, 104 | "metadata": { 105 | "id": "klgFeU6bcTrC", 106 | "colab": { 107 | "base_uri": "https://localhost:8080/", 108 | "height": 54 109 | }, 110 | "outputId": "060d7276-cb05-418f-8be9-91e7a671a8bb" 111 | }, 112 | "outputs": [ 113 | { 114 | "output_type": "execute_result", 115 | "data": { 116 | "text/plain": [ 117 | "Eq(-t*cos(2*pi*t) + Derivative(y(t), t) + Derivative(y(t), (t, 2)), 0)" 118 | ], 119 | "text/latex": "$\\displaystyle - t \\cos{\\left(2 \\pi t \\right)} + \\frac{d}{d t} y{\\left(t \\right)} + \\frac{d^{2}}{d t^{2}} y{\\left(t \\right)} = 0$" 120 | }, 121 | "metadata": {}, 122 | "execution_count": 4 123 | } 124 | ], 125 | "source": [ 126 | "diffeq" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 5, 132 | "metadata": { 133 | "id": "E4Uu2hbiYJtv", 134 | "colab": { 135 | "base_uri": "https://localhost:8080/", 136 | "height": 61 137 | }, 138 | "outputId": "d66ce7f6-f1b4-4667-faa7-8708227ff805" 139 | }, 140 | "outputs": [ 141 | { 142 | "output_type": "execute_result", 143 | "data": { 144 | "text/plain": [ 145 | "Eq(Subs(Derivative(y(t), t), t, 0), 10)" 146 | ], 147 | "text/latex": "$\\displaystyle \\left. \\frac{d}{d t} y{\\left(t \\right)} \\right|_{\\substack{ t=0 }} = 10$" 148 | }, 149 | "metadata": {}, 150 | "execution_count": 5 151 | } 152 | ], 153 | "source": [ 154 | "sp.Eq(f(t).diff(t).subs(t,0),10)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 6, 160 | "metadata": { 161 | "id": "29QUbt_2YwlJ", 162 | "colab": { 163 | "base_uri": "https://localhost:8080/", 164 | "height": 39 165 | }, 166 | "outputId": "abd0aca6-9156-49f1-fba3-9349e0b5fe9d" 167 | }, 168 | "outputs": [ 169 | { 170 | "output_type": "execute_result", 171 | "data": { 172 | "text/plain": [ 173 | "Eq(y(0), 1)" 174 | ], 175 | "text/latex": "$\\displaystyle y{\\left(0 \\right)} = 1$" 176 | }, 177 | "metadata": {}, 178 | "execution_count": 6 179 | } 180 | ], 181 | "source": [ 182 | "sp.Eq(f(t).subs(t,0),1)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 7, 188 | "metadata": { 189 | "id": "r9KVq1yjYfld", 190 | "colab": { 191 | "base_uri": "https://localhost:8080/", 192 | "height": 82 193 | }, 194 | "outputId": "c8ef8212-caed-48a7-fbe2-1a0bc700570e" 195 | }, 196 | "outputs": [ 197 | { 198 | "output_type": "execute_result", 199 | "data": { 200 | "text/plain": [ 201 | "Eq(y(t), (2*pi*t*exp(t)*sin(2*pi*t) + 8*pi**3*t*exp(t)*sin(2*pi*t) - 16*pi**4*t*exp(t)*cos(2*pi*t) - 4*pi**2*t*exp(t)*cos(2*pi*t) + 16*pi**3*exp(t)*sin(2*pi*t) + exp(t)*cos(2*pi*t) + 12*pi**2*exp(t)*cos(2*pi*t) - exp(t) + 36*pi**2*exp(t) + 336*pi**4*exp(t) + 704*pi**6*exp(t) - 640*pi**6 - 304*pi**4 - 44*pi**2)*exp(-t)/(4*pi**2*(1 + 8*pi**2 + 16*pi**4)))" 202 | ], 203 | "text/latex": "$\\displaystyle y{\\left(t \\right)} = \\frac{\\left(2 \\pi t e^{t} \\sin{\\left(2 \\pi t \\right)} + 8 \\pi^{3} t e^{t} \\sin{\\left(2 \\pi t \\right)} - 16 \\pi^{4} t e^{t} \\cos{\\left(2 \\pi t \\right)} - 4 \\pi^{2} t e^{t} \\cos{\\left(2 \\pi t \\right)} + 16 \\pi^{3} e^{t} \\sin{\\left(2 \\pi t \\right)} + e^{t} \\cos{\\left(2 \\pi t \\right)} + 12 \\pi^{2} e^{t} \\cos{\\left(2 \\pi t \\right)} - e^{t} + 36 \\pi^{2} e^{t} + 336 \\pi^{4} e^{t} + 704 \\pi^{6} e^{t} - 640 \\pi^{6} - 304 \\pi^{4} - 44 \\pi^{2}\\right) e^{- t}}{4 \\pi^{2} \\cdot \\left(1 + 8 \\pi^{2} + 16 \\pi^{4}\\right)}$" 204 | }, 205 | "metadata": {}, 206 | "execution_count": 7 207 | } 208 | ], 209 | "source": [ 210 | "sp.Eq(f(t),sol)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 8, 216 | "metadata": { 217 | "id": "MNVOpPyCW-GU", 218 | "colab": { 219 | "base_uri": "https://localhost:8080/", 220 | "height": 37 221 | }, 222 | "outputId": "a477b9e6-cad4-4584-e926-d3e8c9538160" 223 | }, 224 | "outputs": [ 225 | { 226 | "output_type": "execute_result", 227 | "data": { 228 | "text/plain": [ 229 | "0" 230 | ], 231 | "text/latex": "$\\displaystyle 0$" 232 | }, 233 | "metadata": {}, 234 | "execution_count": 8 235 | } 236 | ], 237 | "source": [ 238 | "#verify solution\n", 239 | "sp.simplify(-t*sp.cos(sp.pi*2*t)+sol.diff(t)+sol.diff(t,t))" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": { 245 | "id": "NQ61lEQeXgrc" 246 | }, 247 | "source": [ 248 | "### Constructing the MLP" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": 9, 254 | "metadata": { 255 | "id": "Lml6PGLPZgmr", 256 | "colab": { 257 | "base_uri": "https://localhost:8080/" 258 | }, 259 | "outputId": "ad72cbc0-73f6-4a44-d384-3e2de933faf2" 260 | }, 261 | "outputs": [ 262 | { 263 | "output_type": "stream", 264 | "name": "stderr", 265 | "text": [ 266 | "WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "# construct data\n", 272 | "\n", 273 | "N_b = 1\n", 274 | "N_c = 100\n", 275 | "\n", 276 | "tmin, tmax = 0.0, jnp.pi\n", 277 | "\n", 278 | "\"\"\"boundary conditions\"\"\"\n", 279 | "\n", 280 | "\n", 281 | "# U[0] = 1\n", 282 | "t_0 = jnp.ones([N_b, 1], dtype=\"float32\") * 0.0\n", 283 | "ic_0 = jnp.ones_like(t_0)\n", 284 | "IC_0 = jnp.concatenate([t_0, ic_0], axis=1)\n", 285 | "\n", 286 | "# U_t[0] = 10\n", 287 | "t_b1 = jnp.zeros([N_b, 1])\n", 288 | "bc_1 = jnp.ones_like(t_b1) * 10\n", 289 | "BC_1 = jnp.concatenate([t_b1, bc_1], axis=1)\n", 290 | "\n", 291 | "conds: list[jax.Array] = [IC_0, BC_1]\n", 292 | "\n", 293 | "# collocation points\n", 294 | "\n", 295 | "key = jax.random.PRNGKey(0)\n", 296 | "\n", 297 | "t_c = jax.random.uniform(key, minval=tmin, maxval=tmax, shape=(N_c, 1))\n", 298 | "colloc = t_c" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "source": [ 304 | "# Build Model" 305 | ], 306 | "metadata": { 307 | "id": "V_gR3d4AKHxJ" 308 | } 309 | }, 310 | { 311 | "cell_type": "code", 312 | "source": [ 313 | "init_func = jax.nn.initializers.glorot_uniform()\n", 314 | "\n", 315 | "\n", 316 | "class Linear(pytc.TreeClass):\n", 317 | " def __init__(\n", 318 | " self,\n", 319 | " in_features: int,\n", 320 | " out_features: int,\n", 321 | " key: jax.random.KeyArray = jax.random.PRNGKey(0),\n", 322 | " ):\n", 323 | " self.weight = init_func(key, (in_features, out_features))\n", 324 | " self.bias = jax.numpy.zeros((out_features,))\n", 325 | "\n", 326 | " def __call__(self, x: jax.Array) -> jax.Array:\n", 327 | " return x @ self.weight + self.bias\n", 328 | "\n", 329 | "\n", 330 | "class MLP(pytc.TreeClass):\n", 331 | " def __init__(self, key: jax.random.KeyArray = jax.random.PRNGKey(0)):\n", 332 | " k1, k2, k3, k4 = jax.random.split(key, 4)\n", 333 | " self.l1 = Linear(1, 20, key=k1)\n", 334 | " self.l2 = Linear(20, 20, key=k2)\n", 335 | " self.l3 = Linear(20, 20, key=k3)\n", 336 | " self.l4 = Linear(20, 1, key=k4)\n", 337 | "\n", 338 | " def __call__(self, x: jax.Array) -> jax.Array:\n", 339 | " x = self.l1(x)\n", 340 | " x = jax.nn.tanh(x)\n", 341 | " x = self.l2(x)\n", 342 | " x = jax.nn.tanh(x)\n", 343 | " x = self.l3(x)\n", 344 | " x = jax.nn.tanh(x)\n", 345 | " x = self.l4(x)\n", 346 | " return x\n", 347 | "\n", 348 | "\n", 349 | "model = MLP()\n", 350 | "print(pytc.tree_summary(model))" 351 | ], 352 | "metadata": { 353 | "id": "TD7IQp70F65_", 354 | "colab": { 355 | "base_uri": "https://localhost:8080/" 356 | }, 357 | "outputId": "0fdcd611-098e-4b46-b27f-33d0da1c725d" 358 | }, 359 | "execution_count": 10, 360 | "outputs": [ 361 | { 362 | "output_type": "stream", 363 | "name": "stdout", 364 | "text": [ 365 | "┌──────────┬──────────┬─────┬──────┐\n", 366 | "│Name │Type │Count│Size │\n", 367 | "├──────────┼──────────┼─────┼──────┤\n", 368 | "│.l1.weight│f32[1,20] │20 │80.00B│\n", 369 | "├──────────┼──────────┼─────┼──────┤\n", 370 | "│.l1.bias │f32[20] │20 │80.00B│\n", 371 | "├──────────┼──────────┼─────┼──────┤\n", 372 | "│.l2.weight│f32[20,20]│400 │1.56KB│\n", 373 | "├──────────┼──────────┼─────┼──────┤\n", 374 | "│.l2.bias │f32[20] │20 │80.00B│\n", 375 | "├──────────┼──────────┼─────┼──────┤\n", 376 | "│.l3.weight│f32[20,20]│400 │1.56KB│\n", 377 | "├──────────┼──────────┼─────┼──────┤\n", 378 | "│.l3.bias │f32[20] │20 │80.00B│\n", 379 | "├──────────┼──────────┼─────┼──────┤\n", 380 | "│.l4.weight│f32[20,1] │20 │80.00B│\n", 381 | "├──────────┼──────────┼─────┼──────┤\n", 382 | "│.l4.bias │f32[1] │1 │4.00B │\n", 383 | "├──────────┼──────────┼─────┼──────┤\n", 384 | "│Σ │MLP │901 │3.52KB│\n", 385 | "└──────────┴──────────┴─────┴──────┘\n" 386 | ] 387 | } 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 11, 393 | "metadata": { 394 | "id": "KoZZJl2TbI_n" 395 | }, 396 | "outputs": [], 397 | "source": [ 398 | "def mse(true, pred):\n", 399 | " return jnp.mean((true - pred) ** 2)\n", 400 | "\n", 401 | "\n", 402 | "def diff(func: Callable, *args, **kwargs):\n", 403 | " \"\"\"sum then grad\"\"\"\n", 404 | " return jax.grad(lambda *ar, **kws: jnp.sum(func(*ar, **kws)), *args, **kwargs)\n", 405 | "\n", 406 | "\n", 407 | "def ode_loss(t, u):\n", 408 | " u_t = diff(u)\n", 409 | " u_tt = diff(u_t)\n", 410 | " return -t * jnp.cos(2 * jnp.pi * t) + u_t(t) + u_tt(t)\n", 411 | "\n", 412 | "\n", 413 | "def loss_func(model, colloc, conds):\n", 414 | " t_c = colloc[:, [0]]\n", 415 | " ufunc = model\n", 416 | " ufunc_t = diff(model)\n", 417 | "\n", 418 | " loss = jnp.mean(ode_loss(t_c, ufunc) ** 2)\n", 419 | "\n", 420 | " t_ic, u_ic = conds[0][:, [0]], conds[0][:, [1]]\n", 421 | " loss += mse(u_ic, ufunc(t_ic))\n", 422 | "\n", 423 | " t_bc, u_bc = conds[1][:, [0]], conds[1][:, [1]]\n", 424 | " loss += mse(u_bc, ufunc_t(t_bc))\n", 425 | "\n", 426 | " return loss\n", 427 | "\n", 428 | "\n", 429 | "optim = optax.adam(1e-3)\n", 430 | "optim_state = optim.init(model)\n", 431 | "\n", 432 | "\n", 433 | "@jax.jit\n", 434 | "def train_step(\n", 435 | " model: MLP,\n", 436 | " optim_state: optax.OptState,\n", 437 | " colloc: jax.Array,\n", 438 | " conds: list[jax.Array],\n", 439 | "):\n", 440 | " # Get the gradient w.r.t to MLP params\n", 441 | " grads = jax.grad(loss_func)(model, colloc, conds)\n", 442 | "\n", 443 | " # Update model\n", 444 | " updates, optim_state = optim.update(grads, optim_state)\n", 445 | " model = optax.apply_updates(model, updates)\n", 446 | "\n", 447 | " return model, optim_state" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": 12, 453 | "metadata": { 454 | "id": "kBzGA8OVc8C6", 455 | "colab": { 456 | "base_uri": "https://localhost:8080/", 457 | "height": 1000, 458 | "referenced_widgets": [ 459 | "271df676bcac4937bada71edccf32886", 460 | "ee8a741f598f4dc7b751f21b2e975bf7", 461 | "cad59966e8754537bfb7f8ef5d8289ac", 462 | "949e74fbeb5d4892a5b3b47cacfde5c5", 463 | "8189284ea2fc475192db86f14b63580a", 464 | "2285fe8ee81e45229e30ddd255357255", 465 | "5f82eb29bbd64344b93a1f9565397184", 466 | "b47d9843e3894611a631df2b46cc9ac0", 467 | "0df7710f34444afbba03b6232de3e1fd", 468 | "e193d91f0e414ec7b527d9c0ebd47558", 469 | "43561c163e1f4dec929eac0ead24d266" 470 | ] 471 | }, 472 | "outputId": "8e5ce4f3-6102-435b-ec4b-62d519d6178d" 473 | }, 474 | "outputs": [ 475 | { 476 | "output_type": "display_data", 477 | "data": { 478 | "text/plain": [ 479 | " 0%| | 0/10000 [00:00" 640 | ] 641 | }, 642 | "metadata": {}, 643 | "execution_count": 13 644 | }, 645 | { 646 | "output_type": "display_data", 647 | "data": { 648 | "text/plain": [ 649 | "
" 650 | ], 651 | "image/png": "\n" 652 | }, 653 | "metadata": {} 654 | } 655 | ] 656 | } 657 | ], 658 | "metadata": { 659 | "colab": { 660 | "name": "[1] ODE-PINN-ClassForm.ipynb", 661 | "provenance": [], 662 | "include_colab_link": true 663 | }, 664 | "kernelspec": { 665 | "display_name": "Python 3.8.9 64-bit", 666 | "language": "python", 667 | "name": "python3" 668 | }, 669 | "language_info": { 670 | "codemirror_mode": { 671 | "name": "ipython", 672 | "version": 3 673 | }, 674 | "file_extension": ".py", 675 | "mimetype": "text/x-python", 676 | "name": "python", 677 | "nbconvert_exporter": "python", 678 | "pygments_lexer": "ipython3", 679 | "version": "3.8.9" 680 | }, 681 | "vscode": { 682 | "interpreter": { 683 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 684 | } 685 | }, 686 | "widgets": { 687 | "application/vnd.jupyter.widget-state+json": { 688 | "271df676bcac4937bada71edccf32886": { 689 | "model_module": "@jupyter-widgets/controls", 690 | "model_name": "HBoxModel", 691 | "model_module_version": "1.5.0", 692 | "state": { 693 | "_dom_classes": [], 694 | "_model_module": "@jupyter-widgets/controls", 695 | "_model_module_version": "1.5.0", 696 | "_model_name": "HBoxModel", 697 | "_view_count": null, 698 | "_view_module": "@jupyter-widgets/controls", 699 | "_view_module_version": "1.5.0", 700 | "_view_name": "HBoxView", 701 | "box_style": "", 702 | "children": [ 703 | "IPY_MODEL_ee8a741f598f4dc7b751f21b2e975bf7", 704 | "IPY_MODEL_cad59966e8754537bfb7f8ef5d8289ac", 705 | "IPY_MODEL_949e74fbeb5d4892a5b3b47cacfde5c5" 706 | ], 707 | "layout": "IPY_MODEL_8189284ea2fc475192db86f14b63580a" 708 | } 709 | }, 710 | "ee8a741f598f4dc7b751f21b2e975bf7": { 711 | "model_module": "@jupyter-widgets/controls", 712 | "model_name": "HTMLModel", 713 | "model_module_version": "1.5.0", 714 | "state": { 715 | "_dom_classes": [], 716 | "_model_module": "@jupyter-widgets/controls", 717 | "_model_module_version": "1.5.0", 718 | "_model_name": "HTMLModel", 719 | "_view_count": null, 720 | "_view_module": "@jupyter-widgets/controls", 721 | "_view_module_version": "1.5.0", 722 | "_view_name": "HTMLView", 723 | "description": "", 724 | "description_tooltip": null, 725 | "layout": "IPY_MODEL_2285fe8ee81e45229e30ddd255357255", 726 | "placeholder": "​", 727 | "style": "IPY_MODEL_5f82eb29bbd64344b93a1f9565397184", 728 | "value": "100%" 729 | } 730 | }, 731 | "cad59966e8754537bfb7f8ef5d8289ac": { 732 | "model_module": "@jupyter-widgets/controls", 733 | "model_name": "FloatProgressModel", 734 | "model_module_version": "1.5.0", 735 | "state": { 736 | "_dom_classes": [], 737 | "_model_module": "@jupyter-widgets/controls", 738 | "_model_module_version": "1.5.0", 739 | "_model_name": "FloatProgressModel", 740 | "_view_count": null, 741 | "_view_module": "@jupyter-widgets/controls", 742 | "_view_module_version": "1.5.0", 743 | "_view_name": "ProgressView", 744 | "bar_style": "success", 745 | "description": "", 746 | "description_tooltip": null, 747 | "layout": "IPY_MODEL_b47d9843e3894611a631df2b46cc9ac0", 748 | "max": 10000, 749 | "min": 0, 750 | "orientation": "horizontal", 751 | "style": "IPY_MODEL_0df7710f34444afbba03b6232de3e1fd", 752 | "value": 10000 753 | } 754 | }, 755 | "949e74fbeb5d4892a5b3b47cacfde5c5": { 756 | "model_module": "@jupyter-widgets/controls", 757 | "model_name": "HTMLModel", 758 | "model_module_version": "1.5.0", 759 | "state": { 760 | "_dom_classes": [], 761 | "_model_module": "@jupyter-widgets/controls", 762 | "_model_module_version": "1.5.0", 763 | "_model_name": "HTMLModel", 764 | "_view_count": null, 765 | "_view_module": "@jupyter-widgets/controls", 766 | "_view_module_version": "1.5.0", 767 | "_view_name": "HTMLView", 768 | "description": "", 769 | "description_tooltip": null, 770 | "layout": "IPY_MODEL_e193d91f0e414ec7b527d9c0ebd47558", 771 | "placeholder": "​", 772 | "style": "IPY_MODEL_43561c163e1f4dec929eac0ead24d266", 773 | "value": " 10000/10000 [00:29<00:00, 886.65it/s]" 774 | } 775 | }, 776 | "8189284ea2fc475192db86f14b63580a": { 777 | "model_module": "@jupyter-widgets/base", 778 | "model_name": "LayoutModel", 779 | "model_module_version": "1.2.0", 780 | "state": { 781 | "_model_module": "@jupyter-widgets/base", 782 | "_model_module_version": "1.2.0", 783 | "_model_name": "LayoutModel", 784 | "_view_count": null, 785 | "_view_module": "@jupyter-widgets/base", 786 | "_view_module_version": "1.2.0", 787 | "_view_name": "LayoutView", 788 | "align_content": null, 789 | "align_items": null, 790 | "align_self": null, 791 | "border": null, 792 | "bottom": null, 793 | "display": null, 794 | "flex": null, 795 | "flex_flow": null, 796 | "grid_area": null, 797 | "grid_auto_columns": null, 798 | "grid_auto_flow": null, 799 | "grid_auto_rows": null, 800 | "grid_column": null, 801 | "grid_gap": null, 802 | "grid_row": null, 803 | "grid_template_areas": null, 804 | "grid_template_columns": null, 805 | "grid_template_rows": null, 806 | "height": null, 807 | "justify_content": null, 808 | "justify_items": null, 809 | "left": null, 810 | "margin": null, 811 | "max_height": null, 812 | "max_width": null, 813 | "min_height": null, 814 | "min_width": null, 815 | "object_fit": null, 816 | "object_position": null, 817 | "order": null, 818 | "overflow": null, 819 | "overflow_x": null, 820 | "overflow_y": null, 821 | "padding": null, 822 | "right": null, 823 | "top": null, 824 | "visibility": null, 825 | "width": null 826 | } 827 | }, 828 | "2285fe8ee81e45229e30ddd255357255": { 829 | "model_module": "@jupyter-widgets/base", 830 | "model_name": "LayoutModel", 831 | "model_module_version": "1.2.0", 832 | "state": { 833 | "_model_module": "@jupyter-widgets/base", 834 | "_model_module_version": "1.2.0", 835 | "_model_name": "LayoutModel", 836 | "_view_count": null, 837 | "_view_module": "@jupyter-widgets/base", 838 | "_view_module_version": "1.2.0", 839 | "_view_name": "LayoutView", 840 | "align_content": null, 841 | "align_items": null, 842 | "align_self": null, 843 | "border": null, 844 | "bottom": null, 845 | "display": null, 846 | "flex": null, 847 | "flex_flow": null, 848 | "grid_area": null, 849 | "grid_auto_columns": null, 850 | "grid_auto_flow": null, 851 | "grid_auto_rows": null, 852 | "grid_column": null, 853 | "grid_gap": null, 854 | "grid_row": null, 855 | "grid_template_areas": null, 856 | "grid_template_columns": null, 857 | "grid_template_rows": null, 858 | "height": null, 859 | "justify_content": null, 860 | "justify_items": null, 861 | "left": null, 862 | "margin": null, 863 | "max_height": null, 864 | "max_width": null, 865 | "min_height": null, 866 | "min_width": null, 867 | "object_fit": null, 868 | "object_position": null, 869 | "order": null, 870 | "overflow": null, 871 | "overflow_x": null, 872 | "overflow_y": null, 873 | "padding": null, 874 | "right": null, 875 | "top": null, 876 | "visibility": null, 877 | "width": null 878 | } 879 | }, 880 | "5f82eb29bbd64344b93a1f9565397184": { 881 | "model_module": "@jupyter-widgets/controls", 882 | "model_name": "DescriptionStyleModel", 883 | "model_module_version": "1.5.0", 884 | "state": { 885 | "_model_module": "@jupyter-widgets/controls", 886 | "_model_module_version": "1.5.0", 887 | "_model_name": "DescriptionStyleModel", 888 | "_view_count": null, 889 | "_view_module": "@jupyter-widgets/base", 890 | "_view_module_version": "1.2.0", 891 | "_view_name": "StyleView", 892 | "description_width": "" 893 | } 894 | }, 895 | "b47d9843e3894611a631df2b46cc9ac0": { 896 | "model_module": "@jupyter-widgets/base", 897 | "model_name": "LayoutModel", 898 | "model_module_version": "1.2.0", 899 | "state": { 900 | "_model_module": "@jupyter-widgets/base", 901 | "_model_module_version": "1.2.0", 902 | "_model_name": "LayoutModel", 903 | "_view_count": null, 904 | "_view_module": "@jupyter-widgets/base", 905 | "_view_module_version": "1.2.0", 906 | "_view_name": "LayoutView", 907 | "align_content": null, 908 | "align_items": null, 909 | "align_self": null, 910 | "border": null, 911 | "bottom": null, 912 | "display": null, 913 | "flex": null, 914 | "flex_flow": null, 915 | "grid_area": null, 916 | "grid_auto_columns": null, 917 | "grid_auto_flow": null, 918 | "grid_auto_rows": null, 919 | "grid_column": null, 920 | "grid_gap": null, 921 | "grid_row": null, 922 | "grid_template_areas": null, 923 | "grid_template_columns": null, 924 | "grid_template_rows": null, 925 | "height": null, 926 | "justify_content": null, 927 | "justify_items": null, 928 | "left": null, 929 | "margin": null, 930 | "max_height": null, 931 | "max_width": null, 932 | "min_height": null, 933 | "min_width": null, 934 | "object_fit": null, 935 | "object_position": null, 936 | "order": null, 937 | "overflow": null, 938 | "overflow_x": null, 939 | "overflow_y": null, 940 | "padding": null, 941 | "right": null, 942 | "top": null, 943 | "visibility": null, 944 | "width": null 945 | } 946 | }, 947 | "0df7710f34444afbba03b6232de3e1fd": { 948 | "model_module": "@jupyter-widgets/controls", 949 | "model_name": "ProgressStyleModel", 950 | "model_module_version": "1.5.0", 951 | "state": { 952 | "_model_module": "@jupyter-widgets/controls", 953 | "_model_module_version": "1.5.0", 954 | "_model_name": "ProgressStyleModel", 955 | "_view_count": null, 956 | "_view_module": "@jupyter-widgets/base", 957 | "_view_module_version": "1.2.0", 958 | "_view_name": "StyleView", 959 | "bar_color": null, 960 | "description_width": "" 961 | } 962 | }, 963 | "e193d91f0e414ec7b527d9c0ebd47558": { 964 | "model_module": "@jupyter-widgets/base", 965 | "model_name": "LayoutModel", 966 | "model_module_version": "1.2.0", 967 | "state": { 968 | "_model_module": "@jupyter-widgets/base", 969 | "_model_module_version": "1.2.0", 970 | "_model_name": "LayoutModel", 971 | "_view_count": null, 972 | "_view_module": "@jupyter-widgets/base", 973 | "_view_module_version": "1.2.0", 974 | "_view_name": "LayoutView", 975 | "align_content": null, 976 | "align_items": null, 977 | "align_self": null, 978 | "border": null, 979 | "bottom": null, 980 | "display": null, 981 | "flex": null, 982 | "flex_flow": null, 983 | "grid_area": null, 984 | "grid_auto_columns": null, 985 | "grid_auto_flow": null, 986 | "grid_auto_rows": null, 987 | "grid_column": null, 988 | "grid_gap": null, 989 | "grid_row": null, 990 | "grid_template_areas": null, 991 | "grid_template_columns": null, 992 | "grid_template_rows": null, 993 | "height": null, 994 | "justify_content": null, 995 | "justify_items": null, 996 | "left": null, 997 | "margin": null, 998 | "max_height": null, 999 | "max_width": null, 1000 | "min_height": null, 1001 | "min_width": null, 1002 | "object_fit": null, 1003 | "object_position": null, 1004 | "order": null, 1005 | "overflow": null, 1006 | "overflow_x": null, 1007 | "overflow_y": null, 1008 | "padding": null, 1009 | "right": null, 1010 | "top": null, 1011 | "visibility": null, 1012 | "width": null 1013 | } 1014 | }, 1015 | "43561c163e1f4dec929eac0ead24d266": { 1016 | "model_module": "@jupyter-widgets/controls", 1017 | "model_name": "DescriptionStyleModel", 1018 | "model_module_version": "1.5.0", 1019 | "state": { 1020 | "_model_module": "@jupyter-widgets/controls", 1021 | "_model_module_version": "1.5.0", 1022 | "_model_name": "DescriptionStyleModel", 1023 | "_view_count": null, 1024 | "_view_module": "@jupyter-widgets/base", 1025 | "_view_module_version": "1.2.0", 1026 | "_view_name": "StyleView", 1027 | "description_width": "" 1028 | } 1029 | } 1030 | } 1031 | } 1032 | }, 1033 | "nbformat": 4, 1034 | "nbformat_minor": 0 1035 | } -------------------------------------------------------------------------------- /[5]_System_of_ODEs_PINN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": { 17 | "id": "v77fdC1ZLyg1" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "#Credits : Mahmoud Asem @Asem000 October 2021" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": { 28 | "colab": { 29 | "base_uri": "https://localhost:8080/" 30 | }, 31 | "id": "vAR0swbLX_ZI", 32 | "outputId": "97823711-ee53-4921-bf97-c2f9a312e57f" 33 | }, 34 | "outputs": [ 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (0.0.9)\n", 40 | "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (0.12.0)\n", 41 | "Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.2.21)\n", 42 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.19.5)\n", 43 | "Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax) (0.0.8)\n", 44 | "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.1.71+cuda111)\n", 45 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->optax) (1.15.0)\n", 46 | "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.11.1)\n", 47 | "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.6)\n", 48 | "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0)\n", 49 | "Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (1.4.1)\n", 50 | "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.12)\n", 51 | "Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (0.51.2)\n", 52 | "Requirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.7/dist-packages (from numba) (1.19.5)\n", 53 | "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba) (57.4.0)\n", 54 | "Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba) (0.34.0)\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "#Imports\n", 60 | "import jax \n", 61 | "import jax.numpy as jnp\n", 62 | "import numpy as np\n", 63 | "import matplotlib.pyplot as plt\n", 64 | "from matplotlib import cm\n", 65 | "import matplotlib as mpl\n", 66 | "!pip install optax\n", 67 | "import optax\n", 68 | "!pip install numba\n", 69 | "import numba" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": { 75 | "id": "7bg4nSbsXVwD" 76 | }, 77 | "source": [ 78 | "### System of ODEs numerical solution\n", 79 | "$\\large \\frac{dx}{dt} = x$
\n", 80 | "\n", 81 | "$\\large \\frac{dy}{dt} = x - y$
\n", 82 | "\n", 83 | "$x(t=0) = 1$\n", 84 | "\n", 85 | "$y(t=0) = 2$\n", 86 | "\n", 87 | "
\n", 88 | "$\\text{analytical solution}$\n", 89 | "\n", 90 | "$x(t) = e^{t}$\n", 91 | "\n", 92 | "$y(t) = \\frac{1}{2} e^{t} + \\frac{3}{2} e^{-t}$" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": { 98 | "id": "K89DstaYpwh0" 99 | }, 100 | "source": [] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 3, 105 | "metadata": { 106 | "id": "8uW0HW1-pxl8" 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "'''\n", 111 | "\n", 112 | "solve \n", 113 | "dx/dt = x\n", 114 | "dy/dt = x-y\n", 115 | "\n", 116 | "x(0) = 1\n", 117 | "y(0) = 2\n", 118 | "\n", 119 | "solution =\n", 120 | "\n", 121 | "x(t) = exp(t)\n", 122 | "y(t) = 0.5*exp(t) +1.5*exp(-t)\n", 123 | "\n", 124 | "'''\n", 125 | "\n", 126 | "@numba.njit\n", 127 | "def RK4(odefun,ics,h,span,degree):\n", 128 | " \n", 129 | " N= int( (span[1]-span[0])/h )\n", 130 | " \n", 131 | " tY = np.zeros((N+1,degree+1))\n", 132 | " tY[0,1:] = ics\n", 133 | " \n", 134 | " \n", 135 | " for i in range(N):\n", 136 | " tY[i+1,0] = tY[i,0] + h\n", 137 | "\n", 138 | " k1= odefun(tY[i,0] , tY[i,1:])\n", 139 | " k2= odefun(tY[i,0] +(h/2), tY[i,1:] +(h*k1)/2 )\n", 140 | " k3= odefun(tY[i,0] +(h/2), tY[i,1:] +(h*k2)/2)\n", 141 | " k4= odefun(tY[i,0] +(h) , tY[i,1:] +(h*k3))\n", 142 | " \n", 143 | " tY[i+1,1:] = tY[i,1:] + h*(1/6) * (k1+2*k2+2*k3+k4)\n", 144 | " \n", 145 | " return tY[:,0],tY[:,1:]\n", 146 | "\n", 147 | "@numba.njit\n", 148 | "def system_of_ode(t,V):\n", 149 | " y1,y2 = V[0],V[1]\n", 150 | " return np.array([y1,y1-y2])\n", 151 | "\n", 152 | "t,y=RK4(system_of_ode,\n", 153 | " ics=np.array([1,2]),\n", 154 | " h=1e-3,\n", 155 | " span=np.array([1e-4,np.pi]),\n", 156 | " degree =2)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 4, 162 | "metadata": { 163 | "colab": { 164 | "base_uri": "https://localhost:8080/", 165 | "height": 282 166 | }, 167 | "id": "bNctv_EFp-yD", 168 | "outputId": "9ede81ec-108c-4d4b-f928-d0c2493aea0c" 169 | }, 170 | "outputs": [ 171 | { 172 | "data": { 173 | "text/plain": [ 174 | "" 175 | ] 176 | }, 177 | "execution_count": 4, 178 | "metadata": {}, 179 | "output_type": "execute_result" 180 | }, 181 | { 182 | "data": { 183 | "image/png": "", 184 | "text/plain": [ 185 | "
" 186 | ] 187 | }, 188 | "metadata": { 189 | "needs_background": "light" 190 | }, 191 | "output_type": "display_data" 192 | } 193 | ], 194 | "source": [ 195 | "plt.plot(t,y[:,0],'-r',label='RK4[1]')\n", 196 | "plt.plot(t,np.exp(t),'--k',label='Analytical[1]')\n", 197 | "\n", 198 | "plt.plot(t,y[:,1],'-g',label='RK4[2]')\n", 199 | "plt.plot(t,0.5*np.exp(t)+1.5*np.exp(-t),'--b',label='Analytical[2]')\n", 200 | "\n", 201 | "plt.legend()" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": { 207 | "id": "NQ61lEQeXgrc" 208 | }, 209 | "source": [ 210 | "### Constructing the MLP" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 5, 216 | "metadata": { 217 | "colab": { 218 | "base_uri": "https://localhost:8080/" 219 | }, 220 | "id": "Lml6PGLPZgmr", 221 | "outputId": "aaa25445-8e2d-4540-8cf3-20d6e5404fc8" 222 | }, 223 | "outputs": [ 224 | { 225 | "name": "stderr", 226 | "output_type": "stream", 227 | "text": [ 228 | "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" 229 | ] 230 | } 231 | ], 232 | "source": [ 233 | "N_b = 1_000\n", 234 | "N_c = 10_000\n", 235 | "\n", 236 | "tmin,tmax=0. ,jnp.pi\n", 237 | "\n", 238 | "'''boundary conditions'''\n", 239 | "\n", 240 | "\n", 241 | "# y1[0] = 1\n", 242 | "y1_t0 = jnp.zeros([N_b,1],dtype='float32')\n", 243 | "y1_ic = jnp.ones_like(y1_t0) \n", 244 | "Y1_IC = jnp.concatenate([y1_t0,y1_ic],axis=1)\n", 245 | "\n", 246 | "# y2[0] = 2\n", 247 | "y2_t0 = jnp.zeros([N_b,1],dtype='float32')\n", 248 | "y2_ic = jnp.ones_like(y2_t0) * 2\n", 249 | "Y2_IC = jnp.concatenate([y2_t0,y2_ic],axis=1)\n", 250 | "\n", 251 | "conds = [Y1_IC,Y2_IC]\n", 252 | "\n", 253 | "#collocation points\n", 254 | "\n", 255 | "key=jax.random.PRNGKey(0)\n", 256 | "\n", 257 | "t_c = jax.random.uniform(key,minval=tmin,maxval=tmax,shape=(N_c,1))\n", 258 | "colloc = t_c\n", 259 | "\n", 260 | "def ODE_loss(t,y1,y2):\n", 261 | "\n", 262 | " y1_t=lambda t:jax.grad(lambda t:jnp.sum(y1(t)))(t)\n", 263 | " y2_t=lambda t:jax.grad(lambda t:jnp.sum(y2(t)))(t)\n", 264 | "\n", 265 | " return y1_t(t) - y1(t) , y2_t(t) - y1(t) + y2(t)\n" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 6, 271 | "metadata": { 272 | "id": "KoZZJl2TbI_n" 273 | }, 274 | "outputs": [], 275 | "source": [ 276 | "def init_params(layers):\n", 277 | " keys = jax.random.split(jax.random.PRNGKey(0),len(layers)-1)\n", 278 | " params = list()\n", 279 | " for key,n_in,n_out in zip(keys,layers[:-1],layers[1:]):\n", 280 | " lb, ub = -(1 / jnp.sqrt(n_in)), (1 / jnp.sqrt(n_in)) # xavier initialization lower and upper bound\n", 281 | " W = lb + (ub-lb) * jax.random.uniform(key,shape=(n_in,n_out))\n", 282 | " B = jax.random.uniform(key,shape=(n_out,))\n", 283 | " params.append({'W':W,'B':B})\n", 284 | " return params\n", 285 | "\n", 286 | "def fwd(params,t):\n", 287 | " X = jnp.concatenate([t],axis=1)\n", 288 | " *hidden,last = params\n", 289 | " for layer in hidden :\n", 290 | " X = jax.nn.tanh(X@layer['W']+layer['B'])\n", 291 | " return X@last['W'] + last['B']\n", 292 | "\n", 293 | "@jax.jit\n", 294 | "def MSE(true,pred):\n", 295 | " return jnp.mean((true-pred)**2)\n", 296 | "\n", 297 | "def loss_fun(params,colloc,conds):\n", 298 | " t_c =colloc[:,[0]]\n", 299 | "\n", 300 | " y1_func = lambda t : fwd(params,t)[:,[0]]\n", 301 | " y1_func_t=lambda t:jax.grad(lambda t:jnp.sum(y1_func(t)))(t)\n", 302 | "\n", 303 | " y2_func = lambda t : fwd(params,t)[:,[1]]\n", 304 | " y2_func_t=lambda t:jax.grad(lambda t:jnp.sum(y2_func(t)))(t)\n", 305 | "\n", 306 | " loss_y1,loss_y2 =ODE_loss(t_c,y1_func,y2_func)\n", 307 | "\n", 308 | " loss = jnp.mean( loss_y1 **2) \n", 309 | " loss+= jnp.mean(loss_y2 **2)\n", 310 | "\n", 311 | " t_ic,y1_ic = conds[0][:,[0]],conds[0][:,[1]] \n", 312 | " loss += MSE(y1_ic,y1_func(t_ic))\n", 313 | "\n", 314 | " t_ic,y2_ic = conds[1][:,[0]],conds[1][:,[1]] \n", 315 | " loss += MSE(y2_ic,y2_func(t_ic))\n", 316 | "\n", 317 | " return loss\n", 318 | "\n", 319 | "@jax.jit\n", 320 | "def update(opt_state,params,colloc,conds):\n", 321 | " # Get the gradient w.r.t to MLP params\n", 322 | " grads=jax.jit(jax.grad(loss_fun,0))(params,colloc,conds)\n", 323 | " \n", 324 | " #Update params\n", 325 | " updates, opt_state = optimizer.update(grads, opt_state)\n", 326 | " params = optax.apply_updates(params, updates)\n", 327 | "\n", 328 | " #Update params\n", 329 | " # return jax.tree_multimap(lambda params,grads : params-LR*grads, params,grads)\n", 330 | " return opt_state,params\n" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 7, 336 | "metadata": { 337 | "id": "ae1ZDoy0c29c" 338 | }, 339 | "outputs": [], 340 | "source": [ 341 | "# construct the MLP of 6 hidden layers of 8 neurons for each layer\n", 342 | "params = init_params([1] + [8]*2+[2])" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 8, 348 | "metadata": { 349 | "id": "jySmbUwic5yk" 350 | }, 351 | "outputs": [], 352 | "source": [ 353 | "optimizer = optax.adam(1e-2)\n", 354 | "opt_state = optimizer.init(params)" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": 9, 360 | "metadata": { 361 | "colab": { 362 | "base_uri": "https://localhost:8080/" 363 | }, 364 | "id": "VS0bosXPg1Oo", 365 | "outputId": "258603fc-cb6e-4dcb-eb86-4560004b71e2" 366 | }, 367 | "outputs": [ 368 | { 369 | "data": { 370 | "text/plain": [ 371 | "(10000, 1)" 372 | ] 373 | }, 374 | "execution_count": 9, 375 | "metadata": {}, 376 | "output_type": "execute_result" 377 | } 378 | ], 379 | "source": [ 380 | "fwd(params,t_c)[:,[0]].shape" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": 10, 386 | "metadata": { 387 | "colab": { 388 | "base_uri": "https://localhost:8080/" 389 | }, 390 | "id": "kBzGA8OVc8C6", 391 | "outputId": "7dda4b83-5d7e-497c-fb98-64e7b1322528" 392 | }, 393 | "outputs": [ 394 | { 395 | "name": "stdout", 396 | "output_type": "stream", 397 | "text": [ 398 | "Epoch=0\tloss=4.391e+00\n", 399 | "Epoch=100\tloss=3.921e-01\n", 400 | "Epoch=200\tloss=3.891e-01\n", 401 | "Epoch=300\tloss=3.876e-01\n", 402 | "Epoch=400\tloss=3.866e-01\n", 403 | "Epoch=500\tloss=3.855e-01\n", 404 | "Epoch=600\tloss=3.840e-01\n", 405 | "Epoch=700\tloss=3.820e-01\n", 406 | "Epoch=800\tloss=3.788e-01\n", 407 | "Epoch=900\tloss=3.726e-01\n", 408 | "Epoch=1000\tloss=3.555e-01\n", 409 | "Epoch=1100\tloss=3.070e-01\n", 410 | "Epoch=1200\tloss=2.254e-01\n", 411 | "Epoch=1300\tloss=1.834e-01\n", 412 | "Epoch=1400\tloss=1.354e-01\n", 413 | "Epoch=1500\tloss=1.085e-01\n", 414 | "Epoch=1600\tloss=8.198e-02\n", 415 | "Epoch=1700\tloss=6.454e-02\n", 416 | "Epoch=1800\tloss=5.201e-02\n", 417 | "Epoch=1900\tloss=4.232e-02\n", 418 | "Epoch=2000\tloss=3.476e-02\n", 419 | "Epoch=2100\tloss=2.897e-02\n", 420 | "Epoch=2200\tloss=2.339e-02\n", 421 | "Epoch=2300\tloss=1.980e-02\n", 422 | "Epoch=2400\tloss=1.632e-02\n", 423 | "Epoch=2500\tloss=1.425e-02\n", 424 | "Epoch=2600\tloss=1.194e-02\n", 425 | "Epoch=2700\tloss=1.174e-02\n", 426 | "Epoch=2800\tloss=9.231e-03\n", 427 | "Epoch=2900\tloss=7.939e-03\n", 428 | "Epoch=3000\tloss=7.260e-03\n", 429 | "Epoch=3100\tloss=6.298e-03\n", 430 | "Epoch=3200\tloss=7.007e-02\n", 431 | "Epoch=3300\tloss=5.132e-03\n", 432 | "Epoch=3400\tloss=4.541e-03\n", 433 | "Epoch=3500\tloss=5.706e-03\n", 434 | "Epoch=3600\tloss=3.834e-03\n", 435 | "Epoch=3700\tloss=3.433e-03\n", 436 | "Epoch=3800\tloss=3.316e-03\n", 437 | "Epoch=3900\tloss=2.964e-03\n", 438 | "Epoch=4000\tloss=2.688e-03\n", 439 | "Epoch=4100\tloss=2.604e-03\n", 440 | "Epoch=4200\tloss=2.364e-03\n", 441 | "Epoch=4300\tloss=1.091e-01\n", 442 | "Epoch=4400\tloss=2.085e-03\n", 443 | "Epoch=4500\tloss=1.907e-03\n", 444 | "Epoch=4600\tloss=2.153e-03\n", 445 | "Epoch=4700\tloss=1.716e-03\n", 446 | "Epoch=4800\tloss=1.577e-03\n", 447 | "Epoch=4900\tloss=1.593e-03\n", 448 | "Epoch=5000\tloss=1.462e-03\n", 449 | "Epoch=5100\tloss=1.348e-03\n", 450 | "Epoch=5200\tloss=1.934e-03\n", 451 | "Epoch=5300\tloss=1.247e-03\n", 452 | "Epoch=5400\tloss=1.154e-03\n", 453 | "Epoch=5500\tloss=1.163e-03\n", 454 | "Epoch=5600\tloss=1.062e-03\n", 455 | "Epoch=5700\tloss=6.481e-03\n", 456 | "Epoch=5800\tloss=9.879e-04\n", 457 | "Epoch=5900\tloss=9.198e-04\n", 458 | "Epoch=6000\tloss=9.548e-04\n", 459 | "Epoch=6100\tloss=8.567e-04\n", 460 | "Epoch=6200\tloss=5.164e-03\n", 461 | "Epoch=6300\tloss=8.136e-04\n", 462 | "Epoch=6400\tloss=7.632e-04\n", 463 | "Epoch=6500\tloss=8.119e-04\n", 464 | "Epoch=6600\tloss=7.259e-04\n", 465 | "Epoch=6700\tloss=6.844e-04\n", 466 | "Epoch=6800\tloss=6.932e-04\n", 467 | "Epoch=6900\tloss=6.498e-04\n", 468 | "Epoch=7000\tloss=1.022e-02\n", 469 | "Epoch=7100\tloss=6.268e-04\n", 470 | "Epoch=7200\tloss=5.936e-04\n", 471 | "Epoch=7300\tloss=3.431e-03\n", 472 | "Epoch=7400\tloss=5.711e-04\n", 473 | "Epoch=7500\tloss=5.428e-04\n", 474 | "Epoch=7600\tloss=6.806e-04\n", 475 | "Epoch=7700\tloss=5.217e-04\n", 476 | "Epoch=7800\tloss=7.266e-03\n", 477 | "Epoch=7900\tloss=5.146e-04\n", 478 | "Epoch=8000\tloss=4.886e-04\n", 479 | "Epoch=8100\tloss=1.107e-03\n", 480 | "Epoch=8200\tloss=4.893e-04\n", 481 | "Epoch=8300\tloss=4.574e-04\n", 482 | "Epoch=8400\tloss=9.877e-02\n", 483 | "Epoch=8500\tloss=4.501e-04\n", 484 | "Epoch=8600\tloss=4.283e-04\n", 485 | "Epoch=8700\tloss=2.029e-03\n", 486 | "Epoch=8800\tloss=4.216e-04\n", 487 | "Epoch=8900\tloss=4.044e-04\n", 488 | "Epoch=9000\tloss=4.368e-04\n", 489 | "Epoch=9100\tloss=4.015e-04\n", 490 | "Epoch=9200\tloss=3.859e-04\n", 491 | "Epoch=9300\tloss=1.088e-03\n", 492 | "Epoch=9400\tloss=3.847e-04\n", 493 | "Epoch=9500\tloss=3.701e-04\n", 494 | "Epoch=9600\tloss=7.795e-04\n", 495 | "Epoch=9700\tloss=3.653e-04\n", 496 | "Epoch=9800\tloss=8.272e-02\n", 497 | "Epoch=9900\tloss=3.632e-04\n", 498 | "CPU times: user 1min 39s, sys: 2.21 s, total: 1min 41s\n", 499 | "Wall time: 1min 31s\n" 500 | ] 501 | } 502 | ], 503 | "source": [ 504 | "%%time\n", 505 | "epochs = 10_000\n", 506 | "for _ in range(epochs):\n", 507 | " opt_state,params = update(opt_state,params,colloc,conds)\n", 508 | "\n", 509 | " # print loss and epoch info\n", 510 | " if _ %(100) ==0:\n", 511 | " print(f'Epoch={_}\\tloss={loss_fun(params,colloc,conds):.3e}')" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": 11, 517 | "metadata": { 518 | "colab": { 519 | "base_uri": "https://localhost:8080/", 520 | "height": 282 521 | }, 522 | "id": "eWeNvDsdDEuI", 523 | "outputId": "3ca4560a-ca9f-451f-db9d-0b61e081cfa2" 524 | }, 525 | "outputs": [ 526 | { 527 | "data": { 528 | "text/plain": [ 529 | "" 530 | ] 531 | }, 532 | "execution_count": 11, 533 | "metadata": {}, 534 | "output_type": "execute_result" 535 | }, 536 | { 537 | "data": { 538 | "image/png": "", 539 | "text/plain": [ 540 | "
" 541 | ] 542 | }, 543 | "metadata": { 544 | "needs_background": "light" 545 | }, 546 | "output_type": "display_data" 547 | } 548 | ], 549 | "source": [ 550 | "dT = 1e-3\n", 551 | "Tf = jnp.pi\n", 552 | "T = np.arange(0,Tf+dT,dT)\n", 553 | "\n", 554 | "plt.plot(t,np.exp(t),'-b',label='Analytical[y1]')\n", 555 | "plt.plot(T,fwd(params,T.reshape(-1,1))[:,0],'--k',label='NN[y1]',linewidth=2)\n", 556 | "\n", 557 | "plt.plot(t,0.5*np.exp(t)+1.5*np.exp(-t),'-r',label='Analytical[y2]')\n", 558 | "plt.plot(T,fwd(params,T.reshape(-1,1))[:,1],'--k',label='NN[y2]',linewidth=2)\n", 559 | "\n", 560 | "plt.legend()" 561 | ] 562 | } 563 | ], 564 | "metadata": { 565 | "colab": { 566 | "collapsed_sections": [], 567 | "include_colab_link": true, 568 | "name": "[5] System-of-ODE-PINN.ipynb", 569 | "provenance": [] 570 | }, 571 | "kernelspec": { 572 | "display_name": "Python 3.8.9 64-bit", 573 | "language": "python", 574 | "name": "python3" 575 | }, 576 | "language_info": { 577 | "codemirror_mode": { 578 | "name": "ipython", 579 | "version": 3 580 | }, 581 | "file_extension": ".py", 582 | "mimetype": "text/x-python", 583 | "name": "python", 584 | "nbconvert_exporter": "python", 585 | "pygments_lexer": "ipython3", 586 | "version": "3.8.9" 587 | }, 588 | "vscode": { 589 | "interpreter": { 590 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" 591 | } 592 | } 593 | }, 594 | "nbformat": 4, 595 | "nbformat_minor": 0 596 | } 597 | -------------------------------------------------------------------------------- /[6]_ODE_PINN_finite_difference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": { 17 | "id": "v77fdC1ZLyg1" 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "#Credits : Mahmoud Asem @Asem000" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": { 28 | "colab": { 29 | "base_uri": "https://localhost:8080/" 30 | }, 31 | "id": "vAR0swbLX_ZI", 32 | "outputId": "dab17b4c-908d-4213-ac91-c20d0f7ef122" 33 | }, 34 | "outputs": [ 35 | { 36 | "output_type": "stream", 37 | "name": "stdout", 38 | "text": [ 39 | "Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (0.1.1)\n", 40 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.21.5)\n", 41 | "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.3.2+cuda11.cudnn805)\n", 42 | "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (1.0.0)\n", 43 | "Requirement already satisfied: typing-extensions>=3.10.0 in /usr/local/lib/python3.7/dist-packages (from optax) (3.10.0.2)\n", 44 | "Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.3.4)\n", 45 | "Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax) (0.1.1)\n", 46 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->optax) (1.15.0)\n", 47 | "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.11.2)\n", 48 | "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.6)\n", 49 | "Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (1.4.1)\n", 50 | "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0)\n", 51 | "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (2.0)\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "#Imports\n", 57 | "import jax \n", 58 | "import jax.numpy as jnp\n", 59 | "import numpy as np\n", 60 | "import matplotlib.pyplot as plt\n", 61 | "from matplotlib import cm\n", 62 | "import matplotlib as mpl\n", 63 | "!pip install optax\n", 64 | "import optax" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "id": "yoPHsh5lWvyP" 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "import sympy as sp" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": { 81 | "id": "7bg4nSbsXVwD" 82 | }, 83 | "source": [ 84 | "### Generate a a differential equation and its solution using SymPy" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": { 91 | "id": "P9664e-mVMTN" 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "t= sp.symbols('t')\n", 96 | "f = sp.Function('y')\n", 97 | "diffeq = sp.Eq(f(t).diff(t,t) + f(t).diff(t)-t*sp.cos(2*sp.pi*t),0)\n", 98 | "sol = sp.simplify(sp.dsolve(diffeq,ics={f(0):1,f(t).diff(t).subs(t,0):10}).rhs)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": { 105 | "colab": { 106 | "base_uri": "https://localhost:8080/", 107 | "height": 54 108 | }, 109 | "id": "klgFeU6bcTrC", 110 | "outputId": "5c0fc009-ce8b-472f-b98b-6da8afd6a66c" 111 | }, 112 | "outputs": [ 113 | { 114 | "output_type": "execute_result", 115 | "data": { 116 | "text/plain": [ 117 | "Eq(-t*cos(2*pi*t) + Derivative(y(t), t) + Derivative(y(t), (t, 2)), 0)" 118 | ], 119 | "text/latex": "$\\displaystyle - t \\cos{\\left(2 \\pi t \\right)} + \\frac{d}{d t} y{\\left(t \\right)} + \\frac{d^{2}}{d t^{2}} y{\\left(t \\right)} = 0$" 120 | }, 121 | "metadata": {}, 122 | "execution_count": 4 123 | } 124 | ], 125 | "source": [ 126 | "diffeq" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "colab": { 134 | "base_uri": "https://localhost:8080/", 135 | "height": 60 136 | }, 137 | "id": "E4Uu2hbiYJtv", 138 | "outputId": "6444bdb2-9243-46c9-e801-d19960dcb233" 139 | }, 140 | "outputs": [ 141 | { 142 | "output_type": "execute_result", 143 | "data": { 144 | "text/plain": [ 145 | "Eq(Subs(Derivative(y(t), t), t, 0), 10)" 146 | ], 147 | "text/latex": "$\\displaystyle \\left. \\frac{d}{d t} y{\\left(t \\right)} \\right|_{\\substack{ t=0 }} = 10$" 148 | }, 149 | "metadata": {}, 150 | "execution_count": 5 151 | } 152 | ], 153 | "source": [ 154 | "sp.Eq(f(t).diff(t).subs(t,0),10)" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": { 161 | "colab": { 162 | "base_uri": "https://localhost:8080/", 163 | "height": 38 164 | }, 165 | "id": "29QUbt_2YwlJ", 166 | "outputId": "0e3cce9d-29c7-4f26-c279-535093e2ad37" 167 | }, 168 | "outputs": [ 169 | { 170 | "output_type": "execute_result", 171 | "data": { 172 | "text/plain": [ 173 | "Eq(y(0), 1)" 174 | ], 175 | "text/latex": "$\\displaystyle y{\\left(0 \\right)} = 1$" 176 | }, 177 | "metadata": {}, 178 | "execution_count": 6 179 | } 180 | ], 181 | "source": [ 182 | "sp.Eq(f(t).subs(t,0),1)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": { 189 | "colab": { 190 | "base_uri": "https://localhost:8080/", 191 | "height": 81 192 | }, 193 | "id": "r9KVq1yjYfld", 194 | "outputId": "9a8437fc-c2fc-4fe7-ad9f-26ca25b974d4" 195 | }, 196 | "outputs": [ 197 | { 198 | "output_type": "execute_result", 199 | "data": { 200 | "text/plain": [ 201 | "Eq(y(t), (2*pi*t*exp(t)*sin(2*pi*t) + 8*pi**3*t*exp(t)*sin(2*pi*t) - 16*pi**4*t*exp(t)*cos(2*pi*t) - 4*pi**2*t*exp(t)*cos(2*pi*t) + 16*pi**3*exp(t)*sin(2*pi*t) + exp(t)*cos(2*pi*t) + 12*pi**2*exp(t)*cos(2*pi*t) - exp(t) + 36*pi**2*exp(t) + 336*pi**4*exp(t) + 704*pi**6*exp(t) - 640*pi**6 - 304*pi**4 - 44*pi**2)*exp(-t)/(4*pi**2*(1 + 8*pi**2 + 16*pi**4)))" 202 | ], 203 | "text/latex": "$\\displaystyle y{\\left(t \\right)} = \\frac{\\left(2 \\pi t e^{t} \\sin{\\left(2 \\pi t \\right)} + 8 \\pi^{3} t e^{t} \\sin{\\left(2 \\pi t \\right)} - 16 \\pi^{4} t e^{t} \\cos{\\left(2 \\pi t \\right)} - 4 \\pi^{2} t e^{t} \\cos{\\left(2 \\pi t \\right)} + 16 \\pi^{3} e^{t} \\sin{\\left(2 \\pi t \\right)} + e^{t} \\cos{\\left(2 \\pi t \\right)} + 12 \\pi^{2} e^{t} \\cos{\\left(2 \\pi t \\right)} - e^{t} + 36 \\pi^{2} e^{t} + 336 \\pi^{4} e^{t} + 704 \\pi^{6} e^{t} - 640 \\pi^{6} - 304 \\pi^{4} - 44 \\pi^{2}\\right) e^{- t}}{4 \\pi^{2} \\left(1 + 8 \\pi^{2} + 16 \\pi^{4}\\right)}$" 204 | }, 205 | "metadata": {}, 206 | "execution_count": 7 207 | } 208 | ], 209 | "source": [ 210 | "sp.Eq(f(t),sol)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": { 217 | "colab": { 218 | "base_uri": "https://localhost:8080/", 219 | "height": 37 220 | }, 221 | "id": "MNVOpPyCW-GU", 222 | "outputId": "53a70887-e440-43a8-f811-30c862b21866" 223 | }, 224 | "outputs": [ 225 | { 226 | "output_type": "execute_result", 227 | "data": { 228 | "text/plain": [ 229 | "0" 230 | ], 231 | "text/latex": "$\\displaystyle 0$" 232 | }, 233 | "metadata": {}, 234 | "execution_count": 8 235 | } 236 | ], 237 | "source": [ 238 | "#verify solution\n", 239 | "sp.simplify(-t*sp.cos(sp.pi*2*t)+sol.diff(t)+sol.diff(t,t))" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": { 245 | "id": "NQ61lEQeXgrc" 246 | }, 247 | "source": [ 248 | "### Constructing the MLP" 249 | ] 250 | }, 251 | { 252 | "cell_type": "code", 253 | "execution_count": null, 254 | "metadata": { 255 | "id": "Lml6PGLPZgmr" 256 | }, 257 | "outputs": [], 258 | "source": [ 259 | "N_b = 1\n", 260 | "N_c = 100\n", 261 | "\n", 262 | "tmin,tmax=0. ,jnp.pi\n", 263 | "\n", 264 | "'''boundary conditions'''\n", 265 | "\n", 266 | "\n", 267 | "# U[0] = 1\n", 268 | "t_0 = jnp.ones([N_b,1],dtype='float32')*0.\n", 269 | "ic_0 = jnp.ones_like(t_0) \n", 270 | "IC_0 = jnp.concatenate([t_0,ic_0],axis=1)\n", 271 | "\n", 272 | "# U_t[0] = 10\n", 273 | "t_b1 = jnp.zeros([N_b,1])\n", 274 | "bc_1 = jnp.ones_like(t_b1) * 10\n", 275 | "BC_1 = jnp.concatenate([t_b1,bc_1],axis=1)\n", 276 | "\n", 277 | "conds = [IC_0,BC_1]\n", 278 | "\n", 279 | "#collocation points\n", 280 | "\n", 281 | "key=jax.random.PRNGKey(0)\n", 282 | "\n", 283 | "t_c = jnp.linspace(tmin,tmax,N_c).reshape(-1,1)\n", 284 | "colloc = t_c\n", 285 | "\n", 286 | "def ODE_loss(t,u):\n", 287 | " dt = 0.03173326\n", 288 | " u_t = lambda t: (-u(t+2*dt)+8*u(t+dt)-8*u(t-dt)+u(t-2*dt))/(12*dt)\n", 289 | " u_tt = lambda t: (-u(t+2*dt) + 16*u(t+dt) -30*u(t) + 16 * u(t-dt) - u(t-2*dt))/(12*dt**2)\n", 290 | " return -t*jnp.cos(2*jnp.pi*t) + u_t(t) + u_tt(t)" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": { 297 | "id": "KoZZJl2TbI_n" 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "def init_params(layers):\n", 302 | " keys = jax.random.split(jax.random.PRNGKey(0),len(layers)-1)\n", 303 | " params = list()\n", 304 | " for key,n_in,n_out in zip(keys,layers[:-1],layers[1:]):\n", 305 | " lb, ub = -(1 / jnp.sqrt(n_in)), (1 / jnp.sqrt(n_in)) # xavier initialization lower and upper bound\n", 306 | " W = lb + (ub-lb) * jax.random.uniform(key,shape=(n_in,n_out))\n", 307 | " B = jax.random.uniform(key,shape=(n_out,))\n", 308 | " params.append({'W':W,'B':B})\n", 309 | " return params\n", 310 | "\n", 311 | "def fwd(params,t):\n", 312 | " X = jnp.concatenate([t],axis=1)\n", 313 | " *hidden,last = params\n", 314 | " for layer in hidden :\n", 315 | " X = jax.nn.tanh(X@layer['W']+layer['B'])\n", 316 | " return X@last['W'] + last['B']\n", 317 | "\n", 318 | "@jax.jit\n", 319 | "def MSE(true,pred):\n", 320 | " return jnp.mean((true-pred)**2)\n", 321 | "\n", 322 | "def loss_fun(params,colloc,conds):\n", 323 | " t_c =colloc[:,[0]]\n", 324 | " ufunc = lambda t : fwd(params,t)\n", 325 | " ufunc_t=lambda t:jax.grad(lambda t:jnp.sum(ufunc(t)))(t)\n", 326 | " loss =jnp.mean(ODE_loss(t_c,ufunc) **2)\n", 327 | "\n", 328 | " t_ic,u_ic = conds[0][:,[0]],conds[0][:,[1]] \n", 329 | " loss += MSE(u_ic,ufunc(t_ic))\n", 330 | "\n", 331 | " t_bc,u_bc = conds[1][:,[0]],conds[1][:,[1]] \n", 332 | " loss += MSE(u_bc,ufunc_t(t_bc))\n", 333 | "\n", 334 | " return loss\n", 335 | "\n", 336 | "@jax.jit\n", 337 | "def update(opt_state,params,colloc,conds):\n", 338 | " # Get the gradient w.r.t to MLP params\n", 339 | " grads=jax.jit(jax.grad(loss_fun,0))(params,colloc,conds)\n", 340 | " \n", 341 | " #Update params\n", 342 | " updates, opt_state = optimizer.update(grads, opt_state)\n", 343 | " params = optax.apply_updates(params, updates)\n", 344 | "\n", 345 | " #Update params\n", 346 | " # return jax.tree_multimap(lambda params,grads : params-LR*grads, params,grads)\n", 347 | " return opt_state,params\n" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "metadata": { 354 | "id": "ae1ZDoy0c29c" 355 | }, 356 | "outputs": [], 357 | "source": [ 358 | "# construct the MLP of 6 hidden layers of 8 neurons for each layer\n", 359 | "params = init_params([1] + [20]*4+[1])" 360 | ] 361 | }, 362 | { 363 | "cell_type": "code", 364 | "execution_count": null, 365 | "metadata": { 366 | "id": "jySmbUwic5yk" 367 | }, 368 | "outputs": [], 369 | "source": [ 370 | "lr = optax.piecewise_constant_schedule(1e-3,{10_000:5e-3,30_000:1e-3,50_000:5e-4,70_000:1e-4})\n", 371 | "optimizer = optax.adam(1e-3)\n", 372 | "opt_state = optimizer.init(params)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "code", 377 | "execution_count": null, 378 | "metadata": { 379 | "colab": { 380 | "base_uri": "https://localhost:8080/" 381 | }, 382 | "id": "kBzGA8OVc8C6", 383 | "outputId": "ed7565fa-5172-42c8-ce96-39b3cd357d5b" 384 | }, 385 | "outputs": [ 386 | { 387 | "output_type": "stream", 388 | "name": "stdout", 389 | "text": [ 390 | "Epoch=0\tloss=1.026e+02\n", 391 | "Epoch=1000\tloss=1.134e+01\n", 392 | "Epoch=2000\tloss=7.139e+00\n", 393 | "Epoch=3000\tloss=4.549e+00\n", 394 | "Epoch=4000\tloss=2.824e+00\n", 395 | "Epoch=5000\tloss=2.343e+00\n", 396 | "Epoch=6000\tloss=2.147e+00\n", 397 | "Epoch=7000\tloss=1.904e+00\n", 398 | "Epoch=8000\tloss=1.563e+00\n", 399 | "Epoch=9000\tloss=1.016e+00\n", 400 | "Epoch=10000\tloss=9.470e-01\n", 401 | "Epoch=11000\tloss=9.235e-01\n", 402 | "Epoch=12000\tloss=9.041e-01\n", 403 | "Epoch=13000\tloss=8.905e-01\n", 404 | "Epoch=14000\tloss=8.731e-01\n", 405 | "Epoch=15000\tloss=8.661e-01\n", 406 | "Epoch=16000\tloss=8.594e-01\n", 407 | "Epoch=17000\tloss=8.494e-01\n", 408 | "Epoch=18000\tloss=8.424e-01\n", 409 | "Epoch=19000\tloss=8.380e-01\n", 410 | "Epoch=20000\tloss=8.348e-01\n", 411 | "Epoch=21000\tloss=8.259e-01\n", 412 | "Epoch=22000\tloss=8.208e-01\n", 413 | "Epoch=23000\tloss=8.104e-01\n", 414 | "Epoch=24000\tloss=8.035e-01\n", 415 | "Epoch=25000\tloss=7.996e-01\n", 416 | "Epoch=26000\tloss=7.991e-01\n", 417 | "Epoch=27000\tloss=7.855e-01\n", 418 | "Epoch=28000\tloss=7.851e-01\n", 419 | "Epoch=29000\tloss=7.818e-01\n", 420 | "Epoch=30000\tloss=7.784e-01\n", 421 | "Epoch=31000\tloss=7.729e-01\n", 422 | "Epoch=32000\tloss=7.824e-01\n", 423 | "Epoch=33000\tloss=7.703e-01\n", 424 | "Epoch=34000\tloss=7.798e-01\n", 425 | "Epoch=35000\tloss=7.726e-01\n", 426 | "Epoch=36000\tloss=7.629e-01\n", 427 | "Epoch=37000\tloss=7.736e-01\n", 428 | "Epoch=38000\tloss=7.748e-01\n", 429 | "Epoch=39000\tloss=7.663e-01\n", 430 | "Epoch=40000\tloss=7.684e-01\n", 431 | "Epoch=41000\tloss=7.751e-01\n", 432 | "Epoch=42000\tloss=7.682e-01\n", 433 | "Epoch=43000\tloss=7.782e-01\n", 434 | "Epoch=44000\tloss=7.733e-01\n", 435 | "Epoch=45000\tloss=7.665e-01\n", 436 | "Epoch=46000\tloss=7.643e-01\n", 437 | "Epoch=47000\tloss=7.625e-01\n", 438 | "Epoch=48000\tloss=7.619e-01\n", 439 | "Epoch=49000\tloss=7.627e-01\n", 440 | "Epoch=50000\tloss=7.688e-01\n", 441 | "Epoch=51000\tloss=7.621e-01\n", 442 | "Epoch=52000\tloss=7.633e-01\n", 443 | "Epoch=53000\tloss=7.616e-01\n", 444 | "Epoch=54000\tloss=7.713e-01\n", 445 | "Epoch=55000\tloss=7.645e-01\n", 446 | "Epoch=56000\tloss=7.589e-01\n", 447 | "Epoch=57000\tloss=7.626e-01\n", 448 | "Epoch=58000\tloss=7.668e-01\n", 449 | "Epoch=59000\tloss=7.682e-01\n", 450 | "Epoch=60000\tloss=7.605e-01\n", 451 | "Epoch=61000\tloss=7.642e-01\n", 452 | "Epoch=62000\tloss=7.570e-01\n", 453 | "Epoch=63000\tloss=7.593e-01\n", 454 | "Epoch=64000\tloss=7.541e-01\n", 455 | "Epoch=65000\tloss=7.565e-01\n", 456 | "Epoch=66000\tloss=7.604e-01\n", 457 | "Epoch=67000\tloss=7.644e-01\n", 458 | "Epoch=68000\tloss=7.555e-01\n", 459 | "Epoch=69000\tloss=7.600e-01\n", 460 | "Epoch=70000\tloss=7.641e-01\n", 461 | "Epoch=71000\tloss=7.505e-01\n", 462 | "Epoch=72000\tloss=7.549e-01\n", 463 | "Epoch=73000\tloss=7.627e-01\n", 464 | "Epoch=74000\tloss=7.575e-01\n", 465 | "Epoch=75000\tloss=7.568e-01\n", 466 | "Epoch=76000\tloss=7.699e-01\n", 467 | "Epoch=77000\tloss=7.498e-01\n", 468 | "Epoch=78000\tloss=7.505e-01\n", 469 | "Epoch=79000\tloss=7.544e-01\n", 470 | "Epoch=80000\tloss=7.511e-01\n", 471 | "Epoch=81000\tloss=7.579e-01\n", 472 | "Epoch=82000\tloss=7.470e-01\n", 473 | "Epoch=83000\tloss=7.527e-01\n", 474 | "Epoch=84000\tloss=7.583e-01\n", 475 | "Epoch=85000\tloss=7.445e-01\n", 476 | "Epoch=86000\tloss=7.476e-01\n", 477 | "Epoch=87000\tloss=7.633e-01\n", 478 | "Epoch=88000\tloss=7.500e-01\n", 479 | "Epoch=89000\tloss=7.467e-01\n", 480 | "Epoch=90000\tloss=7.447e-01\n", 481 | "Epoch=91000\tloss=7.495e-01\n", 482 | "Epoch=92000\tloss=7.487e-01\n", 483 | "Epoch=93000\tloss=7.542e-01\n", 484 | "Epoch=94000\tloss=7.583e-01\n", 485 | "Epoch=95000\tloss=7.612e-01\n", 486 | "Epoch=96000\tloss=7.489e-01\n", 487 | "Epoch=97000\tloss=7.502e-01\n", 488 | "Epoch=98000\tloss=7.474e-01\n", 489 | "Epoch=99000\tloss=7.563e-01\n", 490 | "CPU times: user 1min 57s, sys: 256 ms, total: 1min 57s\n", 491 | "Wall time: 3min 3s\n" 492 | ] 493 | } 494 | ], 495 | "source": [ 496 | "%%time\n", 497 | "epochs = 100_000\n", 498 | "for _ in range(epochs):\n", 499 | " opt_state,params = update(opt_state,params,colloc,conds)\n", 500 | "\n", 501 | " # print loss and epoch info\n", 502 | " if _ %(1000) ==0:\n", 503 | " print(f'Epoch={_}\\tloss={loss_fun(params,colloc,conds):.3e}')" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": null, 509 | "metadata": { 510 | "colab": { 511 | "base_uri": "https://localhost:8080/", 512 | "height": 282 513 | }, 514 | "id": "eWeNvDsdDEuI", 515 | "outputId": "1b79ef5d-7f2d-40ee-f4a1-4d0ba1cb0edd" 516 | }, 517 | "outputs": [ 518 | { 519 | "output_type": "execute_result", 520 | "data": { 521 | "text/plain": [ 522 | "" 523 | ] 524 | }, 525 | "metadata": {}, 526 | "execution_count": 36 527 | }, 528 | { 529 | "output_type": "display_data", 530 | "data": { 531 | "text/plain": [ 532 | "
" 533 | ], 534 | "image/png": "\n" 535 | }, 536 | "metadata": { 537 | "needs_background": "light" 538 | } 539 | } 540 | ], 541 | "source": [ 542 | "lam_sol= sp.lambdify(t,sol)\n", 543 | "\n", 544 | "dT = 1e-3\n", 545 | "Tf = jnp.pi\n", 546 | "T = np.arange(0,Tf+dT,dT)\n", 547 | "\n", 548 | "\n", 549 | "sym_sol =np.array([lam_sol(i) for i in T])\n", 550 | "\n", 551 | "plt.plot(T,sym_sol,'--r',label='sympy solution')\n", 552 | "plt.plot(T,fwd(params,T.reshape(-1,1))[:,0],'--k',label='NN solution')\n", 553 | "plt.legend()" 554 | ] 555 | } 556 | ], 557 | "metadata": { 558 | "colab": { 559 | "collapsed_sections": [], 560 | "name": "[6] ODE-PINN finite difference.ipynb", 561 | "provenance": [], 562 | "include_colab_link": true 563 | }, 564 | "kernelspec": { 565 | "display_name": "Python 3", 566 | "language": "python", 567 | "name": "python3" 568 | }, 569 | "language_info": { 570 | "codemirror_mode": { 571 | "name": "ipython", 572 | "version": 3 573 | }, 574 | "file_extension": ".py", 575 | "mimetype": "text/x-python", 576 | "name": "python", 577 | "nbconvert_exporter": "python", 578 | "pygments_lexer": "ipython3", 579 | "version": "3.9.6" 580 | } 581 | }, 582 | "nbformat": 4, 583 | "nbformat_minor": 0 584 | } --------------------------------------------------------------------------------