├── .gitignore
├── .markdownlint.yaml
├── LICENSE
├── README.md
├── [1]_ODE_PINN.ipynb
├── [1]_ODE_PINN_ClassForm.ipynb
├── [2]_PDE_Burgers_PINN.ipynb
├── [3]_PDE_Laplace_PINN.ipynb
├── [3]_PDE_Laplace_PINN_ClassForm.ipynb
├── [4]_ODE_Supervised_and_PINN.ipynb
├── [5]_System_of_ODEs_PINN.ipynb
├── [6]_ODE_PINN_finite_difference.ipynb
└── [7]_PDE_LAPLACE_FO_PINN.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | **/__pycache__
3 | **/.ipynb_checkpoints
4 | **/.DS_Store
5 | *.Icon
6 | *.egg-info/
7 | .trunk
--------------------------------------------------------------------------------
/.markdownlint.yaml:
--------------------------------------------------------------------------------
1 | # Autoformatter friendly markdownlint config (all formatting rules disabled)
2 | default: true
3 | blank_lines: false
4 | bullet: false
5 | html: false
6 | indentation: false
7 | line_length: false
8 | spaces: false
9 | url: false
10 | whitespace: false
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 ASEM000
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [Physics informed neural network](https://maziarraissi.github.io/PINNs/) in [JAX](https://github.com/google/jax)
2 |
3 | Example notebooks for various applications in Physics informed neural network.
4 |
5 | | Description | Functional form | Class form ✨*New*✨ |
6 | |---|---|---|
7 | | **[ODE]** |
|
|
8 | | **[PDE]** Burgers |
| |
9 | | **[PDE]** Laplace |
|
|
10 | | **[ODE]** Supervised loss + PINN |
| |
11 | | **[ODE]** System of ODE |
| |
12 | | **[ODE]** Finite difference |
| |
13 |
14 |
15 | If you find it useful please give it a ⭐
16 |
--------------------------------------------------------------------------------
/[1]_ODE_PINN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "view-in-github"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {
17 | "id": "v77fdC1ZLyg1"
18 | },
19 | "outputs": [],
20 | "source": [
21 | "#Credits : Mahmoud Asem @Asem000 Septemeber 2021"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 3,
27 | "metadata": {
28 | "colab": {
29 | "base_uri": "https://localhost:8080/"
30 | },
31 | "id": "vAR0swbLX_ZI",
32 | "outputId": "23ef5ef3-495b-4582-96a2-3108e8dce0bb"
33 | },
34 | "outputs": [
35 | {
36 | "name": "stdout",
37 | "output_type": "stream",
38 | "text": [
39 | "Collecting optax\n",
40 | " Downloading optax-0.0.9-py3-none-any.whl (118 kB)\n",
41 | "\u001b[K |████████████████████████████████| 118 kB 8.5 MB/s eta 0:00:01\n",
42 | "\u001b[?25hCollecting chex>=0.0.4\n",
43 | " Downloading chex-0.0.8-py3-none-any.whl (57 kB)\n",
44 | "\u001b[K |████████████████████████████████| 57 kB 6.1 MB/s eta 0:00:01\n",
45 | "\u001b[?25hRequirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.1.70+cuda110)\n",
46 | "Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.2.19)\n",
47 | "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (0.12.0)\n",
48 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.19.5)\n",
49 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->optax) (1.15.0)\n",
50 | "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.6)\n",
51 | "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.11.1)\n",
52 | "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0)\n",
53 | "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.12)\n",
54 | "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.4.1)\n",
55 | "Installing collected packages: chex, optax\n",
56 | "Successfully installed chex-0.0.8 optax-0.0.9\n"
57 | ]
58 | }
59 | ],
60 | "source": [
61 | "#Imports\n",
62 | "import jax \n",
63 | "import jax.numpy as jnp\n",
64 | "import numpy as np\n",
65 | "import matplotlib.pyplot as plt\n",
66 | "from matplotlib import cm\n",
67 | "import matplotlib as mpl\n",
68 | "!pip install optax\n",
69 | "import optax"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 4,
75 | "metadata": {
76 | "id": "yoPHsh5lWvyP"
77 | },
78 | "outputs": [],
79 | "source": [
80 | "import sympy as sp"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {
86 | "id": "7bg4nSbsXVwD"
87 | },
88 | "source": [
89 | "### Generate a a differential equation and its solution using SymPy"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": 19,
95 | "metadata": {
96 | "id": "P9664e-mVMTN"
97 | },
98 | "outputs": [],
99 | "source": [
100 | "t= sp.symbols('t')\n",
101 | "f = sp.Function('y')\n",
102 | "diffeq = sp.Eq(f(t).diff(t,t) + f(t).diff(t)-t*sp.cos(2*sp.pi*t),0)\n",
103 | "sol = sp.simplify(sp.dsolve(diffeq,ics={f(0):1,f(t).diff(t).subs(t,0):10}).rhs)"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 20,
109 | "metadata": {
110 | "colab": {
111 | "base_uri": "https://localhost:8080/",
112 | "height": 54
113 | },
114 | "id": "klgFeU6bcTrC",
115 | "outputId": "cad6a477-37e0-414d-daf6-6e4f05a288cd"
116 | },
117 | "outputs": [
118 | {
119 | "data": {
120 | "text/latex": [
121 | "$\\displaystyle - t \\cos{\\left(2 \\pi t \\right)} + \\frac{d}{d t} y{\\left(t \\right)} + \\frac{d^{2}}{d t^{2}} y{\\left(t \\right)} = 0$"
122 | ],
123 | "text/plain": [
124 | "Eq(-t*cos(2*pi*t) + Derivative(y(t), t) + Derivative(y(t), (t, 2)), 0)"
125 | ]
126 | },
127 | "execution_count": 20,
128 | "metadata": {},
129 | "output_type": "execute_result"
130 | }
131 | ],
132 | "source": [
133 | "diffeq"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": 21,
139 | "metadata": {
140 | "colab": {
141 | "base_uri": "https://localhost:8080/",
142 | "height": 60
143 | },
144 | "id": "E4Uu2hbiYJtv",
145 | "outputId": "e58180aa-779a-49fa-9771-b9f4763bb037"
146 | },
147 | "outputs": [
148 | {
149 | "data": {
150 | "text/latex": [
151 | "$\\displaystyle \\left. \\frac{d}{d t} y{\\left(t \\right)} \\right|_{\\substack{ t=0 }} = 10$"
152 | ],
153 | "text/plain": [
154 | "Eq(Subs(Derivative(y(t), t), t, 0), 10)"
155 | ]
156 | },
157 | "execution_count": 21,
158 | "metadata": {},
159 | "output_type": "execute_result"
160 | }
161 | ],
162 | "source": [
163 | "sp.Eq(f(t).diff(t).subs(t,0),10)"
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": 22,
169 | "metadata": {
170 | "colab": {
171 | "base_uri": "https://localhost:8080/",
172 | "height": 38
173 | },
174 | "id": "29QUbt_2YwlJ",
175 | "outputId": "3d9edb85-a44f-42cc-8776-bccda97cbc7e"
176 | },
177 | "outputs": [
178 | {
179 | "data": {
180 | "text/latex": [
181 | "$\\displaystyle y{\\left(0 \\right)} = 1$"
182 | ],
183 | "text/plain": [
184 | "Eq(y(0), 1)"
185 | ]
186 | },
187 | "execution_count": 22,
188 | "metadata": {},
189 | "output_type": "execute_result"
190 | }
191 | ],
192 | "source": [
193 | "sp.Eq(f(t).subs(t,0),1)"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 24,
199 | "metadata": {
200 | "colab": {
201 | "base_uri": "https://localhost:8080/",
202 | "height": 101
203 | },
204 | "id": "r9KVq1yjYfld",
205 | "outputId": "b6e56baa-b419-49e0-bba9-3706d96b394c"
206 | },
207 | "outputs": [
208 | {
209 | "data": {
210 | "text/latex": [
211 | "$\\displaystyle y{\\left(t \\right)} = \\frac{\\left(2 \\pi t e^{t} \\sin{\\left(2 \\pi t \\right)} + 8 \\pi^{3} t e^{t} \\sin{\\left(2 \\pi t \\right)} - 16 \\pi^{4} t e^{t} \\cos{\\left(2 \\pi t \\right)} - 4 \\pi^{2} t e^{t} \\cos{\\left(2 \\pi t \\right)} + 16 \\pi^{3} e^{t} \\sin{\\left(2 \\pi t \\right)} + e^{t} \\cos{\\left(2 \\pi t \\right)} + 12 \\pi^{2} e^{t} \\cos{\\left(2 \\pi t \\right)} - e^{t} + 36 \\pi^{2} e^{t} + 336 \\pi^{4} e^{t} + 704 \\pi^{6} e^{t} - 640 \\pi^{6} - 304 \\pi^{4} - 44 \\pi^{2}\\right) e^{- t}}{4 \\pi^{2} \\left(1 + 8 \\pi^{2} + 16 \\pi^{4}\\right)}$"
212 | ],
213 | "text/plain": [
214 | "Eq(y(t), (2*pi*t*exp(t)*sin(2*pi*t) + 8*pi**3*t*exp(t)*sin(2*pi*t) - 16*pi**4*t*exp(t)*cos(2*pi*t) - 4*pi**2*t*exp(t)*cos(2*pi*t) + 16*pi**3*exp(t)*sin(2*pi*t) + exp(t)*cos(2*pi*t) + 12*pi**2*exp(t)*cos(2*pi*t) - exp(t) + 36*pi**2*exp(t) + 336*pi**4*exp(t) + 704*pi**6*exp(t) - 640*pi**6 - 304*pi**4 - 44*pi**2)*exp(-t)/(4*pi**2*(1 + 8*pi**2 + 16*pi**4)))"
215 | ]
216 | },
217 | "execution_count": 24,
218 | "metadata": {},
219 | "output_type": "execute_result"
220 | }
221 | ],
222 | "source": [
223 | "sp.Eq(f(t),sol)"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": 25,
229 | "metadata": {
230 | "colab": {
231 | "base_uri": "https://localhost:8080/",
232 | "height": 37
233 | },
234 | "id": "MNVOpPyCW-GU",
235 | "outputId": "c333d114-01de-4132-8ebf-effd19785112"
236 | },
237 | "outputs": [
238 | {
239 | "data": {
240 | "text/latex": [
241 | "$\\displaystyle 0$"
242 | ],
243 | "text/plain": [
244 | "0"
245 | ]
246 | },
247 | "execution_count": 25,
248 | "metadata": {},
249 | "output_type": "execute_result"
250 | }
251 | ],
252 | "source": [
253 | "#verify solution\n",
254 | "sp.simplify(-t*sp.cos(sp.pi*2*t)+sol.diff(t)+sol.diff(t,t))"
255 | ]
256 | },
257 | {
258 | "cell_type": "markdown",
259 | "metadata": {
260 | "id": "NQ61lEQeXgrc"
261 | },
262 | "source": [
263 | "### Constructing the MLP"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 61,
269 | "metadata": {
270 | "id": "Lml6PGLPZgmr"
271 | },
272 | "outputs": [],
273 | "source": [
274 | "N_b = 1\n",
275 | "N_c = 100\n",
276 | "\n",
277 | "tmin,tmax=0. ,jnp.pi\n",
278 | "\n",
279 | "'''boundary conditions'''\n",
280 | "\n",
281 | "\n",
282 | "# U[0] = 1\n",
283 | "t_0 = jnp.ones([N_b,1],dtype='float32')*0.\n",
284 | "ic_0 = jnp.ones_like(t_0) \n",
285 | "IC_0 = jnp.concatenate([t_0,ic_0],axis=1)\n",
286 | "\n",
287 | "# U_t[0] = 10\n",
288 | "t_b1 = jnp.zeros([N_b,1])\n",
289 | "bc_1 = jnp.ones_like(t_b1) * 10\n",
290 | "BC_1 = jnp.concatenate([t_b1,bc_1],axis=1)\n",
291 | "\n",
292 | "conds = [IC_0,BC_1]\n",
293 | "\n",
294 | "#collocation points\n",
295 | "\n",
296 | "key=jax.random.PRNGKey(0)\n",
297 | "\n",
298 | "t_c = jax.random.uniform(key,minval=tmin,maxval=tmax,shape=(N_c,1))\n",
299 | "colloc = t_c\n",
300 | "\n",
301 | "def ODE_loss(t,u):\n",
302 | " u_t=lambda t:jax.grad(lambda t:jnp.sum(u(t)))(t)\n",
303 | " u_tt=lambda t:jax.grad(lambda t : jnp.sum(u_t(t)))(t)\n",
304 | " return -t*jnp.cos(2*jnp.pi*t) + u_t(t) + u_tt(t)"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 68,
310 | "metadata": {
311 | "id": "KoZZJl2TbI_n"
312 | },
313 | "outputs": [],
314 | "source": [
315 | "def init_params(layers):\n",
316 | " keys = jax.random.split(jax.random.PRNGKey(0),len(layers)-1)\n",
317 | " params = list()\n",
318 | " for key,n_in,n_out in zip(keys,layers[:-1],layers[1:]):\n",
319 | " lb, ub = -(1 / jnp.sqrt(n_in)), (1 / jnp.sqrt(n_in)) # xavier initialization lower and upper bound\n",
320 | " W = lb + (ub-lb) * jax.random.uniform(key,shape=(n_in,n_out))\n",
321 | " B = jax.random.uniform(key,shape=(n_out,))\n",
322 | " params.append({'W':W,'B':B})\n",
323 | " return params\n",
324 | "\n",
325 | "def fwd(params,t):\n",
326 | " X = jnp.concatenate([t],axis=1)\n",
327 | " *hidden,last = params\n",
328 | " for layer in hidden :\n",
329 | " X = jax.nn.tanh(X@layer['W']+layer['B'])\n",
330 | " return X@last['W'] + last['B']\n",
331 | "\n",
332 | "@jax.jit\n",
333 | "def MSE(true,pred):\n",
334 | " return jnp.mean((true-pred)**2)\n",
335 | "\n",
336 | "def loss_fun(params,colloc,conds):\n",
337 | " t_c =colloc[:,[0]]\n",
338 | " ufunc = lambda t : fwd(params,t)\n",
339 | " ufunc_t=lambda t:jax.grad(lambda t:jnp.sum(ufunc(t)))(t)\n",
340 | " loss =jnp.mean(ODE_loss(t_c,ufunc) **2)\n",
341 | "\n",
342 | " t_ic,u_ic = conds[0][:,[0]],conds[0][:,[1]] \n",
343 | " loss += MSE(u_ic,ufunc(t_ic))\n",
344 | "\n",
345 | " t_bc,u_bc = conds[1][:,[0]],conds[1][:,[1]] \n",
346 | " loss += MSE(u_bc,ufunc_t(t_bc))\n",
347 | "\n",
348 | " return loss\n",
349 | "\n",
350 | "@jax.jit\n",
351 | "def update(opt_state,params,colloc,conds):\n",
352 | " # Get the gradient w.r.t to MLP params\n",
353 | " grads=jax.jit(jax.grad(loss_fun,0))(params,colloc,conds)\n",
354 | " \n",
355 | " #Update params\n",
356 | " updates, opt_state = optimizer.update(grads, opt_state)\n",
357 | " params = optax.apply_updates(params, updates)\n",
358 | "\n",
359 | " #Update params\n",
360 | " # return jax.tree_multimap(lambda params,grads : params-LR*grads, params,grads)\n",
361 | " return opt_state,params\n"
362 | ]
363 | },
364 | {
365 | "cell_type": "code",
366 | "execution_count": 69,
367 | "metadata": {
368 | "id": "ae1ZDoy0c29c"
369 | },
370 | "outputs": [],
371 | "source": [
372 | "# construct the MLP of 6 hidden layers of 8 neurons for each layer\n",
373 | "params = init_params([1] + [20]*4+[1])"
374 | ]
375 | },
376 | {
377 | "cell_type": "code",
378 | "execution_count": 70,
379 | "metadata": {
380 | "id": "jySmbUwic5yk"
381 | },
382 | "outputs": [],
383 | "source": [
384 | "lr = optax.piecewise_constant_schedule(1e-3,{10_000:5e-3,30_000:1e-3,50_000:5e-4,70_000:1e-4})\n",
385 | "optimizer = optax.adam(lr)\n",
386 | "opt_state = optimizer.init(params)"
387 | ]
388 | },
389 | {
390 | "cell_type": "code",
391 | "execution_count": 71,
392 | "metadata": {
393 | "colab": {
394 | "base_uri": "https://localhost:8080/"
395 | },
396 | "id": "kBzGA8OVc8C6",
397 | "outputId": "effacbde-f915-47cd-8f13-a44e6b697062"
398 | },
399 | "outputs": [
400 | {
401 | "name": "stdout",
402 | "output_type": "stream",
403 | "text": [
404 | "Epoch=0\tloss=1.026e+02\n",
405 | "Epoch=100\tloss=1.500e+01\n",
406 | "Epoch=200\tloss=7.508e+00\n",
407 | "Epoch=300\tloss=5.177e+00\n",
408 | "Epoch=400\tloss=3.458e+00\n",
409 | "Epoch=500\tloss=2.615e+00\n",
410 | "Epoch=600\tloss=2.391e+00\n",
411 | "Epoch=700\tloss=2.254e+00\n",
412 | "Epoch=800\tloss=2.133e+00\n",
413 | "Epoch=900\tloss=2.013e+00\n",
414 | "Epoch=1000\tloss=1.855e+00\n",
415 | "Epoch=1100\tloss=1.549e+00\n",
416 | "Epoch=1200\tloss=1.018e+00\n",
417 | "Epoch=1300\tloss=9.915e-01\n",
418 | "Epoch=1400\tloss=9.756e-01\n",
419 | "Epoch=1500\tloss=9.618e-01\n",
420 | "Epoch=1600\tloss=9.480e-01\n",
421 | "Epoch=1700\tloss=9.327e-01\n",
422 | "Epoch=1800\tloss=9.145e-01\n",
423 | "Epoch=1900\tloss=8.902e-01\n",
424 | "Epoch=2000\tloss=7.379e-01\n",
425 | "Epoch=2100\tloss=3.372e-01\n",
426 | "Epoch=2200\tloss=1.753e-02\n",
427 | "Epoch=2300\tloss=4.962e-03\n",
428 | "Epoch=2400\tloss=3.304e-03\n",
429 | "Epoch=2500\tloss=2.670e-03\n",
430 | "Epoch=2600\tloss=2.238e-03\n",
431 | "Epoch=2700\tloss=1.883e-03\n",
432 | "Epoch=2800\tloss=1.577e-03\n",
433 | "Epoch=2900\tloss=1.311e-03\n",
434 | "Epoch=3000\tloss=1.082e-03\n",
435 | "Epoch=3100\tloss=8.856e-04\n",
436 | "Epoch=3200\tloss=7.198e-04\n",
437 | "Epoch=3300\tloss=5.825e-04\n",
438 | "Epoch=3400\tloss=4.711e-04\n",
439 | "Epoch=3500\tloss=6.700e-04\n",
440 | "Epoch=3600\tloss=3.161e-04\n",
441 | "Epoch=3700\tloss=3.296e-04\n",
442 | "Epoch=3800\tloss=2.278e-04\n",
443 | "Epoch=3900\tloss=2.032e-04\n",
444 | "Epoch=4000\tloss=1.785e-04\n",
445 | "Epoch=4100\tloss=1.624e-04\n",
446 | "Epoch=4200\tloss=1.494e-04\n",
447 | "Epoch=4300\tloss=1.377e-04\n",
448 | "Epoch=4400\tloss=2.545e-04\n",
449 | "Epoch=4500\tloss=1.195e-04\n",
450 | "Epoch=4600\tloss=2.162e-03\n",
451 | "Epoch=4700\tloss=1.041e-04\n",
452 | "Epoch=4800\tloss=1.321e-04\n",
453 | "Epoch=4900\tloss=9.035e-05\n",
454 | "Epoch=5000\tloss=8.385e-05\n",
455 | "Epoch=5100\tloss=8.181e-05\n",
456 | "Epoch=5200\tloss=7.205e-05\n",
457 | "Epoch=5300\tloss=2.472e-03\n",
458 | "Epoch=5400\tloss=6.137e-05\n",
459 | "Epoch=5500\tloss=6.695e-05\n",
460 | "Epoch=5600\tloss=5.197e-05\n",
461 | "Epoch=5700\tloss=4.714e-05\n",
462 | "Epoch=5800\tloss=7.266e-05\n",
463 | "Epoch=5900\tloss=3.922e-05\n",
464 | "Epoch=6000\tloss=1.470e-03\n",
465 | "Epoch=6100\tloss=3.234e-05\n",
466 | "Epoch=6200\tloss=1.919e-03\n",
467 | "Epoch=6300\tloss=2.676e-05\n",
468 | "Epoch=6400\tloss=2.389e-05\n",
469 | "Epoch=6500\tloss=2.301e-05\n",
470 | "Epoch=6600\tloss=1.962e-05\n",
471 | "Epoch=6700\tloss=3.187e-04\n",
472 | "Epoch=6800\tloss=1.618e-05\n",
473 | "Epoch=6900\tloss=2.400e-03\n",
474 | "Epoch=7000\tloss=1.355e-05\n",
475 | "Epoch=7100\tloss=2.970e-04\n",
476 | "Epoch=7200\tloss=1.146e-05\n",
477 | "Epoch=7300\tloss=1.040e-05\n",
478 | "Epoch=7400\tloss=3.527e-05\n",
479 | "Epoch=7500\tloss=9.057e-06\n",
480 | "Epoch=7600\tloss=4.930e-04\n",
481 | "Epoch=7700\tloss=3.244e-04\n",
482 | "Epoch=7800\tloss=1.300e-03\n",
483 | "Epoch=7900\tloss=7.630e-06\n",
484 | "Epoch=8000\tloss=1.313e-05\n",
485 | "Epoch=8100\tloss=1.451e-05\n",
486 | "Epoch=8200\tloss=7.152e-06\n",
487 | "Epoch=8300\tloss=6.462e-05\n",
488 | "Epoch=8400\tloss=5.616e-03\n",
489 | "Epoch=8500\tloss=6.044e-06\n",
490 | "Epoch=8600\tloss=7.831e-05\n",
491 | "Epoch=8700\tloss=5.532e-06\n",
492 | "Epoch=8800\tloss=8.422e-05\n",
493 | "Epoch=8900\tloss=1.461e-03\n",
494 | "Epoch=9000\tloss=5.241e-06\n",
495 | "Epoch=9100\tloss=2.537e-04\n",
496 | "Epoch=9200\tloss=4.977e-06\n",
497 | "Epoch=9300\tloss=4.646e-05\n",
498 | "Epoch=9400\tloss=4.768e-06\n",
499 | "Epoch=9500\tloss=1.605e-05\n",
500 | "Epoch=9600\tloss=4.656e-06\n",
501 | "Epoch=9700\tloss=3.441e-03\n",
502 | "Epoch=9800\tloss=4.752e-06\n",
503 | "Epoch=9900\tloss=6.980e-03\n",
504 | "CPU times: user 17.7 s, sys: 139 ms, total: 17.8 s\n",
505 | "Wall time: 17.7 s\n"
506 | ]
507 | }
508 | ],
509 | "source": [
510 | "%%time\n",
511 | "epochs = 10_000\n",
512 | "for _ in range(epochs):\n",
513 | " opt_state,params = update(opt_state,params,colloc,conds)\n",
514 | "\n",
515 | " # print loss and epoch info\n",
516 | " if _ %(100) ==0:\n",
517 | " print(f'Epoch={_}\\tloss={loss_fun(params,colloc,conds):.3e}')"
518 | ]
519 | },
520 | {
521 | "cell_type": "code",
522 | "execution_count": 80,
523 | "metadata": {
524 | "colab": {
525 | "base_uri": "https://localhost:8080/",
526 | "height": 282
527 | },
528 | "id": "eWeNvDsdDEuI",
529 | "outputId": "32551eeb-25df-4d2e-8cae-82cc52b41ac5"
530 | },
531 | "outputs": [
532 | {
533 | "data": {
534 | "text/plain": [
535 | ""
536 | ]
537 | },
538 | "execution_count": 80,
539 | "metadata": {},
540 | "output_type": "execute_result"
541 | },
542 | {
543 | "data": {
544 | "image/png": "",
545 | "text/plain": [
546 | ""
547 | ]
548 | },
549 | "metadata": {
550 | "needs_background": "light"
551 | },
552 | "output_type": "display_data"
553 | }
554 | ],
555 | "source": [
556 | "lam_sol= sp.lambdify(t,sol)\n",
557 | "\n",
558 | "dT = 1e-3\n",
559 | "Tf = jnp.pi\n",
560 | "T = np.arange(0,Tf+dT,dT)\n",
561 | "\n",
562 | "\n",
563 | "sym_sol =np.array([lam_sol(i) for i in T])\n",
564 | "\n",
565 | "plt.plot(T,sym_sol,'--r',label='sympy solution')\n",
566 | "plt.plot(T,fwd(params,T.reshape(-1,1))[:,0],'--k',label='NN solution')\n",
567 | "plt.legend()"
568 | ]
569 | }
570 | ],
571 | "metadata": {
572 | "colab": {
573 | "authorship_tag": "ABX9TyPlDK/9ZvMulH91+B+B32BC",
574 | "collapsed_sections": [],
575 | "include_colab_link": true,
576 | "name": "[1] ODE-PINN.ipynb",
577 | "provenance": []
578 | },
579 | "kernelspec": {
580 | "display_name": "Python 3.8.9 64-bit",
581 | "language": "python",
582 | "name": "python3"
583 | },
584 | "language_info": {
585 | "codemirror_mode": {
586 | "name": "ipython",
587 | "version": 3
588 | },
589 | "file_extension": ".py",
590 | "mimetype": "text/x-python",
591 | "name": "python",
592 | "nbconvert_exporter": "python",
593 | "pygments_lexer": "ipython3",
594 | "version": "3.8.9"
595 | },
596 | "vscode": {
597 | "interpreter": {
598 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
599 | }
600 | }
601 | },
602 | "nbformat": 4,
603 | "nbformat_minor": 4
604 | }
605 |
--------------------------------------------------------------------------------
/[1]_ODE_PINN_ClassForm.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "source": [
16 | "!pip install optax\n",
17 | "!pip install pytreeclass\n",
18 | "!pip install tqdm"
19 | ],
20 | "metadata": {
21 | "id": "PmiMsCTpOtKE",
22 | "colab": {
23 | "base_uri": "https://localhost:8080/"
24 | },
25 | "outputId": "673af120-a5ba-45a4-953f-e77dee359c09"
26 | },
27 | "execution_count": 1,
28 | "outputs": [
29 | {
30 | "output_type": "stream",
31 | "name": "stdout",
32 | "text": [
33 | "Requirement already satisfied: optax in /usr/local/lib/python3.10/dist-packages (0.1.7)\n",
34 | "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from optax) (1.4.0)\n",
35 | "Requirement already satisfied: chex>=0.1.5 in /usr/local/lib/python3.10/dist-packages (from optax) (0.1.7)\n",
36 | "Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.10/dist-packages (from optax) (0.4.14)\n",
37 | "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.10/dist-packages (from optax) (0.4.14+cuda11.cudnn86)\n",
38 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.10/dist-packages (from optax) (1.23.5)\n",
39 | "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.10/dist-packages (from chex>=0.1.5->optax) (0.1.8)\n",
40 | "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.1.5->optax) (0.12.0)\n",
41 | "Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.1.5->optax) (4.7.1)\n",
42 | "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.1.55->optax) (0.2.0)\n",
43 | "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax>=0.1.55->optax) (3.3.0)\n",
44 | "Requirement already satisfied: scipy>=1.7 in /usr/local/lib/python3.10/dist-packages (from jax>=0.1.55->optax) (1.10.1)\n",
45 | "Requirement already satisfied: pytreeclass in /usr/local/lib/python3.10/dist-packages (0.6.0)\n",
46 | "Requirement already satisfied: jax>=0.4.7 in /usr/local/lib/python3.10/dist-packages (from pytreeclass) (0.4.14)\n",
47 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from pytreeclass) (4.7.1)\n",
48 | "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.7->pytreeclass) (0.2.0)\n",
49 | "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.7->pytreeclass) (1.23.5)\n",
50 | "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.7->pytreeclass) (3.3.0)\n",
51 | "Requirement already satisfied: scipy>=1.7 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.7->pytreeclass) (1.10.1)\n",
52 | "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.1)\n"
53 | ]
54 | }
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 2,
60 | "metadata": {
61 | "id": "vAR0swbLX_ZI"
62 | },
63 | "outputs": [],
64 | "source": [
65 | "# Imports\n",
66 | "from __future__ import annotations\n",
67 | "from typing import Callable\n",
68 | "import jax\n",
69 | "import jax.numpy as jnp\n",
70 | "import numpy as np\n",
71 | "import matplotlib.pyplot as plt\n",
72 | "import optax\n",
73 | "import sympy as sp\n",
74 | "import pytreeclass as pytc\n",
75 | "from tqdm.notebook import tqdm"
76 | ]
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "metadata": {
81 | "id": "7bg4nSbsXVwD"
82 | },
83 | "source": [
84 | "### Generate a a differential equation and its solution using SymPy"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 3,
90 | "metadata": {
91 | "id": "P9664e-mVMTN"
92 | },
93 | "outputs": [],
94 | "source": [
95 | "t= sp.symbols('t')\n",
96 | "f = sp.Function('y')\n",
97 | "diffeq = sp.Eq(f(t).diff(t,t) + f(t).diff(t)-t*sp.cos(2*sp.pi*t),0)\n",
98 | "sol = sp.simplify(sp.dsolve(diffeq,ics={f(0):1,f(t).diff(t).subs(t,0):10}).rhs)"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": 4,
104 | "metadata": {
105 | "id": "klgFeU6bcTrC",
106 | "colab": {
107 | "base_uri": "https://localhost:8080/",
108 | "height": 54
109 | },
110 | "outputId": "060d7276-cb05-418f-8be9-91e7a671a8bb"
111 | },
112 | "outputs": [
113 | {
114 | "output_type": "execute_result",
115 | "data": {
116 | "text/plain": [
117 | "Eq(-t*cos(2*pi*t) + Derivative(y(t), t) + Derivative(y(t), (t, 2)), 0)"
118 | ],
119 | "text/latex": "$\\displaystyle - t \\cos{\\left(2 \\pi t \\right)} + \\frac{d}{d t} y{\\left(t \\right)} + \\frac{d^{2}}{d t^{2}} y{\\left(t \\right)} = 0$"
120 | },
121 | "metadata": {},
122 | "execution_count": 4
123 | }
124 | ],
125 | "source": [
126 | "diffeq"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": 5,
132 | "metadata": {
133 | "id": "E4Uu2hbiYJtv",
134 | "colab": {
135 | "base_uri": "https://localhost:8080/",
136 | "height": 61
137 | },
138 | "outputId": "d66ce7f6-f1b4-4667-faa7-8708227ff805"
139 | },
140 | "outputs": [
141 | {
142 | "output_type": "execute_result",
143 | "data": {
144 | "text/plain": [
145 | "Eq(Subs(Derivative(y(t), t), t, 0), 10)"
146 | ],
147 | "text/latex": "$\\displaystyle \\left. \\frac{d}{d t} y{\\left(t \\right)} \\right|_{\\substack{ t=0 }} = 10$"
148 | },
149 | "metadata": {},
150 | "execution_count": 5
151 | }
152 | ],
153 | "source": [
154 | "sp.Eq(f(t).diff(t).subs(t,0),10)"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": 6,
160 | "metadata": {
161 | "id": "29QUbt_2YwlJ",
162 | "colab": {
163 | "base_uri": "https://localhost:8080/",
164 | "height": 39
165 | },
166 | "outputId": "abd0aca6-9156-49f1-fba3-9349e0b5fe9d"
167 | },
168 | "outputs": [
169 | {
170 | "output_type": "execute_result",
171 | "data": {
172 | "text/plain": [
173 | "Eq(y(0), 1)"
174 | ],
175 | "text/latex": "$\\displaystyle y{\\left(0 \\right)} = 1$"
176 | },
177 | "metadata": {},
178 | "execution_count": 6
179 | }
180 | ],
181 | "source": [
182 | "sp.Eq(f(t).subs(t,0),1)"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 7,
188 | "metadata": {
189 | "id": "r9KVq1yjYfld",
190 | "colab": {
191 | "base_uri": "https://localhost:8080/",
192 | "height": 82
193 | },
194 | "outputId": "c8ef8212-caed-48a7-fbe2-1a0bc700570e"
195 | },
196 | "outputs": [
197 | {
198 | "output_type": "execute_result",
199 | "data": {
200 | "text/plain": [
201 | "Eq(y(t), (2*pi*t*exp(t)*sin(2*pi*t) + 8*pi**3*t*exp(t)*sin(2*pi*t) - 16*pi**4*t*exp(t)*cos(2*pi*t) - 4*pi**2*t*exp(t)*cos(2*pi*t) + 16*pi**3*exp(t)*sin(2*pi*t) + exp(t)*cos(2*pi*t) + 12*pi**2*exp(t)*cos(2*pi*t) - exp(t) + 36*pi**2*exp(t) + 336*pi**4*exp(t) + 704*pi**6*exp(t) - 640*pi**6 - 304*pi**4 - 44*pi**2)*exp(-t)/(4*pi**2*(1 + 8*pi**2 + 16*pi**4)))"
202 | ],
203 | "text/latex": "$\\displaystyle y{\\left(t \\right)} = \\frac{\\left(2 \\pi t e^{t} \\sin{\\left(2 \\pi t \\right)} + 8 \\pi^{3} t e^{t} \\sin{\\left(2 \\pi t \\right)} - 16 \\pi^{4} t e^{t} \\cos{\\left(2 \\pi t \\right)} - 4 \\pi^{2} t e^{t} \\cos{\\left(2 \\pi t \\right)} + 16 \\pi^{3} e^{t} \\sin{\\left(2 \\pi t \\right)} + e^{t} \\cos{\\left(2 \\pi t \\right)} + 12 \\pi^{2} e^{t} \\cos{\\left(2 \\pi t \\right)} - e^{t} + 36 \\pi^{2} e^{t} + 336 \\pi^{4} e^{t} + 704 \\pi^{6} e^{t} - 640 \\pi^{6} - 304 \\pi^{4} - 44 \\pi^{2}\\right) e^{- t}}{4 \\pi^{2} \\cdot \\left(1 + 8 \\pi^{2} + 16 \\pi^{4}\\right)}$"
204 | },
205 | "metadata": {},
206 | "execution_count": 7
207 | }
208 | ],
209 | "source": [
210 | "sp.Eq(f(t),sol)"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": 8,
216 | "metadata": {
217 | "id": "MNVOpPyCW-GU",
218 | "colab": {
219 | "base_uri": "https://localhost:8080/",
220 | "height": 37
221 | },
222 | "outputId": "a477b9e6-cad4-4584-e926-d3e8c9538160"
223 | },
224 | "outputs": [
225 | {
226 | "output_type": "execute_result",
227 | "data": {
228 | "text/plain": [
229 | "0"
230 | ],
231 | "text/latex": "$\\displaystyle 0$"
232 | },
233 | "metadata": {},
234 | "execution_count": 8
235 | }
236 | ],
237 | "source": [
238 | "#verify solution\n",
239 | "sp.simplify(-t*sp.cos(sp.pi*2*t)+sol.diff(t)+sol.diff(t,t))"
240 | ]
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "metadata": {
245 | "id": "NQ61lEQeXgrc"
246 | },
247 | "source": [
248 | "### Constructing the MLP"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": 9,
254 | "metadata": {
255 | "id": "Lml6PGLPZgmr",
256 | "colab": {
257 | "base_uri": "https://localhost:8080/"
258 | },
259 | "outputId": "ad72cbc0-73f6-4a44-d384-3e2de933faf2"
260 | },
261 | "outputs": [
262 | {
263 | "output_type": "stream",
264 | "name": "stderr",
265 | "text": [
266 | "WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
267 | ]
268 | }
269 | ],
270 | "source": [
271 | "# construct data\n",
272 | "\n",
273 | "N_b = 1\n",
274 | "N_c = 100\n",
275 | "\n",
276 | "tmin, tmax = 0.0, jnp.pi\n",
277 | "\n",
278 | "\"\"\"boundary conditions\"\"\"\n",
279 | "\n",
280 | "\n",
281 | "# U[0] = 1\n",
282 | "t_0 = jnp.ones([N_b, 1], dtype=\"float32\") * 0.0\n",
283 | "ic_0 = jnp.ones_like(t_0)\n",
284 | "IC_0 = jnp.concatenate([t_0, ic_0], axis=1)\n",
285 | "\n",
286 | "# U_t[0] = 10\n",
287 | "t_b1 = jnp.zeros([N_b, 1])\n",
288 | "bc_1 = jnp.ones_like(t_b1) * 10\n",
289 | "BC_1 = jnp.concatenate([t_b1, bc_1], axis=1)\n",
290 | "\n",
291 | "conds: list[jax.Array] = [IC_0, BC_1]\n",
292 | "\n",
293 | "# collocation points\n",
294 | "\n",
295 | "key = jax.random.PRNGKey(0)\n",
296 | "\n",
297 | "t_c = jax.random.uniform(key, minval=tmin, maxval=tmax, shape=(N_c, 1))\n",
298 | "colloc = t_c"
299 | ]
300 | },
301 | {
302 | "cell_type": "markdown",
303 | "source": [
304 | "# Build Model"
305 | ],
306 | "metadata": {
307 | "id": "V_gR3d4AKHxJ"
308 | }
309 | },
310 | {
311 | "cell_type": "code",
312 | "source": [
313 | "init_func = jax.nn.initializers.glorot_uniform()\n",
314 | "\n",
315 | "\n",
316 | "class Linear(pytc.TreeClass):\n",
317 | " def __init__(\n",
318 | " self,\n",
319 | " in_features: int,\n",
320 | " out_features: int,\n",
321 | " key: jax.random.KeyArray = jax.random.PRNGKey(0),\n",
322 | " ):\n",
323 | " self.weight = init_func(key, (in_features, out_features))\n",
324 | " self.bias = jax.numpy.zeros((out_features,))\n",
325 | "\n",
326 | " def __call__(self, x: jax.Array) -> jax.Array:\n",
327 | " return x @ self.weight + self.bias\n",
328 | "\n",
329 | "\n",
330 | "class MLP(pytc.TreeClass):\n",
331 | " def __init__(self, key: jax.random.KeyArray = jax.random.PRNGKey(0)):\n",
332 | " k1, k2, k3, k4 = jax.random.split(key, 4)\n",
333 | " self.l1 = Linear(1, 20, key=k1)\n",
334 | " self.l2 = Linear(20, 20, key=k2)\n",
335 | " self.l3 = Linear(20, 20, key=k3)\n",
336 | " self.l4 = Linear(20, 1, key=k4)\n",
337 | "\n",
338 | " def __call__(self, x: jax.Array) -> jax.Array:\n",
339 | " x = self.l1(x)\n",
340 | " x = jax.nn.tanh(x)\n",
341 | " x = self.l2(x)\n",
342 | " x = jax.nn.tanh(x)\n",
343 | " x = self.l3(x)\n",
344 | " x = jax.nn.tanh(x)\n",
345 | " x = self.l4(x)\n",
346 | " return x\n",
347 | "\n",
348 | "\n",
349 | "model = MLP()\n",
350 | "print(pytc.tree_summary(model))"
351 | ],
352 | "metadata": {
353 | "id": "TD7IQp70F65_",
354 | "colab": {
355 | "base_uri": "https://localhost:8080/"
356 | },
357 | "outputId": "0fdcd611-098e-4b46-b27f-33d0da1c725d"
358 | },
359 | "execution_count": 10,
360 | "outputs": [
361 | {
362 | "output_type": "stream",
363 | "name": "stdout",
364 | "text": [
365 | "┌──────────┬──────────┬─────┬──────┐\n",
366 | "│Name │Type │Count│Size │\n",
367 | "├──────────┼──────────┼─────┼──────┤\n",
368 | "│.l1.weight│f32[1,20] │20 │80.00B│\n",
369 | "├──────────┼──────────┼─────┼──────┤\n",
370 | "│.l1.bias │f32[20] │20 │80.00B│\n",
371 | "├──────────┼──────────┼─────┼──────┤\n",
372 | "│.l2.weight│f32[20,20]│400 │1.56KB│\n",
373 | "├──────────┼──────────┼─────┼──────┤\n",
374 | "│.l2.bias │f32[20] │20 │80.00B│\n",
375 | "├──────────┼──────────┼─────┼──────┤\n",
376 | "│.l3.weight│f32[20,20]│400 │1.56KB│\n",
377 | "├──────────┼──────────┼─────┼──────┤\n",
378 | "│.l3.bias │f32[20] │20 │80.00B│\n",
379 | "├──────────┼──────────┼─────┼──────┤\n",
380 | "│.l4.weight│f32[20,1] │20 │80.00B│\n",
381 | "├──────────┼──────────┼─────┼──────┤\n",
382 | "│.l4.bias │f32[1] │1 │4.00B │\n",
383 | "├──────────┼──────────┼─────┼──────┤\n",
384 | "│Σ │MLP │901 │3.52KB│\n",
385 | "└──────────┴──────────┴─────┴──────┘\n"
386 | ]
387 | }
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": 11,
393 | "metadata": {
394 | "id": "KoZZJl2TbI_n"
395 | },
396 | "outputs": [],
397 | "source": [
398 | "def mse(true, pred):\n",
399 | " return jnp.mean((true - pred) ** 2)\n",
400 | "\n",
401 | "\n",
402 | "def diff(func: Callable, *args, **kwargs):\n",
403 | " \"\"\"sum then grad\"\"\"\n",
404 | " return jax.grad(lambda *ar, **kws: jnp.sum(func(*ar, **kws)), *args, **kwargs)\n",
405 | "\n",
406 | "\n",
407 | "def ode_loss(t, u):\n",
408 | " u_t = diff(u)\n",
409 | " u_tt = diff(u_t)\n",
410 | " return -t * jnp.cos(2 * jnp.pi * t) + u_t(t) + u_tt(t)\n",
411 | "\n",
412 | "\n",
413 | "def loss_func(model, colloc, conds):\n",
414 | " t_c = colloc[:, [0]]\n",
415 | " ufunc = model\n",
416 | " ufunc_t = diff(model)\n",
417 | "\n",
418 | " loss = jnp.mean(ode_loss(t_c, ufunc) ** 2)\n",
419 | "\n",
420 | " t_ic, u_ic = conds[0][:, [0]], conds[0][:, [1]]\n",
421 | " loss += mse(u_ic, ufunc(t_ic))\n",
422 | "\n",
423 | " t_bc, u_bc = conds[1][:, [0]], conds[1][:, [1]]\n",
424 | " loss += mse(u_bc, ufunc_t(t_bc))\n",
425 | "\n",
426 | " return loss\n",
427 | "\n",
428 | "\n",
429 | "optim = optax.adam(1e-3)\n",
430 | "optim_state = optim.init(model)\n",
431 | "\n",
432 | "\n",
433 | "@jax.jit\n",
434 | "def train_step(\n",
435 | " model: MLP,\n",
436 | " optim_state: optax.OptState,\n",
437 | " colloc: jax.Array,\n",
438 | " conds: list[jax.Array],\n",
439 | "):\n",
440 | " # Get the gradient w.r.t to MLP params\n",
441 | " grads = jax.grad(loss_func)(model, colloc, conds)\n",
442 | "\n",
443 | " # Update model\n",
444 | " updates, optim_state = optim.update(grads, optim_state)\n",
445 | " model = optax.apply_updates(model, updates)\n",
446 | "\n",
447 | " return model, optim_state"
448 | ]
449 | },
450 | {
451 | "cell_type": "code",
452 | "execution_count": 12,
453 | "metadata": {
454 | "id": "kBzGA8OVc8C6",
455 | "colab": {
456 | "base_uri": "https://localhost:8080/",
457 | "height": 1000,
458 | "referenced_widgets": [
459 | "271df676bcac4937bada71edccf32886",
460 | "ee8a741f598f4dc7b751f21b2e975bf7",
461 | "cad59966e8754537bfb7f8ef5d8289ac",
462 | "949e74fbeb5d4892a5b3b47cacfde5c5",
463 | "8189284ea2fc475192db86f14b63580a",
464 | "2285fe8ee81e45229e30ddd255357255",
465 | "5f82eb29bbd64344b93a1f9565397184",
466 | "b47d9843e3894611a631df2b46cc9ac0",
467 | "0df7710f34444afbba03b6232de3e1fd",
468 | "e193d91f0e414ec7b527d9c0ebd47558",
469 | "43561c163e1f4dec929eac0ead24d266"
470 | ]
471 | },
472 | "outputId": "8e5ce4f3-6102-435b-ec4b-62d519d6178d"
473 | },
474 | "outputs": [
475 | {
476 | "output_type": "display_data",
477 | "data": {
478 | "text/plain": [
479 | " 0%| | 0/10000 [00:00, ?it/s]"
480 | ],
481 | "application/vnd.jupyter.widget-view+json": {
482 | "version_major": 2,
483 | "version_minor": 0,
484 | "model_id": "271df676bcac4937bada71edccf32886"
485 | }
486 | },
487 | "metadata": {}
488 | },
489 | {
490 | "output_type": "stream",
491 | "name": "stdout",
492 | "text": [
493 | "Epoch=100\tloss=7.777e+00\n",
494 | "Epoch=200\tloss=4.808e+00\n",
495 | "Epoch=300\tloss=3.418e+00\n",
496 | "Epoch=400\tloss=2.729e+00\n",
497 | "Epoch=500\tloss=2.319e+00\n",
498 | "Epoch=600\tloss=2.135e+00\n",
499 | "Epoch=700\tloss=2.059e+00\n",
500 | "Epoch=800\tloss=2.009e+00\n",
501 | "Epoch=900\tloss=1.959e+00\n",
502 | "Epoch=1000\tloss=1.902e+00\n",
503 | "Epoch=1100\tloss=1.841e+00\n",
504 | "Epoch=1200\tloss=1.779e+00\n",
505 | "Epoch=1300\tloss=1.719e+00\n",
506 | "Epoch=1400\tloss=1.666e+00\n",
507 | "Epoch=1500\tloss=1.620e+00\n",
508 | "Epoch=1600\tloss=1.566e+00\n",
509 | "Epoch=1700\tloss=1.475e+00\n",
510 | "Epoch=1800\tloss=1.291e+00\n",
511 | "Epoch=1900\tloss=1.112e+00\n",
512 | "Epoch=2000\tloss=8.120e-01\n",
513 | "Epoch=2100\tloss=1.764e-01\n",
514 | "Epoch=2200\tloss=1.382e-02\n",
515 | "Epoch=2300\tloss=5.528e-03\n",
516 | "Epoch=2400\tloss=4.135e-03\n",
517 | "Epoch=2500\tloss=3.343e-03\n",
518 | "Epoch=2600\tloss=2.765e-03\n",
519 | "Epoch=2700\tloss=2.345e-03\n",
520 | "Epoch=2800\tloss=2.044e-03\n",
521 | "Epoch=2900\tloss=1.826e-03\n",
522 | "Epoch=3000\tloss=1.662e-03\n",
523 | "Epoch=3100\tloss=1.533e-03\n",
524 | "Epoch=3200\tloss=1.427e-03\n",
525 | "Epoch=3300\tloss=1.335e-03\n",
526 | "Epoch=3400\tloss=1.254e-03\n",
527 | "Epoch=3500\tloss=1.180e-03\n",
528 | "Epoch=3600\tloss=1.113e-03\n",
529 | "Epoch=3700\tloss=1.050e-03\n",
530 | "Epoch=3800\tloss=9.922e-04\n",
531 | "Epoch=3900\tloss=9.382e-04\n",
532 | "Epoch=4000\tloss=8.878e-04\n",
533 | "Epoch=4100\tloss=8.411e-04\n",
534 | "Epoch=4200\tloss=7.969e-04\n",
535 | "Epoch=4300\tloss=7.584e-04\n",
536 | "Epoch=4400\tloss=7.174e-04\n",
537 | "Epoch=4500\tloss=6.867e-04\n",
538 | "Epoch=4600\tloss=6.469e-04\n",
539 | "Epoch=4700\tloss=6.162e-04\n",
540 | "Epoch=4800\tloss=5.888e-04\n",
541 | "Epoch=4900\tloss=5.540e-04\n",
542 | "Epoch=5000\tloss=5.256e-04\n",
543 | "Epoch=5100\tloss=4.983e-04\n",
544 | "Epoch=5200\tloss=4.781e-04\n",
545 | "Epoch=5300\tloss=4.465e-04\n",
546 | "Epoch=5400\tloss=4.320e-04\n",
547 | "Epoch=5500\tloss=3.982e-04\n",
548 | "Epoch=5600\tloss=4.065e-04\n",
549 | "Epoch=5700\tloss=7.689e-04\n",
550 | "Epoch=5800\tloss=3.319e-04\n",
551 | "Epoch=5900\tloss=3.510e-04\n",
552 | "Epoch=6000\tloss=2.920e-04\n",
553 | "Epoch=6100\tloss=2.819e-04\n",
554 | "Epoch=6200\tloss=2.558e-04\n",
555 | "Epoch=6300\tloss=2.390e-04\n",
556 | "Epoch=6400\tloss=2.541e-04\n",
557 | "Epoch=6500\tloss=2.083e-04\n",
558 | "Epoch=6600\tloss=2.044e-04\n",
559 | "Epoch=6700\tloss=1.817e-04\n",
560 | "Epoch=6800\tloss=1.796e-04\n",
561 | "Epoch=6900\tloss=1.590e-04\n",
562 | "Epoch=7000\tloss=1.675e-04\n",
563 | "Epoch=7100\tloss=1.400e-04\n",
564 | "Epoch=7200\tloss=4.347e-04\n",
565 | "Epoch=7300\tloss=1.242e-04\n",
566 | "Epoch=7400\tloss=2.232e-03\n",
567 | "Epoch=7500\tloss=1.114e-04\n",
568 | "Epoch=7600\tloss=1.159e-04\n",
569 | "Epoch=7700\tloss=1.010e-04\n",
570 | "Epoch=7800\tloss=4.399e-04\n",
571 | "Epoch=7900\tloss=9.291e-05\n",
572 | "Epoch=8000\tloss=1.408e-04\n",
573 | "Epoch=8100\tloss=8.608e-05\n",
574 | "Epoch=8200\tloss=1.010e-04\n",
575 | "Epoch=8300\tloss=6.762e-04\n",
576 | "Epoch=8400\tloss=7.779e-05\n",
577 | "Epoch=8500\tloss=8.429e-05\n",
578 | "Epoch=8600\tloss=2.236e-04\n",
579 | "Epoch=8700\tloss=7.161e-05\n",
580 | "Epoch=8800\tloss=8.281e-05\n",
581 | "Epoch=8900\tloss=4.630e-04\n",
582 | "Epoch=9000\tloss=9.161e-05\n",
583 | "Epoch=9100\tloss=3.809e-04\n",
584 | "Epoch=9200\tloss=6.414e-05\n",
585 | "Epoch=9300\tloss=6.829e-05\n",
586 | "Epoch=9400\tloss=2.783e-04\n",
587 | "Epoch=9500\tloss=1.557e-04\n",
588 | "Epoch=9600\tloss=5.868e-05\n",
589 | "Epoch=9700\tloss=8.793e-05\n",
590 | "Epoch=9800\tloss=5.920e-05\n",
591 | "Epoch=9900\tloss=5.596e-05\n",
592 | "Epoch=10000\tloss=1.025e-03\n",
593 | "CPU times: user 19.7 s, sys: 198 ms, total: 19.9 s\n",
594 | "Wall time: 29.6 s\n"
595 | ]
596 | }
597 | ],
598 | "source": [
599 | "%%time\n",
600 | "epochs = 10_000\n",
601 | "for _ in tqdm(range(1,epochs+1)):\n",
602 | " model,optim_state = train_step(model, optim_state,colloc,conds)\n",
603 | "\n",
604 | " # print loss and epoch info\n",
605 | " if _ %(100) ==0:\n",
606 | " print(f'Epoch={_}\\tloss={loss_func(model,colloc,conds):.3e}')"
607 | ]
608 | },
609 | {
610 | "cell_type": "code",
611 | "source": [
612 | "lam_sol= sp.lambdify(t,sol)\n",
613 | "\n",
614 | "dT = 1e-3\n",
615 | "Tf = jnp.pi\n",
616 | "T = np.arange(0,Tf+dT,dT)\n",
617 | "\n",
618 | "\n",
619 | "sym_sol =np.array([lam_sol(i) for i in T])\n",
620 | "\n",
621 | "plt.plot(T,sym_sol,'--r',label='sympy solution')\n",
622 | "plt.plot(T,model(T.reshape(-1,1))[:,0],'--k',label='NN solution')\n",
623 | "plt.legend()"
624 | ],
625 | "metadata": {
626 | "colab": {
627 | "base_uri": "https://localhost:8080/",
628 | "height": 447
629 | },
630 | "id": "hIkxKojKR3A3",
631 | "outputId": "ffc9bf75-3c57-4494-cafd-156d5faf4cd5"
632 | },
633 | "execution_count": 13,
634 | "outputs": [
635 | {
636 | "output_type": "execute_result",
637 | "data": {
638 | "text/plain": [
639 | ""
640 | ]
641 | },
642 | "metadata": {},
643 | "execution_count": 13
644 | },
645 | {
646 | "output_type": "display_data",
647 | "data": {
648 | "text/plain": [
649 | ""
650 | ],
651 | "image/png": "\n"
652 | },
653 | "metadata": {}
654 | }
655 | ]
656 | }
657 | ],
658 | "metadata": {
659 | "colab": {
660 | "name": "[1] ODE-PINN-ClassForm.ipynb",
661 | "provenance": [],
662 | "include_colab_link": true
663 | },
664 | "kernelspec": {
665 | "display_name": "Python 3.8.9 64-bit",
666 | "language": "python",
667 | "name": "python3"
668 | },
669 | "language_info": {
670 | "codemirror_mode": {
671 | "name": "ipython",
672 | "version": 3
673 | },
674 | "file_extension": ".py",
675 | "mimetype": "text/x-python",
676 | "name": "python",
677 | "nbconvert_exporter": "python",
678 | "pygments_lexer": "ipython3",
679 | "version": "3.8.9"
680 | },
681 | "vscode": {
682 | "interpreter": {
683 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
684 | }
685 | },
686 | "widgets": {
687 | "application/vnd.jupyter.widget-state+json": {
688 | "271df676bcac4937bada71edccf32886": {
689 | "model_module": "@jupyter-widgets/controls",
690 | "model_name": "HBoxModel",
691 | "model_module_version": "1.5.0",
692 | "state": {
693 | "_dom_classes": [],
694 | "_model_module": "@jupyter-widgets/controls",
695 | "_model_module_version": "1.5.0",
696 | "_model_name": "HBoxModel",
697 | "_view_count": null,
698 | "_view_module": "@jupyter-widgets/controls",
699 | "_view_module_version": "1.5.0",
700 | "_view_name": "HBoxView",
701 | "box_style": "",
702 | "children": [
703 | "IPY_MODEL_ee8a741f598f4dc7b751f21b2e975bf7",
704 | "IPY_MODEL_cad59966e8754537bfb7f8ef5d8289ac",
705 | "IPY_MODEL_949e74fbeb5d4892a5b3b47cacfde5c5"
706 | ],
707 | "layout": "IPY_MODEL_8189284ea2fc475192db86f14b63580a"
708 | }
709 | },
710 | "ee8a741f598f4dc7b751f21b2e975bf7": {
711 | "model_module": "@jupyter-widgets/controls",
712 | "model_name": "HTMLModel",
713 | "model_module_version": "1.5.0",
714 | "state": {
715 | "_dom_classes": [],
716 | "_model_module": "@jupyter-widgets/controls",
717 | "_model_module_version": "1.5.0",
718 | "_model_name": "HTMLModel",
719 | "_view_count": null,
720 | "_view_module": "@jupyter-widgets/controls",
721 | "_view_module_version": "1.5.0",
722 | "_view_name": "HTMLView",
723 | "description": "",
724 | "description_tooltip": null,
725 | "layout": "IPY_MODEL_2285fe8ee81e45229e30ddd255357255",
726 | "placeholder": "",
727 | "style": "IPY_MODEL_5f82eb29bbd64344b93a1f9565397184",
728 | "value": "100%"
729 | }
730 | },
731 | "cad59966e8754537bfb7f8ef5d8289ac": {
732 | "model_module": "@jupyter-widgets/controls",
733 | "model_name": "FloatProgressModel",
734 | "model_module_version": "1.5.0",
735 | "state": {
736 | "_dom_classes": [],
737 | "_model_module": "@jupyter-widgets/controls",
738 | "_model_module_version": "1.5.0",
739 | "_model_name": "FloatProgressModel",
740 | "_view_count": null,
741 | "_view_module": "@jupyter-widgets/controls",
742 | "_view_module_version": "1.5.0",
743 | "_view_name": "ProgressView",
744 | "bar_style": "success",
745 | "description": "",
746 | "description_tooltip": null,
747 | "layout": "IPY_MODEL_b47d9843e3894611a631df2b46cc9ac0",
748 | "max": 10000,
749 | "min": 0,
750 | "orientation": "horizontal",
751 | "style": "IPY_MODEL_0df7710f34444afbba03b6232de3e1fd",
752 | "value": 10000
753 | }
754 | },
755 | "949e74fbeb5d4892a5b3b47cacfde5c5": {
756 | "model_module": "@jupyter-widgets/controls",
757 | "model_name": "HTMLModel",
758 | "model_module_version": "1.5.0",
759 | "state": {
760 | "_dom_classes": [],
761 | "_model_module": "@jupyter-widgets/controls",
762 | "_model_module_version": "1.5.0",
763 | "_model_name": "HTMLModel",
764 | "_view_count": null,
765 | "_view_module": "@jupyter-widgets/controls",
766 | "_view_module_version": "1.5.0",
767 | "_view_name": "HTMLView",
768 | "description": "",
769 | "description_tooltip": null,
770 | "layout": "IPY_MODEL_e193d91f0e414ec7b527d9c0ebd47558",
771 | "placeholder": "",
772 | "style": "IPY_MODEL_43561c163e1f4dec929eac0ead24d266",
773 | "value": " 10000/10000 [00:29<00:00, 886.65it/s]"
774 | }
775 | },
776 | "8189284ea2fc475192db86f14b63580a": {
777 | "model_module": "@jupyter-widgets/base",
778 | "model_name": "LayoutModel",
779 | "model_module_version": "1.2.0",
780 | "state": {
781 | "_model_module": "@jupyter-widgets/base",
782 | "_model_module_version": "1.2.0",
783 | "_model_name": "LayoutModel",
784 | "_view_count": null,
785 | "_view_module": "@jupyter-widgets/base",
786 | "_view_module_version": "1.2.0",
787 | "_view_name": "LayoutView",
788 | "align_content": null,
789 | "align_items": null,
790 | "align_self": null,
791 | "border": null,
792 | "bottom": null,
793 | "display": null,
794 | "flex": null,
795 | "flex_flow": null,
796 | "grid_area": null,
797 | "grid_auto_columns": null,
798 | "grid_auto_flow": null,
799 | "grid_auto_rows": null,
800 | "grid_column": null,
801 | "grid_gap": null,
802 | "grid_row": null,
803 | "grid_template_areas": null,
804 | "grid_template_columns": null,
805 | "grid_template_rows": null,
806 | "height": null,
807 | "justify_content": null,
808 | "justify_items": null,
809 | "left": null,
810 | "margin": null,
811 | "max_height": null,
812 | "max_width": null,
813 | "min_height": null,
814 | "min_width": null,
815 | "object_fit": null,
816 | "object_position": null,
817 | "order": null,
818 | "overflow": null,
819 | "overflow_x": null,
820 | "overflow_y": null,
821 | "padding": null,
822 | "right": null,
823 | "top": null,
824 | "visibility": null,
825 | "width": null
826 | }
827 | },
828 | "2285fe8ee81e45229e30ddd255357255": {
829 | "model_module": "@jupyter-widgets/base",
830 | "model_name": "LayoutModel",
831 | "model_module_version": "1.2.0",
832 | "state": {
833 | "_model_module": "@jupyter-widgets/base",
834 | "_model_module_version": "1.2.0",
835 | "_model_name": "LayoutModel",
836 | "_view_count": null,
837 | "_view_module": "@jupyter-widgets/base",
838 | "_view_module_version": "1.2.0",
839 | "_view_name": "LayoutView",
840 | "align_content": null,
841 | "align_items": null,
842 | "align_self": null,
843 | "border": null,
844 | "bottom": null,
845 | "display": null,
846 | "flex": null,
847 | "flex_flow": null,
848 | "grid_area": null,
849 | "grid_auto_columns": null,
850 | "grid_auto_flow": null,
851 | "grid_auto_rows": null,
852 | "grid_column": null,
853 | "grid_gap": null,
854 | "grid_row": null,
855 | "grid_template_areas": null,
856 | "grid_template_columns": null,
857 | "grid_template_rows": null,
858 | "height": null,
859 | "justify_content": null,
860 | "justify_items": null,
861 | "left": null,
862 | "margin": null,
863 | "max_height": null,
864 | "max_width": null,
865 | "min_height": null,
866 | "min_width": null,
867 | "object_fit": null,
868 | "object_position": null,
869 | "order": null,
870 | "overflow": null,
871 | "overflow_x": null,
872 | "overflow_y": null,
873 | "padding": null,
874 | "right": null,
875 | "top": null,
876 | "visibility": null,
877 | "width": null
878 | }
879 | },
880 | "5f82eb29bbd64344b93a1f9565397184": {
881 | "model_module": "@jupyter-widgets/controls",
882 | "model_name": "DescriptionStyleModel",
883 | "model_module_version": "1.5.0",
884 | "state": {
885 | "_model_module": "@jupyter-widgets/controls",
886 | "_model_module_version": "1.5.0",
887 | "_model_name": "DescriptionStyleModel",
888 | "_view_count": null,
889 | "_view_module": "@jupyter-widgets/base",
890 | "_view_module_version": "1.2.0",
891 | "_view_name": "StyleView",
892 | "description_width": ""
893 | }
894 | },
895 | "b47d9843e3894611a631df2b46cc9ac0": {
896 | "model_module": "@jupyter-widgets/base",
897 | "model_name": "LayoutModel",
898 | "model_module_version": "1.2.0",
899 | "state": {
900 | "_model_module": "@jupyter-widgets/base",
901 | "_model_module_version": "1.2.0",
902 | "_model_name": "LayoutModel",
903 | "_view_count": null,
904 | "_view_module": "@jupyter-widgets/base",
905 | "_view_module_version": "1.2.0",
906 | "_view_name": "LayoutView",
907 | "align_content": null,
908 | "align_items": null,
909 | "align_self": null,
910 | "border": null,
911 | "bottom": null,
912 | "display": null,
913 | "flex": null,
914 | "flex_flow": null,
915 | "grid_area": null,
916 | "grid_auto_columns": null,
917 | "grid_auto_flow": null,
918 | "grid_auto_rows": null,
919 | "grid_column": null,
920 | "grid_gap": null,
921 | "grid_row": null,
922 | "grid_template_areas": null,
923 | "grid_template_columns": null,
924 | "grid_template_rows": null,
925 | "height": null,
926 | "justify_content": null,
927 | "justify_items": null,
928 | "left": null,
929 | "margin": null,
930 | "max_height": null,
931 | "max_width": null,
932 | "min_height": null,
933 | "min_width": null,
934 | "object_fit": null,
935 | "object_position": null,
936 | "order": null,
937 | "overflow": null,
938 | "overflow_x": null,
939 | "overflow_y": null,
940 | "padding": null,
941 | "right": null,
942 | "top": null,
943 | "visibility": null,
944 | "width": null
945 | }
946 | },
947 | "0df7710f34444afbba03b6232de3e1fd": {
948 | "model_module": "@jupyter-widgets/controls",
949 | "model_name": "ProgressStyleModel",
950 | "model_module_version": "1.5.0",
951 | "state": {
952 | "_model_module": "@jupyter-widgets/controls",
953 | "_model_module_version": "1.5.0",
954 | "_model_name": "ProgressStyleModel",
955 | "_view_count": null,
956 | "_view_module": "@jupyter-widgets/base",
957 | "_view_module_version": "1.2.0",
958 | "_view_name": "StyleView",
959 | "bar_color": null,
960 | "description_width": ""
961 | }
962 | },
963 | "e193d91f0e414ec7b527d9c0ebd47558": {
964 | "model_module": "@jupyter-widgets/base",
965 | "model_name": "LayoutModel",
966 | "model_module_version": "1.2.0",
967 | "state": {
968 | "_model_module": "@jupyter-widgets/base",
969 | "_model_module_version": "1.2.0",
970 | "_model_name": "LayoutModel",
971 | "_view_count": null,
972 | "_view_module": "@jupyter-widgets/base",
973 | "_view_module_version": "1.2.0",
974 | "_view_name": "LayoutView",
975 | "align_content": null,
976 | "align_items": null,
977 | "align_self": null,
978 | "border": null,
979 | "bottom": null,
980 | "display": null,
981 | "flex": null,
982 | "flex_flow": null,
983 | "grid_area": null,
984 | "grid_auto_columns": null,
985 | "grid_auto_flow": null,
986 | "grid_auto_rows": null,
987 | "grid_column": null,
988 | "grid_gap": null,
989 | "grid_row": null,
990 | "grid_template_areas": null,
991 | "grid_template_columns": null,
992 | "grid_template_rows": null,
993 | "height": null,
994 | "justify_content": null,
995 | "justify_items": null,
996 | "left": null,
997 | "margin": null,
998 | "max_height": null,
999 | "max_width": null,
1000 | "min_height": null,
1001 | "min_width": null,
1002 | "object_fit": null,
1003 | "object_position": null,
1004 | "order": null,
1005 | "overflow": null,
1006 | "overflow_x": null,
1007 | "overflow_y": null,
1008 | "padding": null,
1009 | "right": null,
1010 | "top": null,
1011 | "visibility": null,
1012 | "width": null
1013 | }
1014 | },
1015 | "43561c163e1f4dec929eac0ead24d266": {
1016 | "model_module": "@jupyter-widgets/controls",
1017 | "model_name": "DescriptionStyleModel",
1018 | "model_module_version": "1.5.0",
1019 | "state": {
1020 | "_model_module": "@jupyter-widgets/controls",
1021 | "_model_module_version": "1.5.0",
1022 | "_model_name": "DescriptionStyleModel",
1023 | "_view_count": null,
1024 | "_view_module": "@jupyter-widgets/base",
1025 | "_view_module_version": "1.2.0",
1026 | "_view_name": "StyleView",
1027 | "description_width": ""
1028 | }
1029 | }
1030 | }
1031 | }
1032 | },
1033 | "nbformat": 4,
1034 | "nbformat_minor": 0
1035 | }
--------------------------------------------------------------------------------
/[5]_System_of_ODEs_PINN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "colab_type": "text",
7 | "id": "view-in-github"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {
17 | "id": "v77fdC1ZLyg1"
18 | },
19 | "outputs": [],
20 | "source": [
21 | "#Credits : Mahmoud Asem @Asem000 October 2021"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "metadata": {
28 | "colab": {
29 | "base_uri": "https://localhost:8080/"
30 | },
31 | "id": "vAR0swbLX_ZI",
32 | "outputId": "97823711-ee53-4921-bf97-c2f9a312e57f"
33 | },
34 | "outputs": [
35 | {
36 | "name": "stdout",
37 | "output_type": "stream",
38 | "text": [
39 | "Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (0.0.9)\n",
40 | "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (0.12.0)\n",
41 | "Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.2.21)\n",
42 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.19.5)\n",
43 | "Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax) (0.0.8)\n",
44 | "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.1.71+cuda111)\n",
45 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->optax) (1.15.0)\n",
46 | "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.11.1)\n",
47 | "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.6)\n",
48 | "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0)\n",
49 | "Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (1.4.1)\n",
50 | "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (1.12)\n",
51 | "Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (0.51.2)\n",
52 | "Requirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.7/dist-packages (from numba) (1.19.5)\n",
53 | "Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba) (57.4.0)\n",
54 | "Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba) (0.34.0)\n"
55 | ]
56 | }
57 | ],
58 | "source": [
59 | "#Imports\n",
60 | "import jax \n",
61 | "import jax.numpy as jnp\n",
62 | "import numpy as np\n",
63 | "import matplotlib.pyplot as plt\n",
64 | "from matplotlib import cm\n",
65 | "import matplotlib as mpl\n",
66 | "!pip install optax\n",
67 | "import optax\n",
68 | "!pip install numba\n",
69 | "import numba"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {
75 | "id": "7bg4nSbsXVwD"
76 | },
77 | "source": [
78 | "### System of ODEs numerical solution\n",
79 | "$\\large \\frac{dx}{dt} = x$
\n",
80 | "\n",
81 | "$\\large \\frac{dy}{dt} = x - y$
\n",
82 | "\n",
83 | "$x(t=0) = 1$\n",
84 | "\n",
85 | "$y(t=0) = 2$\n",
86 | "\n",
87 | "
\n",
88 | "$\\text{analytical solution}$\n",
89 | "\n",
90 | "$x(t) = e^{t}$\n",
91 | "\n",
92 | "$y(t) = \\frac{1}{2} e^{t} + \\frac{3}{2} e^{-t}$"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "metadata": {
98 | "id": "K89DstaYpwh0"
99 | },
100 | "source": []
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": 3,
105 | "metadata": {
106 | "id": "8uW0HW1-pxl8"
107 | },
108 | "outputs": [],
109 | "source": [
110 | "'''\n",
111 | "\n",
112 | "solve \n",
113 | "dx/dt = x\n",
114 | "dy/dt = x-y\n",
115 | "\n",
116 | "x(0) = 1\n",
117 | "y(0) = 2\n",
118 | "\n",
119 | "solution =\n",
120 | "\n",
121 | "x(t) = exp(t)\n",
122 | "y(t) = 0.5*exp(t) +1.5*exp(-t)\n",
123 | "\n",
124 | "'''\n",
125 | "\n",
126 | "@numba.njit\n",
127 | "def RK4(odefun,ics,h,span,degree):\n",
128 | " \n",
129 | " N= int( (span[1]-span[0])/h )\n",
130 | " \n",
131 | " tY = np.zeros((N+1,degree+1))\n",
132 | " tY[0,1:] = ics\n",
133 | " \n",
134 | " \n",
135 | " for i in range(N):\n",
136 | " tY[i+1,0] = tY[i,0] + h\n",
137 | "\n",
138 | " k1= odefun(tY[i,0] , tY[i,1:])\n",
139 | " k2= odefun(tY[i,0] +(h/2), tY[i,1:] +(h*k1)/2 )\n",
140 | " k3= odefun(tY[i,0] +(h/2), tY[i,1:] +(h*k2)/2)\n",
141 | " k4= odefun(tY[i,0] +(h) , tY[i,1:] +(h*k3))\n",
142 | " \n",
143 | " tY[i+1,1:] = tY[i,1:] + h*(1/6) * (k1+2*k2+2*k3+k4)\n",
144 | " \n",
145 | " return tY[:,0],tY[:,1:]\n",
146 | "\n",
147 | "@numba.njit\n",
148 | "def system_of_ode(t,V):\n",
149 | " y1,y2 = V[0],V[1]\n",
150 | " return np.array([y1,y1-y2])\n",
151 | "\n",
152 | "t,y=RK4(system_of_ode,\n",
153 | " ics=np.array([1,2]),\n",
154 | " h=1e-3,\n",
155 | " span=np.array([1e-4,np.pi]),\n",
156 | " degree =2)"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": 4,
162 | "metadata": {
163 | "colab": {
164 | "base_uri": "https://localhost:8080/",
165 | "height": 282
166 | },
167 | "id": "bNctv_EFp-yD",
168 | "outputId": "9ede81ec-108c-4d4b-f928-d0c2493aea0c"
169 | },
170 | "outputs": [
171 | {
172 | "data": {
173 | "text/plain": [
174 | ""
175 | ]
176 | },
177 | "execution_count": 4,
178 | "metadata": {},
179 | "output_type": "execute_result"
180 | },
181 | {
182 | "data": {
183 | "image/png": "",
184 | "text/plain": [
185 | ""
186 | ]
187 | },
188 | "metadata": {
189 | "needs_background": "light"
190 | },
191 | "output_type": "display_data"
192 | }
193 | ],
194 | "source": [
195 | "plt.plot(t,y[:,0],'-r',label='RK4[1]')\n",
196 | "plt.plot(t,np.exp(t),'--k',label='Analytical[1]')\n",
197 | "\n",
198 | "plt.plot(t,y[:,1],'-g',label='RK4[2]')\n",
199 | "plt.plot(t,0.5*np.exp(t)+1.5*np.exp(-t),'--b',label='Analytical[2]')\n",
200 | "\n",
201 | "plt.legend()"
202 | ]
203 | },
204 | {
205 | "cell_type": "markdown",
206 | "metadata": {
207 | "id": "NQ61lEQeXgrc"
208 | },
209 | "source": [
210 | "### Constructing the MLP"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": 5,
216 | "metadata": {
217 | "colab": {
218 | "base_uri": "https://localhost:8080/"
219 | },
220 | "id": "Lml6PGLPZgmr",
221 | "outputId": "aaa25445-8e2d-4540-8cf3-20d6e5404fc8"
222 | },
223 | "outputs": [
224 | {
225 | "name": "stderr",
226 | "output_type": "stream",
227 | "text": [
228 | "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
229 | ]
230 | }
231 | ],
232 | "source": [
233 | "N_b = 1_000\n",
234 | "N_c = 10_000\n",
235 | "\n",
236 | "tmin,tmax=0. ,jnp.pi\n",
237 | "\n",
238 | "'''boundary conditions'''\n",
239 | "\n",
240 | "\n",
241 | "# y1[0] = 1\n",
242 | "y1_t0 = jnp.zeros([N_b,1],dtype='float32')\n",
243 | "y1_ic = jnp.ones_like(y1_t0) \n",
244 | "Y1_IC = jnp.concatenate([y1_t0,y1_ic],axis=1)\n",
245 | "\n",
246 | "# y2[0] = 2\n",
247 | "y2_t0 = jnp.zeros([N_b,1],dtype='float32')\n",
248 | "y2_ic = jnp.ones_like(y2_t0) * 2\n",
249 | "Y2_IC = jnp.concatenate([y2_t0,y2_ic],axis=1)\n",
250 | "\n",
251 | "conds = [Y1_IC,Y2_IC]\n",
252 | "\n",
253 | "#collocation points\n",
254 | "\n",
255 | "key=jax.random.PRNGKey(0)\n",
256 | "\n",
257 | "t_c = jax.random.uniform(key,minval=tmin,maxval=tmax,shape=(N_c,1))\n",
258 | "colloc = t_c\n",
259 | "\n",
260 | "def ODE_loss(t,y1,y2):\n",
261 | "\n",
262 | " y1_t=lambda t:jax.grad(lambda t:jnp.sum(y1(t)))(t)\n",
263 | " y2_t=lambda t:jax.grad(lambda t:jnp.sum(y2(t)))(t)\n",
264 | "\n",
265 | " return y1_t(t) - y1(t) , y2_t(t) - y1(t) + y2(t)\n"
266 | ]
267 | },
268 | {
269 | "cell_type": "code",
270 | "execution_count": 6,
271 | "metadata": {
272 | "id": "KoZZJl2TbI_n"
273 | },
274 | "outputs": [],
275 | "source": [
276 | "def init_params(layers):\n",
277 | " keys = jax.random.split(jax.random.PRNGKey(0),len(layers)-1)\n",
278 | " params = list()\n",
279 | " for key,n_in,n_out in zip(keys,layers[:-1],layers[1:]):\n",
280 | " lb, ub = -(1 / jnp.sqrt(n_in)), (1 / jnp.sqrt(n_in)) # xavier initialization lower and upper bound\n",
281 | " W = lb + (ub-lb) * jax.random.uniform(key,shape=(n_in,n_out))\n",
282 | " B = jax.random.uniform(key,shape=(n_out,))\n",
283 | " params.append({'W':W,'B':B})\n",
284 | " return params\n",
285 | "\n",
286 | "def fwd(params,t):\n",
287 | " X = jnp.concatenate([t],axis=1)\n",
288 | " *hidden,last = params\n",
289 | " for layer in hidden :\n",
290 | " X = jax.nn.tanh(X@layer['W']+layer['B'])\n",
291 | " return X@last['W'] + last['B']\n",
292 | "\n",
293 | "@jax.jit\n",
294 | "def MSE(true,pred):\n",
295 | " return jnp.mean((true-pred)**2)\n",
296 | "\n",
297 | "def loss_fun(params,colloc,conds):\n",
298 | " t_c =colloc[:,[0]]\n",
299 | "\n",
300 | " y1_func = lambda t : fwd(params,t)[:,[0]]\n",
301 | " y1_func_t=lambda t:jax.grad(lambda t:jnp.sum(y1_func(t)))(t)\n",
302 | "\n",
303 | " y2_func = lambda t : fwd(params,t)[:,[1]]\n",
304 | " y2_func_t=lambda t:jax.grad(lambda t:jnp.sum(y2_func(t)))(t)\n",
305 | "\n",
306 | " loss_y1,loss_y2 =ODE_loss(t_c,y1_func,y2_func)\n",
307 | "\n",
308 | " loss = jnp.mean( loss_y1 **2) \n",
309 | " loss+= jnp.mean(loss_y2 **2)\n",
310 | "\n",
311 | " t_ic,y1_ic = conds[0][:,[0]],conds[0][:,[1]] \n",
312 | " loss += MSE(y1_ic,y1_func(t_ic))\n",
313 | "\n",
314 | " t_ic,y2_ic = conds[1][:,[0]],conds[1][:,[1]] \n",
315 | " loss += MSE(y2_ic,y2_func(t_ic))\n",
316 | "\n",
317 | " return loss\n",
318 | "\n",
319 | "@jax.jit\n",
320 | "def update(opt_state,params,colloc,conds):\n",
321 | " # Get the gradient w.r.t to MLP params\n",
322 | " grads=jax.jit(jax.grad(loss_fun,0))(params,colloc,conds)\n",
323 | " \n",
324 | " #Update params\n",
325 | " updates, opt_state = optimizer.update(grads, opt_state)\n",
326 | " params = optax.apply_updates(params, updates)\n",
327 | "\n",
328 | " #Update params\n",
329 | " # return jax.tree_multimap(lambda params,grads : params-LR*grads, params,grads)\n",
330 | " return opt_state,params\n"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": 7,
336 | "metadata": {
337 | "id": "ae1ZDoy0c29c"
338 | },
339 | "outputs": [],
340 | "source": [
341 | "# construct the MLP of 6 hidden layers of 8 neurons for each layer\n",
342 | "params = init_params([1] + [8]*2+[2])"
343 | ]
344 | },
345 | {
346 | "cell_type": "code",
347 | "execution_count": 8,
348 | "metadata": {
349 | "id": "jySmbUwic5yk"
350 | },
351 | "outputs": [],
352 | "source": [
353 | "optimizer = optax.adam(1e-2)\n",
354 | "opt_state = optimizer.init(params)"
355 | ]
356 | },
357 | {
358 | "cell_type": "code",
359 | "execution_count": 9,
360 | "metadata": {
361 | "colab": {
362 | "base_uri": "https://localhost:8080/"
363 | },
364 | "id": "VS0bosXPg1Oo",
365 | "outputId": "258603fc-cb6e-4dcb-eb86-4560004b71e2"
366 | },
367 | "outputs": [
368 | {
369 | "data": {
370 | "text/plain": [
371 | "(10000, 1)"
372 | ]
373 | },
374 | "execution_count": 9,
375 | "metadata": {},
376 | "output_type": "execute_result"
377 | }
378 | ],
379 | "source": [
380 | "fwd(params,t_c)[:,[0]].shape"
381 | ]
382 | },
383 | {
384 | "cell_type": "code",
385 | "execution_count": 10,
386 | "metadata": {
387 | "colab": {
388 | "base_uri": "https://localhost:8080/"
389 | },
390 | "id": "kBzGA8OVc8C6",
391 | "outputId": "7dda4b83-5d7e-497c-fb98-64e7b1322528"
392 | },
393 | "outputs": [
394 | {
395 | "name": "stdout",
396 | "output_type": "stream",
397 | "text": [
398 | "Epoch=0\tloss=4.391e+00\n",
399 | "Epoch=100\tloss=3.921e-01\n",
400 | "Epoch=200\tloss=3.891e-01\n",
401 | "Epoch=300\tloss=3.876e-01\n",
402 | "Epoch=400\tloss=3.866e-01\n",
403 | "Epoch=500\tloss=3.855e-01\n",
404 | "Epoch=600\tloss=3.840e-01\n",
405 | "Epoch=700\tloss=3.820e-01\n",
406 | "Epoch=800\tloss=3.788e-01\n",
407 | "Epoch=900\tloss=3.726e-01\n",
408 | "Epoch=1000\tloss=3.555e-01\n",
409 | "Epoch=1100\tloss=3.070e-01\n",
410 | "Epoch=1200\tloss=2.254e-01\n",
411 | "Epoch=1300\tloss=1.834e-01\n",
412 | "Epoch=1400\tloss=1.354e-01\n",
413 | "Epoch=1500\tloss=1.085e-01\n",
414 | "Epoch=1600\tloss=8.198e-02\n",
415 | "Epoch=1700\tloss=6.454e-02\n",
416 | "Epoch=1800\tloss=5.201e-02\n",
417 | "Epoch=1900\tloss=4.232e-02\n",
418 | "Epoch=2000\tloss=3.476e-02\n",
419 | "Epoch=2100\tloss=2.897e-02\n",
420 | "Epoch=2200\tloss=2.339e-02\n",
421 | "Epoch=2300\tloss=1.980e-02\n",
422 | "Epoch=2400\tloss=1.632e-02\n",
423 | "Epoch=2500\tloss=1.425e-02\n",
424 | "Epoch=2600\tloss=1.194e-02\n",
425 | "Epoch=2700\tloss=1.174e-02\n",
426 | "Epoch=2800\tloss=9.231e-03\n",
427 | "Epoch=2900\tloss=7.939e-03\n",
428 | "Epoch=3000\tloss=7.260e-03\n",
429 | "Epoch=3100\tloss=6.298e-03\n",
430 | "Epoch=3200\tloss=7.007e-02\n",
431 | "Epoch=3300\tloss=5.132e-03\n",
432 | "Epoch=3400\tloss=4.541e-03\n",
433 | "Epoch=3500\tloss=5.706e-03\n",
434 | "Epoch=3600\tloss=3.834e-03\n",
435 | "Epoch=3700\tloss=3.433e-03\n",
436 | "Epoch=3800\tloss=3.316e-03\n",
437 | "Epoch=3900\tloss=2.964e-03\n",
438 | "Epoch=4000\tloss=2.688e-03\n",
439 | "Epoch=4100\tloss=2.604e-03\n",
440 | "Epoch=4200\tloss=2.364e-03\n",
441 | "Epoch=4300\tloss=1.091e-01\n",
442 | "Epoch=4400\tloss=2.085e-03\n",
443 | "Epoch=4500\tloss=1.907e-03\n",
444 | "Epoch=4600\tloss=2.153e-03\n",
445 | "Epoch=4700\tloss=1.716e-03\n",
446 | "Epoch=4800\tloss=1.577e-03\n",
447 | "Epoch=4900\tloss=1.593e-03\n",
448 | "Epoch=5000\tloss=1.462e-03\n",
449 | "Epoch=5100\tloss=1.348e-03\n",
450 | "Epoch=5200\tloss=1.934e-03\n",
451 | "Epoch=5300\tloss=1.247e-03\n",
452 | "Epoch=5400\tloss=1.154e-03\n",
453 | "Epoch=5500\tloss=1.163e-03\n",
454 | "Epoch=5600\tloss=1.062e-03\n",
455 | "Epoch=5700\tloss=6.481e-03\n",
456 | "Epoch=5800\tloss=9.879e-04\n",
457 | "Epoch=5900\tloss=9.198e-04\n",
458 | "Epoch=6000\tloss=9.548e-04\n",
459 | "Epoch=6100\tloss=8.567e-04\n",
460 | "Epoch=6200\tloss=5.164e-03\n",
461 | "Epoch=6300\tloss=8.136e-04\n",
462 | "Epoch=6400\tloss=7.632e-04\n",
463 | "Epoch=6500\tloss=8.119e-04\n",
464 | "Epoch=6600\tloss=7.259e-04\n",
465 | "Epoch=6700\tloss=6.844e-04\n",
466 | "Epoch=6800\tloss=6.932e-04\n",
467 | "Epoch=6900\tloss=6.498e-04\n",
468 | "Epoch=7000\tloss=1.022e-02\n",
469 | "Epoch=7100\tloss=6.268e-04\n",
470 | "Epoch=7200\tloss=5.936e-04\n",
471 | "Epoch=7300\tloss=3.431e-03\n",
472 | "Epoch=7400\tloss=5.711e-04\n",
473 | "Epoch=7500\tloss=5.428e-04\n",
474 | "Epoch=7600\tloss=6.806e-04\n",
475 | "Epoch=7700\tloss=5.217e-04\n",
476 | "Epoch=7800\tloss=7.266e-03\n",
477 | "Epoch=7900\tloss=5.146e-04\n",
478 | "Epoch=8000\tloss=4.886e-04\n",
479 | "Epoch=8100\tloss=1.107e-03\n",
480 | "Epoch=8200\tloss=4.893e-04\n",
481 | "Epoch=8300\tloss=4.574e-04\n",
482 | "Epoch=8400\tloss=9.877e-02\n",
483 | "Epoch=8500\tloss=4.501e-04\n",
484 | "Epoch=8600\tloss=4.283e-04\n",
485 | "Epoch=8700\tloss=2.029e-03\n",
486 | "Epoch=8800\tloss=4.216e-04\n",
487 | "Epoch=8900\tloss=4.044e-04\n",
488 | "Epoch=9000\tloss=4.368e-04\n",
489 | "Epoch=9100\tloss=4.015e-04\n",
490 | "Epoch=9200\tloss=3.859e-04\n",
491 | "Epoch=9300\tloss=1.088e-03\n",
492 | "Epoch=9400\tloss=3.847e-04\n",
493 | "Epoch=9500\tloss=3.701e-04\n",
494 | "Epoch=9600\tloss=7.795e-04\n",
495 | "Epoch=9700\tloss=3.653e-04\n",
496 | "Epoch=9800\tloss=8.272e-02\n",
497 | "Epoch=9900\tloss=3.632e-04\n",
498 | "CPU times: user 1min 39s, sys: 2.21 s, total: 1min 41s\n",
499 | "Wall time: 1min 31s\n"
500 | ]
501 | }
502 | ],
503 | "source": [
504 | "%%time\n",
505 | "epochs = 10_000\n",
506 | "for _ in range(epochs):\n",
507 | " opt_state,params = update(opt_state,params,colloc,conds)\n",
508 | "\n",
509 | " # print loss and epoch info\n",
510 | " if _ %(100) ==0:\n",
511 | " print(f'Epoch={_}\\tloss={loss_fun(params,colloc,conds):.3e}')"
512 | ]
513 | },
514 | {
515 | "cell_type": "code",
516 | "execution_count": 11,
517 | "metadata": {
518 | "colab": {
519 | "base_uri": "https://localhost:8080/",
520 | "height": 282
521 | },
522 | "id": "eWeNvDsdDEuI",
523 | "outputId": "3ca4560a-ca9f-451f-db9d-0b61e081cfa2"
524 | },
525 | "outputs": [
526 | {
527 | "data": {
528 | "text/plain": [
529 | ""
530 | ]
531 | },
532 | "execution_count": 11,
533 | "metadata": {},
534 | "output_type": "execute_result"
535 | },
536 | {
537 | "data": {
538 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD4CAYAAAD1jb0+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deVxU1f/H8ddhBxVRcIFAUcxdwQW1FNdcylxwV1xJrdyyzL6WWi6talammVsuae5bWWnmmprmvpumoqKoKIEgi8Cc3x8gPzVQVODOwOf5eMzDmXvv3Pu+TH24nDn3HKW1RgghhOWxMjqAEEKIJyMFXAghLJQUcCGEsFBSwIUQwkJJARdCCAtlk5MHc3Nz097e3jl5SCGEsHj79++/obUu8uDyHC3g3t7e7Nu3LycPKYQQFk8pdSG95dKEIoQQFkoKuBBCWCgp4EIIYaFytA08PYmJiYSGhhIfH290FJEOBwcHPD09sbW1NTqKEOIBhhfw0NBQChQogLe3N0opo+OIe2ituXnzJqGhoZQqVcroOEKIBxjehBIfH4+rq6sUbzOklMLV1VX+OhLCTBlewAEp3mZMPhshzJdZFHAhhMitrl+HYcPg6tWs37cU8FRr1qxBKcWpU6eeeB+9e/dmxYoVD93m448/vu/1888//0THGjNmDJMmTQKgYcOGlCtXjh9//PGh75k6dSplypRBKcWNGzfSli9dupQyZcrw8ssvP1EWIUTGvvoKvvgCoqKyft9SwFMtXryYevXqsXjx4mw9zoMFfNeuXVmy30WLFtG6deuHblO3bl1+//13SpYsed/yzp07M3v27CzJIYT4f1FR8PXXJjp0gHLlsn7/UsCBmJgYduzYwZw5c1iyZAkAW7dupWHDhnTo0IHy5csTFBTE3dmLxo0bh7+/P5UrV6Z///48OKvR5s2badu2bdrrjRs3EhgYyIgRI4iLi8PPz4+goCAA8ufPn7bdZ599RpUqVfD19WXEiBEAzJo1C39/f3x9fWnfvj2xsbEPPZezZ89SvXr1tNdnzpxJe12tWjVkLBohcs60aRAd3YuoqC6EhIRk+f4N70Z4r6FD4dChrN2nnx98+eXDt1m7di0tWrSgbNmyuLq6sn//fgAOHjzI8ePH8fDwoG7duuzcuZN69eoxaNAg3n//fQB69OjBunXraNWqVdr+GjVqxIABAwgPD6dIkSLMnTuX4OBgWrVqxdSpUzmUzkn++uuvrF27lj179uDk5ERERAQA7dq1o1+/fgCMGjWKOXPmMHjw4AzPxcfHh4IFC3Lo0CH8/PyYO3cuffr0eayfmRDi6cXGwqRJZ4Af2LzZCiurCVl+DLkCJ6X5pEuXLgB06dIlrRmlVq1aeHp6YmVlhZ+fX9pv0C1btlC7dm2qVKnC5s2bOX78+H37U0rRo0cPFi5cSGRkJH/++ScvvvjiQzP8/vvv9OnTBycnJwAKFy4MwLFjxwgICKBKlSosWrToP8dKT9++fZk7dy7JycksXbqUbt26PdbPQwjx9GbPhn//LUq/fuMYOnQoJUqUyPJjmNUV+KOulLNDREQEmzdv5ujRoyilSE5ORilFy5Ytsbe3T9vO2tqapKQk4uPjGTBgAPv27cPLy4sxY8ak20+6T58+tGrVCgcHBzp27IiNzZP9qHv37s2aNWvw9fVl3rx5bN269ZHvad++PWPHjqVx48bUqFEDV1fXJzq2EOLJ3LkDEydC/foFmTlzZLYdJ89fga9YsYIePXpw4cIFQkJCuHTpEqVKleKPP/5Id/u7xdrNzY2YmJgMe514eHjg4eHBhx9+eF8Thq2tLYmJif/ZvmnTpsydOzetjftuE0p0dDTu7u4kJiayaNGiTJ2Tg4MDzZs35/XXX5fmEyEMsHAhhIaaeO+97D1Oni/gixcvJjAw8L5l7du3z7A3iouLC/369aNy5co0b94cf3//DPcdFBSEl5cXFSpUSFvWv39/qlatmvYl5l0tWrSgdevW1KxZEz8/v7QuguPHj6d27drUrVuX8uXLZ/q8goKCsLKyolmzZmnLpkyZgqenJ6GhoVStWpW+fftmen9CiMxJSoJx48KwsyvDsWOT/9PJIUtprXPsUaNGDf2gEydO/GdZbjFw4EA9e/bsbD9OgwYN9N69e+9bNnHiRD1q1KhM72PLli26ZcuW6a7LzZ+REFlt3jyt4W0N6Hbt2mXJPoF9Op2amuevwLNLjRo1OHLkCN27d8/2YxUuXJjevXun3cgTGBjIggULeOONNzL1/qVLlzJgwAAKFSqUnTGFyPWSkmDs2HCsrKYD8F42t6GY1ZeYucndrog5YdWqVfe9Xr169WO9v3PnznTu3DkrIwmRJ/3wA5w/Pwm4TcuWLalRo0a2Hk+uwIUQIgskJcGYMdexspoKpAx3kd3kClwIIbLA4sVw/vxEIJaXX36ZmjVrZvsx5QpcCCGeUlISjB9vwslpO5AzV98gV+BCCPHUliyBM2esWL78T4oW3ZHtbd93yRU4Kbe+Dxs2LO31pEmT0n6DjhkzBicnJ65fv562/u4AVCEhITg6OuLn5/fIY7Ro0QIXF5f/DNkaFBRE4cKFHzkMrRDCPCUnw/jxULUqtGtnRf369XPs2FLAAXt7e1atWnXfGNn3cnNz4/PPP093nY+PT7qDUz1o+PDhfP/99/9ZnplhYIUQ5mvJEjh9egWvv34RqxyuqFLAARsbG/r3788XX3yR7vrg4GCWLl2adnt7Rt5//32+vGdAl5EjR/LVV18B0KRJEwoUKJB1oYUQhktMhFGjrqJUD95441muXLmSo8c3uwKulMrwMXPmzLTtZs6c+dBtH9fAgQNZtGgRUelMm5E/f36Cg4PTinFGgoODWbBgAQAmk4klS5bkyI08QghjzJsHISGfoXU8L730Eh4eHjl6fLMr4EZxdnamZ8+eTJkyJd31Q4YMYf78+URHR2e4D29vb1xdXTl48CC//fYb1apVk5EAhcil4uNh9OiLKJVy1+UHH3yQ4xnMrheKzuTAL/3796d///5ZeuyhQ4dSvXr1dEfwc3FxoVu3bkybNu2h++jbty/z5s3j6tWrBAcHZ2k+IYT5mD4drl0bCyTQpUuXTHVmyGpyBX6PwoUL06lTJ+bMmZPu+rfeeosZM2aQlJSU4T4CAwNZv349e/fupXnz5tkVVQhhoOhoGDfuJDAPGxsbxo8fb0gOKeAPGDZs2EN7owQGBpKQkJDh++3s7GjUqBGdOnXC2to6bXlAQAAdO3Zk06ZNeHp6smHDhizPLoTIGV9+CZGRowETr7zyCmXKlDEkh9k1oRghJiYm7XmxYsXumzj4wTuqJk+ezOTJkzPcl8lkYvfu3Sxfvvy+5RlNECGEsCw3b8KkSdC48fu4uOi0+XGNIFfgT8Ha2pqoqKi0tq8TJ05QpkwZmjRpwrPPPpupfQQFBbFt2zYcHByyM6oQIotMmJDShDJlSlVWrlyZ4z1P7vXIK3CllBewACgGaGCm1vorpVRhYCngDYQAnbTW/2ZfVPPj5eXFpUuX0l5XrFiRc+fOPdY+MjtNmhDCeFeuwJQpkQQFFaRSpcfvrpzVMnMFngQM01pXBOoAA5VSFYERwCat9bPAptTXQgiRa40dq4mPb8aZM425cOGC0XEeXcC11mFa6wOpz6OBk8AzQBtgfupm84G22RVSCCGMdvIkzJ69GthLSMhJ3NzcjI70eG3gSilvoBqwByimtQ5LXXWVlCaW9N7TXym1Tym1Lzw8/CmiCiGEcYYPvwP8D4DRo0eTL18+YwPxGAVcKZUfWAkM1Vrfundd6qSb6d6Bo7WeqbWuqbWuWaRIkacKK4QQRti2DX7++VtMpn8oV65clt9E+KQyVcCVUrakFO9FWuu7EzBeU0q5p653B65n9H5LsGbNGpRSnDp16on30bt370cOC/vxxx/f9/r5559/omONGTOGSZMmAdCwYUPKlSuXNqlxRoKCgihXrhyVK1cmODiYxMREIGVS4zJlyvxnqFshBJhMMHRoJFZWYwGYMGECtra2BqdK8cgCrlJGhpoDnNRa39sB+kegV+rzXsDarI+XcxYvXky9evVYvHhxth7nwQK+a9euLNlvZoalDQoK4tSpUxw9epS4uDhmz54NpExqfPe5EOJ+S5fCoUMfYTJF0KBBA1q1amV0pDSZuQKvC/QAGiulDqU+XgI+BZoqpc4AL6S+tkgxMTHs2LGDOXPmsGTJEgC2bt1Kw4YN6dChA+XLlycoKChtnJZx48bh7+9P5cqV6d+//3/Gb9m8eTNt2/7/d7obN24kMDCQESNGEBcXh5+fH0FBQcD/Tw4B8Nlnn1GlShV8fX0ZMSKlU8+sWbPw9/fH19eX9u3b33eTUXrOnj1L9erV016fOXMm7fVLL72UNlpjrVq1CA0NfdIfmRB5QkICvPceuLsXIl++fHz++edPNNppdnlkP3Ct9Q4go8RNsjTN0KGQickRHoufX8p9rw+xdu1aWrRoQdmyZXF1dWX//v0AHDx4kOPHj+Ph4UHdunXZuXMn9erVY9CgQWl3X/Xo0YN169bd91u5UaNGDBgwgPDwcIoUKcLcuXMJDg6mVatWTJ06Nd0JIH799VfWrl3Lnj17cHJySht7vF27dvTr1w+AUaNGMWfOHAYPHpzhufj4+FCwYEEOHTqEn58fc+fO/c/gXImJiXz//fePHB5XiLxu6lQICYGNG9/D338gBQsWNDrSfeROTFKaT7p06QJAly5d0ppRatWqhaenJ1ZWVvj5+RESEgLAli1bqF27NlWqVGHz5s0cP378vv0ppejRowcLFy4kMjKSP//8kxdffPGhGX7//Xf69OmDk5MTkDKwFsCxY8cICAigSpUqLFq06D/HSk/fvn2ZO3cuycnJLF26lG7dut23fsCAAdSvX5+AgIBH/3CEyKMiIuDDD6FFC3jhBcyueIO5jYXyiCvl7BAREcHmzZs5evQoSimSk5NRStGyZUvs7e3TtrO2tiYpKYn4+HgGDBjAvn378PLyYsyYMcTHx/9nv3369KFVq1Y4ODjQsWNHbGye7Efdu3dv1qxZg6+vL/PmzWPr1q2PfE/79u0ZO3YsjRs3pkaNGveNST527FjCw8OZMWPGE+URIq8YO1YTGdmBWrVakJwcfN/gdOYiz1+Br1ixgh49enDhwgVCQkK4dOkSpUqVynDwqbvF2s3NjZiYmAx7nXh4eODh4cGHH354XxOGra1tWu+PezVt2pS5c+emtXHfbUKJjo7G3d2dxMTETN927+DgQPPmzXn99dfvO/bs2bPZsGEDixcvxiqnJ+8TwoKcOAFTp64EVjF9+khu375tdKR05fn/ixcvXkxgYOB9y9q3b59hbxQXFxf69etH5cqVad68Of7+/hnuOygoCC8vLypUqJC2rH///lStWjXtS8y7WrRoQevWralZsyZ+fn5pXQTHjx9P7dq1qVu3LuXLl8/0eQUFBWFlZUWzZs3Slr322mtcu3aN5557Dj8/P8aNG5fp/QmRV2gNgwfHofXbQEqnBWdnZ4NTZUBrnWOPGjVq6AedOHHiP8tyi4EDB+rZs2dn+3EaNGig9+7de9+yiRMn6lGjRmV6H1u2bNEtW7ZMd11u/oyEeNDatVrDGA1oX19fnZSUZHQkDezT6dTUPH8Fnl1q1KjBkSNHcmRS48KFC9O7d++0G3kCAwNZsGABb7zxRqbev3TpUgYMGEChQoWyM6YQZi8hAQYPvoBSKb2ip0yZYpZt33eZ15eYucjdrog5YdWqVfe9Xr169WO9v3PnznTu3DkrIwlhkb78Ei5efBuIp0uXLtSvX9/oSA8lBVwIIYCwMBg//hYFCpwkOdmJCRMmGB3pkaSACyEE8O67kJjozKFDB4mJOYSXl5fRkR5JCrgQIs/bswfmz4f//Q8qVLAFMu5dZk7kS0whRJ6WnAz9+t3AyWkA/fpdMzrOY5ECTsqt78OGDUt7PWnSpLTZ6MeMGYOTkxPXr///aLl3B6AKCQnB0dExbVLjjBw6dIjnnnuOSpUqUbVqVZYuXZq2LigoiMKFCz9yGFohRPaYPh2OHn2X2NjpDB/+utFxHosUcMDe3p5Vq1Zx48aNdNe7ubnx+eefp7vOx8cn3cGp7uXk5MSCBQs4fvw469evZ+jQoURGRgKZGwZWCJE9rl6F//1vFzAbW1tbPvroI6MjPRYp4ICNjQ39+/fniy++SHd9cHAwS5cuTbu9PSPvv/8+X94znsvIkSP56quvKFu2LM8++yyQcot90aJFkenlhDDem28mEhf3GgDvvPPOfXdNWwKzK+B3x6tO7zFz5sy07WbOnPnQbR/XwIEDWbRoEVFRUf9Zlz9/foKDgx85/GpwcDALFiwAwGQysWTJkv/cyPPXX39x584dfHx8HjujECLrbN4MS5Z8hdZHKV26NCNHjjQ60mMzuwJuFGdnZ3r27MmUKVPSXT9kyBDmz59PdHR0hvvw9vbG1dWVgwcP8ttvv1GtWrX7RgIMCwujR48ezJ07VwaTEsJAd+5Av34XUeoDAKZNm4ajo6PBqR6f2XUj1DrduZH/o3///lk+sejQoUOpXr36fyZAgJRBrLp168a0adMeuo++ffsyb948rl69SnBwcNryW7du0bJlSz766CPq1KmTpbmFEI9n0iQ4d249EEvHjh1p0aKF0ZGeiFwG3qNw4cJ06tSJOXPmpLv+rbfeYsaMGSQlJWW4j8DAQNavX8/evXtp3rw5AHfu3CEwMJCePXvSoUOHbMkuhMickJCUiRrat+/Pn3/+meF3X5ZACvgDhg0b9tDeKIGBgSQkJGT4fjs7Oxo1akSnTp3SBsFZtmwZ27dvZ968efj5+eHn5/fInitCiKynNbz2GlhbwxdfQJ06dXjmmWeMjvXEzK4JxQgxMTFpz4sVK3bfxMF3+4PfNXnyZCZPnpzhvkwmE7t372b58uVpy7p3754joxIKIR5u0SLYsOELBg6sgpfXC0bHeWpyBf4UrK2tiYqKSruR58SJE5QpU4YmTZqkdRt8lKCgILZt24aDg0N2RhUizwsPh0GD9gFvM316c86ePWt0pKcmV+BPwcvLi0uXLqW9rlixIufOnXusfWR2mjQhxNMZPPgOt24FAyaGDn0rV3TlNYsr8Mz2PBE5Tz4bkRv8/DMsXfpZWp/v8ePHGx0pSxhewB0cHLh586YUCjOktebmzZvSvCMsWnQ09O17AqVSivbs2bNxcnIyOFXWMLwJxdPTk9DQULm13Ew5ODjg6elpdAwhntiIEclcvfoKkEj//v1p1KiR0ZGyjOEF3NbWllKlShkdQwiRC+3aBd98cx4npxBcXDwsYpadx2F4ARdCiOwQGwu9e4O3dxm2bz/B9evnKFiwoNGxspQUcCFErvTuu3DmDGzZAl5ehfDyqmF0pCxn+JeYQgiR1bZsgSlTplKjxgieey7jO6ctnRRwIUSuEh0N3bufQqnh7N//GVu3bjU6UraRJhQhRK7y5ptJXLnSE4inV69eaYPK5UZyBS6EyDXWr4c5cz4B9uLl5fXISVgsnRRwIUSu8O+/0KPHAWAcAPPmzct1vU4eJAVcCGHxUoaJjefGjZ5AEkOGDKFx48ZGx8p20gYuhLB4CxfCsmWxlCtXCq0T+eSTT4yOlCOkgAshLNq5czBwIAQEFGbz5h+JiAjPNWOdPMojm1CUUt8ppa4rpY7ds2yMUuqyUupQ6uOl7I0phBD/lZQEnTtHolQiCxeCjY2iaNGiRsfKMZlpA58HpDfj5xdaa7/Uxy9ZG0sIIR5t3DjNvn1BFC0agMkUYnScHPfIJhSt9XallHf2RxFCiMzbsQM+/HAK8AsREYWxscl7LcJP0wtlkFLqSGoTS6GMNlJK9VdK7VNK7ZMhY4UQWSEqCjp1OoTW7wAwZ86cPDns8ZMW8OmAD+AHhAGfZ7Sh1nqm1rqm1rpmkSJFnvBwQgiRQmt45ZXbhIV1Ae7w2muv0bZtW6NjGeKJCrjW+prWOllrbQJmAbWyNpYQQqRvxgxYuXIo8DcVK1bk888zvH7M9Z6ogCul3O95GQgcy2hbIYTIKgcPwpAhm4HZ2Nvbs2TJkjzTZTA9j2z1V0otBhoCbkqpUOADoKFSyg/QQAjwajZmFEIIoqKgY0coWrQRr776OR4eBalSpYrRsQyVmV4oXdNZPCcbsgghRLq0hr59ISQEtm1T1K37ltGRzIKMhSKEMHtTp2pWrBjL22+HULeu0WnMhxRwIYRZ27cP3nxzOjCGxYsbcOfOHaMjmQ0p4EIIs3XzJrRuvY/k5DcB+Oyzz7CzszM4lfmQAi6EMEvJydC+/Q3CwjoAdxgwYABdunQxOpZZkQIuhDBL776bxLZtnYEL+Pv7M3nyZKMjmR0p4EIIs7N8OUyc+A6wmaJFi7Jq1Srs7e2NjmV2pIALIczKsWPQpw94e3vj6OjIypUr8+Q4J5mR94bvEkKYrchICAyEAgVg584hWFt3plixYkbHMltSwIUQZiE5GTp2vM7587fYurUMHh4AUrwfRppQhBBmYfjwRH7/vQN2dv7ADqPjWAQp4EIIw82erfniiwHAHxQq5ISPj4/RkSyCFHAhhKG2bYNXX/0cmI2DgwOrV6/G3d39ke8TUsCFEAY6exZefnkNJlPKzDrff/89tWrJ9AKZJQVcCGGIyEh44YUDxMQEAZqPP/6YDh06GB3LokgvFCFEjktKgs6d4dKl09jYJBIU1IsRI0YYHcviSAEXQuQorWHAAPjtN5g9uwtVq/rg6+uLUsroaBZHCrgQIkeNHZvErFn/8N575XnlFQB/oyNZLGkDF0LkmO++04wd+yo2Nv40aLDR6DgWTwq4ECJHrF8Pfft+AHyHrW0yzs4FjI5k8aSACyGy3YED0KbNdLQej7W1NcuWLaNOnTpGx7J4UsCFENnq/Hlo3Hg1d+4MBGDGjBm8/PLLBqfKHaSACyGyzdWrEBCwg6ioroBm3LhxvJLyzaXIAtILRQiRLf79F5o1g/Dwm9jZQXDwa4waNcroWMZISgKbrC+3cgUuhMhyMTHQsiX8/TesW9eGvXv/YurUqXmzr/fOnVCpUsoPI4vJFbgQIkslJECLFiHs3n2JFSsCaNoUoKrRsYyxdy+89BIUKwYFC2b57uUKXAiRZZKSoG3bK+zc2QQbm2YULZqHx/U+dIhdTZpgKlwYNm+G4sWz/BBSwIUQWcJkgu7db7B+fVPgHL6+lalaNY9eeR8/zg8BAdSLjqavvz/6mWey5TBSwIUQT81kgj59oli6tDlwgkqVKrF+/XqcnZ2NjpbzTp9meb169IyJQQOlq1bNtrZ/KeBCiKdiMsErr0SxYEFz4AA+Pj5s3LgRV1dXo6PlvH/+Yc3zz9MtMpJkYPTo0dna80YKuBDiiWkNAwdq5s1rA+zB29ubTZs25c0Zdf7+m9W1atHp5k2SgP/973+MHTs2Ww8pBVwI8US0hsGD4dtvFR07vkv58uXZsmULJUuWNDpazjtxgg3PPUfHf/8lERg2bBiffPJJtneblG6EQojHpjUMGaKZNk0xbBhMnNic5OSj2GTDzSpm79gxaNyYOjY2VKtUiaatW/PRRx/lSJ/3PPjTFkI8DZMJXnstklmz2tCu3SgmTmyKUuTN4n34MLpxY5SDAwU3b2Z7iRI4ODjk2A1L0oQihMi05GTo3v0ms2Y1BbZz6tRQkpOTjI5ljAMH+Pb55wmOj8e0ZQuUK4ejo2OO3m0qBVwIkSmJidCuXRiLFzcA9lGqVCl+/fXXvHnlvW0bnz//PK/HxjIvNpbfQ0IMifHIAq6U+k4pdV0pdeyeZYWVUhuVUmdS/y2UvTGFEEaKj4cXXzzPjz/WA45ToUIF/vjjD0qUKGF0tByn167lf02a8HZCAgBff/01zZo1MyRLZq7A5wEtHlg2AtiktX4W2JT6WgiRC92+DY0anWTTpgDgHDVq1GD79u08k013F5qzpLlz6RsYyITkZGxsbPj+++8ZNGiQYXkeWcC11tuBiAcWtwHmpz6fD7TN4lxCCDMQGQnNm8OePdexsblBQEAAmzZtws3NzehoOS5+wgQ6BgfzndY4Ojqydu1aunfvbmimJ228Kqa1Dkt9fhUoltGGSqn+QH8gT/65JYSlunwZWrRIGQV16dIGPPPMZvz8/HBycjI6Ws7SGkaPxvTRR1x3dcUlOZl169ZRt25do5M9fTdCrbVWSumHrJ8JzASoWbNmhtsJIczHyZMQELCS27et+eWXtrzwAsDzRsfKeYmJ8OqrMHcuTn378tPHHxN2/TqVKlUyOhnw5L1Qriml3AFS/72edZGEEEbatQtq1Piamzc7onVXvL3/MTqSMW7d4nD9+rw2dy7JI0fCzJkULlLEbIo3PHkB/xHolfq8F7A2a+IIIYy0dq2JBg3eIS5uCKB5//3R+Pj4GB0r54WGst7Xl3q7dzMDmFqkCJjhbEKPbEJRSi0GGgJuSqlQ4APgU2CZUuoV4ALQKTtDCiGy3zffJDBwYB9gMTY2NsyZM4eePXsaHSvnHTrEzIYNGRAVRTLQtWtXXn31VaNTpeuRBVxr3TWDVU2yOIsQwgAmEwwdGsnXX7cHNpM/f35WrlxpWN9mI5l+/ZV327RhQmIiACNHjmTcuHFYWZnnPY958BYqIcRdt29D9+6wZs1FbGx24+ZWnF9++YVq1aoZHS1naU3c55/Tc/hwVpAyrsuMGTMIDg42OtlDSQEXIo+6cgVatYJDh+Crr6pSvvwaypUrm/eGg01IgIEDsZ0zh3+LFME5IYEVK1bQNGU2ZrMmBVyIPOjgQWjceCZxcY78+GMPWrYEMP+CleWuXUMHBqL+/BOb0aNZOngw12/coEKFCkYnyxQp4ELkMStWJNG169skJX2Fra0dlSvXB/LYVTfAgQNMb9KENbduse6HH7Dt2hVXwLVIEaOTZZp5tswLIbKcyQRvvx1Bx44vpxZvW2bM+DbvNZkAdxYu5LVatRgQGclvJhM/OzoaHemJyBW4EHlAZCS0anWYHTsCgfO4urqyevVqAgICjI6Ws+7cIfT11+n83XfsAuzt7Zk9ezZt21rmcE5SwIXI5Y4fh6ZNfyQsrAsQR/Xq1Vm5cj74Y7UAABx+SURBVCXe3t5GR8tZoaFsbNaMbidPcgPw9PRk1apV+Pv7G53siUkTihC52KpVUKcO3LlTFicnG/r06cOOHTvyXvH+/Xe2V65M89Ti3axZMw4ePGjRxRvkClyIXCkxEd5+O5IpUwpSu7Zi5cryJCYeoWTJkjk65ZfhTCb45BMYPZp65cvzoq8vtRo3ZtSoUVhbWxud7qlJARcil7l0CVq02MSJE0EEBIxh48bXsLcH8DY4WQ67fp1dbdpQcvdununWDasZM/jR0TFXFO67pAlFiFzkp5+SKFt2FCdONAWu4eS0Bju7vDeKc9KvvzKudGkCdu+mi48PiXPnQv78uap4gxRwIXKFxEQYMOASrVs3Ij7+I6ysFGPHjuXnn3/OW00md+5woX9/Gr30Eh/cvo1WirodOpjlSIJZQZpQhLBwly5B8+Y/cfJkbyACd3cPFi/+gQYNGhgdLWedOcOy5s3pf/48UYCHuzsLvv+eJk1y77h7cgUuhAVbsgSqVEnm778/ACJ46aWXOHz4UN4q3lrD3Lm8XrEinVOLd5s2bTh85EiuLt4gBVwIixQVBd27a7p2hQoVrNmwYQmTJ0/mp59+oogF3Qr+1K5ehTZtIDiYEl5eONjbM336dFavXp0nJl6WJhQhLMzWrckEBk4gKuowH3ywmFGjFDY2ZXnhhbJGR8tRMfPnc3LIEPzv3IHJkxk+cCDtQ0IoWzbv/BykgAthIRIT4Y03zjJ9ek9gFwAvvvgmNja1jQ2W027c4I9Onei9ZQuR1tYc37KF4gEB2ECeKt4gTShCWISDBzWlS89k+nRfYBfu7h5s2LCB2rXzVvGOXbaMN0uWpMGWLZwDPCtWJKpoUaNjGUYKuBBm7M4dGDLkPDVqNCU09FXgNp07d+bYsaN5a8qzsDB+CwigcufOfBkbi5W1NaNHj2bvvn2UK1fO6HSGkSYUIczUgQPQuzccPToL2EThwq5Mnfo1XbtmNE1tLmQywezZjBkyhLEJCQBUqVyZ7+bOpWbNmgaHM55cgQthZhIS4N1371CrFty8CStWjGbo0KGcPHkibxXvU6egYUN49VWaV6qEk6Mjn3zyCfsPHJDinUquwIUwI5s23aFr14mEh39Ht24HmDq1IIUKOdK+/RdGR8s58fGcfucd1n7zDcOdnWHOHJ7r04dL//5L4cKFjU5nVqSAC2EGbtyA7t23smHDAOAkAC1a/EihQj2MDZbDYlas4MN+/ZgcGUki4PfttzTt1AlAinc6pAlFCANpDV98cY1nnunBhg2NgJP4+JRh48aN9OiRd4q3/ucffqhenXIdO/JZavEODg6mei6/k/JpyRW4EAY5cQLatVvC33+/BkRhZ2fPyJHv8c477+Dg4GB0vJwRG8vhoUMZPHs2f+iUURP9a9Zk6rRp1KpVy+Bw5k+uwIXIYVFR8Pbb4OcHly8XBqJo3rwFJ04c5/33388bxVtrWLYMKlZkzqxZ/KE1RVxdmTNnDrv37JHinUlyBS5EDjGZYPLky4wbt56YmFfo0wc++aQZZ8/uok6dOnlm2Ne4TZs49+abVDp6FKpWZey6dTj98QcjRozAxcXF6HgWRQq4EDlg06ZYevSYRFjYZ0AcCxb40aNHDQCKFn3O2HA5JPnUKRZ2786o/fuxsbbm1IwZ2L/yCoWsrfm0ZUuj41kkaUIRIhtduqSpW3cxL7xQnrCwD4BY2rdvR0CAq9HRcs6NG2wMDKRGhQr03r+fUMClUiWuNG0KuWyGnJwmBVyIbBAZCd2776JkyefZtasbcAlf32ps3bqVFStW5I1Z4W/d4q/+/WlevDjN1qzhMODl4cH8+fPZf/AgpUqVMjqhxZMmFCGyUHw8fPMNfPQRRETMBXbj5laMzz77mF69euW6ORnTdfs2TJtG8qef0vXffzkHOOfPz7sjR/LGG2/g6OhodMJcQwq4EFnAZIIvvwxhwoQbXLtWk+bN4c03P2DnTneGDx9OgQIFjI6Y/RISOD52LMVmz8YtPBzrF1/kw7p1ORITw9tvv42rax5qNsohUsCFeApaw6JF13nzzY+4cWM6Dg5l2bDhMM2aWQOeNG8+zuiI2S8+nlOffsr4iRNZHBvLW56eTNqxA+rWpSuQh0ZvyXFSwIV4AlrDkiU3eOutz7l6dSoQg1KK9u2rUbt2DFDQ6IjZ7/ZtDowezSfTp7MyPh4N2NrYoDp3hrp1jU6XJ0gBF+IxaA2rVt1i8OCPCAubBtwG4KWXWvLJJx9TtWpVYwPmhMhIDo8cyf9mzWJDYiIAdra29O7dm/dGjqRkyZIGB8w7pIALkQlaw8aNMGYM/PmnLVZW84HbtGjxEmPGvJ83Zsa5dg2+/BKmTSM2OpoNQD5HR14bMIC33noLDw8PoxPmOU9VwJVSIUA0kAwkaa1lkF6Rq5hMMHfuZUaPnkpY2Lt4eTnz7beOuLnNpEQJd/z9/Y2OmO3i9u1j0dChHPjzT77RGjp04Ln33mP2/v20bdtWvpw0UFZcgTfSWt/Igv0IYTbu3IGJE08wceIkoqIWAom0bp2fZctGYm8P0NrghNlMa64uWcI3o0Yx/dw57v4P/vqPP1KlVSsAXvHzMy6fAKQJRYj7REdrRo7cyaxZE4iP/wkApaxo374jI0e2TC3euVhcHIc++YQvp0xhcVQUd1IXV69alTeHD6dc8+aGxhP3e9oCroHflFIamKG1nvngBkqp/kB/gBIlSjzl4YTIHmFhMH06TJgwmISEaQDY2TkQHNyHYcPeokyZMgYnzGb//APffsvt776j7r//EgsopWjbqhVvDhtGQEBAnhlsy5I8bQGvp7W+rJQqCmxUSp3SWm+/d4PUoj4ToGbNmvopjydEltqyJYqvv45m3TpPkpLA3785J0/+wNChgxg0aBBFixY1OmL2SUri7KxZzJ84kVHnz2NnY0O+tm3przUmDw+GvPEGPj4+RqcUD/FUBVxrfTn13+tKqdVALWD7w98lhLGSkmDKlBNMmDCVa9cWYGPThoEDFzFoEJQu3ZLY2Ivkz5/f6JjZJuniRda98w7T16zht9SZ3it27EiXr74Cd3fy0OybFu+JC7hSKh9gpbWOTn3eDMgDt50JSxUWlsz//vczy5Z9TULC72nLGzaMYPJkE1ZWVoBV7izeCQkc/+Yb5k+dyvfnznE1dbGDnR2dO3em4ttvg7u7oRHF43uaK/BiwOrUdjEb4Aet9fosSSVEFtEatmyB8eO3sXVrbyAEAHt7J3r16sHgwYOoXLmykRGzj9Zw8CDMnQs//EC3iAiOpK4qV7o0rw4aRK9evWSyYAv2xAVca30O8M3CLEJkmWvXkpkyJZTly0ty5gw4O3sBIXh6lmLo0IEEBwdTqFAho2Nmi8TQUH4bO5Z5y5fzcVQUz9rbQ9u2DCxWjP1xcfTu0ydPzQCUm0k3QpFrmEywbNlFPvpoDseOfQcU4PnnjzN6tKJDh9IcPbqHGjVq5MohXZNv3mT7Z5+x5IcfWHH5MhGpy8u1aMGHP/wAhQqldAUTuYoUcGHxjhy5zfvvr2bDhu+Jj99ISu9WKFGiDCtXXqN48eIAuW+i3Nu34aefGPXBB8w5fTqtXRugoo8PPfr2pXv37pBL/9IQUsCFhYqKSpnUfNq0Axw+XJ+7g0rZ2NjRtm07Xn+9Hw0bNkz9YjL3MEVHs3vKFHwPHybfL7/A7dtccHTkKuDj6UmXnj3p0rVr7m3XF/eRAi4sxp07MHv2MRYuPMrBg12Jj4fy5Svj5ORAhQpVeeWVHnTq1CnXjc2RcOUKmydNYs2qVay9cIFrwPL8+ekQFARduzKicGHeuHOHGjVqSLt2HiMFXJi1pCRYuPA0X3+9jMOHl5OcfARwpF+/lvTr50zNmnb8++/pXNeTQl+6xLL332f1L7/wy/XrRN+zzrt4ceI//RR69QKgkjERhRmQAi7MjskE69aF89FHs9i/fxnJyYfT1uXLV5AuXTozblwsxYs7A+SK4q0TEzm+eDGVTp1C/for6tAhJgAHUtf7Pvssgd260TYwkKpVq8qVtgCkgAszkZQE69dHsnGjCytWwJUrMcBIAJycnAkMbEvXrp144YUXsM8lI0pFnTnD7199xa+//ML6kBAua80xKysqBQTAp5/yppUVN2xtadOmjczgLtIlBVwYJibGxNSpe1i0aC0nTqzFZNLY25+iRQvo0qUUBw+Ool69WjRr1ix3FO34eGI2beKLCRP4/cABdsXEkHTP6uIuLlycNYtKHToA0N2YlMKCSAEXOSo0NI5Jk35nzZq1XLjwE3A9bV2+fC4cOXKN0qWLAdCly3iDUmaN5IQEDi1ezOlff6XrjRuwcycOCQlMJGUWFGsrKwJ8fXmxQwdavPgivr6+ua7XjMheUsBFttIa9uyJZ/NmB379FXbt2orJ9P+TIRQrVpKOHdsQGNiGgIAAbG1tDcv6tJLi4zm0fDk7V61iy59/su3aNSIBe6BtlSo4DhiATZMmfHb6NMVKlqRx48a4uLgYHVtYMCngIsuFhcUxdeo21qxZz99/byA5uTywmurVYfjwxvz2WwBt2zalbds2VKlSxWK/kNPR0ai//oIdO9iwdi3tDh4k9oFtShcrRuOmTYn+/HMcU4emfb1ly5wPK3IlKeDiqSUnw7p155k1aw07d64nMnI7EJ+2vlChGI4dM+HhYQXY8+mnljfisCk5mTPbtrF39Wr+3L6dnWfOUD8+nilag1KUKVeOWODZ4sWpGxBAQIsWNG7cGG9vb6Oji1xMCrh4bCaT5rffzrJ3rwMHDniydStERv4IvJW2TblyNWjXrgUvvticOnXqYGtrYW274eGwdy+zvv2WJbt2sT8igih9/3wkNu7uKSP91alDaWdnroWH5+4JIITZkQIuMmXnzhDmzNnC1q1buHBhCyZTKPAu3t4f0749lC3blL/+6k7bts1p1qyZxRQyk8nE+X37OPzzzxzetYt9x44xQSkqhYUBcArYnLrtMy4u1KxaldpNm1K3fv2UGekdHQFQYDHnLHIPKeDiP7SGU6dgxw74+uvRnDy5kKSkkPu2yZ/flV69rJk69e6SisD3OZz0MSUlwenTxOzZw9tTpnAkJISjUVHEPHBlHVizJpXeegv8/ellZ0eD8HBq1qyJh4eHQcGFSJ8UcEFUVCyLFv3FTz/t5ODBXSQkzCAy0hMAB4cwkpJCcHR0oU6dBrRu3YjGjRtRuXJls+3yFhURwclNmzi1cycnDx3i5Llz2EVHsyI+HuLjcSLlV83dLxw9nJ3xLV+eqrVrU71ePerXrw+pIxhWTX0IYY6kgOdB587FMW3aGrZv38Pff+8iOvog3HNLSYMGO+nZszP16kFi4lskJAzA19fXrMbRTk5O5tI//+By4wYuly/DiRPMXLeOMYcPE5aU9J/tHa2sSH7jDayrVcPK15dZhw5R3NOTqlWr4ubmZsAZCPH0pIDncqdPX2bZsr0cOfIvJlMf/voLLl0C6AEkA6CUFSVL+tGwYV2aNatLkyaN+P/m3IrGBE9liojgp7lzOXv4MOdOn+ZsaChnIyIIiYsjEZgNvAKgFHZuboQlJeFoY0M5d3cqlC9PBX9/yvv5UbFiRVSFCpD6V0O3qnJdLSyfFPBcJDT0JitW/MWWLfs5dGgvV67sJSkpLHWtK6VK9aZuXUWtWo7s2jWA8uXdaNDgeWrXrk2BAgVyNKvWmhvh4Vw8epRLhw9z6eRJLp4/z6XLlyE2liXFi8M//6AiIugJ3EpnH+4FChDfvj0MHQply9I2IYGGkZGUKFHCbJt3hMhKFlPAExMTLfouvayUmJjIX3/9zc8/HyYhoTLXrvly8CCcPLkErQfdt629fUHKl69Jgwb+fPZZAg4ODgC8+eaUbMmmtSYqKoqrFy9y9dQprp45w9Xz57kaGkqnkiWpnpQEFy/y8aFDjLpxI919OCqFfvZZVMeOKB8femzZgnZ2xsfXF5+KFfHx8aF06dI4OTnd9z4XR0e5s1HkKZZRwD/9lDrjx/N3QgJFHB0pkj8/RVxcKOLqSpFixWhUrx4vtW0LRYoQZ2VF2NWrFCpUCGdnZ7Nqt30SycmwbNkONm06wIEDhzl37hBRUceBhNQtRvHMM75UqwZ16vize3d9ateuRuPG/tSq5U+ZMmWe+mo0MjKSK5cucTMkhJsXLxJx+TI3w8K4ef06LlozolIlCAsj4coVXLZuJf6BXh13lQKqu7tDiRI84+1NoehovFxdKeHujpe3NyXKl8erXDm8SpZE16uHSs09dfjwp8ovRG5lGQXcxYVIrbmdnMztmBhCYmLg6v/PAKhXreKlt1JuItlrZ0eDO3fS1jnb2OBiZ5dydZYvHwt79MDL2xtcXFh19CgXbt0iv4sL+VxcyFeoEPkKF8apUCGKFitGmTJlUvavNSaTKdt+GWituXbtOlu2HGfbthMcO/YPpUp9wfHjipMnIT5+AHD0vvcULlyKChX86NmzEv3TZquthcm0hfj4eOLj44mLi+PcuXPExcURGxlJAaWo6O4Ot25x89Ilvl2+nFuRkURHRXErOppbt29z6/ZtouPjmV2hAtWSkuDmTUZfusTUhIQHYwNQFhixZQsUL469uzsONjbYKUXxAgUoXrgwxYsVo/gzz1Dc25ta7dpB6ryUvbSmt4XeQi+EuVA6g6ul7FCzZk29b9++J3rv3T/Nb9y4Qfj164RfvEj4+fOEh4ZSy8ODxh4eEB7Opn37CP7lFyITEriVTm+Ey8Dd3rxtgB8zOF5zKyvWu7pCvnyE29tT9O+/sbeywtHKCjsrK+ysrdMe0+rVo3GpUmBry4LTp1l48iR2NjbY2thgpRQq9eGaLx/T2nUgIkLx9+mrjNi4gMi4GJKT4tGpXyjeVcOhNS2eqU5xtyR2h65lWcQ5bG1tsLKxxsrammSTiYSkJOKSktj3/PMUSkqCuDiaHTvGxpiYdM+pPbAi9XkIKVfEGfmpdGleLl0aXF2ZdPEis06fxrVAAVwLFcLV1RXXokUpXLw4nmXK0Ou11yC1GMfHx6c10wghsoZSar/WuuaDyy3jChxQSuHi4oKLi0vKlfHzz6e7XRPgQurz5ORkoqOjiYyMTHn8+y9FfH1TZvOOiiLwhx8odfw4MdHR3L59m9uxscTGx3M7Lo6KxYqBvz/ExhIXHo76+28STCYSTKb/HDN+1y7480+4c4fTsbFsTGcbSPnF8e3u3RQBigBnuHcw1ft1jf+RHmd/hLMQZ23N+8nJEJf+trFXr1KoYEEoUACHfPmwj43F0cYGRxsbHGxscLSzw9HOjlIVKqRMw+XsjJuNDSOWL8fZ1RVnNzecixTB2c2NAs7OFChQgLJly0LBggC8nfrIDCneQuQci7kCN5rWOq1ZIqVpIoSDB89z5Mh5qlcfzLVrLpw9C6tX1yA29gBWgDUpt1iT+m9hl5r077OJZ8to3ItHsuCH9/B4xg334sVxLlAAZW2Nrb09Nvb2VKpcmYpVqoCVFeHh4Rw5cgQbG5u0h7W1NQ4ODjg4OODt7Y2NjU1aTksd3U8Ikb6MrsClgKdDa01CQgKJiQ6cPw9btx5n0aLJXL58noiI88TFXYL7mjwuYmXlhacnJCR0JSpqI8WLl8XH51l8fZ/F378s5co9S5kyZXK8u54QwvJZfBNKVktOhrVrd7B//xnOnLnEhQuXuHo1lIiIC9y+HYK9/WvEx09O3foW8N1978+f34NixUrh7V2K0aNNPPcc2NnBnTvzsbOzy/HzEULkPbmugMfGxnL69HlOngzj1KnL/PNPSnEOC7tERMRlKlbcR2ioLZcvQ3LyO8Cf6e7H0/M6ffqAjw8ULVqB/fu/oVKlUpQuXYqSJUtm2NYrxVsIkVMspoBfvHiRixevcOrUVf755yohIWFcvnyVa9fCcHfvQL58PQkLg3Pn1nPrVvsM95OcfIUGDUri5QVHj7bg9u0ylC7tSYUKXpQt64WXlxelSpXC2dn5nne50KjR69l/kkII8RgsooCPHg0ff1wfk+lCuuuvXPGhXLmeeHqCt7cXO3aUp1ChYnh4eFKypCflynlRsaIXJUp4UrmyO/9/kfx+jp2DEEJkNYso4M8+C15efiQmFsHNzZ3ixYtTokRxfHzcKVOmOFWqVKRcubtb+wMnDUwrhBA5Q3qhCCGEmcuoF4oM2SaEEBZKCrgQQlgoKeBCCGGhpIALIYSFeqoCrpRqoZT6Wyn1j1JqRFaFEkII8WhPXMCVUtbANOBFUiZO7KqUMnYCRSGEyEOe5gq8FvCP1vqc1voOsISUIbaFEELkgKcp4M8Al+55HZq67D5Kqf5KqX1KqX3h4eFPcTghhBD3yvY7MbXWM4GZAEqpcKVU+vfDP5obkP4suJbD0s9B8hvP0s/B0vODMedQMr2FT1PALwNe97z2TF2WIa11kSc9mFJqX3p3IlkSSz8HyW88Sz8HS88P5nUOT9OEshd4VilVSillB3Qh4ykmhRBCZLEnvgLXWicppQYBG0iZPew7rfXxLEsmhBDioZ6qDVxr/QvwSxZleZSZOXSc7GTp5yD5jWfp52Dp+cGMziFHRyMUQgiRdeRWeiGEsFBSwIUQwkKZXQF/1PgqSil7pdTS1PV7lFLeOZ8yY5nI3zu1P/yh1EdfI3JmRCn1nVLqulLqWAbrlVJqSur5HVFKVc/pjI+SiXNoqJSKuuczMKu59ZRSXkqpLUqpE0qp40qpN9LZxmw/h0zmN/fPwEEp9ZdS6nDqOYxNZxvja5HW2mwepPRmOQuUBuyAw0DFB7YZAHyb+rwLsNTo3I+Zvzcw1eisDzmH+kB14FgG618CfgUUUAfYY3TmJziHhsA6o3M+JL87UD31eQHgdDr/HZnt55DJ/Ob+GSggf+pzW2APUOeBbQyvReZ2BZ6Z8VXaAPNTn68AmiilVA5mfBiLHx9Ga70diHjIJm2ABTrFbsBFKeWeM+kyJxPnYNa01mFa6wOpz6NJmeT1wWEqzPZzyGR+s5b6c41JfWmb+niwx4fhtcjcCnhmxldJ20ZrnQREAa45ku7RMjU+DNA+9c/eFUopr3TWm7PMnqO5ey71z+NflVKVjA6TkdQ/y6uRcgV4L4v4HB6SH8z8M1BKWSulDgHXgY1a6ww/A6NqkbkV8LzgJ8Bba10V2Mj//wYXOecAUFJr7Qt8DawxOE+6lFL5gZXAUK31LaPzPK5H5Df7z0Brnay19iNlmJBaSqnKRmd6kLkV8MyMr5K2jVLKBigI3MyRdI/2yPxa65ta64TUl7OBGjmULas89hg45kZrfevun8c65WY0W6WUm8Gx7qOUsiWl+C3SWq9KZxOz/hweld8SPoO7tNaRwBagxQOrDK9F5lbAMzO+yo9Ar9TnHYDNOvVbBDPwyPwPtFO2JqV90JL8CPRM7QVRB4jSWocZHepxKKWK322rVErVIuX/A3O5CCA12xzgpNZ6cgabme3nkJn8FvAZFFFKuaQ+dwSaAqce2MzwWpTtw8k+Dp3B+CpKqXHAPq31j6T8h/G9UuofUr6o6mJc4vtlMv8QpVRrIImU/L0NC5wOpdRiUnoIuCmlQoEPSPkCB631t6QMnfAS8A8QC/QxJmnGMnEOHYDXlVJJQBzQxYwuAgDqAj2Ao6ltsADvASXAIj6HzOQ398/AHZivUmYeswKWaa3XmVstklvphRDCQplbE4oQQohMkgIuhBAWSgq4EEJYKCngQghhoaSACyGEhZICLoQQFkoKuBBCWKj/A0ehJdY29BBtAAAAAElFTkSuQmCC",
539 | "text/plain": [
540 | ""
541 | ]
542 | },
543 | "metadata": {
544 | "needs_background": "light"
545 | },
546 | "output_type": "display_data"
547 | }
548 | ],
549 | "source": [
550 | "dT = 1e-3\n",
551 | "Tf = jnp.pi\n",
552 | "T = np.arange(0,Tf+dT,dT)\n",
553 | "\n",
554 | "plt.plot(t,np.exp(t),'-b',label='Analytical[y1]')\n",
555 | "plt.plot(T,fwd(params,T.reshape(-1,1))[:,0],'--k',label='NN[y1]',linewidth=2)\n",
556 | "\n",
557 | "plt.plot(t,0.5*np.exp(t)+1.5*np.exp(-t),'-r',label='Analytical[y2]')\n",
558 | "plt.plot(T,fwd(params,T.reshape(-1,1))[:,1],'--k',label='NN[y2]',linewidth=2)\n",
559 | "\n",
560 | "plt.legend()"
561 | ]
562 | }
563 | ],
564 | "metadata": {
565 | "colab": {
566 | "collapsed_sections": [],
567 | "include_colab_link": true,
568 | "name": "[5] System-of-ODE-PINN.ipynb",
569 | "provenance": []
570 | },
571 | "kernelspec": {
572 | "display_name": "Python 3.8.9 64-bit",
573 | "language": "python",
574 | "name": "python3"
575 | },
576 | "language_info": {
577 | "codemirror_mode": {
578 | "name": "ipython",
579 | "version": 3
580 | },
581 | "file_extension": ".py",
582 | "mimetype": "text/x-python",
583 | "name": "python",
584 | "nbconvert_exporter": "python",
585 | "pygments_lexer": "ipython3",
586 | "version": "3.8.9"
587 | },
588 | "vscode": {
589 | "interpreter": {
590 | "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
591 | }
592 | }
593 | },
594 | "nbformat": 4,
595 | "nbformat_minor": 0
596 | }
597 |
--------------------------------------------------------------------------------
/[6]_ODE_PINN_finite_difference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {
17 | "id": "v77fdC1ZLyg1"
18 | },
19 | "outputs": [],
20 | "source": [
21 | "#Credits : Mahmoud Asem @Asem000"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": null,
27 | "metadata": {
28 | "colab": {
29 | "base_uri": "https://localhost:8080/"
30 | },
31 | "id": "vAR0swbLX_ZI",
32 | "outputId": "dab17b4c-908d-4213-ac91-c20d0f7ef122"
33 | },
34 | "outputs": [
35 | {
36 | "output_type": "stream",
37 | "name": "stdout",
38 | "text": [
39 | "Requirement already satisfied: optax in /usr/local/lib/python3.7/dist-packages (0.1.1)\n",
40 | "Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from optax) (1.21.5)\n",
41 | "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.7/dist-packages (from optax) (0.3.2+cuda11.cudnn805)\n",
42 | "Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.7/dist-packages (from optax) (1.0.0)\n",
43 | "Requirement already satisfied: typing-extensions>=3.10.0 in /usr/local/lib/python3.7/dist-packages (from optax) (3.10.0.2)\n",
44 | "Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.7/dist-packages (from optax) (0.3.4)\n",
45 | "Requirement already satisfied: chex>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from optax) (0.1.1)\n",
46 | "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py>=0.7.1->optax) (1.15.0)\n",
47 | "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.11.2)\n",
48 | "Requirement already satisfied: dm-tree>=0.1.5 in /usr/local/lib/python3.7/dist-packages (from chex>=0.0.4->optax) (0.1.6)\n",
49 | "Requirement already satisfied: scipy>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (1.4.1)\n",
50 | "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax>=0.1.55->optax) (3.3.0)\n",
51 | "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.37->optax) (2.0)\n"
52 | ]
53 | }
54 | ],
55 | "source": [
56 | "#Imports\n",
57 | "import jax \n",
58 | "import jax.numpy as jnp\n",
59 | "import numpy as np\n",
60 | "import matplotlib.pyplot as plt\n",
61 | "from matplotlib import cm\n",
62 | "import matplotlib as mpl\n",
63 | "!pip install optax\n",
64 | "import optax"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {
71 | "id": "yoPHsh5lWvyP"
72 | },
73 | "outputs": [],
74 | "source": [
75 | "import sympy as sp"
76 | ]
77 | },
78 | {
79 | "cell_type": "markdown",
80 | "metadata": {
81 | "id": "7bg4nSbsXVwD"
82 | },
83 | "source": [
84 | "### Generate a a differential equation and its solution using SymPy"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {
91 | "id": "P9664e-mVMTN"
92 | },
93 | "outputs": [],
94 | "source": [
95 | "t= sp.symbols('t')\n",
96 | "f = sp.Function('y')\n",
97 | "diffeq = sp.Eq(f(t).diff(t,t) + f(t).diff(t)-t*sp.cos(2*sp.pi*t),0)\n",
98 | "sol = sp.simplify(sp.dsolve(diffeq,ics={f(0):1,f(t).diff(t).subs(t,0):10}).rhs)"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "metadata": {
105 | "colab": {
106 | "base_uri": "https://localhost:8080/",
107 | "height": 54
108 | },
109 | "id": "klgFeU6bcTrC",
110 | "outputId": "5c0fc009-ce8b-472f-b98b-6da8afd6a66c"
111 | },
112 | "outputs": [
113 | {
114 | "output_type": "execute_result",
115 | "data": {
116 | "text/plain": [
117 | "Eq(-t*cos(2*pi*t) + Derivative(y(t), t) + Derivative(y(t), (t, 2)), 0)"
118 | ],
119 | "text/latex": "$\\displaystyle - t \\cos{\\left(2 \\pi t \\right)} + \\frac{d}{d t} y{\\left(t \\right)} + \\frac{d^{2}}{d t^{2}} y{\\left(t \\right)} = 0$"
120 | },
121 | "metadata": {},
122 | "execution_count": 4
123 | }
124 | ],
125 | "source": [
126 | "diffeq"
127 | ]
128 | },
129 | {
130 | "cell_type": "code",
131 | "execution_count": null,
132 | "metadata": {
133 | "colab": {
134 | "base_uri": "https://localhost:8080/",
135 | "height": 60
136 | },
137 | "id": "E4Uu2hbiYJtv",
138 | "outputId": "6444bdb2-9243-46c9-e801-d19960dcb233"
139 | },
140 | "outputs": [
141 | {
142 | "output_type": "execute_result",
143 | "data": {
144 | "text/plain": [
145 | "Eq(Subs(Derivative(y(t), t), t, 0), 10)"
146 | ],
147 | "text/latex": "$\\displaystyle \\left. \\frac{d}{d t} y{\\left(t \\right)} \\right|_{\\substack{ t=0 }} = 10$"
148 | },
149 | "metadata": {},
150 | "execution_count": 5
151 | }
152 | ],
153 | "source": [
154 | "sp.Eq(f(t).diff(t).subs(t,0),10)"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "metadata": {
161 | "colab": {
162 | "base_uri": "https://localhost:8080/",
163 | "height": 38
164 | },
165 | "id": "29QUbt_2YwlJ",
166 | "outputId": "0e3cce9d-29c7-4f26-c279-535093e2ad37"
167 | },
168 | "outputs": [
169 | {
170 | "output_type": "execute_result",
171 | "data": {
172 | "text/plain": [
173 | "Eq(y(0), 1)"
174 | ],
175 | "text/latex": "$\\displaystyle y{\\left(0 \\right)} = 1$"
176 | },
177 | "metadata": {},
178 | "execution_count": 6
179 | }
180 | ],
181 | "source": [
182 | "sp.Eq(f(t).subs(t,0),1)"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "metadata": {
189 | "colab": {
190 | "base_uri": "https://localhost:8080/",
191 | "height": 81
192 | },
193 | "id": "r9KVq1yjYfld",
194 | "outputId": "9a8437fc-c2fc-4fe7-ad9f-26ca25b974d4"
195 | },
196 | "outputs": [
197 | {
198 | "output_type": "execute_result",
199 | "data": {
200 | "text/plain": [
201 | "Eq(y(t), (2*pi*t*exp(t)*sin(2*pi*t) + 8*pi**3*t*exp(t)*sin(2*pi*t) - 16*pi**4*t*exp(t)*cos(2*pi*t) - 4*pi**2*t*exp(t)*cos(2*pi*t) + 16*pi**3*exp(t)*sin(2*pi*t) + exp(t)*cos(2*pi*t) + 12*pi**2*exp(t)*cos(2*pi*t) - exp(t) + 36*pi**2*exp(t) + 336*pi**4*exp(t) + 704*pi**6*exp(t) - 640*pi**6 - 304*pi**4 - 44*pi**2)*exp(-t)/(4*pi**2*(1 + 8*pi**2 + 16*pi**4)))"
202 | ],
203 | "text/latex": "$\\displaystyle y{\\left(t \\right)} = \\frac{\\left(2 \\pi t e^{t} \\sin{\\left(2 \\pi t \\right)} + 8 \\pi^{3} t e^{t} \\sin{\\left(2 \\pi t \\right)} - 16 \\pi^{4} t e^{t} \\cos{\\left(2 \\pi t \\right)} - 4 \\pi^{2} t e^{t} \\cos{\\left(2 \\pi t \\right)} + 16 \\pi^{3} e^{t} \\sin{\\left(2 \\pi t \\right)} + e^{t} \\cos{\\left(2 \\pi t \\right)} + 12 \\pi^{2} e^{t} \\cos{\\left(2 \\pi t \\right)} - e^{t} + 36 \\pi^{2} e^{t} + 336 \\pi^{4} e^{t} + 704 \\pi^{6} e^{t} - 640 \\pi^{6} - 304 \\pi^{4} - 44 \\pi^{2}\\right) e^{- t}}{4 \\pi^{2} \\left(1 + 8 \\pi^{2} + 16 \\pi^{4}\\right)}$"
204 | },
205 | "metadata": {},
206 | "execution_count": 7
207 | }
208 | ],
209 | "source": [
210 | "sp.Eq(f(t),sol)"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": null,
216 | "metadata": {
217 | "colab": {
218 | "base_uri": "https://localhost:8080/",
219 | "height": 37
220 | },
221 | "id": "MNVOpPyCW-GU",
222 | "outputId": "53a70887-e440-43a8-f811-30c862b21866"
223 | },
224 | "outputs": [
225 | {
226 | "output_type": "execute_result",
227 | "data": {
228 | "text/plain": [
229 | "0"
230 | ],
231 | "text/latex": "$\\displaystyle 0$"
232 | },
233 | "metadata": {},
234 | "execution_count": 8
235 | }
236 | ],
237 | "source": [
238 | "#verify solution\n",
239 | "sp.simplify(-t*sp.cos(sp.pi*2*t)+sol.diff(t)+sol.diff(t,t))"
240 | ]
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "metadata": {
245 | "id": "NQ61lEQeXgrc"
246 | },
247 | "source": [
248 | "### Constructing the MLP"
249 | ]
250 | },
251 | {
252 | "cell_type": "code",
253 | "execution_count": null,
254 | "metadata": {
255 | "id": "Lml6PGLPZgmr"
256 | },
257 | "outputs": [],
258 | "source": [
259 | "N_b = 1\n",
260 | "N_c = 100\n",
261 | "\n",
262 | "tmin,tmax=0. ,jnp.pi\n",
263 | "\n",
264 | "'''boundary conditions'''\n",
265 | "\n",
266 | "\n",
267 | "# U[0] = 1\n",
268 | "t_0 = jnp.ones([N_b,1],dtype='float32')*0.\n",
269 | "ic_0 = jnp.ones_like(t_0) \n",
270 | "IC_0 = jnp.concatenate([t_0,ic_0],axis=1)\n",
271 | "\n",
272 | "# U_t[0] = 10\n",
273 | "t_b1 = jnp.zeros([N_b,1])\n",
274 | "bc_1 = jnp.ones_like(t_b1) * 10\n",
275 | "BC_1 = jnp.concatenate([t_b1,bc_1],axis=1)\n",
276 | "\n",
277 | "conds = [IC_0,BC_1]\n",
278 | "\n",
279 | "#collocation points\n",
280 | "\n",
281 | "key=jax.random.PRNGKey(0)\n",
282 | "\n",
283 | "t_c = jnp.linspace(tmin,tmax,N_c).reshape(-1,1)\n",
284 | "colloc = t_c\n",
285 | "\n",
286 | "def ODE_loss(t,u):\n",
287 | " dt = 0.03173326\n",
288 | " u_t = lambda t: (-u(t+2*dt)+8*u(t+dt)-8*u(t-dt)+u(t-2*dt))/(12*dt)\n",
289 | " u_tt = lambda t: (-u(t+2*dt) + 16*u(t+dt) -30*u(t) + 16 * u(t-dt) - u(t-2*dt))/(12*dt**2)\n",
290 | " return -t*jnp.cos(2*jnp.pi*t) + u_t(t) + u_tt(t)"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": null,
296 | "metadata": {
297 | "id": "KoZZJl2TbI_n"
298 | },
299 | "outputs": [],
300 | "source": [
301 | "def init_params(layers):\n",
302 | " keys = jax.random.split(jax.random.PRNGKey(0),len(layers)-1)\n",
303 | " params = list()\n",
304 | " for key,n_in,n_out in zip(keys,layers[:-1],layers[1:]):\n",
305 | " lb, ub = -(1 / jnp.sqrt(n_in)), (1 / jnp.sqrt(n_in)) # xavier initialization lower and upper bound\n",
306 | " W = lb + (ub-lb) * jax.random.uniform(key,shape=(n_in,n_out))\n",
307 | " B = jax.random.uniform(key,shape=(n_out,))\n",
308 | " params.append({'W':W,'B':B})\n",
309 | " return params\n",
310 | "\n",
311 | "def fwd(params,t):\n",
312 | " X = jnp.concatenate([t],axis=1)\n",
313 | " *hidden,last = params\n",
314 | " for layer in hidden :\n",
315 | " X = jax.nn.tanh(X@layer['W']+layer['B'])\n",
316 | " return X@last['W'] + last['B']\n",
317 | "\n",
318 | "@jax.jit\n",
319 | "def MSE(true,pred):\n",
320 | " return jnp.mean((true-pred)**2)\n",
321 | "\n",
322 | "def loss_fun(params,colloc,conds):\n",
323 | " t_c =colloc[:,[0]]\n",
324 | " ufunc = lambda t : fwd(params,t)\n",
325 | " ufunc_t=lambda t:jax.grad(lambda t:jnp.sum(ufunc(t)))(t)\n",
326 | " loss =jnp.mean(ODE_loss(t_c,ufunc) **2)\n",
327 | "\n",
328 | " t_ic,u_ic = conds[0][:,[0]],conds[0][:,[1]] \n",
329 | " loss += MSE(u_ic,ufunc(t_ic))\n",
330 | "\n",
331 | " t_bc,u_bc = conds[1][:,[0]],conds[1][:,[1]] \n",
332 | " loss += MSE(u_bc,ufunc_t(t_bc))\n",
333 | "\n",
334 | " return loss\n",
335 | "\n",
336 | "@jax.jit\n",
337 | "def update(opt_state,params,colloc,conds):\n",
338 | " # Get the gradient w.r.t to MLP params\n",
339 | " grads=jax.jit(jax.grad(loss_fun,0))(params,colloc,conds)\n",
340 | " \n",
341 | " #Update params\n",
342 | " updates, opt_state = optimizer.update(grads, opt_state)\n",
343 | " params = optax.apply_updates(params, updates)\n",
344 | "\n",
345 | " #Update params\n",
346 | " # return jax.tree_multimap(lambda params,grads : params-LR*grads, params,grads)\n",
347 | " return opt_state,params\n"
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": null,
353 | "metadata": {
354 | "id": "ae1ZDoy0c29c"
355 | },
356 | "outputs": [],
357 | "source": [
358 | "# construct the MLP of 6 hidden layers of 8 neurons for each layer\n",
359 | "params = init_params([1] + [20]*4+[1])"
360 | ]
361 | },
362 | {
363 | "cell_type": "code",
364 | "execution_count": null,
365 | "metadata": {
366 | "id": "jySmbUwic5yk"
367 | },
368 | "outputs": [],
369 | "source": [
370 | "lr = optax.piecewise_constant_schedule(1e-3,{10_000:5e-3,30_000:1e-3,50_000:5e-4,70_000:1e-4})\n",
371 | "optimizer = optax.adam(1e-3)\n",
372 | "opt_state = optimizer.init(params)"
373 | ]
374 | },
375 | {
376 | "cell_type": "code",
377 | "execution_count": null,
378 | "metadata": {
379 | "colab": {
380 | "base_uri": "https://localhost:8080/"
381 | },
382 | "id": "kBzGA8OVc8C6",
383 | "outputId": "ed7565fa-5172-42c8-ce96-39b3cd357d5b"
384 | },
385 | "outputs": [
386 | {
387 | "output_type": "stream",
388 | "name": "stdout",
389 | "text": [
390 | "Epoch=0\tloss=1.026e+02\n",
391 | "Epoch=1000\tloss=1.134e+01\n",
392 | "Epoch=2000\tloss=7.139e+00\n",
393 | "Epoch=3000\tloss=4.549e+00\n",
394 | "Epoch=4000\tloss=2.824e+00\n",
395 | "Epoch=5000\tloss=2.343e+00\n",
396 | "Epoch=6000\tloss=2.147e+00\n",
397 | "Epoch=7000\tloss=1.904e+00\n",
398 | "Epoch=8000\tloss=1.563e+00\n",
399 | "Epoch=9000\tloss=1.016e+00\n",
400 | "Epoch=10000\tloss=9.470e-01\n",
401 | "Epoch=11000\tloss=9.235e-01\n",
402 | "Epoch=12000\tloss=9.041e-01\n",
403 | "Epoch=13000\tloss=8.905e-01\n",
404 | "Epoch=14000\tloss=8.731e-01\n",
405 | "Epoch=15000\tloss=8.661e-01\n",
406 | "Epoch=16000\tloss=8.594e-01\n",
407 | "Epoch=17000\tloss=8.494e-01\n",
408 | "Epoch=18000\tloss=8.424e-01\n",
409 | "Epoch=19000\tloss=8.380e-01\n",
410 | "Epoch=20000\tloss=8.348e-01\n",
411 | "Epoch=21000\tloss=8.259e-01\n",
412 | "Epoch=22000\tloss=8.208e-01\n",
413 | "Epoch=23000\tloss=8.104e-01\n",
414 | "Epoch=24000\tloss=8.035e-01\n",
415 | "Epoch=25000\tloss=7.996e-01\n",
416 | "Epoch=26000\tloss=7.991e-01\n",
417 | "Epoch=27000\tloss=7.855e-01\n",
418 | "Epoch=28000\tloss=7.851e-01\n",
419 | "Epoch=29000\tloss=7.818e-01\n",
420 | "Epoch=30000\tloss=7.784e-01\n",
421 | "Epoch=31000\tloss=7.729e-01\n",
422 | "Epoch=32000\tloss=7.824e-01\n",
423 | "Epoch=33000\tloss=7.703e-01\n",
424 | "Epoch=34000\tloss=7.798e-01\n",
425 | "Epoch=35000\tloss=7.726e-01\n",
426 | "Epoch=36000\tloss=7.629e-01\n",
427 | "Epoch=37000\tloss=7.736e-01\n",
428 | "Epoch=38000\tloss=7.748e-01\n",
429 | "Epoch=39000\tloss=7.663e-01\n",
430 | "Epoch=40000\tloss=7.684e-01\n",
431 | "Epoch=41000\tloss=7.751e-01\n",
432 | "Epoch=42000\tloss=7.682e-01\n",
433 | "Epoch=43000\tloss=7.782e-01\n",
434 | "Epoch=44000\tloss=7.733e-01\n",
435 | "Epoch=45000\tloss=7.665e-01\n",
436 | "Epoch=46000\tloss=7.643e-01\n",
437 | "Epoch=47000\tloss=7.625e-01\n",
438 | "Epoch=48000\tloss=7.619e-01\n",
439 | "Epoch=49000\tloss=7.627e-01\n",
440 | "Epoch=50000\tloss=7.688e-01\n",
441 | "Epoch=51000\tloss=7.621e-01\n",
442 | "Epoch=52000\tloss=7.633e-01\n",
443 | "Epoch=53000\tloss=7.616e-01\n",
444 | "Epoch=54000\tloss=7.713e-01\n",
445 | "Epoch=55000\tloss=7.645e-01\n",
446 | "Epoch=56000\tloss=7.589e-01\n",
447 | "Epoch=57000\tloss=7.626e-01\n",
448 | "Epoch=58000\tloss=7.668e-01\n",
449 | "Epoch=59000\tloss=7.682e-01\n",
450 | "Epoch=60000\tloss=7.605e-01\n",
451 | "Epoch=61000\tloss=7.642e-01\n",
452 | "Epoch=62000\tloss=7.570e-01\n",
453 | "Epoch=63000\tloss=7.593e-01\n",
454 | "Epoch=64000\tloss=7.541e-01\n",
455 | "Epoch=65000\tloss=7.565e-01\n",
456 | "Epoch=66000\tloss=7.604e-01\n",
457 | "Epoch=67000\tloss=7.644e-01\n",
458 | "Epoch=68000\tloss=7.555e-01\n",
459 | "Epoch=69000\tloss=7.600e-01\n",
460 | "Epoch=70000\tloss=7.641e-01\n",
461 | "Epoch=71000\tloss=7.505e-01\n",
462 | "Epoch=72000\tloss=7.549e-01\n",
463 | "Epoch=73000\tloss=7.627e-01\n",
464 | "Epoch=74000\tloss=7.575e-01\n",
465 | "Epoch=75000\tloss=7.568e-01\n",
466 | "Epoch=76000\tloss=7.699e-01\n",
467 | "Epoch=77000\tloss=7.498e-01\n",
468 | "Epoch=78000\tloss=7.505e-01\n",
469 | "Epoch=79000\tloss=7.544e-01\n",
470 | "Epoch=80000\tloss=7.511e-01\n",
471 | "Epoch=81000\tloss=7.579e-01\n",
472 | "Epoch=82000\tloss=7.470e-01\n",
473 | "Epoch=83000\tloss=7.527e-01\n",
474 | "Epoch=84000\tloss=7.583e-01\n",
475 | "Epoch=85000\tloss=7.445e-01\n",
476 | "Epoch=86000\tloss=7.476e-01\n",
477 | "Epoch=87000\tloss=7.633e-01\n",
478 | "Epoch=88000\tloss=7.500e-01\n",
479 | "Epoch=89000\tloss=7.467e-01\n",
480 | "Epoch=90000\tloss=7.447e-01\n",
481 | "Epoch=91000\tloss=7.495e-01\n",
482 | "Epoch=92000\tloss=7.487e-01\n",
483 | "Epoch=93000\tloss=7.542e-01\n",
484 | "Epoch=94000\tloss=7.583e-01\n",
485 | "Epoch=95000\tloss=7.612e-01\n",
486 | "Epoch=96000\tloss=7.489e-01\n",
487 | "Epoch=97000\tloss=7.502e-01\n",
488 | "Epoch=98000\tloss=7.474e-01\n",
489 | "Epoch=99000\tloss=7.563e-01\n",
490 | "CPU times: user 1min 57s, sys: 256 ms, total: 1min 57s\n",
491 | "Wall time: 3min 3s\n"
492 | ]
493 | }
494 | ],
495 | "source": [
496 | "%%time\n",
497 | "epochs = 100_000\n",
498 | "for _ in range(epochs):\n",
499 | " opt_state,params = update(opt_state,params,colloc,conds)\n",
500 | "\n",
501 | " # print loss and epoch info\n",
502 | " if _ %(1000) ==0:\n",
503 | " print(f'Epoch={_}\\tloss={loss_fun(params,colloc,conds):.3e}')"
504 | ]
505 | },
506 | {
507 | "cell_type": "code",
508 | "execution_count": null,
509 | "metadata": {
510 | "colab": {
511 | "base_uri": "https://localhost:8080/",
512 | "height": 282
513 | },
514 | "id": "eWeNvDsdDEuI",
515 | "outputId": "1b79ef5d-7f2d-40ee-f4a1-4d0ba1cb0edd"
516 | },
517 | "outputs": [
518 | {
519 | "output_type": "execute_result",
520 | "data": {
521 | "text/plain": [
522 | ""
523 | ]
524 | },
525 | "metadata": {},
526 | "execution_count": 36
527 | },
528 | {
529 | "output_type": "display_data",
530 | "data": {
531 | "text/plain": [
532 | ""
533 | ],
534 | "image/png": "\n"
535 | },
536 | "metadata": {
537 | "needs_background": "light"
538 | }
539 | }
540 | ],
541 | "source": [
542 | "lam_sol= sp.lambdify(t,sol)\n",
543 | "\n",
544 | "dT = 1e-3\n",
545 | "Tf = jnp.pi\n",
546 | "T = np.arange(0,Tf+dT,dT)\n",
547 | "\n",
548 | "\n",
549 | "sym_sol =np.array([lam_sol(i) for i in T])\n",
550 | "\n",
551 | "plt.plot(T,sym_sol,'--r',label='sympy solution')\n",
552 | "plt.plot(T,fwd(params,T.reshape(-1,1))[:,0],'--k',label='NN solution')\n",
553 | "plt.legend()"
554 | ]
555 | }
556 | ],
557 | "metadata": {
558 | "colab": {
559 | "collapsed_sections": [],
560 | "name": "[6] ODE-PINN finite difference.ipynb",
561 | "provenance": [],
562 | "include_colab_link": true
563 | },
564 | "kernelspec": {
565 | "display_name": "Python 3",
566 | "language": "python",
567 | "name": "python3"
568 | },
569 | "language_info": {
570 | "codemirror_mode": {
571 | "name": "ipython",
572 | "version": 3
573 | },
574 | "file_extension": ".py",
575 | "mimetype": "text/x-python",
576 | "name": "python",
577 | "nbconvert_exporter": "python",
578 | "pygments_lexer": "ipython3",
579 | "version": "3.9.6"
580 | }
581 | },
582 | "nbformat": 4,
583 | "nbformat_minor": 0
584 | }
--------------------------------------------------------------------------------