├── Antiderivative ├── DeepONet │ └── DeepONet_Antiderivative.py ├── FNO │ ├── FNOAntiderivative.py │ └── utilities3.py └── LOCA │ └── LOCAAntiderivative.py ├── Climate_Modeling ├── DeepONet │ └── DeepONet_Weatherg.py ├── FNO │ ├── Adam.py │ ├── utilities3.py │ └── weather_FNO.py └── LOCA │ └── LOCAWeather.py ├── Darcy ├── DeepONet │ └── DeepONet_Darcy.py ├── FNO │ ├── FNODarcy.py │ └── utilities3.py └── LOCA │ └── LOCADarcy.py ├── MMNIST ├── DeepONet │ └── DeepONet_MNIST.py ├── FNO │ ├── Adam.py │ ├── FNOMMNIST.py │ └── utilities3.py └── LOCA │ └── LOCAMMNIST.py ├── PushForward ├── DeepONet │ └── DeepONet_Pushforward.py └── LOCA │ ├── LOCAPushforward.py │ └── LOCA_closetoDON.py ├── README.md └── ShallowWaters ├── DeepONet └── DeepOnet_SW.py ├── FNO ├── Adam.py └── FNOSW.py └── LOCA └── LOCAShallowWater.py /Antiderivative/DeepONet/DeepONet_Antiderivative.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from numpy.polynomial import polyutils 5 | 6 | from jax.experimental.stax import Dense, Gelu 7 | from jax.experimental import stax 8 | import os 9 | 10 | from scipy.integrate import solve_ivp 11 | 12 | import timeit 13 | 14 | from jax.experimental import optimizers 15 | 16 | from absl import app 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | from jax.numpy.linalg import norm 21 | 22 | from jax import random, grad, vmap, jit, vjp 23 | from functools import partial 24 | 25 | from torch.utils import data 26 | 27 | from tqdm import trange 28 | 29 | import itertools 30 | 31 | import scipy.signal as signal 32 | from kymatio.numpy import Scattering1D 33 | 34 | from jax.experimental.ode import odeint 35 | from jax.config import config 36 | from numpy.polynomial.legendre import leggauss 37 | 38 | def get_freer_gpu(): 39 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') 40 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 41 | return str(np.argmax(memory_available)) 42 | 43 | os.environ['CUDA_VISIBLE_DEVICES']= get_freer_gpu() 44 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']="False" 45 | 46 | class DataGenerator(data.Dataset): 47 | def __init__(self, u, y, s, 48 | batch_size=100, rng_key=random.PRNGKey(1234)): 49 | 'Initialization' 50 | self.u = u 51 | self.y = y 52 | self.s = s 53 | self.N = u.shape[0] 54 | self.batch_size = batch_size 55 | self.key = rng_key 56 | 57 | # @partial(jit, static_argnums=(0,)) 58 | def __getitem__(self, index): 59 | 'Generate one batch of data' 60 | self.key, subkey = random.split(self.key) 61 | inputs,outputs = self.__data_generation(subkey) 62 | return inputs, outputs 63 | 64 | @partial(jit, static_argnums=(0,)) 65 | def __data_generation(self, key): 66 | 'Generates data containing batch_size samples' 67 | idx = random.choice(key, self.N, (self.batch_size,), replace=False) 68 | s = self.s[idx,:,:] 69 | u = self.u[idx,:,:] 70 | y = self.y[idx,:,:] 71 | inputs = (u, y) 72 | return inputs, s 73 | 74 | class PositionalEncodingY: 75 | def __init__(self, Y, d_model, max_len = 100,H=20): 76 | self.d_model = d_model 77 | self.Y = Y 78 | self.max_len = max_len 79 | self.H = H 80 | 81 | @partial(jit, static_argnums=(0,)) 82 | def forward(self, x): 83 | self.pe = np.zeros((x.shape[0], self.max_len, self.H)) 84 | T = jnp.asarray(self.Y[:,:,0:1]) 85 | position = jnp.tile(T,(1,1,self.H)) 86 | div_term = 2**jnp.arange(0,int(self.H/2),1)*jnp.pi 87 | self.pe = jax.ops.index_update(self.pe, jax.ops.index[:,:,0::2], jnp.cos(position[:,:,0::2] * div_term)) 88 | self.pe = jax.ops.index_update(self.pe, jax.ops.index[:,:,1::2], jnp.sin(position[:,:,1::2] * div_term)) 89 | x = jnp.concatenate([x, self.pe],axis=-1) 90 | return x 91 | 92 | class PositionalEncodingU: 93 | def __init__(self, Y, d_model, max_len = 100,H=20): 94 | self.d_model = d_model 95 | self.Y = Y 96 | self.max_len = max_len 97 | self.H = H 98 | 99 | @partial(jit, static_argnums=(0,)) 100 | def forward(self, x): 101 | self.pe = np.zeros((x.shape[0], self.max_len, self.H)) 102 | T = jnp.asarray(self.Y[:,:,0:1]) 103 | position = jnp.tile(T,(1,1,self.H)) 104 | div_term = 2**jnp.arange(0,int(self.H/2),1)*jnp.pi 105 | self.pe = jax.ops.index_update(self.pe, jax.ops.index[:,:,0::2], jnp.cos(position[:,:,0::2] * div_term)) 106 | self.pe = jax.ops.index_update(self.pe, jax.ops.index[:,:,1::2], jnp.sin(position[:,:,1::2] * div_term)) 107 | x = jnp.concatenate([x, self.pe],axis=-1) 108 | return x 109 | 110 | class DON: 111 | def __init__(self,branch_layers, trunk_layers , m=100, P=100, mn=None, std=None): 112 | # Network initialization and evaluation functions 113 | 114 | seed = np.random.randint(10000) 115 | self.branch_init, self.branch_apply = self.init_NN(branch_layers, activation=Gelu) 116 | self.in_shape = (-1, branch_layers[0]) 117 | self.out_shape, branch_params = self.branch_init(random.PRNGKey(seed), self.in_shape) 118 | 119 | seed = np.random.randint(10000) 120 | self.trunk_init, self.trunk_apply = self.init_NN(trunk_layers, activation=Gelu) 121 | self.in_shape = (-1, trunk_layers[0]) 122 | self.out_shape, trunk_params = self.trunk_init(random.PRNGKey(seed), self.in_shape) 123 | 124 | params = (trunk_params, branch_params) 125 | # Use optimizers to set optimizer initialization and update functions 126 | self.opt_init,self.opt_update,self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 127 | decay_steps=100, 128 | decay_rate=0.99)) 129 | self.opt_state = self.opt_init(params) 130 | # Logger 131 | self.itercount = itertools.count() 132 | self.loss_log = [] 133 | self.mean = mn 134 | self.std = std 135 | 136 | 137 | def init_NN(self, Q, activation=Gelu): 138 | layers = [] 139 | num_layers = len(Q) 140 | if num_layers < 2: 141 | net_init, net_apply = stax.serial() 142 | else: 143 | for i in range(0, num_layers-1): 144 | layers.append(Dense(Q[i+1])) 145 | layers.append(activation) 146 | layers.append(Dense(Q[-1])) 147 | net_init, net_apply = stax.serial(*layers) 148 | return net_init, net_apply 149 | 150 | @partial(jax.jit, static_argnums=0) 151 | def DON(self, params, inputs, ds=1): 152 | trunk_params, branch_params = params 153 | inputsxu, inputsy = inputs 154 | t = self.trunk_apply(trunk_params, inputsy).reshape(inputsy.shape[0], inputsy.shape[1], ds, int(100/ds)) 155 | b = self.branch_apply(branch_params, inputsxu.reshape(inputsxu.shape[0],1,inputsxu.shape[1]*inputsxu.shape[2])) 156 | b = b.reshape(b.shape[0],int(b.shape[2]/ds),ds) 157 | Guy = jnp.einsum("ijkl,ilk->ijk", t,b) 158 | return Guy 159 | 160 | @partial(jax.jit, static_argnums=0) 161 | def loss(self, params, batch): 162 | inputs, y = batch 163 | y_pred = self.DON(params,inputs) 164 | y = y*self.std + self.mean 165 | y_pred = y_pred*self.std + self.mean 166 | loss = np.mean((y.flatten() - y_pred.flatten())**2) 167 | return loss 168 | 169 | @partial(jax.jit, static_argnums=0) 170 | def lossT(self, params, batch): 171 | inputs, outputs = batch 172 | y_pred = self.DON(params,inputs) 173 | y_pred = y_pred*self.std + self.mean 174 | loss = np.mean((outputs.flatten() - y_pred.flatten())**2) 175 | return loss 176 | 177 | @partial(jax.jit, static_argnums=0) 178 | def L2errorT(self, params, batch): 179 | inputs, y = batch 180 | y_pred = self.DON(params,inputs) 181 | y_pred = y_pred*self.std + self.mean 182 | return norm(y.flatten() - y_pred.flatten(), 2)/norm(y.flatten(),2) 183 | 184 | @partial(jax.jit, static_argnums=0) 185 | def L2error(self, params, batch): 186 | inputs, y = batch 187 | y_pred = self.DON(params,inputs) 188 | y = y*self.std + self.mean 189 | y_pred = y_pred*self.std + self.mean 190 | return norm(y.flatten() - y_pred.flatten(), 2)/norm(y.flatten(),2) 191 | 192 | 193 | @partial(jit, static_argnums=(0,)) 194 | def step(self, i, opt_state, batch): 195 | params = self.get_params(opt_state) 196 | g = grad(self.loss)(params, batch) 197 | return self.opt_update(i, g, opt_state) 198 | 199 | def train(self, train_dataset, test_dataset, nIter = 10000): 200 | train_data = iter(train_dataset) 201 | test_data = iter(test_dataset) 202 | 203 | pbar = trange(nIter) 204 | for it in pbar: 205 | train_batch = next(train_data) 206 | test_batch = next(test_data) 207 | 208 | self.opt_state = self.step(next(self.itercount), self.opt_state, train_batch) 209 | 210 | if it % 100 == 0: 211 | params = self.get_params(self.opt_state) 212 | 213 | loss_train = self.loss(params, train_batch) 214 | loss_test = self.lossT(params, test_batch) 215 | 216 | errorTrain = self.L2error(params, train_batch) 217 | errorTest = self.L2errorT(params, test_batch) 218 | 219 | self.loss_log.append(loss_train) 220 | 221 | pbar.set_postfix({'Training loss': loss_train, 222 | 'Testing loss' : loss_test, 223 | 'Test error': errorTest, 224 | 'Train error': errorTrain}) 225 | 226 | @partial(jit, static_argnums=(0,)) 227 | def predict(self, params, inputs): 228 | s_pred = self.DON(params,inputs) 229 | return s_pred 230 | 231 | @partial(jit, static_argnums=(0,)) 232 | def predictT(self, params, inputs): 233 | s_pred = self.DON(params,inputs) 234 | return s_pred 235 | 236 | def ravel_list(self, *lst): 237 | return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([]) 238 | 239 | def ravel_pytree(self, pytree): 240 | leaves, treedef = jax.tree_util.tree_flatten(pytree) 241 | flat, unravel_list = vjp(self.ravel_list, *leaves) 242 | unravel_pytree = lambda flat: jax.tree_util.tree_unflatten(treedef, unravel_list(flat)) 243 | return flat, unravel_pytree 244 | 245 | def count_params(self, params): 246 | trunk_params, branch_params = params 247 | blv, _ = self.ravel_pytree(branch_params) 248 | tlv, _ = self.ravel_pytree(trunk_params) 249 | print("The number of model parameters is:",blv.shape[0]+tlv.shape[0]) 250 | 251 | 252 | # Define RBF kernel 253 | def RBF(x1, x2, params): 254 | output_scale, lengthscales = params 255 | diffs = jnp.expand_dims(x1 / lengthscales, 1) - \ 256 | jnp.expand_dims(x2 / lengthscales, 0) 257 | r2 = jnp.sum(diffs**2, axis=2) 258 | return output_scale * jnp.exp(-0.5 * r2) 259 | 260 | # Geneate training data corresponding to one input sample 261 | def generate_one_training_data(key, m=100, P=1): 262 | # Sample GP prior at a fine grid 263 | N = 512 264 | length_scale = 0.9 265 | gp_params = (1.0, length_scale) 266 | # key1, key2 = random.split(key,num=2) 267 | # z = random.uniform(key1, minval=-2, maxval=2) 268 | # output_scale = 10**z 269 | # z = random.uniform(key2, minval=-2, maxval=0) 270 | # length_scale = 10**z 271 | # gp_params = (output_scale, length_scale) 272 | jitter = 1e-10 273 | X = jnp.linspace(0, 1, N)[:,None] 274 | K = RBF(X, X, gp_params) 275 | L = jnp.linalg.cholesky(K + jitter*jnp.eye(N)) 276 | gp_sample = jnp.dot(L, random.normal(key, (N,))) 277 | 278 | # Create a callable interpolation function 279 | u_fn = lambda x, t: jnp.interp(t, X.flatten(), gp_sample) 280 | 281 | # Ijnput sensor locations and measurements 282 | x = jnp.linspace(0, 1, m) 283 | u = vmap(u_fn, in_axes=(None,0))(0.0, x) 284 | 285 | # Output sensor locations and measurements 286 | y = jnp.linspace(0, 1, P) 287 | s = odeint(u_fn, 0.0, y) 288 | return u, y, s 289 | 290 | # Geneate test data corresponding to one input sample 291 | def generate_one_test_data(key, m=100, P=100): 292 | # Sample GP prior at a fine grid 293 | N = 512 294 | length_scale = 0.1 295 | gp_params = (1.0, length_scale) 296 | # key1, key2 = random.split(key,num=2) 297 | # z = random.uniform(key1, minval=-2, maxval=2) 298 | # output_scale = 10**z 299 | # z = random.uniform(key2, minval=-2, maxval=0) 300 | # length_scale = 10**z 301 | # gp_params = (output_scale, length_scale) 302 | jitter = 1e-10 303 | X = jnp.linspace(0, 1, N)[:,None] 304 | K = RBF(X, X, gp_params) 305 | L = jnp.linalg.cholesky(K + jitter*jnp.eye(N)) 306 | gp_sample = jnp.dot(L, random.normal(key, (N,))) 307 | # Create a callable interpolation function 308 | u_fn = lambda x, t: jnp.interp(t, X.flatten(), gp_sample) 309 | # Input sensor locations and measurements 310 | x = jnp.linspace(0, 1, m) 311 | u = vmap(u_fn, in_axes=(None,0))(0.0, x) 312 | # Output sensor locations and measurements 313 | y = jnp.linspace(0, 1, P) 314 | s = odeint(u_fn, 0.0, y) 315 | return u, y, s 316 | 317 | # Geneate training data corresponding to N input sample 318 | def generate_training_data(key, N, m, P): 319 | config.update("jax_enable_x64", True) 320 | keys = random.split(key, N) 321 | gen_fn = jit(lambda key: generate_one_training_data(key, m, P)) 322 | u_train, y_train, s_train = vmap(gen_fn)(keys) 323 | config.update("jax_enable_x64", False) 324 | return u_train, y_train, s_train 325 | 326 | # Geneate test data corresponding to N input sample 327 | def generate_test_data(key, N, m, P): 328 | config.update("jax_enable_x64", True) 329 | keys = random.split(key, N) 330 | gen_fn = jit(lambda key: generate_one_test_data(key, m, P)) 331 | u, y, s = vmap(gen_fn)(keys) 332 | config.update("jax_enable_x64", False) 333 | return u, y, s 334 | 335 | TRAINING_ITERATIONS = 50000 336 | P = 100 337 | m = 1000 338 | num_train = 1000 339 | num_test = 1000 340 | training_batch_size = 100 341 | du = 1 342 | dy = 1 343 | ds = 1 344 | n_hat = 100 345 | Nx = P 346 | index = 9 347 | length_scale = 0.9 348 | H_y = 2 349 | H_u = 2 350 | 351 | # Create the dataset 352 | key_train = random.PRNGKey(0) 353 | U_train, y_train, s_train = generate_training_data(key_train, num_train, m, Nx) 354 | key_test = random.PRNGKey(12345) 355 | U_test, y_test, s_test = generate_test_data(key_test, num_test, m, Nx) 356 | 357 | # Make all array to be jax numpy format 358 | y_train = jnp.asarray(y_train) 359 | s_train = jnp.asarray(s_train) 360 | U_train = jnp.asarray(U_train) 361 | 362 | y_test = jnp.asarray(y_test) 363 | s_test = jnp.asarray(s_test) 364 | U_test = jnp.asarray(U_test) 365 | 366 | U_train = jnp.reshape(U_train,(num_test,m,du)) 367 | y_train = jnp.reshape(y_train,(num_train,P,dy)) 368 | s_train = jnp.reshape(s_train,(num_train,P,ds)) 369 | 370 | U_test = jnp.reshape(U_test,(num_test,m,du)) 371 | y_test = jnp.reshape(y_test,(num_test,P,dy)) 372 | s_test = jnp.reshape(s_test,(num_test,P,ds)) 373 | 374 | pos_encodingy = PositionalEncodingY(y_train,int(y_train.shape[1]*y_train.shape[2]), max_len = P, H=H_y) 375 | y_train = pos_encodingy.forward(y_train) 376 | del pos_encodingy 377 | 378 | pos_encodingyt = PositionalEncodingY(y_test,int(y_test.shape[1]*y_test.shape[2]), max_len = P, H=H_y) 379 | y_test = pos_encodingyt.forward(y_test) 380 | del pos_encodingyt 381 | 382 | pos_encodingy = PositionalEncodingU(U_train,int(U_train.shape[1]*U_train.shape[2]), max_len = m, H=H_u) 383 | U_train = pos_encodingy.forward(U_train) 384 | del pos_encodingy 385 | 386 | pos_encodingyt = PositionalEncodingU(U_test,int(U_test.shape[1]*U_test.shape[2]), max_len = m, H=H_u) 387 | U_test = pos_encodingyt.forward(U_test) 388 | del pos_encodingyt 389 | 390 | s_train_mean = jnp.mean(s_train,axis=0) 391 | s_train_std = jnp.std(s_train,axis=0) + 1e-03 392 | 393 | s_train = (s_train - s_train_mean)/s_train_std 394 | 395 | # Perform the scattering transform for the inputs yh 396 | train_dataset = DataGenerator(U_train, y_train, s_train, training_batch_size) 397 | train_dataset = iter(train_dataset) 398 | 399 | test_dataset = DataGenerator(U_test, y_test, s_test, training_batch_size) 400 | test_dataset = iter(test_dataset) 401 | 402 | branch_layers = [m*(du*H_u+du), 512, 512, ds*n_hat] 403 | trunk_layers = [H_y*dy + dy, 512, 512, ds*n_hat] 404 | 405 | model = DON(branch_layers, trunk_layers, m=m, P=P, mn=s_train_mean, std=s_train_std) 406 | 407 | model.count_params(model.get_params(model.opt_state)) 408 | 409 | start_time = timeit.default_timer() 410 | model.train(train_dataset, test_dataset, nIter=TRAINING_ITERATIONS) 411 | elapsed = timeit.default_timer() - start_time 412 | print("The training wall-clock time is seconds is equal to %f seconds"%elapsed) 413 | 414 | params = model.get_params(model.opt_state) 415 | 416 | uCNN_test = model.predictT(params, (U_test, y_test)) 417 | test_error_u = [] 418 | for i in range(0,num_train): 419 | test_error_u.append(norm(s_test[i,:,0]- uCNN_test[i,:,0],2)/norm(s_test[i,:,0],2)) 420 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u),np.std(test_error_u),np.min(test_error_u),np.max(test_error_u))) 421 | 422 | uCNN_train = model.predict(params, (U_train, y_train)) 423 | train_error_u = [] 424 | for i in range(0,num_test): 425 | train_error_u.append(norm(s_train[i,:,0]- uCNN_train[i,:,0],2)/norm(s_train[i,:,0],2)) 426 | print("The average train u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(train_error_u),np.std(train_error_u),np.min(train_error_u),np.max(train_error_u))) 427 | 428 | 429 | np.savez_compressed("/scratch/gkissas/Antiderivative/DON/Antiderivative_test_P%d_m%d_ls%f_id%d_DON.npz"%(P,m,length_scale,index), uCNN_super_all_test=uCNN_test, U_test=U_test, s_all_test=s_test, test_error=test_error_u) -------------------------------------------------------------------------------- /Antiderivative/FNO/FNOAntiderivative.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 1D problem such as the (time-independent) Burgers equation discussed in Section 5.1 in the [paper](https://arxiv.org/pdf/2010.08895.pdf). 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn.parameter import Parameter 11 | import matplotlib.pyplot as plt 12 | 13 | import operator 14 | from functools import reduce 15 | from functools import partial 16 | from timeit import default_timer 17 | from utilities3 import * 18 | from jax import random, vmap, jit 19 | 20 | import jax.numpy as jnp 21 | from jax.experimental.ode import odeint 22 | from jax.config import config 23 | import argparse 24 | 25 | import os 26 | 27 | seed = np.random.randint(10000) 28 | torch.manual_seed(seed) 29 | np.random.seed(seed) 30 | 31 | ################################################################ 32 | # 1d fourier layer 33 | ################################################################ 34 | class SpectralConv1d(nn.Module): 35 | def __init__(self, in_channels, out_channels, modes1): 36 | super(SpectralConv1d, self).__init__() 37 | 38 | """ 39 | 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. 40 | """ 41 | 42 | self.in_channels = in_channels 43 | self.out_channels = out_channels 44 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 45 | 46 | self.scale = (1 / (in_channels*out_channels)) 47 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat)) 48 | 49 | # Complex multiplication 50 | def compl_mul1d(self, input, weights): 51 | # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x) 52 | return torch.einsum("bix,iox->box", input, weights) 53 | 54 | def forward(self, x): 55 | batchsize = x.shape[0] 56 | #Compute Fourier coeffcients up to factor of e^(- something constant) 57 | x_ft = torch.fft.rfft(x) 58 | 59 | # Multiply relevant Fourier modes 60 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1, device=x.device, dtype=torch.cfloat) 61 | out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1) 62 | 63 | #Return to physical space 64 | x = torch.fft.irfft(out_ft, n=x.size(-1)) 65 | return x 66 | 67 | class FNO1d(nn.Module): 68 | def __init__(self, modes, width): 69 | super(FNO1d, self).__init__() 70 | 71 | """ 72 | The overall network. It contains 4 layers of the Fourier layer. 73 | 1. Lift the input to the desire channel dimension by self.fc0 . 74 | 2. 4 layers of the integral operators u' = (W + K)(u). 75 | W defined by self.w; K defined by self.conv . 76 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 77 | 78 | input: the solution of the initial condition and location (a(x), x) 79 | input shape: (batchsize, x=s, c=2) 80 | output: the solution of a later timestep 81 | output shape: (batchsize, x=s, c=1) 82 | """ 83 | 84 | self.modes1 = modes 85 | self.width = width 86 | self.padding = 2 # pad the domain if input is non-periodic 87 | self.fc0 = nn.Linear(2, self.width) # input channel is 2: (a(x), x) 88 | 89 | self.conv0 = SpectralConv1d(self.width, self.width, self.modes1) 90 | self.conv1 = SpectralConv1d(self.width, self.width, self.modes1) 91 | self.conv2 = SpectralConv1d(self.width, self.width, self.modes1) 92 | self.conv3 = SpectralConv1d(self.width, self.width, self.modes1) 93 | self.w0 = nn.Conv1d(self.width, self.width, 1) 94 | self.w1 = nn.Conv1d(self.width, self.width, 1) 95 | self.w2 = nn.Conv1d(self.width, self.width, 1) 96 | self.w3 = nn.Conv1d(self.width, self.width, 1) 97 | 98 | self.fc1 = nn.Linear(self.width, 128) 99 | self.fc2 = nn.Linear(128, 1) 100 | 101 | def forward(self, x): 102 | grid = self.get_grid(x.shape, x.device) 103 | x = torch.cat((x, grid), dim=-1) 104 | x = self.fc0(x) 105 | x = x.permute(0, 2, 1) 106 | x = F.pad(x, [0,self.padding]) # pad the domain if input is non-periodic 107 | 108 | x1 = self.conv0(x) 109 | x2 = self.w0(x) 110 | x = x1 + x2 111 | x = F.gelu(x) 112 | 113 | x1 = self.conv1(x) 114 | x2 = self.w1(x) 115 | x = x1 + x2 116 | x = F.gelu(x) 117 | 118 | x1 = self.conv2(x) 119 | x2 = self.w2(x) 120 | x = x1 + x2 121 | x = F.gelu(x) 122 | 123 | x1 = self.conv3(x) 124 | x2 = self.w3(x) 125 | x = x1 + x2 126 | 127 | x = x[..., :-self.padding] # pad the domain if input is non-periodic 128 | x = x.permute(0, 2, 1) 129 | x = self.fc1(x) 130 | x = F.gelu(x) 131 | x = self.fc2(x) 132 | return x 133 | 134 | def get_grid(self, shape, device): 135 | batchsize, size_x = shape[0], shape[1] 136 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 137 | gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1]) 138 | return gridx.to(device) 139 | 140 | # Define RBF kernel 141 | def RBF(x1, x2, params): 142 | output_scale, lengthscales = params 143 | diffs = jnp.expand_dims(x1 / lengthscales, 1) - \ 144 | jnp.expand_dims(x2 / lengthscales, 0) 145 | r2 = jnp.sum(diffs**2, axis=2) 146 | return output_scale * jnp.exp(-0.5 * r2) 147 | 148 | # Geneate training data corresponding to one input sample 149 | def generate_one_training_data(key, m=100, P=1, ls=1): 150 | # Sample GP prior at a fine grid 151 | N = 512 152 | # length_scale = ls 153 | # gp_params = (1.0, length_scale) 154 | key1, key2 = random.split(key,num=2) 155 | z = random.uniform(key1, minval=-2, maxval=2) 156 | output_scale = 10**z 157 | z = random.uniform(key2, minval=-2, maxval=0) 158 | length_scale = 10**z 159 | gp_params = (output_scale, length_scale) 160 | jitter = 1e-10 161 | X = jnp.linspace(0, 1, N)[:,None] 162 | K = RBF(X, X, gp_params) 163 | L = jnp.linalg.cholesky(K + jitter*jnp.eye(N)) 164 | gp_sample = jnp.dot(L, random.normal(key, (N,))) 165 | 166 | # Create a callable interpolation function 167 | u_fn = lambda x, t: jnp.interp(t, X.flatten(), gp_sample) 168 | 169 | # Ijnput sensor locations and measurements 170 | x = jnp.linspace(0, 1, m) 171 | u = vmap(u_fn, in_axes=(None,0))(0.0, x) 172 | 173 | # Output sensor locations and measurements 174 | y = jnp.linspace(0, 1, P) 175 | s = odeint(u_fn, 0.0, y) 176 | return u, y, s 177 | 178 | # Geneate test data corresponding to one input sample 179 | def generate_one_test_data(key, m=100, P=100, ls =0.1): 180 | # Sample GP prior at a fine grid 181 | N = 512 182 | # length_scale = ls 183 | # gp_params = (1.0, length_scale) 184 | key1, key2 = random.split(key,num=2) 185 | z = random.uniform(key1, minval=-2, maxval=2) 186 | output_scale = 10**z 187 | z = random.uniform(key2, minval=-2, maxval=0) 188 | length_scale = 10**z 189 | gp_params = (output_scale, length_scale) 190 | jitter = 1e-10 191 | X = jnp.linspace(0, 1, N)[:,None] 192 | K = RBF(X, X, gp_params) 193 | L = jnp.linalg.cholesky(K + jitter*jnp.eye(N)) 194 | gp_sample = jnp.dot(L, random.normal(key, (N,))) 195 | # Create a callable interpolation function 196 | u_fn = lambda x, t: jnp.interp(t, X.flatten(), gp_sample) 197 | # Input sensor locations and measurements 198 | x = jnp.linspace(0, 1, m) 199 | u = vmap(u_fn, in_axes=(None,0))(0.0, x) 200 | # Output sensor locations and measurements 201 | y = jnp.linspace(0, 1, P) 202 | s = odeint(u_fn, 0.0, y) 203 | return u, y, s 204 | 205 | # Geneate training data corresponding to N input sample 206 | def generate_training_data(key, N, m, P, ls): 207 | config.update("jax_enable_x64", True) 208 | keys = random.split(key, N) 209 | gen_fn = jit(lambda key: generate_one_training_data(key, m, P, ls)) 210 | u_train, y_train, s_train = vmap(gen_fn)(keys) 211 | config.update("jax_enable_x64", False) 212 | return u_train, y_train, s_train 213 | 214 | # Geneate test data corresponding to N input sample 215 | def generate_test_data(key, N, m, P, ls): 216 | config.update("jax_enable_x64", True) 217 | keys = random.split(key, N) 218 | gen_fn = jit(lambda key: generate_one_test_data(key, m, P, ls)) 219 | u, y, s = vmap(gen_fn)(keys) 220 | config.update("jax_enable_x64", False) 221 | return u, y, s 222 | 223 | 224 | ################################################################ 225 | # configurations 226 | ################################################################ 227 | def main(l,id): 228 | ntrain = 1000 229 | ntest = 1000 230 | m = 1000 231 | Nx = 1000 232 | 233 | 234 | h = 1000 235 | s = h 236 | 237 | batch_size = 100 238 | learning_rate = 0.001 239 | 240 | epochs = 500 241 | step_size = 100 242 | gamma = 0.5 243 | 244 | modes = 32 245 | width = 100 246 | length_scale = int(l) 247 | ind = id 248 | P = 100 249 | 250 | ################################################################ 251 | # read data 252 | ################################################################ 253 | 254 | # Data is of the shape (number of samples, grid size) 255 | print('The lengthscale is %.2f'%(0.1*l)) 256 | key_train = random.PRNGKey(0) 257 | U_train, y_train, s_train = generate_training_data(key_train, ntrain, m, Nx, 0.1*l) 258 | key_test = random.PRNGKey(12345) 259 | U_test, y_test, s_test = generate_test_data(key_test, ntest, m, Nx, 0.1) 260 | 261 | dtype_double = torch.FloatTensor 262 | cdtype_double = torch.cuda.DoubleTensor 263 | x_train = torch.from_numpy(np.asarray(U_train)).type(dtype_double).reshape(ntrain,s,1) 264 | y_train = torch.from_numpy(np.asarray(s_train)).type(dtype_double).reshape(ntrain,s,1) 265 | 266 | x_test = torch.from_numpy(np.asarray(U_test)).type(dtype_double).reshape(ntrain,s,1) 267 | y_test = torch.from_numpy(np.asarray(s_test)).type(dtype_double).reshape(ntrain,s,1) 268 | 269 | ind_train = torch.randint(s, (ntrain, P)) 270 | ind_test = torch.randint(s, (ntest, P)) 271 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train, ind_train), batch_size=batch_size, shuffle=True) 272 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test, ind_test), batch_size=batch_size, shuffle=True) 273 | 274 | ################################################################ 275 | # training and evaluation 276 | ################################################################ 277 | 278 | batch_ind = torch.arange(batch_size).reshape(-1, 1).repeat(1, P) 279 | 280 | # model 281 | model = FNO1d(modes, width).cuda() 282 | print(count_params(model)) 283 | 284 | ################################################################ 285 | # training and evaluation 286 | ################################################################ 287 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 288 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 289 | 290 | myloss = LpLoss(size_average=False) 291 | for ep in range(epochs): 292 | model.train() 293 | t1 = default_timer() 294 | train_mse = 0 295 | train_l2 = 0 296 | for x, y, idx in train_loader: 297 | x, y = x.cuda(), y.cuda() 298 | 299 | optimizer.zero_grad() 300 | out = model(x) 301 | 302 | y = y[batch_ind, idx] 303 | out = out[batch_ind, idx] 304 | l2 = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 305 | # l2.backward() 306 | 307 | mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean') 308 | mse.backward() 309 | 310 | optimizer.step() 311 | train_mse += mse.item() 312 | train_l2 += l2.item() 313 | 314 | scheduler.step() 315 | model.eval() 316 | test_l2 = 0.0 317 | with torch.no_grad(): 318 | for x, y, idx in test_loader: 319 | x, y = x.cuda(), y.cuda() 320 | 321 | out = model(x) 322 | y = y[batch_ind, idx] 323 | out = out[batch_ind, idx] 324 | 325 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 326 | 327 | train_mse /= len(train_loader) 328 | train_l2 /= ntrain 329 | test_l2 /= ntest 330 | 331 | t2 = default_timer() 332 | print(ep, t2-t1, train_mse, train_l2, test_l2) 333 | 334 | x_test = torch.from_numpy(np.asarray(U_test)).type(dtype_double).reshape(ntrain,s,1) 335 | y_test = torch.from_numpy(np.asarray(s_test)).type(dtype_double).reshape(ntrain,s,1) 336 | 337 | pred_torch = torch.zeros(y_test.shape) 338 | baseline_torch = torch.zeros(y_test.shape) 339 | index = 0 340 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 341 | test_error_u = [] 342 | test_error_u_np = [] 343 | with torch.no_grad(): 344 | for x, y in test_loader: 345 | test_l2 = 0 346 | x, y = x.cuda(), y.cuda() 347 | 348 | out = model(x) 349 | pred_torch[index] = out 350 | baseline_torch[index,:,:] = y[:,:,:] 351 | 352 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 353 | test_error_u.append(test_l2) 354 | test_error_u_np.append(np.linalg.norm(y.view(-1).cpu().numpy()- out.view(-1).cpu().numpy(),2)/np.linalg.norm(y.view(-1).cpu().numpy(),2)) 355 | # print(index, test_l2) 356 | index = index + 1 357 | print("The average test u error (no noise) is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u),np.std(test_error_u),np.min(test_error_u),np.max(test_error_u))) 358 | print("The average test u error (no noise) is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u_np),np.std(test_error_u_np),np.min(test_error_u_np),np.max(test_error_u_np))) 359 | 360 | # in_noise_test = 0.05*np.random.normal(loc=0.0, scale=1.0, size=(U_test.shape)) 361 | # U_test = U_test + in_noise_test 362 | 363 | x_test = torch.from_numpy(np.asarray(U_test)).type(dtype_double).reshape(ntrain,s,1) 364 | y_test = torch.from_numpy(np.asarray(s_test)).type(dtype_double).reshape(ntrain,s,1) 365 | 366 | pred_torch = torch.zeros(y_test.shape) 367 | baseline_torch = torch.zeros(y_test.shape) 368 | index = 0 369 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 370 | test_error_u = [] 371 | test_error_u_np = [] 372 | with torch.no_grad(): 373 | for x, y in test_loader: 374 | test_l2 = 0 375 | x, y = x.cuda(), y.cuda() 376 | 377 | out = model(x) 378 | pred_torch[index] = out 379 | baseline_torch[index,:,:] = y[:,:,:] 380 | 381 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 382 | test_error_u.append(test_l2) 383 | test_error_u_np.append(np.linalg.norm(y.view(-1).cpu().numpy()- out.view(-1).cpu().numpy(),2)/np.linalg.norm(y.view(-1).cpu().numpy(),2)) 384 | # print(index, test_l2) 385 | index = index + 1 386 | print("The average test u error (noise) is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u),np.std(test_error_u),np.min(test_error_u),np.max(test_error_u))) 387 | print("The average test u error (noise) is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u_np),np.std(test_error_u_np),np.min(test_error_u_np),np.max(test_error_u_np))) 388 | 389 | if __name__ == "__main__": 390 | 391 | parser = argparse.ArgumentParser(description='Process model parameters.') 392 | parser.add_argument('l', metavar='l', type=int, nargs='+', help='Lenghtscale of test dataset') 393 | parser.add_argument('id', metavar='id', type=int, nargs='+', help='Index of the run') 394 | 395 | args = parser.parse_args() 396 | l = args.l[0] 397 | id = args.id[0] 398 | 399 | main(l,id) -------------------------------------------------------------------------------- /Antiderivative/FNO/utilities3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.io 4 | import h5py 5 | import torch.nn as nn 6 | 7 | import operator 8 | from functools import reduce 9 | from functools import partial 10 | import os 11 | 12 | ################################################# 13 | # 14 | # Utilities 15 | # 16 | ################################################# 17 | def get_freer_gpu(): 18 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') 19 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 20 | return str(np.argmax(memory_available)) 21 | 22 | os.environ['CUDA_VISIBLE_DEVICES']= get_freer_gpu() 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | # reading data 27 | class MatReader(object): 28 | def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): 29 | super(MatReader, self).__init__() 30 | 31 | self.to_torch = to_torch 32 | self.to_cuda = to_cuda 33 | self.to_float = to_float 34 | 35 | self.file_path = file_path 36 | 37 | self.data = None 38 | self.old_mat = None 39 | self._load_file() 40 | 41 | def _load_file(self): 42 | try: 43 | self.data = scipy.io.loadmat(self.file_path) 44 | self.old_mat = True 45 | except: 46 | self.data = h5py.File(self.file_path) 47 | self.old_mat = False 48 | 49 | def load_file(self, file_path): 50 | self.file_path = file_path 51 | self._load_file() 52 | 53 | def read_field(self, field): 54 | x = self.data[field] 55 | 56 | if not self.old_mat: 57 | x = x[()] 58 | x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) 59 | 60 | if self.to_float: 61 | x = x.astype(np.float32) 62 | 63 | if self.to_torch: 64 | x = torch.from_numpy(x) 65 | 66 | if self.to_cuda: 67 | x = x.cuda() 68 | 69 | return x 70 | 71 | def set_cuda(self, to_cuda): 72 | self.to_cuda = to_cuda 73 | 74 | def set_torch(self, to_torch): 75 | self.to_torch = to_torch 76 | 77 | def set_float(self, to_float): 78 | self.to_float = to_float 79 | 80 | # normalization, pointwise gaussian 81 | class UnitGaussianNormalizer(object): 82 | def __init__(self, x, eps=0.00001): 83 | super(UnitGaussianNormalizer, self).__init__() 84 | 85 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 86 | self.mean = torch.mean(x, 0) 87 | self.std = torch.std(x, 0) 88 | self.eps = eps 89 | 90 | def encode(self, x): 91 | x = (x - self.mean) / (self.std + self.eps) 92 | return x 93 | 94 | def decode(self, x, sample_idx=None): 95 | if sample_idx is None: 96 | std = self.std + self.eps # n 97 | mean = self.mean 98 | else: 99 | if len(self.mean.shape) == len(sample_idx[0].shape): 100 | std = self.std[sample_idx] + self.eps # batch*n 101 | mean = self.mean[sample_idx] 102 | if len(self.mean.shape) > len(sample_idx[0].shape): 103 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 104 | mean = self.mean[:,sample_idx] 105 | 106 | # x is in shape of batch*n or T*batch*n 107 | x = (x * std) + mean 108 | return x 109 | 110 | def cuda(self): 111 | self.mean = self.mean.cuda() 112 | self.std = self.std.cuda() 113 | 114 | def cpu(self): 115 | self.mean = self.mean.cpu() 116 | self.std = self.std.cpu() 117 | 118 | # normalization, Gaussian 119 | class GaussianNormalizer(object): 120 | def __init__(self, x, eps=0.00001): 121 | super(GaussianNormalizer, self).__init__() 122 | 123 | self.mean = torch.mean(x) 124 | self.std = torch.std(x) 125 | self.eps = eps 126 | 127 | def encode(self, x): 128 | x = (x - self.mean) / (self.std + self.eps) 129 | return x 130 | 131 | def decode(self, x, sample_idx=None): 132 | x = (x * (self.std + self.eps)) + self.mean 133 | return x 134 | 135 | def cuda(self): 136 | self.mean = self.mean.cuda() 137 | self.std = self.std.cuda() 138 | 139 | def cpu(self): 140 | self.mean = self.mean.cpu() 141 | self.std = self.std.cpu() 142 | 143 | 144 | # normalization, scaling by range 145 | class RangeNormalizer(object): 146 | def __init__(self, x, low=0.0, high=1.0): 147 | super(RangeNormalizer, self).__init__() 148 | mymin = torch.min(x, 0)[0].view(-1) 149 | mymax = torch.max(x, 0)[0].view(-1) 150 | 151 | self.a = (high - low)/(mymax - mymin) 152 | self.b = -self.a*mymax + high 153 | 154 | def encode(self, x): 155 | s = x.size() 156 | x = x.view(s[0], -1) 157 | x = self.a*x + self.b 158 | x = x.view(s) 159 | return x 160 | 161 | def decode(self, x): 162 | s = x.size() 163 | x = x.view(s[0], -1) 164 | x = (x - self.b)/self.a 165 | x = x.view(s) 166 | return x 167 | 168 | #loss function with rel/abs Lp loss 169 | class LpLoss(object): 170 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 171 | super(LpLoss, self).__init__() 172 | 173 | #Dimension and Lp-norm type are postive 174 | assert d > 0 and p > 0 175 | 176 | self.d = d 177 | self.p = p 178 | self.reduction = reduction 179 | self.size_average = size_average 180 | 181 | def abs(self, x, y): 182 | num_examples = x.size()[0] 183 | 184 | #Assume uniform mesh 185 | h = 1.0 / (x.size()[1] - 1.0) 186 | 187 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 188 | 189 | if self.reduction: 190 | if self.size_average: 191 | return torch.mean(all_norms) 192 | else: 193 | return torch.sum(all_norms) 194 | 195 | return all_norms 196 | 197 | def rel(self, x, y): 198 | num_examples = x.size()[0] 199 | 200 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 201 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 202 | 203 | if self.reduction: 204 | if self.size_average: 205 | return torch.mean(diff_norms/y_norms) 206 | else: 207 | return torch.sum(diff_norms/y_norms) 208 | 209 | return diff_norms/y_norms 210 | 211 | def __call__(self, x, y): 212 | return self.rel(x, y) 213 | 214 | # Sobolev norm (HS norm) 215 | # where we also compare the numerical derivatives between the output and target 216 | class HsLoss(object): 217 | def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True): 218 | super(HsLoss, self).__init__() 219 | 220 | #Dimension and Lp-norm type are postive 221 | assert d > 0 and p > 0 222 | 223 | self.d = d 224 | self.p = p 225 | self.k = k 226 | self.balanced = group 227 | self.reduction = reduction 228 | self.size_average = size_average 229 | 230 | if a == None: 231 | a = [1,] * k 232 | self.a = a 233 | 234 | def rel(self, x, y): 235 | num_examples = x.size()[0] 236 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 237 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 238 | if self.reduction: 239 | if self.size_average: 240 | return torch.mean(diff_norms/y_norms) 241 | else: 242 | return torch.sum(diff_norms/y_norms) 243 | return diff_norms/y_norms 244 | 245 | def __call__(self, x, y, a=None): 246 | nx = x.size()[1] 247 | ny = x.size()[2] 248 | k = self.k 249 | balanced = self.balanced 250 | a = self.a 251 | x = x.view(x.shape[0], nx, ny, -1) 252 | y = y.view(y.shape[0], nx, ny, -1) 253 | 254 | k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny) 255 | k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1) 256 | k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device) 257 | k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device) 258 | 259 | x = torch.fft.fftn(x, dim=[1, 2]) 260 | y = torch.fft.fftn(y, dim=[1, 2]) 261 | 262 | if balanced==False: 263 | weight = 1 264 | if k >= 1: 265 | weight += a[0]**2 * (k_x**2 + k_y**2) 266 | if k >= 2: 267 | weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 268 | weight = torch.sqrt(weight) 269 | loss = self.rel(x*weight, y*weight) 270 | else: 271 | loss = self.rel(x, y) 272 | if k >= 1: 273 | weight = a[0] * torch.sqrt(k_x**2 + k_y**2) 274 | loss += self.rel(x*weight, y*weight) 275 | if k >= 2: 276 | weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 277 | loss += self.rel(x*weight, y*weight) 278 | loss = loss / (k+1) 279 | 280 | return loss 281 | 282 | # A simple feedforward neural network 283 | class DenseNet(torch.nn.Module): 284 | def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False): 285 | super(DenseNet, self).__init__() 286 | 287 | self.n_layers = len(layers) - 1 288 | 289 | assert self.n_layers >= 1 290 | 291 | self.layers = nn.ModuleList() 292 | 293 | for j in range(self.n_layers): 294 | self.layers.append(nn.Linear(layers[j], layers[j+1])) 295 | 296 | if j != self.n_layers - 1: 297 | if normalize: 298 | self.layers.append(nn.BatchNorm1d(layers[j+1])) 299 | 300 | self.layers.append(nonlinearity()) 301 | 302 | if out_nonlinearity is not None: 303 | self.layers.append(out_nonlinearity()) 304 | 305 | def forward(self, x): 306 | for _, l in enumerate(self.layers): 307 | x = l(x) 308 | 309 | return x 310 | 311 | 312 | # print the number of parameters 313 | def count_params(model): 314 | c = 0 315 | for p in list(model.parameters()): 316 | c += reduce(operator.mul, list(p.size())) 317 | return c 318 | -------------------------------------------------------------------------------- /Climate_Modeling/DeepONet/DeepONet_Weatherg.py: -------------------------------------------------------------------------------- 1 | from jax.core import as_named_shape 2 | from scipy import linalg, interpolate 3 | from sklearn import gaussian_process as gp 4 | from jax.example_libraries.stax import Dense, Gelu, Relu 5 | from jax.example_libraries import stax 6 | import os 7 | 8 | import timeit 9 | 10 | from jax.example_libraries import optimizers 11 | 12 | from absl import app 13 | from jax import vjp 14 | import jax 15 | import jax.numpy as jnp 16 | import numpy as np 17 | from jax.numpy.linalg import norm 18 | 19 | from jax import random, grad, jit 20 | from functools import partial 21 | 22 | from torch.utils import data 23 | 24 | from scipy import interpolate 25 | 26 | from tqdm import trange 27 | from math import sqrt 28 | 29 | import itertools 30 | 31 | 32 | def get_freer_gpu(): 33 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') 34 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 35 | return str(np.argmax(memory_available)) 36 | 37 | os.environ['CUDA_VISIBLE_DEVICES']= get_freer_gpu() 38 | 39 | def output_construction(s,Y,P=100,ds=1, dy=2, N=1000,Nx=100,Ny=100): 40 | s = s.reshape(Nx,Ny) 41 | x = np.random.randint(Nx, size=P) 42 | y = np.random.randint(Ny, size=P) 43 | Y_all = np.hstack([x[:, None], y[:,None]]) * [1./(Nx - 1), 1./(Ny - 1)] 44 | s_all = s[x][range(P), y][:, None] 45 | return s_all, Y_all 46 | 47 | class DataGenerator(data.Dataset): 48 | def __init__(self, u, y, s, 49 | batch_size=100, rng_key=random.PRNGKey(1234)): 50 | 'Initialization' 51 | self.u = u 52 | self.y = y 53 | self.s = s 54 | 55 | self.N = u.shape[0] 56 | self.batch_size = batch_size 57 | self.key = rng_key 58 | 59 | def __getitem__(self, index): 60 | 'Generate one batch of data' 61 | self.key, subkey = random.split(self.key) 62 | inputs,outputs = self.__data_generation(subkey) 63 | return inputs, outputs 64 | 65 | @partial(jit, static_argnums=(0,)) 66 | def __data_generation(self, key): 67 | 'Generates data containing batch_size samples' 68 | idx = random.choice(key, self.N, (self.batch_size,), replace=False) 69 | s = self.s[idx,:,:] 70 | u = self.u[idx,:,:] 71 | y = self.y[idx,:,:] 72 | inputs = (u, y) 73 | return inputs, s 74 | 75 | class PositionalEncodingY: 76 | def __init__(self, Y, d_model, max_len = 100, H=20): 77 | self.d_model = int(np.ceil(d_model/4)*2) 78 | self.Y = Y 79 | self.max_len = max_len 80 | self.H = H 81 | 82 | @partial(jit, static_argnums=(0,)) 83 | def forward(self, x): 84 | pex = np.zeros((x.shape[0], self.max_len, self.H)) 85 | pey = np.zeros((x.shape[0], self.max_len, self.H)) 86 | X1 = jnp.take(self.Y, 0, axis=2)[:,:,None] 87 | X2 = jnp.take(self.Y, 1, axis=2)[:,:,None] 88 | positionX1 = jnp.tile(X1,(1,1,self.H)) 89 | positionX2 = jnp.tile(X2,(1,1,self.H)) 90 | div_term = 2**jnp.arange(0,int(self.H/2),1)*jnp.pi 91 | pex = jax.ops.index_update(pex, jax.ops.index[:,:,0::2], jnp.cos(positionX1[:,:,0::2] * div_term)) 92 | pex = jax.ops.index_update(pex, jax.ops.index[:,:,1::2], jnp.sin(positionX1[:,:,1::2] * div_term)) 93 | pey = jax.ops.index_update(pey, jax.ops.index[:,:,0::2], jnp.cos(positionX2[:,:,0::2] * div_term)) 94 | pey = jax.ops.index_update(pey, jax.ops.index[:,:,1::2], jnp.sin(positionX2[:,:,1::2] * div_term)) 95 | pos_embedding = jnp.concatenate((pex,pey),axis=-1) 96 | x = jnp.concatenate([x, pos_embedding], -1) 97 | return x 98 | 99 | class PositionalEncodingU: 100 | def __init__(self, U, d_model, max_len = 100, H=20): 101 | self.d_model = int(np.ceil(d_model/2)*2) 102 | self.U = U 103 | self.max_len = max_len 104 | self.H = H 105 | 106 | @partial(jit, static_argnums=(0,)) 107 | def forward(self, x): 108 | peu = np.zeros((x.shape[0], self.max_len, self.H)) 109 | U = jnp.take(self.U, 0, axis=2)[:,:,None] 110 | positionU = jnp.tile(U,(1,1,self.H)) 111 | div_term = 2**jnp.arange(0,int(self.H/2),1)*jnp.pi 112 | peu = jax.ops.index_update(peu, jax.ops.index[:,:,0::2], jnp.cos(positionU[:,:,0::2] * div_term)) 113 | peu = jax.ops.index_update(peu, jax.ops.index[:,:,1::2], jnp.sin(positionU[:,:,1::2] * div_term)) 114 | x = jnp.concatenate([x, peu], -1) 115 | return x 116 | 117 | 118 | class DON: 119 | def __init__(self,branch_layers, trunk_layers , m=100, P=100, mn=None, std=None): 120 | # Network initialization and evaluation functions 121 | 122 | self.branch_init, self.branch_apply = self.init_NN(branch_layers, activation=Gelu) 123 | self.in_shape = (-1, branch_layers[0]) 124 | self.out_shape, branch_params = self.branch_init(random.PRNGKey(10000), self.in_shape) 125 | 126 | self.trunk_init, self.trunk_apply = self.init_NN(trunk_layers, activation=Gelu) 127 | self.in_shape = (-1, trunk_layers[0]) 128 | self.out_shape, trunk_params = self.trunk_init(random.PRNGKey(10000), self.in_shape) 129 | 130 | params = (trunk_params, branch_params) 131 | # Use optimizers to set optimizer initialization and update functions 132 | self.opt_init,self.opt_update,self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 133 | decay_steps=100, 134 | decay_rate=0.95)) 135 | self.opt_state = self.opt_init(params) 136 | # Logger 137 | self.itercount = itertools.count() 138 | self.loss_log = [] 139 | self.mean = mn 140 | self.std = std 141 | 142 | 143 | def init_NN(self, Q, activation=Gelu): 144 | layers = [] 145 | num_layers = len(Q) 146 | if num_layers < 2: 147 | net_init, net_apply = stax.serial() 148 | else: 149 | for i in range(0, num_layers-2): 150 | layers.append(Dense(Q[i+1])) 151 | layers.append(activation) 152 | layers.append(Dense(Q[-1])) 153 | net_init, net_apply = stax.serial(*layers) 154 | return net_init, net_apply 155 | 156 | @partial(jax.jit, static_argnums=0) 157 | def DON(self, params, inputs, ds=1): 158 | trunk_params, branch_params = params 159 | inputsxu, inputsy = inputs 160 | t = self.trunk_apply(trunk_params, inputsy).reshape(inputsy.shape[0], inputsy.shape[1], ds, int(100/ds)) 161 | b = self.branch_apply(branch_params, inputsxu.reshape(inputsxu.shape[0],1,inputsxu.shape[1]*inputsxu.shape[2])) 162 | b = b.reshape(b.shape[0],int(b.shape[2]/ds),ds) 163 | Guy = jnp.einsum("ijkl,ilk->ijk", t,b) 164 | return Guy 165 | 166 | @partial(jax.jit, static_argnums=0) 167 | def loss(self, params, batch): 168 | inputs, y = batch 169 | y_pred = self.DON(params,inputs) 170 | y = y*self.std + self.mean 171 | y_pred = y_pred*self.std + self.mean 172 | loss = np.mean((y.flatten() - y_pred.flatten())**2) 173 | return loss 174 | 175 | @partial(jax.jit, static_argnums=0) 176 | def lossT(self, params, batch): 177 | inputs, outputs = batch 178 | y_pred = self.DON(params,inputs) 179 | y_pred = y_pred*self.std + self.mean 180 | loss = np.mean((outputs.flatten() - y_pred.flatten())**2) 181 | return loss 182 | 183 | @partial(jax.jit, static_argnums=0) 184 | def L2errorT(self, params, batch): 185 | inputs, y = batch 186 | y_pred = self.DON(params,inputs) 187 | y_pred = y_pred*self.std + self.mean 188 | return norm(y.flatten() - y_pred.flatten(), 2)/norm(y.flatten(),2) 189 | 190 | @partial(jax.jit, static_argnums=0) 191 | def L2error(self, params, batch): 192 | inputs, y = batch 193 | y_pred = self.DON(params,inputs) 194 | y = y*self.std + self.mean 195 | y_pred = y_pred*self.std + self.mean 196 | return norm(y.flatten() - y_pred.flatten(), 2)/norm(y.flatten(),2) 197 | 198 | @partial(jit, static_argnums=(0,)) 199 | def step(self, i, opt_state, batch): 200 | params = self.get_params(opt_state) 201 | g = grad(self.loss)(params, batch) 202 | return self.opt_update(i, g, opt_state) 203 | 204 | def train(self, train_dataset, test_dataset, nIter = 10000): 205 | train_data = iter(train_dataset) 206 | test_data = iter(test_dataset) 207 | 208 | pbar = trange(nIter) 209 | for it in pbar: 210 | train_batch = next(train_data) 211 | test_batch = next(test_data) 212 | 213 | self.opt_state = self.step(next(self.itercount), self.opt_state, train_batch) 214 | 215 | if it % 100 == 0: 216 | params = self.get_params(self.opt_state) 217 | 218 | loss_train = self.loss(params, train_batch) 219 | loss_test = self.lossT(params, test_batch) 220 | 221 | errorTrain = self.L2error(params, train_batch) 222 | errorTest = self.L2errorT(params, test_batch) 223 | 224 | self.loss_log.append(loss_train) 225 | 226 | pbar.set_postfix({'Training loss': loss_train, 227 | 'Testing loss' : loss_test, 228 | 'Test error': errorTest, 229 | 'Train error': errorTrain}) 230 | 231 | 232 | @partial(jit, static_argnums=(0,)) 233 | def predict(self, params, inputs): 234 | s_pred = self.DON(params,inputs) 235 | return s_pred*self.std + self.mean 236 | 237 | def ravel_list(self, *lst): 238 | return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([]) 239 | 240 | def ravel_pytree(self, pytree): 241 | leaves, treedef = jax.tree_util.tree_flatten(pytree) 242 | flat, unravel_list = vjp(self.ravel_list, *leaves) 243 | unravel_pytree = lambda flat: jax.tree_util.tree_unflatten(treedef, unravel_list(flat)) 244 | return flat, unravel_pytree 245 | 246 | def count_params(self, params): 247 | trunk_params, branch_params = params 248 | blv, _ = self.ravel_pytree(branch_params) 249 | tlv, _ = self.ravel_pytree(trunk_params) 250 | print("The number of model parameters is:",blv.shape[0]+tlv.shape[0]) 251 | 252 | 253 | def predict_function(U_in,Y_in, model=None, params= None, H=10): 254 | y = np.expand_dims(Y_in,axis=0) 255 | y = np.tile(y,(U_in.shape[0],1,1)) 256 | inputs_trainxu = jnp.asarray(U_in) 257 | pos_encodingy = PositionalEncodingY(y,int(y.shape[1]*y.shape[2]), max_len = Y_in.shape[0], H=H) 258 | y = pos_encodingy.forward(y) 259 | del pos_encodingy 260 | uCNN_super_all = model.predict(params, (inputs_trainxu, y)) 261 | return uCNN_super_all, y[:,:,1:2], y[:,:,0:1] 262 | 263 | def error_full_resolution(uCNN_super_all, s_all,tag='train', num_train=1000, Nx=32, Ny=32): 264 | test_error_u = [] 265 | z = uCNN_super_all.reshape(num_train,Nx,Ny) 266 | s = s_all.reshape(num_train,Nx,Ny) 267 | s = np.swapaxes(s,1,2) 268 | for i in range(0,num_train): 269 | test_error_u.append(norm(s[i,:,0]- z[i,:,0], 2)/norm(s[i,:,0], 2)) 270 | print("The average "+tag+" u error for the super resolution is %e, the standard deviation %e, the minimum error is %e and the maximum error is %e"%(np.mean(test_error_u),np.std(test_error_u),np.min(test_error_u),np.max(test_error_u))) 271 | absolute_error = np.abs(z-s) 272 | return absolute_error, np.mean(test_error_u), test_error_u 273 | 274 | def minmax(a, mean): 275 | minpos = a.index(min(a)) 276 | maxpos = a.index(max(a)) 277 | meanpos = min(range(len(a)), key=lambda i: abs(a[i]-mean)) 278 | 279 | print("The maximum is at position", maxpos) 280 | print("The minimum is at position", minpos) 281 | print("The mean is at position", meanpos) 282 | return minpos,maxpos,meanpos 283 | 284 | 285 | def main(_): 286 | TRAINING_ITERATIONS = 100000 287 | P = 144 288 | m = int(72*72) 289 | num_train = 1825 290 | num_test = 1825 291 | training_batch_size = 100 292 | du = 1 293 | dy = 2 294 | ds = 1 295 | n_hat = 100 296 | Nx = 72 297 | Ny = 72 298 | H_y = 10 299 | H_u = 10 300 | 301 | d = np.load("../Data/weather_dataset.npz") 302 | u_train = d["U_train"][:num_train,:] 303 | S_train = d["S_train"][:num_train,:]/1000. 304 | Y_train = d["Y_train"] 305 | 306 | d = np.load("../Data/weather_dataset.npz") 307 | u_test = d["U_train"][-num_test:,:] 308 | S_test = d["S_train"][-num_test:,:]/1000. 309 | Y_test = d["Y_train"] 310 | 311 | Y_train_in = Y_train 312 | Y_test_in = Y_test 313 | 314 | s_all_test = S_test 315 | s_all_train = S_train 316 | 317 | s_train = np.zeros((num_train,P,ds)) 318 | y_train = np.zeros((num_train,P,dy)) 319 | U_train = np.zeros((num_train,m,du)) 320 | 321 | s_test = np.zeros((num_test,P,ds)) 322 | y_test = np.zeros((num_test,P,dy)) 323 | U_test = np.zeros((num_test,m,du)) 324 | 325 | for i in range(0,num_train): 326 | s_train[i,:,:], y_train[i,:,:] = output_construction(S_train[i,:], Y_train, Nx=Nx, Ny=Ny, P=P, ds=ds) 327 | U_train[i,:,:] = u_train[i,:][:,None] 328 | 329 | for i in range(num_test): 330 | s_test[i,:,:], y_test[i,:,:] = output_construction(S_test[i,:], Y_test, Nx=Nx, Ny=Ny, P=P, ds=ds) 331 | U_test[i,:,:] = u_test[i,:][:,None] 332 | 333 | U_train = jnp.asarray(U_train) 334 | y_train = jnp.asarray(y_train) 335 | s_train = jnp.asarray(s_train) 336 | 337 | U_test = jnp.asarray(U_test) 338 | y_test = jnp.asarray(y_test) 339 | s_test = jnp.asarray(s_test) 340 | 341 | U_train = jnp.reshape(U_train,(num_train,m,du)) 342 | y_train = jnp.reshape(y_train,(num_train,P,dy)) 343 | s_train = jnp.reshape(s_train,(num_train,P,ds)) 344 | 345 | U_test = jnp.reshape(U_test,(num_test,m,du)) 346 | y_test = jnp.reshape(y_test,(num_test,P,dy)) 347 | s_test = jnp.reshape(s_test,(num_test,P,ds)) 348 | 349 | pos_encodingy = PositionalEncodingY(y_train,int(y_train.shape[1]*y_train.shape[2]), max_len = P, H=H_y) 350 | y_train = pos_encodingy.forward(y_train) 351 | del pos_encodingy 352 | 353 | pos_encodingyt = PositionalEncodingY(y_test,int(y_test.shape[1]*y_test.shape[2]), max_len = P, H=H_y) 354 | y_test = pos_encodingyt.forward(y_test) 355 | del pos_encodingyt 356 | 357 | pos_encodingy = PositionalEncodingU(U_train,int(U_train.shape[1]*U_train.shape[2]), max_len = m, H=H_u) 358 | U_train = pos_encodingy.forward(U_train) 359 | del pos_encodingy 360 | 361 | print(U_test[0,0:20,:]) 362 | 363 | pos_encodingyt = PositionalEncodingU(U_test,int(U_test.shape[1]*U_test.shape[2]), max_len = m, H=H_u) 364 | U_test = pos_encodingyt.forward(U_test) 365 | del pos_encodingyt 366 | 367 | s_train_mean = jnp.mean(s_train,axis=0) 368 | s_train_std = jnp.std(s_train,axis=0) 369 | 370 | s_train = (s_train - s_train_mean)/s_train_std 371 | 372 | train_dataset = DataGenerator(U_train, y_train, s_train, training_batch_size) 373 | train_dataset = iter(train_dataset) 374 | 375 | test_dataset = DataGenerator(U_test, y_test, s_test, training_batch_size) 376 | test_dataset = iter(test_dataset) 377 | 378 | branch_layers = [m*(du*H_u+du), 100, 100, 100, 100, ds*n_hat] 379 | trunk_layers = [H_y*dy + dy, 100, 100, 100, 100, ds*n_hat] 380 | 381 | model = DON(branch_layers, trunk_layers, m=m, P=P, mn=s_train_mean, std=s_train_std) 382 | model.count_params(model.get_params(model.opt_state)) 383 | 384 | start_time = timeit.default_timer() 385 | model.train(train_dataset, test_dataset, nIter=TRAINING_ITERATIONS) 386 | elapsed = timeit.default_timer() - start_time 387 | print("The training wall-clock time is seconds is equal to %f seconds"%elapsed) 388 | 389 | params = model.get_params(model.opt_state) 390 | 391 | print("Predicting the solution for the full resolution") 392 | uCNN_super_all_test = np.zeros_like(s_all_test).reshape(num_test, Nx*Ny, ds) 393 | for i in range(0, Nx*Ny, P): 394 | idx = i + np.arange(0,P) 395 | uCNN_super_all_test[:,idx,:], _, _ = predict_function(U_test , Y_test_in[idx,:], model=model, params=params, H=H_y) 396 | 397 | uCNN_super_all_train = np.zeros_like(s_all_train).reshape(num_train, Nx*Ny, ds) 398 | for i in range(0, Nx*Ny, P): 399 | idx = i + np.arange(0,P) 400 | uCNN_super_all_train[:,idx,:], _, _ = predict_function(U_train , Y_train_in[idx,:], model=model, params=params, H=H_y) 401 | 402 | absolute_error_test, mean_test_error, test_error = error_full_resolution(uCNN_super_all_test, s_all_test, tag='test', num_train=num_train,Nx=Nx, Ny=Ny) 403 | absolute_error_train, mean_train_error, train_error = error_full_resolution(uCNN_super_all_train, s_all_train, tag='train',num_train=num_test ,Nx=Nx, Ny=Ny) 404 | 405 | print(np.max(absolute_error_test), np.max(absolute_error_train)) 406 | np.savez_compressed("Error_Weather_DeepONet_P%d"%(P), test_error = test_error) 407 | 408 | 409 | if __name__ == '__main__': 410 | app.run(main) -------------------------------------------------------------------------------- /Climate_Modeling/FNO/Adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import Tensor 4 | from typing import List, Optional 5 | from torch.optim.optimizer import Optimizer 6 | 7 | import os 8 | os.environ['CUDA_VISIBLE_DEVICES']="3" 9 | 10 | def adam(params: List[Tensor], 11 | grads: List[Tensor], 12 | exp_avgs: List[Tensor], 13 | exp_avg_sqs: List[Tensor], 14 | max_exp_avg_sqs: List[Tensor], 15 | state_steps: List[int], 16 | *, 17 | amsgrad: bool, 18 | beta1: float, 19 | beta2: float, 20 | lr: float, 21 | weight_decay: float, 22 | eps: float): 23 | r"""Functional API that performs Adam algorithm computation. 24 | See :class:`~torch.optim.Adam` for details. 25 | """ 26 | 27 | for i, param in enumerate(params): 28 | 29 | grad = grads[i] 30 | exp_avg = exp_avgs[i] 31 | exp_avg_sq = exp_avg_sqs[i] 32 | step = state_steps[i] 33 | 34 | bias_correction1 = 1 - beta1 ** step 35 | bias_correction2 = 1 - beta2 ** step 36 | 37 | if weight_decay != 0: 38 | grad = grad.add(param, alpha=weight_decay) 39 | 40 | # Decay the first and second moment running average coefficient 41 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 42 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 43 | if amsgrad: 44 | # Maintains the maximum of all 2nd moment running avg. till now 45 | torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 46 | # Use the max. for normalizing running avg. of gradient 47 | denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) 48 | else: 49 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) 50 | 51 | step_size = lr / bias_correction1 52 | 53 | param.addcdiv_(exp_avg, denom, value=-step_size) 54 | 55 | 56 | class Adam(Optimizer): 57 | r"""Implements Adam algorithm. 58 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 59 | The implementation of the L2 penalty follows changes proposed in 60 | `Decoupled Weight Decay Regularization`_. 61 | Args: 62 | params (iterable): iterable of parameters to optimize or dicts defining 63 | parameter groups 64 | lr (float, optional): learning rate (default: 1e-3) 65 | betas (Tuple[float, float], optional): coefficients used for computing 66 | running averages of gradient and its square (default: (0.9, 0.999)) 67 | eps (float, optional): term added to the denominator to improve 68 | numerical stability (default: 1e-8) 69 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 70 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 71 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 72 | (default: False) 73 | .. _Adam\: A Method for Stochastic Optimization: 74 | https://arxiv.org/abs/1412.6980 75 | .. _Decoupled Weight Decay Regularization: 76 | https://arxiv.org/abs/1711.05101 77 | .. _On the Convergence of Adam and Beyond: 78 | https://openreview.net/forum?id=ryQu7f-RZ 79 | """ 80 | 81 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 82 | weight_decay=0, amsgrad=False): 83 | if not 0.0 <= lr: 84 | raise ValueError("Invalid learning rate: {}".format(lr)) 85 | if not 0.0 <= eps: 86 | raise ValueError("Invalid epsilon value: {}".format(eps)) 87 | if not 0.0 <= betas[0] < 1.0: 88 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 89 | if not 0.0 <= betas[1] < 1.0: 90 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 91 | if not 0.0 <= weight_decay: 92 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 93 | defaults = dict(lr=lr, betas=betas, eps=eps, 94 | weight_decay=weight_decay, amsgrad=amsgrad) 95 | super(Adam, self).__init__(params, defaults) 96 | 97 | def __setstate__(self, state): 98 | super(Adam, self).__setstate__(state) 99 | for group in self.param_groups: 100 | group.setdefault('amsgrad', False) 101 | 102 | @torch.no_grad() 103 | def step(self, closure=None): 104 | """Performs a single optimization step. 105 | Args: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | with torch.enable_grad(): 112 | loss = closure() 113 | 114 | for group in self.param_groups: 115 | params_with_grad = [] 116 | grads = [] 117 | exp_avgs = [] 118 | exp_avg_sqs = [] 119 | max_exp_avg_sqs = [] 120 | state_steps = [] 121 | beta1, beta2 = group['betas'] 122 | 123 | for p in group['params']: 124 | if p.grad is not None: 125 | params_with_grad.append(p) 126 | if p.grad.is_sparse: 127 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 128 | grads.append(p.grad) 129 | 130 | state = self.state[p] 131 | # Lazy state initialization 132 | if len(state) == 0: 133 | state['step'] = 0 134 | # Exponential moving average of gradient values 135 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 136 | # Exponential moving average of squared gradient values 137 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 138 | if group['amsgrad']: 139 | # Maintains max of all exp. moving avg. of sq. grad. values 140 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 141 | 142 | exp_avgs.append(state['exp_avg']) 143 | exp_avg_sqs.append(state['exp_avg_sq']) 144 | 145 | if group['amsgrad']: 146 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 147 | 148 | # update the steps for each param group update 149 | state['step'] += 1 150 | # record the step after step update 151 | state_steps.append(state['step']) 152 | 153 | adam(params_with_grad, 154 | grads, 155 | exp_avgs, 156 | exp_avg_sqs, 157 | max_exp_avg_sqs, 158 | state_steps, 159 | amsgrad=group['amsgrad'], 160 | beta1=beta1, 161 | beta2=beta2, 162 | lr=group['lr'], 163 | weight_decay=group['weight_decay'], 164 | eps=group['eps']) 165 | return loss -------------------------------------------------------------------------------- /Climate_Modeling/FNO/utilities3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.io 4 | import h5py 5 | import torch.nn as nn 6 | 7 | import operator 8 | from functools import reduce 9 | from functools import partial 10 | 11 | import os 12 | def get_freer_gpu(): 13 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') 14 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 15 | return str(np.argmax(memory_available)) 16 | 17 | os.environ['CUDA_VISIBLE_DEVICES']= get_freer_gpu() 18 | 19 | ################################################# 20 | # 21 | # Utilities 22 | # 23 | ################################################# 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | # reading data 27 | class MatReader(object): 28 | def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): 29 | super(MatReader, self).__init__() 30 | 31 | self.to_torch = to_torch 32 | self.to_cuda = to_cuda 33 | self.to_float = to_float 34 | 35 | self.file_path = file_path 36 | 37 | self.data = None 38 | self.old_mat = None 39 | self._load_file() 40 | 41 | def _load_file(self): 42 | try: 43 | self.data = scipy.io.loadmat(self.file_path) 44 | self.old_mat = True 45 | except: 46 | self.data = h5py.File(self.file_path) 47 | self.old_mat = False 48 | 49 | def load_file(self, file_path): 50 | self.file_path = file_path 51 | self._load_file() 52 | 53 | def read_field(self, field): 54 | x = self.data[field] 55 | 56 | if not self.old_mat: 57 | x = x[()] 58 | x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) 59 | 60 | if self.to_float: 61 | x = x.astype(np.float32) 62 | 63 | if self.to_torch: 64 | x = torch.from_numpy(x) 65 | 66 | if self.to_cuda: 67 | x = x.cuda() 68 | 69 | return x 70 | 71 | def set_cuda(self, to_cuda): 72 | self.to_cuda = to_cuda 73 | 74 | def set_torch(self, to_torch): 75 | self.to_torch = to_torch 76 | 77 | def set_float(self, to_float): 78 | self.to_float = to_float 79 | 80 | # normalization, pointwise gaussian 81 | class UnitGaussianNormalizer(object): 82 | def __init__(self, x, eps=0.00001): 83 | super(UnitGaussianNormalizer, self).__init__() 84 | 85 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 86 | self.mean = torch.mean(x, 0) 87 | self.std = torch.std(x, 0) 88 | self.eps = eps 89 | 90 | def encode(self, x): 91 | x = (x - self.mean) / (self.std + self.eps) 92 | return x 93 | 94 | def decode(self, x, sample_idx=None): 95 | if sample_idx is None: 96 | std = self.std + self.eps # n 97 | mean = self.mean 98 | else: 99 | if len(self.mean.shape) == len(sample_idx[0].shape): 100 | std = self.std[sample_idx] + self.eps # batch*n 101 | mean = self.mean[sample_idx] 102 | if len(self.mean.shape) > len(sample_idx[0].shape): 103 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 104 | mean = self.mean[:,sample_idx] 105 | 106 | # x is in shape of batch*n or T*batch*n 107 | x = (x * std) + mean 108 | return x 109 | 110 | def cuda(self): 111 | self.mean = self.mean.cuda() 112 | self.std = self.std.cuda() 113 | 114 | def cpu(self): 115 | self.mean = self.mean.cpu() 116 | self.std = self.std.cpu() 117 | 118 | # normalization, Gaussian 119 | class GaussianNormalizer(object): 120 | def __init__(self, x, eps=0.00001): 121 | super(GaussianNormalizer, self).__init__() 122 | 123 | self.mean = torch.mean(x) 124 | self.std = torch.std(x) 125 | self.eps = eps 126 | 127 | def encode(self, x): 128 | x = (x - self.mean) / (self.std + self.eps) 129 | return x 130 | 131 | def decode(self, x, sample_idx=None): 132 | x = (x * (self.std + self.eps)) + self.mean 133 | return x 134 | 135 | def cuda(self): 136 | self.mean = self.mean.cuda() 137 | self.std = self.std.cuda() 138 | 139 | def cpu(self): 140 | self.mean = self.mean.cpu() 141 | self.std = self.std.cpu() 142 | 143 | 144 | # normalization, scaling by range 145 | class RangeNormalizer(object): 146 | def __init__(self, x, low=0.0, high=1.0): 147 | super(RangeNormalizer, self).__init__() 148 | mymin = torch.min(x, 0)[0].view(-1) 149 | mymax = torch.max(x, 0)[0].view(-1) 150 | 151 | self.a = (high - low)/(mymax - mymin) 152 | self.b = -self.a*mymax + high 153 | 154 | def encode(self, x): 155 | s = x.size() 156 | x = x.view(s[0], -1) 157 | x = self.a*x + self.b 158 | x = x.view(s) 159 | return x 160 | 161 | def decode(self, x): 162 | s = x.size() 163 | x = x.view(s[0], -1) 164 | x = (x - self.b)/self.a 165 | x = x.view(s) 166 | return x 167 | 168 | #loss function with rel/abs Lp loss 169 | class LpLoss(object): 170 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 171 | super(LpLoss, self).__init__() 172 | 173 | #Dimension and Lp-norm type are postive 174 | assert d > 0 and p > 0 175 | 176 | self.d = d 177 | self.p = p 178 | self.reduction = reduction 179 | self.size_average = size_average 180 | 181 | def abs(self, x, y): 182 | num_examples = x.size()[0] 183 | 184 | #Assume uniform mesh 185 | h = 1.0 / (x.size()[1] - 1.0) 186 | 187 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 188 | 189 | if self.reduction: 190 | if self.size_average: 191 | return torch.mean(all_norms) 192 | else: 193 | return torch.sum(all_norms) 194 | 195 | return all_norms 196 | 197 | def rel(self, x, y): 198 | num_examples = x.size()[0] 199 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 200 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 201 | if self.reduction: 202 | if self.size_average: 203 | return torch.mean(diff_norms/y_norms) 204 | else: 205 | return torch.sum(diff_norms/y_norms) 206 | 207 | return diff_norms/y_norms 208 | 209 | def __call__(self, x, y): 210 | return self.rel(x, y) 211 | 212 | # Sobolev norm (HS norm) 213 | # where we also compare the numerical derivatives between the output and target 214 | class HsLoss(object): 215 | def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True): 216 | super(HsLoss, self).__init__() 217 | 218 | #Dimension and Lp-norm type are postive 219 | assert d > 0 and p > 0 220 | 221 | self.d = d 222 | self.p = p 223 | self.k = k 224 | self.balanced = group 225 | self.reduction = reduction 226 | self.size_average = size_average 227 | 228 | if a == None: 229 | a = [1,] * k 230 | self.a = a 231 | 232 | def rel(self, x, y): 233 | num_examples = x.size()[0] 234 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 235 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 236 | if self.reduction: 237 | if self.size_average: 238 | return torch.mean(diff_norms/y_norms) 239 | else: 240 | return torch.sum(diff_norms/y_norms) 241 | return diff_norms/y_norms 242 | 243 | def __call__(self, x, y, a=None): 244 | nx = x.size()[1] 245 | ny = x.size()[2] 246 | k = self.k 247 | balanced = self.balanced 248 | a = self.a 249 | x = x.view(x.shape[0], nx, ny, -1) 250 | y = y.view(y.shape[0], nx, ny, -1) 251 | 252 | k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny) 253 | k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1) 254 | k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device) 255 | k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device) 256 | 257 | x = torch.fft.fftn(x, dim=[1, 2]) 258 | y = torch.fft.fftn(y, dim=[1, 2]) 259 | 260 | if balanced==False: 261 | weight = 1 262 | if k >= 1: 263 | weight += a[0]**2 * (k_x**2 + k_y**2) 264 | if k >= 2: 265 | weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 266 | weight = torch.sqrt(weight) 267 | loss = self.rel(x*weight, y*weight) 268 | else: 269 | loss = self.rel(x, y) 270 | if k >= 1: 271 | weight = a[0] * torch.sqrt(k_x**2 + k_y**2) 272 | loss += self.rel(x*weight, y*weight) 273 | if k >= 2: 274 | weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 275 | loss += self.rel(x*weight, y*weight) 276 | loss = loss / (k+1) 277 | 278 | return loss 279 | 280 | # A simple feedforward neural network 281 | class DenseNet(torch.nn.Module): 282 | def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False): 283 | super(DenseNet, self).__init__() 284 | 285 | self.n_layers = len(layers) - 1 286 | 287 | assert self.n_layers >= 1 288 | 289 | self.layers = nn.ModuleList() 290 | 291 | for j in range(self.n_layers): 292 | self.layers.append(nn.Linear(layers[j], layers[j+1])) 293 | 294 | if j != self.n_layers - 1: 295 | if normalize: 296 | self.layers.append(nn.BatchNorm1d(layers[j+1])) 297 | 298 | self.layers.append(nonlinearity()) 299 | 300 | if out_nonlinearity is not None: 301 | self.layers.append(out_nonlinearity()) 302 | 303 | def forward(self, x): 304 | for _, l in enumerate(self.layers): 305 | x = l(x) 306 | 307 | return x 308 | 309 | 310 | # print the number of parameters 311 | def count_params(model): 312 | c = 0 313 | for p in list(model.parameters()): 314 | c += reduce(operator.mul, list(p.size())) 315 | return c -------------------------------------------------------------------------------- /Climate_Modeling/FNO/weather_FNO.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 2D problem such as the Darcy Flow discussed in Section 5.2 in the [paper](https://arxiv.org/pdf/2010.08895.pdf). 4 | """ 5 | 6 | import numpy as np 7 | from numpy.linalg import norm 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn.parameter import Parameter 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | import operator 16 | from functools import reduce 17 | from functools import partial 18 | 19 | from timeit import default_timer 20 | from utilities3 import * 21 | 22 | torch.manual_seed(0) 23 | np.random.seed(0) 24 | 25 | import timeit 26 | 27 | 28 | ################################################################ 29 | # fourier layer 30 | ################################################################ 31 | class SpectralConv2d(nn.Module): 32 | def __init__(self, in_channels, out_channels, modes1, modes2): 33 | super(SpectralConv2d, self).__init__() 34 | 35 | """ 36 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. 37 | """ 38 | 39 | self.in_channels = in_channels 40 | self.out_channels = out_channels 41 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 42 | self.modes2 = modes2 43 | 44 | self.scale = (1 / (in_channels * out_channels)) 45 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 46 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 47 | 48 | # Complex multiplication 49 | def compl_mul2d(self, input, weights): 50 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 51 | return torch.einsum("bixy,ioxy->boxy", input, weights) 52 | 53 | def forward(self, x): 54 | batchsize = x.shape[0] 55 | #Compute Fourier coeffcients up to factor of e^(- something constant) 56 | x_ft = torch.fft.rfft2(x) 57 | 58 | # Multiply relevant Fourier modes 59 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device) 60 | out_ft[:, :, :self.modes1, :self.modes2] = \ 61 | self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 62 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 63 | self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 64 | 65 | #Return to physical space 66 | x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) 67 | return x 68 | 69 | class FNO2d(nn.Module): 70 | def __init__(self, modes1, modes2, width): 71 | super(FNO2d, self).__init__() 72 | 73 | """ 74 | The overall network. It contains 4 layers of the Fourier layer. 75 | 1. Lift the input to the desire channel dimension by self.fc0 . 76 | 2. 4 layers of the integral operators u' = (W + K)(u). 77 | W defined by self.w; K defined by self.conv . 78 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 79 | 80 | input: the solution of the coefficient function and locations (a(x, y), x, y) 81 | input shape: (batchsize, x=s, y=s, c=3) 82 | output: the solution 83 | output shape: (batchsize, x=s, y=s, c=1) 84 | """ 85 | 86 | self.modes1 = modes1 87 | self.modes2 = modes2 88 | self.width = width 89 | self.fc0 = nn.Linear(3, self.width) # input channel is 3: (a(x, y), x, y) 90 | 91 | self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 92 | self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 93 | self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 94 | self.conv3 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 95 | self.w0 = nn.Conv1d(self.width, self.width, 1) 96 | self.w1 = nn.Conv1d(self.width, self.width, 1) 97 | self.w2 = nn.Conv1d(self.width, self.width, 1) 98 | self.w3 = nn.Conv1d(self.width, self.width, 1) 99 | 100 | 101 | self.fc1 = nn.Linear(self.width, 128) 102 | self.fc2 = nn.Linear(128, 1) 103 | 104 | def forward(self, x): 105 | batchsize = x.shape[0] 106 | size_x, size_y = x.shape[1], x.shape[2] 107 | 108 | x = self.fc0(x) 109 | x = x.permute(0, 3, 1, 2) 110 | 111 | x1 = self.conv0(x) 112 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 113 | x = x1 + x2 114 | x = F.relu(x) 115 | 116 | x1 = self.conv1(x) 117 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 118 | x = x1 + x2 119 | x = F.relu(x) 120 | 121 | x1 = self.conv2(x) 122 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 123 | x = x1 + x2 124 | x = F.relu(x) 125 | 126 | x1 = self.conv3(x) 127 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 128 | x = x1 + x2 129 | 130 | x = x.permute(0, 2, 3, 1) 131 | x = self.fc1(x) 132 | x = F.relu(x) 133 | x = self.fc2(x) 134 | return x 135 | 136 | 137 | ################################################################ 138 | # configs 139 | ################################################################ 140 | ntrain = 1825 141 | ntest = 1825 142 | 143 | batch_size = 73 144 | learning_rate = 0.001 145 | 146 | epochs = 400 147 | step_size = 100 148 | gamma = 0.5 149 | 150 | modes = 12 151 | width = 32 152 | 153 | r = 1 154 | Nx = 72 155 | Ny = 72 156 | h = Nx 157 | s = h 158 | P = 144 159 | 160 | ################################################################ 161 | # load data and data normalization 162 | ################################################################ 163 | 164 | d = np.load("../Data/weather_dataset.npz") 165 | U_train = d["U_train"][:ntrain,:].reshape(ntrain,Nx,Ny) 166 | S_train = d["S_train"][:ntrain,:].reshape(ntrain,Nx,Ny)/1000. 167 | CX = d["X_train"] 168 | CY = d["Y_train"] 169 | 170 | d = np.load("../Data/weather_dataset.npz") 171 | U_test = d["U_train"][ntest:,:].reshape(ntrain,Nx,Ny) 172 | S_test = d["S_train"][ntest:,:].reshape(ntrain,Nx,Ny)/1000. 173 | CX = d["X_train"] 174 | CY = d["Y_train"] 175 | 176 | dtype_double = torch.FloatTensor 177 | cdtype_double = torch.cuda.DoubleTensor 178 | U_train = torch.from_numpy(np.asarray(U_train)).type(dtype_double) 179 | S_train = torch.from_numpy(np.asarray(S_train)).type(dtype_double) 180 | 181 | U_test = torch.from_numpy(np.asarray(U_test)).type(dtype_double) 182 | S_test = torch.from_numpy(np.asarray(S_test)).type(dtype_double) 183 | 184 | x_train = U_train 185 | y_train = S_train 186 | 187 | x_test = U_test 188 | y_test = S_test 189 | 190 | grids = [] 191 | lontest = np.linspace(0,355,num=Nx)/360 192 | lattest = (np.linspace(90,-87.5,num=Ny) + 90.)/180. 193 | 194 | grids.append(lontest) 195 | grids.append(lattest) 196 | grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T 197 | grid = grid.reshape(1,s,s,2) 198 | grid = torch.tensor(grid, dtype=torch.float) 199 | x_train = torch.cat([x_train.reshape(ntrain,s,s,1), grid.repeat(ntrain,1,1,1)], dim=3) 200 | x_test = torch.cat([x_test.reshape(ntest,s,s,1), grid.repeat(ntest,1,1,1)], dim=3) 201 | 202 | ind_train = torch.randint(s*s, (ntrain, P)) 203 | ind_test = torch.randint(s*s, (ntest, P)) 204 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train, ind_train), batch_size=batch_size, shuffle=True) 205 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test, ind_test), batch_size=batch_size, shuffle=True) 206 | 207 | ################################################################ 208 | # training and evaluation 209 | ################################################################ 210 | 211 | batch_ind = torch.arange(batch_size).reshape(-1, 1).repeat(1, P) 212 | model = FNO2d(modes, modes, width).cuda() 213 | print(count_params(model)) 214 | 215 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 216 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 217 | 218 | myloss = LpLoss(size_average=False) 219 | start_time = timeit.default_timer() 220 | for ep in range(epochs): 221 | model.train() 222 | t1 = default_timer() 223 | train_l2 = 0 224 | for x, y, idx in train_loader: 225 | x, y = x.cuda(), y.cuda() 226 | 227 | optimizer.zero_grad() 228 | out = model(x).reshape(batch_size, s*s) 229 | y = y.reshape(batch_size, s*s) 230 | y = y[batch_ind, idx] 231 | out = out[batch_ind, idx] 232 | 233 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 234 | loss.backward() 235 | 236 | optimizer.step() 237 | train_l2 += loss.item() 238 | 239 | scheduler.step() 240 | 241 | model.eval() 242 | test_l2 = 0.0 243 | with torch.no_grad(): 244 | for x, y, idx in test_loader: 245 | x, y = x.cuda(), y.cuda() 246 | 247 | out = model(x).reshape(batch_size, s*s) 248 | y = y.reshape(batch_size, s*s,1) 249 | y = y[batch_ind, idx] 250 | out = out[batch_ind, idx] 251 | 252 | test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item() 253 | 254 | train_l2/= ntrain 255 | test_l2 /= ntest 256 | 257 | t2 = default_timer() 258 | 259 | print(ep, t2-t1, train_l2, test_l2)#, np.mean(error_total)) 260 | 261 | elapsed = timeit.default_timer() - start_time 262 | print("The training wall-clock time is seconds is equal to %f seconds"%elapsed) 263 | 264 | pred_torch = torch.zeros(S_test.shape) 265 | baseline_torch = torch.zeros(S_test.shape) 266 | index = 0 267 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 268 | test_error_u = [] 269 | test_error_u_np = [] 270 | with torch.no_grad(): 271 | for x, y in test_loader: 272 | test_l2 = 0 273 | x, y = x.cuda(), y.cuda() 274 | 275 | out = model(x).reshape(1, s, s) 276 | pred_torch[index,:,:] = out[:,:,:] 277 | baseline_torch[index,:,:] = y[:,:,:] 278 | 279 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 280 | test_error_u.append(test_l2) 281 | test_error_u_np.append(np.linalg.norm(out.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[1])- y.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[1]),2)/np.linalg.norm(out.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[2]),2)) 282 | index = index + 1 283 | 284 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u),np.std(test_error_u),np.min(test_error_u),np.max(test_error_u))) 285 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u_np),np.std(test_error_u_np),np.min(test_error_u_np),np.max(test_error_u_np))) -------------------------------------------------------------------------------- /Darcy/FNO/FNODarcy.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 2D problem such as the Darcy Flow discussed in Section 5.2 in the [paper](https://arxiv.org/pdf/2010.08895.pdf). 4 | """ 5 | 6 | import numpy as np 7 | from numpy.linalg import norm 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn.parameter import Parameter 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | import operator 16 | from functools import reduce 17 | from functools import partial 18 | 19 | from timeit import default_timer 20 | from utilities3 import * 21 | 22 | torch.manual_seed(0) 23 | np.random.seed(0) 24 | 25 | import timeit 26 | 27 | 28 | ################################################################ 29 | # fourier layer 30 | ################################################################ 31 | class SpectralConv2d(nn.Module): 32 | def __init__(self, in_channels, out_channels, modes1, modes2): 33 | super(SpectralConv2d, self).__init__() 34 | 35 | """ 36 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. 37 | """ 38 | 39 | self.in_channels = in_channels 40 | self.out_channels = out_channels 41 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 42 | self.modes2 = modes2 43 | 44 | self.scale = (1 / (in_channels * out_channels)) 45 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 46 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 47 | 48 | # Complex multiplication 49 | def compl_mul2d(self, input, weights): 50 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 51 | return torch.einsum("bixy,ioxy->boxy", input, weights) 52 | 53 | def forward(self, x): 54 | batchsize = x.shape[0] 55 | #Compute Fourier coeffcients up to factor of e^(- something constant) 56 | x_ft = torch.fft.rfft2(x) 57 | 58 | # Multiply relevant Fourier modes 59 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device) 60 | out_ft[:, :, :self.modes1, :self.modes2] = \ 61 | self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 62 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 63 | self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 64 | #Return to physical space 65 | x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) 66 | return x 67 | 68 | class FNO2d(nn.Module): 69 | def __init__(self, modes1, modes2, width): 70 | super(FNO2d, self).__init__() 71 | 72 | """ 73 | The overall network. It contains 4 layers of the Fourier layer. 74 | 1. Lift the input to the desire channel dimension by self.fc0 . 75 | 2. 4 layers of the integral operators u' = (W + K)(u). 76 | W defined by self.w; K defined by self.conv . 77 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 78 | 79 | input: the solution of the coefficient function and locations (a(x, y), x, y) 80 | input shape: (batchsize, x=s, y=s, c=3) 81 | output: the solution 82 | output shape: (batchsize, x=s, y=s, c=1) 83 | """ 84 | 85 | self.modes1 = modes1 86 | self.modes2 = modes2 87 | self.width = width 88 | self.fc0 = nn.Linear(3, self.width) 89 | self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 90 | self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 91 | self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 92 | self.conv3 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 93 | self.w0 = nn.Conv1d(self.width, self.width, 1) 94 | self.w1 = nn.Conv1d(self.width, self.width, 1) 95 | self.w2 = nn.Conv1d(self.width, self.width, 1) 96 | self.w3 = nn.Conv1d(self.width, self.width, 1) 97 | self.fc1 = nn.Linear(self.width, 128) 98 | self.fc2 = nn.Linear(128, 1) 99 | 100 | def forward(self, x): 101 | batchsize = x.shape[0] 102 | size_x, size_y = x.shape[1], x.shape[2] 103 | 104 | x = self.fc0(x) 105 | x = x.permute(0, 3, 1, 2) 106 | 107 | x1 = self.conv0(x) 108 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 109 | x = x1 + x2 110 | x = F.relu(x) 111 | 112 | x1 = self.conv1(x) 113 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 114 | x = x1 + x2 115 | x = F.relu(x) 116 | 117 | x1 = self.conv2(x) 118 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 119 | x = x1 + x2 120 | x = F.relu(x) 121 | 122 | x1 = self.conv3(x) 123 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 124 | x = x1 + x2 125 | 126 | x = x.permute(0, 2, 3, 1) 127 | x = self.fc1(x) 128 | x = F.relu(x) 129 | x = self.fc2(x) 130 | return x 131 | 132 | 133 | ################################################################ 134 | # configs 135 | ################################################################ 136 | ntrain = 1000 137 | ntest = 1000 138 | 139 | batch_size = 100 140 | learning_rate = 0.001 141 | 142 | epochs = 500 143 | step_size = 100 144 | gamma = 0.5 145 | 146 | modes = 8 147 | width = 32 148 | 149 | r = 1 150 | h = 32 151 | sub = 1 152 | s = 32 153 | P = 128 154 | 155 | ################################################################ 156 | # load data and data normalization 157 | ################################################################ 158 | d = np.load("../Data/train_darcy_dataset.npz") 159 | U_train = d["U_train"].reshape(ntrain,32,32)[:,::sub,::sub] 160 | X_train = d["X_train"] 161 | Y_train = d["Y_train"] 162 | S_train = d["s_train"].reshape(ntrain,32,32)[:,::sub,::sub] 163 | 164 | d = np.load("../Data/test_darcy_dataset.npz") 165 | U_test = d["U_test"].reshape(ntest,32,32)[:,::sub,::sub] 166 | X_test = d["X_test"] 167 | Y_test = d["Y_test"] 168 | S_test = d["s_test"].reshape(ntest,32,32)[:,::sub,::sub] 169 | 170 | dtype_double = torch.FloatTensor 171 | cdtype_double = torch.cuda.DoubleTensor 172 | U_train = torch.from_numpy(np.asarray(U_train)).type(dtype_double) 173 | S_train = torch.from_numpy(np.asarray(S_train)).type(dtype_double) 174 | 175 | U_test = torch.from_numpy(np.asarray(U_test)).type(dtype_double) 176 | S_test = torch.from_numpy(np.asarray(S_test)).type(dtype_double) 177 | 178 | x_train = U_train 179 | y_train = S_train 180 | 181 | x_test = U_test 182 | y_test = S_test 183 | 184 | grids = [] 185 | grids.append(np.linspace(0, 1, s)) 186 | grids.append(np.linspace(0, 1, s)) 187 | grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T 188 | grid = grid.reshape(1,s,s,2) 189 | grid = torch.tensor(grid, dtype=torch.float) 190 | x_train = torch.cat([x_train.reshape(ntrain,s,s,1), grid.repeat(ntrain,1,1,1)], dim=3) 191 | x_test = torch.cat([x_test.reshape(ntest,s,s,1), grid.repeat(ntest,1,1,1)], dim=3) 192 | 193 | ind_train = torch.randint(s*s, (ntrain, P)) 194 | ind_test = torch.randint(s*s, (ntest, P)) 195 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train, ind_train), batch_size=batch_size, shuffle=True) 196 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test, ind_test), batch_size=batch_size, shuffle=True) 197 | 198 | ################################################################ 199 | # training and evaluation 200 | ################################################################ 201 | 202 | batch_ind = torch.arange(batch_size).reshape(-1, 1).repeat(1, P) 203 | model = FNO2d(modes, modes, width).cuda() 204 | print(count_params(model)) 205 | 206 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 207 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 208 | 209 | myloss = LpLoss(size_average=False) 210 | 211 | start_time = timeit.default_timer() 212 | for ep in range(epochs): 213 | model.train() 214 | t1 = default_timer() 215 | train_l2 = 0 216 | for x, y, idx in train_loader: 217 | x, y = x.cuda(), y.cuda() 218 | 219 | optimizer.zero_grad() 220 | out = model(x).reshape(batch_size, s*s) 221 | y = y.reshape(batch_size, s*s) 222 | y = y[batch_ind, idx] 223 | out = out[batch_ind, idx] 224 | 225 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 226 | loss.backward() 227 | 228 | optimizer.step() 229 | train_l2 += loss.item() 230 | 231 | scheduler.step() 232 | 233 | model.eval() 234 | test_l2 = 0.0 235 | with torch.no_grad(): 236 | for x, y, idx in test_loader: 237 | x, y = x.cuda(), y.cuda() 238 | 239 | out = model(x).reshape(batch_size, s*s) 240 | y = y.reshape(batch_size, s*s,1) 241 | y = y[batch_ind, idx] 242 | out = out[batch_ind, idx] 243 | 244 | test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item() 245 | 246 | train_l2/= ntrain 247 | test_l2 /= ntest 248 | 249 | t2 = default_timer() 250 | 251 | print(ep, t2-t1, train_l2, test_l2) 252 | 253 | elapsed = timeit.default_timer() - start_time 254 | print("The training wall-clock time is seconds is equal to %f seconds"%elapsed) 255 | 256 | grids = [] 257 | grids.append(np.linspace(0, 1, s)) 258 | grids.append(np.linspace(0, 1, s)) 259 | grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T 260 | grid = grid.reshape(1,s,s,2) 261 | grid = torch.tensor(grid, dtype=torch.float) 262 | x_train = torch.cat([x_train, grid.repeat(ntrain,1,1,1)], dim=3) 263 | 264 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=1, shuffle=False) 265 | 266 | train_error_u_np = [] 267 | with torch.no_grad(): 268 | for x, y in train_loader: 269 | x, y = x.cuda(), y.cuda() 270 | 271 | out = model(x).reshape(1, s, s) 272 | train_error_u_np.append(np.linalg.norm(y.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[1])- out.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[1]),2)/np.linalg.norm(y.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[2]),2)) 273 | 274 | print("The average train u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(train_error_u_np),np.std(train_error_u_np),np.min(train_error_u_np),np.max(train_error_u_np))) 275 | 276 | sub = 1 277 | s = 32 278 | d = np.load("data/test_darcy_dataset_FNO2.npz") 279 | U_test = d["U_test"].reshape(ntest,32,32)[:,::sub,::sub] 280 | X_test = d["X_test"] 281 | Y_test = d["Y_test"] 282 | S_test = d["s_test"].reshape(ntest,32,32)[:,::sub,::sub] 283 | 284 | dtype_double = torch.FloatTensor 285 | cdtype_double = torch.cuda.DoubleTensor 286 | 287 | U_test = torch.from_numpy(np.asarray(U_test)).type(dtype_double) 288 | S_test = torch.from_numpy(np.asarray(S_test)).type(dtype_double) 289 | 290 | x_test = U_test 291 | y_test = S_test 292 | 293 | grids = [] 294 | grids.append(np.linspace(0, 1, s)) 295 | grids.append(np.linspace(0, 1, s)) 296 | grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T 297 | grid = grid.reshape(1,s,s,2) 298 | grid = torch.tensor(grid, dtype=torch.float) 299 | x_test = torch.cat([x_test.reshape(ntest,s,s,1), grid.repeat(ntest,1,1,1)], dim=3) 300 | 301 | pred_torch = torch.zeros(S_test.shape) 302 | baseline_torch = torch.zeros(S_test.shape) 303 | index = 0 304 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False) 305 | test_error_u = [] 306 | test_error_u_np = [] 307 | with torch.no_grad(): 308 | for x, y in test_loader: 309 | test_l2 = 0 310 | x, y = x.cuda(), y.cuda() 311 | 312 | out = model(x).reshape(1, s, s) 313 | pred_torch[index,:,:] = out[:,:,:] 314 | baseline_torch[index,:,:] = y[:,:,:] 315 | 316 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 317 | test_error_u.append(test_l2) 318 | test_error_u_np.append(np.linalg.norm(y.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[1])- out.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[1]),2)/np.linalg.norm(y.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[2]),2)) 319 | index = index + 1 320 | 321 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u),np.std(test_error_u),np.min(test_error_u),np.max(test_error_u))) 322 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u_np),np.std(test_error_u_np),np.min(test_error_u_np),np.max(test_error_u_np))) -------------------------------------------------------------------------------- /Darcy/FNO/utilities3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.io 4 | import h5py 5 | import torch.nn as nn 6 | 7 | import operator 8 | from functools import reduce 9 | from functools import partial 10 | 11 | ################################################# 12 | # 13 | # Utilities 14 | # 15 | ################################################# 16 | import os 17 | def get_freer_gpu(): 18 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') 19 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 20 | return str(np.argmax(memory_available)) 21 | 22 | os.environ['CUDA_VISIBLE_DEVICES']= get_freer_gpu() 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | # reading data 27 | class MatReader(object): 28 | def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): 29 | super(MatReader, self).__init__() 30 | 31 | self.to_torch = to_torch 32 | self.to_cuda = to_cuda 33 | self.to_float = to_float 34 | 35 | self.file_path = file_path 36 | 37 | self.data = None 38 | self.old_mat = None 39 | self._load_file() 40 | 41 | def _load_file(self): 42 | try: 43 | self.data = scipy.io.loadmat(self.file_path) 44 | self.old_mat = True 45 | except: 46 | self.data = h5py.File(self.file_path) 47 | self.old_mat = False 48 | 49 | def load_file(self, file_path): 50 | self.file_path = file_path 51 | self._load_file() 52 | 53 | def read_field(self, field): 54 | x = self.data[field] 55 | 56 | if not self.old_mat: 57 | x = x[()] 58 | x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) 59 | 60 | if self.to_float: 61 | x = x.astype(np.float32) 62 | 63 | if self.to_torch: 64 | x = torch.from_numpy(x) 65 | 66 | if self.to_cuda: 67 | x = x.cuda() 68 | 69 | return x 70 | 71 | def set_cuda(self, to_cuda): 72 | self.to_cuda = to_cuda 73 | 74 | def set_torch(self, to_torch): 75 | self.to_torch = to_torch 76 | 77 | def set_float(self, to_float): 78 | self.to_float = to_float 79 | 80 | # normalization, pointwise gaussian 81 | class UnitGaussianNormalizer(object): 82 | def __init__(self, x, eps=0.00001): 83 | super(UnitGaussianNormalizer, self).__init__() 84 | 85 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 86 | self.mean = torch.mean(x, 0) 87 | self.std = torch.std(x, 0) 88 | self.eps = eps 89 | 90 | def encode(self, x): 91 | x = (x - self.mean) / (self.std + self.eps) 92 | return x 93 | 94 | def decode(self, x, sample_idx=None): 95 | if sample_idx is None: 96 | std = self.std + self.eps # n 97 | mean = self.mean 98 | else: 99 | if len(self.mean.shape) == len(sample_idx[0].shape): 100 | std = self.std[sample_idx] + self.eps # batch*n 101 | mean = self.mean[sample_idx] 102 | if len(self.mean.shape) > len(sample_idx[0].shape): 103 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 104 | mean = self.mean[:,sample_idx] 105 | 106 | # x is in shape of batch*n or T*batch*n 107 | x = (x * std) + mean 108 | return x 109 | 110 | def cuda(self): 111 | self.mean = self.mean.cuda() 112 | self.std = self.std.cuda() 113 | 114 | def cpu(self): 115 | self.mean = self.mean.cpu() 116 | self.std = self.std.cpu() 117 | 118 | # normalization, Gaussian 119 | class GaussianNormalizer(object): 120 | def __init__(self, x, eps=0.00001): 121 | super(GaussianNormalizer, self).__init__() 122 | 123 | self.mean = torch.mean(x) 124 | self.std = torch.std(x) 125 | self.eps = eps 126 | 127 | def encode(self, x): 128 | x = (x - self.mean) / (self.std + self.eps) 129 | return x 130 | 131 | def decode(self, x, sample_idx=None): 132 | x = (x * (self.std + self.eps)) + self.mean 133 | return x 134 | 135 | def cuda(self): 136 | self.mean = self.mean.cuda() 137 | self.std = self.std.cuda() 138 | 139 | def cpu(self): 140 | self.mean = self.mean.cpu() 141 | self.std = self.std.cpu() 142 | 143 | 144 | # normalization, scaling by range 145 | class RangeNormalizer(object): 146 | def __init__(self, x, low=0.0, high=1.0): 147 | super(RangeNormalizer, self).__init__() 148 | mymin = torch.min(x, 0)[0].view(-1) 149 | mymax = torch.max(x, 0)[0].view(-1) 150 | 151 | self.a = (high - low)/(mymax - mymin) 152 | self.b = -self.a*mymax + high 153 | 154 | def encode(self, x): 155 | s = x.size() 156 | x = x.view(s[0], -1) 157 | x = self.a*x + self.b 158 | x = x.view(s) 159 | return x 160 | 161 | def decode(self, x): 162 | s = x.size() 163 | x = x.view(s[0], -1) 164 | x = (x - self.b)/self.a 165 | x = x.view(s) 166 | return x 167 | 168 | #loss function with rel/abs Lp loss 169 | class LpLoss(object): 170 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 171 | super(LpLoss, self).__init__() 172 | 173 | #Dimension and Lp-norm type are postive 174 | assert d > 0 and p > 0 175 | 176 | self.d = d 177 | self.p = p 178 | self.reduction = reduction 179 | self.size_average = size_average 180 | 181 | def abs(self, x, y): 182 | num_examples = x.size()[0] 183 | 184 | #Assume uniform mesh 185 | h = 1.0 / (x.size()[1] - 1.0) 186 | 187 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 188 | 189 | if self.reduction: 190 | if self.size_average: 191 | return torch.mean(all_norms) 192 | else: 193 | return torch.sum(all_norms) 194 | 195 | return all_norms 196 | 197 | def rel(self, x, y): 198 | num_examples = x.size()[0] 199 | 200 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 201 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 202 | 203 | if self.reduction: 204 | if self.size_average: 205 | return torch.mean(diff_norms/y_norms) 206 | else: 207 | return torch.sum(diff_norms/y_norms) 208 | 209 | return diff_norms/y_norms 210 | 211 | def __call__(self, x, y): 212 | return self.rel(x, y) 213 | 214 | # Sobolev norm (HS norm) 215 | # where we also compare the numerical derivatives between the output and target 216 | class HsLoss(object): 217 | def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True): 218 | super(HsLoss, self).__init__() 219 | 220 | #Dimension and Lp-norm type are postive 221 | assert d > 0 and p > 0 222 | 223 | self.d = d 224 | self.p = p 225 | self.k = k 226 | self.balanced = group 227 | self.reduction = reduction 228 | self.size_average = size_average 229 | 230 | if a == None: 231 | a = [1,] * k 232 | self.a = a 233 | 234 | def rel(self, x, y): 235 | num_examples = x.size()[0] 236 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 237 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 238 | if self.reduction: 239 | if self.size_average: 240 | return torch.mean(diff_norms/y_norms) 241 | else: 242 | return torch.sum(diff_norms/y_norms) 243 | return diff_norms/y_norms 244 | 245 | def __call__(self, x, y, a=None): 246 | nx = x.size()[1] 247 | ny = x.size()[2] 248 | k = self.k 249 | balanced = self.balanced 250 | a = self.a 251 | x = x.view(x.shape[0], nx, ny, -1) 252 | y = y.view(y.shape[0], nx, ny, -1) 253 | 254 | k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny) 255 | k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1) 256 | k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device) 257 | k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device) 258 | 259 | x = torch.fft.fftn(x, dim=[1, 2]) 260 | y = torch.fft.fftn(y, dim=[1, 2]) 261 | 262 | if balanced==False: 263 | weight = 1 264 | if k >= 1: 265 | weight += a[0]**2 * (k_x**2 + k_y**2) 266 | if k >= 2: 267 | weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 268 | weight = torch.sqrt(weight) 269 | loss = self.rel(x*weight, y*weight) 270 | else: 271 | loss = self.rel(x, y) 272 | if k >= 1: 273 | weight = a[0] * torch.sqrt(k_x**2 + k_y**2) 274 | loss += self.rel(x*weight, y*weight) 275 | if k >= 2: 276 | weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 277 | loss += self.rel(x*weight, y*weight) 278 | loss = loss / (k+1) 279 | 280 | return loss 281 | 282 | # A simple feedforward neural network 283 | class DenseNet(torch.nn.Module): 284 | def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False): 285 | super(DenseNet, self).__init__() 286 | 287 | self.n_layers = len(layers) - 1 288 | 289 | assert self.n_layers >= 1 290 | 291 | self.layers = nn.ModuleList() 292 | 293 | for j in range(self.n_layers): 294 | self.layers.append(nn.Linear(layers[j], layers[j+1])) 295 | 296 | if j != self.n_layers - 1: 297 | if normalize: 298 | self.layers.append(nn.BatchNorm1d(layers[j+1])) 299 | 300 | self.layers.append(nonlinearity()) 301 | 302 | if out_nonlinearity is not None: 303 | self.layers.append(out_nonlinearity()) 304 | 305 | def forward(self, x): 306 | for _, l in enumerate(self.layers): 307 | x = l(x) 308 | 309 | return x 310 | 311 | 312 | # print the number of parameters 313 | def count_params(model): 314 | c = 0 315 | for p in list(model.parameters()): 316 | c += reduce(operator.mul, list(p.size())) 317 | return c 318 | -------------------------------------------------------------------------------- /MMNIST/FNO/Adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import Tensor 4 | from typing import List, Optional 5 | from torch.optim.optimizer import Optimizer 6 | 7 | import os 8 | os.environ['CUDA_VISIBLE_DEVICES']="3" 9 | 10 | def adam(params: List[Tensor], 11 | grads: List[Tensor], 12 | exp_avgs: List[Tensor], 13 | exp_avg_sqs: List[Tensor], 14 | max_exp_avg_sqs: List[Tensor], 15 | state_steps: List[int], 16 | *, 17 | amsgrad: bool, 18 | beta1: float, 19 | beta2: float, 20 | lr: float, 21 | weight_decay: float, 22 | eps: float): 23 | r"""Functional API that performs Adam algorithm computation. 24 | See :class:`~torch.optim.Adam` for details. 25 | """ 26 | 27 | for i, param in enumerate(params): 28 | 29 | grad = grads[i] 30 | exp_avg = exp_avgs[i] 31 | exp_avg_sq = exp_avg_sqs[i] 32 | step = state_steps[i] 33 | 34 | bias_correction1 = 1 - beta1 ** step 35 | bias_correction2 = 1 - beta2 ** step 36 | 37 | if weight_decay != 0: 38 | grad = grad.add(param, alpha=weight_decay) 39 | 40 | # Decay the first and second moment running average coefficient 41 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 42 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 43 | if amsgrad: 44 | # Maintains the maximum of all 2nd moment running avg. till now 45 | torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 46 | # Use the max. for normalizing running avg. of gradient 47 | denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) 48 | else: 49 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) 50 | 51 | step_size = lr / bias_correction1 52 | 53 | param.addcdiv_(exp_avg, denom, value=-step_size) 54 | 55 | 56 | class Adam(Optimizer): 57 | r"""Implements Adam algorithm. 58 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 59 | The implementation of the L2 penalty follows changes proposed in 60 | `Decoupled Weight Decay Regularization`_. 61 | Args: 62 | params (iterable): iterable of parameters to optimize or dicts defining 63 | parameter groups 64 | lr (float, optional): learning rate (default: 1e-3) 65 | betas (Tuple[float, float], optional): coefficients used for computing 66 | running averages of gradient and its square (default: (0.9, 0.999)) 67 | eps (float, optional): term added to the denominator to improve 68 | numerical stability (default: 1e-8) 69 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 70 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 71 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 72 | (default: False) 73 | .. _Adam\: A Method for Stochastic Optimization: 74 | https://arxiv.org/abs/1412.6980 75 | .. _Decoupled Weight Decay Regularization: 76 | https://arxiv.org/abs/1711.05101 77 | .. _On the Convergence of Adam and Beyond: 78 | https://openreview.net/forum?id=ryQu7f-RZ 79 | """ 80 | 81 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 82 | weight_decay=0, amsgrad=False): 83 | if not 0.0 <= lr: 84 | raise ValueError("Invalid learning rate: {}".format(lr)) 85 | if not 0.0 <= eps: 86 | raise ValueError("Invalid epsilon value: {}".format(eps)) 87 | if not 0.0 <= betas[0] < 1.0: 88 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 89 | if not 0.0 <= betas[1] < 1.0: 90 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 91 | if not 0.0 <= weight_decay: 92 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 93 | defaults = dict(lr=lr, betas=betas, eps=eps, 94 | weight_decay=weight_decay, amsgrad=amsgrad) 95 | super(Adam, self).__init__(params, defaults) 96 | 97 | def __setstate__(self, state): 98 | super(Adam, self).__setstate__(state) 99 | for group in self.param_groups: 100 | group.setdefault('amsgrad', False) 101 | 102 | @torch.no_grad() 103 | def step(self, closure=None): 104 | """Performs a single optimization step. 105 | Args: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | with torch.enable_grad(): 112 | loss = closure() 113 | 114 | for group in self.param_groups: 115 | params_with_grad = [] 116 | grads = [] 117 | exp_avgs = [] 118 | exp_avg_sqs = [] 119 | max_exp_avg_sqs = [] 120 | state_steps = [] 121 | beta1, beta2 = group['betas'] 122 | 123 | for p in group['params']: 124 | if p.grad is not None: 125 | params_with_grad.append(p) 126 | if p.grad.is_sparse: 127 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 128 | grads.append(p.grad) 129 | 130 | state = self.state[p] 131 | # Lazy state initialization 132 | if len(state) == 0: 133 | state['step'] = 0 134 | # Exponential moving average of gradient values 135 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 136 | # Exponential moving average of squared gradient values 137 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 138 | if group['amsgrad']: 139 | # Maintains max of all exp. moving avg. of sq. grad. values 140 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 141 | 142 | exp_avgs.append(state['exp_avg']) 143 | exp_avg_sqs.append(state['exp_avg_sq']) 144 | 145 | if group['amsgrad']: 146 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 147 | 148 | # update the steps for each param group update 149 | state['step'] += 1 150 | # record the step after step update 151 | state_steps.append(state['step']) 152 | 153 | adam(params_with_grad, 154 | grads, 155 | exp_avgs, 156 | exp_avg_sqs, 157 | max_exp_avg_sqs, 158 | state_steps, 159 | amsgrad=group['amsgrad'], 160 | beta1=beta1, 161 | beta2=beta2, 162 | lr=group['lr'], 163 | weight_decay=group['weight_decay'], 164 | eps=group['eps']) 165 | return loss -------------------------------------------------------------------------------- /MMNIST/FNO/FNOMMNIST.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 2D problem such as the Darcy Flow discussed in Section 5.2 in the [paper](https://arxiv.org/pdf/2010.08895.pdf). 4 | """ 5 | 6 | from jax._src.numpy.lax_numpy import arange 7 | import numpy as np 8 | from numpy.linalg import norm 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn.parameter import Parameter 13 | import os 14 | 15 | from functools import reduce 16 | 17 | from timeit import default_timer 18 | from utilities3 import count_params, LpLoss 19 | 20 | 21 | import timeit 22 | 23 | 24 | seed = np.random.randint(10000) 25 | torch.manual_seed(seed) 26 | np.random.seed(seed) 27 | 28 | 29 | ################################################################ 30 | # fourier layer 31 | ################################################################ 32 | class SpectralConv2d(nn.Module): 33 | def __init__(self, in_channels, out_channels, modes1, modes2): 34 | super(SpectralConv2d, self).__init__() 35 | 36 | """ 37 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. 38 | """ 39 | 40 | self.in_channels = in_channels 41 | self.out_channels = out_channels 42 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 43 | self.modes2 = modes2 44 | 45 | self.scale = (1 / (in_channels * out_channels)) 46 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 47 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 48 | 49 | # Complex multiplication 50 | def compl_mul2d(self, input, weights): 51 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 52 | return torch.einsum("bixy,ioxy->boxy", input, weights) 53 | 54 | def forward(self, x): 55 | batchsize = x.shape[0] 56 | #Compute Fourier coeffcients up to factor of e^(- something constant) 57 | x_ft = torch.fft.rfft2(x) 58 | 59 | # Multiply relevant Fourier modes 60 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device) 61 | out_ft[:, :, :self.modes1, :self.modes2] = \ 62 | self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 63 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 64 | self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 65 | 66 | #Return to physical space 67 | x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) 68 | return x 69 | 70 | class FNO2d(nn.Module): 71 | def __init__(self, modes1, modes2, width, indices=None): 72 | super(FNO2d, self).__init__() 73 | 74 | """ 75 | The overall network. It contains 4 layers of the Fourier layer. 76 | 1. Lift the input to the desire channel dimension by self.fc0 . 77 | 2. 4 layers of the integral operators u' = (W + K)(u). 78 | W defined by self.w; K defined by self.conv . 79 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 80 | 81 | input: the solution of the coefficient function and locations (a(x, y), x, y) 82 | input shape: (batchsize, x=s, y=s, c=3) 83 | output: the solution 84 | output shape: (batchsize, x=s, y=s, c=1) 85 | """ 86 | 87 | self.modes1 = modes1 88 | self.modes2 = modes2 89 | self.width = width 90 | self.fc0 = nn.Linear(4, self.width) 91 | 92 | self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 93 | self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 94 | self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 95 | self.conv3 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 96 | self.w0 = nn.Conv1d(self.width, self.width, 1) 97 | self.w1 = nn.Conv1d(self.width, self.width, 1) 98 | self.w2 = nn.Conv1d(self.width, self.width, 1) 99 | self.w3 = nn.Conv1d(self.width, self.width, 1) 100 | 101 | 102 | self.fc1 = nn.Linear(self.width, 128) 103 | self.fc2 = nn.Linear(128, 2) 104 | 105 | self.indices = indices 106 | 107 | def forward(self, x): 108 | batchsize = x.shape[0] 109 | size_x, size_y = x.shape[1], x.shape[2] 110 | 111 | x = self.fc0(x) 112 | x = x.permute(0, 3, 1, 2) 113 | 114 | x1 = self.conv0(x) 115 | x2 = self.w0(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 116 | x = x1 + x2 117 | x = F.relu(x) 118 | 119 | x1 = self.conv1(x) 120 | x2 = self.w1(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 121 | x = x1 + x2 122 | x = F.relu(x) 123 | 124 | x1 = self.conv2(x) 125 | x2 = self.w2(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 126 | x = x1 + x2 127 | x = F.relu(x) 128 | 129 | x1 = self.conv3(x) 130 | x2 = self.w3(x.view(batchsize, self.width, -1)).view(batchsize, self.width, size_x, size_y) 131 | x = x1 + x2 132 | 133 | x = x.permute(0, 2, 3, 1) 134 | x = self.fc1(x) 135 | x = F.relu(x) 136 | x = self.fc2(x) 137 | return x 138 | 139 | 140 | ################################################################ 141 | # configs 142 | ################################################################ 143 | ntrain = 60000 144 | ntest = 10000 145 | 146 | batch_size = 100 147 | learning_rate = 0.001 148 | 149 | epochs = 200 150 | step_size = 100 151 | gamma = 0.5 152 | 153 | modes = 12 154 | width = 32 155 | h = 28 156 | 157 | r = 1 158 | sub = 1 159 | sub1 = 1 160 | s = h 161 | P = 56 162 | ind = 9 163 | ################################################################ 164 | # load data and data normalization 165 | ################################################################ 166 | idxT = [11] 167 | d = np.load("../Data/MMNIST_dataset_train.npz") 168 | dispx_allsteps_train = d["dispx_allsteps_train"][:ntrain,idxT,::sub,::sub,None][:,-1,:,:,:] 169 | dispy_allsteps_train = d["dispy_allsteps_train"][:ntrain,idxT,::sub,::sub,None][:,-1,:,:,:] 170 | u_trainx = d["dispx_allsteps_train"][:ntrain,7,::sub,::sub,None] 171 | u_trainy = d["dispy_allsteps_train"][:ntrain,7,::sub,::sub,None] 172 | 173 | d = np.load("../Data/MMNIST_dataset_test.npz") 174 | dispx_allsteps_test = d["dispx_allsteps_test"][:ntest,idxT,::sub,::sub,None][:,-1,:,:,:] 175 | dispy_allsteps_test = d["dispy_allsteps_test"][:ntest,idxT,::sub,::sub,None][:,-1,:,:,:] 176 | u_testx = d["dispx_allsteps_test"][:ntest,7,::sub,::sub,None] 177 | u_testy = d["dispy_allsteps_test"][:ntest,7,::sub,::sub,None] 178 | 179 | S_train = np.concatenate((dispx_allsteps_train,dispy_allsteps_train),axis=-1) 180 | S_test = np.concatenate((dispx_allsteps_test,dispy_allsteps_test),axis=-1) 181 | 182 | U_train = np.concatenate((u_trainx,u_trainy),axis=-1) 183 | U_test = np.concatenate((u_testx,u_testy),axis=-1) 184 | 185 | tag = "CN" 186 | # in_noise_train = 0.15*np.random.normal(loc=0.0, scale=1.0, size=(U_train.shape)) 187 | # U_train = U_train + in_noise_train 188 | 189 | dtype_double = torch.FloatTensor 190 | cdtype_double = torch.cuda.DoubleTensor 191 | U_train = torch.from_numpy(np.asarray(U_train)).type(dtype_double) 192 | S_train = torch.from_numpy(np.asarray(S_train)).type(dtype_double) 193 | 194 | U_test = torch.from_numpy(np.asarray(U_test)).type(dtype_double) 195 | S_test = torch.from_numpy(np.asarray(S_test)).type(dtype_double) 196 | 197 | x_train = U_train 198 | y_train = S_train 199 | 200 | x_test = U_test 201 | y_test = S_test 202 | 203 | ########################################### 204 | grids = [] 205 | grids.append(np.linspace(0, 1, s)) 206 | grids.append(np.linspace(0, 1, s)) 207 | grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T 208 | grid = grid.reshape(1,s,s,2) 209 | grid = torch.tensor(grid, dtype=torch.float) 210 | x_train = torch.cat([x_train.reshape(ntrain,s,s,2), grid.repeat(ntrain,1,1,1)], dim=3) 211 | x_test = torch.cat([x_test.reshape(ntest,s,s,2), grid.repeat(ntest,1,1,1)], dim=3) 212 | 213 | ind_train = torch.randint(s*s, (ntrain, P)) 214 | ind_test = torch.randint(s*s, (ntest, P)) 215 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train, ind_train), batch_size=batch_size, shuffle=True) 216 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test, ind_test), batch_size=batch_size, shuffle=True) 217 | 218 | ################################################################ 219 | # training and evaluation 220 | ################################################################ 221 | 222 | batch_ind = torch.arange(batch_size).reshape(-1, 1).repeat(1, P) 223 | model = FNO2d(modes, modes, width).cuda() 224 | print(count_params(model)) 225 | 226 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 227 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 228 | 229 | myloss = LpLoss(size_average=False) 230 | 231 | start_time = timeit.default_timer() 232 | for ep in range(epochs): 233 | model.train() 234 | t1 = default_timer() 235 | train_l2 = 0 236 | for x, y, idx in train_loader: 237 | x, y = x.cuda(), y.cuda() 238 | 239 | optimizer.zero_grad() 240 | out = model(x).reshape(batch_size, s*s,2) 241 | y = y.reshape(batch_size, s*s,2) 242 | y = y[batch_ind, idx] 243 | out = out[batch_ind, idx] 244 | 245 | loss = myloss(out.view(batch_size,-1), y.view(batch_size,-1)) 246 | loss.backward() 247 | 248 | optimizer.step() 249 | train_l2 += loss.item() 250 | 251 | scheduler.step() 252 | 253 | model.eval() 254 | test_l2 = 0.0 255 | with torch.no_grad(): 256 | for x, y, idx in test_loader: 257 | x, y = x.cuda(), y.cuda() 258 | 259 | out = model(x).reshape(batch_size, s*s,2) 260 | y = y.reshape(batch_size, s*s,2) 261 | y = y[batch_ind, idx] 262 | out = out[batch_ind, idx] 263 | 264 | test_l2 += myloss(out.view(batch_size,-1), y.view(batch_size,-1)).item() 265 | 266 | train_l2/= ntrain 267 | test_l2 /= ntest 268 | 269 | t2 = default_timer() 270 | 271 | print(ep, t2-t1, train_l2, test_l2) 272 | 273 | elapsed = timeit.default_timer() - start_time 274 | print("The training wall-clock time is seconds is equal to %f seconds"%elapsed) 275 | 276 | sub = 1 277 | s = 28 278 | d = np.load("/scratch/gkissas/MMNIST_dataset_test.npz") 279 | dispx_allsteps_test = d["dispx_allsteps_test"][:ntest,idxT,::sub,::sub,None][:,-1,:,:,:] 280 | dispy_allsteps_test = d["dispy_allsteps_test"][:ntest,idxT,::sub,::sub,None][:,-1,:,:,:] 281 | u_testx = d["dispx_allsteps_test"][:ntest,7,::sub,::sub,None] 282 | u_testy = d["dispy_allsteps_test"][:ntest,7,::sub,::sub,None] 283 | 284 | S_test = np.concatenate((dispx_allsteps_test,dispy_allsteps_test),axis=-1) 285 | U_test = np.concatenate((u_testx,u_testy),axis=-1) 286 | 287 | in_noise_test = 0.15*np.random.normal(loc=0.0, scale=1.0, size=(U_test.shape)) 288 | U_test = U_test + in_noise_test 289 | 290 | dtype_double = torch.FloatTensor 291 | cdtype_double = torch.cuda.DoubleTensor 292 | 293 | U_test = torch.from_numpy(np.asarray(U_test)).type(dtype_double) 294 | S_test = torch.from_numpy(np.asarray(S_test)).type(dtype_double) 295 | 296 | x_test = U_test 297 | y_test = S_test 298 | 299 | grids = [] 300 | grids.append(np.linspace(0, 1, s)) 301 | grids.append(np.linspace(0, 1, s)) 302 | grid = np.vstack([xx.ravel() for xx in np.meshgrid(*grids)]).T 303 | grid = grid.reshape(1,s,s,2) 304 | grid = torch.tensor(grid, dtype=torch.float) 305 | x_test = torch.cat([x_test.reshape(ntest,s,s,2), grid.repeat(ntest,1,1,1)], dim=3) 306 | 307 | pred_torch = torch.zeros(S_test.shape) 308 | baseline_torch = torch.zeros(S_test.shape) 309 | index = 0 310 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=True) 311 | test_error_u = [] 312 | test_error_u_np = [] 313 | test_error_v_np = [] 314 | with torch.no_grad(): 315 | for x, y in test_loader: 316 | test_l2 = 0 317 | x, y= x.cuda(), y.cuda() 318 | 319 | out = model(x).reshape(1, s, s,2) 320 | # out = y_normalizer.decode(out) 321 | pred_torch[index,:,:] = out[:,:,:,:] 322 | baseline_torch[index,:,:] = y[:,:,:,:] 323 | 324 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 325 | test_error_u.append(test_l2) 326 | test_error_u_np.append(np.linalg.norm(y.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[1],2)[:,0]- out.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[1],2)[:,0],2)/np.linalg.norm(y.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[2],2)[:,0],2)) 327 | test_error_v_np.append(np.linalg.norm(y.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[1],2)[:,1]- out.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[1],2)[:,1],2)/np.linalg.norm(y.cpu().numpy().reshape(S_test.shape[1]*S_test.shape[2],2)[:,1],2)) 328 | # print(index, test_l2) 329 | index = index + 1 330 | 331 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u_np),np.std(test_error_u_np),np.min(test_error_u_np),np.max(test_error_u_np))) 332 | print("The average test v error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_v_np),np.std(test_error_v_np),np.min(test_error_v_np),np.max(test_error_v_np))) -------------------------------------------------------------------------------- /MMNIST/FNO/utilities3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.io 4 | import h5py 5 | import torch.nn as tnn 6 | 7 | import operator 8 | from functools import reduce 9 | from functools import partial 10 | import os 11 | 12 | ################################################ 13 | # 14 | # Utilities 15 | # 16 | ################################################# 17 | def get_freer_gpu(): 18 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') 19 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 20 | return str(np.argmax(memory_available)) 21 | 22 | os.environ['CUDA_VISIBLE_DEVICES']= get_freer_gpu() 23 | os.environ['CUDA_VISIBLE_DEVICES']= "6" 24 | 25 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | 27 | # reading data 28 | class MatReader(object): 29 | def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): 30 | super(MatReader, self).__init__() 31 | 32 | self.to_torch = to_torch 33 | self.to_cuda = to_cuda 34 | self.to_float = to_float 35 | 36 | self.file_path = file_path 37 | 38 | self.data = None 39 | self.old_mat = None 40 | self._load_file() 41 | 42 | def _load_file(self): 43 | try: 44 | self.data = scipy.io.loadmat(self.file_path) 45 | self.old_mat = True 46 | except: 47 | self.data = h5py.File(self.file_path) 48 | self.old_mat = False 49 | 50 | def load_file(self, file_path): 51 | self.file_path = file_path 52 | self._load_file() 53 | 54 | def read_field(self, field): 55 | x = self.data[field] 56 | 57 | if not self.old_mat: 58 | x = x[()] 59 | x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) 60 | 61 | if self.to_float: 62 | x = x.astype(np.float32) 63 | 64 | if self.to_torch: 65 | x = torch.from_numpy(x) 66 | 67 | if self.to_cuda: 68 | x = x.cuda() 69 | 70 | return x 71 | 72 | def set_cuda(self, to_cuda): 73 | self.to_cuda = to_cuda 74 | 75 | def set_torch(self, to_torch): 76 | self.to_torch = to_torch 77 | 78 | def set_float(self, to_float): 79 | self.to_float = to_float 80 | 81 | # normalization, pointwise gaussian 82 | class UnitGaussianNormalizer(object): 83 | def __init__(self, x, eps=0.00001): 84 | super(UnitGaussianNormalizer, self).__init__() 85 | 86 | # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T 87 | self.mean = torch.mean(x, 0) 88 | self.std = torch.std(x, 0) 89 | self.eps = eps 90 | 91 | def encode(self, x): 92 | x = (x - self.mean) / (self.std + self.eps) 93 | return x 94 | 95 | def decode(self, x, sample_idx=None): 96 | if sample_idx is None: 97 | std = self.std + self.eps # n 98 | mean = self.mean 99 | else: 100 | if len(self.mean.shape) == len(sample_idx[0].shape): 101 | std = self.std[sample_idx] + self.eps # batch*n 102 | mean = self.mean[sample_idx] 103 | if len(self.mean.shape) > len(sample_idx[0].shape): 104 | std = self.std[:,sample_idx]+ self.eps # T*batch*n 105 | mean = self.mean[:,sample_idx] 106 | 107 | # x is in shape of batch*n or T*batch*n 108 | x = (x * std) + mean 109 | return x 110 | 111 | def cuda(self): 112 | self.mean = self.mean.cuda() 113 | self.std = self.std.cuda() 114 | 115 | def cpu(self): 116 | self.mean = self.mean.cpu() 117 | self.std = self.std.cpu() 118 | 119 | # normalization, Gaussian 120 | class GaussianNormalizer(object): 121 | def __init__(self, x, eps=0.00001): 122 | super(GaussianNormalizer, self).__init__() 123 | 124 | self.mean = torch.mean(x) 125 | self.std = torch.std(x) 126 | self.eps = eps 127 | 128 | def encode(self, x): 129 | x = (x - self.mean) / (self.std + self.eps) 130 | return x 131 | 132 | def decode(self, x, sample_idx=None): 133 | x = (x * (self.std + self.eps)) + self.mean 134 | return x 135 | 136 | def cuda(self): 137 | self.mean = self.mean.cuda() 138 | self.std = self.std.cuda() 139 | 140 | def cpu(self): 141 | self.mean = self.mean.cpu() 142 | self.std = self.std.cpu() 143 | 144 | 145 | # normalization, scaling by range 146 | class RangeNormalizer(object): 147 | def __init__(self, x, low=0.0, high=1.0): 148 | super(RangeNormalizer, self).__init__() 149 | mymin = torch.min(x, 0)[0].view(-1) 150 | mymax = torch.max(x, 0)[0].view(-1) 151 | 152 | self.a = (high - low)/(mymax - mymin) 153 | self.b = -self.a*mymax + high 154 | 155 | def encode(self, x): 156 | s = x.size() 157 | x = x.view(s[0], -1) 158 | x = self.a*x + self.b 159 | x = x.view(s) 160 | return x 161 | 162 | def decode(self, x): 163 | s = x.size() 164 | x = x.view(s[0], -1) 165 | x = (x - self.b)/self.a 166 | x = x.view(s) 167 | return x 168 | 169 | #loss function with rel/abs Lp loss 170 | class LpLoss(object): 171 | def __init__(self, d=2, p=2, size_average=True, reduction=True): 172 | super(LpLoss, self).__init__() 173 | 174 | #Dimension and Lp-norm type are postive 175 | assert d > 0 and p > 0 176 | 177 | self.d = d 178 | self.p = p 179 | self.reduction = reduction 180 | self.size_average = size_average 181 | 182 | print(self.d, self.p, self.reduction, self.size_average) 183 | 184 | def abs(self, x, y): 185 | num_examples = x.size()[0] 186 | 187 | 188 | 189 | #Assume uniform mesh 190 | h = 1.0 / (x.size()[1] - 1.0) 191 | 192 | all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1) 193 | 194 | if self.reduction: 195 | if self.size_average: 196 | return torch.mean(all_norms) 197 | else: 198 | return torch.sum(all_norms) 199 | 200 | return all_norms 201 | 202 | def rel(self, x, y): 203 | num_examples = x.size()[0] # 100 x 64*64 x 10 204 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 205 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 206 | if self.reduction: 207 | if self.size_average: 208 | return torch.mean(diff_norms/y_norms) 209 | else: 210 | return torch.sum(diff_norms/y_norms) 211 | return diff_norms/y_norms 212 | 213 | def __call__(self, x, y): 214 | return self.rel(x, y) 215 | 216 | # Sobolev norm (HS norm) 217 | # where we also compare the numerical derivatives between the output and target 218 | class HsLoss(object): 219 | def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True): 220 | super(HsLoss, self).__init__() 221 | 222 | #Dimension and Lp-norm type are postive 223 | assert d > 0 and p > 0 224 | 225 | self.d = d 226 | self.p = p 227 | self.k = k 228 | self.balanced = group 229 | self.reduction = reduction 230 | self.size_average = size_average 231 | 232 | if a == None: 233 | a = [1,] * k 234 | self.a = a 235 | 236 | def rel(self, x, y): 237 | num_examples = x.size()[0] 238 | diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1) 239 | y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1) 240 | if self.reduction: 241 | if self.size_average: 242 | return torch.mean(diff_norms/y_norms) 243 | else: 244 | return torch.sum(diff_norms/y_norms) 245 | return diff_norms/y_norms 246 | 247 | def __call__(self, x, y, a=None): 248 | nx = x.size()[1] 249 | ny = x.size()[2] 250 | k = self.k 251 | balanced = self.balanced 252 | a = self.a 253 | x = x.view(x.shape[0], nx, ny, -1) 254 | y = y.view(y.shape[0], nx, ny, -1) 255 | 256 | k_x = torch.cat((torch.arange(start=0, end=nx//2, step=1),torch.arange(start=-nx//2, end=0, step=1)), 0).reshape(nx,1).repeat(1,ny) 257 | k_y = torch.cat((torch.arange(start=0, end=ny//2, step=1),torch.arange(start=-ny//2, end=0, step=1)), 0).reshape(1,ny).repeat(nx,1) 258 | k_x = torch.abs(k_x).reshape(1,nx,ny,1).to(x.device) 259 | k_y = torch.abs(k_y).reshape(1,nx,ny,1).to(x.device) 260 | 261 | x = torch.fft.fftn(x, dim=[1, 2]) 262 | y = torch.fft.fftn(y, dim=[1, 2]) 263 | 264 | if balanced==False: 265 | weight = 1 266 | if k >= 1: 267 | weight += a[0]**2 * (k_x**2 + k_y**2) 268 | if k >= 2: 269 | weight += a[1]**2 * (k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 270 | weight = torch.sqrt(weight) 271 | loss = self.rel(x*weight, y*weight) 272 | else: 273 | loss = self.rel(x, y) 274 | if k >= 1: 275 | weight = a[0] * torch.sqrt(k_x**2 + k_y**2) 276 | loss += self.rel(x*weight, y*weight) 277 | if k >= 2: 278 | weight = a[1] * torch.sqrt(k_x**4 + 2*k_x**2*k_y**2 + k_y**4) 279 | loss += self.rel(x*weight, y*weight) 280 | loss = loss / (k+1) 281 | 282 | return loss 283 | 284 | # A simple feedforward neural network 285 | class DenseNet(torch.nn.Module): 286 | def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False): 287 | super(DenseNet, self).__init__() 288 | 289 | self.n_layers = len(layers) - 1 290 | 291 | assert self.n_layers >= 1 292 | 293 | self.layers = tnn.ModuleList() 294 | 295 | for j in range(self.n_layers): 296 | self.layers.append(tnn.Linear(layers[j], layers[j+1])) 297 | 298 | if j != self.n_layers - 1: 299 | if normalize: 300 | self.layers.append(tnn.BatchNorm1d(layers[j+1])) 301 | 302 | self.layers.append(nonlinearity()) 303 | 304 | if out_nonlinearity is not None: 305 | self.layers.append(out_nonlinearity()) 306 | 307 | def forward(self, x): 308 | for _, l in enumerate(self.layers): 309 | x = l(x) 310 | 311 | return x 312 | 313 | 314 | # print the number of parameters 315 | def count_params(model): 316 | c = 0 317 | for p in list(model.parameters()): 318 | c += reduce(operator.mul, list(p.size())) 319 | return c -------------------------------------------------------------------------------- /PushForward/DeepONet/DeepONet_Pushforward.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from numpy.polynomial import polyutils 5 | 6 | from jax.experimental.stax import Dense, Gelu 7 | from jax.experimental import stax 8 | import os 9 | 10 | from scipy.integrate import solve_ivp 11 | 12 | import timeit 13 | 14 | from jax.experimental import optimizers 15 | 16 | from absl import app 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | from jax.numpy.linalg import norm 21 | 22 | from jax import random, grad, vmap, jit, vjp 23 | from functools import partial 24 | 25 | from torch.utils import data 26 | 27 | from tqdm import trange 28 | 29 | import itertools 30 | 31 | def get_freer_gpu(): 32 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') 33 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 34 | return str(np.argmax(memory_available)) 35 | 36 | os.environ['CUDA_VISIBLE_DEVICES']= get_freer_gpu() 37 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']="False" 38 | 39 | class DataGenerator(data.Dataset): 40 | def __init__(self, u, y, s, 41 | batch_size=100, rng_key=random.PRNGKey(1234)): 42 | 'Initialization' 43 | self.u = u 44 | self.y = y 45 | self.s = s 46 | self.N = u.shape[0] 47 | self.batch_size = batch_size 48 | self.key = rng_key 49 | 50 | # @partial(jit, static_argnums=(0,)) 51 | def __getitem__(self, index): 52 | 'Generate one batch of data' 53 | self.key, subkey = random.split(self.key) 54 | inputs,outputs = self.__data_generation(subkey) 55 | return inputs, outputs 56 | 57 | @partial(jit, static_argnums=(0,)) 58 | def __data_generation(self, key): 59 | 'Generates data containing batch_size samples' 60 | idx = random.choice(key, self.N, (self.batch_size,), replace=False) 61 | s = self.s[idx,:,:] 62 | u = self.u[idx,:,:] 63 | y = self.y[idx,:,:] 64 | inputs = (u, y) 65 | return inputs, s 66 | 67 | class PositionalEncodingY: 68 | def __init__(self, Y, d_model, max_len = 100,H=20): 69 | self.d_model = d_model 70 | self.Y = Y 71 | self.max_len = max_len 72 | self.H = H 73 | 74 | @partial(jit, static_argnums=(0,)) 75 | def forward(self, x): 76 | self.pe = np.zeros((x.shape[0], self.max_len, self.H)) 77 | T = jnp.asarray(self.Y[:,:,0:1]) 78 | position = jnp.tile(T,(1,1,self.H)) 79 | div_term = 2**jnp.arange(0,int(self.H/2),1)*jnp.pi 80 | self.pe = jax.ops.index_update(self.pe, jax.ops.index[:,:,0::2], jnp.cos(position[:,:,0::2] * div_term)) 81 | self.pe = jax.ops.index_update(self.pe, jax.ops.index[:,:,1::2], jnp.sin(position[:,:,1::2] * div_term)) 82 | x = jnp.concatenate([x, self.pe],axis=-1) 83 | return x 84 | 85 | class PositionalEncodingU: 86 | def __init__(self, Y, d_model, max_len = 100,H=20): 87 | self.d_model = d_model 88 | self.Y = Y 89 | self.max_len = max_len 90 | self.H = H 91 | 92 | @partial(jit, static_argnums=(0,)) 93 | def forward(self, x): 94 | self.pe = np.zeros((x.shape[0], self.max_len, self.H)) 95 | T = jnp.asarray(self.Y[:,:,0:1]) 96 | position = jnp.tile(T,(1,1,self.H)) 97 | div_term = 2**jnp.arange(0,int(self.H/2),1)*jnp.pi 98 | self.pe = jax.ops.index_update(self.pe, jax.ops.index[:,:,0::2], jnp.cos(position[:,:,0::2] * div_term)) 99 | self.pe = jax.ops.index_update(self.pe, jax.ops.index[:,:,1::2], jnp.sin(position[:,:,1::2] * div_term)) 100 | x = jnp.concatenate([x, self.pe],axis=-1) 101 | return x 102 | 103 | class DON: 104 | def __init__(self,branch_layers, trunk_layers , m=100, P=100, mn=None, std=None): 105 | # Network initialization and evaluation functions 106 | 107 | seed = np.random.randint(10000) 108 | self.branch_init, self.branch_apply = self.init_NN(branch_layers, activation=Gelu) 109 | self.in_shape = (-1, branch_layers[0]) 110 | self.out_shape, branch_params = self.branch_init(random.PRNGKey(seed), self.in_shape) 111 | 112 | seed = np.random.randint(10000) 113 | self.trunk_init, self.trunk_apply = self.init_NN(trunk_layers, activation=Gelu) 114 | self.in_shape = (-1, trunk_layers[0]) 115 | self.out_shape, trunk_params = self.trunk_init(random.PRNGKey(seed), self.in_shape) 116 | 117 | params = (trunk_params, branch_params) 118 | # Use optimizers to set optimizer initialization and update functions 119 | self.opt_init,self.opt_update,self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 120 | decay_steps=100, 121 | decay_rate=0.99)) 122 | self.opt_state = self.opt_init(params) 123 | # Logger 124 | self.itercount = itertools.count() 125 | self.loss_log = [] 126 | self.mean = mn 127 | self.std = std 128 | 129 | 130 | def init_NN(self, Q, activation=Gelu): 131 | layers = [] 132 | num_layers = len(Q) 133 | if num_layers < 2: 134 | net_init, net_apply = stax.serial() 135 | else: 136 | for i in range(0, num_layers-1): 137 | layers.append(Dense(Q[i+1])) 138 | layers.append(activation) 139 | layers.append(Dense(Q[-1])) 140 | net_init, net_apply = stax.serial(*layers) 141 | return net_init, net_apply 142 | 143 | @partial(jax.jit, static_argnums=0) 144 | def DON(self, params, inputs, ds=1): 145 | trunk_params, branch_params = params 146 | u, y = inputs 147 | print(u.shape, y.shape) 148 | t = self.trunk_apply(trunk_params, y).reshape(y.shape[0], y.shape[1], ds, int(100/ds)) 149 | b = self.branch_apply(branch_params, u.reshape(u.shape[0],1,u.shape[1]*u.shape[2])) 150 | b = b.reshape(b.shape[0],int(b.shape[2]/ds),ds) 151 | Guy = jnp.einsum("ijkl,ilk->ijk", t,b) 152 | return Guy 153 | 154 | @partial(jax.jit, static_argnums=0) 155 | def loss(self, params, batch): 156 | inputs, y = batch 157 | y_pred = self.DON(params,inputs) 158 | y = y*self.std + self.mean 159 | y_pred = y_pred*self.std + self.mean 160 | loss = np.mean((y.flatten() - y_pred.flatten())**2) 161 | return loss 162 | 163 | @partial(jax.jit, static_argnums=0) 164 | def lossT(self, params, batch): 165 | inputs, outputs = batch 166 | y_pred = self.DON(params,inputs) 167 | y_pred = y_pred*self.std + self.mean 168 | loss = np.mean((outputs.flatten() - y_pred.flatten())**2) 169 | return loss 170 | 171 | @partial(jax.jit, static_argnums=0) 172 | def L2errorT(self, params, batch): 173 | inputs, y = batch 174 | y_pred = self.DON(params,inputs) 175 | y_pred = y_pred*self.std + self.mean 176 | return norm(y.flatten() - y_pred.flatten(), 2)/norm(y.flatten(),2) 177 | 178 | @partial(jax.jit, static_argnums=0) 179 | def L2error(self, params, batch): 180 | inputs, y = batch 181 | y_pred = self.DON(params,inputs) 182 | y = y*self.std + self.mean 183 | y_pred = y_pred*self.std + self.mean 184 | return norm(y.flatten() - y_pred.flatten(), 2)/norm(y.flatten(),2) 185 | 186 | 187 | @partial(jit, static_argnums=(0,)) 188 | def step(self, i, opt_state, batch): 189 | params = self.get_params(opt_state) 190 | g = grad(self.loss)(params, batch) 191 | return self.opt_update(i, g, opt_state) 192 | 193 | def train(self, train_dataset, test_dataset, nIter = 10000): 194 | train_data = iter(train_dataset) 195 | test_data = iter(test_dataset) 196 | 197 | pbar = trange(nIter) 198 | for it in pbar: 199 | train_batch = next(train_data) 200 | test_batch = next(test_data) 201 | 202 | self.opt_state = self.step(next(self.itercount), self.opt_state, train_batch) 203 | 204 | if it % 100 == 0: 205 | params = self.get_params(self.opt_state) 206 | 207 | loss_train = self.loss(params, train_batch) 208 | loss_test = self.lossT(params, test_batch) 209 | 210 | errorTrain = self.L2error(params, train_batch) 211 | errorTest = self.L2errorT(params, test_batch) 212 | 213 | self.loss_log.append(loss_train) 214 | 215 | pbar.set_postfix({'Training loss': loss_train, 216 | 'Testing loss' : loss_test, 217 | 'Test error': errorTest, 218 | 'Train error': errorTrain}) 219 | 220 | @partial(jit, static_argnums=(0,)) 221 | def predict(self, params, inputs): 222 | s_pred = self.DON(params,inputs) 223 | return s_pred 224 | 225 | @partial(jit, static_argnums=(0,)) 226 | def predictT(self, params, inputs): 227 | s_pred = self.DON(params,inputs) 228 | return s_pred 229 | 230 | def ravel_list(self, *lst): 231 | return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([]) 232 | 233 | def ravel_pytree(self, pytree): 234 | leaves, treedef = jax.tree_util.tree_flatten(pytree) 235 | flat, unravel_list = vjp(self.ravel_list, *leaves) 236 | unravel_pytree = lambda flat: jax.tree_util.tree_unflatten(treedef, unravel_list(flat)) 237 | return flat, unravel_pytree 238 | 239 | def count_params(self, params): 240 | trunk_params, branch_params = params 241 | blv, _ = self.ravel_pytree(branch_params) 242 | tlv, _ = self.ravel_pytree(trunk_params) 243 | print("The number of model parameters is:",blv.shape[0]+tlv.shape[0]) 244 | 245 | TRAINING_ITERATIONS = 50000 246 | P = 300 247 | m = 300 248 | num_train = 1000 249 | num_test = 1000 250 | training_batch_size = 100 251 | du = 1 252 | dy = 1 253 | ds = 1 254 | n_hat = 100 255 | Nx = P 256 | index = 9 257 | length_scale = 0.9 258 | H_y = 10 259 | H_u = 10 260 | 261 | d = np.load("../Data/train_pushforward.npz") 262 | U_train = d["U_train"] 263 | x_train = d["x_train"] 264 | y_train = d["y_train"] 265 | s_train = d["s_train"] 266 | 267 | d = np.load("../Data/test_pushforward.npz") 268 | U_test = d["U_test"] 269 | x_test = d["x_test"] 270 | y_test = d["y_test"] 271 | s_test = d["s_test"] 272 | 273 | y_train = jnp.asarray(y_train) 274 | s_train = jnp.asarray(s_train) 275 | U_train = jnp.asarray(U_train) 276 | 277 | y_test = jnp.asarray(y_test) 278 | s_test = jnp.asarray(s_test) 279 | U_test = jnp.asarray(U_test) 280 | 281 | U_train = jnp.reshape(U_train,(num_test,m,du)) 282 | y_train = jnp.reshape(y_train,(num_train,P,dy)) 283 | s_train = jnp.reshape(s_train,(num_train,P,ds)) 284 | 285 | U_test = jnp.reshape(U_test,(num_test,m,du)) 286 | y_test = jnp.reshape(y_test,(num_test,P,dy)) 287 | s_test = jnp.reshape(s_test,(num_test,P,ds)) 288 | 289 | plot=False 290 | if plot: 291 | import matplotlib.pyplot as plt 292 | pltN = 10 293 | for i in range(0,pltN-1): 294 | plt.plot(y_train[i,:,0], s_train[i,:,0], 'r-') 295 | plt.plot(y_test[i,:,0], s_test[i,:,0], 'b-') 296 | 297 | plt.plot(y_train[pltN,:,0], s_train[pltN,:,0], 'r-', label="Training output") 298 | plt.plot(y_test[pltN,:,0], s_test[pltN,:,0], 'b-', label="Testing output") 299 | plt.legend() 300 | plt.show() 301 | 302 | x = jnp.linspace(0,1,num=m) 303 | pltN = 10 304 | for i in range(0,pltN-1): 305 | plt.plot(x, np.asarray(U_train)[i,:,0], 'y-') 306 | plt.plot(x, np.asarray(U_test)[i,:,0], 'g-') 307 | 308 | plt.plot(x, np.asarray(U_train)[pltN,:,0], 'y-', label="Training input") 309 | plt.plot(x, np.asarray(U_test)[pltN,:,0], 'g-', label="Testing input") 310 | plt.legend() 311 | 312 | pos_encodingy = PositionalEncodingY(y_train,int(y_train.shape[1]*y_train.shape[2]), max_len = P, H=H_y) 313 | y_train = pos_encodingy.forward(y_train) 314 | del pos_encodingy 315 | 316 | pos_encodingyt = PositionalEncodingY(y_test,int(y_test.shape[1]*y_test.shape[2]), max_len = P, H=H_y) 317 | y_test = pos_encodingyt.forward(y_test) 318 | del pos_encodingyt 319 | 320 | pos_encodingy = PositionalEncodingU(U_train,int(U_train.shape[1]*U_train.shape[2]), max_len = m, H=H_u) 321 | U_train = pos_encodingy.forward(U_train) 322 | del pos_encodingy 323 | 324 | pos_encodingyt = PositionalEncodingU(U_test,int(U_test.shape[1]*U_test.shape[2]), max_len = m, H=H_u) 325 | U_test = pos_encodingyt.forward(U_test) 326 | del pos_encodingyt 327 | 328 | s_train_mean = 0.#jnp.mean(s_train,axis=0) 329 | s_train_std = 1.#jnp.std(s_train,axis=0) + 1e-03 330 | 331 | s_train = (s_train - s_train_mean)/s_train_std 332 | 333 | # Perform the scattering transform for the inputs yh 334 | train_dataset = DataGenerator(U_train, y_train, s_train, training_batch_size) 335 | train_dataset = iter(train_dataset) 336 | 337 | test_dataset = DataGenerator(U_test, y_test, s_test, training_batch_size) 338 | test_dataset = iter(test_dataset) 339 | 340 | branch_layers = [m*(du*H_u+du), 512, 512, ds*n_hat] 341 | trunk_layers = [H_y*dy + dy, 512, 512, ds*n_hat] 342 | 343 | # branch_layers = [m*du, 512, 512, ds*n_hat] 344 | # trunk_layers = [dy, 512, 512, ds*n_hat] 345 | 346 | 347 | model = DON(branch_layers, trunk_layers, m=m, P=P, mn=s_train_mean, std=s_train_std) 348 | 349 | model.count_params(model.get_params(model.opt_state)) 350 | 351 | start_time = timeit.default_timer() 352 | model.train(train_dataset, test_dataset, nIter=TRAINING_ITERATIONS) 353 | elapsed = timeit.default_timer() - start_time 354 | print("The training wall-clock time is seconds is equal to %f seconds"%elapsed) 355 | 356 | params = model.get_params(model.opt_state) 357 | 358 | uCNN_test = model.predictT(params, (U_test, y_test)) 359 | test_error_u = [] 360 | for i in range(0,num_train): 361 | test_error_u.append(norm(s_test[i,:,0]- uCNN_test[i,:,0],2)/norm(s_test[i,:,0],2)) 362 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u),np.std(test_error_u),np.min(test_error_u),np.max(test_error_u))) 363 | 364 | uCNN_train = model.predict(params, (U_train, y_train)) 365 | train_error_u = [] 366 | for i in range(0,num_test): 367 | train_error_u.append(norm(s_train[i,:,0]- uCNN_train[i,:,0],2)/norm(s_train[i,:,0],2)) 368 | print("The average train u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(train_error_u),np.std(train_error_u),np.min(train_error_u),np.max(train_error_u))) 369 | 370 | trunk_params, branch_params = params 371 | t = model.trunk_apply(trunk_params, y_test).reshape(y_test.shape[0], y_test.shape[1], ds, int(n_hat/ds)) 372 | 373 | def minmax(a): 374 | minpos = a.index(min(a)) 375 | print("The minimum is at position", minpos) 376 | return minpos 377 | minpos = minmax(train_error_u) 378 | 379 | np.savez_compressed("eigenfunctions_DON3.npz", efuncs=t[minpos,:,0,:]) 380 | -------------------------------------------------------------------------------- /PushForward/LOCA/LOCAPushforward.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | from jax.example_libraries.stax import Dense, Gelu 9 | from jax.example_libraries import stax 10 | from jax.example_libraries import optimizers 11 | import os 12 | 13 | import timeit 14 | import numpy as np 15 | from jax.numpy.linalg import norm 16 | 17 | from jax import random, grad, vmap, jit, vjp 18 | from functools import partial 19 | 20 | from torch.utils import data 21 | 22 | from tqdm import trange 23 | 24 | import itertools 25 | 26 | from kymatio.numpy import Scattering1D 27 | from jax.flatten_util import ravel_pytree 28 | 29 | from numpy.polynomial.legendre import leggauss 30 | 31 | def get_freer_gpu(): 32 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Used >tmp') 33 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 34 | return str(np.argmin(memory_available)) 35 | 36 | os.environ['CUDA_VISIBLE_DEVICES']= get_freer_gpu() 37 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']="False" 38 | 39 | def pairwise_distances(dist,**arg): 40 | return jit(vmap(vmap(partial(dist,**arg),in_axes=(None,0)),in_axes=(0,None))) 41 | 42 | def euclid_distance(x,y): 43 | XX=jnp.dot(x,x) 44 | YY=jnp.dot(y,y) 45 | XY=jnp.dot(x,y) 46 | return XX+YY-2*XY 47 | 48 | class DataGenerator(data.Dataset): 49 | def __init__(self, inputsxuy, inputsxu, y, s, z, w, 50 | batch_size=100, rng_key=random.PRNGKey(1234)): 51 | 'Initialization' 52 | self.inputsxuy = inputsxuy 53 | self.inputsxu = inputsxu 54 | self.y = y 55 | self.s = s 56 | self.z = z 57 | self.w = w 58 | 59 | self.N = inputsxu.shape[0] 60 | self.batch_size = batch_size 61 | self.key = rng_key 62 | 63 | # @partial(jit, static_argnums=(0,)) 64 | def __getitem__(self, index): 65 | 'Generate one batch of data' 66 | self.key, subkey = random.split(self.key) 67 | inputs,outputs = self.__data_generation(subkey) 68 | return inputs, outputs 69 | 70 | @partial(jit, static_argnums=(0,)) 71 | def __data_generation(self, key): 72 | 'Generates data containing batch_size samples' 73 | idx = random.choice(key, self.N, (self.batch_size,), replace=False) 74 | s = self.s[idx,:,:] 75 | inputsxu = self.inputsxu[idx,:,:] 76 | y = self.y[idx,:,:] 77 | z = self.z[idx,:,:] 78 | w = self.w[idx,:,:] 79 | inputs = (inputsxu, y, z, w) 80 | return inputs, s 81 | 82 | class PositionalEncodingY: 83 | def __init__(self, Y, d_model, max_len = 100,H=20): 84 | self.d_model = d_model 85 | self.Y = Y 86 | self.max_len = max_len 87 | self.H = H 88 | 89 | @partial(jit, static_argnums=(0,)) 90 | def forward(self, x): 91 | self.pe = np.zeros((x.shape[0], self.max_len, self.H)) 92 | T = jnp.asarray(self.Y[:,:,0:1]) 93 | position = jnp.tile(T,(1,1,self.H)) 94 | div_term = 2**jnp.arange(0,int(self.H/2),1)*jnp.pi 95 | self.pe = jax.ops.index_update(self.pe, jax.ops.index[:,:,0::2], jnp.cos(position[:,:,0::2] * div_term)) 96 | self.pe = jax.ops.index_update(self.pe, jax.ops.index[:,:,1::2], jnp.sin(position[:,:,1::2] * div_term)) 97 | x = jnp.concatenate([x, self.pe],axis=-1) 98 | return x 99 | 100 | def scatteringTransform(sig, l=100, m=100, training_batch_size = 100): 101 | J = 1 102 | Q = 8 103 | T = sig.shape[1] 104 | scattering = Scattering1D(J, T, Q) 105 | sig = np.asarray(sig) 106 | sctcoef = np.zeros((training_batch_size, 1200, 1)) 107 | for i in range(0,training_batch_size): 108 | sctcoef[i,:,:] = scattering(sig[i,:,0]).flatten()[:,None] 109 | return sctcoef 110 | 111 | class LpLoss(object): 112 | def __init__(self, d=2, p=2): 113 | super(LpLoss, self).__init__() 114 | 115 | self.d = d 116 | self.p = p 117 | 118 | def rel(self, y, x): 119 | num_examples = x.shape[0] 120 | diff_norms = jnp.linalg.norm(y.reshape(num_examples,-1) - x.reshape(num_examples,-1), self.p, 1) 121 | y_norms = jnp.linalg.norm(y.reshape(num_examples,-1), self.p, 1) 122 | return jnp.mean(diff_norms/y_norms) 123 | 124 | def __call__(self, y, x): 125 | return self.rel(y, x) 126 | 127 | class LOCA: 128 | def __init__(self, q_layers, g_layers, v_layers , m=100, P=100, jac_det=None): 129 | # Network initialization and evaluation functions 130 | seed = np.random.randint(10000) 131 | self.q_init, self.q_apply = self.init_NN(q_layers, activation=Gelu) 132 | self.in_shape = (-1, q_layers[0]) 133 | self.out_shape, q_params = self.q_init(random.PRNGKey(seed), self.in_shape) 134 | 135 | seed = np.random.randint(10000) 136 | self.v_init, self.v_apply = self.init_NN(v_layers, activation=Gelu) 137 | self.in_shape = (-1, v_layers[0]) 138 | self.out_shape, v_params = self.v_init(random.PRNGKey(seed), self.in_shape) 139 | 140 | seed = np.random.randint(10000) 141 | self.g_init, self.g_apply = self.init_NN(g_layers, activation=Gelu) 142 | self.in_shape = (-1, g_layers[0]) 143 | self.out_shape, g_params = self.g_init(random.PRNGKey(seed), self.in_shape) 144 | 145 | # RBF kernel parameters 146 | beta = [1.] 147 | gamma = [1.] 148 | # Model parameters 149 | params = (beta, gamma,q_params, g_params, v_params) 150 | 151 | self.opt_init,self.opt_update,self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 152 | decay_steps=100, 153 | decay_rate=0.99)) 154 | self.opt_state = self.opt_init(params) 155 | self.itercount = itertools.count() 156 | self.loss_log = [] 157 | 158 | self.l2loss = LpLoss() 159 | 160 | self.jac_det = jac_det 161 | self.vdistance_function = vmap(pairwise_distances(euclid_distance)) 162 | 163 | def init_NN(self, Q, activation=Gelu): 164 | layers = [] 165 | num_layers = len(Q) 166 | if num_layers < 2: 167 | net_init, net_apply = stax.serial() 168 | else: 169 | for i in range(0, num_layers-2): 170 | layers.append(Dense(Q[i+1])) 171 | layers.append(activation) 172 | layers.append(Dense(Q[-1])) 173 | net_init, net_apply = stax.serial(*layers) 174 | return net_init, net_apply 175 | 176 | @partial(jax.jit, static_argnums=0) 177 | def RBF(self, X, Y, gamma, beta): 178 | d = self.vdistance_function(X, Y) 179 | return beta[0]*jnp.exp(-gamma[0]*d) 180 | 181 | @partial(jax.jit, static_argnums=0) 182 | def Matern_32(self, X, Y, gamma=[0.5], beta=[0.5]): 183 | d = self.vdistance_function(X, Y) 184 | return (1 + (jnp.sqrt(3)*gamma[0])*d)*beta[0]*jnp.exp(-(jnp.sqrt(3)*gamma[0])*d) 185 | 186 | @partial(jax.jit, static_argnums=0) 187 | def Matern_52(self, X, Y, gamma=[0.5], beta=[0.5]): 188 | d = self.vdistance_function(X, Y) 189 | return (1 + (jnp.sqrt(5)*gamma[0])*d + (5/3*gamma[0]**2)*d**2)*beta[0]*jnp.exp(-(jnp.sqrt(5)*gamma[0])*d) 190 | 191 | @partial(jax.jit, static_argnums=0) 192 | def periodic(self, X, Y, gamma=[0.5], beta=[0.5], p=0.7): 193 | d = self.vdistance_function(X, Y) 194 | return beta[0]*jnp.exp(-2.0*jnp.sin(jnp.pi*d/p)**2*gamma[0]) 195 | 196 | @partial(jax.jit, static_argnums=0) 197 | def RQK(self, X, Y, gamma, beta): 198 | d = self.vdistance_function(X, Y) 199 | return beta[0]*(1 + (1./(3.*0.1*2))*gamma[0]*d)**(gamma[0]) 200 | 201 | @partial(jax.jit, static_argnums=0) 202 | def local_periodic(self, X, Y, gamma, beta): 203 | return self.periodic(X, Y, gamma, beta)*self.RBF(X, Y, gamma, beta) 204 | 205 | @partial(jax.jit, static_argnums=0) 206 | def LOCA_net(self, params, inputs, ds=1): 207 | beta, gamma, q_params, g_params, v_params = params 208 | u, y, z, w = inputs 209 | y = self.q_apply(q_params,y) 210 | z = self.q_apply(q_params,z) 211 | 212 | K = self.periodic(z, z, gamma, beta) 213 | Kzz = jnp.sqrt(self.jac_det*jnp.einsum("ijk,ikl->ijl",K,w)) 214 | 215 | K = self.periodic(y, z, gamma, beta) 216 | Kyz = jnp.sqrt(self.jac_det*jnp.einsum("ijk,ikl->ijl",K,w)) 217 | 218 | mean_K = jnp.matmul(Kyz, jnp.swapaxes(Kzz,1,2)) 219 | K = jnp.divide(K,mean_K) 220 | 221 | g = self.g_apply(g_params,z) 222 | g = self.jac_det*jnp.einsum("ijk,iklm,ik->ijlm",K,g.reshape(g.shape[0],g.shape[1], ds, int(g.shape[-1]/ds)),w[:,:,-1]) 223 | g = jax.nn.softmax(g, axis=-1) 224 | 225 | v = self.v_apply(v_params, u.reshape(u.shape[0],1,u.shape[1]*u.shape[2])) 226 | v = v.reshape(v.shape[0],int(v.shape[2]/ds),ds) 227 | Guy = jnp.einsum("ijkl,ilk->ijk", g,v) 228 | return Guy 229 | 230 | @partial(jax.jit, static_argnums=0) 231 | def loss(self, params, batch): 232 | inputs, outputs = batch 233 | y_pred = self.LOCA_net(params,inputs) 234 | loss = np.mean((outputs.flatten() - y_pred.flatten())**2) 235 | return loss 236 | 237 | @partial(jax.jit, static_argnums=0) 238 | def L2error(self, params, batch): 239 | inputs, outputs = batch 240 | y_pred = self.LOCA_net(params,inputs) 241 | return norm(outputs.flatten() - y_pred.flatten(), 2)/norm(outputs.flatten(),2) 242 | 243 | @partial(jit, static_argnums=(0,)) 244 | def step(self, i, opt_state, batch): 245 | params = self.get_params(opt_state) 246 | g = grad(self.loss)(params, batch) 247 | return self.opt_update(i, g, opt_state) 248 | 249 | def train(self, train_dataset, test_dataset, nIter = 10000): 250 | train_data = iter(train_dataset) 251 | test_data = iter(test_dataset) 252 | 253 | pbar = trange(nIter) 254 | for it in pbar: 255 | train_batch = next(train_data) 256 | test_batch = next(test_data) 257 | 258 | self.opt_state = self.step(next(self.itercount), self.opt_state, train_batch) 259 | 260 | if it % 100 == 0: 261 | params = self.get_params(self.opt_state) 262 | 263 | loss_train = self.loss(params, train_batch) 264 | loss_test = self.loss(params, test_batch) 265 | 266 | errorTrain = self.L2error(params, train_batch) 267 | errorTest = self.L2error(params, test_batch) 268 | 269 | self.loss_log.append(loss_train) 270 | 271 | pbar.set_postfix({'Training loss': loss_train, 272 | 'Testing loss' : loss_test, 273 | 'Test error': errorTest, 274 | 'Train error': errorTrain}) 275 | 276 | @partial(jit, static_argnums=(0,)) 277 | def predict(self, params, inputs): 278 | s_pred = self.LOCA_net(params,inputs) 279 | return s_pred 280 | 281 | def count_params(self, params): 282 | params_flat, _ = ravel_pytree(params) 283 | print("The number of model parameters is:",params_flat.shape[0]) 284 | 285 | 286 | TRAINING_ITERATIONS = 50000 287 | P = 300 288 | m = 300 289 | L = 1 290 | T = 1 291 | N_hat = 1 292 | num_train = 1000 293 | num_test = 1000 294 | training_batch_size = 100 295 | du = 1 296 | dy = 1 297 | ds = 1 298 | n_hat = 100 299 | l = 100 300 | Nx = P 301 | H = 20 302 | 303 | # Number of GLL quadrature points, coordinates and weights 304 | polypoints = 20 305 | z, w = leggauss(polypoints) 306 | lb = np.array([0.0]) 307 | ub = np.array([1.0]) 308 | 309 | # Map [-1,1] -> [0,1] 310 | z = 0.5*(ub - lb)*(z + 1.0) + lb 311 | jac_det = 0.5*(ub-lb) 312 | 313 | # Reshape both weights and coordinates. We need them to have shape: (num_train, N, dy) 314 | z = np.tile(np.expand_dims(z,0),(num_train,1))[:,:,None] 315 | w = np.tile(np.expand_dims(w,0),(num_train,1))[:,:,None] 316 | 317 | # Create the dataset 318 | d = np.load("../Data/train_pushforward.npz") 319 | U_train = d["U_train"] 320 | x_train = d["x_train"] 321 | y_train = d["y_train"] 322 | s_train = d["s_train"] 323 | 324 | d = np.load("../Data/test_pushforward.npz") 325 | U_test = d["U_test"] 326 | x_test = d["x_test"] 327 | y_test = d["y_test"] 328 | s_test = d["s_test"] 329 | 330 | # Make all array to be jax numpy format 331 | y_train = jnp.asarray(y_train) 332 | s_train = jnp.asarray(s_train) 333 | 334 | y_test = jnp.asarray(y_test) 335 | s_test = jnp.asarray(s_test) 336 | 337 | z = jnp.asarray(z) 338 | w = jnp.asarray(w) 339 | 340 | U_train = np.reshape(U_train,(num_test,m,du)) 341 | y_train = jnp.reshape(y_train,(num_train,P,dy)) 342 | s_train = jnp.reshape(s_train,(num_train,P,ds)) 343 | 344 | U_test = np.reshape(U_test,(num_test,m,du)) 345 | y_test = jnp.reshape(y_test,(num_test,P,dy)) 346 | s_test = jnp.reshape(s_test,(num_test,P,ds)) 347 | 348 | z = jnp.reshape(z,(num_test,polypoints,dy)) 349 | w = jnp.reshape(w,(num_test,polypoints,dy)) 350 | 351 | plot=False 352 | if plot: 353 | import matplotlib.pyplot as plt 354 | pltN = 10 355 | for i in range(0,pltN-1): 356 | plt.plot(y_train[i,:,0], s_train[i,:,0], 'r-') 357 | plt.plot(y_test[i,:,0], s_test[i,:,0], 'b-') 358 | 359 | plt.plot(y_train[pltN,:,0], s_train[pltN,:,0], 'r-', label="Training output") 360 | plt.plot(y_test[pltN,:,0], s_test[pltN,:,0], 'b-', label="Testing output") 361 | plt.legend() 362 | plt.show() 363 | 364 | x = jnp.linspace(0,1,num=m) 365 | pltN = 10 366 | for i in range(0,pltN-1): 367 | plt.plot(x, np.asarray(U_train)[i,:,0], 'y-') 368 | plt.plot(x, np.asarray(U_test)[i,:,0], 'g-') 369 | 370 | plt.plot(x, np.asarray(U_train)[pltN,:,0], 'y-', label="Training input") 371 | plt.plot(x, np.asarray(U_test)[pltN,:,0], 'g-', label="Testing input") 372 | plt.legend() 373 | plt.show() 374 | 375 | y_train_pos = y_train 376 | 377 | pos_encodingy = PositionalEncodingY(y_train,int(y_train.shape[1]*y_train.shape[2]), max_len = P, H=H) 378 | y_train = pos_encodingy.forward(y_train) 379 | del pos_encodingy 380 | 381 | pos_encodingy = PositionalEncodingY(z,int(z.shape[1]*z.shape[2]), max_len = polypoints, H=H) 382 | z = pos_encodingy.forward(z) 383 | del pos_encodingy 384 | 385 | pos_encodingyt = PositionalEncodingY(y_test,int(y_test.shape[1]*y_test.shape[2]), max_len = P, H=H) 386 | y_test = pos_encodingyt.forward(y_test) 387 | del pos_encodingyt 388 | 389 | start_time = timeit.default_timer() 390 | inputs_trainxu = jnp.asarray(scatteringTransform(U_train, l=l, m=m, training_batch_size=num_train)) 391 | inputs_testxu = jnp.asarray(scatteringTransform(U_test , l=l, m=m, training_batch_size=num_test)) 392 | 393 | # inputs_trainxu = jnp.asarray(U_train) 394 | # inputs_testxu = jnp.asarray(U_test ) 395 | elapsed = timeit.default_timer() - start_time 396 | print("The wall-clock time for for loop is seconds is equal to %f seconds"%elapsed) 397 | print(inputs_trainxu.shape, inputs_testxu.shape) 398 | 399 | train_dataset = DataGenerator(inputs_trainxu, inputs_trainxu, y_train, s_train, z, w, training_batch_size) 400 | train_dataset = iter(train_dataset) 401 | 402 | test_dataset = DataGenerator(inputs_testxu, inputs_testxu, y_test, s_test, z, w, training_batch_size) 403 | test_dataset = iter(test_dataset) 404 | 405 | q_layers = [L*dy+H*dy, 100, 100, l] 406 | v_layers = [1200*du, 1024, ds*n_hat] 407 | g_layers = [l, 100, 100, ds*n_hat] 408 | 409 | model = LOCA(q_layers, g_layers, v_layers, m=m, P=P, jac_det=jac_det) 410 | 411 | model.count_params(model.get_params(model.opt_state)) 412 | 413 | start_time = timeit.default_timer() 414 | model.train(train_dataset, test_dataset, nIter=TRAINING_ITERATIONS) 415 | elapsed = timeit.default_timer() - start_time 416 | print("The training wall-clock time is seconds is equal to %f seconds"%elapsed) 417 | 418 | params = model.get_params(model.opt_state) 419 | 420 | uCNN_test = model.predict(params, (inputs_testxu,y_test, z, w)) 421 | test_error_u = [] 422 | for i in range(0,s_test.shape[0]): 423 | test_error_u.append(jnp.linalg.norm(s_test[i,:,-1] - uCNN_test[i,:,-1], 2)/jnp.linalg.norm(s_test[i,:,-1], 2)) 424 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u),np.std(test_error_u),np.min(test_error_u),np.max(test_error_u))) 425 | 426 | uCNN_train = model.predict(params, (inputs_trainxu, y_train, z, w)) 427 | 428 | train_error_u = [] 429 | for i in range(0,s_test.shape[0]): 430 | train_error_u.append(jnp.linalg.norm(s_train[i,:,-1] - uCNN_train[i,:,-1], 2)/jnp.linalg.norm(s_train[i,:,-1], 2)) 431 | print("The average train u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(train_error_u),np.std(train_error_u),np.min(train_error_u),np.max(train_error_u))) 432 | 433 | beta, gamma, q_params, g_params, v_params = params 434 | y = y_test 435 | y = model.q_apply(q_params,y) 436 | z = model.q_apply(q_params,z) 437 | 438 | K = model.periodic(z, z, gamma, beta) 439 | Kzz = jnp.sqrt(model.jac_det*jnp.einsum("ijk,ikl->ijl",K,w)) 440 | 441 | K = model.periodic(y, z, gamma, beta) 442 | Kyz = jnp.sqrt(model.jac_det*jnp.einsum("ijk,ikl->ijl",K,w)) 443 | 444 | mean_K = jnp.matmul(Kyz, jnp.swapaxes(Kzz,1,2)) 445 | K = jnp.divide(K,mean_K) 446 | 447 | g = model.g_apply(g_params,z) 448 | g = model.jac_det*jnp.einsum("ijk,iklm,ik->ijlm",K,g.reshape(g.shape[0],g.shape[1], ds, int(g.shape[-1]/ds)),w[:,:,-1]) 449 | g = jax.nn.softmax(g, axis=-1) 450 | 451 | def minmax(a): 452 | minpos = a.index(min(a)) 453 | print("The minimum is at position", minpos) 454 | return minpos 455 | minpos = minmax(train_error_u) 456 | 457 | np.savez_compressed("eigenfunctions_KCAlocalper.npz", efuncs=g[minpos,:,0,:]) 458 | -------------------------------------------------------------------------------- /PushForward/LOCA/LOCA_closetoDON.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | from jax.example_libraries.stax import Dense, Gelu 9 | from jax.example_libraries import stax 10 | from jax.example_libraries import optimizers 11 | import os 12 | 13 | import timeit 14 | import numpy as np 15 | from jax.numpy.linalg import norm 16 | 17 | from jax import random, grad, vmap, jit, vjp 18 | from functools import partial 19 | 20 | from torch.utils import data 21 | 22 | from tqdm import trange 23 | 24 | import itertools 25 | 26 | from kymatio.numpy import Scattering1D 27 | from jax.flatten_util import ravel_pytree 28 | 29 | from numpy.polynomial.legendre import leggauss 30 | 31 | def get_freer_gpu(): 32 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Used >tmp') 33 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()] 34 | return str(np.argmin(memory_available)) 35 | 36 | os.environ['CUDA_VISIBLE_DEVICES']= get_freer_gpu() 37 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']="False" 38 | 39 | def pairwise_distances(dist,**arg): 40 | return jit(vmap(vmap(partial(dist,**arg),in_axes=(None,0)),in_axes=(0,None))) 41 | 42 | def euclid_distance(x,y): 43 | XX=jnp.dot(x,x) 44 | YY=jnp.dot(y,y) 45 | XY=jnp.dot(x,y) 46 | return XX+YY-2*XY 47 | 48 | class DataGenerator(data.Dataset): 49 | def __init__(self, inputsxuy, inputsxu, y, s, z, w, 50 | batch_size=100, rng_key=random.PRNGKey(1234)): 51 | 'Initialization' 52 | self.inputsxuy = inputsxuy 53 | self.inputsxu = inputsxu 54 | self.y = y 55 | self.s = s 56 | self.z = z 57 | self.w = w 58 | 59 | self.N = inputsxu.shape[0] 60 | self.batch_size = batch_size 61 | self.key = rng_key 62 | 63 | # @partial(jit, static_argnums=(0,)) 64 | def __getitem__(self, index): 65 | 'Generate one batch of data' 66 | self.key, subkey = random.split(self.key) 67 | inputs,outputs = self.__data_generation(subkey) 68 | return inputs, outputs 69 | 70 | @partial(jit, static_argnums=(0,)) 71 | def __data_generation(self, key): 72 | 'Generates data containing batch_size samples' 73 | idx = random.choice(key, self.N, (self.batch_size,), replace=False) 74 | s = self.s[idx,:,:] 75 | inputsxu = self.inputsxu[idx,:,:] 76 | y = self.y[idx,:,:] 77 | z = self.z[idx,:,:] 78 | w = self.w[idx,:,:] 79 | inputs = (inputsxu, y, z, w) 80 | return inputs, s 81 | 82 | class PositionalEncodingY: 83 | def __init__(self, Y, d_model, max_len = 100,H=20): 84 | self.d_model = d_model 85 | self.Y = Y 86 | self.max_len = max_len 87 | self.H = H 88 | 89 | @partial(jit, static_argnums=(0,)) 90 | def forward(self, x): 91 | self.pe = np.zeros((x.shape[0], self.max_len, self.H)) 92 | T = jnp.asarray(self.Y[:,:,0:1]) 93 | position = jnp.tile(T,(1,1,self.H)) 94 | div_term = 2**jnp.arange(0,int(self.H/2),1)*jnp.pi 95 | self.pe = jax.ops.index_update(self.pe, jax.ops.index[:,:,0::2], jnp.cos(position[:,:,0::2] * div_term)) 96 | self.pe = jax.ops.index_update(self.pe, jax.ops.index[:,:,1::2], jnp.sin(position[:,:,1::2] * div_term)) 97 | x = jnp.concatenate([x, self.pe],axis=-1) 98 | return x 99 | 100 | def scatteringTransform(sig, l=100, m=100, training_batch_size = 100): 101 | J = 1 102 | Q = 8 103 | T = sig.shape[1] 104 | scattering = Scattering1D(J, T, Q) 105 | sig = np.asarray(sig) 106 | sctcoef = np.zeros((training_batch_size, 1200, 1)) 107 | for i in range(0,training_batch_size): 108 | sctcoef[i,:,:] = scattering(sig[i,:,0]).flatten()[:,None] 109 | return sctcoef 110 | 111 | class LpLoss(object): 112 | def __init__(self, d=2, p=2): 113 | super(LpLoss, self).__init__() 114 | 115 | self.d = d 116 | self.p = p 117 | 118 | def rel(self, y, x): 119 | num_examples = x.shape[0] 120 | diff_norms = jnp.linalg.norm(y.reshape(num_examples,-1) - x.reshape(num_examples,-1), self.p, 1) 121 | y_norms = jnp.linalg.norm(y.reshape(num_examples,-1), self.p, 1) 122 | return jnp.mean(diff_norms/y_norms) 123 | 124 | def __call__(self, y, x): 125 | return self.rel(y, x) 126 | 127 | class LOCA: 128 | def __init__(self, g_layers, v_layers , m=100, P=100, jac_det=None): 129 | # Network initialization and evaluation functions 130 | seed = np.random.randint(10000) 131 | self.v_init, self.v_apply = self.init_NN(v_layers, activation=Gelu) 132 | self.in_shape = (-1, v_layers[0]) 133 | self.out_shape, v_params = self.v_init(random.PRNGKey(seed), self.in_shape) 134 | 135 | seed = np.random.randint(10000) 136 | self.g_init, self.g_apply = self.init_NN(g_layers, activation=Gelu) 137 | self.in_shape = (-1, g_layers[0]) 138 | self.out_shape, g_params = self.g_init(random.PRNGKey(seed), self.in_shape) 139 | 140 | # RBF kernel parameters 141 | beta = [1.] 142 | gamma = [1.] 143 | # Model parameters 144 | params = (beta, gamma, g_params, v_params) 145 | 146 | self.opt_init,self.opt_update,self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 147 | decay_steps=100, 148 | decay_rate=0.99)) 149 | self.opt_state = self.opt_init(params) 150 | self.itercount = itertools.count() 151 | self.loss_log = [] 152 | 153 | self.l2loss = LpLoss() 154 | 155 | self.jac_det = jac_det 156 | self.vdistance_function = vmap(pairwise_distances(euclid_distance)) 157 | 158 | def init_NN(self, Q, activation=Gelu): 159 | layers = [] 160 | num_layers = len(Q) 161 | if num_layers < 2: 162 | net_init, net_apply = stax.serial() 163 | else: 164 | for i in range(0, num_layers-2): 165 | layers.append(Dense(Q[i+1])) 166 | layers.append(activation) 167 | layers.append(Dense(Q[-1])) 168 | net_init, net_apply = stax.serial(*layers) 169 | return net_init, net_apply 170 | 171 | @partial(jax.jit, static_argnums=0) 172 | def RBF(self, X, Y, gamma, beta): 173 | d = self.vdistance_function(X, Y) 174 | return beta[0]*jnp.exp(-gamma[0]*d) 175 | 176 | @partial(jax.jit, static_argnums=0) 177 | def Matern_32(self, X, Y, gamma=[0.5], beta=[0.5]): 178 | d = self.vdistance_function(X, Y) 179 | return (1 + (jnp.sqrt(3)*gamma[0])*d)*beta[0]*jnp.exp(-(jnp.sqrt(3)*gamma[0])*d) 180 | 181 | @partial(jax.jit, static_argnums=0) 182 | def Matern_52(self, X, Y, gamma=[0.5], beta=[0.5]): 183 | d = self.vdistance_function(X, Y) 184 | return (1 + (jnp.sqrt(5)*gamma[0])*d + (5/3*gamma[0]**2)*d**2)*beta[0]*jnp.exp(-(jnp.sqrt(5)*gamma[0])*d) 185 | 186 | @partial(jax.jit, static_argnums=0) 187 | def periodic(self, X, Y, gamma=[0.5], beta=[0.5], p=0.7): 188 | d = self.vdistance_function(X, Y) 189 | return beta[0]*jnp.exp(-2.0*jnp.sin(jnp.pi*d/p)**2*gamma[0]) 190 | 191 | @partial(jax.jit, static_argnums=0) 192 | def RQK(self, X, Y, gamma, beta): 193 | d = self.vdistance_function(X, Y) 194 | return beta[0]*(1 + (1./(3.*0.1*2))*gamma[0]*d)**(gamma[0]) 195 | 196 | @partial(jax.jit, static_argnums=0) 197 | def local_periodic(self, X, Y, gamma, beta): 198 | return self.periodic(X, Y, gamma, beta)*self.RBF(X, Y, gamma, beta) 199 | 200 | @partial(jax.jit, static_argnums=0) 201 | def LOCA_net(self, params, inputs, ds=1): 202 | beta, gamma, g_params, v_params = params 203 | u, y, z, w = inputs 204 | K = self.RBF(z, z, gamma, beta) 205 | Kzz = jnp.sqrt(self.jac_det*jnp.einsum("ijk,ikl->ijl",K,w)) 206 | 207 | K = self.RBF(y, z, gamma, beta) 208 | Kyz = jnp.sqrt(self.jac_det*jnp.einsum("ijk,ikl->ijl",K,w)) 209 | 210 | mean_K = jnp.matmul(Kyz, jnp.swapaxes(Kzz,1,2)) 211 | K = jnp.divide(K,mean_K) 212 | 213 | g = self.g_apply(g_params,z) 214 | g = self.jac_det*jnp.einsum("ijk,iklm,ik->ijlm",K,g.reshape(g.shape[0],g.shape[1], ds, int(g.shape[-1]/ds)),w[:,:,-1]) 215 | g = jax.nn.softmax(g, axis=-1) 216 | 217 | v = self.v_apply(v_params, u.reshape(u.shape[0],1,u.shape[1]*u.shape[2])) 218 | v = v.reshape(v.shape[0],int(v.shape[2]/ds),ds) 219 | Guy = jnp.einsum("ijkl,ilk->ijk", g,v) 220 | return Guy 221 | 222 | @partial(jax.jit, static_argnums=0) 223 | def loss(self, params, batch): 224 | inputs, outputs = batch 225 | y_pred = self.LOCA_net(params,inputs) 226 | loss = np.mean((outputs.flatten() - y_pred.flatten())**2) 227 | return loss 228 | 229 | @partial(jax.jit, static_argnums=0) 230 | def L2error(self, params, batch): 231 | inputs, outputs = batch 232 | y_pred = self.LOCA_net(params,inputs) 233 | return norm(outputs.flatten() - y_pred.flatten(), 2)/norm(outputs.flatten(),2) 234 | 235 | @partial(jit, static_argnums=(0,)) 236 | def step(self, i, opt_state, batch): 237 | params = self.get_params(opt_state) 238 | g = grad(self.loss)(params, batch) 239 | return self.opt_update(i, g, opt_state) 240 | 241 | def train(self, train_dataset, test_dataset, nIter = 10000): 242 | train_data = iter(train_dataset) 243 | test_data = iter(test_dataset) 244 | 245 | pbar = trange(nIter) 246 | for it in pbar: 247 | train_batch = next(train_data) 248 | test_batch = next(test_data) 249 | 250 | self.opt_state = self.step(next(self.itercount), self.opt_state, train_batch) 251 | 252 | if it % 100 == 0: 253 | params = self.get_params(self.opt_state) 254 | 255 | loss_train = self.loss(params, train_batch) 256 | loss_test = self.loss(params, test_batch) 257 | 258 | errorTrain = self.L2error(params, train_batch) 259 | errorTest = self.L2error(params, test_batch) 260 | 261 | self.loss_log.append(loss_train) 262 | 263 | pbar.set_postfix({'Training loss': loss_train, 264 | 'Testing loss' : loss_test, 265 | 'Test error': errorTest, 266 | 'Train error': errorTrain}) 267 | 268 | @partial(jit, static_argnums=(0,)) 269 | def predict(self, params, inputs): 270 | s_pred = self.LOCA_net(params,inputs) 271 | return s_pred 272 | 273 | def count_params(self, params): 274 | params_flat, _ = ravel_pytree(params) 275 | print("The number of model parameters is:",params_flat.shape[0]) 276 | 277 | 278 | TRAINING_ITERATIONS = 50000 279 | P = 300 280 | m = 300 281 | L = 1 282 | T = 1 283 | N_hat = 1 284 | num_train = 1000 285 | num_test = 1000 286 | training_batch_size = 100 287 | du = 1 288 | dy = 1 289 | ds = 1 290 | n_hat = 100 291 | l = 100 292 | Nx = P 293 | H = 20 294 | 295 | # Number of GLL quadrature points, coordinates and weights 296 | polypoints = 20 297 | z, w = leggauss(polypoints) 298 | lb = np.array([0.0]) 299 | ub = np.array([1.0]) 300 | 301 | # Map [-1,1] -> [0,1] 302 | z = 0.5*(ub - lb)*(z + 1.0) + lb 303 | jac_det = 0.5*(ub-lb) 304 | 305 | # Reshape both weights and coordinates. We need them to have shape: (num_train, N, dy) 306 | z = np.tile(np.expand_dims(z,0),(num_train,1))[:,:,None] 307 | w = np.tile(np.expand_dims(w,0),(num_train,1))[:,:,None] 308 | 309 | # Create the dataset 310 | d = np.load("../Data/train_pushforward.npz") 311 | U_train = d["U_train"] 312 | x_train = d["x_train"] 313 | y_train = d["y_train"] 314 | s_train = d["s_train"] 315 | 316 | d = np.load("../Data/test_pushforward.npz") 317 | U_test = d["U_test"] 318 | x_test = d["x_test"] 319 | y_test = d["y_test"] 320 | s_test = d["s_test"] 321 | 322 | # Make all array to be jax numpy format 323 | y_train = jnp.asarray(y_train) 324 | s_train = jnp.asarray(s_train) 325 | 326 | y_test = jnp.asarray(y_test) 327 | s_test = jnp.asarray(s_test) 328 | 329 | z = jnp.asarray(z) 330 | w = jnp.asarray(w) 331 | 332 | U_train = np.reshape(U_train,(num_test,m,du)) 333 | y_train = jnp.reshape(y_train,(num_train,P,dy)) 334 | s_train = jnp.reshape(s_train,(num_train,P,ds)) 335 | 336 | U_test = np.reshape(U_test,(num_test,m,du)) 337 | y_test = jnp.reshape(y_test,(num_test,P,dy)) 338 | s_test = jnp.reshape(s_test,(num_test,P,ds)) 339 | 340 | z = jnp.reshape(z,(num_test,polypoints,dy)) 341 | w = jnp.reshape(w,(num_test,polypoints,dy)) 342 | 343 | plot=False 344 | if plot: 345 | import matplotlib.pyplot as plt 346 | pltN = 10 347 | for i in range(0,pltN-1): 348 | plt.plot(y_train[i,:,0], s_train[i,:,0], 'r-') 349 | plt.plot(y_test[i,:,0], s_test[i,:,0], 'b-') 350 | 351 | plt.plot(y_train[pltN,:,0], s_train[pltN,:,0], 'r-', label="Training output") 352 | plt.plot(y_test[pltN,:,0], s_test[pltN,:,0], 'b-', label="Testing output") 353 | plt.legend() 354 | plt.show() 355 | 356 | x = jnp.linspace(0,1,num=m) 357 | pltN = 10 358 | for i in range(0,pltN-1): 359 | plt.plot(x, np.asarray(U_train)[i,:,0], 'y-') 360 | plt.plot(x, np.asarray(U_test)[i,:,0], 'g-') 361 | 362 | plt.plot(x, np.asarray(U_train)[pltN,:,0], 'y-', label="Training input") 363 | plt.plot(x, np.asarray(U_test)[pltN,:,0], 'g-', label="Testing input") 364 | plt.legend() 365 | plt.show() 366 | 367 | y_train_pos = y_train 368 | 369 | pos_encodingy = PositionalEncodingY(y_train,int(y_train.shape[1]*y_train.shape[2]), max_len = P, H=H) 370 | y_train = pos_encodingy.forward(y_train) 371 | del pos_encodingy 372 | 373 | pos_encodingy = PositionalEncodingY(z,int(z.shape[1]*z.shape[2]), max_len = polypoints, H=H) 374 | z = pos_encodingy.forward(z) 375 | del pos_encodingy 376 | 377 | pos_encodingyt = PositionalEncodingY(y_test,int(y_test.shape[1]*y_test.shape[2]), max_len = P, H=H) 378 | y_test = pos_encodingyt.forward(y_test) 379 | del pos_encodingyt 380 | 381 | start_time = timeit.default_timer() 382 | 383 | inputs_trainxu = jnp.asarray(U_train) 384 | inputs_testxu = jnp.asarray(U_test ) 385 | elapsed = timeit.default_timer() - start_time 386 | print("The wall-clock time for for loop is seconds is equal to %f seconds"%elapsed) 387 | print(inputs_trainxu.shape, inputs_testxu.shape) 388 | 389 | train_dataset = DataGenerator(inputs_trainxu, inputs_trainxu, y_train, s_train, z, w, training_batch_size) 390 | train_dataset = iter(train_dataset) 391 | 392 | test_dataset = DataGenerator(inputs_testxu, inputs_testxu, y_test, s_test, z, w, training_batch_size) 393 | test_dataset = iter(test_dataset) 394 | 395 | v_layers = [m*du, 512, 512, ds*n_hat] 396 | g_layers = [L*dy+H*dy, 512, 512, ds*n_hat] 397 | 398 | model = LOCA(g_layers, v_layers, m=m, P=P, jac_det=jac_det) 399 | 400 | model.count_params(model.get_params(model.opt_state)) 401 | 402 | start_time = timeit.default_timer() 403 | model.train(train_dataset, test_dataset, nIter=TRAINING_ITERATIONS) 404 | elapsed = timeit.default_timer() - start_time 405 | print("The training wall-clock time is seconds is equal to %f seconds"%elapsed) 406 | 407 | params = model.get_params(model.opt_state) 408 | 409 | uCNN_test = model.predict(params, (inputs_testxu,y_test, z, w)) 410 | test_error_u = [] 411 | for i in range(0,s_test.shape[0]): 412 | test_error_u.append(jnp.linalg.norm(s_test[i,:,-1] - uCNN_test[i,:,-1], 2)/jnp.linalg.norm(s_test[i,:,-1], 2)) 413 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u),np.std(test_error_u),np.min(test_error_u),np.max(test_error_u))) 414 | 415 | uCNN_train = model.predict(params, (inputs_trainxu, y_train, z, w)) 416 | 417 | train_error_u = [] 418 | for i in range(0,s_test.shape[0]): 419 | train_error_u.append(jnp.linalg.norm(s_train[i,:,-1] - uCNN_train[i,:,-1], 2)/jnp.linalg.norm(s_train[i,:,-1], 2)) 420 | print("The average train u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(train_error_u),np.std(train_error_u),np.min(train_error_u),np.max(train_error_u))) 421 | 422 | 423 | beta, gamma, g_params, v_params = params 424 | 425 | y = y_test 426 | 427 | K = model.RBF(z, z, gamma, beta) 428 | Kzz = jnp.sqrt(model.jac_det*jnp.einsum("ijk,ikl->ijl",K,w)) 429 | 430 | K = model.RBF(y, z, gamma, beta) 431 | Kyz = jnp.sqrt(model.jac_det*jnp.einsum("ijk,ikl->ijl",K,w)) 432 | 433 | mean_K = jnp.matmul(Kyz, jnp.swapaxes(Kzz,1,2)) 434 | K = jnp.divide(K,mean_K) 435 | 436 | g = model.g_apply(g_params,z) 437 | g = model.jac_det*jnp.einsum("ijk,iklm,ik->ijlm",K,g.reshape(g.shape[0],g.shape[1], ds, int(g.shape[-1]/ds)),w[:,:,-1]) 438 | g = jax.nn.softmax(g, axis=-1) 439 | 440 | def minmax(a): 441 | minpos = a.index(min(a)) 442 | print("The minimum is at position", minpos) 443 | return minpos 444 | minpos = minmax(train_error_u) 445 | 446 | np.savez_compressed("eigenfunctions_KCAlocalper_closetoDON2.npz", efuncs=g[minpos,:,0,:]) 447 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Operators with Coupled Attention 2 | 3 | ![Fig2 resized](https://user-images.githubusercontent.com/24652388/182936051-613aaa15-e743-4d7c-9ff6-e2d4093750fd.png) 4 | 5 | Code and data accompanying the manuscript titled "Learning Operators with Coupled Attention", authored by Georgios Kissas*, Jacob H. Seidman*, Leonardo Ferreira Guilhoto, Victor M. Preciado, George J.Pappas and Paris Perdikaris. 6 | 7 | \* These authors contributed equally. 8 | 9 | # Abstract 10 | 11 | Supervised operator learning is an emerging machine learning paradigm with applications to modeling the evolution maps of spatio-temporal dynamical systems and approximating general black-box relationships between functional data. We propose a novel operator learning method, LOCA (Learning Operators with Coupled Attention), motivated from the attention mechanism. The input functions are mapped to a finite set of features which are then averaged with attention weights that depend on the output query locations. By coupling these attention weights together with an integral transform, LOCA is able explicitly learn correlations in the target output functions, enabling us to approximate nonlinear operators even when the number of output function measurements is very small. Our formulation is accompanied by rigorous approximation theoretic guarantees on the expressiveness of the proposed model. Empirically, we evaluate the performance of LOCA on several operator learning scenarios involving systems governed by ordinary and partial differential equations, as well as a black-box climate prediction problem. Through these scenarios we demonstrate state of the art accuracy, robustness with respect to noisy input data, and a consistently small spread of errors over testing data sets, even for out-of-distribution prediction tasks. 12 | 13 | 14 | # Citation 15 | 16 | @article{JMLR:v23:21-1521, 17 | author = {Georgios Kissas and Jacob H. Seidman and Leonardo Ferreira Guilhoto and Victor M. Preciado and George J. Pappas and Paris Perdikaris}, 18 | title = {Learning Operators with Coupled Attention}, 19 | journal = {Journal of Machine Learning Research}, 20 | year = {2022}, 21 | volume = {23}, 22 | number = {215}, 23 | pages = {1--63}, 24 | url = {http://jmlr.org/papers/v23/21-1521.html} 25 | } 26 | 27 | 28 | The repository contains all the necassary code and data to reproduce the results in the paper. 29 | 30 | You can find a LOCA tutorial with explanation for the Darcy flow example [here](https://colab.research.google.com/drive/1axxLGhgwipCSw9WQVMBklvQdW_K99E1D?usp=sharing). 31 | 32 | The training and testing data sets accompanying the manuscript can be found [here](https://drive.google.com/file/d/1UyjKnsL15FUHESO4rUt3Hg_CkFlxoCPz/view?usp=sharing) and the codes to plot the results as well as the data to reproduce the figures in the manuscript can be found [here](https://drive.google.com/file/d/1iwCwXLufPYg0j4dn3d_eZjnnf25fiLfy/view?usp=sharing). 33 | 34 | You can find the codes for LOCA, DeepONet and FNO used for each example in this paper under the respective folder names. 35 | 36 | 37 | 38 | ## ⚠️ The LOCA methodology and code cannot be used for commercial purposes (protected by a patent at the University of Pennsylvania).⚠️ 39 | 40 | -------------------------------------------------------------------------------- /ShallowWaters/FNO/Adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import Tensor 4 | from typing import List, Optional 5 | from torch.optim.optimizer import Optimizer 6 | 7 | import os 8 | os.environ['CUDA_VISIBLE_DEVICES']="3" 9 | 10 | def adam(params: List[Tensor], 11 | grads: List[Tensor], 12 | exp_avgs: List[Tensor], 13 | exp_avg_sqs: List[Tensor], 14 | max_exp_avg_sqs: List[Tensor], 15 | state_steps: List[int], 16 | *, 17 | amsgrad: bool, 18 | beta1: float, 19 | beta2: float, 20 | lr: float, 21 | weight_decay: float, 22 | eps: float): 23 | r"""Functional API that performs Adam algorithm computation. 24 | See :class:`~torch.optim.Adam` for details. 25 | """ 26 | 27 | for i, param in enumerate(params): 28 | 29 | grad = grads[i] 30 | exp_avg = exp_avgs[i] 31 | exp_avg_sq = exp_avg_sqs[i] 32 | step = state_steps[i] 33 | 34 | bias_correction1 = 1 - beta1 ** step 35 | bias_correction2 = 1 - beta2 ** step 36 | 37 | if weight_decay != 0: 38 | grad = grad.add(param, alpha=weight_decay) 39 | 40 | # Decay the first and second moment running average coefficient 41 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 42 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 43 | if amsgrad: 44 | # Maintains the maximum of all 2nd moment running avg. till now 45 | torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 46 | # Use the max. for normalizing running avg. of gradient 47 | denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) 48 | else: 49 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) 50 | 51 | step_size = lr / bias_correction1 52 | 53 | param.addcdiv_(exp_avg, denom, value=-step_size) 54 | 55 | 56 | class Adam(Optimizer): 57 | r"""Implements Adam algorithm. 58 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 59 | The implementation of the L2 penalty follows changes proposed in 60 | `Decoupled Weight Decay Regularization`_. 61 | Args: 62 | params (iterable): iterable of parameters to optimize or dicts defining 63 | parameter groups 64 | lr (float, optional): learning rate (default: 1e-3) 65 | betas (Tuple[float, float], optional): coefficients used for computing 66 | running averages of gradient and its square (default: (0.9, 0.999)) 67 | eps (float, optional): term added to the denominator to improve 68 | numerical stability (default: 1e-8) 69 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 70 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 71 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 72 | (default: False) 73 | .. _Adam\: A Method for Stochastic Optimization: 74 | https://arxiv.org/abs/1412.6980 75 | .. _Decoupled Weight Decay Regularization: 76 | https://arxiv.org/abs/1711.05101 77 | .. _On the Convergence of Adam and Beyond: 78 | https://openreview.net/forum?id=ryQu7f-RZ 79 | """ 80 | 81 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 82 | weight_decay=0, amsgrad=False): 83 | if not 0.0 <= lr: 84 | raise ValueError("Invalid learning rate: {}".format(lr)) 85 | if not 0.0 <= eps: 86 | raise ValueError("Invalid epsilon value: {}".format(eps)) 87 | if not 0.0 <= betas[0] < 1.0: 88 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 89 | if not 0.0 <= betas[1] < 1.0: 90 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 91 | if not 0.0 <= weight_decay: 92 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 93 | defaults = dict(lr=lr, betas=betas, eps=eps, 94 | weight_decay=weight_decay, amsgrad=amsgrad) 95 | super(Adam, self).__init__(params, defaults) 96 | 97 | def __setstate__(self, state): 98 | super(Adam, self).__setstate__(state) 99 | for group in self.param_groups: 100 | group.setdefault('amsgrad', False) 101 | 102 | @torch.no_grad() 103 | def step(self, closure=None): 104 | """Performs a single optimization step. 105 | Args: 106 | closure (callable, optional): A closure that reevaluates the model 107 | and returns the loss. 108 | """ 109 | loss = None 110 | if closure is not None: 111 | with torch.enable_grad(): 112 | loss = closure() 113 | 114 | for group in self.param_groups: 115 | params_with_grad = [] 116 | grads = [] 117 | exp_avgs = [] 118 | exp_avg_sqs = [] 119 | max_exp_avg_sqs = [] 120 | state_steps = [] 121 | beta1, beta2 = group['betas'] 122 | 123 | for p in group['params']: 124 | if p.grad is not None: 125 | params_with_grad.append(p) 126 | if p.grad.is_sparse: 127 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 128 | grads.append(p.grad) 129 | 130 | state = self.state[p] 131 | # Lazy state initialization 132 | if len(state) == 0: 133 | state['step'] = 0 134 | # Exponential moving average of gradient values 135 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 136 | # Exponential moving average of squared gradient values 137 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 138 | if group['amsgrad']: 139 | # Maintains max of all exp. moving avg. of sq. grad. values 140 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 141 | 142 | exp_avgs.append(state['exp_avg']) 143 | exp_avg_sqs.append(state['exp_avg_sq']) 144 | 145 | if group['amsgrad']: 146 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 147 | 148 | # update the steps for each param group update 149 | state['step'] += 1 150 | # record the step after step update 151 | state_steps.append(state['step']) 152 | 153 | adam(params_with_grad, 154 | grads, 155 | exp_avgs, 156 | exp_avg_sqs, 157 | max_exp_avg_sqs, 158 | state_steps, 159 | amsgrad=group['amsgrad'], 160 | beta1=beta1, 161 | beta2=beta2, 162 | lr=group['lr'], 163 | weight_decay=group['weight_decay'], 164 | eps=group['eps']) 165 | return loss -------------------------------------------------------------------------------- /ShallowWaters/FNO/FNOSW.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | This file is the Fourier Neural Operator for 3D problem such as the Navier-Stokes equation discussed in Section 5.3 in the [paper](https://arxiv.org/pdf/2010.08895.pdf), 4 | which takes the 2D spatial + 1D temporal equation directly as a 3D problem 5 | """ 6 | 7 | 8 | import torch 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from utilities3 import * 13 | from timeit import default_timer 14 | import timeit 15 | 16 | 17 | ################################################################ 18 | # 3d fourier layers 19 | ################################################################ 20 | 21 | class SpectralConv3d(nn.Module): 22 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 23 | super(SpectralConv3d, self).__init__() 24 | 25 | """ 26 | 3D Fourier layer. It does FFT, linear transform, and Inverse FFT. 27 | """ 28 | 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 32 | self.modes2 = modes2 33 | self.modes3 = modes3 34 | 35 | self.scale = (1 / (in_channels * out_channels)) 36 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 37 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 38 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 39 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 40 | 41 | # Complex multiplication 42 | def compl_mul3d(self, input, weights): 43 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 44 | return torch.einsum("bixyz,ioxyz->boxyz", input, weights) 45 | 46 | def forward(self, x): 47 | batchsize = x.shape[0] 48 | #Compute Fourier coeffcients up to factor of e^(- something constant) 49 | x_ft = torch.fft.rfftn(x, dim=[-3,-2,-1]) 50 | 51 | # Multiply relevant Fourier modes 52 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device) 53 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 54 | self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 55 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 56 | self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 57 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 58 | self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 59 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 60 | self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:,:self.modes3], self.weights4) 61 | 62 | #Return to physical space 63 | x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1))) 64 | return x 65 | 66 | class FNO3d(nn.Module): 67 | def __init__(self, modes1, modes2, modes3, width): 68 | super(FNO3d, self).__init__() 69 | 70 | """ 71 | The overall network. It contains 4 layers of the Fourier layer. 72 | 1. Lift the input to the desire channel dimension by self.fc0 . 73 | 2. 4 layers of the integral operators u' = (W + K)(u). 74 | W defined by self.w; K defined by self.conv . 75 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 76 | 77 | input: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t). It's a constant function in time, except for the last index. 78 | input shape: (batchsize, x=64, y=64, t=40, c=13) 79 | output: the solution of the next 40 timesteps 80 | output shape: (batchsize, x=64, y=64, t=40, c=1) 81 | """ 82 | 83 | self.modes1 = modes1 84 | self.modes2 = modes2 85 | self.modes3 = modes3 86 | self.width = width 87 | self.padding = 12 # pad the domain if input is non-periodic 88 | self.fc0 = nn.Linear(6, self.width) 89 | # input channel is 12: the solution of the first 10 timesteps + 3 locations (u(1, x, y), ..., u(10, x, y), x, y, t) 90 | 91 | self.conv0 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 92 | self.conv1 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 93 | self.conv2 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 94 | self.conv3 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 95 | self.w0 = nn.Conv3d(self.width, self.width, 1) 96 | self.w1 = nn.Conv3d(self.width, self.width, 1) 97 | self.w2 = nn.Conv3d(self.width, self.width, 1) 98 | self.w3 = nn.Conv3d(self.width, self.width, 1) 99 | self.bn0 = torch.nn.BatchNorm3d(self.width) 100 | self.bn1 = torch.nn.BatchNorm3d(self.width) 101 | self.bn2 = torch.nn.BatchNorm3d(self.width) 102 | self.bn3 = torch.nn.BatchNorm3d(self.width) 103 | 104 | self.fc1 = nn.Linear(self.width, 128) 105 | self.fc2 = nn.Linear(128, 3) 106 | 107 | def forward(self, x): 108 | grid = self.get_grid(x.shape, x.device) 109 | x = torch.cat((x, grid), dim=-1) 110 | x = self.fc0(x) 111 | x = x.permute(0, 4, 1, 2, 3) 112 | x = F.pad(x, [0,self.padding]) # pad the domain if input is non-periodic 113 | 114 | x1 = self.conv0(x) 115 | x2 = self.w0(x) 116 | x = x1 + x2 117 | x = F.gelu(x) 118 | 119 | x1 = self.conv1(x) 120 | x2 = self.w1(x) 121 | x = x1 + x2 122 | x = F.gelu(x) 123 | 124 | x1 = self.conv2(x) 125 | x2 = self.w2(x) 126 | x = x1 + x2 127 | x = F.gelu(x) 128 | 129 | x1 = self.conv3(x) 130 | x2 = self.w3(x) 131 | x = x1 + x2 132 | 133 | x = x[..., :-self.padding] 134 | x = x.permute(0, 2, 3, 4, 1) # pad the domain if input is non-periodic 135 | x = self.fc1(x) 136 | x = F.gelu(x) 137 | x = self.fc2(x) 138 | return x 139 | 140 | def get_grid(self, shape, device): 141 | batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3] 142 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 143 | gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1]) 144 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 145 | gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1]) 146 | gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float) 147 | gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1]) 148 | return torch.cat((gridx, gridy, gridz), dim=-1).to(device) 149 | 150 | ################################################################ 151 | # configs 152 | ################################################################ 153 | ntrain = 1000 154 | ntest = 1000 155 | 156 | modes = 8 157 | width = 25 158 | 159 | batch_size = 100 160 | batch_size2 = batch_size 161 | 162 | epochs = 400 163 | learning_rate = 0.001 164 | scheduler_step = 100 165 | scheduler_gamma = 0.5 166 | 167 | print(epochs, learning_rate, scheduler_step, scheduler_gamma) 168 | 169 | runtime = np.zeros(2,) 170 | t1 = default_timer() 171 | 172 | 173 | sub = 1 174 | S = 32 // sub 175 | T_in = 1 176 | T = 5 177 | par = 3 178 | P = 128 179 | ################################################################ 180 | # load data 181 | ################################################################ 182 | idxT = [10,15,20,25,30] 183 | d = np.load("/scratch/gkissas/all_train_SW_Nx32_Ny32_numtrain1000.npz") 184 | U_train = d["U_train"][:,:,:,:] 185 | S_train = np.swapaxes(d["s_train"][:,idxT,:,:,None,:],4,1)[:,-1,:,:,:,:] 186 | TT = d["T_train"][idxT] 187 | CX = d["X_train"] 188 | CY = d["Y_train"] 189 | X_sim_train = d["XX_train"] 190 | Y_sim_train = d["YY_train"] 191 | 192 | d = np.load("/scratch/gkissas/all_test_SW_Nx32_Ny32_numtest1000.npz") 193 | U_test = d["U_test"][:,:,:,:] 194 | S_test = np.swapaxes(d["s_test"][:,idxT,:,:,None,:],4,1)[:,-1,:,:,:,:] 195 | TT = d["T_test"][idxT] 196 | CX = d["X_test"] 197 | CY = d["Y_test"] 198 | X_sim_test = d["XX_test"] 199 | Y_sim_test = d["YY_test"] 200 | 201 | dtype_double = torch.FloatTensor 202 | cdtype_double = torch.cuda.DoubleTensor 203 | train_a = torch.from_numpy(np.asarray(U_train)).type(dtype_double) 204 | train_u = torch.from_numpy(np.asarray(S_train)).type(dtype_double) 205 | 206 | test_a = torch.from_numpy(np.asarray(U_test)).type(dtype_double) 207 | test_u = torch.from_numpy(np.asarray(S_test)).type(dtype_double) 208 | 209 | print(train_u.shape, train_a.shape) 210 | print(test_u.shape, test_a.shape) 211 | assert (S == train_u.shape[-3]) 212 | assert (T == train_u.shape[-2]) 213 | assert (par == train_u.shape[-1]) 214 | 215 | train_a = train_a.reshape(ntrain,S,S,1,par).repeat([1,1,1,T,1]) 216 | test_a = test_a.reshape(ntest,S,S,1,par).repeat([1,1,1,T,1]) 217 | 218 | ind_train = torch.randint(S*S*T, (ntrain, P)) 219 | ind_test = torch.randint(S*S*T, (ntest, P)) 220 | 221 | train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u, ind_train), batch_size=batch_size, shuffle=True) 222 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u, ind_test), batch_size=batch_size, shuffle=False) 223 | 224 | t2 = default_timer() 225 | 226 | print('preprocessing finished, time used:', t2-t1) 227 | device = torch.device('cuda') 228 | 229 | ################################################################ 230 | # training and evaluation 231 | ################################################################ 232 | 233 | batch_ind = torch.arange(batch_size).reshape(-1, 1).repeat(1, P) 234 | model = FNO3d(modes, modes, modes, width).cuda() 235 | print(count_params(model)) 236 | 237 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) # the weight decay is 1e-4 originally 238 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 239 | 240 | 241 | myloss = LpLoss(size_average=False) 242 | start_time = timeit.default_timer() 243 | for ep in range(epochs): 244 | model.train() 245 | t1 = default_timer() 246 | train_mse = 0 247 | train_l2 = 0 248 | for x, y, idx in train_loader: 249 | x, y = x.cuda(), y.cuda() 250 | 251 | optimizer.zero_grad() 252 | out = model(x).view(batch_size, S*S*T, par) 253 | y = y.reshape(batch_size, S*S*T, par) 254 | y = y[batch_ind, idx] 255 | out = out[batch_ind, idx] 256 | 257 | mse = F.mse_loss(out, y, reduction='mean') 258 | l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1)) 259 | l2.backward() 260 | 261 | optimizer.step() 262 | train_mse += mse.item() 263 | train_l2 += l2.item() 264 | 265 | scheduler.step() 266 | 267 | model.eval() 268 | test_l2 = 0.0 269 | with torch.no_grad(): 270 | for x, y, idx in test_loader: 271 | x, y = x.cuda(), y.cuda() 272 | 273 | out = model(x).view(batch_size, S*S*T, par) 274 | y = y.reshape(batch_size, S*S*T,par) 275 | y = y[batch_ind, idx] 276 | out = out[batch_ind, idx] 277 | test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item() 278 | 279 | train_mse /= len(train_loader) 280 | train_l2 /= ntrain 281 | test_l2 /= ntest 282 | 283 | t2 = default_timer() 284 | print(ep, t2-t1, train_mse, train_l2, test_l2) 285 | elapsed = timeit.default_timer() - start_time 286 | print("The training wall-clock time is seconds is equal to %f seconds"%elapsed) 287 | 288 | pred = torch.zeros(test_u.shape) 289 | index = 0 290 | test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u), batch_size=1, shuffle=False) 291 | test_error_u = [] 292 | test_error_rho_np = [] 293 | test_error_u_np = [] 294 | test_error_v_np = [] 295 | with torch.no_grad(): 296 | for x, y in test_loader: 297 | test_l2 = 0 298 | x, y = x.cuda(), y.cuda() 299 | 300 | out = model(x).view(S, S, T, par) 301 | pred[index,:,:,:] = out 302 | test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item() 303 | test_error_u.append(test_l2) 304 | test_error_rho_np.append(np.linalg.norm(y.cpu().numpy().reshape(test_u.shape[2]*test_u.shape[2]*T,par)[:,0]- out.cpu().numpy().reshape(test_u.shape[2]*test_u.shape[2]*T,par)[:,0],2)/np.linalg.norm(y.cpu().numpy().reshape(test_u.shape[2]*test_u.shape[2]*T,par)[:,0],2)) 305 | test_error_u_np.append(np.linalg.norm(y.cpu().numpy().reshape(test_u.shape[2]*test_u.shape[2]*T,par)[:,1]- out.cpu().numpy().reshape(test_u.shape[2]*test_u.shape[2]*T,par)[:,1],2)/np.linalg.norm(y.cpu().numpy().reshape(test_u.shape[2]*test_u.shape[2]*T,par)[:,1],2)) 306 | test_error_v_np.append(np.linalg.norm(y.cpu().numpy().reshape(test_u.shape[2]*test_u.shape[2]*T,par)[:,2]- out.cpu().numpy().reshape(test_u.shape[2]*test_u.shape[2]*T,par)[:,2],2)/np.linalg.norm(y.cpu().numpy().reshape(test_u.shape[2]*test_u.shape[2]*T,par)[:,2],2)) 307 | index = index + 1 308 | 309 | print("The average test rho error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_rho_np),np.std(test_error_rho_np),np.min(test_error_rho_np),np.max(test_error_rho_np))) 310 | print("The average test u error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_u_np),np.std(test_error_u_np),np.min(test_error_u_np),np.max(test_error_u_np))) 311 | print("The average test v error is %e the standard deviation is %e the min error is %e and the max error is %e"%(np.mean(test_error_v_np),np.std(test_error_v_np),np.min(test_error_v_np),np.max(test_error_v_np))) --------------------------------------------------------------------------------