├── .github └── workflows │ └── python-package-conda.yml ├── 1D_linear_Poisson ├── 1D_linear_poisson_HMC.py └── 1D_linear_poisson_SVGD.py ├── 1D_nonlinear_Poisson ├── 1D_nonlinear_poisson_HMC.py ├── 1D_nonlinear_poisson_SVGD_norm.py └── 1D_nonlinear_poisson_rPINN.py ├── 2D_GWF_problem ├── 2D_gwf_BPINN_inverse_HMC.py ├── 2D_gwf_inverse_rPINN.py ├── Nk_40_Nh_40_randseed_111_idxh.out ├── Nk_40_Nh_40_randseed_111_idxk.out ├── coord_ref.out ├── h_ref_05.out └── k_ref_05.out ├── LICENSE ├── README.md └── rPINN_schematization.PNG /.github/workflows/python-package-conda.yml: -------------------------------------------------------------------------------- 1 | name: Python Package using Conda 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build-linux: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | max-parallel: 5 10 | 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python 3.10 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: '3.10' 17 | - name: Add conda to system path 18 | run: | 19 | # $CONDA is an environment variable pointing to the root of the miniconda directory 20 | echo $CONDA/bin >> $GITHUB_PATH 21 | - name: Install dependencies 22 | run: | 23 | conda env update --file environment.yml --name base 24 | - name: Lint with flake8 25 | run: | 26 | conda install flake8 27 | # stop the build if there are Python syntax errors or undefined names 28 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 29 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 30 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 31 | - name: Test with pytest 32 | run: | 33 | conda install pytest 34 | pytest 35 | -------------------------------------------------------------------------------- /1D_linear_Poisson/1D_linear_poisson_HMC.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Jun 26 11:31:38 2023 5 | 6 | @author: yifeizong 7 | """ 8 | 9 | import jax 10 | import os 11 | import jax.numpy as jnp 12 | from jax import random, grad, vmap, jit 13 | from jax.flatten_util import ravel_pytree 14 | from jax.example_libraries import optimizers 15 | 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | import matplotlib as mpl 19 | import scipy.linalg as spl 20 | import seaborn as sns 21 | import pandas as pd 22 | from tensorflow_probability.substrates import jax as tfp 23 | tfd = tfp.distributions 24 | 25 | import itertools 26 | import argparse 27 | from functools import partial 28 | from tqdm import trange 29 | 30 | #command line argument parser 31 | parser = argparse.ArgumentParser(description="1D Linear Poisson with HMC") 32 | parser.add_argument( 33 | "--rand_seed", 34 | type=int, 35 | default=42, 36 | help="random seed") 37 | parser.add_argument( 38 | "--sigma", 39 | type=float, 40 | default=0.01, 41 | help="Data uncertainty") 42 | parser.add_argument( 43 | "--sigma_r", 44 | type=float, 45 | default=0.01, 46 | help="Aleotoric uncertainty to the residual") 47 | parser.add_argument( 48 | "--sigma_d", 49 | type=float, 50 | default=0.004, 51 | help="Aleotoric uncertainty to the data") 52 | parser.add_argument( 53 | "--sigma_p", 54 | type=float, 55 | default=0.1452, 56 | help="Prior std") 57 | parser.add_argument( 58 | "--Nres", 59 | type=int, 60 | default=128, 61 | help="Number of reisudal points") 62 | parser.add_argument( 63 | "--Nsamples", 64 | type=int, 65 | default=5000, 66 | help="Number of Posterior samples") 67 | parser.add_argument( 68 | "--Nburn", 69 | type=int, 70 | default=50000, 71 | help="Number of Posterior samples") 72 | parser.add_argument( 73 | "--data_load", 74 | type=bool, 75 | default=False, 76 | help="If to load data") 77 | args = parser.parse_args() 78 | 79 | #Define parameters 80 | layers_u = [1, 50, 50, 1] 81 | lbt = np.array([-1.]) 82 | ubt = np.array([1.]) 83 | k = -1/np.pi**2 84 | dataset = dict() 85 | rand_seed = args.rand_seed 86 | Nres = args.Nres 87 | sigma = args.sigma 88 | sigma_r = args.sigma_r 89 | sigma_b = args.sigma_d 90 | sigma_p = args.sigma_p 91 | Nsamples = args.Nsamples 92 | Nburn = args.Nburn 93 | Nchains = 9 94 | path_f = f'1D_linear_poisson_HMC_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_Nburn_{Nburn}_{Nchains}_chains' 95 | path_fig = os.path.join(path_f,'figures') 96 | if not os.path.exists(path_f): 97 | os.makedirs(path_f) 98 | if not os.path.exists(path_fig): 99 | os.makedirs(path_fig) 100 | f_rec = open(os.path.join(path_f,'record.out'), 'a+') 101 | 102 | def u(x): 103 | return jnp.sin(jnp.pi*x) 104 | 105 | def f(x): 106 | return jnp.sin(jnp.pi*x) 107 | 108 | rl2e = lambda yest, yref : spl.norm(yest - yref, 2) / spl.norm(yref, 2) 109 | infe = lambda yest, yref : spl.norm(yest - yref, np.inf) 110 | lpp = lambda h, href, sigma: np.sum( -(h - href)**2/(2*sigma**2) - 1/2*np.log( 2*np.pi) - 2*np.log(sigma)) 111 | 112 | #create noisy boundary data 113 | np.random.seed(rand_seed) 114 | x_data = np.array([lbt[0], ubt[0]])[:,np.newaxis] 115 | y_data = np.array([u(lbt[0]), u(ubt[0])])[:,np.newaxis].astype(np.float32) + np.random.normal(0,sigma,(2,1)).astype(np.float32) 116 | data = jnp.concatenate([x_data,y_data], axis=1) 117 | dataset.update({'data': data}) 118 | 119 | #create noisy forcing sampling 120 | X_r = np.linspace(lbt[0], ubt[0], Nres) 121 | X_r = jnp.sort(X_r, axis = 0)[:,np.newaxis] 122 | y_r = f(X_r) + np.random.normal(0,sigma,(Nres,1)) 123 | Dres = jnp.asarray(jnp.concatenate([X_r,y_r], axis=1)) 124 | dataset.update({'res': Dres}) 125 | 126 | # Define FNN 127 | def FNN(layers, activation=jnp.tanh): 128 | 129 | def init(prng_key): #return a list of (W,b) tuples 130 | def init_layer(key, d_in, d_out): 131 | key1, key2 = random.split(key) 132 | glorot_stddev = 1.0 / jnp.sqrt((d_in + d_out) / 2.) 133 | W = glorot_stddev * random.normal(key1, (d_in, d_out)) 134 | b = jnp.zeros(d_out) 135 | return W, b 136 | key, *keys = random.split(prng_key, len(layers)) 137 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 138 | return params 139 | 140 | def forward(params, inputs): 141 | Z = inputs 142 | for W, b in params[:-1]: 143 | outputs = jnp.dot(Z, W) + b 144 | Z = activation(outputs) 145 | W, b = params[-1] 146 | outputs = jnp.dot(Z, W) + b 147 | return outputs 148 | 149 | return init, forward 150 | 151 | # Define the model 152 | class PINN(): 153 | def __init__(self, key, layers, dataset, lbt, ubt, k, sigma_r, sigma_b, sigma_p): 154 | 155 | self.lbt = lbt #domain lower corner 156 | self.ubt = ubt #domain upper corner 157 | self.k = k 158 | self.scale_coe = 0.5 159 | self.scale = 2 * self.scale_coe / (self.ubt-self.lbt) 160 | self.sigma_r = sigma_r 161 | self.sigma_b = sigma_b 162 | self.sigma_p = sigma_p 163 | 164 | # Prepare normalized training data 165 | self.dataset = dataset 166 | self.X_res, self.y_res = self.normalize(dataset['res'][:,0:1]), dataset['res'][:,1:2] 167 | self.X_data, self.y_data = self.normalize(dataset['data'][:,0:1]), dataset['data'][:,1:2] 168 | 169 | # Initalize the network 170 | self.init, self.forward = FNN(layers, activation=jnp.tanh) 171 | self.params = self.init(key) 172 | _, self.unravel = ravel_pytree(self.params) 173 | self.num_params = ravel_pytree(self.params)[0].shape[0] 174 | 175 | # Evaluate the network and the residual over the grid 176 | self.u_pred_map = vmap(self.predict_u, (None, 0)) 177 | self.f_pred_map = vmap(self.predict_res, (None, 0)) 178 | 179 | self.itercount = itertools.count() 180 | self.loss_log = [] 181 | self.loss_likelihood_log = [] 182 | self.loss_dbc_log = [] 183 | self.loss_res_log = [] 184 | 185 | lr = optimizers.exponential_decay(1e-3, decay_steps=5000, decay_rate=0.9) 186 | self.opt_init, \ 187 | self.opt_update, \ 188 | self.get_params = optimizers.adam(lr) 189 | self.opt_state = self.opt_init(self.params) 190 | 191 | 192 | def normalize(self, X): 193 | if X.shape[1] == 1: 194 | return 2.0 * self.scale_coe * (X - self.lbt[0:1])/(self.ubt[0:1] - self.lbt[0:1]) - self.scale_coe 195 | if X.shape[1] == 2: 196 | return 2.0 * self.scale_coe * (X - self.lbt[0:2])/(self.ubt[0:2] - self.lbt[0:2]) - self.scale_coe 197 | if X.shape[1] == 3: 198 | return 2.0 * self.scale_coe * (X - self.lbt)/(self.ubt - self.lbt) - self.scale_coe 199 | 200 | @partial(jit, static_argnums=(0,)) 201 | def u_net(self, params, x): 202 | inputs = jnp.hstack([x]) 203 | outputs = self.forward(params, inputs) 204 | return outputs[0] 205 | 206 | @partial(jit, static_argnums=(0,)) 207 | def res_net(self, params, x): 208 | u_xx = grad(grad(self.u_net, argnums=1), argnums=1)(params, x)*self.scale[0]**2 209 | return self.k*u_xx 210 | 211 | @partial(jit, static_argnums=(0,)) 212 | def predict_u(self, params, x): 213 | # Normalize input first, and then predict 214 | x = 2.0 * self.scale_coe * (x - self.lbt[0])/(self.ubt[0] - self.lbt[0]) - self.scale_coe 215 | return self.u_net(params, x) 216 | 217 | @partial(jit, static_argnums=(0,)) 218 | def predict_res(self, params, x): 219 | # Normalize input first, and then predict 220 | x = 2.0 * self.scale_coe * (x - self.lbt[0])/(self.ubt[0] - self.lbt[0]) - self.scale_coe 221 | return self.res_net(params, x) 222 | 223 | @partial(jit, static_argnums=(0,)) 224 | def u_pred_vector(self, params): 225 | # For HMC 226 | u_pred_vec = vmap(self.u_net, (None, 0))(self.unravel(params), self.X_data[:,0]) 227 | return u_pred_vec 228 | 229 | @partial(jit, static_argnums=(0,)) 230 | def f_pred_vector(self, params): 231 | # For HMC 232 | f_pred_vec = vmap(self.res_net, (None, 0))(self.unravel(params), self.X_res[:,0]) 233 | return f_pred_vec 234 | 235 | @partial(jit, static_argnums=(0,)) 236 | def loss_dbc(self, params, beta): 237 | u_pred = vmap(self.u_net, (None, 0))(params, self.X_data[:,0]) 238 | loss_bc = jnp.sum((u_pred.flatten() - self.y_data.flatten()- beta)**2) 239 | return loss_bc 240 | 241 | @partial(jit, static_argnums=(0,)) 242 | def loss_res(self, params, alpha): 243 | f_pred = vmap(self.res_net, (None, 0))(params, self.X_res[:,0]) 244 | loss_res = jnp.sum((f_pred.flatten() - self.y_res.flatten() - alpha)**2) 245 | return loss_res 246 | 247 | @partial(jit, static_argnums=(0,)) 248 | def l2_regularizer(self, params, omega): 249 | return jnp.sum((ravel_pytree(params)[0] - omega)**2) 250 | 251 | @partial(jit, static_argnums=(0,)) 252 | def loss(self, params, alpha, beta, omega): 253 | return 1/self.sigma_r**2*self.loss_res(params, alpha) + 1/self.sigma_b**2*self.loss_dbc(params, beta) + \ 254 | 1/self.sigma_p**2*self.l2_regularizer(params, omega) 255 | 256 | @partial(jit, static_argnums=(0,)) 257 | def step(self, i, opt_state, alpha, beta, omega): 258 | params = self.get_params(opt_state) 259 | g = grad(self.loss, argnums=0)(params, alpha, beta, omega) 260 | 261 | return self.opt_update(i, g, opt_state) 262 | 263 | def train(self, nIter, num_print, alpha, beta, omega): 264 | pbar = trange(nIter) 265 | # Main training loop 266 | for it in pbar: 267 | self.current_count = next(self.itercount) 268 | self.opt_state = self.step(self.current_count, self.opt_state, alpha, beta, omega) 269 | 270 | if it % num_print == 0: 271 | params = self.get_params(self.opt_state) 272 | 273 | loss_value = self.loss(params, alpha, beta, omega) 274 | loss_res_value = self.loss_res(params, alpha) 275 | loss_dbc_value = self.loss_dbc(params, beta) 276 | loss_reg_value = self.l2_regularizer(params, omega) 277 | 278 | 279 | pbar.set_postfix({'Loss': loss_value, 280 | 'Loss_res': loss_res_value, 281 | 'Loss_dbc': loss_dbc_value, 282 | 'Loss_reg': loss_reg_value}) 283 | self.loss_log.append(loss_value) 284 | self.loss_likelihood_log.append(loss_res_value + loss_dbc_value) 285 | self.loss_res_log.append(loss_res_value) 286 | self.loss_dbc_log.append(loss_dbc_value) 287 | 288 | 289 | key1, key2 = random.split(random.PRNGKey(0), 2) 290 | pinn = PINN(key1, layers_u, dataset, lbt, ubt, k, sigma_r, sigma_b, sigma_p) 291 | num_params = ravel_pytree(pinn.params)[0].shape[0] 292 | 293 | # def target_log_prob_fn(theta): 294 | # prior = prior_dist.log_prob(theta) 295 | # r_likelihood = jnp.sum( -jnp.log(sigma_r) - jnp.log(2*jnp.pi)/2 -(y_r.ravel() - pinn.f_pred_vector(theta))**2/(2*sigma_r**2)) 296 | # u_likelihood = jnp.sum( -jnp.log(sigma_b) - jnp.log(2*jnp.pi)/2 -(y_data.ravel() - pinn.u_pred_vector(theta))**2/(2*sigma_b**2)) 297 | # return prior + r_likelihood + u_likelihood 298 | 299 | def target_log_prob_fn(theta): #same 300 | prior = jnp.sum(-(theta)**2/(2*sigma_p**2)) 301 | r_likelihood = jnp.sum(-(y_r.ravel() - pinn.f_pred_vector(theta))**2/(2*sigma_r**2)) 302 | u_likelihood = jnp.sum(-(y_data.ravel() - pinn.u_pred_vector(theta))**2/(2*sigma_b**2)) 303 | return prior + r_likelihood + u_likelihood 304 | 305 | # key3, key4, key5 = random.split(key1, 3) 306 | # init_state = jnp.zeros((1, num_params)) 307 | # init_state = jnp.concatenate([init_state,random.normal(key4 ,(1, num_params))], axis=0) 308 | # init_state = jnp.concatenate([init_state,3 + random.normal(key5 ,(1, num_params))], axis=0) 309 | 310 | new_key, *subkeys = random.split(key1, Nchains + 1) 311 | init_state = jnp.zeros((1, num_params)) 312 | for key in subkeys[:-1]: 313 | init_state = jnp.concatenate([init_state,random.normal(key ,(1, num_params))], axis=0) 314 | 315 | nuts_kernel = tfp.mcmc.NoUTurnSampler( 316 | target_log_prob_fn = target_log_prob_fn, step_size = 0.0005, max_tree_depth=10, max_energy_diff=1000.0, 317 | unrolled_leapfrog_steps=1, parallel_iterations=30) 318 | 319 | kernel = tfp.mcmc.DualAveragingStepSizeAdaptation( 320 | inner_kernel=nuts_kernel, num_adaptation_steps=int(Nburn * 0.75)) 321 | 322 | def run_chain(init_state, key): 323 | samples, trace = tfp.mcmc.sample_chain( 324 | num_results= Nsamples, 325 | num_burnin_steps= Nburn, 326 | current_state= init_state, 327 | kernel= kernel, 328 | seed=key, 329 | trace_fn= lambda _,pkr: [pkr.inner_results.log_accept_ratio, 330 | pkr.inner_results.target_log_prob, 331 | pkr.inner_results.step_size] 332 | ) 333 | return samples, trace 334 | 335 | 336 | print('\nStart HMC Sampling') 337 | #states, trace = jit(run_chain)(init_state, key3) 338 | #states, trace = jax.pmap(run_chain, in_axes=(0, None))(init_state, key3) 339 | states, trace = jit(vmap(run_chain, in_axes=(0, None)))(init_state, new_key) 340 | print('\nFinish HMC Sampling') 341 | np.save(os.path.join(path_f,'chains'), states) 342 | 343 | # ============================================================================= 344 | # Post-processing HMC results 345 | # ============================================================================= 346 | 347 | accept_ratio = np.exp(trace[0]) 348 | target_log_prob = trace[1] 349 | step_size = trace[2] 350 | print(f'Average accept ratio for each chain: {np.mean(accept_ratio, axis = 1)}') 351 | print(f'Average step size for each chain: {np.mean(step_size, axis = 1)}') 352 | print(f'Average accept ratio for each chain: {np.mean(accept_ratio, axis = 1)}', file = f_rec) 353 | print(f'Average step size for each chain: {np.mean(step_size, axis = 1)}', file = f_rec) 354 | 355 | mark = [None, 'o', None] 356 | linestyle = ['solid', 'dotted', 'dashed'] 357 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 358 | for i, mark in enumerate(mark): 359 | ax.plot(np.arange(Nsamples)[::10], target_log_prob[i,::10], marker = mark, markersize = 2, markevery= 100, markerfacecolor='None', linestyle = 'dashed', label = f'chain {i + 1}', alpha = 0.8) 360 | ax.set_xlabel('Sample index', fontsize = 15) 361 | ax.set_ylabel('Negative log prob', fontsize = 15) 362 | ax.tick_params(axis='both', which = 'major', labelsize=12) 363 | ax.set_xlim(0,Nsamples) 364 | ax.legend(fontsize=6) 365 | plt.savefig(os.path.join(path_fig,'target_log_prob.png')) 366 | plt.show() 367 | 368 | Npred = 201 369 | x_pred_index = jnp.linspace(-1,1,Npred) 370 | f_ref = f(x_pred_index) 371 | u_ref = u(x_pred_index) 372 | samples = states 373 | 374 | @jit 375 | def get_u_pred(sample): 376 | return pinn.u_pred_map(pinn.unravel(sample),x_pred_index) 377 | 378 | @jit 379 | def get_f_pred(sample): 380 | return pinn.f_pred_map(pinn.unravel(sample),x_pred_index) 381 | 382 | u_pred_ens = np.array([vmap(get_u_pred)(samples[i,:,:]) for i in range(samples.shape[0])]) # (3, 10000, 201) 383 | f_pred_ens = np.array([vmap(get_f_pred)(samples[i,:,:]) for i in range(samples.shape[0])]) 384 | 385 | u_pred_ens_mean = np.mean(u_pred_ens, axis = 1) #(3, 201) 386 | u_pred_ens_std = np.std(u_pred_ens, axis = 1) 387 | f_pred_ens_mean = np.mean(f_pred_ens, axis = 1) 388 | f_pred_ens_std = np.std(f_pred_ens, axis = 1) 389 | 390 | u_env = np.logical_and( (u_pred_ens_mean < u_ref + 2*u_pred_ens_std), (u_pred_ens_mean > u_ref - 2*u_pred_ens_std) ) 391 | f_env = np.logical_and( (f_pred_ens_mean < f_ref + 2*f_pred_ens_std), (f_pred_ens_mean > f_ref - 2*f_pred_ens_std) ) 392 | 393 | # ============================================================================= 394 | # Posterior Statistics 395 | # ============================================================================= 396 | 397 | for i in range(Nchains): 398 | rl2e_u = rl2e(u_pred_ens_mean[i, :], u_ref) 399 | infe_u = infe(u_pred_ens_mean[i, :], u_ref) 400 | lpp_u = lpp(u_pred_ens_mean[i, :], u_ref, u_pred_ens_std[i, :]) 401 | rl2e_f = rl2e(f_pred_ens_mean[i, :], f_ref) 402 | infe_f = infe(f_pred_ens_mean[i, :], f_ref) 403 | lpp_f = lpp(f_pred_ens_mean[i, :], f_ref, f_pred_ens_std[i, :]) 404 | 405 | print(f'chain {i}:\n') 406 | print('u prediction:\n') 407 | print('Relative RL2 error: {}'.format(rl2e_u)) 408 | print('Absolute inf error: {}'.format(infe_u)) 409 | print('Average standard deviation: {}'.format(np.mean(u_pred_ens_std[i, :]))) 410 | print('log predictive probability: {}'.format(lpp_u)) 411 | print('Percentage of coverage:{}\n'.format(np.sum(u_env[i, :])/Npred)) 412 | 413 | print('f prediction:\n') 414 | print('Relative RL2 error: {}'.format(rl2e_f)) 415 | print('Absolute inf error: {}'.format(infe_f)) 416 | print('Average standard deviation: {}'.format(np.mean(f_pred_ens_std[i, :]))) 417 | print('log predictive probability: {}'.format(lpp_f)) 418 | print('Percentage of coverage:{}\n'.format(np.sum(f_env[i, :])/Npred)) 419 | 420 | print(f'chain {i}:\n', file = f_rec) 421 | print('u prediction:\n', file = f_rec) 422 | print('Relative RL2 error: {}'.format(rl2e_u), file = f_rec) 423 | print('Absolute inf error: {}'.format(infe_u), file = f_rec) 424 | print('Average standard deviation: {}'.format(np.mean(u_pred_ens_std[i, :])), file = f_rec) 425 | print('log predictive probability: {}'.format(lpp_u), file = f_rec) 426 | print('Percentage of coverage:{}\n'.format(np.sum(u_env[i, :])/Npred), file = f_rec) 427 | 428 | print('f prediction:\n', file = f_rec) 429 | print('Relative RL2 error: {}'.format(rl2e_f), file = f_rec) 430 | print('Absolute inf error: {}'.format(infe_f), file = f_rec) 431 | print('Average standard deviation: {}'.format(np.mean(f_pred_ens_std[i, :])), file = f_rec) 432 | print('log predictive probability: {}'.format(lpp_f), file = f_rec) 433 | print('Percentage of coverage:{}\n'.format(np.sum(f_env[i, :])/Npred), file = f_rec) 434 | 435 | 436 | rhat = tfp.mcmc.diagnostic.potential_scale_reduction(states.transpose((1,0,2)), independent_chain_ndims=1) 437 | ess = tfp.mcmc.effective_sample_size(states[0], filter_beyond_positive_pairs=True) 438 | 439 | fig, ax = plt.subplots(1, 1, figsize=(4, 4), sharex='col', sharey='col', dpi = 300) 440 | g = sns.histplot(rhat, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 441 | g.tick_params(labelsize=16) 442 | g.set_xlabel("$\hat{r}$", fontsize=18) 443 | g.set_ylabel("Count", fontsize=18) 444 | fig.tight_layout() 445 | plt.savefig(os.path.join(path_fig,'rhat.png')) 446 | plt.show() 447 | 448 | fig, ax = plt.subplots(1, 1, figsize=(4, 4), sharex='col', sharey='col', dpi = 300) 449 | g = sns.histplot(ess, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 450 | g.tick_params(labelsize=16) 451 | g.set_xlabel("ESS", fontsize=18) 452 | g.set_ylabel("Count", fontsize=18) 453 | fig.tight_layout() 454 | plt.savefig(os.path.join(path_fig,'ess.png')) 455 | plt.show() 456 | 457 | idx_low = np.argmin(rhat) 458 | idx_high = np.argmax(rhat) 459 | samples1 = states[:,:,idx_low] 460 | df1 = pd.DataFrame({'chains': 'chain1', 'indice':np.arange(0, samples1[0,].shape[0], 5), 'trace':samples1[0, ::5]}) 461 | df2 = pd.DataFrame({'chains': 'chain2', 'indice':np.arange(0, samples1[1,].shape[0], 5), 'trace':samples1[1, ::5]}) 462 | df3 = pd.DataFrame({'chains': 'chain3', 'indice':np.arange(0, samples1[2,].shape[0], 5), 'trace':samples1[2, ::5]}) 463 | df = pd.concat([df1, df2, df3], ignore_index=True) 464 | plt.figure(figsize=(4,4)) 465 | g = sns.jointplot(data=df, x='indice', y='trace', xlim=(0, 5000), ylim=(-4, 4), hue='chains', joint_kws={'alpha': 0.6}) 466 | g.ax_joint.tick_params(labelsize=18) 467 | g.ax_joint.set_xlabel("Index", fontsize=24) 468 | g.ax_joint.set_ylabel("Trace", fontsize=24) 469 | g.ax_joint.legend(fontsize=16) 470 | g.ax_marg_x.remove() 471 | #plt.title('Trace plot for parameter with lowest $\hat{r}$') 472 | plt.gcf().set_dpi(300) 473 | plt.savefig(os.path.join(path_fig,'trace_plot_rhat_lowest.png')) 474 | fig.tight_layout() 475 | plt.show() 476 | 477 | samples2 = states[:,:,idx_high] 478 | df1 = pd.DataFrame({'chains': 'chain1', 'indice':np.arange(0, samples2[0, ::].shape[0], 5), 'trace':samples2[0, ::5]}) 479 | df2 = pd.DataFrame({'chains': 'chain2', 'indice':np.arange(0, samples2[1, ::].shape[0], 5), 'trace':samples2[1, ::5]}) 480 | df3 = pd.DataFrame({'chains': 'chain3', 'indice':np.arange(0, samples2[2, ::].shape[0], 5), 'trace':samples2[2, ::5]}) 481 | df = pd.concat([df1,df2, df3], ignore_index=True) 482 | plt.figure(figsize=(4,4)) 483 | #g = sns.jointplot(data=df, x='indice', y='trace', xlim=(0, 5000), ylim=(-4, 4), hue='chains', joint_kws={'alpha': 1}) 484 | g = sns.jointplot(data=df, x='indice', y='trace', xlim=(0, 5000), ylim=(-4, 4), hue='chains', joint_kws={'alpha': 0.6}) 485 | g.ax_joint.tick_params(labelsize=18) 486 | g.ax_joint.set_xlabel("Index", fontsize=24) 487 | g.ax_joint.set_ylabel("Trace", fontsize=24) 488 | g.ax_joint.legend(fontsize=16) 489 | g.ax_marg_x.remove() 490 | #plt.title('Trace plot for parameter with highest $\hat{r}$') 491 | plt.gcf().set_dpi(300) 492 | plt.savefig(os.path.join(path_fig,'trace_plot_rhat_highest.png')) 493 | fig.tight_layout() 494 | plt.show() 495 | 496 | 497 | # ============================================================================= 498 | # Compute Hessian at the mean 499 | # ============================================================================= 500 | 501 | chain0 = states[0] 502 | chain1 = states[1] 503 | chain2 = states[2] 504 | chain0_m = np.mean(chain0, axis = 0) 505 | chain1_m = np.mean(chain1, axis = 0) 506 | chain2_m = np.mean(chain2, axis = 0) 507 | hess = jax.hessian(target_log_prob_fn) 508 | hess_chain0 = hess(chain0_m) 509 | _, s0, _ = jax.scipy.linalg.svd(jax.scipy.linalg.inv(hess_chain0)) 510 | hess_chain1 = hess(chain1_m) 511 | _, s1, _ = jax.scipy.linalg.svd(jax.scipy.linalg.inv(hess_chain1)) 512 | hess_chain2 = hess(chain2_m) 513 | _, s2, _ = jax.scipy.linalg.svd(jax.scipy.linalg.inv(hess_chain2)) 514 | 515 | s = np.concatenate((s0[np.newaxis, :], s1[np.newaxis, :], s2[np.newaxis, :]), axis = 0) 516 | np.savetxt(os.path.join(path_f,'singular_values_posterior_hessian.out'), s) 517 | 518 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 519 | #mark = [None, 'o', None] 520 | linestyle = ['solid', 'dotted', 'dashed'] 521 | for i, ls in enumerate(linestyle): 522 | ax.plot(s[i], linestyle = ls, marker = None, markersize = 2, markevery= 100, markerfacecolor='None', label=f'chain{i+1}', alpha = 0.8) 523 | ax.set_xlabel('Index', fontsize=16) 524 | ax.set_ylabel('Eigenvalues', fontsize=16) 525 | plt.yscale('log') 526 | ax.tick_params(axis='both', which = 'major', labelsize=13) 527 | ax.legend(fontsize=8) 528 | plt.savefig(os.path.join(path_fig,'singular_values_posterior_hessian.png')) 529 | plt.show() 530 | 531 | fig, ax = plt.subplots(1, 1, figsize=(4, 4), sharex='col', sharey='col', dpi = 300) 532 | g = sns.histplot(chain0_m, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 533 | g = sns.histplot(chain1_m, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 534 | g = sns.histplot(chain2_m, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 535 | g.tick_params(labelsize=16) 536 | g.set_xlabel("Weight", fontsize=18) 537 | g.set_ylabel("Count", fontsize=18) 538 | fig.tight_layout() 539 | plt.savefig(os.path.join(path_fig,'weight.png')) 540 | plt.show() 541 | 542 | # ============================================================================= 543 | # Plot posterior space 544 | # ============================================================================= 545 | 546 | 547 | # class RandomCoordinates(object): 548 | # # randomly choose some directions 549 | # def __init__(self, origin): 550 | # self.origin = origin # (num_params,) 551 | # self.v0 = self.normalize( 552 | # random.normal(key = random.PRNGKey(88), shape = self.origin.shape), 553 | # self.origin) 554 | # self.v1 = self.normalize( 555 | # random.normal(key = random.PRNGKey(66), shape = self.origin.shape), 556 | # self.origin) 557 | 558 | # def __call__(self, a, b): 559 | # return a*self.v0 + b * self.v1 + self.origin 560 | 561 | # def normalize(self, weights, origin): 562 | # return weights * jnp.abs(origin)/ jnp.abs(weights) # 563 | 564 | 565 | # class LossSurface(object): 566 | # def __init__(self, loss_fn, coords): 567 | # self.loss_fn = loss_fn 568 | # self.coords = coords 569 | 570 | # def compile(self, range, num_points): 571 | # loss_fn_0d = lambda x, y: self.loss_fn(self.coords(x,y)) 572 | # loss_fn_1d = jax.vmap(loss_fn_0d, in_axes = (0,0), out_axes = 0) 573 | # loss_fn_2d = jax.vmap(loss_fn_1d, in_axes = (0,0), out_axes = 0) 574 | 575 | # self.a_grid = jnp.linspace(-1.0, 1.0, num=num_points) ** 3 * range #(-5, 5) power rate 576 | # self.b_grid = jnp.linspace(-1.0, 1.0, num=num_points) ** 3 * range 577 | # self.aa, self.bb = jnp.meshgrid(self.a_grid, self.b_grid) 578 | # self.loss_grid = loss_fn_2d(self.aa, self.bb) 579 | 580 | # def project_points(self, points): 581 | # x = jax.vmap(lambda x: jnp.dot(x, self.coords.v0)/jnp.linalg.norm(self.coords.v0), 0, 0)(points) 582 | # y = jax.vmap(lambda y: jnp.dot(y, self.coords.v1)/jnp.linalg.norm(self.coords.v1), 0, 0)(points) 583 | # return x, y 584 | 585 | # def plot(self, levels=30, points = None, ax=None, **kwargs): 586 | # xs = self.a_grid 587 | # ys = self.b_grid 588 | # zs = self.loss_grid 589 | # if ax is None: 590 | # fig, ax = plt.subplots(dpi = 300, **kwargs) 591 | # ax.set_title("Loss Surface") 592 | # ax.set_aspect("equal") 593 | 594 | # # Set Levels 595 | # min_loss = zs.min() 596 | # max_loss = zs.max() 597 | # levels = jnp.linspace( 598 | # max_loss, min_loss, num=levels 599 | # )[::-1] 600 | 601 | # # levels = jnp.exp( 602 | # # jnp.log(min_loss) + 603 | # # jnp.linspace(0., 1.0, num=levels) ** 3 * (jnp.log(max_loss))- jnp.log(min_loss)) 604 | 605 | # # Create Contour Plot 606 | # CS = ax.contourf( 607 | # xs, 608 | # ys, 609 | # zs, 610 | # levels=levels, 611 | # cmap= 'magma', 612 | # linewidths=0.75, 613 | # norm = mpl.colors.Normalize(vmin = min_loss, vmax = max_loss), 614 | # ) 615 | # for i in points: 616 | # #origin_x, origin_y = self.project_points(self.coords.origin) 617 | # point_x, point_y = self.project_points(i) 618 | # ax.scatter(point_x, point_y, s = 20) 619 | # #ax.scatter(origin_x, origin_y, s = 1, c = 'r', marker = 'x') 620 | # ax.clabel(CS, fontsize=8, fmt="%1.2f") 621 | # #plt.colorbar(CS) 622 | # plt.show() 623 | # return ax 624 | 625 | # coords = RandomCoordinates(chain0_m) 626 | # loss_surface = LossSurface(target_log_prob_fn, coords) 627 | # loss_surface.compile(range = 5, num_points= 500) 628 | # ax = loss_surface.plot(levels = 15, points = [chain1_m, chain2_m]) 629 | 630 | # ============================================================================= 631 | # Plot different chains 632 | # ============================================================================= 633 | 634 | u_pred_ens = np.array([vmap(get_u_pred)(samples[i,:,:]) for i in range(samples.shape[0])]) 635 | f_pred_ens = np.array([vmap(get_f_pred)(samples[i,:,:]) for i in range(samples.shape[0])]) 636 | 637 | u_pred_ens_mean = np.mean(u_pred_ens, axis = 1) 638 | u_pred_ens_std = np.std(u_pred_ens, axis = 1) 639 | f_pred_ens_mean = np.mean(f_pred_ens, axis = 1) 640 | f_pred_ens_std = np.std(f_pred_ens, axis = 1) 641 | 642 | fig, ax = plt.subplots(dpi=300, figsize=(4,4)) 643 | ax.plot(x_pred_index, u_ref, 'k-', label='Exact', zorder=5) # Higher zorder to ensure the line is on top 644 | color = ['#2ca02c', '#ff7f0e', '#1f77b4'] # green, orange, blue 645 | linestyle = ['solid', 'dashdot', 'dashed'] 646 | # Adjust the zorder for fill_between 647 | zorders_fill = [1, 2, 3] # blue highest, then orange, then green 648 | # Plot lines and fill regions 649 | for i, (c, ls, z) in enumerate(zip(color, linestyle, zorders_fill)): 650 | ax.plot(x_pred_index, u_pred_ens_mean[i, :], color=c, linestyle=ls, markersize=1, markevery=2, markerfacecolor='None', label=f'Chain {i}', alpha=0.8, zorder=z+1) 651 | ax.fill_between(x_pred_index, u_pred_ens_mean[i,:] + 2 * u_pred_ens_std[i,:], u_pred_ens_mean[i,:] - 2 * u_pred_ens_std[i,:], color=color[i], alpha=0.4, zorder=z) 652 | ax.scatter(x_data, y_data, label='Obs' , s = 20, facecolors='none', edgecolors='b', zorder=6) # Higher zorder to ensure the scatter is on top 653 | ax.set_xlabel('$x$', fontsize=16) 654 | ax.set_ylabel('$u(x)$', fontsize=16) 655 | ax.set_xlim(-1.02,1.02) 656 | ax.set_ylim(-1.5,1.5) 657 | ax.tick_params(axis='both', which='major', labelsize=13) 658 | ax.legend(fontsize=10, loc='upper left') 659 | fig.tight_layout() 660 | plt.savefig(os.path.join(path_fig,f'1D_linear_poisson_HMC_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_upred.png')) 661 | plt.show() 662 | 663 | fig, ax = plt.subplots(dpi=300, figsize=(4,4)) 664 | ax.plot(x_pred_index, f_ref, 'k-', label='Exact', zorder=5) # Higher zorder to ensure the line is on top 665 | color = ['#2ca02c', '#ff7f0e', '#1f77b4'] # green, orange, blue 666 | linestyle = ['solid', 'dashdot', 'dashed'] 667 | # Adjust the zorder for fill_between 668 | zorders_fill = [1, 2, 3] # blue highest, then orange, then green 669 | # Plot lines and fill regions 670 | for i, (c, ls, z) in enumerate(zip(color, linestyle, zorders_fill)): 671 | ax.plot(x_pred_index, f_pred_ens_mean[i, :], color=c, linestyle=ls, markersize=1, markevery=2, markerfacecolor='None', label=f'Chain {i}', alpha=0.8, zorder=z+1) 672 | ax.fill_between(x_pred_index, f_pred_ens_mean[i,:] + 2 * f_pred_ens_std[i,:], f_pred_ens_mean[i,:] - 2 * f_pred_ens_std[i,:], color=color[i], alpha=0.4, zorder=z) 673 | ax.scatter(X_r, y_r, label='Obs', s=20, facecolors='none', edgecolors='b', zorder=6) # Higher zorder to ensure the scatter is on top 674 | ax.set_xlabel('$x$', fontsize=16) 675 | ax.set_ylabel('$f(x)$', fontsize=16) 676 | ax.set_xlim(-1.02,1.02) 677 | ax.set_ylim(-1.5,1.5) 678 | ax.tick_params(axis='both', which='major', labelsize=13) 679 | ax.legend(fontsize=10, loc='upper left') 680 | fig.tight_layout() 681 | plt.savefig(os.path.join(path_fig,f'1D_linear_poisson_HMC_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_fpred.png')) 682 | plt.show() 683 | 684 | f_rec.close() 685 | 686 | -------------------------------------------------------------------------------- /1D_linear_Poisson/1D_linear_poisson_SVGD.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Jan 28 15:11:25 2024 5 | 6 | @author: yifeizong 7 | """ 8 | 9 | import jax 10 | import os 11 | import jax.numpy as jnp 12 | from jax import random, grad, vmap, jit 13 | from jax.flatten_util import ravel_pytree 14 | from jax.example_libraries import optimizers 15 | 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | import matplotlib as mpl 19 | import scipy.linalg as spl 20 | from scipy.spatial.distance import pdist, squareform 21 | import seaborn as sns 22 | import pandas as pd 23 | from tensorflow_probability.substrates import jax as tfp 24 | tfd = tfp.distributions 25 | 26 | import itertools 27 | import argparse 28 | from functools import partial 29 | from tqdm import trange 30 | from time import perf_counter 31 | 32 | #command line argument parser 33 | parser = argparse.ArgumentParser(description="1D Linear Poisson with Stein Variational Gradient Descent") 34 | parser.add_argument( 35 | "--rand_seed", 36 | type=int, 37 | default=42, 38 | help="random seed") 39 | parser.add_argument( 40 | "--sigma", 41 | type=float, 42 | default=0.01, 43 | help="Data uncertainty") 44 | parser.add_argument( 45 | "--sigma_r", 46 | type=float, 47 | default=0.01, 48 | help="Aleotoric uncertainty to the residual") 49 | parser.add_argument( 50 | "--sigma_d", 51 | type=float, 52 | default=0.004, 53 | help="Aleotoric uncertainty to the data") 54 | parser.add_argument( 55 | "--sigma_p", 56 | type=float, 57 | default=0.145, 58 | help="Prior std") 59 | parser.add_argument( 60 | "--Nres", 61 | type=int, 62 | default=128, 63 | help="Number of reisudal points") 64 | parser.add_argument( 65 | "--Nsamples", 66 | type=int, 67 | default=1000, 68 | help="Number of Posterior samples") 69 | parser.add_argument( 70 | "--nIter", 71 | type=int, 72 | default=5300, 73 | help="Number of Posterior samples") 74 | args = parser.parse_args() 75 | 76 | #Define parameters 77 | layers_u = [1, 50, 50, 1] 78 | lbt = np.array([-1.]) 79 | ubt = np.array([1.]) 80 | k = -1/np.pi**2 81 | dataset = dict() 82 | rand_seed = args.rand_seed 83 | Nres = args.Nres 84 | sigma = args.sigma 85 | sigma_r = args.sigma_r 86 | sigma_b = args.sigma_d 87 | sigma_p = args.sigma_p 88 | Nsamples = args.Nsamples 89 | nIter = args.nIter 90 | num_print = 20 91 | #path_f = f'1D_linear_poisson_SVGD_Nres_{Nres}_sigma_{sigma}_test' 92 | path_f = f'1D_linear_poisson_SVGD_Nres_{Nres}_sigma_{sigma}_nIter_{nIter}_lr_1e-3_median_trick' 93 | path_fig = os.path.join(path_f,'figures') 94 | if not os.path.exists(path_f): 95 | os.makedirs(path_f) 96 | if not os.path.exists(path_fig): 97 | os.makedirs(path_fig) 98 | f_rec = open(os.path.join(path_f,'record.out'), 'a+') 99 | 100 | def u(x): 101 | return jnp.sin(jnp.pi*x) 102 | 103 | def f(x): 104 | return jnp.sin(jnp.pi*x) 105 | 106 | rl2e = lambda yest, yref : spl.norm(yest - yref, 2) / spl.norm(yref, 2) 107 | infe = lambda yest, yref : spl.norm(yest - yref, np.inf) 108 | lpp = lambda h, href, sigma: np.sum( -(h - href)**2/(2*sigma**2) - 1/2*np.log( 2*np.pi) - 2*np.log(sigma)) 109 | 110 | #create noisy boundary data 111 | np.random.seed(rand_seed) 112 | x_data = np.array([lbt[0], ubt[0]])[:,np.newaxis] 113 | y_data = np.array([u(lbt[0]), u(ubt[0])])[:,np.newaxis].astype(np.float32) + np.random.normal(0,sigma,(2,1)).astype(np.float32) 114 | data = jnp.concatenate([x_data,y_data], axis=1) 115 | dataset.update({'data': data}) 116 | 117 | #create noisy forcing sampling 118 | X_r = np.linspace(lbt[0], ubt[0], Nres) 119 | X_r = jnp.sort(X_r, axis = 0)[:,np.newaxis] 120 | y_r = f(X_r) + np.random.normal(0,sigma,(Nres,1)) 121 | Dres = jnp.asarray(jnp.concatenate([X_r,y_r], axis=1)) 122 | dataset.update({'res': Dres}) 123 | 124 | # Define FNN 125 | def FNN(layers, activation=jnp.tanh): 126 | 127 | def init(prng_key): #return a list of (W,b) tuples 128 | def init_layer(key, d_in, d_out): 129 | key1, key2 = random.split(key) 130 | glorot_stddev = 1.0 / jnp.sqrt((d_in + d_out) / 2.) 131 | W = glorot_stddev * random.normal(key1, (d_in, d_out)) 132 | b = jnp.zeros(d_out) 133 | return W, b 134 | key, *keys = random.split(prng_key, len(layers)) 135 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 136 | return params 137 | 138 | def forward(params, inputs): 139 | Z = inputs 140 | for W, b in params[:-1]: 141 | outputs = jnp.dot(Z, W) + b 142 | Z = activation(outputs) 143 | W, b = params[-1] 144 | outputs = jnp.dot(Z, W) + b 145 | return outputs 146 | 147 | return init, forward 148 | 149 | # Define the model 150 | class PINN(): 151 | def __init__(self, key, layers, dataset, lbt, ubt, k, sigma_r, sigma_b, sigma_p): 152 | 153 | self.lbt = lbt #domain lower corner 154 | self.ubt = ubt #domain upper corner 155 | self.k = k 156 | self.scale_coe = 0.5 157 | self.scale = 2 * self.scale_coe / (self.ubt-self.lbt) 158 | self.sigma_r = sigma_r 159 | self.sigma_b = sigma_b 160 | self.sigma_p = sigma_p 161 | 162 | # Prepare normalized training data 163 | self.dataset = dataset 164 | self.X_res, self.y_res = self.normalize(dataset['res'][:,0:1]), dataset['res'][:,1:2] 165 | self.X_data, self.y_data = self.normalize(dataset['data'][:,0:1]), dataset['data'][:,1:2] 166 | 167 | # Initalize the network 168 | self.init, self.forward = FNN(layers, activation=jnp.tanh) 169 | self.params = self.init(key) 170 | flat_params, self.unravel = ravel_pytree(self.params) 171 | self.num_params = flat_params.shape[0] 172 | self.log_prob_rec = [] 173 | self.u_rl2e_log = [] 174 | self.u_lpp_log = [] 175 | self.u_std_log = [] 176 | self.f_rl2e_log = [] 177 | self.f_lpp_log = [] 178 | self.f_std_log = [] 179 | 180 | self.itercount = itertools.count() 181 | 182 | # Evaluate the network and the residual over the grid 183 | self.u_pred_map = vmap(self.predict_u, (None, 0)) 184 | self.f_pred_map = vmap(self.predict_res, (None, 0)) 185 | self.rl2e = lambda yest, yref : spl.norm(yest - yref, 2) / spl.norm(yref, 2) 186 | self.lpp = lambda h, href, sigma: np.sum(-(h - href)**2/(2*sigma**2) - 1/2*np.log( 2*np.pi) - 2*np.log(sigma)) 187 | 188 | def normalize(self, X): 189 | if X.shape[1] == 1: 190 | return 2.0 * self.scale_coe * (X - self.lbt[0:1])/(self.ubt[0:1] - self.lbt[0:1]) - self.scale_coe 191 | if X.shape[1] == 2: 192 | return 2.0 * self.scale_coe * (X - self.lbt[0:2])/(self.ubt[0:2] - self.lbt[0:2]) - self.scale_coe 193 | if X.shape[1] == 3: 194 | return 2.0 * self.scale_coe * (X - self.lbt)/(self.ubt - self.lbt) - self.scale_coe 195 | 196 | @partial(jit, static_argnums=(0,)) 197 | def u_net(self, params, x): 198 | inputs = jnp.hstack([x]) 199 | outputs = self.forward(params, inputs) 200 | return outputs[0] 201 | 202 | @partial(jit, static_argnums=(0,)) 203 | def res_net(self, params, x): 204 | u_xx = grad(grad(self.u_net, argnums=1), argnums=1)(params, x)*self.scale[0]**2 205 | return self.k*u_xx 206 | 207 | @partial(jit, static_argnums=(0,)) 208 | def predict_u(self, params, x): 209 | x = 2.0 * self.scale_coe * (x - self.lbt[0])/(self.ubt[0] - self.lbt[0]) - self.scale_coe 210 | return self.u_net(params, x) 211 | 212 | @partial(jit, static_argnums=(0,)) 213 | def predict_res(self, params, x): 214 | x = 2.0 * self.scale_coe * (x - self.lbt[0])/(self.ubt[0] - self.lbt[0]) - self.scale_coe 215 | return self.res_net(params, x) 216 | 217 | @partial(jit, static_argnums=(0,)) 218 | def u_pred_vector(self, theta): 219 | params = self.unravel(theta) 220 | u_pred_vec = vmap(self.u_net, (None, 0))(params, self.X_data[:,0]) 221 | return u_pred_vec 222 | 223 | @partial(jit, static_argnums=(0,)) 224 | def f_pred_vector(self, theta): 225 | params = self.unravel(theta) 226 | f_pred_vec = vmap(self.res_net, (None, 0))(params, self.X_res[:,0]) 227 | return f_pred_vec 228 | 229 | @partial(jit, static_argnums=(0,)) 230 | def target_log_prob(self, theta): 231 | prior = jnp.sum(-(theta)**2/(2*self.sigma_p**2)) 232 | r_likelihood = jnp.sum(-(self.y_res.ravel() - self.f_pred_vector(theta))**2/(2*self.sigma_r**2)) 233 | u_likelihood = jnp.sum(-(self.y_data.ravel() - self.u_pred_vector(theta))**2/(2*self.sigma_b**2)) 234 | return prior + r_likelihood + u_likelihood 235 | 236 | @partial(jit, static_argnums=(0,)) 237 | def grad_log_prob(self, theta): 238 | return jax.value_and_grad(self.target_log_prob, argnums = 0)(theta)[1] 239 | 240 | # @partial(jit, static_argnums=(0,)) 241 | # def median_trick_h(self, theta): 242 | # ''' 243 | # Compute the correlation length for the kernel using median trick 244 | # input: parameter batch (Nsample, Nparams) 245 | # output: h 246 | # ''' 247 | # #pairwise_dists = -((theta[:, None, :] - theta)) #one way to compute pairwise distance 248 | # #another way, but faster in JAX 249 | # pairwise_dist = vmap(vmap(lambda x, y: (x - y), in_axes=(None, 0)), in_axes=(0, None))(theta, theta) 250 | # pairwise_dist_sq = (pairwise_dist ** 2).sum(axis=-1) 251 | # med_sq = jnp.median(pairwise_dist_sq) 252 | # h_sq = med_sq / jnp.log(theta.shape[0] + 1) 253 | # return jnp.sqrt(h_sq/2) 254 | 255 | def median_trick_h(self, theta): 256 | ''' 257 | The scipy one seems even faster and memory efficient 258 | 259 | ''' 260 | sq_dist = pdist(theta) 261 | pairwise_dists = squareform(sq_dist)**2 262 | h = np.median(pairwise_dists) 263 | h = np.sqrt(0.5 * h / np.log(theta.shape[0]+1)) 264 | return h 265 | 266 | @partial(jit, static_argnums=(0,)) 267 | def rbf_kernel(self, theta1, theta2, h): 268 | ''' 269 | Evaluate the rbf kernel k(x, x') = exp(-|x - x'|^2/(2h^2)) 270 | input: theta1, theta2 are 1d array of parameters, 271 | h is correlation length 272 | output: a scalar value of kernel evaluation 273 | ''' 274 | # here theta1 and theta2 are 1d-array of parameters 275 | return jnp.exp(-((theta1 - theta2)**2).sum(axis=-1) / (2 * h**2)) 276 | 277 | @partial(jit, static_argnums=(0,)) 278 | def compute_kernel_matrix(self, theta, h): 279 | return vmap(vmap(lambda x, y: self.rbf_kernel(x, y, h), in_axes=(None, 0)), in_axes=(0, None))(theta, theta) 280 | 281 | @partial(jit, static_argnums=(0,)) 282 | def kernel_and_grad(self, theta, h): 283 | ''' 284 | input theta: (Nsamples, Nparams) 285 | h is correlation length 286 | output: K: #(Nsamples, Nsamples) 287 | grad_K: #(Nsamples, Nparams) 288 | ''' 289 | K = self.compute_kernel_matrix(theta, h) #(Nsamples, Nsamples) 290 | grad_K = jnp.sum(jnp.einsum('ijk,ij->ijk', theta - theta[:, None, :], K), axis = 0)/ (h**2) 291 | return (K, grad_K) 292 | 293 | @partial(jit, static_argnums=(0,)) 294 | def svgd_step(self, i, opt_state, h): 295 | theta = self.get_params(opt_state) 296 | grad_logprob = vmap(self.grad_log_prob)(theta) 297 | K, grad_K = self.kernel_and_grad(theta, h) 298 | phi = -(jnp.einsum('ij, jk->ik', K, grad_logprob)/theta.shape[0] + grad_K) #(Nsamples, Nparams) 299 | return self.opt_update(i, phi, opt_state) 300 | 301 | def svgd_train(self, key, Nsamples, nIter, num_print, bandwidth, u_ref, f_ref): 302 | 303 | new_key, subkey = random.split(key, 2) 304 | init_state = random.normal(subkey ,(Nsamples, self.num_params)) 305 | 306 | x_test = jnp.linspace(-1,1,201) 307 | u_pred = vmap(lambda sample: self.u_pred_map(self.unravel(sample),x_test)) 308 | f_pred = vmap(lambda sample: self.f_pred_map(self.unravel(sample),x_test)) 309 | u = u_ref(x_test) 310 | f = f_ref(x_test) 311 | 312 | #lr = optimizers.exponential_decay(1e-3, decay_steps=20, decay_rate=0.9) 313 | #lr = optimizers.exponential_decay(1e-4, decay_steps=1000, decay_rate=0.9) 314 | self.opt_init, \ 315 | self.opt_update, \ 316 | self.get_params = optimizers.adam(1e-3) 317 | # self.get_params = optimizers.adagrad(lr, momentum=0.9) 318 | self.opt_state = self.opt_init(init_state) 319 | 320 | ts = perf_counter() 321 | pbar = trange(nIter) 322 | 323 | for it in pbar: 324 | self.current_count = next(self.itercount) 325 | theta = self.get_params(self.opt_state) 326 | h = bandwidth if bandwidth > 0 else self.median_trick_h(theta) 327 | self.opt_state = self.svgd_step(self.current_count, self.opt_state, h) 328 | 329 | if it % num_print == 0: 330 | 331 | log_prob = vmap(self.target_log_prob)(theta) 332 | u_test_ens = u_pred(theta) 333 | f_test_ens = f_pred(theta) 334 | u_test_mean = jnp.mean(u_test_ens, axis = 0) 335 | u_test_std = jnp.std(u_test_ens, axis = 0) 336 | rl2e_u = self.rl2e(u_test_mean, u) 337 | lpp_u = self.lpp(u_test_mean, u, u_test_std) 338 | 339 | f_test_mean = jnp.mean(f_test_ens, axis = 0) 340 | f_test_std = jnp.std(f_test_ens, axis = 0) 341 | rl2e_f = self.rl2e(f_test_mean, f) 342 | lpp_f = self.lpp(f_test_mean, f, f_test_std) 343 | 344 | self.log_prob_rec.append(jnp.mean(log_prob)) 345 | self.u_rl2e_log.append(rl2e_u) 346 | self.u_lpp_log.append(lpp_u) 347 | self.u_std_log.append(jnp.mean(u_test_std)) 348 | self.f_rl2e_log.append(rl2e_f) 349 | self.f_lpp_log.append(lpp_f) 350 | self.f_std_log.append(jnp.mean(f_test_std)) 351 | 352 | pbar.set_postfix({'Log prob': jnp.mean(log_prob), 353 | 'correlation': h, 354 | 'u_lpp':lpp_u, 355 | 'f_lpp': lpp_f}) 356 | 357 | return self.get_params(self.opt_state) 358 | 359 | key1, key2 = random.split(random.PRNGKey(0), 2) 360 | model = PINN(key1, layers_u, dataset, lbt, ubt, k, sigma_r, sigma_b, sigma_p) 361 | ts = perf_counter() 362 | samples = model.svgd_train(key2, Nsamples, nIter = nIter, num_print = num_print, bandwidth = -1, u_ref = u, f_ref = f) 363 | timings = perf_counter() - ts 364 | print(f"SVGD: {timings} s") 365 | print(f"SVGD: {timings} s", file = f_rec) 366 | np.savetxt(os.path.join(path_f,f'SVGD_posterior_samples_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}.out'), samples) 367 | 368 | 369 | Npred = 201 370 | x_pred_index = jnp.linspace(-1,1,Npred) 371 | u_ref = u(x_pred_index) 372 | f_ref = f(x_pred_index) 373 | 374 | @jit 375 | def get_u_pred(sample): 376 | return model.u_pred_map(model.unravel(sample),x_pred_index) 377 | 378 | @jit 379 | def get_f_pred(sample): 380 | return model.f_pred_map(model.unravel(sample),x_pred_index) 381 | 382 | u_pred_ens = vmap(get_u_pred)(samples) 383 | f_pred_ens = vmap(get_f_pred)(samples) 384 | np.savetxt(os.path.join(path_f,'u_pred_ens.out'), u_pred_ens) 385 | np.savetxt(os.path.join(path_f,'f_pred_ens.out'), f_pred_ens) 386 | 387 | u_pred_ens_mean = np.mean(u_pred_ens, axis = 0) 388 | u_pred_ens_std = np.std(u_pred_ens, axis = 0) 389 | f_pred_ens_mean = np.mean(f_pred_ens, axis = 0) 390 | f_pred_ens_std = np.std(f_pred_ens, axis = 0) 391 | 392 | u_env = np.logical_and( (u_pred_ens_mean < u_ref + 2*u_pred_ens_std), (u_pred_ens_mean > u_ref - 2*u_pred_ens_std) ) 393 | f_env = np.logical_and( (f_pred_ens_mean < f_ref + 2*f_pred_ens_std), (f_pred_ens_mean > f_ref - 2*f_pred_ens_std) ) 394 | 395 | # ============================================================================= 396 | # Posterior Statistics 397 | # ============================================================================= 398 | 399 | rl2e_u = rl2e(u_pred_ens_mean, u_ref) 400 | infe_u = infe(u_pred_ens_mean, u_ref) 401 | lpp_u = lpp(u_pred_ens_mean, u_ref, u_pred_ens_std) 402 | rl2e_f = rl2e(f_pred_ens_mean, f_ref) 403 | infe_f = infe(f_pred_ens_mean, f_ref) 404 | lpp_f = lpp(f_pred_ens_mean, f_ref, f_pred_ens_std) 405 | 406 | print('u prediction:\n') 407 | print('Relative RL2 error: {}'.format(rl2e_u)) 408 | print('Absolute inf error: {}'.format(infe_u)) 409 | print('Average standard deviation: {}'.format(np.mean(u_pred_ens_std))) 410 | print('log predictive probability: {}'.format(lpp_u)) 411 | print('Percentage of coverage:{}\n'.format(np.sum(u_env)/Npred)) 412 | 413 | print('f prediction:\n') 414 | print('Relative RL2 error: {}'.format(rl2e_f)) 415 | print('Absolute inf error: {}'.format(infe_f)) 416 | print('Average standard deviation: {}'.format(np.mean(f_pred_ens_std))) 417 | print('log predictive probability: {}'.format(lpp_f)) 418 | print('Percentage of coverage:{}\n'.format(np.sum(f_env)/Npred)) 419 | 420 | print('u prediction:\n', file = f_rec) 421 | print('Relative RL2 error: {}'.format(rl2e_u), file = f_rec) 422 | print('Absolute inf error: {}'.format(infe_u), file = f_rec) 423 | print('Average standard deviation: {}'.format(np.mean(u_pred_ens_std)), file = f_rec) 424 | print('log predictive probability: {}'.format(lpp_u), file = f_rec) 425 | print('Percentage of coverage:{}\n'.format(np.sum(u_env)/Npred), file = f_rec) 426 | 427 | print('f prediction:\n', file = f_rec) 428 | print('Relative RL2 error: {}'.format(rl2e_f), file = f_rec) 429 | print('Absolute inf error: {}'.format(infe_f), file = f_rec) 430 | print('Average standard deviation: {}'.format(np.mean(f_pred_ens_std)), file = f_rec) 431 | print('log predictive probability: {}'.format(lpp_f), file = f_rec) 432 | print('Percentage of coverage:{}\n'.format(np.sum(f_env)/Npred), file = f_rec) 433 | 434 | f_rec.close() 435 | 436 | # ============================================================================= 437 | # Plot posterior predictions 438 | # ============================================================================= 439 | 440 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 441 | ax.plot(x_pred_index, u_ref, 'k-', label='Exact') 442 | ax.plot(x_pred_index, u_pred_ens_mean, markersize = 1, markevery=2, markerfacecolor='None', label= 'SVGD mean', alpha = 0.8) 443 | ax.fill_between(x_pred_index, u_pred_ens_mean + 2 * u_pred_ens_std, u_pred_ens_mean - 2 * u_pred_ens_std, 444 | alpha = 0.3, label = r'$95 \% $ CI') 445 | ax.scatter(data[:,0], data[:,1], label='Obs' , s = 20, facecolors='none', edgecolors='b') 446 | ax.set_xlabel('$x$', fontsize=16) 447 | ax.set_ylabel('$u(x)$', fontsize=16) 448 | ax.set_xlim(-1.02,1.02) 449 | ax.set_ylim(-1.5,1.5) 450 | ax.tick_params(axis='both', which = 'major', labelsize=13) 451 | ax.legend(fontsize=10) 452 | fig.tight_layout() 453 | plt.savefig(os.path.join(path_fig,f'1D_linear_poisson_SVGD_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}_upred.png')) 454 | plt.show() 455 | 456 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 457 | r_ref = f(x_pred_index) 458 | ax.plot(x_pred_index, r_ref, 'k-', label='Exact') 459 | ax.plot(x_pred_index, f_pred_ens_mean, markersize = 1, markevery=2, markerfacecolor='None', label='SVGD mean', alpha = 0.8) 460 | ax.fill_between(x_pred_index, f_pred_ens_mean + 2 * f_pred_ens_std, f_pred_ens_mean - 2 * f_pred_ens_std, 461 | alpha = 0.3, label = r'$95 \% $ CI') 462 | ax.scatter(Dres[:,0], Dres[:,1], label='Obs' , s = 20, facecolors='none', edgecolors='b') 463 | ax.set_xlabel('$x$', fontsize=16) 464 | ax.set_xlim(-1.02,1.02) 465 | ax.set_ylim(-1.5,1.5) 466 | ax.set_ylabel('$f(x)$', fontsize=16) 467 | ax.tick_params(axis='both', which = 'major', labelsize=13) 468 | ax.legend(fontsize=10, loc= 'upper left') 469 | fig.tight_layout() 470 | plt.savefig(os.path.join(path_fig,f'1D_linear_poisson_SVGD_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}_fpred.png')) 471 | plt.show() 472 | 473 | # Loss plot 474 | t = np.arange(0, nIter, num_print) 475 | fig = plt.figure(constrained_layout=False, figsize=(4, 4), dpi = 300) 476 | ax = fig.add_subplot() 477 | ax.plot(t, -np.array(model.log_prob_rec), color='blue', label='Negative log prob') 478 | ax.set_yscale('log') 479 | ax.set_ylabel('Loss', fontsize = 16) 480 | ax.set_xlabel('Epochs', fontsize = 16) 481 | ax.legend(loc='upper right', fontsize = 14) 482 | fig.tight_layout() 483 | plt.savefig(os.path.join(path_fig,f'1D_linear_poisson_SVGD_loss.png')) 484 | plt.show() 485 | 486 | t = np.arange(0, nIter, num_print) 487 | fig = plt.figure(constrained_layout=False, figsize=(4, 4), dpi = 300) 488 | ax = fig.add_subplot() 489 | ax.plot(t, np.array(model.u_rl2e_log), label='u') 490 | ax.plot(t, np.array(model.f_rl2e_log), label='f') 491 | ax.set_yscale('log') 492 | ax.set_ylabel('relative L2 error', fontsize = 16) 493 | ax.set_xlabel('Epochs', fontsize = 16) 494 | ax.legend(loc='upper right', fontsize = 14) 495 | fig.tight_layout() 496 | fig.savefig(os.path.join(path_fig,'test_rl2e.png')) 497 | 498 | t = np.arange(0, nIter, num_print) 499 | fig = plt.figure(constrained_layout=False, figsize=(4, 4), dpi = 300) 500 | ax = fig.add_subplot() 501 | ax.plot(t, np.array(model.u_lpp_log), label='u') 502 | ax.plot(t, np.array(model.f_lpp_log), label='f') 503 | ax.set_ylabel('LPP', fontsize = 16) 504 | ax.set_xlabel('Epochs', fontsize = 16) 505 | ax.legend(loc='upper right', fontsize = 14) 506 | fig.tight_layout() 507 | fig.savefig(os.path.join(path_fig,'test_lpp.png')) 508 | 509 | t = np.arange(0, nIter, num_print) 510 | fig = plt.figure(constrained_layout=False, figsize=(4, 4), dpi = 300) 511 | ax = fig.add_subplot() 512 | ax.plot(t, np.array(model.u_std_log), label='u') 513 | ax.plot(t, np.array(model.f_std_log), label='f') 514 | ax.set_ylabel('std', fontsize = 16) 515 | ax.set_xlabel('Epochs', fontsize = 16) 516 | ax.set_yscale('log') 517 | ax.legend(loc='upper right', fontsize = 14) 518 | fig.tight_layout() 519 | fig.savefig(os.path.join(path_fig,'test_std.png')) 520 | -------------------------------------------------------------------------------- /1D_nonlinear_Poisson/1D_nonlinear_poisson_HMC.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Jun 26 11:31:38 2023 5 | 6 | @author: yifeizong 7 | """ 8 | 9 | import jax 10 | import os 11 | import jax.numpy as jnp 12 | from jax import random, grad, vmap, jit 13 | from jax.flatten_util import ravel_pytree 14 | from jax.example_libraries import optimizers 15 | 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | import matplotlib as mpl 19 | import scipy.linalg as spl 20 | import seaborn as sns 21 | import pandas as pd 22 | from tensorflow_probability.substrates import jax as tfp 23 | tfd = tfp.distributions 24 | 25 | import itertools 26 | import argparse 27 | from functools import partial 28 | from tqdm import trange 29 | from time import perf_counter 30 | 31 | #command line argument parser 32 | parser = argparse.ArgumentParser(description="1D nonLinear Poisson with HMC") 33 | parser.add_argument( 34 | "--rand_seed", 35 | type=int, 36 | default=42, 37 | help="random seed") 38 | parser.add_argument( 39 | "--sigma", 40 | type=float, 41 | default=0.1, 42 | help="Data uncertainty") 43 | parser.add_argument( 44 | "--sigma_r", 45 | type=float, 46 | default=0.1, 47 | help="Aleotoric uncertainty to the residual") 48 | parser.add_argument( 49 | "--sigma_d", 50 | type=float, 51 | default=0.04, 52 | help="Aleotoric uncertainty to the data") 53 | parser.add_argument( 54 | "--sigma_p", 55 | type=float, 56 | default=1.452, 57 | help="Prior std") 58 | parser.add_argument( 59 | "--Nres", 60 | type=int, 61 | default=128, 62 | help="Number of reisudal points") 63 | parser.add_argument( 64 | "--Nsamples", 65 | type=int, 66 | default=5000, 67 | help="Number of Posterior samples") 68 | parser.add_argument( 69 | "--Nburn", 70 | type=int, 71 | default=100000, 72 | help="Number of Posterior samples") 73 | parser.add_argument( 74 | "--data_load", 75 | type=bool, 76 | default=False, 77 | help="If to load data") 78 | args = parser.parse_args() 79 | 80 | #Define parameters 81 | layers_u = [1, 50, 50, 1] 82 | lbt = np.array([-0.7]) 83 | ubt = np.array([0.7]) 84 | lamb = 0.01 85 | k = 0.7 86 | dataset = dict() 87 | rand_seed = args.rand_seed 88 | Nres = args.Nres 89 | sigma = args.sigma 90 | sigma_r = args.sigma_r 91 | sigma_b = args.sigma_d 92 | sigma_p = args.sigma_p 93 | Nsamples = args.Nsamples 94 | Nburn = args.Nburn 95 | Nchains = 9 96 | dataset = dict() 97 | path_f = f'1D_nonlinear_poisson_HMC_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_Nburn_{Nburn}' 98 | path_fig = os.path.join(path_f,'figures') 99 | if not os.path.exists(path_f): 100 | os.makedirs(path_f) 101 | if not os.path.exists(path_fig): 102 | os.makedirs(path_fig) 103 | f_rec = open(os.path.join(path_f,'record.out'), 'a+') 104 | 105 | def u(x): 106 | return jnp.sin(6*x)**3 107 | 108 | def f(x): 109 | return lamb*(-108*jnp.sin(6*x)**3 + 216*jnp.sin(6*x)*jnp.cos(6*x)**2) + k*jnp.tanh(u(x)) 110 | 111 | rl2e = lambda yest, yref : spl.norm(yest - yref, 2) / spl.norm(yref, 2) 112 | infe = lambda yest, yref : spl.norm(yest - yref, np.inf) 113 | lpp = lambda h, href, sigma: np.sum( -(h - href)**2/(2*sigma**2) - 1/2*np.log( 2*np.pi) - 2*np.log(sigma)) 114 | 115 | # Create synthetic noisy data 116 | #create noisy boundary data 117 | np.random.seed(rand_seed) 118 | x_data = np.array([lbt[0], ubt[0]])[:,np.newaxis] 119 | y_data = np.array([u(lbt[0]), u(ubt[0])])[:,np.newaxis].astype(np.float32) + np.random.normal(0,sigma,(2,1)).astype(np.float32) 120 | data = jnp.concatenate([x_data,y_data], axis=1) 121 | dataset.update({'data': data}) 122 | 123 | #create noisy forcing sampling 124 | X_r = np.linspace(lbt[0], ubt[0], Nres) 125 | X_r = jnp.sort(X_r, axis = 0)[:,np.newaxis] 126 | y_r = f(X_r) + np.random.normal(0,sigma,(Nres,1)) 127 | Dres = jnp.asarray(jnp.concatenate([X_r,y_r], axis=1)) 128 | dataset.update({'res': Dres}) 129 | 130 | # Define FNN 131 | def FNN(layers, activation=jnp.tanh): 132 | 133 | def init(prng_key): #return a list of (W,b) tuples 134 | def init_layer(key, d_in, d_out): 135 | key1, key2 = random.split(key) 136 | glorot_stddev = 1.0 / jnp.sqrt((d_in + d_out) / 2.) 137 | W = glorot_stddev * random.normal(key1, (d_in, d_out)) 138 | b = jnp.zeros(d_out) 139 | return W, b 140 | key, *keys = random.split(prng_key, len(layers)) 141 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 142 | return params 143 | 144 | def forward(params, inputs): 145 | Z = inputs 146 | for W, b in params[:-1]: 147 | outputs = jnp.dot(Z, W) + b 148 | Z = activation(outputs) 149 | W, b = params[-1] 150 | outputs = jnp.dot(Z, W) + b 151 | return outputs 152 | 153 | return init, forward 154 | 155 | # Define the model 156 | class PINN(): 157 | def __init__(self, key, layers, dataset, lbt, ubt, lamb, k, sigma_r, sigma_b, sigma_p): 158 | 159 | self.lbt = lbt #domain lower corner 160 | self.ubt = ubt #domain upper corner 161 | self.k = k 162 | self.lamb = lamb 163 | self.sigma_r = sigma_r 164 | self.sigma_b = sigma_b 165 | self.sigma_p = sigma_p 166 | 167 | # Prepare normalized training data 168 | self.dataset = dataset 169 | self.X_res, self.y_res = dataset['res'][:,0:1], dataset['res'][:,1:2] 170 | self.X_data, self.y_data = dataset['data'][:,0:1], dataset['data'][:,1:2] 171 | 172 | # Initalize the network 173 | self.init, self.forward = FNN(layers, activation=jnp.tanh) 174 | self.params = self.init(key) 175 | _, self.unravel = ravel_pytree(self.params) 176 | self.num_params = ravel_pytree(self.params)[0].shape[0] 177 | 178 | # Evaluate the network and the residual over the grid 179 | self.u_pred_map = vmap(self.predict_u, (None, 0)) 180 | self.f_pred_map = vmap(self.predict_f, (None, 0)) 181 | 182 | self.itercount = itertools.count() 183 | self.loss_log = [] 184 | self.loss_likelihood_log = [] 185 | self.loss_dbc_log = [] 186 | self.loss_res_log = [] 187 | 188 | lr = optimizers.exponential_decay(1e-3, decay_steps=5000, decay_rate=0.9) 189 | self.opt_init, \ 190 | self.opt_update, \ 191 | self.get_params = optimizers.adam(lr) 192 | self.opt_state = self.opt_init(self.params) 193 | 194 | @partial(jit, static_argnums=(0,)) 195 | def u_net(self, params, x): 196 | inputs = jnp.hstack([x]) 197 | outputs = self.forward(params, inputs) 198 | return outputs[0] 199 | 200 | @partial(jit, static_argnums=(0,)) 201 | def res_net(self, params, x): 202 | u = self.u_net(params, x) 203 | u_xx = grad(grad(self.u_net, argnums=1), argnums=1)(params, x) 204 | return self.lamb*u_xx + self.k*jnp.tanh(u) 205 | 206 | @partial(jit, static_argnums=(0,)) 207 | def predict_u(self, params, x): 208 | return self.u_net(params, x) 209 | 210 | @partial(jit, static_argnums=(0,)) 211 | def predict_f(self, params, x): 212 | return self.res_net(params, x) 213 | 214 | @partial(jit, static_argnums=(0,)) 215 | def u_pred_vector(self, params): 216 | u_pred_vec = vmap(self.u_net, (None, 0))(self.unravel(params), self.X_data[:,0]) 217 | return u_pred_vec 218 | 219 | @partial(jit, static_argnums=(0,)) 220 | def f_pred_vector(self, params): 221 | f_pred_vec = vmap(self.res_net, (None, 0))(self.unravel(params), self.X_res[:,0]) 222 | return f_pred_vec 223 | 224 | 225 | key1, key2 = random.split(random.PRNGKey(0), 2) 226 | pinn = PINN(key2, layers_u, dataset, lbt, ubt, lamb, k, sigma_r, sigma_b, sigma_p) 227 | num_params = ravel_pytree(pinn.params)[0].shape[0] 228 | 229 | def target_log_prob_fn(theta): 230 | prior = -1/(2*sigma_p**2)*jnp.sum( theta**2 ) 231 | r_likelihood = -1/(2*sigma_r**2)*jnp.sum( (y_r.ravel() - pinn.f_pred_vector(theta))**2 ) 232 | u_likelihood = -1/(2*sigma_b**2)*jnp.sum( (y_data.ravel() - pinn.u_pred_vector(theta))**2 ) 233 | return prior + r_likelihood + u_likelihood 234 | 235 | new_key, *subkeys = random.split(key1, Nchains + 1) 236 | init_state = jnp.zeros((1, num_params)) 237 | for key in subkeys[:-1]: 238 | init_state = jnp.concatenate([init_state,random.normal(key ,(1, num_params))], axis=0) 239 | 240 | # key3, key4, key5 = random.split(key1, 3) 241 | # init_state = jnp.zeros((1, num_params)) 242 | # init_state = jnp.concatenate([init_state,random.normal(key4 ,(1, num_params))], axis=0) 243 | # init_state = jnp.concatenate([init_state,3 + random.normal(key5 ,(1, num_params))], axis=0) 244 | 245 | nuts_kernel = tfp.mcmc.NoUTurnSampler( 246 | target_log_prob_fn = target_log_prob_fn, step_size = 0.0005, max_tree_depth=10, max_energy_diff=1000.0, 247 | unrolled_leapfrog_steps=1, parallel_iterations=30) 248 | 249 | kernel = tfp.mcmc.DualAveragingStepSizeAdaptation( 250 | inner_kernel=nuts_kernel, num_adaptation_steps=int(Nburn * 0.75)) 251 | 252 | def run_chain(init_state, key): 253 | samples, trace = tfp.mcmc.sample_chain( 254 | num_results= Nsamples, 255 | num_burnin_steps= Nburn, 256 | current_state= init_state, 257 | kernel= kernel, 258 | seed=key, 259 | trace_fn= lambda _,pkr: [pkr.inner_results.log_accept_ratio, 260 | pkr.inner_results.target_log_prob, 261 | pkr.inner_results.step_size] 262 | ) 263 | return samples, trace 264 | 265 | # nuts_kernel = tfp.mcmc.NoUTurnSampler( 266 | # target_log_prob_fn = target_log_prob_fn, step_size = 0.0001, max_tree_depth=10, max_energy_diff=1000.0, 267 | # unrolled_leapfrog_steps=1, parallel_iterations=30) 268 | 269 | # def run_chain(init_state, key): 270 | # samples, trace = tfp.mcmc.sample_chain( 271 | # num_results= Nsamples, 272 | # num_burnin_steps= Nburn, 273 | # current_state= init_state, 274 | # kernel= nuts_kernel, 275 | # seed=key, 276 | # trace_fn= lambda _,pkr: [pkr.log_accept_ratio, 277 | # pkr.target_log_prob, 278 | # pkr.step_size] 279 | # ) 280 | # return samples, trace 281 | 282 | ts = perf_counter() 283 | print('\nStart HMC Sampling') 284 | #states, trace = jit(run_chain)(init_state, key3) 285 | states, trace = jit(vmap(run_chain, in_axes=(0, None)))(init_state, new_key) 286 | print('\nFinish HMC Sampling') 287 | timings = perf_counter() - ts 288 | print(f"HMC: {timings} s") 289 | np.save(os.path.join(path_f,'chains'), states) 290 | 291 | # ============================================================================= 292 | # Post-processing HMC results 293 | # ============================================================================= 294 | 295 | accept_ratio = np.exp(trace[0]) 296 | target_log_prob = trace[1] 297 | step_size = trace[2] 298 | print(f'Average accept ratio for each chain: {np.mean(accept_ratio, axis = 1)}') 299 | print(f'Average step size for each chain: {np.mean(step_size, axis = 1)}') 300 | print(f'Average accept ratio for each chain: {np.mean(accept_ratio, axis = 1)}', file = f_rec) 301 | print(f'Average step size for each chain: {np.mean(step_size, axis = 1)}', file = f_rec) 302 | 303 | mark = [None, 'o', None] 304 | linestyle = ['solid', 'dotted', 'dashed'] 305 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 306 | for i, mark in enumerate(mark): 307 | ax.plot(np.arange(Nsamples)[::10], target_log_prob[i,::10], marker = mark, markersize = 2, markevery= 100, markerfacecolor='None', linestyle = 'dashed', label = f'chain {i + 1}', alpha = 0.8) 308 | ax.set_xlabel('Sample index', fontsize = 15) 309 | ax.set_ylabel('Negative log prob', fontsize = 15) 310 | ax.tick_params(axis='both', which = 'major', labelsize=12) 311 | ax.set_xlim(0,Nsamples) 312 | ax.legend(fontsize=6) 313 | plt.savefig(os.path.join(path_fig,'target_log_prob.png')) 314 | plt.show() 315 | 316 | Npred = 201 317 | x_pred_index = jnp.linspace(-0.7,0.7,Npred) 318 | f_ref = f(x_pred_index) 319 | u_ref = u(x_pred_index) 320 | samples = states #(3, 10000, 2701) 321 | samples = states 322 | 323 | @jit 324 | def get_u_pred(sample): 325 | return pinn.u_pred_map(pinn.unravel(sample),x_pred_index) 326 | 327 | @jit 328 | def get_f_pred(sample): 329 | return pinn.f_pred_map(pinn.unravel(sample),x_pred_index) 330 | 331 | u_pred_ens = np.array([vmap(get_u_pred)(samples[i,:,:]) for i in range(samples.shape[0])]) # (3, 10000, 201) 332 | f_pred_ens = np.array([vmap(get_f_pred)(samples[i,:,:]) for i in range(samples.shape[0])]) 333 | 334 | u_pred_ens_mean = np.mean(u_pred_ens, axis = 1) #(3, 201) 335 | u_pred_ens_std = np.std(u_pred_ens, axis = 1) 336 | f_pred_ens_mean = np.mean(f_pred_ens, axis = 1) 337 | f_pred_ens_std = np.std(f_pred_ens, axis = 1) 338 | 339 | u_env = np.logical_and( (u_pred_ens_mean < u_ref + 2*u_pred_ens_std), (u_pred_ens_mean > u_ref - 2*u_pred_ens_std) ) 340 | f_env = np.logical_and( (f_pred_ens_mean < f_ref + 2*f_pred_ens_std), (f_pred_ens_mean > f_ref - 2*f_pred_ens_std) ) 341 | 342 | # ============================================================================= 343 | # Posterior Statistics 344 | # ============================================================================= 345 | 346 | for i in range(Nchains): 347 | rl2e_u = rl2e(u_pred_ens_mean[i, :], u_ref) 348 | infe_u = infe(u_pred_ens_mean[i, :], u_ref) 349 | lpp_u = lpp(u_pred_ens_mean[i, :], u_ref, u_pred_ens_std[i, :]) 350 | rl2e_f = rl2e(f_pred_ens_mean[i, :], f_ref) 351 | infe_f = infe(f_pred_ens_mean[i, :], f_ref) 352 | lpp_f = lpp(f_pred_ens_mean[i, :], f_ref, f_pred_ens_std[i, :]) 353 | 354 | print(f'chain {i}:\n') 355 | print('u prediction:\n') 356 | print('Relative RL2 error: {}'.format(rl2e_u)) 357 | print('Absolute inf error: {}'.format(infe_u)) 358 | print('Average standard deviation: {}'.format(np.mean(u_pred_ens_std[i, :]))) 359 | print('log predictive probability: {}'.format(lpp_u)) 360 | print('Percentage of coverage:{}\n'.format(np.sum(u_env[i, :])/Npred)) 361 | 362 | print('f prediction:\n') 363 | print('Relative RL2 error: {}'.format(rl2e_f)) 364 | print('Absolute inf error: {}'.format(infe_f)) 365 | print('Average standard deviation: {}'.format(np.mean(f_pred_ens_std[i, :]))) 366 | print('log predictive probability: {}'.format(lpp_f)) 367 | print('Percentage of coverage:{}\n'.format(np.sum(f_env[i, :])/Npred)) 368 | 369 | print(f'chain {i}:\n', file = f_rec) 370 | print('u prediction:\n', file = f_rec) 371 | print('Relative RL2 error: {}'.format(rl2e_u), file = f_rec) 372 | print('Absolute inf error: {}'.format(infe_u), file = f_rec) 373 | print('Average standard deviation: {}'.format(np.mean(u_pred_ens_std[i, :])), file = f_rec) 374 | print('log predictive probability: {}'.format(lpp_u), file = f_rec) 375 | print('Percentage of coverage:{}\n'.format(np.sum(u_env[i, :])/Npred), file = f_rec) 376 | 377 | print('f prediction:\n', file = f_rec) 378 | print('Relative RL2 error: {}'.format(rl2e_f), file = f_rec) 379 | print('Absolute inf error: {}'.format(infe_f), file = f_rec) 380 | print('Average standard deviation: {}'.format(np.mean(f_pred_ens_std[i, :])), file = f_rec) 381 | print('log predictive probability: {}'.format(lpp_f), file = f_rec) 382 | print('Percentage of coverage:{}\n'.format(np.sum(f_env[i, :])/Npred), file = f_rec) 383 | 384 | # fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 385 | # ax.plot(x_pred_index, u_ref, 'k-', label='Exact') 386 | # mark = [None, None, None] 387 | # linestyle = ['solid', 'dotted', 'dashed'] 388 | # color = ['#2ca02c', '#ff7f0e', '#1f77b4'] # green , orange, blue 389 | # for i, c in enumerate(color): 390 | # ax.plot(x_pred_index, u_pred_ens_mean[i, :], color = c, markersize = 1, markevery=2, markerfacecolor='None', label=f'Chain {i}', alpha = 0.8) 391 | # #ax.scatter(x_pred_index, u_pred_ens_mean[i], marker = mark, label=f'Pred mean - chain{i+1}' , s = 15) 392 | # ax.fill_between(x_pred_index, u_pred_ens_mean[i,:] + 2 * u_pred_ens_std[i,:], u_pred_ens_mean[i,:] - 2 * u_pred_ens_std[i,:], alpha = 0.3) 393 | # #alpha = 0.3, label = f'$95\%$ CI') 394 | # ax.scatter(x_data, y_data, label='Obs' , s = 20, facecolors='none', edgecolors='b') 395 | # ax.set_xlabel('$x$', fontsize=16) 396 | # ax.set_ylabel('$u(x)$', fontsize=16) 397 | # ax.set_xlim(-0.72,0.72) 398 | # ax.set_ylim(-2.5,2.5) 399 | # ax.tick_params(axis='both', which = 'major', labelsize=13) 400 | # ax.legend(fontsize=10, loc = 'upper left') 401 | # fig.tight_layout() 402 | # plt.show() 403 | # plt.savefig(os.path.join(path_fig,'u_pred.png')) 404 | 405 | 406 | # fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 407 | # ax.plot(x_pred_index, f_ref, 'k-', label='Exact') 408 | # color = ['#2ca02c', '#ff7f0e', '#1f77b4'] # green , orange, blue 409 | # for i, c in enumerate(color): 410 | # ax.plot(x_pred_index, f_pred_ens_mean[i, :], color = c, markersize = 1, markevery=2, markerfacecolor='None', label=f'Chain {i}', alpha = 0.8) 411 | # #ax.scatter(x_pred_index, u_pred_ens_mean[i], marker = mark, label=f'Pred mean - chain{i+1}' , s = 15) 412 | # ax.fill_between(x_pred_index, f_pred_ens_mean[i,:] + 2 * f_pred_ens_std[i,:], f_pred_ens_mean[i,:] - 2 * f_pred_ens_std[i,:], alpha = 0.3) 413 | # #alpha = 0.3, label = f'$95\%$ CI') 414 | # ax.scatter(X_r, y_r, label='Obs' , s = 20, facecolors='none', edgecolors='b') 415 | # ax.set_xlabel('$x$', fontsize=16) 416 | # ax.set_ylabel('$f(x)$', fontsize=16) 417 | # ax.set_xlim(-0.72,0.72) 418 | # ax.set_ylim(-2.5,2.5) 419 | # ax.tick_params(axis='both', which = 'major', labelsize=13) 420 | # ax.legend(fontsize=10, loc = 'upper left') 421 | # fig.tight_layout() 422 | # plt.show() 423 | # plt.savefig(os.path.join(path_fig,'f_pred.png')) 424 | 425 | #rhat and ess 426 | rhat = tfp.mcmc.diagnostic.potential_scale_reduction(states.transpose((1,0,2)), independent_chain_ndims=1) 427 | ess = tfp.mcmc.effective_sample_size(states[0], filter_beyond_positive_pairs=True) 428 | rhat_u = tfp.mcmc.diagnostic.potential_scale_reduction(u_pred_ens.transpose((1,0,2)), independent_chain_ndims=1) 429 | rhat_f = tfp.mcmc.diagnostic.potential_scale_reduction(f_pred_ens.transpose((1,0,2)), independent_chain_ndims=1) 430 | 431 | fig, ax = plt.subplots(1, 1, figsize=(4, 4), sharex='col', sharey='col', dpi = 300) 432 | g = sns.histplot(rhat, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 433 | g.tick_params(labelsize=16) 434 | g.set_xlabel("$\hat{r}$", fontsize=18) 435 | g.set_ylabel("Count", fontsize=18) 436 | fig.tight_layout() 437 | plt.savefig(os.path.join(path_fig,'rhat.png')) 438 | plt.show() 439 | 440 | fig, ax = plt.subplots(1, 1, figsize=(4, 4), sharex='col', sharey='col', dpi = 300) 441 | g = sns.histplot(rhat_u, bins = 25, kde=True, kde_kws = {'gridsize':5000}) 442 | g.tick_params(labelsize=16) 443 | g.set_xlabel("$\hat{r}_{u}$", fontsize=18) 444 | g.set_ylabel("Count", fontsize=18) 445 | fig.tight_layout() 446 | plt.savefig(os.path.join(path_fig,'rhat_u.png')) 447 | plt.show() 448 | 449 | fig, ax = plt.subplots(1, 1, figsize=(4, 4), sharex='col', sharey='col', dpi = 300) 450 | g = sns.histplot(rhat_f, bins = 25, kde=True, kde_kws = {'gridsize':5000}) 451 | g.tick_params(labelsize=16) 452 | g.set_xlabel("$\hat{r}_{f}$", fontsize=18) 453 | g.set_ylabel("Count", fontsize=18) 454 | fig.tight_layout() 455 | plt.savefig(os.path.join(path_fig,'rhat_f.png')) 456 | plt.show() 457 | 458 | fig, ax = plt.subplots(1, 1, figsize=(4, 4), sharex='col', sharey='col', dpi = 300) 459 | g = sns.histplot(ess, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 460 | g.tick_params(labelsize=16) 461 | g.set_xlabel("ESS", fontsize=18) 462 | g.set_ylabel("Count", fontsize=18) 463 | fig.tight_layout() 464 | plt.savefig(os.path.join(path_fig,'ess.png')) 465 | plt.show() 466 | 467 | # trace plots 468 | idx_low = np.argmin(rhat) 469 | idx_high = np.argmax(rhat) 470 | samples1 = states[:,:,idx_low] 471 | df1 = pd.DataFrame({'chains': 'chain1', 'indice':np.arange(0, samples1[0,].shape[0], 5), 'trace':samples1[0, ::5]}) 472 | df2 = pd.DataFrame({'chains': 'chain2', 'indice':np.arange(0, samples1[1,].shape[0], 5), 'trace':samples1[1, ::5]}) 473 | df3 = pd.DataFrame({'chains': 'chain3', 'indice':np.arange(0, samples1[2,].shape[0], 5), 'trace':samples1[2, ::5]}) 474 | df = pd.concat([df1, df2, df3], ignore_index=True) 475 | plt.figure(figsize=(4,4)) 476 | g = sns.jointplot(data=df, x='indice', y='trace', xlim=(0, 5000), ylim=(-4, 4), hue='chains', joint_kws={'alpha': 0.6}) 477 | g.ax_joint.tick_params(labelsize=18) 478 | g.ax_joint.set_xlabel("Index", fontsize=24) 479 | g.ax_joint.set_ylabel("Trace", fontsize=24) 480 | g.ax_joint.legend(fontsize=16) 481 | g.ax_marg_x.remove() 482 | #plt.title('Trace plot for parameter with lowest $\hat{r}$') 483 | plt.gcf().set_dpi(300) 484 | plt.savefig(os.path.join(path_fig,'trace_plot_rhat_lowest.png')) 485 | fig.tight_layout() 486 | plt.show() 487 | 488 | samples2 = states[:,:,idx_high] 489 | df1 = pd.DataFrame({'chains': 'chain1', 'indice':np.arange(0, samples2[0, ::].shape[0], 5), 'trace':samples2[0, ::5]}) 490 | df2 = pd.DataFrame({'chains': 'chain2', 'indice':np.arange(0, samples2[1, ::].shape[0], 5), 'trace':samples2[1, ::5]}) 491 | df3 = pd.DataFrame({'chains': 'chain3', 'indice':np.arange(0, samples2[2, ::].shape[0], 5), 'trace':samples2[2, ::5]}) 492 | df = pd.concat([df1,df2, df3], ignore_index=True) 493 | plt.figure(figsize=(4,4)) 494 | #g = sns.jointplot(data=df, x='indice', y='trace', xlim=(0, 5000), ylim=(-4, 4), hue='chains', joint_kws={'alpha': 1}) 495 | g = sns.jointplot(data=df, x='indice', y='trace', xlim=(0, 5000), ylim=(-4, 4), hue='chains', joint_kws={'alpha': 0.6}) 496 | g.ax_joint.tick_params(labelsize=18) 497 | g.ax_joint.set_xlabel("Index", fontsize=24) 498 | g.ax_joint.set_ylabel("Trace", fontsize=24) 499 | g.ax_joint.legend(fontsize=16) 500 | g.ax_marg_x.remove() 501 | #plt.title('Trace plot for parameter with highest $\hat{r}$') 502 | plt.gcf().set_dpi(300) 503 | plt.savefig(os.path.join(path_fig,'trace_plot_rhat_highest.png')) 504 | fig.tight_layout() 505 | plt.show() 506 | 507 | # ============================================================================= 508 | # Compute Hessian at the mean 509 | # ============================================================================= 510 | cov0 = np.cov(samples[0].T) 511 | cov0_diag = np.diag(cov0) 512 | cov1 = np.cov(samples[1].T) 513 | cov1_diag = np.diag(cov1) 514 | cov2 = np.cov(samples[2].T) 515 | cov2_diag = np.diag(cov2) 516 | 517 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 518 | g = ax.imshow(cov0[:100, :100], interpolation = 'nearest', cmap = 'hot') 519 | plt.colorbar(g) 520 | plt.show() 521 | 522 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 523 | g = ax.imshow(cov1[:100, :100], interpolation = 'nearest', cmap = 'hot') 524 | plt.colorbar(g) 525 | plt.show() 526 | 527 | chain0 = states[0] 528 | chain1 = states[1] 529 | chain2 = states[2] 530 | chain0_m = np.mean(chain0, axis = 0) 531 | chain1_m = np.mean(chain1, axis = 0) 532 | chain2_m = np.mean(chain2, axis = 0) 533 | hess = jax.hessian(target_log_prob_fn) 534 | hess_chain0 = hess(chain0_m) 535 | _, s0, _ = jax.scipy.linalg.svd(jax.scipy.linalg.inv(hess_chain0)) 536 | hess_chain1 = hess(chain1_m) 537 | _, s1, _ = jax.scipy.linalg.svd(jax.scipy.linalg.inv(hess_chain1)) 538 | hess_chain2 = hess(chain2_m) 539 | _, s2, _ = jax.scipy.linalg.svd(jax.scipy.linalg.inv(hess_chain2)) 540 | 541 | s = np.concatenate((s0[np.newaxis, :], s1[np.newaxis, :], s2[np.newaxis, :]), axis = 0) 542 | np.savetxt(os.path.join(path_f,'singular_values_posterior_hessian.out'), s) 543 | 544 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 545 | #mark = [None, 'o', None] 546 | linestyle = ['solid', 'dotted', 'dashed'] 547 | for i, ls in enumerate(linestyle): 548 | ax.plot(s[i], linestyle = ls, marker = None, markersize = 2, markevery= 100, markerfacecolor='None', label=f'chain{i+1}', alpha = 0.8) 549 | ax.set_xlabel('Index', fontsize=16) 550 | ax.set_ylabel('Eigenvalues', fontsize=16) 551 | plt.yscale('log') 552 | ax.tick_params(axis='both', which = 'major', labelsize=13) 553 | ax.legend(fontsize=8) 554 | plt.savefig(os.path.join(path_fig,'singular_values_posterior_hessian.png')) 555 | plt.show() 556 | 557 | fig, ax = plt.subplots(1, 1, figsize=(4, 4), sharex='col', sharey='col', dpi = 300) 558 | g = sns.histplot(chain0_m, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 559 | g = sns.histplot(chain1_m, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 560 | g = sns.histplot(chain2_m, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 561 | g.tick_params(labelsize=16) 562 | g.set_xlabel("Weight", fontsize=18) 563 | g.set_ylabel("Count", fontsize=18) 564 | fig.tight_layout() 565 | plt.savefig(os.path.join(path_fig,'weight.png')) 566 | plt.show() 567 | 568 | # ============================================================================= 569 | # Plot posterior space 570 | # ============================================================================= 571 | 572 | 573 | # class RandomCoordinates(object): 574 | # # randomly choose some directions 575 | # def __init__(self, origin): 576 | # self.origin = origin # (num_params,) 577 | # self.v0 = self.normalize( 578 | # random.normal(key = random.PRNGKey(88), shape = self.origin.shape), 579 | # self.origin) 580 | # self.v1 = self.normalize( 581 | # random.normal(key = random.PRNGKey(66), shape = self.origin.shape), 582 | # self.origin) 583 | 584 | # def __call__(self, a, b): 585 | # return a*self.v0 + b * self.v1 + self.origin 586 | 587 | # def normalize(self, weights, origin): 588 | # return weights * jnp.abs(origin)/ jnp.abs(weights) # 589 | 590 | 591 | # class LossSurface(object): 592 | # def __init__(self, loss_fn, coords): 593 | # self.loss_fn = loss_fn 594 | # self.coords = coords 595 | 596 | # def compile(self, range, num_points): 597 | # loss_fn_0d = lambda x, y: self.loss_fn(self.coords(x,y)) 598 | # loss_fn_1d = jax.vmap(loss_fn_0d, in_axes = (0,0), out_axes = 0) 599 | # loss_fn_2d = jax.vmap(loss_fn_1d, in_axes = (0,0), out_axes = 0) 600 | 601 | # self.a_grid = jnp.linspace(-1.0, 1.0, num=num_points) ** 3 * range #(-5, 5) power rate 602 | # self.b_grid = jnp.linspace(-1.0, 1.0, num=num_points) ** 3 * range 603 | # self.aa, self.bb = jnp.meshgrid(self.a_grid, self.b_grid) 604 | # self.loss_grid = loss_fn_2d(self.aa, self.bb) 605 | 606 | # def project_points(self, points): 607 | # x = jax.vmap(lambda x: jnp.dot(x, self.coords.v0)/jnp.linalg.norm(self.coords.v0), 0, 0)(points) 608 | # y = jax.vmap(lambda y: jnp.dot(y, self.coords.v1)/jnp.linalg.norm(self.coords.v1), 0, 0)(points) 609 | # return x, y 610 | 611 | # def plot(self, levels=30, points = None, ax=None, **kwargs): 612 | # xs = self.a_grid 613 | # ys = self.b_grid 614 | # zs = self.loss_grid 615 | # if ax is None: 616 | # fig, ax = plt.subplots(dpi = 300, **kwargs) 617 | # ax.set_title("Loss Surface") 618 | # ax.set_aspect("equal") 619 | 620 | # # Set Levels 621 | # min_loss = zs.min() 622 | # max_loss = zs.max() 623 | # levels = jnp.linspace( 624 | # max_loss, min_loss, num=levels 625 | # )[::-1] 626 | 627 | # # levels = jnp.exp( 628 | # # jnp.log(min_loss) + 629 | # # jnp.linspace(0., 1.0, num=levels) ** 3 * (jnp.log(max_loss))- jnp.log(min_loss)) 630 | 631 | # # Create Contour Plot 632 | # CS = ax.contourf( 633 | # xs, 634 | # ys, 635 | # zs, 636 | # levels=levels, 637 | # cmap= 'magma', 638 | # linewidths=0.75, 639 | # norm = mpl.colors.Normalize(vmin = min_loss, vmax = max_loss), 640 | # ) 641 | # for i in points: 642 | # #origin_x, origin_y = self.project_points(self.coords.origin) 643 | # point_x, point_y = self.project_points(i) 644 | # ax.scatter(point_x, point_y, s = 20) 645 | # #ax.scatter(origin_x, origin_y, s = 1, c = 'r', marker = 'x') 646 | # ax.clabel(CS, fontsize=8, fmt="%1.2f") 647 | # #plt.colorbar(CS) 648 | # plt.show() 649 | # return ax 650 | 651 | # coords = RandomCoordinates(chain0_m) 652 | # loss_surface = LossSurface(target_log_prob_fn, coords) 653 | # loss_surface.compile(range = 5, num_points= 500) 654 | # ax = loss_surface.plot(levels = 15, points = [chain1_m, chain2_m]) 655 | # ============================================================================= 656 | # Plot different chains 657 | # ============================================================================= 658 | 659 | u_pred_ens = np.array([vmap(get_u_pred)(samples[i,:,:]) for i in range(samples.shape[0])]) 660 | f_pred_ens = np.array([vmap(get_f_pred)(samples[i,:,:]) for i in range(samples.shape[0])]) 661 | 662 | u_pred_ens_mean = np.mean(u_pred_ens, axis = 1) 663 | u_pred_ens_std = np.std(u_pred_ens, axis = 1) 664 | f_pred_ens_mean = np.mean(f_pred_ens, axis = 1) 665 | f_pred_ens_std = np.std(f_pred_ens, axis = 1) 666 | 667 | fig, ax = plt.subplots(dpi=300, figsize=(4,4)) 668 | ax.plot(x_pred_index, u_ref, 'k-', label='Exact', zorder=5) # Higher zorder to ensure the line is on top 669 | color = ['#2ca02c', '#ff7f0e', '#1f77b4'] # green, orange, blue 670 | linestyle = ['solid', 'dashdot', 'dashed'] 671 | # Adjust the zorder for fill_between 672 | zorders_fill = [1, 2, 3] # blue highest, then orange, then green 673 | # Plot lines and fill regions 674 | for i, (c, ls, z) in enumerate(zip(color, linestyle, zorders_fill)): 675 | ax.plot(x_pred_index, u_pred_ens_mean[i, :], color=c, linestyle=ls, markersize=1, markevery=2, markerfacecolor='None', label=f'Chain {i}', alpha=0.8, zorder=z+1) 676 | ax.fill_between(x_pred_index, u_pred_ens_mean[i,:] + 2 * u_pred_ens_std[i,:], u_pred_ens_mean[i,:] - 2 * u_pred_ens_std[i,:], color=color[i], alpha=0.4, zorder=z) 677 | ax.scatter(x_data, y_data, label='Obs' , s = 20, facecolors='none', edgecolors='b', zorder=6) # Higher zorder to ensure the scatter is on top 678 | ax.set_xlabel('$x$', fontsize=16) 679 | ax.set_ylabel('$u(x)$', fontsize=16) 680 | ax.set_xlim(-0.72,0.72) 681 | ax.set_ylim(-2.5,2.5) 682 | ax.tick_params(axis='both', which='major', labelsize=13) 683 | ax.legend(fontsize=10, loc='upper left') 684 | fig.tight_layout() 685 | plt.savefig(os.path.join(path_fig,f'1D_nonlinear_poisson_HMC_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_upred.png')) 686 | plt.show() 687 | 688 | fig, ax = plt.subplots(dpi=300, figsize=(4,4)) 689 | ax.plot(x_pred_index, f_ref, 'k-', label='Exact', zorder=5) # Higher zorder to ensure the line is on top 690 | color = ['#2ca02c', '#ff7f0e', '#1f77b4'] # green, orange, blue 691 | linestyle = ['solid', 'dashdot', 'dashed'] 692 | # Adjust the zorder for fill_between 693 | zorders_fill = [1, 2, 3] # blue highest, then orange, then green 694 | # Plot lines and fill regions 695 | for i, (c, ls, z) in enumerate(zip(color, linestyle, zorders_fill)): 696 | ax.plot(x_pred_index, f_pred_ens_mean[i, :], color=c, linestyle=ls, markersize=1, markevery=2, markerfacecolor='None', label=f'Chain {i}', alpha=0.8, zorder=z+1) 697 | ax.fill_between(x_pred_index, f_pred_ens_mean[i,:] + 2 * f_pred_ens_std[i,:], f_pred_ens_mean[i,:] - 2 * f_pred_ens_std[i,:], color=color[i], alpha=0.4, zorder=z) 698 | ax.scatter(X_r, y_r, label='Obs', s=20, facecolors='none', edgecolors='b', zorder=6) # Higher zorder to ensure the scatter is on top 699 | ax.set_xlabel('$x$', fontsize=16) 700 | ax.set_ylabel('$f(x)$', fontsize=16) 701 | ax.set_xlim(-0.72,0.72) 702 | ax.set_ylim(-2.5,2.5) 703 | ax.tick_params(axis='both', which='major', labelsize=13) 704 | ax.legend(fontsize=10, loc='upper left') 705 | fig.tight_layout() 706 | plt.savefig(os.path.join(path_fig,f'1D_nonlinear_poisson_HMC_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_fpred.png')) 707 | plt.show() 708 | 709 | 710 | # fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 711 | # ax.plot(x_pred_index, u_ref, 'k-', label='Exact') 712 | # ax.plot(x_pred_index, u_pred_ens_mean, markersize = 1, markevery=2, markerfacecolor='None', label= f'HMC mean', alpha = 0.8) 713 | # ax.fill_between(x_pred_index, u_pred_ens_mean + 2 * u_pred_ens_std, u_pred_ens_mean - 2 * u_pred_ens_std, 714 | # alpha = 0.3, label = r'$95 \% $ CI') 715 | # ax.scatter(data[:,0], data[:,1], label='Obs' , s = 20, facecolors='none', edgecolors='b') 716 | # ax.set_xlabel('$x$', fontsize=16) 717 | # ax.set_ylabel('$u(x)$', fontsize=16) 718 | # ax.set_xlim(-0.72,0.72) 719 | # ax.set_ylim(-2.5,2.5) 720 | # ax.tick_params(axis='both', which = 'major', labelsize=13) 721 | # ax.legend(fontsize=10) 722 | # fig.tight_layout() 723 | # plt.savefig(os.path.join(path_fig,f'1D_nonlinear_poisson_HMC_mean_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_upred.png')) 724 | # plt.show() 725 | 726 | # fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 727 | # r_ref = f(x_pred_index) 728 | # ax.plot(x_pred_index, r_ref, 'k-', label='Exact') 729 | # ax.plot(x_pred_index, f_pred_ens_mean, markersize = 1, markevery=2, markerfacecolor='None', label=f'HMC mean', alpha = 0.8) 730 | # ax.fill_between(x_pred_index, f_pred_ens_mean + 2 * f_pred_ens_std, f_pred_ens_mean - 2 * f_pred_ens_std, 731 | # alpha = 0.3, label = r'$95 \% $ CI') 732 | # ax.scatter(Dres[:,0], Dres[:,1], label='Obs' , s = 20, facecolors='none', edgecolors='b') 733 | # ax.set_xlabel('$x$', fontsize=16) 734 | # ax.set_xlim(-0.72,0.72) 735 | # ax.set_ylim(-2.5,2.5) 736 | # ax.set_ylabel('$f(x)$', fontsize=16) 737 | # ax.tick_params(axis='both', which = 'major', labelsize=13) 738 | # ax.legend(fontsize=10, loc= 'upper left') 739 | # fig.tight_layout() 740 | # plt.savefig(os.path.join(path_fig,f'1D_nonlinear_poisson_HMC_mean_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_fpred.png')) 741 | # plt.show() 742 | 743 | 744 | f_rec.close() 745 | 746 | 747 | sns.set(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) 748 | 749 | # Generating data for a normal distribution 750 | x = np.linspace(-4, 4, 1000) 751 | y = np.exp(-0.5 * x**2) / np.sqrt(2 * np.pi) 752 | 753 | # Creating the plot 754 | fig, ax = plt.subplots(dpi = 100) 755 | ax.plot(x, y, color='black') # black solid line for the normal distribution 756 | ax.fill_between(x, y, color='#C5E0B4') # filling below the line with yellow color 757 | 758 | # Removing axis 759 | ax.set_axis_off() 760 | 761 | plt.show() 762 | 763 | # x = np.linspace(-5, 5, 100) 764 | # y = np.linspace(-5, 5, 100) 765 | # X, Y = np.meshgrid(x, y) 766 | # Z = np.exp(-0.5 * ((X - 1)**2 + (Y - 1)**2) / 1) + np.exp(-0.5 * ((X + 2)**2 + (Y + 2)**2) / 0.5) + np.exp(-0.2 * ((X - 1.5)**2 + (Y - 2)**2) / 0.2) + np.exp(-0.7 * ((X + 3)**2 + (Y + 3)**2) / 1.2) 767 | 768 | # # Plotting 769 | # fig = plt.figure(dpi = 200) 770 | # ax = fig.add_subplot(111, projection='3d') 771 | # ax.plot_surface(X, Y, Z, rstride=1, cstride=1, color='#FFE699', edgecolor='none') 772 | 773 | # ax.set_axis_off() # Remove axes for clean visualization 774 | 775 | # plt.show() -------------------------------------------------------------------------------- /1D_nonlinear_Poisson/1D_nonlinear_poisson_SVGD_norm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Jan 28 15:11:25 2024 5 | 6 | @author: yifeizong 7 | """ 8 | 9 | import jax 10 | import os 11 | import jax.numpy as jnp 12 | from jax import random, grad, vmap, jit 13 | from jax.flatten_util import ravel_pytree 14 | from jax.example_libraries import optimizers 15 | 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | import matplotlib as mpl 19 | import scipy.linalg as spl 20 | from scipy.spatial.distance import pdist, squareform 21 | import seaborn as sns 22 | import pandas as pd 23 | from tensorflow_probability.substrates import jax as tfp 24 | tfd = tfp.distributions 25 | 26 | import itertools 27 | import argparse 28 | from functools import partial 29 | from tqdm import trange 30 | from time import perf_counter 31 | 32 | #command line argument parser 33 | parser = argparse.ArgumentParser(description="1D nonLinear Poisson with HMC") 34 | parser.add_argument( 35 | "--rand_seed", 36 | type=int, 37 | default=8888, 38 | help="random seed") 39 | parser.add_argument( 40 | "--sigma", 41 | type=float, 42 | default=0.1, 43 | help="Data uncertainty") 44 | parser.add_argument( 45 | "--sigma_r", 46 | type=float, 47 | default=0.1, 48 | help="Aleotoric uncertainty to the residual") 49 | parser.add_argument( 50 | "--sigma_d", 51 | type=float, 52 | default=0.1118034, 53 | help="Aleotoric uncertainty to the data") 54 | parser.add_argument( 55 | "--sigma_p", 56 | type=float, 57 | default=4.107919, 58 | help="Prior std") 59 | parser.add_argument( 60 | "--Nres", 61 | type=int, 62 | default=32, 63 | help="Number of reisudal points") 64 | parser.add_argument( 65 | "--Nsamples", 66 | type=int, 67 | default=1000, 68 | help="Number of Posterior samples") 69 | parser.add_argument( 70 | "--nIter", 71 | type=int, 72 | default=50000, 73 | help="Number of Posterior samples") 74 | args = parser.parse_args() 75 | 76 | #Define parameters 77 | layers_u = [1, 50, 50, 1] 78 | lbt = np.array([-0.7]) 79 | ubt = np.array([0.7]) 80 | lamb = 0.01 81 | k = 0.7 82 | dataset = dict() 83 | rand_seed = args.rand_seed 84 | Nres = args.Nres 85 | sigma = args.sigma 86 | sigma_r = args.sigma_r 87 | sigma_d = args.sigma_d 88 | sigma_p = args.sigma_p 89 | # sigma = 0.01 90 | # sigma_r = 0.01 91 | # sigma_d = 0.01118 92 | # sigma_p = 0.41079 93 | Nsamples = args.Nsamples 94 | nIter = args.nIter 95 | num_print = 20 96 | bandwidth = -1 97 | path_f = f'1D_nonlinear_poisson_SVGD_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}_bandwidth_{bandwidth}' 98 | path_fig = os.path.join(path_f,'figures') 99 | if not os.path.exists(path_f): 100 | os.makedirs(path_f) 101 | if not os.path.exists(path_fig): 102 | os.makedirs(path_fig) 103 | f_rec = open(os.path.join(path_f,'record.out'), 'a+') 104 | 105 | def u(x): 106 | return jnp.sin(6*x)**3 107 | 108 | def f(x): 109 | return lamb*(-108*jnp.sin(6*x)**3 + 216*jnp.sin(6*x)*jnp.cos(6*x)**2) + k*jnp.tanh(u(x)) 110 | 111 | rl2e = lambda yest, yref : spl.norm(yest - yref, 2) / spl.norm(yref, 2) 112 | infe = lambda yest, yref : spl.norm(yest - yref, np.inf) 113 | lpp = lambda h, href, sigma: np.sum( -(h - href)**2/(2*sigma**2) - 1/2*np.log( 2*np.pi) - 2*np.log(sigma)) 114 | 115 | #create noisy boundary data 116 | 117 | np.random.seed(rand_seed) 118 | x_data = np.array([lbt[0], ubt[0]])[:,np.newaxis] 119 | y_data = np.array([u(lbt[0]), u(ubt[0])])[:,np.newaxis].astype(np.float32) + np.random.normal(0,sigma,(2,1)).astype(np.float32) 120 | data = jnp.concatenate([x_data,y_data], axis=1) 121 | dataset.update({'data': data}) 122 | 123 | #create noisy forcing sampling 124 | X_r = np.linspace(lbt[0], ubt[0], Nres) 125 | X_r = jnp.sort(X_r, axis = 0)[:,np.newaxis] 126 | y_r = f(X_r) + np.random.normal(0,sigma,(Nres,1)) 127 | Dres = jnp.asarray(jnp.concatenate([X_r,y_r], axis=1)) 128 | dataset.update({'res': Dres}) 129 | 130 | # Define FNN 131 | def FNN(layers, activation=jnp.tanh): 132 | 133 | def init(prng_key): #return a list of (W,b) tuples 134 | def init_layer(key, d_in, d_out): 135 | key1, key2 = random.split(key) 136 | glorot_stddev = 1.0 / jnp.sqrt((d_in + d_out) / 2.) 137 | W = glorot_stddev * random.normal(key1, (d_in, d_out)) 138 | b = jnp.zeros(d_out) 139 | return W, b 140 | key, *keys = random.split(prng_key, len(layers)) 141 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 142 | return params 143 | 144 | def forward(params, inputs): 145 | Z = inputs 146 | for W, b in params[:-1]: 147 | outputs = jnp.dot(Z, W) + b 148 | Z = activation(outputs) 149 | W, b = params[-1] 150 | outputs = jnp.dot(Z, W) + b 151 | return outputs 152 | 153 | return init, forward 154 | 155 | # Define the model 156 | class PINN(): 157 | def __init__(self, key, layers, dataset, lbt, ubt, lamb, k, sigma_r, sigma_d, sigma_p): 158 | 159 | self.lbt = lbt #domain lower corner 160 | self.ubt = ubt #domain upper corner 161 | self.k = k 162 | self.lamb = lamb 163 | self.scale_coe = 0.5 164 | self.scale = 2 * self.scale_coe / (self.ubt - self.lbt) 165 | self.sigma_r = sigma_r 166 | self.sigma_d = sigma_d 167 | self.sigma_p = sigma_p 168 | 169 | # Prepare normalized training data 170 | self.dataset = dataset 171 | self.X_res, self.y_res = self.normalize(dataset['res'][:,0:1]), dataset['res'][:,1:2] 172 | self.X_data, self.y_data = self.normalize(dataset['data'][:,0:1]), dataset['data'][:,1:2] 173 | 174 | # Initalize the network 175 | self.init, self.forward = FNN(layers, activation=jnp.tanh) 176 | self.params = self.init(key) 177 | raveled_params, self.unravel = ravel_pytree(self.params) 178 | self.num_params = raveled_params.shape[0] 179 | 180 | self.itercount = itertools.count() 181 | self.log_prob_log = [] 182 | self.u_rl2e_log = [] 183 | self.u_lpp_log = [] 184 | self.f_rl2e_log = [] 185 | self.f_lpp_log = [] 186 | 187 | # Evaluate the network and the residual over the grid 188 | self.u_pred_map = vmap(self.predict_u, (None, 0)) 189 | self.f_pred_map = vmap(self.predict_f, (None, 0)) 190 | self.rl2e = lambda yest, yref : spl.norm(yest - yref, 2) / spl.norm(yref, 2) 191 | self.lpp = lambda h, href, sigma: np.sum(-(h - href)**2/(2*sigma**2) - 1/2*np.log( 2*np.pi) - 2*np.log(sigma)) 192 | 193 | def normalize(self, X): 194 | if X.shape[1] == 1: 195 | return 2.0 * self.scale_coe * (X - self.lbt[0:1])/(self.ubt[0:1] - self.lbt[0:1]) - self.scale_coe 196 | if X.shape[1] == 2: 197 | return 2.0 * self.scale_coe * (X - self.lbt[0:2])/(self.ubt[0:2] - self.lbt[0:2]) - self.scale_coe 198 | if X.shape[1] == 3: 199 | return 2.0 * self.scale_coe * (X - self.lbt)/(self.ubt - self.lbt) - self.scale_coe 200 | 201 | @partial(jit, static_argnums=(0,)) 202 | def u_net(self, params, x): 203 | inputs = jnp.hstack([x]) 204 | outputs = self.forward(params, inputs) 205 | return outputs[0] 206 | 207 | @partial(jit, static_argnums=(0,)) 208 | def res_net(self, params, x): 209 | u = self.u_net(params, x) 210 | u_xx = grad(grad(self.u_net, argnums=1), argnums=1)(params, x)*self.scale[0]**2 211 | return self.lamb*u_xx + self.k*jnp.tanh(u) 212 | 213 | @partial(jit, static_argnums=(0,)) 214 | def predict_u(self, params, x): 215 | # Normalize input first, and then predict 216 | x = 2.0 * self.scale_coe * (x - self.lbt[0])/(self.ubt[0] - self.lbt[0]) - self.scale_coe 217 | return self.u_net(params, x) 218 | 219 | @partial(jit, static_argnums=(0,)) 220 | def predict_f(self, params, x): 221 | # Normalize input first, and then predict 222 | x = 2.0 * self.scale_coe * (x - self.lbt[0])/(self.ubt[0] - self.lbt[0]) - self.scale_coe 223 | return self.res_net(params, x) 224 | 225 | @partial(jit, static_argnums=(0,)) 226 | def u_pred_vector(self, params): 227 | u_pred_vec = vmap(self.u_net, (None, 0))(self.unravel(params), self.X_data[:,0]) 228 | return u_pred_vec 229 | 230 | @partial(jit, static_argnums=(0,)) 231 | def f_pred_vector(self, params): 232 | f_pred_vec = vmap(self.res_net, (None, 0))(self.unravel(params), self.X_res[:,0]) 233 | return f_pred_vec 234 | 235 | @partial(jit, static_argnums=(0,)) 236 | def target_log_prob(self, theta): 237 | prior = jnp.sum(-(theta)**2/(2*self.sigma_p**2)) 238 | r_likelihood = jnp.sum(-(y_r.ravel() - self.f_pred_vector(theta))**2/(2*self.sigma_r**2)) 239 | u_likelihood = jnp.sum(-(y_data.ravel() - self.u_pred_vector(theta))**2/(2*self.sigma_d**2)) 240 | return prior + r_likelihood + u_likelihood 241 | 242 | @partial(jit, static_argnums=(0,)) 243 | def grad_log_prob(self, theta): 244 | return jax.value_and_grad(self.target_log_prob, argnums = 0)(theta)[1] 245 | 246 | def median_trick_h(self, theta): 247 | ''' 248 | The scipy one seems even faster and memory efficient 249 | 250 | ''' 251 | sq_dist = pdist(theta) 252 | pairwise_dists = squareform(sq_dist) 253 | h = np.median(pairwise_dists)**2 254 | h = np.sqrt(0.5 * h / np.log(theta.shape[0]+1)) 255 | return h 256 | 257 | @partial(jit, static_argnums=(0,)) 258 | def rbf_kernel(self, theta1, theta2, h): 259 | ''' 260 | Evaluate the rbf kernel k(x, x') = exp(-|x - x'|^2/(2h^2)) 261 | input: theta1, theta2 are 1d array of parameters, 262 | h is correlation length 263 | output: a scalar value of kernel evaluation 264 | ''' 265 | # here theta1 and theta2 are 1d-array of parameters 266 | return jnp.exp(-((theta1 - theta2)**2).sum(axis=-1) / (2 * h**2)) 267 | 268 | @partial(jit, static_argnums=(0,)) 269 | def compute_kernel_matrix(self, theta, h): 270 | return vmap(vmap(lambda x, y: self.rbf_kernel(x, y, h), in_axes=(None, 0)), in_axes=(0, None))(theta, theta) 271 | 272 | @partial(jit, static_argnums=(0,)) 273 | def kernel_and_grad(self, theta, h): 274 | ''' 275 | input theta: (Nsamples, Nparams) 276 | h is correlation length 277 | output: K: #(Nsamples, Nsamples) 278 | grad_K: #(Nsamples, Nparams) 279 | ''' 280 | K = self.compute_kernel_matrix(theta, h) #(Nsamples, Nsamples) 281 | grad_K = jnp.sum(jnp.einsum('ijk,ij->ijk', theta - theta[:, None, :], K), axis = 0)/ (h**2) 282 | return (K, grad_K) 283 | 284 | @partial(jit, static_argnums=(0,)) 285 | def svgd_step(self, i, opt_state, h): 286 | theta = self.get_params(opt_state) 287 | grad_logprob = vmap(self.grad_log_prob)(theta) 288 | K, grad_K = self.kernel_and_grad(theta, h) 289 | phi = -(jnp.einsum('ij, jk->ik', K, grad_logprob)/theta.shape[0] + grad_K) #(Nsamples, Nparams) 290 | return self.opt_update(i, phi, opt_state) 291 | 292 | def svgd_train(self, key, Nsamples, nIter, num_print, bandwidth, u_ref, f_ref): 293 | 294 | new_key, subkey = random.split(key, 2) 295 | init_state = random.normal(subkey , (Nsamples, self.num_params)) 296 | 297 | x_test = jnp.linspace(-0.7,0.7,101) 298 | u_pred = vmap(lambda sample: self.u_pred_map(self.unravel(sample),x_test)) 299 | f_pred = vmap(lambda sample: self.f_pred_map(self.unravel(sample),x_test)) 300 | u = u_ref(x_test) 301 | f = f_ref(x_test) 302 | 303 | lr = optimizers.exponential_decay(1e-4, decay_steps=1000, decay_rate=0.9) 304 | #lr = 1e-4 305 | self.opt_init, \ 306 | self.opt_update, \ 307 | self.get_params = optimizers.adam(lr) 308 | self.opt_state = self.opt_init(init_state) 309 | 310 | ts = perf_counter() 311 | pbar = trange(nIter) 312 | 313 | for it in pbar: 314 | self.current_count = next(self.itercount) 315 | theta = self.get_params(self.opt_state) 316 | h = bandwidth if bandwidth > 0 else self.median_trick_h(theta) 317 | self.opt_state = self.svgd_step(self.current_count, self.opt_state, h) 318 | 319 | if it % num_print == 0: 320 | 321 | log_prob = jnp.mean(vmap(self.target_log_prob)(theta)) 322 | u_test_ens = u_pred(theta) 323 | f_test_ens = f_pred(theta) 324 | u_test_mean = jnp.mean(u_test_ens, axis = 0) 325 | u_test_std = jnp.std(u_test_ens, axis = 0) 326 | rl2e_u = self.rl2e(u_test_mean, u) 327 | lpp_u = self.lpp(u_test_mean, u, u_test_std) 328 | 329 | f_test_mean = jnp.mean(f_test_ens, axis = 0) 330 | f_test_std = jnp.std(f_test_ens, axis = 0) 331 | rl2e_f = self.rl2e(f_test_mean, f) 332 | lpp_f = self.lpp(f_test_mean, f, f_test_std) 333 | 334 | self.log_prob_log.append(log_prob) 335 | self.u_rl2e_log.append(rl2e_u) 336 | self.u_lpp_log.append(lpp_u) 337 | self.f_rl2e_log.append(rl2e_f) 338 | self.f_lpp_log.append(lpp_f) 339 | 340 | pbar.set_postfix({'Log prob': log_prob, 341 | 'u_rl2e': rl2e_u, 342 | 'u_lpp':lpp_u, 343 | 'f_rl2e': rl2e_f, 344 | 'f_lpp': lpp_f}) 345 | 346 | timings = perf_counter() - ts 347 | print(f"SVGD: {timings} s") 348 | return self.get_params(self.opt_state) 349 | 350 | key1, key2 = random.split(random.PRNGKey(0), 2) 351 | model = PINN(key2, layers_u, dataset, lbt, ubt, lamb, k, sigma_r, sigma_d, sigma_p) 352 | ts = perf_counter() 353 | samples = model.svgd_train(key2, Nsamples, nIter = nIter, num_print = num_print, bandwidth = bandwidth, u_ref = u, f_ref = f) 354 | timings = perf_counter() - ts 355 | print(f"SVGD: {timings} s") 356 | print(f"SVGD: {timings} s", file = f_rec) 357 | np.savetxt(os.path.join(path_f,f'SVGD_samples_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}.out'), samples) 358 | 359 | Npred = 201 360 | x_pred_index = jnp.linspace(-0.7,0.7,Npred) 361 | u_ref = u(x_pred_index) 362 | f_ref = f(x_pred_index) 363 | 364 | @jit 365 | def get_u_pred(sample): 366 | return model.u_pred_map(model.unravel(sample),x_pred_index) 367 | 368 | @jit 369 | def get_f_pred(sample): 370 | return model.f_pred_map(model.unravel(sample),x_pred_index) 371 | 372 | u_pred_ens = vmap(get_u_pred)(samples) 373 | f_pred_ens = vmap(get_f_pred)(samples) 374 | np.savetxt(os.path.join(path_f,'u_pred_ens.out'), u_pred_ens) 375 | np.savetxt(os.path.join(path_f,'f_pred_ens.out'), f_pred_ens) 376 | 377 | u_pred_ens_mean = np.mean(u_pred_ens, axis = 0) 378 | u_pred_ens_std = np.std(u_pred_ens, axis = 0) 379 | f_pred_ens_mean = np.mean(f_pred_ens, axis = 0) 380 | f_pred_ens_std = np.std(f_pred_ens, axis = 0) 381 | 382 | u_env = np.logical_and( (u_pred_ens_mean < u_ref + 2*u_pred_ens_std), (u_pred_ens_mean > u_ref - 2*u_pred_ens_std) ) 383 | f_env = np.logical_and( (f_pred_ens_mean < f_ref + 2*f_pred_ens_std), (f_pred_ens_mean > f_ref - 2*f_pred_ens_std) ) 384 | 385 | # ============================================================================= 386 | # Posterior Statistics 387 | # ============================================================================= 388 | 389 | rl2e_u = rl2e(u_pred_ens_mean, u_ref) 390 | infe_u = infe(u_pred_ens_mean, u_ref) 391 | lpp_u = lpp(u_pred_ens_mean, u_ref, u_pred_ens_std) 392 | rl2e_f = rl2e(f_pred_ens_mean, f_ref) 393 | infe_f = infe(f_pred_ens_mean, f_ref) 394 | lpp_f = lpp(f_pred_ens_mean, f_ref, f_pred_ens_std) 395 | 396 | print('u prediction:\n') 397 | print('Relative RL2 error: {}'.format(rl2e_u)) 398 | print('Absolute inf error: {}'.format(infe_u)) 399 | print('Average standard deviation: {}'.format(np.mean(u_pred_ens_std))) 400 | print('log predictive probability: {}'.format(lpp_u)) 401 | print('Percentage of coverage:{}\n'.format(np.sum(u_env)/Npred)) 402 | 403 | print('f prediction:\n') 404 | print('Relative RL2 error: {}'.format(rl2e_f)) 405 | print('Absolute inf error: {}'.format(infe_f)) 406 | print('Average standard deviation: {}'.format(np.mean(f_pred_ens_std))) 407 | print('log predictive probability: {}'.format(lpp_f)) 408 | print('Percentage of coverage:{}\n'.format(np.sum(f_env)/Npred)) 409 | 410 | print('u prediction:\n', file = f_rec) 411 | print('Relative RL2 error: {}'.format(rl2e_u), file = f_rec) 412 | print('Absolute inf error: {}'.format(infe_u), file = f_rec) 413 | print('Average standard deviation: {}'.format(np.mean(u_pred_ens_std)), file = f_rec) 414 | print('log predictive probability: {}'.format(lpp_u), file = f_rec) 415 | print('Percentage of coverage:{}\n'.format(np.sum(u_env)/Npred), file = f_rec) 416 | 417 | print('f prediction:\n', file = f_rec) 418 | print('Relative RL2 error: {}'.format(rl2e_f), file = f_rec) 419 | print('Absolute inf error: {}'.format(infe_f), file = f_rec) 420 | print('Average standard deviation: {}'.format(np.mean(f_pred_ens_std)), file = f_rec) 421 | print('log predictive probability: {}'.format(lpp_f), file = f_rec) 422 | print('Percentage of coverage:{}\n'.format(np.sum(f_env)/Npred), file = f_rec) 423 | 424 | f_rec.close() 425 | 426 | # ============================================================================= 427 | # Plot posterior predictions 428 | # ============================================================================= 429 | 430 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 431 | ax.plot(x_pred_index, u_ref, 'k-', label='Exact') 432 | ax.plot(x_pred_index, u_pred_ens_mean, markersize = 1, markevery=2, markerfacecolor='None', label= 'SVGD mean', alpha = 0.8) 433 | ax.fill_between(x_pred_index, u_pred_ens_mean + 2 * u_pred_ens_std, u_pred_ens_mean - 2 * u_pred_ens_std, 434 | alpha = 0.3, label = r'$95 \% $ CI') 435 | ax.scatter(data[:,0], data[:,1], label='Obs' , s = 20, facecolors='none', edgecolors='b') 436 | ax.set_xlabel('$x$', fontsize=16) 437 | ax.set_ylabel('$u(x)$', fontsize=16) 438 | ax.set_xlim(-0.72,0.72) 439 | ax.set_ylim(-2.5,2.5) 440 | ax.tick_params(axis='both', which = 'major', labelsize=13) 441 | ax.legend(fontsize=10) 442 | fig.tight_layout() 443 | plt.savefig(os.path.join(path_fig,f'1D_nonlinear_poisson_SVGD_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}_upred.png')) 444 | plt.show() 445 | 446 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 447 | ax.plot(x_pred_index, f_ref, 'k-', label='Exact') 448 | ax.plot(x_pred_index, f_pred_ens_mean, markersize = 1, markevery=2, markerfacecolor='None', label='SVGD mean', alpha = 0.8) 449 | ax.fill_between(x_pred_index, f_pred_ens_mean + 2 * f_pred_ens_std, f_pred_ens_mean - 2 * f_pred_ens_std, 450 | alpha = 0.3, label = r'$95 \% $ CI') 451 | ax.scatter(Dres[:,0], Dres[:,1], label='Obs' , s = 20, facecolors='none', edgecolors='b') 452 | ax.set_xlabel('$x$', fontsize=16) 453 | ax.set_xlim(-0.72,0.72) 454 | ax.set_ylim(-2.5,2.5) 455 | ax.set_ylabel('$f(x)$', fontsize=16) 456 | ax.tick_params(axis='both', which = 'major', labelsize=13) 457 | ax.legend(fontsize=10, loc= 'upper left') 458 | fig.tight_layout() 459 | plt.savefig(os.path.join(path_fig,f'1D_nonlinear_poisson_SVGD_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}_fpred.png')) 460 | plt.show() 461 | 462 | # Log prob plot 463 | t = np.arange(0, nIter, num_print) 464 | fig = plt.figure(constrained_layout=False, figsize=(4, 4), dpi = 300) 465 | ax = fig.add_subplot() 466 | ax.plot(t, -np.array(model.log_prob_log), color='blue', label='Negative Log prob') 467 | ax.set_yscale('log') 468 | ax.set_ylabel('Loss', fontsize = 16) 469 | ax.set_xlabel('Epochs', fontsize = 16) 470 | ax.legend(loc='upper right', fontsize = 14) 471 | fig.tight_layout() 472 | fig.savefig(os.path.join(path_fig,'loss.png')) 473 | 474 | t = np.arange(0, nIter, num_print) 475 | fig = plt.figure(constrained_layout=False, figsize=(4, 4), dpi = 300) 476 | ax = fig.add_subplot() 477 | ax.plot(t, np.array(model.u_rl2e_log), label='u') 478 | ax.plot(t, np.array(model.f_rl2e_log), label='f') 479 | ax.set_ylabel('relative L2 error', fontsize = 16) 480 | ax.set_xlabel('Epochs', fontsize = 16) 481 | ax.legend(loc='upper right', fontsize = 14) 482 | fig.tight_layout() 483 | fig.savefig(os.path.join(path_fig,'test_rl2e.png')) 484 | 485 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 486 | ax.plot(x_pred_index, u_ref, 'k-', label='Exact') 487 | for i in range(1, 1000, 50): 488 | ax.plot(x_pred_index, get_u_pred(samples[i]), alpha = 0.5) 489 | ax.fill_between(x_pred_index, u_pred_ens_mean + 2 * u_pred_ens_std, u_pred_ens_mean - 2 * u_pred_ens_std, 490 | alpha = 0.3, label = r'$95 \% $ CI') 491 | ax.scatter(data[:,0], data[:,1], label='Obs' , s = 20, facecolors='none', edgecolors='b') 492 | ax.set_xlabel('$x$', fontsize=16) 493 | ax.set_xlim(-0.72,0.72) 494 | ax.set_ylim(-2.5,2.5) 495 | ax.set_ylabel('$u(x)$', fontsize=16) 496 | ax.tick_params(axis='both', which = 'major', labelsize=13) 497 | ax.legend(fontsize=10, loc= 'upper left') 498 | fig.tight_layout() 499 | fig.savefig(os.path.join(path_fig,'u_realizations.png')) 500 | plt.show() 501 | 502 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 503 | ax.plot(x_pred_index, f_ref, 'k-', label='Exact') 504 | for i in range(1, 1000, 50): 505 | ax.plot(x_pred_index, get_f_pred(samples[i]), alpha = 0.5) 506 | ax.fill_between(x_pred_index, f_pred_ens_mean + 2 * f_pred_ens_std, f_pred_ens_mean - 2 * f_pred_ens_std, 507 | alpha = 0.3, label = r'$95 \% $ CI') 508 | ax.scatter(Dres[:,0], Dres[:,1], label='Obs' , s = 20, facecolors='none', edgecolors='b') 509 | ax.set_xlabel('$x$', fontsize=16) 510 | ax.set_xlim(-0.72,0.72) 511 | ax.set_ylim(-2.5,2.5) 512 | ax.set_ylabel('$f(x)$', fontsize=16) 513 | ax.tick_params(axis='both', which = 'major', labelsize=13) 514 | ax.legend(fontsize=10, loc= 'upper left') 515 | fig.tight_layout() 516 | fig.savefig(os.path.join(path_fig,'f_realizations.png')) 517 | plt.show() -------------------------------------------------------------------------------- /1D_nonlinear_Poisson/1D_nonlinear_poisson_rPINN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jun 20 17:55:28 2023 5 | 6 | @author: yifeizong 7 | """ 8 | 9 | #Import dependencies 10 | import jax 11 | import os 12 | import jax.numpy as jnp 13 | from jax import random, grad, vmap, jit 14 | from jax.flatten_util import ravel_pytree 15 | from jax.example_libraries import optimizers 16 | 17 | import numpy as np 18 | import pandas as pd 19 | import seaborn as sns 20 | import matplotlib.pyplot as plt 21 | import matplotlib as mpl 22 | import scipy.linalg as spl 23 | from tensorflow_probability.substrates import jax as tfp 24 | tfd = tfp.distributions 25 | 26 | import itertools 27 | import argparse 28 | from functools import partial 29 | from tqdm import trange 30 | from time import perf_counter 31 | 32 | #command line argument parser 33 | # parser = argparse.ArgumentParser(description="1D nonLinear Poisson with randomized PINN") 34 | # parser.add_argument( 35 | # "--rand_seed", 36 | # type=int, 37 | # default=8888, 38 | # help="random seed") 39 | # parser.add_argument( 40 | # "--sigma", 41 | # type=float, 42 | # default=0.1, 43 | # help="Data noise level") 44 | # parser.add_argument( 45 | # "--sigma_r", 46 | # type=float, 47 | # default=0.1, 48 | # help="Aleotoric uncertainty to the residual") 49 | # parser.add_argument( 50 | # "--sigma_b", 51 | # type=float, 52 | # default=0.1, 53 | # help="Aleotoric uncertainty to the boundary data") 54 | # parser.add_argument( 55 | # "--sigma_p", 56 | # type=float, 57 | # default=1, 58 | # help="Prior std") 59 | # parser.add_argument( 60 | # "--Nres", 61 | # type=int, 62 | # default=512, 63 | # help="Number of reisudal points") 64 | # parser.add_argument( 65 | # "--Nsamples", 66 | # type=int, 67 | # default=100, 68 | # help="Number of posterior samples") 69 | # parser.add_argument( 70 | # "--nIter", 71 | # type=int, 72 | # default=8000, 73 | # help="Number of training epochs per realization") 74 | # parser.add_argument( 75 | # "--data_load", 76 | # type=bool, 77 | # default=False, 78 | # help="If to load data") 79 | # parser.add_argument( 80 | # "--method", 81 | # type=str, 82 | # default='rPINN', 83 | # help="Method for Bayesian training") 84 | # parser.add_argument( 85 | # "--model_load", 86 | # type=bool, 87 | # default=False, 88 | # help="If to load existing samples") 89 | # args = parser.parse_args() 90 | 91 | #Define parameters 92 | layers_u = [1, 50, 50, 1] 93 | lbt = np.array([-0.7]) 94 | ubt = np.array([0.7]) 95 | lamb = 0.01 96 | k = 0.7 97 | rand_seed = 8888 98 | Nres = 32 99 | Nb = 2 100 | sigma = 0.1 101 | lambda_r = 20*2700/Nres 102 | lambda_b = 1*2700/Nb 103 | lambda_p = 1 104 | gamma = sigma**2*lambda_r 105 | Nsamples = 5000 106 | nIter = 2000 107 | method = 'DE' 108 | model_load = False 109 | num_print = 200 110 | dataset = dict() 111 | 112 | path_f = f'1D_nonlinear_poisson_{method}_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}' 113 | path_fig = os.path.join(path_f,'figures') 114 | if not os.path.exists(path_f): 115 | os.makedirs(path_f) 116 | if not os.path.exists(path_fig): 117 | os.makedirs(path_fig) 118 | f_rec = open(os.path.join(path_f,'record.out'), 'a+') 119 | 120 | def u(x): 121 | return jnp.sin(6*x)**3 122 | 123 | def f(x): 124 | return lamb*(-108*jnp.sin(6*x)**3 + 216*jnp.sin(6*x)*jnp.cos(6*x)**2) + k*jnp.tanh(u(x)) 125 | 126 | rl2e = lambda yest, yref : spl.norm(yest - yref, 2) / spl.norm(yref, 2) 127 | infe = lambda yest, yref : spl.norm(yest - yref, np.inf) 128 | lpp = lambda h, href, sigma: np.sum( -(h - href)**2/(2*sigma**2) - 1/2*np.log( 2*np.pi) - 2*np.log(sigma)) 129 | 130 | #create noisy boundary data 131 | np.random.seed(rand_seed) 132 | x_data = np.array([lbt[0], ubt[0]])[:,np.newaxis] 133 | y_data = np.array([u(lbt[0]), u(ubt[0])])[:,np.newaxis].astype(np.float32) + np.random.normal(0,sigma,(2,1)).astype(np.float32) 134 | data = jnp.concatenate([x_data,y_data], axis=1) 135 | dataset.update({'data': data}) 136 | 137 | #create noisy forcing sampling 138 | X_r = np.linspace(lbt[0], ubt[0], Nres) 139 | X_r = jnp.sort(X_r, axis = 0)[:,np.newaxis] 140 | y_r = f(X_r) + np.random.normal(0,sigma,(Nres,1)).astype(np.float32) 141 | Dres = jnp.asarray(jnp.concatenate([X_r,y_r], axis=1)) 142 | dataset.update({'res': Dres}) 143 | 144 | # Define FNN 145 | def FNN(layers, activation=jnp.tanh): 146 | 147 | def init(prng_key): #return a list of (W,b) tuples 148 | def init_layer(key, d_in, d_out): 149 | key1, key2 = random.split(key) 150 | glorot_stddev = 1.0 / jnp.sqrt((d_in + d_out) / 2.) 151 | W = glorot_stddev * random.normal(key1, (d_in, d_out)) 152 | b = jnp.zeros(d_out) 153 | return W, b 154 | key, *keys = random.split(prng_key, len(layers)) 155 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 156 | return params 157 | 158 | def forward(params, inputs): 159 | Z = inputs 160 | for W, b in params[:-1]: 161 | outputs = jnp.dot(Z, W) + b 162 | Z = activation(outputs) 163 | W, b = params[-1] 164 | outputs = jnp.dot(Z, W) + b 165 | return outputs 166 | 167 | return init, forward 168 | 169 | # Define the model 170 | class PINN(): 171 | def __init__(self, key, layers, dataset, lbt, ubt, lamb, k, sigma, lambda_r, lambda_b, lambda_p, gamma): 172 | 173 | self.lbt = lbt #domain lower corner 174 | self.ubt = ubt #domain upper corner 175 | self.k = k 176 | self.lamb = lamb 177 | self.scale_coe = 0.5 178 | self.scale = 2 * self.scale_coe / (self.ubt - self.lbt) 179 | 180 | # Prepare normalized training data 181 | self.dataset = dataset 182 | self.X_res, self.y_res = self.normalize(dataset['res'][:,0:1]), dataset['res'][:,1:2] 183 | self.X_data, self.y_data = self.normalize(dataset['data'][:,0:1]), dataset['data'][:,1:2] 184 | 185 | # Initalize the network 186 | self.init, self.forward = FNN(layers, activation=jnp.tanh) 187 | self.params = self.init(key) 188 | raveled, self.unravel = ravel_pytree(self.params) 189 | self.num_params = raveled.shape[0] 190 | 191 | # Evaluate the state and the residual over the grid 192 | self.u_pred_map = vmap(self.predict_u, (None, 0)) 193 | self.f_pred_map = vmap(self.predict_f, (None, 0)) 194 | 195 | self.itercount = itertools.count() 196 | self.loss_log = [] 197 | self.loss_likelihood_log = [] 198 | self.loss_dbc_log = [] 199 | self.loss_res_log = [] 200 | 201 | # Optimizer 202 | lr = optimizers.exponential_decay(1e-3, decay_steps=5000, decay_rate=0.9) 203 | self.opt_init, \ 204 | self.opt_update, \ 205 | self.get_params = optimizers.adam(lr) 206 | self.opt_state = self.opt_init(self.params) 207 | 208 | self.lambda_r = lambda_r 209 | self.lambda_b = lambda_b 210 | self.lambda_p = lambda_p 211 | self.gamma = gamma 212 | self.sigma_p = jnp.sqrt(self.gamma) 213 | self.sigma_r = jnp.sqrt(self.gamma/self.lambda_r) 214 | self.sigma_b = jnp.sqrt(self.gamma/self.lambda_b) 215 | 216 | #Define random noise distributions 217 | # for residual term 218 | self.alpha_dist = tfd.Independent(tfd.Normal(loc= jnp.zeros_like(self.y_res.ravel()), 219 | scale= self.sigma_r*jnp.ones_like(self.y_res.ravel())), reinterpreted_batch_ndims = 1) 220 | # for boundary term 221 | self.beta_dist = tfd.Independent(tfd.Normal(loc= jnp.zeros_like(self.y_data.ravel()), 222 | scale= self.sigma_b*jnp.ones_like(self.y_data.ravel())), reinterpreted_batch_ndims = 1) 223 | # for regularization term 224 | self.omega_dist = tfd.Independent(tfd.Normal(loc= jnp.zeros((self.num_params,)), 225 | scale= self.sigma_p*jnp.ones((self.num_params,))), reinterpreted_batch_ndims = 1) 226 | 227 | 228 | # normalize inputs of DNN to [-0.5, 0.5] 229 | def normalize(self, X): 230 | if X.shape[1] == 1: 231 | return 2.0 * self.scale_coe * (X - self.lbt[0:1])/(self.ubt[0:1] - self.lbt[0:1]) - self.scale_coe 232 | if X.shape[1] == 2: 233 | return 2.0 * self.scale_coe * (X - self.lbt[0:2])/(self.ubt[0:2] - self.lbt[0:2]) - self.scale_coe 234 | if X.shape[1] == 3: 235 | return 2.0 * self.scale_coe * (X - self.lbt)/(self.ubt - self.lbt) - self.scale_coe 236 | 237 | @partial(jit, static_argnums=(0,)) 238 | def u_net(self, params, x): 239 | inputs = jnp.hstack([x]) 240 | outputs = self.forward(params, inputs) 241 | return outputs[0] 242 | 243 | @partial(jit, static_argnums=(0,)) 244 | def res_net(self, params, x): 245 | u = self.u_net(params, x) 246 | u_xx = grad(grad(self.u_net, argnums=1), argnums=1)(params, x)*self.scale[0]**2 247 | return self.lamb*u_xx + self.k*jnp.tanh(u) 248 | 249 | @partial(jit, static_argnums=(0,)) 250 | def predict_u(self, params, x): 251 | # Normalize input first, and then predict 252 | x = 2.0 * self.scale_coe * (x - self.lbt[0])/(self.ubt[0] - self.lbt[0]) - self.scale_coe 253 | return self.u_net(params, x) 254 | 255 | @partial(jit, static_argnums=(0,)) 256 | def predict_f(self, params, x): 257 | # Normalize input first, and then predict 258 | x = 2.0 * self.scale_coe * (x - self.lbt[0])/(self.ubt[0] - self.lbt[0]) - self.scale_coe 259 | return self.res_net(params, x) 260 | 261 | @partial(jit, static_argnums=(0,)) 262 | def loss_dbc(self, params, beta): 263 | u_pred = vmap(self.u_net, (None, 0))(params, self.X_data[:,0]) 264 | loss_bc = jnp.sum((u_pred.flatten() - self.y_data.flatten() - beta)**2) 265 | return loss_bc 266 | 267 | @partial(jit, static_argnums=(0,)) 268 | def loss_res(self, params, alpha): 269 | f_pred = vmap(self.res_net, (None, 0))(params, self.X_res[:,0]) 270 | loss_res = jnp.sum((f_pred.flatten() - self.y_res.flatten() - alpha)**2) 271 | return loss_res 272 | 273 | @partial(jit, static_argnums=(0,)) 274 | def l2_regularizer(self, params, omega): 275 | return jnp.sum((ravel_pytree(params)[0] - omega)**2) 276 | 277 | @partial(jit, static_argnums=(0,)) 278 | def loss(self, params, alpha, beta, omega): 279 | return 1/self.sigma_r**2*self.loss_res(params, alpha) + 1/self.sigma_b**2*self.loss_dbc(params, beta) + \ 280 | 1/self.sigma_p**2*self.l2_regularizer(params, omega) 281 | 282 | @partial(jit, static_argnums=(0,)) 283 | def step(self, i, opt_state, alpha, beta, omega): 284 | params = self.get_params(opt_state) 285 | g = grad(self.loss, argnums=0)(params, alpha, beta, omega) 286 | 287 | return self.opt_update(i, g, opt_state) 288 | 289 | def train(self, nIter, num_print, alpha, beta, omega): 290 | pbar = trange(nIter) 291 | # Main training loop 292 | for it in pbar: 293 | self.current_count = next(self.itercount) 294 | self.opt_state = self.step(self.current_count, self.opt_state, alpha, beta, omega) 295 | 296 | if it % num_print == 0: 297 | params = self.get_params(self.opt_state) 298 | 299 | loss_value = self.loss(params, alpha, beta, omega) 300 | loss_res_value = self.loss_res(params, alpha) 301 | loss_dbc_value = self.loss_dbc(params, beta) 302 | loss_reg_value = self.l2_regularizer(params, omega) 303 | 304 | 305 | pbar.set_postfix({'Loss': loss_value, 306 | 'Loss_res': loss_res_value, 307 | 'Loss_dbc': loss_dbc_value, 308 | 'Loss_reg': loss_reg_value}) 309 | self.loss_log.append(loss_value) 310 | self.loss_likelihood_log.append(loss_res_value + loss_dbc_value) 311 | self.loss_res_log.append(loss_res_value) 312 | self.loss_dbc_log.append(loss_dbc_value) 313 | 314 | def de_sample(self, Nsample, nIter, num_print, key): 315 | #Using deep ensemble methods 316 | params_sample = [] 317 | alpha = beta = omega = 0 318 | 319 | for it in range(Nsample): 320 | key, *keys = random.split(key, 2) 321 | 322 | lr = optimizers.exponential_decay(1e-3, decay_steps=5000, decay_rate=0.9) 323 | self.opt_init, \ 324 | self.opt_update, \ 325 | self.get_params = optimizers.adam(lr) 326 | self.params = self.init(keys[0]) 327 | self.opt_state = self.opt_init(self.params) 328 | self.itercount = itertools.count() 329 | self.train(nIter, num_print, alpha, beta, omega) 330 | 331 | params = self.get_params(self.opt_state) 332 | params_sample.append(ravel_pytree(params)[0]) 333 | print(f'{it}-th sample finished') 334 | return jnp.array(params_sample) 335 | 336 | def rpinn_sample(self, Nsample, nIter, num_print, key): 337 | #sample with randomized PINN 338 | params_sample = [] 339 | alpha_sample = [] 340 | beta_sample = [] 341 | omega_sample = [] 342 | 343 | for it in range(Nsample): 344 | key, *keys = random.split(key, 5) 345 | #key, *keys = random.split(key, 4) 346 | alpha = self.alpha_dist.sample(1, keys[0])[0] 347 | beta = self.beta_dist.sample(1, keys[1])[0] 348 | omega = self.omega_dist.sample(1, keys[2])[0] 349 | 350 | lr = optimizers.exponential_decay(1e-3, decay_steps=5000, decay_rate=0.9) 351 | self.opt_init, \ 352 | self.opt_update, \ 353 | self.get_params = optimizers.adam(lr) 354 | self.params = self.init(keys[3]) 355 | self.opt_state = self.opt_init(self.params) 356 | self.itercount = itertools.count() 357 | self.train(nIter, num_print, alpha, beta, omega) 358 | 359 | params = self.get_params(self.opt_state) 360 | params_sample.append(ravel_pytree(params)[0]) 361 | alpha_sample.append(alpha) 362 | beta_sample.append(beta) 363 | omega_sample.append(omega) 364 | print(f'{it}-th sample finished') 365 | return jnp.array(params_sample), jnp.array(alpha_sample),\ 366 | jnp.array(beta_sample), jnp.array(omega_sample) 367 | 368 | # def metropolis_ratio(self, theta_old, theta_new, alpha_old, alpha_new): 369 | # def fn(theta, alpha): 370 | # res = self.r_pred_vector(theta) 371 | # delta = res - alpha 372 | # prod = jnp.einsum('i,i->', res, delta) 373 | # return prod 374 | 375 | # jac_old = jax.jacfwd(self.r_pred_vector, argnums = 0)(theta_old) #jacobian dr/dtheta (Nres, num_params) 376 | # jac_new = jax.jacfwd(self.r_pred_vector, argnums = 0)(theta_new) #jacobian dr/dtheta 377 | # hess_old = jax.hessian(fn, argnums = 0)(theta_old, alpha_old) #hessian 378 | # hess_new = jax.hessian(fn, argnums = 0)(theta_new, alpha_new) 379 | 380 | # #logdet_old = np.linalg.slogdet(sigma_r**2*Sigma + hess.T + jnp.einsum('ik,kj->ij', jac.T, jac))[1] 381 | # #logdet_new = np.linalg.slogdet(sigma_r**2*Sigma + hess.T + jnp.einsum('ik,kj->ij', jac.T, jac))[1] 382 | # #ratio = np.sqrt(np.exp(logdet_new - logdet_old)) 383 | 384 | # det_old = np.linalg.det(sigma_r**2*self.Sigma + hess_old.T + jnp.einsum('ik,kj->ij', jac_old.T, jac_old)) 385 | # det_new = np.linalg.det(sigma_r**2*self.Sigma + hess_new.T + jnp.einsum('ik,kj->ij', jac_new.T, jac_new)) 386 | # ratio = np.sqrt(np.math.abs(det_new))/np.math.sqrt(np.math.abs(det_old)) 387 | 388 | # return ratio 389 | 390 | # def rpinn_sample_metro(self, Nsample, nIter, num_print, key): 391 | # params_sample = [] 392 | # alpha_sample = [] 393 | # beta_sample = [] 394 | # omega_sample = [] 395 | 396 | # for it in range(Nsample): 397 | # key, *keys = random.split(key, 4) 398 | # alpha = self.alpha_dist.sample(1, keys[0])[0] 399 | # beta = self.beta_dist.sample(1, keys[1])[0] 400 | # omega = self.omega_dist.sample(1, keys[2])[0] 401 | 402 | # lr = optimizers.exponential_decay(1e-3, decay_steps=5000, decay_rate=0.9) 403 | # self.opt_init, \ 404 | # self.opt_update, \ 405 | # self.get_params = optimizers.adam(lr) 406 | # #self.params = self.init(keys[3]) 407 | # self.opt_state = self.opt_init(self.params) 408 | # self.itercount = itertools.count() 409 | # self.train(nIter, num_print, alpha, beta, omega) 410 | 411 | # theta_new = ravel_pytree(self.get_params(self.opt_state))[0] 412 | # ratio = np.abs() 413 | # u = np.random.uniform(low=0.0, high=1.0, size=()) 414 | # params_sample.append() 415 | # alpha_sample.append(alpha) 416 | # beta_sample.append(beta) 417 | # omega_sample.append(omega) 418 | # print(f'{it}-th sample finished') 419 | # return jnp.array(params_sample), jnp.array(alpha_sample),\ 420 | # jnp.array(beta_sample), jnp.array(omega_sample) 421 | 422 | 423 | key = random.PRNGKey(rand_seed) 424 | key, subkey = random.split(key, 2) 425 | model = PINN(key, layers_u, dataset, lbt, ubt, lamb, k, sigma, lambda_r, lambda_b, lambda_p, gamma) 426 | if model_load == False: 427 | if method == 'rPINN': 428 | 429 | ts = perf_counter() 430 | samples, alpha_ens, beta_ens, omega_ens = model.rpinn_sample(Nsamples, nIter = nIter, 431 | num_print = num_print, key = subkey) 432 | timings = perf_counter() - ts 433 | print(f"rPINN: {timings} s") 434 | print(f"rPINN: {timings} s", file = f_rec) 435 | 436 | elif method == 'rPINN_metro': 437 | 438 | ts = perf_counter() 439 | samples, alpha_ens, beta_ens, omega_ens = model.rpinn_sample_metro(Nsamples, nIter = nIter, 440 | num_print = num_print, key = subkey) 441 | timings = perf_counter() - ts 442 | print(f"rPINN-metro: {timings} s") 443 | print(f"rPINN-metro: {timings} s", file = f_rec) 444 | 445 | elif method == 'DE': 446 | 447 | ts = perf_counter() 448 | samples = model.de_sample(Nsamples, nIter = nIter, num_print = num_print, key = subkey) 449 | timings = perf_counter() - ts 450 | print(f"Deep ensemble: {timings} s") 451 | print(f"Deep ensemble: {timings} s", file = f_rec) 452 | 453 | else: 454 | samples, omega_ens = model.rms_sample(Nsamples, nIter = nIter, num_print = num_print, key = subkey) 455 | 456 | np.savetxt(os.path.join(path_f,f'{method}_posterior_samples_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}.out'), samples) 457 | else: 458 | samples = np.loadtxt(os.path.join(path_f,f'{method}_posterior_samples_Nres_{Nres}_sigma_{sigma}__Nsamples_{Nsamples}_nIter_{nIter}.out')) 459 | 460 | Npred = 201 461 | x_pred_index = jnp.linspace(-0.7,0.7,Npred) 462 | u_ref = u(x_pred_index) 463 | f_ref = f(x_pred_index) 464 | 465 | @jit 466 | def get_u_pred(sample): 467 | return model.u_pred_map(model.unravel(sample),x_pred_index) 468 | 469 | @jit 470 | def get_f_pred(sample): 471 | return model.f_pred_map(model.unravel(sample),x_pred_index) 472 | 473 | u_pred_ens = vmap(get_u_pred)(samples) 474 | f_pred_ens = vmap(get_f_pred)(samples) 475 | np.savetxt(os.path.join(path_f,'u_pred_ens.out'), u_pred_ens) 476 | np.savetxt(os.path.join(path_f,'f_pred_ens.out'), f_pred_ens) 477 | 478 | u_pred_ens_mean = np.mean(u_pred_ens, axis = 0) 479 | u_pred_ens_std = np.std(u_pred_ens, axis = 0) 480 | f_pred_ens_mean = np.mean(f_pred_ens, axis = 0) 481 | f_pred_ens_std = np.std(f_pred_ens, axis = 0) 482 | 483 | u_env = np.logical_and( (u_pred_ens_mean < u_ref + 2*u_pred_ens_std), (u_pred_ens_mean > u_ref - 2*u_pred_ens_std) ) 484 | f_env = np.logical_and( (f_pred_ens_mean < f_ref + 2*f_pred_ens_std), (f_pred_ens_mean > f_ref - 2*f_pred_ens_std) ) 485 | 486 | # ============================================================================= 487 | # Posterior Statistics 488 | # ============================================================================= 489 | 490 | rl2e_u = rl2e(u_pred_ens_mean, u_ref) 491 | infe_u = infe(u_pred_ens_mean, u_ref) 492 | lpp_u = lpp(u_pred_ens_mean, u_ref, u_pred_ens_std) 493 | rl2e_f = rl2e(f_pred_ens_mean, f_ref) 494 | infe_f = infe(f_pred_ens_mean, f_ref) 495 | lpp_f = lpp(f_pred_ens_mean, f_ref, f_pred_ens_std) 496 | 497 | print('u prediction:\n') 498 | print('Relative RL2 error: {}'.format(rl2e_u)) 499 | print('Absolute inf error: {}'.format(infe_u)) 500 | print('Average standard deviation: {}'.format(np.mean(u_pred_ens_std))) 501 | print('log predictive probability: {}'.format(lpp_u)) 502 | print('Percentage of coverage:{}\n'.format(np.sum(u_env)/Npred)) 503 | 504 | print('f prediction:\n') 505 | print('Relative RL2 error: {}'.format(rl2e_f)) 506 | print('Absolute inf error: {}'.format(infe_f)) 507 | print('Average standard deviation: {}'.format(np.mean(f_pred_ens_std))) 508 | print('log predictive probability: {}'.format(lpp_f)) 509 | print('Percentage of coverage:{}\n'.format(np.sum(f_env)/Npred)) 510 | 511 | print('u prediction:\n', file = f_rec) 512 | print('Relative RL2 error: {}'.format(rl2e_u), file = f_rec) 513 | print('Absolute inf error: {}'.format(infe_u), file = f_rec) 514 | print('Average standard deviation: {}'.format(np.mean(u_pred_ens_std)), file = f_rec) 515 | print('log predictive probability: {}'.format(lpp_u), file = f_rec) 516 | print('Percentage of coverage:{}\n'.format(np.sum(u_env)/Npred), file = f_rec) 517 | 518 | print('f prediction:\n', file = f_rec) 519 | print('Relative RL2 error: {}'.format(rl2e_f), file = f_rec) 520 | print('Absolute inf error: {}'.format(infe_f), file = f_rec) 521 | print('Average standard deviation: {}'.format(np.mean(f_pred_ens_std)), file = f_rec) 522 | print('log predictive probability: {}'.format(lpp_f), file = f_rec) 523 | print('Percentage of coverage:{}\n'.format(np.sum(f_env)/Npred), file = f_rec) 524 | 525 | 526 | # ============================================================================= 527 | # Plot posterior predictions 528 | # ============================================================================= 529 | 530 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 531 | ax.plot(x_pred_index, u_ref, 'k-', label='Exact') 532 | # for i in range(1, 5000, 500): 533 | # ax.plot(x_pred_index, get_u_pred(samples[i])) 534 | ax.plot(x_pred_index, u_pred_ens_mean, markersize = 1, markevery=2, markerfacecolor='None', label= f'{method} mean', alpha = 0.8) 535 | ax.fill_between(x_pred_index, u_pred_ens_mean + 2 * u_pred_ens_std, u_pred_ens_mean - 2 * u_pred_ens_std, 536 | alpha = 0.3, label = r'$95 \% $ CI') 537 | ax.scatter(data[:,0], data[:,1], label='Obs' , s = 20, facecolors='none', edgecolors='b') 538 | ax.set_xlabel('$x$', fontsize=16) 539 | ax.set_ylabel('$u(x)$', fontsize=16) 540 | ax.set_xlim(-0.72,0.72) 541 | ax.set_ylim(-2.5,2.5) 542 | ax.tick_params(axis='both', which = 'major', labelsize=13) 543 | ax.legend(fontsize=10) 544 | fig.tight_layout() 545 | plt.savefig(os.path.join(path_fig,f'1D_nonlinear_poisson_{method}_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}_upred.png')) 546 | plt.show() 547 | 548 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 549 | r_ref = f(x_pred_index) 550 | ax.plot(x_pred_index, r_ref, 'k-', label='Exact') 551 | # for i in range(1, 5000, 500): 552 | # ax.plot(x_pred_index, get_r_pred(samples[i])) 553 | ax.plot(x_pred_index, f_pred_ens_mean, markersize = 1, markevery=2, markerfacecolor='None', label=f'{method} mean', alpha = 0.8) 554 | ax.fill_between(x_pred_index, f_pred_ens_mean + 2 * f_pred_ens_std, f_pred_ens_mean - 2 * f_pred_ens_std, 555 | alpha = 0.3, label = r'$95 \% $ CI') 556 | ax.scatter(Dres[:,0], Dres[:,1], label='Obs' , s = 20, facecolors='none', edgecolors='b') 557 | ax.set_xlabel('$x$', fontsize=16) 558 | ax.set_xlim(-0.72,0.72) 559 | ax.set_ylim(-2.5,2.5) 560 | ax.set_ylabel('$f(x)$', fontsize=16) 561 | ax.tick_params(axis='both', which = 'major', labelsize=13) 562 | ax.legend(fontsize=10, loc= 'upper left') 563 | fig.tight_layout() 564 | plt.savefig(os.path.join(path_fig,f'1D_nonlinear_poisson_{method}_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}_fpred.png')) 565 | plt.show() 566 | 567 | 568 | f_rec.close() 569 | 570 | # idx0, idx1, idx2 = np.argsort(theta_mean_hmc)[-3:] 571 | # theta0_hmc, theta1_hmc, theta2_hmc = hmc_samples[:,idx0], hmc_samples[:,idx1], hmc_samples[:,idx2] 572 | # theta0_rpinn, theta1_rpinn, theta2_rpinn = samples[:,idx0], samples[:,idx1], samples[:,idx2] 573 | # df1 = pd.dataFrame({'Method': 'HMC', r'$\theta_0$': theta0_hmc, r'$\theta_1$': theta1_hmc, r'$\xi_2$': theta2_hmc}) 574 | # df2 = pd.dataFrame({'Method': 'rPICKLE', r'$\theta_0$': theta0_rpinn, r'$\theta_1$':theta1_rpinn, r'$\xi_2$': theta2_rpinn}) 575 | # df = pd.concat([df1, df2], ignore_index=True) 576 | # # g = sns.jointplot(data=df, x='xi_0', y='xi_1', hue='method', joint_kws={'alpha': 0.5}, kind = 'kde') 577 | # # g.ax_joint.tick_params(labelsize=16) 578 | # # g.ax_joint.set_xlabel(r"$\xi_0$", fontsize=24) 579 | # # g.ax_joint.set_ylabel(r"$\xi_1$", fontsize=24) 580 | 581 | # plt.figure(figsize=(4,4)) 582 | # g = sns.PairGrid(df, hue='Method') 583 | # g.map_diag(sns.histplot) 584 | # g.map_offdiag(sns.kdeplot) 585 | # g.add_legend(fontsize=18) 586 | # plt.gcf().set_dpi(600) 587 | # g.tight_layout() 588 | # plt.show() 589 | 590 | 591 | # fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 592 | # r_ref = f(x_pred_index) 593 | # ax.plot(x_pred_index, r_ref, 'k-', label='Exact') 594 | # # for i in range(1, 5000, 500): 595 | # # ax.plot(x_pred_index, get_r_pred(samples[i])) 596 | # ax.plot(x_pred_index, r_pred_ens_mean_hmc, markersize = 1, markevery=2, markerfacecolor='None', label=r'HMC mean', alpha = 0.8) 597 | # ax.fill_between(x_pred_index, r_pred_ens_mean_hmc + 2 * r_pred_ens_std_hmc, r_pred_ens_mean_hmc - 2 * r_pred_ens_std_hmc, 598 | # alpha = 0.3, label = r'$95 \% $ CI') 599 | # #ax.scatter(Dres[:,0], Dres[:,1], label='Obs' , s = 20, facecolors='none', edgecolors='b') 600 | # ax.set_xlabel('$X$', fontsize=16) 601 | # ax.set_xlim(-1.02,1.02) 602 | # ax.set_ylim(-1.2,1.2) 603 | # ax.set_ylabel('$r(x)$', fontsize=16) 604 | # ax.tick_params(axis='both', which = 'major', labelsize=13) 605 | # ax.legend(fontsize=10, loc= 'upper left') 606 | # #plt.savefig(os.path.join(path_fig,'r_pred.png')) 607 | # plt.show() 608 | 609 | 610 | #Hessian specification 611 | 612 | # hess_fn = jax.hessian(target_log_prob_fn) 613 | # hessian0 = hess(np.mean(sample0, axis = 0)) 614 | # _, s0, _ = jax.scipy.linalg.svd(jax.scipy.linalg.inv(hessian0)) 615 | # fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 616 | # ax.plot(s[i], linestyle = linestyle[i], marker = mark, markersize = 2, markevery= 100, markerfacecolor='None', label=f'chain{i+1}', alpha = 0.8) 617 | # ax.set_xlabel('Index', fontsize=16) 618 | # ax.set_ylabel('Eigenvalues', fontsize=16) 619 | # plt.yscale('log') 620 | # ax.tick_params(axis='both', which = 'major', labelsize=13) 621 | # ax.legend(fontsize=8) 622 | # plt.show() 623 | 624 | # ============================================================================= 625 | # Loss landscape 626 | # ============================================================================= 627 | 628 | 629 | # def normalize_weights(weights, origin): 630 | # return weights* jnp.abs(origin)/ jnp.abs(weights) 631 | 632 | # class RandomCoordinates(object): 633 | # def __init__(self, origin): 634 | # self.origin = origin # (num_params,) 635 | # self.v0 = normalize_weights( 636 | # jax.random.normal(key = random.PRNGKey(88), shape = self.origin.shape), 637 | # origin) 638 | # self.v1 = normalize_weights( 639 | # jax.random.normal(key = random.PRNGKey(66), shape = self.origin.shape), 640 | # origin) 641 | 642 | # def __call__(self, a, b): 643 | # return a*self.v0 + b*self.v1 + self.origin 644 | 645 | 646 | # class LossSurface(object): 647 | # def __init__(self, loss_fn, coords): 648 | # self.loss_fn = loss_fn 649 | # self.coords = coords 650 | 651 | # def compile(self, range, num_points): 652 | # loss_fn_0d = lambda x, y: self.loss_fn(self.coords(x,y)) 653 | # loss_fn_1d = jax.vmap(loss_fn_0d, in_axes = (0,0), out_axes = 0) 654 | # loss_fn_2d = jax.vmap(loss_fn_1d, in_axes = (0,0), out_axes = 0) 655 | 656 | # self.a_grid = jnp.linspace(-1.0, 1.0, num=num_points) ** 3 * range 657 | # self.b_grid = jnp.linspace(-1.0, 1.0, num=num_points) ** 3 * range 658 | # self.aa, self.bb = jnp.meshgrid(self.a_grid, self.b_grid) 659 | # self.loss_grid = loss_fn_2d(self.aa, self.bb) 660 | 661 | # def project_points(self, points): 662 | # x = jax.vmap(lambda x: jnp.dot(x, self.coords.v0)/jnp.linalg.norm(self.coords.v0), 0, 0)(points) 663 | # y = jax.vmap(lambda y: jnp.dot(y, self.coords.v1)/jnp.linalg.norm(self.coords.v1), 0, 0)(points) 664 | # return x, y 665 | 666 | # def plot(self, levels=30, points = None, ax=None, **kwargs): 667 | # xs = self.a_grid 668 | # ys = self.b_grid 669 | # zs = self.loss_grid 670 | # if ax is None: 671 | # fig, ax = plt.subplots(dpi = 600, **kwargs) 672 | # ax.set_title("Loss Surface") 673 | # ax.set_aspect("equal") 674 | 675 | # # Set Levels 676 | # min_loss = zs.min() 677 | # max_loss = zs.max() 678 | # # levels = jnp.exp( 679 | # # jnp.linspace( 680 | # # jnp.log(min_loss), jnp.log(max_loss), num=levels 681 | # # ) 682 | # # ) 683 | # levels = jnp.exp( 684 | # jnp.log(min_loss) + 685 | # jnp.linspace(0., 1.0, num=levels) ** 3 * (jnp.log(max_loss))- jnp.log(min_loss)) 686 | # # Create Contour Plot 687 | # CS = ax.contour( 688 | # xs, 689 | # ys, 690 | # zs, 691 | # levels=levels, 692 | # cmap= 'magma', 693 | # linewidths=0.75, 694 | # norm = mpl.colors.LogNorm(vmin=min_loss, vmax=max_loss * 2.0), 695 | # ) 696 | # point_x, point_y = self.project_points(points) 697 | # origin_x, origin_y = self.project_points(self.coords.origin) 698 | # ax.scatter(point_x, point_y, s = 0.25, c = 'g', marker = 'o') 699 | # ax.scatter(origin_x, origin_y, s = 1, c = 'r', marker = 'x') 700 | # ax.clabel(CS, fontsize=8, fmt="%1.2f") 701 | # #plt.colorbar(CS) 702 | # plt.show() 703 | # return ax 704 | 705 | # theta_prior_m = jnp.zeros((pinn.num_params,)) 706 | # theta_prior_std = jnp.ones((pinn.num_params,)) 707 | # prior_dist = tfd.Independent(tfd.Normal(loc= theta_prior_m, scale= theta_prior_std), 708 | # reinterpreted_batch_ndims= 1) 709 | 710 | # def target_log_prob_fn(theta): 711 | # prior = prior_dist.log_prob(theta) 712 | # r_likelihood = jnp.sum( -jnp.log(sigma_r) - jnp.log(2*jnp.pi)/2 -(y_r.ravel() - pinn.r_pred_vector(theta))**2/(2*sigma_r**2)) 713 | # u_likelihood = jnp.sum( -jnp.log(sigma_b) - jnp.log(2*jnp.pi)/2 -(y_data.ravel() - pinn.u_pred_vector(theta))**2/(2*sigma_b**2)) 714 | # return -(prior + r_likelihood + u_likelihood) 715 | 716 | # # def pinn_loss_fn(theta): 717 | # # prior = 1/2*jnp.linalg.norm(theta)**2 718 | # # r_likelihood = jnp.sum( (y_r.ravel() - pinn.r_pred_vector(theta))**2/(2*sigma_r**2)) 719 | # # u_likelihood = jnp.sum( (y_data.ravel() - pinn.u_pred_vector(theta))**2/(2*sigma_b**2)) 720 | # # return prior + r_likelihood + u_likelihood 721 | 722 | # optim_params = pinn.deterministic_run(random.PRNGKey(999), 100000, 200) 723 | # coords = RandomCoordinates(optim_params) 724 | # loss_surface = LossSurface(target_log_prob_fn, coords) 725 | # loss_surface.compile(range = 5, num_points= 500) 726 | # ax = loss_surface.plot(levels = 30, points = hmc_samples) 727 | 728 | 729 | -------------------------------------------------------------------------------- /2D_GWF_problem/2D_gwf_BPINN_inverse_HMC.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Jan 3 22:41:24 2024 5 | 6 | @author: yifei_linux 7 | """ 8 | 9 | #Import dependencies 10 | import jax 11 | import os 12 | import jax.numpy as jnp 13 | from jax import random, grad, vmap, jit 14 | from jax.flatten_util import ravel_pytree 15 | from jax.example_libraries import optimizers 16 | # from jax.lib import xla_bridge 17 | 18 | import numpy as np 19 | import pandas as pd 20 | import seaborn as sns 21 | import matplotlib.pyplot as plt 22 | import matplotlib as mpl 23 | import scipy.linalg as spl 24 | from tensorflow_probability.substrates import jax as tfp 25 | tfd = tfp.distributions 26 | 27 | import itertools 28 | import argparse 29 | from functools import partial 30 | from tqdm import trange 31 | from pyDOE import lhs 32 | from time import perf_counter 33 | 34 | #command line argument parser 35 | parser = argparse.ArgumentParser(description="2D Darcy with HMC") 36 | parser.add_argument( 37 | "--rand_seed", 38 | type=int, 39 | default=111, 40 | help="random seed") 41 | parser.add_argument( 42 | "--Nres", 43 | type=int, 44 | default=500, 45 | help="Number of reisudal points") 46 | parser.add_argument( 47 | "--Nchains", 48 | type=int, 49 | default=3, 50 | help="Number of Posterior chains") 51 | parser.add_argument( 52 | "--Nsamples", 53 | type=int, 54 | default=500, 55 | help="Number of Posterior samples") 56 | parser.add_argument( 57 | "--Nburn", 58 | type=int, 59 | default=100000, 60 | help="Number of Posterior samples") 61 | parser.add_argument( 62 | "--data_load", 63 | type=bool, 64 | default=False, 65 | help="If to load data") 66 | args = parser.parse_args() 67 | 68 | # print(xla_bridge.get_backend().platform) 69 | #jax.config.update('jax_platform_name', 'cpu') 70 | print(f'jax is using: {jax.devices()} \n') 71 | 72 | #Define parameters 73 | layers_k = [2, 60, 60, 60, 60, 1] 74 | layers_h = [2, 60, 60, 60, 60, 1] 75 | lbt = np.array([0., 0.]) 76 | ubt = np.array([1., 0.5]) 77 | num_print = 200 78 | Nk = 40 79 | Nh = 40 80 | sigma = 0.1 81 | sigma_r = 0.3536 82 | sigma_d = 0.1 83 | sigma_nbl = 0.0632 84 | sigma_nbb = 0.089 85 | sigma_nbt = 0.089 86 | sigma_dbr = 0.0632 87 | sigma_p = 2.369 88 | rand_seed = args.rand_seed 89 | Nres = args.Nres 90 | Nchains = args.Nchains 91 | Nsamples = args.Nsamples 92 | Nburn = args.Nburn 93 | model_load = args.data_load 94 | dataset = dict() 95 | x = np.linspace(lbt[0], ubt[0], 256) 96 | y = np.linspace(lbt[1], ubt[1], 128) 97 | XX, YY = np.meshgrid(x,y) 98 | 99 | #Load data 100 | k_ref = np.loadtxt('k_ref_05.out', dtype=float) #(32768,) 101 | h_ref = np.loadtxt('h_ref_05.out', dtype=float) #(32768,) 102 | y_ref = np.log(k_ref) 103 | k_ref = y_ref 104 | h_ref = h_ref - h_ref.min(0) 105 | x_ref = np.loadtxt('coord_ref.out', dtype=float) #(32768, 2) 106 | N = k_ref.shape[0] 107 | 108 | #Create dataset 109 | #k, h measurements and residual points 110 | np.random.seed(rand_seed) 111 | #idx_k = np.random.choice(N, Nk, replace= False) 112 | idx_k = np.loadtxt('Nk_40_Nh_40_randseed_111_idxk.out').astype(np.int64) 113 | y_k, x_k = k_ref[idx_k][:,np.newaxis] + np.random.normal(0,sigma,(Nk,1)).astype(np.float32), x_ref[idx_k, :] 114 | k_data = jnp.concatenate([x_k,y_k], axis=1) 115 | #idx_h = np.random.choice(N, Nh, replace= False) 116 | idx_h = np.loadtxt('Nk_40_Nh_40_randseed_111_idxk.out').astype(np.int64) 117 | y_h, x_h = h_ref[idx_h][:,np.newaxis] + np.random.normal(0,sigma,(Nh,1)).astype(np.float32), x_ref[idx_h, :] 118 | h_data = jnp.concatenate([x_h,y_h], axis=1) 119 | 120 | x_nor = lhs(2,200000)[:Nres,:] 121 | #x_nor = np.loadtxt(f'Nk_40_Nh_40_Nres_{Nres}_randseed_111_xnor.out') 122 | x_res = lbt + (ubt -lbt) * x_nor 123 | y_res= np.zeros((Nres,1)) + np.random.normal(0,sigma,(Nres,1)).astype(np.float32) 124 | res = jnp.concatenate([x_res, y_res],axis=1) 125 | 126 | # Dirichlet BC at right 127 | x2_dbr = np.linspace(lbt[1],ubt[1],16)[:,np.newaxis] 128 | x1_dbr = ubt[0]*jnp.ones_like(x2_dbr) 129 | y_dbr = jnp.zeros_like(x2_dbr) + np.random.normal(0,sigma,(16,1)).astype(np.float32) 130 | dbr = jnp.concatenate([x1_dbr,x2_dbr,y_dbr],axis=1) 131 | 132 | # Neumann BC at lefth 133 | x2_nbl = np.linspace(lbt[1],ubt[1],16)[:,np.newaxis] 134 | x1_nbl = lbt[0]*jnp.ones_like(x2_nbl) 135 | y_nbl = jnp.ones_like(x2_nbl) + np.random.normal(0,sigma,(16,1)).astype(np.float32) 136 | nbl = jnp.concatenate([x1_nbl,x2_nbl,y_nbl],axis=1) 137 | 138 | # Neumann BC at top 139 | x1_nbt = np.linspace(lbt[0],ubt[0],32)[:,np.newaxis] 140 | x2_nbt = ubt[1]*jnp.ones_like(x1_nbt) 141 | y_nbt = jnp.zeros_like(x1_nbt) + np.random.normal(0,sigma,(32,1)).astype(np.float32) 142 | nbt = jnp.concatenate([x1_nbt,x2_nbt,y_nbt],axis=1) 143 | 144 | # Neumann BC at below 145 | x1_nbb = np.linspace(lbt[0],ubt[0],32)[:,np.newaxis] 146 | x2_nbb = lbt[1]*jnp.ones_like(x1_nbb) 147 | y_nbb = jnp.zeros_like(x1_nbb) + np.random.normal(0,sigma,(32,1)).astype(np.float32) 148 | nbb = jnp.concatenate([x1_nbb,x2_nbb,y_nbb],axis=1) 149 | 150 | dataset.update({'k_data': k_data}) 151 | dataset.update({'h_data': h_data}) 152 | dataset.update({'res': res}) 153 | dataset.update({'dbr': dbr}) 154 | dataset.update({'nbl': nbl}) 155 | dataset.update({'nbt': nbt}) 156 | dataset.update({'nbb': nbb}) 157 | 158 | path_f = f'2D_Nk_{Nk}_Nh_{Nh}_Nres_{Nres}_sigma_{sigma}_Nburn_{Nburn}_Nsamples_{Nsamples}_HMC' 159 | path_fig = os.path.join(path_f,'figures') 160 | if not os.path.exists(path_f): 161 | os.makedirs(path_f) 162 | if not os.path.exists(path_fig): 163 | os.makedirs(path_fig) 164 | f_rec = open(os.path.join(path_f,'record.out'), 'a+') 165 | 166 | print(f'method:HMC rand_seed:{rand_seed}', file = f_rec) 167 | print(f'layers_k:{layers_k} layers_h:{layers_h}', file = f_rec) 168 | print(f'Nk:{Nk} Nh:{Nh} Nres:{Nres}\n', file = f_rec) 169 | print(f'sigma:{sigma} sigma_r:{sigma_r} sigma_nbl:{sigma_nbl} sigma_nbb:{sigma_nbb} sigma_nbt:{sigma_nbt} sigma_dbr:{sigma_dbr} sigma_p:{sigma_p}\n') 170 | print(f'sigma:{sigma} sigma_r:{sigma_r} sigma_nbl:{sigma_nbl} sigma_nbb:{sigma_nbb} sigma_nbt:{sigma_nbt} sigma_dbr:{sigma_dbr} sigma_p:{sigma_p}\n', file = f_rec) 171 | 172 | rl2e = lambda yest, yref : spl.norm(yest - yref, 2) / spl.norm(yref, 2) 173 | infe = lambda yest, yref : spl.norm(yest - yref, np.inf) 174 | lpp = lambda h, href, sigma: np.sum( -(h - href)**2/(2*sigma**2) - 1/2*np.log( 2*np.pi) - 2*np.log(sigma)) 175 | 176 | def pcolormesh(XX, YY, Z, points = None, title = None, savefig = None, cmap='jet', vmax = None, vmin = None): 177 | fig, ax = plt.subplots(dpi = 300, figsize = (6,4)) 178 | if vmax is not None: 179 | c = ax.pcolormesh(XX, YY, Z, vmin = vmin, vmax = vmax, cmap=cmap) 180 | else: 181 | c = ax.pcolormesh(XX, YY, Z, vmin = np.min(Z), vmax = np.max(Z), cmap=cmap) 182 | if points is not None: 183 | plt.plot(points[:,0], points[:,1], 'ko', markersize = 1.0) 184 | fig.colorbar(c, ax=ax, fraction= 0.05, pad= 0.05) 185 | ax.tick_params(axis='both', which = 'major', labelsize=16) 186 | ax.set_xlabel('$x_1$', fontsize=20) 187 | ax.set_ylabel('$x_2$', fontsize=20) 188 | if title is not None: 189 | ax.set_title(title, fontsize=14) 190 | fig.tight_layout() 191 | #ax.set_aspect('equal') 192 | plt.show() 193 | if savefig is not None: 194 | plt.savefig(os.path.join(path_fig,f'{savefig}.png')) 195 | 196 | # Define FNN 197 | def FNN(layers, activation=jnp.tanh): 198 | 199 | def init(prng_key): #return a list of (W,b) tuples 200 | def init_layer(key, d_in, d_out): 201 | key1, key2 = random.split(key) 202 | glorot_stddev = 1.0 / jnp.sqrt((d_in + d_out) / 2.) 203 | W = glorot_stddev * random.normal(key1, (d_in, d_out)) 204 | b = jnp.zeros(d_out) 205 | return W, b 206 | key, *keys = random.split(prng_key, len(layers)) 207 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 208 | return params 209 | 210 | def forward(params, inputs): 211 | Z = inputs 212 | for W, b in params[:-1]: 213 | outputs = jnp.dot(Z, W) + b 214 | Z = activation(outputs) 215 | W, b = params[-1] 216 | outputs = jnp.dot(Z, W) + b 217 | return outputs 218 | 219 | return init, forward 220 | 221 | # Define the model 222 | class PINN(): 223 | def __init__(self, key, layers_k, layers_h, dataset, lbt, ubt, sigma_r, sigma_d, sigma_nbl, sigma_nbb, sigma_nbt, sigma_dbr, sigma_p): 224 | 225 | self.lbt = lbt #domain lower corner 226 | self.ubt = ubt #domain upper corner 227 | self.sigma_r = sigma_r #residual term 228 | self.sigma_d = sigma_d #data term 229 | self.sigma_nbl = sigma_nbl 230 | self.sigma_nbb = sigma_nbb 231 | self.sigma_nbt = sigma_nbt 232 | self.sigma_dbr = sigma_dbr 233 | self.sigma_p = sigma_p #prior term 234 | self.itercount = itertools.count() 235 | 236 | # Prepare normalized training data 237 | self.dataset = dataset 238 | self.x_res, self.y_res = dataset['res'][:,0:2], dataset['res'][:,2:3] 239 | self.x_dbr, self.y_dbr = dataset['dbr'][:,0:2], dataset['dbr'][:,2:3] 240 | self.x_nbl, self.y_nbl = dataset['nbl'][:,0:2], dataset['nbl'][:,2:3] 241 | self.x_nbt, self.y_nbt = dataset['nbt'][:,0:2], dataset['nbt'][:,2:3] 242 | self.x_nbb, self.y_nbb = dataset['nbb'][:,0:2], dataset['nbb'][:,2:3] 243 | self.x_h, self.y_h = dataset['h_data'][:,0:2], dataset['h_data'][:,2:3] 244 | self.x_k, self.y_k = dataset['k_data'][:,0:2], dataset['k_data'][:,2:3] 245 | 246 | # Initalize the network 247 | key, *keys = random.split(key, num = 3) 248 | self.init_k, self.forward_k = FNN(layers_k, activation=jnp.tanh) 249 | self.params_k = self.init_k(keys[0]) 250 | raveled_k, self.unravel_k = ravel_pytree(self.params_k) 251 | self.num_params_k = raveled_k.shape[0] 252 | 253 | self.init_h, self.forward_h = FNN(layers_h, activation=jnp.tanh) 254 | self.params_h = self.init_h(keys[1]) 255 | raveled_h, self.unravel_h = ravel_pytree(self.params_h) 256 | self.num_params_h = raveled_h.shape[0] 257 | self.num_params = self.num_params_k + self.num_params_h 258 | 259 | # Evaluate the state, parameter and the residual over the grid 260 | self.h_pred_map = vmap(self.h_net, (None, 0, 0)) 261 | self.k_pred_map = vmap(self.k_net, (None, 0, 0)) 262 | self.r_pred_map = vmap(self.res_net, (None, 0, 0)) 263 | 264 | # Optimizer 265 | lr = optimizers.exponential_decay(1e-4, decay_steps=5000, decay_rate=0.9) 266 | self.opt_init, \ 267 | self.opt_update, \ 268 | self.get_params = optimizers.adam(lr) 269 | 270 | self.opt_state_k = self.opt_init(self.params_k) 271 | self.opt_state_h = self.opt_init(self.params_h) 272 | 273 | @partial(jit, static_argnums=(0,)) 274 | def h_net(self, params, x1, x2): #no problem 275 | inputs = jnp.hstack([x1, x2]) 276 | outputs = self.forward_h(params[1], inputs) 277 | return outputs[0] 278 | 279 | @partial(jit, static_argnums=(0,)) 280 | def k_net(self, params, x1, x2): #no problem 281 | inputs = jnp.hstack([x1, x2]) 282 | outputs = self.forward_k(params[0], inputs) 283 | return outputs[0] 284 | 285 | @partial(jit, static_argnums=(0,)) 286 | def qx(self, params, x1, x2): 287 | k = jnp.exp(self.k_net(params, x1, x2)) 288 | #k = self.k_net(params, x1, x2) 289 | dhdx = grad(self.h_net, argnums=1)(params, x1, x2) 290 | return -k*dhdx 291 | 292 | @partial(jit, static_argnums=(0,)) 293 | def qy(self, params, x1, x2): 294 | k = jnp.exp(self.k_net(params, x1, x2)) 295 | #k = self.k_net(params, x1, x2) 296 | dhdy = grad(self.h_net, argnums=2)(params, x1, x2) 297 | return -k*dhdy 298 | 299 | @partial(jit, static_argnums=(0,)) 300 | def res_net(self, params, x1, x2): 301 | dhdx2 = grad(self.qx, argnums=1)(params, x1, x2) 302 | dhdy2 = grad(self.qy, argnums=2)(params, x1, x2) 303 | return dhdx2 + dhdy2 304 | 305 | @partial(jit, static_argnums=(0,)) 306 | def h_pred_vector(self, flat_params): 307 | # For HMC 308 | params_k = self.unravel_k(flat_params[:self.num_params_k]) 309 | params_h = self.unravel_h(flat_params[self.num_params_k:]) 310 | params = [params_k, params_h] 311 | h_pred_vec = vmap(self.h_net, (None, 0, 0))(params, self.x_h[:,0], self.x_h[:,1]) 312 | return h_pred_vec 313 | 314 | @partial(jit, static_argnums=(0,)) 315 | def k_pred_vector(self, flat_params): 316 | # For HMC 317 | params_k = self.unravel_k(flat_params[:self.num_params_k]) 318 | params_h = self.unravel_h(flat_params[self.num_params_k:]) 319 | params = [params_k, params_h] 320 | k_pred_vec = vmap(self.k_net, (None, 0, 0))(params, self.x_k[:,0], self.x_k[:,1]) 321 | return k_pred_vec 322 | 323 | @partial(jit, static_argnums=(0,)) 324 | def r_pred_vector(self, flat_params): 325 | # For HMC 326 | params_k = self.unravel_k(flat_params[:self.num_params_k]) 327 | params_h = self.unravel_h(flat_params[self.num_params_k:]) 328 | params = [params_k, params_h] 329 | r_pred_vec = vmap(self.res_net, (None, 0, 0 ))(params, self.x_res[:,0], self.x_res[:,1]) 330 | return r_pred_vec 331 | 332 | @partial(jit, static_argnums=(0,)) 333 | def dbr_pred_vector(self, flat_params): 334 | # For HMC 335 | params_k = self.unravel_k(flat_params[:self.num_params_k]) 336 | params_h = self.unravel_h(flat_params[self.num_params_k:]) 337 | params = [params_k, params_h] 338 | dbr_pred_vec = vmap(self.h_net, (None, 0, 0))(params, self.x_dbr[:,0], self.x_dbr[:,1]) 339 | return dbr_pred_vec 340 | 341 | @partial(jit, static_argnums=(0,)) 342 | def nbl_pred_vector(self, flat_params): 343 | params_k = self.unravel_k(flat_params[:self.num_params_k]) 344 | params_h = self.unravel_h(flat_params[self.num_params_k:]) 345 | params = [params_k, params_h] 346 | nbl_pred_vec = vmap(self.qx, (None, 0, 0))(params, self.x_nbl[:,0], self.x_nbl[:,1]) 347 | return nbl_pred_vec 348 | 349 | @partial(jit, static_argnums=(0,)) 350 | def nbt_pred_vector(self, flat_params): 351 | params_k = self.unravel_k(flat_params[:self.num_params_k]) 352 | params_h = self.unravel_h(flat_params[self.num_params_k:]) 353 | params = [params_k, params_h] 354 | nbt_pred_vec = vmap(self.qy, (None, 0, 0))(params, self.x_nbt[:,0], self.x_nbt[:,1]) 355 | return nbt_pred_vec 356 | 357 | @partial(jit, static_argnums=(0,)) 358 | def nbb_pred_vector(self, flat_params): 359 | params_k = self.unravel_k(flat_params[:self.num_params_k]) 360 | params_h = self.unravel_h(flat_params[self.num_params_k:]) 361 | params = [params_k, params_h] 362 | nbb_pred_vec = vmap(self.qy, (None, 0, 0))(params, self.x_nbb[:,0], self.x_nbb[:,1]) 363 | return nbb_pred_vec 364 | 365 | @partial(jit, static_argnums=(0,)) 366 | def target_log_prob_fn(self, theta): 367 | prior = -1/(2*self.sigma_p**2) * jnp.sum((theta)**2) 368 | r_likelihood = -1/(2*self.sigma_r**2) * jnp.sum((self.y_res.ravel() - self.r_pred_vector(theta))**2) 369 | k_likelihood = -1/(2*self.sigma_d**2) * jnp.sum((self.y_k.ravel() - self.k_pred_vector(theta))**2) 370 | h_likelihood = -1/(2*self.sigma_d**2) * jnp.sum((self.y_h.ravel() - self.h_pred_vector(theta))**2) 371 | dbr_likelihood = -1/(2*self.sigma_dbr**2) * jnp.sum((self.y_dbr.ravel() - self.dbr_pred_vector(theta))**2) 372 | nbl_likelihood = -1/(2*self.sigma_nbl**2) * jnp.sum((self.y_nbl.ravel() - self.nbl_pred_vector(theta))**2) 373 | nbb_likelihood = -1/(2*self.sigma_nbb**2) * jnp.sum((self.y_nbb.ravel() - self.nbb_pred_vector(theta))**2) 374 | nbt_likelihood = -1/(2*self.sigma_nbt**2) * jnp.sum((self.y_nbt.ravel() - self.nbt_pred_vector(theta))**2) 375 | return prior + r_likelihood + k_likelihood + h_likelihood + dbr_likelihood + nbl_likelihood + nbb_likelihood + nbt_likelihood 376 | 377 | key1, key2 = random.split(random.PRNGKey(0), 2) 378 | pinn = PINN(key2, layers_k, layers_h, dataset, lbt, ubt, sigma_r, sigma_d, sigma_nbl, sigma_nbb, sigma_nbt, sigma_dbr, sigma_p) 379 | 380 | new_key, *subkeys = random.split(key1, Nchains + 1) 381 | init_state = jnp.zeros((1, pinn.num_params)) 382 | for key in subkeys: 383 | init_state = jnp.concatenate([init_state,random.normal(key ,(1, pinn.num_params))], axis=0) 384 | 385 | nuts_kernel = tfp.mcmc.NoUTurnSampler( 386 | target_log_prob_fn = pinn.target_log_prob_fn, step_size = 0.0005, max_tree_depth=10, max_energy_diff=1000.0, 387 | unrolled_leapfrog_steps=1, parallel_iterations=30) 388 | 389 | kernel = tfp.mcmc.DualAveragingStepSizeAdaptation( 390 | inner_kernel=nuts_kernel, num_adaptation_steps=int(Nburn * 0.75)) 391 | 392 | def run_chain(init_state, key): 393 | samples, trace = tfp.mcmc.sample_chain( 394 | num_results= Nsamples, 395 | num_burnin_steps= Nburn, 396 | current_state= init_state, 397 | kernel= kernel, 398 | seed=key, 399 | trace_fn= lambda _,pkr: [pkr.inner_results.log_accept_ratio, 400 | pkr.inner_results.target_log_prob, 401 | pkr.inner_results.step_size] 402 | ) 403 | return samples, trace 404 | 405 | ts = perf_counter() 406 | print('\nStart HMC Sampling') 407 | states, trace = jit(vmap(run_chain, in_axes=(0, None)))(init_state, new_key) 408 | print('\nFinish HMC Sampling') 409 | np.save(os.path.join(path_f,'chains'), states) 410 | #states = np.load(os.path.join(path_f,'chains.npy')) 411 | timings = perf_counter() - ts 412 | print(f"HMC: {timings} s") 413 | print(f"HMC: {timings} s", file = f_rec) 414 | # ============================================================================= 415 | # Post-processing HMC results 416 | # ============================================================================= 417 | 418 | accept_ratio = np.exp(trace[0]) 419 | target_log_prob = trace[1] 420 | step_size = trace[2] 421 | print(f'Average accept ratio for each chain: {np.mean(accept_ratio, axis = 1)}') 422 | print(f'Average step size for each chain: {np.mean(step_size, axis = 1)}') 423 | print(f'Average accept ratio for each chain: {np.mean(accept_ratio, axis = 1)}', file = f_rec) 424 | print(f'Average step size for each chain: {np.mean(step_size, axis = 1)}', file = f_rec) 425 | 426 | samples = states #(Nchains, Nsamples, Nparams) 427 | 428 | @jit 429 | def get_h_pred(sample): 430 | params_k = pinn.unravel_k(sample[:pinn.num_params_k]) 431 | params_h = pinn.unravel_h(sample[pinn.num_params_k:]) 432 | params = [params_k, params_h] 433 | return pinn.h_pred_map(params,x_ref[:,0], x_ref[:,1]) 434 | 435 | @jit 436 | def get_k_pred(sample): 437 | params_k = pinn.unravel_k(sample[:pinn.num_params_k]) 438 | params_h = pinn.unravel_h(sample[pinn.num_params_k:]) 439 | params = [params_k, params_h] 440 | return pinn.k_pred_map(params,x_ref[:,0], x_ref[:,1]) 441 | 442 | @jit 443 | def get_r_pred(sample): 444 | params_k = pinn.unravel_k(sample[:pinn.num_params_k]) 445 | params_h = pinn.unravel_h(sample[pinn.num_params_k:]) 446 | params = [params_k, params_h] 447 | return pinn.r_pred_map(params,x_ref[:,0], x_ref[:,1]) 448 | 449 | h_pred_ens = np.array([vmap(get_h_pred)(samples[i,:,:]) for i in range(samples.shape[0])]) #(Nchains, Nsamples, 32768) 450 | k_pred_ens = np.array([vmap(get_k_pred)(samples[i,:,:]) for i in range(samples.shape[0])]) 451 | 452 | # h_pred_ens = vmap(get_h_pred)(samples[:,:,:]) 453 | # k_pred_ens = vmap(get_k_pred)(samples[:,:,:]) 454 | 455 | h_pred_ens_mean = np.mean(h_pred_ens, axis = 0) #(Nchains, 32768) 456 | h_pred_ens_std = np.std(h_pred_ens, axis = 0) 457 | k_pred_ens_mean = np.mean(k_pred_ens, axis = 0) 458 | k_pred_ens_std = np.std(k_pred_ens, axis = 0) 459 | 460 | h_env = np.logical_and( (h_pred_ens_mean < h_ref.ravel() + 2*h_pred_ens_std), (h_pred_ens_mean > h_ref.ravel() - 2*h_pred_ens_std) ) 461 | k_env = np.logical_and( (k_pred_ens_mean < k_ref.ravel() + 2*k_pred_ens_std), (k_pred_ens_mean > k_ref.ravel() - 2*k_pred_ens_std) ) 462 | 463 | for i in range(Nchains): 464 | rl2e_h = rl2e(h_pred_ens_mean[i, :], h_ref) 465 | infe_h = infe(h_pred_ens_mean[i, :], h_ref) 466 | lpp_h = lpp(h_pred_ens_mean[i, :], h_ref, h_pred_ens_std[i, :]) 467 | rl2e_k = rl2e(k_pred_ens_mean[i, :], k_ref) 468 | infe_k = infe(k_pred_ens_mean[i, :], k_ref) 469 | lpp_k = lpp(k_pred_ens_mean[i, :], k_ref, k_pred_ens_std[i, :]) 470 | 471 | print('chains:{i}\n') 472 | print('h prediction:\n') 473 | print('Relative RL2 error: {}'.format(rl2e_h)) 474 | print('Absolute inf error: {}'.format(infe_h)) 475 | print('Average standard deviation: {}'.format(np.mean(h_pred_ens_std[i, :]))) 476 | print('log predictive probability: {}'.format(lpp_h)) 477 | print('Percentage of coverage:{}\n'.format(np.sum(h_env)/32768)) 478 | 479 | print('k prediction:\n') 480 | print('Relative RL2 error: {}'.format(rl2e_k)) 481 | print('Absolute inf error: {}'.format(infe_k)) 482 | print('Average standard deviation: {}'.format(np.mean(k_pred_ens_std[i, :]))) 483 | print('log predictive probability: {}'.format(lpp_k)) 484 | print('Percentage of coverage:{}\n'.format(np.sum(k_env)/32768)) 485 | 486 | print('chains:{i}\n', file = f_rec) 487 | print('h prediction:\n', file = f_rec) 488 | print('Relative RL2 error: {}'.format(rl2e_h), file = f_rec) 489 | print('Absolute inf error: {}'.format(infe_h), file = f_rec) 490 | print('Average standard deviation: {}'.format(np.mean(h_pred_ens_std[i, :])), file = f_rec) 491 | print('log predictive probability: {}'.format(lpp_h), file = f_rec) 492 | print('Percentage of coverage:{}\n'.format(np.sum(h_env)/32768), file = f_rec) 493 | 494 | print('k prediction:\n', file = f_rec) 495 | print('Relative RL2 error: {}'.format(rl2e_k), file = f_rec) 496 | print('Absolute inf error: {}'.format(infe_k), file = f_rec) 497 | print('Average standard deviation: {}'.format(np.mean(k_pred_ens_std[i, :])), file = f_rec) 498 | print('log predictive probability: {}'.format(lpp_k), file = f_rec) 499 | print('Percentage of coverage:{}\n'.format(np.sum(k_env)/32768), file = f_rec) 500 | 501 | 502 | #Plot of k field 503 | pcolormesh(XX, YY, k_ref.reshape(128,256), points = None, title = None, savefig = 'y_ref') 504 | savefig = f'2D_gwf_rPINN_sigmar_{sigma_r}_Nsamples_{Nsamples}_ypred_mean' 505 | pcolormesh(XX, YY, k_pred_ens_mean.reshape(128,256), points = None, title = None, savefig = savefig) 506 | savefig = f'2D_gwf_rPINN_sigmar_{sigma_r}_Nsamples_{Nsamples}_ypred_std' 507 | pcolormesh(XX, YY, k_pred_ens_std.reshape(128,256), points = x_k, title = None, savefig = savefig) 508 | savefig = f'2D_gwf_rPINN_sigmar_{sigma_r}_Nsamples_{Nsamples}_ypred_diff' 509 | pcolormesh(XX, YY, np.abs(k_pred_ens_mean.reshape(128,256) - k_ref.reshape(128,256)), points = x_k, title = None, savefig = savefig) 510 | savefig = f'2D_gwf_rPINN_sigmar_{sigma_r}_Nsamples_{Nsamples}_ypred_env' 511 | pcolormesh(XX, YY, k_env.reshape(128,256), points = x_k, title = None, savefig = savefig) 512 | 513 | #Plot of h field 514 | pcolormesh(XX, YY, h_ref.reshape(128,256), points = None, title = None, savefig = 'h_ref') 515 | savefig = f'2D_gwf_rPINN_sigmar_{sigma_r}_Nsamples_{Nsamples}_hpred_mean' 516 | pcolormesh(XX, YY, h_pred_ens_mean.reshape(128,256), points = None, title = None, savefig = savefig ) 517 | savefig = f'2D_gwf_rPINN_sigmar_{sigma_r}_Nsamples_{Nsamples}_hpred_std' 518 | pcolormesh(XX, YY, h_pred_ens_std.reshape(128,256), points = x_h, title = None, savefig = savefig ) 519 | savefig = f'2D_gwf_rPINN_sigmar_{sigma_r}_Nsamples_{Nsamples}_hpred_diff' 520 | pcolormesh(XX, YY, np.abs(h_pred_ens_mean.reshape(128,256) - h_ref.reshape(128,256)), points = x_h, title = None, savefig = savefig ) 521 | savefig = f'2D_gwf_rPINN_sigmar_{sigma_r}_Nsamples_{Nsamples}_hpred_env' 522 | pcolormesh(XX, YY, h_env.reshape(128,256), points = x_h, title = None, savefig = savefig, cmap='jet', vmax = 1, vmin = 0) 523 | 524 | 525 | rhat = tfp.mcmc.diagnostic.potential_scale_reduction(states.transpose((1,0,2)), independent_chain_ndims=1) 526 | ess = tfp.mcmc.effective_sample_size(states[0], filter_beyond_positive_pairs=True) 527 | 528 | fig, ax = plt.subplots(1, 1, figsize=(4, 4), sharex='col', sharey='col', dpi = 300) 529 | g = sns.histplot(rhat, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 530 | g.tick_params(labelsize=16) 531 | g.set_xlabel("$\hat{r}$", fontsize=18) 532 | g.set_ylabel("Count", fontsize=18) 533 | fig.tight_layout() 534 | plt.savefig(os.path.join(path_fig,'rhat.png')) 535 | plt.show() 536 | 537 | fig, ax = plt.subplots(1, 1, figsize=(4, 4), sharex='col', sharey='col', dpi = 300) 538 | g = sns.histplot(ess, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 539 | g.tick_params(labelsize=16) 540 | g.set_xlabel("ESS", fontsize=18) 541 | g.set_ylabel("Count", fontsize=18) 542 | fig.tight_layout() 543 | plt.savefig(os.path.join(path_fig,'ess.png')) 544 | plt.show() 545 | 546 | idx_low = np.argmin(rhat) 547 | idx_high = np.argmax(rhat) 548 | samples1 = states[:,:,idx_low] 549 | df1 = pd.DataFrame({'chains': 'chain1', 'indice':np.arange(0, samples1[0,].shape[0], 5), 'trace':samples1[0, ::5]}) 550 | df2 = pd.DataFrame({'chains': 'chain2', 'indice':np.arange(0, samples1[1,].shape[0], 5), 'trace':samples1[1, ::5]}) 551 | df3 = pd.DataFrame({'chains': 'chain3', 'indice':np.arange(0, samples1[2,].shape[0], 5), 'trace':samples1[2, ::5]}) 552 | df = pd.concat([df1, df2, df3], ignore_index=True) 553 | plt.figure(figsize=(4,4)) 554 | g = sns.jointplot(data=df, x='indice', y='trace', xlim=(0, 500), ylim=(-6, 6), hue='chains', joint_kws={'alpha': 0.6}) 555 | g.ax_joint.tick_params(labelsize=18) 556 | g.ax_joint.set_xlabel("Index", fontsize=24) 557 | g.ax_joint.set_ylabel("Trace", fontsize=24) 558 | g.ax_joint.legend(fontsize=16) 559 | g.ax_marg_x.remove() 560 | #plt.title('Trace plot for parameter with lowest $\hat{r}$') 561 | plt.gcf().set_dpi(300) 562 | plt.savefig(os.path.join(path_fig,'trace_plot_rhat_lowest.png')) 563 | fig.tight_layout() 564 | plt.show() 565 | 566 | samples2 = states[:,:,idx_high] 567 | df1 = pd.DataFrame({'chains': 'chain1', 'indice':np.arange(0, samples2[0, ::].shape[0], 5), 'trace':samples2[0, ::5]}) 568 | df2 = pd.DataFrame({'chains': 'chain2', 'indice':np.arange(0, samples2[1, ::].shape[0], 5), 'trace':samples2[1, ::5]}) 569 | df3 = pd.DataFrame({'chains': 'chain3', 'indice':np.arange(0, samples2[2, ::].shape[0], 5), 'trace':samples2[2, ::5]}) 570 | df = pd.concat([df1,df2, df3], ignore_index=True) 571 | plt.figure(figsize=(4,4)) 572 | #g = sns.jointplot(data=df, x='indice', y='trace', xlim=(0, 5000), ylim=(-4, 4), hue='chains', joint_kws={'alpha': 1}) 573 | g = sns.jointplot(data=df, x='indice', y='trace', xlim=(0, 500), ylim=(-6, 6), hue='chains', joint_kws={'alpha': 0.6}) 574 | g.ax_joint.tick_params(labelsize=18) 575 | g.ax_joint.set_xlabel("Index", fontsize=24) 576 | g.ax_joint.set_ylabel("Trace", fontsize=24) 577 | g.ax_joint.legend(fontsize=16) 578 | g.ax_marg_x.remove() 579 | #plt.title('Trace plot for parameter with highest $\hat{r}$') 580 | plt.gcf().set_dpi(300) 581 | plt.savefig(os.path.join(path_fig,'trace_plot_rhat_highest.png')) 582 | fig.tight_layout() 583 | plt.show() 584 | 585 | mark = [None, 'o', None] 586 | linestyle = ['solid', 'dotted', 'dashed'] 587 | fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 588 | for i, mark in enumerate(mark): 589 | ax.plot(np.arange(Nsamples)[::10], target_log_prob[i,::10], marker = mark, markersize = 2, markevery= 100, markerfacecolor='None', linestyle = 'dashed', label = f'chain {i + 1}', alpha = 0.8) 590 | ax.set_xlabel('Sample index', fontsize = 15) 591 | ax.set_ylabel('Negative log prob', fontsize = 15) 592 | ax.tick_params(axis='both', which = 'major', labelsize=12) 593 | ax.set_xlim(0,Nsamples) 594 | ax.legend(fontsize=10) 595 | plt.savefig(os.path.join(path_fig,'target_log_prob.png')) 596 | plt.show() 597 | 598 | # chain0 = states[0] 599 | # chain1 = states[1] 600 | # chain2 = states[2] 601 | # chain0_m = np.mean(chain0, axis = 0) 602 | # chain1_m = np.mean(chain1, axis = 0) 603 | # chain2_m = np.mean(chain2, axis = 0) 604 | # hess = jax.hessian(pinn.target_log_prob_fn) 605 | # hess_chain0 = hess(chain0_m) 606 | # _, s0, _ = jax.scipy.linalg.svd(jax.scipy.linalg.inv(hess_chain0)) 607 | # hess_chain1 = hess(chain1_m) 608 | # _, s1, _ = jax.scipy.linalg.svd(jax.scipy.linalg.inv(hess_chain1)) 609 | # hess_chain2 = hess(chain2_m) 610 | # _, s2, _ = jax.scipy.linalg.svd(jax.scipy.linalg.inv(hess_chain2)) 611 | 612 | # s = np.concatenate((s0[np.newaxis, :], s1[np.newaxis, :], s2[np.newaxis, :]), axis = 0) 613 | # np.savetxt(os.path.join(path_f,'singular_values_posterior_hessian.out'), s) 614 | 615 | # fig, ax = plt.subplots(dpi = 300, figsize = (4,4)) 616 | # #mark = [None, 'o', None] 617 | # linestyle = ['solid', 'dotted', 'dashed'] 618 | # for i, ls in enumerate(linestyle): 619 | # ax.plot(s[i], linestyle = ls, marker = None, markersize = 2, markevery= 100, markerfacecolor='None', label=f'chain{i+1}', alpha = 0.8) 620 | # ax.set_xlabel('Index', fontsize=16) 621 | # ax.set_ylabel('Eigenvalues', fontsize=16) 622 | # plt.yscale('log') 623 | # ax.tick_params(axis='both', which = 'major', labelsize=13) 624 | # ax.legend(fontsize=8) 625 | # plt.savefig(os.path.join(path_fig,'singular_values_posterior_hessian.png')) 626 | # plt.show() 627 | 628 | # fig, ax = plt.subplots(1, 1, figsize=(4, 4), sharex='col', sharey='col', dpi = 300) 629 | # g = sns.histplot(chain0_m, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 630 | # g = sns.histplot(chain1_m, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 631 | # g = sns.histplot(chain2_m, bins = 50, kde=True, kde_kws = {'gridsize':5000}) 632 | # g.tick_params(labelsize=16) 633 | # g.set_xlabel("Weight", fontsize=18) 634 | # g.set_ylabel("Count", fontsize=18) 635 | # fig.tight_layout() 636 | # plt.savefig(os.path.join(path_fig,'weight.png')) 637 | # plt.show() 638 | -------------------------------------------------------------------------------- /2D_GWF_problem/2D_gwf_inverse_rPINN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Jan 5 14:09:03 2024 5 | 6 | @author: yifei_linux 7 | """ 8 | 9 | #Import dependencies 10 | import jax 11 | import os 12 | import jax.numpy as jnp 13 | from jax import random, grad, vmap, jit 14 | from jax.flatten_util import ravel_pytree 15 | from jax.example_libraries import optimizers 16 | # from jax.lib import xla_bridge 17 | 18 | import numpy as np 19 | import matplotlib.pyplot as plt 20 | import matplotlib as mpl 21 | import scipy.linalg as spl 22 | from tensorflow_probability.substrates import jax as tfp 23 | tfd = tfp.distributions 24 | 25 | import itertools 26 | import argparse 27 | from functools import partial 28 | from tqdm import trange 29 | from pyDOE import lhs 30 | 31 | #command line argument parser 32 | parser = argparse.ArgumentParser(description="2D Darcy with rPINN") 33 | parser.add_argument( 34 | "--rand_seed", 35 | type=int, 36 | default=111, 37 | help="random seed") 38 | parser.add_argument( 39 | "--sigma", 40 | type=float, 41 | default=1, 42 | help="Measurement noise level") 43 | parser.add_argument( 44 | "--lambda_r", 45 | type=float, 46 | default=1, 47 | help="residual weight for the PINN") 48 | parser.add_argument( 49 | "--lambda_nbl", 50 | type=float, 51 | default=1, 52 | help="left neumann boundary weight for the PINN") 53 | parser.add_argument( 54 | "--lambda_nbb", 55 | type=float, 56 | default=1, 57 | help="lower neumann boundary weight for the PINN") 58 | parser.add_argument( 59 | "--lambda_nbt", 60 | type=float, 61 | default=1, 62 | help="top neumann boundary weight for the PINN") 63 | parser.add_argument( 64 | "--lambda_db", 65 | type=float, 66 | default=1, 67 | help="dirichlet boundary weight for the PINN") 68 | parser.add_argument( 69 | "--lambda_d", 70 | type=float, 71 | default=1, 72 | help="Data weight for the PINN") 73 | parser.add_argument( 74 | "--lambda_p", 75 | type=float, 76 | default=1, 77 | help="L2 reg weight") 78 | parser.add_argument( 79 | "--Nres", 80 | type=int, 81 | default=500, 82 | help="Number of reisudal points") 83 | parser.add_argument( 84 | "--Nsamples", 85 | type=int, 86 | default=500, 87 | help="Number of posterior samples") 88 | parser.add_argument( 89 | "--nIter", 90 | type=int, 91 | default=50000, 92 | help="Number of training epochs per realization") 93 | parser.add_argument( 94 | "--data_load", 95 | type=bool, 96 | default=False, 97 | help="If to load data") 98 | parser.add_argument( 99 | "--method", 100 | type=str, 101 | default='DE', 102 | help="Method for Bayesian training") 103 | args = parser.parse_args() 104 | 105 | print(f'jax is using: {jax.devices()} \n') 106 | 107 | #Define parameters 108 | layers_k = [2, 60, 60, 60, 60, 1] 109 | layers_h = [2, 60, 60, 60, 60, 1] 110 | num_params = 22442 111 | lbt = np.array([0., 0.]) 112 | ubt = np.array([1., 0.5]) 113 | num_print = 200 114 | Nk = 40 115 | Nh = 40 116 | rand_seed = args.rand_seed 117 | sigma = args.sigma 118 | lambda_r = args.lambda_r 119 | lambda_d = args.lambda_d 120 | lambda_nbl = args.lambda_nbl 121 | lambda_nbb = args.lambda_nbb 122 | lambda_nbt = args.lambda_nbt 123 | lambda_db = args.lambda_db 124 | lambda_p = args.lambda_p 125 | Nres = args.Nres 126 | Nsamples = args.Nsamples 127 | nIter = args.nIter 128 | model_load = args.data_load 129 | method = args.method 130 | dataset = dict() 131 | x = np.linspace(lbt[0], ubt[0], 256) 132 | y = np.linspace(lbt[1], ubt[1], 128) 133 | XX, YY = np.meshgrid(x,y) 134 | 135 | #Load data 136 | k_ref = np.loadtxt('k_ref_05.out', dtype=float) #(32768,) 137 | h_ref = np.loadtxt('h_ref_05.out', dtype=float) #(32768,) 138 | y_ref = np.log(k_ref) 139 | k_ref = y_ref 140 | h_ref = h_ref - h_ref.min(0) 141 | x_ref = np.loadtxt('coord_ref.out', dtype=float) #(32768, 2) 142 | N = k_ref.shape[0] 143 | 144 | #Create dataset 145 | #k, h measurements and residual points 146 | np.random.seed(rand_seed) 147 | #idx_k = np.random.choice(N, Nk, replace= False) 148 | idx_k = np.loadtxt('Nk_40_Nh_40_randseed_111_idxk.out').astype(np.int64) 149 | y_k, x_k = k_ref[idx_k][:,np.newaxis] + np.random.normal(0,sigma,(Nk,1)).astype(np.float32), x_ref[idx_k, :] 150 | k_data = jnp.concatenate([x_k,y_k], axis=1) 151 | #idx_h = np.random.choice(N, Nh, replace= False) 152 | idx_h = np.loadtxt('Nk_40_Nh_40_randseed_111_idxk.out').astype(np.int64) 153 | y_h, x_h = h_ref[idx_h][:,np.newaxis] + np.random.normal(0,sigma,(Nh,1)).astype(np.float32), x_ref[idx_h, :] 154 | h_data = jnp.concatenate([x_h,y_h], axis=1) 155 | 156 | x_nor = lhs(2,200000)[:Nres,:] 157 | x_res = lbt + (ubt -lbt) * x_nor 158 | y_res= np.zeros((Nres,1)) + np.random.normal(0,sigma,(Nres,1)).astype(np.float32) 159 | res = jnp.concatenate([x_res, y_res],axis=1) 160 | 161 | # Dirichlet BC at right 162 | x2_dbr = np.linspace(lbt[1],ubt[1],16)[:,np.newaxis] 163 | x1_dbr = ubt[0]*jnp.ones_like(x2_dbr) 164 | y_dbr = jnp.zeros_like(x2_dbr) + np.random.normal(0,sigma,(16,1)).astype(np.float32) 165 | dbr = jnp.concatenate([x1_dbr,x2_dbr,y_dbr],axis=1) 166 | 167 | # Neumann BC at lefth 168 | x2_nbl = np.linspace(lbt[1],ubt[1],16)[:,np.newaxis] 169 | x1_nbl = lbt[0]*jnp.ones_like(x2_nbl) 170 | y_nbl = jnp.ones_like(x2_nbl) + np.random.normal(0,sigma,(16,1)).astype(np.float32) 171 | nbl = jnp.concatenate([x1_nbl,x2_nbl,y_nbl],axis=1) 172 | 173 | # Neumann BC at top 174 | x1_nbt = np.linspace(lbt[0],ubt[0],32)[:,np.newaxis] 175 | x2_nbt = ubt[1]*jnp.ones_like(x1_nbt) 176 | y_nbt = jnp.zeros_like(x1_nbt) + np.random.normal(0,sigma,(32,1)).astype(np.float32) 177 | nbt = jnp.concatenate([x1_nbt,x2_nbt,y_nbt],axis=1) 178 | 179 | # Neumann BC at below 180 | x1_nbb = np.linspace(lbt[0],ubt[0],32)[:,np.newaxis] 181 | x2_nbb = lbt[1]*jnp.ones_like(x1_nbb) 182 | y_nbb = jnp.zeros_like(x1_nbb) + np.random.normal(0,sigma,(32,1)).astype(np.float32) 183 | nbb = jnp.concatenate([x1_nbb,x2_nbb,y_nbb],axis=1) 184 | 185 | dataset.update({'k_data': k_data}) 186 | dataset.update({'h_data': h_data}) 187 | dataset.update({'res': res}) 188 | dataset.update({'dbr': dbr}) 189 | dataset.update({'nbl': nbl}) 190 | dataset.update({'nbt': nbt}) 191 | dataset.update({'nbb': nbb}) 192 | 193 | path_f = f'2D_inverse_Nk_{Nk}_Nh_{Nh}_Nres_{Nres}_nIter_{nIter}_sigma_{sigma}_Nsamples_{Nsamples}_{method}_measurenorm' 194 | path_fig = os.path.join(path_f,'figures') 195 | if not os.path.exists(path_f): 196 | os.makedirs(path_f) 197 | if not os.path.exists(path_fig): 198 | os.makedirs(path_fig) 199 | f_rec = open(os.path.join(path_f,'record.out'), 'a+') 200 | 201 | print(f'method:{method} rand_seed:{rand_seed} nIter:{nIter}', file = f_rec) 202 | print(f'layers_k:{layers_k} layers_h:{layers_h}', file = f_rec) 203 | print(f'Nk:{Nk} Nh:{Nh} Nres:{Nres}\n', file = f_rec) 204 | print(f'sigma:{sigma} lambda_r:{lambda_r} lambda_nbl:{lambda_nbl} lambda_nbb:{lambda_nbb} lambda_nbt:{lambda_nbt} lambda_db:{lambda_db} lambda_p:{lambda_p}\n') 205 | print(f'sigma:{sigma} lambda_r:{lambda_r} lambda_nbl:{lambda_nbl} lambda_nbb:{lambda_nbb} lambda_nbt:{lambda_nbt} lambda_db:{lambda_db} lambda_p:{lambda_p}\n', file = f_rec) 206 | lambda_r = lambda_r*num_params/res.shape[0] 207 | lambda_d = lambda_d*num_params/k_data.shape[0] 208 | lambda_nbl = lambda_nbl*num_params/nbl.shape[0] 209 | lambda_nbb = lambda_nbb*num_params/nbb.shape[0] 210 | lambda_nbt = lambda_nbt*num_params/nbt.shape[0] 211 | lambda_db = lambda_db*num_params/dbr.shape[0] 212 | lambda_p = lambda_p 213 | print(f'Normalized PINN weights: lambda_r:{lambda_r} lambda_nbl:{lambda_nbl} lambda_nbb:{lambda_nbb} lambda_nbt:{lambda_nbt} lambda_db:{lambda_db} lambda_p:{lambda_p}\n') 214 | print(f'Normalized PINN weights: lambda_r:{lambda_r} lambda_nbl:{lambda_nbl} lambda_nbb:{lambda_nbb} lambda_nbt:{lambda_nbt} lambda_db:{lambda_db} lambda_p:{lambda_p}\n', file = f_rec) 215 | 216 | rl2e = lambda yest, yref : spl.norm(yest - yref, 2) / spl.norm(yref, 2) 217 | infe = lambda yest, yref : spl.norm(yest - yref, np.inf) 218 | lpp = lambda h, href, sigma: np.sum( -(h - href)**2/(2*sigma**2) - 1/2*np.log( 2*np.pi) - 2*np.log(sigma)) 219 | 220 | def pcolormesh(XX, YY, Z, points = None, title = None, savefig = None, cmap='jet'): 221 | fig, ax = plt.subplots(dpi = 300, figsize = (6,4)) 222 | c = ax.pcolormesh(XX, YY, Z, vmin = np.min(Z), vmax = np.max(Z), cmap=cmap) 223 | if points is not None: 224 | plt.plot(points[:,0], points[:,1], 'ko', markersize = 1.0) 225 | fig.colorbar(c, ax=ax, fraction= 0.05, pad= 0.05) 226 | ax.tick_params(axis='both', which = 'major', labelsize=16) 227 | ax.set_xlabel('$x_1$', fontsize=20) 228 | ax.set_ylabel('$x_2$', fontsize=20) 229 | if title is not None: 230 | ax.set_title(title, fontsize=14) 231 | fig.tight_layout() 232 | #ax.set_aspect('equal') 233 | if savefig is not None: 234 | plt.savefig(os.path.join(path_fig,f'{savefig}.png')) 235 | plt.show() 236 | 237 | # Define FNN 238 | def FNN(layers, activation=jnp.tanh): 239 | 240 | def init(prng_key): #return a list of (W,b) tuples 241 | def init_layer(key, d_in, d_out): 242 | key1, key2 = random.split(key) 243 | glorot_stddev = 1.0 / jnp.sqrt((d_in + d_out) / 2.) 244 | W = glorot_stddev * random.normal(key1, (d_in, d_out)) 245 | b = jnp.zeros(d_out) 246 | return W, b 247 | key, *keys = random.split(prng_key, len(layers)) 248 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 249 | return params 250 | 251 | def forward(params, inputs): 252 | Z = inputs 253 | for W, b in params[:-1]: 254 | outputs = jnp.dot(Z, W) + b 255 | Z = activation(outputs) 256 | W, b = params[-1] 257 | outputs = jnp.dot(Z, W) + b 258 | return outputs 259 | 260 | return init, forward 261 | 262 | # Define the model 263 | class PINN(): 264 | def __init__(self, key, layers_k, layers_h, dataset, lbt, ubt, sigma, lambda_r, lambda_d, 265 | lambda_nbl, lambda_nbb, lambda_nbt, lambda_db, lambda_p): 266 | 267 | self.lbt = lbt #domain lower corner 268 | self.ubt = ubt #domain upper corner 269 | self.sigma = sigma 270 | self.lambda_r = lambda_r 271 | self.lambda_d = lambda_d 272 | self.lambda_nbl = lambda_nbl 273 | self.lambda_nbb = lambda_nbb 274 | self.lambda_nbt = lambda_nbt 275 | self.lambda_db = lambda_db 276 | self.lambda_p = lambda_p 277 | 278 | self.gamma = sigma**2*lambda_d 279 | self.sigma_r = jnp.sqrt(self.gamma/self.lambda_r) 280 | self.sigma_d = jnp.sqrt(self.gamma/self.lambda_d) 281 | self.sigma_nbl = jnp.sqrt(self.gamma/self.lambda_nbl) 282 | self.sigma_nbb = jnp.sqrt(self.gamma/self.lambda_nbb) 283 | self.sigma_nbt = jnp.sqrt(self.gamma/self.lambda_nbt) 284 | self.sigma_db = jnp.sqrt(self.gamma/self.lambda_db) 285 | self.sigma_p = jnp.sqrt(self.gamma) 286 | 287 | # Prepare normalized training data 288 | self.dataset = dataset 289 | self.x_res, self.y_res = dataset['res'][:,0:2], dataset['res'][:,2:3] 290 | self.x_dbr, self.y_dbr = dataset['dbr'][:,0:2], dataset['dbr'][:,2:3] 291 | self.x_nbl, self.y_nbl = dataset['nbl'][:,0:2], dataset['nbl'][:,2:3] 292 | self.x_nbt, self.y_nbt = dataset['nbt'][:,0:2], dataset['nbt'][:,2:3] 293 | self.x_nbb, self.y_nbb = dataset['nbb'][:,0:2], dataset['nbb'][:,2:3] 294 | self.x_h, self.y_h = dataset['h_data'][:,0:2], dataset['h_data'][:,2:3] 295 | self.x_k, self.y_k = dataset['k_data'][:,0:2], dataset['k_data'][:,2:3] 296 | self.y_nb = np.hstack((self.y_nbl.ravel(), self.y_nbt.ravel(), self.y_nbb.ravel())) 297 | 298 | # Initalize the network 299 | key, *keys = random.split(key, num = 3) 300 | self.init_k, self.forward_k = FNN(layers_k, activation=jnp.tanh) 301 | self.params_k = self.init_k(keys[0]) 302 | raveled_k, self.unravel_k = ravel_pytree(self.params_k) 303 | self.num_params_k = raveled_k.shape[0] 304 | 305 | self.init_h, self.forward_h = FNN(layers_h, activation=jnp.tanh) 306 | self.params_h = self.init_h(keys[1]) 307 | raveled_h, self.unravel_h = ravel_pytree(self.params_h) 308 | self.num_params_h = raveled_h.shape[0] 309 | self.num_params = self.num_params_k + self.num_params_h 310 | 311 | # Evaluate the state, parameter and the residual over the grid 312 | self.h_pred_map = vmap(self.h_net, (None, 0, 0)) 313 | self.k_pred_map = vmap(self.k_net, (None, 0, 0)) 314 | self.r_pred_map = vmap(self.res_net, (None, 0, 0)) 315 | 316 | # Optimizer 317 | self.itercount = itertools.count() 318 | lr = optimizers.exponential_decay(1e-4, decay_steps=5000, decay_rate=0.9) 319 | self.opt_init, \ 320 | self.opt_update, \ 321 | self.get_params = optimizers.adam(lr) 322 | 323 | self.opt_state_k = self.opt_init(self.params_k) 324 | self.opt_state_h = self.opt_init(self.params_h) 325 | 326 | #Define random noise distributions 327 | #for h measurements 328 | self.h_dist = tfd.Independent(tfd.Normal(loc= jnp.zeros_like(self.y_h.ravel()), 329 | scale= self.sigma_d*jnp.ones_like(self.y_h.ravel())), reinterpreted_batch_ndims = 1) 330 | #for k measurements 331 | self.k_dist = tfd.Independent(tfd.Normal(loc= jnp.zeros_like(self.y_k.ravel()), 332 | scale= self.sigma_d*jnp.ones_like(self.y_k.ravel())), reinterpreted_batch_ndims = 1) 333 | # for residual term 334 | self.r_dist = tfd.Independent(tfd.Normal(loc= jnp.zeros_like(self.y_res.ravel()), 335 | scale= self.sigma_r*jnp.ones_like(self.y_res.ravel())), reinterpreted_batch_ndims = 1) 336 | #for bd measurements 337 | self.db_dist = tfd.Independent(tfd.Normal(loc= jnp.zeros_like(self.y_dbr.ravel()), 338 | scale= self.sigma_db*jnp.ones_like(self.y_dbr.ravel())), reinterpreted_batch_ndims = 1) 339 | # for Neumann boundary term 340 | self.nbl_dist = tfd.Independent(tfd.Normal(loc= jnp.zeros_like(self.y_nbl.ravel()), 341 | scale= self.sigma_nbl*jnp.ones_like(self.y_nbl.ravel())), reinterpreted_batch_ndims = 1) 342 | self.nbb_dist = tfd.Independent(tfd.Normal(loc= jnp.zeros_like(self.y_nbb.ravel()), 343 | scale= self.sigma_nbb*jnp.ones_like(self.y_nbb.ravel())), reinterpreted_batch_ndims = 1) 344 | self.nbt_dist = tfd.Independent(tfd.Normal(loc= jnp.zeros_like(self.y_nbt.ravel()), 345 | scale= self.sigma_nbt*jnp.ones_like(self.y_nbt.ravel())), reinterpreted_batch_ndims = 1) 346 | # for regularization term 347 | self.p_dist = tfd.Independent(tfd.Normal(loc= jnp.zeros((self.num_params,)), 348 | scale= self.sigma_p*jnp.ones((self.num_params,))), reinterpreted_batch_ndims = 1) 349 | 350 | @partial(jit, static_argnums=(0,)) 351 | def h_net(self, params, x1, x2): #no problem 352 | inputs = jnp.hstack([x1, x2]) 353 | outputs = self.forward_h(params[1], inputs) 354 | return outputs[0] 355 | 356 | @partial(jit, static_argnums=(0,)) 357 | def k_net(self, params, x1, x2): #no problem 358 | inputs = jnp.hstack([x1, x2]) 359 | outputs = self.forward_k(params[0], inputs) 360 | return outputs[0] 361 | 362 | @partial(jit, static_argnums=(0,)) 363 | def qx(self, params, x1, x2): 364 | k = jnp.exp(self.k_net(params, x1, x2)) 365 | #k = self.k_net(params, x1, x2) 366 | dhdx = grad(self.h_net, argnums=1)(params, x1, x2) 367 | return -k*dhdx 368 | 369 | @partial(jit, static_argnums=(0,)) 370 | def qy(self, params, x1, x2): 371 | k = jnp.exp(self.k_net(params, x1, x2)) 372 | #k = self.k_net(params, x1, x2) 373 | dhdy = grad(self.h_net, argnums=2)(params, x1, x2) 374 | return -k*dhdy 375 | 376 | @partial(jit, static_argnums=(0,)) 377 | def res_net(self, params, x1, x2): 378 | dhdx2 = grad(self.qx, argnums=1)(params, x1, x2) 379 | dhdy2 = grad(self.qy, argnums=2)(params, x1, x2) 380 | return dhdx2 + dhdy2 381 | 382 | #loss function 383 | @partial(jit, static_argnums=(0,)) 384 | def loss_r(self, params, r_noise): 385 | r_pred = vmap(self.res_net, (None, 0, 0))(params, self.x_res[:,0], self.x_res[:,1]) 386 | loss_res = jnp.sum((r_pred.flatten() - self.y_res.flatten() - r_noise)**2) 387 | return loss_res 388 | 389 | @partial(jit, static_argnums=(0,)) 390 | def loss_k(self, params, k_noise): 391 | k_pred = vmap(self.k_net, (None, 0, 0))(params, self.x_k[:,0], self.x_k[:,1]) 392 | loss_k = jnp.sum((k_pred.flatten() - self.y_k.flatten() - k_noise)**2) 393 | return loss_k 394 | 395 | @partial(jit, static_argnums=(0,)) 396 | def loss_h(self, params, h_noise): 397 | h_pred = vmap(self.h_net, (None, 0, 0 ))(params, self.x_h[:,0], self.x_h[:,1]) 398 | loss_h = jnp.sum((h_pred.flatten() - self.y_h.flatten() - h_noise)**2) 399 | return loss_h 400 | 401 | @partial(jit, static_argnums=(0,)) 402 | def loss_db(self, params, db_noise): 403 | h_pred = vmap(self.h_net, (None, 0, 0))(params, self.x_dbr[:,0], self.x_dbr[:,1]) 404 | loss_db = jnp.sum((self.y_dbr.flatten() - h_pred.flatten() - db_noise)**2) 405 | return loss_db 406 | 407 | @partial(jit, static_argnums=(0,)) 408 | def loss_nbl(self, params, nbl_noise): 409 | loss_nbl = jnp.sum((vmap(self.qx, (None, 0, 0))(params, self.x_nbl[:,0], self.x_nbl[:,1]).flatten() - self.y_nbl.flatten() - nbl_noise)**2) 410 | return loss_nbl 411 | 412 | @partial(jit, static_argnums=(0,)) 413 | def loss_nbb(self, params, nbb_noise): 414 | loss_nbb = jnp.sum((vmap(self.qy, (None, 0, 0))(params, self.x_nbb[:,0], self.x_nbb[:,1]).flatten() - self.y_nbb.flatten() - nbb_noise)**2) 415 | return loss_nbb 416 | 417 | @partial(jit, static_argnums=(0,)) 418 | def loss_nbt(self, params, nbt_noise): 419 | loss_nbt = jnp.sum((vmap(self.qy, (None, 0, 0))(params, self.x_nbt[:,0], self.x_nbt[:,1]).flatten() - self.y_nbt.flatten() - nbt_noise)**2) 420 | return loss_nbt 421 | 422 | @partial(jit, static_argnums=(0,)) 423 | def l2_reg(self, params, p_noise): 424 | return jnp.sum((ravel_pytree(params)[0] - p_noise)**2) 425 | 426 | @partial(jit, static_argnums=(0,)) 427 | def loss(self, params, k_noise, h_noise, r_noise, db_noise, nbl_noise, nbb_noise, nbt_noise, p_noise): 428 | return 1/self.sigma_r**2*self.loss_r(params, r_noise) + \ 429 | 1/self.sigma_d**2*self.loss_k(params, k_noise) + 1/self.sigma_d**2*self.loss_h(params, h_noise) +\ 430 | 1/self.sigma_db**2*self.loss_db(params, db_noise) + \ 431 | 1/self.sigma_nbl**2*self.loss_nbl(params, nbl_noise) + 1/self.sigma_nbb**2*self.loss_nbb(params, nbb_noise) + 1/self.sigma_nbt**2*self.loss_nbt(params, nbt_noise) +\ 432 | 1/self.sigma_p**2*self.l2_reg(params, p_noise) 433 | 434 | @partial(jit, static_argnums=(0,)) 435 | def step(self, i, opt_state_k, opt_state_h, k_noise, h_noise, r_noise, db_noise, nbl_noise, nbb_noise, nbt_noise, p_noise): 436 | params_k = self.get_params(opt_state_k) 437 | params_h = self.get_params(opt_state_h) 438 | params = [params_k, params_h] 439 | g = grad(self.loss)(params, k_noise, h_noise, r_noise, db_noise, nbl_noise, nbb_noise, nbt_noise, p_noise) 440 | 441 | return self.opt_update(i, g[0], opt_state_k), self.opt_update(i, g[1], opt_state_h) 442 | 443 | # def train(self, nIter, num_print, k_noise, h_noise, r_noise, db_noise, nbl_noise, nbb_noise, nbt_noise, p_noise): 444 | # pbar = trange(nIter) 445 | # # Main training loop 446 | # for it in pbar: 447 | # self.current_count = next(self.itercount) 448 | # self.opt_state_k, self.opt_state_h = self.step(self.current_count, self.opt_state_k, self.opt_state_h, \ 449 | # k_noise, h_noise, r_noise, db_noise, nbl_noise, nbb_noise, nbt_noise, p_noise) 450 | # if it % num_print == 0: 451 | # params_k = self.get_params(self.opt_state_k) 452 | # params_h = self.get_params(self.opt_state_h) 453 | # params = [params_k, params_h] 454 | 455 | # loss_value = self.loss(params,k_noise, h_noise, r_noise, db_noise, nbl_noise, nbb_noise, nbt_noise, p_noise) 456 | # loss_r_value = self.loss_r(params, r_noise) 457 | # loss_k_value = self.loss_k(params, k_noise) 458 | # loss_h_value = self.loss_h(params, h_noise) 459 | 460 | # #loss_reg_value = self.l2_reg(params[0]) + self.l2_reg(params[1]) 461 | 462 | # pbar.set_postfix({'Loss': loss_value, 463 | # 'Loss_r': loss_r_value, 464 | # 'Loss_k': loss_k_value, 465 | # 'Loss_h': loss_h_value 466 | # }) 467 | 468 | # return [self.get_params(self.opt_state_k), self.get_params(self.opt_state_h)] 469 | 470 | def train(self, nIter, num_print, k_noise, h_noise, r_noise, db_noise, nbl_noise, nbb_noise, nbt_noise, p_noise): 471 | pbar = trange(nIter) 472 | # Main training loop 473 | for it in pbar: 474 | self.current_count = next(self.itercount) 475 | self.opt_state_k, self.opt_state_h = self.step(self.current_count, self.opt_state_k, self.opt_state_h, \ 476 | k_noise, h_noise, r_noise, db_noise, nbl_noise, nbb_noise, nbt_noise, p_noise) 477 | if it % num_print == 0: 478 | params_k = self.get_params(self.opt_state_k) 479 | params_h = self.get_params(self.opt_state_h) 480 | params = [params_k, params_h] 481 | 482 | loss_value = self.loss(params,k_noise, h_noise, r_noise, db_noise, nbl_noise, nbb_noise, nbt_noise, p_noise) 483 | 484 | pbar.set_postfix({'Loss': loss_value}) 485 | 486 | return [self.get_params(self.opt_state_k), self.get_params(self.opt_state_h)] 487 | 488 | def rpinn_sample(self, Nsample, nIter, num_print, key): 489 | #sample with randomized PINN 490 | params_sample = [] 491 | # alpha_sample = [] 492 | # beta_sample = [] 493 | # omega_sample = [] 494 | 495 | for it in range(Nsample): 496 | key, *keys = random.split(key, 11) 497 | #key, *keys = random.split(key, 4) 498 | k_noise = self.k_dist.sample(1, keys[0])[0] 499 | h_noise = self.h_dist.sample(1, keys[1])[0] 500 | r_noise = self.r_dist.sample(1, keys[2])[0] 501 | db_noise = self.db_dist.sample(1, keys[3])[0] 502 | nbl_noise = self.nbl_dist.sample(1, keys[4])[0] 503 | nbb_noise = self.nbb_dist.sample(1, keys[5])[0] 504 | nbt_noise = self.nbt_dist.sample(1, keys[6])[0] 505 | p_noise = self.p_dist.sample(1, keys[7])[0] 506 | 507 | lr = optimizers.exponential_decay(1e-4, decay_steps=5000, decay_rate=0.9) 508 | self.opt_init, \ 509 | self.opt_update, \ 510 | self.get_params = optimizers.adam(lr) 511 | self.params_k = self.init_k(keys[8]) 512 | self.params_h = self.init_k(keys[9]) 513 | self.opt_state_k = self.opt_init(self.params_k) 514 | self.opt_state_h = self.opt_init(self.params_h) 515 | 516 | self.itercount = itertools.count() 517 | params = self.train(nIter, num_print, k_noise, h_noise, r_noise, db_noise, nbl_noise, nbb_noise, nbt_noise, p_noise) 518 | 519 | params_sample.append(ravel_pytree(params)[0]) 520 | print(f'{it}-th sample finished') 521 | return jnp.array(params_sample) 522 | 523 | def de_sample(self, Nsample, nIter, num_print, key): 524 | #sample with deep ensemble 525 | params_sample = [] 526 | 527 | for it in range(Nsample): 528 | key, *keys = random.split(key, 3) 529 | k_noise = 0 530 | h_noise = 0 531 | r_noise = 0 532 | db_noise = 0 533 | nbl_noise = 0 534 | nbb_noise = 0 535 | nbt_noise = 0 536 | p_noise = 0 537 | 538 | lr = optimizers.exponential_decay(1e-4, decay_steps=5000, decay_rate=0.9) 539 | self.opt_init, \ 540 | self.opt_update, \ 541 | self.get_params = optimizers.adam(lr) 542 | self.params_k = self.init_k(keys[0]) 543 | self.params_h = self.init_k(keys[1]) 544 | self.opt_state_k = self.opt_init(self.params_k) 545 | self.opt_state_h = self.opt_init(self.params_h) 546 | 547 | self.itercount = itertools.count() 548 | params = self.train(nIter, num_print, k_noise, h_noise, r_noise, db_noise, nbl_noise, nbb_noise, nbt_noise, p_noise) 549 | 550 | params_sample.append(ravel_pytree(params)[0]) 551 | print(f'{it}-th sample finished') 552 | return jnp.array(params_sample) 553 | 554 | 555 | key = random.PRNGKey(rand_seed) 556 | key, subkey = random.split(key, 2) 557 | pinn = PINN(key, layers_k, layers_h, dataset, lbt, ubt, sigma, lambda_r, lambda_d, 558 | lambda_nbl, lambda_nbb, lambda_nbt, lambda_db, lambda_p) 559 | if model_load == False: 560 | if method == 'rPINN': 561 | samples = pinn.rpinn_sample(Nsamples, nIter = nIter, 562 | num_print = num_print, key = subkey) 563 | elif method == 'rPINN_metro': 564 | samples = pinn.rpinn_sample_metro(Nsamples, nIter = nIter, 565 | num_print = num_print, key = subkey) 566 | else: 567 | samples = pinn.de_sample(Nsamples, nIter = nIter, num_print = num_print, key = subkey) 568 | 569 | np.savetxt(os.path.join(path_f,f'{method}_posterior_samples_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}.out'), samples) 570 | else: 571 | samples = np.loadtxt(os.path.join(path_f,f'{method}_posterior_samples_Nres_{Nres}_sigma_{sigma}_Nsamples_{Nsamples}_nIter_{nIter}.out')) 572 | 573 | print(f'sigma_r:{pinn.sigma_r} sigma_d:{pinn.sigma_d} sigma_nbl:{pinn.sigma_nbl} sigma_nbb:{pinn.sigma_nbb} sigma_nbt:{pinn.sigma_nbt} sigma_db:{pinn.sigma_db} sigma_p:{pinn.sigma_p}\n') 574 | print(f'sigma_r:{pinn.sigma_r} sigma_d:{pinn.sigma_d} sigma_nbl:{pinn.sigma_nbl} sigma_nbb:{pinn.sigma_nbb} sigma_nbt:{pinn.sigma_nbt} sigma_db:{pinn.sigma_db} sigma_p:{pinn.sigma_p}', file = f_rec) 575 | 576 | @jit 577 | def get_h_pred(sample): 578 | params_k = pinn.unravel_k(sample[:pinn.num_params_k]) 579 | params_h = pinn.unravel_h(sample[pinn.num_params_k:]) 580 | params = [params_k, params_h] 581 | return pinn.h_pred_map(params,x_ref[:,0], x_ref[:,1]) 582 | 583 | @jit 584 | def get_k_pred(sample): 585 | params_k = pinn.unravel_k(sample[:pinn.num_params_k]) 586 | params_h = pinn.unravel_h(sample[pinn.num_params_k:]) 587 | params = [params_k, params_h] 588 | return pinn.k_pred_map(params,x_ref[:,0], x_ref[:,1]) 589 | 590 | @jit 591 | def get_r_pred(sample): 592 | params_k = pinn.unravel_k(sample[:pinn.num_params_k]) 593 | params_h = pinn.unravel_h(sample[pinn.num_params_k:]) 594 | params = [params_k, params_h] 595 | return pinn.r_pred_map(params,x_ref[:,0], x_ref[:,1]) 596 | 597 | # h_pred_ens = np.array([vmap(get_h_pred)(samples[i,:,:]) for i in range(samples.shape[0])]) #(Nchains, Nsamples, 32768) 598 | # k_pred_ens = np.array([vmap(get_k_pred)(samples[i,:,:]) for i in range(samples.shape[0])]) 599 | 600 | h_pred_ens = vmap(get_h_pred)(samples[::,:]) 601 | k_pred_ens = vmap(get_k_pred)(samples[::,:]) 602 | 603 | h_pred_ens_mean = np.mean(h_pred_ens, axis = 0) #(Nchains, 32768) 604 | h_pred_ens_std = np.std(h_pred_ens, axis = 0) 605 | k_pred_ens_mean = np.mean(k_pred_ens, axis = 0) 606 | k_pred_ens_std = np.std(k_pred_ens, axis = 0) 607 | 608 | h_env = np.logical_and( (h_pred_ens_mean < h_ref.ravel() + 2*h_pred_ens_std), (h_pred_ens_mean > h_ref.ravel() - 2*h_pred_ens_std) ) 609 | k_env = np.logical_and( (k_pred_ens_mean < k_ref.ravel() + 2*k_pred_ens_std), (k_pred_ens_mean > k_ref.ravel() - 2*k_pred_ens_std) ) 610 | 611 | rl2e_h = rl2e(h_pred_ens_mean, h_ref) 612 | infe_h = infe(h_pred_ens_mean, h_ref) 613 | lpp_h = lpp(h_pred_ens_mean, h_ref, h_pred_ens_std) 614 | rl2e_k = rl2e(k_pred_ens_mean, k_ref) 615 | infe_k = infe(k_pred_ens_mean, k_ref) 616 | lpp_k = lpp(k_pred_ens_mean, k_ref, k_pred_ens_std) 617 | 618 | print('h prediction:\n') 619 | print('Relative RL2 error: {}'.format(rl2e_h)) 620 | print('Absolute inf error: {}'.format(infe_h)) 621 | print('Average standard deviation: {}'.format(np.mean(h_pred_ens_std))) 622 | print('log predictive probability: {}'.format(lpp_h)) 623 | print('Percentage of coverage:{}\n'.format(np.sum(h_env)/32768)) 624 | 625 | print('k prediction:\n') 626 | print('Relative RL2 error: {}'.format(rl2e_k)) 627 | print('Absolute inf error: {}'.format(infe_k)) 628 | print('Average standard deviation: {}'.format(np.mean(k_pred_ens_std))) 629 | print('log predictive probability: {}'.format(lpp_k)) 630 | print('Percentage of coverage:{}\n'.format(np.sum(k_env)/32768)) 631 | 632 | print('h prediction:\n', file = f_rec) 633 | print('Relative RL2 error: {}'.format(rl2e_h), file = f_rec) 634 | print('Absolute inf error: {}'.format(infe_h), file = f_rec) 635 | print('Average standard deviation: {}'.format(np.mean(h_pred_ens_std)), file = f_rec) 636 | print('log predictive probability: {}'.format(lpp_h), file = f_rec) 637 | print('Percentage of coverage:{}\n'.format(np.sum(h_env)/32768), file = f_rec) 638 | 639 | print('k prediction:\n', file = f_rec) 640 | print('Relative RL2 error: {}'.format(rl2e_k), file = f_rec) 641 | print('Absolute inf error: {}'.format(infe_k), file = f_rec) 642 | print('Average standard deviation: {}'.format(np.mean(k_pred_ens_std)), file = f_rec) 643 | print('log predictive probability: {}'.format(lpp_k), file = f_rec) 644 | print('Percentage of coverage:{}\n'.format(np.sum(k_env)/32768), file = f_rec) 645 | 646 | #Plot of k field 647 | pcolormesh(XX, YY, k_ref.reshape(128,256), points = None, title = None, savefig = 'y_ref') 648 | savefig = f'2D_gwf_rPINN_sigma_{sigma}_Nsamples_{Nsamples}_ypred_mean' 649 | pcolormesh(XX, YY, k_pred_ens_mean.reshape(128,256), points = None, title = None, savefig = savefig) 650 | savefig = f'2D_gwf_rPINN_sigma_{sigma}_Nsamples_{Nsamples}_ypred_std' 651 | pcolormesh(XX, YY, k_pred_ens_std.reshape(128,256), points = x_k, title = None, savefig = savefig) 652 | savefig = f'2D_gwf_rPINN_sigma_{sigma}_Nsamples_{Nsamples}_ypred_diff' 653 | pcolormesh(XX, YY, np.abs(k_pred_ens_mean.reshape(128,256) - k_ref.reshape(128,256)), points = x_k, title = None, savefig = savefig) 654 | savefig = f'2D_gwf_rPINN_sigma_{sigma}_Nsamples_{Nsamples}_ypred_env' 655 | pcolormesh(XX, YY, k_env.reshape(128,256), points = x_k, title = None, savefig = savefig) 656 | 657 | #Plot of h field 658 | pcolormesh(XX, YY, h_ref.reshape(128,256), points = None, title = None, savefig = 'h_ref') 659 | savefig = f'2D_gwf_rPINN_sigma_{sigma}_Nsamples_{Nsamples}_hpred_mean' 660 | pcolormesh(XX, YY, h_pred_ens_mean.reshape(128,256), points = None, title = None, savefig = savefig ) 661 | savefig = f'2D_gwf_rPINN_sigma_{sigma}_Nsamples_{Nsamples}_hpred_std' 662 | pcolormesh(XX, YY, h_pred_ens_std.reshape(128,256), points = x_h, title = None, savefig = savefig ) 663 | savefig = f'2D_gwf_rPINN_sigma_{sigma}_Nsamples_{Nsamples}_hpred_diff' 664 | pcolormesh(XX, YY, np.abs(h_pred_ens_mean.reshape(128,256) - h_ref.reshape(128,256)), points = x_h, title = None, savefig = savefig ) 665 | savefig = f'2D_gwf_rPINN_sigma_{sigma}_Nsamples_{Nsamples}_hpred_env' 666 | pcolormesh(XX, YY, h_env.reshape(128,256), points = x_h, title = None, savefig = savefig ) 667 | 668 | f_rec.close() -------------------------------------------------------------------------------- /2D_GWF_problem/Nk_40_Nh_40_randseed_111_idxh.out: -------------------------------------------------------------------------------- 1 | 1.446500000000000000e+04 2 | 1.556200000000000000e+04 3 | 1.932200000000000000e+04 4 | 1.803700000000000000e+04 5 | 1.553100000000000000e+04 6 | 7.811000000000000000e+03 7 | 2.740100000000000000e+04 8 | 2.190800000000000000e+04 9 | 2.944000000000000000e+03 10 | 2.846800000000000000e+04 11 | 2.095900000000000000e+04 12 | 2.348500000000000000e+04 13 | 8.834000000000000000e+03 14 | 2.893600000000000000e+04 15 | 8.292000000000000000e+03 16 | 1.991200000000000000e+04 17 | 1.042600000000000000e+04 18 | 6.820000000000000000e+03 19 | 3.176700000000000000e+04 20 | 1.295500000000000000e+04 21 | 1.502000000000000000e+03 22 | 2.298200000000000000e+04 23 | 1.578100000000000000e+04 24 | 4.358000000000000000e+03 25 | 1.139000000000000000e+04 26 | 1.264000000000000000e+04 27 | 2.831600000000000000e+04 28 | 6.432000000000000000e+03 29 | 2.699100000000000000e+04 30 | 2.716000000000000000e+04 31 | 1.952000000000000000e+03 32 | 2.046200000000000000e+04 33 | 2.541200000000000000e+04 34 | 2.568600000000000000e+04 35 | 2.177000000000000000e+03 36 | 1.741500000000000000e+04 37 | 2.005000000000000000e+03 38 | 1.383000000000000000e+04 39 | 2.605100000000000000e+04 40 | 3.122000000000000000e+04 41 | -------------------------------------------------------------------------------- /2D_GWF_problem/Nk_40_Nh_40_randseed_111_idxk.out: -------------------------------------------------------------------------------- 1 | 1.290000000000000000e+03 2 | 1.781500000000000000e+04 3 | 8.607000000000000000e+03 4 | 1.330000000000000000e+03 5 | 1.359900000000000000e+04 6 | 2.952500000000000000e+04 7 | 1.576700000000000000e+04 8 | 3.267700000000000000e+04 9 | 5.655000000000000000e+03 10 | 3.901000000000000000e+03 11 | 6.662000000000000000e+03 12 | 1.372700000000000000e+04 13 | 1.148800000000000000e+04 14 | 6.140000000000000000e+03 15 | 1.188700000000000000e+04 16 | 2.959600000000000000e+04 17 | 2.178700000000000000e+04 18 | 4.667000000000000000e+03 19 | 2.136900000000000000e+04 20 | 2.220000000000000000e+03 21 | 2.618800000000000000e+04 22 | 9.147000000000000000e+03 23 | 2.339600000000000000e+04 24 | 2.256900000000000000e+04 25 | 2.616000000000000000e+03 26 | 1.795800000000000000e+04 27 | 1.493600000000000000e+04 28 | 2.899800000000000000e+04 29 | 1.893600000000000000e+04 30 | 1.604300000000000000e+04 31 | 1.609800000000000000e+04 32 | 1.559000000000000000e+03 33 | 8.383000000000000000e+03 34 | 1.781700000000000000e+04 35 | 2.547400000000000000e+04 36 | 2.177500000000000000e+04 37 | 2.467000000000000000e+04 38 | 1.545200000000000000e+04 39 | 2.762900000000000000e+04 40 | 3.307000000000000000e+03 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Randomized Physics-informed Neural Networks (rPINN) 2 | 3 | Implementations of the "randomize-then-optimize" approach for sampling Physics-informed Neural Network (PINN) posteriors in the Bayesian framework, as described in the paper: 4 | 5 | > **Randomized physics-informed neural networks for Bayesian data assimilation** 6 | > *Computer Methods in Applied Mechanics and Engineering, 2024* 7 | > [https://doi.org/10.1016/j.cma.2024.117670](https://doi.org/10.1016/j.cma.2024.117670) 8 | 9 | --- 10 | 11 | ## Overview 12 | This repo contains JAX-based implementations of randomized PINNs (rPINNs), which inject randomness into the PINN loss function to generate posterior samples for inverse problems governed by partial differential equations (PDEs). This method provides a scalable alternative to traditional Bayesian inference techniques for physics-constrained problems. 13 | 14 | Key features: 15 | - Efficient and scalable uncertainty quantification in PDE inverse problems with data assimilation 16 | - Outperforms Hamiltonian Monte Carlo (HMC) and Stein Variational Gradient Descent (SVGD) methods, while they suffer from the curse of dimensionality and ill-conditioned posterior covariance structure. Additionally, rPINN is highly parallelizable. 17 | - We propose a weighted-likelihood Bayesian PINN formulation to balance contributions from different terms (e.g., PDE, IC, BC residuals, measurements). 18 | 19 | 20 | The following inverse PDE problems are included: 21 | - 1D Linear Poisson Equation 22 | - 1D Non-Linear Poisson Equation 23 | - 2D Diffusion Equation with Spatially Varying Coefficient 24 | 25 | If you find this code useful for your research, please cite the following paper: 26 | ```bibtex 27 | @article{zong2025randomized, 28 | title={Randomized physics-informed neural networks for Bayesian data assimilation}, 29 | author={Zong, Yifei and Barajas-Solano, David and Tartakovsky, Alexandre M}, 30 | journal={Computer Methods in Applied Mechanics and Engineering}, 31 | volume={436}, 32 | pages={117670}, 33 | year={2025}, 34 | publisher={Elsevier} 35 | } 36 | ``` 37 | 38 | Here is another paper that we have published using the randomized physics-informed conditional Karhunen-Loève expansion (rPICKLE) method for uncertainty quantification in high-dimensional PDE-constrained inverse problems, a real application on the Hanford Site subsurface problem. 39 | ```bibtex 40 | @article{zong2024randomized, 41 | title={Randomized physics-informed machine learning for uncertainty quantification in high-dimensional inverse problems}, 42 | author={Zong, Yifei and Barajas-Solano, David and Tartakovsky, Alexandre M}, 43 | journal={Journal of Computational Physics}, 44 | volume={519}, 45 | pages={113395}, 46 | year={2024}, 47 | publisher={Elsevier} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /rPINN_schematization.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geekyifei/randomized-PINN/b1d735cd3cf35e741db68e45cb4fca99e282da66/rPINN_schematization.PNG --------------------------------------------------------------------------------