├── README.md ├── examples ├── 0_pathwise_conditioning.ipynb ├── 1_model_api.ipynb ├── 2_minimizing_sample_paths.ipynb ├── 3_training_via_sampling.ipynb └── 4_autoencoding_mnist.ipynb ├── gpflow_sampling ├── __init__.py ├── bases │ ├── __init__.py │ ├── core.py │ ├── dispatch.py │ ├── fourier_bases.py │ └── fourier_initializers.py ├── covariances │ ├── Kfus.py │ ├── Kufs.py │ ├── Kuus.py │ └── __init__.py ├── inducing_variables.py ├── kernels.py ├── models.py ├── sampling │ ├── __init__.py │ ├── core.py │ ├── decoupled_samplers.py │ ├── priors │ │ ├── __init__.py │ │ └── fourier_priors.py │ └── updates │ │ ├── __init__.py │ │ ├── cg_updates.py │ │ ├── exact_updates.py │ │ └── linear_updates.py └── utils │ ├── __init__.py │ ├── array_ops.py │ ├── conv_ops.py │ ├── gpflow_ops.py │ └── linalg.py ├── setup.py └── tests ├── __init__.py ├── kernels ├── __init__.py └── test_conv2d.py └── sampling ├── __init__.py ├── priors ├── __init__.py ├── test_fourier_conv2d.py └── test_fourier_dense.py └── updates ├── __init__.py ├── common.py ├── test_cg.py ├── test_exact.py └── test_linear.py /README.md: -------------------------------------------------------------------------------- 1 | # GPflowSampling 2 | Companion code for [Efficiently Sampling Functions from Gaussian Process Posteriors](https://arxiv.org/abs/2002.09309) and [Pathwise Conditioning of Gaussian processes](https://arxiv.org/abs/2011.04026). 3 | ## Overview 4 | Software provided here revolves around Matheron's update rule 5 | 6 | 7 | 8 | which allows us to represent a GP posterior as the sum of a prior random function and a data-driven update term. Thinking about conditioning at the level of random function (rather than marginal distributions) enables us to accurately sample GP posteriors in linear time. 9 | 10 | Please see `examples` for tutorials and (hopefully) illustrative use cases. 11 | 12 | ## Installation 13 | ``` 14 | git clone git@github.com:j-wilson/GPflowSampling.git 15 | cd GPflowSampling 16 | pip install -e . 17 | ``` 18 | To install the dependencies needed to run `examples`, use `pip install -e .[examples]`. 19 | 20 | 21 | ## Citing Us 22 | If our work helps you in a way that you feel warrants reference, please cite the following paper: 23 | ``` 24 | @inproceedings{wilson2020efficiently, 25 | title={Efficiently sampling functions from Gaussian process posteriors}, 26 | author={James T. Wilson 27 | and Viacheslav Borovitskiy 28 | and Alexander Terenin 29 | and Peter Mostowsky 30 | and Marc Peter Deisenroth}, 31 | booktitle={International Conference on Machine Learning}, 32 | year={2020}, 33 | url={https://arxiv.org/abs/2002.09309} 34 | } 35 | ``` 36 | -------------------------------------------------------------------------------- /gpflow_sampling/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from gpflow_sampling import utils, kernels, covariances, bases, sampling, models 5 | -------------------------------------------------------------------------------- /gpflow_sampling/bases/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from gpflow_sampling.bases.dispatch import ( 5 | kernel_basis as kernel, 6 | fourier_basis as fourier 7 | ) 8 | -------------------------------------------------------------------------------- /gpflow_sampling/bases/core.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | from abc import abstractmethod 9 | from typing import Union 10 | from tensorflow import Module 11 | from gpflow.base import TensorType 12 | from gpflow.kernels import Kernel, SharedIndependent, SeparateIndependent 13 | from gpflow.inducing_variables import InducingVariables 14 | from gpflow_sampling.utils import get_inducing_shape 15 | from gpflow_sampling.covariances import Kfu as Kfu_dispatch 16 | 17 | # ---- Exports 18 | __all__ = ('AbstractBasis', 'KernelBasis') 19 | 20 | 21 | # ============================================== 22 | # core 23 | # ============================================== 24 | class AbstractBasis(Module): 25 | def __init__(self, initialized: bool = False, name: str = None): 26 | super().__init__(name=name) 27 | self.initialized = initialized 28 | 29 | @abstractmethod 30 | def __call__(self, *args, **kwargs): 31 | raise NotImplementedError 32 | 33 | @abstractmethod 34 | def initialize(self, *args, **kwargs): 35 | pass 36 | 37 | def _maybe_initialize(self, *args, **kwargs): 38 | if not self.initialized: 39 | self.initialize(*args, **kwargs) 40 | self.initialized = True 41 | 42 | @property 43 | @abstractmethod 44 | def num_bases(self): 45 | raise NotImplementedError 46 | 47 | 48 | class KernelBasis(AbstractBasis): 49 | def __init__(self, 50 | kernel: Kernel, 51 | centers: Union[TensorType, InducingVariables], 52 | name: str = None, 53 | **default_kwargs): 54 | 55 | super().__init__(name=name) 56 | self.kernel = kernel 57 | self.centers = centers 58 | self.default_kwargs = default_kwargs 59 | 60 | def __call__(self, x, **kwargs): 61 | _kwargs = {**self.default_kwargs, **kwargs} # resolve keyword arguments 62 | self._maybe_initialize(x, **_kwargs) 63 | if isinstance(self.centers, InducingVariables): 64 | return Kfu_dispatch(self.centers, self.kernel, x, **_kwargs) 65 | 66 | if isinstance(self.kernel, (SharedIndependent, SeparateIndependent)): 67 | # TODO: Improve handling of "full_output_cov". Here, we're imitating 68 | # the behavior of gpflow.covariances.Kuf. 69 | _kwargs.setdefault('full_output_cov', False) 70 | 71 | return self.kernel.K(x, self.centers, **_kwargs) 72 | 73 | @property 74 | def num_bases(self): 75 | """ 76 | TODO: Edge-cases? 77 | """ 78 | if isinstance(self.centers, InducingVariables): 79 | return get_inducing_shape(self.centers)[-1] 80 | return self.centers.shape[-1] 81 | -------------------------------------------------------------------------------- /gpflow_sampling/bases/dispatch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | from typing import Union 9 | from gpflow import kernels as gpflow_kernels 10 | from gpflow.base import TensorType 11 | from gpflow.utilities import Dispatcher 12 | from gpflow.inducing_variables import InducingVariables 13 | from gpflow_sampling import kernels 14 | from gpflow_sampling.bases import fourier_bases 15 | from gpflow_sampling.bases.core import KernelBasis 16 | 17 | 18 | # ---- Exports 19 | __all__ = ( 20 | 'kernel_basis', 21 | 'fourier_basis', 22 | ) 23 | 24 | kernel_basis = Dispatcher("kernel_basis") 25 | fourier_basis = Dispatcher("fourier_basis") 26 | 27 | 28 | # ============================================== 29 | # dispatch 30 | # ============================================== 31 | @kernel_basis.register(gpflow_kernels.Kernel) 32 | def _kernel_fallback(kern: gpflow_kernels.Kernel, 33 | centers: Union[TensorType, InducingVariables], 34 | **kwargs): 35 | return KernelBasis(kernel=kern, centers=centers, **kwargs) 36 | 37 | 38 | @fourier_basis.register(gpflow_kernels.Stationary) 39 | def _fourier_stationary(kern: gpflow_kernels.Stationary, **kwargs): 40 | return fourier_bases.Dense(kernel=kern, **kwargs) 41 | 42 | 43 | @fourier_basis.register(gpflow_kernels.MultioutputKernel) 44 | def _fourier_multioutput(kern: gpflow_kernels.MultioutputKernel, **kwargs): 45 | return fourier_bases.MultioutputDense(kernel=kern, **kwargs) 46 | 47 | 48 | @fourier_basis.register(kernels.Conv2d) 49 | def _fourier_conv2d(kern: kernels.Conv2d, **kwargs): 50 | return fourier_bases.Conv2d(kernel=kern, **kwargs) 51 | 52 | 53 | @fourier_basis.register(kernels.Conv2dTranspose) 54 | def _fourier_conv2d_transposed(kern: kernels.Conv2dTranspose, **kwargs): 55 | return fourier_bases.Conv2dTranspose(kernel=kern, **kwargs) 56 | 57 | 58 | @fourier_basis.register(kernels.DepthwiseConv2d) 59 | def _fourier_depthwise_conv2d(kern: kernels.DepthwiseConv2d, **kwargs): 60 | return fourier_bases.DepthwiseConv2d(kernel=kern, **kwargs) 61 | -------------------------------------------------------------------------------- /gpflow_sampling/bases/fourier_bases.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from typing import Any 11 | from gpflow import kernels as gpflow_kernels 12 | from gpflow.base import TensorType 13 | from gpflow.inducing_variables import InducingVariables 14 | from gpflow_sampling import kernels, inducing_variables 15 | from gpflow_sampling.bases.core import AbstractBasis 16 | from gpflow_sampling.utils import (move_axis, 17 | expand_to, 18 | batch_tensordot, 19 | inducing_to_tensor) 20 | from gpflow_sampling.bases.fourier_initializers import (bias_initializer, 21 | weight_initializer) 22 | 23 | 24 | # ---- Exports 25 | __all__ = ( 26 | 'Dense', 27 | 'MultioutputDense', 28 | 'Conv2d', 29 | 'Conv2dTranspose', 30 | 'DepthwiseConv2d', 31 | ) 32 | 33 | 34 | # ============================================== 35 | # fourier_bases 36 | # ============================================== 37 | class AbstractFourierBasis(AbstractBasis): 38 | def __init__(self, 39 | kernel: gpflow_kernels.Kernel, 40 | num_bases: int, 41 | initialized: bool = False, 42 | name: Any = None): 43 | super().__init__(initialized=initialized, name=name) 44 | self.kernel = kernel 45 | self._num_bases = num_bases 46 | 47 | @property 48 | def num_bases(self): 49 | return self._num_bases 50 | 51 | 52 | class Dense(AbstractFourierBasis): 53 | def __init__(self, 54 | kernel: gpflow_kernels.Stationary, 55 | num_bases: int, 56 | weights: tf.Tensor = None, 57 | biases: tf.Tensor = None, 58 | name: str = None): 59 | super().__init__(name=name, 60 | kernel=kernel, 61 | num_bases=num_bases) 62 | self._weights = weights 63 | self._biases = biases 64 | 65 | def __call__(self, x: TensorType, **kwargs) -> tf.Tensor: 66 | self._maybe_initialize(x, **kwargs) 67 | if isinstance(x, InducingVariables): # TODO: Allow this behavior? 68 | x = inducing_to_tensor(x) 69 | 70 | proj = tf.tensordot(x, self.weights, axes=[-1, -1]) # [..., B] 71 | feat = tf.cos(proj + self.biases) 72 | return self.output_scale * feat 73 | 74 | def initialize(self, x: TensorType, dtype: Any = None): 75 | if isinstance(x, InducingVariables): 76 | x = inducing_to_tensor(x) 77 | 78 | if dtype is None: 79 | dtype = x.dtype 80 | 81 | self._biases = bias_initializer(self.kernel, self.num_bases, dtype=dtype) 82 | self._weights = weight_initializer(self.kernel, x.shape[-1], 83 | batch_shape=[self.num_bases], 84 | dtype=dtype) 85 | 86 | @property 87 | def weights(self): 88 | if self._weights is None: 89 | return None 90 | return tf.math.reciprocal(self.kernel.lengthscales) * self._weights 91 | 92 | @property 93 | def biases(self): 94 | return self._biases 95 | 96 | @property 97 | def output_scale(self): 98 | return tf.sqrt(2 * self.kernel.variance / self.num_bases) 99 | 100 | 101 | class MultioutputDense(Dense): 102 | def __call__(self, x: TensorType, multioutput_axis: int = None, **kwargs): 103 | self._maybe_initialize(x, **kwargs) 104 | if isinstance(x, InducingVariables): # TODO: Allow this behavior? 105 | x = inducing_to_tensor(x) 106 | 107 | # Compute (batch) tensor dot product 108 | batch_axes = None if (multioutput_axis is None) else [0, multioutput_axis] 109 | proj = move_axis(batch_tensordot(self.weights, 110 | x, 111 | axes=[-1, -1], 112 | batch_axes=batch_axes), 1, -1) 113 | 114 | ndims = proj.shape.ndims 115 | feat = tf.cos(proj + expand_to(self.biases, axis=1, ndims=ndims)) 116 | return expand_to(self.output_scale, axis=1, ndims=ndims) * feat # [L, N, B] 117 | 118 | def initialize(self, x: TensorType, dtype: Any = None): 119 | if isinstance(x, InducingVariables): 120 | x = inducing_to_tensor(x) 121 | 122 | if dtype is None: 123 | dtype = x.dtype 124 | 125 | biases = [] 126 | weights = [] 127 | for kernel in self.kernel.latent_kernels: 128 | biases.append( 129 | bias_initializer(kernel, self.num_bases, dtype=dtype)) 130 | 131 | weights.append( 132 | weight_initializer(kernel, x.shape[-1], 133 | batch_shape=[self.num_bases], 134 | dtype=dtype)) 135 | 136 | self._biases = tf.stack(biases, axis=0) # [L, B] 137 | self._weights = tf.stack(weights, axis=0) # [L, B, D] 138 | 139 | @property 140 | def weights(self): 141 | if self._weights is None: 142 | return None 143 | 144 | num_lengthscales = None 145 | for kernel in self.kernel.latent_kernels: 146 | if kernel.ard: 147 | ls = kernel.lengthscales 148 | assert ls.shape.ndims == 1 149 | if num_lengthscales is None: 150 | num_lengthscales = ls.shape[0] 151 | else: 152 | assert num_lengthscales == ls.shape[0] 153 | 154 | inv_lengthscales = [] 155 | for kernel in self.kernel.latent_kernels: 156 | inv_ls = tf.math.reciprocal(kernel.lengthscales) 157 | if not kernel.ard and num_lengthscales is not None: 158 | inv_ls = tf.fill([num_lengthscales], inv_ls) 159 | inv_lengthscales.append(inv_ls) 160 | 161 | # [L, 1, D] or [L, 1, 1] 162 | inv_lengthscales = expand_to(arr=tf.stack(inv_lengthscales), 163 | axis=1, 164 | ndims=self._weights.shape.ndims) 165 | 166 | return inv_lengthscales * self._weights 167 | 168 | @property 169 | def output_scale(self): 170 | variances = tf.stack([k.variance for k in self.kernel.latent_kernels]) 171 | return tf.sqrt(2 * variances / self.num_bases) # [L] 172 | 173 | 174 | class Conv2d(AbstractFourierBasis): 175 | def __init__(self, 176 | kernel: kernels.Conv2d, 177 | num_bases: int, 178 | filters: tf.Tensor = None, 179 | biases: tf.Tensor = None, 180 | name: str = None): 181 | 182 | super().__init__(name=name, 183 | kernel=kernel, 184 | num_bases=num_bases) 185 | self._filters = filters 186 | self._biases = biases 187 | 188 | def __call__(self, x: TensorType) -> tf.Tensor: 189 | self._maybe_initialize(x) 190 | if isinstance(x, InducingVariables) or len(x.shape) == 4: 191 | conv = self.convolve(x) 192 | elif len(x.shape) > 4: # allow for higher order batches 193 | x_4d = tf.reshape(x, [-1] + list(x.shape[-3:])) 194 | conv = self.convolve(x_4d) 195 | conv = tf.reshape(conv, list(x.shape[:-3]) + list(conv.shape[1:])) 196 | else: 197 | raise NotImplementedError 198 | return self.output_scale * tf.cos(conv + self.biases) 199 | 200 | def convolve(self, x: TensorType) -> tf.Tensor: 201 | if isinstance(x, inducing_variables.InducingImages): 202 | return tf.nn.conv2d(input=x.as_images, 203 | filters=self.filters, 204 | strides=(1, 1, 1, 1), 205 | padding="VALID") 206 | return self.kernel.convolve(input=x, filters=self.filters) 207 | 208 | def initialize(self, x, dtype: Any = None): 209 | if isinstance(x, inducing_variables.InducingImages): 210 | x = x.as_images 211 | 212 | if dtype is None: 213 | dtype = x.dtype 214 | 215 | self._biases = bias_initializer(self.kernel.kernel, 216 | self.num_bases, 217 | dtype=dtype) 218 | 219 | patch_size = (self.kernel.channels_in 220 | * self.kernel.patch_shape[0] 221 | * self.kernel.patch_shape[1]) 222 | 223 | weights = weight_initializer(self.kernel.kernel, patch_size, 224 | batch_shape=[self.num_bases], 225 | dtype=dtype) 226 | 227 | shape = self.kernel.patch_shape + [self.kernel.channels_in, self.num_bases] 228 | self._filters = tf.reshape(move_axis(weights, -1, 0), shape) 229 | 230 | @property 231 | def filters(self): 232 | if self._filters is None: 233 | return None 234 | 235 | shape = list(self.kernel.patch_shape) + [self.kernel.channels_in, 1] 236 | inv_ls = tf.math.reciprocal(self.kernel.kernel.lengthscales) 237 | if self.kernel.kernel.ard: 238 | coeffs = tf.reshape(inv_ls, shape) 239 | else: 240 | coeffs = tf.fill(shape, inv_ls) 241 | 242 | return coeffs * self._filters 243 | 244 | @property 245 | def biases(self): 246 | return self._biases 247 | 248 | @property 249 | def output_scale(self): 250 | return tf.sqrt(2 * self.kernel.kernel.variance / self.num_bases) 251 | 252 | 253 | class Conv2dTranspose(Conv2d): 254 | pass 255 | 256 | 257 | class DepthwiseConv2d(Conv2d): 258 | def convolve(self, x: TensorType) -> tf.Tensor: 259 | if isinstance(x, inducing_variables.DepthwiseInducingImages): 260 | return tf.nn.depthwise_conv2d(input=x.as_images, 261 | filter=self.filters, 262 | strides=(1, 1, 1, 1), 263 | padding="VALID") 264 | 265 | return self.kernel.convolve(input=x, filters=self.filters) 266 | 267 | def initialize(self, x, dtype: Any = None): 268 | if isinstance(x, inducing_variables.InducingImages): 269 | x = x.as_images 270 | 271 | if dtype is None: 272 | dtype = x.dtype 273 | 274 | channels_out = self.kernel.channels_in * self.num_bases 275 | self._biases = bias_initializer(self.kernel.kernel, 276 | channels_out, 277 | dtype=dtype) 278 | 279 | patch_size = self.kernel.patch_shape[0] * self.kernel.patch_shape[1] 280 | batch_shape = [self.kernel.channels_in, self.num_bases] 281 | weights = weight_initializer(self.kernel.kernel, patch_size, 282 | batch_shape=batch_shape, 283 | dtype=dtype) 284 | 285 | self._filters = tf.reshape(move_axis(weights, -1, 0), 286 | self.kernel.patch_shape + batch_shape) 287 | 288 | @property 289 | def filters(self): 290 | if self._filters is None: 291 | return None 292 | 293 | shape = list(self.kernel.patch_shape) + [self.kernel.channels_in, 1] 294 | inv_ls = tf.math.reciprocal(self.kernel.kernel.lengthscales) 295 | if self.kernel.kernel.ard: 296 | coeffs = tf.reshape(tf.transpose(inv_ls), shape) 297 | else: 298 | coeffs = tf.fill(shape, inv_ls) 299 | 300 | return coeffs * self._filters 301 | 302 | @property 303 | def output_scale(self): 304 | num_features_out = self.num_bases * self.kernel.channels_in 305 | return tf.sqrt(2 * self.kernel.kernel.variance / num_features_out) 306 | -------------------------------------------------------------------------------- /gpflow_sampling/bases/fourier_initializers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | from typing import Any, List 12 | from gpflow import kernels 13 | from gpflow.config import default_float 14 | from gpflow.utilities import Dispatcher 15 | 16 | # ---- Exports 17 | __all__ = ('bias_initializer', 'weight_initializer') 18 | 19 | 20 | # ============================================== 21 | # initializers 22 | # ============================================== 23 | MaternKernel = kernels.Matern12, kernels.Matern32, kernels.Matern52 24 | bias_initializer = Dispatcher("bias_initializer") 25 | weight_initializer = Dispatcher("weight_initializer") 26 | 27 | 28 | @bias_initializer.register(kernels.Stationary, int) 29 | def _bias_initializer_fallback(kern: kernels.Stationary, 30 | ndims: int, 31 | *, 32 | batch_shape: List = None, 33 | dtype: Any = None, 34 | maxval: float = 2 * np.pi) -> tf.Tensor: 35 | if dtype is None: 36 | dtype = default_float() 37 | 38 | shape = [ndims] if batch_shape is None else list(batch_shape) + [ndims] 39 | return tf.random.uniform(shape=shape, maxval=maxval, dtype=dtype) 40 | 41 | 42 | @weight_initializer.register(kernels.SquaredExponential, int) 43 | def _weight_initializer_squaredExp(kern: kernels.SquaredExponential, 44 | ndims: int, 45 | *, 46 | batch_shape: List = None, 47 | dtype: Any = None, 48 | normal_rvs: tf.Tensor = None) -> tf.Tensor: 49 | if dtype is None: 50 | dtype = default_float() 51 | 52 | if batch_shape is None: 53 | batch_shape = [] 54 | 55 | shape = [ndims] if batch_shape is None else list(batch_shape) + [ndims] 56 | if normal_rvs is None: 57 | return tf.random.normal(shape, dtype=dtype) 58 | 59 | assert tuple(normal_rvs.shape) == tuple(shape) 60 | return tf.convert_to_tensor(normal_rvs, dtype=dtype) 61 | 62 | 63 | @weight_initializer.register(MaternKernel, int) 64 | def _weight_initializer_matern(kern: MaternKernel, 65 | ndims: int, 66 | *, 67 | batch_shape: List = None, 68 | dtype: Any = None, 69 | normal_rvs: tf.Tensor = None, 70 | gamma_rvs: tf.Tensor = None) -> tf.Tensor: 71 | if dtype is None: 72 | dtype = default_float() 73 | 74 | if isinstance(kern, kernels.Matern12): 75 | smoothness = 1/2 76 | elif isinstance(kern, kernels.Matern32): 77 | smoothness = 3/2 78 | elif isinstance(kern, kernels.Matern52): 79 | smoothness = 5/2 80 | else: 81 | raise NotImplementedError 82 | 83 | batch_shape = [] if batch_shape is None else list(batch_shape) 84 | if normal_rvs is None: 85 | normal_rvs = tf.random.normal(shape=batch_shape + [ndims], dtype=dtype) 86 | else: 87 | assert tuple(normal_rvs.shape) == tuple(batch_shape + [ndims]) 88 | normal_rvs = tf.convert_to_tensor(normal_rvs, dtype=dtype) 89 | 90 | if gamma_rvs is None: 91 | gamma_rvs = tf.random.gamma(shape=batch_shape + [1], 92 | alpha=smoothness, 93 | beta=smoothness, 94 | dtype=dtype) 95 | else: 96 | assert tuple(gamma_rvs.shape) == tuple(batch_shape + [1]) 97 | gamma_rvs = tf.convert_to_tensor(gamma_rvs, dtype=dtype) 98 | 99 | # Return draws from a multivariate-t distribution 100 | return tf.math.rsqrt(gamma_rvs) * normal_rvs 101 | -------------------------------------------------------------------------------- /gpflow_sampling/covariances/Kfus.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from gpflow import kernels 11 | from gpflow.base import TensorLike 12 | from gpflow.utilities import Dispatcher 13 | from gpflow.inducing_variables import (InducingVariables, 14 | SharedIndependentInducingVariables) 15 | from gpflow.covariances.dispatch import Kuf as Kuf_dispatch 16 | from gpflow_sampling.kernels import Conv2d, DepthwiseConv2d 17 | from gpflow_sampling.utils import move_axis, get_inducing_shape 18 | from gpflow_sampling.inducing_variables import (InducingImages, 19 | DepthwiseInducingImages) 20 | 21 | 22 | # ============================================== 23 | # Kfus 24 | # ============================================== 25 | Kfu = Dispatcher("Kfu") 26 | 27 | 28 | @Kfu.register(InducingVariables, kernels.Kernel, TensorLike) 29 | def _Kfu_fallback(Z, kern, X, **kwargs): 30 | Kuf = Kuf_dispatch(Z, kern, X, **kwargs) 31 | 32 | # Assume features of x and z are 1-dimensional 33 | ndims_x = X.shape.ndims - 1 # assume x lives in 1d space 34 | ndims_z = len(get_inducing_shape(Z)) - 1 35 | assert ndims_x + ndims_z == Kuf.shape.ndims 36 | 37 | # Swap the batch axes of x and z 38 | axes = list(range(ndims_x + ndims_z)) 39 | perm = axes[ndims_z: ndims_z + ndims_x] + axes[:ndims_z] 40 | return tf.transpose(Kuf, perm) 41 | 42 | 43 | @Kfu.register(InducingVariables, kernels.MultioutputKernel, TensorLike) 44 | def _Kfu_fallback_multioutput(Z, kern, X, **kwargs): 45 | Kuf = Kuf_dispatch(Z, kern, X, **kwargs) 46 | 47 | # Assume features of x and z are 1-dimensional 48 | ndims_x = X.shape.ndims - 1 # assume x lives in 1d space 49 | ndims_z = 1 # shared Z live in 1d space, separate Z are 2d but 1-to-1 with L 50 | assert ndims_x + ndims_z == Kuf.shape.ndims - 1 51 | 52 | # Swap the batch axes of x and z 53 | axes = list(range(1, ndims_x + ndims_z + 1)) # keep L output-features first 54 | perm = [0] + axes[ndims_z: ndims_z + ndims_x] + axes[:ndims_z] 55 | return tf.transpose(Kuf, perm) 56 | 57 | 58 | @Kfu.register(SharedIndependentInducingVariables, 59 | kernels.SharedIndependent, 60 | TensorLike) 61 | def _Kfu_fallback_shared(Z, kern, X, **kwargs): 62 | return _Kfu_fallback(Z, kern, X, **kwargs) # Edge-case where L is supressed 63 | 64 | 65 | def _Kfu_conv2d_fallback(feat: InducingImages, 66 | kern: Conv2d, 67 | Xnew: tf.Tensor, 68 | full_spatial: bool = False): 69 | 70 | Zp = feat.as_patches # [M, patch_len] 71 | Xp = kern.get_patches(Xnew, full_spatial=full_spatial) 72 | Kxz = kern.kernel.K(Xp, Zp) # [N, H * W, M] or [N, H, W, M] 73 | if full_spatial: # convert to 4D image format 74 | spatial_out = kern.get_spatial_out(Xnew.shape[-3:-1]) # to [N, H, W, M] 75 | return tf.reshape(Kxz, list(Kxz.shape[:-2]) + spatial_out + [Kxz.shape[-1]]) 76 | 77 | if kern.weights is None: 78 | return tf.reduce_mean(Kxz, axis=-2) 79 | 80 | return tf.tensordot(Kxz, kern.weights, [-2, -1]) 81 | 82 | 83 | @Kfu.register(InducingImages, Conv2d, object) 84 | def _Kfu_conv2d(feat: InducingImages, 85 | kern: Conv2d, 86 | Xnew: tf.Tensor, 87 | full_spatial: bool = False): 88 | 89 | if not isinstance(kern.kernel, kernels.Stationary): 90 | return _Kfu_conv2d_fallback(feat, kern, Xnew, full_spatial) 91 | 92 | # Compute (squared) Mahalanobis distances between patches 93 | patch_shape = list(kern.patch_shape) 94 | channels_in = Xnew.shape[-3 if kern.data_format == "NCHW" else -1] 95 | precis = tf.square(tf.math.reciprocal(kern.kernel.lengthscales)) 96 | 97 | # Construct lengthscale filters [h, w, channels_in, 1] 98 | if kern.kernel.ard: 99 | filters = tf.reshape(precis, patch_shape + [channels_in, 1]) 100 | else: 101 | filters = tf.fill(patch_shape + [channels_in, 1], precis) 102 | 103 | r2 = tf.transpose(tf.nn.conv2d(input=tf.square(feat.as_images), 104 | filters=filters, 105 | strides=[1, 1], 106 | padding="VALID")) 107 | 108 | X = tf.reshape(Xnew, [-1] + list(Xnew.shape)[-3:]) # stack as 4d images 109 | r2 += kern.convolve(tf.square(X), filters) # [N, height_out, width_out, M] 110 | 111 | filters *= feat.as_filters # [h, w, channels_in, M] 112 | r2 -= 2 * kern.convolve(X, filters) 113 | 114 | Kxz = kern.kernel.K_r2(r2) 115 | if not full_spatial: 116 | Kxz = tf.reshape(Kxz, list(Kxz.shape[:-3]) + [-1, len(feat)]) # [N, P, M] 117 | if kern.weights is None: 118 | Kxz = tf.reduce_mean(Kxz, axis=-2) 119 | else: 120 | Kxz = tf.tensordot(Kxz, kern.weights, [-2, -1]) 121 | 122 | # Undo stacking of Xnew as 4d images X 123 | return tf.reshape(Kxz, list(Xnew.shape[:-3]) + list(Kxz.shape[1:])) 124 | 125 | 126 | def _Kfu_depthwise_conv2d_fallback(feat: DepthwiseInducingImages, 127 | kern: DepthwiseConv2d, 128 | Xnew: tf.Tensor, 129 | full_spatial: bool = False): 130 | 131 | Zp = feat.as_patches # [M, channels_in, patch_len] 132 | Xp = kern.get_patches(Xnew, full_spatial=full_spatial) 133 | r2 = tf.reduce_sum(tf.math.squared_difference( # compute square distances 134 | tf.expand_dims(kern.kernel.scale(Xp), -Zp.shape.ndims), 135 | kern.kernel.scale(Zp)), axis=-1) 136 | 137 | Kxz = kern.kernel.K_r2(r2) 138 | if full_spatial: # convert to 4D image format as [N, H, W, channels_in * M] 139 | return tf.reshape(move_axis(Kxz, -1, -2), list(Kxz.shape[:-2]) + [-1]) 140 | 141 | if kern.weights is None: # reduce over channels and patches 142 | return tf.reduce_mean(Kxz, axis=[-3, -1]) 143 | 144 | return tf.tensordot(kern.weights, Kxz, axes=[(0, 1), (-3, -1)]) 145 | 146 | 147 | @Kfu.register(DepthwiseInducingImages, DepthwiseConv2d, object) 148 | def _Kfu_depthwise_conv2d(feat: DepthwiseInducingImages, 149 | kern: DepthwiseConv2d, 150 | Xnew: tf.Tensor, 151 | full_spatial: bool = False): 152 | 153 | if not isinstance(kern.kernel, kernels.Stationary): 154 | return _Kfu_depthwise_conv2d_fallback(feat, kern, Xnew, full_spatial) 155 | 156 | # Compute (squared) Mahalanobis distances between patches 157 | patch_shape = list(kern.patch_shape) 158 | channels_in = Xnew.shape[-3 if kern.data_format == "NCHW" else -1] 159 | channels_out = len(feat) * channels_in 160 | precis = tf.square(tf.math.reciprocal(kern.kernel.lengthscales)) 161 | 162 | # Construct lengthscale filters [h, w, channels_in, 1] 163 | if kern.kernel.ard: # notice the transpose! 164 | assert tuple(precis.shape) == (channels_in, tf.reduce_prod(patch_shape)) 165 | filters = tf.reshape(tf.transpose(precis), patch_shape + [channels_in, 1]) 166 | else: 167 | filters = tf.fill(patch_shape + [channels_in, 1], precis) 168 | 169 | ZZ = tf.nn.depthwise_conv2d(input=tf.square(feat.as_images), 170 | filter=filters, 171 | strides=[1, 1, 1, 1], 172 | padding="VALID") # [M, 1, 1, channels_in] 173 | 174 | r2 = tf.reshape(move_axis(ZZ, 0, -1), [1, 1, 1, channels_out]) 175 | 176 | X = tf.reshape(Xnew, [-1] + list(Xnew.shape)[-3:]) # stack as 4d images 177 | r2 += tf.repeat(kern.convolve(tf.square(X), filters), len(feat), axis=-1) 178 | 179 | filters *= feat.as_filters # [h, w, channels_in, M] 180 | r2 -= 2 * kern.convolve(X, filters) # [N, height_out, width_out, chan_out] 181 | 182 | Kxz = kern.kernel.K_r2(r2) 183 | if full_spatial: 184 | Kxz = tf.reduce_mean( 185 | tf.reshape(Kxz, list(Kxz.shape[:-1]) + [channels_in, -1]), 186 | axis=-2) # average over input channels 187 | else: 188 | Kxz = tf.reshape(Kxz, list(Kxz.shape[:-3]) + [-1, len(feat)]) # [N, P, M] 189 | if kern.weights is None: 190 | Kxz = tf.reduce_mean(Kxz, axis=-2) 191 | else: 192 | div = tf.cast(1/channels_in, Kxz.dtype) 193 | Kxz = div * tf.tensordot(Kxz, tf.reshape(kern.weights, [-1]), [-2, -1]) 194 | 195 | # Undo stacking of Xnew as 4d images X 196 | return tf.reshape(Kxz, list(Xnew.shape[:-3]) + list(Kxz.shape[1:])) 197 | -------------------------------------------------------------------------------- /gpflow_sampling/covariances/Kufs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from gpflow.base import TensorLike 11 | from gpflow.covariances.dispatch import Kuf 12 | from gpflow_sampling.utils import swap_axes 13 | from gpflow_sampling.kernels import Conv2d 14 | from gpflow_sampling.covariances.Kfus import Kfu as Kfu_dispatch 15 | from gpflow_sampling.inducing_variables import InducingImages 16 | 17 | 18 | # ============================================== 19 | # Kufs 20 | # ============================================== 21 | @Kuf.register(InducingImages, Conv2d, TensorLike) 22 | def _Kuf_conv2d_fallback(Z, kernel, X, full_spatial: bool = False, **kwargs): 23 | Kfu = Kfu_dispatch(Z, kernel, X, full_spatial=full_spatial, **kwargs) 24 | 25 | ndims_x = X.shape.ndims - 3 # assume x lives in 3d image space 26 | ndims_z = Z.as_images.shape.ndims - 3 27 | 28 | if full_spatial: 29 | assert Kfu.shape.ndims == ndims_x + ndims_z + 2 30 | return swap_axes(Kfu, -4, -1) # TODO: this is a hack 31 | 32 | # Swap the batch axes of x and z 33 | assert Kfu.shape.ndims == ndims_x + ndims_z 34 | axes = list(range(ndims_x + ndims_z)) 35 | perm = axes[ndims_x: ndims_x + ndims_z] + axes[:ndims_x] 36 | return tf.transpose(Kfu, perm) 37 | -------------------------------------------------------------------------------- /gpflow_sampling/covariances/Kuus.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from gpflow.utilities.ops import square_distance 11 | from gpflow.covariances.dispatch import Kuu 12 | from gpflow_sampling.utils import move_axis 13 | from gpflow_sampling.kernels import Conv2d, DepthwiseConv2d 14 | from gpflow_sampling.inducing_variables import (InducingImages, 15 | DepthwiseInducingImages) 16 | 17 | # ============================================== 18 | # Kuus 19 | # ============================================== 20 | @Kuu.register(InducingImages, Conv2d) 21 | def _Kuu_conv2d(feat: InducingImages, 22 | kern: Conv2d, 23 | jitter: float = 0.0): 24 | _Kuu = kern.kernel.K(feat.as_patches) 25 | return tf.linalg.set_diag(_Kuu, tf.linalg.diag_part(_Kuu) + jitter) 26 | 27 | 28 | @Kuu.register(DepthwiseInducingImages, DepthwiseConv2d) 29 | def _Kuu_depthwise_conv2d(feat: DepthwiseInducingImages, 30 | kern: DepthwiseConv2d, 31 | jitter: float = 0.0): 32 | 33 | # Prepare scaled inducing patches; shape(Zp) = [channels_in, M, patch_len] 34 | Zp = move_axis(kern.kernel.scale(feat.as_patches), -2, 0) 35 | r2 = square_distance(Zp, None) 36 | _Kuu = tf.reduce_mean(kern.kernel.K_r2(r2), axis=0) # [M, M] 37 | return tf.linalg.set_diag(_Kuu, tf.linalg.diag_part(_Kuu) + jitter) 38 | -------------------------------------------------------------------------------- /gpflow_sampling/covariances/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from gpflow_sampling.covariances.Kuus import Kuu 5 | from gpflow_sampling.covariances.Kufs import Kuf 6 | from gpflow_sampling.covariances.Kfus import Kfu -------------------------------------------------------------------------------- /gpflow_sampling/inducing_variables.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | from typing import Optional 10 | from gpflow import inducing_variables 11 | from gpflow.base import TensorData, Parameter 12 | from gpflow.config import default_float 13 | from gpflow_sampling.utils import move_axis 14 | 15 | 16 | # ---- Exports 17 | __all__ = ( 18 | 'InducingImages', 19 | 'SharedInducingImages', 20 | 'DepthwiseInducingImages', 21 | 'SharedDepthwiseInducingImages', 22 | ) 23 | 24 | 25 | # ============================================== 26 | # inducing_variables 27 | # ============================================== 28 | class InducingImages(inducing_variables.InducingVariables): 29 | def __init__(self, images: TensorData, name: Optional[str] = None): 30 | """ 31 | :param images: initial values of inducing locations in image form. 32 | 33 | The shape of the inducing variables varies by representation: 34 | - as Z: [M, height * width * channels_in] 35 | - as images: [M, height, width, channels_in] 36 | - as patches: [M, height * width * channels_in] 37 | - as filters: [height, width, channels_in, M] 38 | 39 | TODO: 40 | - Generalize to allow for inducing image with multiple patches? 41 | - Work on naming convention? The term 'image' is a bit too general. 42 | Patch works, however this term usually refers to a vectorized form 43 | and (for now) overlaps with GPflow's own inducing class. Alternatives 44 | include: filter, window, glimpse 45 | """ 46 | super().__init__(name=name) 47 | self._images = Parameter(images, dtype=default_float()) 48 | 49 | def __len__(self) -> int: 50 | return self._images.shape[0] 51 | 52 | @property 53 | def Z(self) -> tf.Tensor: 54 | return tf.reshape(self._images, [len(self), -1]) 55 | 56 | @property 57 | def as_patches(self) -> tf.Tensor: 58 | return tf.reshape(self.as_images, [len(self), -1]) 59 | 60 | @property 61 | def as_filters(self) -> tf.Tensor: 62 | return move_axis(self.as_images, 0, -1) 63 | 64 | @property 65 | def as_images(self) -> tf.Tensor: 66 | return tf.convert_to_tensor(self._images, dtype=self._images.dtype) 67 | 68 | 69 | class SharedInducingImages(InducingImages): 70 | def __init__(self, 71 | images: TensorData, 72 | channels_in: int, 73 | name: Optional[str] = None): 74 | """ 75 | :param images: initial values of inducing locations in image form. 76 | :param channels_in: number of input channels to share across 77 | 78 | Same as but with the same single-channel inducing 79 | images shared across all input channels. 80 | 81 | The shape of the inducing variables varies by representation: 82 | - as Z: [M, height * width] (new!) 83 | - as images: [M, height, width, channels_in] 84 | - as patches [M, channels_in, height * width] 85 | - as filters: [height, width, channels_in, M] 86 | """ 87 | assert images.shape.ndims == 4 and images.shape[-1] == 1 88 | self.channels_in = channels_in 89 | super().__init__(images, name=name) 90 | 91 | @property 92 | def as_images(self) -> tf.Tensor: 93 | return tf.tile(self._images, [1, 1, 1, self.channels_in]) 94 | 95 | 96 | class DepthwiseInducingImages(InducingImages): 97 | """ 98 | Same as but for depthwise convolutions. 99 | 100 | The shape of the inducing variables varies by representation: 101 | - as Z: [M, height * width * channels_in] 102 | - as images: [M, height, width, channels_in] 103 | - as patches [M, channels_in, height * width] (new!) 104 | - as filters: [height, width, channels_in, M] 105 | """ 106 | @property 107 | def as_patches(self) -> tf.Tensor: 108 | images = self.as_images 109 | patches = tf.reshape(images, [images.shape[0], -1, images.shape[-1]]) 110 | return tf.transpose(patches, [0, 2, 1]) # [M, channels_in, patch_len] 111 | 112 | 113 | class SharedDepthwiseInducingImages(SharedInducingImages, 114 | DepthwiseInducingImages): 115 | def __init__(self, 116 | images: TensorData, 117 | channels_in: int, 118 | name: Optional[str] = None): 119 | """ 120 | :param images: initial values of inducing locations in image form. 121 | :param channels_in: number of input channels to share across. 122 | """ 123 | SharedInducingImages.__init__(self, 124 | name=name, 125 | images=images, 126 | channels_in=channels_in) 127 | -------------------------------------------------------------------------------- /gpflow_sampling/kernels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from typing import List 11 | from warnings import warn 12 | from gpflow import kernels 13 | from gpflow.base import TensorType, Parameter 14 | from gpflow.config import default_float 15 | from gpflow_sampling.utils import (conv_ops, 16 | swap_axes, 17 | move_axis, 18 | batch_tensordot) 19 | from tensorflow.python.keras.utils.conv_utils import (conv_output_length, 20 | deconv_output_length) 21 | 22 | # ---- Exports 23 | __all__ = ( 24 | 'Conv2d', 25 | 'Conv2dTranspose', 26 | 'DepthwiseConv2d', 27 | ) 28 | 29 | 30 | # ============================================== 31 | # kernels 32 | # ============================================== 33 | class Conv2d(kernels.MultioutputKernel): 34 | def __init__(self, 35 | kernel: kernels.Kernel, 36 | image_shape: List, 37 | patch_shape: List, 38 | channels_in: int = 1, 39 | channels_out: int = 1, 40 | weights: TensorType = "default", 41 | strides: List = None, 42 | padding: str = "VALID", 43 | dilations: List = None, 44 | data_format: str = "NHWC"): 45 | 46 | strides = list((1, 1) if strides is None else strides) 47 | dilations = list((1, 1) if dilations is None else dilations) 48 | 49 | # Sanity checks 50 | assert len(strides) == 2 51 | assert len(dilations) == 2 52 | assert padding in ("VALID", "SAME") 53 | assert data_format in ("NHWC", "NCHW") 54 | 55 | if isinstance(weights, str) and weights == "default": # TODO: improve me 56 | spatial_out = self.get_spatial_out(spatial_in=image_shape, 57 | filter_shape=patch_shape, 58 | strides=strides, 59 | padding=padding, 60 | dilations=dilations) 61 | 62 | weights = tf.ones([tf.reduce_prod(spatial_out)], dtype=default_float()) 63 | 64 | super().__init__() 65 | self.kernel = kernel 66 | self.image_shape = image_shape 67 | self.patch_shape = patch_shape 68 | self.channels_in = channels_in 69 | self.channels_out = channels_out 70 | 71 | self.strides = strides 72 | self.padding = padding 73 | self.dilations = dilations 74 | self.data_format = data_format 75 | self._weights = None if (weights is None) else Parameter(weights) 76 | 77 | def __call__(self, 78 | X: TensorType, 79 | X2: TensorType=None, 80 | *, 81 | full_cov: bool = False, 82 | full_spatial: bool = False, 83 | presliced: bool = False): 84 | 85 | if not presliced: 86 | X, X2 = self.slice(X, X2) 87 | 88 | if not full_cov and X2 is not None: 89 | raise ValueError( 90 | "Ambiguous inputs: passing in `X2` is not compatible with `full_cov=False`." 91 | ) 92 | 93 | if not full_cov: 94 | return self.K_diag(X, full_spatial=full_spatial) 95 | return self.K(X, X2, full_spatial=full_spatial) 96 | 97 | def K(self, X: tf.Tensor, X2: tf.Tensor = None, full_spatial: bool = False): 98 | """ 99 | TODO: For stationary kernels, implement this using convolutions? 100 | """ 101 | P = self.get_patches(X, full_spatial=full_spatial) 102 | P2 = P if X2 is None else self.get_patches(X2, full_spatial=full_spatial) 103 | K = self.kernel.K(P, P2) 104 | if full_spatial: 105 | return K # [N, H1, W1, N2, H2, W2] 106 | 107 | # At this point, shape(K) = [N, H1 * W1, N2, H2 * W2] 108 | if self.weights is None: 109 | return tf.reduce_mean(K, axis=[-3, -1]) 110 | return tf.tensordot(tf.linalg.matvec(K, self.weights), 111 | self.weights, 112 | axes=[-2, 0]) 113 | 114 | def K_diag(self, X: tf.Tensor, full_spatial: bool = False): 115 | P = self.get_patches(X, full_spatial=full_spatial) 116 | K = self.kernel.K_diag(P) 117 | if full_spatial: 118 | return K # [N, H1, W1] 119 | 120 | # At this point, shape(K) = [N, H1 * W1] 121 | if self.weights is None: 122 | return tf.reduce_mean(K, axis=[-2, -1]) 123 | return tf.linalg.matvec(tf.linalg.matvec(K, self.weights), self.weights) 124 | 125 | def convolve(self, 126 | input, 127 | filters, 128 | strides: List = None, 129 | padding: str = None, 130 | dilations: List = None, 131 | data_format: str = None): 132 | 133 | if strides is None: 134 | strides = self.strides 135 | 136 | if padding is None: 137 | padding = self.padding 138 | 139 | if dilations is None: 140 | dilations = self.dilations 141 | 142 | if data_format is None: 143 | data_format = self.data_format 144 | 145 | return tf.nn.conv2d(input=input, 146 | filters=filters, 147 | strides=strides, 148 | padding=padding, 149 | dilations=dilations, 150 | data_format=data_format) 151 | 152 | def get_patches(self, X: TensorType, full_spatial: bool = True): 153 | # Extract image patches 154 | X_nchw = conv_ops.reformat_data(X, self.data_format, "NHWC") 155 | patches = tf.image.extract_patches(images=X_nchw, 156 | sizes=[1] + self.patch_shape + [1], 157 | strides=[1] + self.strides + [1], 158 | rates=[1] + self.dilations + [1], 159 | padding=self.padding) 160 | 161 | if full_spatial: 162 | output_shape = list(X.shape[:-3]) + list(patches.shape[-3:]) 163 | else: 164 | output_shape = list(X.shape[:-3]) + [-1, patches.shape[-1]] 165 | 166 | return tf.reshape(patches, output_shape) 167 | 168 | def get_shape_out(self, 169 | shape_in: List, 170 | filter_shape: List, 171 | strides: List = None, 172 | dilations: List = None, 173 | data_format: str = None) -> List: 174 | 175 | if data_format is None: 176 | data_format = self.data_format 177 | 178 | if data_format == "NHWC": 179 | *batch, height, width, _ = list(shape_in) 180 | else: 181 | *batch, _, height, width = list(shape_in) 182 | 183 | spatial_out = self.get_spatial_out(spatial_in=[height, width], 184 | filter_shape=filter_shape[:2], 185 | strides=strides, 186 | dilations=dilations) 187 | 188 | nhwc_out = batch + spatial_out + [filter_shape[-1]] 189 | return conv_ops.reformat_shape(nhwc_out, "NHWC", data_format) 190 | 191 | def get_spatial_out(self, 192 | spatial_in: List = None, 193 | filter_shape: List = None, 194 | strides: List = None, 195 | padding: str = None, 196 | dilations: List = None) -> List: 197 | 198 | if spatial_in is None: 199 | spatial_in = self.image_shape 200 | 201 | if filter_shape is None: 202 | filter_shape = self.patch_shape 203 | else: 204 | assert len(filter_shape) == 2 205 | 206 | if strides is None: 207 | strides = self.strides 208 | 209 | if padding is None: 210 | padding = self.padding 211 | 212 | if dilations is None: 213 | dilations = self.dilations 214 | 215 | return [conv_output_length(input_length=spatial_in[i], 216 | filter_size=filter_shape[i], 217 | stride=strides[i], 218 | padding=padding.lower(), 219 | dilation=dilations[i]) for i in range(2)] 220 | 221 | @property 222 | def num_patches(self): 223 | return tf.reduce_prod(self.get_spatial_out()) 224 | 225 | @property 226 | def weights(self): 227 | if self._weights is None: 228 | return None 229 | return tf.cast(1/self.num_patches, self._weights.dtype) * self._weights 230 | 231 | @property 232 | def num_latent_gps(self): 233 | return self.channels_out 234 | 235 | @property 236 | def latent_kernels(self): 237 | return self.kernel, 238 | 239 | 240 | class Conv2dTranspose(Conv2d): 241 | def __init__(self, 242 | *args, 243 | strides: List = None, 244 | dilations: List = None, 245 | **kwargs): 246 | 247 | strides = list((1, 1) if strides is None else strides) 248 | dilations = list((1, 1) if dilations is None else dilations) 249 | if strides != [1, 1] and dilations != [1, 1]: 250 | raise NotImplementedError('Tensorflow does not implement transposed' 251 | 'convolutions with strides and dilations.') 252 | 253 | super().__init__(*args, strides=strides, dilations=dilations, **kwargs) 254 | 255 | def convolve(self, 256 | input, 257 | filters, 258 | strides: List = None, 259 | padding: str = None, 260 | dilations: List = None, 261 | data_format: str = None): 262 | 263 | if strides is None: 264 | strides = self.strides 265 | 266 | if padding is None: 267 | padding = self.padding 268 | 269 | if dilations is None: 270 | dilations = self.dilations 271 | 272 | if data_format is None: 273 | data_format = self.data_format 274 | 275 | shape_out = self.get_shape_out(shape_in=input.shape, 276 | filter_shape=filters.shape, 277 | strides=strides, 278 | dilations=dilations, 279 | data_format=data_format) 280 | 281 | _filters = swap_axes(filters, -2, -1) 282 | conv_kwargs = dict(filters=_filters, 283 | padding=padding, 284 | output_shape=shape_out) 285 | 286 | if dilations != [1, 1]: # TODO: improve me 287 | assert data_format == 'NHWC' 288 | assert list(strides) == [1, 1] 289 | assert len(dilations) == 2 and dilations[0] == dilations[1] 290 | return tf.nn.atrous_conv2d_transpose(input, 291 | rate=dilations[0], 292 | **conv_kwargs) 293 | 294 | return tf.nn.conv2d_transpose(input, 295 | strides=strides, 296 | dilations=dilations, 297 | data_format=data_format, 298 | **conv_kwargs) 299 | 300 | def get_patches(self, X, full_spatial: bool = False): 301 | """ 302 | Returns the patches used by a 2d transposed convolution. 303 | """ 304 | spatial_in = X.shape[-3: -1] 305 | 306 | # Pad X with (stride - 1) zeros in between each pixel 307 | if any(stride != 1 for stride in self.strides): 308 | shape = list(X.shape[:-3]) 309 | terms = [tf.range(size) for size in shape] 310 | for i, stride in enumerate(self.strides): 311 | size = X.shape[i - 3] 312 | shape.append(stride * (size - 1) + 1) 313 | terms.append(tf.range(stride * size, delta=stride)) 314 | shape.append(X.shape[-1]) 315 | grid = tf.meshgrid(*terms, indexing='ij') 316 | X = tf.scatter_nd(tf.stack(grid, -1), X, shape) 317 | 318 | # Prepare padding info 319 | if self.padding == "VALID": 320 | h_pad = 2 * [self.dilations[0] * (self.patch_shape[0] - 1)] 321 | w_pad = 2 * [self.dilations[1] * (self.patch_shape[1] - 1)] 322 | elif self.padding == "SAME": 323 | height_out, width_out = self.get_spatial_out(spatial_in) 324 | extra = height_out - X.shape[-3] 325 | h_pad = list(map(lambda x: tf.cast(x, tf.int64), 326 | (tf.math.ceil(0.5 * extra), tf.math.floor(0.5 * extra)))) 327 | 328 | extra = width_out - X.shape[-2] 329 | w_pad = list(map(lambda x: tf.cast(x, tf.int64), 330 | (tf.math.ceil(0.5 * extra), tf.math.floor(0.5 * extra)))) 331 | 332 | # Extract (flipped) image patches 333 | X_pad = tf.pad(X, [2 * [0], h_pad, w_pad, 2 * [0]]) 334 | patches = tf.image.extract_patches(images=X_pad, 335 | sizes=[1] + self.patch_shape + [1], 336 | strides=[1, 1, 1, 1], 337 | rates=[1] + self.dilations + [1], 338 | padding=self.padding) 339 | 340 | if full_spatial: 341 | output_shape = list(X.shape[:-3]) + list(patches.shape[-3:]) 342 | else: 343 | output_shape = list(X.shape[:-3]) + [-1, patches.shape[-1]] 344 | 345 | return tf.reshape(tf.reverse( # reverse channel-wise patches and reshape 346 | tf.reshape(patches, list(patches.shape[:-1]) + [-1, X.shape[-1]]), 347 | axis=[-2]), output_shape) 348 | 349 | def get_spatial_out(self, 350 | spatial_in: List = None, 351 | filter_shape: List = None, 352 | strides: List = None, 353 | padding: str = None, 354 | dilations: List = None) -> List: 355 | 356 | if spatial_in is None: 357 | spatial_in = self.image_shape 358 | 359 | if filter_shape is None: 360 | filter_shape = self.patch_shape 361 | else: 362 | assert len(filter_shape) == 2 363 | 364 | if strides is None: 365 | strides = self.strides 366 | 367 | if padding is None: 368 | padding = self.padding 369 | 370 | if dilations is None: 371 | dilations = self.dilations 372 | 373 | return [deconv_output_length(input_length=spatial_in[i], 374 | filter_size=filter_shape[i], 375 | stride=strides[i], 376 | padding=padding.lower(), 377 | dilation=dilations[i]) for i in range(2)] 378 | 379 | 380 | class DepthwiseConv2d(Conv2d): 381 | def __init__(self, 382 | kernel: kernels.Kernel, 383 | image_shape: List, 384 | patch_shape: List, 385 | channels_in: int = 1, 386 | channels_out: int = 1, 387 | weights: TensorType = "default", 388 | strides: List = None, 389 | padding: str = "VALID", 390 | dilations: List = None, 391 | data_format: str = "NHWC", 392 | **kwargs): 393 | 394 | strides = list((1, 1) if strides is None else strides) 395 | dilations = list((1, 1) if dilations is None else dilations) 396 | if strides != [1, 1] and dilations != [1, 1]: 397 | warn(f"{self.__class__} does not pass unit tests when strides != [1, 1]" 398 | f" and dilations != [1, 1] simultaneously.") 399 | 400 | if isinstance(weights, str) and weights == "default": # TODO: improve me 401 | spatial_out = self.get_spatial_out(spatial_in=image_shape, 402 | filter_shape=patch_shape, 403 | strides=strides, 404 | padding=padding, 405 | dilations=dilations) 406 | 407 | weights = tf.ones([tf.reduce_prod(spatial_out), channels_in], 408 | dtype=default_float()) 409 | 410 | super().__init__(kernel=kernel, 411 | image_shape=image_shape, 412 | patch_shape=patch_shape, 413 | channels_in=channels_in, 414 | channels_out=channels_out, 415 | weights=weights, 416 | strides=strides, 417 | padding=padding, 418 | dilations=dilations, 419 | data_format=data_format) 420 | 421 | def K(self, X: tf.Tensor, X2: tf.Tensor = None, full_spatial: bool = False): 422 | P = self.get_patches(X, full_spatial=full_spatial) 423 | P2 = P if X2 is None else self.get_patches(X2, full_spatial=full_spatial) 424 | 425 | # TODO: Temporary hack, use of self.kernel should be deprecated 426 | K = move_axis( 427 | tf.linalg.diag_part( 428 | move_axis(self.kernel.K(P, P2), P.shape.ndims - 2, -2)), -1, 0) 429 | 430 | if full_spatial: 431 | return K # [channels_in, N, H1, W1, N2, H2, W2] 432 | 433 | # At this point, shape(K) = [N, num_patches, N2, num_patches] 434 | if self.weights is None: 435 | return tf.reduce_mean(K, axis=[0, -3, -1]) 436 | 437 | K = batch_tensordot(K, self.weights, axes=[-1, 0], batch_axes=[0, 1]) 438 | K = batch_tensordot(K, self.weights, axes=[-2, 0], batch_axes=[0, 1]) 439 | return tf.reduce_mean(K, axis=0) 440 | 441 | def K_diag(self, X: tf.Tensor, full_spatial: bool = False): 442 | raise NotImplementedError 443 | 444 | P = self.get_patches(X, full_spatial=full_spatial) 445 | K = tf.reduce_mean(self.kernel.K(P), axis=-2) # average over channels 446 | if full_spatial: 447 | return K # [num_channels, N, H1, W1, H1, W1] 448 | 449 | # At this point, K has shape # [num_channels, N, num_patches, num_patches] 450 | if self.weights is None: 451 | return tf.reduce_mean(K, axis=[-2, -1]) 452 | 453 | K = batch_tensordot(K, self.weights, axes=[-1, 0], batch_axes=[0, 1]) 454 | K = batch_tensordot(K, self.weights, axes=[-1, 0], batch_axes=[0, 1]) 455 | return tf.reduce_mean(K, axis=0) 456 | 457 | def convolve(self, 458 | input, 459 | filters, 460 | strides: List = None, 461 | padding: str = None, 462 | dilations: List = None, 463 | data_format: str = None): 464 | 465 | if strides is None: 466 | strides = self.strides 467 | 468 | if padding is None: 469 | padding = self.padding 470 | 471 | if dilations is None: 472 | dilations = self.dilations 473 | 474 | if data_format is None: 475 | data_format = self.data_format 476 | 477 | return tf.nn.depthwise_conv2d(input=input, 478 | filter=filters, 479 | strides=[1] + strides + [1], 480 | padding=padding, 481 | dilations=dilations, 482 | data_format=data_format) 483 | 484 | def get_patches(self, X: TensorType, full_spatial: bool = False): 485 | """ 486 | Returns the patches used by a 2d depthwise convolution. 487 | """ 488 | patches = super().get_patches(X, full_spatial=full_spatial) 489 | channels_in = X.shape[-3 if self.data_format == "NCHW" else -1] 490 | depthwise_patches = tf.reshape(patches, 491 | list(patches.shape[:-1]) + [-1, channels_in]) 492 | return move_axis(depthwise_patches, -2, -1) 493 | 494 | def get_shape_out(self, 495 | shape_in: List, 496 | filter_shape: List, 497 | strides: List = None, 498 | dilations: List = None, 499 | data_format: str = None) -> List: 500 | 501 | if data_format is None: 502 | data_format = self.data_format 503 | 504 | if data_format == "NHWC": 505 | *batch, height, width, _ = list(shape_in) 506 | else: 507 | *batch, _, height, width = list(shape_in) 508 | 509 | spatial_out = self.get_spatial_out(spatial_in=[height, width], 510 | filter_shape=filter_shape[:2], 511 | strides=strides, 512 | dilations=dilations) 513 | 514 | nhwc_out = batch + spatial_out + [filter_shape[-2] * filter_shape[-1]] 515 | return conv_ops.reformat_shape(nhwc_out, "NHWC", data_format) 516 | -------------------------------------------------------------------------------- /gpflow_sampling/models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from abc import abstractmethod 11 | from typing import Optional 12 | from contextlib import contextmanager 13 | from gpflow.base import TensorLike 14 | from gpflow.config import default_float, default_jitter 15 | from gpflow.models import GPModel, SVGP, GPR 16 | from gpflow_sampling import sampling, covariances 17 | from gpflow_sampling.sampling.core import AbstractSampler, CompositeSampler 18 | 19 | 20 | # ---- Exports 21 | __all__ = ('PathwiseGPModel', 'PathwiseGPR', 'PathwiseSVGP') 22 | 23 | 24 | # ============================================== 25 | # models 26 | # ============================================== 27 | class PathwiseGPModel(GPModel): 28 | def __init__(self, *args, paths: AbstractSampler = None, **kwargs): 29 | super().__init__(*args, **kwargs) 30 | self._paths = paths 31 | 32 | @abstractmethod 33 | def generate_paths(self, *args, **kwargs) -> AbstractSampler: 34 | raise NotImplementedError 35 | 36 | def predict_f_samples(self, 37 | Xnew: TensorLike, 38 | num_samples: Optional[int] = None, 39 | full_cov: bool = True, 40 | full_output_cov: bool = True, 41 | **kwargs) -> tf.Tensor: 42 | 43 | assert full_cov and full_output_cov, NotImplementedError 44 | if self.paths is None: 45 | raise RuntimeError("Paths were not initialized.") 46 | 47 | if num_samples is not None: 48 | assert num_samples == self.paths.sample_shape,\ 49 | ValueError("Requested number of samples does not match path count.") 50 | 51 | return self.paths(Xnew, **kwargs) 52 | 53 | @contextmanager 54 | def temporary_paths(self, *args, **kwargs): 55 | try: 56 | init_paths = self.paths 57 | temp_paths = self.generate_paths(*args, **kwargs) 58 | self.set_paths(temp_paths) 59 | yield temp_paths 60 | finally: 61 | self.set_paths(init_paths) 62 | 63 | def set_paths(self, paths) -> AbstractSampler: 64 | self._paths = paths 65 | return paths 66 | 67 | @contextmanager 68 | def set_temporary_paths(self, paths): 69 | try: 70 | init_paths = self.paths 71 | self.set_paths(paths) 72 | yield paths 73 | finally: 74 | self.set_paths(init_paths) 75 | 76 | @property 77 | def paths(self) -> AbstractSampler: 78 | return self._paths 79 | 80 | 81 | class PathwiseGPR(GPR, PathwiseGPModel): 82 | def __init__(self, *args, paths: AbstractSampler = None, **kwargs): 83 | GPR.__init__(self, *args, **kwargs) 84 | self._paths = paths 85 | 86 | def generate_paths(self, 87 | num_samples: int, 88 | num_bases: int = None, 89 | prior: AbstractSampler = None, 90 | sample_axis: int = None, 91 | **kwargs) -> CompositeSampler: 92 | 93 | if prior is None: 94 | prior = sampling.priors.random_fourier(self.kernel, 95 | num_bases=num_bases, 96 | sample_shape=[num_samples], 97 | sample_axis=sample_axis) 98 | elif num_bases is not None: 99 | assert prior.sample_shape == [num_samples] 100 | 101 | diag = tf.convert_to_tensor(self.likelihood.variance) 102 | return sampling.decoupled(self.kernel, 103 | prior, 104 | *self.data, 105 | mean_function=self.mean_function, 106 | diag=diag, 107 | sample_axis=sample_axis, 108 | **kwargs) 109 | 110 | 111 | class PathwiseSVGP(SVGP, PathwiseGPModel): 112 | def __init__(self, *args, paths: AbstractSampler = None, **kwargs): 113 | SVGP.__init__(self, *args, **kwargs) 114 | self._paths = paths 115 | 116 | def generate_paths(self, 117 | num_samples: int, 118 | num_bases: int = None, 119 | prior: AbstractSampler = None, 120 | sample_axis: int = None, 121 | **kwargs) -> CompositeSampler: 122 | 123 | if prior is None: 124 | prior = sampling.priors.random_fourier(self.kernel, 125 | num_bases=num_bases, 126 | sample_shape=[num_samples], 127 | sample_axis=sample_axis) 128 | elif num_bases is not None: 129 | assert prior.sample_shape == [num_samples] 130 | 131 | return sampling.decoupled(self.kernel, 132 | prior, 133 | self.inducing_variable, 134 | self._generate_u(num_samples), 135 | mean_function=self.mean_function, 136 | sample_axis=sample_axis, 137 | **kwargs) 138 | 139 | def _generate_u(self, num_samples: int, L: tf.Tensor = None): 140 | """ 141 | Returns samples $u ~ q(u)$. 142 | """ 143 | q_sqrt = tf.linalg.band_part(self.q_sqrt, -1, 0) 144 | shape = self.num_latent_gps, q_sqrt.shape[-1], num_samples 145 | rvs = tf.random.normal(shape, dtype=default_float()) # [L, M, S] 146 | uT = q_sqrt @ rvs + tf.transpose(self.q_mu)[..., None] 147 | if self.whiten: 148 | if L is None: 149 | Z = self.inducing_variable 150 | K = covariances.Kuu(Z, self.kernel, jitter=default_jitter()) 151 | L = tf.linalg.cholesky(K) 152 | uT = L @ uT 153 | return tf.transpose(uT) # [S, M, L] 154 | -------------------------------------------------------------------------------- /gpflow_sampling/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from gpflow_sampling.sampling import core 6 | from gpflow_sampling.sampling import priors, updates 7 | from gpflow_sampling.sampling.decoupled_samplers import decoupled 8 | 9 | -------------------------------------------------------------------------------- /gpflow_sampling/sampling/core.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from abc import abstractmethod 11 | from typing import List, Callable, Union 12 | from gpflow.inducing_variables import InducingVariables 13 | from gpflow_sampling.utils import (move_axis, 14 | normalize_axis, 15 | batch_tensordot, 16 | get_inducing_shape) 17 | 18 | # ---- Exports 19 | __all__ = ( 20 | 'AbstractSampler', 21 | 'DenseSampler', 22 | 'MultioutputDenseSampler', 23 | 'CompositeSampler', 24 | ) 25 | 26 | 27 | # ============================================== 28 | # core 29 | # ============================================== 30 | class AbstractSampler(tf.Module): 31 | @abstractmethod 32 | def __call__(self, *args, **kwargs): 33 | raise NotImplementedError 34 | 35 | @property 36 | def sample_shape(self): 37 | raise NotImplementedError 38 | 39 | 40 | class CompositeSampler(AbstractSampler): 41 | def __init__(self, 42 | join_rule: Callable, 43 | samplers: List[Callable], 44 | mean_function: Callable = None, 45 | name: str = None): 46 | """ 47 | Combine base samples via a specified join rule. 48 | """ 49 | super().__init__(name=name) 50 | self._join_rule = join_rule 51 | self._samplers = samplers 52 | self.mean_function = mean_function 53 | 54 | def __call__(self, x: tf.Tensor, **kwargs) -> tf.Tensor: 55 | samples = [sampler(x, **kwargs) for sampler in self.samplers] 56 | vals = self.join_rule(samples) 57 | return vals if self.mean_function is None else vals + self.mean_function(x) 58 | 59 | @property 60 | def join_rule(self) -> Callable: 61 | return self._join_rule 62 | 63 | @property 64 | def samplers(self): 65 | return self._samplers 66 | 67 | @property 68 | def sample_shape(self): 69 | for i, sampler in enumerate(self.samplers): 70 | if i == 0: 71 | sample_shape = sampler.sample_shape 72 | else: 73 | assert sample_shape == sampler.sample_shape 74 | return sample_shape 75 | 76 | 77 | class DenseSampler(AbstractSampler): 78 | def __init__(self, 79 | weights: Union[tf.Tensor, tf.Variable], 80 | basis: Callable = None, 81 | mean_function: Callable = None, 82 | sample_axis: int = None, 83 | name: str = None): 84 | """ 85 | Return samples as weighted sums of features. 86 | """ 87 | assert weights.shape.ndims > 1 88 | super().__init__(name=name) 89 | self.weights = weights 90 | self.basis = basis 91 | self.mean_function = mean_function 92 | self.sample_axis = sample_axis 93 | 94 | def __call__(self, x: tf.Tensor, sample_axis: int = "default", **kwargs): 95 | """ 96 | :param sample_axis: Specify an axis of inputs x as corresponding 1-to-1 with 97 | sample-specific slices of weight tensor w when computing tensor dot 98 | products. 99 | """ 100 | if sample_axis == "default": 101 | sample_axis = self.sample_axis 102 | 103 | feat = x if self.basis is None else self.basis(x, **kwargs) 104 | if sample_axis is None: 105 | batch_axes = None 106 | else: 107 | assert len(self.sample_shape), "Received sample_axis but self.weights has" \ 108 | " no dedicated axis for samples; this" \ 109 | " usually implies that sample_shape=[]." 110 | 111 | ndims_x = len(get_inducing_shape(x) if 112 | isinstance(x, InducingVariables) else x.shape) 113 | 114 | batch_axes = [-3, normalize_axis(sample_axis, ndims_x) - ndims_x] 115 | 116 | # Batch-axes notwithstanding, shape(vals) = [N, S] (after move_axis) 117 | vals = move_axis(batch_tensordot(self.weights, 118 | feat, 119 | axes=[-1, -1], 120 | batch_axes=batch_axes), 121 | len(self.sample_shape), # axis=-2 houses scalar output 1 122 | -1) 123 | 124 | return vals if self.mean_function is None else vals + self.mean_function(x) 125 | 126 | @property 127 | def sample_shape(self): 128 | w_shape = list(self.weights.shape) 129 | if len(w_shape) == 2: 130 | return [] 131 | return w_shape[:-2] 132 | 133 | 134 | class MultioutputDenseSampler(DenseSampler): 135 | def __init__(self, 136 | weights: Union[tf.Tensor, tf.Variable], 137 | *args, 138 | multioutput_axis: int = None, 139 | **kwargs): 140 | """ 141 | :param multioutput_axis: Specify an axis of inputs x (or features thereof) 142 | as corresponding 1-to-1 with output-feature-specific slices of weight 143 | tensor w when computing tensor dot products. 144 | """ 145 | self.multioutput_axis = multioutput_axis 146 | super().__init__(weights, *args, **kwargs) 147 | 148 | def __call__(self, x: tf.Tensor, sample_axis: int = "default", **kwargs): 149 | """ 150 | :param sample_axis: Specify an axis of inputs x as corresponding 1-to-1 with 151 | sample-specific slices of weight tensor w when computing tensor dot 152 | products. 153 | 154 | TODO: Improve hard-coding of multioutput-/sample-axis of weights. 155 | """ 156 | if sample_axis == "default": 157 | sample_axis = self.sample_axis 158 | 159 | if self.multioutput_axis is None: 160 | batch_w = [] # batch axes for w 161 | batch_x = [] # batch axes for x 162 | else: 163 | batch_w = [-2] 164 | batch_x = [self.multioutput_axis] 165 | 166 | if sample_axis is not None: 167 | assert len(self.sample_shape), "Received sample_axis but self.weights has" \ 168 | " no dedicated axis for samples; this" \ 169 | " usually implies that sample_shape=[]." 170 | batch_w.append(-3) 171 | 172 | # TODO: If basis(x) grows the rank of x, it should only do so from the 173 | # left, such that the negative i-th axis (i > 1) remains the same. 174 | ndims_x = len(get_inducing_shape(x) if 175 | isinstance(x, InducingVariables) else x.shape) 176 | batch_x.append(normalize_axis(sample_axis, ndims_x) - ndims_x) 177 | 178 | feat = x if self.basis is None else self.basis(x, **kwargs) 179 | vals = move_axis(batch_tensordot(self.weights, # output features go last 180 | feat, 181 | axes=[-1, -1], 182 | batch_axes=[batch_w, batch_x]), 183 | len(self.sample_shape), # axis=-2 houses multioutputs L 184 | -1) 185 | 186 | return vals if self.mean_function is None else vals + self.mean_function(x) 187 | -------------------------------------------------------------------------------- /gpflow_sampling/sampling/decoupled_samplers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from typing import List, Callable 11 | from gpflow import inducing_variables 12 | from gpflow.base import TensorLike 13 | from gpflow.utilities import Dispatcher 14 | from gpflow.kernels import Kernel, MultioutputKernel, LinearCoregionalization 15 | from gpflow_sampling.sampling.updates import exact as exact_update 16 | from gpflow_sampling.sampling.core import AbstractSampler, CompositeSampler 17 | from gpflow_sampling.kernels import Conv2d 18 | from gpflow_sampling.inducing_variables import InducingImages 19 | 20 | # ---- Exports 21 | __all__ = ('decoupled',) 22 | decoupled = Dispatcher("decoupled") 23 | 24 | 25 | # ============================================== 26 | # decoupled_samplers 27 | # ============================================== 28 | @decoupled.register(Kernel, AbstractSampler, TensorLike, TensorLike) 29 | def _decoupled_fallback(kern: Kernel, 30 | prior: AbstractSampler, 31 | Z: TensorLike, 32 | u: TensorLike, 33 | *, 34 | mean_function: Callable = None, 35 | update_rule: Callable = exact_update, 36 | join_rule: Callable = sum, 37 | **kwargs): 38 | 39 | f = prior(Z, sample_axis=None) # [S, M, L] 40 | update = update_rule(kern, Z, u, f, **kwargs) 41 | return CompositeSampler(samplers=[prior, update], 42 | join_rule=join_rule, 43 | mean_function=mean_function) 44 | 45 | 46 | @decoupled.register(MultioutputKernel, AbstractSampler, TensorLike, TensorLike) 47 | def _decoupled_multioutput(kern: MultioutputKernel, 48 | prior: AbstractSampler, 49 | Z: TensorLike, 50 | u: TensorLike, 51 | *, 52 | mean_function: Callable = None, 53 | update_rule: Callable = exact_update, 54 | join_rule: Callable = sum, 55 | multioutput_axis_Z: int = "default", 56 | **kwargs): 57 | 58 | # Determine whether or not to evalaute Z pathwise (per output feature) 59 | # TODO: Ugly. This argument is actually being passed to the prior's basis. 60 | # Disallow non-inducing-variable Z for multioutput cases of decoupled? 61 | if multioutput_axis_Z == "default": 62 | if isinstance(Z, inducing_variables.MultioutputInducingVariables) and not\ 63 | isinstance(Z, inducing_variables.SharedIndependentInducingVariables): 64 | multioutput_axis_Z = 0 65 | else: 66 | multioutput_axis_Z = None 67 | 68 | f = prior(Z, sample_axis=None, multioutput_axis=multioutput_axis_Z) 69 | update = update_rule(kern, Z, u, f, **kwargs) 70 | return CompositeSampler(samplers=[prior, update], 71 | join_rule=join_rule, 72 | mean_function=mean_function) 73 | 74 | 75 | @decoupled.register(LinearCoregionalization, AbstractSampler, TensorLike, TensorLike) 76 | def _decoupled_lcm(kern: LinearCoregionalization, 77 | prior: AbstractSampler, 78 | Z: TensorLike, 79 | u: TensorLike, 80 | *, 81 | join_rule: Callable = None, 82 | **kwargs): 83 | if join_rule is None: 84 | def join_rule(terms: List[tf.Tensor]) -> tf.Tensor: 85 | return tf.tensordot(kern.W, sum(terms), axes=[-1, 0]) 86 | return _decoupled_multioutput(kern, prior, Z, u, join_rule=join_rule, **kwargs) 87 | 88 | 89 | @decoupled.register(Conv2d, AbstractSampler, InducingImages, TensorLike) 90 | def _decoupled_conv(kern: Conv2d, 91 | prior: AbstractSampler, 92 | Z: InducingImages, 93 | u: TensorLike, 94 | *, 95 | mean_function: Callable = None, 96 | update_rule: Callable = exact_update, 97 | join_rule: Callable = sum, 98 | **kwargs): 99 | 100 | f = tf.squeeze(prior(Z, sample_axis=None), axis=[-3, -2]) # [S, M, L] 101 | update = update_rule(kern, Z, u, f, **kwargs) 102 | return CompositeSampler(samplers=[prior, update], 103 | join_rule=join_rule, 104 | mean_function=mean_function) 105 | -------------------------------------------------------------------------------- /gpflow_sampling/sampling/priors/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from gpflow_sampling.sampling.priors.fourier_priors import random_fourier 5 | 6 | -------------------------------------------------------------------------------- /gpflow_sampling/sampling/priors/fourier_priors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from typing import Any, List, Callable 11 | from gpflow.config import default_float 12 | from gpflow.kernels import Kernel, MultioutputKernel 13 | from gpflow.utilities import Dispatcher 14 | from gpflow_sampling.bases import fourier as fourier_basis 15 | from gpflow_sampling.sampling.core import DenseSampler, MultioutputDenseSampler 16 | from gpflow_sampling.kernels import Conv2d, DepthwiseConv2d 17 | 18 | 19 | # ---- Exports 20 | __all__ = ('random_fourier',) 21 | random_fourier = Dispatcher("random_fourier") 22 | 23 | 24 | # ============================================== 25 | # fourier_priors 26 | # ============================================== 27 | @random_fourier.register(Kernel) 28 | def _random_fourier(kernel: Kernel, 29 | sample_shape: List, 30 | num_bases: int, 31 | basis: Callable = None, 32 | dtype: Any = None, 33 | name: str = None, 34 | **kwargs): 35 | 36 | if dtype is None: 37 | dtype = default_float() 38 | 39 | if basis is None: 40 | basis = fourier_basis(kernel, num_bases=num_bases) 41 | 42 | weights = tf.random.normal(list(sample_shape) + [1, num_bases], dtype=dtype) 43 | return DenseSampler(weights=weights, basis=basis, name=name, **kwargs) 44 | 45 | 46 | @random_fourier.register(MultioutputKernel) 47 | def _random_fourier_multioutput(kernel: MultioutputKernel, 48 | sample_shape: List, 49 | num_bases: int, 50 | basis: Callable = None, 51 | dtype: Any = None, 52 | name: str = None, 53 | multioutput_axis: int = 0, 54 | **kwargs): 55 | if dtype is None: 56 | dtype = default_float() 57 | 58 | if basis is None: 59 | basis = fourier_basis(kernel, num_bases=num_bases) 60 | 61 | shape = list(sample_shape) + [kernel.num_latent_gps, num_bases] 62 | weights = tf.random.normal(shape, dtype=dtype) 63 | return MultioutputDenseSampler(name=name, 64 | basis=basis, 65 | weights=weights, 66 | multioutput_axis=multioutput_axis, 67 | **kwargs) 68 | 69 | 70 | @random_fourier.register(Conv2d) 71 | def _random_fourier_conv(kernel: Conv2d, 72 | sample_shape: List, 73 | num_bases: int, 74 | basis: Callable = None, 75 | dtype: Any = None, 76 | name: str = None, 77 | **kwargs): 78 | 79 | if dtype is None: 80 | dtype = default_float() 81 | 82 | if basis is None: 83 | basis = fourier_basis(kernel, num_bases=num_bases) 84 | 85 | shape = list(sample_shape) + [kernel.num_latent_gps, num_bases] 86 | weights = tf.random.normal(shape, dtype=dtype) 87 | return MultioutputDenseSampler(weights=weights, 88 | basis=basis, 89 | name=name, 90 | **kwargs) 91 | 92 | 93 | @random_fourier.register(DepthwiseConv2d) 94 | def _random_fourier_depthwise_conv(kernel: DepthwiseConv2d, 95 | sample_shape: List, 96 | num_bases: int, 97 | basis: Callable = None, 98 | dtype: Any = None, 99 | name: str = None, 100 | **kwargs): 101 | 102 | if dtype is None: 103 | dtype = default_float() 104 | 105 | if basis is None: 106 | basis = fourier_basis(kernel, num_bases=num_bases) 107 | 108 | channels_out = num_bases * kernel.channels_in 109 | shape = list(sample_shape) + [kernel.num_latent_gps, channels_out] 110 | weights = tf.random.normal(shape, dtype=dtype) 111 | return MultioutputDenseSampler(weights=weights, 112 | basis=basis, 113 | name=name, 114 | **kwargs) 115 | -------------------------------------------------------------------------------- /gpflow_sampling/sampling/updates/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from gpflow_sampling.sampling.updates.linear_updates import linear 5 | from gpflow_sampling.sampling.updates.exact_updates import exact 6 | from gpflow_sampling.sampling.updates.cg_updates import cg 7 | -------------------------------------------------------------------------------- /gpflow_sampling/sampling/updates/cg_updates.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from warnings import warn 11 | from gpflow.base import TensorLike 12 | from gpflow.config import default_jitter 13 | from gpflow.utilities import Dispatcher 14 | from gpflow import kernels, inducing_variables 15 | from gpflow_sampling import covariances, kernels as kernels_ext 16 | from gpflow_sampling.bases import kernel as kernel_basis 17 | from gpflow_sampling.bases.core import AbstractBasis 18 | from gpflow_sampling.sampling.core import DenseSampler, MultioutputDenseSampler 19 | from gpflow_sampling.inducing_variables import InducingImages 20 | from gpflow_sampling.utils import get_default_preconditioner 21 | 22 | 23 | # ============================================== 24 | # cg_updates 25 | # ============================================== 26 | cg = Dispatcher("cg_updates") 27 | 28 | 29 | @cg.register(kernels.Kernel, TensorLike, TensorLike, TensorLike) 30 | def _cg_fallback(kern: kernels.Kernel, 31 | Z: TensorLike, 32 | u: TensorLike, 33 | f: TensorLike, 34 | *, 35 | diag: TensorLike = None, 36 | basis: AbstractBasis = None, 37 | preconditioner: tf.linalg.LinearOperator = "default", 38 | tol: float = 1e-3, 39 | max_iter: int = 100, 40 | **kwargs): 41 | """ 42 | Return pathwise updates of a prior processes $f$ subject to the 43 | condition $p(f | u) = N(f | u, diag)$ on $f = f(Z)$. 44 | """ 45 | u_shape = tuple(u.shape) 46 | f_shape = tuple(f.shape) 47 | assert u_shape[-1] == 1, "Recieved multiple output features" 48 | assert u_shape == f_shape[-len(u_shape):], "Incompatible shapes detected" 49 | if basis is None: # finite-dimensional basis used to express the update 50 | basis = kernel_basis(kern, centers=Z) 51 | 52 | if diag is None: 53 | diag = default_jitter() 54 | 55 | # Prepare linear system for CG solver 56 | if isinstance(Z, inducing_variables.InducingVariables): 57 | Kff = covariances.Kuu(Z, kern, jitter=0.0) 58 | else: 59 | Kff = kern(Z, full_cov=True) 60 | Kuu = tf.linalg.set_diag(Kff, tf.linalg.diag_part(Kff) + diag) 61 | operator = tf.linalg.LinearOperatorFullMatrix(Kuu, 62 | is_non_singular=True, 63 | is_self_adjoint=True, 64 | is_positive_definite=True, 65 | is_square=True) 66 | 67 | if preconditioner == "default": 68 | preconditioner = get_default_preconditioner(Kff, diag=diag) 69 | 70 | # Compute error term and CG initializer 71 | err = tf.linalg.adjoint(u - f) # [S, 1, M] 72 | err -= (diag ** 0.5) * tf.random.normal(err.shape, dtype=err.dtype) 73 | if preconditioner is None: 74 | initializer = None 75 | else: 76 | initializer = preconditioner.matvec(err) 77 | 78 | # Approximately solve for $Cov(u, u)^{-1}(u - f(Z))$ using linear CG 79 | res = tf.linalg.experimental.conjugate_gradient(operator=operator, 80 | rhs=err, 81 | preconditioner=preconditioner, 82 | x=initializer, 83 | tol=tol, 84 | max_iter=max_iter) 85 | 86 | weights = res.x 87 | if tf.math.count_nonzero(tf.math.is_nan(weights)): 88 | warn("One or more update weights returned by CG are NaN") 89 | 90 | return DenseSampler(basis=basis, weights=weights, **kwargs) 91 | 92 | 93 | @cg.register((kernels.SharedIndependent, 94 | kernels.SeparateIndependent, 95 | kernels.LinearCoregionalization), 96 | TensorLike, 97 | TensorLike, 98 | TensorLike) 99 | def _cg_independent(kern: kernels.MultioutputKernel, 100 | Z: TensorLike, 101 | u: TensorLike, 102 | f: TensorLike, 103 | *, 104 | diag: TensorLike = None, 105 | basis: AbstractBasis = None, 106 | preconditioner: tf.linalg.LinearOperator = "default", 107 | tol: float = 1e-3, 108 | max_iter: int = 100, 109 | multioutput_axis: int = 0, 110 | **kwargs): 111 | """ 112 | Return (independent) pathwise updates for each of the latent prior processes 113 | $f$ subject to the condition $p(f | u) = N(f | u, diag)$ on $f = f(Z)$. 114 | """ 115 | u_shape = tuple(u.shape) 116 | f_shape = tuple(f.shape) 117 | assert u_shape[-1] == kern.num_latent_gps, "Num. outputs != num. latent GPs" 118 | assert u_shape == f_shape[-len(u_shape):], "Incompatible shapes detected" 119 | if basis is None: # finite-dimensional basis used to express the update 120 | basis = kernel_basis(kern, centers=Z) 121 | 122 | if diag is None: 123 | diag = default_jitter() 124 | 125 | # Prepare linear system for CG solver 126 | if isinstance(Z, inducing_variables.InducingVariables): 127 | Kff = covariances.Kuu(Z, kern, jitter=0.0) 128 | else: 129 | Kff = kern(Z, full_cov=True, full_output_cov=False) 130 | Kuu = tf.linalg.set_diag(Kff, tf.linalg.diag_part(Kff) + diag) 131 | operator = tf.linalg.LinearOperatorFullMatrix(Kuu, 132 | is_non_singular=True, 133 | is_self_adjoint=True, 134 | is_positive_definite=True, 135 | is_square=True) 136 | 137 | if preconditioner == "default": 138 | preconditioner = get_default_preconditioner(Kff, diag=diag) 139 | 140 | err = tf.linalg.adjoint(u - f) # [S, L, M] 141 | err -= (diag ** 0.5) * tf.random.normal(err.shape, dtype=err.dtype) 142 | if preconditioner is None: 143 | initializer = None 144 | else: 145 | initializer = preconditioner.matvec(err) 146 | 147 | # Approximately solve for $Cov(u, u)^{-1}(u - f(Z))$ using linear CG 148 | res = tf.linalg.experimental.conjugate_gradient(operator=operator, 149 | rhs=err, 150 | preconditioner=preconditioner, 151 | x=initializer, 152 | tol=tol, 153 | max_iter=max_iter) 154 | 155 | weights = res.x 156 | if tf.math.count_nonzero(tf.math.is_nan(weights)): 157 | warn("One or more update weights returned by CG are NaN") 158 | 159 | return MultioutputDenseSampler(basis=basis, 160 | weights=weights, 161 | multioutput_axis=multioutput_axis, 162 | **kwargs) 163 | 164 | 165 | @cg.register(kernels.SharedIndependent, 166 | inducing_variables.SharedIndependentInducingVariables, 167 | TensorLike, 168 | TensorLike) 169 | def _cg_shared(kern, Z, u, f, *, multioutput_axis=None, **kwargs): 170 | """ 171 | Edge-case where the multioutput axis gets suppressed. 172 | """ 173 | return _cg_independent(kern, 174 | Z, 175 | u, 176 | f, 177 | multioutput_axis=multioutput_axis, 178 | **kwargs) 179 | 180 | 181 | @cg.register(kernels_ext.Conv2d, InducingImages, TensorLike, TensorLike) 182 | def _cg_conv2d(kern, Z, u, f, *, basis=None, multioutput_axis=None, **kwargs): 183 | if basis is None: # finite-dimensional basis used to express the update 184 | basis = kernel_basis(kern, centers=Z, full_spatial=True) 185 | return _cg_independent(kern, 186 | Z, 187 | u, 188 | f, 189 | basis=basis, 190 | multioutput_axis=multioutput_axis, 191 | **kwargs) 192 | -------------------------------------------------------------------------------- /gpflow_sampling/sampling/updates/exact_updates.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from gpflow import kernels, inducing_variables 11 | from gpflow.base import TensorLike 12 | from gpflow.config import default_jitter 13 | from gpflow.utilities import Dispatcher 14 | from gpflow_sampling import covariances, kernels as kernels_ext 15 | from gpflow_sampling.utils import swap_axes, move_axis 16 | from gpflow_sampling.bases import kernel as kernel_basis 17 | from gpflow_sampling.bases.core import AbstractBasis 18 | from gpflow_sampling.sampling.core import DenseSampler, MultioutputDenseSampler 19 | from gpflow_sampling.inducing_variables import InducingImages 20 | 21 | 22 | # ============================================== 23 | # exact_updates 24 | # ============================================== 25 | exact = Dispatcher("exact_updates") 26 | 27 | 28 | @exact.register(kernels.Kernel, TensorLike, TensorLike, TensorLike) 29 | def _exact_fallback(kern: kernels.Kernel, 30 | Z: TensorLike, 31 | u: TensorLike, 32 | f: TensorLike, 33 | *, 34 | L : TensorLike = None, 35 | diag: TensorLike = None, 36 | basis: AbstractBasis = None, 37 | **kwargs): 38 | """ 39 | Return pathwise updates of a prior processes $f$ subject to the 40 | condition $p(f | u) = N(f | u, diag)$ on $f = f(Z)$. 41 | """ 42 | u_shape = tuple(u.shape) 43 | f_shape = tuple(f.shape) 44 | assert u_shape[-1] == 1, "Recieved multiple output features" 45 | assert u_shape == f_shape[-len(u_shape):], "Incompatible shapes detected" 46 | if basis is None: # finite-dimensional basis used to express the update 47 | basis = kernel_basis(kern, centers=Z) 48 | 49 | # Prepare diagonal term 50 | if diag is None: 51 | diag = default_jitter() 52 | if isinstance(diag, float): 53 | diag = tf.convert_to_tensor(diag, dtype=f.dtype) 54 | diag = tf.expand_dims(diag, axis=-1) # [M, 1] or [1, 1] or [1] 55 | 56 | # Compute error term and matrix square root $Cov(u, u)^{1/2}$ 57 | err = u - f # [S, M, 1] 58 | err -= tf.sqrt(diag) * tf.random.normal(err.shape, dtype=err.dtype) 59 | if L is None: 60 | if isinstance(Z, inducing_variables.InducingVariables): 61 | K = covariances.Kuu(Z, kern, jitter=0.0) 62 | else: 63 | K = kern(Z, full_cov=True) 64 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag[..., 0]) 65 | L = tf.linalg.cholesky(K) 66 | 67 | # Solve for $Cov(u, u)^{-1}(u - f(Z))$ 68 | weights = tf.linalg.adjoint(tf.linalg.cholesky_solve(L, err)) 69 | return DenseSampler(basis=basis, weights=weights, **kwargs) 70 | 71 | 72 | @exact.register((kernels.SharedIndependent, 73 | kernels.SeparateIndependent, 74 | kernels.LinearCoregionalization), 75 | TensorLike, 76 | TensorLike, 77 | TensorLike) 78 | def _exact_independent(kern: kernels.MultioutputKernel, 79 | Z: TensorLike, 80 | u: TensorLike, 81 | f: TensorLike, 82 | *, 83 | L: TensorLike = None, 84 | diag: TensorLike = None, 85 | basis: AbstractBasis = None, 86 | multioutput_axis: int = 0, 87 | **kwargs): 88 | """ 89 | Return (independent) pathwise updates for each of the latent prior processes 90 | $f$ subject to the condition $p(f | u) = N(f | u, diag)$ on $f = f(Z)$. 91 | """ 92 | u_shape = tuple(u.shape) 93 | f_shape = tuple(f.shape) 94 | assert u_shape[-1] == kern.num_latent_gps, "Num. outputs != num. latent GPs" 95 | assert u_shape == f_shape[-len(u_shape):], "Incompatible shapes detected" 96 | if basis is None: # finite-dimensional basis used to express the update 97 | basis = kernel_basis(kern, centers=Z) 98 | 99 | # Prepare diagonal term 100 | if diag is None: # used by 101 | diag = default_jitter() 102 | if isinstance(diag, float): 103 | diag = tf.convert_to_tensor(diag, dtype=f.dtype) 104 | diag = tf.expand_dims(diag, axis=-1) # ([L] or []) + ([M] or []) + [1] 105 | 106 | # Compute error term and matrix square root $Cov(u, u)^{1/2}$ 107 | err = swap_axes(u - f, -3, -1) # [L, M, S] 108 | err -= tf.sqrt(diag) * tf.random.normal(err.shape, dtype=err.dtype) 109 | if L is None: 110 | if isinstance(Z, inducing_variables.InducingVariables): 111 | K = covariances.Kuu(Z, kern, jitter=0.0) 112 | else: 113 | K = kern(Z, full_cov=True, full_output_cov=False) 114 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag[..., 0]) 115 | L = tf.linalg.cholesky(K) 116 | 117 | # Solve for $Cov(u, u)^{-1}(u - f(Z))$ 118 | weights = move_axis(tf.linalg.cholesky_solve(L, err), -1, -3) # [S, L, M] 119 | return MultioutputDenseSampler(basis=basis, 120 | weights=weights, 121 | multioutput_axis=multioutput_axis, 122 | **kwargs) 123 | 124 | 125 | @exact.register(kernels.SharedIndependent, 126 | inducing_variables.SharedIndependentInducingVariables, 127 | TensorLike, 128 | TensorLike) 129 | def _exact_shared(kern, Z, u, f, *, multioutput_axis=None, **kwargs): 130 | """ 131 | Edge-case where the multioutput axis gets suppressed. 132 | """ 133 | return _exact_independent(kern, 134 | Z, 135 | u, 136 | f, 137 | multioutput_axis=multioutput_axis, 138 | **kwargs) 139 | 140 | 141 | @exact.register(kernels_ext.Conv2d, InducingImages, TensorLike, TensorLike) 142 | def _exact_conv2d(kern, Z, u, f, *, basis=None, multioutput_axis=None, **kwargs): 143 | if basis is None: # finite-dimensional basis used to express the update 144 | basis = kernel_basis(kern, centers=Z, full_spatial=True) 145 | return _exact_independent(kern, 146 | Z, 147 | u, 148 | f, 149 | basis=basis, 150 | multioutput_axis=multioutput_axis, 151 | **kwargs) 152 | -------------------------------------------------------------------------------- /gpflow_sampling/sampling/updates/linear_updates.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | """ 8 | Pathwise updates for Gaussian processes with linear 9 | kernels in some explicit, finite-dimensional basis. 10 | """ 11 | # ---- Imports 12 | import tensorflow as tf 13 | 14 | from gpflow import inducing_variables 15 | from gpflow.base import TensorLike 16 | from gpflow.config import default_jitter 17 | from gpflow.utilities import Dispatcher 18 | from gpflow_sampling.utils import swap_axes, move_axis, inducing_to_tensor 19 | from gpflow_sampling.bases.core import AbstractBasis 20 | from gpflow_sampling.sampling.core import DenseSampler, MultioutputDenseSampler 21 | 22 | 23 | # ============================================== 24 | # linear_update 25 | # ============================================== 26 | linear = Dispatcher("linear_updates") 27 | 28 | 29 | @linear.register(TensorLike, TensorLike, TensorLike) 30 | def _linear_fallback(Z: TensorLike, 31 | u: TensorLike, 32 | f: TensorLike, 33 | *, 34 | L : TensorLike = None, 35 | diag: TensorLike = None, 36 | basis: AbstractBasis = None, 37 | **kwargs): 38 | 39 | u_shape = tuple(u.shape) 40 | f_shape = tuple(f.shape) 41 | assert u_shape[-1] == 1, "Recieved multiple output features" 42 | assert u_shape == f_shape[-len(u_shape):], "Incompatible shapes detected" 43 | 44 | # Prepare diagonal term 45 | if diag is None: # used by 46 | diag = default_jitter() 47 | if isinstance(diag, float): 48 | diag = tf.convert_to_tensor(diag, dtype=f.dtype) 49 | diag = tf.expand_dims(diag, axis=-1) # [M, 1] or [1, 1] or [1] 50 | 51 | # Extract "features" of Z 52 | if basis is None: 53 | if isinstance(Z, inducing_variables.InducingVariables): 54 | feat = inducing_to_tensor(Z) # [M, D] 55 | else: 56 | feat = Z 57 | else: 58 | feat = basis(Z) # [M, D] (maybe a different "D" than above) 59 | 60 | # Compute error term and matrix square root $Cov(u, u)^{1/2}$ 61 | err = swap_axes(u - f, -3, -1) # [1, M, S] 62 | err -= tf.sqrt(diag) * tf.random.normal(err.shape, dtype=err.dtype) 63 | M, D = feat.shape[-2:] 64 | if L is None: 65 | if D < M: 66 | feat_iDiag = feat * tf.math.reciprocal(diag) 67 | S = tf.matmul(feat_iDiag, feat, transpose_a=True) # [D, D] 68 | L = tf.linalg.cholesky(S + tf.eye(S.shape[-1], dtype=S.dtype)) 69 | else: 70 | K = tf.matmul(feat, feat, transpose_b=True) # [M, M] 71 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag[..., 0]) 72 | L = tf.linalg.cholesky(K) 73 | else: 74 | assert L.shape[-1] == min(M, D) # TODO: improve me 75 | 76 | # Solve for $Cov(u, u)^{-1}(u - f(Z))$ 77 | if D < M: 78 | feat_iDiag = feat * tf.math.reciprocal(diag) 79 | weights = tf.linalg.adjoint(tf.linalg.cholesky_solve(L, 80 | tf.matmul(feat_iDiag, err, transpose_a=True))) 81 | else: 82 | iK_err = tf.linalg.cholesky_solve(L, err) # [S, M, 1] 83 | weights = tf.matmul(iK_err, feat, transpose_a=True) # [S, 1, D] 84 | 85 | return DenseSampler(basis=basis, weights=move_axis(weights, -2, -3), **kwargs) 86 | 87 | 88 | @linear.register(inducing_variables.MultioutputInducingVariables, 89 | TensorLike, 90 | TensorLike) 91 | def _linear_multioutput(Z: inducing_variables.MultioutputInducingVariables, 92 | u: TensorLike, 93 | f: TensorLike, 94 | *, 95 | L: TensorLike = None, 96 | diag: TensorLike = None, 97 | basis: AbstractBasis = None, 98 | multioutput_axis: int = "default", 99 | **kwargs): 100 | assert tuple(u.shape) == tuple(f.shape) 101 | if multioutput_axis == "default": 102 | multioutput_axis = None if (basis is None) else 0 103 | 104 | # Prepare diagonal term 105 | if diag is None: # used by 106 | diag = default_jitter() 107 | if isinstance(diag, float): 108 | diag = tf.convert_to_tensor(diag, dtype=f.dtype) 109 | diag = tf.expand_dims(diag, axis=-1) # ([L] or []) + ([M] or []) + [1] 110 | 111 | # Extract "features" of Z 112 | if basis is None: 113 | if isinstance(Z, inducing_variables.InducingVariables): 114 | feat = inducing_to_tensor(Z) # [L, M, D] or [M, D] 115 | else: 116 | feat = Z 117 | elif isinstance(Z, inducing_variables.SharedIndependentInducingVariables): 118 | feat = basis(Z) 119 | else: 120 | feat = basis(Z, multioutput_axis=0) # first axis of Z is output-specific 121 | 122 | # Compute error term and matrix square root $Cov(u, u)^{1/2}$ 123 | err = swap_axes(u - f, -3, -1) # [L, M, S] 124 | err -= tf.sqrt(diag) * tf.random.normal(err.shape, dtype=err.dtype) 125 | M, D = feat.shape[-2:] 126 | if L is None: 127 | if D < M: 128 | feat_iDiag = feat * tf.math.reciprocal(diag) 129 | S = tf.matmul(feat_iDiag, feat, transpose_a=True) # [L, D, D] or [D, D] 130 | L = tf.linalg.cholesky(S + tf.eye(S.shape[-1], dtype=S.dtype)) 131 | else: 132 | K = tf.matmul(feat, feat, transpose_b=True) # [L, M, M] or [M, M] 133 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag[..., 0]) 134 | L = tf.linalg.cholesky(K) 135 | else: 136 | assert L.shape[-1] == min(M, D) # TODO: improve me 137 | 138 | # Solve for $Cov(u, u)^{-1}(u - f(Z))$ 139 | if D < M: 140 | feat_iDiag = feat * tf.math.reciprocal(diag) 141 | weights = tf.linalg.adjoint(tf.linalg.cholesky_solve(L, 142 | tf.matmul(feat_iDiag, err, transpose_a=True))) 143 | else: 144 | iK_err = tf.linalg.cholesky_solve(L, err) # [L, S, M] 145 | weights = tf.matmul(iK_err, feat, transpose_a=True) # [L, S, D] 146 | 147 | return MultioutputDenseSampler(basis=basis, 148 | weights=swap_axes(weights, -3, -2), # [S, L, D] 149 | multioutput_axis=multioutput_axis, 150 | **kwargs) 151 | -------------------------------------------------------------------------------- /gpflow_sampling/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from gpflow_sampling.utils.array_ops import * 5 | from gpflow_sampling.utils.linalg import * 6 | from gpflow_sampling.utils.gpflow_ops import * 7 | -------------------------------------------------------------------------------- /gpflow_sampling/utils/array_ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | # ---- Exports 12 | __all__ = ( 13 | 'normalize_axis', 14 | 'swap_axes', 15 | 'move_axis', 16 | 'expand_n', 17 | 'expand_to', 18 | ) 19 | 20 | 21 | # ============================================== 22 | # misc 23 | # ============================================== 24 | def normalize_axis(axis, ndims, minval=None, maxval=None): 25 | if minval is None: 26 | minval = -ndims 27 | 28 | if maxval is None: 29 | maxval = ndims - 1 30 | 31 | assert maxval >= axis >= minval 32 | return ndims + axis if (axis < 0) else axis 33 | 34 | 35 | def move_axis(arr: tf.Tensor, src: int, dest: int): 36 | ndims = len(arr.shape) 37 | src = ndims + src if (src < 0) else src 38 | dest = ndims + dest if (dest < 0) else dest 39 | 40 | src = normalize_axis(src, ndims) 41 | dest = normalize_axis(dest, ndims) 42 | 43 | perm = list(range(ndims)) 44 | perm.insert(dest, perm.pop(src)) 45 | return tf.transpose(arr, perm) 46 | 47 | 48 | def swap_axes(arr: tf.Tensor, a: int, b: int): 49 | ndims = len(arr.shape) 50 | a = normalize_axis(a, ndims) 51 | b = normalize_axis(b, ndims) 52 | 53 | perm = list(range(ndims)) 54 | perm[a] = b 55 | perm[b] = a 56 | return tf.transpose(arr, perm) 57 | 58 | 59 | def expand_n(arr: tf.Tensor, axis: int, n: int): 60 | ndims = len(arr.shape) 61 | axis = normalize_axis(axis, ndims, maxval=ndims) 62 | return arr[axis * (slice(None),) + n * (np.newaxis,)] 63 | 64 | 65 | def expand_to(arr: tf.Tensor, axis: int, ndims: int): 66 | _ndims = len(arr.shape) 67 | if _ndims == ndims: 68 | return tf.identity(arr) 69 | assert ndims > _ndims 70 | return expand_n(arr, axis, ndims - _ndims) 71 | -------------------------------------------------------------------------------- /gpflow_sampling/utils/conv_ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | from typing import List 10 | from gpflow.base import TensorType 11 | from gpflow_sampling.utils.array_ops import move_axis 12 | 13 | 14 | # ---- Exports 15 | __all__ = ( 16 | 'reformat_shape', 17 | 'reformat_data', 18 | ) 19 | 20 | 21 | # ============================================== 22 | # conv_ops 23 | # ============================================== 24 | def reformat_shape(shape: List, 25 | input_format: str, 26 | output_format: str) -> List: 27 | """ 28 | Helper method for shape data between NHWC and NCHW formats. 29 | """ 30 | if input_format == output_format: 31 | return shape # noop 32 | 33 | if input_format == 'NHWC': 34 | assert output_format == "NCHW" 35 | return shape[:-3] + shape[-1:] + shape[-3: -1] 36 | 37 | if input_format == 'NCHW': 38 | assert output_format == "NHWC" 39 | return shape[:-3] + shape[-2:] + [shape[-3]] 40 | 41 | raise NotImplementedError 42 | 43 | 44 | def reformat_data(x: TensorType, 45 | format_in: str, 46 | format_out: str): 47 | """ 48 | Helper method for converting image data between NHWC and NCHW formats. 49 | """ 50 | if format_in == format_out: 51 | return x # noop 52 | 53 | if format_in == "NHWC": 54 | assert format_out == "NCHW" 55 | return move_axis(x, -1, -3) 56 | 57 | if format_in == "NCHW": 58 | assert format_out == "NHWC" 59 | return move_axis(x, -3, -1) 60 | 61 | raise NotImplementedError 62 | -------------------------------------------------------------------------------- /gpflow_sampling/utils/gpflow_ops.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | from gpflow.inducing_variables import (InducingVariables, 10 | MultioutputInducingVariables, 11 | SharedIndependentInducingVariables, 12 | SeparateIndependentInducingVariables) 13 | 14 | from gpflow.utilities import Dispatcher 15 | 16 | 17 | # ---- Exports 18 | __all__ = ('get_inducing_shape', 'inducing_to_tensor') 19 | 20 | 21 | # ============================================== 22 | # gpflow_utils 23 | # ============================================== 24 | get_inducing_shape = Dispatcher("get_inducing_shape") 25 | inducing_to_tensor = Dispatcher("inducing_to_tensor") 26 | 27 | @get_inducing_shape.register(InducingVariables) 28 | def _getter(x): 29 | assert not isinstance(InducingVariables, MultioutputInducingVariables) 30 | return list(x.Z.shape) 31 | 32 | 33 | @get_inducing_shape.register(SharedIndependentInducingVariables) 34 | def _getter(x: SharedIndependentInducingVariables): 35 | assert len(x.inducing_variables) == 1 36 | return get_inducing_shape(x.inducing_variables[0]) 37 | 38 | 39 | @get_inducing_shape.register(SeparateIndependentInducingVariables) 40 | def _getter(x: SeparateIndependentInducingVariables): 41 | for n, z in enumerate(x.inducing_variables): 42 | if n == 0: 43 | shape = get_inducing_shape(z) 44 | else: 45 | assert shape == get_inducing_shape(z) 46 | return [n + 1] + shape 47 | 48 | 49 | @inducing_to_tensor.register(InducingVariables) 50 | def _convert(x: InducingVariables, **kwargs): 51 | assert not isinstance(InducingVariables, MultioutputInducingVariables) 52 | return tf.convert_to_tensor(x.Z, **kwargs) 53 | 54 | 55 | @inducing_to_tensor.register(SharedIndependentInducingVariables) 56 | def _convert(x: InducingVariables, **kwargs): 57 | assert len(x.inducing_variables) == 1 58 | return inducing_to_tensor(x.inducing_variables[0], **kwargs) 59 | 60 | 61 | @inducing_to_tensor.register(SeparateIndependentInducingVariables) 62 | def _convert(x: InducingVariables, **kwargs): 63 | return tf.stack([inducing_to_tensor(z, **kwargs) 64 | for z in x.inducing_variables], axis=0) 65 | 66 | -------------------------------------------------------------------------------- /gpflow_sampling/utils/linalg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import numpy as np 9 | import tensorflow as tf 10 | from typing import List, Union 11 | from string import ascii_lowercase 12 | from collections.abc import Iterable 13 | from gpflow.config import default_jitter 14 | from gpflow_sampling.utils.array_ops import normalize_axis 15 | from tensorflow_probability.python.math import pivoted_cholesky 16 | 17 | # ---- Exports 18 | __all__ = ( 19 | 'batch_tensordot', 20 | 'get_default_preconditioner', 21 | ) 22 | 23 | 24 | # ============================================== 25 | # linalg 26 | # ============================================== 27 | def batch_tensordot(a: tf.Tensor, 28 | b: tf.Tensor, 29 | axes: List, 30 | batch_axes: List = None) -> tf.Tensor: 31 | """ 32 | Computes tensor products like with: 33 | - Support for batch dimensions; 1-to-1 rather than Cartesian product 34 | - Broadcasting of batch dimensions 35 | 36 | Example: 37 | a = tf.random.rand([5, 4, 3, 2]) 38 | b = tf.random.rand([6, 4, 2, 1]) 39 | c = batch_tensordot(a, b, [-1, -2], [[-3, -2], [-3, -1]]) # [5, 6, 4, 3] 40 | """ 41 | ndims_a = len(a.shape) 42 | ndims_b = len(b.shape) 43 | assert min(ndims_a, ndims_b) > 0 44 | assert max(ndims_a, ndims_b) < 26 45 | 46 | # Prepare batch and contraction axes 47 | def parse_axes(axes): 48 | if axes is None: 49 | return [], [] 50 | 51 | assert len(axes) == 2 52 | axes_a, axes_b = axes 53 | 54 | a_is_int = isinstance(axes_a, int) 55 | b_is_int = isinstance(axes_b, int) 56 | assert a_is_int == b_is_int 57 | if a_is_int: 58 | return [normalize_axis(axes_a, ndims_a)], \ 59 | [normalize_axis(axes_b, ndims_b)] 60 | 61 | assert isinstance(axes_a, Iterable) 62 | assert isinstance(axes_b, Iterable) 63 | axes_a = list(axes_a) 64 | axes_b = list(axes_b) 65 | length = len(axes_a) 66 | assert length == len(axes_b) 67 | if length == 0: 68 | return [], [] 69 | 70 | axes_a = [normalize_axis(ax, ndims_a) for ax in axes_a] 71 | axes_b = [normalize_axis(ax, ndims_b) for ax in axes_b] 72 | return map(list, zip(*sorted(zip(axes_a, axes_b)))) # sort according to a 73 | 74 | reduce_a, reduce_b = parse_axes(axes) # defines the tensor contraction 75 | batch_a, batch_b = parse_axes(batch_axes) # group these together 1-to-1 76 | 77 | # Construct left-hand side einsum conventions 78 | active_a = reduce_a + batch_a 79 | active_b = reduce_b + batch_b 80 | assert len(active_a) == len(set(active_a)) # check for duplicates 81 | assert len(active_b) == len(set(active_b)) 82 | 83 | lhs_a = list(ascii_lowercase[:ndims_a]) 84 | lhs_b = list(ascii_lowercase[ndims_a: ndims_a + ndims_b]) 85 | for (pos_a, pos_b) in zip(active_a, active_b): 86 | lhs_b[pos_b] = lhs_a[pos_a] 87 | 88 | # Construct right-hand side einsum convention 89 | rhs = [] 90 | for i, char_a in enumerate(lhs_a): 91 | if i not in reduce_a: 92 | rhs.append(char_a) 93 | 94 | for i, char_b in enumerate(lhs_b): 95 | if i not in active_b: 96 | rhs.append(char_b) 97 | 98 | # Enable broadcasting by eliminate singleton dimenisions 99 | for (pos_a, pos_b) in zip(batch_a, batch_b): 100 | if a.shape[pos_a] == b.shape[pos_b]: 101 | continue # TODO: test for edge cases 102 | 103 | if a.shape[pos_a] == 1: 104 | a = tf.squeeze(a, axis=pos_a) 105 | del lhs_a[pos_a] 106 | 107 | if b.shape[pos_b] == 1: 108 | b = tf.squeeze(b, axis=pos_b) 109 | del lhs_b[pos_b] 110 | 111 | # Compute einsum 112 | return tf.einsum(f"{''.join(lhs_a)},{''.join(lhs_b)}->{''.join(rhs)}", a, b) 113 | 114 | 115 | def get_default_preconditioner(arr: tf.Tensor, 116 | diag: Union[tf.Tensor, tf.linalg.LinearOperator], 117 | max_rank: int = 100, 118 | diag_rtol: float = None): 119 | """ 120 | Returns a preconditioner representing 121 | $(D + LL^{T})^{-1}$ where $D$ is a diagonal matrix and $L$ is the 122 | partial pivoted Cholesky factor of a symmetric PSD matrix $A$. 123 | """ 124 | if diag_rtol is None: 125 | diag_rtol = default_jitter() 126 | 127 | N = arr.shape[-1] 128 | if not isinstance(diag, tf.linalg.LinearOperator): 129 | diag = tf.convert_to_tensor(diag, dtype=arr.dtype) 130 | if N == 1 or (diag.shape.ndims and diag.shape[-1] > 1): 131 | diag = tf.linalg.LinearOperatorDiag(diag) 132 | else: 133 | diag = tf.linalg.LinearOperatorScaledIdentity(N, diag) 134 | 135 | piv_chol = pivoted_cholesky(arr, max_rank=max_rank, diag_rtol=diag_rtol) 136 | low_rank = tf.linalg.LinearOperatorLowRankUpdate(diag, piv_chol) 137 | return low_rank.inverse() 138 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | requirements = ( 4 | 'numpy>=1.18.0', 5 | 'tensorflow>=2.2.0', 6 | 'tensorflow-probability>=0.9.0', 7 | 'gpflow>=2.0.3', 8 | ) 9 | 10 | extra_requirements = { 11 | 'examples': ( 12 | 'matplotlib', 13 | 'seaborn', 14 | 'tqdm', 15 | 'tensorflow-datasets', 16 | ), 17 | } 18 | 19 | setup(name='gpflow_sampling', 20 | version='0.2', 21 | license='Creative Commons Attribution-Noncommercial-Share Alike license', 22 | packages=find_packages(exclude=["examples*"]), 23 | python_requires='>=3.6', 24 | install_requires=requirements, 25 | extras_require=extra_requirements) 26 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/j-wilson/GPflowSampling/bc9c553fe4d8f522726f002c18df9965246df345/tests/__init__.py -------------------------------------------------------------------------------- /tests/kernels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/j-wilson/GPflowSampling/bc9c553fe4d8f522726f002c18df9965246df345/tests/kernels/__init__.py -------------------------------------------------------------------------------- /tests/kernels/test_conv2d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from numpy import allclose 11 | from typing import Any, List, NamedTuple 12 | from gpflow import config as gpflow_config, kernels as gpflow_kernels 13 | from gpflow.config import default_float as floatx 14 | from gpflow_sampling import kernels, covariances 15 | from gpflow_sampling.inducing_variables import * 16 | from gpflow_sampling.covariances.Kfus import (_Kfu_conv2d_fallback, 17 | _Kfu_depthwise_conv2d_fallback) 18 | 19 | SupportedBaseKernels = (gpflow_kernels.Matern12, 20 | gpflow_kernels.Matern32, 21 | gpflow_kernels.Matern52, 22 | gpflow_kernels.SquaredExponential,) 23 | 24 | 25 | # ============================================== 26 | # test_conv2d 27 | # ============================================== 28 | class ConfigConv2d(NamedTuple): 29 | seed: int = 1 30 | floatx: Any = 'float64' 31 | jitter: float = 1e-16 32 | 33 | num_test: int = 16 34 | num_cond: int = 32 35 | kernel_variance: float = 1.0 36 | rel_lengthscales_min: float = 0.1 37 | rel_lengthscales_max: float = 1.0 38 | 39 | 40 | channels_in: int = 3 41 | channels_out: int = 2 42 | image_shape: List = [28, 28] 43 | patch_shape: List = [5, 5] 44 | strides: List = [2, 2] 45 | 46 | padding: str = "SAME" 47 | dilations: List = [1, 1] 48 | 49 | 50 | def test_conv2d(config: ConfigConv2d = None): 51 | if config is None: 52 | config = ConfigConv2d() 53 | 54 | tf.random.set_seed(config.seed) 55 | gpflow_config.set_default_float(config.floatx) 56 | gpflow_config.set_default_jitter(config.jitter) 57 | 58 | X_shape = [config.num_test] + config.image_shape + [config.channels_in] 59 | X = tf.reshape(tf.range(tf.reduce_prod(X_shape), dtype=floatx()), X_shape) 60 | X /= tf.reduce_max(X) 61 | 62 | patch_len = config.channels_in * int(tf.reduce_prod(config.patch_shape)) 63 | for cls in SupportedBaseKernels: 64 | minval = config.rel_lengthscales_min * (patch_len ** 0.5) 65 | maxval = config.rel_lengthscales_max * (patch_len ** 0.5) 66 | lenscales = tf.random.uniform(shape=[patch_len], 67 | minval=minval, 68 | maxval=maxval, 69 | dtype=floatx()) 70 | 71 | base = cls(lengthscales=lenscales, variance=config.kernel_variance) 72 | kern = kernels.Conv2d(kernel=base, 73 | image_shape=config.image_shape, 74 | patch_shape=config.patch_shape, 75 | channels_in=config.channels_in, 76 | channels_out=config.channels_out, 77 | dilations=config.dilations, 78 | padding=config.padding, 79 | strides=config.strides) 80 | 81 | kern._weights = tf.random.normal(kern._weights.shape, dtype=floatx()) 82 | 83 | # Test full and shared inducing images 84 | Z_shape = [config.num_cond] + config.patch_shape + [config.channels_in] 85 | Zsrc = tf.random.normal(Z_shape, dtype=floatx()) 86 | for Z in (InducingImages(Zsrc), 87 | SharedInducingImages(Zsrc[..., :1], config.channels_in)): 88 | 89 | test = _Kfu_conv2d_fallback(Z, kern, X) 90 | allclose(covariances.Kfu(Z, kern, X), test) 91 | 92 | 93 | def test_conv2d_transpose(config: ConfigConv2d = None): 94 | if config is None: 95 | config = ConfigConv2d() 96 | 97 | tf.random.set_seed(config.seed) 98 | gpflow_config.set_default_float(config.floatx) 99 | gpflow_config.set_default_jitter(config.jitter) 100 | 101 | X_shape = [config.num_test] + config.image_shape + [config.channels_in] 102 | X = tf.reshape(tf.range(tf.reduce_prod(X_shape), dtype=floatx()), X_shape) 103 | X /= tf.reduce_max(X) 104 | 105 | patch_len = config.channels_in * int(tf.reduce_prod(config.patch_shape)) 106 | for cls in SupportedBaseKernels: 107 | minval = config.rel_lengthscales_min * (patch_len ** 0.5) 108 | maxval = config.rel_lengthscales_max * (patch_len ** 0.5) 109 | lenscales = tf.random.uniform(shape=[patch_len], 110 | minval=minval, 111 | maxval=maxval, 112 | dtype=floatx()) 113 | 114 | base = cls(lengthscales=lenscales, variance=config.kernel_variance) 115 | kern = kernels.Conv2dTranspose(kernel=base, 116 | image_shape=config.image_shape, 117 | patch_shape=config.patch_shape, 118 | channels_in=config.channels_in, 119 | channels_out=config.channels_out, 120 | dilations=config.dilations, 121 | padding=config.padding, 122 | strides=config.strides) 123 | 124 | kern._weights = tf.random.normal(kern._weights.shape, dtype=floatx()) 125 | 126 | # Test full and shared inducing images 127 | Z_shape = [config.num_cond] + config.patch_shape + [config.channels_in] 128 | Zsrc = tf.random.normal(Z_shape, dtype=floatx()) 129 | for Z in (InducingImages(Zsrc), 130 | SharedInducingImages(Zsrc[..., :1], config.channels_in)): 131 | 132 | test = _Kfu_conv2d_fallback(Z, kern, X) 133 | allclose(covariances.Kfu(Z, kern, X), test) 134 | 135 | 136 | def test_depthwise_conv2d(config: ConfigConv2d = None): 137 | if config is None: 138 | config = ConfigConv2d() 139 | 140 | tf.random.set_seed(config.seed) 141 | gpflow_config.set_default_float(config.floatx) 142 | gpflow_config.set_default_jitter(config.jitter) 143 | 144 | X_shape = [config.num_test] + config.image_shape + [config.channels_in] 145 | X = tf.reshape(tf.range(tf.reduce_prod(X_shape), dtype=floatx()), X_shape) 146 | X /= tf.reduce_max(X) 147 | 148 | patch_len = int(tf.reduce_prod(config.patch_shape)) 149 | for cls in SupportedBaseKernels: 150 | minval = config.rel_lengthscales_min * (patch_len ** 0.5) 151 | maxval = config.rel_lengthscales_max * (patch_len ** 0.5) 152 | lenscales = tf.random.uniform(shape=[config.channels_in, patch_len], 153 | minval=minval, 154 | maxval=maxval, 155 | dtype=floatx()) 156 | 157 | base = cls(lengthscales=lenscales, variance=config.kernel_variance) 158 | kern = kernels.DepthwiseConv2d(kernel=base, 159 | image_shape=config.image_shape, 160 | patch_shape=config.patch_shape, 161 | channels_in=config.channels_in, 162 | channels_out=config.channels_out, 163 | dilations=config.dilations, 164 | padding=config.padding, 165 | strides=config.strides) 166 | 167 | kern._weights = tf.random.normal(kern._weights.shape, dtype=floatx()) 168 | 169 | # Test full and shared inducing images 170 | Z_shape = [config.num_cond] + config.patch_shape + [config.channels_in] 171 | Zsrc = tf.random.normal(Z_shape, dtype=floatx()) 172 | for Z in (DepthwiseInducingImages(Zsrc), 173 | SharedDepthwiseInducingImages(Zsrc[..., :1], config.channels_in)): 174 | 175 | test = _Kfu_depthwise_conv2d_fallback(Z, kern, X) 176 | allclose(covariances.Kfu(Z, kern, X), test) 177 | 178 | -------------------------------------------------------------------------------- /tests/sampling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/j-wilson/GPflowSampling/bc9c553fe4d8f522726f002c18df9965246df345/tests/sampling/__init__.py -------------------------------------------------------------------------------- /tests/sampling/priors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/j-wilson/GPflowSampling/bc9c553fe4d8f522726f002c18df9965246df345/tests/sampling/priors/__init__.py -------------------------------------------------------------------------------- /tests/sampling/priors/test_fourier_conv2d.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from numpy import allclose 11 | from typing import Any, List, NamedTuple 12 | from gpflow import config as gpflow_config, kernels as gpflow_kernels 13 | from gpflow.config import default_float as floatx 14 | from gpflow_sampling import kernels, covariances, inducing_variables 15 | from gpflow_sampling.sampling.priors import random_fourier 16 | from gpflow_sampling.bases import fourier as fourier_basis 17 | from gpflow_sampling.utils import batch_tensordot, swap_axes 18 | 19 | SupportedBaseKernels = (gpflow_kernels.Matern12, 20 | gpflow_kernels.Matern32, 21 | gpflow_kernels.Matern52, 22 | gpflow_kernels.SquaredExponential,) 23 | 24 | 25 | # ============================================== 26 | # test_fourier_conv2d 27 | # ============================================== 28 | class ConfigFourierConv2d(NamedTuple): 29 | seed: int = 1 30 | floatx: Any = 'float64' 31 | jitter: float = 1e-6 32 | 33 | num_test: int = 16 34 | num_cond: int = 5 35 | num_bases: int = 4096 36 | num_samples: int = 16384 37 | shard_size: int = 1024 38 | 39 | kernel_variance: float = 0.9 # keep this near 1 since it impacts MC error 40 | rel_lengthscales_min: float = 0.5 41 | rel_lengthscales_max: float = 2.0 42 | num_latent_gps: int = 3 43 | 44 | # Convolutional settings 45 | channels_in: int = 2 46 | image_shape: List = [5, 5] 47 | patch_shape: List = [3, 3] 48 | strides: List = [1, 1] 49 | dilations: List = [1, 1] 50 | 51 | 52 | def _avg_spatial_inner_product(a, b=None, batch_dims: int = 0): 53 | _a = tf.reshape(a, list(a.shape[:-3]) + [-1, a.shape[-1]]) 54 | if b is None: 55 | _b = _a 56 | else: 57 | _b = tf.reshape(b, list(b.shape[:-3]) + [-1, b.shape[-1]]) 58 | batch_axes = 2 * [list(range(batch_dims))] 59 | 60 | prod = batch_tensordot(_a, _b, axes=[-1, -1], batch_axes=batch_axes) 61 | return tf.reduce_mean(prod, [-3, -1]) 62 | 63 | 64 | def _test_fourier_conv2d_common(config, kern, X, Z): 65 | # Use closed-form evaluations as ground truth 66 | Kuu = covariances.Kuu(Z, kern) 67 | Kfu = covariances.Kfu(Z, kern, X) 68 | Kff = kern(X, full_cov=True) 69 | 70 | # Test Fourier-feature-based kernel approximator 71 | basis = fourier_basis(kern, num_bases=config.num_bases) 72 | feat_x = basis(X) # [N, B] or [N, L, B] 73 | feat_z = basis(Z) 74 | 75 | tol = 3 * config.num_bases ** -0.5 76 | assert allclose(_avg_spatial_inner_product(feat_x, feat_x), Kff, tol, tol) 77 | assert allclose(_avg_spatial_inner_product(feat_x, feat_z), Kfu, tol, tol) 78 | assert allclose(_avg_spatial_inner_product(feat_z, feat_z), Kuu, tol, tol) 79 | del feat_x, feat_z 80 | 81 | # Test covariance of functions draw from approximate prior 82 | fx = [] 83 | fz = [] 84 | count = 0 85 | while count < config.num_samples: 86 | size = min(config.shard_size, config.num_samples - count) 87 | funcs = random_fourier(kern, 88 | basis=basis, 89 | num_bases=config.num_bases, 90 | sample_shape=[size]) 91 | 92 | fx.append(funcs(X)) 93 | fz.append(funcs(Z)) 94 | count += size 95 | 96 | fx = swap_axes(tf.concat(fx, axis=0), 0, -1) # [L, N, H, W, S] 97 | fz = swap_axes(tf.concat(fz, axis=0), 0, -1) # [L, M, 1, 1, S] 98 | nb = fx.shape.ndims - 4 # num. of batch dimensions 99 | tol += 3 * config.num_samples ** -0.5 100 | frac = 1 / config.num_samples 101 | 102 | assert allclose(frac * _avg_spatial_inner_product(fx, fx, nb), Kff, tol, tol) 103 | assert allclose(frac * _avg_spatial_inner_product(fx, fz, nb), Kfu, tol, tol) 104 | assert allclose(frac * _avg_spatial_inner_product(fz, fz, nb), Kuu, tol, tol) 105 | 106 | 107 | def test_conv2d(config: ConfigFourierConv2d = None): 108 | """ 109 | TODO: Consider separating out the test for Conv2dTranspose since it only 110 | supports a subset of strides/dilatons. 111 | """ 112 | if config is None: 113 | config = ConfigFourierConv2d() 114 | 115 | tf.random.set_seed(config.seed) 116 | gpflow_config.set_default_float(config.floatx) 117 | gpflow_config.set_default_jitter(config.jitter) 118 | 119 | X_shape = [config.num_test] + config.image_shape + [config.channels_in] 120 | X = tf.reshape(tf.range(tf.reduce_prod(X_shape), dtype=floatx()), X_shape) 121 | X /= tf.reduce_max(X) 122 | 123 | Z_shape = [config.num_cond] + config.patch_shape + [config.channels_in] 124 | Zsrc = tf.random.normal(Z_shape, dtype=floatx()) 125 | Z = inducing_variables.InducingImages(Zsrc) 126 | 127 | patch_len = config.channels_in * config.patch_shape[0] * config.patch_shape[1] 128 | for base_cls in SupportedBaseKernels: 129 | minval = config.rel_lengthscales_min * (patch_len ** 0.5) 130 | maxval = config.rel_lengthscales_max * (patch_len ** 0.5) 131 | lenscales = tf.random.uniform(shape=[patch_len], 132 | minval=minval, 133 | maxval=maxval, 134 | dtype=floatx()) 135 | 136 | base = base_cls(lengthscales=lenscales, variance=config.kernel_variance) 137 | for cls in (kernels.Conv2d, kernels.Conv2dTranspose): 138 | kern = cls(kernel=base, 139 | image_shape=config.image_shape, 140 | patch_shape=config.patch_shape, 141 | channels_in=config.channels_in, 142 | channels_out=config.num_latent_gps, 143 | dilations=config.dilations, 144 | strides=config.strides) 145 | 146 | _test_fourier_conv2d_common(config, kern, X, Z) 147 | 148 | 149 | def test_depthwise_conv2d(config: ConfigFourierConv2d = None): 150 | if config is None: 151 | config = ConfigFourierConv2d() 152 | 153 | assert config.num_bases % config.channels_in == 0 154 | tf.random.set_seed(config.seed) 155 | gpflow_config.set_default_float(config.floatx) 156 | gpflow_config.set_default_jitter(config.jitter) 157 | 158 | X_shape = [config.num_test] + config.image_shape + [config.channels_in] 159 | X = tf.random.uniform(X_shape, dtype=floatx()) 160 | 161 | img_shape = [config.num_cond] + config.patch_shape + [config.channels_in] 162 | Zsrc = tf.random.normal(img_shape, dtype=floatx()) 163 | Z = inducing_variables.DepthwiseInducingImages(Zsrc) 164 | 165 | patch_len = config.patch_shape[0] * config.patch_shape[1] 166 | for base_cls in SupportedBaseKernels: 167 | minval = config.rel_lengthscales_min * (patch_len ** 0.5) 168 | maxval = config.rel_lengthscales_max * (patch_len ** 0.5) 169 | lenscales = tf.random.uniform(shape=[config.channels_in, patch_len], 170 | minval=minval, 171 | maxval=maxval, 172 | dtype=floatx()) 173 | 174 | base = base_cls(lengthscales=lenscales, variance=config.kernel_variance) 175 | for cls in (kernels.DepthwiseConv2d,): 176 | kern = cls(kernel=base, 177 | image_shape=config.image_shape, 178 | patch_shape=config.patch_shape, 179 | channels_in=config.channels_in, 180 | channels_out=config.num_latent_gps, 181 | dilations=config.dilations, 182 | strides=config.strides) 183 | 184 | _test_fourier_conv2d_common(config, kern, X, Z) 185 | -------------------------------------------------------------------------------- /tests/sampling/priors/test_fourier_dense.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from numpy import allclose 11 | from typing import Any, NamedTuple 12 | from gpflow import kernels, config as gpflow_config 13 | from gpflow.inducing_variables import (InducingPoints, 14 | MultioutputInducingVariables, 15 | SharedIndependentInducingVariables, 16 | SeparateIndependentInducingVariables) 17 | from gpflow.config import default_float as floatx 18 | from gpflow_sampling import covariances 19 | from gpflow_sampling.bases import fourier as fourier_basis 20 | from gpflow_sampling.sampling.priors import random_fourier 21 | 22 | SupportedBaseKernels = (kernels.Matern12, 23 | kernels.Matern32, 24 | kernels.Matern52, 25 | kernels.SquaredExponential) 26 | 27 | 28 | # ============================================== 29 | # test_fourier_dense 30 | # ============================================== 31 | class ConfigFourierDense(NamedTuple): 32 | seed: int = 1 33 | floatx: Any = 'float64' 34 | jitter: float = 1e-6 35 | 36 | num_test: int = 16 37 | num_cond: int = 8 38 | num_bases: int = 4096 39 | num_samples: int = 16384 40 | shard_size: int = 1024 41 | input_dims: int = 3 42 | 43 | kernel_variance: float = 0.9 # keep this near 1 since it impacts MC error 44 | lengthscales_min: float = 0.1 45 | lengthscales_max: float = 1.0 46 | 47 | 48 | def _test_fourier_dense_common(config, kern, X, Z): 49 | # Test Fourier-feature-based kernel approximator 50 | Kuu = covariances.Kuu(Z, kern) 51 | Kfu = covariances.Kfu(Z, kern, X) 52 | basis = fourier_basis(kern, num_bases=config.num_bases) 53 | Z_opt = dict() # options used when evaluating basis/prior at Z 54 | if isinstance(Z, MultioutputInducingVariables): 55 | Kff = kern(X, full_cov=True, full_output_cov=False) 56 | if not isinstance(Z, SharedIndependentInducingVariables): 57 | # Handling for non-shared multioutput inducing variables. 58 | # We need to indicate that Z's outermost axis should be 59 | # evaluated 1-to-1 with the L latent GPs 60 | Z_opt.setdefault("multioutput_axis", 0) 61 | 62 | feat_x = basis(X) # [L, N, B] 63 | feat_z = basis(Z, **Z_opt) # [L, M, B] 64 | else: 65 | Kff = kern(X, full_cov=True) 66 | feat_x = basis(X) 67 | feat_z = basis(Z) 68 | 69 | tol = 3 * config.num_bases ** -0.5 70 | assert allclose(tf.matmul(feat_x, feat_x, transpose_b=True), Kff, tol, tol) 71 | assert allclose(tf.matmul(feat_x, feat_z, transpose_b=True), Kfu, tol, tol) 72 | assert allclose(tf.matmul(feat_z, feat_z, transpose_b=True), Kuu, tol, tol) 73 | del feat_x, feat_z 74 | 75 | # Test covariance of functions draw from approximate prior 76 | fx = [] 77 | fz = [] 78 | count = 0 79 | while count < config.num_samples: 80 | size = min(config.shard_size, config.num_samples - count) 81 | funcs = random_fourier(kern, 82 | basis=basis, 83 | num_bases=config.num_bases, 84 | sample_shape=[size]) 85 | 86 | fx.append(funcs(X)) 87 | fz.append(funcs(Z, **Z_opt)) 88 | count += size 89 | 90 | fx = tf.transpose(tf.concat(fx, axis=0)) # [L, N, S] 91 | fz = tf.transpose(tf.concat(fz, axis=0)) # [L, M, S] 92 | tol += 3 * config.num_samples ** -0.5 93 | frac = 1 / config.num_samples 94 | assert allclose(frac * tf.matmul(fx, fx, transpose_b=True), Kff, tol, tol) 95 | assert allclose(frac * tf.matmul(fx, fz, transpose_b=True), Kfu, tol, tol) 96 | assert allclose(frac * tf.matmul(fz, fz, transpose_b=True), Kuu, tol, tol) 97 | 98 | 99 | def test_dense(config: ConfigFourierDense = None): 100 | if config is None: 101 | config = ConfigFourierDense() 102 | 103 | tf.random.set_seed(config.seed) 104 | gpflow_config.set_default_float(config.floatx) 105 | gpflow_config.set_default_jitter(config.jitter) 106 | 107 | X = tf.random.uniform([config.num_test, config.input_dims], dtype=floatx()) 108 | Z = tf.random.uniform([config.num_cond, config.input_dims], dtype=floatx()) 109 | Z = InducingPoints(Z) 110 | for cls in SupportedBaseKernels: 111 | lenscales = tf.random.uniform(shape=[config.input_dims], 112 | minval=config.lengthscales_min, 113 | maxval=config.lengthscales_max, 114 | dtype=floatx()) 115 | 116 | kern = cls(lengthscales=lenscales, variance=config.kernel_variance) 117 | _test_fourier_dense_common(config, kern, X, Z) 118 | 119 | 120 | def test_dense_shared(config: ConfigFourierDense = None, output_dim: int = 2): 121 | if config is None: 122 | config = ConfigFourierDense() 123 | 124 | tf.random.set_seed(config.seed) 125 | gpflow_config.set_default_float(config.floatx) 126 | gpflow_config.set_default_jitter(config.jitter) 127 | 128 | X = tf.random.uniform([config.num_test, config.input_dims], dtype=floatx()) 129 | Z = tf.random.uniform([config.num_cond, config.input_dims], dtype=floatx()) 130 | Z = SharedIndependentInducingVariables(InducingPoints(Z)) 131 | for cls in SupportedBaseKernels: 132 | lenscales = tf.random.uniform(shape=[config.input_dims], 133 | minval=config.lengthscales_min, 134 | maxval=config.lengthscales_max, 135 | dtype=floatx()) 136 | 137 | base = cls(lengthscales=lenscales, variance=config.kernel_variance) 138 | kern = kernels.SharedIndependent(base, output_dim=output_dim) 139 | _test_fourier_dense_common(config, kern, X, Z) 140 | 141 | 142 | def test_dense_separate(config: ConfigFourierDense = None): 143 | if config is None: 144 | config = ConfigFourierDense() 145 | 146 | tf.random.set_seed(config.seed) 147 | gpflow_config.set_default_float(config.floatx) 148 | gpflow_config.set_default_jitter(config.jitter) 149 | 150 | allZ = [] 151 | allK = [] 152 | for cls in SupportedBaseKernels: 153 | lenscales = tf.random.uniform(shape=[config.input_dims], 154 | minval=config.lengthscales_min, 155 | maxval=config.lengthscales_max, 156 | dtype=floatx()) 157 | 158 | rel_variance = tf.random.uniform(shape=[], 159 | minval=0.9, 160 | maxval=1.1, 161 | dtype=floatx()) 162 | 163 | allK.append(cls(lengthscales=lenscales, 164 | variance=config.kernel_variance * rel_variance)) 165 | 166 | allZ.append(InducingPoints( 167 | tf.random.uniform([config.num_cond, config.input_dims], dtype=floatx()))) 168 | 169 | kern = kernels.SeparateIndependent(allK) 170 | Z = SeparateIndependentInducingVariables(allZ) 171 | X = tf.random.uniform([config.num_test, config.input_dims], dtype=floatx()) 172 | _test_fourier_dense_common(config, kern, X, Z) 173 | 174 | 175 | if __name__ == "__main__": 176 | test_dense() 177 | test_dense_shared() 178 | test_dense_separate() 179 | -------------------------------------------------------------------------------- /tests/sampling/updates/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/j-wilson/GPflowSampling/bc9c553fe4d8f522726f002c18df9965246df345/tests/sampling/updates/__init__.py -------------------------------------------------------------------------------- /tests/sampling/updates/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | from numpy import allclose 10 | from typing import Union, NamedTuple 11 | from functools import partial, update_wrapper 12 | 13 | from gpflow import config as gpflow_config 14 | from gpflow import kernels, mean_functions 15 | from gpflow.base import TensorLike 16 | from gpflow.config import default_jitter, default_float as floatx 17 | from gpflow.models import GPR, SVGP 18 | from gpflow.kernels import MultioutputKernel, SharedIndependent 19 | from gpflow.utilities import Dispatcher 20 | from gpflow.inducing_variables import (InducingPoints, 21 | InducingVariables, 22 | SharedIndependentInducingVariables, 23 | SeparateIndependentInducingVariables) 24 | from gpflow_sampling import covariances, kernels as kernels_ext 25 | from gpflow_sampling.utils import batch_tensordot 26 | from gpflow_sampling.inducing_variables import * 27 | 28 | SupportedBaseKernels = (kernels.Matern12, 29 | kernels.Matern32, 30 | kernels.Matern52, 31 | kernels.SquaredExponential) 32 | 33 | # ---- Export 34 | __all__ = ('sample_joint', 35 | 'avg_spatial_inner_product', 36 | 'test_update_sparse', 37 | 'test_update_sparse_shared', 38 | 'test_update_sparse_separate', 39 | 'test_update_conv2d') 40 | 41 | # ============================================== 42 | # common 43 | # ============================================== 44 | sample_joint = Dispatcher("sample_joint") 45 | 46 | 47 | @sample_joint.register(kernels.Kernel, TensorLike, TensorLike) 48 | def _sample_joint_fallback(kern, 49 | X, 50 | Xnew, 51 | num_samples: int, 52 | L: TensorLike = None, 53 | diag: TensorLike = None): 54 | """ 55 | Sample from the joint distribution of $f(X), g(Z)$ via a 56 | location-scale transform. 57 | """ 58 | if diag is None: 59 | diag = default_jitter() 60 | 61 | if L is None: 62 | K = kern(tf.concat([X, Xnew], axis=-2), full_cov=True) 63 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag) 64 | L = tf.linalg.cholesky(K) 65 | 66 | # Draw samples using a location-scale transform 67 | rvs = tf.random.normal(list(L.shape[:-1]) + [num_samples], dtype=floatx()) 68 | draws = tf.expand_dims(L @ rvs, 0) # [1, N + T, S] 69 | return tf.split(tf.transpose(draws), [-1, Xnew.shape[0]], axis=-2), L 70 | 71 | 72 | @sample_joint.register(kernels.Kernel, InducingVariables, TensorLike) 73 | def _sample_joint_inducing(kern, 74 | Z, 75 | Xnew, 76 | num_samples: int, 77 | L: TensorLike = None, 78 | diag: Union[float, tf.Tensor] = None): 79 | """ 80 | Sample from the joint distribution of $f(X), g(Z)$ via a 81 | location-scale transform. 82 | """ 83 | if diag is None: 84 | diag = default_jitter() 85 | 86 | # Construct joint covariance and compute matrix square root 87 | has_multiple_outputs = isinstance(kern, MultioutputKernel) 88 | if L is None: 89 | if has_multiple_outputs: 90 | Kff = kern(Xnew, full_cov=True, full_output_cov=False) 91 | else: 92 | Kff = kern(Xnew, full_cov=True) 93 | Kuu = covariances.Kuu(Z, kern, jitter=0.0) 94 | Kuf = covariances.Kuf(Z, kern, Xnew) 95 | if isinstance(kern, SharedIndependent) and \ 96 | isinstance(Z, SharedIndependentInducingVariables): 97 | Kuu = tf.tile(Kuu[None], [Kff.shape[0], 1, 1]) 98 | Kuf = tf.tile(Kuf[None], [Kff.shape[0], 1, 1]) 99 | 100 | K = tf.concat([tf.concat([Kuu, Kuf], axis=-1), 101 | tf.concat([tf.linalg.adjoint(Kuf), Kff], axis=-1)], axis=-2) 102 | 103 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag) 104 | L = tf.linalg.cholesky(K) 105 | 106 | # Draw samples using a location-scale transform 107 | rvs = tf.random.normal(list(L.shape[:-1]) + [num_samples], dtype=floatx()) 108 | draws = L @ rvs # [L, M + N, S] or [M + N, S] 109 | if not has_multiple_outputs: 110 | draws = tf.expand_dims(draws, 0) 111 | 112 | return tf.split(tf.transpose(draws), [-1, Xnew.shape[0]], axis=-2), L 113 | 114 | 115 | @sample_joint.register(kernels_ext.Conv2d, InducingImages, TensorLike) 116 | def _sample_joint_conv2d(kern, 117 | Z, 118 | Xnew, 119 | num_samples: int, 120 | L: TensorLike = None, 121 | diag: Union[float, tf.Tensor] = None): 122 | """ 123 | Sample from the joint distribution of $f(X), g(Z)$ via a 124 | location-scale transform. 125 | """ 126 | if diag is None: 127 | diag = default_jitter() 128 | 129 | # Construct joint covariance and compute matrix square root 130 | if L is None: 131 | Zp = Z.as_patches # [M, patch_len] 132 | Xp = kern.get_patches(Xnew, full_spatial=False) 133 | P = tf.concat([Zp, tf.reshape(Xp, [-1, Xp.shape[-1]])], axis=0) 134 | K = kern.kernel(P, full_cov=True) 135 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag) 136 | L = tf.linalg.cholesky(K) 137 | L = tf.tile(L[None], [kern.channels_out, 1, 1]) # TODO: Improve me 138 | 139 | # Draw samples using a location-scale transform 140 | spatial_in = Xnew.shape[-3:-1] 141 | spatial_out = kern.get_spatial_out(spatial_in) 142 | rvs = tf.random.normal(list(L.shape[:-1]) + [num_samples], dtype=floatx()) 143 | draws = tf.transpose(L @ rvs) # [S, M + P, L] 144 | fz, fx = tf.split(draws, [len(Z), -1], axis=1) 145 | 146 | # Reorganize $f(X)$ as a 3d feature map 147 | fx_shape = [num_samples, Xnew.shape[0]] + spatial_out + [kern.channels_out] 148 | fx = tf.reshape(fx, fx_shape) 149 | return (fz, fx), L 150 | 151 | 152 | def avg_spatial_inner_product(a, b=None, batch_dims: int = 0): 153 | """ 154 | Used to compute covariances of functions defined as sums over 155 | patch response functions in 4D image format [N, H, W, C] 156 | """ 157 | _a = tf.reshape(a, list(a.shape[:-3]) + [-1, a.shape[-1]]) 158 | if b is None: 159 | _b = _a 160 | else: 161 | _b = tf.reshape(b, list(b.shape[:-3]) + [-1, b.shape[-1]]) 162 | batch_axes = 2 * [list(range(batch_dims))] 163 | prod = batch_tensordot(_a, _b, axes=[-1, -1], batch_axes=batch_axes) 164 | return tf.reduce_mean(prod, [-3, -1]) 165 | 166 | 167 | def test_update_dense(default_config: NamedTuple = None): 168 | def decorator(subroutine): 169 | def main(config): 170 | assert config is not None, ValueError 171 | tf.random.set_seed(config.seed) 172 | gpflow_config.set_default_float(config.floatx) 173 | gpflow_config.set_default_jitter(config.jitter) 174 | 175 | X = tf.random.uniform([config.num_cond, config.input_dims], dtype=floatx()) 176 | Xnew = tf.random.uniform([config.num_test, config.input_dims], dtype=floatx()) 177 | for cls in SupportedBaseKernels: 178 | minval = config.rel_lengthscales_min * (config.input_dims ** 0.5) 179 | maxval = config.rel_lengthscales_max * (config.input_dims ** 0.5) 180 | lenscales = tf.random.uniform(shape=[config.input_dims], 181 | minval=minval, 182 | maxval=maxval, 183 | dtype=floatx()) 184 | 185 | kern = cls(lengthscales=lenscales, variance=config.kernel_variance) 186 | const = tf.random.normal([1], dtype=floatx()) 187 | 188 | K = kern(X, full_cov=True) 189 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + config.noise_variance) 190 | L = tf.linalg.cholesky(K) 191 | y = L @ tf.random.normal([L.shape[-1], 1], dtype=floatx()) + const 192 | 193 | model = GPR(kernel=kern, 194 | noise_variance=config.noise_variance, 195 | data=(X, y), 196 | mean_function=mean_functions.Constant(c=const)) 197 | 198 | mf, Sff = subroutine(config, model, Xnew) 199 | mg, Sgg = model.predict_f(Xnew, full_cov=True) 200 | 201 | tol = config.error_tol 202 | assert allclose(mf, mg, tol, tol) 203 | assert allclose(Sff, Sgg, tol, tol) 204 | 205 | return update_wrapper(partial(main, config=default_config), subroutine) 206 | return decorator 207 | 208 | 209 | def test_update_sparse(default_config: NamedTuple = None): 210 | def decorator(subroutine): 211 | def main(config): 212 | assert config is not None, ValueError 213 | tf.random.set_seed(config.seed) 214 | gpflow_config.set_default_float(config.floatx) 215 | gpflow_config.set_default_jitter(config.jitter) 216 | 217 | X = tf.random.uniform([config.num_test,config.input_dims], dtype=floatx()) 218 | Z_shape = config.num_cond, config.input_dims 219 | for cls in SupportedBaseKernels: 220 | minval = config.rel_lengthscales_min * (config.input_dims ** 0.5) 221 | maxval = config.rel_lengthscales_max * (config.input_dims ** 0.5) 222 | lenscales = tf.random.uniform(shape=[config.input_dims], 223 | minval=minval, 224 | maxval=maxval, 225 | dtype=floatx()) 226 | 227 | q_sqrt = tf.zeros([1] + 2 * [config.num_cond], dtype=floatx()) 228 | kern = cls(lengthscales=lenscales, variance=config.kernel_variance) 229 | Z = InducingPoints(tf.random.uniform(Z_shape, dtype=floatx())) 230 | 231 | const = tf.random.normal([1], dtype=floatx()) 232 | model = SVGP(kernel=kern, 233 | likelihood=None, 234 | inducing_variable=Z, 235 | mean_function=mean_functions.Constant(c=const), 236 | q_sqrt=q_sqrt) 237 | 238 | mf, Sff = subroutine(config, model, X) 239 | mg, Sgg = model.predict_f(X, full_cov=True) 240 | 241 | tol = config.error_tol 242 | assert allclose(mf, mg, tol, tol) 243 | assert allclose(Sff, Sgg, tol, tol) 244 | 245 | return update_wrapper(partial(main, config=default_config), subroutine) 246 | return decorator 247 | 248 | 249 | def test_update_sparse_shared(default_config: NamedTuple = None): 250 | def decorator(subroutine): 251 | def main(config): 252 | assert config is not None, ValueError 253 | tf.random.set_seed(config.seed) 254 | gpflow_config.set_default_float(config.floatx) 255 | gpflow_config.set_default_jitter(config.jitter) 256 | 257 | X = tf.random.uniform([config.num_test,config.input_dims], dtype=floatx()) 258 | Z_shape = config.num_cond, config.input_dims 259 | for cls in SupportedBaseKernels: 260 | minval = config.rel_lengthscales_min * (config.input_dims ** 0.5) 261 | maxval = config.rel_lengthscales_max * (config.input_dims ** 0.5) 262 | lenscales = tf.random.uniform(shape=[config.input_dims], 263 | minval=minval, 264 | maxval=maxval, 265 | dtype=floatx()) 266 | 267 | base = cls(lengthscales=lenscales, variance=config.kernel_variance) 268 | kern = kernels.SharedIndependent(base, output_dim=2) 269 | 270 | Z = SharedIndependentInducingVariables( 271 | InducingPoints(tf.random.uniform(Z_shape, dtype=floatx()))) 272 | Kuu = covariances.Kuu(Z, kern, jitter=gpflow_config.default_jitter()) 273 | q_sqrt = tf.stack([tf.zeros(2 * [config.num_cond], dtype=floatx()), 274 | tf.linalg.cholesky(Kuu)]) 275 | 276 | const = tf.random.normal([2], dtype=floatx()) 277 | model = SVGP(kernel=kern, 278 | likelihood=None, 279 | inducing_variable=Z, 280 | mean_function=mean_functions.Constant(c=const), 281 | q_sqrt=q_sqrt, 282 | whiten=False, 283 | num_latent_gps=2) 284 | 285 | mf, Sff = subroutine(config, model, X) 286 | mg, Sgg = model.predict_f(X, full_cov=True) 287 | tol = config.error_tol 288 | assert allclose(mf, mg, tol, tol) 289 | assert allclose(Sff, Sgg, tol, tol) 290 | return update_wrapper(partial(main, config=default_config), subroutine) 291 | return decorator 292 | 293 | 294 | def test_update_sparse_separate(default_config: NamedTuple = None): 295 | def decorator(subroutine): 296 | def main(config): 297 | assert config is not None, ValueError 298 | tf.random.set_seed(config.seed) 299 | gpflow_config.set_default_float(config.floatx) 300 | gpflow_config.set_default_jitter(config.jitter) 301 | 302 | X = tf.random.uniform([config.num_test,config.input_dims], dtype=floatx()) 303 | allK = [] 304 | allZ = [] 305 | Z_shape = config.num_cond, config.input_dims 306 | for cls in SupportedBaseKernels: 307 | minval = config.rel_lengthscales_min * (config.input_dims ** 0.5) 308 | maxval = config.rel_lengthscales_max * (config.input_dims ** 0.5) 309 | lenscales = tf.random.uniform(shape=[config.input_dims], 310 | minval=minval, 311 | maxval=maxval, 312 | dtype=floatx()) 313 | 314 | rel_variance = tf.random.uniform(shape=[], 315 | minval=0.9, 316 | maxval=1.1, 317 | dtype=floatx()) 318 | 319 | allK.append(cls(lengthscales=lenscales, 320 | variance=config.kernel_variance * rel_variance)) 321 | 322 | allZ.append(InducingPoints(tf.random.uniform(Z_shape, dtype=floatx()))) 323 | 324 | kern = kernels.SeparateIndependent(allK) 325 | Z = SeparateIndependentInducingVariables(allZ) 326 | 327 | Kuu = covariances.Kuu(Z, kern, jitter=gpflow_config.default_jitter()) 328 | q_sqrt = tf.linalg.cholesky(Kuu)\ 329 | * tf.random.uniform(shape=[kern.num_latent_gps, 1, 1], 330 | minval=0.0, 331 | maxval=0.5, 332 | dtype=floatx()) 333 | 334 | const = tf.random.normal([len(kern.kernels)], dtype=floatx()) 335 | model = SVGP(kernel=kern, 336 | likelihood=None, 337 | inducing_variable=Z, 338 | mean_function=mean_functions.Constant(c=const), 339 | q_sqrt=q_sqrt, 340 | whiten=False, 341 | num_latent_gps=len(allK)) 342 | 343 | mf, Sff = subroutine(config, model, X) 344 | mg, Sgg = model.predict_f(X, full_cov=True) 345 | tol = config.error_tol 346 | assert allclose(mf, mg, tol, tol) 347 | assert allclose(Sff, Sgg, tol, tol) 348 | return update_wrapper(partial(main, config=default_config), subroutine) 349 | return decorator 350 | 351 | 352 | def test_update_conv2d(default_config: NamedTuple = None): 353 | def decorator(subroutine): 354 | def main(config): 355 | assert config is not None, ValueError 356 | tf.random.set_seed(config.seed) 357 | gpflow_config.set_default_float(config.floatx) 358 | gpflow_config.set_default_jitter(config.jitter) 359 | 360 | X_shape = [config.num_test] + config.image_shape + [config.channels_in] 361 | X = tf.reshape(tf.range(tf.reduce_prod(X_shape), dtype=floatx()), X_shape) 362 | X /= tf.reduce_max(X) 363 | 364 | patch_len = config.channels_in * int(tf.reduce_prod(config.patch_shape)) 365 | for base_cls in SupportedBaseKernels: 366 | minval = config.rel_lengthscales_min * (patch_len ** 0.5) 367 | maxval = config.rel_lengthscales_max * (patch_len ** 0.5) 368 | lenscales = tf.random.uniform(shape=[patch_len], 369 | minval=minval, 370 | maxval=maxval, 371 | dtype=floatx()) 372 | 373 | base = base_cls(lengthscales=lenscales, variance=config.kernel_variance) 374 | Z_shape = [config.num_cond] + config.patch_shape + [config.channels_in] 375 | for cls in (kernels_ext.Conv2d, kernels_ext.Conv2dTranspose): 376 | kern = cls(kernel=base, 377 | image_shape=config.image_shape, 378 | patch_shape=config.patch_shape, 379 | channels_in=config.channels_in, 380 | channels_out=config.num_latent_gps, 381 | strides=config.strides, 382 | padding=config.padding, 383 | dilations=config.dilations) 384 | 385 | Z = InducingImages(tf.random.uniform(Z_shape, dtype=floatx())) 386 | q_sqrt = tf.linalg.cholesky(covariances.Kuu(Z, kern)) 387 | q_sqrt *= tf.random.uniform([config.num_latent_gps, 1, 1], 388 | minval=0.0, 389 | maxval=0.5, 390 | dtype=floatx()) 391 | 392 | # TODO: GPflow's SVGP class is not setup to support outputs defined 393 | # as spatial feature maps. For now, we content ourselves with 394 | # the following hack... 395 | const = tf.random.normal([config.num_latent_gps], dtype=floatx()) 396 | mean_function = lambda x: const 397 | 398 | model = SVGP(kernel=kern, 399 | likelihood=None, 400 | mean_function=mean_function, 401 | inducing_variable=Z, 402 | q_sqrt=q_sqrt, 403 | whiten=False, 404 | num_latent_gps=config.num_latent_gps) 405 | 406 | mf, Sff = subroutine(config, model, X) 407 | mg, Sgg = model.predict_f(X, full_cov=True) 408 | 409 | tol = config.error_tol 410 | assert allclose(mf, mg, tol, tol) 411 | assert allclose(Sff, Sgg, tol, tol) 412 | 413 | return update_wrapper(partial(main, config=default_config), subroutine) 414 | return decorator 415 | -------------------------------------------------------------------------------- /tests/sampling/updates/test_cg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from . import common 11 | from typing import Any, List, NamedTuple 12 | from gpflow import covariances 13 | from gpflow.models import GPR, SVGP 14 | from gpflow.config import default_jitter, default_float as floatx 15 | from gpflow_sampling.utils import swap_axes 16 | from gpflow_sampling.utils.linalg import get_default_preconditioner 17 | from gpflow_sampling.sampling.updates.cg_updates import cg as cg_update 18 | 19 | 20 | # ============================================== 21 | # test_cg 22 | # ============================================== 23 | class ConfigDense(NamedTuple): 24 | seed: int = 0 25 | floatx: Any = 'float64' 26 | jitter: float = 1e-6 27 | 28 | num_test: int = 128 29 | num_cond: int = 32 30 | num_samples: int = 16384 31 | shard_size: int = 1024 32 | input_dims: int = 3 33 | 34 | kernel_variance: float = 0.9 # keep this near 1 since it impacts MC error 35 | rel_lengthscales_min: float = 0.05 36 | rel_lengthscales_max: float = 0.5 37 | noise_variance: float = 1e-2 # only used by GPR test 38 | 39 | @property 40 | def error_tol(self): 41 | return 4 * (self.num_samples ** -0.5) 42 | 43 | 44 | class ConfigConv2d(NamedTuple): 45 | seed: int = 1 46 | floatx: Any = 'float64' 47 | jitter: float = 1e-5 48 | 49 | num_test: int = 16 50 | num_cond: int = 32 51 | num_samples: int = 16384 52 | shard_size: int = 1024 53 | 54 | kernel_variance: float = 0.9 55 | rel_lengthscales_min: float = 0.5 56 | rel_lengthscales_max: float = 2.0 57 | num_latent_gps: int = 3 58 | 59 | # Convolutional settings 60 | channels_in: int = 2 61 | image_shape: List = [3, 3] # Keep these small, since common.sample_joint 62 | patch_shape: List = [2, 2] # becomes very expensive for Conv2dTranspose! 63 | strides: List = [1, 1] 64 | padding: str = "VALID" 65 | dilations: List = [1, 1] 66 | 67 | @property 68 | def error_tol(self): 69 | return 4 * (self.num_samples ** -0.5) 70 | 71 | 72 | def _test_cg_gpr(config: ConfigDense, 73 | model: GPR, 74 | Xnew: tf.Tensor) -> tf.Tensor: 75 | """ 76 | Sample generation subroutine common to each unit test 77 | """ 78 | # Prepare preconditioner for CG 79 | X, y = model.data 80 | Kff = model.kernel(X, full_cov=True) 81 | max_rank = config.num_cond//(2 if config.num_cond > 1 else 1) 82 | preconditioner = get_default_preconditioner(Kff, 83 | diag=model.likelihood.variance, 84 | max_rank=max_rank) 85 | 86 | count = 0 87 | L_joint = None 88 | samples = [] 89 | while count < config.num_samples: 90 | # Sample $u ~ N(q_mu, q_sqrt q_sqrt^{T})$ 91 | size = min(config.shard_size, config.num_samples - count) 92 | 93 | # Generate draws from the joint distribution $p(f(X), f(Xnew))$ 94 | (f, fnew), L_joint = common.sample_joint(model.kernel, 95 | X, 96 | Xnew, 97 | num_samples=size, 98 | L=L_joint) 99 | 100 | # Solve for update functions 101 | update_fns = cg_update(model.kernel, 102 | X, 103 | y, 104 | f + model.mean_function(X), 105 | tol=1e-6, 106 | diag=model.likelihood.variance, 107 | max_iter=config.num_cond, 108 | preconditioner=preconditioner) 109 | 110 | samples.append(fnew + update_fns(Xnew)) 111 | count += size 112 | 113 | samples = tf.concat(samples, axis=0) 114 | if model.mean_function is not None: 115 | samples += model.mean_function(Xnew) 116 | return samples 117 | 118 | 119 | def _test_cg_svgp(config: ConfigDense, 120 | model: SVGP, 121 | Xnew: tf.Tensor) -> tf.Tensor: 122 | """ 123 | Sample generation subroutine common to each unit test 124 | """ 125 | # Prepare preconditioner for CG 126 | Z = model.inducing_variable 127 | Kff = covariances.Kuu(Z, model.kernel, jitter=0) 128 | max_rank = config.num_cond//(2 if config.num_cond > 1 else 1) 129 | preconditioner = get_default_preconditioner(Kff, 130 | diag=default_jitter(), 131 | max_rank=max_rank) 132 | 133 | count = 0 134 | samples = [] 135 | L_joint = None 136 | while count < config.num_samples: 137 | # Sample $u ~ N(q_mu, q_sqrt q_sqrt^{T})$ 138 | size = min(config.shard_size, config.num_samples - count) 139 | shape = model.num_latent_gps, config.num_cond, size 140 | rvs = tf.random.normal(shape=shape, dtype=floatx()) 141 | u = tf.transpose(model.q_sqrt @ rvs) 142 | 143 | # Generate draws from the joint distribution $p(f(X), g(Z))$ 144 | (f, fnew), L_joint = common.sample_joint(model.kernel, 145 | Z, 146 | Xnew, 147 | num_samples=size, 148 | L=L_joint) 149 | 150 | # Solve for update functions 151 | update_fns = cg_update(model.kernel, 152 | Z, 153 | u, 154 | f, 155 | tol=1e-6, 156 | max_iter=config.num_cond, 157 | preconditioner=preconditioner) 158 | 159 | samples.append(fnew + update_fns(Xnew)) 160 | count += size 161 | 162 | samples = tf.concat(samples, axis=0) 163 | if model.mean_function is not None: 164 | samples += model.mean_function(Xnew) 165 | return samples 166 | 167 | 168 | @common.test_update_dense(default_config=ConfigDense()) 169 | def test_cg_dense(*args, **kwargs): 170 | f = _test_cg_gpr(*args, **kwargs) # [S, N, 1] 171 | mf = tf.reduce_mean(f, axis=0) 172 | res = tf.squeeze(f - mf, axis=-1) 173 | Sff = tf.matmul(res, res, transpose_a=True)/f.shape[0] 174 | return mf, Sff 175 | 176 | 177 | @common.test_update_sparse(default_config=ConfigDense()) 178 | def test_cg_sparse(*args, **kwargs): 179 | f = _test_cg_svgp(*args, **kwargs) # [S, N, 1] 180 | mf = tf.reduce_mean(f, axis=0) 181 | res = tf.squeeze(f - mf, axis=-1) 182 | Sff = tf.matmul(res, res, transpose_a=True)/f.shape[0] 183 | return mf, Sff 184 | 185 | 186 | @common.test_update_sparse_shared(default_config=ConfigDense()) 187 | def test_cg_sparse_shared(*args, **kwargs): 188 | f = _test_cg_svgp(*args, **kwargs) # [S, N, L] 189 | mf = tf.reduce_mean(f, axis=0) 190 | res = tf.transpose(f - mf) 191 | Sff = tf.matmul(res, res, transpose_b=True)/f.shape[0] 192 | return mf, Sff 193 | 194 | 195 | @common.test_update_sparse_separate(default_config=ConfigDense()) 196 | def test_cg_sparse_separate(*args, **kwargs): 197 | f = _test_cg_svgp(*args, **kwargs) # [S, N, L] 198 | mf = tf.reduce_mean(f, axis=0) 199 | res = tf.transpose(f - mf) 200 | Sff = tf.matmul(res, res, transpose_b=True)/f.shape[0] 201 | return mf, Sff 202 | 203 | 204 | @common.test_update_conv2d(default_config=ConfigConv2d()) 205 | def test_cg_conv2d(*args, **kwargs): 206 | f = swap_axes(_test_cg_svgp(*args, **kwargs), 0, -1) # [L, N, H, W, S] 207 | mf = tf.transpose(tf.reduce_mean(f, [-3, -2, -1])) # [N, L] 208 | res = f - tf.reduce_mean(f, axis=-1, keepdims=True) 209 | Sff = common.avg_spatial_inner_product(res, batch_dims=f.shape.ndims - 4) 210 | return mf, Sff/f.shape[-1] # [L, N, N] 211 | -------------------------------------------------------------------------------- /tests/sampling/updates/test_exact.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from . import common 11 | from typing import Any, List, Union, NamedTuple 12 | from gpflow.config import default_jitter, default_float as floatx 13 | from gpflow.models import GPR, SVGP 14 | from gpflow_sampling import covariances 15 | from gpflow_sampling.utils import swap_axes 16 | from gpflow_sampling.sampling.updates import exact as exact_update 17 | 18 | 19 | # ============================================== 20 | # test_exact 21 | # ============================================== 22 | class ConfigDense(NamedTuple): 23 | seed: int = 0 24 | floatx: Any = 'float64' 25 | jitter: float = 1e-6 26 | 27 | num_test: int = 128 28 | num_cond: int = 32 29 | num_samples: int = 16384 30 | shard_size: int = 1024 31 | input_dims: int = 3 32 | 33 | kernel_variance: float = 0.9 # keep this near 1 since it impacts MC error 34 | rel_lengthscales_min: float = 0.05 35 | rel_lengthscales_max: float = 0.5 36 | noise_variance: float = 1e-2 # only used by GPR test 37 | 38 | @property 39 | def error_tol(self): 40 | return 4 * (self.num_samples ** -0.5) 41 | 42 | 43 | class ConfigConv2d(NamedTuple): 44 | seed: int = 1 45 | floatx: Any = 'float64' 46 | jitter: float = 1e-5 47 | 48 | num_test: int = 128 49 | num_cond: int = 32 50 | num_samples: int = 16384 51 | shard_size: int = 1024 52 | kernel_variance: float = 0.9 53 | rel_lengthscales_min: float = 0.5 54 | rel_lengthscales_max: float = 2.0 55 | num_latent_gps: int = 3 56 | 57 | # Convolutional settings 58 | channels_in: int = 2 59 | image_shape: List = [3, 3] # Keep these small, since common.sample_joint 60 | patch_shape: List = [2, 2] # becomes very expensive for Conv2dTranspose! 61 | strides: List = [1, 1] 62 | padding: str = "VALID" 63 | dilations: List = [1, 1] 64 | 65 | @property 66 | def error_tol(self): 67 | return 4 * (self.num_samples ** -0.5) 68 | 69 | 70 | def _test_exact_gpr(config: ConfigDense, 71 | model: GPR, 72 | Xnew: tf.Tensor) -> tf.Tensor: 73 | """ 74 | Sample generation subroutine common to each unit test 75 | """ 76 | # Precompute Cholesky factor (optional) 77 | X, y = model.data 78 | Kyy = model.kernel(X, full_cov=True) 79 | Kyy = tf.linalg.set_diag(Kyy, 80 | tf.linalg.diag_part(Kyy) + model.likelihood.variance) 81 | Lyy = tf.linalg.cholesky(Kyy) 82 | 83 | count = 0 84 | L_joint = None 85 | samples = [] 86 | while count < config.num_samples: 87 | # Sample $u ~ N(q_mu, q_sqrt q_sqrt^{T})$ 88 | size = min(config.shard_size, config.num_samples - count) 89 | 90 | # Generate draws from the joint distribution $p(f(X), f(Xnew))$ 91 | (f, fnew), L_joint = common.sample_joint(model.kernel, 92 | X, 93 | Xnew, 94 | num_samples=size, 95 | L=L_joint) 96 | 97 | # Solve for update functions 98 | update_fns = exact_update(model.kernel, 99 | X, 100 | y, 101 | f + model.mean_function(X), 102 | L=Lyy, 103 | diag=model.likelihood.variance) 104 | 105 | samples.append(fnew + update_fns(Xnew)) 106 | count += size 107 | 108 | samples = tf.concat(samples, axis=0) 109 | if model.mean_function is not None: 110 | samples += model.mean_function(Xnew) 111 | return samples 112 | 113 | 114 | def _test_exact_svgp(config: Union[ConfigDense, ConfigConv2d], 115 | model: SVGP, 116 | Xnew: tf.Tensor) -> tf.Tensor: 117 | """ 118 | Sample generation subroutine common to each unit test 119 | """ 120 | # Precompute Cholesky factor (optional) 121 | Z = model.inducing_variable 122 | Kuu = covariances.Kuu(Z, model.kernel, jitter=default_jitter()) 123 | Luu = tf.linalg.cholesky(Kuu) 124 | 125 | count = 0 126 | L_joint = None 127 | samples = [] 128 | while count < config.num_samples: 129 | # Sample $u ~ N(q_mu, q_sqrt q_sqrt^{T})$ 130 | size = min(config.shard_size, config.num_samples - count) 131 | shape = model.num_latent_gps, config.num_cond, size 132 | rvs = tf.random.normal(shape=shape, dtype=floatx()) 133 | u = tf.transpose(model.q_sqrt @ rvs) 134 | 135 | # Generate draws from the joint distribution $p(f(X), g(Z))$ 136 | (f, fnew), L_joint = common.sample_joint(model.kernel, 137 | Z, 138 | Xnew, 139 | num_samples=size, 140 | L=L_joint) 141 | 142 | # Solve for update functions 143 | update_fns = exact_update(model.kernel, Z, u, f, L=Luu) 144 | samples.append(fnew + update_fns(Xnew)) 145 | count += size 146 | 147 | samples = tf.concat(samples, axis=0) 148 | if model.mean_function is not None: 149 | samples += model.mean_function(Xnew) 150 | return samples 151 | 152 | 153 | @common.test_update_dense(default_config=ConfigDense()) 154 | def test_exact_dense(*args, **kwargs): 155 | f = _test_exact_gpr(*args, **kwargs) # [S, N, 1] 156 | mf = tf.reduce_mean(f, axis=0) 157 | res = tf.squeeze(f - mf, axis=-1) 158 | Sff = tf.matmul(res, res, transpose_a=True)/f.shape[0] 159 | return mf, Sff 160 | 161 | 162 | @common.test_update_sparse(default_config=ConfigDense()) 163 | def test_exact_sparse(*args, **kwargs): 164 | f = _test_exact_svgp(*args, **kwargs) # [S, N, 1] 165 | mf = tf.reduce_mean(f, axis=0) 166 | res = tf.squeeze(f - mf, axis=-1) 167 | Sff = tf.matmul(res, res, transpose_a=True)/f.shape[0] 168 | return mf, Sff 169 | 170 | 171 | @common.test_update_sparse_shared(default_config=ConfigDense()) 172 | def test_exact_sparse_shared(*args, **kwargs): 173 | f = _test_exact_svgp(*args, **kwargs) # [S, N, L] 174 | mf = tf.reduce_mean(f, axis=0) 175 | res = tf.transpose(f - mf) 176 | Sff = tf.matmul(res, res, transpose_b=True)/f.shape[0] 177 | return mf, Sff 178 | 179 | 180 | @common.test_update_sparse_separate(default_config=ConfigDense()) 181 | def test_exact_sparse_separate(*args, **kwargs): 182 | f = _test_exact_svgp(*args, **kwargs) # [S, N, L] 183 | mf = tf.reduce_mean(f, axis=0) 184 | res = tf.transpose(f - mf) 185 | Sff = tf.matmul(res, res, transpose_b=True)/f.shape[0] 186 | return mf, Sff 187 | 188 | 189 | @common.test_update_conv2d(default_config=ConfigConv2d()) 190 | def test_exact_conv2d(*args, **kwargs): 191 | f = swap_axes(_test_exact_svgp(*args, **kwargs), 0, -1) # [L, N, H, W, S] 192 | mf = tf.transpose(tf.reduce_mean(f, [-3, -2, -1])) # [N, L] 193 | res = f - tf.reduce_mean(f, axis=-1, keepdims=True) 194 | Sff = common.avg_spatial_inner_product(res, batch_dims=f.shape.ndims - 4) 195 | return mf, Sff/f.shape[-1] # [L, N, N] 196 | -------------------------------------------------------------------------------- /tests/sampling/updates/test_linear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # ============================================== 5 | # Preamble 6 | # ============================================== 7 | # ---- Imports 8 | import tensorflow as tf 9 | 10 | from . import common 11 | from typing import Any, NamedTuple 12 | from gpflow.config import default_float as floatx 13 | from gpflow.models import SVGP 14 | from gpflow_sampling.bases import fourier as fourier_basis 15 | from gpflow_sampling.sampling.updates import linear as linear_update 16 | 17 | 18 | # ============================================== 19 | # test_linear 20 | # ============================================== 21 | class ConfigDense(NamedTuple): 22 | seed: int = 0 23 | floatx: Any = 'float64' 24 | jitter: float = 1e-6 25 | 26 | num_test: int = 128 27 | num_cond: int = 32 28 | num_bases: int = 4096 29 | num_samples: int = 16384 30 | shard_size: int = 1024 31 | input_dims: int = 3 32 | 33 | kernel_variance: float = 0.9 # keep this near 1 since it impacts MC error 34 | rel_lengthscales_min: float = 0.05 35 | rel_lengthscales_max: float = 0.5 36 | 37 | @property 38 | def error_tol(self): 39 | return 4 * (self.num_samples ** -0.5 + self.num_bases ** -0.5) 40 | 41 | 42 | def _test_linear_svgp(config: ConfigDense, 43 | model: SVGP, 44 | Xnew: tf.Tensor) -> tf.Tensor: 45 | """ 46 | Sample generation subroutine common to each unit test 47 | """ 48 | Z = model.inducing_variable 49 | count = 0 50 | basis = fourier_basis(model.kernel, num_bases=config.num_bases) 51 | L_joint = None 52 | samples = [] 53 | while count < config.num_samples: 54 | # Sample $u ~ N(q_mu, q_sqrt q_sqrt^{T})$ 55 | size = min(config.shard_size, config.num_samples - count) 56 | shape = model.num_latent_gps, config.num_cond, size 57 | rvs = tf.random.normal(shape=shape, dtype=floatx()) 58 | u = tf.transpose(model.q_sqrt @ rvs) 59 | 60 | # Generate draws from the joint distribution $p(f(X), g(Z))$ 61 | (f, fnew), L_joint = common.sample_joint(model.kernel, 62 | Z, 63 | Xnew, 64 | num_samples=size, 65 | L=L_joint) 66 | 67 | # Solve for update functions 68 | update_fns = linear_update(Z, u, f, basis=basis) 69 | samples.append(fnew + update_fns(Xnew)) 70 | count += size 71 | 72 | samples = tf.concat(samples, axis=0) 73 | if model.mean_function is not None: 74 | samples += model.mean_function(Xnew) 75 | return samples 76 | 77 | 78 | @common.test_update_sparse(default_config=ConfigDense()) 79 | def test_linear_sparse(*args, **kwargs): 80 | f = _test_linear_svgp(*args, **kwargs) # [S, N, 1] 81 | mf = tf.reduce_mean(f, axis=0) 82 | res = tf.squeeze(f - mf, axis=-1) 83 | Sff = tf.matmul(res, res, transpose_a=True)/f.shape[0] 84 | return mf, Sff 85 | 86 | 87 | @common.test_update_sparse_shared(default_config=ConfigDense()) 88 | def test_linear_sparse_shared(*args, **kwargs): 89 | f = _test_linear_svgp(*args, **kwargs) # [S, N, L] 90 | mf = tf.reduce_mean(f, axis=0) 91 | res = tf.transpose(f - mf) 92 | Sff = tf.matmul(res, res, transpose_b=True)/f.shape[0] 93 | return mf, Sff 94 | 95 | 96 | @common.test_update_sparse_separate(default_config=ConfigDense()) 97 | def test_linear_sparse_separate(*args, **kwargs): 98 | f = _test_linear_svgp(*args, **kwargs) # [S, N, L] 99 | mf = tf.reduce_mean(f, axis=0) 100 | res = tf.transpose(f - mf) 101 | Sff = tf.matmul(res, res, transpose_b=True)/f.shape[0] 102 | return mf, Sff 103 | --------------------------------------------------------------------------------