├── Allen_Cahn └── AC.ipynb ├── KS ├── chaotic_KS.py └── regular_KS.py ├── LICENSE ├── Lorentz └── Causal_PINNs_lorentz.py ├── NS └── NS.py ├── README.md ├── animations ├── AC.mp4 ├── KS.mp4 ├── NS.mp4 └── Readme.md ├── data ├── AC.mat ├── NS.npy ├── ks_chaotic.mat └── ks_simple.mat └── requirements.txt /KS/chaotic_KS.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax.numpy as np 3 | from jax import random, grad, vmap, jit, jacfwd, jacrev 4 | from jax.example_libraries import optimizers 5 | from jax.experimental.ode import odeint 6 | from jax.experimental.jet import jet 7 | from jax.nn import relu 8 | from jax.config import config 9 | from jax import lax 10 | from jax.flatten_util import ravel_pytree 11 | import itertools 12 | from functools import partial 13 | from torch.utils import data 14 | from tqdm import trange 15 | 16 | import scipy.io 17 | from scipy.interpolate import griddata 18 | import matplotlib.pyplot as plt 19 | 20 | 21 | # Define the neural net 22 | def modified_MLP(layers, L=1.0, M_t=1, M_x=1, activation=relu): 23 | def xavier_init(key, d_in, d_out): 24 | glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.) 25 | W = glorot_stddev * random.normal(key, (d_in, d_out)) 26 | b = np.zeros(d_out) 27 | return W, b 28 | 29 | # Define input encoding function 30 | def input_encoding(t, x): 31 | w = 2 * np.pi / L 32 | k_t = np.power(10, np.arange(-M_t//2, M_t//2)) 33 | k_x = np.arange(1, M_x + 1) 34 | 35 | out = np.hstack([k_t * t , 36 | 1, np.cos(k_x * w * x), np.sin(k_x * w * x)]) 37 | return out 38 | 39 | 40 | def init(rng_key): 41 | U1, b1 = xavier_init(random.PRNGKey(12345), layers[0], layers[1]) 42 | U2, b2 = xavier_init(random.PRNGKey(54321), layers[0], layers[1]) 43 | def init_layer(key, d_in, d_out): 44 | k1, k2 = random.split(key) 45 | W, b = xavier_init(k1, d_in, d_out) 46 | return W, b 47 | key, *keys = random.split(rng_key, len(layers)) 48 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 49 | return (params, U1, b1, U2, b2) 50 | 51 | def apply(params, inputs): 52 | params, U1, b1, U2, b2 = params 53 | 54 | t = inputs[0] 55 | x = inputs[1] 56 | inputs = input_encoding(t, x) 57 | U = activation(np.dot(inputs, U1) + b1) 58 | V = activation(np.dot(inputs, U2) + b2) 59 | for W, b in params[:-1]: 60 | outputs = activation(np.dot(inputs, W) + b) 61 | inputs = np.multiply(outputs, U) + np.multiply(1 - outputs, V) 62 | W, b = params[-1] 63 | outputs = np.dot(inputs, W) + b 64 | return outputs 65 | return init, apply 66 | 67 | 68 | class DataGenerator(data.Dataset): 69 | def __init__(self, t0, t1, n_t=10, n_x=64, rng_key=random.PRNGKey(1234)): 70 | 'Initialization' 71 | self.t0 = t0 72 | self.t1 = (1 + 0.01) * t1 73 | self.n_t = n_t 74 | self.n_x = n_x 75 | self.key = rng_key 76 | 77 | def __getitem__(self, index): 78 | 'Generate one batch of data' 79 | self.key, subkey = random.split(self.key) 80 | batch = self.__data_generation(subkey) 81 | return batch 82 | 83 | @partial(jit, static_argnums=(0,)) 84 | def __data_generation(self, key): 85 | 'Generates data containing batch_size samples' 86 | subkeys = random.split(key, 2) 87 | t_r = random.uniform(subkeys[0], shape=(self.n_t,), minval=self.t0, maxval=self.t1).sort() 88 | x_r = random.uniform(subkeys[1], shape=(self.n_x,), minval=0.0, maxval=2.0*np.pi) 89 | batch = (t_r, x_r) 90 | return batch 91 | 92 | 93 | # Define the model 94 | class PINN: 95 | def __init__(self, key, u_exact, arch, layers, M_t, M_x, state0, t0, t1, n_t, n_x, tol): 96 | 97 | self.u_exact = u_exact 98 | 99 | self.M_t = M_t 100 | self.M_x = M_x 101 | 102 | # grid 103 | self.n_t = n_t 104 | self.n_x = n_x 105 | 106 | self.t0 = t0 107 | self.t1 = t1 108 | eps = 0.01 * self.t1 109 | self.t_r = np.linspace(self.t0, self.t1 + eps, n_t) 110 | self.x_r = np.linspace(0, 2.0 * np.pi, n_x) 111 | 112 | # IC 113 | t_ic = np.zeros((x_star.shape[0], 1)) 114 | x_ic = x_star.reshape(-1, 1) 115 | self.X_ic = np.hstack([t_ic, x_ic]) 116 | self.Y_ic = state0 117 | 118 | # Weight matrix 119 | self.M = np.triu(np.ones((n_t, n_t)), k=1).T 120 | self.tol = tol 121 | 122 | 123 | d0 = 2 * M_x + M_t + 1 124 | layers = [d0] + layers 125 | self.init, self.apply = modified_MLP(layers, L=2.0*np.pi, M_t=self.M_t, M_x=self.M_x, activation=np.tanh) 126 | params = self.init(rng_key = key) 127 | 128 | # Use optimizers to set optimizer initialization and update functions 129 | self.opt_init, self.opt_update, self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 130 | decay_steps=5000, 131 | decay_rate=0.9)) 132 | self.opt_state = self.opt_init(params) 133 | _, self.unravel = ravel_pytree(params) 134 | 135 | 136 | self.u_pred_fn = vmap(vmap(self.neural_net, (None, 0, None)), (None, None, 0)) # consistent with the dataset 137 | self.r_pred_fn = vmap(vmap(self.residual_net, (None, None, 0)), (None, 0, None)) 138 | 139 | # Logger 140 | self.itercount = itertools.count() 141 | 142 | self.l2_error_log = [] 143 | self.loss_log = [] 144 | self.loss_ics_log = [] 145 | self.loss_res_log = [] 146 | 147 | def neural_net(self, params, t, x): 148 | z = np.stack([t, x]) 149 | outputs = self.apply(params, z) 150 | return outputs[0] 151 | 152 | def residual_net(self, params, t, x): 153 | u = self.neural_net(params, t, x) 154 | u_t = grad(self.neural_net, argnums=1)(params, t, x) 155 | 156 | u_fn = lambda x: self.neural_net(params, t, x) 157 | _, (u_x, u_xx, u_xxx, u_xxxx) = jet(u_fn, (x, ), [[1.0, 0.0, 0.0, 0.0]]) 158 | 159 | return u_t + 100.0 / 16.0 * u * u_x + 100.0 / 16.0**2 * u_xx + 100.0 / 16.0**4 * u_xxxx 160 | 161 | 162 | @partial(jit, static_argnums=(0,)) 163 | def residuals_and_weights(self, params, batch, tol): 164 | t_r, x_r = batch 165 | L_0 = 1e4 * self.loss_ics(params) 166 | r_pred = self.r_pred_fn(params, t_r, x_r) 167 | L_t = np.mean(r_pred**2, axis=1) 168 | W = lax.stop_gradient(np.exp(- tol * (self.M @ L_t + L_0) )) 169 | return L_0, L_t, W 170 | 171 | @partial(jit, static_argnums=(0,)) 172 | def loss_ics(self, params): 173 | # Compute forward pass 174 | u_pred = vmap(self.neural_net, (None, 0, 0))(params, self.X_ic[:,0], self.X_ic[:,1]) 175 | # Compute loss 176 | loss_ics = np.mean((self.Y_ic.flatten() - u_pred.flatten())**2) 177 | return loss_ics 178 | 179 | 180 | @partial(jit, static_argnums=(0,)) 181 | def loss_res(self, params, batch): 182 | t_r, x_r = batch 183 | # Compute forward pass 184 | r_pred = self.r_pred_fn(params, t_r, x_r) 185 | # Compute loss 186 | loss_r = np.mean(r_pred**2) 187 | return loss_r 188 | 189 | @partial(jit, static_argnums=(0,)) 190 | def loss(self, params, batch): 191 | L_0, L_t, W = self.residuals_and_weights(params, batch, self.tol) 192 | # Compute loss 193 | loss = np.mean(W * L_t + L_0) 194 | return loss 195 | 196 | @partial(jit, static_argnums=(0,)) 197 | def compute_l2_error(self, params): 198 | u_pred = self.u_pred_fn(params, t_star[:num_step], x_star) 199 | l2_error = np.linalg.norm(u_pred - self.u_exact) / np.linalg.norm(self.u_exact) 200 | return l2_error 201 | 202 | # Define a compiled update step 203 | @partial(jit, static_argnums=(0,)) 204 | def step(self, i, opt_state, batch): 205 | params = self.get_params(opt_state) 206 | g = grad(self.loss)(params, batch) 207 | 208 | return self.opt_update(i, g, opt_state) 209 | 210 | # Optimize parameters in a loop 211 | def train(self, dataset, nIter = 10000): 212 | res_data = iter(dataset) 213 | pbar = trange(nIter) 214 | # Main training loop 215 | for it in pbar: 216 | batch= next(res_data) 217 | self.current_count = next(self.itercount) 218 | self.opt_state = self.step(self.current_count, self.opt_state, batch) 219 | 220 | if it % 1000 == 0: 221 | params = self.get_params(self.opt_state) 222 | 223 | 224 | l2_error_value = self.compute_l2_error(params) 225 | loss_value = self.loss(params, batch) 226 | 227 | loss_ics_value = self.loss_ics(params) 228 | loss_res_value = self.loss_res(params, batch) 229 | 230 | _, _, W_value = self.residuals_and_weights(params, batch, self.tol) 231 | 232 | self.l2_error_log.append(l2_error_value) 233 | self.loss_log.append(loss_value) 234 | self.loss_ics_log.append(loss_ics_value) 235 | self.loss_res_log.append(loss_res_value) 236 | 237 | pbar.set_postfix({'l2 error': l2_error_value, 238 | 'Loss': loss_value, 239 | 'loss_ics' : loss_ics_value, 240 | 'loss_res': loss_res_value, 241 | 'W_min' : W_value.min()}) 242 | 243 | if W_value.min() > 0.99: 244 | break 245 | 246 | # Evaluates predictions at test points 247 | @partial(jit, static_argnums=(0,)) 248 | def predict_u(self, params, X_star): 249 | u_pred = vmap(self.u_net, (None, 0, 0))(params, X_star[:,0], X_star[:,1]) 250 | return u_pred 251 | 252 | 253 | data = scipy.io.loadmat('../ks_chaotic.mat') 254 | # Test data 255 | usol = data['usol'] 256 | 257 | t_star = data['t'][0] 258 | x_star = data['x'][0] 259 | TT, XX = np.meshgrid(t_star, x_star) 260 | X_star = np.hstack((TT.flatten()[:, None], XX.flatten()[:, None])) 261 | 262 | 263 | 264 | # Hpyer-parameters 265 | key = random.PRNGKey(1234) 266 | M_t = 6 267 | M_x = 5 268 | layers = [128, 128, 128, 128, 128, 128, 128, 128, 1] 269 | num_step = 25 270 | t0 = 0.0 271 | t1 = t_star[num_step] 272 | n_t = 32 273 | n_x = 256 274 | 275 | tol = 1.0 276 | tol_list = [1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2] 277 | time_step = 0 278 | 279 | state0 = usol[:, time_step:time_step+1] 280 | t_star = data['t'][0][:num_step] 281 | x_star = data['x'][0] 282 | 283 | # Create data set 284 | dataset = DataGenerator(t0, t1, n_t, n_x) 285 | 286 | 287 | # arch = 'MLP' 288 | arch = 'modified_MLP' 289 | print('Arch:', arch) 290 | print('Alg: temporal reweighting, Random collocation points') 291 | 292 | 293 | N = 250 // num_step 294 | 295 | u_pred_list = [] 296 | params_list = [] 297 | losses_list = [] 298 | 299 | for k in range(N): 300 | # Initialize model 301 | u_exact = usol[:, time_step + k * num_step:time_step + (k+1) * num_step] # (512, num_step) 302 | print('Final Time: {}'.format(k + 1)) 303 | model = PINN(key, u_exact, arch, layers, M_t, M_x, state0, t0, t1, n_t, n_x, tol) 304 | 305 | # Train 306 | for tol in tol_list: 307 | model.tol = tol 308 | print('tol: ', tol) 309 | # Train 310 | model.train(dataset, nIter=200000) 311 | 312 | # Store 313 | params = model.get_params(model.opt_state) 314 | u_pred = model.u_pred_fn(params, t_star, x_star) 315 | u_pred_list.append(u_pred) 316 | flat_params, _ = ravel_pytree(params) 317 | params_list.append(flat_params) 318 | losses_list.append([model.loss_log, model.loss_ics_log, model.loss_res_log]) 319 | 320 | 321 | np.save(arch + '_u_pred_list.npy', u_pred_list) 322 | np.save(arch + '_params_list.npy', params_list) 323 | np.save(arch + '_losses_list.npy', losses_list) 324 | 325 | u_preds = np.hstack(u_pred_list) 326 | error = np.linalg.norm(u_preds - usol[:, time_step:time_step + (k+1) * num_step]) / np.linalg.norm(usol[:, time_step:time_step + (k+1) * num_step]) 327 | print('Relative l2 error: {:.3e}'.format(error)) 328 | 329 | params = model.get_params(model.opt_state) 330 | u0_pred = vmap(model.neural_net, (None, None, 0))(params, t1, x_star) 331 | state0 = u0_pred 332 | 333 | -------------------------------------------------------------------------------- /KS/regular_KS.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | from jax import random, grad, vmap, jit, jacfwd, jacrev 3 | from jax.example_libraries import optimizers 4 | from jax.experimental.jet import jet 5 | from jax.nn import relu 6 | from jax.config import config 7 | from jax import lax 8 | from jax.flatten_util import ravel_pytree 9 | import itertools 10 | from functools import partial 11 | from torch.utils import data 12 | from tqdm import trange 13 | 14 | import scipy.io 15 | 16 | 17 | 18 | # Define MLP 19 | def MLP(layers, L=1.0, M=1, activation=relu): 20 | # Define input encoding function 21 | def input_encoding(t, x): 22 | w = 2.0 * np.pi / L 23 | k = np.arange(1, M + 1) 24 | out = np.hstack([t, 1, 25 | np.cos(k * w * x), np.sin(k * w * x)]) 26 | return out 27 | 28 | def init(rng_key): 29 | def init_layer(key, d_in, d_out): 30 | k1, k2 = random.split(key) 31 | glorot_stddev = 1.0 / np.sqrt((d_in + d_out) / 2.) 32 | W = glorot_stddev * random.normal(k1, (d_in, d_out)) 33 | b = np.zeros(d_out) 34 | return W, b 35 | key, *keys = random.split(rng_key, len(layers)) 36 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 37 | return params 38 | def apply(params, inputs): 39 | t = inputs[0] 40 | x = inputs[1] 41 | H = input_encoding(t, x) 42 | for W, b in params[:-1]: 43 | outputs = np.dot(H, W) + b 44 | H = activation(outputs) 45 | W, b = params[-1] 46 | outputs = np.dot(H, W) + b 47 | return outputs 48 | return init, apply 49 | 50 | 51 | # Define modified MLP 52 | def modified_MLP(layers, L=1.0, M=1, activation=relu): 53 | def xavier_init(key, d_in, d_out): 54 | glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.) 55 | W = glorot_stddev * random.normal(key, (d_in, d_out)) 56 | b = np.zeros(d_out) 57 | return W, b 58 | 59 | # Define input encoding function 60 | def input_encoding(t, x): 61 | w = 2 * np.pi / L 62 | k = np.arange(1, M + 1) 63 | out = np.hstack([t, 1, 64 | np.cos(k * w * x), np.sin(k * w * x)]) 65 | return out 66 | 67 | 68 | def init(rng_key): 69 | U1, b1 = xavier_init(random.PRNGKey(12345), layers[0], layers[1]) 70 | U2, b2 = xavier_init(random.PRNGKey(54321), layers[0], layers[1]) 71 | def init_layer(key, d_in, d_out): 72 | k1, k2 = random.split(key) 73 | W, b = xavier_init(k1, d_in, d_out) 74 | return W, b 75 | key, *keys = random.split(rng_key, len(layers)) 76 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 77 | return (params, U1, b1, U2, b2) 78 | 79 | def apply(params, inputs): 80 | params, U1, b1, U2, b2 = params 81 | 82 | t = inputs[0] 83 | x = inputs[1] 84 | inputs = input_encoding(t, x) 85 | U = activation(np.dot(inputs, U1) + b1) 86 | V = activation(np.dot(inputs, U2) + b2) 87 | for W, b in params[:-1]: 88 | outputs = activation(np.dot(inputs, W) + b) 89 | inputs = np.multiply(outputs, U) + np.multiply(1 - outputs, V) 90 | W, b = params[-1] 91 | outputs = np.dot(inputs, W) + b 92 | return outputs 93 | return init, apply 94 | 95 | 96 | class DataGenerator(data.Dataset): 97 | def __init__(self, t0, t1, n_t=10, n_x=64, rng_key=random.PRNGKey(1234)): 98 | 'Initialization' 99 | self.t0 = t0 100 | self.t1 = t1 101 | self.n_t = n_t 102 | self.n_x = n_x 103 | self.key = rng_key 104 | 105 | def __getitem__(self, index): 106 | 'Generate one batch of data' 107 | self.key, subkey = random.split(self.key) 108 | batch = self.__data_generation(subkey) 109 | return batch 110 | 111 | @partial(jit, static_argnums=(0,)) 112 | def __data_generation(self, key): 113 | 'Generates data containing batch_size samples' 114 | subkeys = random.split(key, 2) 115 | t_r = random.uniform(subkeys[0], shape=(self.n_t,), minval=self.t0, maxval=self.t1).sort() 116 | x_r = random.uniform(subkeys[1], shape=(self.n_x,), minval=-1.0, maxval=1.0) 117 | batch = (t_r, x_r) 118 | return batch 119 | 120 | 121 | 122 | # Define the model 123 | class PINN: 124 | def __init__(self, key, arch, layers, M_x, state0, t0, t1, n_t, n_x, tol=1.0): 125 | 126 | # grid 127 | eps = 0.01 * t1 128 | self.t_r = np.linspace(t0, t1 + eps, n_t) 129 | self.x_r = np.linspace(-1.0, 1.0, n_x) 130 | 131 | # IC 132 | t_ic = np.zeros((x_star.shape[0], 1)) 133 | x_ic = x_star.reshape(-1, 1) 134 | self.X_ic = np.hstack([t_ic, x_ic]) 135 | self.Y_ic = state0 136 | 137 | # Weight matrix and causal parameter 138 | self.M = np.triu(np.ones((n_t, n_t)), k=1).T 139 | self.tol = tol 140 | 141 | if arch == 'MLP': 142 | d0 = 2 * M_x + 2 143 | layers = [d0] + layers 144 | self.init, self.apply = MLP(layers, L=2.0, M=M_x, activation=np.tanh) 145 | params = self.init(rng_key = key) 146 | 147 | if arch == 'modified_MLP': 148 | d0 = 2 * M_x + 2 149 | layers = [d0] + layers 150 | self.init, self.apply = modified_MLP(layers, L=2.0, M=M_x, activation=np.tanh) 151 | params = self.init(rng_key = key) 152 | 153 | 154 | # Use optimizers to set optimizer initialization and update functions 155 | lr = optimizers.exponential_decay(1e-3, decay_steps=5000, decay_rate=0.9) 156 | self.opt_init, self.opt_update, self.get_params = optimizers.adam(lr) 157 | self.opt_state = self.opt_init(params) 158 | _, self.unravel = ravel_pytree(params) 159 | 160 | # Evaluate functions over a grid 161 | self.u_pred_fn = vmap(vmap(self.neural_net, (None, 0, None)), (None, None, 0)) # consistent with the dataset 162 | self.r_pred_fn = vmap(vmap(self.residual_net, (None, None, 0)), (None, 0, None)) 163 | 164 | # Logger 165 | self.loss_log = [] 166 | self.loss_ics_log = [] 167 | self.loss_res_log = [] 168 | 169 | self.itercount = itertools.count() 170 | 171 | 172 | def neural_net(self, params, t, x): 173 | z = np.stack([t, x]) 174 | outputs = self.apply(params, z) 175 | return outputs[0] 176 | 177 | def residual_net(self, params, t, x): 178 | u = self.neural_net(params, t, x) 179 | u_t = grad(self.neural_net, argnums=1)(params, t, x) 180 | u_fn = lambda x: self.neural_net(params, t, x) # For using Taylor-mode AD 181 | _, (u_x, u_xx, u_xxx, u_xxxx) = jet(u_fn, (x, ), [[1.0, 0.0, 0.0, 0.0]]) # Taylor-mode AD 182 | return u_t + 5 * u * u_x + 0.5 * u_xx + 0.005 * u_xxxx 183 | 184 | # Compute the temporal weights 185 | @partial(jit, static_argnums=(0,)) 186 | def residuals_and_weights(self, params, batch, tol): 187 | t_r, x_r = batch 188 | L_0 = 1e3 * self.loss_ics(params) 189 | r_pred = self.r_pred_fn(params, t_r, x_r) 190 | L_t = np.mean(r_pred**2, axis=1) 191 | W = lax.stop_gradient(np.exp(- tol * (self.M @ L_t + L_0) )) 192 | return L_0, L_t, W 193 | 194 | # Initial condition loss 195 | @partial(jit, static_argnums=(0,)) 196 | def loss_ics(self, params): 197 | # Compute forward pass 198 | u_pred = vmap(self.neural_net, (None, 0, 0))(params, self.X_ic[:,0], self.X_ic[:,1]) 199 | # Compute loss 200 | loss_ics = np.mean((self.Y_ic.flatten() - u_pred.flatten())**2) 201 | return loss_ics 202 | 203 | # Residual loss 204 | @partial(jit, static_argnums=(0,)) 205 | def loss_res(self, params, batch): 206 | t_r, x_r = batch 207 | # Compute forward pass 208 | r_pred = self.r_pred_fn(params, t_r, x_r) 209 | # Compute loss 210 | loss_r = np.mean(r_pred**2) 211 | return loss_r 212 | 213 | # Total loss 214 | @partial(jit, static_argnums=(0,)) 215 | def loss(self, params, batch): 216 | L_0, L_t, W = self.residuals_and_weights(params, batch, self.tol) 217 | # Compute loss 218 | loss = np.mean(W * L_t + L_0) 219 | return loss 220 | 221 | # Define a compiled update step 222 | @partial(jit, static_argnums=(0,)) 223 | def step(self, i, opt_state, batch): 224 | params = self.get_params(opt_state) 225 | g = grad(self.loss)(params, batch) 226 | return self.opt_update(i, g, opt_state) 227 | 228 | # Optimize parameters in a loop 229 | def train(self, dataset, nIter = 10000): 230 | res_data = iter(dataset) 231 | pbar = trange(nIter) 232 | # Main training loop 233 | for it in pbar: 234 | # Get batch 235 | batch= next(res_data) 236 | self.current_count = next(self.itercount) 237 | self.opt_state = self.step(self.current_count, self.opt_state, batch) 238 | 239 | if it % 1000 == 0: 240 | params = self.get_params(self.opt_state) 241 | 242 | loss_value = self.loss(params, batch) 243 | loss_ics_value = self.loss_ics(params) 244 | loss_res_value = self.loss_res(params, batch) 245 | _, _, W_value = self.residuals_and_weights(params, batch, self.tol) 246 | 247 | self.loss_log.append(loss_value) 248 | self.loss_ics_log.append(loss_ics_value) 249 | self.loss_res_log.append(loss_res_value) 250 | 251 | pbar.set_postfix({'Loss': loss_value, 252 | 'loss_ics' : loss_ics_value, 253 | 'loss_res': loss_res_value, 254 | 'W_min' : W_value.min()}) 255 | 256 | if W_value.min() > 0.99: 257 | break 258 | 259 | 260 | # Load data 261 | data = scipy.io.loadmat('ks_simple.mat') 262 | # Test data 263 | usol = data['usol'] 264 | 265 | 266 | # Hpyer-parameters 267 | key = random.PRNGKey(1234) 268 | M_t = 2 269 | M_x = 5 270 | t0 = 0.0 271 | t1 = 0.1 272 | n_t = 32 273 | n_x = 64 274 | tol_list = [1e-2, 1e-1, 1e0, 1e1, 1e2] 275 | layers = [256, 256, 256, 1] # using Fourier embedding so it is not 1 276 | 277 | # Initial state 278 | state0 = usol[:, 0:1] 279 | dt = 1 / 250 280 | idx = int(t1 / dt) 281 | t_star = data['t'][0][:idx] 282 | x_star = data['x'][0] 283 | 284 | # Create data set 285 | dataset = DataGenerator(t0, t1, n_t, n_x) 286 | 287 | arch = 'modified_MLP' 288 | print('arch:', arch) 289 | 290 | N = 10 291 | u_pred_list = [] 292 | params_list = [] 293 | losses_list = [] 294 | 295 | 296 | # Time marching 297 | for k in range(N): 298 | # Initialize model 299 | print('Final Time: {}'.format((k + 1) * t1)) 300 | model = PINN(key, arch, layers, M_x, state0, t0, t1, n_t, n_x) 301 | 302 | # Train 303 | for tol in tol_list: 304 | model.tol = tol 305 | print("tol:", model.tol) 306 | # Train 307 | model.train(dataset, nIter=200000) 308 | 309 | # Store 310 | params = model.get_params(model.opt_state) 311 | u_pred = model.u_pred_fn(params, t_star, x_star) 312 | u_pred_list.append(u_pred) 313 | flat_params, _ = ravel_pytree(params) 314 | params_list.append(flat_params) 315 | losses_list.append([model.loss_log, model.loss_ics_log, model.loss_res_log]) 316 | 317 | 318 | np.save('u_pred_list.npy', u_pred_list) 319 | np.save('params_list.npy', params_list) 320 | np.save('losses_list.npy', losses_list) 321 | 322 | # error 323 | u_preds = np.hstack(u_pred_list) 324 | error = np.linalg.norm(u_preds - usol[:, :(k+1) * idx]) / np.linalg.norm(usol[:, :(k+1) * idx]) 325 | print('Relative l2 error: {:.3e}'.format(error)) 326 | 327 | params = model.get_params(model.opt_state) 328 | u0_pred = vmap(model.neural_net, (None, None, 0))(params, t1, x_star) 329 | state0 = u0_pred 330 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 58 | Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 63 | ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-NC-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution, NonCommercial, and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. NonCommercial means not primarily intended for or directed towards 126 | commercial advantage or monetary compensation. For purposes of 127 | this Public License, the exchange of the Licensed Material for 128 | other material subject to Copyright and Similar Rights by digital 129 | file-sharing or similar means is NonCommercial provided there is 130 | no payment of monetary compensation in connection with the 131 | exchange. 132 | 133 | l. Share means to provide material to the public by any means or 134 | process that requires permission under the Licensed Rights, such 135 | as reproduction, public display, public performance, distribution, 136 | dissemination, communication, or importation, and to make material 137 | available to the public including in ways that members of the 138 | public may access the material from a place and at a time 139 | individually chosen by them. 140 | 141 | m. Sui Generis Database Rights means rights other than copyright 142 | resulting from Directive 96/9/EC of the European Parliament and of 143 | the Council of 11 March 1996 on the legal protection of databases, 144 | as amended and/or succeeded, as well as other essentially 145 | equivalent rights anywhere in the world. 146 | 147 | n. You means the individual or entity exercising the Licensed Rights 148 | under this Public License. Your has a corresponding meaning. 149 | 150 | 151 | Section 2 -- Scope. 152 | 153 | a. License grant. 154 | 155 | 1. Subject to the terms and conditions of this Public License, 156 | the Licensor hereby grants You a worldwide, royalty-free, 157 | non-sublicensable, non-exclusive, irrevocable license to 158 | exercise the Licensed Rights in the Licensed Material to: 159 | 160 | a. reproduce and Share the Licensed Material, in whole or 161 | in part, for NonCommercial purposes only; and 162 | 163 | b. produce, reproduce, and Share Adapted Material for 164 | NonCommercial purposes only. 165 | 166 | 2. Exceptions and Limitations. For the avoidance of doubt, where 167 | Exceptions and Limitations apply to Your use, this Public 168 | License does not apply, and You do not need to comply with 169 | its terms and conditions. 170 | 171 | 3. Term. The term of this Public License is specified in Section 172 | 6(a). 173 | 174 | 4. Media and formats; technical modifications allowed. The 175 | Licensor authorizes You to exercise the Licensed Rights in 176 | all media and formats whether now known or hereafter created, 177 | and to make technical modifications necessary to do so. The 178 | Licensor waives and/or agrees not to assert any right or 179 | authority to forbid You from making technical modifications 180 | necessary to exercise the Licensed Rights, including 181 | technical modifications necessary to circumvent Effective 182 | Technological Measures. For purposes of this Public License, 183 | simply making modifications authorized by this Section 2(a) 184 | (4) never produces Adapted Material. 185 | 186 | 5. Downstream recipients. 187 | 188 | a. Offer from the Licensor -- Licensed Material. Every 189 | recipient of the Licensed Material automatically 190 | receives an offer from the Licensor to exercise the 191 | Licensed Rights under the terms and conditions of this 192 | Public License. 193 | 194 | b. Additional offer from the Licensor -- Adapted Material. 195 | Every recipient of Adapted Material from You 196 | automatically receives an offer from the Licensor to 197 | exercise the Licensed Rights in the Adapted Material 198 | under the conditions of the Adapter's License You apply. 199 | 200 | c. No downstream restrictions. You may not offer or impose 201 | any additional or different terms or conditions on, or 202 | apply any Effective Technological Measures to, the 203 | Licensed Material if doing so restricts exercise of the 204 | Licensed Rights by any recipient of the Licensed 205 | Material. 206 | 207 | 6. No endorsement. Nothing in this Public License constitutes or 208 | may be construed as permission to assert or imply that You 209 | are, or that Your use of the Licensed Material is, connected 210 | with, or sponsored, endorsed, or granted official status by, 211 | the Licensor or others designated to receive attribution as 212 | provided in Section 3(a)(1)(A)(i). 213 | 214 | b. Other rights. 215 | 216 | 1. Moral rights, such as the right of integrity, are not 217 | licensed under this Public License, nor are publicity, 218 | privacy, and/or other similar personality rights; however, to 219 | the extent possible, the Licensor waives and/or agrees not to 220 | assert any such rights held by the Licensor to the limited 221 | extent necessary to allow You to exercise the Licensed 222 | Rights, but not otherwise. 223 | 224 | 2. Patent and trademark rights are not licensed under this 225 | Public License. 226 | 227 | 3. To the extent possible, the Licensor waives any right to 228 | collect royalties from You for the exercise of the Licensed 229 | Rights, whether directly or through a collecting society 230 | under any voluntary or waivable statutory or compulsory 231 | licensing scheme. In all other cases the Licensor expressly 232 | reserves any right to collect such royalties, including when 233 | the Licensed Material is used other than for NonCommercial 234 | purposes. 235 | 236 | 237 | Section 3 -- License Conditions. 238 | 239 | Your exercise of the Licensed Rights is expressly made subject to the 240 | following conditions. 241 | 242 | a. Attribution. 243 | 244 | 1. If You Share the Licensed Material (including in modified 245 | form), You must: 246 | 247 | a. retain the following if it is supplied by the Licensor 248 | with the Licensed Material: 249 | 250 | i. identification of the creator(s) of the Licensed 251 | Material and any others designated to receive 252 | attribution, in any reasonable manner requested by 253 | the Licensor (including by pseudonym if 254 | designated); 255 | 256 | ii. a copyright notice; 257 | 258 | iii. a notice that refers to this Public License; 259 | 260 | iv. a notice that refers to the disclaimer of 261 | warranties; 262 | 263 | v. a URI or hyperlink to the Licensed Material to the 264 | extent reasonably practicable; 265 | 266 | b. indicate if You modified the Licensed Material and 267 | retain an indication of any previous modifications; and 268 | 269 | c. indicate the Licensed Material is licensed under this 270 | Public License, and include the text of, or the URI or 271 | hyperlink to, this Public License. 272 | 273 | 2. You may satisfy the conditions in Section 3(a)(1) in any 274 | reasonable manner based on the medium, means, and context in 275 | which You Share the Licensed Material. For example, it may be 276 | reasonable to satisfy the conditions by providing a URI or 277 | hyperlink to a resource that includes the required 278 | information. 279 | 3. If requested by the Licensor, You must remove any of the 280 | information required by Section 3(a)(1)(A) to the extent 281 | reasonably practicable. 282 | 283 | b. ShareAlike. 284 | 285 | In addition to the conditions in Section 3(a), if You Share 286 | Adapted Material You produce, the following conditions also apply. 287 | 288 | 1. The Adapter's License You apply must be a Creative Commons 289 | license with the same License Elements, this version or 290 | later, or a BY-NC-SA Compatible License. 291 | 292 | 2. You must include the text of, or the URI or hyperlink to, the 293 | Adapter's License You apply. You may satisfy this condition 294 | in any reasonable manner based on the medium, means, and 295 | context in which You Share Adapted Material. 296 | 297 | 3. You may not offer or impose any additional or different terms 298 | or conditions on, or apply any Effective Technological 299 | Measures to, Adapted Material that restrict exercise of the 300 | rights granted under the Adapter's License You apply. 301 | 302 | 303 | Section 4 -- Sui Generis Database Rights. 304 | 305 | Where the Licensed Rights include Sui Generis Database Rights that 306 | apply to Your use of the Licensed Material: 307 | 308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 309 | to extract, reuse, reproduce, and Share all or a substantial 310 | portion of the contents of the database for NonCommercial purposes 311 | only; 312 | 313 | b. if You include all or a substantial portion of the database 314 | contents in a database in which You have Sui Generis Database 315 | Rights, then the database in which You have Sui Generis Database 316 | Rights (but not its individual contents) is Adapted Material, 317 | including for purposes of Section 3(b); and 318 | 319 | c. You must comply with the conditions in Section 3(a) if You Share 320 | all or a substantial portion of the contents of the database. 321 | 322 | For the avoidance of doubt, this Section 4 supplements and does not 323 | replace Your obligations under this Public License where the Licensed 324 | Rights include other Copyright and Similar Rights. 325 | 326 | 327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 328 | 329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 339 | 340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 349 | 350 | c. The disclaimer of warranties and limitation of liability provided 351 | above shall be interpreted in a manner that, to the extent 352 | possible, most closely approximates an absolute disclaimer and 353 | waiver of all liability. 354 | 355 | 356 | Section 6 -- Term and Termination. 357 | 358 | a. This Public License applies for the term of the Copyright and 359 | Similar Rights licensed here. However, if You fail to comply with 360 | this Public License, then Your rights under this Public License 361 | terminate automatically. 362 | 363 | b. Where Your right to use the Licensed Material has terminated under 364 | Section 6(a), it reinstates: 365 | 366 | 1. automatically as of the date the violation is cured, provided 367 | it is cured within 30 days of Your discovery of the 368 | violation; or 369 | 370 | 2. upon express reinstatement by the Licensor. 371 | 372 | For the avoidance of doubt, this Section 6(b) does not affect any 373 | right the Licensor may have to seek remedies for Your violations 374 | of this Public License. 375 | 376 | c. For the avoidance of doubt, the Licensor may also offer the 377 | Licensed Material under separate terms or conditions or stop 378 | distributing the Licensed Material at any time; however, doing so 379 | will not terminate this Public License. 380 | 381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 382 | License. 383 | 384 | 385 | Section 7 -- Other Terms and Conditions. 386 | 387 | a. The Licensor shall not be bound by any additional or different 388 | terms or conditions communicated by You unless expressly agreed. 389 | 390 | b. Any arrangements, understandings, or agreements regarding the 391 | Licensed Material not stated herein are separate from and 392 | independent of the terms and conditions of this Public License. 393 | 394 | 395 | Section 8 -- Interpretation. 396 | 397 | a. For the avoidance of doubt, this Public License does not, and 398 | shall not be interpreted to, reduce, limit, restrict, or impose 399 | conditions on any use of the Licensed Material that could lawfully 400 | be made without permission under this Public License. 401 | 402 | b. To the extent possible, if any provision of this Public License is 403 | deemed unenforceable, it shall be automatically reformed to the 404 | minimum extent necessary to make it enforceable. If the provision 405 | cannot be reformed, it shall be severed from this Public License 406 | without affecting the enforceability of the remaining terms and 407 | conditions. 408 | 409 | c. No term or condition of this Public License will be waived and no 410 | failure to comply consented to unless expressly agreed to by the 411 | Licensor. 412 | 413 | d. Nothing in this Public License constitutes or may be interpreted 414 | as a limitation upon, or waiver of, any privileges and immunities 415 | that apply to the Licensor or You, including from the legal 416 | processes of any jurisdiction or authority. 417 | 418 | ======================================================================= 419 | 420 | Creative Commons is not a party to its public 421 | licenses. Notwithstanding, Creative Commons may elect to apply one of 422 | its public licenses to material it publishes and in those instances 423 | will be considered the “Licensor.” The text of the Creative Commons 424 | public licenses is dedicated to the public domain under the CC0 Public 425 | Domain Dedication. Except for the limited purpose of indicating that 426 | material is shared under a Creative Commons public license or as 427 | otherwise permitted by the Creative Commons policies published at 428 | creativecommons.org/policies, Creative Commons does not authorize the 429 | use of the trademark "Creative Commons" or any other trademark or logo 430 | of Creative Commons without its prior written consent including, 431 | without limitation, in connection with any unauthorized modifications 432 | to any of its public licenses or any other arrangements, 433 | understandings, or agreements concerning use of licensed material. For 434 | the avoidance of doubt, this paragraph does not form part of the 435 | public licenses. 436 | 437 | Creative Commons may be contacted at creativecommons.org. 438 | -------------------------------------------------------------------------------- /Lorentz/Causal_PINNs_lorentz.py: -------------------------------------------------------------------------------- 1 | 2 | # Commented out IPython magic to ensure Python compatibility. 3 | import numpy as onp 4 | import jax.numpy as np 5 | from jax import random, grad, vmap, jit, jacfwd, jacrev 6 | from jax.experimental import optimizers 7 | from jax.experimental.ode import odeint 8 | from jax.nn import relu 9 | from jax.config import config 10 | from jax import lax 11 | from jax.flatten_util import ravel_pytree 12 | import itertools 13 | from functools import partial 14 | from torch.utils import data 15 | from tqdm import trange 16 | 17 | import scipy.io 18 | from scipy.interpolate import griddata 19 | from scipy.linalg import lstsq 20 | from scipy.optimize import lsq_linear 21 | from sklearn.linear_model import RidgeCV 22 | import matplotlib.pyplot as plt 23 | import scipy.optimize 24 | from scipy.optimize import least_squares 25 | 26 | from scipy.integrate import odeint as scipy_odeint 27 | from mpl_toolkits.mplot3d import Axes3D 28 | 29 | # Define the neural net 30 | def init_layer(key, d_in, d_out): 31 | k1, k2 = random.split(key) 32 | glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.) 33 | W = glorot_stddev * random.normal(k1, (d_in, d_out)) 34 | b = np.zeros(d_out) 35 | return W, b 36 | 37 | def MLP(layers, activation=relu): 38 | ''' Vanilla MLP''' 39 | def init(rng_key): 40 | key, *keys = random.split(rng_key, len(layers)) 41 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 42 | return params 43 | def apply(params, inputs): 44 | for W, b in params[:-1]: 45 | outputs = np.dot(inputs, W) + b 46 | inputs = activation(outputs) 47 | W, b = params[-1] 48 | outputs = np.dot(inputs, W) + b 49 | return outputs 50 | return init, apply 51 | 52 | 53 | # Define the neural net 54 | def modified_MLP(layers, activation=relu): 55 | def xavier_init(key, d_in, d_out): 56 | glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.) 57 | W = glorot_stddev * random.normal(key, (d_in, d_out)) 58 | b = np.zeros(d_out) 59 | return W, b 60 | 61 | def init(rng_key): 62 | U1, b1 = xavier_init(random.PRNGKey(12345), layers[0], layers[1]) 63 | U2, b2 = xavier_init(random.PRNGKey(54321), layers[0], layers[1]) 64 | def init_layer(key, d_in, d_out): 65 | k1, k2 = random.split(key) 66 | W, b = xavier_init(k1, d_in, d_out) 67 | return W, b 68 | key, *keys = random.split(rng_key, len(layers)) 69 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 70 | return (params, U1, b1, U2, b2) 71 | 72 | def apply(params, inputs): 73 | params, U1, b1, U2, b2 = params 74 | U = activation(np.dot(inputs, U1) + b1) 75 | V = activation(np.dot(inputs, U2) + b2) 76 | for W, b in params[:-1]: 77 | outputs = activation(np.dot(inputs, W) + b) 78 | inputs = np.multiply(outputs, U) + np.multiply(1 - outputs, V) 79 | W, b = params[-1] 80 | outputs = np.dot(inputs, W) + b 81 | return outputs 82 | return init, apply 83 | 84 | # Define Fourier feature net 85 | def MLP_FF(layers, sigma=1.0, activation=relu): 86 | # Define input encoding function 87 | def input_encoding(x, w): 88 | out = np.hstack([np.sin(np.dot(x, w)), 89 | np.cos(np.dot(x, w))]) 90 | return out 91 | freqs = sigma * random.normal(random.PRNGKey(0), (layers[0], layers[1]//2)) 92 | def init(rng_key): 93 | def init_layer(key, d_in, d_out): 94 | k1, k2 = random.split(key) 95 | glorot_stddev = 1.0 / np.sqrt((d_in + d_out) / 2.) 96 | W = glorot_stddev * random.normal(k1, (d_in, d_out)) 97 | b = np.zeros(d_out) 98 | return W, b 99 | key, *keys = random.split(rng_key, len(layers)) 100 | params = list(map(init_layer, keys, layers[1:-1], layers[2:])) 101 | return params 102 | def apply(params, inputs): 103 | H = input_encoding(inputs, freqs) 104 | for W, b in params[:-1]: 105 | outputs = np.dot(H, W) + b 106 | H = activation(outputs) 107 | W, b = params[-1] 108 | outputs = np.dot(H, W) + b 109 | return outputs 110 | return init, apply 111 | 112 | 113 | # Define the model 114 | class PINN: 115 | def __init__(self, layers, states0, t0, t1, tol): 116 | 117 | self.states0 = states0 118 | self.t0 = t0 119 | self.t1 = t1 120 | 121 | # Grid 122 | n_t = 300 123 | eps = 0.1 * self.t1 124 | self.t = np.linspace(self.t0, self.t1 + eps, n_t) 125 | 126 | self.M = np.triu(np.ones((n_t, n_t)), k=1).T 127 | self.tol = tol 128 | 129 | self.rho = 28.0 130 | self.sigma = 10.0 131 | self.beta = 8.0 / 3.0 132 | 133 | self.init, self.apply = MLP(layers, activation=np.tanh) 134 | # self.init, self.apply = modified_MLP(layers, activation=np.tanh) 135 | params = self.init(random.PRNGKey(1234)) 136 | 137 | # Use optimizers to set optimizer initialization and update functions 138 | self.opt_init, \ 139 | self.opt_update, \ 140 | self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 141 | decay_steps=5000, 142 | decay_rate=0.9)) 143 | self.opt_state = self.opt_init(params) 144 | _, self.unravel = ravel_pytree(params) 145 | 146 | # Logger 147 | self.itercount = itertools.count() 148 | 149 | self.loss_log = [] 150 | self.loss_ics_log = [] 151 | self.loss_res_log = [] 152 | 153 | def neural_net(self, params, t): 154 | t = np.stack([t]) 155 | outputs = self.apply(params, t) * t 156 | x = outputs[0] + self.states0[0] 157 | y = outputs[1] + self.states0[1] 158 | z = outputs[2] + self.states0[2] 159 | return x, y, z 160 | 161 | def x_fn(self, params, t): 162 | x, _, _ = self.neural_net(params, t) 163 | return x 164 | 165 | def y_fn(self, params, t): 166 | _, y, _ = self.neural_net(params, t) 167 | return y 168 | 169 | def z_fn(self, params, t): 170 | _, _, z = self.neural_net(params, t) 171 | return z 172 | 173 | def residual_net(self, params, t): 174 | x, y, z = self.neural_net(params, t) 175 | x_t = grad(self.x_fn, argnums=1)(params, t) 176 | y_t = grad(self.y_fn, argnums=1)(params, t) 177 | z_t = grad(self.z_fn, argnums=1)(params, t) 178 | 179 | res_1 = x_t - self.sigma * (y - x) 180 | res_2 = y_t - x * (self.rho - z) + y 181 | res_3 = z_t - x * y + self.beta * z 182 | 183 | return res_1, res_2, res_3 184 | 185 | def loss_ics(self, params): 186 | # Compute forward pass 187 | x_pred, y_pred, z_pred =self.neural_net(params, self.t0) 188 | # Compute loss 189 | 190 | loss_x_ic = np.mean((self.states0[0] - x_pred)**2) 191 | loss_y_ic = np.mean((self.states0[1] - y_pred)**2) 192 | loss_z_ic = np.mean((self.states0[2] - z_pred)**2) 193 | return loss_x_ic + loss_y_ic + loss_z_ic 194 | 195 | @partial(jit, static_argnums=(0,)) 196 | def residuals_and_weights(self, params, tol): 197 | r1_pred, r2_pred, r3_pred = vmap(self.residual_net, (None, 0))(params, self.t) 198 | W = lax.stop_gradient(np.exp(- tol * self.M @ (r1_pred**2 + r2_pred**2 + r3_pred**2))) 199 | return r1_pred, r2_pred, r3_pred, W 200 | 201 | @partial(jit, static_argnums=(0,)) 202 | def loss_res(self, params): 203 | # Compute forward pass 204 | r1_pred, r2_pred, r3_pred, W = self.residuals_and_weights(params, self.tol) 205 | # Compute loss 206 | loss_res = np.mean(W * (r1_pred**2 + r2_pred**2 + r3_pred**2)) 207 | return loss_res 208 | 209 | @partial(jit, static_argnums=(0,)) 210 | def loss(self, params): 211 | 212 | loss_res = self.loss_res(params) 213 | 214 | loss = loss_res 215 | return loss 216 | 217 | # Define a compiled update step 218 | @partial(jit, static_argnums=(0,)) 219 | def step(self, i, opt_state): 220 | params = self.get_params(opt_state) 221 | g = grad(self.loss)(params) 222 | return self.opt_update(i, g, opt_state) 223 | 224 | # Optimize parameters in a loop 225 | def train(self, nIter = 10000): 226 | pbar = trange(nIter) 227 | # Main training loop 228 | for it in pbar: 229 | self.current_count = next(self.itercount) 230 | self.opt_state = self.step(self.current_count, self.opt_state) 231 | 232 | if it % 1000 == 0: 233 | params = self.get_params(self.opt_state) 234 | 235 | loss_value = self.loss(params) 236 | loss_ics_value = self.loss_ics(params) 237 | loss_res_value = self.loss_res(params) 238 | _, _, _, W_value = self.residuals_and_weights(params, self.tol) 239 | 240 | self.loss_log.append(loss_value) 241 | self.loss_ics_log.append(loss_ics_value) 242 | self.loss_res_log.append(loss_res_value) 243 | 244 | pbar.set_postfix({'Loss': loss_value, 245 | 'loss_ics' : loss_ics_value, 246 | 'loss_res': loss_res_value, 247 | 'W_min': W_value.min()} ) 248 | 249 | if W_value.min() > 0.99: 250 | break 251 | 252 | # Evaluates predictions at test points 253 | @partial(jit, static_argnums=(0,)) 254 | def predict_u(self, params, t_star): 255 | x_pred, y_pred, z_pred = vmap(self.neural_net, (None, 0))(params, t_star) 256 | return x_pred, y_pred, z_pred 257 | 258 | def f(state, t): 259 | x, y, z = state # Unpack the state vector 260 | return sigma * (y - x), x * (rho - z) - y, x * y - beta * z # Derivatives 261 | 262 | rho = 28.0 263 | sigma = 10.0 264 | beta = 8.0 / 3.0 265 | 266 | state0 = [1.0, 1.0, 1.0] 267 | 268 | T = 30 269 | t_star = onp.arange(0, T, 0.01) 270 | states = scipy_odeint(f, state0, t_star) 271 | 272 | # Create PINNs model 273 | t0 = 0.0 274 | t1 = 0.5 275 | tol = 0.1 276 | 277 | tol_list = [1e-3, 1e-2, 1e-1, 1e0, 1e1] 278 | 279 | layers = [1, 512, 512, 512, 3] 280 | 281 | x_pred_list = [] 282 | y_pred_list = [] 283 | z_pred_list = [] 284 | params_list = [] 285 | losses_list = [] 286 | 287 | state0 = np.array([1.0, 1.0, 1.0]) 288 | t = np.arange(t0, t1, 0.01) 289 | for k in range(int(T / t1)): 290 | # Initialize model 291 | print('Final Time: {}'.format( (k+1) * t1)) 292 | model = PINN(layers, state0, t0, t1, tol) 293 | 294 | for tol in tol_list: 295 | model.tol = tol 296 | print('tol:', model.tol) 297 | # Train 298 | model.train(nIter=300000) 299 | 300 | params = model.get_params(model.opt_state) 301 | x_pred, y_pred, z_pred = model.predict_u(params, t) 302 | x0_pred, y0_pred, z0_pred = model.neural_net(params, model.t1) 303 | state0 = np.array([x0_pred, y0_pred, z0_pred]) 304 | 305 | # Store predictions 306 | x_pred_list.append(x_pred) 307 | y_pred_list.append(y_pred) 308 | z_pred_list.append(z_pred) 309 | losses_list.append([model.loss_ics_log, model.loss_res_log]) 310 | 311 | # Store params 312 | flat_params, _ = ravel_pytree(params) 313 | params_list.append(flat_params) 314 | 315 | np.save('x_pred_list.npy', x_pred_list) 316 | np.save('y_pred_list.npy', y_pred_list) 317 | np.save('z_pred_list.npy', z_pred_list) 318 | np.save('params_list.npy', params_list) 319 | np.save('losses_list.npy', losses_list) 320 | 321 | # Error 322 | t_star = onp.arange(t0, (k+1) * t1, 0.01) 323 | states = scipy_odeint(f, [1.0, 1.0, 1.0], t_star) 324 | 325 | x_preds = np.hstack(x_pred_list) 326 | y_preds = np.hstack(y_pred_list) 327 | z_preds = np.hstack(z_pred_list) 328 | 329 | error_x = np.linalg.norm(x_preds - states[:, 0]) / np.linalg.norm(states[:, 0]) 330 | error_y = np.linalg.norm(y_preds - states[:, 1]) / np.linalg.norm(states[:, 1]) 331 | error_z = np.linalg.norm(z_preds - states[:, 2]) / np.linalg.norm(states[:, 2]) 332 | print('Relative l2 error x: {:.3e}'.format(error_x)) 333 | print('Relative l2 error y: {:.3e}'.format(error_y)) 334 | print('Relative l2 error z: {:.3e}'.format(error_z)) 335 | 336 | 337 | # np.save('x_pred_list.npy', x_pred_list) 338 | # np.save('y_pred_list.npy', y_pred_list) 339 | # np.save('z_pred_list.npy', z_pred_list) 340 | # np.save('params_list.npy', params_list) 341 | 342 | # x_preds = np.hstack(x_pred_list) 343 | # y_preds = np.hstack(y_pred_list) 344 | # z_preds = np.hstack(z_pred_list) 345 | 346 | # error_x = np.linalg.norm(x_preds - states[:, 0]) / np.linalg.norm(states[:, 0]) 347 | # error_y = np.linalg.norm(y_preds - states[:, 1]) / np.linalg.norm(states[:, 1]) 348 | # error_z = np.linalg.norm(z_preds - states[:, 2]) / np.linalg.norm(states[:, 2]) 349 | # print('Relative l2 error x: {:.3e}'.format(error_x)) 350 | # print('Relative l2 error y: {:.3e}'.format(error_y)) 351 | # print('Relative l2 error z: {:.3e}'.format(error_z)) 352 | 353 | 354 | 355 | 356 | -------------------------------------------------------------------------------- /NS/NS.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax.numpy as np 3 | from jax import random, grad, vmap, jit, jacfwd, jacrev 4 | example_libraries 5 | from jax.nn import relu 6 | from jax import lax 7 | from jax.flatten_util import ravel_pytree 8 | import itertools 9 | from functools import partial 10 | from torch.utils import data 11 | from tqdm import trange 12 | 13 | 14 | # Define the neural net 15 | def init_layer(key, d_in, d_out): 16 | k1, k2 = random.split(key) 17 | glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.) 18 | W = glorot_stddev * random.normal(k1, (d_in, d_out)) 19 | b = np.zeros(d_out) 20 | return W, b 21 | 22 | 23 | def MLP(layers, activation=relu): 24 | ''' Vanilla MLP''' 25 | 26 | def init(rng_key): 27 | key, *keys = random.split(rng_key, len(layers)) 28 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 29 | return params 30 | 31 | def apply(params, inputs): 32 | for W, b in params[:-1]: 33 | outputs = np.dot(inputs, W) + b 34 | inputs = activation(outputs) 35 | W, b = params[-1] 36 | outputs = np.dot(inputs, W) + b 37 | return outputs 38 | 39 | return init, apply 40 | 41 | 42 | # Define the neural net 43 | def modified_MLP_II(layers, L_x=1.0, L_y=1.0, M_t=1, M_x=1, M_y=1, activation=relu): 44 | def xavier_init(key, d_in, d_out): 45 | glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.) 46 | W = glorot_stddev * random.normal(key, (d_in, d_out)) 47 | b = np.zeros(d_out) 48 | return W, b 49 | 50 | w_x = 2.0 * np.pi / L_x 51 | w_y = 2.0 * np.pi / L_y 52 | k_x = np.arange(1, M_x + 1) 53 | k_y = np.arange(1, M_y + 1) 54 | k_xx, k_yy = np.meshgrid(k_x, k_y) 55 | k_xx = k_xx.flatten() 56 | k_yy = k_yy.flatten() 57 | 58 | # Define input encoding function 59 | def input_encoding(t, x, y): 60 | k_t = np.power(10.0, np.arange(0, M_t + 1)) 61 | out = np.hstack([1, k_t * t, 62 | np.cos(k_x * w_x * x), np.cos(k_y * w_y * y), 63 | np.sin(k_x * w_x * x), np.sin(k_y * w_y * y), 64 | np.cos(k_xx * w_x * x) * np.cos(k_yy * w_y * y), 65 | np.cos(k_xx * w_x * x) * np.sin(k_yy * w_y * y), 66 | np.sin(k_xx * w_x * x) * np.cos(k_yy * w_y * y), 67 | np.sin(k_xx * w_x * x) * np.sin(k_yy * w_y * y)]) 68 | return out 69 | 70 | def init(rng_key): 71 | U1, b1 = xavier_init(random.PRNGKey(12345), layers[0], layers[1]) 72 | U2, b2 = xavier_init(random.PRNGKey(54321), layers[0], layers[1]) 73 | 74 | def init_layer(key, d_in, d_out): 75 | k1, k2 = random.split(key) 76 | W, b = xavier_init(k1, d_in, d_out) 77 | return W, b 78 | 79 | key, *keys = random.split(rng_key, len(layers)) 80 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 81 | return (params, U1, b1, U2, b2) 82 | 83 | def apply(params, inputs): 84 | params, U1, b1, U2, b2 = params 85 | 86 | t = inputs[0] 87 | x = inputs[1] 88 | y = inputs[2] 89 | inputs = input_encoding(t, x, y) 90 | U = activation(np.dot(inputs, U1) + b1) 91 | V = activation(np.dot(inputs, U2) + b2) 92 | for W, b in params[:-1]: 93 | outputs = activation(np.dot(inputs, W) + b) 94 | inputs = np.multiply(outputs, U) + np.multiply(1 - outputs, V) 95 | W, b = params[-1] 96 | outputs = np.dot(inputs, W) + b 97 | return outputs 98 | 99 | return init, apply 100 | 101 | 102 | # Define the neural net 103 | def modified_MLP_III(layers, L_x=1.0, L_y=1.0, M_t=1, M_x=1, M_y=1, activation=relu): 104 | def xavier_init(key, d_in, d_out): 105 | glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.) 106 | W = glorot_stddev * random.normal(key, (d_in, d_out)) 107 | b = np.zeros(d_out) 108 | return W, b 109 | 110 | w_x = 2.0 * np.pi / L_x 111 | w_y = 2.0 * np.pi / L_y 112 | k_x = np.arange(1, M_x + 1) 113 | k_y = np.arange(1, M_y + 1) 114 | k_xx, k_yy = np.meshgrid(k_x, k_y) 115 | k_xx = k_xx.flatten() 116 | k_yy = k_yy.flatten() 117 | 118 | # Define input encoding function 119 | def spatial_encoding(x, y, M_x, M_y): 120 | out = np.hstack([1, np.cos(k_x * w_x * x), np.cos(k_y * w_y * y), 121 | np.sin(k_x * w_x * x), np.sin(k_y * w_y * y), 122 | np.cos(k_xx * w_x * x) * np.cos(k_yy * w_y * y), 123 | np.cos(k_xx * w_x * x) * np.sin(k_yy * w_y * y), 124 | np.sin(k_xx * w_x * x) * np.cos(k_yy * w_y * y), 125 | np.sin(k_xx * w_x * x) * np.sin(k_yy * w_y * y)]) 126 | return out 127 | 128 | def temporal_encoding(t, M_t): 129 | k = np.power(10.0, np.arange(0, M_t + 1)) 130 | out = k * t 131 | return out 132 | 133 | def init(rng_key): 134 | U1, b1 = xavier_init(random.PRNGKey(12345), M_t + 1, layers[1]) 135 | U2, b2 = xavier_init(random.PRNGKey(54321), 2 * M_x + 2 * M_y + 4 * M_x * M_y + 1, layers[1]) 136 | 137 | def init_layer(key, d_in, d_out): 138 | k1, k2 = random.split(key) 139 | W, b = xavier_init(k1, d_in, d_out) 140 | return W, b 141 | 142 | key, *keys = random.split(rng_key, len(layers)) 143 | params = list(map(init_layer, keys, layers[:-1], layers[1:])) 144 | return (params, U1, b1, U2, b2) 145 | 146 | def apply(params, inputs): 147 | params, U1, b1, U2, b2 = params 148 | t = inputs[0] 149 | x = inputs[1] 150 | y = inputs[2] 151 | H_t = temporal_encoding(t, M_t) 152 | H_x = spatial_encoding(x, y, M_x, M_y) 153 | inputs = np.hstack([H_t, H_x]) 154 | U = activation(np.dot(H_t, U1) + b1) 155 | V = activation(np.dot(H_x, U2) + b2) 156 | for W, b in params[:-1]: 157 | outputs = activation(np.dot(inputs, W) + b) 158 | inputs = np.multiply(outputs, U) + np.multiply(1 - outputs, V) 159 | W, b = params[-1] 160 | outputs = np.dot(inputs, W) + b 161 | return outputs 162 | 163 | return init, apply 164 | 165 | 166 | class DataGenerator(data.Dataset): 167 | def __init__(self, t0, t1, n_t=10, n_x=64, rng_key=random.PRNGKey(1234)): 168 | 'Initialization' 169 | self.t0 = t0 170 | self.t1 = t1 + 0.01 * t1 171 | self.n_t = n_t 172 | self.n_x = n_x 173 | self.key = rng_key 174 | 175 | def __getitem__(self, index): 176 | 'Generate one batch of data' 177 | self.key, subkey = random.split(self.key) 178 | batch = self.__data_generation(subkey) 179 | return batch 180 | 181 | @partial(jit, static_argnums=(0,)) 182 | def __data_generation(self, key): 183 | 'Generates data containing batch_size samples' 184 | subkeys = random.split(key, 2) 185 | t_r = random.uniform(subkeys[0], shape=(self.n_t,), minval=self.t0, maxval=self.t1).sort() 186 | x_r = random.uniform(subkeys[1], shape=(self.n_x, 2), minval=0.0, maxval=2.0 * np.pi) 187 | batch = (t_r, x_r) 188 | return batch 189 | 190 | 191 | # Define the model 192 | class PINN: 193 | def __init__(self, key, w_exact, layers, M_t, M_x, M_y, state0, t0, t1, n_t, x_star, y_star, tol): 194 | 195 | self.w_exact = w_exact 196 | 197 | self.M_t = M_t 198 | self.M_x = M_x 199 | self.M_y = M_y 200 | 201 | # grid 202 | self.n_t = n_t 203 | self.t0 = t0 204 | self.t1 = t1 205 | eps = 0.01 * t1 206 | self.t = np.linspace(self.t0, self.t1 + eps, n_t) 207 | self.x_star = x_star 208 | self.y_star = y_star 209 | 210 | # initial state 211 | self.state0 = state0 212 | 213 | self.tol = tol 214 | self.M = np.triu(np.ones((n_t, n_t)), k=1).T 215 | 216 | self.init, self.apply = modified_MLP_II(layers, L_x=2 * np.pi, L_y=2 * np.pi, M_t=M_t, M_x=M_x, M_y=M_y, 217 | activation=np.tanh) 218 | params = self.init(rng_key=key) 219 | 220 | # Use optimizers to set optimizer initialization and update functions 221 | self.opt_init, self.opt_update, self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 222 | decay_steps=10000, 223 | decay_rate=0.9)) 224 | self.opt_state = self.opt_init(params) 225 | _, self.unravel = ravel_pytree(params) 226 | 227 | self.u0_pred_fn = vmap(vmap(self.u_net, (None, None, None, 0)), (None, None, 0, None)) 228 | self.v0_pred_fn = vmap(vmap(self.v_net, (None, None, None, 0)), (None, None, 0, None)) 229 | self.w0_pred_fn = vmap(vmap(self.vorticity_net, (None, None, None, 0)), (None, None, 0, None)) 230 | self.u_pred_fn = vmap(vmap(vmap(self.u_net, (None, None, None, 0)), (None, None, 0, None)), 231 | (None, 0, None, None)) 232 | self.v_pred_fn = vmap(vmap(vmap(self.v_net, (None, None, None, 0)), (None, None, 0, None)), 233 | (None, 0, None, None)) 234 | self.w_pred_fn = vmap(vmap(vmap(self.vorticity_net, (None, None, None, 0)), (None, None, 0, None)), 235 | (None, 0, None, None)) 236 | # self.r_pred_fn = vmap(vmap(vmap(self.residual_net, (None, None, None, 0)), (None, None, 0, None)), (None, 0, None, None)) 237 | self.r_pred_fn = vmap(vmap(self.residual_net, (None, None, 0, 0)), (None, 0, None, None)) 238 | 239 | # Logger 240 | self.itercount = itertools.count() 241 | 242 | self.loss_log = [] 243 | self.loss_ics_log = [] 244 | self.loss_u0_log = [] 245 | self.loss_v0_log = [] 246 | self.loss_w0_log = [] 247 | self.loss_bcs_log = [] 248 | self.loss_res_w_log = [] 249 | self.loss_res_c_log = [] 250 | self.l2_error_log = [] 251 | 252 | def neural_net(self, params, t, x, y): 253 | z = np.stack([t, x, y]) 254 | outputs = self.apply(params, z) 255 | u = outputs[0] 256 | v = outputs[1] 257 | return u, v 258 | 259 | def u_net(self, params, t, x, y): 260 | u, _ = self.neural_net(params, t, x, y) 261 | return u 262 | 263 | def v_net(self, params, t, x, y): 264 | _, v = self.neural_net(params, t, x, y) 265 | return v 266 | 267 | def vorticity_net(self, params, t, x, y): 268 | u_y = grad(self.u_net, argnums=3)(params, t, x, y) 269 | v_x = grad(self.v_net, argnums=2)(params, t, x, y) 270 | w = v_x - u_y 271 | return w 272 | 273 | def residual_net(self, params, t, x, y): 274 | 275 | u, v = self.neural_net(params, t, x, y) 276 | 277 | u_x = grad(self.u_net, argnums=2)(params, t, x, y) 278 | v_y = grad(self.v_net, argnums=3)(params, t, x, y) 279 | 280 | w_t = grad(self.vorticity_net, argnums=1)(params, t, x, y) 281 | w_x = grad(self.vorticity_net, argnums=2)(params, t, x, y) 282 | w_y = grad(self.vorticity_net, argnums=3)(params, t, x, y) 283 | 284 | w_xx = grad(grad(self.vorticity_net, argnums=2), argnums=2)(params, t, x, y) 285 | w_yy = grad(grad(self.vorticity_net, argnums=3), argnums=3)(params, t, x, y) 286 | 287 | res_w = w_t + u * w_x + v * w_y - nu * (w_xx + w_yy) 288 | res_c = u_x + v_y 289 | 290 | return res_w, res_c 291 | 292 | @partial(jit, static_argnums=(0,)) 293 | def residuals_and_weights(self, params, tol, batch): 294 | t_r, x_r = batch 295 | loss_u0, loss_v0, loss_w0 = self.loss_ics(params) 296 | L_0 = 1e5 * (loss_u0 + loss_v0 + loss_w0) 297 | res_w_pred, res_c_pred = self.r_pred_fn(params, t_r, x_r[:, 0], x_r[:, 1]) 298 | L_t = np.mean(res_w_pred ** 2 + 100 * res_c_pred ** 2, axis=1) 299 | W = lax.stop_gradient(np.exp(- tol * (self.M @ L_t + L_0))) 300 | return L_0, L_t, W 301 | 302 | @partial(jit, static_argnums=(0,)) 303 | def loss_ics(self, params): 304 | # Compute forward pass 305 | u0_pred = self.u0_pred_fn(params, 0.0, self.x_star, self.y_star) 306 | v0_pred = self.v0_pred_fn(params, 0.0, self.x_star, self.y_star) 307 | w0_pred = self.w0_pred_fn(params, 0.0, self.x_star, self.y_star) 308 | # Compute loss 309 | loss_u0 = np.mean((u0_pred - self.state0[0, :, :]) ** 2) 310 | loss_v0 = np.mean((v0_pred - self.state0[1, :, :]) ** 2) 311 | loss_w0 = np.mean((w0_pred - self.state0[2, :, :]) ** 2) 312 | return loss_u0, loss_v0, loss_w0 313 | 314 | @partial(jit, static_argnums=(0,)) 315 | def loss_res(self, params, batch): 316 | t_r, x_r = batch 317 | # Compute forward pass 318 | res_w_pred, res_c_pred = self.r_pred_fn(params, t_r, x_r[:, 0], x_r[:, 1]) 319 | # Compute loss 320 | loss_res_w = np.mean(res_w_pred ** 2) 321 | loss_res_c = np.mean(res_c_pred ** 2) 322 | return loss_res_w, loss_res_c 323 | 324 | @partial(jit, static_argnums=(0,)) 325 | def loss(self, params, batch): 326 | 327 | L_0, L_t, W = self.residuals_and_weights(params, self.tol, batch) 328 | # Compute loss 329 | loss = np.mean(W * L_t + L_0) 330 | return loss 331 | 332 | @partial(jit, static_argnums=(0,)) 333 | def compute_l2_error(self, params): 334 | w_pred = self.w_pred_fn(params, t_star[:num_step], x_star, y_star) 335 | l2_error = np.linalg.norm(w_pred - self.w_exact) / np.linalg.norm(self.w_exact) 336 | return l2_error 337 | 338 | # Define a compiled update step 339 | @partial(jit, static_argnums=(0,)) 340 | def step(self, i, opt_state, batch): 341 | params = self.get_params(opt_state) 342 | g = grad(self.loss)(params, batch) 343 | return self.opt_update(i, g, opt_state) 344 | 345 | # Optimize parameters in a loop 346 | def train(self, dataset, nIter=10000): 347 | res_data = iter(dataset) 348 | pbar = trange(nIter) 349 | # Main training loop 350 | for it in pbar: 351 | batch = next(res_data) 352 | self.current_count = next(self.itercount) 353 | self.opt_state = self.step(self.current_count, self.opt_state, batch) 354 | 355 | if it % 1000 == 0: 356 | params = self.get_params(self.opt_state) 357 | 358 | l2_error_value = self.compute_l2_error(params) 359 | 360 | loss_value = self.loss(params, batch) 361 | 362 | loss_u0_value, loss_v0_value, loss_w0_value = self.loss_ics(params) 363 | loss_res_w_value, loss_res_c_value = self.loss_res(params, batch) 364 | _, _, W_value = self.residuals_and_weights(params, tol, batch) 365 | 366 | self.l2_error_log.append(l2_error_value) 367 | self.loss_log.append(loss_value) 368 | self.loss_u0_log.append(loss_u0_value) 369 | self.loss_v0_log.append(loss_v0_value) 370 | self.loss_w0_log.append(loss_w0_value) 371 | self.loss_res_w_log.append(loss_res_w_value) 372 | self.loss_res_c_log.append(loss_res_c_value) 373 | 374 | pbar.set_postfix({'l2 error': l2_error_value, 375 | 'Loss': loss_value, 376 | 'loss_u0': loss_u0_value, 377 | 'loss_v0': loss_v0_value, 378 | 'loss_w0': loss_w0_value, 379 | 'loss_res_w': loss_res_w_value, 380 | 'loss_res_c': loss_res_c_value, 381 | 'W_min': W_value.min()}) 382 | 383 | if W_value.min() > 0.99: 384 | break 385 | 386 | 387 | data = np.load('../NS.npy', allow_pickle=True).item() 388 | # Test data 389 | sol = data['sol'] 390 | 391 | t_star = data['t'] 392 | x_star = data['x'] 393 | y_star = data['y'] 394 | nu = data['viscosity'] 395 | 396 | # downsampling 397 | sol = sol 398 | x_star = x_star 399 | y_star = y_star 400 | 401 | # Create PINNs model 402 | key = random.PRNGKey(1234) 403 | 404 | u0 = data['u0'] 405 | v0 = data['v0'] 406 | w0 = data['w0'] 407 | state0 = np.stack([u0, v0, w0]) 408 | M_t = 2 409 | M_x = 5 410 | M_y = 5 411 | d0 = 2 * M_x + 2 * M_y + 4 * M_x * M_y + M_t + 2 412 | layers = [d0, 128, 128, 128, 128, 2] 413 | 414 | num_step = 10 415 | t0 = 0.0 416 | t1 = t_star[num_step] 417 | n_t = 32 418 | tol = 1.0 419 | tol_list = [1e-3, 1e-2, 1e-1, 1e0] 420 | 421 | # Create data set 422 | n_x = 256 423 | dataset = DataGenerator(t0, t1, n_t, n_x) 424 | 425 | N = 20 426 | w_pred_list = [] 427 | params_list = [] 428 | losses_list = [] 429 | 430 | for k in range(N): 431 | # Initialize model 432 | print('Final Time: {}'.format(k + 1)) 433 | w_exact = sol[num_step * k: num_step * (k + 1), :, :] 434 | model = PINN(key, w_exact, layers, M_t, M_x, M_y, state0, t0, t1, n_t, x_star, y_star, tol) 435 | 436 | # Train 437 | for tol in tol_list: 438 | model.tol = tol 439 | print('tol:', model.tol) 440 | # Train 441 | model.train(dataset, nIter=100000) 442 | 443 | # Store 444 | params = model.get_params(model.opt_state) 445 | w_pred = model.w_pred_fn(params, t_star[:num_step], x_star, y_star) 446 | w_pred_list.append(w_pred) 447 | flat_params, _ = ravel_pytree(params) 448 | params_list.append(flat_params) 449 | losses_list.append([model.l2_error_log, 450 | model.loss_log, 451 | model.loss_u0_log, 452 | model.loss_v0_log, 453 | model.loss_w0_log, 454 | model.loss_res_w_log, 455 | model.loss_res_c_log, ]) 456 | 457 | np.save('causal_w_pred_list.npy', w_pred_list) 458 | np.save('causal_params_list.npy', params_list) 459 | np.save('causal_losses_list.npy', losses_list) 460 | 461 | # error 462 | w_preds = np.vstack(w_pred_list) 463 | error = np.linalg.norm(w_preds - sol[:num_step * (k + 1), :, :]) / np.linalg.norm(sol[:num_step * (k + 1), :, :]) 464 | print('Relative l2 error: {:.3e}'.format(error)) 465 | 466 | params = model.get_params(model.opt_state) 467 | u0_pred = model.u0_pred_fn(params, t1, x_star, y_star) 468 | v0_pred = model.v0_pred_fn(params, t1, x_star, y_star) 469 | w0_pred = model.w0_pred_fn(params, t1, x_star, y_star) 470 | state0 = np.stack([u0_pred, v0_pred, w0_pred]) 471 | 472 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Respecting causality is all you need for training physics-informed neural networks 2 | 3 | # ⚠️The proposed causal training algorithm cannot be used for commercial purposes (protected by a patent at the University of Pennsylvania).⚠️ 4 | 5 | Code and data (available upon request) accompanying the manuscript titled "[Respecting causality is all you need for training physics-informed neural networks](https://arxiv.org/abs/2203.07404)", authored by Sifan Wang, Shyam Sankaran, and Paris Perdikaris. 6 | 7 | # Abstract 8 | 9 | While the popularity of physics-informed neural networks (PINNs) is steadily rising, to this date PINNs have not been successful in simulating dynamical systems whose solution exhibits multi-scale, chaotic or turbulent behavior. In this work we attribute this shortcoming to the inability of existing PINNs formulations to respect the spatio-temporal causal structure that is inherent to the evolution of physical systems. We argue that this is a fundamental limitation and a key source of error that can ultimately steer PINN models to converge towards erroneous solutions. We address this pathology by proposing a simple re-formulation of PINNs loss functions that can explicitly account for physical causality during model training. We demonstrate that this simple modification alone is enough to introduce significant accuracy improvements, as well as a practical quantitative mechanism for assessing the convergence of a PINNs model. We provide state-of-the-art numerical results across a series of benchmarks for which existing PINNs formulations fail, including the chaotic Lorenz system, the Kuramoto–Sivashinsky equation in the chaotic regime, and the Navier-Stokes equations in the turbulent regime. To the best of our knowledge, this is the first time that PINNs have been successful in simulating such systems, introducing new opportunities for their applicability to problems of industrial complexity. 10 | 11 | # Citation 12 | 13 | @article{wang2024respecting, 14 | title={Respecting causality for training physics-informed neural networks}, 15 | author={Wang, Sifan and Sankaran, Shyam and Perdikaris, Paris}, 16 | journal={Computer Methods in Applied Mechanics and Engineering}, 17 | volume={421}, 18 | pages={116813}, 19 | year={2024}, 20 | publisher={Elsevier} 21 | } 22 | 23 | 24 | # Examples 25 | 26 | ### Allen–Cahn equation 27 | 28 | https://user-images.githubusercontent.com/70182613/160253357-7936e254-ba60-4a9d-abd6-de761e3075c9.mp4 29 | 30 | ### Kuramoto–Sivashinsky equation 31 | 32 | https://user-images.githubusercontent.com/3844367/152894380-3910ee92-6f9b-473b-9942-3d3919f2f22d.mp4 33 | 34 | ### Navier-Stokes equation 35 | 36 | https://user-images.githubusercontent.com/3844367/152894393-6fbc5e1e-f2b0-419e-aa74-3ecb17d0e23e.mp4 37 | 38 | # License 39 | 40 | Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. 41 | -------------------------------------------------------------------------------- /animations/AC.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/CausalPINNs/0d1d83f87f6fa2d6756d0e35cac94a7be6183841/animations/AC.mp4 -------------------------------------------------------------------------------- /animations/KS.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/CausalPINNs/0d1d83f87f6fa2d6756d0e35cac94a7be6183841/animations/KS.mp4 -------------------------------------------------------------------------------- /animations/NS.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/CausalPINNs/0d1d83f87f6fa2d6756d0e35cac94a7be6183841/animations/NS.mp4 -------------------------------------------------------------------------------- /animations/Readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/AC.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/CausalPINNs/0d1d83f87f6fa2d6756d0e35cac94a7be6183841/data/AC.mat -------------------------------------------------------------------------------- /data/NS.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/CausalPINNs/0d1d83f87f6fa2d6756d0e35cac94a7be6183841/data/NS.npy -------------------------------------------------------------------------------- /data/ks_chaotic.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/CausalPINNs/0d1d83f87f6fa2d6756d0e35cac94a7be6183841/data/ks_chaotic.mat -------------------------------------------------------------------------------- /data/ks_simple.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/CausalPINNs/0d1d83f87f6fa2d6756d0e35cac94a7be6183841/data/ks_simple.mat -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax==0.4.5 2 | matplotlib==3.5.2 3 | numpy==1.21.5 4 | scikit_learn==1.0.2 5 | scipy==1.9.1 6 | torch==1.13.1 7 | tqdm==4.64.1 8 | --------------------------------------------------------------------------------