├── Human motion ├── HM_GP_NODE.py ├── HM_GP_NODE_MCMC_convergence.py ├── data │ ├── A_W.npy │ ├── A_W_2.npy │ ├── A_hyp.npy │ ├── A_hyp_2.npy │ ├── A_noise.npy │ ├── A_noise_2.npy │ ├── A_par.npy │ ├── A_par_2.npy │ ├── B_W.npy │ ├── B_W_2.npy │ ├── B_hyp.npy │ ├── B_hyp_2.npy │ ├── B_noise.npy │ ├── B_noise_2.npy │ ├── B_par.npy │ ├── B_par_2.npy │ ├── X.mat │ ├── Y.mat │ ├── u_pca.mat │ └── v_pca.mat ├── npODE.zip └── plots │ ├── A_autocorrelation_1.png │ ├── A_autocorrelation_2.png │ ├── A_box_plot.png │ ├── A_geweke.png │ ├── A_x_1.png │ ├── A_x_2.png │ ├── A_x_3.png │ ├── A_y_27.png │ ├── A_y_34.png │ ├── A_y_37.png │ ├── A_y_39.png │ ├── A_y_42.png │ ├── A_y_48.png │ ├── B_autocorrelation_1.png │ ├── B_autocorrelation_2.png │ ├── B_box_plot.png │ ├── B_geweke.png │ ├── B_x_1.png │ ├── B_x_2.png │ ├── B_x_3.png │ ├── B_y_27.png │ ├── B_y_34.png │ ├── B_y_37.png │ ├── B_y_39.png │ ├── B_y_42.png │ └── B_y_48.png ├── Predator-prey ├── PP_GP_NODE.py ├── PP_GP_NODE_MCMC_convergence.py ├── PP_sindy.py ├── data │ ├── IC.npy │ ├── IC_2.npy │ ├── W.npy │ ├── W_2.npy │ ├── hyp.npy │ ├── hyp_2.npy │ ├── noise.npy │ ├── noise_2.npy │ ├── par.npy │ └── par_2.npy ├── plots │ ├── autocorrelation.png │ ├── box_plot.png │ ├── box_plot_x0.png │ ├── geweke.png │ ├── x_1.png │ └── x_2.png └── plots_sindy │ ├── case_1_as_GP_NODE_x1.png │ ├── case_1_as_GP_NODE_x2.png │ ├── case_2_fine_dt_x1.png │ ├── case_2_fine_dt_x2.png │ ├── case_3_fine_dt_no_noise_x1.png │ ├── case_3_fine_dt_no_noise_x2.png │ ├── case_4_no_t_gap_large_dt_x1.png │ ├── case_4_no_t_gap_large_dt_x2.png │ ├── case_5_no_t_gap_small_dt_x1.png │ └── case_5_no_t_gap_small_dt_x2.png ├── README.md ├── Yeast-Glycolysis ├── YG_GP_NODE.py ├── YG_GP_NODE_MCMC_convergence.py ├── data │ ├── W.npy │ ├── W_2.npy │ ├── hyp.npy │ ├── hyp_2.npy │ ├── noise.npy │ ├── noise_2.npy │ ├── par_and_IC.npy │ └── par_and_IC_2.npy └── plots │ ├── autocorrelation_1.png │ ├── autocorrelation_2.png │ ├── autocorrelation_3.png │ ├── box_plot.png │ ├── geweke.png │ ├── random_x0_x_1.png │ ├── random_x0_x_2.png │ ├── random_x0_x_3.png │ ├── random_x0_x_4.png │ ├── random_x0_x_5.png │ ├── random_x0_x_6.png │ ├── random_x0_x_7.png │ ├── x_1.png │ ├── x_2.png │ ├── x_3.png │ ├── x_4.png │ ├── x_5.png │ ├── x_6.png │ └── x_7.png └── readme.txt /Human motion/HM_GP_NODE.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Sep 29 17:43:50 2020 5 | 6 | @author: mohamedazizbhouri 7 | """ 8 | 9 | import math 10 | 11 | import jax.numpy as np 12 | import jax.random as random 13 | from jax import vmap, jit 14 | from jax.experimental.ode import odeint 15 | from jax.config import config 16 | config.update("jax_enable_x64", True) 17 | 18 | from numpyro import sample 19 | import numpyro.distributions as dist 20 | from numpyro.infer import MCMC, NUTS 21 | 22 | import numpy as onp 23 | import matplotlib.pyplot as plt 24 | from functools import partial 25 | import time 26 | 27 | from scipy import io 28 | 29 | class ODE_GP: 30 | # Initialize the class 31 | def __init__(self, t, i_t, X_train_ff, x0, dxdt, case): 32 | self.t = t 33 | self.x0 = x0 34 | self.i_t = i_t 35 | self.dxdt = dxdt 36 | self.jitter = 1e-8 37 | self.ind = ind 38 | 39 | self.max_t = t.max(0) 40 | 41 | self.max_X = [] 42 | self.X = [] 43 | self.N = [] 44 | self.D = len(i_t) 45 | self.t_t = [] 46 | for i in range(len(i_t)): 47 | self.max_X.append(np.abs(X_train_ff[i]).max(0)) 48 | self.X.append(X_train_ff[i] / self.max_X[i]) 49 | self.N.append(X_train_ff[i].shape[0]) 50 | self.t_t.append(t[self.i_t[i]]/self.max_t) 51 | 52 | if case == 'A_': 53 | self.M = 9*3 54 | if case == 'B_': 55 | self.M = 10*3 56 | 57 | @partial(jit, static_argnums=(0,)) 58 | def RBF(self,x1, x2, params): 59 | diffs = (x1 / params).T - x2 / params 60 | return np.exp(-0.5 * diffs**2) 61 | 62 | def model(self, t, X): 63 | noise = sample('noise', dist.LogNormal(0.0, 1.0), sample_shape=(self.D,)) 64 | 65 | hyp = sample('hyp', dist.Gamma(1.0, 0.5), sample_shape=(self.D,)) 66 | W = sample('W', dist.LogNormal(0.0, 1.0), sample_shape=(self.D,)) 67 | 68 | m0 = self.M-1 69 | sigma = 1 70 | 71 | tau0 = (m0/(self.M-m0) * (sigma/np.sqrt(1.0*sum(self.N)))) 72 | tau_tilde = sample('tau_tilde', dist.HalfCauchy(1.), sample_shape=(self.D,)) 73 | tau = np.repeat(tau0 * tau_tilde,self.M//self.D) 74 | 75 | slab_scale=1 76 | slab_scale2 = slab_scale**2 77 | 78 | slab_df=1 79 | half_slab_df = slab_df/2 80 | c2_tilde = sample('c2_tilde', dist.InverseGamma(half_slab_df, half_slab_df)) 81 | c2 = slab_scale2 * c2_tilde 82 | 83 | lambd = sample('lambd', dist.HalfCauchy(1.), sample_shape=(self.M,)) 84 | lambd_tilde = tau**2 * c2 * lambd**2 / (c2 + tau**2 * lambd**2) 85 | par = sample('par', dist.MultivariateNormal(np.zeros(self.M,), np.diag(lambd_tilde))) 86 | 87 | # compute kernel 88 | K_11 = W[0]*self.RBF(self.t_t[0], self.t_t[0], hyp[0]) + np.eye(self.N[0])*(noise[0] + self.jitter) 89 | K_22 = W[1]*self.RBF(self.t_t[1], self.t_t[1], hyp[1]) + np.eye(self.N[1])*(noise[1] + self.jitter) 90 | K_33 = W[2]*self.RBF(self.t_t[2], self.t_t[2], hyp[2]) + np.eye(self.N[2])*(noise[2] + self.jitter) 91 | K = np.concatenate([np.concatenate([K_11, np.zeros((self.N[0], self.N[1])), np.zeros((self.N[0], self.N[2]))], axis = 1), 92 | np.concatenate([np.zeros((self.N[1], self.N[0])), K_22, np.zeros((self.N[1], self.N[2]))], axis = 1), 93 | np.concatenate([np.zeros((self.N[2], self.N[0])), np.zeros((self.N[2], self.N[1])), K_33], axis = 1)], axis = 0) 94 | 95 | # compute mean 96 | mut = odeint(self.dxdt, self.x0, self.t.flatten(), par) 97 | mu1 = mut[self.i_t[0],ind[0]] / self.max_X[0] 98 | mu2 = mut[self.i_t[1],ind[1]] / self.max_X[1] 99 | mu3 = mut[self.i_t[2],ind[2]] / self.max_X[2] 100 | mu = np.concatenate((mu1,mu2,mu3),axis=0) 101 | mu = mu.flatten('F') 102 | 103 | X = np.concatenate((self.X[0],self.X[1],self.X[2]),axis=0) 104 | X = X.flatten('F') 105 | 106 | # sample X according to the standard gaussian process formula 107 | sample("X", dist.MultivariateNormal(loc=mu, covariance_matrix=K), obs=X) 108 | 109 | # helper function for doing hmc inference 110 | def train(self, settings, rng_key): 111 | start = time.time() 112 | kernel = NUTS(self.model, 113 | target_accept_prob = settings['target_accept_prob']) 114 | mcmc = MCMC(kernel, 115 | num_warmup = settings['num_warmup'], 116 | num_samples = settings['num_samples'], 117 | num_chains = settings['num_chains'], 118 | progress_bar=True, 119 | jit_model_args=True) 120 | mcmc.run(rng_key, self.t, self.X) 121 | mcmc.print_summary() 122 | elapsed = time.time() - start 123 | print('\nMCMC elapsed time: %.2f seconds' % (elapsed)) 124 | return mcmc.get_samples() 125 | 126 | @partial(jit, static_argnums=(0,)) 127 | def predict(self, t_star, par): 128 | X = odeint(self.dxdt, self.x0, t_star, par) 129 | return X 130 | 131 | plt.rcParams.update(plt.rcParamsDefault) 132 | plt.rc('font', family='serif') 133 | plt.rcParams.update({'font.size': 16, 134 | 'lines.linewidth': 2, 135 | 'axes.labelsize': 20, # fontsize for x and y labels 136 | 'axes.titlesize': 20, 137 | 'xtick.labelsize': 16, 138 | 'ytick.labelsize': 16, 139 | 'legend.fontsize': 20, 140 | 'axes.linewidth': 2, 141 | "pgf.texsystem": "pdflatex", # change this if using xetex or lautex 142 | "text.usetex": True, # use LaTeX to write all text 143 | }) 144 | 145 | # load data 146 | mat = io.loadmat('data/X.mat') 147 | dat = mat['X'] 148 | dat_max = np.abs(dat).max(0) 149 | dat = dat / dat_max 150 | 151 | l_dat = io.loadmat('data/v_pca.mat') 152 | v_pca = l_dat['v_pca'] 153 | l_dat = io.loadmat('data/u_pca.mat') 154 | u_pca = l_dat['u_pca'] 155 | l_dat = io.loadmat('data/Y.mat') 156 | Y = l_dat['Y'] 157 | 158 | case = 'B_' # 'A_' or 'B_' 159 | 160 | if case =='A_': 161 | def robot_dict(x, t, par): 162 | x1, x2, x3 = x 163 | dxdt = [par[0] + par[1]*x1 + par[2]*x2 + par[3]*x3 + par[4]*x1*x2 + par[5]*x1*x3 + par[6]*x2*x3, 164 | par[7] + par[8]*x1 + par[9]*x2 + par[10]*x3 + par[11]*x1*x2 + par[12]*x1*x3 + par[13]*x2*x3, 165 | par[14] + par[15]*x1 + par[16]*x2 + par[17]*x3 + par[18]*x1*x2 + par[19]*x1*x3 + par[20]*x2*x3] 166 | return dxdt 167 | if case =='B_': 168 | def robot_dict(x, t, par): 169 | x1, x2, x3 = x 170 | dxdt = [par[18] + par[0]*x1 + par[1]*x2 + par[2]*x3 + par[3]*x1*x2 + par[4]*x1*x3 + par[5]*x2*x3 + par[21]*x1**2 + par[22]*x2**2 + par[23]*x3**2, 171 | par[19] + par[6]*x1 + par[7]*x2 + par[8]*x3 + par[9]*x1*x2 + par[10]*x1*x3 + par[11]*x2*x3 + par[24]*x1**2 + par[25]*x2**2 + par[26]*x3**2, 172 | par[20] + par[12]*x1 + par[13]*x2 + par[14]*x3 + par[15]*x1*x2 + par[16]*x1*x3 + par[17]*x2*x3 + par[27]*x1**2 + par[28]*x2**2 + par[29]*x3**2] 173 | return dxdt 174 | 175 | key = random.PRNGKey(1234) 176 | D = 3 177 | 178 | Nt_star = 82 179 | 180 | x0 = dat[0,:] 181 | 182 | # Test data 183 | t_star = np.linspace(0, 1, Nt_star) 184 | X_star = dat 185 | 186 | # Training data 187 | Nt = 62 188 | t = t_star 189 | 190 | i1 = 33 191 | i2 = 48 192 | ind_t = np.concatenate([np.arange(0,i1)[:,None],np.arange(i2,82)[:,None]]) 193 | ind_t = ind_t[:,0] 194 | 195 | ind_t_test = np.arange(i1,i2) 196 | 197 | i_t = [] 198 | i_t.append(ind_t) 199 | i_t.append(ind_t) 200 | i_t.append(ind_t) 201 | 202 | X_train = dat 203 | 204 | ind = [0,1,2] 205 | 206 | X1_train = X_train[i_t[0],ind[0]] 207 | X2_train = X_train[i_t[1],ind[1]] 208 | X3_train = X_train[i_t[2],ind[2]] 209 | 210 | X_train_ff = [] 211 | X_train_ff.append(X1_train) 212 | X_train_ff.append(X2_train) 213 | X_train_ff.append(X3_train) 214 | 215 | model = ODE_GP(t[:,None], i_t, X_train_ff, x0, robot_dict, case) 216 | rng_key_train, rng_key_predict = random.split(random.PRNGKey(0)) 217 | 218 | num_warmup = 4000 219 | num_samples = 8000 220 | num_chains = 1 221 | target_accept_prob = 0.85 222 | settings = {'num_warmup': num_warmup, 223 | 'num_samples': num_samples, 224 | 'num_chains': num_chains, 225 | 'target_accept_prob': target_accept_prob} 226 | samples = model.train(settings, rng_key_train) 227 | 228 | np.save('data/'+case+'par',np.array(samples['par'])) 229 | np.save('data/'+case+'noise',np.array(samples['noise'])) 230 | np.save('data/'+case+'hyp',np.array(samples['hyp'])) 231 | np.save('data/'+case+'W',np.array(samples['W'])) 232 | 233 | def RBF(x1, x2, params): 234 | diffs = (x1 / params).T - x2 / params 235 | return np.exp(-0.5 * diffs**2) 236 | def Ksin(x, xp, period, lengthscale): 237 | K = np.exp(-2.0*np.sin(np.pi*np.abs(x.T-xp)/period)**2/lengthscale**2) 238 | return K 239 | 240 | N_fine = 100 241 | t_test = np.linspace(0, 1, Nt_star) 242 | Nt_test = t_test.shape[0] 243 | 244 | t_tr = t[:,None] /model.max_t 245 | t_te = t_test[:,None] /model.max_t 246 | 247 | vmap_args = (samples['par'],) 248 | pred_X_tr_i = lambda a: model.predict(t, a) 249 | X_tr_i = vmap(pred_X_tr_i)(*vmap_args) 250 | 251 | pred_X_ode_i = lambda a: model.predict(t_test, a) 252 | X_ode_i = vmap(pred_X_ode_i)(*vmap_args) 253 | 254 | X_pred_GP = [] 255 | Npred_GP_f = 0 256 | 257 | ind_PCA = [26,33,36,38,41,47] 258 | Y_PCA_GP = [] 259 | Y_PCA = np.matmul( np.matmul( dat_max*X_star ,np.diag(np.sqrt(v_pca[:,0]))) ,u_pca ) 260 | 261 | for i in range(num_samples): 262 | if i % 500 == 0: 263 | print(i) 264 | K1_tr = samples['W'][i,0]*RBF(model.t_t[0], model.t_t[0], samples['hyp'][i,0]) + np.eye(model.N[0])*(samples['noise'][i,0] + model.jitter) 265 | K2_tr = samples['W'][i,1]*RBF(model.t_t[1], model.t_t[1], samples['hyp'][i,1]) + np.eye(model.N[1])*(samples['noise'][i,1] + model.jitter) 266 | K3_tr = samples['W'][i,2]*RBF(model.t_t[2], model.t_t[2], samples['hyp'][i,2]) + np.eye(model.N[2])*(samples['noise'][i,2] + model.jitter) 267 | K_tr = np.concatenate([np.concatenate([K1_tr, np.zeros((model.N[0], model.N[1])), np.zeros((model.N[0], model.N[2]))], axis = 1), 268 | np.concatenate([np.zeros((model.N[1], model.N[0])), K2_tr, np.zeros((model.N[1], model.N[2]))], axis = 1), 269 | np.concatenate([np.zeros((model.N[2], model.N[0])), np.zeros((model.N[2], model.N[1])), K3_tr], axis = 1)], axis = 0) 270 | K1_trte = samples['W'][i,0]*RBF(t_te, model.t_t[0], samples['hyp'][i,0]) 271 | K2_trte = samples['W'][i,1]*RBF(t_te, model.t_t[1], samples['hyp'][i,1]) 272 | K3_trte = samples['W'][i,2]*RBF(t_te, model.t_t[2], samples['hyp'][i,2]) 273 | K_trte = np.concatenate([np.concatenate([K1_trte, np.zeros((model.N[0],Nt_test)), np.zeros((model.N[0],Nt_test))], axis = 1), 274 | np.concatenate([np.zeros((model.N[1],Nt_test)), K2_trte, np.zeros((model.N[1],Nt_test))], axis = 1), 275 | np.concatenate([np.zeros((model.N[2],Nt_test)), np.zeros((model.N[2],Nt_test)), K3_trte], axis = 1)], axis = 0) 276 | K1_te = samples['W'][i,0]*RBF(t_te, t_te, samples['hyp'][i,0]) 277 | K2_te = samples['W'][i,1]*RBF(t_te, t_te, samples['hyp'][i,1]) 278 | K3_te = samples['W'][i,2]*RBF(t_te, t_te, samples['hyp'][i,2]) 279 | K_te = np.concatenate([np.concatenate([K1_te, np.zeros((Nt_test,Nt_test)), np.zeros((Nt_test,Nt_test))], axis = 1), 280 | np.concatenate([np.zeros((Nt_test,Nt_test)), K2_te, np.zeros((Nt_test,Nt_test))], axis = 1), 281 | np.concatenate([np.zeros((Nt_test,Nt_test)), np.zeros((Nt_test,Nt_test)), K3_te], axis = 1)], axis = 0) 282 | X_tr1 = X_tr_i[i,i_t[0],ind[0]] / model.max_X[0] 283 | X_tr2 = X_tr_i[i,i_t[1],ind[1]] / model.max_X[1] 284 | X_tr3 = X_tr_i[i,i_t[2],ind[2]] / model.max_X[2] 285 | X_tr = np.concatenate((X_tr1,X_tr2,X_tr3),axis=0) 286 | 287 | L = np.linalg.cholesky(K_tr) 288 | X_train_f = np.concatenate((model.X[0],model.X[1],model.X[2]),axis=0) 289 | X_train_f = X_train_f.flatten('F') 290 | 291 | dX = np.matmul( K_trte.T, np.linalg.solve(np.transpose(L), np.linalg.solve(L,X_train_f.flatten('F')-X_tr.flatten('F'))) ) 292 | X_ode1 = X_ode_i[i,:,ind[0]] / model.max_X[0] 293 | X_ode2 = X_ode_i[i,:,ind[1]] / model.max_X[1] 294 | X_ode3 = X_ode_i[i,:,ind[2]] / model.max_X[2] 295 | X_ode = np.concatenate((X_ode1,X_ode2,X_ode3),axis=0) 296 | 297 | mu = X_ode.flatten('F') + dX 298 | K = K_te - np.matmul(K_trte.T, np.linalg.solve(np.transpose(L), np.linalg.solve(L,K_trte))) 299 | pred = onp.random.multivariate_normal(mu, K) 300 | if not math.isnan( np.amax(np.abs(pred)) ): 301 | Npred_GP_f += 1 302 | X_pred_GP.append( pred.reshape((D, Nt_test)).T ) 303 | Y_PCA_GP.append( np.matmul( np.matmul( dat_max*np.array(model.max_X)*pred.reshape((D, Nt_test)).T,np.diag(np.sqrt(v_pca[:,0]))) ,u_pca ) ) 304 | 305 | X_pred_GP = np.array(X_pred_GP) 306 | mean_prediction_GP, std_prediction_GP = np.mean(X_pred_GP, axis=0), np.std(X_pred_GP, axis=0) 307 | lower_GP = mean_prediction_GP - 2.0*std_prediction_GP 308 | upper_GP = mean_prediction_GP + 2.0*std_prediction_GP 309 | 310 | for i in range(D): 311 | plt.figure(figsize = (12,6.5)) 312 | plt.xticks(fontsize=22) 313 | plt.yticks(fontsize=22) 314 | plt.plot(t_star, X_star[:,i], 'r-', label = "True Trajectory of $x_"+str(i+1)+"(t)$") 315 | 316 | plt.plot(t[i_t[i]], X_train[i_t[i],i],'ro', label = "Training data of $x_"+str(i+1)+"(t)$") 317 | plt.plot(t_test, mean_prediction_GP[:,i], 'g--', label = "MAP Trajectory of $x_"+str(i+1)+"(t)$") 318 | plt.fill_between(t_test, lower_GP[:,i], upper_GP[:,i], facecolor='orange', alpha=0.5, label="Two std band") 319 | plt.axvspan(t_star[i1-1], t_star[i2], alpha=0.1, color='blue') 320 | plt.xlabel('$t$',fontsize=26) 321 | plt.ylabel('$x_'+str(i+1)+'(t)$',fontsize=26) 322 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 323 | plt.ylim(top= 2.2*upper_GP[:,i].max(0)) 324 | tt = 'plots/' + case + 'x_' + str(i+1) + ".png" 325 | plt.savefig(tt, dpi = 300) 326 | 327 | Y_PCA_GP = np.array(Y_PCA_GP) 328 | mean_prediction_Y, std_prediction_Y = np.mean(Y_PCA_GP, axis=0), np.std(Y_PCA_GP, axis=0) 329 | lower_Y = mean_prediction_Y - 2.0*std_prediction_Y 330 | upper_Y = mean_prediction_Y + 2.0*std_prediction_Y 331 | 332 | for i in range(len(ind_PCA)): 333 | plt.figure(figsize = (12,6.5)) 334 | plt.xticks(fontsize=22) 335 | plt.yticks(fontsize=22) 336 | plt.plot(t_star, Y_PCA[:,ind_PCA[i]], 'r-', label = "True Trajectory of PCA-recovered $y_{"+str(ind_PCA[i]+1)+"}(t)$") 337 | plt.plot(t_test, mean_prediction_Y[:,ind_PCA[i]], 'g--', label = "MAP Trajectory PCA-recovered $y_{"+str(ind_PCA[i]+1)+"}(t)$") 338 | plt.fill_between(t_test, lower_Y[:,ind_PCA[i]], upper_Y[:,ind_PCA[i]], facecolor='orange', alpha=0.5, label="Two std band") 339 | plt.axvspan(t_star[i1-1], t_star[i2], alpha=0.1, color='blue') 340 | plt.xlabel('$t$',fontsize=26) 341 | plt.ylabel('PCA-recovered $y_{'+str(ind_PCA[i]+1)+'}(t)$',fontsize=26) 342 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 343 | plt.ylim(top= 2.2*upper_Y[:,ind_PCA[i]].max(0)) 344 | tt = 'plots/' + case + 'y_' + str(ind_PCA[i]+1) + ".png" 345 | plt.savefig(tt, dpi = 300) 346 | 347 | err_obs = [] 348 | err_miss = [] 349 | for i in range(num_samples): 350 | err_obs.append( np.sqrt( np.sum( (Y_PCA_GP[i,ind_t,:]-Y[ind_t,:])**2 ) / (Y[ind_t,:].shape[0]*Y[ind_t,:].shape[1]) ) ) 351 | err_miss.append( np.sqrt( np.sum( (Y_PCA_GP[i,ind_t_test,:]-Y[ind_t_test,:])**2 ) / (Y[ind_t_test,:].shape[0]*Y[ind_t_test,:].shape[1]) ) ) 352 | 353 | err_obs = np.array(err_obs) 354 | err_miss = np.array(err_miss) 355 | print("Error for fitting observed data:",np.mean(err_obs),np.std(err_obs)) 356 | print("Error for forecasting missing data:",np.mean(err_miss),np.std(err_miss)) 357 | 358 | print(Npred_GP_f) 359 | 360 | import matplotlib as mpl 361 | 362 | def figsize(scale, nplots = 1): 363 | fig_width_pt = 390.0 # Get this from LaTeX using \the\textwidth 364 | inches_per_pt = 1.0/72.27 # Convert pt to inch 365 | golden_mean = (np.sqrt(5.0)-1.0)/2.0 # Aesthetic ratio (you could change this) 366 | fig_width = fig_width_pt*inches_per_pt*scale # width in inches 367 | fig_height = nplots*fig_width*golden_mean # height in inches 368 | fig_size = [fig_width,fig_height] 369 | return fig_size 370 | 371 | pgf_with_latex = { # setup matplotlib to use latex for output 372 | "pgf.texsystem": "pdflatex", # change this if using xetex or lautex 373 | "text.usetex": True, # use LaTeX to write all text 374 | "font.family": "serif", 375 | "font.serif": [], # blank entries should cause plots to inherit fonts from the document 376 | "font.sans-serif": [], 377 | "font.monospace": [], 378 | "axes.labelsize": 20, # LaTeX default is 10pt font. 379 | "axes.titlesize": 20, 380 | "axes.linewidth": 2, 381 | "font.size": 16, 382 | "lines.linewidth": 2, 383 | "legend.fontsize": 20, # Make the legend/label fonts a little smaller 384 | "xtick.labelsize": 24, 385 | "ytick.labelsize": 24, 386 | "figure.figsize": figsize(1.0), # default fig size of 0.9 textwidth 387 | "pgf.preamble": [ 388 | r"\usepackage[utf8x]{inputenc}", # use utf8 fonts becasue your computer can handle it :) 389 | r"\usepackage[T1]{fontenc}", # plots will be generated using this preamble 390 | ] 391 | } 392 | mpl.rcParams.update(pgf_with_latex) 393 | 394 | def newfig(width, nplots = 1): 395 | fig = plt.figure(figsize=figsize(width, nplots)) 396 | ax = fig.add_subplot(111) 397 | return fig, ax 398 | 399 | def savefig(filename, crop = True): 400 | if crop == True: 401 | plt.savefig('{}.png'.format(filename), bbox_inches='tight', pad_inches=0 , dpi = 300) 402 | else: 403 | plt.savefig('{}.png'.format(filename) , dpi = 300) 404 | 405 | if case == 'A_': # case A 406 | 407 | a11 = samples['par'][:,0] 408 | a12 = samples['par'][:,1] 409 | a13 = samples['par'][:,2] 410 | a14 = samples['par'][:,3] 411 | a15 = samples['par'][:,4] 412 | a16 = samples['par'][:,5] 413 | a17 = samples['par'][:,6] 414 | a21 = samples['par'][:,7] 415 | a22 = samples['par'][:,8] 416 | a23 = samples['par'][:,9] 417 | a24 = samples['par'][:,10] 418 | a25 = samples['par'][:,11] 419 | a26 = samples['par'][:,12] 420 | a27 = samples['par'][:,13] 421 | a31 = samples['par'][:,14] 422 | a32 = samples['par'][:,15] 423 | a33 = samples['par'][:,16] 424 | a34 = samples['par'][:,17] 425 | a35 = samples['par'][:,18] 426 | a36 = samples['par'][:,19] 427 | a37 = samples['par'][:,20] 428 | 429 | Data = [a11,a12,a13,a14,a15,a16,a17,a21,a22,a23,a24,a25,a26,a27,a31,a32,a33,a34,a35,a36,a37] 430 | 431 | name = [r'$a_{11}$',r'$a_{12}$',r'$a_{13}$',r'$a_{14}$',r'$a_{15}$',r'$a_{16}$',r'$a_{17}$',r'$a_{21}$',r'$a_{22}$',r'$a_{23}$',r'$a_{24}$',r'$a_{25}$',r'$a_{26}$',r'$a_{27}$',r'$a_{31}$',r'$a_{32}$',r'$a_{33}$',r'$a_{34}$',r'$a_{35}$',r'$a_{36}$',r'$a_{37}$'] 432 | 433 | if case == 'B_': # case B 434 | 435 | a11 = samples['par'][:,18] 436 | a12 = samples['par'][:,0] 437 | a13 = samples['par'][:,1] 438 | a14 = samples['par'][:,2] 439 | a15 = samples['par'][:,3] 440 | a16 = samples['par'][:,4] 441 | a17 = samples['par'][:,5] 442 | a18 = samples['par'][:,21] 443 | a19 = samples['par'][:,22] 444 | a110 = samples['par'][:,23] 445 | 446 | a21 = samples['par'][:,19] 447 | a22 = samples['par'][:,6] 448 | a23 = samples['par'][:,7] 449 | a24 = samples['par'][:,8] 450 | a25 = samples['par'][:,9] 451 | a26 = samples['par'][:,10] 452 | a27 = samples['par'][:,11] 453 | a28 = samples['par'][:,24] 454 | a29 = samples['par'][:,25] 455 | a210 = samples['par'][:,26] 456 | 457 | a31 = samples['par'][:,20] 458 | a32 = samples['par'][:,12] 459 | a33 = samples['par'][:,13] 460 | a34 = samples['par'][:,14] 461 | a35 = samples['par'][:,15] 462 | a36 = samples['par'][:,16] 463 | a37 = samples['par'][:,17] 464 | a38 = samples['par'][:,27] 465 | a39 = samples['par'][:,28] 466 | a310 = samples['par'][:,29] 467 | 468 | Data = [a11,a12,a13,a14,a15,a16,a17,a18,a19,a110,a21,a22,a23,a24,a25,a26,a27,a28,a29,a210,a31,a32,a33,a34,a35,a36,a37,a38,a39,a310] 469 | 470 | name = [r'$a_{11}$',r'$a_{12}$',r'$a_{13}$',r'$a_{14}$',r'$a_{15}$',r'$a_{16}$',r'$a_{17}$',r'$a_{18}$',r'$a_{19}$',r'$a_{110}$',r'$a_{21}$',r'$a_{22}$',r'$a_{23}$',r'$a_{24}$',r'$a_{25}$',r'$a_{26}$',r'$a_{27}$',r'$a_{28}$',r'$a_{29}$',r'$a_{210}$',r'$a_{31}$',r'$a_{32}$',r'$a_{33}$',r'$a_{34}$',r'$a_{35}$',r'$a_{36}$',r'$a_{37}$',r'$a_{38}$',r'$a_{39}$',r'$a_{310}$'] 471 | 472 | fig7, ax7 = plt.subplots(figsize=(20, 10)) 473 | ax7.boxplot(Data, showfliers=False, labels=name) 474 | savefig('plots/' + case + 'box_plot', True) 475 | -------------------------------------------------------------------------------- /Human motion/HM_GP_NODE_MCMC_convergence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pymc3 as pm3 4 | 5 | plt.rcParams.update(plt.rcParamsDefault) 6 | plt.rc('font', family='serif') 7 | plt.rcParams.update({'font.size': 16, 8 | 'lines.linewidth': 2, 9 | 'axes.labelsize': 20, # fontsize for x and y labels 10 | 'axes.titlesize': 20, 11 | 'xtick.labelsize': 16, 12 | 'ytick.labelsize': 16, 13 | 'legend.fontsize': 20, 14 | 'axes.linewidth': 2, 15 | "pgf.texsystem": "pdflatex", # change this if using xetex or lautex 16 | "text.usetex": True, # use LaTeX to write all text 17 | }) 18 | 19 | case = 'A_' # 'A_' or 'B_' 20 | N_var = 10 21 | i0 = 0 22 | ie = 8000 23 | step = 8 24 | if case == 'A_': 25 | # A dictionary 26 | data2 = np.load("data/A_par.npy",allow_pickle=True) 27 | data3 = np.load("data/A_par_2.npy",allow_pickle=True) 28 | a13_1 = data2[i0:ie:step,2:3] 29 | a14_1 = data2[i0:ie:step,3:4] 30 | a16_1 = data2[i0:ie:step,5:6] 31 | a17_1 = data2[i0:ie:step,6:7] 32 | a22_1 = data2[i0:ie:step,8:9] 33 | a25_1 = data2[i0:ie:step,11:12] 34 | a26_1 = data2[i0:ie:step,12:13] 35 | a34_1 = data2[i0:ie:step,17:18] 36 | a35_1 = data2[i0:ie:step,18:19] 37 | a36_1 = data2[i0:ie:step,19:20] 38 | 39 | a13_2 = data3[i0:ie:step,2:3] 40 | a14_2 = data3[i0:ie:step,3:4] 41 | a16_2 = data3[i0:ie:step,5:6] 42 | a17_2 = data3[i0:ie:step,6:7] 43 | a22_2 = data3[i0:ie:step,8:9] 44 | a25_2 = data3[i0:ie:step,11:12] 45 | a26_2 = data3[i0:ie:step,12:13] 46 | a34_2 = data3[i0:ie:step,17:18] 47 | a35_2 = data3[i0:ie:step,18:19] 48 | a36_2 = data3[i0:ie:step,19:20] 49 | 50 | else: 51 | # case B 52 | data2 = np.load("data/B_par.npy",allow_pickle=True) 53 | data3 = np.load("data/B_par_2.npy",allow_pickle=True) 54 | a13_1 = data2[i0:ie:step,1:2] 55 | a14_1 = data2[i0:ie:step,2:3] 56 | a16_1 = data2[i0:ie:step,4:5] 57 | a17_1 = data2[i0:ie:step,5:6] 58 | a22_1 = data2[i0:ie:step,6:7] 59 | a25_1 = data2[i0:ie:step,9:10] 60 | a26_1 = data2[i0:ie:step,10:11] 61 | a34_1 = data2[i0:ie:step,14:15] 62 | a35_1 = data2[i0:ie:step,15:16] 63 | a36_1 = data2[i0:ie:step,16:17] 64 | 65 | a13_2 = data3[i0:ie:step,1:2] 66 | a14_2 = data3[i0:ie:step,2:3] 67 | a16_2 = data3[i0:ie:step,4:5] 68 | a17_2 = data3[i0:ie:step,5:6] 69 | a22_2 = data3[i0:ie:step,6:7] 70 | a25_2 = data3[i0:ie:step,9:10] 71 | a26_2 = data3[i0:ie:step,10:11] 72 | a34_2 = data3[i0:ie:step,14:15] 73 | a35_2 = data3[i0:ie:step,15:16] 74 | a36_2 = data3[i0:ie:step,16:17] 75 | 76 | names = [r'$a_{13}$',r'$a_{14}$',r'$a_{16}$',r'$a_{17}$',r'$a_{22}$',r'$a_{25}$',r'$a_{26}$',r'$a_{34}$',r'$a_{35}$',r'$a_{36}$'] 77 | 78 | data_chain1 = np.concatenate((a13_1,a14_1,a16_1,a17_1,a22_1,a25_1,a26_1,a34_1,a35_1,a36_1),axis=-1) # 2000 x 5 79 | data_chain2 = np.concatenate((a13_2,a14_2,a16_2,a17_2,a22_2,a25_2,a26_2,a34_2,a35_2,a36_2),axis=-1) # 2000 x 5 80 | 81 | N = a13_1.shape[0] 82 | iteration = np.arange(0,N) 83 | 84 | N_per_block = 5 85 | data_traceplot1_1 = {} 86 | data_traceplot1_2 = {} 87 | data_traceplot2_1 = {} 88 | data_traceplot2_2 = {} 89 | 90 | j = 0 91 | for i,name in enumerate(names[j:j+N_per_block]): 92 | data_traceplot1_1[name] = data_chain1[:,j+i] 93 | data_traceplot2_1[name] = data_chain2[:,j+i] 94 | 95 | j = N_per_block 96 | for i,name in enumerate(names[j:j+N_per_block]): 97 | data_traceplot1_2[name] = data_chain1[:,j+i] 98 | data_traceplot2_2[name] = data_chain2[:,j+i] 99 | 100 | for i in range(N_var): 101 | 102 | chain1 = data_chain1[:,i:i+1] 103 | chain2 = data_chain2[:,i:i+1] 104 | 105 | burn_in = 0 106 | length = (ie-i0)//step 107 | 108 | n = chain1[burn_in:burn_in+length].shape[0] 109 | 110 | W = (chain1[burn_in:burn_in+length].std()**2 + chain2[burn_in:burn_in+length].std()**2)/2 111 | mean1 = chain1[burn_in:burn_in+length].mean() 112 | mean2 = chain2[burn_in:burn_in+length].mean() 113 | mean = (mean1 + mean2)/2 114 | B = n * ((mean1 - mean)**2 + (mean2 - mean)**2) 115 | var_theta = (1 - 1/n) * W + 1/n*B 116 | print("Gelman-Rubin Diagnostic: ", np.sqrt(var_theta/W)) 117 | 118 | j = 0 119 | corr_plot1_1 = pm3.autocorrplot(data_traceplot1_1,var_names=names[j:j+N_per_block],grid=(1,N_per_block),figsize=(12,6.5),textsize=18,combined=True) 120 | corr_plot1_1 = corr_plot1_1[None,:] 121 | for i in range(N_per_block): 122 | corr_plot1_1[0, i].set_xlabel('Lag Index',fontsize=26) 123 | corr_plot1_1[0, 0].set_ylabel('Autocorrelation Value',fontsize=26) 124 | plt.savefig("plots/"+case+"autocorrelation_1.png", bbox_inches='tight', pad_inches=0.01) 125 | 126 | j = N_per_block 127 | corr_plot1_2 = pm3.autocorrplot(data_traceplot1_2,var_names=names[j:j+N_per_block],grid=(1,N_per_block),figsize=(12,6.5),textsize=18,combined=True) 128 | corr_plot1_2 = corr_plot1_2[None,:] 129 | for i in range(N_per_block): 130 | corr_plot1_2[0, i].set_xlabel('Lag Index',fontsize=26) 131 | corr_plot1_2[0, 0].set_ylabel('Autocorrelation Value',fontsize=26) 132 | plt.savefig("plots/"+case+"autocorrelation_2.png", bbox_inches='tight', pad_inches=0.01) 133 | 134 | plt.figure(figsize=(12,7)) 135 | 136 | for i in range(data_chain1.shape[1]): 137 | gw_plot = pm3.geweke(data_chain1[:,i],.1,.5,20) 138 | plt.scatter(gw_plot[:,0],gw_plot[:,1],label="%s"%names[i]) 139 | plt.axhline(-1.98, c='r') 140 | plt.axhline(1.98, c='r') 141 | plt.xticks(fontsize=22) 142 | plt.yticks(fontsize=22) 143 | plt.xlabel("Subchain sample number",fontsize=26) 144 | plt.ylabel("Geweke z-score",fontsize=26) 145 | plt.title('Geweke Plot Comparing first 10$\%$ and Slices of the Last 50$\%$ of Chain') 146 | 147 | plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left') 148 | plt.tight_layout() 149 | plt.show() 150 | plt.savefig("plots/"+case+"geweke.png", pad_inches=0.01) 151 | 152 | -------------------------------------------------------------------------------- /Human motion/data/A_W.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/A_W.npy -------------------------------------------------------------------------------- /Human motion/data/A_W_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/A_W_2.npy -------------------------------------------------------------------------------- /Human motion/data/A_hyp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/A_hyp.npy -------------------------------------------------------------------------------- /Human motion/data/A_hyp_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/A_hyp_2.npy -------------------------------------------------------------------------------- /Human motion/data/A_noise.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/A_noise.npy -------------------------------------------------------------------------------- /Human motion/data/A_noise_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/A_noise_2.npy -------------------------------------------------------------------------------- /Human motion/data/A_par.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/A_par.npy -------------------------------------------------------------------------------- /Human motion/data/A_par_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/A_par_2.npy -------------------------------------------------------------------------------- /Human motion/data/B_W.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/B_W.npy -------------------------------------------------------------------------------- /Human motion/data/B_W_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/B_W_2.npy -------------------------------------------------------------------------------- /Human motion/data/B_hyp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/B_hyp.npy -------------------------------------------------------------------------------- /Human motion/data/B_hyp_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/B_hyp_2.npy -------------------------------------------------------------------------------- /Human motion/data/B_noise.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/B_noise.npy -------------------------------------------------------------------------------- /Human motion/data/B_noise_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/B_noise_2.npy -------------------------------------------------------------------------------- /Human motion/data/B_par.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/B_par.npy -------------------------------------------------------------------------------- /Human motion/data/B_par_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/B_par_2.npy -------------------------------------------------------------------------------- /Human motion/data/X.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/X.mat -------------------------------------------------------------------------------- /Human motion/data/Y.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/Y.mat -------------------------------------------------------------------------------- /Human motion/data/u_pca.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/u_pca.mat -------------------------------------------------------------------------------- /Human motion/data/v_pca.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/data/v_pca.mat -------------------------------------------------------------------------------- /Human motion/npODE.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/npODE.zip -------------------------------------------------------------------------------- /Human motion/plots/A_autocorrelation_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_autocorrelation_1.png -------------------------------------------------------------------------------- /Human motion/plots/A_autocorrelation_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_autocorrelation_2.png -------------------------------------------------------------------------------- /Human motion/plots/A_box_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_box_plot.png -------------------------------------------------------------------------------- /Human motion/plots/A_geweke.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_geweke.png -------------------------------------------------------------------------------- /Human motion/plots/A_x_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_x_1.png -------------------------------------------------------------------------------- /Human motion/plots/A_x_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_x_2.png -------------------------------------------------------------------------------- /Human motion/plots/A_x_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_x_3.png -------------------------------------------------------------------------------- /Human motion/plots/A_y_27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_y_27.png -------------------------------------------------------------------------------- /Human motion/plots/A_y_34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_y_34.png -------------------------------------------------------------------------------- /Human motion/plots/A_y_37.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_y_37.png -------------------------------------------------------------------------------- /Human motion/plots/A_y_39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_y_39.png -------------------------------------------------------------------------------- /Human motion/plots/A_y_42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_y_42.png -------------------------------------------------------------------------------- /Human motion/plots/A_y_48.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/A_y_48.png -------------------------------------------------------------------------------- /Human motion/plots/B_autocorrelation_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_autocorrelation_1.png -------------------------------------------------------------------------------- /Human motion/plots/B_autocorrelation_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_autocorrelation_2.png -------------------------------------------------------------------------------- /Human motion/plots/B_box_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_box_plot.png -------------------------------------------------------------------------------- /Human motion/plots/B_geweke.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_geweke.png -------------------------------------------------------------------------------- /Human motion/plots/B_x_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_x_1.png -------------------------------------------------------------------------------- /Human motion/plots/B_x_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_x_2.png -------------------------------------------------------------------------------- /Human motion/plots/B_x_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_x_3.png -------------------------------------------------------------------------------- /Human motion/plots/B_y_27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_y_27.png -------------------------------------------------------------------------------- /Human motion/plots/B_y_34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_y_34.png -------------------------------------------------------------------------------- /Human motion/plots/B_y_37.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_y_37.png -------------------------------------------------------------------------------- /Human motion/plots/B_y_39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_y_39.png -------------------------------------------------------------------------------- /Human motion/plots/B_y_42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_y_42.png -------------------------------------------------------------------------------- /Human motion/plots/B_y_48.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Human motion/plots/B_y_48.png -------------------------------------------------------------------------------- /Predator-prey/PP_GP_NODE.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Sep 29 15:42:18 2020 5 | 6 | @author: mohamedazizbhouri 7 | """ 8 | 9 | import math 10 | 11 | import jax.numpy as np 12 | import jax.random as random 13 | from jax import vmap, jit 14 | from jax.experimental.ode import odeint 15 | from jax.config import config 16 | config.update("jax_enable_x64", True) 17 | 18 | from numpyro import sample 19 | import numpyro.distributions as dist 20 | from numpyro.infer import MCMC, NUTS 21 | 22 | import numpy as onp 23 | import matplotlib.pyplot as plt 24 | from functools import partial 25 | import time 26 | 27 | class ODE_GP: 28 | # Initialize the class 29 | def __init__(self, t, i_t, X_train_ff, x0, dxdt, ind): 30 | # Normalization 31 | self.t = t 32 | self.x0 = x0 33 | self.i_t = i_t 34 | self.dxdt = dxdt 35 | self.jitter = 1e-8 36 | self.ind = ind 37 | 38 | self.max_t = t.max(0) 39 | 40 | self.max_X = [] 41 | self.X = [] 42 | self.N = [] 43 | self.D = len(i_t) 44 | self.t_t = [] 45 | for i in range(len(i_t)): 46 | self.max_X.append(np.abs(X_train_ff[i]).max(0)) 47 | self.X.append(X_train_ff[i] / self.max_X[i]) 48 | self.N.append(X_train_ff[i].shape[0]) 49 | self.t_t.append(t[self.i_t[i]]/self.max_t) 50 | 51 | @partial(jit, static_argnums=(0,)) 52 | def RBF(self,x1, x2, params): 53 | diffs = (x1 / params).T - x2 / params 54 | return np.exp(-0.5 * diffs**2) 55 | 56 | def model(self, t, X): 57 | 58 | noise = sample('noise', dist.LogNormal(0.0, 1.0), sample_shape=(self.D,)) 59 | hyp = sample('hyp', dist.Gamma(1.0, 0.5), sample_shape=(self.D,)) 60 | W = sample('W', dist.LogNormal(0.0, 1.0), sample_shape=(self.D,)) 61 | 62 | M = 7*2 63 | m0 = 13 64 | sigma = 1 65 | 66 | tau0 = (m0/(M-m0) * (sigma/np.sqrt(1.0*self.D*sum(self.N)))) 67 | 68 | tau_tilde = sample('tau_tilde', dist.HalfCauchy(1.), sample_shape=(self.D,)) 69 | tau = np.repeat(tau0 * tau_tilde,M//self.D) 70 | 71 | slab_scale=1 72 | slab_scale2 = slab_scale**2 73 | 74 | slab_df=1 75 | half_slab_df = slab_df/2 76 | c2_tilde = sample('c2_tilde', dist.InverseGamma(half_slab_df, half_slab_df)) 77 | c2 = slab_scale2 * c2_tilde 78 | 79 | lambd = sample('lambd', dist.HalfCauchy(1.), sample_shape=(M,)) 80 | lambd_tilde = tau**2 * c2 * lambd**2 / (c2 + tau**2 * lambd**2) 81 | par = sample('par', dist.MultivariateNormal(np.zeros(M,), np.diag(lambd_tilde))) 82 | 83 | IC = sample('IC', dist.Uniform(4, 6)) 84 | 85 | # compute kernel 86 | K_11 = W[0]*self.RBF(self.t_t[0], self.t_t[0], hyp[0]) + np.eye(self.N[0])*(noise[0] + self.jitter) 87 | K_22 = W[1]*self.RBF(self.t_t[1], self.t_t[1], hyp[1]) + np.eye(self.N[1])*(noise[1] + self.jitter) 88 | K = np.concatenate([np.concatenate([K_11, np.zeros((self.N[0], self.N[1]))], axis = 1), 89 | np.concatenate([np.zeros((self.N[1], self.N[0])), K_22], axis = 1)], axis = 0) 90 | 91 | # compute mean 92 | x0 = np.array([IC, 5.0]) 93 | mut = odeint(self.dxdt, x0, self.t.flatten(), par) 94 | mu1 = mut[self.i_t[0],ind[0]] / self.max_X[0] 95 | mu2 = mut[self.i_t[1],ind[1]] / self.max_X[1] 96 | mu = np.concatenate((mu1,mu2),axis=0) 97 | mu = mu.flatten('F') 98 | 99 | X = np.concatenate((self.X[0],self.X[1]),axis=0) 100 | X = X.flatten('F') 101 | 102 | # sample X according to the standard gaussian process formula 103 | sample("X", dist.MultivariateNormal(loc=mu, covariance_matrix=K), obs=X) 104 | 105 | # helper function for doing hmc inference 106 | def train(self, settings, rng_key): 107 | start = time.time() 108 | kernel = NUTS(self.model, 109 | target_accept_prob = settings['target_accept_prob']) 110 | mcmc = MCMC(kernel, 111 | num_warmup = settings['num_warmup'], 112 | num_samples = settings['num_samples'], 113 | num_chains = settings['num_chains'], 114 | progress_bar=True, 115 | jit_model_args=True) 116 | mcmc.run(rng_key, self.t, self.X) 117 | mcmc.print_summary() 118 | elapsed = time.time() - start 119 | print('\nMCMC elapsed time: %.2f seconds' % (elapsed)) 120 | return mcmc.get_samples() 121 | 122 | @partial(jit, static_argnums=(0,)) 123 | def predict(self, t_star, par, IC): 124 | x0_l = np.array([IC, 5.0]) 125 | X = odeint(self.dxdt, x0_l, t_star, par) 126 | return X 127 | 128 | plt.rcParams.update(plt.rcParamsDefault) 129 | plt.rc('font', family='serif') 130 | plt.rcParams.update({'font.size': 16, 131 | 'lines.linewidth': 2, 132 | 'axes.labelsize': 20, 133 | 'axes.titlesize': 20, 134 | 'xtick.labelsize': 16, 135 | 'ytick.labelsize': 16, 136 | 'legend.fontsize': 20, 137 | 'axes.linewidth': 2, 138 | "pgf.texsystem": "pdflatex", # change this if using xetex or lautex 139 | "text.usetex": True, # use LaTeX to write all text 140 | }) 141 | 142 | def LV(x, t, alpha, beta, gamma, delta): 143 | x1, x2 = x 144 | dxdt = [alpha*x1+beta*x1*x2, delta*x1*x2+gamma*x2] 145 | return dxdt 146 | 147 | def LV_dict(x, t, par): 148 | x1, x2 = x 149 | dxdt = [par[1]*x1 + par[2]*x2 + par[3]*x1*x2 + par[4]*x1**2 + par[5]*x2**2 + par[6]*x1**3 + par[0]*x2**3, 150 | par[9]*x1 + par[10]*x2 + par[11]*x1*x2 + par[12]*x1**2 + par[13]*x2**2 + par[8]*x1**3 + par[7]*x2**3] 151 | return dxdt 152 | 153 | key = random.PRNGKey(123) 154 | D = 2 155 | alpha = 1.0 156 | beta = -0.1 157 | gamma = -1.5 158 | delta = 0.75 159 | IC = 5.0 160 | noise = 0.1 161 | N = 110 162 | N_fine = 1100 163 | Tf = 16.5 164 | Tf_test = 30 165 | x0 = np.array([5.0, IC]) 166 | 167 | # Training data 168 | t_fine = np.linspace(0, Tf, N_fine+1) 169 | X_fine = odeint(LV, x0, t_fine, alpha, beta, gamma, delta) 170 | X_fine_noise = X_fine + noise*X_fine.std(0)*random.normal(key, X_fine.shape) 171 | 172 | t = t_fine[[list(range(0, N_fine+1, N_fine//N))]] 173 | X_train = X_fine_noise[list(range(0, N_fine+1, N_fine//N)),:] 174 | 175 | # Test data 176 | t_star = np.linspace(0, Tf_test, 2*N_fine+1) 177 | X_star = odeint(LV, x0, t_star, alpha, beta, gamma, delta) 178 | 179 | gap = 9 180 | ind_t = np.array([0]) 181 | ind_t = np.concatenate([ind_t[:,None],np.arange(gap+1,N+1)[:,None]]) 182 | ind_t = ind_t[:,0] 183 | 184 | j1 = list(range(0, ind_t.shape[0]+1, 2)) 185 | j2 = list([0]) + list( range(1, ind_t.shape[0]+1, 2) ) 186 | i1 = ind_t[j2] 187 | i2 = ind_t[j1] 188 | i1 = i1[1:] 189 | 190 | i_t_plot = [] 191 | i_t_plot.append(i1) 192 | i_t_plot.append(i2) 193 | 194 | i_t = [] 195 | i_t.append(i1) 196 | i_t.append(i2) 197 | 198 | ind = [0,1] 199 | 200 | X1_train = X_train[i_t[0],ind[0]] 201 | X2_train = X_train[i_t[1],ind[1]] 202 | 203 | X_train_ff = [] 204 | X_train_ff.append(X1_train) 205 | X_train_ff.append(X2_train) 206 | 207 | model = ODE_GP(t[:,None], i_t, X_train_ff, x0, LV_dict, ind) 208 | rng_key_train, rng_key_predict = random.split(random.PRNGKey(123)) 209 | 210 | num_warmup = 4000 211 | num_samples = 8000 212 | num_chains = 1 213 | target_accept_prob = 0.85 214 | settings = {'num_warmup': num_warmup, 215 | 'num_samples': num_samples, 216 | 'num_chains': num_chains, 217 | 'target_accept_prob': target_accept_prob} 218 | samples = model.train(settings, rng_key_train) 219 | 220 | print('True values: alpha = %f, beta = %f, gamma = %f, delta = %f' % (alpha, beta, gamma, delta)) 221 | 222 | np.save('data/par',np.array(samples['par'])) 223 | np.save('data/IC',np.array(samples['IC'])) 224 | np.save('data/noise',np.array(samples['noise'])) 225 | np.save('data/hyp',np.array(samples['hyp'])) 226 | np.save('data/W',np.array(samples['W'])) 227 | 228 | def RBF(x1, x2, params): 229 | diffs = (x1 / params).T - x2 / params 230 | return np.exp(-0.5 * diffs**2) 231 | 232 | Nt = N+1 233 | N_fine = 100 234 | t_test = np.linspace(0, Tf_test, 2*N_fine+1) 235 | Nt_test = t_test.shape[0] 236 | 237 | t_tr = t[:,None] /model.max_t 238 | t_te = t_test[:,None] /model.max_t 239 | 240 | vmap_args = (samples['par'], samples['IC']) 241 | pred_X_tr_i = lambda b, c: model.predict(t, b, c) 242 | X_tr_i = vmap(pred_X_tr_i)(*vmap_args) 243 | 244 | pred_X_ode_i = lambda b, c: model.predict(t_test, b, c) 245 | X_ode_i = vmap(pred_X_ode_i)(*vmap_args) 246 | 247 | X_pred_GP = [] 248 | Npred_GP_f = 0 249 | for i in range(num_samples): 250 | if i % 500 == 0: 251 | print(i) 252 | K1_tr = samples['W'][i,0]*RBF(model.t_t[0], model.t_t[0], samples['hyp'][i,0]) + np.eye(model.N[0])*(samples['noise'][i,0] + model.jitter) 253 | K2_tr = samples['W'][i,1]*RBF(model.t_t[1], model.t_t[1], samples['hyp'][i,1]) + np.eye(model.N[1])*(samples['noise'][i,1] + model.jitter) 254 | K_tr = np.concatenate([np.concatenate([K1_tr, np.zeros((model.N[0], model.N[1]))], axis = 1), 255 | np.concatenate([np.zeros((model.N[1], model.N[0])), K2_tr], axis = 1)], axis = 0) 256 | K1_trte = samples['W'][i,0]*RBF(t_te, model.t_t[0], samples['hyp'][i,0]) 257 | K2_trte = samples['W'][i,1]*RBF(t_te, model.t_t[1], samples['hyp'][i,1]) 258 | K_trte = np.concatenate([np.concatenate([K1_trte, np.zeros((model.N[0],Nt_test))], axis = 1), 259 | np.concatenate([np.zeros((model.N[1],Nt_test)), K2_trte], axis = 1)], axis = 0) 260 | K1_te = samples['W'][i,0]*RBF(t_te, t_te, samples['hyp'][i,0]) 261 | K2_te = samples['W'][i,1]*RBF(t_te, t_te, samples['hyp'][i,1]) 262 | K_te = np.concatenate([np.concatenate([K1_te, np.zeros((Nt_test,Nt_test))], axis = 1), 263 | np.concatenate([np.zeros((Nt_test,Nt_test)), K2_te], axis = 1)], axis = 0) 264 | X_tr1 = X_tr_i[i,i_t[0],ind[0]] / model.max_X[0] 265 | X_tr2 = X_tr_i[i,i_t[1],ind[1]] / model.max_X[1] 266 | X_tr = np.concatenate((X_tr1,X_tr2),axis=0) 267 | 268 | L = np.linalg.cholesky(K_tr) 269 | X_train_f = np.concatenate((model.X[0],model.X[1]),axis=0) 270 | X_train_f = X_train_f.flatten('F') 271 | dX = np.matmul( K_trte.T, np.linalg.solve(np.transpose(L), np.linalg.solve(L,X_train_f.flatten('F')-X_tr.flatten('F'))) ) 272 | X_ode1 = X_ode_i[i,:,ind[0]] / model.max_X[0] 273 | X_ode2 = X_ode_i[i,:,ind[1]] / model.max_X[1] 274 | X_ode = np.concatenate((X_ode1,X_ode2),axis=0) 275 | 276 | mu = X_ode.flatten('F') + dX 277 | K = K_te - np.matmul(K_trte.T, np.linalg.solve(np.transpose(L), np.linalg.solve(L,K_trte))) 278 | pred0 = onp.random.multivariate_normal(mu, K) 279 | pred0 = pred0.reshape((len(ind), Nt_test)).T 280 | pred0[:,0:1] = model.max_X[0] * pred0[:,0:1] 281 | pred0[:,1:2] = model.max_X[1] * pred0[:,1:2] 282 | pred = onp.array(X_ode_i[i,:,:]) 283 | pred[:,ind] = pred0 284 | if not math.isnan( np.amax(np.abs(pred)) ): 285 | Npred_GP_f += 1 286 | X_pred_GP.append( pred ) 287 | 288 | X_pred_GP = np.array(X_pred_GP) 289 | 290 | mean_prediction_GP, std_prediction_GP = np.mean(X_pred_GP, axis=0), np.std(X_pred_GP, axis=0) 291 | lower_GP = mean_prediction_GP - 2.0*std_prediction_GP 292 | upper_GP = mean_prediction_GP + 2.0*std_prediction_GP 293 | 294 | i_tr = -1 295 | for i in range(D): 296 | plt.figure(figsize = (12,6.5)) 297 | plt.xticks(fontsize=22) 298 | plt.yticks(fontsize=22) 299 | plt.plot(t_star, X_star[:,i], 'r-', label = "True Trajectory of $x_"+str(i+1)+"(t)$") 300 | if i in ind : 301 | i_tr += 1 302 | plt.plot(t[i_t_plot[i_tr]], X_train[i_t_plot[i_tr],i], 'ro', label = "Training data of $x_"+str(i+1)+"(t)$") 303 | else: 304 | plt.plot(t[0], X_train[0,i], 'ro', label = "Training data of $x_"+str(i+1)+"(t)$") 305 | plt.plot(t_test, mean_prediction_GP[:,i], 'g--', label = "MAP Trajectory of $x_"+str(i+1)+"(t)$") 306 | plt.fill_between(t_test, lower_GP[:,i], upper_GP[:,i], facecolor='orange', alpha=0.5, label="Two std band") 307 | plt.xlabel('$t$',fontsize=26) 308 | plt.ylabel('$x_'+str(i+1)+'(t)$',fontsize=26) 309 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 310 | plt.ylim(top= 1.7*X_star[:,i].max(0)) 311 | tt = 'plots/x_' + str(i+1) + ".png" 312 | plt.savefig(tt, dpi = 300) 313 | 314 | print(Npred_GP_f) 315 | 316 | import matplotlib as mpl 317 | 318 | def figsize(scale, nplots = 1): 319 | fig_width_pt = 390.0 # Get this from LaTeX using \the\textwidth 320 | inches_per_pt = 1.0/72.27 # Convert pt to inch 321 | golden_mean = (np.sqrt(5.0)-1.0)/2.0 # Aesthetic ratio (you could change this) 322 | fig_width = fig_width_pt*inches_per_pt*scale # width in inches 323 | fig_height = nplots*fig_width*golden_mean # height in inches 324 | fig_size = [fig_width,fig_height] 325 | return fig_size 326 | 327 | pgf_with_latex = { # setup matplotlib to use latex for output 328 | "pgf.texsystem": "pdflatex", # change this if using xetex or lautex 329 | "text.usetex": True, # use LaTeX to write all text 330 | "font.family": "serif", 331 | "font.serif": [], # blank entries should cause plots to inherit fonts from the document 332 | "font.sans-serif": [], 333 | "font.monospace": [], 334 | "axes.labelsize": 20, # LaTeX default is 10pt font. 335 | "axes.titlesize": 20, 336 | "axes.linewidth": 2, 337 | "font.size": 16, 338 | "lines.linewidth": 2, 339 | "legend.fontsize": 20, # Make the legend/label fonts a little smaller 340 | "xtick.labelsize": 24, 341 | "ytick.labelsize": 24, 342 | "figure.figsize": figsize(1.0), # default fig size of 0.9 textwidth 343 | "pgf.preamble": [ 344 | r"\usepackage[utf8x]{inputenc}", # use utf8 fonts becasue your computer can handle it :) 345 | r"\usepackage[T1]{fontenc}", # plots will be generated using this preamble 346 | ] 347 | } 348 | mpl.rcParams.update(pgf_with_latex) 349 | 350 | def newfig(width, nplots = 1): 351 | fig = plt.figure(figsize=figsize(width, nplots)) 352 | ax = fig.add_subplot(111) 353 | return fig, ax 354 | 355 | def savefig(filename, crop = True): 356 | if crop == True: 357 | plt.savefig('{}.png'.format(filename), bbox_inches='tight', pad_inches=0 , dpi = 100) 358 | else: 359 | plt.savefig('{}.png'.format(filename) , dpi = 100) 360 | 361 | a11 = samples['par'][:,1] 362 | a12 = samples['par'][:,2] 363 | a13 = samples['par'][:,3] 364 | a14 = samples['par'][:,4] 365 | a15 = samples['par'][:,5] 366 | a16 = samples['par'][:,6] 367 | a17 = samples['par'][:,0] 368 | a21 = samples['par'][:,9] 369 | a22 = samples['par'][:,10] 370 | a23 = samples['par'][:,11] 371 | a24 = samples['par'][:,12] 372 | a25 = samples['par'][:,13] 373 | a26 = samples['par'][:,8] 374 | a27 = samples['par'][:,7] 375 | 376 | Data = [a11,a12,a13,a14,a15,a16,a17,a21,a22,a23,a24,a25,a26,a27] 377 | 378 | true = [1.0, 0.0, -0.1, 0.0, 0.0, 0.0, 0.0, 0.0, -1.5, 0.75, 0.0, 0.0, 0.0, 0.0] 379 | name = [r'$a_{11} (\alpha)$',r'$a_{12}$',r'$a_{13} (\beta)$',r'$a_{14}$',r'$a_{15}$',r'$a_{16}$',r'$a_{17}$',r'$a_{21}$',r'$a_{22} (\gamma)$',r'$a_{23} (\delta)$',r'$a_{24}$',r'$a_{25}$',r'$a_{26}$',r'$a_{27}$'] 380 | 381 | fig7, ax7 = plt.subplots(figsize=(20, 10)) 382 | ax7.boxplot(Data, showfliers=False, labels=name) 383 | for i in range(len(name)): 384 | X = i + np.linspace(0.76, 1.24, 100)[:,None] 385 | plt.plot(X, true[i]*np.ones(X.shape[0]), 'b-') 386 | savefig('plots/box_plot', True) 387 | 388 | name = [r'$x_{1,0}$'] 389 | fig8, ax8 = plt.subplots(figsize=(5, 10)) 390 | ax8.boxplot(samples['IC'], showfliers=False, labels=name) 391 | X = np.linspace(0.93, 1.07, 100)[:,None] 392 | plt.plot(X, 5.0*np.ones(X.shape[0]), 'b-') 393 | plt.show() 394 | savefig('plots/box_plot_x0', True) 395 | -------------------------------------------------------------------------------- /Predator-prey/PP_GP_NODE_MCMC_convergence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pymc3 as pm3 4 | 5 | plt.rcParams.update(plt.rcParamsDefault) 6 | plt.rc('font', family='serif') 7 | plt.rcParams.update({'font.size': 16, 8 | 'lines.linewidth': 2, 9 | 'axes.labelsize': 20, # fontsize for x and y labels (was 10) 10 | 'axes.titlesize': 20, 11 | 'xtick.labelsize': 16, 12 | 'ytick.labelsize': 16, 13 | 'legend.fontsize': 20, 14 | 'axes.linewidth': 2, 15 | "pgf.texsystem": "pdflatex", # change this if using xetex or lautex 16 | "text.usetex": True, # use LaTeX to write all text 17 | }) 18 | 19 | data2 = np.load("data/par.npy",allow_pickle=True) 20 | data2_IC = np.load("data/IC.npy",allow_pickle=True) 21 | data3 = np.load("data/par_2.npy",allow_pickle=True) 22 | data3_IC = np.load("data/IC_2.npy",allow_pickle=True) 23 | 24 | N_var = 5 25 | i0 = 0 26 | ie = 8000 27 | step = 8 28 | 29 | a11_1 = data2[i0:ie:step,1:2] # alpha = 1 30 | a13_1 = data2[i0:ie:step,3:4] # beta = -0.1 31 | a22_1 = data2[i0:ie:step,10:11] # gamma = -1.5 32 | a23_1 = data2[i0:ie:step,11:12] # delta = 0.75 33 | IC_1 = data2_IC[i0:ie:step] # 5.0 34 | IC_1 = IC_1[:,None] 35 | 36 | a11_2 = data3[i0:ie:step,1:2] # alpha 37 | a13_2 = data3[i0:ie:step,3:4] # beta 38 | a22_2 = data3[i0:ie:step,10:11] # gamma 39 | a23_2 = data3[i0:ie:step,11:12] # delta 40 | IC_2 = data3_IC[i0:ie:step] 41 | IC_2 = IC_2[:,None] 42 | 43 | names = [r'$a_{11} \ (\alpha)$',r'$a_{13} \ (\beta)$',r'$a_{22} \ (\gamma)$',r'$a_{23} \ (\delta)$',r'$x_{1,0}$'] 44 | 45 | N = a11_1.shape[0] 46 | iteration = np.arange(0,N) 47 | 48 | 49 | data_chain1 = np.concatenate((a11_1,a13_1,a22_1,a23_1,IC_1),axis=-1) 50 | data_chain2 = np.concatenate((a11_2,a13_2,a22_2,a23_2,IC_2),axis=-1) 51 | 52 | data_traceplot1 = {} 53 | data_traceplot2 = {} 54 | for i,name in enumerate(names): 55 | data_traceplot1[name] = data_chain1[:,i] 56 | data_traceplot2[name] = data_chain2[:,i] 57 | 58 | for i in range(N_var): 59 | 60 | chain1 = data_chain1[:,i:i+1] 61 | chain2 = data_chain2[:,i:i+1] 62 | 63 | burn_in = 0 64 | length = (ie-i0)//step 65 | 66 | n = chain1[burn_in:burn_in+length].shape[0] 67 | 68 | W = (chain1[burn_in:burn_in+length].std()**2 + chain2[burn_in:burn_in+length].std()**2)/2 69 | mean1 = chain1[burn_in:burn_in+length].mean() 70 | mean2 = chain2[burn_in:burn_in+length].mean() 71 | mean = (mean1 + mean2)/2 72 | B = n * ((mean1 - mean)**2 + (mean2 - mean)**2) 73 | var_theta = (1 - 1/n) * W + 1/n*B 74 | print("Gelman-Rubin Diagnostic: ", np.sqrt(var_theta/W)) 75 | 76 | corr_plot1 = pm3.autocorrplot(data_traceplot1,var_names=names,grid=(1,N_var),figsize=(12,6.5),textsize=18,combined=True) 77 | corr_plot1 = corr_plot1[None,:] 78 | for i in range(N_var): 79 | corr_plot1[0, i].set_xlabel('Lag Index',fontsize=26) 80 | corr_plot1[0, 0].set_ylabel('Autocorrelation Value',fontsize=26) 81 | plt.savefig("plots/autocorrelation.png", bbox_inches='tight', pad_inches=0.01) 82 | 83 | plt.figure(figsize=(12,6.5)) 84 | for i in range(data_chain1.shape[1]): 85 | gw_plot = pm3.geweke(data_chain1[:,i],.1,.5,20) 86 | plt.scatter(gw_plot[:,0],gw_plot[:,1],label="%s"%names[i]) 87 | plt.axhline(-1.98, c='r') 88 | plt.axhline(1.98, c='r') 89 | plt.xticks(fontsize=22) 90 | plt.yticks(fontsize=22) 91 | plt.xlabel("Subchain sample number",fontsize=26) 92 | plt.ylabel("Geweke z-score",fontsize=26) 93 | plt.title('Geweke Plot Comparing first 10$\%$ and Slices of the Last 50$\%$ of Chain') 94 | 95 | plt.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left') 96 | plt.tight_layout() 97 | plt.show() 98 | plt.savefig("plots/geweke.png", pad_inches=0.01) 99 | 100 | -------------------------------------------------------------------------------- /Predator-prey/PP_sindy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Feb 6 00:03:19 2021 5 | 6 | @author: mohamedazizbhouri 7 | """ 8 | 9 | ################################################################################ 10 | 11 | import numpy as onp 12 | import matplotlib.pyplot as plt 13 | 14 | import jax.numpy as np 15 | import jax.random as random 16 | from jax.experimental.ode import odeint as odeint_jax 17 | from jax.config import config 18 | config.update("jax_enable_x64", True) 19 | 20 | import pysindy as ps 21 | 22 | onp.random.seed(1234) 23 | 24 | plt.rcParams.update(plt.rcParamsDefault) 25 | plt.rc('font', family='serif') 26 | plt.rcParams.update({'font.size': 16, 27 | 'lines.linewidth': 2, 28 | 'axes.labelsize': 20, # fontsize for x and y labels 29 | 'axes.titlesize': 20, 30 | 'xtick.labelsize': 16, 31 | 'ytick.labelsize': 16, 32 | 'legend.fontsize': 20, 33 | 'axes.linewidth': 2, 34 | "pgf.texsystem": "pdflatex", # change this if using xetex or lautex 35 | "text.usetex": True, # use LaTeX to write all text 36 | }) 37 | dpiv = 100 38 | 39 | ################################################################################ 40 | 41 | def LV(x, t, alpha, beta, gamma, delta): 42 | x1, x2 = x 43 | dxdt = [alpha*x1+beta*x1*x2, delta*x1*x2+gamma*x2] 44 | return dxdt 45 | 46 | key = random.PRNGKey(123) 47 | D = 2 48 | alpha = 1.0 49 | beta = -0.1 50 | gamma = -1.5 51 | delta = 0.75 52 | IC = 5.0 53 | 54 | noise = 0.1 55 | N_fine = 1100 56 | Tf = 16.5 57 | Tf_test = 30 58 | x0_onp = [5.0, IC] 59 | x0 = np.array(x0_onp) 60 | 61 | # Test data 62 | t_star = np.linspace(0, Tf_test, 2*N_fine+1) 63 | t_grid_test = onp.array(t_star) 64 | data_test = onp.array( odeint_jax(LV, x0, t_grid_test, alpha, beta, gamma, delta) ) 65 | 66 | library_functions = [ 67 | lambda x : x, 68 | lambda x,y : x*y, 69 | lambda x : x**2, 70 | lambda x : x**3 71 | ] 72 | 73 | library_function_names = [ 74 | lambda x : x, 75 | lambda x,y : x + '.' + y, 76 | lambda x : x + '^2', 77 | lambda x : x + '^3' 78 | ] 79 | custom_library = ps.CustomLibrary( 80 | library_functions=library_functions, function_names=library_function_names 81 | ) 82 | 83 | ################################################################################ 84 | 85 | N = 55 86 | N_fine = 1100 87 | 88 | # Training data 89 | t_fine = np.linspace(0, Tf, N_fine+1) 90 | X_fine = odeint_jax(LV, x0, t_fine, alpha, beta, gamma, delta) 91 | X_fine_noise = X_fine + noise*X_fine.std(0)*random.normal(key, X_fine.shape) 92 | t = t_fine[onp.array( list(range(0, N_fine+1, N_fine//N)) )] 93 | X_train = X_fine_noise[list(range(0, N_fine+1, N_fine//N)),:] 94 | 95 | gap = 4 96 | ind_t = np.array([0]) 97 | ind_t = np.concatenate([ind_t[:,None],np.arange(gap+1,N+1)[:,None]]) 98 | ind_t = ind_t[:,0] 99 | 100 | t_grid = onp.array(t[ind_t]) 101 | print('case_1_as_GP_NODE',t_grid,t_grid.shape) 102 | model_GP_ODE = ps.SINDy(feature_library=custom_library) 103 | model_GP_ODE.fit(onp.array(X_train[ind_t,:]), t=t_grid) 104 | print('case_1_as_GP_NODE:') 105 | model_GP_ODE.print() 106 | x_test_sim = model_GP_ODE.simulate(x0_onp, t_grid_test) 107 | 108 | plt.figure(figsize=(12,6.5)) 109 | plt.xticks(fontsize=22) 110 | plt.yticks(fontsize=22) 111 | plt.plot(t_grid_test, data_test[:,0],'r-', label = "True trajectory of $x_1(t)$") 112 | plt.plot(t_grid, X_train[ind_t,0], 'ro', label = "Training data of $x_1(t)$") 113 | plt.plot(t_grid_test, x_test_sim[:,0],'g--', label = "SINDy prediction of $x_1(t)$") 114 | plt.xlabel('$t$',fontsize=26) 115 | plt.ylabel('$x_1(t)$',fontsize=26) 116 | plt.ylim((0.0, 8.0)) 117 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 118 | plt.savefig('plots_sindy/case_1_as_GP_NODE_x1.png', dpi = dpiv) 119 | 120 | plt.figure(figsize=(12,6.5)) 121 | plt.xticks(fontsize=22) 122 | plt.yticks(fontsize=22) 123 | plt.plot(t_grid_test, data_test[:,1],'r-', label = "True trajectory of $x_2(t)$") 124 | plt.plot(t_grid, X_train[ind_t,1], 'ro', label = "Training data of $x_2(t)$") 125 | plt.plot(t_grid_test, x_test_sim[:,1],'g--', label = "SINDy prediction of $x_2(t)$") 126 | plt.xlabel('$t$',fontsize=26) 127 | plt.ylabel('$x_2(t)$',fontsize=26) 128 | plt.ylim((0.0, 45.0)) 129 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 130 | plt.savefig('plots_sindy/case_1_as_GP_NODE_x2.png', dpi = dpiv) 131 | 132 | ################################################################################ 133 | 134 | N = 7040 # dt = 0.00234375 135 | N_fine = 35200 136 | 137 | # Training data 138 | t_fine = np.linspace(0, Tf, N_fine+1) 139 | X_fine = odeint_jax(LV, x0, t_fine, alpha, beta, gamma, delta) 140 | X_fine_noise = X_fine + noise*X_fine.std(0)*random.normal(key, X_fine.shape) 141 | t = t_fine[onp.array( list(range(0, N_fine+1, N_fine//N)) )] 142 | X_train = X_fine_noise[list(range(0, N_fine+1, N_fine//N)),:] 143 | 144 | gap = 639 145 | ind_t = np.array([0]) 146 | ind_t = np.concatenate([ind_t[:,None],np.arange(gap+1,N+1)[:,None]]) 147 | ind_t = ind_t[:,0] 148 | 149 | t_grid = onp.array(t[ind_t]) 150 | print('case_2_fine_dt',t_grid,t_grid.shape) 151 | model_GP_ODE_fine_dt = ps.SINDy(feature_library=custom_library) 152 | model_GP_ODE_fine_dt.fit(onp.array(X_train[ind_t,:]), t=t_grid) 153 | print('case_2_fine_dt:') 154 | model_GP_ODE_fine_dt.print() 155 | x_test_sim = model_GP_ODE_fine_dt.simulate(x0_onp, t_grid_test) 156 | 157 | plt.figure(figsize=(12,6.5)) 158 | plt.xticks(fontsize=22) 159 | plt.yticks(fontsize=22) 160 | plt.plot(t_grid_test, data_test[:,0],'r-', label = "True trajectory of $x_1(t)$") 161 | plt.plot(t_grid, X_train[ind_t,0], 'ro', label = "Training data of $x_1(t)$") 162 | plt.plot(t_grid_test, x_test_sim[:,0],'g--', label = "SINDy prediction of $x_1(t)$") 163 | plt.xlabel('$t$',fontsize=26) 164 | plt.ylabel('$x_1(t)$',fontsize=26) 165 | plt.ylim((0.0, 8.0)) 166 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 167 | plt.savefig('plots_sindy/case_2_fine_dt_x1.png', dpi = dpiv) 168 | 169 | plt.figure(figsize=(12,6.5)) 170 | plt.xticks(fontsize=22) 171 | plt.yticks(fontsize=22) 172 | plt.plot(t_grid_test, data_test[:,1],'r-', label = "True trajectory of $x_2(t)$") 173 | plt.plot(t_grid, X_train[ind_t,1], 'ro', label = "Training data of $x_2(t)$") 174 | plt.plot(t_grid_test, x_test_sim[:,1],'g--', label = "SINDy prediction of $x_2(t)$") 175 | plt.xlabel('$t$',fontsize=26) 176 | plt.ylabel('$x_2(t)$',fontsize=26) 177 | plt.ylim((0.0, 45.0)) 178 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 179 | plt.savefig('plots_sindy/case_2_fine_dt_x2.png', dpi = dpiv) 180 | 181 | ################################################################################ 182 | 183 | N = 7040 # dt = 0.00234375 184 | N_fine = 35200 185 | 186 | # Training data 187 | t_fine = np.linspace(0, Tf, N_fine+1) 188 | X_fine = odeint_jax(LV, x0, t_fine, alpha, beta, gamma, delta) 189 | X_fine_noise = X_fine + noise*X_fine.std(0)*random.normal(key, X_fine.shape) 190 | 191 | t = t_fine[onp.array( list(range(0, N_fine+1, N_fine//N)) )] 192 | X_train = X_fine_noise[list(range(0, N_fine+1, N_fine//N)),:] 193 | 194 | gap = 639 195 | ind_t = np.array([0]) 196 | ind_t = np.concatenate([ind_t[:,None],np.arange(gap+1,N+1)[:,None]]) 197 | ind_t = ind_t[:,0] 198 | 199 | t_grid = onp.array(t[ind_t]) 200 | print('case_3_fine_dt_no_noise',t_grid,t_grid.shape) 201 | model_GP_ODE_fine_dt_no_noise = ps.SINDy(feature_library=custom_library) 202 | model_GP_ODE_fine_dt_no_noise.fit(onp.array(X_train[ind_t,:]), t=t_grid) 203 | print('case_3_fine_dt_no_noise:') 204 | model_GP_ODE_fine_dt_no_noise.print() 205 | x_test_sim = model_GP_ODE_fine_dt_no_noise.simulate(x0_onp, t_grid_test) 206 | 207 | plt.figure(figsize=(12,6.5)) 208 | plt.xticks(fontsize=22) 209 | plt.yticks(fontsize=22) 210 | plt.plot(t_grid_test, data_test[:,0],'r-', label = "True trajectory of $x_1(t)$") 211 | plt.plot(t_grid, X_train[ind_t,0], 'ro', label = "Training data of $x_1(t)$") 212 | plt.plot(t_grid_test, x_test_sim[:,0],'g--', label = "SINDy prediction of $x_1(t)$") 213 | plt.xlabel('$t$',fontsize=26) 214 | plt.ylabel('$x_1(t)$',fontsize=26) 215 | plt.ylim((0.0, 8.0)) 216 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 217 | plt.savefig('plots_sindy/case_3_fine_dt_no_noise_x1.png', dpi = dpiv) 218 | 219 | plt.figure(figsize=(12,6.5)) 220 | plt.xticks(fontsize=22) 221 | plt.yticks(fontsize=22) 222 | plt.plot(t_grid_test, data_test[:,1],'r-', label = "True trajectory of $x_2(t)$") 223 | plt.plot(t_grid, X_train[ind_t,1], 'ro', label = "Training data of $x_2(t)$") 224 | plt.plot(t_grid_test, x_test_sim[:,1],'g--', label = "SINDy prediction of $x_2(t)$") 225 | plt.xlabel('$t$',fontsize=26) 226 | plt.ylabel('$x_2(t)$',fontsize=26) 227 | plt.ylim((0.0, 45.0)) 228 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 229 | plt.savefig('plots_sindy/case_3_fine_dt_no_noise_x2.png', dpi = dpiv) 230 | 231 | ################################################################################ 232 | 233 | N = 110 234 | t_grid = onp.linspace(0, Tf, N//2+1) 235 | data = odeint_jax(LV, x0, t_grid, alpha, beta, gamma, delta) 236 | data = onp.array( data + noise*data.std(0)*random.normal(key, data.shape) ) 237 | 238 | print('case_4_no_t_gap_large_dt',t_grid,t_grid.shape) 239 | model_no_t_gap_large_dt = ps.SINDy(feature_library=custom_library) 240 | model_no_t_gap_large_dt.fit(data, t=t_grid) # data Nt x D 241 | print('case_4_no_t_gap_large_dt:') 242 | model_no_t_gap_large_dt.print() 243 | x_test_sim = model_no_t_gap_large_dt.simulate(x0_onp, t_grid_test) 244 | 245 | plt.figure(figsize=(12,6.5)) 246 | plt.xticks(fontsize=22) 247 | plt.yticks(fontsize=22) 248 | plt.plot(t_grid_test, data_test[:,0],'r-', label = "True trajectory of $x_1(t)$") 249 | plt.plot(t_grid, data[:,0], 'ro', label = "Training data of $x_1(t)$") 250 | plt.plot(t_grid_test, x_test_sim[:,0],'g--', label = "SINDy prediction of $x_1(t)$") 251 | plt.xlabel('$t$',fontsize=26) 252 | plt.ylabel('$x_1(t)$',fontsize=26) 253 | plt.ylim((0.0, 8.0)) 254 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 255 | plt.savefig('plots_sindy/case_4_no_t_gap_large_dt_x1.png', dpi = dpiv) 256 | 257 | plt.figure(figsize=(12,6.5)) 258 | plt.xticks(fontsize=22) 259 | plt.yticks(fontsize=22) 260 | plt.plot(t_grid_test, data_test[:,1],'r-', label = "True trajectory of $x_2(t)$") 261 | plt.plot(t_grid, data[:,1], 'ro', label = "Training data of $x_2(t)$") 262 | plt.plot(t_grid_test, x_test_sim[:,1],'g--', label = "SINDy prediction of $x_2(t)$") 263 | plt.xlabel('$t$',fontsize=26) 264 | plt.ylabel('$x_2(t)$',fontsize=26) 265 | plt.ylim((0.0, 45.0)) 266 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 267 | plt.savefig('plots_sindy/case_4_no_t_gap_large_dt_x2.png', dpi = dpiv) 268 | 269 | ################################################################################ 270 | ################################################################################ 271 | ################################################################################ 272 | ################################################################################ 273 | 274 | N = 110 275 | t_grid = onp.linspace(0, Tf, 4*N+1) # N//2+1) # 441 dt = 0.0375 276 | data = odeint_jax(LV, x0, t_grid, alpha, beta, gamma, delta) 277 | data = onp.array( data + noise*data.std(0)*random.normal(key, data.shape) ) 278 | 279 | print('case_5_no_t_gap_small_dt',t_grid,t_grid.shape) 280 | model_no_t_gap_small_dt = ps.SINDy(feature_library=custom_library) 281 | model_no_t_gap_small_dt.fit(data, t=t_grid) # data Nt x D 282 | print('case_5_no_t_gap_small_dt:') 283 | model_no_t_gap_small_dt.print() 284 | x_test_sim = model_no_t_gap_small_dt.simulate(x0_onp, t_grid_test) 285 | 286 | plt.figure(figsize=(12,6.5)) 287 | plt.xticks(fontsize=22) 288 | plt.yticks(fontsize=22) 289 | plt.plot(t_grid_test, data_test[:,0],'r-', label = "True trajectory of $x_1(t)$") 290 | plt.plot(t_grid, data[:,0], 'ro', label = "Training data of $x_1(t)$") 291 | plt.plot(t_grid_test, x_test_sim[:,0],'g--', label = "SINDy prediction of $x_1(t)$") 292 | plt.xlabel('$t$',fontsize=26) 293 | plt.ylabel('$x_1(t)$',fontsize=26) 294 | plt.ylim((0.0, 8.)) 295 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 296 | plt.savefig('plots_sindy/case_5_no_t_gap_small_dt_x1.png', dpi = dpiv) 297 | 298 | plt.figure(figsize=(12,6.5)) 299 | plt.xticks(fontsize=22) 300 | plt.yticks(fontsize=22) 301 | plt.plot(t_grid_test, data_test[:,1],'r-', label = "True trajectory of $x_2(t)$") 302 | plt.plot(t_grid, data[:,1], 'ro', label = "Training data of $x_2(t)$") 303 | plt.plot(t_grid_test, x_test_sim[:,1],'g--', label = "SINDy prediction of $x_2(t)$") 304 | plt.xlabel('$t$',fontsize=26) 305 | plt.ylabel('$x_2(t)$',fontsize=26) 306 | plt.ylim((0.0, 45.0)) 307 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 308 | plt.savefig('plots_sindy/case_5_no_t_gap_small_dt_x2.png', dpi = dpiv) 309 | -------------------------------------------------------------------------------- /Predator-prey/data/IC.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/data/IC.npy -------------------------------------------------------------------------------- /Predator-prey/data/IC_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/data/IC_2.npy -------------------------------------------------------------------------------- /Predator-prey/data/W.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/data/W.npy -------------------------------------------------------------------------------- /Predator-prey/data/W_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/data/W_2.npy -------------------------------------------------------------------------------- /Predator-prey/data/hyp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/data/hyp.npy -------------------------------------------------------------------------------- /Predator-prey/data/hyp_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/data/hyp_2.npy -------------------------------------------------------------------------------- /Predator-prey/data/noise.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/data/noise.npy -------------------------------------------------------------------------------- /Predator-prey/data/noise_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/data/noise_2.npy -------------------------------------------------------------------------------- /Predator-prey/data/par.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/data/par.npy -------------------------------------------------------------------------------- /Predator-prey/data/par_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/data/par_2.npy -------------------------------------------------------------------------------- /Predator-prey/plots/autocorrelation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots/autocorrelation.png -------------------------------------------------------------------------------- /Predator-prey/plots/box_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots/box_plot.png -------------------------------------------------------------------------------- /Predator-prey/plots/box_plot_x0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots/box_plot_x0.png -------------------------------------------------------------------------------- /Predator-prey/plots/geweke.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots/geweke.png -------------------------------------------------------------------------------- /Predator-prey/plots/x_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots/x_1.png -------------------------------------------------------------------------------- /Predator-prey/plots/x_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots/x_2.png -------------------------------------------------------------------------------- /Predator-prey/plots_sindy/case_1_as_GP_NODE_x1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots_sindy/case_1_as_GP_NODE_x1.png -------------------------------------------------------------------------------- /Predator-prey/plots_sindy/case_1_as_GP_NODE_x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots_sindy/case_1_as_GP_NODE_x2.png -------------------------------------------------------------------------------- /Predator-prey/plots_sindy/case_2_fine_dt_x1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots_sindy/case_2_fine_dt_x1.png -------------------------------------------------------------------------------- /Predator-prey/plots_sindy/case_2_fine_dt_x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots_sindy/case_2_fine_dt_x2.png -------------------------------------------------------------------------------- /Predator-prey/plots_sindy/case_3_fine_dt_no_noise_x1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots_sindy/case_3_fine_dt_no_noise_x1.png -------------------------------------------------------------------------------- /Predator-prey/plots_sindy/case_3_fine_dt_no_noise_x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots_sindy/case_3_fine_dt_no_noise_x2.png -------------------------------------------------------------------------------- /Predator-prey/plots_sindy/case_4_no_t_gap_large_dt_x1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots_sindy/case_4_no_t_gap_large_dt_x1.png -------------------------------------------------------------------------------- /Predator-prey/plots_sindy/case_4_no_t_gap_large_dt_x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots_sindy/case_4_no_t_gap_large_dt_x2.png -------------------------------------------------------------------------------- /Predator-prey/plots_sindy/case_5_no_t_gap_small_dt_x1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots_sindy/case_5_no_t_gap_small_dt_x1.png -------------------------------------------------------------------------------- /Predator-prey/plots_sindy/case_5_no_t_gap_small_dt_x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Predator-prey/plots_sindy/case_5_no_t_gap_small_dt_x2.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Gaussian processes meet NeuralODEs 2 | 3 | Code and data accompanying the manuscript titled "Gaussian processes meet NeuralODEs: A Bayesian framework for learning the dynamics of partially observed systems from scarce and noisy data", authored by Mohamed Aziz Bhouri and Paris Perdikaris. 4 | 5 | ## Abstract 6 | 7 | This paper presents a machine learning framework for Bayesian systems identification from partial, noisy and irregular observations of nonlinear dynamical systems. The proposed method takes advantage of recent developments in differentiable programming to propagate gradient information through ordinary differential equation solvers and perform Bayesian inference with respect to unknown model parameters using Markov Chain Monte Carlo and Gaussian Process priors over the observed system states. This allows us to exploit temporal correlations in the observed data, and efficiently infer posterior distributions over plausible models with quantified uncertainty. Moreover, the use of sparsity-promoting priors such as the Finnish Horseshoe for free model parameters enables the discovery of interpretable and parsimonious representations for the underlying latent dynamics. A series of numerical studies is presented to demonstrate the effectiveness of the proposed methods including predator-prey systems, systems biology, and a 50-dimensional human motion dynamical system. Taken together, our findings put forth a novel, flexible and robust workflow for data-driven model discovery under uncertainty. 8 | 9 | ## Citation 10 | 11 | @article{Bhouri2022GPNode, 12 | author = {Bhouri, Mohamed Aziz and Perdikaris, Paris }, 13 | title = {Gaussian processes meet NeuralODEs: a Bayesian framework for learning the dynamics of partially observed systems from scarce and noisy data}, 14 | journal = {Philosophical Transactions of the Royal Society A: Mathematical, Physical and Engineering Sciences}, 15 | volume = {380}, 16 | number = {2229}, 17 | pages = {20210201}, 18 | year = {2022}, 19 | doi = {10.1098/rsta.2021.0201}, 20 | URL = {https://royalsocietypublishing.org/doi/abs/10.1098/rsta.2021.0201}, 21 | eprint = {https://royalsocietypublishing.org/doi/pdf/10.1098/rsta.2021.0201} 22 | } 23 | -------------------------------------------------------------------------------- /Yeast-Glycolysis/YG_GP_NODE.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Sep 29 15:42:18 2020 5 | 6 | @author: mohamedazizbhouri 7 | """ 8 | 9 | import math 10 | 11 | import jax.numpy as np 12 | import jax.random as random 13 | from jax import vmap, jit 14 | from jax.experimental.ode import odeint 15 | from jax.config import config 16 | config.update("jax_enable_x64", True) 17 | 18 | from numpyro import sample 19 | import numpyro.distributions as dist 20 | from numpyro.infer import MCMC, NUTS 21 | 22 | import numpy as onp 23 | import matplotlib.pyplot as plt 24 | from functools import partial 25 | import time 26 | 27 | class ODE_GP: 28 | # Initialize the class 29 | def __init__(self, t, i_t, X_train_ff, x0, dxdt, ind): 30 | # Normalization 31 | self.t = t 32 | self.x0 = x0 33 | self.i_t = i_t 34 | self.dxdt = dxdt 35 | self.jitter = 1e-8 36 | self.ind = ind 37 | 38 | self.max_t = t.max(0) 39 | 40 | self.max_X = [] 41 | self.X = [] 42 | self.N = [] 43 | self.D = len(i_t) 44 | self.t_t = [] 45 | for i in range(len(i_t)): 46 | self.max_X.append(np.abs(X_train_ff[i]).max(0)) 47 | self.X.append(X_train_ff[i] / self.max_X[i]) 48 | self.N.append(X_train_ff[i].shape[0]) 49 | self.t_t.append(t[self.i_t[i]]/self.max_t) 50 | 51 | @partial(jit, static_argnums=(0,)) 52 | def RBF(self,x1, x2, params): 53 | diffs = (x1 / params).T - x2 / params 54 | return np.exp(-0.5 * diffs**2) 55 | 56 | def model(self, t, X): 57 | 58 | noise = sample('noise', dist.LogNormal(0.0, 1.0), sample_shape=(self.D,)) 59 | hyp = sample('hyp', dist.Gamma(1.0, 0.5), sample_shape=(self.D,)) 60 | W = sample('W', dist.LogNormal(0.0, 1.0), sample_shape=(self.D,)) 61 | 62 | J0 = sample('J0', dist.Uniform(1.0, 10.0)) # 2.5 63 | k1 = sample('k1', dist.Uniform(80., 120.0)) # 100. 64 | k2 = sample('k2', dist.Uniform(1., 10.0)) # 6. 65 | k3 = sample('k3', dist.Uniform(2., 20.0)) # 16. 66 | k4 = sample('k4', dist.Uniform(80., 120.0)) # 100. 67 | k5 = sample('k5', dist.Uniform(0.1, 2.0)) # 1.28 68 | k6 = sample('k6', dist.Uniform(2., 20.0)) # 12. 69 | k = sample('k', dist.Uniform(0.1, 2.0)) # 1.8 70 | ka = sample('ka', dist.Uniform(2., 20.0)) # 13. 71 | q = sample('q', dist.Uniform(1., 10.0)) # 4. 72 | KI = sample('KI', dist.Uniform(0.1, 2.0)) # 0.52 73 | phi = sample('phi', dist.Uniform(0.05, 1.0)) # 0.1 74 | Np = sample('Np', dist.Uniform(0.1, 2.0)) # 1. 75 | A = sample('A', dist.Uniform(1., 10.0)) #4. 76 | 77 | IC = sample('IC', dist.Uniform(0, 1)) 78 | 79 | # compute kernel 80 | K_11 = W[0]*self.RBF(self.t_t[0], self.t_t[0], hyp[0]) + np.eye(self.N[0])*(noise[0] + self.jitter) 81 | K_22 = W[1]*self.RBF(self.t_t[1], self.t_t[1], hyp[1]) + np.eye(self.N[1])*(noise[1] + self.jitter) 82 | K_33 = W[2]*self.RBF(self.t_t[2], self.t_t[2], hyp[2]) + np.eye(self.N[2])*(noise[2] + self.jitter) 83 | K = np.concatenate([np.concatenate([K_11, np.zeros((self.N[0], self.N[1])), np.zeros((self.N[0], self.N[2]))], axis = 1), 84 | np.concatenate([np.zeros((self.N[1], self.N[0])), K_22, np.zeros((self.N[1], self.N[2]))], axis = 1), 85 | np.concatenate([np.zeros((self.N[2], self.N[0])), np.zeros((self.N[2], self.N[1])), K_33], axis = 1)], axis = 0) 86 | 87 | # compute mean 88 | x0 = np.array([0.5, 1.9, 0.18, 0.15, IC, 0.1, 0.064]) 89 | mut = odeint(self.dxdt, x0, self.t.flatten(), J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A) 90 | mu1 = mut[self.i_t[0],ind[0]] / self.max_X[0] 91 | mu2 = mut[self.i_t[1],ind[1]] / self.max_X[1] 92 | mu3 = mut[self.i_t[2],ind[2]] / self.max_X[2] 93 | mu = np.concatenate((mu1,mu2,mu3),axis=0) 94 | 95 | # Concat data 96 | mu = mu.flatten('F') 97 | X = np.concatenate((self.X[0],self.X[1],self.X[2]),axis=0) 98 | X = X.flatten('F') 99 | 100 | # sample X according to the standard gaussian process formula 101 | sample("X", dist.MultivariateNormal(loc=mu, covariance_matrix=K), obs=X) 102 | 103 | # helper function for doing hmc inference 104 | def train(self, settings, rng_key): 105 | start = time.time() 106 | kernel = NUTS(self.model, 107 | target_accept_prob = settings['target_accept_prob']) 108 | mcmc = MCMC(kernel, 109 | num_warmup = settings['num_warmup'], 110 | num_samples = settings['num_samples'], 111 | num_chains = settings['num_chains'], 112 | progress_bar=True, 113 | jit_model_args=True) 114 | mcmc.run(rng_key, self.t, self.X) 115 | mcmc.print_summary() 116 | elapsed = time.time() - start 117 | print('\nMCMC elapsed time: %.2f seconds' % (elapsed)) 118 | return mcmc.get_samples() 119 | 120 | @partial(jit, static_argnums=(0,)) 121 | def predict(self, t_star, J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A, N20): 122 | x0_l = np.array([0.5, 1.9, 0.18, 0.15, N20, 0.1, 0.064]) 123 | X = odeint(self.dxdt, x0_l, t_star, J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A) 124 | return X 125 | 126 | plt.rcParams.update(plt.rcParamsDefault) 127 | plt.rc('font', family='serif') 128 | plt.rcParams.update({'font.size': 16, 129 | 'lines.linewidth': 2, 130 | 'axes.labelsize': 20, 131 | 'axes.titlesize': 20, 132 | 'xtick.labelsize': 16, 133 | 'ytick.labelsize': 16, 134 | 'legend.fontsize': 20, 135 | 'axes.linewidth': 2, 136 | "pgf.texsystem": "pdflatex", # change this if using xetex or lautex 137 | "text.usetex": True, # use LaTeX to write all text 138 | }) 139 | 140 | def glyc(x, t, J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A): 141 | S1, S2, S3, S4, N2, A3, S4ex = x 142 | J = ka*(S4-S4ex) 143 | N1 = Np-N2 144 | A2 = A-A3 145 | v1 = k1*S1*A3/(1+(A3/KI)**q) 146 | v2 = k2*S2*N1 147 | v3 = k3*S3*A2 148 | v4 = k4*S4*N2 149 | v5 = k5*A3 150 | v6 = k6*S2*N2 151 | v7 = k*S4ex 152 | dxdt = [J0-v1, 2*v1-v2-v6, v2-v3, v3-v4-J, v2-v4-v6, -2*v1+2*v3-v5, phi*J-v7] 153 | return dxdt 154 | 155 | key = random.PRNGKey(1234) 156 | D = 7 157 | J0 = 2.5 158 | k1 = 100. 159 | k2 = 6. 160 | k3 = 16. 161 | k4 = 100. 162 | k5 = 1.28 163 | k6 = 12. 164 | k = 1.8 165 | ka = 13. 166 | q = 4. 167 | KI = 0.52 168 | phi = 0.1 169 | Np = 1. 170 | A = 4. 171 | IC = 0.16 172 | 173 | noise = 0.1 174 | N = 120 175 | N_fine = 1200 176 | Tf = 3 177 | Tf_test = 6 178 | x0 = np.array([0.5, 1.9, 0.18, 0.15, IC, 0.1, 0.064]) 179 | 180 | # Training data 181 | t_fine = np.linspace(0, Tf, N_fine+1) 182 | X_fine = odeint(glyc, x0, t_fine, J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A) 183 | X_fine_noise = X_fine + noise*X_fine.std(0)*random.normal(key, X_fine.shape) 184 | 185 | t = t_fine[[list(range(0, N_fine+1, N_fine//N))]] 186 | X_train = X_fine_noise[list(range(0, N_fine+1, N_fine//N)),:] 187 | 188 | # Test data 189 | t_star = np.linspace(0, Tf_test, 2*N_fine+1) 190 | X_star = odeint(glyc, x0, t_star, J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A) 191 | 192 | gap = 19 193 | ind_t = np.array([0]) 194 | ind_t = np.concatenate([ind_t[:,None],np.arange(gap+1,N+1)[:,None]]) 195 | ind_t = ind_t[:,0] 196 | 197 | j1 = list(range(0, ind_t.shape[0]+1, 2)) 198 | j2 = list([0]) + list( range(1, ind_t.shape[0]+1, 2) ) 199 | 200 | i1 = ind_t[j1] 201 | i1 = i1[1:] 202 | i2 = ind_t[j2] 203 | i3 = ind_t[j1] 204 | 205 | i_t = [] 206 | i_t.append(i1) 207 | i_t.append(i2) 208 | i_t.append(i3) 209 | 210 | ind = [4,5,6] 211 | 212 | X1_train = X_train[i_t[0],ind[0]] 213 | X2_train = X_train[i_t[1],ind[1]] 214 | X3_train = X_train[i_t[2],ind[2]] 215 | 216 | X_train_ff = [] 217 | X_train_ff.append(X1_train) 218 | X_train_ff.append(X2_train) 219 | X_train_ff.append(X3_train) 220 | 221 | model = ODE_GP(t[:,None], i_t, X_train_ff, x0, glyc, ind) 222 | rng_key_train, rng_key_predict = random.split(random.PRNGKey(0)) 223 | 224 | num_warmup = 4000 225 | num_samples = 8000 226 | num_chains = 1 227 | target_accept_prob = 0.85 228 | settings = {'num_warmup': num_warmup, 229 | 'num_samples': num_samples, 230 | 'num_chains': num_chains, 231 | 'target_accept_prob': target_accept_prob} 232 | samples = model.train(settings, rng_key_train) 233 | print('True values: J0 = %f, k1 = %f, k2 = %f, k3 = %f, k4 = %f, k5 = %f, k6 = %f, k = %f, ka = %f, q = %f, KI = %f, phi = %f, Np = %f, A = %f, IC = %f' % (J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A, IC)) 234 | 235 | vmap_args = (samples['J0'], samples['k1'], samples['k2'], samples['k3'], samples['k4'], samples['k5'], samples['k6'], samples['k'], samples['ka'], samples['q'], samples['KI'], samples['phi'], samples['Np'], samples['A'], samples['IC']) 236 | 237 | np.save('data/par_and_IC',np.array(vmap_args)) 238 | np.save('data/noise',np.array(samples['noise'])) 239 | np.save('data/hyp',np.array(samples['hyp'])) 240 | np.save('data/W',np.array(samples['W'])) 241 | 242 | def RBF(x1, x2, params): 243 | diffs = (x1 / params).T - x2 / params 244 | return np.exp(-0.5 * diffs**2) 245 | 246 | Nt = N+1 247 | N_fine = 100 248 | t_test = np.linspace(0, Tf_test, 2*N_fine+1) 249 | Nt_test = t_test.shape[0] 250 | 251 | t_tr = t[:,None] /model.max_t 252 | t_te = t_test[:,None] /model.max_t 253 | 254 | pred_X_tr_i = lambda J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A, N20: model.predict(t, J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A, N20) 255 | X_tr_i = vmap(pred_X_tr_i)(*vmap_args) 256 | 257 | pred_X_ode_i = lambda J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A, N20: model.predict(t_test, J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A, N20) 258 | X_ode_i = vmap(pred_X_ode_i)(*vmap_args) 259 | 260 | X_pred_GP = [] 261 | Npred_GP_f = 0 262 | for i in range(num_samples): 263 | if i % 500 == 0: 264 | print(i) 265 | K1_tr = samples['W'][i,0]*RBF(model.t_t[0], model.t_t[0], samples['hyp'][i,0]) + np.eye(model.N[0])*(samples['noise'][i,0] + model.jitter) 266 | K2_tr = samples['W'][i,1]*RBF(model.t_t[1], model.t_t[1], samples['hyp'][i,1]) + np.eye(model.N[1])*(samples['noise'][i,1] + model.jitter) 267 | K3_tr = samples['W'][i,2]*RBF(model.t_t[2], model.t_t[2], samples['hyp'][i,2]) + np.eye(model.N[2])*(samples['noise'][i,2] + model.jitter) 268 | K_tr = np.concatenate([np.concatenate([K1_tr, np.zeros((model.N[0], model.N[1])), np.zeros((model.N[0], model.N[2]))], axis = 1), 269 | np.concatenate([np.zeros((model.N[1], model.N[0])), K2_tr, np.zeros((model.N[1], model.N[2]))], axis = 1), 270 | np.concatenate([np.zeros((model.N[2], model.N[0])), np.zeros((model.N[2], model.N[1])), K3_tr], axis = 1)], axis = 0) 271 | K1_trte = samples['W'][i,0]*RBF(t_te, model.t_t[0], samples['hyp'][i,0]) 272 | K2_trte = samples['W'][i,1]*RBF(t_te, model.t_t[1], samples['hyp'][i,1]) 273 | K3_trte = samples['W'][i,2]*RBF(t_te, model.t_t[2], samples['hyp'][i,2]) 274 | K_trte = np.concatenate([np.concatenate([K1_trte, np.zeros((model.N[0],Nt_test)), np.zeros((model.N[0],Nt_test))], axis = 1), 275 | np.concatenate([np.zeros((model.N[1],Nt_test)), K2_trte, np.zeros((model.N[1],Nt_test))], axis = 1), 276 | np.concatenate([np.zeros((model.N[2],Nt_test)), np.zeros((model.N[2],Nt_test)), K3_trte], axis = 1)], axis = 0) 277 | K1_te = samples['W'][i,0]*RBF(t_te, t_te, samples['hyp'][i,0]) 278 | K2_te = samples['W'][i,1]*RBF(t_te, t_te, samples['hyp'][i,1]) 279 | K3_te = samples['W'][i,2]*RBF(t_te, t_te, samples['hyp'][i,2]) 280 | K_te = np.concatenate([np.concatenate([K1_te, np.zeros((Nt_test,Nt_test)), np.zeros((Nt_test,Nt_test))], axis = 1), 281 | np.concatenate([np.zeros((Nt_test,Nt_test)), K2_te, np.zeros((Nt_test,Nt_test))], axis = 1), 282 | np.concatenate([np.zeros((Nt_test,Nt_test)), np.zeros((Nt_test,Nt_test)), K3_te], axis = 1)], axis = 0) 283 | x0_l = np.array([0.5, 1.9, 0.18, 0.15, samples['IC'][i], 0.1, 0.064]) 284 | # X_tr_i = odeint(glyc, x0_l, t, samples['J0'][i], samples['k1'][i], samples['k2'][i], samples['k3'][i], samples['k4'][i], samples['k5'][i], samples['k6'][i], samples['k'][i], samples['ka'][i], samples['q'][i], samples['KI'][i], samples['phi'][i], samples['Np'][i], samples['A'][i]) 285 | X_tr1 = X_tr_i[i,i_t[0],ind[0]] / model.max_X[0] 286 | X_tr2 = X_tr_i[i,i_t[1],ind[1]] / model.max_X[1] 287 | X_tr3 = X_tr_i[i,i_t[2],ind[2]] / model.max_X[2] 288 | X_tr = np.concatenate((X_tr1,X_tr2,X_tr3),axis=0) 289 | 290 | L = np.linalg.cholesky(K_tr) 291 | X_train_f = np.concatenate((model.X[0],model.X[1],model.X[2]),axis=0) 292 | X_train_f = X_train_f.flatten('F') 293 | dX = np.matmul( K_trte.T, np.linalg.solve(np.transpose(L), np.linalg.solve(L,X_train_f.flatten('F')-X_tr.flatten('F'))) ) 294 | # X_ode_i = odeint(glyc, x0_l, t_test, samples['J0'][i], samples['k1'][i], samples['k2'][i], samples['k3'][i], samples['k4'][i], samples['k5'][i], samples['k6'][i], samples['k'][i], samples['ka'][i], samples['q'][i], samples['KI'][i], samples['phi'][i], samples['Np'][i], samples['A'][i]) 295 | X_ode1 = X_ode_i[i,:,ind[0]] / model.max_X[0] 296 | X_ode2 = X_ode_i[i,:,ind[1]] / model.max_X[1] 297 | X_ode3 = X_ode_i[i,:,ind[2]] / model.max_X[2] 298 | X_ode = np.concatenate((X_ode1,X_ode2,X_ode3),axis=0) 299 | 300 | mu = X_ode.flatten('F') + dX 301 | K = K_te - np.matmul(K_trte.T, np.linalg.solve(np.transpose(L), np.linalg.solve(L,K_trte))) 302 | pred0 = onp.random.multivariate_normal(mu, K) 303 | pred0 = pred0.reshape((len(ind), Nt_test)).T 304 | pred0[:,0:1] = model.max_X[0] * pred0[:,0:1] 305 | pred0[:,1:2] = model.max_X[1] * pred0[:,1:2] 306 | pred0[:,2:3] = model.max_X[2] * pred0[:,2:3] 307 | pred = onp.array(X_ode_i[i,:,:]) 308 | pred[:,ind] = pred0 309 | if not math.isnan( np.amax(np.abs(pred)) ): 310 | Npred_GP_f += 1 311 | X_pred_GP.append( pred ) 312 | 313 | X_pred_GP = np.array(X_pred_GP) 314 | 315 | mean_prediction_GP, std_prediction_GP = np.mean(X_pred_GP, axis=0), np.std(X_pred_GP, axis=0) 316 | lower_GP = mean_prediction_GP - 2.0*std_prediction_GP 317 | upper_GP = mean_prediction_GP + 2.0*std_prediction_GP 318 | 319 | var_n = ['S_1','S_2','S_3','S_4','N_2','A_3','S_4^{ex}'] 320 | i_tr = -1 321 | for i in range(D): 322 | plt.figure(figsize = (12,6)) 323 | plt.plot(t_star, X_star[:,i], 'r-', label = "True Trajectory of $"+var_n[i]+"(t)$") 324 | if i in ind : 325 | i_tr += 1 326 | plt.plot(t[i_t[i_tr]], X_train[i_t[i_tr],i], 'ro', label = "Training data of $"+var_n[i]+"(t)$") 327 | else: 328 | plt.plot(t[0], X_train[0,i], 'ro', label = "Training data of $"+var_n[i]+"(t)$") 329 | plt.plot(t_test, mean_prediction_GP[:,i], 'g--', label = "MAP Trajectory of $"+var_n[i]+"(t)$") 330 | plt.fill_between(t_test, lower_GP[:,i], upper_GP[:,i], facecolor='orange', alpha=0.5, label="Two std band") 331 | plt.xlabel('$t$',fontsize=26) 332 | plt.ylabel('$'+var_n[i]+'(t)$',fontsize=26) 333 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 334 | if i==1: 335 | plt.ylim(top= 1.8*X_star[:,i].max(0)) 336 | if i==2 or i==5: 337 | plt.ylim(top= 2.0*X_star[:,i].max(0)) 338 | if i==4: 339 | plt.ylim(top= 1.9*X_star[:,i].max(0)) 340 | if i==0 or i==3: 341 | plt.ylim(top= 1.7*X_star[:,i].max(0)) 342 | if i==6: 343 | plt.ylim(top= 1.5*X_star[:,i].max(0)) 344 | tt = 'plots/x_' + str(i+1) + ".png" 345 | plt.savefig(tt, dpi = 100) 346 | print(Npred_GP_f) 347 | 348 | x0 = np.array([onp.random.uniform(0.15,1.60),onp.random.uniform(0.19,2.10),onp.random.uniform(0.04,0.20),onp.random.uniform(0.10,0.35),onp.random.uniform(0.08,0.30),onp.random.uniform(0.14,2.67),onp.random.uniform(0.05,0.10)]) 349 | X_star = odeint(glyc, x0, t_star, J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A) 350 | vmap_args_diff_x0 = (samples['J0'], samples['k1'], samples['k2'], samples['k3'], samples['k4'], samples['k5'], samples['k6'], samples['k'], samples['ka'], samples['q'], samples['KI'], samples['phi'], samples['Np'], samples['A']) 351 | 352 | @partial(jit, static_argnums=(0,)) 353 | def predict_diff_x0( t_star, J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A): 354 | X = odeint(glyc, x0, t_star, J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A) 355 | return X 356 | 357 | pred_X_tr_i = lambda J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A: predict_diff_x0(t, J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A) 358 | X_tr_i = vmap(pred_X_tr_i)(*vmap_args_diff_x0) 359 | 360 | pred_X_ode_i = lambda J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A: predict_diff_x0(t_test, J0, k1, k2, k3, k4, k5, k6, k, ka, q, KI, phi, Np, A) 361 | X_ode_i = vmap(pred_X_ode_i)(*vmap_args_diff_x0) 362 | 363 | X_pred_GP = [] 364 | Npred_GP_f = 0 365 | for i in range(num_samples): 366 | if i % 500 == 0: 367 | print(i) 368 | K1_tr = samples['W'][i,0]*RBF(model.t_t[0], model.t_t[0], samples['hyp'][i,0]) + np.eye(model.N[0])*(samples['noise'][i,0] + model.jitter) 369 | K2_tr = samples['W'][i,1]*RBF(model.t_t[1], model.t_t[1], samples['hyp'][i,1]) + np.eye(model.N[1])*(samples['noise'][i,1] + model.jitter) 370 | K3_tr = samples['W'][i,2]*RBF(model.t_t[2], model.t_t[2], samples['hyp'][i,2]) + np.eye(model.N[2])*(samples['noise'][i,2] + model.jitter) 371 | K_tr = np.concatenate([np.concatenate([K1_tr, np.zeros((model.N[0], model.N[1])), np.zeros((model.N[0], model.N[2]))], axis = 1), 372 | np.concatenate([np.zeros((model.N[1], model.N[0])), K2_tr, np.zeros((model.N[1], model.N[2]))], axis = 1), 373 | np.concatenate([np.zeros((model.N[2], model.N[0])), np.zeros((model.N[2], model.N[1])), K3_tr], axis = 1)], axis = 0) 374 | K1_trte = samples['W'][i,0]*RBF(t_te, model.t_t[0], samples['hyp'][i,0]) 375 | K2_trte = samples['W'][i,1]*RBF(t_te, model.t_t[1], samples['hyp'][i,1]) 376 | K3_trte = samples['W'][i,2]*RBF(t_te, model.t_t[2], samples['hyp'][i,2]) 377 | K_trte = np.concatenate([np.concatenate([K1_trte, np.zeros((model.N[0],Nt_test)), np.zeros((model.N[0],Nt_test))], axis = 1), 378 | np.concatenate([np.zeros((model.N[1],Nt_test)), K2_trte, np.zeros((model.N[1],Nt_test))], axis = 1), 379 | np.concatenate([np.zeros((model.N[2],Nt_test)), np.zeros((model.N[2],Nt_test)), K3_trte], axis = 1)], axis = 0) 380 | K1_te = samples['W'][i,0]*RBF(t_te, t_te, samples['hyp'][i,0]) 381 | K2_te = samples['W'][i,1]*RBF(t_te, t_te, samples['hyp'][i,1]) 382 | K3_te = samples['W'][i,2]*RBF(t_te, t_te, samples['hyp'][i,2]) 383 | K_te = np.concatenate([np.concatenate([K1_te, np.zeros((Nt_test,Nt_test)), np.zeros((Nt_test,Nt_test))], axis = 1), 384 | np.concatenate([np.zeros((Nt_test,Nt_test)), K2_te, np.zeros((Nt_test,Nt_test))], axis = 1), 385 | np.concatenate([np.zeros((Nt_test,Nt_test)), np.zeros((Nt_test,Nt_test)), K3_te], axis = 1)], axis = 0) 386 | X_tr1 = X_tr_i[i,i_t[0],ind[0]] / model.max_X[0] 387 | X_tr2 = X_tr_i[i,i_t[1],ind[1]] / model.max_X[1] 388 | X_tr3 = X_tr_i[i,i_t[2],ind[2]] / model.max_X[2] 389 | X_tr = np.concatenate((X_tr1,X_tr2,X_tr3),axis=0) 390 | 391 | L = np.linalg.cholesky(K_tr) 392 | X_train_f = np.concatenate((model.X[0],model.X[1],model.X[2]),axis=0) 393 | X_train_f = X_train_f.flatten('F') 394 | dX = np.matmul( K_trte.T, np.linalg.solve(np.transpose(L), np.linalg.solve(L,X_train_f.flatten('F')-X_tr.flatten('F'))) ) 395 | X_ode1 = X_ode_i[i,:,ind[0]] / model.max_X[0] 396 | X_ode2 = X_ode_i[i,:,ind[1]] / model.max_X[1] 397 | X_ode3 = X_ode_i[i,:,ind[2]] / model.max_X[2] 398 | X_ode = np.concatenate((X_ode1,X_ode2,X_ode3),axis=0) 399 | 400 | mu = X_ode.flatten('F') + dX 401 | K = K_te - np.matmul(K_trte.T, np.linalg.solve(np.transpose(L), np.linalg.solve(L,K_trte))) 402 | pred0 = onp.random.multivariate_normal(mu, K) 403 | pred0 = pred0.reshape((len(ind), Nt_test)).T 404 | pred0[:,0:1] = model.max_X[0] * pred0[:,0:1] 405 | pred0[:,1:2] = model.max_X[1] * pred0[:,1:2] 406 | pred0[:,2:3] = model.max_X[2] * pred0[:,2:3] 407 | pred = onp.array(X_ode_i[i,:,:]) 408 | pred[:,ind] = pred0 409 | if not math.isnan( np.amax(np.abs(pred)) ): 410 | Npred_GP_f += 1 411 | X_pred_GP.append( pred ) 412 | 413 | X_pred_GP = np.array(X_pred_GP) 414 | mean_prediction_GP, std_prediction_GP = np.mean(X_pred_GP, axis=0), np.std(X_pred_GP, axis=0) 415 | lower_GP = mean_prediction_GP - 2.0*std_prediction_GP 416 | upper_GP = mean_prediction_GP + 2.0*std_prediction_GP 417 | 418 | var_n = ['S_1','S_2','S_3','S_4','N_2','A_3','S_4^{ex}'] 419 | i_tr = -1 420 | for i in range(D): 421 | plt.figure(figsize = (12,6)) 422 | plt.plot(t_star, X_star[:,i], 'r-', label = "True Trajectory of $"+var_n[i]+"(t)$") 423 | plt.plot(t_test, mean_prediction_GP[:,i], 'g--', label = "MAP Trajectory of $"+var_n[i]+"(t)$") 424 | plt.fill_between(t_test, lower_GP[:,i], upper_GP[:,i], facecolor='orange', alpha=0.5, label="Two std band") 425 | plt.xlabel('$t$',fontsize=26) 426 | plt.ylabel('$'+var_n[i]+'(t)$',fontsize=26) 427 | plt.legend(loc='upper right', frameon=False, prop={'size': 20}) 428 | if i==1: 429 | plt.ylim(top= 1.8*X_star[:,i].max(0)) 430 | if i==2 or i==5: 431 | plt.ylim(top= 2.0*X_star[:,i].max(0)) 432 | if i==4: 433 | plt.ylim(top= 1.9*X_star[:,i].max(0)) 434 | if i==0 or i==3: 435 | plt.ylim(top= 1.7*X_star[:,i].max(0)) 436 | if i==6: 437 | plt.ylim(top= 1.5*X_star[:,i].max(0)) 438 | tt = 'plots/random_x0_x_' + str(i+1) + ".png" 439 | plt.savefig(tt, dpi = 100) 440 | 441 | print(Npred_GP_f) 442 | 443 | import matplotlib as mpl 444 | 445 | def figsize(scale, nplots = 1): 446 | fig_width_pt = 390.0 # Get this from LaTeX using \the\textwidth 447 | inches_per_pt = 1.0/72.27 # Convert pt to inch 448 | golden_mean = (np.sqrt(5.0)-1.0)/2.0 # Aesthetic ratio (you could change this) 449 | fig_width = fig_width_pt*inches_per_pt*scale # width in inches 450 | fig_height = nplots*fig_width*golden_mean # height in inches 451 | fig_size = [fig_width,fig_height] 452 | return fig_size 453 | 454 | pgf_with_latex = { # setup matplotlib to use latex for output 455 | "pgf.texsystem": "pdflatex", # change this if using xetex or lautex 456 | "text.usetex": True, # use LaTeX to write all text 457 | "font.family": "serif", 458 | "font.serif": [], # blank entries should cause plots to inherit fonts from the document 459 | "font.sans-serif": [], 460 | "font.monospace": [], 461 | "axes.labelsize": 20, # LaTeX default is 10pt font. 462 | "axes.titlesize": 20, 463 | "axes.linewidth": 2, 464 | "font.size": 16, 465 | "lines.linewidth": 2, 466 | "legend.fontsize": 20, # Make the legend/label fonts a little smaller 467 | "xtick.labelsize": 16, 468 | "ytick.labelsize": 16, 469 | "figure.figsize": figsize(1.0), # default fig size of 0.9 textwidth 470 | "pgf.preamble": [ 471 | r"\usepackage[utf8x]{inputenc}", # use utf8 fonts becasue your computer can handle it :) 472 | r"\usepackage[T1]{fontenc}", # plots will be generated using this preamble 473 | ] 474 | } 475 | mpl.rcParams.update(pgf_with_latex) 476 | 477 | def newfig(width, nplots = 1): 478 | fig = plt.figure(figsize=figsize(width, nplots)) 479 | ax = fig.add_subplot(111) 480 | return fig, ax 481 | 482 | def savefig(filename, crop = True): 483 | if crop == True: 484 | plt.savefig('{}.png'.format(filename), bbox_inches='tight', pad_inches=0 , dpi = 100) 485 | else: 486 | plt.savefig('{}.png'.format(filename) , dpi = 100) 487 | 488 | Data = [samples['J0'], samples['k1'], samples['k2'], samples['k3'], samples['k4'], samples['k5'], samples['k6'], samples['k'], samples['ka'], samples['q'], samples['KI'], samples['phi'], samples['Np'], samples['A'], samples['IC']] 489 | 490 | true = [2.5, 100, 6, 16, 100, 1.28, 12, 1.8, 13, 4, 0.52, 0.1, 1, 4, 0.16] 491 | name = [r'$J_0$', r'$k_1$', r'$k_2$', r'$k_3$', r'$k_4$', r'$k_5$', r'$k_6$', r'$k$',r'$\kappa$', r'$q$', r'$K_1$', r'$\varphi$', r'$N$', r'$A$', r'$N_{2,0}$'] 492 | 493 | X = np.linspace(0.94, 1.06, 100)[:,None] 494 | fig = plt.figure() 495 | fig.set_size_inches(15, 10) 496 | for i in range(2): 497 | for j in range(7+i): 498 | ind = i*7 + j 499 | ax = plt.subplot2grid((2,8), (i,j)) 500 | ax.boxplot([Data[ind]], showfliers=False) 501 | ax.plot(X, true[ind]*np.ones(X.shape[0]), 'b-') 502 | ax.set_xticklabels([]) 503 | ax.set_xticks([]) 504 | plt.title(name[ind], y=-0.12) 505 | 506 | fig.tight_layout() 507 | 508 | savefig('plots/box_plot', False) 509 | -------------------------------------------------------------------------------- /Yeast-Glycolysis/YG_GP_NODE_MCMC_convergence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pymc3 as pm3 4 | 5 | plt.rcParams.update(plt.rcParamsDefault) 6 | plt.rc('font', family='serif') 7 | plt.rcParams.update({'font.size': 16, 8 | 'lines.linewidth': 2, 9 | 'axes.labelsize': 20, # fontsize for x and y labels 10 | 'axes.titlesize': 20, 11 | 'xtick.labelsize': 16, 12 | 'ytick.labelsize': 16, 13 | 'legend.fontsize': 20, 14 | 'axes.linewidth': 2, 15 | "pgf.texsystem": "pdflatex", # change this if using xetex or lautex 16 | "text.usetex": True, # use LaTeX to write all text 17 | }) 18 | 19 | data2 = np.load("data/par_and_IC.npy",allow_pickle=True).T 20 | data3 = np.load("data/par_and_IC_2.npy",allow_pickle=True).T 21 | 22 | N_var = 15 23 | i0 = 0 24 | ie = 8000 25 | step = 8 26 | 27 | J0_1 = data2[i0:ie:step,0:1] # J0 = 2.5 28 | k1_1 = data2[i0:ie:step,1:2] # k1 = 100. 29 | k2_1 = data2[i0:ie:step,2:3] # k2 = 6. 30 | k3_1 = data2[i0:ie:step,3:4] # k3 = 16. 31 | k4_1 = data2[i0:ie:step,4:5] # k4 = 100. 32 | k5_1 = data2[i0:ie:step,5:6] # k5 = 1.28 33 | k6_1 = data2[i0:ie:step,6:7] # k6 = 12. 34 | k_1 = data2[i0:ie:step,7:8] # k = 1.8 35 | ka_1 = data2[i0:ie:step,8:9] # ka = 13. 36 | q_1 = data2[i0:ie:step,9:10] # q = 4. 37 | KI_1 = data2[i0:ie:step,10:11] # KI = 0.52 38 | phi_1 = data2[i0:ie:step,11:12] # phi = 0.1 39 | Np_1 = data2[i0:ie:step,12:13] # Np = 1. 40 | A_1 = data2[i0:ie:step,13:14] # A = 4. 41 | IC_1 = data2[i0:ie:step,14:15] # IC = 0.16 42 | 43 | J0_2 = data3[i0:ie:step,0:1] # J0 = 2.5 44 | k1_2 = data3[i0:ie:step,1:2] # k1 = 100. 45 | k2_2 = data3[i0:ie:step,2:3] # k2 = 6. 46 | k3_2 = data3[i0:ie:step,3:4] # k3 = 16. 47 | k4_2 = data3[i0:ie:step,4:5] # k4 = 100. 48 | k5_2 = data3[i0:ie:step,5:6] # k5 = 1.28 49 | k6_2 = data3[i0:ie:step,6:7] # k6 = 12. 50 | k_2 = data3[i0:ie:step,7:8] # k = 1.8 51 | ka_2 = data3[i0:ie:step,8:9] # ka = 13. 52 | q_2 = data3[i0:ie:step,9:10] # q = 4. 53 | KI_2 = data3[i0:ie:step,10:11] # KI = 0.52 54 | phi_2 = data3[i0:ie:step,11:12] # phi = 0.1 55 | Np_2 = data3[i0:ie:step,12:13] # Np = 1. 56 | A_2 = data3[i0:ie:step,13:14] # A = 4. 57 | IC_2 = data3[i0:ie:step,14:15] # IC = 0.16 58 | 59 | names = [r'$J_0$', r'$k_1$', r'$k_2$', r'$k_3$', r'$k_4$', r'$k_5$', r'$k_6$', r'$k$',r'$\kappa$', r'$q$', r'$K_1$', r'$\varphi$', r'$N$', r'$A$', r'$N_{2,0}$'] 60 | 61 | N = J0_1.shape[0] 62 | iteration = np.arange(0,N) 63 | 64 | 65 | data_chain1 = np.concatenate((J0_1,k1_1,k2_1,k3_1,k4_1,k5_1,k6_1,k_1,ka_1,q_1,KI_1,phi_1,Np_1,A_1,IC_1),axis=-1) # 2000 x 5 66 | data_chain2 = np.concatenate((J0_2,k1_2,k2_2,k3_2,k4_2,k5_2,k6_2,k_2,ka_2,q_2,KI_2,phi_2,Np_2,A_2,IC_2),axis=-1) # 2000 x 5 67 | 68 | N_per_block = 5 69 | data_traceplot1_1 = {} 70 | data_traceplot1_2 = {} 71 | data_traceplot1_3 = {} 72 | data_traceplot2_1 = {} 73 | data_traceplot2_2 = {} 74 | data_traceplot2_3 = {} 75 | 76 | j = 0 77 | for i,name in enumerate(names[j:j+N_per_block]): 78 | data_traceplot1_1[name] = data_chain1[:,j+i] 79 | data_traceplot2_1[name] = data_chain2[:,j+i] 80 | 81 | j = N_per_block 82 | for i,name in enumerate(names[j:j+N_per_block]): 83 | data_traceplot1_2[name] = data_chain1[:,j+i] 84 | data_traceplot2_2[name] = data_chain2[:,j+i] 85 | 86 | j = 2*N_per_block 87 | for i,name in enumerate(names[j:j+N_per_block]): 88 | data_traceplot1_3[name] = data_chain1[:,j+i] 89 | data_traceplot2_3[name] = data_chain2[:,j+i] 90 | 91 | for i in range(N_var): 92 | 93 | chain1 = data_chain1[:,i:i+1] 94 | chain2 = data_chain2[:,i:i+1] 95 | 96 | burn_in = 0 97 | length = (ie-i0)//step 98 | 99 | n = chain1[burn_in:burn_in+length].shape[0] 100 | 101 | W = (chain1[burn_in:burn_in+length].std()**2 + chain2[burn_in:burn_in+length].std()**2)/2 102 | mean1 = chain1[burn_in:burn_in+length].mean() 103 | mean2 = chain2[burn_in:burn_in+length].mean() 104 | mean = (mean1 + mean2)/2 105 | B = n * ((mean1 - mean)**2 + (mean2 - mean)**2) 106 | var_theta = (1 - 1/n) * W + 1/n*B 107 | print("Gelman-Rubin Diagnostic: ", np.sqrt(var_theta/W)) 108 | 109 | j = 0 110 | corr_plot1_1 = pm3.autocorrplot(data_traceplot1_1,var_names=names[j:j+N_per_block],grid=(1,N_per_block),figsize=(12,6.5),textsize=18,combined=True) 111 | corr_plot1_1 = corr_plot1_1[None,:] 112 | for i in range(N_per_block): 113 | corr_plot1_1[0, i].set_xlabel('Lag Index',fontsize=26) 114 | corr_plot1_1[0, 0].set_ylabel('Autocorrelation Value',fontsize=26) 115 | plt.savefig("plots/autocorrelation_1.png", bbox_inches='tight', pad_inches=0.01) 116 | 117 | j = N_per_block 118 | corr_plot1_2 = pm3.autocorrplot(data_traceplot1_2,var_names=names[j:j+N_per_block],grid=(1,N_per_block),figsize=(12,6.5),textsize=18,combined=True) 119 | corr_plot1_2 = corr_plot1_2[None,:] 120 | for i in range(N_per_block): 121 | corr_plot1_2[0, i].set_xlabel('Lag Index',fontsize=26) 122 | corr_plot1_2[0, 0].set_ylabel('Autocorrelation Value',fontsize=26) 123 | plt.savefig("plots/autocorrelation_2.png", bbox_inches='tight', pad_inches=0.01) 124 | 125 | j = 2*N_per_block 126 | corr_plot1_3 = pm3.autocorrplot(data_traceplot1_3,var_names=names[j:j+N_per_block],grid=(1,N_per_block),figsize=(12,6.5),textsize=18,combined=True) 127 | corr_plot1_3 = corr_plot1_3[None,:] 128 | for i in range(N_per_block): 129 | corr_plot1_3[0, i].set_xlabel('Lag Index',fontsize=26) 130 | corr_plot1_3[0, 0].set_ylabel('Autocorrelation Value',fontsize=26) 131 | plt.savefig("plots/autocorrelation_3.png", bbox_inches='tight', pad_inches=0.01) 132 | 133 | plt.figure(figsize=(12,7)) 134 | for i in range(data_chain1.shape[1]): 135 | gw_plot = pm3.geweke(data_chain1[:,i],.1,.5,20) 136 | plt.scatter(gw_plot[:,0],gw_plot[:,1],label="%s"%names[i]) 137 | plt.axhline(-1.98, c='r') 138 | plt.axhline(1.98, c='r') 139 | plt.xticks(fontsize=22) 140 | plt.yticks(fontsize=22) 141 | plt.xlabel("Subchain sample number",fontsize=26) 142 | plt.ylabel("Geweke z-score",fontsize=26) 143 | plt.title('Geweke Plot Comparing first 10$\%$ and Slices of the Last 50$\%$ of Chain') 144 | 145 | plt.legend(bbox_to_anchor=(1.0, 1.2), loc='upper left') 146 | plt.tight_layout() 147 | plt.show() 148 | plt.savefig("plots/geweke.png", pad_inches=0.01) 149 | -------------------------------------------------------------------------------- /Yeast-Glycolysis/data/W.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/data/W.npy -------------------------------------------------------------------------------- /Yeast-Glycolysis/data/W_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/data/W_2.npy -------------------------------------------------------------------------------- /Yeast-Glycolysis/data/hyp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/data/hyp.npy -------------------------------------------------------------------------------- /Yeast-Glycolysis/data/hyp_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/data/hyp_2.npy -------------------------------------------------------------------------------- /Yeast-Glycolysis/data/noise.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/data/noise.npy -------------------------------------------------------------------------------- /Yeast-Glycolysis/data/noise_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/data/noise_2.npy -------------------------------------------------------------------------------- /Yeast-Glycolysis/data/par_and_IC.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/data/par_and_IC.npy -------------------------------------------------------------------------------- /Yeast-Glycolysis/data/par_and_IC_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/data/par_and_IC_2.npy -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/autocorrelation_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/autocorrelation_1.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/autocorrelation_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/autocorrelation_2.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/autocorrelation_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/autocorrelation_3.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/box_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/box_plot.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/geweke.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/geweke.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/random_x0_x_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/random_x0_x_1.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/random_x0_x_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/random_x0_x_2.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/random_x0_x_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/random_x0_x_3.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/random_x0_x_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/random_x0_x_4.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/random_x0_x_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/random_x0_x_5.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/random_x0_x_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/random_x0_x_6.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/random_x0_x_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/random_x0_x_7.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/x_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/x_1.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/x_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/x_2.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/x_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/x_3.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/x_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/x_4.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/x_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/x_5.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/x_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/x_6.png -------------------------------------------------------------------------------- /Yeast-Glycolysis/plots/x_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/GP-NODEs/12b9364f1515e06f32703417de3a483474156802/Yeast-Glycolysis/plots/x_7.png -------------------------------------------------------------------------------- /readme.txt: -------------------------------------------------------------------------------- 1 | GP-NODE Code Guide 2 | @author: mohamedazizbhouri 3 | 4 | Rk1: The easiest way to download the codes and data is to use the following Google Drive link to get a zip file of the whole repository: https://drive.google.com/file/d/1c3SsMf5hYpyKEDC7bbPmueI_5Ax0o03l/view?usp=sharing 5 | 6 | Rk2: The code was tested using the jax version 0.1.73, the jaxlib version 0.1.51, and the numpyro version 0.3.0 7 | 8 | ################################################### 9 | ############## Predator-prey problem ############## 10 | ################################################### 11 | 12 | The folder "Predator-prey" contains the implementation of the GP-NODE method for a predator-prey problem with dictionary learning as detailed in the paper "Gaussian processes meet NeuralODEs: A Bayesian framework for learning the dynamics of partially observed systems from scarce and noisy data". 13 | 14 | The code "PP_GP_NODE.py" contains such implementation. It generates the following 5 numpy files which are saved in the folder "data": 15 | * "par.npy" (of size: number of samples by length of the dictionary) contains the trace of samples of the inferred dictionary parameters, 16 | * "IC.npy" (of size: number of samples) contains the trace of samples of the inferred initial condition for the variable x_0 of the Predator-Prey system, 17 | * "noise.npy" (of size: number of samples by observable state dimension (2)) contains the trace of the samples of the Gaussian noise variance, 18 | * "hyp.npy" (of size: number of samples by observable state dimension (2)) contains the trace of the samples of the RBF kernel length, 19 | * "W.npy" (of size: number of samples by observable state dimension (2)) contains the trace of the samples of the RBF kernel variance. 20 | 21 | The code "PP_GP_NODE.py" also generates the following 3 plots which are saved in the folder "plots": 22 | * "x_1.png" and "x_2.png" which show the learned dynamics versus the true dynamics and the training data of the variables x_1 and x_2 respectively, 23 | * "box_plot.png" and "box_plot_x0.png" which show the uncertainty estimation of the inferred dictionary parameters and initial condition respectively. 24 | 25 | The code "PP_GP_NODE_MCMC_convergence.py" performs: 26 | * the Gelman Rubin tests for the non-zero dictionary parameters and the initial condition, 27 | * the Geweke diagnostic whose results are saved in the file "geweke.png" within the folder "plots", 28 | * the autocorrelation estimation as a function of the lag, and the corresponding results are saved in the file "autocorrelation.png" within the folder "plots". 29 | The Geweke diagnostic and the autocorrelation estimation are performed using the traces saved in the folder "data" by running the code "PP_GP_NODE.py". In order to perform the Gelman Rubin tests, two chains need to be considered. For instance, the results obtained for the second chain can be saved in the folder "data" with the extension "_2" in the files names such that we would have "par_2.npy", "IC_2.noy", etc as it is in the proposed repository. 30 | 31 | The code "PP_sindy.py" contains the implementation of the SINDY framework for the Predator-prey problem. 5 cases are considered depending on the available datasets as detailed in the paper "Gaussian processes meet NeuralODEs: A Bayesian framework for learning the dynamics of partially observed systems from scarce and noisy data". The code outputs the inferred dictionary parameters and generates 10 plots showing the SINDY predictions versus the true dynamics and the training data for each of the two variables and for each for the 5 cases considered. These plots are saved in the folder "plots_sindy". 32 | 33 | ################################################### 34 | ############# Yeast-Glycolysis problem ############ 35 | ################################################### 36 | 37 | The folder "Yeast-Glycolysis" contains the implementation of the GP-NODE method for the Yeast-Glycolysis system as detailed in the paper "Gaussian processes meet NeuralODEs: A Bayesian framework for learning the dynamics of partially observed systems from scarce and noisy data". 38 | 39 | The code "YG_GP_NODE.py" contains such implementation. It generates the following 4 numpy files which are saved in the folder "data": 40 | * "par_and_IC.npy" (of size: number of samples by number of physical parameters (15)) contains the trace of samples of the inferred physical parameters, 41 | * "noise.npy" (of size: number of samples by observable state dimension (3)) contains the trace of the samples of the Gaussian noise variance, 42 | * "hyp.npy" (of size: number of samples by observable state dimension (3)) contains the trace of the samples of the RBF kernel length, 43 | * "W.npy" (of size: number of samples by observable state dimension (3)) contains the trace of the samples of the RBF kernel variance. 44 | 45 | The code "YG_GP_NODE.py" also generates the following 15 plots which are saved in the folder "plots": 46 | * "x_1.png" ... "x_7.png" which show the learned dynamics versus the true dynamics and the training data of the variables x_1 ... x_7 respectively, 47 | * "random_x0_x_1.png" ... "random_x0_x_7.png" which show the future forecasts versus the true dynamics of the variables x_1 ... x_7 respectively for unseen initial conditions that are randomly sampled, 48 | * "box_plot.png" which shows the uncertainty estimation of the inferred physical parameters. 49 | 50 | The code "YG_GP_NODE_MCMC_convergence.py" performs: 51 | * the Gelman Rubin tests for the physical parameters, 52 | * the Geweke diagnostic whose results are saved in the file "geweke.png" within the folder "plots", 53 | * the autocorrelation estimation as a function of the lag, and the corresponding results are saved in the files "autocorrelation_1.png", "autocorrelation_2.png" and "autocorrelation_3.png" within the folder "plots". 54 | The Geweke diagnostic and the autocorrelation estimation are performed using the traces saved in the folder "data" by running the code "YG_GP_NODE.py". In order to perform the Gelman Rubin tests, two chains need to be considered. For instance, the results obtained for the second chain can be saved in the folder "data" with the extension "_2" in the files names such that we would have "par_and_IC_2.npy" as it is in the proposed repository. 55 | 56 | ################################################### 57 | ######## Human motion capture data problem ######## 58 | ################################################### 59 | 60 | The folder "Human motion" contains the implementation of the GP-NODE method for the human motion capture data problem as detailed in the paper "Gaussian processes meet NeuralODEs: A Bayesian framework for learning the dynamics of partially observed systems from scarce and noisy data". 61 | 62 | First, extract the folder "npODE" from the zip file "npODE.zip" which located in the folder "Human motion". Second, run the "demo_cmu_walking.m" code (located in the folder "Human motion/npODE/exps") in order to obtain the fitting and forecasting error of the npODE method and generate the data in .mat format. The latter is saved in different files in the the folder "Human motion/npODE/exps" under the names "X.mat", "Y.mat", "u_pca.mat" and "v_pca.mat". These files need to be copied to the folder "Human motion/data". 63 | 64 | The code "HM_GP_NODE.py" contains the implementation of the GP-NODE method to the human motion capture data problem. First, please consider specifying the case to be considered by assigning the value 'A_' to the variable "case" in the code "HM_GP_NODE.py" if you want to consider the A case as detailed in the paper "Gaussian processes meet NeuralODEs: A Bayesian framework for learning the dynamics of partially observed systems from scarce and noisy data". Otherwise, please consider assigning the value 'B_' to the variable "case" in the code "HM_GP_NODE.py" if you want to consider the B case. 65 | 66 | For each case (A or B), the code "HM_GP_NODE.py" generates the following 4 numpy files which are saved in the folder "data": 67 | * case+"par.npy" (of size: number of samples by length of the dictionary) contains the trace of samples of the inferred dictionary parameters, 68 | * case+"noise.npy" (of size: number of samples by observable state dimension (3)) contains the trace of the samples of the Gaussian noise variance, 69 | * case+"hyp.npy" (of size: number of samples by observable state dimension (3)) contains the trace of the samples of the RBF kernel length, 70 | * case+"W.npy" (of size: number of samples by observable state dimension (3)) contains the trace of the samples of the RBF kernel variance. 71 | 72 | For each case (A or B), the code "HM_GP_NODE.py" also generates the following 10 plots which are saved in the folder "plots": 73 | * case+"x_1.png", case+"x_2.png" and case+"x_3.png" which show the learned dynamics versus the true dynamics and the training data of the variables x_1, x_2 and x_3 respectively, 74 | * case+"y_27.png", case+"y_34.png", case+"y_37.png", case+"y_39.png", case+"y_42.png" and case+"y_48.png" which show the learned dynamics versus the true dynamics of PCA-recovered y_27, y_34, y_37, y_39, y_42 and y_48 respectively, 75 | * "box_plot.png" which shows the uncertainty estimation of the inferred dictionary parameters. 76 | 77 | The code "HM_GP_NODE_MCMC_convergence.py" performs: 78 | * the Gelman Rubin tests for the most significant non-zero dictionary parameters, 79 | * the Geweke diagnostic whose results are saved in the file case+"geweke.png" within the folder "plots", 80 | * the autocorrelation estimation as a function of the lag, and the corresponding results are saved in the files case+"autocorrelation_1.png" and case+"autocorrelation_2.png" within the folder "plots". 81 | The Geweke diagnostic and the autocorrelation estimation are performed using the traces saved in the folder "data" by running the code "HM_GP_NODE.py". In order to perform the Gelman Rubin tests, two chains need to be considered. For instance, the results obtained for the second chain can be saved in the folder "data" with the extension "_2" in the files names such that we would have case+"par_2.npy" as it is in the proposed repository. 82 | Please consider specifying the case to be considered by assigning the value 'A_' or 'B_' to the variable "case" in the code "HM_GP_NODE_MCMC_convergence.py". 83 | --------------------------------------------------------------------------------