├── .DS_Store ├── .gitignore ├── README.md ├── examples ├── .DS_Store ├── README.md ├── basic_dgp.py ├── basic_gp.py ├── basic_svgp.py ├── imgs │ ├── 1layer_deepgp.png │ └── 2layer_deepgp.png └── notebooks │ ├── .DS_Store │ ├── inducing_point_locs.gif │ └── inducing_var.py ├── ladax ├── __init__.py ├── distributions │ ├── __init__.py │ └── multivariate_normal.py ├── gaussian_processes │ ├── __init__.py │ ├── gaussian_process_layers.py │ ├── gaussian_processes.py │ └── inducing_variables.py ├── kernels │ ├── __init__.py │ ├── kernel_layers.py │ └── kernels.py ├── likelihoods │ ├── __init__.py │ └── gaussian_log_likelihood.py ├── losses │ ├── __init__.py │ └── gaussian_likelihood.py ├── models │ ├── __init__.py │ └── svgp.py └── utils.py ├── requirements.txt └── setup.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljtait/ladax/6046fd30278343a9204c1d8627f9ff53f3ab84cc/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # mac os stuff 141 | .DS_Store 142 | 143 | # IDE stuff 144 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LADAX: Layers of distributions using FLAX/JAX 2 | 3 | ## Introduction 4 | 5 | Small demonstration of using the FLAX package to create layers 6 | of distributions. Current demonstrations focus on using Gaussian 7 | processes. Why? Because once the work is done in creating the 8 | basic `GaussianProcessLayers` etc. we can use the FLAX functional 9 | layers API to 10 | 1. Easily combine simpler GPs to create DeepGPs 11 | 2. Easily slot GPs into other deep learning frameworks. 12 | 13 | 14 | Briefly the design envisions three components 15 | 16 | #### 1. Distributions 17 | A probability distribution, represented as a valid JAX type, 18 | this can be achieved by registering the object as a pytree 19 | node. This process is made convenient using the 20 | `struct.dataclass` decorator from FLAX. 21 | 22 | #### 2. Distribution layers 23 | These are instances of a `flax.nn.Module` objects which 24 | accept some input valid JAX type, and return an output 25 | in the form of a distribution. 26 | 27 | #### 3. Providers 28 | Like the above, only without an input! An example is 29 | the `RBFKernelProvider` which returns a `Kernel`, 30 | a `struct` decorated container of the exponentiated 31 | quadratic kernel function. Because these components 32 | subclass `flax.nn.Module` they are a convenient place 33 | to handle initialisation and storage of parameters. 34 | The motivation for this distinction is that it often 35 | easier to canonicalise the parameters of a distribution 36 | returned by a layer, and outsource subtleties and 37 | variations of these parameterisations in a seperate 38 | module. 39 | 40 | The following code snippet violates the three definitions 41 | above (WIP!), but gives an idea 42 | ```python 43 | class SVGP(nn.Module): 44 | def apply(self, x): 45 | kernel_fn = kernel_provider(x, **kernel_fn_kwargs) 46 | inducing_var = inducing_variable_provider(x, kernel_fn, **inducing_var_kwargs) 47 | vgp = SVGPLayer(x, mean_fn, kernel_fn, inducing_var) 48 | return vgp 49 | ``` 50 | in the above we have the following 51 | * A `GP` is canonicalised by a `mean_fn` and `kernel_fn`, we abstract away the 52 | specification and parameterisation of these objects to another module. 53 | 54 | ## ToDo 55 | 56 | * Remove `likelihoods` and put this functionality into `losses`, and make the 57 | layer loss functions in `losses` import and parameterise the objects in 58 | `distributions`. 59 | * Kernel algebra -- sums, products of kernels etc. 60 | * Apply kernel providers only to slices of index points 61 | * Examples of deep GPs with multiple GPs per layer, perhaps create an `IndependentGP` 62 | collection 63 | * More general multioutput GPs 64 | * Stop putting `index_points` through the kernel provider layers, just pass the number 65 | of features 66 | * More losses -- Poisson etc. for count data -------------------------------------------------------------------------------- /examples/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljtait/ladax/6046fd30278343a9204c1d8627f9ff53f3ab84cc/examples/.DS_Store -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | ## Gaussian processes training example 2 | 3 | Demonstrates using distribution type objects inside NN-Layers 4 | 5 | For a simple example of GP regression 6 | inside of a `flax.nn.Model` run 7 | 8 | ```shell script 9 | $ python basic_gp.py 10 | ``` 11 | 12 | For an example of fitting a variational Gaussian 13 | process and to plot the resulting fit 14 | 15 | ```shell script 16 | $ python basic_svgp.py --plot=True 17 | ``` 18 | 19 | Finally to fit a Deep-GP to a step function, using 20 | different numbers of layers run 21 | 22 | ```shell script 23 | $ python basic_dgp.py --plot=True --num_layers=2 24 | ``` 25 | 26 | 1-Layer "Deep"-GP | 2-Layer Deep GP 27 | :--------:|:----------------: 28 | ![](./imgs/1layer_deepgp.png) | ![](./imgs/2layer_deepgp.png) 29 | 30 | ## Inducing variables 31 | 32 | The sparse variational GP uses an additional set of 'pseudo'-index points, 33 | typically referred to as inducing points, to achieve increased computational 34 | efficiency. Specifying a `SVGP` model therefore requires the usual components 35 | for specifying a GP, along with a layer which provides the inducing variables. 36 | 37 | ![](./notebooks/inducing_point_locs.gif) | 38 | :----------:| 39 | Trajectories of the inducing point locations during training | -------------------------------------------------------------------------------- /examples/basic_dgp.py: -------------------------------------------------------------------------------- 1 | from jax.config import config; config.update("jax_enable_x64", True) 2 | 3 | from absl import app 4 | from absl import flags 5 | from absl import logging 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | from flax import nn, optim 10 | from jax import random 11 | 12 | from ladax import kernels 13 | from ladax import distributions 14 | from ladax import gaussian_processes 15 | from ladax.gaussian_processes import inducing_variables 16 | 17 | FLAGS = flags.FLAGS 18 | 19 | flags.DEFINE_float( 20 | 'learning_rate', default=0.01, 21 | help=('The learning rate for the adam optimizer.')) 22 | 23 | flags.DEFINE_float( 24 | 'beta1', default=0.9, 25 | help=('The beta1 parameter of the adam optimizer.')) 26 | 27 | flags.DEFINE_integer( 28 | 'num_epochs', default=1000, 29 | help=('Number of training epochs.')) 30 | 31 | flags.DEFINE_integer( 32 | 'num_samples', default=10, 33 | help=('Number of samples to approximate the ELBO.')) 34 | 35 | flags.DEFINE_bool( 36 | 'plot', default=False, 37 | help=('Plot the results.',)) 38 | 39 | flags.DEFINE_integer( 40 | 'num_inducing_points', default=10, 41 | help=('Number of inducing points epochs.')) 42 | 43 | flags.DEFINE_boolean( 44 | 'whiten', default=True, 45 | help=('Apply the whitening transform to the inducing variable prior.', )) 46 | 47 | flags.DEFINE_integer( 48 | 'num_layers', default=2, 49 | help=('Number of layers for the deep GP model.', )) 50 | 51 | 52 | class LikelihoodProvider(nn.Module): 53 | def apply(self, 54 | x: jnp.ndarray) -> distributions.MultivariateNormalDiag: 55 | """ 56 | Args: 57 | x: nd-array 58 | Returns: 59 | ll: 60 | """ 61 | obs_noise_scale = jax.nn.softplus( 62 | self.param('observation_noise_scale', 63 | (1, ), 64 | lambda key, shape: 1.0e-1*jnp.ones([1]))) 65 | return distributions.MultivariateNormalDiag( 66 | mean=x[..., 0], scale_diag=jnp.ones(x.shape[:-1])*obs_noise_scale) 67 | 68 | 69 | class DeepGPModel(nn.Module): 70 | def apply(self, x, sample_key, **kwargs): 71 | """ 72 | Args: 73 | x: nd-array input index points for the Deep GP model. 74 | sample_key: random number generator for stochastic inference. 75 | **kwargs: additional kwargs passed to layers. 76 | Returns: 77 | loglik: The output observation model. 78 | vgps: Intermediate variational GP outputs for each layer 79 | """ 80 | vgps = {} 81 | 82 | mf = lambda x_: jnp.zeros(x_.shape[:-1]) # initial mean_fun 83 | for layer in range(1, FLAGS.num_layers+1): 84 | kf = kernels.RBFKernelProvider( 85 | x, 86 | name='kernel_fn_{}'.format(layer), 87 | **kwargs.get('kernel_fn_{}_kwargs'.format(layer), {})) 88 | 89 | inducing_var = inducing_variables.InducingPointsProvider( 90 | x, 91 | kf, 92 | name='inducing_var_{}'.format(layer), 93 | num_inducing_points=FLAGS.num_inducing_points, 94 | **kwargs.get('inducing_var_{}_kwargs'.format(layer), {})) 95 | 96 | vgp = gaussian_processes.SVGPLayer( 97 | x, mf, kf, 98 | inducing_var, 99 | name='vgp_{}'.format(layer)) 100 | 101 | x = vgp.marginal().sample(sample_key)[..., jnp.newaxis] 102 | vgps[layer] = vgp 103 | 104 | mf = lambda x_: x_[..., 0] # identity mean_fn for later layers. 105 | 106 | loglik = LikelihoodProvider(x, name='loglik') 107 | 108 | return loglik, vgps 109 | 110 | 111 | def create_model(key, input_shape): 112 | 113 | def inducing_loc_init(key, shape): 114 | return jnp.linspace(-1.5, 1.5, FLAGS.num_inducing_points)[:, jnp.newaxis] 115 | 116 | kwargs = {} 117 | for i in range(1, FLAGS.num_layers + 1): 118 | kwargs['kernel_fn_{}_kwargs'.format(i)] = { 119 | 'amplitude_init': lambda key, shape: jnp.ones(shape), 120 | 'length_scale_init': lambda key, shape: jnp.ones(shape)} 121 | kwargs['inducing_var_{}_kwargs'.format(i)] = { 122 | 'fixed_locations': False, 123 | 'whiten': FLAGS.whiten, 124 | 'inducing_locations_init': inducing_loc_init} 125 | 126 | model_def = DeepGPModel.partial(**kwargs) 127 | 128 | with nn.stochastic(key): 129 | _, params = model_def.init_by_shape( 130 | key, 131 | [(input_shape, jnp.float64), ], 132 | nn.make_rng(), 133 | **kwargs) 134 | 135 | return nn.Model(model_def, params) 136 | 137 | 138 | def create_optimizer(model, learning_rate, beta1): 139 | optimizer_def = optim.Adam(learning_rate=learning_rate, beta1=beta1) 140 | optimizer = optimizer_def.create(model) 141 | return optimizer 142 | 143 | 144 | @jax.jit 145 | def multi_sample_train_step(optimizer, batch, sample_keys): 146 | 147 | def loss_fn(model): 148 | def single_sample_loss(key): 149 | ell, vgps = model(batch['index_points'], key) 150 | return -ell.log_prob(batch['y']) + jnp.sum([vgp.prior_kl() for _, vgp in vgps.items()]) 151 | return jnp.mean(jax.vmap(single_sample_loss)(sample_keys)) 152 | 153 | grad_fn = jax.value_and_grad(loss_fn, has_aux=False) 154 | loss, grad = grad_fn(optimizer.target) 155 | 156 | optimizer = optimizer.apply_gradient(grad) 157 | metrics = {'loss': loss} 158 | return optimizer, metrics 159 | 160 | 161 | def train_epoch(optimizer, train_ds, epoch, sample_key): 162 | """Train for a single epoch.""" 163 | optimizer, epoch_metrics = multi_sample_train_step(optimizer, train_ds, sample_key) 164 | epoch_metrics_np = jax.device_get(epoch_metrics) 165 | 166 | logging.info('train epoch: %d, loss: %.4f', 167 | epoch, 168 | epoch_metrics_np['loss']) 169 | 170 | return optimizer, epoch_metrics_np 171 | 172 | 173 | def train(train_ds): 174 | rng = random.PRNGKey(0) 175 | 176 | with nn.stochastic(rng): 177 | model = create_model(rng, train_ds['index_points'].shape) 178 | optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.beta1) 179 | 180 | key = nn.make_rng() 181 | 182 | for epoch in range(1, FLAGS.num_epochs + 1): 183 | key = random.split(key, FLAGS.num_samples + 1) 184 | key, sample_key = (key[0], key[1:]) 185 | optimizer, metrics = train_epoch( 186 | optimizer, train_ds, epoch, sample_key) 187 | 188 | return optimizer 189 | 190 | 191 | def step_fn(x): 192 | if x <= 0.: 193 | return -1. 194 | else: 195 | return 1. 196 | 197 | 198 | def get_datasets(): 199 | rng = random.PRNGKey(123) 200 | index_points = jnp.linspace(-1.5, 1.5, 25) 201 | y = (jnp.array([step_fn(x) for x in index_points]) 202 | + 0.1*random.normal(rng, index_points.shape)) 203 | train_ds = {'index_points': index_points[..., jnp.newaxis], 'y': y} 204 | return train_ds 205 | 206 | 207 | def main(_): 208 | 209 | train_ds = get_datasets() 210 | optimizer = train(train_ds) 211 | 212 | if FLAGS.plot: 213 | import matplotlib.pyplot as plt 214 | 215 | xx_pred = jnp.linspace(-1.5, 1.5)[:, jnp.newaxis] 216 | 217 | num_samples = 50 218 | subkeys = random.split(random.PRNGKey(0), num=num_samples) 219 | 220 | def sample(skey): 221 | ll, vgps = optimizer.target(xx_pred, skey) 222 | return ll.mean 223 | 224 | samples = jax.vmap(sample)(subkeys) 225 | pred_m = jnp.mean(samples, axis=0) 226 | pred_sd = jnp.std(samples, axis=0) 227 | 228 | fig, ax = plt.subplots() 229 | 230 | ax.plot(xx_pred[:, 0], pred_m, 'C0-', 231 | label=r'$\mathbb{E}_{f \sim q(f)}[f(x)]$') 232 | ax.fill_between(xx_pred[:, 0], 233 | pred_m - 2 * pred_sd, 234 | pred_m + 2 * pred_sd, alpha=0.5, label=r'$\pm 2$ std. dev.') 235 | 236 | ax.step(xx_pred[:, 0], [step_fn(x) for x in xx_pred], 'k--', alpha=0.7, label='True function') 237 | ax.plot(train_ds['index_points'][:, 0], train_ds['y'], 238 | 'ks', label='observations') 239 | ax.legend() 240 | plt.show() 241 | 242 | 243 | if __name__ == '__main__': 244 | app.run(main) -------------------------------------------------------------------------------- /examples/basic_gp.py: -------------------------------------------------------------------------------- 1 | from jax.config import config; config.update("jax_enable_x64", True) 2 | 3 | from absl import app 4 | from absl import flags 5 | from absl import logging 6 | 7 | import jax 8 | from jax import random 9 | from jax.tree_util import tree_flatten, tree_unflatten 10 | import jax.numpy as jnp 11 | 12 | from typing import Callable, Tuple 13 | from flax import nn 14 | 15 | import scipy as oscipy 16 | 17 | from ladax import kernels, distributions, gaussian_processes, utils 18 | 19 | FLAGS = flags.FLAGS 20 | 21 | flags.DEFINE_bool( 22 | 'plot', default=False, 23 | help=('Plot the results.', )) 24 | 25 | 26 | class MarginalObservationModel(nn.Module): 27 | """ The observation model p(y|x, {hyper par}) = ∫p(y,f|x)df where f(x) ~ GP(m(x), k(x, x')). """ 28 | def apply(self, pf: distributions.MultivariateNormalTriL) -> distributions.MultivariateNormalFull: 29 | """ Applys the marginal observation model of the conditional 30 | Args: 31 | pf: distribution of the latent GP to be marginalised over, 32 | a `distribution.MultivariateNormal` object. 33 | Returns: 34 | py: the marginalised distribution of the observations, a 35 | `distributions.MultivariateNormal` object. 36 | """ 37 | obs_noise_scale = jax.nn.softplus( 38 | self.param('observation_noise_scale', 39 | (), jax.nn.initializers.ones)) 40 | 41 | covariance = pf.scale @ pf.scale.T 42 | covariance = utils.diag_shift(covariance, obs_noise_scale**2) 43 | 44 | return distributions.MultivariateNormalFull( 45 | pf.mean, covariance) 46 | 47 | 48 | class GaussianProcessLayer(nn.Module): 49 | """ Provides a Gaussian process. 50 | """ 51 | def apply(self, 52 | index_points: jnp.ndarray, 53 | kernel_fn: Callable, 54 | mean_fn: Callable = None, 55 | jitter: float =1e-4): 56 | """ 57 | Args: 58 | index_points: the nd-array of index points of the GP model 59 | kernel_fn: callable kernel function. 60 | mean_fn: callable mean function of the GP model. 61 | (default: `None` is equivalent to lambda x: jnp.zeros(x.shape[:-1])) 62 | jitter: float `jitter` term to add to the diagonal of the covariance 63 | function before computing downstream Cholesky decompositions. 64 | Returns: 65 | p: `distributions.MultivariateNormalTriL` object. 66 | """ 67 | if mean_fn is None: 68 | mean_fn = lambda x: jnp.zeros(x.shape[:-1], dtype=index_points.dtype) 69 | 70 | return gaussian_processes.GaussianProcess(index_points, mean_fn, kernel_fn, jitter) 71 | 72 | 73 | class GPModel(nn.Module): 74 | """ Model for i.i.d noise observations from a GP with 75 | RBF kernel. """ 76 | 77 | def apply(self, x, dtype=jnp.float64) -> distributions.MultivariateNormalFull: 78 | """ 79 | Args: 80 | x: the nd-array of index points of the GP model. 81 | dtype: the data-type of the computation (default: float64) 82 | Returns: 83 | py_x: Distribution of the observations at the index points. 84 | """ 85 | kern_fn = kernels.RBFKernelProvider(x, name='kernel_fn') 86 | mean_fn = lambda x: nn.Dense(x, features=1, name='linear_mean_fn')[..., 0] 87 | gp_x = GaussianProcessLayer(x, kern_fn, mean_fn, name='gp_layer') 88 | py_x = MarginalObservationModel(gp_x.marginal(), name='observation_model') 89 | return py_x 90 | 91 | 92 | def build_par_pack_and_unpack(model): 93 | """ Build utility functions to pack and unpack paramater pytrees 94 | for the scipy optimizers. """ 95 | value_flat, value_tree = tree_flatten(model.params) 96 | section_shapes = [item.shape for item in value_flat] 97 | section_sizes = jnp.cumsum(jnp.array([item.size for item in value_flat])) 98 | 99 | def par_from_array(arr): 100 | value_flat = jnp.split(arr, section_sizes) 101 | value_flat = [x.reshape(s) 102 | for x, s in zip(value_flat, section_shapes)] 103 | 104 | params = tree_unflatten(value_tree, value_flat) 105 | return params 106 | 107 | def array_from_par(params): 108 | value_flat, value_tree = tree_flatten(params) 109 | return jnp.concatenate([item.ravel() for item in value_flat]) 110 | 111 | return par_from_array, array_from_par 112 | 113 | 114 | def get_datasets(sim_key: random.PRNGKey, true_obs_noise_scale: float = 0.5) -> Tuple[dict, dict]: 115 | """ Generate the datasets. """ 116 | index_points = jnp.linspace(-3., 3., 25)[..., jnp.newaxis] 117 | y = (-0.5 + .33 * index_points[:, 0] + 118 | + jnp.sin(index_points[:, 0]) 119 | + true_obs_noise_scale * random.normal(sim_key, index_points.shape[:-1])) 120 | 121 | test_index_points = jnp.linspace(-3., 3., 100)[:, jnp.newaxis] 122 | 123 | train_ds = {'index_points': index_points, 'y': y} 124 | test_ds = {'index_points': test_index_points, 125 | 'y': -0.5 + .33 * test_index_points[:, 0] + jnp.sin(test_index_points[:, 0])} 126 | return train_ds, test_ds 127 | 128 | 129 | def train(train_ds): 130 | """ Complete training of the GP-Model. 131 | Args: 132 | train_ds: Python `dict` with entries `index_points` and `y`. 133 | Returns: 134 | trained_model: A `GPModel` instance with trained hyper-parameters. 135 | """ 136 | rng = random.PRNGKey(0) 137 | 138 | # initialise the model 139 | py, params = GPModel.init(rng, train_ds['index_points']) 140 | model = nn.Model(GPModel, params) 141 | 142 | # utility functions for packing and unpacking param dicts 143 | par_from_array, array_from_par = build_par_pack_and_unpack(model) 144 | 145 | @jax.jit 146 | def loss_fun(model: GPModel, params: dict) -> float: 147 | """ This is clumsier than the usual FLAX loss_fn. """ 148 | py = model.module.call(params, train_ds['index_points']) 149 | return -py.log_prob(train_ds['y']) 150 | 151 | # wrap loss fun for scipy.optimize 152 | def wrapped_loss_fun(arr): 153 | params = par_from_array(arr) 154 | return loss_fun(model, params) 155 | 156 | @jax.jit 157 | def loss_and_grads(x): 158 | return jax.value_and_grad(wrapped_loss_fun)(x) 159 | 160 | res = oscipy.optimize.minimize( 161 | loss_and_grads, 162 | x0=array_from_par(params), 163 | jac=True, 164 | method='BFGS') 165 | 166 | logging.info('Optimisation message: {}'.format(res.message)) 167 | 168 | trained_model = model.replace(params=par_from_array(res.x)) 169 | return trained_model 170 | 171 | 172 | def main(_): 173 | train_ds, test_ds = get_datasets(random.PRNGKey(123)) 174 | trained_model = train(train_ds) 175 | 176 | if FLAGS.plot: 177 | import matplotlib.pyplot as plt 178 | 179 | obs_noise_scale = jax.nn.softplus( 180 | trained_model.params['observation_model']['observation_noise_scale']) 181 | 182 | def learned_kernel_fn(x1, x2): 183 | return kernels.RBFKernelProvider.call( 184 | trained_model.params['kernel_fn'], x1)(x1, x2) 185 | 186 | def learned_mean_fn(x): 187 | return nn.Dense.call(trained_model.params['linear_mean_fn'], x, features=1)[:, 0] 188 | 189 | # prior GP model at learned model parameters 190 | fitted_gp = gaussian_processes.GaussianProcess( 191 | train_ds['index_points'], 192 | learned_mean_fn, 193 | learned_kernel_fn, 1e-4 194 | ) 195 | posterior_gp = fitted_gp.posterior_gp( 196 | train_ds['y'], 197 | test_ds['index_points'], 198 | obs_noise_scale**2) 199 | 200 | pred_f_mean = posterior_gp.mean_function(test_ds['index_points']) 201 | pred_f_var = jnp.diag( 202 | posterior_gp.kernel_function(test_ds['index_points'], test_ds['index_points'])) 203 | 204 | fig, ax = plt.subplots() 205 | ax.fill_between(test_ds['index_points'][:, 0], 206 | pred_f_mean - 2*jnp.sqrt(pred_f_var), 207 | pred_f_mean + 2*jnp.sqrt(pred_f_var), alpha=0.5) 208 | 209 | ax.plot(test_ds['index_points'][:, 0], posterior_gp.mean_function(test_ds['index_points']), '-') 210 | ax.plot(train_ds['index_points'], train_ds['y'], 'ks') 211 | 212 | plt.show() 213 | 214 | 215 | if __name__ == '__main__': 216 | app.run(main) 217 | -------------------------------------------------------------------------------- /examples/basic_svgp.py: -------------------------------------------------------------------------------- 1 | from jax.config import config; 2 | 3 | config.update("jax_enable_x64", True) 4 | 5 | from absl import app 6 | from absl import flags 7 | from absl import logging 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | from flax import nn, optim 12 | from jax import random 13 | 14 | from ladax import kernels 15 | from ladax.gaussian_processes import inducing_variables 16 | from ladax import gaussian_processes 17 | from ladax import likelihoods 18 | 19 | 20 | FLAGS = flags.FLAGS 21 | 22 | flags.DEFINE_float( 23 | 'learning_rate', default=0.001, 24 | help=('The learning rate for the momentum optimizer.')) 25 | 26 | flags.DEFINE_float( 27 | 'momentum', default=0.9, 28 | help=('The decay rate used for the momentum optimizer.')) 29 | 30 | flags.DEFINE_integer( 31 | 'num_epochs', default=1000, 32 | help=('Number of training epochs.')) 33 | 34 | flags.DEFINE_bool( 35 | 'plot', default=False, 36 | help=('Plot the results.',)) 37 | 38 | 39 | class LikelihoodProvider(nn.Module): 40 | def apply(self, 41 | vgp: gaussian_processes.VariationalGaussianProcess) -> likelihoods.GaussianLogLik: 42 | """ 43 | Args: 44 | vgp: variational Gaussian process regression model q(f). 45 | Returns: 46 | ll: log-likelihood model with method `variational_expectations` to 47 | compute ∫ log p(y|f) q(f) df 48 | """ 49 | obs_noise_scale = jax.nn.softplus( 50 | self.param('observation_noise_scale', 51 | (), 52 | jax.nn.initializers.ones)) 53 | variational_distribution = vgp.marginal() 54 | return likelihoods.GaussianLogLik( 55 | variational_distribution.mean, 56 | variational_distribution.scale, obs_noise_scale) 57 | 58 | 59 | class SVGPModel(nn.Module): 60 | def apply(self, x, inducing_locations_init, **kwargs): 61 | """ 62 | Args: 63 | x: the nd-array of index points of the GP model 64 | inducing_locations_init: initializer function for the inducing 65 | variable locations. 66 | Returns: 67 | ell: variational likelihood object. 68 | vgp: the variational GP q(f) = ∫p(f|u)q(u)du where 69 | `q(u) == inducing_var.variational_distribution`. 70 | """ 71 | kern_fun = kernels.RBFKernelProvider( 72 | x, name='kernel_fun', **kwargs.get('kernel_fun_kwargs', {})) 73 | inducing_var = inducing_variables.InducingPointsProvider( 74 | x, 75 | kern_fun, 76 | num_inducing_points=5, 77 | inducing_locations_init=inducing_locations_init, 78 | name='inducing_var') 79 | 80 | vgp = gaussian_processes.SVGPLayer(x, 81 | lambda x_: jnp.zeros(x_.shape[:-1]), 82 | kern_fun, 83 | inducing_var, 84 | name='vgp') 85 | 86 | ell = LikelihoodProvider(vgp, name='ell') 87 | 88 | return ell, vgp 89 | 90 | 91 | def create_model(key, input_shape): 92 | def inducing_loc_init(key, shape): 93 | return random.uniform(key, shape, minval=-3., maxval=3.) 94 | 95 | # pass initializers as kwargs 96 | kernel_fun_kwargs = { 97 | 'amplitude_init': lambda key, shape: jnp.ones(shape), 98 | 'length_scale_init': lambda key, shape: .5 * jnp.ones(shape)} 99 | kwargs = {'kernel_fun_kwargs': kernel_fun_kwargs} 100 | 101 | _, params = SVGPModel.init_by_shape( 102 | key, 103 | [(input_shape, jnp.float64), ], 104 | inducing_locations_init=inducing_loc_init, 105 | **kwargs) 106 | 107 | return nn.Model(SVGPModel, params) 108 | 109 | 110 | def create_optimizer(model, learning_rate, beta): 111 | optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta) 112 | optimizer = optimizer_def.create(model) 113 | return optimizer 114 | 115 | 116 | @jax.jit 117 | def train_step(optimizer, batch): 118 | """Train for a single step.""" 119 | def inducing_loc_init(key, shape): 120 | return random.uniform(key, shape, minval=-3., maxval=3.) 121 | 122 | def loss_fn(model): 123 | ell, vgp = model(batch['index_points'], inducing_loc_init) 124 | return (-ell.variational_expectation(batch['y']) 125 | + vgp.prior_kl()) 126 | 127 | grad_fn = jax.value_and_grad(loss_fn, has_aux=False) 128 | loss, grad = grad_fn(optimizer.target) 129 | optimizer = optimizer.apply_gradient(grad) 130 | metrics = {'loss': loss} 131 | return optimizer, metrics 132 | 133 | 134 | def train_epoch(optimizer, train_ds, epoch): 135 | """Train for a single epoch.""" 136 | optimizer, epoch_metrics = train_step(optimizer, train_ds) 137 | epoch_metrics_np = jax.device_get(epoch_metrics) 138 | 139 | logging.info('train epoch: %d, loss: %.4f', 140 | epoch, 141 | epoch_metrics_np['loss']) 142 | 143 | return optimizer, epoch_metrics_np 144 | 145 | 146 | def train(train_ds): 147 | rng = random.PRNGKey(0) 148 | 149 | num_epochs = FLAGS.num_epochs 150 | 151 | model = create_model(rng, (15, 1)) 152 | optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.momentum) 153 | 154 | for epoch in range(1, num_epochs + 1): 155 | optimizer, metrics = train_epoch( 156 | optimizer, train_ds, epoch) 157 | 158 | return optimizer 159 | 160 | 161 | def main(_): 162 | jnp.set_printoptions(precision=3, suppress=True) 163 | 164 | shape = (15, 1) 165 | index_points = jnp.linspace(-3., 3., shape[0])[:, None] 166 | 167 | rng = random.PRNGKey(123) 168 | 169 | y = (jnp.sin(index_points)[:, 0] 170 | + 0.33 * random.normal(rng, (15,))) 171 | 172 | train_ds = {'index_points': index_points, 'y': y} 173 | 174 | optimizer = train(train_ds) 175 | 176 | if FLAGS.plot: 177 | import matplotlib.pyplot as plt 178 | model = optimizer.target 179 | 180 | def inducing_loc_init(key, shape): 181 | return random.uniform(key, shape, minval=-3., maxval=3.) 182 | 183 | xx_pred = jnp.linspace(-3., 5.)[:, None] 184 | 185 | _, vgp = model(xx_pred, inducing_loc_init) 186 | 187 | pred_m = vgp.mean_function(xx_pred) 188 | pred_v = jnp.diag(vgp.kernel_function(xx_pred, xx_pred)) 189 | 190 | fig, ax = plt.subplots(figsize=(6, 4)) 191 | 192 | ax.fill_between( 193 | xx_pred[:, 0], 194 | pred_m - 2 * jnp.sqrt(pred_v), 195 | pred_m + 2 * jnp.sqrt(pred_v), alpha=0.5) 196 | ax.plot(xx_pred[:, 0], pred_m, '-', 197 | label=r'$\mathbb{E}_{f \sim q(f)}[f(x)]$') 198 | ax.plot(model.params['inducing_var']['locations'][:, 0], 199 | model.params['inducing_var']['mean'], '+', 200 | label=r'$E_{u \sim q(u)}[u]$') 201 | ax.plot(train_ds['index_points'][:, 0], train_ds['y'], 'ks', label='observations') 202 | ax.legend() 203 | plt.show() 204 | 205 | 206 | if __name__ == '__main__': 207 | app.run(main) -------------------------------------------------------------------------------- /examples/imgs/1layer_deepgp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljtait/ladax/6046fd30278343a9204c1d8627f9ff53f3ab84cc/examples/imgs/1layer_deepgp.png -------------------------------------------------------------------------------- /examples/imgs/2layer_deepgp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljtait/ladax/6046fd30278343a9204c1d8627f9ff53f3ab84cc/examples/imgs/2layer_deepgp.png -------------------------------------------------------------------------------- /examples/notebooks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljtait/ladax/6046fd30278343a9204c1d8627f9ff53f3ab84cc/examples/notebooks/.DS_Store -------------------------------------------------------------------------------- /examples/notebooks/inducing_point_locs.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieljtait/ladax/6046fd30278343a9204c1d8627f9ff53f3ab84cc/examples/notebooks/inducing_point_locs.gif -------------------------------------------------------------------------------- /examples/notebooks/inducing_var.py: -------------------------------------------------------------------------------- 1 | from jax.config import config; 2 | 3 | config.update("jax_enable_x64", True) 4 | 5 | from absl import app 6 | from absl import flags 7 | from absl import logging 8 | 9 | import jax 10 | import jax.numpy as jnp 11 | from flax import nn, optim 12 | from jax import random 13 | 14 | import ladax 15 | from ladax import kernels, losses 16 | from ladax.gaussian_processes import inducing_variables 17 | 18 | 19 | import matplotlib.pyplot as plt 20 | import numpy as onp 21 | from collections import namedtuple 22 | 23 | FLAGS = flags.FLAGS 24 | 25 | flags.DEFINE_integer( 26 | 'num_inducing_points', default=7, 27 | help=('Number of inducing points.', )) 28 | 29 | flags.DEFINE_float( 30 | 'learning_rate', default=0.0001, 31 | help=('The learning rate for the momentum optimizer.')) 32 | 33 | flags.DEFINE_float( 34 | 'momentum', default=0.9, 35 | help=('The decay rate used for the momentum optimizer.')) 36 | 37 | flags.DEFINE_integer( 38 | 'num_epochs', default=100, 39 | help=('Number of training epochs.')) 40 | 41 | flags.DEFINE_bool( 42 | 'plot', default=False, 43 | help=('Plot the results.',)) 44 | 45 | flags.DEFINE_integer( 46 | 'num_training_points', default=50, 47 | help=('The number of training points.', )) 48 | 49 | 50 | LossAndModel = namedtuple('LossAndModel', 'loss model') 51 | 52 | 53 | def create_model(key): 54 | 55 | inducing_var_kwargs = { 56 | 'num_inducing_points': FLAGS.num_inducing_points, 57 | 'inducing_locations_init': jax.nn.initializers.normal(stddev=1.), 58 | 'fixed_locations': False, 59 | 'whiten': False} 60 | 61 | svgp_layer_kwargs = {'jitter': 1.0e-4} 62 | 63 | clz = ladax.models.svgp_factory(kernels.RBFKernelProvider, 64 | inducing_variables.InducingPointsProvider, 65 | inducing_variable_kwargs=inducing_var_kwargs, 66 | svgp_layer_kwargs=svgp_layer_kwargs) 67 | 68 | vgp, params = clz.init_by_shape( 69 | key, [([FLAGS.num_training_points, 2], jnp.float32)], ) 70 | 71 | return nn.Model(clz, params) 72 | 73 | 74 | def create_loss(rng, model, train_ds): 75 | 76 | loss_clz = losses.VariationalGaussianLikelihoodLoss 77 | 78 | dist = model(train_ds['index_points']) 79 | _, params = loss_clz.init(rng, train_ds['y'], dist) 80 | return nn.Model(loss_clz, params) 81 | 82 | 83 | def create_optimizer(loss_and_model, learning_rate, beta): 84 | optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta) 85 | optimizer = optimizer_def.create(loss_and_model) 86 | return optimizer 87 | 88 | 89 | def true_function(x, y): 90 | return jnp.cos(x) * y 91 | 92 | 93 | def get_datasets(): 94 | onp.random.seed(123) 95 | index_points = onp.random.normal(size=100).reshape((50, 2)) 96 | y = true_function(*index_points.T) 97 | y += .1*onp.random.randn(*y.shape) 98 | return {'index_points': index_points, 'y': y} 99 | 100 | 101 | @jax.jit 102 | def train_step(optimizer, batch): 103 | """Train for a single step.""" 104 | 105 | def loss_fn(loss_and_model): 106 | vgp = loss_and_model.model(batch['index_points']) 107 | negell = loss_and_model.loss(batch['y'], vgp) 108 | loss = vgp.prior_kl() + negell 109 | return loss, vgp.inducing_variable.locations 110 | 111 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 112 | (loss, z), grad = grad_fn(optimizer.target) 113 | optimizer = optimizer.apply_gradient(grad) 114 | metrics = {'loss': loss, 'z': z} 115 | return optimizer, metrics 116 | 117 | 118 | def train_epoch(optimizer, train_ds, epoch): 119 | """Train for a single epoch.""" 120 | optimizer, epoch_metrics = train_step(optimizer, train_ds) 121 | epoch_metrics_np = jax.device_get(epoch_metrics) 122 | 123 | logging.info('train epoch: %d, loss: %.4f', epoch, epoch_metrics_np['loss']) 124 | 125 | return optimizer, epoch_metrics_np 126 | 127 | 128 | def train(train_ds): 129 | rng = random.PRNGKey(0) 130 | num_epochs = FLAGS.num_epochs 131 | 132 | model = create_model(rng) 133 | loss = create_loss(rng, model, train_ds) 134 | 135 | # we are going to collect the locations of the inducing points 136 | z = [model.params['1']['locations'], ] 137 | 138 | loss_and_model = LossAndModel(loss, model) 139 | optimizer = create_optimizer( 140 | loss_and_model, FLAGS.learning_rate, FLAGS.momentum) 141 | 142 | for epoch in range(1, num_epochs + 1): 143 | optimizer, metrics = train_epoch( 144 | optimizer, train_ds, epoch) 145 | z.append(metrics['z']) 146 | 147 | return optimizer.target, z 148 | 149 | 150 | def main(_): 151 | train_ds = get_datasets() 152 | trained_model_and_loss, z = train(train_ds) 153 | 154 | trained_model = trained_model_and_loss.model 155 | trained_loss = trained_model_and_loss.loss 156 | obs_noise_scale = jax.nn.softplus(trained_loss.params['observation_noise_scale']) 157 | print(obs_noise_scale) 158 | 159 | if FLAGS.plot: 160 | from matplotlib import animation 161 | 162 | vgp = trained_model(train_ds['index_points']) 163 | post_gp = vgp.posterior_gp( 164 | train_ds['y'], 165 | train_ds['index_points'], 166 | obs_noise_scale**2, 167 | jitter=1e-4) 168 | 169 | xmin, xmax = (-3., 3.) 170 | ymin, ymax = (-3., 3.) 171 | 172 | xx = onp.linspace(xmin, xmax, 50) 173 | x, y = onp.meshgrid(xx, xx) 174 | X = onp.column_stack((x.ravel(), y.ravel())) 175 | m = post_gp.mean_function(X) 176 | 177 | true_y = true_function(*X.T) 178 | 179 | vmin = min(m.min(), true_y.min()) 180 | vmax = max(m.max(), true_y.max()) 181 | fig, axes = plt.subplots(ncols=2) 182 | axes[0].contourf(x, y, m.reshape(x.shape), 183 | alpha=0.5, vmin=vmin, vmax=vmax) 184 | axes[0].plot(*z[0].T, 'C1o') 185 | axes[0].plot(*z[-1].T, 'C1^') 186 | 187 | z = jnp.array(z) 188 | 189 | lines = axes[0].plot(z[..., 0], z[..., 1], '-') 190 | 191 | axes[0].plot(*train_ds['index_points'].T, 'k+') 192 | axes[1].contourf(x, y, true_y.reshape(x.shape), 193 | alpha=0.5, vmin=vmin, vmax=vmax) 194 | 195 | for ax in axes: 196 | ax.set_xlim((xmin, xmax)) 197 | ax.set_ylim((ymin, ymax)) 198 | 199 | # animated plot 200 | fig, ax = plt.subplots(figsize=(4, 4)) 201 | 202 | zcol = 'k' 203 | 204 | ax.plot(z[..., 0], z[..., 1], '-', color=zcol) 205 | 206 | lines = [] 207 | for _ in range(FLAGS.num_inducing_points): 208 | line, = ax.plot([], [], 'D', color=zcol) 209 | lines.append(line) 210 | 211 | ax.contour(x, y, m.reshape(x.shape), 'k-', levels=10) 212 | ax.plot(*train_ds['index_points'].T, 'k+', label='obs. index points') 213 | 214 | # initialization function: plot the background of each frame 215 | def init(): 216 | for line in lines: 217 | line.set_data([], []) 218 | return lines 219 | 220 | nskip = 10 221 | def animate(i): 222 | for k in range(FLAGS.num_inducing_points): 223 | x = [z[i*nskip, k, 0], ] 224 | y = [z[i*nskip, k, 1], ] 225 | lines[k].set_data(x, y) 226 | return lines 227 | 228 | nframes = (FLAGS.num_epochs + 1) // nskip 229 | 230 | ax.plot([], [], 'D', color=zcol, label='inducing point locs.') 231 | ax.legend() 232 | 233 | anim = animation.FuncAnimation(fig, animate, init_func=init, 234 | frames=nframes, 235 | interval=20, blit=True, 236 | repeat=True) 237 | anim.save('inducing_point_locs.gif', writer='imagemagick', fps=30) 238 | 239 | plt.show() 240 | 241 | 242 | if __name__ == '__main__': 243 | app.run(main) 244 | -------------------------------------------------------------------------------- /ladax/__init__.py: -------------------------------------------------------------------------------- 1 | from . import distributions 2 | from . import kernels 3 | from . import gaussian_processes 4 | from . import likelihoods 5 | from . import utils 6 | from . import models 7 | from . import losses 8 | -------------------------------------------------------------------------------- /ladax/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from .multivariate_normal import (MultivariateNormalDiag, 2 | MultivariateNormalTriL, 3 | MultivariateNormalFull) 4 | -------------------------------------------------------------------------------- /ladax/distributions/multivariate_normal.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.scipy as jscipy 3 | import abc 4 | from jax import random 5 | from flax import struct 6 | 7 | 8 | @struct.dataclass 9 | class MultivariateNormal: 10 | 11 | @abc.abstractmethod 12 | def log_prob(self, x): 13 | pass 14 | 15 | 16 | @struct.dataclass 17 | class MultivariateNormalDiag(MultivariateNormal): 18 | mean: jnp.ndarray 19 | scale_diag: jnp.ndarray 20 | 21 | def log_prob(self, x): 22 | return jnp.sum( 23 | jscipy.stats.norm.logpdf( 24 | x, loc=self.mean, scale=self.scale_diag)) 25 | 26 | def sample(self, key, shape=()): 27 | return random.normal(key, shape=shape) * self.scale_diag + self.mean 28 | 29 | 30 | @struct.dataclass 31 | class MultivariateNormalTriL(MultivariateNormal): 32 | mean: jnp.ndarray 33 | scale: jnp.ndarray 34 | 35 | def log_prob(self, x): 36 | dim = x.shape[-1] 37 | dev = x - self.mean 38 | maha = jnp.sum(dev * 39 | jscipy.linalg.cho_solve((self.scale, True), dev)) 40 | log_2_pi = jnp.log(2 * jnp.pi) 41 | log_det_cov = 2 * jnp.sum(jnp.log(jnp.diag(self.scale))) 42 | return -0.5 * (dim * log_2_pi + log_det_cov + maha) 43 | 44 | def sample(self, key, shape=()): 45 | full_shape = shape + self.mean.shape 46 | std_normals = random.normal(key, full_shape) 47 | return jnp.tensordot(std_normals, self.scale, [-1, 1]) + self.mean 48 | 49 | @property 50 | def covariance(self): 51 | return self.scale @ self.scale.T 52 | 53 | 54 | @struct.dataclass 55 | class MultivariateNormalFull(MultivariateNormal): 56 | mean: jnp.ndarray 57 | covariance: jnp.ndarray 58 | 59 | def log_prob(self, x): 60 | scale = jnp.linalg.cholesky(self.covariance) 61 | dim = x.shape[-1] 62 | dev = x - self.mean 63 | maha = jnp.sum(dev * 64 | jscipy.linalg.cho_solve((scale, True), dev)) 65 | log_2_pi = jnp.log(2 * jnp.pi) 66 | log_det_cov = 2 * jnp.sum(jnp.log(jnp.diag(scale))) 67 | return -0.5 * (dim * log_2_pi + log_det_cov + maha) 68 | 69 | def sample(self, key, shape=()): 70 | return random.multivariate_normal( 71 | key, self.mean, self.covariance, shape) 72 | -------------------------------------------------------------------------------- /ladax/gaussian_processes/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_processes import (GaussianProcess, 2 | VariationalGaussianProcess) 3 | 4 | from .inducing_variables import (InducingVariable, 5 | InducingPointsVariable, 6 | InducingPointsProvider) 7 | 8 | from .gaussian_process_layers import SVGPLayer 9 | -------------------------------------------------------------------------------- /ladax/gaussian_processes/gaussian_process_layers.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.scipy as jscipy 3 | 4 | from flax import nn 5 | from ladax import kernels, utils 6 | 7 | from .gaussian_processes import VariationalGaussianProcess 8 | 9 | 10 | class SVGPLayer(nn.Module): 11 | def apply(self, 12 | index_points, 13 | mean_fn, 14 | kernel_fn, 15 | inducing_var, 16 | jitter=1e-4): 17 | """ 18 | Args: 19 | index_points: the nd-array of index points of the GP model. 20 | mean_fn: callable mean function of the GP model. 21 | kernel_fn: callable kernel function. 22 | inducing_var: inducing variables `inducing_variables.InducingPointsVariable`. 23 | jitter: float `jitter` term to add to the diagonal of the covariance 24 | function before computing Cholesky decompositions. 25 | Returns: 26 | svgp: A sparse Variational GP model. 27 | """ 28 | z = inducing_var.locations 29 | qu = inducing_var.variational_distribution 30 | qu_mean = qu.mean 31 | qu_scale = qu.scale 32 | 33 | # cholesky of the base kernel function applied at the inducing point 34 | # locations. 35 | kzz_chol = jnp.linalg.cholesky( 36 | utils.diag_shift(kernel_fn(z, z), jitter)) 37 | 38 | if inducing_var.whiten: 39 | qu_mean = kzz_chol @ qu_mean 40 | qu_scale = kzz_chol @ qu_scale 41 | 42 | z = inducing_var.locations 43 | 44 | var_kern = kernels.VariationalKernel( 45 | kernel_fn, z, qu_scale) 46 | 47 | def var_mean(x_): 48 | kxz = kernel_fn(x_, z) 49 | dev = (qu_mean - mean_fn(z))[..., None] 50 | return (mean_fn(x_)[..., None] 51 | + kxz @ jscipy.linalg.cho_solve( 52 | (kzz_chol, True), dev))[..., 0] 53 | 54 | return VariationalGaussianProcess( 55 | index_points, var_mean, var_kern, jitter, inducing_var) 56 | -------------------------------------------------------------------------------- /ladax/gaussian_processes/gaussian_processes.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.scipy as jscipy 3 | from flax import struct, nn 4 | 5 | from typing import Any, Callable 6 | from ladax import distributions, kernels, utils 7 | 8 | 9 | def multivariate_gaussian_kl(q, p): 10 | """ KL-divergence between multivariate Gaussian distributions defined as 11 | ∫ N(q.mean, q.scale) log{ N(q.mean, q.scale) / N (p.mean, p.scale) }. 12 | Args: 13 | q: `MultivariateNormal` object 14 | p: `MultivariateNormal` object 15 | Returns: 16 | kl: Python `float` the KL-divergence between `q` and `p`. 17 | """ 18 | m_diff = q.mean - p.mean 19 | return .5*(2*jnp.log(jnp.diag(p.scale)).sum() - 2*jnp.log(jnp.diag(q.scale)).sum() 20 | - q.mean.shape[-1] 21 | + jnp.trace(jscipy.linalg.cho_solve((p.scale, True), q.scale) @ q.scale.T) 22 | + jnp.sum(m_diff * jscipy.linalg.cho_solve((p.scale, True), m_diff))) 23 | 24 | 25 | @struct.dataclass 26 | class GaussianProcess: 27 | index_points: jnp.ndarray 28 | mean_function: Callable = struct.field(pytree_node=False) 29 | kernel_function: Callable = struct.field(pytree_node=False) 30 | jitter: float 31 | 32 | def marginal(self): 33 | kxx = self.kernel_function(self.index_points, self.index_points) 34 | chol_kxx = jnp.linalg.cholesky(utils.diag_shift(kxx, self.jitter)) 35 | mean = self.mean_function(self.index_points) 36 | return distributions.MultivariateNormalTriL(mean, chol_kxx) 37 | 38 | def posterior_gp(self, y, x_new, observation_noise_variance, jitter=None): 39 | """ Returns a new GP conditional on y. """ 40 | cond_kernel_fn, _ = kernels.SchurComplementKernelProvider.init( 41 | None, 42 | self.kernel_function, 43 | self.index_points, 44 | observation_noise_variance) 45 | 46 | marginal = self.marginal() 47 | 48 | def cond_mean_fn(x): 49 | k_xnew_x = self.kernel_function(x, self.index_points) 50 | return (self.mean_function(x) 51 | + k_xnew_x @ jscipy.linalg.cho_solve( 52 | (cond_kernel_fn.divisor_matrix_cholesky, True), 53 | y - marginal.mean)) 54 | 55 | jitter = jitter if jitter else self.jitter 56 | return GaussianProcess(x_new, 57 | cond_mean_fn, 58 | cond_kernel_fn, 59 | jitter) 60 | 61 | 62 | @struct.dataclass 63 | class VariationalGaussianProcess(GaussianProcess): 64 | """ ToDo(dan): ugly `Any` typing to avoid circular dependency with GP 65 | inside of inducing_variables. Ideally break this by lifting 66 | variational GPs into their own module. 67 | """ 68 | inducing_variable: Any 69 | 70 | def prior_kl(self): 71 | if self.inducing_variable.whiten: 72 | return self.prior_kl_whiten() 73 | else: 74 | qu = self.inducing_variable.variational_distribution 75 | pu = self.inducing_variable.prior_distribution 76 | return multivariate_gaussian_kl(qu, pu) 77 | 78 | def prior_kl_whiten(self): 79 | qu = self.inducing_variable.variational_distribution 80 | log_det = 2*jnp.sum(jnp.log(jnp.diag(qu.scale))) 81 | dim = qu.mean.shape[-1] 82 | return -.5*(log_det + 0.5*dim - jnp.sum(qu.mean**2) - jnp.sum(qu.scale**2)) 83 | -------------------------------------------------------------------------------- /ladax/gaussian_processes/inducing_variables.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax import struct, nn 4 | from jax import random 5 | from ladax.distributions import MultivariateNormalDiag, MultivariateNormalTriL 6 | from ladax.gaussian_processes import GaussianProcess 7 | from typing import Union, Callable 8 | 9 | 10 | @struct.dataclass 11 | class InducingVariable: 12 | variational_distribution: MultivariateNormalTriL 13 | prior_distribution: MultivariateNormalTriL 14 | 15 | 16 | @struct.dataclass 17 | class InducingPointsVariable(InducingVariable): 18 | locations: jnp.ndarray 19 | whiten: bool = False 20 | 21 | 22 | class InducingPointsProvider(nn.Module): 23 | """ Handles parameterisation of an inducing points variable. """ 24 | 25 | def apply(self, 26 | index_points: jnp.ndarray, 27 | kernel_fun: Callable, 28 | num_inducing_points: int, 29 | inducing_locations_init: Union[Callable, None] = None, 30 | fixed_locations: bool = False, 31 | whiten: bool = False, 32 | jitter: float = 1e-4, 33 | dtype: jnp.dtype = jnp.float64) -> InducingPointsVariable: 34 | """ 35 | Args: 36 | index_points: the nd-array of index points of the GP model. 37 | kernel_fun: callable kernel function. 38 | num_inducing_points: total number of inducing points. 39 | inducing_locations_init: initializer function for the inducing 40 | variable locations. 41 | fixed_locations: boolean specifying whether to optimise the inducing 42 | point locations (default True). 43 | whiten: boolean specifying whether to apply the whitening transformation. 44 | (default False) 45 | jitter: float `jitter` term to add to the diagonal of the covariance 46 | function of the GP prior of the inducing variable, only used if no 47 | whitening transform applied. 48 | dtype: the data-type of the computation (default: float64) 49 | Returns: 50 | inducing_var: inducing variables `inducing_variables.InducingPointsVariable` 51 | """ 52 | n_features = index_points.shape[-1] 53 | z_shape = (num_inducing_points, n_features) 54 | if inducing_locations_init is None: 55 | inducing_locations_init = lambda key, shape: random.normal(key, z_shape) 56 | 57 | if fixed_locations: 58 | _default_key = random.PRNGKey(0) 59 | z = inducing_locations_init(_default_key, z_shape) 60 | else: 61 | z = self.param('locations', 62 | (num_inducing_points, n_features), 63 | inducing_locations_init) 64 | 65 | qu_mean = self.param('mean', (num_inducing_points,), 66 | lambda key, shape: jax.nn.initializers.zeros( 67 | key, z_shape[0], dtype=dtype)) 68 | 69 | qu_scale = self.param( 70 | 'scale', 71 | (num_inducing_points, num_inducing_points), 72 | lambda key, shape: jnp.eye(num_inducing_points, dtype=dtype)) 73 | 74 | if whiten: 75 | prior = MultivariateNormalDiag(mean=jnp.zeros(index_points.shape[-1]), 76 | scale_diag=jnp.ones(index_points.shape[-2])) 77 | 78 | else: 79 | prior = GaussianProcess( 80 | z, 81 | lambda x_: jnp.zeros(x_.shape[:-1]), 82 | kernel_fun, 83 | jitter).marginal() 84 | 85 | return InducingPointsVariable( 86 | variational_distribution=MultivariateNormalTriL(qu_mean, jnp.tril(qu_scale)), 87 | prior_distribution=prior, 88 | locations=z, 89 | whiten=whiten) 90 | -------------------------------------------------------------------------------- /ladax/kernels/__init__.py: -------------------------------------------------------------------------------- 1 | from .kernels import Kernel, SchurComplementKernel, VariationalKernel 2 | from .kernel_layers import RBFKernelProvider, SchurComplementKernelProvider 3 | -------------------------------------------------------------------------------- /ladax/kernels/kernel_layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import jax.numpy as jnp 3 | import jax 4 | from flax import nn 5 | 6 | from .kernels import Kernel, SchurComplementKernel 7 | from ladax import utils 8 | 9 | 10 | def rbf_kernel_fun(x, x2, amplitude, lengthscale): 11 | """ Functional definition of an RBF kernel. """ 12 | pwd_dists = (x[..., jnp.newaxis, :] - x2[..., jnp.newaxis, :, :]) / lengthscale 13 | kernel_matrix = jnp.exp(-.5 * jnp.sum(pwd_dists ** 2, axis=-1)) 14 | return amplitude**2 * kernel_matrix 15 | 16 | 17 | class RBFKernelProvider(nn.Module): 18 | """ Provides an RBF kernel function. 19 | The role of a kernel provider is to handle initialisation, and 20 | parameter storage of a particular kernel function. Allowing 21 | functionally defined kernels to be slotted into more complex models 22 | built using the Flax functional api. 23 | """ 24 | def apply(self, 25 | index_points: jnp.ndarray, 26 | amplitude_init: Callable = jax.nn.initializers.ones, 27 | length_scale_init: Callable = jax.nn.initializers.ones) -> Callable: 28 | """ 29 | Args: 30 | index_points: The nd-array of index points to the kernel. Only used for 31 | feature shape finding. 32 | amplitude_init: initializer function for the amplitude parameter. 33 | length_scale_init: initializer function for the length-scale parameter. 34 | Returns: 35 | rbf_kernel_fun: Callable kernel function. 36 | """ 37 | amplitude = jax.nn.softplus( 38 | self.param('amplitude', 39 | (1,), 40 | amplitude_init)) + jnp.finfo(float).tiny 41 | 42 | length_scale = jax.nn.softplus( 43 | self.param('length_scale', 44 | (index_points.shape[-1],), 45 | length_scale_init)) + jnp.finfo(float).tiny 46 | 47 | return Kernel( 48 | lambda x_, y_: rbf_kernel_fun(x_, y_, amplitude, length_scale)) 49 | 50 | 51 | class SchurComplementKernelProvider(nn.Module): 52 | """ Provides a schur complement kernel. """ 53 | def apply(self, 54 | base_kernel_fun: Callable, 55 | fixed_index_points: jnp.ndarray, 56 | diag_shift: jnp.ndarray = jnp.zeros([1])) -> SchurComplementKernel: 57 | """ 58 | Args: 59 | kernel_fun: 60 | fixed_index_points: 61 | diag_shift: Python `float` 62 | Returns: 63 | """ 64 | # compute the "divisor-matrix" 65 | divisor_matrix = base_kernel_fun( 66 | fixed_index_points, fixed_index_points) 67 | 68 | divisor_matrix_cholesky = jnp.linalg.cholesky( 69 | utils.diag_shift(divisor_matrix, diag_shift)) 70 | 71 | return SchurComplementKernel(base_kernel_fun, 72 | fixed_index_points, 73 | divisor_matrix_cholesky) 74 | -------------------------------------------------------------------------------- /ladax/kernels/kernels.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.scipy as jscipy 3 | from flax import struct 4 | from typing import Callable 5 | 6 | 7 | @struct.dataclass 8 | class Kernel: 9 | kernel_fn: Callable = struct.field(pytree_node=False) 10 | 11 | def apply(self, x, x2): 12 | return self.kernel_fn(x, x2) 13 | 14 | def __call__(self, x, x2=None): 15 | x2 = x if x2 is None else x2 16 | return self.apply(x, x2) 17 | 18 | 19 | @struct.dataclass 20 | class SchurComplementKernel(Kernel): 21 | fixed_inputs: jnp.ndarray 22 | divisor_matrix_cholesky: jnp.ndarray 23 | 24 | def apply(self, x1, x2): 25 | k12 = self.kernel_fn(x1, x2) 26 | k1z = self.kernel_fn(x1, self.fixed_inputs) 27 | kz2 = self.kernel_fn(self.fixed_inputs, x2) 28 | return (k12 29 | - k1z @ jscipy.linalg.cho_solve( 30 | (self.divisor_matrix_cholesky, True), kz2)) 31 | 32 | 33 | @struct.dataclass 34 | class VariationalKernel(Kernel): 35 | fixed_inputs: jnp.ndarray 36 | variational_scale: jnp.ndarray 37 | jitter: float = 1.0e-4 38 | 39 | def apply(self, x1, x2): 40 | z = self.fixed_inputs 41 | kxy = self.kernel_fn(x1, x2) 42 | kxz = self.kernel_fn(x1, z) 43 | kzy = self.kernel_fn(z, x2) 44 | kzz = self.kernel_fn(z, z) 45 | kzz_cholesky = jnp.linalg.cholesky( 46 | kzz + self.jitter * jnp.eye(z.shape[-2])) 47 | 48 | kzz_chol_qu_scale = jscipy.linalg.cho_solve( 49 | (kzz_cholesky, True), self.variational_scale) 50 | 51 | return (kxy 52 | - kxz @ jscipy.linalg.cho_solve((kzz_cholesky, True), kzy) 53 | + kxz @ (kzz_chol_qu_scale @ kzz_chol_qu_scale.T) @ kzy) 54 | -------------------------------------------------------------------------------- /ladax/likelihoods/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_log_likelihood import GaussianLogLik 2 | -------------------------------------------------------------------------------- /ladax/likelihoods/gaussian_log_likelihood.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from flax import struct 3 | 4 | 5 | @struct.dataclass 6 | class GaussianLogLik: 7 | qu_mean: jnp.ndarray 8 | qu_scale: jnp.ndarray 9 | observation_noise_scale: jnp.ndarray 10 | 11 | def variational_expectation(self, y): 12 | return -.5 * jnp.squeeze( 13 | (jnp.sum(jnp.square(self.qu_mean - y)) 14 | + jnp.trace(self.qu_scale @ self.qu_scale.T)) 15 | / self.observation_noise_scale ** 2 16 | + y.shape[-1] * jnp.log(self.observation_noise_scale ** 2) 17 | + jnp.log(2 * jnp.pi)) 18 | -------------------------------------------------------------------------------- /ladax/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .gaussian_likelihood import VariationalGaussianLikelihoodLoss 2 | -------------------------------------------------------------------------------- /ladax/losses/gaussian_likelihood.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from flax import nn 4 | from ladax.gaussian_processes import VariationalGaussianProcess 5 | 6 | 7 | class VariationalGaussianLikelihoodLoss(nn.Module): 8 | """ """ 9 | def apply(self, y, vgp: VariationalGaussianProcess): 10 | obs_noise_scale = jax.nn.softplus( 11 | self.param('observation_noise_scale', (), jax.nn.initializers.ones)) 12 | 13 | variational_distribution = vgp.marginal() 14 | qu_mean = variational_distribution.mean 15 | qu_scale = variational_distribution.scale 16 | 17 | # Expected value of iid gaussians under q(u) 18 | expected_gll_under_qu = -.5 * jnp.squeeze( 19 | (jnp.sum(jnp.square(qu_mean - y)) 20 | + jnp.trace(qu_scale @ qu_scale.T)) 21 | / obs_noise_scale ** 2 22 | + y.shape[-1] * jnp.log(obs_noise_scale ** 2) 23 | + jnp.log(2 * jnp.pi)) 24 | 25 | # flip sign to minimize the elbo 26 | return -expected_gll_under_qu 27 | -------------------------------------------------------------------------------- /ladax/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .svgp import svgp_factory 2 | -------------------------------------------------------------------------------- /ladax/models/svgp.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from flax.nn import Module 3 | from ladax.gaussian_processes import SVGPLayer 4 | 5 | 6 | def svgp_factory(kernel_provider, 7 | inducing_variable_provider, 8 | mean_fn=None, 9 | kernel_fn_kwargs=None, 10 | inducing_variable_kwargs=None, 11 | svgp_layer_kwargs=None): 12 | 13 | mean_fn = mean_fn if mean_fn else lambda x: jnp.zeros(x.shape[-2], x.dtype) 14 | kernel_fn_kwargs = {} if kernel_fn_kwargs is None else kernel_fn_kwargs 15 | inducing_variable_kwargs = {} if inducing_variable_kwargs is None else inducing_variable_kwargs 16 | svgp_layer_kwargs = {} if svgp_layer_kwargs is None else svgp_layer_kwargs 17 | 18 | class SVGP(Module): 19 | def apply(self, x): 20 | kernel_fn = kernel_provider(x, **kernel_fn_kwargs) 21 | inducing_var = inducing_variable_provider(x, kernel_fn, **inducing_variable_kwargs) 22 | vgp = SVGPLayer(x, mean_fn, kernel_fn, inducing_var, **svgp_layer_kwargs) 23 | return vgp 24 | 25 | return SVGP 26 | -------------------------------------------------------------------------------- /ladax/utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | from typing import Union 3 | import jax.ops as ops 4 | 5 | 6 | def diag_shift(mat: jnp.ndarray, 7 | val: Union[float, jnp.ndarray]) -> jnp.ndarray: 8 | """ Shifts the diagonal of mat by val. """ 9 | return ops.index_update( 10 | mat, 11 | jnp.diag_indices(mat.shape[-1], len(mat.shape)), 12 | jnp.diag(mat) + val) 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | jaxlib 3 | flax 4 | numpy -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages 3 | from setuptools import setup 4 | 5 | here = os.path.abspath(os.path.dirname(__file__)) 6 | try: 7 | README = open(os.path.join(here, "README.md")).read() 8 | except IOError: 9 | README = "" 10 | 11 | version = "0.0.1" 12 | 13 | install_requires = [ 14 | "numpy>=1.12", 15 | "jaxlib>=0.1.41", 16 | "jax>=0.1.59", 17 | "flax", 18 | "matplotlib", 19 | "dataclasses", 20 | "msgpack", 21 | ] 22 | 23 | tests_require = [ 24 | ] 25 | 26 | setup( 27 | name="ladax", 28 | version=version, 29 | description="Ladax: layered distribution models using FLAX/JAX.", 30 | long_description="\n\n".join([README]), 31 | long_description_content_type='text/markdown', 32 | classifiers=[ 33 | "Development Status :: 3 - Alpha", 34 | "Intended Audience :: Developers", 35 | "Intended Audience :: Science/Research", 36 | "License :: OSI Approved :: MIT License", 37 | "Programming Language :: Python :: 3.7", 38 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 39 | ], 40 | keywords="", 41 | author="Dan Tait", 42 | author_email="tait.djk@gmail.com", 43 | url="https://github.com/danieljtait/ladax", 44 | license="Apache", 45 | packages=find_packages(), 46 | include_package_data=False, 47 | zip_safe=False, 48 | install_requires=install_requires, 49 | extras_require={ 50 | "testing": tests_require, 51 | }, 52 | ) 53 | --------------------------------------------------------------------------------