├── Dockerfile ├── experiments ├── dataloader.py ├── visualization.py ├── rnn.py ├── hidden_physics.py ├── direct_solution.py ├── time_stepper.py └── latent_neural_odes.py ├── .gitignore ├── requirements.txt ├── README.md └── docs ├── survey_structure.svg └── time_vs_state_plot.svg /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8-slim-buster 2 | 3 | WORKDIR /app 4 | 5 | COPY requirements.txt requirements.txt 6 | 7 | RUN apt-get update && apt-get install build-essential -y 8 | RUN pip3 install --upgrade pip 9 | RUN pip3 install --upgrade pip setuptools wheel 10 | RUN pip3 install -r requirements.txt 11 | 12 | COPY . . -------------------------------------------------------------------------------- /experiments/dataloader.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import numpy as np 3 | import torch 4 | 5 | from enum import Enum 6 | from smt.sampling_methods import LHS 7 | from torchdyn.numerics import odeint 8 | 9 | class Sampling(Enum): 10 | RANDOM = 0 11 | GRID = 1 12 | 13 | 14 | def grid_init_samples(domain, n_trajectories: int) -> np.ndarray: 15 | x = np.linspace(domain[0][0], domain[0][1], n_trajectories) 16 | y = np.linspace(domain[1][0], domain[1][1], n_trajectories) 17 | 18 | xx, yy = np.meshgrid(x, y) 19 | return np.concatenate((xx.flatten()[..., np.newaxis], yy.flatten()[..., np.newaxis]), axis=1) 20 | 21 | 22 | def random_init_samples(domain, n_trajectories: int) -> np.ndarray: 23 | values = LHS(xlimits=np.array(domain)) 24 | return values(n_trajectories) 25 | 26 | 27 | def pendulum(t, y): 28 | θ = y[:, 0] 29 | ω = y[:, 1] 30 | 31 | dθ = ω 32 | dω = -torch.sin(θ) 33 | 34 | return torch.stack((dθ, dω), dim=1) 35 | 36 | 37 | def load_pendulum_data(t_span, y0s_domain=None, n_trajectories=100, sampling=Sampling.RANDOM, solver='rk4') -> Tuple[torch.Tensor, torch.Tensor]: 38 | if not y0s_domain: 39 | y0s_domain = [[-1., 1.], [-1., 1.]] 40 | 41 | if sampling == Sampling.RANDOM: 42 | y0s = random_init_samples(y0s_domain, n_trajectories) 43 | elif sampling == Sampling.GRID: 44 | y0s = grid_init_samples(y0s_domain, n_trajectories) 45 | 46 | y0s = torch.tensor(y0s) 47 | _, ys = odeint(pendulum, y0s, t_span, solver) 48 | 49 | return y0s, ys 50 | -------------------------------------------------------------------------------- /experiments/visualization.py: -------------------------------------------------------------------------------- 1 | from matplotlib import cm, pyplot as plt 2 | from matplotlib.collections import LineCollection 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def get_meshgrid(step_per_axis, domain=1.0, flatten=False, dtype=torch.float64): 8 | ε = 1e-10 9 | x = torch.arange(-domain, domain + ε, step_per_axis, dtype=dtype) 10 | y = x = torch.arange(-domain, domain + ε, step_per_axis, dtype=dtype) 11 | 12 | xy = torch.stack(torch.meshgrid((x, y))) 13 | if flatten: 14 | return xy.flatten(1).T 15 | else: 16 | return xy.T 17 | 18 | 19 | def plot_colored(fig, ax, t, x, cmap="jet", label=None, colorbar=False, **kwargs): 20 | """ 21 | t : (time x traj) 22 | x : (time x traj x state) 23 | """ 24 | # x = torch.tensor(x) 25 | # t = torch.tensor(t) 26 | x = torch.as_tensor(x) 27 | t = torch.as_tensor(t) 28 | 29 | if t.ndim == 1: 30 | t = t.unsqueeze(-1).expand(-1, x.shape[1]) 31 | 32 | norm = plt.Normalize(t.min(), t.max()) 33 | 34 | for i in range(t.shape[1]): 35 | xi = x[:, i] 36 | ti = t[:, i] 37 | segments = torch.stack([xi[:-1], xi[1:]], axis=1) 38 | 39 | if i == t.shape[1] - 1: 40 | lc = LineCollection(segments, cmap=cmap, norm=norm, label=label, **kwargs) 41 | else: 42 | lc = LineCollection(segments, cmap=cmap, norm=norm, **kwargs) 43 | lc.set_array(ti) 44 | ax.add_collection(lc) 45 | 46 | if colorbar: 47 | fig.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax) 48 | ax.set_xlim(x.min(), x.max()) 49 | ax.set_ylim(x.min(), x.max()) 50 | -------------------------------------------------------------------------------- /experiments/rnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from matplotlib import pyplot as plt 5 | 6 | from tqdm import tqdm 7 | from torch import nn 8 | from torch.nn.functional import mse_loss 9 | from torchdyn.numerics import odeint 10 | 11 | from dataloader import Sampling, load_pendulum_data 12 | 13 | def run(): 14 | """ 15 | Training data 16 | 17 | """ 18 | y0s_domain = [[-1., 1.], [-1., 1.]] 19 | n_steps = 1 20 | step_size = 0.01 21 | t_span = torch.arange(0., step_size * (n_steps + 1), step_size) 22 | 23 | y0s, ys = load_pendulum_data(t_span, y0s_domain, n_trajectories=20, sampling=Sampling.RANDOM) 24 | 25 | """ 26 | Network 27 | 28 | """ 29 | class RNN(nn.Module): 30 | 31 | def __init__(self, n_states): 32 | super().__init__() 33 | self.rnn = nn.RNN(n_states, n_states, batch_first=True).double() 34 | 35 | def forward(self, y0s): 36 | y0s = y0s.unsqueeze(dim=1) 37 | out, hidden = self.rnn(y0s, torch.swapaxes(y0s, 0, 1)) 38 | return hidden.squeeze(dim=0) 39 | 40 | 41 | model = RNN(n_states=2) 42 | opt = torch.optim.Adam(model.parameters()) 43 | 44 | """ 45 | Training 46 | 47 | """ 48 | epochs = 20 49 | progress = tqdm(range(epochs), 'Training') 50 | losses = [] 51 | 52 | for _ in progress: 53 | _, y_pred = odeint(lambda t, y: model(y), y0s, t_span, 'euler') 54 | 55 | loss = mse_loss(y_pred, ys) 56 | loss.backward() 57 | opt.step() 58 | opt.zero_grad() 59 | 60 | losses.append(loss.item()) 61 | progress.set_description(f'loss: {loss.item()}') 62 | 63 | """ 64 | Test data 65 | 66 | """ 67 | y0s_domain = [[-1., 1.], [-1., 1.]] 68 | n_steps = 20 69 | step_size = 0.01 70 | t_span = torch.arange(0., step_size * (n_steps + 1), step_size) 71 | 72 | y0s, ys = load_pendulum_data(t_span, y0s_domain, n_trajectories=10, sampling=Sampling.GRID) 73 | _, y_pred = odeint(lambda t, y: model(y), y0s, t_span, 'euler') 74 | 75 | """ 76 | Plot 77 | 78 | """ 79 | plt.plot(y_pred.detach().numpy()[:, :, 1], y_pred.detach().numpy()[:, :, 0], color='r') 80 | plt.plot(ys.numpy()[:, :, 1], ys.numpy()[:, :, 0], color='b') 81 | plt.scatter(ys[0, :, 1], ys[0, :, 0], color='g') 82 | plt.ylim(y0s_domain[0]) 83 | plt.xlim(y0s_domain[1]) 84 | plt.show() 85 | 86 | 87 | 88 | if __name__ == "__main__": 89 | run() 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | aiohttp==3.7.4.post0 3 | argon2-cffi==20.1.0 4 | async-timeout==3.0.1 5 | attrs==21.2.0 6 | backcall==0.2.0 7 | bleach==4.0.0 8 | boltons==21.0.0 9 | cachetools==4.2.2 10 | certifi==2021.5.30 11 | cffi==1.14.6 12 | chardet==4.0.0 13 | charset-normalizer==2.0.4 14 | cycler==0.10.0 15 | debugpy==1.4.1 16 | decorator==5.0.9 17 | defusedxml==0.7.1 18 | einops==0.3.0 19 | entrypoints==0.3 20 | fsspec==2021.7.0 21 | future==0.18.2 22 | google-auth==1.35.0 23 | google-auth-oauthlib==0.4.5 24 | grpcio==1.39.0 25 | idna==3.2 26 | ipykernel==6.2.0 27 | ipython==7.26.0 28 | ipython-genutils==0.2.0 29 | ipywidgets==7.6.3 30 | jedi==0.18.0 31 | Jinja2==3.0.1 32 | joblib==1.0.1 33 | jsonschema==3.2.0 34 | jupyter-client==6.1.12 35 | jupyter-core==4.7.1 36 | jupyterlab-pygments==0.1.2 37 | jupyterlab-widgets==1.0.0 38 | kiwisolver==1.3.1 39 | Markdown==3.3.4 40 | MarkupSafe==2.0.1 41 | matplotlib==3.4.3 42 | matplotlib-inline==0.1.2 43 | mistune==0.8.4 44 | multidict==5.1.0 45 | nbclient==0.5.4 46 | nbconvert==6.1.0 47 | nbformat==5.1.3 48 | nest-asyncio==1.5.1 49 | notebook==6.4.3 50 | numpy==1.21.2 51 | oauthlib==3.1.1 52 | packaging==21.0 53 | pandocfilters==1.4.3 54 | parso==0.8.2 55 | pastel==0.2.1 56 | pexpect==4.8.0 57 | pickleshare==0.7.5 58 | Pillow==8.3.1 59 | poethepoet==0.10.0 60 | prometheus-client==0.11.0 61 | prompt-toolkit==3.0.19 62 | protobuf==3.17.3 63 | ptyprocess==0.7.0 64 | pyasn1==0.4.8 65 | pyasn1-modules==0.2.8 66 | pycparser==2.20 67 | pyDeprecate==0.3.1 68 | pyDOE==0.3.8 69 | Pygments==2.10.0 70 | pyparsing==2.4.7 71 | pyrsistent==0.18.0 72 | python-dateutil==2.8.2 73 | pytorch-lightning==1.4.2 74 | PyYAML==5.4.1 75 | pyzmq==22.2.1 76 | requests==2.26.0 77 | requests-oauthlib==1.3.0 78 | rsa==4.7.2 79 | scikit-learn==0.24.2 80 | scipy==1.7.1 81 | Send2Trash==1.8.0 82 | six==1.16.0 83 | sklearn==0.0 84 | tensorboard==2.6.0 85 | tensorboard-data-server==0.6.1 86 | tensorboard-plugin-wit==1.8.0 87 | terminado==0.11.0 88 | testpath==0.5.0 89 | threadpoolctl==2.2.0 90 | tomlkit==0.7.2 91 | torch==1.9.0 92 | torchaudio==0.9.0 93 | torchcde==0.2.3 94 | torchdiffeq==0.2.2 95 | torchdyn==1.0.1 96 | torchmetrics==0.5.0 97 | torchsde==0.2.5 98 | torchvision==0.10.0 99 | tornado==6.1 100 | tqdm==4.62.1 101 | traitlets==5.0.5 102 | trampoline==0.1.2 103 | typing-extensions==3.10.0.0 104 | urllib3==1.26.6 105 | wcwidth==0.2.5 106 | webencodings==0.5.1 107 | Werkzeug==2.0.1 108 | widgetsnbextension==3.5.1 109 | yarl==1.6.3 110 | absl-py==0.13.0 111 | aiohttp==3.7.4.post0 112 | alabaster==0.7.12 113 | appdirs==1.4.4 114 | argon2-cffi==20.1.0 115 | async-timeout==3.0.1 116 | attrs==21.2.0 117 | Babel==2.9.1 118 | backcall==0.2.0 119 | black==21.7b0 120 | bleach==4.0.0 121 | boltons==21.0.0 122 | cachetools==4.2.2 123 | certifi==2021.5.30 124 | cffi==1.14.6 125 | chardet==4.0.0 126 | charset-normalizer==2.0.4 127 | click==8.0.1 128 | cycler==0.10.0 129 | debugpy==1.4.1 130 | decorator==5.0.9 131 | defusedxml==0.7.1 132 | docutils==0.17.1 133 | einops==0.3.0 134 | entrypoints==0.3 135 | fsspec==2021.7.0 136 | future==0.18.2 137 | google-auth==1.35.0 138 | google-auth-oauthlib==0.4.5 139 | grpcio==1.39.0 140 | idna==3.2 141 | imagesize==1.2.0 142 | ipykernel==6.2.0 143 | ipython==7.26.0 144 | ipython-genutils==0.2.0 145 | ipywidgets==7.6.3 146 | jedi==0.18.0 147 | Jinja2==3.0.1 148 | joblib==1.0.1 149 | jsonschema==3.2.0 150 | jupyter-client==6.1.12 151 | jupyter-core==4.7.1 152 | jupyterlab-pygments==0.1.2 153 | jupyterlab-widgets==1.0.0 154 | kiwisolver==1.3.1 155 | Markdown==3.3.4 156 | MarkupSafe==2.0.1 157 | matplotlib==3.4.3 158 | matplotlib-inline==0.1.2 159 | mistune==0.8.4 160 | multidict==5.1.0 161 | mypy-extensions==0.4.3 162 | nbclient==0.5.4 163 | nbconvert==6.1.0 164 | nbformat==5.1.3 165 | nest-asyncio==1.5.1 166 | notebook==6.4.3 167 | numpy==1.21.2 168 | numpydoc==1.1.0 169 | oauthlib==3.1.1 170 | packaging==21.0 171 | pandocfilters==1.4.3 172 | parso==0.8.2 173 | pastel==0.2.1 174 | pathspec==0.9.0 175 | pexpect==4.8.0 176 | pickleshare==0.7.5 177 | Pillow==8.3.1 178 | poethepoet==0.10.0 179 | prometheus-client==0.11.0 180 | prompt-toolkit==3.0.19 181 | protobuf==3.17.3 182 | ptyprocess==0.7.0 183 | pyasn1==0.4.8 184 | pyasn1-modules==0.2.8 185 | pycparser==2.20 186 | pyDeprecate==0.3.1 187 | pyDOE==0.3.8 188 | pyDOE2==1.3.0 189 | Pygments==2.10.0 190 | pyparsing==2.4.7 191 | pyrsistent==0.18.0 192 | python-dateutil==2.8.2 193 | pytorch-lightning==1.4.2 194 | pytz==2021.1 195 | PyYAML==5.4.1 196 | pyzmq==22.2.1 197 | regex==2021.8.3 198 | requests==2.26.0 199 | requests-oauthlib==1.3.0 200 | rsa==4.7.2 201 | scikit-learn==0.24.2 202 | scipy==1.7.1 203 | Send2Trash==1.8.0 204 | six==1.16.0 205 | sklearn==0.0 206 | smt==1.0.0 207 | snowballstemmer==2.1.0 208 | Sphinx==4.1.2 209 | sphinxcontrib-applehelp==1.0.2 210 | sphinxcontrib-devhelp==1.0.2 211 | sphinxcontrib-htmlhelp==2.0.0 212 | sphinxcontrib-jsmath==1.0.1 213 | sphinxcontrib-qthelp==1.0.3 214 | sphinxcontrib-serializinghtml==1.1.5 215 | tensorboard==2.6.0 216 | tensorboard-data-server==0.6.1 217 | tensorboard-plugin-wit==1.8.0 218 | terminado==0.11.0 219 | testpath==0.5.0 220 | threadpoolctl==2.2.0 221 | tomli==1.2.1 222 | tomlkit==0.7.2 223 | torch==1.9.0 224 | torchaudio==0.9.0 225 | torchcde==0.2.3 226 | torchdiffeq==0.2.2 227 | torchdyn==1.0.1 228 | torchmetrics==0.5.0 229 | torchsde==0.2.5 230 | torchvision==0.10.0 231 | tornado==6.1 232 | tqdm==4.62.1 233 | traitlets==5.0.5 234 | trampoline==0.1.2 235 | typing-extensions==3.10.0.0 236 | urllib3==1.26.6 237 | wcwidth==0.2.5 238 | webencodings==0.5.1 239 | Werkzeug==2.0.1 240 | widgetsnbextension==3.5.1 241 | yarl==1.6.3 242 | -------------------------------------------------------------------------------- /experiments/hidden_physics.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from collections import defaultdict 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from scipy.integrate import solve_ivp 10 | from torch.autograd import grad 11 | from tqdm import tqdm 12 | 13 | 14 | def listify(A): 15 | return [a for a in A.flatten()] 16 | 17 | 18 | def xavier_init(module): 19 | for m in module.modules(): 20 | if type(m) == nn.Linear: 21 | nn.init.xavier_uniform_(m.weight) 22 | 23 | 24 | def plot_loss(losses): 25 | fig, ax = plt.subplots() 26 | fig.canvas.manager.set_window_title(f"loss terms") 27 | 28 | for loss_name, loss in losses.items(): 29 | ax.plot(loss, label=loss_name) 30 | 31 | ax.legend() 32 | ax.set_xlabel("epoch") 33 | 34 | 35 | def l_fun(t): 36 | return np.sin(t) + 2 37 | 38 | 39 | def f(t, y): 40 | θ, ω = y 41 | g = 1.0 42 | l = l_fun(t) 43 | 44 | dω = -(g / l) * np.sin(θ) 45 | dθ = ω 46 | 47 | return dθ, dω 48 | 49 | 50 | if __name__ == "__main__": 51 | 52 | parser = ArgumentParser() 53 | parser.add_argument("--device", default="cpu") 54 | parser.add_argument("--n_epochs", default=6000, type=int) 55 | parser.add_argument("--t_start", default=0.0, type=float) 56 | parser.add_argument("--t_end", type=float, default=np.pi * 4) 57 | args = parser.parse_args() 58 | 59 | device = args.device 60 | n_epochs = args.n_epochs 61 | 62 | """ 63 | generate data 64 | 65 | """ 66 | y0 = [np.pi / 4, 0] 67 | step_size = 0.01 68 | 69 | t_start = args.t_start 70 | t_end = args.t_end 71 | 72 | t_eval = np.arange(t_start, t_end, step_size) 73 | 74 | y_true = solve_ivp( 75 | f, t_span=(t_start, t_end), t_eval=t_eval, y0=y0, method="RK45" 76 | ).y 77 | θ, ω = y_true 78 | l = torch.tensor(l_fun(t_eval)).to(device) 79 | 80 | """ 81 | network 82 | 83 | """ 84 | nn_hidden = ( 85 | nn.Sequential( 86 | nn.Linear(1, 32), 87 | nn.Softplus(), 88 | nn.Linear(32, 32), 89 | nn.Softplus(), 90 | nn.Linear(32, 32), 91 | nn.Softplus(), 92 | nn.Linear(32, 32), 93 | nn.Softplus(), 94 | nn.Linear(32, 2), 95 | ) 96 | .double() 97 | .to(device) 98 | ) 99 | 100 | xavier_init(nn_hidden) 101 | optimizer = torch.optim.Adam(nn_hidden.parameters()) 102 | 103 | t_train = torch.tensor(t_eval, requires_grad=True).to(device) 104 | y_train = torch.tensor(y_true).to(device) 105 | 106 | subsample_every = int(2.5 / step_size) 107 | losses = defaultdict(list) 108 | 109 | """ 110 | training 111 | 112 | """ 113 | for _ in tqdm(range(n_epochs), "training hidden physics model"): 114 | θ_pred, l_pred = nn_hidden(t_train[..., None]).T 115 | 116 | ω_pred = grad( 117 | listify(θ_pred), 118 | t_train, 119 | only_inputs=True, 120 | retain_graph=True, 121 | create_graph=True, 122 | )[0] 123 | 124 | dω_pred = grad( 125 | listify(ω_pred), 126 | t_train, 127 | only_inputs=True, 128 | retain_graph=True, 129 | create_graph=True, 130 | )[0] 131 | 132 | dω_eq = -(1.0 / l_pred) * torch.sin(θ_pred) 133 | 134 | y_pred = torch.column_stack((θ_pred, ω_pred)).T 135 | 136 | loss_collocation = F.mse_loss( 137 | y_pred[:, ::subsample_every], y_train[:, ::subsample_every] 138 | ) 139 | loss_hidden = F.mse_loss(dω_pred, dω_eq) 140 | loss_length = F.mse_loss(l_pred, l) 141 | 142 | loss = loss_collocation + loss_hidden 143 | 144 | loss.backward() 145 | optimizer.step() 146 | optimizer.zero_grad() 147 | 148 | losses["collocation"].append(loss_collocation.item()) 149 | losses["hidden"].append(loss_hidden.item()) 150 | 151 | plot_loss(losses) 152 | 153 | predicted = { 154 | "θ(t)": θ_pred.detach().cpu().flatten(), 155 | "ω(t)": ω_pred.detach().cpu().flatten(), 156 | "l(t)": l_pred.detach().cpu().flatten(), 157 | } 158 | true = { 159 | "θ(t)": θ, 160 | "ω(t)": ω, 161 | "l(t)": l.detach().cpu(), 162 | } 163 | 164 | fig, (ax0, ax1, ax2) = plt.subplots(3, 1, sharex=True) 165 | fig.canvas.manager.set_window_title(f"states") 166 | 167 | ax0.set_ylabel("θ(t)") 168 | ax0.plot(t_eval, θ, c="black", label="true") 169 | ax0.plot( 170 | t_eval, 171 | θ_pred.detach().cpu().flatten(), 172 | c="b", 173 | linestyle="--", 174 | label="predicted", 175 | ) 176 | ax0.scatter( 177 | t_eval[::subsample_every], 178 | θ[::subsample_every], 179 | c="black", 180 | linestyle="None", 181 | label="collocation point", 182 | ) 183 | 184 | ax1.set_ylabel("ω(t)") 185 | ax1.plot(t_eval, ω, c="black", label="true") 186 | ax1.plot( 187 | t_eval, 188 | ω_pred.detach().cpu().flatten(), 189 | c="r", 190 | linestyle="--", 191 | label="predicted", 192 | ) 193 | ax1.scatter( 194 | t_eval[::subsample_every], 195 | ω[::subsample_every], 196 | c="black", 197 | linestyle="None", 198 | label="collocation point", 199 | ) 200 | 201 | ax2.set_ylabel("l(t)") 202 | ax2.set_xlabel("t") 203 | ax2.plot(t_eval, l, c="black", label="true") 204 | ax2.plot( 205 | t_eval, 206 | l_pred.detach().cpu().flatten(), 207 | c="g", 208 | linestyle="--", 209 | label="predicted", 210 | ) 211 | # skip drawing misleading collocation points, since none are used for the pendulum length 212 | 213 | ax1.legend() 214 | plt.tight_layout() 215 | 216 | plt.show() 217 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Constructing Neural Network-Based Models for Simulating Dynamical Systems 2 | 3 | This is a companion repo for the review paper **Constructing Neural Network-Based Models for Simulating Dynamical Systems** which provides a practical description on how models like *Neural Ordinary Differential Equations* and *Physics-informed Neural Networks* can be implemented. 4 | The full paper can be accessed at: https://dl.acm.org/doi/10.1145/3567591 5 | 6 | The code in the repo is implemented in Python using PyTorch for defining and training the models. 7 | The scripts can be run using default parameters to reproduce the plots seen in the paper, as well as things like loss curves which were cut due to space requirements: 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | # Installing dependencies 16 | 17 | The dependencies necessary to run the scripts can be installed through pip using the `requirements.txt` file as follows: 18 | ``` bash 19 | python3 -m pip install -r requirements.txt 20 | ``` 21 | In case you are using Conda you can run the following command from a fresh environment: 22 | ``` bash 23 | conda install --file requirements.txt 24 | ``` 25 | 26 | # Running the experiments 27 | 28 | Each experiment can be run using default parameters by executing the script in the python interpreter as follows: 29 | ``` 30 | python3 experiments/.py ... 31 | ``` 32 | The table below contains the commands necessary to train and evaluate the models described in the review paper. 33 | 34 | | Name | Section | Command | 35 | | -------------------------------------------- | ------- | --------------------------------------------------------------- | 36 | | Vanilla Direct-Solution | 3.2 | python3 experiments/direct_solution.py --model vanilla | 37 | | Automatic Differentiation in Direct-Solution | 3.3 | python3 experiments/direct_solution.py --model autodiff | 38 | | Physics Informed Neural Networks | 3.4 | python3 experiments/direct_solution.py --model pinn | 39 | | Hidden Physics Networks | 3.5 | python3 experiments/hidden_physics.py | 40 | | Direct Time-Stepper | 4.2.1 | python3 experiments/time_stepper.py --solver direct | 41 | | Residual Time-Stepper | 4.2.2 | python3 experiments/time_stepper.py --solver resnet | 42 | | Euler Time-Stepper | 4.2.3 | python3 experiments/time_stepper.py --solver euler | 43 | | Neural ODEs Time-Stepper | 4.2.4 | python3 experiments/time_stepper.py --solver {rk4,dopri5,tsit5} | 44 | | Neural State-Space Model | 4.3.1 | ... | 45 | | Neural ODEs with input | 4.3.2-3 | ... | 46 | | Lagrangian Time-Stepper | 4.4.1 | ... | 47 | | Hamiltonian Time-Stepper | 4.4.1 | ... | 48 | | Deep Potential Time-Stepper | 4.4.2 | ... | 49 | | Deep Markov-Model | 4.5.1 | ... | 50 | | Latent Neural ODEs | 4.5.2 | python3 experiments/latent_neural_odes.py | 51 | | Bayesian Neural ODEs | 4.5.3 | ... | 52 | | Neural SDEs | 4.5.4 | ... | 53 | 54 | 55 | 56 | # Docker Image 57 | In an effort to ensure that the code can be executed in the future, we provide a docker image. 58 | The Docker image allows the code to be run in a Linux based virtual machine on any platform supported by Docker. 59 | 60 | To use the docker image, invoke the build command in the root of this repository: 61 | ``` bash 62 | docker build . -t python_dynamical_systems 63 | ``` 64 | 65 | Following this "containers" containing the code and all dependencies can be instantiated via the "run" command: 66 | ``` bash 67 | docker run -ti python_dynamical_systems bash 68 | ``` 69 | The command will establish an interactive connection to the container. 70 | Following this you can execute the code as if it was running on your host machine: 71 | ``` bash 72 | python3 experiments/time_stepper.py ... 73 | ``` 74 | 75 | # Citing the paper 76 | 77 | If you use the work please consider citing it: 78 | ``` bibtex 79 | @article{10.1145/3567591, 80 | author = {Legaard, Christian and Schranz, Thomas and Schweiger, Gerald and Drgo\v{n}a, J\'{a}n and Falay, Basak and Gomes, Cl\'{a}udio and Iosifidis, Alexandros and Abkar, Mahdi and Larsen, Peter}, 81 | title = {Constructing Neural Network Based Models for Simulating Dynamical Systems}, 82 | year = {2023}, 83 | issue_date = {November 2023}, 84 | publisher = {Association for Computing Machinery}, 85 | address = {New York, NY, USA}, 86 | volume = {55}, 87 | number = {11}, 88 | issn = {0360-0300}, 89 | url = {https://doi.org/10.1145/3567591}, 90 | doi = {10.1145/3567591}, 91 | abstract = {Dynamical systems see widespread use in natural sciences like physics, biology, and chemistry, as well as engineering disciplines such as circuit analysis, computational fluid dynamics, and control. For simple systems, the differential equations governing the dynamics can be derived by applying fundamental physical laws. However, for more complex systems, this approach becomes exceedingly difficult. Data-driven modeling is an alternative paradigm that seeks to learn an approximation of the dynamics of a system using observations of the true system. In recent years, there has been an increased interest in applying data-driven modeling techniques to solve a wide range of problems in physics and engineering. This article provides a survey of the different ways to construct models of dynamical systems using neural networks. In addition to the basic overview, we review the related literature and outline the most significant challenges from numerical simulations that this modeling paradigm must overcome. Based on the reviewed literature and identified challenges, we provide a discussion on promising research areas.}, 92 | journal = {ACM Comput. Surv.}, 93 | month = {feb}, 94 | articleno = {236}, 95 | numpages = {34}, 96 | keywords = {physics-informed neural networks, physics-based regularization, Neural ODEs} 97 | } 98 | ``` 99 | -------------------------------------------------------------------------------- /experiments/direct_solution.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from collections import defaultdict 3 | from math import ceil, sin 4 | from math import floor 5 | 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import grad 12 | from torch.nn.modules.container import T 13 | 14 | from tqdm import tqdm 15 | import matplotlib.pyplot as plt 16 | from scipy.integrate import solve_ivp 17 | import matplotlib as mpl 18 | 19 | 20 | if __name__ == "__main__": 21 | 22 | parser = ArgumentParser() 23 | parser.add_argument( 24 | "--model", 25 | choices=["vanilla", "autodiff", "pinn"], 26 | default="vanilla", 27 | ) 28 | parser.add_argument("--hidden_dim", default=32, type=int) 29 | parser.add_argument("--n_layers", default=5, type=int) 30 | parser.add_argument("--device", default="cpu") 31 | parser.add_argument("--n_epochs", default=2000, type=int) 32 | parser.add_argument("--t_start", default=0.0, type=float) 33 | parser.add_argument("--t_end", type=float, default=np.pi * 4) 34 | args = parser.parse_args() 35 | 36 | ############### Setup Experiment ############### 37 | y0 = [np.pi / 4, 0] 38 | step_size = 0.01 39 | t_start = args.t_start 40 | t_end = args.t_end 41 | t_span = (t_start, t_end) 42 | t_eval = np.arange(t_start, t_end, step_size) # 0.0 , 0.01 43 | 44 | g = 1.0 # gravitational acceleration [m/s^2] 45 | l = 1.0 # length of pendulum [m] 46 | 47 | model = args.model 48 | n_epochs = args.n_epochs 49 | device = args.device 50 | subsample_every = int(2.5 / step_size) 51 | 52 | ############### Define Derivative ############### 53 | def f(t, y): 54 | θ, ω = y # state variables go in 55 | g = 1.0 56 | l = 1.0 57 | dω = -(g / l) * np.sin(θ) 58 | dθ = ω # special case (common for mechanical systems), the state variable ω is per definition dθ 59 | 60 | return dθ, dω # derivatives of state variables go out 61 | 62 | ############### Solve ODE ############### 63 | 64 | res = solve_ivp(f, t_span, t_eval=t_eval, y0=y0, method="RK45") 65 | 66 | ############### Plot ############### 67 | 68 | def plot_colored(ax, x, y, c, cmap=plt.cm.jet, steps=10, **kwargs): 69 | a = c.size 70 | c = np.asarray(c) 71 | c -= c.min() 72 | c = c / c.max() 73 | it = 0 74 | while it < c.size - steps: 75 | x_segm = x[it : it + steps + 1] 76 | y_segm = y[it : it + steps + 1] 77 | c_segm = cmap(c[it + steps // 2]) 78 | ax.plot(x_segm, y_segm, c=c_segm, **kwargs) 79 | it += steps 80 | 81 | θ, ω = res.y 82 | 83 | def xavier_init(module): 84 | for m in module.modules(): 85 | if type(m) == nn.Linear: 86 | nn.init.xavier_uniform_(m.weight) 87 | 88 | def construct_network(input_dim, output_dim, hidden_dim, hidden_layers): 89 | 90 | layers = [nn.Linear(input_dim, hidden_dim), nn.Softplus()] 91 | for _ in range(hidden_layers): 92 | layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.Softplus()]) 93 | layers.append(nn.Linear(hidden_dim, output_dim)) 94 | 95 | net = nn.Sequential(*layers).double().to(device) 96 | xavier_init(net) 97 | return net 98 | 99 | hidden_dim = args.hidden_dim 100 | hidden_layers = args.n_layers 101 | 102 | t = torch.tensor(t_eval, device=device, requires_grad=True) 103 | losses = defaultdict(lambda: defaultdict(list)) 104 | 105 | # `torch.autograd.grad` supports only lists of scalar values (e.g. single evaluation of network). 106 | # however the function accepts a list of these. 107 | def listify(A): 108 | return [a for a in A.flatten()] 109 | 110 | y_train = torch.tensor(res.y[:, ::subsample_every]).to(device) 111 | t_train = torch.tensor(t_eval[::subsample_every], requires_grad=True).to(device) 112 | θω_pred = None 113 | 114 | if model == "vanilla": 115 | nn_vanilla = construct_network(1, 2, hidden_dim, hidden_layers) 116 | 117 | opt_vanilla = torch.optim.Adam(nn_vanilla.parameters()) 118 | 119 | for epoch in tqdm(range(n_epochs), desc="vanilla: training epoch"): 120 | out = nn_vanilla(t_train.unsqueeze(-1)).T 121 | 122 | loss_collocation = F.mse_loss(out, y_train) 123 | 124 | loss_collocation.backward() 125 | opt_vanilla.step() 126 | nn_vanilla.zero_grad() 127 | losses["vanilla"]["collocation"].append(loss_collocation.item()) 128 | 129 | θω_pred = nn_vanilla(t.unsqueeze(-1)).detach().detach().cpu().T 130 | 131 | # autodiff 132 | elif model == "autodiff": 133 | nn_autodiff = construct_network(1, 1, hidden_dim, hidden_layers) 134 | opt_autodiff = torch.optim.Adam(nn_autodiff.parameters()) 135 | 136 | for epoch in tqdm(range(n_epochs), desc="autodiff: training epoch"): 137 | θ_pred = nn_autodiff(t_train.unsqueeze(-1)).T 138 | 139 | θ_listed = listify(θ_pred) 140 | 141 | # [0] since we differentiate with respect to an "single input", 142 | # which is coincidentially a tensor. 143 | # in this case ω ≜ dθ 144 | ω_pred = grad( 145 | θ_listed, 146 | t_train, 147 | only_inputs=True, 148 | retain_graph=True, 149 | create_graph=True, 150 | )[0].unsqueeze(0) 151 | 152 | θω = torch.cat((θ_pred, ω_pred), dim=0) 153 | 154 | loss_collocation = F.mse_loss(θω, y_train) 155 | loss_collocation.backward() 156 | 157 | # sanity check 158 | max_grad = next(nn_autodiff.modules())[0].weight.grad.max() 159 | assert ( 160 | max_grad != 0.0 161 | ), "maximal gradient of first layer was zero, something is up!" 162 | 163 | opt_autodiff.step() 164 | nn_autodiff.zero_grad() 165 | 166 | losses["autodiff"]["collocation"].append(loss_collocation.item()) 167 | 168 | θ_autodiff = nn_autodiff(t.unsqueeze(-1)).T 169 | θ_autodiff_listed = listify(θ_autodiff) 170 | ω_autodiff = grad(θ_autodiff_listed, t, only_inputs=True)[0].unsqueeze(0) 171 | θω_pred = torch.cat((θ_autodiff, ω_autodiff), dim=0).detach().cpu() 172 | 173 | # pinn 174 | elif model == "pinn": 175 | nn_pinn = construct_network(1, 1, hidden_dim, hidden_layers) 176 | 177 | opt_pinn = torch.optim.Adam(nn_pinn.parameters()) 178 | t_train_dense = torch.tensor(t_eval, requires_grad=True).to(device) 179 | losses_pinn = {"collocation": [], "equation": []} 180 | for epoch in tqdm(range(n_epochs), desc="pinn: training epoch"): 181 | θ_pred = nn_pinn(t_train_dense.unsqueeze(-1)).T 182 | θ_listed = listify(θ_pred) 183 | 184 | ω_pred = grad( 185 | θ_listed, 186 | t_train_dense, 187 | only_inputs=True, 188 | retain_graph=True, 189 | create_graph=True, 190 | )[0].unsqueeze(0) 191 | ω_listed = listify(ω_pred) 192 | dω_pred = grad( 193 | ω_listed, 194 | t_train_dense, 195 | only_inputs=True, 196 | retain_graph=True, 197 | create_graph=True, 198 | )[0].unsqueeze(0) 199 | 200 | θω_dense = torch.cat((θ_pred, ω_pred), dim=0) 201 | 202 | # collocation loss is defined for sparse set of points 203 | θω = θω_dense[:, ::subsample_every] 204 | loss_collocation = F.mse_loss(θω, y_train) 205 | 206 | # equation based loss is defined for dense samples 207 | dω_eq = -torch.sin(θ_pred) 208 | loss_equation = F.mse_loss(dω_pred, dω_eq) 209 | 210 | loss_total = loss_collocation + loss_equation 211 | loss_total.backward() 212 | # next(net.modules())[0].weight.grad => this gives you gradients of the loss 213 | 214 | # sanity check 215 | max_grad = next(nn_pinn.modules())[0].weight.grad.max() 216 | assert ( 217 | max_grad != 0.0 218 | ), "maximal gradient of first layer was zero, something is up!" 219 | 220 | opt_pinn.step() 221 | nn_pinn.zero_grad() 222 | 223 | losses["pinn"]["collocation"].append(loss_collocation.item()) 224 | losses["pinn"]["equation"].append(loss_equation.item()) 225 | 226 | θ_pinn = nn_pinn(t.unsqueeze(-1)).T 227 | θ_pinn_listed = listify(θ_pinn) 228 | ω_pinn = grad(θ_pinn_listed, t, only_inputs=True)[0].unsqueeze(0) 229 | θω_pred = torch.cat((θ_pinn, ω_pinn), dim=0).detach().cpu() 230 | 231 | def plot_predictions(model, θω_pred): 232 | ω_numerical = np.diff(θω_pred[:1]) / step_size 233 | fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True) 234 | fig.canvas.manager.set_window_title(model) 235 | ax1.set_ylabel("θ(t)") 236 | ax2.set_ylabel("ω(t)") 237 | ax2.set_xlabel("t") 238 | 239 | ax1.plot(t_eval, θ, c="black", label="true") 240 | ax1.plot(t_eval, θω_pred[0], c="b", linestyle="--", label="predicted") 241 | 242 | ax2.plot(t_eval, ω, c="black", label="true") 243 | ax2.plot(t_eval, θω_pred[1], c="r", linestyle="--", label="predicted") 244 | ax2.plot( 245 | t_eval[1:], 246 | ω_numerical.T, 247 | c="r", 248 | linestyle="dotted", 249 | label="numerical", 250 | ) 251 | 252 | ax1.scatter( 253 | t_eval[::subsample_every], 254 | res.y[:, ::subsample_every][0], 255 | c="black", 256 | linestyle="None", 257 | label="collocation point", 258 | ) 259 | ax2.scatter( 260 | t_eval[::subsample_every], 261 | res.y[:, ::subsample_every][1], 262 | c="black", 263 | linestyle="None", 264 | label="collocation point", 265 | ) 266 | ax2.legend() 267 | plt.tight_layout() 268 | 269 | # plot loss functions as function of training steps 270 | def plot_losses(model, losses): 271 | fig, ax = plt.subplots() 272 | fig.canvas.manager.set_window_title(f"loss terms '{model}'") 273 | 274 | for loss_name, loss in losses.items(): 275 | ax.plot(loss, label=loss_name) 276 | 277 | ax.legend() 278 | ax.set_xlabel("epoch") 279 | 280 | plot_predictions(model, θω_pred) 281 | plot_losses(model, losses[model]) 282 | 283 | fig = plt.figure() 284 | x, y = np.meshgrid( 285 | np.arange(-np.pi, np.pi, 0.01), 286 | np.arange(-np.pi, np.pi, 0.01), 287 | ) 288 | dθ, dω = f(None, (x, y)) 289 | plt.streamplot(x, y, dθ, dω, density=2) 290 | plt.xlabel("θ") 291 | plt.ylabel("ω") 292 | fig.canvas.manager.set_window_title(f"phase portrait '{model}'") 293 | 294 | ax = plt.gca() 295 | plot_colored(ax, θ, ω, t_eval) 296 | cmap = plt.cm.jet 297 | 298 | norm = mpl.colors.Normalize(vmin=t_eval.min(), vmax=t_eval.max()) 299 | fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax) 300 | 301 | # draw initial state 302 | plt.scatter(θ[0], ω[0], label="$y_0$", marker="*", c="g", s=200, zorder=100) 303 | plt.legend(loc="upper right") 304 | ax.set_aspect(1) 305 | 306 | plt.show() 307 | -------------------------------------------------------------------------------- /experiments/time_stepper.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from numpy import double 3 | import torch.nn as nn 4 | from torch.optim import Adam 5 | from tqdm import tqdm 6 | from pyDOE import lhs 7 | import matplotlib.pyplot as plt 8 | from torch.nn.functional import mse_loss 9 | import numpy as np 10 | 11 | import torch 12 | from torchdyn.numerics import odeint 13 | from torchdyn.numerics.solvers import SolverTemplate 14 | import matplotlib.patches as mpatches 15 | from matplotlib.collections import LineCollection 16 | 17 | from visualization import get_meshgrid, plot_colored 18 | 19 | 20 | class DirectSolver(SolverTemplate): 21 | def __init__(self, dtype=torch.float32): 22 | super().__init__(order=1) 23 | self.dtype = dtype 24 | self.stepping_class = "fixed" 25 | 26 | def step(self, f, x, t, dt, k1=None): 27 | 28 | x_sol = f(t, x) 29 | return None, x_sol, None 30 | 31 | 32 | class ResnetSolver(SolverTemplate): 33 | def __init__(self, step_size=None, dtype=torch.float32): 34 | super().__init__(order=1) 35 | self.dtype = dtype 36 | self.stepping_class = "fixed" 37 | 38 | self.step_size = 1 if step_size is None else step_size 39 | 40 | def step(self, f, x, t, dt, k1=None): 41 | x_sol = x + f(t, x) * self.step_size 42 | return None, x_sol, None 43 | 44 | 45 | if __name__ == "__main__": 46 | 47 | parser = ArgumentParser() 48 | parser.add_argument( 49 | "--solver", 50 | choices=["direct", "resnet", "euler", "rk4", "dopri5", "tsit5"], 51 | default="euler", 52 | ) 53 | parser.add_argument("--hidden_dim", default=32, type=int) 54 | parser.add_argument("--n_layers", default=8, type=int) 55 | parser.add_argument("--device", default="cpu") 56 | parser.add_argument("--n_epochs", default=2000, type=int) 57 | parser.add_argument("--n_traj_train", default=100, type=int) 58 | parser.add_argument("--n_traj_validate", default=10, type=int) 59 | parser.add_argument("--t_start_train", default=0.0, type=float) 60 | parser.add_argument("--t_end_train", type=float, default=0.001) 61 | parser.add_argument("--t_start_validate", default=0.0, type=float) 62 | parser.add_argument("--t_end_validate", type=float, default=4 * np.pi) 63 | parser.add_argument("--step_size_train", default=0.001, type=float) 64 | parser.add_argument("--step_size_validate", default=0.001, type=float) 65 | parser.add_argument("--noise_std", default=0.0, type=float) 66 | parser.add_argument("--domain_train", default=1.0, type=float) 67 | parser.add_argument("--domain_validate", default=1.0, type=float) 68 | args = parser.parse_args() 69 | 70 | # generate data 71 | 72 | def f(t, x): 73 | θ = x[..., 0] 74 | ω = x[..., 1] 75 | 76 | dθ = ω 77 | dω = -torch.sin(θ) 78 | 79 | return torch.stack((dθ, dω), dim=-1) 80 | 81 | domain_draw_factor = 1.3 82 | 83 | domain_train = args.domain_train 84 | domain_validate = args.domain_validate 85 | x0_train = ( 86 | torch.tensor(lhs(2, args.n_traj_train), device=args.device) * 2 - 1 87 | ) * domain_train 88 | x0_validate = ( 89 | torch.tensor(lhs(2, args.n_traj_validate), device=args.device) * 2 - 1 90 | ) * domain_validate 91 | x0_grid = get_meshgrid(step_per_axis=0.01, domain=domain_validate) 92 | x0_example = torch.tensor((0.6, 0)).double().unsqueeze(0).to(args.device) 93 | 94 | step_size_train = args.step_size_train 95 | ε = 1e-10 96 | t_span_train = torch.arange( 97 | args.t_start_train, args.t_end_train + ε, step_size_train 98 | ) 99 | t_span_validate = torch.arange( 100 | args.t_start_validate, 101 | args.t_end_validate + ε, 102 | args.step_size_validate, 103 | ) 104 | 105 | if args.solver.lower() == "direct": 106 | solver = DirectSolver() 107 | elif args.solver.lower() == "resnet": 108 | solver = ResnetSolver() 109 | else: 110 | solver = args.solver 111 | 112 | _, x_train = odeint(f, x0_train, t_span_train, solver="rk4") 113 | x_true = x_train 114 | x_train = x_train + torch.randn_like(x_train) * args.noise_std 115 | 116 | _, x_validate = odeint(f, x0_validate, t_span_validate, solver="rk4") 117 | _, x_example = odeint(f, x0_example, t_span_validate, solver="rk4") 118 | 119 | ##################### model ########################## 120 | layers = [] 121 | layers.append(nn.Linear(2, args.hidden_dim)) 122 | for _ in range(args.n_layers): 123 | layers.append(nn.Linear(args.hidden_dim, args.hidden_dim)) 124 | layers.append(nn.Softplus()) 125 | 126 | layers.append(nn.Linear(args.hidden_dim, 2)) 127 | 128 | net = nn.Sequential(*layers) 129 | net.to(args.device).double() 130 | 131 | for m in net.modules(): 132 | if type(m) == nn.Linear: 133 | nn.init.xavier_uniform_(m.weight) 134 | 135 | # optimizer 136 | 137 | opt = Adam(net.parameters()) 138 | 139 | # train 140 | losses = [] 141 | 142 | for _ in tqdm(range(args.n_epochs)): 143 | 144 | _, x_pred_train = odeint( 145 | lambda t, x: net(x), x0_train, t_span_train, solver=solver 146 | ) 147 | loss = mse_loss(x_pred_train, x_train) 148 | loss.backward() 149 | opt.step() 150 | opt.zero_grad() 151 | losses.append(loss.item()) 152 | 153 | _, x_pred_train = odeint(lambda t, x: net(x), x0_train, t_span_train, solver=solver) 154 | 155 | _, x_pred_validate = odeint( 156 | lambda t, x: net(x), x0_validate, t_span_validate, solver=solver 157 | ) 158 | 159 | _, x_pred_example = odeint( 160 | lambda t, x: net(x), x0_example, t_span_validate, solver=solver 161 | ) 162 | 163 | # derivatives 164 | # x0_grid_before = get_meshgrid(step_per_axis=0.01, domain=domain) 165 | x_derivative = f(None, x0_grid) 166 | 167 | if args.solver == "direct": 168 | 169 | out = net(x0_grid) 170 | # normalize for the step size used during training. If the network is trained with a step-size of 1/100 of a second 171 | # it will predict changes that are 100 times as small as those for 1 second. 172 | x_derivative_pred = (out - x0_grid) / step_size_train 173 | 174 | elif args.solver == "resnet": 175 | x_derivative_pred = net(x0_grid) / step_size_train 176 | else: 177 | x_derivative_pred = net(x0_grid) 178 | 179 | # plot 180 | x_pred_train = x_pred_train.detach().numpy() 181 | x_pred_validate = x_pred_validate.detach().numpy() 182 | x_pred_example = x_pred_example.detach().numpy() 183 | x_derivative_pred = x_derivative_pred.detach().numpy() 184 | x_derivative = x_derivative.detach().numpy() 185 | x0_grid = x0_grid.detach().numpy() 186 | 187 | # streamplot 188 | density = 1 189 | fig, ax = plt.subplots() 190 | fig.canvas.manager.set_window_title("stream plot") 191 | ode_patch = mpatches.Patch(color="black", label="true") 192 | nn_patch = mpatches.Patch(color="blue", label="pred") 193 | 194 | ax.streamplot( 195 | x0_grid[..., 0], 196 | x0_grid[..., 1], 197 | x_derivative[..., 0], 198 | x_derivative[..., 1], 199 | color="black", 200 | density=density, 201 | ) 202 | 203 | ax.streamplot( 204 | x0_grid[..., 0], 205 | x0_grid[..., 1], 206 | x_derivative_pred[..., 0], 207 | x_derivative_pred[..., 1], 208 | color="blue", 209 | density=density, 210 | ) 211 | 212 | ax.set_xlabel("θ") 213 | ax.set_ylabel("ω") 214 | ax.set_xlim(-domain_validate, domain_validate) 215 | ax.set_ylim(-domain_validate, domain_validate) 216 | 217 | ax.legend(handles=[ode_patch, nn_patch]) 218 | 219 | # quiver 220 | fig, ax = plt.subplots() 221 | fig.canvas.manager.set_window_title("quiver") 222 | ax.quiver( 223 | x0_grid[..., 0], 224 | x0_grid[..., 1], 225 | x_derivative[..., 0], 226 | x_derivative[..., 1], 227 | color="black", 228 | angles="xy", 229 | scale_units="xy", 230 | label="true" 231 | # scale=1, 232 | ) 233 | 234 | ax.quiver( 235 | x0_grid[..., 0], 236 | x0_grid[..., 1], 237 | x_derivative_pred[..., 0], 238 | x_derivative_pred[..., 1], 239 | color="blue", 240 | angles="xy", 241 | scale_units="xy", 242 | # scale=1, 243 | label="pred", 244 | ) 245 | ax.legend() 246 | 247 | ax.set_xlabel("θ") 248 | ax.set_ylabel("ω") 249 | ax.set_ylim(-domain_validate, domain_validate) 250 | ax.set_xlim(-domain_validate, domain_validate) 251 | 252 | # phase space, training 253 | fig, ax = plt.subplots() 254 | fig.canvas.manager.set_window_title("phase space: training") 255 | 256 | lines_true = LineCollection( 257 | [x for x in x_true.swapaxes(0, 1)], color="black", label="true" 258 | ) 259 | 260 | # lines_noise = LineCollection( 261 | # [x for x in x_train.swapaxes(0, 1)], color="red", label="noisy" 262 | # ) 263 | 264 | lines_pred = LineCollection( 265 | [x for x in x_pred_train.swapaxes(0, 1)], color="blue", label="pred" 266 | ) 267 | 268 | ax.add_collection(lines_true) 269 | if args.noise_std != 0.0: 270 | ax.scatter(x_train[..., 0], x_train[..., 1], label="observations") 271 | # ax.add_collection(lines_noise) 272 | ax.add_collection(lines_pred) 273 | 274 | ax.set_xlabel("θ") 275 | ax.set_ylabel("ω") 276 | ax.set_xlim(-domain_draw_factor * domain_train, domain_draw_factor * domain_train) 277 | ax.set_ylim(-domain_draw_factor * domain_train, domain_draw_factor * domain_train) 278 | ax.legend() 279 | 280 | # phase space, validation 281 | fig, ax = plt.subplots() 282 | fig.canvas.manager.set_window_title("phase space: validation") 283 | lines_true = LineCollection( 284 | [x for x in x_validate.swapaxes(0, 1)], color="black", label="true" 285 | ) 286 | # lines_pred = LineCollection( 287 | # [x for x in x_pred_validate.swapaxes(0, 1)], color="blue", label="pred" 288 | # ) 289 | # plot_colored(fig, ax, t_span_validate, x_validate, label="true") 290 | ax.add_collection(lines_true) 291 | 292 | plot_colored( 293 | fig, 294 | ax, 295 | t_span_validate, 296 | x_pred_validate, 297 | label="pred", 298 | colorbar=True, 299 | # linestyle="dashed", 300 | ) 301 | # ax.add_collection(lines_pred) 302 | ax.set_xlabel("θ") 303 | ax.set_ylabel("ω") 304 | ax.set_xlim(-2 * domain_validate, 2 * domain_validate) 305 | ax.set_ylim(-2 * domain_validate, 2 * domain_validate) 306 | ax.legend() 307 | 308 | # time series validation, specific idx 309 | fig, (ax1, ax2) = plt.subplots(2, sharex=True) 310 | example_idx = 0 311 | fig.canvas.manager.set_window_title(f"states vs time: validation idx={example_idx}") 312 | 313 | ax1.plot(t_span_validate, x_validate[..., example_idx, 0], color="black") 314 | ax1.plot( 315 | t_span_validate, 316 | x_pred_validate[..., example_idx, 0], 317 | linestyle="dashed", 318 | color="blue", 319 | ) 320 | ax2.plot( 321 | t_span_validate, x_validate[..., example_idx, 1], color="black", label="true" 322 | ) 323 | ax2.plot( 324 | t_span_validate, 325 | x_pred_validate[..., example_idx, 1], 326 | linestyle="dashed", 327 | color="blue", 328 | label="pred", 329 | ) 330 | ax1.set_ylabel("θ(t)") 331 | ax2.set_ylabel("ω(t)") 332 | ax2.set_xlabel("t") 333 | ax2.legend() 334 | 335 | # time series validation, example (0.6,0) 336 | fig, (ax1, ax2) = plt.subplots(2, sharex=True) 337 | fig.canvas.manager.set_window_title(f"states vs time: validation example") 338 | 339 | ax1.plot(t_span_validate, x_example[..., 0], color="black") 340 | ax1.plot( 341 | t_span_validate, 342 | x_pred_example[..., 0], 343 | linestyle="dashed", 344 | color="blue", 345 | ) 346 | ax2.plot(t_span_validate, x_example[..., 1], label="true", color="black") 347 | ax2.plot( 348 | t_span_validate, 349 | x_pred_example[..., 1], 350 | linestyle="dashed", 351 | color="blue", 352 | label="predicted", 353 | ) 354 | ax1.set_ylabel("θ(t)") 355 | ax2.set_ylabel("ω(t)") 356 | ax2.set_xlabel("t") 357 | ax2.legend() 358 | 359 | # show 360 | fig, ax = plt.subplots() 361 | ax.plot(losses) 362 | ax.set_xlabel("epoch") 363 | ax.set_ylabel("MSE") 364 | 365 | plt.show() 366 | -------------------------------------------------------------------------------- /experiments/latent_neural_odes.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from collections import defaultdict 3 | from itertools import chain 4 | from numpy.random.mtrand import choice 5 | 6 | import torch.nn as nn 7 | from torch.nn.modules.rnn import RNN 8 | from torch.optim import Adam 9 | from tqdm import tqdm 10 | from pyDOE import lhs 11 | import matplotlib.pyplot as plt 12 | from torch.nn.functional import mse_loss 13 | import numpy as np 14 | 15 | import torch 16 | from torchdyn.numerics import odeint 17 | from torchdyn.numerics.solvers import SolverTemplate 18 | import matplotlib.patches as mpatches 19 | from matplotlib.collections import LineCollection 20 | 21 | from visualization import get_meshgrid, plot_colored 22 | 23 | 24 | class DirectSolver(SolverTemplate): 25 | def __init__(self, dtype=torch.float32): 26 | super().__init__(order=1) 27 | self.dtype = dtype 28 | self.stepping_class = "fixed" 29 | 30 | def step(self, f, x, t, dt, k1=None): 31 | 32 | x_sol = f(t, x) 33 | return None, x_sol, None 34 | 35 | 36 | class ResnetSolver(SolverTemplate): 37 | def __init__(self, step_size=None, dtype=torch.float32): 38 | super().__init__(order=1) 39 | self.dtype = dtype 40 | self.stepping_class = "fixed" 41 | 42 | self.step_size = 1 if step_size is None else step_size 43 | 44 | def step(self, f, x, t, dt, k1=None): 45 | x_sol = x + f(t, x) * self.step_size 46 | return None, x_sol, None 47 | 48 | 49 | if __name__ == "__main__": 50 | 51 | parser = ArgumentParser() 52 | parser.add_argument( 53 | "--solver", choices=["direct", "resnet", "euler", "rk4"], default="euler" 54 | ) 55 | parser.add_argument("--hidden_dim", default=32, type=int) 56 | parser.add_argument("--n_layers", default=8, type=int) 57 | parser.add_argument("--device", default="cpu") 58 | parser.add_argument("--n_epochs", default=100, type=int) 59 | parser.add_argument("--n_traj_train", default=55, type=int) 60 | parser.add_argument("--n_traj_validate", default=10, type=int) 61 | parser.add_argument("--t_start_train", default=0.0, type=float) 62 | parser.add_argument("--t_end_train", default=1.0, type=float) 63 | parser.add_argument("--t_start_validate", default=0.0, type=float) 64 | parser.add_argument("--t_end_validate", default=4 * np.pi, type=float) 65 | parser.add_argument("--step_size_train", default=0.01, type=float) 66 | parser.add_argument("--step_size_validate", default=0.01, type=float) 67 | parser.add_argument("--noise_std", default=0.000001, type=float) 68 | parser.add_argument("--domain_train", default=1.0, type=float) 69 | parser.add_argument("--domain_validate", default=1.0, type=float) 70 | parser.add_argument("--latent_dim", default=4, type=int) 71 | parser.add_argument("--scatter", default=False, type=bool) 72 | parser.add_argument( 73 | "--criterion", default="elbo", choices=["elbo", "mse"], type=str 74 | ) 75 | args = parser.parse_args() 76 | 77 | # generate data 78 | 79 | def f(t, x): 80 | θ = x[..., 0] 81 | ω = x[..., 1] 82 | 83 | dθ = ω 84 | dω = -torch.sin(θ) 85 | 86 | return torch.stack((dθ, dω), dim=-1) 87 | 88 | domain_draw_factor = 1.3 89 | 90 | noise_log_var = 2 * torch.log(torch.tensor(args.noise_std)) 91 | 92 | latent_dim = args.latent_dim 93 | domain_train = args.domain_train 94 | domain_validate = args.domain_validate 95 | 96 | x0_train = ( 97 | torch.tensor(lhs(2, args.n_traj_train), device=args.device) * 2 - 1 98 | ) * domain_train 99 | x0_validate = ( 100 | torch.tensor(lhs(2, args.n_traj_validate), device=args.device) * 2 - 1 101 | ) * domain_validate 102 | x0_grid = get_meshgrid(step_per_axis=0.01, domain=domain_validate) 103 | x0_example = torch.tensor((0.6, 0)).double().unsqueeze(0).to(args.device) 104 | 105 | step_size_train = args.step_size_train 106 | ε = 1e-10 107 | t_span_train = torch.arange( 108 | args.t_start_train, args.t_end_train + ε, step_size_train 109 | ) 110 | t_span_validate = torch.arange( 111 | args.t_start_validate, 112 | args.t_end_validate + ε, 113 | args.step_size_validate, 114 | ) 115 | 116 | if args.solver.lower() == "direct": 117 | solver = DirectSolver() 118 | elif args.solver.lower() == "resnet": 119 | solver = ResnetSolver() 120 | else: 121 | solver = args.solver 122 | 123 | _, x_train = odeint(f, x0_train, t_span_train, solver="rk4") 124 | x_train = x_train + torch.randn_like(x_train) * args.noise_std 125 | 126 | _, x_validate = odeint(f, x0_validate, t_span_train, solver="rk4") 127 | _, x_example = odeint(f, x0_example, t_span_validate, solver="rk4") 128 | 129 | ##################### model ########################## 130 | device = args.device 131 | 132 | class Encoder(nn.Module): 133 | def __init__(self, input_size, hidden_size, output_size) -> None: 134 | super().__init__() 135 | self.rnn = nn.RNN(input_size, hidden_size) 136 | self.h2o = nn.Linear(hidden_size, output_size) 137 | 138 | def forward(self, x): 139 | x_flipped = torch.flip(x, (0,)) 140 | _, h = self.rnn(x_flipped) 141 | z0 = self.h2o(h) 142 | return z0 143 | 144 | class LatentODE(nn.Module): 145 | def __init__(self) -> None: 146 | super().__init__() 147 | 148 | self.encoder = Encoder(2, 25, latent_dim * 2) 149 | 150 | layers = [] 151 | layers.append(nn.Linear(latent_dim, args.hidden_dim)) 152 | for _ in range(args.n_layers): 153 | layers.append(nn.Linear(args.hidden_dim, args.hidden_dim)) 154 | layers.append(nn.Softplus()) 155 | 156 | layers.append(nn.Linear(args.hidden_dim, 4)) 157 | self.dynamics = nn.Sequential(*layers) 158 | 159 | self.decoder = ( 160 | nn.Sequential( 161 | nn.Linear(latent_dim, 20), nn.Softplus(), nn.Linear(20, 2) 162 | ) 163 | .to(device) 164 | .double() 165 | ) 166 | self.inference = False 167 | 168 | def forward(self, t, x): 169 | z = self.encoder(x)[-1] 170 | qz0_mean, qz0_logvar = z[:, :latent_dim], z[:, latent_dim:] 171 | qz0_std = torch.exp(0.5 * qz0_logvar) 172 | ε = torch.randn_like(qz0_mean) 173 | 174 | if self.inference: 175 | z0 = qz0_mean 176 | else: 177 | z0 = qz0_mean + ε * qz0_std # TODO 178 | # z0 = qz0_mean # TODO 179 | 180 | _, z_pred = odeint(lambda t, z: self.dynamics(z), z0, t, solver=solver) 181 | 182 | x_pred = self.decoder(z_pred) 183 | 184 | return qz0_mean, qz0_logvar, x_pred 185 | 186 | net = LatentODE().to(args.device).double() 187 | 188 | for m in chain(net.modules()): 189 | if type(m) == nn.Linear: 190 | nn.init.xavier_uniform_(m.weight) 191 | 192 | # optimizer 193 | 194 | opt = Adam(net.parameters()) 195 | 196 | # train 197 | losses = defaultdict(list) 198 | 199 | def log_normal_pdf(x, mean, logvar): 200 | px = -0.5 * ( 201 | torch.log(torch.tensor(np.pi)) 202 | + logvar 203 | + (x - mean) ** 2.0 / torch.exp(logvar) 204 | ) 205 | return px.sum((0, 2)) 206 | 207 | # return -.5 * (const + logvar + (x - mean) ** 2. / torch.exp(logvar)) 208 | 209 | def normal_kl(μ1, log_var1, μ2, log_var2): 210 | v1 = torch.exp(log_var1) 211 | v2 = torch.exp(log_var2) 212 | lstd1 = log_var1 / 2.0 213 | lstd2 = log_var2 / 2.0 214 | 215 | kl = lstd2 - lstd1 + ((v1 + (μ1 - μ2) ** 2.0) / (2.0 * v2)) - 0.5 216 | return kl.sum(-1) 217 | 218 | for _ in tqdm(range(args.n_epochs)): 219 | 220 | qz0_mean, qz0_logvar, x_pred_train = net(t_span_train, x_train) 221 | 222 | log_px = log_normal_pdf(x_train, x_pred_train, noise_log_var) 223 | log_normal_kl = normal_kl( 224 | qz0_mean, 225 | qz0_logvar, 226 | torch.zeros_like(qz0_mean), 227 | torch.zeros_like(qz0_logvar), # we want logvar = 1 => var = 1 228 | ) 229 | 230 | loss_elbo = torch.mean(log_normal_kl - log_px, dim=0) 231 | loss_mse = mse_loss(x_pred_train, x_train) 232 | 233 | if args.criterion == "mse": 234 | loss = loss_mse 235 | elif args.criterion == "elbo": 236 | loss = loss_elbo 237 | 238 | loss.backward() 239 | opt.step() 240 | opt.zero_grad() 241 | losses["log_px"].append( 242 | torch.mean(log_px).detach().numpy() 243 | ) # how likely is it to observe the targets, if the output distributions produced by the network, is as they claim. 244 | losses["kl"].append(torch.mean(log_normal_kl).detach().numpy()) 245 | losses["mse"].append(loss_mse.detach().numpy()) 246 | losses["elbo"].append(loss_elbo.detach().numpy()) 247 | 248 | net.inference = True 249 | _, _, x_pred_train = net(t_span_train, x_train) 250 | _, _, x_pred_validate = net(t_span_validate, x_validate) 251 | 252 | # _, x_pred_example = odeint( 253 | # lambda t, x: dynamics(x), x0_example, t_span_validate, solver=solver 254 | # ) 255 | 256 | # derivatives 257 | # x0_grid_before = get_meshgrid(step_per_axis=0.01, domain=domain) 258 | x_derivative = f(None, x0_grid) 259 | 260 | # plot 261 | x_pred_train = x_pred_train.detach().numpy() 262 | x_pred_validate = x_pred_validate.detach().numpy() 263 | 264 | # phase space, training 265 | fig, ax = plt.subplots() 266 | fig.canvas.manager.set_window_title("phase space: training") 267 | 268 | lines_true = LineCollection( 269 | [x for x in x_train.swapaxes(0, 1)], 270 | color="black", 271 | label="true", 272 | ) 273 | # lines_pred = LineCollection( 274 | # [x for x in x_pred_train.swapaxes(0, 1)], color="blue", label="pred" 275 | # ) 276 | 277 | if args.scatter: 278 | ax.scatter(x_pred_train[..., 0], x_pred_train[..., 1]) 279 | plot_colored( 280 | fig, 281 | ax, 282 | t_span_train, 283 | x_pred_train, 284 | label="pred", 285 | colorbar=True, 286 | # linestyle="dashed", 287 | ) 288 | if args.scatter: 289 | ax.scatter(x_train[..., 0], x_train[..., 1]) 290 | ax.add_collection(lines_true) 291 | # ax.add_collection(lines_pred) 292 | 293 | ax.set_xlabel("θ") 294 | ax.set_ylabel("ω") 295 | ax.set_xlim(-domain_draw_factor * domain_train, domain_draw_factor * domain_train) 296 | ax.set_ylim(-domain_draw_factor * domain_train, domain_draw_factor * domain_train) 297 | ax.legend() 298 | 299 | # # phase space, validation 300 | fig, ax = plt.subplots() 301 | fig.canvas.manager.set_window_title("phase space: validation") 302 | 303 | lines_true = LineCollection( 304 | [x for x in x_validate.swapaxes(0, 1)], 305 | color="black", 306 | label="true", 307 | ) 308 | # lines_pred = LineCollection( 309 | # [x for x in x_pred_train.swapaxes(0, 1)], color="blue", label="pred" 310 | # ) 311 | if args.scatter: 312 | ax.scatter(x_pred_validate[..., 0], x_pred_validate[..., 1]) 313 | plot_colored( 314 | fig, 315 | ax, 316 | t_span_validate, 317 | x_pred_validate, 318 | label="pred", 319 | colorbar=True, 320 | # linestyle="dashed", 321 | ) 322 | 323 | if args.scatter: 324 | ax.scatter(x_validate[..., 0], x_validate[..., 1]) 325 | ax.add_collection(lines_true) 326 | # ax.add_collection(lines_pred) 327 | 328 | ax.set_xlabel("θ") 329 | ax.set_ylabel("ω") 330 | ax.set_xlim( 331 | -domain_draw_factor * domain_validate, domain_draw_factor * domain_validate 332 | ) 333 | ax.set_ylim( 334 | -domain_draw_factor * domain_validate, domain_draw_factor * domain_validate 335 | ) 336 | ax.legend() 337 | 338 | # # time series validation, example (0.6,0) 339 | # fig, (ax1, ax2) = plt.subplots(2, sharex=True) 340 | # fig.canvas.manager.set_window_title(f"states vs time: validation example") 341 | 342 | # ax1.plot(t_span_validate, x_example[..., 0], color="black") 343 | # ax1.plot( 344 | # t_span_validate, 345 | # x_pred_example[..., 0], 346 | # linestyle="dashed", 347 | # color="blue", 348 | # ) 349 | # ax2.plot(t_span_validate, x_example[..., 1], label="true", color="black") 350 | # ax2.plot( 351 | # t_span_validate, 352 | # x_pred_example[..., 1], 353 | # linestyle="dashed", 354 | # color="blue", 355 | # label="predicted", 356 | # ) 357 | # ax1.set_ylabel("θ(t)") 358 | # ax2.set_ylabel("ω(t)") 359 | # ax2.set_xlabel("t") 360 | # ax2.legend() 361 | 362 | # show 363 | fig, ax = plt.subplots() 364 | ax.set_xlabel("epoch") 365 | ax.set_ylabel("loss") 366 | for name, values in losses.items(): 367 | ax.plot(values, label=name) 368 | 369 | ax.legend() 370 | plt.show() 371 | -------------------------------------------------------------------------------- /docs/survey_structure.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | SurveyStructureSurvey...Background (Sec. 2)Background...Differential Equations (Sec. 2.1)Differential Equations...Neural Networks(Sec. 2.2)Neural Networks...Direct-SolutionModels(Sec. 3)Direct-Solution...Hidden Physics Networks(Sec. 3.5)Hidden Physics N...Time-StepperModels(Sec. 4)Time-Stepper...Physics Informed Neural Networks(Sec. 3.1-4)Physics Informed N...Integration Schemes (Sec. 4.2)Integration Schemes...Neural ODEs(Sec. 4.2.1-4.2.4)Neural ODEs...Network Architecture(Sec. 4.4)Network Architectur...External Input(Sec. 4.3)External Inpu...Neural State-Space Models(Sec. 4.3.1)Neural State-Spa...Neural ODEswith input(Sec. 4.3.2-3)Neural ODEs...Graph Neural Networks(Sec. 4.4.3)Graph Neural Net...Hamiltonian/ Lagrangian NN(Sec 4.4.1)Hamiltonian/ Lag...Deep Potential NNs(Sec. 4.4.2)Deep Potential N...Uncertainty(Sec. 4.5)Uncertainty...Deep Markov Models(Sec. 4.5.1)Deep Markov...Baysian Neural ODEs(Sec. 4.5.3)Baysian Neur...Neural SDEs(Sec. 4.5.4)Neural SDEs...Model Taxonomy(Sec. 2.3)Model Taxonomy...Latent Neural ODEs(Sec. 4.5.2)Latent Neura...Viewer does not support full SVG 1.1 -------------------------------------------------------------------------------- /docs/time_vs_state_plot.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 2022-11-19T11:39:59.949724 10 | image/svg+xml 11 | 12 | 13 | Matplotlib v3.4.3, https://matplotlib.org/ 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 30 | 31 | 32 | 33 | 39 | 40 | 41 | 42 | 43 | 44 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 119 | 140 | 147 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 239 | 252 | 273 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 415 | 416 | 417 | 535 | 536 | 537 | 540 | 541 | 542 | 545 | 546 | 547 | 550 | 551 | 552 | 555 | 556 | 557 | 558 | 559 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | 590 | 614 | 615 | 616 | 617 | 618 | 619 | 620 | 621 | 622 | 623 | 624 | 625 | 626 | 627 | 628 | 629 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 660 | 661 | 662 | 663 | 693 | 694 | 695 | 696 | 697 | 698 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 | 707 | 708 | 747 | 748 | 749 | 750 | 751 | 752 | 753 | 754 | 755 | 756 | 757 | 758 | 759 | 760 | 761 | 762 | 776 | 777 | 778 | 779 | 780 | 781 | 782 | 783 | 784 | 785 | 786 | 787 | 788 | 789 | 790 | 791 | 792 | 793 | 794 | 795 | 796 | 797 | 798 | 799 | 800 | 801 | 802 | 803 | 804 | 805 | 806 | 807 | 808 | 809 | 810 | 811 | 812 | 813 | 814 | 815 | 816 | 817 | 818 | 819 | 820 | 821 | 822 | 823 | 824 | 825 | 826 | 827 | 828 | 829 | 830 | 831 | 832 | 833 | 834 | 835 | 836 | 837 | 838 | 839 | 840 | 841 | 842 | 843 | 844 | 845 | 846 | 847 | 848 | 849 | 850 | 851 | 852 | 853 | 854 | 872 | 873 | 874 | 875 | 876 | 877 | 878 | 879 | 880 | 881 | 999 | 1000 | 1001 | 1121 | 1122 | 1123 | 1126 | 1127 | 1128 | 1131 | 1132 | 1133 | 1136 | 1137 | 1138 | 1141 | 1142 | 1143 | 1144 | 1155 | 1156 | 1157 | 1160 | 1161 | 1162 | 1163 | 1164 | 1165 | 1166 | 1183 | 1205 | 1230 | 1231 | 1232 | 1233 | 1234 | 1235 | 1236 | 1237 | 1238 | 1241 | 1242 | 1243 | 1244 | 1245 | 1246 | 1247 | 1273 | 1299 | 1300 | 1301 | 1302 | 1303 | 1304 | 1305 | 1306 | 1307 | 1308 | 1309 | 1310 | 1311 | 1312 | 1313 | 1314 | 1315 | 1316 | 1317 | 1318 | --------------------------------------------------------------------------------