├── .gitignore
├── README.md
├── assets
├── lorenz_trajectories.png
└── pendulum_train.gif
├── setup.py
├── tests
├── __init__.py
├── ghnn_test.py
└── ghnn_test
│ └── pendulum_test_data.pkl
├── tutorial
├── ghnn_tutorial.ipynb
└── weakform_tutorial.ipynb
└── weakformghnn
├── __init__.py
└── _src
├── __init__.py
├── _models.py
├── _vector_calc.py
└── _weak_form.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # mac misc files
132 | *.DS_Store
133 |
134 | # vscode setup
135 | *.vscode
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Weak Form Generalized Hamiltonian Learning
2 | This library provides a PyTorch implementation for performing Weak Form Generalized Hamiltonian Learning. This code accompanies this [Neurips2020 paper](https://proceedings.neurips.cc/paper/2020/file/d93c96e6a23fff65b91b900aaa541998-Paper.pdf) [1] by Kevin L. Course, Trefor W. Evans, and Prasanth B. Nair.
3 |
4 | As everything is written in PyTorch, all algorithms provide full GPU support.
5 |
6 | ---
7 |
8 |
9 |
10 |
11 |
12 |
13 | Please cite our paper if you find this code useful in your research. The bibliographic information for the paper is,
14 | ```bash
15 | @inproceedings{course_wfghnn_2020,
16 | author = {Course, Kevin and Evans, Trefor and Nair, Prasanth},
17 | booktitle = {Advances in Neural Information Processing Systems},
18 | editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
19 | pages = {18716--18726},
20 | publisher = {Curran Associates, Inc.},
21 | title = {Weak Form Generalized Hamiltonian Learning},
22 | url = {https://proceedings.neurips.cc/paper/2020/file/d93c96e6a23fff65b91b900aaa541998-Paper.pdf},
23 | volume = {33},
24 | year = {2020}
25 | }
26 | ```
27 | Our experiments make heavy use of Chen *et. al's* [torchdiffeq package](https://github.com/rtqichen/torchdiffeq) [2] and we use the rectified Huber unit from Kolter *et. al* [3].
28 |
29 | ---
30 |
31 | ## Installation
32 | ```bash
33 | pip install git+https://github.com/coursekevin/weakformghnn#egg=weakformghnn
34 | ```
35 |
36 | ## Library Overview
37 | This library provides **three** main components:
38 |
39 | * GHNN class: The class provides a convenient interface for specifying **strong priors** on the form of the generalized Hamiltonian. This module inherits from the torch nn.Module class. This allows GHNNs to be trained in the usual ways one can train a model for a continuous time ODE (ie. see [1,2]). See the ghnn_tutorial for an example of how to use this class in context.
40 |
41 | * GHNNwHPrior class: This class has a similar structure to the main GHNN class except it should only be used when the generalized Hamiltonian is known. It can be trained in the same way as a standard GHNN.
42 |
43 | * weak_form_loss: Use this loss function to perform "weak form regression" with any model for a continuous time ODE. See the weakform_tutorial for an example.
44 |
45 |
46 |
47 |
48 |
49 |
50 | ## References
51 | [1] Kevin L. Course, Trefor W. Evans, Prasanth B. Nair. "Weak Form Generalized Hamiltonian Learning." Advances in Neural Information Processing Systems. 2020.
52 |
53 | [2] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." Advances in Neural Information Processing Systems. 2018.
54 |
55 | [3] J. Z. Kolter and G. Manek. “Learning Stable Deep Dynamics Models”. In: Advances in Neural Information Processing Systems 32. Ed. by H. Wallach, H. Larochelle, A. Beygelzimer, F. d. Alché-Buc, E. Fox, and R. Garnett. Curran Associates, Inc., 2019, pp. 11126–11134.
--------------------------------------------------------------------------------
/assets/lorenz_trajectories.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coursekevin/weakformghnn/3c0ee28cf213e17efbf44ef15f4f339fbe52edb2/assets/lorenz_trajectories.png
--------------------------------------------------------------------------------
/assets/pendulum_train.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coursekevin/weakformghnn/3c0ee28cf213e17efbf44ef15f4f339fbe52edb2/assets/pendulum_train.gif
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="weakformghnn",
8 | version="1.0.0",
9 | author="Kevin L. Course, Trefor W. Evans, Prasanth B. Nair",
10 | author_email="kevin.course@mail.utoronto.ca",
11 | description="PyTorch implementation of weak form generalized Hamiltonian learning.",
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | url="https://github.com/coursekevin/weakformghnn",
15 | packages=setuptools.find_packages(),
16 | install_requires=['torch'],
17 | classifiers=[
18 | "Programming Language :: Python :: 3",
19 | "Operating System :: OS Independent",
20 | ],
21 | )
22 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coursekevin/weakformghnn/3c0ee28cf213e17efbf44ef15f4f339fbe52edb2/tests/__init__.py
--------------------------------------------------------------------------------
/tests/ghnn_test.py:
--------------------------------------------------------------------------------
1 | """
2 | Unit testing for generalized Hamiltonian learning
3 | """
4 | from weakformghnn import *
5 | import torch
6 | import torch.nn as nn
7 | import pytest
8 | import numpy as np
9 | import pickle
10 | from distutils import dir_util
11 | import os
12 | from torchdiffeq import odeint
13 | import matplotlib.pyplot as plt
14 |
15 | # Constants
16 | NDIM = 3
17 | NHIDDEN = 50
18 | NLAYERS = 3
19 |
20 |
21 | @pytest.fixture
22 | def GHNN_model():
23 | return GHNN(NDIM, NHIDDEN, NLAYERS)
24 |
25 |
26 | @pytest.fixture
27 | def convex_concave_data():
28 | t = torch.linspace(0., 10., 100)
29 | y_conv = torch.pow(t, 2)+10
30 | y_conc = -y_conv
31 | loss = torch.nn.MSELoss()
32 | return (t.view(-1, 1), y_conv.view(-1, 1), y_conc.view(-1, 1), loss)
33 |
34 |
35 | @pytest.fixture
36 | def ivp_integration_data(datadir):
37 | y0 = torch.tensor([np.pi - np.pi/32, 0.0], requires_grad=True)
38 | t_eval = torch.linspace(0., 20., 2000)
39 | data_path = datadir.join('pendulum_test_data.pkl')
40 | with open(data_path, 'rb') as handle:
41 | y = torch.tensor(pickle.load(handle))
42 | return (y0, t_eval, y)
43 |
44 |
45 | @pytest.fixture
46 | def vector_calc_data():
47 | def f(x):
48 | return torch.stack([2*x[:, 0]**3 + x[:, 1],
49 | x[:, 0]*torch.sin(x[:, 1]) + x[:, 2]**2, x[:, 2]**4], dim=1)
50 |
51 | def div_f(x):
52 | return 6*x[:, 0]**2 + x[:, 0]*torch.cos(x[:, 1]) + 4*x[:, 2]**3
53 |
54 | def curl_f(x):
55 | return torch.stack([torch.sin(x[:, 1])-1., torch.zeros(x[:, 0].shape), -2*x[:, 2]], dim=1)
56 |
57 | X = torch.randn(20, 3, requires_grad=True)
58 | f_val = f(X)
59 | div_val = div_f(X)
60 | curl_val = curl_f(X)
61 |
62 | return (f_val, div_val, curl_val, X)
63 | # Tests
64 |
65 |
66 | def test_PosLinear():
67 | pos_func = PosLinear(5, 1)
68 | x = torch.pow(torch.randn(20, 5), 2)
69 | np.testing.assert_array_less(torch.zeros(
70 | 20, 1), pos_func(x).detach().numpy())
71 | np.testing.assert_array_less(
72 | pos_func(-x).detach().numpy(), torch.zeros(20, 1))
73 |
74 |
75 | def test_ScalarFuncZero():
76 | zero_func = ScalarFuncZero(NDIM, NHIDDEN, NLAYERS)
77 | x = torch.zeros(10, 20, NDIM)
78 | np.testing.assert_almost_equal(torch.zeros(
79 | 10, 20, 1), zero_func(x).detach().numpy())
80 |
81 |
82 | def test_ScalarFuncPos():
83 | zero_func_pos = ScalarFuncZeroPos(NDIM, NHIDDEN, NLAYERS)
84 | x = 10.*torch.randn(10, 20, NDIM)-5.
85 | np.testing.assert_array_less(torch.zeros(
86 | 10, 20, 1), zero_func_pos(x).detach().numpy())
87 |
88 | z = torch.zeros(10, NDIM)
89 | np.testing.assert_array_almost_equal(torch.zeros(
90 | 10, 1).numpy(), zero_func_pos(z).detach().numpy())
91 |
92 |
93 | def test_ScalarFunc():
94 | scalar_func = ScalarFunc(0, NHIDDEN, NLAYERS)
95 | out1 = scalar_func(torch.randn(10, 0))
96 | out2 = scalar_func(torch.randn(10, 0))
97 | np.testing.assert_array_equal(out1.detach().numpy(), out2.detach().numpy())
98 |
99 |
100 | def test_ZeroDivMat():
101 | J = ZeroDivMat(NDIM, NHIDDEN, NLAYERS)
102 | x = torch.randn(10, NDIM)
103 | np.testing.assert_array_equal((10, NDIM, NDIM), J(x).shape)
104 |
105 | J = ZeroDivMat(2, NHIDDEN, NLAYERS)
106 | x = torch.randn(10, 2)
107 | J_out = J(x)
108 | np.testing.assert_array_equal(
109 | J_out[0].detach().numpy(), J_out[1].detach().numpy())
110 |
111 |
112 | def test_GHNN(GHNN_model):
113 | # note this function just tests that number of computed outputs is correct
114 | model = GHNN_model
115 | x = torch.randn(20, NDIM, requires_grad=True)
116 | np.testing.assert_array_equal((20, NDIM), model(0., x).shape)
117 |
118 | # test additional dimensions are handled correctly
119 | x = torch.randn(3, 5, 10, NDIM)
120 | x_shape = x.shape
121 | out_loop = []
122 | for i in range(x.shape[0]):
123 | out_tmp = []
124 | for j in range(x.shape[1]):
125 | out_tmp.append(model(0., x[i, j, :, :]))
126 | out_loop.append(torch.stack(out_tmp))
127 | out_loop = torch.stack(out_loop)
128 | np.testing.assert_array_almost_equal(
129 | out_loop.detach().numpy(), model(0., x).detach().numpy())
130 |
131 |
132 | def test_hessian(GHNN_model):
133 | model = GHNN_model
134 |
135 | def gr_f(xin):
136 | x = xin[:, 0]
137 | y = xin[:, 1]
138 | z = xin[:, 2]
139 | return torch.stack([2*x*y + z, torch.pow(x, 2), x + 3*torch.pow(z, 2)]).T
140 |
141 | def h(xin):
142 | x = xin[:, 0]
143 | y = xin[:, 1]
144 | z = xin[:, 2]
145 | hess = []
146 | for j in range(xin.shape[0]):
147 | hess.append(torch.tensor(
148 | [[2*y[j], 2*x[j], 1.], [2*x[j], 0., 0.], [1., 0., 6*z[j]]]))
149 | return torch.stack(hess)
150 | x = torch.randn(20, NDIM, requires_grad=True)
151 |
152 | np.testing.assert_array_almost_equal(
153 | model.hessian(gr_f(x), x).detach().numpy(), h(x))
154 |
155 |
156 | def test_ConvexFunc(convex_concave_data):
157 | t, yconv, yconc, loss = convex_concave_data
158 | conv_func1 = ConvexFunc(1, 30, 3)
159 | conv_func2 = ConvexFunc(1, 30, 3)
160 | optimizer1 = torch.optim.Adam(conv_func1.parameters(), lr=1e-1)
161 | optimizer2 = torch.optim.Adam(conv_func2.parameters(), lr=1e-1)
162 | for _ in range(500):
163 | L1 = loss(conv_func1(t), yconv)
164 | L1.backward()
165 | optimizer1.step()
166 | optimizer1.zero_grad()
167 | L2 = loss(conv_func2(t), yconc)
168 | L2.backward()
169 | optimizer2.step()
170 | optimizer2.zero_grad()
171 | assert L1 < 0.5
172 | assert L2 > 1000
173 |
174 |
175 | def test_ConcaveFunc(convex_concave_data):
176 | t, yconv, yconc, loss = convex_concave_data
177 | conc_func1 = ConcaveFunc(1, 30, 3)
178 | conc_func2 = ConcaveFunc(1, 30, 3)
179 | optimizer1 = torch.optim.Adam(conc_func1.parameters(), lr=1e-1)
180 | optimizer2 = torch.optim.Adam(conc_func2.parameters(), lr=1e-1)
181 | for _ in range(500):
182 | L1 = loss(conc_func1(t), yconv)
183 | L1.backward()
184 | optimizer1.step()
185 | optimizer1.zero_grad()
186 | L2 = loss(conc_func2(t), yconc)
187 | L2.backward()
188 | optimizer2.step()
189 | optimizer2.zero_grad()
190 | assert L2 < 0.5
191 | assert L1 > 1000
192 |
193 |
194 | def test_ConcaveZero(convex_concave_data):
195 | t, _, yconc, loss = convex_concave_data
196 | conv_zero = ConcaveFuncZero(1, 30, 3)
197 | optimizer = torch.optim.Adam(conv_zero.parameters(), lr=1e-1)
198 | for _ in range(500):
199 | L = loss(conv_zero(t), yconc)
200 | L.backward()
201 | optimizer.step()
202 | optimizer.zero_grad()
203 | x_in = torch.zeros(20, 1)
204 | np.testing.assert_array_almost_equal(
205 | x_in, conv_zero(x_in).detach().numpy(), decimal=4)
206 |
207 |
208 | def test_gauss_rbf():
209 | t = torch.tensor([1., 2., 3.])
210 | c = torch.tensor([4., 5.])
211 | eps = 0.3
212 |
213 | t.requires_grad = True
214 |
215 | grbf, grbf_deriv = gauss_rbf(t, c, eps)
216 |
217 | true_grbf = torch.tensor([[0.4449, 0.2369],
218 | [0.6977, 0.4449],
219 | [0.9139, 0.6977]])
220 |
221 | true_grad_grbf = []
222 | for i in range(grbf.shape[1]):
223 | true_grad_grbf.append(torch.autograd.grad(
224 | grbf[:, i].sum(), t, retain_graph=True)[0].reshape(-1, 1))
225 | true_grad_grbf = torch.cat(true_grad_grbf, dim=1)
226 |
227 | # testing grbf val
228 | np.testing.assert_array_almost_equal(
229 | grbf.detach().numpy(), true_grbf.numpy(), decimal=4)
230 | # testing grbf time deriv
231 | np.testing.assert_array_almost_equal(
232 | grbf_deriv.detach().numpy(), true_grad_grbf.detach().numpy(), decimal=4)
233 |
234 |
235 | def test_poly_bf():
236 | deg = 3
237 | t = torch.tensor([1., 2., 3.])
238 | c = torch.tensor([4., 5.])
239 |
240 | t.requires_grad = True
241 | poly, poly_deriv = poly_bf(t, c, deg)
242 |
243 | expected = torch.tensor([[1., 1., -3., -4., 9., 16., -27., -64.],
244 | [1., 1., -2., -3., 4., 9., -8., -27.],
245 | [1., 1., -1., -2., 1., 4., -1., -8.]])
246 | grad_expected = []
247 | for i in range(poly.shape[1]):
248 | grad_expected.append(torch.autograd.grad(
249 | poly[:, i].sum(), t, retain_graph=True)[0].reshape(-1, 1))
250 | grad_expected = torch.cat(grad_expected, dim=1)
251 |
252 | # testing poly bf val
253 | np.testing.assert_almost_equal(poly.detach().numpy(), expected.numpy())
254 | # testing poly bf time deriv.
255 | np.testing.assert_almost_equal(
256 | poly_deriv.detach().numpy(), grad_expected.detach().numpy())
257 |
258 |
259 | def test_divergence_calc(vector_calc_data):
260 | f_val, div_val, _, X = vector_calc_data
261 |
262 | div_calc = divergence(f_val, X)
263 |
264 | np.testing.assert_array_equal(
265 | div_calc.detach().numpy(), div_val.detach().numpy())
266 |
267 |
268 | def test_curl_calc(vector_calc_data):
269 | f_val, _, curl_val, X = vector_calc_data
270 |
271 | curl_calc = curl(f_val, X)
272 | np.testing.assert_array_equal(
273 | curl_calc.detach().numpy(), curl_val.detach().numpy())
274 |
275 |
276 | if __name__ == "__main__":
277 | pytest.main()
278 |
--------------------------------------------------------------------------------
/tests/ghnn_test/pendulum_test_data.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/coursekevin/weakformghnn/3c0ee28cf213e17efbf44ef15f4f339fbe52edb2/tests/ghnn_test/pendulum_test_data.pkl
--------------------------------------------------------------------------------
/tutorial/weakform_tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Weak form loss tutorial\n",
8 | "\n",
9 | "This tutorial will walk through how to learn a continuous time model for an ODE from the weak form of the governing equations. Note that the weak form loss function can be used to training any continuous time model of an ODE. In this tutorial we will use a generic fully-connected neural network. \n",
10 | "\n",
11 | "The weakformghnn package provides a squared weak form loss function to be used to efficiently learn an ODE."
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 1,
17 | "metadata": {},
18 | "outputs": [
19 | {
20 | "data": {
21 | "text/plain": [
22 | ""
23 | ]
24 | },
25 | "execution_count": 1,
26 | "metadata": {},
27 | "output_type": "execute_result"
28 | }
29 | ],
30 | "source": [
31 | "import torch\n",
32 | "import torch.nn as nn\n",
33 | "import numpy as np\n",
34 | "import math \n",
35 | "import matplotlib.pyplot as plt\n",
36 | "from torchdiffeq import odeint\n",
37 | "from weakformghnn import weak_form_loss, gauss_rbf\n",
38 | "torch.random.manual_seed(42)"
39 | ]
40 | },
41 | {
42 | "cell_type": "markdown",
43 | "metadata": {},
44 | "source": [
45 | "## Generating data\n",
46 | "\n",
47 | "Let's generate some training data. We will simulate a duffing oscillator as it decays towards one of it's two regions of local stability."
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 2,
53 | "metadata": {},
54 | "outputs": [
55 | {
56 | "data": {
57 | "text/plain": [
58 | "Text(0, 0.5, 'y')"
59 | ]
60 | },
61 | "execution_count": 2,
62 | "metadata": {},
63 | "output_type": "execute_result"
64 | },
65 | {
66 | "data": {
67 | "image/png": "\n",
68 | "text/plain": [
69 | ""
70 | ]
71 | },
72 | "metadata": {
73 | "needs_background": "light"
74 | },
75 | "output_type": "display_data"
76 | }
77 | ],
78 | "source": [
79 | "t = torch.linspace(0., 20., 1000)\n",
80 | "class DuffingODE(nn.Module):\n",
81 | " \"\"\" Duffing oscilator with friction\n",
82 | " \"\"\"\n",
83 | " def __init__(self):\n",
84 | " super(DuffingODE, self).__init__()\n",
85 | " self.xl = torch.tensor([-2., -2.])\n",
86 | " self.xu = torch.tensor([2., 2.])\n",
87 | " self.NDIM = 2\n",
88 | " self._J = torch.tensor([[0., 1.], [-1., 0.]])\n",
89 | " self._R = 3.5*torch.Tensor([[0., 0.], [0., -0.1]])\n",
90 | " def H(self, x):\n",
91 | " x2 = x[:, 1]\n",
92 | " x1 = x[:, 0]\n",
93 | " return 0.5*x2.pow(2) - 0.5*x1.pow(2) + 0.25*x1.pow(4)\n",
94 | " def ode(self, t, x):\n",
95 | " with torch.enable_grad():\n",
96 | " if not x.requires_grad:\n",
97 | " x.requires_grad = True\n",
98 | " grad_h = torch.autograd.grad(\n",
99 | " self.H(x).sum(), x, retain_graph=False)[0]\n",
100 | " return torch.mm(grad_h, self._J.T) + torch.mm(grad_h, self._R.T)\n",
101 | " def forward(self, t, x):\n",
102 | " in_shape = x.shape\n",
103 | " x = x.reshape(-1, self.NDIM)\n",
104 | " return self.ode(t, x).reshape(in_shape)\n",
105 | "\n",
106 | "ode = DuffingODE()\n",
107 | "y = odeint(ode, torch.tensor([[-1.5,3.0]]), t).squeeze(1)\n",
108 | "y_meas = y + torch.randn(y.shape)*0.1\n",
109 | "\n",
110 | "plt.plot(y[:,0], y[:,1], label='true trajectory')\n",
111 | "plt.plot(y_meas[:,0], y_meas[:,1],'o',alpha=0.3, label='data')\n",
112 | "plt.legend()\n",
113 | "plt.xlabel('t')\n",
114 | "plt.ylabel('y')"
115 | ]
116 | },
117 | {
118 | "cell_type": "markdown",
119 | "metadata": {},
120 | "source": [
121 | "## Defining an ODE model\n",
122 | "\n",
123 | "Let's now define a neural network model for an autonomous ODE."
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 3,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "class ODEModel(nn.Module):\n",
133 | " def __init__(self, num_inputs, nhidden):\n",
134 | " super(ODEModel, self).__init__()\n",
135 | " self.num_inputs = num_inputs\n",
136 | " self.mlp = nn.Sequential(\n",
137 | " nn.Linear(num_inputs, nhidden),\n",
138 | " nn.ReLU(),\n",
139 | " nn.Linear(nhidden, nhidden),\n",
140 | " nn.ReLU(),\n",
141 | " nn.Linear(nhidden, num_inputs)\n",
142 | " )\n",
143 | "\n",
144 | " def forward(self, t, x):\n",
145 | " bs = x.shape[:-1] # reshape to handle multi-dim batches\n",
146 | " return self.mlp(x.reshape(-1,self.num_inputs)).reshape(*bs, self.num_inputs)\n"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "metadata": {},
152 | "source": [
153 | "## Training the ODE with weak form loss\n"
154 | ]
155 | },
156 | {
157 | "cell_type": "markdown",
158 | "metadata": {},
159 | "source": [
160 | "First we define a helper function for sampling data"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 4,
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "def sample_data(t, y, batch_time_int, batch_size):\n",
170 | " \"\"\" Draws batch_size samples from y and t with an integration time of batch_time_int\n",
171 | " Assumes all points are evenly spaced \n",
172 | " inputs:\n",
173 | " t < tensor(num_t,) > \n",
174 | " y < tensor(num_t, num_start_times, NDIM) > \n",
175 | " batch_time_int < int > \n",
176 | " batch_size < int > \n",
177 | " outputs:\n",
178 | " t_data < tensor(batch_time_int) > \n",
179 | " y_data < tensor(batch_time_int, batch_size, NDIM) > \n",
180 | " \"\"\"\n",
181 | " num_t = len(t)\n",
182 | " t_data = t[:batch_time_int]\n",
183 | " batch_idx = np.random.choice(np.arange(num_t - batch_time_int + 1),\n",
184 | " batch_size, replace=False)\n",
185 | " run_idx = np.random.choice(np.arange(y.shape[1]), batch_size, replace=True)\n",
186 | " y_data = torch.stack(\n",
187 | " [y[batch_idx[i]:batch_idx[i] + batch_time_int, run_idx[i]]\n",
188 | " for i in range(len(batch_idx))], dim=1)\n",
189 | " return t_data, y_data"
190 | ]
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "metadata": {},
195 | "source": [
196 | "We will run this example on GPU if it is available"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 5,
202 | "metadata": {},
203 | "outputs": [
204 | {
205 | "name": "stdout",
206 | "output_type": "stream",
207 | "text": [
208 | "Running on: cuda:0\n"
209 | ]
210 | }
211 | ],
212 | "source": [
213 | "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
214 | "print(\"Running on: {}\".format(device))"
215 | ]
216 | },
217 | {
218 | "cell_type": "markdown",
219 | "metadata": {},
220 | "source": [
221 | "We will use 400 radial basis functions with a shape parameter of 10 in each integration time window."
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 6,
227 | "metadata": {},
228 | "outputs": [
229 | {
230 | "name": "stdout",
231 | "output_type": "stream",
232 | "text": [
233 | "Epoch: 0050, loss: 1.53\n",
234 | "Epoch: 0100, loss: 1.16\n",
235 | "Epoch: 0150, loss: 1.11\n",
236 | "Epoch: 0200, loss: 1.07\n"
237 | ]
238 | }
239 | ],
240 | "source": [
241 | "epoch = 200\n",
242 | "model = ODEModel(2, 200)\n",
243 | "model.to(device)\n",
244 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-1)\n",
245 | "rbf_ep = 10.0\n",
246 | "model.to(device)\n",
247 | "for ep in range(1,epoch+1):\n",
248 | " optimizer.zero_grad()\n",
249 | " t_sample, y_sample = sample_data(t, y_meas.unsqueeze(1), 500, 80)\n",
250 | " t_sample = t_sample.to(device)\n",
251 | " y_sample = y_sample.to(device)\n",
252 | " dydt_pred = model(t_sample, y_sample)\n",
253 | " c = torch.linspace(t_sample[0], t_sample[-1], 400, device=device)\n",
254 | " psi, psi_dot = gauss_rbf(t_sample, c, rbf_ep)\n",
255 | " loss = weak_form_loss(dydt_pred.transpose(1, 0), y_sample.transpose(1, 0), \n",
256 | " t_sample, psi, psi_dot) \n",
257 | " if ep % 50 == 0:\n",
258 | " print(\"Epoch: {:04d}, loss: {:.2f}\".format(ep,loss.item()))\n",
259 | " loss.backward()\n",
260 | " optimizer.step()\n"
261 | ]
262 | },
263 | {
264 | "cell_type": "markdown",
265 | "metadata": {},
266 | "source": [
267 | "## Visualizing the results \n",
268 | "\n",
269 | "We see we were able to learn a model for the ODE which reproduces the data generating trajectory well."
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 7,
275 | "metadata": {},
276 | "outputs": [
277 | {
278 | "data": {
279 | "image/png": "\n",
280 | "text/plain": [
281 | ""
282 | ]
283 | },
284 | "metadata": {
285 | "needs_background": "light"
286 | },
287 | "output_type": "display_data"
288 | }
289 | ],
290 | "source": [
291 | "model.to('cpu')\n",
292 | "y_pred = odeint(model, y[0],t)\n",
293 | "with torch.no_grad():\n",
294 | " fig=plt.figure()\n",
295 | " ax = fig.add_subplot(111)\n",
296 | " ax.plot(t, y, label='true trajectory')\n",
297 | " ax.plot(t, y_pred,'--', label='predicted trajectory')\n",
298 | " ax.set_xlabel('t (s)')\n",
299 | " ax.set_ylabel('y')\n",
300 | " plt.legend()\n",
301 | " plt.pause(1/60) "
302 | ]
303 | }
304 | ],
305 | "metadata": {
306 | "kernelspec": {
307 | "display_name": "Python 3",
308 | "language": "python",
309 | "name": "python3"
310 | },
311 | "language_info": {
312 | "codemirror_mode": {
313 | "name": "ipython",
314 | "version": 3
315 | },
316 | "file_extension": ".py",
317 | "mimetype": "text/x-python",
318 | "name": "python",
319 | "nbconvert_exporter": "python",
320 | "pygments_lexer": "ipython3",
321 | "version": "3.8.5"
322 | }
323 | },
324 | "nbformat": 4,
325 | "nbformat_minor": 4
326 | }
--------------------------------------------------------------------------------
/weakformghnn/__init__.py:
--------------------------------------------------------------------------------
1 | from ._src import *
2 |
--------------------------------------------------------------------------------
/weakformghnn/_src/__init__.py:
--------------------------------------------------------------------------------
1 | from ._models import *
2 | from ._weak_form import *
3 | from ._vector_calc import *
4 |
--------------------------------------------------------------------------------
/weakformghnn/_src/_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.data import Dataset
4 | from ._vector_calc import curl
5 | import math
6 |
7 | __all__ = ['PosLinear',
8 | 'ConvexFunc',
9 | 'ConvexFuncZero',
10 | 'ConcaveFunc',
11 | 'ConcaveFuncZero',
12 | 'ScalarFunc',
13 | 'ScalarFuncZero',
14 | 'ScalarFuncZeroPos',
15 | 'ScalarFuncPosUnbnd',
16 | 'ZeroDivMat',
17 | 'GHNN',
18 | 'HNN',
19 | 'ODEFCN',
20 | 'ReHU',
21 | 'GHNNwHPrior']
22 |
23 |
24 | class PosLinear(nn.Module):
25 | """ Linear layer with positive weights only
26 | """
27 | __constants__ = ['in_features', 'out_features']
28 |
29 | def __init__(self, in_features, out_features):
30 | super(PosLinear, self).__init__()
31 | self.in_features = in_features
32 | self.out_features = out_features
33 | self.weight = nn.parameter.Parameter(
34 | torch.Tensor(out_features, in_features))
35 | self.reset_parameters()
36 |
37 | def reset_parameters(self):
38 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
39 |
40 | def forward(self, input):
41 | return input @ torch.pow(self.weight, 2).T
42 |
43 | def extra_repr(self):
44 | return 'in_features={}, out_features={}'.format(
45 | self.in_features, self.out_features
46 | )
47 |
48 |
49 | class ConvexFunc(nn.Module):
50 | def __init__(self, ndim, nhidden, nlayers):
51 | super(ConvexFunc, self).__init__()
52 | self.linears = nn.ModuleList([nn.Linear(ndim, nhidden)])
53 | self.poslinear = nn.ModuleList()
54 | hidden_lin_layers = [nn.Linear(ndim, nhidden)
55 | for j in range(nlayers - 1)]
56 | hidden_pos_layers = [PosLinear(nhidden, nhidden)
57 | for j in range(nlayers-1)]
58 | self.linears.extend(hidden_lin_layers)
59 | self.poslinear.extend(hidden_pos_layers)
60 | self.linears.extend([nn.Linear(ndim, 1)])
61 | self.poslinear.extend([PosLinear(nhidden, 1)])
62 | self.activation = nn.Softplus()
63 | self.ndim = ndim
64 |
65 | def forward(self, x):
66 | z = self.activation(self.linears[0](x))
67 | for j in range(len(self.poslinear)):
68 | z = self.activation(self.poslinear[j](z) + self.linears[j+1](x))
69 | return z
70 |
71 |
72 | class ConvexFuncZero(ConvexFunc):
73 | def __init__(self, ndim, nhidden, nlayers):
74 | super(ConvexFuncZero, self).__init__(ndim, nhidden, nlayers)
75 | self.conv_forward = super(ConvexFuncZero, self).forward
76 | self.ndim
77 | self.register_buffer('zero_input', torch.zeros(1, ndim))
78 |
79 | def forward(self, x):
80 | return self.conv_forward(x) - self.conv_forward(self.zero_input)
81 |
82 |
83 | class ConcaveFunc(ConvexFunc):
84 | def __init__(self, ndim, nhidden, nlayers):
85 | super(ConcaveFunc, self).__init__(ndim, nhidden, nlayers)
86 |
87 | def forward(self, x):
88 | conv_out = super(ConcaveFunc, self).forward(x)
89 | return -1*conv_out
90 |
91 |
92 | class ConcaveFuncZero(ConcaveFunc):
93 | def __init__(self, ndim, nhidden, nlayers):
94 | super(ConcaveFuncZero, self).__init__(ndim, nhidden, nlayers)
95 | self.conc_forward = super(ConcaveFuncZero, self).forward
96 | self.ndim = ndim
97 | self.register_buffer('zero_input', torch.zeros(1, ndim))
98 |
99 | def forward(self, x):
100 | return self.conc_forward(x) - self.conc_forward(self.zero_input)
101 |
102 |
103 | class ScalarFunc(nn.Module):
104 | def __init__(self, ndim, nhidden, nlayers):
105 | super(ScalarFunc, self).__init__()
106 | if ndim == 0:
107 | self.weight = nn.parameter.Parameter(torch.Tensor(1, 1))
108 | self._res_weight()
109 | self.mlp = lambda x: torch.ones(
110 | x.shape[0], device=x.device)*self.weight
111 | else:
112 | self.linears = nn.ModuleList([nn.Linear(ndim, nhidden)])
113 | hidden_layers = [nn.Linear(nhidden, nhidden)
114 | for j in range(nlayers - 1)]
115 | self.linears.extend(hidden_layers)
116 | self.activation = nn.Softplus()
117 | self.out = nn.Linear(nhidden, 1)
118 | self.mlp = self._mlp_forward
119 |
120 | def _mlp_forward(self, x):
121 | for _, l in enumerate(self.linears):
122 | x = self.activation(l(x))
123 | return self.out(x)
124 |
125 | def forward(self, x):
126 | return self.mlp(x)
127 |
128 | def _res_weight(self):
129 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
130 |
131 |
132 | class ScalarFuncZero(ScalarFunc):
133 | def __init__(self, ndim, nhidden, nlayers):
134 | super(ScalarFuncZero, self).__init__(ndim, nhidden, nlayers)
135 | self.ndim = ndim
136 | self.scalar = super(ScalarFuncZero, self).forward
137 | self.register_buffer('zero_input', torch.zeros(1, ndim))
138 |
139 | def forward(self, x):
140 | return self.scalar(x) - self.scalar(self.zero_input)
141 |
142 |
143 | class ScalarFuncZeroPos(ScalarFunc):
144 | def __init__(self, ndim, nhidden, nlayers):
145 | super(ScalarFuncZeroPos, self).__init__(ndim, nhidden, nlayers)
146 | self.ndim = ndim
147 | self.scalar = super(ScalarFuncZeroPos, self).forward
148 | self.register_buffer('zero_input', torch.zeros(1, ndim))
149 | self.make_pos = ReHU(1.)
150 |
151 | def sc_zero_pos(self, x):
152 | return self.make_pos(self.scalar(x)-self.scalar(self.zero_input))
153 | # return torch.pow(self.scalar(x)-self.scalar(self.zero_input), 2)
154 |
155 | def forward(self, x):
156 | return self.sc_zero_pos(x) + 0.01*torch.pow(x, 2).sum(dim=-1).unsqueeze(-1)
157 |
158 |
159 | class ScalarFuncPosUnbnd(ScalarFunc):
160 | def __init__(self, ndim, nhidden, nlayers):
161 | super(ScalarFuncPosUnbnd, self).__init__(ndim, nhidden, nlayers)
162 | self.ndim = ndim
163 | self.scalar = super(ScalarFuncPosUnbnd, self).forward
164 | self.register_buffer('zero_input', torch.zeros(1, ndim))
165 | self.make_pos = nn.Softplus()
166 |
167 | def sc_zero_pos(self, x):
168 | return self.make_pos(self.scalar(x))
169 |
170 | def forward(self, x):
171 | return self.sc_zero_pos(x) + 0.01*torch.pow(x, 2).sum(dim=-1).unsqueeze(-1) - self.sc_zero_pos(self.zero_input)
172 |
173 |
174 | class ZeroDivMat(nn.Module):
175 | """ Class for constructing a J matrix for generalized Hamiltonian
176 | TODO: write in terms of masked NN rather than a collection of NNs
177 | """
178 |
179 | def __init__(self, ndim, nhidden, nlayers):
180 | super(ZeroDivMat, self).__init__()
181 | self.ndim = ndim
182 | self.J = nn.ModuleDict()
183 | self.idx_list = []
184 | self.rel_terms = []
185 | for j in range(ndim):
186 | for i in range(ndim):
187 | if i >= j:
188 | continue
189 | else:
190 | self.rel_terms.append([ind for ind in range(
191 | self.ndim) if ind != i and ind != j])
192 | self.idx_list.append((i, j))
193 | self.J.update(
194 | [['{},{}'.format(i, j), ScalarFunc(ndim-2, nhidden, nlayers)]])
195 |
196 | def forward(self, x):
197 | J_ret = torch.zeros(x.shape[0], self.ndim, self.ndim, device=x.device)
198 | for ind, (i, j) in enumerate(self.idx_list):
199 | J_ret[:, i, j] = self.J['{},{}'.format(i, j)](
200 | x[:, self.rel_terms[ind]]).view(-1,)
201 | return J_ret - J_ret.transpose(2, 1)
202 |
203 |
204 | class GHNN(nn.Module):
205 | """ Main generalized Hamiltonian neural nets class
206 |
207 | INPUTS
208 | ndim < int > : number of input dimensions
209 | nhidden < int > : number of hidden units
210 | nlayers < int > : number of hidden layers
211 | prior < dict > : {
212 | 'H': choices = ['H0', 'H1', None], default = 'H0',
213 | 'dHdt': choices = ['decreasing', 'constant', None], default = None
214 | }
215 | Desciption of prior information dictionary:
216 | H (strong priors on the form of the Hamiltonian)
217 | 'H0': H(x) -> infty as x -> infty, H(x) + H(0) >= 0, H(0) = 0.
218 | - use if energy is bounded locally (ie. damped duffing oscilator)
219 | - makes sense in most cases
220 | 'H1': H(x) -> infty as x -> infty, H(0) = 0, H(x) > 0 forall x != 0
221 | - use if globally stable at x=0
222 | dHdt (strong priors on the energy flux rate)
223 | 'decreasing' : dHdt(x) < 0 along trajectories following dxdt = f(x)
224 | 'constant' : dHdt(x) = 0 along trajectories following dxdt = f(x)
225 |
226 | Regularization shemes:
227 | The forward function takes in an optional argument reg_schemes. This tuple
228 | should contain the names of regularization schemes you would like to use in training.
229 |
230 | The choices for regularization schemes are:
231 | - curl (use when performing curl regularization)
232 | - dhdt (use when imposing a soft energy flux rate prior)
233 |
234 | NOTE: for known gen. Hamiltonian use GHNNwHPrior class
235 | """
236 |
237 | def __init__(self, ndim, nhidden, nlayers, prior={'H': 'H0'}):
238 | super(GHNN, self).__init__()
239 | self.H_prior = prior.get('H', None)
240 | self.dHdt_prior = prior.get('dHdt', None)
241 |
242 | if self.H_prior == 'H0':
243 | self.H = ScalarFuncPosUnbnd(ndim, nhidden, nlayers)
244 | elif self.H_prior == 'H1':
245 | self.H = ScalarFuncZeroPos(ndim, nhidden, nlayers)
246 | else:
247 | self.H = ScalarFuncZero(ndim, nhidden, nlayers)
248 |
249 | self.J = ZeroDivMat(
250 | ndim, max(1, int(nhidden/(ndim*(ndim-1)/2))), nlayers)
251 | self.ndim = ndim
252 |
253 | self.hh_strict = False
254 | self.conservative = False
255 | if self.dHdt_prior == 'decreasing':
256 | self.v = ConcaveFuncZero(ndim, nhidden, nlayers)
257 | elif self.dHdt_prior == 'constant':
258 | self.conservative = True
259 | else:
260 | self.hh_strict = True
261 | self.fd = ScalarFuncZero(ndim, nhidden, nlayers)
262 |
263 | tri_l_ind = torch.ones(ndim, ndim).tril(-1) == 1
264 | self.register_buffer('tri_l_ind', tri_l_ind)
265 |
266 | def forward(self, t, x, reg_schemes=()):
267 | in_shape = x.shape
268 | x = x.reshape(-1, self.ndim).clone() # note: clone likely unecessary
269 | with torch.enable_grad():
270 | if not x.requires_grad:
271 | x.requires_grad = True
272 | grad_H, J_grad_H = self.conservative_forward(x)
273 |
274 | if self.conservative:
275 | grad_v, R_grad_H = torch.zeros(
276 | grad_H.shape), torch.zeros(J_grad_H.shape)
277 | else:
278 | grad_v, R_grad_H = self.nonconservative_forward(x, grad_H)
279 | dxdt = J_grad_H + R_grad_H
280 | dxdt = dxdt.reshape(in_shape)
281 | if not reg_schemes:
282 | out = dxdt
283 | else:
284 | reg_vals = []
285 | for reg in reg_schemes:
286 | if reg == 'dhdt':
287 | reg_vals.append(torch.einsum(
288 | 'ij,ij->i', grad_H, R_grad_H).reshape(in_shape[:-1]))
289 | elif reg == 'curl':
290 | reg_vals.append(curl(R_grad_H, x))
291 | # reg_vals.append(self._curl_R_grad_H(x, grad_H, grad_v))
292 | else:
293 | print('Reg. scheme: {} not recognized'.format(reg))
294 | out = (dxdt, reg_vals)
295 | return out
296 |
297 | def forward_odeint(self, t, x):
298 | """ note: only kept around for compatibility with ghnn v0.0.2 please use forward()
299 | forward method for ode integration tools
300 | """
301 | if not x.requires_grad:
302 | x.requires_grad = True
303 | out = self.forward(t, x)
304 | return out.detach()
305 |
306 | def conservative_forward(self, x):
307 | grad_H = torch.autograd.grad(self.H(x).sum(), x, create_graph=True)[0]
308 | J_grad_H = torch.einsum('ijk,ik->ij', self.J(x), grad_H)
309 | return grad_H, J_grad_H
310 |
311 | def nonconservative_forward(self, x, grad_H):
312 | if self.hh_strict:
313 | div_f = torch.autograd.grad(
314 | self.fd(x).sum(), x, create_graph=True)[0]
315 | R_grad_H = div_f
316 | grad_v = None
317 | elif self.conservative:
318 | grad_v = torch.zeros(grad_H.shape, requires_grad=True)*x
319 | R_grad_H = torch.zeros(grad_H.shape, requires_grad=True)*x
320 | else:
321 | grad_v = torch.autograd.grad(
322 | self.v(x).sum(), x, create_graph=True)[0]
323 | R_grad_H = torch.autograd.grad(
324 | grad_v, x, create_graph=True, grad_outputs=grad_H)[0]
325 | return grad_v, R_grad_H
326 |
327 | def _curl_R_grad_H(self, x, grad_H, grad_v):
328 | """ Note: needs serious optimization
329 | """
330 | if self.hh_strict:
331 | print('Warning, R*gradH is is stable by construction.')
332 | return torch.zeros(x.shape, device=x.device)
333 | else:
334 | hess_H = self.hessian(grad_H, x)
335 | hess_v = self.hessian(grad_v, x)
336 | VtrH = torch.einsum('ijk,ikl->ijl', hess_H, hess_v)
337 | curl_mat = VtrH - VtrH.transpose(1, 2)
338 | return curl_mat[:, self.tri_l_ind]
339 |
340 | def hessian(self, grad_f, x):
341 | hess = []
342 | for j in range(x.shape[1]):
343 | hess.append(torch.autograd.grad(
344 | grad_f[:, j].sum(), x, create_graph=True)[0])
345 | return torch.stack(hess, dim=2)
346 |
347 | def _set_requires_grad_true(self, x):
348 | if not x.requires_grad:
349 | x.requires_grad = True
350 |
351 | def _set_requires_grad_false(self, x):
352 | if x.requires_grad:
353 | x.requires_grad = False
354 |
355 |
356 | class GHNNwHPrior(nn.Module):
357 | """
358 | INPUTS
359 | ndim < int > : number of input dimensions
360 | nhidden < int > : number of hidden units
361 | nlayers < int > : number of hidden layers
362 | H < function (tensor) -> (tensor) > : known gen. Hamiltonian
363 |
364 | Example:
365 | model = GHNNwHPrior(2, 200, 3, H)
366 | (training model)
367 | x = torch.randn(1,2)
368 | # get value of derivative at x
369 | dxdt = model(0., x)
370 | # get J at x
371 | J = model.J()
372 | # get R at x
373 | R = model.R()
374 | """
375 |
376 | def __init__(self, n_dim, n_hidden, n_layers, H):
377 | super(GHNNwHPrior, self).__init__()
378 | self.n_dim = n_dim
379 | self.n_hidden = n_hidden
380 | self.n_layers = n_layers
381 | self.H = H
382 | # setting up linear layers
383 | layers = [nn.Linear(n_dim, n_hidden), ]
384 | for j in range(n_layers-1):
385 | layers.append(nn.Softplus())
386 | layers.append(nn.Linear(n_hidden, n_hidden))
387 | layers.append(nn.Softplus())
388 | num_out = int(n_dim**2)
389 | layers.append(nn.Linear(n_hidden, num_out))
390 | self.mlp = nn.Sequential(*layers)
391 | # setting up W matrix
392 | # gives full array indices lol
393 | self.ind = torch.tril_indices(n_dim, n_dim, offset=n_dim)
394 | self.register_buffer('W', torch.zeros(n_dim, n_dim))
395 | self.W_tmp = None
396 |
397 | def forward(self, t, input):
398 | in_shape = input.shape
399 | x = input.reshape(-1, in_shape[-1])
400 | W = self.W.repeat(x.shape[0], 1, 1)
401 | W[:, self.ind[0], self.ind[1]] = self.mlp(x)
402 | self.W_tmp = W.detach().clone()
403 | with torch.enable_grad():
404 | if not x.requires_grad:
405 | x.requires_grad = True
406 | grad_H = torch.autograd.grad(
407 | self.H(x).sum(), x, create_graph=True)[0]
408 | out = (W @ grad_H.unsqueeze(-1)).squeeze(-1)
409 | return out.reshape(in_shape)
410 |
411 | def J(self):
412 | return 0.5*(self.W_tmp - self.W_tmp.transpose(2, 1))
413 |
414 | def R(self):
415 | return 0.5*(self.W_tmp + self.W_tmp.transpose(2, 1))
416 |
417 |
418 | class HNN(nn.Module):
419 | "Assumes ndim = 2n where n is the dimension of q"
420 |
421 | def __init__(self, ndim, nhidden, nlayers):
422 | super(HNN, self).__init__()
423 | self.H = ScalarFuncZero(ndim, nhidden, nlayers)
424 | self.ndim = ndim
425 | n = int(ndim/2)
426 | top = torch.cat([torch.zeros(n, n), torch.eye(n)], dim=1)
427 | bot = torch.cat([-torch.eye(n), torch.zeros(n, n)], dim=1)
428 | self.register_buffer('J', torch.cat([top, bot], dim=0))
429 |
430 | def forward(self, t, x):
431 | # assumes x is the form [[q1, q2, ..., qn, p1, p2, ..., pn], ...]
432 | in_shape = x.shape
433 | x = x.view(-1, self.ndim).clone()
434 | with torch.enable_grad():
435 | if not x.requires_grad:
436 | x.requires_grad = True
437 | H = self.H(x)
438 | gradH = torch.autograd.grad(H.sum(), x, create_graph=True)[0]
439 | dxdt = gradH @ self.J.T
440 |
441 | return dxdt.reshape(in_shape)
442 |
443 |
444 | class ODEFCN(nn.Module):
445 | def __init__(self, ndim, nhidden, nlayers):
446 | super(ODEFCN, self).__init__()
447 | self.linears = nn.ModuleList([nn.Linear(ndim, nhidden)])
448 | hidden_layers = [nn.Linear(nhidden, nhidden)
449 | for j in range(nlayers - 1)]
450 | self.linears.extend(hidden_layers)
451 | self.activation = nn.Softplus()
452 | self.out = nn.Linear(nhidden, ndim)
453 | self.mlp = self._mlp_forward
454 |
455 | def _mlp_forward(self, x):
456 | for _, l in enumerate(self.linears):
457 | x = self.activation(l(x))
458 | return self.out(x)
459 |
460 | def forward(self, t, x):
461 | return self.mlp(x)
462 |
463 | def H(self, x):
464 | return 0.*x.sum(dim=-1)
465 |
466 |
467 | class ReHU(nn.Module):
468 | """ Rectified Huber unit
469 | from: https://github.com/locuslab/stable_dynamics/blob/master/models/stabledynamics.py
470 | """
471 |
472 | def __init__(self, d):
473 | super(ReHU, self).__init__()
474 | self.a = 1/d
475 | self.b = -d/2
476 |
477 | def forward(self, x):
478 | return torch.max(torch.clamp(torch.sign(x)*self.a/2*x**2, min=0, max=-self.b), x+self.b)
479 |
--------------------------------------------------------------------------------
/weakformghnn/_src/_vector_calc.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | __all__ = ['divergence', 'curl']
4 |
5 |
6 | def divergence(f, x):
7 | """ Computes the divergence for a function f at points x
8 | INPUTS:
9 | f < tensor > : vector function values (mxn)
10 | x < tensor > : (mxn) input vector
11 | OUTPUTS
12 | divergence < tensor > : (m,)
13 | """
14 | div = []
15 | for j in range(x.shape[1]):
16 | grad_f = torch.autograd.grad(
17 | f[:, j].sum(), x, create_graph=True, allow_unused=True)[0]
18 | div.append(grad_f[:, j])
19 | return torch.stack(div).sum(dim=0)
20 |
21 |
22 | def curl(f, x):
23 | """ Computes the curl for a function f at points x
24 | INPUTS:
25 | f < tensor > : vector function values (mxn)
26 | x < tensor > : mxn input vector
27 |
28 | OUTPUTS:
29 | curl < tensor > : mx(n*(n-1)/2)
30 | """
31 | N = x.shape[1]
32 | grad_f_array = []
33 | for i in range(N):
34 | grad_f = torch.autograd.grad(
35 | f[:, i].sum(), x, allow_unused=True, create_graph=True)[0]
36 | grad_f_array.append(grad_f)
37 | cu = []
38 | for i in range(N):
39 | for j in range(N):
40 | if i >= j:
41 | continue
42 | else:
43 | c_ij = grad_f_array[j][:, i] - grad_f_array[i][:, j]
44 | cu.append(c_ij)
45 | return torch.stack(cu, dim=1)
46 |
--------------------------------------------------------------------------------
/weakformghnn/_src/_weak_form.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | __all__ = ['gauss_rbf',
4 | 'poly_bf',
5 | 'weak_form_loss']
6 |
7 |
8 | def gauss_rbf(t, c, eps, minimum=0.):
9 | """ Returns the gaussian rbf for a collection of points t
10 | INPUTS:
11 | t < torch.Tensor (m,) > : inputs to gaussian rbfs
12 | c < torch.Tensor (M,) > : centers of rbfs (where M is the number of GRBFS)
13 | eps < float > : shape paramter of gaussian radial basis functions
14 | minimum < float > : minimum grbf function value
15 | RETURNS:
16 | gauss_rbfs < torch.Tensor(m,M) > : grbfs evaluated at t's
17 | gauss_rbf_derivs < torch.Tensor(m,M) > grbf derivs evaluted at ts
18 | """
19 | diff = t.view(-1, 1) - c.view(1, -1) # m x M
20 | grbf = torch.exp(-eps**2 * torch.pow(diff, 2)) + minimum
21 | grbf_deriv = torch.mul(-2*eps**2*diff, grbf)
22 | return grbf, grbf_deriv
23 |
24 |
25 | def poly_bf(t, c, deg):
26 | """ Returns polynomial basis funcs and their derivatives centered at c evaluated at points t
27 | INPUTS:
28 | t < torch.Tensor (m,) > : inputs to gaussian rbfs
29 | c < torch.Tensor (M,) > : centers of rbfs (where M is the number of GRBFS)
30 | deg < int > : polynomial bf degree (note should be < ~ 10)
31 | RETURNS:
32 | poly_bfs < torch.Tensor(m,M) > : poly basis functions evaluated at t's (M = c(deg + 1))
33 | poly_bf_derivs < torch.Tensor(m,M) > : poly basis function derivatives evaluated at times
34 | """
35 | diff = t.view(-1, 1) - c.view(1, -1)
36 | poly = torch.cat([torch.pow(diff, j) for j in range(deg+1)], dim=1)
37 | poly_deriv = torch.cat(
38 | [torch.ones(diff.shape)*j for j in range(deg+1)], dim=1)
39 | poly_deriv[:, 2*diff.shape[1]:] = poly_deriv[:, 2*diff.shape[1]:] * \
40 | poly[:, diff.shape[1]:-diff.shape[1]]
41 | return poly, poly_deriv
42 |
43 |
44 | def weak_form_loss(dx_est, x, t, psi, psi_dot):
45 | """ Returns the weak-form ode model loss
46 | INPUTS:
47 | dx_est < torch.Tensor, (Bs, m, ndim) > : estimate for derivative
48 | x < torch.Tensor, (Bs, m, ndim) > : state training point
49 | t < torch.Tensor, (m) > integration times
50 | psi < torch.Tensor, (m,M) > : test functions evaluated at measurment times
51 | psi_dot < torch.Tensor, (m,M) > : test function derivatives evaluated at measurement times
52 | OUTPUTS:
53 | weak_form_loss < torch.float > : weak-form squared loss
54 | """
55 | # boundary term
56 | RH_bound = torch.einsum('in,m->inm', x[:, -1, :], psi[-1])
57 | LH_bound = torch.einsum('in,m->inm', x[:, 0, :], psi[0])
58 | B = RH_bound - LH_bound
59 |
60 | x_psi_dot = torch.einsum('imn,ml->imnl', x, psi_dot)
61 | f_psi = torch.einsum('imn,ml->imnl', dx_est, psi)
62 |
63 | L = torch.trapz(f_psi + x_psi_dot, t, dim=1) - B
64 | return L.pow(2).sum(-1).mean()
65 |
--------------------------------------------------------------------------------