├── .gitignore ├── Models ├── hopf_nn.py ├── vanderpol_nn.py ├── ks_nn.py └── lorenz_nn.py ├── README.md ├── Systems ├── hopf.py ├── vanderpol.py ├── lorenz.py └── ks.py ├── Utils ├── Solvers.py ├── Chaos01Test.py └── Plotters.py ├── main.py ├── training_loop.py └── NODE └── NODE.py /.gitignore: -------------------------------------------------------------------------------- 1 | /venv 2 | /__pycache__ 3 | *.pyc 4 | *~ 5 | -------------------------------------------------------------------------------- /Models/hopf_nn.py: -------------------------------------------------------------------------------- 1 | from NODE.NODE import * 2 | 3 | 4 | class HopfNormalTrain(ODEF): 5 | """ 6 | neural network for learning the hopf normal form 7 | """ 8 | def __init__(self): 9 | super(HopfNormalTrain, self).__init__() 10 | self.lin = nn.Linear(3, 256) 11 | self.lin1 = nn.Linear(256, 3) 12 | self.relu = nn.ReLU() 13 | 14 | def forward(self, x): 15 | x = self.relu(self.lin(x)) 16 | x = self.lin1(x) 17 | x = x.view(-1,3) 18 | x[:, 2] = 0 19 | return x.unsqueeze(1) 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # K-NODE 2 | Knowledge-based learning of nonlinear dynamics and chaos 3 | 4 | ## Operating System 5 | Ubuntu 20.04.5 LTS 6 | 7 | ## Python Version 8 | Python 3.8.3 9 | 10 | ## Dependencies 11 | - contourpy 1.0.5 12 | - cycler 0.11.0 13 | - fonttools 4.37.4 14 | - kiwisolver 1.4.4 15 | - matplotlib 3.6.0 16 | - numpy 1.23.3 17 | - packaging 21.3 18 | - Pillow 9.2.0 19 | - pip 22.2.2 20 | - pyparsing 3.0.9 21 | - python-dateutil 2.8.2 22 | - scipy 1.9.1 23 | - setuptools 41.2.0 24 | - six 1.16.0 25 | - torch 1.12.1 26 | - typing_extensions 4.3.0 27 | 28 | ## To Run 29 | ```python main.py``` 30 | This will start learning the chaotic Lorenz system 31 | -------------------------------------------------------------------------------- /Models/vanderpol_nn.py: -------------------------------------------------------------------------------- 1 | class VanderPolTrain(ODEF): 2 | """ 3 | neural network for learning the stiff van der pol oscillator 4 | """ 5 | def __init__(self): 6 | super(VanderPolTrain, self).__init__() 7 | self.lin1 = nn.Linear(3, 512, bias=False) 8 | self.lin2 = nn.Linear(512, 32, bias=False) 9 | self.lin3 = nn.Linear(32, 2, bias=False) 10 | self.tanh = nn.Tanh() 11 | self.softplus = nn.Softplus(beta=1.8) 12 | 13 | def forward(self, t, x): 14 | if isinstance(x, np.ndarray): x = FloatTensor(x) 15 | x = x.float() 16 | x = self.tanh(self.lin1(x)) 17 | x = self.softplus(self.lin2(x)) 18 | x = self.lin3(x) 19 | 20 | x = x.view(1, -1) 21 | x = torch.cat([x, Tensor([[0]])], 1) 22 | 23 | return x -------------------------------------------------------------------------------- /Systems/hopf.py: -------------------------------------------------------------------------------- 1 | from NODE.NODE import * 2 | 3 | 4 | class HopfNormal(ODEF): 5 | """ 6 | Hopf normal form 7 | """ 8 | def __init__(self): 9 | super(HopfNormal, self).__init__() 10 | self.lin = nn.Linear(9, 3, bias=False) 11 | W = Tensor([[0, 1, 0, 1, 0, -1, -1, 0, 0], 12 | [-1, 0, 0, 0, 1, 0, 0, -1, -1], 13 | [0, 0, 0, 0, 0, 0, 0, 0, 0]]) 14 | self.lin.weight = nn.Parameter(W) 15 | 16 | def forward(self, x): 17 | y = torch.ones(1, 9) 18 | y[0][0] = x[0][0] 19 | y[0][1] = x[0][1] 20 | y[0][2] = x[0][2] 21 | y[0][3] = x[0][0] * x[0][2] 22 | y[0][4] = x[0][1] * x[0][2] 23 | y[0][5] = x[0][0] ** 3 24 | y[0][6] = x[0][0] * x[0][1] ** 2 25 | y[0][7] = x[0][1] ** 3 26 | y[0][8] = x[0][1] * x[0][0] ** 2 27 | y = self.lin(y) 28 | return y 29 | -------------------------------------------------------------------------------- /Systems/vanderpol.py: -------------------------------------------------------------------------------- 1 | from NODE.NODE import * 2 | 3 | class VanderPol(ODEF): 4 | """ 5 | The Van der Pol oscillator 6 | """ 7 | def __init__(self): 8 | super(VanderPol, self).__init__() 9 | self.lin = nn.Linear(4, 3, bias=False) 10 | W = Tensor([[0, 1, 0, 0], 11 | [-1, 0, 1, -1], 12 | [0, 0, 0, 0]]) 13 | self.lin.weight = nn.Parameter(W) 14 | 15 | def forward(self, t, x): 16 | try: 17 | y = torch.ones([1, 4]) 18 | y[0][0] = x[0][0] 19 | y[0][1] = x[0][1] 20 | y[0][2] = x[0][1] * x[0][2] 21 | y[0][3] = x[0][1] * x[0][0] ** 2 * x[0][2] 22 | y = self.lin(y) 23 | except: 24 | y = np.zeros(4) 25 | y[0] = x[0] 26 | y[1] = x[1] 27 | y[2] = x[1] * x[2] 28 | y[3] = x[1] * x[0] ** 2 * x[2] 29 | y = self.lin(Tensor(y)) 30 | y = y.view(1, -1) 31 | return y -------------------------------------------------------------------------------- /Utils/Solvers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | """## solvers""" 4 | 5 | def RK(z0, n_steps, f, h): 6 | ''' 7 | 4th Order Runge Kutta Numerical Solver 8 | Input: 9 | z0: initial condition 10 | t0: initial time (not actual time, but the index of time) 11 | n_steps: the number of steps to integrate 12 | f: vector field 13 | h: step size 14 | Return: 15 | z: the state after n_steps 16 | ''' 17 | z = z0 18 | for i in range(int(n_steps)): 19 | k1 = h * f(z) 20 | k2 = h * f(z + 0.5 * k1) 21 | k3 = h * f(z + 0.5 * k2) 22 | k4 = h * f(z + k3) 23 | 24 | z = z + (1.0 / 6.0)*(k1 + 2 * k2 + 2 * k3 + k4) 25 | return z 26 | 27 | 28 | def Euler(z0, n_steps, f, step_size): 29 | ''' 30 | Simplest Euler ODE initial value solver 31 | Input: 32 | z0: initial condition 33 | t0: initial time (not actual time, but the index of time) 34 | n_steps: the number of steps to integrate 35 | f: vector field 36 | h: step size 37 | Return: 38 | z: the state after n_steps 39 | ''' 40 | z = z0 41 | for i_step in range(int(n_steps)): 42 | z = z + step_size * f(z) 43 | return z 44 | -------------------------------------------------------------------------------- /Models/ks_nn.py: -------------------------------------------------------------------------------- 1 | from NODE.NODE import * 2 | 3 | 4 | class KS_conv64(ODEF): 5 | """ 6 | neural network for learning the KS equation 7 | """ 8 | def __init__(self): 9 | super(KS_conv64, self).__init__() 10 | # Encoder 11 | bias = False 12 | padding_mode = 'replicate' 13 | self.enc_conv1 = nn.Conv1d(1, 32, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode, bias=bias) 14 | self.enc_conv2 = nn.Conv1d(32, 256, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode, bias=bias) 15 | self.enc_conv3 = nn.Conv1d(256, 128, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode, bias=bias) 16 | self.enc_conv4 = nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode, bias=bias) 17 | self.enc_conv5 = nn.Conv1d(128, 128, kernel_size=3, stride=2, padding=1, padding_mode=padding_mode, bias=bias) 18 | self.enc_conv6 = nn.Conv1d(128, 128, kernel_size=3, stride=2, padding=1, padding_mode=padding_mode, bias=bias) 19 | 20 | self.lin1 = nn.Linear(2048, 64, bias=bias) 21 | 22 | self.relu = nn.LeakyReLU(0.05) 23 | self.tanh = nn.Tanh() 24 | 25 | def forward(self, x, t): 26 | x = x.view(1, 1, -1) 27 | x = self.relu(self.enc_conv1(x)) 28 | x = self.relu(self.enc_conv2(x)) 29 | x = self.tanh(self.enc_conv3(x)) 30 | x = self.relu(self.enc_conv4(x)) 31 | x = self.relu(self.enc_conv5(x)) 32 | x = self.enc_conv6(x) 33 | 34 | x = x.view(-1) 35 | x = self.lin1(x) 36 | x = x.view(1, -1) 37 | return x -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from training_loop import sample_and_grow, sample_data 3 | from Models.lorenz_nn import * 4 | from Utils.Solvers import RK 5 | from Systems.lorenz import Lorenz 6 | from NODE.NODE import NeuralODE 7 | 8 | ################### Initialize Models ################## 9 | simulation_solver = RK 10 | simulation_step_size = 0.01 11 | training_solver = RK 12 | training_step_size = 0.01 13 | torch.manual_seed(0) 14 | ode_train = NeuralODE(LorenzModifiedKNODE(), training_solver, training_step_size) 15 | ode_true = NeuralODE(Lorenz(), simulation_solver, simulation_step_size) 16 | hybrid = False 17 | loss_arr = [] 18 | save_path = None 19 | 20 | ################### Generating Training Data ################### 21 | sampling_rate = simulation_step_size * 1 # seconds per instance i.e. 1/Hz, assumed to be lower than simulation rate 22 | SimICs = torch.tensor([[[-8., 7., 27.]]]) # initial condition for simulation 23 | t0 = 0 # start point (index of time) 24 | N_POINTS = 800 # Number of times the solver steps. total_time_span = N_POINTS * simulation_step_size 25 | NOISE_VAR = 0 # 0.316227766 # Variance of gaussian noise added to the observation. Assumed to be 0-mean 26 | times, obs_noiseless, t_v, x_v = sample_data(ode_true, t0, N_POINTS, SimICs, simulation_step_size, sampling_rate) 27 | torch.manual_seed(6) 28 | obs = obs_noiseless + torch.randn_like(obs_noiseless) * NOISE_VAR 29 | obs = obs.detach() # [len, 1, dim] 30 | times = times.detach() 31 | 32 | ################## Training ################### 33 | # Training Parameters 34 | EPOCHs = 2000 # No. of epochs to train 35 | LOOKAHEAD = 2 # lookahead 36 | name = "lookahead_" + str(LOOKAHEAD - 1) 37 | LR = 0.01 # learning rate 38 | sample_and_grow(ode_train, obs, times, EPOCHs, LR, hybrid, LOOKAHEAD, loss_arr, plot_freq=20) 39 | -------------------------------------------------------------------------------- /Systems/lorenz.py: -------------------------------------------------------------------------------- 1 | from NODE.NODE import * 2 | 3 | 4 | class Lorenz(ODEF): 5 | """ 6 | chaotic lorenz system 7 | """ 8 | def __init__(self): 9 | super(Lorenz, self).__init__() 10 | self.lin = nn.Linear(5, 3, bias=False) 11 | W = Tensor([[-10, 10, 0, 0, 0], 12 | [28, -1, 0, -1, 0], 13 | [0, 0, -8 / 3, 0, 1]]) 14 | self.lin.weight = nn.Parameter(W) 15 | 16 | def forward(self, x): 17 | bs, _, dim = x.shape 18 | y = y = torch.ones([bs, 5]) 19 | y[:, 0] = x[:, :, 0] 20 | y[:, 1] = x[:, :, 1] 21 | y[:, 2] = x[:, :, 2] 22 | y[:, 3] = x[:, :, 0] * x[:, :, 2] 23 | y[:, 4] = x[:, :, 0] * x[:, :, 1] 24 | x_dot = self.lin(y) 25 | return x_dot.view(bs, -1, dim) 26 | 27 | 28 | class LorenzLimitCycle(ODEF): 29 | """ 30 | modified lorenz system which forms a limit cycle 31 | """ 32 | def __init__(self): 33 | super(LorenzLimitCycle, self).__init__() 34 | self.lin = nn.Linear(5, 3, bias=False) 35 | W = Tensor([[-10, 10, 0, 0, 0], 36 | [-4.8, 7.2, 0, -1, 0], 37 | [0, 0, -8 / 3, 0, 1]]) 38 | self.lin.weight = nn.Parameter(W) 39 | 40 | def forward(self, x): 41 | y = y = torch.ones([1, 5]) 42 | y[0][0] = x[0][0] 43 | y[0][1] = x[0][1] 44 | y[0][2] = x[0][2] 45 | y[0][3] = x[0][0] * x[0][2] 46 | y[0][4] = x[0][0] * x[0][1] 47 | return self.lin(y) 48 | 49 | 50 | class LorenzSindy(ODEF): 51 | """ 52 | incorrectly identified lorenz system using SINDy 53 | """ 54 | def __init__(self): 55 | super(LorenzSindy, self).__init__() 56 | self.lin = nn.Linear(5, 3, bias=False) 57 | # system identified by SINDy using the correct nonlinearities 58 | W = Tensor([[-9.913, 9.913, 0, 0, 0], 59 | [27.212, -0.848, 0, -0.978, 0], 60 | [0, 0, -2.636, 0, 0.988]]) 61 | self.lin.weight = nn.Parameter(W) 62 | 63 | def forward(self, x): 64 | y = y = torch.ones([1, 5]) 65 | y[0][0] = x[0][0] 66 | y[0][1] = x[0][1] 67 | y[0][2] = x[0][2] 68 | y[0][3] = x[0][0] * x[0][2] 69 | y[0][4] = x[0][0] * x[0][1] 70 | return self.lin(y) 71 | 72 | -------------------------------------------------------------------------------- /Utils/Chaos01Test.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import curve_fit 2 | import numpy as np 3 | 4 | def plot_pq(p, q, save=None, figsize=(6, 6), title=''): 5 | """ 6 | plotting the p-q system driven by a trajectory 7 | :param p: p 8 | :param q: q 9 | :param save: path where the image gets saved 10 | :param figsize: figure size 11 | :param title: title of the plot 12 | """ 13 | mpl.rcParams.update({'font.size': 22}) 14 | fig = plt.figure(figsize=figsize) 15 | ax = fig.add_subplot(1, 1, 1) 16 | fig.set_facecolor('white') 17 | ax.set_facecolor('white') 18 | ax.set_title(title) 19 | ax.set_xlabel("$p_c$", labelpad=2) 20 | ax.set_ylabel("$q_c$", labelpad=2) 21 | 22 | for i in range(len(p)): 23 | ax.plot(p[i:i + 10], q[i:i + 10], color=plt.cm.jet(i / len(p) / 1.6), linewidth=2) 24 | 25 | if save is not None: 26 | plt.savefig(save + '.png', format='png', dpi=600, bbox_inches='tight', pad_inches=0) 27 | pass 28 | plt.show() 29 | 30 | def compute_pq(traj): 31 | """ 32 | Computing p and q for a 1D trajectory 33 | :param traj: 1D trajectory 34 | :return: numpy arrays of p and q 35 | """ 36 | c = 0.4 37 | p_list = [] 38 | q_list = [] 39 | p = traj[0]*np.cos(c) 40 | q = traj[0]*np.sin(c) 41 | p_list.append(p) 42 | q_list.append(q) 43 | for n in range(len(traj)-1): 44 | p = p + traj[n] * np.cos((n+1)*c) 45 | q = q + traj[n] * np.sin((n+1)*c) 46 | p_list.append(p) 47 | q_list.append(q) 48 | return np.array(p_list), np.array(q_list) 49 | 50 | def compute_M(p, q): 51 | """ 52 | Compute the mean square displacement of the 2D p-q system 53 | :param p: numpy array of p 54 | :param q: numpy array of q 55 | :return: numpy array of the mean square displacement 56 | """ 57 | M_list = [] 58 | n_cut = int(len(p)/10) 59 | N = len(p)-n_cut 60 | for n in range(n_cut): 61 | M = np.mean([(p[j+n] - p[j])**2 + (q[j+n] - q[j])**2 for j in range(N)]) 62 | M_list.append(M) 63 | print("Size of M:\n", len(M_list)) 64 | return np.array(M_list) 65 | 66 | def compute_Kc(pred): 67 | """ 68 | Computing Kc of the 0-1 test using the first dimension of the state 69 | :param pred: predicted trajectory of type torch tensor and size [len, bs, dim] 70 | :return: Kc 71 | """ 72 | dim = pred.size()[-1] 73 | data_for_test = pred[:][::5] 74 | traj = data_for_test.detach().numpy().reshape([-1,dim])[:,0] #only taking the first dimension 75 | p_traj, q_traj = compute_pq(traj) 76 | M = compute_M(p_traj, q_traj) 77 | def test(x, m, c): 78 | return m*x+c 79 | ns = np.log(np.arange(1,len(M)+1)) 80 | log_M = np.log(M+1) 81 | param, param_cov = curve_fit(test, ns, log_M) 82 | print("Kc is", param[0]) 83 | return param[0] -------------------------------------------------------------------------------- /Utils/Plotters.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from mpl_toolkits.mplot3d import Axes3D 3 | import numpy as np 4 | import time 5 | 6 | def plot_trajectories(fig, obs=None, noiseless_traj=None,times=None, trajs=None, save=None, title=''): 7 | plt.ion() 8 | ax = fig.add_subplot(1, 1, 1, projection='3d') 9 | if title is not None: 10 | ax.set_title('True Trajectory and Predicted Trajectory\n'+title) 11 | 12 | if noiseless_traj is not None: 13 | z = np.array([o.detach().numpy() for o in noiseless_traj]) 14 | z = np.reshape(z, [-1,3]) 15 | for i in range(len(z)): 16 | ax.plot(z[i:i+10, 0], z[i:i+10, 1], z[i:i+10, 2], color=plt.cm.jet(i/len(z)/1.6)) 17 | 18 | if obs is not None: 19 | z = np.array([o.detach().numpy() for o in obs]) 20 | z = np.reshape(z, [-1,3]) 21 | ax.scatter(z[:,0], z[:,1], z[:,2], marker='.', color='k', alpha=0.5, linewidths=0, s=45) 22 | 23 | if trajs is not None: 24 | z = np.array([o.detach().numpy() for o in trajs]) 25 | z = np.reshape(z, [-1,3]) 26 | for i in range(len(z)): 27 | ax.plot(z[i:i+10, 0], z[i:i+10, 1], z[i:i+10, 2], color='r', alpha=0.3) 28 | 29 | fig.canvas.draw() 30 | fig.canvas.flush_events() 31 | # time.sleep(0.1) 32 | plt.show() 33 | if save is not None: 34 | plt.savefig(save+'.png', format='png', dpi=400, bbox_inches ='tight', pad_inches = 0.1) 35 | pass 36 | 37 | 38 | ############## VISUALIZING KS ################ 39 | def make_color_map(DATA_TO_PLOT, title="Untitled", slice=False, save=None, figure_size=(10,3), save_eps=False): 40 | nsteps, ngrids,*a = DATA_TO_PLOT.shape 41 | offset = 0 # number of grids (starting from 0th) not to plot 42 | 43 | # creating mesh 44 | x = np.linspace(0,nsteps-1, num=nsteps)*0.25*0.089 45 | y = np.linspace(0,ngrids-offset-1, num=ngrids-offset) 46 | X, Y = np.meshgrid(x,y) 47 | 48 | # plotting color grid 49 | fig = plt.figure(figsize=figure_size) 50 | ax = plt.axes() 51 | c = ax.pcolormesh(X, Y, DATA_TO_PLOT.T.reshape(ngrids,nsteps)[offset:ngrids,:], cmap="jet") 52 | ax.set_title(title) 53 | ax.set_xlabel('$\Lambda_{max} t$') # adding axes labels changes the appearance of the color map 54 | ax.set_ylabel('space') 55 | fig.tight_layout() 56 | fig.colorbar(c) 57 | if save is not None: 58 | plt.savefig(save+'.png', format='png', dpi=600, bbox_inches ='tight', pad_inches = 0) 59 | if save_eps: 60 | plt.savefig(save+'.eps', format='eps', dpi=600, bbox_inches ='tight', pad_inches = 0) 61 | pass 62 | 63 | if slice: 64 | # plotting one single trajectory 65 | fig = plt.figure(figsize=figure_size) 66 | ax = fig.add_subplot(1, 1, 1) 67 | #for i in range(12): 68 | #i = i*10 69 | ax.plot(DATA_TO_PLOT[0,:]) 70 | ax.plot(DATA_TO_PLOT[1,:]) 71 | ax.set_title(title+'\nSingle Trajectory') 72 | ax.set_xlabel('$\Lambda_{max} t$') 73 | ax.set_ylabel('state') 74 | fig.tight_layout() 75 | -------------------------------------------------------------------------------- /Models/lorenz_nn.py: -------------------------------------------------------------------------------- 1 | from NODE.NODE import * 2 | 3 | 4 | class LorenzTrain(ODEF): 5 | """ 6 | neural network for learning the chaotic lorenz system 7 | """ 8 | def __init__(self): 9 | super(LorenzTrain, self).__init__() 10 | self.lin = nn.Linear(3, 256) 11 | self.lin3 = nn.Linear(256, 3) 12 | self.relu = nn.ReLU() 13 | 14 | def forward(self, x): 15 | x = self.relu(self.lin(x)) 16 | x = self.lin3(x) 17 | return x 18 | 19 | 20 | class LorenzSindyKNODE(ODEF): 21 | """ 22 | KNODE combining incorrectly SINDy-identified lorenz system and a neural network 23 | """ 24 | def __init__(self): 25 | super(LorenzSindyKNODE, self).__init__() 26 | self.lin_im = nn.Linear(6, 3, bias=False) 27 | # xy and xz are excluded from the library of functions 28 | self.W = Tensor([[-9.913, 9.913, 0, 0, 0, 0], 29 | [-7.175, 20.507, 0, -0.613, 0, 0], 30 | [0, 0, -3.05, 0, 0.504, 0.479]]) 31 | 32 | self.lin_im.weight = nn.Parameter(self.W) 33 | 34 | self.lin1 = nn.Linear(3, 32) 35 | self.lin2 = nn.Linear(32, 512) 36 | self.lin3 = nn.Linear(512, 32) 37 | self.lin4 = nn.Linear(32, 3) 38 | 39 | self.Mout = nn.Linear(6, 3) 40 | 41 | self.relu = nn.ReLU() 42 | 43 | def forward(self, x): 44 | x = x.view(-1, 1, 3) 45 | bs, _, _ = x.size() 46 | y = torch.zeros([bs, 1, 6]) 47 | y[:, :, 0] = x[:, :, 0] 48 | y[:, :, 1] = x[:, :, 1] 49 | y[:, :, 2] = x[:, :, 2] 50 | y[:, :, 3] = x[:, :, 1] * x[:, :, 2] 51 | y[:, :, 4] = x[:, :, 0] ** 2 52 | y[:, :, 5] = x[:, :, 1] ** 2 53 | 54 | y = self.lin_im(y) 55 | x = self.relu(self.lin1(x)) 56 | x = self.relu(self.lin2(x)) 57 | x = self.relu(self.lin3(x)) 58 | x = self.lin4(x) 59 | 60 | x = self.Mout(torch.cat([x, y], 2)) 61 | return x 62 | 63 | 64 | class LorenzModifiedKNODE(ODEF): 65 | """ 66 | KNODE combining limit cycle lorenz system and a neural network 67 | """ 68 | def __init__(self): 69 | super(LorenzModifiedKNODE, self).__init__() 70 | # imperfect model 71 | self.lin_im = nn.Linear(5, 3, bias=False) 72 | # Using inaccurate coefficients 73 | self.W = Tensor([[-10, 10, 0, 0, 0], 74 | [-4.8, 7.2, 0, -1, 0], 75 | [0, 0, -8 / 3, 0, 1]]) 76 | self.lin_im.weight = nn.Parameter(self.W) 77 | 78 | # neural network 79 | self.lin1 = nn.Linear(3, 32) 80 | self.lin2 = nn.Linear(32, 512) 81 | self.lin3 = nn.Linear(512, 32) 82 | self.lin4 = nn.Linear(32, 3) 83 | self.Mout = nn.Linear(6, 3) 84 | self.relu = nn.ReLU() 85 | 86 | def forward(self, x): 87 | x = x.view(-1, 1, 3) 88 | bs, _, _ = x.size() 89 | y = torch.zeros([bs, 1, 5]) 90 | y[:, :, 0] = x[:, :, 0] 91 | y[:, :, 1] = x[:, :, 1] 92 | y[:, :, 2] = x[:, :, 2] 93 | y[:, :, 3] = x[:, :, 0] * x[:, :, 2] 94 | y[:, :, 4] = x[:, :, 0] * x[:, :, 1] 95 | y = self.lin_im(y) 96 | 97 | x = self.relu(self.lin1(x)) 98 | x = self.relu(self.lin2(x)) 99 | x = self.relu(self.lin3(x)) 100 | x = self.lin4(x) 101 | x = self.Mout(torch.cat([x, y], 2)) 102 | return x 103 | -------------------------------------------------------------------------------- /Systems/ks.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from scipy.fft import fft, ifft 3 | from Utils.Plotters import * 4 | """ 5 | Simulating the Kumaramoto-Sivashinsky Equation 6 | """ 7 | class KSSimulation(object): 8 | def __init__(self): 9 | ############## simulation parameters ############### 10 | self.n_steps = 5000 # total number of steps to simulate 11 | self.transient = 1000 # number of steps to discard as transient state 12 | self.d = 60 # periodicity length 13 | self.tau = 0.25 # time step 14 | self.N = 64 # number of grid points 15 | self.const = 0 # error value 16 | self.E, self.E2, self.Q, self.f1, self.f2, self.f3, self.g = self.precompute_KS() 17 | 18 | def KS_pred(self, u, param): 19 | v = fft(u) 20 | vv = np.zeros(self.N, self.n_step) 21 | vv[:, 0] = v 22 | for i in range(self.n_step): 23 | Nv = self.g * fft(np.real(ifft(v)) ** 2) 24 | a = self.E2 * v + self.Q * Nv 25 | Na = self.g * fft(np.real(ifft(a)) ** 2) 26 | b = self.E2 * v + self.Q * Na 27 | Nb = self.g * fft(np.real(ifft(b)) ** 2) 28 | c = self.E2 * a + self.Q * (2 * Nb - Nv) 29 | Nc = self.g * fft(np.real(ifft(c)) ** 2) 30 | v = self.E * v + Nv * self.f1 + 2 * (Na + Nb) * self.f2 + Nc * self.f3 31 | vv[:, i] = v 32 | uu = np.real(vv) 33 | return uu 34 | 35 | def KS_forecast(self, u): 36 | v = fft(u) 37 | Nv = self.g * fft(np.real(ifft(v)) ** 2) 38 | a = self.E2 * v + self.Q * Nv 39 | Na = self.g * fft(np.real(ifft(a)) ** 2) 40 | b = self.E2 * v + self.Q * Na 41 | Nb = self.g * fft(np.real(ifft(b)) ** 2) 42 | c = self.E2 * a + self.Q * (2 * Nb-Nv) 43 | Nc = self.g * fft(np.real(ifft(c)) ** 2) 44 | v = self.E * v + Nv * self.f1 + 2 * (Na + Nb) * self.f2 + Nc * self.f3 45 | u = np.real(ifft(v)) 46 | return u 47 | 48 | def precompute_KS(self): 49 | k = np.concatenate((np.arange(0, self.N/2), 50 | np.array([0]), 51 | np.arange(-self.N / 2 + 1, 0))).T * 2 * np.pi / self.d # wave number 52 | L = (1 + self.const) * k ** 2 - k ** 4 # fourier multiplier 53 | E = np.exp(self.tau * L) 54 | E2 = np.exp(self.tau * L / 2) 55 | M = 16 # number of points for complex means 56 | r = np.exp(1j * np.pi * (np.arange(1, M + 1) - 0.5) / M) #roots of unity 57 | LR = self.tau * np.tile(L, (M, 1)).T + np.tile(r, (self.N, 1)) 58 | Q = self.tau * np.real(np.mean((np.exp(LR / 2) - 1) / LR, axis=1)) 59 | f1 = self.tau*np.real(np.mean((-4 - LR + np.exp(LR) * (4 - 3 * LR + LR ** 2)) / LR ** 3, axis=1)) 60 | f2 = self.tau*np.real(np.mean((2 + LR + np.exp(LR) * (-2 + LR)) / LR ** 3, axis=1)) 61 | f3 = self.tau*np.real(np.mean((-4 - 3 * LR - LR ** 2 + np.exp(LR) * (4 - LR)) / LR ** 3, axis=1)) 62 | g = -0.5j * k 63 | E = E.reshape(1, -1) 64 | E2 = E2.reshape(1, -1) 65 | Q = Q.reshape(1, -1) 66 | f1 = f1.reshape(1, -1) 67 | f2 = f2.reshape(1, -1) 68 | f3 = f3.reshape(1, -1) 69 | g = g.reshape(1, -1) 70 | return E, E2, Q, f1, f2, f3, g 71 | 72 | def generate_KS_data(self): 73 | np.random.seed(3) # seed 3 for data 74 | x = 10 * (-1 + 2 * np.random.rand(1, self.N)) 75 | data = [] 76 | for i in range(self.n_steps): 77 | x = self.KS_forecast(x) 78 | data.append(x) 79 | 80 | # discarding transient data 81 | truncated_data = np.array(data)[self.transient:].reshape([self.n_steps - self.transient, -1, 1]) 82 | return truncated_data 83 | 84 | KSsim = KSSimulation() 85 | KS_64 = KSsim.generate_KS_data() 86 | make_color_map(KS_64, figure_size=(15, 3), title="Kuramoto-Sivashinsky Equation 64 Grids") 87 | plt.show() -------------------------------------------------------------------------------- /training_loop.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | from Systems.lorenz import * 3 | from Utils.Plotters import * 4 | from Utils.Chaos01Test import * 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def sample_data(func, t_0, n_points, ics, sim_h, sampling_r): 9 | """ 10 | Samples simulated trajectories to generate training data 11 | 12 | :param func: the function whose dynamics is to be simulated 13 | :param t_0: time stamp of the initial condition 14 | :param n_points: the total number of steps to simulate 15 | :param ics: initial conditions 16 | :param sim_h: simulation step size 17 | :param sampling_r: sampling rate 18 | :return: sampled time stamps (times) and trajectory (obs) 19 | :return: simulated time stamps (t) and trajectory (x) 20 | """ 21 | t = torch.from_numpy(np.arange(t_0, t_0 + n_points, 1)).to(ics) 22 | x = func(ics, t, return_whole_sequence=True) 23 | ratio = int(sampling_r / sim_h) 24 | obs = x[0::ratio] 25 | times = t[0::ratio] 26 | return times, obs.squeeze(1), t, x.squeeze(1) 27 | 28 | 29 | def sample_and_grow(ode_train, true_sampled_traj, true_sampled_times, epochs, 30 | lr, hybrid, lookahead, loss_arr, plot_freq=50, save_path=None): 31 | """ 32 | The main training loop 33 | 34 | :param ode_train: the ode to be trained 35 | :param true_sampled_traj: sampled observations (training data) 36 | :param true_sampled_times: sampled time stamps 37 | :param epochs: the total number of epochs to train 38 | :param lookahead: lookahead 39 | :param loss_arr: array where the training losses are stored 40 | :param plot_freq: frequency of which the trajectories are plotted 41 | :return: None 42 | """ 43 | plot_title = "Epoch: {0} Loss: {1:.3e} Sim Step: {2} \n No. of Points: {3} Lookahead: {4} LR: {5}" 44 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, ode_train.parameters()), lr=lr) 45 | n_segments = len(true_sampled_traj) 46 | fig = plt.figure() 47 | 48 | for i in range(epochs): 49 | # Train Neural ODE 50 | true_segments_list = [] 51 | all_init = [] 52 | for j in range(0, n_segments - lookahead + 1, 1): 53 | true_sampled_segment = true_sampled_traj[j:j + lookahead] 54 | true_segments_list.append(true_sampled_segment) 55 | 56 | all_init = true_sampled_traj[:n_segments-lookahead+1] # the initial condition for each segment 57 | true_sampled_time_segment = torch.tensor(np.arange(lookahead)) # the times step to predict 58 | 59 | # predicting 60 | z_ = ode_train(all_init, true_sampled_time_segment, return_whole_sequence=True) 61 | z_ = z_.view(-1, 3) 62 | obs_ = torch.cat(true_segments_list, 1) 63 | obs_ = obs_.view(-1, 3) 64 | 65 | # computing loss 66 | loss = F.mse_loss(z_, obs_) 67 | loss_arr.append(loss.item()) 68 | optimizer.zero_grad() 69 | loss.backward(retain_graph=True) 70 | if hybrid: 71 | ode_train.func.lin_im.weight.grad *= 0 72 | 73 | optimizer.step() 74 | 75 | if i % plot_freq == 0: 76 | # saving model 77 | if save_path is not None: 78 | CHECKPOINT_PATH = save_path + "Lorenz_" + name + ".pth" 79 | torch.save({'ode_train': ode_train, 'ode_true': ode_true, 'loss_arr': loss_arr}, 80 | CHECKPOINT_PATH) 81 | 82 | # computing trajectory using the current model 83 | z_p = ode_train(true_sampled_traj[0], true_sampled_times, return_whole_sequence=True) 84 | # plotting 85 | plot_trajectories(fig, obs=[true_sampled_traj], noiseless_traj=[true_sampled_traj], 86 | times=[true_sampled_times], trajs=[z_p[:int(true_sampled_times[-1])]], 87 | save=None, title=plot_title.format(i, loss.item(), ode_train.STEP_SIZE, n_segments, 88 | lookahead - 1, lr)) 89 | print(plot_title.format(i, loss.item(), ode_train.STEP_SIZE, n_segments, 90 | lookahead - 1, lr)) 91 | -------------------------------------------------------------------------------- /NODE/NODE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.autograd import Variable 8 | 9 | 10 | class ODEF(nn.Module): 11 | def forward_with_grad(self, z, grad_outputs): 12 | """Compute f and a df/dz, a df/dp, a df/dt""" 13 | batch_size = z.shape[0] 14 | out = self.forward(z) 15 | 16 | a = grad_outputs 17 | adfdz, *adfdp = torch.autograd.grad( 18 | # concatenating tuples 19 | (out,), (z,) + tuple(self.parameters()), grad_outputs=(a), 20 | allow_unused=True, retain_graph=True 21 | ) 22 | # grad method automatically sums gradients for batch items, we have to expand them back 23 | if adfdp is not None: 24 | adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0) # unsqueeze(0) add dimension 1 to the position 0 25 | adfdp = adfdp.expand(batch_size, -1) / batch_size # passing -1 does not change dimension in that position 26 | return out, adfdz, adfdp 27 | 28 | def flatten_parameters(self): 29 | p_shapes = [] 30 | flat_parameters = [] 31 | for p in self.parameters(): 32 | p_shapes.append(p.size()) 33 | flat_parameters.append(p.flatten()) 34 | return torch.cat(flat_parameters) 35 | 36 | class ODEAdjoint(torch.autograd.Function): 37 | @staticmethod 38 | def forward(ctx, z0, t, flat_parameters, func, ode_solve, STEP_SIZE): 39 | assert isinstance(func, ODEF) 40 | bs, *z_shape = z0.size() 41 | time_len = t.size(0) 42 | 43 | with torch.no_grad(): 44 | # initialize z to len of time and type of z0 45 | z = torch.zeros(time_len, bs, *z_shape).to(z0) 46 | z[0] = z0 47 | # solving throughout time 48 | for i_t in range(time_len - 1): 49 | # z0 updated to next step 50 | z0 = ode_solve(z0, torch.abs(t[i_t+1]-t[i_t]), func, STEP_SIZE) 51 | z[i_t+1] = z0 52 | 53 | ctx.func = func 54 | ctx.save_for_backward(t, z.clone(), flat_parameters) 55 | ctx.ode_solve = ode_solve 56 | ctx.STEP_SIZE = STEP_SIZE 57 | return z 58 | 59 | @staticmethod 60 | def backward(ctx, dLdz): 61 | """ 62 | dLdz shape: time_len, batch_size, *z_shape 63 | """ 64 | func = ctx.func 65 | t, z, flat_parameters = ctx.saved_tensors 66 | time_len, bs, *z_shape = z.size() 67 | n_dim = np.prod(z_shape) 68 | n_params = flat_parameters.size(0) 69 | ode_solve = ctx.ode_solve 70 | STEP_SIZE = ctx.STEP_SIZE 71 | 72 | # Dynamics of augmented system to be calculated backwards in time 73 | def augmented_dynamics(aug_z_i): 74 | """ 75 | tensors here are temporal slices 76 | t_i - is tensor with size: bs, 1 77 | aug_z_i - is tensor with size: bs, n_dim*2 + n_params + 1 78 | """ 79 | z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim] # ignore parameters and time 80 | 81 | # Unflatten z and a 82 | z_i = z_i.view(bs, *z_shape) 83 | a = a.view(bs, *z_shape) 84 | with torch.set_grad_enabled(True): 85 | z_i = z_i.detach().requires_grad_(True) 86 | func_eval, adfdz, adfdp = func.forward_with_grad(z_i, grad_outputs=a) # bs, *z_shape 87 | adfdz = adfdz.to(z_i) if adfdz is not None else torch.zeros(bs, *z_shape).to(z_i) 88 | adfdp = adfdp.to(z_i) if adfdp is not None else torch.zeros(bs, n_params).to(z_i) 89 | 90 | # Flatten f and adfdz 91 | func_eval = func_eval.view(bs, n_dim) 92 | adfdz = adfdz.view(bs, n_dim) 93 | return torch.cat((func_eval, -adfdz, -adfdp), dim=1) 94 | 95 | dLdz = dLdz.view(time_len, bs, n_dim) # flatten dLdz for convenience 96 | with torch.no_grad(): 97 | ## Create placeholders for output gradients 98 | # Prev computed backwards adjoints to be adjusted by direct gradients 99 | adj_z = torch.zeros(bs, n_dim).to(dLdz) 100 | adj_p = torch.zeros(bs, n_params).to(dLdz) 101 | # In contrast to z and p we need to return gradients for all times 102 | adj_t = torch.zeros(time_len, bs, 1).to(dLdz) 103 | 104 | for i_t in range(time_len-1, 0, -1): 105 | z_i = z[i_t] 106 | t_i = t[i_t] 107 | f_i = func(z_i).view(bs, n_dim) 108 | # Compute direct gradients 109 | dLdz_i = dLdz[i_t] 110 | 111 | # Adjusting adjoints with direct gradients 112 | adj_z += dLdz_i 113 | 114 | # Pack augmented variable 115 | aug_z = torch.cat((z_i.view(bs, n_dim), adj_z, torch.zeros(bs, n_params).to(z)), dim=-1) 116 | 117 | # Solve augmented system backwards 118 | aug_ans = ode_solve(aug_z, torch.abs(t_i-t[i_t-1]), augmented_dynamics, -STEP_SIZE) 119 | 120 | # Unpack solved backwards augmented system 121 | adj_z[:] = aug_ans[:, n_dim:2*n_dim] 122 | adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params] 123 | 124 | del aug_z, aug_ans 125 | 126 | ## Adjust 0 time adjoint with direct gradients 127 | # Compute direct gradients 128 | dLdz_0 = dLdz[0] 129 | 130 | # Adjust adjoints 131 | adj_z += dLdz_0 132 | #print("\nreturned grad:\n", adj_p) 133 | return adj_z.view(bs, *z_shape), None, adj_p, None, None, None 134 | 135 | class NeuralODE(nn.Module): 136 | def __init__(self, func, ode_solve, STEP_SIZE): 137 | super(NeuralODE, self).__init__() 138 | assert isinstance(func, ODEF) 139 | self.func = func 140 | self.ode_solve = ode_solve 141 | self.STEP_SIZE = STEP_SIZE 142 | 143 | def forward(self, z0, t=Tensor([0., 1.]), return_whole_sequence=False): 144 | t = t.to(z0) 145 | z = ODEAdjoint.apply(z0, t, self.func.flatten_parameters(), self.func, self.ode_solve, self.STEP_SIZE) 146 | if return_whole_sequence: 147 | return z 148 | else: 149 | return z[-1] 150 | 151 | --------------------------------------------------------------------------------