├── LICENSE ├── README.md └── notebooks ├── ML_tutorial.ipynb ├── advection_diffusion_DL.ipynb └── lorenz_timestep.ipynb /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Joseph Bakarji 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | This repo is a collection of machine learning methods applied to dynamical systems and fluid dynamics problems. See related lectures on the website: http://www.databookuw.com/page-5/page-24/ 4 | 5 | -------------------------------------------------------------------------------- /notebooks/advection_diffusion_DL.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "ME599_deep_learning_for_PDEs.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [ 9 | "X3C3EuF19GOC", 10 | "UrdGk10n9AyE", 11 | "63ZyNs5qMcrz", 12 | "_uV4EY3i9WXk" 13 | ] 14 | }, 15 | "kernelspec": { 16 | "name": "python3", 17 | "display_name": "Python 3" 18 | }, 19 | "language_info": { 20 | "name": "python" 21 | } 22 | }, 23 | "cells": [ 24 | { 25 | "cell_type": "code", 26 | "metadata": { 27 | "id": "3grFXgv78ama" 28 | }, 29 | "source": [ 30 | "import tensorflow as tf\n", 31 | "import sklearn\n", 32 | "import numpy as np\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "\n", 35 | "from tensorflow import keras\n", 36 | "from keras import layers\n", 37 | "from sklearn.model_selection import train_test_split" 38 | ], 39 | "execution_count": null, 40 | "outputs": [] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": { 45 | "id": "QbiqrnxK98E-" 46 | }, 47 | "source": [ 48 | "This code will apply various techniques in Machine Learning on the advection-diffusion equation, commonly found in many physics and engineering applications, particularly in fluid dynamics. Let $u(x, t)$ be a space-time dependent quantity, such as temperature or concentration in a fluid flow, its corresponding advection diffusion equation is given by\n", 49 | "\n", 50 | "$$\\frac{\\partial u}{\\partial t} = v \\frac{\\partial u}{\\partial x} + D \\frac{\\partial^2 u}{\\partial x^2},$$\n", 51 | "\n", 52 | "where $D$ is a diffusion coefficient, $v$ is the advection coefficient describing the speed of the flow." 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": { 58 | "id": "d6sT9BcZLJhn" 59 | }, 60 | "source": [ 61 | "### Good-old numerical solver" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "metadata": { 67 | "colab": { 68 | "base_uri": "https://localhost:8080/" 69 | }, 70 | "id": "rT0Nd1rA8vxs", 71 | "outputId": "9b518760-4440-4b62-b7de-75b83e16d3c0" 72 | }, 73 | "source": [ 74 | "x_left = -2\n", 75 | "x_right = 2\n", 76 | "t_end = 2\n", 77 | "\n", 78 | "D = 0.02\n", 79 | "v = 0.1\n", 80 | "dx = 0.01\n", 81 | "dt = 0.001\n", 82 | "\n", 83 | "cfl = D*dt/dx**2\n", 84 | "print('CFL = %.4f'%(cfl))\n", 85 | "x_steps = int((x_right - x_left)/dx)\n", 86 | "t_steps = int(t_end/dt)\n", 87 | "\n", 88 | "xx = np.linspace(x_left, x_right, x_steps)\n", 89 | "tt = np.linspace(0, t_end, t_steps)\n", 90 | "uu = np.zeros((t_steps, x_steps))\n", 91 | "\n", 92 | "# ICs\n", 93 | "def gaussian(x, mu, sig, shift):\n", 94 | " return shift + np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))\n", 95 | "\n", 96 | "ic_mean = 0\n", 97 | "ic_std = 0.2\n", 98 | "ic_shift = 0 \n", 99 | "uu[0, :] = gaussian(xx, ic_mean, ic_std, ic_shift)" 100 | ], 101 | "execution_count": null, 102 | "outputs": [ 103 | { 104 | "output_type": "stream", 105 | "text": [ 106 | "CFL = 0.2000\n" 107 | ], 108 | "name": "stdout" 109 | } 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": { 115 | "id": "0kzXfn7CdbnD" 116 | }, 117 | "source": [ 118 | "$$ u(x_i, t_{i+1}) = u(x_i, t_i) + D \\left(\\frac{\\Delta t}{\\Delta x}\\right)^2 (u(x_{i-1}, t_i) - 2 u(x_i, t_i) + u(x_{i+1}, t_i) ) + v \\frac{\\Delta t}{2\\Delta x} (u(x_{i+1}, t_i) - u(x_{i-1}, t_i))$$" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "metadata": { 124 | "id": "CEVSLEqe8xP6" 125 | }, 126 | "source": [ 127 | "# Euler Stepper\n", 128 | "def adv_diff_euler_step(ul, uc, ur, D, v, dx, dt):\n", 129 | " return uc + D * dt/dx**2 * (ul - 2 * uc + ur) + v * dt/(2*dx) * (ur - ul)\n", 130 | "\n", 131 | "# Solve\n", 132 | "inputs = []\n", 133 | "outputs = []\n", 134 | "for ti, t in enumerate(tt[:-1]):\n", 135 | " for xi, x in enumerate(xx[1:-1]):\n", 136 | " uu[ti+1, xi] = adv_diff_euler_step(uu[ti, xi-1], uu[ti, xi], uu[ti, xi+1], D, v, dx, dt)\n", 137 | "\n", 138 | " # zero flux BC\n", 139 | " uu[ti+1, 0] = uu[ti+1, 1]\n", 140 | " uu[ti+1, -1] = uu[ti+1, -2]\n", 141 | "\n", 142 | " # Save data\n", 143 | " inputs.append([t, x])\n", 144 | " outputs.append(uu[ti+1, xi])\n" 145 | ], 146 | "execution_count": null, 147 | "outputs": [] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "metadata": { 152 | "colab": { 153 | "base_uri": "https://localhost:8080/", 154 | "height": 279 155 | }, 156 | "id": "XQ-_2s8q8zSE", 157 | "outputId": "2a4e3efa-c89b-4f0c-8761-6b970e5b0186" 158 | }, 159 | "source": [ 160 | "fig = plt.figure()\n", 161 | "plt.plot(xx, uu[0, :], lw=2)\n", 162 | "plt.plot(xx, uu[-1, :], lw=2)\n", 163 | "plt.xlabel('$x$')\n", 164 | "plt.ylabel('$u(x, t)$')\n", 165 | "plt.legend(['$t_0$', '$t_{end}$'])\n", 166 | "\n", 167 | "\n", 168 | "plt.show()" 169 | ], 170 | "execution_count": null, 171 | "outputs": [ 172 | { 173 | "output_type": "display_data", 174 | "data": { 175 | "image/png": "\n", 176 | "text/plain": [ 177 | "
" 178 | ] 179 | }, 180 | "metadata": { 181 | "tags": [], 182 | "needs_background": "light" 183 | } 184 | } 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": { 190 | "id": "X3C3EuF19GOC" 191 | }, 192 | "source": [ 193 | "### Approximate the solution by a neural network?\n", 194 | "A *naive* first step would be to approximate the function $u(x, t)$ by building a neural network with inputs $(x, t)$ and output $u$. Such a network can be useful for surrogate modeling." 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "metadata": { 200 | "id": "uAEtfNmB80-B" 201 | }, 202 | "source": [ 203 | "test_ratio = 0.25\n", 204 | "dev_ratio = 0.2\n", 205 | "\n", 206 | "# Prepare data\n", 207 | "inputs_array = np.asarray(inputs)\n", 208 | "outputs_array = np.asarray(outputs)\n", 209 | "\n", 210 | "# Split into train-dev-test sets\n", 211 | "X_train, X_test, y_train, y_test = train_test_split(inputs_array, outputs_array, test_size=test_ratio, shuffle=False)\n", 212 | "X_train, X_dev, y_train, y_dev = train_test_split(X_train, y_train, test_size=dev_ratio, shuffle=False)" 213 | ], 214 | "execution_count": null, 215 | "outputs": [] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "metadata": { 220 | "colab": { 221 | "base_uri": "https://localhost:8080/" 222 | }, 223 | "id": "_9eHr3NP81cT", 224 | "outputId": "33d49ac7-e1da-460e-d934-853008daddad" 225 | }, 226 | "source": [ 227 | "\n", 228 | "# Build model\n", 229 | "deep_approx = keras.models.Sequential()\n", 230 | "deep_approx.add(layers.Dense(10, input_dim=2, activation='elu'))\n", 231 | "deep_approx.add(layers.Dense(10, activation='elu'))\n", 232 | "deep_approx.add(layers.Dense(1, activation='linear'))\n", 233 | "\n", 234 | "# Compile model\n", 235 | "deep_approx.compile(loss='mse', optimizer='adam')\n", 236 | "\n", 237 | "# Fit!\n", 238 | "history = deep_approx.fit(X_train, y_train, \n", 239 | " epochs=10, batch_size=32, \n", 240 | " validation_data=(X_dev, y_dev),\n", 241 | " callbacks=keras.callbacks.EarlyStopping(patience=5))" 242 | ], 243 | "execution_count": null, 244 | "outputs": [ 245 | { 246 | "output_type": "stream", 247 | "text": [ 248 | "Epoch 1/10\n", 249 | "14918/14918 [==============================] - 22s 1ms/step - loss: 0.0146 - val_loss: 4.2067e-04\n", 250 | "Epoch 2/10\n", 251 | "14918/14918 [==============================] - 21s 1ms/step - loss: 2.4455e-04 - val_loss: 1.5875e-04\n", 252 | "Epoch 3/10\n", 253 | "14918/14918 [==============================] - 21s 1ms/step - loss: 1.4117e-04 - val_loss: 1.3537e-04\n", 254 | "Epoch 4/10\n", 255 | "14918/14918 [==============================] - 21s 1ms/step - loss: 9.3146e-05 - val_loss: 1.6698e-04\n", 256 | "Epoch 5/10\n", 257 | "14918/14918 [==============================] - 21s 1ms/step - loss: 6.1776e-05 - val_loss: 1.5959e-04\n", 258 | "Epoch 6/10\n", 259 | "14918/14918 [==============================] - 21s 1ms/step - loss: 4.4186e-05 - val_loss: 1.6316e-04\n", 260 | "Epoch 7/10\n", 261 | "14918/14918 [==============================] - 21s 1ms/step - loss: 3.4193e-05 - val_loss: 1.7513e-04\n", 262 | "Epoch 8/10\n", 263 | "14918/14918 [==============================] - 21s 1ms/step - loss: 2.8389e-05 - val_loss: 1.0539e-04\n", 264 | "Epoch 9/10\n", 265 | "14918/14918 [==============================] - 21s 1ms/step - loss: 2.4172e-05 - val_loss: 1.0172e-04\n", 266 | "Epoch 10/10\n", 267 | "14918/14918 [==============================] - 21s 1ms/step - loss: 2.0476e-05 - val_loss: 7.5740e-05\n" 268 | ], 269 | "name": "stdout" 270 | } 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "metadata": { 276 | "colab": { 277 | "base_uri": "https://localhost:8080/" 278 | }, 279 | "id": "heB3InzVGd4j", 280 | "outputId": "003aba6a-b54f-4577-f983-18cc29a8ecbc" 281 | }, 282 | "source": [ 283 | "deep_approx.summary()\n" 284 | ], 285 | "execution_count": null, 286 | "outputs": [ 287 | { 288 | "output_type": "stream", 289 | "text": [ 290 | "Model: \"sequential_1\"\n", 291 | "_________________________________________________________________\n", 292 | "Layer (type) Output Shape Param # \n", 293 | "=================================================================\n", 294 | "dense_3 (Dense) (None, 10) 30 \n", 295 | "_________________________________________________________________\n", 296 | "dense_4 (Dense) (None, 10) 110 \n", 297 | "_________________________________________________________________\n", 298 | "dense_5 (Dense) (None, 1) 11 \n", 299 | "=================================================================\n", 300 | "Total params: 151\n", 301 | "Trainable params: 151\n", 302 | "Non-trainable params: 0\n", 303 | "_________________________________________________________________\n" 304 | ], 305 | "name": "stdout" 306 | } 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "metadata": { 312 | "colab": { 313 | "base_uri": "https://localhost:8080/", 314 | "height": 279 315 | }, 316 | "id": "Syo-u0BSJxUQ", 317 | "outputId": "ccb46792-2b2e-4ade-fe95-a65b63cd8d06" 318 | }, 319 | "source": [ 320 | "# history.history contains loss information\n", 321 | "\n", 322 | "idx0 = 1\n", 323 | "plt.figure()\n", 324 | "plt.plot(history.history['loss'][idx0:], '.-', lw=2)\n", 325 | "plt.plot(history.history['val_loss'][idx0:], '.-', lw=2)\n", 326 | "plt.xlabel('epochs')\n", 327 | "plt.ylabel('Validation loss')\n", 328 | "plt.legend(['training loss', 'validation loss'])\n", 329 | "plt.show()" 330 | ], 331 | "execution_count": null, 332 | "outputs": [ 333 | { 334 | "output_type": "display_data", 335 | "data": { 336 | "image/png": "\n", 337 | "text/plain": [ 338 | "
" 339 | ] 340 | }, 341 | "metadata": { 342 | "tags": [], 343 | "needs_background": "light" 344 | } 345 | } 346 | ] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "metadata": { 351 | "id": "UrdGk10n9AyE" 352 | }, 353 | "source": [ 354 | "#### Does it extrapolate in time?" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "metadata": { 360 | "colab": { 361 | "base_uri": "https://localhost:8080/", 362 | "height": 297 363 | }, 364 | "id": "ZZgBROuz9BPd", 365 | "outputId": "01202c42-6c0a-4dc9-d3e9-9fe7bc73fe31" 366 | }, 367 | "source": [ 368 | "import seaborn as sns\n", 369 | "c = sns.color_palette()\n", 370 | "\n", 371 | "nplots = 11\n", 372 | "rmin = 0\n", 373 | "rmax = 1\n", 374 | "idxes = np.arange(int(rmin*len(tt)), int(rmax*len(tt)), int((rmax-rmin)*len(tt)/nplots))\n", 375 | "e_mean = []\n", 376 | "tt_mean = []\n", 377 | "\n", 378 | "fig = plt.figure(figsize=(12, 4))\n", 379 | "ax0 = fig.add_subplot(121)\n", 380 | "ax1 = fig.add_subplot(122)\n", 381 | "for idx, i in enumerate(idxes):\n", 382 | " data_in = np.array([ [tt[i], x] for x in xx])\n", 383 | " u_approx = deep_approx.predict(data_in)\n", 384 | " ax0.plot(xx, u_approx, lw=2, color=c[idx%len(c)])\n", 385 | " ax0.plot(xx, uu[i, :], lw=2, linestyle='--')\n", 386 | " tt_mean.append(tt[i])\n", 387 | " e_mean.append( np.mean((u_approx[:, 0] - uu[i, :])**2) )\n", 388 | "\n", 389 | "ax1.plot(tt_mean, e_mean, '.-', lw=2, color=c[0], markersize=10)\n", 390 | "ax1.plot([(1-test_ratio)*t_end]*2, [min(e_mean), max(e_mean)], ':', color=c[1])\n", 391 | "ax1.legend(['RMSE', 'Train/dev time horizon'])\n", 392 | "\n", 393 | "ax0.set_xlabel('$x$')\n", 394 | "ax0.set_ylabel('$u(x, t)$')\n", 395 | "# ax0.legend(['$t^*_{end}$'])\n", 396 | "ax1.set_ylabel('Error')\n", 397 | "\n", 398 | "fig.tight_layout()\n", 399 | "plt.show()" 400 | ], 401 | "execution_count": null, 402 | "outputs": [ 403 | { 404 | "output_type": "display_data", 405 | "data": { 406 | "image/png": "\n", 407 | "text/plain": [ 408 | "
" 409 | ] 410 | }, 411 | "metadata": { 412 | "tags": [], 413 | "needs_background": "light" 414 | } 415 | } 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "metadata": { 421 | "id": "63ZyNs5qMcrz" 422 | }, 423 | "source": [ 424 | "### Approximate the integrator by a neural network (RNN)\n", 425 | "We can basically fit the coefficient of the Euler stepping function we used above `adv_diff_euler_step(ul, uc, ur, D, v, dx, dt)`. If we ignore its dependence on input parameters, we're approximating the function $\\phi$ in\n", 426 | "\n", 427 | "$$u(t_i, x_i) = \\phi(u(t_{i-1}, x_{i-1}), u(t_{i-1}, x_{i}), u(t_{i-1}, x_{i+1}))$$\n", 428 | "\n", 429 | "In dynamical systems where the state $u$ only depends on time, $\\phi$ is commonly referred as a **flow map**, and we look to approximate the function\n", 430 | "$$ u(t_i) = \\phi(u(t_{i-1})) $$\n", 431 | "\n", 432 | "These paper use this approach:\n", 433 | "- https://www.youtube.com/watch?v=Jfl3dIlSTrU\n", 434 | "- " 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "metadata": { 440 | "id": "CHdlp1ee9Gse" 441 | }, 442 | "source": [ 443 | "# Let's run the simulation and save the data at each step\n", 444 | "\n", 445 | "inputs = []\n", 446 | "outputs = []\n", 447 | "for ti, t in enumerate(tt[:-1]):\n", 448 | " for xi, x in enumerate(xx[1:-1]):\n", 449 | " uu[ti+1, xi] = adv_diff_euler_step(uu[ti, xi-1], uu[ti, xi], uu[ti, xi+1], D, v, dx, dt)\n", 450 | " uu[ti+1, 0] = uu[ti+1, 1]\n", 451 | " uu[ti+1, -1] = uu[ti+1, -2]\n", 452 | "\n", 453 | " # Collect data\n", 454 | " inputs.append([uu[ti, xi-1], uu[ti, xi], uu[ti, xi+1]])\n", 455 | " outputs.append(uu[ti+1, xi])\n", 456 | "\n", 457 | "inputs_array = np.asarray(inputs)\n", 458 | "outputs_array = np.asarray(outputs)\n", 459 | "\n", 460 | "Xs_train, Xs_test, ys_train, ys_test = train_test_split(inputs_array, outputs_array, test_size=test_ratio, shuffle=False)\n", 461 | "Xs_train, Xs_dev, ys_train, ys_dev = train_test_split(Xs_train, ys_train, test_size=dev_ratio, shuffle=True)" 462 | ], 463 | "execution_count": null, 464 | "outputs": [] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "metadata": { 469 | "colab": { 470 | "base_uri": "https://localhost:8080/" 471 | }, 472 | "id": "jIBEvifZ9KhE", 473 | "outputId": "08a28eaa-04bb-4bab-9bcd-0825a1c7ef25" 474 | }, 475 | "source": [ 476 | "## linear regression of stepper\n", 477 | "\n", 478 | "# Build model\n", 479 | "deep_stepper = keras.models.Sequential()\n", 480 | "deep_stepper.add(layers.Dense(1, input_dim=3, activation='linear'))\n", 481 | "\n", 482 | "# Compile model\n", 483 | "deep_stepper.compile(loss='mse', optimizer='adam')\n", 484 | "\n", 485 | "# Fit!\n", 486 | "history = deep_stepper.fit(Xs_train, ys_train, epochs=3, batch_size=32, \n", 487 | " validation_data=(Xs_dev, ys_dev),\n", 488 | " callbacks=keras.callbacks.EarlyStopping(patience=5))" 489 | ], 490 | "execution_count": null, 491 | "outputs": [ 492 | { 493 | "output_type": "stream", 494 | "text": [ 495 | "Epoch 1/3\n", 496 | "14918/14918 [==============================] - 19s 1ms/step - loss: 0.0434 - val_loss: 6.9518e-06\n", 497 | "Epoch 2/3\n", 498 | "14918/14918 [==============================] - 17s 1ms/step - loss: 1.9073e-06 - val_loss: 4.8422e-10\n", 499 | "Epoch 3/3\n", 500 | "14918/14918 [==============================] - 17s 1ms/step - loss: 8.4980e-09 - val_loss: 5.1057e-10\n" 501 | ], 502 | "name": "stdout" 503 | } 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "metadata": { 509 | "colab": { 510 | "base_uri": "https://localhost:8080/" 511 | }, 512 | "id": "stoi2aPw9OTt", 513 | "outputId": "9db74818-9110-4905-e9e0-952b78b452a7" 514 | }, 515 | "source": [ 516 | "## Integrate with the neural network\n", 517 | "# WARNING (and lesson): this will take too long! (around 10 minutes)\n", 518 | "# Can accelerate by vectorizing input_stencil\n", 519 | "\t\n", 520 | "from tqdm import tqdm \n", 521 | "\n", 522 | "uu_deep = np.zeros(uu.shape)\n", 523 | "uu_deep[0, :] = uu[0, :]\n", 524 | "\n", 525 | "for ti in tqdm(range(len(tt[:-1]))):\n", 526 | " for xi, x in enumerate(xx[1:-1]):\n", 527 | " input_stencil = np.array([[uu[ti, xi-1], uu[ti, xi], uu[ti, xi+1]]])\n", 528 | " uu_deep[ti+1, xi] = deep_stepper( input_stencil )[0][0].numpy()\n", 529 | " uu_deep[ti+1, 0] = uu_deep[ti+1, 1]\n", 530 | " uu_deep[ti+1, -1] = uu_deep[ti+1, -2]" 531 | ], 532 | "execution_count": null, 533 | "outputs": [ 534 | { 535 | "output_type": "stream", 536 | "text": [ 537 | "100%|██████████| 1999/1999 [09:23<00:00, 3.55it/s]\n" 538 | ], 539 | "name": "stderr" 540 | } 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "metadata": { 546 | "id": "_sUBfqYr9UNy", 547 | "colab": { 548 | "base_uri": "https://localhost:8080/", 549 | "height": 572 550 | }, 551 | "outputId": "cda9bd26-87dd-4f54-b5d0-694ad4216323" 552 | }, 553 | "source": [ 554 | "fig = plt.figure()\n", 555 | "plt.plot(xx, uu_deep[-1, :], lw=2)\n", 556 | "plt.plot(xx, uu[-1, :], '--', lw=2)\n", 557 | "plt.xlabel('$x$')\n", 558 | "plt.ylabel('$u(x, t_{end})$')\n", 559 | "plt.legend(['Deep stepper', 'Euler stepper'])\n", 560 | "\n", 561 | "idx_list = [0, int(len(tt)/4), int(len(tt)/2), int(len(tt))-1]\n", 562 | "leg = []\n", 563 | "fig = plt.figure()\n", 564 | "for idx in idx_list:\n", 565 | " plt.plot(xx, (uu_deep[idx, :] - uu[idx, :])**2, lw=2)\n", 566 | " leg.append('$t=%.2f s$'%(tt[idx]))\n", 567 | "plt.xlabel('$x$')\n", 568 | "plt.ylabel('$Error$')\n", 569 | "plt.legend(leg)\n" 570 | ], 571 | "execution_count": null, 572 | "outputs": [ 573 | { 574 | "output_type": "execute_result", 575 | "data": { 576 | "text/plain": [ 577 | "" 578 | ] 579 | }, 580 | "metadata": { 581 | "tags": [] 582 | }, 583 | "execution_count": 40 584 | }, 585 | { 586 | "output_type": "display_data", 587 | "data": { 588 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEJCAYAAAB7UTvrAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxU5dn/8c81k5UAYUnYAwn7GrYERBBRBLcWtWor7rvW8mjtUx/x18e1m1aftq51r7Z1BxeqKFYFFQUhIPtOCBC2LJCEkH3m+v0xQwyQQBJm5iSZ6/165cWZmXvO+eaQ5Jqz3PctqooxxhgD4HI6gDHGmKbDioIxxphqVhSMMcZUs6JgjDGmmhUFY4wx1awoGGOMqRayoiAi54jIRhHZIiIz62jzUxFZJyJrReT1UGUzxhjjI6HopyAibmATMAXIBpYC01V1XY02/YC3gTNV9YCIdFLVnKCHM8YYUy0iRNsZA2xR1UwAEXkTuABYV6PNTcDTqnoAoL4FISEhQZOTkwOb1hhjWrBly5blqWpiba+Fqih0B3bWeJwNjD2qTX8AEfkGcAMPqOonJ1pxcnIyGRkZgcppjDEtnohsr+u1UBWF+ogA+gGTgB7AVyIyTFULjm4oIjcDNwP07NkzlBmNMaZFC9WF5l1AUo3HPfzP1ZQNzFHVSlXdhu8aRL/aVqaqz6tqmqqmJSbWegRkjDGmEUJVFJYC/UQkRUSigMuAOUe1eR/fUQIikoDvdFJmiPIZY4whRKePVLVKRGYA8/BdL3hZVdeKyENAhqrO8b82VUTWAR7gLlXND0U+Y8zJqaysJDs7m7KyMqejmBpiYmLo0aMHkZGR9X5PSG5JDaa0tDS1C83GOGvbtm20adOGjh07IiJOxzGAqpKfn8/BgwdJSUk54jURWaaqabW9z3o0G2NOWllZmRWEJkZE6NixY4OP3qwoGGMCwgpC09OY/5OmdEuqMU3ClpyDvLlkJ5PXziSOUkp7nMaQH82gdXwHp6OZ43C73QwbNozKykoiIiK4+uqrufPOO3G5QvvZd8GCBURFRXHqqaeGdLuBYkXBGD+vV3nii808+cUWPF5lTGQx49zLYPMScv/yMjumPsPgU89zOqapQ2xsLCtWrAAgJyeHyy+/nKKiIh588MGQ5liwYAGtW7d2rCh4PB7cbnej32+nj4wB1Otlyd9uJOOLd/Gqcll6Em2mPcLS0Y+xMWIAiRyg77wrWbvwA6ejmnro1KkTzz//PE899RSqisfj4a677iI9PZ3U1FSee+656raPPvpo9fP3338/AFlZWQwcOJArrriCQYMGcckll1BSUnLMdp544gkGDx5Mamoql112GVlZWTz77LP85S9/YcSIEXz99dfk5uZy8cUXk56eTnp6Ot988w0ADzzwAFdddRXjxo2jX79+vPDCC4CvqEycOJHzzz+fAQMGcOutt+L1egH49NNPGTduHKNGjeLSSy+luLgY8I3scPfddzNq1Cjeeeedk9t5qtqsv0aPHq3GnKzF/3pQ9f62mntfks5fu+OI16oqK3XRE1er3t9WC+/volnrlzuUsulat26d0xE0Li7umOfi4+N17969+txzz+lvf/tbVVUtKyvT0aNHa2Zmps6bN09vuukm9Xq96vF49Pzzz9cvv/xSt23bpoAuXLhQVVWvu+46ffTRR49Zf9euXbWsrExVVQ8cOKCqqvfff/8RbadPn65ff/21qqpu375dBw4cWN0uNTVVS0pKNDc3V3v06KG7du3S+fPna3R0tG7dulWrqqr0rLPO0nfeeUdzc3P1tNNO0+LiYlVVffjhh/XBBx9UVdVevXrpI488Uut+qe3/Bl9XgFr/ptrpIxP2Ni2bT/qmP4PAtvT7mDQ46YjX3RERpN/2Mt//Xy4jS75h/axf0+2e/xDptgPt2iTP/Cgo6816+PxGv/fTTz9l1apVzJo1C4DCwkI2b97Mp59+yqeffsrIkSMBKC4uZvPmzfTs2ZOkpCTGjx8PwJVXXskTTzzBr3/96yPWm5qayhVXXMGFF17IhRdeWOu2P/vsM9at+2Hsz6KioupP+BdccAGxsbHExsZyxhlnsGTJEtq1a8eYMWPo3bs3ANOnT2fhwoXExMSwbt266kwVFRWMGzeuer0/+9nPGr1/arKiYMJaZUU5kXN/iUuURV2uYNyPbqy1ndvtZsBtb/DC4/fw54OTmfFVJr84o2+I05qGyMzMxO1206lTJ1SVJ598krPPPvuINvPmzeOee+7hlltuOeL5rKysY+7cqe1Ono8++oivvvqKf//73/z+979n9erVx7Txer0sXryYmJiYY16raxu1Pa+qTJkyhTfeeKPW7zcuLq7W5xusrkOI5vJlp4/MyVj0j3tV72+r2Q/000PFhSdsv3Bzrva6+0Pt95u5uj3vUAgSNg9N7fRRTk6OTpkyRe+77z5VVX3uuef0ggsu0IqKClVV3bhxoxYXF+u8efN0zJgxevDgQVVVzc7O1n379lWfPvr2229VVfWGG27Qxx577IjteTwe3bZtm6qqVlRUaNeuXfXAgQP62GOPVW9X1Xf66E9/+lP14++//15VfaePhg8frqWlpZqXl6dJSUnVp49iYmI0MzNTPR6PTp06VWfNmqU5OTmalJSkmzdvVlXV4uJi3bhxo6r6Th/l5ubWul8aevrIjn9N2DpYkM+grS8CkDfxD7SKa3vC94zvm8BFI7sTWXWIz97/e7AjmgYoLS1lxIgRDBkyhLPOOoupU6dWXzi+8cYbGTx4MKNGjWLo0KHccsstVFVVMXXqVC6//HLGjRvHsGHDuOSSSzh48CAAAwYM4Omnn2bQoEEcOHCAn//850dsz+PxcOWVVzJs2DBGjhzJ7bffTrt27fjxj3/Me++9V32h+YknniAjI4PU1FQGDx7Ms88+W72O1NRUzjjjDE455RTuvfdeunXrBkB6ejozZsxg0KBBpKSkcNFFF5GYmMgrr7zC9OnTSU1NZdy4cWzYsCHg+9GGuTBh6+33ZjF5xS/ZF5XMoHu+Qup5P3t2Tj6RT48igUJ2X7GApP4jgpy06Vu/fj2DBg1yOkbAZGVl8aMf/Yg1a9YEbRsPPPAArVu3PuY6xYIFC3jsscf48MMPA7Kd2v5vbJgLY45ysKyS361sw2nlj1N2/lP1LggAPTp1JKvDabhF2ffh74KY0pjQswvNJiy9k5FNUVkVY1K6MWpEwz/pJ11wL1V//4jhhV+QuzuLxG7JgQ9pHJOcnBzUowTwHSnUZtKkSUyaNCmo2z4eO1IwYcfr8ZD/9Yu0poTrx6ec+A216JY8gJWtJxApHrbOfSLACY1xjhUFE3bWfPUud5U/xQexD3HWwMbP3Bcz4TYA+mXPoqKsNFDxjHGUFQUTdrxLXwIgJ+UCIiIaP0bM4LFnk+lKpiOFrP7sX4GKZ4yjrCiYsJK/bydDD31Hpbrpf/atJ7UucbnIGXAF33kH8sX2igAlNMZZVhRMWNn82d+JEC9r48bQsXOPk17fgPPv4ErP/fwtO5l9RTYVpZPcbjcjRoyo/nr44YeP2/6VV15hxowZAc/xhz/8IeDrDCUrCiasJGa+B4A39bKArK9962jOHNgJr8L73+8KyDpN4xweOvvw18yZMwO6/qqqqnq1c7Io1Dfj8VhRMGFj27ql9PFkUkgcQyb9NGDr/cnI7gyTTGIW/Tlg6zSBk5ycTF5eHgAZGRm13u55ouGtx48fz1VXXXXEe/bs2cPEiRMZMWIEQ4cO5euvv2bmzJnVPauvuOIKAP71r38xZswYRowYwS233ILH4wGgdevW3HnnnQwZMoTJkyeTm5sL+G5JveOOO6rXu2TJEgAOHTrE9ddfz5gxYxg5ciQffOAbxv2VV15h2rRpnHnmmUyePPmk95cVBRM2Fm/YwXfegazvcBbRMa0Ctt4z+nXgteg/ck3Zv9i+4fuArdc0zOE/xoe/3nrrrXq/94477uDOO+9k6dKlzJ49mxtv/GFgxHXr1vHZZ58dMxDd66+/ztlnn82KFStYuXJl9Smrw0csr732GuvXr+ett97im2++YcWKFbjdbl577TXA90c+LS2NtWvXcvrppx8xGVBJSQkrVqzgmWee4frrrwfg97//PWeeeSZLlixh/vz53HXXXRw6dAiA5cuXM2vWLL788stG77/DrPOaCRuv7khkQ8V9vDJ1dEDXGxUdzcr400gv/ITdi96k18CRAV1/s/RAfN2v/eivkHadbznj7/DhL4+znsJ6b7LmzGsNdbzhradNm0ZsbOwx70lPT+f666+nsrKSCy+8kBG1dIL8/PPPWbZsGenp6YCvcHXq1AkAl8tVPdz1lVdeyU9+8pPq902fPh2AiRMnUlRUREFBAZ9++ilz5szhscceA6CsrIwdO3YAMGXKFDp0CMx0sVYUTFjYlneIDXsP0iYmglP7dgr4+iNTL4KvP6Fz9ifAIwFfv2m8iIiI6pnLyspqvxngeMNb1zUk9cSJE/nqq6/46KOPuPbaa/nVr37F1VdffUQbVeWaa67hj3/84wlz1hwuu66hs2fPns2AAQOOeO27774L3LDZ2OkjEyZWLvyIVNnKWQM7ERUR+B/7QeOncVBj6e3JYufmVQFff7PzQGHdX4ePEsC3fLy2AZCcnMyyZcsAmD17dq1tpk6dypNPPln9uD5HHNu3b6dz587cdNNN3HjjjSxfvhyAyMhIKisrAZg8eTKzZs0iJycHgP3797N9+3bAV4gOT/rz+uuvM2HChOp1Hz71tXDhQuLj44mPj+fss8/mySef5PAgpt9/H5xTlVYUTFgYvOZR5kTfy5Ud1gdl/dExrdjQ7jQAdi06yTlyTaMcfU3h8N1H999/P3fccQdpaWl1Tmh/vOGt67JgwQKGDx/OyJEjeeutt7jjjjsAuPnmm6tnZBs8eDC/+93vmDp1KqmpqUyZMoU9e/YAviOQJUuWMHToUL744gvuu+++6nXHxMQwcuRIbr31Vl56ydfZ8t5776WyspLU1FSGDBnCvffee1L7qy4hGzpbRM4BHgfcwIuq+vBRr18LPAocvq/vKVV98UTrtaGzzYns27GZzi+nUaLRyP9sJTauTVC2s2zu3xm95JesixrK4P/3TVC20VS1tKGzQ6F169bV1y1qmjRpEo899hhpabWObN1gTXLobBFxA08D5wKDgekiMriWpm+p6gj/1wkLgjH1se2btwHY0Hps0AoCQL9Tp5GtCSwv7ULhIevhbJqnUF1oHgNsUdVMABF5E7gAWHfcdxkTAHHbPwfA0/+8oG6nbbuO3NLtXyzatp+2W/KYNrxbULdnmrfajhLAd1rKSaG6ptAd2Fnjcbb/uaNdLCKrRGSWiCSFJpppyUoPHaR/6Sq8KvQeNy3o2ztzUGcA5m/ICfq2jAmGpnSh+d9AsqqmAv8BXq2roYjcLCIZIpJxuBegMbXZtORjoqWSrRF96dipts8hgXXGwE604yAxG2bjCcCQA81Jc5/atyVqzP9JqIrCLqDmJ/8e/HBBGQBVzVfVcv/DF4E6exip6vOqmqaqaYmJjR8P37R8W7N2kKdtyet6Wki21ycxjjmxD/JHfYLNy78IyTabgpiYGPLz860wNCGqSn5+fq19L44nVNcUlgL9RCQFXzG4DLi8ZgMR6aqqe/wPpwHBuXfQhJUn89P47/JnmD1pVEi2JyLsThhPz9x3KFj5IYyZGpLtOq1Hjx5kZ2djR+5NS0xMDD16NGw04JAUBVWtEpEZwDx8t6S+rKprReQhIENV5wC3i8g0oArYD1wbimym5dqef4hteYdoGxNFakrXkG03buh5MP8dEvYuDNk2nRYZGUlKSuOmNjVNS8iGuVDVucDco567r8byPcA9ocpjWr5lK76nDSWc1q8PEe7QXT7rmzaV8i8i6V2VSUHeXtoldAnZto05WU3pQrMxAdU747d8H30zl8WvDel2Y+NasyV6MC5RMpd+HNJtG3OyrCiYFqmyopx+JSuIEC/9R44P+fYPdvONY1O5ZX7It23MybCiYFqkrasWEidl7JBudO7eO+Tb75A6FY8KJQV24dU0L1YUTIt0YK2vF/Oe9umObL9P6gQmyktcd2gGO/eXOJLBmMawomBapNa7FwHg7jPRke27IyIY0rsXAN9uzXMkgzGNYUXBtDgV5WX0LVsDQK/RzvUTmNAvAVBWr9/gWAZjGspmXjMtzpbVixgsFWS5kkju0tOxHBO7KUujbyMiU1FvFuKqfSx/Y5oSO1IwLc5nhT1IL3uGj/s94GiOXkm9UHHTniJ2bGzc3MHGhJoVBdPiLNqaTy7tSB4a+ltRaxKXix1tfJO57139uaNZjKkvKwqmRSmrqGLZjgMAjO3d0eE04EkaB0DEzkUOJzGmfqwomBZl69J5fO76Lx5q9yEd4qKcjkPnYWcCkHRwBer1OpzGmBOzomBalKINX5DkyqV/26Yxl0HP/iM5QBs6sZ89WXYXkmn6rCiYFiV+72IAovue7nASH5fbTVarVACyV37mcBpjTsyKgmkxystK6FOxEYDeaU1nHoPtQ37OReUPMqvS2QvfxtSHFQXTYmxbs4hoqSTLlUR8h05Ox6nWZ/hEvtd+LM4qcjqKMSdkRcG0GAUbfZPa5MQPdzjJkQZ1bUPr6Ah27C9hT2Gp03GMOS4rCqbFiNqTAYAkjXE4yZEi3C5+kbiCf0b+gZ3fvOV0HGOOy4qCaRFUlefKzuLxqp+QmHqW03GOMaJNEae516BbbX4F07RZUTAtwq6CUuYV9+XlyOn07D3Y6TjHaD9oEgBdDixzNogxJ2BFwbQIy3cUADCqZztcLnE4zbF6Dz+NUo2il3cnB3L3OB3HmDpZUTAtgnf5v7jB/RETO5c7HaVWUdExbIseAMCOlQucDWPMcVhRMC3CkF1vc2/ka6S3LXA6Sp0KE0YCUJL5rcNJjKmbFQXT7JUcKiK5MhOPCinDT3M6Tp1ie/s6r8Xnfe9wEmPqZkXBNHvbVn1DpHjYFtGbuDbtnI5Tp14jJjHbM4F/lo6n0mOD45mmyYqCafaKNvk6reW1b1qd1o7WPqELT8ffxRuVE1m323o3m6bJioJp9mL3+W7zjOg51uEkJzaqV3sAlm0/4HASY2oX0qIgIueIyEYR2SIiM4/T7mIRURFJC2U+0/yo10uvkrUAdB06ydkw9TC2exRTXUtxrXnb6SjG1CoiVBsSETfwNDAFyAaWisgcVV13VLs2wB3Ad6HKZpqv7Xtz2OzpR7I7l77J/Z2Oc0Lp7Uu4NOov7NvbEbjb6TjGHCOURwpjgC2qmqmqFcCbwAW1tPst8AhQFsJsppnK2FPFTZX/zaO9/464mv7Z0J79R1BEHJ3JZ+/OLU7HMeYYofwt6g7srPE42/9cNREZBSSp6kchzGWascPn5kf7z9U3dS63m20xvmE4dq1a4GwYY2rRZD5aiYgL+DPw3/Voe7OIZIhIRm5ubvDDmSarNHMRsZQ1m6IAUNJ5NACVWYsdTmLMsUJZFHYBSTUe9/A/d1gbYCiwQESygFOAObVdbFbV51U1TVXTEhMTgxjZNGVFBfn8+eD/sCz65wztEut0nHpr08/Xia3jgRUOJzHmWKEsCkuBfiKSIiJRwGXAnMMvqmqhqiaoarKqJgOLgWmqmhHCjKYZyVr5FS5RdkYmExPTfIpCyvCJvt7XlVspKS50Oo4xRwhZUVDVKmAGMA9YD7ytqmtF5CERmRaqHKblOLTVN4bQgY4jHU7SMHFt2pEZ0YfN2p2Nmzc7HceYI4TsllQAVZ0LzD3qufvqaDspFJlM8xWX4+u0Fpl8isNJGu61YS/xyuJd3FXQjuZV0kxL12QuNBvTEF6Ph5RSXxeXHqmnO5ym4UYm+66FWc9m09RYUTDN0vaNy2kjpewlkc7dezsdp8FG9WwPKHu3b0S9NjieaTqsKJhmac+GJQBktxnmcJLG6dE+lo9j7mWu/oKdW1Y5HceYalYUTLM0u2oCp5Q9SVbqL52O0igiQllcNwD2rv3a4TTG/MCKgmmWvt9xgL10pN/Apj1c9vGUd/V1wdGdSxxOYswPrCiYZmf/oQoy8w4RE+licLe2TsdptHb+TmydClY6nMSYH1hRMM3OjsXv8lnUr/lNu/8Q6W6+P8LJqeOpUDe9PDsoKsh3Oo4xgBUF0wyVbv2Wvq7d9G1d6XSUkxITG8e2yL64RNm+6iun4xgDWFEwzVBb/8T3sb3HOZzk5B3oMAKA4i3fOpzEGJ+Q9mg25mRVVlaQUr4RBHoOb36d1o5WlnoVl340gHhvOs2/xJmWwI4UTLOStXYJraScndKNDp26n/gNTVz/oWks1YF8l12K16tOxzHGioJpXvLX++7p39u2eXZaO1q3drF0aRvDwbIqtuYWOx3HGCsKpnmJ2L0UAG+PMQ4nCZzrO67m75GPkLvodaejGNPwoiAicSLiDkYYY07k1bLTebpqGh2HTXE6SsCkti7iDPdKIrZ/6XQUY05cFETEJSKXi8hHIpIDbAD2iMg6EXlURPoGP6YxsK+ojH8f7Mvf3FeS0j/V6TgB027ABAA6F9oYSMZ59TlSmA/0Ae4Buqhqkqp2Aibgmx3tERG5MogZjQFguX+Y6RFJ7XC7xOE0gZMybBzlGkkv704K99uc48ZZ9SkKZ6nqb1V1lapWj/GrqvtVdbaqXgy8FbyIxviULX+Ta9zzmNilwukoARUdHcu2qH4AZK20U0jGWScsCqp6RLdREblCRPofr40xwTBo51s8GPkqp7TOcTpKwBV08A3sV7rVOrEZZzXm7qNc4BkR+UpE3hWRhwMdypijlZWW0LtyM14VerWATmtHi0rxTSkal7vc4SQm3DW4KKjqp8B3qjoRuAZoHfBUxhwla823REkV2909iW+f4HScgEtKPYM5nnG8U5qOxzqxGQc1tp9CWxEZDZQDcQHMY0ytDmxcCEBuu5Zz11FNid168Ujc//CPikls2nfQ6TgmjDW2KPwKGA88C3wSuDjG1C56T4ZvIWmss0GCaFSv9gAs33HA4SQmnDW2KPwTGANEAqMDF8eYY6nXS89DqwHoMuQ0h9MEz5hukZzpWk7V6vedjmLCWGNHSV2kqo8DiEjHAOYx5hi7cvLZ4ulFb3cMSX1b5ukjgFPa7ueqqMfYuasbcIfTcUyYauyRwgUicoOI9FdVmzLKBNWyvRVcW3k3D/T6J+JqucN1JQ8ZS6lGkaS7OZC7x+k4Jkw19jfsSmA38BMReaG+bxKRc0Rko4hsEZGZtbx+q4isFpEVIrJQRAY3Mp9pQZb5ezKPTu7gcJLgioyKZlvUAAC2r7JObMYZjS0KdwM/VdWHgVn1eYN/EL2ngXOBwcD0Wv7ov66qw1R1BPAn4M+NzGdakMLMpURTwcie7ZyOEnSFCSMBKM1c5HASE64aWxS8wDb/8hn1fM8YYIuqZqpqBfAmcEHNBqpaVONhHGA3bIe5QwcL+L+CX7Ei+maGd4lxOk7QxaT47q5qk/u9w0lMuGpsUSgB4kUkEuhZz/d0B3bWeJztf+4IIvILEdmK70jh9kbmMy3EtlVfEyFedkUkERfX8vtJ9hw+CYCU8g1UVbasMZ5M81DvoiAiyTUe3g9sxXc6KKAzg6jq06raB98pqv+tI8vNIpIhIhm5uTaqZEt2cJOv01p++xEOJwmNjp17kC1d2asd2Jq51ek4Jgw15Ejh3cMLqlqlqs8AL6vqh/V8/y4gqcbjHv7n6vImcGFtL6jq86qapqppiYmJ9dy8aY7i9i0DwJ0SPtPaP9H/VSZX/B9L9sc6HcWEofpMsvNT/6B3bURkkIjUfM/zDdjWUqCfiKSISBRwGTDnqG31q/HwfGBzA9ZvWhivx0Ny2VoAeqTW99JV8zcsuTPww11XxoRSfTqvfQPEADfiuxtogIgU4LsltbS+G1LVKhGZAcwD3PiOMtaKyENAhqrOAWaIyFlAJXAA34B7Jkxt37icFErYSwJdksJngr+RPX3DXWRt3waMdDaMCTsnLAqqugv4h4hsVdVvoLoXczK+qTnrTVXnAnOPeu6+GsvWjdNU27v+W1KA7DapdHE6TAgN7Nyaz6Pvok/pLvL2jSGhcw+nI5kwUu9rCocLgn85X1WXqeqh4MQyBt6pnMiE8r+yY1h4fVaIiHBTFu0bPWbHygXOhjFhp+WOGWCavYwdBWRrJ/oPCb9TKEWJowAo37bY4SQm3DS6KIhIVxGJDmQYYw7LOVjGjv0ltIpyM7BLG6fjhFxsb9/dVvF51onNhNbJHCn8E9ggIo8FKowxh2V/+w4fR93N3R2+JMIdfge0vVJ9U46mVGyisqLc4TQmnDT6t01VzwJ6A38PXBxjfCoyv2GQayf924Rnr972iV3ZKd2IlQq2rfnO6TgmjDSkR/PjIiI1n1OftYGPZcJd+3zfBPat+453OIlz9rb1zR2xf+PXDicx4aQhRwoHgTkiEgcgImeLyDcneI8xDVZWUkxK5RY8KiSPmOR0HMfkDrmW6RW/4R3vJKejmDBS75nXVPV/ReRyYIGIVADFwDFzIhhzsjJXfs1g8bDVnUKf+JY9h8Lx9E4dz6IvvHTfadcUTOg05PTRZOAm4BCQANyuqnZcawKucJPvxyqvffjdilpTv05taBMdwa6CUnYX1HvwAGNOSkNOH/0GuFdVJwGXAG+JyJlBSWXCWqu9GQC4e53icBJnuV3CTZ3W8kLkY+z89m2n45gw0ZAezWeq6kL/8mp8M6j9LljBTHjyepXnyibzfNX5dEud7HQcx6W3zmeKezls/cLpKCZM1GeUVKnteVXdA0w+XhtjGmprbjFzS4fyUqvr6dqzj9NxHNd+0CQAOh9Y5mwQEzbqc6QwX0T+S0SOmGHNP/z1OBF5FRvN1ATI4m37ARiT0hH7rAG9h59GmUaS7N3J/pzjTT9iTGDUpyicA3iAN0Rkj4isE5Ft+OY6mA78VVVfCWJGE0Zilr3A5e7POa17+PVirk1UdAxbowcBkPW9nUIywVefobPLgGeAZ/xzMicApapaEOxwJryo18sZuf/k0sgCtne+yuk4TUZR5zGwcxUVmQsB2y8muBpyS2oG8CzwM2C0iCQELZUJS9lbV5NAAWD9A0UAABn3SURBVPnE07PfcKfjNBlt+k8EoGNehsNJTDhoyDH6NOAdIAq4BdguItuDksqEpT0rPwcgK2444rLTR4f1HjmJT7xjeK1sPAfLKp2OY1q4hvRo3o1vCs5PAERkEL7+CsYEhGuHb9SUyh6nOpykaWnVOp5nOz/Aip0FTNp+gEkDOjkdybRgDTl91KvmY1VdD/QPeCITltTrJanIN3dAp2HWJ/JoY1N8w30szdrvcBLT0tX7SAHf3Uc9gW3AaqAAGBqUVCbs7Nm+iW7kU0BrkgelOR2nyRnbszXrXKtou3YRnG19Rk3wNOT00an+Tmp9gGFAB+DHwQpmwsuabdls9QwlKr4zY91up+M0OaO7xTAp8hGqCt2UlcwkplVrpyOZFqpBV/P88ydsUdX3VPUlVc0OVjATXj7NT+Sqyv/H6rE2kV9t4jsksi0ihSipYsuyz52OY1owu8XDNAnfbcsHYGxKR4eTNF25CWMBOLjeOrGZ4LGiYBy3Z+9uEg+spH00DO7W1uk4TVbsAN8F+A45ixxOYloyKwrGcTsXzebd6Ad4Me4Z3C4b76gufdKmUKlu+lZuoqgg3+k4poUKWVEQkXNEZKOIbBGRY2ZsE5Ff+cdVWiUinx99C6xpuWTblwBU9hjncJKmrXXb9myNGoBblMyMT52OY1qokBQFEXEDT+Obg2EwMF1EBh/V7HsgTVVTgVnAn0KRzThLvV5SipYC0GXkOQ6nafoKOp/Cdm8ntuzKcTqKaaFCdaQwBtiiqpmqWgG8CVxQs4GqzlfVEv/DxUCPEGUzDtq+cTkJFJBHO3oNGOV0nCZPT5/J6RV/5cUD4T1VqQmeUBWF7sDOGo+z/c/V5Qbg46AmMk3C3u8/ASCrbbqNd1QPo1ISiI5wsWHvQfKKy52OY1qgJvdbKCJXAmnAo8dpc7OIZIhIRm5ubujCmYCL2fk1ANp7kqM5movoCDfpyR1I5AArVq9yOo5pgUJVFHYBSTUe9/A/dwQROQv4DTBNVev8GKSqz6tqmqqmJSYmBjysCY3KKg9tS3YA0DPtXIfTNB83xc5nacwvaLv0caejmBYoVEVhKdBPRFL803heBsyp2UBERgLP4SsIdhUtDKzMLuTM8ke5Ou5vdO5h8zHXV9dBpwDQ7cBSh5OYligkRUFVq4AZwDxgPfC2qq4VkYdEZJq/2aNAa+AdEVkhInPqWJ1pIRZuyQOElP6pTkdpVvqkTqCIVvTQvezetsHpOKaFacgoqSdFVecCc4967r4ay2eFKotpGpZt8t17ML6vTeLXEO6ICLbEpTHq0FfsXDqHbikDnY5kWpAmd6HZhIfiogO8uO9nvBd1H6ekxDsdp9mp6j0ZgOgsGwfJBJYVBeOILUs+IVoqaRUVQdtWsU7HaXaST/F18+l/aDnlZSUnaG1M/VlRMI4oX+/rn7C/2+kOJ2meOnVPIdOVTCRVrF+x2Ok4pgWxomBCTr1eeuX75mNOGPkjh9M0X58P+SMjy5/jw7wuTkcxLYgVBRNy2zcupwu55BNPn9TxTsdptoaNHEsxrViwyTpwmsCxomBCbk+G727jzPhxuGzqzUYb3as9baIjyMwpIjv3gNNxTAthRcGEXJudCwBwDZjqbJBmLtLt4n8TviIj+lZ2zX/R6TimhbCiYEKqqKySq4tncHvV7fQ95YITv8EcV0q3BDpIMVF2a6oJECsKJqQWbs5jvzeOvT3OI76DdVo7Wb3G+gYEGHBomd2aagLCioIJqc/W7QVg0kAbyDAQOvfoQ6YrmVZSzsbFNtq8OXlWFEzIVFaU81/rL+fhiOc5u387p+O0GDndfL2bS1d/4HAS0xJYUTAhs3HJPFLYzbioLfTpZqeOAiUh/WIA+uR/idfjcTiNae6sKJiQObTifQB2dznT4SQtS59h49hDIgkUsPH7r5yOY5q5kI2SasKber0k5y0AoGPaxc6GaWHE5eKLvjN5fW05E/Z1ZpDTgUyzZkcKJiS2rFxIZ/LJoQN9h5/mdJwWp/e4i1irycxbtw9VdTqOacasKJiQyMt4F4BtCZOsF3MQpCe3p32rSLLyS9i876DTcUwzZkXBBJ16vXTb8x8AWqVOO0Fr0xgRbhe3JW1nVtQD5M39vdNxTDNmRcEE3fo9B7mt9DZelIsZeMp5TsdpscYktyPNtYnu2R+iXq/TcUwzZUXBBN2cVXtYq8lkDb+TyKhop+O0WIPHT+MAbejlzSZz7RKn45hmyoqCCSpV5d8rdwMwbXh3h9O0bJFR0Wzq6OvIlrPoNYfTmObKioIJqo0Zn/N4yd1c2/o70nq1dzpOi9d69M8A6LXnEzuFZBrFioIJqsLvXifNtYlzOubgconTcVq8QWPPJocOdNMcNi6zkVNNw1lRMEFTVVlB37zPAOg47gqH04QHl9tNZuezATiw9G2H05jmyIqCCZq1X79PRwrZKd3oa9NuhkyHCTfw68pb+HXej6ioslNIpmGsKJig8Sz7BwDZyT9BXPajFir9hqaxJvFH7Cpx8/n6fU7HMc2M/aaaoNifs4thxd/iUaHvlJucjhNWRISfpiUBMGtplrNhTLMTsqIgIueIyEYR2SIiM2t5faKILBeRKhG5JFS5THBs+s9LRIqH1a3Gktgt2ek4YefCEd24L/JfPLr9EnJ2bXM6jmlGQlIURMQNPA2cCwwGpovI4KOa7QCuBV4PRSYTPKrK7/eO5deVt1A+ZobTccJSh9bRjIgvpoMUk/mfF5yOY5qRUB0pjAG2qGqmqlYAbwJHzNquqlmqugqwK2PN3IqdBazOrWJB7BRGTTzf6ThhK2L0NQD02D7bJt8x9RaqotAd2Fnjcbb/OdMCvf7tFgB+MqoHkW67bOWUIaddyF4S6KF7WfPVe07HMc1Es/yNFZGbRSRDRDJyc3OdjmNqyNu7g3vWX8S9Ef/kqrE9nY4T1twREWzrfTkA+t2zDqcxzUWoisIuIKnG4x7+5xpFVZ9X1TRVTUtMTDzpcCZwNs99kg5ykFFti0jqGOd0nLA36LxfUKpRDC9byo5NK5yOY5qBUBWFpUA/EUkRkSjgMmBOiLZtQqS8rIR+O3y9aCPH/9zhNAagXUIXVnf09XBev+Ath9OY5iAkRUFVq4AZwDxgPfC2qq4VkYdEZBqAiKSLSDZwKfCciKwNRTYTOKvmvUICBWxzJTNknF1gbio6nnM308p/y53Zp1NYWul0HNPERYRqQ6o6F5h71HP31Vheiu+0kmmGvB4PHVc+B0DukGtJsR7MTUaf/sOISzlESWY+r323ndsm9XU6kmnC7DfXBMSKz9+gtzeLHDqQet7NTscxR7ntjD4AfPPVfygpLnQ4jWnKrCiYk6aq7Fw+D4DM/jcQE2sXmJuaCX0T+Ev72bzmncmq9//qdBzThFlRMCdtwcZc7ii4jOvcv2f4BXc4HcfUQkTonT4VgL5bXqaspNjhRKapsqJgToqq8sQXmwEYd/q5xMa1cTiRqUvqpJ+yxd2HBApYYUcLpg5WFMxJWTL/A3RnBh3iorhibC+n45jjEJeLg2N/BcCATc9RVJDvcCLTFFlRMI1WUV5G96/v5v3o+/jTiFziokN2M5tppBFnXc66yKG0p4h1b93vdBzTBFlRMI22fNaf6KF72e7qwaRzbLTz5kBcLiLO/QMAo3a/we6sjQ4nMk2NFQXTKAV5exm02TeeTsGEe4mIjHI4kamv/qNO57v4c3jNM5nHF+51Oo5pYqwomEbZ/M/biecQa6JHkDrpp07HMQ3U/ZqXeZjreGtNMQs25jgdxzQhVhRMg61aMJv0wnmUaSTxlz5l8y83Qz06xHHnlP4A/O7dpRQftA5txsd+m02DHCqrJOrL3wPwfZ+fk9R3mMOJTGPdOCGFSztl80rZ7az9x6+cjmOaCCsKpkF+N3cDl5fexVsxPyV9+r1OxzEnIcLt4tazhtGF/aTnzGbN1x84Hck0AVYUTL19sGIXbyzZQXFEO4Zd/X92cbkF6JN6Kkt73oBLlC6f307e3h1ORzIOs6Jg6mXHphXsfe83RFDF/T8ezOBubZ2OZAJkzDUPszZqOAkUsO/lK/FUVTkdyTjIioI5ocL9uXjfvJJb5D0e7z6fy8fYNJstiTsigs7X/4s82jGkYiVLX5jhdCTjICsK5rjKy0rY9exFJHt3ss3Vk4lX34eIOB3LBFhCl57sm/IMFepmzN43+eCTj52OZBxi4xKYOnmqqljz1HRGV6wmhw7EXvsubeI7OB3LBMmQ8eezNP8PvLx4N5986SWy+x7OG9bV6VgmxOxIwdSqsqKcFY9fwujiBRRrLEUXv06Xnv2cjmWCLH3arQydcjWq8F9vfM/cJeucjmRCzIqCOUZZpYePnrmL0QfnU6yx7Dj3FfoOG+d0LBMit03qw88n9WGwbmX8R5NZMtuG2Q4nVhTMEXYXlHLps4v4f3tPZwGj2XPh2ww+5RynY5kQEhHuPmcgdw/KI15KGLP6fhb/7VaqKiucjmZCwIqCqbZq4Udc+sRnrN5VSMcO7el66wf0GznR6VjGIROufpDvBv8vlermlH1vsOGxKeTt3u50LBNkVhQMBwv3s+SJq0j97HJmVj7FaX07MucXExjQxWZRC3djf3oXm895jXziGVq+gqjnx7H0g2dQr9fpaCZIrCiEsarKCpbM/gulfxnNmP1zqFA3HXuP5O/XjKZ9nPVWNj6Dx52L56YvWRmTTlsO0W/577jh+fmsyi5wOpoJAlFVpzOclLS0NM3IyHA6RrNSVnqI1R+/SKc1z9PLmw3Apoj+RPzkaXoPHuNwOtNUqdfL0g+eYd7K7bxUNgmACwe25tbBlQxMnwzWf6XZEJFlqppW22vWTyFMqCqrdxUyZ8VuMr7P4N2q+3GJsls6s3v0XYw693pcbrfTMU0TJi4XYy6awYCzK4n4cgt//yaLhM1vMTDrNTZ/2o/9fS6kz6SrSOhqc3U3ZyE9UhCRc4DHATfwoqo+fNTr0cA/gNFAPvAzVc063jrtSKFuhQdyyVw6j4rN84nI38jFJTMB36e5R+NnkzwojeHnXk9UdIyzQU2ztK+ojDXvPszorBdoRzEAXhU2RA+lqNtpxKeeS7/hE4hw21nqpuZ4RwohKwoi4gY2AVOAbGApMF1V19VocxuQqqq3ishlwEWq+rPjrTfci4KqUlRaxe7CUnJ3ZRK1+g2i8teTeGgz3b17cMkP/7/Xuv9I8vDT+fHwbozq2c6GqzABUXroIGsXvIV77WyGHFpClPgG1PvAcyozuYNBXduQluBhMoto1akv7bv3o1NSP6JiYh1OHr6ayumjMcAWVc30h3oTuACo2WXyAuAB//Is4CkREQ1w5fJ4lQPFpVCSB6rVX4r6FvHije2AumNRgNIDSFkR4I+hiqoXUFQiqGzbk8MR3QXbwFMFeJHD61NFValqlYAnNgFVkPIiIot8t/epx0NVZRlVFWV4KsvwVpazK/E0yr0RVHi8dM7+lFaFm6GsEHd5IRGVRURVFtGmKp+vPMP43/KrARgk2/k4+tnq77OCCLZEDqKwyzjiB5/JS+lTcEfYGUMTWLFxbUg7/0Y4/0YK9+ex+rt/U7VlAcsPDaS0wMPyHQV0zM5gbNSfq3/bvSrkSjsOuuMpiWjHP5MeIrZtAq2jIxhW8Bnxnv24ouJwRccREdMaV2Q0LncknrjOVHXsj9vlItJbTuzBbbjdEbjcEYjLjYgg4kIRvK07Q4TvKNhVfgB35SEAxH+0jEsQXOCKQFt3OvwsrkP7fAsiwJEfnDQqDiLjfA8qS5GKg3XuF22VWH2dRUoPgLey9oYRMWi0f9RhbxVSur/udca0A/cPN4HEx0YSFRHYI7FQ/oXoDuys8TgbGFtXG1WtEpFCoCOQF8ggOQfLuOiPs1gc8191trmu4i7me0cC8KuIt7k94v1a2+3RDkwof6r68dLoW0mUolrb/rnyEp7w/ASAKa4MXoj6c53bv7Hsb+QTD8Crka9yintVre16ejsQF+WmW7tYusWnsrhiOu6uqbTvM5KkfsMZHNOqzm0YE2jxHRIYfe51wHWMBe44VMH6PUXkra9gydbziCvJpkPFHjqRRyIHSPQcAA+8v2Y/5f5TUG9HvcoY18Za1/+eZzx3Vv4CgD6yi8+j76ozyyXl95GhAwG4P+JVrouYV2u7zd7uTKl4tPrxhuhriJHa/4A/VHkVL3vOBeAi19f8JepvdW5/QNkrlOP7A/5O1AOkuzbV2u5dzwR+VXkbAH0lm8+i/6de3xPA6zeO5dS+CXW2b4xm+bFRRG4Gbgbo2bPhwzi7RIhvFUOet131c4r4v3yfDeJaxdHFFeMr9FXt2ONJRMXXRvAdMyjCfld7enVshfhysbekO6XaBvV/QlB++DeiTQIDon33/revSmBraQr411klUXhckXhcUXhcUUzu3ZXK6PZEuV0cKDyXb72jkZh4XK3aExnXjsi4DsS178zwHn1ZE9+hxqmgCY3ap8YEQ4e4KMb3TYC+FwEXVT9fUV5Obs5OivfvpaQgl4fbprP/UCWHyqso2PFjFpcMw1VZgstTSkRVKS6txOX1UBrVj1Gt2uHxKomVxWQeTMaFB7d6EP/R+eHfhPg2renu8p2i8lbGs9fzwx/PH37bodDdjsQ20YDvpEF+VXuiqah+vSZ3dCsSXL4/9FHeWPK98XV+7wlxkZSLr22ppw35Wnvbqog4EqJ97eI1mnxP3euMaxVDgvxwpBCM6zWhvKYwDnhAVc/2P74HQFX/WKPNPH+bRSISAewFEo93+ijcrykYY0xDHe+aQihvC1gK9BORFBGJAi4D5hzVZg5wjX/5EuCLQF9PMMYYU7eQnT7yXyOYAczDd0vqy6q6VkQeAjJUdQ7wEvBPEdkC7MdXOIwxxoRISK8pqOpcYO5Rz91XY7kMuDSUmYwxxvzAepUYY4ypZkXBGGNMNSsKxhhjqllRMMYYU82KgjHGmGrNfj4FEckFGjtHYAIBHkIjQCxXw1iuhrFcDdMSc/VS1cTaXmj2ReFkiEhGXb36nGS5GsZyNYzlaphwy2Wnj4wxxlSzomCMMaZauBeF550OUAfL1TCWq2EsV8OEVa6wvqZgjDHmSOF+pGCMMaaGsCoKIvKoiGwQkVUi8p6ItKuj3TkislFEtojIzBDkulRE1oqIV0TqvJtARLJEZLWIrBCRoE8i0YBcod5fHUTkPyKy2f9v+zraefz7aoWIHD1MeyDzHPf7F5FoEXnL//p3IpIcrCwNzHWtiOTW2Ec3hiDTyyKSIyJr6nhdROQJf+ZVIjIq2JnqmWuSiBTW2Ff31dYuCLmSRGS+iKzz/y7eUUubwO6zw/MHh8MXMBWI8C8/AjxSSxs3sBXoDUQBK4HBQc41CBgALADSjtMuC0gI4f46YS6H9tefgJn+5Zm1/T/6XysOwT464fcP3AY861++DHirieS6FngqVD9P/m1OBEYBa+p4/TzgY3wTIJ4CfNdEck0CPgzlvvJvtyswyr/cBthUy/9jQPdZWB0pqOqnqlrlf7gY6FFLszHAFlXNVNUK4E3ggiDnWq+qtU9K66B65gr5/vKv/1X/8qvAhUHe3vHU5/uvmXcWMFl+mD/VyVwhp6pf4ZsrpS4XAP9Qn8VAOxHp2gRyOUJV96jqcv/yQWA9vrnsawroPguronCU6/FV16N1B3bWeJzNsf8JTlHgUxFZ5p+nuilwYn91VtU9/uW9QOc62sWISIaILBaRYBWO+nz/1W38H0oKgY5BytOQXAAX+085zBKRpCBnqo+m/Ps3TkRWisjHIjIk1Bv3n3YcCXx31EsB3WchnWQnFETkM6BLLS/9RlU/8Lf5DVAFvNaUctXDBFXdJSKdgP+IyAb/JxyncwXc8XLVfKCqKiJ13ULXy7+/egNfiMhqVd0a6KzN2L+BN1S1XERuwXc0c6bDmZqq5fh+nopF5DzgfaBfqDYuIq2B2cAvVbUomNtqcUVBVc863usici3wI2Cy+k/IHWUXUPMTUw//c0HNVc917PL/myMi7+E7RXBSRSEAuUK+v0Rkn4h0VdU9/sPknDrWcXh/ZYrIAnyfsgJdFOrz/R9uky0iEUA8kB/gHA3Opao1M7yI71qN04Ly83Syav4hVtW5IvKMiCSoatDHRBKRSHwF4TVVfbeWJgHdZ2F1+khEzgH+B5imqiV1NFsK9BORFBGJwndhMGh3rtSXiMSJSJvDy/gumtd6p0SIObG/5gDX+JevAY45ohGR9iIS7V9OAMYD64KQpT7ff828lwBf1PGBJKS5jjrvPA3f+WqnzQGu9t9RcwpQWONUoWNEpMvh60AiMgbf385gF3b823wJWK+qf66jWWD3Waivpjv5BWzBd+5thf/r8B0h3YC5Ndqdh+8q/1Z8p1GCnesifOcBy4F9wLyjc+G7i2Sl/2ttU8nl0P7qCHwObAY+Azr4n08DXvQvnwqs9u+v1cANQcxzzPcPPITvwwdADPCO/+dvCdA72Puonrn+6P9ZWgnMBwaGINMbwB6g0v+zdQNwK3Cr/3UBnvZnXs1x7sYLca4ZNfbVYuDUEOWagO9a4qoaf7fOC+Y+sx7NxhhjqoXV6SNjjDHHZ0XBGGNMNSsKxhhjqllRMMYYU82KgjHGmGpWFIwxxlSzomCMMaaaFQVjAsw//v0U//LvRORJpzMZU18tbuwjY5qA+4GH/AMXjsQ3hIQxzYL1aDYmCETkS6A1MEl94+Ab0yzY6SNjAkxEhuGbMavCCoJpbqwoGBNA/pFHX8M3G1axf2ReY5oNKwrGBIiItALeBf5bVdcDv8V3fcGYZsOuKRhjjKlmRwrGGGOqWVEwxhhTzYqCMcaYalYUjDHGVLOiYIwxppoVBWOMMdWsKBhjjKlmRcEYY0y1/w/sbd0RnfP+rQAAAABJRU5ErkJggg==\n", 589 | "text/plain": [ 590 | "
" 591 | ] 592 | }, 593 | "metadata": { 594 | "tags": [], 595 | "needs_background": "light" 596 | } 597 | }, 598 | { 599 | "output_type": "display_data", 600 | "data": { 601 | "image/png": "\n", 602 | "text/plain": [ 603 | "
" 604 | ] 605 | }, 606 | "metadata": { 607 | "tags": [], 608 | "needs_background": "light" 609 | } 610 | } 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "metadata": { 616 | "colab": { 617 | "base_uri": "https://localhost:8080/" 618 | }, 619 | "id": "4kOIsX8jO_K-", 620 | "outputId": "e72cdcee-5d73-45e1-f0ff-29f165a28271" 621 | }, 622 | "source": [ 623 | "# If the time stepper is linear. Does it learn the right coefficients?\n", 624 | "\n", 625 | "# Euler step: uc + D * dt/dx**2 * (ul - 2 * uc + ur) + v * dt/(2*dx) * (ur - ul)\n", 626 | "\n", 627 | "weights = deep_stepper.get_weights()[0]\n", 628 | "bias = deep_stepper.get_weights()[1]\n", 629 | "\n", 630 | "print(\"actual coefficient of u_left is %.5f and the fit is %.5f\"%(D*dt/dx**2 - v*dt/(2*dx), weights[0]))\n", 631 | "print(\"actual coefficient of u_center is %.5f and the fit is %.5f\"%(-2*D*dt/dx**2 + 1, weights[1]) )\n", 632 | "print(\"actual coefficient of u_right is %.5f and the fit is %.5f\"%(D*dt/dx**2 + v*dt/(2*dx), weights[2]))\n", 633 | "print(bias)\n", 634 | "\n", 635 | "## In general, you're not guaranteed to get the same solution." 636 | ], 637 | "execution_count": null, 638 | "outputs": [ 639 | { 640 | "output_type": "stream", 641 | "text": [ 642 | "actual coefficient of u_left is 0.19500 and the fit is 0.25546\n", 643 | "actual coefficient of u_center is 0.60000 and the fit is 0.47929\n", 644 | "actual coefficient of u_right is 0.20500 and the fit is 0.26528\n", 645 | "[-1.5215408e-05]\n" 646 | ], 647 | "name": "stdout" 648 | } 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "metadata": { 654 | "id": "zS4isykIckmC" 655 | }, 656 | "source": [ 657 | "## Nonlinear regression of stepper\n", 658 | "\n", 659 | "# Build model\n", 660 | "deep_stepper2 = keras.models.Sequential()\n", 661 | "deep_stepper2.add(layers.Dense(2, input_dim=3, activation='elu'))\n", 662 | "deep_stepper2.add(layers.Dense(10, activation='elu'))\n", 663 | "deep_stepper2.add(layers.Dense(1, activation='linear'))\n", 664 | "\n", 665 | "# Compile model\n", 666 | "deep_stepper2.compile(loss='mse', optimizer='adam')\n", 667 | "\n", 668 | "# Fit!\n", 669 | "history = deep_stepper2.fit(Xs_train, ys_train, epochs=3, batch_size=32, \n", 670 | " validation_data=(Xs_dev, ys_dev),\n", 671 | " callbacks=keras.callbacks.EarlyStopping(patience=5))" 672 | ], 673 | "execution_count": null, 674 | "outputs": [] 675 | }, 676 | { 677 | "cell_type": "code", 678 | "metadata": { 679 | "id": "oZWvSlptcyZE" 680 | }, 681 | "source": [ 682 | "## Integrate with the neural network\n", 683 | "# WARNING (and lesson): this will take too long! (around 10 minutes)\n", 684 | "# Can accelerate by vectorizing input_stencil\n", 685 | "\n", 686 | "uu_deep2 = np.zeros(uu.shape)\n", 687 | "uu_deep2[0, :] = uu[0, :]\n", 688 | "\n", 689 | "for ti, t in enumerate(tt[:-1]):\n", 690 | " # print(ti/len(tt))\n", 691 | " for xi, x in enumerate(xx[1:-1]):\n", 692 | " input_stencil = np.array([[uu[ti, xi-1], uu[ti, xi], uu[ti, xi+1]]])\n", 693 | " uu_deep2[ti+1, xi] = deep_stepper2( input_stencil )[0][0].numpy()\n", 694 | " uu_deep2[ti+1, 0] = uu_deep2[ti+1, 1]\n", 695 | " uu_deep2[ti+1, -1] = uu_deep2[ti+1, -2]" 696 | ], 697 | "execution_count": null, 698 | "outputs": [] 699 | }, 700 | { 701 | "cell_type": "code", 702 | "metadata": { 703 | "id": "-dJ8HE8EcyZH" 704 | }, 705 | "source": [ 706 | "fig = plt.figure()\n", 707 | "plt.plot(xx, uu_deep[-1, :], lw=2)\n", 708 | "plt.plot(xx, uu_deep2[-1, :], lw=2)\n", 709 | "plt.plot(xx, uu[-1, :], lw=2, '--')\n", 710 | "plt.xlabel('$x$')\n", 711 | "plt.ylabel('$u(x, t)$')\n", 712 | "plt.legend(['Deep stepper', 'Euler stepper'])\n", 713 | "\n", 714 | "idx_list = [0, int(len(tt)/4), int(len(tt)/2), int(len(tt))-1]\n", 715 | "leg = []\n", 716 | "fig = plt.figure()\n", 717 | "for idx in idx_list:\n", 718 | " plt.plot(xx, (uu_deep[idx, :] - uu[idx, :])**2, lw=2)\n", 719 | " leg.append(str(tt[idx]))\n", 720 | "plt.xlabel('$x$')\n", 721 | "plt.ylabel('$Error$')\n", 722 | "plt.legend(leg)\n" 723 | ], 724 | "execution_count": null, 725 | "outputs": [] 726 | }, 727 | { 728 | "cell_type": "markdown", 729 | "metadata": { 730 | "id": "_uV4EY3i9WXk" 731 | }, 732 | "source": [ 733 | "#### Generalize over $D$, $v$, $dx$, $dt$?" 734 | ] 735 | } 736 | ] 737 | } --------------------------------------------------------------------------------