├── allen_cahn ├── ac_solution_dirichlet.pkl └── CausalAllenCahnModel.py ├── README.md ├── utils.py ├── models.py ├── poisson_2d └── PoissonModel.py ├── helmholtz_2d └── HelmholtzModel.py ├── archs.py └── LICENSE /allen_cahn/ac_solution_dirichlet.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PredictiveIntelligenceLab/ActNet/HEAD/allen_cahn/ac_solution_dirichlet.pkl -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ActNet 2 | 3 | Repository for some of the experiments presented in the paper "Deep Learning Alternatives of the Kolmogorov Superposition Theorem", acepted as a Spotlight Paper in ICLR 2025. (arXiv: https://arxiv.org/abs/2410.01990) 4 | 5 | This code requires common libraries of the JAX environment, such as Flax (for neural network design) and Optax/JaxOpt (for training and optimization). Plotting is done using Matplotlb. 6 | 7 | Experiments comparing against the state-of-the-art require integration with JaxPi, which is an open-source library. The code for those experiments can now be found on the ActNet branch of JaxPi: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/ActNet 8 | 9 | FILES: 10 | * archs.py : includes the architectures used in the paper, including JAX implementations of ActNet, KAN and Siren. 11 | * models.py : includes a training model for regression that can be used with any of the architectures. 12 | * utils.py : includes useful code for sampling batches. 13 | * poisson_2d/ : directory containing minimal code to run the Poisson 2D problem. 14 | * PoissonModel.py : flexible training model for the 2D Poisson problem that can be used with any desired architecture. 15 | * prediction_plots.ipynb : Jupyter notebook tutorial showing how to run the Poisson problem and produce plots. 16 | * helmholtz_2d/ : directory containing minimal code to run the Helmholtz 2D problem. 17 | * HelmholtzModel.py : flexible training model for the 2D Helmholtz problem that can be used with any desired architecture. 18 | * prediction_plots.ipynb : Jupyter notebook tutorial showing how to run the Helmholtz problem and produce plots. 19 | * allen_cahn/ : directory containing minimal code to run the Allen-Cahn problem. 20 | * ac_solution_dirichlet.pkl : pickle file containing reference solution for the Allen-Cahn obtained using a classical solver. 21 | * CausalAllenCahnModel.py : flexible training model for the Allen-Cahn problem that can be used with any desired architecture. 22 | * prediction_plots.ipynb : Jupyter notebook tutorial showing how to run the Allen-Cahn problem and produce plots. 23 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import random 4 | from jax import jit, vmap, grad 5 | 6 | 7 | from functools import partial 8 | import torch.utils.data as data 9 | 10 | 11 | # Dataset loader 12 | class BatchedDataset(data.Dataset): 13 | ''' A data loader for creating mini-batches. 14 | 15 | Attributes: 16 | raw_data: full dataset to be used. This should be a tuple of lenght 3, 17 | formated as (inputs, targets, weights). 18 | key: a PRNG key used as the random key. 19 | batch_size: the size of each mini-batch. If None, uses full batch 20 | (default: None). 21 | ''' 22 | 23 | def __init__(self, raw_data, key, batch_size=None): 24 | super().__init__() 25 | self.inputs = raw_data[0] 26 | self.targets = raw_data[1] 27 | self.weights = raw_data[2] 28 | self.size = len(self.weights) 29 | self.key = key 30 | if batch_size is None: # Will use full batch 31 | self.batch_size = self.size 32 | else: 33 | self.batch_size = batch_size 34 | 35 | def __len__(self): 36 | return self.size 37 | 38 | def __getitem__(self, idx): 39 | self.key, subkey = random.split(self.key) 40 | batch_inputs, batch_targets, batched_weights = self.__select_batch(subkey) 41 | return batch_inputs, batch_targets, batched_weights 42 | 43 | @partial(jit, static_argnums=(0,)) 44 | def __select_batch(self, key): 45 | idx = random.choice(key, self.size, (self.batch_size,), replace=False) 46 | batch_inputs = self.inputs[idx] 47 | batch_targets = self.targets[idx] 48 | batched_weights = self.weights[idx] 49 | return batch_inputs, batch_targets, batched_weights 50 | 51 | 52 | class SquareDataset(data.Dataset): 53 | ''' A data loader for creating mini-batches of uniformly samples points on the 54 | inside and on the boundary of a [-1,1]^2 square. Generates a pair of vectors 55 | (interior_batch, border_batch) with iid points on the interior and border of 56 | squre, respectively. 57 | 58 | Attributes: 59 | key: a PRNG key used as the random key. 60 | batch_size: the size of each mini-batch. Should be a tuple of lenght 2 in 61 | the format (inside_batch_size, border_batch_size). 62 | ''' 63 | def __init__(self, key, batch_size=(10_000, 800)): 64 | super().__init__() 65 | self.size = batch_size[0] 66 | self.key = key 67 | self.batch_size = batch_size 68 | 69 | def __len__(self): 70 | return self.size 71 | 72 | def __getitem__(self, idx): 73 | self.key, subkey1, subkey2 = random.split(self.key, 3) 74 | interior_batch, border_batch = self.__select_batch(subkey1, subkey2) 75 | return interior_batch, border_batch 76 | 77 | @partial(jit, static_argnums=(0,)) 78 | def __select_batch(self, subkey1, subkey2): 79 | interior_batch = random.uniform(subkey1, shape=(self.batch_size[0], 2), 80 | minval=-1, maxval=1) 81 | border_batch = random.uniform(subkey2, shape=(self.batch_size[1],1), 82 | minval=-1, maxval=1) 83 | aux = jnp.split(border_batch, 4) 84 | border_batch = jnp.concatenate([ 85 | jnp.hstack([-jnp.ones_like(aux[0]), aux[0]]), 86 | jnp.hstack([jnp.ones_like(aux[1]), aux[1]]), 87 | jnp.hstack([aux[2], -jnp.ones_like(aux[2])]), 88 | jnp.hstack([aux[3], jnp.ones_like(aux[3])]), 89 | ], axis=0) 90 | return interior_batch, border_batch 91 | 92 | # alias for SquareDataset 93 | Poisson2DDataset = SquareDataset -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import random 4 | from jax import jit, vmap, grad, value_and_grad 5 | 6 | import numpy as onp 7 | import optax 8 | from optax._src import linear_algebra 9 | import jaxopt 10 | 11 | from functools import partial 12 | import itertools 13 | from tqdm import trange 14 | import matplotlib.pyplot as plt 15 | 16 | 17 | 18 | # not used for the ActNet paper, but maybe useful for the future 19 | class RegressionModel: 20 | """ A model for training/evaluating a neural network using regression. 21 | 22 | Attributes: 23 | arch (nn.Module): a Flax module of the desired architecute. 24 | batch: an initial batch used for initializing parameters and computing 25 | normalization factors. 26 | optimizer: the optimizer to be used when running gradient descent. 27 | normalize_inputs: whether to normalize inputs before passing them to the 28 | architecture (default: True). 29 | normalize_outputs: whether to normalize outputs of the architecture 30 | (default: True). 31 | key: a PRNG key used as the random key for initialization. 32 | steps_per_check (int): how many training steps to use between logging 33 | and displaying losses (default: 100). 34 | """ 35 | def __init__(self, arch, batch, optimizer=None, normalize_inputs=True, 36 | normalize_outputs=True, key=random.PRNGKey(43), 37 | steps_per_check=100) -> None: 38 | # Define model 39 | self.arch = arch 40 | self.key = key 41 | self.steps_per_check = steps_per_check 42 | 43 | # Initialize parameters 44 | inputs, outputs, _ = batch 45 | self.params = self.arch.init(self.key, inputs) 46 | 47 | # Tabulate function for checking network architecture 48 | self.tabulate = lambda : \ 49 | self.arch.tabulate(self.key, inputs, console_kwargs={'width':110}) 50 | 51 | # Vectorized functions 52 | self.normalize_inputs = normalize_inputs 53 | self.normalize_outputs = normalize_outputs 54 | self.normalize_data = (normalize_inputs or normalize_outputs) 55 | if self.normalize_data: 56 | mu_x = inputs.mean(0, keepdims=True) 57 | sig_x = inputs.std(0, keepdims=True) 58 | mu_y = outputs.mean(0, keepdims=True) 59 | sig_y = outputs.std(0, keepdims=True) 60 | self.norm_stats = ((mu_x, sig_x), (mu_y, sig_y)) 61 | if self.normalize_inputs: 62 | if self.normalize_outputs: 63 | self.apply = lambda params, x : \ 64 | mu_y + sig_y*self.arch.apply(params, 65 | (x-mu_x)/(sig_x + 0.01)) 66 | else: 67 | self.apply = lambda params, x : \ 68 | self.arch.apply(params, (x-mu_x)/(sig_x + 0.01)) 69 | else: 70 | self.apply = lambda params, x : \ 71 | mu_y + sig_y*self.arch.apply(params, x) 72 | 73 | else: 74 | self.norm_stats = None 75 | self.apply = self.arch.apply 76 | # jits apply function for numerical consistency (sometimes jitted 77 | # version behaves slightly differently than non-jitted one) 78 | self.apply = jit(self.apply) 79 | 80 | # Optimizer 81 | if optimizer is None: 82 | lr = optax.exponential_decay(1e-3, transition_steps=1000, 83 | decay_rate=0.9, end_value=1e-5) 84 | self.optimizer = optax.adam(learning_rate=lr) 85 | else: 86 | self.optimizer = optimizer 87 | self.opt_state = self.optimizer.init(self.params) 88 | 89 | # Optimizer LBFGS 90 | self.optimizer_lbfgs = jaxopt.LBFGS(self.loss) 91 | self.opt_state_lbfgs = self.optimizer_lbfgs.init_state(self.params, 92 | batch) 93 | self.optimizer_update_lbfgs = jit(self.optimizer_lbfgs.update) 94 | 95 | # Logger 96 | self.itercount = itertools.count() 97 | self.loss_log = [] 98 | self.grad_norm_log = [] 99 | 100 | def recon_loss(self, params, u, s, w): 101 | outputs = self.apply(params, u) # shape (batch_dim, out_dim) 102 | loss = jnp.mean(w*(s-outputs)**2, axis=(-1)) # shape (batch_dim,) 103 | return loss 104 | 105 | @partial(jit, static_argnums=(0,)) 106 | def loss(self, params, batch): 107 | inputs, targets, weights = batch 108 | u = inputs 109 | s = targets 110 | w = weights 111 | return self.recon_loss(params, u, s, w).mean() # scalar 112 | 113 | 114 | # Define a compiled update step 115 | @partial(jit, static_argnums=(0,)) 116 | def step(self, params, opt_state, batch): 117 | grads = grad(self.loss)(params, batch) 118 | updates, opt_state = self.optimizer.update(grads, opt_state, params) 119 | params = optax.apply_updates(params, updates) 120 | return params, opt_state, grads 121 | 122 | # Optimize parameters in a loop 123 | def train(self, dataset, nIter = 10000): 124 | """ Trains the neural network for nIter steps using data loader. 125 | 126 | Args: 127 | dataset (BatchedDataset): data loader for training. 128 | nIter (int): number of training iterations. 129 | """ 130 | data = iter(dataset) 131 | pbar = trange(nIter) 132 | # Main training loop 133 | for it in pbar: 134 | batch = next(data) 135 | self.params, self.opt_state, grads = self.step(self.params, 136 | self.opt_state, 137 | batch) 138 | # Logger 139 | if it % self.steps_per_check == 0: 140 | l = self.loss(self.params, batch) 141 | g_norm = linear_algebra.global_norm(grads).squeeze() 142 | self.loss_log.append(l) 143 | self.grad_norm_log.append(g_norm) 144 | pbar.set_postfix({ 145 | 'loss': l, 146 | 'grad_norm': jnp.mean(jnp.array(g_norm)) 147 | }) 148 | 149 | # Define a compiled update step 150 | @partial(jit, static_argnums=(0,)) 151 | def step_lbfgs(self, params, opt_state, batch): 152 | new_params, opt_state = self.optimizer_update_lbfgs(params, 153 | opt_state, 154 | batch) 155 | return new_params, opt_state 156 | 157 | # Optimize parameters in a loop 158 | def train_lbfgs(self, dataset, nIter = 10000): 159 | """ Trains the neural network using LBFGS optimizer for nIter steps 160 | using data loader. 161 | 162 | Args: 163 | dataset (BatchedDataset): data loader for training. 164 | nIter (int): number of training iterations. 165 | """ 166 | data = iter(dataset) 167 | pbar = trange(nIter) 168 | batch = next(data) 169 | self.opt_state_lbfgs = self.optimizer_lbfgs.init_state(self.params, 170 | batch) 171 | # Main training loop 172 | for it in pbar: 173 | batch = next(data) 174 | # Logger 175 | if it % self.steps_per_check == 0: 176 | l = self.loss(self.params, batch) 177 | self.loss_log.append(l) 178 | grads = grad(self.loss)(self.params, batch) 179 | g_norm = linear_algebra.global_norm(grads).squeeze() 180 | self.grad_norm_log.append(g_norm) 181 | pbar.set_postfix({ 182 | 'loss': l, 183 | 'grad_norm': jnp.mean(jnp.array(g_norm)) 184 | }) 185 | # optimization step 186 | self.params, self.opt_state_lbfgs = self.step_lbfgs(self.params, 187 | self.opt_state_lbfgs, 188 | batch) 189 | 190 | def plot_logs(self, window=None) -> None: 191 | """ Plots logs of training losses and gradient norms through training. 192 | 193 | Args: 194 | window: desired window for computing moving averages (default: None) 195 | """ 196 | plot_logs(self.loss_log, self.grad_norm_log, window=window, 197 | steps_per_check=self.steps_per_check) 198 | 199 | def batched_apply(self, x, batch_size=2_048): 200 | '''Performs forward pass using smaller batches, then concatenates them 201 | together before returning predictions. Useful for avoiding OoM issues 202 | when input is large. 203 | 204 | Args: 205 | x: input to the model 206 | batch_size: maximum batch size for computation. 207 | 208 | Returns: 209 | predictions of the model on input x 210 | ''' 211 | num_batches = int(jnp.ceil(len(x) / batch_size)) 212 | x_batches = jnp.split(x, 213 | batch_size*(1+jnp.arange(num_batches-1)), 214 | axis=0) 215 | pred_fn = lambda ins : self.apply(self.params, ins) 216 | y_pred = jnp.concatenate([pred_fn(ins) for ins in x_batches], axis=0) 217 | return y_pred 218 | 219 | def get_rmse(self, batch, batch_size=2_048): 220 | # Create predictions 221 | u, s_true, _ = batch 222 | if batch_size is None: # single forward pass 223 | s_pred = self.apply(self.params, u) 224 | else: # breaks prediction into smaller forward passes 225 | s_pred = self.batched_apply(u, batch_size=batch_size) 226 | error = s_pred - s_true 227 | rmse = jnp.sqrt(jnp.mean(error**2)) 228 | return rmse 229 | 230 | def plot_predictions(self, batch, return_pred=False, batch_size=2_048): 231 | """Computes and plots model predictions for a given batch of data. 232 | 233 | Args: 234 | batch: data for creating/plotting results. 235 | return_pred: whether to return predictions after plotting 236 | (default: False). 237 | batch_size: batch size for computations (to avoid OoM issues in the 238 | case of large datasets). (default: 2048) 239 | """ 240 | # Create predictions 241 | u, s_true, _ = batch 242 | if batch_size is None: # single forward pass 243 | s_pred = self.apply(self.params, u) 244 | else: # breaks prediction into smaller forward passes 245 | s_pred = self.batched_apply(u, batch_size=batch_size) 246 | 247 | error = s_pred - s_true 248 | rel_l2_error = jnp.sqrt(jnp.sum(error**2)/jnp.sum(s_true**2)) 249 | print('Relative L2 error: {:.2e}'.format(rel_l2_error)) 250 | print('RMSE: {:.2e}'.format(jnp.sqrt(jnp.mean(error**2)))) 251 | 252 | if u.shape[-1]== 1: # domain is 1D 253 | plt.figure(figsize=(15, 4)) 254 | 255 | # Ploting examples of reconstructions 256 | plt.subplot(131) 257 | plt.plot(u, s_true, lw=1) 258 | plt.plot(u, s_pred, '--', lw=1) 259 | plt.xlabel('$y$') 260 | plt.ylabel('$s$') 261 | plt.title('Prediction Vs Truth (predictions are dashed)') 262 | 263 | # Ploting error 264 | plt.subplot(132) 265 | plt.plot(u, s_pred-s_true, lw=1) 266 | plt.xlabel('$y$') 267 | plt.ylabel('$s$') 268 | plt.title('Error') 269 | 270 | # plotting histogram of errors 271 | plt.subplot(133) 272 | plt.hist(error.flatten(), bins=50) 273 | plt.title('Histogram of errors') 274 | 275 | plt.show() 276 | elif u.shape[-1] == 2: # domain is 2D 277 | plt.figure(figsize=(15, 4)) 278 | 279 | # Ploting examples of reconstructions 280 | plt.subplot(131) 281 | plt.scatter(u[:,0], u[:,1], c=s_pred) 282 | plt.colorbar() 283 | plt.xlabel('$y$') 284 | plt.ylabel('$s$') 285 | plt.title('Prediction') 286 | 287 | # Ploting true solution 288 | plt.subplot(132) 289 | plt.scatter(u[:,0], u[:,1], c=s_true) 290 | plt.colorbar() 291 | plt.xlabel('$y$') 292 | plt.ylabel('$s$') 293 | plt.title('True') 294 | 295 | # Ploting errors 296 | plt.subplot(133) 297 | plt.scatter(u[:,0], u[:,1], c=s_pred-s_true) 298 | plt.colorbar() 299 | plt.xlabel('$y$') 300 | plt.ylabel('$s$') 301 | plt.title('Error') 302 | 303 | plt.show() 304 | else: # domain is higher than 2D. Plot histogram of errors instead 305 | # plotting histogram of errors 306 | plt.hist(error.flatten(), bins=50) 307 | plt.title('Histogram of errors') 308 | plt.show() 309 | 310 | if return_pred: 311 | return s_pred 312 | 313 | # alias for RegressionModel 314 | SupervisedModel = RegressionModel 315 | 316 | 317 | 318 | # Functions to help plotting 319 | def plot_logs(loss_log, grad_norm_log, window=None, steps_per_check=100): 320 | """ Plots logs of training losses and gradient norms through training. 321 | 322 | Args: 323 | loss_log: sequence of training losses. 324 | grad_norm_log: sequence of parameter gradient norms. 325 | window: desired window for computing moving averages (default: None). 326 | steps_per_check: how many training steps were taken between each log. 327 | """ 328 | plt.figure(figsize=(12, 4)) 329 | 330 | # Plotting losses 331 | plt.subplot(121) 332 | if window is None: 333 | plt.plot(steps_per_check*jnp.arange(len(loss_log)), loss_log) 334 | else: 335 | assert type(window) is int , f'window must be an interger or None, not {type(window)}' 336 | plt.plot(steps_per_check*jnp.arange(len(loss_log) - window), 337 | [onp.mean(loss_log[i:i+window]) for i in range(len(loss_log) - window)]) 338 | plt.yscale('log') 339 | plt.title('Loss through iterations') 340 | 341 | # Plotting gradient norms 342 | plt.subplot(122) 343 | if window is None: 344 | plt.plot(steps_per_check*jnp.arange(len(grad_norm_log)), grad_norm_log) 345 | else: 346 | assert type(window) is int , f'window must be an interger or None, not {type(window)}' 347 | plt.plot(steps_per_check*jnp.arange(len(grad_norm_log) - window), 348 | [onp.mean(grad_norm_log[i:i+window]) for i in range(len(grad_norm_log) - window)]) 349 | plt.yscale('log') 350 | plt.title('Global gradient norm through iterations') 351 | plt.show() -------------------------------------------------------------------------------- /allen_cahn/CausalAllenCahnModel.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from jax import random 4 | from jax import vmap, jit, grad, value_and_grad 5 | 6 | import optax 7 | from optax._src import linear_algebra 8 | 9 | from functools import partial 10 | import itertools 11 | from tqdm import trange 12 | import torch.utils.data as data 13 | 14 | import matplotlib.pyplot as plt 15 | 16 | import sys 17 | sys.path.append('..') # makes modules in parent repository available to import 18 | from models import plot_logs 19 | 20 | 21 | class SquareDataset(data.Dataset): 22 | ''' A data loader for creating mini-batches of uniformly samples points on the 23 | inside of a square/rectangle. Generates iid points on the interior of the 24 | square/rectangle. These points are returned ordered by time (earliest first). 25 | 26 | Attributes: 27 | key: a PRNG key used as the random key. 28 | minvals: pair indicating minimum values for each dimension. 29 | minvals: pair indicating maximum values for each dimension. 30 | batch_size: the size of each mini-batch. 31 | ''' 32 | def __init__(self, key, minvals=(-1,-1), maxvals=(1,1), batch_size=10_000): 33 | super().__init__() 34 | self.minvals = jnp.array(minvals) 35 | self.maxvals = jnp.array(maxvals) 36 | self.size = batch_size 37 | self.key = key 38 | self.batch_size = batch_size 39 | 40 | def __len__(self): 41 | return self.size 42 | 43 | def __getitem__(self, idx): 44 | self.key, subkey = random.split(self.key) 45 | interior_batch = self.__select_batch(subkey) 46 | return interior_batch 47 | 48 | @partial(jit, static_argnums=(0,)) 49 | def __select_batch(self, subkey): 50 | interior_batch = random.uniform(subkey, shape=(self.batch_size, 2), 51 | minval=self.minvals, maxval=self.maxvals) 52 | # time is last coordinate 53 | idx_sort = jnp.argsort(interior_batch[:,-1]) 54 | # returns batch with times ordered in increasing time 55 | return interior_batch[idx_sort] 56 | 57 | 58 | def get_ac_residual_fn(f, D=1e-4): 59 | ''' 60 | Computes the Allen-Cahn (AC) residual of a given funciton in a 61 | computationally efficient way by avoiding unecessary repetition of the 62 | computational graph. 63 | 64 | Args: 65 | f (Callable): a function with signature (params, x, t)->scalar 66 | D (float): the D constant parameter for the AC equation. 67 | Returns: 68 | (Callable): a function with signature (params, x, t)->scalar that is the 69 | AC residual of input function 70 | ''' 71 | _single_f = lambda params, x, t : f(params, x[None,:], t[None,:]).squeeze() 72 | def _scalar_grads(params, x, t): 73 | u, (ux, ut) = value_and_grad(_single_f, argnums=(1,2))(params, x, t) 74 | return ux.squeeze(), (u, ut.squeeze()) 75 | def _ac_aux(params, x, t): 76 | (u_x, (u, u_t)), u_xx = value_and_grad(_scalar_grads, 77 | argnums=1, 78 | has_aux=True)(params, x, t) 79 | return u, u_t, u_xx.squeeze() 80 | def _res_fn(params, xs, ts): 81 | u, ut, uxx = vmap(_ac_aux, in_axes=(None, 0, 0))(params, xs, ts) 82 | return ut - D*uxx + 5*(u**3 - u) 83 | return _res_fn 84 | 85 | 86 | 87 | class AllenCahnModel: 88 | """ A model for training/evaluating a neural network using physics-informed 89 | causal training on the Allen-Cahn problem. 90 | 91 | Attributes: 92 | arch (nn.Module): a Flax module of the desired architecute. 93 | batch: an initial batch used for initializing parameters and computing 94 | normalization factors. 95 | true_fun: the true ground_truth function, if known (default: None). 96 | D: the D constant parameter for the AC equation. 97 | optimizer: the optimizer to be used when running gradient descent. 98 | normalize_inputs: whether to normalize inputs before passing them to the 99 | architecture (default: True). 100 | key: a PRNG key used as the random key for initialization. 101 | exact_bd_condition: whether to exactly enforce boundary condition on 102 | [-1,1]^2 (see https://arxiv.org/abs/2104.08426). This implicitly 103 | makes it so that border losses are not computed (defaut: True). 104 | bdr_enforcer_order: order of the polynomial used for enforcing border 105 | exactly. Should be an even integer (default: 2). 106 | steps_per_check (int): how many training steps to use between logging 107 | and displaying losses (default: 50). 108 | """ 109 | def __init__(self, arch, batch, true_sol=None, D=1e-4, 110 | optimizer=None, normalize_inputs=True, key=random.PRNGKey(43), 111 | exact_bd_condition=True, bdr_enforcer_order=2, 112 | steps_per_check=50) -> None: 113 | # Define model 114 | self.arch = arch 115 | self.key = key 116 | self.steps_per_check = steps_per_check 117 | self.true_sol = true_sol 118 | self.D = D 119 | 120 | # Initialize parameters 121 | interior_batch = batch 122 | self.params = self.arch.init(self.key, interior_batch) 123 | 124 | # Tabulate function for checking network architecture 125 | self.tabulate = lambda : \ 126 | self.arch.tabulate(self.key, 127 | interior_batch, 128 | console_kwargs={'width':110}) 129 | 130 | # Vectorized functions 131 | self.normalize_inputs = normalize_inputs 132 | self.exact_bd_condition = exact_bd_condition 133 | self.bdr_enforcer_order = bdr_enforcer_order # should be an even number 134 | if normalize_inputs: 135 | mu_x = jnp.hstack(interior_batch).mean(0, keepdims=True) 136 | sig_x = jnp.hstack(interior_batch).std(0, keepdims=True) 137 | self.norm_stats = (mu_x, sig_x) 138 | _apply = lambda params, x, y : \ 139 | self.arch.apply(params, (jnp.hstack([x, y])-mu_x)/sig_x) 140 | if self.exact_bd_condition: 141 | _apply = lambda params, x, t : \ 142 | (1-t)*(x**2 * jnp.cos(jnp.pi*x)) \ 143 | + t*((1-x**self.bdr_enforcer_order)*self.arch.apply(params, 144 | (jnp.hstack([x, t])-mu_x)/sig_x) - 1) 145 | else: 146 | _apply = lambda params, x, y : \ 147 | self.arch.apply(params, (jnp.hstack([x, y])-mu_x)/sig_x) 148 | else: 149 | self.norm_stats = None 150 | if self.exact_bd_condition: 151 | _apply = lambda params, x, t : \ 152 | (1-t)*(x**2 * jnp.cos(jnp.pi*x)) \ 153 | + t*((1-x**self.bdr_enforcer_order)*self.arch.apply(params, 154 | jnp.hstack([x, t])) - 1) 155 | else: 156 | _apply = lambda params, x, y : \ 157 | self.arch.apply(params, jnp.hstack([x, y])) 158 | # jits apply function for numerical consistency (sometimes jitted 159 | # version behaves slightly differently than non-jitted one) 160 | self.apply = jit(_apply) 161 | 162 | # Vectorized derivatives. 163 | # functions prefixed by '_single' take in a vector of shape (1,) and 164 | # output a scalar of shape (,) 165 | _single_f = lambda params, x, y : \ 166 | self.apply(params, x[None,:], y[None,:]).squeeze() 167 | # x derivatives 168 | _single_f_x = lambda params, x, y : \ 169 | grad(_single_f, argnums=1)(params, x, y).squeeze() # scalar 170 | self.f_x = vmap(_single_f_x, in_axes=(None, 0, 0)) 171 | _single_f_xx = lambda params, x, y : \ 172 | grad(_single_f_x, argnums=1)(params, x, y).squeeze() # scalar 173 | self.f_xx = vmap(_single_f_xx, in_axes=(None, 0, 0)) 174 | # y derivatives 175 | _single_f_y = lambda params, x, y : \ 176 | grad(_single_f, argnums=2)(params, x, y).squeeze() # scalar 177 | self.f_y = vmap(_single_f_y, in_axes=(None, 0, 0)) 178 | _single_f_yy = lambda params, x, y : \ 179 | grad(_single_f_y, argnums=2)(params, x, y).squeeze() # scalar 180 | self.f_yy = vmap(_single_f_yy, in_axes=(None, 0, 0)) 181 | # laplacian 182 | self.ac_residual = get_ac_residual_fn(self.apply, D=self.D) 183 | 184 | # Optimizer 185 | if optimizer is None: # use a standard optimizer 186 | lr = optax.exponential_decay(1e-3, transition_steps=1000, 187 | decay_rate=0.8, end_value=1e-7) 188 | self.optimizer = optax.chain( 189 | optax.adaptive_grad_clip(1e-2), 190 | optax.adam(learning_rate=lr), 191 | ) 192 | else: 193 | self.optimizer = optimizer 194 | self.opt_state = self.optimizer.init(self.params) 195 | 196 | # Logger 197 | self.itercount = itertools.count() 198 | self.loss_log = [] 199 | self.grad_norm_log = [] 200 | self.rel_l2_log = [] 201 | 202 | def residual_loss(self, params, x, t, causal_eps): 203 | res = self.ac_residual(params, x, t)[:,None] # shape (batch_dim,1) 204 | goal = jnp.zeros_like(res) 205 | res = jnp.mean((res-goal)**2, axis=-1) # shape (batch_dim,) 206 | if causal_eps is None: # no causal learning 207 | return res 208 | else: 209 | # compute causal weights 210 | ws = jax.lax.stop_gradient(jnp.exp(-causal_eps*(jnp.cumsum(res) - res))) # shape (num_ts,) 211 | # make it so that mean value of weights is 1 to maintain loss in the 212 | # same order of magnitude 213 | ws = ws/(ws.mean()+1e-3) 214 | assert ws.shape == res.shape, f"ws is shape {ws.shape} but res is shape {res.shape}" 215 | return ws*res 216 | 217 | 218 | 219 | def pinn_loss(self, params, interior_batch, causal_eps): 220 | r_loss = self.residual_loss(params, 221 | interior_batch[:,0][:,None], 222 | interior_batch[:,1][:,None], 223 | causal_eps) 224 | if self.exact_bd_condition: 225 | # no need to consider border loss, since it will be 0 when bdry 226 | # condition is exactly enforced 227 | return r_loss.mean() 228 | else: 229 | raise NotImplementedError 230 | # consider both residual loss initial condition loss and boundary condition loss 231 | #b_loss = self.border_loss(params, border_batch[:,0][:,None], border_batch[:,1][:,None]) 232 | #return self.pinn_weights[0]*r_loss.mean() + self.pinn_weights[1]*b_loss.mean() 233 | 234 | 235 | @partial(jit, static_argnums=(0,)) 236 | def loss(self, params, batch, causal_eps): 237 | interior_batch = batch 238 | return self.pinn_loss(params, interior_batch, causal_eps).mean() # scalar 239 | 240 | 241 | # Define a compiled update step 242 | @partial(jit, static_argnums=(0,)) 243 | def step(self, params, opt_state, batch, causal_eps): 244 | grads = grad(self.loss)(params, batch, causal_eps) 245 | updates, opt_state = self.optimizer.update(grads, opt_state, params) 246 | params = optax.apply_updates(params, updates) 247 | return params, opt_state, grads 248 | 249 | # Optimize parameters in a loop 250 | def train(self, dataset, nIter = 10_000, causal_eps=None): 251 | """ Trains the neural network for nIter steps using data loader. 252 | 253 | Args: 254 | dataset (SquareDataset): data loader for training. 255 | nIter (int): number of training iterations. 256 | causal_eps (None | float): epsilon for computing causal loss. If 257 | None, does not use causal learning (default: None). 258 | """ 259 | data = iter(dataset) 260 | pbar = trange(nIter) 261 | # Main training loop 262 | for it in pbar: 263 | batch = next(data) 264 | self.params, self.opt_state, grads = self.step(self.params, 265 | self.opt_state, 266 | batch, 267 | causal_eps) 268 | # Logger 269 | if it % self.steps_per_check == 0: 270 | l = self.loss(self.params, batch, causal_eps) 271 | g_norm = linear_algebra.global_norm(grads).squeeze() 272 | self.loss_log.append(l) 273 | self.grad_norm_log.append(g_norm) 274 | if self.true_sol is not None: 275 | pred = self.apply(self.params, 276 | self.true_sol[0][0], 277 | self.true_sol[0][1]) 278 | true = self.true_sol[1] 279 | rel_l2_error = jnp.sqrt(((pred-true)**2).mean() \ 280 | / ((true)**2).mean()) 281 | self.rel_l2_log.append(rel_l2_error) 282 | pbar.set_postfix_str(f"loss:{l : .3e}, rel_l2:{rel_l2_error : .2e}, 'grad_norm':{jnp.mean(jnp.array(g_norm)) : .2e}") 283 | else: 284 | pbar.set_postfix({ 285 | 'loss': l, 286 | 'grad_norm': jnp.mean(jnp.array(g_norm)), 287 | }) 288 | 289 | def plot_logs(self, window=None) -> None: 290 | """ Plots logs of training losses and gradient norms through training. 291 | 292 | Args: 293 | window: desired window for computing moving averages (default: None). 294 | """ 295 | plot_logs(self.loss_log, self.grad_norm_log, window=window, 296 | steps_per_check=self.steps_per_check) 297 | 298 | def batched_apply(self, x, batch_size=2_048): 299 | '''Performs forward pass using smaller batches, then concatenates them 300 | together before returning predictions. Useful for avoiding OoM issues 301 | when input is large. 302 | 303 | Args: 304 | x: input to the model 305 | batch_size: maximum batch size for computation. 306 | 307 | Returns: 308 | predictions of the model on input x 309 | ''' 310 | num_batches = int(jnp.ceil(len(x) / batch_size)) 311 | x_batches = jnp.split(x, 312 | batch_size*(1+jnp.arange(num_batches-1)), 313 | axis=0) 314 | pred_fn = jit(lambda ins : \ 315 | self.apply(self.params, 316 | ins[:,0][:,None], 317 | ins[:,1][:,None])) 318 | y_pred = jnp.concatenate([pred_fn(ins) for ins in x_batches], axis=0) 319 | return y_pred 320 | 321 | def get_rmse(self, batch, batch_size=2_048): 322 | # Create predictions 323 | u, s_true = batch 324 | if batch_size is None: # single forward pass 325 | s_pred = self.apply(self.params, u) 326 | else: # breaks prediction into smaller forward passes 327 | s_pred = self.batched_apply(u, batch_size=batch_size) 328 | error = s_pred - s_true 329 | rmse = jnp.sqrt(jnp.mean(error**2)) 330 | return rmse 331 | 332 | def plot_predictions(self, batch, return_pred=False, batch_size=2_048, 333 | num_levels = 500): 334 | """Computes and plots model predictions for a given batch of data. 335 | 336 | Args: 337 | batch: data for creating/plotting results. 338 | return_pred: whether to return predictions after plotting 339 | (default: False). 340 | batch_size: batch size for computations (to avoid OoM issues in the 341 | case of large datasets). (default: 2048) 342 | num_levels: number of levels for contour plot (default: 500). 343 | """ 344 | # Create predictions 345 | u, s_true = batch 346 | if batch_size is None: # single forward pass 347 | s_pred = self.apply(self.params, u) 348 | else: # breaks prediction into smaller forward passes 349 | s_pred = self.batched_apply(u, batch_size=batch_size) 350 | 351 | error = s_pred - s_true 352 | rel_l2_error = jnp.sqrt(jnp.sum(error**2)/jnp.sum(s_true**2)) 353 | print('Relative L2 error: {:.2e}'.format(rel_l2_error)) 354 | print('RMSE: {:.2e}'.format(jnp.sqrt(jnp.mean(error**2)))) 355 | 356 | plt.figure(figsize=(16, 4)) 357 | 358 | # Ploting examples of reconstructions 359 | plt.subplot(131) 360 | plt.tricontourf(u[:,1], u[:,0], 361 | s_pred.T.squeeze(), levels=num_levels, cmap='jet') 362 | plt.colorbar() 363 | plt.xlabel('$t$') 364 | plt.ylabel('$x$') 365 | plt.title('Prediction') 366 | 367 | # Ploting true solution 368 | plt.subplot(132) 369 | plt.tricontourf(u[:,1], u[:,0], 370 | s_true.T.squeeze(), levels=num_levels, cmap='jet') 371 | plt.colorbar() 372 | plt.xlabel('$t$') 373 | plt.ylabel('$x$') 374 | plt.title('True') 375 | 376 | # Ploting absolute 377 | plt.subplot(133) 378 | plt.tricontourf(u[:,1], u[:,0], 379 | abs(s_pred-s_true).T.squeeze(), 380 | levels=num_levels, cmap='jet') 381 | plt.colorbar() 382 | plt.xlabel('$t$') 383 | plt.ylabel('$x$') 384 | plt.title('Absolute Error') 385 | 386 | plt.show() 387 | 388 | plt.show() 389 | 390 | if return_pred: 391 | return s_pred -------------------------------------------------------------------------------- /poisson_2d/PoissonModel.py: -------------------------------------------------------------------------------- 1 | # Basic Library Importsk 2 | import jax 3 | import jax.numpy as jnp 4 | from jax import random 5 | from jax import vmap, jit, grad, value_and_grad 6 | 7 | import optax 8 | from optax._src import linear_algebra 9 | import jaxopt 10 | 11 | import matplotlib.pyplot as plt 12 | 13 | from functools import partial 14 | import itertools 15 | from tqdm import trange 16 | 17 | import sys 18 | sys.path.append('..') # makes modules in parent repository available to import 19 | from models import plot_logs 20 | 21 | 22 | 23 | def get_laplacian(f): 24 | '''Computes the 2D laplacian of a given funciton in a computationally 25 | efficient way by avoiding unecessary repetition of the computational graph. 26 | 27 | Args: 28 | f (Callable): a function with signature (params, x, y)->scalar 29 | Returns: 30 | (Callable): a function with signature (params, x, y)->scalar that is the 31 | laplacian of input function 32 | ''' 33 | _single_f = lambda params, x, y : f(params, x[None,:], y[None,:]).squeeze() 34 | def _scalar_grads(params, x, y): 35 | ux, uy = grad(_single_f, argnums=(1,2))(params, x, y) 36 | return ux.squeeze(), uy.squeeze() 37 | def _lapl_aux(params, x, y): 38 | (u_x, u_y), u_xx = value_and_grad(_scalar_grads, 39 | argnums=1, 40 | has_aux=True)(params, x, y) 41 | return u_y, u_xx.squeeze() 42 | def _lapl_aux_2(params, x, y): 43 | (u_y, u_xx), u_yy = value_and_grad(_lapl_aux, 44 | argnums=2, 45 | has_aux=True)(params, x, y) 46 | return u_xx, u_yy.squeeze() 47 | def _lapl(params, xs, ys): 48 | uxx, uyy = vmap(_lapl_aux_2, in_axes=(None, 0, 0))(params, xs, ys) 49 | return uxx + uyy 50 | return _lapl 51 | 52 | class PoissonModel: 53 | """ A model for training/evaluating a neural network using physics-informed 54 | training on the 2D Poisson problem. 55 | 56 | Attributes: 57 | arch (nn.Module): a Flax module of the desired architecute. 58 | batch: an initial batch used for initializing parameters and computing 59 | normalization factors. 60 | true_fun: the true ground_truth function, if known (default: None). 61 | pde_res_fn: the target function for the PDE differential operator. 62 | optimizer: the optimizer to be used when running gradient descent. 63 | normalize_inputs: whether to normalize inputs before passing them to the 64 | architecture (default: True). 65 | key: a PRNG key used as the random key for initialization. 66 | pinn_weights: pair of weights for balacing residual/border losses. 67 | exact_bd_condition: whether to exactly enforce boundary condition on 68 | [-1,1]^2 (see https://arxiv.org/abs/2104.08426). This implicitly 69 | makes it so that border losses are not computed (defaut: True). 70 | bdr_enforcer_order: order of the polynomial used for enforcing border 71 | exactly. Should be an even integer (default: 2). 72 | steps_per_check (int): how many training steps to use between logging 73 | and displaying losses (default: 50). 74 | """ 75 | def __init__(self, arch, batch, true_fun=None, 76 | pde_res_fn=lambda x, y : \ 77 | -(jnp.pi**2)*(1+4*(y**2))*jnp.sin(jnp.pi*x)*jnp.sin(jnp.pi*(y**2)) \ 78 | + 2*jnp.pi*jnp.sin(jnp.pi*x)*jnp.cos(jnp.pi*(y**2)), 79 | optimizer=None, normalize_inputs=True, key=random.PRNGKey(43), 80 | pinn_weights=(0.001, 1.), exact_bd_condition=True, 81 | bdr_enforcer_order=2, steps_per_check=50) -> None: 82 | # Define model 83 | self.arch = arch 84 | self.key = key 85 | self.steps_per_check = steps_per_check 86 | self.pde_res_fn = pde_res_fn 87 | self.true_fun = true_fun 88 | self.pinn_weights = pinn_weights # not really used in the experiments (boundary is enforced exactly) 89 | 90 | # Initialize parameters 91 | interior_batch, border_batch = batch 92 | self.params = self.arch.init(self.key, interior_batch) 93 | 94 | # Tabulate function for checking network architecture 95 | self.tabulate = lambda : \ 96 | self.arch.tabulate(self.key, 97 | interior_batch, 98 | console_kwargs={'width':110}) 99 | 100 | # Vectorized functions 101 | self.normalize_inputs = normalize_inputs 102 | self.exact_bd_condition = exact_bd_condition 103 | self.bdr_enforcer_order = bdr_enforcer_order # should be an even number 104 | if normalize_inputs: 105 | mu_x = jnp.hstack(interior_batch).mean(0, keepdims=True) 106 | sig_x = jnp.hstack(interior_batch).std(0, keepdims=True) 107 | self.norm_stats = (mu_x, sig_x) 108 | _apply = lambda params, x, y : \ 109 | self.arch.apply(params, (jnp.hstack([x, y])-mu_x)/sig_x) 110 | if self.exact_bd_condition: 111 | _apply = lambda params, x, y : \ 112 | (1-x**self.bdr_enforcer_order)\ 113 | *(1-y**self.bdr_enforcer_order)\ 114 | *self.arch.apply(params, (jnp.hstack([x, y])-mu_x)/sig_x) 115 | else: 116 | _apply = lambda params, x, y : \ 117 | self.arch.apply(params, (jnp.hstack([x, y])-mu_x)/sig_x) 118 | else: 119 | self.norm_stats = None 120 | if self.exact_bd_condition: 121 | _apply = lambda params, x, y : \ 122 | (1-x**self.bdr_enforcer_order)\ 123 | *(1-y**self.bdr_enforcer_order)\ 124 | *self.arch.apply(params, jnp.hstack([x, y])) 125 | else: 126 | _apply = lambda params, x, y : \ 127 | self.arch.apply(params, jnp.hstack([x, y])) 128 | # jits apply function for numerical consistency (sometimes jitted 129 | # version behaves slightly differently than non-jitted one) 130 | self.apply = jit(_apply) 131 | 132 | # Vectorized derivatives. 133 | # functions prefixed by '_single' take in a vector of shape (1,) and 134 | # output a scalar of shape (,) 135 | _single_f = lambda params, x, y : \ 136 | self.apply(params, x[None,:], y[None,:]).squeeze() 137 | # x derivatives 138 | _single_f_x = lambda params, x, y : \ 139 | grad(_single_f, argnums=1)(params, x, y).squeeze() # scalar 140 | self.f_x = vmap(_single_f_x, in_axes=(None, 0, 0)) 141 | _single_f_xx = lambda params, x, y : \ 142 | grad(_single_f_x, argnums=1)(params, x, y).squeeze() # scalar 143 | self.f_xx = vmap(_single_f_xx, in_axes=(None, 0, 0)) 144 | # y derivatives 145 | _single_f_y = lambda params, x, y : \ 146 | grad(_single_f, argnums=2)(params, x, y).squeeze() # scalar 147 | self.f_y = vmap(_single_f_y, in_axes=(None, 0, 0)) 148 | _single_f_yy = lambda params, x, y : \ 149 | grad(_single_f_y, argnums=2)(params, x, y).squeeze() # scalar 150 | self.f_yy = vmap(_single_f_yy, in_axes=(None, 0, 0)) 151 | # laplacian 152 | self.laplacian = get_laplacian(self.apply) 153 | 154 | # Optimizer 155 | if optimizer is None: # use a standard optimizer 156 | lr = optax.exponential_decay(1e-3, transition_steps=1000, 157 | decay_rate=0.8, end_value=1e-7) 158 | self.optimizer = optax.chain( 159 | optax.adaptive_grad_clip(1e-2), 160 | optax.adam(learning_rate=lr), 161 | ) 162 | else: 163 | self.optimizer = optimizer 164 | self.opt_state = self.optimizer.init(self.params) 165 | 166 | # Optimizer LBFGS 167 | self.optimizer_lbfgs = jaxopt.LBFGS(self.loss) 168 | self.opt_state_lbfgs = self.optimizer_lbfgs.init_state(self.params, 169 | batch) 170 | self.optimizer_update_lbfgs = jit(self.optimizer_lbfgs.update) 171 | 172 | # Logger 173 | self.itercount = itertools.count() 174 | self.loss_log = [] 175 | self.grad_norm_log = [] 176 | self.rel_l2_log = [] 177 | 178 | def residual_loss(self, params, x, y): 179 | res = self.laplacian(params, x, y)[:,None] # shape (batch_dim,1) 180 | goal = self.pde_res_fn(x, y) # shape (batch_dim,1) 181 | return jnp.mean((res-goal)**2, axis=-1) # shape (batch_dim,) 182 | 183 | def border_loss(self, params, x, y): 184 | # function should be zero at the boundary 185 | outputs = self.apply(params, x, y) # shape (batch_dim, out_dim) 186 | return jnp.mean(outputs**2, axis=-1) # shape (batch_dim,) 187 | 188 | def pinn_loss(self, params, interior_batch, border_batch): 189 | r_loss = self.residual_loss(params, 190 | interior_batch[:,0][:,None], 191 | interior_batch[:,1][:,None]) 192 | if self.exact_bd_condition: 193 | # no need to consider border loss, since it will be 0 when bdry 194 | # condition is exactly enforced 195 | return self.pinn_weights[0]*r_loss.mean() 196 | else: 197 | # consider both residual loss and boundary condition loss 198 | b_loss = self.border_loss(params, 199 | border_batch[:,0][:,None], 200 | border_batch[:,1][:,None]) 201 | return self.pinn_weights[0]*r_loss.mean() \ 202 | + self.pinn_weights[1]*b_loss.mean() 203 | 204 | 205 | @partial(jit, static_argnums=(0,)) 206 | def loss(self, params, batch): 207 | interior_batch, border_batch = batch 208 | return self.pinn_loss(params, interior_batch, border_batch).mean() # scalar 209 | 210 | 211 | # Define a compiled update step 212 | @partial(jit, static_argnums=(0,)) 213 | def step(self, params, opt_state, batch): 214 | grads = grad(self.loss)(params, batch) 215 | updates, opt_state = self.optimizer.update(grads, opt_state, params) 216 | params = optax.apply_updates(params, updates) 217 | return params, opt_state, grads 218 | 219 | # Optimize parameters in a loop 220 | def train(self, dataset, nIter = 10_000): 221 | """ Trains the neural network for nIter steps using data loader. 222 | 223 | Args: 224 | dataset (SquareDataset): data loader for training. 225 | nIter (int): number of training iterations. 226 | """ 227 | data = iter(dataset) 228 | pbar = trange(nIter) 229 | # Main training loop 230 | for it in pbar: 231 | batch = next(data) 232 | self.params, self.opt_state, grads = self.step(self.params, 233 | self.opt_state, 234 | batch) 235 | # Logger 236 | if it % self.steps_per_check == 0: 237 | l = self.loss(self.params, batch) 238 | g_norm = linear_algebra.global_norm(grads).squeeze() 239 | self.loss_log.append(l) 240 | self.grad_norm_log.append(g_norm) 241 | if self.true_fun is not None: # true function is known 242 | interior_batch, border_batch = batch 243 | pred = self.apply(self.params, 244 | interior_batch[:,0][:,None], 245 | interior_batch[:,1][:,None]) 246 | true = self.true_fun(interior_batch[:,0][:,None], 247 | interior_batch[:,1][:,None]) 248 | rel_l2_error = jnp.sqrt(((pred-true)**2).mean() \ 249 | / ((true)**2).mean()) 250 | self.rel_l2_log.append(rel_l2_error) 251 | pbar.set_postfix_str( 252 | f"loss:{l : .3e}, rel_l2:{rel_l2_error : .2e}, 'grad_norm':{jnp.mean(jnp.array(g_norm)) : .2e}") 253 | else: # true function is unknown 254 | pbar.set_postfix({ 255 | 'loss': l, 256 | 'grad_norm': jnp.mean(jnp.array(g_norm))}) 257 | 258 | # Define a compiled update step 259 | @partial(jit, static_argnums=(0,)) 260 | def step_lbfgs(self, params, opt_state, batch): 261 | new_params, opt_state = self.optimizer_update_lbfgs(params, 262 | opt_state, 263 | batch) 264 | return new_params, opt_state 265 | 266 | # Optimize parameters in a loop 267 | def train_lbfgs(self, dataset, nIter = 10000): 268 | """ Trains the neural network using LBFGS optimizer for nIter steps 269 | using data loader. 270 | 271 | Args: 272 | dataset (SquareDataset): data loader for training. 273 | nIter (int): number of training iterations. 274 | """ 275 | data = iter(dataset) 276 | pbar = trange(nIter) 277 | batch = next(data) 278 | self.opt_state_lbfgs = self.optimizer_lbfgs.init_state(self.params, 279 | batch) 280 | # Main training loop 281 | for it in pbar: 282 | batch = next(data) 283 | # Logger 284 | if it % self.steps_per_check == 0: 285 | l = self.loss(self.params, batch) 286 | self.loss_log.append(l) 287 | grads = grad(self.loss)(self.params, batch) 288 | g_norm = linear_algebra.global_norm(grads).squeeze() 289 | self.grad_norm_log.append(g_norm) 290 | if self.true_fun is not None: 291 | interior_batch, border_batch = batch 292 | pred = self.apply(self.params, interior_batch[:,0][:,None], 293 | interior_batch[:,1][:,None]) 294 | true = self.true_fun(interior_batch[:,0][:,None], 295 | interior_batch[:,1][:,None]) 296 | rel_l2_error = jnp.sqrt(((pred-true)**2).mean() \ 297 | / ((true)**2).mean()) 298 | self.rel_l2_log.append(rel_l2_error) 299 | pbar.set_postfix_str( 300 | f"loss:{l : .3e}, rel_l2:{rel_l2_error : .2e}, 'grad_norm':{jnp.mean(jnp.array(g_norm)) : .2e}") 301 | else: 302 | pbar.set_postfix({ 303 | 'loss': l, 304 | 'grad_norm': jnp.mean(jnp.array(g_norm))}) 305 | # take step 306 | self.params, self.opt_state_lbfgs = self.step_lbfgs(self.params, 307 | self.opt_state_lbfgs, 308 | batch) 309 | 310 | 311 | def plot_logs(self, window=None) -> None: 312 | """ Plots logs of training losses and gradient norms through training. 313 | 314 | Args: 315 | window: desired window for computing moving averages (default: None). 316 | """ 317 | plot_logs(self.loss_log, self.grad_norm_log, window=window, 318 | steps_per_check=self.steps_per_check) 319 | 320 | def batched_apply(self, x, batch_size=2_048): 321 | '''Performs forward pass using smaller batches, then concatenates them 322 | together before returning predictions. Useful for avoiding OoM issues 323 | when input is large. 324 | 325 | Args: 326 | x: input to the model 327 | batch_size: maximum batch size for computation. 328 | 329 | Returns: 330 | predictions of the model on input x 331 | ''' 332 | num_batches = int(jnp.ceil(len(x) / batch_size)) 333 | x_batches = jnp.split(x, 334 | batch_size*(1+jnp.arange(num_batches-1)), 335 | axis=0) 336 | pred_fn = jit(lambda ins : self.apply(self.params, 337 | ins[:,0][:,None], 338 | ins[:,1][:,None])) 339 | y_pred = jnp.concatenate([pred_fn(ins) for ins in x_batches], axis=0) 340 | return y_pred 341 | 342 | def get_rmse(self, batch, batch_size=2_048): 343 | # Create predictions 344 | u, s_true = batch 345 | if batch_size is None: # single forward pass 346 | s_pred = self.apply(self.params, u) 347 | else: # breaks prediction into smaller forward passes 348 | s_pred = self.batched_apply(u, batch_size=batch_size) 349 | error = s_pred - s_true 350 | rmse = jnp.sqrt(jnp.mean(error**2)) 351 | return rmse 352 | 353 | def plot_predictions(self, batch, return_pred=False, batch_size=2_048, 354 | num_levels = 100): 355 | """Computes and plots model predictions for a given batch of data. 356 | 357 | Args: 358 | batch: data for creating/plotting results. 359 | return_pred: whether to return predictions after plotting 360 | (default: False). 361 | batch_size: batch size for computations (to avoid OoM issues in the 362 | case of large datasets). (default: 2048) 363 | num_levels: number of levels for contour plot (default: 100). 364 | """ 365 | # Create predictions 366 | u, s_true = batch 367 | if batch_size is None: # single forward pass 368 | s_pred = self.apply(self.params, u) 369 | else: # breaks prediction into smaller forward passes 370 | s_pred = self.batched_apply(u, batch_size=batch_size) 371 | 372 | error = s_pred - s_true 373 | rel_l2_error = jnp.sqrt(jnp.sum(error**2)/jnp.sum(s_true**2)) 374 | print('Relative L2 error: {:.2e}'.format(rel_l2_error)) 375 | print('RMSE: {:.2e}'.format(jnp.sqrt(jnp.mean(error**2)))) 376 | 377 | plt.figure(figsize=(16, 4)) 378 | 379 | # Ploting examples of reconstructions 380 | plt.subplot(131) 381 | plt.tricontourf(u[:,0], u[:,1], 382 | s_pred.squeeze(), levels=num_levels) 383 | plt.colorbar() 384 | plt.xlabel('$x$') 385 | plt.ylabel('$y$') 386 | plt.title('Prediction') 387 | 388 | # Ploting true solution 389 | plt.subplot(132) 390 | plt.tricontourf(u[:,0], u[:,1], 391 | s_true.squeeze(), levels=num_levels) 392 | plt.colorbar() 393 | plt.xlabel('$x$') 394 | plt.ylabel('$y$') 395 | plt.title('True') 396 | 397 | # Ploting absolute 398 | plt.subplot(133) 399 | plt.tricontourf(u[:,0], u[:,1], 400 | abs(s_pred-s_true).squeeze(), levels=num_levels) 401 | plt.colorbar() 402 | plt.xlabel('$x$') 403 | plt.ylabel('$y$') 404 | plt.title('Absolute Error') 405 | 406 | plt.show() 407 | 408 | plt.show() 409 | 410 | if return_pred: 411 | return s_pred 412 | 413 | -------------------------------------------------------------------------------- /helmholtz_2d/HelmholtzModel.py: -------------------------------------------------------------------------------- 1 | # Basic Library Importsk 2 | import jax 3 | import jax.numpy as jnp 4 | from jax import random 5 | from jax import vmap, jit, grad, value_and_grad 6 | 7 | import optax 8 | from optax._src import linear_algebra 9 | import jaxopt 10 | 11 | import matplotlib.pyplot as plt 12 | 13 | from functools import partial 14 | import itertools 15 | from tqdm import trange 16 | 17 | import sys 18 | sys.path.append('..') # makes modules in parent repository available to import 19 | from models import plot_logs 20 | 21 | def get_value_and_laplacian(f): 22 | ''' 23 | Computes the 2D laplacian of a given funciton in a computationally efficient 24 | way by avoiding unecessary repetition of the computational graph. 25 | 26 | Args: 27 | f (Callable): a function with signature (params, x, y)->scalar 28 | Returns: 29 | (Callable): a function with signature (params, x, y)->scalar, scalar that is the 30 | laplacian of input function 31 | ''' 32 | _single_f = lambda params, x, y : f(params, x[None,:], y[None,:]).squeeze() 33 | def _scalar_grads(params, x, y): 34 | u, (ux, uy) = value_and_grad(_single_f, argnums=(1,2))(params, x, y) 35 | return ux.squeeze(), (u, uy.squeeze()) 36 | def _lapl_aux(params, x, y): 37 | (u_x, (u, u_y)), u_xx = value_and_grad(_scalar_grads, 38 | argnums=1, 39 | has_aux=True)(params, x, y) 40 | return u_y, (u, u_xx.squeeze()) 41 | def _lapl_aux_2(params, x, y): 42 | (u_y, (u, u_xx)), u_yy = value_and_grad(_lapl_aux, 43 | argnums=2, 44 | has_aux=True)(params, x, y) 45 | return u, u_xx, u_yy.squeeze() 46 | def _value_and_lapl(params, xs, ys): 47 | u, uxx, uyy = vmap(_lapl_aux_2, in_axes=(None, 0, 0))(params, xs, ys) 48 | return u, uxx + uyy 49 | return _value_and_lapl 50 | 51 | 52 | class HelmholtzModel: 53 | """ A model for training/evaluating a neural network using physics-informed 54 | training on the 2D Helmholtz problem. 55 | 56 | Attributes: 57 | arch (nn.Module): a Flax module of the desired architecute. 58 | batch: an initial batch used for initializing parameters and computing 59 | normalization factors. 60 | true_fun: the true ground_truth function, if known (default: None). 61 | pde_res_fn: the target function for the PDE differential operator. 62 | optimizer: the optimizer to be used when running gradient descent. 63 | normalize_inputs: whether to normalize inputs before passing them to the 64 | architecture (default: True). 65 | key: a PRNG key used as the random key for initialization. 66 | pinn_weights: pair of weights for balacing residual/border losses. 67 | exact_bd_condition: whether to exactly enforce boundary condition on 68 | [-1,1]^2 (see https://arxiv.org/abs/2104.08426). This implicitly 69 | makes it so that border losses are not computed (defaut: True). 70 | bdr_enforcer_order: order of the polynomial used for enforcing border 71 | exactly. Should be an even integer (default: 2). 72 | steps_per_check (int): how many training steps to use between logging 73 | and displaying losses (default: 50). 74 | """ 75 | def __init__(self, arch, batch, true_fun=None, 76 | pde_res_fn=lambda x, y : \ 77 | -(jnp.pi**2)*(1+4*(y**2))*jnp.sin(jnp.pi*x)*jnp.sin(jnp.pi*(y**2)) \ 78 | + 2*jnp.pi*jnp.sin(jnp.pi*x)*jnp.cos(jnp.pi*(y**2)), 79 | kappa = 1, 80 | optimizer=None, normalize_inputs=True, key=random.PRNGKey(43), 81 | pinn_weights=(0.001, 1.), exact_bd_condition=False, 82 | bdr_enforcer_order=2, steps_per_check=50) -> None: 83 | # Define model 84 | self.arch = arch 85 | self.key = key 86 | self.steps_per_check = steps_per_check 87 | self.pde_res_fn = pde_res_fn 88 | self.kappa = kappa 89 | self.true_fun = true_fun 90 | self.pinn_weights = pinn_weights # not really used in the experiments (boundary is enforced exactly) 91 | 92 | # Initialize parameters 93 | interior_batch, border_batch = batch 94 | self.params = self.arch.init(self.key, interior_batch) 95 | 96 | # Tabulate function for checking network architecture 97 | self.tabulate = lambda : \ 98 | self.arch.tabulate(self.key, 99 | interior_batch, 100 | console_kwargs={'width':110}) 101 | 102 | # Vectorized functions 103 | self.normalize_inputs = normalize_inputs 104 | self.exact_bd_condition = exact_bd_condition 105 | self.bdr_enforcer_order = bdr_enforcer_order # should be an even number 106 | if normalize_inputs: 107 | mu_x = jnp.hstack(interior_batch).mean(0, keepdims=True) 108 | sig_x = jnp.hstack(interior_batch).std(0, keepdims=True) 109 | self.norm_stats = (mu_x, sig_x) 110 | _apply = lambda params, x, y : \ 111 | self.arch.apply(params, (jnp.hstack([x, y])-mu_x)/sig_x) 112 | if self.exact_bd_condition: 113 | _apply = lambda params, x, y : \ 114 | (1-x**self.bdr_enforcer_order)\ 115 | *(1-y**self.bdr_enforcer_order)\ 116 | *self.arch.apply(params, 117 | (jnp.hstack([x, y])-mu_x)/sig_x) 118 | else: 119 | _apply = lambda params, x, y : \ 120 | self.arch.apply(params, 121 | (jnp.hstack([x, y])-mu_x)/sig_x) 122 | else: 123 | self.norm_stats = None 124 | if self.exact_bd_condition: 125 | _apply = lambda params, x, y : \ 126 | (1-x**self.bdr_enforcer_order)\ 127 | *(1-y**self.bdr_enforcer_order)\ 128 | *self.arch.apply(params, 129 | jnp.hstack([x, y])) 130 | else: 131 | _apply = lambda params, x, y : \ 132 | self.arch.apply(params, 133 | jnp.hstack([x, y])) 134 | # jits apply function for numerical consistency (sometimes jitted 135 | # version behaves slightly differently than non-jitted one) 136 | self.apply = jit(_apply) 137 | 138 | # Vectorized derivatives. 139 | # functions prefixed by '_single' take in a vector of shape (1,) and 140 | # output a scalar of shape (,) 141 | _single_f = lambda params, x, y : \ 142 | self.apply(params, x[None,:], y[None,:]).squeeze() 143 | # x derivatives 144 | _single_f_x = lambda params, x, y : \ 145 | grad(_single_f, argnums=1)(params, x, y).squeeze() # scalar 146 | self.f_x = vmap(_single_f_x, in_axes=(None, 0, 0)) 147 | _single_f_xx = lambda params, x, y : \ 148 | grad(_single_f_x, argnums=1)(params, x, y).squeeze() # scalar 149 | self.f_xx = vmap(_single_f_xx, in_axes=(None, 0, 0)) 150 | # y derivatives 151 | _single_f_y = lambda params, x, y : \ 152 | grad(_single_f, argnums=2)(params, x, y).squeeze() # scalar 153 | self.f_y = vmap(_single_f_y, in_axes=(None, 0, 0)) 154 | _single_f_yy = lambda params, x, y : \ 155 | grad(_single_f_y, argnums=2)(params, x, y).squeeze() # scalar 156 | self.f_yy = vmap(_single_f_yy, in_axes=(None, 0, 0)) 157 | # laplacian 158 | self.value_and_laplacian = get_value_and_laplacian(self.apply) 159 | 160 | # Optimizer 161 | if optimizer is None: # use a standard optimizer 162 | lr = optax.exponential_decay(1e-3, transition_steps=1000, 163 | decay_rate=0.8, end_value=1e-7) 164 | self.optimizer = optax.chain( 165 | optax.adaptive_grad_clip(1e-2), 166 | optax.adam(learning_rate=lr), 167 | ) 168 | else: 169 | self.optimizer = optimizer 170 | self.opt_state = self.optimizer.init(self.params) 171 | 172 | # Optimizer LBFGS 173 | self.optimizer_lbfgs = jaxopt.LBFGS(self.loss) 174 | self.opt_state_lbfgs = self.optimizer_lbfgs.init_state(self.params, 175 | batch) 176 | self.optimizer_update_lbfgs = jit(self.optimizer_lbfgs.update) 177 | 178 | # Logger 179 | self.itercount = itertools.count() 180 | self.loss_log = [] 181 | self.grad_norm_log = [] 182 | self.rel_l2_log = [] 183 | 184 | def residual_loss(self, params, x, y): 185 | #res = self.laplacian(params, x, y)[:,None] # shape (batch_dim,1) 186 | u, delta_u = self.value_and_laplacian(params, x, y) 187 | u, delta_u = u[:,None], delta_u[:,None] 188 | res = delta_u + (self.kappa**2)*u 189 | goal = self.pde_res_fn(x, y) # shape (batch_dim,1) 190 | return jnp.mean((res-goal)**2, axis=-1) # shape (batch_dim,) 191 | 192 | def border_loss(self, params, x, y): 193 | # function should be zero at the boundary 194 | outputs = self.apply(params, x, y) # shape (batch_dim, out_dim) 195 | return jnp.mean(outputs**2, axis=-1) # shape (batch_dim,) 196 | 197 | def pinn_loss(self, params, interior_batch, border_batch): 198 | r_loss = self.residual_loss(params, 199 | interior_batch[:,0][:,None], 200 | interior_batch[:,1][:,None]) 201 | if self.exact_bd_condition: 202 | # no need to consider border loss, since it will be 0 when bdry 203 | # condition is exactly enforced 204 | return self.pinn_weights[0]*r_loss.mean() 205 | else: 206 | # consider both residual loss and boundary condition loss 207 | b_loss = self.border_loss(params, 208 | border_batch[:,0][:,None], 209 | border_batch[:,1][:,None]) 210 | return self.pinn_weights[0]*r_loss.mean() \ 211 | + self.pinn_weights[1]*b_loss.mean() 212 | 213 | 214 | @partial(jit, static_argnums=(0,)) 215 | def loss(self, params, batch): 216 | interior_batch, border_batch = batch 217 | return self.pinn_loss(params, interior_batch, border_batch).mean() # scalar 218 | 219 | 220 | # Define a compiled update step 221 | @partial(jit, static_argnums=(0,)) 222 | def step(self, params, opt_state, batch): 223 | grads = grad(self.loss)(params, batch) 224 | updates, opt_state = self.optimizer.update(grads, opt_state, params) 225 | params = optax.apply_updates(params, updates) 226 | return params, opt_state, grads 227 | 228 | # Optimize parameters in a loop 229 | def train(self, dataset, nIter = 10_000): 230 | """ Trains the neural network for nIter steps using data loader. 231 | 232 | Args: 233 | dataset (SquareDataset): data loader for training. 234 | nIter (int): number of training iterations. 235 | """ 236 | data = iter(dataset) 237 | pbar = trange(nIter) 238 | # Main training loop 239 | for it in pbar: 240 | batch = next(data) 241 | self.params, self.opt_state, grads = self.step(self.params, 242 | self.opt_state, 243 | batch) 244 | # Logger 245 | if it % self.steps_per_check == 0: 246 | l = self.loss(self.params, batch) 247 | g_norm = linear_algebra.global_norm(grads).squeeze() 248 | self.loss_log.append(l) 249 | self.grad_norm_log.append(g_norm) 250 | if self.true_fun is not None: 251 | interior_batch, border_batch = batch 252 | pred = self.apply(self.params, 253 | interior_batch[:,0][:,None], 254 | interior_batch[:,1][:,None]) 255 | true = self.true_fun(interior_batch[:,0][:,None], 256 | interior_batch[:,1][:,None]) 257 | rel_l2_error = jnp.sqrt(((pred-true)**2).mean() \ 258 | / ((true)**2).mean()) 259 | self.rel_l2_log.append(rel_l2_error) 260 | pbar.set_postfix_str(f"loss:{l : .3e}, rel_l2:{rel_l2_error : .2e}, 'grad_norm':{jnp.mean(jnp.array(g_norm)) : .2e}") 261 | else: 262 | pbar.set_postfix({ 263 | 'loss': l, 264 | 'grad_norm': jnp.mean(jnp.array(g_norm)), 265 | }) 266 | 267 | # Define a compiled update step 268 | @partial(jit, static_argnums=(0,)) 269 | def step_lbfgs(self, params, opt_state, batch): 270 | new_params, opt_state = self.optimizer_update_lbfgs(params, 271 | opt_state, 272 | batch) 273 | return new_params, opt_state 274 | 275 | # Optimize parameters in a loop 276 | def train_lbfgs(self, dataset, nIter = 10000): 277 | """ Trains the neural network using LBFGS optimizer for nIter steps 278 | using data loader. 279 | 280 | Args: 281 | dataset (SquareDataset): data loader for training. 282 | nIter (int): number of training iterations. 283 | """ 284 | data = iter(dataset) 285 | pbar = trange(nIter) 286 | batch = next(data) 287 | self.opt_state_lbfgs = self.optimizer_lbfgs.init_state(self.params, 288 | batch) 289 | # Main training loop 290 | for it in pbar: 291 | batch = next(data) 292 | # Logger 293 | if it % self.steps_per_check == 0: 294 | l = self.loss(self.params, batch) 295 | self.loss_log.append(l) 296 | grads = grad(self.loss)(self.params, batch) 297 | g_norm = linear_algebra.global_norm(grads).squeeze() 298 | self.grad_norm_log.append(g_norm) 299 | if self.true_fun is not None: 300 | interior_batch, border_batch = batch 301 | pred = self.apply(self.params, 302 | interior_batch[:,0][:,None], 303 | interior_batch[:,1][:,None]) 304 | true = self.true_fun(interior_batch[:,0][:,None], 305 | interior_batch[:,1][:,None]) 306 | rel_l2_error = jnp.sqrt(((pred-true)**2).mean() \ 307 | / ((true)**2).mean()) 308 | self.rel_l2_log.append(rel_l2_error) 309 | pbar.set_postfix_str(f"loss:{l : .3e}, rel_l2:{rel_l2_error : .2e}, 'grad_norm':{jnp.mean(jnp.array(g_norm)) : .2e}") 310 | else: 311 | pbar.set_postfix({ 312 | 'loss': l, 313 | 'grad_norm': jnp.mean(jnp.array(g_norm)), 314 | }) 315 | # take step 316 | self.params, self.opt_state_lbfgs = self.step_lbfgs(self.params, 317 | self.opt_state_lbfgs, 318 | batch) 319 | 320 | 321 | def plot_logs(self, window=None) -> None: 322 | """ Plots logs of training losses and gradient norms through training. 323 | 324 | Args: 325 | window: desired window for computing moving averages (default: None). 326 | """ 327 | plot_logs(self.loss_log, self.grad_norm_log, window=window, 328 | steps_per_check=self.steps_per_check) 329 | 330 | def batched_apply(self, x, batch_size=2_048): 331 | '''Performs forward pass using smaller batches, then concatenates them 332 | together before returning predictions. Useful for avoiding OoM issues 333 | when input is large. 334 | 335 | Args: 336 | x: input to the model 337 | batch_size: maximum batch size for computation. 338 | 339 | Returns: 340 | predictions of the model on input x 341 | ''' 342 | num_batches = int(jnp.ceil(len(x) / batch_size)) 343 | x_batches = jnp.split(x, 344 | batch_size*(1+jnp.arange(num_batches-1)), 345 | axis=0) 346 | pred_fn = jit(lambda ins : self.apply(self.params, 347 | ins[:,0][:,None], 348 | ins[:,1][:,None])) 349 | y_pred = jnp.concatenate([pred_fn(ins) for ins in x_batches], axis=0) 350 | return y_pred 351 | 352 | def get_rmse(self, batch, batch_size=2_048): 353 | # Create predictions 354 | u, s_true = batch 355 | if batch_size is None: # single forward pass 356 | s_pred = self.apply(self.params, u) 357 | else: # breaks prediction into smaller forward passes 358 | s_pred = self.batched_apply(u, batch_size=batch_size) 359 | error = s_pred - s_true 360 | rmse = jnp.sqrt(jnp.mean(error**2)) 361 | return rmse 362 | 363 | def plot_predictions(self, batch, return_pred=False, batch_size=2_048, 364 | num_levels = 100): 365 | """Computes and plots model predictions for a given batch of data. 366 | 367 | Args: 368 | batch: data for creating/plotting results. 369 | return_pred: whether to return predictions after plotting 370 | (default: False). 371 | batch_size: batch size for computations (to avoid OoM issues in the 372 | case of large datasets). (default: 2048) 373 | num_levels: number of levels for contour plot (default: 100). 374 | """ 375 | # Create predictions 376 | u, s_true = batch 377 | if batch_size is None: # single forward pass 378 | s_pred = self.apply(self.params, u) 379 | else: # breaks prediction into smaller forward passes 380 | s_pred = self.batched_apply(u, batch_size=batch_size) 381 | 382 | error = s_pred - s_true 383 | rel_l2_error = jnp.sqrt(jnp.sum(error**2)/jnp.sum(s_true**2)) 384 | print('Relative L2 error: {:.2e}'.format(rel_l2_error)) 385 | print('RMSE: {:.2e}'.format(jnp.sqrt(jnp.mean(error**2)))) 386 | 387 | plt.figure(figsize=(16, 4)) 388 | 389 | # Ploting examples of reconstructions 390 | plt.subplot(131) 391 | plt.tricontourf(u[:,0], u[:,1], 392 | s_pred.squeeze(), levels=num_levels) 393 | plt.colorbar() 394 | plt.xlabel('$x$') 395 | plt.ylabel('$y$') 396 | plt.title('Prediction') 397 | 398 | # Ploting true solution 399 | plt.subplot(132) 400 | plt.tricontourf(u[:,0], u[:,1], 401 | s_true.squeeze(), levels=num_levels) 402 | plt.colorbar() 403 | plt.xlabel('$x$') 404 | plt.ylabel('$y$') 405 | plt.title('True') 406 | 407 | # Ploting absolute 408 | plt.subplot(133) 409 | plt.tricontourf(u[:,0], u[:,1], 410 | abs(s_pred-s_true).squeeze(), levels=num_levels) 411 | plt.colorbar() 412 | plt.xlabel('$x$') 413 | plt.ylabel('$y$') 414 | plt.title('Absolute Error') 415 | 416 | plt.show() 417 | 418 | plt.show() 419 | 420 | if return_pred: 421 | return s_pred -------------------------------------------------------------------------------- /archs.py: -------------------------------------------------------------------------------- 1 | # Basic Library Imports 2 | import jax 3 | import jax.numpy as jnp 4 | from jax import random 5 | from jax import vmap, jit 6 | 7 | from flax import linen as nn 8 | 9 | from typing import Any, Callable, Sequence, Tuple, Union 10 | 11 | # acceptable types for matmul precision in JAX 12 | PrecisionLike = Union[None, str, jax.lax.Precision, Tuple[str, str], 13 | Tuple[jax.lax.Precision, jax.lax.Precision]] 14 | # acceptable type for vector shapes 15 | Shape = Sequence[int] 16 | 17 | # identity function 18 | identity = lambda x : x 19 | 20 | 21 | ###################################################### 22 | #################### Initializers #################### 23 | ###################################################### 24 | 25 | # Siren Initialization 26 | def siren_initializer(key, shape, dtype=jnp.float32): 27 | """ 28 | Returns a random vector of desired shape using Siren's initialization. 29 | 30 | Args: 31 | key: a PRNG key used as the random key. 32 | shape: shape of weights. 33 | dtype: the dtype of the weights. 34 | 35 | Returns: 36 | A random Siren weight array with the specified shape and dtype. 37 | """ 38 | aux = jnp.sqrt(6. / shape[0]) 39 | return random.uniform(key, shape=shape, minval=-aux, maxval=aux, dtype=dtype) 40 | 41 | def siren_first_layer_initializer(key, shape, dtype): 42 | """ 43 | Returns a random vector of desired shape using Siren's initialization for the 44 | first layer. 45 | 46 | Args: 47 | key: a PRNG key used as the random key. 48 | shape: shape of weights. 49 | dtype: the dtype of the weights. 50 | 51 | Returns: 52 | A random Siren weight array (first layer) with the specified shape & dtype. 53 | """ 54 | aux = 1/shape[0] 55 | return random.uniform(key, shape, minval=-aux, maxval=aux, dtype=dtype) 56 | 57 | # Custom Initialization 58 | def kan_initializer(key, shape, dtype=jnp.float32, sigma_0=0.1): 59 | """ 60 | Returns a random vector of desired shape using KAN's initialization. 61 | 62 | Args: 63 | key: a PRNG key used as the random key. 64 | shape: shape of weights. 65 | dtype: the dtype of the weights. 66 | sigma (float): sigma parameter for initialization as specified in KAN paper. 67 | 68 | Returns: 69 | A random KAN weight array with the specified shape and dtype. 70 | """ 71 | aux = sigma_0/jnp.sqrt(shape[0]) 72 | return aux*random.normal(key, shape=shape, dtype=dtype) 73 | 74 | def get_kan_initializer(sigma=0.1): 75 | """ 76 | Returns a KAN initializer with desired choice of sigma. 77 | 78 | Args: 79 | sigma (float): sigma parameter for initialization as specified in KAN paper. 80 | 81 | Returns: 82 | A KAN initializer function. 83 | """ 84 | return lambda key, shape, dtype=jnp.float32 : \ 85 | kan_initializer(key, shape, dtype=dtype, sigma_0=sigma) 86 | 87 | 88 | ############################################################### 89 | ######################## Architectures ######################## 90 | ############################################################### 91 | 92 | ############# 93 | #### MLP #### 94 | ############# 95 | 96 | class MLP(nn.Module): 97 | """A Multi-Layer Prerception network. 98 | 99 | Attributes: 100 | features: sequence of int detailing width of each layer. 101 | activation: activation function to be used in between layers (default: 102 | nn.gelu). 103 | output_activation: activation for last layer of network (default: identity). 104 | precision: numerical precision of the computation. See ``jax.lax.Precision`` 105 | for details. (default: None) 106 | """ 107 | features: Sequence[int] 108 | activation : Callable=nn.gelu 109 | output_activation : Callable=identity 110 | precision: PrecisionLike = None 111 | 112 | @nn.compact 113 | def __call__(self, x): 114 | """Forward pass of a MLP network. 115 | 116 | Args: 117 | x: The nd-array to be transformed. 118 | 119 | Returns: 120 | The transformed input x. 121 | """ 122 | for feat in self.features[:-1]: 123 | x = self.activation(nn.Dense(feat, precision=self.precision)(x)) 124 | x = nn.Dense(self.features[-1], precision=self.precision)(x) 125 | return self.output_activation(x) # different activation on output layer 126 | 127 | 128 | ############### 129 | #### Siren #### 130 | ############### 131 | 132 | # see https://arxiv.org/abs/2006.09661 for details about Siren, which is an MLP 133 | # with sine activation and a specific initialization pattern. See below for an 134 | # iteractive colab notebook provided by the authors: 135 | # https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb 136 | 137 | class Siren(nn.Module): 138 | """A Siren network. 139 | 140 | Attributes: 141 | features: sequence of int detailing width of each layer. 142 | w0: frequency content parameter for mutiplying initial inputs. See Siren 143 | paper for more details. 144 | output_activation: activation for last layer of network (default: identity). 145 | precision: numerical precision of the computation. See ``jax.lax.Precision`` 146 | for details. (default: None) 147 | """ 148 | features: Sequence[int] 149 | w0 : float 150 | output_activation : Callable=identity 151 | precision: PrecisionLike = None 152 | 153 | @nn.compact 154 | def __call__(self, x): 155 | """Forward pass of a Siren network. 156 | 157 | Args: 158 | x: The nd-array to be transformed. 159 | 160 | Returns: 161 | The transformed input x. 162 | """ 163 | x = x*self.w0 164 | x = jnp.sin(nn.Dense(self.features[0], 165 | kernel_init=siren_first_layer_initializer, 166 | precision=self.precision)(x)) 167 | for feat in self.features[1:-1]: 168 | x = jnp.sin(nn.Dense(feat, 169 | kernel_init=siren_initializer, 170 | precision=self.precision)(x)) 171 | x = nn.Dense(self.features[-1])(x) 172 | return self.output_activation(x) 173 | 174 | 175 | ################ 176 | #### ActNet #### 177 | ################ 178 | 179 | # from https://www.wolframalpha.com/input?i=E%5B%28sin%28wx%2Bp%29%29%5D+where+x+is+normally+distributed 180 | def _mean_transf(mu, sigma, w, p): 181 | """ Mean of the R.V. Y=sin(w*X+p) when X is normally distributed with mean mu 182 | and standard deviation sigma. 183 | 184 | Args: 185 | mu: mean of the R.V. X. 186 | sigma: standard deviation of the R.V. X. 187 | w: frequency of the sinusoidal transformation. 188 | p: phase of the sinusoidal transformation. 189 | 190 | Returns: 191 | The mean of the transformed R.V. Y. 192 | """ 193 | return jnp.exp(-0.5* (sigma*w)**2) * jnp.sin(p + mu*w) 194 | 195 | # from https://www.wolframalpha.com/input?i=E%5Bsin%28wx%2Bp%29%5E2%5D+where+x+is+normally+distributed 196 | def _var_transf(mu, sigma, w, p): 197 | """ Variance of the R.V. Y=sin(w*X+p) when X is normally distributed with 198 | mean mu and standard deviation sigma. 199 | 200 | Args: 201 | mu: mean of the R.V. X. 202 | sigma: standard deviation of the R.V. X. 203 | w: frequency of the sinusoidal transformation. 204 | p: phase of the sinusoidal transformation. 205 | 206 | Returns: 207 | The variance of the transformed R.V. Y. 208 | """ 209 | return 0.5 - 0.5*jnp.exp(-2 * ((sigma*w)**2))*jnp.cos(2*(p+mu*w)) \ 210 | -_mean_transf(mu, sigma, w, p)**2 211 | 212 | class ActLayer(nn.Module): 213 | """A ActLayer module. 214 | 215 | For further details on standard choices of initializers, please refer to 216 | Appendix D of the paper: https://arxiv.org/pdf/2410.01990 217 | 218 | Attributes: 219 | out_dim: output dimension of ActLayer. 220 | num_freqs: number of frequencies/basis functions of the ActLayer. 221 | use_bias: whether to add bias the the output (default: True). 222 | freqs_init: initializer for basis function frequencies. 223 | phases_init: initializer for basis function phases. 224 | beta_init: initializer for beta parameter. 225 | lamb_init: initializer for lambda parameter. 226 | bias_init: initializer for bias parameter. 227 | freze_basis: whether to freeze gradients passing thorough basis 228 | functions (default: False). 229 | freq_scaling: whether to scale basis functions to ensure mean 0 and 230 | standard deviation 1 (default: True). 231 | freq_scaling_eps: small epsilon added to the denominator of frequency 232 | scaling for numerical stability (default: 1e-3). 233 | precision: numerical precision of the computation. See 234 | ``jax.lax.Precision`` for details. (default: None) 235 | """ 236 | out_dim : int 237 | num_freqs : int 238 | use_bias : bool=True 239 | # parameter initializers 240 | freqs_init : Callable=nn.initializers.normal(stddev=1.) # normal w/ mean 0 241 | phases_init : Callable=nn.initializers.zeros 242 | beta_init : Callable=nn.initializers.variance_scaling(1., 243 | 'fan_in', 244 | distribution='uniform') 245 | lamb_init : Callable=nn.initializers.variance_scaling(1., 246 | 'fan_in', 247 | distribution='uniform') 248 | bias_init : Callable=nn.initializers.zeros 249 | # other configurations 250 | freeze_basis : bool=False 251 | freq_scaling : bool=True 252 | freq_scaling_eps : float=1e-3 253 | precision: PrecisionLike = None 254 | 255 | @nn.compact 256 | def __call__(self, x): 257 | """Forward pass of an ActLayer. 258 | 259 | Args: 260 | x: The nd-array to be transformed. 261 | 262 | Returns: 263 | The transformed input x. 264 | """ 265 | # x should initially be shape (batch, d) 266 | 267 | # initialize trainable parameters 268 | freqs = self.param('freqs', 269 | self.freqs_init, 270 | (1,1,self.num_freqs)) # shape (1, 1, num_freqs) 271 | phases = self.param('phases', 272 | self.phases_init, 273 | (1,1,self.num_freqs)) # shape (1, 1, num_freqs) 274 | beta = self.param('beta', 275 | self.beta_init, 276 | (self.num_freqs, self.out_dim)) # shape (num_freqs, out_dim) 277 | lamb = self.param('lamb', 278 | self.lamb_init, 279 | (x.shape[-1], self.out_dim)) # shape (d, out_dim) 280 | 281 | if self.freeze_basis: 282 | freqs = jax.lax.stop_gradient(freqs) 283 | phases = jax.lax.stop_gradient(phases) 284 | 285 | # perform basis expansion 286 | x = jnp.expand_dims(x, 2) # shape (batch, d, 1) 287 | x = jnp.sin(freqs*x + phases) # shape (batch_dim, d, num_freqs) 288 | if self.freq_scaling: 289 | x = (x - _mean_transf(0., 1., freqs, phases)) \ 290 | / (jnp.sqrt(self.freq_scaling_eps + _var_transf(0., 1., 291 | freqs, phases))) 292 | 293 | 294 | # efficiently computes eq 6 from https://arxiv.org/pdf/2410.01990 using 295 | # einsum. Depending on hardware and JAX/CUDA version, there may be 296 | # slightly faster alternatives, but we chose this one for the sake of 297 | # simplicity/clarity. 298 | x = jnp.einsum('...ij, jk, ik-> ...k', x, beta, lamb, 299 | precision=self.precision) 300 | 301 | # optionally add bias 302 | if self.use_bias: 303 | bias = self.param('bias', 304 | self.bias_init, 305 | (self.out_dim,)) 306 | x = x + bias # Shape (batch_size, out_dim) 307 | 308 | return x # Shape (batch_size, out_dim) 309 | 310 | 311 | class ActNet(nn.Module): 312 | """A ActNet module. 313 | 314 | Attributes: 315 | embed_dim: embedding dimension for ActLayers. 316 | num_layers: how many intermediate blocks are used. 317 | out_dim: output dimension of ActNet. 318 | num_freqs: number of frequencies/basis functions of the ActLayers. 319 | output_activation: output_activation: activation for last layer of 320 | network (default: identity). 321 | op_order: order of operations contained in each intermediate block. This 322 | should be a string containing only 'A' (ActLayer), 'S' (Skip 323 | connection) or 'L' (LayerNorm) characters. (default: 'A') 324 | use_act_bias: whether to add bias the the output of ActLayers 325 | (default: True). 326 | freqs_init: initializer for basis function frequencies of ActLayers. 327 | phases_init: initializer for basis function phases of ActLayers. 328 | beta_init: initializer for beta parameter of ActLayers. 329 | lamb_init: initializer for lambda parameter of ActLayers. 330 | act_bias_init: initializer for bias parameter of ActLayers. 331 | proj_bias_init: initializer for bias parameter of initial projection 332 | Layer. 333 | w0_init: initializer for w0 scale parameter. 334 | w0_fixed: if False, initializes w0 using w0_init. Otherwise uses given 335 | fixed w0 (default: False). 336 | freze_basis: whether to freeze gradients passing thorough basis 337 | functions (default: False). 338 | freq_scaling: whether to scale basis functions to ensure mean 0 and 339 | standard deviation 1 (default: True). 340 | freq_scaling_eps: small epsilon added to the denominator of frequency 341 | scaling for numerical stability (default: 1e-3). 342 | precision: numerical precision of the computation. See 343 | ``jax.lax.Precision`` for details. (default: None) 344 | """ 345 | embed_dim : int 346 | num_layers : int # number of layers in the network 347 | out_dim : int # dimension of output vector 348 | num_freqs : int # how many frequencies/basis functions to use in ActLayers 349 | output_activation : Callable = identity 350 | op_order : str='A' 351 | # op_order should be a string containing only 'A' (ActLayer), 'S' (Skip 352 | # connection) or 'L' (LayerNorm) characters. This feature was used for 353 | # development/debugging, but is not used in any experiment of the paper. 354 | 355 | # parameter initializers 356 | freqs_init : Callable=nn.initializers.normal(stddev=1.) # normal w/ mean 0 357 | phases_init : Callable=nn.initializers.zeros 358 | beta_init : Callable=nn.initializers.variance_scaling(1., 'fan_in', 359 | distribution='uniform') 360 | lamb_init : Callable=nn.initializers.variance_scaling(1., 'fan_in', 361 | distribution='uniform') 362 | act_bias_init : Callable=nn.initializers.zeros 363 | proj_bias_init : Callable=lambda key, shape, dtype :\ 364 | random.uniform(key, shape, dtype, 365 | minval=-jnp.sqrt(3), maxval=jnp.sqrt(3)) 366 | 367 | w0_init : Callable=nn.initializers.constant(30.) # following SIREN strategy 368 | w0_fixed : Union[False, float]=False # if False, initializes w0 as above. Otherwise uses given fixed w0 369 | 370 | # other ActLayer configurations 371 | use_act_bias : bool=True 372 | freeze_basis : bool=False 373 | freq_scaling : bool=True 374 | freq_scaling_eps : float=1e-3 375 | precision: PrecisionLike = None # numerical precision for matrix operations 376 | 377 | @nn.compact 378 | def __call__(self, x): 379 | """Forward pass of an ActNet. 380 | 381 | Args: 382 | x: The nd-array to be transformed. 383 | 384 | Returns: 385 | The transformed input x. 386 | """ 387 | # initialize w0 parameter 388 | if self.w0_fixed is False: 389 | # trainable scalar parameter 390 | w0 = self.param('w0', 391 | self.w0_init, 392 | ()) 393 | # use softplus to ensure w0 is positive and does not decay to zero 394 | # too fast (used only while debugging). 395 | w0 = nn.softplus(w0) 396 | else: # use user-specified value for w0 397 | w0 = self.w0_fixed 398 | # scale by w0 factor, then project to embeded dimension 399 | x = x*w0 400 | x = nn.Dense(self.embed_dim, bias_init=self.proj_bias_init, 401 | precision=self.precision)(x) 402 | 403 | for _ in range(self.num_layers): 404 | y = x # store initial value as x, do operations on y 405 | for char in self.op_order: 406 | if char == 'A': # ActLayer 407 | y = ActLayer( 408 | out_dim = self.embed_dim, 409 | num_freqs = self.num_freqs, 410 | use_bias = self.use_act_bias, 411 | freqs_init = self.freqs_init, 412 | phases_init = self.phases_init, 413 | beta_init = self.beta_init, 414 | lamb_init = self.lamb_init, 415 | bias_init = self.act_bias_init, 416 | freeze_basis = self.freeze_basis, 417 | freq_scaling = self.freq_scaling, 418 | freq_scaling_eps = self.freq_scaling_eps, 419 | precision=self.precision, 420 | )(y) 421 | elif char == 'S': # Skip connection 422 | y = y + x 423 | elif char == 'L': # LayerNorm 424 | y = nn.LayerNorm()(y) 425 | else: 426 | raise NotImplementedError(f"Could not recognize option '{char}'. Options for op_order should be 'A' (ActLayer), 'S' (Skip connection) or 'L' (LayerNorm).") 427 | x = y # update value of x after all operations are done 428 | 429 | # project to output dimension and potentially use output activation 430 | x = nn.Dense(self.out_dim, kernel_init=nn.initializers.he_uniform(), 431 | precision=self.precision)(x) 432 | x = self.output_activation(x) 433 | 434 | return x 435 | 436 | 437 | ############## 438 | #### KAN ##### 439 | ############## 440 | 441 | # Adapted to JAX from the "EfficientKAN" GitHub repository (PyTorch). Code was 442 | # altered as little as possible, for the sake of consistency/fairness. 443 | # https://github.com/Blealtan/efficient-kan/blob/master/src/efficient_kan/kan.py 444 | 445 | class KANLinear(nn.Module): 446 | in_features : int 447 | out_features : int 448 | grid_size : int=5 449 | spline_order: int=3 450 | scale_noise : float=0.1 451 | scale_base : float=1.0 452 | scale_spline : float=1.0 453 | enable_standalone_scale_spline : bool=True 454 | base_activation : Callable=nn.silu 455 | grid_eps : float=0.02 456 | grid_range : Sequence[Union[float, int]]=(-1,1) 457 | precision: PrecisionLike = None 458 | 459 | def setup(self): 460 | h = (self.grid_range[1] - self.grid_range[0]) / self.grid_size 461 | self.h = h 462 | grid = ( 463 | ( 464 | jnp.arange(start=-self.spline_order, stop=self.grid_size + self.spline_order + 1) * h 465 | + self.grid_range[0] 466 | ) 467 | ) 468 | self.grid = grid * jnp.ones((self.in_features, 1)) 469 | 470 | self.base_weight = self.param('base_weight', # parameter name 471 | nn.initializers.he_uniform(), # initialization funciton 472 | (self.out_features, self.in_features)) # shape info 473 | self.spline_weight = self.param('spline_weight', # parameter name 474 | nn.initializers.he_uniform(), # initialization funciton 475 | (self.out_features, self.in_features, self.grid_size+self.spline_order)) # shape info 476 | 477 | if self.enable_standalone_scale_spline: 478 | self.spline_scaler = self.param('spline_scaler', # parameter name 479 | nn.initializers.he_uniform(), # initialization funciton 480 | (self.out_features, self.in_features)) # shape info 481 | 482 | 483 | def b_splines(self, x: jax.Array): 484 | """ 485 | Compute the B-spline bases for the given input tensor. 486 | 487 | Args: 488 | x: Input tensor of shape (batch_size, in_features). 489 | 490 | Returns: 491 | B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). 492 | """ 493 | assert len(x.shape) == 2 and x.shape[1] == self.in_features 494 | 495 | # grid is shape (in_features, grid_size + 2 * spline_order + 1) 496 | grid = self.grid 497 | x = jnp.expand_dims(x, -1) 498 | bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])) 499 | for k in range(1, self.spline_order + 1): 500 | bases = ( 501 | (x - grid[:, : -(k + 1)]) 502 | / (grid[:, k:-1] - grid[:, : -(k + 1)]) 503 | * bases[:, :, :-1] 504 | ) + ( 505 | (grid[:, k + 1 :] - x) 506 | / (grid[:, k + 1 :] - grid[:, 1:(-k)]) 507 | * bases[:, :, 1:] 508 | ) 509 | 510 | assert bases.shape == ( 511 | x.shape[0], 512 | self.in_features, 513 | self.grid_size + self.spline_order, 514 | ) 515 | return bases 516 | 517 | @property 518 | def scaled_spline_weight(self): 519 | return self.spline_weight * ( 520 | jnp.expand_dims(self.spline_scaler, -1) 521 | if self.enable_standalone_scale_spline 522 | else 1.0 523 | ) 524 | 525 | def __call__(self, x: jax.Array): 526 | assert x.shape[-1] == self.in_features, f"x.shape[-1]={x.shape[-1]} should be equal to {self.in_features}" 527 | original_shape = x.shape 528 | x = x.reshape(-1, self.in_features) 529 | 530 | base_output = jnp.matmul(self.base_activation(x), self.base_weight.T, precision=self.precision) 531 | spline_output = jnp.matmul( 532 | self.b_splines(x).reshape(x.shape[0], -1), 533 | self.scaled_spline_weight.reshape(self.out_features, -1).T, 534 | precision=self.precision, 535 | ) 536 | output = base_output + spline_output 537 | 538 | output = output.reshape(*original_shape[:-1], self.out_features) 539 | return output 540 | 541 | 542 | class KAN(nn.Module): 543 | features : Sequence[int] 544 | output_activation : Callable=identity 545 | grid_size : int=5 546 | spline_order: int=3 547 | scale_noise : float=0.1 548 | scale_base : float=1.0 549 | scale_spline : float=1.0 550 | enable_standalone_scale_spline : bool=True 551 | base_activation : Callable=nn.silu 552 | grid_eps : float=0.02 553 | grid_range : Sequence[Union[float, int]]=(-1,1) 554 | precision: PrecisionLike = None 555 | 556 | def setup(self): 557 | self.layers = [KANLinear( 558 | self.features[i], 559 | self.features[i+1], 560 | grid_size=self.grid_size, 561 | spline_order=self.spline_order, 562 | scale_noise=self.scale_noise, 563 | scale_base=self.scale_base, 564 | scale_spline=self.scale_spline, 565 | enable_standalone_scale_spline=self.enable_standalone_scale_spline, 566 | base_activation=self.base_activation, 567 | grid_eps=self.grid_eps, 568 | grid_range=self.grid_range, 569 | precision=self.precision, 570 | ) for i in range(len(self.features) - 1)] 571 | 572 | def __call__(self, x): 573 | for l in self.layers: 574 | x = l(x) 575 | return self.output_activation(x) 576 | 577 | 578 | 579 | ############################################################ 580 | ################### Architecture Builder ################### 581 | ############################################################ 582 | 583 | def arch_from_config(arch_config): 584 | ''' Given a config file, outputs architecture object with given 585 | configurations. 586 | 587 | Args: 588 | arch_config: config file specifying architecture hyperparameters. 589 | 590 | Returns: 591 | Architecture as a Flax Linen nn.Module. 592 | ''' 593 | if arch_config.arch_type == 'ActNet': 594 | arch = ActNet(**arch_config.hyperparams) 595 | return arch 596 | elif arch_config.arch_type == 'MLP': 597 | arch = MLP(**arch_config.hyperparams) 598 | return arch 599 | elif arch_config.arch_type == 'Siren': 600 | arch = Siren(**arch_config.hyperparams) 601 | return arch 602 | elif arch_config.arch_type == 'KAN': 603 | arch = KAN(**arch_config.hyperparams) 604 | return arch 605 | else: 606 | raise NotImplementedError(f"Cannot recognize arch_type {arch_config.arch_type}. So far, only 'ActNet', 'MLP', 'Siren' and 'KAN' are implemented") -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | --------------------------------------------------------------------------------