├── .DS_Store ├── .coveragerc ├── .gitattributes ├── .gitignore ├── .ipynb_checkpoints ├── NN_ex3-checkpoint.ipynb ├── PyTorch_Diffusion-checkpoint.ipynb ├── PyTorch_V1-checkpoint.ipynb ├── PyTorch_V2-checkpoint.ipynb ├── PyTorch_V3-checkpoint.ipynb ├── PyTorch_V4-checkpoint.ipynb ├── PyTorch_V4_multi_traj-checkpoint.ipynb ├── PyTorch_V_KS-checkpoint.ipynb ├── README-checkpoint.md └── setup-checkpoint.py ├── AUTHORS.rst ├── CHANGELOG.rst ├── LICENSE.txt ├── README.md ├── config ├── .DS_Store ├── start_notebook.sh └── stop_notebook.sh ├── examples ├── .DS_Store ├── .gitignore ├── ODE_Couple_spring_mass.ipynb ├── ODE_Example_Maxwell_model.ipynb ├── ODE_Example_coupled_nonlin.ipynb ├── ODE_Logistic_equation.ipynb ├── ODE_Lotka_Volterra.ipynb ├── PDE_2D_Advection-Diffusion.ipynb ├── PDE_Burgers.ipynb ├── PDE_KdV.ipynb ├── VE_datagen.py ├── data │ ├── Advection_diffusion.mat │ ├── burgers.npy │ └── kdv.npy └── runs │ └── .DS_Store ├── requirements.txt ├── setup.cfg ├── setup.py ├── src ├── .DS_Store ├── __pycache__ │ ├── DeepMod.cpython-36.pyc │ ├── DeepMod.cpython-37.pyc │ ├── library_function.cpython-36.pyc │ ├── library_function.cpython-37.pyc │ ├── neural_net.cpython-36.pyc │ ├── neural_net.cpython-37.pyc │ ├── sparsity.cpython-36.pyc │ └── sparsity.cpython-37.pyc └── deepymod_torch │ ├── DeepMod.py │ ├── __init__.py │ ├── library_functions.py │ ├── losses.py │ ├── network.py │ ├── output.py │ ├── sparsity.py │ ├── training.py │ └── utilities.py └── tests ├── Untitled.ipynb ├── burgers.py ├── data ├── burgers.npy └── keller_segel.npy ├── diffusion.py ├── keller_segel.py └── runs ├── Mar18_10-10-19_2a6de4b656b9 └── events.out.tfevents.1584526219.2a6de4b656b9.90002.0 ├── Mar18_10-16-15_2a6de4b656b9 └── events.out.tfevents.1584526575.2a6de4b656b9.90002.1 ├── Mar18_10-54-59_2a6de4b656b9 └── events.out.tfevents.1584528899.2a6de4b656b9.84437.0 ├── Mar18_10-55-36_2a6de4b656b9 └── events.out.tfevents.1584528936.2a6de4b656b9.84887.0 ├── Mar18_10-55-50_2a6de4b656b9 └── events.out.tfevents.1584528950.2a6de4b656b9.85113.0 ├── Mar18_10-59-05_2a6de4b656b9 └── events.out.tfevents.1584529145.2a6de4b656b9.85113.1 ├── Mar18_11-18-25_2a6de4b656b9 └── events.out.tfevents.1584530305.2a6de4b656b9.1129.0 ├── Mar18_11-20-38_2a6de4b656b9 └── events.out.tfevents.1584530438.2a6de4b656b9.1129.1 ├── Mar18_11-21-06_2a6de4b656b9 └── events.out.tfevents.1584530466.2a6de4b656b9.3011.0 ├── Mar18_11-23-42_2a6de4b656b9 └── events.out.tfevents.1584530622.2a6de4b656b9.3011.1 ├── Mar18_11-25-57_2a6de4b656b9 └── events.out.tfevents.1584530757.2a6de4b656b9.6415.0 ├── Mar18_11-28-57_2a6de4b656b9 └── events.out.tfevents.1584530937.2a6de4b656b9.6415.1 ├── Mar18_11-52-32_2a6de4b656b9 └── events.out.tfevents.1584532352.2a6de4b656b9.24739.0 ├── Mar18_11-58-06_2a6de4b656b9 └── events.out.tfevents.1584532686.2a6de4b656b9.28482.0 ├── Mar18_12-57-02_2a6de4b656b9 └── events.out.tfevents.1584536222.2a6de4b656b9.65888.0 ├── Mar18_12-59-42_2a6de4b656b9 └── events.out.tfevents.1584536382.2a6de4b656b9.65888.1 ├── Mar18_13-00-54_2a6de4b656b9 └── events.out.tfevents.1584536454.2a6de4b656b9.68447.0 ├── Mar18_13-04-42_2a6de4b656b9 └── events.out.tfevents.1584536682.2a6de4b656b9.68447.1 ├── Mar18_13-45-28_2a6de4b656b9 └── events.out.tfevents.1584539128.2a6de4b656b9.97639.0 ├── Mar18_13-47-36_2a6de4b656b9 └── events.out.tfevents.1584539256.2a6de4b656b9.99060.0 ├── Mar18_13-50-01_2a6de4b656b9 └── events.out.tfevents.1584539401.2a6de4b656b9.1039.0 ├── Mar18_13-51-19_2a6de4b656b9 └── events.out.tfevents.1584539479.2a6de4b656b9.1935.0 ├── Mar18_14-04-33_2a6de4b656b9 └── events.out.tfevents.1584540273.2a6de4b656b9.10769.0 ├── Mar18_14-04-57_2a6de4b656b9 └── events.out.tfevents.1584540297.2a6de4b656b9.11078.0 ├── Mar18_14-05-40_2a6de4b656b9 └── events.out.tfevents.1584540340.2a6de4b656b9.99504.0 ├── Mar18_14-06-48_2a6de4b656b9 └── events.out.tfevents.1584540408.2a6de4b656b9.99504.1 ├── Mar18_14-07-09_2a6de4b656b9 └── events.out.tfevents.1584540429.2a6de4b656b9.99504.2 ├── Mar18_14-09-01_2a6de4b656b9 └── events.out.tfevents.1584540541.2a6de4b656b9.13828.0 ├── Mar18_14-09-39_2a6de4b656b9 └── events.out.tfevents.1584540579.2a6de4b656b9.13828.1 ├── Mar18_14-20-25_2a6de4b656b9 └── events.out.tfevents.1584541225.2a6de4b656b9.22220.0 ├── Mar18_14-21-03_2a6de4b656b9 └── events.out.tfevents.1584541263.2a6de4b656b9.22220.1 ├── Mar18_14-25-46_2a6de4b656b9 └── events.out.tfevents.1584541546.2a6de4b656b9.26142.0 ├── Mar18_14-26-20_2a6de4b656b9 └── events.out.tfevents.1584541580.2a6de4b656b9.26142.1 └── Mar18_14-49-13_2a6de4b656b9 └── events.out.tfevents.1584542953.2a6de4b656b9.41183.0 /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/.DS_Store -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | source = deepymod_torch 5 | # omit = bad_file.py 6 | 7 | [paths] 8 | source = 9 | src/ 10 | */site-packages/ 11 | 12 | [report] 13 | # Regexes for lines to exclude from consideration 14 | exclude_lines = 15 | # Have to re-enable the standard pragma 16 | pragma: no cover 17 | 18 | # Don't complain about missing debug-only code: 19 | def __repr__ 20 | if self\.debug 21 | 22 | # Don't complain if tests don't hit defensive assertion code: 23 | raise AssertionError 24 | raise NotImplementedError 25 | 26 | # Don't complain if non-runnable code isn't run: 27 | if 0: 28 | if __name__ == .__main__.: 29 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | __pycache__/* 12 | .cache/* 13 | .*.swp 14 | */.ipynb_checkpoints/* 15 | 16 | # Project files 17 | .ropeproject 18 | .project 19 | .pydevproject 20 | .settings 21 | .idea 22 | tags 23 | 24 | # Package files 25 | *.egg 26 | *.eggs/ 27 | .installed.cfg 28 | *.egg-info 29 | 30 | # Unittest and coverage 31 | htmlcov/* 32 | .coverage 33 | .tox 34 | junit.xml 35 | coverage.xml 36 | .pytest_cache/ 37 | 38 | # Build and docs folder/files 39 | build/* 40 | dist/* 41 | sdist/* 42 | docs/api/* 43 | docs/_rst/* 44 | docs/_build/* 45 | cover/* 46 | MANIFEST 47 | 48 | # Per-project virtualenvs 49 | .venv*/ 50 | 51 | .vscode/ 52 | 53 | src/deepymod_torch/.ipynb_checkpoints/ 54 | 55 | tests/datasets/.ipynb_checkpoints/ 56 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/PyTorch_V1-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 80, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append('src/')" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 81, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import numpy as np\n", 20 | "import torch, torch.nn\n", 21 | "from library_function import library_1D_new\n", 22 | "from neural_net import LinNetwork\n", 23 | "from DeepMod import DeepMod\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "plt.style.use('seaborn-notebook')\n", 26 | "import torch.nn as nn\n", 27 | "from torch.autograd import grad" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "# Preparing data" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 82, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "np.random.seed(34) \n", 44 | "number_of_samples = 1000\n", 45 | "\n", 46 | "data = np.load('data/burgers.npy', allow_pickle=True).item()\n", 47 | "\n", 48 | "X = np.transpose((data['x'].flatten(), data['t'].flatten()))\n", 49 | "y = np.real(data['u']).reshape((data['u'].size, 1))\n", 50 | "\n", 51 | "idx = np.random.permutation(y.size)\n", 52 | "X_train = torch.tensor(X[idx, :][:number_of_samples], dtype=torch.float32, requires_grad=True)\n", 53 | "y_train = torch.tensor(y[idx, :][:number_of_samples], dtype=torch.float32)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "# Building network" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 83, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "network = LinNetwork(input_dim=2, hidden_dim=20, layers=5, output_dim=1)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 88, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "weight_vector = torch.ones(6, 1, dtype=torch.float32, requires_grad=True)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 89, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "optimizer = torch.optim.Adam([{'params':network.parameters()}, {'params': weight_vector}])" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 90, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "optimizer.zero_grad() \n", 97 | "prediction = network(X_train)\n", 98 | "y = prediction" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 97, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "name": "stdout", 108 | "output_type": "stream", 109 | "text": [ 110 | "0 tensor(3.7311e-06, grad_fn=) tensor(-0.1818, grad_fn=)\n", 111 | "100 tensor(3.6724e-06, grad_fn=) tensor(-7.3358e-05, grad_fn=)\n", 112 | "200 tensor(3.6312e-06, grad_fn=) tensor(-4.9606e-05, grad_fn=)\n", 113 | "300 tensor(3.8892e-06, grad_fn=) tensor(9.5919e-05, grad_fn=)\n", 114 | "400 tensor(3.5963e-06, grad_fn=) tensor(-0.0001, grad_fn=)\n", 115 | "500 tensor(5.3377e-06, grad_fn=) tensor(-4.4078e-05, grad_fn=)\n", 116 | "600 tensor(5.0081e-06, grad_fn=) tensor(7.3165e-05, grad_fn=)\n", 117 | "700 tensor(3.8246e-06, grad_fn=) tensor(-0.0001, grad_fn=)\n", 118 | "800 tensor(3.4980e-06, grad_fn=) tensor(1.9535e-05, grad_fn=)\n", 119 | "900 tensor(3.4562e-06, grad_fn=) tensor(-5.7116e-05, grad_fn=)\n", 120 | "1000 tensor(3.4937e-06, grad_fn=) tensor(6.1989e-05, grad_fn=)\n", 121 | "1100 tensor(3.4262e-06, grad_fn=) tensor(-8.2478e-05, grad_fn=)\n", 122 | "1200 tensor(3.3935e-06, grad_fn=) tensor(-4.1530e-05, grad_fn=)\n", 123 | "1300 tensor(3.4495e-06, grad_fn=) tensor(6.9141e-05, grad_fn=)\n", 124 | "1400 tensor(3.3593e-06, grad_fn=) tensor(-6.8262e-05, grad_fn=)\n", 125 | "1500 tensor(3.3252e-06, grad_fn=) tensor(-3.8326e-05, grad_fn=)\n", 126 | "1600 tensor(3.3909e-06, grad_fn=) tensor(4.7579e-05, grad_fn=)\n", 127 | "1700 tensor(3.2960e-06, grad_fn=) tensor(-6.7174e-05, grad_fn=)\n", 128 | "1800 tensor(2.3697e-05, grad_fn=) tensor(0.0001, grad_fn=)\n", 129 | "1900 tensor(3.2710e-06, grad_fn=) tensor(-0.0001, grad_fn=)\n", 130 | "2000 tensor(3.2350e-06, grad_fn=) tensor(-4.8399e-05, grad_fn=)\n", 131 | "2100 tensor(9.2293e-06, grad_fn=) tensor(0.0002, grad_fn=)\n", 132 | "2200 tensor(3.2151e-06, grad_fn=) tensor(-0.0002, grad_fn=)\n", 133 | "2300 tensor(3.1807e-06, grad_fn=) tensor(-4.4554e-05, grad_fn=)\n", 134 | "2400 tensor(1.9931e-05, grad_fn=) tensor(3.9637e-05, grad_fn=)\n", 135 | "2500 tensor(3.1744e-06, grad_fn=) tensor(-1.3441e-05, grad_fn=)\n", 136 | "2600 tensor(3.1370e-06, grad_fn=) tensor(-4.9308e-05, grad_fn=)\n", 137 | "2700 tensor(3.1110e-06, grad_fn=) tensor(-2.2903e-05, grad_fn=)\n", 138 | "2800 tensor(3.1280e-06, grad_fn=) tensor(7.5161e-05, grad_fn=)\n", 139 | "2900 tensor(3.0853e-06, grad_fn=) tensor(-6.2227e-05, grad_fn=)\n", 140 | "3000 tensor(3.0578e-06, grad_fn=) tensor(-2.6256e-05, grad_fn=)\n", 141 | "3100 tensor(3.1008e-06, grad_fn=) tensor(5.0247e-05, grad_fn=)\n", 142 | "3200 tensor(3.0387e-06, grad_fn=) tensor(-5.5552e-05, grad_fn=)\n", 143 | "3300 tensor(1.9379e-05, grad_fn=) tensor(2.4483e-05, grad_fn=)\n", 144 | "3400 tensor(3.0176e-06, grad_fn=) tensor(6.4075e-07, grad_fn=)\n", 145 | "3500 tensor(2.9882e-06, grad_fn=) tensor(-4.0457e-05, grad_fn=)\n", 146 | "3600 tensor(4.8089e-06, grad_fn=) tensor(9.1687e-05, grad_fn=)\n", 147 | "3700 tensor(2.9691e-06, grad_fn=) tensor(-7.2911e-05, grad_fn=)\n", 148 | "3800 tensor(2.9419e-06, grad_fn=) tensor(-3.0220e-05, grad_fn=)\n", 149 | "3900 tensor(3.9835e-06, grad_fn=) tensor(3.8549e-05, grad_fn=)\n", 150 | "4000 tensor(2.9320e-06, grad_fn=) tensor(-2.4557e-05, grad_fn=)\n", 151 | "4100 tensor(2.9056e-06, grad_fn=) tensor(-2.4766e-05, grad_fn=)\n", 152 | "4200 tensor(4.1978e-06, grad_fn=) tensor(9.8139e-05, grad_fn=)\n", 153 | "4300 tensor(2.8869e-06, grad_fn=) tensor(-6.6012e-05, grad_fn=)\n", 154 | "4400 tensor(2.8615e-06, grad_fn=) tensor(-2.8625e-05, grad_fn=)\n", 155 | "4500 tensor(3.3687e-06, grad_fn=) tensor(2.2247e-05, grad_fn=)\n", 156 | "4600 tensor(2.8442e-06, grad_fn=) tensor(-1.1802e-05, grad_fn=)\n", 157 | "4700 tensor(2.8304e-06, grad_fn=) tensor(-1.7643e-05, grad_fn=)\n", 158 | "4800 tensor(2.8527e-06, grad_fn=) tensor(8.3357e-05, grad_fn=)\n", 159 | "4900 tensor(2.8050e-06, grad_fn=) tensor(-5.7667e-05, grad_fn=)\n", 160 | "5000 tensor(2.7875e-06, grad_fn=) tensor(-2.6137e-05, grad_fn=)\n", 161 | "5100 tensor(2.7982e-06, grad_fn=) tensor(2.2650e-05, grad_fn=)\n", 162 | "5200 tensor(3.0202e-06, grad_fn=) tensor(-3.9071e-05, grad_fn=)\n", 163 | "5300 tensor(2.7841e-06, grad_fn=) tensor(6.8218e-05, grad_fn=)\n", 164 | "5400 tensor(2.7457e-06, grad_fn=) tensor(-7.2166e-05, grad_fn=)\n", 165 | "5500 tensor(1.0309e-05, grad_fn=) tensor(-2.2650e-06, grad_fn=)\n", 166 | "5600 tensor(2.7385e-06, grad_fn=) tensor(2.8595e-05, grad_fn=)\n", 167 | "5700 tensor(2.7120e-06, grad_fn=) tensor(-4.8965e-05, grad_fn=)\n", 168 | "5800 tensor(5.8005e-06, grad_fn=) tensor(-0.0002, grad_fn=)\n", 169 | "5900 tensor(2.7059e-06, grad_fn=) tensor(0.0002, grad_fn=)\n", 170 | "6000 tensor(2.6830e-06, grad_fn=) tensor(-4.0665e-05, grad_fn=)\n", 171 | "6100 tensor(4.3549e-06, grad_fn=) tensor(-9.0778e-05, grad_fn=)\n", 172 | "6200 tensor(2.6626e-06, grad_fn=) tensor(8.7872e-05, grad_fn=)\n", 173 | "6300 tensor(2.7573e-06, grad_fn=) tensor(6.4954e-05, grad_fn=)\n", 174 | "6400 tensor(2.6473e-06, grad_fn=) tensor(-6.1795e-05, grad_fn=)\n", 175 | "6500 tensor(1.1358e-05, grad_fn=) tensor(-0.0003, grad_fn=)\n", 176 | "6600 tensor(2.6398e-06, grad_fn=) tensor(0.0003, grad_fn=)\n", 177 | "6700 tensor(2.6210e-06, grad_fn=) tensor(-2.6017e-05, grad_fn=)\n", 178 | "6800 tensor(3.0068e-06, grad_fn=) tensor(6.2495e-05, grad_fn=)\n", 179 | "6900 tensor(2.6101e-06, grad_fn=) tensor(-6.8158e-05, grad_fn=)\n", 180 | "7000 tensor(4.2694e-06, grad_fn=) tensor(-0.0001, grad_fn=)\n", 181 | "7100 tensor(2.8782e-06, grad_fn=) tensor(0.0002, grad_fn=)\n", 182 | "7200 tensor(2.6215e-06, grad_fn=) tensor(-1.6689e-05, grad_fn=)\n", 183 | "7300 tensor(2.5803e-06, grad_fn=) tensor(-5.3614e-05, grad_fn=)\n", 184 | "7400 tensor(2.6122e-06, grad_fn=) tensor(3.4973e-05, grad_fn=)\n", 185 | "7500 tensor(2.5636e-06, grad_fn=) tensor(-6.4358e-05, grad_fn=)\n", 186 | "7600 tensor(2.5875e-06, grad_fn=) tensor(4.0919e-05, grad_fn=)\n", 187 | "7700 tensor(2.5568e-06, grad_fn=) tensor(-5.8919e-05, grad_fn=)\n", 188 | "7800 tensor(5.3939e-06, grad_fn=) tensor(-0.0002, grad_fn=)\n", 189 | "7900 tensor(2.5471e-06, grad_fn=) tensor(0.0002, grad_fn=)\n", 190 | "8000 tensor(2.5332e-06, grad_fn=) tensor(-3.0413e-05, grad_fn=)\n", 191 | "8100 tensor(2.5644e-06, grad_fn=) tensor(6.4805e-05, grad_fn=)\n", 192 | "8200 tensor(2.5314e-06, grad_fn=) tensor(-7.6413e-05, grad_fn=)\n", 193 | "8300 tensor(2.5180e-06, grad_fn=) tensor(-1.8880e-05, grad_fn=)\n", 194 | "8400 tensor(2.5951e-06, grad_fn=) tensor(9.0465e-05, grad_fn=)\n", 195 | "8500 tensor(2.5153e-06, grad_fn=) tensor(-7.1764e-05, grad_fn=)\n", 196 | "8600 tensor(2.5032e-06, grad_fn=) tensor(-2.1592e-05, grad_fn=)\n", 197 | "8700 tensor(3.1537e-06, grad_fn=) tensor(0.0002, grad_fn=)\n", 198 | "8800 tensor(2.5065e-06, grad_fn=) tensor(-0.0002, grad_fn=)\n", 199 | "8900 tensor(2.4933e-06, grad_fn=) tensor(-2.3946e-05, grad_fn=)\n", 200 | "9000 tensor(5.0203e-06, grad_fn=) tensor(-1.9088e-05, grad_fn=)\n", 201 | "9100 tensor(2.5003e-06, grad_fn=) tensor(3.3617e-05, grad_fn=)\n", 202 | "9200 tensor(2.4816e-06, grad_fn=) tensor(-3.6657e-05, grad_fn=)\n", 203 | "9300 tensor(0.0001, grad_fn=) tensor(-4.0799e-05, grad_fn=)\n", 204 | "9400 tensor(2.4813e-06, grad_fn=) tensor(7.7516e-05, grad_fn=)\n", 205 | "9500 tensor(2.4662e-06, grad_fn=) tensor(-4.0948e-05, grad_fn=)\n", 206 | "9600 tensor(1.8595e-05, grad_fn=) tensor(-0.0002, grad_fn=)\n", 207 | "9700 tensor(2.4644e-06, grad_fn=) tensor(0.0002, grad_fn=)\n", 208 | "9800 tensor(2.4534e-06, grad_fn=) tensor(-1.6689e-05, grad_fn=)\n", 209 | "9900 tensor(2.4885e-06, grad_fn=) tensor(7.0810e-05, grad_fn=)\n" 210 | ] 211 | } 212 | ], 213 | "source": [ 214 | "l1=10**(-5)\n", 215 | "loss_L1_pre = l1*loss_L1(weight_vector,torch.zeros_like(weight_vector))\n", 216 | "for iteration in np.arange(10000):\n", 217 | " optimizer.zero_grad() \n", 218 | " prediction = network(X_train)\n", 219 | " y = prediction\n", 220 | " y_t, theta = library_1D_new(X_train, y,library_config={'poly_order':1,'diff_order':2})\n", 221 | " f = y_t - theta @ weight_vector\n", 222 | " loss_PI = torch.nn.MSELoss()\n", 223 | " loss_MSE = torch.nn.MSELoss()\n", 224 | " loss_L1 = nn.L1Loss()\n", 225 | " loss = loss_MSE(prediction, y_train) + loss_PI(f, torch.zeros_like(f)) + l1*loss_L1(weight_vector,torch.zeros_like(weight_vector))\n", 226 | " \n", 227 | " loss.backward()\n", 228 | " optimizer.step()\n", 229 | " \n", 230 | " if iteration % 100 == 0:\n", 231 | " print(iteration, loss, loss_L1_pre - loss_L1(weight_vector,torch.zeros_like(weight_vector)))\n", 232 | " loss_L1_pre = loss_L1(weight_vector,torch.zeros_like(weight_vector))" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 98, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "scaled_time = torch.norm(y_t).detach().numpy()\n", 242 | "scaled_theta = torch.norm(theta,dim=0).detach().numpy()" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 99, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "scaled_weight_vector = np.squeeze(weight_vector.detach().numpy())*scaled_theta" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 101, 257 | "metadata": {}, 258 | "outputs": [ 259 | { 260 | "data": { 261 | "text/plain": [ 262 | "tensor([[ 4.5746e-05],\n", 263 | " [-2.1027e-03],\n", 264 | " [ 9.9532e-02],\n", 265 | " [-5.5832e-04],\n", 266 | " [-9.9276e-01],\n", 267 | " [-1.1922e-03]], requires_grad=True)" 268 | ] 269 | }, 270 | "execution_count": 101, 271 | "metadata": {}, 272 | "output_type": "execute_result" 273 | } 274 | ], 275 | "source": [ 276 | "weight_vector" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 4, 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "name": "stdout", 286 | "output_type": "stream", 287 | "text": [ 288 | "Sequential(\n", 289 | " (0): Linear(in_features=2, out_features=10, bias=True)\n", 290 | " (1): Tanh()\n", 291 | " (2): Linear(in_features=10, out_features=10, bias=True)\n", 292 | " (3): Tanh()\n", 293 | " (4): Linear(in_features=10, out_features=10, bias=True)\n", 294 | " (5): Tanh()\n", 295 | " (6): Linear(in_features=10, out_features=10, bias=True)\n", 296 | " (7): Tanh()\n", 297 | " (8): Linear(in_features=10, out_features=10, bias=True)\n", 298 | " (9): Tanh()\n", 299 | " (10): Linear(in_features=10, out_features=1, bias=True)\n", 300 | ")\n" 301 | ] 302 | } 303 | ], 304 | "source": [ 305 | "DeepMod(X_train, y_train, config={'input_dim':2, 'hidden_dim':10, 'layers':5, 'output_dim':1})" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": 4, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [ 314 | "X = np.arange(0,1000,1).reshape(-1,1)\n", 315 | "y = X*X.reshape(-1,1)\n", 316 | "idx = np.random.permutation(y.size)" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 5, 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "X_train = torch.tensor(X[:, :][:number_of_samples], dtype=torch.float32, requires_grad=True)\n", 326 | "y_train = torch.tensor(y[:, :][:number_of_samples], dtype=torch.float32)" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": 6, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "network = LinNetwork(input_dim=1, hidden_dim=10, layers=5, output_dim=1)" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 7, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "prediction = network(X_train)\n", 345 | "y = prediction" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": 8, 351 | "metadata": {}, 352 | "outputs": [ 353 | { 354 | "name": "stdout", 355 | "output_type": "stream", 356 | "text": [ 357 | "tensor([[1.0000, 0.1584, 0.0251],\n", 358 | " [1.0000, 0.1471, 0.0217],\n", 359 | " [1.0000, 0.1337, 0.0179],\n", 360 | " [1.0000, 0.1316, 0.0173],\n", 361 | " [1.0000, 0.1339, 0.0179],\n", 362 | " [1.0000, 0.1372, 0.0188],\n", 363 | " [1.0000, 0.1406, 0.0198],\n", 364 | " [1.0000, 0.1437, 0.0207],\n", 365 | " [1.0000, 0.1464, 0.0214],\n", 366 | " [1.0000, 0.1487, 0.0221]], grad_fn=)\n", 367 | "tensor([[ 1.0000e+00, -3.4018e-03, -1.2140e-02],\n", 368 | " [ 1.0000e+00, -1.7684e-02, -2.4011e-03],\n", 369 | " [ 1.0000e+00, -6.8308e-03, 1.3013e-02],\n", 370 | " [ 1.0000e+00, 1.0594e-03, 3.8364e-03],\n", 371 | " [ 1.0000e+00, 3.0818e-03, 8.6494e-04],\n", 372 | " [ 1.0000e+00, 3.4536e-03, 2.6085e-05],\n", 373 | " [ 1.0000e+00, 3.2918e-03, -3.0473e-04],\n", 374 | " [ 1.0000e+00, 2.9105e-03, -4.3198e-04],\n", 375 | " [ 1.0000e+00, 2.4694e-03, -4.3464e-04],\n", 376 | " [ 1.0000e+00, 2.0625e-03, -3.7298e-04]], grad_fn=)\n" 377 | ] 378 | }, 379 | { 380 | "data": { 381 | "text/plain": [ 382 | "(tensor([], size=(10, 0), grad_fn=),\n", 383 | " tensor([[ 1.0000e+00, -3.4018e-03, -1.2140e-02, 1.0000e+00, -3.4018e-03,\n", 384 | " -1.2140e-02, 1.5843e-01, -5.3895e-04, -1.9233e-03],\n", 385 | " [ 1.0000e+00, -1.7684e-02, -2.4011e-03, 1.0000e+00, -1.7684e-02,\n", 386 | " -2.4011e-03, 1.4715e-01, -2.6021e-03, -3.5332e-04],\n", 387 | " [ 1.0000e+00, -6.8308e-03, 1.3013e-02, 1.0000e+00, -6.8308e-03,\n", 388 | " 1.3013e-02, 1.3367e-01, -9.1310e-04, 1.7395e-03],\n", 389 | " [ 1.0000e+00, 1.0594e-03, 3.8364e-03, 1.0000e+00, 1.0594e-03,\n", 390 | " 3.8364e-03, 1.3157e-01, 1.3938e-04, 5.0476e-04],\n", 391 | " [ 1.0000e+00, 3.0818e-03, 8.6494e-04, 1.0000e+00, 3.0818e-03,\n", 392 | " 8.6494e-04, 1.3388e-01, 4.1260e-04, 1.1580e-04],\n", 393 | " [ 1.0000e+00, 3.4536e-03, 2.6085e-05, 1.0000e+00, 3.4536e-03,\n", 394 | " 2.6085e-05, 1.3722e-01, 4.7390e-04, 3.5793e-06],\n", 395 | " [ 1.0000e+00, 3.2918e-03, -3.0473e-04, 1.0000e+00, 3.2918e-03,\n", 396 | " -3.0473e-04, 1.4062e-01, 4.6289e-04, -4.2850e-05],\n", 397 | " [ 1.0000e+00, 2.9105e-03, -4.3198e-04, 1.0000e+00, 2.9105e-03,\n", 398 | " -4.3198e-04, 1.4373e-01, 4.1832e-04, -6.2089e-05],\n", 399 | " [ 1.0000e+00, 2.4694e-03, -4.3464e-04, 1.0000e+00, 2.4694e-03,\n", 400 | " -4.3464e-04, 1.4642e-01, 3.6156e-04, -6.3640e-05],\n", 401 | " [ 1.0000e+00, 2.0625e-03, -3.7298e-04, 1.0000e+00, 2.0625e-03,\n", 402 | " -3.7298e-04, 1.4868e-01, 3.0665e-04, -5.5455e-05]],\n", 403 | " grad_fn=))" 404 | ] 405 | }, 406 | "execution_count": 8, 407 | "metadata": {}, 408 | "output_type": "execute_result" 409 | } 410 | ], 411 | "source": [ 412 | "library_1D_new(X_train, y,library_config={'poly_order':2,'diff_order':2})" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "metadata": {}, 419 | "outputs": [], 420 | "source": [] 421 | } 422 | ], 423 | "metadata": { 424 | "kernelspec": { 425 | "display_name": "Python 3", 426 | "language": "python", 427 | "name": "python3" 428 | }, 429 | "language_info": { 430 | "codemirror_mode": { 431 | "name": "ipython", 432 | "version": 3 433 | }, 434 | "file_extension": ".py", 435 | "mimetype": "text/x-python", 436 | "name": "python", 437 | "nbconvert_exporter": "python", 438 | "pygments_lexer": "ipython3", 439 | "version": "3.6.8" 440 | } 441 | }, 442 | "nbformat": 4, 443 | "nbformat_minor": 2 444 | } 445 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/PyTorch_V2-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append('src/')\n", 11 | "import numpy as np\n", 12 | "import torch, torch.nn\n", 13 | "from library_function import library_1D\n", 14 | "from neural_net import LinNetwork\n", 15 | "from DeepMod import DeepMod\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "plt.style.use('seaborn-notebook')\n", 18 | "import torch.nn as nn\n", 19 | "from torch.autograd import grad" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "# Preparing data" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "np.random.seed(34) \n", 36 | "number_of_samples = 1000\n", 37 | "\n", 38 | "data = np.load('data/burgers.npy', allow_pickle=True).item()\n", 39 | "\n", 40 | "X = np.transpose((data['x'].flatten(), data['t'].flatten()))\n", 41 | "y = np.real(data['u']).reshape((data['u'].size, 1))\n", 42 | "\n", 43 | "idx = np.random.permutation(y.size)\n", 44 | "X_train = torch.tensor(X[idx, :][:number_of_samples], dtype=torch.float32, requires_grad=True)\n", 45 | "y_train = torch.tensor(y[idx, :][:number_of_samples], dtype=torch.float32)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "# Building network" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "optim_config ={'lambda':1e-5,'max_iteration':5000}\n", 62 | "lib_config={'poly_order':2, 'diff_order':2, 'total_terms':9}\n", 63 | "network_config={'input_dim':2, 'hidden_dim':10, 'layers':5, 'output_dim':1}" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "Epoch | Total loss | MSE | PI | L1 \n", 76 | "0 0.97975564 0.060335446 0.91941017 1e-05\n", 77 | "[0.999 1.0009997 0.9990019 1.001 0.9990086 1.0009941\n", 78 | " 0.9990001 0.9995901 0.99904853]\n" 79 | ] 80 | }, 81 | { 82 | "ename": "KeyboardInterrupt", 83 | "evalue": "", 84 | "output_type": "error", 85 | "traceback": [ 86 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 87 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 88 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mweight_vector\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscaled_weight_vector\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparsity\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDeepMod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnetwork_config\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlib_config\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptim_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 89 | "\u001b[0;32m~/Documents/GitHub/DeepMoD_Torch/src/DeepMod.py\u001b[0m in \u001b[0;36mDeepMod\u001b[0;34m(data, target, network_config, library_config, optim_config)\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;31m# Print the output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 36\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0miteration\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 37\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Epoch | Total loss | MSE | PI | L1 '\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0miteration\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m500\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 90 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 91 | ] 92 | } 93 | ], 94 | "source": [ 95 | "weight_vector, scaled_weight_vector, sparsity = DeepMod(X_train, y_train,network_config, lib_config, optim_config)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 25, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "data": { 105 | "text/plain": [ 106 | "array(0.09383126, dtype=float32)" 107 | ] 108 | }, 109 | "execution_count": 25, 110 | "metadata": {}, 111 | "output_type": "execute_result" 112 | } 113 | ], 114 | "source": [ 115 | "torch.var(test_tensor).detach().numpy()" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 28, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/plain": [ 126 | "tensor([[0.0000, 0.9044, 0.1192, 0.6631, 0.9442],\n", 127 | " [0.4493, 0.6197, 0.7321, 0.7207, 0.8996],\n", 128 | " [0.1552, 0.6388, 0.4838, 0.6342, 0.0000],\n", 129 | " [0.3081, 0.2653, 0.5313, 0.8170, 0.2314],\n", 130 | " [0.9311, 0.9999, 0.9976, 0.4489, 0.9990],\n", 131 | " [0.9396, 0.1914, 0.0000, 0.8966, 0.7329],\n", 132 | " [0.1867, 0.6454, 0.8032, 0.2288, 0.1029],\n", 133 | " [0.8257, 0.8126, 0.9281, 0.3943, 0.8439],\n", 134 | " [0.5643, 0.9060, 0.9728, 0.5139, 0.5190],\n", 135 | " [0.7834, 0.0000, 0.3767, 0.5115, 0.7732]])" 136 | ] 137 | }, 138 | "execution_count": 28, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "torch.where(torch.abs(test_tensor) > torch.var(test_tensor), test_tensor, torch.zeros_like(test_tensor))" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 8, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "test_tensor = torch.rand((10, 5))" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 5, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "test_tensor = torch.rand(1)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 10, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | "[[0.71023726 0.1892488 0.6319442 0.19041932 0.9054366 ]\n", 175 | " [0.11108172 0.55048764 0.157606 0.01528466 0.08591795]\n", 176 | " [0.4742189 0.4514264 0.09886968 0.8400874 0.25091565]\n", 177 | " [0.5689354 0.8502427 0.6938656 0.5038133 0.9952059 ]\n", 178 | " [0.05185252 0.28007346 0.7712979 0.7902976 0.00522673]\n", 179 | " [0.88777936 0.02592885 0.92315173 0.5481899 0.3546307 ]\n", 180 | " [0.39075083 0.708404 0.5060967 0.87950134 0.7201277 ]\n", 181 | " [0.1006586 0.10669661 0.04020923 0.7579842 0.44361705]\n", 182 | " [0.8464011 0.3912769 0.75543344 0.7260391 0.65598917]\n", 183 | " [0.45569474 0.35944712 0.81301725 0.9455671 0.40103966]]\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "print(test_tensor.numpy())" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 17, 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "ename": "TypeError", 198 | "evalue": "to() received an invalid combination of arguments - got (requires_grad=bool, dtype=torch.dtype, ), but expected one of:\n * (torch.device device, torch.dtype dtype, bool non_blocking, bool copy)\n * (torch.dtype dtype, bool non_blocking, bool copy)\n * (Tensor tensor, bool non_blocking, bool copy)\n", 199 | "output_type": "error", 200 | "traceback": [ 201 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 202 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 203 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtest_tensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequires_grad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 204 | "\u001b[0;31mTypeError\u001b[0m: to() received an invalid combination of arguments - got (requires_grad=bool, dtype=torch.dtype, ), but expected one of:\n * (torch.device device, torch.dtype dtype, bool non_blocking, bool copy)\n * (torch.dtype dtype, bool non_blocking, bool copy)\n * (Tensor tensor, bool non_blocking, bool copy)\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "test_tensor.to(dtype=torch.float32, requires_grad=True)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 9, 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "data": { 219 | "text/plain": [ 220 | "tensor([0.6499, 0.5928, 0.2442, 0.3834, 0.7453, 0.9785, 0.6684, 0.5740, 0.6104,\n", 221 | " 0.9307])" 222 | ] 223 | }, 224 | "execution_count": 9, 225 | "metadata": {}, 226 | "output_type": "execute_result" 227 | } 228 | ], 229 | "source": [ 230 | "np.squeeze(torch.rand((10, 1)).detach())" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 21, 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "data": { 240 | "text/plain": [ 241 | "tensor([[0.3600, 0.6430, 0.9685, 0.4562, 0.5089],\n", 242 | " [0.8153, 0.4840, 0.3934, 0.6650, 0.7970],\n", 243 | " [0.7338, 0.5660, 0.4046, 0.8905, 0.0250],\n", 244 | " [0.7464, 0.3303, 0.8371, 0.6098, 0.0227],\n", 245 | " [0.4997, 0.6406, 0.3026, 0.7290, 0.3725],\n", 246 | " [0.5349, 0.9146, 0.1108, 0.4827, 0.9528],\n", 247 | " [0.0816, 0.8703, 0.3789, 0.6827, 0.8305],\n", 248 | " [0.4820, 0.6484, 0.8214, 0.2271, 0.6757],\n", 249 | " [0.2052, 0.4192, 0.0632, 0.6259, 0.9335],\n", 250 | " [0.0515, 0.8854, 0.7249, 0.0454, 0.2680]], requires_grad=True)" 251 | ] 252 | }, 253 | "execution_count": 21, 254 | "metadata": {}, 255 | "output_type": "execute_result" 256 | } 257 | ], 258 | "source": [ 259 | "torch.tensor(torch.rand((10, 5)).numpy(), requires_grad=True)" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": {}, 266 | "outputs": [], 267 | "source": [] 268 | } 269 | ], 270 | "metadata": { 271 | "kernelspec": { 272 | "display_name": "Python 3", 273 | "language": "python", 274 | "name": "python3" 275 | }, 276 | "language_info": { 277 | "codemirror_mode": { 278 | "name": "ipython", 279 | "version": 3 280 | }, 281 | "file_extension": ".py", 282 | "mimetype": "text/x-python", 283 | "name": "python", 284 | "nbconvert_exporter": "python", 285 | "pygments_lexer": "ipython3", 286 | "version": "3.6.8" 287 | } 288 | }, 289 | "nbformat": 4, 290 | "nbformat_minor": 2 291 | } 292 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/PyTorch_V3-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append('src/')\n", 11 | "import numpy as np\n", 12 | "import torch, torch.nn\n", 13 | "from library_function import library_1D\n", 14 | "from neural_net import LinNetwork\n", 15 | "from DeepMod import DeepMod\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "plt.style.use('seaborn-notebook')\n", 18 | "import torch.nn as nn\n", 19 | "from torch.autograd import grad" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "# Preparing data" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 2, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "np.random.seed(34) \n", 36 | "number_of_samples = 800\n", 37 | "\n", 38 | "data = np.load('data/burgers.npy', allow_pickle=True).item()\n", 39 | "\n", 40 | "X = np.transpose((data['x'].flatten(), data['t'].flatten()))\n", 41 | "y = np.real(data['u']).reshape((data['u'].size, 1))\n", 42 | "\n", 43 | "idx = np.random.permutation(y.size)\n", 44 | "X_train = torch.tensor(X[idx, :][:number_of_samples], dtype=torch.float32, requires_grad=True)\n", 45 | "y_train = torch.tensor(y[idx, :][:number_of_samples], dtype=torch.float32)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "# Building network" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "optim_config ={'lambda':1e-5,'max_iteration':500}\n", 62 | "lib_config={'poly_order':2, 'diff_order':2, 'total_terms':9}\n", 63 | "network_config={'input_dim':2, 'hidden_dim':10, 'layers':5, 'output_dim':1}" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "data": { 73 | "text/plain": [ 74 | "tensor([[1]])" 75 | ] 76 | }, 77 | "execution_count": 4, 78 | "metadata": {}, 79 | "output_type": "execute_result" 80 | } 81 | ], 82 | "source": [ 83 | "torch.tensor(1).reshape((1,-1))" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "Epoch | Total loss | MSE | PI | L1 \n", 96 | "0 0.8294 0.1602 0.6692 0.0000\n", 97 | "[[1. 1. 1. 1. 1. 1. 1. 1. 1.]]\n", 98 | "tensor([0, 3])\n", 99 | "Epoch | Total loss | MSE | PI \n", 100 | "0 0.2386 0.1287 0.1099\n", 101 | "[[0.61 1.37]]\n", 102 | "1000 0.0313 0.0228 0.0085\n", 103 | "[[-0.04 1.53]]\n", 104 | "2000 0.0181 0.0099 0.0082\n", 105 | "[[-0.12 0.94]]\n" 106 | ] 107 | }, 108 | { 109 | "ename": "KeyboardInterrupt", 110 | "evalue": "", 111 | "output_type": "error", 112 | "traceback": [ 113 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 114 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 115 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0msparse_weight_vector\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparsity_pattern\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDeepMod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnetwork_config\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlib_config\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptim_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 116 | "\u001b[0;32m~/Documents/GitHub/DeepMoD_Torch/src/DeepMod.py\u001b[0m in \u001b[0;36mDeepMod\u001b[0;34m(data, target, network_config, library_config, optim_config)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# Final Training without L1 and with the sparsity pattern\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0msparse_weight_vector\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mFinal_Training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptim_config\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlibrary_config\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnetwork\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparse_weight_vector\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparsity_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msparse_weight_vector\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparsity_mask\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 117 | "\u001b[0;32m~/Documents/GitHub/DeepMoD_Torch/src/neural_net.py\u001b[0m in \u001b[0;36mFinal_Training\u001b[0;34m(data, target, optim_config, library_config, network, sparse_weight_vector, sparsity_mask)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;31m# Optimizwe step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 118 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \"\"\"\n\u001b[0;32m--> 102\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 103\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 119 | "\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 88\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 89\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m allow_unreachable=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 120 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "sparse_weight_vector, sparsity_pattern = DeepMod(X_train, y_train,network_config, lib_config, optim_config)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "scaled_weight_vector" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 9, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "data": { 144 | "text/plain": [ 145 | "tensor([2, 4])" 146 | ] 147 | }, 148 | "execution_count": 9, 149 | "metadata": {}, 150 | "output_type": "execute_result" 151 | } 152 | ], 153 | "source": [ 154 | "sparsity" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 6, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "test_tensor = torch.rand((10, 5))" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 7, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "data": { 173 | "text/plain": [ 174 | "torch.Size([10, 5])" 175 | ] 176 | }, 177 | "execution_count": 7, 178 | "metadata": {}, 179 | "output_type": "execute_result" 180 | } 181 | ], 182 | "source": [ 183 | "test_tensor.shape" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 11, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "test_tensor = torch.rand((10,1))\n", 193 | "test_tensor_2 = torch.rand((1,10))" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 12, 199 | "metadata": {}, 200 | "outputs": [ 201 | { 202 | "data": { 203 | "text/plain": [ 204 | "tensor([[0.1171, 0.2456, 0.5149, 0.6258, 0.2198, 0.7128, 0.5038, 0.2305, 0.3162,\n", 205 | " 0.9743],\n", 206 | " [0.1198, 0.2511, 0.5266, 0.6399, 0.2248, 0.7290, 0.5153, 0.2357, 0.3234,\n", 207 | " 0.9964],\n", 208 | " [0.0810, 0.1698, 0.3560, 0.4326, 0.1519, 0.4927, 0.3483, 0.1593, 0.2186,\n", 209 | " 0.6735],\n", 210 | " [0.0911, 0.1911, 0.4007, 0.4870, 0.1710, 0.5547, 0.3921, 0.1794, 0.2461,\n", 211 | " 0.7582],\n", 212 | " [0.0239, 0.0501, 0.1050, 0.1276, 0.0448, 0.1454, 0.1027, 0.0470, 0.0645,\n", 213 | " 0.1987],\n", 214 | " [0.0109, 0.0228, 0.0477, 0.0580, 0.0204, 0.0661, 0.0467, 0.0214, 0.0293,\n", 215 | " 0.0903],\n", 216 | " [0.0115, 0.0240, 0.0504, 0.0613, 0.0215, 0.0698, 0.0493, 0.0226, 0.0310,\n", 217 | " 0.0954],\n", 218 | " [0.0085, 0.0178, 0.0373, 0.0453, 0.0159, 0.0517, 0.0365, 0.0167, 0.0229,\n", 219 | " 0.0706],\n", 220 | " [0.0602, 0.1262, 0.2647, 0.3216, 0.1130, 0.3664, 0.2590, 0.1185, 0.1625,\n", 221 | " 0.5008],\n", 222 | " [0.0898, 0.1882, 0.3947, 0.4796, 0.1685, 0.5463, 0.3862, 0.1766, 0.2424,\n", 223 | " 0.7468]])" 224 | ] 225 | }, 226 | "execution_count": 12, 227 | "metadata": {}, 228 | "output_type": "execute_result" 229 | } 230 | ], 231 | "source": [ 232 | "test_tensor @ test_tensor_2" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 17, 238 | "metadata": {}, 239 | "outputs": [ 240 | { 241 | "ename": "TypeError", 242 | "evalue": "to() received an invalid combination of arguments - got (requires_grad=bool, dtype=torch.dtype, ), but expected one of:\n * (torch.device device, torch.dtype dtype, bool non_blocking, bool copy)\n * (torch.dtype dtype, bool non_blocking, bool copy)\n * (Tensor tensor, bool non_blocking, bool copy)\n", 243 | "output_type": "error", 244 | "traceback": [ 245 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 246 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 247 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtest_tensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequires_grad\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 248 | "\u001b[0;31mTypeError\u001b[0m: to() received an invalid combination of arguments - got (requires_grad=bool, dtype=torch.dtype, ), but expected one of:\n * (torch.device device, torch.dtype dtype, bool non_blocking, bool copy)\n * (torch.dtype dtype, bool non_blocking, bool copy)\n * (Tensor tensor, bool non_blocking, bool copy)\n" 249 | ] 250 | } 251 | ], 252 | "source": [ 253 | "test_tensor.to(dtype=torch.float32, requires_grad=True)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 9, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "data": { 263 | "text/plain": [ 264 | "tensor([0.6499, 0.5928, 0.2442, 0.3834, 0.7453, 0.9785, 0.6684, 0.5740, 0.6104,\n", 265 | " 0.9307])" 266 | ] 267 | }, 268 | "execution_count": 9, 269 | "metadata": {}, 270 | "output_type": "execute_result" 271 | } 272 | ], 273 | "source": [ 274 | "np.squeeze(torch.rand((10, 1)).detach())" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 21, 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "data": { 284 | "text/plain": [ 285 | "tensor([[0.3600, 0.6430, 0.9685, 0.4562, 0.5089],\n", 286 | " [0.8153, 0.4840, 0.3934, 0.6650, 0.7970],\n", 287 | " [0.7338, 0.5660, 0.4046, 0.8905, 0.0250],\n", 288 | " [0.7464, 0.3303, 0.8371, 0.6098, 0.0227],\n", 289 | " [0.4997, 0.6406, 0.3026, 0.7290, 0.3725],\n", 290 | " [0.5349, 0.9146, 0.1108, 0.4827, 0.9528],\n", 291 | " [0.0816, 0.8703, 0.3789, 0.6827, 0.8305],\n", 292 | " [0.4820, 0.6484, 0.8214, 0.2271, 0.6757],\n", 293 | " [0.2052, 0.4192, 0.0632, 0.6259, 0.9335],\n", 294 | " [0.0515, 0.8854, 0.7249, 0.0454, 0.2680]], requires_grad=True)" 295 | ] 296 | }, 297 | "execution_count": 21, 298 | "metadata": {}, 299 | "output_type": "execute_result" 300 | } 301 | ], 302 | "source": [ 303 | "torch.tensor(torch.rand((10, 5)).numpy(), requires_grad=True)" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [] 312 | } 313 | ], 314 | "metadata": { 315 | "kernelspec": { 316 | "display_name": "Python 3", 317 | "language": "python", 318 | "name": "python3" 319 | }, 320 | "language_info": { 321 | "codemirror_mode": { 322 | "name": "ipython", 323 | "version": 3 324 | }, 325 | "file_extension": ".py", 326 | "mimetype": "text/x-python", 327 | "name": "python", 328 | "nbconvert_exporter": "python", 329 | "pygments_lexer": "ipython3", 330 | "version": "3.6.8" 331 | } 332 | }, 333 | "nbformat": 4, 334 | "nbformat_minor": 2 335 | } 336 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/PyTorch_V4_multi_traj-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import sys\n", 10 | "sys.path.append('src/')\n", 11 | "import numpy as np\n", 12 | "import torch, torch.nn\n", 13 | "from library_function import library_1D\n", 14 | "from neural_net import LinNetwork\n", 15 | "from DeepMod import DeepMod\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "plt.style.use('seaborn-notebook')\n", 18 | "import torch.nn as nn\n", 19 | "from torch.autograd import grad\n", 20 | "from scipy.io import loadmat\n", 21 | "\n", 22 | "%load_ext autoreload\n", 23 | "%autoreload 2" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "# Preparing data" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "np.random.seed(34) \n", 40 | "number_of_samples = 1000\n", 41 | "\n", 42 | "data = np.load('data/burgers.npy', allow_pickle=True).item()\n", 43 | "\n", 44 | "X = np.transpose((data['x'].flatten(), data['t'].flatten()))\n", 45 | "y = np.real(np.transpose((data['u'].flatten(),data['u'].flatten())))" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "idx = np.random.permutation(y.shape[0])\n", 55 | "X_train = torch.tensor(X[idx, :][:number_of_samples], dtype=torch.float32, requires_grad=True)\n", 56 | "y_train = torch.tensor(y[idx, :][:number_of_samples], dtype=torch.float32)" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "rawdata = loadmat('data/kinetics_new.mat')\n", 66 | "raw = np.real(rawdata['Expression1'])\n", 67 | "raw= raw.reshape((1901,3))\n", 68 | "t = raw[:-1,0].reshape(-1,1)\n", 69 | "X1= raw[:-1,1]\n", 70 | "X2 = raw[:-1,2]" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 4, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "X = np.float32(t.reshape(-1,1))" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 5, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "y= np.vstack((X1,X2))\n", 89 | "y = np.transpose(y)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 6, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "number_of_samples = 1000" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 7, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "idx = np.random.permutation(y.shape[0])" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 8, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "X_train = torch.tensor(X[idx, :][:number_of_samples], dtype=torch.float32, requires_grad=True)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 9, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "y_train = torch.tensor(y[idx, :][:number_of_samples], dtype=torch.float32)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 10, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "data": { 135 | "text/plain": [ 136 | "torch.Size([1000, 2])" 137 | ] 138 | }, 139 | "execution_count": 10, 140 | "metadata": {}, 141 | "output_type": "execute_result" 142 | } 143 | ], 144 | "source": [ 145 | "y_train.shape" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "# Building network" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 11, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "optim_config ={'lambda':1e-6,'max_iteration':50000}\n", 162 | "lib_config={'poly_order':1, 'diff_order':2, 'total_terms':4}\n", 163 | "network_config={'input_dim':1, 'hidden_dim':20, 'layers':5, 'output_dim':2}" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "name": "stdout", 173 | "output_type": "stream", 174 | "text": [ 175 | "Epoch | Total loss | MSE | PI | L1 \n", 176 | "0 2.0E+00 1.5E+00 4.2E-01 1.0E-06\n", 177 | "[[1. 1.]\n", 178 | " [1. 1.]\n", 179 | " [1. 1.]\n", 180 | " [1. 1.]]\n", 181 | "1000 2.4E-01 2.3E-01 1.0E-02 8.8E-07\n", 182 | "[[0.33 0.46]\n", 183 | " [0.25 0.26]\n", 184 | " [1.49 1.36]\n", 185 | " [1.49 1.39]]\n", 186 | "2000 6.4E-02 6.0E-02 3.3E-03 7.9E-07\n", 187 | "[[-0.22 0. ]\n", 188 | " [-0.35 -0.3 ]\n", 189 | " [ 1.42 1.29]\n", 190 | " [ 1.42 1.29]]\n", 191 | "3000 8.0E-03 7.5E-03 5.2E-04 7.9E-07\n", 192 | "[[-0.54 -0.25]\n", 193 | " [-0.68 -0.6 ]\n", 194 | " [ 1.1 1.06]\n", 195 | " [ 1.08 1.01]]\n", 196 | "4000 6.6E-03 6.2E-03 3.8E-04 7.6E-07\n", 197 | "[[-0.57 -0.27]\n", 198 | " [-0.67 -0.62]\n", 199 | " [ 0.98 1.01]\n", 200 | " [ 1. 0.95]]\n", 201 | "5000 6.0E-03 5.7E-03 3.6E-04 7.1E-07\n", 202 | "[[-0.55 -0.25]\n", 203 | " [-0.6 -0.6 ]\n", 204 | " [ 0.85 0.98]\n", 205 | " [ 0.93 0.91]]\n", 206 | "6000 5.0E-03 4.7E-03 3.3E-04 6.3E-07\n", 207 | "[[-0.52 -0.22]\n", 208 | " [-0.48 -0.56]\n", 209 | " [ 0.65 0.93]\n", 210 | " [ 0.81 0.86]]\n", 211 | "7000 3.3E-03 3.1E-03 2.5E-04 4.9E-07\n", 212 | "[[-0.46 -0.17]\n", 213 | " [-0.3 -0.49]\n", 214 | " [ 0.31 0.83]\n", 215 | " [ 0.6 0.76]]\n", 216 | "8000 1.2E-03 1.1E-03 1.0E-04 3.2E-07\n", 217 | "[[-0.36 -0.1 ]\n", 218 | " [-0.05 -0.38]\n", 219 | " [-0.12 0.67]\n", 220 | " [ 0.27 0.62]]\n", 221 | "9000 1.5E-04 1.4E-04 1.5E-05 2.9E-07\n", 222 | "[[-0.27 -0.05]\n", 223 | " [ 0.16 -0.3 ]\n", 224 | " [-0.47 0.53]\n", 225 | " [-0.02 0.5 ]]\n", 226 | "10000 2.4E-05 1.4E-05 9.5E-06 3.1E-07\n", 227 | "[[-0.22 -0.03]\n", 228 | " [ 0.24 -0.28]\n", 229 | " [-0.62 0.5 ]\n", 230 | " [-0.15 0.47]]\n", 231 | "11000 8.1E-05 3.1E-05 4.9E-05 3.2E-07\n", 232 | "[[-0.21 -0.03]\n", 233 | " [ 0.26 -0.27]\n", 234 | " [-0.65 0.5 ]\n", 235 | " [-0.18 0.46]]\n", 236 | "12000 2.4E-06 1.6E-06 4.6E-07 3.2E-07\n", 237 | "[[-0.21 -0.03]\n", 238 | " [ 0.27 -0.28]\n", 239 | " [-0.66 0.51]\n", 240 | " [-0.18 0.46]]\n", 241 | "13000 2.1E-06 1.6E-06 2.3E-07 3.3E-07\n", 242 | "[[-0.21 -0.03]\n", 243 | " [ 0.27 -0.28]\n", 244 | " [-0.66 0.51]\n", 245 | " [-0.19 0.46]]\n", 246 | "14000 2.4E-05 1.1E-05 1.2E-05 3.3E-07\n", 247 | "[[-0.21 -0.03]\n", 248 | " [ 0.27 -0.28]\n", 249 | " [-0.66 0.51]\n", 250 | " [-0.19 0.45]]\n", 251 | "15000 2.0E-06 1.5E-06 1.2E-07 3.3E-07\n", 252 | "[[-0.21 -0.02]\n", 253 | " [ 0.27 -0.28]\n", 254 | " [-0.66 0.52]\n", 255 | " [-0.19 0.45]]\n", 256 | "16000 9.2E-06 4.9E-06 4.0E-06 3.2E-07\n", 257 | "[[-0.21 -0.02]\n", 258 | " [ 0.27 -0.28]\n", 259 | " [-0.66 0.52]\n", 260 | " [-0.19 0.45]]\n", 261 | "17000 2.7E-06 2.2E-06 1.5E-07 3.2E-07\n", 262 | "[[-0.21 -0.02]\n", 263 | " [ 0.27 -0.28]\n", 264 | " [-0.66 0.52]\n", 265 | " [-0.19 0.44]]\n", 266 | "18000 1.5E-05 1.3E-05 1.9E-06 3.2E-07\n", 267 | "[[-0.21 -0.02]\n", 268 | " [ 0.27 -0.28]\n", 269 | " [-0.66 0.52]\n", 270 | " [-0.19 0.44]]\n", 271 | "19000 6.6E-06 5.9E-06 3.4E-07 3.2E-07\n", 272 | "[[-0.21 -0.02]\n", 273 | " [ 0.27 -0.28]\n", 274 | " [-0.66 0.53]\n", 275 | " [-0.19 0.44]]\n", 276 | "20000 1.8E-06 1.4E-06 1.2E-07 3.2E-07\n", 277 | "[[-0.21 -0.01]\n", 278 | " [ 0.27 -0.28]\n", 279 | " [-0.66 0.53]\n", 280 | " [-0.19 0.44]]\n", 281 | "21000 2.0E-05 1.9E-05 9.9E-07 3.2E-07\n", 282 | "[[-0.21 -0.01]\n", 283 | " [ 0.27 -0.28]\n", 284 | " [-0.66 0.53]\n", 285 | " [-0.19 0.43]]\n", 286 | "22000 1.9E-06 1.4E-06 1.9E-07 3.2E-07\n", 287 | "[[-0.21 -0.01]\n", 288 | " [ 0.27 -0.29]\n", 289 | " [-0.66 0.53]\n", 290 | " [-0.19 0.43]]\n", 291 | "23000 1.9E-06 1.4E-06 2.0E-07 3.2E-07\n", 292 | "[[-0.21 -0.01]\n", 293 | " [ 0.27 -0.29]\n", 294 | " [-0.66 0.54]\n", 295 | " [-0.19 0.43]]\n", 296 | "24000 4.7E-05 3.5E-05 1.1E-05 3.2E-07\n", 297 | "[[-0.21 -0. ]\n", 298 | " [ 0.27 -0.29]\n", 299 | " [-0.66 0.54]\n", 300 | " [-0.19 0.42]]\n", 301 | "25000 2.9E-05 1.0E-05 1.8E-05 3.2E-07\n", 302 | "[[-0.21 -0. ]\n", 303 | " [ 0.27 -0.29]\n", 304 | " [-0.66 0.54]\n", 305 | " [-0.19 0.42]]\n", 306 | "26000 2.2E-06 1.4E-06 4.7E-07 3.2E-07\n", 307 | "[[-0.21 0. ]\n", 308 | " [ 0.27 -0.29]\n", 309 | " [-0.66 0.55]\n", 310 | " [-0.19 0.42]]\n", 311 | "27000 4.1E-06 2.2E-06 1.6E-06 3.2E-07\n", 312 | "[[-0.21 0. ]\n", 313 | " [ 0.27 -0.29]\n", 314 | " [-0.66 0.55]\n", 315 | " [-0.19 0.41]]\n" 316 | ] 317 | } 318 | ], 319 | "source": [ 320 | "sparse_weight_vector, sparsity_pattern, prediction, network = DeepMod(X_train, y_train,network_config, lib_config, optim_config)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "prediction = network(torch.tensor(X, dtype=torch.float32))" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 15, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "prediction = prediction.detach().numpy()" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 64, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "x, y = np.meshgrid(X[:,0], X[:,1])" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": 145, 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "mask = torch.tensor((0,1,3))" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 157, 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "data": { 366 | "text/plain": [ 367 | "tensor([0, 1, 3])" 368 | ] 369 | }, 370 | "execution_count": 157, 371 | "metadata": {}, 372 | "output_type": "execute_result" 373 | } 374 | ], 375 | "source": [ 376 | "mask" 377 | ] 378 | }, 379 | { 380 | "cell_type": "code", 381 | "execution_count": 158, 382 | "metadata": {}, 383 | "outputs": [], 384 | "source": [ 385 | "sparse_coefs = torch.tensor((0.1,0.2,0.4)).reshape(-1,1)" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 159, 391 | "metadata": {}, 392 | "outputs": [ 393 | { 394 | "data": { 395 | "text/plain": [ 396 | "tensor([[0.1000],\n", 397 | " [0.2000],\n", 398 | " [0.4000]])" 399 | ] 400 | }, 401 | "execution_count": 159, 402 | "metadata": {}, 403 | "output_type": "execute_result" 404 | } 405 | ], 406 | "source": [ 407 | "sparse_coefs" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 291, 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "dummy = torch.ones((5,3,1))\n", 417 | "dummy2 = torch.ones((5,1,4))" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": 292, 423 | "metadata": {}, 424 | "outputs": [ 425 | { 426 | "data": { 427 | "text/plain": [ 428 | "torch.Size([5, 3, 4])" 429 | ] 430 | }, 431 | "execution_count": 292, 432 | "metadata": {}, 433 | "output_type": "execute_result" 434 | } 435 | ], 436 | "source": [ 437 | "(dummy @ dummy2).shape" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": 293, 443 | "metadata": {}, 444 | "outputs": [ 445 | { 446 | "data": { 447 | "text/plain": [ 448 | "torch.Size([5, 3, 1])" 449 | ] 450 | }, 451 | "execution_count": 293, 452 | "metadata": {}, 453 | "output_type": "execute_result" 454 | } 455 | ], 456 | "source": [ 457 | "dummy.shape" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": 294, 463 | "metadata": {}, 464 | "outputs": [ 465 | { 466 | "data": { 467 | "text/plain": [ 468 | "torch.Size([5, 3, 1])" 469 | ] 470 | }, 471 | "execution_count": 294, 472 | "metadata": {}, 473 | "output_type": "execute_result" 474 | } 475 | ], 476 | "source": [ 477 | "dummy.reshape(-1,3,1).shape" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": 164, 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "dummy = dummy.reshape(2,2)" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": 128, 492 | "metadata": {}, 493 | "outputs": [ 494 | { 495 | "ename": "TypeError", 496 | "evalue": "'Tensor' object is not callable", 497 | "output_type": "error", 498 | "traceback": [ 499 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 500 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 501 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwhere\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcoefs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcoefs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdummy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 502 | "\u001b[0;31mTypeError\u001b[0m: 'Tensor' object is not callable" 503 | ] 504 | } 505 | ], 506 | "source": [ 507 | "torch.where(coefs(mask),coefs,dummy)" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": 45, 513 | "metadata": {}, 514 | "outputs": [], 515 | "source": [ 516 | "x = np.linspace(0, 1, 100)\n", 517 | "X, Y = np.meshgrid(x, x)\n", 518 | "Z = np.sin(X)*np.sin(Y)\n" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 30, 524 | "metadata": {}, 525 | "outputs": [], 526 | "source": [ 527 | "b = torch.ones((10, 2), dtype=torch.float32, requires_grad=True)" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 31, 533 | "metadata": {}, 534 | "outputs": [], 535 | "source": [ 536 | "a = torch.tensor(np.ones((2,10)), dtype=torch.float32)" 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": 13, 542 | "metadata": {}, 543 | "outputs": [], 544 | "source": [ 545 | "test=torch.tensor([[0.3073, 0.4409],\n", 546 | " [0.0212, 0.6602]])" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": 17, 552 | "metadata": {}, 553 | "outputs": [ 554 | { 555 | "data": { 556 | "text/plain": [ 557 | "tensor([[0.3073, 0.4409],\n", 558 | " [0.0000, 0.6602]])" 559 | ] 560 | }, 561 | "execution_count": 17, 562 | "metadata": {}, 563 | "output_type": "execute_result" 564 | } 565 | ], 566 | "source": [ 567 | "torch.where(test>torch.tensor(0.3),test, torch.zeros_like(test))" 568 | ] 569 | }, 570 | { 571 | "cell_type": "raw", 572 | "metadata": {}, 573 | "source": [ 574 | "test2=torch.reshape(test, (1,4))" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": 83, 580 | "metadata": {}, 581 | "outputs": [ 582 | { 583 | "data": { 584 | "text/plain": [ 585 | "tensor([[0.3073],\n", 586 | " [0.4409],\n", 587 | " [0.0212],\n", 588 | " [0.6602]])" 589 | ] 590 | }, 591 | "execution_count": 83, 592 | "metadata": {}, 593 | "output_type": "execute_result" 594 | } 595 | ], 596 | "source": [ 597 | "test2[0,:].reshape(-1,1)" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": 47, 603 | "metadata": {}, 604 | "outputs": [], 605 | "source": [ 606 | "mask=torch.nonzero(test2[0,:])" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": 39, 612 | "metadata": {}, 613 | "outputs": [ 614 | { 615 | "ename": "RuntimeError", 616 | "evalue": "shape '[1, 4]' is invalid for input of size 8", 617 | "output_type": "error", 618 | "traceback": [ 619 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 620 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 621 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnonzero\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 622 | "\u001b[0;31mRuntimeError\u001b[0m: shape '[1, 4]' is invalid for input of size 8" 623 | ] 624 | } 625 | ], 626 | "source": [ 627 | "mask=torch.reshape(torch.nonzero(test2), (1,4))" 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "execution_count": 48, 633 | "metadata": {}, 634 | "outputs": [ 635 | { 636 | "data": { 637 | "text/plain": [ 638 | "tensor([[0],\n", 639 | " [1],\n", 640 | " [2],\n", 641 | " [3]])" 642 | ] 643 | }, 644 | "execution_count": 48, 645 | "metadata": {}, 646 | "output_type": "execute_result" 647 | } 648 | ], 649 | "source": [ 650 | "mask" 651 | ] 652 | }, 653 | { 654 | "cell_type": "code", 655 | "execution_count": 54, 656 | "metadata": {}, 657 | "outputs": [ 658 | { 659 | "ename": "RuntimeError", 660 | "evalue": "index 1 is out of bounds for dim with size 1", 661 | "output_type": "error", 662 | "traceback": [ 663 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 664 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 665 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtest2\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 666 | "\u001b[0;31mRuntimeError\u001b[0m: index 1 is out of bounds for dim with size 1" 667 | ] 668 | } 669 | ], 670 | "source": [ 671 | "test2[mask[1]]" 672 | ] 673 | }, 674 | { 675 | "cell_type": "code", 676 | "execution_count": 49, 677 | "metadata": {}, 678 | "outputs": [ 679 | { 680 | "data": { 681 | "text/plain": [ 682 | "10" 683 | ] 684 | }, 685 | "execution_count": 49, 686 | "metadata": {}, 687 | "output_type": "execute_result" 688 | } 689 | ], 690 | "source": [ 691 | "a.shape[1]" 692 | ] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "execution_count": null, 697 | "metadata": {}, 698 | "outputs": [], 699 | "source": [] 700 | } 701 | ], 702 | "metadata": { 703 | "kernelspec": { 704 | "display_name": "Python 3", 705 | "language": "python", 706 | "name": "python3" 707 | }, 708 | "language_info": { 709 | "codemirror_mode": { 710 | "name": "ipython", 711 | "version": 3 712 | }, 713 | "file_extension": ".py", 714 | "mimetype": "text/x-python", 715 | "name": "python", 716 | "nbconvert_exporter": "python", 717 | "pygments_lexer": "ipython3", 718 | "version": "3.6.8" 719 | } 720 | }, 721 | "nbformat": 4, 722 | "nbformat_minor": 2 723 | } 724 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/README-checkpoint.md: -------------------------------------------------------------------------------- 1 | # DeePyMoD 2 | 3 | DeePyMoD is a PyTorch-based implementation of the DeepMoD algorithm for model discovery of PDEs. We use a neural network to model our dataset, build a library of possible terms from the networks output and employ sparse regression to find the PDE underlying the dataset. More information can be found in our paper: [arXiv:1904.09406](http://arxiv.org/abs/1904.09406) 4 | 5 | **What's the use case?** Classical Model Discovery methods such as PDE-find struggle with elevated noise levels and sparse datasets due the low accuracy of numerical differentiation. DeepMoD can handle high noise and sparse datasets, making it well suited for model discovery on actual experimental data. 6 | 7 | **What types of models can you discover?** DeepMoD can discover non-linear, multi-dimensional and/or coupled PDEs. See our paper for a demonstration of each. 8 | 9 | **How hard is it to apply it to my data?** Not at all! We've designed the code to be accessible without having in-depth knowledge of deep learning or model discovery. You can load in the data, train the model and get the result in a few lines of code. We include a few notebooks with examples in the examples folder. Feel free to open an issue if you need any additional help. 10 | 11 | **How do I modify the code?** We provide two interfaces, an object-based and functional-based one. The object-based interface is simply a wrapper around the functional one. The code has been modularly designed and is well documented, so you should be able to plug-in another training regime, cost function or library function yourself pretty easily. 12 | 13 | # Features 14 | 15 | * **Fast** We implemented a neural network which also calculates the derivatives w.r.t. input on the forward pass. This saves a lot of calculations, making DeePyMoD at least 30% faster than a standard implementation. 16 | 17 | * **Extendable** DeePyMoD is designed to be easily extendable and modifiable. You can simply plug in your own cost function, library or training regime. 18 | 19 | * **Automatic library** The library and coefficient vectors are automatically constructed from the maximum order of polynomial and differentiation. If that doesn't cut it for your use case, it's easy to plug in your own library function. 20 | 21 | * **Extensive logging** We provide a simple command line logger to see how training is going and an extensive custom Tensorboard logger. 22 | 23 | # How to install 24 | We provide two ways to use DeePyMoD, either as a package or in a ready-to-use Docker container. 25 | 26 | ## Package 27 | DeePyMoD is released as a pip package, so simply run 28 | 29 | ``` pip install DeePyMoD``` 30 | 31 | to install. Alternatively, you can clone the 32 | We currently provide two ways to use our software, either in a docker container or as a normal package. If you want to use it as a package, simply clone the repo and run: 33 | 34 | ```python setup.py install``` 35 | 36 | 37 | ## Container 38 | A GPU-ready Docker image can also be used. Once you've cloned the repo, go into the config folder and run: 39 | 40 | ```./start_notebook.sh``` 41 | 42 | This pulls our lab's standard docker image from dockerhub, mounts the project directory inside the container and starts a jupyterlab server which can be accessed through localhost:8888. You can stop the container by running the stop_notebook script. This will stop the container; next time you run start_notebook.sh it will look if any containers from that project exist and restart them instead of building a new one, so your changes inside the container are maintained. 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/setup-checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Setup file for deepymod_torch. 4 | Use setup.cfg to configure your project. 5 | 6 | This file was generated with PyScaffold 3.2. 7 | PyScaffold helps you to put up the scaffold of your new Python project. 8 | Learn more under: https://pyscaffold.org/ 9 | """ 10 | import sys 11 | 12 | from pkg_resources import require, VersionConflict 13 | from setuptools import setup 14 | 15 | try: 16 | require('setuptools>=38.3') 17 | except VersionConflict: 18 | print("Error: version of setuptools is too old (<38.3)!") 19 | sys.exit(1) 20 | 21 | 22 | if __name__ == "__main__": 23 | setup(use_pyscaffold=True) 24 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributors 3 | ============ 4 | 5 | * Gert-Jan 6 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Changelog 3 | ========= 4 | 5 | Version 0.1 6 | =========== 7 | 8 | - Feature A added 9 | - FIX: nasty bug #1729 fixed 10 | - add your changes here! 11 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 Gert-Jan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeePyMoD 2 | 3 | ## Deep learning based model discovery for ODEs and PDEs 4 | 5 | DeePyMoD is a PyTorch-based implementation of the DeepMoD algorithm for model discovery of PDEs and ODEs. We use a neural network to model our dataset, build a library of possible terms from the networks output and employ sparse regression to find the PDE underlying the dataset. More information can be found in our paper: [arXiv:1904.09406](http://arxiv.org/abs/1904.09406) 6 | 7 | **What's the use case?** Classical Model Discovery methods struggle with elevated noise levels and sparse datasets due the low accuracy of numerical differentiation. DeepMoD can handle high noise and sparse datasets, making it well suited for model discovery on actual experimental data. 8 | 9 | **What types of models can you discover?** DeepMoD can discover non-linear, multi-dimensional and/or coupled ODEs and PDEs. See our paper and the examples folder for a demonstration of each. 10 | 11 | **How hard is it to apply it to my data?** Not at all! We've designed the code to be accessible without having in-depth knowledge of deep learning or model discovery. You can load in the data, train the model and get the result in a few lines of code. We include a few notebooks with examples in the examples folder. Feel free to open an issue if you need any additional help. 12 | 13 | **How do I modify the code?** We provide two interfaces, an object-based and functional-based one. The object-based interface is simply a wrapper around the functional one. The code has been modularly designed and is well documented, so you should be able to plug-in another training regime, cost function or library function yourself pretty easily. 14 | 15 | # Features 16 | 17 | * **Many example notebooks** We have implemented a varyity of examples ranging from 2D Advection Diffusion, Burgers' equation to non-linear, higher order ODE's If you miss any example, don't hesitate to give us a heads-up. 18 | 19 | * **Extendable** DeePyMoD is designed to be easily extendable and modifiable. You can simply plug in your own cost function, library or training regime. 20 | 21 | * **Automatic library** The library and coefficient vectors are automatically constructed from the maximum order of polynomial and differentiation. If that doesn't cut it for your use case, it's easy to plug in your own library function. 22 | 23 | * **Extensive logging** We provide a simple command line logger to see how training is going and an extensive custom Tensorboard logger. 24 | 25 | * **Fast** Depending on the size of the data-set DeepMoD, running a model search with DeepMoD takes of the order of minutes/ tens of minutes on a standard CPU. Running the code on GPU's drastically improves performence. 26 | 27 | # How to install 28 | We provide two ways to use DeePyMoD, either as a package or in a ready-to-use Docker container. 29 | 30 | ## Package 31 | DeePyMoD is released as a pip package, so simply run 32 | 33 | ``` pip install DeePyMoD``` 34 | 35 | to install. Alternatively, you can clone the 36 | We currently provide two ways to use our software, either in a docker container or as a normal package. If you want to use it as a package, simply clone the repo and run: 37 | 38 | ```python setup.py install``` 39 | 40 | 41 | ## Container 42 | A GPU-ready Docker image can also be used. Once you've cloned the repo, go into the config folder and run: 43 | 44 | ```./start_notebook.sh``` 45 | 46 | This pulls our lab's standard docker image from dockerhub, mounts the project directory inside the container and starts a jupyterlab server which can be accessed through localhost:8888. You can stop the container by running the stop_notebook script. This will stop the container; next time you run start_notebook.sh it will look if any containers from that project exist and restart them instead of building a new one, so your changes inside the container are maintained. 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/config/.DS_Store -------------------------------------------------------------------------------- /config/start_notebook.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | container_name="deepymod_pytorch" 4 | image="phimal/projects:general" 5 | 6 | #First automatically find project directory 7 | cd ../ 8 | projectdir=$(pwd) 9 | cd config/ 10 | 11 | if [ "$(docker ps -aq -f name=$container_name)" ]; then 12 | echo "Restarting container." 13 | docker restart $container_name 14 | else 15 | if hash nvidia-docker 2>/dev/null; then 16 | echo 'Starting container with gpu.' 17 | docker run -d\ 18 | -p 8888:8888 -p 6006:6006 -p 8787:8787\ 19 | -v "$projectdir:/home/working/" \ 20 | --ipc=host \ 21 | --name=$container_name \ 22 | --runtime=nvidia \ 23 | $image bash -c "cd /home/working/ && \ 24 | python setup.py develop && \ 25 | jupyter lab --ip 0.0.0.0 --no-browser --allow-root --NotebookApp.token=''" 26 | else 27 | echo 'Starting container without gpu.' 28 | docker run -d\ 29 | -p 8888:8888 -p 6006:6006 -p 8787:8787 \ 30 | -v "$projectdir:/home/working/" \ 31 | --ipc=host \ 32 | --name=$container_name \ 33 | $image bash -c "cd /home/working/ && \ 34 | python setup.py develop && \ 35 | jupyter lab --ip 0.0.0.0 --no-browser --allow-root --NotebookApp.token=''" 36 | fi 37 | fi 38 | 39 | # Also create a nice stop script 40 | echo "You can stop this container by running stop_notebook.sh" 41 | echo "docker stop $container_name" > stop_notebook.sh 42 | -------------------------------------------------------------------------------- /config/stop_notebook.sh: -------------------------------------------------------------------------------- 1 | docker stop deepymod_pytorch 2 | -------------------------------------------------------------------------------- /examples/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/examples/.DS_Store -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/examples/.gitignore -------------------------------------------------------------------------------- /examples/ODE_Logistic_equation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Logistic model " 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 3, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# General imports\n", 17 | "import numpy as np\n", 18 | "import torch\n", 19 | "import matplotlib.pylab as plt\n", 20 | "\n", 21 | "# DeepMoD stuff\n", 22 | "from deepymod_torch.DeepMod import DeepMod\n", 23 | "from deepymod_torch.training import train_deepmod, train_mse\n", 24 | "from deepymod_torch.library_functions import library_1D_in\n", 25 | "\n", 26 | "from scipy.integrate import odeint\n", 27 | "\n", 28 | "# Settings for reproducibility\n", 29 | "np.random.seed(40)\n", 30 | "torch.manual_seed(0)\n", 31 | "\n", 32 | "%load_ext autoreload\n", 33 | "%autoreload 2" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 10, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "def dU_dt(U, t):\n", 43 | " # Here U is a vector such that u=U[0] and v=U[1]. This function should return [u, v]' \n", 44 | " # The ODE we solve here is u' = u*v and v' = -0.2v\n", 45 | " return [U[1]*U[0], -0.2*U[1]]\n", 46 | "U0 = [2.5, 0.4]\n", 47 | "ts = np.linspace(0, 20, 500)\n", 48 | "Y = odeint(dU_dt_sin, U0, ts)\n", 49 | "T = ts.reshape(-1,1)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 11, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "T_rs = T\n", 59 | "Y_rs = Y/np.max(np.abs(Y),axis=0)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 12, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "data": { 69 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEGCAYAAAB1iW6ZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3dd3xUVf7/8dcnCSlAKgmEEnrvYOiKqIiACGJBEAuiIiqWddf9uqs/17q6rruuBVFEREBFWXVFlhURUemQ0CH0GiCFAClA+vn9cScYQwITmMmdmXyej8c8ptw7M5/cmbxzcu6954gxBqWUUt7Pz+4ClFJKuYYGulJK+QgNdKWU8hEa6Eop5SM00JVSykcE2PXG0dHRpmnTpna9vVJKeaXExMRjxpiY8pbZFuhNmzYlISHBrrdXSimvJCIHKlqmXS5KKeUjNNCVUspHaKArpZSP0EBXSikfoYGulFI+4oKBLiLTRSRNRLZUsFxE5C0R2S0im0Sku+vLVEopdSHOtNBnAIPPs3wI0MpxmQBMufSylFJKVdYFj0M3xvwiIk3Ps8oIYKaxxuFdJSIRIlLfGHPURTX+VnIibP8WBj7nlpdXSilXKS42HD+dT2pWruOSR2pWLle3rUvnRhEufz9XnFjUEDhU6n6y47FzAl1EJmC14mncuPHFvduRdbDsDWh/IzToenGvoZRSl6i42JCek8fhk2c4cvaSy9HMM6Rm5ZGWlUtadh6FxefOOVGndpDHBrqU81i5s2YYY6YCUwHi4+MvbmaNTrfC98/A+lka6Eopt8ktKCL5xGkOnfg1sI+ezLUCPPMMKZm5FBT9NsZqBwUQGx5MbFgwLVpEUy8siHphwdQLC6JumPV4dO0gAgPcczyKKwI9GYgrdb8RcMQFr1u+kAhoPwI2zYVBL0GNELe9lVLKdxljOHG6gAMZpzh4/DQHM05zwHF98PhpUrJyf7O+v58QGxZMw4gQujeOpEFECA0iQmgYEXz2dlhwDZt+GosrAn0eMElE5gC9gEy39Z+X6H4XbPocts2DLre59a2UUt4tK7eAvemn2JOWw570HPYdO8UBR2jn5BX+Zt16YUE0iapFv5bRNKlTk8ZRNWkUGULDyBDqhgbj71deh4TnuGCgi8hnwAAgWkSSgb8ANQCMMe8BC4ChwG7gNHCPu4o9q0k/iGoO62ZqoCulKC42HM3KPRvae9Jz2JN2ij3pOaRl551dL8BPaBxVkyZ1atKzWRRxUTVp4rjfKLImIYH+Nv4Ul86Zo1zGXGC5AR52WUXOEIFud8Li5yFjD9RpUaVvr5SyT0ZOHjtSsklKyWZHShbbU7LZlZrDmYKis+uEBQfQsm5trmwdQ4u6tWkeXYsWdWvTOKomNfx993xK24bPvWRdb4cfX7J2jg58zu5qlFIulltQxO60HLanZLP9aBY7UrNJOprNsZxfW9zRtYNoGxvKmJ6NaVm3Ni1irOCuUysQEc/uHnEH7w300FhofR1s+BSuegb8vfdHUaq6yyssYvvRbDYfzmRzciabD2eyMzX77CF/QQF+tK4XylVtYmgTG0q7+mG0iQ0lunaQzZV7Fu9OwW53wo4FsOt7aDvU7mqUUk7ILyxme0rWb8J7R8qv4R1RswadGoYzoU1zOjQIp239UJrWqeXxOyQ9gXcHeqtBUDvW6nbRQFfKI6Vn57Hu4AnWHThB4oETbD6cSV5hMQDhITXo3Cic+/s3p3PDcDo2DKdRZEi17C5xBe8OdP8A6DoGlr8FWUchrL7dFSlVrRUVG7anZJ0N73UHT3Lw+GkAAv396NAwjDt7N6Fr4wi6NIrQ8HYx7w50sLpdlr0BG2ZD/yftrkapaqWwqJhtR7NYtTeDVXuPs3bfcbIdx3bHhAZxWeNI7uzdhO5NIujQIJzgGt59WKCn8/5Ar9MCmg+AhBlw+RPgp18YpdzlfAHeIqYWw7s2oGezKLo3jtTWtw28P9ABetwHn98BOxdqX7pSLnYw4zQ/70pn6c50Vu7JOCfAezevQ6/mUdQNDba5UuUbgd56CIQ2gIQPNdCVukQ5eYWs3JPBLzvTWbornf0ZVh94o8gQhnVpQN8WGuCeyjcC3T8ALrsbfnoVju+1hgVQSjnFGMPutBwWJaXy8450Eg+coLDYUDPQnz7N63BPv2Zc0SqaZtG1tAvFw/lGoIM1YNfPr0HCRzDoRburUcqjFRQVs3b/cX7Ylsbi7akccLTCOzQI4/7+zenfKobuTSIICtB9Ut7EdwI9rIHV3bJ+Nlz1NNTQfweVKi0rt4CfdqTzw7ZUftqRRlZuIYEBfvRrUYcJ/ZtzTdt6xIbr7403851AB4i/F5K+hW3f6CiMSgGZZwr4YVsqCzYfZemuY+QXFVOnViDXdYhlYPt6XNEqmpqBvhUD1ZlvfZLNroQ6La2doxroqpoqHeK/7EqnoMjQIDyYu/o0YUinWLrGRepp9D7KtwLdzw/ix8PCP0PKZojtZHdFSlWJ0/mFLNyawrcbj7K0VIjf3acpQzvXp2ujCPw0xH2ebwU6WMPqLn4B1nwAw9+yuxql3Kao2LBizzG+XneY77amcDq/6GyIX9+5Pl3jIvSolGrG9wI9JBI6j4JNX1jjpNeMsrsipVwq6WgWX68/zDcbDpOalUdocADDuzRgZLeG9GgapS3xasz3Ah2g10RrerrEGXDFE3ZXo9QlO3Eqn6/WH2ZuwiG2p2QT4CcMaBPDs8MacU27ujpGigJ8NdDrdYBm/WHtNOj7CPjbOxO3UhfDGMPKvRnMWXOI77akkF9UTJe4CF4Y0YHrO9Wnjk7uoMrwzUAH6P0QfDbaOoyx4012V6OU09Kz8/hyXTKfrz3EvmOnCAsO4PZejRndM462sWF2l6c8mO8GeqvrILIZrH5PA115PGMMa/ef4OMV+1m4NYXCYkPPplE8cnVLhnaqr10qyim+G+h+ftDrAfjuKTicCA0vs7sipc6RW1DENxsOM2PFAZKOZhEeUoNxfZsyumccLeuG2l2e8jK+G+gAXcfCjy/Dqvfg5g/srkapsw6fPMOslQeYs/YgJ08X0DY2lFdu6sSNXRsSEqitcXVxfDvQg8Og2x3WztFBL0JorN0VqWou8cAJpi3dy8KtKQAMah/LuH5N6dUsSo8ZV5fMtwMdoNcEqx997Ydw9dN2V6OqoeJiw5Idabz/817W7D9OeEgNJvRvwR29G9Mosqbd5Skf4vuBHtUc2gyxxne5/HcQqL9AqmrkFxYzb+MRpv6yh52pOTSMCOHZYe25rUcctYJ8/1dPVb3q8a3q+wh8tAA2fAI977e7GuXjcvIK+Wz1QT5cto+UrFzaxobyr9u6cn3n+tTw97O7POXDqkegN+4DjXrAyneswbt0ImnlBjl5hXy8Yj8fLN3LydMF9Gleh1dv7sSVrWO0f1xVieoR6CLQ7zFrIumkedBhpN0VKR+SnVvAzJUHzgb5VW1iePSaVnRrHGl3aaqaqR6BDtBmKES1gGX/gvY3WiGv1CXIzi1wtMj3kXmmgKvb1uWxa1rRJS7C7tJUNVV9At3P3+pLn/847F9qjfWi1EU4nV/IR8v3M/WXvWSeKeCatnV5bGArOjfSIFf2cirQRWQw8CbgD0wzxrxaZnlj4GMgwrHOU8aYBS6u9dJ1GQNLXoblb2mgq0orKCpmzpqDvLl4N8dy8rimbV0eH9iaTo3C7S5NKcCJQBcRf2AycC2QDKwVkXnGmG2lVnsG+MIYM0VE2gMLgKZuqPfS1Ai2htb98UVI2QKxHe2uSHmB4mLDt5uO8M9FOzmQcZqeTaN4/87uXNZEx9pXnsWZY6h6AruNMXuNMfnAHGBEmXUMUDIMXDhwxHUluliPe6FGLVjxtt2VKA9njOHnnenc8M4yHpuzgZAa/nw0rgefP9Bbw1x5JGe6XBoCh0rdTwZ6lVnnOeB7EXkEqAUMdEl17hASCZfdDWumWmeORjS2uyLlgbanZPHS/CSW7T5GXFQI/7qtK8O7NNDZgJRHc6aFXt432JS5PwaYYYxpBAwFZonIOa8tIhNEJEFEEtLT0ytfrav0eRgQWP6mfTUoj5SRk8fTX29m6JtL2Xw4k2eHtWfxEwO4sVtDDXPl8ZxpoScDcaXuN+LcLpV7gcEAxpiVIhIMRANppVcyxkwFpgLEx8eX/aNQdcIbWZNJr5sFV/wBwurbVoryDPmFxXy8Yj9vLd7F6YIi7urTlMcHtiKiZqDdpSnlNGda6GuBViLSTEQCgdHAvDLrHASuARCRdkAwYGMT3AmX/w6KC7UvvZozxvD91hQGvfEzLy9IIr5pJAsfv4LnhnfQMFde54ItdGNMoYhMAhZiHZI43RizVUReABKMMfOA3wMfiMjvsLpjxhlj7GuBOyOqGXQeBQnTrXCvHWN3RaqK7T92ir/M28rPO9NpWbc2M+7pwYA2de0uS6mLJnblbnx8vElISLDlvc9K3wmTe8Llj8PA5+ytRVWZ3IIipvy0hyk/7yHQ34/fXduau/o00YGzlFcQkURjTHx5y6rPmaLliWltjeuy5gPo+yjU1EPRfN2SHWn85ZutHDx+muFdGvD09e2oFxZsd1lKuYQ2Sfr/AfJzYPX7dlei3OjwyTNMnJXIPR+tJcBf+PS+Xrw1ppuGufIp1buFDlCvA7QdBqunWIczBodd+DnKaxQVG2au3M/fF+6g2BievK4N91/RnMAAbcso36OBDnDF72H7fKuVfuWTdlejXGRXajZ//HIT6w+eZECbGF66saNO+aZ8mgY6QMPu0HqIdQhjz/uss0mV18ovLObdn3YzecluagcF8MZtXbixa0OdZEL5PP2/s8RVf4a8TFg52e5K1CVYf/AEw95eyr9+2MWQjvX54YkrGdmtkYa5qha0hV6ifmdr4otVU6DXg1Crjt0VqUrILSji9YU7+HD5PmLDgpk+Lp6r29azuyylqpS20Esb8CfIPwXL/2V3JaoSNiWfZNjby5i2bB+392zM97/rr2GuqiVtoZdWt6119uiaD6DPJAjVUPBkBUXFvP2j1VceUzuImeN70r+1nvGrqi9toZd15f9BUT4s+6fdlajz2Jmazch3l/PW4l2M6NKAhb/rr2Guqj1toZdVp4U1EmPCdGsO0vBGdlekSikqNkxfto+/f7+D0KAA3rujO4M76miZSoG20Mt35R/BGPj5b3ZXokpJyczljmmreXlBEgNax7Dwd/01zJUqRVvo5YloDPHjYa2jLz2mjd0VVXuLtqXyx39vJLegmNdu7syt8XooolJlaQu9Ilf+0Zp79Ifn7a6kWsstKOLZb7Zw/8wEGkSEMP/RyxnVI07DXKlyaAu9IrWi4fLH4MeX4MBKaNLH7oqqnZ2p2Tzy6Xp2pGZz3+XNeHJwG4IC/O0uSymPpS308+n9MITWh0X/z+pTV1XCGMPsVQe44e1lZJzKY8Y9PXhmWHsNc6UuQAP9fAJrWicbJa+FpG/trqZayMkr5JHP1vPMf7bQq3kd/vdYf51FSCknaaBfSNexEN0GFj8PRQV2V+PTdqRkM/ydZSzYfJQ/Dm7DjHE9iAkNsrsspbyGBvqF+AfAtc9Dxm5YN9PuanzWl4nJjJi8jOzcQj69vzcPDWiJn5/u+FSqMjTQndF6MDTuCz+9ArlZdlfjU3ILinjqy038fu5GusZF8N9HL6d3cx0YTamLoYHuDBG47mU4lQ5LX7e7Gp9xIOMUN727gjlrD/HwVS2YfW8v6obqlHBKXSw9bNFZDbtb/emrpsBl4yCqud0VebVF21J54vMN+PmJDnWrlItoC70yrnkW/APh+/9ndyVeq7jY8OYPu7h/ZgJNo2vx30cv1zBXykU00CsjNBaueMKaf3TvT3ZX43Vy8gqZODuRN37YyU3dGzJ3Yh+d41MpF9JAr6zeD0NEE/juT1BUaHc1XmPfsVOMnLycxdvTeHZYe/5xaxeCa+iJQkq5kgZ6ZdUIhkEvQdo2WDfD7mq8wpIdaQx/ZxnHcvKYNb4n4y9vpmOxKOUGGugXo90N0PQK+PFlOHPC7mo8ljGGd3/azfgZa2kUWZN5ky6nb8tou8tSymdpoF8MERj8KuSehMUv2l2NR8otKOLRORt47bsdXN+pPl892Je4KO0vV8qdNNAvVmxH6DXRmtnocKLd1XiUtOxcRk9dxfxNR/jj4Da8PaYbIYHaX66Uu2mgX4oBf4La9WD+E1BcZHc1HiHpaBYjJ69gR0o2U8ZexkMDWmp/uVJVRAP9UgSHweC/wtENVku9mvtxeyq3TFlBYXExcyf2YXDHWLtLUqpa0UC/VB1uguYDrL70nDS7q7GFMdbEzfd9bJ0s9M3Dl9OxYbjdZSlV7TgV6CIyWER2iMhuEXmqgnVGicg2EdkqIp+6tkwPJgJD/wGFZ6rlGaQFRcU8858tvDB/GwPb1WPuxD7Ehut4LErZ4YKBLiL+wGRgCNAeGCMi7cus0wr4E9DPGNMBeNwNtXqu6JbQ7zHYNAf2LbW7miqTnVvA+Blr+WT1QR64sjnv3XEZNQN1eCCl7OJMC70nsNsYs9cYkw/MAUaUWed+YLIx5gSAMab69T1c8XuIbArfPgYFuXZX43Ypmbnc+t5KVu7J4LWbO/OnIe10/HKlbOZMoDcEDpW6n+x4rLTWQGsRWS4iq0RkcHkvJCITRCRBRBLS09MvrmJPVSMEhv0Lju+Bn/9mdzVutTM1m5veXc6h46eZPq4Ho3rE2V2SUgrnAr28ZlfZGZMDgFbAAGAMME1EIs55kjFTjTHxxpj4mJiYytbq+VpcBV3vgOVvwtFNdlfjFqv2ZnDLlBUUFBs+f6AP/Vv74OeolJdyJtCTgdJNsEbAkXLW+cYYU2CM2QfswAr46mfQi1CzDsyb5HODd83fdIS7PlxDTGgQXz3YV49kUcrDOBPoa4FWItJMRAKB0cC8Muv8B7gKQESisbpg9rqyUK9RMwqGvgZHN8Kqd+2uxmWmLd3LpE/X0yUunC/1NH6lPNIFA90YUwhMAhYCScAXxpitIvKCiAx3rLYQyBCRbcAS4EljTIa7ivZ47W+ENtfDkr/Cce/+u1ZcbHjh22289N8khnSMZda9vYioGWh3WUqpcogxZbvDq0Z8fLxJSEiw5b2rRNYRmNwL6neBu+aBn/edw5VXWMQTn2/kv5uPMq5vU/7fsPb465EsStlKRBKNMfHlLfO+lPEWYQ2siaX3L4U1U+2uptJy8goZP2Mt/918lD8PbctfbtAwV8rTaaC7U7c7odV18MNf4Nguu6tx2vFT+Yz9YBWr9h7n9Vu7MKF/Cx1gSykvoIHuTiIw/C0ICIavJ3rFUS9HM88w6v2VJKVk894dl3HLZY3sLkkp5SQNdHcLjYXr/wGHE2DFm3ZXc1570nO4ZcpKUjJzmTm+J9e2r2d3SUqpStBArwqdboEOI2HJK5Cyxe5qyrXlcCaj3ltJbkERcyb0pnfzOnaXpJSqJA30qjL0HxASCV8/AIV5dlfzGyv3ZDB66iqCa/gzd2IfPWFIKS+lgV5VatWB4W9D6hb44Tm7qznr+60p3P3RGmLDg/n3g31oHlPb7pKUUhdJA70qtRkMPSdYZ5DuWmR3NXyZmMzE2Ym0qx/G3Af6UD88xO6SlFKXQAO9ql37ItTtYB31kp1qWxmzVx3g93M30rt5HT69rxeRtfTsT6W8nQZ6VasRDLdMh/xT8J+JUFxc5SVMW7qXZ/6zhavb1mX6uB7UCtJJKZTyBRrodqjb1ppces+PsGpylb71Oz/uOjsuy3t3XEZwDf8qfX+llPtooNvlsnug3Q3ww/NwZL3b384Yw98Xbuf173cysltD3h7TjcAA/fiV8iX6G20XEbjhLahdF/49HnIz3fZWxhhemL+NyUv2MKZnHP+4tQsB/vrRK+Vr9LfaTjWj4OYP4cQB+OZhcMPIl8XFhj9/vYWPlu/nnn5N+evITjr3p1I+SgPdbk36wLUvQNK3sPIdl750YVExv5+7kc/WHOThq1rw7LD2OsiWUj5MA90T9HkY2g2HRX+BAytc8pL5hcU88tl6vl5/mD8Mas2T17XVMFfKx2mgewIRGDEZIpvC3HGXfHx6XmERD32SyP+2pPDM9e2YdHX1nN5VqepGA91TBIfBbbMgN8vaSXqRQ+3mFhTx4Ox1/JCUxos3duS+K5q7uFCllKfSQPck9TrAsDfgwDJY/Fyln55bUMTE2Yn8uD2Nl0d25M7eTVxfo1LKY+kpgp6m6xhIXgMr3obYztB5lFNPyy0oYsKsRH7Zmc4rN3ViTM/Gbi5UKeVptIXuiQb/DZr0g28mQXLiBVfPLSji/pkJLN2Vzms3d9YwV6qa0kD3RAGBMGomhNaDObdD1tEKVz2TX8R9HyewbPcx/nZzZ0b1iKvCQpVSnkQD3VPViobRn0FethXqBWfOWeVMfhH3fryW5XuO8fdbujAqXsNcqepMA92TxXaEm96HI+tg3qO/OZP0dH4h42esZdXeDP45qotO5qyU0kD3eO1ugKuegc1fwLI3ADiVV8g9H61l9b4M/jmqKyO7aZgrpfQoF+/Q/w+QngSLnye3diPuWR1HwoHjvHFbV0Z0bWh3dUopD6EtdG8gAiPepSiuL/7fPIj/oeW8ObqbhrlS6jc00L1EdpE/43Mf54Cpy8yab3JDg2y7S1JKeRgNdC+QnVvA3dPXsPxwEclDZlIjMBhm3wLZKXaXppTyIBroHi47t4C7pq9hU3Im79zejQG9e8DYL+B0Bnw6CvJy7C5RKeUhNNA9WE5eIXdPX8Pm5Ezeub07gzvWtxY06Aa3zoCUzfD5HVCYZ2udSinP4FSgi8hgEdkhIrtF5KnzrHeLiBgRiXddidVTTl4h46avYaOjZT64Y+xvV2g9CIa/A3uXwJf3XfTojEop33HBQBcRf2AyMARoD4wRkfblrBcKPAqsdnWR1Y11nPka1h86ydtjuv3aMi+r21i47hVImgfzH3PLFHZKKe/hTAu9J7DbGLPXGJMPzAFGlLPei8BrQK4L66t2Sk4aWnfwJG+N7sbQThWEeYk+D8GV/wfrZ8PCpzXUlarGnAn0hsChUveTHY+dJSLdgDhjzPzzvZCITBCRBBFJSE9Pr3Sxvu50fiH3zFhL4sETvDm6K9d3vkCYlxjwJ+j5AKyaDL/83b1FKqU8ljNnipY3EeXZZqCI+AFvAOMu9ELGmKnAVID4+HhtSpZSMjZLwv7j/Gt0N4Z1buD8k0Vg8KuQlwVLXobAWtY8pUqpasWZQE8GSg/j1wg4Uup+KNAR+MkxCXEsME9EhhtjElxVqC87k1/EvTMSWLPPOp1/eJdKhHkJPz9rJ2nBaVj4Z0Cs7hilVLXhTKCvBVqJSDPgMDAauL1koTEmE4guuS8iPwF/0DB3TskQuCUDbV3S6fz+AXDzh1Y/+sI/WS333g+6rlillEe7YB+6MaYQmAQsBJKAL4wxW0XkBREZ7u4CfVnJTEMr92bwj1FduLGbC8Zm8a8Bt0y3Rmn87ilY9d6lv6ZSyis4NdqiMWYBsKDMY89WsO6ASy/L95WE+fI9x3j9li6uHQLXvwbc8hHMHQff/Z/VUu/1gOteXynlkfRMURuUhPmy3dZMQze7Y3KKklBvOwz+90dr0mmllE/TQK9iuQVFPDAr8ewcoG6daSgg0Ar19jfC98/Ajy/rcepK+TCd4KIK5RUWMXF2Ij/vTOe1mztXzRygAYFWn/q3ofDLa9ahjde9Yh0Vo5TyKRroVSSvsIiJsxL5aUc6r97UiVE9qnBCZz9/GP42BIVZJx/lZcMNb1lHxSilfIb+RleBvMIiHpy9jiU70nnlpk6M7tm46osQgetehuAw+OkVK9RvngYBQVVfi1LKLfT/bjfLLyzm4U/W8eP2NP46shNj7AjzEiIw4Cm47q/WgF6zboIzJ+2rRynlUhrobpRfWMzDn67jh6Q0XrqxI7f3sjHMS+vzMNw0DQ6thumD4eShCz9HKeXxNNDdJL+wmEmfrmPRtlReHNGBO3o3sbuk3+p8K9z5FWQdhg+vtSbLUEp5NQ10N8gvLOaRz9bx/bZUXhjRgTv7NLW7pPI16w/jvwPxg+lDYM8SuytSSl0CDXQXK+lmWbg1leduaM9dnhrmJep1gHsXQWQT+OQWSJhud0VKqYukge5C1tEsiSxytMzH9Wtmd0nOCW8I9/wPWlwN838HC57UKe2U8kIa6C5Scgbo4u1pvDyyo+e3zMsKDoMxc6DPJFgzFWbfBKeP212VUqoSNNBdoGRslp93WicNje3lYTtAneXnbx2rPuJdOLgSpl0D6Tvtrkop5SQN9EtUMp55ydgstpw05GrdxsLd31onH027BpK+tbsipZQTNNAvgTUH6BpW7sngH7d2qZqxWapK495w/xKo0xI+vwMWPav96kp5OA30i5STV8i46WvPTht3U3c3jppol4g467DG+Hth+ZswcwRkp9pdlVKqAhroFyE7t4Bx09eQePAEb47udmnTxnm6gCAY9k8Y+T4cToT3+8OBlXZXpZQqhwZ6JWXlFnDX9DVsOHSSd8Z044aLmdDZG3UZDfcvhsCaMON6WPpPKC62uyqlVCka6JVw4lQ+d05bzZbDmUwe250hnerbXVLVqtcBJvwE7YbB4udh1gjIOmJ3VUopBw10J6Vl5zJ66iqSUrJ5747LuK5DrN0l2SM4HG792BpfPTkBpvSF7f+1uyqlFBroTjl88gy3vb+KQydO89G4HlzTrp7dJdlLBLrfBQ/8AuFxMOd2mP8E5J+2uzKlqjUN9AvYf+wUo95bybGcPGbd25N+LaPtLslzRLeC+36Avo9Awocw9UpITrS7KqWqLQ3089iRks2t76/kTEERn93fm8uaRNldkucJCIJBL8GdX0P+KfhwIPzwHBTm2V2ZUtWOBnoFNiWf5LapK/ET+HxCbzo2DLe7JM/W4mp4aCV0vR2WvQHvXwmH19ldlVLVigZ6OdbuP87tH6ymdlAAcx/oS6t6oXaX5B2Cw2HEZBj7b8jNhGkDYfGL2lpXqopooJfxy8507vxwNXXDgpg7sQ+N69S0uyTv0+paq7XeZTQsfR2m9IP9y+yuSimfp4FeyryNR7j347U0i67NFw/0oX54iN0leb/UBigAABCZSURBVK+QCLjxXbjjSyjKt05G+s9DcCrD7sqU8lka6A4fLd/Ho5+tp1vjSD5/oDfRtYPsLsk3tBwID62Cy5+ATZ/DO/GwfjYYY3dlSvmcah/oxhheX7iD57/dxqD29Zg5vidhwTXsLsu3BNaEgX+BB5ZCdGv45mH4aCgc3WR3ZUr5lGod6IVFxfzpq828s2Q3Y3rG8e7Y7gTX8Le7LN9Vr7011d0Nb0L6dmugr3mPQk663ZUp5ROqbaDnFhTx0CfrmLP2EI9c3ZK/juxEgH+13RxVx88PLhsHj66D3g/Chk/g7e6w4h0ozLe7OqW8mlMJJiKDRWSHiOwWkafKWf6EiGwTkU0islhEPHoOtswzBdz14RoWJaXy3A3t+f2gNoiI3WVVLyGRMPgVeHAlxPWE75+GKX1gx3fav67URbpgoIuIPzAZGAK0B8aISPsyq60H4o0xnYF/A6+5ulBXOXT8NLdMWcH6Qyd4a3Q3xvVrZndJ1VtMa+tImNvnWvc/u806IubganvrUsoLOdNC7wnsNsbsNcbkA3OAEaVXMMYsMcaUjMy0CvDI6Xs2HjrJyHdXkJqVy8zxvarPWObeoPUgq7U+9HU4tgumD4LPxkBakt2VKeU1nAn0hsChUveTHY9V5F7gf+UtEJEJIpIgIgnp6VW7I+z7rSncNnUlwTX8+OqhvvRpUadK3185ISAQet4Pj22Aq5+xTkZ6tw98/SCcPGh3dUp5PGcCvbzO5XI7OUXkDiAe+Ht5y40xU40x8caY+JiYGOervATGGD5cto8HZifSJjaMrx/qR8u6eiq/RwusBf2fhMc2Qp+HYcuX8FZ3+PZxOHHA7uqU8ljOBHoyUHo6+0bAOdPUiMhA4GlguDHGIwbvKCo2PDdvKy/O38Z17WOZc39vYkL1hCGvUTMKrnvZOiKm+53WCUlvd4dvJsHxvXZXp5THcSbQ1wKtRKSZiAQCo4F5pVcQkW7A+1hhnub6Misv80wB9368lo9XHuD+K5rx7tjuhATqMeZeKbwRDHvD6oqJHw+bvoC3462umGO77a5OKY8hxolDxERkKPAvwB+Ybox5WUReABKMMfNE5AegE3DU8ZSDxpjh53vN+Ph4k5CQcGnVV2Bveg73zUzgYMZpnh/RgbG9PPooSlVZWUdhxVuQMN0aybHdDdD3UYjrYXdlSrmdiCQaY+LLXeZMoLuDuwL9553pTPp0HTX8/Zgytju9muvOT5+VnQqr3oWEjyAvExr3sWZPaj3EOoFJKR9ULQK9ZOfnXxck0bpeKB/cFU9clA59Wy3kZVv96yvfhcyDUKeltTO1823WDlalfIjPB3puQRF//nozX607zJCOsbx+axdqBQW45LWVFykqhG3/sbpjjm60JtzoOhZ63Ad1WthdnVIu4dOBvv/YKSbOTmR7SjaPD2zFo1e3ws9PT+Ov1oyBAytg7TRImgfFhdYUeT3uh9bXgZ/uHFfe63yB7tXN2O+2pPDk3I34+wsf3dODq9rUtbsk5QlEoGk/65KdAokfQ+JHMGcMhDeG+HHQ5XYIq293pUq5lFe20AuLinlt4Q6m/rKXLo3CmTy2O40itb9cnUdRAexYAGs+gP1LQfyg5bXQ7Q5oPdg6S1UpL+BTLfS0rFwmfbqeNfuPc2fvJjwzrB1BAfovtLoA/xrQfoR1ydhjDdu74VP44k6oWQc6j4ZuY6FeB7srVeqieV0L/a3Fu5jy0x5evbkTI7qeb0gZpS6gqBD2LoH1s2D7AigugNjO0OkW6HizdUKTUh7Gp3aKFhYVc+jEGZpF6+FoyoVOZcDmL6yzUI+ssx5r3Bc63Qztb4Ra0fbWp5SDTwW6Um6XsQe2fAWb58KxHSD+0OIqq9XeZog1OYdSNtFAV+piGAOpW2Dzv62AzzwIfgHQ9HJoOwzaXg9hOqa+qloa6EpdKmPg8DrY/i0kzYeMXdbjDeOh3TAr4KNb2VujqhY00JVytfQdkPQtbJ8PR9Zbj0U1h1aDrMMhm/aDGiH21qh8kga6Uu6UmWwdJbPre+sY98JcCAiBZldY4d5qoBX2SrmABrpSVaXgDOxfDrsXWQFfMhFHVAtoPgCa9YemV0AtHQVUXRwNdKXskrEHdv9gXQ6sgPwc6/F6naxwb9YfmvSF4DB761ReQwNdKU9QVGD1t+/7Gfb9AgdXQ1GedVhkg27QuLd1iesNtatmzl3lfTTQlfJEBbmQvMYK931LrROaivKtZVEtHOHey5q4I7qVNeiYqvZ8aiwXpXxGjeBfu13Amk7vyAY4tAoOroId/7PGnAEIiYJGPayWfMPu0KC7tuLVOTTQlfIUAUHQuJd16feYdex7xm4r3A+ugsMJ1o5WHP9Vh8dZAX825LtZk3qoaksDXSlPJWJ1tUS3gu53Wo/l5VizMR1ZZ53odGSdNYlHiajmUK8jxHayRo6s1xEiGmt3TTWhga6UNwmq/evkHSVOH7d2th5ZB0c3WcMVJH3L2ZZ8UNiv4V6vgxX2MW2t11I+RQNdKW9XMwpaXmNdSuTlQFoSpG6G1K2QsgU2zoH87F/XCWtktf5j2kB0a8d1G2tkSW3ReyUNdKV8UVBtiOthXUoYAycPWOGevh2O7bQu62ZBwalf1wuJtII9upUV9FHNrUtkUwjUmcE8mQa6UtWFiBXKkU2tAcVKGANZh63xaY7t/PV653fW5B+lhdaHyGaOkG/qCPpmENVMhxX2ABroSlV3ItbsTOGNftttA3DmBBzfByf2WcMYHN9vXe9ZDBuO/nbd4AiIiLOOvgmPs14vIs6amDu8EdSuq105bqaBrpSqWEgkNIy0DossK/8UnNhvBf7xvdbtzGTret/S3/bXA/gHQXhDK+wj4iCsIYTGWq3+kutaMeCncwRfLA10pdTFCazlOHqmgom1z5yEzENWyJ885Lh9yLq9axHkpHH2SJwS4ge16zkCvsG5gV87xgr9mtEQEOj2H9HbaKArpdwjJMK6xHYqf3lRIZxKg+yjkJ1y7vWJ/XBwJZw5Xv7zg8KtI3JqxZzn2hH+IZHg7/tx5/s/oVLKM/kHWFP4XWgav8I8R9CnwKl0x+UYnD726/3je+HQajidAaa4/NcJCnf8kYms3MWL/hPQQFdKebaAIIhsYl0upLjI2pFbEvol12eOW4+XvmQe+vV2RX8EwJqsJDgMgkKtk7SCQh33w0vdvsCywNpVskNYA10p5Tv8/B3dLdHOP6e42NqBWzbwz5yA0ycgLxNysyAvG/KyrNs5adbtkscuSKxQD6xlnSNw1Z+h480X/WNWxKlAF5HBwJuAPzDNGPNqmeVBwEzgMiADuM0Ys9+1pSqllBv4+VmDmgWHW8foV1bJH4S8bEfwO4I+N/PXPwD5OdZRQXnZ1u2QKJf/GOBEoIuIPzAZuBZIBtaKyDxjzLZSq90LnDDGtBSR0cDfgNvcUbBSSnmU0n8QbB7s0s+JdXoCu40xe40x+cAcYESZdUYAHztu/xu4RkTPIFBKqarkTKA3BA6Vup/seKzcdYwxhUAmoLPgKqVUFXIm0MtraZedt86ZdRCRCSKSICIJ6enpztSnlFLKSc4EejIQV+p+I+BIReuISABWT9I5ZwMYY6YaY+KNMfExMTp9llJKuZIzgb4WaCUizUQkEBgNzCuzzjzgbsftW4AfjV2zTyulVDV1waNcjDGFIjIJWIh12OJ0Y8xWEXkBSDDGzAM+BGaJyG6slvlodxatlFLqXE4dh26MWQAsKPPYs6Vu5wK3urY0pZRSleFMl4tSSikvIHZ1dYtIOnDgIp8eDRxzYTmuonVVjtZVeZ5am9ZVOZdSVxNjTLlHldgW6JdCRBKMMfF211GW1lU5WlfleWptWlfluKsu7XJRSikfoYGulFI+wlsDfardBVRA66ocravyPLU2raty3FKXV/ahK6WUOpe3ttCVUkqVoYGulFI+wqMDXUQGi8gOEdktIk+VszxIRD53LF8tIk2roKY4EVkiIkkislVEHitnnQEikikiGxyXZ8t7LTfUtl9ENjveM6Gc5SIibzm21yYR6V4FNbUptR02iEiWiDxeZp0q214iMl1E0kRkS6nHokRkkYjsclxHVvDcux3r7BKRu8tbx4U1/V1Etjs+p69FJKKC5573M3dTbc+JyOFSn9fQCp573t9fN9T1eama9ovIhgqe65ZtVlE2VOn3yxjjkRescWP2AM2BQGAj0L7MOg8B7zlujwY+r4K66gPdHbdDgZ3l1DUAmG/DNtsPRJ9n+VDgf1jDHfcGVtvwmaZgnRhhy/YC+gPdgS2lHnsNeMpx+yngb+U8LwrY67iOdNyOdGNNg4AAx+2/lVeTM5+5m2p7DviDE5/1eX9/XV1XmeX/AJ6tym1WUTZU5ffLk1voHjlTkjHmqDFmneN2NpDEuRN+eKoRwExjWQVEiEj9Knz/a4A9xpiLPUP4khljfuHcoZ1Lf48+Bm4s56nXAYuMMceNMSeARcBgd9VkjPneWJPFAKzCGra6ylWwvZzhzO+vW+pyZMAo4DNXvZ+TNVWUDVX2/fLkQPf4mZIcXTzdgNXlLO4jIhtF5H8i0qGKSjLA9yKSKCITylnuzDZ1p9FU/Etmx/YqUc8YcxSsX0qgbjnr2LntxmP9Z1WeC33m7jLJ0R00vYIuBDu31xVAqjFmVwXL3b7NymRDlX2/PDnQXTZTkjuISG3gS+BxY0xWmcXrsLoVugBvA/+pipqAfsaY7sAQ4GER6V9muZ3bKxAYDswtZ7Fd26sybNl2IvI0UAh8UsEqF/rM3WEK0ALoChzF6t4oy7bvGjCG87fO3brNLpANFT6tnMcqvb08OdBdNlOSq4lIDawP7BNjzFdllxtjsowxOY7bC4AaIhLt7rqMMUcc12nA11j/9pbmzDZ1lyHAOmNMatkFdm2vUlJLup4c12nlrFPl286xY2wYMNY4OlrLcuIzdzljTKoxpsgYUwx8UMF72vJdc+TATcDnFa3jzm1WQTZU2ffLkwPdI2dKcvTPfQgkGWP+WcE6sSV9+SLSE2s7Z7i5rloiElpyG2un2pYyq80D7hJLbyCz5F/BKlBhq8mO7VVG6e/R3cA35ayzEBgkIpGOLoZBjsfcQkQGA/8HDDfGnK5gHWc+c3fUVnq/y8gK3tOZ3193GAhsN8Ykl7fQndvsPNlQdd8vV+/pdfFe46FYe4r3AE87HnsB60sOEIz1L/xuYA3QvApquhzrX6FNwAbHZSgwEZjoWGcSsBVrz/4qoG8V1NXc8X4bHe9dsr1K1yXAZMf23AzEV9HnWBMroMNLPWbL9sL6o3IUKMBqFd2Ltd9lMbDLcR3lWDcemFbqueMd37XdwD1urmk3Vp9qyXes5GiuBsCC833mVbC9Zjm+P5uwwqp+2doc98/5/XVnXY7HZ5R8r0qtWyXb7DzZUGXfLz31XymlfIQnd7kopZSqBA10pZTyERroSinlIzTQlVLKR2igK6WUj9BAV6oUEYkQkYfsrkOpi6GBrtRvRWCN4qmU19FAV+q3XgVaOMbK/rvdxShVGXpikVKlOEbJm2+M6WhzKUpVmrbQlVLKR2igK6WUj9BAV+q3srGmD1PK62igK1WKMSYDWC4iW3SnqPI2ulNUKaV8hLbQlVLKR2igK6WUj9BAV0opH6GBrpRSPkIDXSmlfIQGulJK+QgNdKWU8hH/H2JQYubc8okFAAAAAElFTkSuQmCC\n", 70 | "text/plain": [ 71 | "
" 72 | ] 73 | }, 74 | "metadata": { 75 | "needs_background": "light" 76 | }, 77 | "output_type": "display_data" 78 | } 79 | ], 80 | "source": [ 81 | "fig, ax = plt.subplots()\n", 82 | "ax.plot(T_rs, Y_rs[:,0])\n", 83 | "ax.plot(T_rs, Y_rs[:,1])\n", 84 | "ax.set_xlabel('t')\n", 85 | "\n", 86 | "plt.show()" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 13, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "number_of_samples = 100\n", 96 | "\n", 97 | "idx = np.random.permutation(Y.shape[0])\n", 98 | "X_train = torch.tensor(T_rs[idx, :][:number_of_samples], dtype=torch.float32, requires_grad=True)\n", 99 | "y_train = torch.tensor(Y_rs[idx, :][:number_of_samples], dtype=torch.float32)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 14, 105 | "metadata": {}, 106 | "outputs": [ 107 | { 108 | "name": "stdout", 109 | "output_type": "stream", 110 | "text": [ 111 | "torch.Size([100, 1]) torch.Size([100, 2])\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "print(X_train.shape, y_train.shape)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "# Setup a custom library" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 15, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "from torch.autograd import grad\n", 133 | "from itertools import combinations, product\n", 134 | "from functools import reduce" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "Here we show an example where we create a custom library. $\\theta$ in this case containe $[1,u,v, u*v]$ to showcase that non-linear terms can easily be added to the library" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 17, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "def library_non_linear_ODE(input, poly_order, diff_order):\n", 151 | " \n", 152 | " prediction, data = input\n", 153 | " samples = prediction.shape[0]\n", 154 | " \n", 155 | " # Construct the theta matrix\n", 156 | " C = torch.ones_like(prediction[:,0]).view(samples, -1)\n", 157 | " u = prediction[:,0].view(samples, -1)\n", 158 | " v = prediction[:,1].view(samples, -1)\n", 159 | " theta = torch.cat((C, u, v, u*v),dim=1)\n", 160 | "\n", 161 | " # Construct a list of time_derivatives \n", 162 | " time_deriv_list = []\n", 163 | " for output in torch.arange(prediction.shape[1]):\n", 164 | " dy = grad(prediction[:,output], data, grad_outputs=torch.ones_like(prediction[:,output]), create_graph=True)[0]\n", 165 | " time_deriv = dy[:, 0:1]\n", 166 | " time_deriv_list.append(time_deriv)\n", 167 | " \n", 168 | " return time_deriv_list, theta\n" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": {}, 174 | "source": [ 175 | "## Configuring DeepMoD" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "We now setup the options for DeepMoD. The setup requires the dimensions of the neural network, a library function and some args for the library function:" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 30, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "## Running DeepMoD\n", 192 | "config = {'n_in': 1, 'hidden_dims': [20,20,20,20,20], 'n_out': 2, 'library_function': library_non_linear_ODE, 'library_args':{'poly_order': 1, 'diff_order': 0}}" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "Now we instantiate the model. Note that the learning rate of the coefficient vector can typically be set up to an order of magnitude higher to speed up convergence without loss in accuracy" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 31, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "model = DeepMod(**config)\n", 209 | "optimizer = torch.optim.Adam([{'params': model.network_parameters(), 'lr':0.001}, {'params': model.coeff_vector(), 'lr':0.005}])" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": {}, 215 | "source": [ 216 | "## Run DeepMoD " 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "We can now run DeepMoD using all the options we have set and the training data. We need to slightly preprocess the input data for the derivatives:" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 32, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "| Iteration | Progress | Time remaining | Cost | MSE | Reg | L1 |\n", 236 | " 37200 74.40% 0s 2.02e-04 1.27e-06 2.50e-06 1.98e-04 " 237 | ] 238 | }, 239 | { 240 | "ename": "KeyboardInterrupt", 241 | "evalue": "", 242 | "output_type": "error", 243 | "traceback": [ 244 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 245 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 246 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrain_deepmod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m50000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'l1'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m1e-4\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 247 | "\u001b[0;32m~/Documents/GitHub/New_DeepMod_Simple/DeePyMoD_torch/src/deepymod_torch/training.py\u001b[0m in \u001b[0;36mtrain_deepmod\u001b[0;34m(model, data, target, optimizer, max_iterations, loss_func_args)\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;34m'''Performs full deepmod cycle: trains model, thresholds and trains again for unbiased estimate. Updates model in-place.'''\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0;31m# Train first cycle and get prediction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 69\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_iterations\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_func_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 70\u001b[0m \u001b[0mprediction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime_deriv_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparse_theta_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcoeff_vector_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 248 | "\u001b[0;32m~/Documents/GitHub/New_DeepMod_Simple/DeePyMoD_torch/src/deepymod_torch/training.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, data, target, optimizer, max_iterations, loss_func_args)\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;31m# Optimizer step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 35\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0mboard\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 249 | "\u001b[0;32m~/opt/anaconda3/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[0mproducts\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0mDefaults\u001b[0m \u001b[0mto\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;31m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \"\"\"\n\u001b[0;32m--> 195\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 250 | "\u001b[0;32m~/opt/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables)\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0mgrad_tensors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrad_tensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 93\u001b[0;31m \u001b[0mgrad_tensors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_make_grads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 94\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0mretain_graph\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 251 | "\u001b[0;32m~/opt/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36m_make_grads\u001b[0;34m(outputs, grads)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"grad can be implicitly created only for scalar outputs\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0mnew_grads\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones_like\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemory_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreserve_format\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0mnew_grads\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 252 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 253 | ] 254 | } 255 | ], 256 | "source": [ 257 | "train_deepmod(model, X_train, y_train, optimizer, 50000, {'l1': 1e-4})" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 21, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "solution = model(X_train)[0].detach().numpy()" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 22, 272 | "metadata": {}, 273 | "outputs": [ 274 | { 275 | "data": { 276 | "text/plain": [ 277 | "array([17.80820748, 0.4 ])" 278 | ] 279 | }, 280 | "execution_count": 22, 281 | "metadata": {}, 282 | "output_type": "execute_result" 283 | } 284 | ], 285 | "source": [ 286 | "np.max(np.abs(Y),axis=0)" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 27, 292 | "metadata": {}, 293 | "outputs": [ 294 | { 295 | "name": "stdout", 296 | "output_type": "stream", 297 | "text": [ 298 | "Parameter containing:\n", 299 | "tensor([[-4.7004e-04],\n", 300 | " [ 3.1881e-04],\n", 301 | " [ 7.4061e-04],\n", 302 | " [ 3.8119e-01]], requires_grad=True) Parameter containing:\n", 303 | "tensor([[-1.3845e-04],\n", 304 | " [-1.7569e-04],\n", 305 | " [-1.9776e-01],\n", 306 | " [-1.9382e-04]], requires_grad=True)\n" 307 | ] 308 | } 309 | ], 310 | "source": [ 311 | "print(model.fit.coeff_vector[0],model.fit.coeff_vector[1])" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 28, 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "data": { 321 | "text/plain": [ 322 | "[]" 323 | ] 324 | }, 325 | "execution_count": 28, 326 | "metadata": {}, 327 | "output_type": "execute_result" 328 | }, 329 | { 330 | "data": { 331 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAdCElEQVR4nO3de3zU9Z3v8dcnkwEmSAmC10CKWopVqZemQNfT1q0uomvVta2CvWity9nTes52a9niqQ9F18fWynHbbtdtD7audaXeWpuihxbb3faxezxFQQPGqCiihUwQUUgACZDL5/wxM3Ey/OYSMvd5Px8PHpn5/X6T+Tx+mXz45vO9mbsjIiKVr67UAYiISH4ooYuIVAkldBGRKqGELiJSJZTQRUSqRH2p3njKlCk+ffr0Ur29iEhFeuaZZ95y96OCzpUsoU+fPp1169aV6u1FRCqSmf0x3TmVXEREqoQSuohIlVBCFxGpEkroIiJVQgldRKRKKKGLiFQJJXQRkSqhhC4iUiWyTiwys3uAi4A33f20gPMGfA+4ENgHXO3uz+Y7UBGRStDaFuWGR5+jt28QgDqDK+c0c9ulswr+3rnMFL0X+CfgvjTnLwBmxP/NAX4Q/yoiUpVa26Lc8lgHu/b1DR2LhOuoM+OdgwPDrh10uH/NFoCCJ/WsCd3d/8PMpme45BLgPo9tfbTGzBrN7Dh335anGEVEiq61Lcqy1Rvp6u7l+MYI0ydHWLN5FwNpdnlLtMjTWbFmCx1du+nq7uUb80/msrOm5j3mfKzl0gRsTXreGT92SEI3s0XAIoDm5uY8vLWISP4kkni0u3fY8Wh37yHHRsqBhjEhPjbjKI5vjIzqe6WTj4RuAccC/wtz9+XAcoCWlhZtZioiRZOcrENmDLgPfW1qjPCnJx/Fz5+J0ts3kP2bHYaQGSuunVuQ752Qj4TeCUxLej4V6MrD9xURGZXkJG6829JMlE0SX6PdvUN17kJZOGda9otGKR/DFlcCX7CYuUCP6uciUmqx0SbtQ6WSUpUE6gw+N7dMRrmY2QPAOcAUM+sEbgbCAO7+Q2AVsSGLm4gNW/xioYIVEUnH3dm++wAvvbGbjW/s4bu/faVg5ZMgkXAd48Ihuvf1cXxjhMXnz+TSM5uK9v6Q2yiXhVnOO/CVvEUkIsKhQwMbI2GWXnwql57ZxMH+QV7evofnoz28sG03L72xh41v7KGnty/Ld80suSyTevx9R49n8459Q7X3hXOmFaXVPRLmaYbgFFpLS4trxyIRubG1nQee2jpsOGBjJMyeA/0MDA7PT3VA06QI23cf4OBAbJjg+DEhZh47gZnHvoeTj50Qe3zMBC76/v8d0ciUSDjEpz7UxO9e2jGs47SpRK3tdMzsGXdvCTpXsi3oRERubG0P7IzsTtPSHgS27znA1WdP57Smicxqmsh7j2ygru7QwXaLz5/JDY+2Dyu7JFrgQaNcyilpHy4ldBEpiEwlk4SfHsbIkr7+Qf7nhR/Iel3ifZInB1VD0s5ECV1E8iaofJLQ3dvH1x9ez7rXdzIIPPvHXWSeWxlsJJNyLj2zqaoTeColdBEZlXSzK4P0O9z/1BYmjKvnzOZJbHxjz4iGE4brjMXnzzz8YKucErqI5Ky1LcrSlR1pa9y5Wn/TPEJ1lraGDrHkPaa+bmixq6CSjQynhC4iGeUriSc0NUYIxTsxE8P+Uss01dJJWWxK6CKSVmtblMWPbKBvMH/Dm1NLJrddOqvsxnNXKiV0ETlkqdhE63jZ6o15Teafm9usVncBKaGL1LjEmieJ8drR7l6W/Pw5nn5t56iXjE1Q/bs4lNBFatyy1RsPWfNkf/8gP316C3UW23EnV0rcpaVNokVqWOeufRlb4Xd+5nTCAbMwU0XCdXz3ijNYf/M8JfMSUgtdpMZs3bmPVe3bWNW+jQ2dPWmva2qM8BdnTcXMDhnlMqkhzM2fVEu83Cihi1Sh1E7OL59zEoPAL57t5Nkt3QDMaprIN+afTDhk3PnEy8PKLpFwaGg0Sq3NtqxkSugiVSaok/Obrc8D8P5jjuBv58/kolnH0zy5Yeg1U44YW1NrnlQrJXSRKuLu3PZ/Xgjc2OGoI8ay+qsfw+zQmrha4dVBCV2kCux85yCPPtvJg2u38tbeg4HXvLX3QGAyl+qhhC5SIVLr4tfPez9HTxjHA2u38ETHG/QNOGc1N9IYCQdO0x/JKoVSmZTQRcpYuuVoo929XP/wBhxobAjz+bnTueLD05h57IRDaugwvJNTqpcSukiZyrQSIcR23pnUEOYPN5zLuHBo6HgtbuwgMUroImUitaTS1ZN92n33vr5hyTxBnZy1SQldpMSClqfNdQ0V1cUlmRK6SAkF1btzpbq4pFJCFymhoIWxcqENICSIErpIibh71tJKQ7iOA/3OgDshMxbOmabNICQtJXSRAkvt7Pz6vPfTMLae7/32lYyvi4RD/P1ls9QKl5wpoYsUUNC6Kl+Ljx+fPrmBhR+exi/aouzvHxz2Oq1mKIdDCV2kgIJq5Inx47/92sepD9Ux58TJGjMueaGELlIgmWrk3fv6qA/F9pfRmHHJFyV0kQJYs/ltvv3rl9Ke1/hxKQRtQSeSRx1dPVz9L0+zYPkatnXv54qWaYyrH/5rpvHjUihqoYschtSRK188ezrt0R5+ub6LiZEwN1xwMlf9yXTGhUN85CTVyKU4zH0EW3rnUUtLi69bt64k7y0yGje2trNizRZSf3PCIeMvP3oi//XjJzExEi5JbFL9zOwZd28JOqcWusgItLZFA5M5wOTxY/nb+ScXPSaRBNXQRUZg2eqNgckcYPvu/UWNRSRVTi10M5sPfA8IAT9y99tTzjcDPwEa49cscfdVeY5VpOiSN5ioAwYzXKuRK1JqWRO6mYWAu4A/AzqBtWa20t1fSLrsRuBhd/+BmZ0CrAKmFyBekaKIzfB8jt6+d1N4pmRuoJErUnK5tNBnA5vcfTOAmT0IXAIkJ3QH3hN/PBHoymeQIsWUrtMzHQM+O7dZI1ek5HJJ6E3A1qTnncCclGuWAk+Y2X8HxgPnBX0jM1sELAJobm4eaawiBZep0zNZU2NEwxCl7OSS0C3gWOrnfSFwr7vfaWYfAf7VzE5z92F/pbr7cmA5xIYtHk7AIoV0x69fyprMQ2Y8ueQTRYlHZCRyGeXSCUxLej6VQ0sqXwIeBnD3PwDjgCn5CFCkWNo7e+jqyT5SZeGcaVmvESmFXBL6WmCGmZ1gZmOABcDKlGu2AOcCmNkHiCX0HfkMVKRQ9vcNcPuvXuLSf36SuqC/R+PqDD43t1kbTEjZylpycfd+M7sOWE1sSOI97t5hZrcC69x9JXA9cLeZ/Q2xcszVXqopqCIZtLZFueWxDnbti23IfMSYEJEx9ezYe4CFs6cxq2kif/f4i8OWvE10eiqRS7nLaRx6fEz5qpRjNyU9fgE4O7+hieRXa1uUxT/bQN/Au22NvQcH2HtwgC+fc9LQLM+GMfVae0Uqkqb+S81YtnrjsGSe7Jfru4YSutYnl0qlqf9SMzJtyNyVZbNmkUqgFrpUpdTlba/48DTq64z+weAWuqbtSzVQQpeqE7Qx8z/85mUaxoTw/gEGUubwh+tM0/alKqjkIlUnaGNmgInjwtz5mTOY1PDuWuWNkTDLPnO6auZSFdRCl6qTrlb+xu796vCUqqYWulSVvQf6iYRDgedUJ5dqp4QuVaO9s4eL/vE/2d8/QH3KlE9tzCy1QCUXqViJkSzR7l4mjguz92A/R08Yy0OLPkJXd68mB0nNUUKXipQ6kqVnfx91Btf96fuYfcKRAErgUnNUcpGKFDSSZdDhn3//aokiEik9JXSpOO6ediSLZnxKLVNCl4rSe3CA6x/ZkPa8RrJILVMNXcpa8hT+oyeMJVRnbNu9n/mnHsvvN77J/v53p31qJIvUOrXQpWwlOj6j3b04sH3PAbp69rPooyfyw89/iNs/9UGaGiMYsT0+v3XZLHWESk1TC13KVrop/I8/t40bLvyAZn2KpFBCl7KSXGJJt+WVOj5FgimhS9lIHVuejjo+RYIpoUtZaG2Lcv3DGxjIshWtOj5F0lNCl5JLtMwzJXMDTeEXyUIJXUouXednQlNjhCeXfKKIEYlUJiV0KYlcOj9BJRaRkVBCl6LLtfMzZKax5SIjoIQuRTWSzk8lc5GRUUKXolHnp0hhKaFL0ajzU6SwtJaLFE26JW9BnZ8i+aAWuhRM8kiW4yaOIxwy+gYOLbeo81MkP5TQpSBSR7J09ewHoM5iOwslqPNTJH+U0KUgbnmsI7BePjESpmFMvTZvFikAJXTJuxtb29m1ry/wXPe+PtpumlfkiERqgzpFJa9a26KsWLMl7XmtlChSOGqhS14kOkAzjWQBNJJFpICU0GXUbmxtZ8WaLRnXZAFojIRVLxcpIJVcZFQSJZZsydyApRefWoyQRGqWEroctsS6LLkk88/ObVbrXKTAckroZjbfzDaa2SYzW5LmmsvN7AUz6zCzn+Y3TCk3N7a28zcPrc+6yFZTY4TvXHEGt106q0iRidSurDV0MwsBdwF/BnQCa81spbu/kHTNDOAG4Gx332VmRxcqYCm9XMosBnznijPUKhcpolxa6LOBTe6+2d0PAg8Cl6Rc85fAXe6+C8Dd38xvmFJOlq3emDWZq8QiUny5jHJpArYmPe8E5qRc834AM3sSCAFL3f3Xqd/IzBYBiwCam5sPJ14poVyGJobMuPPy05XMRUogl4RuAcdSG2j1wAzgHGAq8J9mdpq7dw97kftyYDlAS0tLtr40KSO57DJkoGQuUkK5lFw6gWlJz6cCXQHX/NLd+9z9NWAjsQQvVSLbWuYqs4iUXi4t9LXADDM7AYgCC4ArU65pBRYC95rZFGIlmM35DFSKL9eNnJu0yJZIWcia0N2938yuA1YTq4/f4+4dZnYrsM7dV8bPzTOzF4ABYLG7v13IwKWwWtuiLP7ZhsD1y5NplyGR8pHT1H93XwWsSjl2U9JjB74W/ydV4JbHOrImc+0yJFJetJaLDJMos6Rb/ha0kbNIuVJClyG5jGQBeO32Py9SRCIyElrLRYZkG8kCsRUTRaQ8KaHLkK4sa5mH60wrJoqUMZVchBtb23ngqa0amihS4ZTQa9xn7/4DT766M+35SDjEty6bpUQuUgGU0GvYja3tGZO5WuUilUUJvUZl28wZ0IQhkQqjTtEalW0J3JAFrckmIuVMLfQak8sSuAAL50zLeF5Eyo8Seg3JdeLQ2ScdqS3jRCqQEnoNyXUJXCVzkcqkhF5DMk0c0ogWkcqnhF5DJh8xhrf2HjzkuJbAFakOGuVSI56P9rD3QP8h+wlqCVyR6qGEXgM2vbmXL9zzNJPHj+XmT55CU2MEI9Yy1yxQkeqhkkuVi3b38vkfP0WdGfdfO4cTpozn6rNPKHVYIlIASuhVKHmseX2dUR8yHv1vZ3PClPGlDk1ECkgllyqTGGuemDjUP+gMDsLL2/eUODIRKTQl9CoTNNb84MAgy1ZvLFFEIlIsSuhVJt2U/mybV4hI5VNCryIDg04kHAo8d3xjpMjRiEixqVO0wiV3gI4fE6K3b4D6OqN/8N21FDXWXKQ2qIVewVI7QN85GEvmC2ZP01hzkRqkFnoFC+oA7R90fvfSDk3lF6lBaqFXsHQdneoAFalNaqFXmETNvKu7F4PAXYfUASpSm5TQK0jqBhVByVwdoCK1Swm9gqTboCJkxqA7x2tNc5GapoReIVrbomknDQ2689rtf17kiESk3KhTtAIkSi3pqGYuIqCEXhEy7QWqmrmIJCihV4B0pRZAk4ZEZIgSegWYMDa4q6OpMaJkLiJDlNDL3L1PvsaeA/2EbPhuoCq1iEgqJfQy9vhzXdzy+Auc94FjuOPTH9T6LCKSUU7DFs1sPvA9IAT8yN1vT3Pdp4FHgA+7+7q8RVmD/t+rb/G1hzbwoeZJ/NOVZzIuHOJTH5pa6rBEpIxlbaGbWQi4C7gAOAVYaGanBFw3AfgfwFP5DrLWdHT1sOi+Z3jv5AZ+dFUL49KscS4ikiyXkstsYJO7b3b3g8CDwCUB1/0dcAewP4/x1ZytO/dx9b+sZcK4en5yzWwaG8aUOiQRqRC5JPQmYGvS8874sSFmdiYwzd0fz/SNzGyRma0zs3U7duwYcbDV7u29B/jCPU9zsH+Q+66ZrQlDIjIiuSR0Czg2tC6UmdUB3wGuz/aN3H25u7e4e8tRRx2Ve5Q14J0D/Vxz71q6unv58VUtzDhmQqlDEpEKk0tC7wSmJT2fCnQlPZ8AnAb83sxeB+YCK82sJV9BVru+gUG+8tNnaY/28P2FZ9Iy/chShyQiFSiXhL4WmGFmJ5jZGGABsDJx0t173H2Ku0939+nAGuBijXLJjbtzw6Pt/H7jDm67dBbzTj221CGJSIXKOmzR3fvN7DpgNbFhi/e4e4eZ3Qqsc/eVmb+DpErepGL82Hr2Hujnr8+dwZVzmksdmohUsJzGobv7KmBVyrGb0lx7zujDql6pm1Tsjc8CnT65ocSRiUil00zRIgtaOXHAnf/1xMslikhEqoU2uCiSRJkl3cqJ2thZREZLCb0IUsssQTTmXERGSyWXIsi0QQVo5UQRyQ+10IsgUzmlSRs7i0ieKKEXUKJu7mnONzVGeHLJJ4oak4hULyX0AslWN1eZRUTyTQm9QDLVzVVmEZFCUEIvkHR1cwOVWUSkIJTQ8yxb3VzDE0WkUJTQ80h1cxEpJSX0PMg2CxRUNxeRwlNCH6VcZoGqbi4ixaCZoqOUbRYoqG4uIsWhhD5K2RbVUt1cRIpFJZfDlG00C6huLiLFpYR+GHIZzfKty2YpkYtIUSmhj1BrW5TrH97AgAe3zdUqF5FSUUIfgUTLPF0y12gWESkldYqOwC2PdWiTChEpW0roOWpti7JrX1/a8xrNIiKlpoSeg0TdPJ2QmTpBRaTklNCzyFY3B7jz8tOVzEWk5NQpmkYu67MANEbCSuYiUhaU0APc2NrO/Wu2ZL0uEg6x9OJTixCRiEh2KrmkaG2L5pTMVTcXkXKjhJ5i2eqNWa+JhEOqm4tI2VHJJUW2xbY0E1REypUSeoqjJ4xl+54DgecmNYQ1E1REypZKLklef+sd9vcPBp4L1Rk3f1IdoCJSvpTQ4/749jssvHsNdQaL582kMRIeOjepIcydn1HNXETKm0ouwJa397Fw+Rp6+wZYce0cTj1+Il/5xPtKHZaIyIjUfAt96859LLx7De8cHOD+L8WSuYhIJarphN65ax8Llq9h74F+Vlw7h9OalMxFpHLVVMklMZ2/q7uXo98zlr4Bp39gkBXXzlUyF5GKl1ML3czmm9lGM9tkZksCzn/NzF4ws+fM7N/M7L35D3V0EotsRbt7cWD77gPsfOcg1370RGZNVTIXkcqXNaGbWQi4C7gAOAVYaGanpFzWBrS4+weBnwF35DvQ0Vq2emPg5hQPrd1agmhERPIvlxb6bGCTu29294PAg8AlyRe4++/cfV/86Rpgan7DHL10M0CzzQwVEakUuST0JiC5GdsZP5bOl4BfBZ0ws0Vmts7M1u3YsSP3KEehtS3K2bf/O+lWM9e2cSJSLXLpFLWAY4H50cw+B7QAHw867+7LgeUALS0t6XeMyIPWtii3PNahbeNEpGbkktA7gWlJz6cCXakXmdl5wDeBj7t78GIoRZLoAM20obMW2RKRapNLQl8LzDCzE4AosAC4MvkCMzsT+N/AfHd/M+9RjtAtj3VkTOYGWmRLRKpO1hq6u/cD1wGrgReBh929w8xuNbOL45ctA44AHjGz9Wa2smARZ9DaFuWMW57IWGYB1c1FpDrlNLHI3VcBq1KO3ZT0+Lw8xzViuZRZQHVzEaleVTFTtLUtyvUPb2DAM/ezNkbCLL34VNXNRaQqVXxCT7TMc0nm62+eV6SoRESKr+IX50o3AzRZJBxi6cXanEJEqltFttBb26IsXdlBd2/mzk+IbU5x8ydVZhGR6ldxCb21LcriRzbQN5i5xBIy487LtcuQiNSOikrouXZ+RsIhvnXZLCVzEakpFZPQc+381AxQEalVFZPQc+n8bGqMaAaoiNSsihnlkm2Z23DINGFIRGpaxST0TNP1JzWEWfZpdYCKSG2rmIS++PyZRMKhYcci4RDfveIM2m6ap2QuIjWvYmroiYSd2OT5eHV+iogMUzEJHWJJXQlcRCRYxZRcREQkMyV0EZEqoYQuIlIllNBFRKqEErqISJVQQhcRqRJK6CIiVUIJXUSkSphnWY62YG9stgP442G+fArwVh7DyRfFNXLlGpviGhnFNTKjieu97n5U0ImSJfTRMLN17t5S6jhSKa6RK9fYFNfIKK6RKVRcKrmIiFQJJXQRkSpRqQl9eakDSENxjVy5xqa4RkZxjUxB4qrIGrqIiByqUlvoIiKSQgldRKRKlHVCN7P5ZrbRzDaZ2ZKA82PN7KH4+afMbHoRYppmZr8zsxfNrMPM/jrgmnPMrMfM1sf/3VTouOLv+7qZtcffc13AeTOzf4zfr+fM7KwixDQz6T6sN7PdZvbVlGuKdr/M7B4ze9PMnk86dqSZ/cbMXol/nZTmtVfFr3nFzK4qQlzLzOyl+M/qF2bWmOa1GX/uBYhrqZlFk35eF6Z5bcbf3wLE9VBSTK+b2fo0ry3I/UqXG4r6+XL3svwHhIBXgROBMcAG4JSUa74M/DD+eAHwUBHiOg44K/54AvByQFznAI+X4J69DkzJcP5C4FeAAXOBp0rwM32D2MSIktwv4GPAWcDzScfuAJbEHy8Bvh3wuiOBzfGvk+KPJxU4rnlAffzxt4PiyuXnXoC4lgJfz+FnnfH3N99xpZy/E7ipmPcrXW4o5uernFvos4FN7r7Z3Q8CDwKXpFxzCfCT+OOfAeeamRUyKHff5u7Pxh/vAV4EKmVfvEuA+zxmDdBoZscV8f3PBV5198OdITxq7v4fwM6Uw8mfo58Alwa89HzgN+6+0913Ab8B5hcyLnd/wt3740/XAFPz9X6jiStHufz+FiSueA64HHggX++XY0zpckPRPl/lnNCbgK1Jzzs5NHEOXRP/4PcAk4sSHRAv8ZwJPBVw+iNmtsHMfmVmpxYpJAeeMLNnzGxRwPlc7mkhLSD9L1kp7lfCMe6+DWK/lMDRAdeU+t5dQ+yvqyDZfu6FcF28FHRPmhJCKe/XR4Ht7v5KmvMFv18puaFon69yTuhBLe3UMZa5XFMQZnYE8HPgq+6+O+X0s8TKCqcD3wdaixETcLa7nwVcAHzFzD6Wcr6U92sMcDHwSMDpUt2vkSjlvfsm0A+sSHNJtp97vv0AOAk4A9hGrLyRqmT3C1hI5tZ5Qe9XltyQ9mUBx0Z8v8o5oXcC05KeTwW60l1jZvXARA7vz8MRMbMwsR/YCnd/NPW8u+92973xx6uAsJlNKXRc7t4V//om8Atif/Ymy+WeFsoFwLPuvj31RKnuV5LtidJT/OubAdeU5N7FO8cuAj7r8WJrqhx+7nnl7tvdfcDdB4G707xfqe5XPXAZ8FC6awp5v9LkhqJ9vso5oa8FZpjZCfHW3QJgZco1K4FEb/CngX9P96HPl3h97sfAi+7+D2muOTZRyzez2cTu89sFjmu8mU1IPCbWofZ8ymUrgS9YzFygJ/GnYBGkbTWV4n6lSP4cXQX8MuCa1cA8M5sULzHMix8rGDObD3wDuNjd96W5Jpefe77jSu53+Ys075fL728hnAe85O6dQScLeb8y5Ibifb7y3dOb517jC4n1FL8KfDN+7FZiH3CAccT+hN8EPA2cWISY/guxP4WeA9bH/10I/BXwV/FrrgM6iPXsrwH+pAhxnRh/vw3x907cr+S4DLgrfj/bgZYi/RwbiCXoiUnHSnK/iP2nsg3oI9Yq+hKxfpd/A16Jfz0yfm0L8KOk114T/6xtAr5YhLg2EaurJj5niRFdxwOrMv3cCxzXv8Y/P88RS1bHpcYVf37I728h44ofvzfxuUq6tij3K0NuKNrnS1P/RUSqRDmXXEREZASU0EVEqoQSuohIlVBCFxGpEkroIiJVQgldRKRKKKGLiFSJ/w+UXzy3UobDvgAAAABJRU5ErkJggg==\n", 332 | "text/plain": [ 333 | "
" 334 | ] 335 | }, 336 | "metadata": { 337 | "needs_background": "light" 338 | }, 339 | "output_type": "display_data" 340 | } 341 | ], 342 | "source": [ 343 | "plt.scatter(X_train.detach().numpy().squeeze(),solution[:,0])\n", 344 | "plt.plot(T_rs,Y_rs[:,0])" 345 | ] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": 29, 350 | "metadata": {}, 351 | "outputs": [ 352 | { 353 | "data": { 354 | "text/plain": [ 355 | "[]" 356 | ] 357 | }, 358 | "execution_count": 29, 359 | "metadata": {}, 360 | "output_type": "execute_result" 361 | }, 362 | { 363 | "data": { 364 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAerUlEQVR4nO3dfXTc1X3n8fdX0tiSH7AAO4CFjR1iHOw4RkQH2Ho3TSDBBhpQfRIekmzSlFMOW9gESJzaG5ZQygYnPtCmWbYtDWweoImBJYpbIE4apyeBxAQ5tnEMGIR5kgzYEMv4QZZH0nf/mBkzGv1+MyNpHn/zeZ2j45n53Zn58pvRl6vvvb97zd0REZHqV1fuAEREpDCU0EVEIkIJXUQkIpTQRUQiQgldRCQiGsr1xtOnT/c5c+aU6+1FRKrSpk2b3nT3GUHHypbQ58yZQ2dnZ7neXkSkKpnZy2HHVHIREYkIJXQRkYhQQhcRiQgldBGRiFBCFxGJCCV0EZGIUEIXEYkIJXQRkYjImdDN7B4z221mvw85bmb292bWZWZPmdmZhQ8zoWNzD0tWb2DuyodZsnoDHZt7ivVWIiJVJ58e+neAZVmOXwDMS/5cBfzD+MMaqWNzD6se2kZPbx8O9PT2seqhbUrqIiJJORO6u/8S+EOWJpcA3/OEjUCzmZ1UqABT1qzfQV98cNhjffFB1qzfUei3EhGpSoWoobcAr6bd704+NoKZXWVmnWbWuWfPnlG9ya7evlE9LiJSawqR0C3gscCNSt39Lndvc/e2GTMCFwsLNbO5KfDxOjOVXUREKExC7wZmpd0/GdhVgNcdZsXS+TTF6kc8PuiuWrqICIVJ6OuAzyRnu5wD7HP31wrwusO0t7Zw2/JF1NvIPwj64oPcvG57od9SRKSq5DNt8QfAb4D5ZtZtZlea2dVmdnWyySPATqAL+GfgL4sVbHtrC0MeWM2hty+uXrqI1LScG1y4+xU5jjtwTcEiymFmcxM9IQOha9bvoL01cDxWRCTyqu5K0RVL54ce04wXEallVZfQ21tbOHZSLPBY2EwYEZFaUHUJHeCrH1s4YsZLU6w+a+9dRCTqyrZJ9Hik6uRr1u84Wk+//qPzVD8XkZpmHjJrpNja2tq8s7Nz3K+zZ38/Z3/t32maUM+h/kFmNjexYul8JXcRiSQz2+TubUHHqrKHnu7xrjcxjIP9iXVeUot2AUrqIlJTqrKGnm7N+h0MZvyVoUW7RKQWVX1C16JdIiIJVZ/Qw6YqagqjiNSaqk/oYYt2HToyoKUARKSmVH1CTy3a1dw0/GKjvYfiWoVRRGpK1Sd0SCT1yRNHTtjR4KiI1JJIJHTQ4KiISGQSunY0EpFaF5mErh2NRKTWRSah59rRSLV0EYm6yCR0yL6jkWrpIhJ1kUroEF5Ln9YUvIa6iEhURC6hr1g6n1jdyLLLQV1oJCIRF7mE3t7awpTGkXPS44POF+/fqqQuIpEVuYQO0HsoHvi4ZryISJRFMqFnW5hLM15EJKoimdDD5qSnaMaLiERRJBN6tjnpoKV1RSSaIpnQIZHUb7908YieelOsnhVL55cpKhGR4qn6PUWzSe0p+o2fPMuufYeZ2FDHbcsXaa9REYmkyPbQU9pbW/j1qvP48rL59A8MceqMKeUOSUSkKCKf0FM+fc4pNMbquPSffsPclQ+zZPUGTV8UkUiJdMkl3YZndjMw6AwMDQHQ09vHqoe2AagEIyKRUDM99DXrdzAwNHzhLs1JF5EoySuhm9kyM9thZl1mtjLg+Gwz+4WZbTazp8zswsKHOj7a0UhEoi5nQjezeuBO4AJgAXCFmS3IaHYjcL+7twKXA/+n0IGOV9jcc81JF5GoyKeHfhbQ5e473f0I8EPgkow2DhyTvD0N2FW4EAsj6OrRxoY6zUkXkcjIZ1C0BXg17X43cHZGm5uBn5rZfwcmAx8JeiEzuwq4CmD27NmjjXVcUgOfa9bvoCdZZvngaTM0ICoikZFPDz3o+vnMbYGuAL7j7icDFwLfN7MRr+3ud7l7m7u3zZgxY/TRjlN7awuPrzyXl1ZfxIWLTuTXL7xF76EjJY9DRKQY8kno3cCstPsnM7KkciVwP4C7/wZoBKYXIsBi+fx58zjQP8A9j71Y7lBERAoin4T+JDDPzOaa2QQSg57rMtq8ApwHYGank0joewoZaKG998RjuOB9J/J/H3+JfSHrp4uIVJOcCd3dB4BrgfXAMyRms2w3s1vM7OJksy8Cf2FmW4EfAH/mHrJbcwX5/Hnz2N8/wN2Pq5cuItUvrytF3f0R4JGMx25Ku/00sKSwoRXf6Scle+mPvciVS+YybZI2khaR6lUzV4qGSfXS71EvXUSqXM0n9NNPOoZlC0/knsdfZF+faukiUr1qPqFDspd+eIC7f7Wz3KGIiIyZEjqwYOYxXLjoRO5+7EXeOtBf7nBERMZECT3pho/Opy8+yJ2/eKHcoYiIjIkSetJ73jWFj3/gZO7d+PLRpQFERKqJEnqaL3zkNAC++e/PlTkSEZHRU0JP09LcxKfPOYUHN3XTtftAucMRERmVmtmCLl/XfPhU7nviZf7kW7+iPz7EzOYmViydr1UZRaTiqYee4VfPv8ngkHM4PoST2Hv0+rVbuLFjW7lDExHJSgk9Q9Deow7cu/EVOjb3lCcoEZE8KKFnyLbH6F//6/YSRiIiMjpK6Bmy7TG6V8vsikgFU0LPoD1GRaRaKaFnaG9tYVIs/LQsWb1BtXQRqUhK6AG+tvz9xOqCtlJNzHpZ9dA2JXURqThK6AHaW1tY84nFtITU0/vig6xZv6PEUYmIZKeEHqK9tYXHV55LcD890VNXL11EKokSeg7ZZr2o9CIilUQJPYcVS+fTFKsPPKbSi4hUEq3lkkNqDZfr1m4JPJ7tQiQRkVJSDz0P7a0toQOk05piJY5GRCSYEnqeViydHziV8eCRAdXRRaQiKKHnqb21hSmNIytU8UHni/dvVVIXkbJTQh+F3pC1XAbdNeNFRMpOCX0Usk1h1IwXESk3JfRRyDaFETTjRUTKSwl9FNpbW7ht+SLqLfj60Ww9eBGRYlNCH6X21hZuv3TxiJ76xPo6Lb0rImWlhD4GqZ56am66GcSHhrhu7RYtrysiZZNXQjezZWa2w8y6zGxlSJtLzexpM9tuZv9S2DArT2rxrr+77AxidXWktiHV8roiUi45E7qZ1QN3AhcAC4ArzGxBRpt5wCpgibsvBK4rQqwVac36HRwZHBr2mGa8iEg55NNDPwvocved7n4E+CFwSUabvwDudPe9AO6+u7BhVq6wmS09vX0qv4hISeWT0FuAV9PudycfS3cacJqZPW5mG81sWdALmdlVZtZpZp179uwZW8QVJtvMFpVfRKSU8knoQXP0PON+AzAP+BBwBfBtM2se8ST3u9y9zd3bZsyYMdpYK1Kuuekqv4hIqeST0LuBWWn3TwZ2BbT5sbvH3f1FYAeJBB95mTNeguiCIxEphXwS+pPAPDOba2YTgMuBdRltOoAPA5jZdBIlmJ2FDLSSpWa8hCV1XXAkIqWQM6G7+wBwLbAeeAa43923m9ktZnZxstl64C0zexr4BbDC3d8qVtCVKqz8ogFSESkFc88sh5dGW1ubd3Z2luW9i6ljcw9r1u+gJ6DM0hSr57bli47ugiQiMlpmtsnd24KO6UrRAstWftEAqYgUkxJ6kYQNhGqAVESKRQm9SMIGQjVAKiLFooReJBogFZFSG7lJphREauAzaIA0dQVpejsRkfFSD72INEAqIqWkhF4CGiAVkVJQQi+BsIFQB9XTRaRglNBLINsCXlqRUUQKRQm9BHIt4KV6uogUghJ6iaQGSIPWIgbV00Vk/JTQS0wXHIlIsSihl1hYPf1zS+aUPhgRiRQl9BJLr6cbcMLUiTTF6vnR5h4OxwfLHZ6IVDFdKVoG7a0tw64Q3fDsG1z53U5WPbSNOy5djFlYpV1EJJx66BXg3PeewA0fOY0fbe7hO79+qdzhiEiVUkKvENd8+D2cv+AEbn34GTburLnNnkSkAJTQK0RdnXH7pYuZc/wkrrnvd5rGKCKjpoReQaY2xrjrM230Dwxx9b2bNEgqIqOihF5hTp0xhTsuXcxT3ftY9dA2yrXnq4hUHyX0CnT+whP50vmJQdJvbegqdzgiUiU0bbFCXfPh97DzzYPc8bPnmDN9MhcvnlnukESkwimhVygz47bli+j+Qx9femArL+w+wIObutnV28fM5iZWLJ2v3Y5EZBiVXCrYxIZ6/vG/foBjGhv45s+fp6e3D0dL7opIMCX0Cnfc5AnUBVw5qiV3RSSTEnoV2LO/P/Dxnt4+9dJF5Cgl9CqQbWldlV5EJEUJvQpk28JOpRcRSdEslyqQms1y3dotgce1TICIgHroVaO9tSV0T9I6M+aufJglqzeo/CJSw/JK6Ga2zMx2mFmXma3M0u7jZuZm1la4ECUlrPQy6K7pjCKSO6GbWT1wJ3ABsAC4wswWBLSbCnweeKLQQUpC5m5HQR+eauoitSufHvpZQJe773T3I8APgUsC2v0N8A3gcAHjkwztrS08vvJcXlx9EWHLdqmmLlKb8knoLcCrafe7k48dZWatwCx3/7dsL2RmV5lZp5l17tmzZ9TBynBh0xmzTXMUkejKJ6EHbXB5tHNoZnXA3wJfzPVC7n6Xu7e5e9uMGTPyj1IChdXUF86cypLVGzRQKlJj8pm22A3MSrt/MrAr7f5U4H3AfyQ3Nz4RWGdmF7t7Z6EClZFS0xnXrN/Brt4+TjimkYNHBvjp07uPtkkNlKa3F5FoslwbKJhZA/AccB7QAzwJfNLdt4e0/w/gS7mSeVtbm3d2Kt8X2jlf+zmvvz1yGKOluYnHV55bhohEpJDMbJO7B84kzFlycfcB4FpgPfAMcL+7bzezW8zs4sKGKuP1RkAyBw2UitSCvK4UdfdHgEcyHrsppO2Hxh+WjNXM5iZ6ApK3BkpFok9XikZM2EDp0oUnlCEaESklJfSIybz46KRjGjnl+El89zcv8+Cm7nKHJyJFpMW5Iqi9tWXYjJYD/QNc/f1NfOmBrfzyuT1senmvtrITiSD10GvAlIkN3P1nbbTOambd1l3ayk4kopTQa8TEhvrAGTBa+0UkOpTQa8hr+4KnNPb09umKUpEIUEKvIdmmLvb09nHd2i2c/j8fVWIXqVJK6DUk21Z2KX3xIW5Yu0VJXaQKKaHXkPQpjdkMATevC1zZQUQqmBJ6jUmtp54rqff2xVVXF6kySug1asXS+YHrIqfTtEaR6qKEXqPaW1v41Dmzc7briw9y3dot6q2LVAEl9Bp2a/siPp1HUodEb33FA1uV1EUqmBJ6jbu1fRF/d9kZOWvqAPEh12CpSAXLucFFsWiDi8rTsbmHVQ9toy8+mLWdgdaBESmTbBtcaHEuOSp9S7ugNdVT0teBSX+eiJSXSi4yTGpa47GTYjnbasBUpLIooUugr35sIbH6XBMbE3p6+7h+7RZu7NhW5KhEJBsldAnU3trCmo8vPrpRRr1lT+4O3LfxFfXURcpIg6KSl3wHTAGOnRTjqx9bqNq6SBFoUFTGLd8BU4C9h+KseHDrsOeJSPGphy6j1rG5h+vXbiGfb06LpjeKFFS2Hrpq6DJqqWUD8hky1YCpSOkoocuY3Nq+iL+97Iycg6WQGDC9d+MrSuoiRaaELmPW3trC7ZcuJlaX3/RGzYIRKS4ldBmX9tYW1nxiMc1NuS9EctCG1CJFpEFRKah8BkxnTmvktX2HtR6MyBhoUFRKJp911nftO3x0PZjr1m6h9ZafqhQjUgBK6FJwqXXWMyvrYZX2vYfi2hlJpABUcpGi6djcw5r1O9jV28fM5qacFyTVmzHkrlKMSBbZSi5K6FIyS1ZvyJnU0zU3xbj5Yi0hIJJu3DV0M1tmZjvMrMvMVgYcv8HMnjazp8zs52Z2yniDluhZsXQ+TbH6vNv39qkUIzIaOXvoZlYPPAd8FOgGngSucPen09p8GHjC3Q+Z2X8DPuTul2V7XfXQa1PH5h5uXred3r543s9RKUbkHePtoZ8FdLn7Tnc/AvwQuCS9gbv/wt0PJe9uBE4eT8ASXe2tLWz56vlH9zHNZ2neQfdhuySpxy4SLJ+E3gK8mna/O/lYmCuBR4MOmNlVZtZpZp179uzJP0qJnNTOSC+uvojbL12cdymmLz6oi5NEQuSzfG5Q9ymwTmNmnwbagD8OOu7udwF3QaLkkmeMEnGpEspf/+t29h7KXYrp6e3D3fnxll3DZtGoHCO1Lp+E3g3MSrt/MrArs5GZfQT4CvDH7t5fmPCkVrS3ttDe2jJsqmOdGYMhYzxLVm9gz4F+4oOJ49q0WiS/ksuTwDwzm2tmE4DLgXXpDcysFfgn4GJ33134MKVW5CrFNDbU8cmzZrN7/zvJPEXlGKl1OXvo7j5gZtcC64F64B53325mtwCd7r4OWANMAR6wxADXK+5+cRHjlhqQvktSZlnlX377SuBzUvPcMy9qUjlGaoEuLJKqlO0ipTNnN7O95236B4eOPtYUq+e25YuU1KXqaXEuiZygi5QmNtTxx6fNYPMrvcOSOagcI7VBm0RLVcpWjpmz8uHA5/T09tGxuWfYbBotLyBRooQuVSs1MyZTS5aFwDLXau/ti7Piga1HX0+kmqnkIpETVI5pbKijsaEu8AKK+JCrHCORoB66RE5YOeb6tVtCn9PT20dPbx8tzU0qy0jV0iwXqRn5LN/7nndN4cU9BxgM+LWoM/jk2bO5tX1RkSIUyU2zXERIlGJi9SNXsojVGTdedDo3fPQ0XnzzYGAyBxhyuHfjK9zYsa3IkYqMjUouUjOC1ozJLKfc8bPncr7OfRtf4eGnXlNJRiqOErrUlLCZMSnZZsikOAxbREwzZaRSqOQikiasLJNLfMj5xk+eLUJEIvlTD10kzWiX8k23a99hvvzgVs47/QT+6NTjmdoYA7SujJSOZrmIZHFjxzZ+8MSrDLpTb0ZjrI6DRwYD2zbF6mmoN/YfHqChzjhjVjPTp0xkw7O7OaJ1ZaRAss1yUUIXGYWOzT2seHDriKV7Y3XGmk8s5qL3n0TnS3t5rGsPjz3/Jlu79wW+zsxpjfx61XlHX3PN+h309PZRn1wDvkU9eQmhhC5SQKO58ChsXRmAixadxMSGOh7e9hr9A0MjjqsnL0GU0EXKJOxipqZYPcdNnpBzRs2kWB39A86gO3WWWFHycHxItfgapguLRMokaF2ZVM/78ZXn5nz+ofjQ0W34hhz64kM472y517G5pxhhS5VSQhcpovbWFm5bvoiW5iaMxDz39DJKS3PTmF872xrvHZt7WLJ6A3NXPsyS1RuU+GuESi4iZdSxuYdVD22jLx48cyYfyxaeyIKZx7Bw5jEsmHkMG194i//xo98Pe03V46MjW8lF89BFyih9ZcjMWS679vWRq7/VFKvnuTf2s/7p14+2rbNEeSZdqjcfltA1Vz4a1EMXqVA3dmzj3o3Bm2HD8F73gf4Bdrz+Ntt3vc1NP94e+pwL3nci8941hXknTGXeCVOYO30yj257fcRfCemv3bG5h5vXbae3LzGr59hJMb76Ma1dUy6a5SJSpdIvbMp3lkvYzJrGWB0nTWvi5bcOHu3B19cZBgxkdulJzJX/8rL3suKBrcQDjgfR/PniU0IXqSFBdfn0Hvfh+CA79xzk+d376dp9gG9t6Ap9rVi9jbiIaqyaYnXctvz9SvbjpBq6SA3JtoE2QGOsngXJAVSAh37XE9ijnzyhPnSZg7Hoiw9xQ3LXqPSknv5XSIp6+mOjHrpIjcvWo08N1hbSCVMn8qu/OpcJDXVZxwlidcaEhnfWztG68wnqoYtIqFw9+tHU0PPxxv5+TrvxUY6fPIG3Dh4JbRcfcuJpfyGMZd359Nk705pimEHvoXhkZ/Kohy4iWSV68E/RFx+53sxYNDfF+NySubz+9mF+8NvwWTxhpkxs4AvnzeP4KROYPmVi4mfqBI6bNIGG+neulcxnjr8ZuL9T4oF3ppAaic1MoLJm9mhQVETGLXNRsrGoA+647IyjifHUVY8Mq52PhxkcO2kCzZNiHDtpAtt79nE4YNGzMLF6Ayf0r5G6ZPL35HvF6owjaQPGTbE6GmP17D0UH/Y/g0JvLq6Si4iMW67t+yB4gDMlaJbLFWfPyjrXPsjMaY385PoP8ub+ft48cIQ3D/Qnf47w1oF+eg/F2XvoyKiSOZBzNk96nndnWDKHxKBv6q8Yz3he6r+xUEk9jHroIlJWQf8TaG6Ksb9/gMGM3nJq3fl8Sh9h8/HLxYCvLV/E5IkNLD55GqccP3lsr6OSi4hUm9GsOx/2/PGuk1Ms/+tP38enzj5lTM8dd8nFzJYB3wTqgW+7++qM4xOB7wEfAN4CLnP3l8YUrYgI+ZV4cj0fGDbLJT44FDq3PlcNfbzqgMdWnsvB/gGmT5lYlPfImdDNrB64E/go0A08aWbr3P3ptGZXAnvd/T1mdjnwdeCyYgQsIpKvoP8pZNvyD4JnuUxsqAvcVWo0PnnObGaOY7nkfOTTQz8L6HL3nQBm9kPgEiA9oV8C3Jy8/SDwv83MvFz1HBGRELl6/mHH0mv95Zzlkk0+Cb0FeDXtfjdwdlgbdx8ws33A8cCb6Y3M7CrgKoDZs2ePMWQRkdK7tX1RSZLyeOSzY5EFPJbZ886nDe5+l7u3uXvbjBkz8olPRETylE9C7wZmpd0/GdgV1sbMGoBpwB8KEaCIiOQnn4T+JDDPzOaa2QTgcmBdRpt1wGeTtz8ObFD9XESktHLW0JM18WuB9SSmLd7j7tvN7Bag093XAXcD3zezLhI988uLGbSIiIyU1zx0d38EeCTjsZvSbh8GPlHY0EREZDTyKbmIiEgVUEIXEYkIJXQRkYgo2+JcZrYHeHmMT59OxkVLFUJxjV6lxqa4Rkdxjc544jrF3QMv5ClbQh8PM+sMW22snBTX6FVqbIprdBTX6BQrLpVcREQiQgldRCQiqjWh31XuAEIortGr1NgU1+gortEpSlxVWUMXEZGRqrWHLiIiGZTQRUQioqITupktM7MdZtZlZisDjk80s7XJ40+Y2ZwSxDTLzH5hZs+Y2XYz+0JAmw+Z2T4z25L8uSnotYoQ20tmti35niN24LaEv0+er6fM7MwSxDQ/7TxsMbO3zey6jDYlO19mdo+Z7Taz36c9dpyZ/czMnk/+e2zIcz+bbPO8mX02qE2B41pjZs8mP6sfmVlzyHOzfu5FiOtmM+tJ+7wuDHlu1t/fIsS1Ni2ml8xsS8hzi3K+wnJDSb9f7l6RPyRWdnwBeDcwAdgKLMho85fAPyZvXw6sLUFcJwFnJm9PBZ4LiOtDwL+V4Zy9BEzPcvxC4FESG5KcAzxRhs/0dRIXRpTlfAEfBM4Efp/22DeAlcnbK4GvBzzvOGBn8t9jk7ePLXJc5wMNydtfD4orn8+9CHHdDHwpj8866+9voePKOH47cFMpz1dYbijl96uSe+hH9zJ19yNAai/TdJcA303efhA4z8yCdk8qGHd/zd1/l7y9H3iGxBZ81eAS4HuesBFoNrOTSvj+5wEvuPtYrxAeN3f/JSM3X0n/Hn0XaA946lLgZ+7+B3ffC/wMWFbMuNz9p+4+kLy7kcTmMiUVcr7ykc/vb1HiSuaAS4EfFOr98owpLDeU7PtVyQk9aC/TzMQ5bC9TILWXaUkkSzytwBMBh/+TmW01s0fNbGGJQnLgp2a2yRL7t2bK55wW0+WE/5KV43ylnODur0HilxJ4V0Cbcp+7Pyfx11WQXJ97MVybLAXdE1JCKOf5+i/AG+7+fMjxop+vjNxQsu9XJSf0gu1lWgxmNgX4f8B17v52xuHfkSgrLAa+BXSUIiZgibufCVwAXGNmH8w4Xs7zNQG4GHgg4HC5ztdolPPcfQUYAO4LaZLrcy+0fwBOBc4AXiNR3shUtvMFXEH23nlRz1eO3BD6tIDHRn2+KjmhV+xepmYWI/GB3efuD2Ued/e33f1A8vYjQMzMphc7Lnfflfx3N/AjEn/2psvnnBbLBcDv3P2NzAPlOl9p3kiVnpL/7g5oU5Zzlxwc+xPgU54stmbK43MvKHd/w90H3X0I+OeQ9yvX+WoAlgNrw9oU83yF5IaSfb8qOaFX5F6myfrc3cAz7n5HSJsTU7V8MzuLxHl+q8hxTTazqanbJAbUfp/RbB3wGUs4B9iX+lOwBEJ7TeU4XxnSv0efBX4c0GY9cL6ZHZssMZyffKxozGwZ8FfAxe5+KKRNPp97oeNKH3f505D3y+f3txg+Ajzr7t1BB4t5vrLkhtJ9vwo90lvgUeMLSYwUvwB8JfnYLSS+4ACNJP6E7wJ+C7y7BDH9ZxJ/Cj0FbEn+XAhcDVydbHMtsJ3EyP5G4I9KENe7k++3NfneqfOVHpcBdybP5zagrUSf4yQSCXpa2mNlOV8k/qfyGhAn0Su6ksS4y8+B55P/Hpds2wZ8O+25f578rnUBnytBXF0k6qqp71lqRtdM4JFsn3uR4/p+8vvzFIlkdVJmXMn7I35/ixlX8vHvpL5XaW1Lcr6y5IaSfb906b+ISERUcslFRERGQQldRCQilNBFRCJCCV1EJCKU0EVEIkIJXUQkIpTQRUQi4v8DELGZlJ5zOGMAAAAASUVORK5CYII=\n", 365 | "text/plain": [ 366 | "
" 367 | ] 368 | }, 369 | "metadata": { 370 | "needs_background": "light" 371 | }, 372 | "output_type": "display_data" 373 | } 374 | ], 375 | "source": [ 376 | "plt.scatter(X_train.detach().numpy().squeeze(),solution[:,1])\n", 377 | "plt.plot(T_rs,Y_rs[:,1])" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "metadata": {}, 384 | "outputs": [], 385 | "source": [] 386 | } 387 | ], 388 | "metadata": { 389 | "kernelspec": { 390 | "display_name": "Python 3", 391 | "language": "python", 392 | "name": "python3" 393 | }, 394 | "language_info": { 395 | "codemirror_mode": { 396 | "name": "ipython", 397 | "version": 3 398 | }, 399 | "file_extension": ".py", 400 | "mimetype": "text/x-python", 401 | "name": "python", 402 | "nbconvert_exporter": "python", 403 | "pygments_lexer": "ipython3", 404 | "version": "3.7.6" 405 | } 406 | }, 407 | "nbformat": 4, 408 | "nbformat_minor": 4 409 | } 410 | -------------------------------------------------------------------------------- /examples/PDE_2D_Advection-Diffusion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 2D Advection-Diffusion equation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "in this notebook we provide a simple example of the DeepMoD algorithm and apply it on the 2D advection-diffusion equation. " 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 22, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stdout", 24 | "output_type": "stream", 25 | "text": [ 26 | "The autoreload extension is already loaded. To reload it, use:\n", 27 | " %reload_ext autoreload\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "# General imports\n", 33 | "import numpy as np\n", 34 | "import torch\n", 35 | "import matplotlib.pylab as plt\n", 36 | "# DeepMoD stuff\n", 37 | "from deepymod_torch.DeepMod import DeepMod\n", 38 | "from deepymod_torch.library_functions import library_2Din_1Dout\n", 39 | "from deepymod_torch.training import train_deepmod, train_mse\n", 40 | "from scipy.io import loadmat\n", 41 | "# Settings for reproducibility\n", 42 | "np.random.seed(42)\n", 43 | "torch.manual_seed(0)\n", 44 | "\n", 45 | "\n", 46 | "%load_ext autoreload\n", 47 | "%autoreload 2" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "## Prepare the data" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "Next, we prepare the dataset." 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 23, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "data = loadmat('data/Advection_diffusion.mat')\n", 71 | "usol = np.real(data['Expression1'])\n", 72 | "usol= usol.reshape((51,51,61,4))\n", 73 | "\n", 74 | "x_v= usol[:,:,:,0]\n", 75 | "y_v = usol[:,:,:,1]\n", 76 | "t_v = usol[:,:,:,2]\n", 77 | "u_v = usol[:,:,:,3]" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "Next we plot the dataset for three different time-points" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 3, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "data": { 94 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAAEWCAYAAAAO34o+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3de9RldX3f8c83g8PFgUgyEQ2DzIRgBhhNgAkxpW0MGoMJgSa6GjChxmioqTSwIjVeUuwiTVNNaswqpGSWUkJCwjKRKkkxBIOJt+hiGEEYZ7BcdURFFOQiMIx8+8dzjpw5nMu+/G577/drrVnzPM/Zzz6/c57z3fv32b/f3tvcXQAAAAAQ2nflbgAAAACAfiJsAAAAAIiCsAEAAAAgCsIGAAAAgCgIGwAAAACiIGwAAAAAiIKwAQAAACAKwsaAmdldZvbSSOt+lZndbWaPmNkHzOx7YjwP0GexatTMnmtmV5nZPWbmZrZ+6vF9zewSM3vQzL5iZr8Zug1AX4WqWzN7kZlda2bfMLOvmdlfmdlzJx43M3uHmX199O+dZmZtnxcIjbCB4MzsGEl/IulMSYdI+pakP87aKACTnpT0d5JeMefx/yLpSEmHS/pJSW8ys5PTNA3AyMGStkhar5VafEjS/554/CxJ/0bSD0t6oaRTJP37tE0EljPuID5MZvZnkn5J0uOSvi3pAnd/Z6B1/zdJ6939VaPvj5C0Q9L3uvtDIZ4D6LuYNTrxHPtIekLSBne/a+LnX5L0Gnf/+9H3vyPpSHc/PeTzA30Ted96nKR/cvcDR99/UtKl7r5l9P1rJf2au78oxPMBoTCyMVDufqakL0j6OXdfM2tjaGbPM7MHFvx71ZzVHyPpponnul3SbknPj/FagD6KXKNzmdnBkr5fEzU8+vqYpq8FGIp5dbukTt9ccfX/WtL2ie/32teKOkWh9sndAJTL3b8g6VkNfnWNpG9O/eybkg5s3SgA39GiRhdZM/p/soapX6AFd29Vp2b2QknnSzpt4sfT+9pvSlpjZuZMW0FBGNlADA9LOmjqZwdpZb4pgLI9PPp/soapXyATM/tBSR+SdI67f2zioel97UGSHiZooDSEjWFbuEEaTdF4eMG/X5rzq9u1csLaeD0/IGlfSZ8P13RgEGLV6PwndL9f0pc1UcOjr7fP/g0AU55Wt0vq9K3zVmRmh0v6sKTfcfc/m3p4r32tqFMUimlUw/ZVST8w78HRFI018x5f4HJJ/2xm/0rSNkkXSLqSk8OB2mLVqMxsP0mrRt/ua2b7uftjo+8vk/TbZrZVK1eU+zVJr2nyPMAAPa1u3b12nZrZoZKuk3SRu188Y5HLJP2mmV2tlYDzRkn/s35zgbgY2Ri239NKh+IBMzsv1Erdfbuk12sldNyrlbne/yHU+oEBiVKjI4/qqSlTO0ffj71d0u2S7pb0T5J+393/LvDzA30Vqm5fp5XQ8vbJkZCJx/9E0t9IulnSLZL+7+hnQFG49C0AAACAKBjZAAAAABAFYQMAAAAYODM7zMw+YmY7zGy7mZ2zYNkfNbNvm9krl62XE8QBAAAA7JH0RnffZmYHSrrBzK51989NLmRmqyS9Q9I1VVbKyAYAAAAwcO7+ZXffNvr6IUk7JB06Y9H/KOn9WrkI0FKdGtlYvf/BfsCBs14zMEzf/Nr2+9z9+3K3YxbqFdhbyfUqUbPAtNJrVpJ+/OBn+QNP7Km07M5HHtku6bGJH21x9y2zljWz9ZKOlfTpqZ8fKunnJZ0k6UerPG+nwsYBBx6qn3jF+3M3AyjGVRdvvDt3G+ahXoG9lVyvEjULTCu9ZiXpgSf26LIf2VRp2RM+8enH3H3zsuXMbI1WRi7OdfcHpx5+t6Tfcvdvm1ml5+1U2AAAAAAQh5k9QytB43J3v3LGIpslXTEKGmsl/YyZ7XH3D8xbJ2EDAAAAGDhbSRDvlbTD3d81axl33zCx/KWS/nZR0JAIGwAAAACkEyWdKelmM7tx9LO3SnqeJLn7xU1WStgAAAAABs7dPy6p2okYK8v/SpXluPQtAAAAgCgIGwAAAACiIGwAAAAAiIKwAQAAACAKwgYAAACAKAgbAAAAAKIgbAAAAACIgrABAAAAIArCBgAAAIAoCBsAAAAAoiBsAAAAAIiCsAEAAAAgCsIGAAAAgCiyhw0zW2VmnzGzv83dFgCLUa9At1CzAHLLHjYknSNpR+5GAKiEegW6hZoFkFXWsGFm6yT9rKT35GwHgOWoV6BbqFkAJcg9svFuSW+S9OS8BczsLDPbamZbdz96f7qWAZhGvQLdQs0CyC5b2DCzUyTd6+43LFrO3be4+2Z337x6/4MTtQ7AJOoV6BZqFkApco5snCjpVDO7S9IVkk4ysz/P2B4A81GvQLdQswBqMbPDzOwjZrbDzLab2TkzltloZv9sZo+b2XlV1pstbLj7W9x9nbuvl3S6pOvc/ZdztQfAfNQr0C3ULIAG9kh6o7sfJelFkt5gZkdPLfMNSb8h6Q+qrjT3ORsAAAAAMnP3L7v7ttHXD2nlSnaHTi1zr7tfL+mJquvdJ2grG3L3f5T0j5mbAaAC6hXoFmoWQF1mtl7SsZI+3XZdRYQNAAAAAPXse9B+2nDSMdUW/sSn15rZ1omfbHH3LdOLmdkaSe+XdK67P9i2jYQNAAAAoP/uc/fNixYws2doJWhc7u5XhnhSztkAAAAABs7MTNJ7Je1w93eFWi8jGwAAAABOlHSmpJvN7MbRz94q6XmS5O4Xm9lzJG2VdJCkJ83sXElHL5puRdgAAAAABs7dPy7JlizzFUnr6qyXaVQAAAAAoiBsAAAAAIiCsAEAAAAgCsIGAAAAgCgIGwAAAACiIGwAAAAAiIKwAQAAACAKwgYAAACAKAgbAAAAAKIgbAAAAACIgrABAAAAIArCBgAAAIAoCBsAAAAAoiBsAAAAAIiCsAEAAAAgCsIGAAAAgCgIGwAAAACiIGwAAAAAiIKwAQAAACAKwgYAAAAwcGZ2iZnda2a3zHn8u83sb8zsJjPbbmavqbJewgYAAACASyWdvODxN0j6nLv/sKQXS/ofZrZ62UoJGwAAAMDAuftHJX1j0SKSDjQzk7RmtOyeZevdJ0zzAAAAAKS06oADtOb4Y6suvtbMtk58v8Xdt9R4ugslXSXpHkkHSvpFd39y2S8RNgAAAID+u8/dN7f4/Z+WdKOkkyQdIelaM/uYuz+46JeYRgUAAABgmddIutJX3CbpTkkbl/0SYQMAAADAMl+Q9BJJMrNDJP2QpDuW/RLTqAAAAICBM7O/1MpVptaa2S5Jb5f0DEly94sl/Y6kS83sZkkm6bfc/b5l6yVsAEACRx13eJT17th2d5T1AgCGxd3PWPL4PZJeVne9hI0eidWZGaNTAywXuw6rPh/1CrRTp5apN2C+bGHDzA6TdJmk50h6UiuX3/qjXO3pmtQdmlnPycZ1WKjZ2XLUYhXU67BRr4uFrtsq66MGMVQ5Rzb2SHqju28zswMl3WBm17r75zK2qVgldmjozAwONTtSYj0uM9lmanUQqNeRUup1VjuoRQxBtrDh7l+W9OXR1w+Z2Q5Jh0oa3IZwkVI2klXQmem3oddsl2pxGWq1/4Zcr12qVQ7aYQiKOGfDzNZLOlbSp/O2pAxd2lDOM34NbDj7aUg124d6XITg0X9DqNe+1Cn1iD7KHjbMbI2k90s6d9YdCM3sLElnSdL+a74/cevS6cuGchobzv5ZVLN9qde+1uMyHCTonz7vY/tep+w/0RdZw4aZPUMrG8HL3f3KWcu4+xZJWyTpWc/e5Ambl0TfN5aT6Mh037Ka7Xq9DqkeF6FW+6GP+9ih1ijBA12W82pUJum9kna4+7tytSOXoW4wJToyXdXnmh1yPS5CrXZX3+qVGn0KdYmuyTmycaKkMyXdbGY3jn72Vne/OmObkmCjuYINZuf0rmapxWqo1U7qRb1So/NRl+iKnFej+rhWbnU+GGw0ZzvquMPZWHZAn2qWWmyGzk13dL1eqdHqqEuULvsJ4kPARnM5NpZIhXpsjwMEiIX6bI79KEpF2IiMDWc9bCwRC7UYFrWKkKjPcKhNlOa7cjegr4467nA2ni3w3iEkPk/x8N6iDfaV8fDeohSEjQgo7jB4H9EWO9s0eJ/RBJ+ZNHifkRthIzCKOiw6MWiKz016vOeogu16erznyIlzNgIqtZA3bdy38rK37Hw8Ykua44RU1FFqLQ4B88WxCLWZF/tS5EDYCKCUjWedUFF3HSWEEDaSWKaUWmyi5NprgnrFpC7XZt9wQACpETZayrkBDREumj5Xrg4QHRjMU3Jnpk2tVvndUgMJ9Qqp7NocMuoTqRA2WsixAU0ZMBaZbEfqjg4bSEwrrTOTuk5LORgwC/U6XKXVJZ6O+sQkM7tE0imS7nX3TTMef7GkD0q6c/SjK939gmXrJWw0lHIjWkrAmCdH8GADibESOjSl1Whp4YN6HZ4S6hLVMK0KEy6VdKGkyxYs8zF3P6XOSgkbDaTaiJbWgali3OYUnRs6MMjdoelKjeYciRyjXocjd122EaKmc4f7pqhRuPtHzWx96PUSNmpKsRHtSgdmkVShg43jcOXq0HS9PlMeEJhGvfZfV4JGzDru8sUeqNHu+fa+B+iRI46ruvhaM9s68f0Wd99S8yl/3MxuknSPpPPcffuyXyBs1BB7I9r1TswsKTo2bByHZ8jnS4WSK3RQr/1VctAooX5Lm944DzXaa/e5++YWv79N0uHu/rCZ/YykD0g6ctkvETYqirkRLWEjGFvsjg0bx+FI3aHpe33mCB3Ua7+UGDK6ULclTG+chxrFLO7+4MTXV5vZH5vZWne/b9HvcQfxCgga4cR8vSXu8BBW6gszDKk+U79e6rUfSvs7drVux+0uqe2l/W2Rn5k9x8xs9PUJWskRX1/2e4xsLEHQCC/mkVSOxvQXF2ZIg4s8oKpSOqN9q9mSRjy4UtWwmNlfSnqxVs7t2CXp7ZKeIUnufrGkV0r6dTPbI+lRSae7uy9bL2Ejg75tGJvatHFfAgcqIWikF6s+p1Gv3VRC0BhCvea8oMMk6nQY3P2MJY9fqJVL49ZC2FggxsY0xcbxqEPuD7auHV89ONi6ZknVoUF3cQW4fLiqHGbJHTSGWK8lhA7qFE0RNuboUtAIGS6WrTtG+IixEWWj2A8EjTKkOChAzXZDzqBBreYPHdQpmiBszBB6Y9rFkFHlOUMHj9AdGjaK3db3S003rd/Yo43zcBlrcG+bcuQOHUAdhI3IQm8kcwSMecZtCdn5IXBA6l/QCFm389aVKoTEHuWgZjGJoLFYjqnI1CjqImxMCdnJCbmRLClkTAsdOjiPAzF17bypJs/JuVYIjfvblIt75aB03GdjQolB46hD7i86aEwK2dagQa2Aq6aguq5ebnr8+S+hXlO0hXvmDAdBoxuSj9hSp6iIsBFBiIIvpdPSRKi2EziGp4tBo/Rajdm+mDcho2bLwI00u4Wbc6JETKMaCVUwoYJGHxx1yP2tp3MwXWM4Yu20+nSBhjZinGM1xj1z+il10ChF1y7eMEvKfSd1imUIGyonmXet81JFiA5OqI0mG8ThidGB6XqdxgodHBjol6EEjT5dvGEagQOlIGwE1GaDmaoDs373zpk/v2v1xqjP23aUg8DRb124r03XQ8a0LlxNTqJm+yxHyMh98QYpbfjgErkoweDDRgnTp0Jv/OYFiia/EzKEhJhWhf4pPWj0LWRM68LV5AgcafXtZpql1XDKK8eNcWNO5MQJ4gGUEDTW7975nX8hhV5vm5NVg13hq5BpcyhfaZ2UmEK+1pLm36OevgSNkq4Ot0jKdiZ539m/YoZBh40QRZE7aMQIGMueK8Tz5Q4cKEOpoxpd6KTEUPLrphMT3/4HxN++xt6Gl/wZXiZF2wkcyGHw06hyabNBSRUuqrShzTSrptOqQgwHM9zbT125GlzbGi79HCuJE8bxdLHvc9MXMa8cJ1GbSG+wIxs5RzW6HjQmtR3p6NMOAvWEPvpVctCYHBUMUcOh1zdLaffKkThi2mVDvc9NG7HvkRMTtYpJgw0bbaUOGimnSzWROnAE6ViyMcxmCEEjdhhI8VwEDoQQ6xLUfQ0Z02K9VgIHUhlk2MhVAG2CRhekDkScvwGpvKCR+8BASRd1GKNWh4t73YTTxcABSAMNG201Kc4mG4ncnZammrQ5186DIy/plfaeh7xQQ2n1Gjp0lKK0zxBmi3Gvm5I+hznEeA+inktDrUKZw4aZnWxmt5rZbWb25hTPmeOD3zRodFmqwMFRmbRy1Owibf/+oYJG6XJfRU6iVnPIWa/cVDOuLgUOdIeZXWJm95rZLXMe/yUz++zo3yfN7IerrDfb1ajMbJWkiyT9lKRdkq43s6vc/XO52lRF3YLMHTSeefu22r/zyBHHBXnu9bt31r5qTo4b/3FlqmpC1GzIsJ87aHQhZExrUpPT2tRoyKvgULeL5dzH9iVo1Knx2FeImyX0/jLWVaqo1U65VNKFki6b8/idkn7C3e83s5dL2iLpx5atdGnYMLOzJV3u7qGr/QRJt7n7HaPnuULSaZKibQjbdnRSJP8QHZgmAWPe77cNHk0ukVt3A8pl/JJJXrOx9OmKcHXlvGw1kspSryH3k6lCRqgrxM2S4jLVUro7kTdF4OgGd/+oma1f8PgnJ779lKR1VdZbZRrVc7RyROR9oyFZq7LiCg6V9MWJ73eNfrYXMzvLzLaa2dbdj3ZrGLXuhrLNBu+Zt2/7zr+QYq0XnbS0ZhfVaymjGkMOGpNyvZagnVHmgy9Sex/7rYe+lqxxy8QOGqnOs0r1PKHeL6ZT9d7acb2P/p3VYl2vlfShKgsuHdlw9982s/8s6WWSXiPpQjN7n6T3uvvtLRo5K7T4jOffopVhGj3r2Zue9ngqsadPNd0QpQwB4+dqMtpRd/pG6tENjrpUsrRmU9QrQSOcNqMcpUynwly197HPXX98q5oN1VGNeb+bnCafP8aIR6gRR6ZTdctu26/O5+k+d9/c9jnN7Ce1Ejb+ZZXlK52z4e5uZl+R9BVJeyQdLOmvzexad39Tw7buknTYxPfrJN3TcF1LlXwErAtBY9bz1g0dIeaLI6vGNVtC/ZV66emqdRzqXKpZmtZmCdOp6MDMlXQfW3LQyB0yZgkxnXGW0gMHus/MXijpPZJe7u5fr/I7S6dRmdlvmNkNkt4p6ROSXuDuvy7peEmvaNHe6yUdaWYbzGy1pNMlXdVifdHEHNVoshEsZVpTkzbUeb11dzoM/0aXvWZz3EwzpMlpiXXruM3vVtH0tTZ9b6nX6JLVa6lBo8TLUU8r9WacsZRw4AnNmdnzJF0p6Ux3/3zV36tyzsZaSb/g7j/t7n/l7k9Ikrs/KemURq1d+f09ks6WdI2kHZLe5+7bm66vi5oGjZKUFjjaYCO42NBqNlQHIOZ5T6HXXXrHbB5q9+m6Vq99urFmE6HbHOL95IDA8JjZX0r6Z0k/ZGa7zOy1ZvZ6M3v9aJHzJX2vpD82sxvNbGuV9VY5Z+P8BY/tqPIkC37/aklXt1lHFSl3RDE7x6UFjbEm06piTali6DeuJjUbqv5SjmqUcFW4ps8X4gpyqS5ZTb3GlWIfG6JDGmq/2bWAMUvI6VUhplTFqFGmPpbL3c9Y8vjrJL2u7nq5g/gSsZJ93Y1iqUFjUqw2ljwkjLLlCBq5pzmGeP5UN+VEt5USNLo4krFMqNdDXaIEhI2AqhZ1H4PGWJ22xto5tLpaEdMxgso9qlFXiMtPl6Jte1IFjmDz/andzgkVNPoqVIhq+z7H2P5Sr8PS+7DR5gNdwnzFkJ2Xh2/4zNx/IcXocHF0BnWlvs9NqVIHDgxD2/0jQaO6vgYODEelS99iuRijGm07MHVCxPSya44/ttVzV8XlcLFMk51cynBactAYK/keORLnbqCeoYSMSX3cV3LuxnAQNgrVpgMTYqRivI6moeOZt2+r3LGJsRFt03lhAxjG/gd050hYH64MV0WduuwiajeNnKMaKYJGm9qOfT8cqfnJ421PGOegAJrq9TSqVFOoQo9qNN3QxZgS1WadoTtjTKUanhSjGkMJGmOxL1ct5T13A3H1MWiEvH9NzHvhjLV5H0rbj3LuxjD0OmwMSeiQEWr9VTe2QxwWR35DCxpjKQIHMK2koBE7EMR+nlyBg4MCaIJpVInEHNWIHTSmnyfV+RyzhLhuOLqhxHM1YnVMqtRw6LqLPaUq17kbTKWKJ1dHM2TQyH1ZaincVKs+nseBfmJko6XcQ5Kpgkab58y1cecIDBbJeQnqJleDi3H1uLqvidENNNV0XxnqM1fSpalDtqXp+1PS6AZTqfqPsDFDrk5q3Y1PjqAR87lL6siw8euemME/RMcgZFgIua6YgSP3wRiE0+r+RRmDRkkhY1qXAwdQR2/DRkmdxRid6JxBo4mQG3s2kP0XO/CnvgR1zHoNsf6SOmNB7kpd0PYf+ZT0uZ6n5DC0CKMbqKO3YSOFkJ3eOhubUoJGKe0A6tRiyhG0lDWS8rkY3RiWro1qdLED37a9jG6gZISNjimtg1+nPVU2pqE7gpy3gTZKugx17OftWucM/dU2aHRV25CUeioy+1dURdiYErp4qhR/lzeOQGgxd2Appk+VcEAgReCI2bGhE9MPTY6aDzVoTEodOEoZ3WAqVX8RNjqkhE7MLKW2C8NQyo5SKqsWSmpLjr8RHZcwmga/1H/zvgSNsa68Hg4MoArCBoKo2rEJNZUqxY6Mzkq/xB7VKKlzP9akTaWMbmB4mn6eutIxr6vp6+ry6Ab6ibDRUJXCDDmFqsSODIAVJddnKW2r25nhiGl+Kf8GJQeN6fvipDwnK2XgAGLp5R3EOSINdFPdzk2Mo3Fdut9NVQ/f8JladyCPfXdx9Fuqo+Qxgkadep61bJ06qypVPR51yP3a8dWDa//epo376padj4dpw3GHa8e2u4OsC+VgZGOA7rxu+3f+hVRqp4sjpIh1lK/Uz/wssdrKEVS01eQzFDpolHjDzUlNXi+1iSbM7GQzu9XMbjOzN894/HAz+wcz+6yZ/aOZrVu2TsJGB4TYaM0LGLGCxyJ9nV+L7uOz+ZQ+vBeMcjdX8onhIT+bMadE5boEdhtN/34c1OsHM1sl6SJJL5d0tKQzzOzoqcX+QNJl7v5CSRdI+r1l6yVsDEDVIJEycABt5T6hsWudCCl/m3P/zVCeXEffUwaBUM81lNENDhJkdYKk29z9DnffLekKSadNLXO0pH8Yff2RGY8/DWFjQupkXuKRw1ICRxc3kGgnVv1V/SzVqcfcnfY2Qt+IU4pXrxwt7ZYujGrkHG3IFTjq4qBAr601s60T/86aevxQSV+c+H7X6GeTbpL0itHXPy/pQDP73kVP2ssTxPGU1OGh7omobTQ9mQ0oSdUa3XDSMZFbAqxIFfLqhtQQQSO3cRva7CfrnjC+fvdO3bV6Y+PnqyrkieKo7rE9+9TpC93n7psXPG4zfuZT358n6UIz+xVJH5X0JUl7Fj0pIxs91jRolDK6AZQoVIel7rlSIc+tKqHTBaRU2me+tPYAI7skHTbx/TpJ90wu4O73uPsvuPuxkt42+tk3F62UsFE4NkhAfCmnNLYNDV2/oANTNPqp7t815ahGqfvRNu2q+37Ufb9z1ynnbWRzvaQjzWyDma2WdLqkqyYXMLO1ZjbOD2+RdMmylRI2AHRO1R1hn8/9KW0EstT3mk7LsJUaNMZKbx+Gxd33SDpb0jWSdkh6n7tvN7MLzOzU0WIvlnSrmX1e0iGSfnfZegkbAFBR245BSQGBTg5mSXG+RqpRja58xpu2M/boRhNc1KH73P1qd3++ux/h7r87+tn57n7V6Ou/dvcjR8u8zt2XnqhD2ACABGIEjVThJeeV8+i8lC/3lJtZuhI0xkpsb4l/V3QTYQNAdjk7lCk60jFDQUmjJUBoTeqzxI57LCVeQh+YRtgoXKrLyAJYrGkHJkUYaPocQ+qUoQylnttTghT1yPuPHAgbPcZ1+QGkREem2/bfb9Yl9vMZ4qhGk/b3dXSDizv0B2EDT9MmpDASg1LQ8QXiY15/vzX5+3KeFaYRNib08c6XdYNDl0ZDuHs4SpfyfIounbtBB3WYYh8A6PqoxlhfXgcwRtjI6JEjjsvdhGLdtXpj7iagUHRUgf7r69SgWOq8X4z6IrV9cjypmf2+pJ+TtFvS7ZJe4+4P5GjLEEyOVsw7+tmlEQ2kR80C3THEeo05GrBo1DDWvvPhGz7DtGT0RpawIelaSW9x9z1m9g6t3O78tzK1pXhrjj822IY0ZqioumFMPaLTx+lxGVCzPVW1U/PM27cxGtsd1GtLVaclTi7HQTtgtizTqNz970e3RJekT0laF3L9O7bdHXJ1wODFrtnSMYcaXZKqXmNNaaw7hSp0fTY9/yn0eVN1X1esqWdMXUVbJZyz8auSPpS7EXVVOTk55HkHDKeiIJ2s2SHo0kniSKaIeu3KeQJta6grNRj778EVqTApWtgwsw+b2S0z/p02sczbJO2RdPmC9ZxlZlvNbOvuR/uXrvsyLSHkFCpODs8jRM1O1uu3HvpaqqZjhGkcwxFjH/vgA8Ou2VBBoSuBowu410Y/RDtnw91fuuhxM3u1pFMkvcTdfcF6tkjaIknPevamucsNQchzN7ouxWVvhzYdL0TNTtbrc9cf35t6pfZQmhj72B/YuLlzNRuqLkMHhDuv2074B0ayTKMys5O1crLaqe7+rRxtmCfHycRdH91gilf/lVyzSKPr26khoV7riTUSEWK9HORAH+Q6Z+NCSQdKutbMbjSzizO1I7rQU4JK69jXaU+OzgpXogpmMDU7NKVtUxBEZ+uV+2u0w/uHEmW59K27/2CO5w1tx1cP5ioNgXG+RplKqlnqDlispHotXezzK5hOBZRxNaooujbfvs5R/1KOROYa1UhxvgYQAp0MAMDQ9TZstBF66k2Mo/W5A0fu5weWGdooWZeCDQcMAGA4CBsthdxp1j36n6vDX/d5q76ukjqHXRsZA0pQUg2j/4Zy8vRQXifKYGYnm9mtZnabmb15zjL/1sw+Z2bbzewvlq2TsJFIrJ1w6vFSF+AAABOoSURBVMCRe0SjTrjj5HCE1PSzn2LEoUujGhieUm/ol+p+GNx3A11hZqskXSTp5ZKOlnSGmR09tcyRkt4i6UR3P0bSucvWS9goTJNzG9Ycf2ySENDkObo4qoH0cgbDFFdJKzUM5D54AAAoygmSbnP3O9x9t6QrJJ02tcyvSbrI3e+XJHe/d9lKex022kyFqdP5qXq0vWqHumnnJ1bHoWmY4br8QHwpgkzOWmaEEgCCWWtmWyf+nTX1+KGSvjjx/a7RzyY9X9LzzewTZvap0X19Fspy6Vss98gRxzW6XvY4FISY45nqqGfVEJZqChXna2CeNncS33DSMcGnU5Q6YgIASOPRx7xOn+c+d9+84HGb8TOf+n4fSUdKerGkdZI+Zmab3P2BeSvt9chGiVJNFxqPRtQNDE1/bxqjGogp9GhiKiHDQdt1DWUKFQcPAKCyXZIOm/h+naR7ZizzQXd/wt3vlHSrVsLHXIxsLHDLzse1aeO+lZaNcaOxpqMb01J3KuoEjRijGkBdoWqtinFIaDrKUepoBrWMee5avbHIk8RjjDbOex6gI66XdKSZbZD0JUmnS3rV1DIfkHSGpEvNbK1WplXdsWilvR/ZKPGoVp2jrV0bIYgRNOpijjdiChXeN5x0TK1OSN3lF8l1Q04gpKGMzg3ldSI/d98j6WxJ10jaIel97r7dzC4ws1NHi10j6etm9jlJH5H0n9z964vWy8hGQDFGN6S0R13biNUpSXkktMRwin6bFSDuvG47R0MBAMm5+9WSrp762fkTX7uk3xz9q6T3IxttxTpKXveofulHF+u2j1ENTMtda3U+w7GPNMYMGjFGNUo7NwYoBQcNgIGEjZRHq+schW8SOEoMHTGDBvO7MU/uzwZTG+qr+zfj4AFi61sYKLGPAAwibLRVd4cXM3BI5WxMSg0/TTGFCn36PM9CQELp+lSDfQsyQFOEjQI0DRw5N8pNnzvmqAZHQTFPrGk+Xeq8121rH6ZQcQABVZQcCrq0jQHmGUzYaLvTiTm60UbqwNEm5DB9CsuUUGd1P99d6Ax0oY1AEyGvDhdSyQEGSG0wYaN0bY4OjgNAzODRdv2xj362HdXgCCjaKLkz36RtMbYljFSidKECAkEjHPbN/TCosFH66EaIDnnI4BFqXXVfF6MaiCH2/W1KDByxg0bJU6hQlq58VtoGha4EjdIPAKJfuM9GZHXvvTHeAIS42+q8TsP4nh0ppmClCBqMagxbrPvbNLHm+GP18A2fyd0MSWWGn5So6/hKubdU6LobB4Y6dxePETJinWdVFwcA0dagRjZCaNKxbVKoMY86pDq5PMWRLI6e9E/Mv2ns0Q2pjE5+0zbEGtVgChW6aMNJxywMEePHuzKaAeQyuJGNHdvu1lHHHZ67GZXctXpjkBGO1JqGjBxHTzj6iUXqHmEdyznCkSJoAKWJWXM5wkQJBy2AUBjZaCDV6IbUnXmuYymDBkc/MRb73jZNrDn+2KQdhpTPx5XlEBvhtx7OtULJBhk2QhzNTh04St84tGljrs4IoxrlKilItu30xA4BIdZfUscuxN+e2i5T7P1YX0YD+vI6gLFBho2c2nSsSw0dbdrU9P0oqTOKMsQc3QjRGQ8dOkKtr+5rY1QDk/gb9xuzDhDCYMNGrtENqf3GuZTQ0bYdOYMGRz6RI3BIT4WEumGh6e8tEjNooHsefcxzN2EvfbkEdR2l3BenBOyn+2NwJ4iHdsvOx7Vp4761fy/EJQNDXia37nO2xYgGlmlSW7Evg9v0hPFFcnWOSuygcCCh/7p64ZMUUmwLOGCAHAY7siGF2ynlGuEYG48wxNqIhF5/7mF3OiMYa/KZLrGTXleT18DNOZHDEEc36ujD9gj9N+iwIeXveO746sFBd8qTwaBpOAixjnnavFZGNYYnxYUYhhY4Sg0a1Hd3lRgsuxY4ShzVKPHvim5iGlUgTadTjcWa/lHKkGnbjVaojkjucDkkj36r353HGFOqYutySKqC+m6v7b6sirpTqbp4v5s6+nRvHA4aYJbBj2xI+adTfacdPT2KUErQQDeVOrohrezsS9zhz9K0nRwNRZeVPsJRevswPGZ2spndama3mdmbZzz+ejO72cxuNLOPm9nRy9ZJ2AgsRODoy846xGsJGTQ46olF2owClhw42gSiVCOjHFDovtgBv02Nldqhb9Ou2FeSy90PYX+dh5mtknSRpJdLOlrSGTPCxF+4+wvc/UckvVPSu5atl7AxEvKDHeSKKh0PHCHaT9Dovpyjhk0+g20DR0mho217mrwXObdb1DgWKS1wlNaeEDho0AsnSLrN3e9w992SrpB02uQC7v7gxLfPlLT0mtmEjQklBo6uhY5QbWajhRBSBw4pf+gI8fwpgwa1Xp5Uf5PU97oJfWPNXG3o+6gGsjpU0hcnvt81+tlezOwNZna7VkY2fmPZSjlBPKJQJ9qNCz/m/QPaCrlxCr2j44hnXju23a2jjju89XpS3tMmxL0Axh2CVCeRhwo4pVxUAt0T+z43UpgLM+Q6cTxE0Clp9BRlePRbj9fp56w1s60T329x9y0T39uM33nayIW7XyTpIjN7laTflvTqRU9K2JgSqmM0FvLKHiWGjtBHQAgaKEWom49Ndg5CB4/QHY+mQSP3qAZ13l25bvI37vinCB2hRlNSXLa6BNRzVPe5++YFj++SdNjE9+sk3bNg+Ssk/a9lT5p1GpWZnWdmbmZrc7ZjWugPevAO9GiqUq6hzljPT9AoX+6aTX0DzdA76vEUp6ZTndr+/iJdDRqYr0295r5Z7SIhP/sxp1aVMG2rLup58K6XdKSZbTCz1ZJOl3TV5AJmduTEtz8r6f8tW2m2kQ0zO0zST0n6Qq42LFLyCMek6Q1DrFGP2DsQNlTla1OzoeupiaZTPGIeeS1lSkTqoBESBxVmK30fO6lJjYW+z81kKGgz2hErXAxlVAN5ufseMztb0jWSVkm6xN23m9kFkra6+1WSzjazl0p6QtL9WjKFSso7jeoPJb1J0gcztiGpcYc65g2TZu3863SwcnQeYgQNOiBRFFGzbYJ7m8AhKct0j5jadEbabCs4uJBEtnpNce6GFO/GmosCwziIpByxSHVQooRRDfbd+bn71ZKunvrZ+RNfn1N3nVnChpmdKulL7n6T2axzUfZa9ixJZ0nS/mu+P0HrnhLraGyKO7ROKuHo4yyxOhxsrMKrWrOL6jVkPaWuobFc88tj6MNRT2p9tqb72IO+57C9HktZZ01rK1bgmCf1tKjS75EDVBHtnA0z+7CZ3TLj32mS3ibp/GXrkCR33+Lum9198+r903eaY+3Mhn5kj6BRnhA1m7teq2gbvu9avbHzO/K27WdUI78Y+9gDDvy+YO1LednpUqYjhpYyaJR6UBL9EG1kw91fOuvnZvYCSRskjY+4rJO0zcxOcPevxGpPGzFHOKS406pKE7OjQdBoJ0XNljK6EWKaRxdHOUKEpFKCxtDrPdU+NtcoYl2pRzhi60qAoqZRRfKrUbn7ze7+bHdf7+7rtXKZreNKDRpjMYvglp2PD+JoH0Gjm0qu2TafqRBH8royyhGqnaUEDcxXUr2mvqlmVzroy7R5HYxqoETcQbyG2B3avu6MhxKmUE1JwTDUTrbU0BGyXSV1SEr6DA1B6u1328DR1dDRtu2pt0Hs11FV9rAxOvpyX+52VJUicPSlgFO9FjoeaZVWs20/YyHvGTPu3OcOHqHb0Pb96cs2rYty12uue9x0LXC0bW+XL18tsR/vu+xho4tSFEWXQ0fKtrOB6qYSb5wZeqebOnTECDohghg36+yHHPujEIGj9NARoo05gkZX+yfII+d9Njot1U3KJgu65JP0cmx46HR0W4k3zoxxf4BZHYEQJ5bHDjIhwhcdEkjp7rsxT6knj4cIQrlHUYEqCBstpL4rconBI1dngqCBWUoNHNNK7yCUMrViGnXfXW1uqBkinI879iWEjhJGW0oa1aCu+4+w0VLqwDGWK3jkPlLJRqlfctXPMuMdcc6jsbmEChp0SPqnDzfTnOzopwweMQJG6QctgDHCRgC5O0yzduohdgi5g8U0OhuoImSHKPf0j5RCjmaUtu1AOLnubxPjvjaxRztijmC0CRqMaiA1wkYguQPHtL7t7Nkg9VeM2gkdOKT+jnKEnjIVY9tD/fdH28AhhTnnadKsUNAkgKSYHtV2NKPUKZLoN8JGQOMdYkmhow/oaPRf6YFD6ucoB0EDTbStrba1FGOUY1oJ51VMyx00+nYQE+lw6dsI2DmGsWPb3byXaCX4kH/Ae3LkFON10BFBSkM7X6GPr5f9+3AQNiKhiNrh/RueWH/zKEfbOxo6YrU7VtBgO1CuEDfTbKuPHfBZQrxORjWQE9OoImJaVTN0MIYr1rlP4x1l6CvpTO7AS51iFTsUETSGK/d0KineeRwlCBWmSgwa1PewMLKRAEVVDdOmIMWtl5hH58ajBqWMeMRuyy07HydooLVQn9G+jXKUEjSAEBjZSIRRjvnoWCClFPcKyDHikbJTwZQKjJV0I80+jHKEDE0htgmMaiAEwkZihI69sdHBLLEvJZ3y5mSzdvhtO1Y5j1bGDhpsE7qnpMAhdTN0hB6ZKTVoYJgIG5kMPXTQocAyKQKHFP48jiq6OrWBoIGYQl9euguhI8b0r5K3L9R4+czsZEl/JGmVpPe4+3+fenxfSZdJOl7S1yX9orvftWidnLOR2dDOUxja60U7KT4rHL1bLub5GWNsF7ot1OcjRkf5rtUbizunI1abQr1/bBeHycxWSbpI0sslHS3pDDM7emqx10q6391/UNIfSnrHsvUSNgrR5074+LX19fUhLgJHXineG7YN/VBy4JCe6uDnCh6xn7/0oEGdd8IJkm5z9zvcfbekKySdNrXMaZL+dPT1X0t6iZnZopWauwdvaSxm9jVJIT+tayXdF3B9IdG25kpuX+i2He7u3xdwfcFQr0UpuX0lt00K275i61UKXrND+ruGRtuaG8w+dszM/k4rr7uK/SQ9NvH9FnffMrGuV0o62d1fN/r+TEk/5u5nTyxzy2iZXaPvbx8tM/d979Q5G6H/4Ga21d03h1xnKLStuZLbV3LbQqNey1Fy+0pum1R++0IKWbOlv28lt4+2NVd6+2Jw95MDrm7WCMX0qESVZfbCNCoAAAAAuyQdNvH9Okn3zFvGzPaR9N2SvrFopYQNAAAAANdLOtLMNpjZakmnS7pqapmrJL169PUrJV3nS87J6NQ0qgi2LF8kG9rWXMntK7ltpSv5vSu5bVLZ7Su5bVL57StV6e9bye2jbc2V3r6iufseMztb0jVaufTtJe6+3cwukLTV3a+S9F5Jf2Zmt2llROP0Zevt1AniAAAAALqDaVQAAAAAoiBsAAAAAIiCsCHJzM4zMzezqtcpTsLMft/MdprZZ83s/5jZswpo08lmdquZ3WZmb87dnjEzO8zMPmJmO8xsu5mdk7tNs5jZKjP7jJn9be62dFmJNUu91tOFmqVew6Beqyu1ZqlXtDH4sGFmh0n6KUlfyN2WGa6VtMndXyjp85LekrMxFW9jn8seSW9096MkvUjSGwpq26RzJO3I3YguK7hmqdd6ulCz1GtL1Gt1hdcs9YrGBh82JP2hpDdpyQ1JcnD3v3f3PaNvP6WV6x3nVOU29lm4+5fdfdvo64e0ssE5NG+r9mZm6yT9rKT35G5LxxVZs9RrPaXXLPUaDPVaXbE1S72ijUGHDTM7VdKX3P2m3G2p4FclfShzGw6V9MWJ73epoI3NmJmtl3SspE/nbcnTvFsrO90nczekqzpUs9RrDYXWLPXaEvVaWydqlnpFXb2/z4aZfVjSc2Y89DZJb5X0srQt2tui9rn7B0fLvE0rQ5iXp2zbDLVvUZ+ama2R9H5J57r7g7nbM2Zmp0i6191vMLMX525PyUquWeo1vBJrlnqtjnoNqviapV7RRO/Dhru/dNbPzewFkjZIusnMpJUh1G1mdoK7fyV3+8bM7NWSTpH0kmV3aEygym3sszGzZ2hlI3i5u1+Zuz1TTpR0qpn9jKT9JB1kZn/u7r+cuV3FKblmqdewCq5Z6rUi6jWoomuWekVT3NRvxMzukrTZ3e/L3ZYxMztZ0rsk/YS7f62A9uyjlRPpXiLpS1q5rf2r3H171oZJspW92Z9K+oa7n5u7PYuMjryc5+6n5G5Ll5VWs9RrPV2pWeo1DOp1uZJrlnpFG4M+Z6MDLpR0oKRrzexGM7s4Z2NGJ9ONb2O/Q9L7StgIjpwo6UxJJ43eqxtHRzmAVKjXeqhZ5FRUvUrF1yz1isYY2QAAAAAQBSMbAAAAAKIgbAAAAACIgrABAAAAIArCBgAAAIAoCBsAAAAAoiBsAAAAAIiCsAEAAAAgCsIGKjOzHzWzz5rZfmb2TDPbbmabcrcLwNNRr0B3UK/oM27qh1rM7L9K2k/S/pJ2ufvvZW4SgDmoV6A7qFf0FWEDtZjZaknXS3pM0r9w929nbhKAOahXoDuoV/QV06hQ1/dIWiPpQK0cgQFQLuoV6A7qFb3EyAZqMbOrJF0haYOk57r72ZmbBGAO6hXoDuoVfbVP7gagO8zs30na4+5/YWarJH3SzE5y9+tytw3A3qhXoDuoV/QZIxsAAAAAouCcDQAAAABREDYAAAAAREHYAAAAABAFYQMAAABAFIQNAAAAAFEQNgAAAABEQdgAAAAAEMX/B0ZS1Y6hTyFxAAAAAElFTkSuQmCC\n", 95 | "text/plain": [ 96 | "
" 97 | ] 98 | }, 99 | "metadata": { 100 | "needs_background": "light" 101 | }, 102 | "output_type": "display_data" 103 | } 104 | ], 105 | "source": [ 106 | "fig, axes = plt.subplots(ncols=3, figsize=(15, 4))\n", 107 | "\n", 108 | "im0 = axes[0].contourf(x_v[:,:,0], y_v[:,:,0], u_v[:,:,0], cmap='coolwarm')\n", 109 | "axes[0].set_xlabel('x')\n", 110 | "axes[0].set_ylabel('y')\n", 111 | "axes[0].set_title('t = 0')\n", 112 | "\n", 113 | "im1 = axes[1].contourf(x_v[:,:,10], y_v[:,:,10], u_v[:,:,10], cmap='coolwarm')\n", 114 | "axes[1].set_xlabel('x')\n", 115 | "axes[1].set_title('t = 10')\n", 116 | "\n", 117 | "im2 = axes[2].contourf(x_v[:,:,20], y_v[:,:,20], u_v[:,:,20], cmap='coolwarm')\n", 118 | "axes[2].set_xlabel('x')\n", 119 | "axes[2].set_title('t= 20')\n", 120 | "\n", 121 | "fig.colorbar(im1, ax=axes.ravel().tolist())\n", 122 | "\n", 123 | "plt.show()" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "We flatten it to give it the right dimensions for feeding it to the network:" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 32, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "X = np.transpose((t_v.flatten(),x_v.flatten(), y_v.flatten()))\n", 140 | "y = np.float32(u_v.reshape((u_v.size, 1)))" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "We select the noise level we add to the data-set" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 33, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "noise_level = 0.01" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 34, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "y_noisy = y + noise_level * np.std(y) * np.random.randn(y.size, 1)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "Select the number of samples:" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 35, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "number_of_samples = 1000" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 8, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "idx = np.random.permutation(y.size)\n", 191 | "X_train = X[idx, :][:number_of_samples]\n", 192 | "y_train = y_noisy[idx, :][:number_of_samples]" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 36, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "number_of_samples = 1000\n", 202 | "\n", 203 | "idx = np.random.permutation(y.shape[0])\n", 204 | "X_train = torch.tensor(X[idx, :][:number_of_samples], dtype=torch.float32, requires_grad=True)\n", 205 | "y_train = torch.tensor(y[idx, :][:number_of_samples], dtype=torch.float32)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": {}, 211 | "source": [ 212 | "## Configure the neural network" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 37, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "## Running DeepMoD\n", 222 | "config = {'n_in': 3, 'hidden_dims': [20, 20, 20, 20, 20, 20], 'n_out': 1, 'library_function': library_2Din_1Dout, 'library_args':{'poly_order': 1, 'diff_order': 2}}" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "Now we instantiate the model:" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 38, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "model = DeepMod(**config)\n", 239 | "optimizer = torch.optim.Adam([{'params': model.network_parameters(), 'lr':0.001}, {'params': model.coeff_vector(), 'lr':0.005}])" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "## Run DeepMoD " 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "We can now run DeepMoD using all the options we have set and the training data. We need to slightly preprocess the input data for the derivatives:" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": 41, 259 | "metadata": {}, 260 | "outputs": [ 261 | { 262 | "name": "stdout", 263 | "output_type": "stream", 264 | "text": [ 265 | "| Iteration | Progress | Time remaining | Cost | MSE | Reg | L1 |\n", 266 | " 25000 100.00% 0s 2.60e-05 2.09e-06 5.13e-06 1.88e-05 \n", 267 | "[Parameter containing:\n", 268 | "tensor([[0.2501],\n", 269 | " [0.4927],\n", 270 | " [0.4943],\n", 271 | " [0.4837]], requires_grad=True)]\n", 272 | "[tensor([1, 2, 3, 4])]\n", 273 | "\n", 274 | "| Iteration | Progress | Time remaining | Cost | MSE | Reg | L1 |\n", 275 | " 25000 100.00% 0s 1.36e-05 2.09e-06 1.15e-05 0.00e+00 " 276 | ] 277 | } 278 | ], 279 | "source": [ 280 | "train_deepmod(model, X_train, y_train, optimizer, 25000, {'l1': 1e-5})" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [] 289 | } 290 | ], 291 | "metadata": { 292 | "kernelspec": { 293 | "display_name": "Python 3", 294 | "language": "python", 295 | "name": "python3" 296 | }, 297 | "language_info": { 298 | "codemirror_mode": { 299 | "name": "ipython", 300 | "version": 3 301 | }, 302 | "file_extension": ".py", 303 | "mimetype": "text/x-python", 304 | "name": "python", 305 | "nbconvert_exporter": "python", 306 | "pygments_lexer": "ipython3", 307 | "version": "3.7.6" 308 | } 309 | }, 310 | "nbformat": 4, 311 | "nbformat_minor": 4 312 | } 313 | -------------------------------------------------------------------------------- /examples/VE_datagen.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.integrate as integ 3 | 4 | #Data generation routines 5 | def calculate_strain_stress(input_type, time_array, input_expr, E_mods, viscs, D_input_lambda=None): 6 | # In this incarnation, the kwarg is non-optional. 7 | 8 | # if D_input_lambda: 9 | input_lambda = input_expr 10 | # else: 11 | # t = sym.symbols('t', real=True) 12 | # D_input_expr = input_expr.diff(t) 13 | 14 | # input_lambda = sym.lambdify(t, input_expr) 15 | # D_input_lambda = sym.lambdify(t, D_input_expr) 16 | 17 | # The following function interprets the provided model parameters differently depending on the input_type. 18 | # If the input_type is 'Strain' then the parameters are assumed to refer to a Maxwell model, whereas 19 | # if the input_type is 'Stress' then the parameters are assumed to refer to a Kelvin model. 20 | relax_creep_lambda = relax_creep(E_mods, viscs, input_type) 21 | 22 | if relax_creep_lambda == False: 23 | return False, False 24 | 25 | start_time_point = time_array[0] 26 | 27 | integrand_lambda = lambda x, t: relax_creep_lambda(t-x)*D_input_lambda(x) 28 | integral_lambda = lambda t: integ.quad(integrand_lambda, start_time_point, t, args=(t))[0] 29 | 30 | output_array = np.array([]) 31 | input_array = np.array([]) 32 | for time_point in time_array: 33 | first_term = input_lambda(start_time_point)*relax_creep_lambda(time_point-start_time_point) 34 | second_term = integral_lambda(time_point) 35 | output_array = np.append(output_array, first_term + second_term) 36 | input_array = np.append(input_array, input_lambda(time_point)) 37 | 38 | if input_type == 'Strain': 39 | strain_array = input_array 40 | stress_array = output_array 41 | else: 42 | strain_array = output_array 43 | stress_array = input_array 44 | 45 | strain_array = strain_array.reshape(time_array.shape) 46 | stress_array = stress_array.reshape(time_array.shape) 47 | 48 | return strain_array, stress_array 49 | 50 | 51 | def relax_creep(E_mods, viscs, input_type): 52 | 53 | # The following function interprets the provided model parameters differently depending on the input_type. 54 | # If the input_type is 'Strain' then the parameters are assumed to refer to a Maxwell model, whereas 55 | # if the input_type is 'Stress' then the parameters are assumed to refer to a Kelvin model. 56 | # The equations used thus allow the data to be generated according to the model now designated. 57 | 58 | E_mods_1plus_array = np.array(E_mods[1:]).reshape(-1,1) 59 | viscs_array = np.array(viscs).reshape(-1,1) 60 | 61 | taus = viscs_array/E_mods_1plus_array 62 | 63 | if input_type == 'Strain': 64 | relax_creep_lambda = lambda t: E_mods[0] + np.sum(np.exp(-t/taus)*E_mods_1plus_array) 65 | elif input_type == 'Stress': 66 | relax_creep_lambda = lambda t: 1/E_mods[0] + np.sum((1-np.exp(-t/taus))/E_mods_1plus_array) 67 | else: 68 | print('Incorrect input_type') 69 | relax_creep_lambda = False 70 | 71 | return relax_creep_lambda -------------------------------------------------------------------------------- /examples/data/Advection_diffusion.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/examples/data/Advection_diffusion.mat -------------------------------------------------------------------------------- /examples/data/burgers.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/examples/data/burgers.npy -------------------------------------------------------------------------------- /examples/data/kdv.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/examples/data/kdv.npy -------------------------------------------------------------------------------- /examples/runs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/examples/runs/.DS_Store -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # ============================================================================= 2 | # DEPRECATION WARNING: 3 | # 4 | # The file `requirements.txt` does not influence the package dependencies and 5 | # will not be automatically created in the next version of PyScaffold (v4.x). 6 | # 7 | # Please have look at the docs for better alternatives 8 | # (`Dependency Management` section). 9 | # ============================================================================= 10 | # 11 | # Add your pinned requirements so that they can be easily installed with: 12 | # pip install -r requirements.txt 13 | # Remember to also add them in setup.cfg but unpinned. 14 | # Example: 15 | # numpy==1.13.3 16 | # scipy==1.0 17 | # 18 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # This file is used to configure your project. 2 | # Read more about the various options under: 3 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files 4 | 5 | [metadata] 6 | name = DeePyMoD 7 | description = DeePyMoD is a PyTorch-based implementation of the DeepMoD algorithm for model discovery of PDEs. 8 | author = Gert-Jan 9 | author-email = gert-jan.both@cri-paris.com 10 | license = mit 11 | long-description = file: README.md 12 | long-description-content-type = text/markdown; charset=UTF-8 13 | url = https://github.com/pyscaffold/pyscaffold/ 14 | project-urls = 15 | Documentation = https://pyscaffold.org/ 16 | # Change if running only on Windows, Mac or Linux (comma-separated) 17 | platforms = any 18 | # Add here all kinds of additional classifiers as defined under 19 | # https://pypi.python.org/pypi?%3Aaction=list_classifiers 20 | classifiers = 21 | Development Status :: 4 - Beta 22 | Programming Language :: Python 23 | 24 | [options] 25 | zip_safe = False 26 | packages = find: 27 | include_package_data = True 28 | package_dir = 29 | =src 30 | # DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD! 31 | setup_requires = pyscaffold>=3.2a0,<3.3a0 32 | # Add here dependencies of your project (semicolon/line-separated), e.g. 33 | install_requires = numpy; torch>=1.3.* ; tensorboard 34 | # The usage of test_requires is discouraged, see `Dependency Management` docs 35 | # tests_require = pytest; pytest-cov 36 | # Require a specific Python version, e.g. Python 2.7 or >= 3.4 37 | python_requires = >= 3.6.* 38 | [options.packages.find] 39 | where = src 40 | exclude = 41 | tests 42 | 43 | [options.extras_require] 44 | # Add here additional requirements for extra features, to install with: 45 | # `pip install DeePyMoD_torch[PDF]` like: 46 | # PDF = ReportLab; RXP 47 | # Add here test requirements (semicolon/line-separated) 48 | testing = 49 | pytest 50 | pytest-cov 51 | 52 | [options.entry_points] 53 | # Add here console scripts like: 54 | # console_scripts = 55 | # script_name = deepymod_torch.module:function 56 | # For example: 57 | # console_scripts = 58 | # fibonacci = deepymod_torch.skeleton:run 59 | # And any other entry points, for example: 60 | # pyscaffold.cli = 61 | # awesome = pyscaffoldext.awesome.extension:AwesomeExtension 62 | 63 | [test] 64 | # py.test options when running `python setup.py test` 65 | # addopts = --verbose 66 | extras = True 67 | 68 | [tool:pytest] 69 | # Options for py.test: 70 | # Specify command line options as you would do when invoking py.test directly. 71 | # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml 72 | # in order to write a coverage file that can be read by Jenkins. 73 | addopts = 74 | --cov deepymod_torch --cov-report term-missing 75 | --verbose 76 | norecursedirs = 77 | dist 78 | build 79 | .tox 80 | testpaths = tests 81 | 82 | [aliases] 83 | dists = bdist_wheel 84 | 85 | [bdist_wheel] 86 | # Use this option if your package is pure-python 87 | universal = 1 88 | 89 | [build_sphinx] 90 | source_dir = docs 91 | build_dir = docs/_build 92 | 93 | [devpi:upload] 94 | # Options for the devpi: PyPI server and packaging tool 95 | # VCS export must be deactivated since we are using setuptools-scm 96 | no-vcs = 1 97 | formats = bdist_wheel 98 | 99 | [flake8] 100 | # Some sane defaults for the code style checker flake8 101 | exclude = 102 | .tox 103 | build 104 | dist 105 | .eggs 106 | docs/conf.py 107 | 108 | [pyscaffold] 109 | # PyScaffold's parameters when the project was created. 110 | # This will be used when updating. Do not change! 111 | version = 3.2 112 | package = deepymod_torch 113 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Setup file for deepymod_torch. 4 | Use setup.cfg to configure your project. 5 | 6 | This file was generated with PyScaffold 3.2. 7 | PyScaffold helps you to put up the scaffold of your new Python project. 8 | Learn more under: https://pyscaffold.org/ 9 | """ 10 | import sys 11 | 12 | from pkg_resources import require, VersionConflict 13 | from setuptools import setup 14 | 15 | try: 16 | require('setuptools>=38.3') 17 | except VersionConflict: 18 | print("Error: version of setuptools is too old (<38.3)!") 19 | sys.exit(1) 20 | 21 | 22 | if __name__ == "__main__": 23 | setup(use_pyscaffold=True) 24 | -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/src/.DS_Store -------------------------------------------------------------------------------- /src/__pycache__/DeepMod.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/src/__pycache__/DeepMod.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/DeepMod.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/src/__pycache__/DeepMod.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/library_function.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/src/__pycache__/library_function.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/library_function.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/src/__pycache__/library_function.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/neural_net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/src/__pycache__/neural_net.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/neural_net.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/src/__pycache__/neural_net.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/sparsity.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/src/__pycache__/sparsity.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/sparsity.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/src/__pycache__/sparsity.cpython-37.pyc -------------------------------------------------------------------------------- /src/deepymod_torch/DeepMod.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from deepymod_torch.network import Fitting, Library 4 | 5 | 6 | class DeepMod(nn.Module): 7 | ''' Class based interface for deepmod.''' 8 | def __init__(self, n_in, hidden_dims, n_out, library_function, library_args): 9 | super().__init__() 10 | self.network = self.build_network(n_in, hidden_dims, n_out) 11 | self.library = Library(library_function, library_args) 12 | self.fit = self.build_fit_layer(n_in, n_out, library_function, library_args) 13 | 14 | def forward(self, input): 15 | prediction = self.network(input) 16 | time_deriv, theta = self.library((prediction, input)) 17 | sparse_theta, coeff_vector = self.fit(theta) 18 | return prediction, time_deriv, sparse_theta, coeff_vector 19 | 20 | def build_network(self, n_in, hidden_dims, n_out): 21 | # NN 22 | network = [] 23 | hs = [n_in] + hidden_dims + [n_out] 24 | for h0, h1 in zip(hs, hs[1:]): # Hidden layers 25 | network.append(nn.Linear(h0, h1)) 26 | network.append(nn.Tanh()) 27 | network.pop() # get rid of last activation function 28 | network = nn.Sequential(*network) 29 | 30 | return network 31 | 32 | def build_fit_layer(self, n_in, n_out, library_function, library_args): 33 | sample_input = torch.ones((1, n_in), dtype=torch.float32, requires_grad=True) 34 | n_terms = self.library((self.network(sample_input), sample_input))[1].shape[1] # do sample pass to infer shapes 35 | fit_layer = Fitting(n_terms, n_out) 36 | 37 | return fit_layer 38 | 39 | # Function below make life easier 40 | def network_parameters(self): 41 | return self.network.parameters() 42 | 43 | def coeff_vector(self): 44 | return self.fit.coeff_vector.parameters() 45 | -------------------------------------------------------------------------------- /src/deepymod_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pkg_resources import get_distribution, DistributionNotFound 3 | 4 | try: 5 | # Change here if project is renamed and does not equal the package name 6 | dist_name = 'DeePyMoD_torch' 7 | __version__ = get_distribution(dist_name).version 8 | except DistributionNotFound: 9 | __version__ = 'unknown' 10 | finally: 11 | del get_distribution, DistributionNotFound 12 | -------------------------------------------------------------------------------- /src/deepymod_torch/library_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import grad 4 | from itertools import combinations, product 5 | from functools import reduce 6 | 7 | def library_poly(prediction, max_order): 8 | # Calculate the polynomes of u 9 | u = torch.ones_like(prediction) 10 | for order in np.arange(1, max_order+1): 11 | u = torch.cat((u, u[:, order-1:order] * prediction), dim=1) 12 | 13 | return u 14 | 15 | 16 | def library_deriv(data, prediction, max_order): 17 | dy = grad(prediction, data, grad_outputs=torch.ones_like(prediction), create_graph=True)[0] 18 | time_deriv = dy[:, 0:1] 19 | 20 | if max_order == 0: 21 | du = torch.ones_like(time_deriv) 22 | else: 23 | du = torch.cat((torch.ones_like(time_deriv), dy[:, 1:2]), dim=1) 24 | if max_order >1: 25 | for order in np.arange(1, max_order): 26 | du = torch.cat((du, grad(du[:, order:order+1], data, grad_outputs=torch.ones_like(prediction), create_graph=True)[0][:, 1:2]), dim=1) 27 | 28 | return time_deriv, du 29 | 30 | 31 | def library_1D_in(input, poly_order, diff_order): 32 | prediction, data = input 33 | poly_list = [] 34 | deriv_list = [] 35 | time_deriv_list = [] 36 | 37 | # Creating lists for all outputs 38 | for output in torch.arange(prediction.shape[1]): 39 | time_deriv, du = library_deriv(data, prediction[:, output:output+1], diff_order) 40 | u = library_poly(prediction[:, output:output+1], poly_order) 41 | 42 | poly_list.append(u) 43 | deriv_list.append(du) 44 | time_deriv_list.append(time_deriv) 45 | 46 | samples = time_deriv_list[0].shape[0] 47 | total_terms = poly_list[0].shape[1] * deriv_list[0].shape[1] 48 | 49 | # Calculating theta 50 | if len(poly_list) == 1: 51 | theta = torch.matmul(poly_list[0][:, :, None], deriv_list[0][:, None, :]).view(samples, total_terms) # If we have a single output, we simply calculate and flatten matrix product between polynomials and derivatives to get library 52 | else: 53 | 54 | theta_uv = reduce((lambda x, y: (x[:, :, None] @ y[:, None, :]).view(samples, -1)), poly_list) 55 | theta_dudv = torch.cat([torch.matmul(du[:, :, None], dv[:, None, :]).view(samples, -1)[:, 1:] for du, dv in combinations(deriv_list, 2)], 1) # calculate all unique combinations of derivatives 56 | theta_udu = torch.cat([torch.matmul(u[:, 1:, None], du[:, None, 1:]).view(samples, (poly_list[0].shape[1]-1) * (deriv_list[0].shape[1]-1)) for u, dv in product(poly_list, deriv_list)], 1) # calculate all unique products of polynomials and derivatives 57 | theta = torch.cat([theta_uv, theta_dudv, theta_udu], dim=1) 58 | 59 | return time_deriv_list, theta 60 | 61 | 62 | def library_2Din_1Dout(input, poly_order, diff_order): 63 | ''' 64 | Constructs a library graph in 1D. Library config is dictionary with required terms. 65 | ''' 66 | prediction, data = input 67 | # Polynomial 68 | 69 | u = torch.ones_like(prediction) 70 | for order in np.arange(1, poly_order+1): 71 | u = torch.cat((u, u[:, order-1:order] * prediction), dim=1) 72 | 73 | # Gradients 74 | du = grad(prediction, data, grad_outputs=torch.ones_like(prediction), create_graph=True)[0] 75 | u_t = du[:, 0:1] 76 | u_x = du[:, 1:2] 77 | u_y = du[:, 2:3] 78 | du2 = grad(u_x, data, grad_outputs=torch.ones_like(prediction), create_graph=True)[0] 79 | u_xx = du2[:, 1:2] 80 | u_xy = du2[:, 2:3] 81 | u_yy = grad(u_y, data, grad_outputs=torch.ones_like(prediction), create_graph=True)[0][:, 2:3] 82 | 83 | du = torch.cat((torch.ones_like(u_x), u_x, u_y , u_xx, u_yy, u_xy), dim=1) 84 | 85 | samples= du.shape[0] 86 | # Bringing it together 87 | theta = torch.matmul(u[:, :, None], du[:, None, :]).view(samples,-1) 88 | 89 | return [u_t], theta 90 | -------------------------------------------------------------------------------- /src/deepymod_torch/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from deepymod_torch.sparsity import scaling 4 | 5 | 6 | def reg_loss(time_deriv_list, sparse_theta_list, coeff_vector_list): 7 | '''Loss function for the regularisation loss. Calculates loss for each term in list.''' 8 | loss = torch.stack([torch.mean((time_deriv - theta @ coeff_vector)**2) for time_deriv, theta, coeff_vector in zip(time_deriv_list, sparse_theta_list, coeff_vector_list)]) 9 | return loss 10 | 11 | def mse_loss(prediction, target): 12 | '''Loss functions for the MSE loss. Calculates loss for each term in list.''' 13 | loss = torch.mean((prediction - target)**2, dim=0) 14 | return loss 15 | 16 | def l1_loss(coeff_vector_list, l1): 17 | '''Loss functions for the L1 loss on the coefficients. Calculates loss for each term in list.''' 18 | loss = torch.stack([torch.sum(torch.abs(coeff_vector)) for coeff_vector in coeff_vector_list]) 19 | return l1 * loss 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /src/deepymod_torch/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Library(nn.Module): 6 | def __init__(self, library_func, library_args={}): 7 | super().__init__() 8 | self.library_func = library_func 9 | self.library_args = library_args 10 | 11 | def forward(self, input): 12 | time_deriv_list, theta = self.library_func(input, **self.library_args) 13 | return time_deriv_list, theta 14 | 15 | 16 | class Fitting(nn.Module): 17 | def __init__(self, n_terms, n_out): 18 | super().__init__() 19 | self.coeff_vector = nn.ParameterList([torch.nn.Parameter(torch.rand((n_terms, 1), dtype=torch.float32)) for _ in torch.arange(n_out)]) 20 | self.sparsity_mask = [torch.arange(n_terms) for _ in torch.arange(n_out)] 21 | 22 | def forward(self, input): 23 | sparse_theta = self.apply_mask(input) 24 | return sparse_theta, self.coeff_vector 25 | 26 | def apply_mask(self, theta): 27 | sparse_theta = [theta[:, sparsity_mask] for sparsity_mask in self.sparsity_mask] 28 | return sparse_theta 29 | -------------------------------------------------------------------------------- /src/deepymod_torch/output.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys, time 3 | import torch 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | class Tensorboard(): 7 | '''Tensorboard class for logging during deepmod training. ''' 8 | def __init__(self, number_of_terms): 9 | self.writer = SummaryWriter() 10 | self.writer.add_custom_scalars(custom_board(number_of_terms)) 11 | 12 | def write(self, iteration, loss, loss_mse, loss_reg, loss_l1, coeff_vector_list, coeff_vector_scaled_list): 13 | # Logs losses, costs and coeff vectors. 14 | self.writer.add_scalar('Total loss', loss, iteration) 15 | for idx in range(len(loss_mse)): 16 | self.writer.add_scalar('MSE '+str(idx), loss_mse[idx], iteration) 17 | self.writer.add_scalar('Regression '+str(idx), loss_reg[idx], iteration) 18 | self.writer.add_scalar('L1 '+str(idx), loss_l1[idx], iteration) 19 | for element_idx, element in enumerate(torch.unbind(coeff_vector_list[idx])): # Tensorboard doesnt have vectors, so we unbind and plot them in together in custom board 20 | self.writer.add_scalar('coeff ' + str(idx) + ' ' + str(element_idx), element, iteration) 21 | for element_idx, element in enumerate(torch.unbind(coeff_vector_scaled_list[idx])): 22 | self.writer.add_scalar('scaled_coeff ' + str(idx) + ' ' + str(element_idx), element, iteration) 23 | 24 | def close(self): 25 | self.writer.close() 26 | 27 | def custom_board(number_of_terms): 28 | '''Custom scalar board for tensorboard.''' 29 | number_of_eqs = len(number_of_terms) 30 | # Initial setup, including all the costs and losses 31 | custom_board = {'Costs': {'MSE': ['Multiline', ['MSE_' + str(idx) for idx in np.arange(number_of_eqs)]], 32 | 'Regression': ['Multiline', ['Regression_' + str(idx) for idx in np.arange(number_of_eqs)]], 33 | 'L1': ['Multiline', ['L1_' + str(idx) for idx in np.arange(number_of_eqs)]]}, 34 | 'Coefficients': {}, 35 | 'Scaled coefficients': {}} 36 | 37 | # Add plot of normal and scaled coefficients for each equation, containing every component in single plot. 38 | for idx in np.arange(number_of_eqs): 39 | custom_board['Coefficients']['Vector_' + str(idx)] = ['Multiline', ['coeff_' + str(idx) + '_' + str(element_idx) for element_idx in np.arange(number_of_terms[idx])]] 40 | custom_board['Scaled coefficients']['Vector_' + str(idx)] = ['Multiline', ['scaled_coeff_' + str(idx) + '_' + str(element_idx) for element_idx in np.arange(number_of_terms[idx])]] 41 | 42 | return custom_board 43 | 44 | def progress(iteration, start_time, max_iteration, cost, MSE, PI, L1): 45 | '''Prints and updates progress of training cycle in command line.''' 46 | percent = iteration.item()/max_iteration * 100 47 | elapsed_time = time.time() - start_time 48 | time_left = elapsed_time * (max_iteration/iteration - 1) if iteration.item() != 0 else 0 49 | sys.stdout.write(f"\r {iteration:>9} {percent:>7.2f}% {time_left:>13.0f}s {cost:>8.2e} {MSE:>8.2e} {PI:>8.2e} {L1:>8.2e} ") 50 | sys.stdout.flush() -------------------------------------------------------------------------------- /src/deepymod_torch/sparsity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def scaling_single_vec(coeff_vector, sparse_theta, time_deriv): 4 | ''' 5 | Rescales the weight vector according to vec_rescaled = vec * |library|/|time_deriv|. 6 | Columns in library correspond to elements of weight_vector. 7 | ''' 8 | scaling_time = torch.norm(time_deriv, dim=0) 9 | scaling_theta = torch.norm(sparse_theta, dim=0)[:, None] 10 | coeff_vector_scaled = coeff_vector * (scaling_theta / scaling_time) 11 | 12 | return coeff_vector_scaled 13 | 14 | def scaling(coeff_vector_list, sparse_theta_list, time_deriv_list): 15 | '''Wrapper around scaling_single_vec to scale multiple eqs. See scaling_single_vec for more details. ''' 16 | coeff_vector_scaled_list = [scaling_single_vec(coeff_vector, sparse_theta, time_deriv) for time_deriv, sparse_theta, coeff_vector in zip(time_deriv_list, sparse_theta_list, coeff_vector_list)] 17 | return coeff_vector_scaled_list 18 | 19 | def threshold_single(coeff_vector_scaled, coeff_vector): 20 | '''Removes coefficient if |value| < std(coefficient_vec) and returns new coefficient vector and sparsity mask. ''' 21 | sparse_coeff_vector = torch.where(torch.abs(coeff_vector_scaled) > torch.std(coeff_vector_scaled, dim=0), coeff_vector, torch.zeros_like(coeff_vector_scaled)) 22 | sparsity_mask = torch.nonzero(sparse_coeff_vector)[:, 0].detach() # detach it so it doesn't get optimized and throws an error 23 | sparse_coeff_vector = torch.nn.Parameter(sparse_coeff_vector[sparsity_mask].clone().detach()) 24 | 25 | return sparse_coeff_vector, sparsity_mask 26 | 27 | def threshold(coeff_vector_list, sparse_theta_list, time_deriv_list): 28 | '''Wrapper around threshold_single to threshold list of vectors. Also performs scaling.''' 29 | 30 | coeff_vector_scaled_list = scaling(coeff_vector_list, sparse_theta_list, time_deriv_list) 31 | result = [threshold_single(coeff_vector_scaled, coeff_vector) for coeff_vector_scaled, coeff_vector in zip(coeff_vector_scaled_list, coeff_vector_list)] 32 | sparse_coeff_vector_list, sparsity_mask_list = map(list, zip(*result)) 33 | 34 | return sparse_coeff_vector_list, sparsity_mask_list -------------------------------------------------------------------------------- /src/deepymod_torch/training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | from deepymod_torch.output import Tensorboard, progress 5 | from deepymod_torch.losses import reg_loss, mse_loss, l1_loss 6 | from deepymod_torch.sparsity import scaling, threshold 7 | 8 | def train(model, data, target, optimizer, max_iterations, loss_func_args={'l1':1e-5}): 9 | '''Trains the deepmod model with MSE, regression and l1 cost function. Updates model in-place.''' 10 | start_time = time.time() 11 | number_of_terms = [coeff_vec.shape[0] for coeff_vec in model(data)[3]] 12 | board = Tensorboard(number_of_terms) 13 | 14 | # Training 15 | print('| Iteration | Progress | Time remaining | Cost | MSE | Reg | L1 |') 16 | for iteration in torch.arange(0, max_iterations + 1): 17 | # Calculating prediction and library and scaling 18 | prediction, time_deriv_list, sparse_theta_list, coeff_vector_list = model(data) 19 | coeff_vector_scaled_list = scaling(coeff_vector_list, sparse_theta_list, time_deriv_list) 20 | 21 | # Calculating loss 22 | loss_reg = reg_loss(time_deriv_list, sparse_theta_list, coeff_vector_list) 23 | loss_mse = mse_loss(prediction, target) 24 | loss_l1 = l1_loss(coeff_vector_scaled_list, loss_func_args['l1']) 25 | loss = torch.sum(loss_reg) + torch.sum(loss_mse) + torch.sum(loss_l1) 26 | 27 | # Writing 28 | if iteration % 100 == 0: 29 | progress(iteration, start_time, max_iterations, loss.item(), torch.sum(loss_mse).item(), torch.sum(loss_reg).item(), torch.sum(loss_l1).item()) 30 | board.write(iteration, loss, loss_mse, loss_reg, loss_l1, coeff_vector_list, coeff_vector_scaled_list) 31 | 32 | # Optimizer step 33 | optimizer.zero_grad() 34 | loss.backward() 35 | optimizer.step() 36 | board.close() 37 | 38 | def train_mse(model, data, target, optimizer, max_iterations, loss_func_args={}): 39 | '''Trains the deepmod model only on the MSE. Updates model in-place.''' 40 | start_time = time.time() 41 | number_of_terms = [coeff_vec.shape[0] for coeff_vec in model(data)[3]] 42 | board = Tensorboard(number_of_terms) 43 | 44 | # Training 45 | print('| Iteration | Progress | Time remaining | Cost | MSE | Reg | L1 |') 46 | for iteration in torch.arange(0, max_iterations + 1): 47 | # Calculating prediction and library and scaling 48 | prediction, time_deriv_list, sparse_theta_list, coeff_vector_list = model(data) 49 | coeff_vector_scaled_list = scaling(coeff_vector_list, sparse_theta_list, time_deriv_list) 50 | 51 | # Calculating loss 52 | loss_mse = mse_loss(prediction, target) 53 | loss = torch.sum(loss_mse) 54 | 55 | # Writing 56 | if iteration % 100 == 0: 57 | progress(iteration, start_time, max_iterations, loss.item(), torch.sum(loss_mse).item(), 0, 0) 58 | board.write(iteration, loss, loss_mse, [0], [0], coeff_vector_list, coeff_vector_scaled_list) 59 | 60 | # Optimizer step 61 | optimizer.zero_grad() 62 | loss.backward() 63 | optimizer.step() 64 | board.close() 65 | 66 | def train_deepmod(model, data, target, optimizer, max_iterations, loss_func_args): 67 | '''Performs full deepmod cycle: trains model, thresholds and trains again for unbiased estimate. Updates model in-place.''' 68 | # Train first cycle and get prediction 69 | train(model, data, target, optimizer, max_iterations, loss_func_args) 70 | prediction, time_deriv_list, sparse_theta_list, coeff_vector_list = model(data) 71 | 72 | # Threshold, set sparsity mask and coeff vector 73 | sparse_coeff_vector_list, sparsity_mask_list = threshold(coeff_vector_list, sparse_theta_list, time_deriv_list) 74 | model.fit.sparsity_mask = sparsity_mask_list 75 | model.fit.coeff_vector = torch.nn.ParameterList(sparse_coeff_vector_list) 76 | 77 | print() 78 | print(sparse_coeff_vector_list) 79 | print(sparsity_mask_list) 80 | 81 | #Resetting optimizer for different shapes, train without l1 82 | optimizer.param_groups[0]['params'] = model.parameters() 83 | print() #empty line for correct printing 84 | train(model, data, target, optimizer, max_iterations, dict(loss_func_args, **{'l1': 0.0})) 85 | 86 | -------------------------------------------------------------------------------- /src/deepymod_torch/utilities.py: -------------------------------------------------------------------------------- 1 | from itertools import product, combinations 2 | import sys 3 | import torch 4 | 5 | def string_matmul(list_1, list_2): 6 | ''' Matrix multiplication with strings.''' 7 | prod = [element[0] + element[1] for element in product(list_1, list_2)] 8 | return prod 9 | 10 | 11 | def terms_definition(poly_list, deriv_list): 12 | ''' Calculates which terms are in the library.''' 13 | if len(poly_list) == 1: 14 | theta = string_matmul(poly_list[0], deriv_list[0]) # If we have a single output, we simply calculate and flatten matrix product between polynomials and derivatives to get library 15 | else: 16 | theta_uv = list(chain.from_iterable([string_matmul(u, v) for u, v in combinations(poly_list, 2)])) # calculate all unique combinations between polynomials 17 | theta_dudv = list(chain.from_iterable([string_matmul(du, dv)[1:] for du, dv in combinations(deriv_list, 2)])) # calculate all unique combinations of derivatives 18 | theta_udu = list(chain.from_iterable([string_matmul(u[1:], du[1:]) for u, du in product(poly_list, deriv_list)])) # calculate all unique combinations of derivatives 19 | theta = theta_uv + theta_dudv + theta_udu 20 | return theta 21 | 22 | def create_deriv_data(X, max_order): 23 | ''' 24 | Automatically creates data-deriv tuple to feed to derivative network. 25 | Shape before network is (sample x order x input). 26 | Shape after network will be (sample x order x input x output). 27 | ''' 28 | 29 | if max_order == 1: 30 | dX = (torch.eye(X.shape[1]) * torch.ones(X.shape[0])[:, None, None])[:, None, :] 31 | else: 32 | dX = [torch.eye(X.shape[1]) * torch.ones(X.shape[0])[:, None, None]] 33 | dX.extend([torch.zeros_like(dX[0]) for order in range(max_order-1)]) 34 | dX = torch.stack(dX, dim=1) 35 | 36 | return (X, dX) 37 | 38 | -------------------------------------------------------------------------------- /tests/Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 25, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "# Burgers, tests 1D input\n", 19 | "\n", 20 | "# General imports\n", 21 | "import numpy as np\n", 22 | "import torch\n", 23 | "\n", 24 | "# DeepMoD stuff\n", 25 | "from deepymod_torch.DeepMod import DeepMod\n", 26 | "from deepymod_torch.library_functions import library_1D_in\n", 27 | "from deepymod_torch.training import train_deepmod, train_mse\n", 28 | "\n", 29 | "# Setting cuda\n", 30 | "if torch.cuda.is_available():\n", 31 | " torch.set_default_tensor_type('torch.cuda.FloatTensor')\n", 32 | "\n", 33 | "# Settings for reproducibility\n", 34 | "np.random.seed(42)\n", 35 | "torch.manual_seed(0)\n", 36 | "torch.backends.cudnn.deterministic = True\n", 37 | "torch.backends.cudnn.benchmark = False\n", 38 | "\n", 39 | "%load_ext autoreload\n", 40 | "%autoreload 2" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 26, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# Loading data\n", 50 | "data = np.load('data/burgers.npy', allow_pickle=True).item()\n", 51 | "X = np.transpose((data['t'].flatten(), data['x'].flatten()))\n", 52 | "y = np.real(data['u']).reshape((data['u'].size, 1))\n", 53 | "number_of_samples = 1000\n", 54 | "\n", 55 | "idx = np.random.permutation(y.size)\n", 56 | "X_train = torch.tensor(X[idx, :][:number_of_samples], dtype=torch.float32, requires_grad=True)\n", 57 | "y_train = torch.tensor(y[idx, :][:number_of_samples], dtype=torch.float32, requires_grad=True)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 27, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "## Running DeepMoD\n", 67 | "config = {'input_dim': 2, 'hidden_dims': [20, 20, 20, 20, 20, 20], 'output_dim': 1, 'library_function': library_1D_in, 'library_args':{'poly_order': 2, 'diff_order': 2}}" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 28, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "model = DeepMod(config)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 29, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "data": { 86 | "text/plain": [ 87 | "DeepMod(\n", 88 | " (network): Sequential(\n", 89 | " (0): Linear(in_features=2, out_features=20, bias=True)\n", 90 | " (1): Tanh()\n", 91 | " (2): Linear(in_features=20, out_features=20, bias=True)\n", 92 | " (3): Tanh()\n", 93 | " (4): Linear(in_features=20, out_features=20, bias=True)\n", 94 | " (5): Tanh()\n", 95 | " (6): Linear(in_features=20, out_features=20, bias=True)\n", 96 | " (7): Tanh()\n", 97 | " (8): Linear(in_features=20, out_features=20, bias=True)\n", 98 | " (9): Tanh()\n", 99 | " (10): Linear(in_features=20, out_features=20, bias=True)\n", 100 | " (11): Tanh()\n", 101 | " (12): Linear(in_features=20, out_features=1, bias=True)\n", 102 | " )\n", 103 | " (library): Library()\n", 104 | " (fit): Fitting(\n", 105 | " (coeff_vector): ParameterList( (0): Parameter containing: [torch.FloatTensor of size 9x1])\n", 106 | " )\n", 107 | ")" 108 | ] 109 | }, 110 | "execution_count": 29, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "model" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 30, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "optimizer = torch.optim.Adam([{'params': model.network.parameters(), 'lr':0.002}, {'params': model.fit.parameters(), 'lr':0.002}])" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 32, 131 | "metadata": {}, 132 | "outputs": [ 133 | { 134 | "data": { 135 | "text/plain": [ 136 | "DeepMod(\n", 137 | " (network): Sequential(\n", 138 | " (0): Linear(in_features=2, out_features=20, bias=True)\n", 139 | " (1): Tanh()\n", 140 | " (2): Linear(in_features=20, out_features=20, bias=True)\n", 141 | " (3): Tanh()\n", 142 | " (4): Linear(in_features=20, out_features=20, bias=True)\n", 143 | " (5): Tanh()\n", 144 | " (6): Linear(in_features=20, out_features=20, bias=True)\n", 145 | " (7): Tanh()\n", 146 | " (8): Linear(in_features=20, out_features=20, bias=True)\n", 147 | " (9): Tanh()\n", 148 | " (10): Linear(in_features=20, out_features=20, bias=True)\n", 149 | " (11): Tanh()\n", 150 | " (12): Linear(in_features=20, out_features=1, bias=True)\n", 151 | " )\n", 152 | " (library): Library()\n", 153 | " (fit): Fitting(\n", 154 | " (coeff_vector): ParameterList( (0): Parameter containing: [torch.FloatTensor of size 9x1])\n", 155 | " )\n", 156 | ")" 157 | ] 158 | }, 159 | "execution_count": 32, 160 | "metadata": {}, 161 | "output_type": "execute_result" 162 | } 163 | ], 164 | "source": [ 165 | "model" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 38, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "def train(model, data, target, optimizer, max_iterations, loss_func_args):\n", 175 | " '''Trains the deepmod model with MSE, regression and l1 cost function. Updates model in-place.'''\n", 176 | " start_time = 0#time.time()\n", 177 | " number_of_terms = [9]#[coeff_vec.shape[0] for coeff_vec in model(data)[3]]\n", 178 | " #board = Tensorboard(number_of_terms)\n", 179 | "\n", 180 | " # Training\n", 181 | " print('| Iteration | Progress | Time remaining | Cost | MSE | Reg | L1 |')\n", 182 | " for iteration in torch.arange(0, max_iterations + 1):\n", 183 | " # Calculating prediction and library and scaling\n", 184 | " prediction, time_deriv_list, sparse_theta_list, coeff_vector_list = model(data)\n", 185 | " coeff_vector_scaled_list = scaling(coeff_vector_list, sparse_theta_list, time_deriv_list) \n", 186 | " \n", 187 | " # Calculating loss\n", 188 | " loss_reg = reg_loss(time_deriv_list, sparse_theta_list, coeff_vector_list)\n", 189 | " loss_mse = mse_loss(prediction, target)\n", 190 | " loss_l1 = l1_loss(coeff_vector_scaled_list, loss_func_args['l1'])\n", 191 | " loss = torch.sum(loss_reg) + torch.sum(loss_mse) + torch.sum(loss_l1)\n", 192 | " \n", 193 | " # Writing\n", 194 | " if iteration % 100 == 0:\n", 195 | " progress(iteration, start_time, max_iterations, loss.item(), torch.sum(loss_mse).item(), torch.sum(loss_reg).item(), torch.sum(loss_l1).item())\n", 196 | " #board.write(iteration, loss, loss_mse, loss_reg, loss_l1, coeff_vector_list, coeff_vector_scaled_list)\n", 197 | "\n", 198 | " # Optimizer step\n", 199 | " optimizer.zero_grad()\n", 200 | " loss.backward()\n", 201 | " optimizer.step()\n", 202 | " #board.close()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 39, 208 | "metadata": {}, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "| Iteration | Progress | Time remaining | Cost | MSE | Reg | L1 |\n" 215 | ] 216 | }, 217 | { 218 | "ename": "NameError", 219 | "evalue": "name 'scaling' is not defined", 220 | "output_type": "error", 221 | "traceback": [ 222 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 223 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 224 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'l1'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m1e-5\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 225 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, data, target, optimizer, max_iterations, loss_func_args)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;31m# Calculating prediction and library and scaling\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mprediction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime_deriv_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparse_theta_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcoeff_vector_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mcoeff_vector_scaled_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscaling\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcoeff_vector_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msparse_theta_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime_deriv_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;31m# Calculating loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 226 | "\u001b[0;31mNameError\u001b[0m: name 'scaling' is not defined" 227 | ] 228 | } 229 | ], 230 | "source": [ 231 | "train(model, X_train, y_train, optimizer, 1000, {'l1':1e-5})" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [] 240 | } 241 | ], 242 | "metadata": { 243 | "kernelspec": { 244 | "display_name": "Python 3", 245 | "language": "python", 246 | "name": "python3" 247 | }, 248 | "language_info": { 249 | "codemirror_mode": { 250 | "name": "ipython", 251 | "version": 3 252 | }, 253 | "file_extension": ".py", 254 | "mimetype": "text/x-python", 255 | "name": "python", 256 | "nbconvert_exporter": "python", 257 | "pygments_lexer": "ipython3", 258 | "version": "3.6.9" 259 | } 260 | }, 261 | "nbformat": 4, 262 | "nbformat_minor": 4 263 | } 264 | -------------------------------------------------------------------------------- /tests/burgers.py: -------------------------------------------------------------------------------- 1 | # Burgers, tests 1D input 2 | 3 | # General imports 4 | import numpy as np 5 | import torch 6 | 7 | # DeepMoD stuff 8 | from deepymod_torch.DeepMod import DeepMod 9 | from deepymod_torch.library_functions import library_1D_in 10 | from deepymod_torch.training import train_deepmod, train_mse 11 | 12 | # Setting cuda 13 | if torch.cuda.is_available(): 14 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 15 | 16 | # Settings for reproducibility 17 | np.random.seed(42) 18 | torch.manual_seed(0) 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | 22 | # Loading data 23 | data = np.load('data/burgers.npy', allow_pickle=True).item() 24 | X = np.transpose((data['t'].flatten(), data['x'].flatten())) 25 | y = np.real(data['u']).reshape((data['u'].size, 1)) 26 | number_of_samples = 1000 27 | 28 | idx = np.random.permutation(y.size) 29 | X_train = torch.tensor(X[idx, :][:number_of_samples], dtype=torch.float32, requires_grad=True) 30 | y_train = torch.tensor(y[idx, :][:number_of_samples], dtype=torch.float32, requires_grad=True) 31 | 32 | ## Running DeepMoD 33 | config = {'n_in': 2, 'hidden_dims': [20, 20, 20, 20, 20, 20], 'n_out': 1, 'library_function': library_1D_in, 'library_args':{'poly_order': 2, 'diff_order': 2}} 34 | 35 | model = DeepMod(**config) 36 | optimizer = torch.optim.Adam([{'params': model.network_parameters(), 'lr':0.002}, {'params': model.coeff_vector(), 'lr':0.002}]) 37 | #train_mse(model, X_train, y_train, optimizer, 1000) 38 | train_deepmod(model, X_train, y_train, optimizer, 1000, {'l1': 1e-5}) 39 | 40 | 41 | print() 42 | print(model.fit.sparsity_mask) 43 | print(model.fit.coeff_vector) -------------------------------------------------------------------------------- /tests/data/burgers.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/data/burgers.npy -------------------------------------------------------------------------------- /tests/data/keller_segel.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/data/keller_segel.npy -------------------------------------------------------------------------------- /tests/diffusion.py: -------------------------------------------------------------------------------- 1 | # Diffusion, tests 2D input, D = 0.5 2 | 3 | # General imports 4 | import numpy as np 5 | import torch 6 | 7 | # DeepMoD stuff 8 | from deepymod_torch.DeepMod import DeepMod 9 | from deepymod_torch.library_functions import library_basic 10 | from deepymod_torch.utilities import create_deriv_data 11 | 12 | # Setting cuda 13 | if torch.cuda.is_available(): 14 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 15 | 16 | # Settings for reproducibility 17 | np.random.seed(42) 18 | torch.manual_seed(0) 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | 22 | # Loading data 23 | data = np.load('data/diffusion_2D.npy', allow_pickle=True).item() 24 | X = np.transpose((data['t'].flatten(), data['x'].flatten(), data['y'].flatten())) 25 | y = np.real(data['u']).reshape((data['u'].size, 1)) 26 | number_of_samples = 1000 27 | 28 | idx = np.random.permutation(y.size) 29 | X_train = torch.tensor(X[idx, :][:number_of_samples], dtype=torch.float32, requires_grad=True) 30 | y_train = torch.tensor(y[idx, :][:number_of_samples], dtype=torch.float32) 31 | 32 | ## Running DeepMoD 33 | config = {'input_dim': 3, 'hidden_dim': 20, 'layers': 5, 'output_dim': 1, 'library_function': library_basic, 'library_args':{'poly_order': 1, 'diff_order': 2}} 34 | 35 | X_input = create_deriv_data(X_train, config['library_args']['diff_order']) 36 | 37 | model = DeepMod(config) 38 | optimizer = torch.optim.Adam(model.parameters(), lr=0.002) 39 | model.train(X_input, y_train, optimizer, 5000, type='deepmod') 40 | 41 | print() 42 | print(model.sparsity_mask_list) 43 | print(model.coeff_vector_list) -------------------------------------------------------------------------------- /tests/keller_segel.py: -------------------------------------------------------------------------------- 1 | # Keller Segel, tests coupled output. 2 | 3 | # General imports 4 | import numpy as np 5 | import torch 6 | 7 | # DeepMoD stuff 8 | from deepymod_torch.DeepMod import DeepMod 9 | from deepymod_torch.library_functions import library_basic 10 | from deepymod_torch.utilities import create_deriv_data 11 | 12 | # Setting cuda 13 | if torch.cuda.is_available(): 14 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 15 | 16 | # Settings for reproducibility 17 | np.random.seed(42) 18 | torch.manual_seed(0) 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | 22 | # Loading data 23 | data = np.load('data/keller_segel.npy', allow_pickle=True).item() 24 | X = np.transpose((data['t'].flatten(), data['x'].flatten())) 25 | y = np.transpose((data['u'].flatten(), data['v'].flatten())) 26 | number_of_samples = 5000 27 | 28 | idx = np.random.permutation(y.shape[0]) 29 | X_train = torch.tensor(X[idx, :][:number_of_samples], dtype=torch.float32) 30 | y_train = torch.tensor(y[idx, :][:number_of_samples], dtype=torch.float32) 31 | 32 | ## Running DeepMoD 33 | config = {'input_dim': 2, 'hidden_dim': 20, 'layers': 5, 'output_dim': 2, 'library_function': library_basic, 'library_args':{'poly_order': 1, 'diff_order': 2}} 34 | 35 | X_input = create_deriv_data(X_train, config['library_args']['diff_order']) 36 | optimizer = torch.optim.Adam(model.parameters()) 37 | model.train(X_input, y_train, optimizer, 100000, type='deepmod') 38 | 39 | print() 40 | print(model.sparsity_mask_list) 41 | print(model.coeff_vector_list) -------------------------------------------------------------------------------- /tests/runs/Mar18_10-10-19_2a6de4b656b9/events.out.tfevents.1584526219.2a6de4b656b9.90002.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_10-10-19_2a6de4b656b9/events.out.tfevents.1584526219.2a6de4b656b9.90002.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_10-16-15_2a6de4b656b9/events.out.tfevents.1584526575.2a6de4b656b9.90002.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_10-16-15_2a6de4b656b9/events.out.tfevents.1584526575.2a6de4b656b9.90002.1 -------------------------------------------------------------------------------- /tests/runs/Mar18_10-54-59_2a6de4b656b9/events.out.tfevents.1584528899.2a6de4b656b9.84437.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_10-54-59_2a6de4b656b9/events.out.tfevents.1584528899.2a6de4b656b9.84437.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_10-55-36_2a6de4b656b9/events.out.tfevents.1584528936.2a6de4b656b9.84887.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_10-55-36_2a6de4b656b9/events.out.tfevents.1584528936.2a6de4b656b9.84887.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_10-55-50_2a6de4b656b9/events.out.tfevents.1584528950.2a6de4b656b9.85113.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_10-55-50_2a6de4b656b9/events.out.tfevents.1584528950.2a6de4b656b9.85113.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_10-59-05_2a6de4b656b9/events.out.tfevents.1584529145.2a6de4b656b9.85113.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_10-59-05_2a6de4b656b9/events.out.tfevents.1584529145.2a6de4b656b9.85113.1 -------------------------------------------------------------------------------- /tests/runs/Mar18_11-18-25_2a6de4b656b9/events.out.tfevents.1584530305.2a6de4b656b9.1129.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_11-18-25_2a6de4b656b9/events.out.tfevents.1584530305.2a6de4b656b9.1129.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_11-20-38_2a6de4b656b9/events.out.tfevents.1584530438.2a6de4b656b9.1129.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_11-20-38_2a6de4b656b9/events.out.tfevents.1584530438.2a6de4b656b9.1129.1 -------------------------------------------------------------------------------- /tests/runs/Mar18_11-21-06_2a6de4b656b9/events.out.tfevents.1584530466.2a6de4b656b9.3011.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_11-21-06_2a6de4b656b9/events.out.tfevents.1584530466.2a6de4b656b9.3011.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_11-23-42_2a6de4b656b9/events.out.tfevents.1584530622.2a6de4b656b9.3011.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_11-23-42_2a6de4b656b9/events.out.tfevents.1584530622.2a6de4b656b9.3011.1 -------------------------------------------------------------------------------- /tests/runs/Mar18_11-25-57_2a6de4b656b9/events.out.tfevents.1584530757.2a6de4b656b9.6415.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_11-25-57_2a6de4b656b9/events.out.tfevents.1584530757.2a6de4b656b9.6415.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_11-28-57_2a6de4b656b9/events.out.tfevents.1584530937.2a6de4b656b9.6415.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_11-28-57_2a6de4b656b9/events.out.tfevents.1584530937.2a6de4b656b9.6415.1 -------------------------------------------------------------------------------- /tests/runs/Mar18_11-52-32_2a6de4b656b9/events.out.tfevents.1584532352.2a6de4b656b9.24739.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_11-52-32_2a6de4b656b9/events.out.tfevents.1584532352.2a6de4b656b9.24739.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_11-58-06_2a6de4b656b9/events.out.tfevents.1584532686.2a6de4b656b9.28482.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_11-58-06_2a6de4b656b9/events.out.tfevents.1584532686.2a6de4b656b9.28482.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_12-57-02_2a6de4b656b9/events.out.tfevents.1584536222.2a6de4b656b9.65888.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_12-57-02_2a6de4b656b9/events.out.tfevents.1584536222.2a6de4b656b9.65888.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_12-59-42_2a6de4b656b9/events.out.tfevents.1584536382.2a6de4b656b9.65888.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_12-59-42_2a6de4b656b9/events.out.tfevents.1584536382.2a6de4b656b9.65888.1 -------------------------------------------------------------------------------- /tests/runs/Mar18_13-00-54_2a6de4b656b9/events.out.tfevents.1584536454.2a6de4b656b9.68447.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_13-00-54_2a6de4b656b9/events.out.tfevents.1584536454.2a6de4b656b9.68447.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_13-04-42_2a6de4b656b9/events.out.tfevents.1584536682.2a6de4b656b9.68447.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_13-04-42_2a6de4b656b9/events.out.tfevents.1584536682.2a6de4b656b9.68447.1 -------------------------------------------------------------------------------- /tests/runs/Mar18_13-45-28_2a6de4b656b9/events.out.tfevents.1584539128.2a6de4b656b9.97639.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_13-45-28_2a6de4b656b9/events.out.tfevents.1584539128.2a6de4b656b9.97639.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_13-47-36_2a6de4b656b9/events.out.tfevents.1584539256.2a6de4b656b9.99060.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_13-47-36_2a6de4b656b9/events.out.tfevents.1584539256.2a6de4b656b9.99060.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_13-50-01_2a6de4b656b9/events.out.tfevents.1584539401.2a6de4b656b9.1039.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_13-50-01_2a6de4b656b9/events.out.tfevents.1584539401.2a6de4b656b9.1039.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_13-51-19_2a6de4b656b9/events.out.tfevents.1584539479.2a6de4b656b9.1935.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_13-51-19_2a6de4b656b9/events.out.tfevents.1584539479.2a6de4b656b9.1935.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_14-04-33_2a6de4b656b9/events.out.tfevents.1584540273.2a6de4b656b9.10769.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_14-04-33_2a6de4b656b9/events.out.tfevents.1584540273.2a6de4b656b9.10769.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_14-04-57_2a6de4b656b9/events.out.tfevents.1584540297.2a6de4b656b9.11078.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_14-04-57_2a6de4b656b9/events.out.tfevents.1584540297.2a6de4b656b9.11078.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_14-05-40_2a6de4b656b9/events.out.tfevents.1584540340.2a6de4b656b9.99504.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_14-05-40_2a6de4b656b9/events.out.tfevents.1584540340.2a6de4b656b9.99504.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_14-06-48_2a6de4b656b9/events.out.tfevents.1584540408.2a6de4b656b9.99504.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_14-06-48_2a6de4b656b9/events.out.tfevents.1584540408.2a6de4b656b9.99504.1 -------------------------------------------------------------------------------- /tests/runs/Mar18_14-07-09_2a6de4b656b9/events.out.tfevents.1584540429.2a6de4b656b9.99504.2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_14-07-09_2a6de4b656b9/events.out.tfevents.1584540429.2a6de4b656b9.99504.2 -------------------------------------------------------------------------------- /tests/runs/Mar18_14-09-01_2a6de4b656b9/events.out.tfevents.1584540541.2a6de4b656b9.13828.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_14-09-01_2a6de4b656b9/events.out.tfevents.1584540541.2a6de4b656b9.13828.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_14-09-39_2a6de4b656b9/events.out.tfevents.1584540579.2a6de4b656b9.13828.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_14-09-39_2a6de4b656b9/events.out.tfevents.1584540579.2a6de4b656b9.13828.1 -------------------------------------------------------------------------------- /tests/runs/Mar18_14-20-25_2a6de4b656b9/events.out.tfevents.1584541225.2a6de4b656b9.22220.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_14-20-25_2a6de4b656b9/events.out.tfevents.1584541225.2a6de4b656b9.22220.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_14-21-03_2a6de4b656b9/events.out.tfevents.1584541263.2a6de4b656b9.22220.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_14-21-03_2a6de4b656b9/events.out.tfevents.1584541263.2a6de4b656b9.22220.1 -------------------------------------------------------------------------------- /tests/runs/Mar18_14-25-46_2a6de4b656b9/events.out.tfevents.1584541546.2a6de4b656b9.26142.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_14-25-46_2a6de4b656b9/events.out.tfevents.1584541546.2a6de4b656b9.26142.0 -------------------------------------------------------------------------------- /tests/runs/Mar18_14-26-20_2a6de4b656b9/events.out.tfevents.1584541580.2a6de4b656b9.26142.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_14-26-20_2a6de4b656b9/events.out.tfevents.1584541580.2a6de4b656b9.26142.1 -------------------------------------------------------------------------------- /tests/runs/Mar18_14-49-13_2a6de4b656b9/events.out.tfevents.1584542953.2a6de4b656b9.41183.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PhIMaL/DeePyMoD_torch/b39dff541d641369e9fde5136389ee589ce8f29f/tests/runs/Mar18_14-49-13_2a6de4b656b9/events.out.tfevents.1584542953.2a6de4b656b9.41183.0 --------------------------------------------------------------------------------