├── allen_cahn
├── ac_solution_dirichlet.pkl
└── CausalAllenCahnModel.py
├── README.md
├── utils.py
├── models.py
├── poisson_2d
└── PoissonModel.py
├── helmholtz_2d
└── HelmholtzModel.py
├── archs.py
└── LICENSE
/allen_cahn/ac_solution_dirichlet.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PredictiveIntelligenceLab/ActNet/HEAD/allen_cahn/ac_solution_dirichlet.pkl
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ActNet
2 |
3 | Repository for some of the experiments presented in the paper "Deep Learning Alternatives of the Kolmogorov Superposition Theorem", acepted as a Spotlight Paper in ICLR 2025. (arXiv: https://arxiv.org/abs/2410.01990)
4 |
5 | This code requires common libraries of the JAX environment, such as Flax (for neural network design) and Optax/JaxOpt (for training and optimization). Plotting is done using Matplotlb.
6 |
7 | Experiments comparing against the state-of-the-art require integration with JaxPi, which is an open-source library. The code for those experiments can now be found on the ActNet branch of JaxPi: https://github.com/PredictiveIntelligenceLab/jaxpi/tree/ActNet
8 |
9 | FILES:
10 | * archs.py : includes the architectures used in the paper, including JAX implementations of ActNet, KAN and Siren.
11 | * models.py : includes a training model for regression that can be used with any of the architectures.
12 | * utils.py : includes useful code for sampling batches.
13 | * poisson_2d/ : directory containing minimal code to run the Poisson 2D problem.
14 | * PoissonModel.py : flexible training model for the 2D Poisson problem that can be used with any desired architecture.
15 | * prediction_plots.ipynb : Jupyter notebook tutorial showing how to run the Poisson problem and produce plots.
16 | * helmholtz_2d/ : directory containing minimal code to run the Helmholtz 2D problem.
17 | * HelmholtzModel.py : flexible training model for the 2D Helmholtz problem that can be used with any desired architecture.
18 | * prediction_plots.ipynb : Jupyter notebook tutorial showing how to run the Helmholtz problem and produce plots.
19 | * allen_cahn/ : directory containing minimal code to run the Allen-Cahn problem.
20 | * ac_solution_dirichlet.pkl : pickle file containing reference solution for the Allen-Cahn obtained using a classical solver.
21 | * CausalAllenCahnModel.py : flexible training model for the Allen-Cahn problem that can be used with any desired architecture.
22 | * prediction_plots.ipynb : Jupyter notebook tutorial showing how to run the Allen-Cahn problem and produce plots.
23 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax import random
4 | from jax import jit, vmap, grad
5 |
6 |
7 | from functools import partial
8 | import torch.utils.data as data
9 |
10 |
11 | # Dataset loader
12 | class BatchedDataset(data.Dataset):
13 | ''' A data loader for creating mini-batches.
14 |
15 | Attributes:
16 | raw_data: full dataset to be used. This should be a tuple of lenght 3,
17 | formated as (inputs, targets, weights).
18 | key: a PRNG key used as the random key.
19 | batch_size: the size of each mini-batch. If None, uses full batch
20 | (default: None).
21 | '''
22 |
23 | def __init__(self, raw_data, key, batch_size=None):
24 | super().__init__()
25 | self.inputs = raw_data[0]
26 | self.targets = raw_data[1]
27 | self.weights = raw_data[2]
28 | self.size = len(self.weights)
29 | self.key = key
30 | if batch_size is None: # Will use full batch
31 | self.batch_size = self.size
32 | else:
33 | self.batch_size = batch_size
34 |
35 | def __len__(self):
36 | return self.size
37 |
38 | def __getitem__(self, idx):
39 | self.key, subkey = random.split(self.key)
40 | batch_inputs, batch_targets, batched_weights = self.__select_batch(subkey)
41 | return batch_inputs, batch_targets, batched_weights
42 |
43 | @partial(jit, static_argnums=(0,))
44 | def __select_batch(self, key):
45 | idx = random.choice(key, self.size, (self.batch_size,), replace=False)
46 | batch_inputs = self.inputs[idx]
47 | batch_targets = self.targets[idx]
48 | batched_weights = self.weights[idx]
49 | return batch_inputs, batch_targets, batched_weights
50 |
51 |
52 | class SquareDataset(data.Dataset):
53 | ''' A data loader for creating mini-batches of uniformly samples points on the
54 | inside and on the boundary of a [-1,1]^2 square. Generates a pair of vectors
55 | (interior_batch, border_batch) with iid points on the interior and border of
56 | squre, respectively.
57 |
58 | Attributes:
59 | key: a PRNG key used as the random key.
60 | batch_size: the size of each mini-batch. Should be a tuple of lenght 2 in
61 | the format (inside_batch_size, border_batch_size).
62 | '''
63 | def __init__(self, key, batch_size=(10_000, 800)):
64 | super().__init__()
65 | self.size = batch_size[0]
66 | self.key = key
67 | self.batch_size = batch_size
68 |
69 | def __len__(self):
70 | return self.size
71 |
72 | def __getitem__(self, idx):
73 | self.key, subkey1, subkey2 = random.split(self.key, 3)
74 | interior_batch, border_batch = self.__select_batch(subkey1, subkey2)
75 | return interior_batch, border_batch
76 |
77 | @partial(jit, static_argnums=(0,))
78 | def __select_batch(self, subkey1, subkey2):
79 | interior_batch = random.uniform(subkey1, shape=(self.batch_size[0], 2),
80 | minval=-1, maxval=1)
81 | border_batch = random.uniform(subkey2, shape=(self.batch_size[1],1),
82 | minval=-1, maxval=1)
83 | aux = jnp.split(border_batch, 4)
84 | border_batch = jnp.concatenate([
85 | jnp.hstack([-jnp.ones_like(aux[0]), aux[0]]),
86 | jnp.hstack([jnp.ones_like(aux[1]), aux[1]]),
87 | jnp.hstack([aux[2], -jnp.ones_like(aux[2])]),
88 | jnp.hstack([aux[3], jnp.ones_like(aux[3])]),
89 | ], axis=0)
90 | return interior_batch, border_batch
91 |
92 | # alias for SquareDataset
93 | Poisson2DDataset = SquareDataset
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax import random
4 | from jax import jit, vmap, grad, value_and_grad
5 |
6 | import numpy as onp
7 | import optax
8 | from optax._src import linear_algebra
9 | import jaxopt
10 |
11 | from functools import partial
12 | import itertools
13 | from tqdm import trange
14 | import matplotlib.pyplot as plt
15 |
16 |
17 |
18 | # not used for the ActNet paper, but maybe useful for the future
19 | class RegressionModel:
20 | """ A model for training/evaluating a neural network using regression.
21 |
22 | Attributes:
23 | arch (nn.Module): a Flax module of the desired architecute.
24 | batch: an initial batch used for initializing parameters and computing
25 | normalization factors.
26 | optimizer: the optimizer to be used when running gradient descent.
27 | normalize_inputs: whether to normalize inputs before passing them to the
28 | architecture (default: True).
29 | normalize_outputs: whether to normalize outputs of the architecture
30 | (default: True).
31 | key: a PRNG key used as the random key for initialization.
32 | steps_per_check (int): how many training steps to use between logging
33 | and displaying losses (default: 100).
34 | """
35 | def __init__(self, arch, batch, optimizer=None, normalize_inputs=True,
36 | normalize_outputs=True, key=random.PRNGKey(43),
37 | steps_per_check=100) -> None:
38 | # Define model
39 | self.arch = arch
40 | self.key = key
41 | self.steps_per_check = steps_per_check
42 |
43 | # Initialize parameters
44 | inputs, outputs, _ = batch
45 | self.params = self.arch.init(self.key, inputs)
46 |
47 | # Tabulate function for checking network architecture
48 | self.tabulate = lambda : \
49 | self.arch.tabulate(self.key, inputs, console_kwargs={'width':110})
50 |
51 | # Vectorized functions
52 | self.normalize_inputs = normalize_inputs
53 | self.normalize_outputs = normalize_outputs
54 | self.normalize_data = (normalize_inputs or normalize_outputs)
55 | if self.normalize_data:
56 | mu_x = inputs.mean(0, keepdims=True)
57 | sig_x = inputs.std(0, keepdims=True)
58 | mu_y = outputs.mean(0, keepdims=True)
59 | sig_y = outputs.std(0, keepdims=True)
60 | self.norm_stats = ((mu_x, sig_x), (mu_y, sig_y))
61 | if self.normalize_inputs:
62 | if self.normalize_outputs:
63 | self.apply = lambda params, x : \
64 | mu_y + sig_y*self.arch.apply(params,
65 | (x-mu_x)/(sig_x + 0.01))
66 | else:
67 | self.apply = lambda params, x : \
68 | self.arch.apply(params, (x-mu_x)/(sig_x + 0.01))
69 | else:
70 | self.apply = lambda params, x : \
71 | mu_y + sig_y*self.arch.apply(params, x)
72 |
73 | else:
74 | self.norm_stats = None
75 | self.apply = self.arch.apply
76 | # jits apply function for numerical consistency (sometimes jitted
77 | # version behaves slightly differently than non-jitted one)
78 | self.apply = jit(self.apply)
79 |
80 | # Optimizer
81 | if optimizer is None:
82 | lr = optax.exponential_decay(1e-3, transition_steps=1000,
83 | decay_rate=0.9, end_value=1e-5)
84 | self.optimizer = optax.adam(learning_rate=lr)
85 | else:
86 | self.optimizer = optimizer
87 | self.opt_state = self.optimizer.init(self.params)
88 |
89 | # Optimizer LBFGS
90 | self.optimizer_lbfgs = jaxopt.LBFGS(self.loss)
91 | self.opt_state_lbfgs = self.optimizer_lbfgs.init_state(self.params,
92 | batch)
93 | self.optimizer_update_lbfgs = jit(self.optimizer_lbfgs.update)
94 |
95 | # Logger
96 | self.itercount = itertools.count()
97 | self.loss_log = []
98 | self.grad_norm_log = []
99 |
100 | def recon_loss(self, params, u, s, w):
101 | outputs = self.apply(params, u) # shape (batch_dim, out_dim)
102 | loss = jnp.mean(w*(s-outputs)**2, axis=(-1)) # shape (batch_dim,)
103 | return loss
104 |
105 | @partial(jit, static_argnums=(0,))
106 | def loss(self, params, batch):
107 | inputs, targets, weights = batch
108 | u = inputs
109 | s = targets
110 | w = weights
111 | return self.recon_loss(params, u, s, w).mean() # scalar
112 |
113 |
114 | # Define a compiled update step
115 | @partial(jit, static_argnums=(0,))
116 | def step(self, params, opt_state, batch):
117 | grads = grad(self.loss)(params, batch)
118 | updates, opt_state = self.optimizer.update(grads, opt_state, params)
119 | params = optax.apply_updates(params, updates)
120 | return params, opt_state, grads
121 |
122 | # Optimize parameters in a loop
123 | def train(self, dataset, nIter = 10000):
124 | """ Trains the neural network for nIter steps using data loader.
125 |
126 | Args:
127 | dataset (BatchedDataset): data loader for training.
128 | nIter (int): number of training iterations.
129 | """
130 | data = iter(dataset)
131 | pbar = trange(nIter)
132 | # Main training loop
133 | for it in pbar:
134 | batch = next(data)
135 | self.params, self.opt_state, grads = self.step(self.params,
136 | self.opt_state,
137 | batch)
138 | # Logger
139 | if it % self.steps_per_check == 0:
140 | l = self.loss(self.params, batch)
141 | g_norm = linear_algebra.global_norm(grads).squeeze()
142 | self.loss_log.append(l)
143 | self.grad_norm_log.append(g_norm)
144 | pbar.set_postfix({
145 | 'loss': l,
146 | 'grad_norm': jnp.mean(jnp.array(g_norm))
147 | })
148 |
149 | # Define a compiled update step
150 | @partial(jit, static_argnums=(0,))
151 | def step_lbfgs(self, params, opt_state, batch):
152 | new_params, opt_state = self.optimizer_update_lbfgs(params,
153 | opt_state,
154 | batch)
155 | return new_params, opt_state
156 |
157 | # Optimize parameters in a loop
158 | def train_lbfgs(self, dataset, nIter = 10000):
159 | """ Trains the neural network using LBFGS optimizer for nIter steps
160 | using data loader.
161 |
162 | Args:
163 | dataset (BatchedDataset): data loader for training.
164 | nIter (int): number of training iterations.
165 | """
166 | data = iter(dataset)
167 | pbar = trange(nIter)
168 | batch = next(data)
169 | self.opt_state_lbfgs = self.optimizer_lbfgs.init_state(self.params,
170 | batch)
171 | # Main training loop
172 | for it in pbar:
173 | batch = next(data)
174 | # Logger
175 | if it % self.steps_per_check == 0:
176 | l = self.loss(self.params, batch)
177 | self.loss_log.append(l)
178 | grads = grad(self.loss)(self.params, batch)
179 | g_norm = linear_algebra.global_norm(grads).squeeze()
180 | self.grad_norm_log.append(g_norm)
181 | pbar.set_postfix({
182 | 'loss': l,
183 | 'grad_norm': jnp.mean(jnp.array(g_norm))
184 | })
185 | # optimization step
186 | self.params, self.opt_state_lbfgs = self.step_lbfgs(self.params,
187 | self.opt_state_lbfgs,
188 | batch)
189 |
190 | def plot_logs(self, window=None) -> None:
191 | """ Plots logs of training losses and gradient norms through training.
192 |
193 | Args:
194 | window: desired window for computing moving averages (default: None)
195 | """
196 | plot_logs(self.loss_log, self.grad_norm_log, window=window,
197 | steps_per_check=self.steps_per_check)
198 |
199 | def batched_apply(self, x, batch_size=2_048):
200 | '''Performs forward pass using smaller batches, then concatenates them
201 | together before returning predictions. Useful for avoiding OoM issues
202 | when input is large.
203 |
204 | Args:
205 | x: input to the model
206 | batch_size: maximum batch size for computation.
207 |
208 | Returns:
209 | predictions of the model on input x
210 | '''
211 | num_batches = int(jnp.ceil(len(x) / batch_size))
212 | x_batches = jnp.split(x,
213 | batch_size*(1+jnp.arange(num_batches-1)),
214 | axis=0)
215 | pred_fn = lambda ins : self.apply(self.params, ins)
216 | y_pred = jnp.concatenate([pred_fn(ins) for ins in x_batches], axis=0)
217 | return y_pred
218 |
219 | def get_rmse(self, batch, batch_size=2_048):
220 | # Create predictions
221 | u, s_true, _ = batch
222 | if batch_size is None: # single forward pass
223 | s_pred = self.apply(self.params, u)
224 | else: # breaks prediction into smaller forward passes
225 | s_pred = self.batched_apply(u, batch_size=batch_size)
226 | error = s_pred - s_true
227 | rmse = jnp.sqrt(jnp.mean(error**2))
228 | return rmse
229 |
230 | def plot_predictions(self, batch, return_pred=False, batch_size=2_048):
231 | """Computes and plots model predictions for a given batch of data.
232 |
233 | Args:
234 | batch: data for creating/plotting results.
235 | return_pred: whether to return predictions after plotting
236 | (default: False).
237 | batch_size: batch size for computations (to avoid OoM issues in the
238 | case of large datasets). (default: 2048)
239 | """
240 | # Create predictions
241 | u, s_true, _ = batch
242 | if batch_size is None: # single forward pass
243 | s_pred = self.apply(self.params, u)
244 | else: # breaks prediction into smaller forward passes
245 | s_pred = self.batched_apply(u, batch_size=batch_size)
246 |
247 | error = s_pred - s_true
248 | rel_l2_error = jnp.sqrt(jnp.sum(error**2)/jnp.sum(s_true**2))
249 | print('Relative L2 error: {:.2e}'.format(rel_l2_error))
250 | print('RMSE: {:.2e}'.format(jnp.sqrt(jnp.mean(error**2))))
251 |
252 | if u.shape[-1]== 1: # domain is 1D
253 | plt.figure(figsize=(15, 4))
254 |
255 | # Ploting examples of reconstructions
256 | plt.subplot(131)
257 | plt.plot(u, s_true, lw=1)
258 | plt.plot(u, s_pred, '--', lw=1)
259 | plt.xlabel('$y$')
260 | plt.ylabel('$s$')
261 | plt.title('Prediction Vs Truth (predictions are dashed)')
262 |
263 | # Ploting error
264 | plt.subplot(132)
265 | plt.plot(u, s_pred-s_true, lw=1)
266 | plt.xlabel('$y$')
267 | plt.ylabel('$s$')
268 | plt.title('Error')
269 |
270 | # plotting histogram of errors
271 | plt.subplot(133)
272 | plt.hist(error.flatten(), bins=50)
273 | plt.title('Histogram of errors')
274 |
275 | plt.show()
276 | elif u.shape[-1] == 2: # domain is 2D
277 | plt.figure(figsize=(15, 4))
278 |
279 | # Ploting examples of reconstructions
280 | plt.subplot(131)
281 | plt.scatter(u[:,0], u[:,1], c=s_pred)
282 | plt.colorbar()
283 | plt.xlabel('$y$')
284 | plt.ylabel('$s$')
285 | plt.title('Prediction')
286 |
287 | # Ploting true solution
288 | plt.subplot(132)
289 | plt.scatter(u[:,0], u[:,1], c=s_true)
290 | plt.colorbar()
291 | plt.xlabel('$y$')
292 | plt.ylabel('$s$')
293 | plt.title('True')
294 |
295 | # Ploting errors
296 | plt.subplot(133)
297 | plt.scatter(u[:,0], u[:,1], c=s_pred-s_true)
298 | plt.colorbar()
299 | plt.xlabel('$y$')
300 | plt.ylabel('$s$')
301 | plt.title('Error')
302 |
303 | plt.show()
304 | else: # domain is higher than 2D. Plot histogram of errors instead
305 | # plotting histogram of errors
306 | plt.hist(error.flatten(), bins=50)
307 | plt.title('Histogram of errors')
308 | plt.show()
309 |
310 | if return_pred:
311 | return s_pred
312 |
313 | # alias for RegressionModel
314 | SupervisedModel = RegressionModel
315 |
316 |
317 |
318 | # Functions to help plotting
319 | def plot_logs(loss_log, grad_norm_log, window=None, steps_per_check=100):
320 | """ Plots logs of training losses and gradient norms through training.
321 |
322 | Args:
323 | loss_log: sequence of training losses.
324 | grad_norm_log: sequence of parameter gradient norms.
325 | window: desired window for computing moving averages (default: None).
326 | steps_per_check: how many training steps were taken between each log.
327 | """
328 | plt.figure(figsize=(12, 4))
329 |
330 | # Plotting losses
331 | plt.subplot(121)
332 | if window is None:
333 | plt.plot(steps_per_check*jnp.arange(len(loss_log)), loss_log)
334 | else:
335 | assert type(window) is int , f'window must be an interger or None, not {type(window)}'
336 | plt.plot(steps_per_check*jnp.arange(len(loss_log) - window),
337 | [onp.mean(loss_log[i:i+window]) for i in range(len(loss_log) - window)])
338 | plt.yscale('log')
339 | plt.title('Loss through iterations')
340 |
341 | # Plotting gradient norms
342 | plt.subplot(122)
343 | if window is None:
344 | plt.plot(steps_per_check*jnp.arange(len(grad_norm_log)), grad_norm_log)
345 | else:
346 | assert type(window) is int , f'window must be an interger or None, not {type(window)}'
347 | plt.plot(steps_per_check*jnp.arange(len(grad_norm_log) - window),
348 | [onp.mean(grad_norm_log[i:i+window]) for i in range(len(grad_norm_log) - window)])
349 | plt.yscale('log')
350 | plt.title('Global gradient norm through iterations')
351 | plt.show()
--------------------------------------------------------------------------------
/allen_cahn/CausalAllenCahnModel.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from jax import random
4 | from jax import vmap, jit, grad, value_and_grad
5 |
6 | import optax
7 | from optax._src import linear_algebra
8 |
9 | from functools import partial
10 | import itertools
11 | from tqdm import trange
12 | import torch.utils.data as data
13 |
14 | import matplotlib.pyplot as plt
15 |
16 | import sys
17 | sys.path.append('..') # makes modules in parent repository available to import
18 | from models import plot_logs
19 |
20 |
21 | class SquareDataset(data.Dataset):
22 | ''' A data loader for creating mini-batches of uniformly samples points on the
23 | inside of a square/rectangle. Generates iid points on the interior of the
24 | square/rectangle. These points are returned ordered by time (earliest first).
25 |
26 | Attributes:
27 | key: a PRNG key used as the random key.
28 | minvals: pair indicating minimum values for each dimension.
29 | minvals: pair indicating maximum values for each dimension.
30 | batch_size: the size of each mini-batch.
31 | '''
32 | def __init__(self, key, minvals=(-1,-1), maxvals=(1,1), batch_size=10_000):
33 | super().__init__()
34 | self.minvals = jnp.array(minvals)
35 | self.maxvals = jnp.array(maxvals)
36 | self.size = batch_size
37 | self.key = key
38 | self.batch_size = batch_size
39 |
40 | def __len__(self):
41 | return self.size
42 |
43 | def __getitem__(self, idx):
44 | self.key, subkey = random.split(self.key)
45 | interior_batch = self.__select_batch(subkey)
46 | return interior_batch
47 |
48 | @partial(jit, static_argnums=(0,))
49 | def __select_batch(self, subkey):
50 | interior_batch = random.uniform(subkey, shape=(self.batch_size, 2),
51 | minval=self.minvals, maxval=self.maxvals)
52 | # time is last coordinate
53 | idx_sort = jnp.argsort(interior_batch[:,-1])
54 | # returns batch with times ordered in increasing time
55 | return interior_batch[idx_sort]
56 |
57 |
58 | def get_ac_residual_fn(f, D=1e-4):
59 | '''
60 | Computes the Allen-Cahn (AC) residual of a given funciton in a
61 | computationally efficient way by avoiding unecessary repetition of the
62 | computational graph.
63 |
64 | Args:
65 | f (Callable): a function with signature (params, x, t)->scalar
66 | D (float): the D constant parameter for the AC equation.
67 | Returns:
68 | (Callable): a function with signature (params, x, t)->scalar that is the
69 | AC residual of input function
70 | '''
71 | _single_f = lambda params, x, t : f(params, x[None,:], t[None,:]).squeeze()
72 | def _scalar_grads(params, x, t):
73 | u, (ux, ut) = value_and_grad(_single_f, argnums=(1,2))(params, x, t)
74 | return ux.squeeze(), (u, ut.squeeze())
75 | def _ac_aux(params, x, t):
76 | (u_x, (u, u_t)), u_xx = value_and_grad(_scalar_grads,
77 | argnums=1,
78 | has_aux=True)(params, x, t)
79 | return u, u_t, u_xx.squeeze()
80 | def _res_fn(params, xs, ts):
81 | u, ut, uxx = vmap(_ac_aux, in_axes=(None, 0, 0))(params, xs, ts)
82 | return ut - D*uxx + 5*(u**3 - u)
83 | return _res_fn
84 |
85 |
86 |
87 | class AllenCahnModel:
88 | """ A model for training/evaluating a neural network using physics-informed
89 | causal training on the Allen-Cahn problem.
90 |
91 | Attributes:
92 | arch (nn.Module): a Flax module of the desired architecute.
93 | batch: an initial batch used for initializing parameters and computing
94 | normalization factors.
95 | true_fun: the true ground_truth function, if known (default: None).
96 | D: the D constant parameter for the AC equation.
97 | optimizer: the optimizer to be used when running gradient descent.
98 | normalize_inputs: whether to normalize inputs before passing them to the
99 | architecture (default: True).
100 | key: a PRNG key used as the random key for initialization.
101 | exact_bd_condition: whether to exactly enforce boundary condition on
102 | [-1,1]^2 (see https://arxiv.org/abs/2104.08426). This implicitly
103 | makes it so that border losses are not computed (defaut: True).
104 | bdr_enforcer_order: order of the polynomial used for enforcing border
105 | exactly. Should be an even integer (default: 2).
106 | steps_per_check (int): how many training steps to use between logging
107 | and displaying losses (default: 50).
108 | """
109 | def __init__(self, arch, batch, true_sol=None, D=1e-4,
110 | optimizer=None, normalize_inputs=True, key=random.PRNGKey(43),
111 | exact_bd_condition=True, bdr_enforcer_order=2,
112 | steps_per_check=50) -> None:
113 | # Define model
114 | self.arch = arch
115 | self.key = key
116 | self.steps_per_check = steps_per_check
117 | self.true_sol = true_sol
118 | self.D = D
119 |
120 | # Initialize parameters
121 | interior_batch = batch
122 | self.params = self.arch.init(self.key, interior_batch)
123 |
124 | # Tabulate function for checking network architecture
125 | self.tabulate = lambda : \
126 | self.arch.tabulate(self.key,
127 | interior_batch,
128 | console_kwargs={'width':110})
129 |
130 | # Vectorized functions
131 | self.normalize_inputs = normalize_inputs
132 | self.exact_bd_condition = exact_bd_condition
133 | self.bdr_enforcer_order = bdr_enforcer_order # should be an even number
134 | if normalize_inputs:
135 | mu_x = jnp.hstack(interior_batch).mean(0, keepdims=True)
136 | sig_x = jnp.hstack(interior_batch).std(0, keepdims=True)
137 | self.norm_stats = (mu_x, sig_x)
138 | _apply = lambda params, x, y : \
139 | self.arch.apply(params, (jnp.hstack([x, y])-mu_x)/sig_x)
140 | if self.exact_bd_condition:
141 | _apply = lambda params, x, t : \
142 | (1-t)*(x**2 * jnp.cos(jnp.pi*x)) \
143 | + t*((1-x**self.bdr_enforcer_order)*self.arch.apply(params,
144 | (jnp.hstack([x, t])-mu_x)/sig_x) - 1)
145 | else:
146 | _apply = lambda params, x, y : \
147 | self.arch.apply(params, (jnp.hstack([x, y])-mu_x)/sig_x)
148 | else:
149 | self.norm_stats = None
150 | if self.exact_bd_condition:
151 | _apply = lambda params, x, t : \
152 | (1-t)*(x**2 * jnp.cos(jnp.pi*x)) \
153 | + t*((1-x**self.bdr_enforcer_order)*self.arch.apply(params,
154 | jnp.hstack([x, t])) - 1)
155 | else:
156 | _apply = lambda params, x, y : \
157 | self.arch.apply(params, jnp.hstack([x, y]))
158 | # jits apply function for numerical consistency (sometimes jitted
159 | # version behaves slightly differently than non-jitted one)
160 | self.apply = jit(_apply)
161 |
162 | # Vectorized derivatives.
163 | # functions prefixed by '_single' take in a vector of shape (1,) and
164 | # output a scalar of shape (,)
165 | _single_f = lambda params, x, y : \
166 | self.apply(params, x[None,:], y[None,:]).squeeze()
167 | # x derivatives
168 | _single_f_x = lambda params, x, y : \
169 | grad(_single_f, argnums=1)(params, x, y).squeeze() # scalar
170 | self.f_x = vmap(_single_f_x, in_axes=(None, 0, 0))
171 | _single_f_xx = lambda params, x, y : \
172 | grad(_single_f_x, argnums=1)(params, x, y).squeeze() # scalar
173 | self.f_xx = vmap(_single_f_xx, in_axes=(None, 0, 0))
174 | # y derivatives
175 | _single_f_y = lambda params, x, y : \
176 | grad(_single_f, argnums=2)(params, x, y).squeeze() # scalar
177 | self.f_y = vmap(_single_f_y, in_axes=(None, 0, 0))
178 | _single_f_yy = lambda params, x, y : \
179 | grad(_single_f_y, argnums=2)(params, x, y).squeeze() # scalar
180 | self.f_yy = vmap(_single_f_yy, in_axes=(None, 0, 0))
181 | # laplacian
182 | self.ac_residual = get_ac_residual_fn(self.apply, D=self.D)
183 |
184 | # Optimizer
185 | if optimizer is None: # use a standard optimizer
186 | lr = optax.exponential_decay(1e-3, transition_steps=1000,
187 | decay_rate=0.8, end_value=1e-7)
188 | self.optimizer = optax.chain(
189 | optax.adaptive_grad_clip(1e-2),
190 | optax.adam(learning_rate=lr),
191 | )
192 | else:
193 | self.optimizer = optimizer
194 | self.opt_state = self.optimizer.init(self.params)
195 |
196 | # Logger
197 | self.itercount = itertools.count()
198 | self.loss_log = []
199 | self.grad_norm_log = []
200 | self.rel_l2_log = []
201 |
202 | def residual_loss(self, params, x, t, causal_eps):
203 | res = self.ac_residual(params, x, t)[:,None] # shape (batch_dim,1)
204 | goal = jnp.zeros_like(res)
205 | res = jnp.mean((res-goal)**2, axis=-1) # shape (batch_dim,)
206 | if causal_eps is None: # no causal learning
207 | return res
208 | else:
209 | # compute causal weights
210 | ws = jax.lax.stop_gradient(jnp.exp(-causal_eps*(jnp.cumsum(res) - res))) # shape (num_ts,)
211 | # make it so that mean value of weights is 1 to maintain loss in the
212 | # same order of magnitude
213 | ws = ws/(ws.mean()+1e-3)
214 | assert ws.shape == res.shape, f"ws is shape {ws.shape} but res is shape {res.shape}"
215 | return ws*res
216 |
217 |
218 |
219 | def pinn_loss(self, params, interior_batch, causal_eps):
220 | r_loss = self.residual_loss(params,
221 | interior_batch[:,0][:,None],
222 | interior_batch[:,1][:,None],
223 | causal_eps)
224 | if self.exact_bd_condition:
225 | # no need to consider border loss, since it will be 0 when bdry
226 | # condition is exactly enforced
227 | return r_loss.mean()
228 | else:
229 | raise NotImplementedError
230 | # consider both residual loss initial condition loss and boundary condition loss
231 | #b_loss = self.border_loss(params, border_batch[:,0][:,None], border_batch[:,1][:,None])
232 | #return self.pinn_weights[0]*r_loss.mean() + self.pinn_weights[1]*b_loss.mean()
233 |
234 |
235 | @partial(jit, static_argnums=(0,))
236 | def loss(self, params, batch, causal_eps):
237 | interior_batch = batch
238 | return self.pinn_loss(params, interior_batch, causal_eps).mean() # scalar
239 |
240 |
241 | # Define a compiled update step
242 | @partial(jit, static_argnums=(0,))
243 | def step(self, params, opt_state, batch, causal_eps):
244 | grads = grad(self.loss)(params, batch, causal_eps)
245 | updates, opt_state = self.optimizer.update(grads, opt_state, params)
246 | params = optax.apply_updates(params, updates)
247 | return params, opt_state, grads
248 |
249 | # Optimize parameters in a loop
250 | def train(self, dataset, nIter = 10_000, causal_eps=None):
251 | """ Trains the neural network for nIter steps using data loader.
252 |
253 | Args:
254 | dataset (SquareDataset): data loader for training.
255 | nIter (int): number of training iterations.
256 | causal_eps (None | float): epsilon for computing causal loss. If
257 | None, does not use causal learning (default: None).
258 | """
259 | data = iter(dataset)
260 | pbar = trange(nIter)
261 | # Main training loop
262 | for it in pbar:
263 | batch = next(data)
264 | self.params, self.opt_state, grads = self.step(self.params,
265 | self.opt_state,
266 | batch,
267 | causal_eps)
268 | # Logger
269 | if it % self.steps_per_check == 0:
270 | l = self.loss(self.params, batch, causal_eps)
271 | g_norm = linear_algebra.global_norm(grads).squeeze()
272 | self.loss_log.append(l)
273 | self.grad_norm_log.append(g_norm)
274 | if self.true_sol is not None:
275 | pred = self.apply(self.params,
276 | self.true_sol[0][0],
277 | self.true_sol[0][1])
278 | true = self.true_sol[1]
279 | rel_l2_error = jnp.sqrt(((pred-true)**2).mean() \
280 | / ((true)**2).mean())
281 | self.rel_l2_log.append(rel_l2_error)
282 | pbar.set_postfix_str(f"loss:{l : .3e}, rel_l2:{rel_l2_error : .2e}, 'grad_norm':{jnp.mean(jnp.array(g_norm)) : .2e}")
283 | else:
284 | pbar.set_postfix({
285 | 'loss': l,
286 | 'grad_norm': jnp.mean(jnp.array(g_norm)),
287 | })
288 |
289 | def plot_logs(self, window=None) -> None:
290 | """ Plots logs of training losses and gradient norms through training.
291 |
292 | Args:
293 | window: desired window for computing moving averages (default: None).
294 | """
295 | plot_logs(self.loss_log, self.grad_norm_log, window=window,
296 | steps_per_check=self.steps_per_check)
297 |
298 | def batched_apply(self, x, batch_size=2_048):
299 | '''Performs forward pass using smaller batches, then concatenates them
300 | together before returning predictions. Useful for avoiding OoM issues
301 | when input is large.
302 |
303 | Args:
304 | x: input to the model
305 | batch_size: maximum batch size for computation.
306 |
307 | Returns:
308 | predictions of the model on input x
309 | '''
310 | num_batches = int(jnp.ceil(len(x) / batch_size))
311 | x_batches = jnp.split(x,
312 | batch_size*(1+jnp.arange(num_batches-1)),
313 | axis=0)
314 | pred_fn = jit(lambda ins : \
315 | self.apply(self.params,
316 | ins[:,0][:,None],
317 | ins[:,1][:,None]))
318 | y_pred = jnp.concatenate([pred_fn(ins) for ins in x_batches], axis=0)
319 | return y_pred
320 |
321 | def get_rmse(self, batch, batch_size=2_048):
322 | # Create predictions
323 | u, s_true = batch
324 | if batch_size is None: # single forward pass
325 | s_pred = self.apply(self.params, u)
326 | else: # breaks prediction into smaller forward passes
327 | s_pred = self.batched_apply(u, batch_size=batch_size)
328 | error = s_pred - s_true
329 | rmse = jnp.sqrt(jnp.mean(error**2))
330 | return rmse
331 |
332 | def plot_predictions(self, batch, return_pred=False, batch_size=2_048,
333 | num_levels = 500):
334 | """Computes and plots model predictions for a given batch of data.
335 |
336 | Args:
337 | batch: data for creating/plotting results.
338 | return_pred: whether to return predictions after plotting
339 | (default: False).
340 | batch_size: batch size for computations (to avoid OoM issues in the
341 | case of large datasets). (default: 2048)
342 | num_levels: number of levels for contour plot (default: 500).
343 | """
344 | # Create predictions
345 | u, s_true = batch
346 | if batch_size is None: # single forward pass
347 | s_pred = self.apply(self.params, u)
348 | else: # breaks prediction into smaller forward passes
349 | s_pred = self.batched_apply(u, batch_size=batch_size)
350 |
351 | error = s_pred - s_true
352 | rel_l2_error = jnp.sqrt(jnp.sum(error**2)/jnp.sum(s_true**2))
353 | print('Relative L2 error: {:.2e}'.format(rel_l2_error))
354 | print('RMSE: {:.2e}'.format(jnp.sqrt(jnp.mean(error**2))))
355 |
356 | plt.figure(figsize=(16, 4))
357 |
358 | # Ploting examples of reconstructions
359 | plt.subplot(131)
360 | plt.tricontourf(u[:,1], u[:,0],
361 | s_pred.T.squeeze(), levels=num_levels, cmap='jet')
362 | plt.colorbar()
363 | plt.xlabel('$t$')
364 | plt.ylabel('$x$')
365 | plt.title('Prediction')
366 |
367 | # Ploting true solution
368 | plt.subplot(132)
369 | plt.tricontourf(u[:,1], u[:,0],
370 | s_true.T.squeeze(), levels=num_levels, cmap='jet')
371 | plt.colorbar()
372 | plt.xlabel('$t$')
373 | plt.ylabel('$x$')
374 | plt.title('True')
375 |
376 | # Ploting absolute
377 | plt.subplot(133)
378 | plt.tricontourf(u[:,1], u[:,0],
379 | abs(s_pred-s_true).T.squeeze(),
380 | levels=num_levels, cmap='jet')
381 | plt.colorbar()
382 | plt.xlabel('$t$')
383 | plt.ylabel('$x$')
384 | plt.title('Absolute Error')
385 |
386 | plt.show()
387 |
388 | plt.show()
389 |
390 | if return_pred:
391 | return s_pred
--------------------------------------------------------------------------------
/poisson_2d/PoissonModel.py:
--------------------------------------------------------------------------------
1 | # Basic Library Importsk
2 | import jax
3 | import jax.numpy as jnp
4 | from jax import random
5 | from jax import vmap, jit, grad, value_and_grad
6 |
7 | import optax
8 | from optax._src import linear_algebra
9 | import jaxopt
10 |
11 | import matplotlib.pyplot as plt
12 |
13 | from functools import partial
14 | import itertools
15 | from tqdm import trange
16 |
17 | import sys
18 | sys.path.append('..') # makes modules in parent repository available to import
19 | from models import plot_logs
20 |
21 |
22 |
23 | def get_laplacian(f):
24 | '''Computes the 2D laplacian of a given funciton in a computationally
25 | efficient way by avoiding unecessary repetition of the computational graph.
26 |
27 | Args:
28 | f (Callable): a function with signature (params, x, y)->scalar
29 | Returns:
30 | (Callable): a function with signature (params, x, y)->scalar that is the
31 | laplacian of input function
32 | '''
33 | _single_f = lambda params, x, y : f(params, x[None,:], y[None,:]).squeeze()
34 | def _scalar_grads(params, x, y):
35 | ux, uy = grad(_single_f, argnums=(1,2))(params, x, y)
36 | return ux.squeeze(), uy.squeeze()
37 | def _lapl_aux(params, x, y):
38 | (u_x, u_y), u_xx = value_and_grad(_scalar_grads,
39 | argnums=1,
40 | has_aux=True)(params, x, y)
41 | return u_y, u_xx.squeeze()
42 | def _lapl_aux_2(params, x, y):
43 | (u_y, u_xx), u_yy = value_and_grad(_lapl_aux,
44 | argnums=2,
45 | has_aux=True)(params, x, y)
46 | return u_xx, u_yy.squeeze()
47 | def _lapl(params, xs, ys):
48 | uxx, uyy = vmap(_lapl_aux_2, in_axes=(None, 0, 0))(params, xs, ys)
49 | return uxx + uyy
50 | return _lapl
51 |
52 | class PoissonModel:
53 | """ A model for training/evaluating a neural network using physics-informed
54 | training on the 2D Poisson problem.
55 |
56 | Attributes:
57 | arch (nn.Module): a Flax module of the desired architecute.
58 | batch: an initial batch used for initializing parameters and computing
59 | normalization factors.
60 | true_fun: the true ground_truth function, if known (default: None).
61 | pde_res_fn: the target function for the PDE differential operator.
62 | optimizer: the optimizer to be used when running gradient descent.
63 | normalize_inputs: whether to normalize inputs before passing them to the
64 | architecture (default: True).
65 | key: a PRNG key used as the random key for initialization.
66 | pinn_weights: pair of weights for balacing residual/border losses.
67 | exact_bd_condition: whether to exactly enforce boundary condition on
68 | [-1,1]^2 (see https://arxiv.org/abs/2104.08426). This implicitly
69 | makes it so that border losses are not computed (defaut: True).
70 | bdr_enforcer_order: order of the polynomial used for enforcing border
71 | exactly. Should be an even integer (default: 2).
72 | steps_per_check (int): how many training steps to use between logging
73 | and displaying losses (default: 50).
74 | """
75 | def __init__(self, arch, batch, true_fun=None,
76 | pde_res_fn=lambda x, y : \
77 | -(jnp.pi**2)*(1+4*(y**2))*jnp.sin(jnp.pi*x)*jnp.sin(jnp.pi*(y**2)) \
78 | + 2*jnp.pi*jnp.sin(jnp.pi*x)*jnp.cos(jnp.pi*(y**2)),
79 | optimizer=None, normalize_inputs=True, key=random.PRNGKey(43),
80 | pinn_weights=(0.001, 1.), exact_bd_condition=True,
81 | bdr_enforcer_order=2, steps_per_check=50) -> None:
82 | # Define model
83 | self.arch = arch
84 | self.key = key
85 | self.steps_per_check = steps_per_check
86 | self.pde_res_fn = pde_res_fn
87 | self.true_fun = true_fun
88 | self.pinn_weights = pinn_weights # not really used in the experiments (boundary is enforced exactly)
89 |
90 | # Initialize parameters
91 | interior_batch, border_batch = batch
92 | self.params = self.arch.init(self.key, interior_batch)
93 |
94 | # Tabulate function for checking network architecture
95 | self.tabulate = lambda : \
96 | self.arch.tabulate(self.key,
97 | interior_batch,
98 | console_kwargs={'width':110})
99 |
100 | # Vectorized functions
101 | self.normalize_inputs = normalize_inputs
102 | self.exact_bd_condition = exact_bd_condition
103 | self.bdr_enforcer_order = bdr_enforcer_order # should be an even number
104 | if normalize_inputs:
105 | mu_x = jnp.hstack(interior_batch).mean(0, keepdims=True)
106 | sig_x = jnp.hstack(interior_batch).std(0, keepdims=True)
107 | self.norm_stats = (mu_x, sig_x)
108 | _apply = lambda params, x, y : \
109 | self.arch.apply(params, (jnp.hstack([x, y])-mu_x)/sig_x)
110 | if self.exact_bd_condition:
111 | _apply = lambda params, x, y : \
112 | (1-x**self.bdr_enforcer_order)\
113 | *(1-y**self.bdr_enforcer_order)\
114 | *self.arch.apply(params, (jnp.hstack([x, y])-mu_x)/sig_x)
115 | else:
116 | _apply = lambda params, x, y : \
117 | self.arch.apply(params, (jnp.hstack([x, y])-mu_x)/sig_x)
118 | else:
119 | self.norm_stats = None
120 | if self.exact_bd_condition:
121 | _apply = lambda params, x, y : \
122 | (1-x**self.bdr_enforcer_order)\
123 | *(1-y**self.bdr_enforcer_order)\
124 | *self.arch.apply(params, jnp.hstack([x, y]))
125 | else:
126 | _apply = lambda params, x, y : \
127 | self.arch.apply(params, jnp.hstack([x, y]))
128 | # jits apply function for numerical consistency (sometimes jitted
129 | # version behaves slightly differently than non-jitted one)
130 | self.apply = jit(_apply)
131 |
132 | # Vectorized derivatives.
133 | # functions prefixed by '_single' take in a vector of shape (1,) and
134 | # output a scalar of shape (,)
135 | _single_f = lambda params, x, y : \
136 | self.apply(params, x[None,:], y[None,:]).squeeze()
137 | # x derivatives
138 | _single_f_x = lambda params, x, y : \
139 | grad(_single_f, argnums=1)(params, x, y).squeeze() # scalar
140 | self.f_x = vmap(_single_f_x, in_axes=(None, 0, 0))
141 | _single_f_xx = lambda params, x, y : \
142 | grad(_single_f_x, argnums=1)(params, x, y).squeeze() # scalar
143 | self.f_xx = vmap(_single_f_xx, in_axes=(None, 0, 0))
144 | # y derivatives
145 | _single_f_y = lambda params, x, y : \
146 | grad(_single_f, argnums=2)(params, x, y).squeeze() # scalar
147 | self.f_y = vmap(_single_f_y, in_axes=(None, 0, 0))
148 | _single_f_yy = lambda params, x, y : \
149 | grad(_single_f_y, argnums=2)(params, x, y).squeeze() # scalar
150 | self.f_yy = vmap(_single_f_yy, in_axes=(None, 0, 0))
151 | # laplacian
152 | self.laplacian = get_laplacian(self.apply)
153 |
154 | # Optimizer
155 | if optimizer is None: # use a standard optimizer
156 | lr = optax.exponential_decay(1e-3, transition_steps=1000,
157 | decay_rate=0.8, end_value=1e-7)
158 | self.optimizer = optax.chain(
159 | optax.adaptive_grad_clip(1e-2),
160 | optax.adam(learning_rate=lr),
161 | )
162 | else:
163 | self.optimizer = optimizer
164 | self.opt_state = self.optimizer.init(self.params)
165 |
166 | # Optimizer LBFGS
167 | self.optimizer_lbfgs = jaxopt.LBFGS(self.loss)
168 | self.opt_state_lbfgs = self.optimizer_lbfgs.init_state(self.params,
169 | batch)
170 | self.optimizer_update_lbfgs = jit(self.optimizer_lbfgs.update)
171 |
172 | # Logger
173 | self.itercount = itertools.count()
174 | self.loss_log = []
175 | self.grad_norm_log = []
176 | self.rel_l2_log = []
177 |
178 | def residual_loss(self, params, x, y):
179 | res = self.laplacian(params, x, y)[:,None] # shape (batch_dim,1)
180 | goal = self.pde_res_fn(x, y) # shape (batch_dim,1)
181 | return jnp.mean((res-goal)**2, axis=-1) # shape (batch_dim,)
182 |
183 | def border_loss(self, params, x, y):
184 | # function should be zero at the boundary
185 | outputs = self.apply(params, x, y) # shape (batch_dim, out_dim)
186 | return jnp.mean(outputs**2, axis=-1) # shape (batch_dim,)
187 |
188 | def pinn_loss(self, params, interior_batch, border_batch):
189 | r_loss = self.residual_loss(params,
190 | interior_batch[:,0][:,None],
191 | interior_batch[:,1][:,None])
192 | if self.exact_bd_condition:
193 | # no need to consider border loss, since it will be 0 when bdry
194 | # condition is exactly enforced
195 | return self.pinn_weights[0]*r_loss.mean()
196 | else:
197 | # consider both residual loss and boundary condition loss
198 | b_loss = self.border_loss(params,
199 | border_batch[:,0][:,None],
200 | border_batch[:,1][:,None])
201 | return self.pinn_weights[0]*r_loss.mean() \
202 | + self.pinn_weights[1]*b_loss.mean()
203 |
204 |
205 | @partial(jit, static_argnums=(0,))
206 | def loss(self, params, batch):
207 | interior_batch, border_batch = batch
208 | return self.pinn_loss(params, interior_batch, border_batch).mean() # scalar
209 |
210 |
211 | # Define a compiled update step
212 | @partial(jit, static_argnums=(0,))
213 | def step(self, params, opt_state, batch):
214 | grads = grad(self.loss)(params, batch)
215 | updates, opt_state = self.optimizer.update(grads, opt_state, params)
216 | params = optax.apply_updates(params, updates)
217 | return params, opt_state, grads
218 |
219 | # Optimize parameters in a loop
220 | def train(self, dataset, nIter = 10_000):
221 | """ Trains the neural network for nIter steps using data loader.
222 |
223 | Args:
224 | dataset (SquareDataset): data loader for training.
225 | nIter (int): number of training iterations.
226 | """
227 | data = iter(dataset)
228 | pbar = trange(nIter)
229 | # Main training loop
230 | for it in pbar:
231 | batch = next(data)
232 | self.params, self.opt_state, grads = self.step(self.params,
233 | self.opt_state,
234 | batch)
235 | # Logger
236 | if it % self.steps_per_check == 0:
237 | l = self.loss(self.params, batch)
238 | g_norm = linear_algebra.global_norm(grads).squeeze()
239 | self.loss_log.append(l)
240 | self.grad_norm_log.append(g_norm)
241 | if self.true_fun is not None: # true function is known
242 | interior_batch, border_batch = batch
243 | pred = self.apply(self.params,
244 | interior_batch[:,0][:,None],
245 | interior_batch[:,1][:,None])
246 | true = self.true_fun(interior_batch[:,0][:,None],
247 | interior_batch[:,1][:,None])
248 | rel_l2_error = jnp.sqrt(((pred-true)**2).mean() \
249 | / ((true)**2).mean())
250 | self.rel_l2_log.append(rel_l2_error)
251 | pbar.set_postfix_str(
252 | f"loss:{l : .3e}, rel_l2:{rel_l2_error : .2e}, 'grad_norm':{jnp.mean(jnp.array(g_norm)) : .2e}")
253 | else: # true function is unknown
254 | pbar.set_postfix({
255 | 'loss': l,
256 | 'grad_norm': jnp.mean(jnp.array(g_norm))})
257 |
258 | # Define a compiled update step
259 | @partial(jit, static_argnums=(0,))
260 | def step_lbfgs(self, params, opt_state, batch):
261 | new_params, opt_state = self.optimizer_update_lbfgs(params,
262 | opt_state,
263 | batch)
264 | return new_params, opt_state
265 |
266 | # Optimize parameters in a loop
267 | def train_lbfgs(self, dataset, nIter = 10000):
268 | """ Trains the neural network using LBFGS optimizer for nIter steps
269 | using data loader.
270 |
271 | Args:
272 | dataset (SquareDataset): data loader for training.
273 | nIter (int): number of training iterations.
274 | """
275 | data = iter(dataset)
276 | pbar = trange(nIter)
277 | batch = next(data)
278 | self.opt_state_lbfgs = self.optimizer_lbfgs.init_state(self.params,
279 | batch)
280 | # Main training loop
281 | for it in pbar:
282 | batch = next(data)
283 | # Logger
284 | if it % self.steps_per_check == 0:
285 | l = self.loss(self.params, batch)
286 | self.loss_log.append(l)
287 | grads = grad(self.loss)(self.params, batch)
288 | g_norm = linear_algebra.global_norm(grads).squeeze()
289 | self.grad_norm_log.append(g_norm)
290 | if self.true_fun is not None:
291 | interior_batch, border_batch = batch
292 | pred = self.apply(self.params, interior_batch[:,0][:,None],
293 | interior_batch[:,1][:,None])
294 | true = self.true_fun(interior_batch[:,0][:,None],
295 | interior_batch[:,1][:,None])
296 | rel_l2_error = jnp.sqrt(((pred-true)**2).mean() \
297 | / ((true)**2).mean())
298 | self.rel_l2_log.append(rel_l2_error)
299 | pbar.set_postfix_str(
300 | f"loss:{l : .3e}, rel_l2:{rel_l2_error : .2e}, 'grad_norm':{jnp.mean(jnp.array(g_norm)) : .2e}")
301 | else:
302 | pbar.set_postfix({
303 | 'loss': l,
304 | 'grad_norm': jnp.mean(jnp.array(g_norm))})
305 | # take step
306 | self.params, self.opt_state_lbfgs = self.step_lbfgs(self.params,
307 | self.opt_state_lbfgs,
308 | batch)
309 |
310 |
311 | def plot_logs(self, window=None) -> None:
312 | """ Plots logs of training losses and gradient norms through training.
313 |
314 | Args:
315 | window: desired window for computing moving averages (default: None).
316 | """
317 | plot_logs(self.loss_log, self.grad_norm_log, window=window,
318 | steps_per_check=self.steps_per_check)
319 |
320 | def batched_apply(self, x, batch_size=2_048):
321 | '''Performs forward pass using smaller batches, then concatenates them
322 | together before returning predictions. Useful for avoiding OoM issues
323 | when input is large.
324 |
325 | Args:
326 | x: input to the model
327 | batch_size: maximum batch size for computation.
328 |
329 | Returns:
330 | predictions of the model on input x
331 | '''
332 | num_batches = int(jnp.ceil(len(x) / batch_size))
333 | x_batches = jnp.split(x,
334 | batch_size*(1+jnp.arange(num_batches-1)),
335 | axis=0)
336 | pred_fn = jit(lambda ins : self.apply(self.params,
337 | ins[:,0][:,None],
338 | ins[:,1][:,None]))
339 | y_pred = jnp.concatenate([pred_fn(ins) for ins in x_batches], axis=0)
340 | return y_pred
341 |
342 | def get_rmse(self, batch, batch_size=2_048):
343 | # Create predictions
344 | u, s_true = batch
345 | if batch_size is None: # single forward pass
346 | s_pred = self.apply(self.params, u)
347 | else: # breaks prediction into smaller forward passes
348 | s_pred = self.batched_apply(u, batch_size=batch_size)
349 | error = s_pred - s_true
350 | rmse = jnp.sqrt(jnp.mean(error**2))
351 | return rmse
352 |
353 | def plot_predictions(self, batch, return_pred=False, batch_size=2_048,
354 | num_levels = 100):
355 | """Computes and plots model predictions for a given batch of data.
356 |
357 | Args:
358 | batch: data for creating/plotting results.
359 | return_pred: whether to return predictions after plotting
360 | (default: False).
361 | batch_size: batch size for computations (to avoid OoM issues in the
362 | case of large datasets). (default: 2048)
363 | num_levels: number of levels for contour plot (default: 100).
364 | """
365 | # Create predictions
366 | u, s_true = batch
367 | if batch_size is None: # single forward pass
368 | s_pred = self.apply(self.params, u)
369 | else: # breaks prediction into smaller forward passes
370 | s_pred = self.batched_apply(u, batch_size=batch_size)
371 |
372 | error = s_pred - s_true
373 | rel_l2_error = jnp.sqrt(jnp.sum(error**2)/jnp.sum(s_true**2))
374 | print('Relative L2 error: {:.2e}'.format(rel_l2_error))
375 | print('RMSE: {:.2e}'.format(jnp.sqrt(jnp.mean(error**2))))
376 |
377 | plt.figure(figsize=(16, 4))
378 |
379 | # Ploting examples of reconstructions
380 | plt.subplot(131)
381 | plt.tricontourf(u[:,0], u[:,1],
382 | s_pred.squeeze(), levels=num_levels)
383 | plt.colorbar()
384 | plt.xlabel('$x$')
385 | plt.ylabel('$y$')
386 | plt.title('Prediction')
387 |
388 | # Ploting true solution
389 | plt.subplot(132)
390 | plt.tricontourf(u[:,0], u[:,1],
391 | s_true.squeeze(), levels=num_levels)
392 | plt.colorbar()
393 | plt.xlabel('$x$')
394 | plt.ylabel('$y$')
395 | plt.title('True')
396 |
397 | # Ploting absolute
398 | plt.subplot(133)
399 | plt.tricontourf(u[:,0], u[:,1],
400 | abs(s_pred-s_true).squeeze(), levels=num_levels)
401 | plt.colorbar()
402 | plt.xlabel('$x$')
403 | plt.ylabel('$y$')
404 | plt.title('Absolute Error')
405 |
406 | plt.show()
407 |
408 | plt.show()
409 |
410 | if return_pred:
411 | return s_pred
412 |
413 |
--------------------------------------------------------------------------------
/helmholtz_2d/HelmholtzModel.py:
--------------------------------------------------------------------------------
1 | # Basic Library Importsk
2 | import jax
3 | import jax.numpy as jnp
4 | from jax import random
5 | from jax import vmap, jit, grad, value_and_grad
6 |
7 | import optax
8 | from optax._src import linear_algebra
9 | import jaxopt
10 |
11 | import matplotlib.pyplot as plt
12 |
13 | from functools import partial
14 | import itertools
15 | from tqdm import trange
16 |
17 | import sys
18 | sys.path.append('..') # makes modules in parent repository available to import
19 | from models import plot_logs
20 |
21 | def get_value_and_laplacian(f):
22 | '''
23 | Computes the 2D laplacian of a given funciton in a computationally efficient
24 | way by avoiding unecessary repetition of the computational graph.
25 |
26 | Args:
27 | f (Callable): a function with signature (params, x, y)->scalar
28 | Returns:
29 | (Callable): a function with signature (params, x, y)->scalar, scalar that is the
30 | laplacian of input function
31 | '''
32 | _single_f = lambda params, x, y : f(params, x[None,:], y[None,:]).squeeze()
33 | def _scalar_grads(params, x, y):
34 | u, (ux, uy) = value_and_grad(_single_f, argnums=(1,2))(params, x, y)
35 | return ux.squeeze(), (u, uy.squeeze())
36 | def _lapl_aux(params, x, y):
37 | (u_x, (u, u_y)), u_xx = value_and_grad(_scalar_grads,
38 | argnums=1,
39 | has_aux=True)(params, x, y)
40 | return u_y, (u, u_xx.squeeze())
41 | def _lapl_aux_2(params, x, y):
42 | (u_y, (u, u_xx)), u_yy = value_and_grad(_lapl_aux,
43 | argnums=2,
44 | has_aux=True)(params, x, y)
45 | return u, u_xx, u_yy.squeeze()
46 | def _value_and_lapl(params, xs, ys):
47 | u, uxx, uyy = vmap(_lapl_aux_2, in_axes=(None, 0, 0))(params, xs, ys)
48 | return u, uxx + uyy
49 | return _value_and_lapl
50 |
51 |
52 | class HelmholtzModel:
53 | """ A model for training/evaluating a neural network using physics-informed
54 | training on the 2D Helmholtz problem.
55 |
56 | Attributes:
57 | arch (nn.Module): a Flax module of the desired architecute.
58 | batch: an initial batch used for initializing parameters and computing
59 | normalization factors.
60 | true_fun: the true ground_truth function, if known (default: None).
61 | pde_res_fn: the target function for the PDE differential operator.
62 | optimizer: the optimizer to be used when running gradient descent.
63 | normalize_inputs: whether to normalize inputs before passing them to the
64 | architecture (default: True).
65 | key: a PRNG key used as the random key for initialization.
66 | pinn_weights: pair of weights for balacing residual/border losses.
67 | exact_bd_condition: whether to exactly enforce boundary condition on
68 | [-1,1]^2 (see https://arxiv.org/abs/2104.08426). This implicitly
69 | makes it so that border losses are not computed (defaut: True).
70 | bdr_enforcer_order: order of the polynomial used for enforcing border
71 | exactly. Should be an even integer (default: 2).
72 | steps_per_check (int): how many training steps to use between logging
73 | and displaying losses (default: 50).
74 | """
75 | def __init__(self, arch, batch, true_fun=None,
76 | pde_res_fn=lambda x, y : \
77 | -(jnp.pi**2)*(1+4*(y**2))*jnp.sin(jnp.pi*x)*jnp.sin(jnp.pi*(y**2)) \
78 | + 2*jnp.pi*jnp.sin(jnp.pi*x)*jnp.cos(jnp.pi*(y**2)),
79 | kappa = 1,
80 | optimizer=None, normalize_inputs=True, key=random.PRNGKey(43),
81 | pinn_weights=(0.001, 1.), exact_bd_condition=False,
82 | bdr_enforcer_order=2, steps_per_check=50) -> None:
83 | # Define model
84 | self.arch = arch
85 | self.key = key
86 | self.steps_per_check = steps_per_check
87 | self.pde_res_fn = pde_res_fn
88 | self.kappa = kappa
89 | self.true_fun = true_fun
90 | self.pinn_weights = pinn_weights # not really used in the experiments (boundary is enforced exactly)
91 |
92 | # Initialize parameters
93 | interior_batch, border_batch = batch
94 | self.params = self.arch.init(self.key, interior_batch)
95 |
96 | # Tabulate function for checking network architecture
97 | self.tabulate = lambda : \
98 | self.arch.tabulate(self.key,
99 | interior_batch,
100 | console_kwargs={'width':110})
101 |
102 | # Vectorized functions
103 | self.normalize_inputs = normalize_inputs
104 | self.exact_bd_condition = exact_bd_condition
105 | self.bdr_enforcer_order = bdr_enforcer_order # should be an even number
106 | if normalize_inputs:
107 | mu_x = jnp.hstack(interior_batch).mean(0, keepdims=True)
108 | sig_x = jnp.hstack(interior_batch).std(0, keepdims=True)
109 | self.norm_stats = (mu_x, sig_x)
110 | _apply = lambda params, x, y : \
111 | self.arch.apply(params, (jnp.hstack([x, y])-mu_x)/sig_x)
112 | if self.exact_bd_condition:
113 | _apply = lambda params, x, y : \
114 | (1-x**self.bdr_enforcer_order)\
115 | *(1-y**self.bdr_enforcer_order)\
116 | *self.arch.apply(params,
117 | (jnp.hstack([x, y])-mu_x)/sig_x)
118 | else:
119 | _apply = lambda params, x, y : \
120 | self.arch.apply(params,
121 | (jnp.hstack([x, y])-mu_x)/sig_x)
122 | else:
123 | self.norm_stats = None
124 | if self.exact_bd_condition:
125 | _apply = lambda params, x, y : \
126 | (1-x**self.bdr_enforcer_order)\
127 | *(1-y**self.bdr_enforcer_order)\
128 | *self.arch.apply(params,
129 | jnp.hstack([x, y]))
130 | else:
131 | _apply = lambda params, x, y : \
132 | self.arch.apply(params,
133 | jnp.hstack([x, y]))
134 | # jits apply function for numerical consistency (sometimes jitted
135 | # version behaves slightly differently than non-jitted one)
136 | self.apply = jit(_apply)
137 |
138 | # Vectorized derivatives.
139 | # functions prefixed by '_single' take in a vector of shape (1,) and
140 | # output a scalar of shape (,)
141 | _single_f = lambda params, x, y : \
142 | self.apply(params, x[None,:], y[None,:]).squeeze()
143 | # x derivatives
144 | _single_f_x = lambda params, x, y : \
145 | grad(_single_f, argnums=1)(params, x, y).squeeze() # scalar
146 | self.f_x = vmap(_single_f_x, in_axes=(None, 0, 0))
147 | _single_f_xx = lambda params, x, y : \
148 | grad(_single_f_x, argnums=1)(params, x, y).squeeze() # scalar
149 | self.f_xx = vmap(_single_f_xx, in_axes=(None, 0, 0))
150 | # y derivatives
151 | _single_f_y = lambda params, x, y : \
152 | grad(_single_f, argnums=2)(params, x, y).squeeze() # scalar
153 | self.f_y = vmap(_single_f_y, in_axes=(None, 0, 0))
154 | _single_f_yy = lambda params, x, y : \
155 | grad(_single_f_y, argnums=2)(params, x, y).squeeze() # scalar
156 | self.f_yy = vmap(_single_f_yy, in_axes=(None, 0, 0))
157 | # laplacian
158 | self.value_and_laplacian = get_value_and_laplacian(self.apply)
159 |
160 | # Optimizer
161 | if optimizer is None: # use a standard optimizer
162 | lr = optax.exponential_decay(1e-3, transition_steps=1000,
163 | decay_rate=0.8, end_value=1e-7)
164 | self.optimizer = optax.chain(
165 | optax.adaptive_grad_clip(1e-2),
166 | optax.adam(learning_rate=lr),
167 | )
168 | else:
169 | self.optimizer = optimizer
170 | self.opt_state = self.optimizer.init(self.params)
171 |
172 | # Optimizer LBFGS
173 | self.optimizer_lbfgs = jaxopt.LBFGS(self.loss)
174 | self.opt_state_lbfgs = self.optimizer_lbfgs.init_state(self.params,
175 | batch)
176 | self.optimizer_update_lbfgs = jit(self.optimizer_lbfgs.update)
177 |
178 | # Logger
179 | self.itercount = itertools.count()
180 | self.loss_log = []
181 | self.grad_norm_log = []
182 | self.rel_l2_log = []
183 |
184 | def residual_loss(self, params, x, y):
185 | #res = self.laplacian(params, x, y)[:,None] # shape (batch_dim,1)
186 | u, delta_u = self.value_and_laplacian(params, x, y)
187 | u, delta_u = u[:,None], delta_u[:,None]
188 | res = delta_u + (self.kappa**2)*u
189 | goal = self.pde_res_fn(x, y) # shape (batch_dim,1)
190 | return jnp.mean((res-goal)**2, axis=-1) # shape (batch_dim,)
191 |
192 | def border_loss(self, params, x, y):
193 | # function should be zero at the boundary
194 | outputs = self.apply(params, x, y) # shape (batch_dim, out_dim)
195 | return jnp.mean(outputs**2, axis=-1) # shape (batch_dim,)
196 |
197 | def pinn_loss(self, params, interior_batch, border_batch):
198 | r_loss = self.residual_loss(params,
199 | interior_batch[:,0][:,None],
200 | interior_batch[:,1][:,None])
201 | if self.exact_bd_condition:
202 | # no need to consider border loss, since it will be 0 when bdry
203 | # condition is exactly enforced
204 | return self.pinn_weights[0]*r_loss.mean()
205 | else:
206 | # consider both residual loss and boundary condition loss
207 | b_loss = self.border_loss(params,
208 | border_batch[:,0][:,None],
209 | border_batch[:,1][:,None])
210 | return self.pinn_weights[0]*r_loss.mean() \
211 | + self.pinn_weights[1]*b_loss.mean()
212 |
213 |
214 | @partial(jit, static_argnums=(0,))
215 | def loss(self, params, batch):
216 | interior_batch, border_batch = batch
217 | return self.pinn_loss(params, interior_batch, border_batch).mean() # scalar
218 |
219 |
220 | # Define a compiled update step
221 | @partial(jit, static_argnums=(0,))
222 | def step(self, params, opt_state, batch):
223 | grads = grad(self.loss)(params, batch)
224 | updates, opt_state = self.optimizer.update(grads, opt_state, params)
225 | params = optax.apply_updates(params, updates)
226 | return params, opt_state, grads
227 |
228 | # Optimize parameters in a loop
229 | def train(self, dataset, nIter = 10_000):
230 | """ Trains the neural network for nIter steps using data loader.
231 |
232 | Args:
233 | dataset (SquareDataset): data loader for training.
234 | nIter (int): number of training iterations.
235 | """
236 | data = iter(dataset)
237 | pbar = trange(nIter)
238 | # Main training loop
239 | for it in pbar:
240 | batch = next(data)
241 | self.params, self.opt_state, grads = self.step(self.params,
242 | self.opt_state,
243 | batch)
244 | # Logger
245 | if it % self.steps_per_check == 0:
246 | l = self.loss(self.params, batch)
247 | g_norm = linear_algebra.global_norm(grads).squeeze()
248 | self.loss_log.append(l)
249 | self.grad_norm_log.append(g_norm)
250 | if self.true_fun is not None:
251 | interior_batch, border_batch = batch
252 | pred = self.apply(self.params,
253 | interior_batch[:,0][:,None],
254 | interior_batch[:,1][:,None])
255 | true = self.true_fun(interior_batch[:,0][:,None],
256 | interior_batch[:,1][:,None])
257 | rel_l2_error = jnp.sqrt(((pred-true)**2).mean() \
258 | / ((true)**2).mean())
259 | self.rel_l2_log.append(rel_l2_error)
260 | pbar.set_postfix_str(f"loss:{l : .3e}, rel_l2:{rel_l2_error : .2e}, 'grad_norm':{jnp.mean(jnp.array(g_norm)) : .2e}")
261 | else:
262 | pbar.set_postfix({
263 | 'loss': l,
264 | 'grad_norm': jnp.mean(jnp.array(g_norm)),
265 | })
266 |
267 | # Define a compiled update step
268 | @partial(jit, static_argnums=(0,))
269 | def step_lbfgs(self, params, opt_state, batch):
270 | new_params, opt_state = self.optimizer_update_lbfgs(params,
271 | opt_state,
272 | batch)
273 | return new_params, opt_state
274 |
275 | # Optimize parameters in a loop
276 | def train_lbfgs(self, dataset, nIter = 10000):
277 | """ Trains the neural network using LBFGS optimizer for nIter steps
278 | using data loader.
279 |
280 | Args:
281 | dataset (SquareDataset): data loader for training.
282 | nIter (int): number of training iterations.
283 | """
284 | data = iter(dataset)
285 | pbar = trange(nIter)
286 | batch = next(data)
287 | self.opt_state_lbfgs = self.optimizer_lbfgs.init_state(self.params,
288 | batch)
289 | # Main training loop
290 | for it in pbar:
291 | batch = next(data)
292 | # Logger
293 | if it % self.steps_per_check == 0:
294 | l = self.loss(self.params, batch)
295 | self.loss_log.append(l)
296 | grads = grad(self.loss)(self.params, batch)
297 | g_norm = linear_algebra.global_norm(grads).squeeze()
298 | self.grad_norm_log.append(g_norm)
299 | if self.true_fun is not None:
300 | interior_batch, border_batch = batch
301 | pred = self.apply(self.params,
302 | interior_batch[:,0][:,None],
303 | interior_batch[:,1][:,None])
304 | true = self.true_fun(interior_batch[:,0][:,None],
305 | interior_batch[:,1][:,None])
306 | rel_l2_error = jnp.sqrt(((pred-true)**2).mean() \
307 | / ((true)**2).mean())
308 | self.rel_l2_log.append(rel_l2_error)
309 | pbar.set_postfix_str(f"loss:{l : .3e}, rel_l2:{rel_l2_error : .2e}, 'grad_norm':{jnp.mean(jnp.array(g_norm)) : .2e}")
310 | else:
311 | pbar.set_postfix({
312 | 'loss': l,
313 | 'grad_norm': jnp.mean(jnp.array(g_norm)),
314 | })
315 | # take step
316 | self.params, self.opt_state_lbfgs = self.step_lbfgs(self.params,
317 | self.opt_state_lbfgs,
318 | batch)
319 |
320 |
321 | def plot_logs(self, window=None) -> None:
322 | """ Plots logs of training losses and gradient norms through training.
323 |
324 | Args:
325 | window: desired window for computing moving averages (default: None).
326 | """
327 | plot_logs(self.loss_log, self.grad_norm_log, window=window,
328 | steps_per_check=self.steps_per_check)
329 |
330 | def batched_apply(self, x, batch_size=2_048):
331 | '''Performs forward pass using smaller batches, then concatenates them
332 | together before returning predictions. Useful for avoiding OoM issues
333 | when input is large.
334 |
335 | Args:
336 | x: input to the model
337 | batch_size: maximum batch size for computation.
338 |
339 | Returns:
340 | predictions of the model on input x
341 | '''
342 | num_batches = int(jnp.ceil(len(x) / batch_size))
343 | x_batches = jnp.split(x,
344 | batch_size*(1+jnp.arange(num_batches-1)),
345 | axis=0)
346 | pred_fn = jit(lambda ins : self.apply(self.params,
347 | ins[:,0][:,None],
348 | ins[:,1][:,None]))
349 | y_pred = jnp.concatenate([pred_fn(ins) for ins in x_batches], axis=0)
350 | return y_pred
351 |
352 | def get_rmse(self, batch, batch_size=2_048):
353 | # Create predictions
354 | u, s_true = batch
355 | if batch_size is None: # single forward pass
356 | s_pred = self.apply(self.params, u)
357 | else: # breaks prediction into smaller forward passes
358 | s_pred = self.batched_apply(u, batch_size=batch_size)
359 | error = s_pred - s_true
360 | rmse = jnp.sqrt(jnp.mean(error**2))
361 | return rmse
362 |
363 | def plot_predictions(self, batch, return_pred=False, batch_size=2_048,
364 | num_levels = 100):
365 | """Computes and plots model predictions for a given batch of data.
366 |
367 | Args:
368 | batch: data for creating/plotting results.
369 | return_pred: whether to return predictions after plotting
370 | (default: False).
371 | batch_size: batch size for computations (to avoid OoM issues in the
372 | case of large datasets). (default: 2048)
373 | num_levels: number of levels for contour plot (default: 100).
374 | """
375 | # Create predictions
376 | u, s_true = batch
377 | if batch_size is None: # single forward pass
378 | s_pred = self.apply(self.params, u)
379 | else: # breaks prediction into smaller forward passes
380 | s_pred = self.batched_apply(u, batch_size=batch_size)
381 |
382 | error = s_pred - s_true
383 | rel_l2_error = jnp.sqrt(jnp.sum(error**2)/jnp.sum(s_true**2))
384 | print('Relative L2 error: {:.2e}'.format(rel_l2_error))
385 | print('RMSE: {:.2e}'.format(jnp.sqrt(jnp.mean(error**2))))
386 |
387 | plt.figure(figsize=(16, 4))
388 |
389 | # Ploting examples of reconstructions
390 | plt.subplot(131)
391 | plt.tricontourf(u[:,0], u[:,1],
392 | s_pred.squeeze(), levels=num_levels)
393 | plt.colorbar()
394 | plt.xlabel('$x$')
395 | plt.ylabel('$y$')
396 | plt.title('Prediction')
397 |
398 | # Ploting true solution
399 | plt.subplot(132)
400 | plt.tricontourf(u[:,0], u[:,1],
401 | s_true.squeeze(), levels=num_levels)
402 | plt.colorbar()
403 | plt.xlabel('$x$')
404 | plt.ylabel('$y$')
405 | plt.title('True')
406 |
407 | # Ploting absolute
408 | plt.subplot(133)
409 | plt.tricontourf(u[:,0], u[:,1],
410 | abs(s_pred-s_true).squeeze(), levels=num_levels)
411 | plt.colorbar()
412 | plt.xlabel('$x$')
413 | plt.ylabel('$y$')
414 | plt.title('Absolute Error')
415 |
416 | plt.show()
417 |
418 | plt.show()
419 |
420 | if return_pred:
421 | return s_pred
--------------------------------------------------------------------------------
/archs.py:
--------------------------------------------------------------------------------
1 | # Basic Library Imports
2 | import jax
3 | import jax.numpy as jnp
4 | from jax import random
5 | from jax import vmap, jit
6 |
7 | from flax import linen as nn
8 |
9 | from typing import Any, Callable, Sequence, Tuple, Union
10 |
11 | # acceptable types for matmul precision in JAX
12 | PrecisionLike = Union[None, str, jax.lax.Precision, Tuple[str, str],
13 | Tuple[jax.lax.Precision, jax.lax.Precision]]
14 | # acceptable type for vector shapes
15 | Shape = Sequence[int]
16 |
17 | # identity function
18 | identity = lambda x : x
19 |
20 |
21 | ######################################################
22 | #################### Initializers ####################
23 | ######################################################
24 |
25 | # Siren Initialization
26 | def siren_initializer(key, shape, dtype=jnp.float32):
27 | """
28 | Returns a random vector of desired shape using Siren's initialization.
29 |
30 | Args:
31 | key: a PRNG key used as the random key.
32 | shape: shape of weights.
33 | dtype: the dtype of the weights.
34 |
35 | Returns:
36 | A random Siren weight array with the specified shape and dtype.
37 | """
38 | aux = jnp.sqrt(6. / shape[0])
39 | return random.uniform(key, shape=shape, minval=-aux, maxval=aux, dtype=dtype)
40 |
41 | def siren_first_layer_initializer(key, shape, dtype):
42 | """
43 | Returns a random vector of desired shape using Siren's initialization for the
44 | first layer.
45 |
46 | Args:
47 | key: a PRNG key used as the random key.
48 | shape: shape of weights.
49 | dtype: the dtype of the weights.
50 |
51 | Returns:
52 | A random Siren weight array (first layer) with the specified shape & dtype.
53 | """
54 | aux = 1/shape[0]
55 | return random.uniform(key, shape, minval=-aux, maxval=aux, dtype=dtype)
56 |
57 | # Custom Initialization
58 | def kan_initializer(key, shape, dtype=jnp.float32, sigma_0=0.1):
59 | """
60 | Returns a random vector of desired shape using KAN's initialization.
61 |
62 | Args:
63 | key: a PRNG key used as the random key.
64 | shape: shape of weights.
65 | dtype: the dtype of the weights.
66 | sigma (float): sigma parameter for initialization as specified in KAN paper.
67 |
68 | Returns:
69 | A random KAN weight array with the specified shape and dtype.
70 | """
71 | aux = sigma_0/jnp.sqrt(shape[0])
72 | return aux*random.normal(key, shape=shape, dtype=dtype)
73 |
74 | def get_kan_initializer(sigma=0.1):
75 | """
76 | Returns a KAN initializer with desired choice of sigma.
77 |
78 | Args:
79 | sigma (float): sigma parameter for initialization as specified in KAN paper.
80 |
81 | Returns:
82 | A KAN initializer function.
83 | """
84 | return lambda key, shape, dtype=jnp.float32 : \
85 | kan_initializer(key, shape, dtype=dtype, sigma_0=sigma)
86 |
87 |
88 | ###############################################################
89 | ######################## Architectures ########################
90 | ###############################################################
91 |
92 | #############
93 | #### MLP ####
94 | #############
95 |
96 | class MLP(nn.Module):
97 | """A Multi-Layer Prerception network.
98 |
99 | Attributes:
100 | features: sequence of int detailing width of each layer.
101 | activation: activation function to be used in between layers (default:
102 | nn.gelu).
103 | output_activation: activation for last layer of network (default: identity).
104 | precision: numerical precision of the computation. See ``jax.lax.Precision``
105 | for details. (default: None)
106 | """
107 | features: Sequence[int]
108 | activation : Callable=nn.gelu
109 | output_activation : Callable=identity
110 | precision: PrecisionLike = None
111 |
112 | @nn.compact
113 | def __call__(self, x):
114 | """Forward pass of a MLP network.
115 |
116 | Args:
117 | x: The nd-array to be transformed.
118 |
119 | Returns:
120 | The transformed input x.
121 | """
122 | for feat in self.features[:-1]:
123 | x = self.activation(nn.Dense(feat, precision=self.precision)(x))
124 | x = nn.Dense(self.features[-1], precision=self.precision)(x)
125 | return self.output_activation(x) # different activation on output layer
126 |
127 |
128 | ###############
129 | #### Siren ####
130 | ###############
131 |
132 | # see https://arxiv.org/abs/2006.09661 for details about Siren, which is an MLP
133 | # with sine activation and a specific initialization pattern. See below for an
134 | # iteractive colab notebook provided by the authors:
135 | # https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb
136 |
137 | class Siren(nn.Module):
138 | """A Siren network.
139 |
140 | Attributes:
141 | features: sequence of int detailing width of each layer.
142 | w0: frequency content parameter for mutiplying initial inputs. See Siren
143 | paper for more details.
144 | output_activation: activation for last layer of network (default: identity).
145 | precision: numerical precision of the computation. See ``jax.lax.Precision``
146 | for details. (default: None)
147 | """
148 | features: Sequence[int]
149 | w0 : float
150 | output_activation : Callable=identity
151 | precision: PrecisionLike = None
152 |
153 | @nn.compact
154 | def __call__(self, x):
155 | """Forward pass of a Siren network.
156 |
157 | Args:
158 | x: The nd-array to be transformed.
159 |
160 | Returns:
161 | The transformed input x.
162 | """
163 | x = x*self.w0
164 | x = jnp.sin(nn.Dense(self.features[0],
165 | kernel_init=siren_first_layer_initializer,
166 | precision=self.precision)(x))
167 | for feat in self.features[1:-1]:
168 | x = jnp.sin(nn.Dense(feat,
169 | kernel_init=siren_initializer,
170 | precision=self.precision)(x))
171 | x = nn.Dense(self.features[-1])(x)
172 | return self.output_activation(x)
173 |
174 |
175 | ################
176 | #### ActNet ####
177 | ################
178 |
179 | # from https://www.wolframalpha.com/input?i=E%5B%28sin%28wx%2Bp%29%29%5D+where+x+is+normally+distributed
180 | def _mean_transf(mu, sigma, w, p):
181 | """ Mean of the R.V. Y=sin(w*X+p) when X is normally distributed with mean mu
182 | and standard deviation sigma.
183 |
184 | Args:
185 | mu: mean of the R.V. X.
186 | sigma: standard deviation of the R.V. X.
187 | w: frequency of the sinusoidal transformation.
188 | p: phase of the sinusoidal transformation.
189 |
190 | Returns:
191 | The mean of the transformed R.V. Y.
192 | """
193 | return jnp.exp(-0.5* (sigma*w)**2) * jnp.sin(p + mu*w)
194 |
195 | # from https://www.wolframalpha.com/input?i=E%5Bsin%28wx%2Bp%29%5E2%5D+where+x+is+normally+distributed
196 | def _var_transf(mu, sigma, w, p):
197 | """ Variance of the R.V. Y=sin(w*X+p) when X is normally distributed with
198 | mean mu and standard deviation sigma.
199 |
200 | Args:
201 | mu: mean of the R.V. X.
202 | sigma: standard deviation of the R.V. X.
203 | w: frequency of the sinusoidal transformation.
204 | p: phase of the sinusoidal transformation.
205 |
206 | Returns:
207 | The variance of the transformed R.V. Y.
208 | """
209 | return 0.5 - 0.5*jnp.exp(-2 * ((sigma*w)**2))*jnp.cos(2*(p+mu*w)) \
210 | -_mean_transf(mu, sigma, w, p)**2
211 |
212 | class ActLayer(nn.Module):
213 | """A ActLayer module.
214 |
215 | For further details on standard choices of initializers, please refer to
216 | Appendix D of the paper: https://arxiv.org/pdf/2410.01990
217 |
218 | Attributes:
219 | out_dim: output dimension of ActLayer.
220 | num_freqs: number of frequencies/basis functions of the ActLayer.
221 | use_bias: whether to add bias the the output (default: True).
222 | freqs_init: initializer for basis function frequencies.
223 | phases_init: initializer for basis function phases.
224 | beta_init: initializer for beta parameter.
225 | lamb_init: initializer for lambda parameter.
226 | bias_init: initializer for bias parameter.
227 | freze_basis: whether to freeze gradients passing thorough basis
228 | functions (default: False).
229 | freq_scaling: whether to scale basis functions to ensure mean 0 and
230 | standard deviation 1 (default: True).
231 | freq_scaling_eps: small epsilon added to the denominator of frequency
232 | scaling for numerical stability (default: 1e-3).
233 | precision: numerical precision of the computation. See
234 | ``jax.lax.Precision`` for details. (default: None)
235 | """
236 | out_dim : int
237 | num_freqs : int
238 | use_bias : bool=True
239 | # parameter initializers
240 | freqs_init : Callable=nn.initializers.normal(stddev=1.) # normal w/ mean 0
241 | phases_init : Callable=nn.initializers.zeros
242 | beta_init : Callable=nn.initializers.variance_scaling(1.,
243 | 'fan_in',
244 | distribution='uniform')
245 | lamb_init : Callable=nn.initializers.variance_scaling(1.,
246 | 'fan_in',
247 | distribution='uniform')
248 | bias_init : Callable=nn.initializers.zeros
249 | # other configurations
250 | freeze_basis : bool=False
251 | freq_scaling : bool=True
252 | freq_scaling_eps : float=1e-3
253 | precision: PrecisionLike = None
254 |
255 | @nn.compact
256 | def __call__(self, x):
257 | """Forward pass of an ActLayer.
258 |
259 | Args:
260 | x: The nd-array to be transformed.
261 |
262 | Returns:
263 | The transformed input x.
264 | """
265 | # x should initially be shape (batch, d)
266 |
267 | # initialize trainable parameters
268 | freqs = self.param('freqs',
269 | self.freqs_init,
270 | (1,1,self.num_freqs)) # shape (1, 1, num_freqs)
271 | phases = self.param('phases',
272 | self.phases_init,
273 | (1,1,self.num_freqs)) # shape (1, 1, num_freqs)
274 | beta = self.param('beta',
275 | self.beta_init,
276 | (self.num_freqs, self.out_dim)) # shape (num_freqs, out_dim)
277 | lamb = self.param('lamb',
278 | self.lamb_init,
279 | (x.shape[-1], self.out_dim)) # shape (d, out_dim)
280 |
281 | if self.freeze_basis:
282 | freqs = jax.lax.stop_gradient(freqs)
283 | phases = jax.lax.stop_gradient(phases)
284 |
285 | # perform basis expansion
286 | x = jnp.expand_dims(x, 2) # shape (batch, d, 1)
287 | x = jnp.sin(freqs*x + phases) # shape (batch_dim, d, num_freqs)
288 | if self.freq_scaling:
289 | x = (x - _mean_transf(0., 1., freqs, phases)) \
290 | / (jnp.sqrt(self.freq_scaling_eps + _var_transf(0., 1.,
291 | freqs, phases)))
292 |
293 |
294 | # efficiently computes eq 6 from https://arxiv.org/pdf/2410.01990 using
295 | # einsum. Depending on hardware and JAX/CUDA version, there may be
296 | # slightly faster alternatives, but we chose this one for the sake of
297 | # simplicity/clarity.
298 | x = jnp.einsum('...ij, jk, ik-> ...k', x, beta, lamb,
299 | precision=self.precision)
300 |
301 | # optionally add bias
302 | if self.use_bias:
303 | bias = self.param('bias',
304 | self.bias_init,
305 | (self.out_dim,))
306 | x = x + bias # Shape (batch_size, out_dim)
307 |
308 | return x # Shape (batch_size, out_dim)
309 |
310 |
311 | class ActNet(nn.Module):
312 | """A ActNet module.
313 |
314 | Attributes:
315 | embed_dim: embedding dimension for ActLayers.
316 | num_layers: how many intermediate blocks are used.
317 | out_dim: output dimension of ActNet.
318 | num_freqs: number of frequencies/basis functions of the ActLayers.
319 | output_activation: output_activation: activation for last layer of
320 | network (default: identity).
321 | op_order: order of operations contained in each intermediate block. This
322 | should be a string containing only 'A' (ActLayer), 'S' (Skip
323 | connection) or 'L' (LayerNorm) characters. (default: 'A')
324 | use_act_bias: whether to add bias the the output of ActLayers
325 | (default: True).
326 | freqs_init: initializer for basis function frequencies of ActLayers.
327 | phases_init: initializer for basis function phases of ActLayers.
328 | beta_init: initializer for beta parameter of ActLayers.
329 | lamb_init: initializer for lambda parameter of ActLayers.
330 | act_bias_init: initializer for bias parameter of ActLayers.
331 | proj_bias_init: initializer for bias parameter of initial projection
332 | Layer.
333 | w0_init: initializer for w0 scale parameter.
334 | w0_fixed: if False, initializes w0 using w0_init. Otherwise uses given
335 | fixed w0 (default: False).
336 | freze_basis: whether to freeze gradients passing thorough basis
337 | functions (default: False).
338 | freq_scaling: whether to scale basis functions to ensure mean 0 and
339 | standard deviation 1 (default: True).
340 | freq_scaling_eps: small epsilon added to the denominator of frequency
341 | scaling for numerical stability (default: 1e-3).
342 | precision: numerical precision of the computation. See
343 | ``jax.lax.Precision`` for details. (default: None)
344 | """
345 | embed_dim : int
346 | num_layers : int # number of layers in the network
347 | out_dim : int # dimension of output vector
348 | num_freqs : int # how many frequencies/basis functions to use in ActLayers
349 | output_activation : Callable = identity
350 | op_order : str='A'
351 | # op_order should be a string containing only 'A' (ActLayer), 'S' (Skip
352 | # connection) or 'L' (LayerNorm) characters. This feature was used for
353 | # development/debugging, but is not used in any experiment of the paper.
354 |
355 | # parameter initializers
356 | freqs_init : Callable=nn.initializers.normal(stddev=1.) # normal w/ mean 0
357 | phases_init : Callable=nn.initializers.zeros
358 | beta_init : Callable=nn.initializers.variance_scaling(1., 'fan_in',
359 | distribution='uniform')
360 | lamb_init : Callable=nn.initializers.variance_scaling(1., 'fan_in',
361 | distribution='uniform')
362 | act_bias_init : Callable=nn.initializers.zeros
363 | proj_bias_init : Callable=lambda key, shape, dtype :\
364 | random.uniform(key, shape, dtype,
365 | minval=-jnp.sqrt(3), maxval=jnp.sqrt(3))
366 |
367 | w0_init : Callable=nn.initializers.constant(30.) # following SIREN strategy
368 | w0_fixed : Union[False, float]=False # if False, initializes w0 as above. Otherwise uses given fixed w0
369 |
370 | # other ActLayer configurations
371 | use_act_bias : bool=True
372 | freeze_basis : bool=False
373 | freq_scaling : bool=True
374 | freq_scaling_eps : float=1e-3
375 | precision: PrecisionLike = None # numerical precision for matrix operations
376 |
377 | @nn.compact
378 | def __call__(self, x):
379 | """Forward pass of an ActNet.
380 |
381 | Args:
382 | x: The nd-array to be transformed.
383 |
384 | Returns:
385 | The transformed input x.
386 | """
387 | # initialize w0 parameter
388 | if self.w0_fixed is False:
389 | # trainable scalar parameter
390 | w0 = self.param('w0',
391 | self.w0_init,
392 | ())
393 | # use softplus to ensure w0 is positive and does not decay to zero
394 | # too fast (used only while debugging).
395 | w0 = nn.softplus(w0)
396 | else: # use user-specified value for w0
397 | w0 = self.w0_fixed
398 | # scale by w0 factor, then project to embeded dimension
399 | x = x*w0
400 | x = nn.Dense(self.embed_dim, bias_init=self.proj_bias_init,
401 | precision=self.precision)(x)
402 |
403 | for _ in range(self.num_layers):
404 | y = x # store initial value as x, do operations on y
405 | for char in self.op_order:
406 | if char == 'A': # ActLayer
407 | y = ActLayer(
408 | out_dim = self.embed_dim,
409 | num_freqs = self.num_freqs,
410 | use_bias = self.use_act_bias,
411 | freqs_init = self.freqs_init,
412 | phases_init = self.phases_init,
413 | beta_init = self.beta_init,
414 | lamb_init = self.lamb_init,
415 | bias_init = self.act_bias_init,
416 | freeze_basis = self.freeze_basis,
417 | freq_scaling = self.freq_scaling,
418 | freq_scaling_eps = self.freq_scaling_eps,
419 | precision=self.precision,
420 | )(y)
421 | elif char == 'S': # Skip connection
422 | y = y + x
423 | elif char == 'L': # LayerNorm
424 | y = nn.LayerNorm()(y)
425 | else:
426 | raise NotImplementedError(f"Could not recognize option '{char}'. Options for op_order should be 'A' (ActLayer), 'S' (Skip connection) or 'L' (LayerNorm).")
427 | x = y # update value of x after all operations are done
428 |
429 | # project to output dimension and potentially use output activation
430 | x = nn.Dense(self.out_dim, kernel_init=nn.initializers.he_uniform(),
431 | precision=self.precision)(x)
432 | x = self.output_activation(x)
433 |
434 | return x
435 |
436 |
437 | ##############
438 | #### KAN #####
439 | ##############
440 |
441 | # Adapted to JAX from the "EfficientKAN" GitHub repository (PyTorch). Code was
442 | # altered as little as possible, for the sake of consistency/fairness.
443 | # https://github.com/Blealtan/efficient-kan/blob/master/src/efficient_kan/kan.py
444 |
445 | class KANLinear(nn.Module):
446 | in_features : int
447 | out_features : int
448 | grid_size : int=5
449 | spline_order: int=3
450 | scale_noise : float=0.1
451 | scale_base : float=1.0
452 | scale_spline : float=1.0
453 | enable_standalone_scale_spline : bool=True
454 | base_activation : Callable=nn.silu
455 | grid_eps : float=0.02
456 | grid_range : Sequence[Union[float, int]]=(-1,1)
457 | precision: PrecisionLike = None
458 |
459 | def setup(self):
460 | h = (self.grid_range[1] - self.grid_range[0]) / self.grid_size
461 | self.h = h
462 | grid = (
463 | (
464 | jnp.arange(start=-self.spline_order, stop=self.grid_size + self.spline_order + 1) * h
465 | + self.grid_range[0]
466 | )
467 | )
468 | self.grid = grid * jnp.ones((self.in_features, 1))
469 |
470 | self.base_weight = self.param('base_weight', # parameter name
471 | nn.initializers.he_uniform(), # initialization funciton
472 | (self.out_features, self.in_features)) # shape info
473 | self.spline_weight = self.param('spline_weight', # parameter name
474 | nn.initializers.he_uniform(), # initialization funciton
475 | (self.out_features, self.in_features, self.grid_size+self.spline_order)) # shape info
476 |
477 | if self.enable_standalone_scale_spline:
478 | self.spline_scaler = self.param('spline_scaler', # parameter name
479 | nn.initializers.he_uniform(), # initialization funciton
480 | (self.out_features, self.in_features)) # shape info
481 |
482 |
483 | def b_splines(self, x: jax.Array):
484 | """
485 | Compute the B-spline bases for the given input tensor.
486 |
487 | Args:
488 | x: Input tensor of shape (batch_size, in_features).
489 |
490 | Returns:
491 | B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
492 | """
493 | assert len(x.shape) == 2 and x.shape[1] == self.in_features
494 |
495 | # grid is shape (in_features, grid_size + 2 * spline_order + 1)
496 | grid = self.grid
497 | x = jnp.expand_dims(x, -1)
498 | bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:]))
499 | for k in range(1, self.spline_order + 1):
500 | bases = (
501 | (x - grid[:, : -(k + 1)])
502 | / (grid[:, k:-1] - grid[:, : -(k + 1)])
503 | * bases[:, :, :-1]
504 | ) + (
505 | (grid[:, k + 1 :] - x)
506 | / (grid[:, k + 1 :] - grid[:, 1:(-k)])
507 | * bases[:, :, 1:]
508 | )
509 |
510 | assert bases.shape == (
511 | x.shape[0],
512 | self.in_features,
513 | self.grid_size + self.spline_order,
514 | )
515 | return bases
516 |
517 | @property
518 | def scaled_spline_weight(self):
519 | return self.spline_weight * (
520 | jnp.expand_dims(self.spline_scaler, -1)
521 | if self.enable_standalone_scale_spline
522 | else 1.0
523 | )
524 |
525 | def __call__(self, x: jax.Array):
526 | assert x.shape[-1] == self.in_features, f"x.shape[-1]={x.shape[-1]} should be equal to {self.in_features}"
527 | original_shape = x.shape
528 | x = x.reshape(-1, self.in_features)
529 |
530 | base_output = jnp.matmul(self.base_activation(x), self.base_weight.T, precision=self.precision)
531 | spline_output = jnp.matmul(
532 | self.b_splines(x).reshape(x.shape[0], -1),
533 | self.scaled_spline_weight.reshape(self.out_features, -1).T,
534 | precision=self.precision,
535 | )
536 | output = base_output + spline_output
537 |
538 | output = output.reshape(*original_shape[:-1], self.out_features)
539 | return output
540 |
541 |
542 | class KAN(nn.Module):
543 | features : Sequence[int]
544 | output_activation : Callable=identity
545 | grid_size : int=5
546 | spline_order: int=3
547 | scale_noise : float=0.1
548 | scale_base : float=1.0
549 | scale_spline : float=1.0
550 | enable_standalone_scale_spline : bool=True
551 | base_activation : Callable=nn.silu
552 | grid_eps : float=0.02
553 | grid_range : Sequence[Union[float, int]]=(-1,1)
554 | precision: PrecisionLike = None
555 |
556 | def setup(self):
557 | self.layers = [KANLinear(
558 | self.features[i],
559 | self.features[i+1],
560 | grid_size=self.grid_size,
561 | spline_order=self.spline_order,
562 | scale_noise=self.scale_noise,
563 | scale_base=self.scale_base,
564 | scale_spline=self.scale_spline,
565 | enable_standalone_scale_spline=self.enable_standalone_scale_spline,
566 | base_activation=self.base_activation,
567 | grid_eps=self.grid_eps,
568 | grid_range=self.grid_range,
569 | precision=self.precision,
570 | ) for i in range(len(self.features) - 1)]
571 |
572 | def __call__(self, x):
573 | for l in self.layers:
574 | x = l(x)
575 | return self.output_activation(x)
576 |
577 |
578 |
579 | ############################################################
580 | ################### Architecture Builder ###################
581 | ############################################################
582 |
583 | def arch_from_config(arch_config):
584 | ''' Given a config file, outputs architecture object with given
585 | configurations.
586 |
587 | Args:
588 | arch_config: config file specifying architecture hyperparameters.
589 |
590 | Returns:
591 | Architecture as a Flax Linen nn.Module.
592 | '''
593 | if arch_config.arch_type == 'ActNet':
594 | arch = ActNet(**arch_config.hyperparams)
595 | return arch
596 | elif arch_config.arch_type == 'MLP':
597 | arch = MLP(**arch_config.hyperparams)
598 | return arch
599 | elif arch_config.arch_type == 'Siren':
600 | arch = Siren(**arch_config.hyperparams)
601 | return arch
602 | elif arch_config.arch_type == 'KAN':
603 | arch = KAN(**arch_config.hyperparams)
604 | return arch
605 | else:
606 | raise NotImplementedError(f"Cannot recognize arch_type {arch_config.arch_type}. So far, only 'ActNet', 'MLP', 'Siren' and 'KAN' are implemented")
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------