├── Advection ├── Plot_results │ ├── analytic_solution.py │ ├── error_plot.py │ └── eval_decay.py └── Train_model │ └── train_advection.py ├── Antiderivative ├── Plot_results │ ├── error_lineplots.py │ ├── eval_decay.py │ └── plot_spiral.py └── Train_model │ └── train_antiderivative.py ├── README.md └── Shallow Water ├── Plot_results └── plot_results.py └── Train_model └── train_SW.py /Advection/Plot_results/analytic_solution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import diags 3 | import matplotlib.pyplot as plt 4 | 5 | import jax.numpy as jnp 6 | import numpy as onp 7 | import matplotlib.pyplot as plt 8 | from jax import random, jit, vmap 9 | 10 | plt.rcParams.update(plt.rcParamsDefault) 11 | plt.rc('font', family='serif') 12 | plt.rcParams.update({ 13 | "text.usetex": True, 14 | "font.family": "serif", 15 | 'text.latex.preamble': r'\usepackage{amsmath}', 16 | 'font.size': 20, 17 | 'lines.linewidth': 3, 18 | 'axes.labelsize': 22, # fontsize for x and y labels (was 10) 19 | 'axes.titlesize': 24, 20 | 'xtick.labelsize': 20, 21 | 'ytick.labelsize': 20, 22 | 'legend.fontsize': 20, 23 | 'axes.linewidth': 2}) 24 | 25 | fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(18,8)) 26 | 27 | @jit 28 | def analytical_solution(key, x, t): 29 | mu = random.uniform(key, minval=0.05, maxval=1.0) 30 | return vmap(initial_condition,in_axes=(None,0,None))(x,t,mu), initial_condition(x,0,mu) 31 | 32 | def initial_condition(x, t, mu): 33 | x = x-c*t 34 | denom = 1./jnp.sqrt(0.0002*jnp.pi) 35 | return denom*jnp.exp(-(1./0.0002)*(x-mu)**2) 36 | 37 | lb_x = 0. 38 | ub_x = 2 39 | 40 | lb_t = 0 41 | ub_t = 1 42 | 43 | Nt = 1024 44 | Nx = 1024 45 | N = 2000 46 | 47 | x = jnp.linspace(0,2,num=Nx) 48 | t = jnp.linspace(0,1,num=Nt) 49 | grid = jnp.meshgrid(x, t) 50 | c = 1 51 | 52 | keys = random.split(random.PRNGKey(1000),num=N) 53 | T_exact_all, u_exact_all = vmap(analytical_solution,in_axes=(0,None,None))(keys,x,t) 54 | 55 | 56 | from matplotlib import pyplot as plt 57 | from mpl_toolkits.mplot3d import Axes3D 58 | 59 | sm = plt.cm.ScalarMappable(cmap=plt.cm.cool, norm=plt.Normalize(vmin=0, vmax=10)) 60 | 61 | fig = plt.figure(figsize=(6,5)) 62 | ax = fig.add_subplot(111, projection='3d') 63 | 64 | ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 65 | ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 66 | ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0)) 67 | ax.xaxis._axinfo["grid"]['color'] = (1,1,1,0) 68 | ax.yaxis._axinfo["grid"]['color'] = (1,1,1,0) 69 | ax.zaxis._axinfo["grid"]['color'] = (1,1,1,0) 70 | 71 | idx=33 72 | ax.plot(x, 0*t, T_exact_all[idx,0,:], 'k') 73 | ax.plot_wireframe(grid[0], grid[1], T_exact_all[idx,:,:], rstride=64, cstride=0, color='k', alpha=0.4) 74 | 75 | ax.set_xlabel(r'$x$') 76 | ax.set_ylabel(r'$t$') 77 | ax.set_zlabel(r'$s(x,t)$') 78 | ax.xaxis.labelpad = 10 79 | ax.yaxis.labelpad = 10 80 | ax.zaxis.labelpad = 10 81 | 82 | plt.tight_layout() 83 | plt.savefig("advection_solution.png", bbox_inches='tight', dpi=600) -------------------------------------------------------------------------------- /Advection/Plot_results/error_plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from scipy import stats 4 | 5 | plt.rcParams.update(plt.rcParamsDefault) 6 | plt.rc('font', family='serif') 7 | plt.rcParams.update({ 8 | "text.usetex": True, 9 | "font.family": "serif", 10 | 'text.latex.preamble': r'\usepackage{amsmath}', 11 | 'font.size': 20, 12 | 'lines.linewidth': 3, 13 | 'axes.labelsize': 22, # fontsize for x and y labels (was 10) 14 | 'axes.titlesize': 24, 15 | 'xtick.labelsize': 20, 16 | 'ytick.labelsize': 20, 17 | 'legend.fontsize': 20, 18 | 'axes.linewidth': 2}) 19 | 20 | # load the dataset 21 | n_hat = [2, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] 22 | iterations = [1, 2, 3, 4, 5, 6, 7] 23 | test_error_DON_linear = np.zeros((len(n_hat), len(iterations))) 24 | test_error_DON_nonlinear = np.zeros((len(n_hat), len(iterations))) 25 | 26 | fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(18,8)) 27 | for i in range(len(n_hat)): 28 | n = n_hat[i] 29 | for j in range(len(iterations)): 30 | it = iterations[j] 31 | d = np.load("../Error_Vectors/Error_Advection_DeepONet_nhat%d_iteration%d_linear.npz"%(n,it)) 32 | test_error_DON_linear[i,j] = np.mean(d["test_error"]) 33 | 34 | d = np.load("../Error_Vectors/Error_Advection_DeepONet_nhat%d_iteration%d_nonlinear.npz"%(n,it)) 35 | test_error_DON_nonlinear[i,j] = np.mean(d["test_error"]) 36 | 37 | lin_mu, lin_std = np.median(test_error_DON_linear, axis = 1), stats.median_abs_deviation(test_error_DON_linear, axis = 1) 38 | nonlin_mu, nonlin_std = np.median(test_error_DON_nonlinear, axis = 1), stats.median_abs_deviation(test_error_DON_nonlinear, axis = 1) 39 | 40 | dispersion_scale = 1.0 41 | lin_lower = np.log10(np.clip(lin_mu - dispersion_scale*lin_std, a_min=0., a_max = np.inf) + 1e-8) 42 | lin_upper = np.log10(lin_mu + dispersion_scale*lin_std + 1e-8) 43 | 44 | nonlin_lower = np.log10(np.clip(nonlin_mu - dispersion_scale*nonlin_std, a_min=0., a_max = np.inf) + 1e-8) 45 | nonlin_upper = np.log10(nonlin_mu + dispersion_scale*nonlin_std + 1e-8) 46 | 47 | fig = plt.figure(figsize=(6,5)) 48 | plt.plot(np.array(n_hat), np.log10(lin_mu), 'k', label='Linear decoder') 49 | plt.fill_between(np.array(n_hat), lin_lower, lin_upper, 50 | facecolor='black', alpha=0.5) 51 | 52 | plt.plot(np.array(n_hat), np.log10(nonlin_mu), 'm', label='NOMAD') 53 | plt.fill_between(np.array(n_hat), nonlin_lower, nonlin_upper, 54 | facecolor='magenta', alpha=0.5) 55 | plt.legend(frameon=False) 56 | plt.xlabel(r'Latent dimension $n$') 57 | plt.ylabel(r'Relative $\mathcal{L}_2$ error ($\log_{10}$)') 58 | plt.xticks(n_hat) 59 | plt.savefig("advection_errors.png", bbox_inches='tight', dpi=600) 60 | 61 | -------------------------------------------------------------------------------- /Advection/Plot_results/eval_decay.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import diags 3 | import matplotlib.pyplot as plt 4 | 5 | import jax.numpy as jnp 6 | import numpy as onp 7 | import matplotlib.pyplot as plt 8 | from jax import random, jit, vmap 9 | 10 | plt.rcParams.update(plt.rcParamsDefault) 11 | plt.rc('font', family='serif') 12 | plt.rcParams.update({ 13 | "text.usetex": True, 14 | "font.family": "serif", 15 | 'text.latex.preamble': r'\usepackage{amsmath}', 16 | 'font.size': 20, 17 | 'lines.linewidth': 3, 18 | 'axes.labelsize': 22, # fontsize for x and y labels (was 10) 19 | 'axes.titlesize': 24, 20 | 'xtick.labelsize': 20, 21 | 'ytick.labelsize': 20, 22 | 'legend.fontsize': 20, 23 | 'axes.linewidth': 2}) 24 | 25 | Nt = 100 26 | Nx = 256 27 | N= 2000 28 | 29 | d = np.load("../Data/pure_advection_traintest.npz") 30 | curve_fs = d["solution"][:,:,:].reshape(N,Nt*Nx,1) 31 | 32 | curve_fs = curve_fs - np.mean(curve_fs,axis=0) 33 | 34 | def fPCA_eig(functions, gram_cov): 35 | gramevals, gramevecs = np.linalg.eigh(gram_cov) 36 | efuncs = np.matmul(gramevecs, functions.T).T 37 | evals, efuncs = np.flip(gramevals, axis=-1), np.flip(efuncs, axis=-1) 38 | return evals, efuncs 39 | 40 | from sklearn.metrics.pairwise import pairwise_distances 41 | gram_mat2 = (1./(Nt*Nx))*pairwise_distances(curve_fs[:,:,0],metric=np.dot) 42 | 43 | print(gram_mat2.shape) 44 | evals_rho, efuncs_rho = fPCA_eig(curve_fs[:,:,0].T, gram_mat2[:,:]) 45 | 46 | fig = plt.figure(figsize=(6,5)) 47 | plt.plot(evals_rho, 'k') 48 | plt.xlabel(r'Dimension index') 49 | plt.ylabel(r'Eigenvalue') 50 | plt.yscale('log') 51 | plt.xscale('log') 52 | plt.tight_layout() 53 | plt.savefig("advection_decay.png", bbox_inches='tight', dpi=600) -------------------------------------------------------------------------------- /Advection/Train_model/train_advection.py: -------------------------------------------------------------------------------- 1 | from jax.flatten_util import ravel_pytree 2 | from jax.example_libraries.stax import Dense, Gelu 3 | from jax.example_libraries import stax, optimizers 4 | import os 5 | import timeit 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | from jax.numpy.linalg import norm 11 | from jax import random, grad, jit 12 | from functools import partial 13 | from torch.utils import data 14 | from tqdm import trange 15 | import itertools 16 | import argparse 17 | 18 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']="False" 19 | 20 | class DataGenerator(data.Dataset): 21 | def __init__(self, u, y, s, 22 | batch_size=100, rng_key=random.PRNGKey(1234)): 23 | 'Initialization' 24 | self.u = u 25 | self.y = y 26 | self.s = s 27 | 28 | self.N = u.shape[0] 29 | self.batch_size = batch_size 30 | self.key = rng_key 31 | 32 | def __getitem__(self, index): 33 | 'Generate one batch of data' 34 | self.key, subkey = random.split(self.key) 35 | inputs,outputs = self.__data_generation(subkey) 36 | return inputs, outputs 37 | 38 | @partial(jit, static_argnums=(0,)) 39 | def __data_generation(self, key): 40 | 'Generates data containing batch_size samples' 41 | idx = random.choice(key, self.N, (self.batch_size,), replace=False) 42 | s = self.s[idx,:,:] 43 | u = self.u[idx,:,:] 44 | y = self.y[idx,:,:] 45 | inputs = (u, y) 46 | return inputs, s 47 | 48 | class operator_model: 49 | def __init__(self,branch_layers, trunk_layers , m=100, P=100,n=None, decoder=None, ds=None): 50 | 51 | seed = np.random.randint(low=0, high=100000) 52 | self.branch_init, self.branch_apply = self.init_NN(branch_layers, activation=Gelu) 53 | self.in_shape = (-1, branch_layers[0]) 54 | self.out_shape, branch_params = self.branch_init(random.PRNGKey(seed), self.in_shape) 55 | 56 | seed = np.random.randint(low=0, high=100000) 57 | self.trunk_init, self.trunk_apply = self.init_NN(trunk_layers, activation=Gelu) 58 | self.in_shape = (-1, trunk_layers[0]) 59 | self.out_shape, trunk_params = self.trunk_init(random.PRNGKey(seed), self.in_shape) 60 | 61 | params = (trunk_params, branch_params) 62 | # Use optimizers to set optimizer initialization and update functions 63 | self.opt_init,self.opt_update,self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 64 | decay_steps=100, 65 | decay_rate=0.99)) 66 | self.opt_state = self.opt_init(params) 67 | # Logger 68 | self.itercount = itertools.count() 69 | self.loss_log = [] 70 | 71 | if decoder=="nonlinear": 72 | self.fwd = self.NOMAD 73 | if decoder=="linear": 74 | self.fwd = self.DeepONet 75 | 76 | self.n = n 77 | self.ds = ds 78 | 79 | 80 | def init_NN(self, Q, activation=Gelu): 81 | layers = [] 82 | num_layers = len(Q) 83 | if num_layers < 2: 84 | net_init, net_apply = stax.serial() 85 | else: 86 | for i in range(0, num_layers-2): 87 | layers.append(Dense(Q[i+1])) 88 | layers.append(activation) 89 | layers.append(Dense(Q[-1])) 90 | net_init, net_apply = stax.serial(*layers) 91 | return net_init, net_apply 92 | 93 | @partial(jax.jit, static_argnums=0) 94 | def NOMAD(self, params, inputs): 95 | trunk_params, branch_params = params 96 | inputsu, inputsy = inputs 97 | b = self.branch_apply(branch_params, inputsu.reshape(inputsu.shape[0], 1, inputsu.shape[1])) 98 | b = jnp.tile(b, (1,inputsy.shape[1],1)) 99 | inputs_recon = jnp.concatenate((jnp.tile(inputsy,(1,1,b.shape[-1]//inputsy.shape[-1])), b), axis=-1) 100 | out = self.trunk_apply(trunk_params, inputs_recon) 101 | return out 102 | 103 | @partial(jax.jit, static_argnums=0) 104 | def DeepONet(self, params, inputs): 105 | trunk_params, branch_params = params 106 | inputsxu, inputsy = inputs 107 | t = self.trunk_apply(trunk_params, inputsy).reshape(inputsy.shape[0], inputsy.shape[1], self.ds, self.n) 108 | b = self.branch_apply(branch_params, inputsxu.reshape(inputsxu.shape[0],1,inputsxu.shape[1]*inputsxu.shape[2])) 109 | b = b.reshape(b.shape[0],int(b.shape[2]/self.ds),self.ds) 110 | Guy = jnp.einsum("ijkl,ilk->ijk", t,b) 111 | return Guy 112 | 113 | @partial(jax.jit, static_argnums=0) 114 | def loss(self, params, batch): 115 | inputs, y = batch 116 | y_pred = self.fwd(params,inputs) 117 | loss = np.mean((y.flatten() - y_pred.flatten())**2) 118 | return loss 119 | 120 | @partial(jax.jit, static_argnums=0) 121 | def L2error(self, params, batch): 122 | inputs, y = batch 123 | y_pred = self.fwd(params,inputs) 124 | return norm(y.flatten() - y_pred.flatten(), 2)/norm(y.flatten(),2) 125 | 126 | @partial(jit, static_argnums=(0,)) 127 | def step(self, i, opt_state, batch): 128 | params = self.get_params(opt_state) 129 | g = grad(self.loss)(params, batch) 130 | return self.opt_update(i, g, opt_state) 131 | 132 | def train(self, train_dataset, test_dataset, nIter = 10000): 133 | train_data = iter(train_dataset) 134 | test_data = iter(test_dataset) 135 | 136 | pbar = trange(nIter) 137 | for it in pbar: 138 | train_batch = next(train_data) 139 | test_batch = next(test_data) 140 | 141 | self.opt_state = self.step(next(self.itercount), self.opt_state, train_batch) 142 | 143 | if it % 100 == 0: 144 | params = self.get_params(self.opt_state) 145 | 146 | loss_train = self.loss(params, train_batch) 147 | loss_test = self.loss(params, test_batch) 148 | 149 | errorTrain = self.L2error(params, train_batch) 150 | errorTest = self.L2error(params, test_batch) 151 | 152 | self.loss_log.append(loss_train) 153 | 154 | pbar.set_postfix({'Training loss': loss_train, 155 | 'Testing loss' : loss_test, 156 | 'Test error': errorTest, 157 | 'Train error': errorTrain}) 158 | 159 | @partial(jit, static_argnums=(0,)) 160 | def predict(self, params, inputs): 161 | s_pred = self.fwd(params,inputs) 162 | return s_pred 163 | 164 | def count_params(self): 165 | params = self.get_params(self.opt_state) 166 | params_flat, _ = ravel_pytree(params) 167 | print("The number of model parameters is:",params_flat.shape[0]) 168 | 169 | def main(n, decoder): 170 | TRAINING_ITERATIONS = 20000 171 | P = 25600 172 | m = 256 173 | num_train = 1000 174 | num_test = 1000 175 | training_batch_size = 100 176 | du = 1 177 | dy = 2 178 | ds = 1 179 | Nx = 256 180 | Nt = 100 181 | 182 | d = np.load("../Data/pure_advection_traintest.npz") 183 | U_train = d["ic"][:num_train] 184 | s_train = d["solution"][:num_train] 185 | U_test = d["ic"][-num_test:] 186 | s_test = d["solution"][-num_test:] 187 | 188 | x = np.linspace(0,2,num=Nx) 189 | t = np.linspace(0,1,num=Nt) 190 | 191 | TT, XX = np.meshgrid(t,x,indexing='ij') 192 | 193 | y_train = jnp.tile(np.concatenate((TT.flatten()[:,None],XX.flatten()[:,None]),axis=-1)[None,...],(num_train,1,1)) 194 | y_test = jnp.tile(np.concatenate((TT.flatten()[:,None],XX.flatten()[:,None]),axis=-1)[None,...],(num_test,1,1)) 195 | 196 | del d 197 | 198 | U_train = jnp.asarray(U_train) 199 | y_train = jnp.asarray(y_train) 200 | s_train = jnp.asarray(s_train) 201 | 202 | U_test = jnp.asarray(U_test) 203 | y_test = jnp.asarray(y_test) 204 | s_test = jnp.asarray(s_test) 205 | 206 | U_train = jnp.reshape(U_train,(num_train,m,du)) 207 | y_train = jnp.reshape(y_train,(num_train,P,dy)) 208 | s_train = jnp.reshape(s_train,(num_train,P,ds)) 209 | 210 | U_test = jnp.reshape(U_test,(num_test,m,du)) 211 | y_test = jnp.reshape(y_test,(num_test,P,dy)) 212 | s_test = jnp.reshape(s_test,(num_test,P,ds)) 213 | 214 | 215 | train_dataset = DataGenerator(U_train, y_train, s_train, training_batch_size) 216 | train_dataset = iter(train_dataset) 217 | 218 | test_dataset = DataGenerator(U_test, y_test, s_test, training_batch_size) 219 | test_dataset = iter(test_dataset) 220 | 221 | if decoder=="nonlinear": 222 | branch_layers = [m, 100, 100, 100, 100, 100, ds*n] 223 | trunk_layers = [ds*n*2, 100, 100, 100, 100, 100, ds] 224 | elif decoder=="linear": 225 | branch_layers = [m, 100, 100, 100, 100, 100, ds*n] 226 | trunk_layers = [dy, 100, 100, 100, 100, 100, ds*n] 227 | 228 | model = operator_model(branch_layers, trunk_layers, m=m, P=P,n=n, decoder=decoder, ds=ds) 229 | model.count_params() 230 | 231 | start_time = timeit.default_timer() 232 | model.train(train_dataset, test_dataset, nIter=TRAINING_ITERATIONS) 233 | elapsed = timeit.default_timer() - start_time 234 | print("The training wall-clock time is seconds is equal to %f seconds"%elapsed) 235 | 236 | del train_dataset, test_dataset 237 | 238 | params = model.get_params(model.opt_state) 239 | 240 | s_pred_test = np.zeros_like(s_test) 241 | for i in range(0,num_test,100): 242 | idx = i + np.arange(0,100) 243 | s_pred_test[idx] = model.predict(params, (U_test[idx], y_test[idx])) 244 | test_error_u = [] 245 | for i in range(0,num_train): 246 | test_error_u.append(norm(s_test[i,:,0]- s_pred_test[i,:,0],2)/norm(s_test[i,:,0],2)) 247 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u),np.std(test_error_u),np.min(test_error_u),np.max(test_error_u))) 248 | 249 | s_pred_train = np.zeros_like(s_train) 250 | for i in range(0,num_train,100): 251 | idx = i + np.arange(0,100) 252 | s_pred_train[idx] = model.predict(params, (U_train[idx], y_train[idx])) 253 | train_error_u = [] 254 | for i in range(0,num_test): 255 | train_error_u.append(norm(s_train[i,:,0]- s_pred_train[i,:,0],2)/norm(s_train[i,:,0],2)) 256 | print("The average train u error is %e"%(np.mean(train_error_u))) 257 | 258 | if __name__ == "__main__": 259 | parser = argparse.ArgumentParser(description='Process model parameters.') 260 | parser.add_argument('n', metavar='n', type=int, nargs='+', help='Latent dimension of the solution manifold') 261 | parser.add_argument('decoder', metavar='decoder', type=str, nargs='+', help='Type of decoder. Choices a)"linear" b)"nonlinear"') 262 | 263 | args = parser.parse_args() 264 | n = args.n[0] 265 | decoder = args.decoder[0] 266 | main(n,decoder) 267 | -------------------------------------------------------------------------------- /Antiderivative/Plot_results/error_lineplots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import seaborn as sns 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | 7 | plt.rcParams.update(plt.rcParamsDefault) 8 | plt.rc('font', family='serif') 9 | plt.rcParams.update({ 10 | "text.usetex": True, 11 | "font.family": "serif", 12 | 'text.latex.preamble': r'\usepackage{amsmath}', 13 | 'font.size': 20, 14 | 'lines.linewidth': 3, 15 | 'axes.labelsize': 22, 16 | 'axes.titlesize': 24, 17 | 'xtick.labelsize': 20, 18 | 'ytick.labelsize': 20, 19 | 'legend.fontsize': 20, 20 | 'axes.linewidth': 2}) 21 | 22 | 23 | n_hat = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] 24 | iterations = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 25 | test_error_DON_linear = np.zeros((len(n_hat), len(iterations))) 26 | test_error_DON_nonlinear = np.zeros((len(n_hat), len(iterations))) 27 | 28 | fig = plt.figure(figsize=(15,4)) 29 | 30 | ax3 = fig.add_subplot(1, 1, 1) 31 | for i in range(len(n_hat)): 32 | n = n_hat[i] 33 | for j in range(len(iterations)): 34 | it = iterations[j] 35 | d = np.load("../Error_Vectors/Error_Antiderivative_DeepONet_nhat%d_iteration%d_linear.npz"%(n,it)) 36 | test_error_DON_linear[i,j] = np.mean(d["test_error"]) 37 | 38 | d = np.load("../Error_Vectors/Error_Antiderivative_DeepONet_nhat%d_iteration%d_nonlinear.npz"%(n,it)) 39 | test_error_DON_nonlinear[i,j] = np.mean(d["test_error"]) 40 | 41 | DON_linear = np.tile(np.array(["Linear Decoder"])[None,:],(1000,1)) 42 | DON_nonlinear = np.tile(np.array(["Non-linear Decoder"])[None,:],(1000,1)) 43 | 44 | position = np.tile(np.array(['1','10','20','30','40','50', '60', '70', '80', '90', '100'])[None,:],(1000,1)) 45 | 46 | DON_all_linear = list(zip(test_error_DON_linear.T.flatten(), DON_linear.flatten(), position.flatten())) 47 | DON_all_nonlinear = list(zip(test_error_DON_nonlinear.T.flatten(), DON_nonlinear.flatten(), position.flatten())) 48 | 49 | all_data = DON_all_linear + DON_all_nonlinear 50 | 51 | df_OKA = pd.DataFrame(all_data, columns = ["Relative $\mathcal{L}_2$ error","Method", "Latent dimension size"]) 52 | flierprops = dict(markerfacecolor='0.75', markersize=0.5, 53 | linestyle='none') 54 | ax3 = sns.lineplot(x="Latent dimension size", y="Relative $\mathcal{L}_2$ error", hue="Method", data=df_OKA, palette="Set1") 55 | ax3.legend(loc='upper center', bbox_to_anchor=(0.6, 1), 56 | fancybox=True, shadow=False, ncol=1) 57 | fig.tight_layout() 58 | 59 | plt.savefig("lineplots.jpg", bbox_inches='tight', pad_inches=0,dpi=300) -------------------------------------------------------------------------------- /Antiderivative/Plot_results/eval_decay.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | import jax.numpy as jnp 5 | from functools import partial 6 | from jax import vmap 7 | from numpy.polynomial.legendre import leggauss 8 | 9 | plt.rcParams.update(plt.rcParamsDefault) 10 | plt.rc('font', family='serif') 11 | plt.rcParams.update({ 12 | "text.usetex": True, 13 | "font.family": "serif", 14 | 'text.latex.preamble': r'\usepackage{amsmath}', 15 | 'font.size': 20, 16 | 'lines.linewidth': 3, 17 | 'axes.labelsize': 22, # fontsize for x and y labels (was 10) 18 | 'axes.titlesize': 24, 19 | 'xtick.labelsize': 20, 20 | 'ytick.labelsize': 20, 21 | 'legend.fontsize': 20, 22 | 'axes.linewidth': 2}) 23 | 24 | 25 | fig = plt.figure(figsize=(15,4)) 26 | 27 | def f(x, t): 28 | return jnp.sin(2.0*jnp.pi*t*x) 29 | 30 | def exact_eigenpairs(x, n, alpha=2.0, tau=0.1): 31 | idx = jnp.arange(n)+1 32 | evals = jnp.power((2.0 * jnp.pi * idx)**2 + tau**2, -alpha) 33 | efuns = jnp.sqrt(2.0) * jnp.sin(2.0 * jnp.pi * idx * x) 34 | return evals, efuns 35 | 36 | @partial(vmap, in_axes=(1, None)) 37 | @partial(vmap, in_axes=(None, 1)) 38 | def gram(f, g): 39 | inner_product = lambda phi_i,phi_j: jnp.einsum('ij,i,j', jnp.diag(w), phi_i, phi_j) 40 | return inner_product(f, g) 41 | 42 | def fPCA_eig(functions, gram_cov): 43 | gramevals, gramevecs = jnp.linalg.eigh(gram_cov) 44 | efuncs = jnp.matmul(gramevecs, functions.T).T 45 | evals, efuncs = jnp.flip(gramevals, axis=-1), jnp.flip(efuncs, axis=-1) 46 | return evals, efuncs 47 | 48 | # returns quadrature nodes and weights 49 | def legendre_quadrature_1d(n_quad, bounds=(-1.0,1.0)): 50 | lb, ub = bounds 51 | # GLL nodes and weights in [-1,1] 52 | x, w = leggauss(n_quad) 53 | x = 0.5*(ub - lb)*(x + 1.0) + lb 54 | x = np.array(x[:,None]) 55 | jac_det = 0.5*(ub-lb) 56 | w = np.array(w*jac_det) 57 | return x, w 58 | 59 | 60 | n_quad = 500 61 | bounds = (0, 1) 62 | x, w = legendre_quadrature_1d(n_quad, bounds) 63 | inner_product = lambda phi_i,phi_j: jnp.einsum('ij,i,j', jnp.diag(w), phi_i, phi_j) 64 | 65 | 66 | t0 = 0 67 | t1 = 10 68 | num_funcs = 1000 69 | ts = np.linspace(t0, t1, num_funcs) 70 | curve_fs = vmap(lambda t: f(x, t))(ts)[...,0] 71 | 72 | mean = np.mean(curve_fs, axis=0) 73 | mean_curves = curve_fs - mean 74 | 75 | curve_fs = mean_curves.T 76 | gram_mat = gram(curve_fs, curve_fs) 77 | evals, efuncs = fPCA_eig(curve_fs, gram_mat) 78 | 79 | ax2 = fig.add_subplot(1, 1, 1) 80 | ax2.loglog(evals[:50], 'b-', alpha=1.0, label='Eigenvalue Decay') 81 | ax2.set_xlabel(r'Index') 82 | ax2.set_ylabel(r'Eigenvalue') 83 | ax2.legend(loc='upper center', bbox_to_anchor=(0.60, 1), 84 | fancybox=True, shadow=False, ncol=1) 85 | ax2.autoscale(tight=True) 86 | 87 | plt.savefig("eval_decay.jpg", bbox_inches='tight', pad_inches=0,dpi=300) -------------------------------------------------------------------------------- /Antiderivative/Plot_results/plot_spiral.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import jax.numpy as jnp 4 | from functools import partial 5 | 6 | import numpy as np 7 | from functools import partial 8 | 9 | from jax import vmap 10 | from numpy.polynomial.legendre import leggauss 11 | 12 | plt.rcParams.update(plt.rcParamsDefault) 13 | plt.rc('font', family='serif') 14 | plt.rcParams.update({ 15 | "text.usetex": True, 16 | "font.family": "serif", 17 | 'text.latex.preamble': r'\usepackage{amsmath}', 18 | 'font.size': 20, 19 | 'lines.linewidth': 3, 20 | 'axes.labelsize': 22, # fontsize for x and y labels (was 10) 21 | 'axes.titlesize': 24, 22 | 'xtick.labelsize': 20, 23 | 'ytick.labelsize': 20, 24 | 'legend.fontsize': 20, 25 | 'axes.linewidth': 2}) 26 | 27 | 28 | fig = plt.figure(figsize=(15,4)) 29 | 30 | def f(x, t): 31 | return jnp.sin(2.0*jnp.pi*t*x) 32 | 33 | # Define pairwise inner product function 34 | # Takes two lists of functions and returns gram matrix of inner products 35 | @partial(vmap, in_axes=(1, None)) 36 | @partial(vmap, in_axes=(None, 1)) 37 | def gram(f, g): 38 | inner_product = lambda phi_i,phi_j: jnp.einsum('ij,i,j', jnp.diag(w), phi_i, phi_j) 39 | return inner_product(f, g) 40 | 41 | # given a list of functions and gram matrix of pairwise inner products, gets 42 | # principal component functions and associated eigenvalues 43 | def fPCA_eig(functions, gram_cov): 44 | gramevals, gramevecs = jnp.linalg.eigh(gram_cov) 45 | efuncs = jnp.matmul(gramevecs, functions.T).T 46 | evals, efuncs = jnp.flip(gramevals, axis=-1), jnp.flip(efuncs, axis=-1) 47 | return evals, efuncs 48 | 49 | # returns quadrature nodes and weights 50 | def legendre_quadrature_1d(n_quad, bounds=(-1.0,1.0)): 51 | lb, ub = bounds 52 | # GLL nodes and weights in [-1,1] 53 | x, w = leggauss(n_quad) 54 | # Rescale nodes to [lb,ub] 55 | x = 0.5*(ub - lb)*(x + 1.0) + lb 56 | 57 | x = np.array(x[:,None]) 58 | # Determinant of Jacobian of mapping [lb,ub]-->[-1,1] 59 | jac_det = 0.5*(ub-lb) 60 | w = np.array(w*jac_det) 61 | return x, w 62 | 63 | 64 | n_quad = 500 65 | bounds = (0, 1) 66 | x, w = legendre_quadrature_1d(n_quad, bounds) 67 | inner_product = lambda phi_i,phi_j: jnp.einsum('ij,i,j', jnp.diag(w), phi_i, phi_j) 68 | 69 | t0 = 0 70 | t1 = 10 71 | num_funcs = 1000 72 | ts = np.linspace(t0, t1, num_funcs) 73 | curve_fs = vmap(lambda t: f(x, t))(ts)[...,0] 74 | 75 | mean = np.mean(curve_fs, axis=0) 76 | mean_curves = curve_fs - mean 77 | 78 | curve_fs = mean_curves.T 79 | gram_mat = gram(curve_fs, curve_fs) 80 | evals, efuncs = fPCA_eig(curve_fs, gram_mat) 81 | 82 | projs = vmap(lambda idx: vmap(lambda ft: inner_product(ft, efuncs[:,idx]))(curve_fs.T))(np.arange(num_funcs)) 83 | print(projs.shape) 84 | 85 | first_proj = projs[0] 86 | second_proj = projs[1] 87 | third_proj = projs[2] 88 | 89 | ax1 = fig.add_subplot(1, 3, 1, projection='3d') 90 | ax1.zaxis.set_rotate_label(False) # disable automatic rotation 91 | ax1.set_zlabel(r'z', rotation=90) 92 | for i in range(num_funcs-1): 93 | ax1.plot(first_proj[i:i+2], second_proj[i:i+2], third_proj[i:i+2], color=plt.cm.cool(i/num_funcs)) 94 | ax1.set_xlabel(r'x') 95 | ax1.set_ylabel(r'y') 96 | ax1.set_title(r'Projection onto principal components') 97 | 98 | plt.savefig("spiral.jpg", bbox_inches='tight', pad_inches=0,dpi=300) -------------------------------------------------------------------------------- /Antiderivative/Train_model/train_antiderivative.py: -------------------------------------------------------------------------------- 1 | 2 | from jax.flatten_util import ravel_pytree 3 | from jax.example_libraries.stax import Dense, Gelu 4 | from jax.example_libraries import stax, optimizers 5 | import os 6 | import timeit 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | import numpy as np 11 | from jax.numpy.linalg import norm 12 | from jax import random, grad, jit, vmap 13 | from jax.experimental.ode import odeint 14 | from functools import partial 15 | from torch.utils import data 16 | from tqdm import trange 17 | import itertools 18 | import argparse 19 | 20 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']="False" 21 | 22 | class DataGenerator(data.Dataset): 23 | def __init__(self, u, y, s, 24 | batch_size=100, rng_key=random.PRNGKey(1234)): 25 | 'Initialization' 26 | self.u = u 27 | self.y = y 28 | self.s = s 29 | 30 | self.N = u.shape[0] 31 | self.batch_size = batch_size 32 | self.key = rng_key 33 | 34 | def __getitem__(self, index): 35 | 'Generate one batch of data' 36 | self.key, subkey = random.split(self.key) 37 | inputs,outputs = self.__data_generation(subkey) 38 | return inputs, outputs 39 | 40 | @partial(jit, static_argnums=(0,)) 41 | def __data_generation(self, key): 42 | 'Generates data containing batch_size samples' 43 | idx = random.choice(key, self.N, (self.batch_size,), replace=False) 44 | s = self.s[idx,:,:] 45 | u = self.u[idx,:,:] 46 | y = self.y[idx,:,:] 47 | inputs = (u, y) 48 | return inputs, s 49 | 50 | # Geneate training data corresponding to one input sample 51 | def generate_one_datum(freq, m=500, P=500): 52 | X = jnp.linspace(0,1, num=m) 53 | u = 2*jnp.pi*freq*jnp.cos(2*jnp.pi*freq*X) 54 | u_fn = lambda x, t: jnp.interp(t, X.flatten(), u) 55 | u = vmap(u_fn, in_axes=(None,0))(0.0, X) 56 | y_train = jnp.linspace(0, 1, P) 57 | s_train = odeint(u_fn, 0.0, y_train) 58 | return u, y_train, s_train 59 | 60 | # Geneate training data corresponding to N input sample 61 | def generate_data(freqs, N, m): 62 | gen_fn = jit(lambda freq: generate_one_datum(freq, m)) 63 | u_train, y_train, s_train = vmap(gen_fn)(freqs) 64 | return u_train, y_train, s_train 65 | 66 | class operator_model: 67 | def __init__(self,branch_layers, trunk_layers , m=100, P=100,n=None, decoder=None, ds=None): 68 | 69 | seed = np.random.randint(low=0, high=100000) 70 | self.branch_init, self.branch_apply = self.init_NN(branch_layers, activation=Gelu) 71 | self.in_shape = (-1, branch_layers[0]) 72 | self.out_shape, branch_params = self.branch_init(random.PRNGKey(seed), self.in_shape) 73 | 74 | seed = np.random.randint(low=0, high=100000) 75 | self.trunk_init, self.trunk_apply = self.init_NN(trunk_layers, activation=Gelu) 76 | self.in_shape = (-1, trunk_layers[0]) 77 | self.out_shape, trunk_params = self.trunk_init(random.PRNGKey(seed), self.in_shape) 78 | 79 | params = (trunk_params, branch_params) 80 | # Use optimizers to set optimizer initialization and update functions 81 | self.opt_init,self.opt_update,self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 82 | decay_steps=100, 83 | decay_rate=0.99)) 84 | self.opt_state = self.opt_init(params) 85 | # Logger 86 | self.itercount = itertools.count() 87 | self.loss_log = [] 88 | 89 | if decoder=="nonlinear": 90 | self.fwd = self.NOMAD 91 | if decoder=="linear": 92 | self.fwd = self.DeepONet 93 | 94 | self.n = n 95 | self.ds = ds 96 | 97 | 98 | def init_NN(self, Q, activation=Gelu): 99 | layers = [] 100 | num_layers = len(Q) 101 | if num_layers < 2: 102 | net_init, net_apply = stax.serial() 103 | else: 104 | for i in range(0, num_layers-2): 105 | layers.append(Dense(Q[i+1])) 106 | layers.append(activation) 107 | layers.append(Dense(Q[-1])) 108 | net_init, net_apply = stax.serial(*layers) 109 | return net_init, net_apply 110 | 111 | @partial(jax.jit, static_argnums=0) 112 | def NOMAD(self, params, inputs): 113 | trunk_params, branch_params = params 114 | inputsu, inputsy = inputs 115 | b = self.branch_apply(branch_params, inputsu.reshape(inputsu.shape[0], 1, inputsu.shape[1])) 116 | b = jnp.tile(b, (1,inputsy.shape[1],1)) 117 | inputs_recon = jnp.concatenate((jnp.tile(inputsy,(1,1,b.shape[-1]//inputsy.shape[-1])), b), axis=-1) 118 | out = self.trunk_apply(trunk_params, inputs_recon) 119 | return out 120 | 121 | @partial(jax.jit, static_argnums=0) 122 | def DeepONet(self, params, inputs): 123 | trunk_params, branch_params = params 124 | inputsxu, inputsy = inputs 125 | t = self.trunk_apply(trunk_params, inputsy).reshape(inputsy.shape[0], inputsy.shape[1], self.ds, self.n) 126 | b = self.branch_apply(branch_params, inputsxu.reshape(inputsxu.shape[0],1,inputsxu.shape[1]*inputsxu.shape[2])) 127 | b = b.reshape(b.shape[0],int(b.shape[2]/self.ds),self.ds) 128 | Guy = jnp.einsum("ijkl,ilk->ijk", t,b) 129 | return Guy 130 | 131 | @partial(jax.jit, static_argnums=0) 132 | def loss(self, params, batch): 133 | inputs, y = batch 134 | y_pred = self.fwd(params,inputs) 135 | loss = np.mean((y.flatten() - y_pred.flatten())**2) 136 | return loss 137 | 138 | @partial(jax.jit, static_argnums=0) 139 | def L2error(self, params, batch): 140 | inputs, y = batch 141 | y_pred = self.fwd(params,inputs) 142 | return norm(y.flatten() - y_pred.flatten(), 2)/norm(y.flatten(),2) 143 | 144 | @partial(jit, static_argnums=(0,)) 145 | def step(self, i, opt_state, batch): 146 | params = self.get_params(opt_state) 147 | g = grad(self.loss)(params, batch) 148 | return self.opt_update(i, g, opt_state) 149 | 150 | def train(self, train_dataset, test_dataset, nIter = 10000): 151 | train_data = iter(train_dataset) 152 | test_data = iter(test_dataset) 153 | 154 | pbar = trange(nIter) 155 | for it in pbar: 156 | train_batch = next(train_data) 157 | test_batch = next(test_data) 158 | 159 | self.opt_state = self.step(next(self.itercount), self.opt_state, train_batch) 160 | 161 | if it % 100 == 0: 162 | params = self.get_params(self.opt_state) 163 | 164 | loss_train = self.loss(params, train_batch) 165 | loss_test = self.loss(params, test_batch) 166 | 167 | errorTrain = self.L2error(params, train_batch) 168 | errorTest = self.L2error(params, test_batch) 169 | 170 | self.loss_log.append(loss_train) 171 | 172 | pbar.set_postfix({'Training loss': loss_train, 173 | 'Testing loss' : loss_test, 174 | 'Test error': errorTest, 175 | 'Train error': errorTrain}) 176 | 177 | @partial(jit, static_argnums=(0,)) 178 | def predict(self, params, inputs): 179 | s_pred = self.fwd(params,inputs) 180 | return s_pred 181 | 182 | def count_params(self): 183 | params = self.get_params(self.opt_state) 184 | params_flat, _ = ravel_pytree(params) 185 | print("The number of model parameters is:",params_flat.shape[0]) 186 | 187 | def main(n, decoder): 188 | TRAINING_ITERATIONS = 20000 189 | P = 500 190 | m = 500 191 | num_train = 1000 192 | num_test = 1000 193 | training_batch_size = 100 194 | du = 1 195 | dy = 1 196 | ds = 1 197 | Nx = 100 198 | 199 | key_train = random.PRNGKey(0) 200 | _, subkey = random.split(key_train) 201 | train_freqs = random.uniform(subkey, minval=0, maxval=10, shape=(num_train,)) 202 | U_train, y_train, s_train = generate_data(train_freqs, num_train, m) 203 | key_test = random.PRNGKey(12345) 204 | _, subkey = random.split(key_test) 205 | test_freqs = random.uniform(subkey, minval=0, maxval=10, shape=(num_test,)) 206 | U_test, y_test, s_test = generate_data(test_freqs, num_test, m) 207 | 208 | U_train = jnp.asarray(U_train) 209 | y_train = jnp.asarray(y_train) 210 | s_train = jnp.asarray(s_train) 211 | 212 | U_test = jnp.asarray(U_test) 213 | y_test = jnp.asarray(y_test) 214 | s_test = jnp.asarray(s_test) 215 | 216 | U_train = jnp.reshape(U_train,(num_train,m,du)) 217 | y_train = jnp.reshape(y_train,(num_train,P,dy)) 218 | s_train = jnp.reshape(s_train,(num_train,P,ds)) 219 | 220 | U_test = jnp.reshape(U_test,(num_test,m,du)) 221 | y_test = jnp.reshape(y_test,(num_test,P,dy)) 222 | s_test = jnp.reshape(s_test,(num_test,P,ds)) 223 | 224 | 225 | train_dataset = DataGenerator(U_train, y_train, s_train, training_batch_size) 226 | train_dataset = iter(train_dataset) 227 | test_dataset = DataGenerator(U_test, y_test, s_test, training_batch_size) 228 | test_dataset = iter(test_dataset) 229 | 230 | if decoder=="nonlinear": 231 | branch_layers = [m, 100, 100, 100, 100, 100, ds*n] 232 | trunk_layers = [ds*n*2, 100, 100, 100, 100, 100, ds] 233 | elif decoder=="linear": 234 | branch_layers = [m, 100, 100, 100, 100, 100, ds*n] 235 | trunk_layers = [dy, 100, 100, 100, 100, 100, ds*n] 236 | 237 | model = operator_model(branch_layers, trunk_layers, m=m, P=P,n=n, decoder=decoder, ds=ds) 238 | model.count_params() 239 | 240 | start_time = timeit.default_timer() 241 | model.train(train_dataset, test_dataset, nIter=TRAINING_ITERATIONS) 242 | elapsed = timeit.default_timer() - start_time 243 | print("The training wall-clock time is seconds is equal to %f seconds"%elapsed) 244 | 245 | del train_dataset, test_dataset 246 | 247 | params = model.get_params(model.opt_state) 248 | 249 | s_pred_test = np.zeros_like(s_test) 250 | for i in range(0,num_test,100): 251 | idx = i + np.arange(0,100) 252 | s_pred_test[idx] = model.predict(params, (U_test[idx], y_test[idx])) 253 | test_error_u = [] 254 | for i in range(0,num_train): 255 | test_error_u.append(norm(s_test[i,:,0]- s_pred_test[i,:,0],2)/norm(s_test[i,:,0],2)) 256 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u),np.std(test_error_u),np.min(test_error_u),np.max(test_error_u))) 257 | 258 | s_pred_train = np.zeros_like(s_train) 259 | for i in range(0,num_train,100): 260 | idx = i + np.arange(0,100) 261 | s_pred_train[idx] = model.predict(params, (U_train[idx], y_train[idx])) 262 | train_error_u = [] 263 | for i in range(0,num_test): 264 | train_error_u.append(norm(s_train[i,:,0]- s_pred_train[i,:,0],2)/norm(s_train[i,:,0],2)) 265 | print("The average train u error is %e"%(np.mean(train_error_u))) 266 | 267 | if __name__ == "__main__": 268 | parser = argparse.ArgumentParser(description='Process model parameters.') 269 | parser.add_argument('n', metavar='n', type=int, nargs='+', help='Latent dimension of the solution manifold') 270 | parser.add_argument('decoder', metavar='decoder', type=str, nargs='+', help='Type of decoder. Choices a)"linear" b)"nonlinear"') 271 | 272 | args = parser.parse_args() 273 | n = args.n[0] 274 | decoder = args.decoder[0] 275 | main(n,decoder) 276 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NOMAD 2 | ## Nonlinear Manifold Decoders for Operator Learning 3 | 4 | ![master_figure-2](https://user-images.githubusercontent.com/3844367/195421218-164f4be9-f258-4bed-acba-484d67cae7a3.png) 5 | 6 | This repository contains code and data accompanying the manuscript titled "Nonlinear Manifold Decoders for Operator Learning", authored by Jacob Seidman*, Georgios Kissas*, Paris Perdikaris, and George Pappas. 7 | 8 | ## Abstract 9 | 10 | Supervised learning on function spaces is an emerging area of machine learning research with applications to the prediction of complex physical systems such as fluid flows, solid mechanics, and climate modelling. By directly learning maps (operators) between infinite dimensional function spaces, these models are able to learn discretization invariant representations of target functions. A common approach is to represent such target functions as linear combinations of basis elements learned from data. However, there are simple scenarios where, even though the target functions form a low dimensional submanifold, a very large number of basis elements is needed for an accurate linear representation. Here we propose a novel operator learning framework capable of learning finite-dimensional coordinate representations for nonlinear submanifolds in function spaces. We show this method is able to accurately learn low dimensional representations of solution manifolds to partial differential equations while outperforming linear models of larger size. Additionally, we compare to state-of-the-art operator learning methods on a complex fluid dynamics benchmark and achieve competitive performance with a significantly smaller model size and training cost. 11 | 12 | ## Code Repository 13 | 14 | The repository contains three folders: one for each example presented in the manuscript, namely the Antiderivative, the Advection and the Shallow Water 15 | folders. 16 | 17 | Each example folder should contain 4 subfolders, 18 | 19 | - a) "Train_model": containing the python file that can be used for training the model 20 | - b) "Error_Vectors": containing the error vectors resulting from multiple model runs used to generate the figures in the manuscript 21 | - c) "Data": containing the data sets used for training and testing the model (the Antiderivative does not contain a folder like that 22 | because the data sets are created on-the-fly.) 23 | - d) "Plot_results": containing the python scripts that can be used in order to reproduce the figures presented in the paper. 24 | 25 | a) and d) can be found in this GitHub repository and contain only python script files. 26 | 27 | All the data and codes required to reproduce the results can be downloaded from the following direct Google Drive link 28 | 29 | https://drive.google.com/file/d/1xEzD2swxBcBR5FdHZfc9m7o0Fhe7Z3jB/view?usp=sharing 30 | 31 | ## Code usage 32 | 33 | For training the advection model you need to run the file in the "Train_model" folder with arguments n, the size of the latent dimension, and the type of the decoder, "linear" or "nonlinear". For example if you run "python train_advection 10 nonlinear" you will train the model for the advection case using a solution manifold latent dimension of size 10 and a nonlinear decoder. If you run "python train_advection 10 linear" you will repeat the process using a linear decoder. 34 | 35 | For the Antiderivative case, you can run the file in the "Train_model" folder in the same manner, i.e. "python train_antiderivative 10 nonlinear". For the Shallow Water Equations, the same logic applies. So, you can run "python train_SW.py 10 nonlinear" in the "Train_model" folder. 36 | 37 | For making plots you need to run the python files in the "Plot_results" folders without any arguments, i.e. for the advection case "python analytic_solution.py". 38 | 39 | 40 | ## Citation 41 | 42 | @article{seidman2022nomad, 43 | title={NOMAD: Nonlinear Manifold Decoders for Operator Learning}, 44 | author={Seidman, Jacob H and Kissas, Georgios and Perdikaris, Paris and Pappas, George J}, 45 | journal={arXiv preprint arXiv:2206.03551}, 46 | year={2022} 47 | } 48 | -------------------------------------------------------------------------------- /Shallow Water /Plot_results/plot_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from scipy import stats 4 | 5 | plt.rcParams.update(plt.rcParamsDefault) 6 | plt.rc('font', family='serif') 7 | plt.rcParams.update({ 8 | "text.usetex": True, 9 | "font.family": "serif", 10 | 'text.latex.preamble': r'\usepackage{amsmath}', 11 | 'font.size': 20, 12 | 'lines.linewidth': 3, 13 | 'axes.labelsize': 22, # fontsize for x and y labels (was 10) 14 | 'axes.titlesize': 24, 15 | 'xtick.labelsize': 20, 16 | 'ytick.labelsize': 20, 17 | 'legend.fontsize': 20, 18 | 'axes.linewidth': 2}) 19 | 20 | # load the dataset 21 | n_hat = [1, 2, 5, 10, 30, 50, 70, 100] 22 | iterations = [0,1, 2, 3, 4, 5, 6, 7,8,9] 23 | par = 0 24 | test_error_DON_linear = np.zeros((len(n_hat), len(iterations))) 25 | test_error_DON_nonlinear = np.zeros((len(n_hat), len(iterations))) 26 | 27 | for i in range(len(n_hat)): 28 | n = n_hat[i] 29 | for j in range(len(iterations)): 30 | it = iterations[j] 31 | d = np.load("../Error_vectors/Error_SW_DeepONet_nhat%d_iteration%d_linear.npz"%(n,it)) 32 | test_error_DON_linear[i,j] = np.mean(d["test_error"],axis=1)[...,par] 33 | 34 | d = np.load("../Error_vectors/Error_SW_DeepONet_nhat%d_iteration%d_nonlinear.npz"%(n,it)) 35 | test_error_DON_nonlinear[i,j] = np.mean(d["test_error"],axis=1)[...,par] 36 | 37 | lin_mu, lin_std = np.median(test_error_DON_linear, axis = 1), stats.median_abs_deviation(test_error_DON_linear, axis = 1) 38 | nonlin_mu, nonlin_std = np.median(test_error_DON_nonlinear, axis = 1), stats.median_abs_deviation(test_error_DON_nonlinear, axis = 1) 39 | 40 | dispersion_scale = 1.0 41 | lin_lower = np.log10(np.clip(lin_mu - dispersion_scale*lin_std, a_min=0., a_max = np.inf) + 1e-8) 42 | lin_upper = np.log10(lin_mu + dispersion_scale*lin_std + 1e-8) 43 | 44 | nonlin_lower = np.log10(np.clip(nonlin_mu - dispersion_scale*nonlin_std, a_min=0., a_max = np.inf) + 1e-8) 45 | nonlin_upper = np.log10(nonlin_mu + dispersion_scale*nonlin_std + 1e-8) 46 | 47 | fig = plt.figure(figsize=(8,7)) 48 | ax1 = fig.add_subplot(3,1,1) 49 | ax1.plot(np.array(n_hat), np.log10(lin_mu), 'k', label='Linear Decoder') 50 | ax1.fill_between(np.array(n_hat), lin_lower, lin_upper, 51 | facecolor='black', alpha=0.5), 100 52 | 53 | ax1.plot(np.array(n_hat), np.log10(nonlin_mu), 'm', label='NOMAD') 54 | ax1.fill_between(np.array(n_hat), nonlin_lower, nonlin_upper, 55 | facecolor='magenta', alpha=0.5) 56 | ax1.legend(frameon=False) 57 | ax1.set_xticks([1,5,10,30, 50, 70, 100]) 58 | axR1 = fig.add_subplot(3,1,1, sharex=ax1, frameon=False) 59 | axR1.yaxis.tick_right() 60 | axR1.yaxis.set_label_position("right") 61 | axR1.axes.yaxis.set_ticklabels([]) 62 | axR1.set_ylabel(r'$\rho$') 63 | 64 | test_error_DON_linear = np.zeros((len(n_hat), len(iterations))) 65 | test_error_DON_nonlinear = np.zeros((len(n_hat), len(iterations))) 66 | par = 1 67 | for i in range(len(n_hat)): 68 | n = n_hat[i] 69 | for j in range(len(iterations)): 70 | it = iterations[j] 71 | d = np.load("../Error_vectors/Error_SW_DeepONet_nhat%d_iteration%d_linear.npz"%(n,it)) 72 | test_error_DON_linear[i,j] = np.mean(d["test_error"],axis=1)[...,par] 73 | 74 | d = np.load("../Error_vectors/Error_SW_DeepONet_nhat%d_iteration%d_nonlinear.npz"%(n,it)) 75 | test_error_DON_nonlinear[i,j] = np.mean(d["test_error"],axis=1)[...,par] 76 | 77 | lin_mu, lin_std = np.median(test_error_DON_linear, axis = 1), stats.median_abs_deviation(test_error_DON_linear, axis = 1) 78 | nonlin_mu, nonlin_std = np.median(test_error_DON_nonlinear, axis = 1), stats.median_abs_deviation(test_error_DON_nonlinear, axis = 1) 79 | 80 | dispersion_scale = 1.0 81 | lin_lower = np.log10(np.clip(lin_mu - dispersion_scale*lin_std, a_min=0., a_max = np.inf) + 1e-8) 82 | lin_upper = np.log10(lin_mu + dispersion_scale*lin_std + 1e-8) 83 | 84 | nonlin_lower = np.log10(np.clip(nonlin_mu - dispersion_scale*nonlin_std, a_min=0., a_max = np.inf) + 1e-8) 85 | nonlin_upper = np.log10(nonlin_mu + dispersion_scale*nonlin_std + 1e-8) 86 | 87 | ax2 = fig.add_subplot(3,1,2) 88 | ax2.plot(np.array(n_hat), np.log10(lin_mu), 'k', label='Linear Decoder') 89 | ax2.fill_between(np.array(n_hat), lin_lower, lin_upper, 90 | facecolor='black', alpha=0.5), 100 91 | 92 | ax2.plot(np.array(n_hat), np.log10(nonlin_mu), 'm', label='NOMAD') 93 | ax2.fill_between(np.array(n_hat), nonlin_lower, nonlin_upper, 94 | facecolor='magenta', alpha=0.5) 95 | ax2.legend(frameon=False) 96 | ax2.set_ylabel(r'Relative $\mathcal{L}_2$ error ($\log_{10}$)') 97 | ax2.set_xticks([1,5,10,30, 50, 70, 100]) 98 | axR2 = fig.add_subplot(3,1,2, sharex=ax2, frameon=False) 99 | axR2.yaxis.tick_right() 100 | axR2.yaxis.set_label_position("right") 101 | axR2.axes.yaxis.set_ticklabels([]) 102 | axR2.set_ylabel(r'$v_1$') 103 | 104 | test_error_DON_linear = np.zeros((len(n_hat), len(iterations))) 105 | test_error_DON_nonlinear = np.zeros((len(n_hat), len(iterations))) 106 | par = 2 107 | 108 | for i in range(len(n_hat)): 109 | n = n_hat[i] 110 | for j in range(len(iterations)): 111 | it = iterations[j] 112 | d = np.load("../Error_vectors/Error_SW_DeepONet_nhat%d_iteration%d_linear.npz"%(n,it)) 113 | test_error_DON_linear[i,j] = np.mean(d["test_error"],axis=1)[...,par] 114 | 115 | d = np.load("../Error_vectors/Error_SW_DeepONet_nhat%d_iteration%d_nonlinear.npz"%(n,it)) 116 | test_error_DON_nonlinear[i,j] = np.mean(d["test_error"],axis=1)[...,par] 117 | 118 | lin_mu, lin_std = np.median(test_error_DON_linear, axis = 1), stats.median_abs_deviation(test_error_DON_linear, axis = 1) 119 | nonlin_mu, nonlin_std = np.median(test_error_DON_nonlinear, axis = 1), stats.median_abs_deviation(test_error_DON_nonlinear, axis = 1) 120 | 121 | dispersion_scale = 1.0 122 | lin_lower = np.log10(np.clip(lin_mu - dispersion_scale*lin_std, a_min=0., a_max = np.inf) + 1e-8) 123 | lin_upper = np.log10(lin_mu + dispersion_scale*lin_std + 1e-8) 124 | 125 | nonlin_lower = np.log10(np.clip(nonlin_mu - dispersion_scale*nonlin_std, a_min=0., a_max = np.inf) + 1e-8) 126 | nonlin_upper = np.log10(nonlin_mu + dispersion_scale*nonlin_std + 1e-8) 127 | 128 | ax3 = fig.add_subplot(3,1,3) 129 | ax3.plot(np.array(n_hat), np.log10(lin_mu), 'k', label='Linear Decoder') 130 | ax3.fill_between(np.array(n_hat), lin_lower, lin_upper, 131 | facecolor='black', alpha=0.5), 100 132 | 133 | ax3.plot(np.array(n_hat), np.log10(nonlin_mu), 'm', label='NOMAD') 134 | ax3.fill_between(np.array(n_hat), nonlin_lower, nonlin_upper, 135 | facecolor='magenta', alpha=0.5) 136 | ax3.legend(frameon=False) 137 | ax3.set_xlabel(r'Latent dimension $n$') 138 | ax3.set_xticks([1,5,10,30, 50, 70, 100]) 139 | axR3 = fig.add_subplot(3,1,3, sharex=ax3, frameon=False) 140 | axR3.yaxis.tick_right() 141 | axR3.yaxis.set_label_position("right") 142 | axR3.axes.yaxis.set_ticklabels([]) 143 | axR3.set_ylabel(r'$v_2$') 144 | 145 | plt.savefig("SW_errors.png", bbox_inches='tight', dpi=600) 146 | 147 | -------------------------------------------------------------------------------- /Shallow Water /Train_model/train_SW.py: -------------------------------------------------------------------------------- 1 | from jax.flatten_util import ravel_pytree 2 | from jax.example_libraries.stax import Dense, Gelu 3 | from jax.example_libraries import stax, optimizers 4 | import os 5 | import timeit 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | from jax.numpy.linalg import norm 11 | from jax import random, grad, jit 12 | from functools import partial 13 | from torch.utils import data 14 | from tqdm import trange 15 | import itertools 16 | import argparse 17 | 18 | def output_construction(Ux,t_his, cx, cy, P=1000, ds=3, Nx=32, Ny=32, Nt=100): 19 | U_all = np.zeros((P,ds)) 20 | Y_all = np.zeros((P,ds)) 21 | it = np.random.randint(Nt, size=P) 22 | x = np.random.randint(Nx, size=P) 23 | y = np.random.randint(Ny, size=P) 24 | T, X, Y = np.meshgrid(t_his,cx,cy,indexing="ij") 25 | Y_all[:,:] = np.concatenate((T[it,x][range(P),y][:,None], X[it,x][range(P),y][:,None], Y[it,x][range(P),y][:,None]),axis=-1) 26 | U_all[:,:] = Ux[it,x][range(P),y] 27 | return U_all, Y_all 28 | 29 | class DataGenerator(data.Dataset): 30 | def __init__(self, u, y, s, 31 | batch_size=100, rng_key=random.PRNGKey(1234)): 32 | 'Initialization' 33 | self.u = u 34 | self.y = y 35 | self.s = s 36 | 37 | self.N = u.shape[0] 38 | self.batch_size = batch_size 39 | self.key = rng_key 40 | 41 | def __getitem__(self, index): 42 | 'Generate one batch of data' 43 | self.key, subkey = random.split(self.key) 44 | inputs,outputs = self.__data_generation(subkey) 45 | return inputs, outputs 46 | 47 | @partial(jit, static_argnums=(0,)) 48 | def __data_generation(self, key): 49 | 'Generates data containing batch_size samples' 50 | idx = random.choice(key, self.N, (self.batch_size,), replace=False) 51 | s = self.s[idx,:,:] 52 | u = self.u[idx,:,:] 53 | y = self.y[idx,:,:] 54 | inputs = (u, y) 55 | return inputs, s 56 | 57 | class operator_model: 58 | def __init__(self,branch_layers, trunk_layers , n=None, decoder=None, ds=None): 59 | 60 | seed = np.random.randint(low=0, high=100000) 61 | self.branch_init, self.branch_apply = self.init_NN(branch_layers, activation=Gelu) 62 | self.in_shape = (-1, branch_layers[0]) 63 | self.out_shape, branch_params = self.branch_init(random.PRNGKey(seed), self.in_shape) 64 | 65 | seed = np.random.randint(low=0, high=100000) 66 | self.trunk_init, self.trunk_apply = self.init_NN(trunk_layers, activation=Gelu) 67 | self.in_shape = (-1, trunk_layers[0]) 68 | self.out_shape, trunk_params = self.trunk_init(random.PRNGKey(seed), self.in_shape) 69 | 70 | params = (trunk_params, branch_params) 71 | # Use optimizers to set optimizer initialization and update functions 72 | self.opt_init,self.opt_update,self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 73 | decay_steps=100, 74 | decay_rate=0.99)) 75 | self.opt_state = self.opt_init(params) 76 | # Logger 77 | self.itercount = itertools.count() 78 | self.loss_log = [] 79 | 80 | if decoder=="nonlinear": 81 | self.fwd = self.NOMAD 82 | if decoder=="linear": 83 | self.fwd = self.DeepONet 84 | 85 | self.n = n 86 | self.ds = ds 87 | 88 | 89 | def init_NN(self, Q, activation=Gelu): 90 | layers = [] 91 | num_layers = len(Q) 92 | if num_layers < 2: 93 | net_init, net_apply = stax.serial() 94 | else: 95 | for i in range(0, num_layers-2): 96 | layers.append(Dense(Q[i+1])) 97 | layers.append(activation) 98 | layers.append(Dense(Q[-1])) 99 | net_init, net_apply = stax.serial(*layers) 100 | return net_init, net_apply 101 | 102 | @partial(jax.jit, static_argnums=0) 103 | def NOMAD(self, params, inputs): 104 | trunk_params, branch_params = params 105 | inputsu, inputsy = inputs 106 | b = self.branch_apply(branch_params, inputsu.reshape(inputsu.shape[0], 1, self.ds*inputsu.shape[1])) 107 | b = jnp.tile(b, (1,inputsy.shape[1],1)) 108 | inputs_recon = jnp.concatenate((jnp.tile(inputsy,(1,1,b.shape[-1]//inputsy.shape[-1])), b), axis=-1) 109 | out = self.trunk_apply(trunk_params, inputs_recon) 110 | return out 111 | 112 | @partial(jax.jit, static_argnums=0) 113 | def DeepONet(self, params, inputs): 114 | trunk_params, branch_params = params 115 | inputsxu, inputsy = inputs 116 | t = self.trunk_apply(trunk_params, inputsy).reshape(inputsy.shape[0], inputsy.shape[1], self.ds, self.n) 117 | b = self.branch_apply(branch_params, inputsxu.reshape(inputsxu.shape[0],1,inputsxu.shape[1]*inputsxu.shape[2])) 118 | b = b.reshape(b.shape[0],int(b.shape[2]/self.ds),self.ds) 119 | Guy = jnp.einsum("ijkl,ilk->ijk", t,b) 120 | return Guy 121 | 122 | @partial(jax.jit, static_argnums=0) 123 | def loss(self, params, batch): 124 | inputs, y = batch 125 | y_pred = self.fwd(params,inputs) 126 | loss = np.mean((y.flatten() - y_pred.flatten())**2) 127 | return loss 128 | 129 | @partial(jax.jit, static_argnums=0) 130 | def L2error(self, params, batch): 131 | inputs, y = batch 132 | y_pred = self.fwd(params,inputs) 133 | return norm(y.flatten() - y_pred.flatten(), 2)/norm(y.flatten(),2) 134 | 135 | @partial(jit, static_argnums=(0,)) 136 | def step(self, i, opt_state, batch): 137 | params = self.get_params(opt_state) 138 | g = grad(self.loss)(params, batch) 139 | return self.opt_update(i, g, opt_state) 140 | 141 | def train(self, train_dataset, test_dataset, nIter = 10000): 142 | train_data = iter(train_dataset) 143 | test_data = iter(test_dataset) 144 | 145 | pbar = trange(nIter) 146 | for it in pbar: 147 | train_batch = next(train_data) 148 | test_batch = next(test_data) 149 | 150 | self.opt_state = self.step(next(self.itercount), self.opt_state, train_batch) 151 | 152 | if it % 100 == 0: 153 | params = self.get_params(self.opt_state) 154 | loss_train = self.loss(params, train_batch) 155 | loss_test = self.loss(params, test_batch) 156 | errorTrain = self.L2error(params, train_batch) 157 | errorTest = self.L2error(params, test_batch) 158 | self.loss_log.append(loss_train) 159 | 160 | pbar.set_postfix({'Training loss': loss_train, 161 | 'Testing loss' : loss_test, 162 | 'Test error': errorTest, 163 | 'Train error': errorTrain}) 164 | 165 | @partial(jit, static_argnums=(0,)) 166 | def predict(self, params, inputs): 167 | s_pred = self.fwd(params,inputs) 168 | return s_pred 169 | 170 | def count_params(self): 171 | params = self.get_params(self.opt_state) 172 | params_flat, _ = ravel_pytree(params) 173 | print("The number of model parameters is:",params_flat.shape[0]) 174 | 175 | def main(n, decoder): 176 | TRAINING_ITERATIONS = 100000 177 | P = 128 178 | m = 1024 179 | num_train = 1000 180 | num_test = 1000 181 | training_batch_size = 100 182 | du = 3 183 | dy = 3 184 | ds = 3 185 | Nx = 32 186 | Ny = 32 187 | Nt = 5 188 | 189 | d = np.load("../Data/SW/train_SW.npz") 190 | u_train = d["u_train"] 191 | S_train = d["S_train"] 192 | T = d["T"] 193 | CX = d["CX"] 194 | CY = d["CY"] 195 | 196 | d = np.load("../Data/SW/test_SW.npz") 197 | u_test = d["u_test"] 198 | S_test = d["S_test"] 199 | T = d["T"] 200 | CX = d["CX"] 201 | CY = d["CY"] 202 | 203 | 204 | s_train = np.zeros((num_train,P,ds)) 205 | y_train = np.zeros((num_train,P,dy)) 206 | s_test = np.zeros((num_test,P,ds)) 207 | y_test = np.zeros((num_test,P,dy)) 208 | 209 | U_train = u_train.reshape(num_train,Nx*Ny,du) 210 | U_test = u_test.reshape(num_test,Nx*Ny,du) 211 | 212 | 213 | for i in range(0,num_train): 214 | s_train[i ,:,:], y_train[i,:,:] = output_construction(S_train[i,:,:,:,:], T, CX, CY, P=P,Nt=Nt) 215 | 216 | for i in range(0,num_test): 217 | s_test[i,:,:], y_test[i,:,:] = output_construction(S_test[i,:,:,:,:], T, CX, CY, P=P,Nt=Nt) 218 | 219 | U_train = jnp.asarray(U_train) 220 | y_train = jnp.asarray(y_train) 221 | s_train = jnp.asarray(s_train) 222 | 223 | U_test = jnp.asarray(U_test) 224 | y_test = jnp.asarray(y_test) 225 | s_test = jnp.asarray(s_test) 226 | 227 | U_train = jnp.reshape(U_train,(num_train,m,du)) 228 | y_train = jnp.reshape(y_train,(num_train,P,dy)) 229 | s_train = jnp.reshape(s_train,(num_train,P,ds)) 230 | 231 | U_test = jnp.reshape(U_test,(num_test,m,du)) 232 | y_test = jnp.reshape(y_test,(num_test,P,dy)) 233 | s_test = jnp.reshape(s_test,(num_test,P,ds)) 234 | 235 | train_dataset = DataGenerator(U_train, y_train, s_train, training_batch_size) 236 | train_dataset = iter(train_dataset) 237 | 238 | test_dataset = DataGenerator(U_test, y_test, s_test, training_batch_size) 239 | test_dataset = iter(test_dataset) 240 | 241 | if decoder=="nonlinear": 242 | branch_layers = [m*du, 100, 100, 100, 100, 100, ds*n] 243 | trunk_layers = [ds*n*2, 100, 100, 100, 100, 100, ds] 244 | elif decoder=="linear": 245 | branch_layers = [m*du,100, 100, 100, 100, 100, ds*n] 246 | trunk_layers = [dy, 100, 100, 100, 100, 100, ds*n] 247 | 248 | model = operator_model(branch_layers, trunk_layers, n=n, decoder=decoder, ds=ds) 249 | model.count_params() 250 | 251 | start_time = timeit.default_timer() 252 | model.train(train_dataset, test_dataset, nIter=TRAINING_ITERATIONS) 253 | elapsed = timeit.default_timer() - start_time 254 | print("The training wall-clock time is seconds is equal to %f seconds"%elapsed) 255 | 256 | params = model.get_params(model.opt_state) 257 | 258 | 259 | T, X, Y = np.meshgrid(T, CX, CY,indexing="ij") 260 | Y_train = jnp.tile(jnp.concatenate((T.flatten()[:,None], X.flatten()[:,None], Y.flatten()[:,None]),axis=-1)[None,:,:],(num_train, 1, 1)) 261 | Y_test = jnp.tile(jnp.concatenate((T.flatten()[:,None], X.flatten()[:,None], Y.flatten()[:,None]),axis=-1)[None,:,:],(num_test, 1, 1)) 262 | 263 | S_test = S_test.reshape(num_test,Nt*Nx*Ny,ds) 264 | s_pred_test = np.zeros_like(S_test) 265 | 266 | idx = np.arange(0,100) 267 | for i in range(0,num_test,100): 268 | idx = i + np.arange(0,100) 269 | s_pred_test[idx] = model.predict(params, (U_test[idx], Y_test[idx])) 270 | test_error_rho = [] 271 | test_error_u = [] 272 | test_error_v = [] 273 | for i in range(0,num_train): 274 | test_error_rho.append(norm(S_test[i,:,0]- s_pred_test[i,:,0],2)/norm(S_test[i,:,0],2)) 275 | test_error_u.append(norm(S_test[i,:,1]- s_pred_test[i,:,1],2)/norm(S_test[i,:,1],2)) 276 | test_error_v.append(norm(S_test[i,:,2]- s_pred_test[i,:,2],2)/norm(S_test[i,:,2],2)) 277 | print("The average test rho error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_rho),np.std(test_error_rho),np.min(test_error_rho),np.max(test_error_rho))) 278 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u),np.std(test_error_u),np.min(test_error_u),np.max(test_error_u))) 279 | print("The average test v error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_v),np.std(test_error_v),np.min(test_error_v),np.max(test_error_v))) 280 | 281 | S_train = S_train.reshape(num_train,Nt*Nx*Ny,ds) 282 | s_pred_train = np.zeros_like(S_train) 283 | for i in range(0,num_train,100): 284 | idx = i + np.arange(0,100) 285 | s_pred_train[idx] = model.predict(params, (U_train[idx], Y_train[idx])) 286 | train_error_rho = [] 287 | train_error_u = [] 288 | train_error_v = [] 289 | for i in range(0,num_train): 290 | train_error_rho.append(norm(S_train[i,:,0]- s_pred_train[i,:,0],2)/norm(S_train[i,:,0],2)) 291 | train_error_u.append(norm(S_train[i,:,1]- s_pred_train[i,:,1],2)/norm(S_train[i,:,1],2)) 292 | train_error_v.append(norm(S_train[i,:,2]- s_pred_train[i,:,2],2)/norm(S_train[i,:,2],2)) 293 | print("The average train rho error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(train_error_rho),np.std(train_error_rho),np.min(train_error_rho),np.max(train_error_rho))) 294 | print("The average train u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(train_error_u),np.std(train_error_u),np.min(train_error_u),np.max(train_error_u))) 295 | print("The average train v error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(train_error_v),np.std(train_error_v),np.min(train_error_v),np.max(train_error_v))) 296 | 297 | 298 | 299 | if __name__ == "__main__": 300 | parser = argparse.ArgumentParser(description='Process model parameters.') 301 | parser.add_argument('n', metavar='n', type=int, nargs='+', help='Latent dimension of the solution manifold') 302 | parser.add_argument('decoder', metavar='decoder', type=str, nargs='+', help='Type of decoder. Choices a)"linear" b)"nonlinear"') 303 | 304 | args = parser.parse_args() 305 | n = args.n[0] 306 | decoder = args.decoder[0] 307 | main(n,decoder) 308 | --------------------------------------------------------------------------------