├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── diff_source_2d.py ├── lorenz.py ├── lorenz_noise.py ├── nlse_1d.py ├── reac_diff_2d.py ├── rossler.py └── utils.py ├── diff_source_2d_model.py ├── encoder ├── __init__.py ├── embedding.py └── utils.py ├── lorenz_model.py ├── lorenz_model_extrahidden.py ├── lorenz_model_noise.py ├── nlse_1d_model_embed.py ├── reac_diff_2d_model.py ├── rossler_model.py ├── symder ├── __init__.py ├── odeint_zero.py ├── sym_models.py └── symder.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Data files 7 | *.npy 8 | *.npz 9 | *.pt 10 | 11 | # Output files 12 | out_* 13 | 14 | # Notebook checkpoints 15 | *.ipynb_checkpoints/ 16 | 17 | # VSCode settings 18 | .vscode/ 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Peter Y. Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **SymDer**: **Sym**bolic **Der**ivative Network for *Discovering Sparse Interpretable Dynamics from Partial Observations* 2 | 3 | Implementation of a machine learning method for identifying the governing equations of a nonlinear dynamical system using using only partial observations. Our machine learning framework combines an encoder for state reconstruction with a sparse symbolic model. In order to train our model by matching time derivatives, we implement an algorithmic trick (see `symder/odeint_zero.py`) for taking higher order derivatives of a variable that is implicitly defined by a differential equation (i.e. the symbolic model). 4 | 5 | Please cite "**Discovering sparse interpretable dynamics from partial observations**" (https://doi.org/10.1038/s42005-022-00987-z) and see the paper for more details. This is the official repository for the paper. 6 | 7 | ## Requirements 8 | 9 | JAX >= 0.2.8, Haiku >= 0.0.4, scikit-learn, NumPy, SciPy 10 | 11 | ## Usage 12 | 13 | Data generation scripts are contained in `data/`. Encoder models and related tools are contained in `encoder/`. Symbolic models and the tools for taking higher order symbolic time derivatives are contained in `symder/`. The individual `*_model.py` files provide examples of how to use our method on a variety of ODE and PDE systems. -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["utils", "lorenz", "reac_diff_2d", "nlse_1d"] 2 | from . import * 3 | -------------------------------------------------------------------------------- /data/diff_source_2d.py: -------------------------------------------------------------------------------- 1 | from jax import lax 2 | import numpy as np 3 | from numpy.fft import fftfreq, fft2, ifft2 4 | from sklearn.preprocessing import StandardScaler 5 | 6 | from .utils import solve_ETDRK4, generate_diff_kernels 7 | 8 | __all__ = ["generate_dataset"] 9 | 10 | 11 | def generate_dataset( 12 | sys_size=64, mesh=64, dt=5e-2, tspan=None, num_der=2, seed=0, raw_sol=False, 13 | ): 14 | if tspan is None: 15 | tspan = (0, 50 + 2 * dt) 16 | 17 | kx = np.expand_dims(2 * np.pi * fftfreq(mesh, d=sys_size / mesh), axis=-1) 18 | ky = np.expand_dims(2 * np.pi * fftfreq(mesh, d=sys_size / mesh), axis=0) 19 | 20 | # Initial condition 21 | np.random.seed(seed) 22 | krange = 1 23 | envelope = np.exp(-1 / (2 * krange ** 2) * (kx ** 2 + ky ** 2)) 24 | v0 = envelope * ( 25 | np.random.normal(loc=0, scale=1.0, size=(2, mesh, mesh)) 26 | + 1j * np.random.normal(loc=0, scale=1.0, size=(2, mesh, mesh)) 27 | ) 28 | u0 = np.real(ifft2(v0)) 29 | # normalize 30 | u0 = u0 / np.max(np.abs(u0), axis=(-2, -1), keepdims=True) 31 | 32 | n_rects = 50 33 | u0[1] = np.zeros((1, mesh, mesh)) 34 | rect_pos = ( 35 | np.random.uniform(0, sys_size, size=(n_rects, 2)) * mesh / sys_size 36 | ).astype(int) 37 | rect_size = ( 38 | np.random.uniform(0, 0.05 * sys_size, size=(n_rects, 2)) * mesh / sys_size 39 | ).astype(int) 40 | rect_value = np.random.uniform(0, 0.2, size=(n_rects,)) 41 | for i in range(n_rects): 42 | rect = np.zeros((mesh, mesh), dtype=bool) 43 | rect[: rect_size[i, 0], : rect_size[i, 1]] = True 44 | rect = np.roll(np.roll(rect, rect_pos[i, 0], axis=0), rect_pos[i, 1], axis=1) 45 | u0[1, :, :] = u0[1, :, :] * (1 - rect) + rect_value[i] * rect 46 | 47 | # Differential equation definition 48 | D2 = -(kx ** 2 + ky ** 2) 49 | L = np.stack((0.2 * D2, np.zeros_like(D2))) 50 | 51 | def N(v): 52 | v2 = v[..., 1, :, :] 53 | dv = np.stack([1 * v2, -0.1 * v2], axis=-3) 54 | return dv 55 | 56 | # Solve using ETDRK4 method 57 | print("Generating 2D diffusion with source dataset...") 58 | sol_u = solve_ETDRK4(L, N, fft2(u0), tspan, dt, lambda v: np.real(ifft2(v))) 59 | data = sol_u[:, 0].reshape(sol_u.shape[0], 1 * mesh ** 2) 60 | data = data.T 61 | 62 | # Compute finite difference derivatives 63 | kernels = generate_diff_kernels(num_der) 64 | data = lax.conv(data[:, None, :], kernels[:, None, :], (1,), "VALID") 65 | # time, mesh**2, num_visible, num_der+1 66 | data = data[None, ...].transpose((3, 1, 0, 2)) 67 | 68 | # Rescale/normalize data 69 | reshaped_data = data.reshape(-1, data.shape[2] * data.shape[3]) 70 | scaler = StandardScaler(with_mean=False) 71 | scaler.fit(reshaped_data) 72 | # scaler.scale_ /= scaler.scale_[0] 73 | scaled_data = scaler.transform(reshaped_data) 74 | # time, mesh, mesh, num_visible, num_der+1 75 | scaled_data = scaled_data.reshape(-1, mesh, mesh, 1, num_der + 1) 76 | 77 | return ( 78 | scaled_data, 79 | scaler.scale_.reshape(1, num_der + 1), 80 | sol_u if raw_sol else None, 81 | ) 82 | -------------------------------------------------------------------------------- /data/lorenz.py: -------------------------------------------------------------------------------- 1 | from jax import lax 2 | import numpy as np 3 | from scipy.integrate import solve_ivp 4 | from sklearn.preprocessing import StandardScaler 5 | 6 | from .utils import generate_diff_kernels 7 | 8 | __all__ = ["generate_dataset"] 9 | 10 | 11 | def generate_dataset( 12 | dt=1e-2, tmax=None, num_visible=2, num_der=2, visible_vars=None, raw_sol=False, 13 | ): 14 | if tmax is None: 15 | tmax = 100 + 2 * dt 16 | if visible_vars is None: 17 | visible_vars = list(range(num_visible)) 18 | else: 19 | assert len(visible_vars) == num_visible 20 | 21 | def lorenz(t, y0, sigma, beta, rho): 22 | """Lorenz equations""" 23 | u, v, w = y0[..., 0], y0[..., 1], y0[..., 2] 24 | up = -sigma * (u - v) 25 | vp = rho * u - v - u * w 26 | wp = -beta * w + u * v 27 | return np.stack((up, vp, wp), axis=-1) 28 | 29 | # Lorenz parameters and initial conditions 30 | sigma, beta, rho = 10, 8 / 3.0, 28 31 | u0, v0, w0 = 0, 1, 1.05 32 | 33 | # Integrate the Lorenz equations on the time grid t 34 | print("Generating Lorenz system dataset...") 35 | t_eval = np.arange(0, tmax, dt) 36 | sol = solve_ivp( 37 | lorenz, 38 | (0, tmax), 39 | y0=np.stack((u0, v0, w0), axis=-1), 40 | t_eval=t_eval, 41 | args=(sigma, beta, rho), 42 | ) 43 | data = sol.y[visible_vars] 44 | 45 | # Compute finite difference derivatives 46 | kernels = generate_diff_kernels(num_der) 47 | data = lax.conv(data[:, None, :], kernels[:, None, :], (1,), "VALID") 48 | 49 | # Rescale/normalize data 50 | reshaped_data = data.reshape(-1, data.shape[2]) 51 | scaler = StandardScaler(with_mean=False) 52 | scaled_data = scaler.fit_transform(reshaped_data.T) 53 | scaled_data = scaled_data.reshape( 54 | scaled_data.shape[0], data.shape[0], data.shape[1] 55 | ) 56 | 57 | return ( 58 | scaled_data, 59 | scaler.scale_.reshape(num_visible, num_der + 1), 60 | sol.y if raw_sol else None, 61 | ) 62 | -------------------------------------------------------------------------------- /data/lorenz_noise.py: -------------------------------------------------------------------------------- 1 | from jax import lax 2 | import numpy as np 3 | from scipy.integrate import solve_ivp 4 | from sklearn.preprocessing import StandardScaler 5 | 6 | from .utils import generate_diff_kernels 7 | from scipy.signal import savgol_filter 8 | 9 | __all__ = ["generate_dataset"] 10 | 11 | 12 | def generate_dataset( 13 | dt=1e-2, 14 | tmax=None, 15 | num_visible=2, 16 | num_der=2, 17 | visible_vars=None, 18 | noise=None, 19 | rng=np.random.default_rng(0), 20 | smoothing_params=None, 21 | raw_sol=False, 22 | ): 23 | if tmax is None: 24 | tmax = 100 + 2 * dt 25 | if visible_vars is None: 26 | visible_vars = list(range(num_visible)) 27 | else: 28 | assert len(visible_vars) == num_visible 29 | 30 | def lorenz(t, y0, sigma, beta, rho): 31 | """Lorenz equations""" 32 | u, v, w = y0[..., 0], y0[..., 1], y0[..., 2] 33 | up = -sigma * (u - v) 34 | vp = rho * u - v - u * w 35 | wp = -beta * w + u * v 36 | return np.stack((up, vp, wp), axis=-1) 37 | 38 | # Lorenz parameters and initial conditions 39 | sigma, beta, rho = 10, 8 / 3.0, 28 40 | u0, v0, w0 = 0, 1, 1.05 41 | 42 | # Integrate the Lorenz equations on the time grid t 43 | print("Generating Lorenz system dataset...") 44 | t_eval = np.arange(0, tmax, dt) 45 | sol = solve_ivp( 46 | lorenz, 47 | (0, tmax), 48 | y0=np.stack((u0, v0, w0), axis=-1), 49 | t_eval=t_eval, 50 | args=(sigma, beta, rho), 51 | ) 52 | data = sol.y[visible_vars] 53 | 54 | if noise is not None: 55 | # Add noise 56 | data_no_noise = data.copy() 57 | data += rng.normal(scale=noise, size=data.shape) 58 | 59 | if smoothing_params is not None: 60 | # Smoothing 61 | data = savgol_filter(data, smoothing_params[0], smoothing_params[1]) 62 | 63 | print(np.mean(np.abs(data - data_no_noise))) 64 | 65 | # Compute finite difference derivatives 66 | kernels = generate_diff_kernels(num_der) 67 | data = lax.conv(data[:, None, :], kernels[:, None, :], (1,), "VALID") 68 | 69 | # Rescale/normalize data 70 | reshaped_data = data.reshape(-1, data.shape[2]) 71 | scaler = StandardScaler(with_mean=False) 72 | scaled_data = scaler.fit_transform(reshaped_data.T) 73 | scaled_data = scaled_data.reshape( 74 | scaled_data.shape[0], data.shape[0], data.shape[1] 75 | ) 76 | 77 | return ( 78 | scaled_data, 79 | scaler.scale_.reshape(num_visible, num_der + 1), 80 | sol.y if raw_sol else None, 81 | ) 82 | -------------------------------------------------------------------------------- /data/nlse_1d.py: -------------------------------------------------------------------------------- 1 | from jax import lax 2 | import numpy as np 3 | from numpy.fft import fftfreq, fft, ifft 4 | from scipy.signal import resample 5 | from sklearn.preprocessing import StandardScaler 6 | 7 | from .utils import solve_ETDRK4, generate_diff_kernels 8 | 9 | __all__ = ["generate_dataset"] 10 | 11 | 12 | def generate_dataset( 13 | sys_size=2 * np.pi, 14 | mesh=64, 15 | dt=1e-3, # 10 * 1e-4 16 | tspan=None, 17 | pool=16, 18 | tpool=10, 19 | num_der=2, 20 | seed=0, 21 | squared=False, 22 | raw_sol=False, 23 | ): 24 | if tspan is None: 25 | tspan = (0, 0.5 + 4 * dt) 26 | 27 | out_mesh, mesh = mesh, pool * mesh 28 | dt /= tpool 29 | 30 | k = 2 * np.pi * fftfreq(mesh, d=sys_size / mesh) 31 | 32 | # Initial condition 33 | np.random.seed(seed) 34 | krange = 1.0 35 | envelope = np.exp(-1 / (2 * krange ** 2) * k ** 2) 36 | np.random.seed(0) 37 | v0 = envelope * ( 38 | np.random.normal(loc=0, scale=1.0, size=(1, mesh)) 39 | + 1j * np.random.normal(loc=0, scale=1.0, size=(1, mesh)) 40 | ) 41 | u0 = ifft(v0) 42 | u0 = ( 43 | np.sqrt(2 * mesh) * u0 / np.linalg.norm(u0, axis=(-2, -1), keepdims=True) 44 | ) # normalize 45 | v0 = fft(u0) 46 | 47 | # Differential equation definition 48 | L = -0.5j * k ** 2 49 | 50 | def N(v): 51 | u = ifft(v) 52 | kappa = -1 53 | return -1j * kappa * fft(np.abs(u) ** 2 * u) 54 | 55 | # Solve using ETDRK4 method 56 | print("Generating 1D nonlinear Schrödinger dataset...") 57 | sol_u = solve_ETDRK4(L, N, v0, tspan, dt, lambda v: ifft(v)) 58 | sol_u = resample(sol_u[::tpool], out_mesh, axis=-1) 59 | if squared: 60 | data = (np.abs(sol_u[:, 0]) ** 2).reshape(-1, 1 * out_mesh) 61 | else: 62 | data = np.abs(sol_u[:, 0]).reshape(-1, 1 * out_mesh) 63 | data = data.T 64 | 65 | # Compute finite difference derivatives 66 | kernels = generate_diff_kernels(num_der) 67 | data = lax.conv(data[:, None, :], kernels[:, None, :], (1,), "VALID") 68 | # time, out_mesh, num_visible, num_der+1 69 | data = data[None, ...].transpose((3, 1, 0, 2)) 70 | 71 | # Rescale/normalize data 72 | reshaped_data = data.reshape(-1, data.shape[2] * data.shape[3]) 73 | scaler = StandardScaler(with_mean=False) 74 | scaler.fit(reshaped_data) 75 | scaler.scale_ /= scaler.scale_[0] 76 | scaled_data = scaler.transform(reshaped_data) 77 | # time, out_mesh, num_visible, num_der+1 78 | scaled_data = scaled_data.reshape(-1, out_mesh, 1, num_der + 1) 79 | 80 | return ( 81 | scaled_data, 82 | scaler.scale_.reshape(1, num_der + 1), 83 | sol_u if raw_sol else None, 84 | ) 85 | -------------------------------------------------------------------------------- /data/reac_diff_2d.py: -------------------------------------------------------------------------------- 1 | from jax import lax 2 | import numpy as np 3 | from numpy.fft import fftfreq, fft2, ifft2 4 | from sklearn.preprocessing import StandardScaler 5 | 6 | from .utils import solve_ETDRK4, generate_diff_kernels 7 | 8 | __all__ = ["generate_dataset"] 9 | 10 | 11 | def generate_dataset( 12 | sys_size=64, 13 | mesh=64, 14 | dt=5e-2, 15 | tspan=None, 16 | init_mode="rect", 17 | num_der=2, 18 | seed=0, 19 | raw_sol=False, 20 | ): 21 | if tspan is None: 22 | tspan = (0, 50 + 2 * dt) 23 | 24 | kx = np.expand_dims(2 * np.pi * fftfreq(mesh, d=sys_size / mesh), axis=-1) 25 | ky = np.expand_dims(2 * np.pi * fftfreq(mesh, d=sys_size / mesh), axis=0) 26 | 27 | # Initial condition 28 | np.random.seed(seed) 29 | if init_mode == "fourier": 30 | krange = 5 31 | envelope = np.exp(-1 / (2 * krange ** 2) * (kx ** 2 + ky ** 2)) 32 | v0 = envelope * ( 33 | np.random.normal(loc=0, scale=1.0, size=(2, mesh, mesh)) 34 | + 1j * np.random.normal(loc=0, scale=1.0, size=(2, mesh, mesh)) 35 | ) 36 | u0 = np.real(ifft2(v0)) 37 | # normalize 38 | u0 = 0.55 + 0.45 * u0 / np.max(np.abs(u0), axis=(-2, -1), keepdims=True) 39 | u0[..., 0, :, :] = 0.5 40 | elif init_mode == "rect": 41 | n_rects = 50 42 | u0 = 0.5 * np.ones((2, mesh, mesh)) 43 | rect_pos = ( 44 | np.random.uniform(0, sys_size, size=(n_rects, 2)) * mesh / sys_size 45 | ).astype(int) 46 | rect_size = ( 47 | np.random.uniform(0, 0.2 * sys_size, size=(n_rects, 2)) * mesh / sys_size 48 | ).astype(int) 49 | rect_value = np.random.uniform(0.1, 1, size=(n_rects,)) 50 | for i in range(n_rects): 51 | rect = np.zeros((mesh, mesh), dtype=bool) 52 | rect[: rect_size[i, 0], : rect_size[i, 1]] = True 53 | rect = np.roll( 54 | np.roll(rect, rect_pos[i, 0], axis=0), rect_pos[i, 1], axis=1 55 | ) 56 | u0[1, :, :] = u0[1, :, :] * (1 - rect) + rect_value[i] * rect 57 | else: 58 | raise ValueError( 59 | f"init_mode = '{init_mode}' is not valid." 60 | "init_mode must be in ['fourier', 'rect']" 61 | ) 62 | 63 | # Differential equation definition 64 | D2 = -(kx ** 2 + ky ** 2) 65 | L = np.stack((0.05 * D2, 0.1 * D2)) 66 | 67 | def N(v): 68 | u = np.real(ifft2(v)) 69 | u1 = u[..., 0, :, :] 70 | u2 = u[..., 1, :, :] 71 | du = np.stack([7 / 3.0 * u1 - 8 / 3.0 * u1 * u2, -u2 + u1 * u2], axis=-3) 72 | return fft2(du) 73 | 74 | # Solve using ETDRK4 method 75 | print("Generating 2D reaction-diffusion (diffusive Lotka-Volterra) dataset...") 76 | sol_u = solve_ETDRK4(L, N, fft2(u0), tspan, dt, lambda v: np.real(ifft2(v))) 77 | data = sol_u[:, 0].reshape(sol_u.shape[0], 1 * mesh ** 2) 78 | data = data.T 79 | 80 | # Compute finite difference derivatives 81 | kernels = generate_diff_kernels(num_der) 82 | data = lax.conv(data[:, None, :], kernels[:, None, :], (1,), "VALID") 83 | # time, mesh**2, num_visible, num_der+1 84 | data = data[None, ...].transpose((3, 1, 0, 2)) 85 | 86 | # Rescale/normalize data 87 | reshaped_data = data.reshape(-1, data.shape[2] * data.shape[3]) 88 | scaler = StandardScaler(with_mean=False) 89 | scaler.fit(reshaped_data) 90 | scaler.scale_ /= scaler.scale_[0] 91 | scaled_data = scaler.transform(reshaped_data) 92 | # time, mesh, mesh, num_visible, num_der+1 93 | scaled_data = scaled_data.reshape(-1, mesh, mesh, 1, num_der + 1) 94 | 95 | return ( 96 | scaled_data, 97 | scaler.scale_.reshape(1, num_der + 1), 98 | sol_u if raw_sol else None, 99 | ) 100 | -------------------------------------------------------------------------------- /data/rossler.py: -------------------------------------------------------------------------------- 1 | from jax import lax 2 | import numpy as np 3 | from scipy.integrate import solve_ivp 4 | from sklearn.preprocessing import StandardScaler 5 | 6 | from .utils import generate_diff_kernels 7 | 8 | __all__ = ["generate_dataset"] 9 | 10 | 11 | def generate_dataset(dt=1e-2, tmax=None, num_visible=2, num_der=2, raw_sol=False): 12 | if tmax is None: 13 | tmax = 100 + 2 * dt 14 | 15 | def rossler(t, y0, a, b, c): 16 | """Rossler equations""" 17 | u, v, w = y0[..., 0], y0[..., 1], y0[..., 2] 18 | up = -v - w 19 | vp = u + a * v 20 | wp = b + w * (u - c) 21 | return np.stack((up, vp, wp), axis=-1) 22 | 23 | # Rossler parameters and initial conditions 24 | a, b, c = 0.2, 0.2, 5.7 25 | u0, v0, w0 = 0, 1, 1.05 26 | 27 | # Integrate the Rossler equations on the time grid t 28 | print("Generating Rossler system dataset...") 29 | t_eval = np.arange(0, tmax, dt) 30 | sol = solve_ivp( 31 | rossler, 32 | (0, tmax), 33 | y0=np.stack((u0, v0, w0), axis=-1), 34 | t_eval=t_eval, 35 | args=(a, b, c), 36 | ) 37 | data = sol.y[range(num_visible)] 38 | 39 | # Compute finite difference derivatives 40 | kernels = generate_diff_kernels(num_der) 41 | data = lax.conv(data[:, None, :], kernels[:, None, :], (1,), "VALID") 42 | 43 | # Rescale/normalize data 44 | reshaped_data = data.reshape(-1, data.shape[2]) 45 | scaler = StandardScaler(with_mean=False) 46 | scaled_data = scaler.fit_transform(reshaped_data.T) 47 | scaled_data = scaled_data.reshape( 48 | scaled_data.shape[0], data.shape[0], data.shape[1] 49 | ) 50 | 51 | return ( 52 | scaled_data, 53 | scaler.scale_.reshape(num_visible, num_der + 1), 54 | sol.y if raw_sol else None, 55 | ) 56 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | # import jax.numpy as jnp 2 | from jax import lax 3 | import numpy as np 4 | 5 | import os.path 6 | from tqdm.auto import tqdm 7 | 8 | __all__ = ["solve_ETDRK4", "generate_diff_kernels", "save_dataset", "load_dataset"] 9 | 10 | 11 | def solve_ETDRK4(L, N, v0, tspan, dt, output_func): 12 | """ETDRK4 method""" 13 | E = np.exp(dt * L) 14 | E2 = np.exp(dt * L / 2.0) 15 | 16 | contour_radius = 1 17 | M = 16 18 | r = contour_radius * np.exp(1j * np.pi * (np.arange(1, M + 1) - 0.5) / M) 19 | 20 | LR = dt * L 21 | LR = np.expand_dims(LR, axis=-1) + r 22 | 23 | Q = dt * np.real(np.mean((np.exp(LR / 2.0) - 1) / LR, axis=-1)) 24 | f1 = dt * np.real( 25 | np.mean( 26 | (-4.0 - LR + np.exp(LR) * (4.0 - 3.0 * LR + LR ** 2)) / LR ** 3, axis=-1 27 | ) 28 | ) 29 | f2 = dt * np.real(np.mean((2.0 + LR + np.exp(LR) * (-2.0 + LR)) / LR ** 3, axis=-1)) 30 | f3 = dt * np.real( 31 | np.mean( 32 | (-4.0 - 3.0 * LR - LR ** 2 + np.exp(LR) * (4.0 - LR)) / LR ** 3, axis=-1 33 | ) 34 | ) 35 | 36 | u = [] 37 | v = v0 38 | for t in tqdm(np.arange(tspan[0], tspan[1], dt)): 39 | u.append(output_func(v)) 40 | 41 | Nv = N(v) 42 | a = E2 * v + Q * Nv 43 | Na = N(a) 44 | b = E2 * v + Q * Na 45 | Nb = N(b) 46 | c = E2 * a + Q * (2.0 * Nb - Nv) 47 | Nc = N(c) 48 | v = E * v + Nv * f1 + 2.0 * (Na + Nb) * f2 + Nc * f3 49 | 50 | return np.stack(u) 51 | 52 | 53 | def generate_diff_kernels(order): 54 | p = int(np.floor((order + 1) / 2)) 55 | 56 | rev_d1 = np.array((0.5, 0.0, -0.5)) 57 | d2 = np.array((1.0, -2.0, 1.0)) 58 | 59 | even_kernels = [np.pad(np.array((1.0,)), (p,))] 60 | for i in range(order // 2): 61 | even_kernels.append(np.convolve(even_kernels[-1], d2, mode="same")) 62 | 63 | even_kernels = np.stack(even_kernels) 64 | odd_kernels = lax.conv( 65 | even_kernels[:, None, :], rev_d1[None, None, :], (1,), "SAME" 66 | ).squeeze(1) 67 | 68 | kernels = np.stack((even_kernels, odd_kernels), axis=1).reshape(-1, 2 * p + 1) 69 | if order % 2 == 0: 70 | kernels = kernels[:-1] 71 | 72 | return kernels 73 | 74 | 75 | def get_dataset(filename, generate_dataset, get_raw_sol=False, **gen_kwargs): 76 | if os.path.isfile(filename): 77 | scaled_data, scale, loaded_gen_kwargs, raw_sol = load_dataset( 78 | filename, get_raw_sol 79 | ) 80 | assert gen_kwargs == loaded_gen_kwargs 81 | else: 82 | scaled_data, scale, raw_sol = generate_dataset( 83 | raw_sol=get_raw_sol, **gen_kwargs 84 | ) 85 | save_dataset(filename, scaled_data, scale, gen_kwargs, raw_sol) 86 | return (scaled_data, scale, raw_sol) if get_raw_sol else (scaled_data, scale) 87 | 88 | 89 | def save_dataset(filename, scaled_data, scale, gen_kwargs, raw_sol=None): 90 | if not os.path.isfile(filename): 91 | print(f"Saving dataset to file: {filename}") 92 | np.savez( 93 | filename, 94 | scaled_data=scaled_data, 95 | scale=scale, 96 | gen_kwargs=gen_kwargs, 97 | raw_sol=raw_sol, 98 | ) 99 | else: 100 | raise FileExistsError(f"{filename} already exists! Dataset is not saved.") 101 | 102 | 103 | def load_dataset(filename, get_raw_sol=False): 104 | print(f"Loading dataset from file: {filename}") 105 | dataset = np.load(filename, allow_pickle=True) 106 | return ( 107 | dataset["scaled_data"], 108 | dataset["scale"], 109 | dataset["gen_kwargs"], 110 | dataset["raw_sol"] if get_raw_sol else None, 111 | ) 112 | -------------------------------------------------------------------------------- /diff_source_2d_model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import lax 4 | import numpy as np 5 | 6 | import haiku as hk 7 | import optax 8 | 9 | import os.path 10 | import argparse 11 | from functools import partial 12 | 13 | # from tqdm.auto import tqdm 14 | 15 | from data.utils import get_dataset 16 | from data.diff_source_2d import generate_dataset 17 | 18 | from encoder.utils import append_dzdt, concat_visible 19 | from symder.sym_models import SymModel, Quadratic, SpatialDerivative2D, rescale_z 20 | from symder.symder import get_symder_apply, get_model_apply 21 | 22 | from utils import loss_fn, init_optimizers, save_pytree # , load_pytree 23 | 24 | 25 | def get_model(num_visible, num_hidden, num_der, mesh, dx, dt, scale, get_dzdt=False): 26 | 27 | # Define encoder 28 | hidden_size = 64 29 | pad = 2 30 | 31 | def encoder(x): 32 | return hk.Sequential( 33 | [ 34 | lambda x: jnp.pad( 35 | x, ((0, 0), (0, 0), (pad, pad), (pad, pad), (0, 0)), "wrap" 36 | ), 37 | hk.Conv3D(hidden_size, kernel_shape=5, padding="VALID"), 38 | jax.nn.relu, 39 | hk.Conv3D(hidden_size, kernel_shape=1), 40 | jax.nn.relu, 41 | hk.Conv3D(num_hidden, kernel_shape=1), 42 | ] 43 | )(x) 44 | 45 | encoder = hk.without_apply_rng(hk.transform(encoder)) 46 | encoder_apply = append_dzdt(encoder.apply) if get_dzdt else encoder.apply 47 | encoder_apply = concat_visible( 48 | encoder_apply, visible_transform=lambda x: x[:, pad:-pad] 49 | ) 50 | 51 | # Define symbolic model 52 | n_dims = num_visible + num_hidden 53 | scale_vec = jnp.concatenate((scale[:, 0], jnp.ones(num_hidden))) 54 | 55 | @partial(rescale_z, scale_vec=scale_vec) 56 | def sym_model(z, t): 57 | return SymModel( 58 | 1e1 * dt, 59 | ( 60 | SpatialDerivative2D(mesh, np.sqrt(1e1) * dx, init=jnp.zeros), 61 | hk.Linear(n_dims, w_init=jnp.zeros, b_init=jnp.zeros), 62 | Quadratic(n_dims, init=jnp.zeros), 63 | ), 64 | )(z, t) 65 | 66 | sym_model = hk.without_apply_rng(hk.transform(sym_model)) 67 | 68 | # Define SymDer function which automatically computes 69 | # higher order time derivatives of symbolic model 70 | symder_apply = get_symder_apply( 71 | sym_model.apply, 72 | num_der=num_der, 73 | transform=lambda z: z[..., :num_visible], 74 | get_dzdt=get_dzdt, 75 | ) 76 | 77 | # Define full model, combining encoder and symbolic model 78 | model_apply = get_model_apply( 79 | encoder_apply, 80 | symder_apply, 81 | hidden_transform=lambda z: z[..., -num_hidden:], 82 | get_dzdt=get_dzdt, 83 | ) 84 | model_init = {"encoder": encoder.init, "sym_model": sym_model.init} 85 | 86 | return model_apply, model_init, {"pad": pad} 87 | 88 | 89 | def train( 90 | n_steps, 91 | model_apply, 92 | params, 93 | scaled_data, 94 | loss_fn_args={}, 95 | data_args={}, 96 | optimizers={}, 97 | sparse_thres=None, 98 | sparse_interval=None, 99 | key_seq=hk.PRNGSequence(42), 100 | multi_gpu=False, 101 | ): 102 | 103 | # JIT compile/PMAP gradient function 104 | loss_fn_apply = partial(loss_fn, model_apply, **loss_fn_args) 105 | if multi_gpu: 106 | # Take mean of gradient across multiple devices 107 | def grad_loss(params, batch, target): 108 | grad_out = jax.grad(loss_fn_apply, has_aux=True)(params, batch, target) 109 | return lax.pmean(grad_out, axis_name="devices") 110 | 111 | grad_loss = jax.pmap(grad_loss, axis_name="devices") 112 | else: 113 | grad_loss = jax.jit(jax.grad(loss_fn_apply, has_aux=True)) 114 | 115 | # Initialize sparse mask 116 | sparsify = sparse_thres is not None and sparse_interval is not None 117 | if multi_gpu: 118 | sparse_mask = jax.tree_map( 119 | jax.pmap(lambda x: jnp.ones_like(x, dtype=bool)), params["sym_model"] 120 | ) 121 | else: 122 | sparse_mask = jax.tree_map( 123 | lambda x: jnp.ones_like(x, dtype=bool), params["sym_model"] 124 | ) 125 | 126 | # Initialize optimizers 127 | update_params, opt_state = init_optimizers(params, optimizers, sparsify, multi_gpu) 128 | if multi_gpu: 129 | update_params = jax.pmap(update_params, axis_name="devices") 130 | else: 131 | update_params = jax.jit(update_params) 132 | 133 | # Get batch and target 134 | # TODO: replace this with call to a data generator/data loader 135 | if multi_gpu: 136 | n_devices = jax.device_count() 137 | pad = data_args["pad"] 138 | time_size = (scaled_data.shape[0] - 2 * pad) // n_devices 139 | batch = [] 140 | target = [] 141 | for i in range(n_devices): 142 | start, end = i * time_size, (i + 1) * time_size + 2 * pad 143 | if loss_fn_args["reg_dzdt"] is not None: 144 | # batch, time, mesh, mesh, num_visible, 2 145 | batch.append(scaled_data[None, start:end, :, :, :, :2]) 146 | else: 147 | # batch, time, mesh, mesh, num_visible 148 | batch.append(scaled_data[None, start:end, :, :, :, 0]) 149 | # batch, time, mesh, mesh, num_visible, num_der 150 | target.append(scaled_data[None, start + pad : end - pad, :, :, :, 1:]) 151 | 152 | batch = jax.device_put_sharded(batch, jax.devices()) 153 | target = jax.device_put_sharded(target, jax.devices()) 154 | 155 | else: 156 | if loss_fn_args["reg_dzdt"] is not None: 157 | # batch, time, mesh, mesh, num_visible, 2 158 | batch = scaled_data[None, :, :, :, :, :2] 159 | else: 160 | # batch, time, mesh, mesh, num_visible 161 | batch = scaled_data[None, :, :, :, :, 0] 162 | pad = data_args["pad"] 163 | # batch, time, mesh, mesh, num_visible, num_der 164 | target = scaled_data[None, pad:-pad, :, :, :, 1:] 165 | 166 | batch = jnp.asarray(batch) 167 | target = jnp.asarray(target) 168 | 169 | # Training loop 170 | if multi_gpu: 171 | print(f"Training for {n_steps} steps on {n_devices} devices...") 172 | else: 173 | print(f"Training for {n_steps} steps...") 174 | 175 | best_loss = np.float("inf") 176 | best_params = None 177 | 178 | def thres_fn(x): 179 | return jnp.abs(x) > sparse_thres 180 | 181 | if multi_gpu: 182 | thres_fn = jax.pmap(thres_fn) 183 | 184 | for step in range(n_steps): 185 | # Compute gradients and losses 186 | grads, loss_list = grad_loss(params, batch, target) 187 | 188 | # Save best params if loss is lower than best_loss 189 | loss = loss_list[0][0] if multi_gpu else loss_list[0] 190 | if loss < best_loss: 191 | best_loss = loss 192 | best_params = jax.tree_map(lambda x: x.copy(), params) 193 | 194 | # Update sparse_mask based on a threshold 195 | if sparsify and step > 0 and step % sparse_interval == 0: 196 | sparse_mask = jax.tree_map(thres_fn, best_params["sym_model"]) 197 | 198 | # Update params based on optimizers 199 | params, opt_state, sparse_mask = update_params( 200 | grads, opt_state, params, sparse_mask 201 | ) 202 | 203 | # Print loss 204 | if step % 100 == 0: 205 | loss, mse, reg_dzdt, reg_l1_sparse = loss_list 206 | if multi_gpu: 207 | (loss, mse, reg_dzdt, reg_l1_sparse) = ( 208 | loss[0], 209 | mse[0], 210 | reg_dzdt[0], 211 | reg_l1_sparse[0], 212 | ) 213 | print( 214 | f"Loss[{step}] = {loss}, MSE = {mse}, " 215 | f"Reg. dz/dt = {reg_dzdt}, Reg. L1 Sparse = {reg_l1_sparse}" 216 | ) 217 | if multi_gpu: 218 | print(jax.tree_map(lambda x: x[0], params["sym_model"])) 219 | else: 220 | print(params["sym_model"]) 221 | 222 | if multi_gpu: 223 | best_params = jax.tree_map(lambda x: x[0], best_params) 224 | sparse_mask = jax.tree_map(lambda x: x[0], sparse_mask) 225 | 226 | print("\nBest loss:", best_loss) 227 | print("Best sym_model params:", best_params["sym_model"]) 228 | return best_loss, best_params, sparse_mask 229 | 230 | 231 | if __name__ == "__main__": 232 | 233 | parser = argparse.ArgumentParser( 234 | description="Run SymDer model on 2D diffusion with source data." 235 | ) 236 | parser.add_argument( 237 | "-o", 238 | "--output", 239 | type=str, 240 | default="./diff_source_2d_run0/", 241 | help="Output folder path. Default: ./diff_source_2d_run0/", 242 | ) 243 | parser.add_argument( 244 | "-d", 245 | "--dataset", 246 | type=str, 247 | default="./data/diff_source_2d.npz", 248 | help=( 249 | "Path to 2D diffusion with source dataset (generated and saved " 250 | "if it does not exist). Default: ./data/diff_source_2d.npz" 251 | ), 252 | ) 253 | args = parser.parse_args() 254 | 255 | # Seed random number generator 256 | key_seq = hk.PRNGSequence(42) 257 | 258 | # Set SymDer parameters 259 | num_visible = 1 260 | num_hidden = 1 261 | num_der = 2 262 | 263 | # Set dataset parameters and load/generate dataset 264 | sys_size = 64 265 | mesh = 64 266 | dt = 5e-2 267 | tspan = (0, 50 + 2 * dt) 268 | scaled_data, scale, raw_sol = get_dataset( 269 | args.dataset, 270 | generate_dataset, 271 | get_raw_sol=True, 272 | sys_size=sys_size, 273 | mesh=mesh, 274 | dt=dt, 275 | tspan=tspan, 276 | num_der=num_der, 277 | ) 278 | 279 | # Set training hyperparameters 280 | n_steps = 50000 281 | sparse_thres = 5e-3 282 | sparse_interval = 1000 283 | multi_gpu = True 284 | 285 | # Define optimizers 286 | optimizers = { 287 | "encoder": optax.adabelief(1e-4, eps=1e-16), 288 | "sym_model": optax.adabelief(1e-4, eps=1e-16), 289 | } 290 | 291 | # Set loss function hyperparameters 292 | loss_fn_args = { 293 | "scale": jnp.array(scale), 294 | "deriv_weight": jnp.array([1.0, 10.0]), 295 | "reg_dzdt": 0, 296 | "reg_l1_sparse": 0, 297 | } 298 | get_dzdt = loss_fn_args["reg_dzdt"] is not None 299 | 300 | # Check dataset shapes 301 | assert scaled_data.shape[-2] == num_visible 302 | assert scaled_data.shape[-1] == num_der + 1 303 | assert scale.shape[0] == num_visible 304 | assert scale.shape[1] == num_der + 1 305 | 306 | # Define model 307 | model_apply, model_init, model_args = get_model( 308 | num_visible, 309 | num_hidden, 310 | num_der, 311 | mesh, 312 | sys_size / mesh, 313 | dt, 314 | scale, 315 | get_dzdt=get_dzdt, 316 | ) 317 | 318 | # Initialize parameters 319 | params = {} 320 | params["encoder"] = model_init["encoder"]( 321 | next(key_seq), jnp.ones([1, scaled_data.shape[0], mesh, mesh, num_visible]) 322 | ) 323 | params["sym_model"] = model_init["sym_model"]( 324 | next(key_seq), jnp.ones([1, 1, mesh, mesh, num_visible + num_hidden]), 0.0 325 | ) 326 | if multi_gpu: 327 | for name in params.keys(): 328 | params[name] = jax.device_put_replicated(params[name], jax.devices()) 329 | 330 | # Train 331 | best_loss, best_params, sparse_mask = train( 332 | n_steps, 333 | model_apply, 334 | params, 335 | scaled_data, 336 | loss_fn_args=loss_fn_args, 337 | data_args={"pad": model_args["pad"]}, 338 | optimizers=optimizers, 339 | sparse_thres=sparse_thres, 340 | sparse_interval=sparse_interval, 341 | key_seq=key_seq, 342 | multi_gpu=multi_gpu, 343 | ) 344 | 345 | # Save model parameters and sparse mask 346 | print(f"Saving best model parameters in output folder: {args.output}") 347 | save_pytree( 348 | os.path.join(args.output, "best.pt"), 349 | {"params": best_params, "sparse_mask": sparse_mask}, 350 | ) 351 | -------------------------------------------------------------------------------- /encoder/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["embedding", "utils"] 2 | from . import * 3 | -------------------------------------------------------------------------------- /encoder/embedding.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import haiku as hk 3 | 4 | __all__ = ["DirectEmbedding"] 5 | 6 | 7 | class DirectEmbedding(hk.Module): 8 | def __init__(self, shape, init=jnp.zeros, concat_visible=False, get_dzdt=False): 9 | super().__init__() 10 | self.shape = shape 11 | self.init = init 12 | self.concat_visible = concat_visible 13 | self.get_dzdt = get_dzdt 14 | 15 | def __call__(self, x, *args): 16 | z_hidden = hk.get_parameter( 17 | "z_hidden", (x.shape[0], x.shape[1], *self.shape), init=self.init 18 | ) 19 | 20 | z = jnp.concatenate((x, z_hidden), axis=-1) if self.concat_visible else z_hidden 21 | 22 | if self.get_dzdt: 23 | dz = jnp.diff(z, axis=1) 24 | dz = jnp.concatenate((dz, dz[:, [-1]]), axis=1) 25 | return z, dz 26 | else: 27 | return z 28 | # return z if not self.get_dzdt else (z, jnp.gradient(z_hidden, axis=1)) 29 | 30 | -------------------------------------------------------------------------------- /encoder/utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | __all__ = ['append_dzdt', 'concat_visible', 'normalize_by_magnitude', 'to_complex'] 5 | 6 | 7 | def append_dzdt(encoder_apply, finite_difference=False): 8 | if finite_difference: 9 | def encoder_with_dzdt(params, x, *args, **kwargs): 10 | z_hidden = encoder_apply(params, x, *args, **kwargs) 11 | dz_hidden = jnp.diff(z_hidden, axis=1) 12 | dz_hidden = jnp.concatenate((dz_hidden, dz_hidden[:, [-1]]), axis=1) 13 | return z_hidden, dz_hidden # jnp.gradient(z_hidden, axis=1) 14 | else: 15 | def encoder_with_dzdt(params, x, dxdt): 16 | zero_params = jax.tree_map(jnp.zeros_like, params) 17 | 18 | # dz/dt = dz/dx * dx/dt 19 | z_hidden, dzdt_hidden = jax.jvp( 20 | encoder_apply, (params, x), (zero_params, dxdt)) 21 | return z_hidden, dzdt_hidden 22 | 23 | return encoder_with_dzdt 24 | 25 | 26 | def concat_visible(encoder_apply, visible_transform=None): 27 | def encoder_concat_visible(params, x, *args, **kwargs): 28 | z_visible = (visible_transform(x) 29 | if visible_transform is not None else x) 30 | out = encoder_apply(params, x, *args, **kwargs) 31 | if isinstance(out, list) or isinstance(out, tuple): 32 | z_hidden, *out_args = out 33 | else: 34 | z_hidden = out 35 | out_args = () 36 | z = jnp.concatenate((z_visible, z_hidden), axis=-1) 37 | return z, *out_args 38 | return encoder_concat_visible 39 | 40 | 41 | def normalize_by_magnitude(encoder_apply, pad=None, squared=False): 42 | def encoder_normalized(params, x, *args, **kwargs): 43 | z_phase = encoder_apply(params, x, *args, **kwargs) 44 | z_phase = z_phase/jnp.linalg.norm(z_phase, axis=-1, keepdims=True) 45 | 46 | if pad is not None and pad > 0: 47 | x = x[:, pad:-pad] 48 | z_mag = jnp.sqrt(x) if squared else x 49 | 50 | return z_mag * z_phase 51 | return encoder_normalized 52 | 53 | 54 | def to_complex(encoder_apply): 55 | def encoder_complex(params, x, *args, **kwargs): 56 | out = encoder_apply(params, x, *args, **kwargs) 57 | if isinstance(out, list) or isinstance(out, tuple): 58 | z, *out_args = out 59 | else: 60 | z = out 61 | out_args = () 62 | return z[..., [0]] + 1j * z[..., [1]], *out_args 63 | return encoder_complex 64 | -------------------------------------------------------------------------------- /lorenz_model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | import haiku as hk 6 | import optax 7 | 8 | import os.path 9 | import argparse 10 | from functools import partial 11 | 12 | # from tqdm.auto import tqdm 13 | 14 | from data.utils import get_dataset 15 | from data.lorenz import generate_dataset 16 | 17 | from encoder.utils import append_dzdt, concat_visible 18 | from symder.sym_models import SymModel, Quadratic, rescale_z 19 | from symder.symder import get_symder_apply, get_model_apply 20 | 21 | from utils import loss_fn, init_optimizers, save_pytree # , load_pytree 22 | 23 | 24 | def get_model(num_visible, num_hidden, num_der, dt, scale, get_dzdt=False): 25 | 26 | # Define encoder 27 | hidden_size = 128 28 | pad = 4 29 | 30 | def encoder(x): 31 | return hk.Sequential( 32 | [ 33 | hk.Conv1D(hidden_size, kernel_shape=9, padding="VALID"), 34 | jax.nn.relu, 35 | hk.Conv1D(hidden_size, kernel_shape=1), 36 | jax.nn.relu, 37 | hk.Conv1D(num_hidden, kernel_shape=1), 38 | ] 39 | )(x) 40 | 41 | encoder = hk.without_apply_rng(hk.transform(encoder)) 42 | encoder_apply = append_dzdt(encoder.apply) if get_dzdt else encoder.apply 43 | encoder_apply = concat_visible( 44 | encoder_apply, visible_transform=lambda x: x[:, pad:-pad] 45 | ) 46 | 47 | # Define symbolic model 48 | n_dims = num_visible + num_hidden 49 | scale_vec = jnp.concatenate((scale[:, 0], jnp.ones(num_hidden))) 50 | 51 | @partial(rescale_z, scale_vec=scale_vec) 52 | def sym_model(z, t): 53 | return SymModel( 54 | 1e2 * dt, 55 | ( 56 | hk.Linear(n_dims, w_init=jnp.zeros, b_init=jnp.zeros), 57 | Quadratic(n_dims, init=jnp.zeros), 58 | ), 59 | )(z, t) 60 | 61 | sym_model = hk.without_apply_rng(hk.transform(sym_model)) 62 | 63 | # Define SymDer function which automatically computes 64 | # higher order time derivatives of symbolic model 65 | symder_apply = get_symder_apply( 66 | sym_model.apply, 67 | num_der=num_der, 68 | transform=lambda z: z[..., :num_visible], 69 | get_dzdt=get_dzdt, 70 | ) 71 | 72 | # Define full model, combining encoder and symbolic model 73 | model_apply = get_model_apply( 74 | encoder_apply, 75 | symder_apply, 76 | hidden_transform=lambda z: z[..., -num_hidden:], 77 | get_dzdt=get_dzdt, 78 | ) 79 | model_init = {"encoder": encoder.init, "sym_model": sym_model.init} 80 | 81 | return model_apply, model_init, {"pad": pad} 82 | 83 | 84 | def train( 85 | n_steps, 86 | model_apply, 87 | params, 88 | scaled_data, 89 | loss_fn_args={}, 90 | data_args={}, 91 | optimizers={}, 92 | sparse_thres=None, 93 | sparse_interval=None, 94 | key_seq=hk.PRNGSequence(42), 95 | ): 96 | 97 | # JIT compile gradient function 98 | loss_fn_apply = partial(loss_fn, model_apply, **loss_fn_args) 99 | grad_loss = jax.jit(jax.grad(loss_fn_apply, has_aux=True)) 100 | 101 | # Initialize sparse mask 102 | sparsify = sparse_thres is not None and sparse_interval is not None 103 | sparse_mask = jax.tree_map( 104 | lambda x: jnp.ones_like(x, dtype=bool), params["sym_model"] 105 | ) 106 | 107 | # Initialize optimizers 108 | update_params, opt_state = init_optimizers(params, optimizers, sparsify) 109 | update_params = jax.jit(update_params) 110 | 111 | # Get batch and target 112 | # TODO: replace this with call to a data generator/data loader 113 | if loss_fn_args["reg_dzdt"] is not None: 114 | batch = scaled_data[None, :, :, :2] # batch, time, num_visible, 2 115 | else: 116 | batch = scaled_data[None, :, :, 0] # batch, time, num_visible 117 | pad = data_args["pad"] 118 | # batch, time, num_visible, num_der 119 | target = scaled_data[None, pad:-pad, :, 1:] 120 | 121 | batch = jnp.asarray(batch) 122 | target = jnp.asarray(target) 123 | 124 | # Training loop 125 | print(f"Training for {n_steps} steps...") 126 | 127 | best_loss = np.float("inf") 128 | best_params = None 129 | 130 | for step in range(n_steps): 131 | 132 | # Compute gradients and losses 133 | grads, loss_list = grad_loss(params, batch, target) 134 | 135 | # Save best params if loss is lower than best_loss 136 | loss = loss_list[0] 137 | if loss < best_loss: 138 | best_loss = loss 139 | best_params = jax.tree_map(lambda x: x.copy(), params) 140 | 141 | # Update sparse_mask based on a threshold 142 | if sparsify and step > 0 and step % sparse_interval == 0: 143 | sparse_mask = jax.tree_map( 144 | lambda x: jnp.abs(x) > sparse_thres, best_params["sym_model"] 145 | ) 146 | 147 | # Update params based on optimizers 148 | params, opt_state, sparse_mask = update_params( 149 | grads, opt_state, params, sparse_mask 150 | ) 151 | 152 | # Print loss 153 | if step % 1000 == 0: 154 | loss, mse, reg_dzdt, reg_l1_sparse = loss_list 155 | print( 156 | f"Loss[{step}] = {loss}, MSE = {mse}, " 157 | f"Reg. dz/dt = {reg_dzdt}, Reg. L1 Sparse = {reg_l1_sparse}" 158 | ) 159 | print(params["sym_model"]) 160 | 161 | print("\nBest loss:", best_loss) 162 | print("Best sym_model params:", best_params["sym_model"]) 163 | return best_loss, best_params, sparse_mask 164 | 165 | 166 | if __name__ == "__main__": 167 | 168 | parser = argparse.ArgumentParser( 169 | description="Run SymDer model on Lorenz system data." 170 | ) 171 | parser.add_argument( 172 | "-o", 173 | "--output", 174 | type=str, 175 | default="./lorenz_run0/", 176 | help="Output folder path. Default: ./lorenz_run0/", 177 | ) 178 | parser.add_argument( 179 | "-d", 180 | "--dataset", 181 | type=str, 182 | default="./data/lorenz.npz", 183 | help=( 184 | "Path to Lorenz system dataset (generated and saved " 185 | "if it does not exist). Default: ./data/lorenz.npz" 186 | ), 187 | ) 188 | parser.add_argument( 189 | "-v", 190 | "--visible", 191 | type=int, 192 | nargs="+", 193 | default=[0, 1], 194 | help="List of visible variables (0, 1, and/or 2). Default: 0 1", 195 | ) 196 | args = parser.parse_args() 197 | 198 | # Seed random number generator 199 | key_seq = hk.PRNGSequence(42) 200 | 201 | # Set SymDer parameters 202 | num_visible = len(args.visible) # 2 203 | num_hidden = 3 - num_visible # 1 204 | num_der = 2 205 | 206 | # Set dataset parameters and load/generate dataset 207 | dt = 1e-2 208 | tmax = 100 + 2 * dt 209 | scaled_data, scale, raw_sol = get_dataset( 210 | args.dataset, 211 | generate_dataset, 212 | get_raw_sol=True, 213 | dt=dt, 214 | tmax=tmax, 215 | num_visible=num_visible, 216 | visible_vars=args.visible, 217 | num_der=num_der, 218 | ) 219 | 220 | # Set training hyperparameters 221 | n_steps = 50000 222 | sparse_thres = 1e-3 # 5e-3 223 | sparse_interval = 5000 224 | 225 | # Define optimizers 226 | optimizers = { 227 | "encoder": optax.adabelief(1e-3, eps=1e-16), 228 | "sym_model": optax.adabelief(1e-3, eps=1e-16), 229 | } 230 | 231 | # Set loss function hyperparameters 232 | loss_fn_args = { 233 | "scale": jnp.array(scale), 234 | "deriv_weight": jnp.array([1.0, 1.0]), 235 | "reg_dzdt": 0, 236 | "reg_l1_sparse": 0, 237 | } 238 | get_dzdt = loss_fn_args["reg_dzdt"] is not None 239 | 240 | # Check dataset shapes 241 | assert scaled_data.shape[-2] == num_visible 242 | assert scaled_data.shape[-1] == num_der + 1 243 | assert scale.shape[0] == num_visible 244 | assert scale.shape[1] == num_der + 1 245 | 246 | # Define model 247 | model_apply, model_init, model_args = get_model( 248 | num_visible, num_hidden, num_der, dt, scale, get_dzdt=get_dzdt 249 | ) 250 | 251 | # Initialize parameters 252 | params = {} 253 | params["encoder"] = model_init["encoder"]( 254 | next(key_seq), jnp.ones([1, scaled_data.shape[0], num_visible]) 255 | ) 256 | params["sym_model"] = model_init["sym_model"]( 257 | next(key_seq), jnp.ones([1, 1, num_visible + num_hidden]), 0.0 258 | ) 259 | 260 | # Train 261 | best_loss, best_params, sparse_mask = train( 262 | n_steps, 263 | model_apply, 264 | params, 265 | scaled_data, 266 | loss_fn_args=loss_fn_args, 267 | data_args={"pad": model_args["pad"]}, 268 | optimizers=optimizers, 269 | sparse_thres=sparse_thres, 270 | sparse_interval=sparse_interval, 271 | key_seq=key_seq, 272 | ) 273 | 274 | # Save model parameters and sparse mask 275 | print(f"Saving best model parameters in output folder: {args.output}") 276 | save_pytree( 277 | os.path.join(args.output, "best.pt"), 278 | {"params": best_params, "sparse_mask": sparse_mask}, 279 | ) 280 | -------------------------------------------------------------------------------- /lorenz_model_extrahidden.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | import haiku as hk 6 | import optax 7 | 8 | import os.path 9 | import argparse 10 | from functools import partial 11 | 12 | # from tqdm.auto import tqdm 13 | 14 | from data.utils import get_dataset 15 | from data.lorenz import generate_dataset 16 | 17 | from encoder.utils import append_dzdt, concat_visible 18 | from symder.sym_models import SymModel, Quadratic, rescale_z 19 | from symder.symder import get_symder_apply, get_model_apply 20 | 21 | from utils import loss_fn, init_optimizers, save_pytree # , load_pytree 22 | 23 | 24 | def get_model(num_visible, num_hidden, num_der, dt, scale, get_dzdt=False): 25 | 26 | # Define encoder 27 | hidden_size = 128 28 | pad = 4 29 | 30 | def encoder(x): 31 | return hk.Sequential( 32 | [ 33 | hk.Conv1D(hidden_size, kernel_shape=9, padding="VALID"), 34 | jax.nn.relu, 35 | hk.Conv1D(hidden_size, kernel_shape=1), 36 | jax.nn.relu, 37 | hk.Conv1D(num_hidden, kernel_shape=1), 38 | ] 39 | )(x) 40 | 41 | encoder = hk.without_apply_rng(hk.transform(encoder)) 42 | encoder_apply = append_dzdt(encoder.apply) if get_dzdt else encoder.apply 43 | encoder_apply = concat_visible( 44 | encoder_apply, visible_transform=lambda x: x[:, pad:-pad] 45 | ) 46 | 47 | # Define symbolic model 48 | n_dims = num_visible + num_hidden 49 | scale_vec = jnp.concatenate((scale[:, 0], jnp.ones(num_hidden))) 50 | 51 | @partial(rescale_z, scale_vec=scale_vec) 52 | def sym_model(z, t): 53 | return SymModel( 54 | 1e2 * dt, 55 | ( 56 | hk.Linear(n_dims, w_init=jnp.zeros, b_init=jnp.zeros), 57 | Quadratic(n_dims, init=jnp.zeros), 58 | ), 59 | )(z, t) 60 | 61 | sym_model = hk.without_apply_rng(hk.transform(sym_model)) 62 | 63 | # Define SymDer function which automatically computes 64 | # higher order time derivatives of symbolic model 65 | symder_apply = get_symder_apply( 66 | sym_model.apply, 67 | num_der=num_der, 68 | transform=lambda z: z[..., :num_visible], 69 | get_dzdt=get_dzdt, 70 | ) 71 | 72 | # Define full model, combining encoder and symbolic model 73 | model_apply = get_model_apply( 74 | encoder_apply, 75 | symder_apply, 76 | hidden_transform=lambda z: z[..., -num_hidden:], 77 | get_dzdt=get_dzdt, 78 | ) 79 | model_init = {"encoder": encoder.init, "sym_model": sym_model.init} 80 | 81 | return model_apply, model_init, {"pad": pad} 82 | 83 | 84 | def train( 85 | n_steps, 86 | model_apply, 87 | params, 88 | scaled_data, 89 | loss_fn_args={}, 90 | data_args={}, 91 | optimizers={}, 92 | sparse_thres=None, 93 | sparse_interval=None, 94 | key_seq=hk.PRNGSequence(42), 95 | ): 96 | 97 | # JIT compile gradient function 98 | loss_fn_apply = partial(loss_fn, model_apply, **loss_fn_args) 99 | grad_loss = jax.jit(jax.grad(loss_fn_apply, has_aux=True)) 100 | 101 | # Initialize sparse mask 102 | sparsify = sparse_thres is not None and sparse_interval is not None 103 | sparse_mask = jax.tree_map( 104 | lambda x: jnp.ones_like(x, dtype=bool), params["sym_model"] 105 | ) 106 | 107 | # # Remove hidden state 108 | # flat_mask, tree = jax.tree_flatten(sparse_mask) 109 | # flat_mask[0] = jnp.array([True, True, False]) 110 | # flat_mask[1] = jnp.array( 111 | # [[True, True, False], [True, True, False], [False, False, False]] 112 | # ) 113 | # flat_mask[2] = jnp.array( 114 | # [ 115 | # [[True, True, False], [False, True, False], [False, False, False]], 116 | # [[True, True, False], [False, True, False], [False, False, False]], 117 | # [[False, False, False], [False, False, False], [False, False, False]], 118 | # ] 119 | # ) 120 | # sparse_mask = jax.tree_unflatten(tree, flat_mask) 121 | 122 | # Initialize optimizers 123 | update_params, opt_state = init_optimizers(params, optimizers, sparsify) 124 | update_params = jax.jit(update_params) 125 | 126 | # Get batch and target 127 | # TODO: replace this with call to a data generator/data loader 128 | if loss_fn_args["reg_dzdt"] is not None: 129 | batch = scaled_data[None, :, :, :2] # batch, time, num_visible, 2 130 | else: 131 | batch = scaled_data[None, :, :, 0] # batch, time, num_visible 132 | pad = data_args["pad"] 133 | # batch, time, num_visible, num_der 134 | target = scaled_data[None, pad:-pad, :, 1:] 135 | 136 | batch = jnp.asarray(batch) 137 | target = jnp.asarray(target) 138 | 139 | # Training loop 140 | print(f"Training for {n_steps} steps...") 141 | 142 | best_loss = np.float("inf") 143 | best_params = None 144 | 145 | for step in range(n_steps): 146 | 147 | # Compute gradients and losses 148 | grads, loss_list = grad_loss(params, batch, target) 149 | 150 | # Save best params if loss is lower than best_loss 151 | loss = loss_list[0] 152 | if loss < best_loss: 153 | best_loss = loss 154 | best_params = jax.tree_map(lambda x: x.copy(), params) 155 | 156 | # Update sparse_mask based on a threshold 157 | if sparsify and step > 0 and step % sparse_interval == 0: 158 | sparse_mask = jax.tree_map( 159 | lambda x: jnp.abs(x) > sparse_thres, best_params["sym_model"] 160 | ) 161 | 162 | # Update params based on optimizers 163 | params, opt_state, sparse_mask = update_params( 164 | grads, opt_state, params, sparse_mask 165 | ) 166 | 167 | # Print loss 168 | if step % 1000 == 0: 169 | loss, mse, reg_dzdt, reg_l1_sparse = loss_list 170 | print( 171 | f"Loss[{step}] = {loss}, MSE = {mse}, " 172 | f"Reg. dz/dt = {reg_dzdt}, Reg. L1 Sparse = {reg_l1_sparse}" 173 | ) 174 | print(params["sym_model"]) 175 | 176 | print("\nBest loss:", best_loss) 177 | print("Best sym_model params:", best_params["sym_model"]) 178 | return best_loss, best_params, sparse_mask 179 | 180 | 181 | if __name__ == "__main__": 182 | 183 | parser = argparse.ArgumentParser( 184 | description="Run SymDer model on Lorenz system data." 185 | ) 186 | parser.add_argument( 187 | "-o", 188 | "--output", 189 | type=str, 190 | default="./lorenz_run0/", 191 | help="Output folder path. Default: ./lorenz_run0/", 192 | ) 193 | parser.add_argument( 194 | "-d", 195 | "--dataset", 196 | type=str, 197 | default="./data/lorenz.npz", 198 | help=( 199 | "Path to Lorenz system dataset (generated and saved " 200 | "if it does not exist). Default: ./data/lorenz.npz" 201 | ), 202 | ) 203 | parser.add_argument( 204 | "-v", 205 | "--visible", 206 | type=int, 207 | nargs="+", 208 | default=[0, 1], 209 | help="List of visible variables (0, 1, and/or 2). Default: 0 1", 210 | ) 211 | parser.add_argument( 212 | "-e", 213 | "--extrahidden", 214 | type=int, 215 | help="Number of extra hidden variables. Default: 0", 216 | ) 217 | args = parser.parse_args() 218 | 219 | # Seed random number generator 220 | key_seq = hk.PRNGSequence(42) 221 | 222 | # Set SymDer parameters 223 | num_visible = len(args.visible) # 2 224 | num_hidden = 3 - num_visible + args.extrahidden 225 | num_der = 2 226 | 227 | # Set dataset parameters and load/generate dataset 228 | dt = 1e-2 229 | tmax = 100 + 2 * dt 230 | scaled_data, scale, raw_sol = get_dataset( 231 | args.dataset, 232 | generate_dataset, 233 | get_raw_sol=True, 234 | dt=dt, 235 | tmax=tmax, 236 | num_visible=num_visible, 237 | visible_vars=args.visible, 238 | num_der=num_der, 239 | ) 240 | 241 | # Set training hyperparameters 242 | n_steps = 50000 243 | sparse_thres = 1e-3 # 5e-3 244 | sparse_interval = 5000 245 | 246 | # Define optimizers 247 | optimizers = { 248 | "encoder": optax.adabelief(1e-3, eps=1e-16), 249 | "sym_model": optax.adabelief(1e-3, eps=1e-16), 250 | } 251 | 252 | # Set loss function hyperparameters 253 | loss_fn_args = { 254 | "scale": jnp.array(scale), 255 | "deriv_weight": jnp.array([1.0, 1.0]), 256 | "reg_dzdt": 0, 257 | "reg_l1_sparse": 0, 258 | } 259 | get_dzdt = loss_fn_args["reg_dzdt"] is not None 260 | 261 | # Check dataset shapes 262 | assert scaled_data.shape[-2] == num_visible 263 | assert scaled_data.shape[-1] == num_der + 1 264 | assert scale.shape[0] == num_visible 265 | assert scale.shape[1] == num_der + 1 266 | 267 | # Define model 268 | model_apply, model_init, model_args = get_model( 269 | num_visible, num_hidden, num_der, dt, scale, get_dzdt=get_dzdt 270 | ) 271 | 272 | # Initialize parameters 273 | params = {} 274 | params["encoder"] = model_init["encoder"]( 275 | next(key_seq), jnp.ones([1, scaled_data.shape[0], num_visible]) 276 | ) 277 | params["sym_model"] = model_init["sym_model"]( 278 | next(key_seq), jnp.ones([1, 1, num_visible + num_hidden]), 0.0 279 | ) 280 | 281 | # Train 282 | best_loss, best_params, sparse_mask = train( 283 | n_steps, 284 | model_apply, 285 | params, 286 | scaled_data, 287 | loss_fn_args=loss_fn_args, 288 | data_args={"pad": model_args["pad"]}, 289 | optimizers=optimizers, 290 | sparse_thres=sparse_thres, 291 | sparse_interval=sparse_interval, 292 | key_seq=key_seq, 293 | ) 294 | 295 | # Save model parameters and sparse mask 296 | print(f"Saving best model parameters in output folder: {args.output}") 297 | save_pytree( 298 | os.path.join(args.output, "best.pt"), 299 | {"params": best_params, "sparse_mask": sparse_mask}, 300 | ) 301 | -------------------------------------------------------------------------------- /lorenz_model_noise.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | import haiku as hk 6 | import optax 7 | 8 | import os.path 9 | import argparse 10 | from functools import partial 11 | 12 | # from tqdm.auto import tqdm 13 | 14 | from data.utils import get_dataset 15 | from data.lorenz_noise import generate_dataset 16 | 17 | from encoder.utils import append_dzdt, concat_visible 18 | from symder.sym_models import SymModel, Quadratic, rescale_z 19 | from symder.symder import get_symder_apply, get_model_apply 20 | 21 | from utils import loss_fn, init_optimizers, save_pytree # , load_pytree 22 | 23 | 24 | def get_model(num_visible, num_hidden, num_der, dt, scale, get_dzdt=False): 25 | 26 | # Define encoder 27 | hidden_size = 128 28 | pad = 4 29 | 30 | def encoder(x): 31 | return hk.Sequential( 32 | [ 33 | hk.Conv1D(hidden_size, kernel_shape=9, padding="VALID"), 34 | jax.nn.relu, 35 | hk.Conv1D(hidden_size, kernel_shape=1), 36 | jax.nn.relu, 37 | hk.Conv1D(num_hidden, kernel_shape=1), 38 | ] 39 | )(x) 40 | 41 | encoder = hk.without_apply_rng(hk.transform(encoder)) 42 | encoder_apply = append_dzdt(encoder.apply) if get_dzdt else encoder.apply 43 | encoder_apply = concat_visible( 44 | encoder_apply, visible_transform=lambda x: x[:, pad:-pad] 45 | ) 46 | 47 | # Define symbolic model 48 | n_dims = num_visible + num_hidden 49 | scale_vec = jnp.concatenate((scale[:, 0], jnp.ones(num_hidden))) 50 | 51 | @partial(rescale_z, scale_vec=scale_vec) 52 | def sym_model(z, t): 53 | return SymModel( 54 | 1e2 * dt, 55 | ( 56 | hk.Linear(n_dims, w_init=jnp.zeros, b_init=jnp.zeros), 57 | Quadratic(n_dims, init=jnp.zeros), 58 | ), 59 | )(z, t) 60 | 61 | sym_model = hk.without_apply_rng(hk.transform(sym_model)) 62 | 63 | # Define SymDer function which automatically computes 64 | # higher order time derivatives of symbolic model 65 | symder_apply = get_symder_apply( 66 | sym_model.apply, 67 | num_der=num_der, 68 | transform=lambda z: z[..., :num_visible], 69 | get_dzdt=get_dzdt, 70 | ) 71 | 72 | # Define full model, combining encoder and symbolic model 73 | model_apply = get_model_apply( 74 | encoder_apply, 75 | symder_apply, 76 | hidden_transform=lambda z: z[..., -num_hidden:], 77 | get_dzdt=get_dzdt, 78 | ) 79 | model_init = {"encoder": encoder.init, "sym_model": sym_model.init} 80 | 81 | return model_apply, model_init, {"pad": pad} 82 | 83 | 84 | def train( 85 | n_steps, 86 | model_apply, 87 | params, 88 | scaled_data, 89 | loss_fn_args={}, 90 | data_args={}, 91 | optimizers={}, 92 | sparse_thres=None, 93 | sparse_interval=None, 94 | key_seq=hk.PRNGSequence(42), 95 | ): 96 | 97 | # JIT compile gradient function 98 | loss_fn_apply = partial(loss_fn, model_apply, **loss_fn_args) 99 | grad_loss = jax.jit(jax.grad(loss_fn_apply, has_aux=True)) 100 | 101 | # Initialize sparse mask 102 | sparsify = sparse_thres is not None and sparse_interval is not None 103 | sparse_mask = jax.tree_map( 104 | lambda x: jnp.ones_like(x, dtype=bool), params["sym_model"] 105 | ) 106 | 107 | # # TEMPORARY: Init mask 108 | # flat_mask, tree = jax.tree_flatten(sparse_mask) 109 | # flat_mask[0] = jnp.array([False, False, True]) 110 | # flat_mask[1] = jnp.array( 111 | # [[True, True, False], [True, True, False], [False, False, True]] 112 | # ) 113 | # flat_mask[2] = jnp.array( 114 | # [ 115 | # [[False, False, False], [False, False, False], [False, False, False]], 116 | # [[False, False, True], [False, False, False], [False, False, False]], 117 | # [[False, True, False], [False, False, False], [False, False, False]], 118 | # ] 119 | # ) 120 | # sparse_mask = jax.tree_unflatten(tree, flat_mask) 121 | 122 | # Initialize optimizers 123 | update_params, opt_state = init_optimizers(params, optimizers, sparsify) 124 | update_params = jax.jit(update_params) 125 | 126 | # Get batch and target 127 | # TODO: replace this with call to a data generator/data loader 128 | if loss_fn_args["reg_dzdt"] is not None: 129 | batch = scaled_data[None, :, :, :2] # batch, time, num_visible, 2 130 | else: 131 | batch = scaled_data[None, :, :, 0] # batch, time, num_visible 132 | pad = data_args["pad"] 133 | # batch, time, num_visible, num_der 134 | target = scaled_data[None, pad:-pad, :, 1:] 135 | 136 | batch = jnp.asarray(batch) 137 | target = jnp.asarray(target) 138 | 139 | # Training loop 140 | print(f"Training for {n_steps} steps...") 141 | 142 | best_loss = np.float("inf") 143 | best_params = None 144 | 145 | for step in range(n_steps): 146 | 147 | # Compute gradients and losses 148 | grads, loss_list = grad_loss(params, batch, target) 149 | 150 | # Save best params if loss is lower than best_loss 151 | loss = loss_list[0] 152 | if loss < best_loss: 153 | best_loss = loss 154 | best_params = jax.tree_map(lambda x: x.copy(), params) 155 | 156 | # Update sparse_mask based on a threshold 157 | if sparsify and step > 0 and step % sparse_interval == 0: 158 | sparse_mask = jax.tree_map( 159 | lambda x: jnp.abs(x) > sparse_thres, best_params["sym_model"] 160 | ) 161 | 162 | # Update params based on optimizers 163 | params, opt_state, sparse_mask = update_params( 164 | grads, opt_state, params, sparse_mask 165 | ) 166 | 167 | # Print loss 168 | if step % 1000 == 0: 169 | loss, mse, reg_dzdt, reg_l1_sparse = loss_list 170 | print( 171 | f"Loss[{step}] = {loss}, MSE = {mse}, " 172 | f"Reg. dz/dt = {reg_dzdt}, Reg. L1 Sparse = {reg_l1_sparse}" 173 | ) 174 | print(params["sym_model"]) 175 | 176 | print("\nBest loss:", best_loss) 177 | print("Best sym_model params:", best_params["sym_model"]) 178 | return best_loss, best_params, sparse_mask 179 | 180 | 181 | if __name__ == "__main__": 182 | 183 | parser = argparse.ArgumentParser( 184 | description="Run SymDer model on Lorenz system data with added noise." 185 | ) 186 | parser.add_argument( 187 | "-o", 188 | "--output", 189 | type=str, 190 | default="./lorenz_noise_run0/", 191 | help="Output folder path. Default: ./lorenz_noise_run0/", 192 | ) 193 | parser.add_argument( 194 | "-d", 195 | "--dataset", 196 | type=str, 197 | default="./data/lorenz_noise.npz", 198 | help=( 199 | "Path to Lorenz system dataset (generated and saved " 200 | "if it does not exist). Default: ./data/lorenz_noise.npz" 201 | ), 202 | ) 203 | parser.add_argument( 204 | "-v", 205 | "--visible", 206 | type=int, 207 | nargs="+", 208 | default=[0, 1], 209 | help="List of visible variables (0, 1, and/or 2). Default: 0 1", 210 | ) 211 | parser.add_argument( 212 | "--noise", 213 | type=float, 214 | default=0.0, 215 | help="Gaussian noise standard deviation. Default: 0.0", 216 | ) 217 | parser.add_argument( 218 | "--smooth", 219 | type=int, 220 | nargs="+", 221 | default=None, 222 | help="Smoothing window size, polynomial order. Default: None", 223 | ) 224 | args = parser.parse_args() 225 | 226 | # Seed random number generator 227 | key_seq = hk.PRNGSequence(42) 228 | 229 | # Set SymDer parameters 230 | num_visible = len(args.visible) # 2 231 | num_hidden = 3 - num_visible # 1 232 | num_der = 2 233 | 234 | # Set dataset parameters and load/generate dataset 235 | dt = 1e-2 236 | tmax = 100 + 2 * dt 237 | scaled_data, scale, raw_sol = get_dataset( 238 | args.dataset, 239 | generate_dataset, 240 | get_raw_sol=True, 241 | dt=dt, 242 | tmax=tmax, 243 | num_visible=num_visible, 244 | visible_vars=args.visible, 245 | num_der=num_der, 246 | noise=args.noise, 247 | smoothing_params=args.smooth, 248 | ) 249 | 250 | # Set training hyperparameters 251 | n_steps = 50000 252 | sparse_thres = 1e-3 253 | sparse_interval = 5000 254 | 255 | # Define optimizers 256 | optimizers = { 257 | "encoder": optax.adabelief(1e-3, eps=1e-16), 258 | "sym_model": optax.adabelief(1e-3, eps=1e-16), 259 | } 260 | 261 | # Set loss function hyperparameters 262 | loss_fn_args = { 263 | "scale": jnp.array(scale), 264 | "deriv_weight": jnp.array([1.0, 1.0]), 265 | "reg_dzdt": 0, 266 | "reg_l1_sparse": 0, 267 | } 268 | get_dzdt = loss_fn_args["reg_dzdt"] is not None 269 | 270 | # Check dataset shapes 271 | assert scaled_data.shape[-2] == num_visible 272 | assert scaled_data.shape[-1] == num_der + 1 273 | assert scale.shape[0] == num_visible 274 | assert scale.shape[1] == num_der + 1 275 | 276 | # Define model 277 | model_apply, model_init, model_args = get_model( 278 | num_visible, num_hidden, num_der, dt, scale, get_dzdt=get_dzdt 279 | ) 280 | 281 | # Initialize parameters 282 | params = {} 283 | params["encoder"] = model_init["encoder"]( 284 | next(key_seq), jnp.ones([1, scaled_data.shape[0], num_visible]) 285 | ) 286 | params["sym_model"] = model_init["sym_model"]( 287 | next(key_seq), jnp.ones([1, 1, num_visible + num_hidden]), 0.0 288 | ) 289 | 290 | # Train 291 | best_loss, best_params, sparse_mask = train( 292 | n_steps, 293 | model_apply, 294 | params, 295 | scaled_data, 296 | loss_fn_args=loss_fn_args, 297 | data_args={"pad": model_args["pad"]}, 298 | optimizers=optimizers, 299 | sparse_thres=sparse_thres, 300 | sparse_interval=sparse_interval, 301 | key_seq=key_seq, 302 | ) 303 | 304 | # Save model parameters and sparse mask 305 | print(f"Saving best model parameters in output folder: {args.output}") 306 | save_pytree( 307 | os.path.join(args.output, "best.pt"), 308 | {"params": best_params, "sparse_mask": sparse_mask}, 309 | ) 310 | -------------------------------------------------------------------------------- /nlse_1d_model_embed.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import lax 4 | import numpy as np 5 | 6 | import haiku as hk 7 | import optax 8 | 9 | import os.path 10 | import argparse 11 | from functools import partial 12 | 13 | # from tqdm.auto import tqdm 14 | 15 | from data.utils import get_dataset 16 | from data.nlse_1d import generate_dataset 17 | 18 | from encoder.embedding import DirectEmbedding 19 | from encoder.utils import append_dzdt, normalize_by_magnitude, to_complex 20 | from symder.sym_models import ( 21 | SymModel, 22 | PointwisePolynomial, 23 | SpatialDerivative1D, 24 | rescale_z, 25 | ) 26 | from symder.symder import get_symder_apply, get_model_apply 27 | 28 | from utils import loss_fn_weighted, init_optimizers, save_pytree # , load_pytree 29 | 30 | 31 | def get_model( 32 | num_visible, num_hidden, num_der, mesh, dx, dt, scale, squared=False, get_dzdt=False 33 | ): 34 | # Define encoder 35 | pad = None 36 | 37 | def encoder(x, *args): 38 | return DirectEmbedding( 39 | (mesh, num_hidden), 40 | init=lambda *x: jnp.ones(*x) + 0.1 * hk.initializers.RandomNormal()(*x), 41 | )(x, *args) 42 | 43 | encoder = hk.without_apply_rng(hk.transform(encoder)) 44 | encoder_apply = normalize_by_magnitude(encoder.apply, pad=pad, squared=squared) 45 | if get_dzdt: 46 | encoder_apply = append_dzdt(encoder_apply, finite_difference=True) 47 | encoder_apply = to_complex(encoder_apply) 48 | 49 | # Define symbolic model 50 | scale_vec = jnp.sqrt(scale[:, 0]) if squared else jnp.asarray(scale[:, 0]) 51 | 52 | @partial(rescale_z, scale_vec=scale_vec) 53 | def sym_model(z, t): 54 | return SymModel( 55 | -1j * (1e1 * dt), 56 | ( 57 | SpatialDerivative1D( 58 | mesh, dx, deriv_orders=(1, 2, 3, 4), init=jnp.zeros 59 | ), 60 | lambda u: u 61 | * PointwisePolynomial(poly_terms=(2, 4, 6, 8), init=jnp.zeros)( 62 | jnp.abs(u) 63 | ), 64 | ), 65 | )(z, t) 66 | 67 | sym_model = hk.without_apply_rng(hk.transform(sym_model)) 68 | 69 | # Define SymDer function which automatically computes 70 | # higher order time derivatives of symbolic model 71 | symder_apply = get_symder_apply( 72 | sym_model.apply, 73 | num_der=num_der, 74 | transform=(lambda z: jnp.abs(z) ** 2) if squared else jnp.abs, 75 | get_dzdt=get_dzdt, 76 | ) 77 | 78 | # Define full model, combining encoder and symbolic model 79 | model_apply = get_model_apply( 80 | encoder_apply, 81 | symder_apply, 82 | hidden_transform=lambda z: jnp.concatenate((jnp.real(z), jnp.imag(z)), axis=-1), 83 | get_dzdt=get_dzdt, 84 | ) 85 | model_init = {"encoder": encoder.init, "sym_model": sym_model.init} 86 | 87 | return model_apply, model_init, {"pad": pad} 88 | 89 | 90 | def train( 91 | n_steps, 92 | model_apply, 93 | params, 94 | scaled_data, 95 | loss_fn_args={}, 96 | data_args={}, 97 | optimizers={}, 98 | sparse_thres=None, 99 | sparse_interval=None, 100 | key_seq=hk.PRNGSequence(42), 101 | multi_gpu=False, 102 | ): 103 | 104 | # JIT compile/PMAP gradient function 105 | loss_fn_apply = partial(loss_fn_weighted, model_apply, **loss_fn_args) 106 | if multi_gpu: 107 | # Take mean of gradient across multiple devices 108 | def grad_loss(params, batch, target): 109 | grad_out = jax.grad(loss_fn_apply, has_aux=True)(params, batch, target) 110 | return lax.pmean(grad_out, axis_name="devices") 111 | 112 | grad_loss = jax.pmap(grad_loss, axis_name="devices") 113 | else: 114 | grad_loss = jax.jit(jax.grad(loss_fn_apply, has_aux=True)) 115 | # grad_loss = jax.grad(loss_fn_apply, has_aux=True) 116 | 117 | # Initialize sparse mask 118 | sparsify = sparse_thres is not None and sparse_interval is not None 119 | if multi_gpu: 120 | sparse_mask = jax.tree_map( 121 | jax.pmap(lambda x: jnp.ones_like(x, dtype=bool)), params["sym_model"] 122 | ) 123 | else: 124 | sparse_mask = jax.tree_map( 125 | lambda x: jnp.ones_like(x, dtype=bool), params["sym_model"] 126 | ) 127 | 128 | # Initialize optimizers 129 | update_params, opt_state = init_optimizers(params, optimizers, sparsify, multi_gpu) 130 | if multi_gpu: 131 | update_params = jax.pmap(update_params, axis_name="devices") 132 | else: 133 | update_params = jax.jit(update_params) 134 | 135 | # Get batch and target 136 | # TODO: replace this with call to a data generator/data loader 137 | if multi_gpu: 138 | n_devices = jax.device_count() 139 | pad = data_args["pad"] 140 | time_size = (scaled_data.shape[0] - 2 * pad) // n_devices 141 | batch = [] 142 | target = [] 143 | for i in range(n_devices): 144 | start, end = i * time_size, (i + 1) * time_size + 2 * pad 145 | if loss_fn_args["reg_dzdt"] is not None: 146 | # batch, time, mesh, num_visible, 2 147 | batch.append(scaled_data[None, start:end, :, :, :2]) 148 | else: 149 | # batch, time, mesh, num_visible 150 | batch.append(scaled_data[None, start:end, :, :, 0]) 151 | # batch, time, mesh, num_visible, num_der 152 | target.append(scaled_data[None, start + pad : end - pad, :, :, 1:]) 153 | 154 | batch = jax.device_put_sharded(batch, jax.devices()) 155 | target = jax.device_put_sharded(target, jax.devices()) 156 | 157 | else: 158 | if loss_fn_args["reg_dzdt"] is not None: 159 | # batch, time, mesh, mesh, num_visible, 2 160 | batch = scaled_data[None, :, :, :, :2] 161 | else: 162 | # batch, time, mesh, num_visible 163 | batch = scaled_data[None, :, :, :, 0] 164 | 165 | # batch, time, mesh, num_visible, num_der 166 | target = scaled_data[None, :, :, :, 1:] 167 | 168 | batch = jnp.asarray(batch) 169 | target = jnp.asarray(target) 170 | 171 | # Training loop 172 | if multi_gpu: 173 | print(f"Training for {n_steps} steps on {n_devices} devices...") 174 | else: 175 | print(f"Training for {n_steps} steps...") 176 | 177 | best_loss = np.float("inf") 178 | best_params = None 179 | 180 | def thres_fn(x): 181 | return jnp.abs(x) > sparse_thres 182 | 183 | if multi_gpu: 184 | thres_fn = jax.pmap(thres_fn) 185 | 186 | for step in range(n_steps): 187 | # Compute gradients and losses 188 | # weight = jnp.exp(-0.05 * jnp.arange(batch.shape[1]))[:, None, None, None] 189 | weight = jnp.ones_like(batch[..., [0]]) # 1 / (0.1 + batch[..., [0]]) 190 | grads, loss_list = grad_loss(params, batch, target, weight) 191 | 192 | # Save best params if loss is lower than best_loss 193 | loss = loss_list[0][0] if multi_gpu else loss_list[0] 194 | if loss < best_loss: 195 | best_loss = loss 196 | best_params = jax.tree_map(lambda x: x.copy(), params) 197 | 198 | # Update sparse_mask based on a threshold 199 | if sparsify and step > 0 and step % sparse_interval == 0: 200 | sparse_mask = jax.tree_map(thres_fn, best_params["sym_model"]) 201 | 202 | # Update params based on optimizers 203 | params, opt_state, sparse_mask = update_params( 204 | grads, opt_state, params, sparse_mask 205 | ) 206 | 207 | # Normalize encoder params 208 | flat_params, tree = jax.tree_flatten(params["encoder"]) 209 | flat_params[0] /= jnp.linalg.norm(flat_params[0], axis=-1, keepdims=True) 210 | params["encoder"] = jax.tree_unflatten(tree, flat_params) 211 | 212 | # Print loss 213 | if step % 1000 == 0: 214 | loss, mse, reg_dzdt, reg_l1_sparse = loss_list 215 | if multi_gpu: 216 | (loss, mse, reg_dzdt, reg_l1_sparse) = ( 217 | loss[0], 218 | mse[0], 219 | reg_dzdt[0], 220 | reg_l1_sparse[0], 221 | ) 222 | print( 223 | f"Loss[{step}] = {loss}, MSE = {mse}, " 224 | f"Reg. dz/dt = {reg_dzdt}, Reg. L1 Sparse = {reg_l1_sparse}" 225 | ) 226 | if multi_gpu: 227 | print(jax.tree_map(lambda x: x[0], params["sym_model"])) 228 | else: 229 | print(params["sym_model"]) 230 | 231 | save_pytree( 232 | os.path.join(args.output, f"{step:06}.pt"), 233 | {"params": best_params, "sparse_mask": sparse_mask}, 234 | overwrite=True, 235 | ) 236 | 237 | if multi_gpu: 238 | best_params = jax.tree_map(lambda x: x[0], best_params) 239 | sparse_mask = jax.tree_map(lambda x: x[0], sparse_mask) 240 | 241 | print("\nBest loss:", best_loss) 242 | print("Best sym_model params:", best_params["sym_model"]) 243 | return best_loss, best_params, sparse_mask 244 | 245 | 246 | if __name__ == "__main__": 247 | 248 | parser = argparse.ArgumentParser( 249 | description="Run SymDer model on 1D nonlinear Schrödinger data." 250 | ) 251 | parser.add_argument( 252 | "-o", 253 | "--output", 254 | type=str, 255 | default="./nlse_1d_run0/", 256 | help="Output folder path. Default: ./nlse_1d_run0/", 257 | ) 258 | parser.add_argument( 259 | "-d", 260 | "--dataset", 261 | type=str, 262 | default="./data/nlse_1d_raw.npz", 263 | help=( 264 | "Path to 1D nonlinear Schrödinger dataset (generated and saved " 265 | "if it does not exist). Default: ./data/nlse_1d.npz" 266 | ), 267 | ) 268 | args = parser.parse_args() 269 | 270 | # Seed random number generator 271 | key_seq = hk.PRNGSequence(42) 272 | 273 | # Set SymDer parameters 274 | num_visible = 1 275 | num_hidden = 2 276 | num_der = 2 277 | 278 | # Set dataset parameters and load/generate dataset 279 | sys_size = 2 * np.pi 280 | mesh = 64 281 | dt = 1e-3 282 | tspan = (0, 0.5 + 4 * dt) 283 | squared = False 284 | scaled_data, scale, raw_sol = get_dataset( 285 | args.dataset, 286 | generate_dataset, 287 | get_raw_sol=True, 288 | sys_size=sys_size, 289 | mesh=mesh, 290 | dt=dt, 291 | tspan=tspan, 292 | num_der=num_der, 293 | squared=squared, 294 | ) 295 | print("raw_sol shape ", raw_sol.shape) 296 | 297 | # Set training hyperparameters 298 | n_steps = 100000 299 | sparse_thres = 1e-3 300 | sparse_interval = 10000 301 | multi_gpu = False 302 | 303 | # Define optimizers 304 | optimizers = { 305 | "encoder": optax.adabelief(1e-4, eps=1e-16), 306 | "sym_model": optax.adabelief(1e-4, eps=1e-16), 307 | } 308 | 309 | # Set loss function hyperparameters 310 | loss_fn_args = { 311 | "scale": jnp.array(scale), 312 | "deriv_weight": jnp.array([1.0, 1.0]), 313 | "reg_dzdt": 1e3, 314 | "reg_dzdt_var_norm": False, 315 | "reg_l1_sparse": 0, 316 | } 317 | get_dzdt = loss_fn_args["reg_dzdt"] is not None 318 | 319 | # Check dataset shapes 320 | assert scaled_data.shape[-2] == num_visible 321 | assert scaled_data.shape[-1] == num_der + 1 322 | assert scale.shape[0] == num_visible 323 | assert scale.shape[1] == num_der + 1 324 | 325 | # Define model 326 | model_apply, model_init, model_args = get_model( 327 | num_visible, 328 | num_hidden, 329 | num_der, 330 | mesh, 331 | sys_size / mesh, 332 | dt, 333 | scale, 334 | squared=squared, 335 | get_dzdt=get_dzdt, 336 | ) 337 | 338 | # Initialize parameters 339 | params = {} 340 | params["encoder"] = model_init["encoder"]( 341 | next(key_seq), jnp.ones([1, scaled_data.shape[0], mesh, num_visible]) 342 | ) 343 | params["sym_model"] = model_init["sym_model"]( 344 | next(key_seq), jnp.ones([1, 1, mesh, num_hidden // 2], dtype=jnp.complex64), 0.0 345 | ) 346 | 347 | if multi_gpu: 348 | for name in params.keys(): 349 | params[name] = jax.device_put_replicated(params[name], jax.devices()) 350 | 351 | # Train 352 | best_loss, best_params, sparse_mask = train( 353 | n_steps, 354 | model_apply, 355 | params, 356 | scaled_data, 357 | loss_fn_args=loss_fn_args, 358 | data_args={"pad": model_args["pad"]}, 359 | optimizers=optimizers, 360 | sparse_thres=sparse_thres, 361 | sparse_interval=sparse_interval, 362 | key_seq=key_seq, 363 | multi_gpu=multi_gpu, 364 | ) 365 | 366 | # Save model parameters and sparse mask 367 | print(f"Saving best model parameters in output folder: {args.output}") 368 | save_pytree( 369 | os.path.join(args.output, "best.pt"), 370 | {"params": best_params, "sparse_mask": sparse_mask}, 371 | ) 372 | -------------------------------------------------------------------------------- /reac_diff_2d_model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import lax 4 | import numpy as np 5 | 6 | import haiku as hk 7 | import optax 8 | 9 | import os.path 10 | import argparse 11 | from functools import partial 12 | 13 | # from tqdm.auto import tqdm 14 | 15 | from data.utils import get_dataset 16 | from data.reac_diff_2d import generate_dataset 17 | 18 | from encoder.utils import append_dzdt, concat_visible 19 | from symder.sym_models import SymModel, Quadratic, SpatialDerivative2D, rescale_z 20 | from symder.symder import get_symder_apply, get_model_apply 21 | 22 | from utils import loss_fn, init_optimizers, save_pytree # , load_pytree 23 | 24 | 25 | def get_model(num_visible, num_hidden, num_der, mesh, dx, dt, scale, get_dzdt=False): 26 | 27 | # Define encoder 28 | hidden_size = 64 29 | pad = 2 30 | 31 | def encoder(x): 32 | return hk.Sequential( 33 | [ 34 | lambda x: jnp.pad( 35 | x, ((0, 0), (0, 0), (pad, pad), (pad, pad), (0, 0)), "wrap" 36 | ), 37 | hk.Conv3D(hidden_size, kernel_shape=5, padding="VALID"), 38 | jax.nn.relu, 39 | hk.Conv3D(hidden_size, kernel_shape=1), 40 | jax.nn.relu, 41 | hk.Conv3D(num_hidden, kernel_shape=1), 42 | ] 43 | )(x) 44 | 45 | encoder = hk.without_apply_rng(hk.transform(encoder)) 46 | encoder_apply = append_dzdt(encoder.apply) if get_dzdt else encoder.apply 47 | encoder_apply = concat_visible( 48 | encoder_apply, visible_transform=lambda x: x[:, pad:-pad] 49 | ) 50 | 51 | # Define symbolic model 52 | n_dims = num_visible + num_hidden 53 | scale_vec = jnp.concatenate((scale[:, 0], jnp.ones(num_hidden))) 54 | 55 | @partial(rescale_z, scale_vec=scale_vec) 56 | def sym_model(z, t): 57 | return SymModel( 58 | 1e1 * dt, 59 | ( 60 | SpatialDerivative2D(mesh, np.sqrt(1e1) * dx, init=jnp.zeros), 61 | hk.Linear(n_dims, w_init=jnp.zeros, b_init=jnp.zeros), 62 | Quadratic(n_dims, init=jnp.zeros), 63 | ), 64 | )(z, t) 65 | 66 | sym_model = hk.without_apply_rng(hk.transform(sym_model)) 67 | 68 | # Define SymDer function which automatically computes 69 | # higher order time derivatives of symbolic model 70 | symder_apply = get_symder_apply( 71 | sym_model.apply, 72 | num_der=num_der, 73 | transform=lambda z: z[..., :num_visible], 74 | get_dzdt=get_dzdt, 75 | ) 76 | 77 | # Define full model, combining encoder and symbolic model 78 | model_apply = get_model_apply( 79 | encoder_apply, 80 | symder_apply, 81 | hidden_transform=lambda z: z[..., -num_hidden:], 82 | get_dzdt=get_dzdt, 83 | ) 84 | model_init = {"encoder": encoder.init, "sym_model": sym_model.init} 85 | 86 | return model_apply, model_init, {"pad": pad} 87 | 88 | 89 | def train( 90 | n_steps, 91 | model_apply, 92 | params, 93 | scaled_data, 94 | loss_fn_args={}, 95 | data_args={}, 96 | optimizers={}, 97 | sparse_thres=None, 98 | sparse_interval=None, 99 | key_seq=hk.PRNGSequence(42), 100 | multi_gpu=False, 101 | ): 102 | 103 | # JIT compile/PMAP gradient function 104 | loss_fn_apply = partial(loss_fn, model_apply, **loss_fn_args) 105 | if multi_gpu: 106 | # Take mean of gradient across multiple devices 107 | def grad_loss(params, batch, target): 108 | grad_out = jax.grad(loss_fn_apply, has_aux=True)(params, batch, target) 109 | return lax.pmean(grad_out, axis_name="devices") 110 | 111 | grad_loss = jax.pmap(grad_loss, axis_name="devices") 112 | else: 113 | grad_loss = jax.jit(jax.grad(loss_fn_apply, has_aux=True)) 114 | 115 | # Initialize sparse mask 116 | sparsify = sparse_thres is not None and sparse_interval is not None 117 | if multi_gpu: 118 | sparse_mask = jax.tree_map( 119 | jax.pmap(lambda x: jnp.ones_like(x, dtype=bool)), params["sym_model"] 120 | ) 121 | else: 122 | sparse_mask = jax.tree_map( 123 | lambda x: jnp.ones_like(x, dtype=bool), params["sym_model"] 124 | ) 125 | 126 | # Initialize optimizers 127 | update_params, opt_state = init_optimizers(params, optimizers, sparsify, multi_gpu) 128 | if multi_gpu: 129 | update_params = jax.pmap(update_params, axis_name="devices") 130 | else: 131 | update_params = jax.jit(update_params) 132 | 133 | # Get batch and target 134 | # TODO: replace this with call to a data generator/data loader 135 | if multi_gpu: 136 | n_devices = jax.device_count() 137 | pad = data_args["pad"] 138 | time_size = (scaled_data.shape[0] - 2 * pad) // n_devices 139 | batch = [] 140 | target = [] 141 | for i in range(n_devices): 142 | start, end = i * time_size, (i + 1) * time_size + 2 * pad 143 | if loss_fn_args["reg_dzdt"] is not None: 144 | # batch, time, mesh, mesh, num_visible, 2 145 | batch.append(scaled_data[None, start:end, :, :, :, :2]) 146 | else: 147 | # batch, time, mesh, mesh, num_visible 148 | batch.append(scaled_data[None, start:end, :, :, :, 0]) 149 | # batch, time, mesh, mesh, num_visible, num_der 150 | target.append(scaled_data[None, start + pad : end - pad, :, :, :, 1:]) 151 | 152 | batch = jax.device_put_sharded(batch, jax.devices()) 153 | target = jax.device_put_sharded(target, jax.devices()) 154 | 155 | else: 156 | if loss_fn_args["reg_dzdt"] is not None: 157 | # batch, time, mesh, mesh, num_visible, 2 158 | batch = scaled_data[None, :, :, :, :, :2] 159 | else: 160 | # batch, time, mesh, mesh, num_visible 161 | batch = scaled_data[None, :, :, :, :, 0] 162 | pad = data_args["pad"] 163 | # batch, time, mesh, mesh, num_visible, num_der 164 | target = scaled_data[None, pad:-pad, :, :, :, 1:] 165 | 166 | batch = jnp.asarray(batch) 167 | target = jnp.asarray(target) 168 | 169 | # Training loop 170 | if multi_gpu: 171 | print(f"Training for {n_steps} steps on {n_devices} devices...") 172 | else: 173 | print(f"Training for {n_steps} steps...") 174 | 175 | best_loss = np.float("inf") 176 | best_params = None 177 | 178 | def thres_fn(x): 179 | return jnp.abs(x) > sparse_thres 180 | 181 | if multi_gpu: 182 | thres_fn = jax.pmap(thres_fn) 183 | 184 | for step in range(n_steps): 185 | # Compute gradients and losses 186 | grads, loss_list = grad_loss(params, batch, target) 187 | 188 | # Save best params if loss is lower than best_loss 189 | loss = loss_list[0][0] if multi_gpu else loss_list[0] 190 | if loss < best_loss: 191 | best_loss = loss 192 | best_params = jax.tree_map(lambda x: x.copy(), params) 193 | 194 | # Update sparse_mask based on a threshold 195 | if sparsify and step > 0 and step % sparse_interval == 0: 196 | sparse_mask = jax.tree_map(thres_fn, best_params["sym_model"]) 197 | 198 | # Update params based on optimizers 199 | params, opt_state, sparse_mask = update_params( 200 | grads, opt_state, params, sparse_mask 201 | ) 202 | 203 | # Print loss 204 | if step % 100 == 0: 205 | loss, mse, reg_dzdt, reg_l1_sparse = loss_list 206 | if multi_gpu: 207 | (loss, mse, reg_dzdt, reg_l1_sparse) = ( 208 | loss[0], 209 | mse[0], 210 | reg_dzdt[0], 211 | reg_l1_sparse[0], 212 | ) 213 | print( 214 | f"Loss[{step}] = {loss}, MSE = {mse}, " 215 | f"Reg. dz/dt = {reg_dzdt}, Reg. L1 Sparse = {reg_l1_sparse}" 216 | ) 217 | if multi_gpu: 218 | print(jax.tree_map(lambda x: x[0], params["sym_model"])) 219 | else: 220 | print(params["sym_model"]) 221 | 222 | if multi_gpu: 223 | best_params = jax.tree_map(lambda x: x[0], best_params) 224 | sparse_mask = jax.tree_map(lambda x: x[0], sparse_mask) 225 | 226 | print("\nBest loss:", best_loss) 227 | print("Best sym_model params:", best_params["sym_model"]) 228 | return best_loss, best_params, sparse_mask 229 | 230 | 231 | if __name__ == "__main__": 232 | 233 | parser = argparse.ArgumentParser( 234 | description="Run SymDer model on 2D reaction-diffusion data." 235 | ) 236 | parser.add_argument( 237 | "-o", 238 | "--output", 239 | type=str, 240 | default="./reac_diff_2d_run1/", 241 | help="Output folder path. Default: ./reac_diff_2d_run0/", 242 | ) 243 | parser.add_argument( 244 | "-d", 245 | "--dataset", 246 | type=str, 247 | default="./data/reac_diff_2d.npz", 248 | help=( 249 | "Path to 2D reaction-diffusion dataset (generated and saved " 250 | "if it does not exist). Default: ./data/reac_diff_2d.npz" 251 | ), 252 | ) 253 | args = parser.parse_args() 254 | 255 | # Seed random number generator 256 | key_seq = hk.PRNGSequence(42) 257 | 258 | # Set SymDer parameters 259 | num_visible = 1 260 | num_hidden = 1 261 | num_der = 2 262 | 263 | # Set dataset parameters and load/generate dataset 264 | sys_size = 64 265 | mesh = 64 266 | dt = 5e-2 267 | tspan = (0, 50 + 2 * dt) 268 | scaled_data, scale, raw_sol = get_dataset( 269 | args.dataset, 270 | generate_dataset, 271 | get_raw_sol=True, 272 | sys_size=sys_size, 273 | mesh=mesh, 274 | dt=dt, 275 | tspan=tspan, 276 | num_der=num_der, 277 | ) 278 | 279 | # Set training hyperparameters 280 | n_steps = 100000 281 | sparse_thres = 2e-3 282 | sparse_interval = 1000 283 | multi_gpu = True 284 | 285 | # Define optimizers 286 | optimizers = { 287 | "encoder": optax.adabelief(1e-3, eps=1e-16), 288 | "sym_model": optax.adabelief(1e-3, eps=1e-16), 289 | } 290 | 291 | # Set loss function hyperparameters 292 | loss_fn_args = { 293 | "scale": jnp.array(scale), 294 | "deriv_weight": jnp.array([1.0, 1.0]), 295 | "reg_dzdt": 0, 296 | "reg_l1_sparse": 0, 297 | } 298 | get_dzdt = loss_fn_args["reg_dzdt"] is not None 299 | 300 | # Check dataset shapes 301 | assert scaled_data.shape[-2] == num_visible 302 | assert scaled_data.shape[-1] == num_der + 1 303 | assert scale.shape[0] == num_visible 304 | assert scale.shape[1] == num_der + 1 305 | 306 | # Define model 307 | model_apply, model_init, model_args = get_model( 308 | num_visible, 309 | num_hidden, 310 | num_der, 311 | mesh, 312 | sys_size / mesh, 313 | dt, 314 | scale, 315 | get_dzdt=get_dzdt, 316 | ) 317 | 318 | # Initialize parameters 319 | params = {} 320 | params["encoder"] = model_init["encoder"]( 321 | next(key_seq), jnp.ones([1, scaled_data.shape[0], mesh, mesh, num_visible]) 322 | ) 323 | params["sym_model"] = model_init["sym_model"]( 324 | next(key_seq), jnp.ones([1, 1, mesh, mesh, num_visible + num_hidden]), 0.0 325 | ) 326 | if multi_gpu: 327 | for name in params.keys(): 328 | params[name] = jax.device_put_replicated(params[name], jax.devices()) 329 | 330 | # Train 331 | best_loss, best_params, sparse_mask = train( 332 | n_steps, 333 | model_apply, 334 | params, 335 | scaled_data, 336 | loss_fn_args=loss_fn_args, 337 | data_args={"pad": model_args["pad"]}, 338 | optimizers=optimizers, 339 | sparse_thres=sparse_thres, 340 | sparse_interval=sparse_interval, 341 | key_seq=key_seq, 342 | multi_gpu=multi_gpu, 343 | ) 344 | 345 | # Save model parameters and sparse mask 346 | print(f"Saving best model parameters in output folder: {args.output}") 347 | save_pytree( 348 | os.path.join(args.output, "best.pt"), 349 | {"params": best_params, "sparse_mask": sparse_mask}, 350 | ) 351 | -------------------------------------------------------------------------------- /rossler_model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | import haiku as hk 6 | import optax 7 | 8 | import os.path 9 | import argparse 10 | from functools import partial 11 | 12 | # from tqdm.auto import tqdm 13 | 14 | from data.utils import get_dataset 15 | from data.rossler import generate_dataset 16 | 17 | from encoder.utils import append_dzdt, concat_visible 18 | from symder.sym_models import SymModel, Quadratic, rescale_z 19 | from symder.symder import get_symder_apply, get_model_apply 20 | 21 | from utils import loss_fn, init_optimizers, save_pytree # , load_pytree 22 | 23 | 24 | def get_model(num_visible, num_hidden, num_der, dt, scale, get_dzdt=False): 25 | 26 | # Define encoder 27 | hidden_size = 128 28 | pad = 4 29 | 30 | def encoder(x): 31 | return hk.Sequential( 32 | [ 33 | hk.Conv1D(hidden_size, kernel_shape=9, padding="VALID"), 34 | jax.nn.relu, 35 | hk.Conv1D(hidden_size, kernel_shape=1), 36 | jax.nn.relu, 37 | hk.Conv1D(num_hidden, kernel_shape=1), 38 | ] 39 | )(x) 40 | 41 | encoder = hk.without_apply_rng(hk.transform(encoder)) 42 | encoder_apply = append_dzdt(encoder.apply) if get_dzdt else encoder.apply 43 | encoder_apply = concat_visible( 44 | encoder_apply, visible_transform=lambda x: x[:, pad:-pad] 45 | ) 46 | 47 | # Define symbolic model 48 | n_dims = num_visible + num_hidden 49 | scale_vec = jnp.concatenate((scale[:, 0], jnp.ones(num_hidden))) 50 | 51 | @partial(rescale_z, scale_vec=scale_vec) 52 | def sym_model(z, t): 53 | return SymModel( 54 | 1e2 * dt, 55 | ( 56 | hk.Linear(n_dims, w_init=jnp.zeros, b_init=jnp.zeros), 57 | Quadratic(n_dims, init=jnp.zeros), 58 | ), 59 | )(z, t) 60 | 61 | sym_model = hk.without_apply_rng(hk.transform(sym_model)) 62 | 63 | # Define SymDer function which automatically computes 64 | # higher order time derivatives of symbolic model 65 | symder_apply = get_symder_apply( 66 | sym_model.apply, 67 | num_der=num_der, 68 | transform=lambda z: z[..., :num_visible], 69 | get_dzdt=get_dzdt, 70 | ) 71 | 72 | # Define full model, combining encoder and symbolic model 73 | model_apply = get_model_apply( 74 | encoder_apply, 75 | symder_apply, 76 | hidden_transform=lambda z: z[..., -num_hidden:], 77 | get_dzdt=get_dzdt, 78 | ) 79 | model_init = {"encoder": encoder.init, "sym_model": sym_model.init} 80 | 81 | return model_apply, model_init, {"pad": pad} 82 | 83 | 84 | def train( 85 | n_steps, 86 | model_apply, 87 | params, 88 | scaled_data, 89 | loss_fn_args={}, 90 | data_args={}, 91 | optimizers={}, 92 | sparse_thres=None, 93 | sparse_interval=None, 94 | key_seq=hk.PRNGSequence(42), 95 | ): 96 | 97 | # JIT compile gradient function 98 | loss_fn_apply = partial(loss_fn, model_apply, **loss_fn_args) 99 | grad_loss = jax.jit(jax.grad(loss_fn_apply, has_aux=True)) 100 | 101 | # Initialize sparse mask 102 | sparsify = sparse_thres is not None and sparse_interval is not None 103 | sparse_mask = jax.tree_map( 104 | lambda x: jnp.ones_like(x, dtype=bool), params["sym_model"] 105 | ) 106 | 107 | # Initialize optimizers 108 | update_params, opt_state = init_optimizers(params, optimizers, sparsify) 109 | update_params = jax.jit(update_params) 110 | 111 | # Get batch and target 112 | # TODO: replace this with call to a data generator/data loader 113 | if loss_fn_args["reg_dzdt"] is not None: 114 | batch = scaled_data[None, :, :, :2] # batch, time, num_visible, 2 115 | else: 116 | batch = scaled_data[None, :, :, 0] # batch, time, num_visible 117 | pad = data_args["pad"] 118 | # batch, time, num_visible, num_der 119 | target = scaled_data[None, pad:-pad, :, 1:] 120 | 121 | batch = jnp.asarray(batch) 122 | target = jnp.asarray(target) 123 | 124 | # Training loop 125 | print(f"Training for {n_steps} steps...") 126 | 127 | best_loss = np.float("inf") 128 | best_params = None 129 | 130 | for step in range(n_steps): 131 | 132 | # Compute gradients and losses 133 | grads, loss_list = grad_loss(params, batch, target) 134 | 135 | # Save best params if loss is lower than best_loss 136 | loss = loss_list[0] 137 | if loss < best_loss: 138 | best_loss = loss 139 | best_params = jax.tree_map(lambda x: x.copy(), params) 140 | 141 | # Update sparse_mask based on a threshold 142 | if sparsify and step > 0 and step % sparse_interval == 0: 143 | sparse_mask = jax.tree_map( 144 | lambda x: jnp.abs(x) > sparse_thres, best_params["sym_model"] 145 | ) 146 | 147 | # Update params based on optimizers 148 | params, opt_state, sparse_mask = update_params( 149 | grads, opt_state, params, sparse_mask 150 | ) 151 | 152 | # Print loss 153 | if step % 1000 == 0: 154 | loss, mse, reg_dzdt, reg_l1_sparse = loss_list 155 | print( 156 | f"Loss[{step}] = {loss}, MSE = {mse}, " 157 | f"Reg. dz/dt = {reg_dzdt}, Reg. L1 Sparse = {reg_l1_sparse}" 158 | ) 159 | print(params["sym_model"]) 160 | 161 | print("\nBest loss:", best_loss) 162 | print("Best sym_model params:", best_params["sym_model"]) 163 | return best_loss, best_params, sparse_mask 164 | 165 | 166 | if __name__ == "__main__": 167 | 168 | parser = argparse.ArgumentParser( 169 | description="Run SymDer model on Rossler system data." 170 | ) 171 | parser.add_argument( 172 | "-o", 173 | "--output", 174 | type=str, 175 | default="./rossler_run0/", 176 | help="Output folder path. Default: ./rossler_run0/", 177 | ) 178 | parser.add_argument( 179 | "-d", 180 | "--dataset", 181 | type=str, 182 | default="./data/rossler.npz", 183 | help=( 184 | "Path to Rossler system dataset (generated and saved " 185 | "if it does not exist). Default: ./data/rossler.npz" 186 | ), 187 | ) 188 | args = parser.parse_args() 189 | 190 | # Seed random number generator 191 | key_seq = hk.PRNGSequence(42) 192 | 193 | # Set SymDer parameters 194 | num_visible = 2 195 | num_hidden = 1 196 | num_der = 2 197 | 198 | # Set dataset parameters and load/generate dataset 199 | dt = 1e-2 200 | tmax = 100 + 2 * dt 201 | scaled_data, scale, raw_sol = get_dataset( 202 | args.dataset, 203 | generate_dataset, 204 | get_raw_sol=True, 205 | dt=dt, 206 | tmax=tmax, 207 | num_visible=num_visible, 208 | num_der=num_der, 209 | ) 210 | 211 | # Set training hyperparameters 212 | n_steps = 50000 213 | sparse_thres = 1e-3 214 | sparse_interval = 5000 215 | 216 | # Define optimizers 217 | optimizers = { 218 | "encoder": optax.adabelief(1e-3, eps=1e-16), 219 | "sym_model": optax.adabelief(1e-3, eps=1e-16), 220 | } 221 | 222 | # Set loss function hyperparameters 223 | loss_fn_args = { 224 | "scale": jnp.array(scale), 225 | "deriv_weight": jnp.array([1.0, 1.0]), 226 | "reg_dzdt": 0, 227 | "reg_l1_sparse": 0, 228 | } 229 | get_dzdt = loss_fn_args["reg_dzdt"] is not None 230 | 231 | # Check dataset shapes 232 | assert scaled_data.shape[-2] == num_visible 233 | assert scaled_data.shape[-1] == num_der + 1 234 | assert scale.shape[0] == num_visible 235 | assert scale.shape[1] == num_der + 1 236 | 237 | # Define model 238 | model_apply, model_init, model_args = get_model( 239 | num_visible, num_hidden, num_der, dt, scale, get_dzdt=get_dzdt 240 | ) 241 | 242 | # Initialize parameters 243 | params = {} 244 | params["encoder"] = model_init["encoder"]( 245 | next(key_seq), jnp.ones([1, scaled_data.shape[0], num_visible]) 246 | ) 247 | params["sym_model"] = model_init["sym_model"]( 248 | next(key_seq), jnp.ones([1, 1, num_visible + num_hidden]), 0.0 249 | ) 250 | 251 | # Train 252 | best_loss, best_params, sparse_mask = train( 253 | n_steps, 254 | model_apply, 255 | params, 256 | scaled_data, 257 | loss_fn_args=loss_fn_args, 258 | data_args={"pad": model_args["pad"]}, 259 | optimizers=optimizers, 260 | sparse_thres=sparse_thres, 261 | sparse_interval=sparse_interval, 262 | key_seq=key_seq, 263 | ) 264 | 265 | # Save model parameters and sparse mask 266 | print(f"Saving best model parameters in output folder: {args.output}") 267 | save_pytree( 268 | os.path.join(args.output, "best.pt"), 269 | {"params": best_params, "sparse_mask": sparse_mask}, 270 | ) 271 | -------------------------------------------------------------------------------- /symder/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["odeint_zero", "symder", "sym_models"] 2 | from . import * 3 | -------------------------------------------------------------------------------- /symder/odeint_zero.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from functools import partial 4 | 5 | __all__ = ["odeint_zero", "dfunc", "d_dt"] 6 | 7 | 8 | @partial(jax.custom_jvp, nondiff_argnums=(0,)) 9 | def odeint_zero(func, y0, t, *args): 10 | return y0 11 | 12 | 13 | @odeint_zero.defjvp 14 | def odeint_zero_jvp(func, primals, tangents): 15 | y0, t, *args = primals 16 | dy0, dt, *dargs = tangents 17 | 18 | # Wrap y0 using `odeint_zero` to obtain y 19 | y = odeint_zero(func, y0, t, *args) 20 | 21 | # Define time derivative dy/dt = func(y, t) 22 | dydt = func(y, t, *args) 23 | 24 | # Compute JVP: dy = dy/dy0 * dy0 + dy/dt * dt, where dy/dy0 = 1 25 | dy = jax.tree_multimap( 26 | lambda dy0_, dydt_: dy0_ + dydt_ * jnp.broadcast_to(dt, dydt_.shape), dy0, dydt 27 | ) 28 | return y, dy 29 | 30 | 31 | def dfunc(func, order, transform=None): 32 | func0 = ( 33 | partial(odeint_zero, func) 34 | if transform is None 35 | else lambda y0, t, *args: transform(odeint_zero(func, y0, t, *args)) 36 | ) 37 | 38 | # TODO: Can potentially replace this with jax.experimental.jet 39 | # for more efficient third or higher order derivatives 40 | out = [func0] 41 | for _ in range(order): 42 | out.append(d_dt(out[-1])) 43 | return out 44 | 45 | 46 | def d_dt(func): 47 | def dfunc_dt(y0, t, *args): 48 | dy0 = jax.tree_map(jnp.zeros_like, y0) 49 | dt = jax.tree_map(jnp.ones_like, t) 50 | dargs = jax.tree_map(jnp.zeros_like, args) 51 | return jax.jvp(func, (y0, t, *args), (dy0, dt, *dargs))[1] 52 | 53 | return dfunc_dt 54 | -------------------------------------------------------------------------------- /symder/sym_models.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from jax import lax 3 | import numpy as np 4 | 5 | import haiku as hk 6 | 7 | __all__ = [ 8 | "Quadratic", 9 | "PointwisePolynomial", 10 | "SpatialDerivative1D", 11 | "SpatialDerivative1D_FiniteDiff", 12 | "SpatialDerivative2D", 13 | "SymModel", 14 | "rescale_z", 15 | ] 16 | 17 | 18 | class Quadratic(hk.Module): 19 | def __init__(self, n_dims, init=jnp.zeros): 20 | super().__init__() 21 | self.n_dims = n_dims 22 | 23 | ind = np.arange(n_dims) 24 | mesh = np.stack(np.meshgrid(ind, ind), -1) 25 | self.mask = jnp.array(mesh[..., 0] >= mesh[..., 1]) 26 | 27 | self.init = lambda *args: self.mask * init(*args) 28 | 29 | def __call__(self, z, t=None): 30 | weights = self.mask * hk.get_parameter( 31 | "w", (self.n_dims, self.n_dims, self.n_dims), init=self.init 32 | ) 33 | out = (weights * z[..., None, None, :] * z[..., None, :, None]).sum((-2, -1)) 34 | return out 35 | 36 | 37 | class Cubic(hk.Module): 38 | def __init__(self, n_dims, init=jnp.zeros): 39 | super().__init__() 40 | self.n_dims = n_dims 41 | 42 | ind = np.arange(n_dims) 43 | mesh = np.stack(np.meshgrid(ind, ind, ind), -1) 44 | self.mask = jnp.array(mesh[..., 0] >= mesh[..., 1]) * jnp.array( 45 | mesh[..., 1] >= mesh[..., 2] 46 | ) 47 | 48 | self.init = lambda *args: self.mask * init(*args) 49 | 50 | def __call__(self, z, t=None): 51 | weights = self.mask * hk.get_parameter( 52 | "w", (self.n_dims, self.n_dims, self.n_dims, self.n_dims), init=self.init 53 | ) 54 | out = ( 55 | weights 56 | * z[..., None, None, None, :] 57 | * z[..., None, None, :, None] 58 | * z[..., None, :, None, None] 59 | ).sum((-3, -2, -1)) 60 | return out 61 | 62 | 63 | class PointwisePolynomial(hk.Module): 64 | def __init__( 65 | self, poly_terms=(2, 4), init=jnp.zeros, name="pointwise_polynomial", 66 | ): 67 | super().__init__(name=name) 68 | self.init = init 69 | self.poly_terms = poly_terms 70 | 71 | def __call__(self, z, t=None): 72 | w = hk.get_parameter("w", (len(self.poly_terms),), init=self.init) 73 | terms = jnp.stack([z ** n for n in self.poly_terms], axis=-1) 74 | return jnp.sum(w * terms, axis=-1) 75 | 76 | 77 | class SpatialDerivative1D(hk.Module): 78 | def __init__( 79 | self, 80 | mesh, 81 | dx, 82 | deriv_orders=(1, 2), 83 | init=jnp.zeros, 84 | name="spatial_derivative_1d", 85 | ): 86 | super().__init__(name=name) 87 | self.init = init 88 | 89 | k = 2 * np.pi * np.fft.fftfreq(mesh, d=dx)[:, None] 90 | 91 | # for use in odd derivatives 92 | k_1 = k.copy() 93 | if mesh % 2 == 0: 94 | k_1[int(mesh / 2), :] = 0 95 | 96 | self.ik_vec = jnp.stack( 97 | [(1j * k) ** n if (n % 2 == 0) else (1j * k_1) ** n for n in deriv_orders], 98 | axis=-1, 99 | ) 100 | 101 | def __call__(self, u, t=None): 102 | v = jnp.fft.fft(u, axis=-2) 103 | w = hk.get_parameter("w", (u.shape[-1], self.ik_vec.shape[-1]), init=self.init) 104 | L = jnp.sum(w * self.ik_vec, axis=-1) 105 | du = jnp.fft.ifft(L * v, axis=-2) 106 | return jnp.real(du) if jnp.isrealobj(u) else du 107 | 108 | 109 | class SpatialDerivative1D_FiniteDiff(hk.Module): 110 | def __init__( 111 | self, 112 | mesh, 113 | dx, 114 | deriv_orders=(1, 2), 115 | init=jnp.zeros, 116 | name="spatial_derivative_1d_finite_diff", 117 | ): 118 | super().__init__(name=name) 119 | self.init = init 120 | self.kernels = jnp.take( 121 | self.generate_diff_kernels(max(deriv_orders)), deriv_orders, axis=0 122 | ) 123 | 124 | def generate_diff_kernels(self, order): 125 | self.pad = int(np.floor((order + 1) / 2)) 126 | 127 | rev_d1 = np.array((0.5, 0.0, -0.5)) 128 | d2 = np.array((1.0, -2.0, 1.0)) 129 | 130 | even_kernels = [np.pad(np.array((1.0,)), (self.pad,))] 131 | for i in range(order // 2): 132 | even_kernels.append(np.convolve(even_kernels[-1], d2, mode="same")) 133 | 134 | even_kernels = np.stack(even_kernels) 135 | odd_kernels = lax.conv( 136 | even_kernels[:, None, :], rev_d1[None, None, :], (1,), "SAME" 137 | ).squeeze(1) 138 | 139 | kernels = jnp.stack((even_kernels, odd_kernels), axis=1).reshape( 140 | -1, 2 * self.pad + 1 141 | ) 142 | if order % 2 == 0: 143 | kernels = kernels[:-1] 144 | 145 | return kernels 146 | 147 | def __call__(self, u, t=None): 148 | u_shape = u.shape 149 | n_terms = self.kernels.shape[0] 150 | u = ( 151 | u.reshape(-1, u_shape[-2], u_shape[-1]) 152 | .transpose((0, 2, 1)) 153 | .reshape(-1, 1, u_shape[-2]) 154 | ) 155 | u = jnp.pad(u, ((0, 0), (0, 0), (self.pad, self.pad)), "wrap") 156 | w = hk.get_parameter("w", (u_shape[-1], n_terms), init=self.init) 157 | terms = lax.conv(u, self.kernels[:, None, :].astype(u.dtype), (1,), "VALID") 158 | du = jnp.sum( 159 | jnp.expand_dims(w, -1) 160 | * terms.reshape(-1, u_shape[-1], n_terms, u_shape[-2]), 161 | axis=-2, 162 | ) 163 | return du.transpose((0, 2, 1)).reshape(u_shape) 164 | 165 | 166 | class SpatialDerivative2D(hk.Module): 167 | def __init__(self, mesh, dx, init=jnp.zeros, name="spatial_derivative_2d"): 168 | super().__init__(name=name) 169 | self.init = init 170 | 171 | kx = 2 * np.pi * np.fft.fftfreq(mesh, d=dx)[:, None, None] 172 | ky = 2 * np.pi * np.fft.fftfreq(mesh, d=dx)[None, :, None] 173 | 174 | # for use in odd derivatives 175 | kx_1 = kx.copy() 176 | ky_1 = ky.copy() 177 | if mesh % 2 == 0: 178 | kx_1[int(mesh / 2), :, :] = 0 179 | ky_1[:, int(mesh / 2), :] = 0 180 | 181 | kx = jnp.broadcast_to(kx, (mesh, mesh, 1)) 182 | ky = jnp.broadcast_to(ky, (mesh, mesh, 1)) 183 | kx_1 = jnp.broadcast_to(kx_1, (mesh, mesh, 1)) 184 | ky_1 = jnp.broadcast_to(ky_1, (mesh, mesh, 1)) 185 | 186 | self.ik_vec = jnp.stack( 187 | [ 188 | 1j * kx_1, 189 | 1j * ky_1, 190 | (1j * kx) ** 2, 191 | (1j * ky) ** 2, 192 | (1j * kx_1) * (1j * ky_1), 193 | ], 194 | axis=-1, 195 | ) 196 | 197 | def __call__(self, u, t=None): 198 | v = jnp.fft.fft2(u, axes=(-3, -2)) 199 | w = hk.get_parameter("w", (u.shape[-1], self.ik_vec.shape[-1]), init=self.init) 200 | L = jnp.sum(w * self.ik_vec, axis=-1) 201 | du = jnp.fft.ifft2(L * v, axes=(-3, -2)) 202 | return jnp.real(du) if jnp.isrealobj(u) else du 203 | 204 | 205 | class SymModel(hk.Module): 206 | def __init__(self, dt, module_list, time_dependence=False, name="sym_model"): 207 | super().__init__(name=name) 208 | self.dt = dt 209 | self.module_list = tuple(module_list) 210 | self.time_dependence = time_dependence 211 | 212 | def __call__(self, z, t): 213 | if self.time_dependence: 214 | dz = self.dt * sum(module(z, t) for module in self.module_list) 215 | else: 216 | dz = self.dt * sum(module(z) for module in self.module_list) 217 | return dz 218 | 219 | 220 | def rescale_z(sym_model_apply, scale_vec=None): 221 | if scale_vec is None: 222 | return sym_model_apply 223 | 224 | def rescaled_sym_model(z, t, *args, **kwargs): 225 | return sym_model_apply(z * scale_vec, t, *args, **kwargs) / scale_vec 226 | 227 | return rescaled_sym_model 228 | -------------------------------------------------------------------------------- /symder/symder.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from .odeint_zero import dfunc 3 | 4 | __all__ = ["get_symder_apply", "get_model_apply"] 5 | 6 | 7 | def get_symder_apply(sym_model_apply, num_der=1, transform=None, get_dzdt=False): 8 | def func(z, t, params): 9 | return sym_model_apply(params, z, t) 10 | 11 | dfuncs = dfunc(func, num_der, transform=transform)[1:] 12 | 13 | def symder_apply(params, z, t=0.0): 14 | sym_deriv_x = jnp.stack(tuple(map(lambda f: f(z, t, params), dfuncs)), axis=-1) 15 | if get_dzdt: 16 | sym_dzdt = func(z, t, params) 17 | return sym_deriv_x, sym_dzdt 18 | else: 19 | return sym_deriv_x 20 | 21 | return symder_apply 22 | 23 | 24 | def get_model_apply( 25 | encoder_apply, 26 | symder_apply, 27 | encoder_name="encoder", 28 | sym_model_name="sym_model", 29 | hidden_transform=None, 30 | get_dzdt=False, 31 | ): 32 | if get_dzdt: 33 | 34 | def model_apply(params, x, dxdt): 35 | z, dzdt_hidden = encoder_apply(params[encoder_name], x, dxdt) 36 | sym_deriv_x, sym_dzdt = symder_apply(params[sym_model_name], z) 37 | if hidden_transform is not None: 38 | z_hidden = hidden_transform(z) 39 | sym_dzdt_hidden = hidden_transform(sym_dzdt) 40 | else: 41 | z_hidden = z 42 | sym_dzdt_hidden = sym_dzdt 43 | return sym_deriv_x, z_hidden, dzdt_hidden, sym_dzdt_hidden 44 | 45 | else: 46 | 47 | def model_apply(params, x): 48 | out = encoder_apply(params[encoder_name], x) 49 | z = out[0] if (isinstance(out, list) or isinstance(out, tuple)) else out 50 | sym_deriv_x = symder_apply(params[sym_model_name], z) 51 | z_hidden = hidden_transform(z) if hidden_transform is not None else z 52 | return sym_deriv_x, z_hidden 53 | 54 | return model_apply 55 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import lax 4 | from jax.tree_util import pytree 5 | 6 | import optax 7 | 8 | import pickle 9 | from pathlib import Path 10 | from typing import Union 11 | 12 | __all__ = ["save_pytree", "load_pytree", "loss_fn", "init_optimizers"] 13 | 14 | suffix = ".pt" 15 | 16 | 17 | def save_pytree(path: Union[str, Path], data: pytree, overwrite: bool = False): 18 | path = Path(path) 19 | if path.suffix != suffix: 20 | path = path.with_suffix(suffix) 21 | path.parent.mkdir(parents=True, exist_ok=True) 22 | if path.exists(): 23 | if overwrite: 24 | path.unlink() 25 | else: 26 | raise FileExistsError(f"File {path} already exists.") 27 | with open(path, "wb") as file: 28 | pickle.dump(data, file) 29 | 30 | 31 | def load_pytree(path: Union[str, Path]) -> pytree: 32 | path = Path(path) 33 | if not path.is_file(): 34 | raise ValueError(f"Not a file: {path}") 35 | if path.suffix != suffix: 36 | raise ValueError(f"Not a {suffix} file: {path}") 37 | with open(path, "rb") as file: 38 | data = pickle.load(file) 39 | return data 40 | 41 | 42 | def loss_fn( 43 | model_apply, 44 | params, 45 | batch, 46 | target, 47 | scale=None, 48 | deriv_weight=None, 49 | reg_dzdt=None, 50 | reg_dzdt_var_norm=True, 51 | reg_l1_sparse=None, 52 | sym_model_name="sym_model", 53 | ): 54 | num_der = target.shape[-1] 55 | 56 | if scale is None: 57 | scale = jnp.ones(1, num_der + 1) 58 | if deriv_weight is None: 59 | deriv_weight = jnp.ones(num_der) 60 | 61 | if reg_dzdt is not None: 62 | x = batch[..., 0] 63 | dxdt = batch[..., 1] * scale[:, 1] / scale[:, 0] 64 | sym_deriv_x, z_hidden, dzdt_hidden, sym_dzdt_hidden = model_apply( 65 | params, x, dxdt 66 | ) 67 | else: 68 | sym_deriv_x, z_hidden = model_apply(params, batch) 69 | 70 | # scale to normed derivatives 71 | scaled_sym_deriv_x = sym_deriv_x * scale[:, [0]] / scale[:, 1:] 72 | 73 | # MSE loss 74 | mse_loss = jnp.sum( 75 | deriv_weight 76 | * jnp.mean(((target - scaled_sym_deriv_x) ** 2).reshape(-1, num_der), axis=0) 77 | ) 78 | loss_list = [mse_loss] 79 | 80 | # dz/dt regularization loss 81 | if reg_dzdt is not None: 82 | num_hidden = dzdt_hidden.shape[-1] 83 | if reg_dzdt_var_norm: 84 | reg_dzdt_loss = reg_dzdt * jnp.mean( 85 | (dzdt_hidden - sym_dzdt_hidden) ** 2 86 | / jnp.var(z_hidden.reshape(-1, num_hidden), axis=0) 87 | ) 88 | else: 89 | reg_dzdt_loss = reg_dzdt * jnp.mean((dzdt_hidden - sym_dzdt_hidden) ** 2) 90 | loss_list.append(reg_dzdt_loss) 91 | 92 | # L1 sparse regularization loss 93 | if reg_l1_sparse is not None: 94 | reg_l1_sparse_loss = reg_l1_sparse * jax.tree_util.tree_reduce( 95 | lambda x, y: x + jnp.abs(y).sum(), params[sym_model_name], 0.0 96 | ) 97 | loss_list.append(reg_l1_sparse_loss) 98 | 99 | loss = sum(loss_list) 100 | loss_list.insert(0, loss) 101 | return loss, loss_list 102 | 103 | 104 | def loss_fn_weighted( 105 | model_apply, 106 | params, 107 | batch, 108 | target, 109 | weight, 110 | scale=None, 111 | deriv_weight=None, 112 | reg_dzdt=None, 113 | reg_dzdt_var_norm=True, 114 | reg_l1_sparse=None, 115 | sym_model_name="sym_model", 116 | ): 117 | num_der = target.shape[-1] 118 | 119 | if scale is None: 120 | scale = jnp.ones(1, num_der + 1) 121 | if deriv_weight is None: 122 | deriv_weight = jnp.ones(num_der) 123 | 124 | if reg_dzdt is not None: 125 | x = batch[..., 0] 126 | dxdt = batch[..., 1] * scale[:, 1] / scale[:, 0] 127 | sym_deriv_x, z_hidden, dzdt_hidden, sym_dzdt_hidden = model_apply( 128 | params, x, dxdt 129 | ) 130 | else: 131 | sym_deriv_x, z_hidden = model_apply(params, batch) 132 | 133 | # scale to normed derivatives 134 | scaled_sym_deriv_x = sym_deriv_x * scale[:, [0]] / scale[:, 1:] 135 | 136 | # MSE loss 137 | mse_loss = jnp.sum( 138 | deriv_weight 139 | * jnp.mean( 140 | (weight * (target - scaled_sym_deriv_x) ** 2).reshape(-1, num_der), axis=0 141 | ) 142 | ) 143 | loss_list = [mse_loss] 144 | 145 | # dz/dt regularization loss 146 | if reg_dzdt is not None: 147 | num_hidden = dzdt_hidden.shape[-1] 148 | if reg_dzdt_var_norm: 149 | reg_dzdt_loss = reg_dzdt * jnp.mean( 150 | (dzdt_hidden - sym_dzdt_hidden) ** 2 151 | / jnp.var(z_hidden.reshape(-1, num_hidden), axis=0) 152 | ) 153 | else: 154 | reg_dzdt_loss = reg_dzdt * jnp.mean( 155 | weight[..., 0] * (dzdt_hidden - sym_dzdt_hidden) ** 2 156 | ) 157 | loss_list.append(reg_dzdt_loss) 158 | 159 | # L1 sparse regularization loss 160 | if reg_l1_sparse is not None: 161 | reg_l1_sparse_loss = reg_l1_sparse * jax.tree_util.tree_reduce( 162 | lambda x, y: x + jnp.abs(y).sum(), params[sym_model_name], 0.0 163 | ) 164 | loss_list.append(reg_l1_sparse_loss) 165 | 166 | loss = sum(loss_list) 167 | loss_list.insert(0, loss) 168 | return loss, loss_list 169 | 170 | 171 | def init_optimizers( 172 | params, 173 | optimizers, 174 | sparsify=False, 175 | multi_gpu=False, 176 | sym_model_name="sym_model", 177 | pmap_axis_name="devices", 178 | ): 179 | # Initialize optimizers 180 | opt_init, opt_update, opt_state = {}, {}, {} 181 | for name in params.keys(): 182 | opt_init[name], opt_update[name] = optimizers[name] 183 | if multi_gpu: 184 | opt_state[name] = jax.pmap(opt_init[name])(params[name]) 185 | else: 186 | opt_state[name] = opt_init[name](params[name]) 187 | 188 | # Define update function 189 | def update_params(grads, opt_state, params, sparse_mask): 190 | if sparsify: 191 | grads[sym_model_name] = jax.tree_multimap( 192 | jnp.multiply, sparse_mask, grads[sym_model_name] 193 | ) 194 | 195 | updates = {} 196 | for name in params.keys(): 197 | updates[name], opt_state[name] = opt_update[name]( 198 | grads[name], opt_state[name], params[name] 199 | ) 200 | params = optax.apply_updates(params, updates) 201 | 202 | if sparsify: 203 | params[sym_model_name] = jax.tree_multimap( 204 | jnp.multiply, sparse_mask, params[sym_model_name] 205 | ) 206 | 207 | # TODO: This may not be necessary or can at least be reduced in frequency 208 | if multi_gpu: 209 | # Ensure params, opt_state, sparse_mask are the same across all devices 210 | params = lax.pmean(params, axis_name=pmap_axis_name) 211 | opt_state, sparse_mask = lax.pmax( 212 | (opt_state, sparse_mask), axis_name=pmap_axis_name 213 | ) 214 | 215 | return params, opt_state, sparse_mask 216 | 217 | return update_params, opt_state 218 | --------------------------------------------------------------------------------